diff --git a/.config/hakari.toml b/.config/hakari.toml index b5990d090e..3b6d9d8822 100644 --- a/.config/hakari.toml +++ b/.config/hakari.toml @@ -46,6 +46,9 @@ workspace-members = [ "utils", "wal_craft", "walproposer", + "postgres-protocol2", + "postgres-types2", + "tokio-postgres2", ] # Write out exact versions rather than a semver range. (Defaults to false.) diff --git a/.github/actions/run-python-test-set/action.yml b/.github/actions/run-python-test-set/action.yml index 275f161019..1159627302 100644 --- a/.github/actions/run-python-test-set/action.yml +++ b/.github/actions/run-python-test-set/action.yml @@ -36,8 +36,8 @@ inputs: description: 'Region name for real s3 tests' required: false default: '' - rerun_flaky: - description: 'Whether to rerun flaky tests' + rerun_failed: + description: 'Whether to rerun failed tests' required: false default: 'false' pg_version: @@ -108,7 +108,7 @@ runs: COMPATIBILITY_SNAPSHOT_DIR: /tmp/compatibility_snapshot_pg${{ inputs.pg_version }} ALLOW_BACKWARD_COMPATIBILITY_BREAKAGE: contains(github.event.pull_request.labels.*.name, 'backward compatibility breakage') ALLOW_FORWARD_COMPATIBILITY_BREAKAGE: contains(github.event.pull_request.labels.*.name, 'forward compatibility breakage') - RERUN_FLAKY: ${{ inputs.rerun_flaky }} + RERUN_FAILED: ${{ inputs.rerun_failed }} PG_VERSION: ${{ inputs.pg_version }} shell: bash -euxo pipefail {0} run: | @@ -154,15 +154,8 @@ runs: EXTRA_PARAMS="--out-dir $PERF_REPORT_DIR $EXTRA_PARAMS" fi - if [ "${RERUN_FLAKY}" == "true" ]; then - mkdir -p $TEST_OUTPUT - poetry run ./scripts/flaky_tests.py "${TEST_RESULT_CONNSTR}" \ - --days 7 \ - --output "$TEST_OUTPUT/flaky.json" \ - --pg-version "${DEFAULT_PG_VERSION}" \ - --build-type "${BUILD_TYPE}" - - EXTRA_PARAMS="--flaky-tests-json $TEST_OUTPUT/flaky.json $EXTRA_PARAMS" + if [ "${RERUN_FAILED}" == "true" ]; then + EXTRA_PARAMS="--reruns 2 $EXTRA_PARAMS" fi # We use pytest-split plugin to run benchmarks in parallel on different CI runners diff --git a/.github/workflows/_build-and-test-locally.yml b/.github/workflows/_build-and-test-locally.yml index 8e28049888..42c32a23e3 100644 --- a/.github/workflows/_build-and-test-locally.yml +++ b/.github/workflows/_build-and-test-locally.yml @@ -19,8 +19,8 @@ on: description: 'debug or release' required: true type: string - pg-versions: - description: 'a json array of postgres versions to run regression tests on' + test-cfg: + description: 'a json object of postgres versions and lfc states to run regression tests on' required: true type: string @@ -276,14 +276,14 @@ jobs: options: --init --shm-size=512mb --ulimit memlock=67108864:67108864 strategy: fail-fast: false - matrix: - pg_version: ${{ fromJson(inputs.pg-versions) }} + matrix: ${{ fromJSON(format('{{"include":{0}}}', inputs.test-cfg)) }} steps: - uses: actions/checkout@v4 with: submodules: true - name: Pytest regression tests + continue-on-error: ${{ matrix.lfc_state == 'with-lfc' }} uses: ./.github/actions/run-python-test-set timeout-minutes: 60 with: @@ -293,13 +293,14 @@ jobs: run_with_real_s3: true real_s3_bucket: neon-github-ci-tests real_s3_region: eu-central-1 - rerun_flaky: true + rerun_failed: true pg_version: ${{ matrix.pg_version }} env: TEST_RESULT_CONNSTR: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }} CHECK_ONDISK_DATA_COMPATIBILITY: nonempty BUILD_TAG: ${{ inputs.build-tag }} PAGESERVER_VIRTUAL_FILE_IO_ENGINE: tokio-epoll-uring + USE_LFC: ${{ matrix.lfc_state == 'with-lfc' && 'true' || 'false' }} # Temporary disable this step until we figure out why it's so flaky # Ref https://github.com/neondatabase/neon/issues/4540 diff --git a/.github/workflows/benchmarking.yml b/.github/workflows/benchmarking.yml index acea859b4d..ea8fee80c2 100644 --- a/.github/workflows/benchmarking.yml +++ b/.github/workflows/benchmarking.yml @@ -541,7 +541,7 @@ jobs: runs-on: ${{ matrix.RUNNER }} container: - image: neondatabase/build-tools:pinned + image: neondatabase/build-tools:pinned-bookworm credentials: username: ${{ secrets.NEON_DOCKERHUB_USERNAME }} password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }} @@ -558,12 +558,12 @@ jobs: arch=$(uname -m | sed 's/x86_64/amd64/g' | sed 's/aarch64/arm64/g') cd /home/nonroot - wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-17/libpq5_17.1-1.pgdg110+1_${arch}.deb" - wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-client-16_16.5-1.pgdg110+1_${arch}.deb" - wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-16_16.5-1.pgdg110+1_${arch}.deb" - dpkg -x libpq5_17.1-1.pgdg110+1_${arch}.deb pg - dpkg -x postgresql-16_16.5-1.pgdg110+1_${arch}.deb pg - dpkg -x postgresql-client-16_16.5-1.pgdg110+1_${arch}.deb pg + wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-17/libpq5_17.2-1.pgdg120+1_${arch}.deb" + wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-client-16_16.6-1.pgdg120+1_${arch}.deb" + wget -q "https://apt.postgresql.org/pub/repos/apt/pool/main/p/postgresql-16/postgresql-16_16.6-1.pgdg120+1_${arch}.deb" + dpkg -x libpq5_17.2-1.pgdg120+1_${arch}.deb pg + dpkg -x postgresql-16_16.6-1.pgdg120+1_${arch}.deb pg + dpkg -x postgresql-client-16_16.6-1.pgdg120+1_${arch}.deb pg mkdir -p /tmp/neon/pg_install/v16/bin ln -s /home/nonroot/pg/usr/lib/postgresql/16/bin/pgbench /tmp/neon/pg_install/v16/bin/pgbench diff --git a/.github/workflows/build-build-tools-image.yml b/.github/workflows/build-build-tools-image.yml index 9e7be76901..0a7f0cd7a0 100644 --- a/.github/workflows/build-build-tools-image.yml +++ b/.github/workflows/build-build-tools-image.yml @@ -2,6 +2,17 @@ name: Build build-tools image on: workflow_call: + inputs: + archs: + description: "Json array of architectures to build" + # Default values are set in `check-image` job, `set-variables` step + type: string + required: false + debians: + description: "Json array of Debian versions to build" + # Default values are set in `check-image` job, `set-variables` step + type: string + required: false outputs: image-tag: description: "build-tools tag" @@ -32,25 +43,37 @@ jobs: check-image: runs-on: ubuntu-22.04 outputs: - tag: ${{ steps.get-build-tools-tag.outputs.image-tag }} - found: ${{ steps.check-image.outputs.found }} + archs: ${{ steps.set-variables.outputs.archs }} + debians: ${{ steps.set-variables.outputs.debians }} + tag: ${{ steps.set-variables.outputs.image-tag }} + everything: ${{ steps.set-more-variables.outputs.everything }} + found: ${{ steps.set-more-variables.outputs.found }} steps: - uses: actions/checkout@v4 - - name: Get build-tools image tag for the current commit - id: get-build-tools-tag + - name: Set variables + id: set-variables env: + ARCHS: ${{ inputs.archs || '["x64","arm64"]' }} + DEBIANS: ${{ inputs.debians || '["bullseye","bookworm"]' }} IMAGE_TAG: | ${{ hashFiles('build-tools.Dockerfile', '.github/workflows/build-build-tools-image.yml') }} run: | - echo "image-tag=${IMAGE_TAG}" | tee -a $GITHUB_OUTPUT + echo "archs=${ARCHS}" | tee -a ${GITHUB_OUTPUT} + echo "debians=${DEBIANS}" | tee -a ${GITHUB_OUTPUT} + echo "image-tag=${IMAGE_TAG}" | tee -a ${GITHUB_OUTPUT} - - name: Check if such tag found in the registry - id: check-image + - name: Set more variables + id: set-more-variables env: - IMAGE_TAG: ${{ steps.get-build-tools-tag.outputs.image-tag }} + IMAGE_TAG: ${{ steps.set-variables.outputs.image-tag }} + EVERYTHING: | + ${{ contains(fromJson(steps.set-variables.outputs.archs), 'x64') && + contains(fromJson(steps.set-variables.outputs.archs), 'arm64') && + contains(fromJson(steps.set-variables.outputs.debians), 'bullseye') && + contains(fromJson(steps.set-variables.outputs.debians), 'bookworm') }} run: | if docker manifest inspect neondatabase/build-tools:${IMAGE_TAG}; then found=true @@ -58,8 +81,8 @@ jobs: found=false fi - echo "found=${found}" | tee -a $GITHUB_OUTPUT - + echo "everything=${EVERYTHING}" | tee -a ${GITHUB_OUTPUT} + echo "found=${found}" | tee -a ${GITHUB_OUTPUT} build-image: needs: [ check-image ] @@ -67,8 +90,8 @@ jobs: strategy: matrix: - debian-version: [ bullseye, bookworm ] - arch: [ x64, arm64 ] + arch: ${{ fromJson(needs.check-image.outputs.archs) }} + debian: ${{ fromJson(needs.check-image.outputs.debians) }} runs-on: ${{ fromJson(format('["self-hosted", "{0}"]', matrix.arch == 'arm64' && 'large-arm64' || 'large')) }} @@ -99,11 +122,11 @@ jobs: push: true pull: true build-args: | - DEBIAN_VERSION=${{ matrix.debian-version }} - cache-from: type=registry,ref=cache.neon.build/build-tools:cache-${{ matrix.debian-version }}-${{ matrix.arch }} - cache-to: ${{ github.ref_name == 'main' && format('type=registry,ref=cache.neon.build/build-tools:cache-{0}-{1},mode=max', matrix.debian-version, matrix.arch) || '' }} + DEBIAN_VERSION=${{ matrix.debian }} + cache-from: type=registry,ref=cache.neon.build/build-tools:cache-${{ matrix.debian }}-${{ matrix.arch }} + cache-to: ${{ github.ref_name == 'main' && format('type=registry,ref=cache.neon.build/build-tools:cache-{0}-{1},mode=max', matrix.debian, matrix.arch) || '' }} tags: | - neondatabase/build-tools:${{ needs.check-image.outputs.tag }}-${{ matrix.debian-version }}-${{ matrix.arch }} + neondatabase/build-tools:${{ needs.check-image.outputs.tag }}-${{ matrix.debian }}-${{ matrix.arch }} merge-images: needs: [ check-image, build-image ] @@ -117,16 +140,22 @@ jobs: - name: Create multi-arch image env: - DEFAULT_DEBIAN_VERSION: bullseye + DEFAULT_DEBIAN_VERSION: bookworm + ARCHS: ${{ join(fromJson(needs.check-image.outputs.archs), ' ') }} + DEBIANS: ${{ join(fromJson(needs.check-image.outputs.debians), ' ') }} + EVERYTHING: ${{ needs.check-image.outputs.everything }} IMAGE_TAG: ${{ needs.check-image.outputs.tag }} run: | - for debian_version in bullseye bookworm; do - tags=("-t" "neondatabase/build-tools:${IMAGE_TAG}-${debian_version}") - if [ "${debian_version}" == "${DEFAULT_DEBIAN_VERSION}" ]; then + for debian in ${DEBIANS}; do + tags=("-t" "neondatabase/build-tools:${IMAGE_TAG}-${debian}") + + if [ "${EVERYTHING}" == "true" ] && [ "${debian}" == "${DEFAULT_DEBIAN_VERSION}" ]; then tags+=("-t" "neondatabase/build-tools:${IMAGE_TAG}") fi - docker buildx imagetools create "${tags[@]}" \ - neondatabase/build-tools:${IMAGE_TAG}-${debian_version}-x64 \ - neondatabase/build-tools:${IMAGE_TAG}-${debian_version}-arm64 + for arch in ${ARCHS}; do + tags+=("neondatabase/build-tools:${IMAGE_TAG}-${debian}-${arch}") + done + + docker buildx imagetools create "${tags[@]}" done diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 89fd2d0d17..9830c2a0c9 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -253,7 +253,14 @@ jobs: build-tag: ${{ needs.tag.outputs.build-tag }} build-type: ${{ matrix.build-type }} # Run tests on all Postgres versions in release builds and only on the latest version in debug builds - pg-versions: ${{ matrix.build-type == 'release' && '["v14", "v15", "v16", "v17"]' || '["v17"]' }} + # run without LFC on v17 release only + test-cfg: | + ${{ matrix.build-type == 'release' && '[{"pg_version":"v14", "lfc_state": "without-lfc"}, + {"pg_version":"v15", "lfc_state": "without-lfc"}, + {"pg_version":"v16", "lfc_state": "without-lfc"}, + {"pg_version":"v17", "lfc_state": "without-lfc"}, + {"pg_version":"v17", "lfc_state": "with-lfc"}]' + || '[{"pg_version":"v17", "lfc_state": "without-lfc"}]' }} secrets: inherit # Keep `benchmarks` job outside of `build-and-test-locally` workflow to make job failures non-blocking diff --git a/.github/workflows/periodic_pagebench.yml b/.github/workflows/periodic_pagebench.yml index 1cce348ae2..6b98bc873f 100644 --- a/.github/workflows/periodic_pagebench.yml +++ b/.github/workflows/periodic_pagebench.yml @@ -29,7 +29,7 @@ jobs: trigger_bench_on_ec2_machine_in_eu_central_1: runs-on: [ self-hosted, small ] container: - image: neondatabase/build-tools:pinned + image: neondatabase/build-tools:pinned-bookworm credentials: username: ${{ secrets.NEON_DOCKERHUB_USERNAME }} password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }} diff --git a/.github/workflows/pin-build-tools-image.yml b/.github/workflows/pin-build-tools-image.yml index c196d07d3e..5b43d97de6 100644 --- a/.github/workflows/pin-build-tools-image.yml +++ b/.github/workflows/pin-build-tools-image.yml @@ -94,7 +94,7 @@ jobs: - name: Tag build-tools with `${{ env.TO_TAG }}` in Docker Hub, ECR, and ACR env: - DEFAULT_DEBIAN_VERSION: bullseye + DEFAULT_DEBIAN_VERSION: bookworm run: | for debian_version in bullseye bookworm; do tags=() diff --git a/.github/workflows/pre-merge-checks.yml b/.github/workflows/pre-merge-checks.yml index e1cec6d33d..d2f9d8a666 100644 --- a/.github/workflows/pre-merge-checks.yml +++ b/.github/workflows/pre-merge-checks.yml @@ -23,6 +23,8 @@ jobs: id: python-src with: files: | + .github/workflows/_check-codestyle-python.yml + .github/workflows/build-build-tools-image.yml .github/workflows/pre-merge-checks.yml **/**.py poetry.lock @@ -38,6 +40,10 @@ jobs: if: needs.get-changed-files.outputs.python-changed == 'true' needs: [ get-changed-files ] uses: ./.github/workflows/build-build-tools-image.yml + with: + # Build only one combination to save time + archs: '["x64"]' + debians: '["bookworm"]' secrets: inherit check-codestyle-python: @@ -45,7 +51,8 @@ jobs: needs: [ get-changed-files, build-build-tools-image ] uses: ./.github/workflows/_check-codestyle-python.yml with: - build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm + # `-bookworm-x64` suffix should match the combination in `build-build-tools-image` + build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm-x64 secrets: inherit # To get items from the merge queue merged into main we need to satisfy "Status checks that are required". diff --git a/.github/workflows/report-workflow-stats-batch.yml b/.github/workflows/report-workflow-stats-batch.yml index 98e394a3c2..2ed044b780 100644 --- a/.github/workflows/report-workflow-stats-batch.yml +++ b/.github/workflows/report-workflow-stats-batch.yml @@ -4,10 +4,12 @@ on: schedule: - cron: '*/15 * * * *' - cron: '25 0 * * *' + - cron: '25 1 * * 6' jobs: - gh-workflow-stats-batch: - name: GitHub Workflow Stats Batch + gh-workflow-stats-batch-2h: + name: GitHub Workflow Stats Batch 2 hours + if: github.event.schedule == '*/15 * * * *' runs-on: ubuntu-22.04 permissions: actions: read @@ -16,14 +18,36 @@ jobs: uses: neondatabase/gh-workflow-stats-action@v0.2.1 with: db_uri: ${{ secrets.GH_REPORT_STATS_DB_RW_CONNSTR }} - db_table: "gh_workflow_stats_batch_neon" + db_table: "gh_workflow_stats_neon" gh_token: ${{ secrets.GITHUB_TOKEN }} duration: '2h' - - name: Export Workflow Run for the past 24 hours - if: github.event.schedule == '25 0 * * *' + + gh-workflow-stats-batch-48h: + name: GitHub Workflow Stats Batch 48 hours + if: github.event.schedule == '25 0 * * *' + runs-on: ubuntu-22.04 + permissions: + actions: read + steps: + - name: Export Workflow Run for the past 48 hours uses: neondatabase/gh-workflow-stats-action@v0.2.1 with: db_uri: ${{ secrets.GH_REPORT_STATS_DB_RW_CONNSTR }} - db_table: "gh_workflow_stats_batch_neon" + db_table: "gh_workflow_stats_neon" gh_token: ${{ secrets.GITHUB_TOKEN }} - duration: '24h' + duration: '48h' + + gh-workflow-stats-batch-30d: + name: GitHub Workflow Stats Batch 30 days + if: github.event.schedule == '25 1 * * 6' + runs-on: ubuntu-22.04 + permissions: + actions: read + steps: + - name: Export Workflow Run for the past 30 days + uses: neondatabase/gh-workflow-stats-action@v0.2.1 + with: + db_uri: ${{ secrets.GH_REPORT_STATS_DB_RW_CONNSTR }} + db_table: "gh_workflow_stats_neon" + gh_token: ${{ secrets.GITHUB_TOKEN }} + duration: '720h' diff --git a/.github/workflows/report-workflow-stats.yml b/.github/workflows/report-workflow-stats.yml deleted file mode 100644 index 15e446bcd7..0000000000 --- a/.github/workflows/report-workflow-stats.yml +++ /dev/null @@ -1,41 +0,0 @@ -name: Report Workflow Stats - -on: - workflow_run: - workflows: - - Add `external` label to issues and PRs created by external users - - Benchmarking - - Build and Test - - Build and Test Locally - - Build build-tools image - - Check Permissions - - Check neon with extra platform builds - - Cloud Regression Test - - Create Release Branch - - Handle `approved-for-ci-run` label - - Lint GitHub Workflows - - Notify Slack channel about upcoming release - - Periodic pagebench performance test on dedicated EC2 machine in eu-central-1 region - - Pin build-tools image - - Prepare benchmarking databases by restoring dumps - - Push images to ACR - - Test Postgres client libraries - - Trigger E2E Tests - - cleanup caches by a branch - - Pre-merge checks - types: [completed] - -jobs: - gh-workflow-stats: - name: Github Workflow Stats - runs-on: ubuntu-22.04 - permissions: - actions: read - steps: - - name: Export GH Workflow Stats - uses: neondatabase/gh-workflow-stats-action@v0.1.4 - with: - DB_URI: ${{ secrets.GH_REPORT_STATS_DB_RW_CONNSTR }} - DB_TABLE: "gh_workflow_stats_neon" - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - GH_RUN_ID: ${{ github.event.workflow_run.id }} diff --git a/Cargo.lock b/Cargo.lock index c7af140f7d..b104a35bf5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -46,6 +46,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "aligned-vec" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e0966165eaf052580bd70eb1b32cb3d6245774c0104d1b2793e9650bf83b52a" +dependencies = [ + "equator", +] + [[package]] name = "allocator-api2" version = "0.2.16" @@ -146,6 +155,12 @@ dependencies = [ "static_assertions", ] +[[package]] +name = "arrayvec" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" + [[package]] name = "asn1-rs" version = "0.6.2" @@ -359,6 +374,28 @@ dependencies = [ "tracing", ] +[[package]] +name = "aws-sdk-kms" +version = "1.47.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "564a597a3c71a957d60a2e4c62c93d78ee5a0d636531e15b760acad983a5c18e" +dependencies = [ + "aws-credential-types", + "aws-runtime", + "aws-smithy-async", + "aws-smithy-http", + "aws-smithy-json", + "aws-smithy-runtime", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "bytes", + "http 0.2.9", + "once_cell", + "regex-lite", + "tracing", +] + [[package]] name = "aws-sdk-s3" version = "1.52.0" @@ -575,9 +612,9 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.1" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1ce695746394772e7000b39fe073095db6d45a862d0767dd5ad0ac0d7f8eb87" +checksum = "a065c0fe6fdbdf9f11817eb68582b2ab4aff9e9c39e986ae48f7ec576c6322db" dependencies = [ "aws-smithy-async", "aws-smithy-http", @@ -742,7 +779,7 @@ dependencies = [ "once_cell", "paste", "pin-project", - "quick-xml", + "quick-xml 0.31.0", "rand 0.8.5", "reqwest 0.11.19", "rustc_version", @@ -1220,6 +1257,10 @@ name = "compute_tools" version = "0.1.0" dependencies = [ "anyhow", + "aws-config", + "aws-sdk-kms", + "aws-sdk-s3", + "base64 0.13.1", "bytes", "camino", "cfg-if", @@ -1237,13 +1278,16 @@ dependencies = [ "opentelemetry", "opentelemetry_sdk", "postgres", + "postgres_initdb", "prometheus", "regex", "remote_storage", "reqwest 0.12.4", "rlimit", "rust-ini", + "serde", "serde_json", + "serde_with", "signal-hook", "tar", "thiserror", @@ -1381,6 +1425,15 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" +[[package]] +name = "cpp_demangle" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96e58d342ad113c2b878f16d5d034c03be492ae460cdbc02b7f0f2284d310c7d" +dependencies = [ + "cfg-if", +] + [[package]] name = "cpufeatures" version = "0.2.9" @@ -1664,6 +1717,12 @@ dependencies = [ "utils", ] +[[package]] +name = "diatomic-waker" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab03c107fafeb3ee9f5925686dbb7a73bc76e3932abb0d2b365cb64b169cf04c" + [[package]] name = "diesel" version = "2.2.3" @@ -1904,6 +1963,26 @@ dependencies = [ "termcolor", ] +[[package]] +name = "equator" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c35da53b5a021d2484a7cc49b2ac7f2d840f8236a286f84202369bd338d761ea" +dependencies = [ + "equator-macro", +] + +[[package]] +name = "equator-macro" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3bf679796c0322556351f287a51b49e48f7c4986e727b5dd78c972d30e2e16cc" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.52", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -2011,6 +2090,18 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "findshlibs" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40b9e59cd0f7e0806cca4be089683ecb6434e602038df21fe6bf6711b2f07f64" +dependencies = [ + "cc", + "lazy_static", + "libc", + "winapi", +] + [[package]] name = "fixedbitset" version = "0.4.2" @@ -2089,9 +2180,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", "futures-sink", @@ -2099,9 +2190,9 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-executor" @@ -2116,9 +2207,9 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" [[package]] name = "futures-lite" @@ -2137,9 +2228,9 @@ dependencies = [ [[package]] name = "futures-macro" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", @@ -2148,15 +2239,15 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" [[package]] name = "futures-timer" @@ -2166,9 +2257,9 @@ checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" [[package]] name = "futures-util" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-channel", "futures-core", @@ -2714,6 +2805,24 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "64e9829a50b42bb782c1df523f78d332fe371b10c661e78b7a3c34b0198e9fac" +[[package]] +name = "inferno" +version = "0.11.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "232929e1d75fe899576a3d5c7416ad0d88dbfbb3c3d6aa00873a7408a50ddb88" +dependencies = [ + "ahash", + "indexmap 2.0.1", + "is-terminal", + "itoa", + "log", + "num-format", + "once_cell", + "quick-xml 0.26.0", + "rgb", + "str_stack", +] + [[package]] name = "inotify" version = "0.9.6" @@ -2764,9 +2873,9 @@ dependencies = [ [[package]] name = "ipnet" -version = "2.9.0" +version = "2.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" +checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" [[package]] name = "is-terminal" @@ -3053,6 +3162,15 @@ version = "2.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" +[[package]] +name = "memmap2" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45fd3a57831bf88bc63f8cebc0cf956116276e97fef3966103e96416209f7c92" +dependencies = [ + "libc", +] + [[package]] name = "memoffset" version = "0.7.1" @@ -3278,6 +3396,16 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "num-format" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a652d9771a63711fd3c3deb670acfbe5c30a4072e664d7a3bf5a9e1056ac72c3" +dependencies = [ + "arrayvec", + "itoa", +] + [[package]] name = "num-integer" version = "0.1.45" @@ -3619,6 +3747,7 @@ dependencies = [ "num_cpus", "once_cell", "pageserver_api", + "pageserver_client", "pageserver_compaction", "pin-project-lite", "postgres", @@ -3627,6 +3756,7 @@ dependencies = [ "postgres_backend", "postgres_connection", "postgres_ffi", + "postgres_initdb", "pq_proto", "procfs", "rand 0.8.5", @@ -4009,7 +4139,7 @@ dependencies = [ [[package]] name = "postgres" version = "0.19.4" -source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#a130197713830a0ea0004b539b1f51a66b4c3e18" +source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#00940fcdb57a8e99e805297b75839e7c4c7b1796" dependencies = [ "bytes", "fallible-iterator", @@ -4022,7 +4152,7 @@ dependencies = [ [[package]] name = "postgres-protocol" version = "0.6.4" -source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#a130197713830a0ea0004b539b1f51a66b4c3e18" +source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#00940fcdb57a8e99e805297b75839e7c4c7b1796" dependencies = [ "base64 0.20.0", "byteorder", @@ -4038,16 +4168,40 @@ dependencies = [ "tokio", ] +[[package]] +name = "postgres-protocol2" +version = "0.1.0" +dependencies = [ + "base64 0.20.0", + "byteorder", + "bytes", + "fallible-iterator", + "hmac", + "md-5", + "memchr", + "rand 0.8.5", + "sha2", + "stringprep", + "tokio", +] + [[package]] name = "postgres-types" version = "0.2.4" -source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#a130197713830a0ea0004b539b1f51a66b4c3e18" +source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#00940fcdb57a8e99e805297b75839e7c4c7b1796" dependencies = [ "bytes", "fallible-iterator", "postgres-protocol", - "serde", - "serde_json", +] + +[[package]] +name = "postgres-types2" +version = "0.1.0" +dependencies = [ + "bytes", + "fallible-iterator", + "postgres-protocol2", ] [[package]] @@ -4058,7 +4212,7 @@ dependencies = [ "bytes", "once_cell", "pq_proto", - "rustls 0.23.16", + "rustls 0.23.18", "rustls-pemfile 2.1.1", "serde", "thiserror", @@ -4102,12 +4256,48 @@ dependencies = [ "utils", ] +[[package]] +name = "postgres_initdb" +version = "0.1.0" +dependencies = [ + "anyhow", + "camino", + "thiserror", + "tokio", + "workspace_hack", +] + [[package]] name = "powerfmt" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" +[[package]] +name = "pprof" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebbe2f8898beba44815fdc9e5a4ae9c929e21c5dc29b0c774a15555f7f58d6d0" +dependencies = [ + "aligned-vec", + "backtrace", + "cfg-if", + "criterion", + "findshlibs", + "inferno", + "libc", + "log", + "nix 0.26.4", + "once_cell", + "parking_lot 0.12.1", + "protobuf", + "protobuf-codegen-pure", + "smallvec", + "symbolic-demangle", + "tempfile", + "thiserror", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -4260,6 +4450,31 @@ dependencies = [ "prost", ] +[[package]] +name = "protobuf" +version = "2.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94" + +[[package]] +name = "protobuf-codegen" +version = "2.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "033460afb75cf755fcfc16dfaed20b86468082a2ea24e05ac35ab4a099a017d6" +dependencies = [ + "protobuf", +] + +[[package]] +name = "protobuf-codegen-pure" +version = "2.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95a29399fc94bcd3eeaa951c715f7bea69409b2445356b00519740bcd6ddd865" +dependencies = [ + "protobuf", + "protobuf-codegen", +] + [[package]] name = "proxy" version = "0.1.0" @@ -4316,7 +4531,7 @@ dependencies = [ "parquet_derive", "pbkdf2", "pin-project-lite", - "postgres-protocol", + "postgres-protocol2", "postgres_backend", "pq_proto", "prometheus", @@ -4333,7 +4548,7 @@ dependencies = [ "rsa", "rstest", "rustc-hash", - "rustls 0.23.16", + "rustls 0.23.18", "rustls-native-certs 0.8.0", "rustls-pemfile 2.1.1", "scopeguard", @@ -4351,8 +4566,7 @@ dependencies = [ "tikv-jemalloc-ctl", "tikv-jemallocator", "tokio", - "tokio-postgres", - "tokio-postgres-rustls", + "tokio-postgres2", "tokio-rustls 0.26.0", "tokio-tungstenite", "tokio-util", @@ -4371,6 +4585,15 @@ dependencies = [ "zerocopy", ] +[[package]] +name = "quick-xml" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f50b1c63b38611e7d4d7f68b82d3ad0cc71a2ad2e7f61fc10f1328d917c93cd" +dependencies = [ + "memchr", +] + [[package]] name = "quick-xml" version = "0.31.0" @@ -4853,6 +5076,15 @@ dependencies = [ "subtle", ] +[[package]] +name = "rgb" +version = "0.8.50" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57397d16646700483b67d2dd6511d79318f9d057fdbd21a4066aeac8b41d310a" +dependencies = [ + "bytemuck", +] + [[package]] name = "ring" version = "0.17.6" @@ -5028,9 +5260,9 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.16" +version = "0.23.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eee87ff5d9b36712a58574e12e9f0ea80f915a5b0ac518d322b24a465617925e" +checksum = "9c9cc1d47e243d655ace55ed38201c19ae02c148ae56412ab8750e8f0166ab7f" dependencies = [ "log", "once_cell", @@ -5161,11 +5393,13 @@ dependencies = [ "itertools 0.10.5", "metrics", "once_cell", + "pageserver_api", "parking_lot 0.12.1", "postgres", "postgres-protocol", "postgres_backend", "postgres_ffi", + "pprof", "pq_proto", "rand 0.8.5", "regex", @@ -5191,6 +5425,7 @@ dependencies = [ "tracing-subscriber", "url", "utils", + "wal_decoder", "walproposer", "workspace_hack", ] @@ -5712,6 +5947,12 @@ dependencies = [ "der 0.7.8", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + [[package]] name = "static_assertions" version = "1.1.0" @@ -5738,7 +5979,7 @@ dependencies = [ "once_cell", "parking_lot 0.12.1", "prost", - "rustls 0.23.16", + "rustls 0.23.18", "tokio", "tonic", "tonic-build", @@ -5821,7 +6062,7 @@ dependencies = [ "postgres_ffi", "remote_storage", "reqwest 0.12.4", - "rustls 0.23.16", + "rustls 0.23.18", "rustls-native-certs 0.8.0", "serde", "serde_json", @@ -5858,6 +6099,12 @@ dependencies = [ "workspace_hack", ] +[[package]] +name = "str_stack" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9091b6114800a5f2141aee1d1b9d6ca3592ac062dc5decb3764ec5895a47b4eb" + [[package]] name = "stringprep" version = "0.1.2" @@ -5905,6 +6152,29 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20e16a0f46cf5fd675563ef54f26e83e20f2366bcf027bcb3cc3ed2b98aaf2ca" +[[package]] +name = "symbolic-common" +version = "12.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "366f1b4c6baf6cfefc234bbd4899535fca0b06c74443039a73f6dfb2fad88d77" +dependencies = [ + "debugid", + "memmap2", + "stable_deref_trait", + "uuid", +] + +[[package]] +name = "symbolic-demangle" +version = "12.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aba05ba5b9962ea5617baf556293720a8b2d0a282aa14ee4bf10e22efc7da8c8" +dependencies = [ + "cpp_demangle", + "rustc-demangle", + "symbolic-common", +] + [[package]] name = "syn" version = "1.0.109" @@ -6180,6 +6450,7 @@ dependencies = [ "libc", "mio", "num_cpus", + "parking_lot 0.12.1", "pin-project-lite", "signal-hook-registry", "socket2", @@ -6227,7 +6498,7 @@ dependencies = [ [[package]] name = "tokio-postgres" version = "0.7.7" -source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#a130197713830a0ea0004b539b1f51a66b4c3e18" +source = "git+https://github.com/neondatabase/rust-postgres.git?branch=neon#00940fcdb57a8e99e805297b75839e7c4c7b1796" dependencies = [ "async-trait", "byteorder", @@ -6254,13 +6525,33 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04fb792ccd6bbcd4bba408eb8a292f70fc4a3589e5d793626f45190e6454b6ab" dependencies = [ "ring", - "rustls 0.23.16", + "rustls 0.23.18", "tokio", "tokio-postgres", "tokio-rustls 0.26.0", "x509-certificate", ] +[[package]] +name = "tokio-postgres2" +version = "0.1.0" +dependencies = [ + "async-trait", + "byteorder", + "bytes", + "fallible-iterator", + "futures-util", + "log", + "parking_lot 0.12.1", + "percent-encoding", + "phf", + "pin-project-lite", + "postgres-protocol2", + "postgres-types2", + "tokio", + "tokio-util", +] + [[package]] name = "tokio-rustls" version = "0.24.0" @@ -6288,7 +6579,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.16", + "rustls 0.23.18", "rustls-pki-types", "tokio", ] @@ -6697,7 +6988,7 @@ dependencies = [ "base64 0.22.1", "log", "once_cell", - "rustls 0.23.16", + "rustls 0.23.18", "rustls-pki-types", "url", "webpki-roots 0.26.1", @@ -6759,6 +7050,7 @@ dependencies = [ "chrono", "const_format", "criterion", + "diatomic-waker", "fail", "futures", "git-version", @@ -6772,6 +7064,7 @@ dependencies = [ "once_cell", "pin-project-lite", "postgres_connection", + "pprof", "pq_proto", "rand 0.8.5", "regex", @@ -6781,6 +7074,7 @@ dependencies = [ "serde_assert", "serde_json", "serde_path_to_error", + "serde_with", "signal-hook", "strum", "strum_macros", @@ -6877,10 +7171,16 @@ name = "wal_decoder" version = "0.1.0" dependencies = [ "anyhow", + "async-compression", "bytes", "pageserver_api", "postgres_ffi", + "prost", "serde", + "thiserror", + "tokio", + "tonic", + "tonic-build", "tracing", "utils", "workspace_hack", @@ -7306,6 +7606,7 @@ dependencies = [ "anyhow", "axum", "axum-core", + "base64 0.13.1", "base64 0.21.1", "base64ct", "bytes", @@ -7340,13 +7641,13 @@ dependencies = [ "libc", "log", "memchr", + "nix 0.26.4", "nom", "num-bigint", "num-integer", "num-traits", "once_cell", "parquet", - "postgres-types", "prettyplease", "proc-macro2", "prost", @@ -7356,7 +7657,7 @@ dependencies = [ "regex-automata 0.4.3", "regex-syntax 0.8.2", "reqwest 0.12.4", - "rustls 0.23.16", + "rustls 0.23.18", "scopeguard", "serde", "serde_json", @@ -7371,7 +7672,6 @@ dependencies = [ "time", "time-macros", "tokio", - "tokio-postgres", "tokio-rustls 0.26.0", "tokio-stream", "tokio-util", diff --git a/Cargo.toml b/Cargo.toml index dbda930535..64c384f17a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,10 @@ members = [ "libs/vm_monitor", "libs/walproposer", "libs/wal_decoder", + "libs/postgres_initdb", + "libs/proxy/postgres-protocol2", + "libs/proxy/postgres-types2", + "libs/proxy/tokio-postgres2", ] [workspace.package] @@ -57,6 +61,7 @@ async-trait = "0.1" aws-config = { version = "1.5", default-features = false, features=["rustls", "sso"] } aws-sdk-s3 = "1.52" aws-sdk-iam = "1.46.0" +aws-sdk-kms = "1.47.0" aws-smithy-async = { version = "1.2.1", default-features = false, features=["rt-tokio"] } aws-smithy-types = "1.2" aws-credential-types = "1.2.0" @@ -73,11 +78,12 @@ bytes = "1.0" camino = "1.1.6" cfg-if = "1.0.0" chrono = { version = "0.4", default-features = false, features = ["clock"] } -clap = { version = "4.0", features = ["derive"] } +clap = { version = "4.0", features = ["derive", "env"] } comfy-table = "7.1" const_format = "0.2" crc32c = "0.6" dashmap = { version = "5.5.0", features = ["raw-api"] } +diatomic-waker = { version = "0.2.3" } either = "1.8" enum-map = "2.4.2" enumset = "1.0.12" @@ -106,7 +112,7 @@ hyper-util = "0.1" tokio-tungstenite = "0.21.0" indexmap = "2" indoc = "2" -ipnet = "2.9.0" +ipnet = "2.10.0" itertools = "0.10" itoa = "1.0.11" jsonwebtoken = "9" @@ -130,6 +136,7 @@ parquet = { version = "53", default-features = false, features = ["zstd"] } parquet_derive = "53" pbkdf2 = { version = "0.12.1", features = ["simple", "std"] } pin-project-lite = "0.2" +pprof = { version = "0.14", features = ["criterion", "flamegraph", "protobuf", "protobuf-codec"] } procfs = "0.16" prometheus = {version = "0.13", default-features=false, features = ["process"]} # removes protobuf dependency prost = "0.13" @@ -153,7 +160,7 @@ sentry = { version = "0.32", default-features = false, features = ["backtrace", serde = { version = "1.0", features = ["derive"] } serde_json = "1" serde_path_to_error = "0.1" -serde_with = "2.0" +serde_with = { version = "2.0", features = [ "base64" ] } serde_assert = "0.5.0" sha2 = "0.10.2" signal-hook = "0.3" @@ -212,12 +219,14 @@ tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", br compute_api = { version = "0.1", path = "./libs/compute_api/" } consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" } metrics = { version = "0.1", path = "./libs/metrics/" } +pageserver = { path = "./pageserver" } pageserver_api = { version = "0.1", path = "./libs/pageserver_api/" } pageserver_client = { path = "./pageserver/client" } pageserver_compaction = { version = "0.1", path = "./pageserver/compaction/" } postgres_backend = { version = "0.1", path = "./libs/postgres_backend/" } postgres_connection = { version = "0.1", path = "./libs/postgres_connection/" } postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" } +postgres_initdb = { path = "./libs/postgres_initdb" } pq_proto = { version = "0.1", path = "./libs/pq_proto/" } remote_storage = { version = "0.1", path = "./libs/remote_storage/" } safekeeper_api = { version = "0.1", path = "./libs/safekeeper_api" } diff --git a/Dockerfile b/Dockerfile index 785dd4598e..e888efbae2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,7 +7,7 @@ ARG IMAGE=build-tools ARG TAG=pinned ARG DEFAULT_PG_VERSION=17 ARG STABLE_PG_VERSION=16 -ARG DEBIAN_VERSION=bullseye +ARG DEBIAN_VERSION=bookworm ARG DEBIAN_FLAVOR=${DEBIAN_VERSION}-slim # Build Postgres diff --git a/Makefile b/Makefile index 8e3b755112..9cffc74508 100644 --- a/Makefile +++ b/Makefile @@ -38,6 +38,7 @@ ifeq ($(UNAME_S),Linux) # Seccomp BPF is only available for Linux PG_CONFIGURE_OPTS += --with-libseccomp else ifeq ($(UNAME_S),Darwin) + PG_CFLAGS += -DUSE_PREFETCH ifndef DISABLE_HOMEBREW # macOS with brew-installed openssl requires explicit paths # It can be configured with OPENSSL_PREFIX variable @@ -146,6 +147,8 @@ postgres-%: postgres-configure-% \ $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_prewarm install +@echo "Compiling pg_buffercache $*" $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_buffercache install + +@echo "Compiling pg_visibility $*" + $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_visibility install +@echo "Compiling pageinspect $*" $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pageinspect install +@echo "Compiling amcheck $*" diff --git a/README.md b/README.md index e68ef70bdf..1417d6b9e7 100644 --- a/README.md +++ b/README.md @@ -132,7 +132,7 @@ make -j`sysctl -n hw.logicalcpu` -s To run the `psql` client, install the `postgresql-client` package or modify `PATH` and `LD_LIBRARY_PATH` to include `pg_install/bin` and `pg_install/lib`, respectively. To run the integration tests or Python scripts (not required to use the code), install -Python (3.9 or higher), and install the python3 packages using `./scripts/pysync` (requires [poetry>=1.8](https://python-poetry.org/)) in the project directory. +Python (3.11 or higher), and install the python3 packages using `./scripts/pysync` (requires [poetry>=1.8](https://python-poetry.org/)) in the project directory. #### Running neon database diff --git a/build-tools.Dockerfile b/build-tools.Dockerfile index c1190b13f4..2671702697 100644 --- a/build-tools.Dockerfile +++ b/build-tools.Dockerfile @@ -1,4 +1,4 @@ -ARG DEBIAN_VERSION=bullseye +ARG DEBIAN_VERSION=bookworm FROM debian:bookworm-slim AS pgcopydb_builder ARG DEBIAN_VERSION @@ -57,9 +57,9 @@ RUN mkdir -p /pgcopydb/bin && \ mkdir -p /pgcopydb/lib && \ chmod -R 755 /pgcopydb && \ chown -R nonroot:nonroot /pgcopydb - -COPY --from=pgcopydb_builder /usr/lib/postgresql/16/bin/pgcopydb /pgcopydb/bin/pgcopydb -COPY --from=pgcopydb_builder /pgcopydb/lib/libpq.so.5 /pgcopydb/lib/libpq.so.5 + +COPY --from=pgcopydb_builder /usr/lib/postgresql/16/bin/pgcopydb /pgcopydb/bin/pgcopydb +COPY --from=pgcopydb_builder /pgcopydb/lib/libpq.so.5 /pgcopydb/lib/libpq.so.5 # System deps # @@ -234,7 +234,7 @@ USER nonroot:nonroot WORKDIR /home/nonroot # Python -ENV PYTHON_VERSION=3.9.19 \ +ENV PYTHON_VERSION=3.11.10 \ PYENV_ROOT=/home/nonroot/.pyenv \ PATH=/home/nonroot/.pyenv/shims:/home/nonroot/.pyenv/bin:/home/nonroot/.poetry/bin:$PATH RUN set -e \ @@ -258,14 +258,14 @@ WORKDIR /home/nonroot # Rust # Please keep the version of llvm (installed above) in sync with rust llvm (`rustc --version --verbose | grep LLVM`) -ENV RUSTC_VERSION=1.82.0 +ENV RUSTC_VERSION=1.83.0 ENV RUSTUP_HOME="/home/nonroot/.rustup" ENV PATH="/home/nonroot/.cargo/bin:${PATH}" ARG RUSTFILT_VERSION=0.2.1 -ARG CARGO_HAKARI_VERSION=0.9.30 -ARG CARGO_DENY_VERSION=0.16.1 -ARG CARGO_HACK_VERSION=0.6.31 -ARG CARGO_NEXTEST_VERSION=0.9.72 +ARG CARGO_HAKARI_VERSION=0.9.33 +ARG CARGO_DENY_VERSION=0.16.2 +ARG CARGO_HACK_VERSION=0.6.33 +ARG CARGO_NEXTEST_VERSION=0.9.85 RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux-gnu/rustup-init && whoami && \ chmod +x rustup-init && \ ./rustup-init -y --default-toolchain ${RUSTC_VERSION} && \ @@ -289,7 +289,7 @@ RUN whoami \ && cargo --version --verbose \ && rustup --version --verbose \ && rustc --version --verbose \ - && clang --version + && clang --version RUN if [ "${DEBIAN_VERSION}" = "bookworm" ]; then \ LD_LIBRARY_PATH=/pgcopydb/lib /pgcopydb/bin/pgcopydb --version; \ diff --git a/compute/compute-node.Dockerfile b/compute/compute-node.Dockerfile index 32405ece86..2fcd9985bc 100644 --- a/compute/compute-node.Dockerfile +++ b/compute/compute-node.Dockerfile @@ -3,7 +3,7 @@ ARG REPOSITORY=neondatabase ARG IMAGE=build-tools ARG TAG=pinned ARG BUILD_TAG -ARG DEBIAN_VERSION=bullseye +ARG DEBIAN_VERSION=bookworm ARG DEBIAN_FLAVOR=${DEBIAN_VERSION}-slim ######################################################################################### @@ -1243,7 +1243,7 @@ RUN make -j $(getconf _NPROCESSORS_ONLN) \ ######################################################################################### # -# Compile and run the Neon-specific `compute_ctl` binary +# Compile and run the Neon-specific `compute_ctl` and `fast_import` binaries # ######################################################################################### FROM $REPOSITORY/$IMAGE:$TAG AS compute-tools @@ -1264,6 +1264,7 @@ RUN cd compute_tools && mold -run cargo build --locked --profile release-line-de FROM debian:$DEBIAN_FLAVOR AS compute-tools-image COPY --from=compute-tools /home/nonroot/target/release-line-debug-size-lto/compute_ctl /usr/local/bin/compute_ctl +COPY --from=compute-tools /home/nonroot/target/release-line-debug-size-lto/fast_import /usr/local/bin/fast_import ######################################################################################### # @@ -1458,6 +1459,7 @@ RUN mkdir /var/db && useradd -m -d /var/db/postgres postgres && \ COPY --from=postgres-cleanup-layer --chown=postgres /usr/local/pgsql /usr/local COPY --from=compute-tools --chown=postgres /home/nonroot/target/release-line-debug-size-lto/compute_ctl /usr/local/bin/compute_ctl +COPY --from=compute-tools --chown=postgres /home/nonroot/target/release-line-debug-size-lto/fast_import /usr/local/bin/fast_import # pgbouncer and its config COPY --from=pgbouncer /usr/local/pgbouncer/bin/pgbouncer /usr/local/bin/pgbouncer @@ -1533,6 +1535,25 @@ RUN apt update && \ rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* && \ localedef -i en_US -c -f UTF-8 -A /usr/share/locale/locale.alias en_US.UTF-8 +# s5cmd 2.2.2 from https://github.com/peak/s5cmd/releases/tag/v2.2.2 +# used by fast_import +ARG TARGETARCH +ADD https://github.com/peak/s5cmd/releases/download/v2.2.2/s5cmd_2.2.2_linux_$TARGETARCH.deb /tmp/s5cmd.deb +RUN set -ex; \ + \ + # Determine the expected checksum based on TARGETARCH + if [ "${TARGETARCH}" = "amd64" ]; then \ + CHECKSUM="392c385320cd5ffa435759a95af77c215553d967e4b1c0fffe52e4f14c29cf85"; \ + elif [ "${TARGETARCH}" = "arm64" ]; then \ + CHECKSUM="939bee3cf4b5604ddb00e67f8c157b91d7c7a5b553d1fbb6890fad32894b7b46"; \ + else \ + echo "Unsupported architecture: ${TARGETARCH}"; exit 1; \ + fi; \ + \ + # Compute and validate the checksum + echo "${CHECKSUM} /tmp/s5cmd.deb" | sha256sum -c - +RUN dpkg -i /tmp/s5cmd.deb && rm /tmp/s5cmd.deb + ENV LANG=en_US.utf8 USER postgres ENTRYPOINT ["/usr/local/bin/compute_ctl"] diff --git a/compute_tools/Cargo.toml b/compute_tools/Cargo.toml index 0bf4ed53d6..c0c390caef 100644 --- a/compute_tools/Cargo.toml +++ b/compute_tools/Cargo.toml @@ -10,6 +10,10 @@ default = [] testing = [] [dependencies] +base64.workspace = true +aws-config.workspace = true +aws-sdk-s3.workspace = true +aws-sdk-kms.workspace = true anyhow.workspace = true camino.workspace = true chrono.workspace = true @@ -27,6 +31,8 @@ opentelemetry.workspace = true opentelemetry_sdk.workspace = true postgres.workspace = true regex.workspace = true +serde.workspace = true +serde_with.workspace = true serde_json.workspace = true signal-hook.workspace = true tar.workspace = true @@ -43,6 +49,7 @@ thiserror.workspace = true url.workspace = true prometheus.workspace = true +postgres_initdb.workspace = true compute_api.workspace = true utils.workspace = true workspace_hack.workspace = true diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index 4689cc2b83..6b670de2ea 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -58,7 +58,7 @@ use compute_tools::compute::{ forward_termination_signal, ComputeNode, ComputeState, ParsedSpec, PG_PID, }; use compute_tools::configurator::launch_configurator; -use compute_tools::extension_server::get_pg_version; +use compute_tools::extension_server::get_pg_version_string; use compute_tools::http::api::launch_http_server; use compute_tools::logger::*; use compute_tools::monitor::launch_monitor; @@ -326,7 +326,7 @@ fn wait_spec( connstr: Url::parse(connstr).context("cannot parse connstr as a URL")?, pgdata: pgdata.to_string(), pgbin: pgbin.to_string(), - pgversion: get_pg_version(pgbin), + pgversion: get_pg_version_string(pgbin), live_config_allowed, state: Mutex::new(new_state), state_changed: Condvar::new(), diff --git a/compute_tools/src/bin/fast_import.rs b/compute_tools/src/bin/fast_import.rs new file mode 100644 index 0000000000..6716cc6234 --- /dev/null +++ b/compute_tools/src/bin/fast_import.rs @@ -0,0 +1,345 @@ +//! This program dumps a remote Postgres database into a local Postgres database +//! and uploads the resulting PGDATA into object storage for import into a Timeline. +//! +//! # Context, Architecture, Design +//! +//! See cloud.git Fast Imports RFC () +//! for the full picture. +//! The RFC describing the storage pieces of importing the PGDATA dump into a Timeline +//! is publicly accessible at . +//! +//! # This is a Prototype! +//! +//! This program is part of a prototype feature and not yet used in production. +//! +//! The cloud.git RFC contains lots of suggestions for improving e2e throughput +//! of this step of the timeline import process. +//! +//! # Local Testing +//! +//! - Comment out most of the pgxns in The Dockerfile.compute-tools to speed up the build. +//! - Build the image with the following command: +//! +//! ```bash +//! docker buildx build --build-arg DEBIAN_FLAVOR=bullseye-slim --build-arg GIT_VERSION=local --build-arg PG_VERSION=v14 --build-arg BUILD_TAG="$(date --iso-8601=s -u)" -t localhost:3030/localregistry/compute-node-v14:latest -f compute/Dockerfile.com +//! docker push localhost:3030/localregistry/compute-node-v14:latest +//! ``` + +use anyhow::Context; +use aws_config::BehaviorVersion; +use camino::{Utf8Path, Utf8PathBuf}; +use clap::Parser; +use compute_tools::extension_server::{get_pg_version, PostgresMajorVersion}; +use nix::unistd::Pid; +use tracing::{info, info_span, warn, Instrument}; +use utils::fs_ext::is_directory_empty; + +#[path = "fast_import/child_stdio_to_log.rs"] +mod child_stdio_to_log; +#[path = "fast_import/s3_uri.rs"] +mod s3_uri; +#[path = "fast_import/s5cmd.rs"] +mod s5cmd; + +#[derive(clap::Parser)] +struct Args { + #[clap(long)] + working_directory: Utf8PathBuf, + #[clap(long, env = "NEON_IMPORTER_S3_PREFIX")] + s3_prefix: s3_uri::S3Uri, + #[clap(long)] + pg_bin_dir: Utf8PathBuf, + #[clap(long)] + pg_lib_dir: Utf8PathBuf, +} + +#[serde_with::serde_as] +#[derive(serde::Deserialize)] +struct Spec { + encryption_secret: EncryptionSecret, + #[serde_as(as = "serde_with::base64::Base64")] + source_connstring_ciphertext_base64: Vec, +} + +#[derive(serde::Deserialize)] +enum EncryptionSecret { + #[allow(clippy::upper_case_acronyms)] + KMS { key_id: String }, +} + +#[tokio::main] +pub(crate) async fn main() -> anyhow::Result<()> { + utils::logging::init( + utils::logging::LogFormat::Plain, + utils::logging::TracingErrorLayerEnablement::EnableWithRustLogFilter, + utils::logging::Output::Stdout, + )?; + + info!("starting"); + + let Args { + working_directory, + s3_prefix, + pg_bin_dir, + pg_lib_dir, + } = Args::parse(); + + let aws_config = aws_config::load_defaults(BehaviorVersion::v2024_03_28()).await; + + let spec: Spec = { + let spec_key = s3_prefix.append("/spec.json"); + let s3_client = aws_sdk_s3::Client::new(&aws_config); + let object = s3_client + .get_object() + .bucket(&spec_key.bucket) + .key(spec_key.key) + .send() + .await + .context("get spec from s3")? + .body + .collect() + .await + .context("download spec body")?; + serde_json::from_slice(&object.into_bytes()).context("parse spec as json")? + }; + + match tokio::fs::create_dir(&working_directory).await { + Ok(()) => {} + Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => { + if !is_directory_empty(&working_directory) + .await + .context("check if working directory is empty")? + { + anyhow::bail!("working directory is not empty"); + } else { + // ok + } + } + Err(e) => return Err(anyhow::Error::new(e).context("create working directory")), + } + + let pgdata_dir = working_directory.join("pgdata"); + tokio::fs::create_dir(&pgdata_dir) + .await + .context("create pgdata directory")?; + + // + // Setup clients + // + let aws_config = aws_config::load_defaults(BehaviorVersion::v2024_03_28()).await; + let kms_client = aws_sdk_kms::Client::new(&aws_config); + + // + // Initialize pgdata + // + let pg_version = match get_pg_version(pg_bin_dir.as_str()) { + PostgresMajorVersion::V14 => 14, + PostgresMajorVersion::V15 => 15, + PostgresMajorVersion::V16 => 16, + PostgresMajorVersion::V17 => 17, + }; + let superuser = "cloud_admin"; // XXX: this shouldn't be hard-coded + postgres_initdb::do_run_initdb(postgres_initdb::RunInitdbArgs { + superuser, + locale: "en_US.UTF-8", // XXX: this shouldn't be hard-coded, + pg_version, + initdb_bin: pg_bin_dir.join("initdb").as_ref(), + library_search_path: &pg_lib_dir, // TODO: is this right? Prob works in compute image, not sure about neon_local. + pgdata: &pgdata_dir, + }) + .await + .context("initdb")?; + + let nproc = num_cpus::get(); + + // + // Launch postgres process + // + let mut postgres_proc = tokio::process::Command::new(pg_bin_dir.join("postgres")) + .arg("-D") + .arg(&pgdata_dir) + .args(["-c", "wal_level=minimal"]) + .args(["-c", "shared_buffers=10GB"]) + .args(["-c", "max_wal_senders=0"]) + .args(["-c", "fsync=off"]) + .args(["-c", "full_page_writes=off"]) + .args(["-c", "synchronous_commit=off"]) + .args(["-c", "maintenance_work_mem=8388608"]) + .args(["-c", &format!("max_parallel_maintenance_workers={nproc}")]) + .args(["-c", &format!("max_parallel_workers={nproc}")]) + .args(["-c", &format!("max_parallel_workers_per_gather={nproc}")]) + .args(["-c", &format!("max_worker_processes={nproc}")]) + .args(["-c", "effective_io_concurrency=100"]) + .env_clear() + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .spawn() + .context("spawn postgres")?; + + info!("spawned postgres, waiting for it to become ready"); + tokio::spawn( + child_stdio_to_log::relay_process_output( + postgres_proc.stdout.take(), + postgres_proc.stderr.take(), + ) + .instrument(info_span!("postgres")), + ); + let restore_pg_connstring = + format!("host=localhost port=5432 user={superuser} dbname=postgres"); + loop { + let res = tokio_postgres::connect(&restore_pg_connstring, tokio_postgres::NoTls).await; + if res.is_ok() { + info!("postgres is ready, could connect to it"); + break; + } + } + + // + // Decrypt connection string + // + let source_connection_string = { + match spec.encryption_secret { + EncryptionSecret::KMS { key_id } => { + let mut output = kms_client + .decrypt() + .key_id(key_id) + .ciphertext_blob(aws_sdk_s3::primitives::Blob::new( + spec.source_connstring_ciphertext_base64, + )) + .send() + .await + .context("decrypt source connection string")?; + let plaintext = output + .plaintext + .take() + .context("get plaintext source connection string")?; + String::from_utf8(plaintext.into_inner()) + .context("parse source connection string as utf8")? + } + } + }; + + // + // Start the work + // + + let dumpdir = working_directory.join("dumpdir"); + + let common_args = [ + // schema mapping (prob suffices to specify them on one side) + "--no-owner".to_string(), + "--no-privileges".to_string(), + "--no-publications".to_string(), + "--no-security-labels".to_string(), + "--no-subscriptions".to_string(), + "--no-tablespaces".to_string(), + // format + "--format".to_string(), + "directory".to_string(), + // concurrency + "--jobs".to_string(), + num_cpus::get().to_string(), + // progress updates + "--verbose".to_string(), + ]; + + info!("dump into the working directory"); + { + let mut pg_dump = tokio::process::Command::new(pg_bin_dir.join("pg_dump")) + .args(&common_args) + .arg("-f") + .arg(&dumpdir) + .arg("--no-sync") + // POSITIONAL args + // source db (db name included in connection string) + .arg(&source_connection_string) + // how we run it + .env_clear() + .kill_on_drop(true) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .spawn() + .context("spawn pg_dump")?; + + info!(pid=%pg_dump.id().unwrap(), "spawned pg_dump"); + + tokio::spawn( + child_stdio_to_log::relay_process_output(pg_dump.stdout.take(), pg_dump.stderr.take()) + .instrument(info_span!("pg_dump")), + ); + + let st = pg_dump.wait().await.context("wait for pg_dump")?; + info!(status=?st, "pg_dump exited"); + if !st.success() { + warn!(status=%st, "pg_dump failed, restore will likely fail as well"); + } + } + + // TODO: do it in a streaming way, plenty of internal research done on this already + // TODO: do the unlogged table trick + + info!("restore from working directory into vanilla postgres"); + { + let mut pg_restore = tokio::process::Command::new(pg_bin_dir.join("pg_restore")) + .args(&common_args) + .arg("-d") + .arg(&restore_pg_connstring) + // POSITIONAL args + .arg(&dumpdir) + // how we run it + .env_clear() + .kill_on_drop(true) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .spawn() + .context("spawn pg_restore")?; + + info!(pid=%pg_restore.id().unwrap(), "spawned pg_restore"); + tokio::spawn( + child_stdio_to_log::relay_process_output( + pg_restore.stdout.take(), + pg_restore.stderr.take(), + ) + .instrument(info_span!("pg_restore")), + ); + let st = pg_restore.wait().await.context("wait for pg_restore")?; + info!(status=?st, "pg_restore exited"); + if !st.success() { + warn!(status=%st, "pg_restore failed, restore will likely fail as well"); + } + } + + info!("shutdown postgres"); + { + nix::sys::signal::kill( + Pid::from_raw( + i32::try_from(postgres_proc.id().unwrap()).expect("convert child pid to i32"), + ), + nix::sys::signal::SIGTERM, + ) + .context("signal postgres to shut down")?; + postgres_proc + .wait() + .await + .context("wait for postgres to shut down")?; + } + + info!("upload pgdata"); + s5cmd::sync(Utf8Path::new(&pgdata_dir), &s3_prefix.append("/")) + .await + .context("sync dump directory to destination")?; + + info!("write status"); + { + let status_dir = working_directory.join("status"); + std::fs::create_dir(&status_dir).context("create status directory")?; + let status_file = status_dir.join("status"); + std::fs::write(&status_file, serde_json::json!({"done": true}).to_string()) + .context("write status file")?; + s5cmd::sync(&status_file, &s3_prefix.append("/status/pgdata")) + .await + .context("sync status directory to destination")?; + } + + Ok(()) +} diff --git a/compute_tools/src/bin/fast_import/child_stdio_to_log.rs b/compute_tools/src/bin/fast_import/child_stdio_to_log.rs new file mode 100644 index 0000000000..6724ef9bed --- /dev/null +++ b/compute_tools/src/bin/fast_import/child_stdio_to_log.rs @@ -0,0 +1,35 @@ +use tokio::io::{AsyncBufReadExt, BufReader}; +use tokio::process::{ChildStderr, ChildStdout}; +use tracing::info; + +/// Asynchronously relays the output from a child process's `stdout` and `stderr` to the tracing log. +/// Each line is read and logged individually, with lossy UTF-8 conversion. +/// +/// # Arguments +/// +/// * `stdout`: An `Option` from the child process. +/// * `stderr`: An `Option` from the child process. +/// +pub(crate) async fn relay_process_output(stdout: Option, stderr: Option) { + let stdout_fut = async { + if let Some(stdout) = stdout { + let reader = BufReader::new(stdout); + let mut lines = reader.lines(); + while let Ok(Some(line)) = lines.next_line().await { + info!(fd = "stdout", "{}", line); + } + } + }; + + let stderr_fut = async { + if let Some(stderr) = stderr { + let reader = BufReader::new(stderr); + let mut lines = reader.lines(); + while let Ok(Some(line)) = lines.next_line().await { + info!(fd = "stderr", "{}", line); + } + } + }; + + tokio::join!(stdout_fut, stderr_fut); +} diff --git a/compute_tools/src/bin/fast_import/s3_uri.rs b/compute_tools/src/bin/fast_import/s3_uri.rs new file mode 100644 index 0000000000..52bbef420f --- /dev/null +++ b/compute_tools/src/bin/fast_import/s3_uri.rs @@ -0,0 +1,75 @@ +use anyhow::Result; +use std::str::FromStr; + +/// Struct to hold parsed S3 components +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct S3Uri { + pub bucket: String, + pub key: String, +} + +impl FromStr for S3Uri { + type Err = anyhow::Error; + + /// Parse an S3 URI into a bucket and key + fn from_str(uri: &str) -> Result { + // Ensure the URI starts with "s3://" + if !uri.starts_with("s3://") { + return Err(anyhow::anyhow!("Invalid S3 URI scheme")); + } + + // Remove the "s3://" prefix + let stripped_uri = &uri[5..]; + + // Split the remaining string into bucket and key parts + if let Some((bucket, key)) = stripped_uri.split_once('/') { + Ok(S3Uri { + bucket: bucket.to_string(), + key: key.to_string(), + }) + } else { + Err(anyhow::anyhow!( + "Invalid S3 URI format, missing bucket or key" + )) + } + } +} + +impl S3Uri { + pub fn append(&self, suffix: &str) -> Self { + Self { + bucket: self.bucket.clone(), + key: format!("{}{}", self.key, suffix), + } + } +} + +impl std::fmt::Display for S3Uri { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "s3://{}/{}", self.bucket, self.key) + } +} + +impl clap::builder::TypedValueParser for S3Uri { + type Value = Self; + + fn parse_ref( + &self, + _cmd: &clap::Command, + _arg: Option<&clap::Arg>, + value: &std::ffi::OsStr, + ) -> Result { + let value_str = value.to_str().ok_or_else(|| { + clap::Error::raw( + clap::error::ErrorKind::InvalidUtf8, + "Invalid UTF-8 sequence", + ) + })?; + S3Uri::from_str(value_str).map_err(|e| { + clap::Error::raw( + clap::error::ErrorKind::InvalidValue, + format!("Failed to parse S3 URI: {}", e), + ) + }) + } +} diff --git a/compute_tools/src/bin/fast_import/s5cmd.rs b/compute_tools/src/bin/fast_import/s5cmd.rs new file mode 100644 index 0000000000..d2d9a79736 --- /dev/null +++ b/compute_tools/src/bin/fast_import/s5cmd.rs @@ -0,0 +1,27 @@ +use anyhow::Context; +use camino::Utf8Path; + +use super::s3_uri::S3Uri; + +pub(crate) async fn sync(local: &Utf8Path, remote: &S3Uri) -> anyhow::Result<()> { + let mut builder = tokio::process::Command::new("s5cmd"); + // s5cmd uses aws-sdk-go v1, hence doesn't support AWS_ENDPOINT_URL + if let Some(val) = std::env::var_os("AWS_ENDPOINT_URL") { + builder.arg("--endpoint-url").arg(val); + } + builder + .arg("sync") + .arg(local.as_str()) + .arg(remote.to_string()); + let st = builder + .spawn() + .context("spawn s5cmd")? + .wait() + .await + .context("wait for s5cmd")?; + if st.success() { + Ok(()) + } else { + Err(anyhow::anyhow!("s5cmd failed")) + } +} diff --git a/compute_tools/src/catalog.rs b/compute_tools/src/catalog.rs index 2f6f82dd39..08ae8bf44d 100644 --- a/compute_tools/src/catalog.rs +++ b/compute_tools/src/catalog.rs @@ -1,4 +1,3 @@ -use compute_api::responses::CatalogObjects; use futures::Stream; use postgres::NoTls; use std::{path::Path, process::Stdio, result::Result, sync::Arc}; @@ -13,7 +12,8 @@ use tokio_util::codec::{BytesCodec, FramedRead}; use tracing::warn; use crate::compute::ComputeNode; -use crate::pg_helpers::{get_existing_dbs_async, get_existing_roles_async}; +use crate::pg_helpers::{get_existing_dbs_async, get_existing_roles_async, postgres_conf_for_db}; +use compute_api::responses::CatalogObjects; pub async fn get_dbs_and_roles(compute: &Arc) -> anyhow::Result { let connstr = compute.connstr.clone(); @@ -43,6 +43,8 @@ pub enum SchemaDumpError { DatabaseDoesNotExist, #[error("Failed to execute pg_dump.")] IO(#[from] std::io::Error), + #[error("Unexpected error.")] + Unexpected, } // It uses the pg_dump utility to dump the schema of the specified database. @@ -60,11 +62,38 @@ pub async fn get_database_schema( let pgbin = &compute.pgbin; let basepath = Path::new(pgbin).parent().unwrap(); let pgdump = basepath.join("pg_dump"); - let mut connstr = compute.connstr.clone(); - connstr.set_path(dbname); + + // Replace the DB in the connection string and disable it to parts. + // This is the only option to handle DBs with special characters. + let conf = + postgres_conf_for_db(&compute.connstr, dbname).map_err(|_| SchemaDumpError::Unexpected)?; + let host = conf + .get_hosts() + .first() + .ok_or(SchemaDumpError::Unexpected)?; + let host = match host { + tokio_postgres::config::Host::Tcp(ip) => ip.to_string(), + #[cfg(unix)] + tokio_postgres::config::Host::Unix(path) => path.to_string_lossy().to_string(), + }; + let port = conf + .get_ports() + .first() + .ok_or(SchemaDumpError::Unexpected)?; + let user = conf.get_user().ok_or(SchemaDumpError::Unexpected)?; + let dbname = conf.get_dbname().ok_or(SchemaDumpError::Unexpected)?; + let mut cmd = Command::new(pgdump) + // XXX: this seems to be the only option to deal with DBs with `=` in the name + // See + .env("PGDATABASE", dbname) + .arg("--host") + .arg(host) + .arg("--port") + .arg(port.to_string()) + .arg("--username") + .arg(user) .arg("--schema-only") - .arg(connstr.as_str()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) .kill_on_drop(true) diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index 4f67425ba8..1a026a4014 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -34,9 +34,8 @@ use utils::measured_stream::MeasuredReader; use nix::sys::signal::{kill, Signal}; use remote_storage::{DownloadError, RemotePath}; use tokio::spawn; -use url::Url; -use crate::installed_extensions::get_installed_extensions_sync; +use crate::installed_extensions::get_installed_extensions; use crate::local_proxy; use crate::pg_helpers::*; use crate::spec::*; @@ -816,30 +815,32 @@ impl ComputeNode { Ok(()) } - async fn get_maintenance_client(url: &Url) -> Result { - let mut connstr = url.clone(); + async fn get_maintenance_client( + conf: &tokio_postgres::Config, + ) -> Result { + let mut conf = conf.clone(); - connstr - .query_pairs_mut() - .append_pair("application_name", "apply_config"); + conf.application_name("apply_config"); - let (client, conn) = match tokio_postgres::connect(connstr.as_str(), NoTls).await { + let (client, conn) = match conf.connect(NoTls).await { + // If connection fails, it may be the old node with `zenith_admin` superuser. + // + // In this case we need to connect with old `zenith_admin` name + // and create new user. We cannot simply rename connected user, + // but we can create a new one and grant it all privileges. Err(e) => match e.code() { Some(&SqlState::INVALID_PASSWORD) | Some(&SqlState::INVALID_AUTHORIZATION_SPECIFICATION) => { - // connect with zenith_admin if cloud_admin could not authenticate + // Connect with zenith_admin if cloud_admin could not authenticate info!( "cannot connect to postgres: {}, retrying with `zenith_admin` username", e ); - let mut zenith_admin_connstr = connstr.clone(); - - zenith_admin_connstr - .set_username("zenith_admin") - .map_err(|_| anyhow::anyhow!("invalid connstr"))?; + let mut zenith_admin_conf = postgres::config::Config::from(conf.clone()); + zenith_admin_conf.user("zenith_admin"); let mut client = - Client::connect(zenith_admin_connstr.as_str(), NoTls) + zenith_admin_conf.connect(NoTls) .context("broken cloud_admin credential: tried connecting with cloud_admin but could not authenticate, and zenith_admin does not work either")?; // Disable forwarding so that users don't get a cloud_admin role @@ -853,8 +854,8 @@ impl ComputeNode { drop(client); - // reconnect with connstring with expected name - tokio_postgres::connect(connstr.as_str(), NoTls).await? + // Reconnect with connstring with expected name + conf.connect(NoTls).await? } _ => return Err(e.into()), }, @@ -885,7 +886,7 @@ impl ComputeNode { pub fn apply_spec_sql( &self, spec: Arc, - url: Arc, + conf: Arc, concurrency: usize, ) -> Result<()> { let rt = tokio::runtime::Builder::new_multi_thread() @@ -897,7 +898,7 @@ impl ComputeNode { rt.block_on(async { // Proceed with post-startup configuration. Note, that order of operations is important. - let client = Self::get_maintenance_client(&url).await?; + let client = Self::get_maintenance_client(&conf).await?; let spec = spec.clone(); let databases = get_existing_dbs_async(&client).await?; @@ -931,7 +932,7 @@ impl ComputeNode { RenameAndDeleteDatabases, CreateAndAlterDatabases, ] { - debug!("Applying phase {:?}", &phase); + info!("Applying phase {:?}", &phase); apply_operations( spec.clone(), ctx.clone(), @@ -942,6 +943,7 @@ impl ComputeNode { .await?; } + info!("Applying RunInEachDatabase phase"); let concurrency_token = Arc::new(tokio::sync::Semaphore::new(concurrency)); let db_processes = spec @@ -955,7 +957,7 @@ impl ComputeNode { let spec = spec.clone(); let ctx = ctx.clone(); let jwks_roles = jwks_roles.clone(); - let mut url = url.as_ref().clone(); + let mut conf = conf.as_ref().clone(); let concurrency_token = concurrency_token.clone(); let db = db.clone(); @@ -964,14 +966,14 @@ impl ComputeNode { match &db { DB::SystemDB => {} DB::UserDB(db) => { - url.set_path(db.name.as_str()); + conf.dbname(db.name.as_str()); } } - let url = Arc::new(url); + let conf = Arc::new(conf); let fut = Self::apply_spec_sql_db( spec.clone(), - url, + conf, ctx.clone(), jwks_roles.clone(), concurrency_token.clone(), @@ -1017,7 +1019,7 @@ impl ComputeNode { /// semaphore. The caller has to make sure the semaphore isn't exhausted. async fn apply_spec_sql_db( spec: Arc, - url: Arc, + conf: Arc, ctx: Arc>, jwks_roles: Arc>, concurrency_token: Arc, @@ -1046,7 +1048,7 @@ impl ComputeNode { // that database. || async { if client_conn.is_none() { - let db_client = Self::get_maintenance_client(&url).await?; + let db_client = Self::get_maintenance_client(&conf).await?; client_conn.replace(db_client); } let client = client_conn.as_ref().unwrap(); @@ -1061,34 +1063,16 @@ impl ComputeNode { Ok::<(), anyhow::Error>(()) } - /// Do initial configuration of the already started Postgres. - #[instrument(skip_all)] - pub fn apply_config(&self, compute_state: &ComputeState) -> Result<()> { - // If connection fails, - // it may be the old node with `zenith_admin` superuser. - // - // In this case we need to connect with old `zenith_admin` name - // and create new user. We cannot simply rename connected user, - // but we can create a new one and grant it all privileges. - let mut url = self.connstr.clone(); - url.query_pairs_mut() - .append_pair("application_name", "apply_config"); - - let url = Arc::new(url); - let spec = Arc::new( - compute_state - .pspec - .as_ref() - .expect("spec must be set") - .spec - .clone(), - ); - - // Choose how many concurrent connections to use for applying the spec changes. - // If the cluster is not currently Running we don't have to deal with user connections, + /// Choose how many concurrent connections to use for applying the spec changes. + pub fn max_service_connections( + &self, + compute_state: &ComputeState, + spec: &ComputeSpec, + ) -> usize { + // If the cluster is in Init state we don't have to deal with user connections, // and can thus use all `max_connections` connection slots. However, that's generally not // very efficient, so we generally still limit it to a smaller number. - let max_concurrent_connections = if compute_state.status != ComputeStatus::Running { + if compute_state.status == ComputeStatus::Init { // If the settings contain 'max_connections', use that as template if let Some(config) = spec.cluster.settings.find("max_connections") { config.parse::().ok() @@ -1144,10 +1128,29 @@ impl ComputeNode { .map(|val| if val > 1 { val - 1 } else { 1 }) .last() .unwrap_or(3) - }; + } + } + + /// Do initial configuration of the already started Postgres. + #[instrument(skip_all)] + pub fn apply_config(&self, compute_state: &ComputeState) -> Result<()> { + let mut conf = tokio_postgres::Config::from_str(self.connstr.as_str()).unwrap(); + conf.application_name("apply_config"); + + let conf = Arc::new(conf); + let spec = Arc::new( + compute_state + .pspec + .as_ref() + .expect("spec must be set") + .spec + .clone(), + ); + + let max_concurrent_connections = self.max_service_connections(compute_state, &spec); // Merge-apply spec & changes to PostgreSQL state. - self.apply_spec_sql(spec.clone(), url.clone(), max_concurrent_connections)?; + self.apply_spec_sql(spec.clone(), conf.clone(), max_concurrent_connections)?; if let Some(ref local_proxy) = &spec.clone().local_proxy_config { info!("configuring local_proxy"); @@ -1156,12 +1159,11 @@ impl ComputeNode { // Run migrations separately to not hold up cold starts thread::spawn(move || { - let mut connstr = url.as_ref().clone(); - connstr - .query_pairs_mut() - .append_pair("application_name", "migrations"); + let conf = conf.as_ref().clone(); + let mut conf = postgres::config::Config::from(conf); + conf.application_name("migrations"); - let mut client = Client::connect(connstr.as_str(), NoTls)?; + let mut client = conf.connect(NoTls)?; handle_migrations(&mut client).context("apply_config handle_migrations") }); @@ -1222,21 +1224,28 @@ impl ComputeNode { let pgdata_path = Path::new(&self.pgdata); let postgresql_conf_path = pgdata_path.join("postgresql.conf"); config::write_postgres_conf(&postgresql_conf_path, &spec, None)?; - // temporarily reset max_cluster_size in config + + // TODO(ololobus): We need a concurrency during reconfiguration as well, + // but DB is already running and used by user. We can easily get out of + // `max_connections` limit, and the current code won't handle that. + // let compute_state = self.state.lock().unwrap().clone(); + // let max_concurrent_connections = self.max_service_connections(&compute_state, &spec); + let max_concurrent_connections = 1; + + // Temporarily reset max_cluster_size in config // to avoid the possibility of hitting the limit, while we are reconfiguring: - // creating new extensions, roles, etc... + // creating new extensions, roles, etc. config::with_compute_ctl_tmp_override(pgdata_path, "neon.max_cluster_size=-1", || { self.pg_reload_conf()?; if spec.mode == ComputeMode::Primary { - let mut url = self.connstr.clone(); - url.query_pairs_mut() - .append_pair("application_name", "apply_config"); - let url = Arc::new(url); + let mut conf = tokio_postgres::Config::from_str(self.connstr.as_str()).unwrap(); + conf.application_name("apply_config"); + let conf = Arc::new(conf); let spec = Arc::new(spec.clone()); - self.apply_spec_sql(spec, url, 1)?; + self.apply_spec_sql(spec, conf, max_concurrent_connections)?; } Ok(()) @@ -1362,7 +1371,17 @@ impl ComputeNode { let connstr = self.connstr.clone(); thread::spawn(move || { - get_installed_extensions_sync(connstr).context("get_installed_extensions") + let res = get_installed_extensions(&connstr); + match res { + Ok(extensions) => { + info!( + "[NEON_EXT_STAT] {}", + serde_json::to_string(&extensions) + .expect("failed to serialize extensions list") + ); + } + Err(err) => error!("could not get installed extensions: {err:?}"), + } }); } diff --git a/compute_tools/src/config.rs b/compute_tools/src/config.rs index d4e413034e..d65fe73194 100644 --- a/compute_tools/src/config.rs +++ b/compute_tools/src/config.rs @@ -116,7 +116,7 @@ pub fn write_postgres_conf( vartype: "enum".to_owned(), }; - write!(file, "{}", opt.to_pg_setting())?; + writeln!(file, "{}", opt.to_pg_setting())?; } } diff --git a/compute_tools/src/extension_server.rs b/compute_tools/src/extension_server.rs index da2d107b54..f13b2308e7 100644 --- a/compute_tools/src/extension_server.rs +++ b/compute_tools/src/extension_server.rs @@ -103,14 +103,33 @@ fn get_pg_config(argument: &str, pgbin: &str) -> String { .to_string() } -pub fn get_pg_version(pgbin: &str) -> String { +pub fn get_pg_version(pgbin: &str) -> PostgresMajorVersion { // pg_config --version returns a (platform specific) human readable string // such as "PostgreSQL 15.4". We parse this to v14/v15/v16 etc. let human_version = get_pg_config("--version", pgbin); - parse_pg_version(&human_version).to_string() + parse_pg_version(&human_version) } -fn parse_pg_version(human_version: &str) -> &str { +pub fn get_pg_version_string(pgbin: &str) -> String { + match get_pg_version(pgbin) { + PostgresMajorVersion::V14 => "v14", + PostgresMajorVersion::V15 => "v15", + PostgresMajorVersion::V16 => "v16", + PostgresMajorVersion::V17 => "v17", + } + .to_owned() +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum PostgresMajorVersion { + V14, + V15, + V16, + V17, +} + +fn parse_pg_version(human_version: &str) -> PostgresMajorVersion { + use PostgresMajorVersion::*; // Normal releases have version strings like "PostgreSQL 15.4". But there // are also pre-release versions like "PostgreSQL 17devel" or "PostgreSQL // 16beta2" or "PostgreSQL 17rc1". And with the --with-extra-version @@ -121,10 +140,10 @@ fn parse_pg_version(human_version: &str) -> &str { .captures(human_version) { Some(captures) if captures.len() == 2 => match &captures["major"] { - "14" => return "v14", - "15" => return "v15", - "16" => return "v16", - "17" => return "v17", + "14" => return V14, + "15" => return V15, + "16" => return V16, + "17" => return V17, _ => {} }, _ => {} @@ -263,24 +282,25 @@ mod tests { #[test] fn test_parse_pg_version() { - assert_eq!(parse_pg_version("PostgreSQL 15.4"), "v15"); - assert_eq!(parse_pg_version("PostgreSQL 15.14"), "v15"); + use super::PostgresMajorVersion::*; + assert_eq!(parse_pg_version("PostgreSQL 15.4"), V15); + assert_eq!(parse_pg_version("PostgreSQL 15.14"), V15); assert_eq!( parse_pg_version("PostgreSQL 15.4 (Ubuntu 15.4-0ubuntu0.23.04.1)"), - "v15" + V15 ); - assert_eq!(parse_pg_version("PostgreSQL 14.15"), "v14"); - assert_eq!(parse_pg_version("PostgreSQL 14.0"), "v14"); + assert_eq!(parse_pg_version("PostgreSQL 14.15"), V14); + assert_eq!(parse_pg_version("PostgreSQL 14.0"), V14); assert_eq!( parse_pg_version("PostgreSQL 14.9 (Debian 14.9-1.pgdg120+1"), - "v14" + V14 ); - assert_eq!(parse_pg_version("PostgreSQL 16devel"), "v16"); - assert_eq!(parse_pg_version("PostgreSQL 16beta1"), "v16"); - assert_eq!(parse_pg_version("PostgreSQL 16rc2"), "v16"); - assert_eq!(parse_pg_version("PostgreSQL 16extra"), "v16"); + assert_eq!(parse_pg_version("PostgreSQL 16devel"), V16); + assert_eq!(parse_pg_version("PostgreSQL 16beta1"), V16); + assert_eq!(parse_pg_version("PostgreSQL 16rc2"), V16); + assert_eq!(parse_pg_version("PostgreSQL 16extra"), V16); } #[test] diff --git a/compute_tools/src/http/api.rs b/compute_tools/src/http/api.rs index 3677582c11..a6c6cff20a 100644 --- a/compute_tools/src/http/api.rs +++ b/compute_tools/src/http/api.rs @@ -20,6 +20,7 @@ use anyhow::Result; use hyper::header::CONTENT_TYPE; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Method, Request, Response, Server, StatusCode}; +use metrics::proto::MetricFamily; use metrics::Encoder; use metrics::TextEncoder; use tokio::task; @@ -72,10 +73,22 @@ async fn routes(req: Request, compute: &Arc) -> Response { debug!("serving /metrics GET request"); - let mut buffer = vec![]; - let metrics = installed_extensions::collect(); + // When we call TextEncoder::encode() below, it will immediately + // return an error if a metric family has no metrics, so we need to + // preemptively filter out metric families with no metrics. + let metrics = installed_extensions::collect() + .into_iter() + .filter(|m| !m.get_metric().is_empty()) + .collect::>(); + let encoder = TextEncoder::new(); - encoder.encode(&metrics, &mut buffer).unwrap(); + let mut buffer = vec![]; + + if let Err(err) = encoder.encode(&metrics, &mut buffer) { + let msg = format!("error handling /metrics request: {err}"); + error!(msg); + return render_json_error(&msg, StatusCode::INTERNAL_SERVER_ERROR); + } match Response::builder() .status(StatusCode::OK) @@ -283,7 +296,12 @@ async fn routes(req: Request, compute: &Arc) -> Response render_json(Body::from(serde_json::to_string(&res).unwrap())), Err(e) => render_json_error( diff --git a/compute_tools/src/installed_extensions.rs b/compute_tools/src/installed_extensions.rs index 6dd55855db..f473c29a55 100644 --- a/compute_tools/src/installed_extensions.rs +++ b/compute_tools/src/installed_extensions.rs @@ -2,17 +2,16 @@ use compute_api::responses::{InstalledExtension, InstalledExtensions}; use metrics::proto::MetricFamily; use std::collections::HashMap; use std::collections::HashSet; -use tracing::info; -use url::Url; use anyhow::Result; use postgres::{Client, NoTls}; -use tokio::task; use metrics::core::Collector; use metrics::{register_uint_gauge_vec, UIntGaugeVec}; use once_cell::sync::Lazy; +use crate::pg_helpers::postgres_conf_for_db; + /// We don't reuse get_existing_dbs() just for code clarity /// and to make database listing query here more explicit. /// @@ -42,80 +41,56 @@ fn list_dbs(client: &mut Client) -> Result> { /// /// Same extension can be installed in multiple databases with different versions, /// we only keep the highest and lowest version across all databases. -pub async fn get_installed_extensions(connstr: Url) -> Result { - let mut connstr = connstr.clone(); +pub fn get_installed_extensions(connstr: &url::Url) -> Result { + let mut client = Client::connect(connstr.as_str(), NoTls)?; + let databases: Vec = list_dbs(&mut client)?; - task::spawn_blocking(move || { - let mut client = Client::connect(connstr.as_str(), NoTls)?; - let databases: Vec = list_dbs(&mut client)?; + let mut extensions_map: HashMap = HashMap::new(); + for db in databases.iter() { + let config = postgres_conf_for_db(connstr, db)?; + let mut db_client = config.connect(NoTls)?; + let extensions: Vec<(String, String)> = db_client + .query( + "SELECT extname, extversion FROM pg_catalog.pg_extension;", + &[], + )? + .iter() + .map(|row| (row.get("extname"), row.get("extversion"))) + .collect(); - let mut extensions_map: HashMap = HashMap::new(); - for db in databases.iter() { - connstr.set_path(db); - let mut db_client = Client::connect(connstr.as_str(), NoTls)?; - let extensions: Vec<(String, String)> = db_client - .query( - "SELECT extname, extversion FROM pg_catalog.pg_extension;", - &[], - )? - .iter() - .map(|row| (row.get("extname"), row.get("extversion"))) - .collect(); + for (extname, v) in extensions.iter() { + let version = v.to_string(); - for (extname, v) in extensions.iter() { - let version = v.to_string(); + // increment the number of databases where the version of extension is installed + INSTALLED_EXTENSIONS + .with_label_values(&[extname, &version]) + .inc(); - // increment the number of databases where the version of extension is installed - INSTALLED_EXTENSIONS - .with_label_values(&[extname, &version]) - .inc(); - - extensions_map - .entry(extname.to_string()) - .and_modify(|e| { - e.versions.insert(version.clone()); - // count the number of databases where the extension is installed - e.n_databases += 1; - }) - .or_insert(InstalledExtension { - extname: extname.to_string(), - versions: HashSet::from([version.clone()]), - n_databases: 1, - }); - } + extensions_map + .entry(extname.to_string()) + .and_modify(|e| { + e.versions.insert(version.clone()); + // count the number of databases where the extension is installed + e.n_databases += 1; + }) + .or_insert(InstalledExtension { + extname: extname.to_string(), + versions: HashSet::from([version.clone()]), + n_databases: 1, + }); } + } - let res = InstalledExtensions { - extensions: extensions_map.values().cloned().collect(), - }; + let res = InstalledExtensions { + extensions: extensions_map.values().cloned().collect(), + }; - Ok(res) - }) - .await? -} - -// Gather info about installed extensions -pub fn get_installed_extensions_sync(connstr: Url) -> Result<()> { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("failed to create runtime"); - let result = rt - .block_on(crate::installed_extensions::get_installed_extensions( - connstr, - )) - .expect("failed to get installed extensions"); - - info!( - "[NEON_EXT_STAT] {}", - serde_json::to_string(&result).expect("failed to serialize extensions list") - ); - Ok(()) + Ok(res) } static INSTALLED_EXTENSIONS: Lazy = Lazy::new(|| { register_uint_gauge_vec!( - "installed_extensions", + "compute_installed_extensions", "Number of databases where the version of extension is installed", &["extension_name", "version"] ) diff --git a/compute_tools/src/pg_helpers.rs b/compute_tools/src/pg_helpers.rs index 4a1e5ee0e8..e03b410699 100644 --- a/compute_tools/src/pg_helpers.rs +++ b/compute_tools/src/pg_helpers.rs @@ -6,6 +6,7 @@ use std::io::{BufRead, BufReader}; use std::os::unix::fs::PermissionsExt; use std::path::Path; use std::process::Child; +use std::str::FromStr; use std::thread::JoinHandle; use std::time::{Duration, Instant}; @@ -13,8 +14,10 @@ use anyhow::{bail, Result}; use futures::StreamExt; use ini::Ini; use notify::{RecursiveMode, Watcher}; +use postgres::config::Config; use tokio::io::AsyncBufReadExt; use tokio::time::timeout; +use tokio_postgres; use tokio_postgres::NoTls; use tracing::{debug, error, info, instrument}; @@ -542,3 +545,11 @@ async fn handle_postgres_logs_async(stderr: tokio::process::ChildStderr) -> Resu Ok(()) } + +/// `Postgres::config::Config` handles database names with whitespaces +/// and special characters properly. +pub fn postgres_conf_for_db(connstr: &url::Url, dbname: &str) -> Result { + let mut conf = Config::from_str(connstr.as_str())?; + conf.dbname(dbname); + Ok(conf) +} diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index c4063bbd1a..1ea443b026 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -1153,6 +1153,7 @@ async fn handle_timeline(cmd: &TimelineCmd, env: &mut local_env::LocalEnv) -> Re timeline_info.timeline_id ); } + // TODO: rename to import-basebackup-plus-wal TimelineCmd::Import(args) => { let tenant_id = get_tenant_id(args.tenant_id, env)?; let timeline_id = args.timeline_id; diff --git a/control_plane/src/pageserver.rs b/control_plane/src/pageserver.rs index ae5e22ddc6..1d1455b95b 100644 --- a/control_plane/src/pageserver.rs +++ b/control_plane/src/pageserver.rs @@ -415,6 +415,11 @@ impl PageServerNode { .map(|x| x.parse::()) .transpose() .context("Failed to parse 'timeline_offloading' as bool")?, + wal_receiver_protocol_override: settings + .remove("wal_receiver_protocol_override") + .map(serde_json::from_str) + .transpose() + .context("parse `wal_receiver_protocol_override` from json")?, }; if !settings.is_empty() { bail!("Unrecognized tenant settings: {settings:?}") diff --git a/deny.toml b/deny.toml index 8bf643f4ba..7a1eecac99 100644 --- a/deny.toml +++ b/deny.toml @@ -33,7 +33,6 @@ reason = "the marvin attack only affects private key decryption, not public key [licenses] allow = [ "Apache-2.0", - "Artistic-2.0", "BSD-2-Clause", "BSD-3-Clause", "CC0-1.0", @@ -67,7 +66,7 @@ registries = [] # More documentation about the 'bans' section can be found here: # https://embarkstudios.github.io/cargo-deny/checks/bans/cfg.html [bans] -multiple-versions = "warn" +multiple-versions = "allow" wildcards = "allow" highlight = "all" workspace-default-features = "allow" diff --git a/docs/sourcetree.md b/docs/sourcetree.md index 3732bfdab2..1f7e913c07 100644 --- a/docs/sourcetree.md +++ b/docs/sourcetree.md @@ -113,21 +113,21 @@ so manual installation of dependencies is not recommended. A single virtual environment with all dependencies is described in the single `Pipfile`. ### Prerequisites -- Install Python 3.9 (the minimal supported version) or greater. +- Install Python 3.11 (the minimal supported version) or greater. - Our setup with poetry should work with newer python versions too. So feel free to open an issue with a `c/test-runner` label if something doesn't work as expected. - - If you have some trouble with other version you can resolve it by installing Python 3.9 separately, via [pyenv](https://github.com/pyenv/pyenv) or via system package manager e.g.: + - If you have some trouble with other version you can resolve it by installing Python 3.11 separately, via [pyenv](https://github.com/pyenv/pyenv) or via system package manager e.g.: ```bash # In Ubuntu sudo add-apt-repository ppa:deadsnakes/ppa sudo apt update - sudo apt install python3.9 + sudo apt install python3.11 ``` - Install `poetry` - Exact version of `poetry` is not important, see installation instructions available at poetry's [website](https://python-poetry.org/docs/#installation). - Install dependencies via `./scripts/pysync`. - Note that CI uses specific Python version (look for `PYTHON_VERSION` [here](https://github.com/neondatabase/docker-images/blob/main/rust/Dockerfile)) so if you have different version some linting tools can yield different result locally vs in the CI. - - You can explicitly specify which Python to use by running `poetry env use /path/to/python`, e.g. `poetry env use python3.9`. + - You can explicitly specify which Python to use by running `poetry env use /path/to/python`, e.g. `poetry env use python3.11`. This may also disable the `The currently activated Python version X.Y.Z is not supported by the project` warning. Run `poetry shell` to activate the virtual environment. diff --git a/libs/pageserver_api/Cargo.toml b/libs/pageserver_api/Cargo.toml index 8710904cec..79da05da6c 100644 --- a/libs/pageserver_api/Cargo.toml +++ b/libs/pageserver_api/Cargo.toml @@ -33,6 +33,7 @@ remote_storage.workspace = true postgres_backend.workspace = true nix = {workspace = true, optional = true} reqwest.workspace = true +rand.workspace = true [dev-dependencies] bincode.workspace = true diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index a0a6dedcdd..50e0e9e504 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -18,7 +18,7 @@ use std::{ str::FromStr, time::Duration, }; -use utils::logging::LogFormat; +use utils::{logging::LogFormat, postgres_client::PostgresClientProtocol}; use crate::models::ImageCompressionAlgorithm; use crate::models::LsnLease; @@ -97,6 +97,15 @@ pub struct ConfigToml { pub control_plane_api: Option, pub control_plane_api_token: Option, pub control_plane_emergency_mode: bool, + /// Unstable feature: subject to change or removal without notice. + /// See . + pub import_pgdata_upcall_api: Option, + /// Unstable feature: subject to change or removal without notice. + /// See . + pub import_pgdata_upcall_api_token: Option, + /// Unstable feature: subject to change or removal without notice. + /// See . + pub import_pgdata_aws_endpoint_url: Option, pub heatmap_upload_concurrency: usize, pub secondary_download_concurrency: usize, pub virtual_file_io_engine: Option, @@ -109,7 +118,8 @@ pub struct ConfigToml { pub virtual_file_io_mode: Option, #[serde(skip_serializing_if = "Option::is_none")] pub no_sync: Option, - pub page_service_pipelining: Option, + pub wal_receiver_protocol: PostgresClientProtocol, + pub page_service_pipelining: PageServicePipeliningConfig, } #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -127,16 +137,23 @@ pub struct DiskUsageEvictionTaskConfig { } #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[serde(tag = "mode", rename_all = "kebab-case")] #[serde(deny_unknown_fields)] -pub struct PageServicePipeliningConfig { +pub enum PageServicePipeliningConfig { + Serial, + Pipelined(PageServicePipeliningConfigPipelined), +} +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[serde(deny_unknown_fields)] +pub struct PageServicePipeliningConfigPipelined { /// Causes runtime errors if larger than max get_vectored batch size. pub max_batch_size: NonZeroUsize, - pub protocol_pipelining_mode: PageServiceProtocolPipeliningMode, + pub execution: PageServiceProtocolPipelinedExecutionStrategy, } #[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[serde(rename_all = "kebab-case")] -pub enum PageServiceProtocolPipeliningMode { +pub enum PageServiceProtocolPipelinedExecutionStrategy { ConcurrentFutures, Tasks, } @@ -282,6 +299,8 @@ pub struct TenantConfigToml { /// Enable auto-offloading of timelines. /// (either this flag or the pageserver-global one need to be set) pub timeline_offloading: bool, + + pub wal_receiver_protocol_override: Option, } pub mod defaults { @@ -333,6 +352,9 @@ pub mod defaults { pub const DEFAULT_EPHEMERAL_BYTES_PER_MEMORY_KB: usize = 0; pub const DEFAULT_IO_BUFFER_ALIGNMENT: usize = 512; + + pub const DEFAULT_WAL_RECEIVER_PROTOCOL: utils::postgres_client::PostgresClientProtocol = + utils::postgres_client::PostgresClientProtocol::Vanilla; } impl Default for ConfigToml { @@ -398,6 +420,10 @@ impl Default for ConfigToml { control_plane_api_token: (None), control_plane_emergency_mode: (false), + import_pgdata_upcall_api: (None), + import_pgdata_upcall_api_token: (None), + import_pgdata_aws_endpoint_url: (None), + heatmap_upload_concurrency: (DEFAULT_HEATMAP_UPLOAD_CONCURRENCY), secondary_download_concurrency: (DEFAULT_SECONDARY_DOWNLOAD_CONCURRENCY), @@ -415,10 +441,13 @@ impl Default for ConfigToml { virtual_file_io_mode: None, tenant_config: TenantConfigToml::default(), no_sync: None, - page_service_pipelining: Some(PageServicePipeliningConfig { - max_batch_size: NonZeroUsize::new(32).unwrap(), - protocol_pipelining_mode: PageServiceProtocolPipeliningMode::ConcurrentFutures, - }), + wal_receiver_protocol: DEFAULT_WAL_RECEIVER_PROTOCOL, + page_service_pipelining: PageServicePipeliningConfig::Pipelined( + PageServicePipeliningConfigPipelined { + max_batch_size: NonZeroUsize::new(32).unwrap(), + execution: PageServiceProtocolPipelinedExecutionStrategy::ConcurrentFutures, + }, + ), } } } @@ -506,6 +535,7 @@ impl Default for TenantConfigToml { lsn_lease_length: LsnLease::DEFAULT_LENGTH, lsn_lease_length_for_ts: LsnLease::DEFAULT_LENGTH_FOR_TS, timeline_offloading: false, + wal_receiver_protocol_override: None, } } } diff --git a/libs/pageserver_api/src/key.rs b/libs/pageserver_api/src/key.rs index 4505101ea6..523d143381 100644 --- a/libs/pageserver_api/src/key.rs +++ b/libs/pageserver_api/src/key.rs @@ -229,6 +229,18 @@ impl Key { } } +impl CompactKey { + pub fn raw(&self) -> i128 { + self.0 + } +} + +impl From for CompactKey { + fn from(value: i128) -> Self { + Self(value) + } +} + impl fmt::Display for Key { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( diff --git a/libs/pageserver_api/src/keyspace.rs b/libs/pageserver_api/src/keyspace.rs index 401887d362..c55b9e9484 100644 --- a/libs/pageserver_api/src/keyspace.rs +++ b/libs/pageserver_api/src/keyspace.rs @@ -48,7 +48,7 @@ pub struct ShardedRange<'a> { // Calculate the size of a range within the blocks of the same relation, or spanning only the // top page in the previous relation's space. -fn contiguous_range_len(range: &Range) -> u32 { +pub fn contiguous_range_len(range: &Range) -> u32 { debug_assert!(is_contiguous_range(range)); if range.start.field6 == 0xffffffff { range.end.field6 + 1 @@ -67,7 +67,7 @@ fn contiguous_range_len(range: &Range) -> u32 { /// This matters, because: /// - Within such ranges, keys are used contiguously. Outside such ranges it is sparse. /// - Within such ranges, we may calculate distances using simple subtraction of field6. -fn is_contiguous_range(range: &Range) -> bool { +pub fn is_contiguous_range(range: &Range) -> bool { range.start.field1 == range.end.field1 && range.start.field2 == range.end.field2 && range.start.field3 == range.end.field3 diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 0dfa1ba817..42c5d10c05 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -2,6 +2,8 @@ pub mod detach_ancestor; pub mod partitioning; pub mod utilization; +#[cfg(feature = "testing")] +use camino::Utf8PathBuf; pub use utilization::PageserverUtilization; use std::{ @@ -21,6 +23,7 @@ use utils::{ completion, id::{NodeId, TenantId, TimelineId}, lsn::Lsn, + postgres_client::PostgresClientProtocol, serde_system_time, }; @@ -227,6 +230,9 @@ pub enum TimelineCreateRequestMode { // we continue to accept it by having it here. pg_version: Option, }, + ImportPgdata { + import_pgdata: TimelineCreateRequestModeImportPgdata, + }, // NB: Bootstrap is all-optional, and thus the serde(untagged) will cause serde to stop at Bootstrap. // (serde picks the first matching enum variant, in declaration order). Bootstrap { @@ -236,6 +242,42 @@ pub enum TimelineCreateRequestMode { }, } +#[derive(Serialize, Deserialize, Clone)] +pub struct TimelineCreateRequestModeImportPgdata { + pub location: ImportPgdataLocation, + pub idempotency_key: ImportPgdataIdempotencyKey, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub enum ImportPgdataLocation { + #[cfg(feature = "testing")] + LocalFs { path: Utf8PathBuf }, + AwsS3 { + region: String, + bucket: String, + /// A better name for this would be `prefix`; changing requires coordination with cplane. + /// See . + key: String, + }, +} + +#[derive(Serialize, Deserialize, Clone)] +#[serde(transparent)] +pub struct ImportPgdataIdempotencyKey(pub String); + +impl ImportPgdataIdempotencyKey { + pub fn random() -> Self { + use rand::{distributions::Alphanumeric, Rng}; + Self( + rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(20) + .map(char::from) + .collect(), + ) + } +} + #[derive(Serialize, Deserialize, Clone)] pub struct LsnLeaseRequest { pub lsn: Lsn, @@ -311,6 +353,7 @@ pub struct TenantConfig { pub lsn_lease_length: Option, pub lsn_lease_length_for_ts: Option, pub timeline_offloading: Option, + pub wal_receiver_protocol_override: Option, } /// The policy for the aux file storage. diff --git a/libs/postgres_initdb/Cargo.toml b/libs/postgres_initdb/Cargo.toml new file mode 100644 index 0000000000..1605279bce --- /dev/null +++ b/libs/postgres_initdb/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "postgres_initdb" +version = "0.1.0" +edition.workspace = true +license.workspace = true + +[dependencies] +anyhow.workspace = true +tokio.workspace = true +camino.workspace = true +thiserror.workspace = true +workspace_hack = { version = "0.1", path = "../../workspace_hack" } diff --git a/libs/postgres_initdb/src/lib.rs b/libs/postgres_initdb/src/lib.rs new file mode 100644 index 0000000000..2f072354fb --- /dev/null +++ b/libs/postgres_initdb/src/lib.rs @@ -0,0 +1,103 @@ +//! The canonical way we run `initdb` in Neon. +//! +//! initdb has implicit defaults that are dependent on the environment, e.g., locales & collations. +//! +//! This module's job is to eliminate the environment-dependence as much as possible. + +use std::fmt; + +use camino::Utf8Path; + +pub struct RunInitdbArgs<'a> { + pub superuser: &'a str, + pub locale: &'a str, + pub initdb_bin: &'a Utf8Path, + pub pg_version: u32, + pub library_search_path: &'a Utf8Path, + pub pgdata: &'a Utf8Path, +} + +#[derive(thiserror::Error, Debug)] +pub enum Error { + Spawn(std::io::Error), + Failed { + status: std::process::ExitStatus, + stderr: Vec, + }, + WaitOutput(std::io::Error), + Other(anyhow::Error), +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Error::Spawn(e) => write!(f, "Error spawning command: {:?}", e), + Error::Failed { status, stderr } => write!( + f, + "Command failed with status {:?}: {}", + status, + String::from_utf8_lossy(stderr) + ), + Error::WaitOutput(e) => write!(f, "Error waiting for command output: {:?}", e), + Error::Other(e) => write!(f, "Error: {:?}", e), + } + } +} + +pub async fn do_run_initdb(args: RunInitdbArgs<'_>) -> Result<(), Error> { + let RunInitdbArgs { + superuser, + locale, + initdb_bin: initdb_bin_path, + pg_version, + library_search_path, + pgdata, + } = args; + let mut initdb_command = tokio::process::Command::new(initdb_bin_path); + initdb_command + .args(["--pgdata", pgdata.as_ref()]) + .args(["--username", superuser]) + .args(["--encoding", "utf8"]) + .args(["--locale", locale]) + .arg("--no-instructions") + .arg("--no-sync") + .env_clear() + .env("LD_LIBRARY_PATH", library_search_path) + .env("DYLD_LIBRARY_PATH", library_search_path) + .stdin(std::process::Stdio::null()) + // stdout invocation produces the same output every time, we don't need it + .stdout(std::process::Stdio::null()) + // we would be interested in the stderr output, if there was any + .stderr(std::process::Stdio::piped()); + + // Before version 14, only the libc provide was available. + if pg_version > 14 { + // Version 17 brought with it a builtin locale provider which only provides + // C and C.UTF-8. While being safer for collation purposes since it is + // guaranteed to be consistent throughout a major release, it is also more + // performant. + let locale_provider = if pg_version >= 17 { "builtin" } else { "libc" }; + + initdb_command.args(["--locale-provider", locale_provider]); + } + + let initdb_proc = initdb_command.spawn().map_err(Error::Spawn)?; + + // Ideally we'd select here with the cancellation token, but the problem is that + // we can't safely terminate initdb: it launches processes of its own, and killing + // initdb doesn't kill them. After we return from this function, we want the target + // directory to be able to be cleaned up. + // See https://github.com/neondatabase/neon/issues/6385 + let initdb_output = initdb_proc + .wait_with_output() + .await + .map_err(Error::WaitOutput)?; + if !initdb_output.status.success() { + return Err(Error::Failed { + status: initdb_output.status, + stderr: initdb_output.stderr, + }); + } + + Ok(()) +} diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index b9e5387d86..4b0331999d 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -184,9 +184,8 @@ pub struct CancelKeyData { impl fmt::Display for CancelKeyData { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - // TODO: this is producing strange results, with 0xffffffff........ always in the logs. let hi = (self.backend_pid as u64) << 32; - let lo = self.cancel_key as u64; + let lo = (self.cancel_key as u64) & 0xffffffff; let id = hi | lo; // This format is more compact and might work better for logs. @@ -563,6 +562,9 @@ pub enum BeMessage<'a> { options: &'a [&'a str], }, KeepAlive(WalSndKeepAlive), + /// Batch of interpreted, shard filtered WAL records, + /// ready for the pageserver to ingest + InterpretedWalRecords(InterpretedWalRecordsBody<'a>), } /// Common shorthands. @@ -673,6 +675,22 @@ pub struct WalSndKeepAlive { pub request_reply: bool, } +/// Batch of interpreted WAL records used in the interpreted +/// safekeeper to pageserver protocol. +/// +/// Note that the pageserver uses the RawInterpretedWalRecordsBody +/// counterpart of this from the neondatabase/rust-postgres repo. +/// If you're changing this struct, you likely need to change its +/// twin as well. +#[derive(Debug)] +pub struct InterpretedWalRecordsBody<'a> { + /// End of raw WAL in [`Self::data`] + pub streaming_lsn: u64, + /// Current end of WAL on the server + pub commit_lsn: u64, + pub data: &'a [u8], +} + pub static HELLO_WORLD_ROW: BeMessage = BeMessage::DataRow(&[Some(b"hello world")]); // single text column @@ -997,6 +1015,19 @@ impl BeMessage<'_> { Ok(()) })? } + + BeMessage::InterpretedWalRecords(rec) => { + // We use the COPY_DATA_TAG for our custom message + // since this tag is interpreted as raw bytes. + buf.put_u8(b'd'); + write_body(buf, |buf| { + buf.put_u8(b'0'); // matches INTERPRETED_WAL_RECORD_TAG in postgres-protocol + // dependency + buf.put_u64(rec.streaming_lsn); + buf.put_u64(rec.commit_lsn); + buf.put_slice(rec.data); + }); + } } Ok(()) } @@ -1047,4 +1078,13 @@ mod tests { let data = [0, 0, 0, 7, 0, 0, 0, 0]; FeStartupPacket::parse(&mut BytesMut::from_iter(data)).unwrap_err(); } + + #[test] + fn cancel_key_data() { + let key = CancelKeyData { + backend_pid: -1817212860, + cancel_key: -1183897012, + }; + assert_eq!(format!("{key}"), "CancelKeyData(93af8844b96f2a4c)"); + } } diff --git a/libs/proxy/README.md b/libs/proxy/README.md new file mode 100644 index 0000000000..2ae6210e46 --- /dev/null +++ b/libs/proxy/README.md @@ -0,0 +1,6 @@ +This directory contains libraries that are specific for proxy. + +Currently, it contains a signficant fork/refactoring of rust-postgres that no longer reflects the API +of the original library. Since it was so significant, it made sense to upgrade it to it's own set of libraries. + +Proxy needs unique access to the protocol, which explains why such heavy modifications were necessary. diff --git a/libs/proxy/postgres-protocol2/Cargo.toml b/libs/proxy/postgres-protocol2/Cargo.toml new file mode 100644 index 0000000000..284a632954 --- /dev/null +++ b/libs/proxy/postgres-protocol2/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "postgres-protocol2" +version = "0.1.0" +edition = "2018" +license = "MIT/Apache-2.0" + +[dependencies] +base64 = "0.20" +byteorder.workspace = true +bytes.workspace = true +fallible-iterator.workspace = true +hmac.workspace = true +md-5 = "0.10" +memchr = "2.0" +rand.workspace = true +sha2.workspace = true +stringprep = "0.1" +tokio = { workspace = true, features = ["rt"] } + +[dev-dependencies] +tokio = { workspace = true, features = ["full"] } diff --git a/libs/proxy/postgres-protocol2/src/authentication/mod.rs b/libs/proxy/postgres-protocol2/src/authentication/mod.rs new file mode 100644 index 0000000000..71afa4b9b6 --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/authentication/mod.rs @@ -0,0 +1,37 @@ +//! Authentication protocol support. +use md5::{Digest, Md5}; + +pub mod sasl; + +/// Hashes authentication information in a way suitable for use in response +/// to an `AuthenticationMd5Password` message. +/// +/// The resulting string should be sent back to the database in a +/// `PasswordMessage` message. +#[inline] +pub fn md5_hash(username: &[u8], password: &[u8], salt: [u8; 4]) -> String { + let mut md5 = Md5::new(); + md5.update(password); + md5.update(username); + let output = md5.finalize_reset(); + md5.update(format!("{:x}", output)); + md5.update(salt); + format!("md5{:x}", md5.finalize()) +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn md5() { + let username = b"md5_user"; + let password = b"password"; + let salt = [0x2a, 0x3d, 0x8f, 0xe0]; + + assert_eq!( + md5_hash(username, password, salt), + "md562af4dd09bbb41884907a838a3233294" + ); + } +} diff --git a/libs/proxy/postgres-protocol2/src/authentication/sasl.rs b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs new file mode 100644 index 0000000000..19aa3c1e9a --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs @@ -0,0 +1,516 @@ +//! SASL-based authentication support. + +use hmac::{Hmac, Mac}; +use rand::{self, Rng}; +use sha2::digest::FixedOutput; +use sha2::{Digest, Sha256}; +use std::fmt::Write; +use std::io; +use std::iter; +use std::mem; +use std::str; +use tokio::task::yield_now; + +const NONCE_LENGTH: usize = 24; + +/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism. +pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256"; +/// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism. +pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS"; + +// since postgres passwords are not required to exclude saslprep-prohibited +// characters or even be valid UTF8, we run saslprep if possible and otherwise +// return the raw password. +fn normalize(pass: &[u8]) -> Vec { + let pass = match str::from_utf8(pass) { + Ok(pass) => pass, + Err(_) => return pass.to_vec(), + }; + + match stringprep::saslprep(pass) { + Ok(pass) => pass.into_owned().into_bytes(), + Err(_) => pass.as_bytes().to_vec(), + } +} + +pub(crate) async fn hi(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] { + let mut hmac = + Hmac::::new_from_slice(str).expect("HMAC is able to accept all key sizes"); + hmac.update(salt); + hmac.update(&[0, 0, 0, 1]); + let mut prev = hmac.finalize().into_bytes(); + + let mut hi = prev; + + for i in 1..iterations { + let mut hmac = Hmac::::new_from_slice(str).expect("already checked above"); + hmac.update(&prev); + prev = hmac.finalize().into_bytes(); + + for (hi, prev) in hi.iter_mut().zip(prev) { + *hi ^= prev; + } + // yield every ~250us + // hopefully reduces tail latencies + if i % 1024 == 0 { + yield_now().await + } + } + + hi.into() +} + +enum ChannelBindingInner { + Unrequested, + Unsupported, + TlsServerEndPoint(Vec), +} + +/// The channel binding configuration for a SCRAM authentication exchange. +pub struct ChannelBinding(ChannelBindingInner); + +impl ChannelBinding { + /// The server did not request channel binding. + pub fn unrequested() -> ChannelBinding { + ChannelBinding(ChannelBindingInner::Unrequested) + } + + /// The server requested channel binding but the client is unable to provide it. + pub fn unsupported() -> ChannelBinding { + ChannelBinding(ChannelBindingInner::Unsupported) + } + + /// The server requested channel binding and the client will use the `tls-server-end-point` + /// method. + pub fn tls_server_end_point(signature: Vec) -> ChannelBinding { + ChannelBinding(ChannelBindingInner::TlsServerEndPoint(signature)) + } + + fn gs2_header(&self) -> &'static str { + match self.0 { + ChannelBindingInner::Unrequested => "y,,", + ChannelBindingInner::Unsupported => "n,,", + ChannelBindingInner::TlsServerEndPoint(_) => "p=tls-server-end-point,,", + } + } + + fn cbind_data(&self) -> &[u8] { + match self.0 { + ChannelBindingInner::Unrequested | ChannelBindingInner::Unsupported => &[], + ChannelBindingInner::TlsServerEndPoint(ref buf) => buf, + } + } +} + +/// A pair of keys for the SCRAM-SHA-256 mechanism. +/// See for details. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct ScramKeys { + /// Used by server to authenticate client. + pub client_key: [u8; N], + /// Used by client to verify server's signature. + pub server_key: [u8; N], +} + +/// Password or keys which were derived from it. +enum Credentials { + /// A regular password as a vector of bytes. + Password(Vec), + /// A precomputed pair of keys. + Keys(Box>), +} + +enum State { + Update { + nonce: String, + password: Credentials<32>, + channel_binding: ChannelBinding, + }, + Finish { + server_key: [u8; 32], + auth_message: String, + }, + Done, +} + +/// A type which handles the client side of the SCRAM-SHA-256/SCRAM-SHA-256-PLUS authentication +/// process. +/// +/// During the authentication process, if the backend sends an `AuthenticationSASL` message which +/// includes `SCRAM-SHA-256` as an authentication mechanism, this type can be used. +/// +/// After a `ScramSha256` is constructed, the buffer returned by the `message()` method should be +/// sent to the backend in a `SASLInitialResponse` message along with the mechanism name. +/// +/// The server will reply with an `AuthenticationSASLContinue` message. Its contents should be +/// passed to the `update()` method, after which the buffer returned by the `message()` method +/// should be sent to the backend in a `SASLResponse` message. +/// +/// The server will reply with an `AuthenticationSASLFinal` message. Its contents should be passed +/// to the `finish()` method, after which the authentication process is complete. +pub struct ScramSha256 { + message: String, + state: State, +} + +fn nonce() -> String { + // rand 0.5's ThreadRng is cryptographically secure + let mut rng = rand::thread_rng(); + (0..NONCE_LENGTH) + .map(|_| { + let mut v = rng.gen_range(0x21u8..0x7e); + if v == 0x2c { + v = 0x7e + } + v as char + }) + .collect() +} + +impl ScramSha256 { + /// Constructs a new instance which will use the provided password for authentication. + pub fn new(password: &[u8], channel_binding: ChannelBinding) -> ScramSha256 { + let password = Credentials::Password(normalize(password)); + ScramSha256::new_inner(password, channel_binding, nonce()) + } + + /// Constructs a new instance which will use the provided key pair for authentication. + pub fn new_with_keys(keys: ScramKeys<32>, channel_binding: ChannelBinding) -> ScramSha256 { + let password = Credentials::Keys(keys.into()); + ScramSha256::new_inner(password, channel_binding, nonce()) + } + + fn new_inner( + password: Credentials<32>, + channel_binding: ChannelBinding, + nonce: String, + ) -> ScramSha256 { + ScramSha256 { + message: format!("{}n=,r={}", channel_binding.gs2_header(), nonce), + state: State::Update { + nonce, + password, + channel_binding, + }, + } + } + + /// Returns the message which should be sent to the backend in an `SASLResponse` message. + pub fn message(&self) -> &[u8] { + if let State::Done = self.state { + panic!("invalid SCRAM state"); + } + self.message.as_bytes() + } + + /// Updates the state machine with the response from the backend. + /// + /// This should be called when an `AuthenticationSASLContinue` message is received. + pub async fn update(&mut self, message: &[u8]) -> io::Result<()> { + let (client_nonce, password, channel_binding) = + match mem::replace(&mut self.state, State::Done) { + State::Update { + nonce, + password, + channel_binding, + } => (nonce, password, channel_binding), + _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")), + }; + + let message = + str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + + let parsed = Parser::new(message).server_first_message()?; + + if !parsed.nonce.starts_with(&client_nonce) { + return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid nonce")); + } + + let (client_key, server_key) = match password { + Credentials::Password(password) => { + let salt = match base64::decode(parsed.salt) { + Ok(salt) => salt, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), + }; + + let salted_password = hi(&password, &salt, parsed.iteration_count).await; + + let make_key = |name| { + let mut hmac = Hmac::::new_from_slice(&salted_password) + .expect("HMAC is able to accept all key sizes"); + hmac.update(name); + + let mut key = [0u8; 32]; + key.copy_from_slice(hmac.finalize().into_bytes().as_slice()); + key + }; + + (make_key(b"Client Key"), make_key(b"Server Key")) + } + Credentials::Keys(keys) => (keys.client_key, keys.server_key), + }; + + let mut hash = Sha256::default(); + hash.update(client_key); + let stored_key = hash.finalize_fixed(); + + let mut cbind_input = vec![]; + cbind_input.extend(channel_binding.gs2_header().as_bytes()); + cbind_input.extend(channel_binding.cbind_data()); + let cbind_input = base64::encode(&cbind_input); + + self.message.clear(); + write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap(); + + let auth_message = format!("n=,r={},{},{}", client_nonce, message, self.message); + + let mut hmac = Hmac::::new_from_slice(&stored_key) + .expect("HMAC is able to accept all key sizes"); + hmac.update(auth_message.as_bytes()); + let client_signature = hmac.finalize().into_bytes(); + + let mut client_proof = client_key; + for (proof, signature) in client_proof.iter_mut().zip(client_signature) { + *proof ^= signature; + } + + write!(&mut self.message, ",p={}", base64::encode(client_proof)).unwrap(); + + self.state = State::Finish { + server_key, + auth_message, + }; + Ok(()) + } + + /// Finalizes the authentication process. + /// + /// This should be called when the backend sends an `AuthenticationSASLFinal` message. + /// Authentication has only succeeded if this method returns `Ok(())`. + pub fn finish(&mut self, message: &[u8]) -> io::Result<()> { + let (server_key, auth_message) = match mem::replace(&mut self.state, State::Done) { + State::Finish { + server_key, + auth_message, + } => (server_key, auth_message), + _ => return Err(io::Error::new(io::ErrorKind::Other, "invalid SCRAM state")), + }; + + let message = + str::from_utf8(message).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?; + + let parsed = Parser::new(message).server_final_message()?; + + let verifier = match parsed { + ServerFinalMessage::Error(e) => { + return Err(io::Error::new( + io::ErrorKind::Other, + format!("SCRAM error: {}", e), + )); + } + ServerFinalMessage::Verifier(verifier) => verifier, + }; + + let verifier = match base64::decode(verifier) { + Ok(verifier) => verifier, + Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), + }; + + let mut hmac = Hmac::::new_from_slice(&server_key) + .expect("HMAC is able to accept all key sizes"); + hmac.update(auth_message.as_bytes()); + hmac.verify_slice(&verifier) + .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "SCRAM verification error")) + } +} + +struct Parser<'a> { + s: &'a str, + it: iter::Peekable>, +} + +impl<'a> Parser<'a> { + fn new(s: &'a str) -> Parser<'a> { + Parser { + s, + it: s.char_indices().peekable(), + } + } + + fn eat(&mut self, target: char) -> io::Result<()> { + match self.it.next() { + Some((_, c)) if c == target => Ok(()), + Some((i, c)) => { + let m = format!( + "unexpected character at byte {}: expected `{}` but got `{}", + i, target, c + ); + Err(io::Error::new(io::ErrorKind::InvalidInput, m)) + } + None => Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "unexpected EOF", + )), + } + } + + fn take_while(&mut self, f: F) -> io::Result<&'a str> + where + F: Fn(char) -> bool, + { + let start = match self.it.peek() { + Some(&(i, _)) => i, + None => return Ok(""), + }; + + loop { + match self.it.peek() { + Some(&(_, c)) if f(c) => { + self.it.next(); + } + Some(&(i, _)) => return Ok(&self.s[start..i]), + None => return Ok(&self.s[start..]), + } + } + } + + fn printable(&mut self) -> io::Result<&'a str> { + self.take_while(|c| matches!(c, '\x21'..='\x2b' | '\x2d'..='\x7e')) + } + + fn nonce(&mut self) -> io::Result<&'a str> { + self.eat('r')?; + self.eat('=')?; + self.printable() + } + + fn base64(&mut self) -> io::Result<&'a str> { + self.take_while(|c| matches!(c, 'a'..='z' | 'A'..='Z' | '0'..='9' | '/' | '+' | '=')) + } + + fn salt(&mut self) -> io::Result<&'a str> { + self.eat('s')?; + self.eat('=')?; + self.base64() + } + + fn posit_number(&mut self) -> io::Result { + let n = self.take_while(|c| c.is_ascii_digit())?; + n.parse() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) + } + + fn iteration_count(&mut self) -> io::Result { + self.eat('i')?; + self.eat('=')?; + self.posit_number() + } + + fn eof(&mut self) -> io::Result<()> { + match self.it.peek() { + Some(&(i, _)) => Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unexpected trailing data at byte {}", i), + )), + None => Ok(()), + } + } + + fn server_first_message(&mut self) -> io::Result> { + let nonce = self.nonce()?; + self.eat(',')?; + let salt = self.salt()?; + self.eat(',')?; + let iteration_count = self.iteration_count()?; + self.eof()?; + + Ok(ServerFirstMessage { + nonce, + salt, + iteration_count, + }) + } + + fn value(&mut self) -> io::Result<&'a str> { + self.take_while(|c| matches!(c, '\0' | '=' | ',')) + } + + fn server_error(&mut self) -> io::Result> { + match self.it.peek() { + Some(&(_, 'e')) => {} + _ => return Ok(None), + } + + self.eat('e')?; + self.eat('=')?; + self.value().map(Some) + } + + fn verifier(&mut self) -> io::Result<&'a str> { + self.eat('v')?; + self.eat('=')?; + self.base64() + } + + fn server_final_message(&mut self) -> io::Result> { + let message = match self.server_error()? { + Some(error) => ServerFinalMessage::Error(error), + None => ServerFinalMessage::Verifier(self.verifier()?), + }; + self.eof()?; + Ok(message) + } +} + +struct ServerFirstMessage<'a> { + nonce: &'a str, + salt: &'a str, + iteration_count: u32, +} + +enum ServerFinalMessage<'a> { + Error(&'a str), + Verifier(&'a str), +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn parse_server_first_message() { + let message = "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096"; + let message = Parser::new(message).server_first_message().unwrap(); + assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j"); + assert_eq!(message.salt, "QSXCR+Q6sek8bf92"); + assert_eq!(message.iteration_count, 4096); + } + + // recorded auth exchange from psql + #[tokio::test] + async fn exchange() { + let password = "foobar"; + let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB"; + + let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB"; + let server_first = + "r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\ + =4096"; + let client_final = + "c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\ + 1NTlQYNs5BTeQjdHdk7lOflDo5re2an8="; + let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw="; + + let mut scram = ScramSha256::new_inner( + Credentials::Password(normalize(password.as_bytes())), + ChannelBinding::unsupported(), + nonce.to_string(), + ); + assert_eq!(str::from_utf8(scram.message()).unwrap(), client_first); + + scram.update(server_first.as_bytes()).await.unwrap(); + assert_eq!(str::from_utf8(scram.message()).unwrap(), client_final); + + scram.finish(server_final.as_bytes()).unwrap(); + } +} diff --git a/libs/proxy/postgres-protocol2/src/escape/mod.rs b/libs/proxy/postgres-protocol2/src/escape/mod.rs new file mode 100644 index 0000000000..0ba7efdcac --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/escape/mod.rs @@ -0,0 +1,93 @@ +//! Provides functions for escaping literals and identifiers for use +//! in SQL queries. +//! +//! Prefer parameterized queries where possible. Do not escape +//! parameters in a parameterized query. + +#[cfg(test)] +mod test; + +/// Escape a literal and surround result with single quotes. Not +/// recommended in most cases. +/// +/// If input contains backslashes, result will be of the form ` +/// E'...'` so it is safe to use regardless of the setting of +/// standard_conforming_strings. +pub fn escape_literal(input: &str) -> String { + escape_internal(input, false) +} + +/// Escape an identifier and surround result with double quotes. +pub fn escape_identifier(input: &str) -> String { + escape_internal(input, true) +} + +// Translation of PostgreSQL libpq's PQescapeInternal(). Does not +// require a connection because input string is known to be valid +// UTF-8. +// +// Escape arbitrary strings. If as_ident is true, we escape the +// result as an identifier; if false, as a literal. The result is +// returned in a newly allocated buffer. If we fail due to an +// encoding violation or out of memory condition, we return NULL, +// storing an error message into conn. +fn escape_internal(input: &str, as_ident: bool) -> String { + let mut num_backslashes = 0; + let mut num_quotes = 0; + let quote_char = if as_ident { '"' } else { '\'' }; + + // Scan the string for characters that must be escaped. + for ch in input.chars() { + if ch == quote_char { + num_quotes += 1; + } else if ch == '\\' { + num_backslashes += 1; + } + } + + // Allocate output String. + let mut result_size = input.len() + num_quotes + 3; // two quotes, plus a NUL + if !as_ident && num_backslashes > 0 { + result_size += num_backslashes + 2; + } + + let mut output = String::with_capacity(result_size); + + // If we are escaping a literal that contains backslashes, we use + // the escape string syntax so that the result is correct under + // either value of standard_conforming_strings. We also emit a + // leading space in this case, to guard against the possibility + // that the result might be interpolated immediately following an + // identifier. + if !as_ident && num_backslashes > 0 { + output.push(' '); + output.push('E'); + } + + // Opening quote. + output.push(quote_char); + + // Use fast path if possible. + // + // We've already verified that the input string is well-formed in + // the current encoding. If it contains no quotes and, in the + // case of literal-escaping, no backslashes, then we can just copy + // it directly to the output buffer, adding the necessary quotes. + // + // If not, we must rescan the input and process each character + // individually. + if num_quotes == 0 && (num_backslashes == 0 || as_ident) { + output.push_str(input); + } else { + for ch in input.chars() { + if ch == quote_char || (!as_ident && ch == '\\') { + output.push(ch); + } + output.push(ch); + } + } + + output.push(quote_char); + + output +} diff --git a/libs/proxy/postgres-protocol2/src/escape/test.rs b/libs/proxy/postgres-protocol2/src/escape/test.rs new file mode 100644 index 0000000000..4816a103b7 --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/escape/test.rs @@ -0,0 +1,17 @@ +use crate::escape::{escape_identifier, escape_literal}; + +#[test] +fn test_escape_idenifier() { + assert_eq!(escape_identifier("foo"), String::from("\"foo\"")); + assert_eq!(escape_identifier("f\\oo"), String::from("\"f\\oo\"")); + assert_eq!(escape_identifier("f'oo"), String::from("\"f'oo\"")); + assert_eq!(escape_identifier("f\"oo"), String::from("\"f\"\"oo\"")); +} + +#[test] +fn test_escape_literal() { + assert_eq!(escape_literal("foo"), String::from("'foo'")); + assert_eq!(escape_literal("f\\oo"), String::from(" E'f\\\\oo'")); + assert_eq!(escape_literal("f'oo"), String::from("'f''oo'")); + assert_eq!(escape_literal("f\"oo"), String::from("'f\"oo'")); +} diff --git a/libs/proxy/postgres-protocol2/src/lib.rs b/libs/proxy/postgres-protocol2/src/lib.rs new file mode 100644 index 0000000000..947f2f835d --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/lib.rs @@ -0,0 +1,78 @@ +//! Low level Postgres protocol APIs. +//! +//! This crate implements the low level components of Postgres's communication +//! protocol, including message and value serialization and deserialization. +//! It is designed to be used as a building block by higher level APIs such as +//! `rust-postgres`, and should not typically be used directly. +//! +//! # Note +//! +//! This library assumes that the `client_encoding` backend parameter has been +//! set to `UTF8`. It will most likely not behave properly if that is not the case. +#![doc(html_root_url = "https://docs.rs/postgres-protocol/0.6")] +#![warn(missing_docs, rust_2018_idioms, clippy::all)] + +use byteorder::{BigEndian, ByteOrder}; +use bytes::{BufMut, BytesMut}; +use std::io; + +pub mod authentication; +pub mod escape; +pub mod message; +pub mod password; +pub mod types; + +/// A Postgres OID. +pub type Oid = u32; + +/// A Postgres Log Sequence Number (LSN). +pub type Lsn = u64; + +/// An enum indicating if a value is `NULL` or not. +pub enum IsNull { + /// The value is `NULL`. + Yes, + /// The value is not `NULL`. + No, +} + +fn write_nullable(serializer: F, buf: &mut BytesMut) -> Result<(), E> +where + F: FnOnce(&mut BytesMut) -> Result, + E: From, +{ + let base = buf.len(); + buf.put_i32(0); + let size = match serializer(buf)? { + IsNull::No => i32::from_usize(buf.len() - base - 4)?, + IsNull::Yes => -1, + }; + BigEndian::write_i32(&mut buf[base..], size); + + Ok(()) +} + +trait FromUsize: Sized { + fn from_usize(x: usize) -> Result; +} + +macro_rules! from_usize { + ($t:ty) => { + impl FromUsize for $t { + #[inline] + fn from_usize(x: usize) -> io::Result<$t> { + if x > <$t>::MAX as usize { + Err(io::Error::new( + io::ErrorKind::InvalidInput, + "value too large to transmit", + )) + } else { + Ok(x as $t) + } + } + } + }; +} + +from_usize!(i16); +from_usize!(i32); diff --git a/libs/proxy/postgres-protocol2/src/message/backend.rs b/libs/proxy/postgres-protocol2/src/message/backend.rs new file mode 100644 index 0000000000..356d142f3f --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/message/backend.rs @@ -0,0 +1,766 @@ +#![allow(missing_docs)] + +use byteorder::{BigEndian, ByteOrder, ReadBytesExt}; +use bytes::{Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; +use memchr::memchr; +use std::cmp; +use std::io::{self, Read}; +use std::ops::Range; +use std::str; + +use crate::Oid; + +// top-level message tags +const PARSE_COMPLETE_TAG: u8 = b'1'; +const BIND_COMPLETE_TAG: u8 = b'2'; +const CLOSE_COMPLETE_TAG: u8 = b'3'; +pub const NOTIFICATION_RESPONSE_TAG: u8 = b'A'; +const COPY_DONE_TAG: u8 = b'c'; +const COMMAND_COMPLETE_TAG: u8 = b'C'; +const COPY_DATA_TAG: u8 = b'd'; +const DATA_ROW_TAG: u8 = b'D'; +const ERROR_RESPONSE_TAG: u8 = b'E'; +const COPY_IN_RESPONSE_TAG: u8 = b'G'; +const COPY_OUT_RESPONSE_TAG: u8 = b'H'; +const COPY_BOTH_RESPONSE_TAG: u8 = b'W'; +const EMPTY_QUERY_RESPONSE_TAG: u8 = b'I'; +const BACKEND_KEY_DATA_TAG: u8 = b'K'; +pub const NO_DATA_TAG: u8 = b'n'; +pub const NOTICE_RESPONSE_TAG: u8 = b'N'; +const AUTHENTICATION_TAG: u8 = b'R'; +const PORTAL_SUSPENDED_TAG: u8 = b's'; +pub const PARAMETER_STATUS_TAG: u8 = b'S'; +const PARAMETER_DESCRIPTION_TAG: u8 = b't'; +const ROW_DESCRIPTION_TAG: u8 = b'T'; +pub const READY_FOR_QUERY_TAG: u8 = b'Z'; + +#[derive(Debug, Copy, Clone)] +pub struct Header { + tag: u8, + len: i32, +} + +#[allow(clippy::len_without_is_empty)] +impl Header { + #[inline] + pub fn parse(buf: &[u8]) -> io::Result> { + if buf.len() < 5 { + return Ok(None); + } + + let tag = buf[0]; + let len = BigEndian::read_i32(&buf[1..]); + + if len < 4 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid message length: header length < 4", + )); + } + + Ok(Some(Header { tag, len })) + } + + #[inline] + pub fn tag(self) -> u8 { + self.tag + } + + #[inline] + pub fn len(self) -> i32 { + self.len + } +} + +/// An enum representing Postgres backend messages. +#[non_exhaustive] +pub enum Message { + AuthenticationCleartextPassword, + AuthenticationGss, + AuthenticationKerberosV5, + AuthenticationMd5Password(AuthenticationMd5PasswordBody), + AuthenticationOk, + AuthenticationScmCredential, + AuthenticationSspi, + AuthenticationGssContinue, + AuthenticationSasl(AuthenticationSaslBody), + AuthenticationSaslContinue(AuthenticationSaslContinueBody), + AuthenticationSaslFinal(AuthenticationSaslFinalBody), + BackendKeyData(BackendKeyDataBody), + BindComplete, + CloseComplete, + CommandComplete(CommandCompleteBody), + CopyData, + CopyDone, + CopyInResponse, + CopyOutResponse, + CopyBothResponse, + DataRow(DataRowBody), + EmptyQueryResponse, + ErrorResponse(ErrorResponseBody), + NoData, + NoticeResponse(NoticeResponseBody), + NotificationResponse(NotificationResponseBody), + ParameterDescription(ParameterDescriptionBody), + ParameterStatus(ParameterStatusBody), + ParseComplete, + PortalSuspended, + ReadyForQuery(ReadyForQueryBody), + RowDescription(RowDescriptionBody), +} + +impl Message { + #[inline] + pub fn parse(buf: &mut BytesMut) -> io::Result> { + if buf.len() < 5 { + let to_read = 5 - buf.len(); + buf.reserve(to_read); + return Ok(None); + } + + let tag = buf[0]; + let len = (&buf[1..5]).read_u32::().unwrap(); + + if len < 4 { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: parsing u32", + )); + } + + let total_len = len as usize + 1; + if buf.len() < total_len { + let to_read = total_len - buf.len(); + buf.reserve(to_read); + return Ok(None); + } + + let mut buf = Buffer { + bytes: buf.split_to(total_len).freeze(), + idx: 5, + }; + + let message = match tag { + PARSE_COMPLETE_TAG => Message::ParseComplete, + BIND_COMPLETE_TAG => Message::BindComplete, + CLOSE_COMPLETE_TAG => Message::CloseComplete, + NOTIFICATION_RESPONSE_TAG => { + let process_id = buf.read_i32::()?; + let channel = buf.read_cstr()?; + let message = buf.read_cstr()?; + Message::NotificationResponse(NotificationResponseBody { + process_id, + channel, + message, + }) + } + COPY_DONE_TAG => Message::CopyDone, + COMMAND_COMPLETE_TAG => { + let tag = buf.read_cstr()?; + Message::CommandComplete(CommandCompleteBody { tag }) + } + COPY_DATA_TAG => Message::CopyData, + DATA_ROW_TAG => { + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::DataRow(DataRowBody { storage, len }) + } + ERROR_RESPONSE_TAG => { + let storage = buf.read_all(); + Message::ErrorResponse(ErrorResponseBody { storage }) + } + COPY_IN_RESPONSE_TAG => Message::CopyInResponse, + COPY_OUT_RESPONSE_TAG => Message::CopyOutResponse, + COPY_BOTH_RESPONSE_TAG => Message::CopyBothResponse, + EMPTY_QUERY_RESPONSE_TAG => Message::EmptyQueryResponse, + BACKEND_KEY_DATA_TAG => { + let process_id = buf.read_i32::()?; + let secret_key = buf.read_i32::()?; + Message::BackendKeyData(BackendKeyDataBody { + process_id, + secret_key, + }) + } + NO_DATA_TAG => Message::NoData, + NOTICE_RESPONSE_TAG => { + let storage = buf.read_all(); + Message::NoticeResponse(NoticeResponseBody { storage }) + } + AUTHENTICATION_TAG => match buf.read_i32::()? { + 0 => Message::AuthenticationOk, + 2 => Message::AuthenticationKerberosV5, + 3 => Message::AuthenticationCleartextPassword, + 5 => { + let mut salt = [0; 4]; + buf.read_exact(&mut salt)?; + Message::AuthenticationMd5Password(AuthenticationMd5PasswordBody { salt }) + } + 6 => Message::AuthenticationScmCredential, + 7 => Message::AuthenticationGss, + 8 => Message::AuthenticationGssContinue, + 9 => Message::AuthenticationSspi, + 10 => { + let storage = buf.read_all(); + Message::AuthenticationSasl(AuthenticationSaslBody(storage)) + } + 11 => { + let storage = buf.read_all(); + Message::AuthenticationSaslContinue(AuthenticationSaslContinueBody(storage)) + } + 12 => { + let storage = buf.read_all(); + Message::AuthenticationSaslFinal(AuthenticationSaslFinalBody(storage)) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown authentication tag `{}`", tag), + )); + } + }, + PORTAL_SUSPENDED_TAG => Message::PortalSuspended, + PARAMETER_STATUS_TAG => { + let name = buf.read_cstr()?; + let value = buf.read_cstr()?; + Message::ParameterStatus(ParameterStatusBody { name, value }) + } + PARAMETER_DESCRIPTION_TAG => { + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::ParameterDescription(ParameterDescriptionBody { storage, len }) + } + ROW_DESCRIPTION_TAG => { + let len = buf.read_u16::()?; + let storage = buf.read_all(); + Message::RowDescription(RowDescriptionBody { storage, len }) + } + READY_FOR_QUERY_TAG => { + let status = buf.read_u8()?; + Message::ReadyForQuery(ReadyForQueryBody { status }) + } + tag => { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("unknown message tag `{}`", tag), + )); + } + }; + + if !buf.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: expected buffer to be empty", + )); + } + + Ok(Some(message)) + } +} + +struct Buffer { + bytes: Bytes, + idx: usize, +} + +impl Buffer { + #[inline] + fn slice(&self) -> &[u8] { + &self.bytes[self.idx..] + } + + #[inline] + fn is_empty(&self) -> bool { + self.slice().is_empty() + } + + #[inline] + fn read_cstr(&mut self) -> io::Result { + match memchr(0, self.slice()) { + Some(pos) => { + let start = self.idx; + let end = start + pos; + let cstr = self.bytes.slice(start..end); + self.idx = end + 1; + Ok(cstr) + } + None => Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "unexpected EOF", + )), + } + } + + #[inline] + fn read_all(&mut self) -> Bytes { + let buf = self.bytes.slice(self.idx..); + self.idx = self.bytes.len(); + buf + } +} + +impl Read for Buffer { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let len = { + let slice = self.slice(); + let len = cmp::min(slice.len(), buf.len()); + buf[..len].copy_from_slice(&slice[..len]); + len + }; + self.idx += len; + Ok(len) + } +} + +pub struct AuthenticationMd5PasswordBody { + salt: [u8; 4], +} + +impl AuthenticationMd5PasswordBody { + #[inline] + pub fn salt(&self) -> [u8; 4] { + self.salt + } +} + +pub struct AuthenticationSaslBody(Bytes); + +impl AuthenticationSaslBody { + #[inline] + pub fn mechanisms(&self) -> SaslMechanisms<'_> { + SaslMechanisms(&self.0) + } +} + +pub struct SaslMechanisms<'a>(&'a [u8]); + +impl<'a> FallibleIterator for SaslMechanisms<'a> { + type Item = &'a str; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result> { + let value_end = find_null(self.0, 0)?; + if value_end == 0 { + if self.0.len() != 1 { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "invalid message length: expected to be at end of iterator for sasl", + )); + } + Ok(None) + } else { + let value = get_str(&self.0[..value_end])?; + self.0 = &self.0[value_end + 1..]; + Ok(Some(value)) + } + } +} + +pub struct AuthenticationSaslContinueBody(Bytes); + +impl AuthenticationSaslContinueBody { + #[inline] + pub fn data(&self) -> &[u8] { + &self.0 + } +} + +pub struct AuthenticationSaslFinalBody(Bytes); + +impl AuthenticationSaslFinalBody { + #[inline] + pub fn data(&self) -> &[u8] { + &self.0 + } +} + +pub struct BackendKeyDataBody { + process_id: i32, + secret_key: i32, +} + +impl BackendKeyDataBody { + #[inline] + pub fn process_id(&self) -> i32 { + self.process_id + } + + #[inline] + pub fn secret_key(&self) -> i32 { + self.secret_key + } +} + +pub struct CommandCompleteBody { + tag: Bytes, +} + +impl CommandCompleteBody { + #[inline] + pub fn tag(&self) -> io::Result<&str> { + get_str(&self.tag) + } +} + +#[derive(Debug)] +pub struct DataRowBody { + storage: Bytes, + len: u16, +} + +impl DataRowBody { + #[inline] + pub fn ranges(&self) -> DataRowRanges<'_> { + DataRowRanges { + buf: &self.storage, + len: self.storage.len(), + remaining: self.len, + } + } + + #[inline] + pub fn buffer(&self) -> &[u8] { + &self.storage + } +} + +pub struct DataRowRanges<'a> { + buf: &'a [u8], + len: usize, + remaining: u16, +} + +impl FallibleIterator for DataRowRanges<'_> { + type Item = Option>; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result>>> { + if self.remaining == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: datarowrange is not empty", + )); + } + } + + self.remaining -= 1; + let len = self.buf.read_i32::()?; + if len < 0 { + Ok(Some(None)) + } else { + let len = len as usize; + if self.buf.len() < len { + return Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "unexpected EOF", + )); + } + let base = self.len - self.buf.len(); + self.buf = &self.buf[len..]; + Ok(Some(Some(base..base + len))) + } + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} + +pub struct ErrorResponseBody { + storage: Bytes, +} + +impl ErrorResponseBody { + #[inline] + pub fn fields(&self) -> ErrorFields<'_> { + ErrorFields { buf: &self.storage } + } +} + +pub struct ErrorFields<'a> { + buf: &'a [u8], +} + +impl<'a> FallibleIterator for ErrorFields<'a> { + type Item = ErrorField<'a>; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result>> { + let type_ = self.buf.read_u8()?; + if type_ == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: error fields is not drained", + )); + } + } + + let value_end = find_null(self.buf, 0)?; + let value = get_str(&self.buf[..value_end])?; + self.buf = &self.buf[value_end + 1..]; + + Ok(Some(ErrorField { type_, value })) + } +} + +pub struct ErrorField<'a> { + type_: u8, + value: &'a str, +} + +impl ErrorField<'_> { + #[inline] + pub fn type_(&self) -> u8 { + self.type_ + } + + #[inline] + pub fn value(&self) -> &str { + self.value + } +} + +pub struct NoticeResponseBody { + storage: Bytes, +} + +impl NoticeResponseBody { + #[inline] + pub fn fields(&self) -> ErrorFields<'_> { + ErrorFields { buf: &self.storage } + } +} + +pub struct NotificationResponseBody { + process_id: i32, + channel: Bytes, + message: Bytes, +} + +impl NotificationResponseBody { + #[inline] + pub fn process_id(&self) -> i32 { + self.process_id + } + + #[inline] + pub fn channel(&self) -> io::Result<&str> { + get_str(&self.channel) + } + + #[inline] + pub fn message(&self) -> io::Result<&str> { + get_str(&self.message) + } +} + +pub struct ParameterDescriptionBody { + storage: Bytes, + len: u16, +} + +impl ParameterDescriptionBody { + #[inline] + pub fn parameters(&self) -> Parameters<'_> { + Parameters { + buf: &self.storage, + remaining: self.len, + } + } +} + +pub struct Parameters<'a> { + buf: &'a [u8], + remaining: u16, +} + +impl FallibleIterator for Parameters<'_> { + type Item = Oid; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result> { + if self.remaining == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: parameters is not drained", + )); + } + } + + self.remaining -= 1; + self.buf.read_u32::().map(Some) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} + +pub struct ParameterStatusBody { + name: Bytes, + value: Bytes, +} + +impl ParameterStatusBody { + #[inline] + pub fn name(&self) -> io::Result<&str> { + get_str(&self.name) + } + + #[inline] + pub fn value(&self) -> io::Result<&str> { + get_str(&self.value) + } +} + +pub struct ReadyForQueryBody { + status: u8, +} + +impl ReadyForQueryBody { + #[inline] + pub fn status(&self) -> u8 { + self.status + } +} + +pub struct RowDescriptionBody { + storage: Bytes, + len: u16, +} + +impl RowDescriptionBody { + #[inline] + pub fn fields(&self) -> Fields<'_> { + Fields { + buf: &self.storage, + remaining: self.len, + } + } +} + +pub struct Fields<'a> { + buf: &'a [u8], + remaining: u16, +} + +impl<'a> FallibleIterator for Fields<'a> { + type Item = Field<'a>; + type Error = io::Error; + + #[inline] + fn next(&mut self) -> io::Result>> { + if self.remaining == 0 { + if self.buf.is_empty() { + return Ok(None); + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "invalid message length: field is not drained", + )); + } + } + + self.remaining -= 1; + let name_end = find_null(self.buf, 0)?; + let name = get_str(&self.buf[..name_end])?; + self.buf = &self.buf[name_end + 1..]; + let table_oid = self.buf.read_u32::()?; + let column_id = self.buf.read_i16::()?; + let type_oid = self.buf.read_u32::()?; + let type_size = self.buf.read_i16::()?; + let type_modifier = self.buf.read_i32::()?; + let format = self.buf.read_i16::()?; + + Ok(Some(Field { + name, + table_oid, + column_id, + type_oid, + type_size, + type_modifier, + format, + })) + } +} + +pub struct Field<'a> { + name: &'a str, + table_oid: Oid, + column_id: i16, + type_oid: Oid, + type_size: i16, + type_modifier: i32, + format: i16, +} + +impl<'a> Field<'a> { + #[inline] + pub fn name(&self) -> &'a str { + self.name + } + + #[inline] + pub fn table_oid(&self) -> Oid { + self.table_oid + } + + #[inline] + pub fn column_id(&self) -> i16 { + self.column_id + } + + #[inline] + pub fn type_oid(&self) -> Oid { + self.type_oid + } + + #[inline] + pub fn type_size(&self) -> i16 { + self.type_size + } + + #[inline] + pub fn type_modifier(&self) -> i32 { + self.type_modifier + } + + #[inline] + pub fn format(&self) -> i16 { + self.format + } +} + +#[inline] +fn find_null(buf: &[u8], start: usize) -> io::Result { + match memchr(0, &buf[start..]) { + Some(pos) => Ok(pos + start), + None => Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "unexpected EOF", + )), + } +} + +#[inline] +fn get_str(buf: &[u8]) -> io::Result<&str> { + str::from_utf8(buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) +} diff --git a/libs/proxy/postgres-protocol2/src/message/frontend.rs b/libs/proxy/postgres-protocol2/src/message/frontend.rs new file mode 100644 index 0000000000..5d0a8ff8c8 --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/message/frontend.rs @@ -0,0 +1,297 @@ +//! Frontend message serialization. +#![allow(missing_docs)] + +use byteorder::{BigEndian, ByteOrder}; +use bytes::{Buf, BufMut, BytesMut}; +use std::convert::TryFrom; +use std::error::Error; +use std::io; +use std::marker; + +use crate::{write_nullable, FromUsize, IsNull, Oid}; + +#[inline] +fn write_body(buf: &mut BytesMut, f: F) -> Result<(), E> +where + F: FnOnce(&mut BytesMut) -> Result<(), E>, + E: From, +{ + let base = buf.len(); + buf.extend_from_slice(&[0; 4]); + + f(buf)?; + + let size = i32::from_usize(buf.len() - base)?; + BigEndian::write_i32(&mut buf[base..], size); + Ok(()) +} + +pub enum BindError { + Conversion(Box), + Serialization(io::Error), +} + +impl From> for BindError { + #[inline] + fn from(e: Box) -> BindError { + BindError::Conversion(e) + } +} + +impl From for BindError { + #[inline] + fn from(e: io::Error) -> BindError { + BindError::Serialization(e) + } +} + +#[inline] +pub fn bind( + portal: &str, + statement: &str, + formats: I, + values: J, + mut serializer: F, + result_formats: K, + buf: &mut BytesMut, +) -> Result<(), BindError> +where + I: IntoIterator, + J: IntoIterator, + F: FnMut(T, &mut BytesMut) -> Result>, + K: IntoIterator, +{ + buf.put_u8(b'B'); + + write_body(buf, |buf| { + write_cstr(portal.as_bytes(), buf)?; + write_cstr(statement.as_bytes(), buf)?; + write_counted( + formats, + |f, buf| { + buf.put_i16(f); + Ok::<_, io::Error>(()) + }, + buf, + )?; + write_counted( + values, + |v, buf| write_nullable(|buf| serializer(v, buf), buf), + buf, + )?; + write_counted( + result_formats, + |f, buf| { + buf.put_i16(f); + Ok::<_, io::Error>(()) + }, + buf, + )?; + + Ok(()) + }) +} + +#[inline] +fn write_counted(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E> +where + I: IntoIterator, + F: FnMut(T, &mut BytesMut) -> Result<(), E>, + E: From, +{ + let base = buf.len(); + buf.extend_from_slice(&[0; 2]); + let mut count = 0; + for item in items { + serializer(item, buf)?; + count += 1; + } + let count = i16::from_usize(count)?; + BigEndian::write_i16(&mut buf[base..], count); + + Ok(()) +} + +#[inline] +pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) { + write_body(buf, |buf| { + buf.put_i32(80_877_102); + buf.put_i32(process_id); + buf.put_i32(secret_key); + Ok::<_, io::Error>(()) + }) + .unwrap(); +} + +#[inline] +pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'C'); + write_body(buf, |buf| { + buf.put_u8(variant); + write_cstr(name.as_bytes(), buf) + }) +} + +pub struct CopyData { + buf: T, + len: i32, +} + +impl CopyData +where + T: Buf, +{ + pub fn new(buf: T) -> io::Result> { + let len = buf + .remaining() + .checked_add(4) + .and_then(|l| i32::try_from(l).ok()) + .ok_or_else(|| { + io::Error::new(io::ErrorKind::InvalidInput, "message length overflow") + })?; + + Ok(CopyData { buf, len }) + } + + pub fn write(self, out: &mut BytesMut) { + out.put_u8(b'd'); + out.put_i32(self.len); + out.put(self.buf); + } +} + +#[inline] +pub fn copy_done(buf: &mut BytesMut) { + buf.put_u8(b'c'); + write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); +} + +#[inline] +pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'f'); + write_body(buf, |buf| write_cstr(message.as_bytes(), buf)) +} + +#[inline] +pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'D'); + write_body(buf, |buf| { + buf.put_u8(variant); + write_cstr(name.as_bytes(), buf) + }) +} + +#[inline] +pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'E'); + write_body(buf, |buf| { + write_cstr(portal.as_bytes(), buf)?; + buf.put_i32(max_rows); + Ok(()) + }) +} + +#[inline] +pub fn parse(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()> +where + I: IntoIterator, +{ + buf.put_u8(b'P'); + write_body(buf, |buf| { + write_cstr(name.as_bytes(), buf)?; + write_cstr(query.as_bytes(), buf)?; + write_counted( + param_types, + |t, buf| { + buf.put_u32(t); + Ok::<_, io::Error>(()) + }, + buf, + )?; + Ok(()) + }) +} + +#[inline] +pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'p'); + write_body(buf, |buf| write_cstr(password, buf)) +} + +#[inline] +pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'Q'); + write_body(buf, |buf| write_cstr(query.as_bytes(), buf)) +} + +#[inline] +pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'p'); + write_body(buf, |buf| { + write_cstr(mechanism.as_bytes(), buf)?; + let len = i32::from_usize(data.len())?; + buf.put_i32(len); + buf.put_slice(data); + Ok(()) + }) +} + +#[inline] +pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> { + buf.put_u8(b'p'); + write_body(buf, |buf| { + buf.put_slice(data); + Ok(()) + }) +} + +#[inline] +pub fn ssl_request(buf: &mut BytesMut) { + write_body(buf, |buf| { + buf.put_i32(80_877_103); + Ok::<_, io::Error>(()) + }) + .unwrap(); +} + +#[inline] +pub fn startup_message<'a, I>(parameters: I, buf: &mut BytesMut) -> io::Result<()> +where + I: IntoIterator, +{ + write_body(buf, |buf| { + // postgres protocol version 3.0(196608) in bigger-endian + buf.put_i32(0x00_03_00_00); + for (key, value) in parameters { + write_cstr(key.as_bytes(), buf)?; + write_cstr(value.as_bytes(), buf)?; + } + buf.put_u8(0); + Ok(()) + }) +} + +#[inline] +pub fn sync(buf: &mut BytesMut) { + buf.put_u8(b'S'); + write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); +} + +#[inline] +pub fn terminate(buf: &mut BytesMut) { + buf.put_u8(b'X'); + write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); +} + +#[inline] +fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> { + if s.contains(&0) { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "string contains embedded null", + )); + } + buf.put_slice(s); + buf.put_u8(0); + Ok(()) +} diff --git a/libs/proxy/postgres-protocol2/src/message/mod.rs b/libs/proxy/postgres-protocol2/src/message/mod.rs new file mode 100644 index 0000000000..9e5d997548 --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/message/mod.rs @@ -0,0 +1,8 @@ +//! Postgres message protocol support. +//! +//! See [Postgres's documentation][docs] for more information on message flow. +//! +//! [docs]: https://www.postgresql.org/docs/9.5/static/protocol-flow.html + +pub mod backend; +pub mod frontend; diff --git a/libs/proxy/postgres-protocol2/src/password/mod.rs b/libs/proxy/postgres-protocol2/src/password/mod.rs new file mode 100644 index 0000000000..e669e80f3f --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/password/mod.rs @@ -0,0 +1,107 @@ +//! Functions to encrypt a password in the client. +//! +//! This is intended to be used by client applications that wish to +//! send commands like `ALTER USER joe PASSWORD 'pwd'`. The password +//! need not be sent in cleartext if it is encrypted on the client +//! side. This is good because it ensures the cleartext password won't +//! end up in logs pg_stat displays, etc. + +use crate::authentication::sasl; +use hmac::{Hmac, Mac}; +use md5::Md5; +use rand::RngCore; +use sha2::digest::FixedOutput; +use sha2::{Digest, Sha256}; + +#[cfg(test)] +mod test; + +const SCRAM_DEFAULT_ITERATIONS: u32 = 4096; +const SCRAM_DEFAULT_SALT_LEN: usize = 16; + +/// Hash password using SCRAM-SHA-256 with a randomly-generated +/// salt. +/// +/// The client may assume the returned string doesn't contain any +/// special characters that would require escaping in an SQL command. +pub async fn scram_sha_256(password: &[u8]) -> String { + let mut salt: [u8; SCRAM_DEFAULT_SALT_LEN] = [0; SCRAM_DEFAULT_SALT_LEN]; + let mut rng = rand::thread_rng(); + rng.fill_bytes(&mut salt); + scram_sha_256_salt(password, salt).await +} + +// Internal implementation of scram_sha_256 with a caller-provided +// salt. This is useful for testing. +pub(crate) async fn scram_sha_256_salt( + password: &[u8], + salt: [u8; SCRAM_DEFAULT_SALT_LEN], +) -> String { + // Prepare the password, per [RFC + // 4013](https://tools.ietf.org/html/rfc4013), if possible. + // + // Postgres treats passwords as byte strings (without embedded NUL + // bytes), but SASL expects passwords to be valid UTF-8. + // + // Follow the behavior of libpq's PQencryptPasswordConn(), and + // also the backend. If the password is not valid UTF-8, or if it + // contains prohibited characters (such as non-ASCII whitespace), + // just skip the SASLprep step and use the original byte + // sequence. + let prepared: Vec = match std::str::from_utf8(password) { + Ok(password_str) => { + match stringprep::saslprep(password_str) { + Ok(p) => p.into_owned().into_bytes(), + // contains invalid characters; skip saslprep + Err(_) => Vec::from(password), + } + } + // not valid UTF-8; skip saslprep + Err(_) => Vec::from(password), + }; + + // salt password + let salted_password = sasl::hi(&prepared, &salt, SCRAM_DEFAULT_ITERATIONS).await; + + // client key + let mut hmac = Hmac::::new_from_slice(&salted_password) + .expect("HMAC is able to accept all key sizes"); + hmac.update(b"Client Key"); + let client_key = hmac.finalize().into_bytes(); + + // stored key + let mut hash = Sha256::default(); + hash.update(client_key.as_slice()); + let stored_key = hash.finalize_fixed(); + + // server key + let mut hmac = Hmac::::new_from_slice(&salted_password) + .expect("HMAC is able to accept all key sizes"); + hmac.update(b"Server Key"); + let server_key = hmac.finalize().into_bytes(); + + format!( + "SCRAM-SHA-256${}:{}${}:{}", + SCRAM_DEFAULT_ITERATIONS, + base64::encode(salt), + base64::encode(stored_key), + base64::encode(server_key) + ) +} + +/// **Not recommended, as MD5 is not considered to be secure.** +/// +/// Hash password using MD5 with the username as the salt. +/// +/// The client may assume the returned string doesn't contain any +/// special characters that would require escaping. +pub fn md5(password: &[u8], username: &str) -> String { + // salt password with username + let mut salted_password = Vec::from(password); + salted_password.extend_from_slice(username.as_bytes()); + + let mut hash = Md5::new(); + hash.update(&salted_password); + let digest = hash.finalize(); + format!("md5{:x}", digest) +} diff --git a/libs/proxy/postgres-protocol2/src/password/test.rs b/libs/proxy/postgres-protocol2/src/password/test.rs new file mode 100644 index 0000000000..c9d340f09d --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/password/test.rs @@ -0,0 +1,19 @@ +use crate::password; + +#[tokio::test] +async fn test_encrypt_scram_sha_256() { + // Specify the salt to make the test deterministic. Any bytes will do. + let salt: [u8; 16] = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]; + assert_eq!( + password::scram_sha_256_salt(b"secret", salt).await, + "SCRAM-SHA-256$4096:AQIDBAUGBwgJCgsMDQ4PEA==$8rrDg00OqaiWXJ7p+sCgHEIaBSHY89ZJl3mfIsf32oY=:05L1f+yZbiN8O0AnO40Og85NNRhvzTS57naKRWCcsIA=" + ); +} + +#[test] +fn test_encrypt_md5() { + assert_eq!( + password::md5(b"secret", "foo"), + "md54ab2c5d00339c4b2a4e921d2dc4edec7" + ); +} diff --git a/libs/proxy/postgres-protocol2/src/types/mod.rs b/libs/proxy/postgres-protocol2/src/types/mod.rs new file mode 100644 index 0000000000..78131c05bf --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/types/mod.rs @@ -0,0 +1,294 @@ +//! Conversions to and from Postgres's binary format for various types. +use byteorder::{BigEndian, ReadBytesExt}; +use bytes::{BufMut, BytesMut}; +use fallible_iterator::FallibleIterator; +use std::boxed::Box as StdBox; +use std::error::Error; +use std::str; + +use crate::Oid; + +#[cfg(test)] +mod test; + +/// Serializes a `TEXT`, `VARCHAR`, `CHAR(n)`, `NAME`, or `CITEXT` value. +#[inline] +pub fn text_to_sql(v: &str, buf: &mut BytesMut) { + buf.put_slice(v.as_bytes()); +} + +/// Deserializes a `TEXT`, `VARCHAR`, `CHAR(n)`, `NAME`, or `CITEXT` value. +#[inline] +pub fn text_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + Ok(str::from_utf8(buf)?) +} + +/// Deserializes a `"char"` value. +#[inline] +pub fn char_from_sql(mut buf: &[u8]) -> Result> { + let v = buf.read_i8()?; + if !buf.is_empty() { + return Err("invalid buffer size".into()); + } + Ok(v) +} + +/// Serializes an `OID` value. +#[inline] +pub fn oid_to_sql(v: Oid, buf: &mut BytesMut) { + buf.put_u32(v); +} + +/// Deserializes an `OID` value. +#[inline] +pub fn oid_from_sql(mut buf: &[u8]) -> Result> { + let v = buf.read_u32::()?; + if !buf.is_empty() { + return Err("invalid buffer size".into()); + } + Ok(v) +} + +/// A fallible iterator over `HSTORE` entries. +pub struct HstoreEntries<'a> { + remaining: i32, + buf: &'a [u8], +} + +impl<'a> FallibleIterator for HstoreEntries<'a> { + type Item = (&'a str, Option<&'a str>); + type Error = StdBox; + + #[inline] + #[allow(clippy::type_complexity)] + fn next( + &mut self, + ) -> Result)>, StdBox> { + if self.remaining == 0 { + if !self.buf.is_empty() { + return Err("invalid buffer size".into()); + } + return Ok(None); + } + + self.remaining -= 1; + + let key_len = self.buf.read_i32::()?; + if key_len < 0 { + return Err("invalid key length".into()); + } + let (key, buf) = self.buf.split_at(key_len as usize); + let key = str::from_utf8(key)?; + self.buf = buf; + + let value_len = self.buf.read_i32::()?; + let value = if value_len < 0 { + None + } else { + let (value, buf) = self.buf.split_at(value_len as usize); + let value = str::from_utf8(value)?; + self.buf = buf; + Some(value) + }; + + Ok(Some((key, value))) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} + +/// Deserializes an array value. +#[inline] +pub fn array_from_sql(mut buf: &[u8]) -> Result, StdBox> { + let dimensions = buf.read_i32::()?; + if dimensions < 0 { + return Err("invalid dimension count".into()); + } + + let mut r = buf; + let mut elements = 1i32; + for _ in 0..dimensions { + let len = r.read_i32::()?; + if len < 0 { + return Err("invalid dimension size".into()); + } + let _lower_bound = r.read_i32::()?; + elements = match elements.checked_mul(len) { + Some(elements) => elements, + None => return Err("too many array elements".into()), + }; + } + + if dimensions == 0 { + elements = 0; + } + + Ok(Array { + dimensions, + elements, + buf, + }) +} + +/// A Postgres array. +pub struct Array<'a> { + dimensions: i32, + elements: i32, + buf: &'a [u8], +} + +impl<'a> Array<'a> { + /// Returns an iterator over the dimensions of the array. + #[inline] + pub fn dimensions(&self) -> ArrayDimensions<'a> { + ArrayDimensions(&self.buf[..self.dimensions as usize * 8]) + } + + /// Returns an iterator over the values of the array. + #[inline] + pub fn values(&self) -> ArrayValues<'a> { + ArrayValues { + remaining: self.elements, + buf: &self.buf[self.dimensions as usize * 8..], + } + } +} + +/// An iterator over the dimensions of an array. +pub struct ArrayDimensions<'a>(&'a [u8]); + +impl FallibleIterator for ArrayDimensions<'_> { + type Item = ArrayDimension; + type Error = StdBox; + + #[inline] + fn next(&mut self) -> Result, StdBox> { + if self.0.is_empty() { + return Ok(None); + } + + let len = self.0.read_i32::()?; + let lower_bound = self.0.read_i32::()?; + + Ok(Some(ArrayDimension { len, lower_bound })) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + let len = self.0.len() / 8; + (len, Some(len)) + } +} + +/// Information about a dimension of an array. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct ArrayDimension { + /// The length of this dimension. + pub len: i32, + + /// The base value used to index into this dimension. + pub lower_bound: i32, +} + +/// An iterator over the values of an array, in row-major order. +pub struct ArrayValues<'a> { + remaining: i32, + buf: &'a [u8], +} + +impl<'a> FallibleIterator for ArrayValues<'a> { + type Item = Option<&'a [u8]>; + type Error = StdBox; + + #[inline] + fn next(&mut self) -> Result>, StdBox> { + if self.remaining == 0 { + if !self.buf.is_empty() { + return Err("invalid message length: arrayvalue not drained".into()); + } + return Ok(None); + } + self.remaining -= 1; + + let len = self.buf.read_i32::()?; + let val = if len < 0 { + None + } else { + if self.buf.len() < len as usize { + return Err("invalid value length".into()); + } + + let (val, buf) = self.buf.split_at(len as usize); + self.buf = buf; + Some(val) + }; + + Ok(Some(val)) + } + + fn size_hint(&self) -> (usize, Option) { + let len = self.remaining as usize; + (len, Some(len)) + } +} + +/// Serializes a Postgres ltree string +#[inline] +pub fn ltree_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an ltree string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} + +/// Deserialize a Postgres ltree string +#[inline] +pub fn ltree_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the ltree per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("ltree version 1 only supported".into()), + } +} + +/// Serializes a Postgres lquery string +#[inline] +pub fn lquery_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an lquery string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} + +/// Deserialize a Postgres lquery string +#[inline] +pub fn lquery_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the lquery per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("lquery version 1 only supported".into()), + } +} + +/// Serializes a Postgres ltxtquery string +#[inline] +pub fn ltxtquery_to_sql(v: &str, buf: &mut BytesMut) { + // A version number is prepended to an ltxtquery string per spec + buf.put_u8(1); + // Append the rest of the query + buf.put_slice(v.as_bytes()); +} + +/// Deserialize a Postgres ltxtquery string +#[inline] +pub fn ltxtquery_from_sql(buf: &[u8]) -> Result<&str, StdBox> { + match buf { + // Remove the version number from the front of the ltxtquery per spec + [1u8, rest @ ..] => Ok(str::from_utf8(rest)?), + _ => Err("ltxtquery version 1 only supported".into()), + } +} diff --git a/libs/proxy/postgres-protocol2/src/types/test.rs b/libs/proxy/postgres-protocol2/src/types/test.rs new file mode 100644 index 0000000000..96cc055bc3 --- /dev/null +++ b/libs/proxy/postgres-protocol2/src/types/test.rs @@ -0,0 +1,87 @@ +use bytes::{Buf, BytesMut}; + +use super::*; + +#[test] +fn ltree_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + let mut buf = BytesMut::new(); + + ltree_to_sql("A.B.C", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn ltree_str() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(ltree_from_sql(query.as_slice()).is_ok()) +} + +#[test] +fn ltree_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(ltree_from_sql(query.as_slice()).is_err()) +} + +#[test] +fn lquery_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + let mut buf = BytesMut::new(); + + lquery_to_sql("A.B.C", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn lquery_str() { + let mut query = vec![1u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(lquery_from_sql(query.as_slice()).is_ok()) +} + +#[test] +fn lquery_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("A.B.C".as_bytes()); + + assert!(lquery_from_sql(query.as_slice()).is_err()) +} + +#[test] +fn ltxtquery_sql() { + let mut query = vec![1u8]; + query.extend_from_slice("a & b*".as_bytes()); + + let mut buf = BytesMut::new(); + + ltree_to_sql("a & b*", &mut buf); + + assert_eq!(query.as_slice(), buf.chunk()); +} + +#[test] +fn ltxtquery_str() { + let mut query = vec![1u8]; + query.extend_from_slice("a & b*".as_bytes()); + + assert!(ltree_from_sql(query.as_slice()).is_ok()) +} + +#[test] +fn ltxtquery_wrong_version() { + let mut query = vec![2u8]; + query.extend_from_slice("a & b*".as_bytes()); + + assert!(ltree_from_sql(query.as_slice()).is_err()) +} diff --git a/libs/proxy/postgres-types2/Cargo.toml b/libs/proxy/postgres-types2/Cargo.toml new file mode 100644 index 0000000000..58cfb5571f --- /dev/null +++ b/libs/proxy/postgres-types2/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "postgres-types2" +version = "0.1.0" +edition = "2018" +license = "MIT/Apache-2.0" + +[dependencies] +bytes.workspace = true +fallible-iterator.workspace = true +postgres-protocol2 = { path = "../postgres-protocol2" } diff --git a/libs/proxy/postgres-types2/src/lib.rs b/libs/proxy/postgres-types2/src/lib.rs new file mode 100644 index 0000000000..18ba032151 --- /dev/null +++ b/libs/proxy/postgres-types2/src/lib.rs @@ -0,0 +1,477 @@ +//! Conversions to and from Postgres types. +//! +//! This crate is used by the `tokio-postgres` and `postgres` crates. You normally don't need to depend directly on it +//! unless you want to define your own `ToSql` or `FromSql` definitions. +#![doc(html_root_url = "https://docs.rs/postgres-types/0.2")] +#![warn(clippy::all, rust_2018_idioms, missing_docs)] + +use fallible_iterator::FallibleIterator; +use postgres_protocol2::types; +use std::any::type_name; +use std::error::Error; +use std::fmt; +use std::sync::Arc; + +use crate::type_gen::{Inner, Other}; + +#[doc(inline)] +pub use postgres_protocol2::Oid; + +use bytes::BytesMut; + +/// Generates a simple implementation of `ToSql::accepts` which accepts the +/// types passed to it. +macro_rules! accepts { + ($($expected:ident),+) => ( + fn accepts(ty: &$crate::Type) -> bool { + matches!(*ty, $($crate::Type::$expected)|+) + } + ) +} + +/// Generates an implementation of `ToSql::to_sql_checked`. +/// +/// All `ToSql` implementations should use this macro. +macro_rules! to_sql_checked { + () => { + fn to_sql_checked( + &self, + ty: &$crate::Type, + out: &mut $crate::private::BytesMut, + ) -> ::std::result::Result< + $crate::IsNull, + Box, + > { + $crate::__to_sql_checked(self, ty, out) + } + }; +} + +// WARNING: this function is not considered part of this crate's public API. +// It is subject to change at any time. +#[doc(hidden)] +pub fn __to_sql_checked( + v: &T, + ty: &Type, + out: &mut BytesMut, +) -> Result> +where + T: ToSql, +{ + if !T::accepts(ty) { + return Err(Box::new(WrongType::new::(ty.clone()))); + } + v.to_sql(ty, out) +} + +// mod pg_lsn; +#[doc(hidden)] +pub mod private; +// mod special; +mod type_gen; + +/// A Postgres type. +#[derive(PartialEq, Eq, Clone, Hash)] +pub struct Type(Inner); + +impl fmt::Debug for Type { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Debug::fmt(&self.0, fmt) + } +} + +impl fmt::Display for Type { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match self.schema() { + "public" | "pg_catalog" => {} + schema => write!(fmt, "{}.", schema)?, + } + fmt.write_str(self.name()) + } +} + +impl Type { + /// Creates a new `Type`. + pub fn new(name: String, oid: Oid, kind: Kind, schema: String) -> Type { + Type(Inner::Other(Arc::new(Other { + name, + oid, + kind, + schema, + }))) + } + + /// Returns the `Type` corresponding to the provided `Oid` if it + /// corresponds to a built-in type. + pub fn from_oid(oid: Oid) -> Option { + Inner::from_oid(oid).map(Type) + } + + /// Returns the OID of the `Type`. + pub fn oid(&self) -> Oid { + self.0.oid() + } + + /// Returns the kind of this type. + pub fn kind(&self) -> &Kind { + self.0.kind() + } + + /// Returns the schema of this type. + pub fn schema(&self) -> &str { + match self.0 { + Inner::Other(ref u) => &u.schema, + _ => "pg_catalog", + } + } + + /// Returns the name of this type. + pub fn name(&self) -> &str { + self.0.name() + } +} + +/// Represents the kind of a Postgres type. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum Kind { + /// A simple type like `VARCHAR` or `INTEGER`. + Simple, + /// An enumerated type along with its variants. + Enum(Vec), + /// A pseudo-type. + Pseudo, + /// An array type along with the type of its elements. + Array(Type), + /// A range type along with the type of its elements. + Range(Type), + /// A multirange type along with the type of its elements. + Multirange(Type), + /// A domain type along with its underlying type. + Domain(Type), + /// A composite type along with information about its fields. + Composite(Vec), +} + +/// Information about a field of a composite type. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Field { + name: String, + type_: Type, +} + +impl Field { + /// Creates a new `Field`. + pub fn new(name: String, type_: Type) -> Field { + Field { name, type_ } + } + + /// Returns the name of the field. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the type of the field. + pub fn type_(&self) -> &Type { + &self.type_ + } +} + +/// An error indicating that a `NULL` Postgres value was passed to a `FromSql` +/// implementation that does not support `NULL` values. +#[derive(Debug, Clone, Copy)] +pub struct WasNull; + +impl fmt::Display for WasNull { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.write_str("a Postgres value was `NULL`") + } +} + +impl Error for WasNull {} + +/// An error indicating that a conversion was attempted between incompatible +/// Rust and Postgres types. +#[derive(Debug)] +pub struct WrongType { + postgres: Type, + rust: &'static str, +} + +impl fmt::Display for WrongType { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "cannot convert between the Rust type `{}` and the Postgres type `{}`", + self.rust, self.postgres, + ) + } +} + +impl Error for WrongType {} + +impl WrongType { + /// Creates a new `WrongType` error. + pub fn new(ty: Type) -> WrongType { + WrongType { + postgres: ty, + rust: type_name::(), + } + } +} + +/// An error indicating that a as_text conversion was attempted on a binary +/// result. +#[derive(Debug)] +pub struct WrongFormat {} + +impl Error for WrongFormat {} + +impl fmt::Display for WrongFormat { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "cannot read column as text while it is in binary format" + ) + } +} + +/// A trait for types that can be created from a Postgres value. +pub trait FromSql<'a>: Sized { + /// Creates a new value of this type from a buffer of data of the specified + /// Postgres `Type` in its binary format. + /// + /// The caller of this method is responsible for ensuring that this type + /// is compatible with the Postgres `Type`. + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result>; + + /// Creates a new value of this type from a `NULL` SQL value. + /// + /// The caller of this method is responsible for ensuring that this type + /// is compatible with the Postgres `Type`. + /// + /// The default implementation returns `Err(Box::new(WasNull))`. + #[allow(unused_variables)] + fn from_sql_null(ty: &Type) -> Result> { + Err(Box::new(WasNull)) + } + + /// A convenience function that delegates to `from_sql` and `from_sql_null` depending on the + /// value of `raw`. + fn from_sql_nullable( + ty: &Type, + raw: Option<&'a [u8]>, + ) -> Result> { + match raw { + Some(raw) => Self::from_sql(ty, raw), + None => Self::from_sql_null(ty), + } + } + + /// Determines if a value of this type can be created from the specified + /// Postgres `Type`. + fn accepts(ty: &Type) -> bool; +} + +/// A trait for types which can be created from a Postgres value without borrowing any data. +/// +/// This is primarily useful for trait bounds on functions. +pub trait FromSqlOwned: for<'a> FromSql<'a> {} + +impl FromSqlOwned for T where T: for<'a> FromSql<'a> {} + +impl<'a, T: FromSql<'a>> FromSql<'a> for Option { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result, Box> { + ::from_sql(ty, raw).map(Some) + } + + fn from_sql_null(_: &Type) -> Result, Box> { + Ok(None) + } + + fn accepts(ty: &Type) -> bool { + ::accepts(ty) + } +} + +impl<'a, T: FromSql<'a>> FromSql<'a> for Vec { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result, Box> { + let member_type = match *ty.kind() { + Kind::Array(ref member) => member, + _ => panic!("expected array type"), + }; + + let array = types::array_from_sql(raw)?; + if array.dimensions().count()? > 1 { + return Err("array contains too many dimensions".into()); + } + + array + .values() + .map(|v| T::from_sql_nullable(member_type, v)) + .collect() + } + + fn accepts(ty: &Type) -> bool { + match *ty.kind() { + Kind::Array(ref inner) => T::accepts(inner), + _ => false, + } + } +} + +impl<'a> FromSql<'a> for String { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result> { + <&str as FromSql>::from_sql(ty, raw).map(ToString::to_string) + } + + fn accepts(ty: &Type) -> bool { + <&str as FromSql>::accepts(ty) + } +} + +impl<'a> FromSql<'a> for &'a str { + fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<&'a str, Box> { + match *ty { + ref ty if ty.name() == "ltree" => types::ltree_from_sql(raw), + ref ty if ty.name() == "lquery" => types::lquery_from_sql(raw), + ref ty if ty.name() == "ltxtquery" => types::ltxtquery_from_sql(raw), + _ => types::text_from_sql(raw), + } + } + + fn accepts(ty: &Type) -> bool { + match *ty { + Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, + ref ty + if (ty.name() == "citext" + || ty.name() == "ltree" + || ty.name() == "lquery" + || ty.name() == "ltxtquery") => + { + true + } + _ => false, + } + } +} + +macro_rules! simple_from { + ($t:ty, $f:ident, $($expected:ident),+) => { + impl<'a> FromSql<'a> for $t { + fn from_sql(_: &Type, raw: &'a [u8]) -> Result<$t, Box> { + types::$f(raw) + } + + accepts!($($expected),+); + } + } +} + +simple_from!(i8, char_from_sql, CHAR); +simple_from!(u32, oid_from_sql, OID); + +/// An enum representing the nullability of a Postgres value. +pub enum IsNull { + /// The value is NULL. + Yes, + /// The value is not NULL. + No, +} + +/// A trait for types that can be converted into Postgres values. +pub trait ToSql: fmt::Debug { + /// Converts the value of `self` into the binary format of the specified + /// Postgres `Type`, appending it to `out`. + /// + /// The caller of this method is responsible for ensuring that this type + /// is compatible with the Postgres `Type`. + /// + /// The return value indicates if this value should be represented as + /// `NULL`. If this is the case, implementations **must not** write + /// anything to `out`. + fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result> + where + Self: Sized; + + /// Determines if a value of this type can be converted to the specified + /// Postgres `Type`. + fn accepts(ty: &Type) -> bool + where + Self: Sized; + + /// An adaptor method used internally by Rust-Postgres. + /// + /// *All* implementations of this method should be generated by the + /// `to_sql_checked!()` macro. + fn to_sql_checked( + &self, + ty: &Type, + out: &mut BytesMut, + ) -> Result>; + + /// Specify the encode format + fn encode_format(&self, _ty: &Type) -> Format { + Format::Binary + } +} + +/// Supported Postgres message format types +/// +/// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8` +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum Format { + /// Text format (UTF-8) + Text, + /// Compact, typed binary format + Binary, +} + +impl ToSql for &str { + fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { + match *ty { + ref ty if ty.name() == "ltree" => types::ltree_to_sql(self, w), + ref ty if ty.name() == "lquery" => types::lquery_to_sql(self, w), + ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(self, w), + _ => types::text_to_sql(self, w), + } + Ok(IsNull::No) + } + + fn accepts(ty: &Type) -> bool { + match *ty { + Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, + ref ty + if (ty.name() == "citext" + || ty.name() == "ltree" + || ty.name() == "lquery" + || ty.name() == "ltxtquery") => + { + true + } + _ => false, + } + } + + to_sql_checked!(); +} + +macro_rules! simple_to { + ($t:ty, $f:ident, $($expected:ident),+) => { + impl ToSql for $t { + fn to_sql(&self, + _: &Type, + w: &mut BytesMut) + -> Result> { + types::$f(*self, w); + Ok(IsNull::No) + } + + accepts!($($expected),+); + + to_sql_checked!(); + } + } +} + +simple_to!(u32, oid_to_sql, OID); diff --git a/libs/proxy/postgres-types2/src/private.rs b/libs/proxy/postgres-types2/src/private.rs new file mode 100644 index 0000000000..774f9a301c --- /dev/null +++ b/libs/proxy/postgres-types2/src/private.rs @@ -0,0 +1,34 @@ +use crate::{FromSql, Type}; +pub use bytes::BytesMut; +use std::error::Error; + +pub fn read_be_i32(buf: &mut &[u8]) -> Result> { + if buf.len() < 4 { + return Err("invalid buffer size".into()); + } + let mut bytes = [0; 4]; + bytes.copy_from_slice(&buf[..4]); + *buf = &buf[4..]; + Ok(i32::from_be_bytes(bytes)) +} + +pub fn read_value<'a, T>( + type_: &Type, + buf: &mut &'a [u8], +) -> Result> +where + T: FromSql<'a>, +{ + let len = read_be_i32(buf)?; + let value = if len < 0 { + None + } else { + if len as usize > buf.len() { + return Err("invalid buffer size".into()); + } + let (head, tail) = buf.split_at(len as usize); + *buf = tail; + Some(head) + }; + T::from_sql_nullable(type_, value) +} diff --git a/libs/proxy/postgres-types2/src/type_gen.rs b/libs/proxy/postgres-types2/src/type_gen.rs new file mode 100644 index 0000000000..a1bc3f85c0 --- /dev/null +++ b/libs/proxy/postgres-types2/src/type_gen.rs @@ -0,0 +1,1524 @@ +// Autogenerated file - DO NOT EDIT +use std::sync::Arc; + +use crate::{Kind, Oid, Type}; + +#[derive(PartialEq, Eq, Debug, Hash)] +pub struct Other { + pub name: String, + pub oid: Oid, + pub kind: Kind, + pub schema: String, +} + +#[derive(PartialEq, Eq, Clone, Debug, Hash)] +pub enum Inner { + Bool, + Bytea, + Char, + Name, + Int8, + Int2, + Int2Vector, + Int4, + Regproc, + Text, + Oid, + Tid, + Xid, + Cid, + OidVector, + PgDdlCommand, + Json, + Xml, + XmlArray, + PgNodeTree, + JsonArray, + TableAmHandler, + Xid8Array, + IndexAmHandler, + Point, + Lseg, + Path, + Box, + Polygon, + Line, + LineArray, + Cidr, + CidrArray, + Float4, + Float8, + Unknown, + Circle, + CircleArray, + Macaddr8, + Macaddr8Array, + Money, + MoneyArray, + Macaddr, + Inet, + BoolArray, + ByteaArray, + CharArray, + NameArray, + Int2Array, + Int2VectorArray, + Int4Array, + RegprocArray, + TextArray, + TidArray, + XidArray, + CidArray, + OidVectorArray, + BpcharArray, + VarcharArray, + Int8Array, + PointArray, + LsegArray, + PathArray, + BoxArray, + Float4Array, + Float8Array, + PolygonArray, + OidArray, + Aclitem, + AclitemArray, + MacaddrArray, + InetArray, + Bpchar, + Varchar, + Date, + Time, + Timestamp, + TimestampArray, + DateArray, + TimeArray, + Timestamptz, + TimestamptzArray, + Interval, + IntervalArray, + NumericArray, + CstringArray, + Timetz, + TimetzArray, + Bit, + BitArray, + Varbit, + VarbitArray, + Numeric, + Refcursor, + RefcursorArray, + Regprocedure, + Regoper, + Regoperator, + Regclass, + Regtype, + RegprocedureArray, + RegoperArray, + RegoperatorArray, + RegclassArray, + RegtypeArray, + Record, + Cstring, + Any, + Anyarray, + Void, + Trigger, + LanguageHandler, + Internal, + Anyelement, + RecordArray, + Anynonarray, + TxidSnapshotArray, + Uuid, + UuidArray, + TxidSnapshot, + FdwHandler, + PgLsn, + PgLsnArray, + TsmHandler, + PgNdistinct, + PgDependencies, + Anyenum, + TsVector, + Tsquery, + GtsVector, + TsVectorArray, + GtsVectorArray, + TsqueryArray, + Regconfig, + RegconfigArray, + Regdictionary, + RegdictionaryArray, + Jsonb, + JsonbArray, + AnyRange, + EventTrigger, + Int4Range, + Int4RangeArray, + NumRange, + NumRangeArray, + TsRange, + TsRangeArray, + TstzRange, + TstzRangeArray, + DateRange, + DateRangeArray, + Int8Range, + Int8RangeArray, + Jsonpath, + JsonpathArray, + Regnamespace, + RegnamespaceArray, + Regrole, + RegroleArray, + Regcollation, + RegcollationArray, + Int4multiRange, + NummultiRange, + TsmultiRange, + TstzmultiRange, + DatemultiRange, + Int8multiRange, + AnymultiRange, + AnycompatiblemultiRange, + PgBrinBloomSummary, + PgBrinMinmaxMultiSummary, + PgMcvList, + PgSnapshot, + PgSnapshotArray, + Xid8, + Anycompatible, + Anycompatiblearray, + Anycompatiblenonarray, + AnycompatibleRange, + Int4multiRangeArray, + NummultiRangeArray, + TsmultiRangeArray, + TstzmultiRangeArray, + DatemultiRangeArray, + Int8multiRangeArray, + Other(Arc), +} + +impl Inner { + pub fn from_oid(oid: Oid) -> Option { + match oid { + 16 => Some(Inner::Bool), + 17 => Some(Inner::Bytea), + 18 => Some(Inner::Char), + 19 => Some(Inner::Name), + 20 => Some(Inner::Int8), + 21 => Some(Inner::Int2), + 22 => Some(Inner::Int2Vector), + 23 => Some(Inner::Int4), + 24 => Some(Inner::Regproc), + 25 => Some(Inner::Text), + 26 => Some(Inner::Oid), + 27 => Some(Inner::Tid), + 28 => Some(Inner::Xid), + 29 => Some(Inner::Cid), + 30 => Some(Inner::OidVector), + 32 => Some(Inner::PgDdlCommand), + 114 => Some(Inner::Json), + 142 => Some(Inner::Xml), + 143 => Some(Inner::XmlArray), + 194 => Some(Inner::PgNodeTree), + 199 => Some(Inner::JsonArray), + 269 => Some(Inner::TableAmHandler), + 271 => Some(Inner::Xid8Array), + 325 => Some(Inner::IndexAmHandler), + 600 => Some(Inner::Point), + 601 => Some(Inner::Lseg), + 602 => Some(Inner::Path), + 603 => Some(Inner::Box), + 604 => Some(Inner::Polygon), + 628 => Some(Inner::Line), + 629 => Some(Inner::LineArray), + 650 => Some(Inner::Cidr), + 651 => Some(Inner::CidrArray), + 700 => Some(Inner::Float4), + 701 => Some(Inner::Float8), + 705 => Some(Inner::Unknown), + 718 => Some(Inner::Circle), + 719 => Some(Inner::CircleArray), + 774 => Some(Inner::Macaddr8), + 775 => Some(Inner::Macaddr8Array), + 790 => Some(Inner::Money), + 791 => Some(Inner::MoneyArray), + 829 => Some(Inner::Macaddr), + 869 => Some(Inner::Inet), + 1000 => Some(Inner::BoolArray), + 1001 => Some(Inner::ByteaArray), + 1002 => Some(Inner::CharArray), + 1003 => Some(Inner::NameArray), + 1005 => Some(Inner::Int2Array), + 1006 => Some(Inner::Int2VectorArray), + 1007 => Some(Inner::Int4Array), + 1008 => Some(Inner::RegprocArray), + 1009 => Some(Inner::TextArray), + 1010 => Some(Inner::TidArray), + 1011 => Some(Inner::XidArray), + 1012 => Some(Inner::CidArray), + 1013 => Some(Inner::OidVectorArray), + 1014 => Some(Inner::BpcharArray), + 1015 => Some(Inner::VarcharArray), + 1016 => Some(Inner::Int8Array), + 1017 => Some(Inner::PointArray), + 1018 => Some(Inner::LsegArray), + 1019 => Some(Inner::PathArray), + 1020 => Some(Inner::BoxArray), + 1021 => Some(Inner::Float4Array), + 1022 => Some(Inner::Float8Array), + 1027 => Some(Inner::PolygonArray), + 1028 => Some(Inner::OidArray), + 1033 => Some(Inner::Aclitem), + 1034 => Some(Inner::AclitemArray), + 1040 => Some(Inner::MacaddrArray), + 1041 => Some(Inner::InetArray), + 1042 => Some(Inner::Bpchar), + 1043 => Some(Inner::Varchar), + 1082 => Some(Inner::Date), + 1083 => Some(Inner::Time), + 1114 => Some(Inner::Timestamp), + 1115 => Some(Inner::TimestampArray), + 1182 => Some(Inner::DateArray), + 1183 => Some(Inner::TimeArray), + 1184 => Some(Inner::Timestamptz), + 1185 => Some(Inner::TimestamptzArray), + 1186 => Some(Inner::Interval), + 1187 => Some(Inner::IntervalArray), + 1231 => Some(Inner::NumericArray), + 1263 => Some(Inner::CstringArray), + 1266 => Some(Inner::Timetz), + 1270 => Some(Inner::TimetzArray), + 1560 => Some(Inner::Bit), + 1561 => Some(Inner::BitArray), + 1562 => Some(Inner::Varbit), + 1563 => Some(Inner::VarbitArray), + 1700 => Some(Inner::Numeric), + 1790 => Some(Inner::Refcursor), + 2201 => Some(Inner::RefcursorArray), + 2202 => Some(Inner::Regprocedure), + 2203 => Some(Inner::Regoper), + 2204 => Some(Inner::Regoperator), + 2205 => Some(Inner::Regclass), + 2206 => Some(Inner::Regtype), + 2207 => Some(Inner::RegprocedureArray), + 2208 => Some(Inner::RegoperArray), + 2209 => Some(Inner::RegoperatorArray), + 2210 => Some(Inner::RegclassArray), + 2211 => Some(Inner::RegtypeArray), + 2249 => Some(Inner::Record), + 2275 => Some(Inner::Cstring), + 2276 => Some(Inner::Any), + 2277 => Some(Inner::Anyarray), + 2278 => Some(Inner::Void), + 2279 => Some(Inner::Trigger), + 2280 => Some(Inner::LanguageHandler), + 2281 => Some(Inner::Internal), + 2283 => Some(Inner::Anyelement), + 2287 => Some(Inner::RecordArray), + 2776 => Some(Inner::Anynonarray), + 2949 => Some(Inner::TxidSnapshotArray), + 2950 => Some(Inner::Uuid), + 2951 => Some(Inner::UuidArray), + 2970 => Some(Inner::TxidSnapshot), + 3115 => Some(Inner::FdwHandler), + 3220 => Some(Inner::PgLsn), + 3221 => Some(Inner::PgLsnArray), + 3310 => Some(Inner::TsmHandler), + 3361 => Some(Inner::PgNdistinct), + 3402 => Some(Inner::PgDependencies), + 3500 => Some(Inner::Anyenum), + 3614 => Some(Inner::TsVector), + 3615 => Some(Inner::Tsquery), + 3642 => Some(Inner::GtsVector), + 3643 => Some(Inner::TsVectorArray), + 3644 => Some(Inner::GtsVectorArray), + 3645 => Some(Inner::TsqueryArray), + 3734 => Some(Inner::Regconfig), + 3735 => Some(Inner::RegconfigArray), + 3769 => Some(Inner::Regdictionary), + 3770 => Some(Inner::RegdictionaryArray), + 3802 => Some(Inner::Jsonb), + 3807 => Some(Inner::JsonbArray), + 3831 => Some(Inner::AnyRange), + 3838 => Some(Inner::EventTrigger), + 3904 => Some(Inner::Int4Range), + 3905 => Some(Inner::Int4RangeArray), + 3906 => Some(Inner::NumRange), + 3907 => Some(Inner::NumRangeArray), + 3908 => Some(Inner::TsRange), + 3909 => Some(Inner::TsRangeArray), + 3910 => Some(Inner::TstzRange), + 3911 => Some(Inner::TstzRangeArray), + 3912 => Some(Inner::DateRange), + 3913 => Some(Inner::DateRangeArray), + 3926 => Some(Inner::Int8Range), + 3927 => Some(Inner::Int8RangeArray), + 4072 => Some(Inner::Jsonpath), + 4073 => Some(Inner::JsonpathArray), + 4089 => Some(Inner::Regnamespace), + 4090 => Some(Inner::RegnamespaceArray), + 4096 => Some(Inner::Regrole), + 4097 => Some(Inner::RegroleArray), + 4191 => Some(Inner::Regcollation), + 4192 => Some(Inner::RegcollationArray), + 4451 => Some(Inner::Int4multiRange), + 4532 => Some(Inner::NummultiRange), + 4533 => Some(Inner::TsmultiRange), + 4534 => Some(Inner::TstzmultiRange), + 4535 => Some(Inner::DatemultiRange), + 4536 => Some(Inner::Int8multiRange), + 4537 => Some(Inner::AnymultiRange), + 4538 => Some(Inner::AnycompatiblemultiRange), + 4600 => Some(Inner::PgBrinBloomSummary), + 4601 => Some(Inner::PgBrinMinmaxMultiSummary), + 5017 => Some(Inner::PgMcvList), + 5038 => Some(Inner::PgSnapshot), + 5039 => Some(Inner::PgSnapshotArray), + 5069 => Some(Inner::Xid8), + 5077 => Some(Inner::Anycompatible), + 5078 => Some(Inner::Anycompatiblearray), + 5079 => Some(Inner::Anycompatiblenonarray), + 5080 => Some(Inner::AnycompatibleRange), + 6150 => Some(Inner::Int4multiRangeArray), + 6151 => Some(Inner::NummultiRangeArray), + 6152 => Some(Inner::TsmultiRangeArray), + 6153 => Some(Inner::TstzmultiRangeArray), + 6155 => Some(Inner::DatemultiRangeArray), + 6157 => Some(Inner::Int8multiRangeArray), + _ => None, + } + } + + pub fn oid(&self) -> Oid { + match *self { + Inner::Bool => 16, + Inner::Bytea => 17, + Inner::Char => 18, + Inner::Name => 19, + Inner::Int8 => 20, + Inner::Int2 => 21, + Inner::Int2Vector => 22, + Inner::Int4 => 23, + Inner::Regproc => 24, + Inner::Text => 25, + Inner::Oid => 26, + Inner::Tid => 27, + Inner::Xid => 28, + Inner::Cid => 29, + Inner::OidVector => 30, + Inner::PgDdlCommand => 32, + Inner::Json => 114, + Inner::Xml => 142, + Inner::XmlArray => 143, + Inner::PgNodeTree => 194, + Inner::JsonArray => 199, + Inner::TableAmHandler => 269, + Inner::Xid8Array => 271, + Inner::IndexAmHandler => 325, + Inner::Point => 600, + Inner::Lseg => 601, + Inner::Path => 602, + Inner::Box => 603, + Inner::Polygon => 604, + Inner::Line => 628, + Inner::LineArray => 629, + Inner::Cidr => 650, + Inner::CidrArray => 651, + Inner::Float4 => 700, + Inner::Float8 => 701, + Inner::Unknown => 705, + Inner::Circle => 718, + Inner::CircleArray => 719, + Inner::Macaddr8 => 774, + Inner::Macaddr8Array => 775, + Inner::Money => 790, + Inner::MoneyArray => 791, + Inner::Macaddr => 829, + Inner::Inet => 869, + Inner::BoolArray => 1000, + Inner::ByteaArray => 1001, + Inner::CharArray => 1002, + Inner::NameArray => 1003, + Inner::Int2Array => 1005, + Inner::Int2VectorArray => 1006, + Inner::Int4Array => 1007, + Inner::RegprocArray => 1008, + Inner::TextArray => 1009, + Inner::TidArray => 1010, + Inner::XidArray => 1011, + Inner::CidArray => 1012, + Inner::OidVectorArray => 1013, + Inner::BpcharArray => 1014, + Inner::VarcharArray => 1015, + Inner::Int8Array => 1016, + Inner::PointArray => 1017, + Inner::LsegArray => 1018, + Inner::PathArray => 1019, + Inner::BoxArray => 1020, + Inner::Float4Array => 1021, + Inner::Float8Array => 1022, + Inner::PolygonArray => 1027, + Inner::OidArray => 1028, + Inner::Aclitem => 1033, + Inner::AclitemArray => 1034, + Inner::MacaddrArray => 1040, + Inner::InetArray => 1041, + Inner::Bpchar => 1042, + Inner::Varchar => 1043, + Inner::Date => 1082, + Inner::Time => 1083, + Inner::Timestamp => 1114, + Inner::TimestampArray => 1115, + Inner::DateArray => 1182, + Inner::TimeArray => 1183, + Inner::Timestamptz => 1184, + Inner::TimestamptzArray => 1185, + Inner::Interval => 1186, + Inner::IntervalArray => 1187, + Inner::NumericArray => 1231, + Inner::CstringArray => 1263, + Inner::Timetz => 1266, + Inner::TimetzArray => 1270, + Inner::Bit => 1560, + Inner::BitArray => 1561, + Inner::Varbit => 1562, + Inner::VarbitArray => 1563, + Inner::Numeric => 1700, + Inner::Refcursor => 1790, + Inner::RefcursorArray => 2201, + Inner::Regprocedure => 2202, + Inner::Regoper => 2203, + Inner::Regoperator => 2204, + Inner::Regclass => 2205, + Inner::Regtype => 2206, + Inner::RegprocedureArray => 2207, + Inner::RegoperArray => 2208, + Inner::RegoperatorArray => 2209, + Inner::RegclassArray => 2210, + Inner::RegtypeArray => 2211, + Inner::Record => 2249, + Inner::Cstring => 2275, + Inner::Any => 2276, + Inner::Anyarray => 2277, + Inner::Void => 2278, + Inner::Trigger => 2279, + Inner::LanguageHandler => 2280, + Inner::Internal => 2281, + Inner::Anyelement => 2283, + Inner::RecordArray => 2287, + Inner::Anynonarray => 2776, + Inner::TxidSnapshotArray => 2949, + Inner::Uuid => 2950, + Inner::UuidArray => 2951, + Inner::TxidSnapshot => 2970, + Inner::FdwHandler => 3115, + Inner::PgLsn => 3220, + Inner::PgLsnArray => 3221, + Inner::TsmHandler => 3310, + Inner::PgNdistinct => 3361, + Inner::PgDependencies => 3402, + Inner::Anyenum => 3500, + Inner::TsVector => 3614, + Inner::Tsquery => 3615, + Inner::GtsVector => 3642, + Inner::TsVectorArray => 3643, + Inner::GtsVectorArray => 3644, + Inner::TsqueryArray => 3645, + Inner::Regconfig => 3734, + Inner::RegconfigArray => 3735, + Inner::Regdictionary => 3769, + Inner::RegdictionaryArray => 3770, + Inner::Jsonb => 3802, + Inner::JsonbArray => 3807, + Inner::AnyRange => 3831, + Inner::EventTrigger => 3838, + Inner::Int4Range => 3904, + Inner::Int4RangeArray => 3905, + Inner::NumRange => 3906, + Inner::NumRangeArray => 3907, + Inner::TsRange => 3908, + Inner::TsRangeArray => 3909, + Inner::TstzRange => 3910, + Inner::TstzRangeArray => 3911, + Inner::DateRange => 3912, + Inner::DateRangeArray => 3913, + Inner::Int8Range => 3926, + Inner::Int8RangeArray => 3927, + Inner::Jsonpath => 4072, + Inner::JsonpathArray => 4073, + Inner::Regnamespace => 4089, + Inner::RegnamespaceArray => 4090, + Inner::Regrole => 4096, + Inner::RegroleArray => 4097, + Inner::Regcollation => 4191, + Inner::RegcollationArray => 4192, + Inner::Int4multiRange => 4451, + Inner::NummultiRange => 4532, + Inner::TsmultiRange => 4533, + Inner::TstzmultiRange => 4534, + Inner::DatemultiRange => 4535, + Inner::Int8multiRange => 4536, + Inner::AnymultiRange => 4537, + Inner::AnycompatiblemultiRange => 4538, + Inner::PgBrinBloomSummary => 4600, + Inner::PgBrinMinmaxMultiSummary => 4601, + Inner::PgMcvList => 5017, + Inner::PgSnapshot => 5038, + Inner::PgSnapshotArray => 5039, + Inner::Xid8 => 5069, + Inner::Anycompatible => 5077, + Inner::Anycompatiblearray => 5078, + Inner::Anycompatiblenonarray => 5079, + Inner::AnycompatibleRange => 5080, + Inner::Int4multiRangeArray => 6150, + Inner::NummultiRangeArray => 6151, + Inner::TsmultiRangeArray => 6152, + Inner::TstzmultiRangeArray => 6153, + Inner::DatemultiRangeArray => 6155, + Inner::Int8multiRangeArray => 6157, + Inner::Other(ref u) => u.oid, + } + } + + pub fn kind(&self) -> &Kind { + match *self { + Inner::Bool => &Kind::Simple, + Inner::Bytea => &Kind::Simple, + Inner::Char => &Kind::Simple, + Inner::Name => &Kind::Simple, + Inner::Int8 => &Kind::Simple, + Inner::Int2 => &Kind::Simple, + Inner::Int2Vector => &Kind::Array(Type(Inner::Int2)), + Inner::Int4 => &Kind::Simple, + Inner::Regproc => &Kind::Simple, + Inner::Text => &Kind::Simple, + Inner::Oid => &Kind::Simple, + Inner::Tid => &Kind::Simple, + Inner::Xid => &Kind::Simple, + Inner::Cid => &Kind::Simple, + Inner::OidVector => &Kind::Array(Type(Inner::Oid)), + Inner::PgDdlCommand => &Kind::Pseudo, + Inner::Json => &Kind::Simple, + Inner::Xml => &Kind::Simple, + Inner::XmlArray => &Kind::Array(Type(Inner::Xml)), + Inner::PgNodeTree => &Kind::Simple, + Inner::JsonArray => &Kind::Array(Type(Inner::Json)), + Inner::TableAmHandler => &Kind::Pseudo, + Inner::Xid8Array => &Kind::Array(Type(Inner::Xid8)), + Inner::IndexAmHandler => &Kind::Pseudo, + Inner::Point => &Kind::Simple, + Inner::Lseg => &Kind::Simple, + Inner::Path => &Kind::Simple, + Inner::Box => &Kind::Simple, + Inner::Polygon => &Kind::Simple, + Inner::Line => &Kind::Simple, + Inner::LineArray => &Kind::Array(Type(Inner::Line)), + Inner::Cidr => &Kind::Simple, + Inner::CidrArray => &Kind::Array(Type(Inner::Cidr)), + Inner::Float4 => &Kind::Simple, + Inner::Float8 => &Kind::Simple, + Inner::Unknown => &Kind::Simple, + Inner::Circle => &Kind::Simple, + Inner::CircleArray => &Kind::Array(Type(Inner::Circle)), + Inner::Macaddr8 => &Kind::Simple, + Inner::Macaddr8Array => &Kind::Array(Type(Inner::Macaddr8)), + Inner::Money => &Kind::Simple, + Inner::MoneyArray => &Kind::Array(Type(Inner::Money)), + Inner::Macaddr => &Kind::Simple, + Inner::Inet => &Kind::Simple, + Inner::BoolArray => &Kind::Array(Type(Inner::Bool)), + Inner::ByteaArray => &Kind::Array(Type(Inner::Bytea)), + Inner::CharArray => &Kind::Array(Type(Inner::Char)), + Inner::NameArray => &Kind::Array(Type(Inner::Name)), + Inner::Int2Array => &Kind::Array(Type(Inner::Int2)), + Inner::Int2VectorArray => &Kind::Array(Type(Inner::Int2Vector)), + Inner::Int4Array => &Kind::Array(Type(Inner::Int4)), + Inner::RegprocArray => &Kind::Array(Type(Inner::Regproc)), + Inner::TextArray => &Kind::Array(Type(Inner::Text)), + Inner::TidArray => &Kind::Array(Type(Inner::Tid)), + Inner::XidArray => &Kind::Array(Type(Inner::Xid)), + Inner::CidArray => &Kind::Array(Type(Inner::Cid)), + Inner::OidVectorArray => &Kind::Array(Type(Inner::OidVector)), + Inner::BpcharArray => &Kind::Array(Type(Inner::Bpchar)), + Inner::VarcharArray => &Kind::Array(Type(Inner::Varchar)), + Inner::Int8Array => &Kind::Array(Type(Inner::Int8)), + Inner::PointArray => &Kind::Array(Type(Inner::Point)), + Inner::LsegArray => &Kind::Array(Type(Inner::Lseg)), + Inner::PathArray => &Kind::Array(Type(Inner::Path)), + Inner::BoxArray => &Kind::Array(Type(Inner::Box)), + Inner::Float4Array => &Kind::Array(Type(Inner::Float4)), + Inner::Float8Array => &Kind::Array(Type(Inner::Float8)), + Inner::PolygonArray => &Kind::Array(Type(Inner::Polygon)), + Inner::OidArray => &Kind::Array(Type(Inner::Oid)), + Inner::Aclitem => &Kind::Simple, + Inner::AclitemArray => &Kind::Array(Type(Inner::Aclitem)), + Inner::MacaddrArray => &Kind::Array(Type(Inner::Macaddr)), + Inner::InetArray => &Kind::Array(Type(Inner::Inet)), + Inner::Bpchar => &Kind::Simple, + Inner::Varchar => &Kind::Simple, + Inner::Date => &Kind::Simple, + Inner::Time => &Kind::Simple, + Inner::Timestamp => &Kind::Simple, + Inner::TimestampArray => &Kind::Array(Type(Inner::Timestamp)), + Inner::DateArray => &Kind::Array(Type(Inner::Date)), + Inner::TimeArray => &Kind::Array(Type(Inner::Time)), + Inner::Timestamptz => &Kind::Simple, + Inner::TimestamptzArray => &Kind::Array(Type(Inner::Timestamptz)), + Inner::Interval => &Kind::Simple, + Inner::IntervalArray => &Kind::Array(Type(Inner::Interval)), + Inner::NumericArray => &Kind::Array(Type(Inner::Numeric)), + Inner::CstringArray => &Kind::Array(Type(Inner::Cstring)), + Inner::Timetz => &Kind::Simple, + Inner::TimetzArray => &Kind::Array(Type(Inner::Timetz)), + Inner::Bit => &Kind::Simple, + Inner::BitArray => &Kind::Array(Type(Inner::Bit)), + Inner::Varbit => &Kind::Simple, + Inner::VarbitArray => &Kind::Array(Type(Inner::Varbit)), + Inner::Numeric => &Kind::Simple, + Inner::Refcursor => &Kind::Simple, + Inner::RefcursorArray => &Kind::Array(Type(Inner::Refcursor)), + Inner::Regprocedure => &Kind::Simple, + Inner::Regoper => &Kind::Simple, + Inner::Regoperator => &Kind::Simple, + Inner::Regclass => &Kind::Simple, + Inner::Regtype => &Kind::Simple, + Inner::RegprocedureArray => &Kind::Array(Type(Inner::Regprocedure)), + Inner::RegoperArray => &Kind::Array(Type(Inner::Regoper)), + Inner::RegoperatorArray => &Kind::Array(Type(Inner::Regoperator)), + Inner::RegclassArray => &Kind::Array(Type(Inner::Regclass)), + Inner::RegtypeArray => &Kind::Array(Type(Inner::Regtype)), + Inner::Record => &Kind::Pseudo, + Inner::Cstring => &Kind::Pseudo, + Inner::Any => &Kind::Pseudo, + Inner::Anyarray => &Kind::Pseudo, + Inner::Void => &Kind::Pseudo, + Inner::Trigger => &Kind::Pseudo, + Inner::LanguageHandler => &Kind::Pseudo, + Inner::Internal => &Kind::Pseudo, + Inner::Anyelement => &Kind::Pseudo, + Inner::RecordArray => &Kind::Pseudo, + Inner::Anynonarray => &Kind::Pseudo, + Inner::TxidSnapshotArray => &Kind::Array(Type(Inner::TxidSnapshot)), + Inner::Uuid => &Kind::Simple, + Inner::UuidArray => &Kind::Array(Type(Inner::Uuid)), + Inner::TxidSnapshot => &Kind::Simple, + Inner::FdwHandler => &Kind::Pseudo, + Inner::PgLsn => &Kind::Simple, + Inner::PgLsnArray => &Kind::Array(Type(Inner::PgLsn)), + Inner::TsmHandler => &Kind::Pseudo, + Inner::PgNdistinct => &Kind::Simple, + Inner::PgDependencies => &Kind::Simple, + Inner::Anyenum => &Kind::Pseudo, + Inner::TsVector => &Kind::Simple, + Inner::Tsquery => &Kind::Simple, + Inner::GtsVector => &Kind::Simple, + Inner::TsVectorArray => &Kind::Array(Type(Inner::TsVector)), + Inner::GtsVectorArray => &Kind::Array(Type(Inner::GtsVector)), + Inner::TsqueryArray => &Kind::Array(Type(Inner::Tsquery)), + Inner::Regconfig => &Kind::Simple, + Inner::RegconfigArray => &Kind::Array(Type(Inner::Regconfig)), + Inner::Regdictionary => &Kind::Simple, + Inner::RegdictionaryArray => &Kind::Array(Type(Inner::Regdictionary)), + Inner::Jsonb => &Kind::Simple, + Inner::JsonbArray => &Kind::Array(Type(Inner::Jsonb)), + Inner::AnyRange => &Kind::Pseudo, + Inner::EventTrigger => &Kind::Pseudo, + Inner::Int4Range => &Kind::Range(Type(Inner::Int4)), + Inner::Int4RangeArray => &Kind::Array(Type(Inner::Int4Range)), + Inner::NumRange => &Kind::Range(Type(Inner::Numeric)), + Inner::NumRangeArray => &Kind::Array(Type(Inner::NumRange)), + Inner::TsRange => &Kind::Range(Type(Inner::Timestamp)), + Inner::TsRangeArray => &Kind::Array(Type(Inner::TsRange)), + Inner::TstzRange => &Kind::Range(Type(Inner::Timestamptz)), + Inner::TstzRangeArray => &Kind::Array(Type(Inner::TstzRange)), + Inner::DateRange => &Kind::Range(Type(Inner::Date)), + Inner::DateRangeArray => &Kind::Array(Type(Inner::DateRange)), + Inner::Int8Range => &Kind::Range(Type(Inner::Int8)), + Inner::Int8RangeArray => &Kind::Array(Type(Inner::Int8Range)), + Inner::Jsonpath => &Kind::Simple, + Inner::JsonpathArray => &Kind::Array(Type(Inner::Jsonpath)), + Inner::Regnamespace => &Kind::Simple, + Inner::RegnamespaceArray => &Kind::Array(Type(Inner::Regnamespace)), + Inner::Regrole => &Kind::Simple, + Inner::RegroleArray => &Kind::Array(Type(Inner::Regrole)), + Inner::Regcollation => &Kind::Simple, + Inner::RegcollationArray => &Kind::Array(Type(Inner::Regcollation)), + Inner::Int4multiRange => &Kind::Multirange(Type(Inner::Int4)), + Inner::NummultiRange => &Kind::Multirange(Type(Inner::Numeric)), + Inner::TsmultiRange => &Kind::Multirange(Type(Inner::Timestamp)), + Inner::TstzmultiRange => &Kind::Multirange(Type(Inner::Timestamptz)), + Inner::DatemultiRange => &Kind::Multirange(Type(Inner::Date)), + Inner::Int8multiRange => &Kind::Multirange(Type(Inner::Int8)), + Inner::AnymultiRange => &Kind::Pseudo, + Inner::AnycompatiblemultiRange => &Kind::Pseudo, + Inner::PgBrinBloomSummary => &Kind::Simple, + Inner::PgBrinMinmaxMultiSummary => &Kind::Simple, + Inner::PgMcvList => &Kind::Simple, + Inner::PgSnapshot => &Kind::Simple, + Inner::PgSnapshotArray => &Kind::Array(Type(Inner::PgSnapshot)), + Inner::Xid8 => &Kind::Simple, + Inner::Anycompatible => &Kind::Pseudo, + Inner::Anycompatiblearray => &Kind::Pseudo, + Inner::Anycompatiblenonarray => &Kind::Pseudo, + Inner::AnycompatibleRange => &Kind::Pseudo, + Inner::Int4multiRangeArray => &Kind::Array(Type(Inner::Int4multiRange)), + Inner::NummultiRangeArray => &Kind::Array(Type(Inner::NummultiRange)), + Inner::TsmultiRangeArray => &Kind::Array(Type(Inner::TsmultiRange)), + Inner::TstzmultiRangeArray => &Kind::Array(Type(Inner::TstzmultiRange)), + Inner::DatemultiRangeArray => &Kind::Array(Type(Inner::DatemultiRange)), + Inner::Int8multiRangeArray => &Kind::Array(Type(Inner::Int8multiRange)), + Inner::Other(ref u) => &u.kind, + } + } + + pub fn name(&self) -> &str { + match *self { + Inner::Bool => "bool", + Inner::Bytea => "bytea", + Inner::Char => "char", + Inner::Name => "name", + Inner::Int8 => "int8", + Inner::Int2 => "int2", + Inner::Int2Vector => "int2vector", + Inner::Int4 => "int4", + Inner::Regproc => "regproc", + Inner::Text => "text", + Inner::Oid => "oid", + Inner::Tid => "tid", + Inner::Xid => "xid", + Inner::Cid => "cid", + Inner::OidVector => "oidvector", + Inner::PgDdlCommand => "pg_ddl_command", + Inner::Json => "json", + Inner::Xml => "xml", + Inner::XmlArray => "_xml", + Inner::PgNodeTree => "pg_node_tree", + Inner::JsonArray => "_json", + Inner::TableAmHandler => "table_am_handler", + Inner::Xid8Array => "_xid8", + Inner::IndexAmHandler => "index_am_handler", + Inner::Point => "point", + Inner::Lseg => "lseg", + Inner::Path => "path", + Inner::Box => "box", + Inner::Polygon => "polygon", + Inner::Line => "line", + Inner::LineArray => "_line", + Inner::Cidr => "cidr", + Inner::CidrArray => "_cidr", + Inner::Float4 => "float4", + Inner::Float8 => "float8", + Inner::Unknown => "unknown", + Inner::Circle => "circle", + Inner::CircleArray => "_circle", + Inner::Macaddr8 => "macaddr8", + Inner::Macaddr8Array => "_macaddr8", + Inner::Money => "money", + Inner::MoneyArray => "_money", + Inner::Macaddr => "macaddr", + Inner::Inet => "inet", + Inner::BoolArray => "_bool", + Inner::ByteaArray => "_bytea", + Inner::CharArray => "_char", + Inner::NameArray => "_name", + Inner::Int2Array => "_int2", + Inner::Int2VectorArray => "_int2vector", + Inner::Int4Array => "_int4", + Inner::RegprocArray => "_regproc", + Inner::TextArray => "_text", + Inner::TidArray => "_tid", + Inner::XidArray => "_xid", + Inner::CidArray => "_cid", + Inner::OidVectorArray => "_oidvector", + Inner::BpcharArray => "_bpchar", + Inner::VarcharArray => "_varchar", + Inner::Int8Array => "_int8", + Inner::PointArray => "_point", + Inner::LsegArray => "_lseg", + Inner::PathArray => "_path", + Inner::BoxArray => "_box", + Inner::Float4Array => "_float4", + Inner::Float8Array => "_float8", + Inner::PolygonArray => "_polygon", + Inner::OidArray => "_oid", + Inner::Aclitem => "aclitem", + Inner::AclitemArray => "_aclitem", + Inner::MacaddrArray => "_macaddr", + Inner::InetArray => "_inet", + Inner::Bpchar => "bpchar", + Inner::Varchar => "varchar", + Inner::Date => "date", + Inner::Time => "time", + Inner::Timestamp => "timestamp", + Inner::TimestampArray => "_timestamp", + Inner::DateArray => "_date", + Inner::TimeArray => "_time", + Inner::Timestamptz => "timestamptz", + Inner::TimestamptzArray => "_timestamptz", + Inner::Interval => "interval", + Inner::IntervalArray => "_interval", + Inner::NumericArray => "_numeric", + Inner::CstringArray => "_cstring", + Inner::Timetz => "timetz", + Inner::TimetzArray => "_timetz", + Inner::Bit => "bit", + Inner::BitArray => "_bit", + Inner::Varbit => "varbit", + Inner::VarbitArray => "_varbit", + Inner::Numeric => "numeric", + Inner::Refcursor => "refcursor", + Inner::RefcursorArray => "_refcursor", + Inner::Regprocedure => "regprocedure", + Inner::Regoper => "regoper", + Inner::Regoperator => "regoperator", + Inner::Regclass => "regclass", + Inner::Regtype => "regtype", + Inner::RegprocedureArray => "_regprocedure", + Inner::RegoperArray => "_regoper", + Inner::RegoperatorArray => "_regoperator", + Inner::RegclassArray => "_regclass", + Inner::RegtypeArray => "_regtype", + Inner::Record => "record", + Inner::Cstring => "cstring", + Inner::Any => "any", + Inner::Anyarray => "anyarray", + Inner::Void => "void", + Inner::Trigger => "trigger", + Inner::LanguageHandler => "language_handler", + Inner::Internal => "internal", + Inner::Anyelement => "anyelement", + Inner::RecordArray => "_record", + Inner::Anynonarray => "anynonarray", + Inner::TxidSnapshotArray => "_txid_snapshot", + Inner::Uuid => "uuid", + Inner::UuidArray => "_uuid", + Inner::TxidSnapshot => "txid_snapshot", + Inner::FdwHandler => "fdw_handler", + Inner::PgLsn => "pg_lsn", + Inner::PgLsnArray => "_pg_lsn", + Inner::TsmHandler => "tsm_handler", + Inner::PgNdistinct => "pg_ndistinct", + Inner::PgDependencies => "pg_dependencies", + Inner::Anyenum => "anyenum", + Inner::TsVector => "tsvector", + Inner::Tsquery => "tsquery", + Inner::GtsVector => "gtsvector", + Inner::TsVectorArray => "_tsvector", + Inner::GtsVectorArray => "_gtsvector", + Inner::TsqueryArray => "_tsquery", + Inner::Regconfig => "regconfig", + Inner::RegconfigArray => "_regconfig", + Inner::Regdictionary => "regdictionary", + Inner::RegdictionaryArray => "_regdictionary", + Inner::Jsonb => "jsonb", + Inner::JsonbArray => "_jsonb", + Inner::AnyRange => "anyrange", + Inner::EventTrigger => "event_trigger", + Inner::Int4Range => "int4range", + Inner::Int4RangeArray => "_int4range", + Inner::NumRange => "numrange", + Inner::NumRangeArray => "_numrange", + Inner::TsRange => "tsrange", + Inner::TsRangeArray => "_tsrange", + Inner::TstzRange => "tstzrange", + Inner::TstzRangeArray => "_tstzrange", + Inner::DateRange => "daterange", + Inner::DateRangeArray => "_daterange", + Inner::Int8Range => "int8range", + Inner::Int8RangeArray => "_int8range", + Inner::Jsonpath => "jsonpath", + Inner::JsonpathArray => "_jsonpath", + Inner::Regnamespace => "regnamespace", + Inner::RegnamespaceArray => "_regnamespace", + Inner::Regrole => "regrole", + Inner::RegroleArray => "_regrole", + Inner::Regcollation => "regcollation", + Inner::RegcollationArray => "_regcollation", + Inner::Int4multiRange => "int4multirange", + Inner::NummultiRange => "nummultirange", + Inner::TsmultiRange => "tsmultirange", + Inner::TstzmultiRange => "tstzmultirange", + Inner::DatemultiRange => "datemultirange", + Inner::Int8multiRange => "int8multirange", + Inner::AnymultiRange => "anymultirange", + Inner::AnycompatiblemultiRange => "anycompatiblemultirange", + Inner::PgBrinBloomSummary => "pg_brin_bloom_summary", + Inner::PgBrinMinmaxMultiSummary => "pg_brin_minmax_multi_summary", + Inner::PgMcvList => "pg_mcv_list", + Inner::PgSnapshot => "pg_snapshot", + Inner::PgSnapshotArray => "_pg_snapshot", + Inner::Xid8 => "xid8", + Inner::Anycompatible => "anycompatible", + Inner::Anycompatiblearray => "anycompatiblearray", + Inner::Anycompatiblenonarray => "anycompatiblenonarray", + Inner::AnycompatibleRange => "anycompatiblerange", + Inner::Int4multiRangeArray => "_int4multirange", + Inner::NummultiRangeArray => "_nummultirange", + Inner::TsmultiRangeArray => "_tsmultirange", + Inner::TstzmultiRangeArray => "_tstzmultirange", + Inner::DatemultiRangeArray => "_datemultirange", + Inner::Int8multiRangeArray => "_int8multirange", + Inner::Other(ref u) => &u.name, + } + } +} +impl Type { + /// BOOL - boolean, 'true'/'false' + pub const BOOL: Type = Type(Inner::Bool); + + /// BYTEA - variable-length string, binary values escaped + pub const BYTEA: Type = Type(Inner::Bytea); + + /// CHAR - single character + pub const CHAR: Type = Type(Inner::Char); + + /// NAME - 63-byte type for storing system identifiers + pub const NAME: Type = Type(Inner::Name); + + /// INT8 - ~18 digit integer, 8-byte storage + pub const INT8: Type = Type(Inner::Int8); + + /// INT2 - -32 thousand to 32 thousand, 2-byte storage + pub const INT2: Type = Type(Inner::Int2); + + /// INT2VECTOR - array of int2, used in system tables + pub const INT2_VECTOR: Type = Type(Inner::Int2Vector); + + /// INT4 - -2 billion to 2 billion integer, 4-byte storage + pub const INT4: Type = Type(Inner::Int4); + + /// REGPROC - registered procedure + pub const REGPROC: Type = Type(Inner::Regproc); + + /// TEXT - variable-length string, no limit specified + pub const TEXT: Type = Type(Inner::Text); + + /// OID - object identifier(oid), maximum 4 billion + pub const OID: Type = Type(Inner::Oid); + + /// TID - (block, offset), physical location of tuple + pub const TID: Type = Type(Inner::Tid); + + /// XID - transaction id + pub const XID: Type = Type(Inner::Xid); + + /// CID - command identifier type, sequence in transaction id + pub const CID: Type = Type(Inner::Cid); + + /// OIDVECTOR - array of oids, used in system tables + pub const OID_VECTOR: Type = Type(Inner::OidVector); + + /// PG_DDL_COMMAND - internal type for passing CollectedCommand + pub const PG_DDL_COMMAND: Type = Type(Inner::PgDdlCommand); + + /// JSON - JSON stored as text + pub const JSON: Type = Type(Inner::Json); + + /// XML - XML content + pub const XML: Type = Type(Inner::Xml); + + /// XML[] + pub const XML_ARRAY: Type = Type(Inner::XmlArray); + + /// PG_NODE_TREE - string representing an internal node tree + pub const PG_NODE_TREE: Type = Type(Inner::PgNodeTree); + + /// JSON[] + pub const JSON_ARRAY: Type = Type(Inner::JsonArray); + + /// TABLE_AM_HANDLER + pub const TABLE_AM_HANDLER: Type = Type(Inner::TableAmHandler); + + /// XID8[] + pub const XID8_ARRAY: Type = Type(Inner::Xid8Array); + + /// INDEX_AM_HANDLER - pseudo-type for the result of an index AM handler function + pub const INDEX_AM_HANDLER: Type = Type(Inner::IndexAmHandler); + + /// POINT - geometric point '(x, y)' + pub const POINT: Type = Type(Inner::Point); + + /// LSEG - geometric line segment '(pt1,pt2)' + pub const LSEG: Type = Type(Inner::Lseg); + + /// PATH - geometric path '(pt1,...)' + pub const PATH: Type = Type(Inner::Path); + + /// BOX - geometric box '(lower left,upper right)' + pub const BOX: Type = Type(Inner::Box); + + /// POLYGON - geometric polygon '(pt1,...)' + pub const POLYGON: Type = Type(Inner::Polygon); + + /// LINE - geometric line + pub const LINE: Type = Type(Inner::Line); + + /// LINE[] + pub const LINE_ARRAY: Type = Type(Inner::LineArray); + + /// CIDR - network IP address/netmask, network address + pub const CIDR: Type = Type(Inner::Cidr); + + /// CIDR[] + pub const CIDR_ARRAY: Type = Type(Inner::CidrArray); + + /// FLOAT4 - single-precision floating point number, 4-byte storage + pub const FLOAT4: Type = Type(Inner::Float4); + + /// FLOAT8 - double-precision floating point number, 8-byte storage + pub const FLOAT8: Type = Type(Inner::Float8); + + /// UNKNOWN - pseudo-type representing an undetermined type + pub const UNKNOWN: Type = Type(Inner::Unknown); + + /// CIRCLE - geometric circle '(center,radius)' + pub const CIRCLE: Type = Type(Inner::Circle); + + /// CIRCLE[] + pub const CIRCLE_ARRAY: Type = Type(Inner::CircleArray); + + /// MACADDR8 - XX:XX:XX:XX:XX:XX:XX:XX, MAC address + pub const MACADDR8: Type = Type(Inner::Macaddr8); + + /// MACADDR8[] + pub const MACADDR8_ARRAY: Type = Type(Inner::Macaddr8Array); + + /// MONEY - monetary amounts, $d,ddd.cc + pub const MONEY: Type = Type(Inner::Money); + + /// MONEY[] + pub const MONEY_ARRAY: Type = Type(Inner::MoneyArray); + + /// MACADDR - XX:XX:XX:XX:XX:XX, MAC address + pub const MACADDR: Type = Type(Inner::Macaddr); + + /// INET - IP address/netmask, host address, netmask optional + pub const INET: Type = Type(Inner::Inet); + + /// BOOL[] + pub const BOOL_ARRAY: Type = Type(Inner::BoolArray); + + /// BYTEA[] + pub const BYTEA_ARRAY: Type = Type(Inner::ByteaArray); + + /// CHAR[] + pub const CHAR_ARRAY: Type = Type(Inner::CharArray); + + /// NAME[] + pub const NAME_ARRAY: Type = Type(Inner::NameArray); + + /// INT2[] + pub const INT2_ARRAY: Type = Type(Inner::Int2Array); + + /// INT2VECTOR[] + pub const INT2_VECTOR_ARRAY: Type = Type(Inner::Int2VectorArray); + + /// INT4[] + pub const INT4_ARRAY: Type = Type(Inner::Int4Array); + + /// REGPROC[] + pub const REGPROC_ARRAY: Type = Type(Inner::RegprocArray); + + /// TEXT[] + pub const TEXT_ARRAY: Type = Type(Inner::TextArray); + + /// TID[] + pub const TID_ARRAY: Type = Type(Inner::TidArray); + + /// XID[] + pub const XID_ARRAY: Type = Type(Inner::XidArray); + + /// CID[] + pub const CID_ARRAY: Type = Type(Inner::CidArray); + + /// OIDVECTOR[] + pub const OID_VECTOR_ARRAY: Type = Type(Inner::OidVectorArray); + + /// BPCHAR[] + pub const BPCHAR_ARRAY: Type = Type(Inner::BpcharArray); + + /// VARCHAR[] + pub const VARCHAR_ARRAY: Type = Type(Inner::VarcharArray); + + /// INT8[] + pub const INT8_ARRAY: Type = Type(Inner::Int8Array); + + /// POINT[] + pub const POINT_ARRAY: Type = Type(Inner::PointArray); + + /// LSEG[] + pub const LSEG_ARRAY: Type = Type(Inner::LsegArray); + + /// PATH[] + pub const PATH_ARRAY: Type = Type(Inner::PathArray); + + /// BOX[] + pub const BOX_ARRAY: Type = Type(Inner::BoxArray); + + /// FLOAT4[] + pub const FLOAT4_ARRAY: Type = Type(Inner::Float4Array); + + /// FLOAT8[] + pub const FLOAT8_ARRAY: Type = Type(Inner::Float8Array); + + /// POLYGON[] + pub const POLYGON_ARRAY: Type = Type(Inner::PolygonArray); + + /// OID[] + pub const OID_ARRAY: Type = Type(Inner::OidArray); + + /// ACLITEM - access control list + pub const ACLITEM: Type = Type(Inner::Aclitem); + + /// ACLITEM[] + pub const ACLITEM_ARRAY: Type = Type(Inner::AclitemArray); + + /// MACADDR[] + pub const MACADDR_ARRAY: Type = Type(Inner::MacaddrArray); + + /// INET[] + pub const INET_ARRAY: Type = Type(Inner::InetArray); + + /// BPCHAR - char(length), blank-padded string, fixed storage length + pub const BPCHAR: Type = Type(Inner::Bpchar); + + /// VARCHAR - varchar(length), non-blank-padded string, variable storage length + pub const VARCHAR: Type = Type(Inner::Varchar); + + /// DATE - date + pub const DATE: Type = Type(Inner::Date); + + /// TIME - time of day + pub const TIME: Type = Type(Inner::Time); + + /// TIMESTAMP - date and time + pub const TIMESTAMP: Type = Type(Inner::Timestamp); + + /// TIMESTAMP[] + pub const TIMESTAMP_ARRAY: Type = Type(Inner::TimestampArray); + + /// DATE[] + pub const DATE_ARRAY: Type = Type(Inner::DateArray); + + /// TIME[] + pub const TIME_ARRAY: Type = Type(Inner::TimeArray); + + /// TIMESTAMPTZ - date and time with time zone + pub const TIMESTAMPTZ: Type = Type(Inner::Timestamptz); + + /// TIMESTAMPTZ[] + pub const TIMESTAMPTZ_ARRAY: Type = Type(Inner::TimestamptzArray); + + /// INTERVAL - @ <number> <units>, time interval + pub const INTERVAL: Type = Type(Inner::Interval); + + /// INTERVAL[] + pub const INTERVAL_ARRAY: Type = Type(Inner::IntervalArray); + + /// NUMERIC[] + pub const NUMERIC_ARRAY: Type = Type(Inner::NumericArray); + + /// CSTRING[] + pub const CSTRING_ARRAY: Type = Type(Inner::CstringArray); + + /// TIMETZ - time of day with time zone + pub const TIMETZ: Type = Type(Inner::Timetz); + + /// TIMETZ[] + pub const TIMETZ_ARRAY: Type = Type(Inner::TimetzArray); + + /// BIT - fixed-length bit string + pub const BIT: Type = Type(Inner::Bit); + + /// BIT[] + pub const BIT_ARRAY: Type = Type(Inner::BitArray); + + /// VARBIT - variable-length bit string + pub const VARBIT: Type = Type(Inner::Varbit); + + /// VARBIT[] + pub const VARBIT_ARRAY: Type = Type(Inner::VarbitArray); + + /// NUMERIC - numeric(precision, decimal), arbitrary precision number + pub const NUMERIC: Type = Type(Inner::Numeric); + + /// REFCURSOR - reference to cursor (portal name) + pub const REFCURSOR: Type = Type(Inner::Refcursor); + + /// REFCURSOR[] + pub const REFCURSOR_ARRAY: Type = Type(Inner::RefcursorArray); + + /// REGPROCEDURE - registered procedure (with args) + pub const REGPROCEDURE: Type = Type(Inner::Regprocedure); + + /// REGOPER - registered operator + pub const REGOPER: Type = Type(Inner::Regoper); + + /// REGOPERATOR - registered operator (with args) + pub const REGOPERATOR: Type = Type(Inner::Regoperator); + + /// REGCLASS - registered class + pub const REGCLASS: Type = Type(Inner::Regclass); + + /// REGTYPE - registered type + pub const REGTYPE: Type = Type(Inner::Regtype); + + /// REGPROCEDURE[] + pub const REGPROCEDURE_ARRAY: Type = Type(Inner::RegprocedureArray); + + /// REGOPER[] + pub const REGOPER_ARRAY: Type = Type(Inner::RegoperArray); + + /// REGOPERATOR[] + pub const REGOPERATOR_ARRAY: Type = Type(Inner::RegoperatorArray); + + /// REGCLASS[] + pub const REGCLASS_ARRAY: Type = Type(Inner::RegclassArray); + + /// REGTYPE[] + pub const REGTYPE_ARRAY: Type = Type(Inner::RegtypeArray); + + /// RECORD - pseudo-type representing any composite type + pub const RECORD: Type = Type(Inner::Record); + + /// CSTRING - C-style string + pub const CSTRING: Type = Type(Inner::Cstring); + + /// ANY - pseudo-type representing any type + pub const ANY: Type = Type(Inner::Any); + + /// ANYARRAY - pseudo-type representing a polymorphic array type + pub const ANYARRAY: Type = Type(Inner::Anyarray); + + /// VOID - pseudo-type for the result of a function with no real result + pub const VOID: Type = Type(Inner::Void); + + /// TRIGGER - pseudo-type for the result of a trigger function + pub const TRIGGER: Type = Type(Inner::Trigger); + + /// LANGUAGE_HANDLER - pseudo-type for the result of a language handler function + pub const LANGUAGE_HANDLER: Type = Type(Inner::LanguageHandler); + + /// INTERNAL - pseudo-type representing an internal data structure + pub const INTERNAL: Type = Type(Inner::Internal); + + /// ANYELEMENT - pseudo-type representing a polymorphic base type + pub const ANYELEMENT: Type = Type(Inner::Anyelement); + + /// RECORD[] + pub const RECORD_ARRAY: Type = Type(Inner::RecordArray); + + /// ANYNONARRAY - pseudo-type representing a polymorphic base type that is not an array + pub const ANYNONARRAY: Type = Type(Inner::Anynonarray); + + /// TXID_SNAPSHOT[] + pub const TXID_SNAPSHOT_ARRAY: Type = Type(Inner::TxidSnapshotArray); + + /// UUID - UUID datatype + pub const UUID: Type = Type(Inner::Uuid); + + /// UUID[] + pub const UUID_ARRAY: Type = Type(Inner::UuidArray); + + /// TXID_SNAPSHOT - txid snapshot + pub const TXID_SNAPSHOT: Type = Type(Inner::TxidSnapshot); + + /// FDW_HANDLER - pseudo-type for the result of an FDW handler function + pub const FDW_HANDLER: Type = Type(Inner::FdwHandler); + + /// PG_LSN - PostgreSQL LSN datatype + pub const PG_LSN: Type = Type(Inner::PgLsn); + + /// PG_LSN[] + pub const PG_LSN_ARRAY: Type = Type(Inner::PgLsnArray); + + /// TSM_HANDLER - pseudo-type for the result of a tablesample method function + pub const TSM_HANDLER: Type = Type(Inner::TsmHandler); + + /// PG_NDISTINCT - multivariate ndistinct coefficients + pub const PG_NDISTINCT: Type = Type(Inner::PgNdistinct); + + /// PG_DEPENDENCIES - multivariate dependencies + pub const PG_DEPENDENCIES: Type = Type(Inner::PgDependencies); + + /// ANYENUM - pseudo-type representing a polymorphic base type that is an enum + pub const ANYENUM: Type = Type(Inner::Anyenum); + + /// TSVECTOR - text representation for text search + pub const TS_VECTOR: Type = Type(Inner::TsVector); + + /// TSQUERY - query representation for text search + pub const TSQUERY: Type = Type(Inner::Tsquery); + + /// GTSVECTOR - GiST index internal text representation for text search + pub const GTS_VECTOR: Type = Type(Inner::GtsVector); + + /// TSVECTOR[] + pub const TS_VECTOR_ARRAY: Type = Type(Inner::TsVectorArray); + + /// GTSVECTOR[] + pub const GTS_VECTOR_ARRAY: Type = Type(Inner::GtsVectorArray); + + /// TSQUERY[] + pub const TSQUERY_ARRAY: Type = Type(Inner::TsqueryArray); + + /// REGCONFIG - registered text search configuration + pub const REGCONFIG: Type = Type(Inner::Regconfig); + + /// REGCONFIG[] + pub const REGCONFIG_ARRAY: Type = Type(Inner::RegconfigArray); + + /// REGDICTIONARY - registered text search dictionary + pub const REGDICTIONARY: Type = Type(Inner::Regdictionary); + + /// REGDICTIONARY[] + pub const REGDICTIONARY_ARRAY: Type = Type(Inner::RegdictionaryArray); + + /// JSONB - Binary JSON + pub const JSONB: Type = Type(Inner::Jsonb); + + /// JSONB[] + pub const JSONB_ARRAY: Type = Type(Inner::JsonbArray); + + /// ANYRANGE - pseudo-type representing a range over a polymorphic base type + pub const ANY_RANGE: Type = Type(Inner::AnyRange); + + /// EVENT_TRIGGER - pseudo-type for the result of an event trigger function + pub const EVENT_TRIGGER: Type = Type(Inner::EventTrigger); + + /// INT4RANGE - range of integers + pub const INT4_RANGE: Type = Type(Inner::Int4Range); + + /// INT4RANGE[] + pub const INT4_RANGE_ARRAY: Type = Type(Inner::Int4RangeArray); + + /// NUMRANGE - range of numerics + pub const NUM_RANGE: Type = Type(Inner::NumRange); + + /// NUMRANGE[] + pub const NUM_RANGE_ARRAY: Type = Type(Inner::NumRangeArray); + + /// TSRANGE - range of timestamps without time zone + pub const TS_RANGE: Type = Type(Inner::TsRange); + + /// TSRANGE[] + pub const TS_RANGE_ARRAY: Type = Type(Inner::TsRangeArray); + + /// TSTZRANGE - range of timestamps with time zone + pub const TSTZ_RANGE: Type = Type(Inner::TstzRange); + + /// TSTZRANGE[] + pub const TSTZ_RANGE_ARRAY: Type = Type(Inner::TstzRangeArray); + + /// DATERANGE - range of dates + pub const DATE_RANGE: Type = Type(Inner::DateRange); + + /// DATERANGE[] + pub const DATE_RANGE_ARRAY: Type = Type(Inner::DateRangeArray); + + /// INT8RANGE - range of bigints + pub const INT8_RANGE: Type = Type(Inner::Int8Range); + + /// INT8RANGE[] + pub const INT8_RANGE_ARRAY: Type = Type(Inner::Int8RangeArray); + + /// JSONPATH - JSON path + pub const JSONPATH: Type = Type(Inner::Jsonpath); + + /// JSONPATH[] + pub const JSONPATH_ARRAY: Type = Type(Inner::JsonpathArray); + + /// REGNAMESPACE - registered namespace + pub const REGNAMESPACE: Type = Type(Inner::Regnamespace); + + /// REGNAMESPACE[] + pub const REGNAMESPACE_ARRAY: Type = Type(Inner::RegnamespaceArray); + + /// REGROLE - registered role + pub const REGROLE: Type = Type(Inner::Regrole); + + /// REGROLE[] + pub const REGROLE_ARRAY: Type = Type(Inner::RegroleArray); + + /// REGCOLLATION - registered collation + pub const REGCOLLATION: Type = Type(Inner::Regcollation); + + /// REGCOLLATION[] + pub const REGCOLLATION_ARRAY: Type = Type(Inner::RegcollationArray); + + /// INT4MULTIRANGE - multirange of integers + pub const INT4MULTI_RANGE: Type = Type(Inner::Int4multiRange); + + /// NUMMULTIRANGE - multirange of numerics + pub const NUMMULTI_RANGE: Type = Type(Inner::NummultiRange); + + /// TSMULTIRANGE - multirange of timestamps without time zone + pub const TSMULTI_RANGE: Type = Type(Inner::TsmultiRange); + + /// TSTZMULTIRANGE - multirange of timestamps with time zone + pub const TSTZMULTI_RANGE: Type = Type(Inner::TstzmultiRange); + + /// DATEMULTIRANGE - multirange of dates + pub const DATEMULTI_RANGE: Type = Type(Inner::DatemultiRange); + + /// INT8MULTIRANGE - multirange of bigints + pub const INT8MULTI_RANGE: Type = Type(Inner::Int8multiRange); + + /// ANYMULTIRANGE - pseudo-type representing a polymorphic base type that is a multirange + pub const ANYMULTI_RANGE: Type = Type(Inner::AnymultiRange); + + /// ANYCOMPATIBLEMULTIRANGE - pseudo-type representing a multirange over a polymorphic common type + pub const ANYCOMPATIBLEMULTI_RANGE: Type = Type(Inner::AnycompatiblemultiRange); + + /// PG_BRIN_BLOOM_SUMMARY - BRIN bloom summary + pub const PG_BRIN_BLOOM_SUMMARY: Type = Type(Inner::PgBrinBloomSummary); + + /// PG_BRIN_MINMAX_MULTI_SUMMARY - BRIN minmax-multi summary + pub const PG_BRIN_MINMAX_MULTI_SUMMARY: Type = Type(Inner::PgBrinMinmaxMultiSummary); + + /// PG_MCV_LIST - multivariate MCV list + pub const PG_MCV_LIST: Type = Type(Inner::PgMcvList); + + /// PG_SNAPSHOT - snapshot + pub const PG_SNAPSHOT: Type = Type(Inner::PgSnapshot); + + /// PG_SNAPSHOT[] + pub const PG_SNAPSHOT_ARRAY: Type = Type(Inner::PgSnapshotArray); + + /// XID8 - full transaction id + pub const XID8: Type = Type(Inner::Xid8); + + /// ANYCOMPATIBLE - pseudo-type representing a polymorphic common type + pub const ANYCOMPATIBLE: Type = Type(Inner::Anycompatible); + + /// ANYCOMPATIBLEARRAY - pseudo-type representing an array of polymorphic common type elements + pub const ANYCOMPATIBLEARRAY: Type = Type(Inner::Anycompatiblearray); + + /// ANYCOMPATIBLENONARRAY - pseudo-type representing a polymorphic common type that is not an array + pub const ANYCOMPATIBLENONARRAY: Type = Type(Inner::Anycompatiblenonarray); + + /// ANYCOMPATIBLERANGE - pseudo-type representing a range over a polymorphic common type + pub const ANYCOMPATIBLE_RANGE: Type = Type(Inner::AnycompatibleRange); + + /// INT4MULTIRANGE[] + pub const INT4MULTI_RANGE_ARRAY: Type = Type(Inner::Int4multiRangeArray); + + /// NUMMULTIRANGE[] + pub const NUMMULTI_RANGE_ARRAY: Type = Type(Inner::NummultiRangeArray); + + /// TSMULTIRANGE[] + pub const TSMULTI_RANGE_ARRAY: Type = Type(Inner::TsmultiRangeArray); + + /// TSTZMULTIRANGE[] + pub const TSTZMULTI_RANGE_ARRAY: Type = Type(Inner::TstzmultiRangeArray); + + /// DATEMULTIRANGE[] + pub const DATEMULTI_RANGE_ARRAY: Type = Type(Inner::DatemultiRangeArray); + + /// INT8MULTIRANGE[] + pub const INT8MULTI_RANGE_ARRAY: Type = Type(Inner::Int8multiRangeArray); +} diff --git a/libs/proxy/tokio-postgres2/Cargo.toml b/libs/proxy/tokio-postgres2/Cargo.toml new file mode 100644 index 0000000000..7130c1b726 --- /dev/null +++ b/libs/proxy/tokio-postgres2/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "tokio-postgres2" +version = "0.1.0" +edition = "2018" +license = "MIT/Apache-2.0" + +[dependencies] +async-trait.workspace = true +bytes.workspace = true +byteorder.workspace = true +fallible-iterator.workspace = true +futures-util = { workspace = true, features = ["sink"] } +log = "0.4" +parking_lot.workspace = true +percent-encoding = "2.0" +pin-project-lite.workspace = true +phf = "0.11" +postgres-protocol2 = { path = "../postgres-protocol2" } +postgres-types2 = { path = "../postgres-types2" } +tokio = { workspace = true, features = ["io-util", "time", "net"] } +tokio-util = { workspace = true, features = ["codec"] } diff --git a/libs/proxy/tokio-postgres2/src/cancel_query.rs b/libs/proxy/tokio-postgres2/src/cancel_query.rs new file mode 100644 index 0000000000..cddbf16336 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/cancel_query.rs @@ -0,0 +1,40 @@ +use tokio::net::TcpStream; + +use crate::client::SocketConfig; +use crate::config::{Host, SslMode}; +use crate::tls::MakeTlsConnect; +use crate::{cancel_query_raw, connect_socket, Error}; +use std::io; + +pub(crate) async fn cancel_query( + config: Option, + ssl_mode: SslMode, + mut tls: T, + process_id: i32, + secret_key: i32, +) -> Result<(), Error> +where + T: MakeTlsConnect, +{ + let config = match config { + Some(config) => config, + None => { + return Err(Error::connect(io::Error::new( + io::ErrorKind::InvalidInput, + "unknown host", + ))) + } + }; + + let hostname = match &config.host { + Host::Tcp(host) => &**host, + }; + let tls = tls + .make_tls_connect(hostname) + .map_err(|e| Error::tls(e.into()))?; + + let socket = + connect_socket::connect_socket(&config.host, config.port, config.connect_timeout).await?; + + cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, process_id, secret_key).await +} diff --git a/libs/proxy/tokio-postgres2/src/cancel_query_raw.rs b/libs/proxy/tokio-postgres2/src/cancel_query_raw.rs new file mode 100644 index 0000000000..8c08296435 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/cancel_query_raw.rs @@ -0,0 +1,29 @@ +use crate::config::SslMode; +use crate::tls::TlsConnect; +use crate::{connect_tls, Error}; +use bytes::BytesMut; +use postgres_protocol2::message::frontend; +use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; + +pub async fn cancel_query_raw( + stream: S, + mode: SslMode, + tls: T, + process_id: i32, + secret_key: i32, +) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, +{ + let mut stream = connect_tls::connect_tls(stream, mode, tls).await?; + + let mut buf = BytesMut::new(); + frontend::cancel_request(process_id, secret_key, &mut buf); + + stream.write_all(&buf).await.map_err(Error::io)?; + stream.flush().await.map_err(Error::io)?; + stream.shutdown().await.map_err(Error::io)?; + + Ok(()) +} diff --git a/libs/proxy/tokio-postgres2/src/cancel_token.rs b/libs/proxy/tokio-postgres2/src/cancel_token.rs new file mode 100644 index 0000000000..b949bf358f --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/cancel_token.rs @@ -0,0 +1,62 @@ +use crate::config::SslMode; +use crate::tls::TlsConnect; + +use crate::{cancel_query, client::SocketConfig, tls::MakeTlsConnect}; +use crate::{cancel_query_raw, Error}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpStream; + +/// The capability to request cancellation of in-progress queries on a +/// connection. +#[derive(Clone)] +pub struct CancelToken { + pub(crate) socket_config: Option, + pub(crate) ssl_mode: SslMode, + pub(crate) process_id: i32, + pub(crate) secret_key: i32, +} + +impl CancelToken { + /// Attempts to cancel the in-progress query on the connection associated + /// with this `CancelToken`. + /// + /// The server provides no information about whether a cancellation attempt was successful or not. An error will + /// only be returned if the client was unable to connect to the database. + /// + /// Cancellation is inherently racy. There is no guarantee that the + /// cancellation request will reach the server before the query terminates + /// normally, or that the connection associated with this token is still + /// active. + /// + /// Requires the `runtime` Cargo feature (enabled by default). + pub async fn cancel_query(&self, tls: T) -> Result<(), Error> + where + T: MakeTlsConnect, + { + cancel_query::cancel_query( + self.socket_config.clone(), + self.ssl_mode, + tls, + self.process_id, + self.secret_key, + ) + .await + } + + /// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new + /// connection itself. + pub async fn cancel_query_raw(&self, stream: S, tls: T) -> Result<(), Error> + where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, + { + cancel_query_raw::cancel_query_raw( + stream, + self.ssl_mode, + tls, + self.process_id, + self.secret_key, + ) + .await + } +} diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs new file mode 100644 index 0000000000..96200b71e7 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -0,0 +1,439 @@ +use crate::codec::{BackendMessages, FrontendMessage}; + +use crate::config::Host; +use crate::config::SslMode; +use crate::connection::{Request, RequestMessages}; + +use crate::query::RowStream; +use crate::simple_query::SimpleQueryStream; + +use crate::types::{Oid, ToSql, Type}; + +use crate::{ + prepare, query, simple_query, slice_iter, CancelToken, Error, ReadyForQueryStatus, Row, + SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, +}; +use bytes::BytesMut; +use fallible_iterator::FallibleIterator; +use futures_util::{future, ready, TryStreamExt}; +use parking_lot::Mutex; +use postgres_protocol2::message::{backend::Message, frontend}; +use std::collections::HashMap; +use std::fmt; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::mpsc; + +use std::time::Duration; + +pub struct Responses { + receiver: mpsc::Receiver, + cur: BackendMessages, +} + +impl Responses { + pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll> { + loop { + match self.cur.next().map_err(Error::parse)? { + Some(Message::ErrorResponse(body)) => return Poll::Ready(Err(Error::db(body))), + Some(message) => return Poll::Ready(Ok(message)), + None => {} + } + + match ready!(self.receiver.poll_recv(cx)) { + Some(messages) => self.cur = messages, + None => return Poll::Ready(Err(Error::closed())), + } + } + } + + pub async fn next(&mut self) -> Result { + future::poll_fn(|cx| self.poll_next(cx)).await + } +} + +/// A cache of type info and prepared statements for fetching type info +/// (corresponding to the queries in the [prepare] module). +#[derive(Default)] +struct CachedTypeInfo { + /// A statement for basic information for a type from its + /// OID. Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_QUERY) (or its + /// fallback). + typeinfo: Option, + /// A statement for getting information for a composite type from its OID. + /// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY). + typeinfo_composite: Option, + /// A statement for getting information for a composite type from its OID. + /// Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_COMPOSITE_QUERY) (or + /// its fallback). + typeinfo_enum: Option, + + /// Cache of types already looked up. + types: HashMap, +} + +pub struct InnerClient { + sender: mpsc::UnboundedSender, + cached_typeinfo: Mutex, + + /// A buffer to use when writing out postgres commands. + buffer: Mutex, +} + +impl InnerClient { + pub fn send(&self, messages: RequestMessages) -> Result { + let (sender, receiver) = mpsc::channel(1); + let request = Request { messages, sender }; + self.sender.send(request).map_err(|_| Error::closed())?; + + Ok(Responses { + receiver, + cur: BackendMessages::empty(), + }) + } + + pub fn typeinfo(&self) -> Option { + self.cached_typeinfo.lock().typeinfo.clone() + } + + pub fn set_typeinfo(&self, statement: &Statement) { + self.cached_typeinfo.lock().typeinfo = Some(statement.clone()); + } + + pub fn typeinfo_composite(&self) -> Option { + self.cached_typeinfo.lock().typeinfo_composite.clone() + } + + pub fn set_typeinfo_composite(&self, statement: &Statement) { + self.cached_typeinfo.lock().typeinfo_composite = Some(statement.clone()); + } + + pub fn typeinfo_enum(&self) -> Option { + self.cached_typeinfo.lock().typeinfo_enum.clone() + } + + pub fn set_typeinfo_enum(&self, statement: &Statement) { + self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone()); + } + + pub fn type_(&self, oid: Oid) -> Option { + self.cached_typeinfo.lock().types.get(&oid).cloned() + } + + pub fn set_type(&self, oid: Oid, type_: &Type) { + self.cached_typeinfo.lock().types.insert(oid, type_.clone()); + } + + /// Call the given function with a buffer to be used when writing out + /// postgres commands. + pub fn with_buf(&self, f: F) -> R + where + F: FnOnce(&mut BytesMut) -> R, + { + let mut buffer = self.buffer.lock(); + let r = f(&mut buffer); + buffer.clear(); + r + } +} + +#[derive(Clone)] +pub(crate) struct SocketConfig { + pub host: Host, + pub port: u16, + pub connect_timeout: Option, + // pub keepalive: Option, +} + +/// An asynchronous PostgreSQL client. +/// +/// The client is one half of what is returned when a connection is established. Users interact with the database +/// through this client object. +pub struct Client { + inner: Arc, + + socket_config: Option, + ssl_mode: SslMode, + process_id: i32, + secret_key: i32, +} + +impl Client { + pub(crate) fn new( + sender: mpsc::UnboundedSender, + ssl_mode: SslMode, + process_id: i32, + secret_key: i32, + ) -> Client { + Client { + inner: Arc::new(InnerClient { + sender, + cached_typeinfo: Default::default(), + buffer: Default::default(), + }), + + socket_config: None, + ssl_mode, + process_id, + secret_key, + } + } + + /// Returns process_id. + pub fn get_process_id(&self) -> i32 { + self.process_id + } + + pub(crate) fn inner(&self) -> &Arc { + &self.inner + } + + pub(crate) fn set_socket_config(&mut self, socket_config: SocketConfig) { + self.socket_config = Some(socket_config); + } + + /// Creates a new prepared statement. + /// + /// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc), + /// which are set when executed. Prepared statements can only be used with the connection that created them. + pub async fn prepare(&self, query: &str) -> Result { + self.prepare_typed(query, &[]).await + } + + /// Like `prepare`, but allows the types of query parameters to be explicitly specified. + /// + /// The list of types may be smaller than the number of parameters - the types of the remaining parameters will be + /// inferred. For example, `client.prepare_typed(query, &[])` is equivalent to `client.prepare(query)`. + pub async fn prepare_typed( + &self, + query: &str, + parameter_types: &[Type], + ) -> Result { + prepare::prepare(&self.inner, query, parameter_types).await + } + + /// Executes a statement, returning a vector of the resulting rows. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + /// + /// # Panics + /// + /// Panics if the number of parameters provided does not match the number expected. + pub async fn query( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result, Error> + where + T: ?Sized + ToStatement, + { + self.query_raw(statement, slice_iter(params)) + .await? + .try_collect() + .await + } + + /// The maximally flexible version of [`query`]. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + /// + /// # Panics + /// + /// Panics if the number of parameters provided does not match the number expected. + /// + /// [`query`]: #method.query + pub async fn query_raw<'a, T, I>(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let statement = statement.__convert().into_statement(self).await?; + query::query(&self.inner, statement, params).await + } + + /// Pass text directly to the Postgres backend to allow it to sort out typing itself and + /// to save a roundtrip + pub async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, + { + query::query_txt(&self.inner, statement, params).await + } + + /// Executes a statement, returning the number of rows modified. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + /// + /// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned. + /// + /// # Panics + /// + /// Panics if the number of parameters provided does not match the number expected. + pub async fn execute( + &self, + statement: &T, + params: &[&(dyn ToSql + Sync)], + ) -> Result + where + T: ?Sized + ToStatement, + { + self.execute_raw(statement, slice_iter(params)).await + } + + /// The maximally flexible version of [`execute`]. + /// + /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list + /// provided, 1-indexed. + /// + /// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be + /// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front + /// with the `prepare` method. + /// + /// # Panics + /// + /// Panics if the number of parameters provided does not match the number expected. + /// + /// [`execute`]: #method.execute + pub async fn execute_raw<'a, T, I>(&self, statement: &T, params: I) -> Result + where + T: ?Sized + ToStatement, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let statement = statement.__convert().into_statement(self).await?; + query::execute(self.inner(), statement, params).await + } + + /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows. + /// + /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that + /// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings, + /// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the + /// rows, this method returns a list of an enum which indicates either the completion of one of the commands, + /// or a row of data. This preserves the framing between the separate statements in the request. + /// + /// # Warning + /// + /// Prepared statements should be use for any query which contains user-specified data, as they provided the + /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass + /// them to this method! + pub async fn simple_query(&self, query: &str) -> Result, Error> { + self.simple_query_raw(query).await?.try_collect().await + } + + pub(crate) async fn simple_query_raw(&self, query: &str) -> Result { + simple_query::simple_query(self.inner(), query).await + } + + /// Executes a sequence of SQL statements using the simple query protocol. + /// + /// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that + /// point. This is intended for use when, for example, initializing a database schema. + /// + /// # Warning + /// + /// Prepared statements should be use for any query which contains user-specified data, as they provided the + /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass + /// them to this method! + pub async fn batch_execute(&self, query: &str) -> Result { + simple_query::batch_execute(self.inner(), query).await + } + + /// Begins a new database transaction. + /// + /// The transaction will roll back by default - use the `commit` method to commit it. + pub async fn transaction(&mut self) -> Result, Error> { + struct RollbackIfNotDone<'me> { + client: &'me Client, + done: bool, + } + + impl Drop for RollbackIfNotDone<'_> { + fn drop(&mut self) { + if self.done { + return; + } + + let buf = self.client.inner().with_buf(|buf| { + frontend::query("ROLLBACK", buf).unwrap(); + buf.split().freeze() + }); + let _ = self + .client + .inner() + .send(RequestMessages::Single(FrontendMessage::Raw(buf))); + } + } + + // This is done, as `Future` created by this method can be dropped after + // `RequestMessages` is synchronously send to the `Connection` by + // `batch_execute()`, but before `Responses` is asynchronously polled to + // completion. In that case `Transaction` won't be created and thus + // won't be rolled back. + { + let mut cleaner = RollbackIfNotDone { + client: self, + done: false, + }; + self.batch_execute("BEGIN").await?; + cleaner.done = true; + } + + Ok(Transaction::new(self)) + } + + /// Returns a builder for a transaction with custom settings. + /// + /// Unlike the `transaction` method, the builder can be used to control the transaction's isolation level and other + /// attributes. + pub fn build_transaction(&mut self) -> TransactionBuilder<'_> { + TransactionBuilder::new(self) + } + + /// Constructs a cancellation token that can later be used to request cancellation of a query running on the + /// connection associated with this client. + pub fn cancel_token(&self) -> CancelToken { + CancelToken { + socket_config: self.socket_config.clone(), + ssl_mode: self.ssl_mode, + process_id: self.process_id, + secret_key: self.secret_key, + } + } + + /// Query for type information + pub async fn get_type(&self, oid: Oid) -> Result { + crate::prepare::get_type(&self.inner, oid).await + } + + /// Determines if the connection to the server has already closed. + /// + /// In that case, all future queries will fail. + pub fn is_closed(&self) -> bool { + self.inner.sender.is_closed() + } +} + +impl fmt::Debug for Client { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Client").finish() + } +} diff --git a/libs/proxy/tokio-postgres2/src/codec.rs b/libs/proxy/tokio-postgres2/src/codec.rs new file mode 100644 index 0000000000..7412db785b --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/codec.rs @@ -0,0 +1,109 @@ +use bytes::{Buf, Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; +use postgres_protocol2::message::backend; +use postgres_protocol2::message::frontend::CopyData; +use std::io; +use tokio_util::codec::{Decoder, Encoder}; + +pub enum FrontendMessage { + Raw(Bytes), + CopyData(CopyData>), +} + +pub enum BackendMessage { + Normal { + messages: BackendMessages, + request_complete: bool, + }, + Async(backend::Message), +} + +pub struct BackendMessages(BytesMut); + +impl BackendMessages { + pub fn empty() -> BackendMessages { + BackendMessages(BytesMut::new()) + } +} + +impl FallibleIterator for BackendMessages { + type Item = backend::Message; + type Error = io::Error; + + fn next(&mut self) -> io::Result> { + backend::Message::parse(&mut self.0) + } +} + +pub struct PostgresCodec { + pub max_message_size: Option, +} + +impl Encoder for PostgresCodec { + type Error = io::Error; + + fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> { + match item { + FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf), + FrontendMessage::CopyData(data) => data.write(dst), + } + + Ok(()) + } +} + +impl Decoder for PostgresCodec { + type Item = BackendMessage; + type Error = io::Error; + + fn decode(&mut self, src: &mut BytesMut) -> Result, io::Error> { + let mut idx = 0; + let mut request_complete = false; + + while let Some(header) = backend::Header::parse(&src[idx..])? { + let len = header.len() as usize + 1; + if src[idx..].len() < len { + break; + } + + if let Some(max) = self.max_message_size { + if len > max { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "message too large", + )); + } + } + + match header.tag() { + backend::NOTICE_RESPONSE_TAG + | backend::NOTIFICATION_RESPONSE_TAG + | backend::PARAMETER_STATUS_TAG => { + if idx == 0 { + let message = backend::Message::parse(src)?.unwrap(); + return Ok(Some(BackendMessage::Async(message))); + } else { + break; + } + } + _ => {} + } + + idx += len; + + if header.tag() == backend::READY_FOR_QUERY_TAG { + request_complete = true; + break; + } + } + + if idx == 0 { + Ok(None) + } else { + Ok(Some(BackendMessage::Normal { + messages: BackendMessages(src.split_to(idx)), + request_complete, + })) + } + } +} diff --git a/libs/proxy/tokio-postgres2/src/config.rs b/libs/proxy/tokio-postgres2/src/config.rs new file mode 100644 index 0000000000..969c20ba47 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/config.rs @@ -0,0 +1,897 @@ +//! Connection configuration. + +use crate::connect::connect; +use crate::connect_raw::connect_raw; +use crate::tls::MakeTlsConnect; +use crate::tls::TlsConnect; +use crate::{Client, Connection, Error}; +use std::borrow::Cow; +use std::str; +use std::str::FromStr; +use std::time::Duration; +use std::{error, fmt, iter, mem}; +use tokio::io::{AsyncRead, AsyncWrite}; + +pub use postgres_protocol2::authentication::sasl::ScramKeys; +use tokio::net::TcpStream; + +/// Properties required of a session. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum TargetSessionAttrs { + /// No special properties are required. + Any, + /// The session must allow writes. + ReadWrite, +} + +/// TLS configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum SslMode { + /// Do not use TLS. + Disable, + /// Attempt to connect with TLS but allow sessions without. + Prefer, + /// Require the use of TLS. + Require, +} + +/// Channel binding configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum ChannelBinding { + /// Do not use channel binding. + Disable, + /// Attempt to use channel binding but allow sessions without. + Prefer, + /// Require the use of channel binding. + Require, +} + +/// Replication mode configuration. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[non_exhaustive] +pub enum ReplicationMode { + /// Physical replication. + Physical, + /// Logical replication. + Logical, +} + +/// A host specification. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Host { + /// A TCP hostname. + Tcp(String), +} + +/// Precomputed keys which may override password during auth. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum AuthKeys { + /// A `ClientKey` & `ServerKey` pair for `SCRAM-SHA-256`. + ScramSha256(ScramKeys<32>), +} + +/// Connection configuration. +/// +/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats: +/// +/// # Key-Value +/// +/// This format consists of space-separated key-value pairs. Values which are either the empty string or contain +/// whitespace should be wrapped in `'`. `'` and `\` characters should be backslash-escaped. +/// +/// ## Keys +/// +/// * `user` - The username to authenticate with. Required. +/// * `password` - The password to authenticate with. +/// * `dbname` - The name of the database to connect to. Defaults to the username. +/// * `options` - Command line options used to configure the server. +/// * `application_name` - Sets the `application_name` parameter on the server. +/// * `sslmode` - Controls usage of TLS. If set to `disable`, TLS will not be used. If set to `prefer`, TLS will be used +/// if available, but not used otherwise. If set to `require`, TLS will be forced to be used. Defaults to `prefer`. +/// * `host` - The host to connect to. On Unix platforms, if the host starts with a `/` character it is treated as the +/// path to the directory containing Unix domain sockets. Otherwise, it is treated as a hostname. Multiple hosts +/// can be specified, separated by commas. Each host will be tried in turn when connecting. Required if connecting +/// with the `connect` method. +/// * `port` - The port to connect to. Multiple ports can be specified, separated by commas. The number of ports must be +/// either 1, in which case it will be used for all hosts, or the same as the number of hosts. Defaults to 5432 if +/// omitted or the empty string. +/// * `connect_timeout` - The time limit in seconds applied to each socket-level connection attempt. Note that hostnames +/// can resolve to multiple IP addresses, and this limit is applied to each address. Defaults to no timeout. +/// * `target_session_attrs` - Specifies requirements of the session. If set to `read-write`, the client will check that +/// the `transaction_read_write` session parameter is set to `on`. This can be used to connect to the primary server +/// in a database cluster as opposed to the secondary read-only mirrors. Defaults to `all`. +/// * `channel_binding` - Controls usage of channel binding in the authentication process. If set to `disable`, channel +/// binding will not be used. If set to `prefer`, channel binding will be used if available, but not used otherwise. +/// If set to `require`, the authentication process will fail if channel binding is not used. Defaults to `prefer`. +/// +/// ## Examples +/// +/// ```not_rust +/// host=localhost user=postgres connect_timeout=10 keepalives=0 +/// ``` +/// +/// ```not_rust +/// host=/var/lib/postgresql,localhost port=1234 user=postgres password='password with spaces' +/// ``` +/// +/// ```not_rust +/// host=host1,host2,host3 port=1234,,5678 user=postgres target_session_attrs=read-write +/// ``` +/// +/// # Url +/// +/// This format resembles a URL with a scheme of either `postgres://` or `postgresql://`. All components are optional, +/// and the format accepts query parameters for all of the key-value pairs described in the section above. Multiple +/// host/port pairs can be comma-separated. Unix socket paths in the host section of the URL should be percent-encoded, +/// as the path component of the URL specifies the database name. +/// +/// ## Examples +/// +/// ```not_rust +/// postgresql://user@localhost +/// ``` +/// +/// ```not_rust +/// postgresql://user:password@%2Fvar%2Flib%2Fpostgresql/mydb?connect_timeout=10 +/// ``` +/// +/// ```not_rust +/// postgresql://user@host1:1234,host2,host3:5678?target_session_attrs=read-write +/// ``` +/// +/// ```not_rust +/// postgresql:///mydb?user=user&host=/var/lib/postgresql +/// ``` +#[derive(Clone, PartialEq, Eq)] +pub struct Config { + pub(crate) user: Option, + pub(crate) password: Option>, + pub(crate) auth_keys: Option>, + pub(crate) dbname: Option, + pub(crate) options: Option, + pub(crate) application_name: Option, + pub(crate) ssl_mode: SslMode, + pub(crate) host: Vec, + pub(crate) port: Vec, + pub(crate) connect_timeout: Option, + pub(crate) target_session_attrs: TargetSessionAttrs, + pub(crate) channel_binding: ChannelBinding, + pub(crate) replication_mode: Option, + pub(crate) max_backend_message_size: Option, +} + +impl Default for Config { + fn default() -> Config { + Config::new() + } +} + +impl Config { + /// Creates a new configuration. + pub fn new() -> Config { + Config { + user: None, + password: None, + auth_keys: None, + dbname: None, + options: None, + application_name: None, + ssl_mode: SslMode::Prefer, + host: vec![], + port: vec![], + connect_timeout: None, + target_session_attrs: TargetSessionAttrs::Any, + channel_binding: ChannelBinding::Prefer, + replication_mode: None, + max_backend_message_size: None, + } + } + + /// Sets the user to authenticate with. + /// + /// Required. + pub fn user(&mut self, user: &str) -> &mut Config { + self.user = Some(user.to_string()); + self + } + + /// Gets the user to authenticate with, if one has been configured with + /// the `user` method. + pub fn get_user(&self) -> Option<&str> { + self.user.as_deref() + } + + /// Sets the password to authenticate with. + pub fn password(&mut self, password: T) -> &mut Config + where + T: AsRef<[u8]>, + { + self.password = Some(password.as_ref().to_vec()); + self + } + + /// Gets the password to authenticate with, if one has been configured with + /// the `password` method. + pub fn get_password(&self) -> Option<&[u8]> { + self.password.as_deref() + } + + /// Sets precomputed protocol-specific keys to authenticate with. + /// When set, this option will override `password`. + /// See [`AuthKeys`] for more information. + pub fn auth_keys(&mut self, keys: AuthKeys) -> &mut Config { + self.auth_keys = Some(Box::new(keys)); + self + } + + /// Gets precomputed protocol-specific keys to authenticate with. + /// if one has been configured with the `auth_keys` method. + pub fn get_auth_keys(&self) -> Option { + self.auth_keys.as_deref().copied() + } + + /// Sets the name of the database to connect to. + /// + /// Defaults to the user. + pub fn dbname(&mut self, dbname: &str) -> &mut Config { + self.dbname = Some(dbname.to_string()); + self + } + + /// Gets the name of the database to connect to, if one has been configured + /// with the `dbname` method. + pub fn get_dbname(&self) -> Option<&str> { + self.dbname.as_deref() + } + + /// Sets command line options used to configure the server. + pub fn options(&mut self, options: &str) -> &mut Config { + self.options = Some(options.to_string()); + self + } + + /// Gets the command line options used to configure the server, if the + /// options have been set with the `options` method. + pub fn get_options(&self) -> Option<&str> { + self.options.as_deref() + } + + /// Sets the value of the `application_name` runtime parameter. + pub fn application_name(&mut self, application_name: &str) -> &mut Config { + self.application_name = Some(application_name.to_string()); + self + } + + /// Gets the value of the `application_name` runtime parameter, if it has + /// been set with the `application_name` method. + pub fn get_application_name(&self) -> Option<&str> { + self.application_name.as_deref() + } + + /// Sets the SSL configuration. + /// + /// Defaults to `prefer`. + pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config { + self.ssl_mode = ssl_mode; + self + } + + /// Gets the SSL configuration. + pub fn get_ssl_mode(&self) -> SslMode { + self.ssl_mode + } + + /// Adds a host to the configuration. + /// + /// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. + pub fn host(&mut self, host: &str) -> &mut Config { + self.host.push(Host::Tcp(host.to_string())); + self + } + + /// Gets the hosts that have been added to the configuration with `host`. + pub fn get_hosts(&self) -> &[Host] { + &self.host + } + + /// Adds a port to the configuration. + /// + /// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which + /// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports + /// as hosts. + pub fn port(&mut self, port: u16) -> &mut Config { + self.port.push(port); + self + } + + /// Gets the ports that have been added to the configuration with `port`. + pub fn get_ports(&self) -> &[u16] { + &self.port + } + + /// Sets the timeout applied to socket-level connection attempts. + /// + /// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each + /// host separately. Defaults to no limit. + pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config { + self.connect_timeout = Some(connect_timeout); + self + } + + /// Gets the connection timeout, if one has been set with the + /// `connect_timeout` method. + pub fn get_connect_timeout(&self) -> Option<&Duration> { + self.connect_timeout.as_ref() + } + + /// Sets the requirements of the session. + /// + /// This can be used to connect to the primary server in a clustered database rather than one of the read-only + /// secondary servers. Defaults to `Any`. + pub fn target_session_attrs( + &mut self, + target_session_attrs: TargetSessionAttrs, + ) -> &mut Config { + self.target_session_attrs = target_session_attrs; + self + } + + /// Gets the requirements of the session. + pub fn get_target_session_attrs(&self) -> TargetSessionAttrs { + self.target_session_attrs + } + + /// Sets the channel binding behavior. + /// + /// Defaults to `prefer`. + pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config { + self.channel_binding = channel_binding; + self + } + + /// Gets the channel binding behavior. + pub fn get_channel_binding(&self) -> ChannelBinding { + self.channel_binding + } + + /// Set replication mode. + pub fn replication_mode(&mut self, replication_mode: ReplicationMode) -> &mut Config { + self.replication_mode = Some(replication_mode); + self + } + + /// Get replication mode. + pub fn get_replication_mode(&self) -> Option { + self.replication_mode + } + + /// Set limit for backend messages size. + pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config { + self.max_backend_message_size = Some(max_backend_message_size); + self + } + + /// Get limit for backend messages size. + pub fn get_max_backend_message_size(&self) -> Option { + self.max_backend_message_size + } + + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { + match key { + "user" => { + self.user(value); + } + "password" => { + self.password(value); + } + "dbname" => { + self.dbname(value); + } + "options" => { + self.options(value); + } + "application_name" => { + self.application_name(value); + } + "sslmode" => { + let mode = match value { + "disable" => SslMode::Disable, + "prefer" => SslMode::Prefer, + "require" => SslMode::Require, + _ => return Err(Error::config_parse(Box::new(InvalidValue("sslmode")))), + }; + self.ssl_mode(mode); + } + "host" => { + for host in value.split(',') { + self.host(host); + } + } + "port" => { + for port in value.split(',') { + let port = if port.is_empty() { + 5432 + } else { + port.parse() + .map_err(|_| Error::config_parse(Box::new(InvalidValue("port"))))? + }; + self.port(port); + } + } + "connect_timeout" => { + let timeout = value + .parse::() + .map_err(|_| Error::config_parse(Box::new(InvalidValue("connect_timeout"))))?; + if timeout > 0 { + self.connect_timeout(Duration::from_secs(timeout as u64)); + } + } + "target_session_attrs" => { + let target_session_attrs = match value { + "any" => TargetSessionAttrs::Any, + "read-write" => TargetSessionAttrs::ReadWrite, + _ => { + return Err(Error::config_parse(Box::new(InvalidValue( + "target_session_attrs", + )))); + } + }; + self.target_session_attrs(target_session_attrs); + } + "channel_binding" => { + let channel_binding = match value { + "disable" => ChannelBinding::Disable, + "prefer" => ChannelBinding::Prefer, + "require" => ChannelBinding::Require, + _ => { + return Err(Error::config_parse(Box::new(InvalidValue( + "channel_binding", + )))) + } + }; + self.channel_binding(channel_binding); + } + "max_backend_message_size" => { + let limit = value.parse::().map_err(|_| { + Error::config_parse(Box::new(InvalidValue("max_backend_message_size"))) + })?; + if limit > 0 { + self.max_backend_message_size(limit); + } + } + key => { + return Err(Error::config_parse(Box::new(UnknownOption( + key.to_string(), + )))); + } + } + + Ok(()) + } + + /// Opens a connection to a PostgreSQL database. + /// + /// Requires the `runtime` Cargo feature (enabled by default). + pub async fn connect( + &self, + tls: T, + ) -> Result<(Client, Connection), Error> + where + T: MakeTlsConnect, + { + connect(tls, self).await + } + + /// Connects to a PostgreSQL database over an arbitrary stream. + /// + /// All of the settings other than `user`, `password`, `dbname`, `options`, and `application_name` name are ignored. + pub async fn connect_raw( + &self, + stream: S, + tls: T, + ) -> Result<(Client, Connection), Error> + where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, + { + connect_raw(stream, tls, self).await + } +} + +impl FromStr for Config { + type Err = Error; + + fn from_str(s: &str) -> Result { + match UrlParser::parse(s)? { + Some(config) => Ok(config), + None => Parser::parse(s), + } + } +} + +// Omit password from debug output +impl fmt::Debug for Config { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct Redaction {} + impl fmt::Debug for Redaction { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "_") + } + } + + f.debug_struct("Config") + .field("user", &self.user) + .field("password", &self.password.as_ref().map(|_| Redaction {})) + .field("dbname", &self.dbname) + .field("options", &self.options) + .field("application_name", &self.application_name) + .field("ssl_mode", &self.ssl_mode) + .field("host", &self.host) + .field("port", &self.port) + .field("connect_timeout", &self.connect_timeout) + .field("target_session_attrs", &self.target_session_attrs) + .field("channel_binding", &self.channel_binding) + .field("replication", &self.replication_mode) + .finish() + } +} + +#[derive(Debug)] +struct UnknownOption(String); + +impl fmt::Display for UnknownOption { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "unknown option `{}`", self.0) + } +} + +impl error::Error for UnknownOption {} + +#[derive(Debug)] +struct InvalidValue(&'static str); + +impl fmt::Display for InvalidValue { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "invalid value for option `{}`", self.0) + } +} + +impl error::Error for InvalidValue {} + +struct Parser<'a> { + s: &'a str, + it: iter::Peekable>, +} + +impl<'a> Parser<'a> { + fn parse(s: &'a str) -> Result { + let mut parser = Parser { + s, + it: s.char_indices().peekable(), + }; + + let mut config = Config::new(); + + while let Some((key, value)) = parser.parameter()? { + config.param(key, &value)?; + } + + Ok(config) + } + + fn skip_ws(&mut self) { + self.take_while(char::is_whitespace); + } + + fn take_while(&mut self, f: F) -> &'a str + where + F: Fn(char) -> bool, + { + let start = match self.it.peek() { + Some(&(i, _)) => i, + None => return "", + }; + + loop { + match self.it.peek() { + Some(&(_, c)) if f(c) => { + self.it.next(); + } + Some(&(i, _)) => return &self.s[start..i], + None => return &self.s[start..], + } + } + } + + fn eat(&mut self, target: char) -> Result<(), Error> { + match self.it.next() { + Some((_, c)) if c == target => Ok(()), + Some((i, c)) => { + let m = format!( + "unexpected character at byte {}: expected `{}` but got `{}`", + i, target, c + ); + Err(Error::config_parse(m.into())) + } + None => Err(Error::config_parse("unexpected EOF".into())), + } + } + + fn eat_if(&mut self, target: char) -> bool { + match self.it.peek() { + Some(&(_, c)) if c == target => { + self.it.next(); + true + } + _ => false, + } + } + + fn keyword(&mut self) -> Option<&'a str> { + let s = self.take_while(|c| match c { + c if c.is_whitespace() => false, + '=' => false, + _ => true, + }); + + if s.is_empty() { + None + } else { + Some(s) + } + } + + fn value(&mut self) -> Result { + let value = if self.eat_if('\'') { + let value = self.quoted_value()?; + self.eat('\'')?; + value + } else { + self.simple_value()? + }; + + Ok(value) + } + + fn simple_value(&mut self) -> Result { + let mut value = String::new(); + + while let Some(&(_, c)) = self.it.peek() { + if c.is_whitespace() { + break; + } + + self.it.next(); + if c == '\\' { + if let Some((_, c2)) = self.it.next() { + value.push(c2); + } + } else { + value.push(c); + } + } + + if value.is_empty() { + return Err(Error::config_parse("unexpected EOF".into())); + } + + Ok(value) + } + + fn quoted_value(&mut self) -> Result { + let mut value = String::new(); + + while let Some(&(_, c)) = self.it.peek() { + if c == '\'' { + return Ok(value); + } + + self.it.next(); + if c == '\\' { + if let Some((_, c2)) = self.it.next() { + value.push(c2); + } + } else { + value.push(c); + } + } + + Err(Error::config_parse( + "unterminated quoted connection parameter value".into(), + )) + } + + fn parameter(&mut self) -> Result, Error> { + self.skip_ws(); + let keyword = match self.keyword() { + Some(keyword) => keyword, + None => return Ok(None), + }; + self.skip_ws(); + self.eat('=')?; + self.skip_ws(); + let value = self.value()?; + + Ok(Some((keyword, value))) + } +} + +// This is a pretty sloppy "URL" parser, but it matches the behavior of libpq, where things really aren't very strict +struct UrlParser<'a> { + s: &'a str, + config: Config, +} + +impl<'a> UrlParser<'a> { + fn parse(s: &'a str) -> Result, Error> { + let s = match Self::remove_url_prefix(s) { + Some(s) => s, + None => return Ok(None), + }; + + let mut parser = UrlParser { + s, + config: Config::new(), + }; + + parser.parse_credentials()?; + parser.parse_host()?; + parser.parse_path()?; + parser.parse_params()?; + + Ok(Some(parser.config)) + } + + fn remove_url_prefix(s: &str) -> Option<&str> { + for prefix in &["postgres://", "postgresql://"] { + if let Some(stripped) = s.strip_prefix(prefix) { + return Some(stripped); + } + } + + None + } + + fn take_until(&mut self, end: &[char]) -> Option<&'a str> { + match self.s.find(end) { + Some(pos) => { + let (head, tail) = self.s.split_at(pos); + self.s = tail; + Some(head) + } + None => None, + } + } + + fn take_all(&mut self) -> &'a str { + mem::take(&mut self.s) + } + + fn eat_byte(&mut self) { + self.s = &self.s[1..]; + } + + fn parse_credentials(&mut self) -> Result<(), Error> { + let creds = match self.take_until(&['@']) { + Some(creds) => creds, + None => return Ok(()), + }; + self.eat_byte(); + + let mut it = creds.splitn(2, ':'); + let user = self.decode(it.next().unwrap())?; + self.config.user(&user); + + if let Some(password) = it.next() { + let password = Cow::from(percent_encoding::percent_decode(password.as_bytes())); + self.config.password(password); + } + + Ok(()) + } + + fn parse_host(&mut self) -> Result<(), Error> { + let host = match self.take_until(&['/', '?']) { + Some(host) => host, + None => self.take_all(), + }; + + if host.is_empty() { + return Ok(()); + } + + for chunk in host.split(',') { + let (host, port) = if chunk.starts_with('[') { + let idx = match chunk.find(']') { + Some(idx) => idx, + None => return Err(Error::config_parse(InvalidValue("host").into())), + }; + + let host = &chunk[1..idx]; + let remaining = &chunk[idx + 1..]; + let port = if let Some(port) = remaining.strip_prefix(':') { + Some(port) + } else if remaining.is_empty() { + None + } else { + return Err(Error::config_parse(InvalidValue("host").into())); + }; + + (host, port) + } else { + let mut it = chunk.splitn(2, ':'); + (it.next().unwrap(), it.next()) + }; + + self.host_param(host)?; + let port = self.decode(port.unwrap_or("5432"))?; + self.config.param("port", &port)?; + } + + Ok(()) + } + + fn parse_path(&mut self) -> Result<(), Error> { + if !self.s.starts_with('/') { + return Ok(()); + } + self.eat_byte(); + + let dbname = match self.take_until(&['?']) { + Some(dbname) => dbname, + None => self.take_all(), + }; + + if !dbname.is_empty() { + self.config.dbname(&self.decode(dbname)?); + } + + Ok(()) + } + + fn parse_params(&mut self) -> Result<(), Error> { + if !self.s.starts_with('?') { + return Ok(()); + } + self.eat_byte(); + + while !self.s.is_empty() { + let key = match self.take_until(&['=']) { + Some(key) => self.decode(key)?, + None => return Err(Error::config_parse("unterminated parameter".into())), + }; + self.eat_byte(); + + let value = match self.take_until(&['&']) { + Some(value) => { + self.eat_byte(); + value + } + None => self.take_all(), + }; + + if key == "host" { + self.host_param(value)?; + } else { + let value = self.decode(value)?; + self.config.param(&key, &value)?; + } + } + + Ok(()) + } + + fn host_param(&mut self, s: &str) -> Result<(), Error> { + let s = self.decode(s)?; + self.config.param("host", &s) + } + + fn decode(&self, s: &'a str) -> Result, Error> { + percent_encoding::percent_decode(s.as_bytes()) + .decode_utf8() + .map_err(|e| Error::config_parse(e.into())) + } +} diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs new file mode 100644 index 0000000000..7517fe0cde --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -0,0 +1,112 @@ +use crate::client::SocketConfig; +use crate::config::{Host, TargetSessionAttrs}; +use crate::connect_raw::connect_raw; +use crate::connect_socket::connect_socket; +use crate::tls::{MakeTlsConnect, TlsConnect}; +use crate::{Client, Config, Connection, Error, SimpleQueryMessage}; +use futures_util::{future, pin_mut, Future, FutureExt, Stream}; +use std::io; +use std::task::Poll; +use tokio::net::TcpStream; + +pub async fn connect( + mut tls: T, + config: &Config, +) -> Result<(Client, Connection), Error> +where + T: MakeTlsConnect, +{ + if config.host.is_empty() { + return Err(Error::config("host missing".into())); + } + + if config.port.len() > 1 && config.port.len() != config.host.len() { + return Err(Error::config("invalid number of ports".into())); + } + + let mut error = None; + for (i, host) in config.host.iter().enumerate() { + let port = config + .port + .get(i) + .or_else(|| config.port.first()) + .copied() + .unwrap_or(5432); + + let hostname = match host { + Host::Tcp(host) => host.as_str(), + }; + + let tls = tls + .make_tls_connect(hostname) + .map_err(|e| Error::tls(e.into()))?; + + match connect_once(host, port, tls, config).await { + Ok((client, connection)) => return Ok((client, connection)), + Err(e) => error = Some(e), + } + } + + Err(error.unwrap()) +} + +async fn connect_once( + host: &Host, + port: u16, + tls: T, + config: &Config, +) -> Result<(Client, Connection), Error> +where + T: TlsConnect, +{ + let socket = connect_socket(host, port, config.connect_timeout).await?; + let (mut client, mut connection) = connect_raw(socket, tls, config).await?; + + if let TargetSessionAttrs::ReadWrite = config.target_session_attrs { + let rows = client.simple_query_raw("SHOW transaction_read_only"); + pin_mut!(rows); + + let rows = future::poll_fn(|cx| { + if connection.poll_unpin(cx)?.is_ready() { + return Poll::Ready(Err(Error::closed())); + } + + rows.as_mut().poll(cx) + }) + .await?; + pin_mut!(rows); + + loop { + let next = future::poll_fn(|cx| { + if connection.poll_unpin(cx)?.is_ready() { + return Poll::Ready(Some(Err(Error::closed()))); + } + + rows.as_mut().poll_next(cx) + }); + + match next.await.transpose()? { + Some(SimpleQueryMessage::Row(row)) => { + if row.try_get(0)? == Some("on") { + return Err(Error::connect(io::Error::new( + io::ErrorKind::PermissionDenied, + "database does not allow writes", + ))); + } else { + break; + } + } + Some(_) => {} + None => return Err(Error::unexpected_message()), + } + } + } + + client.set_socket_config(SocketConfig { + host: host.clone(), + port, + connect_timeout: config.connect_timeout, + }); + + Ok((client, connection)) +} diff --git a/libs/proxy/tokio-postgres2/src/connect_raw.rs b/libs/proxy/tokio-postgres2/src/connect_raw.rs new file mode 100644 index 0000000000..80677af969 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/connect_raw.rs @@ -0,0 +1,359 @@ +use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; +use crate::config::{self, AuthKeys, Config, ReplicationMode}; +use crate::connect_tls::connect_tls; +use crate::maybe_tls_stream::MaybeTlsStream; +use crate::tls::{TlsConnect, TlsStream}; +use crate::{Client, Connection, Error}; +use bytes::BytesMut; +use fallible_iterator::FallibleIterator; +use futures_util::{ready, Sink, SinkExt, Stream, TryStreamExt}; +use postgres_protocol2::authentication; +use postgres_protocol2::authentication::sasl; +use postgres_protocol2::authentication::sasl::ScramSha256; +use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message}; +use postgres_protocol2::message::frontend; +use std::collections::{HashMap, VecDeque}; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::sync::mpsc; +use tokio_util::codec::Framed; + +pub struct StartupStream { + inner: Framed, PostgresCodec>, + buf: BackendMessages, + delayed: VecDeque, +} + +impl Sink for StartupStream +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + type Error = io::Error; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_ready(cx) + } + + fn start_send(mut self: Pin<&mut Self>, item: FrontendMessage) -> io::Result<()> { + Pin::new(&mut self.inner).start_send(item) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.inner).poll_close(cx) + } +} + +impl Stream for StartupStream +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + type Item = io::Result; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { + loop { + match self.buf.next() { + Ok(Some(message)) => return Poll::Ready(Some(Ok(message))), + Ok(None) => {} + Err(e) => return Poll::Ready(Some(Err(e))), + } + + match ready!(Pin::new(&mut self.inner).poll_next(cx)) { + Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages, + Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))), + Some(Err(e)) => return Poll::Ready(Some(Err(e))), + None => return Poll::Ready(None), + } + } + } +} + +pub async fn connect_raw( + stream: S, + tls: T, + config: &Config, +) -> Result<(Client, Connection), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, +{ + let stream = connect_tls(stream, config.ssl_mode, tls).await?; + + let mut stream = StartupStream { + inner: Framed::new( + stream, + PostgresCodec { + max_message_size: config.max_backend_message_size, + }, + ), + buf: BackendMessages::empty(), + delayed: VecDeque::new(), + }; + + startup(&mut stream, config).await?; + authenticate(&mut stream, config).await?; + let (process_id, secret_key, parameters) = read_info(&mut stream).await?; + + let (sender, receiver) = mpsc::unbounded_channel(); + let client = Client::new(sender, config.ssl_mode, process_id, secret_key); + let connection = Connection::new(stream.inner, stream.delayed, parameters, receiver); + + Ok((client, connection)) +} + +async fn startup(stream: &mut StartupStream, config: &Config) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + let mut params = vec![("client_encoding", "UTF8")]; + if let Some(user) = &config.user { + params.push(("user", &**user)); + } + if let Some(dbname) = &config.dbname { + params.push(("database", &**dbname)); + } + if let Some(options) = &config.options { + params.push(("options", &**options)); + } + if let Some(application_name) = &config.application_name { + params.push(("application_name", &**application_name)); + } + if let Some(replication_mode) = &config.replication_mode { + match replication_mode { + ReplicationMode::Physical => params.push(("replication", "true")), + ReplicationMode::Logical => params.push(("replication", "database")), + } + } + + let mut buf = BytesMut::new(); + frontend::startup_message(params, &mut buf).map_err(Error::encode)?; + + stream + .send(FrontendMessage::Raw(buf.freeze())) + .await + .map_err(Error::io) +} + +async fn authenticate(stream: &mut StartupStream, config: &Config) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsStream + Unpin, +{ + match stream.try_next().await.map_err(Error::io)? { + Some(Message::AuthenticationOk) => { + can_skip_channel_binding(config)?; + return Ok(()); + } + Some(Message::AuthenticationCleartextPassword) => { + can_skip_channel_binding(config)?; + + let pass = config + .password + .as_ref() + .ok_or_else(|| Error::config("password missing".into()))?; + + authenticate_password(stream, pass).await?; + } + Some(Message::AuthenticationMd5Password(body)) => { + can_skip_channel_binding(config)?; + + let user = config + .user + .as_ref() + .ok_or_else(|| Error::config("user missing".into()))?; + let pass = config + .password + .as_ref() + .ok_or_else(|| Error::config("password missing".into()))?; + + let output = authentication::md5_hash(user.as_bytes(), pass, body.salt()); + authenticate_password(stream, output.as_bytes()).await?; + } + Some(Message::AuthenticationSasl(body)) => { + authenticate_sasl(stream, body, config).await?; + } + Some(Message::AuthenticationKerberosV5) + | Some(Message::AuthenticationScmCredential) + | Some(Message::AuthenticationGss) + | Some(Message::AuthenticationSspi) => { + return Err(Error::authentication( + "unsupported authentication method".into(), + )) + } + Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), + Some(_) => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), + } + + match stream.try_next().await.map_err(Error::io)? { + Some(Message::AuthenticationOk) => Ok(()), + Some(Message::ErrorResponse(body)) => Err(Error::db(body)), + Some(_) => Err(Error::unexpected_message()), + None => Err(Error::closed()), + } +} + +fn can_skip_channel_binding(config: &Config) -> Result<(), Error> { + match config.channel_binding { + config::ChannelBinding::Disable | config::ChannelBinding::Prefer => Ok(()), + config::ChannelBinding::Require => Err(Error::authentication( + "server did not use channel binding".into(), + )), + } +} + +async fn authenticate_password( + stream: &mut StartupStream, + password: &[u8], +) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + let mut buf = BytesMut::new(); + frontend::password_message(password, &mut buf).map_err(Error::encode)?; + + stream + .send(FrontendMessage::Raw(buf.freeze())) + .await + .map_err(Error::io) +} + +async fn authenticate_sasl( + stream: &mut StartupStream, + body: AuthenticationSaslBody, + config: &Config, +) -> Result<(), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsStream + Unpin, +{ + let mut has_scram = false; + let mut has_scram_plus = false; + let mut mechanisms = body.mechanisms(); + while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? { + match mechanism { + sasl::SCRAM_SHA_256 => has_scram = true, + sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true, + _ => {} + } + } + + let channel_binding = stream + .inner + .get_ref() + .channel_binding() + .tls_server_end_point + .filter(|_| config.channel_binding != config::ChannelBinding::Disable) + .map(sasl::ChannelBinding::tls_server_end_point); + + let (channel_binding, mechanism) = if has_scram_plus { + match channel_binding { + Some(channel_binding) => (channel_binding, sasl::SCRAM_SHA_256_PLUS), + None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256), + } + } else if has_scram { + match channel_binding { + Some(_) => (sasl::ChannelBinding::unrequested(), sasl::SCRAM_SHA_256), + None => (sasl::ChannelBinding::unsupported(), sasl::SCRAM_SHA_256), + } + } else { + return Err(Error::authentication("unsupported SASL mechanism".into())); + }; + + if mechanism != sasl::SCRAM_SHA_256_PLUS { + can_skip_channel_binding(config)?; + } + + let mut scram = if let Some(AuthKeys::ScramSha256(keys)) = config.get_auth_keys() { + ScramSha256::new_with_keys(keys, channel_binding) + } else if let Some(password) = config.get_password() { + ScramSha256::new(password, channel_binding) + } else { + return Err(Error::config("password or auth keys missing".into())); + }; + + let mut buf = BytesMut::new(); + frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?; + stream + .send(FrontendMessage::Raw(buf.freeze())) + .await + .map_err(Error::io)?; + + let body = match stream.try_next().await.map_err(Error::io)? { + Some(Message::AuthenticationSaslContinue(body)) => body, + Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), + Some(_) => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), + }; + + scram + .update(body.data()) + .await + .map_err(|e| Error::authentication(e.into()))?; + + let mut buf = BytesMut::new(); + frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?; + stream + .send(FrontendMessage::Raw(buf.freeze())) + .await + .map_err(Error::io)?; + + let body = match stream.try_next().await.map_err(Error::io)? { + Some(Message::AuthenticationSaslFinal(body)) => body, + Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), + Some(_) => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), + }; + + scram + .finish(body.data()) + .map_err(|e| Error::authentication(e.into()))?; + + Ok(()) +} + +async fn read_info( + stream: &mut StartupStream, +) -> Result<(i32, i32, HashMap), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + let mut process_id = 0; + let mut secret_key = 0; + let mut parameters = HashMap::new(); + + loop { + match stream.try_next().await.map_err(Error::io)? { + Some(Message::BackendKeyData(body)) => { + process_id = body.process_id(); + secret_key = body.secret_key(); + } + Some(Message::ParameterStatus(body)) => { + parameters.insert( + body.name().map_err(Error::parse)?.to_string(), + body.value().map_err(Error::parse)?.to_string(), + ); + } + Some(msg @ Message::NoticeResponse(_)) => { + stream.delayed.push_back(BackendMessage::Async(msg)) + } + Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)), + Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), + Some(_) => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), + } + } +} diff --git a/libs/proxy/tokio-postgres2/src/connect_socket.rs b/libs/proxy/tokio-postgres2/src/connect_socket.rs new file mode 100644 index 0000000000..336a13317f --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/connect_socket.rs @@ -0,0 +1,65 @@ +use crate::config::Host; +use crate::Error; +use std::future::Future; +use std::io; +use std::time::Duration; +use tokio::net::{self, TcpStream}; +use tokio::time; + +pub(crate) async fn connect_socket( + host: &Host, + port: u16, + connect_timeout: Option, +) -> Result { + match host { + Host::Tcp(host) => { + let addrs = net::lookup_host((&**host, port)) + .await + .map_err(Error::connect)?; + + let mut last_err = None; + + for addr in addrs { + let stream = + match connect_with_timeout(TcpStream::connect(addr), connect_timeout).await { + Ok(stream) => stream, + Err(e) => { + last_err = Some(e); + continue; + } + }; + + stream.set_nodelay(true).map_err(Error::connect)?; + + return Ok(stream); + } + + Err(last_err.unwrap_or_else(|| { + Error::connect(io::Error::new( + io::ErrorKind::InvalidInput, + "could not resolve any addresses", + )) + })) + } + } +} + +async fn connect_with_timeout(connect: F, timeout: Option) -> Result +where + F: Future>, +{ + match timeout { + Some(timeout) => match time::timeout(timeout, connect).await { + Ok(Ok(socket)) => Ok(socket), + Ok(Err(e)) => Err(Error::connect(e)), + Err(_) => Err(Error::connect(io::Error::new( + io::ErrorKind::TimedOut, + "connection timed out", + ))), + }, + None => match connect.await { + Ok(socket) => Ok(socket), + Err(e) => Err(Error::connect(e)), + }, + } +} diff --git a/libs/proxy/tokio-postgres2/src/connect_tls.rs b/libs/proxy/tokio-postgres2/src/connect_tls.rs new file mode 100644 index 0000000000..64b0b68abc --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/connect_tls.rs @@ -0,0 +1,48 @@ +use crate::config::SslMode; +use crate::maybe_tls_stream::MaybeTlsStream; +use crate::tls::private::ForcePrivateApi; +use crate::tls::TlsConnect; +use crate::Error; +use bytes::BytesMut; +use postgres_protocol2::message::frontend; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; + +pub async fn connect_tls( + mut stream: S, + mode: SslMode, + tls: T, +) -> Result, Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsConnect, +{ + match mode { + SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)), + SslMode::Prefer if !tls.can_connect(ForcePrivateApi) => { + return Ok(MaybeTlsStream::Raw(stream)) + } + SslMode::Prefer | SslMode::Require => {} + } + + let mut buf = BytesMut::new(); + frontend::ssl_request(&mut buf); + stream.write_all(&buf).await.map_err(Error::io)?; + + let mut buf = [0]; + stream.read_exact(&mut buf).await.map_err(Error::io)?; + + if buf[0] != b'S' { + if SslMode::Require == mode { + return Err(Error::tls("server does not support TLS".into())); + } else { + return Ok(MaybeTlsStream::Raw(stream)); + } + } + + let stream = tls + .connect(stream) + .await + .map_err(|e| Error::tls(e.into()))?; + + Ok(MaybeTlsStream::Tls(stream)) +} diff --git a/libs/proxy/tokio-postgres2/src/connection.rs b/libs/proxy/tokio-postgres2/src/connection.rs new file mode 100644 index 0000000000..0aa5c77e22 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/connection.rs @@ -0,0 +1,323 @@ +use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; +use crate::error::DbError; +use crate::maybe_tls_stream::MaybeTlsStream; +use crate::{AsyncMessage, Error, Notification}; +use bytes::BytesMut; +use fallible_iterator::FallibleIterator; +use futures_util::{ready, Sink, Stream}; +use log::{info, trace}; +use postgres_protocol2::message::backend::Message; +use postgres_protocol2::message::frontend; +use std::collections::{HashMap, VecDeque}; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::sync::mpsc; +use tokio_util::codec::Framed; +use tokio_util::sync::PollSender; + +pub enum RequestMessages { + Single(FrontendMessage), +} + +pub struct Request { + pub messages: RequestMessages, + pub sender: mpsc::Sender, +} + +pub struct Response { + sender: PollSender, +} + +#[derive(PartialEq, Debug)] +enum State { + Active, + Terminating, + Closing, +} + +/// A connection to a PostgreSQL database. +/// +/// This is one half of what is returned when a new connection is established. It performs the actual IO with the +/// server, and should generally be spawned off onto an executor to run in the background. +/// +/// `Connection` implements `Future`, and only resolves when the connection is closed, either because a fatal error has +/// occurred, or because its associated `Client` has dropped and all outstanding work has completed. +#[must_use = "futures do nothing unless polled"] +pub struct Connection { + /// HACK: we need this in the Neon Proxy. + pub stream: Framed, PostgresCodec>, + /// HACK: we need this in the Neon Proxy to forward params. + pub parameters: HashMap, + receiver: mpsc::UnboundedReceiver, + pending_request: Option, + pending_responses: VecDeque, + responses: VecDeque, + state: State, +} + +impl Connection +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + pub(crate) fn new( + stream: Framed, PostgresCodec>, + pending_responses: VecDeque, + parameters: HashMap, + receiver: mpsc::UnboundedReceiver, + ) -> Connection { + Connection { + stream, + parameters, + receiver, + pending_request: None, + pending_responses, + responses: VecDeque::new(), + state: State::Active, + } + } + + fn poll_response( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + if let Some(message) = self.pending_responses.pop_front() { + trace!("retrying pending response"); + return Poll::Ready(Some(Ok(message))); + } + + Pin::new(&mut self.stream) + .poll_next(cx) + .map(|o| o.map(|r| r.map_err(Error::io))) + } + + fn poll_read(&mut self, cx: &mut Context<'_>) -> Result, Error> { + if self.state != State::Active { + trace!("poll_read: done"); + return Ok(None); + } + + loop { + let message = match self.poll_response(cx)? { + Poll::Ready(Some(message)) => message, + Poll::Ready(None) => return Err(Error::closed()), + Poll::Pending => { + trace!("poll_read: waiting on response"); + return Ok(None); + } + }; + + let (mut messages, request_complete) = match message { + BackendMessage::Async(Message::NoticeResponse(body)) => { + let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?; + return Ok(Some(AsyncMessage::Notice(error))); + } + BackendMessage::Async(Message::NotificationResponse(body)) => { + let notification = Notification { + process_id: body.process_id(), + channel: body.channel().map_err(Error::parse)?.to_string(), + payload: body.message().map_err(Error::parse)?.to_string(), + }; + return Ok(Some(AsyncMessage::Notification(notification))); + } + BackendMessage::Async(Message::ParameterStatus(body)) => { + self.parameters.insert( + body.name().map_err(Error::parse)?.to_string(), + body.value().map_err(Error::parse)?.to_string(), + ); + continue; + } + BackendMessage::Async(_) => unreachable!(), + BackendMessage::Normal { + messages, + request_complete, + } => (messages, request_complete), + }; + + let mut response = match self.responses.pop_front() { + Some(response) => response, + None => match messages.next().map_err(Error::parse)? { + Some(Message::ErrorResponse(error)) => return Err(Error::db(error)), + _ => return Err(Error::unexpected_message()), + }, + }; + + match response.sender.poll_reserve(cx) { + Poll::Ready(Ok(())) => { + let _ = response.sender.send_item(messages); + if !request_complete { + self.responses.push_front(response); + } + } + Poll::Ready(Err(_)) => { + // we need to keep paging through the rest of the messages even if the receiver's hung up + if !request_complete { + self.responses.push_front(response); + } + } + Poll::Pending => { + self.responses.push_front(response); + self.pending_responses.push_back(BackendMessage::Normal { + messages, + request_complete, + }); + trace!("poll_read: waiting on sender"); + return Ok(None); + } + } + } + } + + fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Some(messages) = self.pending_request.take() { + trace!("retrying pending request"); + return Poll::Ready(Some(messages)); + } + + if self.receiver.is_closed() { + return Poll::Ready(None); + } + + match self.receiver.poll_recv(cx) { + Poll::Ready(Some(request)) => { + trace!("polled new request"); + self.responses.push_back(Response { + sender: PollSender::new(request.sender), + }); + Poll::Ready(Some(request.messages)) + } + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } + + fn poll_write(&mut self, cx: &mut Context<'_>) -> Result { + loop { + if self.state == State::Closing { + trace!("poll_write: done"); + return Ok(false); + } + + if Pin::new(&mut self.stream) + .poll_ready(cx) + .map_err(Error::io)? + .is_pending() + { + trace!("poll_write: waiting on socket"); + return Ok(false); + } + + let request = match self.poll_request(cx) { + Poll::Ready(Some(request)) => request, + Poll::Ready(None) if self.responses.is_empty() && self.state == State::Active => { + trace!("poll_write: at eof, terminating"); + self.state = State::Terminating; + let mut request = BytesMut::new(); + frontend::terminate(&mut request); + RequestMessages::Single(FrontendMessage::Raw(request.freeze())) + } + Poll::Ready(None) => { + trace!( + "poll_write: at eof, pending responses {}", + self.responses.len() + ); + return Ok(true); + } + Poll::Pending => { + trace!("poll_write: waiting on request"); + return Ok(true); + } + }; + + match request { + RequestMessages::Single(request) => { + Pin::new(&mut self.stream) + .start_send(request) + .map_err(Error::io)?; + if self.state == State::Terminating { + trace!("poll_write: sent eof, closing"); + self.state = State::Closing; + } + } + } + } + } + + fn poll_flush(&mut self, cx: &mut Context<'_>) -> Result<(), Error> { + match Pin::new(&mut self.stream) + .poll_flush(cx) + .map_err(Error::io)? + { + Poll::Ready(()) => trace!("poll_flush: flushed"), + Poll::Pending => trace!("poll_flush: waiting on socket"), + } + Ok(()) + } + + fn poll_shutdown(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.state != State::Closing { + return Poll::Pending; + } + + match Pin::new(&mut self.stream) + .poll_close(cx) + .map_err(Error::io)? + { + Poll::Ready(()) => { + trace!("poll_shutdown: complete"); + Poll::Ready(Ok(())) + } + Poll::Pending => { + trace!("poll_shutdown: waiting on socket"); + Poll::Pending + } + } + } + + /// Returns the value of a runtime parameter for this connection. + pub fn parameter(&self, name: &str) -> Option<&str> { + self.parameters.get(name).map(|s| &**s) + } + + /// Polls for asynchronous messages from the server. + /// + /// The server can send notices as well as notifications asynchronously to the client. Applications that wish to + /// examine those messages should use this method to drive the connection rather than its `Future` implementation. + pub fn poll_message( + &mut self, + cx: &mut Context<'_>, + ) -> Poll>> { + let message = self.poll_read(cx)?; + let want_flush = self.poll_write(cx)?; + if want_flush { + self.poll_flush(cx)?; + } + match message { + Some(message) => Poll::Ready(Some(Ok(message))), + None => match self.poll_shutdown(cx) { + Poll::Ready(Ok(())) => Poll::Ready(None), + Poll::Ready(Err(e)) => Poll::Ready(Some(Err(e))), + Poll::Pending => Poll::Pending, + }, + } + } +} + +impl Future for Connection +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + type Output = Result<(), Error>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + while let Some(message) = ready!(self.poll_message(cx)?) { + if let AsyncMessage::Notice(notice) = message { + info!("{}: {}", notice.severity(), notice.message()); + } + } + Poll::Ready(Ok(())) + } +} diff --git a/libs/proxy/tokio-postgres2/src/error/mod.rs b/libs/proxy/tokio-postgres2/src/error/mod.rs new file mode 100644 index 0000000000..6514322250 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/error/mod.rs @@ -0,0 +1,501 @@ +//! Errors. + +use fallible_iterator::FallibleIterator; +use postgres_protocol2::message::backend::{ErrorFields, ErrorResponseBody}; +use std::error::{self, Error as _Error}; +use std::fmt; +use std::io; + +pub use self::sqlstate::*; + +#[allow(clippy::unreadable_literal)] +mod sqlstate; + +/// The severity of a Postgres error or notice. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum Severity { + /// PANIC + Panic, + /// FATAL + Fatal, + /// ERROR + Error, + /// WARNING + Warning, + /// NOTICE + Notice, + /// DEBUG + Debug, + /// INFO + Info, + /// LOG + Log, +} + +impl fmt::Display for Severity { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match *self { + Severity::Panic => "PANIC", + Severity::Fatal => "FATAL", + Severity::Error => "ERROR", + Severity::Warning => "WARNING", + Severity::Notice => "NOTICE", + Severity::Debug => "DEBUG", + Severity::Info => "INFO", + Severity::Log => "LOG", + }; + fmt.write_str(s) + } +} + +impl Severity { + fn from_str(s: &str) -> Option { + match s { + "PANIC" => Some(Severity::Panic), + "FATAL" => Some(Severity::Fatal), + "ERROR" => Some(Severity::Error), + "WARNING" => Some(Severity::Warning), + "NOTICE" => Some(Severity::Notice), + "DEBUG" => Some(Severity::Debug), + "INFO" => Some(Severity::Info), + "LOG" => Some(Severity::Log), + _ => None, + } + } +} + +/// A Postgres error or notice. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DbError { + severity: String, + parsed_severity: Option, + code: SqlState, + message: String, + detail: Option, + hint: Option, + position: Option, + where_: Option, + schema: Option, + table: Option, + column: Option, + datatype: Option, + constraint: Option, + file: Option, + line: Option, + routine: Option, +} + +impl DbError { + pub(crate) fn parse(fields: &mut ErrorFields<'_>) -> io::Result { + let mut severity = None; + let mut parsed_severity = None; + let mut code = None; + let mut message = None; + let mut detail = None; + let mut hint = None; + let mut normal_position = None; + let mut internal_position = None; + let mut internal_query = None; + let mut where_ = None; + let mut schema = None; + let mut table = None; + let mut column = None; + let mut datatype = None; + let mut constraint = None; + let mut file = None; + let mut line = None; + let mut routine = None; + + while let Some(field) = fields.next()? { + match field.type_() { + b'S' => severity = Some(field.value().to_owned()), + b'C' => code = Some(SqlState::from_code(field.value())), + b'M' => message = Some(field.value().to_owned()), + b'D' => detail = Some(field.value().to_owned()), + b'H' => hint = Some(field.value().to_owned()), + b'P' => { + normal_position = Some(field.value().parse::().map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "`P` field did not contain an integer", + ) + })?); + } + b'p' => { + internal_position = Some(field.value().parse::().map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "`p` field did not contain an integer", + ) + })?); + } + b'q' => internal_query = Some(field.value().to_owned()), + b'W' => where_ = Some(field.value().to_owned()), + b's' => schema = Some(field.value().to_owned()), + b't' => table = Some(field.value().to_owned()), + b'c' => column = Some(field.value().to_owned()), + b'd' => datatype = Some(field.value().to_owned()), + b'n' => constraint = Some(field.value().to_owned()), + b'F' => file = Some(field.value().to_owned()), + b'L' => { + line = Some(field.value().parse::().map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "`L` field did not contain an integer", + ) + })?); + } + b'R' => routine = Some(field.value().to_owned()), + b'V' => { + parsed_severity = Some(Severity::from_str(field.value()).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "`V` field contained an invalid value", + ) + })?); + } + _ => {} + } + } + + Ok(DbError { + severity: severity + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`S` field missing"))?, + parsed_severity, + code: code + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`C` field missing"))?, + message: message + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "`M` field missing"))?, + detail, + hint, + position: match normal_position { + Some(position) => Some(ErrorPosition::Original(position)), + None => match internal_position { + Some(position) => Some(ErrorPosition::Internal { + position, + query: internal_query.ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "`q` field missing but `p` field present", + ) + })?, + }), + None => None, + }, + }, + where_, + schema, + table, + column, + datatype, + constraint, + file, + line, + routine, + }) + } + + /// The field contents are ERROR, FATAL, or PANIC (in an error message), + /// or WARNING, NOTICE, DEBUG, INFO, or LOG (in a notice message), or a + /// localized translation of one of these. + pub fn severity(&self) -> &str { + &self.severity + } + + /// A parsed, nonlocalized version of `severity`. (PostgreSQL 9.6+) + pub fn parsed_severity(&self) -> Option { + self.parsed_severity + } + + /// The SQLSTATE code for the error. + pub fn code(&self) -> &SqlState { + &self.code + } + + /// The primary human-readable error message. + /// + /// This should be accurate but terse (typically one line). + pub fn message(&self) -> &str { + &self.message + } + + /// An optional secondary error message carrying more detail about the + /// problem. + /// + /// Might run to multiple lines. + pub fn detail(&self) -> Option<&str> { + self.detail.as_deref() + } + + /// An optional suggestion what to do about the problem. + /// + /// This is intended to differ from `detail` in that it offers advice + /// (potentially inappropriate) rather than hard facts. Might run to + /// multiple lines. + pub fn hint(&self) -> Option<&str> { + self.hint.as_deref() + } + + /// An optional error cursor position into either the original query string + /// or an internally generated query. + pub fn position(&self) -> Option<&ErrorPosition> { + self.position.as_ref() + } + + /// An indication of the context in which the error occurred. + /// + /// Presently this includes a call stack traceback of active procedural + /// language functions and internally-generated queries. The trace is one + /// entry per line, most recent first. + pub fn where_(&self) -> Option<&str> { + self.where_.as_deref() + } + + /// If the error was associated with a specific database object, the name + /// of the schema containing that object, if any. (PostgreSQL 9.3+) + pub fn schema(&self) -> Option<&str> { + self.schema.as_deref() + } + + /// If the error was associated with a specific table, the name of the + /// table. (Refer to the schema name field for the name of the table's + /// schema.) (PostgreSQL 9.3+) + pub fn table(&self) -> Option<&str> { + self.table.as_deref() + } + + /// If the error was associated with a specific table column, the name of + /// the column. + /// + /// (Refer to the schema and table name fields to identify the table.) + /// (PostgreSQL 9.3+) + pub fn column(&self) -> Option<&str> { + self.column.as_deref() + } + + /// If the error was associated with a specific data type, the name of the + /// data type. (Refer to the schema name field for the name of the data + /// type's schema.) (PostgreSQL 9.3+) + pub fn datatype(&self) -> Option<&str> { + self.datatype.as_deref() + } + + /// If the error was associated with a specific constraint, the name of the + /// constraint. + /// + /// Refer to fields listed above for the associated table or domain. + /// (For this purpose, indexes are treated as constraints, even if they + /// weren't created with constraint syntax.) (PostgreSQL 9.3+) + pub fn constraint(&self) -> Option<&str> { + self.constraint.as_deref() + } + + /// The file name of the source-code location where the error was reported. + pub fn file(&self) -> Option<&str> { + self.file.as_deref() + } + + /// The line number of the source-code location where the error was + /// reported. + pub fn line(&self) -> Option { + self.line + } + + /// The name of the source-code routine reporting the error. + pub fn routine(&self) -> Option<&str> { + self.routine.as_deref() + } +} + +impl fmt::Display for DbError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(fmt, "{}: {}", self.severity, self.message)?; + if let Some(detail) = &self.detail { + write!(fmt, "\nDETAIL: {}", detail)?; + } + if let Some(hint) = &self.hint { + write!(fmt, "\nHINT: {}", hint)?; + } + Ok(()) + } +} + +impl error::Error for DbError {} + +/// Represents the position of an error in a query. +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum ErrorPosition { + /// A position in the original query. + Original(u32), + /// A position in an internally generated query. + Internal { + /// The byte position. + position: u32, + /// A query generated by the Postgres server. + query: String, + }, +} + +#[derive(Debug, PartialEq)] +enum Kind { + Io, + UnexpectedMessage, + Tls, + ToSql(usize), + FromSql(usize), + Column(String), + Closed, + Db, + Parse, + Encode, + Authentication, + ConfigParse, + Config, + Connect, + Timeout, +} + +struct ErrorInner { + kind: Kind, + cause: Option>, +} + +/// An error communicating with the Postgres server. +pub struct Error(Box); + +impl fmt::Debug for Error { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Error") + .field("kind", &self.0.kind) + .field("cause", &self.0.cause) + .finish() + } +} + +impl fmt::Display for Error { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0.kind { + Kind::Io => fmt.write_str("error communicating with the server")?, + Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?, + Kind::Tls => fmt.write_str("error performing TLS handshake")?, + Kind::ToSql(idx) => write!(fmt, "error serializing parameter {}", idx)?, + Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?, + Kind::Column(column) => write!(fmt, "invalid column `{}`", column)?, + Kind::Closed => fmt.write_str("connection closed")?, + Kind::Db => fmt.write_str("db error")?, + Kind::Parse => fmt.write_str("error parsing response from server")?, + Kind::Encode => fmt.write_str("error encoding message to server")?, + Kind::Authentication => fmt.write_str("authentication error")?, + Kind::ConfigParse => fmt.write_str("invalid connection string")?, + Kind::Config => fmt.write_str("invalid configuration")?, + Kind::Connect => fmt.write_str("error connecting to server")?, + Kind::Timeout => fmt.write_str("timeout waiting for server")?, + }; + if let Some(ref cause) = self.0.cause { + write!(fmt, ": {}", cause)?; + } + Ok(()) + } +} + +impl error::Error for Error { + fn source(&self) -> Option<&(dyn error::Error + 'static)> { + self.0.cause.as_ref().map(|e| &**e as _) + } +} + +impl Error { + /// Consumes the error, returning its cause. + pub fn into_source(self) -> Option> { + self.0.cause + } + + /// Returns the source of this error if it was a `DbError`. + /// + /// This is a simple convenience method. + pub fn as_db_error(&self) -> Option<&DbError> { + self.source().and_then(|e| e.downcast_ref::()) + } + + /// Determines if the error was associated with closed connection. + pub fn is_closed(&self) -> bool { + self.0.kind == Kind::Closed + } + + /// Returns the SQLSTATE error code associated with the error. + /// + /// This is a convenience method that downcasts the cause to a `DbError` and returns its code. + pub fn code(&self) -> Option<&SqlState> { + self.as_db_error().map(DbError::code) + } + + fn new(kind: Kind, cause: Option>) -> Error { + Error(Box::new(ErrorInner { kind, cause })) + } + + pub(crate) fn closed() -> Error { + Error::new(Kind::Closed, None) + } + + pub(crate) fn unexpected_message() -> Error { + Error::new(Kind::UnexpectedMessage, None) + } + + #[allow(clippy::needless_pass_by_value)] + pub(crate) fn db(error: ErrorResponseBody) -> Error { + match DbError::parse(&mut error.fields()) { + Ok(e) => Error::new(Kind::Db, Some(Box::new(e))), + Err(e) => Error::new(Kind::Parse, Some(Box::new(e))), + } + } + + pub(crate) fn parse(e: io::Error) -> Error { + Error::new(Kind::Parse, Some(Box::new(e))) + } + + pub(crate) fn encode(e: io::Error) -> Error { + Error::new(Kind::Encode, Some(Box::new(e))) + } + + #[allow(clippy::wrong_self_convention)] + pub(crate) fn to_sql(e: Box, idx: usize) -> Error { + Error::new(Kind::ToSql(idx), Some(e)) + } + + pub(crate) fn from_sql(e: Box, idx: usize) -> Error { + Error::new(Kind::FromSql(idx), Some(e)) + } + + pub(crate) fn column(column: String) -> Error { + Error::new(Kind::Column(column), None) + } + + pub(crate) fn tls(e: Box) -> Error { + Error::new(Kind::Tls, Some(e)) + } + + pub(crate) fn io(e: io::Error) -> Error { + Error::new(Kind::Io, Some(Box::new(e))) + } + + pub(crate) fn authentication(e: Box) -> Error { + Error::new(Kind::Authentication, Some(e)) + } + + pub(crate) fn config_parse(e: Box) -> Error { + Error::new(Kind::ConfigParse, Some(e)) + } + + pub(crate) fn config(e: Box) -> Error { + Error::new(Kind::Config, Some(e)) + } + + pub(crate) fn connect(e: io::Error) -> Error { + Error::new(Kind::Connect, Some(Box::new(e))) + } + + #[doc(hidden)] + pub fn __private_api_timeout() -> Error { + Error::new(Kind::Timeout, None) + } +} diff --git a/libs/proxy/tokio-postgres2/src/error/sqlstate.rs b/libs/proxy/tokio-postgres2/src/error/sqlstate.rs new file mode 100644 index 0000000000..13a1d75f95 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/error/sqlstate.rs @@ -0,0 +1,1670 @@ +// Autogenerated file - DO NOT EDIT + +/// A SQLSTATE error code +#[derive(PartialEq, Eq, Clone, Debug)] +pub struct SqlState(Inner); + +impl SqlState { + /// Creates a `SqlState` from its error code. + pub fn from_code(s: &str) -> SqlState { + match SQLSTATE_MAP.get(s) { + Some(state) => state.clone(), + None => SqlState(Inner::Other(s.into())), + } + } + + /// Returns the error code corresponding to the `SqlState`. + pub fn code(&self) -> &str { + match &self.0 { + Inner::E00000 => "00000", + Inner::E01000 => "01000", + Inner::E0100C => "0100C", + Inner::E01008 => "01008", + Inner::E01003 => "01003", + Inner::E01007 => "01007", + Inner::E01006 => "01006", + Inner::E01004 => "01004", + Inner::E01P01 => "01P01", + Inner::E02000 => "02000", + Inner::E02001 => "02001", + Inner::E03000 => "03000", + Inner::E08000 => "08000", + Inner::E08003 => "08003", + Inner::E08006 => "08006", + Inner::E08001 => "08001", + Inner::E08004 => "08004", + Inner::E08007 => "08007", + Inner::E08P01 => "08P01", + Inner::E09000 => "09000", + Inner::E0A000 => "0A000", + Inner::E0B000 => "0B000", + Inner::E0F000 => "0F000", + Inner::E0F001 => "0F001", + Inner::E0L000 => "0L000", + Inner::E0LP01 => "0LP01", + Inner::E0P000 => "0P000", + Inner::E0Z000 => "0Z000", + Inner::E0Z002 => "0Z002", + Inner::E20000 => "20000", + Inner::E21000 => "21000", + Inner::E22000 => "22000", + Inner::E2202E => "2202E", + Inner::E22021 => "22021", + Inner::E22008 => "22008", + Inner::E22012 => "22012", + Inner::E22005 => "22005", + Inner::E2200B => "2200B", + Inner::E22022 => "22022", + Inner::E22015 => "22015", + Inner::E2201E => "2201E", + Inner::E22014 => "22014", + Inner::E22016 => "22016", + Inner::E2201F => "2201F", + Inner::E2201G => "2201G", + Inner::E22018 => "22018", + Inner::E22007 => "22007", + Inner::E22019 => "22019", + Inner::E2200D => "2200D", + Inner::E22025 => "22025", + Inner::E22P06 => "22P06", + Inner::E22010 => "22010", + Inner::E22023 => "22023", + Inner::E22013 => "22013", + Inner::E2201B => "2201B", + Inner::E2201W => "2201W", + Inner::E2201X => "2201X", + Inner::E2202H => "2202H", + Inner::E2202G => "2202G", + Inner::E22009 => "22009", + Inner::E2200C => "2200C", + Inner::E2200G => "2200G", + Inner::E22004 => "22004", + Inner::E22002 => "22002", + Inner::E22003 => "22003", + Inner::E2200H => "2200H", + Inner::E22026 => "22026", + Inner::E22001 => "22001", + Inner::E22011 => "22011", + Inner::E22027 => "22027", + Inner::E22024 => "22024", + Inner::E2200F => "2200F", + Inner::E22P01 => "22P01", + Inner::E22P02 => "22P02", + Inner::E22P03 => "22P03", + Inner::E22P04 => "22P04", + Inner::E22P05 => "22P05", + Inner::E2200L => "2200L", + Inner::E2200M => "2200M", + Inner::E2200N => "2200N", + Inner::E2200S => "2200S", + Inner::E2200T => "2200T", + Inner::E22030 => "22030", + Inner::E22031 => "22031", + Inner::E22032 => "22032", + Inner::E22033 => "22033", + Inner::E22034 => "22034", + Inner::E22035 => "22035", + Inner::E22036 => "22036", + Inner::E22037 => "22037", + Inner::E22038 => "22038", + Inner::E22039 => "22039", + Inner::E2203A => "2203A", + Inner::E2203B => "2203B", + Inner::E2203C => "2203C", + Inner::E2203D => "2203D", + Inner::E2203E => "2203E", + Inner::E2203F => "2203F", + Inner::E2203G => "2203G", + Inner::E23000 => "23000", + Inner::E23001 => "23001", + Inner::E23502 => "23502", + Inner::E23503 => "23503", + Inner::E23505 => "23505", + Inner::E23514 => "23514", + Inner::E23P01 => "23P01", + Inner::E24000 => "24000", + Inner::E25000 => "25000", + Inner::E25001 => "25001", + Inner::E25002 => "25002", + Inner::E25008 => "25008", + Inner::E25003 => "25003", + Inner::E25004 => "25004", + Inner::E25005 => "25005", + Inner::E25006 => "25006", + Inner::E25007 => "25007", + Inner::E25P01 => "25P01", + Inner::E25P02 => "25P02", + Inner::E25P03 => "25P03", + Inner::E26000 => "26000", + Inner::E27000 => "27000", + Inner::E28000 => "28000", + Inner::E28P01 => "28P01", + Inner::E2B000 => "2B000", + Inner::E2BP01 => "2BP01", + Inner::E2D000 => "2D000", + Inner::E2F000 => "2F000", + Inner::E2F005 => "2F005", + Inner::E2F002 => "2F002", + Inner::E2F003 => "2F003", + Inner::E2F004 => "2F004", + Inner::E34000 => "34000", + Inner::E38000 => "38000", + Inner::E38001 => "38001", + Inner::E38002 => "38002", + Inner::E38003 => "38003", + Inner::E38004 => "38004", + Inner::E39000 => "39000", + Inner::E39001 => "39001", + Inner::E39004 => "39004", + Inner::E39P01 => "39P01", + Inner::E39P02 => "39P02", + Inner::E39P03 => "39P03", + Inner::E3B000 => "3B000", + Inner::E3B001 => "3B001", + Inner::E3D000 => "3D000", + Inner::E3F000 => "3F000", + Inner::E40000 => "40000", + Inner::E40002 => "40002", + Inner::E40001 => "40001", + Inner::E40003 => "40003", + Inner::E40P01 => "40P01", + Inner::E42000 => "42000", + Inner::E42601 => "42601", + Inner::E42501 => "42501", + Inner::E42846 => "42846", + Inner::E42803 => "42803", + Inner::E42P20 => "42P20", + Inner::E42P19 => "42P19", + Inner::E42830 => "42830", + Inner::E42602 => "42602", + Inner::E42622 => "42622", + Inner::E42939 => "42939", + Inner::E42804 => "42804", + Inner::E42P18 => "42P18", + Inner::E42P21 => "42P21", + Inner::E42P22 => "42P22", + Inner::E42809 => "42809", + Inner::E428C9 => "428C9", + Inner::E42703 => "42703", + Inner::E42883 => "42883", + Inner::E42P01 => "42P01", + Inner::E42P02 => "42P02", + Inner::E42704 => "42704", + Inner::E42701 => "42701", + Inner::E42P03 => "42P03", + Inner::E42P04 => "42P04", + Inner::E42723 => "42723", + Inner::E42P05 => "42P05", + Inner::E42P06 => "42P06", + Inner::E42P07 => "42P07", + Inner::E42712 => "42712", + Inner::E42710 => "42710", + Inner::E42702 => "42702", + Inner::E42725 => "42725", + Inner::E42P08 => "42P08", + Inner::E42P09 => "42P09", + Inner::E42P10 => "42P10", + Inner::E42611 => "42611", + Inner::E42P11 => "42P11", + Inner::E42P12 => "42P12", + Inner::E42P13 => "42P13", + Inner::E42P14 => "42P14", + Inner::E42P15 => "42P15", + Inner::E42P16 => "42P16", + Inner::E42P17 => "42P17", + Inner::E44000 => "44000", + Inner::E53000 => "53000", + Inner::E53100 => "53100", + Inner::E53200 => "53200", + Inner::E53300 => "53300", + Inner::E53400 => "53400", + Inner::E54000 => "54000", + Inner::E54001 => "54001", + Inner::E54011 => "54011", + Inner::E54023 => "54023", + Inner::E55000 => "55000", + Inner::E55006 => "55006", + Inner::E55P02 => "55P02", + Inner::E55P03 => "55P03", + Inner::E55P04 => "55P04", + Inner::E57000 => "57000", + Inner::E57014 => "57014", + Inner::E57P01 => "57P01", + Inner::E57P02 => "57P02", + Inner::E57P03 => "57P03", + Inner::E57P04 => "57P04", + Inner::E57P05 => "57P05", + Inner::E58000 => "58000", + Inner::E58030 => "58030", + Inner::E58P01 => "58P01", + Inner::E58P02 => "58P02", + Inner::E72000 => "72000", + Inner::EF0000 => "F0000", + Inner::EF0001 => "F0001", + Inner::EHV000 => "HV000", + Inner::EHV005 => "HV005", + Inner::EHV002 => "HV002", + Inner::EHV010 => "HV010", + Inner::EHV021 => "HV021", + Inner::EHV024 => "HV024", + Inner::EHV007 => "HV007", + Inner::EHV008 => "HV008", + Inner::EHV004 => "HV004", + Inner::EHV006 => "HV006", + Inner::EHV091 => "HV091", + Inner::EHV00B => "HV00B", + Inner::EHV00C => "HV00C", + Inner::EHV00D => "HV00D", + Inner::EHV090 => "HV090", + Inner::EHV00A => "HV00A", + Inner::EHV009 => "HV009", + Inner::EHV014 => "HV014", + Inner::EHV001 => "HV001", + Inner::EHV00P => "HV00P", + Inner::EHV00J => "HV00J", + Inner::EHV00K => "HV00K", + Inner::EHV00Q => "HV00Q", + Inner::EHV00R => "HV00R", + Inner::EHV00L => "HV00L", + Inner::EHV00M => "HV00M", + Inner::EHV00N => "HV00N", + Inner::EP0000 => "P0000", + Inner::EP0001 => "P0001", + Inner::EP0002 => "P0002", + Inner::EP0003 => "P0003", + Inner::EP0004 => "P0004", + Inner::EXX000 => "XX000", + Inner::EXX001 => "XX001", + Inner::EXX002 => "XX002", + Inner::Other(code) => code, + } + } + + /// 00000 + pub const SUCCESSFUL_COMPLETION: SqlState = SqlState(Inner::E00000); + + /// 01000 + pub const WARNING: SqlState = SqlState(Inner::E01000); + + /// 0100C + pub const WARNING_DYNAMIC_RESULT_SETS_RETURNED: SqlState = SqlState(Inner::E0100C); + + /// 01008 + pub const WARNING_IMPLICIT_ZERO_BIT_PADDING: SqlState = SqlState(Inner::E01008); + + /// 01003 + pub const WARNING_NULL_VALUE_ELIMINATED_IN_SET_FUNCTION: SqlState = SqlState(Inner::E01003); + + /// 01007 + pub const WARNING_PRIVILEGE_NOT_GRANTED: SqlState = SqlState(Inner::E01007); + + /// 01006 + pub const WARNING_PRIVILEGE_NOT_REVOKED: SqlState = SqlState(Inner::E01006); + + /// 01004 + pub const WARNING_STRING_DATA_RIGHT_TRUNCATION: SqlState = SqlState(Inner::E01004); + + /// 01P01 + pub const WARNING_DEPRECATED_FEATURE: SqlState = SqlState(Inner::E01P01); + + /// 02000 + pub const NO_DATA: SqlState = SqlState(Inner::E02000); + + /// 02001 + pub const NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED: SqlState = SqlState(Inner::E02001); + + /// 03000 + pub const SQL_STATEMENT_NOT_YET_COMPLETE: SqlState = SqlState(Inner::E03000); + + /// 08000 + pub const CONNECTION_EXCEPTION: SqlState = SqlState(Inner::E08000); + + /// 08003 + pub const CONNECTION_DOES_NOT_EXIST: SqlState = SqlState(Inner::E08003); + + /// 08006 + pub const CONNECTION_FAILURE: SqlState = SqlState(Inner::E08006); + + /// 08001 + pub const SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION: SqlState = SqlState(Inner::E08001); + + /// 08004 + pub const SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION: SqlState = SqlState(Inner::E08004); + + /// 08007 + pub const TRANSACTION_RESOLUTION_UNKNOWN: SqlState = SqlState(Inner::E08007); + + /// 08P01 + pub const PROTOCOL_VIOLATION: SqlState = SqlState(Inner::E08P01); + + /// 09000 + pub const TRIGGERED_ACTION_EXCEPTION: SqlState = SqlState(Inner::E09000); + + /// 0A000 + pub const FEATURE_NOT_SUPPORTED: SqlState = SqlState(Inner::E0A000); + + /// 0B000 + pub const INVALID_TRANSACTION_INITIATION: SqlState = SqlState(Inner::E0B000); + + /// 0F000 + pub const LOCATOR_EXCEPTION: SqlState = SqlState(Inner::E0F000); + + /// 0F001 + pub const L_E_INVALID_SPECIFICATION: SqlState = SqlState(Inner::E0F001); + + /// 0L000 + pub const INVALID_GRANTOR: SqlState = SqlState(Inner::E0L000); + + /// 0LP01 + pub const INVALID_GRANT_OPERATION: SqlState = SqlState(Inner::E0LP01); + + /// 0P000 + pub const INVALID_ROLE_SPECIFICATION: SqlState = SqlState(Inner::E0P000); + + /// 0Z000 + pub const DIAGNOSTICS_EXCEPTION: SqlState = SqlState(Inner::E0Z000); + + /// 0Z002 + pub const STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER: SqlState = + SqlState(Inner::E0Z002); + + /// 20000 + pub const CASE_NOT_FOUND: SqlState = SqlState(Inner::E20000); + + /// 21000 + pub const CARDINALITY_VIOLATION: SqlState = SqlState(Inner::E21000); + + /// 22000 + pub const DATA_EXCEPTION: SqlState = SqlState(Inner::E22000); + + /// 2202E + pub const ARRAY_ELEMENT_ERROR: SqlState = SqlState(Inner::E2202E); + + /// 2202E + pub const ARRAY_SUBSCRIPT_ERROR: SqlState = SqlState(Inner::E2202E); + + /// 22021 + pub const CHARACTER_NOT_IN_REPERTOIRE: SqlState = SqlState(Inner::E22021); + + /// 22008 + pub const DATETIME_FIELD_OVERFLOW: SqlState = SqlState(Inner::E22008); + + /// 22008 + pub const DATETIME_VALUE_OUT_OF_RANGE: SqlState = SqlState(Inner::E22008); + + /// 22012 + pub const DIVISION_BY_ZERO: SqlState = SqlState(Inner::E22012); + + /// 22005 + pub const ERROR_IN_ASSIGNMENT: SqlState = SqlState(Inner::E22005); + + /// 2200B + pub const ESCAPE_CHARACTER_CONFLICT: SqlState = SqlState(Inner::E2200B); + + /// 22022 + pub const INDICATOR_OVERFLOW: SqlState = SqlState(Inner::E22022); + + /// 22015 + pub const INTERVAL_FIELD_OVERFLOW: SqlState = SqlState(Inner::E22015); + + /// 2201E + pub const INVALID_ARGUMENT_FOR_LOG: SqlState = SqlState(Inner::E2201E); + + /// 22014 + pub const INVALID_ARGUMENT_FOR_NTILE: SqlState = SqlState(Inner::E22014); + + /// 22016 + pub const INVALID_ARGUMENT_FOR_NTH_VALUE: SqlState = SqlState(Inner::E22016); + + /// 2201F + pub const INVALID_ARGUMENT_FOR_POWER_FUNCTION: SqlState = SqlState(Inner::E2201F); + + /// 2201G + pub const INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION: SqlState = SqlState(Inner::E2201G); + + /// 22018 + pub const INVALID_CHARACTER_VALUE_FOR_CAST: SqlState = SqlState(Inner::E22018); + + /// 22007 + pub const INVALID_DATETIME_FORMAT: SqlState = SqlState(Inner::E22007); + + /// 22019 + pub const INVALID_ESCAPE_CHARACTER: SqlState = SqlState(Inner::E22019); + + /// 2200D + pub const INVALID_ESCAPE_OCTET: SqlState = SqlState(Inner::E2200D); + + /// 22025 + pub const INVALID_ESCAPE_SEQUENCE: SqlState = SqlState(Inner::E22025); + + /// 22P06 + pub const NONSTANDARD_USE_OF_ESCAPE_CHARACTER: SqlState = SqlState(Inner::E22P06); + + /// 22010 + pub const INVALID_INDICATOR_PARAMETER_VALUE: SqlState = SqlState(Inner::E22010); + + /// 22023 + pub const INVALID_PARAMETER_VALUE: SqlState = SqlState(Inner::E22023); + + /// 22013 + pub const INVALID_PRECEDING_OR_FOLLOWING_SIZE: SqlState = SqlState(Inner::E22013); + + /// 2201B + pub const INVALID_REGULAR_EXPRESSION: SqlState = SqlState(Inner::E2201B); + + /// 2201W + pub const INVALID_ROW_COUNT_IN_LIMIT_CLAUSE: SqlState = SqlState(Inner::E2201W); + + /// 2201X + pub const INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE: SqlState = SqlState(Inner::E2201X); + + /// 2202H + pub const INVALID_TABLESAMPLE_ARGUMENT: SqlState = SqlState(Inner::E2202H); + + /// 2202G + pub const INVALID_TABLESAMPLE_REPEAT: SqlState = SqlState(Inner::E2202G); + + /// 22009 + pub const INVALID_TIME_ZONE_DISPLACEMENT_VALUE: SqlState = SqlState(Inner::E22009); + + /// 2200C + pub const INVALID_USE_OF_ESCAPE_CHARACTER: SqlState = SqlState(Inner::E2200C); + + /// 2200G + pub const MOST_SPECIFIC_TYPE_MISMATCH: SqlState = SqlState(Inner::E2200G); + + /// 22004 + pub const NULL_VALUE_NOT_ALLOWED: SqlState = SqlState(Inner::E22004); + + /// 22002 + pub const NULL_VALUE_NO_INDICATOR_PARAMETER: SqlState = SqlState(Inner::E22002); + + /// 22003 + pub const NUMERIC_VALUE_OUT_OF_RANGE: SqlState = SqlState(Inner::E22003); + + /// 2200H + pub const SEQUENCE_GENERATOR_LIMIT_EXCEEDED: SqlState = SqlState(Inner::E2200H); + + /// 22026 + pub const STRING_DATA_LENGTH_MISMATCH: SqlState = SqlState(Inner::E22026); + + /// 22001 + pub const STRING_DATA_RIGHT_TRUNCATION: SqlState = SqlState(Inner::E22001); + + /// 22011 + pub const SUBSTRING_ERROR: SqlState = SqlState(Inner::E22011); + + /// 22027 + pub const TRIM_ERROR: SqlState = SqlState(Inner::E22027); + + /// 22024 + pub const UNTERMINATED_C_STRING: SqlState = SqlState(Inner::E22024); + + /// 2200F + pub const ZERO_LENGTH_CHARACTER_STRING: SqlState = SqlState(Inner::E2200F); + + /// 22P01 + pub const FLOATING_POINT_EXCEPTION: SqlState = SqlState(Inner::E22P01); + + /// 22P02 + pub const INVALID_TEXT_REPRESENTATION: SqlState = SqlState(Inner::E22P02); + + /// 22P03 + pub const INVALID_BINARY_REPRESENTATION: SqlState = SqlState(Inner::E22P03); + + /// 22P04 + pub const BAD_COPY_FILE_FORMAT: SqlState = SqlState(Inner::E22P04); + + /// 22P05 + pub const UNTRANSLATABLE_CHARACTER: SqlState = SqlState(Inner::E22P05); + + /// 2200L + pub const NOT_AN_XML_DOCUMENT: SqlState = SqlState(Inner::E2200L); + + /// 2200M + pub const INVALID_XML_DOCUMENT: SqlState = SqlState(Inner::E2200M); + + /// 2200N + pub const INVALID_XML_CONTENT: SqlState = SqlState(Inner::E2200N); + + /// 2200S + pub const INVALID_XML_COMMENT: SqlState = SqlState(Inner::E2200S); + + /// 2200T + pub const INVALID_XML_PROCESSING_INSTRUCTION: SqlState = SqlState(Inner::E2200T); + + /// 22030 + pub const DUPLICATE_JSON_OBJECT_KEY_VALUE: SqlState = SqlState(Inner::E22030); + + /// 22031 + pub const INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION: SqlState = SqlState(Inner::E22031); + + /// 22032 + pub const INVALID_JSON_TEXT: SqlState = SqlState(Inner::E22032); + + /// 22033 + pub const INVALID_SQL_JSON_SUBSCRIPT: SqlState = SqlState(Inner::E22033); + + /// 22034 + pub const MORE_THAN_ONE_SQL_JSON_ITEM: SqlState = SqlState(Inner::E22034); + + /// 22035 + pub const NO_SQL_JSON_ITEM: SqlState = SqlState(Inner::E22035); + + /// 22036 + pub const NON_NUMERIC_SQL_JSON_ITEM: SqlState = SqlState(Inner::E22036); + + /// 22037 + pub const NON_UNIQUE_KEYS_IN_A_JSON_OBJECT: SqlState = SqlState(Inner::E22037); + + /// 22038 + pub const SINGLETON_SQL_JSON_ITEM_REQUIRED: SqlState = SqlState(Inner::E22038); + + /// 22039 + pub const SQL_JSON_ARRAY_NOT_FOUND: SqlState = SqlState(Inner::E22039); + + /// 2203A + pub const SQL_JSON_MEMBER_NOT_FOUND: SqlState = SqlState(Inner::E2203A); + + /// 2203B + pub const SQL_JSON_NUMBER_NOT_FOUND: SqlState = SqlState(Inner::E2203B); + + /// 2203C + pub const SQL_JSON_OBJECT_NOT_FOUND: SqlState = SqlState(Inner::E2203C); + + /// 2203D + pub const TOO_MANY_JSON_ARRAY_ELEMENTS: SqlState = SqlState(Inner::E2203D); + + /// 2203E + pub const TOO_MANY_JSON_OBJECT_MEMBERS: SqlState = SqlState(Inner::E2203E); + + /// 2203F + pub const SQL_JSON_SCALAR_REQUIRED: SqlState = SqlState(Inner::E2203F); + + /// 2203G + pub const SQL_JSON_ITEM_CANNOT_BE_CAST_TO_TARGET_TYPE: SqlState = SqlState(Inner::E2203G); + + /// 23000 + pub const INTEGRITY_CONSTRAINT_VIOLATION: SqlState = SqlState(Inner::E23000); + + /// 23001 + pub const RESTRICT_VIOLATION: SqlState = SqlState(Inner::E23001); + + /// 23502 + pub const NOT_NULL_VIOLATION: SqlState = SqlState(Inner::E23502); + + /// 23503 + pub const FOREIGN_KEY_VIOLATION: SqlState = SqlState(Inner::E23503); + + /// 23505 + pub const UNIQUE_VIOLATION: SqlState = SqlState(Inner::E23505); + + /// 23514 + pub const CHECK_VIOLATION: SqlState = SqlState(Inner::E23514); + + /// 23P01 + pub const EXCLUSION_VIOLATION: SqlState = SqlState(Inner::E23P01); + + /// 24000 + pub const INVALID_CURSOR_STATE: SqlState = SqlState(Inner::E24000); + + /// 25000 + pub const INVALID_TRANSACTION_STATE: SqlState = SqlState(Inner::E25000); + + /// 25001 + pub const ACTIVE_SQL_TRANSACTION: SqlState = SqlState(Inner::E25001); + + /// 25002 + pub const BRANCH_TRANSACTION_ALREADY_ACTIVE: SqlState = SqlState(Inner::E25002); + + /// 25008 + pub const HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL: SqlState = SqlState(Inner::E25008); + + /// 25003 + pub const INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION: SqlState = SqlState(Inner::E25003); + + /// 25004 + pub const INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION: SqlState = + SqlState(Inner::E25004); + + /// 25005 + pub const NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION: SqlState = SqlState(Inner::E25005); + + /// 25006 + pub const READ_ONLY_SQL_TRANSACTION: SqlState = SqlState(Inner::E25006); + + /// 25007 + pub const SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED: SqlState = SqlState(Inner::E25007); + + /// 25P01 + pub const NO_ACTIVE_SQL_TRANSACTION: SqlState = SqlState(Inner::E25P01); + + /// 25P02 + pub const IN_FAILED_SQL_TRANSACTION: SqlState = SqlState(Inner::E25P02); + + /// 25P03 + pub const IDLE_IN_TRANSACTION_SESSION_TIMEOUT: SqlState = SqlState(Inner::E25P03); + + /// 26000 + pub const INVALID_SQL_STATEMENT_NAME: SqlState = SqlState(Inner::E26000); + + /// 26000 + pub const UNDEFINED_PSTATEMENT: SqlState = SqlState(Inner::E26000); + + /// 27000 + pub const TRIGGERED_DATA_CHANGE_VIOLATION: SqlState = SqlState(Inner::E27000); + + /// 28000 + pub const INVALID_AUTHORIZATION_SPECIFICATION: SqlState = SqlState(Inner::E28000); + + /// 28P01 + pub const INVALID_PASSWORD: SqlState = SqlState(Inner::E28P01); + + /// 2B000 + pub const DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST: SqlState = SqlState(Inner::E2B000); + + /// 2BP01 + pub const DEPENDENT_OBJECTS_STILL_EXIST: SqlState = SqlState(Inner::E2BP01); + + /// 2D000 + pub const INVALID_TRANSACTION_TERMINATION: SqlState = SqlState(Inner::E2D000); + + /// 2F000 + pub const SQL_ROUTINE_EXCEPTION: SqlState = SqlState(Inner::E2F000); + + /// 2F005 + pub const S_R_E_FUNCTION_EXECUTED_NO_RETURN_STATEMENT: SqlState = SqlState(Inner::E2F005); + + /// 2F002 + pub const S_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E2F002); + + /// 2F003 + pub const S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED: SqlState = SqlState(Inner::E2F003); + + /// 2F004 + pub const S_R_E_READING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E2F004); + + /// 34000 + pub const INVALID_CURSOR_NAME: SqlState = SqlState(Inner::E34000); + + /// 34000 + pub const UNDEFINED_CURSOR: SqlState = SqlState(Inner::E34000); + + /// 38000 + pub const EXTERNAL_ROUTINE_EXCEPTION: SqlState = SqlState(Inner::E38000); + + /// 38001 + pub const E_R_E_CONTAINING_SQL_NOT_PERMITTED: SqlState = SqlState(Inner::E38001); + + /// 38002 + pub const E_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E38002); + + /// 38003 + pub const E_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED: SqlState = SqlState(Inner::E38003); + + /// 38004 + pub const E_R_E_READING_SQL_DATA_NOT_PERMITTED: SqlState = SqlState(Inner::E38004); + + /// 39000 + pub const EXTERNAL_ROUTINE_INVOCATION_EXCEPTION: SqlState = SqlState(Inner::E39000); + + /// 39001 + pub const E_R_I_E_INVALID_SQLSTATE_RETURNED: SqlState = SqlState(Inner::E39001); + + /// 39004 + pub const E_R_I_E_NULL_VALUE_NOT_ALLOWED: SqlState = SqlState(Inner::E39004); + + /// 39P01 + pub const E_R_I_E_TRIGGER_PROTOCOL_VIOLATED: SqlState = SqlState(Inner::E39P01); + + /// 39P02 + pub const E_R_I_E_SRF_PROTOCOL_VIOLATED: SqlState = SqlState(Inner::E39P02); + + /// 39P03 + pub const E_R_I_E_EVENT_TRIGGER_PROTOCOL_VIOLATED: SqlState = SqlState(Inner::E39P03); + + /// 3B000 + pub const SAVEPOINT_EXCEPTION: SqlState = SqlState(Inner::E3B000); + + /// 3B001 + pub const S_E_INVALID_SPECIFICATION: SqlState = SqlState(Inner::E3B001); + + /// 3D000 + pub const INVALID_CATALOG_NAME: SqlState = SqlState(Inner::E3D000); + + /// 3D000 + pub const UNDEFINED_DATABASE: SqlState = SqlState(Inner::E3D000); + + /// 3F000 + pub const INVALID_SCHEMA_NAME: SqlState = SqlState(Inner::E3F000); + + /// 3F000 + pub const UNDEFINED_SCHEMA: SqlState = SqlState(Inner::E3F000); + + /// 40000 + pub const TRANSACTION_ROLLBACK: SqlState = SqlState(Inner::E40000); + + /// 40002 + pub const T_R_INTEGRITY_CONSTRAINT_VIOLATION: SqlState = SqlState(Inner::E40002); + + /// 40001 + pub const T_R_SERIALIZATION_FAILURE: SqlState = SqlState(Inner::E40001); + + /// 40003 + pub const T_R_STATEMENT_COMPLETION_UNKNOWN: SqlState = SqlState(Inner::E40003); + + /// 40P01 + pub const T_R_DEADLOCK_DETECTED: SqlState = SqlState(Inner::E40P01); + + /// 42000 + pub const SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION: SqlState = SqlState(Inner::E42000); + + /// 42601 + pub const SYNTAX_ERROR: SqlState = SqlState(Inner::E42601); + + /// 42501 + pub const INSUFFICIENT_PRIVILEGE: SqlState = SqlState(Inner::E42501); + + /// 42846 + pub const CANNOT_COERCE: SqlState = SqlState(Inner::E42846); + + /// 42803 + pub const GROUPING_ERROR: SqlState = SqlState(Inner::E42803); + + /// 42P20 + pub const WINDOWING_ERROR: SqlState = SqlState(Inner::E42P20); + + /// 42P19 + pub const INVALID_RECURSION: SqlState = SqlState(Inner::E42P19); + + /// 42830 + pub const INVALID_FOREIGN_KEY: SqlState = SqlState(Inner::E42830); + + /// 42602 + pub const INVALID_NAME: SqlState = SqlState(Inner::E42602); + + /// 42622 + pub const NAME_TOO_LONG: SqlState = SqlState(Inner::E42622); + + /// 42939 + pub const RESERVED_NAME: SqlState = SqlState(Inner::E42939); + + /// 42804 + pub const DATATYPE_MISMATCH: SqlState = SqlState(Inner::E42804); + + /// 42P18 + pub const INDETERMINATE_DATATYPE: SqlState = SqlState(Inner::E42P18); + + /// 42P21 + pub const COLLATION_MISMATCH: SqlState = SqlState(Inner::E42P21); + + /// 42P22 + pub const INDETERMINATE_COLLATION: SqlState = SqlState(Inner::E42P22); + + /// 42809 + pub const WRONG_OBJECT_TYPE: SqlState = SqlState(Inner::E42809); + + /// 428C9 + pub const GENERATED_ALWAYS: SqlState = SqlState(Inner::E428C9); + + /// 42703 + pub const UNDEFINED_COLUMN: SqlState = SqlState(Inner::E42703); + + /// 42883 + pub const UNDEFINED_FUNCTION: SqlState = SqlState(Inner::E42883); + + /// 42P01 + pub const UNDEFINED_TABLE: SqlState = SqlState(Inner::E42P01); + + /// 42P02 + pub const UNDEFINED_PARAMETER: SqlState = SqlState(Inner::E42P02); + + /// 42704 + pub const UNDEFINED_OBJECT: SqlState = SqlState(Inner::E42704); + + /// 42701 + pub const DUPLICATE_COLUMN: SqlState = SqlState(Inner::E42701); + + /// 42P03 + pub const DUPLICATE_CURSOR: SqlState = SqlState(Inner::E42P03); + + /// 42P04 + pub const DUPLICATE_DATABASE: SqlState = SqlState(Inner::E42P04); + + /// 42723 + pub const DUPLICATE_FUNCTION: SqlState = SqlState(Inner::E42723); + + /// 42P05 + pub const DUPLICATE_PSTATEMENT: SqlState = SqlState(Inner::E42P05); + + /// 42P06 + pub const DUPLICATE_SCHEMA: SqlState = SqlState(Inner::E42P06); + + /// 42P07 + pub const DUPLICATE_TABLE: SqlState = SqlState(Inner::E42P07); + + /// 42712 + pub const DUPLICATE_ALIAS: SqlState = SqlState(Inner::E42712); + + /// 42710 + pub const DUPLICATE_OBJECT: SqlState = SqlState(Inner::E42710); + + /// 42702 + pub const AMBIGUOUS_COLUMN: SqlState = SqlState(Inner::E42702); + + /// 42725 + pub const AMBIGUOUS_FUNCTION: SqlState = SqlState(Inner::E42725); + + /// 42P08 + pub const AMBIGUOUS_PARAMETER: SqlState = SqlState(Inner::E42P08); + + /// 42P09 + pub const AMBIGUOUS_ALIAS: SqlState = SqlState(Inner::E42P09); + + /// 42P10 + pub const INVALID_COLUMN_REFERENCE: SqlState = SqlState(Inner::E42P10); + + /// 42611 + pub const INVALID_COLUMN_DEFINITION: SqlState = SqlState(Inner::E42611); + + /// 42P11 + pub const INVALID_CURSOR_DEFINITION: SqlState = SqlState(Inner::E42P11); + + /// 42P12 + pub const INVALID_DATABASE_DEFINITION: SqlState = SqlState(Inner::E42P12); + + /// 42P13 + pub const INVALID_FUNCTION_DEFINITION: SqlState = SqlState(Inner::E42P13); + + /// 42P14 + pub const INVALID_PSTATEMENT_DEFINITION: SqlState = SqlState(Inner::E42P14); + + /// 42P15 + pub const INVALID_SCHEMA_DEFINITION: SqlState = SqlState(Inner::E42P15); + + /// 42P16 + pub const INVALID_TABLE_DEFINITION: SqlState = SqlState(Inner::E42P16); + + /// 42P17 + pub const INVALID_OBJECT_DEFINITION: SqlState = SqlState(Inner::E42P17); + + /// 44000 + pub const WITH_CHECK_OPTION_VIOLATION: SqlState = SqlState(Inner::E44000); + + /// 53000 + pub const INSUFFICIENT_RESOURCES: SqlState = SqlState(Inner::E53000); + + /// 53100 + pub const DISK_FULL: SqlState = SqlState(Inner::E53100); + + /// 53200 + pub const OUT_OF_MEMORY: SqlState = SqlState(Inner::E53200); + + /// 53300 + pub const TOO_MANY_CONNECTIONS: SqlState = SqlState(Inner::E53300); + + /// 53400 + pub const CONFIGURATION_LIMIT_EXCEEDED: SqlState = SqlState(Inner::E53400); + + /// 54000 + pub const PROGRAM_LIMIT_EXCEEDED: SqlState = SqlState(Inner::E54000); + + /// 54001 + pub const STATEMENT_TOO_COMPLEX: SqlState = SqlState(Inner::E54001); + + /// 54011 + pub const TOO_MANY_COLUMNS: SqlState = SqlState(Inner::E54011); + + /// 54023 + pub const TOO_MANY_ARGUMENTS: SqlState = SqlState(Inner::E54023); + + /// 55000 + pub const OBJECT_NOT_IN_PREREQUISITE_STATE: SqlState = SqlState(Inner::E55000); + + /// 55006 + pub const OBJECT_IN_USE: SqlState = SqlState(Inner::E55006); + + /// 55P02 + pub const CANT_CHANGE_RUNTIME_PARAM: SqlState = SqlState(Inner::E55P02); + + /// 55P03 + pub const LOCK_NOT_AVAILABLE: SqlState = SqlState(Inner::E55P03); + + /// 55P04 + pub const UNSAFE_NEW_ENUM_VALUE_USAGE: SqlState = SqlState(Inner::E55P04); + + /// 57000 + pub const OPERATOR_INTERVENTION: SqlState = SqlState(Inner::E57000); + + /// 57014 + pub const QUERY_CANCELED: SqlState = SqlState(Inner::E57014); + + /// 57P01 + pub const ADMIN_SHUTDOWN: SqlState = SqlState(Inner::E57P01); + + /// 57P02 + pub const CRASH_SHUTDOWN: SqlState = SqlState(Inner::E57P02); + + /// 57P03 + pub const CANNOT_CONNECT_NOW: SqlState = SqlState(Inner::E57P03); + + /// 57P04 + pub const DATABASE_DROPPED: SqlState = SqlState(Inner::E57P04); + + /// 57P05 + pub const IDLE_SESSION_TIMEOUT: SqlState = SqlState(Inner::E57P05); + + /// 58000 + pub const SYSTEM_ERROR: SqlState = SqlState(Inner::E58000); + + /// 58030 + pub const IO_ERROR: SqlState = SqlState(Inner::E58030); + + /// 58P01 + pub const UNDEFINED_FILE: SqlState = SqlState(Inner::E58P01); + + /// 58P02 + pub const DUPLICATE_FILE: SqlState = SqlState(Inner::E58P02); + + /// 72000 + pub const SNAPSHOT_TOO_OLD: SqlState = SqlState(Inner::E72000); + + /// F0000 + pub const CONFIG_FILE_ERROR: SqlState = SqlState(Inner::EF0000); + + /// F0001 + pub const LOCK_FILE_EXISTS: SqlState = SqlState(Inner::EF0001); + + /// HV000 + pub const FDW_ERROR: SqlState = SqlState(Inner::EHV000); + + /// HV005 + pub const FDW_COLUMN_NAME_NOT_FOUND: SqlState = SqlState(Inner::EHV005); + + /// HV002 + pub const FDW_DYNAMIC_PARAMETER_VALUE_NEEDED: SqlState = SqlState(Inner::EHV002); + + /// HV010 + pub const FDW_FUNCTION_SEQUENCE_ERROR: SqlState = SqlState(Inner::EHV010); + + /// HV021 + pub const FDW_INCONSISTENT_DESCRIPTOR_INFORMATION: SqlState = SqlState(Inner::EHV021); + + /// HV024 + pub const FDW_INVALID_ATTRIBUTE_VALUE: SqlState = SqlState(Inner::EHV024); + + /// HV007 + pub const FDW_INVALID_COLUMN_NAME: SqlState = SqlState(Inner::EHV007); + + /// HV008 + pub const FDW_INVALID_COLUMN_NUMBER: SqlState = SqlState(Inner::EHV008); + + /// HV004 + pub const FDW_INVALID_DATA_TYPE: SqlState = SqlState(Inner::EHV004); + + /// HV006 + pub const FDW_INVALID_DATA_TYPE_DESCRIPTORS: SqlState = SqlState(Inner::EHV006); + + /// HV091 + pub const FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER: SqlState = SqlState(Inner::EHV091); + + /// HV00B + pub const FDW_INVALID_HANDLE: SqlState = SqlState(Inner::EHV00B); + + /// HV00C + pub const FDW_INVALID_OPTION_INDEX: SqlState = SqlState(Inner::EHV00C); + + /// HV00D + pub const FDW_INVALID_OPTION_NAME: SqlState = SqlState(Inner::EHV00D); + + /// HV090 + pub const FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH: SqlState = SqlState(Inner::EHV090); + + /// HV00A + pub const FDW_INVALID_STRING_FORMAT: SqlState = SqlState(Inner::EHV00A); + + /// HV009 + pub const FDW_INVALID_USE_OF_NULL_POINTER: SqlState = SqlState(Inner::EHV009); + + /// HV014 + pub const FDW_TOO_MANY_HANDLES: SqlState = SqlState(Inner::EHV014); + + /// HV001 + pub const FDW_OUT_OF_MEMORY: SqlState = SqlState(Inner::EHV001); + + /// HV00P + pub const FDW_NO_SCHEMAS: SqlState = SqlState(Inner::EHV00P); + + /// HV00J + pub const FDW_OPTION_NAME_NOT_FOUND: SqlState = SqlState(Inner::EHV00J); + + /// HV00K + pub const FDW_REPLY_HANDLE: SqlState = SqlState(Inner::EHV00K); + + /// HV00Q + pub const FDW_SCHEMA_NOT_FOUND: SqlState = SqlState(Inner::EHV00Q); + + /// HV00R + pub const FDW_TABLE_NOT_FOUND: SqlState = SqlState(Inner::EHV00R); + + /// HV00L + pub const FDW_UNABLE_TO_CREATE_EXECUTION: SqlState = SqlState(Inner::EHV00L); + + /// HV00M + pub const FDW_UNABLE_TO_CREATE_REPLY: SqlState = SqlState(Inner::EHV00M); + + /// HV00N + pub const FDW_UNABLE_TO_ESTABLISH_CONNECTION: SqlState = SqlState(Inner::EHV00N); + + /// P0000 + pub const PLPGSQL_ERROR: SqlState = SqlState(Inner::EP0000); + + /// P0001 + pub const RAISE_EXCEPTION: SqlState = SqlState(Inner::EP0001); + + /// P0002 + pub const NO_DATA_FOUND: SqlState = SqlState(Inner::EP0002); + + /// P0003 + pub const TOO_MANY_ROWS: SqlState = SqlState(Inner::EP0003); + + /// P0004 + pub const ASSERT_FAILURE: SqlState = SqlState(Inner::EP0004); + + /// XX000 + pub const INTERNAL_ERROR: SqlState = SqlState(Inner::EXX000); + + /// XX001 + pub const DATA_CORRUPTED: SqlState = SqlState(Inner::EXX001); + + /// XX002 + pub const INDEX_CORRUPTED: SqlState = SqlState(Inner::EXX002); +} + +#[derive(PartialEq, Eq, Clone, Debug)] +#[allow(clippy::upper_case_acronyms)] +enum Inner { + E00000, + E01000, + E0100C, + E01008, + E01003, + E01007, + E01006, + E01004, + E01P01, + E02000, + E02001, + E03000, + E08000, + E08003, + E08006, + E08001, + E08004, + E08007, + E08P01, + E09000, + E0A000, + E0B000, + E0F000, + E0F001, + E0L000, + E0LP01, + E0P000, + E0Z000, + E0Z002, + E20000, + E21000, + E22000, + E2202E, + E22021, + E22008, + E22012, + E22005, + E2200B, + E22022, + E22015, + E2201E, + E22014, + E22016, + E2201F, + E2201G, + E22018, + E22007, + E22019, + E2200D, + E22025, + E22P06, + E22010, + E22023, + E22013, + E2201B, + E2201W, + E2201X, + E2202H, + E2202G, + E22009, + E2200C, + E2200G, + E22004, + E22002, + E22003, + E2200H, + E22026, + E22001, + E22011, + E22027, + E22024, + E2200F, + E22P01, + E22P02, + E22P03, + E22P04, + E22P05, + E2200L, + E2200M, + E2200N, + E2200S, + E2200T, + E22030, + E22031, + E22032, + E22033, + E22034, + E22035, + E22036, + E22037, + E22038, + E22039, + E2203A, + E2203B, + E2203C, + E2203D, + E2203E, + E2203F, + E2203G, + E23000, + E23001, + E23502, + E23503, + E23505, + E23514, + E23P01, + E24000, + E25000, + E25001, + E25002, + E25008, + E25003, + E25004, + E25005, + E25006, + E25007, + E25P01, + E25P02, + E25P03, + E26000, + E27000, + E28000, + E28P01, + E2B000, + E2BP01, + E2D000, + E2F000, + E2F005, + E2F002, + E2F003, + E2F004, + E34000, + E38000, + E38001, + E38002, + E38003, + E38004, + E39000, + E39001, + E39004, + E39P01, + E39P02, + E39P03, + E3B000, + E3B001, + E3D000, + E3F000, + E40000, + E40002, + E40001, + E40003, + E40P01, + E42000, + E42601, + E42501, + E42846, + E42803, + E42P20, + E42P19, + E42830, + E42602, + E42622, + E42939, + E42804, + E42P18, + E42P21, + E42P22, + E42809, + E428C9, + E42703, + E42883, + E42P01, + E42P02, + E42704, + E42701, + E42P03, + E42P04, + E42723, + E42P05, + E42P06, + E42P07, + E42712, + E42710, + E42702, + E42725, + E42P08, + E42P09, + E42P10, + E42611, + E42P11, + E42P12, + E42P13, + E42P14, + E42P15, + E42P16, + E42P17, + E44000, + E53000, + E53100, + E53200, + E53300, + E53400, + E54000, + E54001, + E54011, + E54023, + E55000, + E55006, + E55P02, + E55P03, + E55P04, + E57000, + E57014, + E57P01, + E57P02, + E57P03, + E57P04, + E57P05, + E58000, + E58030, + E58P01, + E58P02, + E72000, + EF0000, + EF0001, + EHV000, + EHV005, + EHV002, + EHV010, + EHV021, + EHV024, + EHV007, + EHV008, + EHV004, + EHV006, + EHV091, + EHV00B, + EHV00C, + EHV00D, + EHV090, + EHV00A, + EHV009, + EHV014, + EHV001, + EHV00P, + EHV00J, + EHV00K, + EHV00Q, + EHV00R, + EHV00L, + EHV00M, + EHV00N, + EP0000, + EP0001, + EP0002, + EP0003, + EP0004, + EXX000, + EXX001, + EXX002, + Other(Box), +} + +#[rustfmt::skip] +static SQLSTATE_MAP: phf::Map<&'static str, SqlState> = +::phf::Map { + key: 12913932095322966823, + disps: &[ + (0, 24), + (0, 12), + (0, 74), + (0, 109), + (0, 11), + (0, 9), + (0, 0), + (4, 38), + (3, 155), + (0, 6), + (1, 242), + (0, 66), + (0, 53), + (5, 180), + (3, 221), + (7, 230), + (0, 125), + (1, 46), + (0, 11), + (1, 2), + (0, 5), + (0, 13), + (0, 171), + (0, 15), + (0, 4), + (0, 22), + (1, 85), + (0, 75), + (2, 0), + (1, 25), + (7, 47), + (0, 45), + (0, 35), + (0, 7), + (7, 124), + (0, 0), + (14, 104), + (1, 183), + (61, 50), + (3, 76), + (0, 12), + (0, 7), + (4, 189), + (0, 1), + (64, 102), + (0, 0), + (16, 192), + (24, 19), + (0, 5), + (0, 87), + (0, 89), + (0, 14), + ], + entries: &[ + ("2F000", SqlState::SQL_ROUTINE_EXCEPTION), + ("01008", SqlState::WARNING_IMPLICIT_ZERO_BIT_PADDING), + ("42501", SqlState::INSUFFICIENT_PRIVILEGE), + ("22000", SqlState::DATA_EXCEPTION), + ("0100C", SqlState::WARNING_DYNAMIC_RESULT_SETS_RETURNED), + ("2200N", SqlState::INVALID_XML_CONTENT), + ("40001", SqlState::T_R_SERIALIZATION_FAILURE), + ("28P01", SqlState::INVALID_PASSWORD), + ("38000", SqlState::EXTERNAL_ROUTINE_EXCEPTION), + ("25006", SqlState::READ_ONLY_SQL_TRANSACTION), + ("2203D", SqlState::TOO_MANY_JSON_ARRAY_ELEMENTS), + ("42P09", SqlState::AMBIGUOUS_ALIAS), + ("F0000", SqlState::CONFIG_FILE_ERROR), + ("42P18", SqlState::INDETERMINATE_DATATYPE), + ("40002", SqlState::T_R_INTEGRITY_CONSTRAINT_VIOLATION), + ("22009", SqlState::INVALID_TIME_ZONE_DISPLACEMENT_VALUE), + ("42P08", SqlState::AMBIGUOUS_PARAMETER), + ("08000", SqlState::CONNECTION_EXCEPTION), + ("25P01", SqlState::NO_ACTIVE_SQL_TRANSACTION), + ("22024", SqlState::UNTERMINATED_C_STRING), + ("55000", SqlState::OBJECT_NOT_IN_PREREQUISITE_STATE), + ("25001", SqlState::ACTIVE_SQL_TRANSACTION), + ("03000", SqlState::SQL_STATEMENT_NOT_YET_COMPLETE), + ("42710", SqlState::DUPLICATE_OBJECT), + ("2D000", SqlState::INVALID_TRANSACTION_TERMINATION), + ("2200G", SqlState::MOST_SPECIFIC_TYPE_MISMATCH), + ("22022", SqlState::INDICATOR_OVERFLOW), + ("55006", SqlState::OBJECT_IN_USE), + ("53200", SqlState::OUT_OF_MEMORY), + ("22012", SqlState::DIVISION_BY_ZERO), + ("P0002", SqlState::NO_DATA_FOUND), + ("XX001", SqlState::DATA_CORRUPTED), + ("22P05", SqlState::UNTRANSLATABLE_CHARACTER), + ("40003", SqlState::T_R_STATEMENT_COMPLETION_UNKNOWN), + ("22021", SqlState::CHARACTER_NOT_IN_REPERTOIRE), + ("25000", SqlState::INVALID_TRANSACTION_STATE), + ("42P15", SqlState::INVALID_SCHEMA_DEFINITION), + ("0B000", SqlState::INVALID_TRANSACTION_INITIATION), + ("22004", SqlState::NULL_VALUE_NOT_ALLOWED), + ("42804", SqlState::DATATYPE_MISMATCH), + ("42803", SqlState::GROUPING_ERROR), + ("02001", SqlState::NO_ADDITIONAL_DYNAMIC_RESULT_SETS_RETURNED), + ("25002", SqlState::BRANCH_TRANSACTION_ALREADY_ACTIVE), + ("28000", SqlState::INVALID_AUTHORIZATION_SPECIFICATION), + ("HV009", SqlState::FDW_INVALID_USE_OF_NULL_POINTER), + ("22P01", SqlState::FLOATING_POINT_EXCEPTION), + ("2B000", SqlState::DEPENDENT_PRIVILEGE_DESCRIPTORS_STILL_EXIST), + ("42723", SqlState::DUPLICATE_FUNCTION), + ("21000", SqlState::CARDINALITY_VIOLATION), + ("0Z002", SqlState::STACKED_DIAGNOSTICS_ACCESSED_WITHOUT_ACTIVE_HANDLER), + ("23505", SqlState::UNIQUE_VIOLATION), + ("HV00J", SqlState::FDW_OPTION_NAME_NOT_FOUND), + ("23P01", SqlState::EXCLUSION_VIOLATION), + ("39P03", SqlState::E_R_I_E_EVENT_TRIGGER_PROTOCOL_VIOLATED), + ("42P10", SqlState::INVALID_COLUMN_REFERENCE), + ("2202H", SqlState::INVALID_TABLESAMPLE_ARGUMENT), + ("55P04", SqlState::UNSAFE_NEW_ENUM_VALUE_USAGE), + ("P0000", SqlState::PLPGSQL_ERROR), + ("2F005", SqlState::S_R_E_FUNCTION_EXECUTED_NO_RETURN_STATEMENT), + ("HV00M", SqlState::FDW_UNABLE_TO_CREATE_REPLY), + ("0A000", SqlState::FEATURE_NOT_SUPPORTED), + ("24000", SqlState::INVALID_CURSOR_STATE), + ("25008", SqlState::HELD_CURSOR_REQUIRES_SAME_ISOLATION_LEVEL), + ("01003", SqlState::WARNING_NULL_VALUE_ELIMINATED_IN_SET_FUNCTION), + ("42712", SqlState::DUPLICATE_ALIAS), + ("HV014", SqlState::FDW_TOO_MANY_HANDLES), + ("58030", SqlState::IO_ERROR), + ("2201W", SqlState::INVALID_ROW_COUNT_IN_LIMIT_CLAUSE), + ("22033", SqlState::INVALID_SQL_JSON_SUBSCRIPT), + ("2BP01", SqlState::DEPENDENT_OBJECTS_STILL_EXIST), + ("HV005", SqlState::FDW_COLUMN_NAME_NOT_FOUND), + ("25004", SqlState::INAPPROPRIATE_ISOLATION_LEVEL_FOR_BRANCH_TRANSACTION), + ("54000", SqlState::PROGRAM_LIMIT_EXCEEDED), + ("20000", SqlState::CASE_NOT_FOUND), + ("2203G", SqlState::SQL_JSON_ITEM_CANNOT_BE_CAST_TO_TARGET_TYPE), + ("22038", SqlState::SINGLETON_SQL_JSON_ITEM_REQUIRED), + ("22007", SqlState::INVALID_DATETIME_FORMAT), + ("08004", SqlState::SQLSERVER_REJECTED_ESTABLISHMENT_OF_SQLCONNECTION), + ("2200H", SqlState::SEQUENCE_GENERATOR_LIMIT_EXCEEDED), + ("HV00D", SqlState::FDW_INVALID_OPTION_NAME), + ("P0004", SqlState::ASSERT_FAILURE), + ("22018", SqlState::INVALID_CHARACTER_VALUE_FOR_CAST), + ("0L000", SqlState::INVALID_GRANTOR), + ("22P04", SqlState::BAD_COPY_FILE_FORMAT), + ("22031", SqlState::INVALID_ARGUMENT_FOR_SQL_JSON_DATETIME_FUNCTION), + ("01P01", SqlState::WARNING_DEPRECATED_FEATURE), + ("0LP01", SqlState::INVALID_GRANT_OPERATION), + ("58P02", SqlState::DUPLICATE_FILE), + ("26000", SqlState::INVALID_SQL_STATEMENT_NAME), + ("54001", SqlState::STATEMENT_TOO_COMPLEX), + ("22010", SqlState::INVALID_INDICATOR_PARAMETER_VALUE), + ("HV00C", SqlState::FDW_INVALID_OPTION_INDEX), + ("22008", SqlState::DATETIME_FIELD_OVERFLOW), + ("42P06", SqlState::DUPLICATE_SCHEMA), + ("25007", SqlState::SCHEMA_AND_DATA_STATEMENT_MIXING_NOT_SUPPORTED), + ("42P20", SqlState::WINDOWING_ERROR), + ("HV091", SqlState::FDW_INVALID_DESCRIPTOR_FIELD_IDENTIFIER), + ("HV021", SqlState::FDW_INCONSISTENT_DESCRIPTOR_INFORMATION), + ("42702", SqlState::AMBIGUOUS_COLUMN), + ("02000", SqlState::NO_DATA), + ("54011", SqlState::TOO_MANY_COLUMNS), + ("HV004", SqlState::FDW_INVALID_DATA_TYPE), + ("01006", SqlState::WARNING_PRIVILEGE_NOT_REVOKED), + ("42701", SqlState::DUPLICATE_COLUMN), + ("08P01", SqlState::PROTOCOL_VIOLATION), + ("42622", SqlState::NAME_TOO_LONG), + ("P0003", SqlState::TOO_MANY_ROWS), + ("22003", SqlState::NUMERIC_VALUE_OUT_OF_RANGE), + ("42P03", SqlState::DUPLICATE_CURSOR), + ("23001", SqlState::RESTRICT_VIOLATION), + ("57000", SqlState::OPERATOR_INTERVENTION), + ("22027", SqlState::TRIM_ERROR), + ("42P12", SqlState::INVALID_DATABASE_DEFINITION), + ("3B000", SqlState::SAVEPOINT_EXCEPTION), + ("2201B", SqlState::INVALID_REGULAR_EXPRESSION), + ("22030", SqlState::DUPLICATE_JSON_OBJECT_KEY_VALUE), + ("2F004", SqlState::S_R_E_READING_SQL_DATA_NOT_PERMITTED), + ("428C9", SqlState::GENERATED_ALWAYS), + ("2200S", SqlState::INVALID_XML_COMMENT), + ("22039", SqlState::SQL_JSON_ARRAY_NOT_FOUND), + ("42809", SqlState::WRONG_OBJECT_TYPE), + ("2201X", SqlState::INVALID_ROW_COUNT_IN_RESULT_OFFSET_CLAUSE), + ("39001", SqlState::E_R_I_E_INVALID_SQLSTATE_RETURNED), + ("25P02", SqlState::IN_FAILED_SQL_TRANSACTION), + ("0P000", SqlState::INVALID_ROLE_SPECIFICATION), + ("HV00N", SqlState::FDW_UNABLE_TO_ESTABLISH_CONNECTION), + ("53100", SqlState::DISK_FULL), + ("42601", SqlState::SYNTAX_ERROR), + ("23000", SqlState::INTEGRITY_CONSTRAINT_VIOLATION), + ("HV006", SqlState::FDW_INVALID_DATA_TYPE_DESCRIPTORS), + ("HV00B", SqlState::FDW_INVALID_HANDLE), + ("HV00Q", SqlState::FDW_SCHEMA_NOT_FOUND), + ("01000", SqlState::WARNING), + ("42883", SqlState::UNDEFINED_FUNCTION), + ("57P01", SqlState::ADMIN_SHUTDOWN), + ("22037", SqlState::NON_UNIQUE_KEYS_IN_A_JSON_OBJECT), + ("00000", SqlState::SUCCESSFUL_COMPLETION), + ("55P03", SqlState::LOCK_NOT_AVAILABLE), + ("42P01", SqlState::UNDEFINED_TABLE), + ("42830", SqlState::INVALID_FOREIGN_KEY), + ("22005", SqlState::ERROR_IN_ASSIGNMENT), + ("22025", SqlState::INVALID_ESCAPE_SEQUENCE), + ("XX002", SqlState::INDEX_CORRUPTED), + ("42P16", SqlState::INVALID_TABLE_DEFINITION), + ("55P02", SqlState::CANT_CHANGE_RUNTIME_PARAM), + ("22019", SqlState::INVALID_ESCAPE_CHARACTER), + ("P0001", SqlState::RAISE_EXCEPTION), + ("72000", SqlState::SNAPSHOT_TOO_OLD), + ("42P11", SqlState::INVALID_CURSOR_DEFINITION), + ("40P01", SqlState::T_R_DEADLOCK_DETECTED), + ("57P02", SqlState::CRASH_SHUTDOWN), + ("HV00A", SqlState::FDW_INVALID_STRING_FORMAT), + ("2F002", SqlState::S_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED), + ("23503", SqlState::FOREIGN_KEY_VIOLATION), + ("40000", SqlState::TRANSACTION_ROLLBACK), + ("22032", SqlState::INVALID_JSON_TEXT), + ("2202E", SqlState::ARRAY_ELEMENT_ERROR), + ("42P19", SqlState::INVALID_RECURSION), + ("42611", SqlState::INVALID_COLUMN_DEFINITION), + ("42P13", SqlState::INVALID_FUNCTION_DEFINITION), + ("25003", SqlState::INAPPROPRIATE_ACCESS_MODE_FOR_BRANCH_TRANSACTION), + ("39P02", SqlState::E_R_I_E_SRF_PROTOCOL_VIOLATED), + ("XX000", SqlState::INTERNAL_ERROR), + ("08006", SqlState::CONNECTION_FAILURE), + ("57P04", SqlState::DATABASE_DROPPED), + ("42P07", SqlState::DUPLICATE_TABLE), + ("22P03", SqlState::INVALID_BINARY_REPRESENTATION), + ("22035", SqlState::NO_SQL_JSON_ITEM), + ("42P14", SqlState::INVALID_PSTATEMENT_DEFINITION), + ("01007", SqlState::WARNING_PRIVILEGE_NOT_GRANTED), + ("38004", SqlState::E_R_E_READING_SQL_DATA_NOT_PERMITTED), + ("42P21", SqlState::COLLATION_MISMATCH), + ("0Z000", SqlState::DIAGNOSTICS_EXCEPTION), + ("HV001", SqlState::FDW_OUT_OF_MEMORY), + ("0F000", SqlState::LOCATOR_EXCEPTION), + ("22013", SqlState::INVALID_PRECEDING_OR_FOLLOWING_SIZE), + ("2201E", SqlState::INVALID_ARGUMENT_FOR_LOG), + ("22011", SqlState::SUBSTRING_ERROR), + ("42602", SqlState::INVALID_NAME), + ("01004", SqlState::WARNING_STRING_DATA_RIGHT_TRUNCATION), + ("42P02", SqlState::UNDEFINED_PARAMETER), + ("2203C", SqlState::SQL_JSON_OBJECT_NOT_FOUND), + ("HV002", SqlState::FDW_DYNAMIC_PARAMETER_VALUE_NEEDED), + ("0F001", SqlState::L_E_INVALID_SPECIFICATION), + ("58P01", SqlState::UNDEFINED_FILE), + ("38001", SqlState::E_R_E_CONTAINING_SQL_NOT_PERMITTED), + ("42703", SqlState::UNDEFINED_COLUMN), + ("57P05", SqlState::IDLE_SESSION_TIMEOUT), + ("57P03", SqlState::CANNOT_CONNECT_NOW), + ("HV007", SqlState::FDW_INVALID_COLUMN_NAME), + ("22014", SqlState::INVALID_ARGUMENT_FOR_NTILE), + ("22P06", SqlState::NONSTANDARD_USE_OF_ESCAPE_CHARACTER), + ("2203F", SqlState::SQL_JSON_SCALAR_REQUIRED), + ("2200F", SqlState::ZERO_LENGTH_CHARACTER_STRING), + ("09000", SqlState::TRIGGERED_ACTION_EXCEPTION), + ("2201F", SqlState::INVALID_ARGUMENT_FOR_POWER_FUNCTION), + ("08003", SqlState::CONNECTION_DOES_NOT_EXIST), + ("38002", SqlState::E_R_E_MODIFYING_SQL_DATA_NOT_PERMITTED), + ("F0001", SqlState::LOCK_FILE_EXISTS), + ("42P22", SqlState::INDETERMINATE_COLLATION), + ("2200C", SqlState::INVALID_USE_OF_ESCAPE_CHARACTER), + ("2203E", SqlState::TOO_MANY_JSON_OBJECT_MEMBERS), + ("23514", SqlState::CHECK_VIOLATION), + ("22P02", SqlState::INVALID_TEXT_REPRESENTATION), + ("54023", SqlState::TOO_MANY_ARGUMENTS), + ("2200T", SqlState::INVALID_XML_PROCESSING_INSTRUCTION), + ("22016", SqlState::INVALID_ARGUMENT_FOR_NTH_VALUE), + ("25P03", SqlState::IDLE_IN_TRANSACTION_SESSION_TIMEOUT), + ("3B001", SqlState::S_E_INVALID_SPECIFICATION), + ("08001", SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION), + ("22036", SqlState::NON_NUMERIC_SQL_JSON_ITEM), + ("3F000", SqlState::INVALID_SCHEMA_NAME), + ("39P01", SqlState::E_R_I_E_TRIGGER_PROTOCOL_VIOLATED), + ("22026", SqlState::STRING_DATA_LENGTH_MISMATCH), + ("42P17", SqlState::INVALID_OBJECT_DEFINITION), + ("22034", SqlState::MORE_THAN_ONE_SQL_JSON_ITEM), + ("HV000", SqlState::FDW_ERROR), + ("2200B", SqlState::ESCAPE_CHARACTER_CONFLICT), + ("HV008", SqlState::FDW_INVALID_COLUMN_NUMBER), + ("34000", SqlState::INVALID_CURSOR_NAME), + ("2201G", SqlState::INVALID_ARGUMENT_FOR_WIDTH_BUCKET_FUNCTION), + ("44000", SqlState::WITH_CHECK_OPTION_VIOLATION), + ("HV010", SqlState::FDW_FUNCTION_SEQUENCE_ERROR), + ("39004", SqlState::E_R_I_E_NULL_VALUE_NOT_ALLOWED), + ("22001", SqlState::STRING_DATA_RIGHT_TRUNCATION), + ("3D000", SqlState::INVALID_CATALOG_NAME), + ("25005", SqlState::NO_ACTIVE_SQL_TRANSACTION_FOR_BRANCH_TRANSACTION), + ("2200L", SqlState::NOT_AN_XML_DOCUMENT), + ("27000", SqlState::TRIGGERED_DATA_CHANGE_VIOLATION), + ("HV090", SqlState::FDW_INVALID_STRING_LENGTH_OR_BUFFER_LENGTH), + ("42939", SqlState::RESERVED_NAME), + ("58000", SqlState::SYSTEM_ERROR), + ("2200M", SqlState::INVALID_XML_DOCUMENT), + ("HV00L", SqlState::FDW_UNABLE_TO_CREATE_EXECUTION), + ("57014", SqlState::QUERY_CANCELED), + ("23502", SqlState::NOT_NULL_VIOLATION), + ("22002", SqlState::NULL_VALUE_NO_INDICATOR_PARAMETER), + ("HV00R", SqlState::FDW_TABLE_NOT_FOUND), + ("HV00P", SqlState::FDW_NO_SCHEMAS), + ("38003", SqlState::E_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED), + ("39000", SqlState::EXTERNAL_ROUTINE_INVOCATION_EXCEPTION), + ("22015", SqlState::INTERVAL_FIELD_OVERFLOW), + ("HV00K", SqlState::FDW_REPLY_HANDLE), + ("HV024", SqlState::FDW_INVALID_ATTRIBUTE_VALUE), + ("2200D", SqlState::INVALID_ESCAPE_OCTET), + ("08007", SqlState::TRANSACTION_RESOLUTION_UNKNOWN), + ("2F003", SqlState::S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED), + ("42725", SqlState::AMBIGUOUS_FUNCTION), + ("2203A", SqlState::SQL_JSON_MEMBER_NOT_FOUND), + ("42846", SqlState::CANNOT_COERCE), + ("42P04", SqlState::DUPLICATE_DATABASE), + ("42000", SqlState::SYNTAX_ERROR_OR_ACCESS_RULE_VIOLATION), + ("2203B", SqlState::SQL_JSON_NUMBER_NOT_FOUND), + ("42P05", SqlState::DUPLICATE_PSTATEMENT), + ("53300", SqlState::TOO_MANY_CONNECTIONS), + ("53400", SqlState::CONFIGURATION_LIMIT_EXCEEDED), + ("42704", SqlState::UNDEFINED_OBJECT), + ("2202G", SqlState::INVALID_TABLESAMPLE_REPEAT), + ("22023", SqlState::INVALID_PARAMETER_VALUE), + ("53000", SqlState::INSUFFICIENT_RESOURCES), + ], +}; diff --git a/libs/proxy/tokio-postgres2/src/generic_client.rs b/libs/proxy/tokio-postgres2/src/generic_client.rs new file mode 100644 index 0000000000..768213f8ed --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/generic_client.rs @@ -0,0 +1,64 @@ +use crate::query::RowStream; +use crate::types::Type; +use crate::{Client, Error, Transaction}; +use async_trait::async_trait; +use postgres_protocol2::Oid; + +mod private { + pub trait Sealed {} +} + +/// A trait allowing abstraction over connections and transactions. +/// +/// This trait is "sealed", and cannot be implemented outside of this crate. +#[async_trait] +pub trait GenericClient: private::Sealed { + /// Like `Client::query_raw_txt`. + async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send; + + /// Query for type information + async fn get_type(&self, oid: Oid) -> Result; +} + +impl private::Sealed for Client {} + +#[async_trait] +impl GenericClient for Client { + async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.query_raw_txt(statement, params).await + } + + /// Query for type information + async fn get_type(&self, oid: Oid) -> Result { + self.get_type(oid).await + } +} + +impl private::Sealed for Transaction<'_> {} + +#[async_trait] +#[allow(clippy::needless_lifetimes)] +impl GenericClient for Transaction<'_> { + async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef + Sync + Send, + I: IntoIterator> + Sync + Send, + I::IntoIter: ExactSizeIterator + Sync + Send, + { + self.query_raw_txt(statement, params).await + } + + /// Query for type information + async fn get_type(&self, oid: Oid) -> Result { + self.client().get_type(oid).await + } +} diff --git a/libs/proxy/tokio-postgres2/src/lib.rs b/libs/proxy/tokio-postgres2/src/lib.rs new file mode 100644 index 0000000000..72ba8172b2 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/lib.rs @@ -0,0 +1,148 @@ +//! An asynchronous, pipelined, PostgreSQL client. +#![warn(rust_2018_idioms, clippy::all, missing_docs)] + +pub use crate::cancel_token::CancelToken; +pub use crate::client::Client; +pub use crate::config::Config; +pub use crate::connection::Connection; +use crate::error::DbError; +pub use crate::error::Error; +pub use crate::generic_client::GenericClient; +pub use crate::query::RowStream; +pub use crate::row::{Row, SimpleQueryRow}; +pub use crate::simple_query::SimpleQueryStream; +pub use crate::statement::{Column, Statement}; +use crate::tls::MakeTlsConnect; +pub use crate::tls::NoTls; +pub use crate::to_statement::ToStatement; +pub use crate::transaction::Transaction; +pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder}; +use crate::types::ToSql; +use postgres_protocol2::message::backend::ReadyForQueryBody; +use tokio::net::TcpStream; + +/// After executing a query, the connection will be in one of these states +#[derive(Clone, Copy, Debug, PartialEq)] +#[repr(u8)] +pub enum ReadyForQueryStatus { + /// Connection state is unknown + Unknown, + /// Connection is idle (no transactions) + Idle = b'I', + /// Connection is in a transaction block + Transaction = b'T', + /// Connection is in a failed transaction block + FailedTransaction = b'E', +} + +impl From for ReadyForQueryStatus { + fn from(value: ReadyForQueryBody) -> Self { + match value.status() { + b'I' => Self::Idle, + b'T' => Self::Transaction, + b'E' => Self::FailedTransaction, + _ => Self::Unknown, + } + } +} + +mod cancel_query; +mod cancel_query_raw; +mod cancel_token; +mod client; +mod codec; +pub mod config; +mod connect; +mod connect_raw; +mod connect_socket; +mod connect_tls; +mod connection; +pub mod error; +mod generic_client; +pub mod maybe_tls_stream; +mod prepare; +mod query; +pub mod row; +mod simple_query; +mod statement; +pub mod tls; +mod to_statement; +mod transaction; +mod transaction_builder; +pub mod types; + +/// A convenience function which parses a connection string and connects to the database. +/// +/// See the documentation for [`Config`] for details on the connection string format. +/// +/// Requires the `runtime` Cargo feature (enabled by default). +/// +/// [`Config`]: config/struct.Config.html +pub async fn connect( + config: &str, + tls: T, +) -> Result<(Client, Connection), Error> +where + T: MakeTlsConnect, +{ + let config = config.parse::()?; + config.connect(tls).await +} + +/// An asynchronous notification. +#[derive(Clone, Debug)] +pub struct Notification { + process_id: i32, + channel: String, + payload: String, +} + +impl Notification { + /// The process ID of the notifying backend process. + pub fn process_id(&self) -> i32 { + self.process_id + } + + /// The name of the channel that the notify has been raised on. + pub fn channel(&self) -> &str { + &self.channel + } + + /// The "payload" string passed from the notifying process. + pub fn payload(&self) -> &str { + &self.payload + } +} + +/// An asynchronous message from the server. +#[allow(clippy::large_enum_variant)] +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum AsyncMessage { + /// A notice. + /// + /// Notices use the same format as errors, but aren't "errors" per-se. + Notice(DbError), + /// A notification. + /// + /// Connections can subscribe to notifications with the `LISTEN` command. + Notification(Notification), +} + +/// Message returned by the `SimpleQuery` stream. +#[derive(Debug)] +#[non_exhaustive] +pub enum SimpleQueryMessage { + /// A row of data. + Row(SimpleQueryRow), + /// A statement in the query has completed. + /// + /// The number of rows modified or selected is returned. + CommandComplete(u64), +} + +fn slice_iter<'a>( + s: &'a [&'a (dyn ToSql + Sync)], +) -> impl ExactSizeIterator + 'a { + s.iter().map(|s| *s as _) +} diff --git a/libs/proxy/tokio-postgres2/src/maybe_tls_stream.rs b/libs/proxy/tokio-postgres2/src/maybe_tls_stream.rs new file mode 100644 index 0000000000..9a7e248997 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/maybe_tls_stream.rs @@ -0,0 +1,77 @@ +//! MaybeTlsStream. +//! +//! Represents a stream that may or may not be encrypted with TLS. +use crate::tls::{ChannelBinding, TlsStream}; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +/// A stream that may or may not be encrypted with TLS. +pub enum MaybeTlsStream { + /// An unencrypted stream. + Raw(S), + /// An encrypted stream. + Tls(T), +} + +impl AsyncRead for MaybeTlsStream +where + S: AsyncRead + Unpin, + T: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + MaybeTlsStream::Raw(s) => Pin::new(s).poll_read(cx, buf), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for MaybeTlsStream +where + S: AsyncWrite + Unpin, + T: AsyncWrite + Unpin, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match &mut *self { + MaybeTlsStream::Raw(s) => Pin::new(s).poll_write(cx, buf), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + MaybeTlsStream::Raw(s) => Pin::new(s).poll_flush(cx), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + MaybeTlsStream::Raw(s) => Pin::new(s).poll_shutdown(cx), + MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx), + } + } +} + +impl TlsStream for MaybeTlsStream +where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsStream + Unpin, +{ + fn channel_binding(&self) -> ChannelBinding { + match self { + MaybeTlsStream::Raw(_) => ChannelBinding::none(), + MaybeTlsStream::Tls(s) => s.channel_binding(), + } + } +} diff --git a/libs/proxy/tokio-postgres2/src/prepare.rs b/libs/proxy/tokio-postgres2/src/prepare.rs new file mode 100644 index 0000000000..da0c755c5b --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/prepare.rs @@ -0,0 +1,262 @@ +use crate::client::InnerClient; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::error::SqlState; +use crate::types::{Field, Kind, Oid, Type}; +use crate::{query, slice_iter}; +use crate::{Column, Error, Statement}; +use bytes::Bytes; +use fallible_iterator::FallibleIterator; +use futures_util::{pin_mut, TryStreamExt}; +use log::debug; +use postgres_protocol2::message::backend::Message; +use postgres_protocol2::message::frontend; +use std::future::Future; +use std::pin::Pin; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; + +pub(crate) const TYPEINFO_QUERY: &str = "\ +SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid +FROM pg_catalog.pg_type t +LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid +INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid +WHERE t.oid = $1 +"; + +// Range types weren't added until Postgres 9.2, so pg_range may not exist +const TYPEINFO_FALLBACK_QUERY: &str = "\ +SELECT t.typname, t.typtype, t.typelem, NULL::OID, t.typbasetype, n.nspname, t.typrelid +FROM pg_catalog.pg_type t +INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid +WHERE t.oid = $1 +"; + +const TYPEINFO_ENUM_QUERY: &str = "\ +SELECT enumlabel +FROM pg_catalog.pg_enum +WHERE enumtypid = $1 +ORDER BY enumsortorder +"; + +// Postgres 9.0 didn't have enumsortorder +const TYPEINFO_ENUM_FALLBACK_QUERY: &str = "\ +SELECT enumlabel +FROM pg_catalog.pg_enum +WHERE enumtypid = $1 +ORDER BY oid +"; + +pub(crate) const TYPEINFO_COMPOSITE_QUERY: &str = "\ +SELECT attname, atttypid +FROM pg_catalog.pg_attribute +WHERE attrelid = $1 +AND NOT attisdropped +AND attnum > 0 +ORDER BY attnum +"; + +static NEXT_ID: AtomicUsize = AtomicUsize::new(0); + +pub async fn prepare( + client: &Arc, + query: &str, + types: &[Type], +) -> Result { + let name = format!("s{}", NEXT_ID.fetch_add(1, Ordering::SeqCst)); + let buf = encode(client, &name, query, types)?; + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + + let parameter_description = match responses.next().await? { + Message::ParameterDescription(body) => body, + _ => return Err(Error::unexpected_message()), + }; + + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + let mut parameters = vec![]; + let mut it = parameter_description.parameters(); + while let Some(oid) = it.next().map_err(Error::parse)? { + let type_ = get_type(client, oid).await?; + parameters.push(type_); + } + + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = get_type(client, field.type_oid()).await?; + let column = Column::new(field.name().to_string(), type_, field); + columns.push(column); + } + } + + Ok(Statement::new(client, name, parameters, columns)) +} + +fn prepare_rec<'a>( + client: &'a Arc, + query: &'a str, + types: &'a [Type], +) -> Pin> + 'a + Send>> { + Box::pin(prepare(client, query, types)) +} + +fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result { + if types.is_empty() { + debug!("preparing query {}: {}", name, query); + } else { + debug!("preparing query {} with types {:?}: {}", name, types, query); + } + + client.with_buf(|buf| { + frontend::parse(name, query, types.iter().map(Type::oid), buf).map_err(Error::encode)?; + frontend::describe(b'S', name, buf).map_err(Error::encode)?; + frontend::sync(buf); + Ok(buf.split().freeze()) + }) +} + +pub async fn get_type(client: &Arc, oid: Oid) -> Result { + if let Some(type_) = Type::from_oid(oid) { + return Ok(type_); + } + + if let Some(type_) = client.type_(oid) { + return Ok(type_); + } + + let stmt = typeinfo_statement(client).await?; + + let rows = query::query(client, stmt, slice_iter(&[&oid])).await?; + pin_mut!(rows); + + let row = match rows.try_next().await? { + Some(row) => row, + None => return Err(Error::unexpected_message()), + }; + + let name: String = row.try_get(0)?; + let type_: i8 = row.try_get(1)?; + let elem_oid: Oid = row.try_get(2)?; + let rngsubtype: Option = row.try_get(3)?; + let basetype: Oid = row.try_get(4)?; + let schema: String = row.try_get(5)?; + let relid: Oid = row.try_get(6)?; + + let kind = if type_ == b'e' as i8 { + let variants = get_enum_variants(client, oid).await?; + Kind::Enum(variants) + } else if type_ == b'p' as i8 { + Kind::Pseudo + } else if basetype != 0 { + let type_ = get_type_rec(client, basetype).await?; + Kind::Domain(type_) + } else if elem_oid != 0 { + let type_ = get_type_rec(client, elem_oid).await?; + Kind::Array(type_) + } else if relid != 0 { + let fields = get_composite_fields(client, relid).await?; + Kind::Composite(fields) + } else if let Some(rngsubtype) = rngsubtype { + let type_ = get_type_rec(client, rngsubtype).await?; + Kind::Range(type_) + } else { + Kind::Simple + }; + + let type_ = Type::new(name, oid, kind, schema); + client.set_type(oid, &type_); + + Ok(type_) +} + +fn get_type_rec<'a>( + client: &'a Arc, + oid: Oid, +) -> Pin> + Send + 'a>> { + Box::pin(get_type(client, oid)) +} + +async fn typeinfo_statement(client: &Arc) -> Result { + if let Some(stmt) = client.typeinfo() { + return Ok(stmt); + } + + let stmt = match prepare_rec(client, TYPEINFO_QUERY, &[]).await { + Ok(stmt) => stmt, + Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_TABLE) => { + prepare_rec(client, TYPEINFO_FALLBACK_QUERY, &[]).await? + } + Err(e) => return Err(e), + }; + + client.set_typeinfo(&stmt); + Ok(stmt) +} + +async fn get_enum_variants(client: &Arc, oid: Oid) -> Result, Error> { + let stmt = typeinfo_enum_statement(client).await?; + + query::query(client, stmt, slice_iter(&[&oid])) + .await? + .and_then(|row| async move { row.try_get(0) }) + .try_collect() + .await +} + +async fn typeinfo_enum_statement(client: &Arc) -> Result { + if let Some(stmt) = client.typeinfo_enum() { + return Ok(stmt); + } + + let stmt = match prepare_rec(client, TYPEINFO_ENUM_QUERY, &[]).await { + Ok(stmt) => stmt, + Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_COLUMN) => { + prepare_rec(client, TYPEINFO_ENUM_FALLBACK_QUERY, &[]).await? + } + Err(e) => return Err(e), + }; + + client.set_typeinfo_enum(&stmt); + Ok(stmt) +} + +async fn get_composite_fields(client: &Arc, oid: Oid) -> Result, Error> { + let stmt = typeinfo_composite_statement(client).await?; + + let rows = query::query(client, stmt, slice_iter(&[&oid])) + .await? + .try_collect::>() + .await?; + + let mut fields = vec![]; + for row in rows { + let name = row.try_get(0)?; + let oid = row.try_get(1)?; + let type_ = get_type_rec(client, oid).await?; + fields.push(Field::new(name, type_)); + } + + Ok(fields) +} + +async fn typeinfo_composite_statement(client: &Arc) -> Result { + if let Some(stmt) = client.typeinfo_composite() { + return Ok(stmt); + } + + let stmt = prepare_rec(client, TYPEINFO_COMPOSITE_QUERY, &[]).await?; + + client.set_typeinfo_composite(&stmt); + Ok(stmt) +} diff --git a/libs/proxy/tokio-postgres2/src/query.rs b/libs/proxy/tokio-postgres2/src/query.rs new file mode 100644 index 0000000000..534195a707 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/query.rs @@ -0,0 +1,340 @@ +use crate::client::{InnerClient, Responses}; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::types::IsNull; +use crate::{Column, Error, ReadyForQueryStatus, Row, Statement}; +use bytes::{BufMut, Bytes, BytesMut}; +use fallible_iterator::FallibleIterator; +use futures_util::{ready, Stream}; +use log::{debug, log_enabled, Level}; +use pin_project_lite::pin_project; +use postgres_protocol2::message::backend::Message; +use postgres_protocol2::message::frontend; +use postgres_types2::{Format, ToSql, Type}; +use std::fmt; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]); + +impl fmt::Debug for BorrowToSqlParamsDebug<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.0.iter()).finish() + } +} + +pub async fn query<'a, I>( + client: &InnerClient, + statement: Statement, + params: I, +) -> Result +where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + let buf = if log_enabled!(Level::Debug) { + let params = params.into_iter().collect::>(); + debug!( + "executing statement {} with parameters: {:?}", + statement.name(), + BorrowToSqlParamsDebug(params.as_slice()), + ); + encode(client, &statement, params)? + } else { + encode(client, &statement, params)? + }; + let responses = start(client, buf).await?; + Ok(RowStream { + statement, + responses, + command_tag: None, + status: ReadyForQueryStatus::Unknown, + output_format: Format::Binary, + _p: PhantomPinned, + }) +} + +pub async fn query_txt( + client: &Arc, + query: &str, + params: I, +) -> Result +where + S: AsRef, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, +{ + let params = params.into_iter(); + + let buf = client.with_buf(|buf| { + frontend::parse( + "", // unnamed prepared statement + query, // query to parse + std::iter::empty(), // give no type info + buf, + ) + .map_err(Error::encode)?; + frontend::describe(b'S', "", buf).map_err(Error::encode)?; + // Bind, pass params as text, retrieve as binary + match frontend::bind( + "", // empty string selects the unnamed portal + "", // unnamed prepared statement + std::iter::empty(), // all parameters use the default format (text) + params, + |param, buf| match param { + Some(param) => { + buf.put_slice(param.as_ref().as_bytes()); + Ok(postgres_protocol2::IsNull::No) + } + None => Ok(postgres_protocol2::IsNull::Yes), + }, + Some(0), // all text + buf, + ) { + Ok(()) => Ok(()), + Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)), + Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), + }?; + + // Execute + frontend::execute("", 0, buf).map_err(Error::encode)?; + // Sync + frontend::sync(buf); + + Ok(buf.split().freeze()) + })?; + + // now read the responses + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + + let parameter_description = match responses.next().await? { + Message::ParameterDescription(body) => body, + _ => return Err(Error::unexpected_message()), + }; + + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + + let mut parameters = vec![]; + let mut it = parameter_description.parameters(); + while let Some(oid) = it.next().map_err(Error::parse)? { + let type_ = Type::from_oid(oid).unwrap_or(Type::UNKNOWN); + parameters.push(type_); + } + + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = Type::from_oid(field.type_oid()).unwrap_or(Type::UNKNOWN); + let column = Column::new(field.name().to_string(), type_, field); + columns.push(column); + } + } + + Ok(RowStream { + statement: Statement::new_anonymous(parameters, columns), + responses, + command_tag: None, + status: ReadyForQueryStatus::Unknown, + output_format: Format::Text, + _p: PhantomPinned, + }) +} + +pub async fn execute<'a, I>( + client: &InnerClient, + statement: Statement, + params: I, +) -> Result +where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + let buf = if log_enabled!(Level::Debug) { + let params = params.into_iter().collect::>(); + debug!( + "executing statement {} with parameters: {:?}", + statement.name(), + BorrowToSqlParamsDebug(params.as_slice()), + ); + encode(client, &statement, params)? + } else { + encode(client, &statement, params)? + }; + let mut responses = start(client, buf).await?; + + let mut rows = 0; + loop { + match responses.next().await? { + Message::DataRow(_) => {} + Message::CommandComplete(body) => { + rows = body + .tag() + .map_err(Error::parse)? + .rsplit(' ') + .next() + .unwrap() + .parse() + .unwrap_or(0); + } + Message::EmptyQueryResponse => rows = 0, + Message::ReadyForQuery(_) => return Ok(rows), + _ => return Err(Error::unexpected_message()), + } + } +} + +async fn start(client: &InnerClient, buf: Bytes) -> Result { + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + + Ok(responses) +} + +pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result +where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + client.with_buf(|buf| { + encode_bind(statement, params, "", buf)?; + frontend::execute("", 0, buf).map_err(Error::encode)?; + frontend::sync(buf); + Ok(buf.split().freeze()) + }) +} + +pub fn encode_bind<'a, I>( + statement: &Statement, + params: I, + portal: &str, + buf: &mut BytesMut, +) -> Result<(), Error> +where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, +{ + let param_types = statement.params(); + let params = params.into_iter(); + + assert!( + param_types.len() == params.len(), + "expected {} parameters but got {}", + param_types.len(), + params.len() + ); + + let (param_formats, params): (Vec<_>, Vec<_>) = params + .zip(param_types.iter()) + .map(|(p, ty)| (p.encode_format(ty) as i16, p)) + .unzip(); + + let params = params.into_iter(); + + let mut error_idx = 0; + let r = frontend::bind( + portal, + statement.name(), + param_formats, + params.zip(param_types).enumerate(), + |(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) { + Ok(IsNull::No) => Ok(postgres_protocol2::IsNull::No), + Ok(IsNull::Yes) => Ok(postgres_protocol2::IsNull::Yes), + Err(e) => { + error_idx = idx; + Err(e) + } + }, + Some(1), + buf, + ); + match r { + Ok(()) => Ok(()), + Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)), + Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), + } +} + +pin_project! { + /// A stream of table rows. + pub struct RowStream { + statement: Statement, + responses: Responses, + command_tag: Option, + output_format: Format, + status: ReadyForQueryStatus, + #[pin] + _p: PhantomPinned, + } +} + +impl Stream for RowStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + loop { + match ready!(this.responses.poll_next(cx)?) { + Message::DataRow(body) => { + return Poll::Ready(Some(Ok(Row::new( + this.statement.clone(), + body, + *this.output_format, + )?))) + } + Message::EmptyQueryResponse | Message::PortalSuspended => {} + Message::CommandComplete(body) => { + if let Ok(tag) = body.tag() { + *this.command_tag = Some(tag.to_string()); + } + } + Message::ReadyForQuery(status) => { + *this.status = status.into(); + return Poll::Ready(None); + } + _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), + } + } + } +} + +impl RowStream { + /// Returns information about the columns of data in the row. + pub fn columns(&self) -> &[Column] { + self.statement.columns() + } + + /// Returns the command tag of this query. + /// + /// This is only available after the stream has been exhausted. + pub fn command_tag(&self) -> Option { + self.command_tag.clone() + } + + /// Returns if the connection is ready for querying, with the status of the connection. + /// + /// This might be available only after the stream has been exhausted. + pub fn ready_status(&self) -> ReadyForQueryStatus { + self.status + } +} diff --git a/libs/proxy/tokio-postgres2/src/row.rs b/libs/proxy/tokio-postgres2/src/row.rs new file mode 100644 index 0000000000..10e130707d --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/row.rs @@ -0,0 +1,300 @@ +//! Rows. + +use crate::row::sealed::{AsName, Sealed}; +use crate::simple_query::SimpleColumn; +use crate::statement::Column; +use crate::types::{FromSql, Type, WrongType}; +use crate::{Error, Statement}; +use fallible_iterator::FallibleIterator; +use postgres_protocol2::message::backend::DataRowBody; +use postgres_types2::{Format, WrongFormat}; +use std::fmt; +use std::ops::Range; +use std::str; +use std::sync::Arc; + +mod sealed { + pub trait Sealed {} + + pub trait AsName { + fn as_name(&self) -> &str; + } +} + +impl AsName for Column { + fn as_name(&self) -> &str { + self.name() + } +} + +impl AsName for String { + fn as_name(&self) -> &str { + self + } +} + +/// A trait implemented by types that can index into columns of a row. +/// +/// This cannot be implemented outside of this crate. +pub trait RowIndex: Sealed { + #[doc(hidden)] + fn __idx(&self, columns: &[T]) -> Option + where + T: AsName; +} + +impl Sealed for usize {} + +impl RowIndex for usize { + #[inline] + fn __idx(&self, columns: &[T]) -> Option + where + T: AsName, + { + if *self >= columns.len() { + None + } else { + Some(*self) + } + } +} + +impl Sealed for str {} + +impl RowIndex for str { + #[inline] + fn __idx(&self, columns: &[T]) -> Option + where + T: AsName, + { + if let Some(idx) = columns.iter().position(|d| d.as_name() == self) { + return Some(idx); + }; + + // FIXME ASCII-only case insensitivity isn't really the right thing to + // do. Postgres itself uses a dubious wrapper around tolower and JDBC + // uses the US locale. + columns + .iter() + .position(|d| d.as_name().eq_ignore_ascii_case(self)) + } +} + +impl Sealed for &T where T: ?Sized + Sealed {} + +impl RowIndex for &T +where + T: ?Sized + RowIndex, +{ + #[inline] + fn __idx(&self, columns: &[U]) -> Option + where + U: AsName, + { + T::__idx(*self, columns) + } +} + +/// A row of data returned from the database by a query. +pub struct Row { + statement: Statement, + output_format: Format, + body: DataRowBody, + ranges: Vec>>, +} + +impl fmt::Debug for Row { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Row") + .field("columns", &self.columns()) + .finish() + } +} + +impl Row { + pub(crate) fn new( + statement: Statement, + body: DataRowBody, + output_format: Format, + ) -> Result { + let ranges = body.ranges().collect().map_err(Error::parse)?; + Ok(Row { + statement, + body, + ranges, + output_format, + }) + } + + /// Returns information about the columns of data in the row. + pub fn columns(&self) -> &[Column] { + self.statement.columns() + } + + /// Determines if the row contains no values. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the number of values in the row. + pub fn len(&self) -> usize { + self.columns().len() + } + + /// Deserializes a value from the row. + /// + /// The value can be specified either by its numeric index in the row, or by its column name. + /// + /// # Panics + /// + /// Panics if the index is out of bounds or if the value cannot be converted to the specified type. + pub fn get<'a, I, T>(&'a self, idx: I) -> T + where + I: RowIndex + fmt::Display, + T: FromSql<'a>, + { + match self.get_inner(&idx) { + Ok(ok) => ok, + Err(err) => panic!("error retrieving column {}: {}", idx, err), + } + } + + /// Like `Row::get`, but returns a `Result` rather than panicking. + pub fn try_get<'a, I, T>(&'a self, idx: I) -> Result + where + I: RowIndex + fmt::Display, + T: FromSql<'a>, + { + self.get_inner(&idx) + } + + fn get_inner<'a, I, T>(&'a self, idx: &I) -> Result + where + I: RowIndex + fmt::Display, + T: FromSql<'a>, + { + let idx = match idx.__idx(self.columns()) { + Some(idx) => idx, + None => return Err(Error::column(idx.to_string())), + }; + + let ty = self.columns()[idx].type_(); + if !T::accepts(ty) { + return Err(Error::from_sql( + Box::new(WrongType::new::(ty.clone())), + idx, + )); + } + + FromSql::from_sql_nullable(ty, self.col_buffer(idx)).map_err(|e| Error::from_sql(e, idx)) + } + + /// Get the raw bytes for the column at the given index. + fn col_buffer(&self, idx: usize) -> Option<&[u8]> { + let range = self.ranges.get(idx)?.to_owned()?; + Some(&self.body.buffer()[range]) + } + + /// Interpret the column at the given index as text + /// + /// Useful when using query_raw_txt() which sets text transfer mode + pub fn as_text(&self, idx: usize) -> Result, Error> { + if self.output_format == Format::Text { + match self.col_buffer(idx) { + Some(raw) => { + FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx)) + } + None => Ok(None), + } + } else { + Err(Error::from_sql(Box::new(WrongFormat {}), idx)) + } + } + + /// Row byte size + pub fn body_len(&self) -> usize { + self.body.buffer().len() + } +} + +impl AsName for SimpleColumn { + fn as_name(&self) -> &str { + self.name() + } +} + +/// A row of data returned from the database by a simple query. +#[derive(Debug)] +pub struct SimpleQueryRow { + columns: Arc<[SimpleColumn]>, + body: DataRowBody, + ranges: Vec>>, +} + +impl SimpleQueryRow { + #[allow(clippy::new_ret_no_self)] + pub(crate) fn new( + columns: Arc<[SimpleColumn]>, + body: DataRowBody, + ) -> Result { + let ranges = body.ranges().collect().map_err(Error::parse)?; + Ok(SimpleQueryRow { + columns, + body, + ranges, + }) + } + + /// Returns information about the columns of data in the row. + pub fn columns(&self) -> &[SimpleColumn] { + &self.columns + } + + /// Determines if the row contains no values. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the number of values in the row. + pub fn len(&self) -> usize { + self.columns.len() + } + + /// Returns a value from the row. + /// + /// The value can be specified either by its numeric index in the row, or by its column name. + /// + /// # Panics + /// + /// Panics if the index is out of bounds or if the value cannot be converted to the specified type. + pub fn get(&self, idx: I) -> Option<&str> + where + I: RowIndex + fmt::Display, + { + match self.get_inner(&idx) { + Ok(ok) => ok, + Err(err) => panic!("error retrieving column {}: {}", idx, err), + } + } + + /// Like `SimpleQueryRow::get`, but returns a `Result` rather than panicking. + pub fn try_get(&self, idx: I) -> Result, Error> + where + I: RowIndex + fmt::Display, + { + self.get_inner(&idx) + } + + fn get_inner(&self, idx: &I) -> Result, Error> + where + I: RowIndex + fmt::Display, + { + let idx = match idx.__idx(&self.columns) { + Some(idx) => idx, + None => return Err(Error::column(idx.to_string())), + }; + + let buf = self.ranges[idx].clone().map(|r| &self.body.buffer()[r]); + FromSql::from_sql_nullable(&Type::TEXT, buf).map_err(|e| Error::from_sql(e, idx)) + } +} diff --git a/libs/proxy/tokio-postgres2/src/simple_query.rs b/libs/proxy/tokio-postgres2/src/simple_query.rs new file mode 100644 index 0000000000..fb2550377b --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/simple_query.rs @@ -0,0 +1,142 @@ +use crate::client::{InnerClient, Responses}; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::{Error, ReadyForQueryStatus, SimpleQueryMessage, SimpleQueryRow}; +use bytes::Bytes; +use fallible_iterator::FallibleIterator; +use futures_util::{ready, Stream}; +use log::debug; +use pin_project_lite::pin_project; +use postgres_protocol2::message::backend::Message; +use postgres_protocol2::message::frontend; +use std::marker::PhantomPinned; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +/// Information about a column of a single query row. +#[derive(Debug)] +pub struct SimpleColumn { + name: String, +} + +impl SimpleColumn { + pub(crate) fn new(name: String) -> SimpleColumn { + SimpleColumn { name } + } + + /// Returns the name of the column. + pub fn name(&self) -> &str { + &self.name + } +} + +pub async fn simple_query(client: &InnerClient, query: &str) -> Result { + debug!("executing simple query: {}", query); + + let buf = encode(client, query)?; + let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + Ok(SimpleQueryStream { + responses, + columns: None, + status: ReadyForQueryStatus::Unknown, + _p: PhantomPinned, + }) +} + +pub async fn batch_execute( + client: &InnerClient, + query: &str, +) -> Result { + debug!("executing statement batch: {}", query); + + let buf = encode(client, query)?; + let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + loop { + match responses.next().await? { + Message::ReadyForQuery(status) => return Ok(status.into()), + Message::CommandComplete(_) + | Message::EmptyQueryResponse + | Message::RowDescription(_) + | Message::DataRow(_) => {} + _ => return Err(Error::unexpected_message()), + } + } +} + +pub(crate) fn encode(client: &InnerClient, query: &str) -> Result { + client.with_buf(|buf| { + frontend::query(query, buf).map_err(Error::encode)?; + Ok(buf.split().freeze()) + }) +} + +pin_project! { + /// A stream of simple query results. + pub struct SimpleQueryStream { + responses: Responses, + columns: Option>, + status: ReadyForQueryStatus, + #[pin] + _p: PhantomPinned, + } +} + +impl SimpleQueryStream { + /// Returns if the connection is ready for querying, with the status of the connection. + /// + /// This might be available only after the stream has been exhausted. + pub fn ready_status(&self) -> ReadyForQueryStatus { + self.status + } +} + +impl Stream for SimpleQueryStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + loop { + match ready!(this.responses.poll_next(cx)?) { + Message::CommandComplete(body) => { + let rows = body + .tag() + .map_err(Error::parse)? + .rsplit(' ') + .next() + .unwrap() + .parse() + .unwrap_or(0); + return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(rows)))); + } + Message::EmptyQueryResponse => { + return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(0)))); + } + Message::RowDescription(body) => { + let columns = body + .fields() + .map(|f| Ok(SimpleColumn::new(f.name().to_string()))) + .collect::>() + .map_err(Error::parse)? + .into(); + + *this.columns = Some(columns); + } + Message::DataRow(body) => { + let row = match &this.columns { + Some(columns) => SimpleQueryRow::new(columns.clone(), body)?, + None => return Poll::Ready(Some(Err(Error::unexpected_message()))), + }; + return Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row)))); + } + Message::ReadyForQuery(s) => { + *this.status = s.into(); + return Poll::Ready(None); + } + _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), + } + } + } +} diff --git a/libs/proxy/tokio-postgres2/src/statement.rs b/libs/proxy/tokio-postgres2/src/statement.rs new file mode 100644 index 0000000000..22e160fc05 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/statement.rs @@ -0,0 +1,157 @@ +use crate::client::InnerClient; +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::types::Type; +use postgres_protocol2::{ + message::{backend::Field, frontend}, + Oid, +}; +use std::{ + fmt, + sync::{Arc, Weak}, +}; + +struct StatementInner { + client: Weak, + name: String, + params: Vec, + columns: Vec, +} + +impl Drop for StatementInner { + fn drop(&mut self) { + if let Some(client) = self.client.upgrade() { + let buf = client.with_buf(|buf| { + frontend::close(b'S', &self.name, buf).unwrap(); + frontend::sync(buf); + buf.split().freeze() + }); + let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf))); + } + } +} + +/// A prepared statement. +/// +/// Prepared statements can only be used with the connection that created them. +#[derive(Clone)] +pub struct Statement(Arc); + +impl Statement { + pub(crate) fn new( + inner: &Arc, + name: String, + params: Vec, + columns: Vec, + ) -> Statement { + Statement(Arc::new(StatementInner { + client: Arc::downgrade(inner), + name, + params, + columns, + })) + } + + pub(crate) fn new_anonymous(params: Vec, columns: Vec) -> Statement { + Statement(Arc::new(StatementInner { + client: Weak::new(), + name: String::new(), + params, + columns, + })) + } + + pub(crate) fn name(&self) -> &str { + &self.0.name + } + + /// Returns the expected types of the statement's parameters. + pub fn params(&self) -> &[Type] { + &self.0.params + } + + /// Returns information about the columns returned when the statement is queried. + pub fn columns(&self) -> &[Column] { + &self.0.columns + } +} + +/// Information about a column of a query. +pub struct Column { + name: String, + type_: Type, + + // raw fields from RowDescription + table_oid: Oid, + column_id: i16, + format: i16, + + // that better be stored in self.type_, but that is more radical refactoring + type_oid: Oid, + type_size: i16, + type_modifier: i32, +} + +impl Column { + pub(crate) fn new(name: String, type_: Type, raw_field: Field<'_>) -> Column { + Column { + name, + type_, + table_oid: raw_field.table_oid(), + column_id: raw_field.column_id(), + format: raw_field.format(), + type_oid: raw_field.type_oid(), + type_size: raw_field.type_size(), + type_modifier: raw_field.type_modifier(), + } + } + + /// Returns the name of the column. + pub fn name(&self) -> &str { + &self.name + } + + /// Returns the type of the column. + pub fn type_(&self) -> &Type { + &self.type_ + } + + /// Returns the table OID of the column. + pub fn table_oid(&self) -> Oid { + self.table_oid + } + + /// Returns the column ID of the column. + pub fn column_id(&self) -> i16 { + self.column_id + } + + /// Returns the format of the column. + pub fn format(&self) -> i16 { + self.format + } + + /// Returns the type OID of the column. + pub fn type_oid(&self) -> Oid { + self.type_oid + } + + /// Returns the type size of the column. + pub fn type_size(&self) -> i16 { + self.type_size + } + + /// Returns the type modifier of the column. + pub fn type_modifier(&self) -> i32 { + self.type_modifier + } +} + +impl fmt::Debug for Column { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("Column") + .field("name", &self.name) + .field("type", &self.type_) + .finish() + } +} diff --git a/libs/proxy/tokio-postgres2/src/tls.rs b/libs/proxy/tokio-postgres2/src/tls.rs new file mode 100644 index 0000000000..dc8140719f --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/tls.rs @@ -0,0 +1,162 @@ +//! TLS support. + +use std::error::Error; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use std::{fmt, io}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + +pub(crate) mod private { + pub struct ForcePrivateApi; +} + +/// Channel binding information returned from a TLS handshake. +pub struct ChannelBinding { + pub(crate) tls_server_end_point: Option>, +} + +impl ChannelBinding { + /// Creates a `ChannelBinding` containing no information. + pub fn none() -> ChannelBinding { + ChannelBinding { + tls_server_end_point: None, + } + } + + /// Creates a `ChannelBinding` containing `tls-server-end-point` channel binding information. + pub fn tls_server_end_point(tls_server_end_point: Vec) -> ChannelBinding { + ChannelBinding { + tls_server_end_point: Some(tls_server_end_point), + } + } +} + +/// A constructor of `TlsConnect`ors. +/// +/// Requires the `runtime` Cargo feature (enabled by default). +pub trait MakeTlsConnect { + /// The stream type created by the `TlsConnect` implementation. + type Stream: TlsStream + Unpin; + /// The `TlsConnect` implementation created by this type. + type TlsConnect: TlsConnect; + /// The error type returned by the `TlsConnect` implementation. + type Error: Into>; + + /// Creates a new `TlsConnect`or. + /// + /// The domain name is provided for certificate verification and SNI. + fn make_tls_connect(&mut self, domain: &str) -> Result; +} + +/// An asynchronous function wrapping a stream in a TLS session. +pub trait TlsConnect { + /// The stream returned by the future. + type Stream: TlsStream + Unpin; + /// The error returned by the future. + type Error: Into>; + /// The future returned by the connector. + type Future: Future>; + + /// Returns a future performing a TLS handshake over the stream. + fn connect(self, stream: S) -> Self::Future; + + #[doc(hidden)] + fn can_connect(&self, _: private::ForcePrivateApi) -> bool { + true + } +} + +/// A TLS-wrapped connection to a PostgreSQL database. +pub trait TlsStream: AsyncRead + AsyncWrite { + /// Returns channel binding information for the session. + fn channel_binding(&self) -> ChannelBinding; +} + +/// A `MakeTlsConnect` and `TlsConnect` implementation which simply returns an error. +/// +/// This can be used when `sslmode` is `none` or `prefer`. +#[derive(Debug, Copy, Clone)] +pub struct NoTls; + +impl MakeTlsConnect for NoTls { + type Stream = NoTlsStream; + type TlsConnect = NoTls; + type Error = NoTlsError; + + fn make_tls_connect(&mut self, _: &str) -> Result { + Ok(NoTls) + } +} + +impl TlsConnect for NoTls { + type Stream = NoTlsStream; + type Error = NoTlsError; + type Future = NoTlsFuture; + + fn connect(self, _: S) -> NoTlsFuture { + NoTlsFuture(()) + } + + fn can_connect(&self, _: private::ForcePrivateApi) -> bool { + false + } +} + +/// The future returned by `NoTls`. +pub struct NoTlsFuture(()); + +impl Future for NoTlsFuture { + type Output = Result; + + fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll { + Poll::Ready(Err(NoTlsError(()))) + } +} + +/// The TLS "stream" type produced by the `NoTls` connector. +/// +/// Since `NoTls` doesn't support TLS, this type is uninhabited. +pub enum NoTlsStream {} + +impl AsyncRead for NoTlsStream { + fn poll_read( + self: Pin<&mut Self>, + _: &mut Context<'_>, + _: &mut ReadBuf<'_>, + ) -> Poll> { + match *self {} + } +} + +impl AsyncWrite for NoTlsStream { + fn poll_write(self: Pin<&mut Self>, _: &mut Context<'_>, _: &[u8]) -> Poll> { + match *self {} + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + match *self {} + } + + fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + match *self {} + } +} + +impl TlsStream for NoTlsStream { + fn channel_binding(&self) -> ChannelBinding { + match *self {} + } +} + +/// The error returned by `NoTls`. +#[derive(Debug)] +pub struct NoTlsError(()); + +impl fmt::Display for NoTlsError { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.write_str("no TLS implementation configured") + } +} + +impl Error for NoTlsError {} diff --git a/libs/proxy/tokio-postgres2/src/to_statement.rs b/libs/proxy/tokio-postgres2/src/to_statement.rs new file mode 100644 index 0000000000..427f77dd79 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/to_statement.rs @@ -0,0 +1,57 @@ +use crate::to_statement::private::{Sealed, ToStatementType}; +use crate::Statement; + +mod private { + use crate::{Client, Error, Statement}; + + pub trait Sealed {} + + pub enum ToStatementType<'a> { + Statement(&'a Statement), + Query(&'a str), + } + + impl<'a> ToStatementType<'a> { + pub async fn into_statement(self, client: &Client) -> Result { + match self { + ToStatementType::Statement(s) => Ok(s.clone()), + ToStatementType::Query(s) => client.prepare(s).await, + } + } + } +} + +/// A trait abstracting over prepared and unprepared statements. +/// +/// Many methods are generic over this bound, so that they support both a raw query string as well as a statement which +/// was prepared previously. +/// +/// This trait is "sealed" and cannot be implemented by anything outside this crate. +pub trait ToStatement: Sealed { + #[doc(hidden)] + fn __convert(&self) -> ToStatementType<'_>; +} + +impl ToStatement for Statement { + fn __convert(&self) -> ToStatementType<'_> { + ToStatementType::Statement(self) + } +} + +impl Sealed for Statement {} + +impl ToStatement for str { + fn __convert(&self) -> ToStatementType<'_> { + ToStatementType::Query(self) + } +} + +impl Sealed for str {} + +impl ToStatement for String { + fn __convert(&self) -> ToStatementType<'_> { + ToStatementType::Query(self) + } +} + +impl Sealed for String {} diff --git a/libs/proxy/tokio-postgres2/src/transaction.rs b/libs/proxy/tokio-postgres2/src/transaction.rs new file mode 100644 index 0000000000..03a57e4947 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/transaction.rs @@ -0,0 +1,74 @@ +use crate::codec::FrontendMessage; +use crate::connection::RequestMessages; +use crate::query::RowStream; +use crate::{CancelToken, Client, Error, ReadyForQueryStatus}; +use postgres_protocol2::message::frontend; + +/// A representation of a PostgreSQL database transaction. +/// +/// Transactions will implicitly roll back when dropped. Use the `commit` method to commit the changes made in the +/// transaction. Transactions can be nested, with inner transactions implemented via safepoints. +pub struct Transaction<'a> { + client: &'a mut Client, + done: bool, +} + +impl Drop for Transaction<'_> { + fn drop(&mut self) { + if self.done { + return; + } + + let buf = self.client.inner().with_buf(|buf| { + frontend::query("ROLLBACK", buf).unwrap(); + buf.split().freeze() + }); + let _ = self + .client + .inner() + .send(RequestMessages::Single(FrontendMessage::Raw(buf))); + } +} + +impl<'a> Transaction<'a> { + pub(crate) fn new(client: &'a mut Client) -> Transaction<'a> { + Transaction { + client, + done: false, + } + } + + /// Consumes the transaction, committing all changes made within it. + pub async fn commit(mut self) -> Result { + self.done = true; + self.client.batch_execute("COMMIT").await + } + + /// Rolls the transaction back, discarding all changes made within it. + /// + /// This is equivalent to `Transaction`'s `Drop` implementation, but provides any error encountered to the caller. + pub async fn rollback(mut self) -> Result { + self.done = true; + self.client.batch_execute("ROLLBACK").await + } + + /// Like `Client::query_raw_txt`. + pub async fn query_raw_txt(&self, statement: &str, params: I) -> Result + where + S: AsRef, + I: IntoIterator>, + I::IntoIter: ExactSizeIterator, + { + self.client.query_raw_txt(statement, params).await + } + + /// Like `Client::cancel_token`. + pub fn cancel_token(&self) -> CancelToken { + self.client.cancel_token() + } + + /// Returns a reference to the underlying `Client`. + pub fn client(&self) -> &Client { + self.client + } +} diff --git a/libs/proxy/tokio-postgres2/src/transaction_builder.rs b/libs/proxy/tokio-postgres2/src/transaction_builder.rs new file mode 100644 index 0000000000..9718ac588c --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/transaction_builder.rs @@ -0,0 +1,113 @@ +use crate::{Client, Error, Transaction}; + +/// The isolation level of a database transaction. +#[derive(Debug, Copy, Clone)] +#[non_exhaustive] +pub enum IsolationLevel { + /// Equivalent to `ReadCommitted`. + ReadUncommitted, + + /// An individual statement in the transaction will see rows committed before it began. + ReadCommitted, + + /// All statements in the transaction will see the same view of rows committed before the first query in the + /// transaction. + RepeatableRead, + + /// The reads and writes in this transaction must be able to be committed as an atomic "unit" with respect to reads + /// and writes of all other concurrent serializable transactions without interleaving. + Serializable, +} + +/// A builder for database transactions. +pub struct TransactionBuilder<'a> { + client: &'a mut Client, + isolation_level: Option, + read_only: Option, + deferrable: Option, +} + +impl<'a> TransactionBuilder<'a> { + pub(crate) fn new(client: &'a mut Client) -> TransactionBuilder<'a> { + TransactionBuilder { + client, + isolation_level: None, + read_only: None, + deferrable: None, + } + } + + /// Sets the isolation level of the transaction. + pub fn isolation_level(mut self, isolation_level: IsolationLevel) -> Self { + self.isolation_level = Some(isolation_level); + self + } + + /// Sets the access mode of the transaction. + pub fn read_only(mut self, read_only: bool) -> Self { + self.read_only = Some(read_only); + self + } + + /// Sets the deferrability of the transaction. + /// + /// If the transaction is also serializable and read only, creation of the transaction may block, but when it + /// completes the transaction is able to run with less overhead and a guarantee that it will not be aborted due to + /// serialization failure. + pub fn deferrable(mut self, deferrable: bool) -> Self { + self.deferrable = Some(deferrable); + self + } + + /// Begins the transaction. + /// + /// The transaction will roll back by default - use the `commit` method to commit it. + pub async fn start(self) -> Result, Error> { + let mut query = "START TRANSACTION".to_string(); + let mut first = true; + + if let Some(level) = self.isolation_level { + first = false; + + query.push_str(" ISOLATION LEVEL "); + let level = match level { + IsolationLevel::ReadUncommitted => "READ UNCOMMITTED", + IsolationLevel::ReadCommitted => "READ COMMITTED", + IsolationLevel::RepeatableRead => "REPEATABLE READ", + IsolationLevel::Serializable => "SERIALIZABLE", + }; + query.push_str(level); + } + + if let Some(read_only) = self.read_only { + if !first { + query.push(','); + } + first = false; + + let s = if read_only { + " READ ONLY" + } else { + " READ WRITE" + }; + query.push_str(s); + } + + if let Some(deferrable) = self.deferrable { + if !first { + query.push(','); + } + + let s = if deferrable { + " DEFERRABLE" + } else { + " NOT DEFERRABLE" + }; + query.push_str(s); + } + + self.client.batch_execute(&query).await?; + + Ok(Transaction::new(self.client)) + } +} diff --git a/libs/proxy/tokio-postgres2/src/types.rs b/libs/proxy/tokio-postgres2/src/types.rs new file mode 100644 index 0000000000..e571d7ee00 --- /dev/null +++ b/libs/proxy/tokio-postgres2/src/types.rs @@ -0,0 +1,6 @@ +//! Types. +//! +//! This module is a reexport of the `postgres_types` crate. + +#[doc(inline)] +pub use postgres_types2::*; diff --git a/libs/remote_storage/src/azure_blob.rs b/libs/remote_storage/src/azure_blob.rs index 1c0d43d479..840917ef68 100644 --- a/libs/remote_storage/src/azure_blob.rs +++ b/libs/remote_storage/src/azure_blob.rs @@ -24,6 +24,7 @@ use azure_storage_blobs::{blob::operations::GetBlobBuilder, prelude::ContainerCl use bytes::Bytes; use futures::future::Either; use futures::stream::Stream; +use futures::FutureExt; use futures_util::StreamExt; use futures_util::TryStreamExt; use http_types::{StatusCode, Url}; @@ -31,6 +32,7 @@ use scopeguard::ScopeGuard; use tokio_util::sync::CancellationToken; use tracing::debug; use utils::backoff; +use utils::backoff::exponential_backoff_duration_seconds; use crate::metrics::{start_measuring_requests, AttemptOutcome, RequestKind}; use crate::{ @@ -218,6 +220,11 @@ impl AzureBlobStorage { let started_at = ScopeGuard::into_inner(started_at); let outcome = match &download { Ok(_) => AttemptOutcome::Ok, + // At this level in the stack 404 and 304 responses do not indicate an error. + // There's expected cases when a blob may not exist or hasn't been modified since + // the last get (e.g. probing for timeline indices and heatmap downloads). + // Callers should handle errors if they are unexpected. + Err(DownloadError::NotFound | DownloadError::Unmodified) => AttemptOutcome::Ok, Err(_) => AttemptOutcome::Err, }; crate::metrics::BUCKET_METRICS @@ -302,40 +309,59 @@ impl RemoteStorage for AzureBlobStorage { let mut next_marker = None; + let mut timeout_try_cnt = 1; + 'outer: loop { let mut builder = builder.clone(); if let Some(marker) = next_marker.clone() { builder = builder.marker(marker); } - let response = builder.into_stream(); - let response = response.into_stream().map_err(to_download_error); - let response = tokio_stream::StreamExt::timeout(response, self.timeout); - let response = response.map(|res| match res { - Ok(res) => res, - Err(_elapsed) => Err(DownloadError::Timeout), + // Azure Blob Rust SDK does not expose the list blob API directly. Users have to use + // their pageable iterator wrapper that returns all keys as a stream. We want to have + // full control of paging, and therefore we only take the first item from the stream. + let mut response_stream = builder.into_stream(); + let response = response_stream.next(); + // Timeout mechanism: Azure client will sometimes stuck on a request, but retrying that request + // would immediately succeed. Therefore, we use exponential backoff timeout to retry the request. + // (Usually, exponential backoff is used to determine the sleep time between two retries.) We + // start with 10.0 second timeout, and double the timeout for each failure, up to 5 failures. + // timeout = min(5 * (1.0+1.0)^n, self.timeout). + let this_timeout = (5.0 * exponential_backoff_duration_seconds(timeout_try_cnt, 1.0, self.timeout.as_secs_f64())).min(self.timeout.as_secs_f64()); + let response = tokio::time::timeout(Duration::from_secs_f64(this_timeout), response); + let response = response.map(|res| { + match res { + Ok(Some(Ok(res))) => Ok(Some(res)), + Ok(Some(Err(e))) => Err(to_download_error(e)), + Ok(None) => Ok(None), + Err(_elasped) => Err(DownloadError::Timeout), + } }); - - let mut response = std::pin::pin!(response); - let mut max_keys = max_keys.map(|mk| mk.get()); let next_item = tokio::select! { - op = response.next() => Ok(op), + op = response => op, _ = cancel.cancelled() => Err(DownloadError::Cancelled), - }?; + }; + + if let Err(DownloadError::Timeout) = &next_item { + timeout_try_cnt += 1; + if timeout_try_cnt <= 5 { + continue; + } + } + + let next_item = next_item?; + + if timeout_try_cnt >= 2 { + tracing::warn!("Azure Blob Storage list timed out and succeeded after {} tries", timeout_try_cnt); + } + timeout_try_cnt = 1; + let Some(entry) = next_item else { // The list is complete, so yield it. break; }; let mut res = Listing::default(); - let entry = match entry { - Ok(entry) => entry, - Err(e) => { - // The error is potentially retryable, so we must rewind the loop after yielding. - yield Err(e); - continue; - } - }; next_marker = entry.continuation(); let prefix_iter = entry .blobs @@ -351,7 +377,7 @@ impl RemoteStorage for AzureBlobStorage { last_modified: k.properties.last_modified.into(), size: k.properties.content_length, } - ); + ); for key in blob_iter { res.keys.push(key); diff --git a/libs/remote_storage/src/local_fs.rs b/libs/remote_storage/src/local_fs.rs index 553153826e..ee2fc9d6e2 100644 --- a/libs/remote_storage/src/local_fs.rs +++ b/libs/remote_storage/src/local_fs.rs @@ -360,7 +360,12 @@ impl RemoteStorage for LocalFs { let mut objects = Vec::with_capacity(keys.len()); for key in keys { let path = key.with_base(&self.storage_root); - let metadata = file_metadata(&path).await?; + let metadata = file_metadata(&path).await; + if let Err(DownloadError::NotFound) = metadata { + // Race: if the file is deleted between listing and metadata check, ignore it. + continue; + } + let metadata = metadata?; if metadata.is_dir() { continue; } diff --git a/libs/remote_storage/src/metrics.rs b/libs/remote_storage/src/metrics.rs index f1aa4c433b..48c121fbc8 100644 --- a/libs/remote_storage/src/metrics.rs +++ b/libs/remote_storage/src/metrics.rs @@ -176,7 +176,9 @@ pub(crate) struct BucketMetrics { impl Default for BucketMetrics { fn default() -> Self { - let buckets = [0.01, 0.10, 0.5, 1.0, 5.0, 10.0, 50.0, 100.0]; + // first bucket 100 microseconds to count requests that do not need to wait at all + // and get a permit immediately + let buckets = [0.0001, 0.01, 0.10, 0.5, 1.0, 5.0, 10.0, 50.0, 100.0]; let req_seconds = register_histogram_vec!( "remote_storage_s3_request_seconds", diff --git a/libs/utils/Cargo.toml b/libs/utils/Cargo.toml index 545317f958..a52d953d66 100644 --- a/libs/utils/Cargo.toml +++ b/libs/utils/Cargo.toml @@ -19,6 +19,7 @@ bincode.workspace = true bytes.workspace = true camino.workspace = true chrono.workspace = true +diatomic-waker.workspace = true git-version.workspace = true hex = { workspace = true, features = ["serde"] } humantime.workspace = true @@ -29,9 +30,11 @@ jsonwebtoken.workspace = true nix.workspace = true once_cell.workspace = true pin-project-lite.workspace = true +pprof.workspace = true regex.workspace = true routerify.workspace = true serde.workspace = true +serde_with.workspace = true serde_json.workspace = true signal-hook.workspace = true thiserror.workspace = true diff --git a/libs/utils/src/http/endpoint.rs b/libs/utils/src/http/endpoint.rs index 8ee5abd434..6a85f0ddeb 100644 --- a/libs/utils/src/http/endpoint.rs +++ b/libs/utils/src/http/endpoint.rs @@ -1,7 +1,8 @@ use crate::auth::{AuthError, Claims, SwappableJwtAuth}; use crate::http::error::{api_error_handler, route_error_handler, ApiError}; -use anyhow::Context; -use hyper::header::{HeaderName, AUTHORIZATION}; +use crate::http::request::{get_query_param, parse_query_param}; +use anyhow::{anyhow, Context}; +use hyper::header::{HeaderName, AUTHORIZATION, CONTENT_DISPOSITION}; use hyper::http::HeaderValue; use hyper::Method; use hyper::{header::CONTENT_TYPE, Body, Request, Response}; @@ -12,11 +13,13 @@ use routerify::{Middleware, RequestInfo, Router, RouterBuilder}; use tracing::{debug, info, info_span, warn, Instrument}; use std::future::Future; +use std::io::Write as _; use std::str::FromStr; +use std::time::Duration; use bytes::{Bytes, BytesMut}; -use std::io::Write as _; -use tokio::sync::mpsc; +use pprof::protos::Message as _; +use tokio::sync::{mpsc, Mutex}; use tokio_stream::wrappers::ReceiverStream; static SERVE_METRICS_COUNT: Lazy = Lazy::new(|| { @@ -328,6 +331,82 @@ pub async fn prometheus_metrics_handler(_req: Request) -> Result) -> Result, ApiError> { + enum Format { + Pprof, + Svg, + } + + // Parameters. + let format = match get_query_param(&req, "format")?.as_deref() { + None => Format::Pprof, + Some("pprof") => Format::Pprof, + Some("svg") => Format::Svg, + Some(format) => return Err(ApiError::BadRequest(anyhow!("invalid format {format}"))), + }; + let seconds = match parse_query_param(&req, "seconds")? { + None => 5, + Some(seconds @ 1..=30) => seconds, + Some(_) => return Err(ApiError::BadRequest(anyhow!("duration must be 1-30 secs"))), + }; + let frequency_hz = match parse_query_param(&req, "frequency")? { + None => 99, + Some(1001..) => return Err(ApiError::BadRequest(anyhow!("frequency must be <=1000 Hz"))), + Some(frequency) => frequency, + }; + + // Only allow one profiler at a time. + static PROFILE_LOCK: Lazy> = Lazy::new(|| Mutex::new(())); + let _lock = PROFILE_LOCK + .try_lock() + .map_err(|_| ApiError::Conflict("profiler already running".into()))?; + + // Take the profile. + let report = tokio::task::spawn_blocking(move || { + let guard = pprof::ProfilerGuardBuilder::default() + .frequency(frequency_hz) + .blocklist(&["libc", "libgcc", "pthread", "vdso"]) + .build()?; + std::thread::sleep(Duration::from_secs(seconds)); + guard.report().build() + }) + .await + .map_err(|join_err| ApiError::InternalServerError(join_err.into()))? + .map_err(|pprof_err| ApiError::InternalServerError(pprof_err.into()))?; + + // Return the report in the requested format. + match format { + Format::Pprof => { + let mut body = Vec::new(); + report + .pprof() + .map_err(|err| ApiError::InternalServerError(err.into()))? + .write_to_vec(&mut body) + .map_err(|err| ApiError::InternalServerError(err.into()))?; + + Response::builder() + .status(200) + .header(CONTENT_TYPE, "application/octet-stream") + .header(CONTENT_DISPOSITION, "attachment; filename=\"profile.pb\"") + .body(Body::from(body)) + .map_err(|err| ApiError::InternalServerError(err.into())) + } + + Format::Svg => { + let mut body = Vec::new(); + report + .flamegraph(&mut body) + .map_err(|err| ApiError::InternalServerError(err.into()))?; + Response::builder() + .status(200) + .header(CONTENT_TYPE, "image/svg+xml") + .body(Body::from(body)) + .map_err(|err| ApiError::InternalServerError(err.into())) + } + } +} + pub fn add_request_id_middleware( ) -> Middleware { Middleware::pre(move |req| async move { diff --git a/libs/utils/src/http/request.rs b/libs/utils/src/http/request.rs index 8b8ed5a67f..7ea71685ec 100644 --- a/libs/utils/src/http/request.rs +++ b/libs/utils/src/http/request.rs @@ -30,7 +30,7 @@ pub fn parse_request_param( } } -fn get_query_param<'a>( +pub fn get_query_param<'a>( request: &'a Request, param_name: &str, ) -> Result>, ApiError> { diff --git a/libs/utils/src/postgres_client.rs b/libs/utils/src/postgres_client.rs index dba74f5b0b..a62568202b 100644 --- a/libs/utils/src/postgres_client.rs +++ b/libs/utils/src/postgres_client.rs @@ -7,29 +7,88 @@ use postgres_connection::{parse_host_port, PgConnectionConfig}; use crate::id::TenantTimelineId; +#[derive(Copy, Clone, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum InterpretedFormat { + Bincode, + Protobuf, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum Compression { + Zstd { level: i8 }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +#[serde(tag = "type", content = "args")] +#[serde(rename_all = "kebab-case")] +pub enum PostgresClientProtocol { + /// Usual Postgres replication protocol + Vanilla, + /// Custom shard-aware protocol that replicates interpreted records. + /// Used to send wal from safekeeper to pageserver. + Interpreted { + format: InterpretedFormat, + compression: Option, + }, +} + +pub struct ConnectionConfigArgs<'a> { + pub protocol: PostgresClientProtocol, + + pub ttid: TenantTimelineId, + pub shard_number: Option, + pub shard_count: Option, + pub shard_stripe_size: Option, + + pub listen_pg_addr_str: &'a str, + + pub auth_token: Option<&'a str>, + pub availability_zone: Option<&'a str>, +} + +impl<'a> ConnectionConfigArgs<'a> { + fn options(&'a self) -> Vec { + let mut options = vec![ + "-c".to_owned(), + format!("timeline_id={}", self.ttid.timeline_id), + format!("tenant_id={}", self.ttid.tenant_id), + format!( + "protocol={}", + serde_json::to_string(&self.protocol).unwrap() + ), + ]; + + if self.shard_number.is_some() { + assert!(self.shard_count.is_some()); + assert!(self.shard_stripe_size.is_some()); + + options.push(format!("shard_count={}", self.shard_count.unwrap())); + options.push(format!("shard_number={}", self.shard_number.unwrap())); + options.push(format!( + "shard_stripe_size={}", + self.shard_stripe_size.unwrap() + )); + } + + options + } +} + /// Create client config for fetching WAL from safekeeper on particular timeline. /// listen_pg_addr_str is in form host:\[port\]. pub fn wal_stream_connection_config( - TenantTimelineId { - tenant_id, - timeline_id, - }: TenantTimelineId, - listen_pg_addr_str: &str, - auth_token: Option<&str>, - availability_zone: Option<&str>, + args: ConnectionConfigArgs, ) -> anyhow::Result { let (host, port) = - parse_host_port(listen_pg_addr_str).context("Unable to parse listen_pg_addr_str")?; + parse_host_port(args.listen_pg_addr_str).context("Unable to parse listen_pg_addr_str")?; let port = port.unwrap_or(5432); let mut connstr = PgConnectionConfig::new_host_port(host, port) - .extend_options([ - "-c".to_owned(), - format!("timeline_id={}", timeline_id), - format!("tenant_id={}", tenant_id), - ]) - .set_password(auth_token.map(|s| s.to_owned())); + .extend_options(args.options()) + .set_password(args.auth_token.map(|s| s.to_owned())); - if let Some(availability_zone) = availability_zone { + if let Some(availability_zone) = args.availability_zone { connstr = connstr.extend_options([format!("availability_zone={}", availability_zone)]); } diff --git a/libs/utils/src/seqwait.rs b/libs/utils/src/seqwait.rs index 375b227b99..d99dc25769 100644 --- a/libs/utils/src/seqwait.rs +++ b/libs/utils/src/seqwait.rs @@ -83,7 +83,9 @@ where } wake_these.push(self.heap.pop().unwrap().wake_channel); } - self.update_status(); + if !wake_these.is_empty() { + self.update_status(); + } wake_these } diff --git a/libs/utils/src/sync.rs b/libs/utils/src/sync.rs index 2ee8f35449..7aa26e24bc 100644 --- a/libs/utils/src/sync.rs +++ b/libs/utils/src/sync.rs @@ -1,3 +1,5 @@ pub mod heavier_once_cell; pub mod gate; + +pub mod spsc_fold; diff --git a/libs/utils/src/sync/spsc_fold.rs b/libs/utils/src/sync/spsc_fold.rs new file mode 100644 index 0000000000..a33f8097fc --- /dev/null +++ b/libs/utils/src/sync/spsc_fold.rs @@ -0,0 +1,404 @@ +use core::{future::poll_fn, task::Poll}; +use std::sync::{Arc, Mutex}; + +use diatomic_waker::DiatomicWaker; + +pub struct Sender { + state: Arc>, +} + +pub struct Receiver { + state: Arc>, +} + +struct Inner { + wake_receiver: DiatomicWaker, + wake_sender: DiatomicWaker, + value: Mutex>, +} + +enum State { + NoData, + HasData(T), + TryFoldFailed, // transient state + SenderWaitsForReceiverToConsume(T), + SenderGone(Option), + ReceiverGone, + AllGone, + SenderDropping, // transient state + ReceiverDropping, // transient state +} + +pub fn channel() -> (Sender, Receiver) { + let inner = Inner { + wake_receiver: DiatomicWaker::new(), + wake_sender: DiatomicWaker::new(), + value: Mutex::new(State::NoData), + }; + + let state = Arc::new(inner); + ( + Sender { + state: state.clone(), + }, + Receiver { state }, + ) +} + +#[derive(Debug, thiserror::Error)] +pub enum SendError { + #[error("receiver is gone")] + ReceiverGone, +} + +impl Sender { + /// # Panics + /// + /// If `try_fold` panics, any subsequent call to `send` panic. + pub async fn send(&mut self, value: T, try_fold: F) -> Result<(), SendError> + where + F: Fn(&mut T, T) -> Result<(), T>, + { + let mut value = Some(value); + poll_fn(|cx| { + let mut guard = self.state.value.lock().unwrap(); + match &mut *guard { + State::NoData => { + *guard = State::HasData(value.take().unwrap()); + self.state.wake_receiver.notify(); + Poll::Ready(Ok(())) + } + State::HasData(_) => { + let State::HasData(acc_mut) = &mut *guard else { + unreachable!("this match arm guarantees that the guard is HasData"); + }; + match try_fold(acc_mut, value.take().unwrap()) { + Ok(()) => { + // no need to wake receiver, if it was waiting it already + // got a wake-up when we transitioned from NoData to HasData + Poll::Ready(Ok(())) + } + Err(unfoldable_value) => { + value = Some(unfoldable_value); + let State::HasData(acc) = + std::mem::replace(&mut *guard, State::TryFoldFailed) + else { + unreachable!("this match arm guarantees that the guard is HasData"); + }; + *guard = State::SenderWaitsForReceiverToConsume(acc); + // SAFETY: send is single threaded due to `&mut self` requirement, + // therefore register is not concurrent. + unsafe { + self.state.wake_sender.register(cx.waker()); + } + Poll::Pending + } + } + } + State::SenderWaitsForReceiverToConsume(_data) => { + // Really, we shouldn't be polled until receiver has consumed and wakes us. + Poll::Pending + } + State::ReceiverGone => Poll::Ready(Err(SendError::ReceiverGone)), + State::SenderGone(_) + | State::AllGone + | State::SenderDropping + | State::ReceiverDropping + | State::TryFoldFailed => { + unreachable!(); + } + } + }) + .await + } +} + +impl Drop for Sender { + fn drop(&mut self) { + let Ok(mut guard) = self.state.value.lock() else { + return; + }; + *guard = match std::mem::replace(&mut *guard, State::SenderDropping) { + State::NoData => State::SenderGone(None), + State::HasData(data) | State::SenderWaitsForReceiverToConsume(data) => { + State::SenderGone(Some(data)) + } + State::ReceiverGone => State::AllGone, + State::TryFoldFailed + | State::SenderGone(_) + | State::AllGone + | State::SenderDropping + | State::ReceiverDropping => { + unreachable!("unreachable state {:?}", guard.discriminant_str()) + } + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum RecvError { + #[error("sender is gone")] + SenderGone, +} + +impl Receiver { + pub async fn recv(&mut self) -> Result { + poll_fn(|cx| { + let mut guard = self.state.value.lock().unwrap(); + match &mut *guard { + State::NoData => { + // SAFETY: recv is single threaded due to `&mut self` requirement, + // therefore register is not concurrent. + unsafe { + self.state.wake_receiver.register(cx.waker()); + } + Poll::Pending + } + guard @ State::HasData(_) + | guard @ State::SenderWaitsForReceiverToConsume(_) + | guard @ State::SenderGone(Some(_)) => { + let data = guard + .take_data() + .expect("in these states, data is guaranteed to be present"); + self.state.wake_sender.notify(); + Poll::Ready(Ok(data)) + } + State::SenderGone(None) => Poll::Ready(Err(RecvError::SenderGone)), + State::ReceiverGone + | State::AllGone + | State::SenderDropping + | State::ReceiverDropping + | State::TryFoldFailed => { + unreachable!("unreachable state {:?}", guard.discriminant_str()); + } + } + }) + .await + } +} + +impl Drop for Receiver { + fn drop(&mut self) { + let Ok(mut guard) = self.state.value.lock() else { + return; + }; + *guard = match std::mem::replace(&mut *guard, State::ReceiverDropping) { + State::NoData => State::ReceiverGone, + State::HasData(_) | State::SenderWaitsForReceiverToConsume(_) => State::ReceiverGone, + State::SenderGone(_) => State::AllGone, + State::TryFoldFailed + | State::ReceiverGone + | State::AllGone + | State::SenderDropping + | State::ReceiverDropping => { + unreachable!("unreachable state {:?}", guard.discriminant_str()) + } + } + } +} + +impl State { + fn take_data(&mut self) -> Option { + match self { + State::HasData(_) => { + let State::HasData(data) = std::mem::replace(self, State::NoData) else { + unreachable!("this match arm guarantees that the state is HasData"); + }; + Some(data) + } + State::SenderWaitsForReceiverToConsume(_) => { + let State::SenderWaitsForReceiverToConsume(data) = + std::mem::replace(self, State::NoData) + else { + unreachable!( + "this match arm guarantees that the state is SenderWaitsForReceiverToConsume" + ); + }; + Some(data) + } + State::SenderGone(data) => Some(data.take().unwrap()), + State::NoData + | State::TryFoldFailed + | State::ReceiverGone + | State::AllGone + | State::SenderDropping + | State::ReceiverDropping => None, + } + } + fn discriminant_str(&self) -> &'static str { + match self { + State::NoData => "NoData", + State::HasData(_) => "HasData", + State::TryFoldFailed => "TryFoldFailed", + State::SenderWaitsForReceiverToConsume(_) => "SenderWaitsForReceiverToConsume", + State::SenderGone(_) => "SenderGone", + State::ReceiverGone => "ReceiverGone", + State::AllGone => "AllGone", + State::SenderDropping => "SenderDropping", + State::ReceiverDropping => "ReceiverDropping", + } + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + const FOREVER: std::time::Duration = std::time::Duration::from_secs(u64::MAX); + + #[tokio::test] + async fn test_send_recv() { + let (mut sender, mut receiver) = channel(); + + sender + .send(42, |acc, val| { + *acc += val; + Ok(()) + }) + .await + .unwrap(); + + let received = receiver.recv().await.unwrap(); + assert_eq!(received, 42); + } + + #[tokio::test] + async fn test_send_recv_with_fold() { + let (mut sender, mut receiver) = channel(); + + sender + .send(1, |acc, val| { + *acc += val; + Ok(()) + }) + .await + .unwrap(); + sender + .send(2, |acc, val| { + *acc += val; + Ok(()) + }) + .await + .unwrap(); + + let received = receiver.recv().await.unwrap(); + assert_eq!(received, 3); + } + + #[tokio::test(start_paused = true)] + async fn test_sender_waits_for_receiver_if_try_fold_fails() { + let (mut sender, mut receiver) = channel(); + + sender.send(23, |_, _| panic!("first send")).await.unwrap(); + + let send_fut = sender.send(42, |_, val| Err(val)); + let mut send_fut = std::pin::pin!(send_fut); + + tokio::select! { + _ = tokio::time::sleep(FOREVER) => {}, + _ = &mut send_fut => { + panic!("send should not complete"); + }, + } + + let val = receiver.recv().await.unwrap(); + assert_eq!(val, 23); + + tokio::select! { + _ = tokio::time::sleep(FOREVER) => { + panic!("receiver should have consumed the value"); + }, + _ = &mut send_fut => { }, + } + + let val = receiver.recv().await.unwrap(); + assert_eq!(val, 42); + } + + #[tokio::test(start_paused = true)] + async fn test_sender_errors_if_waits_for_receiver_and_receiver_drops() { + let (mut sender, receiver) = channel(); + + sender.send(23, |_, _| unreachable!()).await.unwrap(); + + let send_fut = sender.send(42, |_, val| Err(val)); + let send_fut = std::pin::pin!(send_fut); + + drop(receiver); + + let result = send_fut.await; + assert!(matches!(result, Err(SendError::ReceiverGone))); + } + + #[tokio::test(start_paused = true)] + async fn test_receiver_errors_if_waits_for_sender_and_sender_drops() { + let (sender, mut receiver) = channel::<()>(); + + let recv_fut = receiver.recv(); + let recv_fut = std::pin::pin!(recv_fut); + + drop(sender); + + let result = recv_fut.await; + assert!(matches!(result, Err(RecvError::SenderGone))); + } + + #[tokio::test(start_paused = true)] + async fn test_receiver_errors_if_waits_for_sender_and_sender_drops_with_data() { + let (mut sender, mut receiver) = channel(); + + sender.send(42, |_, _| unreachable!()).await.unwrap(); + + { + let recv_fut = receiver.recv(); + let recv_fut = std::pin::pin!(recv_fut); + + drop(sender); + + let val = recv_fut.await.unwrap(); + assert_eq!(val, 42); + } + + let result = receiver.recv().await; + assert!(matches!(result, Err(RecvError::SenderGone))); + } + + #[tokio::test(start_paused = true)] + async fn test_receiver_waits_for_sender_if_no_data() { + let (mut sender, mut receiver) = channel(); + + let recv_fut = receiver.recv(); + let mut recv_fut = std::pin::pin!(recv_fut); + + tokio::select! { + _ = tokio::time::sleep(FOREVER) => {}, + _ = &mut recv_fut => { + panic!("recv should not complete"); + }, + } + + sender.send(42, |_, _| Ok(())).await.unwrap(); + + let val = recv_fut.await.unwrap(); + assert_eq!(val, 42); + } + + #[tokio::test] + async fn test_receiver_gone_while_nodata() { + let (mut sender, receiver) = channel(); + drop(receiver); + + let result = sender.send(42, |_, _| Ok(())).await; + assert!(matches!(result, Err(SendError::ReceiverGone))); + } + + #[tokio::test] + async fn test_sender_gone_while_nodata() { + let (sender, mut receiver) = super::channel::(); + drop(sender); + + let result = receiver.recv().await; + assert!(matches!(result, Err(RecvError::SenderGone))); + } +} diff --git a/libs/vm_monitor/src/cgroup.rs b/libs/vm_monitor/src/cgroup.rs index 3223765016..1d70cedcf9 100644 --- a/libs/vm_monitor/src/cgroup.rs +++ b/libs/vm_monitor/src/cgroup.rs @@ -218,7 +218,7 @@ impl MemoryStatus { fn debug_slice(slice: &[Self]) -> impl '_ + Debug { struct DS<'a>(&'a [MemoryStatus]); - impl<'a> Debug for DS<'a> { + impl Debug for DS<'_> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.debug_struct("[MemoryStatus]") .field( @@ -233,7 +233,7 @@ impl MemoryStatus { struct Fields<'a, F>(&'a [MemoryStatus], F); - impl<'a, F: Fn(&MemoryStatus) -> T, T: Debug> Debug for Fields<'a, F> { + impl T, T: Debug> Debug for Fields<'_, F> { fn fmt(&self, f: &mut Formatter) -> fmt::Result { f.debug_list().entries(self.0.iter().map(&self.1)).finish() } diff --git a/libs/wal_decoder/Cargo.toml b/libs/wal_decoder/Cargo.toml index c8c0f4c990..8fac4e38ca 100644 --- a/libs/wal_decoder/Cargo.toml +++ b/libs/wal_decoder/Cargo.toml @@ -8,11 +8,19 @@ license.workspace = true testing = ["pageserver_api/testing"] [dependencies] +async-compression.workspace = true anyhow.workspace = true bytes.workspace = true pageserver_api.workspace = true +prost.workspace = true postgres_ffi.workspace = true serde.workspace = true +thiserror.workspace = true +tokio = { workspace = true, features = ["io-util"] } +tonic.workspace = true tracing.workspace = true utils.workspace = true workspace_hack = { version = "0.1", path = "../../workspace_hack" } + +[build-dependencies] +tonic-build.workspace = true diff --git a/libs/wal_decoder/build.rs b/libs/wal_decoder/build.rs new file mode 100644 index 0000000000..d5b7ad02ad --- /dev/null +++ b/libs/wal_decoder/build.rs @@ -0,0 +1,11 @@ +fn main() -> Result<(), Box> { + // Generate rust code from .proto protobuf. + // + // Note: we previously tried to use deterministic location at proto/ for + // easy location, but apparently interference with cachepot sometimes fails + // the build then. Anyway, per cargo docs build script shouldn't output to + // anywhere but $OUT_DIR. + tonic_build::compile_protos("proto/interpreted_wal.proto") + .unwrap_or_else(|e| panic!("failed to compile protos {:?}", e)); + Ok(()) +} diff --git a/libs/wal_decoder/proto/interpreted_wal.proto b/libs/wal_decoder/proto/interpreted_wal.proto new file mode 100644 index 0000000000..0393392c1a --- /dev/null +++ b/libs/wal_decoder/proto/interpreted_wal.proto @@ -0,0 +1,43 @@ +syntax = "proto3"; + +package interpreted_wal; + +message InterpretedWalRecords { + repeated InterpretedWalRecord records = 1; + optional uint64 next_record_lsn = 2; +} + +message InterpretedWalRecord { + optional bytes metadata_record = 1; + SerializedValueBatch batch = 2; + uint64 next_record_lsn = 3; + bool flush_uncommitted = 4; + uint32 xid = 5; +} + +message SerializedValueBatch { + bytes raw = 1; + repeated ValueMeta metadata = 2; + uint64 max_lsn = 3; + uint64 len = 4; +} + +enum ValueMetaType { + Serialized = 0; + Observed = 1; +} + +message ValueMeta { + ValueMetaType type = 1; + CompactKey key = 2; + uint64 lsn = 3; + optional uint64 batch_offset = 4; + optional uint64 len = 5; + optional bool will_init = 6; +} + +message CompactKey { + int64 high = 1; + int64 low = 2; +} + diff --git a/libs/wal_decoder/src/decoder.rs b/libs/wal_decoder/src/decoder.rs index 1895f25bfc..36c4b19266 100644 --- a/libs/wal_decoder/src/decoder.rs +++ b/libs/wal_decoder/src/decoder.rs @@ -4,6 +4,7 @@ use crate::models::*; use crate::serialized_batch::SerializedValueBatch; use bytes::{Buf, Bytes}; +use pageserver_api::key::rel_block_to_key; use pageserver_api::reltag::{RelTag, SlruKind}; use pageserver_api::shard::ShardIdentity; use postgres_ffi::pg_constants; @@ -32,7 +33,8 @@ impl InterpretedWalRecord { FlushUncommittedRecords::No }; - let metadata_record = MetadataRecord::from_decoded(&decoded, next_record_lsn, pg_version)?; + let metadata_record = + MetadataRecord::from_decoded_filtered(&decoded, shard, next_record_lsn, pg_version)?; let batch = SerializedValueBatch::from_decoded_filtered( decoded, shard, @@ -51,8 +53,13 @@ impl InterpretedWalRecord { } impl MetadataRecord { - fn from_decoded( + /// Builds a metadata record for this WAL record, if any. + /// + /// Only metadata records relevant for the given shard are emitted. Currently, most metadata + /// records are broadcast to all shards for simplicity, but this should be improved. + fn from_decoded_filtered( decoded: &DecodedWALRecord, + shard: &ShardIdentity, next_record_lsn: Lsn, pg_version: u32, ) -> anyhow::Result> { @@ -61,26 +68,27 @@ impl MetadataRecord { let mut buf = decoded.record.clone(); buf.advance(decoded.main_data_offset); - match decoded.xl_rmid { + // First, generate metadata records from the decoded WAL record. + let mut metadata_record = match decoded.xl_rmid { pg_constants::RM_HEAP_ID | pg_constants::RM_HEAP2_ID => { - Self::decode_heapam_record(&mut buf, decoded, pg_version) + Self::decode_heapam_record(&mut buf, decoded, pg_version)? } - pg_constants::RM_NEON_ID => Self::decode_neonmgr_record(&mut buf, decoded, pg_version), + pg_constants::RM_NEON_ID => Self::decode_neonmgr_record(&mut buf, decoded, pg_version)?, // Handle other special record types - pg_constants::RM_SMGR_ID => Self::decode_smgr_record(&mut buf, decoded), - pg_constants::RM_DBASE_ID => Self::decode_dbase_record(&mut buf, decoded, pg_version), + pg_constants::RM_SMGR_ID => Self::decode_smgr_record(&mut buf, decoded)?, + pg_constants::RM_DBASE_ID => Self::decode_dbase_record(&mut buf, decoded, pg_version)?, pg_constants::RM_TBLSPC_ID => { tracing::trace!("XLOG_TBLSPC_CREATE/DROP is not handled yet"); - Ok(None) + None } - pg_constants::RM_CLOG_ID => Self::decode_clog_record(&mut buf, decoded, pg_version), + pg_constants::RM_CLOG_ID => Self::decode_clog_record(&mut buf, decoded, pg_version)?, pg_constants::RM_XACT_ID => { - Self::decode_xact_record(&mut buf, decoded, next_record_lsn) + Self::decode_xact_record(&mut buf, decoded, next_record_lsn)? } pg_constants::RM_MULTIXACT_ID => { - Self::decode_multixact_record(&mut buf, decoded, pg_version) + Self::decode_multixact_record(&mut buf, decoded, pg_version)? } - pg_constants::RM_RELMAP_ID => Self::decode_relmap_record(&mut buf, decoded), + pg_constants::RM_RELMAP_ID => Self::decode_relmap_record(&mut buf, decoded)?, // This is an odd duck. It needs to go to all shards. // Since it uses the checkpoint image (that's initialized from CHECKPOINT_KEY // in WalIngest::new), we have to send the whole DecodedWalRecord::record to @@ -89,19 +97,48 @@ impl MetadataRecord { // Alternatively, one can make the checkpoint part of the subscription protocol // to the pageserver. This should work fine, but can be done at a later point. pg_constants::RM_XLOG_ID => { - Self::decode_xlog_record(&mut buf, decoded, next_record_lsn) + Self::decode_xlog_record(&mut buf, decoded, next_record_lsn)? } pg_constants::RM_LOGICALMSG_ID => { - Self::decode_logical_message_record(&mut buf, decoded) + Self::decode_logical_message_record(&mut buf, decoded)? } - pg_constants::RM_STANDBY_ID => Self::decode_standby_record(&mut buf, decoded), - pg_constants::RM_REPLORIGIN_ID => Self::decode_replorigin_record(&mut buf, decoded), + pg_constants::RM_STANDBY_ID => Self::decode_standby_record(&mut buf, decoded)?, + pg_constants::RM_REPLORIGIN_ID => Self::decode_replorigin_record(&mut buf, decoded)?, _unexpected => { // TODO: consider failing here instead of blindly doing something without // understanding the protocol - Ok(None) + None + } + }; + + // Next, filter the metadata record by shard. + + // Route VM page updates to the shards that own them. VM pages are stored in the VM fork + // of the main relation. These are sharded and managed just like regular relation pages. + // See: https://github.com/neondatabase/neon/issues/9855 + if let Some( + MetadataRecord::Heapam(HeapamRecord::ClearVmBits(ref mut clear_vm_bits)) + | MetadataRecord::Neonrmgr(NeonrmgrRecord::ClearVmBits(ref mut clear_vm_bits)), + ) = metadata_record + { + let is_local_vm_page = |heap_blk| { + let vm_blk = pg_constants::HEAPBLK_TO_MAPBLOCK(heap_blk); + shard.is_key_local(&rel_block_to_key(clear_vm_bits.vm_rel, vm_blk)) + }; + // Send the old and new VM page updates to their respective shards. + clear_vm_bits.old_heap_blkno = clear_vm_bits + .old_heap_blkno + .filter(|&blkno| is_local_vm_page(blkno)); + clear_vm_bits.new_heap_blkno = clear_vm_bits + .new_heap_blkno + .filter(|&blkno| is_local_vm_page(blkno)); + // If neither VM page belongs to this shard, discard the record. + if clear_vm_bits.old_heap_blkno.is_none() && clear_vm_bits.new_heap_blkno.is_none() { + metadata_record = None } } + + Ok(metadata_record) } fn decode_heapam_record( diff --git a/libs/wal_decoder/src/lib.rs b/libs/wal_decoder/src/lib.rs index a8a26956e6..96b717021f 100644 --- a/libs/wal_decoder/src/lib.rs +++ b/libs/wal_decoder/src/lib.rs @@ -1,3 +1,4 @@ pub mod decoder; pub mod models; pub mod serialized_batch; +pub mod wire_format; diff --git a/libs/wal_decoder/src/models.rs b/libs/wal_decoder/src/models.rs index c69f8c869a..af22de5d95 100644 --- a/libs/wal_decoder/src/models.rs +++ b/libs/wal_decoder/src/models.rs @@ -37,12 +37,32 @@ use utils::lsn::Lsn; use crate::serialized_batch::SerializedValueBatch; +// Code generated by protobuf. +pub mod proto { + // Tonic does derives as `#[derive(Clone, PartialEq, ::prost::Message)]` + // we don't use these types for anything but broker data transmission, + // so it's ok to ignore this one. + #![allow(clippy::derive_partial_eq_without_eq)] + // The generated ValueMeta has a `len` method generate for its `len` field. + #![allow(clippy::len_without_is_empty)] + tonic::include_proto!("interpreted_wal"); +} + #[derive(Serialize, Deserialize)] pub enum FlushUncommittedRecords { Yes, No, } +/// A batch of interpreted WAL records +#[derive(Serialize, Deserialize)] +pub struct InterpretedWalRecords { + pub records: Vec, + // Start LSN of the next record after the batch. + // Note that said record may not belong to the current shard. + pub next_record_lsn: Option, +} + /// An interpreted Postgres WAL record, ready to be handled by the pageserver #[derive(Serialize, Deserialize)] pub struct InterpretedWalRecord { @@ -65,6 +85,18 @@ pub struct InterpretedWalRecord { pub xid: TransactionId, } +impl InterpretedWalRecord { + /// Checks if the WAL record is empty + /// + /// An empty interpreted WAL record has no data or metadata and does not have to be sent to the + /// pageserver. + pub fn is_empty(&self) -> bool { + self.batch.is_empty() + && self.metadata_record.is_none() + && matches!(self.flush_uncommitted, FlushUncommittedRecords::No) + } +} + /// The interpreted part of the Postgres WAL record which requires metadata /// writes to the underlying storage engine. #[derive(Serialize, Deserialize)] diff --git a/libs/wal_decoder/src/serialized_batch.rs b/libs/wal_decoder/src/serialized_batch.rs index 9c0708ebbe..41294da7a0 100644 --- a/libs/wal_decoder/src/serialized_batch.rs +++ b/libs/wal_decoder/src/serialized_batch.rs @@ -496,11 +496,16 @@ impl SerializedValueBatch { } } - /// Checks if the batch is empty - /// - /// A batch is empty when it contains no serialized values. - /// Note that it may still contain observed values. + /// Checks if the batch contains any serialized or observed values pub fn is_empty(&self) -> bool { + !self.has_data() && self.metadata.is_empty() + } + + /// Checks if the batch contains data + /// + /// Note that if this returns false, it may still contain observed values or + /// a metadata record. + pub fn has_data(&self) -> bool { let empty = self.raw.is_empty(); if cfg!(debug_assertions) && empty { @@ -510,7 +515,7 @@ impl SerializedValueBatch { .all(|meta| matches!(meta, ValueMeta::Observed(_)))); } - empty + !empty } /// Returns the number of values serialized in the batch diff --git a/libs/wal_decoder/src/wire_format.rs b/libs/wal_decoder/src/wire_format.rs new file mode 100644 index 0000000000..5a343054c3 --- /dev/null +++ b/libs/wal_decoder/src/wire_format.rs @@ -0,0 +1,356 @@ +use bytes::{BufMut, Bytes, BytesMut}; +use pageserver_api::key::CompactKey; +use prost::{DecodeError, EncodeError, Message}; +use tokio::io::AsyncWriteExt; +use utils::bin_ser::{BeSer, DeserializeError, SerializeError}; +use utils::lsn::Lsn; +use utils::postgres_client::{Compression, InterpretedFormat}; + +use crate::models::{ + FlushUncommittedRecords, InterpretedWalRecord, InterpretedWalRecords, MetadataRecord, +}; + +use crate::serialized_batch::{ + ObservedValueMeta, SerializedValueBatch, SerializedValueMeta, ValueMeta, +}; + +use crate::models::proto; + +#[derive(Debug, thiserror::Error)] +pub enum ToWireFormatError { + #[error("{0}")] + Bincode(#[from] SerializeError), + #[error("{0}")] + Protobuf(#[from] ProtobufSerializeError), + #[error("{0}")] + Compression(#[from] std::io::Error), +} + +#[derive(Debug, thiserror::Error)] +pub enum ProtobufSerializeError { + #[error("{0}")] + MetadataRecord(#[from] SerializeError), + #[error("{0}")] + Encode(#[from] EncodeError), +} + +#[derive(Debug, thiserror::Error)] +pub enum FromWireFormatError { + #[error("{0}")] + Bincode(#[from] DeserializeError), + #[error("{0}")] + Protobuf(#[from] ProtobufDeserializeError), + #[error("{0}")] + Decompress(#[from] std::io::Error), +} + +#[derive(Debug, thiserror::Error)] +pub enum ProtobufDeserializeError { + #[error("{0}")] + Transcode(#[from] TranscodeError), + #[error("{0}")] + Decode(#[from] DecodeError), +} + +#[derive(Debug, thiserror::Error)] +pub enum TranscodeError { + #[error("{0}")] + BadInput(String), + #[error("{0}")] + MetadataRecord(#[from] DeserializeError), +} + +pub trait ToWireFormat { + fn to_wire( + self, + format: InterpretedFormat, + compression: Option, + ) -> impl std::future::Future> + Send; +} + +pub trait FromWireFormat { + type T; + fn from_wire( + buf: &Bytes, + format: InterpretedFormat, + compression: Option, + ) -> impl std::future::Future> + Send; +} + +impl ToWireFormat for InterpretedWalRecords { + async fn to_wire( + self, + format: InterpretedFormat, + compression: Option, + ) -> Result { + use async_compression::tokio::write::ZstdEncoder; + use async_compression::Level; + + let encode_res: Result = match format { + InterpretedFormat::Bincode => { + let buf = BytesMut::new(); + let mut buf = buf.writer(); + self.ser_into(&mut buf)?; + Ok(buf.into_inner().freeze()) + } + InterpretedFormat::Protobuf => { + let proto: proto::InterpretedWalRecords = self.try_into()?; + let mut buf = BytesMut::new(); + proto + .encode(&mut buf) + .map_err(|e| ToWireFormatError::Protobuf(e.into()))?; + + Ok(buf.freeze()) + } + }; + + let buf = encode_res?; + let compressed_buf = match compression { + Some(Compression::Zstd { level }) => { + let mut encoder = ZstdEncoder::with_quality( + Vec::with_capacity(buf.len() / 4), + Level::Precise(level as i32), + ); + encoder.write_all(&buf).await?; + encoder.shutdown().await?; + Bytes::from(encoder.into_inner()) + } + None => buf, + }; + + Ok(compressed_buf) + } +} + +impl FromWireFormat for InterpretedWalRecords { + type T = Self; + + async fn from_wire( + buf: &Bytes, + format: InterpretedFormat, + compression: Option, + ) -> Result { + let decompressed_buf = match compression { + Some(Compression::Zstd { .. }) => { + use async_compression::tokio::write::ZstdDecoder; + let mut decoded_buf = Vec::with_capacity(buf.len()); + let mut decoder = ZstdDecoder::new(&mut decoded_buf); + decoder.write_all(buf).await?; + decoder.flush().await?; + Bytes::from(decoded_buf) + } + None => buf.clone(), + }; + + match format { + InterpretedFormat::Bincode => { + InterpretedWalRecords::des(&decompressed_buf).map_err(FromWireFormatError::Bincode) + } + InterpretedFormat::Protobuf => { + let proto = proto::InterpretedWalRecords::decode(decompressed_buf) + .map_err(|e| FromWireFormatError::Protobuf(e.into()))?; + InterpretedWalRecords::try_from(proto) + .map_err(|e| FromWireFormatError::Protobuf(e.into())) + } + } + } +} + +impl TryFrom for proto::InterpretedWalRecords { + type Error = SerializeError; + + fn try_from(value: InterpretedWalRecords) -> Result { + let records = value + .records + .into_iter() + .map(proto::InterpretedWalRecord::try_from) + .collect::, _>>()?; + Ok(proto::InterpretedWalRecords { + records, + next_record_lsn: value.next_record_lsn.map(|l| l.0), + }) + } +} + +impl TryFrom for proto::InterpretedWalRecord { + type Error = SerializeError; + + fn try_from(value: InterpretedWalRecord) -> Result { + let metadata_record = value + .metadata_record + .map(|meta_rec| -> Result, Self::Error> { + let mut buf = Vec::new(); + meta_rec.ser_into(&mut buf)?; + Ok(buf) + }) + .transpose()?; + + Ok(proto::InterpretedWalRecord { + metadata_record, + batch: Some(proto::SerializedValueBatch::from(value.batch)), + next_record_lsn: value.next_record_lsn.0, + flush_uncommitted: matches!(value.flush_uncommitted, FlushUncommittedRecords::Yes), + xid: value.xid, + }) + } +} + +impl From for proto::SerializedValueBatch { + fn from(value: SerializedValueBatch) -> Self { + proto::SerializedValueBatch { + raw: value.raw, + metadata: value + .metadata + .into_iter() + .map(proto::ValueMeta::from) + .collect(), + max_lsn: value.max_lsn.0, + len: value.len as u64, + } + } +} + +impl From for proto::ValueMeta { + fn from(value: ValueMeta) -> Self { + match value { + ValueMeta::Observed(obs) => proto::ValueMeta { + r#type: proto::ValueMetaType::Observed.into(), + key: Some(proto::CompactKey::from(obs.key)), + lsn: obs.lsn.0, + batch_offset: None, + len: None, + will_init: None, + }, + ValueMeta::Serialized(ser) => proto::ValueMeta { + r#type: proto::ValueMetaType::Serialized.into(), + key: Some(proto::CompactKey::from(ser.key)), + lsn: ser.lsn.0, + batch_offset: Some(ser.batch_offset), + len: Some(ser.len as u64), + will_init: Some(ser.will_init), + }, + } + } +} + +impl From for proto::CompactKey { + fn from(value: CompactKey) -> Self { + proto::CompactKey { + high: (value.raw() >> 64) as i64, + low: value.raw() as i64, + } + } +} + +impl TryFrom for InterpretedWalRecords { + type Error = TranscodeError; + + fn try_from(value: proto::InterpretedWalRecords) -> Result { + let records = value + .records + .into_iter() + .map(InterpretedWalRecord::try_from) + .collect::>()?; + + Ok(InterpretedWalRecords { + records, + next_record_lsn: value.next_record_lsn.map(Lsn::from), + }) + } +} + +impl TryFrom for InterpretedWalRecord { + type Error = TranscodeError; + + fn try_from(value: proto::InterpretedWalRecord) -> Result { + let metadata_record = value + .metadata_record + .map(|mrec| -> Result<_, DeserializeError> { MetadataRecord::des(&mrec) }) + .transpose()?; + + let batch = { + let batch = value.batch.ok_or_else(|| { + TranscodeError::BadInput("InterpretedWalRecord::batch missing".to_string()) + })?; + + SerializedValueBatch::try_from(batch)? + }; + + Ok(InterpretedWalRecord { + metadata_record, + batch, + next_record_lsn: Lsn(value.next_record_lsn), + flush_uncommitted: if value.flush_uncommitted { + FlushUncommittedRecords::Yes + } else { + FlushUncommittedRecords::No + }, + xid: value.xid, + }) + } +} + +impl TryFrom for SerializedValueBatch { + type Error = TranscodeError; + + fn try_from(value: proto::SerializedValueBatch) -> Result { + let metadata = value + .metadata + .into_iter() + .map(ValueMeta::try_from) + .collect::, _>>()?; + + Ok(SerializedValueBatch { + raw: value.raw, + metadata, + max_lsn: Lsn(value.max_lsn), + len: value.len as usize, + }) + } +} + +impl TryFrom for ValueMeta { + type Error = TranscodeError; + + fn try_from(value: proto::ValueMeta) -> Result { + match proto::ValueMetaType::try_from(value.r#type) { + Ok(proto::ValueMetaType::Serialized) => { + Ok(ValueMeta::Serialized(SerializedValueMeta { + key: value + .key + .ok_or_else(|| { + TranscodeError::BadInput("ValueMeta::key missing".to_string()) + })? + .into(), + lsn: Lsn(value.lsn), + batch_offset: value.batch_offset.ok_or_else(|| { + TranscodeError::BadInput("ValueMeta::batch_offset missing".to_string()) + })?, + len: value.len.ok_or_else(|| { + TranscodeError::BadInput("ValueMeta::len missing".to_string()) + })? as usize, + will_init: value.will_init.ok_or_else(|| { + TranscodeError::BadInput("ValueMeta::will_init missing".to_string()) + })?, + })) + } + Ok(proto::ValueMetaType::Observed) => Ok(ValueMeta::Observed(ObservedValueMeta { + key: value + .key + .ok_or_else(|| TranscodeError::BadInput("ValueMeta::key missing".to_string()))? + .into(), + lsn: Lsn(value.lsn), + })), + Err(_) => Err(TranscodeError::BadInput(format!( + "Unexpected ValueMeta::type {}", + value.r#type + ))), + } + } +} + +impl From for CompactKey { + fn from(value: proto::CompactKey) -> Self { + (((value.high as i128) << 64) | (value.low as i128)).into() + } +} diff --git a/pageserver/Cargo.toml b/pageserver/Cargo.toml index 143d8236df..140b287ccc 100644 --- a/pageserver/Cargo.toml +++ b/pageserver/Cargo.toml @@ -43,6 +43,7 @@ postgres.workspace = true postgres_backend.workspace = true postgres-protocol.workspace = true postgres-types.workspace = true +postgres_initdb.workspace = true rand.workspace = true range-set-blaze = { version = "0.1.16", features = ["alloc"] } regex.workspace = true @@ -68,6 +69,7 @@ url.workspace = true walkdir.workspace = true metrics.workspace = true pageserver_api.workspace = true +pageserver_client.workspace = true # for ResponseErrorMessageExt TOOD refactor that pageserver_compaction.workspace = true postgres_connection.workspace = true postgres_ffi.workspace = true diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 033a9a4619..a8c2c2e992 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -126,6 +126,7 @@ fn main() -> anyhow::Result<()> { // after setting up logging, log the effective IO engine choice and read path implementations info!(?conf.virtual_file_io_engine, "starting with virtual_file IO engine"); info!(?conf.virtual_file_io_mode, "starting with virtual_file IO mode"); + info!(?conf.wal_receiver_protocol, "starting with WAL receiver protocol"); // The tenants directory contains all the pageserver local disk state. // Create if not exists and make sure all the contents are durable before proceeding. diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index 86c3621cf0..1651db8500 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -14,6 +14,7 @@ use remote_storage::{RemotePath, RemoteStorageConfig}; use std::env; use storage_broker::Uri; use utils::logging::SecretString; +use utils::postgres_client::PostgresClientProtocol; use once_cell::sync::OnceCell; use reqwest::Url; @@ -144,6 +145,10 @@ pub struct PageServerConf { /// JWT token for use with the control plane API. pub control_plane_api_token: Option, + pub import_pgdata_upcall_api: Option, + pub import_pgdata_upcall_api_token: Option, + pub import_pgdata_aws_endpoint_url: Option, + /// If true, pageserver will make best-effort to operate without a control plane: only /// for use in major incidents. pub control_plane_emergency_mode: bool, @@ -183,7 +188,9 @@ pub struct PageServerConf { /// Optionally disable disk syncs (unsafe!) pub no_sync: bool, - pub page_service_pipelining: Option, + pub wal_receiver_protocol: PostgresClientProtocol, + + pub page_service_pipelining: pageserver_api::config::PageServicePipeliningConfig, } /// Token for authentication to safekeepers @@ -326,6 +333,9 @@ impl PageServerConf { control_plane_api, control_plane_api_token, control_plane_emergency_mode, + import_pgdata_upcall_api, + import_pgdata_upcall_api_token, + import_pgdata_aws_endpoint_url, heatmap_upload_concurrency, secondary_download_concurrency, ingest_batch_size, @@ -340,6 +350,7 @@ impl PageServerConf { virtual_file_io_engine, tenant_config, no_sync, + wal_receiver_protocol, page_service_pipelining, } = config_toml; @@ -380,6 +391,10 @@ impl PageServerConf { image_compression, timeline_offloading, ephemeral_bytes_per_memory_kb, + import_pgdata_upcall_api, + import_pgdata_upcall_api_token: import_pgdata_upcall_api_token.map(SecretString::from), + import_pgdata_aws_endpoint_url, + wal_receiver_protocol, page_service_pipelining, // ------------------------------------------------------------ diff --git a/pageserver/src/deletion_queue.rs b/pageserver/src/deletion_queue.rs index 37fa300467..e74c8ecf5a 100644 --- a/pageserver/src/deletion_queue.rs +++ b/pageserver/src/deletion_queue.rs @@ -1144,18 +1144,24 @@ pub(crate) mod mock { rx: tokio::sync::mpsc::UnboundedReceiver, executor_rx: tokio::sync::mpsc::Receiver, cancel: CancellationToken, + executed: Arc, } impl ConsumerState { - async fn consume(&mut self, remote_storage: &GenericRemoteStorage) -> usize { - let mut executed = 0; - + async fn consume(&mut self, remote_storage: &GenericRemoteStorage) { info!("Executing all pending deletions"); // Transform all executor messages to generic frontend messages - while let Ok(msg) = self.executor_rx.try_recv() { + loop { + use either::Either; + let msg = tokio::select! { + left = self.executor_rx.recv() => Either::Left(left), + right = self.rx.recv() => Either::Right(right), + }; match msg { - DeleterMessage::Delete(objects) => { + Either::Left(None) => break, + Either::Right(None) => break, + Either::Left(Some(DeleterMessage::Delete(objects))) => { for path in objects { match remote_storage.delete(&path, &self.cancel).await { Ok(_) => { @@ -1165,18 +1171,13 @@ pub(crate) mod mock { error!("Failed to delete {path}, leaking object! ({e})"); } } - executed += 1; + self.executed.fetch_add(1, Ordering::Relaxed); } } - DeleterMessage::Flush(flush_op) => { + Either::Left(Some(DeleterMessage::Flush(flush_op))) => { flush_op.notify(); } - } - } - - while let Ok(msg) = self.rx.try_recv() { - match msg { - ListWriterQueueMessage::Delete(op) => { + Either::Right(Some(ListWriterQueueMessage::Delete(op))) => { let mut objects = op.objects; for (layer, meta) in op.layers { objects.push(remote_layer_path( @@ -1198,33 +1199,27 @@ pub(crate) mod mock { error!("Failed to delete {path}, leaking object! ({e})"); } } - executed += 1; + self.executed.fetch_add(1, Ordering::Relaxed); } } - ListWriterQueueMessage::Flush(op) => { + Either::Right(Some(ListWriterQueueMessage::Flush(op))) => { op.notify(); } - ListWriterQueueMessage::FlushExecute(op) => { + Either::Right(Some(ListWriterQueueMessage::FlushExecute(op))) => { // We have already executed all prior deletions because mock does them inline op.notify(); } - ListWriterQueueMessage::Recover(_) => { + Either::Right(Some(ListWriterQueueMessage::Recover(_))) => { // no-op in mock } } - info!("All pending deletions have been executed"); } - - executed } } pub struct MockDeletionQueue { tx: tokio::sync::mpsc::UnboundedSender, executor_tx: tokio::sync::mpsc::Sender, - executed: Arc, - remote_storage: Option, - consumer: std::sync::Mutex, lsn_table: Arc>, } @@ -1235,29 +1230,34 @@ pub(crate) mod mock { let executed = Arc::new(AtomicUsize::new(0)); + let mut consumer = ConsumerState { + rx, + executor_rx, + cancel: CancellationToken::new(), + executed: executed.clone(), + }; + + tokio::spawn(async move { + if let Some(remote_storage) = &remote_storage { + consumer.consume(remote_storage).await; + } + }); + Self { tx, executor_tx, - executed, - remote_storage, - consumer: std::sync::Mutex::new(ConsumerState { - rx, - executor_rx, - cancel: CancellationToken::new(), - }), lsn_table: Arc::new(std::sync::RwLock::new(VisibleLsnUpdates::new())), } } #[allow(clippy::await_holding_lock)] pub async fn pump(&self) { - if let Some(remote_storage) = &self.remote_storage { - // Permit holding mutex across await, because this is only ever - // called once at a time in tests. - let mut locked = self.consumer.lock().unwrap(); - let count = locked.consume(remote_storage).await; - self.executed.fetch_add(count, Ordering::Relaxed); - } + let (tx, rx) = tokio::sync::oneshot::channel(); + self.executor_tx + .send(DeleterMessage::Flush(FlushOp { tx })) + .await + .expect("Failed to send flush message"); + rx.await.ok(); } pub(crate) fn new_client(&self) -> DeletionQueueClient { diff --git a/pageserver/src/deletion_queue/deleter.rs b/pageserver/src/deletion_queue/deleter.rs index 1f04bc0410..3d02387c98 100644 --- a/pageserver/src/deletion_queue/deleter.rs +++ b/pageserver/src/deletion_queue/deleter.rs @@ -15,6 +15,7 @@ use tokio_util::sync::CancellationToken; use tracing::info; use tracing::warn; use utils::backoff; +use utils::pausable_failpoint; use crate::metrics; @@ -90,6 +91,7 @@ impl Deleter { /// Block until everything in accumulator has been executed async fn flush(&mut self) -> Result<(), DeletionQueueError> { while !self.accumulator.is_empty() && !self.cancel.is_cancelled() { + pausable_failpoint!("deletion-queue-before-execute-pause"); match self.remote_delete().await { Ok(()) => { // Note: we assume that the remote storage layer returns Ok(()) if some diff --git a/pageserver/src/http/openapi_spec.yml b/pageserver/src/http/openapi_spec.yml index 2bc7f5ad39..7fb9247feb 100644 --- a/pageserver/src/http/openapi_spec.yml +++ b/pageserver/src/http/openapi_spec.yml @@ -623,6 +623,8 @@ paths: existing_initdb_timeline_id: type: string format: hex + import_pgdata: + $ref: "#/components/schemas/TimelineCreateRequestImportPgdata" responses: "201": description: Timeline was created, or already existed with matching parameters @@ -979,6 +981,34 @@ components: $ref: "#/components/schemas/TenantConfig" effective_config: $ref: "#/components/schemas/TenantConfig" + TimelineCreateRequestImportPgdata: + type: object + required: + - location + - idempotency_key + properties: + idempotency_key: + type: string + location: + $ref: "#/components/schemas/TimelineCreateRequestImportPgdataLocation" + TimelineCreateRequestImportPgdataLocation: + type: object + properties: + AwsS3: + $ref: "#/components/schemas/TimelineCreateRequestImportPgdataLocationAwsS3" + TimelineCreateRequestImportPgdataLocationAwsS3: + type: object + properties: + region: + type: string + bucket: + type: string + key: + type: string + required: + - region + - bucket + - key TimelineInfo: type: object required: diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index 306b0f35ab..ceb1c3b012 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -40,6 +40,7 @@ use pageserver_api::models::TenantSorting; use pageserver_api::models::TenantState; use pageserver_api::models::TimelineArchivalConfigRequest; use pageserver_api::models::TimelineCreateRequestMode; +use pageserver_api::models::TimelineCreateRequestModeImportPgdata; use pageserver_api::models::TimelinesInfoAndOffloaded; use pageserver_api::models::TopTenantShardItem; use pageserver_api::models::TopTenantShardsRequest; @@ -55,6 +56,7 @@ use tokio_util::sync::CancellationToken; use tracing::*; use utils::auth::JwtAuth; use utils::failpoint_support::failpoints_handler; +use utils::http::endpoint::profile_cpu_handler; use utils::http::endpoint::prometheus_metrics_handler; use utils::http::endpoint::request_span; use utils::http::request::must_parse_query_param; @@ -80,6 +82,7 @@ use crate::tenant::secondary::SecondaryController; use crate::tenant::size::ModelInputs; use crate::tenant::storage_layer::LayerAccessStatsReset; use crate::tenant::storage_layer::LayerName; +use crate::tenant::timeline::import_pgdata; use crate::tenant::timeline::offload::offload_timeline; use crate::tenant::timeline::offload::OffloadError; use crate::tenant::timeline::CompactFlags; @@ -125,7 +128,7 @@ pub struct State { conf: &'static PageServerConf, tenant_manager: Arc, auth: Option>, - allowlist_routes: Vec, + allowlist_routes: &'static [&'static str], remote_storage: GenericRemoteStorage, broker_client: storage_broker::BrokerClientChannel, disk_usage_eviction_state: Arc, @@ -146,10 +149,13 @@ impl State { deletion_queue_client: DeletionQueueClient, secondary_controller: SecondaryController, ) -> anyhow::Result { - let allowlist_routes = ["/v1/status", "/v1/doc", "/swagger.yml", "/metrics"] - .iter() - .map(|v| v.parse().unwrap()) - .collect::>(); + let allowlist_routes = &[ + "/v1/status", + "/v1/doc", + "/swagger.yml", + "/metrics", + "/profile/cpu", + ]; Ok(Self { conf, tenant_manager, @@ -576,6 +582,35 @@ async fn timeline_create_handler( ancestor_timeline_id, ancestor_start_lsn, }), + TimelineCreateRequestMode::ImportPgdata { + import_pgdata: + TimelineCreateRequestModeImportPgdata { + location, + idempotency_key, + }, + } => tenant::CreateTimelineParams::ImportPgdata(tenant::CreateTimelineParamsImportPgdata { + idempotency_key: import_pgdata::index_part_format::IdempotencyKey::new( + idempotency_key.0, + ), + new_timeline_id, + location: { + use import_pgdata::index_part_format::Location; + use pageserver_api::models::ImportPgdataLocation; + match location { + #[cfg(feature = "testing")] + ImportPgdataLocation::LocalFs { path } => Location::LocalFs { path }, + ImportPgdataLocation::AwsS3 { + region, + bucket, + key, + } => Location::AwsS3 { + region, + bucket, + key, + }, + } + }, + }), }; let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Error); @@ -3148,7 +3183,7 @@ pub fn make_router( if auth.is_some() { router = router.middleware(auth_middleware(|request| { let state = get_state(request); - if state.allowlist_routes.contains(request.uri()) { + if state.allowlist_routes.contains(&request.uri().path()) { None } else { state.auth.as_deref() @@ -3167,6 +3202,7 @@ pub fn make_router( Ok(router .data(state) .get("/metrics", |r| request_span(r, prometheus_metrics_handler)) + .get("/profile/cpu", |r| request_span(r, profile_cpu_handler)) .get("/v1/status", |r| api_handler(r, status_handler)) .put("/v1/failpoints", |r| { testing_api_handler("manage failpoints", r, failpoints_handler) diff --git a/pageserver/src/lib.rs b/pageserver/src/lib.rs index ef6711397a..ff6af3566c 100644 --- a/pageserver/src/lib.rs +++ b/pageserver/src/lib.rs @@ -356,6 +356,25 @@ async fn timed( } } +/// Like [`timed`], but the warning timeout only starts after `cancel` has been cancelled. +async fn timed_after_cancellation( + fut: Fut, + name: &str, + warn_at: std::time::Duration, + cancel: &CancellationToken, +) -> ::Output { + let mut fut = std::pin::pin!(fut); + + tokio::select! { + _ = cancel.cancelled() => { + timed(fut, name, warn_at).await + } + ret = &mut fut => { + ret + } + } +} + #[cfg(test)] mod timed_tests { use super::timed; diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index 92f7c4e2e6..277d89dfe5 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -3,7 +3,7 @@ use metrics::{ register_counter_vec, register_gauge_vec, register_histogram, register_histogram_vec, register_int_counter, register_int_counter_pair_vec, register_int_counter_vec, register_int_gauge, register_int_gauge_vec, register_uint_gauge, register_uint_gauge_vec, - Counter, CounterVec, GaugeVec, Histogram, HistogramVec, IntCounter, IntCounterPair, + Counter, CounterVec, Gauge, GaugeVec, Histogram, HistogramVec, IntCounter, IntCounterPair, IntCounterPairVec, IntCounterVec, IntGauge, IntGaugeVec, UIntGauge, UIntGaugeVec, }; use once_cell::sync::Lazy; @@ -457,6 +457,15 @@ pub(crate) static WAIT_LSN_TIME: Lazy = Lazy::new(|| { .expect("failed to define a metric") }); +static FLUSH_WAIT_UPLOAD_TIME: Lazy = Lazy::new(|| { + register_gauge_vec!( + "pageserver_flush_wait_upload_seconds", + "Time spent waiting for preceding uploads during layer flush", + &["tenant_id", "shard_id", "timeline_id"] + ) + .expect("failed to define a metric") +}); + static LAST_RECORD_LSN: Lazy = Lazy::new(|| { register_int_gauge_vec!( "pageserver_last_record_lsn", @@ -653,6 +662,35 @@ pub(crate) static COMPRESSION_IMAGE_OUTPUT_BYTES: Lazy = Lazy::new(| .expect("failed to define a metric") }); +pub(crate) static RELSIZE_CACHE_ENTRIES: Lazy = Lazy::new(|| { + register_uint_gauge!( + "pageserver_relsize_cache_entries", + "Number of entries in the relation size cache", + ) + .expect("failed to define a metric") +}); + +pub(crate) static RELSIZE_CACHE_HITS: Lazy = Lazy::new(|| { + register_int_counter!("pageserver_relsize_cache_hits", "Relation size cache hits",) + .expect("failed to define a metric") +}); + +pub(crate) static RELSIZE_CACHE_MISSES: Lazy = Lazy::new(|| { + register_int_counter!( + "pageserver_relsize_cache_misses", + "Relation size cache misses", + ) + .expect("failed to define a metric") +}); + +pub(crate) static RELSIZE_CACHE_MISSES_OLD: Lazy = Lazy::new(|| { + register_int_counter!( + "pageserver_relsize_cache_misses_old", + "Relation size cache misses where the lookup LSN is older than the last relation update" + ) + .expect("failed to define a metric") +}); + pub(crate) mod initial_logical_size { use metrics::{register_int_counter, register_int_counter_vec, IntCounter, IntCounterVec}; use once_cell::sync::Lazy; @@ -2097,6 +2135,7 @@ pub(crate) struct WalIngestMetrics { pub(crate) records_committed: IntCounter, pub(crate) records_filtered: IntCounter, pub(crate) gap_blocks_zeroed_on_rel_extend: IntCounter, + pub(crate) clear_vm_bits_unknown: IntCounterVec, } pub(crate) static WAL_INGEST: Lazy = Lazy::new(|| WalIngestMetrics { @@ -2125,6 +2164,12 @@ pub(crate) static WAL_INGEST: Lazy = Lazy::new(|| WalIngestMet "Total number of zero gap blocks written on relation extends" ) .expect("failed to define a metric"), + clear_vm_bits_unknown: register_int_counter_vec!( + "pageserver_wal_ingest_clear_vm_bits_unknown", + "Number of ignored ClearVmBits operations due to unknown pages/relations", + &["entity"], + ) + .expect("failed to define a metric"), }); pub(crate) static WAL_REDO_TIME: Lazy = Lazy::new(|| { @@ -2327,6 +2372,7 @@ pub(crate) struct TimelineMetrics { shard_id: String, timeline_id: String, pub flush_time_histo: StorageTimeMetrics, + pub flush_wait_upload_time_gauge: Gauge, pub compact_time_histo: StorageTimeMetrics, pub create_images_time_histo: StorageTimeMetrics, pub logical_size_histo: StorageTimeMetrics, @@ -2370,6 +2416,9 @@ impl TimelineMetrics { &shard_id, &timeline_id, ); + let flush_wait_upload_time_gauge = FLUSH_WAIT_UPLOAD_TIME + .get_metric_with_label_values(&[&tenant_id, &shard_id, &timeline_id]) + .unwrap(); let compact_time_histo = StorageTimeMetrics::new( StorageTimeOperation::Compact, &tenant_id, @@ -2507,6 +2556,7 @@ impl TimelineMetrics { shard_id, timeline_id, flush_time_histo, + flush_wait_upload_time_gauge, compact_time_histo, create_images_time_histo, logical_size_histo, @@ -2554,6 +2604,14 @@ impl TimelineMetrics { self.resident_physical_size_gauge.get() } + pub(crate) fn flush_wait_upload_time_gauge_add(&self, duration: f64) { + self.flush_wait_upload_time_gauge.add(duration); + crate::metrics::FLUSH_WAIT_UPLOAD_TIME + .get_metric_with_label_values(&[&self.tenant_id, &self.shard_id, &self.timeline_id]) + .unwrap() + .add(duration); + } + pub(crate) fn shutdown(&self) { let was_shutdown = self .shutdown @@ -2570,6 +2628,7 @@ impl TimelineMetrics { let timeline_id = &self.timeline_id; let shard_id = &self.shard_id; let _ = LAST_RECORD_LSN.remove_label_values(&[tenant_id, shard_id, timeline_id]); + let _ = FLUSH_WAIT_UPLOAD_TIME.remove_label_values(&[tenant_id, shard_id, timeline_id]); let _ = STANDBY_HORIZON.remove_label_values(&[tenant_id, shard_id, timeline_id]); { RESIDENT_PHYSICAL_SIZE_GLOBAL.sub(self.resident_physical_size_get()); diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 0316948a7d..bd5dd88f04 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -7,7 +7,10 @@ use bytes::Buf; use futures::FutureExt; use itertools::Itertools; use once_cell::sync::OnceCell; -use pageserver_api::config::{PageServicePipeliningConfig, PageServiceProtocolPipeliningMode}; +use pageserver_api::config::{ + PageServicePipeliningConfig, PageServicePipeliningConfigPipelined, + PageServiceProtocolPipelinedExecutionStrategy, +}; use pageserver_api::models::{self, TenantState}; use pageserver_api::models::{ PagestreamBeMessage, PagestreamDbSizeRequest, PagestreamDbSizeResponse, @@ -36,6 +39,7 @@ use tokio::io::{AsyncWriteExt, BufWriter}; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tracing::*; +use utils::sync::spsc_fold; use utils::{ auth::{Claims, Scope, SwappableJwtAuth}, id::{TenantId, TimelineId}, @@ -44,7 +48,6 @@ use utils::{ }; use crate::auth::check_permission; -use crate::basebackup; use crate::basebackup::BasebackupError; use crate::config::PageServerConf; use crate::context::{DownloadBehavior, RequestContext}; @@ -62,6 +65,7 @@ use crate::tenant::timeline::{self, WaitLsnError}; use crate::tenant::GetTimelineError; use crate::tenant::PageReconstructError; use crate::tenant::Timeline; +use crate::{basebackup, timed_after_cancellation}; use pageserver_api::key::rel_block_to_key; use pageserver_api::reltag::{BlockNumber, RelTag, SlruKind}; use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID; @@ -158,7 +162,7 @@ pub async fn libpq_listener_main( auth: Option>, listener: tokio::net::TcpListener, auth_type: AuthType, - pipelining_config: Option, + pipelining_config: PageServicePipeliningConfig, listener_ctx: RequestContext, listener_cancel: CancellationToken, ) -> Connections { @@ -217,7 +221,7 @@ async fn page_service_conn_main( auth: Option>, socket: tokio::net::TcpStream, auth_type: AuthType, - pipelining_config: Option, + pipelining_config: PageServicePipeliningConfig, connection_ctx: RequestContext, cancel: CancellationToken, ) -> ConnectionHandlerResult { @@ -319,7 +323,7 @@ struct PageServerHandler { /// None only while pagestream protocol is being processed. timeline_handles: Option, - pipelining_config: Option, + pipelining_config: PageServicePipeliningConfig, } struct TimelineHandles { @@ -574,7 +578,7 @@ impl PageServerHandler { pub fn new( tenant_manager: Arc, auth: Option>, - pipelining_config: Option, + pipelining_config: PageServicePipeliningConfig, connection_ctx: RequestContext, cancel: CancellationToken, ) -> Self { @@ -620,7 +624,7 @@ impl PageServerHandler { cancel: &CancellationToken, ctx: &RequestContext, parent_span: Span, - ) -> Result>, QueryError> + ) -> Result, QueryError> where IO: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, { @@ -722,7 +726,7 @@ impl PageServerHandler { span, error: $error, }; - Ok(Some(Box::new(error))) + Ok(Some(error)) }}; } @@ -773,7 +777,7 @@ impl PageServerHandler { } } }; - Ok(Some(Box::new(batched_msg))) + Ok(Some(batched_msg)) } /// Post-condition: `batch` is Some() @@ -781,26 +785,25 @@ impl PageServerHandler { #[allow(clippy::boxed_local)] fn pagestream_do_batch( max_batch_size: NonZeroUsize, - batch: &mut Option>, - this_msg: Box, - ) -> Option> { + batch: &mut Result, + this_msg: Result, + ) -> Result<(), Result> { debug_assert_current_span_has_tenant_and_timeline_id_no_shard_id(); - match (batch.as_deref_mut(), *this_msg) { - // nothing batched yet - (None, this_msg) => { - *batch = Some(Box::new(this_msg)); - None - } + let this_msg = match this_msg { + Ok(this_msg) => this_msg, + Err(e) => return Err(Err(e)), + }; + + match (&mut *batch, this_msg) { // something batched already, let's see if we can add this message to the batch ( - Some(BatchedFeMessage::GetPage { + Ok(BatchedFeMessage::GetPage { span: _, shard: accum_shard, pages: ref mut accum_pages, effective_request_lsn: accum_lsn, }), - // would be nice to have box pattern here BatchedFeMessage::GetPage { span: _, shard: this_shard, @@ -833,12 +836,12 @@ impl PageServerHandler { { // ok to batch accum_pages.extend(this_pages); - None + Ok(()) } // something batched already but this message is unbatchable - (Some(_), this_msg) => { + (_, this_msg) => { // by default, don't continue batching - Some(Box::new(this_msg)) // TODO: avoid re-box + Err(Ok(this_msg)) } } } @@ -848,6 +851,7 @@ impl PageServerHandler { &mut self, pgb_writer: &mut PostgresBackend, batch: BatchedFeMessage, + cancel: &CancellationToken, ctx: &RequestContext, ) -> Result<(), QueryError> where @@ -984,7 +988,7 @@ impl PageServerHandler { } tokio::select! { biased; - _ = self.cancel.cancelled() => { + _ = cancel.cancelled() => { // We were requested to shut down. info!("shutdown request received in page handler"); return Err(QueryError::Shutdown) @@ -1041,8 +1045,8 @@ impl PageServerHandler { .expect("implementation error: timeline_handles should not be locked"); let request_span = info_span!("request", shard_id = tracing::field::Empty); - let (pgb_reader, timeline_handles) = match self.pipelining_config.clone() { - Some(pipelining_config) => { + let ((pgb_reader, timeline_handles), result) = match self.pipelining_config.clone() { + PageServicePipeliningConfig::Pipelined(pipelining_config) => { self.handle_pagerequests_pipelined( pgb, pgb_reader, @@ -1055,7 +1059,7 @@ impl PageServerHandler { ) .await } - None => { + PageServicePipeliningConfig::Serial => { self.handle_pagerequests_serial( pgb, pgb_reader, @@ -1067,7 +1071,7 @@ impl PageServerHandler { ) .await } - }?; + }; debug!("pagestream subprotocol shut down cleanly"); @@ -1077,7 +1081,7 @@ impl PageServerHandler { let replaced = self.timeline_handles.replace(timeline_handles); assert!(replaced.is_none()); - Ok(()) + result } #[allow(clippy::too_many_arguments)] @@ -1090,33 +1094,50 @@ impl PageServerHandler { mut timeline_handles: TimelineHandles, request_span: Span, ctx: &RequestContext, - ) -> Result<(PostgresBackendReader, TimelineHandles), QueryError> + ) -> ( + (PostgresBackendReader, TimelineHandles), + Result<(), QueryError>, + ) where IO: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, { - loop { + let cancel = self.cancel.clone(); + let err = loop { let msg = Self::pagestream_read_message( &mut pgb_reader, tenant_id, timeline_id, &mut timeline_handles, - &self.cancel, + &cancel, ctx, request_span.clone(), ) - .await?; + .await; + let msg = match msg { + Ok(msg) => msg, + Err(e) => break e, + }; let msg = match msg { Some(msg) => msg, None => { debug!("pagestream subprotocol end observed"); - return Ok((pgb_reader, timeline_handles)); + return ((pgb_reader, timeline_handles), Ok(())); } }; - self.pagesteam_handle_batched_message(pgb_writer, *msg, ctx) - .await?; - } + let err = self + .pagesteam_handle_batched_message(pgb_writer, msg, &cancel, ctx) + .await; + match err { + Ok(()) => {} + Err(e) => break e, + } + }; + ((pgb_reader, timeline_handles), Err(err)) } + /// # Cancel-Safety + /// + /// May leak tokio tasks if not polled to completion. #[allow(clippy::too_many_arguments)] async fn handle_pagerequests_pipelined( &mut self, @@ -1126,185 +1147,175 @@ impl PageServerHandler { timeline_id: TimelineId, mut timeline_handles: TimelineHandles, request_span: Span, - pipelining_config: PageServicePipeliningConfig, + pipelining_config: PageServicePipeliningConfigPipelined, ctx: &RequestContext, - ) -> Result<(PostgresBackendReader, TimelineHandles), QueryError> + ) -> ( + (PostgresBackendReader, TimelineHandles), + Result<(), QueryError>, + ) where IO: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, { - let PageServicePipeliningConfig { + // + // Pipelined pagestream handling consists of + // - a Batcher that reads requests off the wire and + // and batches them if possible, + // - an Executor that processes the batched requests. + // + // The batch is built up inside an `spsc_fold` channel, + // shared betwen Batcher (Sender) and Executor (Receiver). + // + // The Batcher continously folds client requests into the batch, + // while the Executor can at any time take out what's in the batch + // in order to process it. + // This means the next batch builds up while the Executor + // executes the last batch. + // + // CANCELLATION + // + // We run both Batcher and Executor futures to completion before + // returning from this function. + // + // If Executor exits first, it signals cancellation to the Batcher + // via a CancellationToken that is child of `self.cancel`. + // If Batcher exits first, it signals cancellation to the Executor + // by dropping the spsc_fold channel Sender. + // + // CLEAN SHUTDOWN + // + // Clean shutdown means that the client ends the COPYBOTH session. + // In response to such a client message, the Batcher exits. + // The Executor continues to run, draining the spsc_fold channel. + // Once drained, the spsc_fold recv will fail with a distinct error + // indicating that the sender disconnected. + // The Executor exits with Ok(()) in response to that error. + // + // Server initiated shutdown is not clean shutdown, but instead + // is an error Err(QueryError::Shutdown) that is propagated through + // error propagation. + // + // ERROR PROPAGATION + // + // When the Batcher encounter an error, it sends it as a value + // through the spsc_fold channel and exits afterwards. + // When the Executor observes such an error in the channel, + // it exits returning that error value. + // + // This design ensures that the Executor stage will still process + // the batch that was in flight when the Batcher encountered an error, + // thereby beahving identical to a serial implementation. + + let PageServicePipeliningConfigPipelined { max_batch_size, - protocol_pipelining_mode, + execution, } = pipelining_config; - let (requests_tx, mut requests_rx) = tokio::sync::mpsc::channel(1); - let read_messages = { - let cancel = self.cancel.child_token(); + // Macro to _define_ a pipeline stage. + macro_rules! pipeline_stage { + ($name:literal, $cancel:expr, $make_fut:expr) => {{ + let cancel: CancellationToken = $cancel; + let stage_fut = $make_fut(cancel.clone()); + async move { + scopeguard::defer! { + debug!("exiting"); + } + timed_after_cancellation(stage_fut, $name, Duration::from_millis(100), &cancel) + .await + } + .instrument(tracing::info_span!($name)) + }}; + } + + // + // Batcher + // + + let cancel_batcher = self.cancel.child_token(); + let (mut batch_tx, mut batch_rx) = spsc_fold::channel(); + let read_messages = pipeline_stage!( + "read_messages", + cancel_batcher.clone(), + move |cancel_batcher| { + let ctx = ctx.attached_child(); + async move { + let mut pgb_reader = pgb_reader; + let mut exit = false; + while !exit { + let read_res = Self::pagestream_read_message( + &mut pgb_reader, + tenant_id, + timeline_id, + &mut timeline_handles, + &cancel_batcher, + &ctx, + request_span.clone(), + ) + .await; + let Some(read_res) = read_res.transpose() else { + debug!("client-initiated shutdown"); + break; + }; + exit |= read_res.is_err(); + let could_send = batch_tx + .send(read_res, |batch, res| { + Self::pagestream_do_batch(max_batch_size, batch, res) + }) + .await; + exit |= could_send.is_err(); + } + (pgb_reader, timeline_handles) + } + } + ); + + // + // Executor + // + + let executor = pipeline_stage!("executor", self.cancel.clone(), move |cancel| { let ctx = ctx.attached_child(); async move { - scopeguard::defer! { - debug!("exiting"); - } - let mut pgb_reader = pgb_reader; + let _cancel_batcher = cancel_batcher.drop_guard(); loop { - let msg = Self::pagestream_read_message( - &mut pgb_reader, - tenant_id, - timeline_id, - &mut timeline_handles, - &cancel, - &ctx, - request_span.clone(), - ) - .await?; - let msg = match msg { - Some(msg) => msg, - None => { - debug!("pagestream subprotocol end observed"); - break; + let maybe_batch = batch_rx.recv().await; + let batch = match maybe_batch { + Ok(batch) => batch, + Err(spsc_fold::RecvError::SenderGone) => { + debug!("upstream gone"); + return Ok(()); } }; - match requests_tx.send(msg).await { - Ok(()) => {} - Err(tokio::sync::mpsc::error::SendError(_)) => { - debug!("downstream is gone"); - break; + let batch = match batch { + Ok(batch) => batch, + Err(e) => { + return Err(e); } - } - } - Ok((pgb_reader, timeline_handles)) - } - } - .instrument(tracing::info_span!("read_messages")); - - enum BatchState { - Building(Option>), - UpstreamDead(Option>), - } - impl BatchState { - fn must_building_mut(&mut self) -> &mut Option> { - match self { - Self::Building(maybe_batch) => maybe_batch, - Self::UpstreamDead(_) => panic!("upstream dead"), - } - } - } - let (batch_tx, mut batch_rx) = tokio::sync::watch::channel(Arc::new( - std::sync::Mutex::new(BatchState::Building(None)), - )); - let notify_batcher = Arc::new(tokio::sync::Notify::new()); - let batcher = { - let notify_batcher = notify_batcher.clone(); - async move { - scopeguard::defer! { - debug!("exiting"); - } - loop { - let maybe_req = requests_rx.recv().await; - let Some(req) = maybe_req else { - batch_tx.send_modify(|pending_batch| { - let mut guard = pending_batch.lock().unwrap(); - match &mut *guard { - BatchState::Building(batch) => { - *guard = BatchState::UpstreamDead(batch.take()); - } - BatchState::UpstreamDead(_) => panic!("twice"), - } - }); - break; }; - // don't read new requests before this one has been processed - let mut req = Some(req); - loop { - let mut wait_notified = None; - let batched = batch_tx.send_if_modified(|pending_batch| { - let mut guard = pending_batch.lock().unwrap(); - let building = guard.must_building_mut(); - match Self::pagestream_do_batch( - max_batch_size, - building, - req.take().unwrap(), - ) { - Some(req_was_not_batched) => { - req.replace(req_was_not_batched); - wait_notified = Some(notify_batcher.notified()); - false - } - None => true, - } - }); - if batched { - break; - } else { - wait_notified.unwrap().await; - } - } + self.pagesteam_handle_batched_message(pgb_writer, batch, &cancel, &ctx) + .await?; } } - } - .instrument(tracing::info_span!("batcher")); + }); - let executor = async { - let mut stop = false; - while !stop { - match batch_rx.changed().await { - Ok(()) => {} - Err(_) => { - debug!("batch_rx observed disconnection of batcher"); - } - }; - let maybe_batch = { - let borrow = batch_rx.borrow(); - let mut guard = borrow.lock().unwrap(); - match &mut *guard { - BatchState::Building(maybe_batch) => maybe_batch.take(), - BatchState::UpstreamDead(maybe_batch) => { - debug!("upstream dead"); - stop = true; - maybe_batch.take() - } - } - }; - let Some(batch) = maybe_batch else { - break; - }; - notify_batcher.notify_one(); - debug!("processing batch"); - self.pagesteam_handle_batched_message(pgb_writer, *batch, ctx) - .await?; - } - Ok(()) - }; + // + // Execute the stages. + // - let read_messages_res; - let executor_res; - match protocol_pipelining_mode { - PageServiceProtocolPipeliningMode::ConcurrentFutures => { - (read_messages_res, _, executor_res) = - tokio::join!(read_messages, batcher, executor); + match execution { + PageServiceProtocolPipelinedExecutionStrategy::ConcurrentFutures => { + tokio::join!(read_messages, executor) } - PageServiceProtocolPipeliningMode::Tasks => { - // cancelled via sensitivity to self.cancel - let read_messages_task = tokio::task::spawn(read_messages); - // cancelled when it observes read_messages_task disconnect the channel - let batcher_task = tokio::task::spawn(batcher); - executor_res = executor.await; - read_messages_res = read_messages_task - .await - .context("read_messages task panicked, check logs for details")?; - let _: () = batcher_task - .await - .context("batcher task panicked, check logs for details")?; + PageServiceProtocolPipelinedExecutionStrategy::Tasks => { + // These tasks are not tracked anywhere. + let read_messages_task = tokio::spawn(read_messages); + let (read_messages_task_res, executor_res_) = + tokio::join!(read_messages_task, executor,); + ( + read_messages_task_res.expect("propagated panic from read_messages"), + executor_res_, + ) } } - - match (read_messages_res, executor_res) { - (Err(e), _) | (_, Err(e)) => { - let e: QueryError = e; - Err(e) - } - (Ok((pgb_reader, timeline_handles)), Ok(())) => Ok((pgb_reader, timeline_handles)), - } } /// Helper function to handle the LSN from client request. @@ -1349,21 +1360,26 @@ impl PageServerHandler { )); } - if request_lsn < **latest_gc_cutoff_lsn { + // Check explicitly for INVALID just to get a less scary error message if the request is obviously bogus + if request_lsn == Lsn::INVALID { + return Err(PageStreamError::BadRequest( + "invalid LSN(0) in request".into(), + )); + } + + // Clients should only read from recent LSNs on their timeline, or from locations holding an LSN lease. + // + // We may have older data available, but we make a best effort to detect this case and return an error, + // to distinguish a misbehaving client (asking for old LSN) from a storage issue (data missing at a legitimate LSN). + if request_lsn < **latest_gc_cutoff_lsn && !timeline.is_gc_blocked_by_lsn_lease_deadline() { let gc_info = &timeline.gc_info.read().unwrap(); if !gc_info.leases.contains_key(&request_lsn) { - // The requested LSN is below gc cutoff and is not guarded by a lease. - - // Check explicitly for INVALID just to get a less scary error message if the - // request is obviously bogus - return Err(if request_lsn == Lsn::INVALID { - PageStreamError::BadRequest("invalid LSN(0) in request".into()) - } else { + return Err( PageStreamError::BadRequest(format!( "tried to request a page version that was garbage collected. requested at {} gc cutoff {}", request_lsn, **latest_gc_cutoff_lsn ).into()) - }); + ); } } diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index 9d21652e00..3174d93bb8 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -10,6 +10,9 @@ use super::tenant::{PageReconstructError, Timeline}; use crate::aux_file; use crate::context::RequestContext; use crate::keyspace::{KeySpace, KeySpaceAccum}; +use crate::metrics::{ + RELSIZE_CACHE_ENTRIES, RELSIZE_CACHE_HITS, RELSIZE_CACHE_MISSES, RELSIZE_CACHE_MISSES_OLD, +}; use crate::span::{ debug_assert_current_span_has_tenant_and_timeline_id, debug_assert_current_span_has_tenant_and_timeline_id_no_shard_id, @@ -389,7 +392,9 @@ impl Timeline { result } - // Get size of a database in blocks + /// Get size of a database in blocks. This is only accurate on shard 0. It will undercount on + /// other shards, by only accounting for relations the shard has pages for, and only accounting + /// for pages up to the highest page number it has stored. pub(crate) async fn get_db_size( &self, spcnode: Oid, @@ -408,7 +413,10 @@ impl Timeline { Ok(total_blocks) } - /// Get size of a relation file + /// Get size of a relation file. The relation must exist, otherwise an error is returned. + /// + /// This is only accurate on shard 0. On other shards, it will return the size up to the highest + /// page number stored in the shard. pub(crate) async fn get_rel_size( &self, tag: RelTag, @@ -444,7 +452,10 @@ impl Timeline { Ok(nblocks) } - /// Does relation exist? + /// Does the relation exist? + /// + /// Only shard 0 has a full view of the relations. Other shards only know about relations that + /// the shard stores pages for. pub(crate) async fn get_rel_exists( &self, tag: RelTag, @@ -478,6 +489,9 @@ impl Timeline { /// Get a list of all existing relations in given tablespace and database. /// + /// Only shard 0 has a full view of the relations. Other shards only know about relations that + /// the shard stores pages for. + /// /// # Cancel-Safety /// /// This method is cancellation-safe. @@ -1129,9 +1143,12 @@ impl Timeline { let rel_size_cache = self.rel_size_cache.read().unwrap(); if let Some((cached_lsn, nblocks)) = rel_size_cache.map.get(tag) { if lsn >= *cached_lsn { + RELSIZE_CACHE_HITS.inc(); return Some(*nblocks); } + RELSIZE_CACHE_MISSES_OLD.inc(); } + RELSIZE_CACHE_MISSES.inc(); None } @@ -1156,6 +1173,7 @@ impl Timeline { } hash_map::Entry::Vacant(entry) => { entry.insert((lsn, nblocks)); + RELSIZE_CACHE_ENTRIES.inc(); } } } @@ -1163,13 +1181,17 @@ impl Timeline { /// Store cached relation size pub fn set_cached_rel_size(&self, tag: RelTag, lsn: Lsn, nblocks: BlockNumber) { let mut rel_size_cache = self.rel_size_cache.write().unwrap(); - rel_size_cache.map.insert(tag, (lsn, nblocks)); + if rel_size_cache.map.insert(tag, (lsn, nblocks)).is_none() { + RELSIZE_CACHE_ENTRIES.inc(); + } } /// Remove cached relation size pub fn remove_cached_rel_size(&self, tag: &RelTag) { let mut rel_size_cache = self.rel_size_cache.write().unwrap(); - rel_size_cache.map.remove(tag); + if rel_size_cache.map.remove(tag).is_some() { + RELSIZE_CACHE_ENTRIES.dec(); + } } } @@ -1229,10 +1251,9 @@ impl<'a> DatadirModification<'a> { } pub(crate) fn has_dirty_data(&self) -> bool { - !self - .pending_data_batch + self.pending_data_batch .as_ref() - .map_or(true, |b| b.is_empty()) + .map_or(false, |b| b.has_data()) } /// Set the current lsn @@ -1408,7 +1429,7 @@ impl<'a> DatadirModification<'a> { Some(pending_batch) => { pending_batch.extend(batch); } - None if !batch.is_empty() => { + None if batch.has_data() => { self.pending_data_batch = Some(batch); } None => { @@ -2276,9 +2297,9 @@ impl<'a> Version<'a> { //--- Metadata structs stored in key-value pairs in the repository. #[derive(Debug, Serialize, Deserialize)] -struct DbDirectory { +pub(crate) struct DbDirectory { // (spcnode, dbnode) -> (do relmapper and PG_VERSION files exist) - dbdirs: HashMap<(Oid, Oid), bool>, + pub(crate) dbdirs: HashMap<(Oid, Oid), bool>, } // The format of TwoPhaseDirectory changed in PostgreSQL v17, because the filenames of @@ -2287,8 +2308,8 @@ struct DbDirectory { // "pg_twophsae/0000000A000002E4". #[derive(Debug, Serialize, Deserialize)] -struct TwoPhaseDirectory { - xids: HashSet, +pub(crate) struct TwoPhaseDirectory { + pub(crate) xids: HashSet, } #[derive(Debug, Serialize, Deserialize)] @@ -2297,12 +2318,12 @@ struct TwoPhaseDirectoryV17 { } #[derive(Debug, Serialize, Deserialize, Default)] -struct RelDirectory { +pub(crate) struct RelDirectory { // Set of relations that exist. (relfilenode, forknum) // // TODO: Store it as a btree or radix tree or something else that spans multiple // key-value pairs, if you have a lot of relations - rels: HashSet<(Oid, u8)>, + pub(crate) rels: HashSet<(Oid, u8)>, } #[derive(Debug, Serialize, Deserialize)] @@ -2311,9 +2332,9 @@ struct RelSizeEntry { } #[derive(Debug, Serialize, Deserialize, Default)] -struct SlruSegmentDirectory { +pub(crate) struct SlruSegmentDirectory { // Set of SLRU segments that exist. - segments: HashSet, + pub(crate) segments: HashSet, } #[derive(Copy, Clone, PartialEq, Eq, Debug, enum_map::Enum)] diff --git a/pageserver/src/task_mgr.rs b/pageserver/src/task_mgr.rs index 6a4e90dd55..622738022a 100644 --- a/pageserver/src/task_mgr.rs +++ b/pageserver/src/task_mgr.rs @@ -381,6 +381,8 @@ pub enum TaskKind { UnitTest, DetachAncestor, + + ImportPgdata, } #[derive(Default)] diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 2e5f69e3c9..339a3ca1bb 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -43,7 +43,9 @@ use std::sync::atomic::AtomicBool; use std::sync::Weak; use std::time::SystemTime; use storage_broker::BrokerClientChannel; +use timeline::import_pgdata; use timeline::offload::offload_timeline; +use timeline::ShutdownMode; use tokio::io::BufReader; use tokio::sync::watch; use tokio::task::JoinSet; @@ -373,7 +375,6 @@ pub struct Tenant { l0_flush_global_state: L0FlushGlobalState, } - impl std::fmt::Debug for Tenant { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{} ({})", self.tenant_shard_id, self.current_state()) @@ -860,6 +861,7 @@ impl Debug for SetStoppingError { pub(crate) enum CreateTimelineParams { Bootstrap(CreateTimelineParamsBootstrap), Branch(CreateTimelineParamsBranch), + ImportPgdata(CreateTimelineParamsImportPgdata), } #[derive(Debug)] @@ -877,7 +879,14 @@ pub(crate) struct CreateTimelineParamsBranch { pub(crate) ancestor_start_lsn: Option, } -/// What is used to determine idempotency of a [`Tenant::create_timeline`] call in [`Tenant::start_creating_timeline`]. +#[derive(Debug)] +pub(crate) struct CreateTimelineParamsImportPgdata { + pub(crate) new_timeline_id: TimelineId, + pub(crate) location: import_pgdata::index_part_format::Location, + pub(crate) idempotency_key: import_pgdata::index_part_format::IdempotencyKey, +} + +/// What is used to determine idempotency of a [`Tenant::create_timeline`] call in [`Tenant::start_creating_timeline`] in [`Tenant::start_creating_timeline`]. /// /// Each [`Timeline`] object holds [`Self`] as an immutable property in [`Timeline::create_idempotency`]. /// @@ -907,19 +916,50 @@ pub(crate) enum CreateTimelineIdempotency { ancestor_timeline_id: TimelineId, ancestor_start_lsn: Lsn, }, + ImportPgdata(CreatingTimelineIdempotencyImportPgdata), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct CreatingTimelineIdempotencyImportPgdata { + idempotency_key: import_pgdata::index_part_format::IdempotencyKey, } /// What is returned by [`Tenant::start_creating_timeline`]. #[must_use] -enum StartCreatingTimelineResult<'t> { - CreateGuard(TimelineCreateGuard<'t>), +enum StartCreatingTimelineResult { + CreateGuard(TimelineCreateGuard), Idempotent(Arc), } +enum TimelineInitAndSyncResult { + ReadyToActivate(Arc), + NeedsSpawnImportPgdata(TimelineInitAndSyncNeedsSpawnImportPgdata), +} + +impl TimelineInitAndSyncResult { + fn ready_to_activate(self) -> Option> { + match self { + Self::ReadyToActivate(timeline) => Some(timeline), + _ => None, + } + } +} + +#[must_use] +struct TimelineInitAndSyncNeedsSpawnImportPgdata { + timeline: Arc, + import_pgdata: import_pgdata::index_part_format::Root, + guard: TimelineCreateGuard, +} + /// What is returned by [`Tenant::create_timeline`]. enum CreateTimelineResult { Created(Arc), Idempotent(Arc), + /// IMPORTANT: This [`Arc`] object is not in [`Tenant::timelines`] when + /// we return this result, nor will this concrete object ever be added there. + /// Cf method comment on [`Tenant::create_timeline_import_pgdata`]. + ImportSpawned(Arc), } impl CreateTimelineResult { @@ -927,18 +967,19 @@ impl CreateTimelineResult { match self { Self::Created(_) => "Created", Self::Idempotent(_) => "Idempotent", + Self::ImportSpawned(_) => "ImportSpawned", } } fn timeline(&self) -> &Arc { match self { - Self::Created(t) | Self::Idempotent(t) => t, + Self::Created(t) | Self::Idempotent(t) | Self::ImportSpawned(t) => t, } } /// Unit test timelines aren't activated, test has to do it if it needs to. #[cfg(test)] fn into_timeline_for_test(self) -> Arc { match self { - Self::Created(t) | Self::Idempotent(t) => t, + Self::Created(t) | Self::Idempotent(t) | Self::ImportSpawned(t) => t, } } } @@ -962,33 +1003,13 @@ pub enum CreateTimelineError { } #[derive(thiserror::Error, Debug)] -enum InitdbError { - Other(anyhow::Error), +pub enum InitdbError { + #[error("Operation was cancelled")] Cancelled, - Spawn(std::io::Result<()>), - Failed(std::process::ExitStatus, Vec), -} - -impl fmt::Display for InitdbError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - InitdbError::Cancelled => write!(f, "Operation was cancelled"), - InitdbError::Spawn(e) => write!(f, "Spawn error: {:?}", e), - InitdbError::Failed(status, stderr) => write!( - f, - "Command failed with status {:?}: {}", - status, - String::from_utf8_lossy(stderr) - ), - InitdbError::Other(e) => write!(f, "Error: {:?}", e), - } - } -} - -impl From for InitdbError { - fn from(error: std::io::Error) -> Self { - InitdbError::Spawn(Err(error)) - } + #[error(transparent)] + Other(anyhow::Error), + #[error(transparent)] + Inner(postgres_initdb::Error), } enum CreateTimelineCause { @@ -996,6 +1017,15 @@ enum CreateTimelineCause { Delete, } +enum LoadTimelineCause { + Attach, + Unoffload, + ImportPgdata { + create_guard: TimelineCreateGuard, + activate: ActivateTimelineArgs, + }, +} + #[derive(thiserror::Error, Debug)] pub(crate) enum GcError { // The tenant is shutting down @@ -1072,24 +1102,35 @@ impl Tenant { /// it is marked as Active. #[allow(clippy::too_many_arguments)] async fn timeline_init_and_sync( - &self, + self: &Arc, timeline_id: TimelineId, resources: TimelineResources, - index_part: IndexPart, + mut index_part: IndexPart, metadata: TimelineMetadata, ancestor: Option>, - _ctx: &RequestContext, - ) -> anyhow::Result<()> { + cause: LoadTimelineCause, + ctx: &RequestContext, + ) -> anyhow::Result { let tenant_id = self.tenant_shard_id; - let idempotency = if metadata.ancestor_timeline().is_none() { - CreateTimelineIdempotency::Bootstrap { - pg_version: metadata.pg_version(), + let import_pgdata = index_part.import_pgdata.take(); + let idempotency = match &import_pgdata { + Some(import_pgdata) => { + CreateTimelineIdempotency::ImportPgdata(CreatingTimelineIdempotencyImportPgdata { + idempotency_key: import_pgdata.idempotency_key().clone(), + }) } - } else { - CreateTimelineIdempotency::Branch { - ancestor_timeline_id: metadata.ancestor_timeline().unwrap(), - ancestor_start_lsn: metadata.ancestor_lsn(), + None => { + if metadata.ancestor_timeline().is_none() { + CreateTimelineIdempotency::Bootstrap { + pg_version: metadata.pg_version(), + } + } else { + CreateTimelineIdempotency::Branch { + ancestor_timeline_id: metadata.ancestor_timeline().unwrap(), + ancestor_start_lsn: metadata.ancestor_lsn(), + } + } } }; @@ -1121,39 +1162,91 @@ impl Tenant { format!("Failed to load layermap for timeline {tenant_id}/{timeline_id}") })?; - { - // avoiding holding it across awaits - let mut timelines_accessor = self.timelines.lock().unwrap(); - match timelines_accessor.entry(timeline_id) { - // We should never try and load the same timeline twice during startup - Entry::Occupied(_) => { - unreachable!( - "Timeline {tenant_id}/{timeline_id} already exists in the tenant map" - ); + match import_pgdata { + Some(import_pgdata) if !import_pgdata.is_done() => { + match cause { + LoadTimelineCause::Attach | LoadTimelineCause::Unoffload => (), + LoadTimelineCause::ImportPgdata { .. } => { + unreachable!("ImportPgdata should not be reloading timeline import is done and persisted as such in s3") + } } - Entry::Vacant(v) => { - v.insert(Arc::clone(&timeline)); - timeline.maybe_spawn_flush_loop(); + let mut guard = self.timelines_creating.lock().unwrap(); + if !guard.insert(timeline_id) { + // We should never try and load the same timeline twice during startup + unreachable!("Timeline {tenant_id}/{timeline_id} is already being created") } + let timeline_create_guard = TimelineCreateGuard { + _tenant_gate_guard: self.gate.enter()?, + owning_tenant: self.clone(), + timeline_id, + idempotency, + // The users of this specific return value don't need the timline_path in there. + timeline_path: timeline + .conf + .timeline_path(&timeline.tenant_shard_id, &timeline.timeline_id), + }; + Ok(TimelineInitAndSyncResult::NeedsSpawnImportPgdata( + TimelineInitAndSyncNeedsSpawnImportPgdata { + timeline, + import_pgdata, + guard: timeline_create_guard, + }, + )) } - }; + Some(_) | None => { + { + let mut timelines_accessor = self.timelines.lock().unwrap(); + match timelines_accessor.entry(timeline_id) { + // We should never try and load the same timeline twice during startup + Entry::Occupied(_) => { + unreachable!( + "Timeline {tenant_id}/{timeline_id} already exists in the tenant map" + ); + } + Entry::Vacant(v) => { + v.insert(Arc::clone(&timeline)); + timeline.maybe_spawn_flush_loop(); + } + } + } - // Sanity check: a timeline should have some content. - anyhow::ensure!( - ancestor.is_some() - || timeline - .layers - .read() - .await - .layer_map() - .expect("currently loading, layer manager cannot be shutdown already") - .iter_historic_layers() - .next() - .is_some(), - "Timeline has no ancestor and no layer files" - ); + // Sanity check: a timeline should have some content. + anyhow::ensure!( + ancestor.is_some() + || timeline + .layers + .read() + .await + .layer_map() + .expect("currently loading, layer manager cannot be shutdown already") + .iter_historic_layers() + .next() + .is_some(), + "Timeline has no ancestor and no layer files" + ); - Ok(()) + match cause { + LoadTimelineCause::Attach | LoadTimelineCause::Unoffload => (), + LoadTimelineCause::ImportPgdata { + create_guard, + activate, + } => { + // TODO: see the comment in the task code above how I'm not so certain + // it is safe to activate here because of concurrent shutdowns. + match activate { + ActivateTimelineArgs::Yes { broker_client } => { + info!("activating timeline after reload from pgdata import task"); + timeline.activate(self.clone(), broker_client, None, ctx); + } + ActivateTimelineArgs::No => (), + } + drop(create_guard); + } + } + + Ok(TimelineInitAndSyncResult::ReadyToActivate(timeline)) + } + } } /// Attach a tenant that's available in cloud storage. @@ -1578,24 +1671,46 @@ impl Tenant { } // TODO again handle early failure - self.load_remote_timeline( - timeline_id, - index_part, - remote_metadata, - TimelineResources { - remote_client, - timeline_get_throttle: self.timeline_get_throttle.clone(), - l0_flush_global_state: self.l0_flush_global_state.clone(), - }, - ctx, - ) - .await - .with_context(|| { - format!( - "failed to load remote timeline {} for tenant {}", - timeline_id, self.tenant_shard_id + let effect = self + .load_remote_timeline( + timeline_id, + index_part, + remote_metadata, + TimelineResources { + remote_client, + timeline_get_throttle: self.timeline_get_throttle.clone(), + l0_flush_global_state: self.l0_flush_global_state.clone(), + }, + LoadTimelineCause::Attach, + ctx, ) - })?; + .await + .with_context(|| { + format!( + "failed to load remote timeline {} for tenant {}", + timeline_id, self.tenant_shard_id + ) + })?; + + match effect { + TimelineInitAndSyncResult::ReadyToActivate(_) => { + // activation happens later, on Tenant::activate + } + TimelineInitAndSyncResult::NeedsSpawnImportPgdata( + TimelineInitAndSyncNeedsSpawnImportPgdata { + timeline, + import_pgdata, + guard, + }, + ) => { + tokio::task::spawn(self.clone().create_timeline_import_pgdata_task( + timeline, + import_pgdata, + ActivateTimelineArgs::No, + guard, + )); + } + } } // Walk through deleted timelines, resume deletion @@ -1719,13 +1834,14 @@ impl Tenant { #[instrument(skip_all, fields(timeline_id=%timeline_id))] async fn load_remote_timeline( - &self, + self: &Arc, timeline_id: TimelineId, index_part: IndexPart, remote_metadata: TimelineMetadata, resources: TimelineResources, + cause: LoadTimelineCause, ctx: &RequestContext, - ) -> anyhow::Result<()> { + ) -> anyhow::Result { span::debug_assert_current_span_has_tenant_id(); info!("downloading index file for timeline {}", timeline_id); @@ -1752,6 +1868,7 @@ impl Tenant { index_part, remote_metadata, ancestor, + cause, ctx, ) .await @@ -1938,6 +2055,7 @@ impl Tenant { TimelineArchivalError::Other(anyhow::anyhow!("Timeline already exists")) } TimelineExclusionError::Other(e) => TimelineArchivalError::Other(e), + TimelineExclusionError::ShuttingDown => TimelineArchivalError::Cancelled, })?; let timeline_preload = self @@ -1976,6 +2094,7 @@ impl Tenant { index_part, remote_metadata, timeline_resources, + LoadTimelineCause::Unoffload, &ctx, ) .await @@ -2213,7 +2332,7 @@ impl Tenant { /// /// Tests should use `Tenant::create_test_timeline` to set up the minimum required metadata keys. pub(crate) async fn create_empty_timeline( - &self, + self: &Arc, new_timeline_id: TimelineId, initdb_lsn: Lsn, pg_version: u32, @@ -2263,7 +2382,7 @@ impl Tenant { // Our current tests don't need the background loops. #[cfg(test)] pub async fn create_test_timeline( - &self, + self: &Arc, new_timeline_id: TimelineId, initdb_lsn: Lsn, pg_version: u32, @@ -2302,7 +2421,7 @@ impl Tenant { #[cfg(test)] #[allow(clippy::too_many_arguments)] pub async fn create_test_timeline_with_layers( - &self, + self: &Arc, new_timeline_id: TimelineId, initdb_lsn: Lsn, pg_version: u32, @@ -2439,6 +2558,16 @@ impl Tenant { self.branch_timeline(&ancestor_timeline, new_timeline_id, ancestor_start_lsn, ctx) .await? } + CreateTimelineParams::ImportPgdata(params) => { + self.create_timeline_import_pgdata( + params, + ActivateTimelineArgs::Yes { + broker_client: broker_client.clone(), + }, + ctx, + ) + .await? + } }; // At this point we have dropped our guard on [`Self::timelines_creating`], and @@ -2481,11 +2610,202 @@ impl Tenant { ); timeline } + CreateTimelineResult::ImportSpawned(timeline) => { + info!("import task spawned, timeline will become visible and activated once the import is done"); + timeline + } }; Ok(activated_timeline) } + /// The returned [`Arc`] is NOT in the [`Tenant::timelines`] map until the import + /// completes in the background. A DIFFERENT [`Arc`] will be inserted into the + /// [`Tenant::timelines`] map when the import completes. + /// We only return an [`Arc`] here so the API handler can create a [`pageserver_api::models::TimelineInfo`] + /// for the response. + async fn create_timeline_import_pgdata( + self: &Arc, + params: CreateTimelineParamsImportPgdata, + activate: ActivateTimelineArgs, + ctx: &RequestContext, + ) -> Result { + let CreateTimelineParamsImportPgdata { + new_timeline_id, + location, + idempotency_key, + } = params; + + let started_at = chrono::Utc::now().naive_utc(); + + // + // There's probably a simpler way to upload an index part, but, remote_timeline_client + // is the canonical way we do it. + // - create an empty timeline in-memory + // - use its remote_timeline_client to do the upload + // - dispose of the uninit timeline + // - keep the creation guard alive + + let timeline_create_guard = match self + .start_creating_timeline( + new_timeline_id, + CreateTimelineIdempotency::ImportPgdata(CreatingTimelineIdempotencyImportPgdata { + idempotency_key: idempotency_key.clone(), + }), + ) + .await? + { + StartCreatingTimelineResult::CreateGuard(guard) => guard, + StartCreatingTimelineResult::Idempotent(timeline) => { + return Ok(CreateTimelineResult::Idempotent(timeline)) + } + }; + + let mut uninit_timeline = { + let this = &self; + let initdb_lsn = Lsn(0); + let _ctx = ctx; + async move { + let new_metadata = TimelineMetadata::new( + // Initialize disk_consistent LSN to 0, The caller must import some data to + // make it valid, before calling finish_creation() + Lsn(0), + None, + None, + Lsn(0), + initdb_lsn, + initdb_lsn, + 15, + ); + this.prepare_new_timeline( + new_timeline_id, + &new_metadata, + timeline_create_guard, + initdb_lsn, + None, + ) + .await + } + } + .await?; + + let in_progress = import_pgdata::index_part_format::InProgress { + idempotency_key, + location, + started_at, + }; + let index_part = import_pgdata::index_part_format::Root::V1( + import_pgdata::index_part_format::V1::InProgress(in_progress), + ); + uninit_timeline + .raw_timeline() + .unwrap() + .remote_client + .schedule_index_upload_for_import_pgdata_state_update(Some(index_part.clone()))?; + + // wait_completion happens in caller + + let (timeline, timeline_create_guard) = uninit_timeline.finish_creation_myself(); + + tokio::spawn(self.clone().create_timeline_import_pgdata_task( + timeline.clone(), + index_part, + activate, + timeline_create_guard, + )); + + // NB: the timeline doesn't exist in self.timelines at this point + Ok(CreateTimelineResult::ImportSpawned(timeline)) + } + + #[instrument(skip_all, fields(tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), timeline_id=%timeline.timeline_id))] + async fn create_timeline_import_pgdata_task( + self: Arc, + timeline: Arc, + index_part: import_pgdata::index_part_format::Root, + activate: ActivateTimelineArgs, + timeline_create_guard: TimelineCreateGuard, + ) { + debug_assert_current_span_has_tenant_and_timeline_id(); + info!("starting"); + scopeguard::defer! {info!("exiting")}; + + let res = self + .create_timeline_import_pgdata_task_impl( + timeline, + index_part, + activate, + timeline_create_guard, + ) + .await; + if let Err(err) = &res { + error!(?err, "task failed"); + // TODO sleep & retry, sensitive to tenant shutdown + // TODO: allow timeline deletion requests => should cancel the task + } + } + + async fn create_timeline_import_pgdata_task_impl( + self: Arc, + timeline: Arc, + index_part: import_pgdata::index_part_format::Root, + activate: ActivateTimelineArgs, + timeline_create_guard: TimelineCreateGuard, + ) -> Result<(), anyhow::Error> { + let ctx = RequestContext::new(TaskKind::ImportPgdata, DownloadBehavior::Warn); + + info!("importing pgdata"); + import_pgdata::doit(&timeline, index_part, &ctx, self.cancel.clone()) + .await + .context("import")?; + info!("import done"); + + // + // Reload timeline from remote. + // This proves that the remote state is attachable, and it reuses the code. + // + // TODO: think about whether this is safe to do with concurrent Tenant::shutdown. + // timeline_create_guard hols the tenant gate open, so, shutdown cannot _complete_ until we exit. + // But our activate() call might launch new background tasks after Tenant::shutdown + // already went past shutting down the Tenant::timelines, which this timeline here is no part of. + // I think the same problem exists with the bootstrap & branch mgmt API tasks (tenant shutting + // down while bootstrapping/branching + activating), but, the race condition is much more likely + // to manifest because of the long runtime of this import task. + + // in theory this shouldn't even .await anything except for coop yield + info!("shutting down timeline"); + timeline.shutdown(ShutdownMode::Hard).await; + info!("timeline shut down, reloading from remote"); + // TODO: we can't do the following check because create_timeline_import_pgdata must return an Arc + // let Some(timeline) = Arc::into_inner(timeline) else { + // anyhow::bail!("implementation error: timeline that we shut down was still referenced from somewhere"); + // }; + let timeline_id = timeline.timeline_id; + + // load from object storage like Tenant::attach does + let resources = self.build_timeline_resources(timeline_id); + let index_part = resources + .remote_client + .download_index_file(&self.cancel) + .await?; + let index_part = match index_part { + MaybeDeletedIndexPart::Deleted(_) => { + // likely concurrent delete call, cplane should prevent this + anyhow::bail!("index part says deleted but we are not done creating yet, this should not happen but") + } + MaybeDeletedIndexPart::IndexPart(p) => p, + }; + let metadata = index_part.metadata.clone(); + self + .load_remote_timeline(timeline_id, index_part, metadata, resources, LoadTimelineCause::ImportPgdata{ + create_guard: timeline_create_guard, activate, }, &ctx) + .await? + .ready_to_activate() + .context("implementation error: reloaded timeline still needs import after import reported success")?; + + anyhow::Ok(()) + } + pub(crate) async fn delete_timeline( self: Arc, timeline_id: TimelineId, @@ -2895,6 +3215,18 @@ impl Tenant { } } + if let ShutdownMode::Reload = shutdown_mode { + tracing::info!("Flushing deletion queue"); + if let Err(e) = self.deletion_queue_client.flush().await { + match e { + DeletionQueueError::ShuttingDown => { + // This is the only error we expect for now. In the future, if more error + // variants are added, we should handle them here. + } + } + } + } + // We cancel the Tenant's cancellation token _after_ the timelines have all shut down. This permits // them to continue to do work during their shutdown methods, e.g. flushing data. tracing::debug!("Cancelling CancellationToken"); @@ -3337,6 +3669,13 @@ where Ok(result) } +enum ActivateTimelineArgs { + Yes { + broker_client: storage_broker::BrokerClientChannel, + }, + No, +} + impl Tenant { pub fn tenant_specific_overrides(&self) -> TenantConfOpt { self.tenant_conf.load().tenant_conf.clone() @@ -3520,6 +3859,7 @@ impl Tenant { /// `validate_ancestor == false` is used when a timeline is created for deletion /// and we might not have the ancestor present anymore which is fine for to be /// deleted timelines. + #[allow(clippy::too_many_arguments)] fn create_timeline_struct( &self, new_timeline_id: TimelineId, @@ -4283,16 +4623,17 @@ impl Tenant { /// If the timeline was already created in the meantime, we check whether this /// request conflicts or is idempotent , based on `state`. async fn start_creating_timeline( - &self, + self: &Arc, new_timeline_id: TimelineId, idempotency: CreateTimelineIdempotency, - ) -> Result, CreateTimelineError> { + ) -> Result { let allow_offloaded = false; match self.create_timeline_create_guard(new_timeline_id, idempotency, allow_offloaded) { Ok(create_guard) => { pausable_failpoint!("timeline-creation-after-uninit"); Ok(StartCreatingTimelineResult::CreateGuard(create_guard)) } + Err(TimelineExclusionError::ShuttingDown) => Err(CreateTimelineError::ShuttingDown), Err(TimelineExclusionError::AlreadyCreating) => { // Creation is in progress, we cannot create it again, and we cannot // check if this request matches the existing one, so caller must try @@ -4582,7 +4923,7 @@ impl Tenant { &'a self, new_timeline_id: TimelineId, new_metadata: &TimelineMetadata, - create_guard: TimelineCreateGuard<'a>, + create_guard: TimelineCreateGuard, start_lsn: Lsn, ancestor: Option>, ) -> anyhow::Result> { @@ -4642,7 +4983,7 @@ impl Tenant { /// The `allow_offloaded` parameter controls whether to tolerate the existence of /// offloaded timelines or not. fn create_timeline_create_guard( - &self, + self: &Arc, timeline_id: TimelineId, idempotency: CreateTimelineIdempotency, allow_offloaded: bool, @@ -4902,48 +5243,16 @@ async fn run_initdb( let _permit = INIT_DB_SEMAPHORE.acquire().await; - let mut initdb_command = tokio::process::Command::new(&initdb_bin_path); - initdb_command - .args(["--pgdata", initdb_target_dir.as_ref()]) - .args(["--username", &conf.superuser]) - .args(["--encoding", "utf8"]) - .args(["--locale", &conf.locale]) - .arg("--no-instructions") - .arg("--no-sync") - .env_clear() - .env("LD_LIBRARY_PATH", &initdb_lib_dir) - .env("DYLD_LIBRARY_PATH", &initdb_lib_dir) - .stdin(std::process::Stdio::null()) - // stdout invocation produces the same output every time, we don't need it - .stdout(std::process::Stdio::null()) - // we would be interested in the stderr output, if there was any - .stderr(std::process::Stdio::piped()); - - // Before version 14, only the libc provide was available. - if pg_version > 14 { - // Version 17 brought with it a builtin locale provider which only provides - // C and C.UTF-8. While being safer for collation purposes since it is - // guaranteed to be consistent throughout a major release, it is also more - // performant. - let locale_provider = if pg_version >= 17 { "builtin" } else { "libc" }; - - initdb_command.args(["--locale-provider", locale_provider]); - } - - let initdb_proc = initdb_command.spawn()?; - - // Ideally we'd select here with the cancellation token, but the problem is that - // we can't safely terminate initdb: it launches processes of its own, and killing - // initdb doesn't kill them. After we return from this function, we want the target - // directory to be able to be cleaned up. - // See https://github.com/neondatabase/neon/issues/6385 - let initdb_output = initdb_proc.wait_with_output().await?; - if !initdb_output.status.success() { - return Err(InitdbError::Failed( - initdb_output.status, - initdb_output.stderr, - )); - } + let res = postgres_initdb::do_run_initdb(postgres_initdb::RunInitdbArgs { + superuser: &conf.superuser, + locale: &conf.locale, + initdb_bin: &initdb_bin_path, + pg_version, + library_search_path: &initdb_lib_dir, + pgdata: initdb_target_dir, + }) + .await + .map_err(InitdbError::Inner); // This isn't true cancellation support, see above. Still return an error to // excercise the cancellation code path. @@ -4951,7 +5260,7 @@ async fn run_initdb( return Err(InitdbError::Cancelled); } - Ok(()) + res } /// Dump contents of a layer file to stdout. @@ -5047,6 +5356,7 @@ pub(crate) mod harness { lsn_lease_length: Some(tenant_conf.lsn_lease_length), lsn_lease_length_for_ts: Some(tenant_conf.lsn_lease_length_for_ts), timeline_offloading: Some(tenant_conf.timeline_offloading), + wal_receiver_protocol_override: tenant_conf.wal_receiver_protocol_override, } } } diff --git a/pageserver/src/tenant/config.rs b/pageserver/src/tenant/config.rs index 4d6176bfd9..5d3ac5a8e3 100644 --- a/pageserver/src/tenant/config.rs +++ b/pageserver/src/tenant/config.rs @@ -19,6 +19,7 @@ use serde_json::Value; use std::num::NonZeroU64; use std::time::Duration; use utils::generation::Generation; +use utils::postgres_client::PostgresClientProtocol; #[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq)] pub(crate) enum AttachmentMode { @@ -353,6 +354,9 @@ pub struct TenantConfOpt { #[serde(skip_serializing_if = "Option::is_none")] #[serde(default)] pub timeline_offloading: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub wal_receiver_protocol_override: Option, } impl TenantConfOpt { @@ -418,6 +422,9 @@ impl TenantConfOpt { timeline_offloading: self .lazy_slru_download .unwrap_or(global_conf.timeline_offloading), + wal_receiver_protocol_override: self + .wal_receiver_protocol_override + .or(global_conf.wal_receiver_protocol_override), } } } @@ -472,6 +479,7 @@ impl From for models::TenantConfig { lsn_lease_length: value.lsn_lease_length.map(humantime), lsn_lease_length_for_ts: value.lsn_lease_length_for_ts.map(humantime), timeline_offloading: value.timeline_offloading, + wal_receiver_protocol_override: value.wal_receiver_protocol_override, } } } diff --git a/pageserver/src/tenant/mgr.rs b/pageserver/src/tenant/mgr.rs index 92b2200542..eb8191e43e 100644 --- a/pageserver/src/tenant/mgr.rs +++ b/pageserver/src/tenant/mgr.rs @@ -1960,7 +1960,7 @@ impl TenantManager { attempt.before_reset_tenant(); let (_guard, progress) = utils::completion::channel(); - match tenant.shutdown(progress, ShutdownMode::Flush).await { + match tenant.shutdown(progress, ShutdownMode::Reload).await { Ok(()) => { slot_guard.drop_old_value().expect("it was just shutdown"); } diff --git a/pageserver/src/tenant/remote_timeline_client.rs b/pageserver/src/tenant/remote_timeline_client.rs index 377bc23542..007bd3eef0 100644 --- a/pageserver/src/tenant/remote_timeline_client.rs +++ b/pageserver/src/tenant/remote_timeline_client.rs @@ -199,7 +199,7 @@ use utils::backoff::{ use utils::pausable_failpoint; use utils::shard::ShardNumber; -use std::collections::{HashMap, VecDeque}; +use std::collections::{HashMap, HashSet, VecDeque}; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::{Arc, Mutex, OnceLock}; use std::time::Duration; @@ -223,7 +223,7 @@ use crate::task_mgr::shutdown_token; use crate::tenant::debug_assert_current_span_has_tenant_and_timeline_id; use crate::tenant::remote_timeline_client::download::download_retry; use crate::tenant::storage_layer::AsLayerDesc; -use crate::tenant::upload_queue::{Delete, UploadQueueStoppedDeletable}; +use crate::tenant::upload_queue::{Delete, OpType, UploadQueueStoppedDeletable}; use crate::tenant::TIMELINES_SEGMENT_NAME; use crate::{ config::PageServerConf, @@ -244,6 +244,7 @@ use self::index::IndexPart; use super::config::AttachedLocationConfig; use super::metadata::MetadataUpdate; use super::storage_layer::{Layer, LayerName, ResidentLayer}; +use super::timeline::import_pgdata; use super::upload_queue::{NotInitialized, SetDeletedFlagProgress}; use super::{DeleteTimelineError, Generation}; @@ -813,6 +814,18 @@ impl RemoteTimelineClient { Ok(need_wait) } + /// Launch an index-file upload operation in the background, setting `import_pgdata` field. + pub(crate) fn schedule_index_upload_for_import_pgdata_state_update( + self: &Arc, + state: Option, + ) -> anyhow::Result<()> { + let mut guard = self.upload_queue.lock().unwrap(); + let upload_queue = guard.initialized_mut()?; + upload_queue.dirty.import_pgdata = state; + self.schedule_index_upload(upload_queue)?; + Ok(()) + } + /// /// Launch an index-file upload operation in the background, if necessary. /// @@ -1090,7 +1103,7 @@ impl RemoteTimelineClient { "scheduled layer file upload {layer}", ); - let op = UploadOp::UploadLayer(layer, metadata); + let op = UploadOp::UploadLayer(layer, metadata, None); self.metric_begin(&op); upload_queue.queued_operations.push_back(op); } @@ -1805,7 +1818,7 @@ impl RemoteTimelineClient { // have finished. upload_queue.inprogress_tasks.is_empty() } - UploadOp::Delete(_) => { + UploadOp::Delete(..) => { // Wait for preceding uploads to finish. Concurrent deletions are OK, though. upload_queue.num_inprogress_deletions == upload_queue.inprogress_tasks.len() } @@ -1833,19 +1846,32 @@ impl RemoteTimelineClient { } // We can launch this task. Remove it from the queue first. - let next_op = upload_queue.queued_operations.pop_front().unwrap(); + let mut next_op = upload_queue.queued_operations.pop_front().unwrap(); debug!("starting op: {}", next_op); - // Update the counters - match next_op { - UploadOp::UploadLayer(_, _) => { + // Update the counters and prepare + match &mut next_op { + UploadOp::UploadLayer(layer, meta, mode) => { + if upload_queue + .recently_deleted + .remove(&(layer.layer_desc().layer_name().clone(), meta.generation)) + { + *mode = Some(OpType::FlushDeletion); + } else { + *mode = Some(OpType::MayReorder) + } upload_queue.num_inprogress_layer_uploads += 1; } UploadOp::UploadMetadata { .. } => { upload_queue.num_inprogress_metadata_uploads += 1; } - UploadOp::Delete(_) => { + UploadOp::Delete(Delete { layers }) => { + for (name, meta) in layers { + upload_queue + .recently_deleted + .insert((name.clone(), meta.generation)); + } upload_queue.num_inprogress_deletions += 1; } UploadOp::Barrier(sender) => { @@ -1921,7 +1947,66 @@ impl RemoteTimelineClient { } let upload_result: anyhow::Result<()> = match &task.op { - UploadOp::UploadLayer(ref layer, ref layer_metadata) => { + UploadOp::UploadLayer(ref layer, ref layer_metadata, mode) => { + if let Some(OpType::FlushDeletion) = mode { + if self.config.read().unwrap().block_deletions { + // Of course, this is not efficient... but usually the queue should be empty. + let mut queue_locked = self.upload_queue.lock().unwrap(); + let mut detected = false; + if let Ok(queue) = queue_locked.initialized_mut() { + for list in queue.blocked_deletions.iter_mut() { + list.layers.retain(|(name, meta)| { + if name == &layer.layer_desc().layer_name() + && meta.generation == layer_metadata.generation + { + detected = true; + // remove the layer from deletion queue + false + } else { + // keep the layer + true + } + }); + } + } + if detected { + info!( + "cancelled blocked deletion of layer {} at gen {:?}", + layer.layer_desc().layer_name(), + layer_metadata.generation + ); + } + } else { + // TODO: we did not guarantee that upload task starts after deletion task, so there could be possibly race conditions + // that we still get the layer deleted. But this only happens if someone creates a layer immediately after it's deleted, + // which is not possible in the current system. + info!( + "waiting for deletion queue flush to complete before uploading layer {} at gen {:?}", + layer.layer_desc().layer_name(), + layer_metadata.generation + ); + { + // We are going to flush, we can clean up the recently deleted list. + let mut queue_locked = self.upload_queue.lock().unwrap(); + if let Ok(queue) = queue_locked.initialized_mut() { + queue.recently_deleted.clear(); + } + } + if let Err(e) = self.deletion_queue_client.flush_execute().await { + warn!( + "failed to flush the deletion queue before uploading layer {} at gen {:?}, still proceeding to upload: {e:#} ", + layer.layer_desc().layer_name(), + layer_metadata.generation + ); + } else { + info!( + "done flushing deletion queue before uploading layer {} at gen {:?}", + layer.layer_desc().layer_name(), + layer_metadata.generation + ); + } + } + } let local_path = layer.local_path(); // We should only be uploading layers created by this `Tenant`'s lifetime, so @@ -2085,7 +2170,7 @@ impl RemoteTimelineClient { upload_queue.inprogress_tasks.remove(&task.task_id); let lsn_update = match task.op { - UploadOp::UploadLayer(_, _) => { + UploadOp::UploadLayer(_, _, _) => { upload_queue.num_inprogress_layer_uploads -= 1; None } @@ -2162,7 +2247,7 @@ impl RemoteTimelineClient { )> { use RemoteTimelineClientMetricsCallTrackSize::DontTrackSize; let res = match op { - UploadOp::UploadLayer(_, m) => ( + UploadOp::UploadLayer(_, m, _) => ( RemoteOpFileKind::Layer, RemoteOpKind::Upload, RemoteTimelineClientMetricsCallTrackSize::Bytes(m.file_size), @@ -2259,6 +2344,7 @@ impl RemoteTimelineClient { blocked_deletions: Vec::new(), shutting_down: false, shutdown_ready: Arc::new(tokio::sync::Semaphore::new(0)), + recently_deleted: HashSet::new(), }; let upload_queue = std::mem::replace( diff --git a/pageserver/src/tenant/remote_timeline_client/download.rs b/pageserver/src/tenant/remote_timeline_client/download.rs index efcd20d1bf..d632e595ad 100644 --- a/pageserver/src/tenant/remote_timeline_client/download.rs +++ b/pageserver/src/tenant/remote_timeline_client/download.rs @@ -706,7 +706,7 @@ where .and_then(|x| x) } -async fn download_retry_forever( +pub(crate) async fn download_retry_forever( op: O, description: &str, cancel: &CancellationToken, diff --git a/pageserver/src/tenant/remote_timeline_client/index.rs b/pageserver/src/tenant/remote_timeline_client/index.rs index d8a881a2c4..506990fb2f 100644 --- a/pageserver/src/tenant/remote_timeline_client/index.rs +++ b/pageserver/src/tenant/remote_timeline_client/index.rs @@ -12,6 +12,7 @@ use utils::id::TimelineId; use crate::tenant::metadata::TimelineMetadata; use crate::tenant::storage_layer::LayerName; +use crate::tenant::timeline::import_pgdata; use crate::tenant::Generation; use pageserver_api::shard::ShardIndex; @@ -37,6 +38,13 @@ pub struct IndexPart { #[serde(skip_serializing_if = "Option::is_none")] pub archived_at: Option, + /// This field supports import-from-pgdata ("fast imports" platform feature). + /// We don't currently use fast imports, so, this field is None for all production timelines. + /// See for more information. + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub import_pgdata: Option, + /// Per layer file name metadata, which can be present for a present or missing layer file. /// /// Older versions of `IndexPart` will not have this property or have only a part of metadata @@ -90,10 +98,11 @@ impl IndexPart { /// - 7: metadata_bytes is no longer written, but still read /// - 8: added `archived_at` /// - 9: +gc_blocking - const LATEST_VERSION: usize = 9; + /// - 10: +import_pgdata + const LATEST_VERSION: usize = 10; // Versions we may see when reading from a bucket. - pub const KNOWN_VERSIONS: &'static [usize] = &[1, 2, 3, 4, 5, 6, 7, 8, 9]; + pub const KNOWN_VERSIONS: &'static [usize] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; pub const FILE_NAME: &'static str = "index_part.json"; @@ -108,6 +117,7 @@ impl IndexPart { lineage: Default::default(), gc_blocking: None, last_aux_file_policy: None, + import_pgdata: None, } } @@ -381,6 +391,7 @@ mod tests { lineage: Lineage::default(), gc_blocking: None, last_aux_file_policy: None, + import_pgdata: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -425,6 +436,7 @@ mod tests { lineage: Lineage::default(), gc_blocking: None, last_aux_file_policy: None, + import_pgdata: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -470,6 +482,7 @@ mod tests { lineage: Lineage::default(), gc_blocking: None, last_aux_file_policy: None, + import_pgdata: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -518,6 +531,7 @@ mod tests { lineage: Lineage::default(), gc_blocking: None, last_aux_file_policy: None, + import_pgdata: None, }; let empty_layers_parsed = IndexPart::from_json_bytes(empty_layers_json.as_bytes()).unwrap(); @@ -561,6 +575,7 @@ mod tests { lineage: Lineage::default(), gc_blocking: None, last_aux_file_policy: None, + import_pgdata: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -607,6 +622,7 @@ mod tests { }, gc_blocking: None, last_aux_file_policy: None, + import_pgdata: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -658,6 +674,7 @@ mod tests { }, gc_blocking: None, last_aux_file_policy: Some(AuxFilePolicy::V2), + import_pgdata: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -714,6 +731,7 @@ mod tests { lineage: Default::default(), gc_blocking: None, last_aux_file_policy: Default::default(), + import_pgdata: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -771,6 +789,7 @@ mod tests { lineage: Default::default(), gc_blocking: None, last_aux_file_policy: Default::default(), + import_pgdata: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -833,6 +852,83 @@ mod tests { }), last_aux_file_policy: Default::default(), archived_at: None, + import_pgdata: None, + }; + + let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); + assert_eq!(part, expected); + } + + #[test] + fn v10_importpgdata_is_parsed() { + let example = r#"{ + "version": 10, + "layer_metadata":{ + "000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9": { "file_size": 25600000 }, + "000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51": { "file_size": 9007199254741001 } + }, + "disk_consistent_lsn":"0/16960E8", + "metadata": { + "disk_consistent_lsn": "0/16960E8", + "prev_record_lsn": "0/1696070", + "ancestor_timeline": "e45a7f37d3ee2ff17dc14bf4f4e3f52e", + "ancestor_lsn": "0/0", + "latest_gc_cutoff_lsn": "0/1696070", + "initdb_lsn": "0/1696070", + "pg_version": 14 + }, + "gc_blocking": { + "started_at": "2024-07-19T09:00:00.123", + "reasons": ["DetachAncestor"] + }, + "import_pgdata": { + "V1": { + "Done": { + "idempotency_key": "specified-by-client-218a5213-5044-4562-a28d-d024c5f057f5", + "started_at": "2024-11-13T09:23:42.123", + "finished_at": "2024-11-13T09:42:23.123" + } + } + } + }"#; + + let expected = IndexPart { + version: 10, + layer_metadata: HashMap::from([ + ("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap(), LayerFileMetadata { + file_size: 25600000, + generation: Generation::none(), + shard: ShardIndex::unsharded() + }), + ("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51".parse().unwrap(), LayerFileMetadata { + file_size: 9007199254741001, + generation: Generation::none(), + shard: ShardIndex::unsharded() + }) + ]), + disk_consistent_lsn: "0/16960E8".parse::().unwrap(), + metadata: TimelineMetadata::new( + Lsn::from_str("0/16960E8").unwrap(), + Some(Lsn::from_str("0/1696070").unwrap()), + Some(TimelineId::from_str("e45a7f37d3ee2ff17dc14bf4f4e3f52e").unwrap()), + Lsn::INVALID, + Lsn::from_str("0/1696070").unwrap(), + Lsn::from_str("0/1696070").unwrap(), + 14, + ).with_recalculated_checksum().unwrap(), + deleted_at: None, + lineage: Default::default(), + gc_blocking: Some(GcBlocking { + started_at: parse_naive_datetime("2024-07-19T09:00:00.123000000"), + reasons: enumset::EnumSet::from_iter([GcBlockingReason::DetachAncestor]), + }), + last_aux_file_policy: Default::default(), + archived_at: None, + import_pgdata: Some(import_pgdata::index_part_format::Root::V1(import_pgdata::index_part_format::V1::Done(import_pgdata::index_part_format::Done{ + started_at: parse_naive_datetime("2024-11-13T09:23:42.123000000"), + finished_at: parse_naive_datetime("2024-11-13T09:42:23.123000000"), + idempotency_key: import_pgdata::index_part_format::IdempotencyKey::new("specified-by-client-218a5213-5044-4562-a28d-d024c5f057f5".to_string()), + }))) }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 95864af4d0..730477a7f4 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -4,6 +4,7 @@ pub mod delete; pub(crate) mod detach_ancestor; mod eviction_task; pub(crate) mod handle; +pub(crate) mod import_pgdata; mod init; pub mod layer_manager; pub(crate) mod logical_size; @@ -49,6 +50,7 @@ use tokio_util::sync::CancellationToken; use tracing::*; use utils::{ fs_ext, pausable_failpoint, + postgres_client::PostgresClientProtocol, sync::gate::{Gate, GateGuard}, }; use wal_decoder::serialized_batch::SerializedValueBatch; @@ -892,10 +894,11 @@ pub(crate) enum ShutdownMode { /// While we are flushing, we continue to accept read I/O for LSNs ingested before /// the call to [`Timeline::shutdown`]. FreezeAndFlush, - /// Only flush the layers to the remote storage without freezing any open layers. This is the - /// mode used by ancestor detach and any other operations that reloads a tenant but not increasing - /// the generation number. - Flush, + /// Only flush the layers to the remote storage without freezing any open layers. Flush the deletion + /// queue. This is the mode used by ancestor detach and any other operations that reloads a tenant + /// but not increasing the generation number. Note that this mode cannot be used at tenant shutdown, + /// as flushing the deletion queue at that time will cause shutdown-in-progress errors. + Reload, /// Shut down immediately, without waiting for any open layers to flush. Hard, } @@ -1816,7 +1819,7 @@ impl Timeline { } } - if let ShutdownMode::Flush = mode { + if let ShutdownMode::Reload = mode { // drain the upload queue self.remote_client.shutdown().await; if !self.remote_client.no_pending_work() { @@ -2085,6 +2088,11 @@ impl Timeline { .unwrap_or(self.conf.default_tenant_conf.lsn_lease_length_for_ts) } + pub(crate) fn is_gc_blocked_by_lsn_lease_deadline(&self) -> bool { + let tenant_conf = self.tenant_conf.load(); + tenant_conf.is_gc_blocked_by_lsn_lease_deadline() + } + pub(crate) fn get_lazy_slru_download(&self) -> bool { let tenant_conf = self.tenant_conf.load(); tenant_conf @@ -2172,6 +2180,21 @@ impl Timeline { ) } + /// Resolve the effective WAL receiver protocol to use for this tenant. + /// + /// Priority order is: + /// 1. Tenant config override + /// 2. Default value for tenant config override + /// 3. Pageserver config override + /// 4. Pageserver config default + pub fn resolve_wal_receiver_protocol(&self) -> PostgresClientProtocol { + let tenant_conf = self.tenant_conf.load().tenant_conf.clone(); + tenant_conf + .wal_receiver_protocol_override + .or(self.conf.default_tenant_conf.wal_receiver_protocol_override) + .unwrap_or(self.conf.wal_receiver_protocol) + } + pub(super) fn tenant_conf_updated(&self, new_conf: &AttachedTenantConf) { // NB: Most tenant conf options are read by background loops, so, // changes will automatically be picked up. @@ -2464,6 +2487,7 @@ impl Timeline { *guard = Some(WalReceiver::start( Arc::clone(self), WalReceiverConf { + protocol: self.resolve_wal_receiver_protocol(), wal_connect_timeout, lagging_wal_timeout, max_lsn_wal_lag, @@ -2647,6 +2671,7 @@ impl Timeline { // // NB: generation numbers naturally protect against this because they disambiguate // (1) and (4) + // TODO: this is basically a no-op now, should we remove it? self.remote_client.schedule_barrier()?; // Tenant::create_timeline will wait for these uploads to happen before returning, or // on retry. @@ -2702,20 +2727,23 @@ impl Timeline { { Some(cancel) => cancel.cancel(), None => { - let state = self.current_state(); - if matches!( - state, - TimelineState::Broken { .. } | TimelineState::Stopping - ) { - - // Can happen when timeline detail endpoint is used when deletion is ongoing (or its broken). - // Don't make noise. - } else { - warn!("unexpected: cancel_wait_for_background_loop_concurrency_limit_semaphore not set, priority-boosting of logical size calculation will not work"); - debug_assert!(false); + match self.current_state() { + TimelineState::Broken { .. } | TimelineState::Stopping => { + // Can happen when timeline detail endpoint is used when deletion is ongoing (or its broken). + // Don't make noise. + } + TimelineState::Loading => { + // Import does not return an activated timeline. + info!("discarding priority boost for logical size calculation because timeline is not yet active"); + } + TimelineState::Active => { + // activation should be setting the once cell + warn!("unexpected: cancel_wait_for_background_loop_concurrency_limit_semaphore not set, priority-boosting of logical size calculation will not work"); + debug_assert!(false); + } } } - }; + } } } @@ -3819,7 +3847,8 @@ impl Timeline { }; // Backpressure mechanism: wait with continuation of the flush loop until we have uploaded all layer files. - // This makes us refuse ingest until the new layers have been persisted to the remote. + // This makes us refuse ingest until the new layers have been persisted to the remote + let start = Instant::now(); self.remote_client .wait_completion() .await @@ -3832,6 +3861,8 @@ impl Timeline { FlushLayerError::Other(anyhow!(e).into()) } })?; + let duration = start.elapsed().as_secs_f64(); + self.metrics.flush_wait_upload_time_gauge_add(duration); // FIXME: between create_delta_layer and the scheduling of the upload in `update_metadata_file`, // a compaction can delete the file and then it won't be available for uploads any more. @@ -5886,7 +5917,7 @@ impl<'a> TimelineWriter<'a> { batch: SerializedValueBatch, ctx: &RequestContext, ) -> anyhow::Result<()> { - if batch.is_empty() { + if !batch.has_data() { return Ok(()); } diff --git a/pageserver/src/tenant/timeline/import_pgdata.rs b/pageserver/src/tenant/timeline/import_pgdata.rs new file mode 100644 index 0000000000..de56468580 --- /dev/null +++ b/pageserver/src/tenant/timeline/import_pgdata.rs @@ -0,0 +1,218 @@ +use std::sync::Arc; + +use anyhow::{bail, Context}; +use remote_storage::RemotePath; +use tokio_util::sync::CancellationToken; +use tracing::{info, info_span, Instrument}; +use utils::lsn::Lsn; + +use crate::{context::RequestContext, tenant::metadata::TimelineMetadata}; + +use super::Timeline; + +mod flow; +mod importbucket_client; +mod importbucket_format; +pub(crate) mod index_part_format; +pub(crate) mod upcall_api; + +pub async fn doit( + timeline: &Arc, + index_part: index_part_format::Root, + ctx: &RequestContext, + cancel: CancellationToken, +) -> anyhow::Result<()> { + let index_part_format::Root::V1(v1) = index_part; + let index_part_format::InProgress { + location, + idempotency_key, + started_at, + } = match v1 { + index_part_format::V1::Done(_) => return Ok(()), + index_part_format::V1::InProgress(in_progress) => in_progress, + }; + + let storage = importbucket_client::new(timeline.conf, &location, cancel.clone()).await?; + + info!("get spec early so we know we'll be able to upcall when done"); + let Some(spec) = storage.get_spec().await? else { + bail!("spec not found") + }; + + let upcall_client = + upcall_api::Client::new(timeline.conf, cancel.clone()).context("create upcall client")?; + + // + // send an early progress update to clean up k8s job early and generate potentially useful logs + // + info!("send early progress update"); + upcall_client + .send_progress_until_success(&spec) + .instrument(info_span!("early_progress_update")) + .await?; + + let status_prefix = RemotePath::from_string("status").unwrap(); + + // + // See if shard is done. + // TODO: incorporate generations into status key for split brain safety. Figure out together with checkpointing. + // + let shard_status_key = + status_prefix.join(format!("shard-{}", timeline.tenant_shard_id.shard_slug())); + let shard_status: Option = + storage.get_json(&shard_status_key).await?; + info!(?shard_status, "peeking shard status"); + if shard_status.map(|st| st.done).unwrap_or(false) { + info!("shard status indicates that the shard is done, skipping import"); + } else { + // TODO: checkpoint the progress into the IndexPart instead of restarting + // from the beginning. + + // + // Wipe the slate clean - the flow does not allow resuming. + // We can implement resuming in the future by checkpointing the progress into the IndexPart. + // + info!("wipe the slate clean"); + { + // TODO: do we need to hold GC lock for this? + let mut guard = timeline.layers.write().await; + assert!( + guard.layer_map()?.open_layer.is_none(), + "while importing, there should be no in-memory layer" // this just seems like a good place to assert it + ); + let all_layers_keys = guard.all_persistent_layers(); + let all_layers: Vec<_> = all_layers_keys + .iter() + .map(|key| guard.get_from_key(key)) + .collect(); + let open = guard.open_mut().context("open_mut")?; + + timeline.remote_client.schedule_gc_update(&all_layers)?; + open.finish_gc_timeline(&all_layers); + } + + // + // Wait for pgdata to finish uploading + // + info!("wait for pgdata to reach status 'done'"); + let pgdata_status_key = status_prefix.join("pgdata"); + loop { + let res = async { + let pgdata_status: Option = storage + .get_json(&pgdata_status_key) + .await + .context("get pgdata status")?; + info!(?pgdata_status, "peeking pgdata status"); + if pgdata_status.map(|st| st.done).unwrap_or(false) { + Ok(()) + } else { + Err(anyhow::anyhow!("pgdata not done yet")) + } + } + .await; + match res { + Ok(_) => break, + Err(err) => { + info!(?err, "indefintely waiting for pgdata to finish"); + if tokio::time::timeout(std::time::Duration::from_secs(10), cancel.cancelled()) + .await + .is_ok() + { + bail!("cancelled while waiting for pgdata"); + } + } + } + } + + // + // Do the import + // + info!("do the import"); + let control_file = storage.get_control_file().await?; + let base_lsn = control_file.base_lsn(); + + info!("update TimelineMetadata based on LSNs from control file"); + { + let pg_version = control_file.pg_version(); + let _ctx: &RequestContext = ctx; + async move { + // FIXME: The 'disk_consistent_lsn' should be the LSN at the *end* of the + // checkpoint record, and prev_record_lsn should point to its beginning. + // We should read the real end of the record from the WAL, but here we + // just fake it. + let disk_consistent_lsn = Lsn(base_lsn.0 + 8); + let prev_record_lsn = base_lsn; + let metadata = TimelineMetadata::new( + disk_consistent_lsn, + Some(prev_record_lsn), + None, // no ancestor + Lsn(0), // no ancestor lsn + base_lsn, // latest_gc_cutoff_lsn + base_lsn, // initdb_lsn + pg_version, + ); + + let _start_lsn = disk_consistent_lsn + 1; + + timeline + .remote_client + .schedule_index_upload_for_full_metadata_update(&metadata)?; + + timeline.remote_client.wait_completion().await?; + + anyhow::Ok(()) + } + } + .await?; + + flow::run( + timeline.clone(), + base_lsn, + control_file, + storage.clone(), + ctx, + ) + .await?; + + // + // Communicate that shard is done. + // + storage + .put_json( + &shard_status_key, + &importbucket_format::ShardStatus { done: true }, + ) + .await + .context("put shard status")?; + } + + // + // Ensure at-least-once deliver of the upcall to cplane + // before we mark the task as done and never come here again. + // + info!("send final progress update"); + upcall_client + .send_progress_until_success(&spec) + .instrument(info_span!("final_progress_update")) + .await?; + + // + // Mark as done in index_part. + // This makes subsequent timeline loads enter the normal load code path + // instead of spawning the import task and calling this here function. + // + info!("mark import as complete in index part"); + timeline + .remote_client + .schedule_index_upload_for_import_pgdata_state_update(Some(index_part_format::Root::V1( + index_part_format::V1::Done(index_part_format::Done { + idempotency_key, + started_at, + finished_at: chrono::Utc::now().naive_utc(), + }), + )))?; + + timeline.remote_client.wait_completion().await?; + + Ok(()) +} diff --git a/pageserver/src/tenant/timeline/import_pgdata/flow.rs b/pageserver/src/tenant/timeline/import_pgdata/flow.rs new file mode 100644 index 0000000000..cbd4168c06 --- /dev/null +++ b/pageserver/src/tenant/timeline/import_pgdata/flow.rs @@ -0,0 +1,798 @@ +//! Import a PGDATA directory into an empty root timeline. +//! +//! This module is adapted hackathon code by Heikki and Stas. +//! Other code in the parent module was written by Christian as part of a customer PoC. +//! +//! The hackathon code was producing image layer files as a free-standing program. +//! +//! It has been modified to +//! - run inside a running Pageserver, within the proper lifecycles of Timeline -> Tenant(Shard) +//! - => sharding-awareness: produce image layers with only the data relevant for this shard +//! - => S3 as the source for the PGDATA instead of local filesystem +//! +//! TODOs before productionization: +//! - ChunkProcessingJob size / ImportJob::total_size does not account for sharding. +//! => produced image layers likely too small. +//! - ChunkProcessingJob should cut up an ImportJob to hit exactly target image layer size. +//! - asserts / unwraps need to be replaced with errors +//! - don't trust remote objects will be small (=prevent OOMs in those cases) +//! - limit all in-memory buffers in size, or download to disk and read from there +//! - limit task concurrency +//! - generally play nice with other tenants in the system +//! - importbucket is different bucket than main pageserver storage, so, should be fine wrt S3 rate limits +//! - but concerns like network bandwidth, local disk write bandwidth, local disk capacity, etc +//! - integrate with layer eviction system +//! - audit for Tenant::cancel nor Timeline::cancel responsivity +//! - audit for Tenant/Timeline gate holding (we spawn tokio tasks during this flow!) +//! +//! An incomplete set of TODOs from the Hackathon: +//! - version-specific CheckPointData (=> pgv abstraction, already exists for regular walingest) + +use std::sync::Arc; + +use anyhow::{bail, ensure}; +use bytes::Bytes; + +use itertools::Itertools; +use pageserver_api::{ + key::{rel_block_to_key, rel_dir_to_key, rel_size_to_key, relmap_file_key, DBDIR_KEY}, + reltag::RelTag, + shard::ShardIdentity, +}; +use postgres_ffi::{pg_constants, relfile_utils::parse_relfilename, BLCKSZ}; +use tokio::task::JoinSet; +use tracing::{debug, info_span, instrument, Instrument}; + +use crate::{ + assert_u64_eq_usize::UsizeIsU64, + pgdatadir_mapping::{SlruSegmentDirectory, TwoPhaseDirectory}, +}; +use crate::{ + context::{DownloadBehavior, RequestContext}, + pgdatadir_mapping::{DbDirectory, RelDirectory}, + task_mgr::TaskKind, + tenant::storage_layer::{ImageLayerWriter, Layer}, +}; + +use pageserver_api::key::Key; +use pageserver_api::key::{ + slru_block_to_key, slru_dir_to_key, slru_segment_size_to_key, CHECKPOINT_KEY, CONTROLFILE_KEY, + TWOPHASEDIR_KEY, +}; +use pageserver_api::keyspace::singleton_range; +use pageserver_api::keyspace::{contiguous_range_len, is_contiguous_range}; +use pageserver_api::reltag::SlruKind; +use utils::bin_ser::BeSer; +use utils::lsn::Lsn; + +use std::collections::HashSet; +use std::ops::Range; + +use super::{ + importbucket_client::{ControlFile, RemoteStorageWrapper}, + Timeline, +}; + +use remote_storage::RemotePath; + +pub async fn run( + timeline: Arc, + pgdata_lsn: Lsn, + control_file: ControlFile, + storage: RemoteStorageWrapper, + ctx: &RequestContext, +) -> anyhow::Result<()> { + Flow { + timeline, + pgdata_lsn, + control_file, + tasks: Vec::new(), + storage, + } + .run(ctx) + .await +} + +struct Flow { + timeline: Arc, + pgdata_lsn: Lsn, + control_file: ControlFile, + tasks: Vec, + storage: RemoteStorageWrapper, +} + +impl Flow { + /// Perform the ingestion into [`Self::timeline`]. + /// Assumes the timeline is empty (= no layers). + pub async fn run(mut self, ctx: &RequestContext) -> anyhow::Result<()> { + let pgdata_lsn = Lsn(self.control_file.control_file_data().checkPoint).align(); + + self.pgdata_lsn = pgdata_lsn; + + let datadir = PgDataDir::new(&self.storage).await?; + + // Import dbdir (00:00:00 keyspace) + // This is just constructed here, but will be written to the image layer in the first call to import_db() + let dbdir_buf = Bytes::from(DbDirectory::ser(&DbDirectory { + dbdirs: datadir + .dbs + .iter() + .map(|db| ((db.spcnode, db.dboid), true)) + .collect(), + })?); + self.tasks + .push(ImportSingleKeyTask::new(DBDIR_KEY, dbdir_buf).into()); + + // Import databases (00:spcnode:dbnode keyspace for each db) + for db in datadir.dbs { + self.import_db(&db).await?; + } + + // Import SLRUs + + // pg_xact (01:00 keyspace) + self.import_slru(SlruKind::Clog, &self.storage.pgdata().join("pg_xact")) + .await?; + // pg_multixact/members (01:01 keyspace) + self.import_slru( + SlruKind::MultiXactMembers, + &self.storage.pgdata().join("pg_multixact/members"), + ) + .await?; + // pg_multixact/offsets (01:02 keyspace) + self.import_slru( + SlruKind::MultiXactOffsets, + &self.storage.pgdata().join("pg_multixact/offsets"), + ) + .await?; + + // Import pg_twophase. + // TODO: as empty + let twophasedir_buf = TwoPhaseDirectory::ser(&TwoPhaseDirectory { + xids: HashSet::new(), + })?; + self.tasks + .push(AnyImportTask::SingleKey(ImportSingleKeyTask::new( + TWOPHASEDIR_KEY, + Bytes::from(twophasedir_buf), + ))); + + // Controlfile, checkpoint + self.tasks + .push(AnyImportTask::SingleKey(ImportSingleKeyTask::new( + CONTROLFILE_KEY, + self.control_file.control_file_buf().clone(), + ))); + + let checkpoint_buf = self + .control_file + .control_file_data() + .checkPointCopy + .encode()?; + self.tasks + .push(AnyImportTask::SingleKey(ImportSingleKeyTask::new( + CHECKPOINT_KEY, + checkpoint_buf, + ))); + + // Assigns parts of key space to later parallel jobs + let mut last_end_key = Key::MIN; + let mut current_chunk = Vec::new(); + let mut current_chunk_size: usize = 0; + let mut parallel_jobs = Vec::new(); + for task in std::mem::take(&mut self.tasks).into_iter() { + if current_chunk_size + task.total_size() > 1024 * 1024 * 1024 { + let key_range = last_end_key..task.key_range().start; + parallel_jobs.push(ChunkProcessingJob::new( + key_range.clone(), + std::mem::take(&mut current_chunk), + &self, + )); + last_end_key = key_range.end; + current_chunk_size = 0; + } + current_chunk_size += task.total_size(); + current_chunk.push(task); + } + parallel_jobs.push(ChunkProcessingJob::new( + last_end_key..Key::MAX, + current_chunk, + &self, + )); + + // Start all jobs simultaneosly + let mut work = JoinSet::new(); + // TODO: semaphore? + for job in parallel_jobs { + let ctx: RequestContext = + ctx.detached_child(TaskKind::ImportPgdata, DownloadBehavior::Error); + work.spawn(async move { job.run(&ctx).await }.instrument(info_span!("parallel_job"))); + } + let mut results = Vec::new(); + while let Some(result) = work.join_next().await { + match result { + Ok(res) => { + results.push(res); + } + Err(_joinset_err) => { + results.push(Err(anyhow::anyhow!( + "parallel job panicked or cancelled, check pageserver logs" + ))); + } + } + } + + if results.iter().all(|r| r.is_ok()) { + Ok(()) + } else { + let mut msg = String::new(); + for result in results { + if let Err(err) = result { + msg.push_str(&format!("{err:?}\n\n")); + } + } + bail!("Some parallel jobs failed:\n\n{msg}"); + } + } + + #[instrument(level = tracing::Level::DEBUG, skip_all, fields(dboid=%db.dboid, tablespace=%db.spcnode, path=%db.path))] + async fn import_db(&mut self, db: &PgDataDirDb) -> anyhow::Result<()> { + debug!("start"); + scopeguard::defer! { + debug!("return"); + } + + // Import relmap (00:spcnode:dbnode:00:*:00) + let relmap_key = relmap_file_key(db.spcnode, db.dboid); + debug!("Constructing relmap entry, key {relmap_key}"); + let relmap_path = db.path.join("pg_filenode.map"); + let relmap_buf = self.storage.get(&relmap_path).await?; + self.tasks + .push(AnyImportTask::SingleKey(ImportSingleKeyTask::new( + relmap_key, relmap_buf, + ))); + + // Import reldir (00:spcnode:dbnode:00:*:01) + let reldir_key = rel_dir_to_key(db.spcnode, db.dboid); + debug!("Constructing reldirs entry, key {reldir_key}"); + let reldir_buf = RelDirectory::ser(&RelDirectory { + rels: db + .files + .iter() + .map(|f| (f.rel_tag.relnode, f.rel_tag.forknum)) + .collect(), + })?; + self.tasks + .push(AnyImportTask::SingleKey(ImportSingleKeyTask::new( + reldir_key, + Bytes::from(reldir_buf), + ))); + + // Import data (00:spcnode:dbnode:reloid:fork:blk) and set sizes for each last + // segment in a given relation (00:spcnode:dbnode:reloid:fork:ff) + for file in &db.files { + debug!(%file.path, %file.filesize, "importing file"); + let len = file.filesize; + ensure!(len % 8192 == 0); + let start_blk: u32 = file.segno * (1024 * 1024 * 1024 / 8192); + let start_key = rel_block_to_key(file.rel_tag, start_blk); + let end_key = rel_block_to_key(file.rel_tag, start_blk + (len / 8192) as u32); + self.tasks + .push(AnyImportTask::RelBlocks(ImportRelBlocksTask::new( + *self.timeline.get_shard_identity(), + start_key..end_key, + &file.path, + self.storage.clone(), + ))); + + // Set relsize for the last segment (00:spcnode:dbnode:reloid:fork:ff) + if let Some(nblocks) = file.nblocks { + let size_key = rel_size_to_key(file.rel_tag); + //debug!("Setting relation size (path={path}, rel_tag={rel_tag}, segno={segno}) to {nblocks}, key {size_key}"); + let buf = nblocks.to_le_bytes(); + self.tasks + .push(AnyImportTask::SingleKey(ImportSingleKeyTask::new( + size_key, + Bytes::from(buf.to_vec()), + ))); + } + } + + Ok(()) + } + + async fn import_slru(&mut self, kind: SlruKind, path: &RemotePath) -> anyhow::Result<()> { + let segments = self.storage.listfilesindir(path).await?; + let segments: Vec<(String, u32, usize)> = segments + .into_iter() + .filter_map(|(path, size)| { + let filename = path.object_name()?; + let segno = u32::from_str_radix(filename, 16).ok()?; + Some((filename.to_string(), segno, size)) + }) + .collect(); + + // Write SlruDir + let slrudir_key = slru_dir_to_key(kind); + let segnos: HashSet = segments + .iter() + .map(|(_path, segno, _size)| *segno) + .collect(); + let slrudir = SlruSegmentDirectory { segments: segnos }; + let slrudir_buf = SlruSegmentDirectory::ser(&slrudir)?; + self.tasks + .push(AnyImportTask::SingleKey(ImportSingleKeyTask::new( + slrudir_key, + Bytes::from(slrudir_buf), + ))); + + for (segpath, segno, size) in segments { + // SlruSegBlocks for each segment + let p = path.join(&segpath); + let file_size = size; + ensure!(file_size % 8192 == 0); + let nblocks = u32::try_from(file_size / 8192)?; + let start_key = slru_block_to_key(kind, segno, 0); + let end_key = slru_block_to_key(kind, segno, nblocks); + debug!(%p, segno=%segno, %size, %start_key, %end_key, "scheduling SLRU segment"); + self.tasks + .push(AnyImportTask::SlruBlocks(ImportSlruBlocksTask::new( + *self.timeline.get_shard_identity(), + start_key..end_key, + &p, + self.storage.clone(), + ))); + + // Followed by SlruSegSize + let segsize_key = slru_segment_size_to_key(kind, segno); + let segsize_buf = nblocks.to_le_bytes(); + self.tasks + .push(AnyImportTask::SingleKey(ImportSingleKeyTask::new( + segsize_key, + Bytes::copy_from_slice(&segsize_buf), + ))); + } + Ok(()) + } +} + +// +// dbdir iteration tools +// + +struct PgDataDir { + pub dbs: Vec, // spcnode, dboid, path +} + +struct PgDataDirDb { + pub spcnode: u32, + pub dboid: u32, + pub path: RemotePath, + pub files: Vec, +} + +struct PgDataDirDbFile { + pub path: RemotePath, + pub rel_tag: RelTag, + pub segno: u32, + pub filesize: usize, + // Cummulative size of the given fork, set only for the last segment of that fork + pub nblocks: Option, +} + +impl PgDataDir { + async fn new(storage: &RemoteStorageWrapper) -> anyhow::Result { + let datadir_path = storage.pgdata(); + // Import ordinary databases, DEFAULTTABLESPACE_OID is smaller than GLOBALTABLESPACE_OID, so import them first + // Traverse database in increasing oid order + + let basedir = &datadir_path.join("base"); + let db_oids: Vec<_> = storage + .listdir(basedir) + .await? + .into_iter() + .filter_map(|path| path.object_name().and_then(|name| name.parse::().ok())) + .sorted() + .collect(); + debug!(?db_oids, "found databases"); + let mut databases = Vec::new(); + for dboid in db_oids { + databases.push( + PgDataDirDb::new( + storage, + &basedir.join(dboid.to_string()), + pg_constants::DEFAULTTABLESPACE_OID, + dboid, + &datadir_path, + ) + .await?, + ); + } + + // special case for global catalogs + databases.push( + PgDataDirDb::new( + storage, + &datadir_path.join("global"), + postgres_ffi::pg_constants::GLOBALTABLESPACE_OID, + 0, + &datadir_path, + ) + .await?, + ); + + databases.sort_by_key(|db| (db.spcnode, db.dboid)); + + Ok(Self { dbs: databases }) + } +} + +impl PgDataDirDb { + #[instrument(level = tracing::Level::DEBUG, skip_all, fields(%dboid, %db_path))] + async fn new( + storage: &RemoteStorageWrapper, + db_path: &RemotePath, + spcnode: u32, + dboid: u32, + datadir_path: &RemotePath, + ) -> anyhow::Result { + let mut files: Vec = storage + .listfilesindir(db_path) + .await? + .into_iter() + .filter_map(|(path, size)| { + debug!(%path, %size, "found file in dbdir"); + path.object_name().and_then(|name| { + // returns (relnode, forknum, segno) + parse_relfilename(name).ok().map(|x| (size, x)) + }) + }) + .sorted_by_key(|(_, relfilename)| *relfilename) + .map(|(filesize, (relnode, forknum, segno))| { + let rel_tag = RelTag { + spcnode, + dbnode: dboid, + relnode, + forknum, + }; + + let path = datadir_path.join(rel_tag.to_segfile_name(segno)); + assert!(filesize % BLCKSZ as usize == 0); // TODO: this should result in an error + let nblocks = filesize / BLCKSZ as usize; + + PgDataDirDbFile { + path, + filesize, + rel_tag, + segno, + nblocks: Some(nblocks), // first non-cummulative sizes + } + }) + .collect(); + + // Set cummulative sizes. Do all of that math here, so that later we could easier + // parallelize over segments and know with which segments we need to write relsize + // entry. + let mut cumulative_nblocks: usize = 0; + let mut prev_rel_tag: Option = None; + for i in 0..files.len() { + if prev_rel_tag == Some(files[i].rel_tag) { + cumulative_nblocks += files[i].nblocks.unwrap(); + } else { + cumulative_nblocks = files[i].nblocks.unwrap(); + } + + files[i].nblocks = if i == files.len() - 1 || files[i + 1].rel_tag != files[i].rel_tag { + Some(cumulative_nblocks) + } else { + None + }; + + prev_rel_tag = Some(files[i].rel_tag); + } + + Ok(PgDataDirDb { + files, + path: db_path.clone(), + spcnode, + dboid, + }) + } +} + +trait ImportTask { + fn key_range(&self) -> Range; + + fn total_size(&self) -> usize { + // TODO: revisit this + if is_contiguous_range(&self.key_range()) { + contiguous_range_len(&self.key_range()) as usize * 8192 + } else { + u32::MAX as usize + } + } + + async fn doit( + self, + layer_writer: &mut ImageLayerWriter, + ctx: &RequestContext, + ) -> anyhow::Result; +} + +struct ImportSingleKeyTask { + key: Key, + buf: Bytes, +} + +impl ImportSingleKeyTask { + fn new(key: Key, buf: Bytes) -> Self { + ImportSingleKeyTask { key, buf } + } +} + +impl ImportTask for ImportSingleKeyTask { + fn key_range(&self) -> Range { + singleton_range(self.key) + } + + async fn doit( + self, + layer_writer: &mut ImageLayerWriter, + ctx: &RequestContext, + ) -> anyhow::Result { + layer_writer.put_image(self.key, self.buf, ctx).await?; + Ok(1) + } +} + +struct ImportRelBlocksTask { + shard_identity: ShardIdentity, + key_range: Range, + path: RemotePath, + storage: RemoteStorageWrapper, +} + +impl ImportRelBlocksTask { + fn new( + shard_identity: ShardIdentity, + key_range: Range, + path: &RemotePath, + storage: RemoteStorageWrapper, + ) -> Self { + ImportRelBlocksTask { + shard_identity, + key_range, + path: path.clone(), + storage, + } + } +} + +impl ImportTask for ImportRelBlocksTask { + fn key_range(&self) -> Range { + self.key_range.clone() + } + + #[instrument(level = tracing::Level::DEBUG, skip_all, fields(%self.path))] + async fn doit( + self, + layer_writer: &mut ImageLayerWriter, + ctx: &RequestContext, + ) -> anyhow::Result { + debug!("Importing relation file"); + + let (rel_tag, start_blk) = self.key_range.start.to_rel_block()?; + let (rel_tag_end, end_blk) = self.key_range.end.to_rel_block()?; + assert_eq!(rel_tag, rel_tag_end); + + let ranges = (start_blk..end_blk) + .enumerate() + .filter_map(|(i, blknum)| { + let key = rel_block_to_key(rel_tag, blknum); + if self.shard_identity.is_key_disposable(&key) { + return None; + } + let file_offset = i.checked_mul(8192).unwrap(); + Some(( + vec![key], + file_offset, + file_offset.checked_add(8192).unwrap(), + )) + }) + .coalesce(|(mut acc, acc_start, acc_end), (mut key, start, end)| { + assert_eq!(key.len(), 1); + assert!(!acc.is_empty()); + assert!(acc_end > acc_start); + if acc_end == start /* TODO additional max range check here, to limit memory consumption per task to X */ { + acc.push(key.pop().unwrap()); + Ok((acc, acc_start, end)) + } else { + Err(((acc, acc_start, acc_end), (key, start, end))) + } + }); + + let mut nimages = 0; + for (keys, range_start, range_end) in ranges { + let range_buf = self + .storage + .get_range(&self.path, range_start.into_u64(), range_end.into_u64()) + .await?; + let mut buf = Bytes::from(range_buf); + // TODO: batched writes + for key in keys { + let image = buf.split_to(8192); + layer_writer.put_image(key, image, ctx).await?; + nimages += 1; + } + } + + Ok(nimages) + } +} + +struct ImportSlruBlocksTask { + shard_identity: ShardIdentity, + key_range: Range, + path: RemotePath, + storage: RemoteStorageWrapper, +} + +impl ImportSlruBlocksTask { + fn new( + shard_identity: ShardIdentity, + key_range: Range, + path: &RemotePath, + storage: RemoteStorageWrapper, + ) -> Self { + ImportSlruBlocksTask { + shard_identity, + key_range, + path: path.clone(), + storage, + } + } +} + +impl ImportTask for ImportSlruBlocksTask { + fn key_range(&self) -> Range { + self.key_range.clone() + } + + async fn doit( + self, + layer_writer: &mut ImageLayerWriter, + ctx: &RequestContext, + ) -> anyhow::Result { + debug!("Importing SLRU segment file {}", self.path); + let buf = self.storage.get(&self.path).await?; + + let (kind, segno, start_blk) = self.key_range.start.to_slru_block()?; + let (_kind, _segno, end_blk) = self.key_range.end.to_slru_block()?; + let mut blknum = start_blk; + let mut nimages = 0; + let mut file_offset = 0; + while blknum < end_blk { + let key = slru_block_to_key(kind, segno, blknum); + assert!( + !self.shard_identity.is_key_disposable(&key), + "SLRU keys need to go into every shard" + ); + let buf = &buf[file_offset..(file_offset + 8192)]; + file_offset += 8192; + layer_writer + .put_image(key, Bytes::copy_from_slice(buf), ctx) + .await?; + blknum += 1; + nimages += 1; + } + Ok(nimages) + } +} + +enum AnyImportTask { + SingleKey(ImportSingleKeyTask), + RelBlocks(ImportRelBlocksTask), + SlruBlocks(ImportSlruBlocksTask), +} + +impl ImportTask for AnyImportTask { + fn key_range(&self) -> Range { + match self { + Self::SingleKey(t) => t.key_range(), + Self::RelBlocks(t) => t.key_range(), + Self::SlruBlocks(t) => t.key_range(), + } + } + /// returns the number of images put into the `layer_writer` + async fn doit( + self, + layer_writer: &mut ImageLayerWriter, + ctx: &RequestContext, + ) -> anyhow::Result { + match self { + Self::SingleKey(t) => t.doit(layer_writer, ctx).await, + Self::RelBlocks(t) => t.doit(layer_writer, ctx).await, + Self::SlruBlocks(t) => t.doit(layer_writer, ctx).await, + } + } +} + +impl From for AnyImportTask { + fn from(t: ImportSingleKeyTask) -> Self { + Self::SingleKey(t) + } +} + +impl From for AnyImportTask { + fn from(t: ImportRelBlocksTask) -> Self { + Self::RelBlocks(t) + } +} + +impl From for AnyImportTask { + fn from(t: ImportSlruBlocksTask) -> Self { + Self::SlruBlocks(t) + } +} + +struct ChunkProcessingJob { + timeline: Arc, + range: Range, + tasks: Vec, + + pgdata_lsn: Lsn, +} + +impl ChunkProcessingJob { + fn new(range: Range, tasks: Vec, env: &Flow) -> Self { + assert!(env.pgdata_lsn.is_valid()); + Self { + timeline: env.timeline.clone(), + range, + tasks, + pgdata_lsn: env.pgdata_lsn, + } + } + + async fn run(self, ctx: &RequestContext) -> anyhow::Result<()> { + let mut writer = ImageLayerWriter::new( + self.timeline.conf, + self.timeline.timeline_id, + self.timeline.tenant_shard_id, + &self.range, + self.pgdata_lsn, + ctx, + ) + .await?; + + let mut nimages = 0; + for task in self.tasks { + nimages += task.doit(&mut writer, ctx).await?; + } + + let resident_layer = if nimages > 0 { + let (desc, path) = writer.finish(ctx).await?; + Layer::finish_creating(self.timeline.conf, &self.timeline, desc, &path)? + } else { + // dropping the writer cleans up + return Ok(()); + }; + + // this is sharing the same code as create_image_layers + let mut guard = self.timeline.layers.write().await; + guard + .open_mut()? + .track_new_image_layers(&[resident_layer.clone()], &self.timeline.metrics); + crate::tenant::timeline::drop_wlock(guard); + + // Schedule the layer for upload but don't add barriers such as + // wait for completion or index upload, so we don't inhibit upload parallelism. + // TODO: limit upload parallelism somehow (e.g. by limiting concurrency of jobs?) + // TODO: or regulate parallelism by upload queue depth? Prob should happen at a higher level. + self.timeline + .remote_client + .schedule_layer_file_upload(resident_layer)?; + + Ok(()) + } +} diff --git a/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs b/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs new file mode 100644 index 0000000000..8d5ab1780f --- /dev/null +++ b/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs @@ -0,0 +1,315 @@ +use std::{ops::Bound, sync::Arc}; + +use anyhow::Context; +use bytes::Bytes; +use postgres_ffi::ControlFileData; +use remote_storage::{ + Download, DownloadError, DownloadOpts, GenericRemoteStorage, Listing, ListingObject, RemotePath, +}; +use serde::de::DeserializeOwned; +use tokio_util::sync::CancellationToken; +use tracing::{debug, info, instrument}; +use utils::lsn::Lsn; + +use crate::{assert_u64_eq_usize::U64IsUsize, config::PageServerConf}; + +use super::{importbucket_format, index_part_format}; + +pub async fn new( + conf: &'static PageServerConf, + location: &index_part_format::Location, + cancel: CancellationToken, +) -> Result { + // FIXME: we probably want some timeout, and we might be able to assume the max file + // size on S3 is 1GiB (postgres segment size). But the problem is that the individual + // downloaders don't know enough about concurrent downloads to make a guess on the + // expected bandwidth and resulting best timeout. + let timeout = std::time::Duration::from_secs(24 * 60 * 60); + let location_storage = match location { + #[cfg(feature = "testing")] + index_part_format::Location::LocalFs { path } => { + GenericRemoteStorage::LocalFs(remote_storage::LocalFs::new(path.clone(), timeout)?) + } + index_part_format::Location::AwsS3 { + region, + bucket, + key, + } => { + // TODO: think about security implications of letting the client specify the bucket & prefix. + // It's the most flexible right now, but, possibly we want to move bucket name into PS conf + // and force the timeline_id into the prefix? + GenericRemoteStorage::AwsS3(Arc::new( + remote_storage::S3Bucket::new( + &remote_storage::S3Config { + bucket_name: bucket.clone(), + prefix_in_bucket: Some(key.clone()), + bucket_region: region.clone(), + endpoint: conf + .import_pgdata_aws_endpoint_url + .clone() + .map(|url| url.to_string()), // by specifying None here, remote_storage/aws-sdk-rust will infer from env + concurrency_limit: 100.try_into().unwrap(), // TODO: think about this + max_keys_per_list_response: Some(1000), // TODO: think about this + upload_storage_class: None, // irrelevant + }, + timeout, + ) + .await + .context("setup s3 bucket")?, + )) + } + }; + let storage_wrapper = RemoteStorageWrapper::new(location_storage, cancel); + Ok(storage_wrapper) +} + +/// Wrap [`remote_storage`] APIs to make it look a bit more like a filesystem API +/// such as [`tokio::fs`], which was used in the original implementation of the import code. +#[derive(Clone)] +pub struct RemoteStorageWrapper { + storage: GenericRemoteStorage, + cancel: CancellationToken, +} + +impl RemoteStorageWrapper { + pub fn new(storage: GenericRemoteStorage, cancel: CancellationToken) -> Self { + Self { storage, cancel } + } + + #[instrument(level = tracing::Level::DEBUG, skip_all, fields(%path))] + pub async fn listfilesindir( + &self, + path: &RemotePath, + ) -> Result, DownloadError> { + assert!( + path.object_name().is_some(), + "must specify dirname, without trailing slash" + ); + let path = path.add_trailing_slash(); + + let res = crate::tenant::remote_timeline_client::download::download_retry_forever( + || async { + let Listing { keys, prefixes: _ } = self + .storage + .list( + Some(&path), + remote_storage::ListingMode::WithDelimiter, + None, + &self.cancel, + ) + .await?; + let res = keys + .into_iter() + .map(|ListingObject { key, size, .. }| (key, size.into_usize())) + .collect(); + Ok(res) + }, + &format!("listfilesindir {path:?}"), + &self.cancel, + ) + .await; + debug!(?res, "returning"); + res + } + + #[instrument(level = tracing::Level::DEBUG, skip_all, fields(%path))] + pub async fn listdir(&self, path: &RemotePath) -> Result, DownloadError> { + assert!( + path.object_name().is_some(), + "must specify dirname, without trailing slash" + ); + let path = path.add_trailing_slash(); + + let res = crate::tenant::remote_timeline_client::download::download_retry_forever( + || async { + let Listing { keys, prefixes } = self + .storage + .list( + Some(&path), + remote_storage::ListingMode::WithDelimiter, + None, + &self.cancel, + ) + .await?; + let res = keys + .into_iter() + .map(|ListingObject { key, .. }| key) + .chain(prefixes.into_iter()) + .collect(); + Ok(res) + }, + &format!("listdir {path:?}"), + &self.cancel, + ) + .await; + debug!(?res, "returning"); + res + } + + #[instrument(level = tracing::Level::DEBUG, skip_all, fields(%path))] + pub async fn get(&self, path: &RemotePath) -> Result { + let res = crate::tenant::remote_timeline_client::download::download_retry_forever( + || async { + let Download { + download_stream, .. + } = self + .storage + .download(path, &DownloadOpts::default(), &self.cancel) + .await?; + let mut reader = tokio_util::io::StreamReader::new(download_stream); + + // XXX optimize this, can we get the capacity hint from somewhere? + let mut buf = Vec::new(); + tokio::io::copy_buf(&mut reader, &mut buf).await?; + Ok(Bytes::from(buf)) + }, + &format!("download {path:?}"), + &self.cancel, + ) + .await; + debug!(len = res.as_ref().ok().map(|buf| buf.len()), "done"); + res + } + + pub async fn get_spec(&self) -> Result, anyhow::Error> { + self.get_json(&RemotePath::from_string("spec.json").unwrap()) + .await + .context("get spec") + } + + #[instrument(level = tracing::Level::DEBUG, skip_all, fields(%path))] + pub async fn get_json( + &self, + path: &RemotePath, + ) -> Result, DownloadError> { + let buf = match self.get(path).await { + Ok(buf) => buf, + Err(DownloadError::NotFound) => return Ok(None), + Err(err) => return Err(err), + }; + let res = serde_json::from_slice(&buf) + .context("serialize") + // TODO: own error type + .map_err(DownloadError::Other)?; + Ok(Some(res)) + } + + #[instrument(level = tracing::Level::DEBUG, skip_all, fields(%path))] + pub async fn put_json(&self, path: &RemotePath, value: &T) -> anyhow::Result<()> + where + T: serde::Serialize, + { + let buf = serde_json::to_vec(value)?; + let bytes = Bytes::from(buf); + utils::backoff::retry( + || async { + let size = bytes.len(); + let bytes = futures::stream::once(futures::future::ready(Ok(bytes.clone()))); + self.storage + .upload_storage_object(bytes, size, path, &self.cancel) + .await + }, + remote_storage::TimeoutOrCancel::caused_by_cancel, + 1, + u32::MAX, + &format!("put json {path}"), + &self.cancel, + ) + .await + .expect("practically infinite retries") + } + + #[instrument(level = tracing::Level::DEBUG, skip_all, fields(%path))] + pub async fn get_range( + &self, + path: &RemotePath, + start_inclusive: u64, + end_exclusive: u64, + ) -> Result, DownloadError> { + let len = end_exclusive + .checked_sub(start_inclusive) + .unwrap() + .into_usize(); + let res = crate::tenant::remote_timeline_client::download::download_retry_forever( + || async { + let Download { + download_stream, .. + } = self + .storage + .download( + path, + &DownloadOpts { + etag: None, + byte_start: Bound::Included(start_inclusive), + byte_end: Bound::Excluded(end_exclusive) + }, + &self.cancel) + .await?; + let mut reader = tokio_util::io::StreamReader::new(download_stream); + + let mut buf = Vec::with_capacity(len); + tokio::io::copy_buf(&mut reader, &mut buf).await?; + Ok(buf) + }, + &format!("download range len=0x{len:x} [0x{start_inclusive:x},0x{end_exclusive:x}) from {path:?}"), + &self.cancel, + ) + .await; + debug!(len = res.as_ref().ok().map(|buf| buf.len()), "done"); + res + } + + pub fn pgdata(&self) -> RemotePath { + RemotePath::from_string("pgdata").unwrap() + } + + pub async fn get_control_file(&self) -> Result { + let control_file_path = self.pgdata().join("global/pg_control"); + info!("get control file from {control_file_path}"); + let control_file_buf = self.get(&control_file_path).await?; + ControlFile::new(control_file_buf) + } +} + +pub struct ControlFile { + control_file_data: ControlFileData, + control_file_buf: Bytes, +} + +impl ControlFile { + pub(crate) fn new(control_file_buf: Bytes) -> Result { + // XXX ControlFileData is version-specific, we're always using v14 here. v17 had changes. + let control_file_data = ControlFileData::decode(&control_file_buf)?; + let control_file = ControlFile { + control_file_data, + control_file_buf, + }; + control_file.try_pg_version()?; // so that we can offer infallible pg_version() + Ok(control_file) + } + pub(crate) fn base_lsn(&self) -> Lsn { + Lsn(self.control_file_data.checkPoint).align() + } + pub(crate) fn pg_version(&self) -> u32 { + self.try_pg_version() + .expect("prepare() checks that try_pg_version doesn't error") + } + pub(crate) fn control_file_data(&self) -> &ControlFileData { + &self.control_file_data + } + pub(crate) fn control_file_buf(&self) -> &Bytes { + &self.control_file_buf + } + fn try_pg_version(&self) -> anyhow::Result { + Ok(match self.control_file_data.catalog_version_no { + // thesea are from catversion.h + 202107181 => 14, + 202209061 => 15, + 202307071 => 16, + /* XXX pg17 */ + catversion => { + anyhow::bail!("unrecognized catalog version {catversion}") + } + }) + } +} diff --git a/pageserver/src/tenant/timeline/import_pgdata/importbucket_format.rs b/pageserver/src/tenant/timeline/import_pgdata/importbucket_format.rs new file mode 100644 index 0000000000..04ba3c6f1f --- /dev/null +++ b/pageserver/src/tenant/timeline/import_pgdata/importbucket_format.rs @@ -0,0 +1,20 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)] +pub struct PgdataStatus { + pub done: bool, + // TODO: remaining fields +} + +#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)] +pub struct ShardStatus { + pub done: bool, + // TODO: remaining fields +} + +// TODO: dedupe with fast_import code +#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)] +pub struct Spec { + pub project_id: String, + pub branch_id: String, +} diff --git a/pageserver/src/tenant/timeline/import_pgdata/index_part_format.rs b/pageserver/src/tenant/timeline/import_pgdata/index_part_format.rs new file mode 100644 index 0000000000..310d97a6a9 --- /dev/null +++ b/pageserver/src/tenant/timeline/import_pgdata/index_part_format.rs @@ -0,0 +1,68 @@ +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "testing")] +use camino::Utf8PathBuf; + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub enum Root { + V1(V1), +} +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub enum V1 { + InProgress(InProgress), + Done(Done), +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +#[serde(transparent)] +pub struct IdempotencyKey(String); + +impl IdempotencyKey { + pub fn new(s: String) -> Self { + Self(s) + } +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct InProgress { + pub idempotency_key: IdempotencyKey, + pub location: Location, + pub started_at: chrono::NaiveDateTime, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub struct Done { + pub idempotency_key: IdempotencyKey, + pub started_at: chrono::NaiveDateTime, + pub finished_at: chrono::NaiveDateTime, +} + +#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)] +pub enum Location { + #[cfg(feature = "testing")] + LocalFs { path: Utf8PathBuf }, + AwsS3 { + region: String, + bucket: String, + key: String, + }, +} + +impl Root { + pub fn is_done(&self) -> bool { + match self { + Root::V1(v1) => match v1 { + V1::Done(_) => true, + V1::InProgress(_) => false, + }, + } + } + pub fn idempotency_key(&self) -> &IdempotencyKey { + match self { + Root::V1(v1) => match v1 { + V1::InProgress(in_progress) => &in_progress.idempotency_key, + V1::Done(done) => &done.idempotency_key, + }, + } + } +} diff --git a/pageserver/src/tenant/timeline/import_pgdata/upcall_api.rs b/pageserver/src/tenant/timeline/import_pgdata/upcall_api.rs new file mode 100644 index 0000000000..c5210f9a30 --- /dev/null +++ b/pageserver/src/tenant/timeline/import_pgdata/upcall_api.rs @@ -0,0 +1,119 @@ +//! FIXME: most of this is copy-paste from mgmt_api.rs ; dedupe into a `reqwest_utils::Client` crate. +use pageserver_client::mgmt_api::{Error, ResponseErrorMessageExt}; +use serde::{Deserialize, Serialize}; +use tokio_util::sync::CancellationToken; +use tracing::error; + +use crate::config::PageServerConf; +use reqwest::Method; + +use super::importbucket_format::Spec; + +pub struct Client { + base_url: String, + authorization_header: Option, + client: reqwest::Client, + cancel: CancellationToken, +} + +pub type Result = std::result::Result; + +#[derive(Serialize, Deserialize, Debug)] +struct ImportProgressRequest { + // no fields yet, not sure if there every will be any +} + +#[derive(Serialize, Deserialize, Debug)] +struct ImportProgressResponse { + // we don't care +} + +impl Client { + pub fn new(conf: &PageServerConf, cancel: CancellationToken) -> anyhow::Result { + let Some(ref base_url) = conf.import_pgdata_upcall_api else { + anyhow::bail!("import_pgdata_upcall_api is not configured") + }; + Ok(Self { + base_url: base_url.to_string(), + client: reqwest::Client::new(), + cancel, + authorization_header: conf + .import_pgdata_upcall_api_token + .as_ref() + .map(|secret_string| secret_string.get_contents()) + .map(|jwt| format!("Bearer {jwt}")), + }) + } + + fn start_request( + &self, + method: Method, + uri: U, + ) -> reqwest::RequestBuilder { + let req = self.client.request(method, uri); + if let Some(value) = &self.authorization_header { + req.header(reqwest::header::AUTHORIZATION, value) + } else { + req + } + } + + async fn request_noerror( + &self, + method: Method, + uri: U, + body: B, + ) -> Result { + self.start_request(method, uri) + .json(&body) + .send() + .await + .map_err(Error::ReceiveBody) + } + + async fn request( + &self, + method: Method, + uri: U, + body: B, + ) -> Result { + let res = self.request_noerror(method, uri, body).await?; + let response = res.error_from_body().await?; + Ok(response) + } + + pub async fn send_progress_once(&self, spec: &Spec) -> Result<()> { + let url = format!( + "{}/projects/{}/branches/{}/import_progress", + self.base_url, spec.project_id, spec.branch_id + ); + let ImportProgressResponse {} = self + .request(Method::POST, url, &ImportProgressRequest {}) + .await? + .json() + .await + .map_err(Error::ReceiveBody)?; + Ok(()) + } + + pub async fn send_progress_until_success(&self, spec: &Spec) -> anyhow::Result<()> { + loop { + match self.send_progress_once(spec).await { + Ok(()) => return Ok(()), + Err(Error::Cancelled) => return Err(anyhow::anyhow!("cancelled")), + Err(err) => { + error!(?err, "error sending progress, retrying"); + if tokio::time::timeout( + std::time::Duration::from_secs(10), + self.cancel.cancelled(), + ) + .await + .is_ok() + { + anyhow::bail!("cancelled while sending early progress update"); + } + } + } + } + } +} diff --git a/pageserver/src/tenant/timeline/offload.rs b/pageserver/src/tenant/timeline/offload.rs index 3595d743bc..3bfbfb5061 100644 --- a/pageserver/src/tenant/timeline/offload.rs +++ b/pageserver/src/tenant/timeline/offload.rs @@ -58,7 +58,7 @@ pub(crate) async fn offload_timeline( } // Now that the Timeline is in Stopping state, request all the related tasks to shut down. - timeline.shutdown(super::ShutdownMode::Flush).await; + timeline.shutdown(super::ShutdownMode::Reload).await; // TODO extend guard mechanism above with method // to make deletions possible while offloading is in progress diff --git a/pageserver/src/tenant/timeline/uninit.rs b/pageserver/src/tenant/timeline/uninit.rs index a93bdde3f8..80a09b4840 100644 --- a/pageserver/src/tenant/timeline/uninit.rs +++ b/pageserver/src/tenant/timeline/uninit.rs @@ -3,7 +3,7 @@ use std::{collections::hash_map::Entry, fs, sync::Arc}; use anyhow::Context; use camino::Utf8PathBuf; use tracing::{error, info, info_span}; -use utils::{fs_ext, id::TimelineId, lsn::Lsn}; +use utils::{fs_ext, id::TimelineId, lsn::Lsn, sync::gate::GateGuard}; use crate::{ context::RequestContext, @@ -23,14 +23,14 @@ use super::Timeline; pub struct UninitializedTimeline<'t> { pub(crate) owning_tenant: &'t Tenant, timeline_id: TimelineId, - raw_timeline: Option<(Arc, TimelineCreateGuard<'t>)>, + raw_timeline: Option<(Arc, TimelineCreateGuard)>, } impl<'t> UninitializedTimeline<'t> { pub(crate) fn new( owning_tenant: &'t Tenant, timeline_id: TimelineId, - raw_timeline: Option<(Arc, TimelineCreateGuard<'t>)>, + raw_timeline: Option<(Arc, TimelineCreateGuard)>, ) -> Self { Self { owning_tenant, @@ -87,6 +87,10 @@ impl<'t> UninitializedTimeline<'t> { } } + pub(crate) fn finish_creation_myself(&mut self) -> (Arc, TimelineCreateGuard) { + self.raw_timeline.take().expect("already checked") + } + /// Prepares timeline data by loading it from the basebackup archive. pub(crate) async fn import_basebackup_from_tar( self, @@ -167,9 +171,10 @@ pub(crate) fn cleanup_timeline_directory(create_guard: TimelineCreateGuard) { /// A guard for timeline creations in process: as long as this object exists, the timeline ID /// is kept in `[Tenant::timelines_creating]` to exclude concurrent attempts to create the same timeline. #[must_use] -pub(crate) struct TimelineCreateGuard<'t> { - owning_tenant: &'t Tenant, - timeline_id: TimelineId, +pub(crate) struct TimelineCreateGuard { + pub(crate) _tenant_gate_guard: GateGuard, + pub(crate) owning_tenant: Arc, + pub(crate) timeline_id: TimelineId, pub(crate) timeline_path: Utf8PathBuf, pub(crate) idempotency: CreateTimelineIdempotency, } @@ -184,20 +189,27 @@ pub(crate) enum TimelineExclusionError { }, #[error("Already creating")] AlreadyCreating, + #[error("Shutting down")] + ShuttingDown, // e.g. I/O errors, or some failure deep in postgres initdb #[error(transparent)] Other(#[from] anyhow::Error), } -impl<'t> TimelineCreateGuard<'t> { +impl TimelineCreateGuard { pub(crate) fn new( - owning_tenant: &'t Tenant, + owning_tenant: &Arc, timeline_id: TimelineId, timeline_path: Utf8PathBuf, idempotency: CreateTimelineIdempotency, allow_offloaded: bool, ) -> Result { + let _tenant_gate_guard = owning_tenant + .gate + .enter() + .map_err(|_| TimelineExclusionError::ShuttingDown)?; + // Lock order: this is the only place we take both locks. During drop() we only // lock creating_timelines let timelines = owning_tenant.timelines.lock().unwrap(); @@ -225,8 +237,12 @@ impl<'t> TimelineCreateGuard<'t> { return Err(TimelineExclusionError::AlreadyCreating); } creating_timelines.insert(timeline_id); + drop(creating_timelines); + drop(timelines_offloaded); + drop(timelines); Ok(Self { - owning_tenant, + _tenant_gate_guard, + owning_tenant: Arc::clone(owning_tenant), timeline_id, timeline_path, idempotency, @@ -234,7 +250,7 @@ impl<'t> TimelineCreateGuard<'t> { } } -impl Drop for TimelineCreateGuard<'_> { +impl Drop for TimelineCreateGuard { fn drop(&mut self) { self.owning_tenant .timelines_creating diff --git a/pageserver/src/tenant/timeline/walreceiver.rs b/pageserver/src/tenant/timeline/walreceiver.rs index 4a3a5c621b..f831f5e48a 100644 --- a/pageserver/src/tenant/timeline/walreceiver.rs +++ b/pageserver/src/tenant/timeline/walreceiver.rs @@ -38,6 +38,7 @@ use storage_broker::BrokerClientChannel; use tokio::sync::watch; use tokio_util::sync::CancellationToken; use tracing::*; +use utils::postgres_client::PostgresClientProtocol; use self::connection_manager::ConnectionManagerStatus; @@ -45,6 +46,7 @@ use super::Timeline; #[derive(Clone)] pub struct WalReceiverConf { + pub protocol: PostgresClientProtocol, /// The timeout on the connection to safekeeper for WAL streaming. pub wal_connect_timeout: Duration, /// The timeout to use to determine when the current connection is "stale" and reconnect to the other one. diff --git a/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs b/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs index de50f217d8..583d6309ab 100644 --- a/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs +++ b/pageserver/src/tenant/timeline/walreceiver/connection_manager.rs @@ -36,7 +36,9 @@ use postgres_connection::PgConnectionConfig; use utils::backoff::{ exponential_backoff, DEFAULT_BASE_BACKOFF_SECONDS, DEFAULT_MAX_BACKOFF_SECONDS, }; -use utils::postgres_client::wal_stream_connection_config; +use utils::postgres_client::{ + wal_stream_connection_config, ConnectionConfigArgs, PostgresClientProtocol, +}; use utils::{ id::{NodeId, TenantTimelineId}, lsn::Lsn, @@ -533,6 +535,7 @@ impl ConnectionManagerState { let node_id = new_sk.safekeeper_id; let connect_timeout = self.conf.wal_connect_timeout; let ingest_batch_size = self.conf.ingest_batch_size; + let protocol = self.conf.protocol; let timeline = Arc::clone(&self.timeline); let ctx = ctx.detached_child( TaskKind::WalReceiverConnectionHandler, @@ -546,6 +549,7 @@ impl ConnectionManagerState { let res = super::walreceiver_connection::handle_walreceiver_connection( timeline, + protocol, new_sk.wal_source_connconf, events_sender, cancellation.clone(), @@ -984,15 +988,33 @@ impl ConnectionManagerState { if info.safekeeper_connstr.is_empty() { return None; // no connection string, ignore sk } - match wal_stream_connection_config( - self.id, - info.safekeeper_connstr.as_ref(), - match &self.conf.auth_token { - None => None, - Some(x) => Some(x), + + let (shard_number, shard_count, shard_stripe_size) = match self.conf.protocol { + PostgresClientProtocol::Vanilla => { + (None, None, None) }, - self.conf.availability_zone.as_deref(), - ) { + PostgresClientProtocol::Interpreted { .. } => { + let shard_identity = self.timeline.get_shard_identity(); + ( + Some(shard_identity.number.0), + Some(shard_identity.count.0), + Some(shard_identity.stripe_size.0), + ) + } + }; + + let connection_conf_args = ConnectionConfigArgs { + protocol: self.conf.protocol, + ttid: self.id, + shard_number, + shard_count, + shard_stripe_size, + listen_pg_addr_str: info.safekeeper_connstr.as_ref(), + auth_token: self.conf.auth_token.as_ref().map(|t| t.as_str()), + availability_zone: self.conf.availability_zone.as_deref() + }; + + match wal_stream_connection_config(connection_conf_args) { Ok(connstr) => Some((*sk_id, info, connstr)), Err(e) => { error!("Failed to create wal receiver connection string from broker data of safekeeper node {}: {e:#}", sk_id); @@ -1096,6 +1118,7 @@ impl ReconnectReason { mod tests { use super::*; use crate::tenant::harness::{TenantHarness, TIMELINE_ID}; + use pageserver_api::config::defaults::DEFAULT_WAL_RECEIVER_PROTOCOL; use url::Host; fn dummy_broker_sk_timeline( @@ -1532,6 +1555,7 @@ mod tests { timeline, cancel: CancellationToken::new(), conf: WalReceiverConf { + protocol: DEFAULT_WAL_RECEIVER_PROTOCOL, wal_connect_timeout: Duration::from_secs(1), lagging_wal_timeout: Duration::from_secs(1), max_lsn_wal_lag: NonZeroU64::new(1024 * 1024).unwrap(), diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index 6ac6920d47..d90ffbfa2c 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -22,7 +22,10 @@ use tokio::{select, sync::watch, time}; use tokio_postgres::{replication::ReplicationStream, Client}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, trace, warn, Instrument}; -use wal_decoder::models::{FlushUncommittedRecords, InterpretedWalRecord}; +use wal_decoder::{ + models::{FlushUncommittedRecords, InterpretedWalRecord, InterpretedWalRecords}, + wire_format::FromWireFormat, +}; use super::TaskStateUpdate; use crate::{ @@ -36,7 +39,7 @@ use crate::{ use postgres_backend::is_expected_io_error; use postgres_connection::PgConnectionConfig; use postgres_ffi::waldecoder::WalStreamDecoder; -use utils::{id::NodeId, lsn::Lsn}; +use utils::{id::NodeId, lsn::Lsn, postgres_client::PostgresClientProtocol}; use utils::{pageserver_feedback::PageserverFeedback, sync::gate::GateError}; /// Status of the connection. @@ -109,6 +112,7 @@ impl From for WalReceiverError { #[allow(clippy::too_many_arguments)] pub(super) async fn handle_walreceiver_connection( timeline: Arc, + protocol: PostgresClientProtocol, wal_source_connconf: PgConnectionConfig, events_sender: watch::Sender>, cancellation: CancellationToken, @@ -260,6 +264,14 @@ pub(super) async fn handle_walreceiver_connection( let mut walingest = WalIngest::new(timeline.as_ref(), startpoint, &ctx).await?; + let interpreted_proto_config = match protocol { + PostgresClientProtocol::Vanilla => None, + PostgresClientProtocol::Interpreted { + format, + compression, + } => Some((format, compression)), + }; + while let Some(replication_message) = { select! { _ = cancellation.cancelled() => { @@ -291,6 +303,15 @@ pub(super) async fn handle_walreceiver_connection( connection_status.latest_connection_update = now; connection_status.commit_lsn = Some(Lsn::from(keepalive.wal_end())); } + ReplicationMessage::RawInterpretedWalRecords(raw) => { + connection_status.latest_connection_update = now; + if !raw.data().is_empty() { + connection_status.latest_wal_update = now; + } + + connection_status.commit_lsn = Some(Lsn::from(raw.commit_lsn())); + connection_status.streaming_lsn = Some(Lsn::from(raw.streaming_lsn())); + } &_ => {} }; if let Err(e) = events_sender.send(TaskStateUpdate::Progress(connection_status)) { @@ -298,7 +319,148 @@ pub(super) async fn handle_walreceiver_connection( return Ok(()); } + async fn commit( + modification: &mut DatadirModification<'_>, + uncommitted: &mut u64, + filtered: &mut u64, + ctx: &RequestContext, + ) -> anyhow::Result<()> { + WAL_INGEST + .records_committed + .inc_by(*uncommitted - *filtered); + modification.commit(ctx).await?; + *uncommitted = 0; + *filtered = 0; + Ok(()) + } + let status_update = match replication_message { + ReplicationMessage::RawInterpretedWalRecords(raw) => { + WAL_INGEST.bytes_received.inc_by(raw.data().len() as u64); + + let mut uncommitted_records = 0; + let mut filtered_records = 0; + + // This is the end LSN of the raw WAL from which the records + // were interpreted. + let streaming_lsn = Lsn::from(raw.streaming_lsn()); + + let (format, compression) = interpreted_proto_config.unwrap(); + let batch = InterpretedWalRecords::from_wire(raw.data(), format, compression) + .await + .with_context(|| { + anyhow::anyhow!( + "Failed to deserialize interpreted records ending at LSN {streaming_lsn}" + ) + })?; + + let InterpretedWalRecords { + records, + next_record_lsn, + } = batch; + + tracing::debug!( + "Received WAL up to {} with next_record_lsn={:?}", + streaming_lsn, + next_record_lsn + ); + + // We start the modification at 0 because each interpreted record + // advances it to its end LSN. 0 is just an initialization placeholder. + let mut modification = timeline.begin_modification(Lsn(0)); + + for interpreted in records { + if matches!(interpreted.flush_uncommitted, FlushUncommittedRecords::Yes) + && uncommitted_records > 0 + { + commit( + &mut modification, + &mut uncommitted_records, + &mut filtered_records, + &ctx, + ) + .await?; + } + + let local_next_record_lsn = interpreted.next_record_lsn; + let ingested = walingest + .ingest_record(interpreted, &mut modification, &ctx) + .await + .with_context(|| { + format!("could not ingest record at {local_next_record_lsn}") + })?; + + if !ingested { + tracing::debug!( + "ingest: filtered out record @ LSN {local_next_record_lsn}" + ); + WAL_INGEST.records_filtered.inc(); + filtered_records += 1; + } + + uncommitted_records += 1; + + // FIXME: this cannot be made pausable_failpoint without fixing the + // failpoint library; in tests, the added amount of debugging will cause us + // to timeout the tests. + fail_point!("walreceiver-after-ingest"); + + // Commit every ingest_batch_size records. Even if we filtered out + // all records, we still need to call commit to advance the LSN. + if uncommitted_records >= ingest_batch_size + || modification.approx_pending_bytes() + > DatadirModification::MAX_PENDING_BYTES + { + commit( + &mut modification, + &mut uncommitted_records, + &mut filtered_records, + &ctx, + ) + .await?; + } + } + + // Records might have been filtered out on the safekeeper side, but we still + // need to advance last record LSN on all shards. If we've not ingested the latest + // record, then set the LSN of the modification past it. This way all shards + // advance their last record LSN at the same time. + let needs_last_record_lsn_advance = match next_record_lsn.map(Lsn::from) { + Some(lsn) if lsn > modification.get_lsn() => { + modification.set_lsn(lsn).unwrap(); + true + } + _ => false, + }; + + if uncommitted_records > 0 || needs_last_record_lsn_advance { + // Commit any uncommitted records + commit( + &mut modification, + &mut uncommitted_records, + &mut filtered_records, + &ctx, + ) + .await?; + } + + if !caught_up && streaming_lsn >= end_of_wal { + info!("caught up at LSN {streaming_lsn}"); + caught_up = true; + } + + tracing::debug!( + "Ingested WAL up to {streaming_lsn}. Last record LSN is {}", + timeline.get_last_record_lsn() + ); + + if let Some(lsn) = next_record_lsn { + last_rec_lsn = lsn; + } + + Some(streaming_lsn) + } + ReplicationMessage::XLogData(xlog_data) => { // Pass the WAL data to the decoder, and see if we can decode // more records as a result. @@ -316,21 +478,6 @@ pub(super) async fn handle_walreceiver_connection( let mut uncommitted_records = 0; let mut filtered_records = 0; - async fn commit( - modification: &mut DatadirModification<'_>, - uncommitted: &mut u64, - filtered: &mut u64, - ctx: &RequestContext, - ) -> anyhow::Result<()> { - WAL_INGEST - .records_committed - .inc_by(*uncommitted - *filtered); - modification.commit(ctx).await?; - *uncommitted = 0; - *filtered = 0; - Ok(()) - } - while let Some((next_record_lsn, recdata)) = waldecoder.poll_decode()? { // It is important to deal with the aligned records as lsn in getPage@LSN is // aligned and can be several bytes bigger. Without this alignment we are diff --git a/pageserver/src/tenant/upload_queue.rs b/pageserver/src/tenant/upload_queue.rs index f14bf2f8c3..ef3aa759f3 100644 --- a/pageserver/src/tenant/upload_queue.rs +++ b/pageserver/src/tenant/upload_queue.rs @@ -3,6 +3,7 @@ use super::storage_layer::ResidentLayer; use crate::tenant::metadata::TimelineMetadata; use crate::tenant::remote_timeline_client::index::IndexPart; use crate::tenant::remote_timeline_client::index::LayerFileMetadata; +use std::collections::HashSet; use std::collections::{HashMap, VecDeque}; use std::fmt::Debug; @@ -14,7 +15,6 @@ use utils::lsn::AtomicLsn; use std::sync::atomic::AtomicU32; use utils::lsn::Lsn; -#[cfg(feature = "testing")] use utils::generation::Generation; // clippy warns that Uninitialized is much smaller than Initialized, which wastes @@ -38,6 +38,12 @@ impl UploadQueue { } } +#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)] +pub(crate) enum OpType { + MayReorder, + FlushDeletion, +} + /// This keeps track of queued and in-progress tasks. pub(crate) struct UploadQueueInitialized { /// Counter to assign task IDs @@ -88,6 +94,9 @@ pub(crate) struct UploadQueueInitialized { #[cfg(feature = "testing")] pub(crate) dangling_files: HashMap, + /// Ensure we order file operations correctly. + pub(crate) recently_deleted: HashSet<(LayerName, Generation)>, + /// Deletions that are blocked by the tenant configuration pub(crate) blocked_deletions: Vec, @@ -183,6 +192,7 @@ impl UploadQueue { queued_operations: VecDeque::new(), #[cfg(feature = "testing")] dangling_files: HashMap::new(), + recently_deleted: HashSet::new(), blocked_deletions: Vec::new(), shutting_down: false, shutdown_ready: Arc::new(tokio::sync::Semaphore::new(0)), @@ -224,6 +234,7 @@ impl UploadQueue { queued_operations: VecDeque::new(), #[cfg(feature = "testing")] dangling_files: HashMap::new(), + recently_deleted: HashSet::new(), blocked_deletions: Vec::new(), shutting_down: false, shutdown_ready: Arc::new(tokio::sync::Semaphore::new(0)), @@ -282,8 +293,8 @@ pub(crate) struct Delete { #[derive(Debug)] pub(crate) enum UploadOp { - /// Upload a layer file - UploadLayer(ResidentLayer, LayerFileMetadata), + /// Upload a layer file. The last field indicates the last operation for thie file. + UploadLayer(ResidentLayer, LayerFileMetadata, Option), /// Upload a index_part.json file UploadMetadata { @@ -305,11 +316,11 @@ pub(crate) enum UploadOp { impl std::fmt::Display for UploadOp { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { match self { - UploadOp::UploadLayer(layer, metadata) => { + UploadOp::UploadLayer(layer, metadata, mode) => { write!( f, - "UploadLayer({}, size={:?}, gen={:?})", - layer, metadata.file_size, metadata.generation + "UploadLayer({}, size={:?}, gen={:?}, mode={:?})", + layer, metadata.file_size, metadata.generation, mode ) } UploadOp::UploadMetadata { uploaded, .. } => { diff --git a/pageserver/src/virtual_file/owned_buffers_io/aligned_buffer/slice.rs b/pageserver/src/virtual_file/owned_buffers_io/aligned_buffer/slice.rs index 6cecf34c1c..1952b82578 100644 --- a/pageserver/src/virtual_file/owned_buffers_io/aligned_buffer/slice.rs +++ b/pageserver/src/virtual_file/owned_buffers_io/aligned_buffer/slice.rs @@ -19,7 +19,7 @@ impl<'a, const N: usize, const A: usize> AlignedSlice<'a, N, ConstAlign> { } } -impl<'a, const N: usize, A: Alignment> Deref for AlignedSlice<'a, N, A> { +impl Deref for AlignedSlice<'_, N, A> { type Target = [u8; N]; fn deref(&self) -> &Self::Target { @@ -27,13 +27,13 @@ impl<'a, const N: usize, A: Alignment> Deref for AlignedSlice<'a, N, A> { } } -impl<'a, const N: usize, A: Alignment> DerefMut for AlignedSlice<'a, N, A> { +impl DerefMut for AlignedSlice<'_, N, A> { fn deref_mut(&mut self) -> &mut Self::Target { self.buf } } -impl<'a, const N: usize, A: Alignment> AsRef<[u8; N]> for AlignedSlice<'a, N, A> { +impl AsRef<[u8; N]> for AlignedSlice<'_, N, A> { fn as_ref(&self) -> &[u8; N] { self.buf } diff --git a/pageserver/src/walingest.rs b/pageserver/src/walingest.rs index ad6ccbc854..d568da596a 100644 --- a/pageserver/src/walingest.rs +++ b/pageserver/src/walingest.rs @@ -334,14 +334,32 @@ impl WalIngest { // replaying it would fail to find the previous image of the page, because // it doesn't exist. So check if the VM page(s) exist, and skip the WAL // record if it doesn't. - let vm_size = get_relsize(modification, vm_rel, ctx).await?; + // + // TODO: analyze the metrics and tighten this up accordingly. This logic + // implicitly assumes that VM pages see explicit WAL writes before + // implicit ClearVmBits, and will otherwise silently drop updates. + let Some(vm_size) = get_relsize(modification, vm_rel, ctx).await? else { + WAL_INGEST + .clear_vm_bits_unknown + .with_label_values(&["relation"]) + .inc(); + return Ok(()); + }; if let Some(blknum) = new_vm_blk { if blknum >= vm_size { + WAL_INGEST + .clear_vm_bits_unknown + .with_label_values(&["new_page"]) + .inc(); new_vm_blk = None; } } if let Some(blknum) = old_vm_blk { if blknum >= vm_size { + WAL_INGEST + .clear_vm_bits_unknown + .with_label_values(&["old_page"]) + .inc(); old_vm_blk = None; } } @@ -572,7 +590,8 @@ impl WalIngest { modification.put_rel_page_image_zero(rel, fsm_physical_page_no)?; fsm_physical_page_no += 1; } - let nblocks = get_relsize(modification, rel, ctx).await?; + // TODO: re-examine the None case here wrt. sharding; should we error? + let nblocks = get_relsize(modification, rel, ctx).await?.unwrap_or(0); if nblocks > fsm_physical_page_no { // check if something to do: FSM is larger than truncate position self.put_rel_truncation(modification, rel, fsm_physical_page_no, ctx) @@ -612,7 +631,8 @@ impl WalIngest { )?; vm_page_no += 1; } - let nblocks = get_relsize(modification, rel, ctx).await?; + // TODO: re-examine the None case here wrt. sharding; should we error? + let nblocks = get_relsize(modification, rel, ctx).await?.unwrap_or(0); if nblocks > vm_page_no { // check if something to do: VM is larger than truncate position self.put_rel_truncation(modification, rel, vm_page_no, ctx) @@ -1430,24 +1450,27 @@ impl WalIngest { } } +/// Returns the size of the relation as of this modification, or None if the relation doesn't exist. +/// +/// This is only accurate on shard 0. On other shards, it will return the size up to the highest +/// page number stored in the shard, or None if the shard does not have any pages for it. async fn get_relsize( modification: &DatadirModification<'_>, rel: RelTag, ctx: &RequestContext, -) -> Result { - let nblocks = if !modification +) -> Result, PageReconstructError> { + if !modification .tline .get_rel_exists(rel, Version::Modified(modification), ctx) .await? { - 0 - } else { - modification - .tline - .get_rel_size(rel, Version::Modified(modification), ctx) - .await? - }; - Ok(nblocks) + return Ok(None); + } + modification + .tline + .get_rel_size(rel, Version::Modified(modification), ctx) + .await + .map(Some) } #[allow(clippy::bool_assert_comparison)] diff --git a/pgxn/neon/logical_replication_monitor.c b/pgxn/neon/logical_replication_monitor.c index 1badbbed21..5eee5a1679 100644 --- a/pgxn/neon/logical_replication_monitor.c +++ b/pgxn/neon/logical_replication_monitor.c @@ -20,7 +20,7 @@ #define LS_MONITOR_CHECK_INTERVAL 10000 /* ms */ -static int logical_replication_max_snap_files = 300; +static int logical_replication_max_snap_files = 10000; /* * According to Chi (shyzh), the pageserver _should_ be good with 10 MB worth of @@ -184,7 +184,7 @@ InitLogicalReplicationMonitor(void) "Maximum allowed logical replication .snap files. When exceeded, slots are dropped until the limit is met. -1 disables the limit.", NULL, &logical_replication_max_snap_files, - 300, -1, INT_MAX, + 10000, -1, INT_MAX, PGC_SIGHUP, 0, NULL, NULL, NULL); diff --git a/poetry.lock b/poetry.lock index 6171f92391..59ae5cf1ca 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -114,7 +114,6 @@ files = [ [package.dependencies] aiohappyeyeballs = ">=2.3.0" aiosignal = ">=1.1.2" -async-timeout = {version = ">=4.0,<6.0", markers = "python_version < \"3.11\""} attrs = ">=17.3.0" frozenlist = ">=1.1.1" multidict = ">=4.5,<7.0" @@ -219,10 +218,8 @@ files = [ ] [package.dependencies] -exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} idna = ">=2.8" sniffio = ">=1.1" -typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} [package.extras] doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] @@ -737,10 +734,7 @@ files = [ [package.dependencies] jmespath = ">=0.7.1,<2.0.0" python-dateutil = ">=2.1,<3.0.0" -urllib3 = [ - {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, - {version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""}, -] +urllib3 = {version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""} [package.extras] crt = ["awscrt (==0.19.19)"] @@ -1069,20 +1063,6 @@ docs = ["myst-parser (==0.18.0)", "sphinx (==5.1.1)"] ssh = ["paramiko (>=2.4.3)"] websockets = ["websocket-client (>=1.3.0)"] -[[package]] -name = "exceptiongroup" -version = "1.1.1" -description = "Backport of PEP 654 (exception groups)" -optional = false -python-versions = ">=3.7" -files = [ - {file = "exceptiongroup-1.1.1-py3-none-any.whl", hash = "sha256:232c37c63e4f682982c8b6459f33a8981039e5fb8756b2074364e5055c498c9e"}, - {file = "exceptiongroup-1.1.1.tar.gz", hash = "sha256:d484c3090ba2889ae2928419117447a14daf3c1231d5e30d0aae34f354f01785"}, -] - -[package.extras] -test = ["pytest (>=6)"] - [[package]] name = "execnet" version = "1.9.0" @@ -1110,7 +1090,6 @@ files = [ [package.dependencies] click = ">=8.0" -importlib-metadata = {version = ">=3.6.0", markers = "python_version < \"3.10\""} itsdangerous = ">=2.0" Jinja2 = ">=3.0" Werkzeug = ">=2.2.2" @@ -1319,25 +1298,6 @@ files = [ {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, ] -[[package]] -name = "importlib-metadata" -version = "4.12.0" -description = "Read metadata from Python packages" -optional = false -python-versions = ">=3.7" -files = [ - {file = "importlib_metadata-4.12.0-py3-none-any.whl", hash = "sha256:7401a975809ea1fdc658c3aa4f78cc2195a0e019c5cbc4c06122884e9ae80c23"}, - {file = "importlib_metadata-4.12.0.tar.gz", hash = "sha256:637245b8bab2b6502fcbc752cc4b7a6f6243bb02b31c5c26156ad103d3d45670"}, -] - -[package.dependencies] -zipp = ">=0.5" - -[package.extras] -docs = ["jaraco.packaging (>=9)", "rst.linker (>=1.9)", "sphinx"] -perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)"] - [[package]] name = "iniconfig" version = "1.1.1" @@ -1898,48 +1858,54 @@ files = [ [[package]] name = "mypy" -version = "1.3.0" +version = "1.13.0" description = "Optional static typing for Python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "mypy-1.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c1eb485cea53f4f5284e5baf92902cd0088b24984f4209e25981cc359d64448d"}, - {file = "mypy-1.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4c99c3ecf223cf2952638da9cd82793d8f3c0c5fa8b6ae2b2d9ed1e1ff51ba85"}, - {file = "mypy-1.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:550a8b3a19bb6589679a7c3c31f64312e7ff482a816c96e0cecec9ad3a7564dd"}, - {file = "mypy-1.3.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:cbc07246253b9e3d7d74c9ff948cd0fd7a71afcc2b77c7f0a59c26e9395cb152"}, - {file = "mypy-1.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:a22435632710a4fcf8acf86cbd0d69f68ac389a3892cb23fbad176d1cddaf228"}, - {file = "mypy-1.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6e33bb8b2613614a33dff70565f4c803f889ebd2f859466e42b46e1df76018dd"}, - {file = "mypy-1.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7d23370d2a6b7a71dc65d1266f9a34e4cde9e8e21511322415db4b26f46f6b8c"}, - {file = "mypy-1.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:658fe7b674769a0770d4b26cb4d6f005e88a442fe82446f020be8e5f5efb2fae"}, - {file = "mypy-1.3.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6e42d29e324cdda61daaec2336c42512e59c7c375340bd202efa1fe0f7b8f8ca"}, - {file = "mypy-1.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:d0b6c62206e04061e27009481cb0ec966f7d6172b5b936f3ead3d74f29fe3dcf"}, - {file = "mypy-1.3.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:76ec771e2342f1b558c36d49900dfe81d140361dd0d2df6cd71b3db1be155409"}, - {file = "mypy-1.3.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ebc95f8386314272bbc817026f8ce8f4f0d2ef7ae44f947c4664efac9adec929"}, - {file = "mypy-1.3.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:faff86aa10c1aa4a10e1a301de160f3d8fc8703b88c7e98de46b531ff1276a9a"}, - {file = "mypy-1.3.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8c5979d0deb27e0f4479bee18ea0f83732a893e81b78e62e2dda3e7e518c92ee"}, - {file = "mypy-1.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c5d2cc54175bab47011b09688b418db71403aefad07cbcd62d44010543fc143f"}, - {file = "mypy-1.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:87df44954c31d86df96c8bd6e80dfcd773473e877ac6176a8e29898bfb3501cb"}, - {file = "mypy-1.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:473117e310febe632ddf10e745a355714e771ffe534f06db40702775056614c4"}, - {file = "mypy-1.3.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:74bc9b6e0e79808bf8678d7678b2ae3736ea72d56eede3820bd3849823e7f305"}, - {file = "mypy-1.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:44797d031a41516fcf5cbfa652265bb994e53e51994c1bd649ffcd0c3a7eccbf"}, - {file = "mypy-1.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ddae0f39ca146972ff6bb4399f3b2943884a774b8771ea0a8f50e971f5ea5ba8"}, - {file = "mypy-1.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1c4c42c60a8103ead4c1c060ac3cdd3ff01e18fddce6f1016e08939647a0e703"}, - {file = "mypy-1.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e86c2c6852f62f8f2b24cb7a613ebe8e0c7dc1402c61d36a609174f63e0ff017"}, - {file = "mypy-1.3.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f9dca1e257d4cc129517779226753dbefb4f2266c4eaad610fc15c6a7e14283e"}, - {file = "mypy-1.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:95d8d31a7713510685b05fbb18d6ac287a56c8f6554d88c19e73f724a445448a"}, - {file = "mypy-1.3.0-py3-none-any.whl", hash = "sha256:a8763e72d5d9574d45ce5881962bc8e9046bf7b375b0abf031f3e6811732a897"}, - {file = "mypy-1.3.0.tar.gz", hash = "sha256:e1f4d16e296f5135624b34e8fb741eb0eadedca90862405b1f1fde2040b9bd11"}, + {file = "mypy-1.13.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6607e0f1dd1fb7f0aca14d936d13fd19eba5e17e1cd2a14f808fa5f8f6d8f60a"}, + {file = "mypy-1.13.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8a21be69bd26fa81b1f80a61ee7ab05b076c674d9b18fb56239d72e21d9f4c80"}, + {file = "mypy-1.13.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7b2353a44d2179846a096e25691d54d59904559f4232519d420d64da6828a3a7"}, + {file = "mypy-1.13.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0730d1c6a2739d4511dc4253f8274cdd140c55c32dfb0a4cf8b7a43f40abfa6f"}, + {file = "mypy-1.13.0-cp310-cp310-win_amd64.whl", hash = "sha256:c5fc54dbb712ff5e5a0fca797e6e0aa25726c7e72c6a5850cfd2adbc1eb0a372"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:581665e6f3a8a9078f28d5502f4c334c0c8d802ef55ea0e7276a6e409bc0d82d"}, + {file = "mypy-1.13.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3ddb5b9bf82e05cc9a627e84707b528e5c7caaa1c55c69e175abb15a761cec2d"}, + {file = "mypy-1.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:20c7ee0bc0d5a9595c46f38beb04201f2620065a93755704e141fcac9f59db2b"}, + {file = "mypy-1.13.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3790ded76f0b34bc9c8ba4def8f919dd6a46db0f5a6610fb994fe8efdd447f73"}, + {file = "mypy-1.13.0-cp311-cp311-win_amd64.whl", hash = "sha256:51f869f4b6b538229c1d1bcc1dd7d119817206e2bc54e8e374b3dfa202defcca"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5c7051a3461ae84dfb5dd15eff5094640c61c5f22257c8b766794e6dd85e72d5"}, + {file = "mypy-1.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39bb21c69a5d6342f4ce526e4584bc5c197fd20a60d14a8624d8743fffb9472e"}, + {file = "mypy-1.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:164f28cb9d6367439031f4c81e84d3ccaa1e19232d9d05d37cb0bd880d3f93c2"}, + {file = "mypy-1.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a4c1bfcdbce96ff5d96fc9b08e3831acb30dc44ab02671eca5953eadad07d6d0"}, + {file = "mypy-1.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:a0affb3a79a256b4183ba09811e3577c5163ed06685e4d4b46429a271ba174d2"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a7b44178c9760ce1a43f544e595d35ed61ac2c3de306599fa59b38a6048e1aa7"}, + {file = "mypy-1.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5d5092efb8516d08440e36626f0153b5006d4088c1d663d88bf79625af3d1d62"}, + {file = "mypy-1.13.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:de2904956dac40ced10931ac967ae63c5089bd498542194b436eb097a9f77bc8"}, + {file = "mypy-1.13.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:7bfd8836970d33c2105562650656b6846149374dc8ed77d98424b40b09340ba7"}, + {file = "mypy-1.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:9f73dba9ec77acb86457a8fc04b5239822df0c14a082564737833d2963677dbc"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:100fac22ce82925f676a734af0db922ecfea991e1d7ec0ceb1e115ebe501301a"}, + {file = "mypy-1.13.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7bcb0bb7f42a978bb323a7c88f1081d1b5dee77ca86f4100735a6f541299d8fb"}, + {file = "mypy-1.13.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bde31fc887c213e223bbfc34328070996061b0833b0a4cfec53745ed61f3519b"}, + {file = "mypy-1.13.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:07de989f89786f62b937851295ed62e51774722e5444a27cecca993fc3f9cd74"}, + {file = "mypy-1.13.0-cp38-cp38-win_amd64.whl", hash = "sha256:4bde84334fbe19bad704b3f5b78c4abd35ff1026f8ba72b29de70dda0916beb6"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0246bcb1b5de7f08f2826451abd947bf656945209b140d16ed317f65a17dc7dc"}, + {file = "mypy-1.13.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7f5b7deae912cf8b77e990b9280f170381fdfbddf61b4ef80927edd813163732"}, + {file = "mypy-1.13.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7029881ec6ffb8bc233a4fa364736789582c738217b133f1b55967115288a2bc"}, + {file = "mypy-1.13.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3e38b980e5681f28f033f3be86b099a247b13c491f14bb8b1e1e134d23bb599d"}, + {file = "mypy-1.13.0-cp39-cp39-win_amd64.whl", hash = "sha256:a6789be98a2017c912ae6ccb77ea553bbaf13d27605d2ca20a76dfbced631b24"}, + {file = "mypy-1.13.0-py3-none-any.whl", hash = "sha256:9c250883f9fd81d212e0952c92dbfcc96fc237f4b7c92f56ac81fd48460b3e5a"}, + {file = "mypy-1.13.0.tar.gz", hash = "sha256:0291a61b6fbf3e6673e3405cfcc0e7650bebc7939659fdca2702958038bd835e"}, ] [package.dependencies] mypy-extensions = ">=1.0.0" -tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typing-extensions = ">=3.10" +typing-extensions = ">=4.6.0" [package.extras] dmypy = ["psutil (>=4.0)"] +faster-cache = ["orjson"] install-types = ["pip"] -python2 = ["typed-ast (>=1.4.0,<2)"] +mypyc = ["setuptools (>=50)"] reports = ["lxml"] [[package]] @@ -2514,11 +2480,9 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "sys_platform == \"win32\""} -exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" pluggy = ">=0.12,<2.0" -tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] @@ -2581,10 +2545,7 @@ files = [ ] [package.dependencies] -pytest = [ - {version = ">=5.0", markers = "python_version < \"3.10\""}, - {version = ">=6.2.4", markers = "python_version >= \"3.10\""}, -] +pytest = {version = ">=6.2.4", markers = "python_version >= \"3.10\""} [[package]] name = "pytest-repeat" @@ -2602,18 +2563,18 @@ pytest = "*" [[package]] name = "pytest-rerunfailures" -version = "13.0" +version = "15.0" description = "pytest plugin to re-run tests to eliminate flaky failures" optional = false -python-versions = ">=3.7" +python-versions = ">=3.9" files = [ - {file = "pytest-rerunfailures-13.0.tar.gz", hash = "sha256:e132dbe420bc476f544b96e7036edd0a69707574209b6677263c950d19b09199"}, - {file = "pytest_rerunfailures-13.0-py3-none-any.whl", hash = "sha256:34919cb3fcb1f8e5d4b940aa75ccdea9661bade925091873b7c6fa5548333069"}, + {file = "pytest-rerunfailures-15.0.tar.gz", hash = "sha256:2d9ac7baf59f4c13ac730b47f6fa80e755d1ba0581da45ce30b72fb3542b4474"}, + {file = "pytest_rerunfailures-15.0-py3-none-any.whl", hash = "sha256:dd150c4795c229ef44320adc9a0c0532c51b78bb7a6843a8c53556b9a611df1a"}, ] [package.dependencies] packaging = ">=17.1" -pytest = ">=7" +pytest = ">=7.4,<8.2.2 || >8.2.2" [[package]] name = "pytest-split" @@ -3092,17 +3053,6 @@ files = [ {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, ] -[[package]] -name = "tomli" -version = "2.0.1" -description = "A lil' TOML parser" -optional = false -python-versions = ">=3.7" -files = [ - {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, - {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, -] - [[package]] name = "types-jwcrypto" version = "1.5.0.20240925" @@ -3359,16 +3309,6 @@ files = [ {file = "wrapt-1.14.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8ad85f7f4e20964db4daadcab70b47ab05c7c1cf2a7c1e51087bfaa83831854c"}, {file = "wrapt-1.14.1-cp310-cp310-win32.whl", hash = "sha256:a9a52172be0b5aae932bef82a79ec0a0ce87288c7d132946d645eba03f0ad8a8"}, {file = "wrapt-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:6d323e1554b3d22cfc03cd3243b5bb815a51f5249fdcbb86fda4bf62bab9e164"}, - {file = "wrapt-1.14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ecee4132c6cd2ce5308e21672015ddfed1ff975ad0ac8d27168ea82e71413f55"}, - {file = "wrapt-1.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2020f391008ef874c6d9e208b24f28e31bcb85ccff4f335f15a3251d222b92d9"}, - {file = "wrapt-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2feecf86e1f7a86517cab34ae6c2f081fd2d0dac860cb0c0ded96d799d20b335"}, - {file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:240b1686f38ae665d1b15475966fe0472f78e71b1b4903c143a842659c8e4cb9"}, - {file = "wrapt-1.14.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9008dad07d71f68487c91e96579c8567c98ca4c3881b9b113bc7b33e9fd78b8"}, - {file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6447e9f3ba72f8e2b985a1da758767698efa72723d5b59accefd716e9e8272bf"}, - {file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:acae32e13a4153809db37405f5eba5bac5fbe2e2ba61ab227926a22901051c0a"}, - {file = "wrapt-1.14.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49ef582b7a1152ae2766557f0550a9fcbf7bbd76f43fbdc94dd3bf07cc7168be"}, - {file = "wrapt-1.14.1-cp311-cp311-win32.whl", hash = "sha256:358fe87cc899c6bb0ddc185bf3dbfa4ba646f05b1b0b9b5a27c2cb92c2cea204"}, - {file = "wrapt-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:26046cd03936ae745a502abf44dac702a5e6880b2b01c29aea8ddf3353b68224"}, {file = "wrapt-1.14.1-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:43ca3bbbe97af00f49efb06e352eae40434ca9d915906f77def219b88e85d907"}, {file = "wrapt-1.14.1-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:6b1a564e6cb69922c7fe3a678b9f9a3c54e72b469875aa8018f18b4d1dd1adf3"}, {file = "wrapt-1.14.1-cp35-cp35m-manylinux2010_i686.whl", hash = "sha256:00b6d4ea20a906c0ca56d84f93065b398ab74b927a7a3dbd470f6fc503f95dc3"}, @@ -3523,21 +3463,6 @@ idna = ">=2.0" multidict = ">=4.0" propcache = ">=0.2.0" -[[package]] -name = "zipp" -version = "3.19.1" -description = "Backport of pathlib-compatible object wrapper for zip files" -optional = false -python-versions = ">=3.8" -files = [ - {file = "zipp-3.19.1-py3-none-any.whl", hash = "sha256:2828e64edb5386ea6a52e7ba7cdb17bb30a73a858f5eb6eb93d8d36f5ea26091"}, - {file = "zipp-3.19.1.tar.gz", hash = "sha256:35427f6d5594f4acf82d25541438348c26736fa9b3afa2754bcd63cdb99d8e8f"}, -] - -[package.extras] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] - [[package]] name = "zstandard" version = "0.21.0" @@ -3598,5 +3523,5 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" -python-versions = "^3.9" -content-hash = "8cb9c38d83eec441391c0528ac2fbefde18c734373b2399e07c69382044e8ced" +python-versions = "^3.11" +content-hash = "426c385df93f578ba3537c40a269535e27fbcca1978b3cf266096ecbc298c6a9" diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 1665d6361a..0d774d529d 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -55,6 +55,7 @@ parquet.workspace = true parquet_derive.workspace = true pin-project-lite.workspace = true postgres_backend.workspace = true +postgres-protocol = { package = "postgres-protocol2", path = "../libs/proxy/postgres-protocol2" } pq_proto.workspace = true prometheus.workspace = true rand.workspace = true @@ -80,8 +81,7 @@ subtle.workspace = true thiserror.workspace = true tikv-jemallocator.workspace = true tikv-jemalloc-ctl = { workspace = true, features = ["use_std"] } -tokio-postgres = { workspace = true, features = ["with-serde_json-1"] } -tokio-postgres-rustls.workspace = true +tokio-postgres = { package = "tokio-postgres2", path = "../libs/proxy/tokio-postgres2" } tokio-rustls.workspace = true tokio-util.workspace = true tokio = { workspace = true, features = ["signal"] } @@ -96,7 +96,6 @@ utils.workspace = true uuid.workspace = true rustls-native-certs.workspace = true x509-parser.workspace = true -postgres-protocol.workspace = true redis.workspace = true zerocopy.workspace = true @@ -117,6 +116,5 @@ tokio-tungstenite.workspace = true pbkdf2 = { workspace = true, features = ["simple", "std"] } rcgen.workspace = true rstest.workspace = true -tokio-postgres-rustls.workspace = true walkdir.workspace = true rand_distr = "0.4" diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index 5772471486..bf7a1cb070 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -6,6 +6,7 @@ use tokio_postgres::config::SslMode; use tracing::{info, info_span}; use super::ComputeCredentialKeys; +use crate::auth::IpPattern; use crate::cache::Cached; use crate::config::AuthenticationConfig; use crate::context::RequestContext; @@ -74,10 +75,10 @@ impl ConsoleRedirectBackend { ctx: &RequestContext, auth_config: &'static AuthenticationConfig, client: &mut PqStream, - ) -> auth::Result { + ) -> auth::Result<(ConsoleRedirectNodeInfo, Option>)> { authenticate(ctx, auth_config, &self.console_uri, client) .await - .map(ConsoleRedirectNodeInfo) + .map(|(node_info, ip_allowlist)| (ConsoleRedirectNodeInfo(node_info), ip_allowlist)) } } @@ -102,7 +103,7 @@ async fn authenticate( auth_config: &'static AuthenticationConfig, link_uri: &reqwest::Url, client: &mut PqStream, -) -> auth::Result { +) -> auth::Result<(NodeInfo, Option>)> { ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect); // registering waiter can fail if we get unlucky with rng. @@ -176,9 +177,12 @@ async fn authenticate( config.password(password.as_ref()); } - Ok(NodeInfo { - config, - aux: db_info.aux, - allow_self_signed_compute: false, // caller may override - }) + Ok(( + NodeInfo { + config, + aux: db_info.aux, + allow_self_signed_compute: false, // caller may override + }, + db_info.allowed_ips, + )) } diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index f721d81aa2..517d4fd34b 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -132,6 +132,93 @@ struct JwkSet<'a> { keys: Vec<&'a RawValue>, } +/// Given a jwks_url, fetch the JWKS and parse out all the signing JWKs. +/// Returns `None` and log a warning if there are any errors. +async fn fetch_jwks( + client: &reqwest_middleware::ClientWithMiddleware, + jwks_url: url::Url, +) -> Option { + let req = client.get(jwks_url.clone()); + // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only. + let resp = req.send().await.and_then(|r| { + r.error_for_status() + .map_err(reqwest_middleware::Error::Reqwest) + }); + + let resp = match resp { + Ok(r) => r, + // TODO: should we re-insert JWKs if we want to keep this JWKs URL? + // I expect these failures would be quite sparse. + Err(e) => { + tracing::warn!(url=?jwks_url, error=?e, "could not fetch JWKs"); + return None; + } + }; + + let resp: http::Response = resp.into(); + + let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE).await { + Ok(bytes) => bytes, + Err(e) => { + tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs"); + return None; + } + }; + + let jwks = match serde_json::from_slice::(&bytes) { + Ok(jwks) => jwks, + Err(e) => { + tracing::warn!(url=?jwks_url, error=?e, "could not decode JWKs"); + return None; + } + }; + + // `jose_jwk::Jwk` is quite large (288 bytes). Let's not pre-allocate for what we don't need. + // + // Even though we limit our responses to 64KiB, we could still receive a payload like + // `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}`. Parsing this as `RawValue` uses 468KiB. + // Pre-allocating the corresponding `Vec::::with_capacity(30000)` uses 8.2MiB. + let mut keys = vec![]; + + let mut failed = 0; + for key in jwks.keys { + let key = match serde_json::from_str::(key.get()) { + Ok(key) => key, + Err(e) => { + tracing::debug!(url=?jwks_url, failed=?e, "could not decode JWK"); + failed += 1; + continue; + } + }; + + // if `use` (called `cls` in rust) is specified to be something other than signing, + // we can skip storing it. + if key + .prm + .cls + .as_ref() + .is_some_and(|c| *c != jose_jwk::Class::Signing) + { + continue; + } + + keys.push(key); + } + + keys.shrink_to_fit(); + + if failed > 0 { + tracing::warn!(url=?jwks_url, failed, "could not decode JWKs"); + } + + if keys.is_empty() { + tracing::warn!(url=?jwks_url, "no valid JWKs found inside the response body"); + return None; + } + + Some(jose_jwk::JwkSet { keys }) +} + impl JwkCacheEntryLock { async fn acquire_permit<'a>(self: &'a Arc) -> JwkRenewalPermit<'a> { JwkRenewalPermit::acquire_permit(self).await @@ -166,87 +253,15 @@ impl JwkCacheEntryLock { // TODO(conrad): run concurrently // TODO(conrad): strip the JWKs urls (should be checked by cplane as well - cloud#16284) for rule in rules { - let req = client.get(rule.jwks_url.clone()); - // TODO(conrad): eventually switch to using reqwest_middleware/`new_client_with_timeout`. - // TODO(conrad): We need to filter out URLs that point to local resources. Public internet only. - match req.send().await.and_then(|r| { - r.error_for_status() - .map_err(reqwest_middleware::Error::Reqwest) - }) { - // todo: should we re-insert JWKs if we want to keep this JWKs URL? - // I expect these failures would be quite sparse. - Err(e) => tracing::warn!(url=?rule.jwks_url, error=?e, "could not fetch JWKs"), - Ok(r) => { - let resp: http::Response = r.into(); - - let bytes = match read_body_with_limit(resp.into_body(), MAX_JWK_BODY_SIZE) - .await - { - Ok(bytes) => bytes, - Err(e) => { - tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs"); - continue; - } - }; - - match serde_json::from_slice::(&bytes) { - Err(e) => { - tracing::warn!(url=?rule.jwks_url, error=?e, "could not decode JWKs"); - } - Ok(jwks) => { - // size_of::<&RawValue>() == 16 - // size_of::() == 288 - // better to not pre-allocate this as it might be pretty large - especially if it has many - // keys we don't want or need. - // trivial 'attack': `{"keys":[` + repeat(`0`).take(30000).join(`,`) + `]}` - // this would consume 8MiB just like that! - let mut keys = vec![]; - let mut failed = 0; - for key in jwks.keys { - match serde_json::from_str::(key.get()) { - Ok(key) => { - // if `use` (called `cls` in rust) is specified to be something other than signing, - // we can skip storing it. - if key - .prm - .cls - .as_ref() - .is_some_and(|c| *c != jose_jwk::Class::Signing) - { - continue; - } - - keys.push(key); - } - Err(e) => { - tracing::debug!(url=?rule.jwks_url, failed=?e, "could not decode JWK"); - failed += 1; - } - } - } - keys.shrink_to_fit(); - - if failed > 0 { - tracing::warn!(url=?rule.jwks_url, failed, "could not decode JWKs"); - } - - if keys.is_empty() { - tracing::warn!(url=?rule.jwks_url, "no valid JWKs found inside the response body"); - continue; - } - - let jwks = jose_jwk::JwkSet { keys }; - key_sets.insert( - rule.id, - KeySet { - jwks, - audience: rule.audience, - role_names: rule.role_names, - }, - ); - } - }; - } + if let Some(jwks) = fetch_jwks(client, rule.jwks_url).await { + key_sets.insert( + rule.id, + KeySet { + jwks, + audience: rule.audience, + role_names: rule.role_names, + }, + ); } } diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index 57ecd5e499..7e1b26a11a 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -6,7 +6,6 @@ pub mod local; use std::net::IpAddr; use std::sync::Arc; -use std::time::Duration; pub use console_redirect::ConsoleRedirectBackend; pub(crate) use console_redirect::ConsoleRedirectError; @@ -30,7 +29,7 @@ use crate::intern::EndpointIdInt; use crate::metrics::Metrics; use crate::proxy::connect_compute::ComputeConnectBackend; use crate::proxy::NeonOptions; -use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter, RateBucketInfo}; +use crate::rate_limiter::{BucketRateLimiter, EndpointRateLimiter}; use crate::stream::Stream; use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{scram, stream}; @@ -192,21 +191,6 @@ impl MaskedIp { // This can't be just per IP because that would limit some PaaS that share IP addresses pub type AuthRateLimiter = BucketRateLimiter<(EndpointIdInt, MaskedIp)>; -impl RateBucketInfo { - /// All of these are per endpoint-maskedip pair. - /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus). - /// - /// First bucket: 1000mcpus total per endpoint-ip pair - /// * 4096000 requests per second with 1 hash rounds. - /// * 1000 requests per second with 4096 hash rounds. - /// * 6.8 requests per second with 600000 hash rounds. - pub const DEFAULT_AUTH_SET: [Self; 3] = [ - Self::new(1000 * 4096, Duration::from_secs(1)), - Self::new(600 * 4096, Duration::from_secs(60)), - Self::new(300 * 4096, Duration::from_secs(600)), - ]; -} - impl AuthenticationConfig { pub(crate) fn check_rate_limit( &self, diff --git a/proxy/src/bin/proxy.rs b/proxy/src/bin/proxy.rs index 45fbe4a398..a935378162 100644 --- a/proxy/src/bin/proxy.rs +++ b/proxy/src/bin/proxy.rs @@ -428,8 +428,9 @@ async fn main() -> anyhow::Result<()> { )?))), None => None, }; + let cancellation_handler = Arc::new(CancellationHandler::< - Option>>, + Option>>, >::new( cancel_map.clone(), redis_publisher, diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index 3ad2d55b53..91e198bf88 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -1,7 +1,8 @@ -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use dashmap::DashMap; +use ipnet::{IpNet, Ipv4Net, Ipv6Net}; use pq_proto::CancelKeyData; use thiserror::Error; use tokio::net::TcpStream; @@ -10,8 +11,10 @@ use tokio_postgres::{CancelToken, NoTls}; use tracing::{debug, info}; use uuid::Uuid; +use crate::auth::{check_peer_addr_is_in_list, IpPattern}; use crate::error::ReportableError; use crate::metrics::{CancellationRequest, CancellationSource, Metrics}; +use crate::rate_limiter::LeakyBucketRateLimiter; use crate::redis::cancellation_publisher::{ CancellationPublisher, CancellationPublisherMut, RedisPublisherClient, }; @@ -20,6 +23,8 @@ pub type CancelMap = Arc>>; pub type CancellationHandlerMain = CancellationHandler>>>; pub(crate) type CancellationHandlerMainInternal = Option>>; +type IpSubnetKey = IpNet; + /// Enables serving `CancelRequest`s. /// /// If `CancellationPublisher` is available, cancel request will be used to publish the cancellation key to other proxy instances. @@ -29,14 +34,23 @@ pub struct CancellationHandler

{ /// This field used for the monitoring purposes. /// Represents the source of the cancellation request. from: CancellationSource, + // rate limiter of cancellation requests + limiter: Arc>>, } #[derive(Debug, Error)] pub(crate) enum CancelError { #[error("{0}")] IO(#[from] std::io::Error), + #[error("{0}")] Postgres(#[from] tokio_postgres::Error), + + #[error("rate limit exceeded")] + RateLimit, + + #[error("IP is not allowed")] + IpNotAllowed, } impl ReportableError for CancelError { @@ -47,6 +61,8 @@ impl ReportableError for CancelError { crate::error::ErrorKind::Postgres } CancelError::Postgres(_) => crate::error::ErrorKind::Compute, + CancelError::RateLimit => crate::error::ErrorKind::RateLimit, + CancelError::IpNotAllowed => crate::error::ErrorKind::User, } } } @@ -79,13 +95,37 @@ impl CancellationHandler

{ cancellation_handler: self, } } + /// Try to cancel a running query for the corresponding connection. /// If the cancellation key is not found, it will be published to Redis. + /// check_allowed - if true, check if the IP is allowed to cancel the query + /// return Result primarily for tests pub(crate) async fn cancel_session( &self, key: CancelKeyData, session_id: Uuid, + peer_addr: IpAddr, + check_allowed: bool, ) -> Result<(), CancelError> { + // TODO: check for unspecified address is only for backward compatibility, should be removed + if !peer_addr.is_unspecified() { + let subnet_key = match peer_addr { + IpAddr::V4(ip) => IpNet::V4(Ipv4Net::new_assert(ip, 24).trunc()), // use defaut mask here + IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()), + }; + if !self.limiter.lock().unwrap().check(subnet_key, 1) { + tracing::debug!("Rate limit exceeded. Skipping cancellation message"); + Metrics::get() + .proxy + .cancellation_requests_total + .inc(CancellationRequest { + source: self.from, + kind: crate::metrics::CancellationOutcome::RateLimitExceeded, + }); + return Err(CancelError::RateLimit); + } + } + // NB: we should immediately release the lock after cloning the token. let Some(cancel_closure) = self.map.get(&key).and_then(|x| x.clone()) else { tracing::warn!("query cancellation key not found: {key}"); @@ -96,9 +136,17 @@ impl CancellationHandler

{ source: self.from, kind: crate::metrics::CancellationOutcome::NotFound, }); - match self.client.try_publish(key, session_id).await { + + if session_id == Uuid::nil() { + // was already published, do not publish it again + return Ok(()); + } + + match self.client.try_publish(key, session_id, peer_addr).await { Ok(()) => {} // do nothing Err(e) => { + // log it here since cancel_session could be spawned in a task + tracing::error!("failed to publish cancellation key: {key}, error: {e}"); return Err(CancelError::IO(std::io::Error::new( std::io::ErrorKind::Other, e.to_string(), @@ -107,6 +155,15 @@ impl CancellationHandler

{ } return Ok(()); }; + + if check_allowed + && !check_peer_addr_is_in_list(&peer_addr, cancel_closure.ip_allowlist.as_slice()) + { + // log it here since cancel_session could be spawned in a task + tracing::warn!("IP is not allowed to cancel the query: {key}"); + return Err(CancelError::IpNotAllowed); + } + Metrics::get() .proxy .cancellation_requests_total @@ -135,13 +192,29 @@ impl CancellationHandler<()> { map, client: (), from, + limiter: Arc::new(std::sync::Mutex::new( + LeakyBucketRateLimiter::::new_with_shards( + LeakyBucketRateLimiter::::DEFAULT, + 64, + ), + )), } } } impl CancellationHandler>>> { pub fn new(map: CancelMap, client: Option>>, from: CancellationSource) -> Self { - Self { map, client, from } + Self { + map, + client, + from, + limiter: Arc::new(std::sync::Mutex::new( + LeakyBucketRateLimiter::::new_with_shards( + LeakyBucketRateLimiter::::DEFAULT, + 64, + ), + )), + } } } @@ -152,13 +225,19 @@ impl CancellationHandler>>> { pub struct CancelClosure { socket_addr: SocketAddr, cancel_token: CancelToken, + ip_allowlist: Vec, } impl CancelClosure { - pub(crate) fn new(socket_addr: SocketAddr, cancel_token: CancelToken) -> Self { + pub(crate) fn new( + socket_addr: SocketAddr, + cancel_token: CancelToken, + ip_allowlist: Vec, + ) -> Self { Self { socket_addr, cancel_token, + ip_allowlist, } } /// Cancels the query running on user's compute node. @@ -168,6 +247,9 @@ impl CancelClosure { debug!("query was cancelled"); Ok(()) } + pub(crate) fn set_ip_allowlist(&mut self, ip_allowlist: Vec) { + self.ip_allowlist = ip_allowlist; + } } /// Helper for registering query cancellation tokens. @@ -229,6 +311,8 @@ mod tests { cancel_key: 0, }, Uuid::new_v4(), + "127.0.0.1".parse().unwrap(), + true, ) .await .unwrap(); diff --git a/proxy/src/compute.rs b/proxy/src/compute.rs index e7fbe9ab47..2abe88ac88 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute.rs @@ -13,7 +13,6 @@ use rustls::pki_types::InvalidDnsNameError; use thiserror::Error; use tokio::net::TcpStream; use tokio_postgres::tls::MakeTlsConnect; -use tokio_postgres_rustls::MakeRustlsConnect; use tracing::{debug, error, info, warn}; use crate::auth::parse_endpoint_param; @@ -24,6 +23,7 @@ use crate::control_plane::errors::WakeComputeError; use crate::control_plane::messages::MetricsAuxInfo; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumDbConnectionsGuard}; +use crate::postgres_rustls::MakeRustlsConnect; use crate::proxy::neon_option; use crate::types::Host; @@ -244,7 +244,6 @@ impl ConnCfg { let port = ports.get(i).or_else(|| ports.first()).unwrap_or(&5432); let host = match host { Host::Tcp(host) => host.as_str(), - Host::Unix(_) => continue, // unix sockets are not welcome here }; match connect_once(host, *port).await { @@ -315,7 +314,7 @@ impl ConnCfg { }; let client_config = client_config.with_no_client_auth(); - let mut mk_tls = tokio_postgres_rustls::MakeRustlsConnect::new(client_config); + let mut mk_tls = crate::postgres_rustls::MakeRustlsConnect::new(client_config); let tls = >::make_tls_connect( &mut mk_tls, host, @@ -342,7 +341,7 @@ impl ConnCfg { // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw. // Yet another reason to rework the connection establishing code. - let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token()); + let cancel_closure = CancelClosure::new(socket_addr, client.cancel_token(), vec![]); let connection = PostgresConnection { stream, diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index c88b2936db..8f78df1964 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use futures::TryFutureExt; +use futures::{FutureExt, TryFutureExt}; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, Instrument}; @@ -35,6 +35,7 @@ pub async fn task_main( socket2::SockRef::from(&listener).set_keepalive(true)?; let connections = tokio_util::task::task_tracker::TaskTracker::new(); + let cancellations = tokio_util::task::task_tracker::TaskTracker::new(); while let Some(accept_result) = run_until_cancelled(listener.accept(), &cancellation_token).await @@ -48,6 +49,7 @@ pub async fn task_main( let session_id = uuid::Uuid::new_v4(); let cancellation_handler = Arc::clone(&cancellation_handler); + let cancellations = cancellations.clone(); debug!(protocol = "tcp", %session_id, "accepted new TCP connection"); @@ -88,40 +90,38 @@ pub async fn task_main( crate::metrics::Protocol::Tcp, &config.region, ); - let span = ctx.span(); - let startup = Box::pin( - handle_client( - config, - backend, - &ctx, - cancellation_handler, - socket, - conn_gauge, - ) - .instrument(span.clone()), - ); - let res = startup.await; + let res = handle_client( + config, + backend, + &ctx, + cancellation_handler, + socket, + conn_gauge, + cancellations, + ) + .instrument(ctx.span()) + .boxed() + .await; match res { Err(e) => { - // todo: log and push to ctx the error kind ctx.set_error_kind(e.get_error_kind()); - error!(parent: &span, "per-client task finished with an error: {e:#}"); + error!(parent: &ctx.span(), "per-client task finished with an error: {e:#}"); } Ok(None) => { ctx.set_success(); } Ok(Some(p)) => { ctx.set_success(); - ctx.log_connect(); - match p.proxy_pass().instrument(span.clone()).await { + let _disconnect = ctx.log_connect(); + match p.proxy_pass().await { Ok(()) => {} Err(ErrorSource::Client(e)) => { - error!(parent: &span, "per-client task finished with an IO error from the client: {e:#}"); + error!(?session_id, "per-client task finished with an IO error from the client: {e:#}"); } Err(ErrorSource::Compute(e)) => { - error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}"); + error!(?session_id, "per-client task finished with an IO error from the compute: {e:#}"); } } } @@ -130,10 +130,12 @@ pub async fn task_main( } connections.close(); + cancellations.close(); drop(listener); // Drain connections connections.wait().await; + cancellations.wait().await; Ok(()) } @@ -145,6 +147,7 @@ pub(crate) async fn handle_client( cancellation_handler: Arc, stream: S, conn_gauge: NumClientConnectionsGuard<'static>, + cancellations: tokio_util::task::task_tracker::TaskTracker, ) -> Result>, ClientRequestError> { debug!( protocol = %ctx.protocol(), @@ -156,25 +159,41 @@ pub(crate) async fn handle_client( let request_gauge = metrics.connection_requests.guard(proto); let tls = config.tls_config.as_ref(); - let record_handshake_error = !ctx.has_private_peer_addr(); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); let do_handshake = handshake(ctx, stream, tls, record_handshake_error); + let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? { HandshakeData::Startup(stream, params) => (stream, params), HandshakeData::Cancel(cancel_key_data) => { - return Ok(cancellation_handler - .cancel_session(cancel_key_data, ctx.session_id()) - .await - .map(|()| None)?) + // spawn a task to cancel the session, but don't wait for it + cancellations.spawn({ + let cancellation_handler_clone = Arc::clone(&cancellation_handler); + let session_id = ctx.session_id(); + let peer_ip = ctx.peer_addr(); + async move { + drop( + cancellation_handler_clone + .cancel_session( + cancel_key_data, + session_id, + peer_ip, + config.authentication_config.ip_allowlist_check_enabled, + ) + .await, + ); + } + }); + + return Ok(None); } }; drop(pause); ctx.set_db_options(params.clone()); - let user_info = match backend + let (user_info, ip_allowlist) = match backend .authenticate(ctx, &config.authentication_config, &mut stream) .await { @@ -198,6 +217,8 @@ pub(crate) async fn handle_client( .or_else(|e| stream.throw_error(e)) .await?; + node.cancel_closure + .set_ip_allowlist(ip_allowlist.unwrap_or_default()); let session = cancellation_handler.get_session(); prepare_client_connection(&node, &session, &mut stream).await?; @@ -212,6 +233,7 @@ pub(crate) async fn handle_client( client: stream, aux: node.aux.clone(), compute: node, + session_id: ctx.session_id(), _req: request_gauge, _conn: conn_gauge, _cancel: session, diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index 6d2d2d51ce..4a063a5faa 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -8,7 +8,7 @@ use pq_proto::StartupMessageParams; use smol_str::SmolStr; use tokio::sync::mpsc; use tracing::field::display; -use tracing::{debug, info_span, Span}; +use tracing::{debug, error, info_span, Span}; use try_lock::TryLock; use uuid::Uuid; @@ -272,11 +272,14 @@ impl RequestContext { this.success = true; } - pub fn log_connect(&self) { - self.0 - .try_lock() - .expect("should not deadlock") - .log_connect(); + pub fn log_connect(self) -> DisconnectLogger { + let mut this = self.0.into_inner(); + this.log_connect(); + + // close current span. + this.span = Span::none(); + + DisconnectLogger(this) } pub(crate) fn protocol(&self) -> Protocol { @@ -411,10 +414,13 @@ impl RequestContextInner { outcome, }); } + if let Some(tx) = self.sender.take() { - tx.send(RequestData::from(&*self)) - .inspect_err(|e| debug!("tx send failed: {e}")) - .ok(); + // If type changes, this error handling needs to be updated. + let tx: mpsc::UnboundedSender = tx; + if let Err(e) = tx.send(RequestData::from(&*self)) { + error!("log_connect channel send failed: {e}"); + } } } @@ -423,9 +429,11 @@ impl RequestContextInner { // Here we log the length of the session. self.disconnect_timestamp = Some(Utc::now()); if let Some(tx) = self.disconnect_sender.take() { - tx.send(RequestData::from(&*self)) - .inspect_err(|e| debug!("tx send failed: {e}")) - .ok(); + // If type changes, this error handling needs to be updated. + let tx: mpsc::UnboundedSender = tx; + if let Err(e) = tx.send(RequestData::from(&*self)) { + error!("log_disconnect channel send failed: {e}"); + } } } } @@ -434,8 +442,14 @@ impl Drop for RequestContextInner { fn drop(&mut self) { if self.sender.is_some() { self.log_connect(); - } else { - self.log_disconnect(); } } } + +pub struct DisconnectLogger(RequestContextInner); + +impl Drop for DisconnectLogger { + fn drop(&mut self) { + self.0.log_disconnect(); + } +} diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index 9bf3a275bb..e328c6de79 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -398,7 +398,7 @@ async fn upload_parquet( .err(); if let Some(err) = maybe_err { - tracing::warn!(%id, %err, "failed to upload request data"); + tracing::error!(%id, error = ?err, "failed to upload request data"); } Ok(buffer.writer()) diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index 500acad50f..9537d717a1 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -114,7 +114,7 @@ impl MockControlPlane { Ok((secret, allowed_ips)) } - .map_err(crate::error::log_error::) + .inspect_err(|e: &GetAuthInfoError| tracing::error!("{e}")) .instrument(info_span!("postgres", url = self.endpoint.as_str())) .await?; Ok(AuthInfo { diff --git a/proxy/src/control_plane/client/neon.rs b/proxy/src/control_plane/client/neon.rs index 757ea6720a..2cad981d01 100644 --- a/proxy/src/control_plane/client/neon.rs +++ b/proxy/src/control_plane/client/neon.rs @@ -134,8 +134,8 @@ impl NeonControlPlaneClient { project_id: body.project_id, }) } - .map_err(crate::error::log_error) - .instrument(info_span!("http", id = request_id)) + .inspect_err(|e| tracing::debug!(error = ?e)) + .instrument(info_span!("do_get_auth_info")) .await } @@ -193,8 +193,8 @@ impl NeonControlPlaneClient { Ok(rules) } - .map_err(crate::error::log_error) - .instrument(info_span!("http", id = request_id)) + .inspect_err(|e| tracing::debug!(error = ?e)) + .instrument(info_span!("do_get_endpoint_jwks")) .await } @@ -252,9 +252,8 @@ impl NeonControlPlaneClient { Ok(node) } - .map_err(crate::error::log_error) - // TODO: redo this span stuff - .instrument(info_span!("http", id = request_id)) + .inspect_err(|e| tracing::debug!(error = ?e)) + .instrument(info_span!("do_wake_compute")) .await } } diff --git a/proxy/src/error.rs b/proxy/src/error.rs index 7b693a7418..2221aac407 100644 --- a/proxy/src/error.rs +++ b/proxy/src/error.rs @@ -10,12 +10,6 @@ pub(crate) fn io_error(e: impl Into>) -> io::Err io::Error::new(io::ErrorKind::Other, e) } -/// A small combinator for pluggable error logging. -pub(crate) fn log_error(e: E) -> E { - tracing::error!("{e}"); - e -} - /// Marks errors that may be safely shown to a client. /// This trait can be seen as a specialized version of [`ToString`]. /// diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index ad7e1d2771..ba69f9cf2d 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -88,6 +88,7 @@ pub mod jemalloc; pub mod logging; pub mod metrics; pub mod parse; +pub mod postgres_rustls; pub mod protocol2; pub mod proxy; pub mod rate_limiter; diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index f91fcd4120..659c57c865 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -351,6 +351,7 @@ pub enum CancellationSource { pub enum CancellationOutcome { NotFound, Found, + RateLimitExceeded, } #[derive(LabelGroup)] diff --git a/proxy/src/postgres_rustls/mod.rs b/proxy/src/postgres_rustls/mod.rs new file mode 100644 index 0000000000..31e7915e89 --- /dev/null +++ b/proxy/src/postgres_rustls/mod.rs @@ -0,0 +1,158 @@ +use std::convert::TryFrom; +use std::sync::Arc; + +use rustls::pki_types::ServerName; +use rustls::ClientConfig; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio_postgres::tls::MakeTlsConnect; + +mod private { + use std::future::Future; + use std::io; + use std::pin::Pin; + use std::task::{Context, Poll}; + + use rustls::pki_types::ServerName; + use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; + use tokio_postgres::tls::{ChannelBinding, TlsConnect}; + use tokio_rustls::client::TlsStream; + use tokio_rustls::TlsConnector; + + use crate::config::TlsServerEndPoint; + + pub struct TlsConnectFuture { + inner: tokio_rustls::Connect, + } + + impl Future for TlsConnectFuture + where + S: AsyncRead + AsyncWrite + Unpin, + { + type Output = io::Result>; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.inner).poll(cx).map_ok(RustlsStream) + } + } + + pub struct RustlsConnect(pub RustlsConnectData); + + pub struct RustlsConnectData { + pub hostname: ServerName<'static>, + pub connector: TlsConnector, + } + + impl TlsConnect for RustlsConnect + where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, + { + type Stream = RustlsStream; + type Error = io::Error; + type Future = TlsConnectFuture; + + fn connect(self, stream: S) -> Self::Future { + TlsConnectFuture { + inner: self.0.connector.connect(self.0.hostname, stream), + } + } + } + + pub struct RustlsStream(TlsStream); + + impl tokio_postgres::tls::TlsStream for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn channel_binding(&self) -> ChannelBinding { + let (_, session) = self.0.get_ref(); + match session.peer_certificates() { + Some([cert, ..]) => TlsServerEndPoint::new(cert) + .ok() + .and_then(|cb| match cb { + TlsServerEndPoint::Sha256(hash) => Some(hash), + TlsServerEndPoint::Undefined => None, + }) + .map_or_else(ChannelBinding::none, |hash| { + ChannelBinding::tls_server_end_point(hash.to_vec()) + }), + _ => ChannelBinding::none(), + } + } + } + + impl AsyncRead for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_read(cx, buf) + } + } + + impl AsyncWrite for RustlsStream + where + S: AsyncRead + AsyncWrite + Unpin, + { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.0).poll_write(cx, buf) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_flush(cx) + } + + fn poll_shutdown( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.0).poll_shutdown(cx) + } + } +} + +/// A `MakeTlsConnect` implementation using `rustls`. +/// +/// That way you can connect to PostgreSQL using `rustls` as the TLS stack. +#[derive(Clone)] +pub struct MakeRustlsConnect { + config: Arc, +} + +impl MakeRustlsConnect { + /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`. + #[must_use] + pub fn new(config: ClientConfig) -> Self { + Self { + config: Arc::new(config), + } + } +} + +impl MakeTlsConnect for MakeRustlsConnect +where + S: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + type Stream = private::RustlsStream; + type TlsConnect = private::RustlsConnect; + type Error = rustls::pki_types::InvalidDnsNameError; + + fn make_tls_connect(&mut self, hostname: &str) -> Result { + ServerName::try_from(hostname).map(|dns_name| { + private::RustlsConnect(private::RustlsConnectData { + hostname: dns_name.to_owned(), + connector: Arc::clone(&self.config).into(), + }) + }) + } +} diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 9415b54a4a..956036d29d 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -10,7 +10,7 @@ pub(crate) mod wake_compute; use std::sync::Arc; pub use copy_bidirectional::{copy_bidirectional_client_compute, ErrorSource}; -use futures::TryFutureExt; +use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; use once_cell::sync::OnceCell; use pq_proto::{BeMessage as Be, StartupMessageParams}; @@ -69,6 +69,7 @@ pub async fn task_main( socket2::SockRef::from(&listener).set_keepalive(true)?; let connections = tokio_util::task::task_tracker::TaskTracker::new(); + let cancellations = tokio_util::task::task_tracker::TaskTracker::new(); while let Some(accept_result) = run_until_cancelled(listener.accept(), &cancellation_token).await @@ -82,6 +83,7 @@ pub async fn task_main( let session_id = uuid::Uuid::new_v4(); let cancellation_handler = Arc::clone(&cancellation_handler); + let cancellations = cancellations.clone(); debug!(protocol = "tcp", %session_id, "accepted new TCP connection"); let endpoint_rate_limiter2 = endpoint_rate_limiter.clone(); @@ -123,42 +125,40 @@ pub async fn task_main( crate::metrics::Protocol::Tcp, &config.region, ); - let span = ctx.span(); - let startup = Box::pin( - handle_client( - config, - auth_backend, - &ctx, - cancellation_handler, - socket, - ClientMode::Tcp, - endpoint_rate_limiter2, - conn_gauge, - ) - .instrument(span.clone()), - ); - let res = startup.await; + let res = handle_client( + config, + auth_backend, + &ctx, + cancellation_handler, + socket, + ClientMode::Tcp, + endpoint_rate_limiter2, + conn_gauge, + cancellations, + ) + .instrument(ctx.span()) + .boxed() + .await; match res { Err(e) => { - // todo: log and push to ctx the error kind ctx.set_error_kind(e.get_error_kind()); - warn!(parent: &span, "per-client task finished with an error: {e:#}"); + warn!(parent: &ctx.span(), "per-client task finished with an error: {e:#}"); } Ok(None) => { ctx.set_success(); } Ok(Some(p)) => { ctx.set_success(); - ctx.log_connect(); - match p.proxy_pass().instrument(span.clone()).await { + let _disconnect = ctx.log_connect(); + match p.proxy_pass().await { Ok(()) => {} Err(ErrorSource::Client(e)) => { - warn!(parent: &span, "per-client task finished with an IO error from the client: {e:#}"); + warn!(?session_id, "per-client task finished with an IO error from the client: {e:#}"); } Err(ErrorSource::Compute(e)) => { - error!(parent: &span, "per-client task finished with an IO error from the compute: {e:#}"); + error!(?session_id, "per-client task finished with an IO error from the compute: {e:#}"); } } } @@ -167,10 +167,12 @@ pub async fn task_main( } connections.close(); + cancellations.close(); drop(listener); // Drain connections connections.wait().await; + cancellations.wait().await; Ok(()) } @@ -253,6 +255,7 @@ pub(crate) async fn handle_client( mode: ClientMode, endpoint_rate_limiter: Arc, conn_gauge: NumClientConnectionsGuard<'static>, + cancellations: tokio_util::task::task_tracker::TaskTracker, ) -> Result>, ClientRequestError> { debug!( protocol = %ctx.protocol(), @@ -268,14 +271,31 @@ pub(crate) async fn handle_client( let record_handshake_error = !ctx.has_private_peer_addr(); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client); let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error); + let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? { HandshakeData::Startup(stream, params) => (stream, params), HandshakeData::Cancel(cancel_key_data) => { - return Ok(cancellation_handler - .cancel_session(cancel_key_data, ctx.session_id()) - .await - .map(|()| None)?) + // spawn a task to cancel the session, but don't wait for it + cancellations.spawn({ + let cancellation_handler_clone = Arc::clone(&cancellation_handler); + let session_id = ctx.session_id(); + let peer_ip = ctx.peer_addr(); + async move { + drop( + cancellation_handler_clone + .cancel_session( + cancel_key_data, + session_id, + peer_ip, + config.authentication_config.ip_allowlist_check_enabled, + ) + .await, + ); + } + }); + + return Ok(None); } }; drop(pause); @@ -346,6 +366,7 @@ pub(crate) async fn handle_client( client: stream, aux: node.aux.clone(), compute: node, + session_id: ctx.session_id(), _req: request_gauge, _conn: conn_gauge, _cancel: session, diff --git a/proxy/src/proxy/passthrough.rs b/proxy/src/proxy/passthrough.rs index 5e07c8eeae..dcaa81e5cd 100644 --- a/proxy/src/proxy/passthrough.rs +++ b/proxy/src/proxy/passthrough.rs @@ -59,6 +59,7 @@ pub(crate) struct ProxyPassthrough { pub(crate) client: Stream, pub(crate) compute: PostgresConnection, pub(crate) aux: MetricsAuxInfo, + pub(crate) session_id: uuid::Uuid, pub(crate) _req: NumConnectionRequestsGuard<'static>, pub(crate) _conn: NumClientConnectionsGuard<'static>, @@ -69,7 +70,7 @@ impl ProxyPassthrough { pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> { let res = proxy_pass(self.client, self.compute.stream, self.aux).await; if let Err(err) = self.compute.cancel_closure.try_cancel_query().await { - tracing::warn!(?err, "could not cancel the query in the database"); + tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database"); } res } diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index 3de8ca8736..2c2c2964b6 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -14,7 +14,6 @@ use rustls::pki_types; use tokio::io::DuplexStream; use tokio_postgres::config::SslMode; use tokio_postgres::tls::{MakeTlsConnect, NoTls}; -use tokio_postgres_rustls::MakeRustlsConnect; use super::connect_compute::ConnectMechanism; use super::retry::CouldRetry; @@ -29,6 +28,7 @@ use crate::control_plane::{ self, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo, NodeInfoCache, }; use crate::error::ErrorKind; +use crate::postgres_rustls::MakeRustlsConnect; use crate::types::{BranchId, EndpointId, ProjectId}; use crate::{sasl, scram}; diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs index 8a672d48dc..4e9206feff 100644 --- a/proxy/src/proxy/wake_compute.rs +++ b/proxy/src/proxy/wake_compute.rs @@ -1,9 +1,9 @@ -use tracing::{error, info, warn}; +use tracing::{error, info}; use super::connect_compute::ComputeConnectBackend; use crate::config::RetryConfig; use crate::context::RequestContext; -use crate::control_plane::errors::WakeComputeError; +use crate::control_plane::errors::{ControlPlaneError, WakeComputeError}; use crate::control_plane::CachedNodeInfo; use crate::error::ReportableError; use crate::metrics::{ @@ -11,6 +11,18 @@ use crate::metrics::{ }; use crate::proxy::retry::{retry_after, should_retry}; +// Use macro to retain original callsite. +macro_rules! log_wake_compute_error { + (error = ?$error:expr, $num_retries:expr, retriable = $retriable:literal) => { + match $error { + WakeComputeError::ControlPlane(ControlPlaneError::Message(_)) => { + info!(error = ?$error, num_retries = $num_retries, retriable = $retriable, "couldn't wake compute node") + } + _ => error!(error = ?$error, num_retries = $num_retries, retriable = $retriable, "couldn't wake compute node"), + } + }; +} + pub(crate) async fn wake_compute( num_retries: &mut u32, ctx: &RequestContext, @@ -20,7 +32,7 @@ pub(crate) async fn wake_compute( loop { match api.wake_compute(ctx).await { Err(e) if !should_retry(&e, *num_retries, config) => { - error!(error = ?e, num_retries, retriable = false, "couldn't wake compute node"); + log_wake_compute_error!(error = ?e, num_retries, retriable = false); report_error(&e, false); Metrics::get().proxy.retries_metric.observe( RetriesMetricGroup { @@ -32,7 +44,7 @@ pub(crate) async fn wake_compute( return Err(e); } Err(e) => { - warn!(error = ?e, num_retries, retriable = true, "couldn't wake compute node"); + log_wake_compute_error!(error = ?e, num_retries, retriable = true); report_error(&e, true); } Ok(n) => { diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 4259fd04f4..a048721e77 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -14,13 +14,13 @@ use tracing::info; use crate::intern::EndpointIdInt; -pub(crate) struct GlobalRateLimiter { +pub struct GlobalRateLimiter { data: Vec, info: Vec, } impl GlobalRateLimiter { - pub(crate) fn new(info: Vec) -> Self { + pub fn new(info: Vec) -> Self { Self { data: vec![ RateBucket { @@ -34,7 +34,7 @@ impl GlobalRateLimiter { } /// Check that number of connections is below `max_rps` rps. - pub(crate) fn check(&mut self) -> bool { + pub fn check(&mut self) -> bool { let now = Instant::now(); let should_allow_request = self @@ -137,6 +137,19 @@ impl RateBucketInfo { Self::new(200, Duration::from_secs(600)), ]; + /// All of these are per endpoint-maskedip pair. + /// Context: 4096 rounds of pbkdf2 take about 1ms of cpu time to execute (1 milli-cpu-second or 1mcpus). + /// + /// First bucket: 1000mcpus total per endpoint-ip pair + /// * 4096000 requests per second with 1 hash rounds. + /// * 1000 requests per second with 4096 hash rounds. + /// * 6.8 requests per second with 600000 hash rounds. + pub const DEFAULT_AUTH_SET: [Self; 3] = [ + Self::new(1000 * 4096, Duration::from_secs(1)), + Self::new(600 * 4096, Duration::from_secs(60)), + Self::new(300 * 4096, Duration::from_secs(600)), + ]; + pub fn rps(&self) -> f64 { (self.max_rpi as f64) / self.interval.as_secs_f64() } diff --git a/proxy/src/rate_limiter/mod.rs b/proxy/src/rate_limiter/mod.rs index 3ae2ecaf8f..5f90102da3 100644 --- a/proxy/src/rate_limiter/mod.rs +++ b/proxy/src/rate_limiter/mod.rs @@ -8,5 +8,4 @@ pub(crate) use limit_algorithm::aimd::Aimd; pub(crate) use limit_algorithm::{ DynamicLimiter, Outcome, RateLimitAlgorithm, RateLimiterConfig, Token, }; -pub(crate) use limiter::GlobalRateLimiter; -pub use limiter::{BucketRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; +pub use limiter::{BucketRateLimiter, GlobalRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; diff --git a/proxy/src/redis/cancellation_publisher.rs b/proxy/src/redis/cancellation_publisher.rs index 7392b0d316..228dbb7f64 100644 --- a/proxy/src/redis/cancellation_publisher.rs +++ b/proxy/src/redis/cancellation_publisher.rs @@ -1,3 +1,4 @@ +use core::net::IpAddr; use std::sync::Arc; use pq_proto::CancelKeyData; @@ -15,6 +16,7 @@ pub trait CancellationPublisherMut: Send + Sync + 'static { &mut self, cancel_key_data: CancelKeyData, session_id: Uuid, + peer_addr: IpAddr, ) -> anyhow::Result<()>; } @@ -24,6 +26,7 @@ pub trait CancellationPublisher: Send + Sync + 'static { &self, cancel_key_data: CancelKeyData, session_id: Uuid, + peer_addr: IpAddr, ) -> anyhow::Result<()>; } @@ -32,6 +35,7 @@ impl CancellationPublisher for () { &self, _cancel_key_data: CancelKeyData, _session_id: Uuid, + _peer_addr: IpAddr, ) -> anyhow::Result<()> { Ok(()) } @@ -42,8 +46,10 @@ impl CancellationPublisherMut for P { &mut self, cancel_key_data: CancelKeyData, session_id: Uuid, + peer_addr: IpAddr, ) -> anyhow::Result<()> { -

::try_publish(self, cancel_key_data, session_id).await +

::try_publish(self, cancel_key_data, session_id, peer_addr) + .await } } @@ -52,9 +58,10 @@ impl CancellationPublisher for Option

{ &self, cancel_key_data: CancelKeyData, session_id: Uuid, + peer_addr: IpAddr, ) -> anyhow::Result<()> { if let Some(p) = self { - p.try_publish(cancel_key_data, session_id).await + p.try_publish(cancel_key_data, session_id, peer_addr).await } else { Ok(()) } @@ -66,10 +73,11 @@ impl CancellationPublisher for Arc> { &self, cancel_key_data: CancelKeyData, session_id: Uuid, + peer_addr: IpAddr, ) -> anyhow::Result<()> { self.lock() .await - .try_publish(cancel_key_data, session_id) + .try_publish(cancel_key_data, session_id, peer_addr) .await } } @@ -97,11 +105,13 @@ impl RedisPublisherClient { &mut self, cancel_key_data: CancelKeyData, session_id: Uuid, + peer_addr: IpAddr, ) -> anyhow::Result<()> { let payload = serde_json::to_string(&Notification::Cancel(CancelSession { region_id: Some(self.region_id.clone()), cancel_key_data, session_id, + peer_addr: Some(peer_addr), }))?; let _: () = self.client.publish(PROXY_CHANNEL_NAME, payload).await?; Ok(()) @@ -120,13 +130,14 @@ impl RedisPublisherClient { &mut self, cancel_key_data: CancelKeyData, session_id: Uuid, + peer_addr: IpAddr, ) -> anyhow::Result<()> { // TODO: review redundant error duplication logs. if !self.limiter.check() { tracing::info!("Rate limit exceeded. Skipping cancellation message"); return Err(anyhow::anyhow!("Rate limit exceeded")); } - match self.publish(cancel_key_data, session_id).await { + match self.publish(cancel_key_data, session_id, peer_addr).await { Ok(()) => return Ok(()), Err(e) => { tracing::error!("failed to publish a message: {e}"); @@ -134,7 +145,7 @@ impl RedisPublisherClient { } tracing::info!("Publisher is disconnected. Reconnectiong..."); self.try_connect().await?; - self.publish(cancel_key_data, session_id).await + self.publish(cancel_key_data, session_id, peer_addr).await } } @@ -143,9 +154,13 @@ impl CancellationPublisherMut for RedisPublisherClient { &mut self, cancel_key_data: CancelKeyData, session_id: Uuid, + peer_addr: IpAddr, ) -> anyhow::Result<()> { tracing::info!("publishing cancellation key to Redis"); - match self.try_publish_internal(cancel_key_data, session_id).await { + match self + .try_publish_internal(cancel_key_data, session_id, peer_addr) + .await + { Ok(()) => { tracing::debug!("cancellation key successfuly published to Redis"); Ok(()) diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 62e7b1b565..9ac07b7e90 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -60,6 +60,7 @@ pub(crate) struct CancelSession { pub(crate) region_id: Option, pub(crate) cancel_key_data: CancelKeyData, pub(crate) session_id: Uuid, + pub(crate) peer_addr: Option, } fn deserialize_json_string<'de, D, T>(deserializer: D) -> Result @@ -137,10 +138,20 @@ impl MessageHandler { return Ok(()); } } + + // TODO: Remove unspecified peer_addr after the complete migration to the new format + let peer_addr = cancel_session + .peer_addr + .unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED)); // This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message. match self .cancellation_handler - .cancel_session(cancel_session.cancel_key_data, uuid::Uuid::nil()) + .cancel_session( + cancel_session.cancel_key_data, + uuid::Uuid::nil(), + peer_addr, + cancel_session.peer_addr.is_some(), + ) .await { Ok(()) => {} @@ -335,6 +346,7 @@ mod tests { cancel_key_data, region_id: None, session_id: uuid, + peer_addr: None, }); let text = serde_json::to_string(&msg)?; let result: Notification = serde_json::from_str(&text)?; @@ -344,6 +356,7 @@ mod tests { cancel_key_data, region_id: Some("region".to_string()), session_id: uuid, + peer_addr: None, }); let text = serde_json::to_string(&msg)?; let result: Notification = serde_json::from_str(&text)?; diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 3037e20888..75909f3358 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -333,7 +333,7 @@ impl PoolingBackend { debug!("setting up backend session state"); // initiates the auth session - if let Err(e) = client.query("select auth.init()", &[]).await { + if let Err(e) = client.execute("select auth.init()", &[]).await { discard.discard(); return Err(e.into()); } diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index bd262f45ed..c302eac568 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -6,9 +6,10 @@ use std::task::{ready, Poll}; use futures::future::poll_fn; use futures::Future; use smallvec::SmallVec; +use tokio::net::TcpStream; use tokio::time::Instant; use tokio_postgres::tls::NoTlsStream; -use tokio_postgres::{AsyncMessage, Socket}; +use tokio_postgres::AsyncMessage; use tokio_util::sync::CancellationToken; use tracing::{error, info, info_span, warn, Instrument}; #[cfg(test)] @@ -57,7 +58,7 @@ pub(crate) fn poll_client( ctx: &RequestContext, conn_info: ConnInfo, client: C, - mut connection: tokio_postgres::Connection, + mut connection: tokio_postgres::Connection, conn_id: uuid::Uuid, aux: MetricsAuxInfo, ) -> Client { diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index 9abe35db08..db9ac49dae 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -24,10 +24,11 @@ use p256::ecdsa::{Signature, SigningKey}; use parking_lot::RwLock; use serde_json::value::RawValue; use signature::Signer; +use tokio::net::TcpStream; use tokio::time::Instant; use tokio_postgres::tls::NoTlsStream; use tokio_postgres::types::ToSql; -use tokio_postgres::{AsyncMessage, Socket}; +use tokio_postgres::AsyncMessage; use tokio_util::sync::CancellationToken; use tracing::{debug, error, info, info_span, warn, Instrument}; @@ -163,7 +164,7 @@ pub(crate) fn poll_client( ctx: &RequestContext, conn_info: ConnInfo, client: C, - mut connection: tokio_postgres::Connection, + mut connection: tokio_postgres::Connection, key: SigningKey, conn_id: uuid::Uuid, aux: MetricsAuxInfo, @@ -286,11 +287,11 @@ impl ClientInnerCommon { let token = resign_jwt(&local_data.key, payload, local_data.jti)?; // initiates the auth session - self.inner.simple_query("discard all").await?; + self.inner.batch_execute("discard all").await?; self.inner - .query( + .execute( "select auth.jwt_session_init($1)", - &[&token as &(dyn ToSql + Sync)], + &[&&*token as &(dyn ToSql + Sync)], ) .await?; diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index 77025f419d..80b42f9e55 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -132,6 +132,7 @@ pub async fn task_main( let connections = tokio_util::task::task_tracker::TaskTracker::new(); connections.close(); // allows `connections.wait to complete` + let cancellations = tokio_util::task::task_tracker::TaskTracker::new(); while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await { let (conn, peer_addr) = res.context("could not accept TCP stream")?; if let Err(e) = conn.set_nodelay(true) { @@ -160,6 +161,7 @@ pub async fn task_main( let connections2 = connections.clone(); let cancellation_handler = cancellation_handler.clone(); let endpoint_rate_limiter = endpoint_rate_limiter.clone(); + let cancellations = cancellations.clone(); connections.spawn( async move { let conn_token2 = conn_token.clone(); @@ -188,6 +190,7 @@ pub async fn task_main( config, backend, connections2, + cancellations, cancellation_handler, endpoint_rate_limiter, conn_token, @@ -313,6 +316,7 @@ async fn connection_handler( config: &'static ProxyConfig, backend: Arc, connections: TaskTracker, + cancellations: TaskTracker, cancellation_handler: Arc, endpoint_rate_limiter: Arc, cancellation_token: CancellationToken, @@ -353,6 +357,7 @@ async fn connection_handler( // `request_handler` is not cancel safe. It expects to be cancelled only at specific times. // By spawning the future, we ensure it never gets cancelled until it decides to. + let cancellations = cancellations.clone(); let handler = connections.spawn( request_handler( req, @@ -364,6 +369,7 @@ async fn connection_handler( conn_info2.clone(), http_request_token, endpoint_rate_limiter.clone(), + cancellations, ) .in_current_span() .map_ok_or_else(api_error_into_response, |r| r), @@ -411,6 +417,7 @@ async fn request_handler( // used to cancel in-flight HTTP requests. not used to cancel websockets http_cancellation_token: CancellationToken, endpoint_rate_limiter: Arc, + cancellations: TaskTracker, ) -> Result>, ApiError> { let host = request .headers() @@ -436,6 +443,7 @@ async fn request_handler( let (response, websocket) = framed_websockets::upgrade::upgrade(&mut request) .map_err(|e| ApiError::BadRequest(e.into()))?; + let cancellations = cancellations.clone(); ws_connections.spawn( async move { if let Err(e) = websocket::serve_websocket( @@ -446,6 +454,7 @@ async fn request_handler( cancellation_handler, endpoint_rate_limiter, host, + cancellations, ) .await { diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 03b37bccd5..afd93d02f0 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -14,7 +14,7 @@ use hyper::{header, HeaderMap, Request, Response, StatusCode}; use pq_proto::StartupMessageParamsBuilder; use serde::Serialize; use serde_json::Value; -use tokio::time; +use tokio::time::{self, Instant}; use tokio_postgres::error::{DbError, ErrorPosition, SqlState}; use tokio_postgres::{GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, Transaction}; use tokio_util::sync::CancellationToken; @@ -980,10 +980,11 @@ async fn query_to_json( current_size: &mut usize, parsed_headers: HttpHeaders, ) -> Result<(ReadyForQueryStatus, impl Serialize), SqlOverHttpError> { - info!("executing query"); + let query_start = Instant::now(); + let query_params = data.params; let mut row_stream = std::pin::pin!(client.query_raw_txt(&data.query, query_params).await?); - info!("finished executing query"); + let query_acknowledged = Instant::now(); // Manually drain the stream into a vector to leave row_stream hanging // around to get a command tag. Also check that the response is not too @@ -1002,6 +1003,7 @@ async fn query_to_json( } } + let query_resp_end = Instant::now(); let ready = row_stream.ready_status(); // grab the command tag and number of rows affected @@ -1021,7 +1023,9 @@ async fn query_to_json( rows = rows.len(), ?ready, command_tag, - "finished reading rows" + acknowledgement = ?(query_acknowledged - query_start), + response = ?(query_resp_end - query_start), + "finished executing query" ); let columns_len = row_stream.columns().len(); diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 4088fea835..bdb83fe6be 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -123,6 +123,7 @@ impl AsyncBufRead for WebSocketRw { } } +#[allow(clippy::too_many_arguments)] pub(crate) async fn serve_websocket( config: &'static ProxyConfig, auth_backend: &'static crate::auth::Backend<'static, ()>, @@ -131,6 +132,7 @@ pub(crate) async fn serve_websocket( cancellation_handler: Arc, endpoint_rate_limiter: Arc, hostname: Option, + cancellations: tokio_util::task::task_tracker::TaskTracker, ) -> anyhow::Result<()> { let websocket = websocket.await?; let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket)); @@ -149,6 +151,7 @@ pub(crate) async fn serve_websocket( ClientMode::Websockets { hostname }, endpoint_rate_limiter, conn_gauge, + cancellations, )) .await; diff --git a/pyproject.toml b/pyproject.toml index 197946fff8..01d15ee6bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ authors = [] package-mode = false [tool.poetry.dependencies] -python = "^3.9" +python = "^3.11" pytest = "^7.4.4" psycopg2-binary = "^2.9.10" typing-extensions = "^4.6.1" @@ -33,7 +33,7 @@ types-psutil = "^5.9.5.12" types-toml = "^0.10.8.6" pytest-httpserver = "^1.0.8" aiohttp = "3.10.11" -pytest-rerunfailures = "^13.0" +pytest-rerunfailures = "^15.0" types-pytest-lazy-fixture = "^0.6.3.3" pytest-split = "^0.8.1" zstandard = "^0.21.0" @@ -51,7 +51,7 @@ testcontainers = "^4.8.1" jsonnet = "^0.20.0" [tool.poetry.group.dev.dependencies] -mypy = "==1.3.0" +mypy = "==1.13.0" ruff = "^0.7.0" [build-system] @@ -89,7 +89,7 @@ module = [ ignore_missing_imports = true [tool.ruff] -target-version = "py39" +target-version = "py311" extend-exclude = [ "vendor/", "target/", @@ -108,6 +108,3 @@ select = [ "B", # bugbear "UP", # pyupgrade ] - -[tool.ruff.lint.pyupgrade] -keep-runtime-typing = true # Remove this stanza when we require Python 3.10 diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 92b7929c7f..f0661a32e0 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "1.82.0" +channel = "1.83.0" profile = "default" # The default profile includes rustc, rust-std, cargo, rust-docs, rustfmt and clippy. # https://rust-lang.github.io/rustup/concepts/profiles.html diff --git a/safekeeper/Cargo.toml b/safekeeper/Cargo.toml index 85561e4aff..635a9222e1 100644 --- a/safekeeper/Cargo.toml +++ b/safekeeper/Cargo.toml @@ -28,8 +28,10 @@ hyper0.workspace = true futures.workspace = true once_cell.workspace = true parking_lot.workspace = true +pageserver_api.workspace = true postgres.workspace = true postgres-protocol.workspace = true +pprof.workspace = true rand.workspace = true regex.workspace = true scopeguard.workspace = true @@ -57,6 +59,7 @@ sd-notify.workspace = true storage_broker.workspace = true tokio-stream.workspace = true utils.workspace = true +wal_decoder.workspace = true workspace_hack.workspace = true diff --git a/safekeeper/benches/README.md b/safekeeper/benches/README.md index 4119cc8d6e..d73fbccf05 100644 --- a/safekeeper/benches/README.md +++ b/safekeeper/benches/README.md @@ -14,6 +14,10 @@ cargo bench --package safekeeper --bench receive_wal process_msg/fsync=false # List available benchmarks. cargo bench --package safekeeper --benches -- --list + +# Generate flamegraph profiles using pprof-rs, profiling for 10 seconds. +# Output in target/criterion/*/profile/flamegraph.svg. +cargo bench --package safekeeper --bench receive_wal process_msg/fsync=false --profile-time 10 ``` Additional charts and statistics are available in `target/criterion/report/index.html`. diff --git a/safekeeper/benches/receive_wal.rs b/safekeeper/benches/receive_wal.rs index e32d7526ca..c637b4fb24 100644 --- a/safekeeper/benches/receive_wal.rs +++ b/safekeeper/benches/receive_wal.rs @@ -10,6 +10,7 @@ use camino_tempfile::tempfile; use criterion::{criterion_group, criterion_main, BatchSize, Bencher, Criterion}; use itertools::Itertools as _; use postgres_ffi::v17::wal_generator::{LogicalMessageGenerator, WalGenerator}; +use pprof::criterion::{Output, PProfProfiler}; use safekeeper::receive_wal::{self, WalAcceptor}; use safekeeper::safekeeper::{ AcceptorProposerMessage, AppendRequest, AppendRequestHeader, ProposerAcceptorMessage, @@ -24,8 +25,9 @@ const GB: usize = 1024 * MB; // Register benchmarks with Criterion. criterion_group!( - benches, - bench_process_msg, + name = benches; + config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); + targets = bench_process_msg, bench_wal_acceptor, bench_wal_acceptor_throughput, bench_file_write diff --git a/safekeeper/src/handler.rs b/safekeeper/src/handler.rs index 3f00b69cde..8dd2929a03 100644 --- a/safekeeper/src/handler.rs +++ b/safekeeper/src/handler.rs @@ -2,11 +2,15 @@ //! protocol commands. use anyhow::Context; +use pageserver_api::models::ShardParameters; +use pageserver_api::shard::{ShardIdentity, ShardStripeSize}; use std::future::Future; use std::str::{self, FromStr}; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, info, info_span, Instrument}; +use utils::postgres_client::PostgresClientProtocol; +use utils::shard::{ShardCount, ShardNumber}; use crate::auth::check_permission; use crate::json_ctrl::{handle_json_ctrl, AppendLogicalMessage}; @@ -35,6 +39,8 @@ pub struct SafekeeperPostgresHandler { pub tenant_id: Option, pub timeline_id: Option, pub ttid: TenantTimelineId, + pub shard: Option, + pub protocol: Option, /// Unique connection id is logged in spans for observability. pub conn_id: ConnectionId, /// Auth scope allowed on the connections and public key used to check auth tokens. None if auth is not configured. @@ -107,11 +113,21 @@ impl postgres_backend::Handler ) -> Result<(), QueryError> { if let FeStartupPacket::StartupMessage { params, .. } = sm { if let Some(options) = params.options_raw() { + let mut shard_count: Option = None; + let mut shard_number: Option = None; + let mut shard_stripe_size: Option = None; + for opt in options { // FIXME `ztenantid` and `ztimelineid` left for compatibility during deploy, // remove these after the PR gets deployed: // https://github.com/neondatabase/neon/pull/2433#discussion_r970005064 match opt.split_once('=') { + Some(("protocol", value)) => { + self.protocol = + Some(serde_json::from_str(value).with_context(|| { + format!("Failed to parse {value} as protocol") + })?); + } Some(("ztenantid", value)) | Some(("tenant_id", value)) => { self.tenant_id = Some(value.parse().with_context(|| { format!("Failed to parse {value} as tenant id") @@ -127,9 +143,54 @@ impl postgres_backend::Handler metrics.set_client_az(client_az) } } + Some(("shard_count", value)) => { + shard_count = Some(value.parse::().with_context(|| { + format!("Failed to parse {value} as shard count") + })?); + } + Some(("shard_number", value)) => { + shard_number = Some(value.parse::().with_context(|| { + format!("Failed to parse {value} as shard number") + })?); + } + Some(("shard_stripe_size", value)) => { + shard_stripe_size = Some(value.parse::().with_context(|| { + format!("Failed to parse {value} as shard stripe size") + })?); + } _ => continue, } } + + match self.protocol() { + PostgresClientProtocol::Vanilla => { + if shard_count.is_some() + || shard_number.is_some() + || shard_stripe_size.is_some() + { + return Err(QueryError::Other(anyhow::anyhow!( + "Shard params specified for vanilla protocol" + ))); + } + } + PostgresClientProtocol::Interpreted { .. } => { + match (shard_count, shard_number, shard_stripe_size) { + (Some(count), Some(number), Some(stripe_size)) => { + let params = ShardParameters { + count: ShardCount(count), + stripe_size: ShardStripeSize(stripe_size), + }; + self.shard = + Some(ShardIdentity::from_params(ShardNumber(number), ¶ms)); + } + _ => { + return Err(QueryError::Other(anyhow::anyhow!( + "Shard params were not specified" + ))); + } + } + } + } } if let Some(app_name) = params.get("application_name") { @@ -150,6 +211,12 @@ impl postgres_backend::Handler tracing::field::debug(self.appname.clone()), ); + if let Some(shard) = self.shard.as_ref() { + if let Some(slug) = shard.shard_slug().strip_prefix("-") { + tracing::Span::current().record("shard", tracing::field::display(slug)); + } + } + Ok(()) } else { Err(QueryError::Other(anyhow::anyhow!( @@ -258,6 +325,8 @@ impl SafekeeperPostgresHandler { tenant_id: None, timeline_id: None, ttid: TenantTimelineId::empty(), + shard: None, + protocol: None, conn_id, claims: None, auth, @@ -265,6 +334,10 @@ impl SafekeeperPostgresHandler { } } + pub fn protocol(&self) -> PostgresClientProtocol { + self.protocol.unwrap_or(PostgresClientProtocol::Vanilla) + } + // when accessing management api supply None as an argument // when using to authorize tenant pass corresponding tenant id fn check_permission(&self, tenant_id: Option) -> Result<(), QueryError> { diff --git a/safekeeper/src/http/routes.rs b/safekeeper/src/http/routes.rs index df68f8a68e..28294abdb9 100644 --- a/safekeeper/src/http/routes.rs +++ b/safekeeper/src/http/routes.rs @@ -1,7 +1,6 @@ -use hyper::{Body, Request, Response, StatusCode, Uri}; -use once_cell::sync::Lazy; +use hyper::{Body, Request, Response, StatusCode}; use serde::{Deserialize, Serialize}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::fmt; use std::io::Write as _; use std::str::FromStr; @@ -14,7 +13,9 @@ use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::CancellationToken; use tracing::{info_span, Instrument}; use utils::failpoint_support::failpoints_handler; -use utils::http::endpoint::{prometheus_metrics_handler, request_span, ChannelWriter}; +use utils::http::endpoint::{ + profile_cpu_handler, prometheus_metrics_handler, request_span, ChannelWriter, +}; use utils::http::request::parse_query_param; use postgres_ffi::WAL_SEGMENT_SIZE; @@ -572,14 +573,8 @@ pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder let mut router = endpoint::make_router(); if conf.http_auth.is_some() { router = router.middleware(auth_middleware(|request| { - #[allow(clippy::mutable_key_type)] - static ALLOWLIST_ROUTES: Lazy> = Lazy::new(|| { - ["/v1/status", "/metrics"] - .iter() - .map(|v| v.parse().unwrap()) - .collect() - }); - if ALLOWLIST_ROUTES.contains(request.uri()) { + const ALLOWLIST_ROUTES: &[&str] = &["/v1/status", "/metrics", "/profile/cpu"]; + if ALLOWLIST_ROUTES.contains(&request.uri().path()) { None } else { // Option> is always provided as data below, hence unwrap(). @@ -598,6 +593,7 @@ pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder .data(Arc::new(conf)) .data(auth) .get("/metrics", |r| request_span(r, prometheus_metrics_handler)) + .get("/profile/cpu", |r| request_span(r, profile_cpu_handler)) .get("/v1/status", |r| request_span(r, status_handler)) .put("/v1/failpoints", |r| { request_span(r, move |r| async { diff --git a/safekeeper/src/lib.rs b/safekeeper/src/lib.rs index 6d68b6b59b..abe6e00a66 100644 --- a/safekeeper/src/lib.rs +++ b/safekeeper/src/lib.rs @@ -29,6 +29,7 @@ pub mod receive_wal; pub mod recovery; pub mod remove_wal; pub mod safekeeper; +pub mod send_interpreted_wal; pub mod send_wal; pub mod state; pub mod timeline; @@ -38,6 +39,7 @@ pub mod timeline_manager; pub mod timelines_set; pub mod wal_backup; pub mod wal_backup_partial; +pub mod wal_reader_stream; pub mod wal_service; pub mod wal_storage; diff --git a/safekeeper/src/recovery.rs b/safekeeper/src/recovery.rs index 9c4149d8f1..7b87166aa0 100644 --- a/safekeeper/src/recovery.rs +++ b/safekeeper/src/recovery.rs @@ -17,6 +17,7 @@ use tokio::{ use tokio_postgres::replication::ReplicationStream; use tokio_postgres::types::PgLsn; use tracing::*; +use utils::postgres_client::{ConnectionConfigArgs, PostgresClientProtocol}; use utils::{id::NodeId, lsn::Lsn, postgres_client::wal_stream_connection_config}; use crate::receive_wal::{WalAcceptor, REPLY_QUEUE_SIZE}; @@ -325,7 +326,17 @@ async fn recovery_stream( conf: &SafeKeeperConf, ) -> anyhow::Result { // TODO: pass auth token - let cfg = wal_stream_connection_config(tli.ttid, &donor.pg_connstr, None, None)?; + let connection_conf_args = ConnectionConfigArgs { + protocol: PostgresClientProtocol::Vanilla, + ttid: tli.ttid, + shard_number: None, + shard_count: None, + shard_stripe_size: None, + listen_pg_addr_str: &donor.pg_connstr, + auth_token: None, + availability_zone: None, + }; + let cfg = wal_stream_connection_config(connection_conf_args)?; let mut cfg = cfg.to_tokio_postgres_config(); // It will make safekeeper give out not committed WAL (up to flush_lsn). cfg.application_name(&format!("safekeeper_{}", conf.my_id)); diff --git a/safekeeper/src/send_interpreted_wal.rs b/safekeeper/src/send_interpreted_wal.rs new file mode 100644 index 0000000000..2589030422 --- /dev/null +++ b/safekeeper/src/send_interpreted_wal.rs @@ -0,0 +1,148 @@ +use std::time::Duration; + +use anyhow::Context; +use futures::StreamExt; +use pageserver_api::shard::ShardIdentity; +use postgres_backend::{CopyStreamHandlerEnd, PostgresBackend}; +use postgres_ffi::MAX_SEND_SIZE; +use postgres_ffi::{get_current_timestamp, waldecoder::WalStreamDecoder}; +use pq_proto::{BeMessage, InterpretedWalRecordsBody, WalSndKeepAlive}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::time::MissedTickBehavior; +use utils::lsn::Lsn; +use utils::postgres_client::Compression; +use utils::postgres_client::InterpretedFormat; +use wal_decoder::models::{InterpretedWalRecord, InterpretedWalRecords}; +use wal_decoder::wire_format::ToWireFormat; + +use crate::send_wal::EndWatchView; +use crate::wal_reader_stream::{WalBytes, WalReaderStreamBuilder}; + +/// Shard-aware interpreted record sender. +/// This is used for sending WAL to the pageserver. Said WAL +/// is pre-interpreted and filtered for the shard. +pub(crate) struct InterpretedWalSender<'a, IO> { + pub(crate) format: InterpretedFormat, + pub(crate) compression: Option, + pub(crate) pgb: &'a mut PostgresBackend, + pub(crate) wal_stream_builder: WalReaderStreamBuilder, + pub(crate) end_watch_view: EndWatchView, + pub(crate) shard: ShardIdentity, + pub(crate) pg_version: u32, + pub(crate) appname: Option, +} + +struct Batch { + wal_end_lsn: Lsn, + available_wal_end_lsn: Lsn, + records: InterpretedWalRecords, +} + +impl InterpretedWalSender<'_, IO> { + /// Send interpreted WAL to a receiver. + /// Stops when an error occurs or the receiver is caught up and there's no active compute. + /// + /// Err(CopyStreamHandlerEnd) is always returned; Result is used only for ? + /// convenience. + pub(crate) async fn run(self) -> Result<(), CopyStreamHandlerEnd> { + let mut wal_position = self.wal_stream_builder.start_pos(); + let mut wal_decoder = + WalStreamDecoder::new(self.wal_stream_builder.start_pos(), self.pg_version); + + let stream = self.wal_stream_builder.build(MAX_SEND_SIZE).await?; + let mut stream = std::pin::pin!(stream); + + let mut keepalive_ticker = tokio::time::interval(Duration::from_secs(1)); + keepalive_ticker.set_missed_tick_behavior(MissedTickBehavior::Skip); + keepalive_ticker.reset(); + + let (tx, mut rx) = tokio::sync::mpsc::channel::(2); + + loop { + tokio::select! { + // Get some WAL from the stream and then: decode, interpret and push it down the + // pipeline. + wal = stream.next(), if tx.capacity() > 0 => { + let WalBytes { wal, wal_start_lsn: _, wal_end_lsn, available_wal_end_lsn } = match wal { + Some(some) => some?, + None => { break; } + }; + + wal_position = wal_end_lsn; + wal_decoder.feed_bytes(&wal); + + let mut records = Vec::new(); + let mut max_next_record_lsn = None; + while let Some((next_record_lsn, recdata)) = wal_decoder + .poll_decode() + .with_context(|| "Failed to decode WAL")? + { + assert!(next_record_lsn.is_aligned()); + max_next_record_lsn = Some(next_record_lsn); + + // Deserialize and interpret WAL record + let interpreted = InterpretedWalRecord::from_bytes_filtered( + recdata, + &self.shard, + next_record_lsn, + self.pg_version, + ) + .with_context(|| "Failed to interpret WAL")?; + + if !interpreted.is_empty() { + records.push(interpreted); + } + } + + let batch = InterpretedWalRecords { + records, + next_record_lsn: max_next_record_lsn + }; + + tx.send(Batch {wal_end_lsn, available_wal_end_lsn, records: batch}).await.unwrap(); + }, + // For a previously interpreted batch, serialize it and push it down the wire. + batch = rx.recv() => { + let batch = match batch { + Some(b) => b, + None => { break; } + }; + + let buf = batch + .records + .to_wire(self.format, self.compression) + .await + .with_context(|| "Failed to serialize interpreted WAL") + .map_err(CopyStreamHandlerEnd::from)?; + + // Reset the keep alive ticker since we are sending something + // over the wire now. + keepalive_ticker.reset(); + + self.pgb + .write_message(&BeMessage::InterpretedWalRecords(InterpretedWalRecordsBody { + streaming_lsn: batch.wal_end_lsn.0, + commit_lsn: batch.available_wal_end_lsn.0, + data: &buf, + })).await?; + } + // Send a periodic keep alive when the connection has been idle for a while. + _ = keepalive_ticker.tick() => { + self.pgb + .write_message(&BeMessage::KeepAlive(WalSndKeepAlive { + wal_end: self.end_watch_view.get().0, + timestamp: get_current_timestamp(), + request_reply: true, + })) + .await?; + } + } + } + + // The loop above ends when the receiver is caught up and there's no more WAL to send. + Err(CopyStreamHandlerEnd::ServerInitiated(format!( + "ending streaming to {:?} at {}, receiver is caughtup and there is no computes", + self.appname, wal_position, + ))) + } +} diff --git a/safekeeper/src/send_wal.rs b/safekeeper/src/send_wal.rs index aa65ec851b..225b7f4c05 100644 --- a/safekeeper/src/send_wal.rs +++ b/safekeeper/src/send_wal.rs @@ -5,12 +5,15 @@ use crate::handler::SafekeeperPostgresHandler; use crate::metrics::RECEIVED_PS_FEEDBACKS; use crate::receive_wal::WalReceivers; use crate::safekeeper::{Term, TermLsn}; +use crate::send_interpreted_wal::InterpretedWalSender; use crate::timeline::WalResidentTimeline; +use crate::wal_reader_stream::WalReaderStreamBuilder; use crate::wal_service::ConnectionId; use crate::wal_storage::WalReader; use crate::GlobalTimelines; use anyhow::{bail, Context as AnyhowContext}; use bytes::Bytes; +use futures::future::Either; use parking_lot::Mutex; use postgres_backend::PostgresBackend; use postgres_backend::{CopyStreamHandlerEnd, PostgresBackendReader, QueryError}; @@ -22,6 +25,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use utils::failpoint_support; use utils::id::TenantTimelineId; use utils::pageserver_feedback::PageserverFeedback; +use utils::postgres_client::PostgresClientProtocol; use std::cmp::{max, min}; use std::net::SocketAddr; @@ -226,7 +230,7 @@ impl WalSenders { /// Get remote_consistent_lsn reported by the pageserver. Returns None if /// client is not pageserver. - fn get_ws_remote_consistent_lsn(self: &Arc, id: WalSenderId) -> Option { + pub fn get_ws_remote_consistent_lsn(self: &Arc, id: WalSenderId) -> Option { let shared = self.mutex.lock(); let slot = shared.get_slot(id); match slot.feedback { @@ -370,6 +374,16 @@ pub struct WalSenderGuard { walsenders: Arc, } +impl WalSenderGuard { + pub fn id(&self) -> WalSenderId { + self.id + } + + pub fn walsenders(&self) -> &Arc { + &self.walsenders + } +} + impl Drop for WalSenderGuard { fn drop(&mut self) { self.walsenders.unregister(self.id); @@ -440,11 +454,12 @@ impl SafekeeperPostgresHandler { } info!( - "starting streaming from {:?}, available WAL ends at {}, recovery={}, appname={:?}", + "starting streaming from {:?}, available WAL ends at {}, recovery={}, appname={:?}, protocol={:?}", start_pos, end_pos, matches!(end_watch, EndWatch::Flush(_)), - appname + appname, + self.protocol(), ); // switch to copy @@ -456,21 +471,56 @@ impl SafekeeperPostgresHandler { // not synchronized with sends, so this avoids deadlocks. let reader = pgb.split().context("START_REPLICATION split")?; + let send_fut = match self.protocol() { + PostgresClientProtocol::Vanilla => { + let sender = WalSender { + pgb, + // should succeed since we're already holding another guard + tli: tli.wal_residence_guard().await?, + appname, + start_pos, + end_pos, + term, + end_watch, + ws_guard: ws_guard.clone(), + wal_reader, + send_buf: vec![0u8; MAX_SEND_SIZE], + }; + + Either::Left(sender.run()) + } + PostgresClientProtocol::Interpreted { + format, + compression, + } => { + let pg_version = tli.tli.get_state().await.1.server.pg_version / 10000; + let end_watch_view = end_watch.view(); + let wal_stream_builder = WalReaderStreamBuilder { + tli: tli.wal_residence_guard().await?, + start_pos, + end_pos, + term, + end_watch, + wal_sender_guard: ws_guard.clone(), + }; + + let sender = InterpretedWalSender { + format, + compression, + pgb, + wal_stream_builder, + end_watch_view, + shard: self.shard.unwrap(), + pg_version, + appname, + }; + + Either::Right(sender.run()) + } + }; + let tli_cancel = tli.cancel.clone(); - let mut sender = WalSender { - pgb, - // should succeed since we're already holding another guard - tli: tli.wal_residence_guard().await?, - appname, - start_pos, - end_pos, - term, - end_watch, - ws_guard: ws_guard.clone(), - wal_reader, - send_buf: vec![0u8; MAX_SEND_SIZE], - }; let mut reply_reader = ReplyReader { reader, ws_guard: ws_guard.clone(), @@ -479,7 +529,7 @@ impl SafekeeperPostgresHandler { let res = tokio::select! { // todo: add read|write .context to these errors - r = sender.run() => r, + r = send_fut => r, r = reply_reader.run() => r, _ = tli_cancel.cancelled() => { return Err(CopyStreamHandlerEnd::Cancelled); @@ -504,16 +554,22 @@ impl SafekeeperPostgresHandler { } } +/// TODO(vlad): maybe lift this instead /// Walsender streams either up to commit_lsn (normally) or flush_lsn in the /// given term (recovery by walproposer or peer safekeeper). -enum EndWatch { +#[derive(Clone)] +pub(crate) enum EndWatch { Commit(Receiver), Flush(Receiver), } impl EndWatch { + pub(crate) fn view(&self) -> EndWatchView { + EndWatchView(self.clone()) + } + /// Get current end of WAL. - fn get(&self) -> Lsn { + pub(crate) fn get(&self) -> Lsn { match self { EndWatch::Commit(r) => *r.borrow(), EndWatch::Flush(r) => r.borrow().lsn, @@ -521,15 +577,44 @@ impl EndWatch { } /// Wait for the update. - async fn changed(&mut self) -> anyhow::Result<()> { + pub(crate) async fn changed(&mut self) -> anyhow::Result<()> { match self { EndWatch::Commit(r) => r.changed().await?, EndWatch::Flush(r) => r.changed().await?, } Ok(()) } + + pub(crate) async fn wait_for_lsn( + &mut self, + lsn: Lsn, + client_term: Option, + ) -> anyhow::Result { + loop { + let end_pos = self.get(); + if end_pos > lsn { + return Ok(end_pos); + } + if let EndWatch::Flush(rx) = &self { + let curr_term = rx.borrow().term; + if let Some(client_term) = client_term { + if curr_term != client_term { + bail!("term changed: requested {}, now {}", client_term, curr_term); + } + } + } + self.changed().await?; + } + } } +pub(crate) struct EndWatchView(EndWatch); + +impl EndWatchView { + pub(crate) fn get(&self) -> Lsn { + self.0.get() + } +} /// A half driving sending WAL. struct WalSender<'a, IO> { pgb: &'a mut PostgresBackend, @@ -566,7 +651,7 @@ impl WalSender<'_, IO> { /// /// Err(CopyStreamHandlerEnd) is always returned; Result is used only for ? /// convenience. - async fn run(&mut self) -> Result<(), CopyStreamHandlerEnd> { + async fn run(mut self) -> Result<(), CopyStreamHandlerEnd> { loop { // Wait for the next portion if it is not there yet, or just // update our end of WAL available for sending value, we diff --git a/safekeeper/src/wal_reader_stream.rs b/safekeeper/src/wal_reader_stream.rs new file mode 100644 index 0000000000..f8c0c502cd --- /dev/null +++ b/safekeeper/src/wal_reader_stream.rs @@ -0,0 +1,149 @@ +use std::sync::Arc; + +use async_stream::try_stream; +use bytes::Bytes; +use futures::Stream; +use postgres_backend::CopyStreamHandlerEnd; +use std::time::Duration; +use tokio::time::timeout; +use utils::lsn::Lsn; + +use crate::{ + safekeeper::Term, + send_wal::{EndWatch, WalSenderGuard}, + timeline::WalResidentTimeline, +}; + +pub(crate) struct WalReaderStreamBuilder { + pub(crate) tli: WalResidentTimeline, + pub(crate) start_pos: Lsn, + pub(crate) end_pos: Lsn, + pub(crate) term: Option, + pub(crate) end_watch: EndWatch, + pub(crate) wal_sender_guard: Arc, +} + +impl WalReaderStreamBuilder { + pub(crate) fn start_pos(&self) -> Lsn { + self.start_pos + } +} + +pub(crate) struct WalBytes { + /// Raw PG WAL + pub(crate) wal: Bytes, + /// Start LSN of [`Self::wal`] + #[allow(dead_code)] + pub(crate) wal_start_lsn: Lsn, + /// End LSN of [`Self::wal`] + pub(crate) wal_end_lsn: Lsn, + /// End LSN of WAL available on the safekeeper. + /// + /// For pagservers this will be commit LSN, + /// while for the compute it will be the flush LSN. + pub(crate) available_wal_end_lsn: Lsn, +} + +impl WalReaderStreamBuilder { + /// Builds a stream of Postgres WAL starting from [`Self::start_pos`]. + /// The stream terminates when the receiver (pageserver) is fully caught up + /// and there's no active computes. + pub(crate) async fn build( + self, + buffer_size: usize, + ) -> anyhow::Result>> { + // TODO(vlad): The code below duplicates functionality from [`crate::send_wal`]. + // We can make the raw WAL sender use this stream too and remove the duplication. + let Self { + tli, + mut start_pos, + mut end_pos, + term, + mut end_watch, + wal_sender_guard, + } = self; + let mut wal_reader = tli.get_walreader(start_pos).await?; + let mut buffer = vec![0; buffer_size]; + + const POLL_STATE_TIMEOUT: Duration = Duration::from_secs(1); + + Ok(try_stream! { + loop { + let have_something_to_send = end_pos > start_pos; + + if !have_something_to_send { + // wait for lsn + let res = timeout(POLL_STATE_TIMEOUT, end_watch.wait_for_lsn(start_pos, term)).await; + match res { + Ok(ok) => { + end_pos = ok?; + }, + Err(_) => { + if let EndWatch::Commit(_) = end_watch { + if let Some(remote_consistent_lsn) = wal_sender_guard + .walsenders() + .get_ws_remote_consistent_lsn(wal_sender_guard.id()) + { + if tli.should_walsender_stop(remote_consistent_lsn).await { + // Stop streaming if the receivers are caught up and + // there's no active compute. This causes the loop in + // [`crate::send_interpreted_wal::InterpretedWalSender::run`] + // to exit and terminate the WAL stream. + return; + } + } + } + + continue; + } + } + } + + + assert!( + end_pos > start_pos, + "nothing to send after waiting for WAL" + ); + + // try to send as much as available, capped by the buffer size + let mut chunk_end_pos = start_pos + buffer_size as u64; + // if we went behind available WAL, back off + if chunk_end_pos >= end_pos { + chunk_end_pos = end_pos; + } else { + // If sending not up to end pos, round down to page boundary to + // avoid breaking WAL record not at page boundary, as protocol + // demands. See walsender.c (XLogSendPhysical). + chunk_end_pos = chunk_end_pos + .checked_sub(chunk_end_pos.block_offset()) + .unwrap(); + } + let send_size = (chunk_end_pos.0 - start_pos.0) as usize; + let buffer = &mut buffer[..send_size]; + let send_size: usize; + { + // If uncommitted part is being pulled, check that the term is + // still the expected one. + let _term_guard = if let Some(t) = term { + Some(tli.acquire_term(t).await?) + } else { + None + }; + // Read WAL into buffer. send_size can be additionally capped to + // segment boundary here. + send_size = wal_reader.read(buffer).await? + }; + let wal = Bytes::copy_from_slice(&buffer[..send_size]); + + yield WalBytes { + wal, + wal_start_lsn: start_pos, + wal_end_lsn: start_pos + send_size as u64, + available_wal_end_lsn: end_pos + }; + + start_pos += send_size as u64; + } + }) + } +} diff --git a/safekeeper/src/wal_service.rs b/safekeeper/src/wal_service.rs index 1ab54d4cce..5248d545db 100644 --- a/safekeeper/src/wal_service.rs +++ b/safekeeper/src/wal_service.rs @@ -44,7 +44,7 @@ pub async fn task_main( error!("connection handler exited: {}", err); } } - .instrument(info_span!("", cid = %conn_id, ttid = field::Empty, application_name = field::Empty)), + .instrument(info_span!("", cid = %conn_id, ttid = field::Empty, application_name = field::Empty, shard = field::Empty)), ); } } diff --git a/scripts/flaky_tests.py b/scripts/flaky_tests.py deleted file mode 100755 index 9312f8b3e7..0000000000 --- a/scripts/flaky_tests.py +++ /dev/null @@ -1,147 +0,0 @@ -#! /usr/bin/env python3 - -from __future__ import annotations - -import argparse -import json -import logging -import os -from collections import defaultdict -from typing import TYPE_CHECKING - -import psycopg2 -import psycopg2.extras -import toml - -if TYPE_CHECKING: - from typing import Any, Optional - -FLAKY_TESTS_QUERY = """ - SELECT - DISTINCT parent_suite, suite, name - FROM results - WHERE - started_at > CURRENT_DATE - INTERVAL '%s' day - AND ( - (status IN ('failed', 'broken') AND reference = 'refs/heads/main') - OR flaky - ) - ; -""" - - -def main(args: argparse.Namespace): - connstr = args.connstr - interval_days = args.days - output = args.output - - build_type = args.build_type - pg_version = args.pg_version - - res: defaultdict[str, defaultdict[str, dict[str, bool]]] - res = defaultdict(lambda: defaultdict(dict)) - - try: - logging.info("connecting to the database...") - with psycopg2.connect(connstr, connect_timeout=30) as conn: - with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: - logging.info("fetching flaky tests...") - cur.execute(FLAKY_TESTS_QUERY, (interval_days,)) - rows = cur.fetchall() - except psycopg2.OperationalError as exc: - logging.error("cannot fetch flaky tests from the DB due to an error", exc) - rows = [] - - # If a test run has non-default PAGESERVER_VIRTUAL_FILE_IO_ENGINE (i.e. not empty, not tokio-epoll-uring), - # use it to parametrize test name along with build_type and pg_version - # - # See test_runner/fixtures/parametrize.py for details - if (io_engine := os.getenv("PAGESERVER_VIRTUAL_FILE_IO_ENGINE", "")) not in ( - "", - "tokio-epoll-uring", - ): - pageserver_virtual_file_io_engine_parameter = f"-{io_engine}" - else: - pageserver_virtual_file_io_engine_parameter = "" - - # re-use existing records of flaky tests from before parametrization by compaction_algorithm - def get_pageserver_default_tenant_config_compaction_algorithm() -> Optional[dict[str, Any]]: - """Duplicated from parametrize.py""" - toml_table = os.getenv("PAGESERVER_DEFAULT_TENANT_CONFIG_COMPACTION_ALGORITHM") - if toml_table is None: - return None - v = toml.loads(toml_table) - assert isinstance(v, dict) - return v - - pageserver_default_tenant_config_compaction_algorithm_parameter = "" - if ( - explicit_default := get_pageserver_default_tenant_config_compaction_algorithm() - ) is not None: - pageserver_default_tenant_config_compaction_algorithm_parameter = ( - f"-{explicit_default['kind']}" - ) - - for row in rows: - # We don't want to automatically rerun tests in a performance suite - if row["parent_suite"] != "test_runner.regress": - continue - - if row["name"].endswith("]"): - parametrized_test = row["name"].replace( - "[", - f"[{build_type}-pg{pg_version}{pageserver_virtual_file_io_engine_parameter}{pageserver_default_tenant_config_compaction_algorithm_parameter}-", - ) - else: - parametrized_test = f"{row['name']}[{build_type}-pg{pg_version}{pageserver_virtual_file_io_engine_parameter}{pageserver_default_tenant_config_compaction_algorithm_parameter}]" - - res[row["parent_suite"]][row["suite"]][parametrized_test] = True - - logging.info( - f"\t{row['parent_suite'].replace('.', '/')}/{row['suite']}.py::{parametrized_test}" - ) - - logging.info(f"saving results to {output.name}") - json.dump(res, output, indent=2) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Detect flaky tests in the last N days") - parser.add_argument( - "--output", - type=argparse.FileType("w"), - default="flaky.json", - help="path to output json file (default: flaky.json)", - ) - parser.add_argument( - "--days", - required=False, - default=10, - type=int, - help="how many days to look back for flaky tests (default: 10)", - ) - parser.add_argument( - "--build-type", - required=True, - type=str, - help="for which build type to create list of flaky tests (debug or release)", - ) - parser.add_argument( - "--pg-version", - required=True, - type=int, - help="for which Postgres version to create list of flaky tests (14, 15, etc.)", - ) - parser.add_argument( - "connstr", - help="connection string to the test results database", - ) - args = parser.parse_args() - - level = logging.INFO - logging.basicConfig( - format="%(message)s", - level=level, - ) - - main(args) diff --git a/scripts/force_layer_download.py b/scripts/force_layer_download.py index a4fd3f6132..835e28c5d6 100644 --- a/scripts/force_layer_download.py +++ b/scripts/force_layer_download.py @@ -194,9 +194,11 @@ async def main_impl(args, report_out, client: Client): tenant_ids = await client.get_tenant_ids() get_timeline_id_coros = [client.get_timeline_ids(tenant_id) for tenant_id in tenant_ids] gathered = await asyncio.gather(*get_timeline_id_coros, return_exceptions=True) - assert len(tenant_ids) == len(gathered) tenant_and_timline_ids = [] - for tid, tlids in zip(tenant_ids, gathered): + for tid, tlids in zip(tenant_ids, gathered, strict=True): + # TODO: add error handling if tlids isinstance(Exception) + assert isinstance(tlids, list) + for tlid in tlids: tenant_and_timline_ids.append((tid, tlid)) elif len(comps) == 1: diff --git a/scripts/ingest_regress_test_result-new-format.py b/scripts/ingest_regress_test_result-new-format.py index e0dd0a7189..064c516718 100644 --- a/scripts/ingest_regress_test_result-new-format.py +++ b/scripts/ingest_regress_test_result-new-format.py @@ -11,7 +11,7 @@ import re import sys from contextlib import contextmanager from dataclasses import dataclass -from datetime import datetime, timezone +from datetime import UTC, datetime from pathlib import Path import backoff @@ -31,6 +31,7 @@ CREATE TABLE IF NOT EXISTS results ( duration INT NOT NULL, flaky BOOLEAN NOT NULL, arch arch DEFAULT 'X64', + lfc BOOLEAN DEFAULT false NOT NULL, build_type TEXT NOT NULL, pg_version INT NOT NULL, run_id BIGINT NOT NULL, @@ -54,6 +55,7 @@ class Row: duration: int flaky: bool arch: str + lfc: bool build_type: str pg_version: int run_id: int @@ -132,6 +134,7 @@ def ingest_test_result( if p["name"].startswith("__") } arch = parameters.get("arch", "UNKNOWN").strip("'") + lfc = parameters.get("lfc", "False") == "True" build_type, pg_version, unparametrized_name = parse_test_name(test["name"]) labels = {label["name"]: label["value"] for label in test["labels"]} @@ -140,11 +143,12 @@ def ingest_test_result( suite=labels["suite"], name=unparametrized_name, status=test["status"], - started_at=datetime.fromtimestamp(test["time"]["start"] / 1000, tz=timezone.utc), - stopped_at=datetime.fromtimestamp(test["time"]["stop"] / 1000, tz=timezone.utc), + started_at=datetime.fromtimestamp(test["time"]["start"] / 1000, tz=UTC), + stopped_at=datetime.fromtimestamp(test["time"]["stop"] / 1000, tz=UTC), duration=test["time"]["duration"], flaky=test["flaky"] or test["retriesStatusChange"], arch=arch, + lfc=lfc, build_type=build_type, pg_version=pg_version, run_id=run_id, diff --git a/storage_scrubber/src/checks.rs b/storage_scrubber/src/checks.rs index 525f412b56..8d855d263c 100644 --- a/storage_scrubber/src/checks.rs +++ b/storage_scrubber/src/checks.rs @@ -128,7 +128,7 @@ pub(crate) async fn branch_cleanup_and_check_errors( let layer_names = index_part.layer_metadata.keys().cloned().collect_vec(); if let Some(err) = check_valid_layermap(&layer_names) { - result.errors.push(format!( + result.warnings.push(format!( "index_part.json contains invalid layer map structure: {err}" )); } diff --git a/test_runner/README.md b/test_runner/README.md index 55d8d2faa9..f342ef8aaa 100644 --- a/test_runner/README.md +++ b/test_runner/README.md @@ -113,7 +113,7 @@ The test suite has a Python enum with equal name but different meaning: ```python @enum.unique -class RemoteStorageKind(str, enum.Enum): +class RemoteStorageKind(StrEnum): LOCAL_FS = "local_fs" MOCK_S3 = "mock_s3" REAL_S3 = "real_s3" diff --git a/test_runner/conftest.py b/test_runner/conftest.py index 84eda52d33..887bfef478 100644 --- a/test_runner/conftest.py +++ b/test_runner/conftest.py @@ -13,5 +13,5 @@ pytest_plugins = ( "fixtures.pg_stats", "fixtures.compare_fixtures", "fixtures.slow", - "fixtures.flaky", + "fixtures.reruns", ) diff --git a/test_runner/fixtures/auth_tokens.py b/test_runner/fixtures/auth_tokens.py index be16be81de..8382ce20b3 100644 --- a/test_runner/fixtures/auth_tokens.py +++ b/test_runner/fixtures/auth_tokens.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from enum import Enum +from enum import StrEnum from typing import Any import jwt @@ -37,8 +37,7 @@ class AuthKeys: return self.generate_token(scope=TokenScope.TENANT, tenant_id=str(tenant_id)) -# TODO: Replace with `StrEnum` when we upgrade to python 3.11 -class TokenScope(str, Enum): +class TokenScope(StrEnum): ADMIN = "admin" PAGE_SERVER_API = "pageserverapi" GENERATIONS_API = "generations_api" diff --git a/test_runner/fixtures/benchmark_fixture.py b/test_runner/fixtures/benchmark_fixture.py index 8e68775471..bb8e75902e 100644 --- a/test_runner/fixtures/benchmark_fixture.py +++ b/test_runner/fixtures/benchmark_fixture.py @@ -9,6 +9,7 @@ import re import timeit from contextlib import contextmanager from datetime import datetime +from enum import StrEnum from pathlib import Path from typing import TYPE_CHECKING @@ -24,8 +25,7 @@ from fixtures.log_helper import log from fixtures.neon_fixtures import NeonPageserver if TYPE_CHECKING: - from collections.abc import Iterator, Mapping - from typing import Callable, Optional + from collections.abc import Callable, Iterator, Mapping """ @@ -61,7 +61,7 @@ class PgBenchRunResult: number_of_threads: int number_of_transactions_actually_processed: int latency_average: float - latency_stddev: Optional[float] + latency_stddev: float | None tps: float run_duration: float run_start_timestamp: int @@ -171,14 +171,14 @@ _PGBENCH_INIT_EXTRACTORS: Mapping[str, re.Pattern[str]] = { @dataclasses.dataclass class PgBenchInitResult: - total: Optional[float] - drop_tables: Optional[float] - create_tables: Optional[float] - client_side_generate: Optional[float] - server_side_generate: Optional[float] - vacuum: Optional[float] - primary_keys: Optional[float] - foreign_keys: Optional[float] + total: float | None + drop_tables: float | None + create_tables: float | None + client_side_generate: float | None + server_side_generate: float | None + vacuum: float | None + primary_keys: float | None + foreign_keys: float | None duration: float start_timestamp: int end_timestamp: int @@ -196,7 +196,7 @@ class PgBenchInitResult: last_line = stderr.splitlines()[-1] - timings: dict[str, Optional[float]] = {} + timings: dict[str, float | None] = {} last_line_items = re.split(r"\(|\)|,", last_line) for item in last_line_items: for key, regex in _PGBENCH_INIT_EXTRACTORS.items(): @@ -227,7 +227,7 @@ class PgBenchInitResult: @enum.unique -class MetricReport(str, enum.Enum): # str is a hack to make it json serializable +class MetricReport(StrEnum): # str is a hack to make it json serializable # this means that this is a constant test parameter # like number of transactions, or number of clients TEST_PARAM = "test_param" @@ -256,9 +256,8 @@ class NeonBenchmarker: metric_value: float, unit: str, report: MetricReport, - labels: Optional[ - dict[str, str] - ] = None, # use this to associate additional key/value pairs in json format for associated Neon object IDs like project ID with the metric + # use this to associate additional key/value pairs in json format for associated Neon object IDs like project ID with the metric + labels: dict[str, str] | None = None, ): """ Record a benchmark result. @@ -412,7 +411,7 @@ class NeonBenchmarker: self, pageserver: NeonPageserver, metric_name: str, - label_filters: Optional[dict[str, str]] = None, + label_filters: dict[str, str] | None = None, ) -> int: """Fetch the value of given int counter from pageserver metrics.""" all_metrics = pageserver.http_client().get_metrics() diff --git a/test_runner/fixtures/common_types.py b/test_runner/fixtures/common_types.py index 0ea7148f50..6c22b31e00 100644 --- a/test_runner/fixtures/common_types.py +++ b/test_runner/fixtures/common_types.py @@ -2,14 +2,14 @@ from __future__ import annotations import random from dataclasses import dataclass -from enum import Enum +from enum import StrEnum from functools import total_ordering from typing import TYPE_CHECKING, TypeVar from typing_extensions import override if TYPE_CHECKING: - from typing import Any, Union + from typing import Any T = TypeVar("T", bound="Id") @@ -24,7 +24,7 @@ class Lsn: representation is like "1/0123abcd". See also pg_lsn datatype in Postgres """ - def __init__(self, x: Union[int, str]): + def __init__(self, x: int | str): if isinstance(x, int): self.lsn_int = x else: @@ -67,7 +67,7 @@ class Lsn: return NotImplemented return self.lsn_int - other.lsn_int - def __add__(self, other: Union[int, Lsn]) -> Lsn: + def __add__(self, other: int | Lsn) -> Lsn: if isinstance(other, int): return Lsn(self.lsn_int + other) elif isinstance(other, Lsn): @@ -190,8 +190,23 @@ class TenantTimelineId: ) -# Workaround for compat with python 3.9, which does not have `typing.Self` -TTenantShardId = TypeVar("TTenantShardId", bound="TenantShardId") +@dataclass +class ShardIndex: + shard_number: int + shard_count: int + + # cf impl Display for ShardIndex + @override + def __str__(self) -> str: + return f"{self.shard_number:02x}{self.shard_count:02x}" + + @classmethod + def parse(cls: type[ShardIndex], input: str) -> ShardIndex: + assert len(input) == 4 + return cls( + shard_number=int(input[0:2], 16), + shard_count=int(input[2:4], 16), + ) class TenantShardId: @@ -202,7 +217,7 @@ class TenantShardId: assert self.shard_number < self.shard_count or self.shard_count == 0 @classmethod - def parse(cls: type[TTenantShardId], input: str) -> TTenantShardId: + def parse(cls: type[TenantShardId], input: str) -> TenantShardId: if len(input) == 32: return cls( tenant_id=TenantId(input), @@ -226,6 +241,10 @@ class TenantShardId: # Unsharded case: equivalent of Rust TenantShardId::unsharded(tenant_id) return str(self.tenant_id) + @property + def shard_index(self) -> ShardIndex: + return ShardIndex(self.shard_number, self.shard_count) + @override def __repr__(self): return self.__str__() @@ -249,7 +268,6 @@ class TenantShardId: return hash(self._tuple()) -# TODO: Replace with `StrEnum` when we upgrade to python 3.11 -class TimelineArchivalState(str, Enum): +class TimelineArchivalState(StrEnum): ARCHIVED = "Archived" UNARCHIVED = "Unarchived" diff --git a/test_runner/fixtures/compare_fixtures.py b/test_runner/fixtures/compare_fixtures.py index 85b6e7a3b8..c0892399bd 100644 --- a/test_runner/fixtures/compare_fixtures.py +++ b/test_runner/fixtures/compare_fixtures.py @@ -99,7 +99,7 @@ class PgCompare(ABC): assert row is not None assert len(row) == len(pg_stat.columns) - for col, val in zip(pg_stat.columns, row): + for col, val in zip(pg_stat.columns, row, strict=False): results[f"{pg_stat.table}.{col}"] = int(val) return results diff --git a/test_runner/fixtures/compute_reconfigure.py b/test_runner/fixtures/compute_reconfigure.py index 6354b7f833..33f01f80fb 100644 --- a/test_runner/fixtures/compute_reconfigure.py +++ b/test_runner/fixtures/compute_reconfigure.py @@ -12,7 +12,8 @@ from fixtures.common_types import TenantId from fixtures.log_helper import log if TYPE_CHECKING: - from typing import Any, Callable, Optional + from collections.abc import Callable + from typing import Any class ComputeReconfigure: @@ -20,12 +21,12 @@ class ComputeReconfigure: self.server = server self.control_plane_compute_hook_api = f"http://{server.host}:{server.port}/notify-attach" self.workloads: dict[TenantId, Any] = {} - self.on_notify: Optional[Callable[[Any], None]] = None + self.on_notify: Callable[[Any], None] | None = None def register_workload(self, workload: Any): self.workloads[workload.tenant_id] = workload - def register_on_notify(self, fn: Optional[Callable[[Any], None]]): + def register_on_notify(self, fn: Callable[[Any], None] | None): """ Add some extra work during a notification, like sleeping to slow things down, or logging what was notified. @@ -68,7 +69,7 @@ def compute_reconfigure_listener(make_httpserver: HTTPServer): # This causes the endpoint to query storage controller for its location, which # is redundant since we already have it here, but this avoids extending the # neon_local CLI to take full lists of locations - reconfigure_threads.submit(lambda workload=workload: workload.reconfigure()) # type: ignore[no-any-return] + reconfigure_threads.submit(lambda workload=workload: workload.reconfigure()) # type: ignore[misc] return Response(status=200) diff --git a/test_runner/fixtures/endpoint/http.py b/test_runner/fixtures/endpoint/http.py index db3723b7cc..1cd9158c68 100644 --- a/test_runner/fixtures/endpoint/http.py +++ b/test_runner/fixtures/endpoint/http.py @@ -1,5 +1,7 @@ from __future__ import annotations +import urllib.parse + import requests from requests.adapters import HTTPAdapter @@ -20,7 +22,9 @@ class EndpointHttpClient(requests.Session): return res.json() def database_schema(self, database: str): - res = self.get(f"http://localhost:{self.port}/database_schema?database={database}") + res = self.get( + f"http://localhost:{self.port}/database_schema?database={urllib.parse.quote(database, safe='')}" + ) res.raise_for_status() return res.text diff --git a/test_runner/fixtures/flaky.py b/test_runner/fixtures/flaky.py deleted file mode 100644 index 01634a29c5..0000000000 --- a/test_runner/fixtures/flaky.py +++ /dev/null @@ -1,78 +0,0 @@ -from __future__ import annotations - -import json -from collections.abc import MutableMapping -from pathlib import Path -from typing import TYPE_CHECKING, cast - -import pytest -from _pytest.config import Config -from _pytest.config.argparsing import Parser -from allure_commons.types import LabelType -from allure_pytest.utils import allure_name, allure_suite_labels - -from fixtures.log_helper import log - -if TYPE_CHECKING: - from collections.abc import MutableMapping - from typing import Any - - -""" -The plugin reruns flaky tests. -It uses `pytest.mark.flaky` provided by `pytest-rerunfailures` plugin and flaky tests detected by `scripts/flaky_tests.py` - -Note: the logic of getting flaky tests is extracted to a separate script to avoid running it for each of N xdist workers -""" - - -def pytest_addoption(parser: Parser): - parser.addoption( - "--flaky-tests-json", - action="store", - type=Path, - help="Path to json file with flaky tests generated by scripts/flaky_tests.py", - ) - - -def pytest_collection_modifyitems(config: Config, items: list[pytest.Item]): - if not config.getoption("--flaky-tests-json"): - return - - # Any error with getting flaky tests aren't critical, so just do not rerun any tests - flaky_json = config.getoption("--flaky-tests-json") - if not flaky_json.exists(): - return - - content = flaky_json.read_text() - try: - flaky_tests = json.loads(content) - except ValueError: - log.error(f"Can't parse {content} as json") - return - - for item in items: - # Use the same logic for constructing test name as Allure does (we store allure-provided data in DB) - # Ref https://github.com/allure-framework/allure-python/blob/2.13.1/allure-pytest/src/listener.py#L98-L100 - allure_labels = dict(allure_suite_labels(item)) - parent_suite = str(allure_labels.get(LabelType.PARENT_SUITE)) - suite = str(allure_labels.get(LabelType.SUITE)) - params = item.callspec.params if hasattr(item, "callspec") else {} - name = allure_name(item, params) - - if flaky_tests.get(parent_suite, {}).get(suite, {}).get(name, False): - # Rerun 3 times = 1 original run + 2 reruns - log.info(f"Marking {item.nodeid} as flaky. It will be rerun up to 3 times") - item.add_marker(pytest.mark.flaky(reruns=2)) - - # pytest-rerunfailures is not compatible with pytest-timeout (timeout is not set for reruns), - # we can workaround it by setting `timeout_func_only` to True[1]. - # Unfortunately, setting `timeout_func_only = True` globally in pytest.ini is broken[2], - # but we still can do it using pytest marker. - # - # - [1] https://github.com/pytest-dev/pytest-rerunfailures/issues/99 - # - [2] https://github.com/pytest-dev/pytest-timeout/issues/142 - timeout_marker = item.get_closest_marker("timeout") - if timeout_marker is not None: - kwargs = cast("MutableMapping[str, Any]", timeout_marker.kwargs) - kwargs["func_only"] = True diff --git a/test_runner/fixtures/h2server.py b/test_runner/fixtures/h2server.py index e890b2bcf1..3e35af3b5b 100644 --- a/test_runner/fixtures/h2server.py +++ b/test_runner/fixtures/h2server.py @@ -31,7 +31,7 @@ from h2.settings import SettingCodes from typing_extensions import override if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any RequestData = collections.namedtuple("RequestData", ["headers", "data"]) @@ -49,7 +49,7 @@ class H2Protocol(asyncio.Protocol): def __init__(self): config = H2Configuration(client_side=False, header_encoding="utf-8") self.conn = H2Connection(config=config) - self.transport: Optional[asyncio.Transport] = None + self.transport: asyncio.Transport | None = None self.stream_data: dict[int, RequestData] = {} self.flow_control_futures: dict[int, asyncio.Future[Any]] = {} @@ -61,7 +61,7 @@ class H2Protocol(asyncio.Protocol): self.transport.write(self.conn.data_to_send()) @override - def connection_lost(self, exc: Optional[Exception]): + def connection_lost(self, exc: Exception | None): for future in self.flow_control_futures.values(): future.cancel() self.flow_control_futures = {} diff --git a/test_runner/fixtures/metrics.py b/test_runner/fixtures/metrics.py index 39c8f70a9c..3f90c233a6 100644 --- a/test_runner/fixtures/metrics.py +++ b/test_runner/fixtures/metrics.py @@ -1,16 +1,12 @@ from __future__ import annotations from collections import defaultdict -from typing import TYPE_CHECKING from prometheus_client.parser import text_string_to_metric_families from prometheus_client.samples import Sample from fixtures.log_helper import log -if TYPE_CHECKING: - from typing import Optional - class Metrics: metrics: dict[str, list[Sample]] @@ -20,7 +16,7 @@ class Metrics: self.metrics = defaultdict(list) self.name = name - def query_all(self, name: str, filter: Optional[dict[str, str]] = None) -> list[Sample]: + def query_all(self, name: str, filter: dict[str, str] | None = None) -> list[Sample]: filter = filter or {} res: list[Sample] = [] @@ -32,7 +28,7 @@ class Metrics: pass return res - def query_one(self, name: str, filter: Optional[dict[str, str]] = None) -> Sample: + def query_one(self, name: str, filter: dict[str, str] | None = None) -> Sample: res = self.query_all(name, filter or {}) assert len(res) == 1, f"expected single sample for {name} {filter}, found {res}" return res[0] @@ -47,9 +43,7 @@ class MetricsGetter: def get_metrics(self) -> Metrics: raise NotImplementedError() - def get_metric_value( - self, name: str, filter: Optional[dict[str, str]] = None - ) -> Optional[float]: + def get_metric_value(self, name: str, filter: dict[str, str] | None = None) -> float | None: metrics = self.get_metrics() results = metrics.query_all(name, filter=filter) if not results: @@ -59,7 +53,7 @@ class MetricsGetter: return results[0].value def get_metrics_values( - self, names: list[str], filter: Optional[dict[str, str]] = None, absence_ok: bool = False + self, names: list[str], filter: dict[str, str] | None = None, absence_ok: bool = False ) -> dict[str, float]: """ When fetching multiple named metrics, it is more efficient to use this @@ -174,6 +168,7 @@ PAGESERVER_PER_TENANT_METRICS: tuple[str, ...] = ( "pageserver_evictions_with_low_residence_duration_total", "pageserver_aux_file_estimated_size", "pageserver_valid_lsn_lease_count", + "pageserver_flush_wait_upload_seconds", counter("pageserver_tenant_throttling_count_accounted_start"), counter("pageserver_tenant_throttling_count_accounted_finish"), counter("pageserver_tenant_throttling_wait_usecs_sum"), diff --git a/test_runner/fixtures/neon_api.py b/test_runner/fixtures/neon_api.py index 9de6681beb..df80f0683c 100644 --- a/test_runner/fixtures/neon_api.py +++ b/test_runner/fixtures/neon_api.py @@ -8,7 +8,7 @@ import requests from fixtures.log_helper import log if TYPE_CHECKING: - from typing import Any, Literal, Optional + from typing import Any, Literal from fixtures.pg_version import PgVersion @@ -40,11 +40,11 @@ class NeonAPI: def create_project( self, - pg_version: Optional[PgVersion] = None, - name: Optional[str] = None, - branch_name: Optional[str] = None, - branch_role_name: Optional[str] = None, - branch_database_name: Optional[str] = None, + pg_version: PgVersion | None = None, + name: str | None = None, + branch_name: str | None = None, + branch_role_name: str | None = None, + branch_database_name: str | None = None, ) -> dict[str, Any]: data: dict[str, Any] = { "project": { @@ -179,8 +179,8 @@ class NeonAPI: def get_connection_uri( self, project_id: str, - branch_id: Optional[str] = None, - endpoint_id: Optional[str] = None, + branch_id: str | None = None, + endpoint_id: str | None = None, database_name: str = "neondb", role_name: str = "neondb_owner", pooled: bool = True, @@ -249,7 +249,7 @@ class NeonAPI: @final class NeonApiEndpoint: - def __init__(self, neon_api: NeonAPI, pg_version: PgVersion, project_id: Optional[str]): + def __init__(self, neon_api: NeonAPI, pg_version: PgVersion, project_id: str | None): self.neon_api = neon_api self.project_id: str self.endpoint_id: str diff --git a/test_runner/fixtures/neon_cli.py b/test_runner/fixtures/neon_cli.py index d220ea57a2..a85a191455 100644 --- a/test_runner/fixtures/neon_cli.py +++ b/test_runner/fixtures/neon_cli.py @@ -20,13 +20,9 @@ from fixtures.pg_version import PgVersion if TYPE_CHECKING: from typing import ( Any, - Optional, - TypeVar, cast, ) - T = TypeVar("T") - # Used to be an ABC. abc.ABC removed due to linter without name change. class AbstractNeonCli: @@ -36,7 +32,7 @@ class AbstractNeonCli: Do not use directly, use specific subclasses instead. """ - def __init__(self, extra_env: Optional[dict[str, str]], binpath: Path): + def __init__(self, extra_env: dict[str, str] | None, binpath: Path): self.extra_env = extra_env self.binpath = binpath @@ -45,7 +41,7 @@ class AbstractNeonCli: def raw_cli( self, arguments: list[str], - extra_env_vars: Optional[dict[str, str]] = None, + extra_env_vars: dict[str, str] | None = None, check_return_code=True, timeout=None, ) -> subprocess.CompletedProcess[str]: @@ -173,7 +169,7 @@ class NeonLocalCli(AbstractNeonCli): def __init__( self, - extra_env: Optional[dict[str, str]], + extra_env: dict[str, str] | None, binpath: Path, repo_dir: Path, pg_distrib_dir: Path, @@ -195,10 +191,10 @@ class NeonLocalCli(AbstractNeonCli): tenant_id: TenantId, timeline_id: TimelineId, pg_version: PgVersion, - conf: Optional[dict[str, Any]] = None, - shard_count: Optional[int] = None, - shard_stripe_size: Optional[int] = None, - placement_policy: Optional[str] = None, + conf: dict[str, Any] | None = None, + shard_count: int | None = None, + shard_stripe_size: int | None = None, + placement_policy: str | None = None, set_default: bool = False, ): """ @@ -302,8 +298,8 @@ class NeonLocalCli(AbstractNeonCli): tenant_id: TenantId, timeline_id: TimelineId, new_branch_name, - ancestor_branch_name: Optional[str] = None, - ancestor_start_lsn: Optional[Lsn] = None, + ancestor_branch_name: str | None = None, + ancestor_start_lsn: Lsn | None = None, ): cmd = [ "timeline", @@ -331,8 +327,8 @@ class NeonLocalCli(AbstractNeonCli): base_lsn: Lsn, base_tarfile: Path, pg_version: PgVersion, - end_lsn: Optional[Lsn] = None, - wal_tarfile: Optional[Path] = None, + end_lsn: Lsn | None = None, + wal_tarfile: Path | None = None, ): cmd = [ "timeline", @@ -380,7 +376,7 @@ class NeonLocalCli(AbstractNeonCli): def init( self, init_config: dict[str, Any], - force: Optional[str] = None, + force: str | None = None, ) -> subprocess.CompletedProcess[str]: with tempfile.NamedTemporaryFile(mode="w+") as init_config_tmpfile: init_config_tmpfile.write(toml.dumps(init_config)) @@ -400,9 +396,9 @@ class NeonLocalCli(AbstractNeonCli): def storage_controller_start( self, - timeout_in_seconds: Optional[int] = None, - instance_id: Optional[int] = None, - base_port: Optional[int] = None, + timeout_in_seconds: int | None = None, + instance_id: int | None = None, + base_port: int | None = None, ): cmd = ["storage_controller", "start"] if timeout_in_seconds is not None: @@ -413,7 +409,7 @@ class NeonLocalCli(AbstractNeonCli): cmd.append(f"--base-port={base_port}") return self.raw_cli(cmd) - def storage_controller_stop(self, immediate: bool, instance_id: Optional[int] = None): + def storage_controller_stop(self, immediate: bool, instance_id: int | None = None): cmd = ["storage_controller", "stop"] if immediate: cmd.extend(["-m", "immediate"]) @@ -424,8 +420,8 @@ class NeonLocalCli(AbstractNeonCli): def pageserver_start( self, id: int, - extra_env_vars: Optional[dict[str, str]] = None, - timeout_in_seconds: Optional[int] = None, + extra_env_vars: dict[str, str] | None = None, + timeout_in_seconds: int | None = None, ) -> subprocess.CompletedProcess[str]: start_args = ["pageserver", "start", f"--id={id}"] if timeout_in_seconds is not None: @@ -442,9 +438,9 @@ class NeonLocalCli(AbstractNeonCli): def safekeeper_start( self, id: int, - extra_opts: Optional[list[str]] = None, - extra_env_vars: Optional[dict[str, str]] = None, - timeout_in_seconds: Optional[int] = None, + extra_opts: list[str] | None = None, + extra_env_vars: dict[str, str] | None = None, + timeout_in_seconds: int | None = None, ) -> subprocess.CompletedProcess[str]: if extra_opts is not None: extra_opts = [f"-e={opt}" for opt in extra_opts] @@ -457,7 +453,7 @@ class NeonLocalCli(AbstractNeonCli): ) def safekeeper_stop( - self, id: Optional[int] = None, immediate=False + self, id: int | None = None, immediate=False ) -> subprocess.CompletedProcess[str]: args = ["safekeeper", "stop"] if id is not None: @@ -467,7 +463,7 @@ class NeonLocalCli(AbstractNeonCli): return self.raw_cli(args) def storage_broker_start( - self, timeout_in_seconds: Optional[int] = None + self, timeout_in_seconds: int | None = None ) -> subprocess.CompletedProcess[str]: cmd = ["storage_broker", "start"] if timeout_in_seconds is not None: @@ -485,10 +481,10 @@ class NeonLocalCli(AbstractNeonCli): http_port: int, tenant_id: TenantId, pg_version: PgVersion, - endpoint_id: Optional[str] = None, + endpoint_id: str | None = None, hot_standby: bool = False, - lsn: Optional[Lsn] = None, - pageserver_id: Optional[int] = None, + lsn: Lsn | None = None, + pageserver_id: int | None = None, allow_multiple=False, ) -> subprocess.CompletedProcess[str]: args = [ @@ -523,11 +519,11 @@ class NeonLocalCli(AbstractNeonCli): def endpoint_start( self, endpoint_id: str, - safekeepers: Optional[list[int]] = None, - remote_ext_config: Optional[str] = None, - pageserver_id: Optional[int] = None, + safekeepers: list[int] | None = None, + remote_ext_config: str | None = None, + pageserver_id: int | None = None, allow_multiple=False, - basebackup_request_tries: Optional[int] = None, + basebackup_request_tries: int | None = None, ) -> subprocess.CompletedProcess[str]: args = [ "endpoint", @@ -555,9 +551,9 @@ class NeonLocalCli(AbstractNeonCli): def endpoint_reconfigure( self, endpoint_id: str, - tenant_id: Optional[TenantId] = None, - pageserver_id: Optional[int] = None, - safekeepers: Optional[list[int]] = None, + tenant_id: TenantId | None = None, + pageserver_id: int | None = None, + safekeepers: list[int] | None = None, check_return_code=True, ) -> subprocess.CompletedProcess[str]: args = ["endpoint", "reconfigure", endpoint_id] @@ -574,7 +570,7 @@ class NeonLocalCli(AbstractNeonCli): endpoint_id: str, destroy=False, check_return_code=True, - mode: Optional[str] = None, + mode: str | None = None, ) -> subprocess.CompletedProcess[str]: args = [ "endpoint", diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index 205a47a9d5..9bcfffeb9c 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -17,7 +17,7 @@ from collections.abc import Iterable, Iterator from contextlib import closing, contextmanager from dataclasses import dataclass from datetime import datetime -from enum import Enum +from enum import StrEnum from functools import cached_property from pathlib import Path from types import TracebackType @@ -90,10 +90,12 @@ from fixtures.safekeeper.utils import wait_walreceivers_absent from fixtures.utils import ( ATTACHMENT_NAME_REGEX, COMPONENT_BINARIES, + USE_LFC, allure_add_grafana_links, assert_no_errors, get_dir_size, print_gc_result, + size_to_bytes, subprocess_capture, wait_until, ) @@ -101,13 +103,8 @@ from fixtures.utils import ( from .neon_api import NeonAPI, NeonApiEndpoint if TYPE_CHECKING: - from typing import ( - Any, - Callable, - Optional, - TypeVar, - Union, - ) + from collections.abc import Callable + from typing import Any, Self, TypeVar from fixtures.paths import SnapshotDirLocked @@ -313,6 +310,31 @@ class PgProtocol: return self.safe_psql(query, log_query=log_query)[0][0] +class PageserverWalReceiverProtocol(StrEnum): + VANILLA = "vanilla" + INTERPRETED = "interpreted" + + @staticmethod + def to_config_key_value(proto) -> tuple[str, dict[str, Any]]: + if proto == PageserverWalReceiverProtocol.VANILLA: + return ( + "wal_receiver_protocol", + { + "type": "vanilla", + }, + ) + elif proto == PageserverWalReceiverProtocol.INTERPRETED: + return ( + "wal_receiver_protocol", + { + "type": "interpreted", + "args": {"format": "protobuf", "compression": {"zstd": {"level": 1}}}, + }, + ) + else: + raise ValueError(f"Unknown protocol type: {proto}") + + class NeonEnvBuilder: """ Builder object to create a Neon runtime environment @@ -338,10 +360,10 @@ class NeonEnvBuilder: top_output_dir: Path, test_output_dir: Path, combination, - test_overlay_dir: Optional[Path] = None, - pageserver_remote_storage: Optional[RemoteStorage] = None, + test_overlay_dir: Path | None = None, + pageserver_remote_storage: RemoteStorage | None = None, # toml that will be decomposed into `--config-override` flags during `pageserver --init` - pageserver_config_override: Optional[str | Callable[[dict[str, Any]], None]] = None, + pageserver_config_override: str | Callable[[dict[str, Any]], None] | None = None, num_safekeepers: int = 1, num_pageservers: int = 1, # Use non-standard SK ids to check for various parsing bugs @@ -349,16 +371,17 @@ class NeonEnvBuilder: # fsync is disabled by default to make the tests go faster safekeepers_enable_fsync: bool = False, auth_enabled: bool = False, - rust_log_override: Optional[str] = None, + rust_log_override: str | None = None, default_branch_name: str = DEFAULT_BRANCH_NAME, preserve_database_files: bool = False, - initial_tenant: Optional[TenantId] = None, - initial_timeline: Optional[TimelineId] = None, - pageserver_virtual_file_io_engine: Optional[str] = None, - pageserver_default_tenant_config_compaction_algorithm: Optional[dict[str, Any]] = None, - safekeeper_extra_opts: Optional[list[str]] = None, - storage_controller_port_override: Optional[int] = None, - pageserver_virtual_file_io_mode: Optional[str] = None, + initial_tenant: TenantId | None = None, + initial_timeline: TimelineId | None = None, + pageserver_virtual_file_io_engine: str | None = None, + pageserver_default_tenant_config_compaction_algorithm: dict[str, Any] | None = None, + safekeeper_extra_opts: list[str] | None = None, + storage_controller_port_override: int | None = None, + pageserver_virtual_file_io_mode: str | None = None, + pageserver_wal_receiver_protocol: PageserverWalReceiverProtocol | None = None, ): self.repo_dir = repo_dir self.rust_log_override = rust_log_override @@ -367,7 +390,7 @@ class NeonEnvBuilder: # Pageserver remote storage self.pageserver_remote_storage = pageserver_remote_storage # Safekeepers remote storage - self.safekeepers_remote_storage: Optional[RemoteStorage] = None + self.safekeepers_remote_storage: RemoteStorage | None = None self.run_id = run_id self.mock_s3_server: MockS3Server = mock_s3_server @@ -378,7 +401,7 @@ class NeonEnvBuilder: self.safekeepers_enable_fsync = safekeepers_enable_fsync self.auth_enabled = auth_enabled self.default_branch_name = default_branch_name - self.env: Optional[NeonEnv] = None + self.env: NeonEnv | None = None self.keep_remote_storage_contents: bool = True self.neon_binpath = neon_binpath self.neon_local_binpath = neon_binpath @@ -391,14 +414,14 @@ class NeonEnvBuilder: self.test_output_dir = test_output_dir self.test_overlay_dir = test_overlay_dir self.overlay_mounts_created_by_us: list[tuple[str, Path]] = [] - self.config_init_force: Optional[str] = None + self.config_init_force: str | None = None self.top_output_dir = top_output_dir - self.control_plane_compute_hook_api: Optional[str] = None - self.storage_controller_config: Optional[dict[Any, Any]] = None + self.control_plane_compute_hook_api: str | None = None + self.storage_controller_config: dict[Any, Any] | None = None - self.pageserver_virtual_file_io_engine: Optional[str] = pageserver_virtual_file_io_engine + self.pageserver_virtual_file_io_engine: str | None = pageserver_virtual_file_io_engine - self.pageserver_default_tenant_config_compaction_algorithm: Optional[dict[str, Any]] = ( + self.pageserver_default_tenant_config_compaction_algorithm: dict[str, Any] | None = ( pageserver_default_tenant_config_compaction_algorithm ) if self.pageserver_default_tenant_config_compaction_algorithm is not None: @@ -412,6 +435,8 @@ class NeonEnvBuilder: self.pageserver_virtual_file_io_mode = pageserver_virtual_file_io_mode + self.pageserver_wal_receiver_protocol = pageserver_wal_receiver_protocol + assert test_name.startswith( "test_" ), "Unexpectedly instantiated from outside a test function" @@ -440,10 +465,10 @@ class NeonEnvBuilder: def init_start( self, - initial_tenant_conf: Optional[dict[str, Any]] = None, + initial_tenant_conf: dict[str, Any] | None = None, default_remote_storage_if_missing: bool = True, - initial_tenant_shard_count: Optional[int] = None, - initial_tenant_shard_stripe_size: Optional[int] = None, + initial_tenant_shard_count: int | None = None, + initial_tenant_shard_stripe_size: int | None = None, ) -> NeonEnv: """ Default way to create and start NeonEnv. Also creates the initial_tenant with root initial_timeline. @@ -781,8 +806,8 @@ class NeonEnvBuilder: self, kind: RemoteStorageKind, user: RemoteStorageUser, - bucket_name: Optional[str] = None, - bucket_region: Optional[str] = None, + bucket_name: str | None = None, + bucket_region: str | None = None, ) -> RemoteStorage: ret = kind.configure( self.repo_dir, @@ -840,14 +865,14 @@ class NeonEnvBuilder: if isinstance(x, S3Storage): x.do_cleanup() - def __enter__(self) -> NeonEnvBuilder: + def __enter__(self) -> Self: return self def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ): # Stop all the nodes. if self.env: @@ -1026,6 +1051,7 @@ class NeonEnv: self.pageserver_virtual_file_io_engine = config.pageserver_virtual_file_io_engine self.pageserver_virtual_file_io_mode = config.pageserver_virtual_file_io_mode + self.pageserver_wal_receiver_protocol = config.pageserver_wal_receiver_protocol # Create the neon_local's `NeonLocalInitConf` cfg: dict[str, Any] = { @@ -1095,6 +1121,13 @@ class NeonEnv: if self.pageserver_virtual_file_io_mode is not None: ps_cfg["virtual_file_io_mode"] = self.pageserver_virtual_file_io_mode + if self.pageserver_wal_receiver_protocol is not None: + key, value = PageserverWalReceiverProtocol.to_config_key_value( + self.pageserver_wal_receiver_protocol + ) + if key not in ps_cfg: + ps_cfg[key] = value + # Create a corresponding NeonPageserver object self.pageservers.append( NeonPageserver(self, ps_id, port=pageserver_port, az_id=ps_cfg["availability_zone"]) @@ -1136,7 +1169,7 @@ class NeonEnv: force=config.config_init_force, ) - def start(self, timeout_in_seconds: Optional[int] = None): + def start(self, timeout_in_seconds: int | None = None): # Storage controller starts first, so that pageserver /re-attach calls don't # bounce through retries on startup self.storage_controller.start(timeout_in_seconds=timeout_in_seconds) @@ -1150,21 +1183,19 @@ class NeonEnv: with concurrent.futures.ThreadPoolExecutor( max_workers=2 + len(self.pageservers) + len(self.safekeepers) ) as executor: - futs.append( - executor.submit(lambda: self.broker.start() or None) - ) # The `or None` is for the linter + futs.append(executor.submit(lambda: self.broker.start())) for pageserver in self.pageservers: futs.append( executor.submit( - lambda ps=pageserver: ps.start(timeout_in_seconds=timeout_in_seconds) + lambda ps=pageserver: ps.start(timeout_in_seconds=timeout_in_seconds) # type: ignore[misc] ) ) for safekeeper in self.safekeepers: futs.append( executor.submit( - lambda sk=safekeeper: sk.start(timeout_in_seconds=timeout_in_seconds) + lambda sk=safekeeper: sk.start(timeout_in_seconds=timeout_in_seconds) # type: ignore[misc] ) ) @@ -1237,7 +1268,7 @@ class NeonEnv: ), "env.pageserver must only be used with single pageserver NeonEnv" return self.pageservers[0] - def get_pageserver(self, id: Optional[int]) -> NeonPageserver: + def get_pageserver(self, id: int | None) -> NeonPageserver: """ Look up a pageserver by its node ID. @@ -1254,7 +1285,7 @@ class NeonEnv: raise RuntimeError(f"Pageserver with ID {id} not found") - def get_tenant_pageserver(self, tenant_id: Union[TenantId, TenantShardId]): + def get_tenant_pageserver(self, tenant_id: TenantId | TenantShardId): """ Get the NeonPageserver where this tenant shard is currently attached, according to the storage controller. @@ -1316,12 +1347,12 @@ class NeonEnv: def create_tenant( self, - tenant_id: Optional[TenantId] = None, - timeline_id: Optional[TimelineId] = None, - conf: Optional[dict[str, Any]] = None, - shard_count: Optional[int] = None, - shard_stripe_size: Optional[int] = None, - placement_policy: Optional[str] = None, + tenant_id: TenantId | None = None, + timeline_id: TimelineId | None = None, + conf: dict[str, Any] | None = None, + shard_count: int | None = None, + shard_stripe_size: int | None = None, + placement_policy: str | None = None, set_default: bool = False, ) -> tuple[TenantId, TimelineId]: """ @@ -1343,7 +1374,7 @@ class NeonEnv: return tenant_id, timeline_id - def config_tenant(self, tenant_id: Optional[TenantId], conf: dict[str, str]): + def config_tenant(self, tenant_id: TenantId | None, conf: dict[str, str]): """ Update tenant config. """ @@ -1353,10 +1384,10 @@ class NeonEnv: def create_branch( self, new_branch_name: str = DEFAULT_BRANCH_NAME, - tenant_id: Optional[TenantId] = None, - ancestor_branch_name: Optional[str] = None, - ancestor_start_lsn: Optional[Lsn] = None, - new_timeline_id: Optional[TimelineId] = None, + tenant_id: TenantId | None = None, + ancestor_branch_name: str | None = None, + ancestor_start_lsn: Lsn | None = None, + new_timeline_id: TimelineId | None = None, ) -> TimelineId: new_timeline_id = new_timeline_id or TimelineId.generate() tenant_id = tenant_id or self.initial_tenant @@ -1370,8 +1401,8 @@ class NeonEnv: def create_timeline( self, new_branch_name: str, - tenant_id: Optional[TenantId] = None, - timeline_id: Optional[TimelineId] = None, + tenant_id: TenantId | None = None, + timeline_id: TimelineId | None = None, ) -> TimelineId: timeline_id = timeline_id or TimelineId.generate() tenant_id = tenant_id or self.initial_tenant @@ -1396,8 +1427,8 @@ def neon_simple_env( compatibility_pg_distrib_dir: Path, pg_version: PgVersion, pageserver_virtual_file_io_engine: str, - pageserver_default_tenant_config_compaction_algorithm: Optional[dict[str, Any]], - pageserver_virtual_file_io_mode: Optional[str], + pageserver_default_tenant_config_compaction_algorithm: dict[str, Any] | None, + pageserver_virtual_file_io_mode: str | None, ) -> Iterator[NeonEnv]: """ Simple Neon environment, with 1 safekeeper and 1 pageserver. No authentication, no fsync. @@ -1453,9 +1484,9 @@ def neon_env_builder( test_overlay_dir: Path, top_output_dir: Path, pageserver_virtual_file_io_engine: str, - pageserver_default_tenant_config_compaction_algorithm: Optional[dict[str, Any]], + pageserver_default_tenant_config_compaction_algorithm: dict[str, Any] | None, record_property: Callable[[str, object], None], - pageserver_virtual_file_io_mode: Optional[str], + pageserver_virtual_file_io_mode: str | None, ) -> Iterator[NeonEnvBuilder]: """ Fixture to create a Neon environment for test. @@ -1530,7 +1561,7 @@ class LogUtils: def log_contains( self, pattern: str, offset: None | LogCursor = None - ) -> Optional[tuple[str, LogCursor]]: + ) -> tuple[str, LogCursor] | None: """Check that the log contains a line that matches the given regex""" logfile = self.logfile if not logfile.exists(): @@ -1569,14 +1600,13 @@ class StorageControllerApiException(Exception): # See libs/pageserver_api/src/controller_api.rs # for the rust definitions of the enums below -# TODO: Replace with `StrEnum` when we upgrade to python 3.11 -class PageserverAvailability(str, Enum): +class PageserverAvailability(StrEnum): ACTIVE = "Active" UNAVAILABLE = "Unavailable" OFFLINE = "Offline" -class PageserverSchedulingPolicy(str, Enum): +class PageserverSchedulingPolicy(StrEnum): ACTIVE = "Active" DRAINING = "Draining" FILLING = "Filling" @@ -1584,7 +1614,7 @@ class PageserverSchedulingPolicy(str, Enum): PAUSE_FOR_RESTART = "PauseForRestart" -class StorageControllerLeadershipStatus(str, Enum): +class StorageControllerLeadershipStatus(StrEnum): LEADER = "leader" STEPPED_DOWN = "stepped_down" CANDIDATE = "candidate" @@ -1602,16 +1632,16 @@ class NeonStorageController(MetricsGetter, LogUtils): def start( self, - timeout_in_seconds: Optional[int] = None, - instance_id: Optional[int] = None, - base_port: Optional[int] = None, - ): + timeout_in_seconds: int | None = None, + instance_id: int | None = None, + base_port: int | None = None, + ) -> Self: assert not self.running self.env.neon_cli.storage_controller_start(timeout_in_seconds, instance_id, base_port) self.running = True return self - def stop(self, immediate: bool = False) -> NeonStorageController: + def stop(self, immediate: bool = False) -> Self: if self.running: self.env.neon_cli.storage_controller_stop(immediate) self.running = False @@ -1673,7 +1703,7 @@ class NeonStorageController(MetricsGetter, LogUtils): return resp - def headers(self, scope: Optional[TokenScope]) -> dict[str, str]: + def headers(self, scope: TokenScope | None) -> dict[str, str]: headers = {} if self.auth_enabled and scope is not None: jwt_token = self.env.auth_keys.generate_token(scope=scope) @@ -1711,9 +1741,9 @@ class NeonStorageController(MetricsGetter, LogUtils): def attach_hook_issue( self, - tenant_shard_id: Union[TenantId, TenantShardId], + tenant_shard_id: TenantId | TenantShardId, pageserver_id: int, - generation_override: Optional[int] = None, + generation_override: int | None = None, ) -> int: body = {"tenant_shard_id": str(tenant_shard_id), "node_id": pageserver_id} if generation_override is not None: @@ -1729,7 +1759,7 @@ class NeonStorageController(MetricsGetter, LogUtils): assert isinstance(gen, int) return gen - def attach_hook_drop(self, tenant_shard_id: Union[TenantId, TenantShardId]): + def attach_hook_drop(self, tenant_shard_id: TenantId | TenantShardId): self.request( "POST", f"{self.api}/debug/v1/attach-hook", @@ -1737,7 +1767,7 @@ class NeonStorageController(MetricsGetter, LogUtils): headers=self.headers(TokenScope.ADMIN), ) - def inspect(self, tenant_shard_id: Union[TenantId, TenantShardId]) -> Optional[tuple[int, int]]: + def inspect(self, tenant_shard_id: TenantId | TenantShardId) -> tuple[int, int] | None: """ :return: 2-tuple of (generation, pageserver id), or None if unknown """ @@ -1857,10 +1887,10 @@ class NeonStorageController(MetricsGetter, LogUtils): def tenant_create( self, tenant_id: TenantId, - shard_count: Optional[int] = None, - shard_stripe_size: Optional[int] = None, - tenant_config: Optional[dict[Any, Any]] = None, - placement_policy: Optional[Union[dict[Any, Any], str]] = None, + shard_count: int | None = None, + shard_stripe_size: int | None = None, + tenant_config: dict[Any, Any] | None = None, + placement_policy: dict[Any, Any] | str | None = None, ): """ Use this rather than pageserver_api() when you need to include shard parameters @@ -1891,6 +1921,20 @@ class NeonStorageController(MetricsGetter, LogUtils): response.raise_for_status() log.info(f"tenant_create success: {response.json()}") + def timeline_create( + self, + tenant_id: TenantId, + body: dict[str, Any], + ): + response = self.request( + "POST", + f"{self.api}/v1/tenant/{tenant_id}/timeline", + json=body, + headers=self.headers(TokenScope.PAGE_SERVER_API), + ) + response.raise_for_status() + log.info(f"timeline_create success: {response.json()}") + def locate(self, tenant_id: TenantId) -> list[dict[str, Any]]: """ :return: list of {"shard_id": "", "node_id": int, "listen_pg_addr": str, "listen_pg_port": int, "listen_http_addr": str, "listen_http_port": int} @@ -1941,7 +1985,7 @@ class NeonStorageController(MetricsGetter, LogUtils): return response.json() def tenant_shard_split( - self, tenant_id: TenantId, shard_count: int, shard_stripe_size: Optional[int] = None + self, tenant_id: TenantId, shard_count: int, shard_stripe_size: int | None = None ) -> list[TenantShardId]: response = self.request( "PUT", @@ -2039,8 +2083,8 @@ class NeonStorageController(MetricsGetter, LogUtils): def poll_node_status( self, node_id: int, - desired_availability: Optional[PageserverAvailability], - desired_scheduling_policy: Optional[PageserverSchedulingPolicy], + desired_availability: PageserverAvailability | None, + desired_scheduling_policy: PageserverSchedulingPolicy | None, max_attempts: int, backoff: float, ): @@ -2259,7 +2303,7 @@ class NeonStorageController(MetricsGetter, LogUtils): json=body, ) - def get_safekeeper(self, id: int) -> Optional[dict[str, Any]]: + def get_safekeeper(self, id: int) -> dict[str, Any] | None: try: response = self.request( "GET", @@ -2285,14 +2329,14 @@ class NeonStorageController(MetricsGetter, LogUtils): response.raise_for_status() return [TenantShardId.parse(tid) for tid in response.json()["updated"]] - def __enter__(self) -> NeonStorageController: + def __enter__(self) -> Self: return self def __exit__( self, - exc_type: Optional[type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, ): self.stop(immediate=True) @@ -2304,10 +2348,10 @@ class NeonProxiedStorageController(NeonStorageController): def start( self, - timeout_in_seconds: Optional[int] = None, - instance_id: Optional[int] = None, - base_port: Optional[int] = None, - ): + timeout_in_seconds: int | None = None, + instance_id: int | None = None, + base_port: int | None = None, + ) -> Self: assert instance_id is not None and base_port is not None self.env.neon_cli.storage_controller_start(timeout_in_seconds, instance_id, base_port) @@ -2317,7 +2361,7 @@ class NeonProxiedStorageController(NeonStorageController): return self def stop_instance( - self, immediate: bool = False, instance_id: Optional[int] = None + self, immediate: bool = False, instance_id: int | None = None ) -> NeonStorageController: assert instance_id in self.instances if self.instances[instance_id]["running"]: @@ -2327,7 +2371,7 @@ class NeonProxiedStorageController(NeonStorageController): self.running = any(meta["running"] for meta in self.instances.values()) return self - def stop(self, immediate: bool = False) -> NeonStorageController: + def stop(self, immediate: bool = False) -> Self: for iid, details in self.instances.items(): if details["running"]: self.env.neon_cli.storage_controller_stop(immediate, iid) @@ -2346,7 +2390,7 @@ class NeonProxiedStorageController(NeonStorageController): def log_contains( self, pattern: str, offset: None | LogCursor = None - ) -> Optional[tuple[str, LogCursor]]: + ) -> tuple[str, LogCursor] | None: raise NotImplementedError() @@ -2393,8 +2437,8 @@ class NeonPageserver(PgProtocol, LogUtils): def timeline_dir( self, - tenant_shard_id: Union[TenantId, TenantShardId], - timeline_id: Optional[TimelineId] = None, + tenant_shard_id: TenantId | TenantShardId, + timeline_id: TimelineId | None = None, ) -> Path: """Get a timeline directory's path based on the repo directory of the test environment""" if timeline_id is None: @@ -2403,7 +2447,7 @@ class NeonPageserver(PgProtocol, LogUtils): def tenant_dir( self, - tenant_shard_id: Optional[Union[TenantId, TenantShardId]] = None, + tenant_shard_id: TenantId | TenantShardId | None = None, ) -> Path: """Get a tenant directory's path based on the repo directory of the test environment""" if tenant_shard_id is None: @@ -2447,9 +2491,9 @@ class NeonPageserver(PgProtocol, LogUtils): def start( self, - extra_env_vars: Optional[dict[str, str]] = None, - timeout_in_seconds: Optional[int] = None, - ) -> NeonPageserver: + extra_env_vars: dict[str, str] | None = None, + timeout_in_seconds: int | None = None, + ) -> Self: """ Start the page server. `overrides` allows to add some config to this pageserver start. @@ -2484,7 +2528,7 @@ class NeonPageserver(PgProtocol, LogUtils): return self - def stop(self, immediate: bool = False) -> NeonPageserver: + def stop(self, immediate: bool = False) -> Self: """ Stop the page server. Returns self. @@ -2497,7 +2541,7 @@ class NeonPageserver(PgProtocol, LogUtils): def restart( self, immediate: bool = False, - timeout_in_seconds: Optional[int] = None, + timeout_in_seconds: int | None = None, ): """ High level wrapper for restart: restarts the process, and waits for @@ -2532,14 +2576,14 @@ class NeonPageserver(PgProtocol, LogUtils): wait_until(20, 0.5, complete) - def __enter__(self) -> NeonPageserver: + def __enter__(self) -> Self: return self def __exit__( self, - exc_type: Optional[type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, ): self.stop(immediate=True) @@ -2548,7 +2592,7 @@ class NeonPageserver(PgProtocol, LogUtils): pytest.skip("pageserver was built without 'testing' feature") def http_client( - self, auth_token: Optional[str] = None, retries: Optional[Retry] = None + self, auth_token: str | None = None, retries: Retry | None = None ) -> PageserverHttpClient: return PageserverHttpClient( port=self.service_port.http, @@ -2585,7 +2629,7 @@ class NeonPageserver(PgProtocol, LogUtils): self, tenant_id: TenantId, config: None | dict[str, Any] = None, - generation: Optional[int] = None, + generation: int | None = None, override_storage_controller_generation: bool = False, ): """ @@ -2619,7 +2663,7 @@ class NeonPageserver(PgProtocol, LogUtils): return client.tenant_location_conf(tenant_id, config, **kwargs) def read_tenant_location_conf( - self, tenant_shard_id: Union[TenantId, TenantShardId] + self, tenant_shard_id: TenantId | TenantShardId ) -> dict[str, Any]: path = self.tenant_dir(tenant_shard_id) / "config-v1" log.info(f"Reading location conf from {path}") @@ -2634,9 +2678,9 @@ class NeonPageserver(PgProtocol, LogUtils): def tenant_create( self, tenant_id: TenantId, - conf: Optional[dict[str, Any]] = None, - auth_token: Optional[str] = None, - generation: Optional[int] = None, + conf: dict[str, Any] | None = None, + auth_token: str | None = None, + generation: int | None = None, ) -> TenantId: if generation is None: generation = self.env.storage_controller.attach_hook_issue(tenant_id, self.id) @@ -2656,7 +2700,7 @@ class NeonPageserver(PgProtocol, LogUtils): return tenant_id def list_layers( - self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId + self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId ) -> list[Path]: """ Inspect local storage on a pageserver to discover which layer files are present. @@ -2749,7 +2793,7 @@ class PgBin: if "/" not in str(command[0]): command[0] = str(self.pg_bin_path / command[0]) - def _build_env(self, env_add: Optional[Env]) -> Env: + def _build_env(self, env_add: Env | None) -> Env: if env_add is None: return self.env env = self.env.copy() @@ -2766,8 +2810,8 @@ class PgBin: def run_nonblocking( self, command: list[str], - env: Optional[Env] = None, - cwd: Optional[Union[str, Path]] = None, + env: Env | None = None, + cwd: str | Path | None = None, ) -> subprocess.Popen[Any]: """ Run one of the postgres binaries, not waiting for it to finish @@ -2790,8 +2834,8 @@ class PgBin: def run( self, command: list[str], - env: Optional[Env] = None, - cwd: Optional[Union[str, Path]] = None, + env: Env | None = None, + cwd: str | Path | None = None, ) -> None: """ Run one of the postgres binaries, waiting for it to finish @@ -2813,8 +2857,8 @@ class PgBin: def run_capture( self, command: list[str], - env: Optional[Env] = None, - cwd: Optional[str] = None, + env: Env | None = None, + cwd: str | None = None, with_command_header=True, **popen_kwargs: Any, ) -> str: @@ -2941,7 +2985,7 @@ class VanillaPostgres(PgProtocol): conf_file.write("\n".join(hba) + "\n") conf_file.write(data) - def start(self, log_path: Optional[str] = None): + def start(self, log_path: str | None = None): assert not self.running self.running = True @@ -2960,14 +3004,14 @@ class VanillaPostgres(PgProtocol): """Return size of pgdatadir subdirectory in bytes.""" return get_dir_size(self.pgdatadir / subdir) - def __enter__(self) -> VanillaPostgres: + def __enter__(self) -> Self: return self def __exit__( self, - exc_type: Optional[type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, ): if self.running: self.stop() @@ -3009,14 +3053,14 @@ class RemotePostgres(PgProtocol): # See https://www.postgresql.org/docs/14/functions-admin.html#FUNCTIONS-ADMIN-GENFILE raise Exception("cannot get size of a Postgres instance") - def __enter__(self) -> RemotePostgres: + def __enter__(self) -> Self: return self def __exit__( self, - exc_type: Optional[type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, ): # do nothing pass @@ -3092,7 +3136,7 @@ class PSQL: self.path = full_path self.database_url = f"postgres://{host}:{port}/main?options=project%3Dgeneric-project-name" - async def run(self, query: Optional[str] = None) -> asyncio.subprocess.Process: + async def run(self, query: str | None = None) -> asyncio.subprocess.Process: run_args = [self.path, "--no-psqlrc", "--quiet", "--tuples-only", self.database_url] if query is not None: run_args += ["--command", query] @@ -3138,7 +3182,7 @@ class NeonProxy(PgProtocol): """All auth backends must inherit from this class""" @property - def default_conn_url(self) -> Optional[str]: + def default_conn_url(self) -> str | None: return None @abc.abstractmethod @@ -3155,7 +3199,7 @@ class NeonProxy(PgProtocol): ] class Console(AuthBackend): - def __init__(self, endpoint: str, fixed_rate_limit: Optional[int] = None): + def __init__(self, endpoint: str, fixed_rate_limit: int | None = None): self.endpoint = endpoint self.fixed_rate_limit = fixed_rate_limit @@ -3183,7 +3227,7 @@ class NeonProxy(PgProtocol): pg_conn_url: str @property - def default_conn_url(self) -> Optional[str]: + def default_conn_url(self) -> str | None: return self.pg_conn_url def extra_args(self) -> list[str]: @@ -3202,8 +3246,8 @@ class NeonProxy(PgProtocol): mgmt_port: int, external_http_port: int, auth_backend: NeonProxy.AuthBackend, - metric_collection_endpoint: Optional[str] = None, - metric_collection_interval: Optional[str] = None, + metric_collection_endpoint: str | None = None, + metric_collection_interval: str | None = None, ): host = "127.0.0.1" domain = "proxy.localtest.me" # resolves to 127.0.0.1 @@ -3221,9 +3265,9 @@ class NeonProxy(PgProtocol): self.metric_collection_endpoint = metric_collection_endpoint self.metric_collection_interval = metric_collection_interval self.http_timeout_seconds = 15 - self._popen: Optional[subprocess.Popen[bytes]] = None + self._popen: subprocess.Popen[bytes] | None = None - def start(self) -> NeonProxy: + def start(self) -> Self: assert self._popen is None # generate key of it doesn't exist @@ -3351,14 +3395,14 @@ class NeonProxy(PgProtocol): log.info(f"SUCCESS, found auth url: {line}") return line - def __enter__(self) -> NeonProxy: + def __enter__(self) -> Self: return self def __exit__( self, - exc_type: Optional[type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, ): if self._popen is not None: self._popen.terminate() @@ -3439,9 +3483,9 @@ class NeonAuthBroker: self.mgmt_port = mgmt_port self.auth_backend = auth_backend self.http_timeout_seconds = 15 - self._popen: Optional[subprocess.Popen[bytes]] = None + self._popen: subprocess.Popen[bytes] | None = None - def start(self) -> NeonAuthBroker: + def start(self) -> Self: assert self._popen is None # generate key of it doesn't exist @@ -3510,14 +3554,14 @@ class NeonAuthBroker: request_result = requests.get(f"http://{self.host}:{self.http_port}/metrics") return request_result.text - def __enter__(self) -> NeonAuthBroker: + def __enter__(self) -> Self: return self def __exit__( self, - _exc_type: Optional[type[BaseException]], - _exc_value: Optional[BaseException], - _traceback: Optional[TracebackType], + _exc_type: type[BaseException] | None, + _exc_value: BaseException | None, + _traceback: TracebackType | None, ): if self._popen is not None: self._popen.terminate() @@ -3673,9 +3717,9 @@ class Endpoint(PgProtocol, LogUtils): ): super().__init__(host="localhost", port=pg_port, user="cloud_admin", dbname="postgres") self.env = env - self.branch_name: Optional[str] = None # dubious - self.endpoint_id: Optional[str] = None # dubious, see asserts below - self.pgdata_dir: Optional[Path] = None # Path to computenode PGDATA + self.branch_name: str | None = None # dubious + self.endpoint_id: str | None = None # dubious, see asserts below + self.pgdata_dir: Path | None = None # Path to computenode PGDATA self.tenant_id = tenant_id self.pg_port = pg_port self.http_port = http_port @@ -3692,7 +3736,7 @@ class Endpoint(PgProtocol, LogUtils): self._running = threading.Semaphore(0) def http_client( - self, auth_token: Optional[str] = None, retries: Optional[Retry] = None + self, auth_token: str | None = None, retries: Retry | None = None ) -> EndpointHttpClient: return EndpointHttpClient( port=self.http_port, @@ -3701,13 +3745,13 @@ class Endpoint(PgProtocol, LogUtils): def create( self, branch_name: str, - endpoint_id: Optional[str] = None, + endpoint_id: str | None = None, hot_standby: bool = False, - lsn: Optional[Lsn] = None, - config_lines: Optional[list[str]] = None, - pageserver_id: Optional[int] = None, + lsn: Lsn | None = None, + config_lines: list[str] | None = None, + pageserver_id: int | None = None, allow_multiple: bool = False, - ) -> Endpoint: + ) -> Self: """ Create a new Postgres endpoint. Returns self. @@ -3736,24 +3780,58 @@ class Endpoint(PgProtocol, LogUtils): self.pgdata_dir = self.env.repo_dir / path self.logfile = self.endpoint_path() / "compute.log" - config_lines = config_lines or [] - # set small 'max_replication_write_lag' to enable backpressure # and make tests more stable. config_lines = ["max_replication_write_lag=15MB"] + config_lines + # Delete file cache if it exists (and we're recreating the endpoint) + if USE_LFC: + if (lfc_path := Path(self.lfc_path())).exists(): + lfc_path.unlink() + else: + lfc_path.parent.mkdir(parents=True, exist_ok=True) + for line in config_lines: + if ( + line.find("neon.max_file_cache_size") > -1 + or line.find("neon.file_cache_size_limit") > -1 + ): + m = re.search(r"=\s*(\S+)", line) + assert m is not None, f"malformed config line {line}" + size = m.group(1) + assert size_to_bytes(size) >= size_to_bytes( + "1MB" + ), "LFC size cannot be set less than 1MB" + # shared_buffers = 512kB to make postgres use LFC intensively + # neon.max_file_cache_size and neon.file_cache size limit are + # set to 1MB because small LFC is better for testing (helps to find more problems) + lfc_path_escaped = str(lfc_path).replace("'", "''") + config_lines = [ + "shared_buffers = 512kB", + f"neon.file_cache_path = '{lfc_path_escaped}'", + "neon.max_file_cache_size = 1MB", + "neon.file_cache_size_limit = 1MB", + ] + config_lines + else: + for line in config_lines: + assert ( + line.find("neon.max_file_cache_size") == -1 + ), "Setting LFC parameters is not allowed when LFC is disabled" + assert ( + line.find("neon.file_cache_size_limit") == -1 + ), "Setting LFC parameters is not allowed when LFC is disabled" + self.config(config_lines) return self def start( self, - remote_ext_config: Optional[str] = None, - pageserver_id: Optional[int] = None, - safekeepers: Optional[list[int]] = None, + remote_ext_config: str | None = None, + pageserver_id: int | None = None, + safekeepers: list[int] | None = None, allow_multiple: bool = False, - basebackup_request_tries: Optional[int] = None, - ) -> Endpoint: + basebackup_request_tries: int | None = None, + ) -> Self: """ Start the Postgres instance. Returns self. @@ -3775,6 +3853,9 @@ class Endpoint(PgProtocol, LogUtils): basebackup_request_tries=basebackup_request_tries, ) self._running.release(1) + self.log_config_value("shared_buffers") + self.log_config_value("neon.max_file_cache_size") + self.log_config_value("neon.file_cache_size_limit") return self @@ -3800,7 +3881,11 @@ class Endpoint(PgProtocol, LogUtils): """Path to the postgresql.conf in the endpoint directory (not the one in pgdata)""" return self.endpoint_path() / "postgresql.conf" - def config(self, lines: list[str]) -> Endpoint: + def lfc_path(self) -> Path: + """Path to the lfc file""" + return self.endpoint_path() / "file_cache" / "file.cache" + + def config(self, lines: list[str]) -> Self: """ Add lines to postgresql.conf. Lines should be an array of valid postgresql.conf rows. @@ -3828,9 +3913,7 @@ class Endpoint(PgProtocol, LogUtils): def is_running(self): return self._running._value > 0 - def reconfigure( - self, pageserver_id: Optional[int] = None, safekeepers: Optional[list[int]] = None - ): + def reconfigure(self, pageserver_id: int | None = None, safekeepers: list[int] | None = None): assert self.endpoint_id is not None # If `safekeepers` is not None, they are remember them as active and use # in the following commands. @@ -3852,6 +3935,35 @@ class Endpoint(PgProtocol, LogUtils): log.info(json.dumps(dict(data_dict, **kwargs))) json.dump(dict(data_dict, **kwargs), file, indent=4) + def respec_deep(self, **kwargs: Any) -> None: + """ + Update the endpoint.json file taking into account nested keys. + It does one level deep update. Should enough for most cases. + Distinct method from respec() to do not break existing functionality. + NOTE: This method also updates the spec.json file, not endpoint.json. + We need it because neon_local also writes to spec.json, so intended + use-case is i) start endpoint with some config, ii) respec_deep(), + iii) call reconfigure() to apply the changes. + """ + config_path = os.path.join(self.endpoint_path(), "spec.json") + with open(config_path) as f: + data_dict: dict[str, Any] = json.load(f) + + log.info("Current compute spec: %s", json.dumps(data_dict, indent=4)) + + for key, value in kwargs.items(): + if isinstance(value, dict): + if key not in data_dict: + data_dict[key] = value + else: + data_dict[key] = {**data_dict[key], **value} + else: + data_dict[key] = value + + with open(config_path, "w") as file: + log.info("Updating compute spec to: %s", json.dumps(data_dict, indent=4)) + json.dump(data_dict, file, indent=4) + # Please note: Migrations only run if pg_skip_catalog_updates is false def wait_for_migrations(self, num_migrations: int = 11): with self.cursor() as cur: @@ -3877,8 +3989,8 @@ class Endpoint(PgProtocol, LogUtils): def stop( self, mode: str = "fast", - sks_wait_walreceiver_gone: Optional[tuple[list[Safekeeper], TimelineId]] = None, - ) -> Endpoint: + sks_wait_walreceiver_gone: tuple[list[Safekeeper], TimelineId] | None = None, + ) -> Self: """ Stop the Postgres instance if it's running. @@ -3912,7 +4024,7 @@ class Endpoint(PgProtocol, LogUtils): return self - def stop_and_destroy(self, mode: str = "immediate") -> Endpoint: + def stop_and_destroy(self, mode: str = "immediate") -> Self: """ Stop the Postgres instance, then destroy the endpoint. Returns self. @@ -3931,15 +4043,15 @@ class Endpoint(PgProtocol, LogUtils): def create_start( self, branch_name: str, - endpoint_id: Optional[str] = None, + endpoint_id: str | None = None, hot_standby: bool = False, - lsn: Optional[Lsn] = None, - config_lines: Optional[list[str]] = None, - remote_ext_config: Optional[str] = None, - pageserver_id: Optional[int] = None, + lsn: Lsn | None = None, + config_lines: list[str] | None = None, + remote_ext_config: str | None = None, + pageserver_id: int | None = None, allow_multiple: bool = False, - basebackup_request_tries: Optional[int] = None, - ) -> Endpoint: + basebackup_request_tries: int | None = None, + ) -> Self: """ Create an endpoint, apply config, and start Postgres. Returns self. @@ -3962,14 +4074,14 @@ class Endpoint(PgProtocol, LogUtils): return self - def __enter__(self) -> Endpoint: + def __enter__(self) -> Self: return self def __exit__( self, - exc_type: Optional[type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, ): self.stop() @@ -3980,16 +4092,46 @@ class Endpoint(PgProtocol, LogUtils): assert self.pgdata_dir is not None # please mypy return get_dir_size(self.pgdata_dir / "pg_wal") / 1024 / 1024 - def clear_shared_buffers(self, cursor: Optional[Any] = None): + def clear_buffers(self, cursor: Any | None = None): """ Best-effort way to clear postgres buffers. Pinned buffers will not be 'cleared.' - - Might also clear LFC. + It clears LFC as well by setting neon.file_cache_size_limit to 0 and then returning it to the previous value, + if LFC is enabled """ if cursor is not None: cursor.execute("select clear_buffer_cache()") + if not USE_LFC: + return + cursor.execute("SHOW neon.file_cache_size_limit") + res = cursor.fetchone() + assert res, "Cannot get neon.file_cache_size_limit" + file_cache_size_limit = res[0] + if file_cache_size_limit == 0: + return + cursor.execute("ALTER SYSTEM SET neon.file_cache_size_limit=0") + cursor.execute("SELECT pg_reload_conf()") + cursor.execute(f"ALTER SYSTEM SET neon.file_cache_size_limit='{file_cache_size_limit}'") + cursor.execute("SELECT pg_reload_conf()") else: self.safe_psql("select clear_buffer_cache()") + if not USE_LFC: + return + file_cache_size_limit = self.safe_psql_scalar( + "SHOW neon.file_cache_size_limit", log_query=False + ) + if file_cache_size_limit == 0: + return + self.safe_psql("ALTER SYSTEM SET neon.file_cache_size_limit=0") + self.safe_psql("SELECT pg_reload_conf()") + self.safe_psql(f"ALTER SYSTEM SET neon.file_cache_size_limit='{file_cache_size_limit}'") + self.safe_psql("SELECT pg_reload_conf()") + + def log_config_value(self, param): + """ + Writes the config value param to log + """ + res = self.safe_psql_scalar(f"SHOW {param}", log_query=False) + log.info("%s = %s", param, res) class EndpointFactory: @@ -4003,14 +4145,14 @@ class EndpointFactory: def create_start( self, branch_name: str, - endpoint_id: Optional[str] = None, - tenant_id: Optional[TenantId] = None, - lsn: Optional[Lsn] = None, + endpoint_id: str | None = None, + tenant_id: TenantId | None = None, + lsn: Lsn | None = None, hot_standby: bool = False, - config_lines: Optional[list[str]] = None, - remote_ext_config: Optional[str] = None, - pageserver_id: Optional[int] = None, - basebackup_request_tries: Optional[int] = None, + config_lines: list[str] | None = None, + remote_ext_config: str | None = None, + pageserver_id: int | None = None, + basebackup_request_tries: int | None = None, ) -> Endpoint: ep = Endpoint( self.env, @@ -4035,12 +4177,12 @@ class EndpointFactory: def create( self, branch_name: str, - endpoint_id: Optional[str] = None, - tenant_id: Optional[TenantId] = None, - lsn: Optional[Lsn] = None, + endpoint_id: str | None = None, + tenant_id: TenantId | None = None, + lsn: Lsn | None = None, hot_standby: bool = False, - config_lines: Optional[list[str]] = None, - pageserver_id: Optional[int] = None, + config_lines: list[str] | None = None, + pageserver_id: int | None = None, ) -> Endpoint: ep = Endpoint( self.env, @@ -4063,7 +4205,7 @@ class EndpointFactory: pageserver_id=pageserver_id, ) - def stop_all(self, fail_on_error=True) -> EndpointFactory: + def stop_all(self, fail_on_error=True) -> Self: exception = None for ep in self.endpoints: try: @@ -4078,7 +4220,7 @@ class EndpointFactory: return self def new_replica( - self, origin: Endpoint, endpoint_id: str, config_lines: Optional[list[str]] = None + self, origin: Endpoint, endpoint_id: str, config_lines: list[str] | None = None ): branch_name = origin.branch_name assert origin in self.endpoints @@ -4094,7 +4236,7 @@ class EndpointFactory: ) def new_replica_start( - self, origin: Endpoint, endpoint_id: str, config_lines: Optional[list[str]] = None + self, origin: Endpoint, endpoint_id: str, config_lines: list[str] | None = None ): branch_name = origin.branch_name assert origin in self.endpoints @@ -4132,7 +4274,7 @@ class Safekeeper(LogUtils): port: SafekeeperPort, id: int, running: bool = False, - extra_opts: Optional[list[str]] = None, + extra_opts: list[str] | None = None, ): self.env = env self.port = port @@ -4158,8 +4300,8 @@ class Safekeeper(LogUtils): self.extra_opts = extra_opts def start( - self, extra_opts: Optional[list[str]] = None, timeout_in_seconds: Optional[int] = None - ) -> Safekeeper: + self, extra_opts: list[str] | None = None, timeout_in_seconds: int | None = None + ) -> Self: if extra_opts is None: # Apply either the extra_opts passed in, or the ones from our constructor: we do not merge the two. extra_opts = self.extra_opts @@ -4194,7 +4336,7 @@ class Safekeeper(LogUtils): break # success return self - def stop(self, immediate: bool = False) -> Safekeeper: + def stop(self, immediate: bool = False) -> Self: self.env.neon_cli.safekeeper_stop(self.id, immediate) self.running = False return self @@ -4238,7 +4380,7 @@ class Safekeeper(LogUtils): return res def http_client( - self, auth_token: Optional[str] = None, gen_sk_wide_token: bool = True + self, auth_token: str | None = None, gen_sk_wide_token: bool = True ) -> SafekeeperHttpClient: """ When auth_token is None but gen_sk_wide is True creates safekeeper wide @@ -4263,6 +4405,10 @@ class Safekeeper(LogUtils): log.info(f"sk {self.id} flush LSN: {flush_lsn}") return flush_lsn + def get_commit_lsn(self, tenant_id: TenantId, timeline_id: TimelineId) -> Lsn: + timeline_status = self.http_client().timeline_status(tenant_id, timeline_id) + return timeline_status.commit_lsn + def pull_timeline( self, srcs: list[Safekeeper], tenant_id: TenantId, timeline_id: TimelineId ) -> dict[str, Any]: @@ -4371,14 +4517,14 @@ class NeonBroker(LogUtils): def start( self, - timeout_in_seconds: Optional[int] = None, - ): + timeout_in_seconds: int | None = None, + ) -> Self: assert not self.running self.env.neon_cli.storage_broker_start(timeout_in_seconds) self.running = True return self - def stop(self): + def stop(self) -> Self: if self.running: self.env.neon_cli.storage_broker_stop() self.running = False @@ -4394,8 +4540,7 @@ class NeonBroker(LogUtils): assert_no_errors(self.logfile, "storage_controller", []) -# TODO: Replace with `StrEnum` when we upgrade to python 3.11 -class NodeKind(str, Enum): +class NodeKind(StrEnum): PAGESERVER = "pageserver" SAFEKEEPER = "safekeeper" @@ -4406,7 +4551,7 @@ class StorageScrubber: self.log_dir = log_dir def scrubber_cli( - self, args: list[str], timeout, extra_env: Optional[dict[str, str]] = None + self, args: list[str], timeout, extra_env: dict[str, str] | None = None ) -> str: assert isinstance(self.env.pageserver_remote_storage, S3Storage) s3_storage = self.env.pageserver_remote_storage @@ -4469,8 +4614,8 @@ class StorageScrubber: self, post_to_storage_controller: bool = False, node_kind: NodeKind = NodeKind.PAGESERVER, - timeline_lsns: Optional[list[dict[str, Any]]] = None, - extra_env: Optional[dict[str, str]] = None, + timeline_lsns: list[dict[str, Any]] | None = None, + extra_env: dict[str, str] | None = None, ) -> tuple[bool, Any]: """ Returns the health status and the metadata summary. @@ -4504,8 +4649,8 @@ class StorageScrubber: def pageserver_physical_gc( self, min_age_secs: int, - tenant_ids: Optional[list[TenantId]] = None, - mode: Optional[str] = None, + tenant_ids: list[TenantId] | None = None, + mode: str | None = None, ): args = ["pageserver-physical-gc", "--min-age", f"{min_age_secs}s"] @@ -4619,7 +4764,7 @@ def check_restored_datadir_content( test_output_dir: Path, env: NeonEnv, endpoint: Endpoint, - ignored_files: Optional[list[str]] = None, + ignored_files: list[str] | None = None, ): pg_bin = PgBin(test_output_dir, env.pg_distrib_dir, env.pg_version) @@ -4721,7 +4866,7 @@ def logical_replication_sync(subscriber: PgProtocol, publisher: PgProtocol) -> L def tenant_get_shards( - env: NeonEnv, tenant_id: TenantId, pageserver_id: Optional[int] = None + env: NeonEnv, tenant_id: TenantId, pageserver_id: int | None = None ) -> list[tuple[TenantShardId, NeonPageserver]]: """ Helper for when you want to talk to one or more pageservers, and the @@ -4784,8 +4929,8 @@ def wait_for_last_flush_lsn( endpoint: Endpoint, tenant: TenantId, timeline: TimelineId, - pageserver_id: Optional[int] = None, - auth_token: Optional[str] = None, + pageserver_id: int | None = None, + auth_token: str | None = None, ) -> Lsn: """Wait for pageserver to catch up the latest flush LSN, returns the last observed lsn.""" @@ -4809,12 +4954,39 @@ def wait_for_last_flush_lsn( return min(results) +def wait_for_commit_lsn( + env: NeonEnv, + tenant: TenantId, + timeline: TimelineId, + lsn: Lsn, +) -> Lsn: + # TODO: it would be better to poll this in the compute, but there's no API for it. See: + # https://github.com/neondatabase/neon/issues/9758 + "Wait for the given LSN to be committed on any Safekeeper" + + max_commit_lsn = Lsn(0) + for i in range(1000): + for sk in env.safekeepers: + commit_lsn = sk.get_commit_lsn(tenant, timeline) + if commit_lsn >= lsn: + log.info(f"{tenant}/{timeline} at commit_lsn {commit_lsn}") + return commit_lsn + max_commit_lsn = max(max_commit_lsn, commit_lsn) + + if i % 10 == 0: + log.info( + f"{tenant}/{timeline} waiting for commit_lsn to reach {lsn}, now {max_commit_lsn}" + ) + time.sleep(0.1) + raise Exception(f"timed out while waiting for commit_lsn to reach {lsn}, was {max_commit_lsn}") + + def flush_ep_to_pageserver( env: NeonEnv, ep: Endpoint, tenant: TenantId, timeline: TimelineId, - pageserver_id: Optional[int] = None, + pageserver_id: int | None = None, ) -> Lsn: """ Stop endpoint and wait until all committed WAL reaches the pageserver @@ -4857,7 +5029,7 @@ def wait_for_wal_insert_lsn( endpoint: Endpoint, tenant: TenantId, timeline: TimelineId, - pageserver_id: Optional[int] = None, + pageserver_id: int | None = None, ) -> Lsn: """Wait for pageserver to catch up the latest flush LSN, returns the last observed lsn.""" last_flush_lsn = Lsn(endpoint.safe_psql("SELECT pg_current_wal_insert_lsn()")[0][0]) @@ -4878,7 +5050,7 @@ def fork_at_current_lsn( endpoint: Endpoint, new_branch_name: str, ancestor_branch_name: str, - tenant_id: Optional[TenantId] = None, + tenant_id: TenantId | None = None, ) -> TimelineId: """ Create new branch at the last LSN of an existing branch. @@ -4951,8 +5123,9 @@ def last_flush_lsn_upload( endpoint: Endpoint, tenant_id: TenantId, timeline_id: TimelineId, - pageserver_id: Optional[int] = None, - auth_token: Optional[str] = None, + pageserver_id: int | None = None, + auth_token: str | None = None, + wait_until_uploaded: bool = True, ) -> Lsn: """ Wait for pageserver to catch to the latest flush LSN of given endpoint, @@ -4966,7 +5139,9 @@ def last_flush_lsn_upload( for tenant_shard_id, pageserver in shards: ps_http = pageserver.http_client(auth_token=auth_token) wait_for_last_record_lsn(ps_http, tenant_shard_id, timeline_id, last_flush_lsn) - ps_http.timeline_checkpoint(tenant_shard_id, timeline_id, wait_until_uploaded=True) + ps_http.timeline_checkpoint( + tenant_shard_id, timeline_id, wait_until_uploaded=wait_until_uploaded + ) return last_flush_lsn @@ -4987,10 +5162,11 @@ def generate_uploads_and_deletions( env: NeonEnv, *, init: bool = True, - tenant_id: Optional[TenantId] = None, - timeline_id: Optional[TimelineId] = None, - data: Optional[str] = None, + tenant_id: TenantId | None = None, + timeline_id: TimelineId | None = None, + data: str | None = None, pageserver: NeonPageserver, + wait_until_uploaded: bool = True, ): """ Using the environment's default tenant + timeline, generate a load pattern @@ -5013,7 +5189,12 @@ def generate_uploads_and_deletions( if init: endpoint.safe_psql("CREATE TABLE foo (id INTEGER PRIMARY KEY, val text)") last_flush_lsn_upload( - env, endpoint, tenant_id, timeline_id, pageserver_id=pageserver.id + env, + endpoint, + tenant_id, + timeline_id, + pageserver_id=pageserver.id, + wait_until_uploaded=wait_until_uploaded, ) def churn(data): @@ -5036,7 +5217,12 @@ def generate_uploads_and_deletions( # in a state where there are "future layers" in remote storage that will generate deletions # after a restart. last_flush_lsn_upload( - env, endpoint, tenant_id, timeline_id, pageserver_id=pageserver.id + env, + endpoint, + tenant_id, + timeline_id, + pageserver_id=pageserver.id, + wait_until_uploaded=wait_until_uploaded, ) # Compaction should generate some GC-elegible layers @@ -5052,4 +5238,4 @@ def generate_uploads_and_deletions( # background ingest, no more uploads pending, and therefore no non-determinism # in subsequent actions like pageserver restarts. flush_ep_to_pageserver(env, endpoint, tenant_id, timeline_id, pageserver.id) - ps_http.timeline_checkpoint(tenant_id, timeline_id, wait_until_uploaded=True) + ps_http.timeline_checkpoint(tenant_id, timeline_id, wait_until_uploaded=wait_until_uploaded) diff --git a/test_runner/fixtures/pageserver/common_types.py b/test_runner/fixtures/pageserver/common_types.py index 2319701e0b..0e068db593 100644 --- a/test_runner/fixtures/pageserver/common_types.py +++ b/test_runner/fixtures/pageserver/common_types.py @@ -2,7 +2,7 @@ from __future__ import annotations import re from dataclasses import dataclass -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING from fixtures.common_types import KEY_MAX, KEY_MIN, Key, Lsn @@ -46,7 +46,7 @@ class DeltaLayerName: return ret -LayerName = Union[ImageLayerName, DeltaLayerName] +LayerName = ImageLayerName | DeltaLayerName class InvalidFileName(Exception): diff --git a/test_runner/fixtures/pageserver/http.py b/test_runner/fixtures/pageserver/http.py index 01583757fa..4cf3ece396 100644 --- a/test_runner/fixtures/pageserver/http.py +++ b/test_runner/fixtures/pageserver/http.py @@ -1,24 +1,32 @@ from __future__ import annotations +import dataclasses +import json +import random +import string import time from collections import defaultdict from dataclasses import dataclass from datetime import datetime -from typing import TYPE_CHECKING, Any +from typing import Any import requests from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry -from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineArchivalState, TimelineId +from fixtures.common_types import ( + Id, + Lsn, + TenantId, + TenantShardId, + TimelineArchivalState, + TimelineId, +) from fixtures.log_helper import log from fixtures.metrics import Metrics, MetricsGetter, parse_metrics from fixtures.pg_version import PgVersion from fixtures.utils import Fn -if TYPE_CHECKING: - from typing import Optional, Union - class PageserverApiException(Exception): def __init__(self, message, status_code: int): @@ -27,6 +35,69 @@ class PageserverApiException(Exception): self.status_code = status_code +@dataclass +class ImportPgdataIdemptencyKey: + key: str + + @staticmethod + def random() -> ImportPgdataIdemptencyKey: + return ImportPgdataIdemptencyKey( + "".join(random.choices(string.ascii_letters + string.digits, k=20)) + ) + + +@dataclass +class LocalFs: + path: str + + +@dataclass +class AwsS3: + region: str + bucket: str + key: str + + +@dataclass +class ImportPgdataLocation: + LocalFs: None | LocalFs = None + AwsS3: None | AwsS3 = None + + +@dataclass +class TimelineCreateRequestModeImportPgdata: + location: ImportPgdataLocation + idempotency_key: ImportPgdataIdemptencyKey + + +@dataclass +class TimelineCreateRequestMode: + Branch: None | dict[str, Any] = None + Bootstrap: None | dict[str, Any] = None + ImportPgdata: None | TimelineCreateRequestModeImportPgdata = None + + +@dataclass +class TimelineCreateRequest: + new_timeline_id: TimelineId + mode: TimelineCreateRequestMode + + def to_json(self) -> str: + class EnhancedJSONEncoder(json.JSONEncoder): + def default(self, o): + if dataclasses.is_dataclass(o) and not isinstance(o, type): + return dataclasses.asdict(o) + elif isinstance(o, Id): + return o.id.hex() + return super().default(o) + + # mode is flattened + this = dataclasses.asdict(self) + mode = this.pop("mode") + this.update(mode) + return json.dumps(self, cls=EnhancedJSONEncoder) + + class TimelineCreate406(PageserverApiException): def __init__(self, res: requests.Response): assert res.status_code == 406 @@ -43,7 +114,7 @@ class TimelineCreate409(PageserverApiException): class InMemoryLayerInfo: kind: str lsn_start: str - lsn_end: Optional[str] + lsn_end: str | None @classmethod def from_json(cls, d: dict[str, Any]) -> InMemoryLayerInfo: @@ -60,10 +131,10 @@ class HistoricLayerInfo: layer_file_name: str layer_file_size: int lsn_start: str - lsn_end: Optional[str] + lsn_end: str | None remote: bool # None for image layers, true if pageserver thinks this is an L0 delta layer - l0: Optional[bool] + l0: bool | None visible: bool @classmethod @@ -180,8 +251,8 @@ class PageserverHttpClient(requests.Session, MetricsGetter): self, port: int, is_testing_enabled_or_skip: Fn, - auth_token: Optional[str] = None, - retries: Optional[Retry] = None, + auth_token: str | None = None, + retries: Retry | None = None, ): super().__init__() self.port = port @@ -278,7 +349,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def tenant_attach( self, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, generation: int, config: None | dict[str, Any] = None, ): @@ -305,7 +376,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): }, ) - def tenant_reset(self, tenant_id: Union[TenantId, TenantShardId], drop_cache: bool): + def tenant_reset(self, tenant_id: TenantId | TenantShardId, drop_cache: bool): params = {} if drop_cache: params["drop_cache"] = "true" @@ -315,10 +386,10 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def tenant_location_conf( self, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, location_conf: dict[str, Any], flush_ms=None, - lazy: Optional[bool] = None, + lazy: bool | None = None, ): body = location_conf.copy() @@ -346,20 +417,20 @@ class PageserverHttpClient(requests.Session, MetricsGetter): assert isinstance(res_json["tenant_shards"], list) return res_json - def tenant_get_location(self, tenant_id: TenantShardId): + def tenant_get_location(self, tenant_id: TenantId | TenantShardId): res = self.get( f"http://localhost:{self.port}/v1/location_config/{tenant_id}", ) self.verbose_error(res) return res.json() - def tenant_delete(self, tenant_id: Union[TenantId, TenantShardId]): + def tenant_delete(self, tenant_id: TenantId | TenantShardId): res = self.delete(f"http://localhost:{self.port}/v1/tenant/{tenant_id}") self.verbose_error(res) return res def tenant_status( - self, tenant_id: Union[TenantId, TenantShardId], activate: bool = False + self, tenant_id: TenantId | TenantShardId, activate: bool = False ) -> dict[Any, Any]: """ :activate: hint the server not to accelerate activation of this tenant in response @@ -378,17 +449,17 @@ class PageserverHttpClient(requests.Session, MetricsGetter): assert isinstance(res_json, dict) return res_json - def tenant_config(self, tenant_id: Union[TenantId, TenantShardId]) -> TenantConfig: + def tenant_config(self, tenant_id: TenantId | TenantShardId) -> TenantConfig: res = self.get(f"http://localhost:{self.port}/v1/tenant/{tenant_id}/config") self.verbose_error(res) return TenantConfig.from_json(res.json()) - def tenant_heatmap_upload(self, tenant_id: Union[TenantId, TenantShardId]): + def tenant_heatmap_upload(self, tenant_id: TenantId | TenantShardId): res = self.post(f"http://localhost:{self.port}/v1/tenant/{tenant_id}/heatmap_upload") self.verbose_error(res) def tenant_secondary_download( - self, tenant_id: Union[TenantId, TenantShardId], wait_ms: Optional[int] = None + self, tenant_id: TenantId | TenantShardId, wait_ms: int | None = None ) -> tuple[int, dict[Any, Any]]: url = f"http://localhost:{self.port}/v1/tenant/{tenant_id}/secondary/download" if wait_ms is not None: @@ -397,13 +468,13 @@ class PageserverHttpClient(requests.Session, MetricsGetter): self.verbose_error(res) return (res.status_code, res.json()) - def tenant_secondary_status(self, tenant_id: Union[TenantId, TenantShardId]): + def tenant_secondary_status(self, tenant_id: TenantId | TenantShardId): url = f"http://localhost:{self.port}/v1/tenant/{tenant_id}/secondary/status" res = self.get(url) self.verbose_error(res) return res.json() - def set_tenant_config(self, tenant_id: Union[TenantId, TenantShardId], config: dict[str, Any]): + def set_tenant_config(self, tenant_id: TenantId | TenantShardId, config: dict[str, Any]): """ Only use this via storage_controller.pageserver_api(). @@ -420,8 +491,8 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def patch_tenant_config_client_side( self, tenant_id: TenantId, - inserts: Optional[dict[str, Any]] = None, - removes: Optional[list[str]] = None, + inserts: dict[str, Any] | None = None, + removes: list[str] | None = None, ): """ Only use this via storage_controller.pageserver_api(). @@ -436,11 +507,11 @@ class PageserverHttpClient(requests.Session, MetricsGetter): del current[key] self.set_tenant_config(tenant_id, current) - def tenant_size(self, tenant_id: Union[TenantId, TenantShardId]) -> int: + def tenant_size(self, tenant_id: TenantId | TenantShardId) -> int: return self.tenant_size_and_modelinputs(tenant_id)[0] def tenant_size_and_modelinputs( - self, tenant_id: Union[TenantId, TenantShardId] + self, tenant_id: TenantId | TenantShardId ) -> tuple[int, dict[str, Any]]: """ Returns the tenant size, together with the model inputs as the second tuple item. @@ -456,7 +527,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): assert isinstance(inputs, dict) return (size, inputs) - def tenant_size_debug(self, tenant_id: Union[TenantId, TenantShardId]) -> str: + def tenant_size_debug(self, tenant_id: TenantId | TenantShardId) -> str: """ Returns the tenant size debug info, as an HTML string """ @@ -468,10 +539,10 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def tenant_time_travel_remote_storage( self, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, timestamp: datetime, done_if_after: datetime, - shard_counts: Optional[list[int]] = None, + shard_counts: list[int] | None = None, ): """ Issues a request to perform time travel operations on the remote storage @@ -490,7 +561,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def timeline_list( self, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, include_non_incremental_logical_size: bool = False, include_timeline_dir_layer_file_size_sum: bool = False, ) -> list[dict[str, Any]]: @@ -510,7 +581,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def timeline_and_offloaded_list( self, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, ) -> TimelinesInfoAndOffloaded: res = self.get( f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline_and_offloaded", @@ -523,11 +594,11 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def timeline_create( self, pg_version: PgVersion, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, new_timeline_id: TimelineId, - ancestor_timeline_id: Optional[TimelineId] = None, - ancestor_start_lsn: Optional[Lsn] = None, - existing_initdb_timeline_id: Optional[TimelineId] = None, + ancestor_timeline_id: TimelineId | None = None, + ancestor_start_lsn: Lsn | None = None, + existing_initdb_timeline_id: TimelineId | None = None, **kwargs, ) -> dict[Any, Any]: body: dict[str, Any] = { @@ -558,7 +629,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def timeline_detail( self, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, include_non_incremental_logical_size: bool = False, include_timeline_dir_layer_file_size_sum: bool = False, @@ -584,7 +655,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): return res_json def timeline_delete( - self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId, **kwargs + self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, **kwargs ): """ Note that deletion is not instant, it is scheduled and performed mostly in the background. @@ -600,9 +671,9 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def timeline_gc( self, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, - gc_horizon: Optional[int], + gc_horizon: int | None, ) -> dict[str, Any]: """ Unlike most handlers, this will wait for the layers to be actually @@ -624,16 +695,14 @@ class PageserverHttpClient(requests.Session, MetricsGetter): assert isinstance(res_json, dict) return res_json - def timeline_block_gc(self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId): + def timeline_block_gc(self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId): res = self.post( f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/block_gc", ) log.info(f"Got GC request response code: {res.status_code}") self.verbose_error(res) - def timeline_unblock_gc( - self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId - ): + def timeline_unblock_gc(self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId): res = self.post( f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/unblock_gc", ) @@ -642,7 +711,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def timeline_offload( self, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, ): self.is_testing_enabled_or_skip() @@ -658,14 +727,14 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def timeline_compact( self, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, force_repartition=False, force_image_layer_creation=False, force_l0_compaction=False, wait_until_uploaded=False, enhanced_gc_bottom_most_compaction=False, - body: Optional[dict[str, Any]] = None, + body: dict[str, Any] | None = None, ): self.is_testing_enabled_or_skip() query = {} @@ -692,7 +761,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): assert res_json is None def timeline_preserve_initdb_archive( - self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId + self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId ): log.info( f"Requesting initdb archive preservation for tenant {tenant_id} and timeline {timeline_id}" @@ -704,7 +773,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def timeline_archival_config( self, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, state: TimelineArchivalState, ): @@ -720,7 +789,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def timeline_get_lsn_by_timestamp( self, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, timestamp: datetime, with_lease: bool = False, @@ -739,7 +808,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): return res_json def timeline_lsn_lease( - self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId, lsn: Lsn + self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, lsn: Lsn ): data = { "lsn": str(lsn), @@ -755,7 +824,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): return res_json def timeline_get_timestamp_of_lsn( - self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId, lsn: Lsn + self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, lsn: Lsn ): log.info(f"Requesting time range of lsn {lsn}, tenant {tenant_id}, timeline {timeline_id}") res = self.get( @@ -765,9 +834,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): res_json = res.json() return res_json - def timeline_layer_map_info( - self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId - ): + def timeline_layer_map_info(self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId): log.info(f"Requesting layer map info of tenant {tenant_id}, timeline {timeline_id}") res = self.get( f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/layer", @@ -778,13 +845,13 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def timeline_checkpoint( self, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, force_repartition=False, force_image_layer_creation=False, force_l0_compaction=False, wait_until_uploaded=False, - compact: Optional[bool] = None, + compact: bool | None = None, **kwargs, ): self.is_testing_enabled_or_skip() @@ -801,7 +868,9 @@ class PageserverHttpClient(requests.Session, MetricsGetter): if compact is not None: query["compact"] = "true" if compact else "false" - log.info(f"Requesting checkpoint: tenant {tenant_id}, timeline {timeline_id}") + log.info( + f"Requesting checkpoint: tenant {tenant_id}, timeline {timeline_id}, wait_until_uploaded={wait_until_uploaded}" + ) res = self.put( f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/checkpoint", params=query, @@ -814,7 +883,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def timeline_spawn_download_remote_layers( self, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, max_concurrent_downloads: int, ) -> dict[str, Any]: @@ -833,7 +902,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def timeline_poll_download_remote_layers_status( self, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, spawn_response: dict[str, Any], poll_state=None, @@ -855,7 +924,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def timeline_download_remote_layers( self, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, max_concurrent_downloads: int, errors_ok=False, @@ -905,7 +974,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): timeline_id: TimelineId, file_kind: str, op_kind: str, - ) -> Optional[int]: + ) -> int | None: metrics = [ "pageserver_remote_timeline_client_calls_started_total", "pageserver_remote_timeline_client_calls_finished_total", @@ -929,7 +998,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def layer_map_info( self, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, ) -> LayerMapInfo: res = self.get( @@ -939,7 +1008,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): return LayerMapInfo.from_json(res.json()) def timeline_layer_scan_disposable_keys( - self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId, layer_name: str + self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, layer_name: str ) -> ScanDisposableKeysResponse: res = self.post( f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/layer/{layer_name}/scan_disposable_keys", @@ -949,7 +1018,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): return ScanDisposableKeysResponse.from_json(res.json()) def download_layer( - self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId, layer_name: str + self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, layer_name: str ): res = self.get( f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/layer/{layer_name}", @@ -958,9 +1027,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): assert res.status_code == 200 - def download_all_layers( - self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId - ): + def download_all_layers(self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId): info = self.layer_map_info(tenant_id, timeline_id) for layer in info.historic_layers: if not layer.remote: @@ -969,9 +1036,9 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def detach_ancestor( self, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, - batch_size: Optional[int] = None, + batch_size: int | None = None, **kwargs, ) -> set[TimelineId]: params = {} @@ -987,7 +1054,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): return set(map(TimelineId, json["reparented_timelines"])) def evict_layer( - self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId, layer_name: str + self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, layer_name: str ): res = self.delete( f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/layer/{layer_name}", @@ -996,7 +1063,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): assert res.status_code in (200, 304) - def evict_all_layers(self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId): + def evict_all_layers(self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId): info = self.layer_map_info(tenant_id, timeline_id) for layer in info.historic_layers: self.evict_layer(tenant_id, timeline_id, layer.layer_file_name) @@ -1009,7 +1076,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): self.verbose_error(res) return res.json() - def tenant_break(self, tenant_id: Union[TenantId, TenantShardId]): + def tenant_break(self, tenant_id: TenantId | TenantShardId): res = self.put(f"http://localhost:{self.port}/v1/tenant/{tenant_id}/break") self.verbose_error(res) @@ -1058,7 +1125,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter): def perf_info( self, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, ): self.is_testing_enabled_or_skip() diff --git a/test_runner/fixtures/pageserver/many_tenants.py b/test_runner/fixtures/pageserver/many_tenants.py index 37b4246d40..b6d19af84c 100644 --- a/test_runner/fixtures/pageserver/many_tenants.py +++ b/test_runner/fixtures/pageserver/many_tenants.py @@ -13,7 +13,8 @@ from fixtures.neon_fixtures import ( from fixtures.remote_storage import LocalFsStorage, RemoteStorageKind if TYPE_CHECKING: - from typing import Any, Callable + from collections.abc import Callable + from typing import Any def single_timeline( diff --git a/test_runner/fixtures/pageserver/utils.py b/test_runner/fixtures/pageserver/utils.py index ac7497ee6c..46700e3fe3 100644 --- a/test_runner/fixtures/pageserver/utils.py +++ b/test_runner/fixtures/pageserver/utils.py @@ -17,14 +17,14 @@ from fixtures.remote_storage import RemoteStorage, RemoteStorageKind, S3Storage from fixtures.utils import wait_until if TYPE_CHECKING: - from typing import Any, Optional, Union + from typing import Any def assert_tenant_state( pageserver_http: PageserverHttpClient, tenant: TenantId, expected_state: str, - message: Optional[str] = None, + message: str | None = None, ) -> None: tenant_status = pageserver_http.tenant_status(tenant) log.info(f"tenant_status: {tenant_status}") @@ -33,7 +33,7 @@ def assert_tenant_state( def remote_consistent_lsn( pageserver_http: PageserverHttpClient, - tenant: Union[TenantId, TenantShardId], + tenant: TenantId | TenantShardId, timeline: TimelineId, ) -> Lsn: detail = pageserver_http.timeline_detail(tenant, timeline) @@ -51,7 +51,7 @@ def remote_consistent_lsn( def wait_for_upload( pageserver_http: PageserverHttpClient, - tenant: Union[TenantId, TenantShardId], + tenant: TenantId | TenantShardId, timeline: TimelineId, lsn: Lsn, ): @@ -138,7 +138,7 @@ def wait_until_all_tenants_state( def wait_until_timeline_state( pageserver_http: PageserverHttpClient, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, expected_state: str, iterations: int, @@ -188,7 +188,7 @@ def wait_until_tenant_active( def last_record_lsn( pageserver_http_client: PageserverHttpClient, - tenant: Union[TenantId, TenantShardId], + tenant: TenantId | TenantShardId, timeline: TimelineId, ) -> Lsn: detail = pageserver_http_client.timeline_detail(tenant, timeline) @@ -200,7 +200,7 @@ def last_record_lsn( def wait_for_last_record_lsn( pageserver_http: PageserverHttpClient, - tenant: Union[TenantId, TenantShardId], + tenant: TenantId | TenantShardId, timeline: TimelineId, lsn: Lsn, ) -> Lsn: @@ -267,10 +267,10 @@ def wait_for_upload_queue_empty( def wait_timeline_detail_404( pageserver_http: PageserverHttpClient, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, iterations: int, - interval: Optional[float] = None, + interval: float | None = None, ): if interval is None: interval = 0.25 @@ -292,10 +292,10 @@ def wait_timeline_detail_404( def timeline_delete_wait_completed( pageserver_http: PageserverHttpClient, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, iterations: int = 20, - interval: Optional[float] = None, + interval: float | None = None, **delete_args, ) -> None: pageserver_http.timeline_delete(tenant_id=tenant_id, timeline_id=timeline_id, **delete_args) @@ -304,9 +304,9 @@ def timeline_delete_wait_completed( # remote_storage must not be None, but that's easier for callers to make mypy happy def assert_prefix_empty( - remote_storage: Optional[RemoteStorage], - prefix: Optional[str] = None, - allowed_postfix: Optional[str] = None, + remote_storage: RemoteStorage | None, + prefix: str | None = None, + allowed_postfix: str | None = None, delimiter: str = "/", ) -> None: assert remote_storage is not None @@ -348,8 +348,8 @@ def assert_prefix_empty( # remote_storage must not be None, but that's easier for callers to make mypy happy def assert_prefix_not_empty( - remote_storage: Optional[RemoteStorage], - prefix: Optional[str] = None, + remote_storage: RemoteStorage | None, + prefix: str | None = None, delimiter: str = "/", ): assert remote_storage is not None @@ -358,7 +358,7 @@ def assert_prefix_not_empty( def list_prefix( - remote: RemoteStorage, prefix: Optional[str] = None, delimiter: str = "/" + remote: RemoteStorage, prefix: str | None = None, delimiter: str = "/" ) -> ListObjectsV2OutputTypeDef: """ Note that this function takes into account prefix_in_bucket. diff --git a/test_runner/fixtures/parametrize.py b/test_runner/fixtures/parametrize.py index 1131bf090f..f57c0f801f 100644 --- a/test_runner/fixtures/parametrize.py +++ b/test_runner/fixtures/parametrize.py @@ -11,7 +11,7 @@ from _pytest.python import Metafunc from fixtures.pg_version import PgVersion if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any """ @@ -20,31 +20,31 @@ Dynamically parametrize tests by different parameters @pytest.fixture(scope="function", autouse=True) -def pg_version() -> Optional[PgVersion]: +def pg_version() -> PgVersion | None: return None @pytest.fixture(scope="function", autouse=True) -def build_type() -> Optional[str]: +def build_type() -> str | None: return None @pytest.fixture(scope="session", autouse=True) -def platform() -> Optional[str]: +def platform() -> str | None: return None @pytest.fixture(scope="function", autouse=True) -def pageserver_virtual_file_io_engine() -> Optional[str]: +def pageserver_virtual_file_io_engine() -> str | None: return os.getenv("PAGESERVER_VIRTUAL_FILE_IO_ENGINE") @pytest.fixture(scope="function", autouse=True) -def pageserver_virtual_file_io_mode() -> Optional[str]: +def pageserver_virtual_file_io_mode() -> str | None: return os.getenv("PAGESERVER_VIRTUAL_FILE_IO_MODE") -def get_pageserver_default_tenant_config_compaction_algorithm() -> Optional[dict[str, Any]]: +def get_pageserver_default_tenant_config_compaction_algorithm() -> dict[str, Any] | None: toml_table = os.getenv("PAGESERVER_DEFAULT_TENANT_CONFIG_COMPACTION_ALGORITHM") if toml_table is None: return None @@ -54,7 +54,7 @@ def get_pageserver_default_tenant_config_compaction_algorithm() -> Optional[dict @pytest.fixture(scope="function", autouse=True) -def pageserver_default_tenant_config_compaction_algorithm() -> Optional[dict[str, Any]]: +def pageserver_default_tenant_config_compaction_algorithm() -> dict[str, Any] | None: return get_pageserver_default_tenant_config_compaction_algorithm() @@ -66,6 +66,7 @@ def pytest_generate_tests(metafunc: Metafunc): metafunc.parametrize("build_type", build_types) + pg_versions: list[PgVersion] if (v := os.getenv("DEFAULT_PG_VERSION")) is None: pg_versions = [version for version in PgVersion if version != PgVersion.NOT_SET] else: @@ -115,5 +116,6 @@ def pytest_runtest_makereport(*args, **kwargs): }.get(os.uname().machine, "UNKNOWN") arch = os.getenv("RUNNER_ARCH", uname_m) allure.dynamic.parameter("__arch", arch) + allure.dynamic.parameter("__lfc", os.getenv("USE_LFC") != "false") yield diff --git a/test_runner/fixtures/paths.py b/test_runner/fixtures/paths.py index 60221573eb..80777d65e9 100644 --- a/test_runner/fixtures/paths.py +++ b/test_runner/fixtures/paths.py @@ -18,7 +18,6 @@ from fixtures.utils import allure_attach_from_dir if TYPE_CHECKING: from collections.abc import Iterator - from typing import Optional BASE_DIR = Path(__file__).parents[2] @@ -26,14 +25,12 @@ COMPUTE_CONFIG_DIR = BASE_DIR / "compute" / "etc" DEFAULT_OUTPUT_DIR: str = "test_output" -def get_test_dir( - request: FixtureRequest, top_output_dir: Path, prefix: Optional[str] = None -) -> Path: +def get_test_dir(request: FixtureRequest, top_output_dir: Path, prefix: str | None = None) -> Path: """Compute the path to a working directory for an individual test.""" test_name = request.node.name test_dir = top_output_dir / f"{prefix or ''}{test_name.replace('/', '-')}" - # We rerun flaky tests multiple times, use a separate directory for each run. + # We rerun failed tests multiple times, use a separate directory for each run. if (suffix := getattr(request.node, "execution_count", None)) is not None: test_dir = test_dir.parent / f"{test_dir.name}-{suffix}" @@ -112,7 +109,7 @@ def compatibility_snapshot_dir() -> Iterator[Path]: @pytest.fixture(scope="session") -def compatibility_neon_binpath() -> Iterator[Optional[Path]]: +def compatibility_neon_binpath() -> Iterator[Path | None]: if os.getenv("REMOTE_ENV"): return comp_binpath = None @@ -133,7 +130,7 @@ def pg_distrib_dir(base_dir: Path) -> Iterator[Path]: @pytest.fixture(scope="session") -def compatibility_pg_distrib_dir() -> Iterator[Optional[Path]]: +def compatibility_pg_distrib_dir() -> Iterator[Path | None]: compat_distrib_dir = None if env_compat_postgres_bin := os.environ.get("COMPATIBILITY_POSTGRES_DISTRIB_DIR"): compat_distrib_dir = Path(env_compat_postgres_bin).resolve() @@ -197,7 +194,7 @@ class FileAndThreadLock: def __init__(self, path: Path): self.path = path self.thread_lock = threading.Lock() - self.fd: Optional[int] = None + self.fd: int | None = None def __enter__(self): self.fd = os.open(self.path, os.O_CREAT | os.O_WRONLY) @@ -208,9 +205,9 @@ class FileAndThreadLock: def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - exc_traceback: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_traceback: TracebackType | None, ): assert self.fd is not None assert self.thread_lock.locked() # ... by us @@ -263,9 +260,9 @@ class SnapshotDir: def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - exc_traceback: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_traceback: TracebackType | None, ): self._lock.__exit__(exc_type, exc_value, exc_traceback) @@ -277,7 +274,7 @@ def shared_snapshot_dir(top_output_dir: Path, ident: str) -> SnapshotDir: @pytest.fixture(scope="function") -def test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Optional[Path]: +def test_overlay_dir(request: FixtureRequest, top_output_dir: Path) -> Path | None: """ Idempotently create a test's overlayfs mount state directory. If the functionality isn't enabled via env var, returns None. diff --git a/test_runner/fixtures/pg_version.py b/test_runner/fixtures/pg_version.py index 798db1e8d9..46423e8c76 100644 --- a/test_runner/fixtures/pg_version.py +++ b/test_runner/fixtures/pg_version.py @@ -1,22 +1,16 @@ from __future__ import annotations -import enum -from typing import TYPE_CHECKING +from enum import StrEnum from typing_extensions import override -if TYPE_CHECKING: - from typing import Optional - - """ This fixture is used to determine which version of Postgres to use for tests. """ # Inherit PgVersion from str rather than int to make it easier to pass as a command-line argument -# TODO: use enum.StrEnum for Python >= 3.11 -class PgVersion(str, enum.Enum): +class PgVersion(StrEnum): V14 = "14" V15 = "15" V16 = "16" @@ -34,7 +28,6 @@ class PgVersion(str, enum.Enum): def __repr__(self) -> str: return f"'{self.value}'" - # Make this explicit for Python 3.11 compatibility, which changes the behavior of enums @override def __str__(self) -> str: return self.value @@ -47,16 +40,18 @@ class PgVersion(str, enum.Enum): @classmethod @override - def _missing_(cls, value: object) -> Optional[PgVersion]: - known_values = {v.value for _, v in cls.__members__.items()} + def _missing_(cls, value: object) -> PgVersion | None: + if not isinstance(value, str): + return None - # Allow passing version as a string with "v" prefix (e.g. "v14") - if isinstance(value, str) and value.lower().startswith("v") and value[1:] in known_values: - return cls(value[1:]) - # Allow passing version as an int (e.g. 15 or 150002, both will be converted to PgVersion.V15) - elif isinstance(value, int) and str(value)[:2] in known_values: - return cls(str(value)[:2]) + known_values = set(cls.__members__.values()) + + # Allow passing version as v-prefixed string (e.g. "v14") + if value.lower().startswith("v") and (v := value[1:]) in known_values: + return cls(v) + + # Allow passing version as an int (i.e. both "15" and "150002" matches PgVersion.V15) + if value.isdigit() and (v := value[:2]) in known_values: + return cls(v) - # Make mypy happy - # See https://github.com/python/mypy/issues/3974 return None diff --git a/test_runner/fixtures/port_distributor.py b/test_runner/fixtures/port_distributor.py index df0eb2a809..6a829a9399 100644 --- a/test_runner/fixtures/port_distributor.py +++ b/test_runner/fixtures/port_distributor.py @@ -3,13 +3,9 @@ from __future__ import annotations import re import socket from contextlib import closing -from typing import TYPE_CHECKING from fixtures.log_helper import log -if TYPE_CHECKING: - from typing import Union - def can_bind(host: str, port: int) -> bool: """ @@ -49,17 +45,19 @@ class PortDistributor: "port range configured for test is exhausted, consider enlarging the range" ) - def replace_with_new_port(self, value: Union[int, str]) -> Union[int, str]: + def replace_with_new_port(self, value: int | str) -> int | str: """ Returns a new port for a port number in a string (like "localhost:1234") or int. Replacements are memorised, so a substitution for the same port is always the same. """ - # TODO: replace with structural pattern matching for Python >= 3.10 - if isinstance(value, int): - return self._replace_port_int(value) - - return self._replace_port_str(value) + match value: + case int(): + return self._replace_port_int(value) + case str(): + return self._replace_port_str(value) + case _: + raise TypeError(f"Unsupported type {type(value)}, should be int | str") def _replace_port_int(self, value: int) -> int: known_port = self.port_map.get(value) diff --git a/test_runner/fixtures/remote_storage.py b/test_runner/fixtures/remote_storage.py index c630ea98b4..4e1e8a884f 100644 --- a/test_runner/fixtures/remote_storage.py +++ b/test_runner/fixtures/remote_storage.py @@ -6,8 +6,9 @@ import json import os import re from dataclasses import dataclass +from enum import StrEnum from pathlib import Path -from typing import TYPE_CHECKING, Union +from typing import TYPE_CHECKING import boto3 import toml @@ -20,7 +21,7 @@ from fixtures.log_helper import log from fixtures.pageserver.common_types import IndexPartDump if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any TIMELINE_INDEX_PART_FILE_NAME = "index_part.json" @@ -28,7 +29,7 @@ TENANT_HEATMAP_FILE_NAME = "heatmap-v1.json" @enum.unique -class RemoteStorageUser(str, enum.Enum): +class RemoteStorageUser(StrEnum): """ Instead of using strings for the users, use a more strict enum. """ @@ -77,21 +78,19 @@ class MockS3Server: class LocalFsStorage: root: Path - def tenant_path(self, tenant_id: Union[TenantId, TenantShardId]) -> Path: + def tenant_path(self, tenant_id: TenantId | TenantShardId) -> Path: return self.root / "tenants" / str(tenant_id) - def timeline_path( - self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId - ) -> Path: + def timeline_path(self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId) -> Path: return self.tenant_path(tenant_id) / "timelines" / str(timeline_id) def timeline_latest_generation( - self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId - ) -> Optional[int]: + self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId + ) -> int | None: timeline_files = os.listdir(self.timeline_path(tenant_id, timeline_id)) index_parts = [f for f in timeline_files if f.startswith("index_part")] - def parse_gen(filename: str) -> Optional[int]: + def parse_gen(filename: str) -> int | None: log.info(f"parsing index_part '{filename}'") parts = filename.split("-") if len(parts) == 2: @@ -104,9 +103,7 @@ class LocalFsStorage: raise RuntimeError(f"No index_part found for {tenant_id}/{timeline_id}") return generations[-1] - def index_path( - self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId - ) -> Path: + def index_path(self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId) -> Path: latest_gen = self.timeline_latest_generation(tenant_id, timeline_id) if latest_gen is None: filename = TIMELINE_INDEX_PART_FILE_NAME @@ -120,7 +117,7 @@ class LocalFsStorage: tenant_id: TenantId, timeline_id: TimelineId, local_name: str, - generation: Optional[int] = None, + generation: int | None = None, ): if generation is None: generation = self.timeline_latest_generation(tenant_id, timeline_id) @@ -130,9 +127,7 @@ class LocalFsStorage: filename = f"{local_name}-{generation:08x}" return self.timeline_path(tenant_id, timeline_id) / filename - def index_content( - self, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId - ) -> Any: + def index_content(self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId) -> Any: with self.index_path(tenant_id, timeline_id).open("r") as f: return json.load(f) @@ -164,17 +159,17 @@ class LocalFsStorage: class S3Storage: bucket_name: str bucket_region: str - access_key: Optional[str] - secret_key: Optional[str] - aws_profile: Optional[str] + access_key: str | None + secret_key: str | None + aws_profile: str | None prefix_in_bucket: str client: S3Client cleanup: bool """Is this MOCK_S3 (false) or REAL_S3 (true)""" real: bool - endpoint: Optional[str] = None + endpoint: str | None = None """formatting deserialized with humantime crate, for example "1s".""" - custom_timeout: Optional[str] = None + custom_timeout: str | None = None def access_env_vars(self) -> dict[str, str]: if self.aws_profile is not None: @@ -272,12 +267,10 @@ class S3Storage: def tenants_path(self) -> str: return f"{self.prefix_in_bucket}/tenants" - def tenant_path(self, tenant_id: Union[TenantShardId, TenantId]) -> str: + def tenant_path(self, tenant_id: TenantShardId | TenantId) -> str: return f"{self.tenants_path()}/{tenant_id}" - def timeline_path( - self, tenant_id: Union[TenantShardId, TenantId], timeline_id: TimelineId - ) -> str: + def timeline_path(self, tenant_id: TenantShardId | TenantId, timeline_id: TimelineId) -> str: return f"{self.tenant_path(tenant_id)}/timelines/{timeline_id}" def get_latest_index_key(self, index_keys: list[str]) -> str: @@ -315,11 +308,11 @@ class S3Storage: assert self.real is False -RemoteStorage = Union[LocalFsStorage, S3Storage] +RemoteStorage = LocalFsStorage | S3Storage @enum.unique -class RemoteStorageKind(str, enum.Enum): +class RemoteStorageKind(StrEnum): LOCAL_FS = "local_fs" MOCK_S3 = "mock_s3" REAL_S3 = "real_s3" @@ -331,8 +324,8 @@ class RemoteStorageKind(str, enum.Enum): run_id: str, test_name: str, user: RemoteStorageUser, - bucket_name: Optional[str] = None, - bucket_region: Optional[str] = None, + bucket_name: str | None = None, + bucket_region: str | None = None, ) -> RemoteStorage: if self == RemoteStorageKind.LOCAL_FS: return LocalFsStorage(LocalFsStorage.component_path(repo_dir, user)) diff --git a/test_runner/fixtures/reruns.py b/test_runner/fixtures/reruns.py new file mode 100644 index 0000000000..f2a25ae8f6 --- /dev/null +++ b/test_runner/fixtures/reruns.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from collections.abc import MutableMapping +from typing import TYPE_CHECKING, cast + +import pytest + +if TYPE_CHECKING: + from collections.abc import MutableMapping + from typing import Any + + from _pytest.config import Config + + +def pytest_collection_modifyitems(config: Config, items: list[pytest.Item]): + # pytest-rerunfailures is not compatible with pytest-timeout (timeout is not set for reruns), + # we can workaround it by setting `timeout_func_only` to True[1]. + # Unfortunately, setting `timeout_func_only = True` globally in pytest.ini is broken[2], + # but we still can do it using pytest marker. + # + # - [1] https://github.com/pytest-dev/pytest-rerunfailures/issues/99 + # - [2] https://github.com/pytest-dev/pytest-timeout/issues/142 + + if not config.getoption("--reruns"): + return + + for item in items: + timeout_marker = item.get_closest_marker("timeout") + if timeout_marker is not None: + kwargs = cast("MutableMapping[str, Any]", timeout_marker.kwargs) + kwargs["func_only"] = True diff --git a/test_runner/fixtures/safekeeper/http.py b/test_runner/fixtures/safekeeper/http.py index 5d9a3bd149..094188c0b5 100644 --- a/test_runner/fixtures/safekeeper/http.py +++ b/test_runner/fixtures/safekeeper/http.py @@ -13,7 +13,7 @@ from fixtures.metrics import Metrics, MetricsGetter, parse_metrics from fixtures.utils import wait_until if TYPE_CHECKING: - from typing import Any, Optional, Union + from typing import Any # Walreceiver as returned by sk's timeline status endpoint. @@ -72,7 +72,7 @@ class TermBumpResponse: class SafekeeperHttpClient(requests.Session, MetricsGetter): HTTPError = requests.HTTPError - def __init__(self, port: int, auth_token: Optional[str] = None, is_testing_enabled=False): + def __init__(self, port: int, auth_token: str | None = None, is_testing_enabled=False): super().__init__() self.port = port self.auth_token = auth_token @@ -98,7 +98,7 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): if not self.is_testing_enabled: pytest.skip("safekeeper was built without 'testing' feature") - def configure_failpoints(self, config_strings: Union[tuple[str, str], list[tuple[str, str]]]): + def configure_failpoints(self, config_strings: tuple[str, str] | list[tuple[str, str]]): self.is_testing_enabled_or_skip() if isinstance(config_strings, tuple): @@ -195,7 +195,7 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): assert isinstance(res_json, dict) return res_json - def debug_dump(self, params: Optional[dict[str, str]] = None) -> dict[str, Any]: + def debug_dump(self, params: dict[str, str] | None = None) -> dict[str, Any]: params = params or {} res = self.get(f"http://localhost:{self.port}/v1/debug_dump", params=params) res.raise_for_status() @@ -204,7 +204,7 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): return res_json def debug_dump_timeline( - self, timeline_id: TimelineId, params: Optional[dict[str, str]] = None + self, timeline_id: TimelineId, params: dict[str, str] | None = None ) -> Any: params = params or {} params["timeline_id"] = str(timeline_id) @@ -285,7 +285,7 @@ class SafekeeperHttpClient(requests.Session, MetricsGetter): self, tenant_id: TenantId, timeline_id: TimelineId, - term: Optional[int], + term: int | None, ) -> TermBumpResponse: body = {} if term is not None: diff --git a/test_runner/fixtures/storage_controller_proxy.py b/test_runner/fixtures/storage_controller_proxy.py index c174358ef5..be95a98ff9 100644 --- a/test_runner/fixtures/storage_controller_proxy.py +++ b/test_runner/fixtures/storage_controller_proxy.py @@ -13,14 +13,14 @@ from werkzeug.wrappers.response import Response from fixtures.log_helper import log if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any class StorageControllerProxy: def __init__(self, server: HTTPServer): self.server: HTTPServer = server self.listen: str = f"http://{server.host}:{server.port}" - self.routing_to: Optional[str] = None + self.routing_to: str | None = None def route_to(self, storage_controller_api: str): self.routing_to = storage_controller_api diff --git a/test_runner/fixtures/utils.py b/test_runner/fixtures/utils.py index bb45385ea6..04e98fe494 100644 --- a/test_runner/fixtures/utils.py +++ b/test_runner/fixtures/utils.py @@ -8,10 +8,10 @@ import subprocess import tarfile import threading import time -from collections.abc import Iterable +from collections.abc import Callable, Iterable from hashlib import sha256 from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar from urllib.parse import urlencode import allure @@ -29,7 +29,7 @@ from fixtures.pg_version import PgVersion if TYPE_CHECKING: from collections.abc import Iterable - from typing import IO, Optional + from typing import IO from fixtures.common_types import TimelineId from fixtures.neon_fixtures import PgBin @@ -57,6 +57,10 @@ VERSIONS_COMBINATIONS = ( ) # fmt: on +# If the environment variable USE_LFC is set and its value is "false", then LFC is disabled for tests. +# If it is not set or set to a value not equal to "false", LFC is enabled by default. +USE_LFC = os.environ.get("USE_LFC") != "false" + def subprocess_capture( capture_dir: Path, @@ -66,10 +70,10 @@ def subprocess_capture( echo_stderr: bool = False, echo_stdout: bool = False, capture_stdout: bool = False, - timeout: Optional[float] = None, + timeout: float | None = None, with_command_header: bool = True, **popen_kwargs: Any, -) -> tuple[str, Optional[str], int]: +) -> tuple[str, str | None, int]: """Run a process and bifurcate its output to files and the `log` logger stderr and stdout are always captured in files. They are also optionally @@ -536,7 +540,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str """ started_at = time.time() - def hash_extracted(reader: Optional[IO[bytes]]) -> bytes: + def hash_extracted(reader: IO[bytes] | None) -> bytes: assert reader is not None digest = sha256(usedforsecurity=False) while True: @@ -563,7 +567,7 @@ def assert_pageserver_backups_equal(left: Path, right: Path, skip_files: set[str mismatching: set[str] = set() - for left_tuple, right_tuple in zip(left_list, right_list): + for left_tuple, right_tuple in zip(left_list, right_list, strict=False): left_path, left_hash = left_tuple right_path, right_hash = right_tuple assert ( @@ -595,7 +599,7 @@ class PropagatingThread(threading.Thread): self.exc = e @override - def join(self, timeout: Optional[float] = None) -> Any: + def join(self, timeout: float | None = None) -> Any: super().join(timeout) if self.exc: raise self.exc @@ -653,6 +657,23 @@ def allpairs_versions(): return {"argnames": "combination", "argvalues": tuple(argvalues), "ids": ids} +def size_to_bytes(hr_size: str) -> int: + """ + Gets human-readable size from postgresql.conf (e.g. 512kB, 10MB) + returns size in bytes + """ + units = {"B": 1, "kB": 1024, "MB": 1024**2, "GB": 1024**3, "TB": 1024**4, "PB": 1024**5} + match = re.search(r"^\'?(\d+)\s*([kMGTP]?B)?\'?$", hr_size) + assert match is not None, f'"{hr_size}" is not a well-formatted human-readable size' + number, unit = match.groups() + + if unit: + amp = units[unit] + else: + amp = 8192 + return int(number) * amp + + def skip_on_postgres(version: PgVersion, reason: str): return pytest.mark.skipif( PgVersion(os.getenv("DEFAULT_PG_VERSION", PgVersion.DEFAULT)) is version, @@ -674,6 +695,13 @@ def run_only_on_default_postgres(reason: str): ) +def run_only_on_postgres(versions: Iterable[PgVersion], reason: str): + return pytest.mark.skipif( + PgVersion(os.getenv("DEFAULT_PG_VERSION", PgVersion.DEFAULT)) not in versions, + reason=reason, + ) + + def skip_in_debug_build(reason: str): return pytest.mark.skipif( os.getenv("BUILD_TYPE", "debug") == "debug", diff --git a/test_runner/fixtures/workload.py b/test_runner/fixtures/workload.py index e869c43185..1b8c9fef44 100644 --- a/test_runner/fixtures/workload.py +++ b/test_runner/fixtures/workload.py @@ -15,7 +15,7 @@ from fixtures.neon_fixtures import ( from fixtures.pageserver.utils import wait_for_last_record_lsn if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any # neon_local doesn't handle creating/modifying endpoints concurrently, so we use a mutex # to ensure we don't do that: this enables running lots of Workloads in parallel safely. @@ -36,8 +36,8 @@ class Workload: env: NeonEnv, tenant_id: TenantId, timeline_id: TimelineId, - branch_name: Optional[str] = None, - endpoint_opts: Optional[dict[str, Any]] = None, + branch_name: str | None = None, + endpoint_opts: dict[str, Any] | None = None, ): self.env = env self.tenant_id = tenant_id @@ -50,10 +50,10 @@ class Workload: self.expect_rows = 0 self.churn_cursor = 0 - self._endpoint: Optional[Endpoint] = None + self._endpoint: Endpoint | None = None self._endpoint_opts = endpoint_opts or {} - def reconfigure(self): + def reconfigure(self) -> None: """ Request the endpoint to reconfigure based on location reported by storage controller """ @@ -61,7 +61,7 @@ class Workload: with ENDPOINT_LOCK: self._endpoint.reconfigure() - def endpoint(self, pageserver_id: Optional[int] = None) -> Endpoint: + def endpoint(self, pageserver_id: int | None = None) -> Endpoint: # We may be running alongside other Workloads for different tenants. Full TTID is # obnoxiously long for use here, but a cut-down version is still unique enough for tests. endpoint_id = f"ep-workload-{str(self.tenant_id)[0:4]}-{str(self.timeline_id)[0:4]}" @@ -94,16 +94,17 @@ class Workload: def __del__(self): self.stop() - def init(self, pageserver_id: Optional[int] = None): + def init(self, pageserver_id: int | None = None, allow_recreate=False): endpoint = self.endpoint(pageserver_id) - + if allow_recreate: + endpoint.safe_psql(f"DROP TABLE IF EXISTS {self.table};") endpoint.safe_psql(f"CREATE TABLE {self.table} (id INTEGER PRIMARY KEY, val text);") endpoint.safe_psql("CREATE EXTENSION IF NOT EXISTS neon_test_utils;") last_flush_lsn_upload( self.env, endpoint, self.tenant_id, self.timeline_id, pageserver_id=pageserver_id ) - def write_rows(self, n: int, pageserver_id: Optional[int] = None, upload: bool = True): + def write_rows(self, n: int, pageserver_id: int | None = None, upload: bool = True): endpoint = self.endpoint(pageserver_id) start = self.expect_rows end = start + n - 1 @@ -125,7 +126,7 @@ class Workload: return False def churn_rows( - self, n: int, pageserver_id: Optional[int] = None, upload: bool = True, ingest: bool = True + self, n: int, pageserver_id: int | None = None, upload: bool = True, ingest: bool = True ): assert self.expect_rows >= n @@ -190,9 +191,9 @@ class Workload: else: log.info(f"Churn: not waiting for upload, disk LSN {last_flush_lsn}") - def validate(self, pageserver_id: Optional[int] = None): + def validate(self, pageserver_id: int | None = None): endpoint = self.endpoint(pageserver_id) - endpoint.clear_shared_buffers() + endpoint.clear_buffers() result = endpoint.safe_psql(f"SELECT COUNT(*) FROM {self.table}") log.info(f"validate({self.expect_rows}): {result}") diff --git a/test_runner/performance/pageserver/test_page_service_batching.py b/test_runner/performance/pageserver/test_page_service_batching.py index 669ce32d57..c47a849fec 100644 --- a/test_runner/performance/pageserver/test_page_service_batching.py +++ b/test_runner/performance/pageserver/test_page_service_batching.py @@ -3,7 +3,7 @@ import json import time from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional, Union +from typing import Any import pytest from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker @@ -16,21 +16,32 @@ TARGET_RUNTIME = 30 @dataclass class PageServicePipeliningConfig: + pass + + +@dataclass +class PageServicePipeliningConfigSerial(PageServicePipeliningConfig): + mode: str = "serial" + + +@dataclass +class PageServicePipeliningConfigPipelined(PageServicePipeliningConfig): max_batch_size: int - protocol_pipelining_mode: str + execution: str + mode: str = "pipelined" -PROTOCOL_PIPELINING_MODES = ["concurrent-futures", "tasks"] +EXECUTION = ["concurrent-futures", "tasks"] -NON_BATCHABLE: list[Optional[PageServicePipeliningConfig]] = [None] +NON_BATCHABLE: list[PageServicePipeliningConfig] = [PageServicePipeliningConfigSerial()] for max_batch_size in [1, 32]: - for protocol_pipelining_mode in PROTOCOL_PIPELINING_MODES: - NON_BATCHABLE.append(PageServicePipeliningConfig(max_batch_size, protocol_pipelining_mode)) + for execution in EXECUTION: + NON_BATCHABLE.append(PageServicePipeliningConfigPipelined(max_batch_size, execution)) -BATCHABLE: list[Optional[PageServicePipeliningConfig]] = [None] +BATCHABLE: list[PageServicePipeliningConfig] = [PageServicePipeliningConfigSerial()] for max_batch_size in [1, 2, 4, 8, 16, 32]: - for protocol_pipelining_mode in PROTOCOL_PIPELINING_MODES: - BATCHABLE.append(PageServicePipeliningConfig(max_batch_size, protocol_pipelining_mode)) + for execution in EXECUTION: + BATCHABLE.append(PageServicePipeliningConfigPipelined(max_batch_size, execution)) @pytest.mark.parametrize( @@ -45,7 +56,7 @@ for max_batch_size in [1, 2, 4, 8, 16, 32]: TARGET_RUNTIME, 1, 128, - f"not batchable {dataclasses.asdict(config) if config else None}", + f"not batchable {dataclasses.asdict(config)}", ) for config in NON_BATCHABLE ], @@ -57,7 +68,7 @@ for max_batch_size in [1, 2, 4, 8, 16, 32]: TARGET_RUNTIME, 100, 128, - f"batchable {dataclasses.asdict(config) if config else None}", + f"batchable {dataclasses.asdict(config)}", ) for config in BATCHABLE ], @@ -67,7 +78,7 @@ def test_throughput( neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchmarker, tablesize_mib: int, - pipelining_config: None | PageServicePipeliningConfig, + pipelining_config: PageServicePipeliningConfig, target_runtime: int, effective_io_concurrency: int, readhead_buffer_size: int, @@ -94,25 +105,23 @@ def test_throughput( # # record perf-related parameters as metrics to simplify processing of results # - params: dict[str, tuple[Union[float, int], dict[str, Any]]] = {} + params: dict[str, tuple[float | int, dict[str, Any]]] = {} params.update( { "tablesize_mib": (tablesize_mib, {"unit": "MiB"}), - "pipelining_enabled": (1 if pipelining_config else 0, {}), # target_runtime is just a polite ask to the workload to run for this long "effective_io_concurrency": (effective_io_concurrency, {}), "readhead_buffer_size": (readhead_buffer_size, {}), # name is not a metric, we just use it to identify the test easily in the `test_...[...]`` notation } ) - if pipelining_config: - params.update( - { - f"pipelining_config.{k}": (v, {}) - for k, v in dataclasses.asdict(pipelining_config).items() - } - ) + params.update( + { + f"pipelining_config.{k}": (v, {}) + for k, v in dataclasses.asdict(pipelining_config).items() + } + ) log.info("params: %s", params) @@ -224,8 +233,6 @@ def test_throughput( env.pageserver.patch_config_toml_nonrecursive( {"page_service_pipelining": dataclasses.asdict(pipelining_config)} - if pipelining_config is not None - else {} ) env.pageserver.restart() metrics = workload() @@ -255,23 +262,21 @@ def test_throughput( ) -PRECISION_CONFIGS: list[Optional[PageServicePipeliningConfig]] = [None] +PRECISION_CONFIGS: list[PageServicePipeliningConfig] = [PageServicePipeliningConfigSerial()] for max_batch_size in [1, 32]: - for protocol_pipelining_mode in PROTOCOL_PIPELINING_MODES: - PRECISION_CONFIGS.append( - PageServicePipeliningConfig(max_batch_size, protocol_pipelining_mode) - ) + for execution in EXECUTION: + PRECISION_CONFIGS.append(PageServicePipeliningConfigPipelined(max_batch_size, execution)) @pytest.mark.parametrize( "pipelining_config,name", - [(config, f"{dataclasses.asdict(config) if config else None}") for config in PRECISION_CONFIGS], + [(config, f"{dataclasses.asdict(config)}") for config in PRECISION_CONFIGS], ) def test_latency( neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchmarker, pg_bin: PgBin, - pipelining_config: Optional[PageServicePipeliningConfig], + pipelining_config: PageServicePipeliningConfig, name: str, ): """ diff --git a/test_runner/performance/pageserver/util.py b/test_runner/performance/pageserver/util.py index 227319c425..bcc3db69f0 100644 --- a/test_runner/performance/pageserver/util.py +++ b/test_runner/performance/pageserver/util.py @@ -16,7 +16,8 @@ from fixtures.neon_fixtures import ( from fixtures.pageserver.utils import wait_until_all_tenants_state if TYPE_CHECKING: - from typing import Any, Callable, Optional + from collections.abc import Callable + from typing import Any def ensure_pageserver_ready_for_benchmarking(env: NeonEnv, n_tenants: int): @@ -46,7 +47,7 @@ def setup_pageserver_with_tenants( name: str, n_tenants: int, setup: Callable[[NeonEnv], tuple[TenantId, TimelineId, dict[str, Any]]], - timeout_in_seconds: Optional[int] = None, + timeout_in_seconds: int | None = None, ) -> NeonEnv: """ Utility function to set up a pageserver with a given number of identical tenants. diff --git a/test_runner/performance/test_bulk_insert.py b/test_runner/performance/test_bulk_insert.py index 36090dcad7..680eb62b39 100644 --- a/test_runner/performance/test_bulk_insert.py +++ b/test_runner/performance/test_bulk_insert.py @@ -56,7 +56,7 @@ def test_bulk_insert(neon_with_baseline: PgCompare): def measure_recovery_time(env: NeonCompare): client = env.env.pageserver.http_client() - pg_version = PgVersion(client.timeline_detail(env.tenant, env.timeline)["pg_version"]) + pg_version = PgVersion(str(client.timeline_detail(env.tenant, env.timeline)["pg_version"])) # Delete the Tenant in the pageserver: this will drop local and remote layers, such that # when we "create" the Tenant again, we will replay the WAL from the beginning. diff --git a/test_runner/performance/test_compaction.py b/test_runner/performance/test_compaction.py index 8868dddf39..0cd1080fa7 100644 --- a/test_runner/performance/test_compaction.py +++ b/test_runner/performance/test_compaction.py @@ -103,6 +103,9 @@ def test_compaction_l0_memory(neon_compare: NeonCompare): cur.execute(f"update tbl{i} set j = {j};") wait_for_last_flush_lsn(env, endpoint, tenant_id, timeline_id) + pageserver_http.timeline_checkpoint( + tenant_id, timeline_id, compact=False + ) # ^1: flush all in-memory layers endpoint.stop() # Check we have generated the L0 stack we expected @@ -118,7 +121,9 @@ def test_compaction_l0_memory(neon_compare: NeonCompare): return v * 1024 before = rss_hwm() - pageserver_http.timeline_compact(tenant_id, timeline_id) + pageserver_http.timeline_compact( + tenant_id, timeline_id + ) # ^1: we must ensure during this process no new L0 layers are flushed after = rss_hwm() log.info(f"RSS across compaction: {before} -> {after} (grew {after - before})") @@ -137,7 +142,7 @@ def test_compaction_l0_memory(neon_compare: NeonCompare): # To be fixed in https://github.com/neondatabase/neon/issues/8184, after which # this memory estimate can be revised far downwards to something that doesn't scale # linearly with the layer sizes. - MEMORY_ESTIMATE = (initial_l0s_size - final_l0s_size) * 1.5 + MEMORY_ESTIMATE = (initial_l0s_size - final_l0s_size) * 1.25 # If we find that compaction is using more memory, this may indicate a regression assert compaction_mapped_rss < MEMORY_ESTIMATE diff --git a/test_runner/performance/test_copy.py b/test_runner/performance/test_copy.py index d571fab6b5..0e56fdc96f 100644 --- a/test_runner/performance/test_copy.py +++ b/test_runner/performance/test_copy.py @@ -2,7 +2,7 @@ from __future__ import annotations from contextlib import closing from io import BufferedReader, RawIOBase -from typing import Optional, final +from typing import final from fixtures.compare_fixtures import PgCompare from typing_extensions import override @@ -13,7 +13,7 @@ class CopyTestData(RawIOBase): def __init__(self, rows: int): self.rows = rows self.rownum = 0 - self.linebuf: Optional[bytes] = None + self.linebuf: bytes | None = None self.ptr = 0 @override diff --git a/test_runner/performance/test_ingest_logical_message.py b/test_runner/performance/test_ingest_logical_message.py new file mode 100644 index 0000000000..d3118eb15a --- /dev/null +++ b/test_runner/performance/test_ingest_logical_message.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import pytest +from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker +from fixtures.common_types import Lsn +from fixtures.log_helper import log +from fixtures.neon_fixtures import ( + NeonEnvBuilder, + wait_for_commit_lsn, + wait_for_last_flush_lsn, +) +from fixtures.pageserver.utils import wait_for_last_record_lsn + + +@pytest.mark.timeout(600) +@pytest.mark.parametrize("size", [1024, 8192, 131072]) +@pytest.mark.parametrize("fsync", [True, False], ids=["fsync", "nofsync"]) +def test_ingest_logical_message( + request: pytest.FixtureRequest, + neon_env_builder: NeonEnvBuilder, + zenbenchmark: NeonBenchmarker, + fsync: bool, + size: int, +): + """ + Benchmarks ingestion of 10 GB of logical message WAL. These are essentially noops, and don't + incur any pageserver writes. + """ + + VOLUME = 10 * 1024**3 + count = VOLUME // size + + neon_env_builder.safekeepers_enable_fsync = fsync + + env = neon_env_builder.init_start() + endpoint = env.endpoints.create_start( + "main", + config_lines=[ + f"fsync = {fsync}", + # Disable backpressure. We don't want to block on pageserver. + "max_replication_apply_lag = 0", + "max_replication_flush_lag = 0", + "max_replication_write_lag = 0", + ], + ) + client = env.pageserver.http_client() + + # Wait for the timeline to be propagated to the pageserver. + wait_for_last_flush_lsn(env, endpoint, env.initial_tenant, env.initial_timeline) + + # Ingest data and measure durations. + start_lsn = Lsn(endpoint.safe_psql("select pg_current_wal_lsn()")[0][0]) + + with endpoint.cursor() as cur: + cur.execute("set statement_timeout = 0") + + # Postgres will return once the logical messages have been written to its local WAL, without + # waiting for Safekeeper commit. We measure ingestion time both for Postgres, Safekeeper, + # and Pageserver to detect bottlenecks. + log.info("Ingesting data") + with zenbenchmark.record_duration("pageserver_ingest"): + with zenbenchmark.record_duration("safekeeper_ingest"): + with zenbenchmark.record_duration("postgres_ingest"): + cur.execute(f""" + select pg_logical_emit_message(false, '', repeat('x', {size})) + from generate_series(1, {count}) + """) + + end_lsn = Lsn(endpoint.safe_psql("select pg_current_wal_lsn()")[0][0]) + + # Wait for Safekeeper. + log.info("Waiting for Safekeeper to catch up") + wait_for_commit_lsn(env, env.initial_tenant, env.initial_timeline, end_lsn) + + # Wait for Pageserver. + log.info("Waiting for Pageserver to catch up") + wait_for_last_record_lsn(client, env.initial_tenant, env.initial_timeline, end_lsn) + + # Now that all data is ingested, delete and recreate the tenant in the pageserver. This will + # reingest all the WAL from the safekeeper without any other constraints. This gives us a + # baseline of how fast the pageserver can ingest this WAL in isolation. + status = env.storage_controller.inspect(tenant_shard_id=env.initial_tenant) + assert status is not None + + client.tenant_delete(env.initial_tenant) + env.pageserver.tenant_create(tenant_id=env.initial_tenant, generation=status[0]) + + with zenbenchmark.record_duration("pageserver_recover_ingest"): + log.info("Recovering WAL into pageserver") + client.timeline_create(env.pg_version, env.initial_tenant, env.initial_timeline) + wait_for_last_flush_lsn(env, endpoint, env.initial_tenant, env.initial_timeline) + + # Emit metrics. + wal_written_mb = round((end_lsn - start_lsn) / (1024 * 1024)) + zenbenchmark.record("wal_written", wal_written_mb, "MB", MetricReport.TEST_PARAM) + zenbenchmark.record("message_count", count, "messages", MetricReport.TEST_PARAM) + + props = {p["name"]: p["value"] for _, p in request.node.user_properties} + for name in ("postgres", "safekeeper", "pageserver", "pageserver_recover"): + throughput = int(wal_written_mb / props[f"{name}_ingest"]) + zenbenchmark.record(f"{name}_throughput", throughput, "MB/s", MetricReport.HIGHER_IS_BETTER) diff --git a/test_runner/performance/test_perf_ingest_using_pgcopydb.py b/test_runner/performance/test_perf_ingest_using_pgcopydb.py index 2f4574ba88..37f2e9db50 100644 --- a/test_runner/performance/test_perf_ingest_using_pgcopydb.py +++ b/test_runner/performance/test_perf_ingest_using_pgcopydb.py @@ -60,13 +60,13 @@ def build_pgcopydb_command(pgcopydb_filter_file: Path, test_output_dir: Path): "--no-acl", "--skip-db-properties", "--table-jobs", - "4", + "8", "--index-jobs", - "4", + "8", "--restore-jobs", - "4", + "8", "--split-tables-larger-than", - "10GB", + "5GB", "--skip-extensions", "--use-copy-binary", "--filters", @@ -136,7 +136,7 @@ def run_command_and_log_output(command, log_file_path: Path): "LD_LIBRARY_PATH": f"{os.getenv('PGCOPYDB_LIB_PATH')}:{os.getenv('PG_16_LIB_PATH')}", "PGCOPYDB_SOURCE_PGURI": cast(str, os.getenv("BENCHMARK_INGEST_SOURCE_CONNSTR")), "PGCOPYDB_TARGET_PGURI": cast(str, os.getenv("BENCHMARK_INGEST_TARGET_CONNSTR")), - "PGOPTIONS": "-c maintenance_work_mem=8388608 -c max_parallel_maintenance_workers=7", + "PGOPTIONS": "-c maintenance_work_mem=8388608 -c max_parallel_maintenance_workers=16", } # Combine the current environment with custom variables env = os.environ.copy() diff --git a/test_runner/performance/test_physical_replication.py b/test_runner/performance/test_physical_replication.py index d56f6dce09..38b04b9114 100644 --- a/test_runner/performance/test_physical_replication.py +++ b/test_runner/performance/test_physical_replication.py @@ -18,7 +18,7 @@ from fixtures.neon_api import connection_parameters_to_env from fixtures.pg_version import PgVersion if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any from fixtures.benchmark_fixture import NeonBenchmarker from fixtures.neon_api import NeonAPI @@ -247,7 +247,7 @@ def test_replication_start_stop( ], env=master_env, ) - replica_pgbench: list[Optional[subprocess.Popen[Any]]] = [None for _ in range(num_replicas)] + replica_pgbench: list[subprocess.Popen[Any] | None] = [None] * num_replicas # Use the bits of iconfig to tell us which configuration we are on. For example # a iconfig of 2 is 10 in binary, indicating replica 0 is suspended and replica 1 is diff --git a/test_runner/performance/test_sharded_ingest.py b/test_runner/performance/test_sharded_ingest.py index 77e8f2cf17..4c21e799c8 100644 --- a/test_runner/performance/test_sharded_ingest.py +++ b/test_runner/performance/test_sharded_ingest.py @@ -15,21 +15,61 @@ from fixtures.neon_fixtures import ( @pytest.mark.timeout(600) @pytest.mark.parametrize("shard_count", [1, 8, 32]) +@pytest.mark.parametrize( + "wal_receiver_protocol", + [ + "vanilla", + "interpreted-bincode-compressed", + "interpreted-protobuf-compressed", + ], +) def test_sharded_ingest( neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchmarker, shard_count: int, + wal_receiver_protocol: str, ): """ Benchmarks sharded ingestion throughput, by ingesting a large amount of WAL into a Safekeeper and fanning out to a large number of shards on dedicated Pageservers. Comparing the base case (shard_count=1) to the sharded case indicates the overhead of sharding. """ - ROW_COUNT = 100_000_000 # about 7 GB of WAL neon_env_builder.num_pageservers = shard_count - env = neon_env_builder.init_start() + env = neon_env_builder.init_configs() + + for ps in env.pageservers: + if wal_receiver_protocol == "vanilla": + ps.patch_config_toml_nonrecursive( + { + "wal_receiver_protocol": { + "type": "vanilla", + } + } + ) + elif wal_receiver_protocol == "interpreted-bincode-compressed": + ps.patch_config_toml_nonrecursive( + { + "wal_receiver_protocol": { + "type": "interpreted", + "args": {"format": "bincode", "compression": {"zstd": {"level": 1}}}, + } + } + ) + elif wal_receiver_protocol == "interpreted-protobuf-compressed": + ps.patch_config_toml_nonrecursive( + { + "wal_receiver_protocol": { + "type": "interpreted", + "args": {"format": "protobuf", "compression": {"zstd": {"level": 1}}}, + } + } + ) + else: + raise AssertionError("Test must use explicit wal receiver protocol config") + + env.start() # Create a sharded tenant and timeline, and migrate it to the respective pageservers. Ensure # the storage controller doesn't mess with shard placements. @@ -50,7 +90,6 @@ def test_sharded_ingest( # Start the endpoint. endpoint = env.endpoints.create_start("main", tenant_id=tenant_id) start_lsn = Lsn(endpoint.safe_psql("select pg_current_wal_lsn()")[0][0]) - # Ingest data and measure WAL volume and duration. with closing(endpoint.connect()) as conn: with conn.cursor() as cur: @@ -68,4 +107,48 @@ def test_sharded_ingest( wal_written_mb = round((end_lsn - start_lsn) / (1024 * 1024)) zenbenchmark.record("wal_written", wal_written_mb, "MB", MetricReport.TEST_PARAM) + total_ingested = 0 + total_records_received = 0 + ingested_by_ps = [] + for pageserver in env.pageservers: + ingested = pageserver.http_client().get_metric_value( + "pageserver_wal_ingest_bytes_received_total" + ) + records_received = pageserver.http_client().get_metric_value( + "pageserver_wal_ingest_records_received_total" + ) + + if ingested is None: + ingested = 0 + + if records_received is None: + records_received = 0 + + ingested_by_ps.append( + ( + pageserver.id, + { + "ingested": ingested, + "records_received": records_received, + }, + ) + ) + + total_ingested += int(ingested) + total_records_received += int(records_received) + + total_ingested_mb = total_ingested / (1024 * 1024) + zenbenchmark.record("wal_ingested", total_ingested_mb, "MB", MetricReport.LOWER_IS_BETTER) + zenbenchmark.record( + "records_received", total_records_received, "records", MetricReport.LOWER_IS_BETTER + ) + + ingested_by_ps.sort(key=lambda x: x[0]) + for _, stats in ingested_by_ps: + for k in stats: + if k != "records_received": + stats[k] /= 1024**2 + + log.info(f"WAL ingested by each pageserver {ingested_by_ps}") + assert tenant_get_shards(env, tenant_id) == shards, "shards moved" diff --git a/test_runner/performance/test_storage_controller_scale.py b/test_runner/performance/test_storage_controller_scale.py index dc051483f8..142bd3d669 100644 --- a/test_runner/performance/test_storage_controller_scale.py +++ b/test_runner/performance/test_storage_controller_scale.py @@ -4,7 +4,7 @@ import concurrent.futures import random import time from collections import defaultdict -from enum import Enum +from enum import StrEnum import pytest from fixtures.common_types import TenantId, TenantShardId, TimelineArchivalState, TimelineId @@ -139,7 +139,7 @@ def test_storage_controller_many_tenants( tenant_timelines_count = 100 # These lists are maintained for use with rng.choice - tenants_with_timelines = list(rng.sample(tenants.keys(), tenant_timelines_count)) + tenants_with_timelines = list(rng.sample(list(tenants.keys()), tenant_timelines_count)) tenants_without_timelines = list( tenant_id for tenant_id in tenants if tenant_id not in tenants_with_timelines ) @@ -171,7 +171,7 @@ def test_storage_controller_many_tenants( # start timing on test nodes if we aren't a bit careful. create_concurrency = 16 - class Operation(str, Enum): + class Operation(StrEnum): TIMELINE_OPS = "timeline_ops" SHARD_MIGRATE = "shard_migrate" TENANT_PASSTHROUGH = "tenant_passthrough" diff --git a/test_runner/performance/test_wal_backpressure.py b/test_runner/performance/test_wal_backpressure.py index 576a4f0467..c6d795ce4d 100644 --- a/test_runner/performance/test_wal_backpressure.py +++ b/test_runner/performance/test_wal_backpressure.py @@ -17,7 +17,8 @@ from fixtures.neon_fixtures import NeonEnvBuilder, PgBin, flush_ep_to_pageserver from performance.test_perf_pgbench import get_durations_matrix, get_scales_matrix if TYPE_CHECKING: - from typing import Any, Callable + from collections.abc import Callable + from typing import Any @pytest.fixture(params=["vanilla", "neon_off", "neon_on"]) diff --git a/test_runner/regress/test_attach_tenant_config.py b/test_runner/regress/test_attach_tenant_config.py index 7d19ba3b5d..670c2698f5 100644 --- a/test_runner/regress/test_attach_tenant_config.py +++ b/test_runner/regress/test_attach_tenant_config.py @@ -2,7 +2,6 @@ from __future__ import annotations from collections.abc import Generator from dataclasses import dataclass -from typing import Optional import pytest from fixtures.common_types import TenantId @@ -105,7 +104,7 @@ def test_null_config(negative_env: NegativeTests): @pytest.mark.parametrize("content_type", [None, "application/json"]) -def test_empty_config(positive_env: NeonEnv, content_type: Optional[str]): +def test_empty_config(positive_env: NeonEnv, content_type: str | None): """ When the 'config' body attribute is omitted, the request should be accepted and the tenant should use the default configuration @@ -175,6 +174,10 @@ def test_fully_custom_config(positive_env: NeonEnv): "lsn_lease_length": "1m", "lsn_lease_length_for_ts": "5s", "timeline_offloading": True, + "wal_receiver_protocol_override": { + "type": "interpreted", + "args": {"format": "bincode", "compression": {"zstd": {"level": 1}}}, + }, } vps_http = env.storage_controller.pageserver_api() diff --git a/test_runner/regress/test_combocid.py b/test_runner/regress/test_combocid.py index 57d5b2d8b3..2db16d9f64 100644 --- a/test_runner/regress/test_combocid.py +++ b/test_runner/regress/test_combocid.py @@ -5,12 +5,7 @@ from fixtures.neon_fixtures import NeonEnvBuilder, flush_ep_to_pageserver def do_combocid_op(neon_env_builder: NeonEnvBuilder, op): env = neon_env_builder.init_start() - endpoint = env.endpoints.create_start( - "main", - config_lines=[ - "shared_buffers='1MB'", - ], - ) + endpoint = env.endpoints.create_start("main") conn = endpoint.connect() cur = conn.cursor() @@ -36,7 +31,7 @@ def do_combocid_op(neon_env_builder: NeonEnvBuilder, op): # Clear the cache, so that we exercise reconstructing the pages # from WAL - endpoint.clear_shared_buffers() + endpoint.clear_buffers() # Check that the cursor opened earlier still works. If the # combocids are not restored correctly, it won't. @@ -65,12 +60,7 @@ def test_combocid_lock(neon_env_builder: NeonEnvBuilder): def test_combocid_multi_insert(neon_env_builder: NeonEnvBuilder): env = neon_env_builder.init_start() - endpoint = env.endpoints.create_start( - "main", - config_lines=[ - "shared_buffers='1MB'", - ], - ) + endpoint = env.endpoints.create_start("main") conn = endpoint.connect() cur = conn.cursor() @@ -98,7 +88,7 @@ def test_combocid_multi_insert(neon_env_builder: NeonEnvBuilder): cur.execute("delete from t") # Clear the cache, so that we exercise reconstructing the pages # from WAL - endpoint.clear_shared_buffers() + endpoint.clear_buffers() # Check that the cursor opened earlier still works. If the # combocids are not restored correctly, it won't. diff --git a/test_runner/regress/test_compaction.py b/test_runner/regress/test_compaction.py index 48950a5a50..302a8fd0d1 100644 --- a/test_runner/regress/test_compaction.py +++ b/test_runner/regress/test_compaction.py @@ -1,24 +1,20 @@ from __future__ import annotations -import enum import json import time -from typing import TYPE_CHECKING +from enum import StrEnum import pytest from fixtures.log_helper import log from fixtures.neon_fixtures import ( NeonEnvBuilder, + PageserverWalReceiverProtocol, generate_uploads_and_deletions, ) from fixtures.pageserver.http import PageserverApiException from fixtures.utils import skip_in_debug_build, wait_until from fixtures.workload import Workload -if TYPE_CHECKING: - from typing import Optional - - AGGRESIVE_COMPACTION_TENANT_CONF = { # Disable gc and compaction. The test runs compaction manually. "gc_period": "0s", @@ -32,7 +28,13 @@ AGGRESIVE_COMPACTION_TENANT_CONF = { @skip_in_debug_build("only run with release build") -def test_pageserver_compaction_smoke(neon_env_builder: NeonEnvBuilder): +@pytest.mark.parametrize( + "wal_receiver_protocol", + [PageserverWalReceiverProtocol.VANILLA, PageserverWalReceiverProtocol.INTERPRETED], +) +def test_pageserver_compaction_smoke( + neon_env_builder: NeonEnvBuilder, wal_receiver_protocol: PageserverWalReceiverProtocol +): """ This is a smoke test that compaction kicks in. The workload repeatedly churns a small number of rows and manually instructs the pageserver to run compaction @@ -41,6 +43,8 @@ def test_pageserver_compaction_smoke(neon_env_builder: NeonEnvBuilder): observed bounds. """ + neon_env_builder.pageserver_wal_receiver_protocol = wal_receiver_protocol + # Effectively disable the page cache to rely only on image layers # to shorten reads. neon_env_builder.pageserver_config_override = """ @@ -172,7 +176,7 @@ LARGE_STRIPES = 32768 def test_sharding_compaction( neon_env_builder: NeonEnvBuilder, stripe_size: int, - shard_count: Optional[int], + shard_count: int | None, gc_compaction: bool, ): """ @@ -277,7 +281,7 @@ def test_sharding_compaction( ) -class CompactionAlgorithm(str, enum.Enum): +class CompactionAlgorithm(StrEnum): LEGACY = "legacy" TIERED = "tiered" diff --git a/test_runner/regress/test_compatibility.py b/test_runner/regress/test_compatibility.py index 96ba3dd5a4..ba7305148f 100644 --- a/test_runner/regress/test_compatibility.py +++ b/test_runner/regress/test_compatibility.py @@ -7,7 +7,6 @@ import subprocess import tempfile from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING import fixtures.utils import pytest @@ -28,10 +27,6 @@ from fixtures.pg_version import PgVersion from fixtures.remote_storage import RemoteStorageKind, S3Storage, s3_storage from fixtures.workload import Workload -if TYPE_CHECKING: - from typing import Optional - - # # A test suite that help to prevent unintentionally breaking backward or forward compatibility between Neon releases. # - `test_create_snapshot` a script wrapped in a test that creates a data snapshot. @@ -385,7 +380,7 @@ def check_neon_works(env: NeonEnv, test_output_dir: Path, sql_dump_path: Path, r def dump_differs( - first: Path, second: Path, output: Path, allowed_diffs: Optional[list[str]] = None + first: Path, second: Path, output: Path, allowed_diffs: list[str] | None = None ) -> bool: """ Runs diff(1) command on two SQL dumps and write the output to the given output file. diff --git a/test_runner/regress/test_compute_catalog.py b/test_runner/regress/test_compute_catalog.py index d43c71ceac..b3719a45ed 100644 --- a/test_runner/regress/test_compute_catalog.py +++ b/test_runner/regress/test_compute_catalog.py @@ -3,13 +3,60 @@ from __future__ import annotations import requests from fixtures.neon_fixtures import NeonEnv +TEST_DB_NAMES = [ + { + "name": "neondb", + "owner": "cloud_admin", + }, + { + "name": "db with spaces", + "owner": "cloud_admin", + }, + { + "name": "db with%20spaces ", + "owner": "cloud_admin", + }, + { + "name": "db with whitespaces ", + "owner": "cloud_admin", + }, + { + "name": "injective db with spaces'; SELECT pg_sleep(10);", + "owner": "cloud_admin", + }, + { + "name": "db with #pound-sign and &ersands=true", + "owner": "cloud_admin", + }, + { + "name": "db with emoji 🌍", + "owner": "cloud_admin", + }, +] + def test_compute_catalog(neon_simple_env: NeonEnv): + """ + Create a bunch of databases with tricky names and test that we can list them + and dump via API. + """ env = neon_simple_env - endpoint = env.endpoints.create_start("main", config_lines=["log_min_messages=debug1"]) - client = endpoint.http_client() + endpoint = env.endpoints.create_start("main") + # Update the spec.json file to include new databases + # and reconfigure the endpoint to create some test databases. + endpoint.respec_deep( + **{ + "skip_pg_catalog_updates": False, + "cluster": { + "databases": TEST_DB_NAMES, + }, + } + ) + endpoint.reconfigure() + + client = endpoint.http_client() objects = client.dbs_and_roles() # Assert that 'cloud_admin' role exists in the 'roles' list @@ -22,9 +69,24 @@ def test_compute_catalog(neon_simple_env: NeonEnv): db["name"] == "postgres" for db in objects["databases"] ), "The 'postgres' database is missing" - ddl = client.database_schema(database="postgres") + # Check other databases + for test_db in TEST_DB_NAMES: + db = next((db for db in objects["databases"] if db["name"] == test_db["name"]), None) + assert db is not None, f"The '{test_db['name']}' database is missing" + assert ( + db["owner"] == test_db["owner"] + ), f"The '{test_db['name']}' database has incorrect owner" - assert "-- PostgreSQL database dump" in ddl + ddl = client.database_schema(database=test_db["name"]) + + # Check that it looks like a valid PostgreSQL dump + assert "-- PostgreSQL database dump" in ddl + + # Check that it doesn't contain health_check and migration traces. + # They are only created in system `postgres` database, so by checking + # that we ensure that we dump right databases. + assert "health_check" not in ddl, f"The '{test_db['name']}' database contains health_check" + assert "migration" not in ddl, f"The '{test_db['name']}' database contains migrations data" try: client.database_schema(database="nonexistentdb") @@ -33,3 +95,44 @@ def test_compute_catalog(neon_simple_env: NeonEnv): assert ( e.response.status_code == 404 ), f"Expected 404 status code, but got {e.response.status_code}" + + +def test_compute_create_databases(neon_simple_env: NeonEnv): + """ + Test that compute_ctl can create and work with databases with special + characters (whitespaces, %, tabs, etc.) in the name. + """ + env = neon_simple_env + + # Create and start endpoint so that neon_local put all the generated + # stuff into the spec.json file. + endpoint = env.endpoints.create_start("main") + + # Update the spec.json file to include new databases + # and reconfigure the endpoint to apply the changes. + endpoint.respec_deep( + **{ + "skip_pg_catalog_updates": False, + "cluster": { + "databases": TEST_DB_NAMES, + }, + } + ) + endpoint.reconfigure() + + for db in TEST_DB_NAMES: + # Check that database has a correct name in the system catalog + with endpoint.cursor() as cursor: + cursor.execute("SELECT datname FROM pg_database WHERE datname = %s", (db["name"],)) + catalog_db = cursor.fetchone() + assert catalog_db is not None + assert len(catalog_db) == 1 + assert catalog_db[0] == db["name"] + + # Check that we can connect to this database without any issues + with endpoint.cursor(dbname=db["name"]) as cursor: + cursor.execute("SELECT * FROM current_database()") + curr_db = cursor.fetchone() + assert curr_db is not None + assert len(curr_db) == 1 + assert curr_db[0] == db["name"] diff --git a/test_runner/regress/test_compute_metrics.py b/test_runner/regress/test_compute_metrics.py index c5e3034591..1b15c5f15e 100644 --- a/test_runner/regress/test_compute_metrics.py +++ b/test_runner/regress/test_compute_metrics.py @@ -3,6 +3,7 @@ from __future__ import annotations import enum import os import shutil +from enum import StrEnum from pathlib import Path from typing import TYPE_CHECKING, cast @@ -16,7 +17,7 @@ from fixtures.paths import BASE_DIR, COMPUTE_CONFIG_DIR if TYPE_CHECKING: from types import TracebackType - from typing import Optional, TypedDict, Union + from typing import Self, TypedDict from fixtures.neon_fixtures import NeonEnv from fixtures.pg_version import PgVersion @@ -26,15 +27,15 @@ if TYPE_CHECKING: metric_name: str type: str help: str - key_labels: Optional[list[str]] - values: Optional[list[str]] - query: Optional[str] - query_ref: Optional[str] + key_labels: list[str] | None + values: list[str] | None + query: str | None + query_ref: str | None class Collector(TypedDict): collector_name: str metrics: list[Metric] - queries: Optional[list[Query]] + queries: list[Query] | None class Query(TypedDict): query_name: str @@ -53,12 +54,12 @@ def __import_callback(dir: str, rel: str) -> tuple[str, bytes]: if not rel: raise RuntimeError("Empty filename") - full_path: Optional[str] = None + full_path: str | None = None if os.path.isabs(rel): full_path = rel else: for p in (dir, *JSONNET_PATH): - assert isinstance(p, (str, Path)), "for mypy" + assert isinstance(p, str | Path), "for mypy" full_path = os.path.join(p, rel) assert isinstance(full_path, str), "for mypy" @@ -82,9 +83,9 @@ def __import_callback(dir: str, rel: str) -> tuple[str, bytes]: def jsonnet_evaluate_file( - jsonnet_file: Union[str, Path], - ext_vars: Optional[Union[str, dict[str, str]]] = None, - tla_vars: Optional[Union[str, dict[str, str]]] = None, + jsonnet_file: str | Path, + ext_vars: str | dict[str, str] | None = None, + tla_vars: str | dict[str, str] | None = None, ) -> str: return cast( "str", @@ -102,7 +103,7 @@ def evaluate_collector(jsonnet_file: Path, pg_version: PgVersion) -> str: def evaluate_config( - jsonnet_file: Path, collector_name: str, collector_file: Union[str, Path], connstr: str + jsonnet_file: Path, collector_name: str, collector_file: str | Path, connstr: str ) -> str: return jsonnet_evaluate_file( jsonnet_file, @@ -115,7 +116,7 @@ def evaluate_config( @enum.unique -class SqlExporterProcess(str, enum.Enum): +class SqlExporterProcess(StrEnum): COMPUTE = "compute" AUTOSCALING = "autoscaling" @@ -184,16 +185,16 @@ class SqlExporterRunner: def stop(self) -> None: raise NotImplementedError() - def __enter__(self) -> SqlExporterRunner: + def __enter__(self) -> Self: self.start() return self def __exit__( self, - exc_type: Optional[type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, ): self.stop() @@ -241,8 +242,7 @@ if SQL_EXPORTER is None: self.with_volume_mapping(str(config_file), container_config_file, "z") self.with_volume_mapping(str(collector_file), container_collector_file, "z") - @override - def start(self) -> SqlExporterContainer: + def start(self) -> Self: super().start() log.info("Waiting for sql_exporter to be ready") diff --git a/test_runner/regress/test_crafted_wal_end.py b/test_runner/regress/test_crafted_wal_end.py index 23c6fa3a5a..6b9dcbba07 100644 --- a/test_runner/regress/test_crafted_wal_end.py +++ b/test_runner/regress/test_crafted_wal_end.py @@ -3,7 +3,7 @@ from __future__ import annotations import pytest from fixtures.log_helper import log from fixtures.neon_cli import WalCraft -from fixtures.neon_fixtures import NeonEnvBuilder +from fixtures.neon_fixtures import NeonEnvBuilder, PageserverWalReceiverProtocol # Restart nodes with WAL end having specially crafted shape, like last record # crossing segment boundary, to test decoding issues. @@ -19,7 +19,17 @@ from fixtures.neon_fixtures import NeonEnvBuilder "wal_record_crossing_segment_followed_by_small_one", ], ) -def test_crafted_wal_end(neon_env_builder: NeonEnvBuilder, wal_type: str): +@pytest.mark.parametrize( + "wal_receiver_protocol", + [PageserverWalReceiverProtocol.VANILLA, PageserverWalReceiverProtocol.INTERPRETED], +) +def test_crafted_wal_end( + neon_env_builder: NeonEnvBuilder, + wal_type: str, + wal_receiver_protocol: PageserverWalReceiverProtocol, +): + neon_env_builder.pageserver_wal_receiver_protocol = wal_receiver_protocol + env = neon_env_builder.init_start() env.create_branch("test_crafted_wal_end") env.pageserver.allowed_errors.extend( diff --git a/test_runner/regress/test_ddl_forwarding.py b/test_runner/regress/test_ddl_forwarding.py index e517e83e6f..1c5554c379 100644 --- a/test_runner/regress/test_ddl_forwarding.py +++ b/test_runner/regress/test_ddl_forwarding.py @@ -13,7 +13,7 @@ from werkzeug.wrappers.request import Request from werkzeug.wrappers.response import Response if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any, Self def handle_db(dbs, roles, operation): @@ -91,15 +91,15 @@ class DdlForwardingContext: lambda request: ddl_forward_handler(request, self.dbs, self.roles, self) ) - def __enter__(self): + def __enter__(self) -> Self: self.pg.start() return self def __exit__( self, - exc_type: Optional[type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, ): self.pg.stop() diff --git a/test_runner/regress/test_disk_usage_eviction.py b/test_runner/regress/test_disk_usage_eviction.py index c8d3b2ff3e..1807511008 100644 --- a/test_runner/regress/test_disk_usage_eviction.py +++ b/test_runner/regress/test_disk_usage_eviction.py @@ -5,6 +5,7 @@ import time from collections import Counter from collections.abc import Iterable from dataclasses import dataclass +from enum import StrEnum from typing import TYPE_CHECKING import pytest @@ -80,7 +81,7 @@ def test_min_resident_size_override_handling( @enum.unique -class EvictionOrder(str, enum.Enum): +class EvictionOrder(StrEnum): RELATIVE_ORDER_EQUAL = "relative_equal" RELATIVE_ORDER_SPARE = "relative_spare" diff --git a/test_runner/regress/test_explain_with_lfc_stats.py b/test_runner/regress/test_explain_with_lfc_stats.py index 2128bd93dd..382556fd7e 100644 --- a/test_runner/regress/test_explain_with_lfc_stats.py +++ b/test_runner/regress/test_explain_with_lfc_stats.py @@ -2,10 +2,13 @@ from __future__ import annotations from pathlib import Path +import pytest from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnv +from fixtures.utils import USE_LFC +@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") def test_explain_with_lfc_stats(neon_simple_env: NeonEnv): env = neon_simple_env @@ -16,8 +19,6 @@ def test_explain_with_lfc_stats(neon_simple_env: NeonEnv): endpoint = env.endpoints.create_start( "main", config_lines=[ - "shared_buffers='1MB'", - f"neon.file_cache_path='{cache_dir}/file.cache'", "neon.max_file_cache_size='128MB'", "neon.file_cache_size_limit='64MB'", ], diff --git a/test_runner/regress/test_hot_standby.py b/test_runner/regress/test_hot_standby.py index a906e7a243..0b1ac11c16 100644 --- a/test_runner/regress/test_hot_standby.py +++ b/test_runner/regress/test_hot_standby.py @@ -170,7 +170,7 @@ def test_hot_standby_gc(neon_env_builder: NeonEnvBuilder, pause_apply: bool): # re-execute the query, it will make GetPage # requests. This does not clear the last-written LSN cache # so we still remember the LSNs of the pages. - secondary.clear_shared_buffers(cursor=s_cur) + secondary.clear_buffers(cursor=s_cur) if pause_apply: s_cur.execute("SELECT pg_wal_replay_pause()") diff --git a/test_runner/regress/test_import_pgdata.py b/test_runner/regress/test_import_pgdata.py new file mode 100644 index 0000000000..29229b73c1 --- /dev/null +++ b/test_runner/regress/test_import_pgdata.py @@ -0,0 +1,307 @@ +import json +import re +import time +from enum import Enum + +import psycopg2 +import psycopg2.errors +import pytest +from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId +from fixtures.log_helper import log +from fixtures.neon_fixtures import NeonEnvBuilder, VanillaPostgres +from fixtures.pageserver.http import ( + ImportPgdataIdemptencyKey, + PageserverApiException, +) +from fixtures.pg_version import PgVersion +from fixtures.remote_storage import RemoteStorageKind +from fixtures.utils import run_only_on_postgres +from pytest_httpserver import HTTPServer +from werkzeug.wrappers.request import Request +from werkzeug.wrappers.response import Response + +num_rows = 1000 + + +class RelBlockSize(Enum): + ONE_STRIPE_SIZE = 1 + TWO_STRPES_PER_SHARD = 2 + MULTIPLE_RELATION_SEGMENTS = 3 + + +smoke_params = [ + # unsharded (the stripe size needs to be given for rel block size calculations) + *[(None, 1024, s) for s in RelBlockSize], + # many shards, small stripe size to speed up test + *[(8, 1024, s) for s in RelBlockSize], +] + + +@run_only_on_postgres( + [PgVersion.V14, PgVersion.V15, PgVersion.V16], + "newer control file catalog version and struct format isn't supported", +) +@pytest.mark.parametrize("shard_count,stripe_size,rel_block_size", smoke_params) +def test_pgdata_import_smoke( + vanilla_pg: VanillaPostgres, + neon_env_builder: NeonEnvBuilder, + shard_count: int | None, + stripe_size: int, + rel_block_size: RelBlockSize, + make_httpserver: HTTPServer, +): + # + # Setup fake control plane for import progress + # + def handler(request: Request) -> Response: + log.info(f"control plane request: {request.json}") + return Response(json.dumps({}), status=200) + + cplane_mgmt_api_server = make_httpserver + cplane_mgmt_api_server.expect_request(re.compile(".*")).respond_with_handler(handler) + + neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS) + env = neon_env_builder.init_start() + + env.pageserver.patch_config_toml_nonrecursive( + { + "import_pgdata_upcall_api": f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/path/to/mgmt/api" + } + ) + env.pageserver.stop() + env.pageserver.start() + + # + # Put data in vanilla pg + # + + vanilla_pg.start() + vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser") + + log.info("create relblock data") + if rel_block_size == RelBlockSize.ONE_STRIPE_SIZE: + target_relblock_size = stripe_size * 8192 + elif rel_block_size == RelBlockSize.TWO_STRPES_PER_SHARD: + target_relblock_size = (shard_count or 1) * stripe_size * 8192 * 2 + elif rel_block_size == RelBlockSize.MULTIPLE_RELATION_SEGMENTS: + target_relblock_size = int(((2.333 * 1024 * 1024 * 1024) // 8192) * 8192) + else: + raise ValueError + + # fillfactor so we don't need to produce that much data + # 900 byte per row is > 10% => 1 row per page + vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""") + + nrows = 0 + while True: + relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')") + log.info( + f"relblock size: {relblock_size/8192} pages (target: {target_relblock_size//8192}) pages" + ) + if relblock_size >= target_relblock_size: + break + addrows = int((target_relblock_size - relblock_size) // 8192) + assert addrows >= 1, "forward progress" + vanilla_pg.safe_psql(f"insert into t select generate_series({nrows+1}, {nrows + addrows})") + nrows += addrows + expect_nrows = nrows + expect_sum = ( + (nrows) * (nrows + 1) // 2 + ) # https://stackoverflow.com/questions/43901484/sum-of-the-integers-from-1-to-n + + def validate_vanilla_equivalence(ep): + # TODO: would be nicer to just compare pgdump + assert ep.safe_psql("select count(*), sum(data::bigint)::bigint from t") == [ + (expect_nrows, expect_sum) + ] + + validate_vanilla_equivalence(vanilla_pg) + + vanilla_pg.stop() + + # + # We have a Postgres data directory now. + # Make a localfs remote storage that looks like how after `fast_import` ran. + # TODO: actually exercise fast_import here + # TODO: test s3 remote storage + # + importbucket = neon_env_builder.repo_dir / "importbucket" + importbucket.mkdir() + # what cplane writes before scheduling fast_import + specpath = importbucket / "spec.json" + specpath.write_text(json.dumps({"branch_id": "somebranch", "project_id": "someproject"})) + # what fast_import writes + vanilla_pg.pgdatadir.rename(importbucket / "pgdata") + statusdir = importbucket / "status" + statusdir.mkdir() + (statusdir / "pgdata").write_text(json.dumps({"done": True})) + + # + # Do the import + # + + tenant_id = TenantId.generate() + env.storage_controller.tenant_create( + tenant_id, shard_count=shard_count, shard_stripe_size=stripe_size + ) + + timeline_id = TimelineId.generate() + log.info("starting import") + start = time.monotonic() + + idempotency = ImportPgdataIdemptencyKey.random() + log.info(f"idempotency key {idempotency}") + # TODO: teach neon_local CLI about the idempotency & 429 error so we can run inside the loop + # and check for 429 + + import_branch_name = "imported" + env.storage_controller.timeline_create( + tenant_id, + { + "new_timeline_id": str(timeline_id), + "import_pgdata": { + "idempotency_key": str(idempotency), + "location": {"LocalFs": {"path": str(importbucket.absolute())}}, + }, + }, + ) + env.neon_cli.mappings_map_branch(import_branch_name, tenant_id, timeline_id) + + while True: + locations = env.storage_controller.locate(tenant_id) + active_count = 0 + for location in locations: + shard_id = TenantShardId.parse(location["shard_id"]) + ps = env.get_pageserver(location["node_id"]) + try: + detail = ps.http_client().timeline_detail(shard_id, timeline_id) + state = detail["state"] + log.info(f"shard {shard_id} state: {state}") + if state == "Active": + active_count += 1 + except PageserverApiException as e: + if e.status_code == 404: + log.info("not found, import is in progress") + continue + elif e.status_code == 429: + log.info("import is in progress") + continue + else: + raise + + shard_status_file = statusdir / f"shard-{shard_id.shard_index}" + if state == "Active": + shard_status_file_contents = ( + shard_status_file.read_text() + ) # Active state implies import is done + shard_status = json.loads(shard_status_file_contents) + assert shard_status["done"] is True + + if active_count == len(locations): + log.info("all shards are active") + break + time.sleep(1) + + import_duration = time.monotonic() - start + log.info(f"import complete; duration={import_duration:.2f}s") + + # + # Get some timeline details for later. + # + locations = env.storage_controller.locate(tenant_id) + [shard_zero] = [ + loc for loc in locations if TenantShardId.parse(loc["shard_id"]).shard_number == 0 + ] + shard_zero_ps = env.get_pageserver(shard_zero["node_id"]) + shard_zero_http = shard_zero_ps.http_client() + shard_zero_timeline_info = shard_zero_http.timeline_detail(shard_zero["shard_id"], timeline_id) + initdb_lsn = Lsn(shard_zero_timeline_info["initdb_lsn"]) + latest_gc_cutoff_lsn = Lsn(shard_zero_timeline_info["latest_gc_cutoff_lsn"]) + last_record_lsn = Lsn(shard_zero_timeline_info["last_record_lsn"]) + disk_consistent_lsn = Lsn(shard_zero_timeline_info["disk_consistent_lsn"]) + _remote_consistent_lsn = Lsn(shard_zero_timeline_info["remote_consistent_lsn"]) + remote_consistent_lsn_visible = Lsn(shard_zero_timeline_info["remote_consistent_lsn_visible"]) + # assert remote_consistent_lsn_visible == remote_consistent_lsn TODO: this fails initially and after restart, presumably because `UploadQueue::clean.1` is still `None` + assert remote_consistent_lsn_visible == disk_consistent_lsn + assert initdb_lsn == latest_gc_cutoff_lsn + assert disk_consistent_lsn == initdb_lsn + 8 + assert last_record_lsn == disk_consistent_lsn + # TODO: assert these values are the same everywhere + + # + # Validate the resulting remote storage state. + # + + # + # Validate the imported data + # + + ro_endpoint = env.endpoints.create_start( + branch_name=import_branch_name, endpoint_id="ro", tenant_id=tenant_id, lsn=last_record_lsn + ) + + validate_vanilla_equivalence(ro_endpoint) + + # ensure the import survives restarts + ro_endpoint.stop() + env.pageserver.stop(immediate=True) + env.pageserver.start() + ro_endpoint.start() + validate_vanilla_equivalence(ro_endpoint) + + # + # validate the layer files in each shard only have the shard-specific data + # (the implementation would be functional but not efficient without this characteristic) + # + + shards = env.storage_controller.locate(tenant_id) + for shard in shards: + shard_ps = env.get_pageserver(shard["node_id"]) + result = shard_ps.timeline_scan_no_disposable_keys(shard["shard_id"], timeline_id) + assert result.tally.disposable_count == 0 + assert ( + result.tally.not_disposable_count > 0 + ), "sanity check, each shard should have some data" + + # + # validate that we can write + # + rw_endpoint = env.endpoints.create_start( + branch_name=import_branch_name, endpoint_id="rw", tenant_id=tenant_id + ) + rw_endpoint.safe_psql("create table othertable(values text)") + rw_lsn = Lsn(rw_endpoint.safe_psql_scalar("select pg_current_wal_flush_lsn()")) + + # TODO: consider using `class Workload` here + # to do compaction and whatnot? + + # + # validate that we can branch (important use case) + # + + # ... at the tip + _ = env.create_branch( + new_branch_name="br-tip", + ancestor_branch_name=import_branch_name, + tenant_id=tenant_id, + ancestor_start_lsn=rw_lsn, + ) + br_tip_endpoint = env.endpoints.create_start( + branch_name="br-tip", endpoint_id="br-tip-ro", tenant_id=tenant_id + ) + validate_vanilla_equivalence(br_tip_endpoint) + br_tip_endpoint.safe_psql("select * from othertable") + + # ... at the initdb lsn + _ = env.create_branch( + new_branch_name="br-initdb", + ancestor_branch_name=import_branch_name, + tenant_id=tenant_id, + ancestor_start_lsn=initdb_lsn, + ) + br_initdb_endpoint = env.endpoints.create_start( + branch_name="br-initdb", endpoint_id="br-initdb-ro", tenant_id=tenant_id + ) + validate_vanilla_equivalence(br_initdb_endpoint) + with pytest.raises(psycopg2.errors.UndefinedTable): + br_initdb_endpoint.safe_psql("select * from othertable") diff --git a/test_runner/regress/test_ingestion_layer_size.py b/test_runner/regress/test_ingestion_layer_size.py index 2916748925..9c9bc5b519 100644 --- a/test_runner/regress/test_ingestion_layer_size.py +++ b/test_runner/regress/test_ingestion_layer_size.py @@ -2,16 +2,12 @@ from __future__ import annotations from collections.abc import Iterable from dataclasses import dataclass -from typing import TYPE_CHECKING from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnvBuilder, wait_for_last_flush_lsn from fixtures.pageserver.http import HistoricLayerInfo, LayerMapInfo from fixtures.utils import human_bytes, skip_in_debug_build -if TYPE_CHECKING: - from typing import Union - @skip_in_debug_build("debug run is unnecessarily slow") def test_ingesting_large_batches_of_images(neon_env_builder: NeonEnvBuilder): @@ -109,14 +105,12 @@ def test_ingesting_large_batches_of_images(neon_env_builder: NeonEnvBuilder): @dataclass class Histogram: - buckets: list[Union[int, float]] + buckets: list[int | float] counts: list[int] sums: list[int] -def histogram_historic_layers( - infos: LayerMapInfo, minimum_sizes: list[Union[int, float]] -) -> Histogram: +def histogram_historic_layers(infos: LayerMapInfo, minimum_sizes: list[int | float]) -> Histogram: def log_layer(layer: HistoricLayerInfo) -> HistoricLayerInfo: log.info( f"{layer.layer_file_name} {human_bytes(layer.layer_file_size)} ({layer.layer_file_size} bytes)" @@ -128,7 +122,7 @@ def histogram_historic_layers( return histogram(sizes, minimum_sizes) -def histogram(sizes: Iterable[int], minimum_sizes: list[Union[int, float]]) -> Histogram: +def histogram(sizes: Iterable[int], minimum_sizes: list[int | float]) -> Histogram: assert all(minimum_sizes[i] < minimum_sizes[i + 1] for i in range(len(minimum_sizes) - 1)) buckets = list(enumerate(minimum_sizes)) counts = [0 for _ in buckets] diff --git a/test_runner/regress/test_installed_extensions.py b/test_runner/regress/test_installed_extensions.py index 54ce7c8340..04ccec5875 100644 --- a/test_runner/regress/test_installed_extensions.py +++ b/test_runner/regress/test_installed_extensions.py @@ -99,11 +99,15 @@ def test_installed_extensions(neon_simple_env: NeonEnv): res = client.metrics() info("Metrics: %s", res) m = parse_metrics(res) - neon_m = m.query_all("installed_extensions", {"extension_name": "neon", "version": "1.2"}) + neon_m = m.query_all( + "compute_installed_extensions", {"extension_name": "neon", "version": "1.2"} + ) assert len(neon_m) == 1 for sample in neon_m: assert sample.value == 2 - neon_m = m.query_all("installed_extensions", {"extension_name": "neon", "version": "1.3"}) + neon_m = m.query_all( + "compute_installed_extensions", {"extension_name": "neon", "version": "1.3"} + ) assert len(neon_m) == 1 for sample in neon_m: assert sample.value == 1 @@ -116,7 +120,7 @@ def test_installed_extensions(neon_simple_env: NeonEnv): try: res = client.metrics() timeout = -1 - if len(parse_metrics(res).query_all("installed_extensions")) < 4: + if len(parse_metrics(res).query_all("compute_installed_extensions")) < 4: # Assume that not all metrics that are collected yet time.sleep(1) timeout -= 1 @@ -128,17 +132,21 @@ def test_installed_extensions(neon_simple_env: NeonEnv): continue assert ( - len(parse_metrics(res).query_all("installed_extensions")) >= 4 + len(parse_metrics(res).query_all("compute_installed_extensions")) >= 4 ), "Not all metrics are collected" info("After restart metrics: %s", res) m = parse_metrics(res) - neon_m = m.query_all("installed_extensions", {"extension_name": "neon", "version": "1.2"}) + neon_m = m.query_all( + "compute_installed_extensions", {"extension_name": "neon", "version": "1.2"} + ) assert len(neon_m) == 1 for sample in neon_m: assert sample.value == 1 - neon_m = m.query_all("installed_extensions", {"extension_name": "neon", "version": "1.3"}) + neon_m = m.query_all( + "compute_installed_extensions", {"extension_name": "neon", "version": "1.3"} + ) assert len(neon_m) == 1 for sample in neon_m: assert sample.value == 1 diff --git a/test_runner/regress/test_layers_from_future.py b/test_runner/regress/test_layers_from_future.py index 309e0f3015..761ec7568f 100644 --- a/test_runner/regress/test_layers_from_future.py +++ b/test_runner/regress/test_layers_from_future.py @@ -2,6 +2,7 @@ from __future__ import annotations import time +import pytest from fixtures.common_types import Lsn from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnvBuilder, flush_ep_to_pageserver @@ -19,7 +20,11 @@ from fixtures.remote_storage import LocalFsStorage, RemoteStorageKind from fixtures.utils import query_scalar, wait_until -def test_issue_5878(neon_env_builder: NeonEnvBuilder): +@pytest.mark.parametrize( + "attach_mode", + ["default_generation", "same_generation"], +) +def test_issue_5878(neon_env_builder: NeonEnvBuilder, attach_mode: str): """ Regression test for issue https://github.com/neondatabase/neon/issues/5878 . @@ -168,11 +173,32 @@ def test_issue_5878(neon_env_builder: NeonEnvBuilder): tenant_conf = ps_http.tenant_config(tenant_id) generation_before_detach = get_generation_number() env.pageserver.tenant_detach(tenant_id) - failpoint_name = "before-delete-layer-pausable" + failpoint_deletion_queue = "deletion-queue-before-execute-pause" - ps_http.configure_failpoints((failpoint_name, "pause")) - env.pageserver.tenant_attach(tenant_id, tenant_conf.tenant_specific_overrides) - generation_after_reattach = get_generation_number() + ps_http.configure_failpoints((failpoint_deletion_queue, "pause")) + + if attach_mode == "default_generation": + env.pageserver.tenant_attach(tenant_id, tenant_conf.tenant_specific_overrides) + elif attach_mode == "same_generation": + # Attach with the same generation number -- this is possible with timeline offload and detach ancestor + env.pageserver.tenant_attach( + tenant_id, + tenant_conf.tenant_specific_overrides, + generation=generation_before_detach, + # We want to avoid the generation bump and don't want to talk with the storcon + override_storage_controller_generation=False, + ) + else: + raise AssertionError(f"Unknown attach_mode: {attach_mode}") + + # Get it from pageserver API instead of storcon API b/c we might not have attached using the storcon + # API if attach_mode == "same_generation" + tenant_location = env.pageserver.http_client().tenant_get_location(tenant_id) + generation_after_reattach = tenant_location["generation"] + + if attach_mode == "same_generation": + # The generation number should be the same as before the detach + assert generation_before_detach == generation_after_reattach wait_until_tenant_active(ps_http, tenant_id) # Ensure the IndexPart upload that unlinks the layer file finishes, i.e., doesn't clog the queue. @@ -182,15 +208,8 @@ def test_issue_5878(neon_env_builder: NeonEnvBuilder): wait_until(10, 0.5, future_layer_is_gone_from_index_part) - # NB: the layer file is unlinked index part now, but, because we made the delete - # operation stuck, the layer file itself is still in the remote_storage - wait_until( - 10, - 0.5, - lambda: env.pageserver.assert_log_contains( - f".*{tenant_id}.*at failpoint.*{failpoint_name}" - ), - ) + # We already make deletion stuck here, but we don't necessarily hit the failpoint + # because deletions are batched. future_layer_path = env.pageserver_remote_storage.remote_layer_path( tenant_id, timeline_id, future_layer.to_str(), generation=generation_before_detach ) @@ -224,11 +243,13 @@ def test_issue_5878(neon_env_builder: NeonEnvBuilder): break time.sleep(1) - # Window has passed, unstuck the delete, let upload queue drain. + # Window has passed, unstuck the delete, let deletion queue drain; the upload queue should + # have drained because we put these layer deletion operations into the deletion queue and + # have consumed the operation from the upload queue. log.info("unstuck the DELETE") - ps_http.configure_failpoints(("before-delete-layer-pausable", "off")) - + ps_http.configure_failpoints((failpoint_deletion_queue, "off")) wait_for_upload_queue_empty(ps_http, tenant_id, timeline_id) + env.pageserver.http_client().deletion_queue_flush(True) # Examine the resulting S3 state. log.info("integrity-check the remote storage") @@ -247,3 +268,12 @@ def test_issue_5878(neon_env_builder: NeonEnvBuilder): final_stat = future_layer_path.stat() log.info(f"future layer path: {future_layer_path}") assert final_stat.st_mtime != pre_stat.st_mtime + + # Ensure no weird errors in the end... + wait_for_upload_queue_empty(ps_http, tenant_id, timeline_id) + + if attach_mode == "same_generation": + # we should have detected a race upload and deferred it + env.pageserver.assert_log_contains( + "waiting for deletion queue flush to complete before uploading layer" + ) diff --git a/test_runner/regress/test_lfc_resize.py b/test_runner/regress/test_lfc_resize.py index 3083128d87..377b0fb4d4 100644 --- a/test_runner/regress/test_lfc_resize.py +++ b/test_runner/regress/test_lfc_resize.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os import random import re import subprocess @@ -10,20 +9,24 @@ import time import pytest from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnv, PgBin +from fixtures.utils import USE_LFC @pytest.mark.timeout(600) +@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") def test_lfc_resize(neon_simple_env: NeonEnv, pg_bin: PgBin): """ Test resizing the Local File Cache """ env = neon_simple_env + cache_dir = env.repo_dir / "file_cache" + cache_dir.mkdir(exist_ok=True) + env.create_branch("test_lfc_resize") endpoint = env.endpoints.create_start( "main", config_lines=[ - "neon.file_cache_path='file.cache'", - "neon.max_file_cache_size=512MB", - "neon.file_cache_size_limit=512MB", + "neon.max_file_cache_size=1GB", + "neon.file_cache_size_limit=1GB", ], ) n_resize = 10 @@ -63,8 +66,8 @@ def test_lfc_resize(neon_simple_env: NeonEnv, pg_bin: PgBin): cur.execute("select pg_reload_conf()") nretries = 10 while True: - lfc_file_path = f"{endpoint.pg_data_dir_path()}/file.cache" - lfc_file_size = os.path.getsize(lfc_file_path) + lfc_file_path = endpoint.lfc_path() + lfc_file_size = lfc_file_path.stat().st_size res = subprocess.run( ["ls", "-sk", lfc_file_path], check=True, text=True, capture_output=True ) diff --git a/test_runner/regress/test_lfc_working_set_approximation.py b/test_runner/regress/test_lfc_working_set_approximation.py index 36dfec969f..17068849d4 100644 --- a/test_runner/regress/test_lfc_working_set_approximation.py +++ b/test_runner/regress/test_lfc_working_set_approximation.py @@ -3,11 +3,13 @@ from __future__ import annotations import time from pathlib import Path +import pytest from fixtures.log_helper import log from fixtures.neon_fixtures import NeonEnv -from fixtures.utils import query_scalar +from fixtures.utils import USE_LFC, query_scalar +@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") def test_lfc_working_set_approximation(neon_simple_env: NeonEnv): env = neon_simple_env @@ -18,8 +20,6 @@ def test_lfc_working_set_approximation(neon_simple_env: NeonEnv): endpoint = env.endpoints.create_start( "main", config_lines=[ - "shared_buffers='1MB'", - f"neon.file_cache_path='{cache_dir}/file.cache'", "neon.max_file_cache_size='128MB'", "neon.file_cache_size_limit='64MB'", ], @@ -72,9 +72,10 @@ WITH (fillfactor='100'); # verify working set size after some index access of a few select pages only blocks = query_scalar(cur, "select approximate_working_set_size(true)") log.info(f"working set size after some index access of a few select pages only {blocks}") - assert blocks < 10 + assert blocks < 12 +@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") def test_sliding_working_set_approximation(neon_simple_env: NeonEnv): env = neon_simple_env diff --git a/test_runner/regress/test_local_file_cache.py b/test_runner/regress/test_local_file_cache.py index fbf018a167..94c630ffcf 100644 --- a/test_runner/regress/test_local_file_cache.py +++ b/test_runner/regress/test_local_file_cache.py @@ -6,10 +6,12 @@ import random import threading import time +import pytest from fixtures.neon_fixtures import NeonEnvBuilder -from fixtures.utils import query_scalar +from fixtures.utils import USE_LFC, query_scalar +@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") def test_local_file_cache_unlink(neon_env_builder: NeonEnvBuilder): env = neon_env_builder.init_start() @@ -19,8 +21,6 @@ def test_local_file_cache_unlink(neon_env_builder: NeonEnvBuilder): endpoint = env.endpoints.create_start( "main", config_lines=[ - "shared_buffers='1MB'", - f"neon.file_cache_path='{cache_dir}/file.cache'", "neon.max_file_cache_size='64MB'", "neon.file_cache_size_limit='10MB'", ], diff --git a/test_runner/regress/test_logical_replication.py b/test_runner/regress/test_logical_replication.py index df83ca1c44..ba471b7147 100644 --- a/test_runner/regress/test_logical_replication.py +++ b/test_runner/regress/test_logical_replication.py @@ -12,7 +12,7 @@ from fixtures.neon_fixtures import ( logical_replication_sync, wait_for_last_flush_lsn, ) -from fixtures.utils import wait_until +from fixtures.utils import USE_LFC, wait_until if TYPE_CHECKING: from fixtures.neon_fixtures import ( @@ -576,7 +576,15 @@ def test_subscriber_synchronous_commit(neon_simple_env: NeonEnv, vanilla_pg: Van # We want all data to fit into shared_buffers because later we stop # safekeeper and insert more; this shouldn't cause page requests as they # will be stuck. - sub = env.endpoints.create("subscriber", config_lines=["shared_buffers=128MB"]) + sub = env.endpoints.create( + "subscriber", + config_lines=[ + "neon.max_file_cache_size = 32MB", + "neon.file_cache_size_limit = 32MB", + ] + if USE_LFC + else [], + ) sub.start() with vanilla_pg.cursor() as pcur: diff --git a/test_runner/regress/test_lsn_mapping.py b/test_runner/regress/test_lsn_mapping.py index 8b41d0cb1c..7f0b541128 100644 --- a/test_runner/regress/test_lsn_mapping.py +++ b/test_runner/regress/test_lsn_mapping.py @@ -3,7 +3,7 @@ from __future__ import annotations import re import time from concurrent.futures import ThreadPoolExecutor -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta import pytest from fixtures.common_types import Lsn @@ -207,7 +207,7 @@ def test_ts_of_lsn_api(neon_env_builder: NeonEnvBuilder): for i in range(1000): cur.execute("INSERT INTO foo VALUES(%s)", (i,)) # Get the timestamp at UTC - after_timestamp = query_scalar(cur, "SELECT clock_timestamp()").replace(tzinfo=timezone.utc) + after_timestamp = query_scalar(cur, "SELECT clock_timestamp()").replace(tzinfo=UTC) after_lsn = query_scalar(cur, "SELECT pg_current_wal_lsn()") tbl.append([i, after_timestamp, after_lsn]) time.sleep(0.02) @@ -273,11 +273,7 @@ def test_ts_of_lsn_api(neon_env_builder: NeonEnvBuilder): ) log.info("result: %s, after_ts: %s", result, after_timestamp) - # TODO use fromisoformat once we have Python 3.11+ - # which has https://github.com/python/cpython/pull/92177 - timestamp = datetime.strptime(result, "%Y-%m-%dT%H:%M:%S.%f000Z").replace( - tzinfo=timezone.utc - ) + timestamp = datetime.fromisoformat(result).replace(tzinfo=UTC) assert timestamp < after_timestamp, "after_timestamp after timestamp" if i > 1: before_timestamp = tbl[i - step_size][1] diff --git a/test_runner/regress/test_oid_overflow.py b/test_runner/regress/test_oid_overflow.py index f69c1112c7..e2bde8be6f 100644 --- a/test_runner/regress/test_oid_overflow.py +++ b/test_runner/regress/test_oid_overflow.py @@ -39,7 +39,7 @@ def test_oid_overflow(neon_env_builder: NeonEnvBuilder): oid = cur.fetchall()[0][0] log.info(f"t2.relfilenode={oid}") - endpoint.clear_shared_buffers(cursor=cur) + endpoint.clear_buffers(cursor=cur) cur.execute("SELECT x from t1") assert cur.fetchone() == (1,) diff --git a/test_runner/regress/test_ondemand_slru_download.py b/test_runner/regress/test_ondemand_slru_download.py index 5eaba78331..f0f12290cc 100644 --- a/test_runner/regress/test_ondemand_slru_download.py +++ b/test_runner/regress/test_ondemand_slru_download.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Optional - import pytest from fixtures.common_types import Lsn from fixtures.log_helper import log @@ -13,7 +11,7 @@ from fixtures.utils import query_scalar # Test on-demand download of the pg_xact SLRUs # @pytest.mark.parametrize("shard_count", [None, 4]) -def test_ondemand_download_pg_xact(neon_env_builder: NeonEnvBuilder, shard_count: Optional[int]): +def test_ondemand_download_pg_xact(neon_env_builder: NeonEnvBuilder, shard_count: int | None): if shard_count is not None: neon_env_builder.num_pageservers = shard_count @@ -79,7 +77,7 @@ def test_ondemand_download_pg_xact(neon_env_builder: NeonEnvBuilder, shard_count @pytest.mark.parametrize("shard_count", [None, 4]) -def test_ondemand_download_replica(neon_env_builder: NeonEnvBuilder, shard_count: Optional[int]): +def test_ondemand_download_replica(neon_env_builder: NeonEnvBuilder, shard_count: int | None): if shard_count is not None: neon_env_builder.num_pageservers = shard_count diff --git a/test_runner/regress/test_pageserver_api.py b/test_runner/regress/test_pageserver_api.py index d1b70b9ee6..05e81b82e0 100644 --- a/test_runner/regress/test_pageserver_api.py +++ b/test_runner/regress/test_pageserver_api.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Optional - from fixtures.common_types import Lsn, TenantId, TimelineId from fixtures.neon_fixtures import ( DEFAULT_BRANCH_NAME, @@ -82,7 +80,7 @@ def expect_updated_msg_lsn( client: PageserverHttpClient, tenant_id: TenantId, timeline_id: TimelineId, - prev_msg_lsn: Optional[Lsn], + prev_msg_lsn: Lsn | None, ) -> Lsn: timeline_details = client.timeline_detail(tenant_id, timeline_id=timeline_id) diff --git a/test_runner/regress/test_pageserver_generations.py b/test_runner/regress/test_pageserver_generations.py index 4f59efb8b3..6ba5753420 100644 --- a/test_runner/regress/test_pageserver_generations.py +++ b/test_runner/regress/test_pageserver_generations.py @@ -11,11 +11,10 @@ of the pageserver are: from __future__ import annotations -import enum import os import re import time -from typing import TYPE_CHECKING +from enum import StrEnum import pytest from fixtures.common_types import TenantId, TimelineId @@ -41,10 +40,6 @@ from fixtures.remote_storage import ( from fixtures.utils import run_only_on_default_postgres, wait_until from fixtures.workload import Workload -if TYPE_CHECKING: - from typing import Optional - - # A tenant configuration that is convenient for generating uploads and deletions # without a large amount of postgres traffic. TENANT_CONF = { @@ -65,7 +60,7 @@ TENANT_CONF = { def read_all( - env: NeonEnv, tenant_id: Optional[TenantId] = None, timeline_id: Optional[TimelineId] = None + env: NeonEnv, tenant_id: TenantId | None = None, timeline_id: TimelineId | None = None ): if tenant_id is None: tenant_id = env.initial_tenant @@ -286,12 +281,12 @@ def test_deferred_deletion(neon_env_builder: NeonEnvBuilder): assert get_deletion_queue_unexpected_errors(ps_http) == 0 -class KeepAttachment(str, enum.Enum): +class KeepAttachment(StrEnum): KEEP = "keep" LOSE = "lose" -class ValidateBefore(str, enum.Enum): +class ValidateBefore(StrEnum): VALIDATE = "validate" NO_VALIDATE = "no-validate" @@ -464,7 +459,11 @@ def test_emergency_mode(neon_env_builder: NeonEnvBuilder, pg_bin: PgBin): env.pageserver.start() # The pageserver should provide service to clients - generate_uploads_and_deletions(env, init=False, pageserver=env.pageserver) + # Because it is in emergency mode, it will not attempt to validate deletions required by the initial barrier, and therefore + # other files cannot be uploaded b/c it's waiting for the initial barrier to be validated. + generate_uploads_and_deletions( + env, init=False, pageserver=env.pageserver, wait_until_uploaded=False + ) # The pageserver should neither validate nor execute any deletions, it should have # loaded the DeletionLists from before though diff --git a/test_runner/regress/test_pageserver_layer_rolling.py b/test_runner/regress/test_pageserver_layer_rolling.py index 200a323a3a..f6a7bfa1ad 100644 --- a/test_runner/regress/test_pageserver_layer_rolling.py +++ b/test_runner/regress/test_pageserver_layer_rolling.py @@ -2,7 +2,6 @@ from __future__ import annotations import asyncio import time -from typing import TYPE_CHECKING import psutil import pytest @@ -17,17 +16,13 @@ from fixtures.pageserver.http import PageserverHttpClient from fixtures.pageserver.utils import wait_for_last_record_lsn, wait_for_upload from fixtures.utils import skip_in_debug_build, wait_until -if TYPE_CHECKING: - from typing import Optional - - TIMELINE_COUNT = 10 ENTRIES_PER_TIMELINE = 10_000 CHECKPOINT_TIMEOUT_SECONDS = 60 async def run_worker_for_tenant( - env: NeonEnv, entries: int, tenant: TenantId, offset: Optional[int] = None + env: NeonEnv, entries: int, tenant: TenantId, offset: int | None = None ) -> Lsn: if offset is None: offset = 0 @@ -136,7 +131,7 @@ def test_pageserver_small_inmemory_layers( wait_until_pageserver_is_caught_up(env, last_flush_lsns) # We didn't write enough data to trigger a size-based checkpoint: we should see dirty data. - wait_until(10, 1, lambda: assert_dirty_bytes_nonzero(env)) # type: ignore + wait_until(10, 1, lambda: assert_dirty_bytes_nonzero(env)) ps_http_client = env.pageserver.http_client() total_wal_ingested_before_restart = wait_for_wal_ingest_metric(ps_http_client) @@ -144,7 +139,7 @@ def test_pageserver_small_inmemory_layers( # Within ~ the checkpoint interval, all the ephemeral layers should be frozen and flushed, # such that there are zero bytes of ephemeral layer left on the pageserver log.info("Waiting for background checkpoints...") - wait_until(CHECKPOINT_TIMEOUT_SECONDS * 2, 1, lambda: assert_dirty_bytes(env, 0)) # type: ignore + wait_until(CHECKPOINT_TIMEOUT_SECONDS * 2, 1, lambda: assert_dirty_bytes(env, 0)) # Zero ephemeral layer bytes does not imply that all the frozen layers were uploaded: they # must be uploaded to remain visible to the pageserver after restart. @@ -185,7 +180,7 @@ def test_idle_checkpoints(neon_env_builder: NeonEnvBuilder): wait_until_pageserver_is_caught_up(env, last_flush_lsns) # We didn't write enough data to trigger a size-based checkpoint: we should see dirty data. - wait_until(10, 1, lambda: assert_dirty_bytes_nonzero(env)) # type: ignore + wait_until(10, 1, lambda: assert_dirty_bytes_nonzero(env)) # Stop the safekeepers, so that we cannot have any more WAL receiver connections for sk in env.safekeepers: @@ -198,7 +193,7 @@ def test_idle_checkpoints(neon_env_builder: NeonEnvBuilder): # Within ~ the checkpoint interval, all the ephemeral layers should be frozen and flushed, # such that there are zero bytes of ephemeral layer left on the pageserver log.info("Waiting for background checkpoints...") - wait_until(CHECKPOINT_TIMEOUT_SECONDS * 2, 1, lambda: assert_dirty_bytes(env, 0)) # type: ignore + wait_until(CHECKPOINT_TIMEOUT_SECONDS * 2, 1, lambda: assert_dirty_bytes(env, 0)) # The code below verifies that we do not flush on the first write # after an idle period longer than the checkpoint timeout. @@ -215,7 +210,7 @@ def test_idle_checkpoints(neon_env_builder: NeonEnvBuilder): run_worker_for_tenant(env, 5, tenant_with_extra_writes, offset=ENTRIES_PER_TIMELINE) ) - dirty_after_write = wait_until(10, 1, lambda: assert_dirty_bytes_nonzero(env)) # type: ignore + dirty_after_write = wait_until(10, 1, lambda: assert_dirty_bytes_nonzero(env)) # We shouldn't flush since we've just opened a new layer waited_for = 0 @@ -317,4 +312,4 @@ def test_total_size_limit(neon_env_builder: NeonEnvBuilder): dirty_bytes = get_dirty_bytes(env) assert dirty_bytes < max_dirty_data - wait_until(compaction_period_s * 2, 1, lambda: assert_dirty_data_limited()) # type: ignore + wait_until(compaction_period_s * 2, 1, lambda: assert_dirty_data_limited()) diff --git a/test_runner/regress/test_pageserver_restart.py b/test_runner/regress/test_pageserver_restart.py index fb6050689c..4bf5705517 100644 --- a/test_runner/regress/test_pageserver_restart.py +++ b/test_runner/regress/test_pageserver_restart.py @@ -2,7 +2,6 @@ from __future__ import annotations import random from contextlib import closing -from typing import Optional import pytest from fixtures.log_helper import log @@ -156,7 +155,7 @@ def test_pageserver_restart(neon_env_builder: NeonEnvBuilder): @pytest.mark.timeout(540) @pytest.mark.parametrize("shard_count", [None, 4]) @skip_in_debug_build("times out in debug builds") -def test_pageserver_chaos(neon_env_builder: NeonEnvBuilder, shard_count: Optional[int]): +def test_pageserver_chaos(neon_env_builder: NeonEnvBuilder, shard_count: int | None): # same rationale as with the immediate stop; we might leave orphan layers behind. neon_env_builder.disable_scrub_on_exit() neon_env_builder.enable_pageserver_remote_storage(s3_storage()) diff --git a/test_runner/regress/test_pageserver_secondary.py b/test_runner/regress/test_pageserver_secondary.py index 12134048e6..a264f4d3c9 100644 --- a/test_runner/regress/test_pageserver_secondary.py +++ b/test_runner/regress/test_pageserver_secondary.py @@ -23,7 +23,7 @@ from werkzeug.wrappers.request import Request from werkzeug.wrappers.response import Response if TYPE_CHECKING: - from typing import Any, Optional, Union + from typing import Any # A tenant configuration that is convenient for generating uploads and deletions @@ -199,7 +199,7 @@ def test_location_conf_churn(neon_env_builder: NeonEnvBuilder, make_httpserver, # state if it was running attached with a stale generation last_state[pageserver.id] = ("Detached", None) else: - secondary_conf: Optional[dict[str, Any]] = None + secondary_conf: dict[str, Any] | None = None if mode == "Secondary": secondary_conf = {"warm": rng.choice([True, False])} @@ -469,7 +469,7 @@ def test_heatmap_uploads(neon_env_builder: NeonEnvBuilder): def list_elegible_layers( - pageserver, tenant_id: Union[TenantId, TenantShardId], timeline_id: TimelineId + pageserver, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId ) -> list[Path]: """ The subset of layer filenames that are elegible for secondary download: at time of writing this @@ -702,7 +702,7 @@ def test_secondary_background_downloads(neon_env_builder: NeonEnvBuilder): else: timeout = int(deadline - now) + 1 try: - wait_until(timeout, 1, lambda: pageserver.assert_log_contains(expression)) # type: ignore + wait_until(timeout, 1, lambda: pageserver.assert_log_contains(expression)) except: log.error(f"Timed out waiting for '{expression}'") raise diff --git a/test_runner/regress/test_pg_regress.py b/test_runner/regress/test_pg_regress.py index 6a5e388c53..2877f14e0e 100644 --- a/test_runner/regress/test_pg_regress.py +++ b/test_runner/regress/test_pg_regress.py @@ -21,8 +21,6 @@ from fixtures.remote_storage import s3_storage from fixtures.utils import skip_in_debug_build if TYPE_CHECKING: - from typing import Optional - from fixtures.neon_fixtures import PgBin from pytest import CaptureFixture @@ -48,7 +46,7 @@ def post_checks(env: NeonEnv, test_output_dir: Path, db_name: str, endpoint: End data properly. """ - ignored_files: Optional[list[str]] = None + ignored_files: list[str] | None = None # Neon handles unlogged relations in a special manner. During a # basebackup, we ship the init fork as the main fork. This presents a @@ -131,7 +129,7 @@ def test_pg_regress( capsys: CaptureFixture[str], base_dir: Path, pg_distrib_dir: Path, - shard_count: Optional[int], + shard_count: int | None, ): DBNAME = "regression" @@ -205,7 +203,7 @@ def test_isolation( capsys: CaptureFixture[str], base_dir: Path, pg_distrib_dir: Path, - shard_count: Optional[int], + shard_count: int | None, ): DBNAME = "isolation_regression" @@ -274,7 +272,7 @@ def test_sql_regress( capsys: CaptureFixture[str], base_dir: Path, pg_distrib_dir: Path, - shard_count: Optional[int], + shard_count: int | None, ): DBNAME = "regression" diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index e59d46e352..5a01d90d85 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -13,7 +13,7 @@ import requests from fixtures.neon_fixtures import PSQL, NeonProxy, VanillaPostgres if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any GET_CONNECTION_PID_QUERY = "SELECT pid FROM pg_stat_activity WHERE state = 'active'" @@ -228,7 +228,7 @@ def test_sql_over_http_serverless_driver(static_proxy: NeonProxy): def test_sql_over_http(static_proxy: NeonProxy): static_proxy.safe_psql("create role http with login password 'http' superuser") - def q(sql: str, params: Optional[list[Any]] = None) -> Any: + def q(sql: str, params: list[Any] | None = None) -> Any: params = params or [] connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/postgres" response = requests.post( @@ -291,7 +291,7 @@ def test_sql_over_http_db_name_with_space(static_proxy: NeonProxy): ) ) - def q(sql: str, params: Optional[list[Any]] = None) -> Any: + def q(sql: str, params: list[Any] | None = None) -> Any: params = params or [] connstr = f"postgresql://http:http@{static_proxy.domain}:{static_proxy.proxy_port}/{urllib.parse.quote(db)}" response = requests.post( @@ -310,7 +310,7 @@ def test_sql_over_http_db_name_with_space(static_proxy: NeonProxy): def test_sql_over_http_output_options(static_proxy: NeonProxy): static_proxy.safe_psql("create role http2 with login password 'http2' superuser") - def q(sql: str, raw_text: bool, array_mode: bool, params: Optional[list[Any]] = None) -> Any: + def q(sql: str, raw_text: bool, array_mode: bool, params: list[Any] | None = None) -> Any: params = params or [] connstr = ( f"postgresql://http2:http2@{static_proxy.domain}:{static_proxy.proxy_port}/postgres" @@ -346,7 +346,7 @@ def test_sql_over_http_batch(static_proxy: NeonProxy): static_proxy.safe_psql("create role http with login password 'http' superuser") def qq( - queries: list[tuple[str, Optional[list[Any]]]], + queries: list[tuple[str, list[Any] | None]], read_only: bool = False, deferrable: bool = False, ) -> Any: diff --git a/test_runner/regress/test_read_validation.py b/test_runner/regress/test_read_validation.py index 471a3b406a..70a7a675df 100644 --- a/test_runner/regress/test_read_validation.py +++ b/test_runner/regress/test_read_validation.py @@ -54,7 +54,7 @@ def test_read_validation(neon_simple_env: NeonEnv): log.info("Clear buffer cache to ensure no stale pages are brought into the cache") - endpoint.clear_shared_buffers(cursor=c) + endpoint.clear_buffers(cursor=c) cache_entries = query_scalar( c, f"select count(*) from pg_buffercache where relfilenode = {relfilenode}" diff --git a/test_runner/regress/test_readonly_node.py b/test_runner/regress/test_readonly_node.py index 826136d5f9..70d558ac5a 100644 --- a/test_runner/regress/test_readonly_node.py +++ b/test_runner/regress/test_readonly_node.py @@ -1,7 +1,6 @@ from __future__ import annotations import time -from typing import Union import pytest from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId @@ -122,7 +121,6 @@ def test_readonly_node(neon_simple_env: NeonEnv): ) -@pytest.mark.skip("See https://github.com/neondatabase/neon/issues/9754") def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder): """ Test static endpoint is protected from GC by acquiring and renewing lsn leases. @@ -175,7 +173,7 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder): def get_layers_protected_by_lease( ps_http: PageserverHttpClient, - tenant_id: Union[TenantId, TenantShardId], + tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, lease_lsn: Lsn, ) -> set[str]: @@ -232,7 +230,7 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder): return offset # Insert some records on main branch - with env.endpoints.create_start("main") as ep_main: + with env.endpoints.create_start("main", config_lines=["shared_buffers=1MB"]) as ep_main: with ep_main.cursor() as cur: cur.execute("CREATE TABLE t0(v0 int primary key, v1 text)") lsn = Lsn(0) diff --git a/test_runner/regress/test_remote_storage.py b/test_runner/regress/test_remote_storage.py index 79b5ebe39a..137e75f784 100644 --- a/test_runner/regress/test_remote_storage.py +++ b/test_runner/regress/test_remote_storage.py @@ -5,7 +5,6 @@ import queue import shutil import threading import time -from typing import TYPE_CHECKING import pytest from fixtures.common_types import Lsn, TenantId, TimelineId @@ -37,9 +36,6 @@ from fixtures.utils import ( ) from requests import ReadTimeout -if TYPE_CHECKING: - from typing import Optional - # # Tests that a piece of data is backed up and restored correctly: @@ -452,7 +448,7 @@ def test_remote_timeline_client_calls_started_metric( for (file_kind, op_kind), observations in calls_started.items(): log.info(f"ensure_calls_started_grew: {file_kind} {op_kind}: {observations}") assert all( - x < y for x, y in zip(observations, observations[1:]) + x < y for x, y in zip(observations, observations[1:], strict=False) ), f"observations for {file_kind} {op_kind} did not grow monotonically: {observations}" def churn(data_pass1, data_pass2): @@ -731,7 +727,7 @@ def test_empty_branch_remote_storage_upload_on_restart(neon_env_builder: NeonEnv # sleep a bit to force the upload task go into exponential backoff time.sleep(1) - q: queue.Queue[Optional[PageserverApiException]] = queue.Queue() + q: queue.Queue[PageserverApiException | None] = queue.Queue() barrier = threading.Barrier(2) def create_in_background(): diff --git a/test_runner/regress/test_s3_restore.py b/test_runner/regress/test_s3_restore.py index 7a9e6d62b2..8764da3c2f 100644 --- a/test_runner/regress/test_s3_restore.py +++ b/test_runner/regress/test_s3_restore.py @@ -1,7 +1,7 @@ from __future__ import annotations import time -from datetime import datetime, timezone +from datetime import UTC, datetime from fixtures.common_types import Lsn from fixtures.log_helper import log @@ -77,7 +77,7 @@ def test_tenant_s3_restore( # These sleeps are important because they fend off differences in clocks between us and S3 time.sleep(4) - ts_before_deletion = datetime.now(tz=timezone.utc).replace(tzinfo=None) + ts_before_deletion = datetime.now(tz=UTC).replace(tzinfo=None) time.sleep(4) assert ( @@ -104,7 +104,7 @@ def test_tenant_s3_restore( ) time.sleep(4) - ts_after_deletion = datetime.now(tz=timezone.utc).replace(tzinfo=None) + ts_after_deletion = datetime.now(tz=UTC).replace(tzinfo=None) time.sleep(4) ps_http.tenant_time_travel_remote_storage( diff --git a/test_runner/regress/test_sharding.py b/test_runner/regress/test_sharding.py index 3194fe6ec4..411574bd86 100644 --- a/test_runner/regress/test_sharding.py +++ b/test_runner/regress/test_sharding.py @@ -3,7 +3,7 @@ from __future__ import annotations import os import time from collections import defaultdict -from typing import TYPE_CHECKING, Any +from typing import Any import pytest import requests @@ -27,9 +27,6 @@ from typing_extensions import override from werkzeug.wrappers.request import Request from werkzeug.wrappers.response import Response -if TYPE_CHECKING: - from typing import Optional, Union - def test_sharding_smoke( neon_env_builder: NeonEnvBuilder, @@ -189,7 +186,7 @@ def test_sharding_split_unsharded( ], ) def test_sharding_split_compaction( - neon_env_builder: NeonEnvBuilder, failpoint: Optional[str], build_type: str + neon_env_builder: NeonEnvBuilder, failpoint: str | None, build_type: str ): """ Test that after a split, we clean up parent layer data in the child shards via compaction. @@ -782,7 +779,7 @@ def test_sharding_split_stripe_size( tenant_id = env.initial_tenant assert len(notifications) == 1 - expect: dict[str, Union[list[dict[str, int]], str, None, int]] = { + expect: dict[str, list[dict[str, int]] | str | None | int] = { "tenant_id": str(env.initial_tenant), "stripe_size": None, "shards": [{"node_id": int(env.pageservers[0].id), "shard_number": 0}], @@ -798,7 +795,7 @@ def test_sharding_split_stripe_size( # Check that we ended up with the stripe size that we expected, both on the pageserver # and in the notifications to compute assert len(notifications) == 2 - expect_after: dict[str, Union[list[dict[str, int]], str, None, int]] = { + expect_after: dict[str, list[dict[str, int]] | str | None | int] = { "tenant_id": str(env.initial_tenant), "stripe_size": new_stripe_size, "shards": [ @@ -1046,7 +1043,7 @@ def test_sharding_ingest_gaps( class Failure: - pageserver_id: Optional[int] + pageserver_id: int | None def apply(self, env: NeonEnv): raise NotImplementedError() @@ -1370,7 +1367,7 @@ def test_sharding_split_failures( assert attached_count == initial_shard_count - def assert_split_done(exclude_ps_id: Optional[int] = None) -> None: + def assert_split_done(exclude_ps_id: int | None = None) -> None: secondary_count = 0 attached_count = 0 for ps in env.pageservers: @@ -1408,7 +1405,7 @@ def test_sharding_split_failures( # e.g. while waiting for a storage controller to re-attach a parent shard if we failed # inside the pageserver and the storage controller responds by detaching children and attaching # parents concurrently (https://github.com/neondatabase/neon/issues/7148) - wait_until(10, 1, lambda: workload.churn_rows(10, upload=False, ingest=False)) # type: ignore + wait_until(10, 1, lambda: workload.churn_rows(10, upload=False, ingest=False)) workload.validate() diff --git a/test_runner/regress/test_sni_router.py b/test_runner/regress/test_sni_router.py index 402f27b384..2a26fef59a 100644 --- a/test_runner/regress/test_sni_router.py +++ b/test_runner/regress/test_sni_router.py @@ -3,7 +3,6 @@ from __future__ import annotations import socket import subprocess from pathlib import Path -from types import TracebackType from typing import TYPE_CHECKING import backoff @@ -12,7 +11,8 @@ from fixtures.neon_fixtures import PgProtocol, VanillaPostgres from fixtures.port_distributor import PortDistributor if TYPE_CHECKING: - from typing import Optional + from types import TracebackType + from typing import Self def generate_tls_cert(cn, certout, keyout): @@ -55,10 +55,10 @@ class PgSniRouter(PgProtocol): self.destination = destination self.tls_cert = tls_cert self.tls_key = tls_key - self._popen: Optional[subprocess.Popen[bytes]] = None + self._popen: subprocess.Popen[bytes] | None = None self.test_output_dir = test_output_dir - def start(self) -> PgSniRouter: + def start(self) -> Self: assert self._popen is None args = [ str(self.neon_binpath / "pg_sni_router"), @@ -91,14 +91,14 @@ class PgSniRouter(PgProtocol): if self._popen: self._popen.wait(timeout=2) - def __enter__(self) -> PgSniRouter: + def __enter__(self) -> Self: return self def __exit__( self, - exc_type: Optional[type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, ): if self._popen is not None: self._popen.terminate() diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index 2c3d79b18a..13bc54a114 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -5,7 +5,7 @@ import json import threading import time from collections import defaultdict -from datetime import datetime, timezone +from datetime import UTC, datetime from enum import Enum from typing import TYPE_CHECKING @@ -56,7 +56,7 @@ from werkzeug.wrappers.request import Request from werkzeug.wrappers.response import Response if TYPE_CHECKING: - from typing import Any, Optional, Union + from typing import Any def get_node_shard_counts(env: NeonEnv, tenant_ids): @@ -593,7 +593,7 @@ def test_storage_controller_compute_hook( # Initial notification from tenant creation assert len(notifications) == 1 - expect: dict[str, Union[list[dict[str, int]], str, None, int]] = { + expect: dict[str, list[dict[str, int]] | str | None | int] = { "tenant_id": str(env.initial_tenant), "stripe_size": None, "shards": [{"node_id": int(env.pageservers[0].id), "shard_number": 0}], @@ -708,7 +708,7 @@ def test_storage_controller_stuck_compute_hook( # Initial notification from tenant creation assert len(notifications) == 1 - expect: dict[str, Union[list[dict[str, int]], str, None, int]] = { + expect: dict[str, list[dict[str, int]] | str | None | int] = { "tenant_id": str(env.initial_tenant), "stripe_size": None, "shards": [{"node_id": int(env.pageservers[0].id), "shard_number": 0}], @@ -1048,7 +1048,7 @@ def test_storage_controller_s3_time_travel_recovery( ) time.sleep(4) - ts_before_disaster = datetime.now(tz=timezone.utc).replace(tzinfo=None) + ts_before_disaster = datetime.now(tz=UTC).replace(tzinfo=None) time.sleep(4) # Simulate a "disaster": delete some random files from remote storage for one of the shards @@ -1072,7 +1072,7 @@ def test_storage_controller_s3_time_travel_recovery( pass time.sleep(4) - ts_after_disaster = datetime.now(tz=timezone.utc).replace(tzinfo=None) + ts_after_disaster = datetime.now(tz=UTC).replace(tzinfo=None) time.sleep(4) # Do time travel recovery @@ -2274,7 +2274,7 @@ def test_storage_controller_node_deletion( @pytest.mark.parametrize("shard_count", [None, 2]) def test_storage_controller_metadata_health( neon_env_builder: NeonEnvBuilder, - shard_count: Optional[int], + shard_count: int | None, ): """ Create three tenants A, B, C. @@ -2494,14 +2494,14 @@ def start_env(env: NeonEnv, storage_controller_port: int): for pageserver in env.pageservers: futs.append( executor.submit( - lambda ps=pageserver: ps.start(timeout_in_seconds=timeout_in_seconds) + lambda ps=pageserver: ps.start(timeout_in_seconds=timeout_in_seconds) # type: ignore[misc] ) ) for safekeeper in env.safekeepers: futs.append( executor.submit( - lambda sk=safekeeper: sk.start(timeout_in_seconds=timeout_in_seconds) + lambda sk=safekeeper: sk.start(timeout_in_seconds=timeout_in_seconds) # type: ignore[misc] ) ) diff --git a/test_runner/regress/test_storage_scrubber.py b/test_runner/regress/test_storage_scrubber.py index 11ad2173ae..3991bd7061 100644 --- a/test_runner/regress/test_storage_scrubber.py +++ b/test_runner/regress/test_storage_scrubber.py @@ -6,7 +6,6 @@ import shutil import threading import time from concurrent.futures import ThreadPoolExecutor -from typing import TYPE_CHECKING import pytest from fixtures.common_types import TenantId, TenantShardId, TimelineId @@ -20,12 +19,9 @@ from fixtures.remote_storage import S3Storage, s3_storage from fixtures.utils import wait_until from fixtures.workload import Workload -if TYPE_CHECKING: - from typing import Optional - @pytest.mark.parametrize("shard_count", [None, 4]) -def test_scrubber_tenant_snapshot(neon_env_builder: NeonEnvBuilder, shard_count: Optional[int]): +def test_scrubber_tenant_snapshot(neon_env_builder: NeonEnvBuilder, shard_count: int | None): """ Test the `tenant-snapshot` subcommand, which grabs data from remote storage @@ -131,7 +127,7 @@ def drop_local_state(env: NeonEnv, tenant_id: TenantId): @pytest.mark.parametrize("shard_count", [None, 4]) -def test_scrubber_physical_gc(neon_env_builder: NeonEnvBuilder, shard_count: Optional[int]): +def test_scrubber_physical_gc(neon_env_builder: NeonEnvBuilder, shard_count: int | None): neon_env_builder.enable_pageserver_remote_storage(s3_storage()) neon_env_builder.num_pageservers = 2 @@ -179,9 +175,7 @@ def test_scrubber_physical_gc(neon_env_builder: NeonEnvBuilder, shard_count: Opt @pytest.mark.parametrize("shard_count", [None, 2]) -def test_scrubber_physical_gc_ancestors( - neon_env_builder: NeonEnvBuilder, shard_count: Optional[int] -): +def test_scrubber_physical_gc_ancestors(neon_env_builder: NeonEnvBuilder, shard_count: int | None): neon_env_builder.enable_pageserver_remote_storage(s3_storage()) neon_env_builder.num_pageservers = 2 @@ -499,7 +493,7 @@ def test_scrubber_physical_gc_ancestors_split(neon_env_builder: NeonEnvBuilder): @pytest.mark.parametrize("shard_count", [None, 4]) def test_scrubber_scan_pageserver_metadata( - neon_env_builder: NeonEnvBuilder, shard_count: Optional[int] + neon_env_builder: NeonEnvBuilder, shard_count: int | None ): """ Create some layers. Delete an object listed in index. Run scrubber and see if it detects the defect. diff --git a/test_runner/regress/test_subxacts.py b/test_runner/regress/test_subxacts.py index 7a46f0140c..b235da0bc7 100644 --- a/test_runner/regress/test_subxacts.py +++ b/test_runner/regress/test_subxacts.py @@ -1,6 +1,11 @@ from __future__ import annotations -from fixtures.neon_fixtures import NeonEnv, check_restored_datadir_content +import pytest +from fixtures.neon_fixtures import ( + NeonEnvBuilder, + PageserverWalReceiverProtocol, + check_restored_datadir_content, +) # Test subtransactions @@ -9,8 +14,14 @@ from fixtures.neon_fixtures import NeonEnv, check_restored_datadir_content # maintained in the pageserver, so subtransactions are not very exciting for # Neon. They are included in the commit record though and updated in the # CLOG. -def test_subxacts(neon_simple_env: NeonEnv, test_output_dir): - env = neon_simple_env +@pytest.mark.parametrize( + "wal_receiver_protocol", + [PageserverWalReceiverProtocol.VANILLA, PageserverWalReceiverProtocol.INTERPRETED], +) +def test_subxacts(neon_env_builder: NeonEnvBuilder, test_output_dir, wal_receiver_protocol): + neon_env_builder.pageserver_wal_receiver_protocol = wal_receiver_protocol + + env = neon_env_builder.init_start() endpoint = env.endpoints.create_start("main") pg_conn = endpoint.connect() diff --git a/test_runner/regress/test_tenant_detach.py b/test_runner/regress/test_tenant_detach.py index 59c14b3263..8d7ca7bc4e 100644 --- a/test_runner/regress/test_tenant_detach.py +++ b/test_runner/regress/test_tenant_detach.py @@ -1,11 +1,10 @@ from __future__ import annotations import asyncio -import enum import random import time +from enum import StrEnum from threading import Thread -from typing import TYPE_CHECKING import asyncpg import pytest @@ -28,10 +27,6 @@ from fixtures.remote_storage import ( from fixtures.utils import query_scalar, wait_until from prometheus_client.samples import Sample -if TYPE_CHECKING: - from typing import Optional - - # In tests that overlap endpoint activity with tenant attach/detach, there are # a variety of warnings that the page service may emit when it cannot acquire # an active tenant to serve a request @@ -57,7 +52,7 @@ def do_gc_target( log.info("gc http thread returning") -class ReattachMode(str, enum.Enum): +class ReattachMode(StrEnum): REATTACH_EXPLICIT = "explicit" REATTACH_RESET = "reset" REATTACH_RESET_DROP = "reset_drop" @@ -498,7 +493,7 @@ def test_metrics_while_ignoring_broken_tenant_and_reloading( r".* Changing Active tenant to Broken state, reason: broken from test" ) - def only_int(samples: list[Sample]) -> Optional[int]: + def only_int(samples: list[Sample]) -> int | None: if len(samples) == 1: return int(samples[0].value) assert len(samples) == 0 diff --git a/test_runner/regress/test_tenant_relocation.py b/test_runner/regress/test_tenant_relocation.py index fc9adb14c9..bf6120aa0a 100644 --- a/test_runner/regress/test_tenant_relocation.py +++ b/test_runner/regress/test_tenant_relocation.py @@ -28,7 +28,7 @@ from fixtures.utils import ( ) if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any def assert_abs_margin_ratio(a: float, b: float, margin_ratio: float): @@ -78,7 +78,7 @@ def populate_branch( tenant_id: TenantId, ps_http: PageserverHttpClient, create_table: bool, - expected_sum: Optional[int], + expected_sum: int | None, ) -> tuple[TimelineId, Lsn]: # insert some data with pg_cur(endpoint) as cur: diff --git a/test_runner/regress/test_timeline_archive.py b/test_runner/regress/test_timeline_archive.py index 0650f12cd1..bc2e048f69 100644 --- a/test_runner/regress/test_timeline_archive.py +++ b/test_runner/regress/test_timeline_archive.py @@ -4,7 +4,6 @@ import json import random import threading import time -from typing import Optional import pytest import requests @@ -661,7 +660,7 @@ def test_timeline_archival_chaos(neon_env_builder: NeonEnvBuilder): ], ) def test_timeline_retain_lsn( - neon_env_builder: NeonEnvBuilder, with_intermediary: bool, offload_child: Optional[str] + neon_env_builder: NeonEnvBuilder, with_intermediary: bool, offload_child: str | None ): """ Ensure that retain_lsn functionality for timelines works, both for offloaded and non-offloaded ones diff --git a/test_runner/regress/test_timeline_detach_ancestor.py b/test_runner/regress/test_timeline_detach_ancestor.py index ef0eb05612..9c7e851ba8 100644 --- a/test_runner/regress/test_timeline_detach_ancestor.py +++ b/test_runner/regress/test_timeline_detach_ancestor.py @@ -5,6 +5,7 @@ import enum import threading import time from concurrent.futures import ThreadPoolExecutor +from enum import StrEnum from queue import Empty, Queue from threading import Barrier @@ -22,7 +23,8 @@ from fixtures.neon_fixtures import ( from fixtures.pageserver.http import HistoricLayerInfo, PageserverApiException from fixtures.pageserver.utils import wait_for_last_record_lsn, wait_timeline_detail_404 from fixtures.remote_storage import LocalFsStorage, RemoteStorageKind -from fixtures.utils import assert_pageserver_backups_equal, wait_until +from fixtures.utils import assert_pageserver_backups_equal, skip_in_debug_build, wait_until +from fixtures.workload import Workload from requests import ReadTimeout @@ -36,7 +38,7 @@ def layer_name(info: HistoricLayerInfo) -> str: @enum.unique -class Branchpoint(str, enum.Enum): +class Branchpoint(StrEnum): """ Have branches at these Lsns possibly relative to L0 layer boundary. """ @@ -414,7 +416,7 @@ def test_detached_receives_flushes_while_being_detached(neon_env_builder: NeonEn assert client.timeline_detail(env.initial_tenant, timeline_id)["ancestor_timeline_id"] is None - ep.clear_shared_buffers() + ep.clear_buffers() assert ep.safe_psql("SELECT count(*) FROM foo;")[0][0] == rows assert ep.safe_psql("SELECT SUM(LENGTH(aux)) FROM foo")[0][0] != 0 ep.stop() @@ -1549,6 +1551,57 @@ def test_timeline_is_deleted_before_timeline_detach_ancestor_completes( env.pageserver.assert_log_contains(".* gc_loop.*: 1 timelines need GC", offset) +@skip_in_debug_build("only run with release build") +def test_pageserver_compaction_detach_ancestor_smoke(neon_env_builder: NeonEnvBuilder): + SMOKE_CONF = { + # Run both gc and gc-compaction. + "gc_period": "5s", + "compaction_period": "5s", + # No PiTR interval and small GC horizon + "pitr_interval": "0s", + "gc_horizon": f"{1024 ** 2}", + "lsn_lease_length": "0s", + # Small checkpoint distance to create many layers + "checkpoint_distance": 1024**2, + # Compact small layers + "compaction_target_size": 1024**2, + "image_creation_threshold": 2, + } + + env = neon_env_builder.init_start(initial_tenant_conf=SMOKE_CONF) + + tenant_id = env.initial_tenant + timeline_id = env.initial_timeline + + row_count = 10000 + churn_rounds = 50 + + ps_http = env.pageserver.http_client() + + workload_parent = Workload(env, tenant_id, timeline_id) + workload_parent.init(env.pageserver.id) + log.info("Writing initial data ...") + workload_parent.write_rows(row_count, env.pageserver.id) + branch_id = env.create_branch("child") + workload_child = Workload(env, tenant_id, branch_id, branch_name="child") + workload_child.init(env.pageserver.id, allow_recreate=True) + log.info("Writing initial data on child...") + workload_child.write_rows(row_count, env.pageserver.id) + + for i in range(1, churn_rounds + 1): + if i % 10 == 0: + log.info(f"Running churn round {i}/{churn_rounds} ...") + + workload_parent.churn_rows(row_count, env.pageserver.id) + workload_child.churn_rows(row_count, env.pageserver.id) + + ps_http.detach_ancestor(tenant_id, branch_id) + + log.info("Validating at workload end ...") + workload_parent.validate(env.pageserver.id) + workload_child.validate(env.pageserver.id) + + # TODO: # - branch near existing L1 boundary, image layers? # - investigate: why are layers started at uneven lsn? not just after branching, but in general. diff --git a/test_runner/regress/test_timeline_gc_blocking.py b/test_runner/regress/test_timeline_gc_blocking.py index c19c78e251..5a5ca3290a 100644 --- a/test_runner/regress/test_timeline_gc_blocking.py +++ b/test_runner/regress/test_timeline_gc_blocking.py @@ -3,7 +3,6 @@ from __future__ import annotations import time from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import TYPE_CHECKING import pytest from fixtures.log_helper import log @@ -14,9 +13,6 @@ from fixtures.neon_fixtures import ( ) from fixtures.pageserver.utils import wait_timeline_detail_404 -if TYPE_CHECKING: - from typing import Optional - @pytest.mark.parametrize("sharded", [True, False]) def test_gc_blocking_by_timeline(neon_env_builder: NeonEnvBuilder, sharded: bool): @@ -89,7 +85,7 @@ def wait_for_another_gc_round(): @dataclass class ScrollableLog: pageserver: NeonPageserver - offset: Optional[LogCursor] + offset: LogCursor | None def assert_log_contains(self, what: str): msg, offset = self.pageserver.assert_log_contains(what, offset=self.offset) diff --git a/test_runner/regress/test_timeline_size.py b/test_runner/regress/test_timeline_size.py index 85c6d17142..4528bc6180 100644 --- a/test_runner/regress/test_timeline_size.py +++ b/test_runner/regress/test_timeline_size.py @@ -7,7 +7,6 @@ import time from collections import defaultdict from contextlib import closing from pathlib import Path -from typing import Optional import psycopg2.errors import psycopg2.extras @@ -668,7 +667,7 @@ def test_tenant_physical_size(neon_env_builder: NeonEnvBuilder): class TimelinePhysicalSizeValues: api_current_physical: int prometheus_resident_physical: float - prometheus_remote_physical: Optional[float] = None + prometheus_remote_physical: float | None = None python_timelinedir_layerfiles_physical: int layer_map_file_size_sum: int diff --git a/test_runner/regress/test_vm_bits.py b/test_runner/regress/test_vm_bits.py index d4c2ca7e07..f93fc6bd8b 100644 --- a/test_runner/regress/test_vm_bits.py +++ b/test_runner/regress/test_vm_bits.py @@ -63,7 +63,7 @@ def test_vm_bit_clear(neon_simple_env: NeonEnv): # Clear the buffer cache, to force the VM page to be re-fetched from # the page server - endpoint.clear_shared_buffers(cursor=cur) + endpoint.clear_buffers(cursor=cur) # Check that an index-only scan doesn't see the deleted row. If the # clearing of the VM bit was not replayed correctly, this would incorrectly diff --git a/test_runner/regress/test_wal_acceptor.py b/test_runner/regress/test_wal_acceptor.py index 6eaaa3c37f..8fa33b81a9 100644 --- a/test_runner/regress/test_wal_acceptor.py +++ b/test_runner/regress/test_wal_acceptor.py @@ -61,7 +61,7 @@ from fixtures.utils import ( ) if TYPE_CHECKING: - from typing import Any, Optional + from typing import Any, Self def wait_lsn_force_checkpoint( @@ -189,7 +189,7 @@ def test_many_timelines(neon_env_builder: NeonEnvBuilder): m.flush_lsns.append(Lsn(int(sk_m.flush_lsn_inexact(tenant_id, timeline_id)))) m.commit_lsns.append(Lsn(int(sk_m.commit_lsn_inexact(tenant_id, timeline_id)))) - for flush_lsn, commit_lsn in zip(m.flush_lsns, m.commit_lsns): + for flush_lsn, commit_lsn in zip(m.flush_lsns, m.commit_lsns, strict=False): # Invariant. May be < when transaction is in progress. assert ( commit_lsn <= flush_lsn @@ -224,7 +224,7 @@ def test_many_timelines(neon_env_builder: NeonEnvBuilder): def __init__(self) -> None: super().__init__(daemon=True) self.should_stop = threading.Event() - self.exception: Optional[BaseException] = None + self.exception: BaseException | None = None def run(self) -> None: try: @@ -521,7 +521,7 @@ def test_wal_backup(neon_env_builder: NeonEnvBuilder): # Shut down subsequently each of safekeepers and fill a segment while sk is # down; ensure segment gets offloaded by others. offloaded_seg_end = [Lsn("0/2000000"), Lsn("0/3000000"), Lsn("0/4000000")] - for victim, seg_end in zip(env.safekeepers, offloaded_seg_end): + for victim, seg_end in zip(env.safekeepers, offloaded_seg_end, strict=False): victim.stop() # roughly fills one segment cur.execute("insert into t select generate_series(1,250000), 'payload'") @@ -666,7 +666,7 @@ def test_s3_wal_replay(neon_env_builder: NeonEnvBuilder): # recreate timeline on pageserver from scratch ps_http.timeline_create( - pg_version=PgVersion(pg_version), + pg_version=PgVersion(str(pg_version)), tenant_id=tenant_id, new_timeline_id=timeline_id, ) @@ -1177,14 +1177,14 @@ def cmp_sk_wal(sks: list[Safekeeper], tenant_id: TenantId, timeline_id: Timeline # report/understand if WALs are different due to that. statuses = [sk_http_cli.timeline_status(tenant_id, timeline_id) for sk_http_cli in sk_http_clis] term_flush_lsns = [(s.last_log_term, s.flush_lsn) for s in statuses] - for tfl, sk in zip(term_flush_lsns[1:], sks[1:]): + for tfl, sk in zip(term_flush_lsns[1:], sks[1:], strict=False): assert ( term_flush_lsns[0] == tfl ), f"(last_log_term, flush_lsn) are not equal on sks {sks[0].id} and {sk.id}: {term_flush_lsns[0]} != {tfl}" # check that WALs are identic. segs = [sk.list_segments(tenant_id, timeline_id) for sk in sks] - for cmp_segs, sk in zip(segs[1:], sks[1:]): + for cmp_segs, sk in zip(segs[1:], sks[1:], strict=False): assert ( segs[0] == cmp_segs ), f"lists of segments on sks {sks[0].id} and {sk.id} are not identic: {segs[0]} and {cmp_segs}" @@ -1455,12 +1455,12 @@ class SafekeeperEnv: self.pg_bin = pg_bin self.num_safekeepers = num_safekeepers self.bin_safekeeper = str(neon_binpath / "safekeeper") - self.safekeepers: Optional[list[subprocess.CompletedProcess[Any]]] = None - self.postgres: Optional[ProposerPostgres] = None - self.tenant_id: Optional[TenantId] = None - self.timeline_id: Optional[TimelineId] = None + self.safekeepers: list[subprocess.CompletedProcess[Any]] | None = None + self.postgres: ProposerPostgres | None = None + self.tenant_id: TenantId | None = None + self.timeline_id: TimelineId | None = None - def init(self) -> SafekeeperEnv: + def init(self) -> Self: assert self.postgres is None, "postgres is already initialized" assert self.safekeepers is None, "safekeepers are already initialized" @@ -1541,7 +1541,7 @@ class SafekeeperEnv: log.info(f"Killing safekeeper with pid {pid}") os.kill(pid, signal.SIGKILL) - def __enter__(self): + def __enter__(self) -> Self: return self def __exit__(self, exc_type, exc_value, traceback): @@ -2446,7 +2446,7 @@ def test_broker_discovery(neon_env_builder: NeonEnvBuilder): # generate some data to commit WAL on safekeepers endpoint.safe_psql("insert into t select generate_series(1,100), 'action'") # clear the buffers - endpoint.clear_shared_buffers() + endpoint.clear_buffers() # read data to fetch pages from pageserver endpoint.safe_psql("select sum(i) from t") diff --git a/test_runner/regress/test_wal_acceptor_async.py b/test_runner/regress/test_wal_acceptor_async.py index d3e989afa8..b32b028fa1 100644 --- a/test_runner/regress/test_wal_acceptor_async.py +++ b/test_runner/regress/test_wal_acceptor_async.py @@ -5,21 +5,22 @@ import random import time from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING import asyncpg import pytest import toml from fixtures.common_types import Lsn, TenantId, TimelineId from fixtures.log_helper import getLogger -from fixtures.neon_fixtures import Endpoint, NeonEnv, NeonEnvBuilder, Safekeeper +from fixtures.neon_fixtures import ( + Endpoint, + NeonEnv, + NeonEnvBuilder, + PageserverWalReceiverProtocol, + Safekeeper, +) from fixtures.remote_storage import RemoteStorageKind from fixtures.utils import skip_in_debug_build -if TYPE_CHECKING: - from typing import Optional - - log = getLogger("root.safekeeper_async") @@ -261,7 +262,7 @@ def test_restarts_frequent_checkpoints(neon_env_builder: NeonEnvBuilder): def endpoint_create_start( - env: NeonEnv, branch: str, pgdir_name: Optional[str], allow_multiple: bool = False + env: NeonEnv, branch: str, pgdir_name: str | None, allow_multiple: bool = False ): endpoint = Endpoint( env, @@ -287,7 +288,7 @@ async def exec_compute_query( env: NeonEnv, branch: str, query: str, - pgdir_name: Optional[str] = None, + pgdir_name: str | None = None, allow_multiple: bool = False, ): with endpoint_create_start( @@ -627,8 +628,15 @@ async def run_segment_init_failure(env: NeonEnv): # Test (injected) failure during WAL segment init. # https://github.com/neondatabase/neon/issues/6401 # https://github.com/neondatabase/neon/issues/6402 -def test_segment_init_failure(neon_env_builder: NeonEnvBuilder): +@pytest.mark.parametrize( + "wal_receiver_protocol", + [PageserverWalReceiverProtocol.VANILLA, PageserverWalReceiverProtocol.INTERPRETED], +) +def test_segment_init_failure( + neon_env_builder: NeonEnvBuilder, wal_receiver_protocol: PageserverWalReceiverProtocol +): neon_env_builder.num_safekeepers = 1 + neon_env_builder.pageserver_wal_receiver_protocol = wal_receiver_protocol env = neon_env_builder.init_start() asyncio.run(run_segment_init_failure(env)) @@ -705,7 +713,7 @@ async def run_wal_lagging(env: NeonEnv, endpoint: Endpoint, test_output_dir: Pat # invalid, to make them unavailable to the endpoint. We use # ports 10, 11 and 12 to simulate unavailable safekeepers. config = toml.load(test_output_dir / "repo" / "config") - for i, (_sk, active) in enumerate(zip(env.safekeepers, active_sk)): + for i, (_sk, active) in enumerate(zip(env.safekeepers, active_sk, strict=False)): if active: config["safekeepers"][i]["pg_port"] = env.safekeepers[i].port.pg else: diff --git a/vendor/postgres-v14 b/vendor/postgres-v14 index aeecd27b1f..284ae56be2 160000 --- a/vendor/postgres-v14 +++ b/vendor/postgres-v14 @@ -1 +1 @@ -Subproject commit aeecd27b1f0775b606409d1cbb9c8aa9853a82af +Subproject commit 284ae56be2397fd3eaf20777fa220b2d0ad968f5 diff --git a/vendor/postgres-v15 b/vendor/postgres-v15 index 544620db4c..aed79ee87b 160000 --- a/vendor/postgres-v15 +++ b/vendor/postgres-v15 @@ -1 +1 @@ -Subproject commit 544620db4ca6945be4f1f686a7fbd2cdfb0bf96f +Subproject commit aed79ee87b94779cc52ec13e3b74eba6ada93f05 diff --git a/vendor/postgres-v16 b/vendor/postgres-v16 index 3cc152ae2d..f5cfc6fa89 160000 --- a/vendor/postgres-v16 +++ b/vendor/postgres-v16 @@ -1 +1 @@ -Subproject commit 3cc152ae2d17b19679c7102486bdb94677705c02 +Subproject commit f5cfc6fa898544050e821ac688adafece1ac3cff diff --git a/vendor/postgres-v17 b/vendor/postgres-v17 index e5d795a1a0..3c15b6565f 160000 --- a/vendor/postgres-v17 +++ b/vendor/postgres-v17 @@ -1 +1 @@ -Subproject commit e5d795a1a0c25da907176d37c905badab70e00c0 +Subproject commit 3c15b6565f6c8d36d169ed9ea7412cf90cfb2a8f diff --git a/vendor/revisions.json b/vendor/revisions.json index a13ef29e45..4dae88e73d 100644 --- a/vendor/revisions.json +++ b/vendor/revisions.json @@ -1,18 +1,18 @@ { "v17": [ "17.2", - "e5d795a1a0c25da907176d37c905badab70e00c0" + "3c15b6565f6c8d36d169ed9ea7412cf90cfb2a8f" ], "v16": [ "16.6", - "3cc152ae2d17b19679c7102486bdb94677705c02" + "f5cfc6fa898544050e821ac688adafece1ac3cff" ], "v15": [ "15.10", - "544620db4ca6945be4f1f686a7fbd2cdfb0bf96f" + "aed79ee87b94779cc52ec13e3b74eba6ada93f05" ], "v14": [ "14.15", - "aeecd27b1f0775b606409d1cbb9c8aa9853a82af" + "284ae56be2397fd3eaf20777fa220b2d0ad968f5" ] } diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 53d3a7364b..c0a3abc377 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -19,7 +19,8 @@ ahash = { version = "0.8" } anyhow = { version = "1", features = ["backtrace"] } axum = { version = "0.7", features = ["ws"] } axum-core = { version = "0.4", default-features = false, features = ["tracing"] } -base64 = { version = "0.21", features = ["alloc"] } +base64-594e8ee84c453af0 = { package = "base64", version = "0.13", features = ["alloc"] } +base64-647d43efb71741da = { package = "base64", version = "0.21", features = ["alloc"] } base64ct = { version = "1", default-features = false, features = ["std"] } bytes = { version = "1", features = ["serde"] } camino = { version = "1", default-features = false, features = ["serde1"] } @@ -52,13 +53,13 @@ lazy_static = { version = "1", default-features = false, features = ["spin_no_st libc = { version = "0.2", features = ["extra_traits", "use_std"] } log = { version = "0.4", default-features = false, features = ["std"] } memchr = { version = "2" } +nix = { version = "0.26" } nom = { version = "7" } num-bigint = { version = "0.4" } num-integer = { version = "0.1", features = ["i128"] } num-traits = { version = "0.2", features = ["i128", "libm"] } once_cell = { version = "1" } parquet = { version = "53", default-features = false, features = ["zstd"] } -postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon", default-features = false, features = ["with-serde_json-1"] } prost = { version = "0.13", features = ["prost-derive"] } rand = { version = "0.8", features = ["small_rng"] } regex = { version = "1" } @@ -77,8 +78,7 @@ subtle = { version = "2" } sync_wrapper = { version = "0.1", default-features = false, features = ["futures"] } tikv-jemalloc-sys = { version = "0.6", features = ["stats"] } time = { version = "0.3", features = ["macros", "serde-well-known"] } -tokio = { version = "1", features = ["fs", "io-std", "io-util", "macros", "net", "process", "rt-multi-thread", "signal", "test-util"] } -tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon", features = ["with-serde_json-1"] } +tokio = { version = "1", features = ["full", "test-util"] } tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring", "tls12"] } tokio-stream = { version = "0.1", features = ["net"] } tokio-util = { version = "0.7", features = ["codec", "compat", "io", "rt"] }