From 0dbe55180288bc87e159cfd3f0b8fca8aaea6204 Mon Sep 17 00:00:00 2001 From: Ruslan Talpa Date: Mon, 21 Jul 2025 21:16:28 +0300 Subject: [PATCH 1/6] proxy: subzero integration in auth-broker (embedded data-api) (#12474) ## Problem We want to have the data-api served by the proxy directly instead of relying on a 3rd party to run a deployment for each project/endpoint. ## Summary of changes With the changes below, the proxy (auth-broker) becomes also a "rest-broker", that can be thought of as a "Multi-tenant" data-api which provides an automated REST api for all the databases in the region. The core of the implementation (that leverages the subzero library) is in proxy/src/serverless/rest.rs and this is the only place that has "new logic". --------- Co-authored-by: Ruslan Talpa Co-authored-by: Alexander Bayandin Co-authored-by: Conrad Ludgate --- .config/hakari.toml | 3 +- .../actions/prepare-for-subzero/action.yml | 28 + .github/workflows/_build-and-test-locally.yml | 6 +- .github/workflows/_check-codestyle-rust.yml | 4 + .github/workflows/build-macos.yml | 4 + .github/workflows/build_and_test.yml | 2 + .github/workflows/neon_extra_builds.yml | 1 + .gitignore | 5 + Cargo.lock | 161 ++- Cargo.toml | 1 + Dockerfile | 26 +- deny.toml | 1 + pgxn/neon/communicator/Cargo.toml | 3 + proxy/Cargo.toml | 6 + proxy/README.md | 22 +- proxy/src/binary/local_proxy.rs | 9 + proxy/src/binary/proxy.rs | 53 + proxy/src/cache/timed_lru.rs | 28 + proxy/src/config.rs | 12 + proxy/src/redis/notifications.rs | 10 +- proxy/src/serverless/mod.rs | 40 +- proxy/src/serverless/rest.rs | 1165 +++++++++++++++++ proxy/src/serverless/sql_over_http.rs | 2 +- proxy/src/util.rs | 10 + proxy/subzero_core/.gitignore | 2 + proxy/subzero_core/Cargo.toml | 12 + proxy/subzero_core/src/lib.rs | 1 + test_runner/fixtures/neon_fixtures.py | 363 +++++ test_runner/fixtures/utils.py | 26 + test_runner/regress/test_rest_broker.py | 137 ++ 30 files changed, 2073 insertions(+), 70 deletions(-) create mode 100644 .github/actions/prepare-for-subzero/action.yml create mode 100644 proxy/src/serverless/rest.rs create mode 100644 proxy/subzero_core/.gitignore create mode 100644 proxy/subzero_core/Cargo.toml create mode 100644 proxy/subzero_core/src/lib.rs create mode 100644 test_runner/regress/test_rest_broker.py diff --git a/.config/hakari.toml b/.config/hakari.toml index 9991cd92b0..dcbc44cc33 100644 --- a/.config/hakari.toml +++ b/.config/hakari.toml @@ -21,13 +21,14 @@ platforms = [ # "x86_64-apple-darwin", # "x86_64-pc-windows-msvc", ] - [final-excludes] workspace-members = [ # vm_monitor benefits from the same Cargo.lock as the rest of our artifacts, but # it is built primarly in separate repo neondatabase/autoscaling and thus is excluded # from depending on workspace-hack because most of the dependencies are not used. "vm_monitor", + # subzero-core is a stub crate that should be excluded from workspace-hack + "subzero-core", # All of these exist in libs and are not usually built independently. # Putting workspace hack there adds a bottleneck for cargo builds. "compute_api", diff --git a/.github/actions/prepare-for-subzero/action.yml b/.github/actions/prepare-for-subzero/action.yml new file mode 100644 index 0000000000..11beb11880 --- /dev/null +++ b/.github/actions/prepare-for-subzero/action.yml @@ -0,0 +1,28 @@ +name: 'Prepare current job for subzero' +description: > + Set git token to access `neondatabase/subzero` from cargo build, + and set `CARGO_NET_GIT_FETCH_WITH_CLI=true` env variable to use git CLI + +inputs: + token: + description: 'GitHub token with access to neondatabase/subzero' + required: true + +runs: + using: "composite" + + steps: + - name: Set git token for neondatabase/subzero + uses: pyTooling/Actions/with-post-step@2307b526df64d55e95884e072e49aac2a00a9afa # v5.1.0 + env: + SUBZERO_ACCESS_TOKEN: ${{ inputs.token }} + with: + main: | + git config --global url."https://x-access-token:${SUBZERO_ACCESS_TOKEN}@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero" + cargo add -p proxy subzero-core --git https://github.com/neondatabase/subzero --rev 396264617e78e8be428682f87469bb25429af88a + post: | + git config --global --unset url."https://x-access-token:${SUBZERO_ACCESS_TOKEN}@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero" + + - name: Set `CARGO_NET_GIT_FETCH_WITH_CLI=true` env variable + shell: bash -euxo pipefail {0} + run: echo "CARGO_NET_GIT_FETCH_WITH_CLI=true" >> ${GITHUB_ENV} diff --git a/.github/workflows/_build-and-test-locally.yml b/.github/workflows/_build-and-test-locally.yml index 94115572df..1b03dc9c03 100644 --- a/.github/workflows/_build-and-test-locally.yml +++ b/.github/workflows/_build-and-test-locally.yml @@ -86,6 +86,10 @@ jobs: with: submodules: true + - uses: ./.github/actions/prepare-for-subzero + with: + token: ${{ secrets.CI_ACCESS_TOKEN }} + - name: Set pg 14 revision for caching id: pg_v14_rev run: echo pg_rev=$(git rev-parse HEAD:vendor/postgres-v14) >> $GITHUB_OUTPUT @@ -116,7 +120,7 @@ jobs: ARCH: ${{ inputs.arch }} SANITIZERS: ${{ inputs.sanitizers }} run: | - CARGO_FLAGS="--locked --features testing" + CARGO_FLAGS="--locked --features testing,rest_broker" if [[ $BUILD_TYPE == "debug" && $ARCH == 'x64' ]]; then cov_prefix="scripts/coverage --profraw-prefix=$GITHUB_JOB --dir=/tmp/coverage run" CARGO_PROFILE="" diff --git a/.github/workflows/_check-codestyle-rust.yml b/.github/workflows/_check-codestyle-rust.yml index 4f844b0bf6..af29e10e97 100644 --- a/.github/workflows/_check-codestyle-rust.yml +++ b/.github/workflows/_check-codestyle-rust.yml @@ -46,6 +46,10 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: true + + - uses: ./.github/actions/prepare-for-subzero + with: + token: ${{ secrets.CI_ACCESS_TOKEN }} - name: Cache cargo deps uses: tespkg/actions-cache@b7bf5fcc2f98a52ac6080eb0fd282c2f752074b1 # v1.8.0 diff --git a/.github/workflows/build-macos.yml b/.github/workflows/build-macos.yml index 2296807d2d..e43eec1133 100644 --- a/.github/workflows/build-macos.yml +++ b/.github/workflows/build-macos.yml @@ -54,6 +54,10 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: true + + - uses: ./.github/actions/prepare-for-subzero + with: + token: ${{ secrets.CI_ACCESS_TOKEN }} - name: Install build dependencies run: | diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 2977f642bc..f237a991cc 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -632,6 +632,8 @@ jobs: BUILD_TAG=${{ needs.meta.outputs.release-tag || needs.meta.outputs.build-tag }} TAG=${{ needs.build-build-tools-image.outputs.image-tag }}-bookworm DEBIAN_VERSION=bookworm + secrets: | + SUBZERO_ACCESS_TOKEN=${{ secrets.CI_ACCESS_TOKEN }} provenance: false push: true pull: true diff --git a/.github/workflows/neon_extra_builds.yml b/.github/workflows/neon_extra_builds.yml index 3e81183687..10ca1a1591 100644 --- a/.github/workflows/neon_extra_builds.yml +++ b/.github/workflows/neon_extra_builds.yml @@ -72,6 +72,7 @@ jobs: check-macos-build: needs: [ check-permissions, files-changed ] uses: ./.github/workflows/build-macos.yml + secrets: inherit with: pg_versions: ${{ needs.files-changed.outputs.postgres_changes }} rebuild_rust_code: ${{ fromJSON(needs.files-changed.outputs.rebuild_rust_code) }} diff --git a/.gitignore b/.gitignore index 835cceb123..1e1c2316af 100644 --- a/.gitignore +++ b/.gitignore @@ -26,9 +26,14 @@ docker-compose/docker-compose-parallel.yml *.o *.so *.Po +*.pid # pgindent typedef lists *.list # Node **/node_modules/ + +# various files for local testing +/proxy/.subzero +local_proxy.json diff --git a/Cargo.lock b/Cargo.lock index 137b883a6d..32ae30a765 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -52,6 +52,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "aliasable" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd" + [[package]] name = "aligned-vec" version = "0.6.1" @@ -490,7 +496,7 @@ dependencies = [ "hex", "hmac", "http 0.2.9", - "http 1.1.0", + "http 1.3.1", "once_cell", "p256 0.11.1", "percent-encoding", @@ -631,7 +637,7 @@ dependencies = [ "aws-smithy-types", "bytes", "http 0.2.9", - "http 1.1.0", + "http 1.3.1", "pin-project-lite", "tokio", "tracing", @@ -649,7 +655,7 @@ dependencies = [ "bytes-utils", "futures-core", "http 0.2.9", - "http 1.1.0", + "http 1.3.1", "http-body 0.4.5", "http-body 1.0.0", "http-body-util", @@ -698,7 +704,7 @@ dependencies = [ "bytes", "form_urlencoded", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "hyper 1.4.1", @@ -732,7 +738,7 @@ checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733" dependencies = [ "bytes", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "mime", @@ -756,7 +762,7 @@ dependencies = [ "form_urlencoded", "futures-util", "headers", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "mime", @@ -1090,7 +1096,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "975982cdb7ad6a142be15bdf84aea7ec6a9e5d4d797c004d43185b24cfe4e684" dependencies = [ "clap", - "heck", + "heck 0.5.0", "indexmap 2.9.0", "log", "proc-macro2", @@ -1228,7 +1234,7 @@ version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "syn 2.0.100", @@ -1334,7 +1340,7 @@ dependencies = [ "flate2", "futures", "hostname-validator", - "http 1.1.0", + "http 1.3.1", "indexmap 2.9.0", "itertools 0.10.5", "jsonwebtoken", @@ -1969,7 +1975,7 @@ checksum = "0892a17df262a24294c382f0d5997571006e7a4348b4327557c4ff1cd4a8bccc" dependencies = [ "darling", "either", - "heck", + "heck 0.5.0", "proc-macro2", "quote", "syn 2.0.100", @@ -2661,7 +2667,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http 1.1.0", + "http 1.3.1", "indexmap 2.9.0", "slab", "tokio", @@ -2743,7 +2749,7 @@ dependencies = [ "base64 0.21.7", "bytes", "headers-core", - "http 1.1.0", + "http 1.3.1", "httpdate", "mime", "sha1", @@ -2755,9 +2761,15 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" dependencies = [ - "http 1.1.0", + "http 1.3.1", ] +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "heck" version = "0.5.0" @@ -2833,9 +2845,9 @@ dependencies = [ [[package]] name = "http" -version = "1.1.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" dependencies = [ "bytes", "fnv", @@ -2860,7 +2872,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" dependencies = [ "bytes", - "http 1.1.0", + "http 1.3.1", ] [[package]] @@ -2871,7 +2883,7 @@ checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" dependencies = [ "bytes", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "pin-project-lite", ] @@ -2995,7 +3007,7 @@ dependencies = [ "futures-channel", "futures-util", "h2 0.4.4", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "httparse", "httpdate", @@ -3028,7 +3040,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a0bea761b46ae2b24eb4aef630d8d1c398157b6fc29e6350ecf090a0b70c952c" dependencies = [ "futures-util", - "http 1.1.0", + "http 1.3.1", "hyper 1.4.1", "hyper-util", "rustls 0.22.4", @@ -3060,7 +3072,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "hyper 1.4.1", "pin-project-lite", @@ -3709,7 +3721,7 @@ version = "0.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9e6777fc80a575f9503d908c8b498782a6c3ee88a06cb416dc3941401e43b94" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "syn 2.0.100", @@ -4160,7 +4172,7 @@ checksum = "10a8a7f5f6ba7c1b286c2fbca0454eaba116f63bbe69ed250b642d36fbb04d80" dependencies = [ "async-trait", "bytes", - "http 1.1.0", + "http 1.3.1", "opentelemetry", "reqwest", ] @@ -4173,7 +4185,7 @@ checksum = "91cf61a1868dacc576bf2b2a1c3e9ab150af7272909e80085c3173384fe11f76" dependencies = [ "async-trait", "futures-core", - "http 1.1.0", + "http 1.3.1", "opentelemetry", "opentelemetry-http", "opentelemetry-proto", @@ -4252,6 +4264,30 @@ dependencies = [ "winapi", ] +[[package]] +name = "ouroboros" +version = "0.18.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e0f050db9c44b97a94723127e6be766ac5c340c48f2c4bb3ffa11713744be59" +dependencies = [ + "aliasable", + "ouroboros_macro", + "static_assertions", +] + +[[package]] +name = "ouroboros_macro" +version = "0.18.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c7028bdd3d43083f6d8d4d5187680d0d3560d54df4cc9d752005268b41e64d0" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "proc-macro2-diagnostics", + "quote", + "syn 2.0.100", +] + [[package]] name = "outref" version = "0.5.1" @@ -4381,7 +4417,7 @@ dependencies = [ "hashlink", "hex", "hex-literal", - "http 1.1.0", + "http 1.3.1", "http-utils", "humantime", "humantime-serde", @@ -5148,6 +5184,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proc-macro2-diagnostics" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", + "version_check", + "yansi", +] + [[package]] name = "procfs" version = "0.16.0" @@ -5217,7 +5266,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" dependencies = [ "bytes", - "heck", + "heck 0.5.0", "itertools 0.12.1", "log", "multimap", @@ -5238,7 +5287,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15" dependencies = [ "bytes", - "heck", + "heck 0.5.0", "itertools 0.12.1", "log", "multimap", @@ -5334,7 +5383,7 @@ dependencies = [ "hex", "hmac", "hostname", - "http 1.1.0", + "http 1.3.1", "http-body-util", "http-utils", "humantime", @@ -5354,6 +5403,7 @@ dependencies = [ "metrics", "once_cell", "opentelemetry", + "ouroboros", "p256 0.13.2", "papaya", "parking_lot 0.12.1", @@ -5390,6 +5440,7 @@ dependencies = [ "socket2", "strum_macros", "subtle", + "subzero-core", "thiserror 1.0.69", "tikv-jemalloc-ctl", "tikv-jemallocator", @@ -5705,14 +5756,14 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.2" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.3", - "regex-syntax 0.8.2", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", ] [[package]] @@ -5726,13 +5777,13 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.3" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.2", + "regex-syntax 0.8.5", ] [[package]] @@ -5749,9 +5800,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.8.2" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "relative-path" @@ -5821,7 +5872,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "hyper 1.4.1", @@ -5863,7 +5914,7 @@ checksum = "d1ccd3b55e711f91a9885a2fa6fbbb2e39db1776420b062efc058c6410f7e5e3" dependencies = [ "anyhow", "async-trait", - "http 1.1.0", + "http 1.3.1", "reqwest", "serde", "thiserror 1.0.69", @@ -5880,7 +5931,7 @@ dependencies = [ "async-trait", "futures", "getrandom 0.2.11", - "http 1.1.0", + "http 1.3.1", "hyper 1.4.1", "parking_lot 0.11.2", "reqwest", @@ -5901,7 +5952,7 @@ dependencies = [ "anyhow", "async-trait", "getrandom 0.2.11", - "http 1.1.0", + "http 1.3.1", "matchit", "opentelemetry", "reqwest", @@ -6260,7 +6311,7 @@ dependencies = [ "fail", "futures", "hex", - "http 1.1.0", + "http 1.3.1", "http-utils", "humantime", "hyper 0.14.30", @@ -7109,7 +7160,7 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "rustversion", @@ -7122,6 +7173,10 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" +[[package]] +name = "subzero-core" +version = "3.0.1" + [[package]] name = "svg_fmt" version = "0.4.3" @@ -7732,7 +7787,7 @@ dependencies = [ "async-trait", "base64 0.22.1", "bytes", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "percent-encoding", @@ -7756,7 +7811,7 @@ dependencies = [ "bytes", "flate2", "h2 0.4.4", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "hyper 1.4.1", @@ -7847,7 +7902,7 @@ dependencies = [ "base64 0.22.1", "bitflags 2.8.0", "bytes", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "mime", "pin-project-lite", @@ -7868,7 +7923,7 @@ name = "tower-otel" version = "0.2.0" source = "git+https://github.com/mattiapenati/tower-otel?rev=56a7321053bcb72443888257b622ba0d43a11fcd#56a7321053bcb72443888257b622ba0d43a11fcd" dependencies = [ - "http 1.1.0", + "http 1.3.1", "opentelemetry", "pin-project", "tower-layer", @@ -8049,7 +8104,7 @@ dependencies = [ "byteorder", "bytes", "data-encoding", - "http 1.1.0", + "http 1.3.1", "httparse", "log", "rand 0.8.5", @@ -8068,7 +8123,7 @@ dependencies = [ "byteorder", "bytes", "data-encoding", - "http 1.1.0", + "http 1.3.1", "httparse", "log", "rand 0.8.5", @@ -8857,8 +8912,8 @@ dependencies = [ "quote", "rand 0.8.5", "regex", - "regex-automata 0.4.3", - "regex-syntax 0.8.2", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", "reqwest", "rustls 0.23.27", "rustls-pki-types", @@ -8954,6 +9009,12 @@ version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4d25c75bf9ea12c4040a97f829154768bbbce366287e2dc044af160cd79a13fd" +[[package]] +name = "yansi" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" + [[package]] name = "yasna" version = "0.5.2" diff --git a/Cargo.toml b/Cargo.toml index 6d91262882..fe647828fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,6 +49,7 @@ members = [ "libs/proxy/tokio-postgres2", "endpoint_storage", "pgxn/neon/communicator", + "proxy/subzero_core", ] [workspace.package] diff --git a/Dockerfile b/Dockerfile index 55b87d4012..654ae72e56 100644 --- a/Dockerfile +++ b/Dockerfile @@ -63,7 +63,14 @@ WORKDIR /home/nonroot COPY --chown=nonroot . . -RUN cargo chef prepare --recipe-path recipe.json +RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \ + set -e \ + && if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \ + export CARGO_NET_GIT_FETCH_WITH_CLI=true && \ + git config --global url."https://$(cat /run/secrets/SUBZERO_ACCESS_TOKEN)@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero" && \ + cargo add -p proxy subzero-core --git https://github.com/neondatabase/subzero --rev 396264617e78e8be428682f87469bb25429af88a; \ + fi \ + && cargo chef prepare --recipe-path recipe.json # Main build image FROM $REPOSITORY/$IMAGE:$TAG AS build @@ -71,20 +78,33 @@ WORKDIR /home/nonroot ARG GIT_VERSION=local ARG BUILD_TAG ARG ADDITIONAL_RUSTFLAGS="" +ENV CARGO_FEATURES="default" # 3. Build cargo dependencies. Note that this step doesn't depend on anything else than # `recipe.json`, so the layer can be reused as long as none of the dependencies change. COPY --from=plan /home/nonroot/recipe.json recipe.json -RUN set -e \ +RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \ + set -e \ + && if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \ + export CARGO_NET_GIT_FETCH_WITH_CLI=true && \ + git config --global url."https://$(cat /run/secrets/SUBZERO_ACCESS_TOKEN)@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero"; \ + fi \ && RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo chef cook --locked --release --recipe-path recipe.json # Perform the main build. We reuse the Postgres build artifacts from the intermediate 'pg-build' # layer, and the cargo dependencies built in the previous step. COPY --chown=nonroot --from=pg-build /home/nonroot/pg_install/ pg_install COPY --chown=nonroot . . +COPY --chown=nonroot --from=plan /home/nonroot/proxy/Cargo.toml proxy/Cargo.toml +COPY --chown=nonroot --from=plan /home/nonroot/Cargo.lock Cargo.lock -RUN set -e \ +RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \ + set -e \ + && if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \ + export CARGO_FEATURES="rest_broker"; \ + fi \ && RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo build \ + --features $CARGO_FEATURES \ --bin pg_sni_router \ --bin pageserver \ --bin pagectl \ diff --git a/deny.toml b/deny.toml index be1c6a2f2c..7afd05a837 100644 --- a/deny.toml +++ b/deny.toml @@ -35,6 +35,7 @@ reason = "The paste crate is a build-only dependency with no runtime components. # More documentation for the licenses section can be found here: # https://embarkstudios.github.io/cargo-deny/checks/licenses/cfg.html [licenses] +version = 2 allow = [ "0BSD", "Apache-2.0", diff --git a/pgxn/neon/communicator/Cargo.toml b/pgxn/neon/communicator/Cargo.toml index e95a269d90..b5ce389297 100644 --- a/pgxn/neon/communicator/Cargo.toml +++ b/pgxn/neon/communicator/Cargo.toml @@ -11,6 +11,9 @@ crate-type = ["staticlib"] # 'testing' feature is currently unused in the communicator, but we accept it for convenience of # calling build scripts, so that you can pass the same feature to all packages. testing = [] +# 'rest_broker' feature is currently unused in the communicator, but we accept it for convenience of +# calling build scripts, so that you can pass the same feature to all packages. +rest_broker = [] [dependencies] neon-shmem.workspace = true diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 82fe6818e3..8392046839 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -7,6 +7,7 @@ license.workspace = true [features] default = [] testing = ["dep:tokio-postgres"] +rest_broker = ["dep:subzero-core", "dep:ouroboros"] [dependencies] ahash.workspace = true @@ -105,6 +106,11 @@ uuid.workspace = true x509-cert.workspace = true redis.workspace = true zerocopy.workspace = true +# uncomment this to use the real subzero-core crate +# subzero-core = { git = "https://github.com/neondatabase/subzero", rev = "396264617e78e8be428682f87469bb25429af88a", features = ["postgresql"], optional = true } +# this is a stub for the subzero-core crate +subzero-core = { path = "./subzero_core", features = ["postgresql"], optional = true} +ouroboros = { version = "0.18", optional = true } # jwt stuff jose-jwa = "0.1.2" diff --git a/proxy/README.md b/proxy/README.md index ff48f9f323..ce957b90af 100644 --- a/proxy/README.md +++ b/proxy/README.md @@ -178,16 +178,24 @@ Create a configuration file called `local_proxy.json` in the root of the repo (u Start the local proxy: ```sh -cargo run --bin local_proxy -- \ - --disable_pg_session_jwt true \ +cargo run --bin local_proxy --features testing -- \ + --disable-pg-session-jwt \ --http 0.0.0.0:7432 ``` -Start the auth broker: +Start the auth/rest broker: + +Note: to enable the rest broker you need to replace the stub subzero-core crate with the real one. + ```sh -LOGFMT=text OTEL_SDK_DISABLED=true cargo run --bin proxy --features testing -- \ +cargo add -p proxy subzero-core --git https://github.com/neondatabase/subzero --rev 396264617e78e8be428682f87469bb25429af88a +``` + +```sh +LOGFMT=text OTEL_SDK_DISABLED=true cargo run --bin proxy --features testing,rest_broker -- \ -c server.crt -k server.key \ --is-auth-broker true \ + --is-rest-broker true \ --wss 0.0.0.0:8080 \ --http 0.0.0.0:7002 \ --auth-backend local @@ -205,3 +213,9 @@ curl -k "https://foo.local.neon.build:8080/sql" \ -H "neon-connection-string: postgresql://authenticator@foo.local.neon.build/database" \ -d '{"query":"select 1","params":[]}' ``` + +Make a rest request against the auth broker (rest broker): +```sh +curl -k "https://foo.local.neon.build:8080/database/rest/v1/items?select=id,name&id=eq.1" \ +-H "Authorization: Bearer $NEON_JWT" +``` diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index 401203d48c..e3f7ba4c15 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -20,6 +20,8 @@ use crate::auth::backend::jwt::JwkCache; use crate::auth::backend::local::LocalBackend; use crate::auth::{self}; use crate::cancellation::CancellationHandler; +#[cfg(feature = "rest_broker")] +use crate::config::RestConfig; use crate::config::{ self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig, refresh_config_loop, @@ -276,6 +278,13 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig accept_jwts: true, console_redirect_confirmation_timeout: Duration::ZERO, }, + #[cfg(feature = "rest_broker")] + rest_config: RestConfig { + is_rest_broker: false, + db_schema_cache: None, + max_schema_size: 0, + hostname_prefix: String::new(), + }, proxy_protocol_v2: config::ProxyProtocolV2::Rejected, handshake_timeout: Duration::from_secs(10), wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?, diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 16a7dc7b67..194a1ed34c 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -31,6 +31,8 @@ use crate::auth::backend::local::LocalBackend; use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned}; use crate::batch::BatchQueue; use crate::cancellation::{CancellationHandler, CancellationProcessor}; +#[cfg(feature = "rest_broker")] +use crate::config::RestConfig; #[cfg(any(test, feature = "testing"))] use crate::config::refresh_config_loop; use crate::config::{ @@ -47,6 +49,8 @@ use crate::redis::{elasticache, notifications}; use crate::scram::threadpool::ThreadPool; use crate::serverless::GlobalConnPoolOptions; use crate::serverless::cancel_set::CancelSet; +#[cfg(feature = "rest_broker")] +use crate::serverless::rest::DbSchemaCache; use crate::tls::client_config::compute_client_config_with_root_certs; #[cfg(any(test, feature = "testing"))] use crate::url::ApiUrl; @@ -246,11 +250,23 @@ struct ProxyCliArgs { /// if this is not local proxy, this toggles whether we accept Postgres REST requests #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] + #[cfg(feature = "rest_broker")] is_rest_broker: bool, /// cache for `db_schema_cache` introspection (use `size=0` to disable) #[clap(long, default_value = "size=1000,ttl=1h")] + #[cfg(feature = "rest_broker")] db_schema_cache: String, + + /// Maximum size allowed for schema in bytes + #[clap(long, default_value_t = 5 * 1024 * 1024)] // 5MB + #[cfg(feature = "rest_broker")] + max_schema_size: usize, + + /// Hostname prefix to strip from request hostname to get database hostname + #[clap(long, default_value = "apirest.")] + #[cfg(feature = "rest_broker")] + hostname_prefix: String, } #[derive(clap::Args, Clone, Copy, Debug)] @@ -517,6 +533,17 @@ pub async fn run() -> anyhow::Result<()> { )); maintenance_tasks.spawn(control_plane::mgmt::task_main(mgmt_listener)); + // add a task to flush the db_schema cache every 10 minutes + #[cfg(feature = "rest_broker")] + if let Some(db_schema_cache) = &config.rest_config.db_schema_cache { + maintenance_tasks.spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(600)).await; + db_schema_cache.flush(); + } + }); + } + if let Some(metrics_config) = &config.metric_collection { // TODO: Add gc regardles of the metric collection being enabled. maintenance_tasks.spawn(usage_metrics::task_main(metrics_config)); @@ -679,6 +706,30 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { timeout: Duration::from_secs(2), }; + #[cfg(feature = "rest_broker")] + let rest_config = { + let db_schema_cache_config: CacheOptions = args.db_schema_cache.parse()?; + info!("Using DbSchemaCache with options={db_schema_cache_config:?}"); + + let db_schema_cache = if args.is_rest_broker { + Some(DbSchemaCache::new( + "db_schema_cache", + db_schema_cache_config.size, + db_schema_cache_config.ttl, + true, + )) + } else { + None + }; + + RestConfig { + is_rest_broker: args.is_rest_broker, + db_schema_cache, + max_schema_size: args.max_schema_size, + hostname_prefix: args.hostname_prefix.clone(), + } + }; + let config = ProxyConfig { tls_config, metric_collection, @@ -691,6 +742,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { connect_to_compute: compute_config, #[cfg(feature = "testing")] disable_pg_session_jwt: false, + #[cfg(feature = "rest_broker")] + rest_config, }; let config = Box::leak(Box::new(config)); diff --git a/proxy/src/cache/timed_lru.rs b/proxy/src/cache/timed_lru.rs index e87cf53ab9..0a7fb40b0c 100644 --- a/proxy/src/cache/timed_lru.rs +++ b/proxy/src/cache/timed_lru.rs @@ -204,6 +204,11 @@ impl TimedLru { self.insert_raw_ttl(key, value, ttl, false); } + #[cfg(feature = "rest_broker")] + pub(crate) fn insert(&self, key: K, value: V) { + self.insert_raw_ttl(key, value, self.ttl, self.update_ttl_on_retrieval); + } + pub(crate) fn insert_unit(&self, key: K, value: V) -> (Option, Cached<&Self, ()>) { let (_, old) = self.insert_raw(key.clone(), value); @@ -214,6 +219,29 @@ impl TimedLru { (old, cached) } + + #[cfg(feature = "rest_broker")] + pub(crate) fn flush(&self) { + let now = Instant::now(); + let mut cache = self.cache.lock(); + + // Collect keys of expired entries first + let expired_keys: Vec<_> = cache + .iter() + .filter_map(|(key, entry)| { + if entry.expires_at <= now { + Some(key.clone()) + } else { + None + } + }) + .collect(); + + // Remove expired entries + for key in expired_keys { + cache.remove(&key); + } + } } impl TimedLru { diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 6157dc8a6a..20bbfd77d8 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -22,6 +22,8 @@ use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig}; use crate::scram::threadpool::ThreadPool; use crate::serverless::GlobalConnPoolOptions; use crate::serverless::cancel_set::CancelSet; +#[cfg(feature = "rest_broker")] +use crate::serverless::rest::DbSchemaCache; pub use crate::tls::server_config::{TlsConfig, configure_tls}; use crate::types::{Host, RoleName}; @@ -30,6 +32,8 @@ pub struct ProxyConfig { pub metric_collection: Option, pub http_config: HttpConfig, pub authentication_config: AuthenticationConfig, + #[cfg(feature = "rest_broker")] + pub rest_config: RestConfig, pub proxy_protocol_v2: ProxyProtocolV2, pub handshake_timeout: Duration, pub wake_compute_retry_config: RetryConfig, @@ -80,6 +84,14 @@ pub struct AuthenticationConfig { pub console_redirect_confirmation_timeout: tokio::time::Duration, } +#[cfg(feature = "rest_broker")] +pub struct RestConfig { + pub is_rest_broker: bool, + pub db_schema_cache: Option, + pub max_schema_size: usize, + pub hostname_prefix: String, +} + #[derive(Debug)] pub struct MetricBackupCollectionConfig { pub remote_storage_config: Option, diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index a6d376562b..88d5550fff 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -10,6 +10,7 @@ use super::connection_with_credentials_provider::ConnectionWithCredentialsProvid use crate::cache::project_info::ProjectInfoCache; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::metrics::{Metrics, RedisErrors, RedisEventsCount}; +use crate::util::deserialize_json_string; const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20); @@ -121,15 +122,6 @@ struct InvalidateRole { role_name: RoleNameInt, } -fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result -where - T: for<'de2> serde::Deserialize<'de2>, - D: serde::Deserializer<'de>, -{ - let s = String::deserialize(deserializer)?; - serde_json::from_str(&s).map_err(::custom) -} - // https://github.com/serde-rs/serde/issues/1714 fn deserialize_unknown_topic<'de, D>(deserializer: D) -> Result<(), D::Error> where diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 5b7289c53d..18cdc39ac7 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -11,6 +11,8 @@ mod http_conn_pool; mod http_util; mod json; mod local_conn_pool; +#[cfg(feature = "rest_broker")] +pub mod rest; mod sql_over_http; mod websocket; @@ -487,6 +489,42 @@ async fn request_handler( .body(Empty::new().map_err(|x| match x {}).boxed()) .map_err(|e| ApiError::InternalServerError(e.into())) } else { - json_response(StatusCode::BAD_REQUEST, "query is not supported") + #[cfg(feature = "rest_broker")] + { + if config.rest_config.is_rest_broker + // we are testing for the path to be /database_name/rest/... + && request + .uri() + .path() + .split('/') + .nth(2) + .is_some_and(|part| part.starts_with("rest")) + { + let ctx = + RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http); + let span = ctx.span(); + + let testodrome_id = request + .headers() + .get("X-Neon-Query-ID") + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + if let Some(query_id) = testodrome_id { + info!(parent: &span, "testodrome query ID: {query_id}"); + ctx.set_testodrome_id(query_id.into()); + } + + rest::handle(config, ctx, request, backend, http_cancellation_token) + .instrument(span) + .await + } else { + json_response(StatusCode::BAD_REQUEST, "query is not supported") + } + } + #[cfg(not(feature = "rest_broker"))] + { + json_response(StatusCode::BAD_REQUEST, "query is not supported") + } } } diff --git a/proxy/src/serverless/rest.rs b/proxy/src/serverless/rest.rs new file mode 100644 index 0000000000..173c2629f7 --- /dev/null +++ b/proxy/src/serverless/rest.rs @@ -0,0 +1,1165 @@ +use std::borrow::Cow; +use std::collections::HashMap; +use std::sync::Arc; + +use bytes::Bytes; +use http::Method; +use http::header::{AUTHORIZATION, CONTENT_TYPE, HOST}; +use http_body_util::combinators::BoxBody; +use http_body_util::{BodyExt, Full}; +use http_utils::error::ApiError; +use hyper::body::Incoming; +use hyper::http::{HeaderName, HeaderValue}; +use hyper::{Request, Response, StatusCode}; +use indexmap::IndexMap; +use ouroboros::self_referencing; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Deserializer}; +use serde_json::Value as JsonValue; +use serde_json::value::RawValue; +use subzero_core::api::ContentType::{ApplicationJSON, Other, SingularJSON, TextCSV}; +use subzero_core::api::QueryNode::{Delete, FunctionCall, Insert, Update}; +use subzero_core::api::Resolution::{IgnoreDuplicates, MergeDuplicates}; +use subzero_core::api::{ApiResponse, ListVal, Payload, Preferences, Representation, SingleVal}; +use subzero_core::config::{db_allowed_select_functions, db_schemas, role_claim_key}; +use subzero_core::dynamic_statement::{JoinIterator, param, sql}; +use subzero_core::error::Error::{ + self as SubzeroCoreError, ContentTypeError, GucHeadersError, GucStatusError, InternalError, + JsonDeserialize, JwtTokenInvalid, NotFound, +}; +use subzero_core::error::pg_error_to_status_code; +use subzero_core::formatter::Param::{LV, PL, SV, Str, StrOwned}; +use subzero_core::formatter::postgresql::{fmt_main_query, generate}; +use subzero_core::formatter::{Param, Snippet, SqlParam}; +use subzero_core::parser::postgrest::parse; +use subzero_core::permissions::{check_safe_functions, replace_select_star}; +use subzero_core::schema::{ + DbSchema, POSTGRESQL_INTROSPECTION_SQL, get_postgresql_configuration_query, +}; +use subzero_core::{content_range_header, content_range_status}; +use tokio_util::sync::CancellationToken; +use tracing::{error, info}; +use typed_json::json; +use url::form_urlencoded; + +use super::backend::{HttpConnError, LocalProxyConnError, PoolingBackend}; +use super::conn_pool::AuthData; +use super::conn_pool_lib::ConnInfo; +use super::error::{ConnInfoError, Credentials, HttpCodeError, ReadPayloadError}; +use super::http_conn_pool::{self, Send}; +use super::http_util::{ + ALLOW_POOL, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_ISOLATION_LEVEL, TXN_READ_ONLY, + get_conn_info, json_response, uuid_to_header_value, +}; +use super::json::JsonConversionError; +use crate::auth::backend::ComputeCredentialKeys; +use crate::cache::{Cached, TimedLru}; +use crate::config::ProxyConfig; +use crate::context::RequestContext; +use crate::error::{ErrorKind, ReportableError, UserFacingError}; +use crate::http::read_body_with_limit; +use crate::metrics::Metrics; +use crate::serverless::sql_over_http::HEADER_VALUE_TRUE; +use crate::types::EndpointCacheKey; +use crate::util::deserialize_json_string; + +static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#; +const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL; + +// A wrapper around the DbSchema that allows for self-referencing +#[self_referencing] +pub struct DbSchemaOwned { + schema_string: String, + #[covariant] + #[borrows(schema_string)] + schema: DbSchema<'this>, +} + +impl<'de> Deserialize<'de> for DbSchemaOwned { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + DbSchemaOwned::try_new(s, |s| serde_json::from_str(s)) + .map_err(::custom) + } +} + +fn split_comma_separated(s: &str) -> Vec { + s.split(',').map(|s| s.trim().to_string()).collect() +} + +fn deserialize_comma_separated<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + Ok(split_comma_separated(&s)) +} + +fn deserialize_comma_separated_option<'de, D>( + deserializer: D, +) -> Result>, D::Error> +where + D: Deserializer<'de>, +{ + let opt = Option::::deserialize(deserializer)?; + if let Some(s) = &opt { + let trimmed = s.trim(); + if trimmed.is_empty() { + return Ok(None); + } + return Ok(Some(split_comma_separated(trimmed))); + } + Ok(None) +} + +// The ApiConfig is the configuration for the API per endpoint +// The configuration is read from the database and cached in the DbSchemaCache +#[derive(Deserialize, Debug)] +pub struct ApiConfig { + #[serde( + default = "db_schemas", + deserialize_with = "deserialize_comma_separated" + )] + pub db_schemas: Vec, + pub db_anon_role: Option, + pub db_max_rows: Option, + #[serde(default = "db_allowed_select_functions")] + pub db_allowed_select_functions: Vec, + // #[serde(deserialize_with = "to_tuple", default)] + // pub db_pre_request: Option<(String, String)>, + #[allow(dead_code)] + #[serde(default = "role_claim_key")] + pub role_claim_key: String, + #[serde(default, deserialize_with = "deserialize_comma_separated_option")] + pub db_extra_search_path: Option>, +} + +// The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint +pub(crate) type DbSchemaCache = TimedLru>; +impl DbSchemaCache { + pub async fn get_cached_or_remote( + &self, + endpoint_id: &EndpointCacheKey, + auth_header: &HeaderValue, + connection_string: &str, + client: &mut http_conn_pool::Client, + ctx: &RequestContext, + config: &'static ProxyConfig, + ) -> Result, RestError> { + match self.get_with_created_at(endpoint_id) { + Some(Cached { value: (v, _), .. }) => Ok(v), + None => { + info!("db_schema cache miss for endpoint: {:?}", endpoint_id); + let remote_value = self + .get_remote(auth_header, connection_string, client, ctx, config) + .await; + let (api_config, schema_owned) = match remote_value { + Ok((api_config, schema_owned)) => (api_config, schema_owned), + Err(e @ RestError::SchemaTooLarge) => { + // for the case where the schema is too large, we cache an empty dummy value + // all the other requests will fail without triggering the introspection query + let schema_owned = serde_json::from_str::(EMPTY_JSON_SCHEMA) + .map_err(|e| JsonDeserialize { source: e })?; + + let api_config = ApiConfig { + db_schemas: vec![], + db_anon_role: None, + db_max_rows: None, + db_allowed_select_functions: vec![], + role_claim_key: String::new(), + db_extra_search_path: None, + }; + let value = Arc::new((api_config, schema_owned)); + self.insert(endpoint_id.clone(), value); + return Err(e); + } + Err(e) => { + return Err(e); + } + }; + let value = Arc::new((api_config, schema_owned)); + self.insert(endpoint_id.clone(), value.clone()); + Ok(value) + } + } + } + pub async fn get_remote( + &self, + auth_header: &HeaderValue, + connection_string: &str, + client: &mut http_conn_pool::Client, + ctx: &RequestContext, + config: &'static ProxyConfig, + ) -> Result<(ApiConfig, DbSchemaOwned), RestError> { + #[derive(Deserialize)] + struct SingleRow { + rows: [Row; 1], + } + + #[derive(Deserialize)] + struct ConfigRow { + #[serde(deserialize_with = "deserialize_json_string")] + config: ApiConfig, + } + + #[derive(Deserialize)] + struct SchemaRow { + json_schema: DbSchemaOwned, + } + + let headers = vec![ + (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())), + ( + &CONN_STRING, + HeaderValue::from_str(connection_string).expect( + "connection string came from a header, so it must be a valid headervalue", + ), + ), + (&AUTHORIZATION, auth_header.clone()), + (&RAW_TEXT_OUTPUT, HEADER_VALUE_TRUE), + ]; + + let query = get_postgresql_configuration_query(Some("pgrst.pre_config")); + let SingleRow { + rows: [ConfigRow { config: api_config }], + } = make_local_proxy_request( + client, + headers.iter().cloned(), + QueryData { + query: Cow::Owned(query), + params: vec![], + }, + config.rest_config.max_schema_size, + ) + .await + .map_err(|e| match e { + RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => { + RestError::SchemaTooLarge + } + e => e, + })?; + + // now that we have the api_config let's run the second INTROSPECTION_SQL query + let SingleRow { + rows: [SchemaRow { json_schema }], + } = make_local_proxy_request( + client, + headers, + QueryData { + query: INTROSPECTION_SQL.into(), + params: vec![ + serde_json::to_value(&api_config.db_schemas) + .expect("Vec is always valid to encode as JSON"), + JsonValue::Bool(false), // include_roles_with_login + JsonValue::Bool(false), // use_internal_permissions + ], + }, + config.rest_config.max_schema_size, + ) + .await + .map_err(|e| match e { + RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => { + RestError::SchemaTooLarge + } + e => e, + })?; + + Ok((api_config, json_schema)) + } +} + +// A type to represent a postgresql errors +// we use our own type (instead of postgres_client::Error) because we get the error from the json response +#[derive(Debug, thiserror::Error, Deserialize)] +pub(crate) struct PostgresError { + pub code: String, + pub message: String, + pub detail: Option, + pub hint: Option, +} +impl HttpCodeError for PostgresError { + fn get_http_status_code(&self) -> StatusCode { + let status = pg_error_to_status_code(&self.code, true); + StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) + } +} +impl ReportableError for PostgresError { + fn get_error_kind(&self) -> ErrorKind { + ErrorKind::User + } +} +impl UserFacingError for PostgresError { + fn to_string_client(&self) -> String { + if self.code.starts_with("PT") { + "Postgres error".to_string() + } else { + self.message.clone() + } + } +} +impl std::fmt::Display for PostgresError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} + +// A type to represent errors that can occur in the rest broker +#[derive(Debug, thiserror::Error)] +pub(crate) enum RestError { + #[error(transparent)] + ReadPayload(#[from] ReadPayloadError), + #[error(transparent)] + ConnectCompute(#[from] HttpConnError), + #[error(transparent)] + ConnInfo(#[from] ConnInfoError), + #[error(transparent)] + Postgres(#[from] PostgresError), + #[error(transparent)] + JsonConversion(#[from] JsonConversionError), + #[error(transparent)] + SubzeroCore(#[from] SubzeroCoreError), + #[error("schema is too large")] + SchemaTooLarge, +} +impl ReportableError for RestError { + fn get_error_kind(&self) -> ErrorKind { + match self { + RestError::ReadPayload(e) => e.get_error_kind(), + RestError::ConnectCompute(e) => e.get_error_kind(), + RestError::ConnInfo(e) => e.get_error_kind(), + RestError::Postgres(_) => ErrorKind::Postgres, + RestError::JsonConversion(_) => ErrorKind::Postgres, + RestError::SubzeroCore(_) => ErrorKind::User, + RestError::SchemaTooLarge => ErrorKind::User, + } + } +} +impl UserFacingError for RestError { + fn to_string_client(&self) -> String { + match self { + RestError::ReadPayload(p) => p.to_string(), + RestError::ConnectCompute(c) => c.to_string_client(), + RestError::ConnInfo(c) => c.to_string_client(), + RestError::SchemaTooLarge => self.to_string(), + RestError::Postgres(p) => p.to_string_client(), + RestError::JsonConversion(_) => "could not parse postgres response".to_string(), + RestError::SubzeroCore(s) => { + // TODO: this is a hack to get the message from the json body + let json = s.json_body(); + let default_message = "Unknown error".to_string(); + + json.get("message") + .map_or(default_message.clone(), |m| match m { + JsonValue::String(s) => s.clone(), + _ => default_message, + }) + } + } + } +} +impl HttpCodeError for RestError { + fn get_http_status_code(&self) -> StatusCode { + match self { + RestError::ReadPayload(e) => e.get_http_status_code(), + RestError::ConnectCompute(h) => match h.get_error_kind() { + ErrorKind::User => StatusCode::BAD_REQUEST, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }, + RestError::ConnInfo(_) => StatusCode::BAD_REQUEST, + RestError::Postgres(e) => e.get_http_status_code(), + RestError::JsonConversion(_) => StatusCode::INTERNAL_SERVER_ERROR, + RestError::SchemaTooLarge => StatusCode::INTERNAL_SERVER_ERROR, + RestError::SubzeroCore(e) => { + let status = e.status_code(); + StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) + } + } + } +} + +// Helper functions for the rest broker + +fn fmt_env_query<'a>(env: &'a HashMap<&'a str, &'a str>) -> Snippet<'a> { + "select " + + if env.is_empty() { + sql("null") + } else { + env.iter() + .map(|(k, v)| { + "set_config(" + param(k as &SqlParam) + ", " + param(v as &SqlParam) + ", true)" + }) + .join(",") + } +} + +// TODO: see about removing the need for cloning the values (inner things are &Cow already) +fn to_sql_param(p: &Param) -> JsonValue { + match p { + SV(SingleVal(v, ..)) => JsonValue::String(v.to_string()), + Str(v) => JsonValue::String((*v).to_string()), + StrOwned(v) => JsonValue::String((*v).clone()), + PL(Payload(v, ..)) => JsonValue::String(v.clone().into_owned()), + LV(ListVal(v, ..)) => { + if v.is_empty() { + JsonValue::String(r"{}".to_string()) + } else { + JsonValue::String(format!( + "{{\"{}\"}}", + v.iter() + .map(|e| e.replace('\\', "\\\\").replace('\"', "\\\"")) + .collect::>() + .join("\",\"") + )) + } + } + } +} + +#[derive(serde::Serialize)] +struct QueryData<'a> { + query: Cow<'a, str>, + params: Vec, +} + +#[derive(serde::Serialize)] +struct BatchQueryData<'a> { + queries: Vec>, +} + +async fn make_local_proxy_request( + client: &mut http_conn_pool::Client, + headers: impl IntoIterator, + body: QueryData<'_>, + max_len: usize, +) -> Result { + let body_string = serde_json::to_string(&body) + .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?; + + let response = make_raw_local_proxy_request(client, headers, body_string).await?; + + let response_status = response.status(); + + if response_status != StatusCode::OK { + return Err(RestError::SubzeroCore(InternalError { + message: "Failed to get endpoint schema".to_string(), + })); + } + + // Capture the response body + let response_body = crate::http::read_body_with_limit(response.into_body(), max_len) + .await + .map_err(ReadPayloadError::from)?; + + // Parse the JSON response + let response_json: S = serde_json::from_slice(&response_body) + .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; + + Ok(response_json) +} + +async fn make_raw_local_proxy_request( + client: &mut http_conn_pool::Client, + headers: impl IntoIterator, + body: String, +) -> Result, RestError> { + let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql"); + let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri); + let req_headers = req.headers_mut().expect("failed to get headers"); + // Add all provided headers to the request + for (header_name, header_value) in headers { + req_headers.insert(header_name, header_value.clone()); + } + + let body_boxed = Full::new(Bytes::from(body)) + .map_err(|never| match never {}) // Convert Infallible to hyper::Error + .boxed(); + + let req = req.body(body_boxed).map_err(|_| { + RestError::SubzeroCore(InternalError { + message: "Failed to build request".to_string(), + }) + })?; + + // Send the request to the local proxy + client + .inner + .inner + .send_request(req) + .await + .map_err(LocalProxyConnError::from) + .map_err(HttpConnError::from) + .map_err(RestError::from) +} + +pub(crate) async fn handle( + config: &'static ProxyConfig, + ctx: RequestContext, + request: Request, + backend: Arc, + cancel: CancellationToken, +) -> Result>, ApiError> { + let result = handle_inner(cancel, config, &ctx, request, backend).await; + + let mut response = match result { + Ok(r) => { + ctx.set_success(); + + // Handling the error response from local proxy here + if r.status().is_server_error() { + let status = r.status(); + + let body_bytes = r + .collect() + .await + .map_err(|e| { + ApiError::InternalServerError(anyhow::Error::msg(format!( + "could not collect http body: {e}" + ))) + })? + .to_bytes(); + + if let Ok(mut json_map) = + serde_json::from_slice::>(&body_bytes) + { + let message = json_map.get("message"); + if let Some(message) = message { + let msg: String = match serde_json::from_str(message.get()) { + Ok(msg) => msg, + Err(_) => { + "Unable to parse the response message from server".to_string() + } + }; + + error!("Error response from local_proxy: {status} {msg}"); + + json_map.retain(|key, _| !key.starts_with("neon:")); // remove all the neon-related keys + + let resp_json = serde_json::to_string(&json_map) + .unwrap_or("failed to serialize the response message".to_string()); + + return json_response(status, resp_json); + } + } + + error!("Unable to parse the response message from local_proxy"); + return json_response( + status, + json!({ "message": "Unable to parse the response message from server".to_string() }), + ); + } + r + } + Err(e @ RestError::SubzeroCore(_)) => { + let error_kind = e.get_error_kind(); + ctx.set_error_kind(error_kind); + + tracing::info!( + kind=error_kind.to_metric_label(), + error=%e, + msg="subzero core error", + "forwarding error to user" + ); + + let RestError::SubzeroCore(subzero_err) = e else { + panic!("expected subzero core error") + }; + + let json_body = subzero_err.json_body(); + let status_code = StatusCode::from_u16(subzero_err.status_code()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + + json_response(status_code, json_body)? + } + Err(e) => { + let error_kind = e.get_error_kind(); + ctx.set_error_kind(error_kind); + + let message = e.to_string_client(); + let status_code = e.get_http_status_code(); + + tracing::info!( + kind=error_kind.to_metric_label(), + error=%e, + msg=message, + "forwarding error to user" + ); + + let (code, detail, hint) = match e { + RestError::Postgres(e) => ( + if e.code.starts_with("PT") { + None + } else { + Some(e.code) + }, + e.detail, + e.hint, + ), + _ => (None, None, None), + }; + + json_response( + status_code, + json!({ + "message": message, + "code": code, + "detail": detail, + "hint": hint, + }), + )? + } + }; + + response + .headers_mut() + .insert("Access-Control-Allow-Origin", HeaderValue::from_static("*")); + Ok(response) +} + +async fn handle_inner( + _cancel: CancellationToken, + config: &'static ProxyConfig, + ctx: &RequestContext, + request: Request, + backend: Arc, +) -> Result>, RestError> { + let _requeset_gauge = Metrics::get() + .proxy + .connection_requests + .guard(ctx.protocol()); + info!( + protocol = %ctx.protocol(), + "handling interactive connection from client" + ); + + // Read host from Host, then URI host as fallback + // TODO: will this be a problem if behind a load balancer? + // TODO: can we use the x-forwarded-host header? + let host = request + .headers() + .get(HOST) + .and_then(|v| v.to_str().ok()) + .unwrap_or_else(|| request.uri().host().unwrap_or("")); + + // a valid path is /database/rest/v1/... so splitting should be ["", "database", "rest", "v1", ...] + let database_name = request + .uri() + .path() + .split('/') + .nth(1) + .ok_or(RestError::SubzeroCore(NotFound { + target: request.uri().path().to_string(), + }))?; + + // we always use the authenticator role to connect to the database + let authenticator_role = "authenticator"; + + // Strip the hostname prefix from the host to get the database hostname + let database_host = host.replace(&config.rest_config.hostname_prefix, ""); + + let connection_string = + format!("postgresql://{authenticator_role}@{database_host}/{database_name}"); + + let conn_info = get_conn_info( + &config.authentication_config, + ctx, + Some(&connection_string), + request.headers(), + )?; + info!( + user = conn_info.conn_info.user_info.user.as_str(), + "credentials" + ); + + match conn_info.auth { + AuthData::Jwt(jwt) => { + let api_prefix = format!("/{database_name}/rest/v1/"); + handle_rest_inner( + config, + ctx, + &api_prefix, + request, + &connection_string, + conn_info.conn_info, + jwt, + backend, + ) + .await + } + AuthData::Password(_) => Err(RestError::ConnInfo(ConnInfoError::MissingCredentials( + Credentials::BearerJwt, + ))), + } +} + +#[allow(clippy::too_many_arguments)] +async fn handle_rest_inner( + config: &'static ProxyConfig, + ctx: &RequestContext, + api_prefix: &str, + request: Request, + connection_string: &str, + conn_info: ConnInfo, + jwt: String, + backend: Arc, +) -> Result>, RestError> { + // validate the jwt token + let jwt_parsed = backend + .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) + .await + .map_err(HttpConnError::from)?; + + let db_schema_cache = + config + .rest_config + .db_schema_cache + .as_ref() + .ok_or(RestError::SubzeroCore(InternalError { + message: "DB schema cache is not configured".to_string(), + }))?; + + let endpoint_cache_key = conn_info + .endpoint_cache_key() + .ok_or(RestError::SubzeroCore(InternalError { + message: "Failed to get endpoint cache key".to_string(), + }))?; + + let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?; + + let (parts, originial_body) = request.into_parts(); + + let auth_header = parts + .headers + .get(AUTHORIZATION) + .ok_or(RestError::SubzeroCore(InternalError { + message: "Authorization header is required".to_string(), + }))?; + + let entry = db_schema_cache + .get_cached_or_remote( + &endpoint_cache_key, + auth_header, + connection_string, + &mut client, + ctx, + config, + ) + .await?; + let (api_config, db_schema_owned) = entry.as_ref(); + let db_schema = db_schema_owned.borrow_schema(); + + let db_schemas = &api_config.db_schemas; // list of schemas available for the api + let db_extra_search_path = &api_config.db_extra_search_path; + // TODO: use this when we get a replacement for jsonpath_lib + // let role_claim_key = &api_config.role_claim_key; + // let role_claim_path = format!("${role_claim_key}"); + let db_anon_role = &api_config.db_anon_role; + let max_rows = api_config.db_max_rows.as_deref(); + let db_allowed_select_functions = api_config + .db_allowed_select_functions + .iter() + .map(|s| s.as_str()) + .collect::>(); + + // extract the jwt claims (we'll need them later to set the role and env) + let jwt_claims = match jwt_parsed.keys { + ComputeCredentialKeys::JwtPayload(payload_bytes) => { + // `payload_bytes` contains the raw JWT payload as Vec + // You can deserialize it back to JSON or parse specific claims + let payload: serde_json::Value = serde_json::from_slice(&payload_bytes) + .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; + Some(payload) + } + ComputeCredentialKeys::AuthKeys(_) => None, + }; + + // read the role from the jwt claims (and set it to the "anon" role if not present) + let (role, authenticated) = match &jwt_claims { + Some(claims) => match claims.get("role") { + Some(JsonValue::String(r)) => (Some(r), true), + _ => (db_anon_role.as_ref(), true), + }, + None => (db_anon_role.as_ref(), false), + }; + + // do not allow unauthenticated requests when there is no anonymous role setup + if let (None, false) = (role, authenticated) { + return Err(RestError::SubzeroCore(JwtTokenInvalid { + message: "unauthenticated requests not allowed".to_string(), + })); + } + + // start deconstructing the request because subzero core mostly works with &str + let method = parts.method; + let method_str = method.as_str(); + let path = parts.uri.path_and_query().map_or("/", |pq| pq.as_str()); + + // this is actually the table name (or rpc/function_name) + // TODO: rename this to something more descriptive + let root = match parts.uri.path().strip_prefix(api_prefix) { + Some(p) => Ok(p), + None => Err(RestError::SubzeroCore(NotFound { + target: parts.uri.path().to_string(), + })), + }?; + + // pick the current schema from the headers (or the first one from config) + let schema_name = &DbSchema::pick_current_schema(db_schemas, method_str, &parts.headers)?; + + // add the content-profile header to the response + let mut response_headers = vec![]; + if db_schemas.len() > 1 { + response_headers.push(("Content-Profile".to_string(), schema_name.clone())); + } + + // parse the query string into a Vec<(&str, &str)> + let query = match parts.uri.query() { + Some(q) => form_urlencoded::parse(q.as_bytes()).collect(), + None => vec![], + }; + let get: Vec<(&str, &str)> = query.iter().map(|(k, v)| (&**k, &**v)).collect(); + + // convert the headers map to a HashMap<&str, &str> + let headers: HashMap<&str, &str> = parts + .headers + .iter() + .map(|(k, v)| (k.as_str(), v.to_str().unwrap_or("__BAD_HEADER__"))) + .collect(); + + let cookies = HashMap::new(); // TODO: add cookies + + // Read the request body (skip for GET requests) + let body_as_string: Option = if method == Method::GET { + None + } else { + let body_bytes = + read_body_with_limit(originial_body, config.http_config.max_request_size_bytes) + .await + .map_err(ReadPayloadError::from)?; + if body_bytes.is_empty() { + None + } else { + Some(String::from_utf8_lossy(&body_bytes).into_owned()) + } + }; + + // parse the request into an ApiRequest struct + let mut api_request = parse( + schema_name, + root, + db_schema, + method_str, + path, + get, + body_as_string.as_deref(), + headers, + cookies, + max_rows, + ) + .map_err(RestError::SubzeroCore)?; + + let role_str = match role { + Some(r) => r, + None => "", + }; + + replace_select_star(db_schema, schema_name, role_str, &mut api_request.query)?; + + // TODO: this is not relevant when acting as PostgREST but will be useful + // in the context of DBX where they need internal permissions + // if !disable_internal_permissions { + // check_privileges(db_schema, schema_name, role_str, &api_request)?; + // } + + check_safe_functions(&api_request, &db_allowed_select_functions)?; + + // TODO: this is not relevant when acting as PostgREST but will be useful + // in the context of DBX where they need internal permissions + // if !disable_internal_permissions { + // insert_policy_conditions(db_schema, schema_name, role_str, &mut api_request.query)?; + // } + + let env_role = Some(role_str); + + // construct the env (passed in to the sql context as GUCs) + let empty_json = "{}".to_string(); + let headers_env = serde_json::to_string(&api_request.headers).unwrap_or(empty_json.clone()); + let cookies_env = serde_json::to_string(&api_request.cookies).unwrap_or(empty_json.clone()); + let get_env = serde_json::to_string(&api_request.get).unwrap_or(empty_json.clone()); + let jwt_claims_env = jwt_claims + .as_ref() + .map(|v| serde_json::to_string(v).unwrap_or(empty_json.clone())) + .unwrap_or(if let Some(r) = env_role { + let claims: HashMap<&str, &str> = HashMap::from([("role", r)]); + serde_json::to_string(&claims).unwrap_or(empty_json.clone()) + } else { + empty_json.clone() + }); + let mut search_path = vec![api_request.schema_name]; + if let Some(extra) = &db_extra_search_path { + search_path.extend(extra.iter().map(|s| s.as_str())); + } + let search_path_str = search_path + .into_iter() + .filter(|s| !s.is_empty()) + .collect::>() + .join(","); + let mut env: HashMap<&str, &str> = HashMap::from([ + ("request.method", api_request.method), + ("request.path", api_request.path), + ("request.headers", &headers_env), + ("request.cookies", &cookies_env), + ("request.get", &get_env), + ("request.jwt.claims", &jwt_claims_env), + ("search_path", &search_path_str), + ]); + if let Some(r) = env_role { + env.insert("role", r); + } + + // generate the sql statements + let (env_statement, env_parameters, _) = generate(fmt_env_query(&env)); + let (main_statement, main_parameters, _) = generate(fmt_main_query( + db_schema, + api_request.schema_name, + &api_request, + &env, + )?); + + let mut headers = vec![ + (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())), + ( + &CONN_STRING, + HeaderValue::from_str(connection_string).expect("invalid connection string"), + ), + (&AUTHORIZATION, auth_header.clone()), + ( + &TXN_ISOLATION_LEVEL, + HeaderValue::from_static("ReadCommitted"), + ), + (&ALLOW_POOL, HEADER_VALUE_TRUE), + ]; + + if api_request.read_only { + headers.push((&TXN_READ_ONLY, HEADER_VALUE_TRUE)); + } + + // convert the parameters from subzero core representation to the local proxy repr. + let req_body = serde_json::to_string(&BatchQueryData { + queries: vec![ + QueryData { + query: env_statement.into(), + params: env_parameters + .iter() + .map(|p| to_sql_param(&p.to_param())) + .collect(), + }, + QueryData { + query: main_statement.into(), + params: main_parameters + .iter() + .map(|p| to_sql_param(&p.to_param())) + .collect(), + }, + ], + }) + .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?; + + // todo: map body to count egress + let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly? + + // send the request to the local proxy + let response = make_raw_local_proxy_request(&mut client, headers, req_body).await?; + let (parts, body) = response.into_parts(); + + let max_response = config.http_config.max_response_size_bytes; + let bytes = read_body_with_limit(body, max_response) + .await + .map_err(ReadPayloadError::from)?; + + // if the response status is greater than 399, then it is an error + // FIXME: check if there are other error codes or shapes of the response + if parts.status.as_u16() > 399 { + // turn this postgres error from the json into PostgresError + let postgres_error = serde_json::from_slice(&bytes) + .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; + + return Err(RestError::Postgres(postgres_error)); + } + + #[derive(Deserialize)] + struct QueryResults { + /// we run two queries, so we want only two results. + results: (EnvRows, MainRows), + } + + /// `env_statement` returns nothing of interest to us + #[derive(Deserialize)] + struct EnvRows {} + + #[derive(Deserialize)] + struct MainRows { + /// `main_statement` only returns a single row. + rows: [MainRow; 1], + } + + #[derive(Deserialize)] + struct MainRow { + body: String, + page_total: Option, + total_result_set: Option, + response_headers: Option, + response_status: Option, + } + + let results: QueryResults = serde_json::from_slice(&bytes) + .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; + + let QueryResults { + results: (_, MainRows { rows: [row] }), + } = results; + + // build the intermediate response object + let api_response = ApiResponse { + page_total: row.page_total.map_or(0, |v| v.parse::().unwrap_or(0)), + total_result_set: row.total_result_set.map(|v| v.parse::().unwrap_or(0)), + top_level_offset: 0, // FIXME: check why this is 0 + response_headers: row.response_headers, + response_status: row.response_status, + body: row.body, + }; + + // TODO: rollback the transaction if the page_total is not 1 and the accept_content_type is SingularJSON + // we can not do this in the context of proxy for now + // if api_request.accept_content_type == SingularJSON && api_response.page_total != 1 { + // // rollback the transaction here + // return Err(RestError::SubzeroCore(SingularityError { + // count: api_response.page_total, + // content_type: "application/vnd.pgrst.object+json".to_string(), + // })); + // } + + // TODO: rollback the transaction if the page_total is not 1 and the method is PUT + // we can not do this in the context of proxy for now + // if api_request.method == Method::PUT && api_response.page_total != 1 { + // // Makes sure the querystring pk matches the payload pk + // // e.g. PUT /items?id=eq.1 { "id" : 1, .. } is accepted, + // // PUT /items?id=eq.14 { "id" : 2, .. } is rejected. + // // If this condition is not satisfied then nothing is inserted, + // // rollback the transaction here + // return Err(RestError::SubzeroCore(PutMatchingPkError)); + // } + + // create and return the response to the client + // this section mostly deals with setting the right headers according to PostgREST specs + let page_total = api_response.page_total; + let total_result_set = api_response.total_result_set; + let top_level_offset = api_response.top_level_offset; + let response_content_type = match (&api_request.accept_content_type, &api_request.query.node) { + (SingularJSON, _) + | ( + _, + FunctionCall { + returns_single: true, + is_scalar: false, + .. + }, + ) => SingularJSON, + (TextCSV, _) => TextCSV, + _ => ApplicationJSON, + }; + + // check if the SQL env set some response headers (happens when we called a rpc function) + if let Some(response_headers_str) = api_response.response_headers { + let Ok(headers_json) = + serde_json::from_str::>>(response_headers_str.as_str()) + else { + return Err(RestError::SubzeroCore(GucHeadersError)); + }; + + response_headers.extend(headers_json.into_iter().flatten()); + } + + // calculate and set the content range header + let lower = top_level_offset as i64; + let upper = top_level_offset as i64 + page_total as i64 - 1; + let total = total_result_set.map(|t| t as i64); + let content_range = match (&method, &api_request.query.node) { + (&Method::POST, Insert { .. }) => content_range_header(1, 0, total), + (&Method::DELETE, Delete { .. }) => content_range_header(1, upper, total), + _ => content_range_header(lower, upper, total), + }; + response_headers.push(("Content-Range".to_string(), content_range)); + + // calculate the status code + #[rustfmt::skip] + let mut status = match (&method, &api_request.query.node, page_total, &api_request.preferences) { + (&Method::POST, Insert { .. }, ..) => 201, + (&Method::DELETE, Delete { .. }, _, Some(Preferences {representation: Some(Representation::Full),..}),) => 200, + (&Method::DELETE, Delete { .. }, ..) => 204, + (&Method::PATCH, Update { columns, .. }, 0, _) if !columns.is_empty() => 404, + (&Method::PATCH, Update { .. }, _,Some(Preferences {representation: Some(Representation::Full),..}),) => 200, + (&Method::PATCH, Update { .. }, ..) => 204, + (&Method::PUT, Insert { .. },_,Some(Preferences {representation: Some(Representation::Full),..}),) => 200, + (&Method::PUT, Insert { .. }, ..) => 204, + _ => content_range_status(lower, upper, total), + }; + + // add the preference-applied header + if let Some(Preferences { + resolution: Some(r), + .. + }) = api_request.preferences + { + response_headers.push(( + "Preference-Applied".to_string(), + match r { + MergeDuplicates => "resolution=merge-duplicates".to_string(), + IgnoreDuplicates => "resolution=ignore-duplicates".to_string(), + }, + )); + } + + // check if the SQL env set some response status (happens when we called a rpc function) + if let Some(response_status_str) = api_response.response_status { + status = response_status_str + .parse::() + .map_err(|_| RestError::SubzeroCore(GucStatusError))?; + } + + // set the content type header + // TODO: move this to a subzero function + // as_header_value(&self) -> Option<&str> + let http_content_type = match response_content_type { + SingularJSON => Ok("application/vnd.pgrst.object+json"), + TextCSV => Ok("text/csv"), + ApplicationJSON => Ok("application/json"), + Other(t) => Err(RestError::SubzeroCore(ContentTypeError { + message: format!("None of these Content-Types are available: {t}"), + })), + }?; + + // build the response body + let response_body = Full::new(Bytes::from(api_response.body)) + .map_err(|never| match never {}) + .boxed(); + + // build the response + let mut response = Response::builder() + .status(StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) + .header(CONTENT_TYPE, http_content_type); + + // Add all headers from response_headers vector + for (header_name, header_value) in response_headers { + response = response.header(header_name, header_value); + } + + // add the body and return the response + response.body(response_body).map_err(|_| { + RestError::SubzeroCore(InternalError { + message: "Failed to build response".to_string(), + }) + }) +} diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 8a14f804b6..f254b41b5b 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -64,7 +64,7 @@ enum Payload { Batch(BatchQueryData), } -static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); +pub(super) const HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result>, D::Error> where diff --git a/proxy/src/util.rs b/proxy/src/util.rs index 0291216d94..c89ebab008 100644 --- a/proxy/src/util.rs +++ b/proxy/src/util.rs @@ -20,3 +20,13 @@ pub async fn run_until( Either::Right((f2, _)) => Err(f2), } } + +pub fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result +where + T: for<'de2> serde::Deserialize<'de2>, + D: serde::Deserializer<'de>, +{ + use serde::Deserialize; + let s = String::deserialize(deserializer)?; + serde_json::from_str(&s).map_err(::custom) +} diff --git a/proxy/subzero_core/.gitignore b/proxy/subzero_core/.gitignore new file mode 100644 index 0000000000..f2f9e58ec3 --- /dev/null +++ b/proxy/subzero_core/.gitignore @@ -0,0 +1,2 @@ +target +Cargo.lock \ No newline at end of file diff --git a/proxy/subzero_core/Cargo.toml b/proxy/subzero_core/Cargo.toml new file mode 100644 index 0000000000..13185873d0 --- /dev/null +++ b/proxy/subzero_core/Cargo.toml @@ -0,0 +1,12 @@ +# This is a stub for the subzero-core crate. +[package] +name = "subzero-core" +version = "3.0.1" +edition = "2024" +publish = false # "private"! + +[features] +default = [] +postgresql = [] + +[dependencies] diff --git a/proxy/subzero_core/src/lib.rs b/proxy/subzero_core/src/lib.rs new file mode 100644 index 0000000000..b99246b98b --- /dev/null +++ b/proxy/subzero_core/src/lib.rs @@ -0,0 +1 @@ +// This is a stub for the subzero-core crate. diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 86ffa9e4d4..1ce34a2c4e 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -4121,6 +4121,294 @@ class NeonAuthBroker: self._popen.kill() +class NeonLocalProxy(LogUtils): + """ + An object managing a local_proxy instance for rest broker testing. + The local_proxy serves as a direct connection to VanillaPostgres. + """ + + def __init__( + self, + neon_binpath: Path, + test_output_dir: Path, + http_port: int, + metrics_port: int, + vanilla_pg: VanillaPostgres, + config_path: Path | None = None, + ): + self.neon_binpath = neon_binpath + self.test_output_dir = test_output_dir + self.http_port = http_port + self.metrics_port = metrics_port + self.vanilla_pg = vanilla_pg + self.config_path = config_path or (test_output_dir / "local_proxy.json") + self.host = "127.0.0.1" + self.running = False + self.logfile = test_output_dir / "local_proxy.log" + self._popen: subprocess.Popen[bytes] | None = None + super().__init__(logfile=self.logfile) + + def start(self) -> Self: + assert self._popen is None + assert not self.running + + # Ensure vanilla_pg is running + if not self.vanilla_pg.is_running(): + self.vanilla_pg.start() + + args = [ + str(self.neon_binpath / "local_proxy"), + "--http", + f"{self.host}:{self.http_port}", + "--metrics", + f"{self.host}:{self.metrics_port}", + "--postgres", + f"127.0.0.1:{self.vanilla_pg.default_options['port']}", + "--config-path", + str(self.config_path), + "--disable-pg-session-jwt", + ] + + logfile = open(self.logfile, "w") + self._popen = subprocess.Popen(args, stdout=logfile, stderr=logfile) + self.running = True + self._wait_until_ready() + return self + + def stop(self) -> Self: + if self._popen is not None and self.running: + self._popen.terminate() + try: + self._popen.wait(timeout=5) + except subprocess.TimeoutExpired: + log.warning("failed to gracefully terminate local_proxy; killing") + self._popen.kill() + self.running = False + return self + + def get_binary_version(self) -> str: + """Get the version string of the local_proxy binary""" + try: + result = subprocess.run( + [str(self.neon_binpath / "local_proxy"), "--version"], + capture_output=True, + text=True, + timeout=10, + ) + return result.stdout.strip() + except (subprocess.TimeoutExpired, subprocess.CalledProcessError): + return "" + + @backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=10) + def _wait_until_ready(self): + assert self._popen and self._popen.poll() is None, ( + "Local proxy exited unexpectedly. Check test log." + ) + requests.get(f"http://{self.host}:{self.http_port}/metrics") + + def get_metrics(self) -> str: + response = requests.get(f"http://{self.host}:{self.metrics_port}/metrics") + return response.text + + def assert_no_errors(self): + # Define allowed error patterns for local_proxy + allowed_errors = [ + # Add patterns as needed + ] + not_allowed = [ + "error", + "panic", + "failed", + ] + + for na in not_allowed: + if na not in allowed_errors: + assert not self.log_contains(na), f"Found disallowed error pattern: {na}" + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ): + self.stop() + + +class NeonRestBrokerProxy(LogUtils): + """ + An object managing a proxy instance configured as both auth broker and rest broker. + This is the main proxy binary with --is-auth-broker and --is-rest-broker flags. + """ + + def __init__( + self, + neon_binpath: Path, + test_output_dir: Path, + wss_port: int, + http_port: int, + mgmt_port: int, + config_path: Path | None = None, + ): + self.neon_binpath = neon_binpath + self.test_output_dir = test_output_dir + self.wss_port = wss_port + self.http_port = http_port + self.mgmt_port = mgmt_port + self.config_path = config_path or (test_output_dir / "rest_broker_proxy.json") + self.host = "127.0.0.1" + self.running = False + self.logfile = test_output_dir / "rest_broker_proxy.log" + self._popen: subprocess.Popen[Any] | None = None + + def start(self) -> Self: + if self.running: + return self + + # Generate self-signed TLS certificates + cert_path = self.test_output_dir / "server.crt" + key_path = self.test_output_dir / "server.key" + + if not cert_path.exists() or not key_path.exists(): + import subprocess + + log.info("Generating self-signed TLS certificate for rest broker") + subprocess.run( + [ + "openssl", + "req", + "-new", + "-x509", + "-days", + "365", + "-nodes", + "-text", + "-out", + str(cert_path), + "-keyout", + str(key_path), + "-subj", + "/CN=*.local.neon.build", + ], + check=True, + ) + + log.info( + f"Starting rest broker proxy on WSS port {self.wss_port}, HTTP port {self.http_port}" + ) + + cmd = [ + str(self.neon_binpath / "proxy"), + "-c", + str(cert_path), + "-k", + str(key_path), + "--is-auth-broker", + "true", + "--is-rest-broker", + "true", + "--wss", + f"{self.host}:{self.wss_port}", + "--http", + f"{self.host}:{self.http_port}", + "--mgmt", + f"{self.host}:{self.mgmt_port}", + "--auth-backend", + "local", + "--config-path", + str(self.config_path), + ] + + log.info(f"Starting rest broker proxy with command: {' '.join(cmd)}") + + with open(self.logfile, "w") as logfile: + self._popen = subprocess.Popen( + cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + cwd=self.test_output_dir, + env={ + **os.environ, + "RUST_LOG": "info", + "LOGFMT": "text", + "OTEL_SDK_DISABLED": "true", + }, + ) + + self.running = True + self._wait_until_ready() + return self + + def stop(self) -> Self: + if not self.running: + return self + + log.info("Stopping rest broker proxy") + + if self._popen is not None: + self._popen.terminate() + try: + self._popen.wait(timeout=10) + except subprocess.TimeoutExpired: + log.warning("failed to gracefully terminate rest broker proxy; killing") + self._popen.kill() + + self.running = False + return self + + def get_binary_version(self) -> str: + cmd = [str(self.neon_binpath / "proxy"), "--version"] + res = subprocess.run(cmd, capture_output=True, text=True, check=True) + return res.stdout.strip() + + @backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=10) + def _wait_until_ready(self): + # Check if the WSS port is ready using a simple HTTPS request + # REST API is served on the WSS port with HTTPS + requests.get(f"https://{self.host}:{self.wss_port}/", timeout=1, verify=False) + # Any response (even error) means the server is up - we just need to connect + + def get_metrics(self) -> str: + # Metrics are still on the HTTP port + response = requests.get(f"http://{self.host}:{self.http_port}/metrics", timeout=5) + response.raise_for_status() + return response.text + + def assert_no_errors(self): + # Define allowed error patterns for rest broker proxy + allowed_errors = [ + "connection closed before message completed", + "connection reset by peer", + "broken pipe", + "client disconnected", + "Authentication failed", + "connection timed out", + "no connection available", + "Pool dropped", + ] + + with open(self.logfile) as f: + for line in f: + if "ERROR" in line or "FATAL" in line: + if not any(allowed in line for allowed in allowed_errors): + raise AssertionError( + f"Found error in rest broker proxy log: {line.strip()}" + ) + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ): + self.stop() + + @pytest.fixture(scope="function") def link_proxy( port_distributor: PortDistributor, neon_binpath: Path, test_output_dir: Path @@ -4203,6 +4491,81 @@ def static_proxy( yield proxy +@pytest.fixture(scope="function") +def local_proxy( + vanilla_pg: VanillaPostgres, + port_distributor: PortDistributor, + neon_binpath: Path, + test_output_dir: Path, +) -> Iterator[NeonLocalProxy]: + """Local proxy that connects directly to vanilla postgres for rest broker testing.""" + + # Start vanilla_pg without database bootstrapping + vanilla_pg.start() + + http_port = port_distributor.get_port() + metrics_port = port_distributor.get_port() + + with NeonLocalProxy( + neon_binpath=neon_binpath, + test_output_dir=test_output_dir, + http_port=http_port, + metrics_port=metrics_port, + vanilla_pg=vanilla_pg, + ) as proxy: + proxy.start() + yield proxy + + +@pytest.fixture(scope="function") +def local_proxy_fixed_port( + vanilla_pg: VanillaPostgres, + neon_binpath: Path, + test_output_dir: Path, +) -> Iterator[NeonLocalProxy]: + """Local proxy that connects directly to vanilla postgres on the hardcoded port 7432.""" + + # Start vanilla_pg without database bootstrapping + vanilla_pg.start() + + # Use the hardcoded port that the rest broker proxy expects + http_port = 7432 + metrics_port = 7433 # Use a different port for metrics + + with NeonLocalProxy( + neon_binpath=neon_binpath, + test_output_dir=test_output_dir, + http_port=http_port, + metrics_port=metrics_port, + vanilla_pg=vanilla_pg, + ) as proxy: + proxy.start() + yield proxy + + +@pytest.fixture(scope="function") +def rest_broker_proxy( + port_distributor: PortDistributor, + neon_binpath: Path, + test_output_dir: Path, +) -> Iterator[NeonRestBrokerProxy]: + """Rest broker proxy that handles both auth broker and rest broker functionality.""" + + wss_port = port_distributor.get_port() + http_port = port_distributor.get_port() + mgmt_port = port_distributor.get_port() + + with NeonRestBrokerProxy( + neon_binpath=neon_binpath, + test_output_dir=test_output_dir, + wss_port=wss_port, + http_port=http_port, + mgmt_port=mgmt_port, + ) as proxy: + proxy.start() + yield proxy + + @pytest.fixture(scope="function") def neon_authorize_jwk() -> jwk.JWK: kid = str(uuid.uuid4()) diff --git a/test_runner/fixtures/utils.py b/test_runner/fixtures/utils.py index 0d7345cc82..1f80c2a290 100644 --- a/test_runner/fixtures/utils.py +++ b/test_runner/fixtures/utils.py @@ -741,3 +741,29 @@ def shared_buffers_for_max_cu(max_cu: float) -> str: sharedBuffersMb = int(max(128, (1023 + maxBackends * 256) / 1024)) sharedBuffers = int(sharedBuffersMb * 1024 / 8) return str(sharedBuffers) + + +def skip_if_proxy_lacks_rest_broker(reason: str = "proxy was built without 'rest_broker' feature"): + # Determine the binary path using the same logic as neon_binpath fixture + def has_rest_broker_feature(): + # Find the neon binaries + if env_neon_bin := os.environ.get("NEON_BIN"): + binpath = Path(env_neon_bin) + else: + base_dir = Path(__file__).parents[2] # Same as BASE_DIR in paths.py + build_type = os.environ.get("BUILD_TYPE", "debug") + binpath = base_dir / "target" / build_type + + proxy_bin = binpath / "proxy" + if not proxy_bin.exists(): + return False + + try: + cmd = [str(proxy_bin), "--help"] + result = subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=10) + help_output = result.stdout + return "--is-rest-broker" in help_output + except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError): + return False + + return pytest.mark.skipif(not has_rest_broker_feature(), reason=reason) diff --git a/test_runner/regress/test_rest_broker.py b/test_runner/regress/test_rest_broker.py new file mode 100644 index 0000000000..60b04655d3 --- /dev/null +++ b/test_runner/regress/test_rest_broker.py @@ -0,0 +1,137 @@ +import json +import signal +import time + +import requests +from fixtures.utils import skip_if_proxy_lacks_rest_broker +from jwcrypto import jwt + + +@skip_if_proxy_lacks_rest_broker() +def test_rest_broker_happy( + local_proxy_fixed_port, rest_broker_proxy, vanilla_pg, neon_authorize_jwk, httpserver +): + """Test REST API endpoint using local_proxy and rest_broker_proxy.""" + + # Use the fixed port local proxy + local_proxy = local_proxy_fixed_port + + # Create the required roles for PostgREST authentication + vanilla_pg.safe_psql("CREATE ROLE authenticator LOGIN") + vanilla_pg.safe_psql("CREATE ROLE authenticated") + vanilla_pg.safe_psql("CREATE ROLE anon") + vanilla_pg.safe_psql("GRANT authenticated TO authenticator") + vanilla_pg.safe_psql("GRANT anon TO authenticator") + + # Create the pgrst schema and configuration function required by the rest broker + vanilla_pg.safe_psql("CREATE SCHEMA IF NOT EXISTS pgrst") + vanilla_pg.safe_psql(""" + CREATE OR REPLACE FUNCTION pgrst.pre_config() + RETURNS VOID AS $$ + SELECT + set_config('pgrst.db_schemas', 'test', true) + , set_config('pgrst.db_aggregates_enabled', 'true', true) + , set_config('pgrst.db_anon_role', 'anon', true) + , set_config('pgrst.jwt_aud', '', true) + , set_config('pgrst.jwt_secret', '', true) + , set_config('pgrst.jwt_role_claim_key', '."role"', true) + + $$ LANGUAGE SQL; + """) + vanilla_pg.safe_psql("GRANT USAGE ON SCHEMA pgrst TO authenticator") + vanilla_pg.safe_psql("GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA pgrst TO authenticator") + + # Bootstrap the database with test data + vanilla_pg.safe_psql("CREATE SCHEMA IF NOT EXISTS test") + vanilla_pg.safe_psql(""" + CREATE TABLE IF NOT EXISTS test.items ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL + ) + """) + vanilla_pg.safe_psql("INSERT INTO test.items (name) VALUES ('test_item')") + + # Grant access to the test schema for the authenticated role + vanilla_pg.safe_psql("GRANT USAGE ON SCHEMA test TO authenticated") + vanilla_pg.safe_psql("GRANT SELECT ON ALL TABLES IN SCHEMA test TO authenticated") + + # Set up HTTP server to serve JWKS (like static_auth_broker) + # Generate public key from the JWK + public_key = neon_authorize_jwk.export_public(as_dict=True) + + # Set up the httpserver to serve the JWKS + httpserver.expect_request("/.well-known/jwks.json").respond_with_json({"keys": [public_key]}) + + # Create JWKS configuration for the rest broker proxy + jwks_config = { + "jwks": [ + { + "id": "1", + "role_names": ["authenticator", "authenticated", "anon"], + "jwks_url": httpserver.url_for("/.well-known/jwks.json"), + "provider_name": "foo", + "jwt_audience": None, + } + ] + } + + # Write the JWKS config to the config file that rest_broker_proxy expects + config_file = rest_broker_proxy.config_path + with open(config_file, "w") as f: + json.dump(jwks_config, f) + + # Write the same config to the local_proxy config file + local_config_file = local_proxy.config_path + with open(local_config_file, "w") as f: + json.dump(jwks_config, f) + + # Signal both proxies to reload their config + if rest_broker_proxy._popen is not None: + rest_broker_proxy._popen.send_signal(signal.SIGHUP) + if local_proxy._popen is not None: + local_proxy._popen.send_signal(signal.SIGHUP) + # Wait a bit for config to reload + time.sleep(0.5) + + # Generate a proper JWT token using the JWK (similar to test_auth_broker.py) + token = jwt.JWT( + header={"kid": neon_authorize_jwk.key_id, "alg": "RS256"}, + claims={ + "sub": "user", + "role": "authenticated", # role that's in role_names + "exp": 9999999999, # expires far in the future + "iat": 1000000000, # issued at + }, + ) + token.make_signed_token(neon_authorize_jwk) + + # Debug: Print the JWT claims and config for troubleshooting + print(f"JWT claims: {token.claims}") + print(f"JWT header: {token.header}") + print(f"Config file contains: {jwks_config}") + print(f"Public key kid: {public_key.get('kid')}") + + # Test REST API call - following SUBZERO.md pattern + # REST API is served on the WSS port with HTTPS and includes database name + # ep-purple-glitter-adqior4l-pooler.c-2.us-east-1.aws.neon.tech + url = f"https://foo.apirest.c-2.local.neon.build:{rest_broker_proxy.wss_port}/postgres/rest/v1/items" + + response = requests.get( + url, + headers={ + "Authorization": f"Bearer {token.serialize()}", + }, + params={"id": "eq.1", "select": "name"}, + verify=False, # Skip SSL verification for self-signed certs + ) + + print(f"Response status: {response.status_code}") + print(f"Response headers: {response.headers}") + print(f"Response body: {response.text}") + + # For now, let's just check that we get some response + # We can refine the assertions once we see what the actual response looks like + assert response.status_code in [200] # Any response means the proxies are working + + # check the response body + assert response.json() == [{"name": "test_item"}] From 050c9f704f94b3434fa3dc0a602f3597acf7ac0d Mon Sep 17 00:00:00 2001 From: Ivan Efremov Date: Mon, 21 Jul 2025 23:27:15 +0300 Subject: [PATCH 2/6] proxy: expose session_id to clients and proxy latency to probes (#12656) Implements #8728 --- proxy/src/auth/backend/console_redirect.rs | 2 -- proxy/src/binary/local_proxy.rs | 10 ++++++ proxy/src/binary/proxy.rs | 23 +++++++++++-- proxy/src/config.rs | 1 + proxy/src/console_redirect_proxy.rs | 8 ++++- proxy/src/metrics.rs | 8 ++--- proxy/src/proxy/mod.rs | 40 +++++++++++++++++++++- 7 files changed, 82 insertions(+), 10 deletions(-) diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index f561df9202..b06ed3a0ae 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -180,8 +180,6 @@ async fn authenticate( return Err(auth::AuthError::NetworkNotAllowed); } - client.write_message(BeMessage::NoticeResponse("Connecting to database.")); - // Backwards compatibility. pg_sni_proxy uses "--" in domain names // while direct connections do not. Once we migrate to pg_sni_proxy // everywhere, we can remove this. diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index e3f7ba4c15..7b9012dc69 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -1,3 +1,4 @@ +use std::env; use std::net::SocketAddr; use std::pin::pin; use std::sync::Arc; @@ -264,6 +265,14 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig timeout: Duration::from_secs(2), }; + let greetings = env::var_os("NEON_MOTD").map_or(String::new(), |s| match s.into_string() { + Ok(s) => s, + Err(_) => { + debug!("NEON_MOTD environment variable is not valid UTF-8"); + String::new() + } + }); + Ok(Box::leak(Box::new(ProxyConfig { tls_config: ArcSwapOption::from(None), metric_collection: None, @@ -290,6 +299,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?, connect_compute_locks, connect_to_compute: compute_config, + greetings, #[cfg(feature = "testing")] disable_pg_session_jwt: args.disable_pg_session_jwt, }))) diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 194a1ed34c..d1dd2fef2a 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -1,4 +1,3 @@ -#[cfg(any(test, feature = "testing"))] use std::env; use std::net::SocketAddr; use std::path::PathBuf; @@ -21,7 +20,7 @@ use tokio::net::TcpListener; use tokio::sync::Notify; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; -use tracing::{error, info, warn}; +use tracing::{debug, error, info, warn}; use utils::sentry_init::init_sentry; use utils::{project_build_tag, project_git_version}; @@ -730,6 +729,25 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { } }; + let mut greetings = env::var_os("NEON_MOTD").map_or(String::new(), |s| match s.into_string() { + Ok(s) => s, + Err(_) => { + debug!("NEON_MOTD environment variable is not valid UTF-8"); + String::new() + } + }); + + match &args.auth_backend { + AuthBackendType::ControlPlane => {} + #[cfg(any(test, feature = "testing"))] + AuthBackendType::Postgres => {} + #[cfg(any(test, feature = "testing"))] + AuthBackendType::Local => {} + AuthBackendType::ConsoleRedirect => { + greetings = "Connected to database".to_string(); + } + } + let config = ProxyConfig { tls_config, metric_collection, @@ -740,6 +758,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?, connect_compute_locks, connect_to_compute: compute_config, + greetings, #[cfg(feature = "testing")] disable_pg_session_jwt: false, #[cfg(feature = "rest_broker")] diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 20bbfd77d8..16b1dff5f4 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -39,6 +39,7 @@ pub struct ProxyConfig { pub wake_compute_retry_config: RetryConfig, pub connect_compute_locks: ApiLocks, pub connect_to_compute: ComputeConfig, + pub greetings: String, // Greeting message sent to the client after connection establishment and contains session_id. #[cfg(feature = "testing")] pub disable_pg_session_jwt: bool, } diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 041a56e032..014317d823 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -233,7 +233,13 @@ pub(crate) async fn handle_client( let session = cancellation_handler.get_key(); - finish_client_init(&pg_settings, *session.key(), &mut stream); + finish_client_init( + ctx, + &pg_settings, + *session.key(), + &mut stream, + &config.greetings, + ); let stream = stream.flush_and_into_inner().await?; let session_id = ctx.session_id(); diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 916604e2ec..7524133093 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -385,10 +385,10 @@ pub enum RedisMsgKind { #[derive(Default, Clone)] pub struct LatencyAccumulated { - cplane: time::Duration, - client: time::Duration, - compute: time::Duration, - retry: time::Duration, + pub cplane: time::Duration, + pub client: time::Duration, + pub compute: time::Duration, + pub retry: time::Duration, } impl std::fmt::Display for LatencyAccumulated { diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 02651109e0..8b7c4ff55d 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -145,7 +145,7 @@ pub(crate) async fn handle_client( let session = cancellation_handler.get_key(); - finish_client_init(&pg_settings, *session.key(), client); + finish_client_init(ctx, &pg_settings, *session.key(), client, &config.greetings); let session_id = ctx.session_id(); let (cancel_on_shutdown, cancel) = oneshot::channel(); @@ -165,9 +165,11 @@ pub(crate) async fn handle_client( /// Finish client connection initialization: confirm auth success, send params, etc. pub(crate) fn finish_client_init( + ctx: &RequestContext, settings: &compute::PostgresSettings, cancel_key_data: CancelKeyData, client: &mut PqStream, + greetings: &String, ) { // Forward all deferred notices to the client. for notice in &settings.delayed_notice { @@ -176,6 +178,12 @@ pub(crate) fn finish_client_init( }); } + // Expose session_id to clients if we have a greeting message. + if !greetings.is_empty() { + let session_msg = format!("{}, session_id: {}", greetings, ctx.session_id()); + client.write_message(BeMessage::NoticeResponse(session_msg.as_str())); + } + // Forward all postgres connection params to the client. for (name, value) in &settings.params { client.write_message(BeMessage::ParameterStatus { @@ -184,6 +192,36 @@ pub(crate) fn finish_client_init( }); } + // Forward recorded latencies for probing requests + if let Some(testodrome_id) = ctx.get_testodrome_id() { + client.write_message(BeMessage::ParameterStatus { + name: "neon.testodrome_id".as_bytes(), + value: testodrome_id.as_bytes(), + }); + + let latency_measured = ctx.get_proxy_latency(); + + client.write_message(BeMessage::ParameterStatus { + name: "neon.cplane_latency".as_bytes(), + value: latency_measured.cplane.as_micros().to_string().as_bytes(), + }); + + client.write_message(BeMessage::ParameterStatus { + name: "neon.client_latency".as_bytes(), + value: latency_measured.client.as_micros().to_string().as_bytes(), + }); + + client.write_message(BeMessage::ParameterStatus { + name: "neon.compute_latency".as_bytes(), + value: latency_measured.compute.as_micros().to_string().as_bytes(), + }); + + client.write_message(BeMessage::ParameterStatus { + name: "neon.retry_latency".as_bytes(), + value: latency_measured.retry.as_micros().to_string().as_bytes(), + }); + } + client.write_message(BeMessage::BackendKeyData(cancel_key_data)); client.write_message(BeMessage::ReadyForQuery); } From b7bc3ce61e6d8b0260f2e6dee299f281168f1bd9 Mon Sep 17 00:00:00 2001 From: Tristan Partin Date: Mon, 21 Jul 2025 15:50:02 -0500 Subject: [PATCH 3/6] Skip PG throttle during configuration (#12670) ## Problem While running tenant split tests I ran into a situation where PG got stuck completely. This seems to be a general problem that was not found in the previous chaos testing fixes. What happened is that if PG gets throttled by PS, and SC decided to move some tenant away, then PG reconfiguration could be blocked forever because it cannot talk to the old PS anymore to refresh the throttling stats, and reconfiguration cannot proceed because it's being throttled. Neon has considered the case that configuration could be blocked if the PG storage is full, but forgot the backpressure case. ## Summary of changes The PR fixes this problem by simply skipping throttling while PS is being configured, i.e., `max_cluster_size < 0`. An alternative fix is to set those throttle knobs to -1 (e.g., max_replication_apply_lag), however these knobs were labeled with PGC_POSTMASTER so their values cannot be changed unless we restart PG. ## How is this tested? Tested manually. Co-authored-by: Chen Luo --- pgxn/neon/walproposer_pg.c | 8 ++++++++ test_runner/regress/test_pg_regress.py | 9 ++++++++- test_runner/regress/test_sharding.py | 2 ++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index 9ed8d0d2d2..93807be8c2 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -400,6 +400,14 @@ static uint64 backpressure_lag_impl(void) { struct WalproposerShmemState* state = NULL; + + /* BEGIN_HADRON */ + if(max_cluster_size < 0){ + // if max cluster size is not set, then we don't apply backpressure because we're reconfiguring PG + return 0; + } + /* END_HADRON */ + if (max_replication_apply_lag > 0 || max_replication_flush_lag > 0 || max_replication_write_lag > 0) { XLogRecPtr writePtr; diff --git a/test_runner/regress/test_pg_regress.py b/test_runner/regress/test_pg_regress.py index a240071a7f..dd9c5437ad 100644 --- a/test_runner/regress/test_pg_regress.py +++ b/test_runner/regress/test_pg_regress.py @@ -368,7 +368,14 @@ def test_max_wal_rate(neon_simple_env: NeonEnv): superuser_name = "databricks_superuser" # Connect to postgres and create a database called "regression". - endpoint = env.endpoints.create_start("main") + endpoint = env.endpoints.create_start( + "main", + config_lines=[ + # we need this option because default max_cluster_size < 0 will disable throttling completely + "neon.max_cluster_size=10GB", + ], + ) + endpoint.safe_psql_many( [ f"CREATE ROLE {superuser_name}", diff --git a/test_runner/regress/test_sharding.py b/test_runner/regress/test_sharding.py index 5549105188..2252c098c7 100644 --- a/test_runner/regress/test_sharding.py +++ b/test_runner/regress/test_sharding.py @@ -1810,6 +1810,8 @@ def test_sharding_backpressure(neon_env_builder: NeonEnvBuilder): "config_lines": [ # Tip: set to 100MB to make the test fail "max_replication_write_lag=1MB", + # Hadron: Need to set max_cluster_size to some value to enable any backpressure at all. + "neon.max_cluster_size=1GB", ], # We need `neon` extension for calling backpressure functions, # this flag instructs `compute_ctl` to pre-install it. From 80baeaa084e603af3a6e86e41ea45975dafddf74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arpad=20M=C3=BCller?= Date: Mon, 21 Jul 2025 23:14:15 +0200 Subject: [PATCH 4/6] storcon: add force_upsert flag to timeline_import endpoint (#12622) It is useful to have ability to update an existing timeline entry, as a way to mirror legacy migrations to the storcon managed table. --- libs/pageserver_api/src/controller_api.rs | 1 + storage_controller/src/persistence.rs | 43 +++++++++++++++++++ .../src/service/safekeeper_service.rs | 23 ++++++++-- 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/libs/pageserver_api/src/controller_api.rs b/libs/pageserver_api/src/controller_api.rs index 8f86b03f72..1248be0b5c 100644 --- a/libs/pageserver_api/src/controller_api.rs +++ b/libs/pageserver_api/src/controller_api.rs @@ -596,6 +596,7 @@ pub struct TimelineImportRequest { pub timeline_id: TimelineId, pub start_lsn: Lsn, pub sk_set: Vec, + pub force_upsert: bool, } #[derive(serde::Serialize, serde::Deserialize, Clone)] diff --git a/storage_controller/src/persistence.rs b/storage_controller/src/persistence.rs index ed9a268064..2e3f8c6908 100644 --- a/storage_controller/src/persistence.rs +++ b/storage_controller/src/persistence.rs @@ -129,6 +129,7 @@ pub(crate) enum DatabaseOperation { UpdateLeader, SetPreferredAzs, InsertTimeline, + UpdateTimeline, UpdateTimelineMembership, GetTimeline, InsertTimelineReconcile, @@ -1463,6 +1464,36 @@ impl Persistence { .await } + /// Update an already present timeline. + /// VERY UNSAFE FUNCTION: this overrides in-progress migrations. Don't use this unless neccessary. + pub(crate) async fn update_timeline_unsafe( + &self, + entry: TimelineUpdate, + ) -> DatabaseResult { + use crate::schema::timelines; + + let entry = &entry; + self.with_measured_conn(DatabaseOperation::UpdateTimeline, move |conn| { + Box::pin(async move { + let inserted_updated = diesel::update(timelines::table) + .filter(timelines::tenant_id.eq(&entry.tenant_id)) + .filter(timelines::timeline_id.eq(&entry.timeline_id)) + .set(entry) + .execute(conn) + .await?; + + match inserted_updated { + 0 => Ok(false), + 1 => Ok(true), + _ => Err(DatabaseError::Logical(format!( + "unexpected number of rows ({inserted_updated})" + ))), + } + }) + }) + .await + } + /// Update timeline membership configuration in the database. /// Perform a compare-and-swap (CAS) operation on the timeline's generation. /// The `new_generation` must be the next (+1) generation after the one in the database. @@ -2503,6 +2534,18 @@ impl TimelineFromDb { } } +// This is separate from TimelinePersistence because we don't want to touch generation and deleted_at values for the update. +#[derive(AsChangeset)] +#[diesel(table_name = crate::schema::timelines)] +#[diesel(treat_none_as_null = true)] +pub(crate) struct TimelineUpdate { + pub(crate) tenant_id: String, + pub(crate) timeline_id: String, + pub(crate) start_lsn: LsnWrapper, + pub(crate) sk_set: Vec, + pub(crate) new_sk_set: Option>, +} + #[derive(Insertable, AsChangeset, Queryable, Selectable, Clone)] #[diesel(table_name = crate::schema::safekeeper_timeline_pending_ops)] pub(crate) struct TimelinePendingOpPersistence { diff --git a/storage_controller/src/service/safekeeper_service.rs b/storage_controller/src/service/safekeeper_service.rs index 7521d7bd86..28c70e203a 100644 --- a/storage_controller/src/service/safekeeper_service.rs +++ b/storage_controller/src/service/safekeeper_service.rs @@ -10,6 +10,7 @@ use crate::id_lock_map::trace_shared_lock; use crate::metrics; use crate::persistence::{ DatabaseError, SafekeeperTimelineOpKind, TimelinePendingOpPersistence, TimelinePersistence, + TimelineUpdate, }; use crate::safekeeper::Safekeeper; use crate::safekeeper_client::SafekeeperClient; @@ -454,19 +455,33 @@ impl Service { let persistence = TimelinePersistence { tenant_id: req.tenant_id.to_string(), timeline_id: req.timeline_id.to_string(), - start_lsn: Lsn::INVALID.into(), + start_lsn: req.start_lsn.into(), generation: 1, sk_set: req.sk_set.iter().map(|sk_id| sk_id.0 as i64).collect(), new_sk_set: None, cplane_notified_generation: 1, deleted_at: None, }; - let inserted = self.persistence.insert_timeline(persistence).await?; + let inserted = self + .persistence + .insert_timeline(persistence.clone()) + .await?; if inserted { tracing::info!("imported timeline into db"); - } else { - tracing::info!("didn't import timeline into db, as it is already present in db"); + return Ok(()); } + tracing::info!("timeline already present in db, updating"); + + let update = TimelineUpdate { + tenant_id: persistence.tenant_id, + timeline_id: persistence.timeline_id, + start_lsn: persistence.start_lsn, + sk_set: persistence.sk_set, + new_sk_set: persistence.new_sk_set, + }; + self.persistence.update_timeline_unsafe(update).await?; + tracing::info!("timeline updated"); + Ok(()) } From 5464552020ecc18d2367b4396363abbcb7e8ea02 Mon Sep 17 00:00:00 2001 From: Konstantin Knizhnik Date: Tue, 22 Jul 2025 07:39:54 +0300 Subject: [PATCH 5/6] Limit number of parallel config apply connections to 100 (#12663) ## Problem See https://databricks.slack.com/archives/C092W8NBXC0/p1752924508578339 In case of larger number of databases and large `max_connections` we can open too many connection for parallel apply config which may cause `Too many open files` error. ## Summary of changes Limit maximal number of parallel config apply connections by 100. --------- Co-authored-by: Kosntantin Knizhnik --- compute_tools/src/spec_apply.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/compute_tools/src/spec_apply.rs b/compute_tools/src/spec_apply.rs index ec7e75922b..47bf61ae1b 100644 --- a/compute_tools/src/spec_apply.rs +++ b/compute_tools/src/spec_apply.rs @@ -411,7 +411,8 @@ impl ComputeNode { .map(|limit| match limit { 0..10 => limit, 10..30 => 10, - 30.. => limit / 3, + 30..300 => limit / 3, + 300.. => 100, }) // If we didn't find max_connections, default to 10 concurrent connections. .unwrap_or(10) From 9c0efba91ea94203d72289e9ae3ecb62c1d20e51 Mon Sep 17 00:00:00 2001 From: Folke Behrens Date: Tue, 22 Jul 2025 11:31:39 +0200 Subject: [PATCH 6/6] Bump rand crate to 0.9 (#12674) --- Cargo.lock | 47 +++++-------- Cargo.toml | 4 +- endpoint_storage/src/app.rs | 2 +- libs/consumption_metrics/src/lib.rs | 2 +- libs/desim/src/node_os.rs | 2 +- libs/desim/src/options.rs | 4 +- libs/desim/src/world.rs | 2 +- libs/metrics/Cargo.toml | 4 +- libs/metrics/src/hll.rs | 14 ++-- libs/neon-shmem/Cargo.toml | 2 +- libs/pageserver_api/src/key.rs | 12 ++-- libs/pageserver_api/src/models.rs | 4 +- libs/pq_proto/src/lib.rs | 8 +-- .../src/authentication/sasl.rs | 4 +- .../postgres-protocol2/src/password/mod.rs | 2 +- libs/remote_storage/Cargo.toml | 2 +- libs/remote_storage/src/simulate_failures.rs | 4 +- libs/remote_storage/tests/test_real_azure.rs | 2 +- libs/remote_storage/tests/test_real_s3.rs | 2 +- libs/utils/src/id.rs | 2 +- libs/utils/src/lsn.rs | 37 +++++----- pageserver/benches/bench_layer_map.rs | 3 +- .../src/bin/compaction-simulator.rs | 2 +- pageserver/compaction/src/simulator.rs | 4 +- pageserver/pagebench/src/cmd/basebackup.rs | 4 +- .../pagebench/src/cmd/getpage_latest_lsn.rs | 11 ++- .../src/cmd/ondemand_download_churn.rs | 4 +- pageserver/src/feature_resolver.rs | 2 +- pageserver/src/tenant.rs | 33 ++++----- pageserver/src/tenant/blob_io.rs | 8 +-- pageserver/src/tenant/disk_btree.rs | 4 +- pageserver/src/tenant/ephemeral_file.rs | 12 ++-- pageserver/src/tenant/mgr.rs | 4 +- pageserver/src/tenant/secondary/scheduler.rs | 4 +- .../src/tenant/storage_layer/delta_layer.rs | 12 ++-- .../inmemory_layer/vectored_dio_read.rs | 8 +-- pageserver/src/tenant/tasks.rs | 2 +- pageserver/src/tenant/timeline.rs | 4 +- pageserver/src/virtual_file.rs | 7 +- proxy/Cargo.toml | 3 +- proxy/src/auth/backend/jwt.rs | 2 +- proxy/src/binary/proxy.rs | 4 +- proxy/src/cache/project_info.rs | 4 +- proxy/src/context/parquet.rs | 70 +++++++++---------- proxy/src/intern.rs | 2 +- proxy/src/pqproto.rs | 6 +- proxy/src/proxy/tests/mod.rs | 4 +- proxy/src/rate_limiter/leaky_bucket.rs | 4 +- proxy/src/rate_limiter/limiter.rs | 4 +- proxy/src/scram/countmin.rs | 10 +-- proxy/src/scram/threadpool.rs | 4 +- proxy/src/serverless/backend.rs | 2 +- proxy/src/serverless/cancel_set.rs | 7 +- proxy/src/serverless/conn_pool_lib.rs | 2 +- proxy/src/serverless/mod.rs | 4 +- safekeeper/src/rate_limit.rs | 2 +- safekeeper/tests/random_test.rs | 2 +- .../tests/walproposer_sim/simulation.rs | 22 +++--- storage_controller/src/hadron_utils.rs | 4 +- .../src/service/chaos_injector.rs | 16 ++--- 60 files changed, 231 insertions(+), 237 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 32ae30a765..f9dd33725a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1451,7 +1451,7 @@ name = "consumption_metrics" version = "0.1.0" dependencies = [ "chrono", - "rand 0.8.5", + "rand 0.9.1", "serde", ] @@ -1854,7 +1854,7 @@ dependencies = [ "bytes", "hex", "parking_lot 0.12.1", - "rand 0.8.5", + "rand 0.9.1", "smallvec", "tracing", "utils", @@ -2099,7 +2099,7 @@ dependencies = [ "itertools 0.10.5", "jsonwebtoken", "prometheus", - "rand 0.8.5", + "rand 0.9.1", "remote_storage", "serde", "serde_json", @@ -3782,8 +3782,8 @@ dependencies = [ "once_cell", "procfs", "prometheus", - "rand 0.8.5", - "rand_distr 0.4.3", + "rand 0.9.1", + "rand_distr", "twox-hash", ] @@ -3875,7 +3875,7 @@ dependencies = [ "lock_api", "nix 0.30.1", "rand 0.9.1", - "rand_distr 0.5.1", + "rand_distr", "rustc-hash 2.1.1", "tempfile", "thiserror 1.0.69", @@ -4351,7 +4351,7 @@ dependencies = [ "pageserver_client_grpc", "pageserver_page_api", "pprof", - "rand 0.8.5", + "rand 0.9.1", "reqwest", "serde", "serde_json", @@ -4448,7 +4448,7 @@ dependencies = [ "pprof", "pq_proto", "procfs", - "rand 0.8.5", + "rand 0.9.1", "range-set-blaze", "regex", "remote_storage", @@ -4515,7 +4515,7 @@ dependencies = [ "postgres_ffi_types", "postgres_versioninfo", "posthog_client_lite", - "rand 0.8.5", + "rand 0.9.1", "remote_storage", "reqwest", "serde", @@ -4585,7 +4585,7 @@ dependencies = [ "once_cell", "pageserver_api", "pin-project-lite", - "rand 0.8.5", + "rand 0.9.1", "svg_fmt", "tokio", "tracing", @@ -4958,7 +4958,7 @@ dependencies = [ "fallible-iterator", "hmac", "memchr", - "rand 0.8.5", + "rand 0.9.1", "sha2", "stringprep", "tokio", @@ -5150,7 +5150,7 @@ dependencies = [ "bytes", "itertools 0.10.5", "postgres-protocol", - "rand 0.8.5", + "rand 0.9.1", "serde", "thiserror 1.0.69", "tokio", @@ -5414,8 +5414,9 @@ dependencies = [ "postgres-protocol2", "postgres_backend", "pq_proto", - "rand 0.8.5", - "rand_distr 0.4.3", + "rand 0.9.1", + "rand_core 0.6.4", + "rand_distr", "rcgen", "redis", "regex", @@ -5617,16 +5618,6 @@ dependencies = [ "getrandom 0.3.3", ] -[[package]] -name = "rand_distr" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" -dependencies = [ - "num-traits", - "rand 0.8.5", -] - [[package]] name = "rand_distr" version = "0.5.1" @@ -5840,7 +5831,7 @@ dependencies = [ "metrics", "once_cell", "pin-project-lite", - "rand 0.8.5", + "rand 0.9.1", "reqwest", "scopeguard", "serde", @@ -6330,7 +6321,7 @@ dependencies = [ "postgres_versioninfo", "pprof", "pq_proto", - "rand 0.8.5", + "rand 0.9.1", "regex", "remote_storage", "reqwest", @@ -7024,7 +7015,7 @@ dependencies = [ "pageserver_client", "postgres_connection", "posthog_client_lite", - "rand 0.8.5", + "rand 0.9.1", "regex", "reqwest", "routerify", @@ -8305,7 +8296,7 @@ dependencies = [ "postgres_connection", "pprof", "pq_proto", - "rand 0.8.5", + "rand 0.9.1", "regex", "scopeguard", "sentry", diff --git a/Cargo.toml b/Cargo.toml index fe647828fc..3a57976cd8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -158,7 +158,9 @@ procfs = "0.16" prometheus = {version = "0.13", default-features=false, features = ["process"]} # removes protobuf dependency prost = "0.13.5" prost-types = "0.13.5" -rand = "0.8" +rand = "0.9" +# Remove after p256 is updated to 0.14. +rand_core = "=0.6" redis = { version = "0.29.2", features = ["tokio-rustls-comp", "keep-alive"] } regex = "1.10.2" reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] } diff --git a/endpoint_storage/src/app.rs b/endpoint_storage/src/app.rs index a7a18743ef..64c21cc8b9 100644 --- a/endpoint_storage/src/app.rs +++ b/endpoint_storage/src/app.rs @@ -233,7 +233,7 @@ mod tests { .unwrap() .as_millis(); use rand::Rng; - let random = rand::thread_rng().r#gen::(); + let random = rand::rng().random::(); let s3_config = remote_storage::S3Config { bucket_name: var(REAL_S3_BUCKET).unwrap(), diff --git a/libs/consumption_metrics/src/lib.rs b/libs/consumption_metrics/src/lib.rs index 448134f31a..aeb33bdfc2 100644 --- a/libs/consumption_metrics/src/lib.rs +++ b/libs/consumption_metrics/src/lib.rs @@ -90,7 +90,7 @@ impl<'a> IdempotencyKey<'a> { IdempotencyKey { now: Utc::now(), node_id, - nonce: rand::thread_rng().gen_range(0..=9999), + nonce: rand::rng().random_range(0..=9999), } } diff --git a/libs/desim/src/node_os.rs b/libs/desim/src/node_os.rs index e0cde7b284..6517c2001e 100644 --- a/libs/desim/src/node_os.rs +++ b/libs/desim/src/node_os.rs @@ -41,7 +41,7 @@ impl NodeOs { /// Generate a random number in range [0, max). pub fn random(&self, max: u64) -> u64 { - self.internal.rng.lock().gen_range(0..max) + self.internal.rng.lock().random_range(0..max) } /// Append a new event to the world event log. diff --git a/libs/desim/src/options.rs b/libs/desim/src/options.rs index 9b1a42fd28..d5da008ef1 100644 --- a/libs/desim/src/options.rs +++ b/libs/desim/src/options.rs @@ -32,10 +32,10 @@ impl Delay { /// Generate a random delay in range [min, max]. Return None if the /// message should be dropped. pub fn delay(&self, rng: &mut StdRng) -> Option { - if rng.gen_bool(self.fail_prob) { + if rng.random_bool(self.fail_prob) { return None; } - Some(rng.gen_range(self.min..=self.max)) + Some(rng.random_range(self.min..=self.max)) } } diff --git a/libs/desim/src/world.rs b/libs/desim/src/world.rs index 576ba89cd7..690d45f373 100644 --- a/libs/desim/src/world.rs +++ b/libs/desim/src/world.rs @@ -69,7 +69,7 @@ impl World { /// Create a new random number generator. pub fn new_rng(&self) -> StdRng { let mut rng = self.rng.lock(); - StdRng::from_rng(rng.deref_mut()).unwrap() + StdRng::from_rng(rng.deref_mut()) } /// Create a new node. diff --git a/libs/metrics/Cargo.toml b/libs/metrics/Cargo.toml index f87e7b8e3a..1718ddfae2 100644 --- a/libs/metrics/Cargo.toml +++ b/libs/metrics/Cargo.toml @@ -17,5 +17,5 @@ procfs.workspace = true measured-process.workspace = true [dev-dependencies] -rand = "0.8" -rand_distr = "0.4.3" +rand.workspace = true +rand_distr = "0.5" diff --git a/libs/metrics/src/hll.rs b/libs/metrics/src/hll.rs index 1a7d7a7e44..81e5bafbdf 100644 --- a/libs/metrics/src/hll.rs +++ b/libs/metrics/src/hll.rs @@ -260,7 +260,7 @@ mod tests { #[test] fn test_cardinality_small() { - let (actual, estimate) = test_cardinality(100, Zipf::new(100, 1.2f64).unwrap()); + let (actual, estimate) = test_cardinality(100, Zipf::new(100.0, 1.2f64).unwrap()); assert_eq!(actual, [46, 30, 32]); assert!(51.3 < estimate[0] && estimate[0] < 51.4); @@ -270,7 +270,7 @@ mod tests { #[test] fn test_cardinality_medium() { - let (actual, estimate) = test_cardinality(10000, Zipf::new(10000, 1.2f64).unwrap()); + let (actual, estimate) = test_cardinality(10000, Zipf::new(10000.0, 1.2f64).unwrap()); assert_eq!(actual, [2529, 1618, 1629]); assert!(2309.1 < estimate[0] && estimate[0] < 2309.2); @@ -280,7 +280,8 @@ mod tests { #[test] fn test_cardinality_large() { - let (actual, estimate) = test_cardinality(1_000_000, Zipf::new(1_000_000, 1.2f64).unwrap()); + let (actual, estimate) = + test_cardinality(1_000_000, Zipf::new(1_000_000.0, 1.2f64).unwrap()); assert_eq!(actual, [129077, 79579, 79630]); assert!(126067.2 < estimate[0] && estimate[0] < 126067.3); @@ -290,7 +291,7 @@ mod tests { #[test] fn test_cardinality_small2() { - let (actual, estimate) = test_cardinality(100, Zipf::new(200, 0.8f64).unwrap()); + let (actual, estimate) = test_cardinality(100, Zipf::new(200.0, 0.8f64).unwrap()); assert_eq!(actual, [92, 58, 60]); assert!(116.1 < estimate[0] && estimate[0] < 116.2); @@ -300,7 +301,7 @@ mod tests { #[test] fn test_cardinality_medium2() { - let (actual, estimate) = test_cardinality(10000, Zipf::new(20000, 0.8f64).unwrap()); + let (actual, estimate) = test_cardinality(10000, Zipf::new(20000.0, 0.8f64).unwrap()); assert_eq!(actual, [8201, 5131, 5051]); assert!(6846.4 < estimate[0] && estimate[0] < 6846.5); @@ -310,7 +311,8 @@ mod tests { #[test] fn test_cardinality_large2() { - let (actual, estimate) = test_cardinality(1_000_000, Zipf::new(2_000_000, 0.8f64).unwrap()); + let (actual, estimate) = + test_cardinality(1_000_000, Zipf::new(2_000_000.0, 0.8f64).unwrap()); assert_eq!(actual, [777847, 482069, 482246]); assert!(699437.4 < estimate[0] && estimate[0] < 699437.5); diff --git a/libs/neon-shmem/Cargo.toml b/libs/neon-shmem/Cargo.toml index 7ed991502e..1cdc9c0c67 100644 --- a/libs/neon-shmem/Cargo.toml +++ b/libs/neon-shmem/Cargo.toml @@ -16,5 +16,5 @@ rustc-hash.workspace = true tempfile = "3.14.0" [dev-dependencies] -rand = "0.9" +rand.workspace = true rand_distr = "0.5.1" diff --git a/libs/pageserver_api/src/key.rs b/libs/pageserver_api/src/key.rs index 102bbee879..4e8fabfa72 100644 --- a/libs/pageserver_api/src/key.rs +++ b/libs/pageserver_api/src/key.rs @@ -981,12 +981,12 @@ mod tests { let mut rng = rand::rngs::StdRng::seed_from_u64(42); let key = Key { - field1: rng.r#gen(), - field2: rng.r#gen(), - field3: rng.r#gen(), - field4: rng.r#gen(), - field5: rng.r#gen(), - field6: rng.r#gen(), + field1: rng.random(), + field2: rng.random(), + field3: rng.random(), + field4: rng.random(), + field5: rng.random(), + field6: rng.random(), }; assert_eq!(key, Key::from_str(&format!("{key}")).unwrap()); diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 11e02a8550..7c7c65fb70 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -443,9 +443,9 @@ pub struct ImportPgdataIdempotencyKey(pub String); impl ImportPgdataIdempotencyKey { pub fn random() -> Self { use rand::Rng; - use rand::distributions::Alphanumeric; + use rand::distr::Alphanumeric; Self( - rand::thread_rng() + rand::rng() .sample_iter(&Alphanumeric) .take(20) .map(char::from) diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index 482dd9a298..5ecb4badf1 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -203,12 +203,12 @@ impl fmt::Display for CancelKeyData { } } -use rand::distributions::{Distribution, Standard}; -impl Distribution for Standard { +use rand::distr::{Distribution, StandardUniform}; +impl Distribution for StandardUniform { fn sample(&self, rng: &mut R) -> CancelKeyData { CancelKeyData { - backend_pid: rng.r#gen(), - cancel_key: rng.r#gen(), + backend_pid: rng.random(), + cancel_key: rng.random(), } } } diff --git a/libs/proxy/postgres-protocol2/src/authentication/sasl.rs b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs index 274c81c500..cfa59a34f4 100644 --- a/libs/proxy/postgres-protocol2/src/authentication/sasl.rs +++ b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs @@ -155,10 +155,10 @@ pub struct ScramSha256 { fn nonce() -> String { // rand 0.5's ThreadRng is cryptographically secure - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); (0..NONCE_LENGTH) .map(|_| { - let mut v = rng.gen_range(0x21u8..0x7e); + let mut v = rng.random_range(0x21u8..0x7e); if v == 0x2c { v = 0x7e } diff --git a/libs/proxy/postgres-protocol2/src/password/mod.rs b/libs/proxy/postgres-protocol2/src/password/mod.rs index e00ca1e34c..8926710225 100644 --- a/libs/proxy/postgres-protocol2/src/password/mod.rs +++ b/libs/proxy/postgres-protocol2/src/password/mod.rs @@ -28,7 +28,7 @@ const SCRAM_DEFAULT_SALT_LEN: usize = 16; /// special characters that would require escaping in an SQL command. pub async fn scram_sha_256(password: &[u8]) -> String { let mut salt: [u8; SCRAM_DEFAULT_SALT_LEN] = [0; SCRAM_DEFAULT_SALT_LEN]; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); rng.fill_bytes(&mut salt); scram_sha_256_salt(password, salt).await } diff --git a/libs/remote_storage/Cargo.toml b/libs/remote_storage/Cargo.toml index 0ae13552b8..ea06725cfd 100644 --- a/libs/remote_storage/Cargo.toml +++ b/libs/remote_storage/Cargo.toml @@ -43,7 +43,7 @@ itertools.workspace = true sync_wrapper = { workspace = true, features = ["futures"] } byteorder = "1.4" -rand = "0.8.5" +rand.workspace = true [dev-dependencies] camino-tempfile.workspace = true diff --git a/libs/remote_storage/src/simulate_failures.rs b/libs/remote_storage/src/simulate_failures.rs index e895380192..f35d2a3081 100644 --- a/libs/remote_storage/src/simulate_failures.rs +++ b/libs/remote_storage/src/simulate_failures.rs @@ -81,7 +81,7 @@ impl UnreliableWrapper { /// fn attempt(&self, op: RemoteOp) -> anyhow::Result { let mut attempts = self.attempts.lock().unwrap(); - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); match attempts.entry(op) { Entry::Occupied(mut e) => { @@ -94,7 +94,7 @@ impl UnreliableWrapper { /* BEGIN_HADRON */ // If there are more attempts to fail, fail the request by probability. if (attempts_before_this < self.attempts_to_fail) - && (rng.gen_range(0..=100) < self.attempt_failure_probability) + && (rng.random_range(0..=100) < self.attempt_failure_probability) { let error = anyhow::anyhow!("simulated failure of remote operation {:?}", e.key()); diff --git a/libs/remote_storage/tests/test_real_azure.rs b/libs/remote_storage/tests/test_real_azure.rs index 4d7caabd39..949035b8c3 100644 --- a/libs/remote_storage/tests/test_real_azure.rs +++ b/libs/remote_storage/tests/test_real_azure.rs @@ -208,7 +208,7 @@ async fn create_azure_client( .as_millis(); // because nanos can be the same for two threads so can millis, add randomness - let random = rand::thread_rng().r#gen::(); + let random = rand::rng().random::(); let remote_storage_config = RemoteStorageConfig { storage: RemoteStorageKind::AzureContainer(AzureConfig { diff --git a/libs/remote_storage/tests/test_real_s3.rs b/libs/remote_storage/tests/test_real_s3.rs index 6b893edf75..f5c81bf45d 100644 --- a/libs/remote_storage/tests/test_real_s3.rs +++ b/libs/remote_storage/tests/test_real_s3.rs @@ -385,7 +385,7 @@ async fn create_s3_client( .as_millis(); // because nanos can be the same for two threads so can millis, add randomness - let random = rand::thread_rng().r#gen::(); + let random = rand::rng().random::(); let remote_storage_config = RemoteStorageConfig { storage: RemoteStorageKind::AwsS3(S3Config { diff --git a/libs/utils/src/id.rs b/libs/utils/src/id.rs index e3037aec21..d63bba75a3 100644 --- a/libs/utils/src/id.rs +++ b/libs/utils/src/id.rs @@ -104,7 +104,7 @@ impl Id { pub fn generate() -> Self { let mut tli_buf = [0u8; 16]; - rand::thread_rng().fill(&mut tli_buf); + rand::rng().fill(&mut tli_buf); Id::from(tli_buf) } diff --git a/libs/utils/src/lsn.rs b/libs/utils/src/lsn.rs index 31e1dda23d..1abb63817b 100644 --- a/libs/utils/src/lsn.rs +++ b/libs/utils/src/lsn.rs @@ -364,42 +364,37 @@ impl MonotonicCounter for RecordLsn { } } -/// Implements [`rand::distributions::uniform::UniformSampler`] so we can sample [`Lsn`]s. +/// Implements [`rand::distr::uniform::UniformSampler`] so we can sample [`Lsn`]s. /// /// This is used by the `pagebench` pageserver benchmarking tool. -pub struct LsnSampler(::Sampler); +pub struct LsnSampler(::Sampler); -impl rand::distributions::uniform::SampleUniform for Lsn { +impl rand::distr::uniform::SampleUniform for Lsn { type Sampler = LsnSampler; } -impl rand::distributions::uniform::UniformSampler for LsnSampler { +impl rand::distr::uniform::UniformSampler for LsnSampler { type X = Lsn; - fn new(low: B1, high: B2) -> Self + fn new(low: B1, high: B2) -> Result where - B1: rand::distributions::uniform::SampleBorrow + Sized, - B2: rand::distributions::uniform::SampleBorrow + Sized, + B1: rand::distr::uniform::SampleBorrow + Sized, + B2: rand::distr::uniform::SampleBorrow + Sized, { - Self( - ::Sampler::new( - low.borrow().0, - high.borrow().0, - ), - ) + ::Sampler::new(low.borrow().0, high.borrow().0) + .map(Self) } - fn new_inclusive(low: B1, high: B2) -> Self + fn new_inclusive(low: B1, high: B2) -> Result where - B1: rand::distributions::uniform::SampleBorrow + Sized, - B2: rand::distributions::uniform::SampleBorrow + Sized, + B1: rand::distr::uniform::SampleBorrow + Sized, + B2: rand::distr::uniform::SampleBorrow + Sized, { - Self( - ::Sampler::new_inclusive( - low.borrow().0, - high.borrow().0, - ), + ::Sampler::new_inclusive( + low.borrow().0, + high.borrow().0, ) + .map(Self) } fn sample(&self, rng: &mut R) -> Self::X { diff --git a/pageserver/benches/bench_layer_map.rs b/pageserver/benches/bench_layer_map.rs index e1444778b8..284cc4d67d 100644 --- a/pageserver/benches/bench_layer_map.rs +++ b/pageserver/benches/bench_layer_map.rs @@ -11,7 +11,8 @@ use pageserver::tenant::layer_map::LayerMap; use pageserver::tenant::storage_layer::{LayerName, PersistentLayerDesc}; use pageserver_api::key::Key; use pageserver_api::shard::TenantShardId; -use rand::prelude::{SeedableRng, SliceRandom, StdRng}; +use rand::prelude::{SeedableRng, StdRng}; +use rand::seq::IndexedRandom; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; diff --git a/pageserver/compaction/src/bin/compaction-simulator.rs b/pageserver/compaction/src/bin/compaction-simulator.rs index dd35417333..6211b86809 100644 --- a/pageserver/compaction/src/bin/compaction-simulator.rs +++ b/pageserver/compaction/src/bin/compaction-simulator.rs @@ -89,7 +89,7 @@ async fn simulate(cmd: &SimulateCmd, results_path: &Path) -> anyhow::Result<()> let cold_key_range = splitpoint..key_range.end; for i in 0..cmd.num_records { - let chosen_range = if rand::thread_rng().gen_bool(0.9) { + let chosen_range = if rand::rng().random_bool(0.9) { &hot_key_range } else { &cold_key_range diff --git a/pageserver/compaction/src/simulator.rs b/pageserver/compaction/src/simulator.rs index bf9f6f2658..44507c335b 100644 --- a/pageserver/compaction/src/simulator.rs +++ b/pageserver/compaction/src/simulator.rs @@ -300,9 +300,9 @@ impl MockTimeline { key_range: &Range, ) -> anyhow::Result<()> { crate::helpers::union_to_keyspace(&mut self.keyspace, vec![key_range.clone()]); - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..num_records { - self.ingest_record(rng.gen_range(key_range.clone()), len); + self.ingest_record(rng.random_range(key_range.clone()), len); self.wal_ingested += len; } Ok(()) diff --git a/pageserver/pagebench/src/cmd/basebackup.rs b/pageserver/pagebench/src/cmd/basebackup.rs index c14bb73136..01875f74b9 100644 --- a/pageserver/pagebench/src/cmd/basebackup.rs +++ b/pageserver/pagebench/src/cmd/basebackup.rs @@ -188,9 +188,9 @@ async fn main_impl( start_work_barrier.wait().await; loop { let (timeline, work) = { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let target = all_targets.choose(&mut rng).unwrap(); - let lsn = target.lsn_range.clone().map(|r| rng.gen_range(r)); + let lsn = target.lsn_range.clone().map(|r| rng.random_range(r)); (target.timeline, Work { lsn }) }; let sender = work_senders.get(&timeline).unwrap(); diff --git a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs index 30b30d36f6..ed7fe9c4ea 100644 --- a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs +++ b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs @@ -326,8 +326,7 @@ async fn main_impl( .cloned() .collect(); let weights = - rand::distributions::weighted::WeightedIndex::new(ranges.iter().map(|v| v.len())) - .unwrap(); + rand::distr::weighted::WeightedIndex::new(ranges.iter().map(|v| v.len())).unwrap(); Box::pin(async move { let scheme = match Url::parse(&args.page_service_connstring) { @@ -427,7 +426,7 @@ async fn run_worker( cancel: CancellationToken, rps_period: Option, ranges: Vec, - weights: rand::distributions::weighted::WeightedIndex, + weights: rand::distr::weighted::WeightedIndex, ) { shared_state.start_work_barrier.wait().await; let client_start = Instant::now(); @@ -469,9 +468,9 @@ async fn run_worker( } // Pick a random page from a random relation. - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let r = &ranges[weights.sample(&mut rng)]; - let key: i128 = rng.gen_range(r.start..r.end); + let key: i128 = rng.random_range(r.start..r.end); let (rel_tag, block_no) = key_to_block(key); let mut blks = VecDeque::with_capacity(batch_size); @@ -502,7 +501,7 @@ async fn run_worker( // We assume that the entire batch can fit within the relation. assert_eq!(blks.len(), batch_size, "incomplete batch"); - let req_lsn = if rng.gen_bool(args.req_latest_probability) { + let req_lsn = if rng.random_bool(args.req_latest_probability) { Lsn::MAX } else { r.timeline_lsn diff --git a/pageserver/pagebench/src/cmd/ondemand_download_churn.rs b/pageserver/pagebench/src/cmd/ondemand_download_churn.rs index 9ff1e638c4..8fbb452140 100644 --- a/pageserver/pagebench/src/cmd/ondemand_download_churn.rs +++ b/pageserver/pagebench/src/cmd/ondemand_download_churn.rs @@ -7,7 +7,7 @@ use std::time::{Duration, Instant}; use pageserver_api::models::HistoricLayerInfo; use pageserver_api::shard::TenantShardId; use pageserver_client::mgmt_api; -use rand::seq::SliceRandom; +use rand::seq::IndexedMutRandom; use tokio::sync::{OwnedSemaphorePermit, mpsc}; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; @@ -260,7 +260,7 @@ async fn timeline_actor( loop { let layer_tx = { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); timeline.layers.choose_mut(&mut rng).expect("no layers") }; match layer_tx.try_send(permit.take().unwrap()) { diff --git a/pageserver/src/feature_resolver.rs b/pageserver/src/feature_resolver.rs index f0178fd9b3..11b0e972b4 100644 --- a/pageserver/src/feature_resolver.rs +++ b/pageserver/src/feature_resolver.rs @@ -155,7 +155,7 @@ impl FeatureResolver { ); let tenant_properties = PerTenantProperties { - remote_size_mb: Some(rand::thread_rng().gen_range(100.0..1000000.00)), + remote_size_mb: Some(rand::rng().random_range(100.0..1000000.00)), } .into_posthog_properties(); diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 3d66ae4719..4c8856c386 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -6161,11 +6161,11 @@ mod tests { use pageserver_api::keyspace::KeySpaceRandomAccum; use pageserver_api::models::{CompactionAlgorithm, CompactionAlgorithmSettings, LsnLease}; use pageserver_compaction::helpers::overlaps_with; + use rand::Rng; #[cfg(feature = "testing")] use rand::SeedableRng; #[cfg(feature = "testing")] use rand::rngs::StdRng; - use rand::{Rng, thread_rng}; #[cfg(feature = "testing")] use std::ops::Range; use storage_layer::{IoConcurrency, PersistentLayerKey}; @@ -6286,8 +6286,8 @@ mod tests { while lsn < lsn_range.end { let mut key = key_range.start; while key < key_range.end { - let gap = random.gen_range(1..=100) <= spec.gap_chance; - let will_init = random.gen_range(1..=100) <= spec.will_init_chance; + let gap = random.random_range(1..=100) <= spec.gap_chance; + let will_init = random.random_range(1..=100) <= spec.will_init_chance; if gap { continue; @@ -6330,8 +6330,8 @@ mod tests { while lsn < lsn_range.end { let mut key = key_range.start; while key < key_range.end { - let gap = random.gen_range(1..=100) <= spec.gap_chance; - let will_init = random.gen_range(1..=100) <= spec.will_init_chance; + let gap = random.random_range(1..=100) <= spec.gap_chance; + let will_init = random.random_range(1..=100) <= spec.will_init_chance; if gap { continue; @@ -7808,7 +7808,7 @@ mod tests { for _ in 0..50 { for _ in 0..NUM_KEYS { lsn = Lsn(lsn.0 + 0x10); - let blknum = thread_rng().gen_range(0..NUM_KEYS); + let blknum = rand::rng().random_range(0..NUM_KEYS); test_key.field6 = blknum as u32; let mut writer = tline.writer().await; writer @@ -7897,7 +7897,7 @@ mod tests { for _ in 0..NUM_KEYS { lsn = Lsn(lsn.0 + 0x10); - let blknum = thread_rng().gen_range(0..NUM_KEYS); + let blknum = rand::rng().random_range(0..NUM_KEYS); test_key.field6 = blknum as u32; let mut writer = tline.writer().await; writer @@ -7965,7 +7965,7 @@ mod tests { for _ in 0..NUM_KEYS { lsn = Lsn(lsn.0 + 0x10); - let blknum = thread_rng().gen_range(0..NUM_KEYS); + let blknum = rand::rng().random_range(0..NUM_KEYS); test_key.field6 = blknum as u32; let mut writer = tline.writer().await; writer @@ -8229,7 +8229,7 @@ mod tests { for _ in 0..NUM_KEYS { lsn = Lsn(lsn.0 + 0x10); - let blknum = thread_rng().gen_range(0..NUM_KEYS); + let blknum = rand::rng().random_range(0..NUM_KEYS); test_key.field6 = (blknum * STEP) as u32; let mut writer = tline.writer().await; writer @@ -8502,7 +8502,7 @@ mod tests { for iter in 1..=10 { for _ in 0..NUM_KEYS { lsn = Lsn(lsn.0 + 0x10); - let blknum = thread_rng().gen_range(0..NUM_KEYS); + let blknum = rand::rng().random_range(0..NUM_KEYS); test_key.field6 = (blknum * STEP) as u32; let mut writer = tline.writer().await; writer @@ -11291,10 +11291,10 @@ mod tests { #[cfg(feature = "testing")] #[tokio::test] async fn test_read_path() -> anyhow::Result<()> { - use rand::seq::SliceRandom; + use rand::seq::IndexedRandom; let seed = if cfg!(feature = "fuzz-read-path") { - let seed: u64 = thread_rng().r#gen(); + let seed: u64 = rand::rng().random(); seed } else { // Use a hard-coded seed when not in fuzzing mode. @@ -11308,8 +11308,8 @@ mod tests { let (queries, will_init_chance, gap_chance) = if cfg!(feature = "fuzz-read-path") { const QUERIES: u64 = 5000; - let will_init_chance: u8 = random.gen_range(0..=10); - let gap_chance: u8 = random.gen_range(0..=50); + let will_init_chance: u8 = random.random_range(0..=10); + let gap_chance: u8 = random.random_range(0..=50); (QUERIES, will_init_chance, gap_chance) } else { @@ -11410,7 +11410,8 @@ mod tests { while used_keys.len() < tenant.conf.max_get_vectored_keys.get() { let selected_lsn = interesting_lsns.choose(&mut random).expect("not empty"); - let mut selected_key = start_key.add(random.gen_range(0..KEY_DIMENSION_SIZE)); + let mut selected_key = + start_key.add(random.random_range(0..KEY_DIMENSION_SIZE)); while used_keys.len() < tenant.conf.max_get_vectored_keys.get() { if used_keys.contains(&selected_key) @@ -11425,7 +11426,7 @@ mod tests { .add_key(selected_key); used_keys.insert(selected_key); - let pick_next = random.gen_range(0..=100) <= PICK_NEXT_CHANCE; + let pick_next = random.random_range(0..=100) <= PICK_NEXT_CHANCE; if pick_next { selected_key = selected_key.next(); } else { diff --git a/pageserver/src/tenant/blob_io.rs b/pageserver/src/tenant/blob_io.rs index ed541c4f12..29320f088c 100644 --- a/pageserver/src/tenant/blob_io.rs +++ b/pageserver/src/tenant/blob_io.rs @@ -535,8 +535,8 @@ pub(crate) mod tests { } pub(crate) fn random_array(len: usize) -> Vec { - let mut rng = rand::thread_rng(); - (0..len).map(|_| rng.r#gen()).collect::<_>() + let mut rng = rand::rng(); + (0..len).map(|_| rng.random()).collect::<_>() } #[tokio::test] @@ -588,9 +588,9 @@ pub(crate) mod tests { let mut rng = rand::rngs::StdRng::seed_from_u64(42); let blobs = (0..1024) .map(|_| { - let mut sz: u16 = rng.r#gen(); + let mut sz: u16 = rng.random(); // Make 50% of the arrays small - if rng.r#gen() { + if rng.random() { sz &= 63; } random_array(sz.into()) diff --git a/pageserver/src/tenant/disk_btree.rs b/pageserver/src/tenant/disk_btree.rs index 419befa41b..40f405307c 100644 --- a/pageserver/src/tenant/disk_btree.rs +++ b/pageserver/src/tenant/disk_btree.rs @@ -1090,7 +1090,7 @@ pub(crate) mod tests { const NUM_KEYS: usize = 100000; let mut all_data: BTreeMap = BTreeMap::new(); for idx in 0..NUM_KEYS { - let u: f64 = rand::thread_rng().gen_range(0.0..1.0); + let u: f64 = rand::rng().random_range(0.0..1.0); let t = -(f64::ln(u)); let key_int = (t * 1000000.0) as u128; @@ -1116,7 +1116,7 @@ pub(crate) mod tests { // Test get() operations on random keys, most of which will not exist for _ in 0..100000 { - let key_int = rand::thread_rng().r#gen::(); + let key_int = rand::rng().random::(); let search_key = u128::to_be_bytes(key_int); assert!(reader.get(&search_key, &ctx).await? == all_data.get(&key_int).cloned()); } diff --git a/pageserver/src/tenant/ephemeral_file.rs b/pageserver/src/tenant/ephemeral_file.rs index 203b5bf592..f2be129090 100644 --- a/pageserver/src/tenant/ephemeral_file.rs +++ b/pageserver/src/tenant/ephemeral_file.rs @@ -508,8 +508,8 @@ mod tests { let write_nbytes = cap * 2 + cap / 2; - let content: Vec = rand::thread_rng() - .sample_iter(rand::distributions::Standard) + let content: Vec = rand::rng() + .sample_iter(rand::distr::StandardUniform) .take(write_nbytes) .collect(); @@ -565,8 +565,8 @@ mod tests { let cap = writer.mutable().capacity(); drop(writer); - let content: Vec = rand::thread_rng() - .sample_iter(rand::distributions::Standard) + let content: Vec = rand::rng() + .sample_iter(rand::distr::StandardUniform) .take(cap * 2 + cap / 2) .collect(); @@ -614,8 +614,8 @@ mod tests { let cap = mutable.capacity(); let align = mutable.align(); drop(writer); - let content: Vec = rand::thread_rng() - .sample_iter(rand::distributions::Standard) + let content: Vec = rand::rng() + .sample_iter(rand::distr::StandardUniform) .take(cap * 2 + cap / 2) .collect(); diff --git a/pageserver/src/tenant/mgr.rs b/pageserver/src/tenant/mgr.rs index 01db09ed59..9b196ae393 100644 --- a/pageserver/src/tenant/mgr.rs +++ b/pageserver/src/tenant/mgr.rs @@ -19,7 +19,7 @@ use pageserver_api::shard::{ }; use pageserver_api::upcall_api::ReAttachResponseTenant; use rand::Rng; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use remote_storage::TimeoutOrCancel; use sysinfo::SystemExt; use tokio::fs; @@ -218,7 +218,7 @@ async fn safe_rename_tenant_dir(path: impl AsRef) -> std::io::Result Duration { if d == Duration::ZERO { d } else { - rand::thread_rng().gen_range((d * (100 - pct)) / 100..(d * (100 + pct)) / 100) + rand::rng().random_range((d * (100 - pct)) / 100..(d * (100 + pct)) / 100) } } @@ -35,7 +35,7 @@ pub(super) fn period_warmup(period: Duration) -> Duration { if period == Duration::ZERO { period } else { - rand::thread_rng().gen_range(Duration::ZERO..period) + rand::rng().random_range(Duration::ZERO..period) } } diff --git a/pageserver/src/tenant/storage_layer/delta_layer.rs b/pageserver/src/tenant/storage_layer/delta_layer.rs index c2f76c859c..f963fdac92 100644 --- a/pageserver/src/tenant/storage_layer/delta_layer.rs +++ b/pageserver/src/tenant/storage_layer/delta_layer.rs @@ -1634,7 +1634,8 @@ pub(crate) mod test { use bytes::Bytes; use itertools::MinMaxResult; use postgres_ffi::PgMajorVersion; - use rand::prelude::{SeedableRng, SliceRandom, StdRng}; + use rand::prelude::{SeedableRng, StdRng}; + use rand::seq::IndexedRandom; use rand::{Rng, RngCore}; /// Construct an index for a fictional delta layer and and then @@ -1788,14 +1789,14 @@ pub(crate) mod test { let mut entries = Vec::new(); for _ in 0..constants::KEY_COUNT { - let count = rng.gen_range(1..constants::MAX_ENTRIES_PER_KEY); + let count = rng.random_range(1..constants::MAX_ENTRIES_PER_KEY); let mut lsns_iter = std::iter::successors(Some(Lsn(constants::LSN_OFFSET.0 + 0x08)), |lsn| { Some(Lsn(lsn.0 + 0x08)) }); let mut lsns = Vec::new(); while lsns.len() < count as usize { - let take = rng.gen_bool(0.5); + let take = rng.random_bool(0.5); let lsn = lsns_iter.next().unwrap(); if take { lsns.push(lsn); @@ -1869,12 +1870,13 @@ pub(crate) mod test { for _ in 0..constants::RANGES_COUNT { let mut range: Option> = Option::default(); while range.is_none() || keyspace.overlaps(range.as_ref().unwrap()) { - let range_start = rng.gen_range(start..end); + let range_start = rng.random_range(start..end); let range_end_offset = range_start + constants::MIN_RANGE_SIZE; if range_end_offset >= end { range = Some(Key::from_i128(range_start)..Key::from_i128(end)); } else { - let range_end = rng.gen_range((range_start + constants::MIN_RANGE_SIZE)..end); + let range_end = + rng.random_range((range_start + constants::MIN_RANGE_SIZE)..end); range = Some(Key::from_i128(range_start)..Key::from_i128(range_end)); } } diff --git a/pageserver/src/tenant/storage_layer/inmemory_layer/vectored_dio_read.rs b/pageserver/src/tenant/storage_layer/inmemory_layer/vectored_dio_read.rs index 27fbc6f5fb..84f4386087 100644 --- a/pageserver/src/tenant/storage_layer/inmemory_layer/vectored_dio_read.rs +++ b/pageserver/src/tenant/storage_layer/inmemory_layer/vectored_dio_read.rs @@ -440,8 +440,8 @@ mod tests { impl InMemoryFile { fn new_random(len: usize) -> Self { Self { - content: rand::thread_rng() - .sample_iter(rand::distributions::Standard) + content: rand::rng() + .sample_iter(rand::distr::StandardUniform) .take(len) .collect(), } @@ -498,7 +498,7 @@ mod tests { len } }; - rand::Rng::fill(&mut rand::thread_rng(), &mut dst_slice[nread..]); // to discover bugs + rand::Rng::fill(&mut rand::rng(), &mut dst_slice[nread..]); // to discover bugs Ok((dst, nread)) } } @@ -763,7 +763,7 @@ mod tests { let len = std::cmp::min(dst.bytes_total(), mocked_bytes.len()); let dst_slice: &mut [u8] = dst.as_mut_rust_slice_full_zeroed(); dst_slice[..len].copy_from_slice(&mocked_bytes[..len]); - rand::Rng::fill(&mut rand::thread_rng(), &mut dst_slice[len..]); // to discover bugs + rand::Rng::fill(&mut rand::rng(), &mut dst_slice[len..]); // to discover bugs Ok((dst, len)) } Err(e) => Err(std::io::Error::other(e)), diff --git a/pageserver/src/tenant/tasks.rs b/pageserver/src/tenant/tasks.rs index 08fc7d61a5..676b39e55b 100644 --- a/pageserver/src/tenant/tasks.rs +++ b/pageserver/src/tenant/tasks.rs @@ -515,7 +515,7 @@ pub(crate) async fn sleep_random_range( interval: RangeInclusive, cancel: &CancellationToken, ) -> Result { - let delay = rand::thread_rng().gen_range(interval); + let delay = rand::rng().random_range(interval); if delay == Duration::ZERO { return Ok(delay); } diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 06e02a7386..0207a1f45b 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -2826,7 +2826,7 @@ impl Timeline { if r.numerator == 0 { false } else { - rand::thread_rng().gen_range(0..r.denominator) < r.numerator + rand::rng().random_range(0..r.denominator) < r.numerator } } None => false, @@ -3908,7 +3908,7 @@ impl Timeline { // 1hour base (60_i64 * 60_i64) // 10min jitter - + rand::thread_rng().gen_range(-10 * 60..10 * 60), + + rand::rng().random_range(-10 * 60..10 * 60), ) .expect("10min < 1hour"), ); diff --git a/pageserver/src/virtual_file.rs b/pageserver/src/virtual_file.rs index 45b6e44c54..a7f0c5914a 100644 --- a/pageserver/src/virtual_file.rs +++ b/pageserver/src/virtual_file.rs @@ -1275,8 +1275,8 @@ mod tests { use std::sync::Arc; use owned_buffers_io::io_buf_ext::IoBufExt; + use rand::Rng; use rand::seq::SliceRandom; - use rand::{Rng, thread_rng}; use super::*; use crate::context::DownloadBehavior; @@ -1358,7 +1358,7 @@ mod tests { // Check that all the other FDs still work too. Use them in random order for // good measure. - file_b_dupes.as_mut_slice().shuffle(&mut thread_rng()); + file_b_dupes.as_mut_slice().shuffle(&mut rand::rng()); for vfile in file_b_dupes.iter_mut() { assert_first_512_eq(vfile, b"content_b").await; } @@ -1413,9 +1413,8 @@ mod tests { let ctx = ctx.detached_child(TaskKind::UnitTest, DownloadBehavior::Error); let hdl = rt.spawn(async move { let mut buf = IoBufferMut::with_capacity_zeroed(SIZE); - let mut rng = rand::rngs::OsRng; for _ in 1..1000 { - let f = &files[rng.gen_range(0..files.len())]; + let f = &files[rand::rng().random_range(0..files.len())]; buf = f .read_exact_at(buf.slice_full(), 0, &ctx) .await diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 8392046839..3c3f93c8e3 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -66,6 +66,7 @@ postgres-client = { package = "tokio-postgres2", path = "../libs/proxy/tokio-pos postgres-protocol = { package = "postgres-protocol2", path = "../libs/proxy/postgres-protocol2" } pq_proto.workspace = true rand.workspace = true +rand_core.workspace = true regex.workspace = true remote_storage = { version = "0.1", path = "../libs/remote_storage/" } reqwest = { workspace = true, features = ["rustls-tls-native-roots"] } @@ -133,6 +134,6 @@ pbkdf2 = { workspace = true, features = ["simple", "std"] } rcgen.workspace = true rstest.workspace = true walkdir.workspace = true -rand_distr = "0.4" +rand_distr = "0.5" tokio-postgres.workspace = true tracing-test = "0.2" diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index a716890a00..6eba869870 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -803,7 +803,7 @@ mod tests { use http_body_util::Full; use hyper::service::service_fn; use hyper_util::rt::TokioIo; - use rand::rngs::OsRng; + use rand_core::OsRng; use rsa::pkcs8::DecodePrivateKey; use serde::Serialize; use serde_json::json; diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index d1dd2fef2a..255f6fbbee 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -13,7 +13,7 @@ use arc_swap::ArcSwapOption; use camino::Utf8PathBuf; use futures::future::Either; use itertools::{Itertools, Position}; -use rand::{Rng, thread_rng}; +use rand::Rng; use remote_storage::RemoteStorageConfig; use tokio::net::TcpListener; #[cfg(any(test, feature = "testing"))] @@ -573,7 +573,7 @@ pub async fn run() -> anyhow::Result<()> { attempt.into_inner() ); } - let jitter = thread_rng().gen_range(0..100); + let jitter = rand::rng().random_range(0..100); tokio::time::sleep(Duration::from_millis(1000 + jitter)).await; } } diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 0ef09a8a9a..a589dd175b 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -5,7 +5,7 @@ use std::time::Duration; use async_trait::async_trait; use clashmap::ClashMap; use clashmap::mapref::one::Ref; -use rand::{Rng, thread_rng}; +use rand::Rng; use tokio::time::Instant; use tracing::{debug, info}; @@ -343,7 +343,7 @@ impl ProjectInfoCacheImpl { } fn gc(&self) { - let shard = thread_rng().gen_range(0..self.project2ep.shards().len()); + let shard = rand::rng().random_range(0..self.project2ep.shards().len()); debug!(shard, "project_info_cache: performing epoch reclamation"); // acquire a random shard lock diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index 4d8df19476..715b818b98 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -523,29 +523,29 @@ mod tests { fn generate_request_data(rng: &mut impl Rng) -> RequestData { RequestData { - session_id: uuid::Builder::from_random_bytes(rng.r#gen()).into_uuid(), - peer_addr: Ipv4Addr::from(rng.r#gen::<[u8; 4]>()).to_string(), + session_id: uuid::Builder::from_random_bytes(rng.random()).into_uuid(), + peer_addr: Ipv4Addr::from(rng.random::<[u8; 4]>()).to_string(), timestamp: chrono::DateTime::from_timestamp_millis( - rng.gen_range(1703862754..1803862754), + rng.random_range(1703862754..1803862754), ) .unwrap() .naive_utc(), application_name: Some("test".to_owned()), user_agent: Some("test-user-agent".to_owned()), - username: Some(hex::encode(rng.r#gen::<[u8; 4]>())), - endpoint_id: Some(hex::encode(rng.r#gen::<[u8; 16]>())), - database: Some(hex::encode(rng.r#gen::<[u8; 16]>())), - project: Some(hex::encode(rng.r#gen::<[u8; 16]>())), - branch: Some(hex::encode(rng.r#gen::<[u8; 16]>())), + username: Some(hex::encode(rng.random::<[u8; 4]>())), + endpoint_id: Some(hex::encode(rng.random::<[u8; 16]>())), + database: Some(hex::encode(rng.random::<[u8; 16]>())), + project: Some(hex::encode(rng.random::<[u8; 16]>())), + branch: Some(hex::encode(rng.random::<[u8; 16]>())), pg_options: None, auth_method: None, jwt_issuer: None, - protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)], + protocol: ["tcp", "ws", "http"][rng.random_range(0..3)], region: String::new(), error: None, - success: rng.r#gen(), + success: rng.random(), cold_start_info: "no", - duration_us: rng.gen_range(0..30_000_000), + duration_us: rng.random_range(0..30_000_000), disconnect_timestamp: None, } } @@ -622,15 +622,15 @@ mod tests { assert_eq!( file_stats, [ - (1313953, 3, 6000), - (1313942, 3, 6000), - (1314001, 3, 6000), - (1313958, 3, 6000), - (1314094, 3, 6000), - (1313931, 3, 6000), - (1313725, 3, 6000), - (1313960, 3, 6000), - (438318, 1, 2000) + (1313878, 3, 6000), + (1313891, 3, 6000), + (1314058, 3, 6000), + (1313914, 3, 6000), + (1313760, 3, 6000), + (1314084, 3, 6000), + (1313965, 3, 6000), + (1313911, 3, 6000), + (438290, 1, 2000) ] ); @@ -662,11 +662,11 @@ mod tests { assert_eq!( file_stats, [ - (1205810, 5, 10000), - (1205534, 5, 10000), - (1205835, 5, 10000), - (1205820, 5, 10000), - (1206074, 5, 10000) + (1206039, 5, 10000), + (1205798, 5, 10000), + (1205776, 5, 10000), + (1206051, 5, 10000), + (1205746, 5, 10000) ] ); @@ -691,15 +691,15 @@ mod tests { assert_eq!( file_stats, [ - (1313953, 3, 6000), - (1313942, 3, 6000), - (1314001, 3, 6000), - (1313958, 3, 6000), - (1314094, 3, 6000), - (1313931, 3, 6000), - (1313725, 3, 6000), - (1313960, 3, 6000), - (438318, 1, 2000) + (1313878, 3, 6000), + (1313891, 3, 6000), + (1314058, 3, 6000), + (1313914, 3, 6000), + (1313760, 3, 6000), + (1314084, 3, 6000), + (1313965, 3, 6000), + (1313911, 3, 6000), + (438290, 1, 2000) ] ); @@ -736,7 +736,7 @@ mod tests { // files are smaller than the size threshold, but they took too long to fill so were flushed early assert_eq!( file_stats, - [(658584, 2, 3001), (658298, 2, 3000), (658094, 2, 2999)] + [(658552, 2, 3001), (658265, 2, 3000), (658061, 2, 2999)] ); tmpdir.close().unwrap(); diff --git a/proxy/src/intern.rs b/proxy/src/intern.rs index d7e39ebaf4..825f2d1049 100644 --- a/proxy/src/intern.rs +++ b/proxy/src/intern.rs @@ -247,7 +247,7 @@ mod tests { use rand::{Rng, SeedableRng}; use rand_distr::Zipf; - let endpoint_dist = Zipf::new(500000, 0.8).unwrap(); + let endpoint_dist = Zipf::new(500000.0, 0.8).unwrap(); let endpoints = StdRng::seed_from_u64(272488357).sample_iter(endpoint_dist); let interner = MyId::get_interner(); diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs index ad99eecda5..680a23c435 100644 --- a/proxy/src/pqproto.rs +++ b/proxy/src/pqproto.rs @@ -7,7 +7,7 @@ use std::io::{self, Cursor}; use bytes::{Buf, BufMut}; use itertools::Itertools; -use rand::distributions::{Distribution, Standard}; +use rand::distr::{Distribution, StandardUniform}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian}; @@ -458,9 +458,9 @@ impl fmt::Display for CancelKeyData { .finish() } } -impl Distribution for Standard { +impl Distribution for StandardUniform { fn sample(&self, rng: &mut R) -> CancelKeyData { - id_to_cancel_key(rng.r#gen()) + id_to_cancel_key(rng.random()) } } diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index dd89b05426..f8bff450e1 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -338,8 +338,8 @@ async fn scram_auth_mock() -> anyhow::Result<()> { let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), Scram::mock())); use rand::Rng; - use rand::distributions::Alphanumeric; - let password: String = rand::thread_rng() + use rand::distr::Alphanumeric; + let password: String = rand::rng() .sample_iter(&Alphanumeric) .take(rand::random::() as usize) .map(char::from) diff --git a/proxy/src/rate_limiter/leaky_bucket.rs b/proxy/src/rate_limiter/leaky_bucket.rs index 12b4bda0c0..9de82e922c 100644 --- a/proxy/src/rate_limiter/leaky_bucket.rs +++ b/proxy/src/rate_limiter/leaky_bucket.rs @@ -3,7 +3,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use ahash::RandomState; use clashmap::ClashMap; -use rand::{Rng, thread_rng}; +use rand::Rng; use tokio::time::Instant; use tracing::info; use utils::leaky_bucket::LeakyBucketState; @@ -61,7 +61,7 @@ impl LeakyBucketRateLimiter { self.map.len() ); let n = self.map.shards().len(); - let shard = thread_rng().gen_range(0..n); + let shard = rand::rng().random_range(0..n); self.map.shards()[shard] .write() .retain(|(_, value)| !value.bucket_is_empty(now)); diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index fd1b2af023..2b3d745a0e 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -147,7 +147,7 @@ impl RateBucketInfo { impl BucketRateLimiter { pub fn new(info: impl Into>) -> Self { - Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new()) + Self::new_with_rand_and_hasher(info, StdRng::from_os_rng(), RandomState::new()) } } @@ -216,7 +216,7 @@ impl BucketRateLimiter { let n = self.map.shards().len(); // this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide // (impossible, infact, unless we have 2048 threads) - let shard = self.rand.lock_propagate_poison().gen_range(0..n); + let shard = self.rand.lock_propagate_poison().random_range(0..n); self.map.shards()[shard].write().clear(); } } diff --git a/proxy/src/scram/countmin.rs b/proxy/src/scram/countmin.rs index 9d56c465ec..d64895f8f5 100644 --- a/proxy/src/scram/countmin.rs +++ b/proxy/src/scram/countmin.rs @@ -86,11 +86,11 @@ mod tests { for _ in 0..n { // number to insert at once - let n = rng.gen_range(1..4096); + let n = rng.random_range(1..4096); // number of insert operations - let m = rng.gen_range(1..100); + let m = rng.random_range(1..100); - let id = uuid::Builder::from_random_bytes(rng.r#gen()).into_uuid(); + let id = uuid::Builder::from_random_bytes(rng.random()).into_uuid(); ids.push((id, n, m)); // N = sum(actual) @@ -140,8 +140,8 @@ mod tests { // probably numbers are too small to truly represent the probabilities. assert_eq!(eval_precision(100, 4096.0, 0.90), 100); assert_eq!(eval_precision(1000, 4096.0, 0.90), 1000); - assert_eq!(eval_precision(100, 4096.0, 0.1), 96); - assert_eq!(eval_precision(1000, 4096.0, 0.1), 988); + assert_eq!(eval_precision(100, 4096.0, 0.1), 100); + assert_eq!(eval_precision(1000, 4096.0, 0.1), 978); } // returns memory usage in bytes, and the time complexity per insert. diff --git a/proxy/src/scram/threadpool.rs b/proxy/src/scram/threadpool.rs index 1aa402227f..ea2e29ede9 100644 --- a/proxy/src/scram/threadpool.rs +++ b/proxy/src/scram/threadpool.rs @@ -51,7 +51,7 @@ impl ThreadPool { *state = Some(ThreadRt { pool: pool.clone(), id: ThreadPoolWorkerId(worker_id.fetch_add(1, Ordering::Relaxed)), - rng: SmallRng::from_entropy(), + rng: SmallRng::from_os_rng(), // used to determine whether we should temporarily skip tasks for fairness. // 99% of estimates will overcount by no more than 4096 samples countmin: CountMinSketch::with_params( @@ -120,7 +120,7 @@ impl ThreadRt { // in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above // are in requests per second. let probability = P.ln() / (P + rate as f64).ln(); - self.rng.gen_bool(probability) + self.rng.random_bool(probability) } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index daa6429039..59e4b09bc9 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -8,7 +8,7 @@ use ed25519_dalek::SigningKey; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use jose_jwk::jose_b64; use postgres_client::config::SslMode; -use rand::rngs::OsRng; +use rand_core::OsRng; use rustls::pki_types::{DnsName, ServerName}; use tokio::net::{TcpStream, lookup_host}; use tokio_rustls::TlsConnector; diff --git a/proxy/src/serverless/cancel_set.rs b/proxy/src/serverless/cancel_set.rs index ba8945afc5..142dc3b3d5 100644 --- a/proxy/src/serverless/cancel_set.rs +++ b/proxy/src/serverless/cancel_set.rs @@ -6,7 +6,7 @@ use std::time::Duration; use indexmap::IndexMap; use parking_lot::Mutex; -use rand::{Rng, thread_rng}; +use rand::distr::uniform::{UniformSampler, UniformUsize}; use rustc_hash::FxHasher; use tokio::time::Instant; use tokio_util::sync::CancellationToken; @@ -39,8 +39,9 @@ impl CancelSet { } pub(crate) fn take(&self) -> Option { + let dist = UniformUsize::new_inclusive(0, usize::MAX).expect("valid bounds"); for _ in 0..4 { - if let Some(token) = self.take_raw(thread_rng().r#gen()) { + if let Some(token) = self.take_raw(dist.sample(&mut rand::rng())) { return Some(token); } tracing::trace!("failed to get cancel token"); @@ -48,7 +49,7 @@ impl CancelSet { None } - pub(crate) fn take_raw(&self, rng: usize) -> Option { + fn take_raw(&self, rng: usize) -> Option { NonZeroUsize::new(self.shards.len()) .and_then(|len| self.shards[rng % len].lock().take(rng / len)) } diff --git a/proxy/src/serverless/conn_pool_lib.rs b/proxy/src/serverless/conn_pool_lib.rs index 42a3ea17a2..ed5cc0ea03 100644 --- a/proxy/src/serverless/conn_pool_lib.rs +++ b/proxy/src/serverless/conn_pool_lib.rs @@ -428,7 +428,7 @@ where loop { interval.tick().await; - let shard = rng.gen_range(0..self.global_pool.shards().len()); + let shard = rng.random_range(0..self.global_pool.shards().len()); self.gc(shard); } } diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 18cdc39ac7..13f9ee2782 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -77,7 +77,7 @@ pub async fn task_main( { let conn_pool = Arc::clone(&conn_pool); tokio::spawn(async move { - conn_pool.gc_worker(StdRng::from_entropy()).await; + conn_pool.gc_worker(StdRng::from_os_rng()).await; }); } @@ -97,7 +97,7 @@ pub async fn task_main( { let http_conn_pool = Arc::clone(&http_conn_pool); tokio::spawn(async move { - http_conn_pool.gc_worker(StdRng::from_entropy()).await; + http_conn_pool.gc_worker(StdRng::from_os_rng()).await; }); } diff --git a/safekeeper/src/rate_limit.rs b/safekeeper/src/rate_limit.rs index 72373b5786..0e697ade57 100644 --- a/safekeeper/src/rate_limit.rs +++ b/safekeeper/src/rate_limit.rs @@ -44,6 +44,6 @@ impl RateLimiter { /// Generate a random duration that is a fraction of the given duration. pub fn rand_duration(duration: &std::time::Duration) -> std::time::Duration { - let randf64 = rand::thread_rng().gen_range(0.0..1.0); + let randf64 = rand::rng().random_range(0.0..1.0); duration.mul_f64(randf64) } diff --git a/safekeeper/tests/random_test.rs b/safekeeper/tests/random_test.rs index e29b58836a..7e7d2390e9 100644 --- a/safekeeper/tests/random_test.rs +++ b/safekeeper/tests/random_test.rs @@ -16,7 +16,7 @@ fn test_random_schedules() -> anyhow::Result<()> { let mut config = TestConfig::new(Some(clock)); for _ in 0..500 { - let seed: u64 = rand::thread_rng().r#gen(); + let seed: u64 = rand::rng().random(); config.network = generate_network_opts(seed); let test = config.start(seed); diff --git a/safekeeper/tests/walproposer_sim/simulation.rs b/safekeeper/tests/walproposer_sim/simulation.rs index edd3bf2d9e..595cc7ab64 100644 --- a/safekeeper/tests/walproposer_sim/simulation.rs +++ b/safekeeper/tests/walproposer_sim/simulation.rs @@ -394,13 +394,13 @@ pub fn generate_schedule(seed: u64) -> Schedule { let mut schedule = Vec::new(); let mut time = 0; - let cnt = rng.gen_range(1..100); + let cnt = rng.random_range(1..100); for _ in 0..cnt { - time += rng.gen_range(0..500); - let action = match rng.gen_range(0..3) { - 0 => TestAction::WriteTx(rng.gen_range(1..10)), - 1 => TestAction::RestartSafekeeper(rng.gen_range(0..3)), + time += rng.random_range(0..500); + let action = match rng.random_range(0..3) { + 0 => TestAction::WriteTx(rng.random_range(1..10)), + 1 => TestAction::RestartSafekeeper(rng.random_range(0..3)), 2 => TestAction::RestartWalProposer, _ => unreachable!(), }; @@ -413,13 +413,13 @@ pub fn generate_schedule(seed: u64) -> Schedule { pub fn generate_network_opts(seed: u64) -> NetworkOptions { let mut rng = rand::rngs::StdRng::seed_from_u64(seed); - let timeout = rng.gen_range(100..2000); - let max_delay = rng.gen_range(1..2 * timeout); - let min_delay = rng.gen_range(1..=max_delay); + let timeout = rng.random_range(100..2000); + let max_delay = rng.random_range(1..2 * timeout); + let min_delay = rng.random_range(1..=max_delay); - let max_fail_prob = rng.gen_range(0.0..0.9); - let connect_fail_prob = rng.gen_range(0.0..max_fail_prob); - let send_fail_prob = rng.gen_range(0.0..connect_fail_prob); + let max_fail_prob = rng.random_range(0.0..0.9); + let connect_fail_prob = rng.random_range(0.0..max_fail_prob); + let send_fail_prob = rng.random_range(0.0..connect_fail_prob); NetworkOptions { keepalive_timeout: Some(timeout), diff --git a/storage_controller/src/hadron_utils.rs b/storage_controller/src/hadron_utils.rs index 871e21c367..8bfbe8e575 100644 --- a/storage_controller/src/hadron_utils.rs +++ b/storage_controller/src/hadron_utils.rs @@ -8,10 +8,10 @@ static CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz01 /// Generate a random string of `length` that can be used as a password. The generated string /// contains alphanumeric characters and special characters (!@#$%^&*()) pub fn generate_random_password(length: usize) -> String { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); (0..length) .map(|_| { - let idx = rng.gen_range(0..CHARSET.len()); + let idx = rng.random_range(0..CHARSET.len()); CHARSET[idx] as char }) .collect() diff --git a/storage_controller/src/service/chaos_injector.rs b/storage_controller/src/service/chaos_injector.rs index 4087de200a..0efeef4e80 100644 --- a/storage_controller/src/service/chaos_injector.rs +++ b/storage_controller/src/service/chaos_injector.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use std::time::Duration; use pageserver_api::controller_api::ShardSchedulingPolicy; -use rand::seq::SliceRandom; -use rand::{Rng, thread_rng}; +use rand::Rng; +use rand::seq::{IndexedRandom, SliceRandom}; use tokio_util::sync::CancellationToken; use utils::id::NodeId; use utils::shard::TenantShardId; @@ -72,7 +72,7 @@ impl ChaosInjector { let cron_interval = self.get_cron_interval_sleep_future(); let chaos_type = tokio::select! { _ = interval.tick() => { - if thread_rng().gen_bool(0.5) { + if rand::rng().random_bool(0.5) { ChaosEvent::MigrationsToSecondary } else { ChaosEvent::GracefulMigrationsAnywhere @@ -134,7 +134,7 @@ impl ChaosInjector { let Some(new_location) = shard .intent .get_secondary() - .choose(&mut thread_rng()) + .choose(&mut rand::rng()) .cloned() else { tracing::info!( @@ -190,7 +190,7 @@ impl ChaosInjector { // Pick our victims: use a hand-rolled loop rather than choose_multiple() because we want // to take the mutable refs from our candidates rather than ref'ing them. while !candidates.is_empty() && victims.len() < batch_size { - let i = thread_rng().gen_range(0..candidates.len()); + let i = rand::rng().random_range(0..candidates.len()); victims.push(candidates.swap_remove(i)); } @@ -210,7 +210,7 @@ impl ChaosInjector { }) .collect::>(); - let Some(victim_node) = candidate_nodes.choose(&mut thread_rng()) else { + let Some(victim_node) = candidate_nodes.choose(&mut rand::rng()) else { // This can happen if e.g. we are in a small region with only one pageserver per AZ. tracing::info!( "no candidate nodes found for migrating shard {tenant_shard_id} within its home AZ", @@ -264,7 +264,7 @@ impl ChaosInjector { out_of_home_az.len() ); - out_of_home_az.shuffle(&mut thread_rng()); + out_of_home_az.shuffle(&mut rand::rng()); victims.extend(out_of_home_az.into_iter().take(batch_size)); } else { tracing::info!( @@ -274,7 +274,7 @@ impl ChaosInjector { ); victims.extend(out_of_home_az); - in_home_az.shuffle(&mut thread_rng()); + in_home_az.shuffle(&mut rand::rng()); victims.extend(in_home_az.into_iter().take(batch_size - victims.len())); }