diff --git a/.config/hakari.toml b/.config/hakari.toml index 9991cd92b0..dcbc44cc33 100644 --- a/.config/hakari.toml +++ b/.config/hakari.toml @@ -21,13 +21,14 @@ platforms = [ # "x86_64-apple-darwin", # "x86_64-pc-windows-msvc", ] - [final-excludes] workspace-members = [ # vm_monitor benefits from the same Cargo.lock as the rest of our artifacts, but # it is built primarly in separate repo neondatabase/autoscaling and thus is excluded # from depending on workspace-hack because most of the dependencies are not used. "vm_monitor", + # subzero-core is a stub crate that should be excluded from workspace-hack + "subzero-core", # All of these exist in libs and are not usually built independently. # Putting workspace hack there adds a bottleneck for cargo builds. "compute_api", diff --git a/.github/actions/prepare-for-subzero/action.yml b/.github/actions/prepare-for-subzero/action.yml new file mode 100644 index 0000000000..11beb11880 --- /dev/null +++ b/.github/actions/prepare-for-subzero/action.yml @@ -0,0 +1,28 @@ +name: 'Prepare current job for subzero' +description: > + Set git token to access `neondatabase/subzero` from cargo build, + and set `CARGO_NET_GIT_FETCH_WITH_CLI=true` env variable to use git CLI + +inputs: + token: + description: 'GitHub token with access to neondatabase/subzero' + required: true + +runs: + using: "composite" + + steps: + - name: Set git token for neondatabase/subzero + uses: pyTooling/Actions/with-post-step@2307b526df64d55e95884e072e49aac2a00a9afa # v5.1.0 + env: + SUBZERO_ACCESS_TOKEN: ${{ inputs.token }} + with: + main: | + git config --global url."https://x-access-token:${SUBZERO_ACCESS_TOKEN}@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero" + cargo add -p proxy subzero-core --git https://github.com/neondatabase/subzero --rev 396264617e78e8be428682f87469bb25429af88a + post: | + git config --global --unset url."https://x-access-token:${SUBZERO_ACCESS_TOKEN}@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero" + + - name: Set `CARGO_NET_GIT_FETCH_WITH_CLI=true` env variable + shell: bash -euxo pipefail {0} + run: echo "CARGO_NET_GIT_FETCH_WITH_CLI=true" >> ${GITHUB_ENV} diff --git a/.github/workflows/_build-and-test-locally.yml b/.github/workflows/_build-and-test-locally.yml index 94115572df..1b03dc9c03 100644 --- a/.github/workflows/_build-and-test-locally.yml +++ b/.github/workflows/_build-and-test-locally.yml @@ -86,6 +86,10 @@ jobs: with: submodules: true + - uses: ./.github/actions/prepare-for-subzero + with: + token: ${{ secrets.CI_ACCESS_TOKEN }} + - name: Set pg 14 revision for caching id: pg_v14_rev run: echo pg_rev=$(git rev-parse HEAD:vendor/postgres-v14) >> $GITHUB_OUTPUT @@ -116,7 +120,7 @@ jobs: ARCH: ${{ inputs.arch }} SANITIZERS: ${{ inputs.sanitizers }} run: | - CARGO_FLAGS="--locked --features testing" + CARGO_FLAGS="--locked --features testing,rest_broker" if [[ $BUILD_TYPE == "debug" && $ARCH == 'x64' ]]; then cov_prefix="scripts/coverage --profraw-prefix=$GITHUB_JOB --dir=/tmp/coverage run" CARGO_PROFILE="" diff --git a/.github/workflows/_check-codestyle-rust.yml b/.github/workflows/_check-codestyle-rust.yml index 4f844b0bf6..af29e10e97 100644 --- a/.github/workflows/_check-codestyle-rust.yml +++ b/.github/workflows/_check-codestyle-rust.yml @@ -46,6 +46,10 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: true + + - uses: ./.github/actions/prepare-for-subzero + with: + token: ${{ secrets.CI_ACCESS_TOKEN }} - name: Cache cargo deps uses: tespkg/actions-cache@b7bf5fcc2f98a52ac6080eb0fd282c2f752074b1 # v1.8.0 diff --git a/.github/workflows/build-macos.yml b/.github/workflows/build-macos.yml index 2296807d2d..e43eec1133 100644 --- a/.github/workflows/build-macos.yml +++ b/.github/workflows/build-macos.yml @@ -54,6 +54,10 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: submodules: true + + - uses: ./.github/actions/prepare-for-subzero + with: + token: ${{ secrets.CI_ACCESS_TOKEN }} - name: Install build dependencies run: | diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 2977f642bc..f237a991cc 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -632,6 +632,8 @@ jobs: BUILD_TAG=${{ needs.meta.outputs.release-tag || needs.meta.outputs.build-tag }} TAG=${{ needs.build-build-tools-image.outputs.image-tag }}-bookworm DEBIAN_VERSION=bookworm + secrets: | + SUBZERO_ACCESS_TOKEN=${{ secrets.CI_ACCESS_TOKEN }} provenance: false push: true pull: true diff --git a/.github/workflows/neon_extra_builds.yml b/.github/workflows/neon_extra_builds.yml index 3e81183687..10ca1a1591 100644 --- a/.github/workflows/neon_extra_builds.yml +++ b/.github/workflows/neon_extra_builds.yml @@ -72,6 +72,7 @@ jobs: check-macos-build: needs: [ check-permissions, files-changed ] uses: ./.github/workflows/build-macos.yml + secrets: inherit with: pg_versions: ${{ needs.files-changed.outputs.postgres_changes }} rebuild_rust_code: ${{ fromJSON(needs.files-changed.outputs.rebuild_rust_code) }} diff --git a/.gitignore b/.gitignore index 835cceb123..1e1c2316af 100644 --- a/.gitignore +++ b/.gitignore @@ -26,9 +26,14 @@ docker-compose/docker-compose-parallel.yml *.o *.so *.Po +*.pid # pgindent typedef lists *.list # Node **/node_modules/ + +# various files for local testing +/proxy/.subzero +local_proxy.json diff --git a/Cargo.lock b/Cargo.lock index 137b883a6d..32ae30a765 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -52,6 +52,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "aliasable" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd" + [[package]] name = "aligned-vec" version = "0.6.1" @@ -490,7 +496,7 @@ dependencies = [ "hex", "hmac", "http 0.2.9", - "http 1.1.0", + "http 1.3.1", "once_cell", "p256 0.11.1", "percent-encoding", @@ -631,7 +637,7 @@ dependencies = [ "aws-smithy-types", "bytes", "http 0.2.9", - "http 1.1.0", + "http 1.3.1", "pin-project-lite", "tokio", "tracing", @@ -649,7 +655,7 @@ dependencies = [ "bytes-utils", "futures-core", "http 0.2.9", - "http 1.1.0", + "http 1.3.1", "http-body 0.4.5", "http-body 1.0.0", "http-body-util", @@ -698,7 +704,7 @@ dependencies = [ "bytes", "form_urlencoded", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "hyper 1.4.1", @@ -732,7 +738,7 @@ checksum = "df1362f362fd16024ae199c1970ce98f9661bf5ef94b9808fee734bc3698b733" dependencies = [ "bytes", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "mime", @@ -756,7 +762,7 @@ dependencies = [ "form_urlencoded", "futures-util", "headers", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "mime", @@ -1090,7 +1096,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "975982cdb7ad6a142be15bdf84aea7ec6a9e5d4d797c004d43185b24cfe4e684" dependencies = [ "clap", - "heck", + "heck 0.5.0", "indexmap 2.9.0", "log", "proc-macro2", @@ -1228,7 +1234,7 @@ version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "syn 2.0.100", @@ -1334,7 +1340,7 @@ dependencies = [ "flate2", "futures", "hostname-validator", - "http 1.1.0", + "http 1.3.1", "indexmap 2.9.0", "itertools 0.10.5", "jsonwebtoken", @@ -1969,7 +1975,7 @@ checksum = "0892a17df262a24294c382f0d5997571006e7a4348b4327557c4ff1cd4a8bccc" dependencies = [ "darling", "either", - "heck", + "heck 0.5.0", "proc-macro2", "quote", "syn 2.0.100", @@ -2661,7 +2667,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http 1.1.0", + "http 1.3.1", "indexmap 2.9.0", "slab", "tokio", @@ -2743,7 +2749,7 @@ dependencies = [ "base64 0.21.7", "bytes", "headers-core", - "http 1.1.0", + "http 1.3.1", "httpdate", "mime", "sha1", @@ -2755,9 +2761,15 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4" dependencies = [ - "http 1.1.0", + "http 1.3.1", ] +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "heck" version = "0.5.0" @@ -2833,9 +2845,9 @@ dependencies = [ [[package]] name = "http" -version = "1.1.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" dependencies = [ "bytes", "fnv", @@ -2860,7 +2872,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" dependencies = [ "bytes", - "http 1.1.0", + "http 1.3.1", ] [[package]] @@ -2871,7 +2883,7 @@ checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" dependencies = [ "bytes", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "pin-project-lite", ] @@ -2995,7 +3007,7 @@ dependencies = [ "futures-channel", "futures-util", "h2 0.4.4", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "httparse", "httpdate", @@ -3028,7 +3040,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a0bea761b46ae2b24eb4aef630d8d1c398157b6fc29e6350ecf090a0b70c952c" dependencies = [ "futures-util", - "http 1.1.0", + "http 1.3.1", "hyper 1.4.1", "hyper-util", "rustls 0.22.4", @@ -3060,7 +3072,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "hyper 1.4.1", "pin-project-lite", @@ -3709,7 +3721,7 @@ version = "0.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9e6777fc80a575f9503d908c8b498782a6c3ee88a06cb416dc3941401e43b94" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "syn 2.0.100", @@ -4160,7 +4172,7 @@ checksum = "10a8a7f5f6ba7c1b286c2fbca0454eaba116f63bbe69ed250b642d36fbb04d80" dependencies = [ "async-trait", "bytes", - "http 1.1.0", + "http 1.3.1", "opentelemetry", "reqwest", ] @@ -4173,7 +4185,7 @@ checksum = "91cf61a1868dacc576bf2b2a1c3e9ab150af7272909e80085c3173384fe11f76" dependencies = [ "async-trait", "futures-core", - "http 1.1.0", + "http 1.3.1", "opentelemetry", "opentelemetry-http", "opentelemetry-proto", @@ -4252,6 +4264,30 @@ dependencies = [ "winapi", ] +[[package]] +name = "ouroboros" +version = "0.18.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e0f050db9c44b97a94723127e6be766ac5c340c48f2c4bb3ffa11713744be59" +dependencies = [ + "aliasable", + "ouroboros_macro", + "static_assertions", +] + +[[package]] +name = "ouroboros_macro" +version = "0.18.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c7028bdd3d43083f6d8d4d5187680d0d3560d54df4cc9d752005268b41e64d0" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "proc-macro2-diagnostics", + "quote", + "syn 2.0.100", +] + [[package]] name = "outref" version = "0.5.1" @@ -4381,7 +4417,7 @@ dependencies = [ "hashlink", "hex", "hex-literal", - "http 1.1.0", + "http 1.3.1", "http-utils", "humantime", "humantime-serde", @@ -5148,6 +5184,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "proc-macro2-diagnostics" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", + "version_check", + "yansi", +] + [[package]] name = "procfs" version = "0.16.0" @@ -5217,7 +5266,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4" dependencies = [ "bytes", - "heck", + "heck 0.5.0", "itertools 0.12.1", "log", "multimap", @@ -5238,7 +5287,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15" dependencies = [ "bytes", - "heck", + "heck 0.5.0", "itertools 0.12.1", "log", "multimap", @@ -5334,7 +5383,7 @@ dependencies = [ "hex", "hmac", "hostname", - "http 1.1.0", + "http 1.3.1", "http-body-util", "http-utils", "humantime", @@ -5354,6 +5403,7 @@ dependencies = [ "metrics", "once_cell", "opentelemetry", + "ouroboros", "p256 0.13.2", "papaya", "parking_lot 0.12.1", @@ -5390,6 +5440,7 @@ dependencies = [ "socket2", "strum_macros", "subtle", + "subzero-core", "thiserror 1.0.69", "tikv-jemalloc-ctl", "tikv-jemallocator", @@ -5705,14 +5756,14 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.2" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.3", - "regex-syntax 0.8.2", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", ] [[package]] @@ -5726,13 +5777,13 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.3" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.2", + "regex-syntax 0.8.5", ] [[package]] @@ -5749,9 +5800,9 @@ checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "regex-syntax" -version = "0.8.2" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "relative-path" @@ -5821,7 +5872,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "hyper 1.4.1", @@ -5863,7 +5914,7 @@ checksum = "d1ccd3b55e711f91a9885a2fa6fbbb2e39db1776420b062efc058c6410f7e5e3" dependencies = [ "anyhow", "async-trait", - "http 1.1.0", + "http 1.3.1", "reqwest", "serde", "thiserror 1.0.69", @@ -5880,7 +5931,7 @@ dependencies = [ "async-trait", "futures", "getrandom 0.2.11", - "http 1.1.0", + "http 1.3.1", "hyper 1.4.1", "parking_lot 0.11.2", "reqwest", @@ -5901,7 +5952,7 @@ dependencies = [ "anyhow", "async-trait", "getrandom 0.2.11", - "http 1.1.0", + "http 1.3.1", "matchit", "opentelemetry", "reqwest", @@ -6260,7 +6311,7 @@ dependencies = [ "fail", "futures", "hex", - "http 1.1.0", + "http 1.3.1", "http-utils", "humantime", "hyper 0.14.30", @@ -7109,7 +7160,7 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2", "quote", "rustversion", @@ -7122,6 +7173,10 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" +[[package]] +name = "subzero-core" +version = "3.0.1" + [[package]] name = "svg_fmt" version = "0.4.3" @@ -7732,7 +7787,7 @@ dependencies = [ "async-trait", "base64 0.22.1", "bytes", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "percent-encoding", @@ -7756,7 +7811,7 @@ dependencies = [ "bytes", "flate2", "h2 0.4.4", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "http-body-util", "hyper 1.4.1", @@ -7847,7 +7902,7 @@ dependencies = [ "base64 0.22.1", "bitflags 2.8.0", "bytes", - "http 1.1.0", + "http 1.3.1", "http-body 1.0.0", "mime", "pin-project-lite", @@ -7868,7 +7923,7 @@ name = "tower-otel" version = "0.2.0" source = "git+https://github.com/mattiapenati/tower-otel?rev=56a7321053bcb72443888257b622ba0d43a11fcd#56a7321053bcb72443888257b622ba0d43a11fcd" dependencies = [ - "http 1.1.0", + "http 1.3.1", "opentelemetry", "pin-project", "tower-layer", @@ -8049,7 +8104,7 @@ dependencies = [ "byteorder", "bytes", "data-encoding", - "http 1.1.0", + "http 1.3.1", "httparse", "log", "rand 0.8.5", @@ -8068,7 +8123,7 @@ dependencies = [ "byteorder", "bytes", "data-encoding", - "http 1.1.0", + "http 1.3.1", "httparse", "log", "rand 0.8.5", @@ -8857,8 +8912,8 @@ dependencies = [ "quote", "rand 0.8.5", "regex", - "regex-automata 0.4.3", - "regex-syntax 0.8.2", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", "reqwest", "rustls 0.23.27", "rustls-pki-types", @@ -8954,6 +9009,12 @@ version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4d25c75bf9ea12c4040a97f829154768bbbce366287e2dc044af160cd79a13fd" +[[package]] +name = "yansi" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" + [[package]] name = "yasna" version = "0.5.2" diff --git a/Cargo.toml b/Cargo.toml index 6d91262882..fe647828fc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,6 +49,7 @@ members = [ "libs/proxy/tokio-postgres2", "endpoint_storage", "pgxn/neon/communicator", + "proxy/subzero_core", ] [workspace.package] diff --git a/Dockerfile b/Dockerfile index 55b87d4012..654ae72e56 100644 --- a/Dockerfile +++ b/Dockerfile @@ -63,7 +63,14 @@ WORKDIR /home/nonroot COPY --chown=nonroot . . -RUN cargo chef prepare --recipe-path recipe.json +RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \ + set -e \ + && if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \ + export CARGO_NET_GIT_FETCH_WITH_CLI=true && \ + git config --global url."https://$(cat /run/secrets/SUBZERO_ACCESS_TOKEN)@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero" && \ + cargo add -p proxy subzero-core --git https://github.com/neondatabase/subzero --rev 396264617e78e8be428682f87469bb25429af88a; \ + fi \ + && cargo chef prepare --recipe-path recipe.json # Main build image FROM $REPOSITORY/$IMAGE:$TAG AS build @@ -71,20 +78,33 @@ WORKDIR /home/nonroot ARG GIT_VERSION=local ARG BUILD_TAG ARG ADDITIONAL_RUSTFLAGS="" +ENV CARGO_FEATURES="default" # 3. Build cargo dependencies. Note that this step doesn't depend on anything else than # `recipe.json`, so the layer can be reused as long as none of the dependencies change. COPY --from=plan /home/nonroot/recipe.json recipe.json -RUN set -e \ +RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \ + set -e \ + && if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \ + export CARGO_NET_GIT_FETCH_WITH_CLI=true && \ + git config --global url."https://$(cat /run/secrets/SUBZERO_ACCESS_TOKEN)@github.com/neondatabase/subzero".insteadOf "https://github.com/neondatabase/subzero"; \ + fi \ && RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo chef cook --locked --release --recipe-path recipe.json # Perform the main build. We reuse the Postgres build artifacts from the intermediate 'pg-build' # layer, and the cargo dependencies built in the previous step. COPY --chown=nonroot --from=pg-build /home/nonroot/pg_install/ pg_install COPY --chown=nonroot . . +COPY --chown=nonroot --from=plan /home/nonroot/proxy/Cargo.toml proxy/Cargo.toml +COPY --chown=nonroot --from=plan /home/nonroot/Cargo.lock Cargo.lock -RUN set -e \ +RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \ + set -e \ + && if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \ + export CARGO_FEATURES="rest_broker"; \ + fi \ && RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo build \ + --features $CARGO_FEATURES \ --bin pg_sni_router \ --bin pageserver \ --bin pagectl \ diff --git a/deny.toml b/deny.toml index be1c6a2f2c..7afd05a837 100644 --- a/deny.toml +++ b/deny.toml @@ -35,6 +35,7 @@ reason = "The paste crate is a build-only dependency with no runtime components. # More documentation for the licenses section can be found here: # https://embarkstudios.github.io/cargo-deny/checks/licenses/cfg.html [licenses] +version = 2 allow = [ "0BSD", "Apache-2.0", diff --git a/pgxn/neon/communicator/Cargo.toml b/pgxn/neon/communicator/Cargo.toml index e95a269d90..b5ce389297 100644 --- a/pgxn/neon/communicator/Cargo.toml +++ b/pgxn/neon/communicator/Cargo.toml @@ -11,6 +11,9 @@ crate-type = ["staticlib"] # 'testing' feature is currently unused in the communicator, but we accept it for convenience of # calling build scripts, so that you can pass the same feature to all packages. testing = [] +# 'rest_broker' feature is currently unused in the communicator, but we accept it for convenience of +# calling build scripts, so that you can pass the same feature to all packages. +rest_broker = [] [dependencies] neon-shmem.workspace = true diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 82fe6818e3..8392046839 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -7,6 +7,7 @@ license.workspace = true [features] default = [] testing = ["dep:tokio-postgres"] +rest_broker = ["dep:subzero-core", "dep:ouroboros"] [dependencies] ahash.workspace = true @@ -105,6 +106,11 @@ uuid.workspace = true x509-cert.workspace = true redis.workspace = true zerocopy.workspace = true +# uncomment this to use the real subzero-core crate +# subzero-core = { git = "https://github.com/neondatabase/subzero", rev = "396264617e78e8be428682f87469bb25429af88a", features = ["postgresql"], optional = true } +# this is a stub for the subzero-core crate +subzero-core = { path = "./subzero_core", features = ["postgresql"], optional = true} +ouroboros = { version = "0.18", optional = true } # jwt stuff jose-jwa = "0.1.2" diff --git a/proxy/README.md b/proxy/README.md index ff48f9f323..ce957b90af 100644 --- a/proxy/README.md +++ b/proxy/README.md @@ -178,16 +178,24 @@ Create a configuration file called `local_proxy.json` in the root of the repo (u Start the local proxy: ```sh -cargo run --bin local_proxy -- \ - --disable_pg_session_jwt true \ +cargo run --bin local_proxy --features testing -- \ + --disable-pg-session-jwt \ --http 0.0.0.0:7432 ``` -Start the auth broker: +Start the auth/rest broker: + +Note: to enable the rest broker you need to replace the stub subzero-core crate with the real one. + ```sh -LOGFMT=text OTEL_SDK_DISABLED=true cargo run --bin proxy --features testing -- \ +cargo add -p proxy subzero-core --git https://github.com/neondatabase/subzero --rev 396264617e78e8be428682f87469bb25429af88a +``` + +```sh +LOGFMT=text OTEL_SDK_DISABLED=true cargo run --bin proxy --features testing,rest_broker -- \ -c server.crt -k server.key \ --is-auth-broker true \ + --is-rest-broker true \ --wss 0.0.0.0:8080 \ --http 0.0.0.0:7002 \ --auth-backend local @@ -205,3 +213,9 @@ curl -k "https://foo.local.neon.build:8080/sql" \ -H "neon-connection-string: postgresql://authenticator@foo.local.neon.build/database" \ -d '{"query":"select 1","params":[]}' ``` + +Make a rest request against the auth broker (rest broker): +```sh +curl -k "https://foo.local.neon.build:8080/database/rest/v1/items?select=id,name&id=eq.1" \ +-H "Authorization: Bearer $NEON_JWT" +``` diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index 401203d48c..e3f7ba4c15 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -20,6 +20,8 @@ use crate::auth::backend::jwt::JwkCache; use crate::auth::backend::local::LocalBackend; use crate::auth::{self}; use crate::cancellation::CancellationHandler; +#[cfg(feature = "rest_broker")] +use crate::config::RestConfig; use crate::config::{ self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig, refresh_config_loop, @@ -276,6 +278,13 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig accept_jwts: true, console_redirect_confirmation_timeout: Duration::ZERO, }, + #[cfg(feature = "rest_broker")] + rest_config: RestConfig { + is_rest_broker: false, + db_schema_cache: None, + max_schema_size: 0, + hostname_prefix: String::new(), + }, proxy_protocol_v2: config::ProxyProtocolV2::Rejected, handshake_timeout: Duration::from_secs(10), wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?, diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 16a7dc7b67..194a1ed34c 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -31,6 +31,8 @@ use crate::auth::backend::local::LocalBackend; use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned}; use crate::batch::BatchQueue; use crate::cancellation::{CancellationHandler, CancellationProcessor}; +#[cfg(feature = "rest_broker")] +use crate::config::RestConfig; #[cfg(any(test, feature = "testing"))] use crate::config::refresh_config_loop; use crate::config::{ @@ -47,6 +49,8 @@ use crate::redis::{elasticache, notifications}; use crate::scram::threadpool::ThreadPool; use crate::serverless::GlobalConnPoolOptions; use crate::serverless::cancel_set::CancelSet; +#[cfg(feature = "rest_broker")] +use crate::serverless::rest::DbSchemaCache; use crate::tls::client_config::compute_client_config_with_root_certs; #[cfg(any(test, feature = "testing"))] use crate::url::ApiUrl; @@ -246,11 +250,23 @@ struct ProxyCliArgs { /// if this is not local proxy, this toggles whether we accept Postgres REST requests #[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)] + #[cfg(feature = "rest_broker")] is_rest_broker: bool, /// cache for `db_schema_cache` introspection (use `size=0` to disable) #[clap(long, default_value = "size=1000,ttl=1h")] + #[cfg(feature = "rest_broker")] db_schema_cache: String, + + /// Maximum size allowed for schema in bytes + #[clap(long, default_value_t = 5 * 1024 * 1024)] // 5MB + #[cfg(feature = "rest_broker")] + max_schema_size: usize, + + /// Hostname prefix to strip from request hostname to get database hostname + #[clap(long, default_value = "apirest.")] + #[cfg(feature = "rest_broker")] + hostname_prefix: String, } #[derive(clap::Args, Clone, Copy, Debug)] @@ -517,6 +533,17 @@ pub async fn run() -> anyhow::Result<()> { )); maintenance_tasks.spawn(control_plane::mgmt::task_main(mgmt_listener)); + // add a task to flush the db_schema cache every 10 minutes + #[cfg(feature = "rest_broker")] + if let Some(db_schema_cache) = &config.rest_config.db_schema_cache { + maintenance_tasks.spawn(async move { + loop { + tokio::time::sleep(Duration::from_secs(600)).await; + db_schema_cache.flush(); + } + }); + } + if let Some(metrics_config) = &config.metric_collection { // TODO: Add gc regardles of the metric collection being enabled. maintenance_tasks.spawn(usage_metrics::task_main(metrics_config)); @@ -679,6 +706,30 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { timeout: Duration::from_secs(2), }; + #[cfg(feature = "rest_broker")] + let rest_config = { + let db_schema_cache_config: CacheOptions = args.db_schema_cache.parse()?; + info!("Using DbSchemaCache with options={db_schema_cache_config:?}"); + + let db_schema_cache = if args.is_rest_broker { + Some(DbSchemaCache::new( + "db_schema_cache", + db_schema_cache_config.size, + db_schema_cache_config.ttl, + true, + )) + } else { + None + }; + + RestConfig { + is_rest_broker: args.is_rest_broker, + db_schema_cache, + max_schema_size: args.max_schema_size, + hostname_prefix: args.hostname_prefix.clone(), + } + }; + let config = ProxyConfig { tls_config, metric_collection, @@ -691,6 +742,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { connect_to_compute: compute_config, #[cfg(feature = "testing")] disable_pg_session_jwt: false, + #[cfg(feature = "rest_broker")] + rest_config, }; let config = Box::leak(Box::new(config)); diff --git a/proxy/src/cache/timed_lru.rs b/proxy/src/cache/timed_lru.rs index e87cf53ab9..0a7fb40b0c 100644 --- a/proxy/src/cache/timed_lru.rs +++ b/proxy/src/cache/timed_lru.rs @@ -204,6 +204,11 @@ impl TimedLru { self.insert_raw_ttl(key, value, ttl, false); } + #[cfg(feature = "rest_broker")] + pub(crate) fn insert(&self, key: K, value: V) { + self.insert_raw_ttl(key, value, self.ttl, self.update_ttl_on_retrieval); + } + pub(crate) fn insert_unit(&self, key: K, value: V) -> (Option, Cached<&Self, ()>) { let (_, old) = self.insert_raw(key.clone(), value); @@ -214,6 +219,29 @@ impl TimedLru { (old, cached) } + + #[cfg(feature = "rest_broker")] + pub(crate) fn flush(&self) { + let now = Instant::now(); + let mut cache = self.cache.lock(); + + // Collect keys of expired entries first + let expired_keys: Vec<_> = cache + .iter() + .filter_map(|(key, entry)| { + if entry.expires_at <= now { + Some(key.clone()) + } else { + None + } + }) + .collect(); + + // Remove expired entries + for key in expired_keys { + cache.remove(&key); + } + } } impl TimedLru { diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 6157dc8a6a..20bbfd77d8 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -22,6 +22,8 @@ use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig}; use crate::scram::threadpool::ThreadPool; use crate::serverless::GlobalConnPoolOptions; use crate::serverless::cancel_set::CancelSet; +#[cfg(feature = "rest_broker")] +use crate::serverless::rest::DbSchemaCache; pub use crate::tls::server_config::{TlsConfig, configure_tls}; use crate::types::{Host, RoleName}; @@ -30,6 +32,8 @@ pub struct ProxyConfig { pub metric_collection: Option, pub http_config: HttpConfig, pub authentication_config: AuthenticationConfig, + #[cfg(feature = "rest_broker")] + pub rest_config: RestConfig, pub proxy_protocol_v2: ProxyProtocolV2, pub handshake_timeout: Duration, pub wake_compute_retry_config: RetryConfig, @@ -80,6 +84,14 @@ pub struct AuthenticationConfig { pub console_redirect_confirmation_timeout: tokio::time::Duration, } +#[cfg(feature = "rest_broker")] +pub struct RestConfig { + pub is_rest_broker: bool, + pub db_schema_cache: Option, + pub max_schema_size: usize, + pub hostname_prefix: String, +} + #[derive(Debug)] pub struct MetricBackupCollectionConfig { pub remote_storage_config: Option, diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index a6d376562b..88d5550fff 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -10,6 +10,7 @@ use super::connection_with_credentials_provider::ConnectionWithCredentialsProvid use crate::cache::project_info::ProjectInfoCache; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::metrics::{Metrics, RedisErrors, RedisEventsCount}; +use crate::util::deserialize_json_string; const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; const RECONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(20); @@ -121,15 +122,6 @@ struct InvalidateRole { role_name: RoleNameInt, } -fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result -where - T: for<'de2> serde::Deserialize<'de2>, - D: serde::Deserializer<'de>, -{ - let s = String::deserialize(deserializer)?; - serde_json::from_str(&s).map_err(::custom) -} - // https://github.com/serde-rs/serde/issues/1714 fn deserialize_unknown_topic<'de, D>(deserializer: D) -> Result<(), D::Error> where diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 5b7289c53d..18cdc39ac7 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -11,6 +11,8 @@ mod http_conn_pool; mod http_util; mod json; mod local_conn_pool; +#[cfg(feature = "rest_broker")] +pub mod rest; mod sql_over_http; mod websocket; @@ -487,6 +489,42 @@ async fn request_handler( .body(Empty::new().map_err(|x| match x {}).boxed()) .map_err(|e| ApiError::InternalServerError(e.into())) } else { - json_response(StatusCode::BAD_REQUEST, "query is not supported") + #[cfg(feature = "rest_broker")] + { + if config.rest_config.is_rest_broker + // we are testing for the path to be /database_name/rest/... + && request + .uri() + .path() + .split('/') + .nth(2) + .is_some_and(|part| part.starts_with("rest")) + { + let ctx = + RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http); + let span = ctx.span(); + + let testodrome_id = request + .headers() + .get("X-Neon-Query-ID") + .and_then(|value| value.to_str().ok()) + .map(|s| s.to_string()); + + if let Some(query_id) = testodrome_id { + info!(parent: &span, "testodrome query ID: {query_id}"); + ctx.set_testodrome_id(query_id.into()); + } + + rest::handle(config, ctx, request, backend, http_cancellation_token) + .instrument(span) + .await + } else { + json_response(StatusCode::BAD_REQUEST, "query is not supported") + } + } + #[cfg(not(feature = "rest_broker"))] + { + json_response(StatusCode::BAD_REQUEST, "query is not supported") + } } } diff --git a/proxy/src/serverless/rest.rs b/proxy/src/serverless/rest.rs new file mode 100644 index 0000000000..173c2629f7 --- /dev/null +++ b/proxy/src/serverless/rest.rs @@ -0,0 +1,1165 @@ +use std::borrow::Cow; +use std::collections::HashMap; +use std::sync::Arc; + +use bytes::Bytes; +use http::Method; +use http::header::{AUTHORIZATION, CONTENT_TYPE, HOST}; +use http_body_util::combinators::BoxBody; +use http_body_util::{BodyExt, Full}; +use http_utils::error::ApiError; +use hyper::body::Incoming; +use hyper::http::{HeaderName, HeaderValue}; +use hyper::{Request, Response, StatusCode}; +use indexmap::IndexMap; +use ouroboros::self_referencing; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Deserializer}; +use serde_json::Value as JsonValue; +use serde_json::value::RawValue; +use subzero_core::api::ContentType::{ApplicationJSON, Other, SingularJSON, TextCSV}; +use subzero_core::api::QueryNode::{Delete, FunctionCall, Insert, Update}; +use subzero_core::api::Resolution::{IgnoreDuplicates, MergeDuplicates}; +use subzero_core::api::{ApiResponse, ListVal, Payload, Preferences, Representation, SingleVal}; +use subzero_core::config::{db_allowed_select_functions, db_schemas, role_claim_key}; +use subzero_core::dynamic_statement::{JoinIterator, param, sql}; +use subzero_core::error::Error::{ + self as SubzeroCoreError, ContentTypeError, GucHeadersError, GucStatusError, InternalError, + JsonDeserialize, JwtTokenInvalid, NotFound, +}; +use subzero_core::error::pg_error_to_status_code; +use subzero_core::formatter::Param::{LV, PL, SV, Str, StrOwned}; +use subzero_core::formatter::postgresql::{fmt_main_query, generate}; +use subzero_core::formatter::{Param, Snippet, SqlParam}; +use subzero_core::parser::postgrest::parse; +use subzero_core::permissions::{check_safe_functions, replace_select_star}; +use subzero_core::schema::{ + DbSchema, POSTGRESQL_INTROSPECTION_SQL, get_postgresql_configuration_query, +}; +use subzero_core::{content_range_header, content_range_status}; +use tokio_util::sync::CancellationToken; +use tracing::{error, info}; +use typed_json::json; +use url::form_urlencoded; + +use super::backend::{HttpConnError, LocalProxyConnError, PoolingBackend}; +use super::conn_pool::AuthData; +use super::conn_pool_lib::ConnInfo; +use super::error::{ConnInfoError, Credentials, HttpCodeError, ReadPayloadError}; +use super::http_conn_pool::{self, Send}; +use super::http_util::{ + ALLOW_POOL, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_ISOLATION_LEVEL, TXN_READ_ONLY, + get_conn_info, json_response, uuid_to_header_value, +}; +use super::json::JsonConversionError; +use crate::auth::backend::ComputeCredentialKeys; +use crate::cache::{Cached, TimedLru}; +use crate::config::ProxyConfig; +use crate::context::RequestContext; +use crate::error::{ErrorKind, ReportableError, UserFacingError}; +use crate::http::read_body_with_limit; +use crate::metrics::Metrics; +use crate::serverless::sql_over_http::HEADER_VALUE_TRUE; +use crate::types::EndpointCacheKey; +use crate::util::deserialize_json_string; + +static EMPTY_JSON_SCHEMA: &str = r#"{"schemas":[]}"#; +const INTROSPECTION_SQL: &str = POSTGRESQL_INTROSPECTION_SQL; + +// A wrapper around the DbSchema that allows for self-referencing +#[self_referencing] +pub struct DbSchemaOwned { + schema_string: String, + #[covariant] + #[borrows(schema_string)] + schema: DbSchema<'this>, +} + +impl<'de> Deserialize<'de> for DbSchemaOwned { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + DbSchemaOwned::try_new(s, |s| serde_json::from_str(s)) + .map_err(::custom) + } +} + +fn split_comma_separated(s: &str) -> Vec { + s.split(',').map(|s| s.trim().to_string()).collect() +} + +fn deserialize_comma_separated<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let s = String::deserialize(deserializer)?; + Ok(split_comma_separated(&s)) +} + +fn deserialize_comma_separated_option<'de, D>( + deserializer: D, +) -> Result>, D::Error> +where + D: Deserializer<'de>, +{ + let opt = Option::::deserialize(deserializer)?; + if let Some(s) = &opt { + let trimmed = s.trim(); + if trimmed.is_empty() { + return Ok(None); + } + return Ok(Some(split_comma_separated(trimmed))); + } + Ok(None) +} + +// The ApiConfig is the configuration for the API per endpoint +// The configuration is read from the database and cached in the DbSchemaCache +#[derive(Deserialize, Debug)] +pub struct ApiConfig { + #[serde( + default = "db_schemas", + deserialize_with = "deserialize_comma_separated" + )] + pub db_schemas: Vec, + pub db_anon_role: Option, + pub db_max_rows: Option, + #[serde(default = "db_allowed_select_functions")] + pub db_allowed_select_functions: Vec, + // #[serde(deserialize_with = "to_tuple", default)] + // pub db_pre_request: Option<(String, String)>, + #[allow(dead_code)] + #[serde(default = "role_claim_key")] + pub role_claim_key: String, + #[serde(default, deserialize_with = "deserialize_comma_separated_option")] + pub db_extra_search_path: Option>, +} + +// The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint +pub(crate) type DbSchemaCache = TimedLru>; +impl DbSchemaCache { + pub async fn get_cached_or_remote( + &self, + endpoint_id: &EndpointCacheKey, + auth_header: &HeaderValue, + connection_string: &str, + client: &mut http_conn_pool::Client, + ctx: &RequestContext, + config: &'static ProxyConfig, + ) -> Result, RestError> { + match self.get_with_created_at(endpoint_id) { + Some(Cached { value: (v, _), .. }) => Ok(v), + None => { + info!("db_schema cache miss for endpoint: {:?}", endpoint_id); + let remote_value = self + .get_remote(auth_header, connection_string, client, ctx, config) + .await; + let (api_config, schema_owned) = match remote_value { + Ok((api_config, schema_owned)) => (api_config, schema_owned), + Err(e @ RestError::SchemaTooLarge) => { + // for the case where the schema is too large, we cache an empty dummy value + // all the other requests will fail without triggering the introspection query + let schema_owned = serde_json::from_str::(EMPTY_JSON_SCHEMA) + .map_err(|e| JsonDeserialize { source: e })?; + + let api_config = ApiConfig { + db_schemas: vec![], + db_anon_role: None, + db_max_rows: None, + db_allowed_select_functions: vec![], + role_claim_key: String::new(), + db_extra_search_path: None, + }; + let value = Arc::new((api_config, schema_owned)); + self.insert(endpoint_id.clone(), value); + return Err(e); + } + Err(e) => { + return Err(e); + } + }; + let value = Arc::new((api_config, schema_owned)); + self.insert(endpoint_id.clone(), value.clone()); + Ok(value) + } + } + } + pub async fn get_remote( + &self, + auth_header: &HeaderValue, + connection_string: &str, + client: &mut http_conn_pool::Client, + ctx: &RequestContext, + config: &'static ProxyConfig, + ) -> Result<(ApiConfig, DbSchemaOwned), RestError> { + #[derive(Deserialize)] + struct SingleRow { + rows: [Row; 1], + } + + #[derive(Deserialize)] + struct ConfigRow { + #[serde(deserialize_with = "deserialize_json_string")] + config: ApiConfig, + } + + #[derive(Deserialize)] + struct SchemaRow { + json_schema: DbSchemaOwned, + } + + let headers = vec![ + (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())), + ( + &CONN_STRING, + HeaderValue::from_str(connection_string).expect( + "connection string came from a header, so it must be a valid headervalue", + ), + ), + (&AUTHORIZATION, auth_header.clone()), + (&RAW_TEXT_OUTPUT, HEADER_VALUE_TRUE), + ]; + + let query = get_postgresql_configuration_query(Some("pgrst.pre_config")); + let SingleRow { + rows: [ConfigRow { config: api_config }], + } = make_local_proxy_request( + client, + headers.iter().cloned(), + QueryData { + query: Cow::Owned(query), + params: vec![], + }, + config.rest_config.max_schema_size, + ) + .await + .map_err(|e| match e { + RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => { + RestError::SchemaTooLarge + } + e => e, + })?; + + // now that we have the api_config let's run the second INTROSPECTION_SQL query + let SingleRow { + rows: [SchemaRow { json_schema }], + } = make_local_proxy_request( + client, + headers, + QueryData { + query: INTROSPECTION_SQL.into(), + params: vec![ + serde_json::to_value(&api_config.db_schemas) + .expect("Vec is always valid to encode as JSON"), + JsonValue::Bool(false), // include_roles_with_login + JsonValue::Bool(false), // use_internal_permissions + ], + }, + config.rest_config.max_schema_size, + ) + .await + .map_err(|e| match e { + RestError::ReadPayload(ReadPayloadError::BodyTooLarge { .. }) => { + RestError::SchemaTooLarge + } + e => e, + })?; + + Ok((api_config, json_schema)) + } +} + +// A type to represent a postgresql errors +// we use our own type (instead of postgres_client::Error) because we get the error from the json response +#[derive(Debug, thiserror::Error, Deserialize)] +pub(crate) struct PostgresError { + pub code: String, + pub message: String, + pub detail: Option, + pub hint: Option, +} +impl HttpCodeError for PostgresError { + fn get_http_status_code(&self) -> StatusCode { + let status = pg_error_to_status_code(&self.code, true); + StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) + } +} +impl ReportableError for PostgresError { + fn get_error_kind(&self) -> ErrorKind { + ErrorKind::User + } +} +impl UserFacingError for PostgresError { + fn to_string_client(&self) -> String { + if self.code.starts_with("PT") { + "Postgres error".to_string() + } else { + self.message.clone() + } + } +} +impl std::fmt::Display for PostgresError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} + +// A type to represent errors that can occur in the rest broker +#[derive(Debug, thiserror::Error)] +pub(crate) enum RestError { + #[error(transparent)] + ReadPayload(#[from] ReadPayloadError), + #[error(transparent)] + ConnectCompute(#[from] HttpConnError), + #[error(transparent)] + ConnInfo(#[from] ConnInfoError), + #[error(transparent)] + Postgres(#[from] PostgresError), + #[error(transparent)] + JsonConversion(#[from] JsonConversionError), + #[error(transparent)] + SubzeroCore(#[from] SubzeroCoreError), + #[error("schema is too large")] + SchemaTooLarge, +} +impl ReportableError for RestError { + fn get_error_kind(&self) -> ErrorKind { + match self { + RestError::ReadPayload(e) => e.get_error_kind(), + RestError::ConnectCompute(e) => e.get_error_kind(), + RestError::ConnInfo(e) => e.get_error_kind(), + RestError::Postgres(_) => ErrorKind::Postgres, + RestError::JsonConversion(_) => ErrorKind::Postgres, + RestError::SubzeroCore(_) => ErrorKind::User, + RestError::SchemaTooLarge => ErrorKind::User, + } + } +} +impl UserFacingError for RestError { + fn to_string_client(&self) -> String { + match self { + RestError::ReadPayload(p) => p.to_string(), + RestError::ConnectCompute(c) => c.to_string_client(), + RestError::ConnInfo(c) => c.to_string_client(), + RestError::SchemaTooLarge => self.to_string(), + RestError::Postgres(p) => p.to_string_client(), + RestError::JsonConversion(_) => "could not parse postgres response".to_string(), + RestError::SubzeroCore(s) => { + // TODO: this is a hack to get the message from the json body + let json = s.json_body(); + let default_message = "Unknown error".to_string(); + + json.get("message") + .map_or(default_message.clone(), |m| match m { + JsonValue::String(s) => s.clone(), + _ => default_message, + }) + } + } + } +} +impl HttpCodeError for RestError { + fn get_http_status_code(&self) -> StatusCode { + match self { + RestError::ReadPayload(e) => e.get_http_status_code(), + RestError::ConnectCompute(h) => match h.get_error_kind() { + ErrorKind::User => StatusCode::BAD_REQUEST, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }, + RestError::ConnInfo(_) => StatusCode::BAD_REQUEST, + RestError::Postgres(e) => e.get_http_status_code(), + RestError::JsonConversion(_) => StatusCode::INTERNAL_SERVER_ERROR, + RestError::SchemaTooLarge => StatusCode::INTERNAL_SERVER_ERROR, + RestError::SubzeroCore(e) => { + let status = e.status_code(); + StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) + } + } + } +} + +// Helper functions for the rest broker + +fn fmt_env_query<'a>(env: &'a HashMap<&'a str, &'a str>) -> Snippet<'a> { + "select " + + if env.is_empty() { + sql("null") + } else { + env.iter() + .map(|(k, v)| { + "set_config(" + param(k as &SqlParam) + ", " + param(v as &SqlParam) + ", true)" + }) + .join(",") + } +} + +// TODO: see about removing the need for cloning the values (inner things are &Cow already) +fn to_sql_param(p: &Param) -> JsonValue { + match p { + SV(SingleVal(v, ..)) => JsonValue::String(v.to_string()), + Str(v) => JsonValue::String((*v).to_string()), + StrOwned(v) => JsonValue::String((*v).clone()), + PL(Payload(v, ..)) => JsonValue::String(v.clone().into_owned()), + LV(ListVal(v, ..)) => { + if v.is_empty() { + JsonValue::String(r"{}".to_string()) + } else { + JsonValue::String(format!( + "{{\"{}\"}}", + v.iter() + .map(|e| e.replace('\\', "\\\\").replace('\"', "\\\"")) + .collect::>() + .join("\",\"") + )) + } + } + } +} + +#[derive(serde::Serialize)] +struct QueryData<'a> { + query: Cow<'a, str>, + params: Vec, +} + +#[derive(serde::Serialize)] +struct BatchQueryData<'a> { + queries: Vec>, +} + +async fn make_local_proxy_request( + client: &mut http_conn_pool::Client, + headers: impl IntoIterator, + body: QueryData<'_>, + max_len: usize, +) -> Result { + let body_string = serde_json::to_string(&body) + .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?; + + let response = make_raw_local_proxy_request(client, headers, body_string).await?; + + let response_status = response.status(); + + if response_status != StatusCode::OK { + return Err(RestError::SubzeroCore(InternalError { + message: "Failed to get endpoint schema".to_string(), + })); + } + + // Capture the response body + let response_body = crate::http::read_body_with_limit(response.into_body(), max_len) + .await + .map_err(ReadPayloadError::from)?; + + // Parse the JSON response + let response_json: S = serde_json::from_slice(&response_body) + .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; + + Ok(response_json) +} + +async fn make_raw_local_proxy_request( + client: &mut http_conn_pool::Client, + headers: impl IntoIterator, + body: String, +) -> Result, RestError> { + let local_proxy_uri = ::http::Uri::from_static("http://proxy.local/sql"); + let mut req = Request::builder().method(Method::POST).uri(local_proxy_uri); + let req_headers = req.headers_mut().expect("failed to get headers"); + // Add all provided headers to the request + for (header_name, header_value) in headers { + req_headers.insert(header_name, header_value.clone()); + } + + let body_boxed = Full::new(Bytes::from(body)) + .map_err(|never| match never {}) // Convert Infallible to hyper::Error + .boxed(); + + let req = req.body(body_boxed).map_err(|_| { + RestError::SubzeroCore(InternalError { + message: "Failed to build request".to_string(), + }) + })?; + + // Send the request to the local proxy + client + .inner + .inner + .send_request(req) + .await + .map_err(LocalProxyConnError::from) + .map_err(HttpConnError::from) + .map_err(RestError::from) +} + +pub(crate) async fn handle( + config: &'static ProxyConfig, + ctx: RequestContext, + request: Request, + backend: Arc, + cancel: CancellationToken, +) -> Result>, ApiError> { + let result = handle_inner(cancel, config, &ctx, request, backend).await; + + let mut response = match result { + Ok(r) => { + ctx.set_success(); + + // Handling the error response from local proxy here + if r.status().is_server_error() { + let status = r.status(); + + let body_bytes = r + .collect() + .await + .map_err(|e| { + ApiError::InternalServerError(anyhow::Error::msg(format!( + "could not collect http body: {e}" + ))) + })? + .to_bytes(); + + if let Ok(mut json_map) = + serde_json::from_slice::>(&body_bytes) + { + let message = json_map.get("message"); + if let Some(message) = message { + let msg: String = match serde_json::from_str(message.get()) { + Ok(msg) => msg, + Err(_) => { + "Unable to parse the response message from server".to_string() + } + }; + + error!("Error response from local_proxy: {status} {msg}"); + + json_map.retain(|key, _| !key.starts_with("neon:")); // remove all the neon-related keys + + let resp_json = serde_json::to_string(&json_map) + .unwrap_or("failed to serialize the response message".to_string()); + + return json_response(status, resp_json); + } + } + + error!("Unable to parse the response message from local_proxy"); + return json_response( + status, + json!({ "message": "Unable to parse the response message from server".to_string() }), + ); + } + r + } + Err(e @ RestError::SubzeroCore(_)) => { + let error_kind = e.get_error_kind(); + ctx.set_error_kind(error_kind); + + tracing::info!( + kind=error_kind.to_metric_label(), + error=%e, + msg="subzero core error", + "forwarding error to user" + ); + + let RestError::SubzeroCore(subzero_err) = e else { + panic!("expected subzero core error") + }; + + let json_body = subzero_err.json_body(); + let status_code = StatusCode::from_u16(subzero_err.status_code()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + + json_response(status_code, json_body)? + } + Err(e) => { + let error_kind = e.get_error_kind(); + ctx.set_error_kind(error_kind); + + let message = e.to_string_client(); + let status_code = e.get_http_status_code(); + + tracing::info!( + kind=error_kind.to_metric_label(), + error=%e, + msg=message, + "forwarding error to user" + ); + + let (code, detail, hint) = match e { + RestError::Postgres(e) => ( + if e.code.starts_with("PT") { + None + } else { + Some(e.code) + }, + e.detail, + e.hint, + ), + _ => (None, None, None), + }; + + json_response( + status_code, + json!({ + "message": message, + "code": code, + "detail": detail, + "hint": hint, + }), + )? + } + }; + + response + .headers_mut() + .insert("Access-Control-Allow-Origin", HeaderValue::from_static("*")); + Ok(response) +} + +async fn handle_inner( + _cancel: CancellationToken, + config: &'static ProxyConfig, + ctx: &RequestContext, + request: Request, + backend: Arc, +) -> Result>, RestError> { + let _requeset_gauge = Metrics::get() + .proxy + .connection_requests + .guard(ctx.protocol()); + info!( + protocol = %ctx.protocol(), + "handling interactive connection from client" + ); + + // Read host from Host, then URI host as fallback + // TODO: will this be a problem if behind a load balancer? + // TODO: can we use the x-forwarded-host header? + let host = request + .headers() + .get(HOST) + .and_then(|v| v.to_str().ok()) + .unwrap_or_else(|| request.uri().host().unwrap_or("")); + + // a valid path is /database/rest/v1/... so splitting should be ["", "database", "rest", "v1", ...] + let database_name = request + .uri() + .path() + .split('/') + .nth(1) + .ok_or(RestError::SubzeroCore(NotFound { + target: request.uri().path().to_string(), + }))?; + + // we always use the authenticator role to connect to the database + let authenticator_role = "authenticator"; + + // Strip the hostname prefix from the host to get the database hostname + let database_host = host.replace(&config.rest_config.hostname_prefix, ""); + + let connection_string = + format!("postgresql://{authenticator_role}@{database_host}/{database_name}"); + + let conn_info = get_conn_info( + &config.authentication_config, + ctx, + Some(&connection_string), + request.headers(), + )?; + info!( + user = conn_info.conn_info.user_info.user.as_str(), + "credentials" + ); + + match conn_info.auth { + AuthData::Jwt(jwt) => { + let api_prefix = format!("/{database_name}/rest/v1/"); + handle_rest_inner( + config, + ctx, + &api_prefix, + request, + &connection_string, + conn_info.conn_info, + jwt, + backend, + ) + .await + } + AuthData::Password(_) => Err(RestError::ConnInfo(ConnInfoError::MissingCredentials( + Credentials::BearerJwt, + ))), + } +} + +#[allow(clippy::too_many_arguments)] +async fn handle_rest_inner( + config: &'static ProxyConfig, + ctx: &RequestContext, + api_prefix: &str, + request: Request, + connection_string: &str, + conn_info: ConnInfo, + jwt: String, + backend: Arc, +) -> Result>, RestError> { + // validate the jwt token + let jwt_parsed = backend + .authenticate_with_jwt(ctx, &conn_info.user_info, jwt) + .await + .map_err(HttpConnError::from)?; + + let db_schema_cache = + config + .rest_config + .db_schema_cache + .as_ref() + .ok_or(RestError::SubzeroCore(InternalError { + message: "DB schema cache is not configured".to_string(), + }))?; + + let endpoint_cache_key = conn_info + .endpoint_cache_key() + .ok_or(RestError::SubzeroCore(InternalError { + message: "Failed to get endpoint cache key".to_string(), + }))?; + + let mut client = backend.connect_to_local_proxy(ctx, conn_info).await?; + + let (parts, originial_body) = request.into_parts(); + + let auth_header = parts + .headers + .get(AUTHORIZATION) + .ok_or(RestError::SubzeroCore(InternalError { + message: "Authorization header is required".to_string(), + }))?; + + let entry = db_schema_cache + .get_cached_or_remote( + &endpoint_cache_key, + auth_header, + connection_string, + &mut client, + ctx, + config, + ) + .await?; + let (api_config, db_schema_owned) = entry.as_ref(); + let db_schema = db_schema_owned.borrow_schema(); + + let db_schemas = &api_config.db_schemas; // list of schemas available for the api + let db_extra_search_path = &api_config.db_extra_search_path; + // TODO: use this when we get a replacement for jsonpath_lib + // let role_claim_key = &api_config.role_claim_key; + // let role_claim_path = format!("${role_claim_key}"); + let db_anon_role = &api_config.db_anon_role; + let max_rows = api_config.db_max_rows.as_deref(); + let db_allowed_select_functions = api_config + .db_allowed_select_functions + .iter() + .map(|s| s.as_str()) + .collect::>(); + + // extract the jwt claims (we'll need them later to set the role and env) + let jwt_claims = match jwt_parsed.keys { + ComputeCredentialKeys::JwtPayload(payload_bytes) => { + // `payload_bytes` contains the raw JWT payload as Vec + // You can deserialize it back to JSON or parse specific claims + let payload: serde_json::Value = serde_json::from_slice(&payload_bytes) + .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; + Some(payload) + } + ComputeCredentialKeys::AuthKeys(_) => None, + }; + + // read the role from the jwt claims (and set it to the "anon" role if not present) + let (role, authenticated) = match &jwt_claims { + Some(claims) => match claims.get("role") { + Some(JsonValue::String(r)) => (Some(r), true), + _ => (db_anon_role.as_ref(), true), + }, + None => (db_anon_role.as_ref(), false), + }; + + // do not allow unauthenticated requests when there is no anonymous role setup + if let (None, false) = (role, authenticated) { + return Err(RestError::SubzeroCore(JwtTokenInvalid { + message: "unauthenticated requests not allowed".to_string(), + })); + } + + // start deconstructing the request because subzero core mostly works with &str + let method = parts.method; + let method_str = method.as_str(); + let path = parts.uri.path_and_query().map_or("/", |pq| pq.as_str()); + + // this is actually the table name (or rpc/function_name) + // TODO: rename this to something more descriptive + let root = match parts.uri.path().strip_prefix(api_prefix) { + Some(p) => Ok(p), + None => Err(RestError::SubzeroCore(NotFound { + target: parts.uri.path().to_string(), + })), + }?; + + // pick the current schema from the headers (or the first one from config) + let schema_name = &DbSchema::pick_current_schema(db_schemas, method_str, &parts.headers)?; + + // add the content-profile header to the response + let mut response_headers = vec![]; + if db_schemas.len() > 1 { + response_headers.push(("Content-Profile".to_string(), schema_name.clone())); + } + + // parse the query string into a Vec<(&str, &str)> + let query = match parts.uri.query() { + Some(q) => form_urlencoded::parse(q.as_bytes()).collect(), + None => vec![], + }; + let get: Vec<(&str, &str)> = query.iter().map(|(k, v)| (&**k, &**v)).collect(); + + // convert the headers map to a HashMap<&str, &str> + let headers: HashMap<&str, &str> = parts + .headers + .iter() + .map(|(k, v)| (k.as_str(), v.to_str().unwrap_or("__BAD_HEADER__"))) + .collect(); + + let cookies = HashMap::new(); // TODO: add cookies + + // Read the request body (skip for GET requests) + let body_as_string: Option = if method == Method::GET { + None + } else { + let body_bytes = + read_body_with_limit(originial_body, config.http_config.max_request_size_bytes) + .await + .map_err(ReadPayloadError::from)?; + if body_bytes.is_empty() { + None + } else { + Some(String::from_utf8_lossy(&body_bytes).into_owned()) + } + }; + + // parse the request into an ApiRequest struct + let mut api_request = parse( + schema_name, + root, + db_schema, + method_str, + path, + get, + body_as_string.as_deref(), + headers, + cookies, + max_rows, + ) + .map_err(RestError::SubzeroCore)?; + + let role_str = match role { + Some(r) => r, + None => "", + }; + + replace_select_star(db_schema, schema_name, role_str, &mut api_request.query)?; + + // TODO: this is not relevant when acting as PostgREST but will be useful + // in the context of DBX where they need internal permissions + // if !disable_internal_permissions { + // check_privileges(db_schema, schema_name, role_str, &api_request)?; + // } + + check_safe_functions(&api_request, &db_allowed_select_functions)?; + + // TODO: this is not relevant when acting as PostgREST but will be useful + // in the context of DBX where they need internal permissions + // if !disable_internal_permissions { + // insert_policy_conditions(db_schema, schema_name, role_str, &mut api_request.query)?; + // } + + let env_role = Some(role_str); + + // construct the env (passed in to the sql context as GUCs) + let empty_json = "{}".to_string(); + let headers_env = serde_json::to_string(&api_request.headers).unwrap_or(empty_json.clone()); + let cookies_env = serde_json::to_string(&api_request.cookies).unwrap_or(empty_json.clone()); + let get_env = serde_json::to_string(&api_request.get).unwrap_or(empty_json.clone()); + let jwt_claims_env = jwt_claims + .as_ref() + .map(|v| serde_json::to_string(v).unwrap_or(empty_json.clone())) + .unwrap_or(if let Some(r) = env_role { + let claims: HashMap<&str, &str> = HashMap::from([("role", r)]); + serde_json::to_string(&claims).unwrap_or(empty_json.clone()) + } else { + empty_json.clone() + }); + let mut search_path = vec![api_request.schema_name]; + if let Some(extra) = &db_extra_search_path { + search_path.extend(extra.iter().map(|s| s.as_str())); + } + let search_path_str = search_path + .into_iter() + .filter(|s| !s.is_empty()) + .collect::>() + .join(","); + let mut env: HashMap<&str, &str> = HashMap::from([ + ("request.method", api_request.method), + ("request.path", api_request.path), + ("request.headers", &headers_env), + ("request.cookies", &cookies_env), + ("request.get", &get_env), + ("request.jwt.claims", &jwt_claims_env), + ("search_path", &search_path_str), + ]); + if let Some(r) = env_role { + env.insert("role", r); + } + + // generate the sql statements + let (env_statement, env_parameters, _) = generate(fmt_env_query(&env)); + let (main_statement, main_parameters, _) = generate(fmt_main_query( + db_schema, + api_request.schema_name, + &api_request, + &env, + )?); + + let mut headers = vec![ + (&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id())), + ( + &CONN_STRING, + HeaderValue::from_str(connection_string).expect("invalid connection string"), + ), + (&AUTHORIZATION, auth_header.clone()), + ( + &TXN_ISOLATION_LEVEL, + HeaderValue::from_static("ReadCommitted"), + ), + (&ALLOW_POOL, HEADER_VALUE_TRUE), + ]; + + if api_request.read_only { + headers.push((&TXN_READ_ONLY, HEADER_VALUE_TRUE)); + } + + // convert the parameters from subzero core representation to the local proxy repr. + let req_body = serde_json::to_string(&BatchQueryData { + queries: vec![ + QueryData { + query: env_statement.into(), + params: env_parameters + .iter() + .map(|p| to_sql_param(&p.to_param())) + .collect(), + }, + QueryData { + query: main_statement.into(), + params: main_parameters + .iter() + .map(|p| to_sql_param(&p.to_param())) + .collect(), + }, + ], + }) + .map_err(|e| RestError::JsonConversion(JsonConversionError::ParseJsonError(e)))?; + + // todo: map body to count egress + let _metrics = client.metrics(ctx); // FIXME: is everything in the context set correctly? + + // send the request to the local proxy + let response = make_raw_local_proxy_request(&mut client, headers, req_body).await?; + let (parts, body) = response.into_parts(); + + let max_response = config.http_config.max_response_size_bytes; + let bytes = read_body_with_limit(body, max_response) + .await + .map_err(ReadPayloadError::from)?; + + // if the response status is greater than 399, then it is an error + // FIXME: check if there are other error codes or shapes of the response + if parts.status.as_u16() > 399 { + // turn this postgres error from the json into PostgresError + let postgres_error = serde_json::from_slice(&bytes) + .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; + + return Err(RestError::Postgres(postgres_error)); + } + + #[derive(Deserialize)] + struct QueryResults { + /// we run two queries, so we want only two results. + results: (EnvRows, MainRows), + } + + /// `env_statement` returns nothing of interest to us + #[derive(Deserialize)] + struct EnvRows {} + + #[derive(Deserialize)] + struct MainRows { + /// `main_statement` only returns a single row. + rows: [MainRow; 1], + } + + #[derive(Deserialize)] + struct MainRow { + body: String, + page_total: Option, + total_result_set: Option, + response_headers: Option, + response_status: Option, + } + + let results: QueryResults = serde_json::from_slice(&bytes) + .map_err(|e| RestError::SubzeroCore(JsonDeserialize { source: e }))?; + + let QueryResults { + results: (_, MainRows { rows: [row] }), + } = results; + + // build the intermediate response object + let api_response = ApiResponse { + page_total: row.page_total.map_or(0, |v| v.parse::().unwrap_or(0)), + total_result_set: row.total_result_set.map(|v| v.parse::().unwrap_or(0)), + top_level_offset: 0, // FIXME: check why this is 0 + response_headers: row.response_headers, + response_status: row.response_status, + body: row.body, + }; + + // TODO: rollback the transaction if the page_total is not 1 and the accept_content_type is SingularJSON + // we can not do this in the context of proxy for now + // if api_request.accept_content_type == SingularJSON && api_response.page_total != 1 { + // // rollback the transaction here + // return Err(RestError::SubzeroCore(SingularityError { + // count: api_response.page_total, + // content_type: "application/vnd.pgrst.object+json".to_string(), + // })); + // } + + // TODO: rollback the transaction if the page_total is not 1 and the method is PUT + // we can not do this in the context of proxy for now + // if api_request.method == Method::PUT && api_response.page_total != 1 { + // // Makes sure the querystring pk matches the payload pk + // // e.g. PUT /items?id=eq.1 { "id" : 1, .. } is accepted, + // // PUT /items?id=eq.14 { "id" : 2, .. } is rejected. + // // If this condition is not satisfied then nothing is inserted, + // // rollback the transaction here + // return Err(RestError::SubzeroCore(PutMatchingPkError)); + // } + + // create and return the response to the client + // this section mostly deals with setting the right headers according to PostgREST specs + let page_total = api_response.page_total; + let total_result_set = api_response.total_result_set; + let top_level_offset = api_response.top_level_offset; + let response_content_type = match (&api_request.accept_content_type, &api_request.query.node) { + (SingularJSON, _) + | ( + _, + FunctionCall { + returns_single: true, + is_scalar: false, + .. + }, + ) => SingularJSON, + (TextCSV, _) => TextCSV, + _ => ApplicationJSON, + }; + + // check if the SQL env set some response headers (happens when we called a rpc function) + if let Some(response_headers_str) = api_response.response_headers { + let Ok(headers_json) = + serde_json::from_str::>>(response_headers_str.as_str()) + else { + return Err(RestError::SubzeroCore(GucHeadersError)); + }; + + response_headers.extend(headers_json.into_iter().flatten()); + } + + // calculate and set the content range header + let lower = top_level_offset as i64; + let upper = top_level_offset as i64 + page_total as i64 - 1; + let total = total_result_set.map(|t| t as i64); + let content_range = match (&method, &api_request.query.node) { + (&Method::POST, Insert { .. }) => content_range_header(1, 0, total), + (&Method::DELETE, Delete { .. }) => content_range_header(1, upper, total), + _ => content_range_header(lower, upper, total), + }; + response_headers.push(("Content-Range".to_string(), content_range)); + + // calculate the status code + #[rustfmt::skip] + let mut status = match (&method, &api_request.query.node, page_total, &api_request.preferences) { + (&Method::POST, Insert { .. }, ..) => 201, + (&Method::DELETE, Delete { .. }, _, Some(Preferences {representation: Some(Representation::Full),..}),) => 200, + (&Method::DELETE, Delete { .. }, ..) => 204, + (&Method::PATCH, Update { columns, .. }, 0, _) if !columns.is_empty() => 404, + (&Method::PATCH, Update { .. }, _,Some(Preferences {representation: Some(Representation::Full),..}),) => 200, + (&Method::PATCH, Update { .. }, ..) => 204, + (&Method::PUT, Insert { .. },_,Some(Preferences {representation: Some(Representation::Full),..}),) => 200, + (&Method::PUT, Insert { .. }, ..) => 204, + _ => content_range_status(lower, upper, total), + }; + + // add the preference-applied header + if let Some(Preferences { + resolution: Some(r), + .. + }) = api_request.preferences + { + response_headers.push(( + "Preference-Applied".to_string(), + match r { + MergeDuplicates => "resolution=merge-duplicates".to_string(), + IgnoreDuplicates => "resolution=ignore-duplicates".to_string(), + }, + )); + } + + // check if the SQL env set some response status (happens when we called a rpc function) + if let Some(response_status_str) = api_response.response_status { + status = response_status_str + .parse::() + .map_err(|_| RestError::SubzeroCore(GucStatusError))?; + } + + // set the content type header + // TODO: move this to a subzero function + // as_header_value(&self) -> Option<&str> + let http_content_type = match response_content_type { + SingularJSON => Ok("application/vnd.pgrst.object+json"), + TextCSV => Ok("text/csv"), + ApplicationJSON => Ok("application/json"), + Other(t) => Err(RestError::SubzeroCore(ContentTypeError { + message: format!("None of these Content-Types are available: {t}"), + })), + }?; + + // build the response body + let response_body = Full::new(Bytes::from(api_response.body)) + .map_err(|never| match never {}) + .boxed(); + + // build the response + let mut response = Response::builder() + .status(StatusCode::from_u16(status).unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)) + .header(CONTENT_TYPE, http_content_type); + + // Add all headers from response_headers vector + for (header_name, header_value) in response_headers { + response = response.header(header_name, header_value); + } + + // add the body and return the response + response.body(response_body).map_err(|_| { + RestError::SubzeroCore(InternalError { + message: "Failed to build response".to_string(), + }) + }) +} diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 8a14f804b6..f254b41b5b 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -64,7 +64,7 @@ enum Payload { Batch(BatchQueryData), } -static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); +pub(super) const HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true"); fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result>, D::Error> where diff --git a/proxy/src/util.rs b/proxy/src/util.rs index 0291216d94..c89ebab008 100644 --- a/proxy/src/util.rs +++ b/proxy/src/util.rs @@ -20,3 +20,13 @@ pub async fn run_until( Either::Right((f2, _)) => Err(f2), } } + +pub fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result +where + T: for<'de2> serde::Deserialize<'de2>, + D: serde::Deserializer<'de>, +{ + use serde::Deserialize; + let s = String::deserialize(deserializer)?; + serde_json::from_str(&s).map_err(::custom) +} diff --git a/proxy/subzero_core/.gitignore b/proxy/subzero_core/.gitignore new file mode 100644 index 0000000000..f2f9e58ec3 --- /dev/null +++ b/proxy/subzero_core/.gitignore @@ -0,0 +1,2 @@ +target +Cargo.lock \ No newline at end of file diff --git a/proxy/subzero_core/Cargo.toml b/proxy/subzero_core/Cargo.toml new file mode 100644 index 0000000000..13185873d0 --- /dev/null +++ b/proxy/subzero_core/Cargo.toml @@ -0,0 +1,12 @@ +# This is a stub for the subzero-core crate. +[package] +name = "subzero-core" +version = "3.0.1" +edition = "2024" +publish = false # "private"! + +[features] +default = [] +postgresql = [] + +[dependencies] diff --git a/proxy/subzero_core/src/lib.rs b/proxy/subzero_core/src/lib.rs new file mode 100644 index 0000000000..b99246b98b --- /dev/null +++ b/proxy/subzero_core/src/lib.rs @@ -0,0 +1 @@ +// This is a stub for the subzero-core crate. diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 86ffa9e4d4..1ce34a2c4e 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -4121,6 +4121,294 @@ class NeonAuthBroker: self._popen.kill() +class NeonLocalProxy(LogUtils): + """ + An object managing a local_proxy instance for rest broker testing. + The local_proxy serves as a direct connection to VanillaPostgres. + """ + + def __init__( + self, + neon_binpath: Path, + test_output_dir: Path, + http_port: int, + metrics_port: int, + vanilla_pg: VanillaPostgres, + config_path: Path | None = None, + ): + self.neon_binpath = neon_binpath + self.test_output_dir = test_output_dir + self.http_port = http_port + self.metrics_port = metrics_port + self.vanilla_pg = vanilla_pg + self.config_path = config_path or (test_output_dir / "local_proxy.json") + self.host = "127.0.0.1" + self.running = False + self.logfile = test_output_dir / "local_proxy.log" + self._popen: subprocess.Popen[bytes] | None = None + super().__init__(logfile=self.logfile) + + def start(self) -> Self: + assert self._popen is None + assert not self.running + + # Ensure vanilla_pg is running + if not self.vanilla_pg.is_running(): + self.vanilla_pg.start() + + args = [ + str(self.neon_binpath / "local_proxy"), + "--http", + f"{self.host}:{self.http_port}", + "--metrics", + f"{self.host}:{self.metrics_port}", + "--postgres", + f"127.0.0.1:{self.vanilla_pg.default_options['port']}", + "--config-path", + str(self.config_path), + "--disable-pg-session-jwt", + ] + + logfile = open(self.logfile, "w") + self._popen = subprocess.Popen(args, stdout=logfile, stderr=logfile) + self.running = True + self._wait_until_ready() + return self + + def stop(self) -> Self: + if self._popen is not None and self.running: + self._popen.terminate() + try: + self._popen.wait(timeout=5) + except subprocess.TimeoutExpired: + log.warning("failed to gracefully terminate local_proxy; killing") + self._popen.kill() + self.running = False + return self + + def get_binary_version(self) -> str: + """Get the version string of the local_proxy binary""" + try: + result = subprocess.run( + [str(self.neon_binpath / "local_proxy"), "--version"], + capture_output=True, + text=True, + timeout=10, + ) + return result.stdout.strip() + except (subprocess.TimeoutExpired, subprocess.CalledProcessError): + return "" + + @backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=10) + def _wait_until_ready(self): + assert self._popen and self._popen.poll() is None, ( + "Local proxy exited unexpectedly. Check test log." + ) + requests.get(f"http://{self.host}:{self.http_port}/metrics") + + def get_metrics(self) -> str: + response = requests.get(f"http://{self.host}:{self.metrics_port}/metrics") + return response.text + + def assert_no_errors(self): + # Define allowed error patterns for local_proxy + allowed_errors = [ + # Add patterns as needed + ] + not_allowed = [ + "error", + "panic", + "failed", + ] + + for na in not_allowed: + if na not in allowed_errors: + assert not self.log_contains(na), f"Found disallowed error pattern: {na}" + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ): + self.stop() + + +class NeonRestBrokerProxy(LogUtils): + """ + An object managing a proxy instance configured as both auth broker and rest broker. + This is the main proxy binary with --is-auth-broker and --is-rest-broker flags. + """ + + def __init__( + self, + neon_binpath: Path, + test_output_dir: Path, + wss_port: int, + http_port: int, + mgmt_port: int, + config_path: Path | None = None, + ): + self.neon_binpath = neon_binpath + self.test_output_dir = test_output_dir + self.wss_port = wss_port + self.http_port = http_port + self.mgmt_port = mgmt_port + self.config_path = config_path or (test_output_dir / "rest_broker_proxy.json") + self.host = "127.0.0.1" + self.running = False + self.logfile = test_output_dir / "rest_broker_proxy.log" + self._popen: subprocess.Popen[Any] | None = None + + def start(self) -> Self: + if self.running: + return self + + # Generate self-signed TLS certificates + cert_path = self.test_output_dir / "server.crt" + key_path = self.test_output_dir / "server.key" + + if not cert_path.exists() or not key_path.exists(): + import subprocess + + log.info("Generating self-signed TLS certificate for rest broker") + subprocess.run( + [ + "openssl", + "req", + "-new", + "-x509", + "-days", + "365", + "-nodes", + "-text", + "-out", + str(cert_path), + "-keyout", + str(key_path), + "-subj", + "/CN=*.local.neon.build", + ], + check=True, + ) + + log.info( + f"Starting rest broker proxy on WSS port {self.wss_port}, HTTP port {self.http_port}" + ) + + cmd = [ + str(self.neon_binpath / "proxy"), + "-c", + str(cert_path), + "-k", + str(key_path), + "--is-auth-broker", + "true", + "--is-rest-broker", + "true", + "--wss", + f"{self.host}:{self.wss_port}", + "--http", + f"{self.host}:{self.http_port}", + "--mgmt", + f"{self.host}:{self.mgmt_port}", + "--auth-backend", + "local", + "--config-path", + str(self.config_path), + ] + + log.info(f"Starting rest broker proxy with command: {' '.join(cmd)}") + + with open(self.logfile, "w") as logfile: + self._popen = subprocess.Popen( + cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + cwd=self.test_output_dir, + env={ + **os.environ, + "RUST_LOG": "info", + "LOGFMT": "text", + "OTEL_SDK_DISABLED": "true", + }, + ) + + self.running = True + self._wait_until_ready() + return self + + def stop(self) -> Self: + if not self.running: + return self + + log.info("Stopping rest broker proxy") + + if self._popen is not None: + self._popen.terminate() + try: + self._popen.wait(timeout=10) + except subprocess.TimeoutExpired: + log.warning("failed to gracefully terminate rest broker proxy; killing") + self._popen.kill() + + self.running = False + return self + + def get_binary_version(self) -> str: + cmd = [str(self.neon_binpath / "proxy"), "--version"] + res = subprocess.run(cmd, capture_output=True, text=True, check=True) + return res.stdout.strip() + + @backoff.on_exception(backoff.expo, requests.exceptions.RequestException, max_time=10) + def _wait_until_ready(self): + # Check if the WSS port is ready using a simple HTTPS request + # REST API is served on the WSS port with HTTPS + requests.get(f"https://{self.host}:{self.wss_port}/", timeout=1, verify=False) + # Any response (even error) means the server is up - we just need to connect + + def get_metrics(self) -> str: + # Metrics are still on the HTTP port + response = requests.get(f"http://{self.host}:{self.http_port}/metrics", timeout=5) + response.raise_for_status() + return response.text + + def assert_no_errors(self): + # Define allowed error patterns for rest broker proxy + allowed_errors = [ + "connection closed before message completed", + "connection reset by peer", + "broken pipe", + "client disconnected", + "Authentication failed", + "connection timed out", + "no connection available", + "Pool dropped", + ] + + with open(self.logfile) as f: + for line in f: + if "ERROR" in line or "FATAL" in line: + if not any(allowed in line for allowed in allowed_errors): + raise AssertionError( + f"Found error in rest broker proxy log: {line.strip()}" + ) + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ): + self.stop() + + @pytest.fixture(scope="function") def link_proxy( port_distributor: PortDistributor, neon_binpath: Path, test_output_dir: Path @@ -4203,6 +4491,81 @@ def static_proxy( yield proxy +@pytest.fixture(scope="function") +def local_proxy( + vanilla_pg: VanillaPostgres, + port_distributor: PortDistributor, + neon_binpath: Path, + test_output_dir: Path, +) -> Iterator[NeonLocalProxy]: + """Local proxy that connects directly to vanilla postgres for rest broker testing.""" + + # Start vanilla_pg without database bootstrapping + vanilla_pg.start() + + http_port = port_distributor.get_port() + metrics_port = port_distributor.get_port() + + with NeonLocalProxy( + neon_binpath=neon_binpath, + test_output_dir=test_output_dir, + http_port=http_port, + metrics_port=metrics_port, + vanilla_pg=vanilla_pg, + ) as proxy: + proxy.start() + yield proxy + + +@pytest.fixture(scope="function") +def local_proxy_fixed_port( + vanilla_pg: VanillaPostgres, + neon_binpath: Path, + test_output_dir: Path, +) -> Iterator[NeonLocalProxy]: + """Local proxy that connects directly to vanilla postgres on the hardcoded port 7432.""" + + # Start vanilla_pg without database bootstrapping + vanilla_pg.start() + + # Use the hardcoded port that the rest broker proxy expects + http_port = 7432 + metrics_port = 7433 # Use a different port for metrics + + with NeonLocalProxy( + neon_binpath=neon_binpath, + test_output_dir=test_output_dir, + http_port=http_port, + metrics_port=metrics_port, + vanilla_pg=vanilla_pg, + ) as proxy: + proxy.start() + yield proxy + + +@pytest.fixture(scope="function") +def rest_broker_proxy( + port_distributor: PortDistributor, + neon_binpath: Path, + test_output_dir: Path, +) -> Iterator[NeonRestBrokerProxy]: + """Rest broker proxy that handles both auth broker and rest broker functionality.""" + + wss_port = port_distributor.get_port() + http_port = port_distributor.get_port() + mgmt_port = port_distributor.get_port() + + with NeonRestBrokerProxy( + neon_binpath=neon_binpath, + test_output_dir=test_output_dir, + wss_port=wss_port, + http_port=http_port, + mgmt_port=mgmt_port, + ) as proxy: + proxy.start() + yield proxy + + @pytest.fixture(scope="function") def neon_authorize_jwk() -> jwk.JWK: kid = str(uuid.uuid4()) diff --git a/test_runner/fixtures/utils.py b/test_runner/fixtures/utils.py index 0d7345cc82..1f80c2a290 100644 --- a/test_runner/fixtures/utils.py +++ b/test_runner/fixtures/utils.py @@ -741,3 +741,29 @@ def shared_buffers_for_max_cu(max_cu: float) -> str: sharedBuffersMb = int(max(128, (1023 + maxBackends * 256) / 1024)) sharedBuffers = int(sharedBuffersMb * 1024 / 8) return str(sharedBuffers) + + +def skip_if_proxy_lacks_rest_broker(reason: str = "proxy was built without 'rest_broker' feature"): + # Determine the binary path using the same logic as neon_binpath fixture + def has_rest_broker_feature(): + # Find the neon binaries + if env_neon_bin := os.environ.get("NEON_BIN"): + binpath = Path(env_neon_bin) + else: + base_dir = Path(__file__).parents[2] # Same as BASE_DIR in paths.py + build_type = os.environ.get("BUILD_TYPE", "debug") + binpath = base_dir / "target" / build_type + + proxy_bin = binpath / "proxy" + if not proxy_bin.exists(): + return False + + try: + cmd = [str(proxy_bin), "--help"] + result = subprocess.run(cmd, capture_output=True, text=True, check=True, timeout=10) + help_output = result.stdout + return "--is-rest-broker" in help_output + except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError): + return False + + return pytest.mark.skipif(not has_rest_broker_feature(), reason=reason) diff --git a/test_runner/regress/test_rest_broker.py b/test_runner/regress/test_rest_broker.py new file mode 100644 index 0000000000..60b04655d3 --- /dev/null +++ b/test_runner/regress/test_rest_broker.py @@ -0,0 +1,137 @@ +import json +import signal +import time + +import requests +from fixtures.utils import skip_if_proxy_lacks_rest_broker +from jwcrypto import jwt + + +@skip_if_proxy_lacks_rest_broker() +def test_rest_broker_happy( + local_proxy_fixed_port, rest_broker_proxy, vanilla_pg, neon_authorize_jwk, httpserver +): + """Test REST API endpoint using local_proxy and rest_broker_proxy.""" + + # Use the fixed port local proxy + local_proxy = local_proxy_fixed_port + + # Create the required roles for PostgREST authentication + vanilla_pg.safe_psql("CREATE ROLE authenticator LOGIN") + vanilla_pg.safe_psql("CREATE ROLE authenticated") + vanilla_pg.safe_psql("CREATE ROLE anon") + vanilla_pg.safe_psql("GRANT authenticated TO authenticator") + vanilla_pg.safe_psql("GRANT anon TO authenticator") + + # Create the pgrst schema and configuration function required by the rest broker + vanilla_pg.safe_psql("CREATE SCHEMA IF NOT EXISTS pgrst") + vanilla_pg.safe_psql(""" + CREATE OR REPLACE FUNCTION pgrst.pre_config() + RETURNS VOID AS $$ + SELECT + set_config('pgrst.db_schemas', 'test', true) + , set_config('pgrst.db_aggregates_enabled', 'true', true) + , set_config('pgrst.db_anon_role', 'anon', true) + , set_config('pgrst.jwt_aud', '', true) + , set_config('pgrst.jwt_secret', '', true) + , set_config('pgrst.jwt_role_claim_key', '."role"', true) + + $$ LANGUAGE SQL; + """) + vanilla_pg.safe_psql("GRANT USAGE ON SCHEMA pgrst TO authenticator") + vanilla_pg.safe_psql("GRANT EXECUTE ON ALL FUNCTIONS IN SCHEMA pgrst TO authenticator") + + # Bootstrap the database with test data + vanilla_pg.safe_psql("CREATE SCHEMA IF NOT EXISTS test") + vanilla_pg.safe_psql(""" + CREATE TABLE IF NOT EXISTS test.items ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL + ) + """) + vanilla_pg.safe_psql("INSERT INTO test.items (name) VALUES ('test_item')") + + # Grant access to the test schema for the authenticated role + vanilla_pg.safe_psql("GRANT USAGE ON SCHEMA test TO authenticated") + vanilla_pg.safe_psql("GRANT SELECT ON ALL TABLES IN SCHEMA test TO authenticated") + + # Set up HTTP server to serve JWKS (like static_auth_broker) + # Generate public key from the JWK + public_key = neon_authorize_jwk.export_public(as_dict=True) + + # Set up the httpserver to serve the JWKS + httpserver.expect_request("/.well-known/jwks.json").respond_with_json({"keys": [public_key]}) + + # Create JWKS configuration for the rest broker proxy + jwks_config = { + "jwks": [ + { + "id": "1", + "role_names": ["authenticator", "authenticated", "anon"], + "jwks_url": httpserver.url_for("/.well-known/jwks.json"), + "provider_name": "foo", + "jwt_audience": None, + } + ] + } + + # Write the JWKS config to the config file that rest_broker_proxy expects + config_file = rest_broker_proxy.config_path + with open(config_file, "w") as f: + json.dump(jwks_config, f) + + # Write the same config to the local_proxy config file + local_config_file = local_proxy.config_path + with open(local_config_file, "w") as f: + json.dump(jwks_config, f) + + # Signal both proxies to reload their config + if rest_broker_proxy._popen is not None: + rest_broker_proxy._popen.send_signal(signal.SIGHUP) + if local_proxy._popen is not None: + local_proxy._popen.send_signal(signal.SIGHUP) + # Wait a bit for config to reload + time.sleep(0.5) + + # Generate a proper JWT token using the JWK (similar to test_auth_broker.py) + token = jwt.JWT( + header={"kid": neon_authorize_jwk.key_id, "alg": "RS256"}, + claims={ + "sub": "user", + "role": "authenticated", # role that's in role_names + "exp": 9999999999, # expires far in the future + "iat": 1000000000, # issued at + }, + ) + token.make_signed_token(neon_authorize_jwk) + + # Debug: Print the JWT claims and config for troubleshooting + print(f"JWT claims: {token.claims}") + print(f"JWT header: {token.header}") + print(f"Config file contains: {jwks_config}") + print(f"Public key kid: {public_key.get('kid')}") + + # Test REST API call - following SUBZERO.md pattern + # REST API is served on the WSS port with HTTPS and includes database name + # ep-purple-glitter-adqior4l-pooler.c-2.us-east-1.aws.neon.tech + url = f"https://foo.apirest.c-2.local.neon.build:{rest_broker_proxy.wss_port}/postgres/rest/v1/items" + + response = requests.get( + url, + headers={ + "Authorization": f"Bearer {token.serialize()}", + }, + params={"id": "eq.1", "select": "name"}, + verify=False, # Skip SSL verification for self-signed certs + ) + + print(f"Response status: {response.status_code}") + print(f"Response headers: {response.headers}") + print(f"Response body: {response.text}") + + # For now, let's just check that we get some response + # We can refine the assertions once we see what the actual response looks like + assert response.status_code in [200] # Any response means the proxies are working + + # check the response body + assert response.json() == [{"name": "test_item"}]