diff --git a/.config/hakari.toml b/.config/hakari.toml index 9991cd92b0..dcbc44cc33 100644 --- a/.config/hakari.toml +++ b/.config/hakari.toml @@ -21,13 +21,14 @@ platforms = [ # "x86_64-apple-darwin", # "x86_64-pc-windows-msvc", ] - [final-excludes] workspace-members = [ # vm_monitor benefits from the same Cargo.lock as the rest of our artifacts, but # it is built primarly in separate repo neondatabase/autoscaling and thus is excluded # from depending on workspace-hack because most of the dependencies are not used. "vm_monitor", + # subzero-core is a stub crate that should be excluded from workspace-hack + "subzero-core", # All of these exist in libs and are not usually built independently. # Putting workspace hack there adds a bottleneck for cargo builds. "compute_api", diff --git a/.github/actions/prepare-for-subzero/action.yml b/.github/actions/prepare-for-subzero/action.yml new file mode 100644 index 0000000000..11beb11880 --- /dev/null +++ b/.github/actions/prepare-for-subzero/action.yml @@ -0,0 +1,28 @@ +name: 'Prepare current job for subzero' +description: > + Set git token to access `neondatabase/subzero` from cargo build, + and set `CARGO_NET_GIT_FETCH_WITH_CLI=true` env variable to use git CLI + +inputs: + token: + description: 'GitHub token with access to neondatabase/subzero' + required: true + +runs: + using: "composite" + + steps: + - name: Set git token for neondatabase/subzero + uses: pyTooling/Actions/with-post-step@2307b526df64d55e95884e072e49aac2a00a9afa # v5.1.0 + env: + SUBZERO_ACCESS_TOKEN: ${{ inputs.token }} + with: + main: | + git config --global url."https://x-access-token:${SUBZERO_ACCESS_TOKEN}@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero" + cargo add -p proxy subzero-core --git https://github.com/neondatabase/subzero --rev 396264617e78e8be428682f87469bb25429af88a + post: | + git config --global --unset url."https://x-access-token:${SUBZERO_ACCESS_TOKEN}@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero" + + - name: Set `CARGO_NET_GIT_FETCH_WITH_CLI=true` env variable + shell: bash -euxo pipefail {0} + run: echo "CARGO_NET_GIT_FETCH_WITH_CLI=true" >> ${GITHUB_ENV} diff --git a/.github/workflows/_build-and-test-locally.yml b/.github/workflows/_build-and-test-locally.yml index 94115572df..1b03dc9c03 100644 --- a/.github/workflows/_build-and-test-locally.yml +++ b/.github/workflows/_build-and-test-locally.yml @@ -86,6 +86,10 @@ jobs: with: submodules: true + - uses: ./.github/actions/prepare-for-subzero + with: + token: ${{ secrets.CI_ACCESS_TOKEN }} + - name: Set pg 14 revision for caching id: pg_v14_rev run: echo pg_rev=$(git rev-parse HEAD:vendor/postgres-v14) >> $GITHUB_OUTPUT @@ -116,7 +120,7 @@ jobs: ARCH: ${{ inputs.arch }} SANITIZERS: ${{ inputs.sanitizers }} run: | - CARGO_FLAGS="--locked --features testing" + CARGO_FLAGS="--locked --features testing,rest_broker" if [[ $BUILD_TYPE == "debug" && $ARCH == 'x64' ]]; then cov_prefix="scripts/coverage --profraw-prefix=$GITHUB_JOB --dir=/tmp/coverage run" CARGO_PROFILE="" diff --git a/.github/workflows/_check-codestyle-rust.yml b/.github/workflows/_check-codestyle-rust.yml index 4f844b0bf6..af29e10e97 100644 --- a/.github/workflows/_check-codestyle-rust.yml +++ b/.github/workflows/_check-codestyle-rust.yml @@ -46,6 +46,10 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: true + + - uses: ./.github/actions/prepare-for-subzero + with: + token: ${{ secrets.CI_ACCESS_TOKEN }} - name: Cache cargo deps uses: tespkg/actions-cache@b7bf5fcc2f98a52ac6080eb0fd282c2f752074b1 # v1.8.0 diff --git a/.github/workflows/build-macos.yml b/.github/workflows/build-macos.yml index 2296807d2d..e43eec1133 100644 --- a/.github/workflows/build-macos.yml +++ b/.github/workflows/build-macos.yml @@ -54,6 +54,10 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: true + + - uses: ./.github/actions/prepare-for-subzero + with: + token: ${{ secrets.CI_ACCESS_TOKEN }} - name: Install build dependencies run: | diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 2977f642bc..f237a991cc 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -632,6 +632,8 @@ jobs: BUILD_TAG=${{ needs.meta.outputs.release-tag || needs.meta.outputs.build-tag }} TAG=${{ needs.build-build-tools-image.outputs.image-tag }}-bookworm DEBIAN_VERSION=bookworm + secrets: | + SUBZERO_ACCESS_TOKEN=${{ secrets.CI_ACCESS_TOKEN }} provenance: false push: true pull: true diff --git a/.github/workflows/neon_extra_builds.yml b/.github/workflows/neon_extra_builds.yml index 3e81183687..10ca1a1591 100644 --- a/.github/workflows/neon_extra_builds.yml +++ b/.github/workflows/neon_extra_builds.yml @@ -72,6 +72,7 @@ jobs: check-macos-build: needs: [ check-permissions, files-changed ] uses: ./.github/workflows/build-macos.yml + secrets: inherit with: pg_versions: ${{ needs.files-changed.outputs.postgres_changes }} rebuild_rust_code: ${{ fromJSON(needs.files-changed.outputs.rebuild_rust_code) }} diff --git a/.gitignore b/.gitignore index 32e8fcf798..43204bcd46 100644 --- a/.gitignore +++ b/.gitignore @@ -27,9 +27,14 @@ docker-compose/docker-compose-parallel.yml *.o *.so *.Po +*.pid # pgindent typedef lists *.list # Node **/node_modules/ + +# various files for local testing +/proxy/.subzero +local_proxy.json diff --git a/Cargo.lock b/Cargo.lock index 9cb48b51be..68c9dadbe0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -52,6 +52,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "aliasable" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd" + [[package]] name = "aligned-vec" version = "0.6.1" @@ -501,7 +507,7 @@ dependencies = [ "hex", "hmac", "http 0.2.9", - "http 1.1.0", + "http 1.3.1", "once_cell", "p256 0.11.1", "percent-encoding", @@ -642,7 +648,7 @@ dependencies = [ "aws-smithy-types", "bytes", "http 0.2.9", - "http 1.1.0", + "http 1.3.1", "pin-project-lite", "tokio", "tracing", @@ -660,7 +666,7 @@ dependencies = [ "bytes-utils", "futures-core", "http 0.2.9", - "http 1.1.0", + "http 1.3.1", "http-body 0.4.5", "http-body 1.0.0", "http-body-util", @@ -698,51 +704,24 @@ dependencies = [ "tracing", ] -[[package]] -name = "axum" -version = "0.7.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" -dependencies = [ - "async-trait", - "axum-core 0.4.5", - "bytes", - "futures-util", - "http 1.1.0", - "http-body 1.0.0", - "http-body-util", - "itoa", - "matchit 0.7.3", - "memchr", - "mime", - "percent-encoding", - "pin-project-lite", - "rustversion", - "serde", - "sync_wrapper 1.0.1", - "tower 0.5.2", - "tower-layer", - "tower-service", -] - [[package]] name = "axum" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8" dependencies = [ - "axum-core 0.5.0", + "axum-core", "base64 0.22.1", "bytes", "form_urlencoded", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", - "hyper 1.6.0", + "hyper 1.4.1", "hyper-util", "itoa", - "matchit 0.8.4", + "matchit", "memchr", "mime", "percent-encoding", @@ -762,26 +741,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "axum-core" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" -dependencies = [ - "async-trait", - "bytes", - "futures-util", - "http 1.1.0", - "http-body 1.0.0", - "http-body-util", - "mime", - "pin-project-lite", - "rustversion", - "sync_wrapper 1.0.1", - "tower-layer", - "tower-service", -] - [[package]] name = "axum-core" version = "0.5.0" @@ -790,7 +749,7 @@ checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733" dependencies = [ "bytes", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "mime", @@ -808,13 +767,13 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "460fc6f625a1f7705c6cf62d0d070794e94668988b1c38111baeec177c715f7b" dependencies = [ - "axum 0.8.1", - "axum-core 0.5.0", + "axum", + "axum-core", "bytes", "form_urlencoded", "futures-util", "headers", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "mime", @@ -1148,8 +1107,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "975982cdb7ad6a142be15bdf84aea7ec6a9e5d4d797c004d43185b24cfe4e684" dependencies = [ "clap", - "heck", - "indexmap 2.9.0", + "heck 0.5.0", + "indexmap 2.10.0", "log", "proc-macro2", "quote", @@ -1286,7 +1245,7 @@ version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "syn 2.0.100", @@ -1349,12 +1308,13 @@ name = "communicator" version = "0.0.0" dependencies = [ "atomic_enum", - "axum 0.8.1", + "axum", "bytes", "cbindgen", "clashmap", - "http 1.1.0", + "http 1.3.1", "libc", + "measured", "metrics", "neon-shmem", "nix 0.30.1", @@ -1366,7 +1326,7 @@ dependencies = [ "thiserror 1.0.69", "tokio", "tokio-pipe", - "tonic 0.12.3", + "tonic", "tracing", "tracing-subscriber", "uring-common", @@ -1380,7 +1340,7 @@ version = "0.1.0" dependencies = [ "anyhow", "chrono", - "indexmap 2.9.0", + "indexmap 2.10.0", "jsonwebtoken", "regex", "remote_storage", @@ -1400,7 +1360,7 @@ dependencies = [ "aws-sdk-kms", "aws-sdk-s3", "aws-smithy-types", - "axum 0.8.1", + "axum", "axum-extra", "base64 0.22.1", "bytes", @@ -1413,8 +1373,11 @@ dependencies = [ "flate2", "futures", "hostname-validator", - "http 1.1.0", - "indexmap 2.9.0", + "http 1.3.1", + "http-body-util", + "hyper 1.4.1", + "hyper-util", + "indexmap 2.10.0", "itertools 0.10.5", "jsonwebtoken", "metrics", @@ -1436,6 +1399,7 @@ dependencies = [ "ring", "rlimit", "rust-ini", + "scopeguard", "serde", "serde_json", "serde_with", @@ -1446,7 +1410,7 @@ dependencies = [ "tokio-postgres", "tokio-stream", "tokio-util", - "tonic 0.13.1", + "tonic", "tower 0.5.2", "tower-http", "tower-otel", @@ -1524,7 +1488,7 @@ name = "consumption_metrics" version = "0.1.0" dependencies = [ "chrono", - "rand 0.8.5", + "rand 0.9.1", "serde", ] @@ -1927,7 +1891,7 @@ dependencies = [ "bytes", "hex", "parking_lot 0.12.1", - "rand 0.8.5", + "rand 0.9.1", "smallvec", "tracing", "utils", @@ -2048,7 +2012,7 @@ checksum = "0892a17df262a24294c382f0d5997571006e7a4348b4327557c4ff1cd4a8bccc" dependencies = [ "darling", "either", - "heck", + "heck 0.5.0", "proc-macro2", "quote", "syn 2.0.100", @@ -2162,7 +2126,7 @@ name = "endpoint_storage" version = "0.0.1" dependencies = [ "anyhow", - "axum 0.8.1", + "axum", "axum-extra", "camino", "camino-tempfile", @@ -2172,7 +2136,7 @@ dependencies = [ "itertools 0.10.5", "jsonwebtoken", "prometheus", - "rand 0.8.5", + "rand 0.9.1", "remote_storage", "serde", "serde_json", @@ -2449,7 +2413,7 @@ dependencies = [ "futures-core", "futures-sink", "http-body-util", - "hyper 1.6.0", + "hyper 1.4.1", "hyper-util", "pin-project", "rand 0.8.5", @@ -2728,7 +2692,7 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.9", - "indexmap 2.9.0", + "indexmap 2.10.0", "slab", "tokio", "tokio-util", @@ -2746,8 +2710,8 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http 1.1.0", - "indexmap 2.9.0", + "http 1.3.1", + "indexmap 2.10.0", "slab", "tokio", "tokio-util", @@ -2838,7 +2802,7 @@ dependencies = [ "base64 0.21.7", "bytes", "headers-core", - "http 1.1.0", + "http 1.3.1", "httpdate", "mime", "sha1", @@ -2850,9 +2814,15 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" dependencies = [ - "http 1.1.0", + "http 1.3.1", ] +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "heck" version = "0.5.0" @@ -2928,9 +2898,9 @@ dependencies = [ [[package]] name = "http" -version = "1.1.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" dependencies = [ "bytes", "fnv", @@ -2955,7 +2925,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" dependencies = [ "bytes", - "http 1.1.0", + "http 1.3.1", ] [[package]] @@ -2966,7 +2936,7 @@ checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" dependencies = [ "bytes", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "pin-project-lite", ] @@ -3010,7 +2980,7 @@ dependencies = [ "pprof", "regex", "routerify", - "rustls 0.23.27", + "rustls 0.23.29", "rustls-pemfile 2.1.1", "serde", "serde_json", @@ -3030,9 +3000,9 @@ dependencies = [ [[package]] name = "httparse" -version = "1.10.1" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" [[package]] name = "httpdate" @@ -3082,15 +3052,15 @@ dependencies = [ [[package]] name = "hyper" -version = "1.6.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05" dependencies = [ "bytes", "futures-channel", "futures-util", "h2 0.4.4", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "httparse", "httpdate", @@ -3123,8 +3093,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a0bea761b46ae2b24eb4aef630d8d1c398157b6fc29e6350ecf090a0b70c952c" dependencies = [ "futures-util", - "http 1.1.0", - "hyper 1.6.0", + "http 1.3.1", + "hyper 1.4.1", "hyper-util", "rustls 0.22.4", "rustls-pki-types", @@ -3139,7 +3109,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3203a961e5c83b6f5498933e78b6b263e208c197b63e9c6c53cc82ffd3f63793" dependencies = [ - "hyper 1.6.0", + "hyper 1.4.1", "hyper-util", "pin-project-lite", "tokio", @@ -3148,21 +3118,20 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.14" +version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc2fdfdbff08affe55bb779f33b053aa1fe5dd5b54c257343c17edfa55711bdb" +checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9" dependencies = [ "bytes", "futures-channel", - "futures-core", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", - "hyper 1.6.0", - "libc", + "hyper 1.4.1", "pin-project-lite", "socket2", "tokio", + "tower 0.4.13", "tower-service", "tracing", ] @@ -3348,9 +3317,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.9.0" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" +checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" dependencies = [ "equivalent", "hashbrown 0.15.2", @@ -3376,7 +3345,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "232929e1d75fe899576a3d5c7416ad0d88dbfbb3c3d6aa00873a7408a50ddb88" dependencies = [ "ahash", - "indexmap 2.9.0", + "indexmap 2.10.0", "is-terminal", "itoa", "log", @@ -3399,7 +3368,7 @@ dependencies = [ "crossbeam-utils", "dashmap 6.1.0", "env_logger", - "indexmap 2.9.0", + "indexmap 2.10.0", "itoa", "log", "num-format", @@ -3760,12 +3729,6 @@ dependencies = [ "regex-automata 0.1.10", ] -[[package]] -name = "matchit" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" - [[package]] name = "matchit" version = "0.8.4" @@ -3811,7 +3774,7 @@ version = "0.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9e6777fc80a575f9503d908c8b498782a6c3ee88a06cb416dc3941401e43b94" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "syn 2.0.100", @@ -3872,8 +3835,8 @@ dependencies = [ "once_cell", "procfs", "prometheus", - "rand 0.8.5", - "rand_distr 0.4.3", + "rand 0.9.1", + "rand_distr", "twox-hash 1.6.3", ] @@ -3968,7 +3931,7 @@ dependencies = [ "lock_api", "nix 0.30.1", "rand 0.9.1", - "rand_distr 0.5.1", + "rand_distr", "rustc-hash 2.1.1", "seahash", "tempfile", @@ -3984,7 +3947,7 @@ version = "0.1.0" dependencies = [ "crossbeam-utils", "rand 0.9.1", - "rand_distr 0.5.1", + "rand_distr", "spin", "tracing", ] @@ -4259,86 +4222,81 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "opentelemetry" -version = "0.27.1" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab70038c28ed37b97d8ed414b6429d343a8bbf44c9f79ec854f3a643029ba6d7" +checksum = "aaf416e4cb72756655126f7dd7bb0af49c674f4c1b9903e80c009e0c37e552e6" dependencies = [ "futures-core", "futures-sink", "js-sys", "pin-project-lite", - "thiserror 1.0.69", + "thiserror 2.0.11", "tracing", ] [[package]] name = "opentelemetry-http" -version = "0.27.0" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10a8a7f5f6ba7c1b286c2fbca0454eaba116f63bbe69ed250b642d36fbb04d80" +checksum = "50f6639e842a97dbea8886e3439710ae463120091e2e064518ba8e716e6ac36d" dependencies = [ "async-trait", "bytes", - "http 1.1.0", + "http 1.3.1", "opentelemetry", "reqwest", ] [[package]] name = "opentelemetry-otlp" -version = "0.27.0" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91cf61a1868dacc576bf2b2a1c3e9ab150af7272909e80085c3173384fe11f76" +checksum = "dbee664a43e07615731afc539ca60c6d9f1a9425e25ca09c57bc36c87c55852b" dependencies = [ - "async-trait", - "futures-core", - "http 1.1.0", + "http 1.3.1", "opentelemetry", "opentelemetry-http", "opentelemetry-proto", "opentelemetry_sdk", "prost 0.13.5", "reqwest", - "thiserror 1.0.69", + "thiserror 2.0.11", ] [[package]] name = "opentelemetry-proto" -version = "0.27.0" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6e05acbfada5ec79023c85368af14abd0b307c015e9064d249b2a950ef459a6" +checksum = "2e046fd7660710fe5a05e8748e70d9058dc15c94ba914e7c4faa7c728f0e8ddc" dependencies = [ "opentelemetry", "opentelemetry_sdk", "prost 0.13.5", - "tonic 0.12.3", + "tonic", ] [[package]] name = "opentelemetry-semantic-conventions" -version = "0.27.0" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc1b6902ff63b32ef6c489e8048c5e253e2e4a803ea3ea7e783914536eb15c52" +checksum = "83d059a296a47436748557a353c5e6c5705b9470ef6c95cfc52c21a8814ddac2" [[package]] name = "opentelemetry_sdk" -version = "0.27.1" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "231e9d6ceef9b0b2546ddf52335785ce41252bc7474ee8ba05bfad277be13ab8" +checksum = "11f644aa9e5e31d11896e024305d7e3c98a88884d9f8919dbf37a9991bc47a4b" dependencies = [ - "async-trait", "futures-channel", "futures-executor", "futures-util", - "glob", "opentelemetry", "percent-encoding", - "rand 0.8.5", + "rand 0.9.1", "serde_json", - "thiserror 1.0.69", + "thiserror 2.0.11", "tokio", "tokio-stream", - "tracing", ] [[package]] @@ -4371,6 +4329,30 @@ dependencies = [ "winapi", ] +[[package]] +name = "ouroboros" +version = "0.18.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e0f050db9c44b97a94723127e6be766ac5c340c48f2c4bb3ffa11713744be59" +dependencies = [ + "aliasable", + "ouroboros_macro", + "static_assertions", +] + +[[package]] +name = "ouroboros_macro" +version = "0.18.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c7028bdd3d43083f6d8d4d5187680d0d3560d54df4cc9d752005268b41e64d0" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "proc-macro2-diagnostics", + "quote", + "syn 2.0.100", +] + [[package]] name = "outref" version = "0.5.1" @@ -4422,13 +4404,13 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", - "axum 0.8.1", + "axum", "bytes", "camino", "clap", "futures", "hdrhistogram", - "http 1.1.0", + "http 1.3.1", "humantime", "humantime-serde", "metrics", @@ -4437,14 +4419,14 @@ dependencies = [ "pageserver_client_grpc", "pageserver_page_api", "pprof", - "rand 0.8.5", + "rand 0.9.1", "reqwest", "serde", "serde_json", "tokio", "tokio-stream", "tokio-util", - "tonic 0.13.1", + "tonic", "tracing", "url", "utils", @@ -4503,7 +4485,7 @@ dependencies = [ "hashlink", "hex", "hex-literal", - "http 1.1.0", + "http 1.3.1", "http-utils", "humantime", "humantime-serde", @@ -4536,14 +4518,14 @@ dependencies = [ "pq_proto", "procfs", "prost 0.13.5", - "rand 0.8.5", + "rand 0.9.1", "range-set-blaze", "regex", "remote_storage", "reqwest", "rpds", "rstest", - "rustls 0.23.27", + "rustls 0.23.29", "scopeguard", "send-future", "serde", @@ -4567,7 +4549,7 @@ dependencies = [ "tokio-tar", "tokio-util", "toml_edit", - "tonic 0.13.1", + "tonic", "tonic-reflection", "tower 0.5.2", "tracing", @@ -4603,7 +4585,7 @@ dependencies = [ "postgres_ffi_types", "postgres_versioninfo", "posthog_client_lite", - "rand 0.8.5", + "rand 0.9.1", "remote_storage", "reqwest", "serde", @@ -4653,7 +4635,7 @@ dependencies = [ "tokio", "tokio-stream", "tokio-util", - "tonic 0.13.1", + "tonic", "tracing", "utils", "workspace_hack", @@ -4673,7 +4655,7 @@ dependencies = [ "once_cell", "pageserver_api", "pin-project-lite", - "rand 0.8.5", + "rand 0.9.1", "svg_fmt", "tokio", "tracing", @@ -4698,7 +4680,7 @@ dependencies = [ "thiserror 1.0.69", "tokio", "tokio-util", - "tonic 0.13.1", + "tonic", "tonic-build", "utils", "workspace_hack", @@ -5055,7 +5037,7 @@ dependencies = [ "fallible-iterator", "hmac", "memchr", - "rand 0.8.5", + "rand 0.9.1", "sha2", "stringprep", "tokio", @@ -5089,7 +5071,7 @@ dependencies = [ "bytes", "once_cell", "pq_proto", - "rustls 0.23.27", + "rustls 0.23.29", "rustls-pemfile 2.1.1", "serde", "thiserror 1.0.69", @@ -5247,7 +5229,7 @@ dependencies = [ "bytes", "itertools 0.10.5", "postgres-protocol", - "rand 0.8.5", + "rand 0.9.1", "serde", "thiserror 1.0.69", "tokio", @@ -5281,6 +5263,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proc-macro2-diagnostics" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", + "version_check", + "yansi", +] + [[package]] name = "procfs" version = "0.16.0" @@ -5350,7 +5345,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" dependencies = [ "bytes", - "heck", + "heck 0.5.0", "itertools 0.12.1", "log", "multimap", @@ -5371,7 +5366,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15" dependencies = [ "bytes", - "heck", + "heck 0.5.0", "itertools 0.12.1", "log", "multimap", @@ -5467,15 +5462,15 @@ dependencies = [ "hex", "hmac", "hostname", - "http 1.1.0", + "http 1.3.1", "http-body-util", "http-utils", "humantime", "humantime-serde", "hyper 0.14.30", - "hyper 1.6.0", + "hyper 1.4.1", "hyper-util", - "indexmap 2.9.0", + "indexmap 2.10.0", "ipnet", "itertools 0.10.5", "itoa", @@ -5487,6 +5482,7 @@ dependencies = [ "metrics", "once_cell", "opentelemetry", + "ouroboros", "p256 0.13.2", "papaya", "parking_lot 0.12.1", @@ -5497,8 +5493,9 @@ dependencies = [ "postgres-protocol2", "postgres_backend", "pq_proto", - "rand 0.8.5", - "rand_distr 0.4.3", + "rand 0.9.1", + "rand_core 0.6.4", + "rand_distr", "rcgen", "redis", "regex", @@ -5510,7 +5507,7 @@ dependencies = [ "rsa", "rstest", "rustc-hash 2.1.1", - "rustls 0.23.27", + "rustls 0.23.29", "rustls-native-certs 0.8.0", "rustls-pemfile 2.1.1", "scopeguard", @@ -5523,6 +5520,7 @@ dependencies = [ "socket2", "strum_macros", "subtle", + "subzero-core", "thiserror 1.0.69", "tikv-jemalloc-ctl", "tikv-jemallocator", @@ -5699,16 +5697,6 @@ dependencies = [ "getrandom 0.3.3", ] -[[package]] -name = "rand_distr" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" -dependencies = [ - "num-traits", - "rand 0.8.5", -] - [[package]] name = "rand_distr" version = "0.5.1" @@ -5798,7 +5786,7 @@ dependencies = [ "num-bigint", "percent-encoding", "pin-project-lite", - "rustls 0.23.27", + "rustls 0.23.29", "rustls-native-certs 0.8.0", "ryu", "sha1_smol", @@ -5838,14 +5826,14 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.2" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.3", - "regex-syntax 0.8.2", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", ] [[package]] @@ -5859,13 +5847,13 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.3" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.2", + "regex-syntax 0.8.5", ] [[package]] @@ -5882,9 +5870,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.8.2" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "relative-path" @@ -5917,12 +5905,12 @@ dependencies = [ "http-body-util", "http-types", "humantime-serde", - "hyper 1.6.0", + "hyper 1.4.1", "itertools 0.10.5", "metrics", "once_cell", "pin-project-lite", - "rand 0.8.5", + "rand 0.9.1", "reqwest", "scopeguard", "serde", @@ -5954,10 +5942,10 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", - "hyper 1.6.0", + "hyper 1.4.1", "hyper-rustls 0.26.0", "hyper-util", "ipnet", @@ -5996,7 +5984,7 @@ checksum = "d1ccd3b55e711f91a9885a2fa6fbbb2e39db1776420b062efc058c6410f7e5e3" dependencies = [ "anyhow", "async-trait", - "http 1.1.0", + "http 1.3.1", "reqwest", "serde", "thiserror 1.0.69", @@ -6013,8 +6001,8 @@ dependencies = [ "async-trait", "futures", "getrandom 0.2.11", - "http 1.1.0", - "hyper 1.6.0", + "http 1.3.1", + "hyper 1.4.1", "parking_lot 0.11.2", "reqwest", "reqwest-middleware", @@ -6027,15 +6015,15 @@ dependencies = [ [[package]] name = "reqwest-tracing" -version = "0.5.5" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73e6153390585f6961341b50e5a1931d6be6dee4292283635903c26ef9d980d2" +checksum = "d70ea85f131b2ee9874f0b160ac5976f8af75f3c9badfe0d955880257d10bd83" dependencies = [ "anyhow", "async-trait", "getrandom 0.2.11", - "http 1.1.0", - "matchit 0.8.4", + "http 1.3.1", + "matchit", "opentelemetry", "reqwest", "reqwest-middleware", @@ -6254,15 +6242,15 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.27" +version = "0.23.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321" +checksum = "2491382039b29b9b11ff08b76ff6c97cf287671dbb74f0be44bda389fffe9bd1" dependencies = [ "log", "once_cell", "ring", "rustls-pki-types", - "rustls-webpki 0.103.3", + "rustls-webpki 0.103.4", "subtle", "zeroize", ] @@ -6326,9 +6314,12 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" +checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" +dependencies = [ + "zeroize", +] [[package]] name = "rustls-webpki" @@ -6353,9 +6344,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.3" +version = "0.103.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435" +checksum = "0a17884ae0c1b773f1ccd2bd4a8c72f16da897310a98b0e84bf349ad5ead92fc" dependencies = [ "ring", "rustls-pki-types", @@ -6393,7 +6384,7 @@ dependencies = [ "fail", "futures", "hex", - "http 1.1.0", + "http 1.3.1", "http-utils", "humantime", "hyper 0.14.30", @@ -6412,11 +6403,11 @@ dependencies = [ "postgres_versioninfo", "pprof", "pq_proto", - "rand 0.8.5", + "rand 0.9.1", "regex", "remote_storage", "reqwest", - "rustls 0.23.27", + "rustls 0.23.29", "safekeeper_api", "safekeeper_client", "scopeguard", @@ -6612,7 +6603,7 @@ checksum = "255914a8e53822abd946e2ce8baa41d4cded6b8e938913b7f7b9da5b7ab44335" dependencies = [ "httpdate", "reqwest", - "rustls 0.23.27", + "rustls 0.23.29", "sentry-backtrace", "sentry-contexts", "sentry-core", @@ -6744,7 +6735,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d2de91cf02bbc07cde38891769ccd5d4f073d22a40683aa4bc7a95781aaa2c4" dependencies = [ "form_urlencoded", - "indexmap 2.9.0", + "indexmap 2.10.0", "itoa", "ryu", "serde", @@ -6825,7 +6816,7 @@ dependencies = [ "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.9.0", + "indexmap 2.10.0", "serde", "serde_derive", "serde_json", @@ -6991,12 +6982,12 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.10" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.48.0", ] [[package]] @@ -7065,16 +7056,16 @@ dependencies = [ "http-body-util", "http-utils", "humantime", - "hyper 1.6.0", + "hyper 1.4.1", "hyper-util", "metrics", "once_cell", "parking_lot 0.12.1", "prost 0.13.5", - "rustls 0.23.27", + "rustls 0.23.29", "tokio", "tokio-rustls 0.26.2", - "tonic 0.13.1", + "tonic", "tonic-build", "tracing", "utils", @@ -7115,11 +7106,11 @@ dependencies = [ "pageserver_client", "postgres_connection", "posthog_client_lite", - "rand 0.8.5", + "rand 0.9.1", "regex", "reqwest", "routerify", - "rustls 0.23.27", + "rustls 0.23.29", "rustls-native-certs 0.8.0", "safekeeper_api", "safekeeper_client", @@ -7173,7 +7164,7 @@ dependencies = [ "postgres_ffi", "remote_storage", "reqwest", - "rustls 0.23.27", + "rustls 0.23.29", "rustls-native-certs 0.8.0", "serde", "serde_json", @@ -7251,7 +7242,7 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "rustversion", @@ -7264,6 +7255,10 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" +[[package]] +name = "subzero-core" +version = "3.0.1" + [[package]] name = "svg_fmt" version = "0.4.3" @@ -7718,7 +7713,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04fb792ccd6bbcd4bba408eb8a292f70fc4a3589e5d793626f45190e6454b6ab" dependencies = [ "ring", - "rustls 0.23.27", + "rustls 0.23.29", "tokio", "tokio-postgres", "tokio-rustls 0.26.2", @@ -7769,7 +7764,7 @@ version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" dependencies = [ - "rustls 0.23.27", + "rustls 0.23.29", "tokio", ] @@ -7868,43 +7863,13 @@ version = "0.22.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f21c7aaf97f1bd9ca9d4f9e73b0a6c74bd5afef56f2bc931943a6e1c37e04e38" dependencies = [ - "indexmap 2.9.0", + "indexmap 2.10.0", "serde", "serde_spanned", "toml_datetime", "winnow", ] -[[package]] -name = "tonic" -version = "0.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52" -dependencies = [ - "async-stream", - "async-trait", - "axum 0.7.9", - "base64 0.22.1", - "bytes", - "h2 0.4.4", - "http 1.1.0", - "http-body 1.0.0", - "http-body-util", - "hyper 1.6.0", - "hyper-timeout", - "hyper-util", - "percent-encoding", - "pin-project", - "prost 0.13.5", - "socket2", - "tokio", - "tokio-stream", - "tower 0.4.13", - "tower-layer", - "tower-service", - "tracing", -] - [[package]] name = "tonic" version = "0.13.1" @@ -7912,15 +7877,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7e581ba15a835f4d9ea06c55ab1bd4dce26fc53752c69a04aac00703bfb49ba9" dependencies = [ "async-trait", - "axum 0.8.1", + "axum", "base64 0.22.1", "bytes", "flate2", "h2 0.4.4", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", - "hyper 1.6.0", + "hyper 1.4.1", "hyper-timeout", "hyper-util", "percent-encoding", @@ -7962,7 +7927,7 @@ dependencies = [ "prost-types 0.13.5", "tokio", "tokio-stream", - "tonic 0.13.1", + "tonic", ] [[package]] @@ -7973,16 +7938,11 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" dependencies = [ "futures-core", "futures-util", - "indexmap 1.9.3", "pin-project", "pin-project-lite", - "rand 0.8.5", - "slab", "tokio", - "tokio-util", "tower-layer", "tower-service", - "tracing", ] [[package]] @@ -7993,7 +7953,7 @@ checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" dependencies = [ "futures-core", "futures-util", - "indexmap 2.9.0", + "indexmap 2.10.0", "pin-project-lite", "slab", "sync_wrapper 1.0.1", @@ -8013,7 +7973,7 @@ dependencies = [ "base64 0.22.1", "bitflags 2.8.0", "bytes", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "mime", "pin-project-lite", @@ -8031,10 +7991,14 @@ checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-otel" -version = "0.2.0" -source = "git+https://github.com/mattiapenati/tower-otel?rev=56a7321053bcb72443888257b622ba0d43a11fcd#56a7321053bcb72443888257b622ba0d43a11fcd" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "345000ea5ae33222624a8ccfdd88892c30db4d413a39c2d4bd714b77e0a4b23c" dependencies = [ - "http 1.1.0", + "axum", + "cfg-if", + "http 1.3.1", + "http-body 1.0.0", "opentelemetry", "pin-project", "tower-layer", @@ -8116,9 +8080,9 @@ dependencies = [ [[package]] name = "tracing-opentelemetry" -version = "0.28.0" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97a971f6058498b5c0f1affa23e7ea202057a7301dbff68e968b2d578bcbd053" +checksum = "ddcf5959f39507d0d04d6413119c04f33b623f4f951ebcbdddddfad2d0623a9c" dependencies = [ "js-sys", "once_cell", @@ -8215,7 +8179,7 @@ dependencies = [ "byteorder", "bytes", "data-encoding", - "http 1.1.0", + "http 1.3.1", "httparse", "log", "rand 0.8.5", @@ -8234,7 +8198,7 @@ dependencies = [ "byteorder", "bytes", "data-encoding", - "http 1.1.0", + "http 1.3.1", "httparse", "log", "rand 0.8.5", @@ -8335,7 +8299,7 @@ dependencies = [ "base64 0.22.1", "log", "once_cell", - "rustls 0.23.27", + "rustls 0.23.29", "rustls-pki-types", "url", "webpki-roots", @@ -8425,7 +8389,7 @@ dependencies = [ "postgres_connection", "pprof", "pq_proto", - "rand 0.8.5", + "rand 0.9.1", "regex", "scopeguard", "sentry", @@ -8476,7 +8440,7 @@ name = "vm_monitor" version = "0.1.0" dependencies = [ "anyhow", - "axum 0.8.1", + "axum", "cgroups-rs", "clap", "futures", @@ -8970,8 +8934,8 @@ dependencies = [ "ahash", "anstream", "anyhow", - "axum 0.8.1", - "axum-core 0.5.0", + "axum", + "axum-core", "base64 0.21.7", "base64ct", "bytes", @@ -9005,9 +8969,9 @@ dependencies = [ "hex", "hmac", "hyper 0.14.30", - "hyper 1.6.0", + "hyper 1.4.1", "hyper-util", - "indexmap 2.9.0", + "indexmap 2.10.0", "itertools 0.12.1", "lazy_static", "libc", @@ -9030,14 +8994,14 @@ dependencies = [ "proc-macro2", "prost 0.13.5", "quote", - "rand 0.8.5", + "rand 0.9.1", "regex", - "regex-automata 0.4.3", - "regex-syntax 0.8.2", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", "reqwest", - "rustls 0.23.27", + "rustls 0.23.29", "rustls-pki-types", - "rustls-webpki 0.103.3", + "rustls-webpki 0.103.4", "scopeguard", "sec1 0.7.3", "serde", @@ -9050,6 +9014,7 @@ dependencies = [ "subtle", "syn 2.0.100", "sync_wrapper 0.1.2", + "thiserror 2.0.11", "tikv-jemalloc-ctl", "tikv-jemalloc-sys", "time", @@ -9059,6 +9024,7 @@ dependencies = [ "tokio-stream", "tokio-util", "toml_edit", + "tonic", "tower 0.5.2", "tracing", "tracing-core", @@ -9135,6 +9101,12 @@ version = "0.8.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fdd20c5420375476fbd4394763288da7eb0cc0b8c11deed431a91562af7335d3" +[[package]] +name = "yansi" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" + [[package]] name = "yasna" version = "0.5.2" diff --git a/Cargo.toml b/Cargo.toml index 76a1a57aa9..fba334d614 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,6 +50,7 @@ members = [ "libs/proxy/tokio-postgres2", "endpoint_storage", "pgxn/neon/communicator", + "proxy/subzero_core", ] [workspace.package] @@ -144,10 +145,10 @@ notify = "6.0.0" num_cpus = "1.15" num-traits = "0.2.19" once_cell = "1.13" -opentelemetry = "0.27" -opentelemetry_sdk = "0.27" -opentelemetry-otlp = { version = "0.27", default-features = false, features = ["http-proto", "trace", "http", "reqwest-client"] } -opentelemetry-semantic-conventions = "0.27" +opentelemetry = "0.30" +opentelemetry_sdk = "0.30" +opentelemetry-otlp = { version = "0.30", default-features = false, features = ["http-proto", "trace", "http", "reqwest-client"] } +opentelemetry-semantic-conventions = "0.30" parking_lot = "0.12" parquet = { version = "53", default-features = false, features = ["zstd"] } parquet_derive = "53" @@ -160,11 +161,13 @@ procfs = "0.16" prometheus = {version = "0.13", default-features=false, features = ["process"]} # removes protobuf dependency prost = "0.13.5" prost-types = "0.13.5" -rand = "0.8" +rand = "0.9" +# Remove after p256 is updated to 0.14. +rand_core = "=0.6" redis = { version = "0.29.2", features = ["tokio-rustls-comp", "keep-alive"] } regex = "1.10.2" reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] } -reqwest-tracing = { version = "0.5", features = ["opentelemetry_0_27"] } +reqwest-tracing = { version = "0.5", features = ["opentelemetry_0_30"] } reqwest-middleware = "0.4" reqwest-retry = "0.7" routerify = "3" @@ -214,15 +217,12 @@ tonic = { version = "0.13.1", default-features = false, features = ["channel", " tonic-reflection = { version = "0.13.1", features = ["server"] } tower = { version = "0.5.2", default-features = false } tower-http = { version = "0.6.2", features = ["auth", "request-id", "trace"] } - -# This revision uses opentelemetry 0.27. There's no tag for it. -tower-otel = { git = "https://github.com/mattiapenati/tower-otel", rev = "56a7321053bcb72443888257b622ba0d43a11fcd" } - +tower-otel = { version = "0.6", features = ["axum"] } tower-service = "0.3.3" tracing = "0.1" tracing-error = "0.2" tracing-log = "0.2" -tracing-opentelemetry = "0.28" +tracing-opentelemetry = "0.31" tracing-serde = "0.2.0" tracing-subscriber = { version = "0.3", default-features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] } try-lock = "0.2.5" diff --git a/Dockerfile b/Dockerfile index 55b87d4012..654ae72e56 100644 --- a/Dockerfile +++ b/Dockerfile @@ -63,7 +63,14 @@ WORKDIR /home/nonroot COPY --chown=nonroot . . -RUN cargo chef prepare --recipe-path recipe.json +RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \ + set -e \ + && if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \ + export CARGO_NET_GIT_FETCH_WITH_CLI=true && \ + git config --global url."https://$(cat /run/secrets/SUBZERO_ACCESS_TOKEN)@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero" && \ + cargo add -p proxy subzero-core --git https://github.com/neondatabase/subzero --rev 396264617e78e8be428682f87469bb25429af88a; \ + fi \ + && cargo chef prepare --recipe-path recipe.json # Main build image FROM $REPOSITORY/$IMAGE:$TAG AS build @@ -71,20 +78,33 @@ WORKDIR /home/nonroot ARG GIT_VERSION=local ARG BUILD_TAG ARG ADDITIONAL_RUSTFLAGS="" +ENV CARGO_FEATURES="default" # 3. Build cargo dependencies. Note that this step doesn't depend on anything else than # `recipe.json`, so the layer can be reused as long as none of the dependencies change. COPY --from=plan /home/nonroot/recipe.json recipe.json -RUN set -e \ +RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \ + set -e \ + && if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \ + export CARGO_NET_GIT_FETCH_WITH_CLI=true && \ + git config --global url."https://$(cat /run/secrets/SUBZERO_ACCESS_TOKEN)@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero"; \ + fi \ && RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo chef cook --locked --release --recipe-path recipe.json # Perform the main build. We reuse the Postgres build artifacts from the intermediate 'pg-build' # layer, and the cargo dependencies built in the previous step. COPY --chown=nonroot --from=pg-build /home/nonroot/pg_install/ pg_install COPY --chown=nonroot . . +COPY --chown=nonroot --from=plan /home/nonroot/proxy/Cargo.toml proxy/Cargo.toml +COPY --chown=nonroot --from=plan /home/nonroot/Cargo.lock Cargo.lock -RUN set -e \ +RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \ + set -e \ + && if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \ + export CARGO_FEATURES="rest_broker"; \ + fi \ && RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo build \ + --features $CARGO_FEATURES \ --bin pg_sni_router \ --bin pageserver \ --bin pagectl \ diff --git a/compute_tools/Cargo.toml b/compute_tools/Cargo.toml index 910bae3bda..496471acc7 100644 --- a/compute_tools/Cargo.toml +++ b/compute_tools/Cargo.toml @@ -27,7 +27,10 @@ fail.workspace = true flate2.workspace = true futures.workspace = true http.workspace = true +http-body-util.workspace = true hostname-validator = "1.1" +hyper.workspace = true +hyper-util.workspace = true indexmap.workspace = true itertools.workspace = true jsonwebtoken.workspace = true @@ -44,6 +47,7 @@ postgres.workspace = true regex.workspace = true reqwest = { workspace = true, features = ["json"] } ring = "0.17" +scopeguard.workspace = true serde.workspace = true serde_with.workspace = true serde_json.workspace = true diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index 78e2c6308f..04723d6f3d 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -138,6 +138,12 @@ struct Cli { /// Run in development mode, skipping VM-specific operations like process termination #[arg(long, action = clap::ArgAction::SetTrue)] pub dev: bool, + + #[arg(long)] + pub pg_init_timeout: Option, + + #[arg(long, default_value_t = false, action = clap::ArgAction::Set)] + pub lakebase_mode: bool, } impl Cli { @@ -188,7 +194,7 @@ fn main() -> Result<()> { .build()?; let _rt_guard = runtime.enter(); - runtime.block_on(init(cli.dev))?; + let tracing_provider = init(cli.dev)?; // enable core dumping for all child processes setrlimit(Resource::CORE, rlimit::INFINITY, rlimit::INFINITY)?; @@ -219,6 +225,8 @@ fn main() -> Result<()> { installed_extensions_collection_interval: Arc::new(AtomicU64::new( cli.installed_extensions_collection_interval, )), + pg_init_timeout: cli.pg_init_timeout.map(Duration::from_secs), + lakebase_mode: cli.lakebase_mode, }, config, )?; @@ -227,11 +235,11 @@ fn main() -> Result<()> { scenario.teardown(); - deinit_and_exit(exit_code); + deinit_and_exit(tracing_provider, exit_code); } -async fn init(dev_mode: bool) -> Result<()> { - init_tracing_and_logging(DEFAULT_LOG_LEVEL).await?; +fn init(dev_mode: bool) -> Result> { + let provider = init_tracing_and_logging(DEFAULT_LOG_LEVEL)?; let mut signals = Signals::new([SIGINT, SIGTERM, SIGQUIT])?; thread::spawn(move || { @@ -242,7 +250,7 @@ async fn init(dev_mode: bool) -> Result<()> { info!("compute build_tag: {}", &BUILD_TAG.to_string()); - Ok(()) + Ok(provider) } fn get_config(cli: &Cli) -> Result { @@ -267,25 +275,27 @@ fn get_config(cli: &Cli) -> Result { } } -fn deinit_and_exit(exit_code: Option) -> ! { - // Shutdown trace pipeline gracefully, so that it has a chance to send any - // pending traces before we exit. Shutting down OTEL tracing provider may - // hang for quite some time, see, for example: - // - https://github.com/open-telemetry/opentelemetry-rust/issues/868 - // - and our problems with staging https://github.com/neondatabase/cloud/issues/3707#issuecomment-1493983636 - // - // Yet, we want computes to shut down fast enough, as we may need a new one - // for the same timeline ASAP. So wait no longer than 2s for the shutdown to - // complete, then just error out and exit the main thread. - info!("shutting down tracing"); - let (sender, receiver) = mpsc::channel(); - let _ = thread::spawn(move || { - tracing_utils::shutdown_tracing(); - sender.send(()).ok() - }); - let shutdown_res = receiver.recv_timeout(Duration::from_millis(2000)); - if shutdown_res.is_err() { - error!("timed out while shutting down tracing, exiting anyway"); +fn deinit_and_exit(tracing_provider: Option, exit_code: Option) -> ! { + if let Some(p) = tracing_provider { + // Shutdown trace pipeline gracefully, so that it has a chance to send any + // pending traces before we exit. Shutting down OTEL tracing provider may + // hang for quite some time, see, for example: + // - https://github.com/open-telemetry/opentelemetry-rust/issues/868 + // - and our problems with staging https://github.com/neondatabase/cloud/issues/3707#issuecomment-1493983636 + // + // Yet, we want computes to shut down fast enough, as we may need a new one + // for the same timeline ASAP. So wait no longer than 2s for the shutdown to + // complete, then just error out and exit the main thread. + info!("shutting down tracing"); + let (sender, receiver) = mpsc::channel(); + let _ = thread::spawn(move || { + _ = p.shutdown(); + sender.send(()).ok() + }); + let shutdown_res = receiver.recv_timeout(Duration::from_millis(2000)); + if shutdown_res.is_err() { + error!("timed out while shutting down tracing, exiting anyway"); + } } info!("shutting down"); diff --git a/compute_tools/src/communicator_socket_client.rs b/compute_tools/src/communicator_socket_client.rs new file mode 100644 index 0000000000..806e0a21e3 --- /dev/null +++ b/compute_tools/src/communicator_socket_client.rs @@ -0,0 +1,98 @@ +//! Client for making request to a running Postgres server's communicator control socket. +//! +//! The storage communicator process that runs inside Postgres exposes an HTTP endpoint in +//! a Unix Domain Socket in the Postgres data directory. This provides access to it. + +use std::path::Path; + +use anyhow::Context; +use hyper::client::conn::http1::SendRequest; +use hyper_util::rt::TokioIo; + +/// Name of the socket within the Postgres data directory. This better match that in +/// `pgxn/neon/communicator/src/lib.rs`. +const NEON_COMMUNICATOR_SOCKET_NAME: &str = "neon-communicator.socket"; + +/// Open a connection to the communicator's control socket, prepare to send requests to it +/// with hyper. +pub async fn connect_communicator_socket(pgdata: &Path) -> anyhow::Result> +where + B: hyper::body::Body + 'static + Send, + B::Data: Send, + B::Error: Into>, +{ + let socket_path = pgdata.join(NEON_COMMUNICATOR_SOCKET_NAME); + let socket_path_len = socket_path.display().to_string().len(); + + // There is a limit of around 100 bytes (108 on Linux?) on the length of the path to a + // Unix Domain socket. The limit is on the connect(2) function used to open the + // socket, not on the absolute path itself. Postgres changes the current directory to + // the data directory and uses a relative path to bind to the socket, and the relative + // path "./neon-communicator.socket" is always short, but when compute_ctl needs to + // open the socket, we need to use a full path, which can be arbitrarily long. + // + // There are a few ways we could work around this: + // + // 1. Change the current directory to the Postgres data directory and use a relative + // path in the connect(2) call. That's problematic because the current directory + // applies to the whole process. We could change the current directory early in + // compute_ctl startup, and that might be a good idea anyway for other reasons too: + // it would be more robust if the data directory is moved around or unlinked for + // some reason, and you would be less likely to accidentally litter other parts of + // the filesystem with e.g. temporary files. However, that's a pretty invasive + // change. + // + // 2. On Linux, you could open() the data directory, and refer to the the socket + // inside it as "/proc/self/fd//neon-communicator.socket". But that's + // Linux-only. + // + // 3. Create a symbolic link to the socket with a shorter path, and use that. + // + // We use the symbolic link approach here. Hopefully the paths we use in production + // are shorter, so that we can open the socket directly, so that this hack is needed + // only in development. + let connect_result = if socket_path_len < 100 { + // We can open the path directly with no hacks. + tokio::net::UnixStream::connect(socket_path).await + } else { + // The path to the socket is too long. Create a symlink to it with a shorter path. + let short_path = std::env::temp_dir().join(format!( + "compute_ctl.short-socket.{}.{}", + std::process::id(), + tokio::task::id() + )); + std::os::unix::fs::symlink(&socket_path, &short_path)?; + + // Delete the symlink as soon as we have connected to it. There's a small chance + // of leaking if the process dies before we remove it, so try to keep that window + // as small as possible. + scopeguard::defer! { + if let Err(err) = std::fs::remove_file(&short_path) { + tracing::warn!("could not remove symlink \"{}\" created for socket: {}", + short_path.display(), err); + } + } + + tracing::info!( + "created symlink \"{}\" for socket \"{}\", opening it now", + short_path.display(), + socket_path.display() + ); + + tokio::net::UnixStream::connect(&short_path).await + }; + + let stream = connect_result.context("connecting to communicator control socket")?; + + let io = TokioIo::new(stream); + let (request_sender, connection) = hyper::client::conn::http1::handshake(io).await?; + + // spawn a task to poll the connection and drive the HTTP state + tokio::spawn(async move { + if let Err(err) = connection.await { + eprintln!("Error in connection: {err}"); + } + }); + + Ok(request_sender) +} diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index d7ec37cc0a..dac17cf6c9 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -114,6 +114,11 @@ pub struct ComputeNodeParams { /// Interval for installed extensions collection pub installed_extensions_collection_interval: Arc, + + /// Timeout of PG compute startup in the Init state. + pub pg_init_timeout: Option, + + pub lakebase_mode: bool, } type TaskHandle = Mutex>>; @@ -155,6 +160,7 @@ pub struct RemoteExtensionMetrics { #[derive(Clone, Debug)] pub struct ComputeState { pub start_time: DateTime, + pub pg_start_time: Option>, pub status: ComputeStatus, /// Timestamp of the last Postgres activity. It could be `None` if /// compute wasn't used since start. @@ -192,6 +198,7 @@ impl ComputeState { pub fn new() -> Self { Self { start_time: Utc::now(), + pg_start_time: None, status: ComputeStatus::Empty, last_active: None, error: None, @@ -737,6 +744,9 @@ impl ComputeNode { }; _this_entered = start_compute_span.enter(); + // Hadron: Record postgres start time (used to enforce pg_init_timeout). + state_guard.pg_start_time.replace(Utc::now()); + state_guard.set_status(ComputeStatus::Init, &self.state_changed); compute_state = state_guard.clone() } @@ -1544,7 +1554,7 @@ impl ComputeNode { .with_context(|| format!("failed to get basebackup@{lsn}"))?; // Update pg_hba.conf received with basebackup. - update_pg_hba(pgdata_path)?; + update_pg_hba(pgdata_path, None)?; // Place pg_dynshmem under /dev/shm. This allows us to use // 'dynamic_shared_memory_type = mmap' so that the files are placed in @@ -1849,6 +1859,7 @@ impl ComputeNode { } // Run migrations separately to not hold up cold starts + let lakebase_mode = self.params.lakebase_mode; let params = self.params.clone(); tokio::spawn(async move { let mut conf = conf.as_ref().clone(); @@ -1861,7 +1872,7 @@ impl ComputeNode { eprintln!("connection error: {e}"); } }); - if let Err(e) = handle_migrations(params, &mut client).await { + if let Err(e) = handle_migrations(params, &mut client, lakebase_mode).await { error!("Failed to run migrations: {}", e); } } diff --git a/compute_tools/src/http/routes/metrics.rs b/compute_tools/src/http/routes/metrics.rs index da8d8b20a5..96b464fd12 100644 --- a/compute_tools/src/http/routes/metrics.rs +++ b/compute_tools/src/http/routes/metrics.rs @@ -1,10 +1,18 @@ +use std::path::Path; +use std::sync::Arc; + +use anyhow::Context; use axum::body::Body; +use axum::extract::State; use axum::response::Response; -use http::StatusCode; use http::header::CONTENT_TYPE; +use http_body_util::BodyExt; +use hyper::{Request, StatusCode}; use metrics::proto::MetricFamily; use metrics::{Encoder, TextEncoder}; +use crate::communicator_socket_client::connect_communicator_socket; +use crate::compute::ComputeNode; use crate::http::JsonResponse; use crate::metrics::collect; @@ -31,3 +39,42 @@ pub(in crate::http) async fn get_metrics() -> Response { .body(Body::from(buffer)) .unwrap() } + +/// Fetch and forward metrics from the Postgres neon extension's metrics +/// exporter that are used by autoscaling-agent. +/// +/// The neon extension exposes these metrics over a Unix domain socket +/// in the data directory. That's not accessible directly from the outside +/// world, so we have this endpoint in compute_ctl to expose it +pub(in crate::http) async fn get_autoscaling_metrics( + State(compute): State>, +) -> Result { + let pgdata = Path::new(&compute.params.pgdata); + + // Connect to the communicator process's metrics socket + let mut metrics_client = connect_communicator_socket(pgdata) + .await + .map_err(|e| JsonResponse::error(StatusCode::INTERNAL_SERVER_ERROR, format!("{e:#}")))?; + + // Make a request for /autoscaling_metrics + let request = Request::builder() + .method("GET") + .uri("/autoscaling_metrics") + .header("Host", "localhost") // hyper requires Host, even though the server won't care + .body(Body::from("")) + .unwrap(); + let resp = metrics_client + .send_request(request) + .await + .context("fetching metrics from Postgres metrics service") + .map_err(|e| JsonResponse::error(StatusCode::INTERNAL_SERVER_ERROR, format!("{e:#}")))?; + + // Build a response that just forwards the response we got. + let mut response = Response::builder(); + response = response.status(resp.status()); + if let Some(content_type) = resp.headers().get(CONTENT_TYPE) { + response = response.header(CONTENT_TYPE, content_type); + } + let body = tonic::service::AxumBody::from_stream(resp.into_body().into_data_stream()); + Ok(response.body(body).unwrap()) +} diff --git a/compute_tools/src/http/server.rs b/compute_tools/src/http/server.rs index 17939e39d4..f0fbca8263 100644 --- a/compute_tools/src/http/server.rs +++ b/compute_tools/src/http/server.rs @@ -81,8 +81,12 @@ impl From<&Server> for Router> { Server::External { config, compute_id, .. } => { - let unauthenticated_router = - Router::>::new().route("/metrics", get(metrics::get_metrics)); + let unauthenticated_router = Router::>::new() + .route("/metrics", get(metrics::get_metrics)) + .route( + "/autoscaling_metrics", + get(metrics::get_autoscaling_metrics), + ); let authenticated_router = Router::>::new() .route("/lfc/prewarm", get(lfc::prewarm_state).post(lfc::prewarm)) diff --git a/compute_tools/src/lib.rs b/compute_tools/src/lib.rs index 2d5d4565b7..4d0a7dca05 100644 --- a/compute_tools/src/lib.rs +++ b/compute_tools/src/lib.rs @@ -4,6 +4,7 @@ #![deny(clippy::undocumented_unsafe_blocks)] pub mod checker; +pub mod communicator_socket_client; pub mod config; pub mod configurator; pub mod http; diff --git a/compute_tools/src/logger.rs b/compute_tools/src/logger.rs index c36f302f99..cd076472a6 100644 --- a/compute_tools/src/logger.rs +++ b/compute_tools/src/logger.rs @@ -13,7 +13,9 @@ use tracing_subscriber::prelude::*; /// set `OTEL_EXPORTER_OTLP_ENDPOINT=http://jaeger:4318`. See /// `tracing-utils` package description. /// -pub async fn init_tracing_and_logging(default_log_level: &str) -> anyhow::Result<()> { +pub fn init_tracing_and_logging( + default_log_level: &str, +) -> anyhow::Result> { // Initialize Logging let env_filter = tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(default_log_level)); @@ -24,8 +26,9 @@ pub async fn init_tracing_and_logging(default_log_level: &str) -> anyhow::Result .with_writer(std::io::stderr); // Initialize OpenTelemetry - let otlp_layer = - tracing_utils::init_tracing("compute_ctl", tracing_utils::ExportConfig::default()).await; + let provider = + tracing_utils::init_tracing("compute_ctl", tracing_utils::ExportConfig::default()); + let otlp_layer = provider.as_ref().map(tracing_utils::layer); // Put it all together tracing_subscriber::registry() @@ -37,7 +40,7 @@ pub async fn init_tracing_and_logging(default_log_level: &str) -> anyhow::Result utils::logging::replace_panic_hook_with_tracing_panic_hook().forget(); - Ok(()) + Ok(provider) } /// Replace all newline characters with a special character to make it diff --git a/compute_tools/src/migration.rs b/compute_tools/src/migration.rs index c5e05822c0..88d870df97 100644 --- a/compute_tools/src/migration.rs +++ b/compute_tools/src/migration.rs @@ -9,15 +9,20 @@ use crate::metrics::DB_MIGRATION_FAILED; pub(crate) struct MigrationRunner<'m> { client: &'m mut Client, migrations: &'m [&'m str], + lakebase_mode: bool, } impl<'m> MigrationRunner<'m> { /// Create a new migration runner - pub fn new(client: &'m mut Client, migrations: &'m [&'m str]) -> Self { + pub fn new(client: &'m mut Client, migrations: &'m [&'m str], lakebase_mode: bool) -> Self { // The neon_migration.migration_id::id column is a bigint, which is equivalent to an i64 assert!(migrations.len() + 1 < i64::MAX as usize); - Self { client, migrations } + Self { + client, + migrations, + lakebase_mode, + } } /// Get the current value neon_migration.migration_id @@ -130,8 +135,13 @@ impl<'m> MigrationRunner<'m> { // ID is also the next index let migration_id = (current_migration + 1) as i64; let migration = self.migrations[current_migration]; + let migration = if self.lakebase_mode { + migration.replace("neon_superuser", "databricks_superuser") + } else { + migration.to_string() + }; - match Self::run_migration(self.client, migration_id, migration).await { + match Self::run_migration(self.client, migration_id, &migration).await { Ok(_) => { info!("Finished migration id={}", migration_id); } diff --git a/compute_tools/src/monitor.rs b/compute_tools/src/monitor.rs index fa01545856..e164f15dba 100644 --- a/compute_tools/src/monitor.rs +++ b/compute_tools/src/monitor.rs @@ -11,6 +11,7 @@ use tracing::{Level, error, info, instrument, span}; use crate::compute::ComputeNode; use crate::metrics::{PG_CURR_DOWNTIME_MS, PG_TOTAL_DOWNTIME_MS}; +const PG_DEFAULT_INIT_TIMEOUIT: Duration = Duration::from_secs(60); const MONITOR_CHECK_INTERVAL: Duration = Duration::from_millis(500); /// Struct to store runtime state of the compute monitor thread. @@ -352,13 +353,47 @@ impl ComputeMonitor { // Hang on condition variable waiting until the compute status is `Running`. fn wait_for_postgres_start(compute: &ComputeNode) { let mut state = compute.state.lock().unwrap(); + let pg_init_timeout = compute + .params + .pg_init_timeout + .unwrap_or(PG_DEFAULT_INIT_TIMEOUIT); + while state.status != ComputeStatus::Running { info!("compute is not running, waiting before monitoring activity"); - state = compute.state_changed.wait(state).unwrap(); + if !compute.params.lakebase_mode { + state = compute.state_changed.wait(state).unwrap(); - if state.status == ComputeStatus::Running { - break; + if state.status == ComputeStatus::Running { + break; + } + continue; } + + if state.pg_start_time.is_some() + && Utc::now() + .signed_duration_since(state.pg_start_time.unwrap()) + .to_std() + .unwrap_or_default() + > pg_init_timeout + { + // If Postgres isn't up and running with working PS/SK connections within POSTGRES_STARTUP_TIMEOUT, it is + // possible that we started Postgres with a wrong spec (so it is talking to the wrong PS/SK nodes). To prevent + // deadends we simply exit (panic) the compute node so it can restart with the latest spec. + // + // NB: We skip this check if we have not attempted to start PG yet (indicated by state.pg_start_up == None). + // This is to make sure the more appropriate errors are surfaced if we encounter issues before we even attempt + // to start PG (e.g., if we can't pull the spec, can't sync safekeepers, or can't get the basebackup). + error!( + "compute did not enter Running state in {} seconds, exiting", + pg_init_timeout.as_secs() + ); + std::process::exit(1); + } + state = compute + .state_changed + .wait_timeout(state, Duration::from_secs(5)) + .unwrap() + .0; } } diff --git a/compute_tools/src/pg_helpers.rs b/compute_tools/src/pg_helpers.rs index 0a3ceed2fa..09bbe89b41 100644 --- a/compute_tools/src/pg_helpers.rs +++ b/compute_tools/src/pg_helpers.rs @@ -11,7 +11,9 @@ use std::time::{Duration, Instant}; use anyhow::{Result, bail}; use compute_api::responses::TlsConfig; -use compute_api::spec::{Database, GenericOption, GenericOptions, PgIdent, Role}; +use compute_api::spec::{ + Database, DatabricksSettings, GenericOption, GenericOptions, PgIdent, Role, +}; use futures::StreamExt; use indexmap::IndexMap; use ini::Ini; @@ -184,6 +186,42 @@ impl DatabaseExt for Database { } } +pub trait DatabricksSettingsExt { + fn as_pg_settings(&self) -> String; +} + +impl DatabricksSettingsExt for DatabricksSettings { + fn as_pg_settings(&self) -> String { + // Postgres GUCs rendered from DatabricksSettings + vec![ + // ssl_ca_file + Some(format!( + "ssl_ca_file = '{}'", + self.pg_compute_tls_settings.ca_file + )), + // [Optional] databricks.workspace_url + Some(format!( + "databricks.workspace_url = '{}'", + &self.databricks_workspace_host + )), + // todo(vikas.jain): these are not required anymore as they are moved to static + // conf but keeping these to avoid image mismatch between hcc and pg. + // Once hcc and pg are in sync, we can remove these. + // + // databricks.enable_databricks_identity_login + Some("databricks.enable_databricks_identity_login = true".to_string()), + // databricks.enable_sql_restrictions + Some("databricks.enable_sql_restrictions = true".to_string()), + ] + .into_iter() + // Removes `None`s + .flatten() + .collect::>() + .join("\n") + + "\n" + } +} + /// Generic trait used to provide quoting / encoding for strings used in the /// Postgres SQL queries and DATABASE_URL. pub trait Escaping { diff --git a/compute_tools/src/spec.rs b/compute_tools/src/spec.rs index 4525a0e831..d00f86a2c0 100644 --- a/compute_tools/src/spec.rs +++ b/compute_tools/src/spec.rs @@ -1,4 +1,6 @@ use std::fs::File; +use std::fs::{self, Permissions}; +use std::os::unix::fs::PermissionsExt; use std::path::Path; use anyhow::{Result, anyhow, bail}; @@ -133,10 +135,25 @@ pub fn get_config_from_control_plane(base_uri: &str, compute_id: &str) -> Result } /// Check `pg_hba.conf` and update if needed to allow external connections. -pub fn update_pg_hba(pgdata_path: &Path) -> Result<()> { +pub fn update_pg_hba(pgdata_path: &Path, databricks_pg_hba: Option<&String>) -> Result<()> { // XXX: consider making it a part of config.json let pghba_path = pgdata_path.join("pg_hba.conf"); + // Update pg_hba to contains databricks specfic settings before adding neon settings + // PG uses the first record that matches to perform authentication, so we need to have + // our rules before the default ones from neon. + // See https://www.postgresql.org/docs/16/auth-pg-hba-conf.html + if let Some(databricks_pg_hba) = databricks_pg_hba { + if config::line_in_file( + &pghba_path, + &format!("include_if_exists {}\n", *databricks_pg_hba), + )? { + info!("updated pg_hba.conf to include databricks_pg_hba.conf"); + } else { + info!("pg_hba.conf already included databricks_pg_hba.conf"); + } + } + if config::line_in_file(&pghba_path, PG_HBA_ALL_MD5)? { info!("updated pg_hba.conf to allow external connections"); } else { @@ -146,6 +163,59 @@ pub fn update_pg_hba(pgdata_path: &Path) -> Result<()> { Ok(()) } +/// Check `pg_ident.conf` and update if needed to allow databricks config. +pub fn update_pg_ident(pgdata_path: &Path, databricks_pg_ident: Option<&String>) -> Result<()> { + info!("checking pg_ident.conf"); + let pghba_path = pgdata_path.join("pg_ident.conf"); + + // Update pg_ident to contains databricks specfic settings + if let Some(databricks_pg_ident) = databricks_pg_ident { + if config::line_in_file( + &pghba_path, + &format!("include_if_exists {}\n", *databricks_pg_ident), + )? { + info!("updated pg_ident.conf to include databricks_pg_ident.conf"); + } else { + info!("pg_ident.conf already included databricks_pg_ident.conf"); + } + } + + Ok(()) +} + +/// Copy tls key_file and cert_file from k8s secret mount directory +/// to pgdata and set private key file permissions as expected by Postgres. +/// See this doc for expected permission +/// K8s secrets mount on dblet does not honor permission and ownership +/// specified in the Volume or VolumeMount. So we need to explicitly copy the file and set the permissions. +pub fn copy_tls_certificates( + key_file: &String, + cert_file: &String, + pgdata_path: &Path, +) -> Result<()> { + let files = [cert_file, key_file]; + for file in files.iter() { + let source = Path::new(file); + let dest = pgdata_path.join(source.file_name().unwrap()); + if !dest.exists() { + std::fs::copy(source, &dest)?; + info!( + "Copying tls file: {} to {}", + &source.display(), + &dest.display() + ); + } + if *file == key_file { + // Postgres requires private key to be readable only by the owner by having + // chmod 600 permissions. + let permissions = Permissions::from_mode(0o600); + fs::set_permissions(&dest, permissions)?; + info!("Setting permission on {}.", &dest.display()); + } + } + Ok(()) +} + /// Create a standby.signal file pub fn add_standby_signal(pgdata_path: &Path) -> Result<()> { // XXX: consider making it a part of config.json @@ -170,7 +240,11 @@ pub async fn handle_neon_extension_upgrade(client: &mut Client) -> Result<()> { } #[instrument(skip_all)] -pub async fn handle_migrations(params: ComputeNodeParams, client: &mut Client) -> Result<()> { +pub async fn handle_migrations( + params: ComputeNodeParams, + client: &mut Client, + lakebase_mode: bool, +) -> Result<()> { info!("handle migrations"); // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! @@ -234,7 +308,7 @@ pub async fn handle_migrations(params: ComputeNodeParams, client: &mut Client) - ), ]; - MigrationRunner::new(client, &migrations) + MigrationRunner::new(client, &migrations, lakebase_mode) .run_migrations() .await?; diff --git a/compute_tools/src/spec_apply.rs b/compute_tools/src/spec_apply.rs index ec7e75922b..47bf61ae1b 100644 --- a/compute_tools/src/spec_apply.rs +++ b/compute_tools/src/spec_apply.rs @@ -411,7 +411,8 @@ impl ComputeNode { .map(|limit| match limit { 0..10 => limit, 10..30 => 10, - 30.. => limit / 3, + 30..300 => limit / 3, + 300.. => 100, }) // If we didn't find max_connections, default to 10 concurrent connections. .unwrap_or(10) diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index 6da3223024..8cd923fc72 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -411,6 +411,12 @@ struct StorageControllerStartCmdArgs { help = "Base port for the storage controller instance idenfified by instance-id (defaults to pageserver cplane api)" )] base_port: Option, + + #[clap( + long, + help = "Whether the storage controller should handle pageserver-reported local disk loss events." + )] + handle_ps_local_disk_loss: Option, } #[derive(clap::Args)] @@ -1800,6 +1806,7 @@ async fn handle_storage_controller( instance_id: args.instance_id, base_port: args.base_port, start_timeout: args.start_timeout, + handle_ps_local_disk_loss: args.handle_ps_local_disk_loss, }; if let Err(e) = svc.start(start_args).await { diff --git a/control_plane/src/endpoint.rs b/control_plane/src/endpoint.rs index 58a419b965..149ea07a6b 100644 --- a/control_plane/src/endpoint.rs +++ b/control_plane/src/endpoint.rs @@ -728,12 +728,9 @@ impl Endpoint { // For the sake of backwards-compatibility, also fill in 'pageserver_connstring' // - // XXX: I believe this is not really needed, except to make - // test_forward_compatibility happy. - // // Use a closure so that we can conviniently return None in the middle of the // loop. - let pageserver_connstring = (|| { + let pageserver_connstring: Option = (|| { let num_shards = if args.pageserver_conninfo.shard_count.is_unsharded() { 1 } else { @@ -749,22 +746,24 @@ impl Endpoint { .pageserver_conninfo .shards .get(&shard_index) - .expect(&format!( - "shard {} not found in pageserver_connection_info", - shard_index - )); + .ok_or_else(|| { + anyhow!( + "shard {} not found in pageserver_connection_info", + shard_index + ) + })?; let pageserver = shard .pageservers .first() - .expect("must have at least one pageserver"); + .ok_or(anyhow!("must have at least one pageserver"))?; if let Some(libpq_url) = &pageserver.libpq_url { connstrings.push(libpq_url.clone()); } else { - return None; + return Ok::<_, anyhow::Error>(None); } } - Some(connstrings.join(",")) - })(); + Ok(Some(connstrings.join(","))) + })()?; // Create config file let config = { diff --git a/control_plane/src/storage_controller.rs b/control_plane/src/storage_controller.rs index f996f39967..35a197112e 100644 --- a/control_plane/src/storage_controller.rs +++ b/control_plane/src/storage_controller.rs @@ -56,6 +56,7 @@ pub struct NeonStorageControllerStartArgs { pub instance_id: u8, pub base_port: Option, pub start_timeout: humantime::Duration, + pub handle_ps_local_disk_loss: Option, } impl NeonStorageControllerStartArgs { @@ -64,6 +65,7 @@ impl NeonStorageControllerStartArgs { instance_id: 1, base_port: None, start_timeout, + handle_ps_local_disk_loss: None, } } } @@ -669,6 +671,10 @@ impl StorageController { println!("Starting storage controller at {scheme}://{host}:{listen_port}"); + if start_args.handle_ps_local_disk_loss.unwrap_or_default() { + args.push("--handle-ps-local-disk-loss".to_string()); + } + background_process::start_process( COMMAND, &instance_dir, diff --git a/deny.toml b/deny.toml index be1c6a2f2c..7afd05a837 100644 --- a/deny.toml +++ b/deny.toml @@ -35,6 +35,7 @@ reason = "The paste crate is a build-only dependency with no runtime components. # More documentation for the licenses section can be found here: # https://embarkstudios.github.io/cargo-deny/checks/licenses/cfg.html [licenses] +version = 2 allow = [ "0BSD", "Apache-2.0", diff --git a/endpoint_storage/src/app.rs b/endpoint_storage/src/app.rs index a7a18743ef..64c21cc8b9 100644 --- a/endpoint_storage/src/app.rs +++ b/endpoint_storage/src/app.rs @@ -233,7 +233,7 @@ mod tests { .unwrap() .as_millis(); use rand::Rng; - let random = rand::thread_rng().r#gen::(); + let random = rand::rng().random::(); let s3_config = remote_storage::S3Config { bucket_name: var(REAL_S3_BUCKET).unwrap(), diff --git a/libs/compute_api/src/spec.rs b/libs/compute_api/src/spec.rs index f7ffcd6444..8cfd6b974a 100644 --- a/libs/compute_api/src/spec.rs +++ b/libs/compute_api/src/spec.rs @@ -460,6 +460,32 @@ pub struct GenericOption { pub vartype: String, } +/// Postgres compute TLS settings. +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] +pub struct PgComputeTlsSettings { + // Absolute path to the certificate file for server-side TLS. + pub cert_file: String, + // Absolute path to the private key file for server-side TLS. + pub key_file: String, + // Absolute path to the certificate authority file for verifying client certificates. + pub ca_file: String, +} + +/// Databricks specific options for compute instance. +/// This is used to store any other settings that needs to be propagate to Compute +/// but should not be persisted to ComputeSpec in the database. +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] +pub struct DatabricksSettings { + pub pg_compute_tls_settings: PgComputeTlsSettings, + // Absolute file path to databricks_pg_hba.conf file. + pub databricks_pg_hba: String, + // Absolute file path to databricks_pg_ident.conf file. + pub databricks_pg_ident: String, + // Hostname portion of the Databricks workspace URL of the endpoint, or empty string if not known. + // A valid hostname is required for the compute instance to support PAT logins. + pub databricks_workspace_host: String, +} + /// Optional collection of `GenericOption`'s. Type alias allows us to /// declare a `trait` on it. pub type GenericOptions = Option>; diff --git a/libs/consumption_metrics/src/lib.rs b/libs/consumption_metrics/src/lib.rs index 448134f31a..aeb33bdfc2 100644 --- a/libs/consumption_metrics/src/lib.rs +++ b/libs/consumption_metrics/src/lib.rs @@ -90,7 +90,7 @@ impl<'a> IdempotencyKey<'a> { IdempotencyKey { now: Utc::now(), node_id, - nonce: rand::thread_rng().gen_range(0..=9999), + nonce: rand::rng().random_range(0..=9999), } } diff --git a/libs/desim/src/node_os.rs b/libs/desim/src/node_os.rs index e0cde7b284..6517c2001e 100644 --- a/libs/desim/src/node_os.rs +++ b/libs/desim/src/node_os.rs @@ -41,7 +41,7 @@ impl NodeOs { /// Generate a random number in range [0, max). pub fn random(&self, max: u64) -> u64 { - self.internal.rng.lock().gen_range(0..max) + self.internal.rng.lock().random_range(0..max) } /// Append a new event to the world event log. diff --git a/libs/desim/src/options.rs b/libs/desim/src/options.rs index 9b1a42fd28..d5da008ef1 100644 --- a/libs/desim/src/options.rs +++ b/libs/desim/src/options.rs @@ -32,10 +32,10 @@ impl Delay { /// Generate a random delay in range [min, max]. Return None if the /// message should be dropped. pub fn delay(&self, rng: &mut StdRng) -> Option { - if rng.gen_bool(self.fail_prob) { + if rng.random_bool(self.fail_prob) { return None; } - Some(rng.gen_range(self.min..=self.max)) + Some(rng.random_range(self.min..=self.max)) } } diff --git a/libs/desim/src/world.rs b/libs/desim/src/world.rs index 576ba89cd7..690d45f373 100644 --- a/libs/desim/src/world.rs +++ b/libs/desim/src/world.rs @@ -69,7 +69,7 @@ impl World { /// Create a new random number generator. pub fn new_rng(&self) -> StdRng { let mut rng = self.rng.lock(); - StdRng::from_rng(rng.deref_mut()).unwrap() + StdRng::from_rng(rng.deref_mut()) } /// Create a new node. diff --git a/libs/metrics/Cargo.toml b/libs/metrics/Cargo.toml index f87e7b8e3a..1718ddfae2 100644 --- a/libs/metrics/Cargo.toml +++ b/libs/metrics/Cargo.toml @@ -17,5 +17,5 @@ procfs.workspace = true measured-process.workspace = true [dev-dependencies] -rand = "0.8" -rand_distr = "0.4.3" +rand.workspace = true +rand_distr = "0.5" diff --git a/libs/metrics/src/hll.rs b/libs/metrics/src/hll.rs index 1a7d7a7e44..81e5bafbdf 100644 --- a/libs/metrics/src/hll.rs +++ b/libs/metrics/src/hll.rs @@ -260,7 +260,7 @@ mod tests { #[test] fn test_cardinality_small() { - let (actual, estimate) = test_cardinality(100, Zipf::new(100, 1.2f64).unwrap()); + let (actual, estimate) = test_cardinality(100, Zipf::new(100.0, 1.2f64).unwrap()); assert_eq!(actual, [46, 30, 32]); assert!(51.3 < estimate[0] && estimate[0] < 51.4); @@ -270,7 +270,7 @@ mod tests { #[test] fn test_cardinality_medium() { - let (actual, estimate) = test_cardinality(10000, Zipf::new(10000, 1.2f64).unwrap()); + let (actual, estimate) = test_cardinality(10000, Zipf::new(10000.0, 1.2f64).unwrap()); assert_eq!(actual, [2529, 1618, 1629]); assert!(2309.1 < estimate[0] && estimate[0] < 2309.2); @@ -280,7 +280,8 @@ mod tests { #[test] fn test_cardinality_large() { - let (actual, estimate) = test_cardinality(1_000_000, Zipf::new(1_000_000, 1.2f64).unwrap()); + let (actual, estimate) = + test_cardinality(1_000_000, Zipf::new(1_000_000.0, 1.2f64).unwrap()); assert_eq!(actual, [129077, 79579, 79630]); assert!(126067.2 < estimate[0] && estimate[0] < 126067.3); @@ -290,7 +291,7 @@ mod tests { #[test] fn test_cardinality_small2() { - let (actual, estimate) = test_cardinality(100, Zipf::new(200, 0.8f64).unwrap()); + let (actual, estimate) = test_cardinality(100, Zipf::new(200.0, 0.8f64).unwrap()); assert_eq!(actual, [92, 58, 60]); assert!(116.1 < estimate[0] && estimate[0] < 116.2); @@ -300,7 +301,7 @@ mod tests { #[test] fn test_cardinality_medium2() { - let (actual, estimate) = test_cardinality(10000, Zipf::new(20000, 0.8f64).unwrap()); + let (actual, estimate) = test_cardinality(10000, Zipf::new(20000.0, 0.8f64).unwrap()); assert_eq!(actual, [8201, 5131, 5051]); assert!(6846.4 < estimate[0] && estimate[0] < 6846.5); @@ -310,7 +311,8 @@ mod tests { #[test] fn test_cardinality_large2() { - let (actual, estimate) = test_cardinality(1_000_000, Zipf::new(2_000_000, 0.8f64).unwrap()); + let (actual, estimate) = + test_cardinality(1_000_000, Zipf::new(2_000_000.0, 0.8f64).unwrap()); assert_eq!(actual, [777847, 482069, 482246]); assert!(699437.4 < estimate[0] && estimate[0] < 699437.5); diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index f01c65d1bd..2a8d05f51e 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -394,7 +394,7 @@ impl From<&OtelExporterConfig> for tracing_utils::ExportConfig { tracing_utils::ExportConfig { endpoint: Some(val.endpoint.clone()), protocol: val.protocol.into(), - timeout: val.timeout, + timeout: Some(val.timeout), } } } diff --git a/libs/pageserver_api/src/controller_api.rs b/libs/pageserver_api/src/controller_api.rs index 8f86b03f72..1248be0b5c 100644 --- a/libs/pageserver_api/src/controller_api.rs +++ b/libs/pageserver_api/src/controller_api.rs @@ -596,6 +596,7 @@ pub struct TimelineImportRequest { pub timeline_id: TimelineId, pub start_lsn: Lsn, pub sk_set: Vec, + pub force_upsert: bool, } #[derive(serde::Serialize, serde::Deserialize, Clone)] diff --git a/libs/pageserver_api/src/key.rs b/libs/pageserver_api/src/key.rs index 102bbee879..4e8fabfa72 100644 --- a/libs/pageserver_api/src/key.rs +++ b/libs/pageserver_api/src/key.rs @@ -981,12 +981,12 @@ mod tests { let mut rng = rand::rngs::StdRng::seed_from_u64(42); let key = Key { - field1: rng.r#gen(), - field2: rng.r#gen(), - field3: rng.r#gen(), - field4: rng.r#gen(), - field5: rng.r#gen(), - field6: rng.r#gen(), + field1: rng.random(), + field2: rng.random(), + field3: rng.random(), + field4: rng.random(), + field5: rng.random(), + field6: rng.random(), }; assert_eq!(key, Key::from_str(&format!("{key}")).unwrap()); diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 11e02a8550..7c7c65fb70 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -443,9 +443,9 @@ pub struct ImportPgdataIdempotencyKey(pub String); impl ImportPgdataIdempotencyKey { pub fn random() -> Self { use rand::Rng; - use rand::distributions::Alphanumeric; + use rand::distr::Alphanumeric; Self( - rand::thread_rng() + rand::rng() .sample_iter(&Alphanumeric) .take(20) .map(char::from) diff --git a/libs/pageserver_api/src/shard.rs b/libs/pageserver_api/src/shard.rs index d6f4cd5e66..74f5f14f87 100644 --- a/libs/pageserver_api/src/shard.rs +++ b/libs/pageserver_api/src/shard.rs @@ -69,22 +69,6 @@ impl Hash for ShardIdentity { } } -/// Stripe size in number of pages -#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Debug)] -pub struct ShardStripeSize(pub u32); - -impl Default for ShardStripeSize { - fn default() -> Self { - DEFAULT_STRIPE_SIZE - } -} - -impl std::fmt::Display for ShardStripeSize { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - self.0.fmt(f) - } -} - /// Layout version: for future upgrades where we might change how the key->shard mapping works #[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Hash, Debug)] pub struct ShardLayout(u8); diff --git a/libs/pageserver_api/src/upcall_api.rs b/libs/pageserver_api/src/upcall_api.rs index 07cada2eb1..fa2c896edb 100644 --- a/libs/pageserver_api/src/upcall_api.rs +++ b/libs/pageserver_api/src/upcall_api.rs @@ -21,6 +21,14 @@ pub struct ReAttachRequest { /// if the node already has a node_id set. #[serde(skip_serializing_if = "Option::is_none", default)] pub register: Option, + + /// Hadron: Optional flag to indicate whether the node is starting with an empty local disk. + /// Will be set to true if the node couldn't find any local tenant data on startup, could be + /// due to the node starting for the first time or due to a local SSD failure/disk wipe event. + /// The flag may be used by the storage controller to update its observed state of the world + /// to make sure that it sends explicit location_config calls to the node following the + /// re-attach request. + pub empty_local_disk: Option, } #[derive(Serialize, Deserialize, Debug)] diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index 482dd9a298..5ecb4badf1 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -203,12 +203,12 @@ impl fmt::Display for CancelKeyData { } } -use rand::distributions::{Distribution, Standard}; -impl Distribution for Standard { +use rand::distr::{Distribution, StandardUniform}; +impl Distribution for StandardUniform { fn sample(&self, rng: &mut R) -> CancelKeyData { CancelKeyData { - backend_pid: rng.r#gen(), - cancel_key: rng.r#gen(), + backend_pid: rng.random(), + cancel_key: rng.random(), } } } diff --git a/libs/proxy/postgres-protocol2/src/authentication/sasl.rs b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs index 274c81c500..cfa59a34f4 100644 --- a/libs/proxy/postgres-protocol2/src/authentication/sasl.rs +++ b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs @@ -155,10 +155,10 @@ pub struct ScramSha256 { fn nonce() -> String { // rand 0.5's ThreadRng is cryptographically secure - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); (0..NONCE_LENGTH) .map(|_| { - let mut v = rng.gen_range(0x21u8..0x7e); + let mut v = rng.random_range(0x21u8..0x7e); if v == 0x2c { v = 0x7e } diff --git a/libs/proxy/postgres-protocol2/src/message/backend.rs b/libs/proxy/postgres-protocol2/src/message/backend.rs index 3fc9a9335c..b1728ef37d 100644 --- a/libs/proxy/postgres-protocol2/src/message/backend.rs +++ b/libs/proxy/postgres-protocol2/src/message/backend.rs @@ -74,7 +74,6 @@ impl Header { } /// An enum representing Postgres backend messages. -#[non_exhaustive] pub enum Message { AuthenticationCleartextPassword, AuthenticationGss, @@ -145,16 +144,7 @@ impl Message { PARSE_COMPLETE_TAG => Message::ParseComplete, BIND_COMPLETE_TAG => Message::BindComplete, CLOSE_COMPLETE_TAG => Message::CloseComplete, - NOTIFICATION_RESPONSE_TAG => { - let process_id = buf.read_i32::()?; - let channel = buf.read_cstr()?; - let message = buf.read_cstr()?; - Message::NotificationResponse(NotificationResponseBody { - process_id, - channel, - message, - }) - } + NOTIFICATION_RESPONSE_TAG => Message::NotificationResponse(NotificationResponseBody {}), COPY_DONE_TAG => Message::CopyDone, COMMAND_COMPLETE_TAG => { let tag = buf.read_cstr()?; @@ -543,28 +533,7 @@ impl NoticeResponseBody { } } -pub struct NotificationResponseBody { - process_id: i32, - channel: Bytes, - message: Bytes, -} - -impl NotificationResponseBody { - #[inline] - pub fn process_id(&self) -> i32 { - self.process_id - } - - #[inline] - pub fn channel(&self) -> io::Result<&str> { - get_str(&self.channel) - } - - #[inline] - pub fn message(&self) -> io::Result<&str> { - get_str(&self.message) - } -} +pub struct NotificationResponseBody {} pub struct ParameterDescriptionBody { storage: Bytes, diff --git a/libs/proxy/postgres-protocol2/src/password/mod.rs b/libs/proxy/postgres-protocol2/src/password/mod.rs index e00ca1e34c..8926710225 100644 --- a/libs/proxy/postgres-protocol2/src/password/mod.rs +++ b/libs/proxy/postgres-protocol2/src/password/mod.rs @@ -28,7 +28,7 @@ const SCRAM_DEFAULT_SALT_LEN: usize = 16; /// special characters that would require escaping in an SQL command. pub async fn scram_sha_256(password: &[u8]) -> String { let mut salt: [u8; SCRAM_DEFAULT_SALT_LEN] = [0; SCRAM_DEFAULT_SALT_LEN]; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); rng.fill_bytes(&mut salt); scram_sha_256_salt(password, salt).await } diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index 828884ffd8..068566e955 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -13,7 +13,7 @@ use serde::{Deserialize, Serialize}; use tokio::sync::mpsc; use crate::cancel_token::RawCancelToken; -use crate::codec::{BackendMessages, FrontendMessage}; +use crate::codec::{BackendMessages, FrontendMessage, RecordNotices}; use crate::config::{Host, SslMode}; use crate::query::RowStream; use crate::simple_query::SimpleQueryStream; @@ -221,6 +221,18 @@ impl Client { &mut self.inner } + pub fn record_notices(&mut self, limit: usize) -> mpsc::UnboundedReceiver> { + let (tx, rx) = mpsc::unbounded_channel(); + + let notices = RecordNotices { sender: tx, limit }; + self.inner + .sender + .send(FrontendMessage::RecordNotices(notices)) + .ok(); + + rx + } + /// Pass text directly to the Postgres backend to allow it to sort out typing itself and /// to save a roundtrip pub async fn query_raw_txt( diff --git a/libs/proxy/tokio-postgres2/src/codec.rs b/libs/proxy/tokio-postgres2/src/codec.rs index daa5371426..813faa0e35 100644 --- a/libs/proxy/tokio-postgres2/src/codec.rs +++ b/libs/proxy/tokio-postgres2/src/codec.rs @@ -3,10 +3,17 @@ use std::io; use bytes::{Bytes, BytesMut}; use fallible_iterator::FallibleIterator; use postgres_protocol2::message::backend; +use tokio::sync::mpsc::UnboundedSender; use tokio_util::codec::{Decoder, Encoder}; pub enum FrontendMessage { Raw(Bytes), + RecordNotices(RecordNotices), +} + +pub struct RecordNotices { + pub sender: UnboundedSender>, + pub limit: usize, } pub enum BackendMessage { @@ -33,14 +40,11 @@ impl FallibleIterator for BackendMessages { pub struct PostgresCodec; -impl Encoder for PostgresCodec { +impl Encoder for PostgresCodec { type Error = io::Error; - fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> { - match item { - FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf), - } - + fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> io::Result<()> { + dst.extend_from_slice(&item); Ok(()) } } diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs index 4a07eccf9a..2f718e1e7d 100644 --- a/libs/proxy/tokio-postgres2/src/connect.rs +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -1,11 +1,9 @@ use std::net::IpAddr; -use postgres_protocol2::message::backend::Message; use tokio::net::TcpStream; use tokio::sync::mpsc; use crate::client::SocketConfig; -use crate::codec::BackendMessage; use crate::config::Host; use crate::connect_raw::connect_raw; use crate::connect_socket::connect_socket; @@ -48,8 +46,8 @@ where let stream = connect_tls(socket, config.ssl_mode, tls).await?; let RawConnection { stream, - parameters, - delayed_notice, + parameters: _, + delayed_notice: _, process_id, secret_key, } = connect_raw(stream, config).await?; @@ -72,13 +70,7 @@ where secret_key, ); - // delayed notices are always sent as "Async" messages. - let delayed = delayed_notice - .into_iter() - .map(|m| BackendMessage::Async(Message::NoticeResponse(m))) - .collect(); - - let connection = Connection::new(stream, delayed, parameters, conn_tx, conn_rx); + let connection = Connection::new(stream, conn_tx, conn_rx); Ok((client, connection)) } diff --git a/libs/proxy/tokio-postgres2/src/connect_raw.rs b/libs/proxy/tokio-postgres2/src/connect_raw.rs index b89a600a2e..462e1be1aa 100644 --- a/libs/proxy/tokio-postgres2/src/connect_raw.rs +++ b/libs/proxy/tokio-postgres2/src/connect_raw.rs @@ -3,7 +3,7 @@ use std::io; use std::pin::Pin; use std::task::{Context, Poll}; -use bytes::BytesMut; +use bytes::{Bytes, BytesMut}; use fallible_iterator::FallibleIterator; use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready}; use postgres_protocol2::authentication::sasl; @@ -14,7 +14,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::codec::Framed; use crate::Error; -use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; +use crate::codec::{BackendMessage, BackendMessages, PostgresCodec}; use crate::config::{self, AuthKeys, Config}; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::TlsStream; @@ -25,7 +25,7 @@ pub struct StartupStream { delayed_notice: Vec, } -impl Sink for StartupStream +impl Sink for StartupStream where S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin, @@ -36,7 +36,7 @@ where Pin::new(&mut self.inner).poll_ready(cx) } - fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> { + fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> io::Result<()> { Pin::new(&mut self.inner).start_send(item) } @@ -120,10 +120,7 @@ where let mut buf = BytesMut::new(); frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?; - stream - .send(FrontendMessage::Raw(buf.freeze())) - .await - .map_err(Error::io) + stream.send(buf.freeze()).await.map_err(Error::io) } async fn authenticate(stream: &mut StartupStream, config: &Config) -> Result<(), Error> @@ -191,10 +188,7 @@ where let mut buf = BytesMut::new(); frontend::password_message(password, &mut buf).map_err(Error::encode)?; - stream - .send(FrontendMessage::Raw(buf.freeze())) - .await - .map_err(Error::io) + stream.send(buf.freeze()).await.map_err(Error::io) } async fn authenticate_sasl( @@ -253,10 +247,7 @@ where let mut buf = BytesMut::new(); frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?; - stream - .send(FrontendMessage::Raw(buf.freeze())) - .await - .map_err(Error::io)?; + stream.send(buf.freeze()).await.map_err(Error::io)?; let body = match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationSaslContinue(body)) => body, @@ -272,10 +263,7 @@ where let mut buf = BytesMut::new(); frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?; - stream - .send(FrontendMessage::Raw(buf.freeze())) - .await - .map_err(Error::io)?; + stream.send(buf.freeze()).await.map_err(Error::io)?; let body = match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationSaslFinal(body)) => body, diff --git a/libs/proxy/tokio-postgres2/src/connection.rs b/libs/proxy/tokio-postgres2/src/connection.rs index fe0372b266..c43a22ffe7 100644 --- a/libs/proxy/tokio-postgres2/src/connection.rs +++ b/libs/proxy/tokio-postgres2/src/connection.rs @@ -1,22 +1,23 @@ -use std::collections::{HashMap, VecDeque}; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use bytes::BytesMut; -use futures_util::{Sink, Stream, ready}; -use postgres_protocol2::message::backend::Message; +use fallible_iterator::FallibleIterator; +use futures_util::{Sink, StreamExt, ready}; +use postgres_protocol2::message::backend::{Message, NoticeResponseBody}; use postgres_protocol2::message::frontend; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::mpsc; use tokio_util::codec::Framed; use tokio_util::sync::PollSender; -use tracing::{info, trace}; +use tracing::trace; -use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; -use crate::error::DbError; +use crate::Error; +use crate::codec::{ + BackendMessage, BackendMessages, FrontendMessage, PostgresCodec, RecordNotices, +}; use crate::maybe_tls_stream::MaybeTlsStream; -use crate::{AsyncMessage, Error, Notification}; #[derive(PartialEq, Debug)] enum State { @@ -33,18 +34,18 @@ enum State { /// occurred, or because its associated `Client` has dropped and all outstanding work has completed. #[must_use = "futures do nothing unless polled"] pub struct Connection { - /// HACK: we need this in the Neon Proxy. - pub stream: Framed, PostgresCodec>, - /// HACK: we need this in the Neon Proxy to forward params. - pub parameters: HashMap, + stream: Framed, PostgresCodec>, sender: PollSender, receiver: mpsc::UnboundedReceiver, + notices: Option, - pending_responses: VecDeque, + pending_response: Option, state: State, } +pub enum Never {} + impl Connection where S: AsyncRead + AsyncWrite + Unpin, @@ -52,70 +53,42 @@ where { pub(crate) fn new( stream: Framed, PostgresCodec>, - pending_responses: VecDeque, - parameters: HashMap, sender: mpsc::Sender, receiver: mpsc::UnboundedReceiver, ) -> Connection { Connection { stream, - parameters, sender: PollSender::new(sender), receiver, - pending_responses, + notices: None, + pending_response: None, state: State::Active, } } - fn poll_response( - &mut self, - cx: &mut Context<'_>, - ) -> Poll>> { - if let Some(message) = self.pending_responses.pop_front() { - trace!("retrying pending response"); - return Poll::Ready(Some(Ok(message))); - } - - Pin::new(&mut self.stream) - .poll_next(cx) - .map(|o| o.map(|r| r.map_err(Error::io))) - } - /// Read and process messages from the connection to postgres. /// client <- postgres - fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll> { loop { - let message = match self.poll_response(cx)? { - Poll::Ready(Some(message)) => message, - Poll::Ready(None) => return Poll::Ready(Err(Error::closed())), - Poll::Pending => { - trace!("poll_read: waiting on response"); - return Poll::Pending; - } - }; - - let messages = match message { - BackendMessage::Async(Message::NoticeResponse(body)) => { - let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?; - return Poll::Ready(Ok(AsyncMessage::Notice(error))); - } - BackendMessage::Async(Message::NotificationResponse(body)) => { - let notification = Notification { - process_id: body.process_id(), - channel: body.channel().map_err(Error::parse)?.to_string(), - payload: body.message().map_err(Error::parse)?.to_string(), + let messages = match self.pending_response.take() { + Some(messages) => messages, + None => { + let message = match self.stream.poll_next_unpin(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => return Poll::Ready(Err(Error::closed())), + Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(Error::io(e))), + Poll::Ready(Some(Ok(message))) => message, }; - return Poll::Ready(Ok(AsyncMessage::Notification(notification))); + + match message { + BackendMessage::Async(Message::NoticeResponse(body)) => { + self.handle_notice(body)?; + continue; + } + BackendMessage::Async(_) => continue, + BackendMessage::Normal { messages } => messages, + } } - BackendMessage::Async(Message::ParameterStatus(body)) => { - self.parameters.insert( - body.name().map_err(Error::parse)?.to_string(), - body.value().map_err(Error::parse)?.to_string(), - ); - continue; - } - BackendMessage::Async(_) => unreachable!(), - BackendMessage::Normal { messages } => messages, }; match self.sender.poll_reserve(cx) { @@ -126,8 +99,7 @@ where return Poll::Ready(Err(Error::closed())); } Poll::Pending => { - self.pending_responses - .push_back(BackendMessage::Normal { messages }); + self.pending_response = Some(messages); trace!("poll_read: waiting on sender"); return Poll::Pending; } @@ -135,6 +107,31 @@ where } } + fn handle_notice(&mut self, body: NoticeResponseBody) -> Result<(), Error> { + let Some(notices) = &mut self.notices else { + return Ok(()); + }; + + let mut fields = body.fields(); + while let Some(field) = fields.next().map_err(Error::parse)? { + // loop until we find the message field + if field.type_() == b'M' { + // if the message field is within the limit, send it. + if let Some(new_limit) = notices.limit.checked_sub(field.value().len()) { + match notices.sender.send(field.value().into()) { + // set the new limit. + Ok(()) => notices.limit = new_limit, + // closed. + Err(_) => self.notices = None, + } + } + break; + } + } + + Ok(()) + } + /// Fetch the next client request and enqueue the response sender. fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll> { if self.receiver.is_closed() { @@ -168,21 +165,23 @@ where match self.poll_request(cx) { // send the message to postgres - Poll::Ready(Some(request)) => { + Poll::Ready(Some(FrontendMessage::Raw(request))) => { Pin::new(&mut self.stream) .start_send(request) .map_err(Error::io)?; } + Poll::Ready(Some(FrontendMessage::RecordNotices(notices))) => { + self.notices = Some(notices) + } // No more messages from the client, and no more responses to wait for. // Send a terminate message to postgres Poll::Ready(None) => { trace!("poll_write: at eof, terminating"); let mut request = BytesMut::new(); frontend::terminate(&mut request); - let request = FrontendMessage::Raw(request.freeze()); Pin::new(&mut self.stream) - .start_send(request) + .start_send(request.freeze()) .map_err(Error::io)?; trace!("poll_write: sent eof, closing"); @@ -231,34 +230,17 @@ where } } - /// Returns the value of a runtime parameter for this connection. - pub fn parameter(&self, name: &str) -> Option<&str> { - self.parameters.get(name).map(|s| &**s) - } - - /// Polls for asynchronous messages from the server. - /// - /// The server can send notices as well as notifications asynchronously to the client. Applications that wish to - /// examine those messages should use this method to drive the connection rather than its `Future` implementation. - pub fn poll_message( - &mut self, - cx: &mut Context<'_>, - ) -> Poll>> { + fn poll_message(&mut self, cx: &mut Context<'_>) -> Poll>> { if self.state != State::Closing { // if the state is still active, try read from and write to postgres. - let message = self.poll_read(cx)?; - let closing = self.poll_write(cx)?; - if let Poll::Ready(()) = closing { + let Poll::Pending = self.poll_read(cx)?; + if self.poll_write(cx)?.is_ready() { self.state = State::Closing; } - if let Poll::Ready(message) = message { - return Poll::Ready(Some(Ok(message))); - } - // poll_read returned Pending. - // poll_write returned Pending or Ready(WriteReady::WaitingOnRead). - // if poll_write returned Ready(WriteReady::WaitingOnRead), then we are waiting to read more data from postgres. + // poll_write returned Pending or Ready(()). + // if poll_write returned Ready(()), then we are waiting to read more data from postgres. if self.state != State::Closing { return Poll::Pending; } @@ -280,11 +262,9 @@ where type Output = Result<(), Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - while let Some(message) = ready!(self.poll_message(cx)?) { - if let AsyncMessage::Notice(notice) = message { - info!("{}: {}", notice.severity(), notice.message()); - } + match self.poll_message(cx)? { + Poll::Ready(None) => Poll::Ready(Ok(())), + Poll::Pending => Poll::Pending, } - Poll::Ready(Ok(())) } } diff --git a/libs/proxy/tokio-postgres2/src/lib.rs b/libs/proxy/tokio-postgres2/src/lib.rs index 791c93b972..e3dd6d9261 100644 --- a/libs/proxy/tokio-postgres2/src/lib.rs +++ b/libs/proxy/tokio-postgres2/src/lib.rs @@ -8,7 +8,6 @@ pub use crate::client::{Client, SocketConfig}; pub use crate::config::Config; pub use crate::connect_raw::RawConnection; pub use crate::connection::Connection; -use crate::error::DbError; pub use crate::error::Error; pub use crate::generic_client::GenericClient; pub use crate::query::RowStream; @@ -93,21 +92,6 @@ impl Notification { } } -/// An asynchronous message from the server. -#[allow(clippy::large_enum_variant)] -#[derive(Debug, Clone)] -#[non_exhaustive] -pub enum AsyncMessage { - /// A notice. - /// - /// Notices use the same format as errors, but aren't "errors" per-se. - Notice(DbError), - /// A notification. - /// - /// Connections can subscribe to notifications with the `LISTEN` command. - Notification(Notification), -} - /// Message returned by the `SimpleQuery` stream. #[derive(Debug)] #[non_exhaustive] diff --git a/libs/remote_storage/Cargo.toml b/libs/remote_storage/Cargo.toml index 0ae13552b8..ea06725cfd 100644 --- a/libs/remote_storage/Cargo.toml +++ b/libs/remote_storage/Cargo.toml @@ -43,7 +43,7 @@ itertools.workspace = true sync_wrapper = { workspace = true, features = ["futures"] } byteorder = "1.4" -rand = "0.8.5" +rand.workspace = true [dev-dependencies] camino-tempfile.workspace = true diff --git a/libs/remote_storage/src/simulate_failures.rs b/libs/remote_storage/src/simulate_failures.rs index e895380192..f35d2a3081 100644 --- a/libs/remote_storage/src/simulate_failures.rs +++ b/libs/remote_storage/src/simulate_failures.rs @@ -81,7 +81,7 @@ impl UnreliableWrapper { /// fn attempt(&self, op: RemoteOp) -> anyhow::Result { let mut attempts = self.attempts.lock().unwrap(); - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); match attempts.entry(op) { Entry::Occupied(mut e) => { @@ -94,7 +94,7 @@ impl UnreliableWrapper { /* BEGIN_HADRON */ // If there are more attempts to fail, fail the request by probability. if (attempts_before_this < self.attempts_to_fail) - && (rng.gen_range(0..=100) < self.attempt_failure_probability) + && (rng.random_range(0..=100) < self.attempt_failure_probability) { let error = anyhow::anyhow!("simulated failure of remote operation {:?}", e.key()); diff --git a/libs/remote_storage/tests/test_real_azure.rs b/libs/remote_storage/tests/test_real_azure.rs index 4d7caabd39..949035b8c3 100644 --- a/libs/remote_storage/tests/test_real_azure.rs +++ b/libs/remote_storage/tests/test_real_azure.rs @@ -208,7 +208,7 @@ async fn create_azure_client( .as_millis(); // because nanos can be the same for two threads so can millis, add randomness - let random = rand::thread_rng().r#gen::(); + let random = rand::rng().random::(); let remote_storage_config = RemoteStorageConfig { storage: RemoteStorageKind::AzureContainer(AzureConfig { diff --git a/libs/remote_storage/tests/test_real_s3.rs b/libs/remote_storage/tests/test_real_s3.rs index 6b893edf75..f5c81bf45d 100644 --- a/libs/remote_storage/tests/test_real_s3.rs +++ b/libs/remote_storage/tests/test_real_s3.rs @@ -385,7 +385,7 @@ async fn create_s3_client( .as_millis(); // because nanos can be the same for two threads so can millis, add randomness - let random = rand::thread_rng().r#gen::(); + let random = rand::rng().random::(); let remote_storage_config = RemoteStorageConfig { storage: RemoteStorageKind::AwsS3(S3Config { diff --git a/libs/tracing-utils/src/lib.rs b/libs/tracing-utils/src/lib.rs index 0893aa173b..76782339da 100644 --- a/libs/tracing-utils/src/lib.rs +++ b/libs/tracing-utils/src/lib.rs @@ -1,11 +1,5 @@ //! Helper functions to set up OpenTelemetry tracing. //! -//! This comes in two variants, depending on whether you have a Tokio runtime available. -//! If you do, call `init_tracing()`. It sets up the trace processor and exporter to use -//! the current tokio runtime. If you don't have a runtime available, or you don't want -//! to share the runtime with the tracing tasks, call `init_tracing_without_runtime()` -//! instead. It sets up a dedicated single-threaded Tokio runtime for the tracing tasks. -//! //! Example: //! //! ```rust,no_run @@ -21,7 +15,8 @@ //! .with_writer(std::io::stderr); //! //! // Initialize OpenTelemetry. Exports tracing spans as OpenTelemetry traces -//! let otlp_layer = tracing_utils::init_tracing("my_application", tracing_utils::ExportConfig::default()).await; +//! let provider = tracing_utils::init_tracing("my_application", tracing_utils::ExportConfig::default()); +//! let otlp_layer = provider.as_ref().map(tracing_utils::layer); //! //! // Put it all together //! tracing_subscriber::registry() @@ -36,16 +31,18 @@ pub mod http; pub mod perf_span; -use opentelemetry::KeyValue; use opentelemetry::trace::TracerProvider; use opentelemetry_otlp::WithExportConfig; pub use opentelemetry_otlp::{ExportConfig, Protocol}; +use opentelemetry_sdk::trace::SdkTracerProvider; use tracing::level_filters::LevelFilter; use tracing::{Dispatch, Subscriber}; use tracing_subscriber::Layer; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::registry::LookupSpan; +pub type Provider = SdkTracerProvider; + /// Set up OpenTelemetry exporter, using configuration from environment variables. /// /// `service_name` is set as the OpenTelemetry 'service.name' resource (see @@ -70,16 +67,7 @@ use tracing_subscriber::registry::LookupSpan; /// If you need some other setting, please test if it works first. And perhaps /// add a comment in the list above to save the effort of testing for the next /// person. -/// -/// This doesn't block, but is marked as 'async' to hint that this must be called in -/// asynchronous execution context. -pub async fn init_tracing( - service_name: &str, - export_config: ExportConfig, -) -> Option> -where - S: Subscriber + for<'span> LookupSpan<'span>, -{ +pub fn init_tracing(service_name: &str, export_config: ExportConfig) -> Option { if std::env::var("OTEL_SDK_DISABLED") == Ok("true".to_string()) { return None; }; @@ -89,52 +77,14 @@ where )) } -/// Like `init_tracing`, but creates a separate tokio Runtime for the tracing -/// tasks. -pub fn init_tracing_without_runtime( - service_name: &str, - export_config: ExportConfig, -) -> Option> +pub fn layer(p: &Provider) -> impl Layer where S: Subscriber + for<'span> LookupSpan<'span>, { - if std::env::var("OTEL_SDK_DISABLED") == Ok("true".to_string()) { - return None; - }; - - // The opentelemetry batch processor and the OTLP exporter needs a Tokio - // runtime. Create a dedicated runtime for them. One thread should be - // enough. - // - // (Alternatively, instead of batching, we could use the "simple - // processor", which doesn't need Tokio, and use "reqwest-blocking" - // feature for the OTLP exporter, which also doesn't need Tokio. However, - // batching is considered best practice, and also I have the feeling that - // the non-Tokio codepaths in the opentelemetry crate are less used and - // might be more buggy, so better to stay on the well-beaten path.) - // - // We leak the runtime so that it keeps running after we exit the - // function. - let runtime = Box::leak(Box::new( - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .thread_name("otlp runtime thread") - .worker_threads(1) - .build() - .unwrap(), - )); - let _guard = runtime.enter(); - - Some(init_tracing_internal( - service_name.to_string(), - export_config, - )) + tracing_opentelemetry::layer().with_tracer(p.tracer("global")) } -fn init_tracing_internal(service_name: String, export_config: ExportConfig) -> impl Layer -where - S: Subscriber + for<'span> LookupSpan<'span>, -{ +fn init_tracing_internal(service_name: String, export_config: ExportConfig) -> Provider { // Sets up exporter from the provided [`ExportConfig`] parameter. // If the endpoint is not specified, it is loaded from the // OTEL_EXPORTER_OTLP_ENDPOINT environment variable. @@ -153,22 +103,14 @@ where opentelemetry_sdk::propagation::TraceContextPropagator::new(), ); - let tracer = opentelemetry_sdk::trace::TracerProvider::builder() - .with_batch_exporter(exporter, opentelemetry_sdk::runtime::Tokio) - .with_resource(opentelemetry_sdk::Resource::new(vec![KeyValue::new( - opentelemetry_semantic_conventions::resource::SERVICE_NAME, - service_name, - )])) + Provider::builder() + .with_batch_exporter(exporter) + .with_resource( + opentelemetry_sdk::Resource::builder() + .with_service_name(service_name) + .build(), + ) .build() - .tracer("global"); - - tracing_opentelemetry::layer().with_tracer(tracer) -} - -// Shutdown trace pipeline gracefully, so that it has a chance to send any -// pending traces before we exit. -pub fn shutdown_tracing() { - opentelemetry::global::shutdown_tracer_provider(); } pub enum OtelEnablement { @@ -176,17 +118,17 @@ pub enum OtelEnablement { Enabled { service_name: String, export_config: ExportConfig, - runtime: &'static tokio::runtime::Runtime, }, } pub struct OtelGuard { + provider: Provider, pub dispatch: Dispatch, } impl Drop for OtelGuard { fn drop(&mut self) { - shutdown_tracing(); + _ = self.provider.shutdown(); } } @@ -199,22 +141,19 @@ impl Drop for OtelGuard { /// The lifetime of the guard should match taht of the application. On drop, it tears down the /// OTEL infra. pub fn init_performance_tracing(otel_enablement: OtelEnablement) -> Option { - let otel_subscriber = match otel_enablement { + match otel_enablement { OtelEnablement::Disabled => None, OtelEnablement::Enabled { service_name, export_config, - runtime, } => { - let otel_layer = runtime - .block_on(init_tracing(&service_name, export_config)) - .with_filter(LevelFilter::INFO); + let provider = init_tracing(&service_name, export_config)?; + + let otel_layer = layer(&provider).with_filter(LevelFilter::INFO); let otel_subscriber = tracing_subscriber::registry().with(otel_layer); - let otel_dispatch = Dispatch::new(otel_subscriber); + let dispatch = Dispatch::new(otel_subscriber); - Some(otel_dispatch) + Some(OtelGuard { dispatch, provider }) } - }; - - otel_subscriber.map(|dispatch| OtelGuard { dispatch }) + } } diff --git a/libs/utils/src/id.rs b/libs/utils/src/id.rs index e3037aec21..d63bba75a3 100644 --- a/libs/utils/src/id.rs +++ b/libs/utils/src/id.rs @@ -104,7 +104,7 @@ impl Id { pub fn generate() -> Self { let mut tli_buf = [0u8; 16]; - rand::thread_rng().fill(&mut tli_buf); + rand::rng().fill(&mut tli_buf); Id::from(tli_buf) } diff --git a/libs/utils/src/lsn.rs b/libs/utils/src/lsn.rs index 31e1dda23d..1abb63817b 100644 --- a/libs/utils/src/lsn.rs +++ b/libs/utils/src/lsn.rs @@ -364,42 +364,37 @@ impl MonotonicCounter for RecordLsn { } } -/// Implements [`rand::distributions::uniform::UniformSampler`] so we can sample [`Lsn`]s. +/// Implements [`rand::distr::uniform::UniformSampler`] so we can sample [`Lsn`]s. /// /// This is used by the `pagebench` pageserver benchmarking tool. -pub struct LsnSampler(::Sampler); +pub struct LsnSampler(::Sampler); -impl rand::distributions::uniform::SampleUniform for Lsn { +impl rand::distr::uniform::SampleUniform for Lsn { type Sampler = LsnSampler; } -impl rand::distributions::uniform::UniformSampler for LsnSampler { +impl rand::distr::uniform::UniformSampler for LsnSampler { type X = Lsn; - fn new(low: B1, high: B2) -> Self + fn new(low: B1, high: B2) -> Result where - B1: rand::distributions::uniform::SampleBorrow + Sized, - B2: rand::distributions::uniform::SampleBorrow + Sized, + B1: rand::distr::uniform::SampleBorrow + Sized, + B2: rand::distr::uniform::SampleBorrow + Sized, { - Self( - ::Sampler::new( - low.borrow().0, - high.borrow().0, - ), - ) + ::Sampler::new(low.borrow().0, high.borrow().0) + .map(Self) } - fn new_inclusive(low: B1, high: B2) -> Self + fn new_inclusive(low: B1, high: B2) -> Result where - B1: rand::distributions::uniform::SampleBorrow + Sized, - B2: rand::distributions::uniform::SampleBorrow + Sized, + B1: rand::distr::uniform::SampleBorrow + Sized, + B2: rand::distr::uniform::SampleBorrow + Sized, { - Self( - ::Sampler::new_inclusive( - low.borrow().0, - high.borrow().0, - ), + ::Sampler::new_inclusive( + low.borrow().0, + high.borrow().0, ) + .map(Self) } fn sample(&self, rng: &mut R) -> Self::X { diff --git a/libs/utils/src/shard.rs b/libs/utils/src/shard.rs index 3345549819..90323f7762 100644 --- a/libs/utils/src/shard.rs +++ b/libs/utils/src/shard.rs @@ -25,6 +25,12 @@ pub struct ShardIndex { pub shard_count: ShardCount, } +/// Stripe size as number of pages. +/// +/// NB: don't implement Default, so callers don't lazily use it by mistake. See DEFAULT_STRIPE_SIZE. +#[derive(Clone, Copy, Serialize, Deserialize, Eq, PartialEq, Debug)] +pub struct ShardStripeSize(pub u32); + /// Formatting helper, for generating the `shard_id` label in traces. pub struct ShardSlug<'a>(&'a TenantShardId); @@ -181,6 +187,12 @@ impl std::fmt::Display for ShardCount { } } +impl std::fmt::Display for ShardStripeSize { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + impl std::fmt::Display for ShardSlug<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( diff --git a/pageserver/benches/bench_layer_map.rs b/pageserver/benches/bench_layer_map.rs index e1444778b8..284cc4d67d 100644 --- a/pageserver/benches/bench_layer_map.rs +++ b/pageserver/benches/bench_layer_map.rs @@ -11,7 +11,8 @@ use pageserver::tenant::layer_map::LayerMap; use pageserver::tenant::storage_layer::{LayerName, PersistentLayerDesc}; use pageserver_api::key::Key; use pageserver_api::shard::TenantShardId; -use rand::prelude::{SeedableRng, SliceRandom, StdRng}; +use rand::prelude::{SeedableRng, StdRng}; +use rand::seq::IndexedRandom; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; diff --git a/pageserver/client_grpc/src/client.rs b/pageserver/client_grpc/src/client.rs index 3a9edc7092..e6a90fb582 100644 --- a/pageserver/client_grpc/src/client.rs +++ b/pageserver/client_grpc/src/client.rs @@ -16,10 +16,9 @@ use crate::pool::{ChannelPool, ClientGuard, ClientPool, StreamGuard, StreamPool} use crate::retry::Retry; use crate::split::GetPageSplitter; use compute_api::spec::PageserverProtocol; -use pageserver_api::shard::ShardStripeSize; use pageserver_page_api as page_api; use utils::id::{TenantId, TimelineId}; -use utils::shard::{ShardCount, ShardIndex, ShardNumber}; +use utils::shard::{ShardCount, ShardIndex, ShardNumber, ShardStripeSize}; /// Max number of concurrent clients per channel (i.e. TCP connection). New channels will be spun up /// when full. @@ -141,8 +140,8 @@ impl PageserverClient { if !old.count.is_unsharded() && shard_spec.stripe_size != old.stripe_size { return Err(anyhow!( "can't change stripe size from {} to {}", - old.stripe_size, - shard_spec.stripe_size + old.stripe_size.expect("always Some when sharded"), + shard_spec.stripe_size.expect("always Some when sharded") )); } @@ -157,23 +156,6 @@ impl PageserverClient { Ok(()) } - /// Returns whether a relation exists. - #[instrument(skip_all, fields(rel=%req.rel, lsn=%req.read_lsn))] - pub async fn check_rel_exists( - &self, - req: page_api::CheckRelExistsRequest, - ) -> tonic::Result { - debug!("sending request: {req:?}"); - let resp = Self::with_retries(CALL_TIMEOUT, async |_| { - // Relation metadata is only available on shard 0. - let mut client = self.shards.load_full().get_zero().client().await?; - Self::with_timeout(REQUEST_TIMEOUT, client.check_rel_exists(req)).await - }) - .await?; - debug!("received response: {resp:?}"); - Ok(resp) - } - /// Returns the total size of a database, as # of bytes. #[instrument(skip_all, fields(db_oid=%req.db_oid, lsn=%req.read_lsn))] pub async fn get_db_size( @@ -249,13 +231,15 @@ impl PageserverClient { // Fast path: request is for a single shard. if let Some(shard_id) = GetPageSplitter::for_single_shard(&req, shards.count, shards.stripe_size) + .map_err(|err| tonic::Status::internal(err.to_string()))? { return Self::get_page_with_shard(req, shards.get(shard_id)?).await; } // Request spans multiple shards. Split it, dispatch concurrent per-shard requests, and // reassemble the responses. - let mut splitter = GetPageSplitter::split(req, shards.count, shards.stripe_size); + let mut splitter = GetPageSplitter::split(req, shards.count, shards.stripe_size) + .map_err(|err| tonic::Status::internal(err.to_string()))?; let mut shard_requests = FuturesUnordered::new(); for (shard_id, shard_req) in splitter.drain_requests() { @@ -265,10 +249,14 @@ impl PageserverClient { } while let Some((shard_id, shard_response)) = shard_requests.next().await.transpose()? { - splitter.add_response(shard_id, shard_response)?; + splitter + .add_response(shard_id, shard_response) + .map_err(|err| tonic::Status::internal(err.to_string()))?; } - splitter.get_response() + splitter + .get_response() + .map_err(|err| tonic::Status::internal(err.to_string())) } /// Fetches pages on the given shard. Does not retry internally. @@ -396,12 +384,14 @@ pub struct ShardSpec { /// NB: this is 0 for unsharded tenants, following `ShardIndex::unsharded()` convention. count: ShardCount, /// The stripe size for these shards. - stripe_size: ShardStripeSize, + /// + /// INVARIANT: None for unsharded tenants, Some for sharded. + stripe_size: Option, } impl ShardSpec { /// Creates a new shard spec with the given URLs and stripe size. All shards must be given. - /// The stripe size may be omitted for unsharded tenants. + /// The stripe size must be Some for sharded tenants, or None for unsharded tenants. pub fn new( urls: HashMap, stripe_size: Option, @@ -414,11 +404,13 @@ impl ShardSpec { n => ShardCount::new(n as u8), }; - // Determine the stripe size. It doesn't matter for unsharded tenants. + // Validate the stripe size. if stripe_size.is_none() && !count.is_unsharded() { return Err(anyhow!("stripe size must be given for sharded tenants")); } - let stripe_size = stripe_size.unwrap_or_default(); + if stripe_size.is_some() && count.is_unsharded() { + return Err(anyhow!("stripe size can't be given for unsharded tenants")); + } // Validate the shard spec. for (shard_id, url) in &urls { @@ -458,8 +450,10 @@ struct Shards { /// /// NB: this is 0 for unsharded tenants, following `ShardIndex::unsharded()` convention. count: ShardCount, - /// The stripe size. Only used for sharded tenants. - stripe_size: ShardStripeSize, + /// The stripe size. + /// + /// INVARIANT: None for unsharded tenants, Some for sharded. + stripe_size: Option, } impl Shards { diff --git a/pageserver/client_grpc/src/split.rs b/pageserver/client_grpc/src/split.rs index b7539b900c..8631638686 100644 --- a/pageserver/client_grpc/src/split.rs +++ b/pageserver/client_grpc/src/split.rs @@ -1,11 +1,12 @@ use std::collections::HashMap; +use anyhow::anyhow; use bytes::Bytes; use pageserver_api::key::rel_block_to_key; -use pageserver_api::shard::{ShardStripeSize, key_to_shard_number}; +use pageserver_api::shard::key_to_shard_number; use pageserver_page_api as page_api; -use utils::shard::{ShardCount, ShardIndex, ShardNumber}; +use utils::shard::{ShardCount, ShardIndex, ShardStripeSize}; /// Splits GetPageRequests that straddle shard boundaries and assembles the responses. /// TODO: add tests for this. @@ -25,43 +26,54 @@ impl GetPageSplitter { pub fn for_single_shard( req: &page_api::GetPageRequest, count: ShardCount, - stripe_size: ShardStripeSize, - ) -> Option { + stripe_size: Option, + ) -> anyhow::Result> { // Fast path: unsharded tenant. if count.is_unsharded() { - return Some(ShardIndex::unsharded()); + return Ok(Some(ShardIndex::unsharded())); } - // Find the first page's shard, for comparison. If there are no pages, just return the first - // shard (caller likely checked already, otherwise the server will reject it). + let Some(stripe_size) = stripe_size else { + return Err(anyhow!("stripe size must be given for sharded tenants")); + }; + + // Find the first page's shard, for comparison. let Some(&first_page) = req.block_numbers.first() else { - return Some(ShardIndex::new(ShardNumber(0), count)); + return Err(anyhow!("no block numbers in request")); }; let key = rel_block_to_key(req.rel, first_page); let shard_number = key_to_shard_number(count, stripe_size, &key); - req.block_numbers + Ok(req + .block_numbers .iter() .skip(1) // computed above .all(|&blkno| { let key = rel_block_to_key(req.rel, blkno); key_to_shard_number(count, stripe_size, &key) == shard_number }) - .then_some(ShardIndex::new(shard_number, count)) + .then_some(ShardIndex::new(shard_number, count))) } /// Splits the given request. pub fn split( req: page_api::GetPageRequest, count: ShardCount, - stripe_size: ShardStripeSize, - ) -> Self { + stripe_size: Option, + ) -> anyhow::Result { // The caller should make sure we don't split requests unnecessarily. debug_assert!( - Self::for_single_shard(&req, count, stripe_size).is_none(), + Self::for_single_shard(&req, count, stripe_size)?.is_none(), "unnecessary request split" ); + if count.is_unsharded() { + return Err(anyhow!("unsharded tenant, no point in splitting request")); + } + let Some(stripe_size) = stripe_size else { + return Err(anyhow!("stripe size must be given for sharded tenants")); + }; + // Split the requests by shard index. let mut requests = HashMap::with_capacity(2); // common case let mut block_shards = Vec::with_capacity(req.block_numbers.len()); @@ -103,11 +115,11 @@ impl GetPageSplitter { .collect(), }; - Self { + Ok(Self { requests, response, block_shards, - } + }) } /// Drains the per-shard requests, moving them out of the splitter to avoid extra allocations. @@ -124,21 +136,30 @@ impl GetPageSplitter { &mut self, shard_id: ShardIndex, response: page_api::GetPageResponse, - ) -> tonic::Result<()> { + ) -> anyhow::Result<()> { // The caller should already have converted status codes into tonic::Status. if response.status_code != page_api::GetPageStatusCode::Ok { - return Err(tonic::Status::internal(format!( + return Err(anyhow!( "unexpected non-OK response for shard {shard_id}: {} {}", response.status_code, response.reason.unwrap_or_default() - ))); + )); } if response.request_id != self.response.request_id { - return Err(tonic::Status::internal(format!( + return Err(anyhow!( "response ID mismatch for shard {shard_id}: expected {}, got {}", - self.response.request_id, response.request_id - ))); + self.response.request_id, + response.request_id + )); + } + + if response.request_id != self.response.request_id { + return Err(anyhow!( + "response ID mismatch for shard {shard_id}: expected {}, got {}", + self.response.request_id, + response.request_id + )); } // Place the shard response pages into the assembled response, in request order. @@ -150,27 +171,26 @@ impl GetPageSplitter { } let Some(slot) = self.response.pages.get_mut(i) else { - return Err(tonic::Status::internal(format!( - "no block_shards slot {i} for shard {shard_id}" - ))); + return Err(anyhow!("no block_shards slot {i} for shard {shard_id}")); }; let Some(page) = pages.next() else { - return Err(tonic::Status::internal(format!( + return Err(anyhow!( "missing page {} in shard {shard_id} response", slot.block_number - ))); + )); }; if page.block_number != slot.block_number { - return Err(tonic::Status::internal(format!( + return Err(anyhow!( "shard {shard_id} returned wrong page at index {i}, expected {} got {}", - slot.block_number, page.block_number - ))); + slot.block_number, + page.block_number + )); } if !slot.image.is_empty() { - return Err(tonic::Status::internal(format!( + return Err(anyhow!( "shard {shard_id} returned duplicate page {} at index {i}", slot.block_number - ))); + )); } *slot = page; @@ -178,10 +198,10 @@ impl GetPageSplitter { // Make sure we've consumed all pages from the shard response. if let Some(extra_page) = pages.next() { - return Err(tonic::Status::internal(format!( + return Err(anyhow!( "shard {shard_id} returned extra page: {}", extra_page.block_number - ))); + )); } Ok(()) @@ -189,18 +209,18 @@ impl GetPageSplitter { /// Fetches the final, assembled response. #[allow(clippy::result_large_err)] - pub fn get_response(self) -> tonic::Result { + pub fn get_response(self) -> anyhow::Result { // Check that the response is complete. for (i, page) in self.response.pages.iter().enumerate() { if page.image.is_empty() { - return Err(tonic::Status::internal(format!( + return Err(anyhow!( "missing page {} for shard {}", page.block_number, self.block_shards .get(i) .map(|s| s.to_string()) .unwrap_or_else(|| "?".to_string()) - ))); + )); } } diff --git a/pageserver/compaction/src/bin/compaction-simulator.rs b/pageserver/compaction/src/bin/compaction-simulator.rs index dd35417333..6211b86809 100644 --- a/pageserver/compaction/src/bin/compaction-simulator.rs +++ b/pageserver/compaction/src/bin/compaction-simulator.rs @@ -89,7 +89,7 @@ async fn simulate(cmd: &SimulateCmd, results_path: &Path) -> anyhow::Result<()> let cold_key_range = splitpoint..key_range.end; for i in 0..cmd.num_records { - let chosen_range = if rand::thread_rng().gen_bool(0.9) { + let chosen_range = if rand::rng().random_bool(0.9) { &hot_key_range } else { &cold_key_range diff --git a/pageserver/compaction/src/simulator.rs b/pageserver/compaction/src/simulator.rs index bf9f6f2658..44507c335b 100644 --- a/pageserver/compaction/src/simulator.rs +++ b/pageserver/compaction/src/simulator.rs @@ -300,9 +300,9 @@ impl MockTimeline { key_range: &Range, ) -> anyhow::Result<()> { crate::helpers::union_to_keyspace(&mut self.keyspace, vec![key_range.clone()]); - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..num_records { - self.ingest_record(rng.gen_range(key_range.clone()), len); + self.ingest_record(rng.random_range(key_range.clone()), len); self.wal_ingested += len; } Ok(()) diff --git a/pageserver/ctl/src/key.rs b/pageserver/ctl/src/key.rs index c4daafdfd0..75bab94757 100644 --- a/pageserver/ctl/src/key.rs +++ b/pageserver/ctl/src/key.rs @@ -4,7 +4,7 @@ use anyhow::Context; use clap::Parser; use pageserver_api::key::Key; use pageserver_api::reltag::{BlockNumber, RelTag, SlruKind}; -use pageserver_api::shard::{ShardCount, ShardStripeSize}; +use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize}; #[derive(Parser)] pub(super) struct DescribeKeyCommand { @@ -128,7 +128,9 @@ impl DescribeKeyCommand { // seeing the sharding placement might be confusing, so leave it out unless shard // count was given. - let stripe_size = stripe_size.map(ShardStripeSize).unwrap_or_default(); + let stripe_size = stripe_size + .map(ShardStripeSize) + .unwrap_or(DEFAULT_STRIPE_SIZE); println!( "# placement with shard_count: {} and stripe_size: {}:", shard_count.0, stripe_size.0 diff --git a/pageserver/page_api/proto/page_service.proto b/pageserver/page_api/proto/page_service.proto index d113a04a42..aaccbd5ef0 100644 --- a/pageserver/page_api/proto/page_service.proto +++ b/pageserver/page_api/proto/page_service.proto @@ -17,11 +17,11 @@ // grpcurl \ // -plaintext \ // -H "neon-tenant-id: 7c4a1f9e3bd6470c8f3e21a65bd2e980" \ -// -H "neon-shard-id: 0b10" \ +// -H "neon-shard-id: 0000" \ // -H "neon-timeline-id: f08c4e9a2d5f76b1e3a7c2d8910f4b3e" \ // -H "authorization: Bearer $JWT" \ -// -d '{"read_lsn": {"request_lsn": 1234567890}, "rel": {"spc_oid": 1663, "db_oid": 1234, "rel_number": 5678, "fork_number": 0}}' -// localhost:51051 page_api.PageService/CheckRelExists +// -d '{"read_lsn": {"request_lsn": 100000000, "not_modified_since_lsn": 1}, "db_oid": 1}' \ +// localhost:51051 page_api.PageService/GetDbSize // ``` // // TODO: consider adding neon-compute-mode ("primary", "static", "replica"). @@ -38,8 +38,8 @@ package page_api; import "google/protobuf/timestamp.proto"; service PageService { - // Returns whether a relation exists. - rpc CheckRelExists(CheckRelExistsRequest) returns (CheckRelExistsResponse); + // NB: unlike libpq, there is no CheckRelExists in gRPC, at the compute team's request. Instead, + // use GetRelSize with allow_missing=true to check existence. // Fetches a base backup. rpc GetBaseBackup (GetBaseBackupRequest) returns (stream GetBaseBackupResponseChunk); @@ -97,17 +97,6 @@ message RelTag { uint32 fork_number = 4; } -// Checks whether a relation exists, at the given LSN. Only valid on shard 0, -// other shards will error. -message CheckRelExistsRequest { - ReadLsn read_lsn = 1; - RelTag rel = 2; -} - -message CheckRelExistsResponse { - bool exists = 1; -} - // Requests a base backup. message GetBaseBackupRequest { // The LSN to fetch the base backup at. 0 or absent means the latest LSN known to the Pageserver. @@ -260,10 +249,15 @@ enum GetPageStatusCode { message GetRelSizeRequest { ReadLsn read_lsn = 1; RelTag rel = 2; + // If true, return missing=true for missing relations instead of a NotFound error. + bool allow_missing = 3; } message GetRelSizeResponse { + // The number of blocks in the relation. uint32 num_blocks = 1; + // If allow_missing=true, this is true for missing relations. + bool missing = 2; } // Requests an SLRU segment. Only valid on shard 0, other shards will error. diff --git a/pageserver/page_api/src/client.rs b/pageserver/page_api/src/client.rs index f70d0e7b28..fc27ea448b 100644 --- a/pageserver/page_api/src/client.rs +++ b/pageserver/page_api/src/client.rs @@ -69,16 +69,6 @@ impl Client { Ok(Self { inner }) } - /// Returns whether a relation exists. - pub async fn check_rel_exists( - &mut self, - req: CheckRelExistsRequest, - ) -> tonic::Result { - let req = proto::CheckRelExistsRequest::from(req); - let resp = self.inner.check_rel_exists(req).await?.into_inner(); - Ok(resp.into()) - } - /// Fetches a base backup. pub async fn get_base_backup( &mut self, @@ -114,7 +104,8 @@ impl Client { Ok(resps.and_then(|resp| ready(GetPageResponse::try_from(resp).map_err(|err| err.into())))) } - /// Returns the size of a relation, as # of blocks. + /// Returns the size of a relation as # of blocks, or None if allow_missing=true and the + /// relation does not exist. pub async fn get_rel_size( &mut self, req: GetRelSizeRequest, diff --git a/pageserver/page_api/src/model.rs b/pageserver/page_api/src/model.rs index 7df7de6fc6..02a59acd83 100644 --- a/pageserver/page_api/src/model.rs +++ b/pageserver/page_api/src/model.rs @@ -141,50 +141,6 @@ impl From for proto::RelTag { } } -/// Checks whether a relation exists, at the given LSN. Only valid on shard 0, other shards error. -#[derive(Clone, Copy, Debug)] -pub struct CheckRelExistsRequest { - pub read_lsn: ReadLsn, - pub rel: RelTag, -} - -impl TryFrom for CheckRelExistsRequest { - type Error = ProtocolError; - - fn try_from(pb: proto::CheckRelExistsRequest) -> Result { - Ok(Self { - read_lsn: pb - .read_lsn - .ok_or(ProtocolError::Missing("read_lsn"))? - .try_into()?, - rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?, - }) - } -} - -impl From for proto::CheckRelExistsRequest { - fn from(request: CheckRelExistsRequest) -> Self { - Self { - read_lsn: Some(request.read_lsn.into()), - rel: Some(request.rel.into()), - } - } -} - -pub type CheckRelExistsResponse = bool; - -impl From for CheckRelExistsResponse { - fn from(pb: proto::CheckRelExistsResponse) -> Self { - pb.exists - } -} - -impl From for proto::CheckRelExistsResponse { - fn from(exists: CheckRelExistsResponse) -> Self { - Self { exists } - } -} - /// Requests a base backup. #[derive(Clone, Copy, Debug)] pub struct GetBaseBackupRequest { @@ -709,6 +665,8 @@ impl From for tonic::Code { pub struct GetRelSizeRequest { pub read_lsn: ReadLsn, pub rel: RelTag, + /// If true, return missing=true for missing relations instead of a NotFound error. + pub allow_missing: bool, } impl TryFrom for GetRelSizeRequest { @@ -721,6 +679,7 @@ impl TryFrom for GetRelSizeRequest { .ok_or(ProtocolError::Missing("read_lsn"))? .try_into()?, rel: proto.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?, + allow_missing: proto.allow_missing, }) } } @@ -730,21 +689,29 @@ impl From for proto::GetRelSizeRequest { Self { read_lsn: Some(request.read_lsn.into()), rel: Some(request.rel.into()), + allow_missing: request.allow_missing, } } } -pub type GetRelSizeResponse = u32; +/// The size of a relation as number of blocks, or None if `allow_missing=true` and the relation +/// does not exist. +/// +/// INVARIANT: never None if `allow_missing=false` (returns `NotFound` error instead). +pub type GetRelSizeResponse = Option; impl From for GetRelSizeResponse { - fn from(proto: proto::GetRelSizeResponse) -> Self { - proto.num_blocks + fn from(pb: proto::GetRelSizeResponse) -> Self { + (!pb.missing).then_some(pb.num_blocks) } } impl From for proto::GetRelSizeResponse { - fn from(num_blocks: GetRelSizeResponse) -> Self { - Self { num_blocks } + fn from(resp: GetRelSizeResponse) -> Self { + Self { + num_blocks: resp.unwrap_or_default(), + missing: resp.is_none(), + } } } diff --git a/pageserver/pagebench/src/cmd/basebackup.rs b/pageserver/pagebench/src/cmd/basebackup.rs index c14bb73136..01875f74b9 100644 --- a/pageserver/pagebench/src/cmd/basebackup.rs +++ b/pageserver/pagebench/src/cmd/basebackup.rs @@ -188,9 +188,9 @@ async fn main_impl( start_work_barrier.wait().await; loop { let (timeline, work) = { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let target = all_targets.choose(&mut rng).unwrap(); - let lsn = target.lsn_range.clone().map(|r| rng.gen_range(r)); + let lsn = target.lsn_range.clone().map(|r| rng.random_range(r)); (target.timeline, Work { lsn }) }; let sender = work_senders.get(&timeline).unwrap(); diff --git a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs index f458f4efe4..48dfa6c56a 100644 --- a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs +++ b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs @@ -354,8 +354,7 @@ async fn main_impl( .cloned() .collect(); let weights = - rand::distributions::weighted::WeightedIndex::new(ranges.iter().map(|v| v.len())) - .unwrap(); + rand::distr::weighted::WeightedIndex::new(ranges.iter().map(|v| v.len())).unwrap(); Box::pin(async move { let scheme = match Url::parse(&args.page_service_connstring) { @@ -455,7 +454,7 @@ async fn run_worker( cancel: CancellationToken, rps_period: Option, ranges: Vec, - weights: rand::distributions::weighted::WeightedIndex, + weights: rand::distr::weighted::WeightedIndex, ) { shared_state.start_work_barrier.wait().await; let client_start = Instant::now(); @@ -497,9 +496,9 @@ async fn run_worker( } // Pick a random page from a random relation. - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let r = &ranges[weights.sample(&mut rng)]; - let key: i128 = rng.gen_range(r.start..r.end); + let key: i128 = rng.random_range(r.start..r.end); let (rel_tag, block_no) = key_to_block(key); let mut blks = VecDeque::with_capacity(batch_size); @@ -530,7 +529,7 @@ async fn run_worker( // We assume that the entire batch can fit within the relation. assert_eq!(blks.len(), batch_size, "incomplete batch"); - let req_lsn = if rng.gen_bool(args.req_latest_probability) { + let req_lsn = if rng.random_bool(args.req_latest_probability) { Lsn::MAX } else { r.timeline_lsn diff --git a/pageserver/pagebench/src/cmd/ondemand_download_churn.rs b/pageserver/pagebench/src/cmd/ondemand_download_churn.rs index 9ff1e638c4..8fbb452140 100644 --- a/pageserver/pagebench/src/cmd/ondemand_download_churn.rs +++ b/pageserver/pagebench/src/cmd/ondemand_download_churn.rs @@ -7,7 +7,7 @@ use std::time::{Duration, Instant}; use pageserver_api::models::HistoricLayerInfo; use pageserver_api::shard::TenantShardId; use pageserver_client::mgmt_api; -use rand::seq::SliceRandom; +use rand::seq::IndexedMutRandom; use tokio::sync::{OwnedSemaphorePermit, mpsc}; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; @@ -260,7 +260,7 @@ async fn timeline_actor( loop { let layer_tx = { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); timeline.layers.choose_mut(&mut rng).expect("no layers") }; match layer_tx.try_send(permit.take().unwrap()) { diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index dfb8b437c3..855af7009c 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -126,7 +126,6 @@ fn main() -> anyhow::Result<()> { Some(cfg) => tracing_utils::OtelEnablement::Enabled { service_name: "pageserver".to_string(), export_config: (&cfg.export_config).into(), - runtime: *COMPUTE_REQUEST_RUNTIME, }, None => tracing_utils::OtelEnablement::Disabled, }; diff --git a/pageserver/src/controller_upcall_client.rs b/pageserver/src/controller_upcall_client.rs index be1de43d18..96829bd6ea 100644 --- a/pageserver/src/controller_upcall_client.rs +++ b/pageserver/src/controller_upcall_client.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::net::IpAddr; use futures::Future; use pageserver_api::config::NodeMetadata; @@ -16,7 +17,7 @@ use tokio_util::sync::CancellationToken; use url::Url; use utils::generation::Generation; use utils::id::{NodeId, TimelineId}; -use utils::{backoff, failpoint_support}; +use utils::{backoff, failpoint_support, ip_address}; use crate::config::PageServerConf; use crate::virtual_file::on_fatal_io_error; @@ -27,6 +28,7 @@ pub struct StorageControllerUpcallClient { http_client: reqwest::Client, base_url: Url, node_id: NodeId, + node_ip_addr: Option, cancel: CancellationToken, } @@ -40,6 +42,7 @@ pub trait StorageControllerUpcallApi { fn re_attach( &self, conf: &PageServerConf, + empty_local_disk: bool, ) -> impl Future< Output = Result, RetryForeverError>, > + Send; @@ -91,11 +94,18 @@ impl StorageControllerUpcallClient { ); } + // Intentionally panics if we encountered any errors parsing or reading the IP address. + // Note that if the required environment variable is not set, `read_node_ip_addr_from_env` returns `Ok(None)` + // instead of an error. + let node_ip_addr = + ip_address::read_node_ip_addr_from_env().expect("Error reading node IP address."); + Self { http_client: client.build().expect("Failed to construct HTTP client"), base_url: url, node_id: conf.id, cancel: cancel.clone(), + node_ip_addr, } } @@ -146,6 +156,7 @@ impl StorageControllerUpcallApi for StorageControllerUpcallClient { async fn re_attach( &self, conf: &PageServerConf, + empty_local_disk: bool, ) -> Result, RetryForeverError> { let url = self .base_url @@ -193,8 +204,8 @@ impl StorageControllerUpcallApi for StorageControllerUpcallClient { listen_http_addr: m.http_host, listen_http_port: m.http_port, listen_https_port: m.https_port, + node_ip_addr: self.node_ip_addr, availability_zone_id: az_id.expect("Checked above"), - node_ip_addr: None, }) } Err(e) => { @@ -217,6 +228,7 @@ impl StorageControllerUpcallApi for StorageControllerUpcallClient { let request = ReAttachRequest { node_id: self.node_id, register: register.clone(), + empty_local_disk: Some(empty_local_disk), }; let response: ReAttachResponse = self diff --git a/pageserver/src/deletion_queue.rs b/pageserver/src/deletion_queue.rs index 7854fd9e36..51581ccc2c 100644 --- a/pageserver/src/deletion_queue.rs +++ b/pageserver/src/deletion_queue.rs @@ -768,6 +768,7 @@ mod test { async fn re_attach( &self, _conf: &PageServerConf, + _empty_local_disk: bool, ) -> Result, RetryForeverError> { unimplemented!() } diff --git a/pageserver/src/feature_resolver.rs b/pageserver/src/feature_resolver.rs index f0178fd9b3..11b0e972b4 100644 --- a/pageserver/src/feature_resolver.rs +++ b/pageserver/src/feature_resolver.rs @@ -155,7 +155,7 @@ impl FeatureResolver { ); let tenant_properties = PerTenantProperties { - remote_size_mb: Some(rand::thread_rng().gen_range(100.0..1000000.00)), + remote_size_mb: Some(rand::rng().random_range(100.0..1000000.00)), } .into_posthog_properties(); diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 23146ac40e..26a23da66f 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -1636,9 +1636,10 @@ impl PageServerHandler { let (shard, ctx) = upgrade_handle_and_set_context!(shard); ( vec![ - Self::handle_get_nblocks_request(&shard, &req, &ctx) + Self::handle_get_nblocks_request(&shard, &req, false, &ctx) .instrument(span.clone()) .await + .map(|msg| msg.expect("allow_missing=false")) .map(|msg| (PagestreamBeMessage::Nblocks(msg), timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], @@ -2303,12 +2304,16 @@ impl PageServerHandler { Ok(PagestreamExistsResponse { req: *req, exists }) } + /// If `allow_missing` is true, returns None instead of Err on missing relations. Otherwise, + /// never returns None. It is only supported by the gRPC protocol, so we pass it separately to + /// avoid changing the libpq protocol types. #[instrument(skip_all, fields(shard_id))] async fn handle_get_nblocks_request( timeline: &Timeline, req: &PagestreamNblocksRequest, + allow_missing: bool, ctx: &RequestContext, - ) -> Result { + ) -> Result, PageStreamError> { let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); let lsn = Self::wait_or_get_last_lsn( timeline, @@ -2320,20 +2325,25 @@ impl PageServerHandler { .await?; let n_blocks = timeline - .get_rel_size( + .get_rel_size_in_reldir( req.rel, Version::LsnRange(LsnRange { effective_lsn: lsn, request_lsn: req.hdr.request_lsn, }), + None, + allow_missing, ctx, ) .await?; + let Some(n_blocks) = n_blocks else { + return Ok(None); + }; - Ok(PagestreamNblocksResponse { + Ok(Some(PagestreamNblocksResponse { req: *req, n_blocks, - }) + })) } #[instrument(skip_all, fields(shard_id))] @@ -3525,8 +3535,8 @@ impl GrpcPageServiceHandler { /// Implements the gRPC page service. /// -/// Tonic will drop the request handler futures if the client goes away (e.g. due to a timeout or -/// cancellation), so the read path must be cancellation-safe. On shutdown, Tonic will wait for +/// On client disconnect (e.g. timeout or client shutdown), Tonic will drop the request handler +/// futures, so the read path must be cancellation-safe. On server shutdown, Tonic will wait for /// in-flight requests to complete. /// /// TODO: when the libpq impl is removed, remove the Pagestream types and inline the handler code. @@ -3539,39 +3549,6 @@ impl proto::PageService for GrpcPageServiceHandler { type GetPagesStream = Pin> + Send>>; - #[instrument(skip_all, fields(rel, lsn))] - async fn check_rel_exists( - &self, - req: tonic::Request, - ) -> Result, tonic::Status> { - let received_at = extract::(&req).0; - let timeline = self.get_request_timeline(&req).await?; - let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); - - // Validate the request, decorate the span, and convert it to a Pagestream request. - Self::ensure_shard_zero(&timeline)?; - let req: page_api::CheckRelExistsRequest = req.into_inner().try_into()?; - - span_record!(rel=%req.rel, lsn=%req.read_lsn); - - let req = PagestreamExistsRequest { - hdr: Self::make_hdr(req.read_lsn, None), - rel: req.rel, - }; - - // Execute the request and convert the response. - let _timer = Self::record_op_start_and_throttle( - &timeline, - metrics::SmgrQueryType::GetRelExists, - received_at, - ) - .await?; - - let resp = PageServerHandler::handle_get_rel_exists_request(&timeline, &req, &ctx).await?; - let resp: page_api::CheckRelExistsResponse = resp.exists; - Ok(tonic::Response::new(resp.into())) - } - #[instrument(skip_all, fields(lsn))] async fn get_base_backup( &self, @@ -3766,12 +3743,14 @@ impl proto::PageService for GrpcPageServiceHandler { // NB: Tonic considers the entire stream to be an in-flight request and will wait // for it to complete before shutting down. React to cancellation between requests. let req = tokio::select! { + biased; + _ = cancel.cancelled() => Err(tonic::Status::unavailable("shutting down")), + result = reqs.message() => match result { Ok(Some(req)) => Ok(req), Ok(None) => break, // client closed the stream Err(err) => Err(err), }, - _ = cancel.cancelled() => Err(tonic::Status::unavailable("shutting down")), }?; let req_id = req.request_id.map(page_api::RequestID::from).unwrap_or_default(); let result = Self::get_page(&ctx, &timeline, req, io_concurrency.clone()) @@ -3796,7 +3775,7 @@ impl proto::PageService for GrpcPageServiceHandler { Ok(tonic::Response::new(Box::pin(resps))) } - #[instrument(skip_all, fields(rel, lsn))] + #[instrument(skip_all, fields(rel, lsn, allow_missing))] async fn get_rel_size( &self, req: tonic::Request, @@ -3808,8 +3787,9 @@ impl proto::PageService for GrpcPageServiceHandler { // Validate the request, decorate the span, and convert it to a Pagestream request. Self::ensure_shard_zero(&timeline)?; let req: page_api::GetRelSizeRequest = req.into_inner().try_into()?; + let allow_missing = req.allow_missing; - span_record!(rel=%req.rel, lsn=%req.read_lsn); + span_record!(rel=%req.rel, lsn=%req.read_lsn, allow_missing=%req.allow_missing); let req = PagestreamNblocksRequest { hdr: Self::make_hdr(req.read_lsn, None), @@ -3824,8 +3804,11 @@ impl proto::PageService for GrpcPageServiceHandler { ) .await?; - let resp = PageServerHandler::handle_get_nblocks_request(&timeline, &req, &ctx).await?; - let resp: page_api::GetRelSizeResponse = resp.n_blocks; + let resp = + PageServerHandler::handle_get_nblocks_request(&timeline, &req, allow_missing, &ctx) + .await?; + let resp: page_api::GetRelSizeResponse = resp.map(|resp| resp.n_blocks); + Ok(tonic::Response::new(resp.into())) } diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index 8b76d980fc..ab9cc88e5f 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -504,8 +504,9 @@ impl Timeline { for rel in rels { let n_blocks = self - .get_rel_size_in_reldir(rel, version, Some((reldir_key, &reldir)), ctx) - .await?; + .get_rel_size_in_reldir(rel, version, Some((reldir_key, &reldir)), false, ctx) + .await? + .expect("allow_missing=false"); total_blocks += n_blocks as usize; } Ok(total_blocks) @@ -521,10 +522,16 @@ impl Timeline { version: Version<'_>, ctx: &RequestContext, ) -> Result { - self.get_rel_size_in_reldir(tag, version, None, ctx).await + Ok(self + .get_rel_size_in_reldir(tag, version, None, false, ctx) + .await? + .expect("allow_missing=false")) } - /// Get size of a relation file. The relation must exist, otherwise an error is returned. + /// Get size of a relation file. If `allow_missing` is true, returns None for missing relations, + /// otherwise errors. + /// + /// INVARIANT: never returns None if `allow_missing=false`. /// /// See [`Self::get_rel_exists_in_reldir`] on why we need `deserialized_reldir_v1`. pub(crate) async fn get_rel_size_in_reldir( @@ -532,8 +539,9 @@ impl Timeline { tag: RelTag, version: Version<'_>, deserialized_reldir_v1: Option<(Key, &RelDirectory)>, + allow_missing: bool, ctx: &RequestContext, - ) -> Result { + ) -> Result, PageReconstructError> { if tag.relnode == 0 { return Err(PageReconstructError::Other( RelationError::InvalidRelnode.into(), @@ -541,7 +549,15 @@ impl Timeline { } if let Some(nblocks) = self.get_cached_rel_size(&tag, version) { - return Ok(nblocks); + return Ok(Some(nblocks)); + } + + if allow_missing + && !self + .get_rel_exists_in_reldir(tag, version, deserialized_reldir_v1, ctx) + .await? + { + return Ok(None); } if (tag.forknum == FSM_FORKNUM || tag.forknum == VISIBILITYMAP_FORKNUM) @@ -553,7 +569,7 @@ impl Timeline { // FSM, and smgrnblocks() on it immediately afterwards, // without extending it. Tolerate that by claiming that // any non-existent FSM fork has size 0. - return Ok(0); + return Ok(Some(0)); } let key = rel_size_to_key(tag); @@ -562,7 +578,7 @@ impl Timeline { self.update_cached_rel_size(tag, version, nblocks); - Ok(nblocks) + Ok(Some(nblocks)) } /// Does the relation exist? @@ -2912,9 +2928,8 @@ static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; BLCKSZ as usize]); mod tests { use hex_literal::hex; use pageserver_api::models::ShardParameters; - use pageserver_api::shard::ShardStripeSize; use utils::id::TimelineId; - use utils::shard::{ShardCount, ShardNumber}; + use utils::shard::{ShardCount, ShardNumber, ShardStripeSize}; use super::*; use crate::DEFAULT_PG_VERSION; diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 3d66ae4719..4c8856c386 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -6161,11 +6161,11 @@ mod tests { use pageserver_api::keyspace::KeySpaceRandomAccum; use pageserver_api::models::{CompactionAlgorithm, CompactionAlgorithmSettings, LsnLease}; use pageserver_compaction::helpers::overlaps_with; + use rand::Rng; #[cfg(feature = "testing")] use rand::SeedableRng; #[cfg(feature = "testing")] use rand::rngs::StdRng; - use rand::{Rng, thread_rng}; #[cfg(feature = "testing")] use std::ops::Range; use storage_layer::{IoConcurrency, PersistentLayerKey}; @@ -6286,8 +6286,8 @@ mod tests { while lsn < lsn_range.end { let mut key = key_range.start; while key < key_range.end { - let gap = random.gen_range(1..=100) <= spec.gap_chance; - let will_init = random.gen_range(1..=100) <= spec.will_init_chance; + let gap = random.random_range(1..=100) <= spec.gap_chance; + let will_init = random.random_range(1..=100) <= spec.will_init_chance; if gap { continue; @@ -6330,8 +6330,8 @@ mod tests { while lsn < lsn_range.end { let mut key = key_range.start; while key < key_range.end { - let gap = random.gen_range(1..=100) <= spec.gap_chance; - let will_init = random.gen_range(1..=100) <= spec.will_init_chance; + let gap = random.random_range(1..=100) <= spec.gap_chance; + let will_init = random.random_range(1..=100) <= spec.will_init_chance; if gap { continue; @@ -7808,7 +7808,7 @@ mod tests { for _ in 0..50 { for _ in 0..NUM_KEYS { lsn = Lsn(lsn.0 + 0x10); - let blknum = thread_rng().gen_range(0..NUM_KEYS); + let blknum = rand::rng().random_range(0..NUM_KEYS); test_key.field6 = blknum as u32; let mut writer = tline.writer().await; writer @@ -7897,7 +7897,7 @@ mod tests { for _ in 0..NUM_KEYS { lsn = Lsn(lsn.0 + 0x10); - let blknum = thread_rng().gen_range(0..NUM_KEYS); + let blknum = rand::rng().random_range(0..NUM_KEYS); test_key.field6 = blknum as u32; let mut writer = tline.writer().await; writer @@ -7965,7 +7965,7 @@ mod tests { for _ in 0..NUM_KEYS { lsn = Lsn(lsn.0 + 0x10); - let blknum = thread_rng().gen_range(0..NUM_KEYS); + let blknum = rand::rng().random_range(0..NUM_KEYS); test_key.field6 = blknum as u32; let mut writer = tline.writer().await; writer @@ -8229,7 +8229,7 @@ mod tests { for _ in 0..NUM_KEYS { lsn = Lsn(lsn.0 + 0x10); - let blknum = thread_rng().gen_range(0..NUM_KEYS); + let blknum = rand::rng().random_range(0..NUM_KEYS); test_key.field6 = (blknum * STEP) as u32; let mut writer = tline.writer().await; writer @@ -8502,7 +8502,7 @@ mod tests { for iter in 1..=10 { for _ in 0..NUM_KEYS { lsn = Lsn(lsn.0 + 0x10); - let blknum = thread_rng().gen_range(0..NUM_KEYS); + let blknum = rand::rng().random_range(0..NUM_KEYS); test_key.field6 = (blknum * STEP) as u32; let mut writer = tline.writer().await; writer @@ -11291,10 +11291,10 @@ mod tests { #[cfg(feature = "testing")] #[tokio::test] async fn test_read_path() -> anyhow::Result<()> { - use rand::seq::SliceRandom; + use rand::seq::IndexedRandom; let seed = if cfg!(feature = "fuzz-read-path") { - let seed: u64 = thread_rng().r#gen(); + let seed: u64 = rand::rng().random(); seed } else { // Use a hard-coded seed when not in fuzzing mode. @@ -11308,8 +11308,8 @@ mod tests { let (queries, will_init_chance, gap_chance) = if cfg!(feature = "fuzz-read-path") { const QUERIES: u64 = 5000; - let will_init_chance: u8 = random.gen_range(0..=10); - let gap_chance: u8 = random.gen_range(0..=50); + let will_init_chance: u8 = random.random_range(0..=10); + let gap_chance: u8 = random.random_range(0..=50); (QUERIES, will_init_chance, gap_chance) } else { @@ -11410,7 +11410,8 @@ mod tests { while used_keys.len() < tenant.conf.max_get_vectored_keys.get() { let selected_lsn = interesting_lsns.choose(&mut random).expect("not empty"); - let mut selected_key = start_key.add(random.gen_range(0..KEY_DIMENSION_SIZE)); + let mut selected_key = + start_key.add(random.random_range(0..KEY_DIMENSION_SIZE)); while used_keys.len() < tenant.conf.max_get_vectored_keys.get() { if used_keys.contains(&selected_key) @@ -11425,7 +11426,7 @@ mod tests { .add_key(selected_key); used_keys.insert(selected_key); - let pick_next = random.gen_range(0..=100) <= PICK_NEXT_CHANCE; + let pick_next = random.random_range(0..=100) <= PICK_NEXT_CHANCE; if pick_next { selected_key = selected_key.next(); } else { diff --git a/pageserver/src/tenant/blob_io.rs b/pageserver/src/tenant/blob_io.rs index ed541c4f12..29320f088c 100644 --- a/pageserver/src/tenant/blob_io.rs +++ b/pageserver/src/tenant/blob_io.rs @@ -535,8 +535,8 @@ pub(crate) mod tests { } pub(crate) fn random_array(len: usize) -> Vec { - let mut rng = rand::thread_rng(); - (0..len).map(|_| rng.r#gen()).collect::<_>() + let mut rng = rand::rng(); + (0..len).map(|_| rng.random()).collect::<_>() } #[tokio::test] @@ -588,9 +588,9 @@ pub(crate) mod tests { let mut rng = rand::rngs::StdRng::seed_from_u64(42); let blobs = (0..1024) .map(|_| { - let mut sz: u16 = rng.r#gen(); + let mut sz: u16 = rng.random(); // Make 50% of the arrays small - if rng.r#gen() { + if rng.random() { sz &= 63; } random_array(sz.into()) diff --git a/pageserver/src/tenant/disk_btree.rs b/pageserver/src/tenant/disk_btree.rs index 419befa41b..40f405307c 100644 --- a/pageserver/src/tenant/disk_btree.rs +++ b/pageserver/src/tenant/disk_btree.rs @@ -1090,7 +1090,7 @@ pub(crate) mod tests { const NUM_KEYS: usize = 100000; let mut all_data: BTreeMap = BTreeMap::new(); for idx in 0..NUM_KEYS { - let u: f64 = rand::thread_rng().gen_range(0.0..1.0); + let u: f64 = rand::rng().random_range(0.0..1.0); let t = -(f64::ln(u)); let key_int = (t * 1000000.0) as u128; @@ -1116,7 +1116,7 @@ pub(crate) mod tests { // Test get() operations on random keys, most of which will not exist for _ in 0..100000 { - let key_int = rand::thread_rng().r#gen::(); + let key_int = rand::rng().random::(); let search_key = u128::to_be_bytes(key_int); assert!(reader.get(&search_key, &ctx).await? == all_data.get(&key_int).cloned()); } diff --git a/pageserver/src/tenant/ephemeral_file.rs b/pageserver/src/tenant/ephemeral_file.rs index 203b5bf592..f2be129090 100644 --- a/pageserver/src/tenant/ephemeral_file.rs +++ b/pageserver/src/tenant/ephemeral_file.rs @@ -508,8 +508,8 @@ mod tests { let write_nbytes = cap * 2 + cap / 2; - let content: Vec = rand::thread_rng() - .sample_iter(rand::distributions::Standard) + let content: Vec = rand::rng() + .sample_iter(rand::distr::StandardUniform) .take(write_nbytes) .collect(); @@ -565,8 +565,8 @@ mod tests { let cap = writer.mutable().capacity(); drop(writer); - let content: Vec = rand::thread_rng() - .sample_iter(rand::distributions::Standard) + let content: Vec = rand::rng() + .sample_iter(rand::distr::StandardUniform) .take(cap * 2 + cap / 2) .collect(); @@ -614,8 +614,8 @@ mod tests { let cap = mutable.capacity(); let align = mutable.align(); drop(writer); - let content: Vec = rand::thread_rng() - .sample_iter(rand::distributions::Standard) + let content: Vec = rand::rng() + .sample_iter(rand::distr::StandardUniform) .take(cap * 2 + cap / 2) .collect(); diff --git a/pageserver/src/tenant/mgr.rs b/pageserver/src/tenant/mgr.rs index 52f67abde5..b47bab16d8 100644 --- a/pageserver/src/tenant/mgr.rs +++ b/pageserver/src/tenant/mgr.rs @@ -19,7 +19,7 @@ use pageserver_api::shard::{ }; use pageserver_api::upcall_api::ReAttachResponseTenant; use rand::Rng; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use remote_storage::TimeoutOrCancel; use sysinfo::SystemExt; use tokio::fs; @@ -218,7 +218,7 @@ async fn safe_rename_tenant_dir(path: impl AsRef) -> std::io::Result TenantStartupMode::Attached(( alc.attach_mode, alc.generation, - ShardStripeSize::default(), + lc.shard.stripe_size, )), LocationMode::Secondary(_) => TenantStartupMode::Secondary, }, @@ -352,7 +352,8 @@ async fn init_load_generations( let client = StorageControllerUpcallClient::new(conf, cancel); info!("Calling {} API to re-attach tenants", client.base_url()); // If we are configured to use the control plane API, then it is the source of truth for what tenants to load. - match client.re_attach(conf).await { + let empty_local_disk = tenant_confs.is_empty(); + match client.re_attach(conf, empty_local_disk).await { Ok(tenants) => tenants .into_iter() .flat_map(|(id, rart)| { diff --git a/pageserver/src/tenant/remote_timeline_client/manifest.rs b/pageserver/src/tenant/remote_timeline_client/manifest.rs index 7dba4508e2..41e9647d8f 100644 --- a/pageserver/src/tenant/remote_timeline_client/manifest.rs +++ b/pageserver/src/tenant/remote_timeline_client/manifest.rs @@ -1,8 +1,8 @@ use chrono::NaiveDateTime; -use pageserver_api::shard::ShardStripeSize; use serde::{Deserialize, Serialize}; use utils::id::TimelineId; use utils::lsn::Lsn; +use utils::shard::ShardStripeSize; /// Tenant shard manifest, stored in remote storage. Contains offloaded timelines and other tenant /// shard-wide information that must be persisted in remote storage. diff --git a/pageserver/src/tenant/secondary/scheduler.rs b/pageserver/src/tenant/secondary/scheduler.rs index 62ca527bbc..8dc1d57b5d 100644 --- a/pageserver/src/tenant/secondary/scheduler.rs +++ b/pageserver/src/tenant/secondary/scheduler.rs @@ -25,7 +25,7 @@ pub(super) fn period_jitter(d: Duration, pct: u32) -> Duration { if d == Duration::ZERO { d } else { - rand::thread_rng().gen_range((d * (100 - pct)) / 100..(d * (100 + pct)) / 100) + rand::rng().random_range((d * (100 - pct)) / 100..(d * (100 + pct)) / 100) } } @@ -35,7 +35,7 @@ pub(super) fn period_warmup(period: Duration) -> Duration { if period == Duration::ZERO { period } else { - rand::thread_rng().gen_range(Duration::ZERO..period) + rand::rng().random_range(Duration::ZERO..period) } } diff --git a/pageserver/src/tenant/storage_layer/delta_layer.rs b/pageserver/src/tenant/storage_layer/delta_layer.rs index c2f76c859c..f963fdac92 100644 --- a/pageserver/src/tenant/storage_layer/delta_layer.rs +++ b/pageserver/src/tenant/storage_layer/delta_layer.rs @@ -1634,7 +1634,8 @@ pub(crate) mod test { use bytes::Bytes; use itertools::MinMaxResult; use postgres_ffi::PgMajorVersion; - use rand::prelude::{SeedableRng, SliceRandom, StdRng}; + use rand::prelude::{SeedableRng, StdRng}; + use rand::seq::IndexedRandom; use rand::{Rng, RngCore}; /// Construct an index for a fictional delta layer and and then @@ -1788,14 +1789,14 @@ pub(crate) mod test { let mut entries = Vec::new(); for _ in 0..constants::KEY_COUNT { - let count = rng.gen_range(1..constants::MAX_ENTRIES_PER_KEY); + let count = rng.random_range(1..constants::MAX_ENTRIES_PER_KEY); let mut lsns_iter = std::iter::successors(Some(Lsn(constants::LSN_OFFSET.0 + 0x08)), |lsn| { Some(Lsn(lsn.0 + 0x08)) }); let mut lsns = Vec::new(); while lsns.len() < count as usize { - let take = rng.gen_bool(0.5); + let take = rng.random_bool(0.5); let lsn = lsns_iter.next().unwrap(); if take { lsns.push(lsn); @@ -1869,12 +1870,13 @@ pub(crate) mod test { for _ in 0..constants::RANGES_COUNT { let mut range: Option> = Option::default(); while range.is_none() || keyspace.overlaps(range.as_ref().unwrap()) { - let range_start = rng.gen_range(start..end); + let range_start = rng.random_range(start..end); let range_end_offset = range_start + constants::MIN_RANGE_SIZE; if range_end_offset >= end { range = Some(Key::from_i128(range_start)..Key::from_i128(end)); } else { - let range_end = rng.gen_range((range_start + constants::MIN_RANGE_SIZE)..end); + let range_end = + rng.random_range((range_start + constants::MIN_RANGE_SIZE)..end); range = Some(Key::from_i128(range_start)..Key::from_i128(range_end)); } } diff --git a/pageserver/src/tenant/storage_layer/inmemory_layer/vectored_dio_read.rs b/pageserver/src/tenant/storage_layer/inmemory_layer/vectored_dio_read.rs index 27fbc6f5fb..84f4386087 100644 --- a/pageserver/src/tenant/storage_layer/inmemory_layer/vectored_dio_read.rs +++ b/pageserver/src/tenant/storage_layer/inmemory_layer/vectored_dio_read.rs @@ -440,8 +440,8 @@ mod tests { impl InMemoryFile { fn new_random(len: usize) -> Self { Self { - content: rand::thread_rng() - .sample_iter(rand::distributions::Standard) + content: rand::rng() + .sample_iter(rand::distr::StandardUniform) .take(len) .collect(), } @@ -498,7 +498,7 @@ mod tests { len } }; - rand::Rng::fill(&mut rand::thread_rng(), &mut dst_slice[nread..]); // to discover bugs + rand::Rng::fill(&mut rand::rng(), &mut dst_slice[nread..]); // to discover bugs Ok((dst, nread)) } } @@ -763,7 +763,7 @@ mod tests { let len = std::cmp::min(dst.bytes_total(), mocked_bytes.len()); let dst_slice: &mut [u8] = dst.as_mut_rust_slice_full_zeroed(); dst_slice[..len].copy_from_slice(&mocked_bytes[..len]); - rand::Rng::fill(&mut rand::thread_rng(), &mut dst_slice[len..]); // to discover bugs + rand::Rng::fill(&mut rand::rng(), &mut dst_slice[len..]); // to discover bugs Ok((dst, len)) } Err(e) => Err(std::io::Error::other(e)), diff --git a/pageserver/src/tenant/tasks.rs b/pageserver/src/tenant/tasks.rs index 08fc7d61a5..676b39e55b 100644 --- a/pageserver/src/tenant/tasks.rs +++ b/pageserver/src/tenant/tasks.rs @@ -515,7 +515,7 @@ pub(crate) async fn sleep_random_range( interval: RangeInclusive, cancel: &CancellationToken, ) -> Result { - let delay = rand::thread_rng().gen_range(interval); + let delay = rand::rng().random_range(interval); if delay == Duration::ZERO { return Ok(delay); } diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 06e02a7386..7f6173db3f 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -448,6 +448,7 @@ pub struct Timeline { /// A channel to send async requests to prepare a basebackup for the basebackup cache. basebackup_cache: Arc, + #[expect(dead_code)] feature_resolver: Arc, } @@ -2826,7 +2827,7 @@ impl Timeline { if r.numerator == 0 { false } else { - rand::thread_rng().gen_range(0..r.denominator) < r.numerator + rand::rng().random_range(0..r.denominator) < r.numerator } } None => false, @@ -3908,7 +3909,7 @@ impl Timeline { // 1hour base (60_i64 * 60_i64) // 10min jitter - + rand::thread_rng().gen_range(-10 * 60..10 * 60), + + rand::rng().random_range(-10 * 60..10 * 60), ) .expect("10min < 1hour"), ); diff --git a/pageserver/src/tenant/timeline/compaction.rs b/pageserver/src/tenant/timeline/compaction.rs index f76ef502dc..9bca952a46 100644 --- a/pageserver/src/tenant/timeline/compaction.rs +++ b/pageserver/src/tenant/timeline/compaction.rs @@ -1326,13 +1326,7 @@ impl Timeline { .max() }; - let (partition_mode, partition_lsn) = if cfg!(test) - || cfg!(feature = "testing") - || self - .feature_resolver - .evaluate_boolean("image-compaction-boundary") - .is_ok() - { + let (partition_mode, partition_lsn) = { let last_repartition_lsn = self.partitioning.read().1; let lsn = match l0_l1_boundary_lsn { Some(boundary) => gc_cutoff @@ -1348,8 +1342,6 @@ impl Timeline { } else { ("l0_l1_boundary", lsn) } - } else { - ("latest_record", self.get_last_record_lsn()) }; // 2. Repartition and create image layers if necessary diff --git a/pageserver/src/tenant/timeline/handle.rs b/pageserver/src/tenant/timeline/handle.rs index 7bca66190f..0b118dd65d 100644 --- a/pageserver/src/tenant/timeline/handle.rs +++ b/pageserver/src/tenant/timeline/handle.rs @@ -654,7 +654,7 @@ mod tests { use pageserver_api::key::{DBDIR_KEY, Key, rel_block_to_key}; use pageserver_api::models::ShardParameters; use pageserver_api::reltag::RelTag; - use pageserver_api::shard::ShardStripeSize; + use pageserver_api::shard::DEFAULT_STRIPE_SIZE; use utils::shard::ShardCount; use utils::sync::gate::GateGuard; @@ -955,7 +955,7 @@ mod tests { }); let child_params = ShardParameters { count: ShardCount(2), - stripe_size: ShardStripeSize::default(), + stripe_size: DEFAULT_STRIPE_SIZE, }; let child0 = Arc::new_cyclic(|myself| StubTimeline { gate: Default::default(), diff --git a/pageserver/src/virtual_file.rs b/pageserver/src/virtual_file.rs index 45b6e44c54..a7f0c5914a 100644 --- a/pageserver/src/virtual_file.rs +++ b/pageserver/src/virtual_file.rs @@ -1275,8 +1275,8 @@ mod tests { use std::sync::Arc; use owned_buffers_io::io_buf_ext::IoBufExt; + use rand::Rng; use rand::seq::SliceRandom; - use rand::{Rng, thread_rng}; use super::*; use crate::context::DownloadBehavior; @@ -1358,7 +1358,7 @@ mod tests { // Check that all the other FDs still work too. Use them in random order for // good measure. - file_b_dupes.as_mut_slice().shuffle(&mut thread_rng()); + file_b_dupes.as_mut_slice().shuffle(&mut rand::rng()); for vfile in file_b_dupes.iter_mut() { assert_first_512_eq(vfile, b"content_b").await; } @@ -1413,9 +1413,8 @@ mod tests { let ctx = ctx.detached_child(TaskKind::UnitTest, DownloadBehavior::Error); let hdl = rt.spawn(async move { let mut buf = IoBufferMut::with_capacity_zeroed(SIZE); - let mut rng = rand::rngs::OsRng; for _ in 1..1000 { - let f = &files[rng.gen_range(0..files.len())]; + let f = &files[rand::rng().random_range(0..files.len())]; buf = f .read_exact_at(buf.slice_full(), 0, &ctx) .await diff --git a/pgxn/neon/Makefile b/pgxn/neon/Makefile index 3ea7a946cf..04a06fcb63 100644 --- a/pgxn/neon/Makefile +++ b/pgxn/neon/Makefile @@ -6,6 +6,7 @@ OBJS = \ $(WIN32RES) \ communicator.o \ communicator_new.o \ + communicator_process.o \ extension_server.o \ file_cache.o \ hll.o \ @@ -65,6 +66,8 @@ WALPROP_OBJS = \ # libcommunicator.a is built by cargo from the Rust sources under communicator/ # subdirectory. `cargo build` also generates communicator_bindings.h. communicator_new.o: communicator/communicator_bindings.h +communicator_process.o: communicator/communicator_bindings.h +file_cache.o: communicator/communicator_bindings.h $(NEON_CARGO_ARTIFACT_TARGET_DIR)/libcommunicator.a communicator/communicator_bindings.h &: (cd $(srcdir)/communicator && cargo build $(CARGO_BUILD_FLAGS) $(CARGO_PROFILE)) diff --git a/pgxn/neon/communicator.c b/pgxn/neon/communicator.c index 158b8940a3..5a08b3e331 100644 --- a/pgxn/neon/communicator.c +++ b/pgxn/neon/communicator.c @@ -1820,12 +1820,12 @@ nm_to_string(NeonMessage *msg) } case T_NeonGetPageResponse: { -#if 0 NeonGetPageResponse *msg_resp = (NeonGetPageResponse *) msg; -#endif appendStringInfoString(&s, "{\"type\": \"NeonGetPageResponse\""); - appendStringInfo(&s, ", \"page\": \"XXX\"}"); + appendStringInfo(&s, ", \"rinfo\": %u/%u/%u", RelFileInfoFmt(msg_resp->req.rinfo)); + appendStringInfo(&s, ", \"forknum\": %d", msg_resp->req.forknum); + appendStringInfo(&s, ", \"blkno\": %u", msg_resp->req.blkno); appendStringInfoChar(&s, '}'); break; } diff --git a/pgxn/neon/communicator/Cargo.toml b/pgxn/neon/communicator/Cargo.toml index 677e6ed388..0901b66428 100644 --- a/pgxn/neon/communicator/Cargo.toml +++ b/pgxn/neon/communicator/Cargo.toml @@ -7,6 +7,9 @@ edition.workspace = true # 'testing' feature is currently unused in the communicator, but we accept it for convenience of # calling build scripts, so that you can pass the same feature to all packages. testing = [] +# 'rest_broker' feature is currently unused in the communicator, but we accept it for convenience of +# calling build scripts, so that you can pass the same feature to all packages. +rest_broker = [] [lib] crate-type = ["staticlib"] @@ -19,12 +22,13 @@ http.workspace = true libc.workspace = true nix.workspace = true atomic_enum = "0.3.0" +measured.workspace = true prometheus.workspace = true prost.workspace = true -tonic = { version = "0.12.0", default-features = false, features=["codegen", "prost", "transport"] } -tokio = { version = "1.43.1", features = ["macros", "net", "io-util", "rt", "rt-multi-thread"] } -tokio-pipe = { version = "0.2.12" } thiserror.workspace = true +tonic = { workspace = true, default-features = false, features=["codegen", "prost", "transport"] } +tokio = { workspace = true, features = ["macros", "net", "io-util", "rt", "rt-multi-thread"] } +tokio-pipe = { version = "0.2.12" } tracing.workspace = true tracing-subscriber.workspace = true diff --git a/pgxn/neon/communicator/README.md b/pgxn/neon/communicator/README.md index a18f64c9f6..0644495496 100644 --- a/pgxn/neon/communicator/README.md +++ b/pgxn/neon/communicator/README.md @@ -1,11 +1,15 @@ # Communicator This package provides the so-called "compute-pageserver communicator", -or just "communicator" in short. It runs in a PostgreSQL server, as -part of the neon extension, and handles the communication with the -pageservers. On the PostgreSQL side, the glue code in pgxn/neon/ uses -the communicator to implement the PostgreSQL Storage Manager (SMGR) -interface. +or just "communicator" in short. The communicator is a separate +background worker process that runs in the PostgreSQL server. It's +part of the neon extension. + +The commuicator handles the communication with the pageservers, and +also provides an HTTP endpoint for metrics over a local Unix Domain +socket (aka. the "communicator control socket"). On the PostgreSQL +side, the glue code in pgxn/neon/ uses the communicator to implement +the PostgreSQL Storage Manager (SMGR) interface. ## Design criteria @@ -14,9 +18,14 @@ interface. ## Source code view +pgxn/neon/communicator_process.c + Contains code needed to start up the communicator process, and + the glue that interacts with PostgreSQL code and the Rust + code in the communicator process. + pgxn/neon/communicator_new.c - Contains the glue that interact with PostgreSQL code and the Rust - communicator code. + Contains the backend code that interacts with the communicator + process. pgxn/neon/communicator/src/backend_interface.rs The entry point for calls from each backend. @@ -24,9 +33,6 @@ pgxn/neon/communicator/src/backend_interface.rs pgxn/neon/communicator/src/init.rs Initialization at server startup -pgxn/neon/communicator/src/worker_process/ - Worker process main loop and glue code - At compilation time, pgxn/neon/communicator/ produces a static library, libcommunicator.a. It is linked to the neon.so extension library. diff --git a/pgxn/neon/communicator/src/backend_interface.rs b/pgxn/neon/communicator/src/backend_interface.rs index abc982193e..f31cbda20e 100644 --- a/pgxn/neon/communicator/src/backend_interface.rs +++ b/pgxn/neon/communicator/src/backend_interface.rs @@ -215,11 +215,17 @@ pub struct FileCacheIterator { /// Iterate over LFC contents #[unsafe(no_mangle)] -pub extern "C" fn bcomm_cache_iterate_begin(_bs: &mut CommunicatorBackendStruct, iter: *mut FileCacheIterator) { +pub extern "C" fn bcomm_cache_iterate_begin( + _bs: &mut CommunicatorBackendStruct, + iter: *mut FileCacheIterator, +) { unsafe { (*iter).next_bucket = 0 }; } #[unsafe(no_mangle)] -pub extern "C" fn bcomm_cache_iterate_next(bs: &mut CommunicatorBackendStruct, iter: *mut FileCacheIterator) -> bool { +pub extern "C" fn bcomm_cache_iterate_next( + bs: &mut CommunicatorBackendStruct, + iter: *mut FileCacheIterator, +) -> bool { use crate::integrated_cache::GetBucketResult; loop { let next_bucket = unsafe { (*iter).next_bucket } as usize; @@ -235,7 +241,7 @@ pub extern "C" fn bcomm_cache_iterate_next(bs: &mut CommunicatorBackendStruct, i (*iter).next_bucket += 1; } break true; - }, + } GetBucketResult::Vacant => { unsafe { (*iter).next_bucket += 1; diff --git a/pgxn/neon/communicator/src/integrated_cache.rs b/pgxn/neon/communicator/src/integrated_cache.rs index e43e76b1b5..c2850dd961 100644 --- a/pgxn/neon/communicator/src/integrated_cache.rs +++ b/pgxn/neon/communicator/src/integrated_cache.rs @@ -759,7 +759,6 @@ impl<'t> IntegratedCacheReadAccess<'t> { Some((key, _)) => GetBucketResult::Occupied(key.rel, key.block_number), } } - } pub struct BackendCacheReadOp<'t> { diff --git a/pgxn/neon/communicator/src/lib.rs b/pgxn/neon/communicator/src/lib.rs index d0c5b758da..2bcc20f054 100644 --- a/pgxn/neon/communicator/src/lib.rs +++ b/pgxn/neon/communicator/src/lib.rs @@ -21,5 +21,9 @@ mod worker_process; mod global_allocator; +/// Name of the Unix Domain Socket that serves the metrics, and other APIs in the +/// future. This is within the Postgres data directory. +const NEON_COMMUNICATOR_SOCKET_NAME: &str = "neon-communicator.socket"; + // FIXME: get this from postgres headers somehow pub const BLCKSZ: usize = 8192; diff --git a/pgxn/neon/communicator/src/neon_request.rs b/pgxn/neon/communicator/src/neon_request.rs index f777256a5f..d40e7484f9 100644 --- a/pgxn/neon/communicator/src/neon_request.rs +++ b/pgxn/neon/communicator/src/neon_request.rs @@ -14,6 +14,8 @@ pub type COid = u32; // This conveniently matches PG_IOV_MAX pub const MAX_GETPAGEV_PAGES: usize = 32; +pub const INVALID_BLOCK_NUMBER: u32 = u32::MAX; + use std::ffi::CStr; use pageserver_page_api::{self as page_api, SlruKind}; @@ -27,7 +29,6 @@ pub enum NeonIORequest { // Read requests. These are C-friendly variants of the corresponding structs in // pageserver_page_api. - RelExists(CRelExistsRequest), RelSize(CRelSizeRequest), GetPageV(CGetPageVRequest), ReadSlruSegment(CReadSlruSegmentRequest), @@ -51,7 +52,7 @@ pub enum NeonIORequest { #[derive(Copy, Clone, Debug)] pub enum NeonIOResult { Empty, - RelExists(bool), + /// InvalidBlockNumber == 0xffffffff means "rel does not exist" RelSize(u32), /// the result pages are written to the shared memory addresses given in the request @@ -85,7 +86,6 @@ impl NeonIORequest { use NeonIORequest::*; match self { Empty => 0, - RelExists(req) => req.request_id, RelSize(req) => req.request_id, GetPageV(req) => req.request_id, ReadSlruSegment(req) => req.request_id, @@ -164,16 +164,6 @@ impl ShmemBuf { } } -#[repr(C)] -#[derive(Copy, Clone, Debug)] -pub struct CRelExistsRequest { - pub request_id: u64, - pub spc_oid: COid, - pub db_oid: COid, - pub rel_number: u32, - pub fork_number: u8, -} - #[repr(C)] #[derive(Copy, Clone, Debug)] pub struct CRelSizeRequest { @@ -182,6 +172,7 @@ pub struct CRelSizeRequest { pub db_oid: COid, pub rel_number: u32, pub fork_number: u8, + pub allow_missing: bool, } #[repr(C)] @@ -317,17 +308,6 @@ pub struct CRelUnlinkRequest { pub lsn: CLsn, } -impl CRelExistsRequest { - pub fn reltag(&self) -> page_api::RelTag { - page_api::RelTag { - spcnode: self.spc_oid, - dbnode: self.db_oid, - relnode: self.rel_number, - forknum: self.fork_number, - } - } -} - impl CRelSizeRequest { pub fn reltag(&self) -> page_api::RelTag { page_api::RelTag { diff --git a/pgxn/neon/communicator/src/worker_process/callbacks.rs b/pgxn/neon/communicator/src/worker_process/callbacks.rs index c3b3a8e3b5..d10605ce4e 100644 --- a/pgxn/neon/communicator/src/worker_process/callbacks.rs +++ b/pgxn/neon/communicator/src/worker_process/callbacks.rs @@ -1,16 +1,38 @@ -//! C callbacks to PostgreSQL facilities that the neon extension needs -//! to provide. These are implemented in `neon/pgxn/communicator_new.c`. -//! The function signatures better match! +//! C callbacks to PostgreSQL facilities that the neon extension needs to provide. These +//! are implemented in `neon/pgxn/communicator_process.c`. The function signatures better +//! match! //! -//! These are called from the communicator threads! Careful what you do, most -//! Postgres functions are not safe to call in that context. - +//! These are called from the communicator threads! Careful what you do, most Postgres +//! functions are not safe to call in that context. use utils::lsn::Lsn; +#[cfg(not(test))] unsafe extern "C" { pub fn notify_proc_unsafe(procno: std::ffi::c_int); pub fn callback_set_my_latch_unsafe(); pub fn callback_get_request_lsn_unsafe() -> u64; + pub fn callback_get_lfc_metrics_unsafe() -> LfcMetrics; +} + +// Compile unit tests with dummy versions of the functions. Unit tests cannot call back +// into the C code. (As of this writing, no unit tests even exists in the communicator +// package, but the code coverage build still builds these and tries to link with the +// external C code.) +#[cfg(test)] +unsafe fn notify_proc_unsafe(procno: std::ffi::c_int) { + panic!("not usable in unit tests"); +} +#[cfg(test)] +unsafe fn callback_set_my_latch_unsafe() { + panic!("not usable in unit tests"); +} +#[cfg(test)] +unsafe fn callback_get_request_lsn_unsafe() { + panic!("not usable in unit tests"); +} +#[cfg(test)] +unsafe fn callback_get_lfc_metrics_unsafe() -> LfcMetrics { + panic!("not usable in unit tests"); } // safe wrappers @@ -26,3 +48,23 @@ pub(super) fn callback_set_my_latch() { pub(super) fn get_request_lsn() -> Lsn { Lsn(unsafe { callback_get_request_lsn_unsafe() }) } + +pub(super) fn callback_get_lfc_metrics() -> LfcMetrics { + unsafe { callback_get_lfc_metrics_unsafe() } +} + +/// Return type of the callback_get_lfc_metrics() function. +#[repr(C)] +pub struct LfcMetrics { + pub lfc_cache_size_limit: i64, + pub lfc_hits: i64, + pub lfc_misses: i64, + pub lfc_used: i64, + pub lfc_writes: i64, + + // working set size looking back 1..60 minutes. + // + // Index 0 is the size of the working set accessed within last 1 minute, + // index 59 is the size of the working set accessed within last 60 minutes. + pub lfc_approximate_working_set_size_windows: [i64; 60], +} diff --git a/pgxn/neon/communicator/src/worker_process/control_socket.rs b/pgxn/neon/communicator/src/worker_process/control_socket.rs new file mode 100644 index 0000000000..f1d052778b --- /dev/null +++ b/pgxn/neon/communicator/src/worker_process/control_socket.rs @@ -0,0 +1,118 @@ +//! Communicator control socket. +//! +//! Currently, the control socket is used to provide information about the communicator +//! process, file cache etc. as prometheus metrics. In the future, it can be used to +//! expose more things. +//! +//! The exporter speaks HTTP, listens on a Unix Domain Socket under the Postgres +//! data directory. For debugging, you can access it with curl: +//! +//! ```sh +//! curl --unix-socket neon-communicator.socket http://localhost/metrics +//! ``` +//! +use axum::Router; +use axum::body::Body; +use axum::extract::State; +use axum::response::Response; +use http::StatusCode; +use http::header::CONTENT_TYPE; + +use measured::MetricGroup; +use measured::text::BufferedTextEncoder; + +use std::io::ErrorKind; + +use tokio::net::UnixListener; + +use crate::NEON_COMMUNICATOR_SOCKET_NAME; +use crate::worker_process::main_loop::CommunicatorWorkerProcessStruct; + +impl<'a> CommunicatorWorkerProcessStruct<'a> { + /// Launch the listener + pub(crate) async fn launch_control_socket_listener( + &'static self, + ) -> Result<(), std::io::Error> { + use axum::routing::get; + let app = Router::new() + .route("/metrics", get(get_metrics)) + .route("/autoscaling_metrics", get(get_autoscaling_metrics)) + .route("/debug/panic", get(handle_debug_panic)) + .route("/debug/dump_cache_map", get(dump_cache_map)) + .with_state(self); + + // If the server is restarted, there might be an old socket still + // lying around. Remove it first. + match std::fs::remove_file(NEON_COMMUNICATOR_SOCKET_NAME) { + Ok(()) => { + tracing::warn!("removed stale control socket"); + } + Err(e) if e.kind() == ErrorKind::NotFound => {} + Err(e) => { + tracing::error!("could not remove stale control socket: {e:#}"); + // Try to proceed anyway. It will likely fail below though. + } + }; + + // Create the unix domain socket and start listening on it + let listener = UnixListener::bind(NEON_COMMUNICATOR_SOCKET_NAME)?; + + tokio::spawn(async { + tracing::info!("control socket listener spawned"); + axum::serve(listener, app) + .await + .expect("axum::serve never returns") + }); + + Ok(()) + } +} + +/// Expose all Prometheus metrics. +async fn get_metrics(State(state): State<&CommunicatorWorkerProcessStruct<'_>>) -> Response { + tracing::trace!("/metrics requested"); + metrics_to_response(&state).await +} + +/// Expose Prometheus metrics, for use by the autoscaling agent. +/// +/// This is a subset of all the metrics. +async fn get_autoscaling_metrics( + State(state): State<&CommunicatorWorkerProcessStruct<'_>>, +) -> Response { + tracing::trace!("/metrics requested"); + metrics_to_response(&state.lfc_metrics).await +} + +async fn handle_debug_panic( + State(_state): State<&CommunicatorWorkerProcessStruct<'_>>, +) -> Response { + panic!("test HTTP handler task panic"); +} + +/// Helper function to convert prometheus metrics to a text response +async fn metrics_to_response(metrics: &(dyn MetricGroup + Sync)) -> Response { + let mut enc = BufferedTextEncoder::new(); + metrics + .collect_group_into(&mut enc) + .unwrap_or_else(|never| match never {}); + + Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "application/text") + .body(Body::from(enc.finish())) + .unwrap() +} + +async fn dump_cache_map( + State(state): State<&CommunicatorWorkerProcessStruct<'static>>, +) -> Response { + let mut buf: Vec = Vec::new(); + state.cache.dump_map(&mut buf); + + Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, "application/text") + .body(Body::from(buf)) + .unwrap() +} diff --git a/pgxn/neon/communicator/src/worker_process/lfc_metrics.rs b/pgxn/neon/communicator/src/worker_process/lfc_metrics.rs new file mode 100644 index 0000000000..fcb291c71b --- /dev/null +++ b/pgxn/neon/communicator/src/worker_process/lfc_metrics.rs @@ -0,0 +1,83 @@ +use measured::{ + FixedCardinalityLabel, Gauge, GaugeVec, LabelGroup, MetricGroup, + label::{LabelName, LabelValue, StaticLabelSet}, + metric::{MetricEncoding, gauge::GaugeState, group::Encoding}, +}; + +use super::callbacks::callback_get_lfc_metrics; + +pub(crate) struct LfcMetricsCollector; + +#[derive(MetricGroup)] +#[metric(new())] +struct LfcMetricsGroup { + /// LFC cache size limit in bytes + lfc_cache_size_limit: Gauge, + /// LFC cache hits + lfc_hits: Gauge, + /// LFC cache misses + lfc_misses: Gauge, + /// LFC chunks used (chunk = 1MB) + lfc_used: Gauge, + /// LFC cache writes + lfc_writes: Gauge, + /// Approximate working set size in pages of 8192 bytes + #[metric(init = GaugeVec::dense())] + lfc_approximate_working_set_size_windows: GaugeVec>, +} + +impl MetricGroup for LfcMetricsCollector +where + GaugeState: MetricEncoding, +{ + fn collect_group_into(&self, enc: &mut T) -> Result<(), ::Err> { + let g = LfcMetricsGroup::new(); + + let lfc_metrics = callback_get_lfc_metrics(); + + g.lfc_cache_size_limit.set(lfc_metrics.lfc_cache_size_limit); + g.lfc_hits.set(lfc_metrics.lfc_hits); + g.lfc_misses.set(lfc_metrics.lfc_misses); + g.lfc_used.set(lfc_metrics.lfc_used); + g.lfc_writes.set(lfc_metrics.lfc_writes); + + for i in 0..60 { + let val = lfc_metrics.lfc_approximate_working_set_size_windows[i]; + g.lfc_approximate_working_set_size_windows + .set(MinuteAsSeconds(i), val); + } + + g.collect_group_into(enc) + } +} + +/// This stores the values in range 0..60, +/// encodes them as seconds (60, 120, 180, ..., 3600) +#[derive(Clone, Copy)] +struct MinuteAsSeconds(usize); + +impl FixedCardinalityLabel for MinuteAsSeconds { + fn cardinality() -> usize { + 60 + } + + fn encode(&self) -> usize { + self.0 + } + + fn decode(value: usize) -> Self { + Self(value) + } +} + +impl LabelValue for MinuteAsSeconds { + fn visit(&self, v: V) -> V::Output { + v.write_int((self.0 + 1) as i64 * 60) + } +} + +impl LabelGroup for MinuteAsSeconds { + fn visit_values(&self, v: &mut impl measured::label::LabelGroupVisitor) { + v.write_value(LabelName::from_str("duration_seconds"), self); + } +} diff --git a/pgxn/neon/communicator/src/worker_process/logging.rs b/pgxn/neon/communicator/src/worker_process/logging.rs index 43f51cd332..1ae31cd0dd 100644 --- a/pgxn/neon/communicator/src/worker_process/logging.rs +++ b/pgxn/neon/communicator/src/worker_process/logging.rs @@ -48,7 +48,7 @@ pub extern "C" fn communicator_worker_configure_logging() -> Box Self { - SimpleFormatter {} - } -} diff --git a/pgxn/neon/communicator/src/worker_process/main_loop.rs b/pgxn/neon/communicator/src/worker_process/main_loop.rs index 0b2f9da366..00a684e91a 100644 --- a/pgxn/neon/communicator/src/worker_process/main_loop.rs +++ b/pgxn/neon/communicator/src/worker_process/main_loop.rs @@ -10,8 +10,9 @@ use crate::global_allocator::MyAllocatorCollector; use crate::init::CommunicatorInitStruct; use crate::integrated_cache::{CacheResult, IntegratedCacheWriteAccess}; use crate::neon_request::{CGetPageVRequest, CPrefetchVRequest}; -use crate::neon_request::{NeonIORequest, NeonIOResult}; +use crate::neon_request::{NeonIORequest, NeonIOResult, INVALID_BLOCK_NUMBER}; use crate::worker_process::in_progress_ios::{RequestInProgressKey, RequestInProgressTable}; +use crate::worker_process::lfc_metrics::LfcMetricsCollector; use pageserver_client_grpc::{PageserverClient, ShardSpec, ShardStripeSize}; use pageserver_page_api as page_api; @@ -20,6 +21,11 @@ use metrics::{IntCounter, IntCounterVec}; use tokio::io::AsyncReadExt; use tokio_pipe::PipeRead; use uring_common::buf::IoBuf; + +use measured::MetricGroup; +use measured::metric::MetricEncoding; +use measured::metric::gauge::GaugeState; +use measured::metric::group::Encoding; use utils::id::{TenantId, TimelineId}; use super::callbacks::{get_request_lsn, notify_proc}; @@ -30,10 +36,10 @@ use utils::lsn::Lsn; pub struct CommunicatorWorkerProcessStruct<'a> { /// Tokio runtime that the main loop and any other related tasks runs in. - runtime: tokio::runtime::Handle, + runtime: tokio::runtime::Runtime, /// Client to communicate with the pageserver - client: PageserverClient, + client: Option, /// Request slots that backends use to send IO requests to the communicator. neon_request_slots: &'a [NeonIORequestSlot], @@ -54,8 +60,9 @@ pub struct CommunicatorWorkerProcessStruct<'a> { stripe_size: Option, /*** Metrics ***/ + pub(crate) lfc_metrics: LfcMetricsCollector, + request_counters: IntCounterVec, - request_rel_exists_counter: IntCounter, request_rel_size_counter: IntCounter, request_get_pagev_counter: IntCounter, request_read_slru_segment_counter: IntCounter, @@ -79,17 +86,35 @@ pub struct CommunicatorWorkerProcessStruct<'a> { allocator_metrics: MyAllocatorCollector, } -pub(super) async fn init( +/// Launch the communicator process's Rust subsystems +pub(super) fn init( cis: Box, - tenant_id: String, - timeline_id: String, - auth_token: Option, + tenant_id: Option<&str>, + timeline_id: Option<&str>, + auth_token: Option<&str>, shard_map: HashMap, stripe_size: Option, initial_file_cache_size: u64, file_cache_path: Option, -) -> CommunicatorWorkerProcessStruct<'static> { - info!("Test log message"); +) -> Result<&'static CommunicatorWorkerProcessStruct<'static>, String> { + // The caller validated these already + let tenant_id = tenant_id + .map(TenantId::from_str) + .transpose() + .map_err(|e| format!("invalid tenant ID: {e}"))?; + let timeline_id = timeline_id + .map(TimelineId::from_str) + .transpose() + .map_err(|e| format!("invalid timeline ID: {e}"))?; + let shard_spec = + ShardSpec::new(shard_map, stripe_size).map_err(|e| format!("invalid shard spec: {e}:"))?; + + let runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .thread_name("communicator thread") + .build() + .unwrap(); + let last_lsn = get_request_lsn(); let file_cache = if let Some(path) = file_cache_path { @@ -109,11 +134,21 @@ pub(super) async fn init( debug!("Initialised integrated cache: {cache:?}"); - let tenant_id = TenantId::from_str(&tenant_id).expect("invalid tenant ID"); - let timeline_id = TimelineId::from_str(&timeline_id).expect("invalid timeline ID"); - let shard_spec = ShardSpec::new(shard_map, stripe_size).expect("invalid shard spec"); - let client = PageserverClient::new(tenant_id, timeline_id, shard_spec, auth_token, None) - .expect("could not create client"); + let client = if let (Some(tenant_id), Some(timeline_id)) = (tenant_id, timeline_id) { + let _guard = runtime.enter(); + Some( + PageserverClient::new( + tenant_id, + timeline_id, + shard_spec, + auth_token.map(|s| s.to_string()), + None, + ) + .expect("could not create client"), + ) + } else { + None + }; let request_counters = IntCounterVec::new( metrics::core::Opts::new( @@ -123,7 +158,6 @@ pub(super) async fn init( &["request_kind"], ) .unwrap(); - let request_rel_exists_counter = request_counters.with_label_values(&["rel_exists"]); let request_rel_size_counter = request_counters.with_label_values(&["rel_size"]); let request_get_pagev_counter = request_counters.with_label_values(&["get_pagev"]); let request_read_slru_segment_counter = @@ -164,8 +198,10 @@ pub(super) async fn init( let request_rel_zero_extend_nblocks_counter = request_nblocks_counters.with_label_values(&["rel_zero_extend"]); - CommunicatorWorkerProcessStruct { - runtime: tokio::runtime::Handle::current(), + let worker_struct = CommunicatorWorkerProcessStruct { + // Note: it's important to not drop the runtime, or all the tasks are dropped + // too. Including it in the returned struct is one way to keep it around. + runtime, stripe_size, neon_request_slots: cis.neon_request_slots, client, @@ -174,8 +210,9 @@ pub(super) async fn init( in_progress_table: RequestInProgressTable::new(), // metrics + lfc_metrics: LfcMetricsCollector, + request_counters, - request_rel_exists_counter, request_rel_size_counter, request_get_pagev_counter, request_read_slru_segment_counter, @@ -197,7 +234,23 @@ pub(super) async fn init( request_rel_zero_extend_nblocks_counter, allocator_metrics: MyAllocatorCollector::new(), - } + }; + + let worker_struct = Box::leak(Box::new(worker_struct)); + + let main_loop_handle = worker_struct.runtime.spawn(worker_struct.run()); + worker_struct.runtime.spawn(async { + let err = main_loop_handle.await.unwrap_err(); + error!("error: {err:?}"); + }); + + // Start the listener on the control socket + worker_struct + .runtime + .block_on(worker_struct.launch_control_socket_listener()) + .map_err(|e| e.to_string())?; + + Ok(worker_struct) } impl<'t> CommunicatorWorkerProcessStruct<'t> { @@ -206,12 +259,13 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { &self, new_shard_map: HashMap, ) { + let client = self.client.as_ref().unwrap(); let shard_spec = ShardSpec::new(new_shard_map, self.stripe_size.clone()).expect("invalid shard spec"); { let _in_runtime = self.runtime.enter(); - if let Err(err) = self.client.update_shards(shard_spec) { + if let Err(err) = client.update_shards(shard_spec) { tracing::error!("could not update shard map: {err:?}"); } } @@ -336,42 +390,15 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { /// Handle one IO request async fn handle_request(&'static self, req: &'_ NeonIORequest) -> NeonIOResult { + let client = self + .client + .as_ref() + .expect("cannot handle requests without client"); match req { NeonIORequest::Empty => { error!("unexpected Empty IO request"); NeonIOResult::Error(0) } - NeonIORequest::RelExists(req) => { - self.request_rel_exists_counter.inc(); - let rel = req.reltag(); - - let _in_progress_guard = self - .in_progress_table - .lock(RequestInProgressKey::Rel(rel), req.request_id) - .await; - - // Check the cache first - let not_modified_since = match self.cache.get_rel_exists(&rel) { - CacheResult::Found(exists) => return NeonIOResult::RelExists(exists), - CacheResult::NotFound(lsn) => lsn, - }; - - match self - .client - .check_rel_exists(page_api::CheckRelExistsRequest { - read_lsn: self.request_lsns(not_modified_since), - rel, - }) - .await - { - Ok(exists) => NeonIOResult::RelExists(exists), - Err(err) => { - info!("tonic error: {err:?}"); - NeonIOResult::Error(0) - } - } - } - NeonIORequest::RelSize(req) => { self.request_rel_size_counter.inc(); let rel = req.reltag(); @@ -387,16 +414,21 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { tracing::trace!("found relsize for {:?} in cache: {}", rel, nblocks); return NeonIOResult::RelSize(nblocks); } + // XXX: we don't cache negative entries, so if there's no entry in the cache, it could mean + // that the relation doesn't exist or that we don't have it cached. CacheResult::NotFound(lsn) => lsn, }; let read_lsn = self.request_lsns(not_modified_since); - match self - .client - .get_rel_size(page_api::GetRelSizeRequest { read_lsn, rel }) + match client + .get_rel_size(page_api::GetRelSizeRequest { + read_lsn, + rel, + allow_missing: req.allow_missing, + }) .await { - Ok(nblocks) => { + Ok(Some(nblocks)) => { // update the cache tracing::info!( "updated relsize for {:?} in cache: {}, lsn {}", @@ -409,6 +441,10 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { NeonIOResult::RelSize(nblocks) } + Ok(None) => { + // TODO: cache negative entry? + NeonIOResult::RelSize(INVALID_BLOCK_NUMBER) + } Err(err) => { info!("tonic error: {err:?}"); NeonIOResult::Error(0) @@ -429,8 +465,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { let lsn = Lsn(req.request_lsn); let file_path = req.destination_file_path(); - match self - .client + match client .get_slru_segment(page_api::GetSlruSegmentRequest { read_lsn: self.request_lsns(lsn), kind: req.slru_kind, @@ -478,8 +513,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { CacheResult::NotFound(lsn) => lsn, }; - match self - .client + match client .get_db_size(page_api::GetDbSizeRequest { read_lsn: self.request_lsns(not_modified_since), db_oid: req.db_oid, @@ -585,6 +619,10 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { /// Subroutine to handle a GetPageV request, since it's a little more complicated than /// others. async fn handle_get_pagev_request(&'t self, req: &CGetPageVRequest) -> Result<(), i32> { + let client = self + .client + .as_ref() + .expect("cannot handle requests without client"); let rel = req.reltag(); // Check the cache first @@ -643,8 +681,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { "sending getpage request for blocks {:?} in rel {:?} lsns {}", block_numbers, rel, read_lsn ); - match self - .client + match client .get_page(page_api::GetPageRequest { request_id: req.request_id.into(), request_class: page_api::GetPageClass::Normal, @@ -704,6 +741,10 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { /// /// This is very similar to a GetPageV request, but the results are only stored in the cache. async fn handle_prefetchv_request(&'static self, req: &CPrefetchVRequest) -> Result<(), i32> { + let client = self + .client + .as_ref() + .expect("cannot handle requests without client"); let rel = req.reltag(); // Check the cache first @@ -744,8 +785,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { // TODO: spawn separate tasks for these. Use the integrated cache to keep track of the // in-flight requests - match self - .client + match client .get_page(page_api::GetPageRequest { request_id: req.request_id.into(), request_class: page_api::GetPageClass::Prefetch, @@ -818,3 +858,13 @@ impl<'t> metrics::core::Collector for CommunicatorWorkerProcessStruct<'t> { values } } + +impl MetricGroup for CommunicatorWorkerProcessStruct<'_> +where + T: Encoding, + GaugeState: MetricEncoding, +{ + fn collect_group_into(&self, enc: &mut T) -> Result<(), T::Err> { + self.lfc_metrics.collect_group_into(enc) + } +} diff --git a/pgxn/neon/communicator/src/worker_process/metrics_exporter.rs b/pgxn/neon/communicator/src/worker_process/metrics_exporter.rs index 9b0891b5aa..e69de29bb2 100644 --- a/pgxn/neon/communicator/src/worker_process/metrics_exporter.rs +++ b/pgxn/neon/communicator/src/worker_process/metrics_exporter.rs @@ -1,82 +0,0 @@ -//! Export information about Postgres, the communicator process, file cache etc. as -//! prometheus metrics. - -use axum::Router; -use axum::body::Body; -use axum::extract::State; -use axum::response::Response; -use http::StatusCode; -use http::header::CONTENT_TYPE; - -use metrics::proto::MetricFamily; -use metrics::{Encoder, TextEncoder}; - -use std::path::PathBuf; - -use tokio::net::UnixListener; - -use crate::worker_process::main_loop::CommunicatorWorkerProcessStruct; - -impl<'a> CommunicatorWorkerProcessStruct<'a> { - pub(crate) async fn launch_exporter_task(&'static self) { - use axum::routing::get; - let app = Router::new() - .route("/metrics", get(get_metrics)) - .route("/dump_cache_map", get(dump_cache_map)) - .with_state(self); - - // Listen on unix domain socket, in the data directory. That should be unique. - let path = PathBuf::from(".metrics.socket"); - - let listener = UnixListener::bind(path.clone()).unwrap(); - - tokio::spawn(async { - tracing::info!("metrics listener spawned"); - axum::serve(listener, app).await.unwrap() - }); - } -} - -async fn dump_cache_map( - State(state): State<&CommunicatorWorkerProcessStruct<'static>>, -) -> Response { - let mut buf: Vec = Vec::new(); - state.cache.dump_map(&mut buf); - - Response::builder() - .status(StatusCode::OK) - .header(CONTENT_TYPE, "application/text") - .body(Body::from(buf)) - .unwrap() -} - -/// Expose Prometheus metrics. -async fn get_metrics(State(state): State<&CommunicatorWorkerProcessStruct<'static>>) -> Response { - use metrics::core::Collector; - let metrics = state.collect(); - - // When we call TextEncoder::encode() below, it will immediately return an - // error if a metric family has no metrics, so we need to preemptively - // filter out metric families with no metrics. - let metrics = metrics - .into_iter() - .filter(|m| !m.get_metric().is_empty()) - .collect::>(); - - let encoder = TextEncoder::new(); - let mut buffer = vec![]; - - if let Err(e) = encoder.encode(&metrics, &mut buffer) { - Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .header(CONTENT_TYPE, "application/text") - .body(Body::from(e.to_string())) - .unwrap() - } else { - Response::builder() - .status(StatusCode::OK) - .header(CONTENT_TYPE, encoder.format_type()) - .body(Body::from(buffer)) - .unwrap() - } -} diff --git a/pgxn/neon/communicator/src/worker_process/mod.rs b/pgxn/neon/communicator/src/worker_process/mod.rs index 064d106d4c..31e6731abf 100644 --- a/pgxn/neon/communicator/src/worker_process/mod.rs +++ b/pgxn/neon/communicator/src/worker_process/mod.rs @@ -1,14 +1,13 @@ //! This code runs in the communicator worker process. This provides //! the glue code to: //! -//! - launch the 'processor', -//! - receive IO requests from backends and pass them to the processor, +//! - launch the main loop, +//! - receive IO requests from backends and process them, //! - write results back to backends. - mod callbacks; +mod control_socket; +mod in_progress_ios; +mod lfc_metrics; mod logging; mod main_loop; -mod metrics_exporter; mod worker_interface; - -mod in_progress_ios; diff --git a/pgxn/neon/communicator/src/worker_process/worker_interface.rs b/pgxn/neon/communicator/src/worker_process/worker_interface.rs index ff9b1ba699..4d49c76549 100644 --- a/pgxn/neon/communicator/src/worker_process/worker_interface.rs +++ b/pgxn/neon/communicator/src/worker_process/worker_interface.rs @@ -1,11 +1,9 @@ //! Functions called from the C code in the worker process use std::collections::HashMap; -use std::ffi::{CStr, c_char}; +use std::ffi::{CStr, CString, c_char}; use std::path::PathBuf; -use tracing::error; - use crate::init::CommunicatorInitStruct; use crate::worker_process::main_loop; use crate::worker_process::main_loop::CommunicatorWorkerProcessStruct; @@ -14,10 +12,23 @@ use pageserver_client_grpc::ShardStripeSize; /// Launch the communicator's tokio tasks, which do most of the work. /// -/// The caller has initialized the process as a regular PostgreSQL -/// background worker process. The shared memory segment used to -/// communicate with the backends has been allocated and initialized -/// earlier, at postmaster startup, in rcommunicator_shmem_init(). +/// The caller has initialized the process as a regular PostgreSQL background worker +/// process. The shared memory segment used to communicate with the backends has been +/// allocated and initialized earlier, at postmaster startup, in +/// rcommunicator_shmem_init(). +/// +/// Inputs: +/// `tenant_id` and `timeline_id` can be NULL, if we're been launched in "non-Neon" mode, +/// where we use local storage instead of connecting to remote neon storage. That's +/// currently only used in some unit tests. +/// +/// Result: +/// Returns pointer to CommunicatorWorkerProcessStruct, which is a handle to running +/// Rust tasks. The C code can use it to interact with the Rust parts. On failure, returns +/// None/NULL, and an error message is returned in *error_p +/// +/// This is called only once in the process, so the returned struct, and error message in +/// case of failure, are simply leaked. #[unsafe(no_mangle)] pub extern "C" fn communicator_worker_process_launch( cis: Box, @@ -29,20 +40,27 @@ pub extern "C" fn communicator_worker_process_launch( stripe_size: u32, file_cache_path: *const c_char, initial_file_cache_size: u64, -) -> &'static CommunicatorWorkerProcessStruct<'static> { + error_p: *mut *const c_char, +) -> Option<&'static CommunicatorWorkerProcessStruct<'static>> { tracing::warn!("starting threads in rust code"); // Convert the arguments into more convenient Rust types - let tenant_id = unsafe { CStr::from_ptr(tenant_id) }.to_str().unwrap(); - let timeline_id = unsafe { CStr::from_ptr(timeline_id) }.to_str().unwrap(); + let tenant_id = if tenant_id.is_null() { + None + } else { + let cstr = unsafe { CStr::from_ptr(tenant_id) }; + Some(cstr.to_str().expect("assume UTF-8")) + }; + let timeline_id = if timeline_id.is_null() { + None + } else { + let cstr = unsafe { CStr::from_ptr(timeline_id) }; + Some(cstr.to_str().expect("assume UTF-8")) + }; let auth_token = if auth_token.is_null() { None } else { - Some( - unsafe { CStr::from_ptr(auth_token) } - .to_str() - .unwrap() - .to_string(), - ) + let cstr = unsafe { CStr::from_ptr(auth_token) }; + Some(cstr.to_str().expect("assume UTF-8")) }; let file_cache_path = { if file_cache_path.is_null() { @@ -53,43 +71,39 @@ pub extern "C" fn communicator_worker_process_launch( } }; let shard_map = shard_map_to_hash(nshards, shard_map); + // FIXME: distinguish between unsharded, and sharded with 1 shard + // Also, we might go from unsharded to sharded while the system + // is running. + let stripe_size = if stripe_size > 0 && nshards > 1 { + Some(ShardStripeSize(stripe_size)) + } else { + None + }; - // start main loop - let runtime = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .thread_name("communicator thread") - .build() - .unwrap(); - - let worker_struct = runtime.block_on(main_loop::init( + // The `init` function does all the work. + let result = main_loop::init( cis, - tenant_id.to_string(), - timeline_id.to_string(), + tenant_id, + timeline_id, auth_token, shard_map, - if stripe_size > 0 { - Some(ShardStripeSize(stripe_size)) - } else { - None - }, + stripe_size, initial_file_cache_size, file_cache_path, - )); - let worker_struct = Box::leak(Box::new(worker_struct)); + ); - let main_loop_handle = runtime.spawn(worker_struct.run()); + // On failure, return the error message to the C caller in *error_p. + match result { + Ok(worker_struct) => Some(worker_struct), + Err(errmsg) => { + let errmsg = CString::new(errmsg).expect("no nuls within error message"); + let errmsg = Box::leak(errmsg.into_boxed_c_str()); + let p: *const c_char = errmsg.as_ptr(); - runtime.spawn(async { - let err = main_loop_handle.await.unwrap_err(); - error!("error: {err:?}"); - }); - - runtime.block_on(worker_struct.launch_exporter_task()); - - // keep the runtime running after we exit this function - Box::leak(Box::new(runtime)); - - worker_struct + unsafe { *error_p = p }; + None + } + } } /// Convert the "shard map" from an array of C strings, indexed by shard no to a rust HashMap diff --git a/pgxn/neon/communicator_new.c b/pgxn/neon/communicator_new.c index cb0bbc5ee0..f0663411fe 100644 --- a/pgxn/neon/communicator_new.c +++ b/pgxn/neon/communicator_new.c @@ -42,6 +42,7 @@ #include "bitmap.h" #include "communicator_new.h" +#include "communicator_process.h" #include "hll.h" #include "neon.h" #include "neon_perf_counters.h" @@ -62,7 +63,6 @@ extern char *lfc_path; #define MaxProcs (MaxBackends + NUM_AUXILIARY_PROCS) -static CommunicatorInitStruct *cis; static CommunicatorBackendStruct *my_bs; static File cache_file = 0; @@ -143,8 +143,6 @@ static bool bounce_needed(void *buffer); static void *bounce_buf(void); static void *bounce_write_if_needed(void *buffer); -static void pump_logging(struct LoggingReceiver *logging); -PGDLLEXPORT void communicator_new_bgworker_main(Datum main_arg); static void communicator_new_backend_exit(int code, Datum arg); static char *print_neon_io_request(NeonIORequest *request); @@ -181,44 +179,12 @@ assign_request_id(void) /**** Initialization functions. These run in postmaster ****/ -void -pg_init_communicator_new(void) -{ - BackgroundWorker bgw; - - if (!neon_use_communicator_worker) - return; - - if (pageserver_connstring[0] == '\0' && pageserver_grpc_urls[0] == '\0') - { - /* running with local storage */ - return; - } - - /* Initialize the background worker process */ - memset(&bgw, 0, sizeof(bgw)); - bgw.bgw_flags = BGWORKER_SHMEM_ACCESS; - bgw.bgw_start_time = BgWorkerStart_PostmasterStart; - snprintf(bgw.bgw_library_name, BGW_MAXLEN, "neon"); - snprintf(bgw.bgw_function_name, BGW_MAXLEN, "communicator_new_bgworker_main"); - snprintf(bgw.bgw_name, BGW_MAXLEN, "Storage communicator process"); - snprintf(bgw.bgw_type, BGW_MAXLEN, "Storage communicator process"); - bgw.bgw_restart_time = 5; - bgw.bgw_notify_pid = 0; - bgw.bgw_main_arg = (Datum) 0; - - RegisterBackgroundWorker(&bgw); -} - static size_t communicator_new_shmem_size(void) { size_t size = 0; int num_request_slots; - if (!neon_use_communicator_worker) - return 0; - size += MAXALIGN( offsetof(CommunicatorShmemData, backends) + MaxProcs * sizeof(CommunicatorShmemPerBackendData) @@ -235,9 +201,6 @@ communicator_new_shmem_size(void) void CommunicatorNewShmemRequest(void) { - if (!neon_use_communicator_worker) - return; - RequestAddinShmemSpace(communicator_new_shmem_size()); } @@ -253,8 +216,7 @@ CommunicatorNewShmemInit(void) uint64 initial_file_cache_size; uint64 max_file_cache_size; - if (!neon_use_communicator_worker) - return; + /* FIXME: much of this could be skipped if !neon_use_communicator_worker */ rc = pipe(pipefd); if (rc != 0) @@ -302,203 +264,6 @@ CommunicatorNewShmemInit(void) /**** Worker process functions. These run in the communicator worker process ****/ -/* Entry point for the communicator bgworker process */ -void -communicator_new_bgworker_main(Datum main_arg) -{ - char **connstrings; - ShardMap shard_map; - uint64 file_cache_size; - struct LoggingReceiver *logging; - const struct CommunicatorWorkerProcessStruct *proc_handle; - - /* - * Pretend that this process is a WAL sender. That affects the shutdown - * sequence: WAL senders are shut down last, after the final checkpoint - * has been written. That's what we want for the communicator process too - */ - am_walsender = true; - MarkPostmasterChildWalSender(); - - /* lfc_size_limit is in MBs */ - file_cache_size = lfc_size_limit * (1024 * 1024 / BLCKSZ); - if (file_cache_size < 100) - file_cache_size = 100; - - /* Establish signal handlers. */ - pqsignal(SIGUSR1, procsignal_sigusr1_handler); - /* - * Postmaster sends us SIGUSR2 when all regular backends and bgworkers - * have exited, and it's time for us to exit too - */ - pqsignal(SIGUSR2, die); - pqsignal(SIGHUP, SignalHandlerForConfigReload); - pqsignal(SIGTERM, die); - - BackgroundWorkerUnblockSignals(); - - if (!parse_shard_map(pageserver_grpc_urls, &shard_map)) - { - /* shouldn't happen, as the GUC was verified already */ - elog(FATAL, "could not parse neon.pageserver_grpcs_urls"); - } - connstrings = palloc(shard_map.num_shards * sizeof(char *)); - for (int i = 0; i < shard_map.num_shards; i++) - connstrings[i] = shard_map.connstring[i]; - - /* - * By default, INFO messages are not printed to the log. We want - * `tracing::info!` messages emitted from the communicator to be printed, - * however, so increase the log level. - * - * XXX: This overrides any user-set value from the config file. That's not - * great, but on the other hand, there should be little reason for user to - * control the verbosity of the communicator. It's not too verbose by - * default. - */ - SetConfigOption("log_min_messages", "INFO", PGC_SUSET, PGC_S_OVERRIDE); - - logging = communicator_worker_configure_logging(); - - elog(LOG, "launching worker process threads"); - proc_handle = communicator_worker_process_launch( - cis, - neon_tenant, - neon_timeline, - neon_auth_token, - connstrings, - shard_map.num_shards, - neon_stripe_size, - lfc_path, - file_cache_size); - pfree(connstrings); - cis = NULL; - if (proc_handle == NULL) - { - /* - * Something went wrong. Before exiting, forward any log messages that - * might've been generated during the failed launch. - */ - pump_logging(logging); - - elog(PANIC, "failure launching threads"); - } - - elog(LOG, "communicator threads started"); - for (;;) - { - ResetLatch(MyLatch); - - /* - * Forward any log messages from the Rust threads into the normal - * Postgres logging facility. - */ - pump_logging(logging); - - CHECK_FOR_INTERRUPTS(); - - if (ConfigReloadPending) - { - ConfigReloadPending = false; - ProcessConfigFile(PGC_SIGHUP); - - /* lfc_size_limit is in MBs */ - file_cache_size = lfc_size_limit * (1024 * 1024 / BLCKSZ); - if (file_cache_size < 100) - file_cache_size = 100; - - /* Reload pageserver URLs */ - if (!parse_shard_map(pageserver_grpc_urls, &shard_map)) - { - /* shouldn't happen, as the GUC was verified already */ - elog(FATAL, "could not parse neon.pageserver_grpcs_urls"); - } - connstrings = palloc(shard_map.num_shards * sizeof(char *)); - for (int i = 0; i < shard_map.num_shards; i++) - connstrings[i] = shard_map.connstring[i]; - - communicator_worker_config_reload(proc_handle, - file_cache_size, - connstrings, - shard_map.num_shards); - pfree(connstrings); - } - - (void) WaitLatch(MyLatch, - WL_LATCH_SET | WL_EXIT_ON_PM_DEATH, - 0, - PG_WAIT_EXTENSION); - } -} - -static void -pump_logging(struct LoggingReceiver *logging) -{ - char errbuf[1000]; - int elevel; - int32 rc; - static uint64_t last_dropped_event_count = 0; - uint64_t dropped_event_count; - uint64_t dropped_now; - - for (;;) - { - rc = communicator_worker_poll_logging(logging, - errbuf, - sizeof(errbuf), - &elevel, - &dropped_event_count); - if (rc == 0) - { - /* nothing to do */ - break; - } - else if (rc == 1) - { - /* Because we don't want to exit on error */ - - if (message_level_is_interesting(elevel)) - { - /* - * Prevent interrupts while cleaning up. - * - * (Not sure if this is required, but all the error handlers - * in Postgres that are installed as sigsetjmp() targets do - * this, so let's follow the example) - */ - HOLD_INTERRUPTS(); - - errstart(elevel, TEXTDOMAIN); - errmsg_internal("[COMMUNICATOR] %s", errbuf); - EmitErrorReport(); - FlushErrorState(); - - /* Now we can allow interrupts again */ - RESUME_INTERRUPTS(); - } - } - else if (rc == -1) - { - elog(ERROR, "logging channel was closed unexpectedly"); - } - } - - /* - * If the queue was full at any time since the last time we reported it, - * report how many messages were lost. We do this outside the loop, so - * that if the logging system is clogged, we don't exacerbate it by - * printing lots of warnings about dropped messages. - */ - dropped_now = dropped_event_count - last_dropped_event_count; - if (dropped_now != 0) - { - elog(WARNING, "%lu communicator log messages were dropped because the log buffer was full", - (unsigned long) dropped_now); - last_dropped_event_count = dropped_event_count; - } -} - - /* * Callbacks from the rust code, in the communicator process. * @@ -514,45 +279,6 @@ notify_proc_unsafe(int procno) } -void -callback_set_my_latch_unsafe(void) -{ - SetLatch(MyLatch); -} - -/* - * FIXME: The logic from neon_get_request_lsns() needs to go here, except for - * the last-written LSN cache stuff, which is managed by the rust code now. - */ -uint64_t -callback_get_request_lsn_unsafe(void) -{ - /* - * NB: be very careful with what you do here! This is called from tokio - * threads, so anything tha tries to take LWLocks is unsafe, for example. - * - * RecoveryInProgress() is OK - */ - if (RecoveryInProgress()) - { - XLogRecPtr replay_lsn = GetXLogReplayRecPtr(NULL); - - return replay_lsn; - } - else - { - XLogRecPtr flushlsn; - -#if PG_VERSION_NUM >= 150000 - flushlsn = GetFlushRecPtr(NULL); -#else - flushlsn = GetFlushRecPtr(); -#endif - - return flushlsn; - } -} - /**** Backend functions. These run in each backend ****/ /* Initialize per-backend private state */ @@ -779,6 +505,7 @@ start_request(NeonIORequest *request, struct NeonIOResult *immediate_result_p) my_next_slot_idx++; if (my_next_slot_idx == my_end_slot_idx) my_next_slot_idx = my_start_slot_idx; + inflight_requests[num_inflight_requests] = request_idx; num_inflight_requests++; @@ -856,13 +583,14 @@ bool communicator_new_rel_exists(NRelFileInfo rinfo, ForkNumber forkNum) { NeonIORequest request = { - .tag = NeonIORequest_RelExists, - .rel_exists = { + .tag = NeonIORequest_RelSize, + .rel_size = { .request_id = assign_request_id(), .spc_oid = NInfoGetSpcOid(rinfo), .db_oid = NInfoGetDbOid(rinfo), .rel_number = NInfoGetRelNumber(rinfo), .fork_number = forkNum, + .allow_missing = true, } }; NeonIOResult result; @@ -870,8 +598,8 @@ communicator_new_rel_exists(NRelFileInfo rinfo, ForkNumber forkNum) perform_request(&request, &result); switch (result.tag) { - case NeonIOResult_RelExists: - return result.rel_exists; + case NeonIOResult_RelSize: + return result.rel_size != InvalidBlockNumber; case NeonIOResult_Error: ereport(ERROR, (errcode_for_file_access(), @@ -879,7 +607,7 @@ communicator_new_rel_exists(NRelFileInfo rinfo, ForkNumber forkNum) RelFileInfoFmt(rinfo), forkNum, pg_strerror(result.error)))); break; default: - elog(ERROR, "unexpected result for RelExists operation: %d", result.tag); + elog(ERROR, "unexpected result for RelSize operation: %d", result.tag); break; } } @@ -1067,6 +795,7 @@ communicator_new_rel_nblocks(NRelFileInfo rinfo, ForkNumber forkNum) .db_oid = NInfoGetDbOid(rinfo), .rel_number = NInfoGetRelNumber(rinfo), .fork_number = forkNum, + .allow_missing = false, } }; NeonIOResult result; @@ -1450,22 +1179,14 @@ print_neon_io_request(NeonIORequest *request) case NeonIORequest_Empty: snprintf(buf, sizeof(buf), "Empty"); return buf; - case NeonIORequest_RelExists: - { - CRelExistsRequest *r = &request->rel_exists; - - snprintf(buf, sizeof(buf), "RelExists: req " UINT64_FORMAT " rel %u/%u/%u.%u", - r->request_id, - r->spc_oid, r->db_oid, r->rel_number, r->fork_number); - return buf; - } case NeonIORequest_RelSize: { CRelSizeRequest *r = &request->rel_size; - snprintf(buf, sizeof(buf), "RelSize: req " UINT64_FORMAT " rel %u/%u/%u.%u", - r->request_id, - r->spc_oid, r->db_oid, r->rel_number, r->fork_number); + snprintf(buf, sizeof(buf), "RelSize: req " UINT64_FORMAT " rel %u/%u/%u.%u allow_missing: %d", + r->request_id, + r->spc_oid, r->db_oid, r->rel_number, r->fork_number, + r->allow_missing); return buf; } case NeonIORequest_GetPageV: diff --git a/pgxn/neon/communicator_new.h b/pgxn/neon/communicator_new.h index 8de2fab57a..ec5d9aad07 100644 --- a/pgxn/neon/communicator_new.h +++ b/pgxn/neon/communicator_new.h @@ -20,7 +20,6 @@ #include "pagestore_client.h" /* initialization at postmaster startup */ -extern void pg_init_communicator_new(void); extern void CommunicatorNewShmemRequest(void); extern void CommunicatorNewShmemInit(void); diff --git a/pgxn/neon/communicator_process.c b/pgxn/neon/communicator_process.c new file mode 100644 index 0000000000..0d3342cd7c --- /dev/null +++ b/pgxn/neon/communicator_process.c @@ -0,0 +1,359 @@ +/*------------------------------------------------------------------------- + * + * communicator_process.c + * Functions for starting up the communicator background worker process. + * + * Currently, the communicator process only functions as a metrics + * exporter. It provides an HTTP endpoint for polling a limited set of + * metrics. TODO: In the future, it will do much more, i.e. handle all + * the communications with the pageservers. + * + * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + *------------------------------------------------------------------------- + */ +#include "postgres.h" + +#include + +#include "miscadmin.h" +#if PG_VERSION_NUM >= 150000 +#include "access/xlogrecovery.h" +#endif +#include "postmaster/bgworker.h" +#include "postmaster/interrupt.h" +#include "postmaster/postmaster.h" +#include "replication/walsender.h" +#include "storage/ipc.h" +#include "storage/latch.h" +#include "storage/pmsignal.h" +#include "storage/procsignal.h" +#include "tcop/tcopprot.h" +#include "utils/timestamp.h" + +#include "communicator_process.h" +#include "file_cache.h" +#include "neon.h" +#include "neon_perf_counters.h" +#include "pagestore_client.h" + +/* the rust bindings, generated by cbindgen */ +#include "communicator/communicator_bindings.h" + +struct CommunicatorInitStruct *cis; + +static void pump_logging(struct LoggingReceiver *logging); +PGDLLEXPORT void communicator_new_bgworker_main(Datum main_arg); + +/**** Initialization functions. These run in postmaster ****/ + +void +pg_init_communicator_process(void) +{ + BackgroundWorker bgw; + + /* Initialize the background worker process */ + memset(&bgw, 0, sizeof(bgw)); + bgw.bgw_flags = BGWORKER_SHMEM_ACCESS; + bgw.bgw_start_time = BgWorkerStart_PostmasterStart; + snprintf(bgw.bgw_library_name, BGW_MAXLEN, "neon"); + snprintf(bgw.bgw_function_name, BGW_MAXLEN, "communicator_new_bgworker_main"); + snprintf(bgw.bgw_name, BGW_MAXLEN, "Storage communicator process"); + snprintf(bgw.bgw_type, BGW_MAXLEN, "Storage communicator process"); + bgw.bgw_restart_time = 5; + bgw.bgw_notify_pid = 0; + bgw.bgw_main_arg = (Datum) 0; + + RegisterBackgroundWorker(&bgw); +} + +/**** Worker process functions. These run in the communicator worker process ****/ + +/* + * Entry point for the communicator bgworker process + */ +void +communicator_new_bgworker_main(Datum main_arg) +{ + char **connstrings; + ShardMap shard_map; + uint64 file_cache_size; + struct LoggingReceiver *logging; + const char *errmsg = NULL; + const struct CommunicatorWorkerProcessStruct *proc_handle; + + /* + * Pretend that this process is a WAL sender. That affects the shutdown + * sequence: WAL senders are shut down last, after the final checkpoint + * has been written. That's what we want for the communicator process too. + */ + am_walsender = true; + MarkPostmasterChildWalSender(); + + /* Establish signal handlers. */ + pqsignal(SIGUSR1, procsignal_sigusr1_handler); + /* + * Postmaster sends us SIGUSR2 when all regular backends and bgworkers + * have exited, and it's time for us to exit too + */ + pqsignal(SIGUSR2, die); + pqsignal(SIGHUP, SignalHandlerForConfigReload); + pqsignal(SIGTERM, die); + + BackgroundWorkerUnblockSignals(); + + /* lfc_size_limit is in MBs */ + file_cache_size = lfc_size_limit * (1024 * 1024 / BLCKSZ); + if (file_cache_size < 100) + file_cache_size = 100; + + if (!parse_shard_map(pageserver_grpc_urls, &shard_map)) + { + /* shouldn't happen, as the GUC was verified already */ + elog(FATAL, "could not parse neon.pageserver_grpcs_urls"); + } + connstrings = palloc(shard_map.num_shards * sizeof(char *)); + for (int i = 0; i < shard_map.num_shards; i++) + connstrings[i] = shard_map.connstring[i]; + + /* + * By default, INFO messages are not printed to the log. We want + * `tracing::info!` messages emitted from the communicator to be printed, + * however, so increase the log level. + * + * XXX: This overrides any user-set value from the config file. That's not + * great, but on the other hand, there should be little reason for user to + * control the verbosity of the communicator. It's not too verbose by + * default. + */ + SetConfigOption("log_min_messages", "INFO", PGC_SUSET, PGC_S_OVERRIDE); + + logging = communicator_worker_configure_logging(); + + Assert(cis != NULL); + proc_handle = communicator_worker_process_launch( + cis, + neon_tenant[0] == '\0' ? NULL : neon_tenant, + neon_timeline[0] == '\0' ? NULL : neon_timeline, + neon_auth_token, + connstrings, + shard_map.num_shards, + neon_stripe_size, + lfc_path, + file_cache_size, + &errmsg); + pfree(connstrings); + cis = NULL; + if (proc_handle == NULL) + { + /* + * Something went wrong. Before exiting, forward any log messages that + * might've been generated during the failed launch. + */ + pump_logging(logging); + + elog(PANIC, "%s", errmsg); + } + + /* + * The Rust tokio runtime has been launched, and it's running in the + * background now. This loop in the main thread handles any interactions + * we need with the rest of PostgreSQL. + * + * NB: This process is now multi-threaded! The Rust threads do not call + * into any Postgres functions, but it's not entirely clear which Postgres + * functions are safe to call from this main thread either. Be very + * careful about adding anything non-trivial here. + * + * Also note that we try to react quickly to any log messages arriving + * from the Rust thread. Be careful to not do anything too expensive here + * that might cause delays. + */ + elog(LOG, "communicator threads started"); + for (;;) + { + TimestampTz before; + long duration; + + ResetLatch(MyLatch); + + /* + * Forward any log messages from the Rust threads into the normal + * Postgres logging facility. + */ + pump_logging(logging); + + /* + * Check interrupts like system shutdown or config reload + * + * We mustn't block for too long within this loop, or we risk the log + * queue to fill up and messages to be lost. Also, even if we can keep + * up, if there's a long delay between sending a message and printing + * it to the log, the timestamps on the messages get skewed, which is + * confusing. + * + * We expect processing interrupts to happen fast enough that it's OK, + * but measure it just in case, and print a warning if it takes longer + * than 100 ms. + */ +#define LOG_SKEW_WARNING_MS 100 + before = GetCurrentTimestamp(); + + CHECK_FOR_INTERRUPTS(); + if (ConfigReloadPending) + { + ConfigReloadPending = false; + ProcessConfigFile(PGC_SIGHUP); + + /* lfc_size_limit is in MBs */ + file_cache_size = lfc_size_limit * (1024 * 1024 / BLCKSZ); + if (file_cache_size < 100) + file_cache_size = 100; + + /* Reload pageserver URLs */ + if (!parse_shard_map(pageserver_grpc_urls, &shard_map)) + { + /* shouldn't happen, as the GUC was verified already */ + elog(FATAL, "could not parse neon.pageserver_grpcs_urls"); + } + connstrings = palloc(shard_map.num_shards * sizeof(char *)); + for (int i = 0; i < shard_map.num_shards; i++) + connstrings[i] = shard_map.connstring[i]; + + communicator_worker_config_reload(proc_handle, + file_cache_size, + connstrings, + shard_map.num_shards); + pfree(connstrings); + } + + duration = TimestampDifferenceMilliseconds(before, GetCurrentTimestamp()); + if (duration > LOG_SKEW_WARNING_MS) + elog(WARNING, "handling interrupts took %ld ms, communicator log timestamps might be skewed", duration); + + /* + * Wait until we are woken up. The rust threads will set the latch + * when there's a log message to forward. + */ + (void) WaitLatch(MyLatch, + WL_LATCH_SET | WL_EXIT_ON_PM_DEATH, + 0, + PG_WAIT_EXTENSION); + } +} + +static void +pump_logging(struct LoggingReceiver *logging) +{ + char errbuf[1000]; + int elevel; + int32 rc; + static uint64_t last_dropped_event_count = 0; + uint64_t dropped_event_count; + uint64_t dropped_now; + + for (;;) + { + rc = communicator_worker_poll_logging(logging, + errbuf, + sizeof(errbuf), + &elevel, + &dropped_event_count); + if (rc == 0) + { + /* nothing to do */ + break; + } + else if (rc == 1) + { + /* Because we don't want to exit on error */ + + if (message_level_is_interesting(elevel)) + { + /* + * Prevent interrupts while cleaning up. + * + * (Not sure if this is required, but all the error handlers + * in Postgres that are installed as sigsetjmp() targets do + * this, so let's follow the example) + */ + HOLD_INTERRUPTS(); + + errstart(elevel, TEXTDOMAIN); + errmsg_internal("[COMMUNICATOR] %s", errbuf); + EmitErrorReport(); + FlushErrorState(); + + /* Now we can allow interrupts again */ + RESUME_INTERRUPTS(); + } + } + else if (rc == -1) + { + elog(ERROR, "logging channel was closed unexpectedly"); + } + } + + /* + * If the queue was full at any time since the last time we reported it, + * report how many messages were lost. We do this outside the loop, so + * that if the logging system is clogged, we don't exacerbate it by + * printing lots of warnings about dropped messages. + */ + dropped_now = dropped_event_count - last_dropped_event_count; + if (dropped_now != 0) + { + elog(WARNING, "%lu communicator log messages were dropped because the log buffer was full", + (unsigned long) dropped_now); + last_dropped_event_count = dropped_event_count; + } +} + +/**** + * Callbacks from the rust code, in the communicator process. + * + * NOTE: These must be thread-safe! It's very limited which PostgreSQL + * functions you can use!!! + * + * The signatures of these need to match those in the Rust code. + */ + +void +callback_set_my_latch_unsafe(void) +{ + SetLatch(MyLatch); +} + +/* + * FIXME: The logic from neon_get_request_lsns() needs to go here, except for + * the last-written LSN cache stuff, which is managed by the rust code now. + */ +uint64_t +callback_get_request_lsn_unsafe(void) +{ + /* + * NB: be very careful with what you do here! This is called from tokio + * threads, so anything tha tries to take LWLocks is unsafe, for example. + * + * RecoveryInProgress() is OK + */ + if (RecoveryInProgress()) + { + XLogRecPtr replay_lsn = GetXLogReplayRecPtr(NULL); + + return replay_lsn; + } + else + { + XLogRecPtr flushlsn; + +#if PG_VERSION_NUM >= 150000 + flushlsn = GetFlushRecPtr(NULL); +#else + flushlsn = GetFlushRecPtr(); +#endif + + return flushlsn; + } +} diff --git a/pgxn/neon/communicator_process.h b/pgxn/neon/communicator_process.h new file mode 100644 index 0000000000..58c8919d1b --- /dev/null +++ b/pgxn/neon/communicator_process.h @@ -0,0 +1,20 @@ +/*------------------------------------------------------------------------- + * + * communicator_process.h + * Communicator process + * + * + * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + *------------------------------------------------------------------------- + */ +#ifndef COMMUNICATOR_PROCESS_H +#define COMMUNICATOR_PROCESS_H + +extern struct CommunicatorInitStruct *cis; + +/* initialization early at postmaster startup */ +extern void pg_init_communicator_process(void); + +#endif /* COMMUNICATOR_PROCESS_H */ diff --git a/pgxn/neon/file_cache.c b/pgxn/neon/file_cache.c index 7c408c82da..0370da3fbd 100644 --- a/pgxn/neon/file_cache.c +++ b/pgxn/neon/file_cache.c @@ -52,6 +52,8 @@ #include "pagestore_client.h" #include "communicator.h" +#include "communicator/communicator_bindings.h" + #define CriticalAssert(cond) do if (!(cond)) elog(PANIC, "LFC: assertion %s failed at %s:%d: ", #cond, __FILE__, __LINE__); while (0) /* @@ -1862,3 +1864,34 @@ lfc_approximate_working_set_size_seconds(time_t duration, bool reset) memset(lfc_ctl->wss_estimation.regs, 0, sizeof lfc_ctl->wss_estimation.regs); return dc; } + +/* + * Get metrics, for the built-in metrics exporter that's part of the communicator + * process. + * + * NB: This is called from a Rust tokio task inside the communicator process. + * Acquiring lwlocks, elog(), allocating memory or anything else non-trivial + * is strictly prohibited here! + */ +struct LfcMetrics +callback_get_lfc_metrics_unsafe(void) +{ + struct LfcMetrics result = { + .lfc_cache_size_limit = (int64) lfc_size_limit * 1024 * 1024, + .lfc_hits = lfc_ctl ? lfc_ctl->hits : 0, + .lfc_misses = lfc_ctl ? lfc_ctl->misses : 0, + .lfc_used = lfc_ctl ? lfc_ctl->used : 0, + .lfc_writes = lfc_ctl ? lfc_ctl->writes : 0, + }; + + if (lfc_ctl) + { + for (int minutes = 1; minutes <= 60; minutes++) + { + result.lfc_approximate_working_set_size_windows[minutes - 1] = + lfc_approximate_working_set_size_seconds(minutes * 60, false); + } + } + + return result; +} diff --git a/pgxn/neon/lfc_prewarm.c b/pgxn/neon/lfc_prewarm.c index 2acb805f9d..680272fb8a 100644 --- a/pgxn/neon/lfc_prewarm.c +++ b/pgxn/neon/lfc_prewarm.c @@ -541,7 +541,7 @@ lfc_prewarm_with_async_requests(FileCacheState *fcs) request_startblkno = request_endblkno = InvalidBlockNumber; } - Assert(n_sent == n_received || prewarm_ctl->prewarm_canceled); + Assert(prewarm_ctl->prewarm_canceled); elog(LOG, "LFC: complete prewarming: loaded %lu pages", (unsigned long) prewarm_ctl->prewarmed_pages); prewarm_ctl->completed = GetCurrentTimestamp(); diff --git a/pgxn/neon/neon.c b/pgxn/neon/neon.c index 59ecd9ab1c..655fdc6faa 100644 --- a/pgxn/neon/neon.c +++ b/pgxn/neon/neon.c @@ -33,6 +33,7 @@ #include "communicator.h" #include "communicator_new.h" +#include "communicator_process.h" #include "extension_server.h" #include "file_cache.h" #include "neon.h" @@ -46,9 +47,6 @@ #include "storage/ipc.h" #endif -/* the rust bindings, generated by cbindgen */ -#include "communicator/communicator_bindings.h" - PG_MODULE_MAGIC; void _PG_init(void); @@ -507,8 +505,9 @@ _PG_init(void) pg_init_walproposer(); init_lwlsncache(); + pg_init_communicator_process(); + pg_init_communicator(); - pg_init_communicator_new(); Custom_XLogReaderRoutines = NeonOnDemandXLogReaderRoutines; diff --git a/pgxn/neon/neon_pgversioncompat.h b/pgxn/neon/neon_pgversioncompat.h index 85646a6dc5..dbe0e5aa3d 100644 --- a/pgxn/neon/neon_pgversioncompat.h +++ b/pgxn/neon/neon_pgversioncompat.h @@ -90,8 +90,7 @@ InitBufferTag(BufferTag *tag, const RelFileNode *rnode, #define InvalidRelFileNumber InvalidOid -#define SMgrRelGetRelInfo(reln) \ - (reln->smgr_rnode.node) +#define SMgrRelGetRelInfo(reln) ((reln)->smgr_rnode.node) #define DropRelationAllLocalBuffers DropRelFileNodeAllLocalBuffers @@ -146,8 +145,7 @@ InitBufferTag(BufferTag *tag, const RelFileNode *rnode, (tag).relNumber = (rel_number); \ } while (false) -#define SMgrRelGetRelInfo(reln) \ - ((reln)->smgr_rlocator) +#define SMgrRelGetRelInfo(reln) ((reln)->smgr_rlocator) #define DropRelationAllLocalBuffers DropRelationAllLocalBuffers #endif diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index 9ed8d0d2d2..93807be8c2 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -400,6 +400,14 @@ static uint64 backpressure_lag_impl(void) { struct WalproposerShmemState* state = NULL; + + /* BEGIN_HADRON */ + if(max_cluster_size < 0){ + // if max cluster size is not set, then we don't apply backpressure because we're reconfiguring PG + return 0; + } + /* END_HADRON */ + if (max_replication_apply_lag > 0 || max_replication_flush_lag > 0 || max_replication_write_lag > 0) { XLogRecPtr writePtr; diff --git a/poetry.lock b/poetry.lock index b2072bf1bc..a920833fbf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -3068,6 +3068,21 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "requests-unixsocket" +version = "0.4.1" +description = "Use requests to talk HTTP via a UNIX domain socket" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "requests_unixsocket-0.4.1-py3-none-any.whl", hash = "sha256:60c4942e9dbecc2f64d611039fb1dfc25da382083c6434ac0316dca3ff908f4d"}, + {file = "requests_unixsocket-0.4.1.tar.gz", hash = "sha256:b2596158c356ecee68d27ba469a52211230ac6fb0cde8b66afb19f0ed47a1995"}, +] + +[package.dependencies] +requests = ">=1.1" + [[package]] name = "responses" version = "0.25.3" @@ -3844,4 +3859,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.1" python-versions = "^3.11" -content-hash = "6a1e8ba06b8194bf28d87fd5e184e2ddc2b4a19dffcbe3953b26da3d55c9212f" +content-hash = "b08aba407631b0341d2ef8bf9acffd733bfc7d32b12d344717ab4c7fef697625" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 82fe6818e3..3c3f93c8e3 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -7,6 +7,7 @@ license.workspace = true [features] default = [] testing = ["dep:tokio-postgres"] +rest_broker = ["dep:subzero-core", "dep:ouroboros"] [dependencies] ahash.workspace = true @@ -65,6 +66,7 @@ postgres-client = { package = "tokio-postgres2", path = "../libs/proxy/tokio-pos postgres-protocol = { package = "postgres-protocol2", path = "../libs/proxy/postgres-protocol2" } pq_proto.workspace = true rand.workspace = true +rand_core.workspace = true regex.workspace = true remote_storage = { version = "0.1", path = "../libs/remote_storage/" } reqwest = { workspace = true, features = ["rustls-tls-native-roots"] } @@ -105,6 +107,11 @@ uuid.workspace = true x509-cert.workspace = true redis.workspace = true zerocopy.workspace = true +# uncomment this to use the real subzero-core crate +# subzero-core = { git = "https://github.com/neondatabase/subzero", rev = "396264617e78e8be428682f87469bb25429af88a", features = ["postgresql"], optional = true } +# this is a stub for the subzero-core crate +subzero-core = { path = "./subzero_core", features = ["postgresql"], optional = true} +ouroboros = { version = "0.18", optional = true } # jwt stuff jose-jwa = "0.1.2" @@ -127,6 +134,6 @@ pbkdf2 = { workspace = true, features = ["simple", "std"] } rcgen.workspace = true rstest.workspace = true walkdir.workspace = true -rand_distr = "0.4" +rand_distr = "0.5" tokio-postgres.workspace = true tracing-test = "0.2" diff --git a/proxy/README.md b/proxy/README.md index ff48f9f323..ce957b90af 100644 --- a/proxy/README.md +++ b/proxy/README.md @@ -178,16 +178,24 @@ Create a configuration file called `local_proxy.json` in the root of the repo (u Start the local proxy: ```sh -cargo run --bin local_proxy -- \ - --disable_pg_session_jwt true \ +cargo run --bin local_proxy --features testing -- \ + --disable-pg-session-jwt \ --http 0.0.0.0:7432 ``` -Start the auth broker: +Start the auth/rest broker: + +Note: to enable the rest broker you need to replace the stub subzero-core crate with the real one. + ```sh -LOGFMT=text OTEL_SDK_DISABLED=true cargo run --bin proxy --features testing -- \ +cargo add -p proxy subzero-core --git https://github.com/neondatabase/subzero --rev 396264617e78e8be428682f87469bb25429af88a +``` + +```sh +LOGFMT=text OTEL_SDK_DISABLED=true cargo run --bin proxy --features testing,rest_broker -- \ -c server.crt -k server.key \ --is-auth-broker true \ + --is-rest-broker true \ --wss 0.0.0.0:8080 \ --http 0.0.0.0:7002 \ --auth-backend local @@ -205,3 +213,9 @@ curl -k "https://foo.local.neon.build:8080/sql" \ -H "neon-connection-string: postgresql://authenticator@foo.local.neon.build/database" \ -d '{"query":"select 1","params":[]}' ``` + +Make a rest request against the auth broker (rest broker): +```sh +curl -k "https://foo.local.neon.build:8080/database/rest/v1/items?select=id,name&id=eq.1" \ +-H "Authorization: Bearer $NEON_JWT" +``` diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index f561df9202..b06ed3a0ae 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -180,8 +180,6 @@ async fn authenticate( return Err(auth::AuthError::NetworkNotAllowed); } - client.write_message(BeMessage::NoticeResponse("Connecting to database.")); - // Backwards compatibility. pg_sni_proxy uses "--" in domain names // while direct connections do not. Once we migrate to pg_sni_proxy // everywhere, we can remove this. diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index a716890a00..6eba869870 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -803,7 +803,7 @@ mod tests { use http_body_util::Full; use hyper::service::service_fn; use hyper_util::rt::TokioIo; - use rand::rngs::OsRng; + use rand_core::OsRng; use rsa::pkcs8::DecodePrivateKey; use serde::Serialize; use serde_json::json; diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index 401203d48c..7b9012dc69 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -1,3 +1,4 @@ +use std::env; use std::net::SocketAddr; use std::pin::pin; use std::sync::Arc; @@ -20,6 +21,8 @@ use crate::auth::backend::jwt::JwkCache; use crate::auth::backend::local::LocalBackend; use crate::auth::{self}; use crate::cancellation::CancellationHandler; +#[cfg(feature = "rest_broker")] +use crate::config::RestConfig; use crate::config::{ self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig, refresh_config_loop, @@ -262,6 +265,14 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig timeout: Duration::from_secs(2), }; + let greetings = env::var_os("NEON_MOTD").map_or(String::new(), |s| match s.into_string() { + Ok(s) => s, + Err(_) => { + debug!("NEON_MOTD environment variable is not valid UTF-8"); + String::new() + } + }); + Ok(Box::leak(Box::new(ProxyConfig { tls_config: ArcSwapOption::from(None), metric_collection: None, @@ -276,11 +287,19 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig accept_jwts: true, console_redirect_confirmation_timeout: Duration::ZERO, }, + #[cfg(feature = "rest_broker")] + rest_config: RestConfig { + is_rest_broker: false, + db_schema_cache: None, + max_schema_size: 0, + hostname_prefix: String::new(), + }, proxy_protocol_v2: config::ProxyProtocolV2::Rejected, handshake_timeout: Duration::from_secs(10), wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?, connect_compute_locks, connect_to_compute: compute_config, + greetings, #[cfg(feature = "testing")] disable_pg_session_jwt: args.disable_pg_session_jwt, }))) diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index 4ac8b6a995..f3782312dc 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -76,7 +76,7 @@ fn cli() -> clap::Command { } pub async fn run() -> anyhow::Result<()> { - let _logging_guard = crate::logging::init().await?; + let _logging_guard = crate::logging::init()?; let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook(); let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]); diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 16a7dc7b67..4148f4bc62 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -1,4 +1,3 @@ -#[cfg(any(test, feature = "testing"))] use std::env; use std::net::SocketAddr; use std::path::PathBuf; @@ -14,14 +13,14 @@ use arc_swap::ArcSwapOption; use camino::Utf8PathBuf; use futures::future::Either; use itertools::{Itertools, Position}; -use rand::{Rng, thread_rng}; +use rand::Rng; use remote_storage::RemoteStorageConfig; use tokio::net::TcpListener; #[cfg(any(test, feature = "testing"))] use tokio::sync::Notify; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; -use tracing::{error, info, warn}; +use tracing::{debug, error, info, warn}; use utils::sentry_init::init_sentry; use utils::{project_build_tag, project_git_version}; @@ -31,6 +30,8 @@ use crate::auth::backend::local::LocalBackend; use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned}; use crate::batch::BatchQueue; use crate::cancellation::{CancellationHandler, CancellationProcessor}; +#[cfg(feature = "rest_broker")] +use crate::config::RestConfig; #[cfg(any(test, feature = "testing"))] use crate::config::refresh_config_loop; use crate::config::{ @@ -47,6 +48,8 @@ use crate::redis::{elasticache, notifications}; use crate::scram::threadpool::ThreadPool; use crate::serverless::GlobalConnPoolOptions; use crate::serverless::cancel_set::CancelSet; +#[cfg(feature = "rest_broker")] +use crate::serverless::rest::DbSchemaCache; use crate::tls::client_config::compute_client_config_with_root_certs; #[cfg(any(test, feature = "testing"))] use crate::url::ApiUrl; @@ -246,11 +249,23 @@ struct ProxyCliArgs { /// if this is not local proxy, this toggles whether we accept Postgres REST requests #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] + #[cfg(feature = "rest_broker")] is_rest_broker: bool, /// cache for `db_schema_cache` introspection (use `size=0` to disable) #[clap(long, default_value = "size=1000,ttl=1h")] + #[cfg(feature = "rest_broker")] db_schema_cache: String, + + /// Maximum size allowed for schema in bytes + #[clap(long, default_value_t = 5 * 1024 * 1024)] // 5MB + #[cfg(feature = "rest_broker")] + max_schema_size: usize, + + /// Hostname prefix to strip from request hostname to get database hostname + #[clap(long, default_value = "apirest.")] + #[cfg(feature = "rest_broker")] + hostname_prefix: String, } #[derive(clap::Args, Clone, Copy, Debug)] @@ -319,7 +334,7 @@ struct PgSniRouterArgs { } pub async fn run() -> anyhow::Result<()> { - let _logging_guard = crate::logging::init().await?; + let _logging_guard = crate::logging::init()?; let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook(); let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]); @@ -517,6 +532,17 @@ pub async fn run() -> anyhow::Result<()> { )); maintenance_tasks.spawn(control_plane::mgmt::task_main(mgmt_listener)); + // add a task to flush the db_schema cache every 10 minutes + #[cfg(feature = "rest_broker")] + if let Some(db_schema_cache) = &config.rest_config.db_schema_cache { + maintenance_tasks.spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(600)).await; + db_schema_cache.flush(); + } + }); + } + if let Some(metrics_config) = &config.metric_collection { // TODO: Add gc regardles of the metric collection being enabled. maintenance_tasks.spawn(usage_metrics::task_main(metrics_config)); @@ -547,7 +573,7 @@ pub async fn run() -> anyhow::Result<()> { attempt.into_inner() ); } - let jitter = thread_rng().gen_range(0..100); + let jitter = rand::rng().random_range(0..100); tokio::time::sleep(Duration::from_millis(1000 + jitter)).await; } } @@ -679,6 +705,49 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { timeout: Duration::from_secs(2), }; + #[cfg(feature = "rest_broker")] + let rest_config = { + let db_schema_cache_config: CacheOptions = args.db_schema_cache.parse()?; + info!("Using DbSchemaCache with options={db_schema_cache_config:?}"); + + let db_schema_cache = if args.is_rest_broker { + Some(DbSchemaCache::new( + "db_schema_cache", + db_schema_cache_config.size, + db_schema_cache_config.ttl, + true, + )) + } else { + None + }; + + RestConfig { + is_rest_broker: args.is_rest_broker, + db_schema_cache, + max_schema_size: args.max_schema_size, + hostname_prefix: args.hostname_prefix.clone(), + } + }; + + let mut greetings = env::var_os("NEON_MOTD").map_or(String::new(), |s| match s.into_string() { + Ok(s) => s, + Err(_) => { + debug!("NEON_MOTD environment variable is not valid UTF-8"); + String::new() + } + }); + + match &args.auth_backend { + AuthBackendType::ControlPlane => {} + #[cfg(any(test, feature = "testing"))] + AuthBackendType::Postgres => {} + #[cfg(any(test, feature = "testing"))] + AuthBackendType::Local => {} + AuthBackendType::ConsoleRedirect => { + greetings = "Connected to database".to_string(); + } + } + let config = ProxyConfig { tls_config, metric_collection, @@ -689,8 +758,11 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?, connect_compute_locks, connect_to_compute: compute_config, + greetings, #[cfg(feature = "testing")] disable_pg_session_jwt: false, + #[cfg(feature = "rest_broker")] + rest_config, }; let config = Box::leak(Box::new(config)); diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 0ef09a8a9a..a589dd175b 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -5,7 +5,7 @@ use std::time::Duration; use async_trait::async_trait; use clashmap::ClashMap; use clashmap::mapref::one::Ref; -use rand::{Rng, thread_rng}; +use rand::Rng; use tokio::time::Instant; use tracing::{debug, info}; @@ -343,7 +343,7 @@ impl ProjectInfoCacheImpl { } fn gc(&self) { - let shard = thread_rng().gen_range(0..self.project2ep.shards().len()); + let shard = rand::rng().random_range(0..self.project2ep.shards().len()); debug!(shard, "project_info_cache: performing epoch reclamation"); // acquire a random shard lock diff --git a/proxy/src/cache/timed_lru.rs b/proxy/src/cache/timed_lru.rs index e87cf53ab9..0a7fb40b0c 100644 --- a/proxy/src/cache/timed_lru.rs +++ b/proxy/src/cache/timed_lru.rs @@ -204,6 +204,11 @@ impl TimedLru { self.insert_raw_ttl(key, value, ttl, false); } + #[cfg(feature = "rest_broker")] + pub(crate) fn insert(&self, key: K, value: V) { + self.insert_raw_ttl(key, value, self.ttl, self.update_ttl_on_retrieval); + } + pub(crate) fn insert_unit(&self, key: K, value: V) -> (Option, Cached<&Self, ()>) { let (_, old) = self.insert_raw(key.clone(), value); @@ -214,6 +219,29 @@ impl TimedLru { (old, cached) } + + #[cfg(feature = "rest_broker")] + pub(crate) fn flush(&self) { + let now = Instant::now(); + let mut cache = self.cache.lock(); + + // Collect keys of expired entries first + let expired_keys: Vec<_> = cache + .iter() + .filter_map(|(key, entry)| { + if entry.expires_at <= now { + Some(key.clone()) + } else { + None + } + }) + .collect(); + + // Remove expired entries + for key in expired_keys { + cache.remove(&key); + } + } } impl TimedLru { diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 6157dc8a6a..16b1dff5f4 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -22,6 +22,8 @@ use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig}; use crate::scram::threadpool::ThreadPool; use crate::serverless::GlobalConnPoolOptions; use crate::serverless::cancel_set::CancelSet; +#[cfg(feature = "rest_broker")] +use crate::serverless::rest::DbSchemaCache; pub use crate::tls::server_config::{TlsConfig, configure_tls}; use crate::types::{Host, RoleName}; @@ -30,11 +32,14 @@ pub struct ProxyConfig { pub metric_collection: Option, pub http_config: HttpConfig, pub authentication_config: AuthenticationConfig, + #[cfg(feature = "rest_broker")] + pub rest_config: RestConfig, pub proxy_protocol_v2: ProxyProtocolV2, pub handshake_timeout: Duration, pub wake_compute_retry_config: RetryConfig, pub connect_compute_locks: ApiLocks, pub connect_to_compute: ComputeConfig, + pub greetings: String, // Greeting message sent to the client after connection establishment and contains session_id. #[cfg(feature = "testing")] pub disable_pg_session_jwt: bool, } @@ -80,6 +85,14 @@ pub struct AuthenticationConfig { pub console_redirect_confirmation_timeout: tokio::time::Duration, } +#[cfg(feature = "rest_broker")] +pub struct RestConfig { + pub is_rest_broker: bool, + pub db_schema_cache: Option, + pub max_schema_size: usize, + pub hostname_prefix: String, +} + #[derive(Debug)] pub struct MetricBackupCollectionConfig { pub remote_storage_config: Option, diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 041a56e032..014317d823 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -233,7 +233,13 @@ pub(crate) async fn handle_client( let session = cancellation_handler.get_key(); - finish_client_init(&pg_settings, *session.key(), &mut stream); + finish_client_init( + ctx, + &pg_settings, + *session.key(), + &mut stream, + &config.greetings, + ); let stream = stream.flush_and_into_inner().await?; let session_id = ctx.session_id(); diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index 4d8df19476..715b818b98 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -523,29 +523,29 @@ mod tests { fn generate_request_data(rng: &mut impl Rng) -> RequestData { RequestData { - session_id: uuid::Builder::from_random_bytes(rng.r#gen()).into_uuid(), - peer_addr: Ipv4Addr::from(rng.r#gen::<[u8; 4]>()).to_string(), + session_id: uuid::Builder::from_random_bytes(rng.random()).into_uuid(), + peer_addr: Ipv4Addr::from(rng.random::<[u8; 4]>()).to_string(), timestamp: chrono::DateTime::from_timestamp_millis( - rng.gen_range(1703862754..1803862754), + rng.random_range(1703862754..1803862754), ) .unwrap() .naive_utc(), application_name: Some("test".to_owned()), user_agent: Some("test-user-agent".to_owned()), - username: Some(hex::encode(rng.r#gen::<[u8; 4]>())), - endpoint_id: Some(hex::encode(rng.r#gen::<[u8; 16]>())), - database: Some(hex::encode(rng.r#gen::<[u8; 16]>())), - project: Some(hex::encode(rng.r#gen::<[u8; 16]>())), - branch: Some(hex::encode(rng.r#gen::<[u8; 16]>())), + username: Some(hex::encode(rng.random::<[u8; 4]>())), + endpoint_id: Some(hex::encode(rng.random::<[u8; 16]>())), + database: Some(hex::encode(rng.random::<[u8; 16]>())), + project: Some(hex::encode(rng.random::<[u8; 16]>())), + branch: Some(hex::encode(rng.random::<[u8; 16]>())), pg_options: None, auth_method: None, jwt_issuer: None, - protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)], + protocol: ["tcp", "ws", "http"][rng.random_range(0..3)], region: String::new(), error: None, - success: rng.r#gen(), + success: rng.random(), cold_start_info: "no", - duration_us: rng.gen_range(0..30_000_000), + duration_us: rng.random_range(0..30_000_000), disconnect_timestamp: None, } } @@ -622,15 +622,15 @@ mod tests { assert_eq!( file_stats, [ - (1313953, 3, 6000), - (1313942, 3, 6000), - (1314001, 3, 6000), - (1313958, 3, 6000), - (1314094, 3, 6000), - (1313931, 3, 6000), - (1313725, 3, 6000), - (1313960, 3, 6000), - (438318, 1, 2000) + (1313878, 3, 6000), + (1313891, 3, 6000), + (1314058, 3, 6000), + (1313914, 3, 6000), + (1313760, 3, 6000), + (1314084, 3, 6000), + (1313965, 3, 6000), + (1313911, 3, 6000), + (438290, 1, 2000) ] ); @@ -662,11 +662,11 @@ mod tests { assert_eq!( file_stats, [ - (1205810, 5, 10000), - (1205534, 5, 10000), - (1205835, 5, 10000), - (1205820, 5, 10000), - (1206074, 5, 10000) + (1206039, 5, 10000), + (1205798, 5, 10000), + (1205776, 5, 10000), + (1206051, 5, 10000), + (1205746, 5, 10000) ] ); @@ -691,15 +691,15 @@ mod tests { assert_eq!( file_stats, [ - (1313953, 3, 6000), - (1313942, 3, 6000), - (1314001, 3, 6000), - (1313958, 3, 6000), - (1314094, 3, 6000), - (1313931, 3, 6000), - (1313725, 3, 6000), - (1313960, 3, 6000), - (438318, 1, 2000) + (1313878, 3, 6000), + (1313891, 3, 6000), + (1314058, 3, 6000), + (1313914, 3, 6000), + (1313760, 3, 6000), + (1314084, 3, 6000), + (1313965, 3, 6000), + (1313911, 3, 6000), + (438290, 1, 2000) ] ); @@ -736,7 +736,7 @@ mod tests { // files are smaller than the size threshold, but they took too long to fill so were flushed early assert_eq!( file_stats, - [(658584, 2, 3001), (658298, 2, 3000), (658094, 2, 2999)] + [(658552, 2, 3001), (658265, 2, 3000), (658061, 2, 2999)] ); tmpdir.close().unwrap(); diff --git a/proxy/src/intern.rs b/proxy/src/intern.rs index d7e39ebaf4..825f2d1049 100644 --- a/proxy/src/intern.rs +++ b/proxy/src/intern.rs @@ -247,7 +247,7 @@ mod tests { use rand::{Rng, SeedableRng}; use rand_distr::Zipf; - let endpoint_dist = Zipf::new(500000, 0.8).unwrap(); + let endpoint_dist = Zipf::new(500000.0, 0.8).unwrap(); let endpoints = StdRng::seed_from_u64(272488357).sample_iter(endpoint_dist); let interner = MyId::get_interner(); diff --git a/proxy/src/logging.rs b/proxy/src/logging.rs index d4fd826c13..0abb500608 100644 --- a/proxy/src/logging.rs +++ b/proxy/src/logging.rs @@ -26,7 +26,7 @@ use crate::metrics::Metrics; /// configuration from environment variables. For example, to change the /// destination, set `OTEL_EXPORTER_OTLP_ENDPOINT=http://jaeger:4318`. /// See -pub async fn init() -> anyhow::Result { +pub fn init() -> anyhow::Result { let logfmt = LogFormat::from_env()?; let env_filter = EnvFilter::builder() @@ -43,8 +43,8 @@ pub async fn init() -> anyhow::Result { .expect("this should be a valid filter directive"), ); - let otlp_layer = - tracing_utils::init_tracing("proxy", tracing_utils::ExportConfig::default()).await; + let provider = tracing_utils::init_tracing("proxy", tracing_utils::ExportConfig::default()); + let otlp_layer = provider.as_ref().map(tracing_utils::layer); let json_log_layer = if logfmt == LogFormat::Json { Some(JsonLoggingLayer::new( @@ -76,7 +76,7 @@ pub async fn init() -> anyhow::Result { .with(text_log_layer) .try_init()?; - Ok(LoggingGuard) + Ok(LoggingGuard(provider)) } /// Initialize logging for local_proxy with log prefix and no opentelemetry. @@ -97,7 +97,7 @@ pub fn init_local_proxy() -> anyhow::Result { .with(fmt_layer) .try_init()?; - Ok(LoggingGuard) + Ok(LoggingGuard(None)) } pub struct LocalProxyFormatter(Format); @@ -118,14 +118,16 @@ where } } -pub struct LoggingGuard; +pub struct LoggingGuard(Option); impl Drop for LoggingGuard { fn drop(&mut self) { - // Shutdown trace pipeline gracefully, so that it has a chance to send any - // pending traces before we exit. - tracing::info!("shutting down the tracing machinery"); - tracing_utils::shutdown_tracing(); + if let Some(p) = &self.0 { + // Shutdown trace pipeline gracefully, so that it has a chance to send any + // pending traces before we exit. + tracing::info!("shutting down the tracing machinery"); + drop(p.shutdown()); + } } } diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 916604e2ec..7524133093 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -385,10 +385,10 @@ pub enum RedisMsgKind { #[derive(Default, Clone)] pub struct LatencyAccumulated { - cplane: time::Duration, - client: time::Duration, - compute: time::Duration, - retry: time::Duration, + pub cplane: time::Duration, + pub client: time::Duration, + pub compute: time::Duration, + pub retry: time::Duration, } impl std::fmt::Display for LatencyAccumulated { diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs index ad99eecda5..680a23c435 100644 --- a/proxy/src/pqproto.rs +++ b/proxy/src/pqproto.rs @@ -7,7 +7,7 @@ use std::io::{self, Cursor}; use bytes::{Buf, BufMut}; use itertools::Itertools; -use rand::distributions::{Distribution, Standard}; +use rand::distr::{Distribution, StandardUniform}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian}; @@ -458,9 +458,9 @@ impl fmt::Display for CancelKeyData { .finish() } } -impl Distribution for Standard { +impl Distribution for StandardUniform { fn sample(&self, rng: &mut R) -> CancelKeyData { - id_to_cancel_key(rng.r#gen()) + id_to_cancel_key(rng.random()) } } diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 02651109e0..8b7c4ff55d 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -145,7 +145,7 @@ pub(crate) async fn handle_client( let session = cancellation_handler.get_key(); - finish_client_init(&pg_settings, *session.key(), client); + finish_client_init(ctx, &pg_settings, *session.key(), client, &config.greetings); let session_id = ctx.session_id(); let (cancel_on_shutdown, cancel) = oneshot::channel(); @@ -165,9 +165,11 @@ pub(crate) async fn handle_client( /// Finish client connection initialization: confirm auth success, send params, etc. pub(crate) fn finish_client_init( + ctx: &RequestContext, settings: &compute::PostgresSettings, cancel_key_data: CancelKeyData, client: &mut PqStream, + greetings: &String, ) { // Forward all deferred notices to the client. for notice in &settings.delayed_notice { @@ -176,6 +178,12 @@ pub(crate) fn finish_client_init( }); } + // Expose session_id to clients if we have a greeting message. + if !greetings.is_empty() { + let session_msg = format!("{}, session_id: {}", greetings, ctx.session_id()); + client.write_message(BeMessage::NoticeResponse(session_msg.as_str())); + } + // Forward all postgres connection params to the client. for (name, value) in &settings.params { client.write_message(BeMessage::ParameterStatus { @@ -184,6 +192,36 @@ pub(crate) fn finish_client_init( }); } + // Forward recorded latencies for probing requests + if let Some(testodrome_id) = ctx.get_testodrome_id() { + client.write_message(BeMessage::ParameterStatus { + name: "neon.testodrome_id".as_bytes(), + value: testodrome_id.as_bytes(), + }); + + let latency_measured = ctx.get_proxy_latency(); + + client.write_message(BeMessage::ParameterStatus { + name: "neon.cplane_latency".as_bytes(), + value: latency_measured.cplane.as_micros().to_string().as_bytes(), + }); + + client.write_message(BeMessage::ParameterStatus { + name: "neon.client_latency".as_bytes(), + value: latency_measured.client.as_micros().to_string().as_bytes(), + }); + + client.write_message(BeMessage::ParameterStatus { + name: "neon.compute_latency".as_bytes(), + value: latency_measured.compute.as_micros().to_string().as_bytes(), + }); + + client.write_message(BeMessage::ParameterStatus { + name: "neon.retry_latency".as_bytes(), + value: latency_measured.retry.as_micros().to_string().as_bytes(), + }); + } + client.write_message(BeMessage::BackendKeyData(cancel_key_data)); client.write_message(BeMessage::ReadyForQuery); } diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index dd89b05426..f8bff450e1 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -338,8 +338,8 @@ async fn scram_auth_mock() -> anyhow::Result<()> { let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), Scram::mock())); use rand::Rng; - use rand::distributions::Alphanumeric; - let password: String = rand::thread_rng() + use rand::distr::Alphanumeric; + let password: String = rand::rng() .sample_iter(&Alphanumeric) .take(rand::random::() as usize) .map(char::from) diff --git a/proxy/src/rate_limiter/leaky_bucket.rs b/proxy/src/rate_limiter/leaky_bucket.rs index 12b4bda0c0..9de82e922c 100644 --- a/proxy/src/rate_limiter/leaky_bucket.rs +++ b/proxy/src/rate_limiter/leaky_bucket.rs @@ -3,7 +3,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use ahash::RandomState; use clashmap::ClashMap; -use rand::{Rng, thread_rng}; +use rand::Rng; use tokio::time::Instant; use tracing::info; use utils::leaky_bucket::LeakyBucketState; @@ -61,7 +61,7 @@ impl LeakyBucketRateLimiter { self.map.len() ); let n = self.map.shards().len(); - let shard = thread_rng().gen_range(0..n); + let shard = rand::rng().random_range(0..n); self.map.shards()[shard] .write() .retain(|(_, value)| !value.bucket_is_empty(now)); diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index fd1b2af023..2b3d745a0e 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -147,7 +147,7 @@ impl RateBucketInfo { impl BucketRateLimiter { pub fn new(info: impl Into>) -> Self { - Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new()) + Self::new_with_rand_and_hasher(info, StdRng::from_os_rng(), RandomState::new()) } } @@ -216,7 +216,7 @@ impl BucketRateLimiter { let n = self.map.shards().len(); // this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide // (impossible, infact, unless we have 2048 threads) - let shard = self.rand.lock_propagate_poison().gen_range(0..n); + let shard = self.rand.lock_propagate_poison().random_range(0..n); self.map.shards()[shard].write().clear(); } } diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index a6d376562b..88d5550fff 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -10,6 +10,7 @@ use super::connection_with_credentials_provider::ConnectionWithCredentialsProvid use crate::cache::project_info::ProjectInfoCache; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::metrics::{Metrics, RedisErrors, RedisEventsCount}; +use crate::util::deserialize_json_string; const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20); @@ -121,15 +122,6 @@ struct InvalidateRole { role_name: RoleNameInt, } -fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result -where - T: for<'de2> serde::Deserialize<'de2>, - D: serde::Deserializer<'de>, -{ - let s = String::deserialize(deserializer)?; - serde_json::from_str(&s).map_err(::custom) -} - // https://github.com/serde-rs/serde/issues/1714 fn deserialize_unknown_topic<'de, D>(deserializer: D) -> Result<(), D::Error> where diff --git a/proxy/src/scram/countmin.rs b/proxy/src/scram/countmin.rs index 9d56c465ec..d64895f8f5 100644 --- a/proxy/src/scram/countmin.rs +++ b/proxy/src/scram/countmin.rs @@ -86,11 +86,11 @@ mod tests { for _ in 0..n { // number to insert at once - let n = rng.gen_range(1..4096); + let n = rng.random_range(1..4096); // number of insert operations - let m = rng.gen_range(1..100); + let m = rng.random_range(1..100); - let id = uuid::Builder::from_random_bytes(rng.r#gen()).into_uuid(); + let id = uuid::Builder::from_random_bytes(rng.random()).into_uuid(); ids.push((id, n, m)); // N = sum(actual) @@ -140,8 +140,8 @@ mod tests { // probably numbers are too small to truly represent the probabilities. assert_eq!(eval_precision(100, 4096.0, 0.90), 100); assert_eq!(eval_precision(1000, 4096.0, 0.90), 1000); - assert_eq!(eval_precision(100, 4096.0, 0.1), 96); - assert_eq!(eval_precision(1000, 4096.0, 0.1), 988); + assert_eq!(eval_precision(100, 4096.0, 0.1), 100); + assert_eq!(eval_precision(1000, 4096.0, 0.1), 978); } // returns memory usage in bytes, and the time complexity per insert. diff --git a/proxy/src/scram/threadpool.rs b/proxy/src/scram/threadpool.rs index 1aa402227f..ea2e29ede9 100644 --- a/proxy/src/scram/threadpool.rs +++ b/proxy/src/scram/threadpool.rs @@ -51,7 +51,7 @@ impl ThreadPool { *state = Some(ThreadRt { pool: pool.clone(), id: ThreadPoolWorkerId(worker_id.fetch_add(1, Ordering::Relaxed)), - rng: SmallRng::from_entropy(), + rng: SmallRng::from_os_rng(), // used to determine whether we should temporarily skip tasks for fairness. // 99% of estimates will overcount by no more than 4096 samples countmin: CountMinSketch::with_params( @@ -120,7 +120,7 @@ impl ThreadRt { // in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above // are in requests per second. let probability = P.ln() / (P + rate as f64).ln(); - self.rng.gen_bool(probability) + self.rng.random_bool(probability) } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index daa6429039..59e4b09bc9 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -8,7 +8,7 @@ use ed25519_dalek::SigningKey; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use jose_jwk::jose_b64; use postgres_client::config::SslMode; -use rand::rngs::OsRng; +use rand_core::OsRng; use rustls::pki_types::{DnsName, ServerName}; use tokio::net::{TcpStream, lookup_host}; use tokio_rustls::TlsConnector; diff --git a/proxy/src/serverless/cancel_set.rs b/proxy/src/serverless/cancel_set.rs index ba8945afc5..142dc3b3d5 100644 --- a/proxy/src/serverless/cancel_set.rs +++ b/proxy/src/serverless/cancel_set.rs @@ -6,7 +6,7 @@ use std::time::Duration; use indexmap::IndexMap; use parking_lot::Mutex; -use rand::{Rng, thread_rng}; +use rand::distr::uniform::{UniformSampler, UniformUsize}; use rustc_hash::FxHasher; use tokio::time::Instant; use tokio_util::sync::CancellationToken; @@ -39,8 +39,9 @@ impl CancelSet { } pub(crate) fn take(&self) -> Option { + let dist = UniformUsize::new_inclusive(0, usize::MAX).expect("valid bounds"); for _ in 0..4 { - if let Some(token) = self.take_raw(thread_rng().r#gen()) { + if let Some(token) = self.take_raw(dist.sample(&mut rand::rng())) { return Some(token); } tracing::trace!("failed to get cancel token"); @@ -48,7 +49,7 @@ impl CancelSet { None } - pub(crate) fn take_raw(&self, rng: usize) -> Option { + fn take_raw(&self, rng: usize) -> Option { NonZeroUsize::new(self.shards.len()) .and_then(|len| self.shards[rng % len].lock().take(rng / len)) } diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 672e59f81f..015c46f787 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -3,15 +3,14 @@ use std::pin::pin; use std::sync::{Arc, Weak}; use std::task::{Poll, ready}; -use futures::Future; use futures::future::poll_fn; -use postgres_client::AsyncMessage; +use futures::{Future, FutureExt}; use postgres_client::tls::MakeTlsConnect; use smallvec::SmallVec; use tokio::net::TcpStream; use tokio::time::Instant; use tokio_util::sync::CancellationToken; -use tracing::{Instrument, error, info, info_span, warn}; +use tracing::{error, info, info_span}; #[cfg(test)] use { super::conn_pool_lib::GlobalConnPoolOptions, @@ -85,16 +84,17 @@ pub(crate) fn poll_client( let cancel = CancellationToken::new(); let cancelled = cancel.clone().cancelled_owned(); - tokio::spawn( - async move { + tokio::spawn(async move { let _conn_gauge = conn_gauge; let mut idle_timeout = pin!(tokio::time::sleep(idle)); let mut cancelled = pin!(cancelled); poll_fn(move |cx| { + let _instrument = span.enter(); + if cancelled.as_mut().poll(cx).is_ready() { info!("connection dropped"); - return Poll::Ready(()) + return Poll::Ready(()); } match rx.has_changed() { @@ -105,7 +105,7 @@ pub(crate) fn poll_client( } Err(_) => { info!("connection dropped"); - return Poll::Ready(()) + return Poll::Ready(()); } _ => {} } @@ -123,41 +123,22 @@ pub(crate) fn poll_client( } } - loop { - let message = ready!(connection.poll_message(cx)); - - match message { - Some(Ok(AsyncMessage::Notice(notice))) => { - info!(%session_id, "notice: {}", notice); - } - Some(Ok(AsyncMessage::Notification(notif))) => { - warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received"); - } - Some(Ok(_)) => { - warn!(%session_id, "unknown message"); - } - Some(Err(e)) => { - error!(%session_id, "connection error: {}", e); - break - } - None => { - info!("connection closed"); - break - } - } + match ready!(connection.poll_unpin(cx)) { + Err(e) => error!(%session_id, "connection error: {}", e), + Ok(()) => info!("connection closed"), } // remove from connection pool if let Some(pool) = pool.clone().upgrade() - && pool.write().remove_client(db_user.clone(), conn_id) { - info!("closed connection removed"); - } + && pool.write().remove_client(db_user.clone(), conn_id) + { + info!("closed connection removed"); + } Poll::Ready(()) - }).await; - - } - .instrument(span)); + }) + .await; + }); let inner = ClientInnerCommon { inner: client, aux, diff --git a/proxy/src/serverless/conn_pool_lib.rs b/proxy/src/serverless/conn_pool_lib.rs index 42a3ea17a2..ed5cc0ea03 100644 --- a/proxy/src/serverless/conn_pool_lib.rs +++ b/proxy/src/serverless/conn_pool_lib.rs @@ -428,7 +428,7 @@ where loop { interval.tick().await; - let shard = rng.gen_range(0..self.global_pool.shards().len()); + let shard = rng.random_range(0..self.global_pool.shards().len()); self.gc(shard); } } diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index e4cbd02bfe..f63d84d66b 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -19,18 +19,17 @@ use std::time::Duration; use base64::Engine as _; use base64::prelude::BASE64_URL_SAFE_NO_PAD; use ed25519_dalek::{Signature, Signer, SigningKey}; -use futures::Future; use futures::future::poll_fn; +use futures::{Future, FutureExt}; use indexmap::IndexMap; use jose_jwk::jose_b64::base64ct::{Base64UrlUnpadded, Encoding}; use parking_lot::RwLock; -use postgres_client::AsyncMessage; use postgres_client::tls::NoTlsStream; use serde_json::value::RawValue; use tokio::net::TcpStream; use tokio::time::Instant; use tokio_util::sync::CancellationToken; -use tracing::{Instrument, debug, error, info, info_span, warn}; +use tracing::{debug, error, info, info_span}; use super::backend::HttpConnError; use super::conn_pool_lib::{ @@ -186,16 +185,17 @@ pub(crate) fn poll_client( let cancel = CancellationToken::new(); let cancelled = cancel.clone().cancelled_owned(); - tokio::spawn( - async move { + tokio::spawn(async move { let _conn_gauge = conn_gauge; let mut idle_timeout = pin!(tokio::time::sleep(idle)); let mut cancelled = pin!(cancelled); poll_fn(move |cx| { + let _instrument = span.enter(); + if cancelled.as_mut().poll(cx).is_ready() { info!("connection dropped"); - return Poll::Ready(()) + return Poll::Ready(()); } match rx.has_changed() { @@ -206,7 +206,7 @@ pub(crate) fn poll_client( } Err(_) => { info!("connection dropped"); - return Poll::Ready(()) + return Poll::Ready(()); } _ => {} } @@ -218,47 +218,35 @@ pub(crate) fn poll_client( if let Some(pool) = pool.clone().upgrade() { // remove client from pool - should close the connection if it's idle. // does nothing if the client is currently checked-out and in-use - if pool.global_pool.write().remove_client(db_user.clone(), conn_id) { + if pool + .global_pool + .write() + .remove_client(db_user.clone(), conn_id) + { info!("idle connection removed"); } } } - loop { - let message = ready!(connection.poll_message(cx)); - - match message { - Some(Ok(AsyncMessage::Notice(notice))) => { - info!(%session_id, "notice: {}", notice); - } - Some(Ok(AsyncMessage::Notification(notif))) => { - warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received"); - } - Some(Ok(_)) => { - warn!(%session_id, "unknown message"); - } - Some(Err(e)) => { - error!(%session_id, "connection error: {}", e); - break - } - None => { - info!("connection closed"); - break - } - } + match ready!(connection.poll_unpin(cx)) { + Err(e) => error!(%session_id, "connection error: {}", e), + Ok(()) => info!("connection closed"), } // remove from connection pool if let Some(pool) = pool.clone().upgrade() - && pool.global_pool.write().remove_client(db_user.clone(), conn_id) { - info!("closed connection removed"); - } + && pool + .global_pool + .write() + .remove_client(db_user.clone(), conn_id) + { + info!("closed connection removed"); + } Poll::Ready(()) - }).await; - - } - .instrument(span)); + }) + .await; + }); let inner = ClientInnerCommon { inner: client, diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 5b7289c53d..13f9ee2782 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -11,6 +11,8 @@ mod http_conn_pool; mod http_util; mod json; mod local_conn_pool; +#[cfg(feature = "rest_broker")] +pub mod rest; mod sql_over_http; mod websocket; @@ -75,7 +77,7 @@ pub async fn task_main( { let conn_pool = Arc::clone(&conn_pool); tokio::spawn(async move { - conn_pool.gc_worker(StdRng::from_entropy()).await; + conn_pool.gc_worker(StdRng::from_os_rng()).await; }); } @@ -95,7 +97,7 @@ pub async fn task_main( { let http_conn_pool = Arc::clone(&http_conn_pool); tokio::spawn(async move { - http_conn_pool.gc_worker(StdRng::from_entropy()).await; + http_conn_pool.gc_worker(StdRng::from_os_rng()).await; }); } @@ -487,6 +489,42 @@ async fn request_handler( .body(Empty::new().map_err(|x| match x {}).boxed()) .map_err(|e| ApiError::InternalServerError(e.into())) } else { - json_response(StatusCode::BAD_REQUEST, "query is not supported") + #[cfg(feature = "rest_broker")] + { + if config.rest_config.is_rest_broker + // we are testing for the path to be /database_name/rest/... + && request + .uri() + .path() + .split('/') + .nth(2) + .is_some_and(|part| part.starts_with("rest")) + { + let ctx = + RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http); + let span = ctx.span(); + + let testodrome_id = request + .headers() + .get("X-Neon-Query-ID") + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + if let Some(query_id) = testodrome_id { + info!(parent: &span, "testodrome query ID: {query_id}"); + ctx.set_testodrome_id(query_id.into()); + } + + rest::handle(config, ctx, request, backend, http_cancellation_token) + .instrument(span) + .await + } else { + json_response(StatusCode::BAD_REQUEST, "query is not supported") + } + } + #[cfg(not(feature = "rest_broker"))] + { + json_response(StatusCode::BAD_REQUEST, "query is not supported") + } } } diff --git a/proxy/src/serverless/rest.rs b/proxy/src/serverless/rest.rs new file mode 100644 index 0000000000..173c2629f7 --- /dev/null +++ b/proxy/src/serverless/rest.rs @@ -0,0 +1,1165 @@ +use std::borrow::Cow; +use std::collections::HashMap; +use std::sync::Arc; + +use bytes::Bytes; +use http::Method; +use http::header::{AUTHORIZATION, CONTENT_TYPE, HOST}; +use http_body_util::combinators::BoxBody; +use http_body_util::{BodyExt, Full}; +use http_utils::error::ApiError; +use hyper::body::Incoming; +use hyper::http::{HeaderName, HeaderValue}; +use hyper::{Request, Response, StatusCode}; +use indexmap::IndexMap; +use ouroboros::self_referencing; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Deserializer}; +use serde_json::Value as JsonValue; +use serde_json::value::RawValue; +use subzero_core::api::ContentType::{ApplicationJSON, Other, SingularJSON, TextCSV}; +use subzero_core::api::QueryNode::{Delete, FunctionCall, Insert, Update}; +use subzero_core::api::Resolution::{IgnoreDuplicates, MergeDuplicates}; +use subzero_core::api::{ApiResponse, ListVal, Payload, Preferences, Representation, SingleVal}; +use subzero_core::config::{db_allowed_select_functions, db_schemas, role_claim_key}; +use subzero_core::dynamic_statement::{JoinIterator, param, sql}; +use subzero_core::error::Error::{ + self as SubzeroCoreError, ContentTypeError, GucHeadersError, GucStatusError, InternalError, + JsonDeserialize, JwtTokenInvalid, NotFound, +}; +use subzero_core::error::pg_error_to_status_code; +use subzero_core::formatter::Param::{LV, PL, SV, Str, StrOwned}; +use subzero_core::formatter::postgresql::{fmt_main_query, generate}; +use subzero_core::formatter::{Param, Snippet, SqlParam}; +use subzero_core::parser::postgrest::parse; +use subzero_core::permissions::{check_safe_functions, replace_select_star}; +use subzero_core::schema::{ + DbSchema, POSTGRESQL_INTROSPECTION_SQL, get_postgresql_configuration_query, +}; +use subzero_core::{content_range_header, content_range_status}; +use tokio_util::sync::CancellationToken; +use tracing::{error, info}; +use typed_json::json; +use url::form_urlencoded; + +use super::backend::{HttpConnError, LocalProxyConnError, PoolingBackend}; +use super::conn_pool::AuthData; +use super::conn_pool_lib::ConnInfo; +use super::error::{ConnInfoError, Credentials, HttpCodeError, ReadPayloadError}; +use super::http_conn_pool::{self, Send}; +use super::http_util::{ + ALLOW_POOL, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_ISOLATION_LEVEL, TXN_READ_ONLY, + get_conn_info, json_response, uuid_to_header_value, +}; +use super::json::JsonConversionError; +use crate::auth::backend::ComputeCredentialKeys; +use crate::cache::{Cached, TimedLru}; +use crate::config::ProxyConfig; +use crate::context::RequestContext; +use crate::error::{ErrorKind, ReportableError, UserFacingError}; +use crate::http::read_body_with_limit; +use crate::metrics::Metrics; +use crate::serverless::sql_over_http::HEADER_VALUE_TRUE; +use crate::types::EndpointCacheKey; +use crate::util::deserialize_json_string; + +static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#; +const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL; + +// A wrapper around the DbSchema that allows for self-referencing +#[self_referencing] +pub struct DbSchemaOwned { + schema_string: String, + #[covariant] + #[borrows(schema_string)] + schema: DbSchema<'this>, +} + +impl<'de> Deserialize<'de> for DbSchemaOwned { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + DbSchemaOwned::try_new(s, |s| serde_json::from_str(s)) + .map_err(::custom) + } +} + +fn split_comma_separated(s: &str) -> Vec { + s.split(',').map(|s| s.trim().to_string()).collect() +} + +fn deserialize_comma_separated<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + Ok(split_comma_separated(&s)) +} + +fn deserialize_comma_separated_option<'de, D>( + deserializer: D, +) -> Result>, D::Error> +where + D: Deserializer<'de>, +{ + let opt = Option::::deserialize(deserializer)?; + if let Some(s) = &opt { + let trimmed = s.trim(); + if trimmed.is_empty() { + return Ok(None); + } + return Ok(Some(split_comma_separated(trimmed))); + } + Ok(None) +} + +// The ApiConfig is the configuration for the API per endpoint +// The configuration is read from the database and cached in the DbSchemaCache +#[derive(Deserialize, Debug)] +pub struct ApiConfig { + #[serde( + default = "db_schemas", + deserialize_with = "deserialize_comma_separated" + )] + pub db_schemas: Vec, + pub db_anon_role: Option, + pub db_max_rows: Option, + #[serde(default = "db_allowed_select_functions")] + pub db_allowed_select_functions: Vec, + // #[serde(deserialize_with = "to_tuple", default)] + // pub db_pre_request: Option<(String, String)>, + #[allow(dead_code)] + #[serde(default = "role_claim_key")] + pub role_claim_key: String, + #[serde(default, deserialize_with = "deserialize_comma_separated_option")] + pub db_extra_search_path: Option>, +} + +// The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint +pub(crate) type DbSchemaCache = TimedLru>; +impl DbSchemaCache { + pub async fn get_cached_or_remote( + &self, + endpoint_id: &EndpointCacheKey, + auth_header: &HeaderValue, + connection_string: &str, + client: &mut http_conn_pool::Client, + ctx: &RequestContext, + config: &'static ProxyConfig, + ) -> Result, RestError> { + match self.get_with_created_at(endpoint_id) { + Some(Cached { value: (v, _), .. }) => Ok(v), + None => { + info!("db_schema cache miss for endpoint: {:?}", endpoint_id); + let remote_value = self + .get_remote(auth_header, connection_string, client, ctx, config) + .await; + let (api_config, schema_owned) = match remote_value { + Ok((api_config, schema_owned)) => (api_config, schema_owned), + Err(e @ RestError::SchemaTooLarge) => { + // for the case where the schema is too large, we cache an empty dummy value + // all the other requests will fail without triggering the introspection query + let schema_owned = serde_json::from_str::(EMPTY_JSON_SCHEMA) + .map_err(|e| JsonDeserialize { source: e })?; + + let api_config = ApiConfig { + db_schemas: vec![], + db_anon_role: None, + db_max_rows: None, + db_allowed_select_functions: vec![], + role_claim_key: String::new(), + db_extra_search_path: None, + }; + let value = Arc::new((api_config, schema_owned)); + self.insert(endpoint_id.clone(), value); + return Err(e); + } + Err(e) => { + return Err(e); + } + }; + let value = Arc::new((api_config, schema_owned)); + self.insert(endpoint_id.clone(), value.clone()); + Ok(value) + } + } + } + pub async fn get_remote( + &self, + auth_header: &HeaderValue, + connection_string: &str, + client: &mut http_conn_pool::Client, + ctx: &RequestContext, + config: &'static ProxyConfig, + ) -> Result<(ApiConfig, DbSchemaOwned), RestError> { + #[derive(Deserialize)] + struct SingleRow { + rows: [Row; 1], + } + + #[derive(Deserialize)] + struct ConfigRow { + #[serde(deserialize_with = "deserialize_json_string")] + config: ApiConfig, + } + + #[derive(Deserialize)] + struct SchemaRow { + json_schema: DbSchemaOwned, + } + + let headers = vec![ + (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())), + ( + &CONN_STRING, + HeaderValue::from_str(connection_string).expect( + "connection string came from a header, so it must be a valid headervalue", + ), + ), + (&AUTHORIZATION, auth_header.clone()), + (&RAW_TEXT_OUTPUT, HEADER_VALUE_TRUE), + ]; + + let query = get_postgresql_configuration_query(Some("pgrst.pre_config")); + let SingleRow { + rows: [ConfigRow { config: api_config }], + } = make_local_proxy_request( + client, + headers.iter().cloned(), + QueryData { + query: Cow::Owned(query), + params: vec![], + }, + config.rest_config.max_schema_size, + ) + .await + .map_err(|e| match e { + RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => { + RestError::SchemaTooLarge + } + e => e, + })?; + + // now that we have the api_config let's run the second INTROSPECTION_SQL query + let SingleRow { + rows: [SchemaRow { json_schema }], + } = make_local_proxy_request( + client, + headers, + QueryData { + query: INTROSPECTION_SQL.into(), + params: vec![ + serde_json::to_value(&api_config.db_schemas) + .expect("Vec is always valid to encode as JSON"), + JsonValue::Bool(false), // include_roles_with_login + JsonValue::Bool(false), // use_internal_permissions + ], + }, + config.rest_config.max_schema_size, + ) + .await + .map_err(|e| match e { + RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => { + RestError::SchemaTooLarge + } + e => e, + })?; + + Ok((api_config, json_schema)) + } +} + +// A type to represent a postgresql errors +// we use our own type (instead of postgres_client::Error) because we get the error from the json response +#[derive(Debug, thiserror::Error, Deserialize)] +pub(crate) struct PostgresError { + pub code: String, + pub message: String, + pub detail: Option, + pub hint: Option, +} +impl HttpCodeError for PostgresError { + fn get_http_status_code(&self) -> StatusCode { + let status = pg_error_to_status_code(&self.code, true); + StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) + } +} +impl ReportableError for PostgresError { + fn get_error_kind(&self) -> ErrorKind { + ErrorKind::User + } +} +impl UserFacingError for PostgresError { + fn to_string_client(&self) -> String { + if self.code.starts_with("PT") { + "Postgres error".to_string() + } else { + self.message.clone() + } + } +} +impl std::fmt::Display for PostgresError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} + +// A type to represent errors that can occur in the rest broker +#[derive(Debug, thiserror::Error)] +pub(crate) enum RestError { + #[error(transparent)] + ReadPayload(#[from] ReadPayloadError), + #[error(transparent)] + ConnectCompute(#[from] HttpConnError), + #[error(transparent)] + ConnInfo(#[from] ConnInfoError), + #[error(transparent)] + Postgres(#[from] PostgresError), + #[error(transparent)] + JsonConversion(#[from] JsonConversionError), + #[error(transparent)] + SubzeroCore(#[from] SubzeroCoreError), + #[error("schema is too large")] + SchemaTooLarge, +} +impl ReportableError for RestError { + fn get_error_kind(&self) -> ErrorKind { + match self { + RestError::ReadPayload(e) => e.get_error_kind(), + RestError::ConnectCompute(e) => e.get_error_kind(), + RestError::ConnInfo(e) => e.get_error_kind(), + RestError::Postgres(_) => ErrorKind::Postgres, + RestError::JsonConversion(_) => ErrorKind::Postgres, + RestError::SubzeroCore(_) => ErrorKind::User, + RestError::SchemaTooLarge => ErrorKind::User, + } + } +} +impl UserFacingError for RestError { + fn to_string_client(&self) -> String { + match self { + RestError::ReadPayload(p) => p.to_string(), + RestError::ConnectCompute(c) => c.to_string_client(), + RestError::ConnInfo(c) => c.to_string_client(), + RestError::SchemaTooLarge => self.to_string(), + RestError::Postgres(p) => p.to_string_client(), + RestError::JsonConversion(_) => "could not parse postgres response".to_string(), + RestError::SubzeroCore(s) => { + // TODO: this is a hack to get the message from the json body + let json = s.json_body(); + let default_message = "Unknown error".to_string(); + + json.get("message") + .map_or(default_message.clone(), |m| match m { + JsonValue::String(s) => s.clone(), + _ => default_message, + }) + } + } + } +} +impl HttpCodeError for RestError { + fn get_http_status_code(&self) -> StatusCode { + match self { + RestError::ReadPayload(e) => e.get_http_status_code(), + RestError::ConnectCompute(h) => match h.get_error_kind() { + ErrorKind::User => StatusCode::BAD_REQUEST, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }, + RestError::ConnInfo(_) => StatusCode::BAD_REQUEST, + RestError::Postgres(e) => e.get_http_status_code(), + RestError::JsonConversion(_) => StatusCode::INTERNAL_SERVER_ERROR, + RestError::SchemaTooLarge => StatusCode::INTERNAL_SERVER_ERROR, + RestError::SubzeroCore(e) => { + let status = e.status_code(); + StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) + } + } + } +} + +// Helper functions for the rest broker + +fn fmt_env_query<'a>(env: &'a HashMap<&'a str, &'a str>) -> Snippet<'a> { + "select " + + if env.is_empty() { + sql("null") + } else { + env.iter() + .map(|(k, v)| { + "set_config(" + param(k as &SqlParam) + ", " + param(v as &SqlParam) + ", true)" + }) + .join(",") + } +} + +// TODO: see about removing the need for cloning the values (inner things are &Cow already) +fn to_sql_param(p: &Param) -> JsonValue { + match p { + SV(SingleVal(v, ..)) => JsonValue::String(v.to_string()), + Str(v) => JsonValue::String((*v).to_string()), + StrOwned(v) => JsonValue::String((*v).clone()), + PL(Payload(v, ..)) => JsonValue::String(v.clone().into_owned()), + LV(ListVal(v, ..)) => { + if v.is_empty() { + JsonValue::String(r"{}".to_string()) + } else { + JsonValue::String(format!( + "{{\"{}\"}}", + v.iter() + .map(|e| e.replace('\\', "\\\\").replace('\"', "\\\"")) + .collect::>() + .join("\",\"") + )) + } + } + } +} + +#[derive(serde::Serialize)] +struct QueryData<'a> { + query: Cow<'a, str>, + params: Vec, +} + +#[derive(serde::Serialize)] +struct BatchQueryData<'a> { + queries: Vec>, +} + +async fn make_local_proxy_request( + client: &mut http_conn_pool::Client, + headers: impl IntoIterator, + body: QueryData<'_>, + max_len: usize, +) -> Result { + let body_string = serde_json::to_string(&body) + .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?; + + let response = make_raw_local_proxy_request(client, headers, body_string).await?; + + let response_status = response.status(); + + if response_status != StatusCode::OK { + return Err(RestError::SubzeroCore(InternalError { + message: "Failed to get endpoint schema".to_string(), + })); + } + + // Capture the response body + let response_body = crate::http::read_body_with_limit(response.into_body(), max_len) + .await + .map_err(ReadPayloadError::from)?; + + // Parse the JSON response + let response_json: S = serde_json::from_slice(&response_body) + .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; + + Ok(response_json) +} + +async fn make_raw_local_proxy_request( + client: &mut http_conn_pool::Client, + headers: impl IntoIterator, + body: String, +) -> Result, RestError> { + let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql"); + let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri); + let req_headers = req.headers_mut().expect("failed to get headers"); + // Add all provided headers to the request + for (header_name, header_value) in headers { + req_headers.insert(header_name, header_value.clone()); + } + + let body_boxed = Full::new(Bytes::from(body)) + .map_err(|never| match never {}) // Convert Infallible to hyper::Error + .boxed(); + + let req = req.body(body_boxed).map_err(|_| { + RestError::SubzeroCore(InternalError { + message: "Failed to build request".to_string(), + }) + })?; + + // Send the request to the local proxy + client + .inner + .inner + .send_request(req) + .await + .map_err(LocalProxyConnError::from) + .map_err(HttpConnError::from) + .map_err(RestError::from) +} + +pub(crate) async fn handle( + config: &'static ProxyConfig, + ctx: RequestContext, + request: Request, + backend: Arc, + cancel: CancellationToken, +) -> Result>, ApiError> { + let result = handle_inner(cancel, config, &ctx, request, backend).await; + + let mut response = match result { + Ok(r) => { + ctx.set_success(); + + // Handling the error response from local proxy here + if r.status().is_server_error() { + let status = r.status(); + + let body_bytes = r + .collect() + .await + .map_err(|e| { + ApiError::InternalServerError(anyhow::Error::msg(format!( + "could not collect http body: {e}" + ))) + })? + .to_bytes(); + + if let Ok(mut json_map) = + serde_json::from_slice::>(&body_bytes) + { + let message = json_map.get("message"); + if let Some(message) = message { + let msg: String = match serde_json::from_str(message.get()) { + Ok(msg) => msg, + Err(_) => { + "Unable to parse the response message from server".to_string() + } + }; + + error!("Error response from local_proxy: {status} {msg}"); + + json_map.retain(|key, _| !key.starts_with("neon:")); // remove all the neon-related keys + + let resp_json = serde_json::to_string(&json_map) + .unwrap_or("failed to serialize the response message".to_string()); + + return json_response(status, resp_json); + } + } + + error!("Unable to parse the response message from local_proxy"); + return json_response( + status, + json!({ "message": "Unable to parse the response message from server".to_string() }), + ); + } + r + } + Err(e @ RestError::SubzeroCore(_)) => { + let error_kind = e.get_error_kind(); + ctx.set_error_kind(error_kind); + + tracing::info!( + kind=error_kind.to_metric_label(), + error=%e, + msg="subzero core error", + "forwarding error to user" + ); + + let RestError::SubzeroCore(subzero_err) = e else { + panic!("expected subzero core error") + }; + + let json_body = subzero_err.json_body(); + let status_code = StatusCode::from_u16(subzero_err.status_code()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + + json_response(status_code, json_body)? + } + Err(e) => { + let error_kind = e.get_error_kind(); + ctx.set_error_kind(error_kind); + + let message = e.to_string_client(); + let status_code = e.get_http_status_code(); + + tracing::info!( + kind=error_kind.to_metric_label(), + error=%e, + msg=message, + "forwarding error to user" + ); + + let (code, detail, hint) = match e { + RestError::Postgres(e) => ( + if e.code.starts_with("PT") { + None + } else { + Some(e.code) + }, + e.detail, + e.hint, + ), + _ => (None, None, None), + }; + + json_response( + status_code, + json!({ + "message": message, + "code": code, + "detail": detail, + "hint": hint, + }), + )? + } + }; + + response + .headers_mut() + .insert("Access-Control-Allow-Origin", HeaderValue::from_static("*")); + Ok(response) +} + +async fn handle_inner( + _cancel: CancellationToken, + config: &'static ProxyConfig, + ctx: &RequestContext, + request: Request, + backend: Arc, +) -> Result>, RestError> { + let _requeset_gauge = Metrics::get() + .proxy + .connection_requests + .guard(ctx.protocol()); + info!( + protocol = %ctx.protocol(), + "handling interactive connection from client" + ); + + // Read host from Host, then URI host as fallback + // TODO: will this be a problem if behind a load balancer? + // TODO: can we use the x-forwarded-host header? + let host = request + .headers() + .get(HOST) + .and_then(|v| v.to_str().ok()) + .unwrap_or_else(|| request.uri().host().unwrap_or("")); + + // a valid path is /database/rest/v1/... so splitting should be ["", "database", "rest", "v1", ...] + let database_name = request + .uri() + .path() + .split('/') + .nth(1) + .ok_or(RestError::SubzeroCore(NotFound { + target: request.uri().path().to_string(), + }))?; + + // we always use the authenticator role to connect to the database + let authenticator_role = "authenticator"; + + // Strip the hostname prefix from the host to get the database hostname + let database_host = host.replace(&config.rest_config.hostname_prefix, ""); + + let connection_string = + format!("postgresql://{authenticator_role}@{database_host}/{database_name}"); + + let conn_info = get_conn_info( + &config.authentication_config, + ctx, + Some(&connection_string), + request.headers(), + )?; + info!( + user = conn_info.conn_info.user_info.user.as_str(), + "credentials" + ); + + match conn_info.auth { + AuthData::Jwt(jwt) => { + let api_prefix = format!("/{database_name}/rest/v1/"); + handle_rest_inner( + config, + ctx, + &api_prefix, + request, + &connection_string, + conn_info.conn_info, + jwt, + backend, + ) + .await + } + AuthData::Password(_) => Err(RestError::ConnInfo(ConnInfoError::MissingCredentials( + Credentials::BearerJwt, + ))), + } +} + +#[allow(clippy::too_many_arguments)] +async fn handle_rest_inner( + config: &'static ProxyConfig, + ctx: &RequestContext, + api_prefix: &str, + request: Request, + connection_string: &str, + conn_info: ConnInfo, + jwt: String, + backend: Arc, +) -> Result>, RestError> { + // validate the jwt token + let jwt_parsed = backend + .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) + .await + .map_err(HttpConnError::from)?; + + let db_schema_cache = + config + .rest_config + .db_schema_cache + .as_ref() + .ok_or(RestError::SubzeroCore(InternalError { + message: "DB schema cache is not configured".to_string(), + }))?; + + let endpoint_cache_key = conn_info + .endpoint_cache_key() + .ok_or(RestError::SubzeroCore(InternalError { + message: "Failed to get endpoint cache key".to_string(), + }))?; + + let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?; + + let (parts, originial_body) = request.into_parts(); + + let auth_header = parts + .headers + .get(AUTHORIZATION) + .ok_or(RestError::SubzeroCore(InternalError { + message: "Authorization header is required".to_string(), + }))?; + + let entry = db_schema_cache + .get_cached_or_remote( + &endpoint_cache_key, + auth_header, + connection_string, + &mut client, + ctx, + config, + ) + .await?; + let (api_config, db_schema_owned) = entry.as_ref(); + let db_schema = db_schema_owned.borrow_schema(); + + let db_schemas = &api_config.db_schemas; // list of schemas available for the api + let db_extra_search_path = &api_config.db_extra_search_path; + // TODO: use this when we get a replacement for jsonpath_lib + // let role_claim_key = &api_config.role_claim_key; + // let role_claim_path = format!("${role_claim_key}"); + let db_anon_role = &api_config.db_anon_role; + let max_rows = api_config.db_max_rows.as_deref(); + let db_allowed_select_functions = api_config + .db_allowed_select_functions + .iter() + .map(|s| s.as_str()) + .collect::>(); + + // extract the jwt claims (we'll need them later to set the role and env) + let jwt_claims = match jwt_parsed.keys { + ComputeCredentialKeys::JwtPayload(payload_bytes) => { + // `payload_bytes` contains the raw JWT payload as Vec + // You can deserialize it back to JSON or parse specific claims + let payload: serde_json::Value = serde_json::from_slice(&payload_bytes) + .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; + Some(payload) + } + ComputeCredentialKeys::AuthKeys(_) => None, + }; + + // read the role from the jwt claims (and set it to the "anon" role if not present) + let (role, authenticated) = match &jwt_claims { + Some(claims) => match claims.get("role") { + Some(JsonValue::String(r)) => (Some(r), true), + _ => (db_anon_role.as_ref(), true), + }, + None => (db_anon_role.as_ref(), false), + }; + + // do not allow unauthenticated requests when there is no anonymous role setup + if let (None, false) = (role, authenticated) { + return Err(RestError::SubzeroCore(JwtTokenInvalid { + message: "unauthenticated requests not allowed".to_string(), + })); + } + + // start deconstructing the request because subzero core mostly works with &str + let method = parts.method; + let method_str = method.as_str(); + let path = parts.uri.path_and_query().map_or("/", |pq| pq.as_str()); + + // this is actually the table name (or rpc/function_name) + // TODO: rename this to something more descriptive + let root = match parts.uri.path().strip_prefix(api_prefix) { + Some(p) => Ok(p), + None => Err(RestError::SubzeroCore(NotFound { + target: parts.uri.path().to_string(), + })), + }?; + + // pick the current schema from the headers (or the first one from config) + let schema_name = &DbSchema::pick_current_schema(db_schemas, method_str, &parts.headers)?; + + // add the content-profile header to the response + let mut response_headers = vec![]; + if db_schemas.len() > 1 { + response_headers.push(("Content-Profile".to_string(), schema_name.clone())); + } + + // parse the query string into a Vec<(&str, &str)> + let query = match parts.uri.query() { + Some(q) => form_urlencoded::parse(q.as_bytes()).collect(), + None => vec![], + }; + let get: Vec<(&str, &str)> = query.iter().map(|(k, v)| (&**k, &**v)).collect(); + + // convert the headers map to a HashMap<&str, &str> + let headers: HashMap<&str, &str> = parts + .headers + .iter() + .map(|(k, v)| (k.as_str(), v.to_str().unwrap_or("__BAD_HEADER__"))) + .collect(); + + let cookies = HashMap::new(); // TODO: add cookies + + // Read the request body (skip for GET requests) + let body_as_string: Option = if method == Method::GET { + None + } else { + let body_bytes = + read_body_with_limit(originial_body, config.http_config.max_request_size_bytes) + .await + .map_err(ReadPayloadError::from)?; + if body_bytes.is_empty() { + None + } else { + Some(String::from_utf8_lossy(&body_bytes).into_owned()) + } + }; + + // parse the request into an ApiRequest struct + let mut api_request = parse( + schema_name, + root, + db_schema, + method_str, + path, + get, + body_as_string.as_deref(), + headers, + cookies, + max_rows, + ) + .map_err(RestError::SubzeroCore)?; + + let role_str = match role { + Some(r) => r, + None => "", + }; + + replace_select_star(db_schema, schema_name, role_str, &mut api_request.query)?; + + // TODO: this is not relevant when acting as PostgREST but will be useful + // in the context of DBX where they need internal permissions + // if !disable_internal_permissions { + // check_privileges(db_schema, schema_name, role_str, &api_request)?; + // } + + check_safe_functions(&api_request, &db_allowed_select_functions)?; + + // TODO: this is not relevant when acting as PostgREST but will be useful + // in the context of DBX where they need internal permissions + // if !disable_internal_permissions { + // insert_policy_conditions(db_schema, schema_name, role_str, &mut api_request.query)?; + // } + + let env_role = Some(role_str); + + // construct the env (passed in to the sql context as GUCs) + let empty_json = "{}".to_string(); + let headers_env = serde_json::to_string(&api_request.headers).unwrap_or(empty_json.clone()); + let cookies_env = serde_json::to_string(&api_request.cookies).unwrap_or(empty_json.clone()); + let get_env = serde_json::to_string(&api_request.get).unwrap_or(empty_json.clone()); + let jwt_claims_env = jwt_claims + .as_ref() + .map(|v| serde_json::to_string(v).unwrap_or(empty_json.clone())) + .unwrap_or(if let Some(r) = env_role { + let claims: HashMap<&str, &str> = HashMap::from([("role", r)]); + serde_json::to_string(&claims).unwrap_or(empty_json.clone()) + } else { + empty_json.clone() + }); + let mut search_path = vec![api_request.schema_name]; + if let Some(extra) = &db_extra_search_path { + search_path.extend(extra.iter().map(|s| s.as_str())); + } + let search_path_str = search_path + .into_iter() + .filter(|s| !s.is_empty()) + .collect::>() + .join(","); + let mut env: HashMap<&str, &str> = HashMap::from([ + ("request.method", api_request.method), + ("request.path", api_request.path), + ("request.headers", &headers_env), + ("request.cookies", &cookies_env), + ("request.get", &get_env), + ("request.jwt.claims", &jwt_claims_env), + ("search_path", &search_path_str), + ]); + if let Some(r) = env_role { + env.insert("role", r); + } + + // generate the sql statements + let (env_statement, env_parameters, _) = generate(fmt_env_query(&env)); + let (main_statement, main_parameters, _) = generate(fmt_main_query( + db_schema, + api_request.schema_name, + &api_request, + &env, + )?); + + let mut headers = vec![ + (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())), + ( + &CONN_STRING, + HeaderValue::from_str(connection_string).expect("invalid connection string"), + ), + (&AUTHORIZATION, auth_header.clone()), + ( + &TXN_ISOLATION_LEVEL, + HeaderValue::from_static("ReadCommitted"), + ), + (&ALLOW_POOL, HEADER_VALUE_TRUE), + ]; + + if api_request.read_only { + headers.push((&TXN_READ_ONLY, HEADER_VALUE_TRUE)); + } + + // convert the parameters from subzero core representation to the local proxy repr. + let req_body = serde_json::to_string(&BatchQueryData { + queries: vec![ + QueryData { + query: env_statement.into(), + params: env_parameters + .iter() + .map(|p| to_sql_param(&p.to_param())) + .collect(), + }, + QueryData { + query: main_statement.into(), + params: main_parameters + .iter() + .map(|p| to_sql_param(&p.to_param())) + .collect(), + }, + ], + }) + .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?; + + // todo: map body to count egress + let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly? + + // send the request to the local proxy + let response = make_raw_local_proxy_request(&mut client, headers, req_body).await?; + let (parts, body) = response.into_parts(); + + let max_response = config.http_config.max_response_size_bytes; + let bytes = read_body_with_limit(body, max_response) + .await + .map_err(ReadPayloadError::from)?; + + // if the response status is greater than 399, then it is an error + // FIXME: check if there are other error codes or shapes of the response + if parts.status.as_u16() > 399 { + // turn this postgres error from the json into PostgresError + let postgres_error = serde_json::from_slice(&bytes) + .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; + + return Err(RestError::Postgres(postgres_error)); + } + + #[derive(Deserialize)] + struct QueryResults { + /// we run two queries, so we want only two results. + results: (EnvRows, MainRows), + } + + /// `env_statement` returns nothing of interest to us + #[derive(Deserialize)] + struct EnvRows {} + + #[derive(Deserialize)] + struct MainRows { + /// `main_statement` only returns a single row. + rows: [MainRow; 1], + } + + #[derive(Deserialize)] + struct MainRow { + body: String, + page_total: Option, + total_result_set: Option, + response_headers: Option, + response_status: Option, + } + + let results: QueryResults = serde_json::from_slice(&bytes) + .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; + + let QueryResults { + results: (_, MainRows { rows: [row] }), + } = results; + + // build the intermediate response object + let api_response = ApiResponse { + page_total: row.page_total.map_or(0, |v| v.parse::().unwrap_or(0)), + total_result_set: row.total_result_set.map(|v| v.parse::().unwrap_or(0)), + top_level_offset: 0, // FIXME: check why this is 0 + response_headers: row.response_headers, + response_status: row.response_status, + body: row.body, + }; + + // TODO: rollback the transaction if the page_total is not 1 and the accept_content_type is SingularJSON + // we can not do this in the context of proxy for now + // if api_request.accept_content_type == SingularJSON && api_response.page_total != 1 { + // // rollback the transaction here + // return Err(RestError::SubzeroCore(SingularityError { + // count: api_response.page_total, + // content_type: "application/vnd.pgrst.object+json".to_string(), + // })); + // } + + // TODO: rollback the transaction if the page_total is not 1 and the method is PUT + // we can not do this in the context of proxy for now + // if api_request.method == Method::PUT && api_response.page_total != 1 { + // // Makes sure the querystring pk matches the payload pk + // // e.g. PUT /items?id=eq.1 { "id" : 1, .. } is accepted, + // // PUT /items?id=eq.14 { "id" : 2, .. } is rejected. + // // If this condition is not satisfied then nothing is inserted, + // // rollback the transaction here + // return Err(RestError::SubzeroCore(PutMatchingPkError)); + // } + + // create and return the response to the client + // this section mostly deals with setting the right headers according to PostgREST specs + let page_total = api_response.page_total; + let total_result_set = api_response.total_result_set; + let top_level_offset = api_response.top_level_offset; + let response_content_type = match (&api_request.accept_content_type, &api_request.query.node) { + (SingularJSON, _) + | ( + _, + FunctionCall { + returns_single: true, + is_scalar: false, + .. + }, + ) => SingularJSON, + (TextCSV, _) => TextCSV, + _ => ApplicationJSON, + }; + + // check if the SQL env set some response headers (happens when we called a rpc function) + if let Some(response_headers_str) = api_response.response_headers { + let Ok(headers_json) = + serde_json::from_str::>>(response_headers_str.as_str()) + else { + return Err(RestError::SubzeroCore(GucHeadersError)); + }; + + response_headers.extend(headers_json.into_iter().flatten()); + } + + // calculate and set the content range header + let lower = top_level_offset as i64; + let upper = top_level_offset as i64 + page_total as i64 - 1; + let total = total_result_set.map(|t| t as i64); + let content_range = match (&method, &api_request.query.node) { + (&Method::POST, Insert { .. }) => content_range_header(1, 0, total), + (&Method::DELETE, Delete { .. }) => content_range_header(1, upper, total), + _ => content_range_header(lower, upper, total), + }; + response_headers.push(("Content-Range".to_string(), content_range)); + + // calculate the status code + #[rustfmt::skip] + let mut status = match (&method, &api_request.query.node, page_total, &api_request.preferences) { + (&Method::POST, Insert { .. }, ..) => 201, + (&Method::DELETE, Delete { .. }, _, Some(Preferences {representation: Some(Representation::Full),..}),) => 200, + (&Method::DELETE, Delete { .. }, ..) => 204, + (&Method::PATCH, Update { columns, .. }, 0, _) if !columns.is_empty() => 404, + (&Method::PATCH, Update { .. }, _,Some(Preferences {representation: Some(Representation::Full),..}),) => 200, + (&Method::PATCH, Update { .. }, ..) => 204, + (&Method::PUT, Insert { .. },_,Some(Preferences {representation: Some(Representation::Full),..}),) => 200, + (&Method::PUT, Insert { .. }, ..) => 204, + _ => content_range_status(lower, upper, total), + }; + + // add the preference-applied header + if let Some(Preferences { + resolution: Some(r), + .. + }) = api_request.preferences + { + response_headers.push(( + "Preference-Applied".to_string(), + match r { + MergeDuplicates => "resolution=merge-duplicates".to_string(), + IgnoreDuplicates => "resolution=ignore-duplicates".to_string(), + }, + )); + } + + // check if the SQL env set some response status (happens when we called a rpc function) + if let Some(response_status_str) = api_response.response_status { + status = response_status_str + .parse::() + .map_err(|_| RestError::SubzeroCore(GucStatusError))?; + } + + // set the content type header + // TODO: move this to a subzero function + // as_header_value(&self) -> Option<&str> + let http_content_type = match response_content_type { + SingularJSON => Ok("application/vnd.pgrst.object+json"), + TextCSV => Ok("text/csv"), + ApplicationJSON => Ok("application/json"), + Other(t) => Err(RestError::SubzeroCore(ContentTypeError { + message: format!("None of these Content-Types are available: {t}"), + })), + }?; + + // build the response body + let response_body = Full::new(Bytes::from(api_response.body)) + .map_err(|never| match never {}) + .boxed(); + + // build the response + let mut response = Response::builder() + .status(StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) + .header(CONTENT_TYPE, http_content_type); + + // Add all headers from response_headers vector + for (header_name, header_value) in response_headers { + response = response.header(header_name, header_value); + } + + // add the body and return the response + response.body(response_body).map_err(|_| { + RestError::SubzeroCore(InternalError { + message: "Failed to build response".to_string(), + }) + }) +} diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 8a14f804b6..f254b41b5b 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -64,7 +64,7 @@ enum Payload { Batch(BatchQueryData), } -static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); +pub(super) const HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result>, D::Error> where diff --git a/proxy/src/util.rs b/proxy/src/util.rs index 0291216d94..c89ebab008 100644 --- a/proxy/src/util.rs +++ b/proxy/src/util.rs @@ -20,3 +20,13 @@ pub async fn run_until( Either::Right((f2, _)) => Err(f2), } } + +pub fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result +where + T: for<'de2> serde::Deserialize<'de2>, + D: serde::Deserializer<'de>, +{ + use serde::Deserialize; + let s = String::deserialize(deserializer)?; + serde_json::from_str(&s).map_err(::custom) +} diff --git a/proxy/subzero_core/.gitignore b/proxy/subzero_core/.gitignore new file mode 100644 index 0000000000..f2f9e58ec3 --- /dev/null +++ b/proxy/subzero_core/.gitignore @@ -0,0 +1,2 @@ +target +Cargo.lock \ No newline at end of file diff --git a/proxy/subzero_core/Cargo.toml b/proxy/subzero_core/Cargo.toml new file mode 100644 index 0000000000..13185873d0 --- /dev/null +++ b/proxy/subzero_core/Cargo.toml @@ -0,0 +1,12 @@ +# This is a stub for the subzero-core crate. +[package] +name = "subzero-core" +version = "3.0.1" +edition = "2024" +publish = false # "private"! + +[features] +default = [] +postgresql = [] + +[dependencies] diff --git a/proxy/subzero_core/src/lib.rs b/proxy/subzero_core/src/lib.rs new file mode 100644 index 0000000000..b99246b98b --- /dev/null +++ b/proxy/subzero_core/src/lib.rs @@ -0,0 +1 @@ +// This is a stub for the subzero-core crate. diff --git a/pyproject.toml b/pyproject.toml index e992e81fe7..7631a05942 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ types-pyyaml = "^6.0.12.20240917" testcontainers = "^4.9.0" # Install a release candidate of `jsonnet`, as it supports Python 3.13 jsonnet = "^0.21.0-rc2" +requests-unixsocket = "^0.4.1" [tool.poetry.group.dev.dependencies] mypy = "==1.13.0" diff --git a/safekeeper/src/rate_limit.rs b/safekeeper/src/rate_limit.rs index 72373b5786..0e697ade57 100644 --- a/safekeeper/src/rate_limit.rs +++ b/safekeeper/src/rate_limit.rs @@ -44,6 +44,6 @@ impl RateLimiter { /// Generate a random duration that is a fraction of the given duration. pub fn rand_duration(duration: &std::time::Duration) -> std::time::Duration { - let randf64 = rand::thread_rng().gen_range(0.0..1.0); + let randf64 = rand::rng().random_range(0.0..1.0); duration.mul_f64(randf64) } diff --git a/safekeeper/src/send_interpreted_wal.rs b/safekeeper/src/send_interpreted_wal.rs index 72a436e25f..671798298b 100644 --- a/safekeeper/src/send_interpreted_wal.rs +++ b/safekeeper/src/send_interpreted_wal.rs @@ -742,7 +742,7 @@ mod tests { use std::str::FromStr; use std::time::Duration; - use pageserver_api::shard::{ShardIdentity, ShardStripeSize}; + use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardIdentity}; use postgres_ffi::{MAX_SEND_SIZE, PgMajorVersion}; use tokio::sync::mpsc::error::TryRecvError; use utils::id::{NodeId, TenantTimelineId}; @@ -786,19 +786,13 @@ mod tests { MAX_SEND_SIZE, ); - let shard_0 = ShardIdentity::new( - ShardNumber(0), - ShardCount(SHARD_COUNT), - ShardStripeSize::default(), - ) - .unwrap(); + let shard_0 = + ShardIdentity::new(ShardNumber(0), ShardCount(SHARD_COUNT), DEFAULT_STRIPE_SIZE) + .unwrap(); - let shard_1 = ShardIdentity::new( - ShardNumber(1), - ShardCount(SHARD_COUNT), - ShardStripeSize::default(), - ) - .unwrap(); + let shard_1 = + ShardIdentity::new(ShardNumber(1), ShardCount(SHARD_COUNT), DEFAULT_STRIPE_SIZE) + .unwrap(); let mut shards = HashMap::new(); @@ -806,7 +800,7 @@ mod tests { let shard_id = ShardIdentity::new( ShardNumber(shard_number), ShardCount(SHARD_COUNT), - ShardStripeSize::default(), + DEFAULT_STRIPE_SIZE, ) .unwrap(); let (tx, rx) = tokio::sync::mpsc::channel::(MSG_COUNT * 2); @@ -934,12 +928,9 @@ mod tests { MAX_SEND_SIZE, ); - let shard_0 = ShardIdentity::new( - ShardNumber(0), - ShardCount(SHARD_COUNT), - ShardStripeSize::default(), - ) - .unwrap(); + let shard_0 = + ShardIdentity::new(ShardNumber(0), ShardCount(SHARD_COUNT), DEFAULT_STRIPE_SIZE) + .unwrap(); struct Sender { tx: Option>, @@ -1088,19 +1079,13 @@ mod tests { WAL_READER_BATCH_SIZE, ); - let shard_0 = ShardIdentity::new( - ShardNumber(0), - ShardCount(SHARD_COUNT), - ShardStripeSize::default(), - ) - .unwrap(); + let shard_0 = + ShardIdentity::new(ShardNumber(0), ShardCount(SHARD_COUNT), DEFAULT_STRIPE_SIZE) + .unwrap(); - let shard_1 = ShardIdentity::new( - ShardNumber(1), - ShardCount(SHARD_COUNT), - ShardStripeSize::default(), - ) - .unwrap(); + let shard_1 = + ShardIdentity::new(ShardNumber(1), ShardCount(SHARD_COUNT), DEFAULT_STRIPE_SIZE) + .unwrap(); let mut shards = HashMap::new(); @@ -1108,7 +1093,7 @@ mod tests { let shard_id = ShardIdentity::new( ShardNumber(shard_number), ShardCount(SHARD_COUNT), - ShardStripeSize::default(), + DEFAULT_STRIPE_SIZE, ) .unwrap(); let (tx, rx) = tokio::sync::mpsc::channel::(MSG_COUNT * 2); diff --git a/safekeeper/src/timeline.rs b/safekeeper/src/timeline.rs index a1a0aab9fd..b8774b30ea 100644 --- a/safekeeper/src/timeline.rs +++ b/safekeeper/src/timeline.rs @@ -427,6 +427,9 @@ impl From for ApiError { TimelineError::NotFound(ttid) => { ApiError::NotFound(anyhow!("timeline {} not found", ttid).into()) } + TimelineError::Deleted(ttid) => { + ApiError::NotFound(anyhow!("timeline {} deleted", ttid).into()) + } _ => ApiError::InternalServerError(anyhow!("{}", te)), } } diff --git a/safekeeper/src/wal_backup.rs b/safekeeper/src/wal_backup.rs index 0e8dfd64c3..03c8f7e84a 100644 --- a/safekeeper/src/wal_backup.rs +++ b/safekeeper/src/wal_backup.rs @@ -8,7 +8,7 @@ use std::time::Duration; use anyhow::{Context, Result}; use camino::{Utf8Path, Utf8PathBuf}; use futures::StreamExt; -use futures::stream::FuturesOrdered; +use futures::stream::{self, FuturesOrdered}; use postgres_ffi::v14::xlog_utils::XLogSegNoOffsetToRecPtr; use postgres_ffi::{PG_TLI, XLogFileName, XLogSegNo}; use remote_storage::{ @@ -723,8 +723,6 @@ pub async fn copy_s3_segments( from_segment: XLogSegNo, to_segment: XLogSegNo, ) -> Result<()> { - const SEGMENTS_PROGRESS_REPORT_INTERVAL: u64 = 1024; - let remote_dst_path = remote_timeline_path(dst_ttid)?; let cancel = CancellationToken::new(); @@ -744,27 +742,69 @@ pub async fn copy_s3_segments( .filter_map(|o| o.key.object_name().map(ToOwned::to_owned)) .collect::>(); - debug!( + info!( "these segments have already been uploaded: {:?}", uploaded_segments ); - for segno in from_segment..to_segment { - if segno % SEGMENTS_PROGRESS_REPORT_INTERVAL == 0 { - info!("copied all segments from {} until {}", from_segment, segno); - } + /* BEGIN_HADRON */ + // Copying multiple segments async. + let mut copy_stream = stream::iter(from_segment..to_segment) + .map(|segno| { + let segment_name = XLogFileName(PG_TLI, segno, wal_seg_size); + let remote_dst_path = remote_dst_path.clone(); + let cancel = cancel.clone(); - let segment_name = XLogFileName(PG_TLI, segno, wal_seg_size); - if uploaded_segments.contains(&segment_name) { - continue; - } - debug!("copying segment {}", segment_name); + async move { + if uploaded_segments.contains(&segment_name) { + return Ok(()); + } - let from = remote_timeline_path(src_ttid)?.join(&segment_name); - let to = remote_dst_path.join(&segment_name); + if segno % 1000 == 0 { + info!("copying segment {} {}", segno, segment_name); + } - storage.copy_object(&from, &to, &cancel).await?; + let from = remote_timeline_path(src_ttid)?.join(&segment_name); + let to = remote_dst_path.join(&segment_name); + + // Retry logic: retry up to 10 times with 1 second delay + let mut retry_count = 0; + const MAX_RETRIES: u32 = 10; + + loop { + match storage.copy_object(&from, &to, &cancel).await { + Ok(()) => return Ok(()), + Err(e) => { + if cancel.is_cancelled() { + // Don't retry if cancellation was requested + return Err(e); + } + + retry_count += 1; + if retry_count >= MAX_RETRIES { + error!( + "Failed to copy segment {} after {} retries: {}", + segment_name, MAX_RETRIES, e + ); + return Err(e); + } + warn!( + "Failed to copy segment {} (attempt {}/{}): {}, retrying...", + segment_name, retry_count, MAX_RETRIES, e + ); + tokio::time::sleep(Duration::from_secs(1)).await; + } + } + } + } + }) + .buffer_unordered(32); // Limit to 32 concurrent uploads + + // Process results, stopping on first error + while let Some(result) = copy_stream.next().await { + result?; } + /* END_HADRON */ info!( "finished copying segments from {} until {}", diff --git a/safekeeper/tests/random_test.rs b/safekeeper/tests/random_test.rs index e29b58836a..7e7d2390e9 100644 --- a/safekeeper/tests/random_test.rs +++ b/safekeeper/tests/random_test.rs @@ -16,7 +16,7 @@ fn test_random_schedules() -> anyhow::Result<()> { let mut config = TestConfig::new(Some(clock)); for _ in 0..500 { - let seed: u64 = rand::thread_rng().r#gen(); + let seed: u64 = rand::rng().random(); config.network = generate_network_opts(seed); let test = config.start(seed); diff --git a/safekeeper/tests/walproposer_sim/simulation.rs b/safekeeper/tests/walproposer_sim/simulation.rs index edd3bf2d9e..595cc7ab64 100644 --- a/safekeeper/tests/walproposer_sim/simulation.rs +++ b/safekeeper/tests/walproposer_sim/simulation.rs @@ -394,13 +394,13 @@ pub fn generate_schedule(seed: u64) -> Schedule { let mut schedule = Vec::new(); let mut time = 0; - let cnt = rng.gen_range(1..100); + let cnt = rng.random_range(1..100); for _ in 0..cnt { - time += rng.gen_range(0..500); - let action = match rng.gen_range(0..3) { - 0 => TestAction::WriteTx(rng.gen_range(1..10)), - 1 => TestAction::RestartSafekeeper(rng.gen_range(0..3)), + time += rng.random_range(0..500); + let action = match rng.random_range(0..3) { + 0 => TestAction::WriteTx(rng.random_range(1..10)), + 1 => TestAction::RestartSafekeeper(rng.random_range(0..3)), 2 => TestAction::RestartWalProposer, _ => unreachable!(), }; @@ -413,13 +413,13 @@ pub fn generate_schedule(seed: u64) -> Schedule { pub fn generate_network_opts(seed: u64) -> NetworkOptions { let mut rng = rand::rngs::StdRng::seed_from_u64(seed); - let timeout = rng.gen_range(100..2000); - let max_delay = rng.gen_range(1..2 * timeout); - let min_delay = rng.gen_range(1..=max_delay); + let timeout = rng.random_range(100..2000); + let max_delay = rng.random_range(1..2 * timeout); + let min_delay = rng.random_range(1..=max_delay); - let max_fail_prob = rng.gen_range(0.0..0.9); - let connect_fail_prob = rng.gen_range(0.0..max_fail_prob); - let send_fail_prob = rng.gen_range(0.0..connect_fail_prob); + let max_fail_prob = rng.random_range(0.0..0.9); + let connect_fail_prob = rng.random_range(0.0..max_fail_prob); + let send_fail_prob = rng.random_range(0.0..connect_fail_prob); NetworkOptions { keepalive_timeout: Some(timeout), diff --git a/storage_controller/migrations/2025-07-08-114340_sk_set_notified_generation/down.sql b/storage_controller/migrations/2025-07-08-114340_sk_set_notified_generation/down.sql new file mode 100644 index 0000000000..27d6048cd3 --- /dev/null +++ b/storage_controller/migrations/2025-07-08-114340_sk_set_notified_generation/down.sql @@ -0,0 +1 @@ +ALTER TABLE timelines DROP sk_set_notified_generation; diff --git a/storage_controller/migrations/2025-07-08-114340_sk_set_notified_generation/up.sql b/storage_controller/migrations/2025-07-08-114340_sk_set_notified_generation/up.sql new file mode 100644 index 0000000000..50178ab6a3 --- /dev/null +++ b/storage_controller/migrations/2025-07-08-114340_sk_set_notified_generation/up.sql @@ -0,0 +1 @@ +ALTER TABLE timelines ADD sk_set_notified_generation INTEGER NOT NULL DEFAULT 1; diff --git a/storage_controller/src/hadron_utils.rs b/storage_controller/src/hadron_utils.rs index 871e21c367..8bfbe8e575 100644 --- a/storage_controller/src/hadron_utils.rs +++ b/storage_controller/src/hadron_utils.rs @@ -8,10 +8,10 @@ static CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz01 /// Generate a random string of `length` that can be used as a password. The generated string /// contains alphanumeric characters and special characters (!@#$%^&*()) pub fn generate_random_password(length: usize) -> String { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); (0..length) .map(|_| { - let idx = rng.gen_range(0..CHARSET.len()); + let idx = rng.random_range(0..CHARSET.len()); CHARSET[idx] as char }) .collect() diff --git a/storage_controller/src/main.rs b/storage_controller/src/main.rs index 5d21feeb10..34d4ac6fba 100644 --- a/storage_controller/src/main.rs +++ b/storage_controller/src/main.rs @@ -225,6 +225,10 @@ struct Cli { #[arg(long)] shard_split_request_timeout: Option, + + /// **Feature Flag** Whether the storage controller should act to rectify pageserver-reported local disk loss. + #[arg(long, default_value = "false")] + handle_ps_local_disk_loss: bool, } enum StrictMode { @@ -477,6 +481,7 @@ async fn async_main() -> anyhow::Result<()> { .shard_split_request_timeout .map(humantime::Duration::into) .unwrap_or(Duration::MAX), + handle_ps_local_disk_loss: args.handle_ps_local_disk_loss, }; // Validate that we can connect to the database diff --git a/storage_controller/src/persistence.rs b/storage_controller/src/persistence.rs index ed9a268064..619b5f69b8 100644 --- a/storage_controller/src/persistence.rs +++ b/storage_controller/src/persistence.rs @@ -129,7 +129,10 @@ pub(crate) enum DatabaseOperation { UpdateLeader, SetPreferredAzs, InsertTimeline, + UpdateTimeline, UpdateTimelineMembership, + UpdateCplaneNotifiedGeneration, + UpdateSkSetNotifiedGeneration, GetTimeline, InsertTimelineReconcile, RemoveTimelineReconcile, @@ -1463,9 +1466,41 @@ impl Persistence { .await } + /// Update an already present timeline. + /// VERY UNSAFE FUNCTION: this overrides in-progress migrations. Don't use this unless neccessary. + pub(crate) async fn update_timeline_unsafe( + &self, + entry: TimelineUpdate, + ) -> DatabaseResult { + use crate::schema::timelines; + + let entry = &entry; + self.with_measured_conn(DatabaseOperation::UpdateTimeline, move |conn| { + Box::pin(async move { + let inserted_updated = diesel::update(timelines::table) + .filter(timelines::tenant_id.eq(&entry.tenant_id)) + .filter(timelines::timeline_id.eq(&entry.timeline_id)) + .set(entry) + .execute(conn) + .await?; + + match inserted_updated { + 0 => Ok(false), + 1 => Ok(true), + _ => Err(DatabaseError::Logical(format!( + "unexpected number of rows ({inserted_updated})" + ))), + } + }) + }) + .await + } + /// Update timeline membership configuration in the database. /// Perform a compare-and-swap (CAS) operation on the timeline's generation. /// The `new_generation` must be the next (+1) generation after the one in the database. + /// Also inserts reconcile_requests to safekeeper_timeline_pending_ops table in the same + /// transaction. pub(crate) async fn update_timeline_membership( &self, tenant_id: TenantId, @@ -1473,8 +1508,11 @@ impl Persistence { new_generation: SafekeeperGeneration, sk_set: &[NodeId], new_sk_set: Option<&[NodeId]>, + reconcile_requests: &[TimelinePendingOpPersistence], ) -> DatabaseResult<()> { - use crate::schema::timelines::dsl; + use crate::schema::safekeeper_timeline_pending_ops as stpo; + use crate::schema::timelines; + use diesel::query_dsl::methods::FilterDsl; let prev_generation = new_generation.previous().unwrap(); @@ -1482,14 +1520,15 @@ impl Persistence { let timeline_id = &timeline_id; self.with_measured_conn(DatabaseOperation::UpdateTimelineMembership, move |conn| { Box::pin(async move { - let updated = diesel::update(dsl::timelines) - .filter(dsl::tenant_id.eq(&tenant_id.to_string())) - .filter(dsl::timeline_id.eq(&timeline_id.to_string())) - .filter(dsl::generation.eq(prev_generation.into_inner() as i32)) + let updated = diesel::update(timelines::table) + .filter(timelines::tenant_id.eq(&tenant_id.to_string())) + .filter(timelines::timeline_id.eq(&timeline_id.to_string())) + .filter(timelines::generation.eq(prev_generation.into_inner() as i32)) .set(( - dsl::generation.eq(new_generation.into_inner() as i32), - dsl::sk_set.eq(sk_set.iter().map(|id| id.0 as i64).collect::>()), - dsl::new_sk_set.eq(new_sk_set + timelines::generation.eq(new_generation.into_inner() as i32), + timelines::sk_set + .eq(sk_set.iter().map(|id| id.0 as i64).collect::>()), + timelines::new_sk_set.eq(new_sk_set .map(|set| set.iter().map(|id| id.0 as i64).collect::>())), )) .execute(conn) @@ -1499,20 +1538,123 @@ impl Persistence { 0 => { // TODO(diko): It makes sense to select the current generation // and include it in the error message for better debuggability. - Err(DatabaseError::Cas( + return Err(DatabaseError::Cas( "Failed to update membership configuration".to_string(), - )) + )); + } + 1 => {} + _ => { + return Err(DatabaseError::Logical(format!( + "unexpected number of rows ({updated})" + ))); + } + }; + + for req in reconcile_requests { + let inserted_updated = diesel::insert_into(stpo::table) + .values(req) + .on_conflict((stpo::tenant_id, stpo::timeline_id, stpo::sk_id)) + .do_update() + .set(req) + .filter(stpo::generation.lt(req.generation)) + .execute(conn) + .await?; + + if inserted_updated > 1 { + return Err(DatabaseError::Logical(format!( + "unexpected number of rows ({inserted_updated})" + ))); } - 1 => Ok(()), - _ => Err(DatabaseError::Logical(format!( - "unexpected number of rows ({updated})" - ))), } + + Ok(()) }) }) .await } + /// Update the cplane notified generation for a timeline. + /// Perform a compare-and-swap (CAS) operation on the timeline's cplane notified generation. + /// The update will fail if the specified generation is less than the cplane notified generation + /// in the database. + pub(crate) async fn update_cplane_notified_generation( + &self, + tenant_id: TenantId, + timeline_id: TimelineId, + generation: SafekeeperGeneration, + ) -> DatabaseResult<()> { + use crate::schema::timelines::dsl; + + let tenant_id = &tenant_id; + let timeline_id = &timeline_id; + self.with_measured_conn( + DatabaseOperation::UpdateCplaneNotifiedGeneration, + move |conn| { + Box::pin(async move { + let updated = diesel::update(dsl::timelines) + .filter(dsl::tenant_id.eq(&tenant_id.to_string())) + .filter(dsl::timeline_id.eq(&timeline_id.to_string())) + .filter(dsl::cplane_notified_generation.le(generation.into_inner() as i32)) + .set(dsl::cplane_notified_generation.eq(generation.into_inner() as i32)) + .execute(conn) + .await?; + + match updated { + 0 => Err(DatabaseError::Cas( + "Failed to update cplane notified generation".to_string(), + )), + 1 => Ok(()), + _ => Err(DatabaseError::Logical(format!( + "unexpected number of rows ({updated})" + ))), + } + }) + }, + ) + .await + } + + /// Update the sk set notified generation for a timeline. + /// Perform a compare-and-swap (CAS) operation on the timeline's sk set notified generation. + /// The update will fail if the specified generation is less than the sk set notified generation + /// in the database. + pub(crate) async fn update_sk_set_notified_generation( + &self, + tenant_id: TenantId, + timeline_id: TimelineId, + generation: SafekeeperGeneration, + ) -> DatabaseResult<()> { + use crate::schema::timelines::dsl; + + let tenant_id = &tenant_id; + let timeline_id = &timeline_id; + self.with_measured_conn( + DatabaseOperation::UpdateSkSetNotifiedGeneration, + move |conn| { + Box::pin(async move { + let updated = diesel::update(dsl::timelines) + .filter(dsl::tenant_id.eq(&tenant_id.to_string())) + .filter(dsl::timeline_id.eq(&timeline_id.to_string())) + .filter(dsl::sk_set_notified_generation.le(generation.into_inner() as i32)) + .set(dsl::sk_set_notified_generation.eq(generation.into_inner() as i32)) + .execute(conn) + .await?; + + match updated { + 0 => Err(DatabaseError::Cas( + "Failed to update sk set notified generation".to_string(), + )), + 1 => Ok(()), + _ => Err(DatabaseError::Logical(format!( + "unexpected number of rows ({updated})" + ))), + } + }) + }, + ) + .await + } + /// Load timeline from db. Returns `None` if not present. pub(crate) async fn get_timeline( &self, @@ -2462,6 +2604,7 @@ pub(crate) struct TimelinePersistence { pub(crate) new_sk_set: Option>, pub(crate) cplane_notified_generation: i32, pub(crate) deleted_at: Option>, + pub(crate) sk_set_notified_generation: i32, } /// This is separate from [TimelinePersistence] only because postgres allows NULLs @@ -2480,6 +2623,7 @@ pub(crate) struct TimelineFromDb { pub(crate) new_sk_set: Option>>, pub(crate) cplane_notified_generation: i32, pub(crate) deleted_at: Option>, + pub(crate) sk_set_notified_generation: i32, } impl TimelineFromDb { @@ -2499,10 +2643,23 @@ impl TimelineFromDb { new_sk_set, cplane_notified_generation: self.cplane_notified_generation, deleted_at: self.deleted_at, + sk_set_notified_generation: self.sk_set_notified_generation, } } } +// This is separate from TimelinePersistence because we don't want to touch generation and deleted_at values for the update. +#[derive(AsChangeset)] +#[diesel(table_name = crate::schema::timelines)] +#[diesel(treat_none_as_null = true)] +pub(crate) struct TimelineUpdate { + pub(crate) tenant_id: String, + pub(crate) timeline_id: String, + pub(crate) start_lsn: LsnWrapper, + pub(crate) sk_set: Vec, + pub(crate) new_sk_set: Option>, +} + #[derive(Insertable, AsChangeset, Queryable, Selectable, Clone)] #[diesel(table_name = crate::schema::safekeeper_timeline_pending_ops)] pub(crate) struct TimelinePendingOpPersistence { diff --git a/storage_controller/src/scheduler.rs b/storage_controller/src/scheduler.rs index b86b4dfab1..23f002d32a 100644 --- a/storage_controller/src/scheduler.rs +++ b/storage_controller/src/scheduler.rs @@ -981,7 +981,7 @@ mod tests { use pageserver_api::models::utilization::test_utilization; use pageserver_api::shard::ShardIdentity; use utils::id::TenantId; - use utils::shard::{ShardCount, ShardNumber, TenantShardId}; + use utils::shard::{ShardCount, ShardNumber, ShardStripeSize, TenantShardId}; use super::*; use crate::tenant_shard::IntentState; @@ -1337,7 +1337,7 @@ mod tests { let shard_identity = ShardIdentity::new( tenant_shard_id.shard_number, tenant_shard_id.shard_count, - pageserver_api::shard::ShardStripeSize(1), + ShardStripeSize(1), ) .unwrap(); let mut shard = TenantShard::new( @@ -1411,7 +1411,7 @@ mod tests { let shard_identity = ShardIdentity::new( tenant_shard_id.shard_number, tenant_shard_id.shard_count, - pageserver_api::shard::ShardStripeSize(1), + ShardStripeSize(1), ) .unwrap(); let mut shard = TenantShard::new( @@ -1573,7 +1573,7 @@ mod tests { let shard_identity = ShardIdentity::new( tenant_shard_id.shard_number, tenant_shard_id.shard_count, - pageserver_api::shard::ShardStripeSize(1), + ShardStripeSize(1), ) .unwrap(); // 1 attached and 1 secondary. diff --git a/storage_controller/src/schema.rs b/storage_controller/src/schema.rs index f3dcdaf798..def519c168 100644 --- a/storage_controller/src/schema.rs +++ b/storage_controller/src/schema.rs @@ -118,6 +118,7 @@ diesel::table! { new_sk_set -> Nullable>>, cplane_notified_generation -> Int4, deleted_at -> Nullable, + sk_set_notified_generation -> Int4, } } diff --git a/storage_controller/src/service.rs b/storage_controller/src/service.rs index 71186076ec..8f5efe8ac4 100644 --- a/storage_controller/src/service.rs +++ b/storage_controller/src/service.rs @@ -487,6 +487,9 @@ pub struct Config { /// Timeout used for HTTP client of split requests. [`Duration::MAX`] if None. pub shard_split_request_timeout: Duration, + + // Feature flag: Whether the storage controller should act to rectify pageserver-reported local disk loss. + pub handle_ps_local_disk_loss: bool, } impl From for ApiError { @@ -2388,6 +2391,33 @@ impl Service { tenants: Vec::new(), }; + // [Hadron] If the pageserver reports in the reattach message that it has an empty disk, it's possible that it just + // recovered from a local disk failure. The response of the reattach request will contain a list of tenants but it + // will not be honored by the pageserver in this case (disk failure). We should make sure we clear any observed + // locations of tenants attached to the node so that the reconciler will discover the discrpancy and reconfigure the + // missing tenants on the node properly. + if self.config.handle_ps_local_disk_loss && reattach_req.empty_local_disk.unwrap_or(false) { + tracing::info!( + "Pageserver {node_id} reports empty local disk, clearing observed locations referencing the pageserver for all tenants", + node_id = reattach_req.node_id + ); + let mut num_tenant_shards_affected = 0; + for (tenant_shard_id, shard) in tenants.iter_mut() { + if shard + .observed + .locations + .remove(&reattach_req.node_id) + .is_some() + { + tracing::info!("Cleared observed location for tenant shard {tenant_shard_id}"); + num_tenant_shards_affected += 1; + } + } + tracing::info!( + "Cleared observed locations for {num_tenant_shards_affected} tenant shards" + ); + } + // TODO: cancel/restart any running reconciliation for this tenant, it might be trying // to call location_conf API with an old generation. Wait for cancellation to complete // before responding to this request. Requires well implemented CancellationToken logic diff --git a/storage_controller/src/service/chaos_injector.rs b/storage_controller/src/service/chaos_injector.rs index 4087de200a..0efeef4e80 100644 --- a/storage_controller/src/service/chaos_injector.rs +++ b/storage_controller/src/service/chaos_injector.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use std::time::Duration; use pageserver_api::controller_api::ShardSchedulingPolicy; -use rand::seq::SliceRandom; -use rand::{Rng, thread_rng}; +use rand::Rng; +use rand::seq::{IndexedRandom, SliceRandom}; use tokio_util::sync::CancellationToken; use utils::id::NodeId; use utils::shard::TenantShardId; @@ -72,7 +72,7 @@ impl ChaosInjector { let cron_interval = self.get_cron_interval_sleep_future(); let chaos_type = tokio::select! { _ = interval.tick() => { - if thread_rng().gen_bool(0.5) { + if rand::rng().random_bool(0.5) { ChaosEvent::MigrationsToSecondary } else { ChaosEvent::GracefulMigrationsAnywhere @@ -134,7 +134,7 @@ impl ChaosInjector { let Some(new_location) = shard .intent .get_secondary() - .choose(&mut thread_rng()) + .choose(&mut rand::rng()) .cloned() else { tracing::info!( @@ -190,7 +190,7 @@ impl ChaosInjector { // Pick our victims: use a hand-rolled loop rather than choose_multiple() because we want // to take the mutable refs from our candidates rather than ref'ing them. while !candidates.is_empty() && victims.len() < batch_size { - let i = thread_rng().gen_range(0..candidates.len()); + let i = rand::rng().random_range(0..candidates.len()); victims.push(candidates.swap_remove(i)); } @@ -210,7 +210,7 @@ impl ChaosInjector { }) .collect::>(); - let Some(victim_node) = candidate_nodes.choose(&mut thread_rng()) else { + let Some(victim_node) = candidate_nodes.choose(&mut rand::rng()) else { // This can happen if e.g. we are in a small region with only one pageserver per AZ. tracing::info!( "no candidate nodes found for migrating shard {tenant_shard_id} within its home AZ", @@ -264,7 +264,7 @@ impl ChaosInjector { out_of_home_az.len() ); - out_of_home_az.shuffle(&mut thread_rng()); + out_of_home_az.shuffle(&mut rand::rng()); victims.extend(out_of_home_az.into_iter().take(batch_size)); } else { tracing::info!( @@ -274,7 +274,7 @@ impl ChaosInjector { ); victims.extend(out_of_home_az); - in_home_az.shuffle(&mut thread_rng()); + in_home_az.shuffle(&mut rand::rng()); victims.extend(in_home_az.into_iter().take(batch_size - victims.len())); } diff --git a/storage_controller/src/service/safekeeper_service.rs b/storage_controller/src/service/safekeeper_service.rs index 7521d7bd86..bc77a1a6b8 100644 --- a/storage_controller/src/service/safekeeper_service.rs +++ b/storage_controller/src/service/safekeeper_service.rs @@ -10,6 +10,7 @@ use crate::id_lock_map::trace_shared_lock; use crate::metrics; use crate::persistence::{ DatabaseError, SafekeeperTimelineOpKind, TimelinePendingOpPersistence, TimelinePersistence, + TimelineUpdate, }; use crate::safekeeper::Safekeeper; use crate::safekeeper_client::SafekeeperClient; @@ -311,6 +312,7 @@ impl Service { new_sk_set: None, cplane_notified_generation: 0, deleted_at: None, + sk_set_notified_generation: 0, }; let inserted = self .persistence @@ -454,19 +456,34 @@ impl Service { let persistence = TimelinePersistence { tenant_id: req.tenant_id.to_string(), timeline_id: req.timeline_id.to_string(), - start_lsn: Lsn::INVALID.into(), + start_lsn: req.start_lsn.into(), generation: 1, sk_set: req.sk_set.iter().map(|sk_id| sk_id.0 as i64).collect(), new_sk_set: None, cplane_notified_generation: 1, deleted_at: None, + sk_set_notified_generation: 1, }; - let inserted = self.persistence.insert_timeline(persistence).await?; + let inserted = self + .persistence + .insert_timeline(persistence.clone()) + .await?; if inserted { tracing::info!("imported timeline into db"); - } else { - tracing::info!("didn't import timeline into db, as it is already present in db"); + return Ok(()); } + tracing::info!("timeline already present in db, updating"); + + let update = TimelineUpdate { + tenant_id: persistence.tenant_id, + timeline_id: persistence.timeline_id, + start_lsn: persistence.start_lsn, + sk_set: persistence.sk_set, + new_sk_set: persistence.new_sk_set, + }; + self.persistence.update_timeline_unsafe(update).await?; + tracing::info!("timeline updated"); + Ok(()) } @@ -879,17 +896,21 @@ impl Service { /// If min_position is not None, validates that majority of safekeepers /// reached at least min_position. /// + /// If update_notified_generation is set, also updates sk_set_notified_generation + /// in the timelines table. + /// /// Return responses from safekeepers in the input order. async fn tenant_timeline_set_membership_quorum( self: &Arc, tenant_id: TenantId, timeline_id: TimelineId, safekeepers: &[Safekeeper], - config: &membership::Configuration, + mconf: &membership::Configuration, min_position: Option<(Term, Lsn)>, + update_notified_generation: bool, ) -> Result>, ApiError> { let req = TimelineMembershipSwitchRequest { - mconf: config.clone(), + mconf: mconf.clone(), }; const SK_SET_MEM_TIMELINE_RECONCILE_TIMEOUT: Duration = Duration::from_secs(30); @@ -930,28 +951,34 @@ impl Service { .await?; for res in results.iter().flatten() { - if res.current_conf.generation > config.generation { + if res.current_conf.generation > mconf.generation { // Antoher switch_membership raced us. return Err(ApiError::Conflict(format!( "received configuration with generation {} from safekeeper, but expected {}", - res.current_conf.generation, config.generation + res.current_conf.generation, mconf.generation ))); - } else if res.current_conf.generation < config.generation { + } else if res.current_conf.generation < mconf.generation { // Note: should never happen. // If we get a response, it should be at least the sent generation. tracing::error!( "received configuration with generation {} from safekeeper, but expected {}", res.current_conf.generation, - config.generation + mconf.generation ); return Err(ApiError::InternalServerError(anyhow::anyhow!( "received configuration with generation {} from safekeeper, but expected {}", res.current_conf.generation, - config.generation + mconf.generation ))); } } + if update_notified_generation { + self.persistence + .update_sk_set_notified_generation(tenant_id, timeline_id, mconf.generation) + .await?; + } + Ok(results) } @@ -1020,17 +1047,22 @@ impl Service { } /// Exclude a timeline from safekeepers in parallel with retries. - /// If an exclude request is unsuccessful, it will be added to - /// the reconciler, and after that the function will succeed. - async fn tenant_timeline_safekeeper_exclude( + /// + /// Assumes that the exclude requests are already persistent in the database. + /// + /// The function does best effort: if an exclude request is unsuccessful, + /// it will be added to the in-memory reconciler, and the function will succeed anyway. + /// + /// Might fail if there is error accessing the database. + async fn tenant_timeline_safekeeper_exclude_reconcile( self: &Arc, tenant_id: TenantId, timeline_id: TimelineId, safekeepers: &[Safekeeper], - config: &membership::Configuration, + mconf: &membership::Configuration, ) -> Result<(), ApiError> { let req = TimelineMembershipSwitchRequest { - mconf: config.clone(), + mconf: mconf.clone(), }; const SK_EXCLUDE_TIMELINE_TIMEOUT: Duration = Duration::from_secs(30); @@ -1048,25 +1080,32 @@ impl Service { let mut reconcile_requests = Vec::new(); - for (idx, res) in results.iter().enumerate() { - if res.is_err() { - let sk_id = safekeepers[idx].skp.id; - let pending_op = TimelinePendingOpPersistence { - tenant_id: tenant_id.to_string(), - timeline_id: timeline_id.to_string(), - generation: config.generation.into_inner() as i32, - op_kind: SafekeeperTimelineOpKind::Exclude, - sk_id, - }; - tracing::info!("writing pending exclude op for sk id {sk_id}"); - self.persistence.insert_pending_op(pending_op).await?; + fail::fail_point!("sk-migration-step-9-mid-exclude", |_| { + Err(ApiError::BadRequest(anyhow::anyhow!( + "failpoint sk-migration-step-9-mid-exclude" + ))) + }); + for (idx, res) in results.iter().enumerate() { + let sk_id = safekeepers[idx].skp.id; + let generation = mconf.generation.into_inner(); + + if res.is_ok() { + self.persistence + .remove_pending_op( + tenant_id, + Some(timeline_id), + NodeId(sk_id as u64), + generation, + ) + .await?; + } else { let req = ScheduleRequest { safekeeper: Box::new(safekeepers[idx].clone()), host_list: Vec::new(), tenant_id, timeline_id: Some(timeline_id), - generation: config.generation.into_inner(), + generation, kind: SafekeeperTimelineOpKind::Exclude, }; reconcile_requests.push(req); @@ -1193,6 +1232,22 @@ impl Service { } // It it is the same new_sk_set, we can continue the migration (retry). } else { + let prev_finished = timeline.cplane_notified_generation == timeline.generation + && timeline.sk_set_notified_generation == timeline.generation; + + if !prev_finished { + // The previous migration is committed, but the finish step failed. + // Safekeepers/cplane might not know about the last membership configuration. + // Retry the finish step to ensure smooth migration. + self.finish_safekeeper_migration_retry(tenant_id, timeline_id, &timeline) + .await?; + } + + if cur_sk_set == new_sk_set { + tracing::info!("timeline is already at the desired safekeeper set"); + return Ok(()); + } + // 3. No active migration yet. // Increment current generation and put desired_set to new_sk_set. generation = generation.next(); @@ -1204,8 +1259,15 @@ impl Service { generation, &cur_sk_set, Some(&new_sk_set), + &[], ) .await?; + + fail::fail_point!("sk-migration-after-step-3", |_| { + Err(ApiError::BadRequest(anyhow::anyhow!( + "failpoint sk-migration-after-step-3" + ))) + }); } let cur_safekeepers = self.get_safekeepers(&timeline.sk_set)?; @@ -1234,6 +1296,7 @@ impl Service { &cur_safekeepers, &joint_config, None, // no min position + true, // update notified generation ) .await?; @@ -1251,6 +1314,12 @@ impl Service { "safekeepers set membership updated", ); + fail::fail_point!("sk-migration-after-step-4", |_| { + Err(ApiError::BadRequest(anyhow::anyhow!( + "failpoint sk-migration-after-step-4" + ))) + }); + // 5. Initialize timeline on safekeeper(s) from new_sk_set where it doesn't exist yet // by doing pull_timeline from the majority of the current set. @@ -1270,6 +1339,12 @@ impl Service { ) .await?; + fail::fail_point!("sk-migration-after-step-5", |_| { + Err(ApiError::BadRequest(anyhow::anyhow!( + "failpoint sk-migration-after-step-5" + ))) + }); + // 6. Call POST bump_term(sync_term) on safekeepers from the new set. Success on majority is enough. // TODO(diko): do we need to bump timeline term? @@ -1285,9 +1360,16 @@ impl Service { &new_safekeepers, &joint_config, Some(sync_position), + false, // we're just waiting for sync position, don't update notified generation ) .await?; + fail::fail_point!("sk-migration-after-step-7", |_| { + Err(ApiError::BadRequest(anyhow::anyhow!( + "failpoint sk-migration-after-step-7" + ))) + }); + // 8. Create new_conf: Configuration incrementing joint_conf generation and // having new safekeeper set as sk_set and None new_sk_set. @@ -1299,45 +1381,55 @@ impl Service { new_members: None, }; - self.persistence - .update_timeline_membership(tenant_id, timeline_id, generation, &new_sk_set, None) - .await?; - - // TODO(diko): at this point we have already updated the timeline in the database, - // but we still need to notify safekeepers and cplane about the new configuration, - // and put delition of the timeline from the old safekeepers into the reconciler. - // Ideally it should be done atomically, but now it's not. - // Worst case: the timeline is not deleted from old safekeepers, - // the compute may require both quorums till the migration is retried and completed. - - self.tenant_timeline_set_membership_quorum( - tenant_id, - timeline_id, - &new_safekeepers, - &new_conf, - None, // no min position - ) - .await?; - let new_ids: HashSet = new_safekeepers.iter().map(|sk| sk.get_id()).collect(); let exclude_safekeepers = cur_safekeepers .into_iter() .filter(|sk| !new_ids.contains(&sk.get_id())) .collect::>(); - self.tenant_timeline_safekeeper_exclude( + let exclude_requests = exclude_safekeepers + .iter() + .map(|sk| TimelinePendingOpPersistence { + sk_id: sk.skp.id, + tenant_id: tenant_id.to_string(), + timeline_id: timeline_id.to_string(), + generation: generation.into_inner() as i32, + op_kind: SafekeeperTimelineOpKind::Exclude, + }) + .collect::>(); + + self.persistence + .update_timeline_membership( + tenant_id, + timeline_id, + generation, + &new_sk_set, + None, + &exclude_requests, + ) + .await?; + + fail::fail_point!("sk-migration-after-step-8", |_| { + Err(ApiError::BadRequest(anyhow::anyhow!( + "failpoint sk-migration-after-step-8" + ))) + }); + + // At this point we have already updated the timeline in the database, so the final + // membership configuration is commited and the migration is not abortable anymore. + // But safekeepers and cplane/compute still need to be notified about the new configuration. + // The [`Self::finish_safekeeper_migration`] does exactly that: notifies everyone about + // the new configuration and reconciles excluded safekeepers. + // If it fails, the safkeeper migration call should be retried. + + self.finish_safekeeper_migration( tenant_id, timeline_id, - &exclude_safekeepers, + &new_safekeepers, &new_conf, + &exclude_safekeepers, ) .await?; - // Notify cplane/compute about the membership change AFTER changing the membership on safekeepers. - // This way the compute will stop talking to excluded safekeepers only after we stop requiring to - // collect a quorum from them. - self.cplane_notify_safekeepers(tenant_id, timeline_id, &new_conf) - .await?; - Ok(()) } @@ -1381,6 +1473,130 @@ impl Service { ApiError::InternalServerError(anyhow::anyhow!( "failed to notify cplane about safekeeper membership change: {err}" )) - }) + })?; + + self.persistence + .update_cplane_notified_generation(tenant_id, timeline_id, mconf.generation) + .await?; + + Ok(()) + } + + /// Finish safekeeper migration. + /// + /// It is the last step of the safekeeper migration. + /// + /// Notifies safekeepers and cplane about the final membership configuration, + /// reconciles excluded safekeepers and updates *_notified_generation in the database. + async fn finish_safekeeper_migration( + self: &Arc, + tenant_id: TenantId, + timeline_id: TimelineId, + new_safekeepers: &[Safekeeper], + new_conf: &membership::Configuration, + exclude_safekeepers: &[Safekeeper], + ) -> Result<(), ApiError> { + // 9. Call PUT configuration on safekeepers from the new set, delivering them new_conf. + // Also try to exclude safekeepers and notify cplane about the membership change. + + self.tenant_timeline_set_membership_quorum( + tenant_id, + timeline_id, + new_safekeepers, + new_conf, + None, // no min position + true, // update notified generation + ) + .await?; + + fail::fail_point!("sk-migration-step-9-after-set-membership", |_| { + Err(ApiError::BadRequest(anyhow::anyhow!( + "failpoint sk-migration-step-9-after-set-membership" + ))) + }); + + self.tenant_timeline_safekeeper_exclude_reconcile( + tenant_id, + timeline_id, + exclude_safekeepers, + new_conf, + ) + .await?; + + fail::fail_point!("sk-migration-step-9-after-exclude", |_| { + Err(ApiError::BadRequest(anyhow::anyhow!( + "failpoint sk-migration-step-9-after-exclude" + ))) + }); + + // Notify cplane/compute about the membership change AFTER changing the membership on safekeepers. + // This way the compute will stop talking to excluded safekeepers only after we stop requiring to + // collect a quorum from them. + self.cplane_notify_safekeepers(tenant_id, timeline_id, new_conf) + .await?; + + fail::fail_point!("sk-migration-after-step-9", |_| { + Err(ApiError::BadRequest(anyhow::anyhow!( + "failpoint sk-migration-after-step-9" + ))) + }); + + Ok(()) + } + + /// Same as [`Self::finish_safekeeper_migration`], but restores the migration state from the database. + /// It's used when the migration failed during the finish step and we need to retry it. + async fn finish_safekeeper_migration_retry( + self: &Arc, + tenant_id: TenantId, + timeline_id: TimelineId, + timeline: &TimelinePersistence, + ) -> Result<(), ApiError> { + if timeline.new_sk_set.is_some() { + // Logical error, should never happen. + return Err(ApiError::InternalServerError(anyhow::anyhow!( + "can't finish timeline migration for {tenant_id}/{timeline_id}: new_sk_set is not None" + ))); + } + + let cur_safekeepers = self.get_safekeepers(&timeline.sk_set)?; + let cur_sk_member_set = + Self::make_member_set(&cur_safekeepers).map_err(ApiError::InternalServerError)?; + + let mconf = membership::Configuration { + generation: SafekeeperGeneration::new(timeline.generation as u32), + members: cur_sk_member_set, + new_members: None, + }; + + // We might have failed between commiting reconciliation requests and adding them to the in-memory reconciler. + // Reload them from the database. + let pending_ops = self + .persistence + .list_pending_ops_for_timeline(tenant_id, timeline_id) + .await?; + + let mut exclude_sk_ids = Vec::new(); + + for op in pending_ops { + if op.op_kind == SafekeeperTimelineOpKind::Exclude + && op.generation == timeline.generation + { + exclude_sk_ids.push(op.sk_id); + } + } + + let exclude_safekeepers = self.get_safekeepers(&exclude_sk_ids)?; + + self.finish_safekeeper_migration( + tenant_id, + timeline_id, + &cur_safekeepers, + &mconf, + &exclude_safekeepers, + ) + .await?; + + Ok(()) } } diff --git a/test_runner/fixtures/endpoint/http.py b/test_runner/fixtures/endpoint/http.py index 1d278095ce..c43445e89d 100644 --- a/test_runner/fixtures/endpoint/http.py +++ b/test_runner/fixtures/endpoint/http.py @@ -66,6 +66,12 @@ class EndpointHttpClient(requests.Session): res.raise_for_status() return res.json() + def autoscaling_metrics(self): + res = self.get(f"http://localhost:{self.external_port}/autoscaling_metrics") + res.raise_for_status() + log.debug("raw compute metrics: %s", res.text) + return res.text + def prewarm_lfc_status(self) -> dict[str, str]: res = self.get(self.prewarm_url) res.raise_for_status() diff --git a/test_runner/fixtures/neon_api.py b/test_runner/fixtures/neon_api.py index bb618325e0..b26bcb286c 100644 --- a/test_runner/fixtures/neon_api.py +++ b/test_runner/fixtures/neon_api.py @@ -24,6 +24,7 @@ def connection_parameters_to_env(params: dict[str, str]) -> dict[str, str]: # Some API calls not yet implemented. # You may want to copy not-yet-implemented methods from the PR https://github.com/neondatabase/neon/pull/11305 +@final class NeonAPI: def __init__(self, neon_api_key: str, neon_api_base_url: str): self.__neon_api_key = neon_api_key @@ -170,7 +171,7 @@ class NeonAPI: protected: bool | None = None, archived: bool | None = None, init_source: str | None = None, - add_endpoint=True, + add_endpoint: bool = True, ) -> dict[str, Any]: data: dict[str, Any] = {} if add_endpoint: diff --git a/test_runner/fixtures/neon_cli.py b/test_runner/fixtures/neon_cli.py index f33d4a0d22..5ad00d155e 100644 --- a/test_runner/fixtures/neon_cli.py +++ b/test_runner/fixtures/neon_cli.py @@ -400,6 +400,7 @@ class NeonLocalCli(AbstractNeonCli): timeout_in_seconds: int | None = None, instance_id: int | None = None, base_port: int | None = None, + handle_ps_local_disk_loss: bool | None = None, ): cmd = ["storage_controller", "start"] if timeout_in_seconds is not None: @@ -408,6 +409,10 @@ class NeonLocalCli(AbstractNeonCli): cmd.append(f"--instance-id={instance_id}") if base_port is not None: cmd.append(f"--base-port={base_port}") + if handle_ps_local_disk_loss is not None: + cmd.append( + f"--handle-ps-local-disk-loss={'true' if handle_ps_local_disk_loss else 'false'}" + ) return self.raw_cli(cmd) def storage_controller_stop(self, immediate: bool, instance_id: int | None = None): diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index fc33fb45c1..ee0a2f4fe9 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -1938,9 +1938,12 @@ class NeonStorageController(MetricsGetter, LogUtils): timeout_in_seconds: int | None = None, instance_id: int | None = None, base_port: int | None = None, + handle_ps_local_disk_loss: bool | None = None, ) -> Self: assert not self.running - self.env.neon_cli.storage_controller_start(timeout_in_seconds, instance_id, base_port) + self.env.neon_cli.storage_controller_start( + timeout_in_seconds, instance_id, base_port, handle_ps_local_disk_loss + ) self.running = True return self @@ -2838,10 +2841,13 @@ class NeonProxiedStorageController(NeonStorageController): timeout_in_seconds: int | None = None, instance_id: int | None = None, base_port: int | None = None, + handle_ps_local_disk_loss: bool | None = None, ) -> Self: assert instance_id is not None and base_port is not None - self.env.neon_cli.storage_controller_start(timeout_in_seconds, instance_id, base_port) + self.env.neon_cli.storage_controller_start( + timeout_in_seconds, instance_id, base_port, handle_ps_local_disk_loss + ) self.instances[instance_id] = {"running": True} self.running = True @@ -4121,6 +4127,294 @@ class NeonAuthBroker: self._popen.kill() +class NeonLocalProxy(LogUtils): + """ + An object managing a local_proxy instance for rest broker testing. + The local_proxy serves as a direct connection to VanillaPostgres. + """ + + def __init__( + self, + neon_binpath: Path, + test_output_dir: Path, + http_port: int, + metrics_port: int, + vanilla_pg: VanillaPostgres, + config_path: Path | None = None, + ): + self.neon_binpath = neon_binpath + self.test_output_dir = test_output_dir + self.http_port = http_port + self.metrics_port = metrics_port + self.vanilla_pg = vanilla_pg + self.config_path = config_path or (test_output_dir / "local_proxy.json") + self.host = "127.0.0.1" + self.running = False + self.logfile = test_output_dir / "local_proxy.log" + self._popen: subprocess.Popen[bytes] | None = None + super().__init__(logfile=self.logfile) + + def start(self) -> Self: + assert self._popen is None + assert not self.running + + # Ensure vanilla_pg is running + if not self.vanilla_pg.is_running(): + self.vanilla_pg.start() + + args = [ + str(self.neon_binpath / "local_proxy"), + "--http", + f"{self.host}:{self.http_port}", + "--metrics", + f"{self.host}:{self.metrics_port}", + "--postgres", + f"127.0.0.1:{self.vanilla_pg.default_options['port']}", + "--config-path", + str(self.config_path), + "--disable-pg-session-jwt", + ] + + logfile = open(self.logfile, "w") + self._popen = subprocess.Popen(args, stdout=logfile, stderr=logfile) + self.running = True + self._wait_until_ready() + return self + + def stop(self) -> Self: + if self._popen is not None and self.running: + self._popen.terminate() + try: + self._popen.wait(timeout=5) + except subprocess.TimeoutExpired: + log.warning("failed to gracefully terminate local_proxy; killing") + self._popen.kill() + self.running = False + return self + + def get_binary_version(self) -> str: + """Get the version string of the local_proxy binary""" + try: + result = subprocess.run( + [str(self.neon_binpath / "local_proxy"), "--version"], + capture_output=True, + text=True, + timeout=10, + ) + return result.stdout.strip() + except (subprocess.TimeoutExpired, subprocess.CalledProcessError): + return "" + + @backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=10) + def _wait_until_ready(self): + assert self._popen and self._popen.poll() is None, ( + "Local proxy exited unexpectedly. Check test log." + ) + requests.get(f"http://{self.host}:{self.http_port}/metrics") + + def get_metrics(self) -> str: + response = requests.get(f"http://{self.host}:{self.metrics_port}/metrics") + return response.text + + def assert_no_errors(self): + # Define allowed error patterns for local_proxy + allowed_errors = [ + # Add patterns as needed + ] + not_allowed = [ + "error", + "panic", + "failed", + ] + + for na in not_allowed: + if na not in allowed_errors: + assert not self.log_contains(na), f"Found disallowed error pattern: {na}" + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ): + self.stop() + + +class NeonRestBrokerProxy(LogUtils): + """ + An object managing a proxy instance configured as both auth broker and rest broker. + This is the main proxy binary with --is-auth-broker and --is-rest-broker flags. + """ + + def __init__( + self, + neon_binpath: Path, + test_output_dir: Path, + wss_port: int, + http_port: int, + mgmt_port: int, + config_path: Path | None = None, + ): + self.neon_binpath = neon_binpath + self.test_output_dir = test_output_dir + self.wss_port = wss_port + self.http_port = http_port + self.mgmt_port = mgmt_port + self.config_path = config_path or (test_output_dir / "rest_broker_proxy.json") + self.host = "127.0.0.1" + self.running = False + self.logfile = test_output_dir / "rest_broker_proxy.log" + self._popen: subprocess.Popen[Any] | None = None + + def start(self) -> Self: + if self.running: + return self + + # Generate self-signed TLS certificates + cert_path = self.test_output_dir / "server.crt" + key_path = self.test_output_dir / "server.key" + + if not cert_path.exists() or not key_path.exists(): + import subprocess + + log.info("Generating self-signed TLS certificate for rest broker") + subprocess.run( + [ + "openssl", + "req", + "-new", + "-x509", + "-days", + "365", + "-nodes", + "-text", + "-out", + str(cert_path), + "-keyout", + str(key_path), + "-subj", + "/CN=*.local.neon.build", + ], + check=True, + ) + + log.info( + f"Starting rest broker proxy on WSS port {self.wss_port}, HTTP port {self.http_port}" + ) + + cmd = [ + str(self.neon_binpath / "proxy"), + "-c", + str(cert_path), + "-k", + str(key_path), + "--is-auth-broker", + "true", + "--is-rest-broker", + "true", + "--wss", + f"{self.host}:{self.wss_port}", + "--http", + f"{self.host}:{self.http_port}", + "--mgmt", + f"{self.host}:{self.mgmt_port}", + "--auth-backend", + "local", + "--config-path", + str(self.config_path), + ] + + log.info(f"Starting rest broker proxy with command: {' '.join(cmd)}") + + with open(self.logfile, "w") as logfile: + self._popen = subprocess.Popen( + cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + cwd=self.test_output_dir, + env={ + **os.environ, + "RUST_LOG": "info", + "LOGFMT": "text", + "OTEL_SDK_DISABLED": "true", + }, + ) + + self.running = True + self._wait_until_ready() + return self + + def stop(self) -> Self: + if not self.running: + return self + + log.info("Stopping rest broker proxy") + + if self._popen is not None: + self._popen.terminate() + try: + self._popen.wait(timeout=10) + except subprocess.TimeoutExpired: + log.warning("failed to gracefully terminate rest broker proxy; killing") + self._popen.kill() + + self.running = False + return self + + def get_binary_version(self) -> str: + cmd = [str(self.neon_binpath / "proxy"), "--version"] + res = subprocess.run(cmd, capture_output=True, text=True, check=True) + return res.stdout.strip() + + @backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=10) + def _wait_until_ready(self): + # Check if the WSS port is ready using a simple HTTPS request + # REST API is served on the WSS port with HTTPS + requests.get(f"https://{self.host}:{self.wss_port}/", timeout=1, verify=False) + # Any response (even error) means the server is up - we just need to connect + + def get_metrics(self) -> str: + # Metrics are still on the HTTP port + response = requests.get(f"http://{self.host}:{self.http_port}/metrics", timeout=5) + response.raise_for_status() + return response.text + + def assert_no_errors(self): + # Define allowed error patterns for rest broker proxy + allowed_errors = [ + "connection closed before message completed", + "connection reset by peer", + "broken pipe", + "client disconnected", + "Authentication failed", + "connection timed out", + "no connection available", + "Pool dropped", + ] + + with open(self.logfile) as f: + for line in f: + if "ERROR" in line or "FATAL" in line: + if not any(allowed in line for allowed in allowed_errors): + raise AssertionError( + f"Found error in rest broker proxy log: {line.strip()}" + ) + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ): + self.stop() + + @pytest.fixture(scope="function") def link_proxy( port_distributor: PortDistributor, neon_binpath: Path, test_output_dir: Path @@ -4203,6 +4497,81 @@ def static_proxy( yield proxy +@pytest.fixture(scope="function") +def local_proxy( + vanilla_pg: VanillaPostgres, + port_distributor: PortDistributor, + neon_binpath: Path, + test_output_dir: Path, +) -> Iterator[NeonLocalProxy]: + """Local proxy that connects directly to vanilla postgres for rest broker testing.""" + + # Start vanilla_pg without database bootstrapping + vanilla_pg.start() + + http_port = port_distributor.get_port() + metrics_port = port_distributor.get_port() + + with NeonLocalProxy( + neon_binpath=neon_binpath, + test_output_dir=test_output_dir, + http_port=http_port, + metrics_port=metrics_port, + vanilla_pg=vanilla_pg, + ) as proxy: + proxy.start() + yield proxy + + +@pytest.fixture(scope="function") +def local_proxy_fixed_port( + vanilla_pg: VanillaPostgres, + neon_binpath: Path, + test_output_dir: Path, +) -> Iterator[NeonLocalProxy]: + """Local proxy that connects directly to vanilla postgres on the hardcoded port 7432.""" + + # Start vanilla_pg without database bootstrapping + vanilla_pg.start() + + # Use the hardcoded port that the rest broker proxy expects + http_port = 7432 + metrics_port = 7433 # Use a different port for metrics + + with NeonLocalProxy( + neon_binpath=neon_binpath, + test_output_dir=test_output_dir, + http_port=http_port, + metrics_port=metrics_port, + vanilla_pg=vanilla_pg, + ) as proxy: + proxy.start() + yield proxy + + +@pytest.fixture(scope="function") +def rest_broker_proxy( + port_distributor: PortDistributor, + neon_binpath: Path, + test_output_dir: Path, +) -> Iterator[NeonRestBrokerProxy]: + """Rest broker proxy that handles both auth broker and rest broker functionality.""" + + wss_port = port_distributor.get_port() + http_port = port_distributor.get_port() + mgmt_port = port_distributor.get_port() + + with NeonRestBrokerProxy( + neon_binpath=neon_binpath, + test_output_dir=test_output_dir, + wss_port=wss_port, + http_port=http_port, + mgmt_port=mgmt_port, + ) as proxy: + proxy.start() + yield proxy + + @pytest.fixture(scope="function") def neon_authorize_jwk() -> jwk.JWK: kid = str(uuid.uuid4()) @@ -5435,7 +5804,7 @@ SKIP_FILES = frozenset( "postmaster.pid", "pg_control", "pg_dynshmem", - ".metrics.socket", + "neon-communicator.socket", ) ) diff --git a/test_runner/fixtures/pageserver/allowed_errors.py b/test_runner/fixtures/pageserver/allowed_errors.py index 59249f31ad..007f80ee5e 100755 --- a/test_runner/fixtures/pageserver/allowed_errors.py +++ b/test_runner/fixtures/pageserver/allowed_errors.py @@ -152,6 +152,8 @@ DEFAULT_STORAGE_CONTROLLER_ALLOWED_ERRORS = [ ".*reconciler.*neon_local error.*", # Tenant rate limits may fire in tests that submit lots of API requests. ".*tenant \\S+ is rate limited.*", + # Reconciliations may get stuck/delayed e.g. in chaos tests. + ".*background_reconcile: Shard reconciliation is stuck.*", ] diff --git a/test_runner/fixtures/utils.py b/test_runner/fixtures/utils.py index 0d7345cc82..1f80c2a290 100644 --- a/test_runner/fixtures/utils.py +++ b/test_runner/fixtures/utils.py @@ -741,3 +741,29 @@ def shared_buffers_for_max_cu(max_cu: float) -> str: sharedBuffersMb = int(max(128, (1023 + maxBackends * 256) / 1024)) sharedBuffers = int(sharedBuffersMb * 1024 / 8) return str(sharedBuffers) + + +def skip_if_proxy_lacks_rest_broker(reason: str = "proxy was built without 'rest_broker' feature"): + # Determine the binary path using the same logic as neon_binpath fixture + def has_rest_broker_feature(): + # Find the neon binaries + if env_neon_bin := os.environ.get("NEON_BIN"): + binpath = Path(env_neon_bin) + else: + base_dir = Path(__file__).parents[2] # Same as BASE_DIR in paths.py + build_type = os.environ.get("BUILD_TYPE", "debug") + binpath = base_dir / "target" / build_type + + proxy_bin = binpath / "proxy" + if not proxy_bin.exists(): + return False + + try: + cmd = [str(proxy_bin), "--help"] + result = subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=10) + help_output = result.stdout + return "--is-rest-broker" in help_output + except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError): + return False + + return pytest.mark.skipif(not has_rest_broker_feature(), reason=reason) diff --git a/test_runner/regress/test_communicator_metrics_exporter.py b/test_runner/regress/test_communicator_metrics_exporter.py new file mode 100644 index 0000000000..0e3e76910a --- /dev/null +++ b/test_runner/regress/test_communicator_metrics_exporter.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +import pytest +import requests +import requests_unixsocket # type: ignore [import-untyped] +from fixtures.metrics import parse_metrics + +if TYPE_CHECKING: + from fixtures.neon_fixtures import NeonEnv + +NEON_COMMUNICATOR_SOCKET_NAME = "neon-communicator.socket" + + +def test_communicator_metrics(neon_simple_env: NeonEnv): + """ + Test the communicator's built-in HTTP prometheus exporter + """ + env = neon_simple_env + + endpoint = env.endpoints.create("main") + endpoint.start() + + # Change current directory to the data directory, so that we can use + # a short relative path to refer to the socket. (There's a 100 char + # limitation on the path.) + os.chdir(str(endpoint.pgdata_dir)) + session = requests_unixsocket.Session() + r = session.get(f"http+unix://{NEON_COMMUNICATOR_SOCKET_NAME}/metrics") + assert r.status_code == 200, f"got response {r.status_code}: {r.text}" + + # quick test that the endpoint returned something expected. (We don't validate + # that the metrics returned are sensible.) + m = parse_metrics(r.text) + m.query_one("lfc_hits") + m.query_one("lfc_misses") + + # Test panic handling. The /debug/panic endpoint raises a Rust panic. It's + # expected to unwind and drop the HTTP connection without response, but not + # kill the process or the server. + with pytest.raises( + requests.ConnectionError, match="Remote end closed connection without response" + ): + r = session.get(f"http+unix://{NEON_COMMUNICATOR_SOCKET_NAME}/debug/panic") + assert r.status_code == 500 + + # Test that subsequent requests after the panic still work. + r = session.get(f"http+unix://{NEON_COMMUNICATOR_SOCKET_NAME}/metrics") + assert r.status_code == 200, f"got response {r.status_code}: {r.text}" + m = parse_metrics(r.text) + m.query_one("lfc_hits") + m.query_one("lfc_misses") diff --git a/test_runner/regress/test_compatibility.py b/test_runner/regress/test_compatibility.py index a3a20cdc62..734887c5b3 100644 --- a/test_runner/regress/test_compatibility.py +++ b/test_runner/regress/test_compatibility.py @@ -197,7 +197,7 @@ def test_create_snapshot( shutil.copytree( test_output_dir, new_compatibility_snapshot_dir, - ignore=shutil.ignore_patterns("pg_dynshmem"), + ignore=shutil.ignore_patterns("pg_dynshmem", "neon-communicator.socket"), ) log.info(f"Copied new compatibility snapshot dir to: {new_compatibility_snapshot_dir}") diff --git a/test_runner/regress/test_hcc_handling_ps_data_loss.py b/test_runner/regress/test_hcc_handling_ps_data_loss.py new file mode 100644 index 0000000000..35d3b72923 --- /dev/null +++ b/test_runner/regress/test_hcc_handling_ps_data_loss.py @@ -0,0 +1,47 @@ +import shutil + +from fixtures.neon_fixtures import NeonEnvBuilder +from fixtures.utils import query_scalar + + +def test_hcc_handling_ps_data_loss( + neon_env_builder: NeonEnvBuilder, +): + """ + Test that following a pageserver local data loss event, the system can recover automatically (i.e. + rehydrating the restarted pageserver from remote storage) without manual intervention. The + pageserver indicates to the storage controller that it has restarted without any local tenant + data in its "reattach" request and the storage controller uses this information to detect the + data loss condition and reconfigure the pageserver as necessary. + """ + env = neon_env_builder.init_configs() + env.broker.start() + env.storage_controller.start(handle_ps_local_disk_loss=True) + env.pageserver.start() + for sk in env.safekeepers: + sk.start() + + # create new nenant + tenant_id, _ = env.create_tenant(shard_count=4) + + endpoint = env.endpoints.create_start("main", tenant_id=tenant_id) + with endpoint.cursor() as cur: + cur.execute("SELECT pg_logical_emit_message(false, 'neon-test', 'between inserts')") + cur.execute("CREATE DATABASE testdb") + + with endpoint.cursor(dbname="testdb") as cur: + cur.execute("CREATE TABLE tbl_one_hundred_rows AS SELECT generate_series(1,100)") + endpoint.stop() + + # Kill the pageserver, remove the `tenants/` directory, and restart. This simulates a pageserver + # that restarted with the same ID but has lost all its local disk data. + env.pageserver.stop(immediate=True) + shutil.rmtree(env.pageserver.tenant_dir()) + env.pageserver.start() + + # Test that the endpoint can start and query the database after the pageserver restarts. This + # indirectly tests that the pageserver was able to rehydrate the tenant data it lost from remote + # storage automatically. + endpoint.start() + with endpoint.cursor(dbname="testdb") as cur: + assert query_scalar(cur, "SELECT count(*) FROM tbl_one_hundred_rows") == 100 diff --git a/test_runner/regress/test_lfc_working_set_approximation.py b/test_runner/regress/test_lfc_working_set_approximation.py index a28bc3d047..2ee15b60fd 100644 --- a/test_runner/regress/test_lfc_working_set_approximation.py +++ b/test_runner/regress/test_lfc_working_set_approximation.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING import pytest from fixtures.log_helper import log +from fixtures.metrics import parse_metrics from fixtures.utils import USE_LFC, query_scalar if TYPE_CHECKING: @@ -75,10 +76,24 @@ WITH (fillfactor='100'); cur.execute("SELECT abalance FROM pgbench_accounts WHERE aid = 104242") cur.execute("SELECT abalance FROM pgbench_accounts WHERE aid = 204242") # verify working set size after some index access of a few select pages only - blocks = query_scalar(cur, "select approximate_working_set_size(true)") + blocks = query_scalar(cur, "select approximate_working_set_size(false)") log.info(f"working set size after some index access of a few select pages only {blocks}") assert blocks < 20 + # Also test the metrics from the /autoscaling_metrics endpoint + autoscaling_metrics = endpoint.http_client().autoscaling_metrics() + log.debug(f"Raw metrics: {autoscaling_metrics}") + m = parse_metrics(autoscaling_metrics) + + http_estimate = m.query_one( + "lfc_approximate_working_set_size_windows", + { + "duration_seconds": "60", + }, + ).value + log.info(f"http estimate: {http_estimate}, blocks: {blocks}") + assert http_estimate > 0 and http_estimate < 20 + @pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") def test_sliding_working_set_approximation(neon_simple_env: NeonEnv): diff --git a/test_runner/regress/test_pg_regress.py b/test_runner/regress/test_pg_regress.py index 728241b465..dd9c5437ad 100644 --- a/test_runner/regress/test_pg_regress.py +++ b/test_runner/regress/test_pg_regress.py @@ -3,6 +3,7 @@ # from __future__ import annotations +import time from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Any, cast @@ -356,6 +357,81 @@ def test_sql_regress( post_checks(env, test_output_dir, DBNAME, endpoint) +def test_max_wal_rate(neon_simple_env: NeonEnv): + """ + Test the databricks.max_wal_mb_per_second GUC and how it affects WAL rate + limiting. + """ + env = neon_simple_env + + DBNAME = "regression" + superuser_name = "databricks_superuser" + + # Connect to postgres and create a database called "regression". + endpoint = env.endpoints.create_start( + "main", + config_lines=[ + # we need this option because default max_cluster_size < 0 will disable throttling completely + "neon.max_cluster_size=10GB", + ], + ) + + endpoint.safe_psql_many( + [ + f"CREATE ROLE {superuser_name}", + f"CREATE DATABASE {DBNAME}", + "CREATE EXTENSION neon", + ] + ) + + endpoint.safe_psql("CREATE TABLE usertable (YCSB_KEY INT, FIELD0 TEXT);", dbname=DBNAME) + + # Write ~1 MB data. + with endpoint.cursor(dbname=DBNAME) as cur: + for _ in range(0, 1000): + cur.execute("INSERT INTO usertable SELECT random(), repeat('a', 1000);") + + # No backpressure + tuples = endpoint.safe_psql("SELECT backpressure_throttling_time();") + assert tuples[0][0] == 0, "Backpressure throttling detected" + + # 0 MB/s max_wal_rate. WAL proposer can still push some WALs but will be super slow. + endpoint.safe_psql_many( + [ + "ALTER SYSTEM SET databricks.max_wal_mb_per_second = 0;", + "SELECT pg_reload_conf();", + ] + ) + + # Write ~10 KB data should hit backpressure. + with endpoint.cursor(dbname=DBNAME) as cur: + cur.execute("SET databricks.max_wal_mb_per_second = 0;") + for _ in range(0, 10): + cur.execute("INSERT INTO usertable SELECT random(), repeat('a', 1000);") + + tuples = endpoint.safe_psql("SELECT backpressure_throttling_time();") + assert tuples[0][0] > 0, "No backpressure throttling detected" + + # 1 MB/s max_wal_rate. + endpoint.safe_psql_many( + [ + "ALTER SYSTEM SET databricks.max_wal_mb_per_second = 1;", + "SELECT pg_reload_conf();", + ] + ) + + # Write 10 MB data. + with endpoint.cursor(dbname=DBNAME) as cur: + start = int(time.time()) + for _ in range(0, 10000): + cur.execute("INSERT INTO usertable SELECT random(), repeat('a', 1000);") + + end = int(time.time()) + assert end - start >= 10, ( + "Throttling should cause the previous inserts to take greater than or equal to 10 seconds" + ) + + @skip_in_debug_build("only run with release build") @pytest.mark.parametrize("reldir_type", ["v1", "v2"]) def test_tx_abort_with_many_relations( diff --git a/test_runner/regress/test_rest_broker.py b/test_runner/regress/test_rest_broker.py new file mode 100644 index 0000000000..60b04655d3 --- /dev/null +++ b/test_runner/regress/test_rest_broker.py @@ -0,0 +1,137 @@ +import json +import signal +import time + +import requests +from fixtures.utils import skip_if_proxy_lacks_rest_broker +from jwcrypto import jwt + + +@skip_if_proxy_lacks_rest_broker() +def test_rest_broker_happy( + local_proxy_fixed_port, rest_broker_proxy, vanilla_pg, neon_authorize_jwk, httpserver +): + """Test REST API endpoint using local_proxy and rest_broker_proxy.""" + + # Use the fixed port local proxy + local_proxy = local_proxy_fixed_port + + # Create the required roles for PostgREST authentication + vanilla_pg.safe_psql("CREATE ROLE authenticator LOGIN") + vanilla_pg.safe_psql("CREATE ROLE authenticated") + vanilla_pg.safe_psql("CREATE ROLE anon") + vanilla_pg.safe_psql("GRANT authenticated TO authenticator") + vanilla_pg.safe_psql("GRANT anon TO authenticator") + + # Create the pgrst schema and configuration function required by the rest broker + vanilla_pg.safe_psql("CREATE SCHEMA IF NOT EXISTS pgrst") + vanilla_pg.safe_psql(""" + CREATE OR REPLACE FUNCTION pgrst.pre_config() + RETURNS VOID AS $$ + SELECT + set_config('pgrst.db_schemas', 'test', true) + , set_config('pgrst.db_aggregates_enabled', 'true', true) + , set_config('pgrst.db_anon_role', 'anon', true) + , set_config('pgrst.jwt_aud', '', true) + , set_config('pgrst.jwt_secret', '', true) + , set_config('pgrst.jwt_role_claim_key', '."role"', true) + + $$ LANGUAGE SQL; + """) + vanilla_pg.safe_psql("GRANT USAGE ON SCHEMA pgrst TO authenticator") + vanilla_pg.safe_psql("GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA pgrst TO authenticator") + + # Bootstrap the database with test data + vanilla_pg.safe_psql("CREATE SCHEMA IF NOT EXISTS test") + vanilla_pg.safe_psql(""" + CREATE TABLE IF NOT EXISTS test.items ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL + ) + """) + vanilla_pg.safe_psql("INSERT INTO test.items (name) VALUES ('test_item')") + + # Grant access to the test schema for the authenticated role + vanilla_pg.safe_psql("GRANT USAGE ON SCHEMA test TO authenticated") + vanilla_pg.safe_psql("GRANT SELECT ON ALL TABLES IN SCHEMA test TO authenticated") + + # Set up HTTP server to serve JWKS (like static_auth_broker) + # Generate public key from the JWK + public_key = neon_authorize_jwk.export_public(as_dict=True) + + # Set up the httpserver to serve the JWKS + httpserver.expect_request("/.well-known/jwks.json").respond_with_json({"keys": [public_key]}) + + # Create JWKS configuration for the rest broker proxy + jwks_config = { + "jwks": [ + { + "id": "1", + "role_names": ["authenticator", "authenticated", "anon"], + "jwks_url": httpserver.url_for("/.well-known/jwks.json"), + "provider_name": "foo", + "jwt_audience": None, + } + ] + } + + # Write the JWKS config to the config file that rest_broker_proxy expects + config_file = rest_broker_proxy.config_path + with open(config_file, "w") as f: + json.dump(jwks_config, f) + + # Write the same config to the local_proxy config file + local_config_file = local_proxy.config_path + with open(local_config_file, "w") as f: + json.dump(jwks_config, f) + + # Signal both proxies to reload their config + if rest_broker_proxy._popen is not None: + rest_broker_proxy._popen.send_signal(signal.SIGHUP) + if local_proxy._popen is not None: + local_proxy._popen.send_signal(signal.SIGHUP) + # Wait a bit for config to reload + time.sleep(0.5) + + # Generate a proper JWT token using the JWK (similar to test_auth_broker.py) + token = jwt.JWT( + header={"kid": neon_authorize_jwk.key_id, "alg": "RS256"}, + claims={ + "sub": "user", + "role": "authenticated", # role that's in role_names + "exp": 9999999999, # expires far in the future + "iat": 1000000000, # issued at + }, + ) + token.make_signed_token(neon_authorize_jwk) + + # Debug: Print the JWT claims and config for troubleshooting + print(f"JWT claims: {token.claims}") + print(f"JWT header: {token.header}") + print(f"Config file contains: {jwks_config}") + print(f"Public key kid: {public_key.get('kid')}") + + # Test REST API call - following SUBZERO.md pattern + # REST API is served on the WSS port with HTTPS and includes database name + # ep-purple-glitter-adqior4l-pooler.c-2.us-east-1.aws.neon.tech + url = f"https://foo.apirest.c-2.local.neon.build:{rest_broker_proxy.wss_port}/postgres/rest/v1/items" + + response = requests.get( + url, + headers={ + "Authorization": f"Bearer {token.serialize()}", + }, + params={"id": "eq.1", "select": "name"}, + verify=False, # Skip SSL verification for self-signed certs + ) + + print(f"Response status: {response.status_code}") + print(f"Response headers: {response.headers}") + print(f"Response body: {response.text}") + + # For now, let's just check that we get some response + # We can refine the assertions once we see what the actual response looks like + assert response.status_code in [200] # Any response means the proxies are working + + # check the response body + assert response.json() == [{"name": "test_item"}] diff --git a/test_runner/regress/test_safekeeper_migration.py b/test_runner/regress/test_safekeeper_migration.py index 170c1a3650..371bec0c62 100644 --- a/test_runner/regress/test_safekeeper_migration.py +++ b/test_runner/regress/test_safekeeper_migration.py @@ -3,11 +3,22 @@ from __future__ import annotations from typing import TYPE_CHECKING import pytest +import requests +from fixtures.log_helper import log from fixtures.neon_fixtures import StorageControllerApiException if TYPE_CHECKING: from fixtures.neon_fixtures import NeonEnvBuilder +# TODO(diko): pageserver spams with various errors during safekeeper migration. +# Fix the code so it handles the migration better. +ALLOWED_PAGESERVER_ERRORS = [ + ".*Timeline .* was cancelled and cannot be used anymore.*", + ".*Timeline .* has been deleted.*", + ".*Timeline .* was not found in global map.*", + ".*wal receiver task finished with an error.*", +] + def test_safekeeper_migration_simple(neon_env_builder: NeonEnvBuilder): """ @@ -24,16 +35,7 @@ def test_safekeeper_migration_simple(neon_env_builder: NeonEnvBuilder): "timeline_safekeeper_count": 1, } env = neon_env_builder.init_start() - # TODO(diko): pageserver spams with various errors during safekeeper migration. - # Fix the code so it handles the migration better. - env.pageserver.allowed_errors.extend( - [ - ".*Timeline .* was cancelled and cannot be used anymore.*", - ".*Timeline .* has been deleted.*", - ".*Timeline .* was not found in global map.*", - ".*wal receiver task finished with an error.*", - ] - ) + env.pageserver.allowed_errors.extend(ALLOWED_PAGESERVER_ERRORS) ep = env.endpoints.create("main", tenant_id=env.initial_tenant) @@ -42,15 +44,23 @@ def test_safekeeper_migration_simple(neon_env_builder: NeonEnvBuilder): assert len(mconf["sk_set"]) == 1 assert mconf["generation"] == 1 + current_sk = mconf["sk_set"][0] + ep.start(safekeeper_generation=1, safekeepers=mconf["sk_set"]) ep.safe_psql("CREATE EXTENSION neon_test_utils;") ep.safe_psql("CREATE TABLE t(a int)") + expected_gen = 1 + for active_sk in range(1, 4): env.storage_controller.migrate_safekeepers( env.initial_tenant, env.initial_timeline, [active_sk] ) + if active_sk != current_sk: + expected_gen += 2 + current_sk = active_sk + other_sks = [sk for sk in range(1, 4) if sk != active_sk] for sk in other_sks: @@ -65,9 +75,6 @@ def test_safekeeper_migration_simple(neon_env_builder: NeonEnvBuilder): assert ep.safe_psql("SELECT * FROM t") == [(i,) for i in range(1, 4)] - # 1 initial generation + 2 migrations on each loop iteration. - expected_gen = 1 + 2 * 3 - mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline) assert mconf["generation"] == expected_gen @@ -113,3 +120,79 @@ def test_new_sk_set_validation(neon_env_builder: NeonEnvBuilder): env.storage_controller.safekeeper_scheduling_policy(decom_sk, "Decomissioned") expect_fail([sk_set[0], decom_sk], "decomissioned") + + +def test_safekeeper_migration_common_set_failpoints(neon_env_builder: NeonEnvBuilder): + """ + Test that safekeeper migration handles failures well. + + Two main conditions are checked: + 1. safekeeper migration handler can be retried on different failures. + 2. writes do not stuck if sk_set and new_sk_set have a quorum in common. + """ + neon_env_builder.num_safekeepers = 4 + neon_env_builder.storage_controller_config = { + "timelines_onto_safekeepers": True, + "timeline_safekeeper_count": 3, + } + env = neon_env_builder.init_start() + env.pageserver.allowed_errors.extend(ALLOWED_PAGESERVER_ERRORS) + + mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline) + assert len(mconf["sk_set"]) == 3 + assert mconf["generation"] == 1 + + ep = env.endpoints.create("main", tenant_id=env.initial_tenant) + ep.start(safekeeper_generation=1, safekeepers=mconf["sk_set"]) + ep.safe_psql("CREATE EXTENSION neon_test_utils;") + ep.safe_psql("CREATE TABLE t(a int)") + + excluded_sk = mconf["sk_set"][-1] + added_sk = [sk.id for sk in env.safekeepers if sk.id not in mconf["sk_set"]][0] + new_sk_set = mconf["sk_set"][:-1] + [added_sk] + log.info(f"migrating sk set from {mconf['sk_set']} to {new_sk_set}") + + failpoints = [ + "sk-migration-after-step-3", + "sk-migration-after-step-4", + "sk-migration-after-step-5", + "sk-migration-after-step-7", + "sk-migration-after-step-8", + "sk-migration-step-9-after-set-membership", + "sk-migration-step-9-mid-exclude", + "sk-migration-step-9-after-exclude", + "sk-migration-after-step-9", + ] + + for i, fp in enumerate(failpoints): + env.storage_controller.configure_failpoints((fp, "return(1)")) + + with pytest.raises(StorageControllerApiException, match=f"failpoint {fp}"): + env.storage_controller.migrate_safekeepers( + env.initial_tenant, env.initial_timeline, new_sk_set + ) + ep.safe_psql(f"INSERT INTO t VALUES ({i})") + + env.storage_controller.configure_failpoints((fp, "off")) + + # No failpoints, migration should succeed. + env.storage_controller.migrate_safekeepers(env.initial_tenant, env.initial_timeline, new_sk_set) + + mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline) + assert mconf["new_sk_set"] is None + assert mconf["sk_set"] == new_sk_set + assert mconf["generation"] == 3 + + ep.clear_buffers() + assert ep.safe_psql("SELECT * FROM t") == [(i,) for i in range(len(failpoints))] + assert ep.safe_psql("SHOW neon.safekeepers")[0][0].startswith("g#3:") + + # Check that we didn't forget to remove the timeline on the excluded safekeeper. + with pytest.raises(requests.exceptions.HTTPError) as exc: + env.safekeepers[excluded_sk - 1].http_client().timeline_status( + env.initial_tenant, env.initial_timeline + ) + assert exc.value.response.status_code == 404 + assert ( + f"timeline {env.initial_tenant}/{env.initial_timeline} deleted" in exc.value.response.text + ) diff --git a/test_runner/regress/test_sharding.py b/test_runner/regress/test_sharding.py index 5549105188..2252c098c7 100644 --- a/test_runner/regress/test_sharding.py +++ b/test_runner/regress/test_sharding.py @@ -1810,6 +1810,8 @@ def test_sharding_backpressure(neon_env_builder: NeonEnvBuilder): "config_lines": [ # Tip: set to 100MB to make the test fail "max_replication_write_lag=1MB", + # Hadron: Need to set max_cluster_size to some value to enable any backpressure at all. + "neon.max_cluster_size=1GB", ], # We need `neon` extension for calling backpressure functions, # this flag instructs `compute_ctl` to pre-install it. diff --git a/vendor/postgres-v14 b/vendor/postgres-v14 index 1cb207d1c9..4cacada8bd 160000 --- a/vendor/postgres-v14 +++ b/vendor/postgres-v14 @@ -1 +1 @@ -Subproject commit 1cb207d1c9efb1f6c6f864a47bf45e992a7f0eb0 +Subproject commit 4cacada8bde7f6424751a0727a657783c6a1d20b diff --git a/vendor/postgres-v15 b/vendor/postgres-v15 index 9d19780350..e5ee23d998 160000 --- a/vendor/postgres-v15 +++ b/vendor/postgres-v15 @@ -1 +1 @@ -Subproject commit 9d19780350c0c7b536312dc3b891ade55628bc7b +Subproject commit e5ee23d99874ea9f5b62f8acc7d076162ae95d6c diff --git a/vendor/postgres-v16 b/vendor/postgres-v16 index 1486f919d4..ad2b69b582 160000 --- a/vendor/postgres-v16 +++ b/vendor/postgres-v16 @@ -1 +1 @@ -Subproject commit 1486f919d4dc21637407ee7ed203497bb5bd516a +Subproject commit ad2b69b58230290fc44c08fbe0c97981c64f6c7d diff --git a/vendor/postgres-v17 b/vendor/postgres-v17 index 160d0b52d6..ba750903a9 160000 --- a/vendor/postgres-v17 +++ b/vendor/postgres-v17 @@ -1 +1 @@ -Subproject commit 160d0b52d66f4a5d21251a2912a50561bf600333 +Subproject commit ba750903a90dded8098f2f56d0b2a9012e6166af diff --git a/vendor/revisions.json b/vendor/revisions.json index 69e7559c67..d62f8e5736 100644 --- a/vendor/revisions.json +++ b/vendor/revisions.json @@ -1,18 +1,18 @@ { "v17": [ "17.5", - "160d0b52d66f4a5d21251a2912a50561bf600333" + "ba750903a90dded8098f2f56d0b2a9012e6166af" ], "v16": [ "16.9", - "1486f919d4dc21637407ee7ed203497bb5bd516a" + "ad2b69b58230290fc44c08fbe0c97981c64f6c7d" ], "v15": [ "15.13", - "9d19780350c0c7b536312dc3b891ade55628bc7b" + "e5ee23d99874ea9f5b62f8acc7d076162ae95d6c" ], "v14": [ "14.18", - "1cb207d1c9efb1f6c6f864a47bf45e992a7f0eb0" + "4cacada8bde7f6424751a0727a657783c6a1d20b" ] } diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index d6d64a2045..f5984d3ac3 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -74,7 +74,7 @@ once_cell = { version = "1" } p256 = { version = "0.13", features = ["jwk"] } parquet = { version = "53", default-features = false, features = ["zstd"] } prost = { version = "0.13", features = ["no-recursion-limit", "prost-derive"] } -rand = { version = "0.8", features = ["small_rng"] } +rand = { version = "0.9" } regex = { version = "1" } regex-automata = { version = "0.4", default-features = false, features = ["dfa-onepass", "hybrid", "meta", "nfa-backtrack", "perf-inline", "perf-literal", "unicode"] } regex-syntax = { version = "0.8" } @@ -93,6 +93,7 @@ spki = { version = "0.7", default-features = false, features = ["pem", "std"] } stable_deref_trait = { version = "1" } subtle = { version = "2" } sync_wrapper = { version = "0.1", default-features = false, features = ["futures"] } +thiserror = { version = "2" } tikv-jemalloc-ctl = { version = "0.6", features = ["stats", "use_std"] } tikv-jemalloc-sys = { version = "0.6", features = ["profiling", "stats", "unprefixed_malloc_on_supported_platforms"] } time = { version = "0.3", features = ["macros", "serde-well-known"] } @@ -101,6 +102,7 @@ tokio-rustls = { version = "0.26", default-features = false, features = ["loggin tokio-stream = { version = "0.1", features = ["net", "sync"] } tokio-util = { version = "0.7", features = ["codec", "compat", "io-util", "rt"] } toml_edit = { version = "0.22", features = ["serde"] } +tonic = { version = "0.13", default-features = false, features = ["codegen", "gzip", "prost", "router", "server", "tls-native-roots", "tls-ring", "zstd"] } tower = { version = "0.5", default-features = false, features = ["balance", "buffer", "limit", "log"] } tracing = { version = "0.1", features = ["log"] } tracing-core = { version = "0.1" }