diff --git a/.config/hakari.toml b/.config/hakari.toml index 9991cd92b0..dcbc44cc33 100644 --- a/.config/hakari.toml +++ b/.config/hakari.toml @@ -21,13 +21,14 @@ platforms = [ # "x86_64-apple-darwin", # "x86_64-pc-windows-msvc", ] - [final-excludes] workspace-members = [ # vm_monitor benefits from the same Cargo.lock as the rest of our artifacts, but # it is built primarly in separate repo neondatabase/autoscaling and thus is excluded # from depending on workspace-hack because most of the dependencies are not used. "vm_monitor", + # subzero-core is a stub crate that should be excluded from workspace-hack + "subzero-core", # All of these exist in libs and are not usually built independently. # Putting workspace hack there adds a bottleneck for cargo builds. "compute_api", diff --git a/.github/actions/prepare-for-subzero/action.yml b/.github/actions/prepare-for-subzero/action.yml new file mode 100644 index 0000000000..11beb11880 --- /dev/null +++ b/.github/actions/prepare-for-subzero/action.yml @@ -0,0 +1,28 @@ +name: 'Prepare current job for subzero' +description: > + Set git token to access `neondatabase/subzero` from cargo build, + and set `CARGO_NET_GIT_FETCH_WITH_CLI=true` env variable to use git CLI + +inputs: + token: + description: 'GitHub token with access to neondatabase/subzero' + required: true + +runs: + using: "composite" + + steps: + - name: Set git token for neondatabase/subzero + uses: pyTooling/Actions/with-post-step@2307b526df64d55e95884e072e49aac2a00a9afa # v5.1.0 + env: + SUBZERO_ACCESS_TOKEN: ${{ inputs.token }} + with: + main: | + git config --global url."https://x-access-token:${SUBZERO_ACCESS_TOKEN}@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero" + cargo add -p proxy subzero-core --git https://github.com/neondatabase/subzero --rev 396264617e78e8be428682f87469bb25429af88a + post: | + git config --global --unset url."https://x-access-token:${SUBZERO_ACCESS_TOKEN}@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero" + + - name: Set `CARGO_NET_GIT_FETCH_WITH_CLI=true` env variable + shell: bash -euxo pipefail {0} + run: echo "CARGO_NET_GIT_FETCH_WITH_CLI=true" >> ${GITHUB_ENV} diff --git a/.github/workflows/_build-and-test-locally.yml b/.github/workflows/_build-and-test-locally.yml index 94115572df..1b03dc9c03 100644 --- a/.github/workflows/_build-and-test-locally.yml +++ b/.github/workflows/_build-and-test-locally.yml @@ -86,6 +86,10 @@ jobs: with: submodules: true + - uses: ./.github/actions/prepare-for-subzero + with: + token: ${{ secrets.CI_ACCESS_TOKEN }} + - name: Set pg 14 revision for caching id: pg_v14_rev run: echo pg_rev=$(git rev-parse HEAD:vendor/postgres-v14) >> $GITHUB_OUTPUT @@ -116,7 +120,7 @@ jobs: ARCH: ${{ inputs.arch }} SANITIZERS: ${{ inputs.sanitizers }} run: | - CARGO_FLAGS="--locked --features testing" + CARGO_FLAGS="--locked --features testing,rest_broker" if [[ $BUILD_TYPE == "debug" && $ARCH == 'x64' ]]; then cov_prefix="scripts/coverage --profraw-prefix=$GITHUB_JOB --dir=/tmp/coverage run" CARGO_PROFILE="" diff --git a/.github/workflows/_check-codestyle-rust.yml b/.github/workflows/_check-codestyle-rust.yml index 4f844b0bf6..af29e10e97 100644 --- a/.github/workflows/_check-codestyle-rust.yml +++ b/.github/workflows/_check-codestyle-rust.yml @@ -46,6 +46,10 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: true + + - uses: ./.github/actions/prepare-for-subzero + with: + token: ${{ secrets.CI_ACCESS_TOKEN }} - name: Cache cargo deps uses: tespkg/actions-cache@b7bf5fcc2f98a52ac6080eb0fd282c2f752074b1 # v1.8.0 diff --git a/.github/workflows/build-macos.yml b/.github/workflows/build-macos.yml index 2296807d2d..e43eec1133 100644 --- a/.github/workflows/build-macos.yml +++ b/.github/workflows/build-macos.yml @@ -54,6 +54,10 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: true + + - uses: ./.github/actions/prepare-for-subzero + with: + token: ${{ secrets.CI_ACCESS_TOKEN }} - name: Install build dependencies run: | diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 2977f642bc..f237a991cc 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -632,6 +632,8 @@ jobs: BUILD_TAG=${{ needs.meta.outputs.release-tag || needs.meta.outputs.build-tag }} TAG=${{ needs.build-build-tools-image.outputs.image-tag }}-bookworm DEBIAN_VERSION=bookworm + secrets: | + SUBZERO_ACCESS_TOKEN=${{ secrets.CI_ACCESS_TOKEN }} provenance: false push: true pull: true diff --git a/.github/workflows/neon_extra_builds.yml b/.github/workflows/neon_extra_builds.yml index 3e81183687..10ca1a1591 100644 --- a/.github/workflows/neon_extra_builds.yml +++ b/.github/workflows/neon_extra_builds.yml @@ -72,6 +72,7 @@ jobs: check-macos-build: needs: [ check-permissions, files-changed ] uses: ./.github/workflows/build-macos.yml + secrets: inherit with: pg_versions: ${{ needs.files-changed.outputs.postgres_changes }} rebuild_rust_code: ${{ fromJSON(needs.files-changed.outputs.rebuild_rust_code) }} diff --git a/.gitignore b/.gitignore index 835cceb123..1e1c2316af 100644 --- a/.gitignore +++ b/.gitignore @@ -26,9 +26,14 @@ docker-compose/docker-compose-parallel.yml *.o *.so *.Po +*.pid # pgindent typedef lists *.list # Node **/node_modules/ + +# various files for local testing +/proxy/.subzero +local_proxy.json diff --git a/Cargo.lock b/Cargo.lock index 3d89359465..b72911c7a8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -52,6 +52,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "aliasable" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd" + [[package]] name = "aligned-vec" version = "0.6.1" @@ -490,7 +496,7 @@ dependencies = [ "hex", "hmac", "http 0.2.9", - "http 1.1.0", + "http 1.3.1", "once_cell", "p256 0.11.1", "percent-encoding", @@ -631,7 +637,7 @@ dependencies = [ "aws-smithy-types", "bytes", "http 0.2.9", - "http 1.1.0", + "http 1.3.1", "pin-project-lite", "tokio", "tracing", @@ -649,7 +655,7 @@ dependencies = [ "bytes-utils", "futures-core", "http 0.2.9", - "http 1.1.0", + "http 1.3.1", "http-body 0.4.5", "http-body 1.0.0", "http-body-util", @@ -698,7 +704,7 @@ dependencies = [ "bytes", "form_urlencoded", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "hyper 1.4.1", @@ -732,7 +738,7 @@ checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733" dependencies = [ "bytes", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "mime", @@ -756,7 +762,7 @@ dependencies = [ "form_urlencoded", "futures-util", "headers", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "mime", @@ -1090,7 +1096,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "975982cdb7ad6a142be15bdf84aea7ec6a9e5d4d797c004d43185b24cfe4e684" dependencies = [ "clap", - "heck", + "heck 0.5.0", "indexmap 2.9.0", "log", "proc-macro2", @@ -1228,7 +1234,7 @@ version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "syn 2.0.100", @@ -1334,7 +1340,7 @@ dependencies = [ "flate2", "futures", "hostname-validator", - "http 1.1.0", + "http 1.3.1", "indexmap 2.9.0", "itertools 0.10.5", "jsonwebtoken", @@ -1445,7 +1451,7 @@ name = "consumption_metrics" version = "0.1.0" dependencies = [ "chrono", - "rand 0.8.5", + "rand 0.9.1", "serde", ] @@ -1848,7 +1854,7 @@ dependencies = [ "bytes", "hex", "parking_lot 0.12.1", - "rand 0.8.5", + "rand 0.9.1", "smallvec", "tracing", "utils", @@ -1969,7 +1975,7 @@ checksum = "0892a17df262a24294c382f0d5997571006e7a4348b4327557c4ff1cd4a8bccc" dependencies = [ "darling", "either", - "heck", + "heck 0.5.0", "proc-macro2", "quote", "syn 2.0.100", @@ -2093,7 +2099,7 @@ dependencies = [ "itertools 0.10.5", "jsonwebtoken", "prometheus", - "rand 0.8.5", + "rand 0.9.1", "remote_storage", "serde", "serde_json", @@ -2661,7 +2667,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http 1.1.0", + "http 1.3.1", "indexmap 2.9.0", "slab", "tokio", @@ -2743,7 +2749,7 @@ dependencies = [ "base64 0.21.7", "bytes", "headers-core", - "http 1.1.0", + "http 1.3.1", "httpdate", "mime", "sha1", @@ -2755,9 +2761,15 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" dependencies = [ - "http 1.1.0", + "http 1.3.1", ] +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "heck" version = "0.5.0" @@ -2833,9 +2845,9 @@ dependencies = [ [[package]] name = "http" -version = "1.1.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" dependencies = [ "bytes", "fnv", @@ -2860,7 +2872,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" dependencies = [ "bytes", - "http 1.1.0", + "http 1.3.1", ] [[package]] @@ -2871,7 +2883,7 @@ checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" dependencies = [ "bytes", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "pin-project-lite", ] @@ -2995,7 +3007,7 @@ dependencies = [ "futures-channel", "futures-util", "h2 0.4.4", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "httparse", "httpdate", @@ -3028,7 +3040,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a0bea761b46ae2b24eb4aef630d8d1c398157b6fc29e6350ecf090a0b70c952c" dependencies = [ "futures-util", - "http 1.1.0", + "http 1.3.1", "hyper 1.4.1", "hyper-util", "rustls 0.22.4", @@ -3060,7 +3072,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "hyper 1.4.1", "pin-project-lite", @@ -3709,7 +3721,7 @@ version = "0.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9e6777fc80a575f9503d908c8b498782a6c3ee88a06cb416dc3941401e43b94" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "syn 2.0.100", @@ -3770,8 +3782,8 @@ dependencies = [ "once_cell", "procfs", "prometheus", - "rand 0.8.5", - "rand_distr 0.4.3", + "rand 0.9.1", + "rand_distr", "twox-hash", ] @@ -3863,7 +3875,7 @@ dependencies = [ "lock_api", "nix 0.30.1", "rand 0.9.1", - "rand_distr 0.5.1", + "rand_distr", "rustc-hash 2.1.1", "tempfile", "thiserror 1.0.69", @@ -4160,7 +4172,7 @@ checksum = "10a8a7f5f6ba7c1b286c2fbca0454eaba116f63bbe69ed250b642d36fbb04d80" dependencies = [ "async-trait", "bytes", - "http 1.1.0", + "http 1.3.1", "opentelemetry", "reqwest", ] @@ -4173,7 +4185,7 @@ checksum = "91cf61a1868dacc576bf2b2a1c3e9ab150af7272909e80085c3173384fe11f76" dependencies = [ "async-trait", "futures-core", - "http 1.1.0", + "http 1.3.1", "opentelemetry", "opentelemetry-http", "opentelemetry-proto", @@ -4252,6 +4264,30 @@ dependencies = [ "winapi", ] +[[package]] +name = "ouroboros" +version = "0.18.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e0f050db9c44b97a94723127e6be766ac5c340c48f2c4bb3ffa11713744be59" +dependencies = [ + "aliasable", + "ouroboros_macro", + "static_assertions", +] + +[[package]] +name = "ouroboros_macro" +version = "0.18.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c7028bdd3d43083f6d8d4d5187680d0d3560d54df4cc9d752005268b41e64d0" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "proc-macro2-diagnostics", + "quote", + "syn 2.0.100", +] + [[package]] name = "outref" version = "0.5.1" @@ -4315,7 +4351,7 @@ dependencies = [ "pageserver_client_grpc", "pageserver_page_api", "pprof", - "rand 0.8.5", + "rand 0.9.1", "reqwest", "serde", "serde_json", @@ -4381,7 +4417,7 @@ dependencies = [ "hashlink", "hex", "hex-literal", - "http 1.1.0", + "http 1.3.1", "http-utils", "humantime", "humantime-serde", @@ -4412,7 +4448,7 @@ dependencies = [ "pprof", "pq_proto", "procfs", - "rand 0.8.5", + "rand 0.9.1", "range-set-blaze", "regex", "remote_storage", @@ -4479,7 +4515,7 @@ dependencies = [ "postgres_ffi_types", "postgres_versioninfo", "posthog_client_lite", - "rand 0.8.5", + "rand 0.9.1", "remote_storage", "reqwest", "serde", @@ -4549,7 +4585,7 @@ dependencies = [ "once_cell", "pageserver_api", "pin-project-lite", - "rand 0.8.5", + "rand 0.9.1", "svg_fmt", "tokio", "tracing", @@ -4923,7 +4959,7 @@ dependencies = [ "fallible-iterator", "hmac", "memchr", - "rand 0.8.5", + "rand 0.9.1", "sha2", "stringprep", "tokio", @@ -5115,7 +5151,7 @@ dependencies = [ "bytes", "itertools 0.10.5", "postgres-protocol", - "rand 0.8.5", + "rand 0.9.1", "serde", "thiserror 1.0.69", "tokio", @@ -5149,6 +5185,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proc-macro2-diagnostics" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", + "version_check", + "yansi", +] + [[package]] name = "procfs" version = "0.16.0" @@ -5218,7 +5267,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" dependencies = [ "bytes", - "heck", + "heck 0.5.0", "itertools 0.12.1", "log", "multimap", @@ -5239,7 +5288,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15" dependencies = [ "bytes", - "heck", + "heck 0.5.0", "itertools 0.12.1", "log", "multimap", @@ -5335,7 +5384,7 @@ dependencies = [ "hex", "hmac", "hostname", - "http 1.1.0", + "http 1.3.1", "http-body-util", "http-utils", "humantime", @@ -5355,6 +5404,7 @@ dependencies = [ "metrics", "once_cell", "opentelemetry", + "ouroboros", "p256 0.13.2", "papaya", "parking_lot 0.12.1", @@ -5365,8 +5415,9 @@ dependencies = [ "postgres-protocol2", "postgres_backend", "pq_proto", - "rand 0.8.5", - "rand_distr 0.4.3", + "rand 0.9.1", + "rand_core 0.6.4", + "rand_distr", "rcgen", "redis", "regex", @@ -5391,6 +5442,7 @@ dependencies = [ "socket2", "strum_macros", "subtle", + "subzero-core", "thiserror 1.0.69", "tikv-jemalloc-ctl", "tikv-jemallocator", @@ -5567,16 +5619,6 @@ dependencies = [ "getrandom 0.3.3", ] -[[package]] -name = "rand_distr" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" -dependencies = [ - "num-traits", - "rand 0.8.5", -] - [[package]] name = "rand_distr" version = "0.5.1" @@ -5706,14 +5748,14 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.2" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.3", - "regex-syntax 0.8.2", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", ] [[package]] @@ -5727,13 +5769,13 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.3" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.2", + "regex-syntax 0.8.5", ] [[package]] @@ -5750,9 +5792,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.8.2" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "relative-path" @@ -5790,7 +5832,7 @@ dependencies = [ "metrics", "once_cell", "pin-project-lite", - "rand 0.8.5", + "rand 0.9.1", "reqwest", "scopeguard", "serde", @@ -5822,7 +5864,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "hyper 1.4.1", @@ -5864,7 +5906,7 @@ checksum = "d1ccd3b55e711f91a9885a2fa6fbbb2e39db1776420b062efc058c6410f7e5e3" dependencies = [ "anyhow", "async-trait", - "http 1.1.0", + "http 1.3.1", "reqwest", "serde", "thiserror 1.0.69", @@ -5881,7 +5923,7 @@ dependencies = [ "async-trait", "futures", "getrandom 0.2.11", - "http 1.1.0", + "http 1.3.1", "hyper 1.4.1", "parking_lot 0.11.2", "reqwest", @@ -5902,7 +5944,7 @@ dependencies = [ "anyhow", "async-trait", "getrandom 0.2.11", - "http 1.1.0", + "http 1.3.1", "matchit", "opentelemetry", "reqwest", @@ -6261,7 +6303,7 @@ dependencies = [ "fail", "futures", "hex", - "http 1.1.0", + "http 1.3.1", "http-utils", "humantime", "hyper 0.14.30", @@ -6280,7 +6322,7 @@ dependencies = [ "postgres_versioninfo", "pprof", "pq_proto", - "rand 0.8.5", + "rand 0.9.1", "regex", "remote_storage", "reqwest", @@ -6974,7 +7016,7 @@ dependencies = [ "pageserver_client", "postgres_connection", "posthog_client_lite", - "rand 0.8.5", + "rand 0.9.1", "regex", "reqwest", "routerify", @@ -7110,7 +7152,7 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "rustversion", @@ -7123,6 +7165,10 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" +[[package]] +name = "subzero-core" +version = "3.0.1" + [[package]] name = "svg_fmt" version = "0.4.3" @@ -7733,7 +7779,7 @@ dependencies = [ "async-trait", "base64 0.22.1", "bytes", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "percent-encoding", @@ -7757,7 +7803,7 @@ dependencies = [ "bytes", "flate2", "h2 0.4.4", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "hyper 1.4.1", @@ -7848,7 +7894,7 @@ dependencies = [ "base64 0.22.1", "bitflags 2.8.0", "bytes", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "mime", "pin-project-lite", @@ -7869,7 +7915,7 @@ name = "tower-otel" version = "0.2.0" source = "git+https://github.com/mattiapenati/tower-otel?rev=56a7321053bcb72443888257b622ba0d43a11fcd#56a7321053bcb72443888257b622ba0d43a11fcd" dependencies = [ - "http 1.1.0", + "http 1.3.1", "opentelemetry", "pin-project", "tower-layer", @@ -8050,7 +8096,7 @@ dependencies = [ "byteorder", "bytes", "data-encoding", - "http 1.1.0", + "http 1.3.1", "httparse", "log", "rand 0.8.5", @@ -8069,7 +8115,7 @@ dependencies = [ "byteorder", "bytes", "data-encoding", - "http 1.1.0", + "http 1.3.1", "httparse", "log", "rand 0.8.5", @@ -8251,7 +8297,7 @@ dependencies = [ "postgres_connection", "pprof", "pq_proto", - "rand 0.8.5", + "rand 0.9.1", "regex", "scopeguard", "sentry", @@ -8858,8 +8904,8 @@ dependencies = [ "quote", "rand 0.8.5", "regex", - "regex-automata 0.4.3", - "regex-syntax 0.8.2", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", "reqwest", "rustls 0.23.27", "rustls-pki-types", @@ -8955,6 +9001,12 @@ version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4d25c75bf9ea12c4040a97f829154768bbbce366287e2dc044af160cd79a13fd" +[[package]] +name = "yansi" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" + [[package]] name = "yasna" version = "0.5.2" diff --git a/Cargo.toml b/Cargo.toml index 6d91262882..3a57976cd8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,6 +49,7 @@ members = [ "libs/proxy/tokio-postgres2", "endpoint_storage", "pgxn/neon/communicator", + "proxy/subzero_core", ] [workspace.package] @@ -157,7 +158,9 @@ procfs = "0.16" prometheus = {version = "0.13", default-features=false, features = ["process"]} # removes protobuf dependency prost = "0.13.5" prost-types = "0.13.5" -rand = "0.8" +rand = "0.9" +# Remove after p256 is updated to 0.14. +rand_core = "=0.6" redis = { version = "0.29.2", features = ["tokio-rustls-comp", "keep-alive"] } regex = "1.10.2" reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"] } diff --git a/Dockerfile b/Dockerfile index 55b87d4012..654ae72e56 100644 --- a/Dockerfile +++ b/Dockerfile @@ -63,7 +63,14 @@ WORKDIR /home/nonroot COPY --chown=nonroot . . -RUN cargo chef prepare --recipe-path recipe.json +RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \ + set -e \ + && if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \ + export CARGO_NET_GIT_FETCH_WITH_CLI=true && \ + git config --global url."https://$(cat /run/secrets/SUBZERO_ACCESS_TOKEN)@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero" && \ + cargo add -p proxy subzero-core --git https://github.com/neondatabase/subzero --rev 396264617e78e8be428682f87469bb25429af88a; \ + fi \ + && cargo chef prepare --recipe-path recipe.json # Main build image FROM $REPOSITORY/$IMAGE:$TAG AS build @@ -71,20 +78,33 @@ WORKDIR /home/nonroot ARG GIT_VERSION=local ARG BUILD_TAG ARG ADDITIONAL_RUSTFLAGS="" +ENV CARGO_FEATURES="default" # 3. Build cargo dependencies. Note that this step doesn't depend on anything else than # `recipe.json`, so the layer can be reused as long as none of the dependencies change. COPY --from=plan /home/nonroot/recipe.json recipe.json -RUN set -e \ +RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \ + set -e \ + && if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \ + export CARGO_NET_GIT_FETCH_WITH_CLI=true && \ + git config --global url."https://$(cat /run/secrets/SUBZERO_ACCESS_TOKEN)@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero"; \ + fi \ && RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo chef cook --locked --release --recipe-path recipe.json # Perform the main build. We reuse the Postgres build artifacts from the intermediate 'pg-build' # layer, and the cargo dependencies built in the previous step. COPY --chown=nonroot --from=pg-build /home/nonroot/pg_install/ pg_install COPY --chown=nonroot . . +COPY --chown=nonroot --from=plan /home/nonroot/proxy/Cargo.toml proxy/Cargo.toml +COPY --chown=nonroot --from=plan /home/nonroot/Cargo.lock Cargo.lock -RUN set -e \ +RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \ + set -e \ + && if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \ + export CARGO_FEATURES="rest_broker"; \ + fi \ && RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo build \ + --features $CARGO_FEATURES \ --bin pg_sni_router \ --bin pageserver \ --bin pagectl \ diff --git a/compute_tools/src/spec_apply.rs b/compute_tools/src/spec_apply.rs index ec7e75922b..47bf61ae1b 100644 --- a/compute_tools/src/spec_apply.rs +++ b/compute_tools/src/spec_apply.rs @@ -411,7 +411,8 @@ impl ComputeNode { .map(|limit| match limit { 0..10 => limit, 10..30 => 10, - 30.. => limit / 3, + 30..300 => limit / 3, + 300.. => 100, }) // If we didn't find max_connections, default to 10 concurrent connections. .unwrap_or(10) diff --git a/deny.toml b/deny.toml index be1c6a2f2c..7afd05a837 100644 --- a/deny.toml +++ b/deny.toml @@ -35,6 +35,7 @@ reason = "The paste crate is a build-only dependency with no runtime components. # More documentation for the licenses section can be found here: # https://embarkstudios.github.io/cargo-deny/checks/licenses/cfg.html [licenses] +version = 2 allow = [ "0BSD", "Apache-2.0", diff --git a/endpoint_storage/src/app.rs b/endpoint_storage/src/app.rs index a7a18743ef..64c21cc8b9 100644 --- a/endpoint_storage/src/app.rs +++ b/endpoint_storage/src/app.rs @@ -233,7 +233,7 @@ mod tests { .unwrap() .as_millis(); use rand::Rng; - let random = rand::thread_rng().r#gen::(); + let random = rand::rng().random::(); let s3_config = remote_storage::S3Config { bucket_name: var(REAL_S3_BUCKET).unwrap(), diff --git a/libs/consumption_metrics/src/lib.rs b/libs/consumption_metrics/src/lib.rs index 448134f31a..aeb33bdfc2 100644 --- a/libs/consumption_metrics/src/lib.rs +++ b/libs/consumption_metrics/src/lib.rs @@ -90,7 +90,7 @@ impl<'a> IdempotencyKey<'a> { IdempotencyKey { now: Utc::now(), node_id, - nonce: rand::thread_rng().gen_range(0..=9999), + nonce: rand::rng().random_range(0..=9999), } } diff --git a/libs/desim/src/node_os.rs b/libs/desim/src/node_os.rs index e0cde7b284..6517c2001e 100644 --- a/libs/desim/src/node_os.rs +++ b/libs/desim/src/node_os.rs @@ -41,7 +41,7 @@ impl NodeOs { /// Generate a random number in range [0, max). pub fn random(&self, max: u64) -> u64 { - self.internal.rng.lock().gen_range(0..max) + self.internal.rng.lock().random_range(0..max) } /// Append a new event to the world event log. diff --git a/libs/desim/src/options.rs b/libs/desim/src/options.rs index 9b1a42fd28..d5da008ef1 100644 --- a/libs/desim/src/options.rs +++ b/libs/desim/src/options.rs @@ -32,10 +32,10 @@ impl Delay { /// Generate a random delay in range [min, max]. Return None if the /// message should be dropped. pub fn delay(&self, rng: &mut StdRng) -> Option { - if rng.gen_bool(self.fail_prob) { + if rng.random_bool(self.fail_prob) { return None; } - Some(rng.gen_range(self.min..=self.max)) + Some(rng.random_range(self.min..=self.max)) } } diff --git a/libs/desim/src/world.rs b/libs/desim/src/world.rs index 576ba89cd7..690d45f373 100644 --- a/libs/desim/src/world.rs +++ b/libs/desim/src/world.rs @@ -69,7 +69,7 @@ impl World { /// Create a new random number generator. pub fn new_rng(&self) -> StdRng { let mut rng = self.rng.lock(); - StdRng::from_rng(rng.deref_mut()).unwrap() + StdRng::from_rng(rng.deref_mut()) } /// Create a new node. diff --git a/libs/metrics/Cargo.toml b/libs/metrics/Cargo.toml index f87e7b8e3a..1718ddfae2 100644 --- a/libs/metrics/Cargo.toml +++ b/libs/metrics/Cargo.toml @@ -17,5 +17,5 @@ procfs.workspace = true measured-process.workspace = true [dev-dependencies] -rand = "0.8" -rand_distr = "0.4.3" +rand.workspace = true +rand_distr = "0.5" diff --git a/libs/metrics/src/hll.rs b/libs/metrics/src/hll.rs index 1a7d7a7e44..81e5bafbdf 100644 --- a/libs/metrics/src/hll.rs +++ b/libs/metrics/src/hll.rs @@ -260,7 +260,7 @@ mod tests { #[test] fn test_cardinality_small() { - let (actual, estimate) = test_cardinality(100, Zipf::new(100, 1.2f64).unwrap()); + let (actual, estimate) = test_cardinality(100, Zipf::new(100.0, 1.2f64).unwrap()); assert_eq!(actual, [46, 30, 32]); assert!(51.3 < estimate[0] && estimate[0] < 51.4); @@ -270,7 +270,7 @@ mod tests { #[test] fn test_cardinality_medium() { - let (actual, estimate) = test_cardinality(10000, Zipf::new(10000, 1.2f64).unwrap()); + let (actual, estimate) = test_cardinality(10000, Zipf::new(10000.0, 1.2f64).unwrap()); assert_eq!(actual, [2529, 1618, 1629]); assert!(2309.1 < estimate[0] && estimate[0] < 2309.2); @@ -280,7 +280,8 @@ mod tests { #[test] fn test_cardinality_large() { - let (actual, estimate) = test_cardinality(1_000_000, Zipf::new(1_000_000, 1.2f64).unwrap()); + let (actual, estimate) = + test_cardinality(1_000_000, Zipf::new(1_000_000.0, 1.2f64).unwrap()); assert_eq!(actual, [129077, 79579, 79630]); assert!(126067.2 < estimate[0] && estimate[0] < 126067.3); @@ -290,7 +291,7 @@ mod tests { #[test] fn test_cardinality_small2() { - let (actual, estimate) = test_cardinality(100, Zipf::new(200, 0.8f64).unwrap()); + let (actual, estimate) = test_cardinality(100, Zipf::new(200.0, 0.8f64).unwrap()); assert_eq!(actual, [92, 58, 60]); assert!(116.1 < estimate[0] && estimate[0] < 116.2); @@ -300,7 +301,7 @@ mod tests { #[test] fn test_cardinality_medium2() { - let (actual, estimate) = test_cardinality(10000, Zipf::new(20000, 0.8f64).unwrap()); + let (actual, estimate) = test_cardinality(10000, Zipf::new(20000.0, 0.8f64).unwrap()); assert_eq!(actual, [8201, 5131, 5051]); assert!(6846.4 < estimate[0] && estimate[0] < 6846.5); @@ -310,7 +311,8 @@ mod tests { #[test] fn test_cardinality_large2() { - let (actual, estimate) = test_cardinality(1_000_000, Zipf::new(2_000_000, 0.8f64).unwrap()); + let (actual, estimate) = + test_cardinality(1_000_000, Zipf::new(2_000_000.0, 0.8f64).unwrap()); assert_eq!(actual, [777847, 482069, 482246]); assert!(699437.4 < estimate[0] && estimate[0] < 699437.5); diff --git a/libs/neon-shmem/Cargo.toml b/libs/neon-shmem/Cargo.toml index 7ed991502e..1cdc9c0c67 100644 --- a/libs/neon-shmem/Cargo.toml +++ b/libs/neon-shmem/Cargo.toml @@ -16,5 +16,5 @@ rustc-hash.workspace = true tempfile = "3.14.0" [dev-dependencies] -rand = "0.9" +rand.workspace = true rand_distr = "0.5.1" diff --git a/libs/pageserver_api/src/controller_api.rs b/libs/pageserver_api/src/controller_api.rs index 8f86b03f72..1248be0b5c 100644 --- a/libs/pageserver_api/src/controller_api.rs +++ b/libs/pageserver_api/src/controller_api.rs @@ -596,6 +596,7 @@ pub struct TimelineImportRequest { pub timeline_id: TimelineId, pub start_lsn: Lsn, pub sk_set: Vec, + pub force_upsert: bool, } #[derive(serde::Serialize, serde::Deserialize, Clone)] diff --git a/libs/pageserver_api/src/key.rs b/libs/pageserver_api/src/key.rs index 102bbee879..4e8fabfa72 100644 --- a/libs/pageserver_api/src/key.rs +++ b/libs/pageserver_api/src/key.rs @@ -981,12 +981,12 @@ mod tests { let mut rng = rand::rngs::StdRng::seed_from_u64(42); let key = Key { - field1: rng.r#gen(), - field2: rng.r#gen(), - field3: rng.r#gen(), - field4: rng.r#gen(), - field5: rng.r#gen(), - field6: rng.r#gen(), + field1: rng.random(), + field2: rng.random(), + field3: rng.random(), + field4: rng.random(), + field5: rng.random(), + field6: rng.random(), }; assert_eq!(key, Key::from_str(&format!("{key}")).unwrap()); diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 8759eaac19..f3eb5b1212 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -446,9 +446,9 @@ pub struct ImportPgdataIdempotencyKey(pub String); impl ImportPgdataIdempotencyKey { pub fn random() -> Self { use rand::Rng; - use rand::distributions::Alphanumeric; + use rand::distr::Alphanumeric; Self( - rand::thread_rng() + rand::rng() .sample_iter(&Alphanumeric) .take(20) .map(char::from) diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index 482dd9a298..5ecb4badf1 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -203,12 +203,12 @@ impl fmt::Display for CancelKeyData { } } -use rand::distributions::{Distribution, Standard}; -impl Distribution for Standard { +use rand::distr::{Distribution, StandardUniform}; +impl Distribution for StandardUniform { fn sample(&self, rng: &mut R) -> CancelKeyData { CancelKeyData { - backend_pid: rng.r#gen(), - cancel_key: rng.r#gen(), + backend_pid: rng.random(), + cancel_key: rng.random(), } } } diff --git a/libs/proxy/postgres-protocol2/src/authentication/sasl.rs b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs index 274c81c500..cfa59a34f4 100644 --- a/libs/proxy/postgres-protocol2/src/authentication/sasl.rs +++ b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs @@ -155,10 +155,10 @@ pub struct ScramSha256 { fn nonce() -> String { // rand 0.5's ThreadRng is cryptographically secure - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); (0..NONCE_LENGTH) .map(|_| { - let mut v = rng.gen_range(0x21u8..0x7e); + let mut v = rng.random_range(0x21u8..0x7e); if v == 0x2c { v = 0x7e } diff --git a/libs/proxy/postgres-protocol2/src/password/mod.rs b/libs/proxy/postgres-protocol2/src/password/mod.rs index e00ca1e34c..8926710225 100644 --- a/libs/proxy/postgres-protocol2/src/password/mod.rs +++ b/libs/proxy/postgres-protocol2/src/password/mod.rs @@ -28,7 +28,7 @@ const SCRAM_DEFAULT_SALT_LEN: usize = 16; /// special characters that would require escaping in an SQL command. pub async fn scram_sha_256(password: &[u8]) -> String { let mut salt: [u8; SCRAM_DEFAULT_SALT_LEN] = [0; SCRAM_DEFAULT_SALT_LEN]; - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); rng.fill_bytes(&mut salt); scram_sha_256_salt(password, salt).await } diff --git a/libs/remote_storage/Cargo.toml b/libs/remote_storage/Cargo.toml index 0ae13552b8..ea06725cfd 100644 --- a/libs/remote_storage/Cargo.toml +++ b/libs/remote_storage/Cargo.toml @@ -43,7 +43,7 @@ itertools.workspace = true sync_wrapper = { workspace = true, features = ["futures"] } byteorder = "1.4" -rand = "0.8.5" +rand.workspace = true [dev-dependencies] camino-tempfile.workspace = true diff --git a/libs/remote_storage/src/simulate_failures.rs b/libs/remote_storage/src/simulate_failures.rs index e895380192..f35d2a3081 100644 --- a/libs/remote_storage/src/simulate_failures.rs +++ b/libs/remote_storage/src/simulate_failures.rs @@ -81,7 +81,7 @@ impl UnreliableWrapper { /// fn attempt(&self, op: RemoteOp) -> anyhow::Result { let mut attempts = self.attempts.lock().unwrap(); - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); match attempts.entry(op) { Entry::Occupied(mut e) => { @@ -94,7 +94,7 @@ impl UnreliableWrapper { /* BEGIN_HADRON */ // If there are more attempts to fail, fail the request by probability. if (attempts_before_this < self.attempts_to_fail) - && (rng.gen_range(0..=100) < self.attempt_failure_probability) + && (rng.random_range(0..=100) < self.attempt_failure_probability) { let error = anyhow::anyhow!("simulated failure of remote operation {:?}", e.key()); diff --git a/libs/remote_storage/tests/test_real_azure.rs b/libs/remote_storage/tests/test_real_azure.rs index 4d7caabd39..949035b8c3 100644 --- a/libs/remote_storage/tests/test_real_azure.rs +++ b/libs/remote_storage/tests/test_real_azure.rs @@ -208,7 +208,7 @@ async fn create_azure_client( .as_millis(); // because nanos can be the same for two threads so can millis, add randomness - let random = rand::thread_rng().r#gen::(); + let random = rand::rng().random::(); let remote_storage_config = RemoteStorageConfig { storage: RemoteStorageKind::AzureContainer(AzureConfig { diff --git a/libs/remote_storage/tests/test_real_s3.rs b/libs/remote_storage/tests/test_real_s3.rs index 6b893edf75..f5c81bf45d 100644 --- a/libs/remote_storage/tests/test_real_s3.rs +++ b/libs/remote_storage/tests/test_real_s3.rs @@ -385,7 +385,7 @@ async fn create_s3_client( .as_millis(); // because nanos can be the same for two threads so can millis, add randomness - let random = rand::thread_rng().r#gen::(); + let random = rand::rng().random::(); let remote_storage_config = RemoteStorageConfig { storage: RemoteStorageKind::AwsS3(S3Config { diff --git a/libs/utils/src/id.rs b/libs/utils/src/id.rs index e3037aec21..d63bba75a3 100644 --- a/libs/utils/src/id.rs +++ b/libs/utils/src/id.rs @@ -104,7 +104,7 @@ impl Id { pub fn generate() -> Self { let mut tli_buf = [0u8; 16]; - rand::thread_rng().fill(&mut tli_buf); + rand::rng().fill(&mut tli_buf); Id::from(tli_buf) } diff --git a/libs/utils/src/lsn.rs b/libs/utils/src/lsn.rs index 31e1dda23d..1abb63817b 100644 --- a/libs/utils/src/lsn.rs +++ b/libs/utils/src/lsn.rs @@ -364,42 +364,37 @@ impl MonotonicCounter for RecordLsn { } } -/// Implements [`rand::distributions::uniform::UniformSampler`] so we can sample [`Lsn`]s. +/// Implements [`rand::distr::uniform::UniformSampler`] so we can sample [`Lsn`]s. /// /// This is used by the `pagebench` pageserver benchmarking tool. -pub struct LsnSampler(::Sampler); +pub struct LsnSampler(::Sampler); -impl rand::distributions::uniform::SampleUniform for Lsn { +impl rand::distr::uniform::SampleUniform for Lsn { type Sampler = LsnSampler; } -impl rand::distributions::uniform::UniformSampler for LsnSampler { +impl rand::distr::uniform::UniformSampler for LsnSampler { type X = Lsn; - fn new(low: B1, high: B2) -> Self + fn new(low: B1, high: B2) -> Result where - B1: rand::distributions::uniform::SampleBorrow + Sized, - B2: rand::distributions::uniform::SampleBorrow + Sized, + B1: rand::distr::uniform::SampleBorrow + Sized, + B2: rand::distr::uniform::SampleBorrow + Sized, { - Self( - ::Sampler::new( - low.borrow().0, - high.borrow().0, - ), - ) + ::Sampler::new(low.borrow().0, high.borrow().0) + .map(Self) } - fn new_inclusive(low: B1, high: B2) -> Self + fn new_inclusive(low: B1, high: B2) -> Result where - B1: rand::distributions::uniform::SampleBorrow + Sized, - B2: rand::distributions::uniform::SampleBorrow + Sized, + B1: rand::distr::uniform::SampleBorrow + Sized, + B2: rand::distr::uniform::SampleBorrow + Sized, { - Self( - ::Sampler::new_inclusive( - low.borrow().0, - high.borrow().0, - ), + ::Sampler::new_inclusive( + low.borrow().0, + high.borrow().0, ) + .map(Self) } fn sample(&self, rng: &mut R) -> Self::X { diff --git a/pageserver/benches/bench_layer_map.rs b/pageserver/benches/bench_layer_map.rs index e1444778b8..284cc4d67d 100644 --- a/pageserver/benches/bench_layer_map.rs +++ b/pageserver/benches/bench_layer_map.rs @@ -11,7 +11,8 @@ use pageserver::tenant::layer_map::LayerMap; use pageserver::tenant::storage_layer::{LayerName, PersistentLayerDesc}; use pageserver_api::key::Key; use pageserver_api::shard::TenantShardId; -use rand::prelude::{SeedableRng, SliceRandom, StdRng}; +use rand::prelude::{SeedableRng, StdRng}; +use rand::seq::IndexedRandom; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; diff --git a/pageserver/compaction/src/bin/compaction-simulator.rs b/pageserver/compaction/src/bin/compaction-simulator.rs index dd35417333..6211b86809 100644 --- a/pageserver/compaction/src/bin/compaction-simulator.rs +++ b/pageserver/compaction/src/bin/compaction-simulator.rs @@ -89,7 +89,7 @@ async fn simulate(cmd: &SimulateCmd, results_path: &Path) -> anyhow::Result<()> let cold_key_range = splitpoint..key_range.end; for i in 0..cmd.num_records { - let chosen_range = if rand::thread_rng().gen_bool(0.9) { + let chosen_range = if rand::rng().random_bool(0.9) { &hot_key_range } else { &cold_key_range diff --git a/pageserver/compaction/src/simulator.rs b/pageserver/compaction/src/simulator.rs index bf9f6f2658..44507c335b 100644 --- a/pageserver/compaction/src/simulator.rs +++ b/pageserver/compaction/src/simulator.rs @@ -300,9 +300,9 @@ impl MockTimeline { key_range: &Range, ) -> anyhow::Result<()> { crate::helpers::union_to_keyspace(&mut self.keyspace, vec![key_range.clone()]); - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for _ in 0..num_records { - self.ingest_record(rng.gen_range(key_range.clone()), len); + self.ingest_record(rng.random_range(key_range.clone()), len); self.wal_ingested += len; } Ok(()) diff --git a/pageserver/pagebench/src/cmd/basebackup.rs b/pageserver/pagebench/src/cmd/basebackup.rs index c14bb73136..01875f74b9 100644 --- a/pageserver/pagebench/src/cmd/basebackup.rs +++ b/pageserver/pagebench/src/cmd/basebackup.rs @@ -188,9 +188,9 @@ async fn main_impl( start_work_barrier.wait().await; loop { let (timeline, work) = { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let target = all_targets.choose(&mut rng).unwrap(); - let lsn = target.lsn_range.clone().map(|r| rng.gen_range(r)); + let lsn = target.lsn_range.clone().map(|r| rng.random_range(r)); (target.timeline, Work { lsn }) }; let sender = work_senders.get(&timeline).unwrap(); diff --git a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs index 30b30d36f6..ed7fe9c4ea 100644 --- a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs +++ b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs @@ -326,8 +326,7 @@ async fn main_impl( .cloned() .collect(); let weights = - rand::distributions::weighted::WeightedIndex::new(ranges.iter().map(|v| v.len())) - .unwrap(); + rand::distr::weighted::WeightedIndex::new(ranges.iter().map(|v| v.len())).unwrap(); Box::pin(async move { let scheme = match Url::parse(&args.page_service_connstring) { @@ -427,7 +426,7 @@ async fn run_worker( cancel: CancellationToken, rps_period: Option, ranges: Vec, - weights: rand::distributions::weighted::WeightedIndex, + weights: rand::distr::weighted::WeightedIndex, ) { shared_state.start_work_barrier.wait().await; let client_start = Instant::now(); @@ -469,9 +468,9 @@ async fn run_worker( } // Pick a random page from a random relation. - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); let r = &ranges[weights.sample(&mut rng)]; - let key: i128 = rng.gen_range(r.start..r.end); + let key: i128 = rng.random_range(r.start..r.end); let (rel_tag, block_no) = key_to_block(key); let mut blks = VecDeque::with_capacity(batch_size); @@ -502,7 +501,7 @@ async fn run_worker( // We assume that the entire batch can fit within the relation. assert_eq!(blks.len(), batch_size, "incomplete batch"); - let req_lsn = if rng.gen_bool(args.req_latest_probability) { + let req_lsn = if rng.random_bool(args.req_latest_probability) { Lsn::MAX } else { r.timeline_lsn diff --git a/pageserver/pagebench/src/cmd/ondemand_download_churn.rs b/pageserver/pagebench/src/cmd/ondemand_download_churn.rs index 9ff1e638c4..8fbb452140 100644 --- a/pageserver/pagebench/src/cmd/ondemand_download_churn.rs +++ b/pageserver/pagebench/src/cmd/ondemand_download_churn.rs @@ -7,7 +7,7 @@ use std::time::{Duration, Instant}; use pageserver_api::models::HistoricLayerInfo; use pageserver_api::shard::TenantShardId; use pageserver_client::mgmt_api; -use rand::seq::SliceRandom; +use rand::seq::IndexedMutRandom; use tokio::sync::{OwnedSemaphorePermit, mpsc}; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; @@ -260,7 +260,7 @@ async fn timeline_actor( loop { let layer_tx = { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); timeline.layers.choose_mut(&mut rng).expect("no layers") }; match layer_tx.try_send(permit.take().unwrap()) { diff --git a/pageserver/src/feature_resolver.rs b/pageserver/src/feature_resolver.rs index f0178fd9b3..11b0e972b4 100644 --- a/pageserver/src/feature_resolver.rs +++ b/pageserver/src/feature_resolver.rs @@ -155,7 +155,7 @@ impl FeatureResolver { ); let tenant_properties = PerTenantProperties { - remote_size_mb: Some(rand::thread_rng().gen_range(100.0..1000000.00)), + remote_size_mb: Some(rand::rng().random_range(100.0..1000000.00)), } .into_posthog_properties(); diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 317deadfae..ebcfcfbf64 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -6214,11 +6214,11 @@ mod tests { use pageserver_api::keyspace::KeySpaceRandomAccum; use pageserver_api::models::{CompactionAlgorithm, CompactionAlgorithmSettings, LsnLease}; use pageserver_compaction::helpers::overlaps_with; + use rand::Rng; #[cfg(feature = "testing")] use rand::SeedableRng; #[cfg(feature = "testing")] use rand::rngs::StdRng; - use rand::{Rng, thread_rng}; #[cfg(feature = "testing")] use std::ops::Range; use storage_layer::{IoConcurrency, PersistentLayerKey}; @@ -6340,8 +6340,8 @@ mod tests { while lsn < lsn_range.end { let mut key = key_range.start; while key < key_range.end { - let gap = random.gen_range(1..=100) <= spec.gap_chance; - let will_init = random.gen_range(1..=100) <= spec.will_init_chance; + let gap = random.random_range(1..=100) <= spec.gap_chance; + let will_init = random.random_range(1..=100) <= spec.will_init_chance; if gap { continue; @@ -6384,8 +6384,8 @@ mod tests { while lsn < lsn_range.end { let mut key = key_range.start; while key < key_range.end { - let gap = random.gen_range(1..=100) <= spec.gap_chance; - let will_init = random.gen_range(1..=100) <= spec.will_init_chance; + let gap = random.random_range(1..=100) <= spec.gap_chance; + let will_init = random.random_range(1..=100) <= spec.will_init_chance; if gap { continue; @@ -7862,7 +7862,7 @@ mod tests { for _ in 0..50 { for _ in 0..NUM_KEYS { lsn = Lsn(lsn.0 + 0x10); - let blknum = thread_rng().gen_range(0..NUM_KEYS); + let blknum = rand::rng().random_range(0..NUM_KEYS); test_key.field6 = blknum as u32; let mut writer = tline.writer().await; writer @@ -7951,7 +7951,7 @@ mod tests { for _ in 0..NUM_KEYS { lsn = Lsn(lsn.0 + 0x10); - let blknum = thread_rng().gen_range(0..NUM_KEYS); + let blknum = rand::rng().random_range(0..NUM_KEYS); test_key.field6 = blknum as u32; let mut writer = tline.writer().await; writer @@ -8019,7 +8019,7 @@ mod tests { for _ in 0..NUM_KEYS { lsn = Lsn(lsn.0 + 0x10); - let blknum = thread_rng().gen_range(0..NUM_KEYS); + let blknum = rand::rng().random_range(0..NUM_KEYS); test_key.field6 = blknum as u32; let mut writer = tline.writer().await; writer @@ -8283,7 +8283,7 @@ mod tests { for _ in 0..NUM_KEYS { lsn = Lsn(lsn.0 + 0x10); - let blknum = thread_rng().gen_range(0..NUM_KEYS); + let blknum = rand::rng().random_range(0..NUM_KEYS); test_key.field6 = (blknum * STEP) as u32; let mut writer = tline.writer().await; writer @@ -8556,7 +8556,7 @@ mod tests { for iter in 1..=10 { for _ in 0..NUM_KEYS { lsn = Lsn(lsn.0 + 0x10); - let blknum = thread_rng().gen_range(0..NUM_KEYS); + let blknum = rand::rng().random_range(0..NUM_KEYS); test_key.field6 = (blknum * STEP) as u32; let mut writer = tline.writer().await; writer @@ -11680,10 +11680,10 @@ mod tests { #[cfg(feature = "testing")] #[tokio::test] async fn test_read_path() -> anyhow::Result<()> { - use rand::seq::SliceRandom; + use rand::seq::IndexedRandom; let seed = if cfg!(feature = "fuzz-read-path") { - let seed: u64 = thread_rng().r#gen(); + let seed: u64 = rand::rng().random(); seed } else { // Use a hard-coded seed when not in fuzzing mode. @@ -11697,8 +11697,8 @@ mod tests { let (queries, will_init_chance, gap_chance) = if cfg!(feature = "fuzz-read-path") { const QUERIES: u64 = 5000; - let will_init_chance: u8 = random.gen_range(0..=10); - let gap_chance: u8 = random.gen_range(0..=50); + let will_init_chance: u8 = random.random_range(0..=10); + let gap_chance: u8 = random.random_range(0..=50); (QUERIES, will_init_chance, gap_chance) } else { @@ -11799,7 +11799,8 @@ mod tests { while used_keys.len() < tenant.conf.max_get_vectored_keys.get() { let selected_lsn = interesting_lsns.choose(&mut random).expect("not empty"); - let mut selected_key = start_key.add(random.gen_range(0..KEY_DIMENSION_SIZE)); + let mut selected_key = + start_key.add(random.random_range(0..KEY_DIMENSION_SIZE)); while used_keys.len() < tenant.conf.max_get_vectored_keys.get() { if used_keys.contains(&selected_key) @@ -11814,7 +11815,7 @@ mod tests { .add_key(selected_key); used_keys.insert(selected_key); - let pick_next = random.gen_range(0..=100) <= PICK_NEXT_CHANCE; + let pick_next = random.random_range(0..=100) <= PICK_NEXT_CHANCE; if pick_next { selected_key = selected_key.next(); } else { diff --git a/pageserver/src/tenant/blob_io.rs b/pageserver/src/tenant/blob_io.rs index ed541c4f12..29320f088c 100644 --- a/pageserver/src/tenant/blob_io.rs +++ b/pageserver/src/tenant/blob_io.rs @@ -535,8 +535,8 @@ pub(crate) mod tests { } pub(crate) fn random_array(len: usize) -> Vec { - let mut rng = rand::thread_rng(); - (0..len).map(|_| rng.r#gen()).collect::<_>() + let mut rng = rand::rng(); + (0..len).map(|_| rng.random()).collect::<_>() } #[tokio::test] @@ -588,9 +588,9 @@ pub(crate) mod tests { let mut rng = rand::rngs::StdRng::seed_from_u64(42); let blobs = (0..1024) .map(|_| { - let mut sz: u16 = rng.r#gen(); + let mut sz: u16 = rng.random(); // Make 50% of the arrays small - if rng.r#gen() { + if rng.random() { sz &= 63; } random_array(sz.into()) diff --git a/pageserver/src/tenant/disk_btree.rs b/pageserver/src/tenant/disk_btree.rs index 419befa41b..40f405307c 100644 --- a/pageserver/src/tenant/disk_btree.rs +++ b/pageserver/src/tenant/disk_btree.rs @@ -1090,7 +1090,7 @@ pub(crate) mod tests { const NUM_KEYS: usize = 100000; let mut all_data: BTreeMap = BTreeMap::new(); for idx in 0..NUM_KEYS { - let u: f64 = rand::thread_rng().gen_range(0.0..1.0); + let u: f64 = rand::rng().random_range(0.0..1.0); let t = -(f64::ln(u)); let key_int = (t * 1000000.0) as u128; @@ -1116,7 +1116,7 @@ pub(crate) mod tests { // Test get() operations on random keys, most of which will not exist for _ in 0..100000 { - let key_int = rand::thread_rng().r#gen::(); + let key_int = rand::rng().random::(); let search_key = u128::to_be_bytes(key_int); assert!(reader.get(&search_key, &ctx).await? == all_data.get(&key_int).cloned()); } diff --git a/pageserver/src/tenant/ephemeral_file.rs b/pageserver/src/tenant/ephemeral_file.rs index 203b5bf592..f2be129090 100644 --- a/pageserver/src/tenant/ephemeral_file.rs +++ b/pageserver/src/tenant/ephemeral_file.rs @@ -508,8 +508,8 @@ mod tests { let write_nbytes = cap * 2 + cap / 2; - let content: Vec = rand::thread_rng() - .sample_iter(rand::distributions::Standard) + let content: Vec = rand::rng() + .sample_iter(rand::distr::StandardUniform) .take(write_nbytes) .collect(); @@ -565,8 +565,8 @@ mod tests { let cap = writer.mutable().capacity(); drop(writer); - let content: Vec = rand::thread_rng() - .sample_iter(rand::distributions::Standard) + let content: Vec = rand::rng() + .sample_iter(rand::distr::StandardUniform) .take(cap * 2 + cap / 2) .collect(); @@ -614,8 +614,8 @@ mod tests { let cap = mutable.capacity(); let align = mutable.align(); drop(writer); - let content: Vec = rand::thread_rng() - .sample_iter(rand::distributions::Standard) + let content: Vec = rand::rng() + .sample_iter(rand::distr::StandardUniform) .take(cap * 2 + cap / 2) .collect(); diff --git a/pageserver/src/tenant/mgr.rs b/pageserver/src/tenant/mgr.rs index 01db09ed59..9b196ae393 100644 --- a/pageserver/src/tenant/mgr.rs +++ b/pageserver/src/tenant/mgr.rs @@ -19,7 +19,7 @@ use pageserver_api::shard::{ }; use pageserver_api::upcall_api::ReAttachResponseTenant; use rand::Rng; -use rand::distributions::Alphanumeric; +use rand::distr::Alphanumeric; use remote_storage::TimeoutOrCancel; use sysinfo::SystemExt; use tokio::fs; @@ -218,7 +218,7 @@ async fn safe_rename_tenant_dir(path: impl AsRef) -> std::io::Result Duration { if d == Duration::ZERO { d } else { - rand::thread_rng().gen_range((d * (100 - pct)) / 100..(d * (100 + pct)) / 100) + rand::rng().random_range((d * (100 - pct)) / 100..(d * (100 + pct)) / 100) } } @@ -35,7 +35,7 @@ pub(super) fn period_warmup(period: Duration) -> Duration { if period == Duration::ZERO { period } else { - rand::thread_rng().gen_range(Duration::ZERO..period) + rand::rng().random_range(Duration::ZERO..period) } } diff --git a/pageserver/src/tenant/storage_layer/delta_layer.rs b/pageserver/src/tenant/storage_layer/delta_layer.rs index c2f76c859c..f963fdac92 100644 --- a/pageserver/src/tenant/storage_layer/delta_layer.rs +++ b/pageserver/src/tenant/storage_layer/delta_layer.rs @@ -1634,7 +1634,8 @@ pub(crate) mod test { use bytes::Bytes; use itertools::MinMaxResult; use postgres_ffi::PgMajorVersion; - use rand::prelude::{SeedableRng, SliceRandom, StdRng}; + use rand::prelude::{SeedableRng, StdRng}; + use rand::seq::IndexedRandom; use rand::{Rng, RngCore}; /// Construct an index for a fictional delta layer and and then @@ -1788,14 +1789,14 @@ pub(crate) mod test { let mut entries = Vec::new(); for _ in 0..constants::KEY_COUNT { - let count = rng.gen_range(1..constants::MAX_ENTRIES_PER_KEY); + let count = rng.random_range(1..constants::MAX_ENTRIES_PER_KEY); let mut lsns_iter = std::iter::successors(Some(Lsn(constants::LSN_OFFSET.0 + 0x08)), |lsn| { Some(Lsn(lsn.0 + 0x08)) }); let mut lsns = Vec::new(); while lsns.len() < count as usize { - let take = rng.gen_bool(0.5); + let take = rng.random_bool(0.5); let lsn = lsns_iter.next().unwrap(); if take { lsns.push(lsn); @@ -1869,12 +1870,13 @@ pub(crate) mod test { for _ in 0..constants::RANGES_COUNT { let mut range: Option> = Option::default(); while range.is_none() || keyspace.overlaps(range.as_ref().unwrap()) { - let range_start = rng.gen_range(start..end); + let range_start = rng.random_range(start..end); let range_end_offset = range_start + constants::MIN_RANGE_SIZE; if range_end_offset >= end { range = Some(Key::from_i128(range_start)..Key::from_i128(end)); } else { - let range_end = rng.gen_range((range_start + constants::MIN_RANGE_SIZE)..end); + let range_end = + rng.random_range((range_start + constants::MIN_RANGE_SIZE)..end); range = Some(Key::from_i128(range_start)..Key::from_i128(range_end)); } } diff --git a/pageserver/src/tenant/storage_layer/inmemory_layer/vectored_dio_read.rs b/pageserver/src/tenant/storage_layer/inmemory_layer/vectored_dio_read.rs index 27fbc6f5fb..84f4386087 100644 --- a/pageserver/src/tenant/storage_layer/inmemory_layer/vectored_dio_read.rs +++ b/pageserver/src/tenant/storage_layer/inmemory_layer/vectored_dio_read.rs @@ -440,8 +440,8 @@ mod tests { impl InMemoryFile { fn new_random(len: usize) -> Self { Self { - content: rand::thread_rng() - .sample_iter(rand::distributions::Standard) + content: rand::rng() + .sample_iter(rand::distr::StandardUniform) .take(len) .collect(), } @@ -498,7 +498,7 @@ mod tests { len } }; - rand::Rng::fill(&mut rand::thread_rng(), &mut dst_slice[nread..]); // to discover bugs + rand::Rng::fill(&mut rand::rng(), &mut dst_slice[nread..]); // to discover bugs Ok((dst, nread)) } } @@ -763,7 +763,7 @@ mod tests { let len = std::cmp::min(dst.bytes_total(), mocked_bytes.len()); let dst_slice: &mut [u8] = dst.as_mut_rust_slice_full_zeroed(); dst_slice[..len].copy_from_slice(&mocked_bytes[..len]); - rand::Rng::fill(&mut rand::thread_rng(), &mut dst_slice[len..]); // to discover bugs + rand::Rng::fill(&mut rand::rng(), &mut dst_slice[len..]); // to discover bugs Ok((dst, len)) } Err(e) => Err(std::io::Error::other(e)), diff --git a/pageserver/src/tenant/tasks.rs b/pageserver/src/tenant/tasks.rs index 08fc7d61a5..676b39e55b 100644 --- a/pageserver/src/tenant/tasks.rs +++ b/pageserver/src/tenant/tasks.rs @@ -515,7 +515,7 @@ pub(crate) async fn sleep_random_range( interval: RangeInclusive, cancel: &CancellationToken, ) -> Result { - let delay = rand::thread_rng().gen_range(interval); + let delay = rand::rng().random_range(interval); if delay == Duration::ZERO { return Ok(delay); } diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 97c605ed74..371d335285 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -2873,7 +2873,7 @@ impl Timeline { if r.numerator == 0 { false } else { - rand::thread_rng().gen_range(0..r.denominator) < r.numerator + rand::rng().random_range(0..r.denominator) < r.numerator } } None => false, @@ -3955,7 +3955,7 @@ impl Timeline { // 1hour base (60_i64 * 60_i64) // 10min jitter - + rand::thread_rng().gen_range(-10 * 60..10 * 60), + + rand::rng().random_range(-10 * 60..10 * 60), ) .expect("10min < 1hour"), ); diff --git a/pageserver/src/virtual_file.rs b/pageserver/src/virtual_file.rs index 45b6e44c54..a7f0c5914a 100644 --- a/pageserver/src/virtual_file.rs +++ b/pageserver/src/virtual_file.rs @@ -1275,8 +1275,8 @@ mod tests { use std::sync::Arc; use owned_buffers_io::io_buf_ext::IoBufExt; + use rand::Rng; use rand::seq::SliceRandom; - use rand::{Rng, thread_rng}; use super::*; use crate::context::DownloadBehavior; @@ -1358,7 +1358,7 @@ mod tests { // Check that all the other FDs still work too. Use them in random order for // good measure. - file_b_dupes.as_mut_slice().shuffle(&mut thread_rng()); + file_b_dupes.as_mut_slice().shuffle(&mut rand::rng()); for vfile in file_b_dupes.iter_mut() { assert_first_512_eq(vfile, b"content_b").await; } @@ -1413,9 +1413,8 @@ mod tests { let ctx = ctx.detached_child(TaskKind::UnitTest, DownloadBehavior::Error); let hdl = rt.spawn(async move { let mut buf = IoBufferMut::with_capacity_zeroed(SIZE); - let mut rng = rand::rngs::OsRng; for _ in 1..1000 { - let f = &files[rng.gen_range(0..files.len())]; + let f = &files[rand::rng().random_range(0..files.len())]; buf = f .read_exact_at(buf.slice_full(), 0, &ctx) .await diff --git a/pgxn/neon/communicator/Cargo.toml b/pgxn/neon/communicator/Cargo.toml index e95a269d90..b5ce389297 100644 --- a/pgxn/neon/communicator/Cargo.toml +++ b/pgxn/neon/communicator/Cargo.toml @@ -11,6 +11,9 @@ crate-type = ["staticlib"] # 'testing' feature is currently unused in the communicator, but we accept it for convenience of # calling build scripts, so that you can pass the same feature to all packages. testing = [] +# 'rest_broker' feature is currently unused in the communicator, but we accept it for convenience of +# calling build scripts, so that you can pass the same feature to all packages. +rest_broker = [] [dependencies] neon-shmem.workspace = true diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index 9ed8d0d2d2..93807be8c2 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -400,6 +400,14 @@ static uint64 backpressure_lag_impl(void) { struct WalproposerShmemState* state = NULL; + + /* BEGIN_HADRON */ + if(max_cluster_size < 0){ + // if max cluster size is not set, then we don't apply backpressure because we're reconfiguring PG + return 0; + } + /* END_HADRON */ + if (max_replication_apply_lag > 0 || max_replication_flush_lag > 0 || max_replication_write_lag > 0) { XLogRecPtr writePtr; diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 82fe6818e3..3c3f93c8e3 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -7,6 +7,7 @@ license.workspace = true [features] default = [] testing = ["dep:tokio-postgres"] +rest_broker = ["dep:subzero-core", "dep:ouroboros"] [dependencies] ahash.workspace = true @@ -65,6 +66,7 @@ postgres-client = { package = "tokio-postgres2", path = "../libs/proxy/tokio-pos postgres-protocol = { package = "postgres-protocol2", path = "../libs/proxy/postgres-protocol2" } pq_proto.workspace = true rand.workspace = true +rand_core.workspace = true regex.workspace = true remote_storage = { version = "0.1", path = "../libs/remote_storage/" } reqwest = { workspace = true, features = ["rustls-tls-native-roots"] } @@ -105,6 +107,11 @@ uuid.workspace = true x509-cert.workspace = true redis.workspace = true zerocopy.workspace = true +# uncomment this to use the real subzero-core crate +# subzero-core = { git = "https://github.com/neondatabase/subzero", rev = "396264617e78e8be428682f87469bb25429af88a", features = ["postgresql"], optional = true } +# this is a stub for the subzero-core crate +subzero-core = { path = "./subzero_core", features = ["postgresql"], optional = true} +ouroboros = { version = "0.18", optional = true } # jwt stuff jose-jwa = "0.1.2" @@ -127,6 +134,6 @@ pbkdf2 = { workspace = true, features = ["simple", "std"] } rcgen.workspace = true rstest.workspace = true walkdir.workspace = true -rand_distr = "0.4" +rand_distr = "0.5" tokio-postgres.workspace = true tracing-test = "0.2" diff --git a/proxy/README.md b/proxy/README.md index ff48f9f323..ce957b90af 100644 --- a/proxy/README.md +++ b/proxy/README.md @@ -178,16 +178,24 @@ Create a configuration file called `local_proxy.json` in the root of the repo (u Start the local proxy: ```sh -cargo run --bin local_proxy -- \ - --disable_pg_session_jwt true \ +cargo run --bin local_proxy --features testing -- \ + --disable-pg-session-jwt \ --http 0.0.0.0:7432 ``` -Start the auth broker: +Start the auth/rest broker: + +Note: to enable the rest broker you need to replace the stub subzero-core crate with the real one. + ```sh -LOGFMT=text OTEL_SDK_DISABLED=true cargo run --bin proxy --features testing -- \ +cargo add -p proxy subzero-core --git https://github.com/neondatabase/subzero --rev 396264617e78e8be428682f87469bb25429af88a +``` + +```sh +LOGFMT=text OTEL_SDK_DISABLED=true cargo run --bin proxy --features testing,rest_broker -- \ -c server.crt -k server.key \ --is-auth-broker true \ + --is-rest-broker true \ --wss 0.0.0.0:8080 \ --http 0.0.0.0:7002 \ --auth-backend local @@ -205,3 +213,9 @@ curl -k "https://foo.local.neon.build:8080/sql" \ -H "neon-connection-string: postgresql://authenticator@foo.local.neon.build/database" \ -d '{"query":"select 1","params":[]}' ``` + +Make a rest request against the auth broker (rest broker): +```sh +curl -k "https://foo.local.neon.build:8080/database/rest/v1/items?select=id,name&id=eq.1" \ +-H "Authorization: Bearer $NEON_JWT" +``` diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index f561df9202..b06ed3a0ae 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -180,8 +180,6 @@ async fn authenticate( return Err(auth::AuthError::NetworkNotAllowed); } - client.write_message(BeMessage::NoticeResponse("Connecting to database.")); - // Backwards compatibility. pg_sni_proxy uses "--" in domain names // while direct connections do not. Once we migrate to pg_sni_proxy // everywhere, we can remove this. diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index a716890a00..6eba869870 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -803,7 +803,7 @@ mod tests { use http_body_util::Full; use hyper::service::service_fn; use hyper_util::rt::TokioIo; - use rand::rngs::OsRng; + use rand_core::OsRng; use rsa::pkcs8::DecodePrivateKey; use serde::Serialize; use serde_json::json; diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index 401203d48c..7b9012dc69 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -1,3 +1,4 @@ +use std::env; use std::net::SocketAddr; use std::pin::pin; use std::sync::Arc; @@ -20,6 +21,8 @@ use crate::auth::backend::jwt::JwkCache; use crate::auth::backend::local::LocalBackend; use crate::auth::{self}; use crate::cancellation::CancellationHandler; +#[cfg(feature = "rest_broker")] +use crate::config::RestConfig; use crate::config::{ self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig, refresh_config_loop, @@ -262,6 +265,14 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig timeout: Duration::from_secs(2), }; + let greetings = env::var_os("NEON_MOTD").map_or(String::new(), |s| match s.into_string() { + Ok(s) => s, + Err(_) => { + debug!("NEON_MOTD environment variable is not valid UTF-8"); + String::new() + } + }); + Ok(Box::leak(Box::new(ProxyConfig { tls_config: ArcSwapOption::from(None), metric_collection: None, @@ -276,11 +287,19 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig accept_jwts: true, console_redirect_confirmation_timeout: Duration::ZERO, }, + #[cfg(feature = "rest_broker")] + rest_config: RestConfig { + is_rest_broker: false, + db_schema_cache: None, + max_schema_size: 0, + hostname_prefix: String::new(), + }, proxy_protocol_v2: config::ProxyProtocolV2::Rejected, handshake_timeout: Duration::from_secs(10), wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?, connect_compute_locks, connect_to_compute: compute_config, + greetings, #[cfg(feature = "testing")] disable_pg_session_jwt: args.disable_pg_session_jwt, }))) diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 16a7dc7b67..255f6fbbee 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -1,4 +1,3 @@ -#[cfg(any(test, feature = "testing"))] use std::env; use std::net::SocketAddr; use std::path::PathBuf; @@ -14,14 +13,14 @@ use arc_swap::ArcSwapOption; use camino::Utf8PathBuf; use futures::future::Either; use itertools::{Itertools, Position}; -use rand::{Rng, thread_rng}; +use rand::Rng; use remote_storage::RemoteStorageConfig; use tokio::net::TcpListener; #[cfg(any(test, feature = "testing"))] use tokio::sync::Notify; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; -use tracing::{error, info, warn}; +use tracing::{debug, error, info, warn}; use utils::sentry_init::init_sentry; use utils::{project_build_tag, project_git_version}; @@ -31,6 +30,8 @@ use crate::auth::backend::local::LocalBackend; use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned}; use crate::batch::BatchQueue; use crate::cancellation::{CancellationHandler, CancellationProcessor}; +#[cfg(feature = "rest_broker")] +use crate::config::RestConfig; #[cfg(any(test, feature = "testing"))] use crate::config::refresh_config_loop; use crate::config::{ @@ -47,6 +48,8 @@ use crate::redis::{elasticache, notifications}; use crate::scram::threadpool::ThreadPool; use crate::serverless::GlobalConnPoolOptions; use crate::serverless::cancel_set::CancelSet; +#[cfg(feature = "rest_broker")] +use crate::serverless::rest::DbSchemaCache; use crate::tls::client_config::compute_client_config_with_root_certs; #[cfg(any(test, feature = "testing"))] use crate::url::ApiUrl; @@ -246,11 +249,23 @@ struct ProxyCliArgs { /// if this is not local proxy, this toggles whether we accept Postgres REST requests #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] + #[cfg(feature = "rest_broker")] is_rest_broker: bool, /// cache for `db_schema_cache` introspection (use `size=0` to disable) #[clap(long, default_value = "size=1000,ttl=1h")] + #[cfg(feature = "rest_broker")] db_schema_cache: String, + + /// Maximum size allowed for schema in bytes + #[clap(long, default_value_t = 5 * 1024 * 1024)] // 5MB + #[cfg(feature = "rest_broker")] + max_schema_size: usize, + + /// Hostname prefix to strip from request hostname to get database hostname + #[clap(long, default_value = "apirest.")] + #[cfg(feature = "rest_broker")] + hostname_prefix: String, } #[derive(clap::Args, Clone, Copy, Debug)] @@ -517,6 +532,17 @@ pub async fn run() -> anyhow::Result<()> { )); maintenance_tasks.spawn(control_plane::mgmt::task_main(mgmt_listener)); + // add a task to flush the db_schema cache every 10 minutes + #[cfg(feature = "rest_broker")] + if let Some(db_schema_cache) = &config.rest_config.db_schema_cache { + maintenance_tasks.spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(600)).await; + db_schema_cache.flush(); + } + }); + } + if let Some(metrics_config) = &config.metric_collection { // TODO: Add gc regardles of the metric collection being enabled. maintenance_tasks.spawn(usage_metrics::task_main(metrics_config)); @@ -547,7 +573,7 @@ pub async fn run() -> anyhow::Result<()> { attempt.into_inner() ); } - let jitter = thread_rng().gen_range(0..100); + let jitter = rand::rng().random_range(0..100); tokio::time::sleep(Duration::from_millis(1000 + jitter)).await; } } @@ -679,6 +705,49 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { timeout: Duration::from_secs(2), }; + #[cfg(feature = "rest_broker")] + let rest_config = { + let db_schema_cache_config: CacheOptions = args.db_schema_cache.parse()?; + info!("Using DbSchemaCache with options={db_schema_cache_config:?}"); + + let db_schema_cache = if args.is_rest_broker { + Some(DbSchemaCache::new( + "db_schema_cache", + db_schema_cache_config.size, + db_schema_cache_config.ttl, + true, + )) + } else { + None + }; + + RestConfig { + is_rest_broker: args.is_rest_broker, + db_schema_cache, + max_schema_size: args.max_schema_size, + hostname_prefix: args.hostname_prefix.clone(), + } + }; + + let mut greetings = env::var_os("NEON_MOTD").map_or(String::new(), |s| match s.into_string() { + Ok(s) => s, + Err(_) => { + debug!("NEON_MOTD environment variable is not valid UTF-8"); + String::new() + } + }); + + match &args.auth_backend { + AuthBackendType::ControlPlane => {} + #[cfg(any(test, feature = "testing"))] + AuthBackendType::Postgres => {} + #[cfg(any(test, feature = "testing"))] + AuthBackendType::Local => {} + AuthBackendType::ConsoleRedirect => { + greetings = "Connected to database".to_string(); + } + } + let config = ProxyConfig { tls_config, metric_collection, @@ -689,8 +758,11 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?, connect_compute_locks, connect_to_compute: compute_config, + greetings, #[cfg(feature = "testing")] disable_pg_session_jwt: false, + #[cfg(feature = "rest_broker")] + rest_config, }; let config = Box::leak(Box::new(config)); diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 0ef09a8a9a..a589dd175b 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -5,7 +5,7 @@ use std::time::Duration; use async_trait::async_trait; use clashmap::ClashMap; use clashmap::mapref::one::Ref; -use rand::{Rng, thread_rng}; +use rand::Rng; use tokio::time::Instant; use tracing::{debug, info}; @@ -343,7 +343,7 @@ impl ProjectInfoCacheImpl { } fn gc(&self) { - let shard = thread_rng().gen_range(0..self.project2ep.shards().len()); + let shard = rand::rng().random_range(0..self.project2ep.shards().len()); debug!(shard, "project_info_cache: performing epoch reclamation"); // acquire a random shard lock diff --git a/proxy/src/cache/timed_lru.rs b/proxy/src/cache/timed_lru.rs index e87cf53ab9..0a7fb40b0c 100644 --- a/proxy/src/cache/timed_lru.rs +++ b/proxy/src/cache/timed_lru.rs @@ -204,6 +204,11 @@ impl TimedLru { self.insert_raw_ttl(key, value, ttl, false); } + #[cfg(feature = "rest_broker")] + pub(crate) fn insert(&self, key: K, value: V) { + self.insert_raw_ttl(key, value, self.ttl, self.update_ttl_on_retrieval); + } + pub(crate) fn insert_unit(&self, key: K, value: V) -> (Option, Cached<&Self, ()>) { let (_, old) = self.insert_raw(key.clone(), value); @@ -214,6 +219,29 @@ impl TimedLru { (old, cached) } + + #[cfg(feature = "rest_broker")] + pub(crate) fn flush(&self) { + let now = Instant::now(); + let mut cache = self.cache.lock(); + + // Collect keys of expired entries first + let expired_keys: Vec<_> = cache + .iter() + .filter_map(|(key, entry)| { + if entry.expires_at <= now { + Some(key.clone()) + } else { + None + } + }) + .collect(); + + // Remove expired entries + for key in expired_keys { + cache.remove(&key); + } + } } impl TimedLru { diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 6157dc8a6a..16b1dff5f4 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -22,6 +22,8 @@ use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig}; use crate::scram::threadpool::ThreadPool; use crate::serverless::GlobalConnPoolOptions; use crate::serverless::cancel_set::CancelSet; +#[cfg(feature = "rest_broker")] +use crate::serverless::rest::DbSchemaCache; pub use crate::tls::server_config::{TlsConfig, configure_tls}; use crate::types::{Host, RoleName}; @@ -30,11 +32,14 @@ pub struct ProxyConfig { pub metric_collection: Option, pub http_config: HttpConfig, pub authentication_config: AuthenticationConfig, + #[cfg(feature = "rest_broker")] + pub rest_config: RestConfig, pub proxy_protocol_v2: ProxyProtocolV2, pub handshake_timeout: Duration, pub wake_compute_retry_config: RetryConfig, pub connect_compute_locks: ApiLocks, pub connect_to_compute: ComputeConfig, + pub greetings: String, // Greeting message sent to the client after connection establishment and contains session_id. #[cfg(feature = "testing")] pub disable_pg_session_jwt: bool, } @@ -80,6 +85,14 @@ pub struct AuthenticationConfig { pub console_redirect_confirmation_timeout: tokio::time::Duration, } +#[cfg(feature = "rest_broker")] +pub struct RestConfig { + pub is_rest_broker: bool, + pub db_schema_cache: Option, + pub max_schema_size: usize, + pub hostname_prefix: String, +} + #[derive(Debug)] pub struct MetricBackupCollectionConfig { pub remote_storage_config: Option, diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 041a56e032..014317d823 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -233,7 +233,13 @@ pub(crate) async fn handle_client( let session = cancellation_handler.get_key(); - finish_client_init(&pg_settings, *session.key(), &mut stream); + finish_client_init( + ctx, + &pg_settings, + *session.key(), + &mut stream, + &config.greetings, + ); let stream = stream.flush_and_into_inner().await?; let session_id = ctx.session_id(); diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index 4d8df19476..715b818b98 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -523,29 +523,29 @@ mod tests { fn generate_request_data(rng: &mut impl Rng) -> RequestData { RequestData { - session_id: uuid::Builder::from_random_bytes(rng.r#gen()).into_uuid(), - peer_addr: Ipv4Addr::from(rng.r#gen::<[u8; 4]>()).to_string(), + session_id: uuid::Builder::from_random_bytes(rng.random()).into_uuid(), + peer_addr: Ipv4Addr::from(rng.random::<[u8; 4]>()).to_string(), timestamp: chrono::DateTime::from_timestamp_millis( - rng.gen_range(1703862754..1803862754), + rng.random_range(1703862754..1803862754), ) .unwrap() .naive_utc(), application_name: Some("test".to_owned()), user_agent: Some("test-user-agent".to_owned()), - username: Some(hex::encode(rng.r#gen::<[u8; 4]>())), - endpoint_id: Some(hex::encode(rng.r#gen::<[u8; 16]>())), - database: Some(hex::encode(rng.r#gen::<[u8; 16]>())), - project: Some(hex::encode(rng.r#gen::<[u8; 16]>())), - branch: Some(hex::encode(rng.r#gen::<[u8; 16]>())), + username: Some(hex::encode(rng.random::<[u8; 4]>())), + endpoint_id: Some(hex::encode(rng.random::<[u8; 16]>())), + database: Some(hex::encode(rng.random::<[u8; 16]>())), + project: Some(hex::encode(rng.random::<[u8; 16]>())), + branch: Some(hex::encode(rng.random::<[u8; 16]>())), pg_options: None, auth_method: None, jwt_issuer: None, - protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)], + protocol: ["tcp", "ws", "http"][rng.random_range(0..3)], region: String::new(), error: None, - success: rng.r#gen(), + success: rng.random(), cold_start_info: "no", - duration_us: rng.gen_range(0..30_000_000), + duration_us: rng.random_range(0..30_000_000), disconnect_timestamp: None, } } @@ -622,15 +622,15 @@ mod tests { assert_eq!( file_stats, [ - (1313953, 3, 6000), - (1313942, 3, 6000), - (1314001, 3, 6000), - (1313958, 3, 6000), - (1314094, 3, 6000), - (1313931, 3, 6000), - (1313725, 3, 6000), - (1313960, 3, 6000), - (438318, 1, 2000) + (1313878, 3, 6000), + (1313891, 3, 6000), + (1314058, 3, 6000), + (1313914, 3, 6000), + (1313760, 3, 6000), + (1314084, 3, 6000), + (1313965, 3, 6000), + (1313911, 3, 6000), + (438290, 1, 2000) ] ); @@ -662,11 +662,11 @@ mod tests { assert_eq!( file_stats, [ - (1205810, 5, 10000), - (1205534, 5, 10000), - (1205835, 5, 10000), - (1205820, 5, 10000), - (1206074, 5, 10000) + (1206039, 5, 10000), + (1205798, 5, 10000), + (1205776, 5, 10000), + (1206051, 5, 10000), + (1205746, 5, 10000) ] ); @@ -691,15 +691,15 @@ mod tests { assert_eq!( file_stats, [ - (1313953, 3, 6000), - (1313942, 3, 6000), - (1314001, 3, 6000), - (1313958, 3, 6000), - (1314094, 3, 6000), - (1313931, 3, 6000), - (1313725, 3, 6000), - (1313960, 3, 6000), - (438318, 1, 2000) + (1313878, 3, 6000), + (1313891, 3, 6000), + (1314058, 3, 6000), + (1313914, 3, 6000), + (1313760, 3, 6000), + (1314084, 3, 6000), + (1313965, 3, 6000), + (1313911, 3, 6000), + (438290, 1, 2000) ] ); @@ -736,7 +736,7 @@ mod tests { // files are smaller than the size threshold, but they took too long to fill so were flushed early assert_eq!( file_stats, - [(658584, 2, 3001), (658298, 2, 3000), (658094, 2, 2999)] + [(658552, 2, 3001), (658265, 2, 3000), (658061, 2, 2999)] ); tmpdir.close().unwrap(); diff --git a/proxy/src/intern.rs b/proxy/src/intern.rs index d7e39ebaf4..825f2d1049 100644 --- a/proxy/src/intern.rs +++ b/proxy/src/intern.rs @@ -247,7 +247,7 @@ mod tests { use rand::{Rng, SeedableRng}; use rand_distr::Zipf; - let endpoint_dist = Zipf::new(500000, 0.8).unwrap(); + let endpoint_dist = Zipf::new(500000.0, 0.8).unwrap(); let endpoints = StdRng::seed_from_u64(272488357).sample_iter(endpoint_dist); let interner = MyId::get_interner(); diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 916604e2ec..7524133093 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -385,10 +385,10 @@ pub enum RedisMsgKind { #[derive(Default, Clone)] pub struct LatencyAccumulated { - cplane: time::Duration, - client: time::Duration, - compute: time::Duration, - retry: time::Duration, + pub cplane: time::Duration, + pub client: time::Duration, + pub compute: time::Duration, + pub retry: time::Duration, } impl std::fmt::Display for LatencyAccumulated { diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs index ad99eecda5..680a23c435 100644 --- a/proxy/src/pqproto.rs +++ b/proxy/src/pqproto.rs @@ -7,7 +7,7 @@ use std::io::{self, Cursor}; use bytes::{Buf, BufMut}; use itertools::Itertools; -use rand::distributions::{Distribution, Standard}; +use rand::distr::{Distribution, StandardUniform}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian}; @@ -458,9 +458,9 @@ impl fmt::Display for CancelKeyData { .finish() } } -impl Distribution for Standard { +impl Distribution for StandardUniform { fn sample(&self, rng: &mut R) -> CancelKeyData { - id_to_cancel_key(rng.r#gen()) + id_to_cancel_key(rng.random()) } } diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 02651109e0..8b7c4ff55d 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -145,7 +145,7 @@ pub(crate) async fn handle_client( let session = cancellation_handler.get_key(); - finish_client_init(&pg_settings, *session.key(), client); + finish_client_init(ctx, &pg_settings, *session.key(), client, &config.greetings); let session_id = ctx.session_id(); let (cancel_on_shutdown, cancel) = oneshot::channel(); @@ -165,9 +165,11 @@ pub(crate) async fn handle_client( /// Finish client connection initialization: confirm auth success, send params, etc. pub(crate) fn finish_client_init( + ctx: &RequestContext, settings: &compute::PostgresSettings, cancel_key_data: CancelKeyData, client: &mut PqStream, + greetings: &String, ) { // Forward all deferred notices to the client. for notice in &settings.delayed_notice { @@ -176,6 +178,12 @@ pub(crate) fn finish_client_init( }); } + // Expose session_id to clients if we have a greeting message. + if !greetings.is_empty() { + let session_msg = format!("{}, session_id: {}", greetings, ctx.session_id()); + client.write_message(BeMessage::NoticeResponse(session_msg.as_str())); + } + // Forward all postgres connection params to the client. for (name, value) in &settings.params { client.write_message(BeMessage::ParameterStatus { @@ -184,6 +192,36 @@ pub(crate) fn finish_client_init( }); } + // Forward recorded latencies for probing requests + if let Some(testodrome_id) = ctx.get_testodrome_id() { + client.write_message(BeMessage::ParameterStatus { + name: "neon.testodrome_id".as_bytes(), + value: testodrome_id.as_bytes(), + }); + + let latency_measured = ctx.get_proxy_latency(); + + client.write_message(BeMessage::ParameterStatus { + name: "neon.cplane_latency".as_bytes(), + value: latency_measured.cplane.as_micros().to_string().as_bytes(), + }); + + client.write_message(BeMessage::ParameterStatus { + name: "neon.client_latency".as_bytes(), + value: latency_measured.client.as_micros().to_string().as_bytes(), + }); + + client.write_message(BeMessage::ParameterStatus { + name: "neon.compute_latency".as_bytes(), + value: latency_measured.compute.as_micros().to_string().as_bytes(), + }); + + client.write_message(BeMessage::ParameterStatus { + name: "neon.retry_latency".as_bytes(), + value: latency_measured.retry.as_micros().to_string().as_bytes(), + }); + } + client.write_message(BeMessage::BackendKeyData(cancel_key_data)); client.write_message(BeMessage::ReadyForQuery); } diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index dd89b05426..f8bff450e1 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -338,8 +338,8 @@ async fn scram_auth_mock() -> anyhow::Result<()> { let proxy = tokio::spawn(dummy_proxy(client, Some(server_config), Scram::mock())); use rand::Rng; - use rand::distributions::Alphanumeric; - let password: String = rand::thread_rng() + use rand::distr::Alphanumeric; + let password: String = rand::rng() .sample_iter(&Alphanumeric) .take(rand::random::() as usize) .map(char::from) diff --git a/proxy/src/rate_limiter/leaky_bucket.rs b/proxy/src/rate_limiter/leaky_bucket.rs index 12b4bda0c0..9de82e922c 100644 --- a/proxy/src/rate_limiter/leaky_bucket.rs +++ b/proxy/src/rate_limiter/leaky_bucket.rs @@ -3,7 +3,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use ahash::RandomState; use clashmap::ClashMap; -use rand::{Rng, thread_rng}; +use rand::Rng; use tokio::time::Instant; use tracing::info; use utils::leaky_bucket::LeakyBucketState; @@ -61,7 +61,7 @@ impl LeakyBucketRateLimiter { self.map.len() ); let n = self.map.shards().len(); - let shard = thread_rng().gen_range(0..n); + let shard = rand::rng().random_range(0..n); self.map.shards()[shard] .write() .retain(|(_, value)| !value.bucket_is_empty(now)); diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index fd1b2af023..2b3d745a0e 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -147,7 +147,7 @@ impl RateBucketInfo { impl BucketRateLimiter { pub fn new(info: impl Into>) -> Self { - Self::new_with_rand_and_hasher(info, StdRng::from_entropy(), RandomState::new()) + Self::new_with_rand_and_hasher(info, StdRng::from_os_rng(), RandomState::new()) } } @@ -216,7 +216,7 @@ impl BucketRateLimiter { let n = self.map.shards().len(); // this lock is ok as the periodic cycle of do_gc makes this very unlikely to collide // (impossible, infact, unless we have 2048 threads) - let shard = self.rand.lock_propagate_poison().gen_range(0..n); + let shard = self.rand.lock_propagate_poison().random_range(0..n); self.map.shards()[shard].write().clear(); } } diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index a6d376562b..88d5550fff 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -10,6 +10,7 @@ use super::connection_with_credentials_provider::ConnectionWithCredentialsProvid use crate::cache::project_info::ProjectInfoCache; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::metrics::{Metrics, RedisErrors, RedisEventsCount}; +use crate::util::deserialize_json_string; const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20); @@ -121,15 +122,6 @@ struct InvalidateRole { role_name: RoleNameInt, } -fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result -where - T: for<'de2> serde::Deserialize<'de2>, - D: serde::Deserializer<'de>, -{ - let s = String::deserialize(deserializer)?; - serde_json::from_str(&s).map_err(::custom) -} - // https://github.com/serde-rs/serde/issues/1714 fn deserialize_unknown_topic<'de, D>(deserializer: D) -> Result<(), D::Error> where diff --git a/proxy/src/scram/countmin.rs b/proxy/src/scram/countmin.rs index 9d56c465ec..d64895f8f5 100644 --- a/proxy/src/scram/countmin.rs +++ b/proxy/src/scram/countmin.rs @@ -86,11 +86,11 @@ mod tests { for _ in 0..n { // number to insert at once - let n = rng.gen_range(1..4096); + let n = rng.random_range(1..4096); // number of insert operations - let m = rng.gen_range(1..100); + let m = rng.random_range(1..100); - let id = uuid::Builder::from_random_bytes(rng.r#gen()).into_uuid(); + let id = uuid::Builder::from_random_bytes(rng.random()).into_uuid(); ids.push((id, n, m)); // N = sum(actual) @@ -140,8 +140,8 @@ mod tests { // probably numbers are too small to truly represent the probabilities. assert_eq!(eval_precision(100, 4096.0, 0.90), 100); assert_eq!(eval_precision(1000, 4096.0, 0.90), 1000); - assert_eq!(eval_precision(100, 4096.0, 0.1), 96); - assert_eq!(eval_precision(1000, 4096.0, 0.1), 988); + assert_eq!(eval_precision(100, 4096.0, 0.1), 100); + assert_eq!(eval_precision(1000, 4096.0, 0.1), 978); } // returns memory usage in bytes, and the time complexity per insert. diff --git a/proxy/src/scram/threadpool.rs b/proxy/src/scram/threadpool.rs index 1aa402227f..ea2e29ede9 100644 --- a/proxy/src/scram/threadpool.rs +++ b/proxy/src/scram/threadpool.rs @@ -51,7 +51,7 @@ impl ThreadPool { *state = Some(ThreadRt { pool: pool.clone(), id: ThreadPoolWorkerId(worker_id.fetch_add(1, Ordering::Relaxed)), - rng: SmallRng::from_entropy(), + rng: SmallRng::from_os_rng(), // used to determine whether we should temporarily skip tasks for fairness. // 99% of estimates will overcount by no more than 4096 samples countmin: CountMinSketch::with_params( @@ -120,7 +120,7 @@ impl ThreadRt { // in which case the SKETCH_RESET_INTERVAL represents 1 second. Thus, the rates above // are in requests per second. let probability = P.ln() / (P + rate as f64).ln(); - self.rng.gen_bool(probability) + self.rng.random_bool(probability) } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index daa6429039..59e4b09bc9 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -8,7 +8,7 @@ use ed25519_dalek::SigningKey; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use jose_jwk::jose_b64; use postgres_client::config::SslMode; -use rand::rngs::OsRng; +use rand_core::OsRng; use rustls::pki_types::{DnsName, ServerName}; use tokio::net::{TcpStream, lookup_host}; use tokio_rustls::TlsConnector; diff --git a/proxy/src/serverless/cancel_set.rs b/proxy/src/serverless/cancel_set.rs index ba8945afc5..142dc3b3d5 100644 --- a/proxy/src/serverless/cancel_set.rs +++ b/proxy/src/serverless/cancel_set.rs @@ -6,7 +6,7 @@ use std::time::Duration; use indexmap::IndexMap; use parking_lot::Mutex; -use rand::{Rng, thread_rng}; +use rand::distr::uniform::{UniformSampler, UniformUsize}; use rustc_hash::FxHasher; use tokio::time::Instant; use tokio_util::sync::CancellationToken; @@ -39,8 +39,9 @@ impl CancelSet { } pub(crate) fn take(&self) -> Option { + let dist = UniformUsize::new_inclusive(0, usize::MAX).expect("valid bounds"); for _ in 0..4 { - if let Some(token) = self.take_raw(thread_rng().r#gen()) { + if let Some(token) = self.take_raw(dist.sample(&mut rand::rng())) { return Some(token); } tracing::trace!("failed to get cancel token"); @@ -48,7 +49,7 @@ impl CancelSet { None } - pub(crate) fn take_raw(&self, rng: usize) -> Option { + fn take_raw(&self, rng: usize) -> Option { NonZeroUsize::new(self.shards.len()) .and_then(|len| self.shards[rng % len].lock().take(rng / len)) } diff --git a/proxy/src/serverless/conn_pool_lib.rs b/proxy/src/serverless/conn_pool_lib.rs index 42a3ea17a2..ed5cc0ea03 100644 --- a/proxy/src/serverless/conn_pool_lib.rs +++ b/proxy/src/serverless/conn_pool_lib.rs @@ -428,7 +428,7 @@ where loop { interval.tick().await; - let shard = rng.gen_range(0..self.global_pool.shards().len()); + let shard = rng.random_range(0..self.global_pool.shards().len()); self.gc(shard); } } diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 5b7289c53d..13f9ee2782 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -11,6 +11,8 @@ mod http_conn_pool; mod http_util; mod json; mod local_conn_pool; +#[cfg(feature = "rest_broker")] +pub mod rest; mod sql_over_http; mod websocket; @@ -75,7 +77,7 @@ pub async fn task_main( { let conn_pool = Arc::clone(&conn_pool); tokio::spawn(async move { - conn_pool.gc_worker(StdRng::from_entropy()).await; + conn_pool.gc_worker(StdRng::from_os_rng()).await; }); } @@ -95,7 +97,7 @@ pub async fn task_main( { let http_conn_pool = Arc::clone(&http_conn_pool); tokio::spawn(async move { - http_conn_pool.gc_worker(StdRng::from_entropy()).await; + http_conn_pool.gc_worker(StdRng::from_os_rng()).await; }); } @@ -487,6 +489,42 @@ async fn request_handler( .body(Empty::new().map_err(|x| match x {}).boxed()) .map_err(|e| ApiError::InternalServerError(e.into())) } else { - json_response(StatusCode::BAD_REQUEST, "query is not supported") + #[cfg(feature = "rest_broker")] + { + if config.rest_config.is_rest_broker + // we are testing for the path to be /database_name/rest/... + && request + .uri() + .path() + .split('/') + .nth(2) + .is_some_and(|part| part.starts_with("rest")) + { + let ctx = + RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http); + let span = ctx.span(); + + let testodrome_id = request + .headers() + .get("X-Neon-Query-ID") + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + if let Some(query_id) = testodrome_id { + info!(parent: &span, "testodrome query ID: {query_id}"); + ctx.set_testodrome_id(query_id.into()); + } + + rest::handle(config, ctx, request, backend, http_cancellation_token) + .instrument(span) + .await + } else { + json_response(StatusCode::BAD_REQUEST, "query is not supported") + } + } + #[cfg(not(feature = "rest_broker"))] + { + json_response(StatusCode::BAD_REQUEST, "query is not supported") + } } } diff --git a/proxy/src/serverless/rest.rs b/proxy/src/serverless/rest.rs new file mode 100644 index 0000000000..173c2629f7 --- /dev/null +++ b/proxy/src/serverless/rest.rs @@ -0,0 +1,1165 @@ +use std::borrow::Cow; +use std::collections::HashMap; +use std::sync::Arc; + +use bytes::Bytes; +use http::Method; +use http::header::{AUTHORIZATION, CONTENT_TYPE, HOST}; +use http_body_util::combinators::BoxBody; +use http_body_util::{BodyExt, Full}; +use http_utils::error::ApiError; +use hyper::body::Incoming; +use hyper::http::{HeaderName, HeaderValue}; +use hyper::{Request, Response, StatusCode}; +use indexmap::IndexMap; +use ouroboros::self_referencing; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Deserializer}; +use serde_json::Value as JsonValue; +use serde_json::value::RawValue; +use subzero_core::api::ContentType::{ApplicationJSON, Other, SingularJSON, TextCSV}; +use subzero_core::api::QueryNode::{Delete, FunctionCall, Insert, Update}; +use subzero_core::api::Resolution::{IgnoreDuplicates, MergeDuplicates}; +use subzero_core::api::{ApiResponse, ListVal, Payload, Preferences, Representation, SingleVal}; +use subzero_core::config::{db_allowed_select_functions, db_schemas, role_claim_key}; +use subzero_core::dynamic_statement::{JoinIterator, param, sql}; +use subzero_core::error::Error::{ + self as SubzeroCoreError, ContentTypeError, GucHeadersError, GucStatusError, InternalError, + JsonDeserialize, JwtTokenInvalid, NotFound, +}; +use subzero_core::error::pg_error_to_status_code; +use subzero_core::formatter::Param::{LV, PL, SV, Str, StrOwned}; +use subzero_core::formatter::postgresql::{fmt_main_query, generate}; +use subzero_core::formatter::{Param, Snippet, SqlParam}; +use subzero_core::parser::postgrest::parse; +use subzero_core::permissions::{check_safe_functions, replace_select_star}; +use subzero_core::schema::{ + DbSchema, POSTGRESQL_INTROSPECTION_SQL, get_postgresql_configuration_query, +}; +use subzero_core::{content_range_header, content_range_status}; +use tokio_util::sync::CancellationToken; +use tracing::{error, info}; +use typed_json::json; +use url::form_urlencoded; + +use super::backend::{HttpConnError, LocalProxyConnError, PoolingBackend}; +use super::conn_pool::AuthData; +use super::conn_pool_lib::ConnInfo; +use super::error::{ConnInfoError, Credentials, HttpCodeError, ReadPayloadError}; +use super::http_conn_pool::{self, Send}; +use super::http_util::{ + ALLOW_POOL, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_ISOLATION_LEVEL, TXN_READ_ONLY, + get_conn_info, json_response, uuid_to_header_value, +}; +use super::json::JsonConversionError; +use crate::auth::backend::ComputeCredentialKeys; +use crate::cache::{Cached, TimedLru}; +use crate::config::ProxyConfig; +use crate::context::RequestContext; +use crate::error::{ErrorKind, ReportableError, UserFacingError}; +use crate::http::read_body_with_limit; +use crate::metrics::Metrics; +use crate::serverless::sql_over_http::HEADER_VALUE_TRUE; +use crate::types::EndpointCacheKey; +use crate::util::deserialize_json_string; + +static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#; +const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL; + +// A wrapper around the DbSchema that allows for self-referencing +#[self_referencing] +pub struct DbSchemaOwned { + schema_string: String, + #[covariant] + #[borrows(schema_string)] + schema: DbSchema<'this>, +} + +impl<'de> Deserialize<'de> for DbSchemaOwned { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + DbSchemaOwned::try_new(s, |s| serde_json::from_str(s)) + .map_err(::custom) + } +} + +fn split_comma_separated(s: &str) -> Vec { + s.split(',').map(|s| s.trim().to_string()).collect() +} + +fn deserialize_comma_separated<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + Ok(split_comma_separated(&s)) +} + +fn deserialize_comma_separated_option<'de, D>( + deserializer: D, +) -> Result>, D::Error> +where + D: Deserializer<'de>, +{ + let opt = Option::::deserialize(deserializer)?; + if let Some(s) = &opt { + let trimmed = s.trim(); + if trimmed.is_empty() { + return Ok(None); + } + return Ok(Some(split_comma_separated(trimmed))); + } + Ok(None) +} + +// The ApiConfig is the configuration for the API per endpoint +// The configuration is read from the database and cached in the DbSchemaCache +#[derive(Deserialize, Debug)] +pub struct ApiConfig { + #[serde( + default = "db_schemas", + deserialize_with = "deserialize_comma_separated" + )] + pub db_schemas: Vec, + pub db_anon_role: Option, + pub db_max_rows: Option, + #[serde(default = "db_allowed_select_functions")] + pub db_allowed_select_functions: Vec, + // #[serde(deserialize_with = "to_tuple", default)] + // pub db_pre_request: Option<(String, String)>, + #[allow(dead_code)] + #[serde(default = "role_claim_key")] + pub role_claim_key: String, + #[serde(default, deserialize_with = "deserialize_comma_separated_option")] + pub db_extra_search_path: Option>, +} + +// The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint +pub(crate) type DbSchemaCache = TimedLru>; +impl DbSchemaCache { + pub async fn get_cached_or_remote( + &self, + endpoint_id: &EndpointCacheKey, + auth_header: &HeaderValue, + connection_string: &str, + client: &mut http_conn_pool::Client, + ctx: &RequestContext, + config: &'static ProxyConfig, + ) -> Result, RestError> { + match self.get_with_created_at(endpoint_id) { + Some(Cached { value: (v, _), .. }) => Ok(v), + None => { + info!("db_schema cache miss for endpoint: {:?}", endpoint_id); + let remote_value = self + .get_remote(auth_header, connection_string, client, ctx, config) + .await; + let (api_config, schema_owned) = match remote_value { + Ok((api_config, schema_owned)) => (api_config, schema_owned), + Err(e @ RestError::SchemaTooLarge) => { + // for the case where the schema is too large, we cache an empty dummy value + // all the other requests will fail without triggering the introspection query + let schema_owned = serde_json::from_str::(EMPTY_JSON_SCHEMA) + .map_err(|e| JsonDeserialize { source: e })?; + + let api_config = ApiConfig { + db_schemas: vec![], + db_anon_role: None, + db_max_rows: None, + db_allowed_select_functions: vec![], + role_claim_key: String::new(), + db_extra_search_path: None, + }; + let value = Arc::new((api_config, schema_owned)); + self.insert(endpoint_id.clone(), value); + return Err(e); + } + Err(e) => { + return Err(e); + } + }; + let value = Arc::new((api_config, schema_owned)); + self.insert(endpoint_id.clone(), value.clone()); + Ok(value) + } + } + } + pub async fn get_remote( + &self, + auth_header: &HeaderValue, + connection_string: &str, + client: &mut http_conn_pool::Client, + ctx: &RequestContext, + config: &'static ProxyConfig, + ) -> Result<(ApiConfig, DbSchemaOwned), RestError> { + #[derive(Deserialize)] + struct SingleRow { + rows: [Row; 1], + } + + #[derive(Deserialize)] + struct ConfigRow { + #[serde(deserialize_with = "deserialize_json_string")] + config: ApiConfig, + } + + #[derive(Deserialize)] + struct SchemaRow { + json_schema: DbSchemaOwned, + } + + let headers = vec![ + (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())), + ( + &CONN_STRING, + HeaderValue::from_str(connection_string).expect( + "connection string came from a header, so it must be a valid headervalue", + ), + ), + (&AUTHORIZATION, auth_header.clone()), + (&RAW_TEXT_OUTPUT, HEADER_VALUE_TRUE), + ]; + + let query = get_postgresql_configuration_query(Some("pgrst.pre_config")); + let SingleRow { + rows: [ConfigRow { config: api_config }], + } = make_local_proxy_request( + client, + headers.iter().cloned(), + QueryData { + query: Cow::Owned(query), + params: vec![], + }, + config.rest_config.max_schema_size, + ) + .await + .map_err(|e| match e { + RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => { + RestError::SchemaTooLarge + } + e => e, + })?; + + // now that we have the api_config let's run the second INTROSPECTION_SQL query + let SingleRow { + rows: [SchemaRow { json_schema }], + } = make_local_proxy_request( + client, + headers, + QueryData { + query: INTROSPECTION_SQL.into(), + params: vec![ + serde_json::to_value(&api_config.db_schemas) + .expect("Vec is always valid to encode as JSON"), + JsonValue::Bool(false), // include_roles_with_login + JsonValue::Bool(false), // use_internal_permissions + ], + }, + config.rest_config.max_schema_size, + ) + .await + .map_err(|e| match e { + RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => { + RestError::SchemaTooLarge + } + e => e, + })?; + + Ok((api_config, json_schema)) + } +} + +// A type to represent a postgresql errors +// we use our own type (instead of postgres_client::Error) because we get the error from the json response +#[derive(Debug, thiserror::Error, Deserialize)] +pub(crate) struct PostgresError { + pub code: String, + pub message: String, + pub detail: Option, + pub hint: Option, +} +impl HttpCodeError for PostgresError { + fn get_http_status_code(&self) -> StatusCode { + let status = pg_error_to_status_code(&self.code, true); + StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) + } +} +impl ReportableError for PostgresError { + fn get_error_kind(&self) -> ErrorKind { + ErrorKind::User + } +} +impl UserFacingError for PostgresError { + fn to_string_client(&self) -> String { + if self.code.starts_with("PT") { + "Postgres error".to_string() + } else { + self.message.clone() + } + } +} +impl std::fmt::Display for PostgresError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} + +// A type to represent errors that can occur in the rest broker +#[derive(Debug, thiserror::Error)] +pub(crate) enum RestError { + #[error(transparent)] + ReadPayload(#[from] ReadPayloadError), + #[error(transparent)] + ConnectCompute(#[from] HttpConnError), + #[error(transparent)] + ConnInfo(#[from] ConnInfoError), + #[error(transparent)] + Postgres(#[from] PostgresError), + #[error(transparent)] + JsonConversion(#[from] JsonConversionError), + #[error(transparent)] + SubzeroCore(#[from] SubzeroCoreError), + #[error("schema is too large")] + SchemaTooLarge, +} +impl ReportableError for RestError { + fn get_error_kind(&self) -> ErrorKind { + match self { + RestError::ReadPayload(e) => e.get_error_kind(), + RestError::ConnectCompute(e) => e.get_error_kind(), + RestError::ConnInfo(e) => e.get_error_kind(), + RestError::Postgres(_) => ErrorKind::Postgres, + RestError::JsonConversion(_) => ErrorKind::Postgres, + RestError::SubzeroCore(_) => ErrorKind::User, + RestError::SchemaTooLarge => ErrorKind::User, + } + } +} +impl UserFacingError for RestError { + fn to_string_client(&self) -> String { + match self { + RestError::ReadPayload(p) => p.to_string(), + RestError::ConnectCompute(c) => c.to_string_client(), + RestError::ConnInfo(c) => c.to_string_client(), + RestError::SchemaTooLarge => self.to_string(), + RestError::Postgres(p) => p.to_string_client(), + RestError::JsonConversion(_) => "could not parse postgres response".to_string(), + RestError::SubzeroCore(s) => { + // TODO: this is a hack to get the message from the json body + let json = s.json_body(); + let default_message = "Unknown error".to_string(); + + json.get("message") + .map_or(default_message.clone(), |m| match m { + JsonValue::String(s) => s.clone(), + _ => default_message, + }) + } + } + } +} +impl HttpCodeError for RestError { + fn get_http_status_code(&self) -> StatusCode { + match self { + RestError::ReadPayload(e) => e.get_http_status_code(), + RestError::ConnectCompute(h) => match h.get_error_kind() { + ErrorKind::User => StatusCode::BAD_REQUEST, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }, + RestError::ConnInfo(_) => StatusCode::BAD_REQUEST, + RestError::Postgres(e) => e.get_http_status_code(), + RestError::JsonConversion(_) => StatusCode::INTERNAL_SERVER_ERROR, + RestError::SchemaTooLarge => StatusCode::INTERNAL_SERVER_ERROR, + RestError::SubzeroCore(e) => { + let status = e.status_code(); + StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) + } + } + } +} + +// Helper functions for the rest broker + +fn fmt_env_query<'a>(env: &'a HashMap<&'a str, &'a str>) -> Snippet<'a> { + "select " + + if env.is_empty() { + sql("null") + } else { + env.iter() + .map(|(k, v)| { + "set_config(" + param(k as &SqlParam) + ", " + param(v as &SqlParam) + ", true)" + }) + .join(",") + } +} + +// TODO: see about removing the need for cloning the values (inner things are &Cow already) +fn to_sql_param(p: &Param) -> JsonValue { + match p { + SV(SingleVal(v, ..)) => JsonValue::String(v.to_string()), + Str(v) => JsonValue::String((*v).to_string()), + StrOwned(v) => JsonValue::String((*v).clone()), + PL(Payload(v, ..)) => JsonValue::String(v.clone().into_owned()), + LV(ListVal(v, ..)) => { + if v.is_empty() { + JsonValue::String(r"{}".to_string()) + } else { + JsonValue::String(format!( + "{{\"{}\"}}", + v.iter() + .map(|e| e.replace('\\', "\\\\").replace('\"', "\\\"")) + .collect::>() + .join("\",\"") + )) + } + } + } +} + +#[derive(serde::Serialize)] +struct QueryData<'a> { + query: Cow<'a, str>, + params: Vec, +} + +#[derive(serde::Serialize)] +struct BatchQueryData<'a> { + queries: Vec>, +} + +async fn make_local_proxy_request( + client: &mut http_conn_pool::Client, + headers: impl IntoIterator, + body: QueryData<'_>, + max_len: usize, +) -> Result { + let body_string = serde_json::to_string(&body) + .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?; + + let response = make_raw_local_proxy_request(client, headers, body_string).await?; + + let response_status = response.status(); + + if response_status != StatusCode::OK { + return Err(RestError::SubzeroCore(InternalError { + message: "Failed to get endpoint schema".to_string(), + })); + } + + // Capture the response body + let response_body = crate::http::read_body_with_limit(response.into_body(), max_len) + .await + .map_err(ReadPayloadError::from)?; + + // Parse the JSON response + let response_json: S = serde_json::from_slice(&response_body) + .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; + + Ok(response_json) +} + +async fn make_raw_local_proxy_request( + client: &mut http_conn_pool::Client, + headers: impl IntoIterator, + body: String, +) -> Result, RestError> { + let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql"); + let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri); + let req_headers = req.headers_mut().expect("failed to get headers"); + // Add all provided headers to the request + for (header_name, header_value) in headers { + req_headers.insert(header_name, header_value.clone()); + } + + let body_boxed = Full::new(Bytes::from(body)) + .map_err(|never| match never {}) // Convert Infallible to hyper::Error + .boxed(); + + let req = req.body(body_boxed).map_err(|_| { + RestError::SubzeroCore(InternalError { + message: "Failed to build request".to_string(), + }) + })?; + + // Send the request to the local proxy + client + .inner + .inner + .send_request(req) + .await + .map_err(LocalProxyConnError::from) + .map_err(HttpConnError::from) + .map_err(RestError::from) +} + +pub(crate) async fn handle( + config: &'static ProxyConfig, + ctx: RequestContext, + request: Request, + backend: Arc, + cancel: CancellationToken, +) -> Result>, ApiError> { + let result = handle_inner(cancel, config, &ctx, request, backend).await; + + let mut response = match result { + Ok(r) => { + ctx.set_success(); + + // Handling the error response from local proxy here + if r.status().is_server_error() { + let status = r.status(); + + let body_bytes = r + .collect() + .await + .map_err(|e| { + ApiError::InternalServerError(anyhow::Error::msg(format!( + "could not collect http body: {e}" + ))) + })? + .to_bytes(); + + if let Ok(mut json_map) = + serde_json::from_slice::>(&body_bytes) + { + let message = json_map.get("message"); + if let Some(message) = message { + let msg: String = match serde_json::from_str(message.get()) { + Ok(msg) => msg, + Err(_) => { + "Unable to parse the response message from server".to_string() + } + }; + + error!("Error response from local_proxy: {status} {msg}"); + + json_map.retain(|key, _| !key.starts_with("neon:")); // remove all the neon-related keys + + let resp_json = serde_json::to_string(&json_map) + .unwrap_or("failed to serialize the response message".to_string()); + + return json_response(status, resp_json); + } + } + + error!("Unable to parse the response message from local_proxy"); + return json_response( + status, + json!({ "message": "Unable to parse the response message from server".to_string() }), + ); + } + r + } + Err(e @ RestError::SubzeroCore(_)) => { + let error_kind = e.get_error_kind(); + ctx.set_error_kind(error_kind); + + tracing::info!( + kind=error_kind.to_metric_label(), + error=%e, + msg="subzero core error", + "forwarding error to user" + ); + + let RestError::SubzeroCore(subzero_err) = e else { + panic!("expected subzero core error") + }; + + let json_body = subzero_err.json_body(); + let status_code = StatusCode::from_u16(subzero_err.status_code()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + + json_response(status_code, json_body)? + } + Err(e) => { + let error_kind = e.get_error_kind(); + ctx.set_error_kind(error_kind); + + let message = e.to_string_client(); + let status_code = e.get_http_status_code(); + + tracing::info!( + kind=error_kind.to_metric_label(), + error=%e, + msg=message, + "forwarding error to user" + ); + + let (code, detail, hint) = match e { + RestError::Postgres(e) => ( + if e.code.starts_with("PT") { + None + } else { + Some(e.code) + }, + e.detail, + e.hint, + ), + _ => (None, None, None), + }; + + json_response( + status_code, + json!({ + "message": message, + "code": code, + "detail": detail, + "hint": hint, + }), + )? + } + }; + + response + .headers_mut() + .insert("Access-Control-Allow-Origin", HeaderValue::from_static("*")); + Ok(response) +} + +async fn handle_inner( + _cancel: CancellationToken, + config: &'static ProxyConfig, + ctx: &RequestContext, + request: Request, + backend: Arc, +) -> Result>, RestError> { + let _requeset_gauge = Metrics::get() + .proxy + .connection_requests + .guard(ctx.protocol()); + info!( + protocol = %ctx.protocol(), + "handling interactive connection from client" + ); + + // Read host from Host, then URI host as fallback + // TODO: will this be a problem if behind a load balancer? + // TODO: can we use the x-forwarded-host header? + let host = request + .headers() + .get(HOST) + .and_then(|v| v.to_str().ok()) + .unwrap_or_else(|| request.uri().host().unwrap_or("")); + + // a valid path is /database/rest/v1/... so splitting should be ["", "database", "rest", "v1", ...] + let database_name = request + .uri() + .path() + .split('/') + .nth(1) + .ok_or(RestError::SubzeroCore(NotFound { + target: request.uri().path().to_string(), + }))?; + + // we always use the authenticator role to connect to the database + let authenticator_role = "authenticator"; + + // Strip the hostname prefix from the host to get the database hostname + let database_host = host.replace(&config.rest_config.hostname_prefix, ""); + + let connection_string = + format!("postgresql://{authenticator_role}@{database_host}/{database_name}"); + + let conn_info = get_conn_info( + &config.authentication_config, + ctx, + Some(&connection_string), + request.headers(), + )?; + info!( + user = conn_info.conn_info.user_info.user.as_str(), + "credentials" + ); + + match conn_info.auth { + AuthData::Jwt(jwt) => { + let api_prefix = format!("/{database_name}/rest/v1/"); + handle_rest_inner( + config, + ctx, + &api_prefix, + request, + &connection_string, + conn_info.conn_info, + jwt, + backend, + ) + .await + } + AuthData::Password(_) => Err(RestError::ConnInfo(ConnInfoError::MissingCredentials( + Credentials::BearerJwt, + ))), + } +} + +#[allow(clippy::too_many_arguments)] +async fn handle_rest_inner( + config: &'static ProxyConfig, + ctx: &RequestContext, + api_prefix: &str, + request: Request, + connection_string: &str, + conn_info: ConnInfo, + jwt: String, + backend: Arc, +) -> Result>, RestError> { + // validate the jwt token + let jwt_parsed = backend + .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) + .await + .map_err(HttpConnError::from)?; + + let db_schema_cache = + config + .rest_config + .db_schema_cache + .as_ref() + .ok_or(RestError::SubzeroCore(InternalError { + message: "DB schema cache is not configured".to_string(), + }))?; + + let endpoint_cache_key = conn_info + .endpoint_cache_key() + .ok_or(RestError::SubzeroCore(InternalError { + message: "Failed to get endpoint cache key".to_string(), + }))?; + + let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?; + + let (parts, originial_body) = request.into_parts(); + + let auth_header = parts + .headers + .get(AUTHORIZATION) + .ok_or(RestError::SubzeroCore(InternalError { + message: "Authorization header is required".to_string(), + }))?; + + let entry = db_schema_cache + .get_cached_or_remote( + &endpoint_cache_key, + auth_header, + connection_string, + &mut client, + ctx, + config, + ) + .await?; + let (api_config, db_schema_owned) = entry.as_ref(); + let db_schema = db_schema_owned.borrow_schema(); + + let db_schemas = &api_config.db_schemas; // list of schemas available for the api + let db_extra_search_path = &api_config.db_extra_search_path; + // TODO: use this when we get a replacement for jsonpath_lib + // let role_claim_key = &api_config.role_claim_key; + // let role_claim_path = format!("${role_claim_key}"); + let db_anon_role = &api_config.db_anon_role; + let max_rows = api_config.db_max_rows.as_deref(); + let db_allowed_select_functions = api_config + .db_allowed_select_functions + .iter() + .map(|s| s.as_str()) + .collect::>(); + + // extract the jwt claims (we'll need them later to set the role and env) + let jwt_claims = match jwt_parsed.keys { + ComputeCredentialKeys::JwtPayload(payload_bytes) => { + // `payload_bytes` contains the raw JWT payload as Vec + // You can deserialize it back to JSON or parse specific claims + let payload: serde_json::Value = serde_json::from_slice(&payload_bytes) + .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; + Some(payload) + } + ComputeCredentialKeys::AuthKeys(_) => None, + }; + + // read the role from the jwt claims (and set it to the "anon" role if not present) + let (role, authenticated) = match &jwt_claims { + Some(claims) => match claims.get("role") { + Some(JsonValue::String(r)) => (Some(r), true), + _ => (db_anon_role.as_ref(), true), + }, + None => (db_anon_role.as_ref(), false), + }; + + // do not allow unauthenticated requests when there is no anonymous role setup + if let (None, false) = (role, authenticated) { + return Err(RestError::SubzeroCore(JwtTokenInvalid { + message: "unauthenticated requests not allowed".to_string(), + })); + } + + // start deconstructing the request because subzero core mostly works with &str + let method = parts.method; + let method_str = method.as_str(); + let path = parts.uri.path_and_query().map_or("/", |pq| pq.as_str()); + + // this is actually the table name (or rpc/function_name) + // TODO: rename this to something more descriptive + let root = match parts.uri.path().strip_prefix(api_prefix) { + Some(p) => Ok(p), + None => Err(RestError::SubzeroCore(NotFound { + target: parts.uri.path().to_string(), + })), + }?; + + // pick the current schema from the headers (or the first one from config) + let schema_name = &DbSchema::pick_current_schema(db_schemas, method_str, &parts.headers)?; + + // add the content-profile header to the response + let mut response_headers = vec![]; + if db_schemas.len() > 1 { + response_headers.push(("Content-Profile".to_string(), schema_name.clone())); + } + + // parse the query string into a Vec<(&str, &str)> + let query = match parts.uri.query() { + Some(q) => form_urlencoded::parse(q.as_bytes()).collect(), + None => vec![], + }; + let get: Vec<(&str, &str)> = query.iter().map(|(k, v)| (&**k, &**v)).collect(); + + // convert the headers map to a HashMap<&str, &str> + let headers: HashMap<&str, &str> = parts + .headers + .iter() + .map(|(k, v)| (k.as_str(), v.to_str().unwrap_or("__BAD_HEADER__"))) + .collect(); + + let cookies = HashMap::new(); // TODO: add cookies + + // Read the request body (skip for GET requests) + let body_as_string: Option = if method == Method::GET { + None + } else { + let body_bytes = + read_body_with_limit(originial_body, config.http_config.max_request_size_bytes) + .await + .map_err(ReadPayloadError::from)?; + if body_bytes.is_empty() { + None + } else { + Some(String::from_utf8_lossy(&body_bytes).into_owned()) + } + }; + + // parse the request into an ApiRequest struct + let mut api_request = parse( + schema_name, + root, + db_schema, + method_str, + path, + get, + body_as_string.as_deref(), + headers, + cookies, + max_rows, + ) + .map_err(RestError::SubzeroCore)?; + + let role_str = match role { + Some(r) => r, + None => "", + }; + + replace_select_star(db_schema, schema_name, role_str, &mut api_request.query)?; + + // TODO: this is not relevant when acting as PostgREST but will be useful + // in the context of DBX where they need internal permissions + // if !disable_internal_permissions { + // check_privileges(db_schema, schema_name, role_str, &api_request)?; + // } + + check_safe_functions(&api_request, &db_allowed_select_functions)?; + + // TODO: this is not relevant when acting as PostgREST but will be useful + // in the context of DBX where they need internal permissions + // if !disable_internal_permissions { + // insert_policy_conditions(db_schema, schema_name, role_str, &mut api_request.query)?; + // } + + let env_role = Some(role_str); + + // construct the env (passed in to the sql context as GUCs) + let empty_json = "{}".to_string(); + let headers_env = serde_json::to_string(&api_request.headers).unwrap_or(empty_json.clone()); + let cookies_env = serde_json::to_string(&api_request.cookies).unwrap_or(empty_json.clone()); + let get_env = serde_json::to_string(&api_request.get).unwrap_or(empty_json.clone()); + let jwt_claims_env = jwt_claims + .as_ref() + .map(|v| serde_json::to_string(v).unwrap_or(empty_json.clone())) + .unwrap_or(if let Some(r) = env_role { + let claims: HashMap<&str, &str> = HashMap::from([("role", r)]); + serde_json::to_string(&claims).unwrap_or(empty_json.clone()) + } else { + empty_json.clone() + }); + let mut search_path = vec![api_request.schema_name]; + if let Some(extra) = &db_extra_search_path { + search_path.extend(extra.iter().map(|s| s.as_str())); + } + let search_path_str = search_path + .into_iter() + .filter(|s| !s.is_empty()) + .collect::>() + .join(","); + let mut env: HashMap<&str, &str> = HashMap::from([ + ("request.method", api_request.method), + ("request.path", api_request.path), + ("request.headers", &headers_env), + ("request.cookies", &cookies_env), + ("request.get", &get_env), + ("request.jwt.claims", &jwt_claims_env), + ("search_path", &search_path_str), + ]); + if let Some(r) = env_role { + env.insert("role", r); + } + + // generate the sql statements + let (env_statement, env_parameters, _) = generate(fmt_env_query(&env)); + let (main_statement, main_parameters, _) = generate(fmt_main_query( + db_schema, + api_request.schema_name, + &api_request, + &env, + )?); + + let mut headers = vec![ + (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())), + ( + &CONN_STRING, + HeaderValue::from_str(connection_string).expect("invalid connection string"), + ), + (&AUTHORIZATION, auth_header.clone()), + ( + &TXN_ISOLATION_LEVEL, + HeaderValue::from_static("ReadCommitted"), + ), + (&ALLOW_POOL, HEADER_VALUE_TRUE), + ]; + + if api_request.read_only { + headers.push((&TXN_READ_ONLY, HEADER_VALUE_TRUE)); + } + + // convert the parameters from subzero core representation to the local proxy repr. + let req_body = serde_json::to_string(&BatchQueryData { + queries: vec![ + QueryData { + query: env_statement.into(), + params: env_parameters + .iter() + .map(|p| to_sql_param(&p.to_param())) + .collect(), + }, + QueryData { + query: main_statement.into(), + params: main_parameters + .iter() + .map(|p| to_sql_param(&p.to_param())) + .collect(), + }, + ], + }) + .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?; + + // todo: map body to count egress + let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly? + + // send the request to the local proxy + let response = make_raw_local_proxy_request(&mut client, headers, req_body).await?; + let (parts, body) = response.into_parts(); + + let max_response = config.http_config.max_response_size_bytes; + let bytes = read_body_with_limit(body, max_response) + .await + .map_err(ReadPayloadError::from)?; + + // if the response status is greater than 399, then it is an error + // FIXME: check if there are other error codes or shapes of the response + if parts.status.as_u16() > 399 { + // turn this postgres error from the json into PostgresError + let postgres_error = serde_json::from_slice(&bytes) + .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; + + return Err(RestError::Postgres(postgres_error)); + } + + #[derive(Deserialize)] + struct QueryResults { + /// we run two queries, so we want only two results. + results: (EnvRows, MainRows), + } + + /// `env_statement` returns nothing of interest to us + #[derive(Deserialize)] + struct EnvRows {} + + #[derive(Deserialize)] + struct MainRows { + /// `main_statement` only returns a single row. + rows: [MainRow; 1], + } + + #[derive(Deserialize)] + struct MainRow { + body: String, + page_total: Option, + total_result_set: Option, + response_headers: Option, + response_status: Option, + } + + let results: QueryResults = serde_json::from_slice(&bytes) + .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; + + let QueryResults { + results: (_, MainRows { rows: [row] }), + } = results; + + // build the intermediate response object + let api_response = ApiResponse { + page_total: row.page_total.map_or(0, |v| v.parse::().unwrap_or(0)), + total_result_set: row.total_result_set.map(|v| v.parse::().unwrap_or(0)), + top_level_offset: 0, // FIXME: check why this is 0 + response_headers: row.response_headers, + response_status: row.response_status, + body: row.body, + }; + + // TODO: rollback the transaction if the page_total is not 1 and the accept_content_type is SingularJSON + // we can not do this in the context of proxy for now + // if api_request.accept_content_type == SingularJSON && api_response.page_total != 1 { + // // rollback the transaction here + // return Err(RestError::SubzeroCore(SingularityError { + // count: api_response.page_total, + // content_type: "application/vnd.pgrst.object+json".to_string(), + // })); + // } + + // TODO: rollback the transaction if the page_total is not 1 and the method is PUT + // we can not do this in the context of proxy for now + // if api_request.method == Method::PUT && api_response.page_total != 1 { + // // Makes sure the querystring pk matches the payload pk + // // e.g. PUT /items?id=eq.1 { "id" : 1, .. } is accepted, + // // PUT /items?id=eq.14 { "id" : 2, .. } is rejected. + // // If this condition is not satisfied then nothing is inserted, + // // rollback the transaction here + // return Err(RestError::SubzeroCore(PutMatchingPkError)); + // } + + // create and return the response to the client + // this section mostly deals with setting the right headers according to PostgREST specs + let page_total = api_response.page_total; + let total_result_set = api_response.total_result_set; + let top_level_offset = api_response.top_level_offset; + let response_content_type = match (&api_request.accept_content_type, &api_request.query.node) { + (SingularJSON, _) + | ( + _, + FunctionCall { + returns_single: true, + is_scalar: false, + .. + }, + ) => SingularJSON, + (TextCSV, _) => TextCSV, + _ => ApplicationJSON, + }; + + // check if the SQL env set some response headers (happens when we called a rpc function) + if let Some(response_headers_str) = api_response.response_headers { + let Ok(headers_json) = + serde_json::from_str::>>(response_headers_str.as_str()) + else { + return Err(RestError::SubzeroCore(GucHeadersError)); + }; + + response_headers.extend(headers_json.into_iter().flatten()); + } + + // calculate and set the content range header + let lower = top_level_offset as i64; + let upper = top_level_offset as i64 + page_total as i64 - 1; + let total = total_result_set.map(|t| t as i64); + let content_range = match (&method, &api_request.query.node) { + (&Method::POST, Insert { .. }) => content_range_header(1, 0, total), + (&Method::DELETE, Delete { .. }) => content_range_header(1, upper, total), + _ => content_range_header(lower, upper, total), + }; + response_headers.push(("Content-Range".to_string(), content_range)); + + // calculate the status code + #[rustfmt::skip] + let mut status = match (&method, &api_request.query.node, page_total, &api_request.preferences) { + (&Method::POST, Insert { .. }, ..) => 201, + (&Method::DELETE, Delete { .. }, _, Some(Preferences {representation: Some(Representation::Full),..}),) => 200, + (&Method::DELETE, Delete { .. }, ..) => 204, + (&Method::PATCH, Update { columns, .. }, 0, _) if !columns.is_empty() => 404, + (&Method::PATCH, Update { .. }, _,Some(Preferences {representation: Some(Representation::Full),..}),) => 200, + (&Method::PATCH, Update { .. }, ..) => 204, + (&Method::PUT, Insert { .. },_,Some(Preferences {representation: Some(Representation::Full),..}),) => 200, + (&Method::PUT, Insert { .. }, ..) => 204, + _ => content_range_status(lower, upper, total), + }; + + // add the preference-applied header + if let Some(Preferences { + resolution: Some(r), + .. + }) = api_request.preferences + { + response_headers.push(( + "Preference-Applied".to_string(), + match r { + MergeDuplicates => "resolution=merge-duplicates".to_string(), + IgnoreDuplicates => "resolution=ignore-duplicates".to_string(), + }, + )); + } + + // check if the SQL env set some response status (happens when we called a rpc function) + if let Some(response_status_str) = api_response.response_status { + status = response_status_str + .parse::() + .map_err(|_| RestError::SubzeroCore(GucStatusError))?; + } + + // set the content type header + // TODO: move this to a subzero function + // as_header_value(&self) -> Option<&str> + let http_content_type = match response_content_type { + SingularJSON => Ok("application/vnd.pgrst.object+json"), + TextCSV => Ok("text/csv"), + ApplicationJSON => Ok("application/json"), + Other(t) => Err(RestError::SubzeroCore(ContentTypeError { + message: format!("None of these Content-Types are available: {t}"), + })), + }?; + + // build the response body + let response_body = Full::new(Bytes::from(api_response.body)) + .map_err(|never| match never {}) + .boxed(); + + // build the response + let mut response = Response::builder() + .status(StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) + .header(CONTENT_TYPE, http_content_type); + + // Add all headers from response_headers vector + for (header_name, header_value) in response_headers { + response = response.header(header_name, header_value); + } + + // add the body and return the response + response.body(response_body).map_err(|_| { + RestError::SubzeroCore(InternalError { + message: "Failed to build response".to_string(), + }) + }) +} diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 8a14f804b6..f254b41b5b 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -64,7 +64,7 @@ enum Payload { Batch(BatchQueryData), } -static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); +pub(super) const HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result>, D::Error> where diff --git a/proxy/src/util.rs b/proxy/src/util.rs index 0291216d94..c89ebab008 100644 --- a/proxy/src/util.rs +++ b/proxy/src/util.rs @@ -20,3 +20,13 @@ pub async fn run_until( Either::Right((f2, _)) => Err(f2), } } + +pub fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result +where + T: for<'de2> serde::Deserialize<'de2>, + D: serde::Deserializer<'de>, +{ + use serde::Deserialize; + let s = String::deserialize(deserializer)?; + serde_json::from_str(&s).map_err(::custom) +} diff --git a/proxy/subzero_core/.gitignore b/proxy/subzero_core/.gitignore new file mode 100644 index 0000000000..f2f9e58ec3 --- /dev/null +++ b/proxy/subzero_core/.gitignore @@ -0,0 +1,2 @@ +target +Cargo.lock \ No newline at end of file diff --git a/proxy/subzero_core/Cargo.toml b/proxy/subzero_core/Cargo.toml new file mode 100644 index 0000000000..13185873d0 --- /dev/null +++ b/proxy/subzero_core/Cargo.toml @@ -0,0 +1,12 @@ +# This is a stub for the subzero-core crate. +[package] +name = "subzero-core" +version = "3.0.1" +edition = "2024" +publish = false # "private"! + +[features] +default = [] +postgresql = [] + +[dependencies] diff --git a/proxy/subzero_core/src/lib.rs b/proxy/subzero_core/src/lib.rs new file mode 100644 index 0000000000..b99246b98b --- /dev/null +++ b/proxy/subzero_core/src/lib.rs @@ -0,0 +1 @@ +// This is a stub for the subzero-core crate. diff --git a/safekeeper/src/rate_limit.rs b/safekeeper/src/rate_limit.rs index 72373b5786..0e697ade57 100644 --- a/safekeeper/src/rate_limit.rs +++ b/safekeeper/src/rate_limit.rs @@ -44,6 +44,6 @@ impl RateLimiter { /// Generate a random duration that is a fraction of the given duration. pub fn rand_duration(duration: &std::time::Duration) -> std::time::Duration { - let randf64 = rand::thread_rng().gen_range(0.0..1.0); + let randf64 = rand::rng().random_range(0.0..1.0); duration.mul_f64(randf64) } diff --git a/safekeeper/tests/random_test.rs b/safekeeper/tests/random_test.rs index e29b58836a..7e7d2390e9 100644 --- a/safekeeper/tests/random_test.rs +++ b/safekeeper/tests/random_test.rs @@ -16,7 +16,7 @@ fn test_random_schedules() -> anyhow::Result<()> { let mut config = TestConfig::new(Some(clock)); for _ in 0..500 { - let seed: u64 = rand::thread_rng().r#gen(); + let seed: u64 = rand::rng().random(); config.network = generate_network_opts(seed); let test = config.start(seed); diff --git a/safekeeper/tests/walproposer_sim/simulation.rs b/safekeeper/tests/walproposer_sim/simulation.rs index edd3bf2d9e..595cc7ab64 100644 --- a/safekeeper/tests/walproposer_sim/simulation.rs +++ b/safekeeper/tests/walproposer_sim/simulation.rs @@ -394,13 +394,13 @@ pub fn generate_schedule(seed: u64) -> Schedule { let mut schedule = Vec::new(); let mut time = 0; - let cnt = rng.gen_range(1..100); + let cnt = rng.random_range(1..100); for _ in 0..cnt { - time += rng.gen_range(0..500); - let action = match rng.gen_range(0..3) { - 0 => TestAction::WriteTx(rng.gen_range(1..10)), - 1 => TestAction::RestartSafekeeper(rng.gen_range(0..3)), + time += rng.random_range(0..500); + let action = match rng.random_range(0..3) { + 0 => TestAction::WriteTx(rng.random_range(1..10)), + 1 => TestAction::RestartSafekeeper(rng.random_range(0..3)), 2 => TestAction::RestartWalProposer, _ => unreachable!(), }; @@ -413,13 +413,13 @@ pub fn generate_schedule(seed: u64) -> Schedule { pub fn generate_network_opts(seed: u64) -> NetworkOptions { let mut rng = rand::rngs::StdRng::seed_from_u64(seed); - let timeout = rng.gen_range(100..2000); - let max_delay = rng.gen_range(1..2 * timeout); - let min_delay = rng.gen_range(1..=max_delay); + let timeout = rng.random_range(100..2000); + let max_delay = rng.random_range(1..2 * timeout); + let min_delay = rng.random_range(1..=max_delay); - let max_fail_prob = rng.gen_range(0.0..0.9); - let connect_fail_prob = rng.gen_range(0.0..max_fail_prob); - let send_fail_prob = rng.gen_range(0.0..connect_fail_prob); + let max_fail_prob = rng.random_range(0.0..0.9); + let connect_fail_prob = rng.random_range(0.0..max_fail_prob); + let send_fail_prob = rng.random_range(0.0..connect_fail_prob); NetworkOptions { keepalive_timeout: Some(timeout), diff --git a/storage_controller/src/hadron_utils.rs b/storage_controller/src/hadron_utils.rs index 871e21c367..8bfbe8e575 100644 --- a/storage_controller/src/hadron_utils.rs +++ b/storage_controller/src/hadron_utils.rs @@ -8,10 +8,10 @@ static CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz01 /// Generate a random string of `length` that can be used as a password. The generated string /// contains alphanumeric characters and special characters (!@#$%^&*()) pub fn generate_random_password(length: usize) -> String { - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); (0..length) .map(|_| { - let idx = rng.gen_range(0..CHARSET.len()); + let idx = rng.random_range(0..CHARSET.len()); CHARSET[idx] as char }) .collect() diff --git a/storage_controller/src/persistence.rs b/storage_controller/src/persistence.rs index ed9a268064..2e3f8c6908 100644 --- a/storage_controller/src/persistence.rs +++ b/storage_controller/src/persistence.rs @@ -129,6 +129,7 @@ pub(crate) enum DatabaseOperation { UpdateLeader, SetPreferredAzs, InsertTimeline, + UpdateTimeline, UpdateTimelineMembership, GetTimeline, InsertTimelineReconcile, @@ -1463,6 +1464,36 @@ impl Persistence { .await } + /// Update an already present timeline. + /// VERY UNSAFE FUNCTION: this overrides in-progress migrations. Don't use this unless neccessary. + pub(crate) async fn update_timeline_unsafe( + &self, + entry: TimelineUpdate, + ) -> DatabaseResult { + use crate::schema::timelines; + + let entry = &entry; + self.with_measured_conn(DatabaseOperation::UpdateTimeline, move |conn| { + Box::pin(async move { + let inserted_updated = diesel::update(timelines::table) + .filter(timelines::tenant_id.eq(&entry.tenant_id)) + .filter(timelines::timeline_id.eq(&entry.timeline_id)) + .set(entry) + .execute(conn) + .await?; + + match inserted_updated { + 0 => Ok(false), + 1 => Ok(true), + _ => Err(DatabaseError::Logical(format!( + "unexpected number of rows ({inserted_updated})" + ))), + } + }) + }) + .await + } + /// Update timeline membership configuration in the database. /// Perform a compare-and-swap (CAS) operation on the timeline's generation. /// The `new_generation` must be the next (+1) generation after the one in the database. @@ -2503,6 +2534,18 @@ impl TimelineFromDb { } } +// This is separate from TimelinePersistence because we don't want to touch generation and deleted_at values for the update. +#[derive(AsChangeset)] +#[diesel(table_name = crate::schema::timelines)] +#[diesel(treat_none_as_null = true)] +pub(crate) struct TimelineUpdate { + pub(crate) tenant_id: String, + pub(crate) timeline_id: String, + pub(crate) start_lsn: LsnWrapper, + pub(crate) sk_set: Vec, + pub(crate) new_sk_set: Option>, +} + #[derive(Insertable, AsChangeset, Queryable, Selectable, Clone)] #[diesel(table_name = crate::schema::safekeeper_timeline_pending_ops)] pub(crate) struct TimelinePendingOpPersistence { diff --git a/storage_controller/src/service/chaos_injector.rs b/storage_controller/src/service/chaos_injector.rs index 4087de200a..0efeef4e80 100644 --- a/storage_controller/src/service/chaos_injector.rs +++ b/storage_controller/src/service/chaos_injector.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use std::time::Duration; use pageserver_api::controller_api::ShardSchedulingPolicy; -use rand::seq::SliceRandom; -use rand::{Rng, thread_rng}; +use rand::Rng; +use rand::seq::{IndexedRandom, SliceRandom}; use tokio_util::sync::CancellationToken; use utils::id::NodeId; use utils::shard::TenantShardId; @@ -72,7 +72,7 @@ impl ChaosInjector { let cron_interval = self.get_cron_interval_sleep_future(); let chaos_type = tokio::select! { _ = interval.tick() => { - if thread_rng().gen_bool(0.5) { + if rand::rng().random_bool(0.5) { ChaosEvent::MigrationsToSecondary } else { ChaosEvent::GracefulMigrationsAnywhere @@ -134,7 +134,7 @@ impl ChaosInjector { let Some(new_location) = shard .intent .get_secondary() - .choose(&mut thread_rng()) + .choose(&mut rand::rng()) .cloned() else { tracing::info!( @@ -190,7 +190,7 @@ impl ChaosInjector { // Pick our victims: use a hand-rolled loop rather than choose_multiple() because we want // to take the mutable refs from our candidates rather than ref'ing them. while !candidates.is_empty() && victims.len() < batch_size { - let i = thread_rng().gen_range(0..candidates.len()); + let i = rand::rng().random_range(0..candidates.len()); victims.push(candidates.swap_remove(i)); } @@ -210,7 +210,7 @@ impl ChaosInjector { }) .collect::>(); - let Some(victim_node) = candidate_nodes.choose(&mut thread_rng()) else { + let Some(victim_node) = candidate_nodes.choose(&mut rand::rng()) else { // This can happen if e.g. we are in a small region with only one pageserver per AZ. tracing::info!( "no candidate nodes found for migrating shard {tenant_shard_id} within its home AZ", @@ -264,7 +264,7 @@ impl ChaosInjector { out_of_home_az.len() ); - out_of_home_az.shuffle(&mut thread_rng()); + out_of_home_az.shuffle(&mut rand::rng()); victims.extend(out_of_home_az.into_iter().take(batch_size)); } else { tracing::info!( @@ -274,7 +274,7 @@ impl ChaosInjector { ); victims.extend(out_of_home_az); - in_home_az.shuffle(&mut thread_rng()); + in_home_az.shuffle(&mut rand::rng()); victims.extend(in_home_az.into_iter().take(batch_size - victims.len())); } diff --git a/storage_controller/src/service/safekeeper_service.rs b/storage_controller/src/service/safekeeper_service.rs index 7521d7bd86..28c70e203a 100644 --- a/storage_controller/src/service/safekeeper_service.rs +++ b/storage_controller/src/service/safekeeper_service.rs @@ -10,6 +10,7 @@ use crate::id_lock_map::trace_shared_lock; use crate::metrics; use crate::persistence::{ DatabaseError, SafekeeperTimelineOpKind, TimelinePendingOpPersistence, TimelinePersistence, + TimelineUpdate, }; use crate::safekeeper::Safekeeper; use crate::safekeeper_client::SafekeeperClient; @@ -454,19 +455,33 @@ impl Service { let persistence = TimelinePersistence { tenant_id: req.tenant_id.to_string(), timeline_id: req.timeline_id.to_string(), - start_lsn: Lsn::INVALID.into(), + start_lsn: req.start_lsn.into(), generation: 1, sk_set: req.sk_set.iter().map(|sk_id| sk_id.0 as i64).collect(), new_sk_set: None, cplane_notified_generation: 1, deleted_at: None, }; - let inserted = self.persistence.insert_timeline(persistence).await?; + let inserted = self + .persistence + .insert_timeline(persistence.clone()) + .await?; if inserted { tracing::info!("imported timeline into db"); - } else { - tracing::info!("didn't import timeline into db, as it is already present in db"); + return Ok(()); } + tracing::info!("timeline already present in db, updating"); + + let update = TimelineUpdate { + tenant_id: persistence.tenant_id, + timeline_id: persistence.timeline_id, + start_lsn: persistence.start_lsn, + sk_set: persistence.sk_set, + new_sk_set: persistence.new_sk_set, + }; + self.persistence.update_timeline_unsafe(update).await?; + tracing::info!("timeline updated"); + Ok(()) } diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index c9d225909c..e402af1129 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -4123,6 +4123,294 @@ class NeonAuthBroker: self._popen.kill() +class NeonLocalProxy(LogUtils): + """ + An object managing a local_proxy instance for rest broker testing. + The local_proxy serves as a direct connection to VanillaPostgres. + """ + + def __init__( + self, + neon_binpath: Path, + test_output_dir: Path, + http_port: int, + metrics_port: int, + vanilla_pg: VanillaPostgres, + config_path: Path | None = None, + ): + self.neon_binpath = neon_binpath + self.test_output_dir = test_output_dir + self.http_port = http_port + self.metrics_port = metrics_port + self.vanilla_pg = vanilla_pg + self.config_path = config_path or (test_output_dir / "local_proxy.json") + self.host = "127.0.0.1" + self.running = False + self.logfile = test_output_dir / "local_proxy.log" + self._popen: subprocess.Popen[bytes] | None = None + super().__init__(logfile=self.logfile) + + def start(self) -> Self: + assert self._popen is None + assert not self.running + + # Ensure vanilla_pg is running + if not self.vanilla_pg.is_running(): + self.vanilla_pg.start() + + args = [ + str(self.neon_binpath / "local_proxy"), + "--http", + f"{self.host}:{self.http_port}", + "--metrics", + f"{self.host}:{self.metrics_port}", + "--postgres", + f"127.0.0.1:{self.vanilla_pg.default_options['port']}", + "--config-path", + str(self.config_path), + "--disable-pg-session-jwt", + ] + + logfile = open(self.logfile, "w") + self._popen = subprocess.Popen(args, stdout=logfile, stderr=logfile) + self.running = True + self._wait_until_ready() + return self + + def stop(self) -> Self: + if self._popen is not None and self.running: + self._popen.terminate() + try: + self._popen.wait(timeout=5) + except subprocess.TimeoutExpired: + log.warning("failed to gracefully terminate local_proxy; killing") + self._popen.kill() + self.running = False + return self + + def get_binary_version(self) -> str: + """Get the version string of the local_proxy binary""" + try: + result = subprocess.run( + [str(self.neon_binpath / "local_proxy"), "--version"], + capture_output=True, + text=True, + timeout=10, + ) + return result.stdout.strip() + except (subprocess.TimeoutExpired, subprocess.CalledProcessError): + return "" + + @backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=10) + def _wait_until_ready(self): + assert self._popen and self._popen.poll() is None, ( + "Local proxy exited unexpectedly. Check test log." + ) + requests.get(f"http://{self.host}:{self.http_port}/metrics") + + def get_metrics(self) -> str: + response = requests.get(f"http://{self.host}:{self.metrics_port}/metrics") + return response.text + + def assert_no_errors(self): + # Define allowed error patterns for local_proxy + allowed_errors = [ + # Add patterns as needed + ] + not_allowed = [ + "error", + "panic", + "failed", + ] + + for na in not_allowed: + if na not in allowed_errors: + assert not self.log_contains(na), f"Found disallowed error pattern: {na}" + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ): + self.stop() + + +class NeonRestBrokerProxy(LogUtils): + """ + An object managing a proxy instance configured as both auth broker and rest broker. + This is the main proxy binary with --is-auth-broker and --is-rest-broker flags. + """ + + def __init__( + self, + neon_binpath: Path, + test_output_dir: Path, + wss_port: int, + http_port: int, + mgmt_port: int, + config_path: Path | None = None, + ): + self.neon_binpath = neon_binpath + self.test_output_dir = test_output_dir + self.wss_port = wss_port + self.http_port = http_port + self.mgmt_port = mgmt_port + self.config_path = config_path or (test_output_dir / "rest_broker_proxy.json") + self.host = "127.0.0.1" + self.running = False + self.logfile = test_output_dir / "rest_broker_proxy.log" + self._popen: subprocess.Popen[Any] | None = None + + def start(self) -> Self: + if self.running: + return self + + # Generate self-signed TLS certificates + cert_path = self.test_output_dir / "server.crt" + key_path = self.test_output_dir / "server.key" + + if not cert_path.exists() or not key_path.exists(): + import subprocess + + log.info("Generating self-signed TLS certificate for rest broker") + subprocess.run( + [ + "openssl", + "req", + "-new", + "-x509", + "-days", + "365", + "-nodes", + "-text", + "-out", + str(cert_path), + "-keyout", + str(key_path), + "-subj", + "/CN=*.local.neon.build", + ], + check=True, + ) + + log.info( + f"Starting rest broker proxy on WSS port {self.wss_port}, HTTP port {self.http_port}" + ) + + cmd = [ + str(self.neon_binpath / "proxy"), + "-c", + str(cert_path), + "-k", + str(key_path), + "--is-auth-broker", + "true", + "--is-rest-broker", + "true", + "--wss", + f"{self.host}:{self.wss_port}", + "--http", + f"{self.host}:{self.http_port}", + "--mgmt", + f"{self.host}:{self.mgmt_port}", + "--auth-backend", + "local", + "--config-path", + str(self.config_path), + ] + + log.info(f"Starting rest broker proxy with command: {' '.join(cmd)}") + + with open(self.logfile, "w") as logfile: + self._popen = subprocess.Popen( + cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + cwd=self.test_output_dir, + env={ + **os.environ, + "RUST_LOG": "info", + "LOGFMT": "text", + "OTEL_SDK_DISABLED": "true", + }, + ) + + self.running = True + self._wait_until_ready() + return self + + def stop(self) -> Self: + if not self.running: + return self + + log.info("Stopping rest broker proxy") + + if self._popen is not None: + self._popen.terminate() + try: + self._popen.wait(timeout=10) + except subprocess.TimeoutExpired: + log.warning("failed to gracefully terminate rest broker proxy; killing") + self._popen.kill() + + self.running = False + return self + + def get_binary_version(self) -> str: + cmd = [str(self.neon_binpath / "proxy"), "--version"] + res = subprocess.run(cmd, capture_output=True, text=True, check=True) + return res.stdout.strip() + + @backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=10) + def _wait_until_ready(self): + # Check if the WSS port is ready using a simple HTTPS request + # REST API is served on the WSS port with HTTPS + requests.get(f"https://{self.host}:{self.wss_port}/", timeout=1, verify=False) + # Any response (even error) means the server is up - we just need to connect + + def get_metrics(self) -> str: + # Metrics are still on the HTTP port + response = requests.get(f"http://{self.host}:{self.http_port}/metrics", timeout=5) + response.raise_for_status() + return response.text + + def assert_no_errors(self): + # Define allowed error patterns for rest broker proxy + allowed_errors = [ + "connection closed before message completed", + "connection reset by peer", + "broken pipe", + "client disconnected", + "Authentication failed", + "connection timed out", + "no connection available", + "Pool dropped", + ] + + with open(self.logfile) as f: + for line in f: + if "ERROR" in line or "FATAL" in line: + if not any(allowed in line for allowed in allowed_errors): + raise AssertionError( + f"Found error in rest broker proxy log: {line.strip()}" + ) + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ): + self.stop() + + @pytest.fixture(scope="function") def link_proxy( port_distributor: PortDistributor, neon_binpath: Path, test_output_dir: Path @@ -4205,6 +4493,81 @@ def static_proxy( yield proxy +@pytest.fixture(scope="function") +def local_proxy( + vanilla_pg: VanillaPostgres, + port_distributor: PortDistributor, + neon_binpath: Path, + test_output_dir: Path, +) -> Iterator[NeonLocalProxy]: + """Local proxy that connects directly to vanilla postgres for rest broker testing.""" + + # Start vanilla_pg without database bootstrapping + vanilla_pg.start() + + http_port = port_distributor.get_port() + metrics_port = port_distributor.get_port() + + with NeonLocalProxy( + neon_binpath=neon_binpath, + test_output_dir=test_output_dir, + http_port=http_port, + metrics_port=metrics_port, + vanilla_pg=vanilla_pg, + ) as proxy: + proxy.start() + yield proxy + + +@pytest.fixture(scope="function") +def local_proxy_fixed_port( + vanilla_pg: VanillaPostgres, + neon_binpath: Path, + test_output_dir: Path, +) -> Iterator[NeonLocalProxy]: + """Local proxy that connects directly to vanilla postgres on the hardcoded port 7432.""" + + # Start vanilla_pg without database bootstrapping + vanilla_pg.start() + + # Use the hardcoded port that the rest broker proxy expects + http_port = 7432 + metrics_port = 7433 # Use a different port for metrics + + with NeonLocalProxy( + neon_binpath=neon_binpath, + test_output_dir=test_output_dir, + http_port=http_port, + metrics_port=metrics_port, + vanilla_pg=vanilla_pg, + ) as proxy: + proxy.start() + yield proxy + + +@pytest.fixture(scope="function") +def rest_broker_proxy( + port_distributor: PortDistributor, + neon_binpath: Path, + test_output_dir: Path, +) -> Iterator[NeonRestBrokerProxy]: + """Rest broker proxy that handles both auth broker and rest broker functionality.""" + + wss_port = port_distributor.get_port() + http_port = port_distributor.get_port() + mgmt_port = port_distributor.get_port() + + with NeonRestBrokerProxy( + neon_binpath=neon_binpath, + test_output_dir=test_output_dir, + wss_port=wss_port, + http_port=http_port, + mgmt_port=mgmt_port, + ) as proxy: + proxy.start() + yield proxy + + @pytest.fixture(scope="function") def neon_authorize_jwk() -> jwk.JWK: kid = str(uuid.uuid4()) diff --git a/test_runner/fixtures/utils.py b/test_runner/fixtures/utils.py index 0d7345cc82..1f80c2a290 100644 --- a/test_runner/fixtures/utils.py +++ b/test_runner/fixtures/utils.py @@ -741,3 +741,29 @@ def shared_buffers_for_max_cu(max_cu: float) -> str: sharedBuffersMb = int(max(128, (1023 + maxBackends * 256) / 1024)) sharedBuffers = int(sharedBuffersMb * 1024 / 8) return str(sharedBuffers) + + +def skip_if_proxy_lacks_rest_broker(reason: str = "proxy was built without 'rest_broker' feature"): + # Determine the binary path using the same logic as neon_binpath fixture + def has_rest_broker_feature(): + # Find the neon binaries + if env_neon_bin := os.environ.get("NEON_BIN"): + binpath = Path(env_neon_bin) + else: + base_dir = Path(__file__).parents[2] # Same as BASE_DIR in paths.py + build_type = os.environ.get("BUILD_TYPE", "debug") + binpath = base_dir / "target" / build_type + + proxy_bin = binpath / "proxy" + if not proxy_bin.exists(): + return False + + try: + cmd = [str(proxy_bin), "--help"] + result = subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=10) + help_output = result.stdout + return "--is-rest-broker" in help_output + except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError): + return False + + return pytest.mark.skipif(not has_rest_broker_feature(), reason=reason) diff --git a/test_runner/regress/test_pg_regress.py b/test_runner/regress/test_pg_regress.py index a240071a7f..dd9c5437ad 100644 --- a/test_runner/regress/test_pg_regress.py +++ b/test_runner/regress/test_pg_regress.py @@ -368,7 +368,14 @@ def test_max_wal_rate(neon_simple_env: NeonEnv): superuser_name = "databricks_superuser" # Connect to postgres and create a database called "regression". - endpoint = env.endpoints.create_start("main") + endpoint = env.endpoints.create_start( + "main", + config_lines=[ + # we need this option because default max_cluster_size < 0 will disable throttling completely + "neon.max_cluster_size=10GB", + ], + ) + endpoint.safe_psql_many( [ f"CREATE ROLE {superuser_name}", diff --git a/test_runner/regress/test_rest_broker.py b/test_runner/regress/test_rest_broker.py new file mode 100644 index 0000000000..60b04655d3 --- /dev/null +++ b/test_runner/regress/test_rest_broker.py @@ -0,0 +1,137 @@ +import json +import signal +import time + +import requests +from fixtures.utils import skip_if_proxy_lacks_rest_broker +from jwcrypto import jwt + + +@skip_if_proxy_lacks_rest_broker() +def test_rest_broker_happy( + local_proxy_fixed_port, rest_broker_proxy, vanilla_pg, neon_authorize_jwk, httpserver +): + """Test REST API endpoint using local_proxy and rest_broker_proxy.""" + + # Use the fixed port local proxy + local_proxy = local_proxy_fixed_port + + # Create the required roles for PostgREST authentication + vanilla_pg.safe_psql("CREATE ROLE authenticator LOGIN") + vanilla_pg.safe_psql("CREATE ROLE authenticated") + vanilla_pg.safe_psql("CREATE ROLE anon") + vanilla_pg.safe_psql("GRANT authenticated TO authenticator") + vanilla_pg.safe_psql("GRANT anon TO authenticator") + + # Create the pgrst schema and configuration function required by the rest broker + vanilla_pg.safe_psql("CREATE SCHEMA IF NOT EXISTS pgrst") + vanilla_pg.safe_psql(""" + CREATE OR REPLACE FUNCTION pgrst.pre_config() + RETURNS VOID AS $$ + SELECT + set_config('pgrst.db_schemas', 'test', true) + , set_config('pgrst.db_aggregates_enabled', 'true', true) + , set_config('pgrst.db_anon_role', 'anon', true) + , set_config('pgrst.jwt_aud', '', true) + , set_config('pgrst.jwt_secret', '', true) + , set_config('pgrst.jwt_role_claim_key', '."role"', true) + + $$ LANGUAGE SQL; + """) + vanilla_pg.safe_psql("GRANT USAGE ON SCHEMA pgrst TO authenticator") + vanilla_pg.safe_psql("GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA pgrst TO authenticator") + + # Bootstrap the database with test data + vanilla_pg.safe_psql("CREATE SCHEMA IF NOT EXISTS test") + vanilla_pg.safe_psql(""" + CREATE TABLE IF NOT EXISTS test.items ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL + ) + """) + vanilla_pg.safe_psql("INSERT INTO test.items (name) VALUES ('test_item')") + + # Grant access to the test schema for the authenticated role + vanilla_pg.safe_psql("GRANT USAGE ON SCHEMA test TO authenticated") + vanilla_pg.safe_psql("GRANT SELECT ON ALL TABLES IN SCHEMA test TO authenticated") + + # Set up HTTP server to serve JWKS (like static_auth_broker) + # Generate public key from the JWK + public_key = neon_authorize_jwk.export_public(as_dict=True) + + # Set up the httpserver to serve the JWKS + httpserver.expect_request("/.well-known/jwks.json").respond_with_json({"keys": [public_key]}) + + # Create JWKS configuration for the rest broker proxy + jwks_config = { + "jwks": [ + { + "id": "1", + "role_names": ["authenticator", "authenticated", "anon"], + "jwks_url": httpserver.url_for("/.well-known/jwks.json"), + "provider_name": "foo", + "jwt_audience": None, + } + ] + } + + # Write the JWKS config to the config file that rest_broker_proxy expects + config_file = rest_broker_proxy.config_path + with open(config_file, "w") as f: + json.dump(jwks_config, f) + + # Write the same config to the local_proxy config file + local_config_file = local_proxy.config_path + with open(local_config_file, "w") as f: + json.dump(jwks_config, f) + + # Signal both proxies to reload their config + if rest_broker_proxy._popen is not None: + rest_broker_proxy._popen.send_signal(signal.SIGHUP) + if local_proxy._popen is not None: + local_proxy._popen.send_signal(signal.SIGHUP) + # Wait a bit for config to reload + time.sleep(0.5) + + # Generate a proper JWT token using the JWK (similar to test_auth_broker.py) + token = jwt.JWT( + header={"kid": neon_authorize_jwk.key_id, "alg": "RS256"}, + claims={ + "sub": "user", + "role": "authenticated", # role that's in role_names + "exp": 9999999999, # expires far in the future + "iat": 1000000000, # issued at + }, + ) + token.make_signed_token(neon_authorize_jwk) + + # Debug: Print the JWT claims and config for troubleshooting + print(f"JWT claims: {token.claims}") + print(f"JWT header: {token.header}") + print(f"Config file contains: {jwks_config}") + print(f"Public key kid: {public_key.get('kid')}") + + # Test REST API call - following SUBZERO.md pattern + # REST API is served on the WSS port with HTTPS and includes database name + # ep-purple-glitter-adqior4l-pooler.c-2.us-east-1.aws.neon.tech + url = f"https://foo.apirest.c-2.local.neon.build:{rest_broker_proxy.wss_port}/postgres/rest/v1/items" + + response = requests.get( + url, + headers={ + "Authorization": f"Bearer {token.serialize()}", + }, + params={"id": "eq.1", "select": "name"}, + verify=False, # Skip SSL verification for self-signed certs + ) + + print(f"Response status: {response.status_code}") + print(f"Response headers: {response.headers}") + print(f"Response body: {response.text}") + + # For now, let's just check that we get some response + # We can refine the assertions once we see what the actual response looks like + assert response.status_code in [200] # Any response means the proxies are working + + # check the response body + assert response.json() == [{"name": "test_item"}] diff --git a/test_runner/regress/test_sharding.py b/test_runner/regress/test_sharding.py index 5549105188..2252c098c7 100644 --- a/test_runner/regress/test_sharding.py +++ b/test_runner/regress/test_sharding.py @@ -1810,6 +1810,8 @@ def test_sharding_backpressure(neon_env_builder: NeonEnvBuilder): "config_lines": [ # Tip: set to 100MB to make the test fail "max_replication_write_lag=1MB", + # Hadron: Need to set max_cluster_size to some value to enable any backpressure at all. + "neon.max_cluster_size=1GB", ], # We need `neon` extension for calling backpressure functions, # this flag instructs `compute_ctl` to pre-install it.