diff --git a/.github/workflows/_build-and-test-locally.yml b/.github/workflows/_build-and-test-locally.yml index 663afa2c8b..e2203a38ec 100644 --- a/.github/workflows/_build-and-test-locally.yml +++ b/.github/workflows/_build-and-test-locally.yml @@ -38,6 +38,11 @@ on: required: false default: 1 type: number + rerun-failed: + description: 'rerun failed tests to ignore flaky tests' + required: false + default: true + type: boolean defaults: run: @@ -99,11 +104,10 @@ jobs: # Set some environment variables used by all the steps. # - # CARGO_FLAGS is extra options to pass to "cargo build", "cargo test" etc. - # It also includes --features, if any + # CARGO_FLAGS is extra options to pass to all "cargo" subcommands. # - # CARGO_FEATURES is passed to "cargo metadata". It is separate from CARGO_FLAGS, - # because "cargo metadata" doesn't accept --release or --debug options + # CARGO_PROFILE is passed to "cargo build", "cargo test" etc, but not to + # "cargo metadata", because it doesn't accept --release or --debug options. # # We run tests with addtional features, that are turned off by default (e.g. in release builds), see # corresponding Cargo.toml files for their descriptions. @@ -112,16 +116,16 @@ jobs: ARCH: ${{ inputs.arch }} SANITIZERS: ${{ inputs.sanitizers }} run: | - CARGO_FEATURES="--features testing" + CARGO_FLAGS="--locked --features testing" if [[ $BUILD_TYPE == "debug" && $ARCH == 'x64' ]]; then cov_prefix="scripts/coverage --profraw-prefix=$GITHUB_JOB --dir=/tmp/coverage run" - CARGO_FLAGS="--locked" + CARGO_PROFILE="" elif [[ $BUILD_TYPE == "debug" ]]; then cov_prefix="" - CARGO_FLAGS="--locked" + CARGO_PROFILE="" elif [[ $BUILD_TYPE == "release" ]]; then cov_prefix="" - CARGO_FLAGS="--locked --release" + CARGO_PROFILE="--release" fi if [[ $SANITIZERS == 'enabled' ]]; then make_vars="WITH_SANITIZERS=yes" @@ -131,8 +135,8 @@ jobs: { echo "cov_prefix=${cov_prefix}" echo "make_vars=${make_vars}" - echo "CARGO_FEATURES=${CARGO_FEATURES}" echo "CARGO_FLAGS=${CARGO_FLAGS}" + echo "CARGO_PROFILE=${CARGO_PROFILE}" echo "CARGO_HOME=${GITHUB_WORKSPACE}/.cargo" } >> $GITHUB_ENV @@ -184,34 +188,18 @@ jobs: path: pg_install/v17 key: v1-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build-type }}-pg-${{ steps.pg_v17_rev.outputs.pg_rev }}-bookworm-${{ hashFiles('Makefile', 'build-tools.Dockerfile') }} - - name: Build postgres v14 - if: steps.cache_pg_14.outputs.cache-hit != 'true' - run: mold -run make ${make_vars} postgres-v14 -j$(nproc) - - - name: Build postgres v15 - if: steps.cache_pg_15.outputs.cache-hit != 'true' - run: mold -run make ${make_vars} postgres-v15 -j$(nproc) - - - name: Build postgres v16 - if: steps.cache_pg_16.outputs.cache-hit != 'true' - run: mold -run make ${make_vars} postgres-v16 -j$(nproc) - - - name: Build postgres v17 - if: steps.cache_pg_17.outputs.cache-hit != 'true' - run: mold -run make ${make_vars} postgres-v17 -j$(nproc) - - - name: Build neon extensions - run: mold -run make ${make_vars} neon-pg-ext -j$(nproc) + - name: Build all + # Note: the Makefile picks up BUILD_TYPE and CARGO_PROFILE from the env variables + run: mold -run make ${make_vars} all -j$(nproc) CARGO_BUILD_FLAGS="$CARGO_FLAGS" - name: Build walproposer-lib run: mold -run make ${make_vars} walproposer-lib -j$(nproc) - - name: Run cargo build - env: - WITH_TESTS: ${{ inputs.sanitizers != 'enabled' && '--tests' || '' }} + - name: Build unit tests + if: inputs.sanitizers != 'enabled' run: | export ASAN_OPTIONS=detect_leaks=0 - ${cov_prefix} mold -run cargo build $CARGO_FLAGS $CARGO_FEATURES --bins ${WITH_TESTS} + ${cov_prefix} mold -run cargo build $CARGO_FLAGS $CARGO_PROFILE --tests # Do install *before* running rust tests because they might recompile the # binaries with different features/flags. @@ -223,7 +211,7 @@ jobs: # Install target binaries mkdir -p /tmp/neon/bin/ binaries=$( - ${cov_prefix} cargo metadata $CARGO_FEATURES --format-version=1 --no-deps | + ${cov_prefix} cargo metadata $CARGO_FLAGS --format-version=1 --no-deps | jq -r '.packages[].targets[] | select(.kind | index("bin")) | .name' ) for bin in $binaries; do @@ -240,7 +228,7 @@ jobs: mkdir -p /tmp/neon/test_bin/ test_exe_paths=$( - ${cov_prefix} cargo test $CARGO_FLAGS $CARGO_FEATURES --message-format=json --no-run | + ${cov_prefix} cargo test $CARGO_FLAGS $CARGO_PROFILE --message-format=json --no-run | jq -r '.executable | select(. != null)' ) for bin in $test_exe_paths; do @@ -274,10 +262,10 @@ jobs: export LD_LIBRARY_PATH #nextest does not yet support running doctests - ${cov_prefix} cargo test --doc $CARGO_FLAGS $CARGO_FEATURES + ${cov_prefix} cargo test --doc $CARGO_FLAGS $CARGO_PROFILE # run all non-pageserver tests - ${cov_prefix} cargo nextest run $CARGO_FLAGS $CARGO_FEATURES -E '!package(pageserver)' + ${cov_prefix} cargo nextest run $CARGO_FLAGS $CARGO_PROFILE -E '!package(pageserver)' # run pageserver tests # (When developing new pageserver features gated by config fields, we commonly make the rust @@ -286,13 +274,13 @@ jobs: # pageserver tests from non-pageserver tests cuts down the time it takes for this CI step.) NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IOENGINE=tokio-epoll-uring \ ${cov_prefix} \ - cargo nextest run $CARGO_FLAGS $CARGO_FEATURES -E 'package(pageserver)' + cargo nextest run $CARGO_FLAGS $CARGO_PROFILE -E 'package(pageserver)' # Run separate tests for real S3 export ENABLE_REAL_S3_REMOTE_STORAGE=nonempty export REMOTE_STORAGE_S3_BUCKET=neon-github-ci-tests export REMOTE_STORAGE_S3_REGION=eu-central-1 - ${cov_prefix} cargo nextest run $CARGO_FLAGS $CARGO_FEATURES -E 'package(remote_storage)' -E 'test(test_real_s3)' + ${cov_prefix} cargo nextest run $CARGO_FLAGS $CARGO_PROFILE -E 'package(remote_storage)' -E 'test(test_real_s3)' # Run separate tests for real Azure Blob Storage # XXX: replace region with `eu-central-1`-like region @@ -301,17 +289,17 @@ jobs: export AZURE_STORAGE_ACCESS_KEY="${{ secrets.AZURE_STORAGE_ACCESS_KEY_DEV }}" export REMOTE_STORAGE_AZURE_CONTAINER="${{ vars.REMOTE_STORAGE_AZURE_CONTAINER }}" export REMOTE_STORAGE_AZURE_REGION="${{ vars.REMOTE_STORAGE_AZURE_REGION }}" - ${cov_prefix} cargo nextest run $CARGO_FLAGS $CARGO_FEATURES -E 'package(remote_storage)' -E 'test(test_real_azure)' + ${cov_prefix} cargo nextest run $CARGO_FLAGS $CARGO_PROFILE -E 'package(remote_storage)' -E 'test(test_real_azure)' - name: Install postgres binaries run: | # Use tar to copy files matching the pattern, preserving the paths in the destionation tar c \ pg_install/v* \ - pg_install/build/*/src/test/regress/*.so \ - pg_install/build/*/src/test/regress/pg_regress \ - pg_install/build/*/src/test/isolation/isolationtester \ - pg_install/build/*/src/test/isolation/pg_isolation_regress \ + build/*/src/test/regress/*.so \ + build/*/src/test/regress/pg_regress \ + build/*/src/test/isolation/isolationtester \ + build/*/src/test/isolation/pg_isolation_regress \ | tar x -C /tmp/neon - name: Upload Neon artifact @@ -379,7 +367,7 @@ jobs: - name: Pytest regression tests continue-on-error: ${{ matrix.lfc_state == 'with-lfc' && inputs.build-type == 'debug' }} uses: ./.github/actions/run-python-test-set - timeout-minutes: ${{ inputs.sanitizers != 'enabled' && 75 || 180 }} + timeout-minutes: ${{ (inputs.build-type == 'release' && inputs.sanitizers != 'enabled') && 75 || 180 }} with: build_type: ${{ inputs.build-type }} test_selection: regress @@ -387,14 +375,14 @@ jobs: run_with_real_s3: true real_s3_bucket: neon-github-ci-tests real_s3_region: eu-central-1 - rerun_failed: ${{ inputs.test-run-count == 1 }} + rerun_failed: ${{ inputs.rerun-failed }} pg_version: ${{ matrix.pg_version }} sanitizers: ${{ inputs.sanitizers }} aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} # `--session-timeout` is equal to (timeout-minutes - 10 minutes) * 60 seconds. # Attempt to stop tests gracefully to generate test reports # until they are forcibly stopped by the stricter `timeout-minutes` limit. - extra_params: --session-timeout=${{ inputs.sanitizers != 'enabled' && 3000 || 10200 }} --count=${{ inputs.test-run-count }} + extra_params: --session-timeout=${{ (inputs.build-type == 'release' && inputs.sanitizers != 'enabled') && 3000 || 10200 }} --count=${{ inputs.test-run-count }} ${{ inputs.test-selection != '' && format('-k "{0}"', inputs.test-selection) || '' }} env: TEST_RESULT_CONNSTR: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }} diff --git a/.github/workflows/build-macos.yml b/.github/workflows/build-macos.yml index 0f7fa3e813..160c3d05bc 100644 --- a/.github/workflows/build-macos.yml +++ b/.github/workflows/build-macos.yml @@ -110,7 +110,7 @@ jobs: build-walproposer-lib: if: | - inputs.pg_versions != '[]' || inputs.rebuild_everything || + contains(inputs.pg_versions, 'v17') || inputs.rebuild_everything || contains(github.event.pull_request.labels.*.name, 'run-extra-build-macos') || contains(github.event.pull_request.labels.*.name, 'run-extra-build-*') || github.ref_name == 'main' @@ -144,7 +144,7 @@ jobs: id: cache_walproposer_lib uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: - path: pg_install/build/walproposer-lib + path: build/walproposer-lib key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-walproposer_lib-v17-${{ steps.pg_rev.outputs.pg_rev }}-${{ hashFiles('Makefile') }} - name: Checkout submodule vendor/postgres-v17 @@ -169,11 +169,11 @@ jobs: run: make walproposer-lib -j$(sysctl -n hw.ncpu) - - name: Upload "pg_install/build/walproposer-lib" artifact + - name: Upload "build/walproposer-lib" artifact uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 with: - name: pg_install--build--walproposer-lib - path: pg_install/build/walproposer-lib + name: build--walproposer-lib + path: build/walproposer-lib # The artifact is supposed to be used by the next job in the same workflow, # so there’s no need to store it for too long. retention-days: 1 @@ -226,11 +226,11 @@ jobs: name: pg_install--v17 path: pg_install/v17 - - name: Download "pg_install/build/walproposer-lib" artifact + - name: Download "build/walproposer-lib" artifact uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0 with: - name: pg_install--build--walproposer-lib - path: pg_install/build/walproposer-lib + name: build--walproposer-lib + path: build/walproposer-lib # `actions/download-artifact` doesn't preserve permissions: # https://github.com/actions/download-artifact?tab=readme-ov-file#permission-loss diff --git a/.github/workflows/build_and_run_selected_test.yml b/.github/workflows/build_and_run_selected_test.yml index 7f1eb991c4..6d3541d1b6 100644 --- a/.github/workflows/build_and_run_selected_test.yml +++ b/.github/workflows/build_and_run_selected_test.yml @@ -58,6 +58,7 @@ jobs: test-cfg: ${{ inputs.pg-versions }} test-selection: ${{ inputs.test-selection }} test-run-count: ${{ fromJson(inputs.run-count) }} + rerun-failed: false secrets: inherit create-test-report: diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 9f2fa3d52c..94f768719f 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -199,6 +199,28 @@ jobs: build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm secrets: inherit + validate-compute-manifest: + runs-on: ubuntu-22.04 + needs: [ meta, check-permissions ] + # We do need to run this in `.*-rc-pr` because of hotfixes. + if: ${{ contains(fromJSON('["pr", "push-main", "storage-rc-pr", "proxy-rc-pr", "compute-rc-pr"]'), needs.meta.outputs.run-kind) }} + steps: + - name: Harden the runner (Audit all outbound calls) + uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0 + with: + egress-policy: audit + + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Set up Node.js + uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0 + with: + node-version: '24' + + - name: Validate manifest against schema + run: | + make -C compute manifest-schema-validation + build-and-test-locally: needs: [ meta, build-build-tools-image ] # We do need to run this in `.*-rc-pr` because of hotfixes. @@ -648,7 +670,7 @@ jobs: ghcr.io/neondatabase/neon:${{ needs.meta.outputs.build-tag }}-bookworm-arm64 compute-node-image-arch: - needs: [ check-permissions, build-build-tools-image, meta ] + needs: [ check-permissions, meta ] if: ${{ contains(fromJSON('["push-main", "pr", "compute-rc-pr"]'), needs.meta.outputs.run-kind) }} permissions: id-token: write # aws-actions/configure-aws-credentials @@ -721,7 +743,6 @@ jobs: GIT_VERSION=${{ github.event.pull_request.head.sha || github.sha }} PG_VERSION=${{ matrix.version.pg }} BUILD_TAG=${{ needs.meta.outputs.release-tag || needs.meta.outputs.build-tag }} - TAG=${{ needs.build-build-tools-image.outputs.image-tag }}-${{ matrix.version.debian }} DEBIAN_VERSION=${{ matrix.version.debian }} provenance: false push: true @@ -741,7 +762,6 @@ jobs: GIT_VERSION=${{ github.event.pull_request.head.sha || github.sha }} PG_VERSION=${{ matrix.version.pg }} BUILD_TAG=${{ needs.meta.outputs.release-tag || needs.meta.outputs.build-tag }} - TAG=${{ needs.build-build-tools-image.outputs.image-tag }}-${{ matrix.version.debian }} DEBIAN_VERSION=${{ matrix.version.debian }} provenance: false push: true diff --git a/.github/workflows/build_and_test_fully.yml b/.github/workflows/build_and_test_fully.yml new file mode 100644 index 0000000000..dd1d63b02b --- /dev/null +++ b/.github/workflows/build_and_test_fully.yml @@ -0,0 +1,151 @@ +name: Build and Test Fully + +on: + schedule: + # * is a special character in YAML so you have to quote this string + # ┌───────────── minute (0 - 59) + # │ ┌───────────── hour (0 - 23) + # │ │ ┌───────────── day of the month (1 - 31) + # │ │ │ ┌───────────── month (1 - 12 or JAN-DEC) + # │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT) + - cron: '0 3 * * *' # run once a day, timezone is utc + workflow_dispatch: + +defaults: + run: + shell: bash -euxo pipefail {0} + +concurrency: + # Allow only one workflow per any non-`main` branch. + group: ${{ github.workflow }}-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }} + cancel-in-progress: true + +env: + RUST_BACKTRACE: 1 + COPT: '-Werror' + +jobs: + tag: + runs-on: [ self-hosted, small ] + container: ${{ vars.NEON_DEV_AWS_ACCOUNT_ID }}.dkr.ecr.${{ vars.AWS_ECR_REGION }}.amazonaws.com/base:pinned + outputs: + build-tag: ${{steps.build-tag.outputs.tag}} + + steps: + # Need `fetch-depth: 0` to count the number of commits in the branch + - name: Harden the runner (Audit all outbound calls) + uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0 + with: + egress-policy: audit + + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + + - name: Get build tag + run: | + echo run:$GITHUB_RUN_ID + echo ref:$GITHUB_REF_NAME + echo rev:$(git rev-list --count HEAD) + if [[ "$GITHUB_REF_NAME" == "main" ]]; then + echo "tag=$(git rev-list --count HEAD)" >> $GITHUB_OUTPUT + elif [[ "$GITHUB_REF_NAME" == "release" ]]; then + echo "tag=release-$(git rev-list --count HEAD)" >> $GITHUB_OUTPUT + elif [[ "$GITHUB_REF_NAME" == "release-proxy" ]]; then + echo "tag=release-proxy-$(git rev-list --count HEAD)" >> $GITHUB_OUTPUT + elif [[ "$GITHUB_REF_NAME" == "release-compute" ]]; then + echo "tag=release-compute-$(git rev-list --count HEAD)" >> $GITHUB_OUTPUT + else + echo "GITHUB_REF_NAME (value '$GITHUB_REF_NAME') is not set to either 'main' or 'release', 'release-proxy', 'release-compute'" + echo "tag=$GITHUB_RUN_ID" >> $GITHUB_OUTPUT + fi + shell: bash + id: build-tag + + build-build-tools-image: + uses: ./.github/workflows/build-build-tools-image.yml + secrets: inherit + + build-and-test-locally: + needs: [ tag, build-build-tools-image ] + strategy: + fail-fast: false + matrix: + arch: [ x64, arm64 ] + build-type: [ debug, release ] + uses: ./.github/workflows/_build-and-test-locally.yml + with: + arch: ${{ matrix.arch }} + build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm + build-tag: ${{ needs.tag.outputs.build-tag }} + build-type: ${{ matrix.build-type }} + rerun-failed: false + test-cfg: '[{"pg_version":"v14", "lfc_state": "with-lfc"}, + {"pg_version":"v15", "lfc_state": "with-lfc"}, + {"pg_version":"v16", "lfc_state": "with-lfc"}, + {"pg_version":"v17", "lfc_state": "with-lfc"}, + {"pg_version":"v14", "lfc_state": "without-lfc"}, + {"pg_version":"v15", "lfc_state": "without-lfc"}, + {"pg_version":"v16", "lfc_state": "without-lfc"}, + {"pg_version":"v17", "lfc_state": "withouts-lfc"}]' + secrets: inherit + + + create-test-report: + needs: [ build-and-test-locally, build-build-tools-image ] + if: ${{ !cancelled() }} + permissions: + id-token: write # aws-actions/configure-aws-credentials + statuses: write + contents: write + pull-requests: write + outputs: + report-url: ${{ steps.create-allure-report.outputs.report-url }} + + runs-on: [ self-hosted, small ] + container: + image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm + credentials: + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + options: --init + + steps: + - name: Harden the runner (Audit all outbound calls) + uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0 + with: + egress-policy: audit + + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Create Allure report + if: ${{ !cancelled() }} + id: create-allure-report + uses: ./.github/actions/allure-report-generate + with: + store-test-results-into-db: true + aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} + env: + REGRESS_TEST_RESULT_CONNSTR_NEW: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }} + + - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + if: ${{ !cancelled() }} + with: + # Retry script for 5XX server errors: https://github.com/actions/github-script#retries + retries: 5 + script: | + const report = { + reportUrl: "${{ steps.create-allure-report.outputs.report-url }}", + reportJsonUrl: "${{ steps.create-allure-report.outputs.report-json-url }}", + } + + const coverage = {} + + const script = require("./scripts/comment-test-report.js") + await script({ + github, + context, + fetch, + report, + coverage, + }) diff --git a/.github/workflows/build_and_test_with_sanitizers.yml b/.github/workflows/build_and_test_with_sanitizers.yml index c54448dedc..32fb3c7c15 100644 --- a/.github/workflows/build_and_test_with_sanitizers.yml +++ b/.github/workflows/build_and_test_with_sanitizers.yml @@ -79,6 +79,7 @@ jobs: build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm build-tag: ${{ needs.tag.outputs.build-tag }} build-type: ${{ matrix.build-type }} + rerun-failed: false test-cfg: '[{"pg_version":"v17"}]' sanitizers: enabled secrets: inherit diff --git a/.github/workflows/large_oltp_benchmark.yml b/.github/workflows/large_oltp_benchmark.yml index 42dcc8e918..050b9047c7 100644 --- a/.github/workflows/large_oltp_benchmark.yml +++ b/.github/workflows/large_oltp_benchmark.yml @@ -33,11 +33,19 @@ jobs: fail-fast: false # allow other variants to continue even if one fails matrix: include: + # test only read-only custom scripts in new branch without database maintenance + - target: new_branch + custom_scripts: select_any_webhook_with_skew.sql@300 select_recent_webhook.sql@397 select_prefetch_webhook.sql@3 + test_maintenance: false + # test all custom scripts in new branch with database maintenance - target: new_branch custom_scripts: insert_webhooks.sql@200 select_any_webhook_with_skew.sql@300 select_recent_webhook.sql@397 select_prefetch_webhook.sql@3 IUD_one_transaction.sql@100 + test_maintenance: true + # test all custom scripts in reuse branch with database maintenance - target: reuse_branch custom_scripts: insert_webhooks.sql@200 select_any_webhook_with_skew.sql@300 select_recent_webhook.sql@397 select_prefetch_webhook.sql@3 IUD_one_transaction.sql@100 - max-parallel: 1 # we want to run each stripe size sequentially to be able to compare the results + test_maintenance: true + max-parallel: 1 # we want to run each benchmark sequentially to not have noisy neighbors on shared storage (PS, SK) permissions: contents: write statuses: write @@ -145,6 +153,7 @@ jobs: PERF_TEST_RESULT_CONNSTR: "${{ secrets.PERF_TEST_RESULT_CONNSTR }}" - name: Benchmark database maintenance + if: ${{ matrix.test_maintenance == 'true' }} uses: ./.github/actions/run-python-test-set with: build_type: ${{ env.BUILD_TYPE }} diff --git a/.github/workflows/large_oltp_growth.yml b/.github/workflows/large_oltp_growth.yml new file mode 100644 index 0000000000..8ca640d6ef --- /dev/null +++ b/.github/workflows/large_oltp_growth.yml @@ -0,0 +1,175 @@ +name: large oltp growth +# workflow to grow the reuse branch of large oltp benchmark continuously (about 16 GB per run) + +on: + # uncomment to run on push for debugging your PR + # push: + # branches: [ bodobolero/increase_large_oltp_workload ] + + schedule: + # * is a special character in YAML so you have to quote this string + # ┌───────────── minute (0 - 59) + # │ ┌───────────── hour (0 - 23) + # │ │ ┌───────────── day of the month (1 - 31) + # │ │ │ ┌───────────── month (1 - 12 or JAN-DEC) + # │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT) + - cron: '0 6 * * *' # 06:00 UTC + - cron: '0 8 * * *' # 08:00 UTC + - cron: '0 10 * * *' # 10:00 UTC + - cron: '0 12 * * *' # 12:00 UTC + - cron: '0 14 * * *' # 14:00 UTC + - cron: '0 16 * * *' # 16:00 UTC + workflow_dispatch: # adds ability to run this manually + +defaults: + run: + shell: bash -euxo pipefail {0} + +concurrency: + # Allow only one workflow globally because we need dedicated resources which only exist once + group: large-oltp-growth + cancel-in-progress: true + +permissions: + contents: read + +jobs: + oltp: + strategy: + fail-fast: false # allow other variants to continue even if one fails + matrix: + include: + # for now only grow the reuse branch, not the other branches. + - target: reuse_branch + custom_scripts: + - grow_action_blocks.sql + - grow_action_kwargs.sql + - grow_device_fingerprint_event.sql + - grow_edges.sql + - grow_hotel_rate_mapping.sql + - grow_ocr_pipeline_results_version.sql + - grow_priceline_raw_response.sql + - grow_relabled_transactions.sql + - grow_state_values.sql + - grow_values.sql + - grow_vertices.sql + - update_accounting_coding_body_tracking_category_selection.sql + - update_action_blocks.sql + - update_action_kwargs.sql + - update_denormalized_approval_workflow.sql + - update_device_fingerprint_event.sql + - update_edges.sql + - update_heron_transaction_enriched_log.sql + - update_heron_transaction_enrichment_requests.sql + - update_hotel_rate_mapping.sql + - update_incoming_webhooks.sql + - update_manual_transaction.sql + - update_ml_receipt_matching_log.sql + - update_ocr_pipeine_results_version.sql + - update_orc_pipeline_step_results.sql + - update_orc_pipeline_step_results_version.sql + - update_priceline_raw_response.sql + - update_quickbooks_transactions.sql + - update_raw_finicity_transaction.sql + - update_relabeled_transactions.sql + - update_state_values.sql + - update_stripe_authorization_event_log.sql + - update_transaction.sql + - update_values.sql + - update_vertices.sql + max-parallel: 1 # we want to run each growth workload sequentially (for now there is just one) + permissions: + contents: write + statuses: write + id-token: write # aws-actions/configure-aws-credentials + env: + TEST_PG_BENCH_DURATIONS_MATRIX: "1h" + TEST_PGBENCH_CUSTOM_SCRIPTS: ${{ join(matrix.custom_scripts, ' ') }} + POSTGRES_DISTRIB_DIR: /tmp/neon/pg_install + PG_VERSION: 16 # pre-determined by pre-determined project + TEST_OUTPUT: /tmp/test_output + BUILD_TYPE: remote + PLATFORM: ${{ matrix.target }} + + runs-on: [ self-hosted, us-east-2, x64 ] + container: + image: ghcr.io/neondatabase/build-tools:pinned-bookworm + credentials: + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + options: --init + + steps: + - name: Harden the runner (Audit all outbound calls) + uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0 + with: + egress-policy: audit + + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Configure AWS credentials # necessary to download artefacts + uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2 + with: + aws-region: eu-central-1 + role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} + role-duration-seconds: 18000 # 5 hours is currently max associated with IAM role + + - name: Download Neon artifact + uses: ./.github/actions/download + with: + name: neon-${{ runner.os }}-${{ runner.arch }}-release-artifact + path: /tmp/neon/ + prefix: latest + aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} + + - name: Set up Connection String + id: set-up-connstr + run: | + case "${{ matrix.target }}" in + reuse_branch) + CONNSTR=${{ secrets.BENCHMARK_LARGE_OLTP_REUSE_CONNSTR }} + ;; + *) + echo >&2 "Unknown target=${{ matrix.target }}" + exit 1 + ;; + esac + + CONNSTR_WITHOUT_POOLER="${CONNSTR//-pooler/}" + + echo "connstr=${CONNSTR}" >> $GITHUB_OUTPUT + echo "connstr_without_pooler=${CONNSTR_WITHOUT_POOLER}" >> $GITHUB_OUTPUT + + - name: pgbench with custom-scripts + uses: ./.github/actions/run-python-test-set + with: + build_type: ${{ env.BUILD_TYPE }} + test_selection: performance + run_in_parallel: false + save_perf_report: true + extra_params: -m remote_cluster --timeout 7200 -k test_perf_oltp_large_tenant_growth + pg_version: ${{ env.PG_VERSION }} + aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} + env: + BENCHMARK_CONNSTR: ${{ steps.set-up-connstr.outputs.connstr }} + VIP_VAP_ACCESS_TOKEN: "${{ secrets.VIP_VAP_ACCESS_TOKEN }}" + PERF_TEST_RESULT_CONNSTR: "${{ secrets.PERF_TEST_RESULT_CONNSTR }}" + + - name: Create Allure report + id: create-allure-report + if: ${{ !cancelled() }} + uses: ./.github/actions/allure-report-generate + with: + aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} + + - name: Post to a Slack channel + if: ${{ github.event.schedule && failure() }} + uses: slackapi/slack-github-action@fcfb566f8b0aab22203f066d80ca1d7e4b5d05b3 # v1.27.1 + with: + channel-id: "C06KHQVQ7U3" # on-call-qa-staging-stream + slack-message: | + Periodic large oltp tenant growth increase: ${{ job.status }} + <${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|GitHub Run> + <${{ steps.create-allure-report.outputs.report-url }}|Allure report> + env: + SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} diff --git a/.github/workflows/proxy-benchmark.yml b/.github/workflows/proxy-benchmark.yml new file mode 100644 index 0000000000..75ecacaced --- /dev/null +++ b/.github/workflows/proxy-benchmark.yml @@ -0,0 +1,83 @@ +name: Periodic proxy performance test on unit-perf hetzner runner + +on: + push: # TODO: remove after testing + branches: + - test-proxy-bench # Runs on pushes to branches starting with test-proxy-bench + # schedule: + # * is a special character in YAML so you have to quote this string + # ┌───────────── minute (0 - 59) + # │ ┌───────────── hour (0 - 23) + # │ │ ┌───────────── day of the month (1 - 31) + # │ │ │ ┌───────────── month (1 - 12 or JAN-DEC) + # │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT) + # - cron: '0 5 * * *' # Runs at 5 UTC once a day + workflow_dispatch: # adds an ability to run this manually + +defaults: + run: + shell: bash -euo pipefail {0} + +concurrency: + group: ${{ github.workflow }} + cancel-in-progress: false + +permissions: + contents: read + +jobs: + run_periodic_proxybench_test: + permissions: + id-token: write # aws-actions/configure-aws-credentials + statuses: write + contents: write + pull-requests: write + runs-on: [self-hosted, unit-perf] + timeout-minutes: 60 # 1h timeout + container: + image: ghcr.io/neondatabase/build-tools:pinned-bookworm + credentials: + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + options: --init + steps: + - name: Checkout proxy-bench Repo + uses: actions/checkout@v4 + with: + repository: neondatabase/proxy-bench + path: proxy-bench + + - name: Set up the environment which depends on $RUNNER_TEMP on nvme drive + id: set-env + shell: bash -euxo pipefail {0} + run: | + PROXY_BENCH_PATH=$(realpath ./proxy-bench) + { + echo "PROXY_BENCH_PATH=$PROXY_BENCH_PATH" + echo "NEON_DIR=${RUNNER_TEMP}/neon" + echo "TEST_OUTPUT=${PROXY_BENCH_PATH}/test_output" + echo "" + } >> "$GITHUB_ENV" + + - name: Run proxy-bench + run: ./${PROXY_BENCH_PATH}/run.sh + + - name: Ingest Bench Results # neon repo script + if: success() + run: | + mkdir -p $TEST_OUTPUT + python $NEON_DIR/scripts/proxy_bench_results_ingest.py --out $TEST_OUTPUT + + - name: Push Metrics to Proxy perf database + if: success() + env: + PERF_TEST_RESULT_CONNSTR: "${{ secrets.PROXY_TEST_RESULT_CONNSTR }}" + REPORT_FROM: $TEST_OUTPUT + run: $NEON_DIR/scripts/generate_and_push_perf_report.sh + + - name: Docker cleanup + run: docker compose down + + - name: Notify Failure + if: failure() + run: echo "Proxy bench job failed" && exit 1 \ No newline at end of file diff --git a/.gitignore b/.gitignore index 45eb4dbf0e..70c7e96303 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /artifact_cache +/build /pg_install /target /tmp_check diff --git a/Cargo.lock b/Cargo.lock index 588a63b6a3..1fee728d9c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -753,6 +753,7 @@ dependencies = [ "axum", "axum-core", "bytes", + "form_urlencoded", "futures-util", "headers", "http 1.1.0", @@ -761,6 +762,8 @@ dependencies = [ "mime", "pin-project-lite", "serde", + "serde_html_form", + "serde_path_to_error", "tower 0.5.2", "tower-layer", "tower-service", @@ -900,12 +903,6 @@ version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" -[[package]] -name = "base64" -version = "0.20.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ea22880d78093b0cbe17c89f64a7d457941e65759157ec6cb31a31d652b05e5" - [[package]] name = "base64" version = "0.21.7" @@ -1297,7 +1294,7 @@ dependencies = [ "aws-smithy-types", "axum", "axum-extra", - "base64 0.13.1", + "base64 0.22.1", "bytes", "camino", "cfg-if", @@ -1321,6 +1318,7 @@ dependencies = [ "p256 0.13.2", "postgres", "postgres_initdb", + "postgres_versioninfo", "regex", "remote_storage", "reqwest", @@ -1423,7 +1421,7 @@ name = "control_plane" version = "0.1.0" dependencies = [ "anyhow", - "base64 0.13.1", + "base64 0.22.1", "camino", "clap", "comfy-table", @@ -1445,6 +1443,7 @@ dependencies = [ "regex", "reqwest", "safekeeper_api", + "safekeeper_client", "scopeguard", "serde", "serde_json", @@ -2054,6 +2053,7 @@ dependencies = [ "axum-extra", "camino", "camino-tempfile", + "clap", "futures", "http-body-util", "itertools 0.10.5", @@ -4256,6 +4256,7 @@ dependencies = [ "tokio-util", "tonic 0.13.1", "tracing", + "url", "utils", "workspace_hack", ] @@ -4335,6 +4336,7 @@ dependencies = [ "postgres_backend", "postgres_connection", "postgres_ffi", + "postgres_ffi_types", "postgres_initdb", "posthog_client_lite", "pprof", @@ -4404,7 +4406,8 @@ dependencies = [ "nix 0.30.1", "once_cell", "postgres_backend", - "postgres_ffi", + "postgres_ffi_types", + "postgres_versioninfo", "rand 0.8.5", "remote_storage", "reqwest", @@ -4428,6 +4431,7 @@ dependencies = [ "futures", "http-utils", "pageserver_api", + "postgres_versioninfo", "reqwest", "serde", "thiserror 1.0.69", @@ -4466,11 +4470,16 @@ dependencies = [ name = "pageserver_page_api" version = "0.1.0" dependencies = [ + "anyhow", "bytes", + "futures", "pageserver_api", "postgres_ffi", "prost 0.13.5", + "strum", + "strum_macros", "thiserror 1.0.69", + "tokio", "tonic 0.13.1", "tonic-build", "utils", @@ -4813,7 +4822,7 @@ dependencies = [ name = "postgres-protocol2" version = "0.1.0" dependencies = [ - "base64 0.20.0", + "base64 0.22.1", "byteorder", "bytes", "fallible-iterator", @@ -4890,6 +4899,8 @@ dependencies = [ "memoffset 0.9.0", "once_cell", "postgres", + "postgres_ffi_types", + "postgres_versioninfo", "pprof", "regex", "serde", @@ -4898,17 +4909,37 @@ dependencies = [ "utils", ] +[[package]] +name = "postgres_ffi_types" +version = "0.1.0" +dependencies = [ + "thiserror 1.0.69", + "workspace_hack", +] + [[package]] name = "postgres_initdb" version = "0.1.0" dependencies = [ "anyhow", "camino", + "postgres_versioninfo", "thiserror 1.0.69", "tokio", "workspace_hack", ] +[[package]] +name = "postgres_versioninfo" +version = "0.1.0" +dependencies = [ + "anyhow", + "serde", + "serde_repr", + "thiserror 1.0.69", + "workspace_hack", +] + [[package]] name = "posthog_client_lite" version = "0.1.0" @@ -5185,7 +5216,7 @@ dependencies = [ "aws-config", "aws-sdk-iam", "aws-sigv4", - "base64 0.13.1", + "base64 0.22.1", "bstr", "bytes", "camino", @@ -6100,6 +6131,7 @@ dependencies = [ "postgres-protocol", "postgres_backend", "postgres_ffi", + "postgres_versioninfo", "pprof", "pq_proto", "rand 0.8.5", @@ -6144,6 +6176,7 @@ dependencies = [ "const_format", "pageserver_api", "postgres_ffi", + "postgres_versioninfo", "pq_proto", "serde", "serde_json", @@ -6420,6 +6453,19 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "serde_html_form" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d2de91cf02bbc07cde38891769ccd5d4f073d22a40683aa4bc7a95781aaa2c4" +dependencies = [ + "form_urlencoded", + "indexmap 2.9.0", + "itoa", + "ryu", + "serde", +] + [[package]] name = "serde_json" version = "1.0.125" @@ -6453,6 +6499,17 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "serde_repr" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "175ee3e80ae9982737ca543e96133087cbd9a485eecc3bc4de9c1a37b47ea59c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + [[package]] name = "serde_spanned" version = "0.6.6" @@ -6476,15 +6533,17 @@ dependencies = [ [[package]] name = "serde_with" -version = "2.3.3" +version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07ff71d2c147a7b57362cead5e22f772cd52f6ab31cfcd9edcd7f6aeb2a0afbe" +checksum = "d6b6f7f2fcb69f747921f79f3926bd1e203fce4fef62c268dd3abfb6d86029aa" dependencies = [ - "base64 0.13.1", + "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", + "indexmap 2.9.0", "serde", + "serde_derive", "serde_json", "serde_with_macros", "time", @@ -6492,9 +6551,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "2.3.3" +version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "881b6f881b17d13214e5d494c939ebab463d01264ce1811e9d4ac3a882e7695f" +checksum = "8d00caa5193a3c8362ac2b73be6b9e768aa5a4b2f721d8f4b339600c3cb51f8e" dependencies = [ "darling", "proc-macro2", @@ -6756,6 +6815,7 @@ dependencies = [ "hex", "http-utils", "humantime", + "humantime-serde", "hyper 0.14.30", "itertools 0.10.5", "json-structural-diff", @@ -6766,6 +6826,7 @@ dependencies = [ "pageserver_api", "pageserver_client", "postgres_connection", + "posthog_client_lite", "rand 0.8.5", "regex", "reqwest", @@ -7544,6 +7605,7 @@ dependencies = [ "axum", "base64 0.22.1", "bytes", + "flate2", "h2 0.4.4", "http 1.1.0", "http-body 1.0.0", @@ -7563,6 +7625,7 @@ dependencies = [ "tower-layer", "tower-service", "tracing", + "zstd", ] [[package]] @@ -8142,6 +8205,7 @@ dependencies = [ "futures", "pageserver_api", "postgres_ffi", + "postgres_ffi_types", "pprof", "prost 0.13.5", "remote_storage", @@ -8565,7 +8629,6 @@ dependencies = [ "anyhow", "axum", "axum-core", - "base64 0.13.1", "base64 0.21.7", "base64ct", "bytes", diff --git a/Cargo.toml b/Cargo.toml index a040010fb7..857bc5d5d9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,8 @@ members = [ "libs/http-utils", "libs/pageserver_api", "libs/postgres_ffi", + "libs/postgres_ffi_types", + "libs/postgres_versioninfo", "libs/safekeeper_api", "libs/desim", "libs/neon-shmem", @@ -71,8 +73,8 @@ aws-credential-types = "1.2.0" aws-sigv4 = { version = "1.2", features = ["sign-http"] } aws-types = "1.3" axum = { version = "0.8.1", features = ["ws"] } -axum-extra = { version = "0.10.0", features = ["typed-header"] } -base64 = "0.13.0" +axum-extra = { version = "0.10.0", features = ["typed-header", "query"] } +base64 = "0.22" bincode = "1.3" bindgen = "0.71" bit_field = "0.10.2" @@ -171,8 +173,9 @@ sentry = { version = "0.37", default-features = false, features = ["backtrace", serde = { version = "1.0", features = ["derive"] } serde_json = "1" serde_path_to_error = "0.1" -serde_with = { version = "2.0", features = [ "base64" ] } +serde_with = { version = "3", features = [ "base64" ] } serde_assert = "0.5.0" +serde_repr = "0.1.20" sha2 = "0.10.2" signal-hook = "0.3" smallvec = "1.11" @@ -199,7 +202,7 @@ tokio-tar = "0.3" tokio-util = { version = "0.7.10", features = ["io", "rt"] } toml = "0.8" toml_edit = "0.22" -tonic = { version = "0.13.1", default-features = false, features = ["channel", "codegen", "prost", "router", "server", "tls-ring", "tls-native-roots"] } +tonic = { version = "0.13.1", default-features = false, features = ["channel", "codegen", "gzip", "prost", "router", "server", "tls-ring", "tls-native-roots", "zstd"] } tonic-reflection = { version = "0.13.1", features = ["server"] } tower = { version = "0.5.2", default-features = false } tower-http = { version = "0.6.2", features = ["auth", "request-id", "trace"] } @@ -259,6 +262,8 @@ pageserver_page_api = { path = "./pageserver/page_api" } postgres_backend = { version = "0.1", path = "./libs/postgres_backend/" } postgres_connection = { version = "0.1", path = "./libs/postgres_connection/" } postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" } +postgres_ffi_types = { version = "0.1", path = "./libs/postgres_ffi_types/" } +postgres_versioninfo = { version = "0.1", path = "./libs/postgres_versioninfo/" } postgres_initdb = { path = "./libs/postgres_initdb" } posthog_client_lite = { version = "0.1", path = "./libs/posthog_client_lite" } pq_proto = { version = "0.1", path = "./libs/pq_proto/" } diff --git a/Dockerfile b/Dockerfile index 3b7962dcf9..69657067de 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,8 +5,6 @@ ARG REPOSITORY=ghcr.io/neondatabase ARG IMAGE=build-tools ARG TAG=pinned -ARG DEFAULT_PG_VERSION=17 -ARG STABLE_PG_VERSION=16 ARG DEBIAN_VERSION=bookworm ARG DEBIAN_FLAVOR=${DEBIAN_VERSION}-slim @@ -47,7 +45,6 @@ COPY --chown=nonroot scripts/ninstall.sh scripts/ninstall.sh ENV BUILD_TYPE=release RUN set -e \ && mold -run make -j $(nproc) -s neon-pg-ext \ - && rm -rf pg_install/build \ && tar -C pg_install -czf /home/nonroot/postgres_install.tar.gz . # Prepare cargo-chef recipe @@ -63,14 +60,11 @@ FROM $REPOSITORY/$IMAGE:$TAG AS build WORKDIR /home/nonroot ARG GIT_VERSION=local ARG BUILD_TAG -ARG STABLE_PG_VERSION COPY --from=pg-build /home/nonroot/pg_install/v14/include/postgresql/server pg_install/v14/include/postgresql/server COPY --from=pg-build /home/nonroot/pg_install/v15/include/postgresql/server pg_install/v15/include/postgresql/server COPY --from=pg-build /home/nonroot/pg_install/v16/include/postgresql/server pg_install/v16/include/postgresql/server COPY --from=pg-build /home/nonroot/pg_install/v17/include/postgresql/server pg_install/v17/include/postgresql/server -COPY --from=pg-build /home/nonroot/pg_install/v16/lib pg_install/v16/lib -COPY --from=pg-build /home/nonroot/pg_install/v17/lib pg_install/v17/lib COPY --from=plan /home/nonroot/recipe.json recipe.json ARG ADDITIONAL_RUSTFLAGS="" @@ -97,7 +91,6 @@ RUN set -e \ # Build final image # FROM $BASE_IMAGE_SHA -ARG DEFAULT_PG_VERSION WORKDIR /data RUN set -e \ @@ -107,9 +100,20 @@ RUN set -e \ libreadline-dev \ libseccomp-dev \ ca-certificates \ - # System postgres for use with client libraries (e.g. in storage controller) - postgresql-15 \ openssl \ + unzip \ + curl \ + && ARCH=$(uname -m) \ + && if [ "$ARCH" = "x86_64" ]; then \ + curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"; \ + elif [ "$ARCH" = "aarch64" ]; then \ + curl "https://awscli.amazonaws.com/awscli-exe-linux-aarch64.zip" -o "awscliv2.zip"; \ + else \ + echo "Unsupported architecture: $ARCH" && exit 1; \ + fi \ + && unzip awscliv2.zip \ + && ./aws/install \ + && rm -rf aws awscliv2.zip \ && rm -f /etc/apt/apt.conf.d/80-retries \ && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* \ && useradd -d /data neon \ diff --git a/Makefile b/Makefile index 0911465fb8..d39b9b68c8 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,18 @@ ROOT_PROJECT_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST)))) -# Where to install Postgres, default is ./pg_install, maybe useful for package managers +# Where to install Postgres, default is ./pg_install, maybe useful for package +# managers. POSTGRES_INSTALL_DIR ?= $(ROOT_PROJECT_DIR)/pg_install/ +# CARGO_BUILD_FLAGS: Extra flags to pass to `cargo build`. `--locked` +# and `--features testing` are popular examples. +# +# CARGO_PROFILE: You can also set to override the cargo profile to +# use. By default, it is derived from BUILD_TYPE. + +# All intermediate build artifacts are stored here. +BUILD_DIR := build + ICU_PREFIX_DIR := /usr/local/icu # @@ -16,12 +26,12 @@ ifeq ($(BUILD_TYPE),release) PG_CONFIGURE_OPTS = --enable-debug --with-openssl PG_CFLAGS += -O2 -g3 $(CFLAGS) PG_LDFLAGS = $(LDFLAGS) - # Unfortunately, `--profile=...` is a nightly feature - CARGO_BUILD_FLAGS += --release + CARGO_PROFILE ?= --profile=release else ifeq ($(BUILD_TYPE),debug) PG_CONFIGURE_OPTS = --enable-debug --with-openssl --enable-cassert --enable-depend PG_CFLAGS += -O0 -g3 $(CFLAGS) PG_LDFLAGS = $(LDFLAGS) + CARGO_PROFILE ?= --profile=dev else $(error Bad build type '$(BUILD_TYPE)', see Makefile for options) endif @@ -93,7 +103,7 @@ all: neon postgres neon-pg-ext .PHONY: neon neon: postgres-headers walproposer-lib cargo-target-dir +@echo "Compiling Neon" - $(CARGO_CMD_PREFIX) cargo build $(CARGO_BUILD_FLAGS) + $(CARGO_CMD_PREFIX) cargo build $(CARGO_BUILD_FLAGS) $(CARGO_PROFILE) .PHONY: cargo-target-dir cargo-target-dir: # https://github.com/rust-lang/cargo/issues/14281 @@ -104,21 +114,20 @@ cargo-target-dir: # Some rules are duplicated for Postgres v14 and 15. We may want to refactor # to avoid the duplication in the future, but it's tolerable for now. # -$(POSTGRES_INSTALL_DIR)/build/%/config.status: - - mkdir -p $(POSTGRES_INSTALL_DIR) - test -e $(POSTGRES_INSTALL_DIR)/CACHEDIR.TAG || echo "$(CACHEDIR_TAG_CONTENTS)" > $(POSTGRES_INSTALL_DIR)/CACHEDIR.TAG +$(BUILD_DIR)/%/config.status: + mkdir -p $(BUILD_DIR) + test -e $(BUILD_DIR)/CACHEDIR.TAG || echo "$(CACHEDIR_TAG_CONTENTS)" > $(BUILD_DIR)/CACHEDIR.TAG +@echo "Configuring Postgres $* build" @test -s $(ROOT_PROJECT_DIR)/vendor/postgres-$*/configure || { \ echo "\nPostgres submodule not found in $(ROOT_PROJECT_DIR)/vendor/postgres-$*/, execute "; \ echo "'git submodule update --init --recursive --depth 2 --progress .' in project root.\n"; \ exit 1; } - mkdir -p $(POSTGRES_INSTALL_DIR)/build/$* + mkdir -p $(BUILD_DIR)/$* VERSION=$*; \ EXTRA_VERSION=$$(cd $(ROOT_PROJECT_DIR)/vendor/postgres-$$VERSION && git rev-parse HEAD); \ - (cd $(POSTGRES_INSTALL_DIR)/build/$$VERSION && \ + (cd $(BUILD_DIR)/$$VERSION && \ env PATH="$(EXTRA_PATH_OVERRIDES):$$PATH" $(ROOT_PROJECT_DIR)/vendor/postgres-$$VERSION/configure \ CFLAGS='$(PG_CFLAGS)' LDFLAGS='$(PG_LDFLAGS)' \ $(PG_CONFIGURE_OPTS) --with-extra-version=" ($$EXTRA_VERSION)" \ @@ -130,96 +139,52 @@ $(POSTGRES_INSTALL_DIR)/build/%/config.status: # the "build-all-versions" entry points) where direct mention of PostgreSQL # versions is used. .PHONY: postgres-configure-v17 -postgres-configure-v17: $(POSTGRES_INSTALL_DIR)/build/v17/config.status +postgres-configure-v17: $(BUILD_DIR)/v17/config.status .PHONY: postgres-configure-v16 -postgres-configure-v16: $(POSTGRES_INSTALL_DIR)/build/v16/config.status +postgres-configure-v16: $(BUILD_DIR)/v16/config.status .PHONY: postgres-configure-v15 -postgres-configure-v15: $(POSTGRES_INSTALL_DIR)/build/v15/config.status +postgres-configure-v15: $(BUILD_DIR)/v15/config.status .PHONY: postgres-configure-v14 -postgres-configure-v14: $(POSTGRES_INSTALL_DIR)/build/v14/config.status +postgres-configure-v14: $(BUILD_DIR)/v14/config.status # Install the PostgreSQL header files into $(POSTGRES_INSTALL_DIR)//include .PHONY: postgres-headers-% postgres-headers-%: postgres-configure-% +@echo "Installing PostgreSQL $* headers" - $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/src/include MAKELEVEL=0 install + $(MAKE) -C $(BUILD_DIR)/$*/src/include MAKELEVEL=0 install # Compile and install PostgreSQL .PHONY: postgres-% postgres-%: postgres-configure-% \ postgres-headers-% # to prevent `make install` conflicts with neon's `postgres-headers` +@echo "Compiling PostgreSQL $*" - $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$* MAKELEVEL=0 install - +@echo "Compiling libpq $*" - $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/src/interfaces/libpq install + $(MAKE) -C $(BUILD_DIR)/$* MAKELEVEL=0 install +@echo "Compiling pg_prewarm $*" - $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_prewarm install + $(MAKE) -C $(BUILD_DIR)/$*/contrib/pg_prewarm install +@echo "Compiling pg_buffercache $*" - $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_buffercache install + $(MAKE) -C $(BUILD_DIR)/$*/contrib/pg_buffercache install +@echo "Compiling pg_visibility $*" - $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_visibility install + $(MAKE) -C $(BUILD_DIR)/$*/contrib/pg_visibility install +@echo "Compiling pageinspect $*" - $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pageinspect install + $(MAKE) -C $(BUILD_DIR)/$*/contrib/pageinspect install +@echo "Compiling pg_trgm $*" - $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_trgm install + $(MAKE) -C $(BUILD_DIR)/$*/contrib/pg_trgm install +@echo "Compiling amcheck $*" - $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/amcheck install + $(MAKE) -C $(BUILD_DIR)/$*/contrib/amcheck install +@echo "Compiling test_decoding $*" - $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/test_decoding install - -.PHONY: postgres-clean-% -postgres-clean-%: - $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$* MAKELEVEL=0 clean - $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pg_buffercache clean - $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/contrib/pageinspect clean - $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/src/interfaces/libpq clean + $(MAKE) -C $(BUILD_DIR)/$*/contrib/test_decoding install .PHONY: postgres-check-% postgres-check-%: postgres-% - $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$* MAKELEVEL=0 check + $(MAKE) -C $(BUILD_DIR)/$* MAKELEVEL=0 check .PHONY: neon-pg-ext-% neon-pg-ext-%: postgres-% - +@echo "Compiling neon $*" - mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-$* + +@echo "Compiling neon-specific Postgres extensions for $*" + mkdir -p $(BUILD_DIR)/pgxn-$* $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \ - -C $(POSTGRES_INSTALL_DIR)/build/neon-$* \ - -f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile install - +@echo "Compiling neon_walredo $*" - mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-walredo-$* - $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \ - -C $(POSTGRES_INSTALL_DIR)/build/neon-walredo-$* \ - -f $(ROOT_PROJECT_DIR)/pgxn/neon_walredo/Makefile install - +@echo "Compiling neon_rmgr $*" - mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-rmgr-$* - $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \ - -C $(POSTGRES_INSTALL_DIR)/build/neon-rmgr-$* \ - -f $(ROOT_PROJECT_DIR)/pgxn/neon_rmgr/Makefile install - +@echo "Compiling neon_test_utils $*" - mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$* - $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \ - -C $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$* \ - -f $(ROOT_PROJECT_DIR)/pgxn/neon_test_utils/Makefile install - +@echo "Compiling neon_utils $*" - mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* - $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \ - -C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \ - -f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile install - -.PHONY: neon-pg-clean-ext-% -neon-pg-clean-ext-%: - $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config \ - -C $(POSTGRES_INSTALL_DIR)/build/neon-$* \ - -f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile clean - $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config \ - -C $(POSTGRES_INSTALL_DIR)/build/neon-walredo-$* \ - -f $(ROOT_PROJECT_DIR)/pgxn/neon_walredo/Makefile clean - $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config \ - -C $(POSTGRES_INSTALL_DIR)/build/neon-test-utils-$* \ - -f $(ROOT_PROJECT_DIR)/pgxn/neon_test_utils/Makefile clean - $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config \ - -C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \ - -f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile clean + -C $(BUILD_DIR)/pgxn-$*\ + -f $(ROOT_PROJECT_DIR)/pgxn/Makefile install # Build walproposer as a static library. walproposer source code is located # in the pgxn/neon directory. @@ -233,15 +198,15 @@ neon-pg-clean-ext-%: .PHONY: walproposer-lib walproposer-lib: neon-pg-ext-v17 +@echo "Compiling walproposer-lib" - mkdir -p $(POSTGRES_INSTALL_DIR)/build/walproposer-lib + mkdir -p $(BUILD_DIR)/walproposer-lib $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v17/bin/pg_config COPT='$(COPT)' \ - -C $(POSTGRES_INSTALL_DIR)/build/walproposer-lib \ + -C $(BUILD_DIR)/walproposer-lib \ -f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile walproposer-lib - cp $(POSTGRES_INSTALL_DIR)/v17/lib/libpgport.a $(POSTGRES_INSTALL_DIR)/build/walproposer-lib - cp $(POSTGRES_INSTALL_DIR)/v17/lib/libpgcommon.a $(POSTGRES_INSTALL_DIR)/build/walproposer-lib - $(AR) d $(POSTGRES_INSTALL_DIR)/build/walproposer-lib/libpgport.a \ + cp $(POSTGRES_INSTALL_DIR)/v17/lib/libpgport.a $(BUILD_DIR)/walproposer-lib + cp $(POSTGRES_INSTALL_DIR)/v17/lib/libpgcommon.a $(BUILD_DIR)/walproposer-lib + $(AR) d $(BUILD_DIR)/walproposer-lib/libpgport.a \ pg_strong_random.o - $(AR) d $(POSTGRES_INSTALL_DIR)/build/walproposer-lib/libpgcommon.a \ + $(AR) d $(BUILD_DIR)/walproposer-lib/libpgcommon.a \ checksum_helper.o \ cryptohash_openssl.o \ hmac_openssl.o \ @@ -249,16 +214,10 @@ walproposer-lib: neon-pg-ext-v17 parse_manifest.o \ scram-common.o ifeq ($(UNAME_S),Linux) - $(AR) d $(POSTGRES_INSTALL_DIR)/build/walproposer-lib/libpgcommon.a \ + $(AR) d $(BUILD_DIR)/walproposer-lib/libpgcommon.a \ pg_crc32c.o endif -.PHONY: walproposer-lib-clean -walproposer-lib-clean: - $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v17/bin/pg_config \ - -C $(POSTGRES_INSTALL_DIR)/build/walproposer-lib \ - -f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile clean - .PHONY: neon-pg-ext neon-pg-ext: \ neon-pg-ext-v14 \ @@ -266,13 +225,6 @@ neon-pg-ext: \ neon-pg-ext-v16 \ neon-pg-ext-v17 -.PHONY: neon-pg-clean-ext -neon-pg-clean-ext: \ - neon-pg-clean-ext-v14 \ - neon-pg-clean-ext-v15 \ - neon-pg-clean-ext-v16 \ - neon-pg-clean-ext-v17 - # shorthand to build all Postgres versions .PHONY: postgres postgres: \ @@ -288,13 +240,6 @@ postgres-headers: \ postgres-headers-v16 \ postgres-headers-v17 -.PHONY: postgres-clean -postgres-clean: \ - postgres-clean-v14 \ - postgres-clean-v15 \ - postgres-clean-v16 \ - postgres-clean-v17 - .PHONY: postgres-check postgres-check: \ postgres-check-v14 \ @@ -302,12 +247,6 @@ postgres-check: \ postgres-check-v16 \ postgres-check-v17 -# This doesn't remove the effects of 'configure'. -.PHONY: clean -clean: postgres-clean neon-pg-clean-ext - $(MAKE) -C compute clean - $(CARGO_CMD_PREFIX) cargo clean - # This removes everything .PHONY: distclean distclean: @@ -320,7 +259,7 @@ fmt: postgres-%-pg-bsd-indent: postgres-% +@echo "Compiling pg_bsd_indent" - $(MAKE) -C $(POSTGRES_INSTALL_DIR)/build/$*/src/tools/pg_bsd_indent/ + $(MAKE) -C $(BUILD_DIR)/$*/src/tools/pg_bsd_indent/ # Create typedef list for the core. Note that generally it should be combined with # buildfarm one to cover platform specific stuff. @@ -339,7 +278,7 @@ postgres-%-pgindent: postgres-%-pg-bsd-indent postgres-%-typedefs.list cat $(ROOT_PROJECT_DIR)/vendor/postgres-$*/src/tools/pgindent/typedefs.list |\ cat - postgres-$*-typedefs.list | sort | uniq > postgres-$*-typedefs-full.list +@echo note: you might want to run it on selected files/dirs instead. - INDENT=$(POSTGRES_INSTALL_DIR)/build/$*/src/tools/pg_bsd_indent/pg_bsd_indent \ + INDENT=$(BUILD_DIR)/$*/src/tools/pg_bsd_indent/pg_bsd_indent \ $(ROOT_PROJECT_DIR)/vendor/postgres-$*/src/tools/pgindent/pgindent --typedefs postgres-$*-typedefs-full.list \ $(ROOT_PROJECT_DIR)/vendor/postgres-$*/src/ \ --excludes $(ROOT_PROJECT_DIR)/vendor/postgres-$*/src/tools/pgindent/exclude_file_patterns @@ -350,9 +289,9 @@ postgres-%-pgindent: postgres-%-pg-bsd-indent postgres-%-typedefs.list neon-pgindent: postgres-v17-pg-bsd-indent neon-pg-ext-v17 $(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v17/bin/pg_config COPT='$(COPT)' \ FIND_TYPEDEF=$(ROOT_PROJECT_DIR)/vendor/postgres-v17/src/tools/find_typedef \ - INDENT=$(POSTGRES_INSTALL_DIR)/build/v17/src/tools/pg_bsd_indent/pg_bsd_indent \ + INDENT=$(BUILD_DIR)/v17/src/tools/pg_bsd_indent/pg_bsd_indent \ PGINDENT_SCRIPT=$(ROOT_PROJECT_DIR)/vendor/postgres-v17/src/tools/pgindent/pgindent \ - -C $(POSTGRES_INSTALL_DIR)/build/neon-v17 \ + -C $(BUILD_DIR)/neon-v17 \ -f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile pgindent diff --git a/compute/.gitignore b/compute/.gitignore index 70980d335a..aab2afaa4e 100644 --- a/compute/.gitignore +++ b/compute/.gitignore @@ -3,3 +3,6 @@ etc/neon_collector.yml etc/neon_collector_autoscaling.yml etc/sql_exporter.yml etc/sql_exporter_autoscaling.yml + +# Node.js dependencies +node_modules/ diff --git a/compute/Makefile b/compute/Makefile index 0036196160..ef2e55f7b1 100644 --- a/compute/Makefile +++ b/compute/Makefile @@ -22,7 +22,7 @@ sql_exporter.yml: $(jsonnet_files) --output-file etc/$@ \ --tla-str collector_name=neon_collector \ --tla-str collector_file=neon_collector.yml \ - --tla-str 'connection_string=postgresql://cloud_admin@127.0.0.1:5432/postgres?sslmode=disable&application_name=sql_exporter' \ + --tla-str 'connection_string=postgresql://cloud_admin@127.0.0.1:5432/postgres?sslmode=disable&application_name=sql_exporter&pgaudit.log=none' \ etc/sql_exporter.jsonnet sql_exporter_autoscaling.yml: $(jsonnet_files) @@ -30,7 +30,7 @@ sql_exporter_autoscaling.yml: $(jsonnet_files) --output-file etc/$@ \ --tla-str collector_name=neon_collector_autoscaling \ --tla-str collector_file=neon_collector_autoscaling.yml \ - --tla-str 'connection_string=postgresql://cloud_admin@127.0.0.1:5432/postgres?sslmode=disable&application_name=sql_exporter_autoscaling' \ + --tla-str 'connection_string=postgresql://cloud_admin@127.0.0.1:5432/postgres?sslmode=disable&application_name=sql_exporter_autoscaling&pgaudit.log=none' \ etc/sql_exporter.jsonnet .PHONY: clean @@ -48,3 +48,11 @@ jsonnetfmt-test: .PHONY: jsonnetfmt-format jsonnetfmt-format: jsonnetfmt --in-place $(jsonnet_files) + +.PHONY: manifest-schema-validation +manifest-schema-validation: node_modules + node_modules/.bin/jsonschema validate -d https://json-schema.org/draft/2020-12/schema manifest.schema.json manifest.yaml + +node_modules: package.json + npm install + touch node_modules diff --git a/compute/compute-node.Dockerfile b/compute/compute-node.Dockerfile index 248f52088b..35ece73030 100644 --- a/compute/compute-node.Dockerfile +++ b/compute/compute-node.Dockerfile @@ -77,9 +77,6 @@ # build_and_test.yml github workflow for how that's done. ARG PG_VERSION -ARG REPOSITORY=ghcr.io/neondatabase -ARG IMAGE=build-tools -ARG TAG=pinned ARG BUILD_TAG ARG DEBIAN_VERSION=bookworm ARG DEBIAN_FLAVOR=${DEBIAN_VERSION}-slim @@ -149,8 +146,11 @@ RUN case $DEBIAN_VERSION in \ ninja-build git autoconf automake libtool build-essential bison flex libreadline-dev \ zlib1g-dev libxml2-dev libcurl4-openssl-dev libossp-uuid-dev wget ca-certificates pkg-config libssl-dev \ libicu-dev libxslt1-dev liblz4-dev libzstd-dev zstd curl unzip g++ \ + libclang-dev \ + jsonnet \ $VERSION_INSTALLS \ - && apt clean && rm -rf /var/lib/apt/lists/* + && apt clean && rm -rf /var/lib/apt/lists/* && \ + useradd -ms /bin/bash nonroot -b /home ######################################################################################### # @@ -171,9 +171,6 @@ RUN cd postgres && \ eval $CONFIGURE_CMD && \ make MAKELEVEL=0 -j $(getconf _NPROCESSORS_ONLN) -s install && \ make MAKELEVEL=0 -j $(getconf _NPROCESSORS_ONLN) -s -C contrib/ install && \ - # Install headers - make MAKELEVEL=0 -j $(getconf _NPROCESSORS_ONLN) -s -C src/include install && \ - make MAKELEVEL=0 -j $(getconf _NPROCESSORS_ONLN) -s -C src/interfaces/libpq install && \ # Enable some of contrib extensions echo 'trusted = true' >> /usr/local/pgsql/share/extension/autoinc.control && \ echo 'trusted = true' >> /usr/local/pgsql/share/extension/dblink.control && \ @@ -1057,17 +1054,10 @@ RUN make -j $(getconf _NPROCESSORS_ONLN) && \ ######################################################################################### # -# Layer "pg build with nonroot user and cargo installed" -# This layer is base and common for layers with `pgrx` +# Layer "build-deps with Rust toolchain installed" # ######################################################################################### -FROM pg-build AS pg-build-nonroot-with-cargo -ARG PG_VERSION - -RUN apt update && \ - apt install --no-install-recommends --no-install-suggests -y curl libclang-dev && \ - apt clean && rm -rf /var/lib/apt/lists/* && \ - useradd -ms /bin/bash nonroot -b /home +FROM build-deps AS build-deps-with-cargo ENV HOME=/home/nonroot ENV PATH="/home/nonroot/.cargo/bin:$PATH" @@ -1082,13 +1072,29 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux ./rustup-init -y --no-modify-path --profile minimal --default-toolchain stable && \ rm rustup-init +######################################################################################### +# +# Layer "pg-build with Rust toolchain installed" +# This layer is base and common for layers with `pgrx` +# +######################################################################################### +FROM pg-build AS pg-build-with-cargo +ARG PG_VERSION + +ENV HOME=/home/nonroot +ENV PATH="/home/nonroot/.cargo/bin:$PATH" +USER nonroot +WORKDIR /home/nonroot + +COPY --from=build-deps-with-cargo /home/nonroot /home/nonroot + ######################################################################################### # # Layer "rust extensions" # This layer is used to build `pgrx` deps # ######################################################################################### -FROM pg-build-nonroot-with-cargo AS rust-extensions-build +FROM pg-build-with-cargo AS rust-extensions-build ARG PG_VERSION RUN case "${PG_VERSION:?}" in \ @@ -1110,7 +1116,7 @@ USER root # and eventually get merged with `rust-extensions-build` # ######################################################################################### -FROM pg-build-nonroot-with-cargo AS rust-extensions-build-pgrx12 +FROM pg-build-with-cargo AS rust-extensions-build-pgrx12 ARG PG_VERSION RUN cargo install --locked --version 0.12.9 cargo-pgrx && \ @@ -1127,7 +1133,7 @@ USER root # and eventually get merged with `rust-extensions-build` # ######################################################################################### -FROM pg-build-nonroot-with-cargo AS rust-extensions-build-pgrx14 +FROM pg-build-with-cargo AS rust-extensions-build-pgrx14 ARG PG_VERSION RUN cargo install --locked --version 0.14.1 cargo-pgrx && \ @@ -1144,10 +1150,12 @@ USER root FROM build-deps AS pgrag-src ARG PG_VERSION - WORKDIR /ext-src +COPY compute/patches/onnxruntime.patch . + RUN wget https://github.com/microsoft/onnxruntime/archive/refs/tags/v1.18.1.tar.gz -O onnxruntime.tar.gz && \ mkdir onnxruntime-src && cd onnxruntime-src && tar xzf ../onnxruntime.tar.gz --strip-components=1 -C . && \ + patch -p1 < /ext-src/onnxruntime.patch && \ echo "#nothing to test here" > neon-test.sh RUN wget https://github.com/neondatabase-labs/pgrag/archive/refs/tags/v0.1.2.tar.gz -O pgrag.tar.gz && \ @@ -1557,20 +1565,20 @@ ARG PG_VERSION WORKDIR /ext-src RUN case "${PG_VERSION}" in \ "v14") \ - export PGAUDIT_VERSION=1.6.2 \ - export PGAUDIT_CHECKSUM=1f350d70a0cbf488c0f2b485e3a5c9b11f78ad9e3cbb95ef6904afa1eb3187eb \ + export PGAUDIT_VERSION=1.6.3 \ + export PGAUDIT_CHECKSUM=37a8f5a7cc8d9188e536d15cf0fdc457fcdab2547caedb54442c37f124110919 \ ;; \ "v15") \ - export PGAUDIT_VERSION=1.7.0 \ - export PGAUDIT_CHECKSUM=8f4a73e451c88c567e516e6cba7dc1e23bc91686bb6f1f77f8f3126d428a8bd8 \ + export PGAUDIT_VERSION=1.7.1 \ + export PGAUDIT_CHECKSUM=e9c8e6e092d82b2f901d72555ce0fe7780552f35f8985573796cd7e64b09d4ec \ ;; \ "v16") \ - export PGAUDIT_VERSION=16.0 \ - export PGAUDIT_CHECKSUM=d53ef985f2d0b15ba25c512c4ce967dce07b94fd4422c95bd04c4c1a055fe738 \ + export PGAUDIT_VERSION=16.1 \ + export PGAUDIT_CHECKSUM=3bae908ab70ba0c6f51224009dbcfff1a97bd6104c6273297a64292e1b921fee \ ;; \ "v17") \ - export PGAUDIT_VERSION=17.0 \ - export PGAUDIT_CHECKSUM=7d0d08d030275d525f36cd48b38c6455f1023da863385badff0cec44965bfd8c \ + export PGAUDIT_VERSION=17.1 \ + export PGAUDIT_CHECKSUM=9c5f37504d393486cc75d2ced83f75f5899be64fa85f689d6babb833b4361e6c \ ;; \ *) \ echo "pgaudit is not supported on this PostgreSQL version" && exit 1;; \ @@ -1621,18 +1629,7 @@ FROM pg-build AS neon-ext-build ARG PG_VERSION COPY pgxn/ pgxn/ -RUN make -j $(getconf _NPROCESSORS_ONLN) \ - -C pgxn/neon \ - -s install && \ - make -j $(getconf _NPROCESSORS_ONLN) \ - -C pgxn/neon_utils \ - -s install && \ - make -j $(getconf _NPROCESSORS_ONLN) \ - -C pgxn/neon_test_utils \ - -s install && \ - make -j $(getconf _NPROCESSORS_ONLN) \ - -C pgxn/neon_rmgr \ - -s install +RUN make -j $(getconf _NPROCESSORS_ONLN) -C pgxn -s install-compute ######################################################################################### # @@ -1722,7 +1719,7 @@ FROM extensions-${EXTENSIONS} AS neon-pg-ext-build # Compile the Neon-specific `compute_ctl`, `fast_import`, and `local_proxy` binaries # ######################################################################################### -FROM $REPOSITORY/$IMAGE:$TAG AS compute-tools +FROM build-deps-with-cargo AS compute-tools ARG BUILD_TAG ENV BUILD_TAG=$BUILD_TAG @@ -1732,7 +1729,7 @@ COPY --chown=nonroot . . RUN --mount=type=cache,uid=1000,target=/home/nonroot/.cargo/registry \ --mount=type=cache,uid=1000,target=/home/nonroot/.cargo/git \ --mount=type=cache,uid=1000,target=/home/nonroot/target \ - mold -run cargo build --locked --profile release-line-debug-size-lto --bin compute_ctl --bin fast_import --bin local_proxy && \ + cargo build --locked --profile release-line-debug-size-lto --bin compute_ctl --bin fast_import --bin local_proxy && \ mkdir target-bin && \ cp target/release-line-debug-size-lto/compute_ctl \ target/release-line-debug-size-lto/fast_import \ @@ -1826,10 +1823,11 @@ RUN rm /usr/local/pgsql/lib/lib*.a # Preprocess the sql_exporter configuration files # ######################################################################################### -FROM $REPOSITORY/$IMAGE:$TAG AS sql_exporter_preprocessor +FROM build-deps AS sql_exporter_preprocessor ARG PG_VERSION USER nonroot +WORKDIR /home/nonroot COPY --chown=nonroot compute compute diff --git a/compute/etc/pgbouncer.ini b/compute/etc/pgbouncer.ini index 9d68cbb8d5..fbcdfd4a87 100644 --- a/compute/etc/pgbouncer.ini +++ b/compute/etc/pgbouncer.ini @@ -21,6 +21,8 @@ unix_socket_dir=/tmp/ unix_socket_mode=0777 ; required for pgbouncer_exporter ignore_startup_parameters=extra_float_digits +; pidfile for graceful termination +pidfile=/tmp/pgbouncer.pid ;; Disable connection logging. It produces a lot of logs that no one looks at, ;; and we can get similar log entries from the proxy too. We had incidents in diff --git a/compute/manifest.schema.json b/compute/manifest.schema.json new file mode 100644 index 0000000000..a25055b45a --- /dev/null +++ b/compute/manifest.schema.json @@ -0,0 +1,209 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "title": "Neon Compute Manifest Schema", + "description": "Schema for Neon compute node configuration manifest", + "type": "object", + "properties": { + "pg_settings": { + "type": "object", + "properties": { + "common": { + "type": "object", + "properties": { + "client_connection_check_interval": { + "type": "string", + "description": "Check for client disconnection interval in milliseconds" + }, + "effective_io_concurrency": { + "type": "string", + "description": "Effective IO concurrency setting" + }, + "fsync": { + "type": "string", + "enum": ["on", "off"], + "description": "Whether to force fsync to disk" + }, + "hot_standby": { + "type": "string", + "enum": ["on", "off"], + "description": "Whether hot standby is enabled" + }, + "idle_in_transaction_session_timeout": { + "type": "string", + "description": "Timeout for idle transactions in milliseconds" + }, + "listen_addresses": { + "type": "string", + "description": "Addresses to listen on" + }, + "log_connections": { + "type": "string", + "enum": ["on", "off"], + "description": "Whether to log connections" + }, + "log_disconnections": { + "type": "string", + "enum": ["on", "off"], + "description": "Whether to log disconnections" + }, + "log_temp_files": { + "type": "string", + "description": "Size threshold for logging temporary files in KB" + }, + "log_error_verbosity": { + "type": "string", + "enum": ["terse", "verbose", "default"], + "description": "Error logging verbosity level" + }, + "log_min_error_statement": { + "type": "string", + "description": "Minimum error level for statement logging" + }, + "maintenance_io_concurrency": { + "type": "string", + "description": "Maintenance IO concurrency setting" + }, + "max_connections": { + "type": "string", + "description": "Maximum number of connections" + }, + "max_replication_flush_lag": { + "type": "string", + "description": "Maximum replication flush lag" + }, + "max_replication_slots": { + "type": "string", + "description": "Maximum number of replication slots" + }, + "max_replication_write_lag": { + "type": "string", + "description": "Maximum replication write lag" + }, + "max_wal_senders": { + "type": "string", + "description": "Maximum number of WAL senders" + }, + "max_wal_size": { + "type": "string", + "description": "Maximum WAL size" + }, + "neon.unstable_extensions": { + "type": "string", + "description": "List of unstable extensions" + }, + "neon.protocol_version": { + "type": "string", + "description": "Neon protocol version" + }, + "password_encryption": { + "type": "string", + "description": "Password encryption method" + }, + "restart_after_crash": { + "type": "string", + "enum": ["on", "off"], + "description": "Whether to restart after crash" + }, + "superuser_reserved_connections": { + "type": "string", + "description": "Number of reserved connections for superuser" + }, + "synchronous_standby_names": { + "type": "string", + "description": "Names of synchronous standby servers" + }, + "wal_keep_size": { + "type": "string", + "description": "WAL keep size" + }, + "wal_level": { + "type": "string", + "description": "WAL level" + }, + "wal_log_hints": { + "type": "string", + "enum": ["on", "off"], + "description": "Whether to log hints in WAL" + }, + "wal_sender_timeout": { + "type": "string", + "description": "WAL sender timeout in milliseconds" + } + }, + "required": [ + "client_connection_check_interval", + "effective_io_concurrency", + "fsync", + "hot_standby", + "idle_in_transaction_session_timeout", + "listen_addresses", + "log_connections", + "log_disconnections", + "log_temp_files", + "log_error_verbosity", + "log_min_error_statement", + "maintenance_io_concurrency", + "max_connections", + "max_replication_flush_lag", + "max_replication_slots", + "max_replication_write_lag", + "max_wal_senders", + "max_wal_size", + "neon.unstable_extensions", + "neon.protocol_version", + "password_encryption", + "restart_after_crash", + "superuser_reserved_connections", + "synchronous_standby_names", + "wal_keep_size", + "wal_level", + "wal_log_hints", + "wal_sender_timeout" + ] + }, + "replica": { + "type": "object", + "properties": { + "hot_standby": { + "type": "string", + "enum": ["on", "off"], + "description": "Whether hot standby is enabled for replicas" + } + }, + "required": ["hot_standby"] + }, + "per_version": { + "type": "object", + "patternProperties": { + "^1[4-7]$": { + "type": "object", + "properties": { + "common": { + "type": "object", + "properties": { + "io_combine_limit": { + "type": "string", + "description": "IO combine limit" + } + } + }, + "replica": { + "type": "object", + "properties": { + "recovery_prefetch": { + "type": "string", + "enum": ["on", "off"], + "description": "Whether to enable recovery prefetch for PostgreSQL replicas" + } + } + } + } + } + } + } + }, + "required": ["common", "replica", "per_version"] + } + }, + "required": ["pg_settings"] +} diff --git a/compute/manifest.yaml b/compute/manifest.yaml index f1cd20c497..4425241d8a 100644 --- a/compute/manifest.yaml +++ b/compute/manifest.yaml @@ -105,17 +105,17 @@ pg_settings: # Neon hot standby ignores pages that are not in the shared_buffers recovery_prefetch: "off" 16: - common: + common: {} replica: # prefetching of blocks referenced in WAL doesn't make sense for us # Neon hot standby ignores pages that are not in the shared_buffers recovery_prefetch: "off" 15: - common: + common: {} replica: # prefetching of blocks referenced in WAL doesn't make sense for us # Neon hot standby ignores pages that are not in the shared_buffers recovery_prefetch: "off" 14: - common: - replica: + common: {} + replica: {} diff --git a/compute/package-lock.json b/compute/package-lock.json new file mode 100644 index 0000000000..693a37cfcb --- /dev/null +++ b/compute/package-lock.json @@ -0,0 +1,37 @@ +{ + "name": "neon-compute", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "neon-compute", + "dependencies": { + "@sourcemeta/jsonschema": "9.3.4" + } + }, + "node_modules/@sourcemeta/jsonschema": { + "version": "9.3.4", + "resolved": "https://registry.npmjs.org/@sourcemeta/jsonschema/-/jsonschema-9.3.4.tgz", + "integrity": "sha512-hkujfkZAIGXUs4U//We9faZW8LZ4/H9LqagRYsFSulH/VLcKPNhZyCTGg7AhORuzm27zqENvKpnX4g2FzudYFw==", + "cpu": [ + "x64", + "arm64" + ], + "license": "AGPL-3.0", + "os": [ + "darwin", + "linux", + "win32" + ], + "bin": { + "jsonschema": "cli.js" + }, + "engines": { + "node": ">=16" + }, + "funding": { + "url": "https://github.com/sponsors/sourcemeta" + } + } + } +} diff --git a/compute/package.json b/compute/package.json new file mode 100644 index 0000000000..581384dc13 --- /dev/null +++ b/compute/package.json @@ -0,0 +1,7 @@ +{ + "name": "neon-compute", + "private": true, + "dependencies": { + "@sourcemeta/jsonschema": "9.3.4" + } +} \ No newline at end of file diff --git a/compute/patches/onnxruntime.patch b/compute/patches/onnxruntime.patch new file mode 100644 index 0000000000..2347547e73 --- /dev/null +++ b/compute/patches/onnxruntime.patch @@ -0,0 +1,15 @@ +diff --git a/cmake/deps.txt b/cmake/deps.txt +index d213b09034..229de2ebf0 100644 +--- a/cmake/deps.txt ++++ b/cmake/deps.txt +@@ -22,7 +22,9 @@ dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b3132 + # it contains changes on top of 3.4.0 which are required to fix build issues. + # Until the 3.4.1 release this is the best option we have. + # Issue link: https://gitlab.com/libeigen/eigen/-/issues/2744 +-eigen;https://gitlab.com/libeigen/eigen/-/archive/e7248b26a1ed53fa030c5c459f7ea095dfd276ac/eigen-e7248b26a1ed53fa030c5c459f7ea095dfd276ac.zip;be8be39fdbc6e60e94fa7870b280707069b5b81a ++# Moved to github mirror to avoid gitlab issues.Add commentMore actions ++# Issue link: https://github.com/bazelbuild/bazel-central-registry/issues/4355 ++eigen;https://github.com/eigen-mirror/eigen/archive/e7248b26a1ed53fa030c5c459f7ea095dfd276ac/eigen-e7248b26a1ed53fa030c5c459f7ea095dfd276ac.zip;61418a349000ba7744a3ad03cf5071f22ebf860a + flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v23.5.26.zip;59422c3b5e573dd192fead2834d25951f1c1670c + fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494 + fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1 diff --git a/compute/vm-image-spec-bookworm.yaml b/compute/vm-image-spec-bookworm.yaml index 057099994a..267e4c83b5 100644 --- a/compute/vm-image-spec-bookworm.yaml +++ b/compute/vm-image-spec-bookworm.yaml @@ -26,7 +26,7 @@ commands: - name: postgres-exporter user: nobody sysvInitAction: respawn - shell: 'DATA_SOURCE_NAME="user=cloud_admin sslmode=disable dbname=postgres application_name=postgres-exporter" /bin/postgres_exporter --config.file=/etc/postgres_exporter.yml' + shell: 'DATA_SOURCE_NAME="user=cloud_admin sslmode=disable dbname=postgres application_name=postgres-exporter pgaudit.log=none" /bin/postgres_exporter --config.file=/etc/postgres_exporter.yml' - name: pgbouncer-exporter user: postgres sysvInitAction: respawn @@ -59,7 +59,7 @@ files: # the rules use ALL as the hostname. Avoid the pointless lookups and the "unable to # resolve host" log messages that they generate. Defaults !fqdn - + # Allow postgres user (which is what compute_ctl runs as) to run /neonvm/bin/resize-swap # and /neonvm/bin/set-disk-quota as root without requiring entering a password (NOPASSWD), # regardless of hostname (ALL) diff --git a/compute/vm-image-spec-bullseye.yaml b/compute/vm-image-spec-bullseye.yaml index d048e20b2e..2b6e77b656 100644 --- a/compute/vm-image-spec-bullseye.yaml +++ b/compute/vm-image-spec-bullseye.yaml @@ -26,7 +26,7 @@ commands: - name: postgres-exporter user: nobody sysvInitAction: respawn - shell: 'DATA_SOURCE_NAME="user=cloud_admin sslmode=disable dbname=postgres application_name=postgres-exporter" /bin/postgres_exporter --config.file=/etc/postgres_exporter.yml' + shell: 'DATA_SOURCE_NAME="user=cloud_admin sslmode=disable dbname=postgres application_name=postgres-exporter pgaudit.log=none" /bin/postgres_exporter --config.file=/etc/postgres_exporter.yml' - name: pgbouncer-exporter user: postgres sysvInitAction: respawn @@ -59,7 +59,7 @@ files: # the rules use ALL as the hostname. Avoid the pointless lookups and the "unable to # resolve host" log messages that they generate. Defaults !fqdn - + # Allow postgres user (which is what compute_ctl runs as) to run /neonvm/bin/resize-swap # and /neonvm/bin/set-disk-quota as root without requiring entering a password (NOPASSWD), # regardless of hostname (ALL) diff --git a/compute_tools/Cargo.toml b/compute_tools/Cargo.toml index f9da3ba700..a5879c4b7c 100644 --- a/compute_tools/Cargo.toml +++ b/compute_tools/Cargo.toml @@ -64,6 +64,7 @@ uuid.workspace = true walkdir.workspace = true x509-cert.workspace = true +postgres_versioninfo.workspace = true postgres_initdb.workspace = true compute_api.workspace = true utils.workspace = true diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index 8b502a058e..d7ff381f1b 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -124,6 +124,10 @@ struct Cli { /// Interval in seconds for collecting installed extensions statistics #[arg(long, default_value = "3600")] pub installed_extensions_collection_interval: u64, + + /// Run in development mode, skipping VM-specific operations like process termination + #[arg(long, action = clap::ArgAction::SetTrue)] + pub dev: bool, } impl Cli { @@ -159,7 +163,7 @@ fn main() -> Result<()> { .build()?; let _rt_guard = runtime.enter(); - runtime.block_on(init())?; + runtime.block_on(init(cli.dev))?; // enable core dumping for all child processes setrlimit(Resource::CORE, rlimit::INFINITY, rlimit::INFINITY)?; @@ -198,13 +202,13 @@ fn main() -> Result<()> { deinit_and_exit(exit_code); } -async fn init() -> Result<()> { +async fn init(dev_mode: bool) -> Result<()> { init_tracing_and_logging(DEFAULT_LOG_LEVEL).await?; let mut signals = Signals::new([SIGINT, SIGTERM, SIGQUIT])?; thread::spawn(move || { for sig in signals.forever() { - handle_exit_signal(sig); + handle_exit_signal(sig, dev_mode); } }); @@ -263,9 +267,9 @@ fn deinit_and_exit(exit_code: Option) -> ! { /// When compute_ctl is killed, send also termination signal to sync-safekeepers /// to prevent leakage. TODO: it is better to convert compute_ctl to async and /// wait for termination which would be easy then. -fn handle_exit_signal(sig: i32) { +fn handle_exit_signal(sig: i32, dev_mode: bool) { info!("received {sig} termination signal"); - forward_termination_signal(); + forward_termination_signal(dev_mode); exit(1); } diff --git a/compute_tools/src/bin/fast_import.rs b/compute_tools/src/bin/fast_import.rs index e65c210b23..0eca9aba53 100644 --- a/compute_tools/src/bin/fast_import.rs +++ b/compute_tools/src/bin/fast_import.rs @@ -29,7 +29,7 @@ use anyhow::{Context, bail}; use aws_config::BehaviorVersion; use camino::{Utf8Path, Utf8PathBuf}; use clap::{Parser, Subcommand}; -use compute_tools::extension_server::{PostgresMajorVersion, get_pg_version}; +use compute_tools::extension_server::get_pg_version; use nix::unistd::Pid; use std::ops::Not; use tracing::{Instrument, error, info, info_span, warn}; @@ -179,12 +179,8 @@ impl PostgresProcess { .await .context("create pgdata directory")?; - let pg_version = match get_pg_version(self.pgbin.as_ref()) { - PostgresMajorVersion::V14 => 14, - PostgresMajorVersion::V15 => 15, - PostgresMajorVersion::V16 => 16, - PostgresMajorVersion::V17 => 17, - }; + let pg_version = get_pg_version(self.pgbin.as_ref()); + postgres_initdb::do_run_initdb(postgres_initdb::RunInitdbArgs { superuser: initdb_user, locale: DEFAULT_LOCALE, // XXX: this shouldn't be hard-coded, @@ -486,10 +482,8 @@ async fn cmd_pgdata( }; let superuser = "cloud_admin"; - let destination_connstring = format!( - "host=localhost port={} user={} dbname=neondb", - pg_port, superuser - ); + let destination_connstring = + format!("host=localhost port={pg_port} user={superuser} dbname=neondb"); let pgdata_dir = workdir.join("pgdata"); let mut proc = PostgresProcess::new(pgdata_dir.clone(), pg_bin_dir.clone(), pg_lib_dir.clone()); diff --git a/compute_tools/src/bin/fast_import/s3_uri.rs b/compute_tools/src/bin/fast_import/s3_uri.rs index cf4dab7c02..e1a85c73e7 100644 --- a/compute_tools/src/bin/fast_import/s3_uri.rs +++ b/compute_tools/src/bin/fast_import/s3_uri.rs @@ -69,7 +69,7 @@ impl clap::builder::TypedValueParser for S3Uri { S3Uri::from_str(value_str).map_err(|e| { clap::Error::raw( clap::error::ErrorKind::InvalidValue, - format!("Failed to parse S3 URI: {}", e), + format!("Failed to parse S3 URI: {e}"), ) }) } diff --git a/compute_tools/src/catalog.rs b/compute_tools/src/catalog.rs index 082ba62b8e..bc9f64075a 100644 --- a/compute_tools/src/catalog.rs +++ b/compute_tools/src/catalog.rs @@ -22,7 +22,7 @@ pub async fn get_dbs_and_roles(compute: &Arc) -> anyhow::Result { let mut lines = stderr_reader.lines(); if let Some(line) = lines.next_line().await? { - if line.contains(&format!("FATAL: database \"{}\" does not exist", dbname)) { + if line.contains(&format!("FATAL: database \"{dbname}\" does not exist")) { return Err(SchemaDumpError::DatabaseDoesNotExist); } warn!("pg_dump stderr: {}", line) diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index bd6ed910be..70b2d28bf2 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -35,6 +35,7 @@ use url::Url; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; use utils::measured_stream::MeasuredReader; +use utils::pid_file; use crate::configurator::launch_configurator; use crate::disk_quota::set_disk_quota; @@ -44,6 +45,7 @@ use crate::lsn_lease::launch_lsn_lease_bg_task_for_static; use crate::metrics::COMPUTE_CTL_UP; use crate::monitor::launch_monitor; use crate::pg_helpers::*; +use crate::pgbouncer::*; use crate::rsyslog::{ PostgresLogsRsyslogConfig, configure_audit_rsyslog, configure_postgres_logs_export, launch_pgaudit_gc, @@ -161,6 +163,10 @@ pub struct ComputeState { pub lfc_prewarm_state: LfcPrewarmState, pub lfc_offload_state: LfcOffloadState, + /// WAL flush LSN that is set after terminating Postgres and syncing safekeepers if + /// mode == ComputeMode::Primary. None otherwise + pub terminate_flush_lsn: Option, + pub metrics: ComputeMetrics, } @@ -176,6 +182,7 @@ impl ComputeState { metrics: ComputeMetrics::default(), lfc_prewarm_state: LfcPrewarmState::default(), lfc_offload_state: LfcOffloadState::default(), + terminate_flush_lsn: None, } } @@ -215,6 +222,45 @@ pub struct ParsedSpec { pub endpoint_storage_token: Option, } +impl ParsedSpec { + pub fn validate(&self) -> Result<(), String> { + // Only Primary nodes are using safekeeper_connstrings, and at the moment + // this method only validates that part of the specs. + if self.spec.mode != ComputeMode::Primary { + return Ok(()); + } + + // While it seems like a good idea to check for an odd number of entries in + // the safekeepers connection string, changes to the list of safekeepers might + // incur appending a new server to a list of 3, in which case a list of 4 + // entries is okay in production. + // + // Still we want unique entries, and at least one entry in the vector + if self.safekeeper_connstrings.is_empty() { + return Err(String::from("safekeeper_connstrings is empty")); + } + + // check for uniqueness of the connection strings in the set + let mut connstrings = self.safekeeper_connstrings.clone(); + + connstrings.sort(); + let mut previous = &connstrings[0]; + + for current in connstrings.iter().skip(1) { + // duplicate entry? + if current == previous { + return Err(format!( + "duplicate entry in safekeeper_connstrings: {current}!", + )); + } + + previous = current; + } + + Ok(()) + } +} + impl TryFrom for ParsedSpec { type Error = String; fn try_from(spec: ComputeSpec) -> Result { @@ -244,6 +290,7 @@ impl TryFrom for ParsedSpec { } else { spec.safekeeper_connstrings.clone() }; + let storage_auth_token = spec.storage_auth_token.clone(); let tenant_id: TenantId = if let Some(tenant_id) = spec.tenant_id { tenant_id @@ -278,7 +325,7 @@ impl TryFrom for ParsedSpec { .clone() .or_else(|| spec.cluster.settings.find("neon.endpoint_storage_token")); - Ok(ParsedSpec { + let res = ParsedSpec { spec, pageserver_connstr, safekeeper_connstrings, @@ -287,7 +334,11 @@ impl TryFrom for ParsedSpec { timeline_id, endpoint_storage_addr, endpoint_storage_token, - }) + }; + + // Now check validity of the parsed specification + res.validate()?; + Ok(res) } } @@ -354,14 +405,11 @@ impl ComputeNode { // that can affect `compute_ctl` and prevent it from properly configuring the database schema. // Unset them via connection string options before connecting to the database. // N.B. keep it in sync with `ZENITH_OPTIONS` in `get_maintenance_client()`. - // - // TODO(ololobus): we currently pass `-c default_transaction_read_only=off` from control plane - // as well. After rolling out this code, we can remove this parameter from control plane. - // In the meantime, double-passing is fine, the last value is applied. - // See: - const EXTRA_OPTIONS: &str = "-c role=cloud_admin -c default_transaction_read_only=off -c search_path=public -c statement_timeout=0"; + const EXTRA_OPTIONS: &str = "-c role=cloud_admin -c default_transaction_read_only=off -c search_path=public -c statement_timeout=0 -c pgaudit.log=none"; let options = match conn_conf.get_options() { - Some(options) => format!("{} {}", options, EXTRA_OPTIONS), + // Allow the control plane to override any options set by the + // compute + Some(options) => format!("{EXTRA_OPTIONS} {options}"), None => EXTRA_OPTIONS.to_string(), }; conn_conf.options(&options); @@ -489,12 +537,21 @@ impl ComputeNode { // Reap the postgres process delay_exit |= this.cleanup_after_postgres_exit()?; + // /terminate returns LSN. If we don't sleep at all, connection will break and we + // won't get result. If we sleep too much, tests will take significantly longer + // and Github Action run will error out + let sleep_duration = if delay_exit { + Duration::from_secs(30) + } else { + Duration::from_millis(300) + }; + // If launch failed, keep serving HTTP requests for a while, so the cloud // control plane can get the actual error. if delay_exit { info!("giving control plane 30s to collect the error before shutdown"); - std::thread::sleep(Duration::from_secs(30)); } + std::thread::sleep(sleep_duration); Ok(exit_code) } @@ -785,7 +842,7 @@ impl ComputeNode { self.spawn_extension_stats_task(); if pspec.spec.autoprewarm { - self.prewarm_lfc(); + self.prewarm_lfc(None); } Ok(()) } @@ -866,20 +923,25 @@ impl ComputeNode { // Maybe sync safekeepers again, to speed up next startup let compute_state = self.state.lock().unwrap().clone(); let pspec = compute_state.pspec.as_ref().expect("spec must be set"); - if matches!(pspec.spec.mode, compute_api::spec::ComputeMode::Primary) { + let lsn = if matches!(pspec.spec.mode, compute_api::spec::ComputeMode::Primary) { info!("syncing safekeepers on shutdown"); let storage_auth_token = pspec.storage_auth_token.clone(); let lsn = self.sync_safekeepers(storage_auth_token)?; - info!("synced safekeepers at lsn {lsn}"); - } + info!(%lsn, "synced safekeepers"); + Some(lsn) + } else { + info!("not primary, not syncing safekeepers"); + None + }; let mut delay_exit = false; let mut state = self.state.lock().unwrap(); - if state.status == ComputeStatus::TerminationPending { + state.terminate_flush_lsn = lsn; + if let ComputeStatus::TerminationPending { mode } = state.status { state.status = ComputeStatus::Terminated; self.state_changed.notify_all(); // we were asked to terminate gracefully, don't exit to avoid restart - delay_exit = true + delay_exit = mode == compute_api::responses::TerminateMode::Fast } drop(state); @@ -1064,7 +1126,7 @@ impl ComputeNode { let sk_configs = sk_connstrs.into_iter().map(|connstr| { // Format connstr let id = connstr.clone(); - let connstr = format!("postgresql://no_user@{}", connstr); + let connstr = format!("postgresql://no_user@{connstr}"); let options = format!( "-c timeline_id={} tenant_id={}", pspec.timeline_id, pspec.tenant_id @@ -1427,7 +1489,7 @@ impl ComputeNode { let (mut client, connection) = conf.connect(NoTls).await?; tokio::spawn(async move { if let Err(e) = connection.await { - eprintln!("connection error: {}", e); + eprintln!("connection error: {e}"); } }); @@ -1570,7 +1632,7 @@ impl ComputeNode { Ok((mut client, connection)) => { tokio::spawn(async move { if let Err(e) = connection.await { - eprintln!("connection error: {}", e); + eprintln!("connection error: {e}"); } }); if let Err(e) = handle_migrations(&mut client).await { @@ -1750,7 +1812,7 @@ impl ComputeNode { // exit loop ComputeStatus::Failed - | ComputeStatus::TerminationPending + | ComputeStatus::TerminationPending { .. } | ComputeStatus::Terminated => break 'cert_update, // wait @@ -1874,7 +1936,7 @@ impl ComputeNode { let (client, connection) = connect_result.unwrap(); tokio::spawn(async move { if let Err(e) = connection.await { - eprintln!("connection error: {}", e); + eprintln!("connection error: {e}"); } }); let result = client @@ -2043,7 +2105,7 @@ LIMIT 100", db_client .simple_query(&query) .await - .with_context(|| format!("Failed to execute query: {}", query))?; + .with_context(|| format!("Failed to execute query: {query}"))?; } Ok(()) @@ -2070,7 +2132,7 @@ LIMIT 100", let version: Option = db_client .query_opt(version_query, &[&ext_name]) .await - .with_context(|| format!("Failed to execute query: {}", version_query))? + .with_context(|| format!("Failed to execute query: {version_query}"))? .map(|row| row.get(0)); // sanitize the inputs as postgres idents. @@ -2085,14 +2147,14 @@ LIMIT 100", db_client .simple_query(&query) .await - .with_context(|| format!("Failed to execute query: {}", query))?; + .with_context(|| format!("Failed to execute query: {query}"))?; } else { let query = format!("CREATE EXTENSION IF NOT EXISTS {ext_name} WITH VERSION {quoted_version}"); db_client .simple_query(&query) .await - .with_context(|| format!("Failed to execute query: {}", query))?; + .with_context(|| format!("Failed to execute query: {query}"))?; } Ok(ext_version) @@ -2251,12 +2313,68 @@ pub async fn installed_extensions(conf: tokio_postgres::Config) -> Result<()> { Ok(()) } -pub fn forward_termination_signal() { +pub fn forward_termination_signal(dev_mode: bool) { let ss_pid = SYNC_SAFEKEEPERS_PID.load(Ordering::SeqCst); if ss_pid != 0 { let ss_pid = nix::unistd::Pid::from_raw(ss_pid as i32); kill(ss_pid, Signal::SIGTERM).ok(); } + + if !dev_mode { + // Terminate pgbouncer with SIGKILL + match pid_file::read(PGBOUNCER_PIDFILE.into()) { + Ok(pid_file::PidFileRead::LockedByOtherProcess(pid)) => { + info!("sending SIGKILL to pgbouncer process pid: {}", pid); + if let Err(e) = kill(pid, Signal::SIGKILL) { + error!("failed to terminate pgbouncer: {}", e); + } + } + // pgbouncer does not lock the pid file, so we read and kill the process directly + Ok(pid_file::PidFileRead::NotHeldByAnyProcess(_)) => { + if let Ok(pid_str) = std::fs::read_to_string(PGBOUNCER_PIDFILE) { + if let Ok(pid) = pid_str.trim().parse::() { + info!( + "sending SIGKILL to pgbouncer process pid: {} (from unlocked pid file)", + pid + ); + if let Err(e) = kill(Pid::from_raw(pid), Signal::SIGKILL) { + error!("failed to terminate pgbouncer: {}", e); + } + } + } else { + info!("pgbouncer pid file exists but process not running"); + } + } + Ok(pid_file::PidFileRead::NotExist) => { + info!("pgbouncer pid file not found, process may not be running"); + } + Err(e) => { + error!("error reading pgbouncer pid file: {}", e); + } + } + + // Terminate local_proxy + match pid_file::read("/etc/local_proxy/pid".into()) { + Ok(pid_file::PidFileRead::LockedByOtherProcess(pid)) => { + info!("sending SIGTERM to local_proxy process pid: {}", pid); + if let Err(e) = kill(pid, Signal::SIGTERM) { + error!("failed to terminate local_proxy: {}", e); + } + } + Ok(pid_file::PidFileRead::NotHeldByAnyProcess(_)) => { + info!("local_proxy PID file exists but process not running"); + } + Ok(pid_file::PidFileRead::NotExist) => { + info!("local_proxy PID file not found, process may not be running"); + } + Err(e) => { + error!("error reading local_proxy PID file: {}", e); + } + } + } else { + info!("Skipping pgbouncer and local_proxy termination because in dev mode"); + } + let pg_pid = PG_PID.load(Ordering::SeqCst); if pg_pid != 0 { let pg_pid = nix::unistd::Pid::from_raw(pg_pid as i32); @@ -2289,3 +2407,21 @@ impl JoinSetExt for tokio::task::JoinSet { }) } } + +#[cfg(test)] +mod tests { + use std::fs::File; + + use super::*; + + #[test] + fn duplicate_safekeeper_connstring() { + let file = File::open("tests/cluster_spec.json").unwrap(); + let spec: ComputeSpec = serde_json::from_reader(file).unwrap(); + + match ParsedSpec::try_from(spec.clone()) { + Ok(_p) => panic!("Failed to detect duplicate entry"), + Err(e) => assert!(e.starts_with("duplicate entry in safekeeper_connstrings:")), + }; + } +} diff --git a/compute_tools/src/compute_prewarm.rs b/compute_tools/src/compute_prewarm.rs index a6a84b3f1f..1c7a7bef60 100644 --- a/compute_tools/src/compute_prewarm.rs +++ b/compute_tools/src/compute_prewarm.rs @@ -25,11 +25,16 @@ struct EndpointStoragePair { } const KEY: &str = "lfc_state"; -impl TryFrom<&crate::compute::ParsedSpec> for EndpointStoragePair { - type Error = anyhow::Error; - fn try_from(pspec: &crate::compute::ParsedSpec) -> Result { - let Some(ref endpoint_id) = pspec.spec.endpoint_id else { - bail!("pspec.endpoint_id missing") +impl EndpointStoragePair { + /// endpoint_id is set to None while prewarming from other endpoint, see replica promotion + /// If not None, takes precedence over pspec.spec.endpoint_id + fn from_spec_and_endpoint( + pspec: &crate::compute::ParsedSpec, + endpoint_id: Option, + ) -> Result { + let endpoint_id = endpoint_id.as_ref().or(pspec.spec.endpoint_id.as_ref()); + let Some(ref endpoint_id) = endpoint_id else { + bail!("pspec.endpoint_id missing, other endpoint_id not provided") }; let Some(ref base_uri) = pspec.endpoint_storage_addr else { bail!("pspec.endpoint_storage_addr missing") @@ -84,7 +89,7 @@ impl ComputeNode { } /// Returns false if there is a prewarm request ongoing, true otherwise - pub fn prewarm_lfc(self: &Arc) -> bool { + pub fn prewarm_lfc(self: &Arc, from_endpoint: Option) -> bool { crate::metrics::LFC_PREWARM_REQUESTS.inc(); { let state = &mut self.state.lock().unwrap().lfc_prewarm_state; @@ -97,7 +102,7 @@ impl ComputeNode { let cloned = self.clone(); spawn(async move { - let Err(err) = cloned.prewarm_impl().await else { + let Err(err) = cloned.prewarm_impl(from_endpoint).await else { cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Completed; return; }; @@ -109,13 +114,14 @@ impl ComputeNode { true } - fn endpoint_storage_pair(&self) -> Result { + /// from_endpoint: None for endpoint managed by this compute_ctl + fn endpoint_storage_pair(&self, from_endpoint: Option) -> Result { let state = self.state.lock().unwrap(); - state.pspec.as_ref().unwrap().try_into() + EndpointStoragePair::from_spec_and_endpoint(state.pspec.as_ref().unwrap(), from_endpoint) } - async fn prewarm_impl(&self) -> Result<()> { - let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?; + async fn prewarm_impl(&self, from_endpoint: Option) -> Result<()> { + let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?; info!(%url, "requesting LFC state from endpoint storage"); let request = Client::new().get(&url).bearer_auth(token); @@ -173,7 +179,7 @@ impl ComputeNode { } async fn offload_lfc_impl(&self) -> Result<()> { - let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?; + let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?; info!(%url, "requesting LFC state from postgres"); let mut compressed = Vec::new(); diff --git a/compute_tools/src/config.rs b/compute_tools/src/config.rs index 933b30134f..169de5c963 100644 --- a/compute_tools/src/config.rs +++ b/compute_tools/src/config.rs @@ -51,7 +51,7 @@ pub fn write_postgres_conf( // Write the postgresql.conf content from the spec file as is. if let Some(conf) = &spec.cluster.postgresql_conf { - writeln!(file, "{}", conf)?; + writeln!(file, "{conf}")?; } // Add options for connecting to storage @@ -70,7 +70,7 @@ pub fn write_postgres_conf( ); // If generation is given, prepend sk list with g#number: if let Some(generation) = spec.safekeepers_generation { - write!(neon_safekeepers_value, "g#{}:", generation)?; + write!(neon_safekeepers_value, "g#{generation}:")?; } neon_safekeepers_value.push_str(&spec.safekeeper_connstrings.join(",")); writeln!( @@ -109,8 +109,8 @@ pub fn write_postgres_conf( tls::update_key_path_blocking(pgdata_path, tls_config); // these are the default, but good to be explicit. - writeln!(file, "ssl_cert_file = '{}'", SERVER_CRT)?; - writeln!(file, "ssl_key_file = '{}'", SERVER_KEY)?; + writeln!(file, "ssl_cert_file = '{SERVER_CRT}'")?; + writeln!(file, "ssl_key_file = '{SERVER_KEY}'")?; } // Locales @@ -191,8 +191,7 @@ pub fn write_postgres_conf( } writeln!( file, - "shared_preload_libraries='{}{}'", - libs, extra_shared_preload_libraries + "shared_preload_libraries='{libs}{extra_shared_preload_libraries}'" )?; } else { // Typically, this should be unreacheable, @@ -244,8 +243,7 @@ pub fn write_postgres_conf( } writeln!( file, - "shared_preload_libraries='{}{}'", - libs, extra_shared_preload_libraries + "shared_preload_libraries='{libs}{extra_shared_preload_libraries}'" )?; } else { // Typically, this should be unreacheable, @@ -263,7 +261,7 @@ pub fn write_postgres_conf( } } - writeln!(file, "neon.extension_server_port={}", extension_server_port)?; + writeln!(file, "neon.extension_server_port={extension_server_port}")?; if spec.drop_subscriptions_before_start { writeln!(file, "neon.disable_logical_replication_subscribers=true")?; @@ -291,7 +289,7 @@ where { let path = pgdata_path.join("compute_ctl_temp_override.conf"); let mut file = File::create(path)?; - write!(file, "{}", options)?; + write!(file, "{options}")?; let res = exec(); diff --git a/compute_tools/src/extension_server.rs b/compute_tools/src/extension_server.rs index 3764bc1525..47931d5f72 100644 --- a/compute_tools/src/extension_server.rs +++ b/compute_tools/src/extension_server.rs @@ -74,9 +74,11 @@ More specifically, here is an example ext_index.json use std::path::Path; use std::str; +use crate::metrics::{REMOTE_EXT_REQUESTS_TOTAL, UNKNOWN_HTTP_STATUS}; use anyhow::{Context, Result, bail}; use bytes::Bytes; use compute_api::spec::RemoteExtSpec; +use postgres_versioninfo::PgMajorVersion; use regex::Regex; use remote_storage::*; use reqwest::StatusCode; @@ -86,8 +88,6 @@ use tracing::log::warn; use url::Url; use zstd::stream::read::Decoder; -use crate::metrics::{REMOTE_EXT_REQUESTS_TOTAL, UNKNOWN_HTTP_STATUS}; - fn get_pg_config(argument: &str, pgbin: &str) -> String { // gives the result of `pg_config [argument]` // where argument is a flag like `--version` or `--sharedir` @@ -106,7 +106,7 @@ fn get_pg_config(argument: &str, pgbin: &str) -> String { .to_string() } -pub fn get_pg_version(pgbin: &str) -> PostgresMajorVersion { +pub fn get_pg_version(pgbin: &str) -> PgMajorVersion { // pg_config --version returns a (platform specific) human readable string // such as "PostgreSQL 15.4". We parse this to v14/v15/v16 etc. let human_version = get_pg_config("--version", pgbin); @@ -114,25 +114,11 @@ pub fn get_pg_version(pgbin: &str) -> PostgresMajorVersion { } pub fn get_pg_version_string(pgbin: &str) -> String { - match get_pg_version(pgbin) { - PostgresMajorVersion::V14 => "v14", - PostgresMajorVersion::V15 => "v15", - PostgresMajorVersion::V16 => "v16", - PostgresMajorVersion::V17 => "v17", - } - .to_owned() + get_pg_version(pgbin).v_str() } -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum PostgresMajorVersion { - V14, - V15, - V16, - V17, -} - -fn parse_pg_version(human_version: &str) -> PostgresMajorVersion { - use PostgresMajorVersion::*; +fn parse_pg_version(human_version: &str) -> PgMajorVersion { + use PgMajorVersion::*; // Normal releases have version strings like "PostgreSQL 15.4". But there // are also pre-release versions like "PostgreSQL 17devel" or "PostgreSQL // 16beta2" or "PostgreSQL 17rc1". And with the --with-extra-version @@ -143,10 +129,10 @@ fn parse_pg_version(human_version: &str) -> PostgresMajorVersion { .captures(human_version) { Some(captures) if captures.len() == 2 => match &captures["major"] { - "14" => return V14, - "15" => return V15, - "16" => return V16, - "17" => return V17, + "14" => return PG14, + "15" => return PG15, + "16" => return PG16, + "17" => return PG17, _ => {} }, _ => {} @@ -310,10 +296,7 @@ async fn download_extension_tar(remote_ext_base_url: &Url, ext_path: &str) -> Re async fn do_extension_server_request(uri: Url) -> Result { let resp = reqwest::get(uri).await.map_err(|e| { ( - format!( - "could not perform remote extensions server request: {:?}", - e - ), + format!("could not perform remote extensions server request: {e:?}"), UNKNOWN_HTTP_STATUS.to_string(), ) })?; @@ -323,7 +306,7 @@ async fn do_extension_server_request(uri: Url) -> Result match resp.bytes().await { Ok(resp) => Ok(resp), Err(e) => Err(( - format!("could not read remote extensions server response: {:?}", e), + format!("could not read remote extensions server response: {e:?}"), // It's fine to return and report error with status as 200 OK, // because we still failed to read the response. status.to_string(), @@ -334,10 +317,7 @@ async fn do_extension_server_request(uri: Url) -> Result Err(( - format!( - "unexpected remote extensions server response status code: {}", - status - ), + format!("unexpected remote extensions server response status code: {status}"), status.to_string(), )), } @@ -349,25 +329,25 @@ mod tests { #[test] fn test_parse_pg_version() { - use super::PostgresMajorVersion::*; - assert_eq!(parse_pg_version("PostgreSQL 15.4"), V15); - assert_eq!(parse_pg_version("PostgreSQL 15.14"), V15); + use postgres_versioninfo::PgMajorVersion::*; + assert_eq!(parse_pg_version("PostgreSQL 15.4"), PG15); + assert_eq!(parse_pg_version("PostgreSQL 15.14"), PG15); assert_eq!( parse_pg_version("PostgreSQL 15.4 (Ubuntu 15.4-0ubuntu0.23.04.1)"), - V15 + PG15 ); - assert_eq!(parse_pg_version("PostgreSQL 14.15"), V14); - assert_eq!(parse_pg_version("PostgreSQL 14.0"), V14); + assert_eq!(parse_pg_version("PostgreSQL 14.15"), PG14); + assert_eq!(parse_pg_version("PostgreSQL 14.0"), PG14); assert_eq!( parse_pg_version("PostgreSQL 14.9 (Debian 14.9-1.pgdg120+1"), - V14 + PG14 ); - assert_eq!(parse_pg_version("PostgreSQL 16devel"), V16); - assert_eq!(parse_pg_version("PostgreSQL 16beta1"), V16); - assert_eq!(parse_pg_version("PostgreSQL 16rc2"), V16); - assert_eq!(parse_pg_version("PostgreSQL 16extra"), V16); + assert_eq!(parse_pg_version("PostgreSQL 16devel"), PG16); + assert_eq!(parse_pg_version("PostgreSQL 16beta1"), PG16); + assert_eq!(parse_pg_version("PostgreSQL 16rc2"), PG16); + assert_eq!(parse_pg_version("PostgreSQL 16extra"), PG16); } #[test] diff --git a/compute_tools/src/http/routes/configure.rs b/compute_tools/src/http/routes/configure.rs index c29e3a97da..b7325d283f 100644 --- a/compute_tools/src/http/routes/configure.rs +++ b/compute_tools/src/http/routes/configure.rs @@ -65,7 +65,7 @@ pub(in crate::http) async fn configure( if state.status == ComputeStatus::Failed { let err = state.error.as_ref().map_or("unknown error", |x| x); - let msg = format!("compute configuration failed: {:?}", err); + let msg = format!("compute configuration failed: {err:?}"); return Err(msg); } } diff --git a/compute_tools/src/http/routes/lfc.rs b/compute_tools/src/http/routes/lfc.rs index 07bcc6bfb7..e98bd781a2 100644 --- a/compute_tools/src/http/routes/lfc.rs +++ b/compute_tools/src/http/routes/lfc.rs @@ -2,6 +2,7 @@ use crate::compute_prewarm::LfcPrewarmStateWithProgress; use crate::http::JsonResponse; use axum::response::{IntoResponse, Response}; use axum::{Json, http::StatusCode}; +use axum_extra::extract::OptionalQuery; use compute_api::responses::LfcOffloadState; type Compute = axum::extract::State>; @@ -16,8 +17,16 @@ pub(in crate::http) async fn offload_state(compute: Compute) -> Json Response { - if compute.prewarm_lfc() { +#[derive(serde::Deserialize)] +pub struct PrewarmQuery { + pub from_endpoint: String, +} + +pub(in crate::http) async fn prewarm( + compute: Compute, + OptionalQuery(query): OptionalQuery, +) -> Response { + if compute.prewarm_lfc(query.map(|q| q.from_endpoint)) { StatusCode::ACCEPTED.into_response() } else { JsonResponse::error( diff --git a/compute_tools/src/http/routes/terminate.rs b/compute_tools/src/http/routes/terminate.rs index 2c24d4ad6b..32d90a5990 100644 --- a/compute_tools/src/http/routes/terminate.rs +++ b/compute_tools/src/http/routes/terminate.rs @@ -1,32 +1,42 @@ -use std::sync::Arc; - +use crate::compute::{ComputeNode, forward_termination_signal}; +use crate::http::JsonResponse; use axum::extract::State; -use axum::response::{IntoResponse, Response}; -use compute_api::responses::ComputeStatus; +use axum::response::Response; +use axum_extra::extract::OptionalQuery; +use compute_api::responses::{ComputeStatus, TerminateResponse}; use http::StatusCode; +use serde::Deserialize; +use std::sync::Arc; use tokio::task; use tracing::info; -use crate::compute::{ComputeNode, forward_termination_signal}; -use crate::http::JsonResponse; +#[derive(Deserialize, Default)] +pub struct TerminateQuery { + mode: compute_api::responses::TerminateMode, +} /// Terminate the compute. -pub(in crate::http) async fn terminate(State(compute): State>) -> Response { +pub(in crate::http) async fn terminate( + State(compute): State>, + OptionalQuery(terminate): OptionalQuery, +) -> Response { + let mode = terminate.unwrap_or_default().mode; { let mut state = compute.state.lock().unwrap(); if state.status == ComputeStatus::Terminated { - return StatusCode::CREATED.into_response(); + return JsonResponse::success(StatusCode::CREATED, state.terminate_flush_lsn); } if !matches!(state.status, ComputeStatus::Empty | ComputeStatus::Running) { return JsonResponse::invalid_status(state.status); } - - state.set_status(ComputeStatus::TerminationPending, &compute.state_changed); - drop(state); + state.set_status( + ComputeStatus::TerminationPending { mode }, + &compute.state_changed, + ); } - forward_termination_signal(); + forward_termination_signal(false); info!("sent signal and notified waiters"); // Spawn a blocking thread to wait for compute to become Terminated. @@ -34,7 +44,7 @@ pub(in crate::http) async fn terminate(State(compute): State>) // be able to serve other requests while some particular request // is waiting for compute to finish configuration. let c = compute.clone(); - task::spawn_blocking(move || { + let lsn = task::spawn_blocking(move || { let mut state = c.state.lock().unwrap(); while state.status != ComputeStatus::Terminated { state = c.state_changed.wait(state).unwrap(); @@ -44,11 +54,10 @@ pub(in crate::http) async fn terminate(State(compute): State>) state.status ); } + state.terminate_flush_lsn }) .await .unwrap(); - info!("terminated Postgres"); - - StatusCode::OK.into_response() + JsonResponse::success(StatusCode::OK, TerminateResponse { lsn }) } diff --git a/compute_tools/src/installed_extensions.rs b/compute_tools/src/installed_extensions.rs index d95c168a99..411e03b7ec 100644 --- a/compute_tools/src/installed_extensions.rs +++ b/compute_tools/src/installed_extensions.rs @@ -43,7 +43,7 @@ pub async fn get_installed_extensions(mut conf: Config) -> Result Result Result> { let mut client = config.connect(NoTls)?; - let cmd = format!("lease lsn {} {} {} ", tenant_shard_id, timeline_id, lsn); + let cmd = format!("lease lsn {tenant_shard_id} {timeline_id} {lsn} "); let res = client.simple_query(&cmd)?; let msg = match res.first() { Some(msg) => msg, diff --git a/compute_tools/src/monitor.rs b/compute_tools/src/monitor.rs index 3311ee47b3..8a2f6addad 100644 --- a/compute_tools/src/monitor.rs +++ b/compute_tools/src/monitor.rs @@ -13,6 +13,12 @@ use crate::metrics::{PG_CURR_DOWNTIME_MS, PG_TOTAL_DOWNTIME_MS}; const MONITOR_CHECK_INTERVAL: Duration = Duration::from_millis(500); +/// Struct to store runtime state of the compute monitor thread. +/// In theory, this could be a part of `Compute`, but i) +/// this state is expected to be accessed only by single thread, +/// so we don't need to care about locking; ii) `Compute` is +/// already quite big. Thus, it seems to be a good idea to keep +/// all the activity/health monitoring parts here. struct ComputeMonitor { compute: Arc, @@ -70,12 +76,38 @@ impl ComputeMonitor { ) } + /// Check if compute is in some terminal or soon-to-be-terminal + /// state, then return `true`, signalling the caller that it + /// should exit gracefully. Otherwise, return `false`. + fn check_interrupts(&mut self) -> bool { + let compute_status = self.compute.get_status(); + if matches!( + compute_status, + ComputeStatus::Terminated + | ComputeStatus::TerminationPending { .. } + | ComputeStatus::Failed + ) { + info!( + "compute is in {} status, stopping compute monitor", + compute_status + ); + return true; + } + + false + } + /// Spin in a loop and figure out the last activity time in the Postgres. - /// Then update it in the shared state. This function never errors out. + /// Then update it in the shared state. This function currently never + /// errors out explicitly, but there is a graceful termination path. + /// Every time we receive an error trying to check Postgres, we use + /// [`ComputeMonitor::check_interrupts()`] because it could be that + /// compute is being terminated already, then we can exit gracefully + /// to not produce errors' noise in the log. /// NB: the only expected panic is at `Mutex` unwrap(), all other errors /// should be handled gracefully. #[instrument(skip_all)] - pub fn run(&mut self) { + pub fn run(&mut self) -> anyhow::Result<()> { // Suppose that `connstr` doesn't change let connstr = self.compute.params.connstr.clone(); let conf = self @@ -93,6 +125,10 @@ impl ComputeMonitor { info!("starting compute monitor for {}", connstr); loop { + if self.check_interrupts() { + break; + } + match &mut client { Ok(cli) => { if cli.is_closed() { @@ -100,6 +136,10 @@ impl ComputeMonitor { downtime_info = self.downtime_info(), "connection to Postgres is closed, trying to reconnect" ); + if self.check_interrupts() { + break; + } + self.report_down(); // Connection is closed, reconnect and try again. @@ -111,15 +151,19 @@ impl ComputeMonitor { self.compute.update_last_active(self.last_active); } Err(e) => { + error!( + downtime_info = self.downtime_info(), + "could not check Postgres: {}", e + ); + if self.check_interrupts() { + break; + } + // Although we have many places where we can return errors in `check()`, // normally it shouldn't happen. I.e., we will likely return error if // connection got broken, query timed out, Postgres returned invalid data, etc. // In all such cases it's suspicious, so let's report this as downtime. self.report_down(); - error!( - downtime_info = self.downtime_info(), - "could not check Postgres: {}", e - ); // Reconnect to Postgres just in case. During tests, I noticed // that queries in `check()` can fail with `connection closed`, @@ -136,6 +180,10 @@ impl ComputeMonitor { downtime_info = self.downtime_info(), "could not connect to Postgres: {}, retrying", e ); + if self.check_interrupts() { + break; + } + self.report_down(); // Establish a new connection and try again. @@ -147,6 +195,9 @@ impl ComputeMonitor { self.last_checked = Utc::now(); thread::sleep(MONITOR_CHECK_INTERVAL); } + + // Graceful termination path + Ok(()) } #[instrument(skip_all)] @@ -429,7 +480,10 @@ pub fn launch_monitor(compute: &Arc) -> thread::JoinHandle<()> { .spawn(move || { let span = span!(Level::INFO, "compute_monitor"); let _enter = span.enter(); - monitor.run(); + match monitor.run() { + Ok(_) => info!("compute monitor thread terminated gracefully"), + Err(err) => error!("compute monitor thread terminated abnormally {:?}", err), + } }) .expect("cannot launch compute monitor thread") } diff --git a/compute_tools/src/pg_helpers.rs b/compute_tools/src/pg_helpers.rs index 94467a0d2f..0a3ceed2fa 100644 --- a/compute_tools/src/pg_helpers.rs +++ b/compute_tools/src/pg_helpers.rs @@ -36,9 +36,9 @@ pub fn escape_literal(s: &str) -> String { let res = s.replace('\'', "''").replace('\\', "\\\\"); if res.contains('\\') { - format!("E'{}'", res) + format!("E'{res}'") } else { - format!("'{}'", res) + format!("'{res}'") } } @@ -46,7 +46,7 @@ pub fn escape_literal(s: &str) -> String { /// with `'{}'` is not required, as it returns a ready-to-use config string. pub fn escape_conf_value(s: &str) -> String { let res = s.replace('\'', "''").replace('\\', "\\\\"); - format!("'{}'", res) + format!("'{res}'") } pub trait GenericOptionExt { @@ -446,7 +446,7 @@ pub async fn tune_pgbouncer( let mut pgbouncer_connstr = "host=localhost port=6432 dbname=pgbouncer user=postgres sslmode=disable".to_string(); if let Ok(pass) = std::env::var("PGBOUNCER_PASSWORD") { - pgbouncer_connstr.push_str(format!(" password={}", pass).as_str()); + pgbouncer_connstr.push_str(format!(" password={pass}").as_str()); } pgbouncer_connstr }; @@ -464,7 +464,7 @@ pub async fn tune_pgbouncer( Ok((client, connection)) => { tokio::spawn(async move { if let Err(e) = connection.await { - eprintln!("connection error: {}", e); + eprintln!("connection error: {e}"); } }); break client; diff --git a/compute_tools/src/pgbouncer.rs b/compute_tools/src/pgbouncer.rs new file mode 100644 index 0000000000..189dfabac9 --- /dev/null +++ b/compute_tools/src/pgbouncer.rs @@ -0,0 +1 @@ +pub const PGBOUNCER_PIDFILE: &str = "/tmp/pgbouncer.pid"; diff --git a/compute_tools/src/spec.rs b/compute_tools/src/spec.rs index 4b38e6e29c..43cfbb48f7 100644 --- a/compute_tools/src/spec.rs +++ b/compute_tools/src/spec.rs @@ -23,12 +23,12 @@ fn do_control_plane_request( ) -> Result { let resp = reqwest::blocking::Client::new() .get(uri) - .header("Authorization", format!("Bearer {}", jwt)) + .header("Authorization", format!("Bearer {jwt}")) .send() .map_err(|e| { ( true, - format!("could not perform request to control plane: {:?}", e), + format!("could not perform request to control plane: {e:?}"), UNKNOWN_HTTP_STATUS.to_string(), ) })?; @@ -39,7 +39,7 @@ fn do_control_plane_request( Ok(spec_resp) => Ok(spec_resp), Err(e) => Err(( true, - format!("could not deserialize control plane response: {:?}", e), + format!("could not deserialize control plane response: {e:?}"), status.to_string(), )), }, @@ -62,7 +62,7 @@ fn do_control_plane_request( // or some internal failure happened. Doesn't make much sense to retry in this case. _ => Err(( false, - format!("unexpected control plane response status code: {}", status), + format!("unexpected control plane response status code: {status}"), status.to_string(), )), } diff --git a/compute_tools/src/spec_apply.rs b/compute_tools/src/spec_apply.rs index 0d1389dbad..fcd072263a 100644 --- a/compute_tools/src/spec_apply.rs +++ b/compute_tools/src/spec_apply.rs @@ -933,56 +933,53 @@ async fn get_operations<'a>( PerDatabasePhase::DeleteDBRoleReferences => { let ctx = ctx.read().await; - let operations = - spec.delta_operations - .iter() - .flatten() - .filter(|op| op.action == "delete_role") - .filter_map(move |op| { - if db.is_owned_by(&op.name) { - return None; - } - if !ctx.roles.contains_key(&op.name) { - return None; - } - let quoted = op.name.pg_quote(); - let new_owner = match &db { - DB::SystemDB => PgIdent::from("cloud_admin").pg_quote(), - DB::UserDB(db) => db.owner.pg_quote(), - }; - let (escaped_role, outer_tag) = op.name.pg_quote_dollar(); + let operations = spec + .delta_operations + .iter() + .flatten() + .filter(|op| op.action == "delete_role") + .filter_map(move |op| { + if db.is_owned_by(&op.name) { + return None; + } + if !ctx.roles.contains_key(&op.name) { + return None; + } + let quoted = op.name.pg_quote(); + let new_owner = match &db { + DB::SystemDB => PgIdent::from("cloud_admin").pg_quote(), + DB::UserDB(db) => db.owner.pg_quote(), + }; + let (escaped_role, outer_tag) = op.name.pg_quote_dollar(); - Some(vec![ - // This will reassign all dependent objects to the db owner - Operation { - query: format!( - "REASSIGN OWNED BY {} TO {}", - quoted, new_owner, - ), - comment: None, - }, - // Revoke some potentially blocking privileges (Neon-specific currently) - Operation { - query: format!( - include_str!("sql/pre_drop_role_revoke_privileges.sql"), - // N.B. this has to be properly dollar-escaped with `pg_quote_dollar()` - role_name = escaped_role, - outer_tag = outer_tag, - ), - comment: None, - }, - // This now will only drop privileges of the role - // TODO: this is obviously not 100% true because of the above case, - // there could be still some privileges that are not revoked. Maybe this - // only drops privileges that were granted *by this* role, not *to this* role, - // but this has to be checked. - Operation { - query: format!("DROP OWNED BY {}", quoted), - comment: None, - }, - ]) - }) - .flatten(); + Some(vec![ + // This will reassign all dependent objects to the db owner + Operation { + query: format!("REASSIGN OWNED BY {quoted} TO {new_owner}",), + comment: None, + }, + // Revoke some potentially blocking privileges (Neon-specific currently) + Operation { + query: format!( + include_str!("sql/pre_drop_role_revoke_privileges.sql"), + // N.B. this has to be properly dollar-escaped with `pg_quote_dollar()` + role_name = escaped_role, + outer_tag = outer_tag, + ), + comment: None, + }, + // This now will only drop privileges of the role + // TODO: this is obviously not 100% true because of the above case, + // there could be still some privileges that are not revoked. Maybe this + // only drops privileges that were granted *by this* role, not *to this* role, + // but this has to be checked. + Operation { + query: format!("DROP OWNED BY {quoted}"), + comment: None, + }, + ]) + }) + .flatten(); Ok(Box::new(operations)) } diff --git a/compute_tools/src/sync_sk.rs b/compute_tools/src/sync_sk.rs index 22b7027b93..6c348644b2 100644 --- a/compute_tools/src/sync_sk.rs +++ b/compute_tools/src/sync_sk.rs @@ -27,7 +27,7 @@ pub async fn ping_safekeeper( let (client, conn) = config.connect(tokio_postgres::NoTls).await?; tokio::spawn(async move { if let Err(e) = conn.await { - eprintln!("connection error: {}", e); + eprintln!("connection error: {e}"); } }); diff --git a/compute_tools/tests/README.md b/compute_tools/tests/README.md new file mode 100644 index 0000000000..adeb9ef4b6 --- /dev/null +++ b/compute_tools/tests/README.md @@ -0,0 +1,6 @@ +### Test files + +The file `cluster_spec.json` has been copied over from libs/compute_api +tests, with some edits: + + - the neon.safekeepers setting contains a duplicate value diff --git a/compute_tools/tests/cluster_spec.json b/compute_tools/tests/cluster_spec.json new file mode 100644 index 0000000000..5655a94de4 --- /dev/null +++ b/compute_tools/tests/cluster_spec.json @@ -0,0 +1,245 @@ +{ + "format_version": 1.0, + + "timestamp": "2021-05-23T18:25:43.511Z", + "operation_uuid": "0f657b36-4b0f-4a2d-9c2e-1dcd615e7d8b", + + "cluster": { + "cluster_id": "test-cluster-42", + "name": "Zenith Test", + "state": "restarted", + "roles": [ + { + "name": "postgres", + "encrypted_password": "6b1d16b78004bbd51fa06af9eda75972", + "options": null + }, + { + "name": "alexk", + "encrypted_password": null, + "options": null + }, + { + "name": "zenith \"new\"", + "encrypted_password": "5b1d16b78004bbd51fa06af9eda75972", + "options": null + }, + { + "name": "zen", + "encrypted_password": "9b1d16b78004bbd51fa06af9eda75972" + }, + { + "name": "\"name\";\\n select 1;", + "encrypted_password": "5b1d16b78004bbd51fa06af9eda75972" + }, + { + "name": "MyRole", + "encrypted_password": "5b1d16b78004bbd51fa06af9eda75972" + } + ], + "databases": [ + { + "name": "DB2", + "owner": "alexk", + "options": [ + { + "name": "LC_COLLATE", + "value": "C", + "vartype": "string" + }, + { + "name": "LC_CTYPE", + "value": "C", + "vartype": "string" + }, + { + "name": "TEMPLATE", + "value": "template0", + "vartype": "enum" + } + ] + }, + { + "name": "zenith", + "owner": "MyRole" + }, + { + "name": "zen", + "owner": "zen" + } + ], + "settings": [ + { + "name": "fsync", + "value": "off", + "vartype": "bool" + }, + { + "name": "wal_level", + "value": "logical", + "vartype": "enum" + }, + { + "name": "hot_standby", + "value": "on", + "vartype": "bool" + }, + { + "name": "prewarm_lfc_on_startup", + "value": "off", + "vartype": "bool" + }, + { + "name": "neon.safekeepers", + "value": "127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501,127.0.0.1:6502", + "vartype": "string" + }, + { + "name": "wal_log_hints", + "value": "on", + "vartype": "bool" + }, + { + "name": "log_connections", + "value": "on", + "vartype": "bool" + }, + { + "name": "shared_buffers", + "value": "32768", + "vartype": "integer" + }, + { + "name": "port", + "value": "55432", + "vartype": "integer" + }, + { + "name": "max_connections", + "value": "100", + "vartype": "integer" + }, + { + "name": "max_wal_senders", + "value": "10", + "vartype": "integer" + }, + { + "name": "listen_addresses", + "value": "0.0.0.0", + "vartype": "string" + }, + { + "name": "wal_sender_timeout", + "value": "0", + "vartype": "integer" + }, + { + "name": "password_encryption", + "value": "md5", + "vartype": "enum" + }, + { + "name": "maintenance_work_mem", + "value": "65536", + "vartype": "integer" + }, + { + "name": "max_parallel_workers", + "value": "8", + "vartype": "integer" + }, + { + "name": "max_worker_processes", + "value": "8", + "vartype": "integer" + }, + { + "name": "neon.tenant_id", + "value": "b0554b632bd4d547a63b86c3630317e8", + "vartype": "string" + }, + { + "name": "max_replication_slots", + "value": "10", + "vartype": "integer" + }, + { + "name": "neon.timeline_id", + "value": "2414a61ffc94e428f14b5758fe308e13", + "vartype": "string" + }, + { + "name": "shared_preload_libraries", + "value": "neon", + "vartype": "string" + }, + { + "name": "synchronous_standby_names", + "value": "walproposer", + "vartype": "string" + }, + { + "name": "neon.pageserver_connstring", + "value": "host=127.0.0.1 port=6400", + "vartype": "string" + }, + { + "name": "test.escaping", + "value": "here's a backslash \\ and a quote ' and a double-quote \" hooray", + "vartype": "string" + } + ] + }, + "delta_operations": [ + { + "action": "delete_db", + "name": "zenith_test" + }, + { + "action": "rename_db", + "name": "DB", + "new_name": "DB2" + }, + { + "action": "delete_role", + "name": "zenith2" + }, + { + "action": "rename_role", + "name": "zenith new", + "new_name": "zenith \"new\"" + } + ], + "remote_extensions": { + "library_index": { + "postgis-3": "postgis", + "libpgrouting-3.4": "postgis", + "postgis_raster-3": "postgis", + "postgis_sfcgal-3": "postgis", + "postgis_topology-3": "postgis", + "address_standardizer-3": "postgis" + }, + "extension_data": { + "postgis": { + "archive_path": "5834329303/v15/extensions/postgis.tar.zst", + "control_data": { + "postgis.control": "# postgis extension\ncomment = ''PostGIS geometry and geography spatial types and functions''\ndefault_version = ''3.3.2''\nmodule_pathname = ''$libdir/postgis-3''\nrelocatable = false\ntrusted = true\n", + "pgrouting.control": "# pgRouting Extension\ncomment = ''pgRouting Extension''\ndefault_version = ''3.4.2''\nmodule_pathname = ''$libdir/libpgrouting-3.4''\nrelocatable = true\nrequires = ''plpgsql''\nrequires = ''postgis''\ntrusted = true\n", + "postgis_raster.control": "# postgis_raster extension\ncomment = ''PostGIS raster types and functions''\ndefault_version = ''3.3.2''\nmodule_pathname = ''$libdir/postgis_raster-3''\nrelocatable = false\nrequires = postgis\ntrusted = true\n", + "postgis_sfcgal.control": "# postgis topology extension\ncomment = ''PostGIS SFCGAL functions''\ndefault_version = ''3.3.2''\nrelocatable = true\nrequires = postgis\ntrusted = true\n", + "postgis_topology.control": "# postgis topology extension\ncomment = ''PostGIS topology spatial types and functions''\ndefault_version = ''3.3.2''\nrelocatable = false\nschema = topology\nrequires = postgis\ntrusted = true\n", + "address_standardizer.control": "# address_standardizer extension\ncomment = ''Used to parse an address into constituent elements. Generally used to support geocoding address normalization step.''\ndefault_version = ''3.3.2''\nrelocatable = true\ntrusted = true\n", + "postgis_tiger_geocoder.control": "# postgis tiger geocoder extension\ncomment = ''PostGIS tiger geocoder and reverse geocoder''\ndefault_version = ''3.3.2''\nrelocatable = false\nschema = tiger\nrequires = ''postgis,fuzzystrmatch''\nsuperuser= false\ntrusted = true\n", + "address_standardizer_data_us.control": "# address standardizer us dataset\ncomment = ''Address Standardizer US dataset example''\ndefault_version = ''3.3.2''\nrelocatable = true\ntrusted = true\n" + } + } + }, + "custom_extensions": [], + "public_extensions": ["postgis"] + }, + "pgbouncer_settings": { + "default_pool_size": "42", + "pool_mode": "session" + } +} diff --git a/control_plane/Cargo.toml b/control_plane/Cargo.toml index 62c039047f..bbaa3f12b9 100644 --- a/control_plane/Cargo.toml +++ b/control_plane/Cargo.toml @@ -36,6 +36,7 @@ pageserver_api.workspace = true pageserver_client.workspace = true postgres_backend.workspace = true safekeeper_api.workspace = true +safekeeper_client.workspace = true postgres_connection.workspace = true storage_broker.workspace = true http-utils.workspace = true diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index ef6985d697..c818d07fef 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -18,7 +18,7 @@ use clap::Parser; use compute_api::requests::ComputeClaimsScope; use compute_api::spec::ComputeMode; use control_plane::broker::StorageBroker; -use control_plane::endpoint::ComputeControlPlane; +use control_plane::endpoint::{ComputeControlPlane, EndpointTerminateMode, PageserverProtocol}; use control_plane::endpoint_storage::{ENDPOINT_STORAGE_DEFAULT_ADDR, EndpointStorage}; use control_plane::local_env; use control_plane::local_env::{ @@ -45,10 +45,10 @@ use pageserver_api::models::{ use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize, TenantShardId}; use postgres_backend::AuthType; use postgres_connection::parse_host_port; -use safekeeper_api::membership::SafekeeperGeneration; +use safekeeper_api::membership::{SafekeeperGeneration, SafekeeperId}; use safekeeper_api::{ DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT, - DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT, + DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT, PgMajorVersion, PgVersionId, }; use storage_broker::DEFAULT_LISTEN_ADDR as DEFAULT_BROKER_ADDR; use tokio::task::JoinSet; @@ -64,7 +64,7 @@ const DEFAULT_PAGESERVER_ID: NodeId = NodeId(1); const DEFAULT_BRANCH_NAME: &str = "main"; project_git_version!(GIT_VERSION); -const DEFAULT_PG_VERSION: u32 = 17; +const DEFAULT_PG_VERSION: PgMajorVersion = PgMajorVersion::PG17; const DEFAULT_PAGESERVER_CONTROL_PLANE_API: &str = "http://127.0.0.1:1234/upcall/v1/"; @@ -169,7 +169,7 @@ struct TenantCreateCmdArgs { #[arg(default_value_t = DEFAULT_PG_VERSION)] #[clap(long, help = "Postgres version to use for the initial timeline")] - pg_version: u32, + pg_version: PgMajorVersion, #[clap( long, @@ -292,7 +292,7 @@ struct TimelineCreateCmdArgs { #[arg(default_value_t = DEFAULT_PG_VERSION)] #[clap(long, help = "Postgres version")] - pg_version: u32, + pg_version: PgMajorVersion, } #[derive(clap::Args)] @@ -324,7 +324,7 @@ struct TimelineImportCmdArgs { #[arg(default_value_t = DEFAULT_PG_VERSION)] #[clap(long, help = "Postgres version of the backup being imported")] - pg_version: u32, + pg_version: PgMajorVersion, } #[derive(clap::Subcommand)] @@ -603,7 +603,15 @@ struct EndpointCreateCmdArgs { #[arg(default_value_t = DEFAULT_PG_VERSION)] #[clap(long, help = "Postgres version")] - pg_version: u32, + pg_version: PgMajorVersion, + + /// Use gRPC to communicate with Pageservers, by generating grpc:// connstrings. + /// + /// Specified on creation such that it's retained across reconfiguration and restarts. + /// + /// NB: not yet supported by computes. + #[clap(long)] + grpc: bool, #[clap( long, @@ -664,6 +672,13 @@ struct EndpointStartCmdArgs { #[clap(short = 't', long, value_parser= humantime::parse_duration, help = "timeout until we fail the command")] #[arg(default_value = "90s")] start_timeout: Duration, + + #[clap( + long, + help = "Run in development mode, skipping VM-specific operations like process termination", + action = clap::ArgAction::SetTrue + )] + dev: bool, } #[derive(clap::Args)] @@ -696,10 +711,9 @@ struct EndpointStopCmdArgs { )] destroy: bool, - #[clap(long, help = "Postgres shutdown mode, passed to \"pg_ctl -m \"")] - #[arg(value_parser(["smart", "fast", "immediate"]))] - #[arg(default_value = "fast")] - mode: String, + #[clap(long, help = "Postgres shutdown mode")] + #[clap(default_value = "fast")] + mode: EndpointTerminateMode, } #[derive(clap::Args)] @@ -905,7 +919,7 @@ fn print_timeline( br_sym = "┗━"; } - print!("{} @{}: ", br_sym, ancestor_lsn); + print!("{br_sym} @{ancestor_lsn}: "); } // Finally print a timeline id and name with new line @@ -1255,6 +1269,45 @@ async fn handle_timeline(cmd: &TimelineCmd, env: &mut local_env::LocalEnv) -> Re pageserver .timeline_import(tenant_id, timeline_id, base, pg_wal, args.pg_version) .await?; + if env.storage_controller.timelines_onto_safekeepers { + println!("Creating timeline on safekeeper ..."); + let timeline_info = pageserver + .timeline_info( + TenantShardId::unsharded(tenant_id), + timeline_id, + pageserver_client::mgmt_api::ForceAwaitLogicalSize::No, + ) + .await?; + let default_sk = SafekeeperNode::from_env(env, env.safekeepers.first().unwrap()); + let default_host = default_sk + .conf + .listen_addr + .clone() + .unwrap_or_else(|| "localhost".to_string()); + let mconf = safekeeper_api::membership::Configuration { + generation: SafekeeperGeneration::new(1), + members: safekeeper_api::membership::MemberSet { + m: vec![SafekeeperId { + host: default_host, + id: default_sk.conf.id, + pg_port: default_sk.conf.pg_port, + }], + }, + new_members: None, + }; + let pg_version = PgVersionId::from(args.pg_version); + let req = safekeeper_api::models::TimelineCreateRequest { + tenant_id, + timeline_id, + mconf, + pg_version, + system_id: None, + wal_seg_size: None, + start_lsn: timeline_info.last_record_lsn, + commit_lsn: None, + }; + default_sk.create_timeline(&req).await?; + } env.register_branch_mapping(branch_name.to_string(), tenant_id, timeline_id)?; println!("Done"); } @@ -1412,6 +1465,7 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res args.internal_http_port, args.pg_version, mode, + args.grpc, !args.update_catalog, false, )?; @@ -1452,13 +1506,20 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res let (pageservers, stripe_size) = if let Some(pageserver_id) = pageserver_id { let conf = env.get_pageserver_conf(pageserver_id).unwrap(); - let parsed = parse_host_port(&conf.listen_pg_addr).expect("Bad config"); - ( - vec![(parsed.0, parsed.1.unwrap_or(5432))], - // If caller is telling us what pageserver to use, this is not a tenant which is - // full managed by storage controller, therefore not sharded. - DEFAULT_STRIPE_SIZE, - ) + // Use gRPC if requested. + let pageserver = if endpoint.grpc { + let grpc_addr = conf.listen_grpc_addr.as_ref().expect("bad config"); + let (host, port) = parse_host_port(grpc_addr)?; + let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT); + (PageserverProtocol::Grpc, host, port) + } else { + let (host, port) = parse_host_port(&conf.listen_pg_addr)?; + let port = port.unwrap_or(5432); + (PageserverProtocol::Libpq, host, port) + }; + // If caller is telling us what pageserver to use, this is not a tenant which is + // fully managed by storage controller, therefore not sharded. + (vec![pageserver], DEFAULT_STRIPE_SIZE) } else { // Look up the currently attached location of the tenant, and its striping metadata, // to pass these on to postgres. @@ -1477,11 +1538,20 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res .await?; } - anyhow::Ok(( - Host::parse(&shard.listen_pg_addr) - .expect("Storage controller reported bad hostname"), - shard.listen_pg_port, - )) + let pageserver = if endpoint.grpc { + ( + PageserverProtocol::Grpc, + Host::parse(&shard.listen_grpc_addr.expect("no gRPC address"))?, + shard.listen_grpc_port.expect("no gRPC port"), + ) + } else { + ( + PageserverProtocol::Libpq, + Host::parse(&shard.listen_pg_addr)?, + shard.listen_pg_port, + ) + }; + anyhow::Ok(pageserver) }), ) .await?; @@ -1526,6 +1596,7 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res stripe_size.0 as usize, args.create_test_user, args.start_timeout, + args.dev, ) .await?; } @@ -1536,11 +1607,19 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res .get(endpoint_id.as_str()) .with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?; let pageservers = if let Some(ps_id) = args.endpoint_pageserver_id { - let pageserver = PageServerNode::from_env(env, env.get_pageserver_conf(ps_id)?); - vec![( - pageserver.pg_connection_config.host().clone(), - pageserver.pg_connection_config.port(), - )] + let conf = env.get_pageserver_conf(ps_id)?; + // Use gRPC if requested. + let pageserver = if endpoint.grpc { + let grpc_addr = conf.listen_grpc_addr.as_ref().expect("bad config"); + let (host, port) = parse_host_port(grpc_addr)?; + let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT); + (PageserverProtocol::Grpc, host, port) + } else { + let (host, port) = parse_host_port(&conf.listen_pg_addr)?; + let port = port.unwrap_or(5432); + (PageserverProtocol::Libpq, host, port) + }; + vec![pageserver] } else { let storage_controller = StorageController::from_env(env); storage_controller @@ -1549,11 +1628,21 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res .shards .into_iter() .map(|shard| { - ( - Host::parse(&shard.listen_pg_addr) - .expect("Storage controller reported malformed host"), - shard.listen_pg_port, - ) + // Use gRPC if requested. + if endpoint.grpc { + ( + PageserverProtocol::Grpc, + Host::parse(&shard.listen_grpc_addr.expect("no gRPC address")) + .expect("bad hostname"), + shard.listen_grpc_port.expect("no gRPC port"), + ) + } else { + ( + PageserverProtocol::Libpq, + Host::parse(&shard.listen_pg_addr).expect("bad hostname"), + shard.listen_pg_port, + ) + } }) .collect::>() }; @@ -1568,7 +1657,10 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res .endpoints .get(endpoint_id) .with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?; - endpoint.stop(&args.mode, args.destroy)?; + match endpoint.stop(args.mode, args.destroy).await?.lsn { + Some(lsn) => println!("{lsn}"), + None => println!("null"), + } } EndpointCmd::GenerateJwt(args) => { let endpoint = { @@ -1650,7 +1742,7 @@ async fn handle_pageserver(subcmd: &PageserverCmd, env: &local_env::LocalEnv) -> StopMode::Immediate => true, }; if let Err(e) = get_pageserver(env, args.pageserver_id)?.stop(immediate) { - eprintln!("pageserver stop failed: {}", e); + eprintln!("pageserver stop failed: {e}"); exit(1); } } @@ -1659,7 +1751,7 @@ async fn handle_pageserver(subcmd: &PageserverCmd, env: &local_env::LocalEnv) -> let pageserver = get_pageserver(env, args.pageserver_id)?; //TODO what shutdown strategy should we use here? if let Err(e) = pageserver.stop(false) { - eprintln!("pageserver stop failed: {}", e); + eprintln!("pageserver stop failed: {e}"); exit(1); } @@ -1676,7 +1768,7 @@ async fn handle_pageserver(subcmd: &PageserverCmd, env: &local_env::LocalEnv) -> { Ok(_) => println!("Page server is up and running"), Err(err) => { - eprintln!("Page server is not available: {}", err); + eprintln!("Page server is not available: {err}"); exit(1); } } @@ -1713,7 +1805,7 @@ async fn handle_storage_controller( }, }; if let Err(e) = svc.stop(stop_args).await { - eprintln!("stop failed: {}", e); + eprintln!("stop failed: {e}"); exit(1); } } @@ -1735,7 +1827,7 @@ async fn handle_safekeeper(subcmd: &SafekeeperCmd, env: &local_env::LocalEnv) -> let safekeeper = get_safekeeper(env, args.id)?; if let Err(e) = safekeeper.start(&args.extra_opt, &args.start_timeout).await { - eprintln!("safekeeper start failed: {}", e); + eprintln!("safekeeper start failed: {e}"); exit(1); } } @@ -1747,7 +1839,7 @@ async fn handle_safekeeper(subcmd: &SafekeeperCmd, env: &local_env::LocalEnv) -> StopMode::Immediate => true, }; if let Err(e) = safekeeper.stop(immediate) { - eprintln!("safekeeper stop failed: {}", e); + eprintln!("safekeeper stop failed: {e}"); exit(1); } } @@ -1760,12 +1852,12 @@ async fn handle_safekeeper(subcmd: &SafekeeperCmd, env: &local_env::LocalEnv) -> }; if let Err(e) = safekeeper.stop(immediate) { - eprintln!("safekeeper stop failed: {}", e); + eprintln!("safekeeper stop failed: {e}"); exit(1); } if let Err(e) = safekeeper.start(&args.extra_opt, &args.start_timeout).await { - eprintln!("safekeeper start failed: {}", e); + eprintln!("safekeeper start failed: {e}"); exit(1); } } @@ -2000,11 +2092,16 @@ async fn handle_stop_all(args: &StopCmdArgs, env: &local_env::LocalEnv) -> Resul } async fn try_stop_all(env: &local_env::LocalEnv, immediate: bool) { + let mode = if immediate { + EndpointTerminateMode::Immediate + } else { + EndpointTerminateMode::Fast + }; // Stop all endpoints match ComputeControlPlane::load(env.clone()) { Ok(cplane) => { for (_k, node) in cplane.endpoints { - if let Err(e) = node.stop(if immediate { "immediate" } else { "fast" }, false) { + if let Err(e) = node.stop(mode, false).await { eprintln!("postgres stop failed: {e:#}"); } } @@ -2016,7 +2113,7 @@ async fn try_stop_all(env: &local_env::LocalEnv, immediate: bool) { let storage = EndpointStorage::from_env(env); if let Err(e) = storage.stop(immediate) { - eprintln!("endpoint_storage stop failed: {:#}", e); + eprintln!("endpoint_storage stop failed: {e:#}"); } for ps_conf in &env.pageservers { diff --git a/control_plane/src/endpoint.rs b/control_plane/src/endpoint.rs index 774a0053f8..e3faa082db 100644 --- a/control_plane/src/endpoint.rs +++ b/control_plane/src/endpoint.rs @@ -37,6 +37,7 @@ //! ``` //! use std::collections::BTreeMap; +use std::fmt::Display; use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream}; use std::path::PathBuf; use std::process::Command; @@ -45,11 +46,14 @@ use std::sync::Arc; use std::time::{Duration, Instant}; use anyhow::{Context, Result, anyhow, bail}; +use base64::Engine; +use base64::prelude::BASE64_URL_SAFE_NO_PAD; use compute_api::requests::{ COMPUTE_AUDIENCE, ComputeClaims, ComputeClaimsScope, ConfigurationRequest, }; use compute_api::responses::{ - ComputeConfig, ComputeCtlConfig, ComputeStatus, ComputeStatusResponse, TlsConfig, + ComputeConfig, ComputeCtlConfig, ComputeStatus, ComputeStatusResponse, TerminateResponse, + TlsConfig, }; use compute_api::spec::{ Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, PgIdent, @@ -63,6 +67,7 @@ use nix::sys::signal::{Signal, kill}; use pageserver_api::shard::ShardStripeSize; use pem::Pem; use reqwest::header::CONTENT_TYPE; +use safekeeper_api::PgMajorVersion; use safekeeper_api::membership::SafekeeperGeneration; use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; @@ -74,7 +79,6 @@ use utils::id::{NodeId, TenantId, TimelineId}; use crate::local_env::LocalEnv; use crate::postgresql_conf::PostgresConf; -use crate::storage_controller::StorageController; // contents of a endpoint.json file #[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)] @@ -86,7 +90,8 @@ pub struct EndpointConf { pg_port: u16, external_http_port: u16, internal_http_port: u16, - pg_version: u32, + pg_version: PgMajorVersion, + grpc: bool, skip_pg_catalog_updates: bool, reconfigure_concurrency: usize, drop_subscriptions_before_start: bool, @@ -164,7 +169,7 @@ impl ComputeControlPlane { public_key_use: Some(PublicKeyUse::Signature), key_operations: Some(vec![KeyOperations::Verify]), key_algorithm: Some(KeyAlgorithm::EdDSA), - key_id: Some(base64::encode_config(key_hash, base64::URL_SAFE_NO_PAD)), + key_id: Some(BASE64_URL_SAFE_NO_PAD.encode(key_hash)), x509_url: None::, x509_chain: None::>, x509_sha1_fingerprint: None::, @@ -173,7 +178,7 @@ impl ComputeControlPlane { algorithm: AlgorithmParameters::OctetKeyPair(OctetKeyPairParameters { key_type: OctetKeyPairType::OctetKeyPair, curve: EllipticCurve::Ed25519, - x: base64::encode_config(public_key, base64::URL_SAFE_NO_PAD), + x: BASE64_URL_SAFE_NO_PAD.encode(public_key), }), }], }) @@ -188,8 +193,9 @@ impl ComputeControlPlane { pg_port: Option, external_http_port: Option, internal_http_port: Option, - pg_version: u32, + pg_version: PgMajorVersion, mode: ComputeMode, + grpc: bool, skip_pg_catalog_updates: bool, drop_subscriptions_before_start: bool, ) -> Result> { @@ -224,6 +230,7 @@ impl ComputeControlPlane { // we also skip catalog updates in the cloud. skip_pg_catalog_updates, drop_subscriptions_before_start, + grpc, reconfigure_concurrency: 1, features: vec![], cluster: None, @@ -242,6 +249,7 @@ impl ComputeControlPlane { internal_http_port, pg_port, pg_version, + grpc, skip_pg_catalog_updates, drop_subscriptions_before_start, reconfigure_concurrency: 1, @@ -296,6 +304,8 @@ pub struct Endpoint { pub tenant_id: TenantId, pub timeline_id: TimelineId, pub mode: ComputeMode, + /// If true, the endpoint should use gRPC to communicate with Pageservers. + pub grpc: bool, // port and address of the Postgres server and `compute_ctl`'s HTTP APIs pub pg_address: SocketAddr, @@ -303,7 +313,7 @@ pub struct Endpoint { pub internal_http_address: SocketAddr, // postgres major version in the format: 14, 15, etc. - pg_version: u32, + pg_version: PgMajorVersion, // These are not part of the endpoint as such, but the environment // the endpoint runs in. @@ -331,15 +341,58 @@ pub enum EndpointStatus { RunningNoPidfile, } -impl std::fmt::Display for EndpointStatus { +impl Display for EndpointStatus { fn fmt(&self, writer: &mut std::fmt::Formatter) -> std::fmt::Result { - let s = match self { + writer.write_str(match self { Self::Running => "running", Self::Stopped => "stopped", Self::Crashed => "crashed", Self::RunningNoPidfile => "running, no pidfile", - }; - write!(writer, "{}", s) + }) + } +} + +#[derive(Default, Clone, Copy, clap::ValueEnum)] +pub enum EndpointTerminateMode { + #[default] + /// Use pg_ctl stop -m fast + Fast, + /// Use pg_ctl stop -m immediate + Immediate, + /// Use /terminate?mode=immediate + ImmediateTerminate, +} + +impl std::fmt::Display for EndpointTerminateMode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(match &self { + EndpointTerminateMode::Fast => "fast", + EndpointTerminateMode::Immediate => "immediate", + EndpointTerminateMode::ImmediateTerminate => "immediate-terminate", + }) + } +} + +/// Protocol used to connect to a Pageserver. +#[derive(Clone, Copy, Debug)] +pub enum PageserverProtocol { + Libpq, + Grpc, +} + +impl PageserverProtocol { + /// Returns the URL scheme for the protocol, used in connstrings. + pub fn scheme(&self) -> &'static str { + match self { + Self::Libpq => "postgresql", + Self::Grpc => "grpc", + } + } +} + +impl Display for PageserverProtocol { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.scheme()) } } @@ -378,6 +431,7 @@ impl Endpoint { mode: conf.mode, tenant_id: conf.tenant_id, pg_version: conf.pg_version, + grpc: conf.grpc, skip_pg_catalog_updates: conf.skip_pg_catalog_updates, reconfigure_concurrency: conf.reconfigure_concurrency, drop_subscriptions_before_start: conf.drop_subscriptions_before_start, @@ -504,7 +558,7 @@ impl Endpoint { conf.append("hot_standby", "on"); // prefetching of blocks referenced in WAL doesn't make sense for us // Neon hot standby ignores pages that are not in the shared_buffers - if self.pg_version >= 15 { + if self.pg_version >= PgMajorVersion::PG15 { conf.append("recovery_prefetch", "off"); } } @@ -606,10 +660,10 @@ impl Endpoint { } } - fn build_pageserver_connstr(pageservers: &[(Host, u16)]) -> String { + fn build_pageserver_connstr(pageservers: &[(PageserverProtocol, Host, u16)]) -> String { pageservers .iter() - .map(|(host, port)| format!("postgresql://no_user@{host}:{port}")) + .map(|(scheme, host, port)| format!("{scheme}://no_user@{host}:{port}")) .collect::>() .join(",") } @@ -654,11 +708,12 @@ impl Endpoint { endpoint_storage_addr: String, safekeepers_generation: Option, safekeepers: Vec, - pageservers: Vec<(Host, u16)>, + pageservers: Vec<(PageserverProtocol, Host, u16)>, remote_ext_base_url: Option<&String>, shard_stripe_size: usize, create_test_user: bool, start_timeout: Duration, + dev: bool, ) -> Result<()> { if self.status() == EndpointStatus::Running { anyhow::bail!("The endpoint is already running"); @@ -792,10 +847,10 @@ impl Endpoint { // Launch compute_ctl let conn_str = self.connstr("cloud_admin", "postgres"); - println!("Starting postgres node at '{}'", conn_str); + println!("Starting postgres node at '{conn_str}'"); if create_test_user { let conn_str = self.connstr("test", "neondb"); - println!("Also at '{}'", conn_str); + println!("Also at '{conn_str}'"); } let mut cmd = Command::new(self.env.neon_distrib_dir.join("compute_ctl")); cmd.args([ @@ -829,6 +884,10 @@ impl Endpoint { cmd.args(["--remote-ext-base-url", remote_ext_base_url]); } + if dev { + cmd.arg("--dev"); + } + let child = cmd.spawn()?; // set up a scopeguard to kill & wait for the child in case we panic or bail below let child = scopeguard::guard(child, |mut child| { @@ -881,7 +940,7 @@ impl Endpoint { ComputeStatus::Empty | ComputeStatus::ConfigurationPending | ComputeStatus::Configuration - | ComputeStatus::TerminationPending + | ComputeStatus::TerminationPending { .. } | ComputeStatus::Terminated => { bail!("unexpected compute status: {:?}", state.status) } @@ -890,8 +949,7 @@ impl Endpoint { Err(e) => { if Instant::now().duration_since(start_at) > start_timeout { return Err(e).context(format!( - "timed out {:?} waiting to connect to compute_ctl HTTP", - start_timeout, + "timed out {start_timeout:?} waiting to connect to compute_ctl HTTP", )); } } @@ -930,7 +988,7 @@ impl Endpoint { // reqwest does not export its error construction utility functions, so let's craft the message ourselves let url = response.url().to_owned(); let msg = match response.text().await { - Ok(err_body) => format!("Error: {}", err_body), + Ok(err_body) => format!("Error: {err_body}"), Err(_) => format!("Http error ({}) at {}.", status.as_u16(), url), }; Err(anyhow::anyhow!(msg)) @@ -939,10 +997,12 @@ impl Endpoint { pub async fn reconfigure( &self, - mut pageservers: Vec<(Host, u16)>, + pageservers: Vec<(PageserverProtocol, Host, u16)>, stripe_size: Option, safekeepers: Option>, ) -> Result<()> { + anyhow::ensure!(!pageservers.is_empty(), "no pageservers provided"); + let (mut spec, compute_ctl_config) = { let config_path = self.endpoint_path().join("config.json"); let file = std::fs::File::open(config_path)?; @@ -954,25 +1014,7 @@ impl Endpoint { let postgresql_conf = self.read_postgresql_conf()?; spec.cluster.postgresql_conf = Some(postgresql_conf); - // If we weren't given explicit pageservers, query the storage controller - if pageservers.is_empty() { - let storage_controller = StorageController::from_env(&self.env); - let locate_result = storage_controller.tenant_locate(self.tenant_id).await?; - pageservers = locate_result - .shards - .into_iter() - .map(|shard| { - ( - Host::parse(&shard.listen_pg_addr) - .expect("Storage controller reported bad hostname"), - shard.listen_pg_port, - ) - }) - .collect::>(); - } - let pageserver_connstr = Self::build_pageserver_connstr(&pageservers); - assert!(!pageserver_connstr.is_empty()); spec.pageserver_connstring = Some(pageserver_connstr); if stripe_size.is_some() { spec.shard_stripe_size = stripe_size.map(|s| s.0 as usize); @@ -1012,15 +1054,34 @@ impl Endpoint { } else { let url = response.url().to_owned(); let msg = match response.text().await { - Ok(err_body) => format!("Error: {}", err_body), + Ok(err_body) => format!("Error: {err_body}"), Err(_) => format!("Http error ({}) at {}.", status.as_u16(), url), }; Err(anyhow::anyhow!(msg)) } } - pub fn stop(&self, mode: &str, destroy: bool) -> Result<()> { - self.pg_ctl(&["-m", mode, "stop"], &None)?; + pub async fn stop( + &self, + mode: EndpointTerminateMode, + destroy: bool, + ) -> Result { + // pg_ctl stop is fast but doesn't allow us to collect LSN. /terminate is + // slow, and test runs time out. Solution: special mode "immediate-terminate" + // which uses /terminate + let response = if let EndpointTerminateMode::ImmediateTerminate = mode { + let ip = self.external_http_address.ip(); + let port = self.external_http_address.port(); + let url = format!("http://{ip}:{port}/terminate?mode=immediate"); + let token = self.generate_jwt(Some(ComputeClaimsScope::Admin))?; + let request = reqwest::Client::new().post(url).bearer_auth(token); + let response = request.send().await.context("/terminate")?; + let text = response.text().await.context("/terminate result")?; + serde_json::from_str(&text).with_context(|| format!("deserializing {text}"))? + } else { + self.pg_ctl(&["-m", &mode.to_string(), "stop"], &None)?; + TerminateResponse { lsn: None } + }; // Also wait for the compute_ctl process to die. It might have some // cleanup work to do after postgres stops, like syncing safekeepers, @@ -1030,7 +1091,7 @@ impl Endpoint { // waiting. Sometimes we do *not* want this cleanup: tests intentionally // do stop when majority of safekeepers is down, so sync-safekeepers // would hang otherwise. This could be a separate flag though. - let send_sigterm = destroy || mode == "immediate"; + let send_sigterm = destroy || !matches!(mode, EndpointTerminateMode::Fast); self.wait_for_compute_ctl_to_exit(send_sigterm)?; if destroy { println!( @@ -1039,7 +1100,7 @@ impl Endpoint { ); std::fs::remove_dir_all(self.endpoint_path())?; } - Ok(()) + Ok(response) } pub fn connstr(&self, user: &str, db_name: &str) -> String { diff --git a/control_plane/src/local_env.rs b/control_plane/src/local_env.rs index 47b77f0720..16cd2d8c08 100644 --- a/control_plane/src/local_env.rs +++ b/control_plane/src/local_env.rs @@ -12,9 +12,11 @@ use std::{env, fs}; use anyhow::{Context, bail}; use clap::ValueEnum; +use pageserver_api::config::PostHogConfig; use pem::Pem; use postgres_backend::AuthType; use reqwest::{Certificate, Url}; +use safekeeper_api::PgMajorVersion; use serde::{Deserialize, Serialize}; use utils::auth::encode_from_key_file; use utils::id::{NodeId, TenantId, TenantTimelineId, TimelineId}; @@ -209,6 +211,12 @@ pub struct NeonStorageControllerConf { pub use_https_safekeeper_api: bool, pub use_local_compute_notifications: bool, + + pub timeline_safekeeper_count: Option, + + pub posthog_config: Option, + + pub kick_secondary_downloads: Option, } impl NeonStorageControllerConf { @@ -236,9 +244,12 @@ impl Default for NeonStorageControllerConf { heartbeat_interval: Self::DEFAULT_HEARTBEAT_INTERVAL, long_reconcile_threshold: None, use_https_pageserver_api: false, - timelines_onto_safekeepers: false, + timelines_onto_safekeepers: true, use_https_safekeeper_api: false, use_local_compute_notifications: true, + timeline_safekeeper_count: None, + posthog_config: None, + kick_secondary_downloads: None, } } } @@ -254,7 +265,7 @@ impl Default for EndpointStorageConf { impl NeonBroker { pub fn client_url(&self) -> Url { let url = if let Some(addr) = self.listen_https_addr { - format!("https://{}", addr) + format!("https://{addr}") } else { format!( "http://{}", @@ -418,25 +429,21 @@ impl LocalEnv { self.pg_distrib_dir.clone() } - pub fn pg_distrib_dir(&self, pg_version: u32) -> anyhow::Result { + pub fn pg_distrib_dir(&self, pg_version: PgMajorVersion) -> anyhow::Result { let path = self.pg_distrib_dir.clone(); - #[allow(clippy::manual_range_patterns)] - match pg_version { - 14 | 15 | 16 | 17 => Ok(path.join(format!("v{pg_version}"))), - _ => bail!("Unsupported postgres version: {}", pg_version), - } + Ok(path.join(pg_version.v_str())) } - pub fn pg_dir(&self, pg_version: u32, dir_name: &str) -> anyhow::Result { + pub fn pg_dir(&self, pg_version: PgMajorVersion, dir_name: &str) -> anyhow::Result { Ok(self.pg_distrib_dir(pg_version)?.join(dir_name)) } - pub fn pg_bin_dir(&self, pg_version: u32) -> anyhow::Result { + pub fn pg_bin_dir(&self, pg_version: PgMajorVersion) -> anyhow::Result { self.pg_dir(pg_version, "bin") } - pub fn pg_lib_dir(&self, pg_version: u32) -> anyhow::Result { + pub fn pg_lib_dir(&self, pg_version: PgMajorVersion) -> anyhow::Result { self.pg_dir(pg_version, "lib") } @@ -727,7 +734,7 @@ impl LocalEnv { let config_toml_path = dentry.path().join("pageserver.toml"); let config_toml: PageserverConfigTomlSubset = toml_edit::de::from_str( &std::fs::read_to_string(&config_toml_path) - .with_context(|| format!("read {:?}", config_toml_path))?, + .with_context(|| format!("read {config_toml_path:?}"))?, ) .context("parse pageserver.toml")?; let identity_toml_path = dentry.path().join("identity.toml"); @@ -737,7 +744,7 @@ impl LocalEnv { } let identity_toml: IdentityTomlSubset = toml_edit::de::from_str( &std::fs::read_to_string(&identity_toml_path) - .with_context(|| format!("read {:?}", identity_toml_path))?, + .with_context(|| format!("read {identity_toml_path:?}"))?, ) .context("parse identity.toml")?; let PageserverConfigTomlSubset { diff --git a/control_plane/src/pageserver.rs b/control_plane/src/pageserver.rs index db14d98afd..942cefffa5 100644 --- a/control_plane/src/pageserver.rs +++ b/control_plane/src/pageserver.rs @@ -16,11 +16,13 @@ use std::time::Duration; use anyhow::{Context, bail}; use camino::Utf8PathBuf; +use pageserver_api::config::{DEFAULT_GRPC_LISTEN_PORT, DEFAULT_HTTP_LISTEN_PORT}; use pageserver_api::models::{self, TenantInfo, TimelineInfo}; use pageserver_api::shard::TenantShardId; use pageserver_client::mgmt_api; use postgres_backend::AuthType; use postgres_connection::{PgConnectionConfig, parse_host_port}; +use safekeeper_api::PgMajorVersion; use utils::auth::{Claims, Scope}; use utils::id::{NodeId, TenantId, TimelineId}; use utils::lsn::Lsn; @@ -120,7 +122,7 @@ impl PageServerNode { .env .generate_auth_token(&Claims::new(None, Scope::GenerationsApi)) .unwrap(); - overrides.push(format!("control_plane_api_token='{}'", jwt_token)); + overrides.push(format!("control_plane_api_token='{jwt_token}'")); } if !conf.other.contains_key("remote_storage") { @@ -254,9 +256,10 @@ impl PageServerNode { // the storage controller let metadata_path = datadir.join("metadata.json"); - let (_http_host, http_port) = + let http_host = "localhost".to_string(); + let (_, http_port) = parse_host_port(&self.conf.listen_http_addr).expect("Unable to parse listen_http_addr"); - let http_port = http_port.unwrap_or(9898); + let http_port = http_port.unwrap_or(DEFAULT_HTTP_LISTEN_PORT); let https_port = match self.conf.listen_https_addr.as_ref() { Some(https_addr) => { @@ -267,6 +270,13 @@ impl PageServerNode { None => None, }; + let (mut grpc_host, mut grpc_port) = (None, None); + if let Some(grpc_addr) = &self.conf.listen_grpc_addr { + let (_, port) = parse_host_port(grpc_addr).expect("Unable to parse listen_grpc_addr"); + grpc_host = Some("localhost".to_string()); + grpc_port = Some(port.unwrap_or(DEFAULT_GRPC_LISTEN_PORT)); + } + // Intentionally hand-craft JSON: this acts as an implicit format compat test // in case the pageserver-side structure is edited, and reflects the real life // situation: the metadata is written by some other script. @@ -275,7 +285,9 @@ impl PageServerNode { serde_json::to_vec(&pageserver_api::config::NodeMetadata { postgres_host: "localhost".to_string(), postgres_port: self.pg_connection_config.port(), - http_host: "localhost".to_string(), + grpc_host, + grpc_port, + http_host, http_port, https_port, other: HashMap::from([( @@ -598,7 +610,7 @@ impl PageServerNode { timeline_id: TimelineId, base: (Lsn, PathBuf), pg_wal: Option<(Lsn, PathBuf)>, - pg_version: u32, + pg_version: PgMajorVersion, ) -> anyhow::Result<()> { // Init base reader let (start_lsn, base_tarfile_path) = base; @@ -637,4 +649,16 @@ impl PageServerNode { Ok(()) } + pub async fn timeline_info( + &self, + tenant_shard_id: TenantShardId, + timeline_id: TimelineId, + force_await_logical_size: mgmt_api::ForceAwaitLogicalSize, + ) -> anyhow::Result { + let timeline_info = self + .http_client + .timeline_info(tenant_shard_id, timeline_id, force_await_logical_size) + .await?; + Ok(timeline_info) + } } diff --git a/control_plane/src/safekeeper.rs b/control_plane/src/safekeeper.rs index 25274d09d8..9234d4ce94 100644 --- a/control_plane/src/safekeeper.rs +++ b/control_plane/src/safekeeper.rs @@ -6,7 +6,6 @@ //! .neon/safekeepers/ //! ``` use std::error::Error as _; -use std::future::Future; use std::io::Write; use std::path::PathBuf; use std::time::Duration; @@ -14,9 +13,9 @@ use std::{io, result}; use anyhow::Context; use camino::Utf8PathBuf; -use http_utils::error::HttpErrorBody; use postgres_connection::PgConnectionConfig; -use reqwest::{IntoUrl, Method}; +use safekeeper_api::models::TimelineCreateRequest; +use safekeeper_client::mgmt_api; use thiserror::Error; use utils::auth::{Claims, Scope}; use utils::id::NodeId; @@ -35,25 +34,14 @@ pub enum SafekeeperHttpError { type Result = result::Result; -pub(crate) trait ResponseErrorMessageExt: Sized { - fn error_from_body(self) -> impl Future> + Send; -} - -impl ResponseErrorMessageExt for reqwest::Response { - async fn error_from_body(self) -> Result { - let status = self.status(); - if !(status.is_client_error() || status.is_server_error()) { - return Ok(self); - } - - // reqwest does not export its error construction utility functions, so let's craft the message ourselves - let url = self.url().to_owned(); - Err(SafekeeperHttpError::Response( - match self.json::().await { - Ok(err_body) => format!("Error: {}", err_body.msg), - Err(_) => format!("Http error ({}) at {}.", status.as_u16(), url), - }, - )) +fn err_from_client_err(err: mgmt_api::Error) -> SafekeeperHttpError { + use mgmt_api::Error::*; + match err { + ApiError(_, str) => SafekeeperHttpError::Response(str), + Cancelled => SafekeeperHttpError::Response("Cancelled".to_owned()), + ReceiveBody(err) => SafekeeperHttpError::Transport(err), + ReceiveErrorBody(err) => SafekeeperHttpError::Response(err), + Timeout(str) => SafekeeperHttpError::Response(format!("timeout: {str}")), } } @@ -70,9 +58,8 @@ pub struct SafekeeperNode { pub pg_connection_config: PgConnectionConfig, pub env: LocalEnv, - pub http_client: reqwest::Client, + pub http_client: mgmt_api::Client, pub listen_addr: String, - pub http_base_url: String, } impl SafekeeperNode { @@ -82,13 +69,14 @@ impl SafekeeperNode { } else { "127.0.0.1".to_string() }; + let jwt = None; + let http_base_url = format!("http://{}:{}", listen_addr, conf.http_port); SafekeeperNode { id: conf.id, conf: conf.clone(), pg_connection_config: Self::safekeeper_connection_config(&listen_addr, conf.pg_port), env: env.clone(), - http_client: env.create_http_client(), - http_base_url: format!("http://{}:{}/v1", listen_addr, conf.http_port), + http_client: mgmt_api::Client::new(env.create_http_client(), http_base_url, jwt), listen_addr, } } @@ -155,7 +143,7 @@ impl SafekeeperNode { let id_string = id.to_string(); // TODO: add availability_zone to the config. // Right now we just specify any value here and use it to check metrics in tests. - let availability_zone = format!("sk-{}", id_string); + let availability_zone = format!("sk-{id_string}"); let mut args = vec![ "-D".to_owned(), @@ -279,20 +267,19 @@ impl SafekeeperNode { ) } - fn http_request(&self, method: Method, url: U) -> reqwest::RequestBuilder { - // TODO: authentication - //if self.env.auth_type == AuthType::NeonJWT { - // builder = builder.bearer_auth(&self.env.safekeeper_auth_token) - //} - self.http_client.request(method, url) + pub async fn check_status(&self) -> Result<()> { + self.http_client + .status() + .await + .map_err(err_from_client_err)?; + Ok(()) } - pub async fn check_status(&self) -> Result<()> { - self.http_request(Method::GET, format!("{}/{}", self.http_base_url, "status")) - .send() - .await? - .error_from_body() - .await?; + pub async fn create_timeline(&self, req: &TimelineCreateRequest) -> Result<()> { + self.http_client + .create_timeline(req) + .await + .map_err(err_from_client_err)?; Ok(()) } } diff --git a/control_plane/src/storage_controller.rs b/control_plane/src/storage_controller.rs index 755d67a7ad..dea7ae2ccf 100644 --- a/control_plane/src/storage_controller.rs +++ b/control_plane/src/storage_controller.rs @@ -6,6 +6,8 @@ use std::str::FromStr; use std::sync::OnceLock; use std::time::{Duration, Instant}; +use crate::background_process; +use crate::local_env::{LocalEnv, NeonStorageControllerConf}; use camino::{Utf8Path, Utf8PathBuf}; use hyper0::Uri; use nix::unistd::Pid; @@ -22,6 +24,7 @@ use pageserver_client::mgmt_api::ResponseErrorMessageExt; use pem::Pem; use postgres_backend::AuthType; use reqwest::{Method, Response}; +use safekeeper_api::PgMajorVersion; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use tokio::process::Command; @@ -31,9 +34,6 @@ use utils::auth::{Claims, Scope, encode_from_key_file}; use utils::id::{NodeId, TenantId}; use whoami::username; -use crate::background_process; -use crate::local_env::{LocalEnv, NeonStorageControllerConf}; - pub struct StorageController { env: LocalEnv, private_key: Option, @@ -48,7 +48,7 @@ pub struct StorageController { const COMMAND: &str = "storage_controller"; -const STORAGE_CONTROLLER_POSTGRES_VERSION: u32 = 16; +const STORAGE_CONTROLLER_POSTGRES_VERSION: PgMajorVersion = PgMajorVersion::PG16; const DB_NAME: &str = "storage_controller"; @@ -167,7 +167,7 @@ impl StorageController { fn storage_controller_instance_dir(&self, instance_id: u8) -> PathBuf { self.env .base_data_dir - .join(format!("storage_controller_{}", instance_id)) + .join(format!("storage_controller_{instance_id}")) } fn pid_file(&self, instance_id: u8) -> Utf8PathBuf { @@ -184,9 +184,15 @@ impl StorageController { /// to other versions if that one isn't found. Some automated tests create circumstances /// where only one version is available in pg_distrib_dir, such as `test_remote_extensions`. async fn get_pg_dir(&self, dir_name: &str) -> anyhow::Result { - let prefer_versions = [STORAGE_CONTROLLER_POSTGRES_VERSION, 16, 15, 14]; + const PREFER_VERSIONS: [PgMajorVersion; 5] = [ + STORAGE_CONTROLLER_POSTGRES_VERSION, + PgMajorVersion::PG16, + PgMajorVersion::PG15, + PgMajorVersion::PG14, + PgMajorVersion::PG17, + ]; - for v in prefer_versions { + for v in PREFER_VERSIONS { let path = Utf8PathBuf::from_path_buf(self.env.pg_dir(v, dir_name)?).unwrap(); if tokio::fs::try_exists(&path).await? { return Ok(path); @@ -220,7 +226,7 @@ impl StorageController { "-d", DB_NAME, "-p", - &format!("{}", postgres_port), + &format!("{postgres_port}"), ]; let pg_lib_dir = self.get_pg_lib_dir().await.unwrap(); let envs = [ @@ -263,7 +269,7 @@ impl StorageController { "-h", "localhost", "-p", - &format!("{}", postgres_port), + &format!("{postgres_port}"), "-U", &username(), "-O", @@ -425,7 +431,7 @@ impl StorageController { // from `LocalEnv`'s config file (`.neon/config`). tokio::fs::write( &pg_data_path.join("postgresql.conf"), - format!("port = {}\nfsync=off\n", postgres_port), + format!("port = {postgres_port}\nfsync=off\n"), ) .await?; @@ -477,7 +483,7 @@ impl StorageController { self.setup_database(postgres_port).await?; } - let database_url = format!("postgresql://localhost:{}/{DB_NAME}", postgres_port); + let database_url = format!("postgresql://localhost:{postgres_port}/{DB_NAME}"); // We support running a startup SQL script to fiddle with the database before we launch storcon. // This is used by the test suite. @@ -508,7 +514,7 @@ impl StorageController { drop(client); conn.await??; - let addr = format!("{}:{}", host, listen_port); + let addr = format!("{host}:{listen_port}"); let address_for_peers = Uri::builder() .scheme(scheme) .authority(addr.clone()) @@ -557,6 +563,10 @@ impl StorageController { args.push("--use-local-compute-notifications".to_string()); } + if let Some(value) = self.config.kick_secondary_downloads { + args.push(format!("--kick-secondary-downloads={value}")); + } + if let Some(ssl_ca_file) = self.env.ssl_ca_cert_path() { args.push(format!("--ssl-ca-file={}", ssl_ca_file.to_str().unwrap())); } @@ -628,6 +638,22 @@ impl StorageController { args.push("--timelines-onto-safekeepers".to_string()); } + if let Some(sk_cnt) = self.config.timeline_safekeeper_count { + args.push(format!("--timeline-safekeeper-count={sk_cnt}")); + } + + let mut envs = vec![ + ("LD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()), + ("DYLD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()), + ]; + + if let Some(posthog_config) = &self.config.posthog_config { + envs.push(( + "POSTHOG_CONFIG".to_string(), + serde_json::to_string(posthog_config)?, + )); + } + println!("Starting storage controller"); background_process::start_process( @@ -635,10 +661,7 @@ impl StorageController { &instance_dir, &self.env.storage_controller_bin(), args, - vec![ - ("LD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()), - ("DYLD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()), - ], + envs, background_process::InitialPidFile::Create(self.pid_file(start_args.instance_id)), &start_args.start_timeout, || async { @@ -802,9 +825,9 @@ impl StorageController { builder = builder.json(&body) } if let Some(private_key) = &self.private_key { - println!("Getting claims for path {}", path); + println!("Getting claims for path {path}"); if let Some(required_claims) = Self::get_claims_for_path(&path)? { - println!("Got claims {:?} for path {}", required_claims, path); + println!("Got claims {required_claims:?} for path {path}"); let jwt_token = encode_from_key_file(&required_claims, private_key)?; builder = builder.header( reqwest::header::AUTHORIZATION, diff --git a/control_plane/storcon_cli/src/main.rs b/control_plane/storcon_cli/src/main.rs index 19c686dcfd..0036b7d0f6 100644 --- a/control_plane/storcon_cli/src/main.rs +++ b/control_plane/storcon_cli/src/main.rs @@ -36,6 +36,10 @@ enum Command { listen_pg_addr: String, #[arg(long)] listen_pg_port: u16, + #[arg(long)] + listen_grpc_addr: Option, + #[arg(long)] + listen_grpc_port: Option, #[arg(long)] listen_http_addr: String, @@ -61,10 +65,16 @@ enum Command { #[arg(long)] scheduling: Option, }, + // Set a node status as deleted. NodeDelete { #[arg(long)] node_id: NodeId, }, + /// Delete a tombstone of node from the storage controller. + NodeDeleteTombstone { + #[arg(long)] + node_id: NodeId, + }, /// Modify a tenant's policies in the storage controller TenantPolicy { #[arg(long)] @@ -82,6 +92,8 @@ enum Command { }, /// List nodes known to the storage controller Nodes {}, + /// List soft deleted nodes known to the storage controller + NodeTombstones {}, /// List tenants known to the storage controller Tenants { /// If this field is set, it will list the tenants on a specific node @@ -410,6 +422,8 @@ async fn main() -> anyhow::Result<()> { node_id, listen_pg_addr, listen_pg_port, + listen_grpc_addr, + listen_grpc_port, listen_http_addr, listen_http_port, listen_https_port, @@ -423,6 +437,8 @@ async fn main() -> anyhow::Result<()> { node_id, listen_pg_addr, listen_pg_port, + listen_grpc_addr, + listen_grpc_port, listen_http_addr, listen_http_port, listen_https_port, @@ -633,7 +649,7 @@ async fn main() -> anyhow::Result<()> { response .new_shards .iter() - .map(|s| format!("{:?}", s)) + .map(|s| format!("{s:?}")) .collect::>() .join(",") ); @@ -755,8 +771,8 @@ async fn main() -> anyhow::Result<()> { println!("Tenant {tenant_id}"); let mut table = comfy_table::Table::new(); - table.add_row(["Policy", &format!("{:?}", policy)]); - table.add_row(["Stripe size", &format!("{:?}", stripe_size)]); + table.add_row(["Policy", &format!("{policy:?}")]); + table.add_row(["Stripe size", &format!("{stripe_size:?}")]); table.add_row(["Config", &serde_json::to_string_pretty(&config).unwrap()]); println!("{table}"); println!("Shards:"); @@ -773,7 +789,7 @@ async fn main() -> anyhow::Result<()> { let secondary = shard .node_secondary .iter() - .map(|n| format!("{}", n)) + .map(|n| format!("{n}")) .collect::>() .join(","); @@ -847,7 +863,7 @@ async fn main() -> anyhow::Result<()> { } } else { // Make it obvious to the user that since they've omitted an AZ, we're clearing it - eprintln!("Clearing preferred AZ for tenant {}", tenant_id); + eprintln!("Clearing preferred AZ for tenant {tenant_id}"); } // Construct a request that modifies all the tenant's shards @@ -900,6 +916,39 @@ async fn main() -> anyhow::Result<()> { .dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None) .await?; } + Command::NodeDeleteTombstone { node_id } => { + storcon_client + .dispatch::<(), ()>( + Method::DELETE, + format!("debug/v1/tombstone/{node_id}"), + None, + ) + .await?; + } + Command::NodeTombstones {} => { + let mut resp = storcon_client + .dispatch::<(), Vec>( + Method::GET, + "debug/v1/tombstone".to_string(), + None, + ) + .await?; + + resp.sort_by(|a, b| a.listen_http_addr.cmp(&b.listen_http_addr)); + + let mut table = comfy_table::Table::new(); + table.set_header(["Id", "Hostname", "AZ", "Scheduling", "Availability"]); + for node in resp { + table.add_row([ + format!("{}", node.id), + node.listen_http_addr, + node.availability_zone_id, + format!("{:?}", node.scheduling), + format!("{:?}", node.availability), + ]); + } + println!("{table}"); + } Command::TenantSetTimeBasedEviction { tenant_id, period, @@ -1085,8 +1134,7 @@ async fn main() -> anyhow::Result<()> { Err((tenant_shard_id, from, to, error)) => { failure += 1; println!( - "Failed to migrate {} from node {} to node {}: {}", - tenant_shard_id, from, to, error + "Failed to migrate {tenant_shard_id} from node {from} to node {to}: {error}" ); } } @@ -1228,8 +1276,7 @@ async fn main() -> anyhow::Result<()> { concurrency, } => { let mut path = format!( - "/v1/tenant/{}/timeline/{}/download_heatmap_layers", - tenant_shard_id, timeline_id, + "/v1/tenant/{tenant_shard_id}/timeline/{timeline_id}/download_heatmap_layers", ); if let Some(c) = concurrency { @@ -1254,8 +1301,7 @@ async fn watch_tenant_shard( ) -> anyhow::Result<()> { if let Some(until_migrated_to) = until_migrated_to { println!( - "Waiting for tenant shard {} to be migrated to node {}", - tenant_shard_id, until_migrated_to + "Waiting for tenant shard {tenant_shard_id} to be migrated to node {until_migrated_to}" ); } @@ -1278,7 +1324,7 @@ async fn watch_tenant_shard( "attached: {} secondary: {} {}", shard .node_attached - .map(|n| format!("{}", n)) + .map(|n| format!("{n}")) .unwrap_or("none".to_string()), shard .node_secondary @@ -1292,15 +1338,12 @@ async fn watch_tenant_shard( "(reconciler idle)" } ); - println!("{}", summary); + println!("{summary}"); // Maybe drop out if we finished migration if let Some(until_migrated_to) = until_migrated_to { if shard.node_attached == Some(until_migrated_to) && !shard.is_reconciling { - println!( - "Tenant shard {} is now on node {}", - tenant_shard_id, until_migrated_to - ); + println!("Tenant shard {tenant_shard_id} is now on node {until_migrated_to}"); break; } } diff --git a/docker-compose/compute_wrapper/shell/compute.sh b/docker-compose/compute_wrapper/shell/compute.sh index c8ca812bf9..1e62e91fd0 100755 --- a/docker-compose/compute_wrapper/shell/compute.sh +++ b/docker-compose/compute_wrapper/shell/compute.sh @@ -95,3 +95,4 @@ echo "Start compute node" -b /usr/local/bin/postgres \ --compute-id "compute-${RANDOM}" \ --config "${CONFIG_FILE}" + --dev diff --git a/docker-compose/ext-src/postgis-src/neon-test.sh b/docker-compose/ext-src/postgis-src/neon-test.sh index 2866649a1b..13df1ec9d1 100755 --- a/docker-compose/ext-src/postgis-src/neon-test.sh +++ b/docker-compose/ext-src/postgis-src/neon-test.sh @@ -1,9 +1,6 @@ -#!/bin/bash +#!/bin/sh set -ex cd "$(dirname "$0")" -if [[ ${PG_VERSION} = v17 ]]; then - sed -i '/computed_columns/d' regress/core/tests.mk -fi -patch -p1 =" 120),1) +- TESTS += \ +- $(top_srcdir)/regress/core/computed_columns +-endif +- + ifeq ($(shell expr "$(POSTGIS_GEOS_VERSION)" ">=" 30700),1) + # GEOS-3.7 adds: + # ST_FrechetDistance diff --git a/regress/runtest.mk b/regress/runtest.mk index c051f03..010e493 100644 --- a/regress/runtest.mk diff --git a/docker-compose/ext-src/postgis-src/postgis-common-v17.patch b/docker-compose/ext-src/postgis-src/postgis-common-v17.patch new file mode 100644 index 0000000000..0b8978281e --- /dev/null +++ b/docker-compose/ext-src/postgis-src/postgis-common-v17.patch @@ -0,0 +1,35 @@ +diff --git a/regress/core/tests.mk b/regress/core/tests.mk +index 9e05244..90987df 100644 +--- a/regress/core/tests.mk ++++ b/regress/core/tests.mk +@@ -143,8 +143,7 @@ TESTS += \ + $(top_srcdir)/regress/core/oriented_envelope \ + $(top_srcdir)/regress/core/point_coordinates \ + $(top_srcdir)/regress/core/out_geojson \ +- $(top_srcdir)/regress/core/wrapx \ +- $(top_srcdir)/regress/core/computed_columns ++ $(top_srcdir)/regress/core/wrapx + + # Slow slow tests + TESTS_SLOW = \ +diff --git a/regress/runtest.mk b/regress/runtest.mk +index 4b95b7e..449d5a2 100644 +--- a/regress/runtest.mk ++++ b/regress/runtest.mk +@@ -24,16 +24,6 @@ check-regress: + + @POSTGIS_TOP_BUILD_DIR=$(abs_top_builddir) $(PERL) $(top_srcdir)/regress/run_test.pl $(RUNTESTFLAGS) $(RUNTESTFLAGS_INTERNAL) $(TESTS) + +- @if echo "$(RUNTESTFLAGS)" | grep -vq -- --upgrade; then \ +- echo "Running upgrade test as RUNTESTFLAGS did not contain that"; \ +- POSTGIS_TOP_BUILD_DIR=$(abs_top_builddir) $(PERL) $(top_srcdir)/regress/run_test.pl \ +- --upgrade \ +- $(RUNTESTFLAGS) \ +- $(RUNTESTFLAGS_INTERNAL) \ +- $(TESTS); \ +- else \ +- echo "Skipping upgrade test as RUNTESTFLAGS already requested upgrades"; \ +- fi + + check-long: + $(PERL) $(top_srcdir)/regress/run_test.pl $(RUNTESTFLAGS) $(TESTS) $(TESTS_SLOW) diff --git a/docker-compose/ext-src/postgis-src/postgis-regular-v16.patch b/docker-compose/ext-src/postgis-src/postgis-regular-v16.patch index 2fd214c534..e7f01ad288 100644 --- a/docker-compose/ext-src/postgis-src/postgis-regular-v16.patch +++ b/docker-compose/ext-src/postgis-src/postgis-regular-v16.patch @@ -125,7 +125,7 @@ index 7a36b65..ad78fc7 100644 DROP SCHEMA tm CASCADE; + diff --git a/regress/core/tests.mk b/regress/core/tests.mk -index 3abd7bc..94903c3 100644 +index 64a9254..94903c3 100644 --- a/regress/core/tests.mk +++ b/regress/core/tests.mk @@ -23,7 +23,6 @@ current_dir := $(dir $(abspath $(lastword $(MAKEFILE_LIST)))) @@ -160,18 +160,6 @@ index 3abd7bc..94903c3 100644 $(top_srcdir)/regress/core/wkb \ $(top_srcdir)/regress/core/wkt \ $(top_srcdir)/regress/core/wmsservers \ -@@ -144,11 +140,6 @@ TESTS_SLOW = \ - $(top_srcdir)/regress/core/concave_hull_hard \ - $(top_srcdir)/regress/core/knn_recheck - --ifeq ($(shell expr "$(POSTGIS_PGSQL_VERSION)" ">=" 120),1) -- TESTS += \ -- $(top_srcdir)/regress/core/computed_columns --endif -- - ifeq ($(shell expr "$(POSTGIS_GEOS_VERSION)" ">=" 30700),1) - # GEOS-3.7 adds: - # ST_FrechetDistance diff --git a/regress/loader/tests.mk b/regress/loader/tests.mk index 1fc77ac..c3cb9de 100644 --- a/regress/loader/tests.mk diff --git a/docker-compose/ext-src/postgis-src/postgis-regular-v17.patch b/docker-compose/ext-src/postgis-src/postgis-regular-v17.patch index f4a9d83478..ae76e559df 100644 --- a/docker-compose/ext-src/postgis-src/postgis-regular-v17.patch +++ b/docker-compose/ext-src/postgis-src/postgis-regular-v17.patch @@ -125,7 +125,7 @@ index 7a36b65..ad78fc7 100644 DROP SCHEMA tm CASCADE; + diff --git a/regress/core/tests.mk b/regress/core/tests.mk -index 9e05244..a63a3e1 100644 +index 90987df..74fe3f1 100644 --- a/regress/core/tests.mk +++ b/regress/core/tests.mk @@ -16,14 +16,13 @@ POSTGIS_PGSQL_VERSION=170 @@ -168,16 +168,6 @@ index 9e05244..a63a3e1 100644 $(top_srcdir)/regress/core/wkb \ $(top_srcdir)/regress/core/wkt \ $(top_srcdir)/regress/core/wmsservers \ -@@ -143,8 +139,7 @@ TESTS += \ - $(top_srcdir)/regress/core/oriented_envelope \ - $(top_srcdir)/regress/core/point_coordinates \ - $(top_srcdir)/regress/core/out_geojson \ -- $(top_srcdir)/regress/core/wrapx \ -- $(top_srcdir)/regress/core/computed_columns -+ $(top_srcdir)/regress/core/wrapx - - # Slow slow tests - TESTS_SLOW = \ diff --git a/regress/loader/tests.mk b/regress/loader/tests.mk index ac4f8ad..4bad4fc 100644 --- a/regress/loader/tests.mk diff --git a/docker-compose/ext-src/postgis-src/regular-test.sh b/docker-compose/ext-src/postgis-src/regular-test.sh index 4b0b929946..1b1683b3f1 100755 --- a/docker-compose/ext-src/postgis-src/regular-test.sh +++ b/docker-compose/ext-src/postgis-src/regular-test.sh @@ -10,8 +10,8 @@ psql -d contrib_regression -c "ALTER DATABASE contrib_regression SET TimeZone='U -c "CREATE EXTENSION postgis_tiger_geocoder CASCADE" \ -c "CREATE EXTENSION postgis_raster SCHEMA public" \ -c "CREATE EXTENSION postgis_sfcgal SCHEMA public" -patch -p1 / + tenants/ + / + tenants/ + / + endpoints/ + / + pgdata/ + + +For other blob storages an equivalent or similar path can be constructed. + +### Reliability, failure modes and corner cases (if relevant) +Reliability is important, but not critical to the workings of Neon. The data +stored in this service will, when lost, reduce performance, but won't be a +cause of permanent data loss - only operational metadata is stored. + +Most, if not all, blob storage services have sufficiently high persistence +guarantees to cater our need for persistence and uptime. The only concern with +blob storages is that the access latency is generally higher than local disk, +but for the object types stored (cache state, ...) I don't think this will be +much of an issue. + +### Interaction/Sequence diagram (if relevant) + +In these diagrams you can replace S3 with any persistent storage device of +choice, but S3 is chosen as representative name: The well-known and short name +of AWS' blob storage. Azure Blob Storage should work too, but it has a much +longer name making it less practical for the diagrams. + +Write data: + +```http +POST /tenants//timelines//endpoints//pgdata/ +Host: epufs.svc.neon.local + +<<< + +200 OK +{ + "version": "", # opaque file version token, changes when the file contents change + "size": , +} +``` + +```mermaid +sequenceDiagram + autonumber + participant co as Compute + participant ep as EPUFS + participant s3 as Blob Storage + + co-->ep: Connect with credentials + co->>+ep: Store Unlogged Persistent File + opt is authenticated + ep->>s3: Write UPF to S3 + end + ep->>-co: OK / Failure / Auth Failure + co-->ep: Cancel connection +``` + +Read data: (optional with cache-relevant request parameters, e.g. If-Modified-Since) +```http +GET /tenants//timelines//endpoints//pgdata/ +Host: epufs.svc.neon.local + +<<< + +200 OK + + +``` + +```mermaid +sequenceDiagram + autonumber + participant co as Compute + participant ep as EPUFS + participant s3 as Blob Storage + + co->>+ep: Read Unlogged Persistent File + opt is authenticated + ep->>+s3: Request UPF from storage + s3->>-ep: Receive UPF from storage + end + ep->>-co: OK(response) / Failure(storage, auth, ...) +``` + +Compute Startup: +```mermaid +sequenceDiagram + autonumber + participant co as Compute + participant ps as Pageserver + participant ep as EPUFS + participant es as Extension server + + note over co: Bind endpoint ep-xxx + par Get basebackup + co->>+ps: Request basebackup @ LSN + ps-)ps: Construct basebackup + ps->>-co: Receive basebackup TAR @ LSN + and Get startup-critical Unlogged Persistent Files + co->>+ep: Get all UPFs of endpoint ep-xxx + ep-)ep: Retrieve and gather all UPFs + ep->>-co: TAR of UPFs + and Get startup-critical extensions + loop For every startup-critical extension + co->>es: Get critical extension + es->>co: Receive critical extension + end + end + note over co: Start compute +``` + +CPlane ops: +```http +DELETE /tenants//timelines//endpoints/ +Host: epufs.svc.neon.local + +<<< + +200 OK +{ + "tenant": "", + "timeline": "", + "endpoint": "", + "deleted": { + "files": , + "bytes": , + }, +} +``` + +```http +DELETE /tenants//timelines/ +Host: epufs.svc.neon.local + +<<< + +200 OK +{ + "tenant": "", + "timeline": "", + "deleted": { + "files": , + "bytes": , + }, +} +``` + +```http +DELETE /tenants/ +Host: epufs.svc.neon.local + +<<< + +200 OK +{ + "tenant": "", + "deleted": { + "files": , + "bytes": , + }, +} +``` + +```mermaid +sequenceDiagram + autonumber + participant cp as Control Plane + participant ep as EPUFS + participant s3 as Blob Storage + + alt Tenant deleted + cp-)ep: Tenant deleted + loop For every object associated with removed tenant + ep->>s3: Remove data of deleted tenant from Storage + end + opt + ep-)cp: Tenant cleanup complete + end + alt Timeline deleted + cp-)ep: Timeline deleted + loop For every object associated with removed timeline + ep->>s3: Remove data of deleted timeline from Storage + end + opt + ep-)cp: Timeline cleanup complete + end + else Endpoint reassigned or removed + cp->>+ep: Endpoint reassigned + loop For every object associated with reassigned/removed endpoint + ep->>s3: Remove data from Storage + end + ep->>-cp: Cleanup complete + end +``` + +### Scalability (if relevant) + +Provisionally: As this service is going to be part of compute startup, this +service should be able to quickly respond to all requests. Therefore this +service is deployed to every AZ we host Computes in, and Computes communicate +(generally) only to the EPUFS endpoint of the AZ they're hosted in. + +Local caching of frequently restarted endpoints' data or metadata may be +needed for best performance. However, due to the regional nature of stored +data but zonal nature of the service deployment, we should be careful when we +implement any local caching, as it is possible that computes in AZ 1 will +update data originally written and thus cached by AZ 2. Cache version tests +and invalidation is therefore required if we want to roll out caching to this +service, which is too broad a scope for an MVC. This is why caching is left +out of scope for this RFC, and should be considered separately after this RFC +is implemented. + +### Security implications (if relevant) +This service must be able to authenticate users at least by Tenant ID, +Timeline ID and Endpoint ID. This will use the existing JWT infrastructure of +Compute, which will be upgraded to the extent needed to support Timeline- and +Endpoint-based claims. + +The service requires unlimited access to (a prefix of) a blob storage bucket, +and thus must be hosted outside the Compute VM sandbox. + +A service that generates pre-signed request URLs for Compute to download the +data from that URL is likely problematic, too: Compute would be able to write +unlimited data to the bucket, or exfiltrate this signed URL to get read/write +access to specific objects in this bucket, which would still effectively give +users access to the S3 bucket (but with improved access logging). + +There may be a use case for transferring data associated with one endpoint to +another endpoint (e.g. to make one endpoint warm its caches with the state of +another endpoint), but that's not currently in scope, and specific needs may +be solved through out-of-line communication of data or pre-signed URLs. + +### Unresolved questions (if relevant) +Caching of files is not in the implementation scope of the document, but +should at some future point be considered to maximize performance. + +## Alternative implementation (if relevant) +Several ideas have come up to solve this issue: + +### Use AUXfile +One prevalent idea was to WAL-log the files using our AUXfile mechanism. + +Benefits: + ++ We already have this storage mechanism + +Demerits: + +- It isn't available on read replicas +- Additional WAL will be consumed during shutdown and after the shutdown + checkpoint, which needs PG modifications to work without panics. +- It increases the data we need to manage in our versioned storage, thus + causing higher storage costs with higher retention due to duplication at + the storage layer. + +### Sign URLs for read/write operations, instead of proxying them + +Benefits: + ++ The service can be implemented with a much reduced IO budget + +Demerits: + +- Users could get access to these signed credentials +- Not all blob storage services may implement URL signing + +### Give endpoints each their own directly accessed block volume + +Benefits: + ++ Easier to integrate for PostgreSQL + +Demerits: + +- Little control on data size and contents +- Potentially problematic as we'd need to store data all across the pgdata + directory. +- EBS is not a good candidate + - Attaches in 10s of seconds, if not more; i.e. too cold to start + - Shared EBS volumes are a no-go, as you'd have to schedule the endpoint + with users of the same EBS volumes, which can't work with VM migration + - EBS storage costs are very high (>80$/kilotenant when using a + volume/tenant) + - EBS volumes can't be mounted across AZ boundaries +- Bucket per endpoint is unfeasible + - S3 buckets are priced at $20/month per 1k, which we could better spend + on developers. + - Allocating service accounts takes time (100s of ms), and service accounts + are a limited resource, too; so they're not a good candidate to allocate + on a per-endpoint basis. + - Giving credentials limited to prefix has similar issues as the pre-signed + URL approach. + - Bucket DNS lookup will fill DNS caches and put pressure on DNS lookup + much more than our current systems would. +- Volumes bound by hypervisor are unlikely + - This requires significant investment and increased software on the + hypervisor. + - It is unclear if we can attach volumes after boot, i.e. for pooled + instances. + +### Put the files into a table + +Benefits: + + + Mostly already available in PostgreSQL + +Demerits: + + - Uses WAL + - Can't be used after shutdown checkpoint + - Needs a RW endpoint, and table & catalog access to write to this data + - Gets hit with DB size limitations + - Depending on user acces: + - Inaccessible: + The user doesn't have control over database size caused by + these systems. + - Accessible: + The user can corrupt these files and cause the system to crash while + user-corrupted files are present, thus increasing on-call overhead. + +## Definition of Done (if relevant) + +This project is done if we have: + +- One S3 bucket equivalent per region, which stores this per-endpoint data. +- A new service endpoint in at least every AZ, which indirectly grants + endpoints access to the data stored for these endpoints in these buckets. +- Compute writes & reads temp-data at shutdown and startup, respectively, for + at least the pg_prewarm or lfc_prewarm state files. +- Cleanup of endpoint data is triggered when the endpoint is deleted or is + detached from its current timeline. diff --git a/endpoint_storage/Cargo.toml b/endpoint_storage/Cargo.toml index b2c9d51551..c2e21d02e2 100644 --- a/endpoint_storage/Cargo.toml +++ b/endpoint_storage/Cargo.toml @@ -8,6 +8,7 @@ anyhow.workspace = true axum-extra.workspace = true axum.workspace = true camino.workspace = true +clap.workspace = true futures.workspace = true jsonwebtoken.workspace = true prometheus.workspace = true diff --git a/endpoint_storage/src/app.rs b/endpoint_storage/src/app.rs index f44efe6d7a..42431c0066 100644 --- a/endpoint_storage/src/app.rs +++ b/endpoint_storage/src/app.rs @@ -374,7 +374,7 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH let request = Request::builder() .uri(format!("/{tenant}/{timeline}/{endpoint}/sub/path/key")) .method(method) - .header("Authorization", format!("Bearer {}", token)) + .header("Authorization", format!("Bearer {token}")) .body(Body::empty()) .unwrap(); let status = ServiceExt::ready(&mut app) diff --git a/endpoint_storage/src/main.rs b/endpoint_storage/src/main.rs index 3d1f05575d..c96cef2083 100644 --- a/endpoint_storage/src/main.rs +++ b/endpoint_storage/src/main.rs @@ -4,6 +4,8 @@ //! for large computes. mod app; use anyhow::Context; +use clap::Parser; +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use tracing::info; use utils::logging; @@ -12,13 +14,29 @@ const fn max_upload_file_limit() -> usize { 100 * 1024 * 1024 } +const fn listen() -> SocketAddr { + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 51243) +} + +#[derive(Parser)] +struct Args { + #[arg(exclusive = true)] + config_file: Option, + #[arg(long, default_value = "false", requires = "config")] + /// to allow testing k8s helm chart where we don't have s3 credentials + no_s3_check_on_startup: bool, + #[arg(long, value_name = "FILE")] + /// inline config mode for k8s helm chart + config: Option, +} + #[derive(serde::Deserialize)] -#[serde(tag = "type")] struct Config { + #[serde(default = "listen")] listen: std::net::SocketAddr, pemfile: camino::Utf8PathBuf, #[serde(flatten)] - storage_config: remote_storage::RemoteStorageConfig, + storage_kind: remote_storage::TypedRemoteStorageKind, #[serde(default = "max_upload_file_limit")] max_upload_file_limit: usize, } @@ -31,13 +49,18 @@ async fn main() -> anyhow::Result<()> { logging::Output::Stdout, )?; - let config: String = std::env::args().skip(1).take(1).collect(); - if config.is_empty() { - anyhow::bail!("Usage: endpoint_storage config.json") - } - info!("Reading config from {config}"); - let config = std::fs::read_to_string(config.clone())?; - let config: Config = serde_json::from_str(&config).context("parsing config")?; + let args = Args::parse(); + let config: Config = if let Some(config_path) = args.config_file { + info!("Reading config from {config_path}"); + let config = std::fs::read_to_string(config_path)?; + serde_json::from_str(&config).context("parsing config")? + } else if let Some(config) = args.config { + info!("Reading inline config"); + serde_json::from_str(&config).context("parsing config")? + } else { + anyhow::bail!("Supply either config file path or --config=inline-config"); + }; + info!("Reading pemfile from {}", config.pemfile.clone()); let pemfile = std::fs::read(config.pemfile.clone())?; info!("Loading public key from {}", config.pemfile.clone()); @@ -46,9 +69,12 @@ async fn main() -> anyhow::Result<()> { let listener = tokio::net::TcpListener::bind(config.listen).await.unwrap(); info!("listening on {}", listener.local_addr().unwrap()); - let storage = remote_storage::GenericRemoteStorage::from_config(&config.storage_config).await?; + let storage = + remote_storage::GenericRemoteStorage::from_storage_kind(config.storage_kind).await?; let cancel = tokio_util::sync::CancellationToken::new(); - app::check_storage_permissions(&storage, cancel.clone()).await?; + if !args.no_s3_check_on_startup { + app::check_storage_permissions(&storage, cancel.clone()).await?; + } let proxy = std::sync::Arc::new(endpoint_storage::Storage { auth, diff --git a/libs/compute_api/src/requests.rs b/libs/compute_api/src/requests.rs index bbab271474..745c44c05b 100644 --- a/libs/compute_api/src/requests.rs +++ b/libs/compute_api/src/requests.rs @@ -16,6 +16,7 @@ pub static COMPUTE_AUDIENCE: &str = "compute"; pub enum ComputeClaimsScope { /// An admin-scoped token allows access to all of `compute_ctl`'s authorized /// facilities. + #[serde(rename = "compute_ctl:admin")] Admin, } @@ -24,7 +25,7 @@ impl FromStr for ComputeClaimsScope { fn from_str(s: &str) -> Result { match s { - "admin" => Ok(ComputeClaimsScope::Admin), + "compute_ctl:admin" => Ok(ComputeClaimsScope::Admin), _ => Err(anyhow::anyhow!("invalid compute claims scope \"{s}\"")), } } @@ -80,3 +81,23 @@ pub struct SetRoleGrantsRequest { pub privileges: Vec, pub role: PgIdent, } + +#[cfg(test)] +mod test { + use std::str::FromStr; + + use crate::requests::ComputeClaimsScope; + + /// Confirm that whether we parse the scope by string or through serde, the + /// same values parse to the same enum variant. + #[test] + fn compute_request_scopes() { + const ADMIN_SCOPE: &str = "compute_ctl:admin"; + + let from_serde: ComputeClaimsScope = + serde_json::from_str(&format!("\"{ADMIN_SCOPE}\"")).unwrap(); + let from_str = ComputeClaimsScope::from_str(ADMIN_SCOPE).unwrap(); + + assert_eq!(from_serde, from_str); + } +} diff --git a/libs/compute_api/src/responses.rs b/libs/compute_api/src/responses.rs index 24d371c6eb..5cad849e3d 100644 --- a/libs/compute_api/src/responses.rs +++ b/libs/compute_api/src/responses.rs @@ -83,6 +83,16 @@ pub struct ComputeStatusResponse { pub error: Option, } +#[derive(Serialize, Clone, Copy, Debug, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "snake_case")] +pub enum TerminateMode { + #[default] + /// wait 30s till returning from /terminate to allow control plane to get the error + Fast, + /// return from /terminate immediately as soon as all components are terminated + Immediate, +} + #[derive(Serialize, Clone, Copy, Debug, Deserialize, PartialEq, Eq)] #[serde(rename_all = "snake_case")] pub enum ComputeStatus { @@ -103,11 +113,16 @@ pub enum ComputeStatus { // control-plane to terminate it. Failed, // Termination requested - TerminationPending, + TerminationPending { mode: TerminateMode }, // Terminated Postgres Terminated, } +#[derive(Deserialize, Serialize)] +pub struct TerminateResponse { + pub lsn: Option, +} + impl Display for ComputeStatus { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -117,7 +132,7 @@ impl Display for ComputeStatus { ComputeStatus::Running => f.write_str("running"), ComputeStatus::Configuration => f.write_str("configuration"), ComputeStatus::Failed => f.write_str("failed"), - ComputeStatus::TerminationPending => f.write_str("termination-pending"), + ComputeStatus::TerminationPending { .. } => f.write_str("termination-pending"), ComputeStatus::Terminated => f.write_str("terminated"), } } diff --git a/libs/desim/src/executor.rs b/libs/desim/src/executor.rs index df8b071c06..bdb9c6cd4b 100644 --- a/libs/desim/src/executor.rs +++ b/libs/desim/src/executor.rs @@ -71,7 +71,7 @@ impl Runtime { debug!("thread panicked: {:?}", e); let mut result = ctx.result.lock(); if result.0 == -1 { - *result = (256, format!("thread panicked: {:?}", e)); + *result = (256, format!("thread panicked: {e:?}")); } }); } @@ -419,13 +419,13 @@ pub fn now() -> u64 { with_thread_context(|ctx| ctx.clock.get().unwrap().now()) } -pub fn exit(code: i32, msg: String) { +pub fn exit(code: i32, msg: String) -> ! { with_thread_context(|ctx| { ctx.allow_panic.store(true, Ordering::SeqCst); let mut result = ctx.result.lock(); *result = (code, msg); panic!("exit"); - }); + }) } pub(crate) fn get_thread_ctx() -> Arc { diff --git a/libs/desim/src/proto.rs b/libs/desim/src/proto.rs index 31bc29e6a6..7c3de4ff4b 100644 --- a/libs/desim/src/proto.rs +++ b/libs/desim/src/proto.rs @@ -47,8 +47,8 @@ impl Debug for AnyMessage { match self { AnyMessage::None => write!(f, "None"), AnyMessage::InternalConnect => write!(f, "InternalConnect"), - AnyMessage::Just32(v) => write!(f, "Just32({})", v), - AnyMessage::ReplCell(v) => write!(f, "ReplCell({:?})", v), + AnyMessage::Just32(v) => write!(f, "Just32({v})"), + AnyMessage::ReplCell(v) => write!(f, "ReplCell({v:?})"), AnyMessage::Bytes(v) => write!(f, "Bytes({})", hex::encode(v)), AnyMessage::LSN(v) => write!(f, "LSN({})", Lsn(*v)), } diff --git a/libs/http-utils/src/endpoint.rs b/libs/http-utils/src/endpoint.rs index 64147f2dd0..f32ced1180 100644 --- a/libs/http-utils/src/endpoint.rs +++ b/libs/http-utils/src/endpoint.rs @@ -582,14 +582,14 @@ pub fn attach_openapi_ui( deepLinking: true, showExtensions: true, showCommonExtensions: true, - url: "{}", + url: "{spec_mount_path}", }}) window.ui = ui; }}; - "#, spec_mount_path))).unwrap()) + "#))).unwrap()) }) ) } @@ -696,7 +696,7 @@ mod tests { let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80); let mut service = builder.build(remote_addr); if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await { - panic!("request service is not ready: {:?}", e); + panic!("request service is not ready: {e:?}"); } let mut req: Request = Request::default(); @@ -716,7 +716,7 @@ mod tests { let remote_addr = SocketAddr::new(IpAddr::from_str("127.0.0.1").unwrap(), 80); let mut service = builder.build(remote_addr); if let Err(e) = poll_fn(|ctx| service.poll_ready(ctx)).await { - panic!("request service is not ready: {:?}", e); + panic!("request service is not ready: {e:?}"); } let req: Request = Request::default(); diff --git a/libs/neon-shmem/src/lib.rs b/libs/neon-shmem/src/lib.rs index e1b14b1371..c689959b68 100644 --- a/libs/neon-shmem/src/lib.rs +++ b/libs/neon-shmem/src/lib.rs @@ -86,7 +86,7 @@ impl ShmemHandle { // somewhat smaller than that, because with anything close to that, you'll run out of // memory anyway. if max_size >= 1 << 48 { - panic!("max size {} too large", max_size); + panic!("max size {max_size} too large"); } if initial_size > max_size { panic!("initial size {initial_size} larger than max size {max_size}"); @@ -279,7 +279,7 @@ mod tests { fn assert_range(ptr: *const u8, expected: u8, range: Range) { for i in range { let b = unsafe { *(ptr.add(i)) }; - assert_eq!(expected, b, "unexpected byte at offset {}", i); + assert_eq!(expected, b, "unexpected byte at offset {i}"); } } diff --git a/libs/pageserver_api/Cargo.toml b/libs/pageserver_api/Cargo.toml index 25f29b8ecd..a34e065788 100644 --- a/libs/pageserver_api/Cargo.toml +++ b/libs/pageserver_api/Cargo.toml @@ -17,7 +17,8 @@ anyhow.workspace = true bytes.workspace = true byteorder.workspace = true utils.workspace = true -postgres_ffi.workspace = true +postgres_ffi_types.workspace = true +postgres_versioninfo.workspace = true enum-map.workspace = true strum.workspace = true strum_macros.workspace = true diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index 30b0612082..0cfa1c8485 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -12,6 +12,7 @@ pub const DEFAULT_HTTP_LISTEN_ADDR: &str = formatcp!("127.0.0.1:{DEFAULT_HTTP_LI pub const DEFAULT_GRPC_LISTEN_PORT: u16 = 51051; // storage-broker already uses 50051 use std::collections::HashMap; +use std::fmt::Display; use std::num::{NonZeroU64, NonZeroUsize}; use std::str::FromStr; use std::time::Duration; @@ -24,16 +25,17 @@ use utils::logging::LogFormat; use crate::models::{ImageCompressionAlgorithm, LsnLease}; // Certain metadata (e.g. externally-addressable name, AZ) is delivered -// as a separate structure. This information is not neeed by the pageserver +// as a separate structure. This information is not needed by the pageserver // itself, it is only used for registering the pageserver with the control // plane and/or storage controller. -// #[derive(PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)] pub struct NodeMetadata { #[serde(rename = "host")] pub postgres_host: String, #[serde(rename = "port")] pub postgres_port: u16, + pub grpc_host: Option, + pub grpc_port: Option, pub http_host: String, pub http_port: u16, pub https_port: Option, @@ -44,7 +46,25 @@ pub struct NodeMetadata { pub other: HashMap, } -/// PostHog integration config. +impl Display for NodeMetadata { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "postgresql://{}:{} ", + self.postgres_host, self.postgres_port + )?; + if let Some(grpc_host) = &self.grpc_host { + let grpc_port = self.grpc_port.unwrap_or_default(); + write!(f, "grpc://{grpc_host}:{grpc_port} ")?; + } + write!(f, "http://{}:{} ", self.http_host, self.http_port)?; + write!(f, "other:{:?}", self.other)?; + Ok(()) + } +} + +/// PostHog integration config. This is used in pageserver, storcon, and neon_local. +/// Ensure backward compatibility when adding new fields. #[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub struct PostHogConfig { /// PostHog project ID @@ -57,6 +77,13 @@ pub struct PostHogConfig { pub private_api_url: String, /// Public API URL pub public_api_url: String, + /// Refresh interval for the feature flag spec. + /// The storcon will push the feature flag spec to the pageserver. If the pageserver does not receive + /// the spec for `refresh_interval`, it will fetch the spec from the PostHog API. + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(with = "humantime_serde")] + pub refresh_interval: Option, } /// `pageserver.toml` @@ -337,17 +364,26 @@ pub struct TimelineImportConfig { pub struct BasebackupCacheConfig { #[serde(with = "humantime_serde")] pub cleanup_period: Duration, - // FIXME: Support max_size_bytes. - // pub max_size_bytes: usize, - pub max_size_entries: i64, + /// Maximum total size of basebackup cache entries on disk in bytes. + /// The cache may slightly exceed this limit because we do not know + /// the exact size of the cache entry untill it's written to disk. + pub max_total_size_bytes: u64, + // TODO(diko): support max_entry_size_bytes. + // pub max_entry_size_bytes: u64, + pub max_size_entries: usize, + /// Size of the channel used to send prepare requests to the basebackup cache worker. + /// If exceeded, new prepare requests will be dropped. + pub prepare_channel_size: usize, } impl Default for BasebackupCacheConfig { fn default() -> Self { Self { cleanup_period: Duration::from_secs(60), - // max_size_bytes: 1024 * 1024 * 1024, // 1 GiB + max_total_size_bytes: 1024 * 1024 * 1024, // 1 GiB + // max_entry_size_bytes: 16 * 1024 * 1024, // 16 MiB max_size_entries: 1000, + prepare_channel_size: 100, } } } @@ -792,7 +828,7 @@ pub mod tenant_conf_defaults { // By default ingest enough WAL for two new L0 layers before checking if new image // image layers should be created. pub const DEFAULT_IMAGE_LAYER_CREATION_CHECK_THRESHOLD: u8 = 2; - pub const DEFAULT_GC_COMPACTION_ENABLED: bool = false; + pub const DEFAULT_GC_COMPACTION_ENABLED: bool = true; pub const DEFAULT_GC_COMPACTION_VERIFICATION: bool = true; pub const DEFAULT_GC_COMPACTION_INITIAL_THRESHOLD_KB: u64 = 5 * 1024 * 1024; // 5GB pub const DEFAULT_GC_COMPACTION_RATIO_PERCENT: u64 = 100; diff --git a/libs/pageserver_api/src/config/tests.rs b/libs/pageserver_api/src/config/tests.rs index 9e61873273..7137df969a 100644 --- a/libs/pageserver_api/src/config/tests.rs +++ b/libs/pageserver_api/src/config/tests.rs @@ -14,6 +14,8 @@ fn test_node_metadata_v1_backward_compatibilty() { NodeMetadata { postgres_host: "localhost".to_string(), postgres_port: 23, + grpc_host: None, + grpc_port: None, http_host: "localhost".to_string(), http_port: 42, https_port: None, @@ -37,6 +39,35 @@ fn test_node_metadata_v2_backward_compatibilty() { NodeMetadata { postgres_host: "localhost".to_string(), postgres_port: 23, + grpc_host: None, + grpc_port: None, + http_host: "localhost".to_string(), + http_port: 42, + https_port: Some(123), + other: HashMap::new(), + } + ) +} + +#[test] +fn test_node_metadata_v3_backward_compatibilty() { + let v3 = serde_json::to_vec(&serde_json::json!({ + "host": "localhost", + "port": 23, + "grpc_host": "localhost", + "grpc_port": 51, + "http_host": "localhost", + "http_port": 42, + "https_port": 123, + })); + + assert_eq!( + serde_json::from_slice::(&v3.unwrap()).unwrap(), + NodeMetadata { + postgres_host: "localhost".to_string(), + postgres_port: 23, + grpc_host: Some("localhost".to_string()), + grpc_port: Some(51), http_host: "localhost".to_string(), http_port: 42, https_port: Some(123), diff --git a/libs/pageserver_api/src/controller_api.rs b/libs/pageserver_api/src/controller_api.rs index c5b49edba0..ff18d40bfe 100644 --- a/libs/pageserver_api/src/controller_api.rs +++ b/libs/pageserver_api/src/controller_api.rs @@ -52,6 +52,8 @@ pub struct NodeRegisterRequest { pub listen_pg_addr: String, pub listen_pg_port: u16, + pub listen_grpc_addr: Option, + pub listen_grpc_port: Option, pub listen_http_addr: String, pub listen_http_port: u16, @@ -101,6 +103,8 @@ pub struct TenantLocateResponseShard { pub listen_pg_addr: String, pub listen_pg_port: u16, + pub listen_grpc_addr: Option, + pub listen_grpc_port: Option, pub listen_http_addr: String, pub listen_http_port: u16, @@ -152,6 +156,8 @@ pub struct NodeDescribeResponse { pub listen_pg_addr: String, pub listen_pg_port: u16, + pub listen_grpc_addr: Option, + pub listen_grpc_port: Option, } #[derive(Serialize, Deserialize, Debug)] @@ -344,6 +350,35 @@ impl Default for ShardSchedulingPolicy { } } +#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)] +pub enum NodeLifecycle { + Active, + Deleted, +} + +impl FromStr for NodeLifecycle { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s { + "active" => Ok(Self::Active), + "deleted" => Ok(Self::Deleted), + _ => Err(anyhow::anyhow!("Unknown node lifecycle '{s}'")), + } + } +} + +impl From for String { + fn from(value: NodeLifecycle) -> String { + use NodeLifecycle::*; + match value { + Active => "active", + Deleted => "deleted", + } + .to_string() + } +} + #[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)] pub enum NodeSchedulingPolicy { Active, @@ -542,8 +577,7 @@ mod test { let err = serde_json::from_value::(create_request).unwrap_err(); assert!( err.to_string().contains("unknown field `unknown_field`"), - "expect unknown field `unknown_field` error, got: {}", - err + "expect unknown field `unknown_field` error, got: {err}" ); } diff --git a/libs/pageserver_api/src/key.rs b/libs/pageserver_api/src/key.rs index c14975167b..102bbee879 100644 --- a/libs/pageserver_api/src/key.rs +++ b/libs/pageserver_api/src/key.rs @@ -4,8 +4,8 @@ use std::ops::Range; use anyhow::{Result, bail}; use byteorder::{BE, ByteOrder}; use bytes::Bytes; -use postgres_ffi::relfile_utils::{FSM_FORKNUM, VISIBILITYMAP_FORKNUM}; -use postgres_ffi::{Oid, RepOriginId}; +use postgres_ffi_types::forknum::{FSM_FORKNUM, VISIBILITYMAP_FORKNUM}; +use postgres_ffi_types::{Oid, RepOriginId}; use serde::{Deserialize, Serialize}; use utils::const_assert; @@ -194,7 +194,7 @@ impl Key { /// will be rejected on the write path. #[allow(dead_code)] pub fn is_valid_key_on_write_path_strong(&self) -> bool { - use postgres_ffi::pg_constants::{DEFAULTTABLESPACE_OID, GLOBALTABLESPACE_OID}; + use postgres_ffi_types::constants::{DEFAULTTABLESPACE_OID, GLOBALTABLESPACE_OID}; if !self.is_i128_representable() { return false; } diff --git a/libs/pageserver_api/src/keyspace.rs b/libs/pageserver_api/src/keyspace.rs index 79e3ef553b..10a242e13b 100644 --- a/libs/pageserver_api/src/keyspace.rs +++ b/libs/pageserver_api/src/keyspace.rs @@ -1,7 +1,6 @@ use std::ops::Range; use itertools::Itertools; -use postgres_ffi::BLCKSZ; use crate::key::Key; use crate::shard::{ShardCount, ShardIdentity}; @@ -269,9 +268,13 @@ impl KeySpace { /// Partition a key space into roughly chunks of roughly 'target_size' bytes /// in each partition. /// - pub fn partition(&self, shard_identity: &ShardIdentity, target_size: u64) -> KeyPartitioning { - // Assume that each value is 8k in size. - let target_nblocks = (target_size / BLCKSZ as u64) as u32; + pub fn partition( + &self, + shard_identity: &ShardIdentity, + target_size: u64, + block_size: u64, + ) -> KeyPartitioning { + let target_nblocks = (target_size / block_size) as u32; let mut parts = Vec::new(); let mut current_part = Vec::new(); @@ -331,8 +334,7 @@ impl KeySpace { std::cmp::max(range.start, prev.start) < std::cmp::min(range.end, prev.end); assert!( !overlap, - "Attempt to merge ovelapping keyspaces: {:?} overlaps {:?}", - prev, range + "Attempt to merge ovelapping keyspaces: {prev:?} overlaps {range:?}" ); } @@ -1101,7 +1103,7 @@ mod tests { // total range contains at least one shard-local page let all_nonzero = fragments.iter().all(|f| f.0 > 0); if !all_nonzero { - eprintln!("Found a zero-length fragment: {:?}", fragments); + eprintln!("Found a zero-length fragment: {fragments:?}"); } assert!(all_nonzero); } else { diff --git a/libs/pageserver_api/src/lib.rs b/libs/pageserver_api/src/lib.rs index ff705e79cd..52aed7a2c2 100644 --- a/libs/pageserver_api/src/lib.rs +++ b/libs/pageserver_api/src/lib.rs @@ -5,11 +5,10 @@ pub mod controller_api; pub mod key; pub mod keyspace; pub mod models; -pub mod record; +pub mod pagestream_api; pub mod reltag; pub mod shard; /// Public API types pub mod upcall_api; -pub mod value; pub mod config; diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 881f24b86c..82a3ac0eb4 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -5,16 +5,13 @@ pub mod utilization; use core::ops::Range; use std::collections::HashMap; use std::fmt::Display; -use std::io::{BufRead, Read}; use std::num::{NonZeroU32, NonZeroU64, NonZeroUsize}; use std::str::FromStr; use std::time::{Duration, SystemTime}; -use byteorder::{BigEndian, ReadBytesExt}; -use bytes::{Buf, BufMut, Bytes, BytesMut}; #[cfg(feature = "testing")] use camino::Utf8PathBuf; -use postgres_ffi::BLCKSZ; +use postgres_versioninfo::PgMajorVersion; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_with::serde_as; pub use utilization::PageserverUtilization; @@ -24,7 +21,6 @@ use utils::{completion, serde_system_time}; use crate::config::Ratio; use crate::key::{CompactKey, Key}; -use crate::reltag::RelTag; use crate::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize, TenantShardId}; /// The state of a tenant in this pageserver. @@ -403,7 +399,7 @@ pub enum TimelineCreateRequestMode { // inherits the ancestor's pg_version. Earlier code wasn't // using a flattened enum, so, it was an accepted field, and // we continue to accept it by having it here. - pg_version: Option, + pg_version: Option, #[serde(default, skip_serializing_if = "std::ops::Not::not")] read_only: bool, }, @@ -415,7 +411,7 @@ pub enum TimelineCreateRequestMode { Bootstrap { #[serde(default)] existing_initdb_timeline_id: Option, - pg_version: Option, + pg_version: Option, }, } @@ -1187,7 +1183,7 @@ impl Display for ImageCompressionAlgorithm { ImageCompressionAlgorithm::Disabled => write!(f, "disabled"), ImageCompressionAlgorithm::Zstd { level } => { if let Some(level) = level { - write!(f, "zstd({})", level) + write!(f, "zstd({level})") } else { write!(f, "zstd") } @@ -1578,7 +1574,7 @@ pub struct TimelineInfo { pub last_received_msg_lsn: Option, /// the timestamp (in microseconds) of the last received message pub last_received_msg_ts: Option, - pub pg_version: u32, + pub pg_version: PgMajorVersion, pub state: TimelineState, @@ -1907,219 +1903,6 @@ pub struct ScanDisposableKeysResponse { pub not_disposable_count: usize, } -// Wrapped in libpq CopyData -#[derive(PartialEq, Eq, Debug)] -pub enum PagestreamFeMessage { - Exists(PagestreamExistsRequest), - Nblocks(PagestreamNblocksRequest), - GetPage(PagestreamGetPageRequest), - DbSize(PagestreamDbSizeRequest), - GetSlruSegment(PagestreamGetSlruSegmentRequest), - #[cfg(feature = "testing")] - Test(PagestreamTestRequest), -} - -// Wrapped in libpq CopyData -#[derive(Debug, strum_macros::EnumProperty)] -pub enum PagestreamBeMessage { - Exists(PagestreamExistsResponse), - Nblocks(PagestreamNblocksResponse), - GetPage(PagestreamGetPageResponse), - Error(PagestreamErrorResponse), - DbSize(PagestreamDbSizeResponse), - GetSlruSegment(PagestreamGetSlruSegmentResponse), - #[cfg(feature = "testing")] - Test(PagestreamTestResponse), -} - -// Keep in sync with `pagestore_client.h` -#[repr(u8)] -enum PagestreamFeMessageTag { - Exists = 0, - Nblocks = 1, - GetPage = 2, - DbSize = 3, - GetSlruSegment = 4, - /* future tags above this line */ - /// For testing purposes, not available in production. - #[cfg(feature = "testing")] - Test = 99, -} - -// Keep in sync with `pagestore_client.h` -#[repr(u8)] -enum PagestreamBeMessageTag { - Exists = 100, - Nblocks = 101, - GetPage = 102, - Error = 103, - DbSize = 104, - GetSlruSegment = 105, - /* future tags above this line */ - /// For testing purposes, not available in production. - #[cfg(feature = "testing")] - Test = 199, -} - -impl TryFrom for PagestreamFeMessageTag { - type Error = u8; - fn try_from(value: u8) -> Result { - match value { - 0 => Ok(PagestreamFeMessageTag::Exists), - 1 => Ok(PagestreamFeMessageTag::Nblocks), - 2 => Ok(PagestreamFeMessageTag::GetPage), - 3 => Ok(PagestreamFeMessageTag::DbSize), - 4 => Ok(PagestreamFeMessageTag::GetSlruSegment), - #[cfg(feature = "testing")] - 99 => Ok(PagestreamFeMessageTag::Test), - _ => Err(value), - } - } -} - -impl TryFrom for PagestreamBeMessageTag { - type Error = u8; - fn try_from(value: u8) -> Result { - match value { - 100 => Ok(PagestreamBeMessageTag::Exists), - 101 => Ok(PagestreamBeMessageTag::Nblocks), - 102 => Ok(PagestreamBeMessageTag::GetPage), - 103 => Ok(PagestreamBeMessageTag::Error), - 104 => Ok(PagestreamBeMessageTag::DbSize), - 105 => Ok(PagestreamBeMessageTag::GetSlruSegment), - #[cfg(feature = "testing")] - 199 => Ok(PagestreamBeMessageTag::Test), - _ => Err(value), - } - } -} - -// A GetPage request contains two LSN values: -// -// request_lsn: Get the page version at this point in time. Lsn::Max is a special value that means -// "get the latest version present". It's used by the primary server, which knows that no one else -// is writing WAL. 'not_modified_since' must be set to a proper value even if request_lsn is -// Lsn::Max. Standby servers use the current replay LSN as the request LSN. -// -// not_modified_since: Hint to the pageserver that the client knows that the page has not been -// modified between 'not_modified_since' and the request LSN. It's always correct to set -// 'not_modified_since equal' to 'request_lsn' (unless Lsn::Max is used as the 'request_lsn'), but -// passing an earlier LSN can speed up the request, by allowing the pageserver to process the -// request without waiting for 'request_lsn' to arrive. -// -// The now-defunct V1 interface contained only one LSN, and a boolean 'latest' flag. The V1 interface was -// sufficient for the primary; the 'lsn' was equivalent to the 'not_modified_since' value, and -// 'latest' was set to true. The V2 interface was added because there was no correct way for a -// standby to request a page at a particular non-latest LSN, and also include the -// 'not_modified_since' hint. That led to an awkward choice of either using an old LSN in the -// request, if the standby knows that the page hasn't been modified since, and risk getting an error -// if that LSN has fallen behind the GC horizon, or requesting the current replay LSN, which could -// require the pageserver unnecessarily to wait for the WAL to arrive up to that point. The new V2 -// interface allows sending both LSNs, and let the pageserver do the right thing. There was no -// difference in the responses between V1 and V2. -// -// V3 version of protocol adds request ID to all requests. This request ID is also included in response -// as well as other fields from requests, which allows to verify that we receive response for our request. -// We copy fields from request to response to make checking more reliable: request ID is formed from process ID -// and local counter, so in principle there can be duplicated requests IDs if process PID is reused. -// -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum PagestreamProtocolVersion { - V2, - V3, -} - -pub type RequestId = u64; - -#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] -pub struct PagestreamRequest { - pub reqid: RequestId, - pub request_lsn: Lsn, - pub not_modified_since: Lsn, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct PagestreamExistsRequest { - pub hdr: PagestreamRequest, - pub rel: RelTag, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct PagestreamNblocksRequest { - pub hdr: PagestreamRequest, - pub rel: RelTag, -} - -#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] -pub struct PagestreamGetPageRequest { - pub hdr: PagestreamRequest, - pub rel: RelTag, - pub blkno: u32, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct PagestreamDbSizeRequest { - pub hdr: PagestreamRequest, - pub dbnode: u32, -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct PagestreamGetSlruSegmentRequest { - pub hdr: PagestreamRequest, - pub kind: u8, - pub segno: u32, -} - -#[derive(Debug)] -pub struct PagestreamExistsResponse { - pub req: PagestreamExistsRequest, - pub exists: bool, -} - -#[derive(Debug)] -pub struct PagestreamNblocksResponse { - pub req: PagestreamNblocksRequest, - pub n_blocks: u32, -} - -#[derive(Debug)] -pub struct PagestreamGetPageResponse { - pub req: PagestreamGetPageRequest, - pub page: Bytes, -} - -#[derive(Debug)] -pub struct PagestreamGetSlruSegmentResponse { - pub req: PagestreamGetSlruSegmentRequest, - pub segment: Bytes, -} - -#[derive(Debug)] -pub struct PagestreamErrorResponse { - pub req: PagestreamRequest, - pub message: String, -} - -#[derive(Debug)] -pub struct PagestreamDbSizeResponse { - pub req: PagestreamDbSizeRequest, - pub db_size: i64, -} - -#[cfg(feature = "testing")] -#[derive(Debug, PartialEq, Eq, Clone)] -pub struct PagestreamTestRequest { - pub hdr: PagestreamRequest, - pub batch_key: u64, - pub message: String, -} - -#[cfg(feature = "testing")] -#[derive(Debug)] -pub struct PagestreamTestResponse { - pub req: PagestreamTestRequest, -} - // This is a cut-down version of TenantHistorySize from the pageserver crate, omitting fields // that require pageserver-internal types. It is sufficient to get the total size. #[derive(Serialize, Deserialize, Debug)] @@ -2131,506 +1914,6 @@ pub struct TenantHistorySize { pub size: Option, } -impl PagestreamFeMessage { - /// Serialize a compute -> pageserver message. This is currently only used in testing - /// tools. Always uses protocol version 3. - pub fn serialize(&self) -> Bytes { - let mut bytes = BytesMut::new(); - - match self { - Self::Exists(req) => { - bytes.put_u8(PagestreamFeMessageTag::Exists as u8); - bytes.put_u64(req.hdr.reqid); - bytes.put_u64(req.hdr.request_lsn.0); - bytes.put_u64(req.hdr.not_modified_since.0); - bytes.put_u32(req.rel.spcnode); - bytes.put_u32(req.rel.dbnode); - bytes.put_u32(req.rel.relnode); - bytes.put_u8(req.rel.forknum); - } - - Self::Nblocks(req) => { - bytes.put_u8(PagestreamFeMessageTag::Nblocks as u8); - bytes.put_u64(req.hdr.reqid); - bytes.put_u64(req.hdr.request_lsn.0); - bytes.put_u64(req.hdr.not_modified_since.0); - bytes.put_u32(req.rel.spcnode); - bytes.put_u32(req.rel.dbnode); - bytes.put_u32(req.rel.relnode); - bytes.put_u8(req.rel.forknum); - } - - Self::GetPage(req) => { - bytes.put_u8(PagestreamFeMessageTag::GetPage as u8); - bytes.put_u64(req.hdr.reqid); - bytes.put_u64(req.hdr.request_lsn.0); - bytes.put_u64(req.hdr.not_modified_since.0); - bytes.put_u32(req.rel.spcnode); - bytes.put_u32(req.rel.dbnode); - bytes.put_u32(req.rel.relnode); - bytes.put_u8(req.rel.forknum); - bytes.put_u32(req.blkno); - } - - Self::DbSize(req) => { - bytes.put_u8(PagestreamFeMessageTag::DbSize as u8); - bytes.put_u64(req.hdr.reqid); - bytes.put_u64(req.hdr.request_lsn.0); - bytes.put_u64(req.hdr.not_modified_since.0); - bytes.put_u32(req.dbnode); - } - - Self::GetSlruSegment(req) => { - bytes.put_u8(PagestreamFeMessageTag::GetSlruSegment as u8); - bytes.put_u64(req.hdr.reqid); - bytes.put_u64(req.hdr.request_lsn.0); - bytes.put_u64(req.hdr.not_modified_since.0); - bytes.put_u8(req.kind); - bytes.put_u32(req.segno); - } - #[cfg(feature = "testing")] - Self::Test(req) => { - bytes.put_u8(PagestreamFeMessageTag::Test as u8); - bytes.put_u64(req.hdr.reqid); - bytes.put_u64(req.hdr.request_lsn.0); - bytes.put_u64(req.hdr.not_modified_since.0); - bytes.put_u64(req.batch_key); - let message = req.message.as_bytes(); - bytes.put_u64(message.len() as u64); - bytes.put_slice(message); - } - } - - bytes.into() - } - - pub fn parse( - body: &mut R, - protocol_version: PagestreamProtocolVersion, - ) -> anyhow::Result { - // these correspond to the NeonMessageTag enum in pagestore_client.h - // - // TODO: consider using protobuf or serde bincode for less error prone - // serialization. - let msg_tag = body.read_u8()?; - let (reqid, request_lsn, not_modified_since) = match protocol_version { - PagestreamProtocolVersion::V2 => ( - 0, - Lsn::from(body.read_u64::()?), - Lsn::from(body.read_u64::()?), - ), - PagestreamProtocolVersion::V3 => ( - body.read_u64::()?, - Lsn::from(body.read_u64::()?), - Lsn::from(body.read_u64::()?), - ), - }; - - match PagestreamFeMessageTag::try_from(msg_tag) - .map_err(|tag: u8| anyhow::anyhow!("invalid tag {tag}"))? - { - PagestreamFeMessageTag::Exists => { - Ok(PagestreamFeMessage::Exists(PagestreamExistsRequest { - hdr: PagestreamRequest { - reqid, - request_lsn, - not_modified_since, - }, - rel: RelTag { - spcnode: body.read_u32::()?, - dbnode: body.read_u32::()?, - relnode: body.read_u32::()?, - forknum: body.read_u8()?, - }, - })) - } - PagestreamFeMessageTag::Nblocks => { - Ok(PagestreamFeMessage::Nblocks(PagestreamNblocksRequest { - hdr: PagestreamRequest { - reqid, - request_lsn, - not_modified_since, - }, - rel: RelTag { - spcnode: body.read_u32::()?, - dbnode: body.read_u32::()?, - relnode: body.read_u32::()?, - forknum: body.read_u8()?, - }, - })) - } - PagestreamFeMessageTag::GetPage => { - Ok(PagestreamFeMessage::GetPage(PagestreamGetPageRequest { - hdr: PagestreamRequest { - reqid, - request_lsn, - not_modified_since, - }, - rel: RelTag { - spcnode: body.read_u32::()?, - dbnode: body.read_u32::()?, - relnode: body.read_u32::()?, - forknum: body.read_u8()?, - }, - blkno: body.read_u32::()?, - })) - } - PagestreamFeMessageTag::DbSize => { - Ok(PagestreamFeMessage::DbSize(PagestreamDbSizeRequest { - hdr: PagestreamRequest { - reqid, - request_lsn, - not_modified_since, - }, - dbnode: body.read_u32::()?, - })) - } - PagestreamFeMessageTag::GetSlruSegment => Ok(PagestreamFeMessage::GetSlruSegment( - PagestreamGetSlruSegmentRequest { - hdr: PagestreamRequest { - reqid, - request_lsn, - not_modified_since, - }, - kind: body.read_u8()?, - segno: body.read_u32::()?, - }, - )), - #[cfg(feature = "testing")] - PagestreamFeMessageTag::Test => Ok(PagestreamFeMessage::Test(PagestreamTestRequest { - hdr: PagestreamRequest { - reqid, - request_lsn, - not_modified_since, - }, - batch_key: body.read_u64::()?, - message: { - let len = body.read_u64::()?; - let mut buf = vec![0; len as usize]; - body.read_exact(&mut buf)?; - String::from_utf8(buf)? - }, - })), - } - } -} - -impl PagestreamBeMessage { - pub fn serialize(&self, protocol_version: PagestreamProtocolVersion) -> Bytes { - let mut bytes = BytesMut::new(); - - use PagestreamBeMessageTag as Tag; - match protocol_version { - PagestreamProtocolVersion::V2 => { - match self { - Self::Exists(resp) => { - bytes.put_u8(Tag::Exists as u8); - bytes.put_u8(resp.exists as u8); - } - - Self::Nblocks(resp) => { - bytes.put_u8(Tag::Nblocks as u8); - bytes.put_u32(resp.n_blocks); - } - - Self::GetPage(resp) => { - bytes.put_u8(Tag::GetPage as u8); - bytes.put(&resp.page[..]) - } - - Self::Error(resp) => { - bytes.put_u8(Tag::Error as u8); - bytes.put(resp.message.as_bytes()); - bytes.put_u8(0); // null terminator - } - Self::DbSize(resp) => { - bytes.put_u8(Tag::DbSize as u8); - bytes.put_i64(resp.db_size); - } - - Self::GetSlruSegment(resp) => { - bytes.put_u8(Tag::GetSlruSegment as u8); - bytes.put_u32((resp.segment.len() / BLCKSZ as usize) as u32); - bytes.put(&resp.segment[..]); - } - - #[cfg(feature = "testing")] - Self::Test(resp) => { - bytes.put_u8(Tag::Test as u8); - bytes.put_u64(resp.req.batch_key); - let message = resp.req.message.as_bytes(); - bytes.put_u64(message.len() as u64); - bytes.put_slice(message); - } - } - } - PagestreamProtocolVersion::V3 => { - match self { - Self::Exists(resp) => { - bytes.put_u8(Tag::Exists as u8); - bytes.put_u64(resp.req.hdr.reqid); - bytes.put_u64(resp.req.hdr.request_lsn.0); - bytes.put_u64(resp.req.hdr.not_modified_since.0); - bytes.put_u32(resp.req.rel.spcnode); - bytes.put_u32(resp.req.rel.dbnode); - bytes.put_u32(resp.req.rel.relnode); - bytes.put_u8(resp.req.rel.forknum); - bytes.put_u8(resp.exists as u8); - } - - Self::Nblocks(resp) => { - bytes.put_u8(Tag::Nblocks as u8); - bytes.put_u64(resp.req.hdr.reqid); - bytes.put_u64(resp.req.hdr.request_lsn.0); - bytes.put_u64(resp.req.hdr.not_modified_since.0); - bytes.put_u32(resp.req.rel.spcnode); - bytes.put_u32(resp.req.rel.dbnode); - bytes.put_u32(resp.req.rel.relnode); - bytes.put_u8(resp.req.rel.forknum); - bytes.put_u32(resp.n_blocks); - } - - Self::GetPage(resp) => { - bytes.put_u8(Tag::GetPage as u8); - bytes.put_u64(resp.req.hdr.reqid); - bytes.put_u64(resp.req.hdr.request_lsn.0); - bytes.put_u64(resp.req.hdr.not_modified_since.0); - bytes.put_u32(resp.req.rel.spcnode); - bytes.put_u32(resp.req.rel.dbnode); - bytes.put_u32(resp.req.rel.relnode); - bytes.put_u8(resp.req.rel.forknum); - bytes.put_u32(resp.req.blkno); - bytes.put(&resp.page[..]) - } - - Self::Error(resp) => { - bytes.put_u8(Tag::Error as u8); - bytes.put_u64(resp.req.reqid); - bytes.put_u64(resp.req.request_lsn.0); - bytes.put_u64(resp.req.not_modified_since.0); - bytes.put(resp.message.as_bytes()); - bytes.put_u8(0); // null terminator - } - Self::DbSize(resp) => { - bytes.put_u8(Tag::DbSize as u8); - bytes.put_u64(resp.req.hdr.reqid); - bytes.put_u64(resp.req.hdr.request_lsn.0); - bytes.put_u64(resp.req.hdr.not_modified_since.0); - bytes.put_u32(resp.req.dbnode); - bytes.put_i64(resp.db_size); - } - - Self::GetSlruSegment(resp) => { - bytes.put_u8(Tag::GetSlruSegment as u8); - bytes.put_u64(resp.req.hdr.reqid); - bytes.put_u64(resp.req.hdr.request_lsn.0); - bytes.put_u64(resp.req.hdr.not_modified_since.0); - bytes.put_u8(resp.req.kind); - bytes.put_u32(resp.req.segno); - bytes.put_u32((resp.segment.len() / BLCKSZ as usize) as u32); - bytes.put(&resp.segment[..]); - } - - #[cfg(feature = "testing")] - Self::Test(resp) => { - bytes.put_u8(Tag::Test as u8); - bytes.put_u64(resp.req.hdr.reqid); - bytes.put_u64(resp.req.hdr.request_lsn.0); - bytes.put_u64(resp.req.hdr.not_modified_since.0); - bytes.put_u64(resp.req.batch_key); - let message = resp.req.message.as_bytes(); - bytes.put_u64(message.len() as u64); - bytes.put_slice(message); - } - } - } - } - bytes.into() - } - - pub fn deserialize(buf: Bytes) -> anyhow::Result { - let mut buf = buf.reader(); - let msg_tag = buf.read_u8()?; - - use PagestreamBeMessageTag as Tag; - let ok = - match Tag::try_from(msg_tag).map_err(|tag: u8| anyhow::anyhow!("invalid tag {tag}"))? { - Tag::Exists => { - let reqid = buf.read_u64::()?; - let request_lsn = Lsn(buf.read_u64::()?); - let not_modified_since = Lsn(buf.read_u64::()?); - let rel = RelTag { - spcnode: buf.read_u32::()?, - dbnode: buf.read_u32::()?, - relnode: buf.read_u32::()?, - forknum: buf.read_u8()?, - }; - let exists = buf.read_u8()? != 0; - Self::Exists(PagestreamExistsResponse { - req: PagestreamExistsRequest { - hdr: PagestreamRequest { - reqid, - request_lsn, - not_modified_since, - }, - rel, - }, - exists, - }) - } - Tag::Nblocks => { - let reqid = buf.read_u64::()?; - let request_lsn = Lsn(buf.read_u64::()?); - let not_modified_since = Lsn(buf.read_u64::()?); - let rel = RelTag { - spcnode: buf.read_u32::()?, - dbnode: buf.read_u32::()?, - relnode: buf.read_u32::()?, - forknum: buf.read_u8()?, - }; - let n_blocks = buf.read_u32::()?; - Self::Nblocks(PagestreamNblocksResponse { - req: PagestreamNblocksRequest { - hdr: PagestreamRequest { - reqid, - request_lsn, - not_modified_since, - }, - rel, - }, - n_blocks, - }) - } - Tag::GetPage => { - let reqid = buf.read_u64::()?; - let request_lsn = Lsn(buf.read_u64::()?); - let not_modified_since = Lsn(buf.read_u64::()?); - let rel = RelTag { - spcnode: buf.read_u32::()?, - dbnode: buf.read_u32::()?, - relnode: buf.read_u32::()?, - forknum: buf.read_u8()?, - }; - let blkno = buf.read_u32::()?; - let mut page = vec![0; 8192]; // TODO: use MaybeUninit - buf.read_exact(&mut page)?; - Self::GetPage(PagestreamGetPageResponse { - req: PagestreamGetPageRequest { - hdr: PagestreamRequest { - reqid, - request_lsn, - not_modified_since, - }, - rel, - blkno, - }, - page: page.into(), - }) - } - Tag::Error => { - let reqid = buf.read_u64::()?; - let request_lsn = Lsn(buf.read_u64::()?); - let not_modified_since = Lsn(buf.read_u64::()?); - let mut msg = Vec::new(); - buf.read_until(0, &mut msg)?; - let cstring = std::ffi::CString::from_vec_with_nul(msg)?; - let rust_str = cstring.to_str()?; - Self::Error(PagestreamErrorResponse { - req: PagestreamRequest { - reqid, - request_lsn, - not_modified_since, - }, - message: rust_str.to_owned(), - }) - } - Tag::DbSize => { - let reqid = buf.read_u64::()?; - let request_lsn = Lsn(buf.read_u64::()?); - let not_modified_since = Lsn(buf.read_u64::()?); - let dbnode = buf.read_u32::()?; - let db_size = buf.read_i64::()?; - Self::DbSize(PagestreamDbSizeResponse { - req: PagestreamDbSizeRequest { - hdr: PagestreamRequest { - reqid, - request_lsn, - not_modified_since, - }, - dbnode, - }, - db_size, - }) - } - Tag::GetSlruSegment => { - let reqid = buf.read_u64::()?; - let request_lsn = Lsn(buf.read_u64::()?); - let not_modified_since = Lsn(buf.read_u64::()?); - let kind = buf.read_u8()?; - let segno = buf.read_u32::()?; - let n_blocks = buf.read_u32::()?; - let mut segment = vec![0; n_blocks as usize * BLCKSZ as usize]; - buf.read_exact(&mut segment)?; - Self::GetSlruSegment(PagestreamGetSlruSegmentResponse { - req: PagestreamGetSlruSegmentRequest { - hdr: PagestreamRequest { - reqid, - request_lsn, - not_modified_since, - }, - kind, - segno, - }, - segment: segment.into(), - }) - } - #[cfg(feature = "testing")] - Tag::Test => { - let reqid = buf.read_u64::()?; - let request_lsn = Lsn(buf.read_u64::()?); - let not_modified_since = Lsn(buf.read_u64::()?); - let batch_key = buf.read_u64::()?; - let len = buf.read_u64::()?; - let mut msg = vec![0; len as usize]; - buf.read_exact(&mut msg)?; - let message = String::from_utf8(msg)?; - Self::Test(PagestreamTestResponse { - req: PagestreamTestRequest { - hdr: PagestreamRequest { - reqid, - request_lsn, - not_modified_since, - }, - batch_key, - message, - }, - }) - } - }; - let remaining = buf.into_inner(); - if !remaining.is_empty() { - anyhow::bail!( - "remaining bytes in msg with tag={msg_tag}: {}", - remaining.len() - ); - } - Ok(ok) - } - - pub fn kind(&self) -> &'static str { - match self { - Self::Exists(_) => "Exists", - Self::Nblocks(_) => "Nblocks", - Self::GetPage(_) => "GetPage", - Self::Error(_) => "Error", - Self::DbSize(_) => "DbSize", - Self::GetSlruSegment(_) => "GetSlruSegment", - #[cfg(feature = "testing")] - Self::Test(_) => "Test", - } - } -} - #[derive(Debug, Serialize, Deserialize)] pub struct PageTraceEvent { pub key: CompactKey, @@ -2656,68 +1939,6 @@ mod tests { use super::*; - #[test] - fn test_pagestream() { - // Test serialization/deserialization of PagestreamFeMessage - let messages = vec![ - PagestreamFeMessage::Exists(PagestreamExistsRequest { - hdr: PagestreamRequest { - reqid: 0, - request_lsn: Lsn(4), - not_modified_since: Lsn(3), - }, - rel: RelTag { - forknum: 1, - spcnode: 2, - dbnode: 3, - relnode: 4, - }, - }), - PagestreamFeMessage::Nblocks(PagestreamNblocksRequest { - hdr: PagestreamRequest { - reqid: 0, - request_lsn: Lsn(4), - not_modified_since: Lsn(4), - }, - rel: RelTag { - forknum: 1, - spcnode: 2, - dbnode: 3, - relnode: 4, - }, - }), - PagestreamFeMessage::GetPage(PagestreamGetPageRequest { - hdr: PagestreamRequest { - reqid: 0, - request_lsn: Lsn(4), - not_modified_since: Lsn(3), - }, - rel: RelTag { - forknum: 1, - spcnode: 2, - dbnode: 3, - relnode: 4, - }, - blkno: 7, - }), - PagestreamFeMessage::DbSize(PagestreamDbSizeRequest { - hdr: PagestreamRequest { - reqid: 0, - request_lsn: Lsn(4), - not_modified_since: Lsn(3), - }, - dbnode: 7, - }), - ]; - for msg in messages { - let bytes = msg.serialize(); - let reconstructed = - PagestreamFeMessage::parse(&mut bytes.reader(), PagestreamProtocolVersion::V3) - .unwrap(); - assert!(msg == reconstructed); - } - } - #[test] fn test_tenantinfo_serde() { // Test serialization/deserialization of TenantInfo @@ -2791,8 +2012,7 @@ mod tests { let err = serde_json::from_value::(config_request).unwrap_err(); assert!( err.to_string().contains("unknown field `unknown_field`"), - "expect unknown field `unknown_field` error, got: {}", - err + "expect unknown field `unknown_field` error, got: {err}" ); } diff --git a/libs/pageserver_api/src/pagestream_api.rs b/libs/pageserver_api/src/pagestream_api.rs new file mode 100644 index 0000000000..862da8268a --- /dev/null +++ b/libs/pageserver_api/src/pagestream_api.rs @@ -0,0 +1,798 @@ +//! Rust definitions of the libpq-based pagestream API +//! +//! See also the C implementation of the same API in pgxn/neon/pagestore_client.h + +use std::io::{BufRead, Read}; + +use crate::reltag::RelTag; + +use byteorder::{BigEndian, ReadBytesExt}; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use utils::lsn::Lsn; + +/// Block size. +/// +/// XXX: We assume 8k block size in the SLRU fetch API. It's not great to hardcode +/// that in the protocol, because Postgres supports different block sizes as a compile +/// time option. +const BLCKSZ: usize = 8192; + +// Wrapped in libpq CopyData +#[derive(PartialEq, Eq, Debug)] +pub enum PagestreamFeMessage { + Exists(PagestreamExistsRequest), + Nblocks(PagestreamNblocksRequest), + GetPage(PagestreamGetPageRequest), + DbSize(PagestreamDbSizeRequest), + GetSlruSegment(PagestreamGetSlruSegmentRequest), + #[cfg(feature = "testing")] + Test(PagestreamTestRequest), +} + +// Wrapped in libpq CopyData +#[derive(Debug, strum_macros::EnumProperty)] +pub enum PagestreamBeMessage { + Exists(PagestreamExistsResponse), + Nblocks(PagestreamNblocksResponse), + GetPage(PagestreamGetPageResponse), + Error(PagestreamErrorResponse), + DbSize(PagestreamDbSizeResponse), + GetSlruSegment(PagestreamGetSlruSegmentResponse), + #[cfg(feature = "testing")] + Test(PagestreamTestResponse), +} + +// Keep in sync with `pagestore_client.h` +#[repr(u8)] +enum PagestreamFeMessageTag { + Exists = 0, + Nblocks = 1, + GetPage = 2, + DbSize = 3, + GetSlruSegment = 4, + /* future tags above this line */ + /// For testing purposes, not available in production. + #[cfg(feature = "testing")] + Test = 99, +} + +// Keep in sync with `pagestore_client.h` +#[repr(u8)] +enum PagestreamBeMessageTag { + Exists = 100, + Nblocks = 101, + GetPage = 102, + Error = 103, + DbSize = 104, + GetSlruSegment = 105, + /* future tags above this line */ + /// For testing purposes, not available in production. + #[cfg(feature = "testing")] + Test = 199, +} + +impl TryFrom for PagestreamFeMessageTag { + type Error = u8; + fn try_from(value: u8) -> Result { + match value { + 0 => Ok(PagestreamFeMessageTag::Exists), + 1 => Ok(PagestreamFeMessageTag::Nblocks), + 2 => Ok(PagestreamFeMessageTag::GetPage), + 3 => Ok(PagestreamFeMessageTag::DbSize), + 4 => Ok(PagestreamFeMessageTag::GetSlruSegment), + #[cfg(feature = "testing")] + 99 => Ok(PagestreamFeMessageTag::Test), + _ => Err(value), + } + } +} + +impl TryFrom for PagestreamBeMessageTag { + type Error = u8; + fn try_from(value: u8) -> Result { + match value { + 100 => Ok(PagestreamBeMessageTag::Exists), + 101 => Ok(PagestreamBeMessageTag::Nblocks), + 102 => Ok(PagestreamBeMessageTag::GetPage), + 103 => Ok(PagestreamBeMessageTag::Error), + 104 => Ok(PagestreamBeMessageTag::DbSize), + 105 => Ok(PagestreamBeMessageTag::GetSlruSegment), + #[cfg(feature = "testing")] + 199 => Ok(PagestreamBeMessageTag::Test), + _ => Err(value), + } + } +} + +// A GetPage request contains two LSN values: +// +// request_lsn: Get the page version at this point in time. Lsn::Max is a special value that means +// "get the latest version present". It's used by the primary server, which knows that no one else +// is writing WAL. 'not_modified_since' must be set to a proper value even if request_lsn is +// Lsn::Max. Standby servers use the current replay LSN as the request LSN. +// +// not_modified_since: Hint to the pageserver that the client knows that the page has not been +// modified between 'not_modified_since' and the request LSN. It's always correct to set +// 'not_modified_since equal' to 'request_lsn' (unless Lsn::Max is used as the 'request_lsn'), but +// passing an earlier LSN can speed up the request, by allowing the pageserver to process the +// request without waiting for 'request_lsn' to arrive. +// +// The now-defunct V1 interface contained only one LSN, and a boolean 'latest' flag. The V1 interface was +// sufficient for the primary; the 'lsn' was equivalent to the 'not_modified_since' value, and +// 'latest' was set to true. The V2 interface was added because there was no correct way for a +// standby to request a page at a particular non-latest LSN, and also include the +// 'not_modified_since' hint. That led to an awkward choice of either using an old LSN in the +// request, if the standby knows that the page hasn't been modified since, and risk getting an error +// if that LSN has fallen behind the GC horizon, or requesting the current replay LSN, which could +// require the pageserver unnecessarily to wait for the WAL to arrive up to that point. The new V2 +// interface allows sending both LSNs, and let the pageserver do the right thing. There was no +// difference in the responses between V1 and V2. +// +// V3 version of protocol adds request ID to all requests. This request ID is also included in response +// as well as other fields from requests, which allows to verify that we receive response for our request. +// We copy fields from request to response to make checking more reliable: request ID is formed from process ID +// and local counter, so in principle there can be duplicated requests IDs if process PID is reused. +// +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum PagestreamProtocolVersion { + V2, + V3, +} + +pub type RequestId = u64; + +#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] +pub struct PagestreamRequest { + pub reqid: RequestId, + pub request_lsn: Lsn, + pub not_modified_since: Lsn, +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct PagestreamExistsRequest { + pub hdr: PagestreamRequest, + pub rel: RelTag, +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct PagestreamNblocksRequest { + pub hdr: PagestreamRequest, + pub rel: RelTag, +} + +#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] +pub struct PagestreamGetPageRequest { + pub hdr: PagestreamRequest, + pub rel: RelTag, + pub blkno: u32, +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct PagestreamDbSizeRequest { + pub hdr: PagestreamRequest, + pub dbnode: u32, +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct PagestreamGetSlruSegmentRequest { + pub hdr: PagestreamRequest, + pub kind: u8, + pub segno: u32, +} + +#[derive(Debug)] +pub struct PagestreamExistsResponse { + pub req: PagestreamExistsRequest, + pub exists: bool, +} + +#[derive(Debug)] +pub struct PagestreamNblocksResponse { + pub req: PagestreamNblocksRequest, + pub n_blocks: u32, +} + +#[derive(Debug)] +pub struct PagestreamGetPageResponse { + pub req: PagestreamGetPageRequest, + pub page: Bytes, +} + +#[derive(Debug)] +pub struct PagestreamGetSlruSegmentResponse { + pub req: PagestreamGetSlruSegmentRequest, + pub segment: Bytes, +} + +#[derive(Debug)] +pub struct PagestreamErrorResponse { + pub req: PagestreamRequest, + pub message: String, +} + +#[derive(Debug)] +pub struct PagestreamDbSizeResponse { + pub req: PagestreamDbSizeRequest, + pub db_size: i64, +} + +#[cfg(feature = "testing")] +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct PagestreamTestRequest { + pub hdr: PagestreamRequest, + pub batch_key: u64, + pub message: String, +} + +#[cfg(feature = "testing")] +#[derive(Debug)] +pub struct PagestreamTestResponse { + pub req: PagestreamTestRequest, +} + +impl PagestreamFeMessage { + /// Serialize a compute -> pageserver message. This is currently only used in testing + /// tools. Always uses protocol version 3. + pub fn serialize(&self) -> Bytes { + let mut bytes = BytesMut::new(); + + match self { + Self::Exists(req) => { + bytes.put_u8(PagestreamFeMessageTag::Exists as u8); + bytes.put_u64(req.hdr.reqid); + bytes.put_u64(req.hdr.request_lsn.0); + bytes.put_u64(req.hdr.not_modified_since.0); + bytes.put_u32(req.rel.spcnode); + bytes.put_u32(req.rel.dbnode); + bytes.put_u32(req.rel.relnode); + bytes.put_u8(req.rel.forknum); + } + + Self::Nblocks(req) => { + bytes.put_u8(PagestreamFeMessageTag::Nblocks as u8); + bytes.put_u64(req.hdr.reqid); + bytes.put_u64(req.hdr.request_lsn.0); + bytes.put_u64(req.hdr.not_modified_since.0); + bytes.put_u32(req.rel.spcnode); + bytes.put_u32(req.rel.dbnode); + bytes.put_u32(req.rel.relnode); + bytes.put_u8(req.rel.forknum); + } + + Self::GetPage(req) => { + bytes.put_u8(PagestreamFeMessageTag::GetPage as u8); + bytes.put_u64(req.hdr.reqid); + bytes.put_u64(req.hdr.request_lsn.0); + bytes.put_u64(req.hdr.not_modified_since.0); + bytes.put_u32(req.rel.spcnode); + bytes.put_u32(req.rel.dbnode); + bytes.put_u32(req.rel.relnode); + bytes.put_u8(req.rel.forknum); + bytes.put_u32(req.blkno); + } + + Self::DbSize(req) => { + bytes.put_u8(PagestreamFeMessageTag::DbSize as u8); + bytes.put_u64(req.hdr.reqid); + bytes.put_u64(req.hdr.request_lsn.0); + bytes.put_u64(req.hdr.not_modified_since.0); + bytes.put_u32(req.dbnode); + } + + Self::GetSlruSegment(req) => { + bytes.put_u8(PagestreamFeMessageTag::GetSlruSegment as u8); + bytes.put_u64(req.hdr.reqid); + bytes.put_u64(req.hdr.request_lsn.0); + bytes.put_u64(req.hdr.not_modified_since.0); + bytes.put_u8(req.kind); + bytes.put_u32(req.segno); + } + #[cfg(feature = "testing")] + Self::Test(req) => { + bytes.put_u8(PagestreamFeMessageTag::Test as u8); + bytes.put_u64(req.hdr.reqid); + bytes.put_u64(req.hdr.request_lsn.0); + bytes.put_u64(req.hdr.not_modified_since.0); + bytes.put_u64(req.batch_key); + let message = req.message.as_bytes(); + bytes.put_u64(message.len() as u64); + bytes.put_slice(message); + } + } + + bytes.into() + } + + pub fn parse( + body: &mut R, + protocol_version: PagestreamProtocolVersion, + ) -> anyhow::Result { + // these correspond to the NeonMessageTag enum in pagestore_client.h + // + // TODO: consider using protobuf or serde bincode for less error prone + // serialization. + let msg_tag = body.read_u8()?; + let (reqid, request_lsn, not_modified_since) = match protocol_version { + PagestreamProtocolVersion::V2 => ( + 0, + Lsn::from(body.read_u64::()?), + Lsn::from(body.read_u64::()?), + ), + PagestreamProtocolVersion::V3 => ( + body.read_u64::()?, + Lsn::from(body.read_u64::()?), + Lsn::from(body.read_u64::()?), + ), + }; + + match PagestreamFeMessageTag::try_from(msg_tag) + .map_err(|tag: u8| anyhow::anyhow!("invalid tag {tag}"))? + { + PagestreamFeMessageTag::Exists => { + Ok(PagestreamFeMessage::Exists(PagestreamExistsRequest { + hdr: PagestreamRequest { + reqid, + request_lsn, + not_modified_since, + }, + rel: RelTag { + spcnode: body.read_u32::()?, + dbnode: body.read_u32::()?, + relnode: body.read_u32::()?, + forknum: body.read_u8()?, + }, + })) + } + PagestreamFeMessageTag::Nblocks => { + Ok(PagestreamFeMessage::Nblocks(PagestreamNblocksRequest { + hdr: PagestreamRequest { + reqid, + request_lsn, + not_modified_since, + }, + rel: RelTag { + spcnode: body.read_u32::()?, + dbnode: body.read_u32::()?, + relnode: body.read_u32::()?, + forknum: body.read_u8()?, + }, + })) + } + PagestreamFeMessageTag::GetPage => { + Ok(PagestreamFeMessage::GetPage(PagestreamGetPageRequest { + hdr: PagestreamRequest { + reqid, + request_lsn, + not_modified_since, + }, + rel: RelTag { + spcnode: body.read_u32::()?, + dbnode: body.read_u32::()?, + relnode: body.read_u32::()?, + forknum: body.read_u8()?, + }, + blkno: body.read_u32::()?, + })) + } + PagestreamFeMessageTag::DbSize => { + Ok(PagestreamFeMessage::DbSize(PagestreamDbSizeRequest { + hdr: PagestreamRequest { + reqid, + request_lsn, + not_modified_since, + }, + dbnode: body.read_u32::()?, + })) + } + PagestreamFeMessageTag::GetSlruSegment => Ok(PagestreamFeMessage::GetSlruSegment( + PagestreamGetSlruSegmentRequest { + hdr: PagestreamRequest { + reqid, + request_lsn, + not_modified_since, + }, + kind: body.read_u8()?, + segno: body.read_u32::()?, + }, + )), + #[cfg(feature = "testing")] + PagestreamFeMessageTag::Test => Ok(PagestreamFeMessage::Test(PagestreamTestRequest { + hdr: PagestreamRequest { + reqid, + request_lsn, + not_modified_since, + }, + batch_key: body.read_u64::()?, + message: { + let len = body.read_u64::()?; + let mut buf = vec![0; len as usize]; + body.read_exact(&mut buf)?; + String::from_utf8(buf)? + }, + })), + } + } +} + +impl PagestreamBeMessage { + pub fn serialize(&self, protocol_version: PagestreamProtocolVersion) -> Bytes { + let mut bytes = BytesMut::new(); + + use PagestreamBeMessageTag as Tag; + match protocol_version { + PagestreamProtocolVersion::V2 => { + match self { + Self::Exists(resp) => { + bytes.put_u8(Tag::Exists as u8); + bytes.put_u8(resp.exists as u8); + } + + Self::Nblocks(resp) => { + bytes.put_u8(Tag::Nblocks as u8); + bytes.put_u32(resp.n_blocks); + } + + Self::GetPage(resp) => { + bytes.put_u8(Tag::GetPage as u8); + bytes.put(&resp.page[..]) + } + + Self::Error(resp) => { + bytes.put_u8(Tag::Error as u8); + bytes.put(resp.message.as_bytes()); + bytes.put_u8(0); // null terminator + } + Self::DbSize(resp) => { + bytes.put_u8(Tag::DbSize as u8); + bytes.put_i64(resp.db_size); + } + + Self::GetSlruSegment(resp) => { + bytes.put_u8(Tag::GetSlruSegment as u8); + bytes.put_u32((resp.segment.len() / BLCKSZ) as u32); + bytes.put(&resp.segment[..]); + } + + #[cfg(feature = "testing")] + Self::Test(resp) => { + bytes.put_u8(Tag::Test as u8); + bytes.put_u64(resp.req.batch_key); + let message = resp.req.message.as_bytes(); + bytes.put_u64(message.len() as u64); + bytes.put_slice(message); + } + } + } + PagestreamProtocolVersion::V3 => { + match self { + Self::Exists(resp) => { + bytes.put_u8(Tag::Exists as u8); + bytes.put_u64(resp.req.hdr.reqid); + bytes.put_u64(resp.req.hdr.request_lsn.0); + bytes.put_u64(resp.req.hdr.not_modified_since.0); + bytes.put_u32(resp.req.rel.spcnode); + bytes.put_u32(resp.req.rel.dbnode); + bytes.put_u32(resp.req.rel.relnode); + bytes.put_u8(resp.req.rel.forknum); + bytes.put_u8(resp.exists as u8); + } + + Self::Nblocks(resp) => { + bytes.put_u8(Tag::Nblocks as u8); + bytes.put_u64(resp.req.hdr.reqid); + bytes.put_u64(resp.req.hdr.request_lsn.0); + bytes.put_u64(resp.req.hdr.not_modified_since.0); + bytes.put_u32(resp.req.rel.spcnode); + bytes.put_u32(resp.req.rel.dbnode); + bytes.put_u32(resp.req.rel.relnode); + bytes.put_u8(resp.req.rel.forknum); + bytes.put_u32(resp.n_blocks); + } + + Self::GetPage(resp) => { + bytes.put_u8(Tag::GetPage as u8); + bytes.put_u64(resp.req.hdr.reqid); + bytes.put_u64(resp.req.hdr.request_lsn.0); + bytes.put_u64(resp.req.hdr.not_modified_since.0); + bytes.put_u32(resp.req.rel.spcnode); + bytes.put_u32(resp.req.rel.dbnode); + bytes.put_u32(resp.req.rel.relnode); + bytes.put_u8(resp.req.rel.forknum); + bytes.put_u32(resp.req.blkno); + bytes.put(&resp.page[..]) + } + + Self::Error(resp) => { + bytes.put_u8(Tag::Error as u8); + bytes.put_u64(resp.req.reqid); + bytes.put_u64(resp.req.request_lsn.0); + bytes.put_u64(resp.req.not_modified_since.0); + bytes.put(resp.message.as_bytes()); + bytes.put_u8(0); // null terminator + } + Self::DbSize(resp) => { + bytes.put_u8(Tag::DbSize as u8); + bytes.put_u64(resp.req.hdr.reqid); + bytes.put_u64(resp.req.hdr.request_lsn.0); + bytes.put_u64(resp.req.hdr.not_modified_since.0); + bytes.put_u32(resp.req.dbnode); + bytes.put_i64(resp.db_size); + } + + Self::GetSlruSegment(resp) => { + bytes.put_u8(Tag::GetSlruSegment as u8); + bytes.put_u64(resp.req.hdr.reqid); + bytes.put_u64(resp.req.hdr.request_lsn.0); + bytes.put_u64(resp.req.hdr.not_modified_since.0); + bytes.put_u8(resp.req.kind); + bytes.put_u32(resp.req.segno); + bytes.put_u32((resp.segment.len() / BLCKSZ) as u32); + bytes.put(&resp.segment[..]); + } + + #[cfg(feature = "testing")] + Self::Test(resp) => { + bytes.put_u8(Tag::Test as u8); + bytes.put_u64(resp.req.hdr.reqid); + bytes.put_u64(resp.req.hdr.request_lsn.0); + bytes.put_u64(resp.req.hdr.not_modified_since.0); + bytes.put_u64(resp.req.batch_key); + let message = resp.req.message.as_bytes(); + bytes.put_u64(message.len() as u64); + bytes.put_slice(message); + } + } + } + } + bytes.into() + } + + pub fn deserialize(buf: Bytes) -> anyhow::Result { + let mut buf = buf.reader(); + let msg_tag = buf.read_u8()?; + + use PagestreamBeMessageTag as Tag; + let ok = + match Tag::try_from(msg_tag).map_err(|tag: u8| anyhow::anyhow!("invalid tag {tag}"))? { + Tag::Exists => { + let reqid = buf.read_u64::()?; + let request_lsn = Lsn(buf.read_u64::()?); + let not_modified_since = Lsn(buf.read_u64::()?); + let rel = RelTag { + spcnode: buf.read_u32::()?, + dbnode: buf.read_u32::()?, + relnode: buf.read_u32::()?, + forknum: buf.read_u8()?, + }; + let exists = buf.read_u8()? != 0; + Self::Exists(PagestreamExistsResponse { + req: PagestreamExistsRequest { + hdr: PagestreamRequest { + reqid, + request_lsn, + not_modified_since, + }, + rel, + }, + exists, + }) + } + Tag::Nblocks => { + let reqid = buf.read_u64::()?; + let request_lsn = Lsn(buf.read_u64::()?); + let not_modified_since = Lsn(buf.read_u64::()?); + let rel = RelTag { + spcnode: buf.read_u32::()?, + dbnode: buf.read_u32::()?, + relnode: buf.read_u32::()?, + forknum: buf.read_u8()?, + }; + let n_blocks = buf.read_u32::()?; + Self::Nblocks(PagestreamNblocksResponse { + req: PagestreamNblocksRequest { + hdr: PagestreamRequest { + reqid, + request_lsn, + not_modified_since, + }, + rel, + }, + n_blocks, + }) + } + Tag::GetPage => { + let reqid = buf.read_u64::()?; + let request_lsn = Lsn(buf.read_u64::()?); + let not_modified_since = Lsn(buf.read_u64::()?); + let rel = RelTag { + spcnode: buf.read_u32::()?, + dbnode: buf.read_u32::()?, + relnode: buf.read_u32::()?, + forknum: buf.read_u8()?, + }; + let blkno = buf.read_u32::()?; + let mut page = vec![0; 8192]; // TODO: use MaybeUninit + buf.read_exact(&mut page)?; + Self::GetPage(PagestreamGetPageResponse { + req: PagestreamGetPageRequest { + hdr: PagestreamRequest { + reqid, + request_lsn, + not_modified_since, + }, + rel, + blkno, + }, + page: page.into(), + }) + } + Tag::Error => { + let reqid = buf.read_u64::()?; + let request_lsn = Lsn(buf.read_u64::()?); + let not_modified_since = Lsn(buf.read_u64::()?); + let mut msg = Vec::new(); + buf.read_until(0, &mut msg)?; + let cstring = std::ffi::CString::from_vec_with_nul(msg)?; + let rust_str = cstring.to_str()?; + Self::Error(PagestreamErrorResponse { + req: PagestreamRequest { + reqid, + request_lsn, + not_modified_since, + }, + message: rust_str.to_owned(), + }) + } + Tag::DbSize => { + let reqid = buf.read_u64::()?; + let request_lsn = Lsn(buf.read_u64::()?); + let not_modified_since = Lsn(buf.read_u64::()?); + let dbnode = buf.read_u32::()?; + let db_size = buf.read_i64::()?; + Self::DbSize(PagestreamDbSizeResponse { + req: PagestreamDbSizeRequest { + hdr: PagestreamRequest { + reqid, + request_lsn, + not_modified_since, + }, + dbnode, + }, + db_size, + }) + } + Tag::GetSlruSegment => { + let reqid = buf.read_u64::()?; + let request_lsn = Lsn(buf.read_u64::()?); + let not_modified_since = Lsn(buf.read_u64::()?); + let kind = buf.read_u8()?; + let segno = buf.read_u32::()?; + let n_blocks = buf.read_u32::()?; + let mut segment = vec![0; n_blocks as usize * BLCKSZ]; + buf.read_exact(&mut segment)?; + Self::GetSlruSegment(PagestreamGetSlruSegmentResponse { + req: PagestreamGetSlruSegmentRequest { + hdr: PagestreamRequest { + reqid, + request_lsn, + not_modified_since, + }, + kind, + segno, + }, + segment: segment.into(), + }) + } + #[cfg(feature = "testing")] + Tag::Test => { + let reqid = buf.read_u64::()?; + let request_lsn = Lsn(buf.read_u64::()?); + let not_modified_since = Lsn(buf.read_u64::()?); + let batch_key = buf.read_u64::()?; + let len = buf.read_u64::()?; + let mut msg = vec![0; len as usize]; + buf.read_exact(&mut msg)?; + let message = String::from_utf8(msg)?; + Self::Test(PagestreamTestResponse { + req: PagestreamTestRequest { + hdr: PagestreamRequest { + reqid, + request_lsn, + not_modified_since, + }, + batch_key, + message, + }, + }) + } + }; + let remaining = buf.into_inner(); + if !remaining.is_empty() { + anyhow::bail!( + "remaining bytes in msg with tag={msg_tag}: {}", + remaining.len() + ); + } + Ok(ok) + } + + pub fn kind(&self) -> &'static str { + match self { + Self::Exists(_) => "Exists", + Self::Nblocks(_) => "Nblocks", + Self::GetPage(_) => "GetPage", + Self::Error(_) => "Error", + Self::DbSize(_) => "DbSize", + Self::GetSlruSegment(_) => "GetSlruSegment", + #[cfg(feature = "testing")] + Self::Test(_) => "Test", + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pagestream() { + // Test serialization/deserialization of PagestreamFeMessage + let messages = vec![ + PagestreamFeMessage::Exists(PagestreamExistsRequest { + hdr: PagestreamRequest { + reqid: 0, + request_lsn: Lsn(4), + not_modified_since: Lsn(3), + }, + rel: RelTag { + forknum: 1, + spcnode: 2, + dbnode: 3, + relnode: 4, + }, + }), + PagestreamFeMessage::Nblocks(PagestreamNblocksRequest { + hdr: PagestreamRequest { + reqid: 0, + request_lsn: Lsn(4), + not_modified_since: Lsn(4), + }, + rel: RelTag { + forknum: 1, + spcnode: 2, + dbnode: 3, + relnode: 4, + }, + }), + PagestreamFeMessage::GetPage(PagestreamGetPageRequest { + hdr: PagestreamRequest { + reqid: 0, + request_lsn: Lsn(4), + not_modified_since: Lsn(3), + }, + rel: RelTag { + forknum: 1, + spcnode: 2, + dbnode: 3, + relnode: 4, + }, + blkno: 7, + }), + PagestreamFeMessage::DbSize(PagestreamDbSizeRequest { + hdr: PagestreamRequest { + reqid: 0, + request_lsn: Lsn(4), + not_modified_since: Lsn(3), + }, + dbnode: 7, + }), + ]; + for msg in messages { + let bytes = msg.serialize(); + let reconstructed = + PagestreamFeMessage::parse(&mut bytes.reader(), PagestreamProtocolVersion::V3) + .unwrap(); + assert!(msg == reconstructed); + } + } +} diff --git a/libs/pageserver_api/src/reltag.rs b/libs/pageserver_api/src/reltag.rs index e0dd4fdfe8..d0e37dffae 100644 --- a/libs/pageserver_api/src/reltag.rs +++ b/libs/pageserver_api/src/reltag.rs @@ -1,9 +1,9 @@ use std::cmp::Ordering; use std::fmt; -use postgres_ffi::Oid; -use postgres_ffi::pg_constants::GLOBALTABLESPACE_OID; -use postgres_ffi::relfile_utils::{MAIN_FORKNUM, forkname_to_number, forknumber_to_name}; +use postgres_ffi_types::Oid; +use postgres_ffi_types::constants::GLOBALTABLESPACE_OID; +use postgres_ffi_types::forknum::{MAIN_FORKNUM, forkname_to_number, forknumber_to_name}; use serde::{Deserialize, Serialize}; /// diff --git a/libs/pageserver_api/src/shard.rs b/libs/pageserver_api/src/shard.rs index feb59f5070..9c16be93e8 100644 --- a/libs/pageserver_api/src/shard.rs +++ b/libs/pageserver_api/src/shard.rs @@ -35,7 +35,7 @@ use std::hash::{Hash, Hasher}; #[doc(inline)] pub use ::utils::shard::*; -use postgres_ffi::relfile_utils::INIT_FORKNUM; +use postgres_ffi_types::forknum::INIT_FORKNUM; use serde::{Deserialize, Serialize}; use crate::key::Key; diff --git a/libs/pageserver_api/src/upcall_api.rs b/libs/pageserver_api/src/upcall_api.rs index 4dce5f7817..07cada2eb1 100644 --- a/libs/pageserver_api/src/upcall_api.rs +++ b/libs/pageserver_api/src/upcall_api.rs @@ -9,7 +9,7 @@ use utils::id::{NodeId, TimelineId}; use crate::controller_api::NodeRegisterRequest; use crate::models::{LocationConfigMode, ShardImportStatus}; -use crate::shard::TenantShardId; +use crate::shard::{ShardStripeSize, TenantShardId}; /// Upcall message sent by the pageserver to the configured `control_plane_api` on /// startup. @@ -23,19 +23,13 @@ pub struct ReAttachRequest { pub register: Option, } -fn default_mode() -> LocationConfigMode { - LocationConfigMode::AttachedSingle -} - #[derive(Serialize, Deserialize, Debug)] pub struct ReAttachResponseTenant { pub id: TenantShardId, /// Mandatory if LocationConfigMode is None or set to an Attached* mode pub r#gen: Option, - - /// Default value only for backward compat: this field should be set - #[serde(default = "default_mode")] pub mode: LocationConfigMode, + pub stripe_size: ShardStripeSize, } #[derive(Serialize, Deserialize)] pub struct ReAttachResponse { diff --git a/libs/postgres_backend/src/lib.rs b/libs/postgres_backend/src/lib.rs index 714d8ac403..091299f842 100644 --- a/libs/postgres_backend/src/lib.rs +++ b/libs/postgres_backend/src/lib.rs @@ -939,7 +939,7 @@ impl PostgresBackendReader { FeMessage::CopyFail => Err(CopyStreamHandlerEnd::CopyFail), FeMessage::Terminate => Err(CopyStreamHandlerEnd::Terminate), _ => Err(CopyStreamHandlerEnd::from(ConnectionError::Protocol( - ProtocolError::Protocol(format!("unexpected message in COPY stream {:?}", msg)), + ProtocolError::Protocol(format!("unexpected message in COPY stream {msg:?}")), ))), }, None => Err(CopyStreamHandlerEnd::EOF), diff --git a/libs/postgres_backend/tests/simple_select.rs b/libs/postgres_backend/tests/simple_select.rs index 75ca123014..23e17799bd 100644 --- a/libs/postgres_backend/tests/simple_select.rs +++ b/libs/postgres_backend/tests/simple_select.rs @@ -61,7 +61,7 @@ async fn simple_select() { // so spawn it off to run on its own. tokio::spawn(async move { if let Err(e) = connection.await { - eprintln!("connection error: {}", e); + eprintln!("connection error: {e}"); } }); @@ -137,7 +137,7 @@ async fn simple_select_ssl() { // so spawn it off to run on its own. tokio::spawn(async move { if let Err(e) = connection.await { - eprintln!("connection error: {}", e); + eprintln!("connection error: {e}"); } }); diff --git a/libs/postgres_connection/src/lib.rs b/libs/postgres_connection/src/lib.rs index cd981b3729..2388303329 100644 --- a/libs/postgres_connection/src/lib.rs +++ b/libs/postgres_connection/src/lib.rs @@ -223,7 +223,7 @@ mod tests_pg_connection_config { assert_eq!(cfg.port(), 123); assert_eq!(cfg.raw_address(), "stub.host.example:123"); assert_eq!( - format!("{:?}", cfg), + format!("{cfg:?}"), "PgConnectionConfig { host: Domain(\"stub.host.example\"), port: 123, password: None }" ); } @@ -239,7 +239,7 @@ mod tests_pg_connection_config { assert_eq!(cfg.port(), 123); assert_eq!(cfg.raw_address(), "[::1]:123"); assert_eq!( - format!("{:?}", cfg), + format!("{cfg:?}"), "PgConnectionConfig { host: Ipv6(::1), port: 123, password: None }" ); } @@ -252,7 +252,7 @@ mod tests_pg_connection_config { assert_eq!(cfg.port(), 123); assert_eq!(cfg.raw_address(), "stub.host.example:123"); assert_eq!( - format!("{:?}", cfg), + format!("{cfg:?}"), "PgConnectionConfig { host: Domain(\"stub.host.example\"), port: 123, password: Some(REDACTED-STRING) }" ); } diff --git a/libs/postgres_ffi/Cargo.toml b/libs/postgres_ffi/Cargo.toml index b7a376841d..d4fec6cbe9 100644 --- a/libs/postgres_ffi/Cargo.toml +++ b/libs/postgres_ffi/Cargo.toml @@ -16,8 +16,10 @@ memoffset.workspace = true pprof.workspace = true thiserror.workspace = true serde.workspace = true +postgres_ffi_types.workspace = true utils.workspace = true tracing.workspace = true +postgres_versioninfo.workspace = true [dev-dependencies] env_logger.workspace = true diff --git a/libs/postgres_ffi/benches/waldecoder.rs b/libs/postgres_ffi/benches/waldecoder.rs index 2e1d62e452..b2a884c7db 100644 --- a/libs/postgres_ffi/benches/waldecoder.rs +++ b/libs/postgres_ffi/benches/waldecoder.rs @@ -4,6 +4,7 @@ use criterion::{Bencher, Criterion, criterion_group, criterion_main}; use postgres_ffi::v17::wal_generator::LogicalMessageGenerator; use postgres_ffi::v17::waldecoder_handler::WalStreamDecoderHandler; use postgres_ffi::waldecoder::WalStreamDecoder; +use postgres_versioninfo::PgMajorVersion; use pprof::criterion::{Output, PProfProfiler}; use utils::lsn::Lsn; @@ -32,7 +33,7 @@ fn bench_complete_record(c: &mut Criterion) { let value_size = LogicalMessageGenerator::make_value_size(size, PREFIX); let value = vec![1; value_size]; - let mut decoder = WalStreamDecoder::new(Lsn(0), 170000); + let mut decoder = WalStreamDecoder::new(Lsn(0), PgMajorVersion::PG17); let msg = LogicalMessageGenerator::new(PREFIX, &value) .next() .unwrap() diff --git a/libs/postgres_ffi/src/lib.rs b/libs/postgres_ffi/src/lib.rs index 05d8de4c7a..9297ac46c9 100644 --- a/libs/postgres_ffi/src/lib.rs +++ b/libs/postgres_ffi/src/lib.rs @@ -14,6 +14,8 @@ use bytes::Bytes; use utils::bin_ser::SerializeError; use utils::lsn::Lsn; +pub use postgres_versioninfo::PgMajorVersion; + macro_rules! postgres_ffi { ($version:ident) => { #[path = "."] @@ -91,21 +93,22 @@ macro_rules! dispatch_pgversion { $version => $code, default = $invalid_pgver_handling, pgversions = [ - 14 : v14, - 15 : v15, - 16 : v16, - 17 : v17, + $crate::PgMajorVersion::PG14 => v14, + $crate::PgMajorVersion::PG15 => v15, + $crate::PgMajorVersion::PG16 => v16, + $crate::PgMajorVersion::PG17 => v17, ] ) }; ($pgversion:expr => $code:expr, default = $default:expr, - pgversions = [$($sv:literal : $vsv:ident),+ $(,)?]) => { - match ($pgversion) { + pgversions = [$($sv:pat => $vsv:ident),+ $(,)?]) => { + match ($pgversion.clone().into()) { $($sv => { use $crate::$vsv as pgv; $code },)+ + #[allow(unreachable_patterns)] _ => { $default } @@ -179,9 +182,9 @@ macro_rules! enum_pgversion { $($variant ( $crate::$md::$t )),+ } impl self::$name { - pub fn pg_version(&self) -> u32 { + pub fn pg_version(&self) -> PgMajorVersion { enum_pgversion_dispatch!(self, $name, _ign, { - pgv::bindings::PG_MAJORVERSION_NUM + pgv::bindings::MY_PGVERSION }) } } @@ -195,15 +198,15 @@ macro_rules! enum_pgversion { }; {name = $name:ident, path = $p:ident, - typ = $t:ident, + $(typ = $t:ident,)? pgversions = [$($variant:ident : $md:ident),+ $(,)?]} => { pub enum $name { - $($variant ($crate::$md::$p::$t)),+ + $($variant $(($crate::$md::$p::$t))?),+ } impl $name { - pub fn pg_version(&self) -> u32 { + pub fn pg_version(&self) -> PgMajorVersion { enum_pgversion_dispatch!(self, $name, _ign, { - pgv::bindings::PG_MAJORVERSION_NUM + pgv::bindings::MY_PGVERSION }) } } @@ -249,22 +252,21 @@ pub use v14::xlog_utils::{ try_from_pg_timestamp, }; -pub fn bkpimage_is_compressed(bimg_info: u8, version: u32) -> bool { +pub fn bkpimage_is_compressed(bimg_info: u8, version: PgMajorVersion) -> bool { dispatch_pgversion!(version, pgv::bindings::bkpimg_is_compressed(bimg_info)) } pub fn generate_wal_segment( segno: u64, system_id: u64, - pg_version: u32, + pg_version: PgMajorVersion, lsn: Lsn, ) -> Result { assert_eq!(segno, lsn.segment_number(WAL_SEGMENT_SIZE)); dispatch_pgversion!( pg_version, - pgv::xlog_utils::generate_wal_segment(segno, system_id, lsn), - Err(SerializeError::BadInput) + pgv::xlog_utils::generate_wal_segment(segno, system_id, lsn) ) } @@ -272,7 +274,7 @@ pub fn generate_pg_control( pg_control_bytes: &[u8], checkpoint_bytes: &[u8], lsn: Lsn, - pg_version: u32, + pg_version: PgMajorVersion, ) -> anyhow::Result<(Bytes, u64, bool)> { dispatch_pgversion!( pg_version, @@ -352,6 +354,7 @@ pub fn fsm_logical_to_physical(addr: BlockNumber) -> BlockNumber { pub mod waldecoder { use std::num::NonZeroU32; + use crate::PgMajorVersion; use bytes::{Buf, Bytes, BytesMut}; use thiserror::Error; use utils::lsn::Lsn; @@ -369,7 +372,7 @@ pub mod waldecoder { pub struct WalStreamDecoder { pub lsn: Lsn, - pub pg_version: u32, + pub pg_version: PgMajorVersion, pub inputbuf: BytesMut, pub state: State, } @@ -382,7 +385,7 @@ pub mod waldecoder { } impl WalStreamDecoder { - pub fn new(lsn: Lsn, pg_version: u32) -> WalStreamDecoder { + pub fn new(lsn: Lsn, pg_version: PgMajorVersion) -> WalStreamDecoder { WalStreamDecoder { lsn, pg_version, diff --git a/libs/postgres_ffi/src/pg_constants.rs b/libs/postgres_ffi/src/pg_constants.rs index b0bdd8a8da..f61b9a71c2 100644 --- a/libs/postgres_ffi/src/pg_constants.rs +++ b/libs/postgres_ffi/src/pg_constants.rs @@ -11,11 +11,7 @@ use crate::{BLCKSZ, PageHeaderData}; -// -// From pg_tablespace_d.h -// -pub const DEFAULTTABLESPACE_OID: u32 = 1663; -pub const GLOBALTABLESPACE_OID: u32 = 1664; +// Note: There are a few more widely-used constants in the postgres_ffi_types::constants crate. // From storage_xlog.h pub const XLOG_SMGR_CREATE: u8 = 0x10; diff --git a/libs/postgres_ffi/src/pg_constants_v14.rs b/libs/postgres_ffi/src/pg_constants_v14.rs index fe01a5df7c..fd393995db 100644 --- a/libs/postgres_ffi/src/pg_constants_v14.rs +++ b/libs/postgres_ffi/src/pg_constants_v14.rs @@ -1,3 +1,7 @@ +use crate::PgMajorVersion; + +pub const MY_PGVERSION: PgMajorVersion = PgMajorVersion::PG14; + pub const XLOG_DBASE_CREATE: u8 = 0x00; pub const XLOG_DBASE_DROP: u8 = 0x10; diff --git a/libs/postgres_ffi/src/pg_constants_v15.rs b/libs/postgres_ffi/src/pg_constants_v15.rs index 3cd1b7aec5..6c1e2c13de 100644 --- a/libs/postgres_ffi/src/pg_constants_v15.rs +++ b/libs/postgres_ffi/src/pg_constants_v15.rs @@ -1,3 +1,7 @@ +use crate::PgMajorVersion; + +pub const MY_PGVERSION: PgMajorVersion = PgMajorVersion::PG15; + pub const XACT_XINFO_HAS_DROPPED_STATS: u32 = 1u32 << 8; pub const XLOG_DBASE_CREATE_FILE_COPY: u8 = 0x00; diff --git a/libs/postgres_ffi/src/pg_constants_v16.rs b/libs/postgres_ffi/src/pg_constants_v16.rs index 31bd5b68fd..d84db502f3 100644 --- a/libs/postgres_ffi/src/pg_constants_v16.rs +++ b/libs/postgres_ffi/src/pg_constants_v16.rs @@ -1,3 +1,7 @@ +use crate::PgMajorVersion; + +pub const MY_PGVERSION: PgMajorVersion = PgMajorVersion::PG16; + pub const XACT_XINFO_HAS_DROPPED_STATS: u32 = 1u32 << 8; pub const XLOG_DBASE_CREATE_FILE_COPY: u8 = 0x00; diff --git a/libs/postgres_ffi/src/pg_constants_v17.rs b/libs/postgres_ffi/src/pg_constants_v17.rs index 2132938680..14d4b3d42f 100644 --- a/libs/postgres_ffi/src/pg_constants_v17.rs +++ b/libs/postgres_ffi/src/pg_constants_v17.rs @@ -1,3 +1,7 @@ +use crate::PgMajorVersion; + +pub const MY_PGVERSION: PgMajorVersion = PgMajorVersion::PG17; + pub const XACT_XINFO_HAS_DROPPED_STATS: u32 = 1u32 << 8; pub const XLOG_DBASE_CREATE_FILE_COPY: u8 = 0x00; diff --git a/libs/postgres_ffi/src/relfile_utils.rs b/libs/postgres_ffi/src/relfile_utils.rs index aa0e625b47..38f94b7221 100644 --- a/libs/postgres_ffi/src/relfile_utils.rs +++ b/libs/postgres_ffi/src/relfile_utils.rs @@ -4,50 +4,7 @@ use once_cell::sync::OnceCell; use regex::Regex; -// -// Fork numbers, from relpath.h -// -pub const MAIN_FORKNUM: u8 = 0; -pub const FSM_FORKNUM: u8 = 1; -pub const VISIBILITYMAP_FORKNUM: u8 = 2; -pub const INIT_FORKNUM: u8 = 3; - -#[derive(Debug, Clone, thiserror::Error, PartialEq, Eq)] -pub enum FilePathError { - #[error("invalid relation fork name")] - InvalidForkName, - #[error("invalid relation data file name")] - InvalidFileName, -} - -impl From for FilePathError { - fn from(_e: core::num::ParseIntError) -> Self { - FilePathError::InvalidFileName - } -} - -/// Convert Postgres relation file's fork suffix to fork number. -pub fn forkname_to_number(forkname: Option<&str>) -> Result { - match forkname { - // "main" is not in filenames, it's implicit if the fork name is not present - None => Ok(MAIN_FORKNUM), - Some("fsm") => Ok(FSM_FORKNUM), - Some("vm") => Ok(VISIBILITYMAP_FORKNUM), - Some("init") => Ok(INIT_FORKNUM), - Some(_) => Err(FilePathError::InvalidForkName), - } -} - -/// Convert Postgres fork number to the right suffix of the relation data file. -pub fn forknumber_to_name(forknum: u8) -> Option<&'static str> { - match forknum { - MAIN_FORKNUM => None, - FSM_FORKNUM => Some("fsm"), - VISIBILITYMAP_FORKNUM => Some("vm"), - INIT_FORKNUM => Some("init"), - _ => Some("UNKNOWN FORKNUM"), - } -} +use postgres_ffi_types::forknum::*; /// Parse a filename of a relation file. Returns (relfilenode, forknum, segno) tuple. /// @@ -75,7 +32,9 @@ pub fn parse_relfilename(fname: &str) -> Result<(u32, u8, u32), FilePathError> { .ok_or(FilePathError::InvalidFileName)?; let relnode_str = caps.name("relnode").unwrap().as_str(); - let relnode = relnode_str.parse::()?; + let relnode = relnode_str + .parse::() + .map_err(|_e| FilePathError::InvalidFileName)?; let forkname = caps.name("forkname").map(|f| f.as_str()); let forknum = forkname_to_number(forkname)?; @@ -84,7 +43,11 @@ pub fn parse_relfilename(fname: &str) -> Result<(u32, u8, u32), FilePathError> { let segno = if segno_match.is_none() { 0 } else { - segno_match.unwrap().as_str().parse::()? + segno_match + .unwrap() + .as_str() + .parse::() + .map_err(|_e| FilePathError::InvalidFileName)? }; Ok((relnode, forknum, segno)) diff --git a/libs/postgres_ffi/src/waldecoder_handler.rs b/libs/postgres_ffi/src/waldecoder_handler.rs index b4d50375bd..9cd40645ec 100644 --- a/libs/postgres_ffi/src/waldecoder_handler.rs +++ b/libs/postgres_ffi/src/waldecoder_handler.rs @@ -114,7 +114,7 @@ impl WalStreamDecoderHandler for WalStreamDecoder { let hdr = XLogLongPageHeaderData::from_bytes(&mut self.inputbuf).map_err( |e| WalDecodeError { - msg: format!("long header deserialization failed {}", e), + msg: format!("long header deserialization failed {e}"), lsn: self.lsn, }, )?; @@ -130,7 +130,7 @@ impl WalStreamDecoderHandler for WalStreamDecoder { let hdr = XLogPageHeaderData::from_bytes(&mut self.inputbuf).map_err(|e| { WalDecodeError { - msg: format!("header deserialization failed {}", e), + msg: format!("header deserialization failed {e}"), lsn: self.lsn, } })?; @@ -155,7 +155,7 @@ impl WalStreamDecoderHandler for WalStreamDecoder { let xl_tot_len = (&self.inputbuf[0..4]).get_u32_le(); if (xl_tot_len as usize) < XLOG_SIZE_OF_XLOG_RECORD { return Err(WalDecodeError { - msg: format!("invalid xl_tot_len {}", xl_tot_len), + msg: format!("invalid xl_tot_len {xl_tot_len}"), lsn: self.lsn, }); } @@ -218,7 +218,7 @@ impl WalStreamDecoderHandler for WalStreamDecoder { let xlogrec = XLogRecord::from_slice(&recordbuf[0..XLOG_SIZE_OF_XLOG_RECORD]).map_err(|e| { WalDecodeError { - msg: format!("xlog record deserialization failed {}", e), + msg: format!("xlog record deserialization failed {e}"), lsn: self.lsn, } })?; diff --git a/libs/postgres_ffi/src/walrecord.rs b/libs/postgres_ffi/src/walrecord.rs index 1ccf4590a9..d593123dc0 100644 --- a/libs/postgres_ffi/src/walrecord.rs +++ b/libs/postgres_ffi/src/walrecord.rs @@ -9,8 +9,8 @@ use utils::bin_ser::DeserializeError; use utils::lsn::Lsn; use crate::{ - BLCKSZ, BlockNumber, MultiXactId, MultiXactOffset, MultiXactStatus, Oid, RepOriginId, - TimestampTz, TransactionId, XLOG_SIZE_OF_XLOG_RECORD, XLogRecord, pg_constants, + BLCKSZ, BlockNumber, MultiXactId, MultiXactOffset, MultiXactStatus, Oid, PgMajorVersion, + RepOriginId, TimestampTz, TransactionId, XLOG_SIZE_OF_XLOG_RECORD, XLogRecord, pg_constants, }; #[repr(C)] @@ -199,20 +199,17 @@ impl DecodedWALRecord { /// Check if this WAL record represents a legacy "copy" database creation, which populates new relations /// by reading other existing relations' data blocks. This is more complex to apply than new-style database /// creations which simply include all the desired blocks in the WAL, so we need a helper function to detect this case. - pub fn is_dbase_create_copy(&self, pg_version: u32) -> bool { + pub fn is_dbase_create_copy(&self, pg_version: PgMajorVersion) -> bool { if self.xl_rmid == pg_constants::RM_DBASE_ID { let info = self.xl_info & pg_constants::XLR_RMGR_INFO_MASK; match pg_version { - 14 => { + PgMajorVersion::PG14 => { // Postgres 14 database creations are always the legacy kind info == crate::v14::bindings::XLOG_DBASE_CREATE } - 15 => info == crate::v15::bindings::XLOG_DBASE_CREATE_FILE_COPY, - 16 => info == crate::v16::bindings::XLOG_DBASE_CREATE_FILE_COPY, - 17 => info == crate::v17::bindings::XLOG_DBASE_CREATE_FILE_COPY, - _ => { - panic!("Unsupported postgres version {pg_version}") - } + PgMajorVersion::PG15 => info == crate::v15::bindings::XLOG_DBASE_CREATE_FILE_COPY, + PgMajorVersion::PG16 => info == crate::v16::bindings::XLOG_DBASE_CREATE_FILE_COPY, + PgMajorVersion::PG17 => info == crate::v17::bindings::XLOG_DBASE_CREATE_FILE_COPY, } } else { false @@ -248,7 +245,7 @@ impl DecodedWALRecord { pub fn decode_wal_record( record: Bytes, decoded: &mut DecodedWALRecord, - pg_version: u32, + pg_version: PgMajorVersion, ) -> anyhow::Result<()> { let mut rnode_spcnode: u32 = 0; let mut rnode_dbnode: u32 = 0; @@ -1106,9 +1103,9 @@ pub struct XlClogTruncate { } impl XlClogTruncate { - pub fn decode(buf: &mut Bytes, pg_version: u32) -> XlClogTruncate { + pub fn decode(buf: &mut Bytes, pg_version: PgMajorVersion) -> XlClogTruncate { XlClogTruncate { - pageno: if pg_version < 17 { + pageno: if pg_version < PgMajorVersion::PG17 { buf.get_u32_le() } else { buf.get_u64_le() as u32 @@ -1199,7 +1196,7 @@ pub fn describe_postgres_wal_record(record: &Bytes) -> Result "HEAP2 MULTI_INSERT", pg_constants::XLOG_HEAP2_VISIBLE => "HEAP2 VISIBLE", _ => { - unknown_str = format!("HEAP2 UNKNOWN_0x{:02x}", info); + unknown_str = format!("HEAP2 UNKNOWN_0x{info:02x}"); &unknown_str } } @@ -1212,7 +1209,7 @@ pub fn describe_postgres_wal_record(record: &Bytes) -> Result "HEAP UPDATE", pg_constants::XLOG_HEAP_HOT_UPDATE => "HEAP HOT_UPDATE", _ => { - unknown_str = format!("HEAP2 UNKNOWN_0x{:02x}", info); + unknown_str = format!("HEAP2 UNKNOWN_0x{info:02x}"); &unknown_str } } @@ -1223,7 +1220,7 @@ pub fn describe_postgres_wal_record(record: &Bytes) -> Result "XLOG FPI", pg_constants::XLOG_FPI_FOR_HINT => "XLOG FPI_FOR_HINT", _ => { - unknown_str = format!("XLOG UNKNOWN_0x{:02x}", info); + unknown_str = format!("XLOG UNKNOWN_0x{info:02x}"); &unknown_str } } @@ -1231,7 +1228,7 @@ pub fn describe_postgres_wal_record(record: &Bytes) -> Result { let info = xlogrec.xl_info & pg_constants::XLR_RMGR_INFO_MASK; - unknown_str = format!("UNKNOWN_RM_{} INFO_0x{:02x}", rmid, info); + unknown_str = format!("UNKNOWN_RM_{rmid} INFO_0x{info:02x}"); &unknown_str } }; diff --git a/libs/postgres_ffi/src/xlog_utils.rs b/libs/postgres_ffi/src/xlog_utils.rs index 14fb1f2a1f..f7b6296053 100644 --- a/libs/postgres_ffi/src/xlog_utils.rs +++ b/libs/postgres_ffi/src/xlog_utils.rs @@ -11,9 +11,9 @@ use super::super::waldecoder::WalStreamDecoder; use super::bindings::{ CheckPoint, ControlFileData, DBState_DB_SHUTDOWNED, FullTransactionId, TimeLineID, TimestampTz, XLogLongPageHeaderData, XLogPageHeaderData, XLogRecPtr, XLogRecord, XLogSegNo, XLOG_PAGE_MAGIC, + MY_PGVERSION }; use super::wal_generator::LogicalMessageGenerator; -use super::PG_MAJORVERSION; use crate::pg_constants; use crate::PG_TLI; use crate::{uint32, uint64, Oid}; @@ -233,7 +233,7 @@ pub fn find_end_of_wal( let mut result = start_lsn; let mut curr_lsn = start_lsn; let mut buf = [0u8; XLOG_BLCKSZ]; - let pg_version = PG_MAJORVERSION[1..3].parse::().unwrap(); + let pg_version = MY_PGVERSION; debug!("find_end_of_wal PG_VERSION: {}", pg_version); let mut decoder = WalStreamDecoder::new(start_lsn, pg_version); diff --git a/libs/postgres_ffi/wal_craft/src/bin/wal_craft.rs b/libs/postgres_ffi/wal_craft/src/bin/wal_craft.rs index 6151ce34ac..44bc4dfa95 100644 --- a/libs/postgres_ffi/wal_craft/src/bin/wal_craft.rs +++ b/libs/postgres_ffi/wal_craft/src/bin/wal_craft.rs @@ -4,6 +4,7 @@ use std::str::FromStr; use anyhow::*; use clap::{Arg, ArgMatches, Command, value_parser}; use postgres::Client; +use postgres_ffi::PgMajorVersion; use wal_craft::*; fn main() -> Result<()> { @@ -48,7 +49,7 @@ fn main() -> Result<()> { Some(("with-initdb", arg_matches)) => { let cfg = Conf { pg_version: *arg_matches - .get_one::("pg-version") + .get_one::("pg-version") .context("'pg-version' is required")?, pg_distrib_dir: arg_matches .get_one::("pg-distrib-dir") diff --git a/libs/postgres_ffi/wal_craft/src/lib.rs b/libs/postgres_ffi/wal_craft/src/lib.rs index ca9530faef..ef9e854297 100644 --- a/libs/postgres_ffi/wal_craft/src/lib.rs +++ b/libs/postgres_ffi/wal_craft/src/lib.rs @@ -9,8 +9,8 @@ use log::*; use postgres::Client; use postgres::types::PgLsn; use postgres_ffi::{ - WAL_SEGMENT_SIZE, XLOG_BLCKSZ, XLOG_SIZE_OF_XLOG_LONG_PHD, XLOG_SIZE_OF_XLOG_RECORD, - XLOG_SIZE_OF_XLOG_SHORT_PHD, + PgMajorVersion, WAL_SEGMENT_SIZE, XLOG_BLCKSZ, XLOG_SIZE_OF_XLOG_LONG_PHD, + XLOG_SIZE_OF_XLOG_RECORD, XLOG_SIZE_OF_XLOG_SHORT_PHD, }; macro_rules! xlog_utils_test { @@ -29,7 +29,7 @@ macro_rules! xlog_utils_test { postgres_ffi::for_all_postgres_versions! { xlog_utils_test } pub struct Conf { - pub pg_version: u32, + pub pg_version: PgMajorVersion, pub pg_distrib_dir: PathBuf, pub datadir: PathBuf, } @@ -52,11 +52,7 @@ impl Conf { pub fn pg_distrib_dir(&self) -> anyhow::Result { let path = self.pg_distrib_dir.clone(); - #[allow(clippy::manual_range_patterns)] - match self.pg_version { - 14 | 15 | 16 | 17 => Ok(path.join(format!("v{}", self.pg_version))), - _ => bail!("Unsupported postgres version: {}", self.pg_version), - } + Ok(path.join(self.pg_version.v_str())) } fn pg_bin_dir(&self) -> anyhow::Result { diff --git a/libs/postgres_ffi/wal_craft/src/xlog_utils_test.rs b/libs/postgres_ffi/wal_craft/src/xlog_utils_test.rs index 4a33dbe25b..366aa7dbef 100644 --- a/libs/postgres_ffi/wal_craft/src/xlog_utils_test.rs +++ b/libs/postgres_ffi/wal_craft/src/xlog_utils_test.rs @@ -24,7 +24,7 @@ fn init_logging() { fn test_end_of_wal(test_name: &str) { use crate::*; - let pg_version = PG_MAJORVERSION[1..3].parse::().unwrap(); + let pg_version = MY_PGVERSION; // Craft some WAL let top_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) @@ -34,7 +34,7 @@ fn test_end_of_wal(test_name: &str) { let cfg = Conf { pg_version, pg_distrib_dir: top_path.join("pg_install"), - datadir: top_path.join(format!("test_output/{}-{PG_MAJORVERSION}", test_name)), + datadir: top_path.join(format!("test_output/{test_name}-{PG_MAJORVERSION}")), }; if cfg.datadir.exists() { fs::remove_dir_all(&cfg.datadir).unwrap(); diff --git a/libs/postgres_ffi_types/Cargo.toml b/libs/postgres_ffi_types/Cargo.toml new file mode 100644 index 0000000000..50c6fc7874 --- /dev/null +++ b/libs/postgres_ffi_types/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "postgres_ffi_types" +version = "0.1.0" +edition.workspace = true +license.workspace = true + +[dependencies] +thiserror.workspace = true +workspace_hack = { version = "0.1", path = "../../workspace_hack" } + +[dev-dependencies] diff --git a/libs/postgres_ffi_types/src/constants.rs b/libs/postgres_ffi_types/src/constants.rs new file mode 100644 index 0000000000..c1a004c5ab --- /dev/null +++ b/libs/postgres_ffi_types/src/constants.rs @@ -0,0 +1,8 @@ +//! Misc constants, copied from PostgreSQL headers. +//! +//! Any constants included here must be the same in all PostgreSQL versions and unlikely to change +//! in the future either! + +// From pg_tablespace_d.h +pub const DEFAULTTABLESPACE_OID: u32 = 1663; +pub const GLOBALTABLESPACE_OID: u32 = 1664; diff --git a/libs/postgres_ffi_types/src/forknum.rs b/libs/postgres_ffi_types/src/forknum.rs new file mode 100644 index 0000000000..9b225d8ce5 --- /dev/null +++ b/libs/postgres_ffi_types/src/forknum.rs @@ -0,0 +1,36 @@ +// Fork numbers, from relpath.h +pub const MAIN_FORKNUM: u8 = 0; +pub const FSM_FORKNUM: u8 = 1; +pub const VISIBILITYMAP_FORKNUM: u8 = 2; +pub const INIT_FORKNUM: u8 = 3; + +#[derive(Debug, Clone, thiserror::Error, PartialEq, Eq)] +pub enum FilePathError { + #[error("invalid relation fork name")] + InvalidForkName, + #[error("invalid relation data file name")] + InvalidFileName, +} + +/// Convert Postgres relation file's fork suffix to fork number. +pub fn forkname_to_number(forkname: Option<&str>) -> Result { + match forkname { + // "main" is not in filenames, it's implicit if the fork name is not present + None => Ok(MAIN_FORKNUM), + Some("fsm") => Ok(FSM_FORKNUM), + Some("vm") => Ok(VISIBILITYMAP_FORKNUM), + Some("init") => Ok(INIT_FORKNUM), + Some(_) => Err(FilePathError::InvalidForkName), + } +} + +/// Convert Postgres fork number to the right suffix of the relation data file. +pub fn forknumber_to_name(forknum: u8) -> Option<&'static str> { + match forknum { + MAIN_FORKNUM => None, + FSM_FORKNUM => Some("fsm"), + VISIBILITYMAP_FORKNUM => Some("vm"), + INIT_FORKNUM => Some("init"), + _ => Some("UNKNOWN FORKNUM"), + } +} diff --git a/libs/postgres_ffi_types/src/lib.rs b/libs/postgres_ffi_types/src/lib.rs new file mode 100644 index 0000000000..84ef499b9f --- /dev/null +++ b/libs/postgres_ffi_types/src/lib.rs @@ -0,0 +1,13 @@ +//! This package contains some PostgreSQL constants and datatypes that are the same in all versions +//! of PostgreSQL and unlikely to change in the future either. These could be derived from the +//! PostgreSQL headers with 'bindgen', but in order to avoid proliferating the dependency to bindgen +//! and the PostgreSQL C headers to all services, we prefer to have this small stand-alone crate for +//! them instead. +//! +//! Be mindful in what you add here, as these types are deeply ingrained in the APIs. + +pub mod constants; +pub mod forknum; + +pub type Oid = u32; +pub type RepOriginId = u16; diff --git a/libs/postgres_initdb/Cargo.toml b/libs/postgres_initdb/Cargo.toml index 1605279bce..5b3b0cd936 100644 --- a/libs/postgres_initdb/Cargo.toml +++ b/libs/postgres_initdb/Cargo.toml @@ -9,4 +9,5 @@ anyhow.workspace = true tokio.workspace = true camino.workspace = true thiserror.workspace = true +postgres_versioninfo.workspace = true workspace_hack = { version = "0.1", path = "../../workspace_hack" } diff --git a/libs/postgres_initdb/src/lib.rs b/libs/postgres_initdb/src/lib.rs index ed54696861..a0c6ebef81 100644 --- a/libs/postgres_initdb/src/lib.rs +++ b/libs/postgres_initdb/src/lib.rs @@ -7,12 +7,13 @@ use std::fmt; use camino::Utf8Path; +use postgres_versioninfo::PgMajorVersion; pub struct RunInitdbArgs<'a> { pub superuser: &'a str, pub locale: &'a str, pub initdb_bin: &'a Utf8Path, - pub pg_version: u32, + pub pg_version: PgMajorVersion, pub library_search_path: &'a Utf8Path, pub pgdata: &'a Utf8Path, } @@ -31,15 +32,15 @@ pub enum Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Error::Spawn(e) => write!(f, "Error spawning command: {:?}", e), + Error::Spawn(e) => write!(f, "Error spawning command: {e:?}"), Error::Failed { status, stderr } => write!( f, "Command failed with status {:?}: {}", status, String::from_utf8_lossy(stderr) ), - Error::WaitOutput(e) => write!(f, "Error waiting for command output: {:?}", e), - Error::Other(e) => write!(f, "Error: {:?}", e), + Error::WaitOutput(e) => write!(f, "Error waiting for command output: {e:?}"), + Error::Other(e) => write!(f, "Error: {e:?}"), } } } @@ -79,12 +80,16 @@ pub async fn do_run_initdb(args: RunInitdbArgs<'_>) -> Result<(), Error> { .stderr(std::process::Stdio::piped()); // Before version 14, only the libc provide was available. - if pg_version > 14 { + if pg_version > PgMajorVersion::PG14 { // Version 17 brought with it a builtin locale provider which only provides // C and C.UTF-8. While being safer for collation purposes since it is // guaranteed to be consistent throughout a major release, it is also more // performant. - let locale_provider = if pg_version >= 17 { "builtin" } else { "libc" }; + let locale_provider = if pg_version >= PgMajorVersion::PG17 { + "builtin" + } else { + "libc" + }; initdb_command.args(["--locale-provider", locale_provider]); } diff --git a/libs/postgres_versioninfo/Cargo.toml b/libs/postgres_versioninfo/Cargo.toml new file mode 100644 index 0000000000..cc59f9698d --- /dev/null +++ b/libs/postgres_versioninfo/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "postgres_versioninfo" +version = "0.1.0" +edition = "2024" +license.workspace = true + +[dependencies] +anyhow.workspace = true +thiserror.workspace = true +serde.workspace = true +serde_repr.workspace = true +workspace_hack = { version = "0.1", path = "../../workspace_hack" } diff --git a/libs/postgres_versioninfo/src/lib.rs b/libs/postgres_versioninfo/src/lib.rs new file mode 100644 index 0000000000..286507b654 --- /dev/null +++ b/libs/postgres_versioninfo/src/lib.rs @@ -0,0 +1,175 @@ +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use serde_repr::{Deserialize_repr, Serialize_repr}; +use std::fmt::{Display, Formatter}; +use std::str::FromStr; + +/// An enum with one variant for each major version of PostgreSQL that we support. +/// +#[derive(Debug, Clone, Copy, Ord, PartialOrd, Eq, PartialEq, Deserialize_repr, Serialize_repr)] +#[repr(u32)] +pub enum PgMajorVersion { + PG14 = 14, + PG15 = 15, + PG16 = 16, + PG17 = 17, + // !!! When you add a new PgMajorVersion, don't forget to update PgMajorVersion::ALL +} + +/// A full PostgreSQL version ID, in MMmmbb numerical format (Major/minor/bugfix) +#[derive(Debug, Copy, Clone, Ord, PartialOrd, Eq, PartialEq)] +#[repr(transparent)] +pub struct PgVersionId(u32); + +impl PgVersionId { + pub const UNKNOWN: PgVersionId = PgVersionId(0); + + pub fn from_full_pg_version(version: u32) -> PgVersionId { + match version { + 0 => PgVersionId(version), // unknown version + 140000..180000 => PgVersionId(version), + _ => panic!("Invalid full PostgreSQL version ID {version}"), + } + } +} + +impl Display for PgVersionId { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + u32::fmt(&self.0, f) + } +} + +impl Serialize for PgVersionId { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + u32::serialize(&self.0, serializer) + } +} + +impl<'de> Deserialize<'de> for PgVersionId { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + u32::deserialize(deserializer).map(PgVersionId) + } + + fn deserialize_in_place(deserializer: D, place: &mut Self) -> Result<(), D::Error> + where + D: Deserializer<'de>, + { + u32::deserialize_in_place(deserializer, &mut place.0) + } +} + +impl PgMajorVersion { + /// Get the numerical representation of the represented Major Version + pub const fn major_version_num(&self) -> u32 { + match self { + PgMajorVersion::PG14 => 14, + PgMajorVersion::PG15 => 15, + PgMajorVersion::PG16 => 16, + PgMajorVersion::PG17 => 17, + } + } + + /// Get the contents of this version's PG_VERSION file. + /// + /// The PG_VERSION file is used to determine the PostgreSQL version that currently + /// owns the data in a PostgreSQL data directory. + pub fn versionfile_string(&self) -> &'static str { + match self { + PgMajorVersion::PG14 => "14", + PgMajorVersion::PG15 => "15", + PgMajorVersion::PG16 => "16\x0A", + PgMajorVersion::PG17 => "17\x0A", + } + } + + /// Get the v{version} string of this major PostgreSQL version. + /// + /// Because this was hand-coded in various places, this was moved into a shared + /// implementation. + pub fn v_str(&self) -> String { + match self { + PgMajorVersion::PG14 => "v14", + PgMajorVersion::PG15 => "v15", + PgMajorVersion::PG16 => "v16", + PgMajorVersion::PG17 => "v17", + } + .to_string() + } + + /// All currently supported major versions of PostgreSQL. + pub const ALL: &'static [PgMajorVersion] = &[ + PgMajorVersion::PG14, + PgMajorVersion::PG15, + PgMajorVersion::PG16, + PgMajorVersion::PG17, + ]; +} + +impl Display for PgMajorVersion { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + PgMajorVersion::PG14 => "PgMajorVersion::PG14", + PgMajorVersion::PG15 => "PgMajorVersion::PG15", + PgMajorVersion::PG16 => "PgMajorVersion::PG16", + PgMajorVersion::PG17 => "PgMajorVersion::PG17", + }) + } +} + +#[derive(Debug, thiserror::Error)] +#[allow(dead_code)] +pub struct InvalidPgVersion(u32); + +impl Display for InvalidPgVersion { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "InvalidPgVersion({})", self.0) + } +} + +impl TryFrom for PgMajorVersion { + type Error = InvalidPgVersion; + + fn try_from(value: PgVersionId) -> Result { + Ok(match value.0 / 10000 { + 14 => PgMajorVersion::PG14, + 15 => PgMajorVersion::PG15, + 16 => PgMajorVersion::PG16, + 17 => PgMajorVersion::PG17, + _ => return Err(InvalidPgVersion(value.0)), + }) + } +} + +impl From for PgVersionId { + fn from(value: PgMajorVersion) -> Self { + PgVersionId((value as u32) * 10000) + } +} + +#[derive(Debug, PartialEq, Eq, thiserror::Error)] +pub struct PgMajorVersionParseError(String); + +impl Display for PgMajorVersionParseError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "PgMajorVersionParseError({})", self.0) + } +} + +impl FromStr for PgMajorVersion { + type Err = PgMajorVersionParseError; + + fn from_str(s: &str) -> Result { + Ok(match s { + "14" => PgMajorVersion::PG14, + "15" => PgMajorVersion::PG15, + "16" => PgMajorVersion::PG16, + "17" => PgMajorVersion::PG17, + _ => return Err(PgMajorVersionParseError(s.to_string())), + }) + } +} diff --git a/libs/posthog_client_lite/src/background_loop.rs b/libs/posthog_client_lite/src/background_loop.rs index a404c76da9..08cb0d2264 100644 --- a/libs/posthog_client_lite/src/background_loop.rs +++ b/libs/posthog_client_lite/src/background_loop.rs @@ -1,17 +1,22 @@ //! A background loop that fetches feature flags from PostHog and updates the feature store. -use std::{sync::Arc, time::Duration}; +use std::{ + sync::Arc, + time::{Duration, SystemTime}, +}; use arc_swap::ArcSwap; use tokio_util::sync::CancellationToken; use tracing::{Instrument, info_span}; -use crate::{CaptureEvent, FeatureStore, PostHogClient, PostHogClientConfig}; +use crate::{ + CaptureEvent, FeatureStore, LocalEvaluationResponse, PostHogClient, PostHogClientConfig, +}; /// A background loop that fetches feature flags from PostHog and updates the feature store. pub struct FeatureResolverBackgroundLoop { posthog_client: PostHogClient, - feature_store: ArcSwap, + feature_store: ArcSwap<(SystemTime, Arc)>, cancel: CancellationToken, } @@ -19,11 +24,35 @@ impl FeatureResolverBackgroundLoop { pub fn new(config: PostHogClientConfig, shutdown_pageserver: CancellationToken) -> Self { Self { posthog_client: PostHogClient::new(config), - feature_store: ArcSwap::new(Arc::new(FeatureStore::new())), + feature_store: ArcSwap::new(Arc::new(( + SystemTime::UNIX_EPOCH, + Arc::new(FeatureStore::new()), + ))), cancel: shutdown_pageserver, } } + /// Update the feature store with a new feature flag spec bypassing the normal refresh loop. + pub fn update(&self, spec: String) -> anyhow::Result<()> { + let resp: LocalEvaluationResponse = serde_json::from_str(&spec)?; + self.update_feature_store_nofail(resp, "http_propagate"); + Ok(()) + } + + fn update_feature_store_nofail(&self, resp: LocalEvaluationResponse, source: &'static str) { + let project_id = self.posthog_client.config.project_id.parse::().ok(); + match FeatureStore::new_with_flags(resp.flags, project_id) { + Ok(feature_store) => { + self.feature_store + .store(Arc::new((SystemTime::now(), Arc::new(feature_store)))); + tracing::info!("Feature flag updated from {}", source); + } + Err(e) => { + tracing::warn!("Cannot process feature flag spec from {}: {}", source, e); + } + } + } + pub fn spawn( self: Arc, handle: &tokio::runtime::Handle, @@ -36,7 +65,10 @@ impl FeatureResolverBackgroundLoop { // Main loop of updating the feature flags. handle.spawn( async move { - tracing::info!("Starting PostHog feature resolver"); + tracing::info!( + "Starting PostHog feature resolver with refresh period: {:?}", + refresh_period + ); let mut ticker = tokio::time::interval(refresh_period); ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); loop { @@ -44,6 +76,17 @@ impl FeatureResolverBackgroundLoop { _ = ticker.tick() => {} _ = cancel.cancelled() => break } + { + let last_update = this.feature_store.load().0; + if let Ok(elapsed) = last_update.elapsed() { + if elapsed < refresh_period { + tracing::debug!( + "Skipping feature flag refresh because it's too soon" + ); + continue; + } + } + } let resp = match this .posthog_client .get_feature_flags_local_evaluation() @@ -55,9 +98,7 @@ impl FeatureResolverBackgroundLoop { continue; } }; - let feature_store = FeatureStore::new_with_flags(resp.flags); - this.feature_store.store(Arc::new(feature_store)); - tracing::info!("Feature flag updated"); + this.update_feature_store_nofail(resp, "refresh_loop"); } tracing::info!("PostHog feature resolver stopped"); } @@ -82,6 +123,6 @@ impl FeatureResolverBackgroundLoop { } pub fn feature_store(&self) -> Arc { - self.feature_store.load_full() + self.feature_store.load().1.clone() } } diff --git a/libs/posthog_client_lite/src/lib.rs b/libs/posthog_client_lite/src/lib.rs index f607b1be0a..d042ee2410 100644 --- a/libs/posthog_client_lite/src/lib.rs +++ b/libs/posthog_client_lite/src/lib.rs @@ -39,6 +39,9 @@ pub struct LocalEvaluationResponse { #[derive(Deserialize)] pub struct LocalEvaluationFlag { + #[allow(dead_code)] + id: u64, + team_id: u64, key: String, filters: LocalEvaluationFlagFilters, active: bool, @@ -107,17 +110,32 @@ impl FeatureStore { } } - pub fn new_with_flags(flags: Vec) -> Self { + pub fn new_with_flags( + flags: Vec, + project_id: Option, + ) -> Result { let mut store = Self::new(); - store.set_flags(flags); - store + store.set_flags(flags, project_id)?; + Ok(store) } - pub fn set_flags(&mut self, flags: Vec) { + pub fn set_flags( + &mut self, + flags: Vec, + project_id: Option, + ) -> Result<(), &'static str> { self.flags.clear(); for flag in flags { + if let Some(project_id) = project_id { + if flag.team_id != project_id { + return Err( + "Retrieved a spec with different project id, wrong config? Discarding the feature flags.", + ); + } + } self.flags.insert(flag.key.clone(), flag); } + Ok(()) } /// Generate a consistent hash for a user ID (e.g., tenant ID). @@ -150,15 +168,13 @@ impl FeatureStore { let PostHogFlagFilterPropertyValue::String(provided) = provided else { // Left should be a string return Err(PostHogEvaluationError::Internal(format!( - "The left side of the condition is not a string: {:?}", - provided + "The left side of the condition is not a string: {provided:?}" ))); }; let PostHogFlagFilterPropertyValue::List(requested) = requested else { // Right should be a list of string return Err(PostHogEvaluationError::Internal(format!( - "The right side of the condition is not a list: {:?}", - requested + "The right side of the condition is not a list: {requested:?}" ))); }; Ok(requested.contains(provided)) @@ -167,14 +183,12 @@ impl FeatureStore { let PostHogFlagFilterPropertyValue::String(requested) = requested else { // Right should be a string return Err(PostHogEvaluationError::Internal(format!( - "The right side of the condition is not a string: {:?}", - requested + "The right side of the condition is not a string: {requested:?}" ))); }; let Ok(requested) = requested.parse::() else { return Err(PostHogEvaluationError::Internal(format!( - "Can not parse the right side of the condition as a number: {:?}", - requested + "Can not parse the right side of the condition as a number: {requested:?}" ))); }; // Left can either be a number or a string @@ -183,16 +197,14 @@ impl FeatureStore { PostHogFlagFilterPropertyValue::String(provided) => { let Ok(provided) = provided.parse::() else { return Err(PostHogEvaluationError::Internal(format!( - "Can not parse the left side of the condition as a number: {:?}", - provided + "Can not parse the left side of the condition as a number: {provided:?}" ))); }; provided } _ => { return Err(PostHogEvaluationError::Internal(format!( - "The left side of the condition is not a number or a string: {:?}", - provided + "The left side of the condition is not a number or a string: {provided:?}" ))); } }; @@ -200,14 +212,12 @@ impl FeatureStore { "lt" => Ok(provided < requested), "gt" => Ok(provided > requested), op => Err(PostHogEvaluationError::Internal(format!( - "Unsupported operator: {}", - op + "Unsupported operator: {op}" ))), } } _ => Err(PostHogEvaluationError::Internal(format!( - "Unsupported operator: {}", - operator + "Unsupported operator: {operator}" ))), } } @@ -355,8 +365,7 @@ impl FeatureStore { if let Some(flag_config) = self.flags.get(flag_key) { if !flag_config.active { return Err(PostHogEvaluationError::NotAvailable(format!( - "The feature flag is not active: {}", - flag_key + "The feature flag is not active: {flag_key}" ))); } let Some(ref multivariate) = flag_config.filters.multivariate else { @@ -383,8 +392,7 @@ impl FeatureStore { // This should not happen because the rollout percentage always adds up to 100, but just in case that PostHog // returned invalid spec, we return an error. return Err(PostHogEvaluationError::Internal(format!( - "Rollout percentage does not add up to 100: {}", - flag_key + "Rollout percentage does not add up to 100: {flag_key}" ))); } GroupEvaluationResult::Unmatched => continue, @@ -395,8 +403,7 @@ impl FeatureStore { } else { // The feature flag is not available yet Err(PostHogEvaluationError::NotAvailable(format!( - "Not found in the local evaluation spec: {}", - flag_key + "Not found in the local evaluation spec: {flag_key}" ))) } } @@ -422,8 +429,7 @@ impl FeatureStore { if let Some(flag_config) = self.flags.get(flag_key) { if !flag_config.active { return Err(PostHogEvaluationError::NotAvailable(format!( - "The feature flag is not active: {}", - flag_key + "The feature flag is not active: {flag_key}" ))); } if flag_config.filters.multivariate.is_some() { @@ -438,8 +444,7 @@ impl FeatureStore { match self.evaluate_group(group, hash_on_global_rollout_percentage, properties)? { GroupEvaluationResult::MatchedAndOverride(_) => { return Err(PostHogEvaluationError::Internal(format!( - "Boolean flag cannot have overrides: {}", - flag_key + "Boolean flag cannot have overrides: {flag_key}" ))); } GroupEvaluationResult::MatchedAndEvaluate => { @@ -453,8 +458,7 @@ impl FeatureStore { } else { // The feature flag is not available yet Err(PostHogEvaluationError::NotAvailable(format!( - "Not found in the local evaluation spec: {}", - flag_key + "Not found in the local evaluation spec: {flag_key}" ))) } } @@ -465,8 +469,7 @@ impl FeatureStore { Ok(flag_config.filters.multivariate.is_none()) } else { Err(PostHogEvaluationError::NotAvailable(format!( - "Not found in the local evaluation spec: {}", - flag_key + "Not found in the local evaluation spec: {flag_key}" ))) } } @@ -534,23 +537,33 @@ impl PostHogClient { }) } - /// Fetch the feature flag specs from the server. - /// - /// This is unfortunately an undocumented API at: - /// - - /// - - /// - /// The handling logic in [`FeatureStore`] mostly follows the Python API implementation. - /// See `_compute_flag_locally` in - pub async fn get_feature_flags_local_evaluation( - &self, - ) -> anyhow::Result { + /// Check if the server API key is a feature flag secure API key. This key can only be + /// used to fetch the feature flag specs and can only be used on a undocumented API + /// endpoint. + fn is_feature_flag_secure_api_key(&self) -> bool { + self.config.server_api_key.starts_with("phs_") + } + + /// Get the raw JSON spec, same as `get_feature_flags_local_evaluation` but without parsing. + pub async fn get_feature_flags_local_evaluation_raw(&self) -> anyhow::Result { // BASE_URL/api/projects/:project_id/feature_flags/local_evaluation // with bearer token of self.server_api_key - let url = format!( - "{}/api/projects/{}/feature_flags/local_evaluation", - self.config.private_api_url, self.config.project_id - ); + // OR + // BASE_URL/api/feature_flag/local_evaluation/ + // with bearer token of feature flag specific self.server_api_key + let url = if self.is_feature_flag_secure_api_key() { + // The new feature local evaluation secure API token + format!( + "{}/api/feature_flag/local_evaluation", + self.config.private_api_url + ) + } else { + // The old personal API token + format!( + "{}/api/projects/{}/feature_flags/local_evaluation", + self.config.private_api_url, self.config.project_id + ) + }; let response = self .client .get(url) @@ -566,7 +579,22 @@ impl PostHogClient { body )); } - Ok(serde_json::from_str(&body)?) + Ok(body) + } + + /// Fetch the feature flag specs from the server. + /// + /// This is unfortunately an undocumented API at: + /// - + /// - + /// + /// The handling logic in [`FeatureStore`] mostly follows the Python API implementation. + /// See `_compute_flag_locally` in + pub async fn get_feature_flags_local_evaluation( + &self, + ) -> Result { + let raw = self.get_feature_flags_local_evaluation_raw().await?; + Ok(serde_json::from_str(&raw)?) } /// Capture an event. This will only be used to report the feature flag usage back to PostHog, though @@ -803,7 +831,7 @@ mod tests { fn evaluate_multivariate() { let mut store = FeatureStore::new(); let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap(); - store.set_flags(response.flags); + store.set_flags(response.flags, None).unwrap(); // This lacks the required properties and cannot be evaluated. let variant = @@ -873,7 +901,7 @@ mod tests { let mut store = FeatureStore::new(); let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap(); - store.set_flags(response.flags); + store.set_flags(response.flags, None).unwrap(); // This lacks the required properties and cannot be evaluated. let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &HashMap::new()); @@ -929,7 +957,7 @@ mod tests { let mut store = FeatureStore::new(); let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap(); - store.set_flags(response.flags); + store.set_flags(response.flags, None).unwrap(); // This lacks the required properties and cannot be evaluated. let variant = diff --git a/libs/pq_proto/src/lib.rs b/libs/pq_proto/src/lib.rs index e7afc64564..482dd9a298 100644 --- a/libs/pq_proto/src/lib.rs +++ b/libs/pq_proto/src/lib.rs @@ -198,7 +198,7 @@ impl fmt::Display for CancelKeyData { // This format is more compact and might work better for logs. f.debug_tuple("CancelKeyData") - .field(&format_args!("{:x}", id)) + .field(&format_args!("{id:x}")) .finish() } } @@ -291,8 +291,7 @@ impl FeMessage { let len = (&buf[1..5]).read_u32::().unwrap(); if len < 4 { return Err(ProtocolError::Protocol(format!( - "invalid message length {}", - len + "invalid message length {len}" ))); } @@ -367,8 +366,7 @@ impl FeStartupPacket { #[allow(clippy::manual_range_contains)] if len < 8 || len > MAX_STARTUP_PACKET_LENGTH { return Err(ProtocolError::Protocol(format!( - "invalid startup packet message length {}", - len + "invalid startup packet message length {len}" ))); } diff --git a/libs/proxy/postgres-protocol2/Cargo.toml b/libs/proxy/postgres-protocol2/Cargo.toml index 7ebb05eec1..9c8f8f3531 100644 --- a/libs/proxy/postgres-protocol2/Cargo.toml +++ b/libs/proxy/postgres-protocol2/Cargo.toml @@ -5,7 +5,7 @@ edition = "2024" license = "MIT/Apache-2.0" [dependencies] -base64 = "0.20" +base64.workspace = true byteorder.workspace = true bytes.workspace = true fallible-iterator.workspace = true diff --git a/libs/proxy/postgres-protocol2/src/authentication/sasl.rs b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs index 2daf9a80d4..b8304f9d8d 100644 --- a/libs/proxy/postgres-protocol2/src/authentication/sasl.rs +++ b/libs/proxy/postgres-protocol2/src/authentication/sasl.rs @@ -3,6 +3,8 @@ use std::fmt::Write; use std::{io, iter, mem, str}; +use base64::Engine as _; +use base64::prelude::BASE64_STANDARD; use hmac::{Hmac, Mac}; use rand::{self, Rng}; use sha2::digest::FixedOutput; @@ -226,7 +228,7 @@ impl ScramSha256 { let (client_key, server_key) = match password { Credentials::Password(password) => { - let salt = match base64::decode(parsed.salt) { + let salt = match BASE64_STANDARD.decode(parsed.salt) { Ok(salt) => salt, Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), }; @@ -255,7 +257,7 @@ impl ScramSha256 { let mut cbind_input = vec![]; cbind_input.extend(channel_binding.gs2_header().as_bytes()); cbind_input.extend(channel_binding.cbind_data()); - let cbind_input = base64::encode(&cbind_input); + let cbind_input = BASE64_STANDARD.encode(&cbind_input); self.message.clear(); write!(&mut self.message, "c={},r={}", cbind_input, parsed.nonce).unwrap(); @@ -272,7 +274,12 @@ impl ScramSha256 { *proof ^= signature; } - write!(&mut self.message, ",p={}", base64::encode(client_proof)).unwrap(); + write!( + &mut self.message, + ",p={}", + BASE64_STANDARD.encode(client_proof) + ) + .unwrap(); self.state = State::Finish { server_key, @@ -301,12 +308,12 @@ impl ScramSha256 { let verifier = match parsed { ServerFinalMessage::Error(e) => { - return Err(io::Error::other(format!("SCRAM error: {}", e))); + return Err(io::Error::other(format!("SCRAM error: {e}"))); } ServerFinalMessage::Verifier(verifier) => verifier, }; - let verifier = match base64::decode(verifier) { + let verifier = match BASE64_STANDARD.decode(verifier) { Ok(verifier) => verifier, Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidInput, e)), }; @@ -336,10 +343,8 @@ impl<'a> Parser<'a> { match self.it.next() { Some((_, c)) if c == target => Ok(()), Some((i, c)) => { - let m = format!( - "unexpected character at byte {}: expected `{}` but got `{}", - i, target, c - ); + let m = + format!("unexpected character at byte {i}: expected `{target}` but got `{c}"); Err(io::Error::new(io::ErrorKind::InvalidInput, m)) } None => Err(io::Error::new( @@ -405,7 +410,7 @@ impl<'a> Parser<'a> { match self.it.peek() { Some(&(i, _)) => Err(io::Error::new( io::ErrorKind::InvalidInput, - format!("unexpected trailing data at byte {}", i), + format!("unexpected trailing data at byte {i}"), )), None => Ok(()), } diff --git a/libs/proxy/postgres-protocol2/src/message/backend.rs b/libs/proxy/postgres-protocol2/src/message/backend.rs index d7eaef9509..3fc9a9335c 100644 --- a/libs/proxy/postgres-protocol2/src/message/backend.rs +++ b/libs/proxy/postgres-protocol2/src/message/backend.rs @@ -211,7 +211,7 @@ impl Message { tag => { return Err(io::Error::new( io::ErrorKind::InvalidInput, - format!("unknown authentication tag `{}`", tag), + format!("unknown authentication tag `{tag}`"), )); } }, @@ -238,7 +238,7 @@ impl Message { tag => { return Err(io::Error::new( io::ErrorKind::InvalidInput, - format!("unknown message tag `{}`", tag), + format!("unknown message tag `{tag}`"), )); } }; diff --git a/libs/proxy/postgres-protocol2/src/password/mod.rs b/libs/proxy/postgres-protocol2/src/password/mod.rs index 4cd9bfb060..e00ca1e34c 100644 --- a/libs/proxy/postgres-protocol2/src/password/mod.rs +++ b/libs/proxy/postgres-protocol2/src/password/mod.rs @@ -6,6 +6,8 @@ //! side. This is good because it ensures the cleartext password won't //! end up in logs pg_stat displays, etc. +use base64::Engine as _; +use base64::prelude::BASE64_STANDARD; use hmac::{Hmac, Mac}; use rand::RngCore; use sha2::digest::FixedOutput; @@ -83,8 +85,8 @@ pub(crate) async fn scram_sha_256_salt( format!( "SCRAM-SHA-256${}:{}${}:{}", SCRAM_DEFAULT_ITERATIONS, - base64::encode(salt), - base64::encode(stored_key), - base64::encode(server_key) + BASE64_STANDARD.encode(salt), + BASE64_STANDARD.encode(stored_key), + BASE64_STANDARD.encode(server_key) ) } diff --git a/libs/proxy/postgres-types2/src/lib.rs b/libs/proxy/postgres-types2/src/lib.rs index 7c9874bda3..c98c45636b 100644 --- a/libs/proxy/postgres-types2/src/lib.rs +++ b/libs/proxy/postgres-types2/src/lib.rs @@ -46,7 +46,7 @@ impl fmt::Display for Type { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match self.schema() { "public" | "pg_catalog" => {} - schema => write!(fmt, "{}.", schema)?, + schema => write!(fmt, "{schema}.")?, } fmt.write_str(self.name()) } diff --git a/libs/proxy/tokio-postgres2/src/cancel_query.rs b/libs/proxy/tokio-postgres2/src/cancel_query.rs index 0bdad0b554..94fbf333ed 100644 --- a/libs/proxy/tokio-postgres2/src/cancel_query.rs +++ b/libs/proxy/tokio-postgres2/src/cancel_query.rs @@ -1,5 +1,3 @@ -use std::io; - use tokio::net::TcpStream; use crate::client::SocketConfig; @@ -8,25 +6,15 @@ use crate::tls::MakeTlsConnect; use crate::{Error, cancel_query_raw, connect_socket}; pub(crate) async fn cancel_query( - config: Option, + config: SocketConfig, ssl_mode: SslMode, - mut tls: T, + tls: T, process_id: i32, secret_key: i32, ) -> Result<(), Error> where T: MakeTlsConnect, { - let config = match config { - Some(config) => config, - None => { - return Err(Error::connect(io::Error::new( - io::ErrorKind::InvalidInput, - "unknown host", - ))); - } - }; - let hostname = match &config.host { Host::Tcp(host) => &**host, }; diff --git a/libs/proxy/tokio-postgres2/src/cancel_token.rs b/libs/proxy/tokio-postgres2/src/cancel_token.rs index f6526395ee..c5566b4ad9 100644 --- a/libs/proxy/tokio-postgres2/src/cancel_token.rs +++ b/libs/proxy/tokio-postgres2/src/cancel_token.rs @@ -7,11 +7,16 @@ use crate::config::SslMode; use crate::tls::{MakeTlsConnect, TlsConnect}; use crate::{Error, cancel_query, cancel_query_raw}; -/// The capability to request cancellation of in-progress queries on a -/// connection. -#[derive(Clone, Serialize, Deserialize)] +/// A cancellation token that allows easy cancellation of a query. +#[derive(Clone)] pub struct CancelToken { - pub socket_config: Option, + pub socket_config: SocketConfig, + pub raw: RawCancelToken, +} + +/// A raw cancellation token that allows cancellation of a query, given a fresh connection to postgres. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RawCancelToken { pub ssl_mode: SslMode, pub process_id: i32, pub secret_key: i32, @@ -36,14 +41,16 @@ impl CancelToken { { cancel_query::cancel_query( self.socket_config.clone(), - self.ssl_mode, + self.raw.ssl_mode, tls, - self.process_id, - self.secret_key, + self.raw.process_id, + self.raw.secret_key, ) .await } +} +impl RawCancelToken { /// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new /// connection itself. pub async fn cancel_query_raw(&self, stream: S, tls: T) -> Result<(), Error> diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index a7edfc076a..41b22e35b6 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -12,6 +12,7 @@ use postgres_protocol2::message::frontend; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc; +use crate::cancel_token::RawCancelToken; use crate::codec::{BackendMessages, FrontendMessage}; use crate::config::{Host, SslMode}; use crate::query::RowStream; @@ -331,10 +332,12 @@ impl Client { /// connection associated with this client. pub fn cancel_token(&self) -> CancelToken { CancelToken { - socket_config: Some(self.socket_config.clone()), - ssl_mode: self.ssl_mode, - process_id: self.process_id, - secret_key: self.secret_key, + socket_config: self.socket_config.clone(), + raw: RawCancelToken { + ssl_mode: self.ssl_mode, + process_id: self.process_id, + secret_key: self.secret_key, + }, } } diff --git a/libs/proxy/tokio-postgres2/src/config.rs b/libs/proxy/tokio-postgres2/src/config.rs index 978d348741..961cbc923e 100644 --- a/libs/proxy/tokio-postgres2/src/config.rs +++ b/libs/proxy/tokio-postgres2/src/config.rs @@ -12,12 +12,13 @@ use tokio::net::TcpStream; use crate::connect::connect; use crate::connect_raw::{RawConnection, connect_raw}; -use crate::tls::{MakeTlsConnect, TlsConnect}; +use crate::connect_tls::connect_tls; +use crate::maybe_tls_stream::MaybeTlsStream; +use crate::tls::{MakeTlsConnect, TlsConnect, TlsStream}; use crate::{Client, Connection, Error}; /// TLS configuration. #[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] -#[non_exhaustive] pub enum SslMode { /// Do not use TLS. Disable, @@ -231,7 +232,7 @@ impl Config { /// Requires the `runtime` Cargo feature (enabled by default). pub async fn connect( &self, - tls: T, + tls: &T, ) -> Result<(Client, Connection), Error> where T: MakeTlsConnect, @@ -239,7 +240,7 @@ impl Config { connect(tls, self).await } - pub async fn connect_raw( + pub async fn tls_and_authenticate( &self, stream: S, tls: T, @@ -248,7 +249,19 @@ impl Config { S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { - connect_raw(stream, tls, self).await + let stream = connect_tls(stream, self.ssl_mode, tls).await?; + connect_raw(stream, self).await + } + + pub async fn authenticate( + &self, + stream: MaybeTlsStream, + ) -> Result, Error> + where + S: AsyncRead + AsyncWrite + Unpin, + T: TlsStream + Unpin, + { + connect_raw(stream, self).await } } diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs index 39a0a87c74..4a07eccf9a 100644 --- a/libs/proxy/tokio-postgres2/src/connect.rs +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -9,11 +9,12 @@ use crate::codec::BackendMessage; use crate::config::Host; use crate::connect_raw::connect_raw; use crate::connect_socket::connect_socket; +use crate::connect_tls::connect_tls; use crate::tls::{MakeTlsConnect, TlsConnect}; use crate::{Client, Config, Connection, Error, RawConnection}; pub async fn connect( - mut tls: T, + tls: &T, config: &Config, ) -> Result<(Client, Connection), Error> where @@ -44,13 +45,14 @@ where T: TlsConnect, { let socket = connect_socket(host_addr, host, port, config.connect_timeout).await?; + let stream = connect_tls(socket, config.ssl_mode, tls).await?; let RawConnection { stream, parameters, delayed_notice, process_id, secret_key, - } = connect_raw(socket, tls, config).await?; + } = connect_raw(stream, config).await?; let socket_config = SocketConfig { host_addr, diff --git a/libs/proxy/tokio-postgres2/src/connect_raw.rs b/libs/proxy/tokio-postgres2/src/connect_raw.rs index 20dc538cf2..b89a600a2e 100644 --- a/libs/proxy/tokio-postgres2/src/connect_raw.rs +++ b/libs/proxy/tokio-postgres2/src/connect_raw.rs @@ -16,9 +16,8 @@ use tokio_util::codec::Framed; use crate::Error; use crate::codec::{BackendMessage, BackendMessages, FrontendMessage, PostgresCodec}; use crate::config::{self, AuthKeys, Config}; -use crate::connect_tls::connect_tls; use crate::maybe_tls_stream::MaybeTlsStream; -use crate::tls::{TlsConnect, TlsStream}; +use crate::tls::TlsStream; pub struct StartupStream { inner: Framed, PostgresCodec>, @@ -87,16 +86,13 @@ pub struct RawConnection { } pub async fn connect_raw( - stream: S, - tls: T, + stream: MaybeTlsStream, config: &Config, -) -> Result, Error> +) -> Result, Error> where S: AsyncRead + AsyncWrite + Unpin, - T: TlsConnect, + T: TlsStream + Unpin, { - let stream = connect_tls(stream, config.ssl_mode, tls).await?; - let mut stream = StartupStream { inner: Framed::new(stream, PostgresCodec), buf: BackendMessages::empty(), diff --git a/libs/proxy/tokio-postgres2/src/error/mod.rs b/libs/proxy/tokio-postgres2/src/error/mod.rs index 8149bceeb9..5309bce17e 100644 --- a/libs/proxy/tokio-postgres2/src/error/mod.rs +++ b/libs/proxy/tokio-postgres2/src/error/mod.rs @@ -332,10 +332,10 @@ impl fmt::Display for DbError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { write!(fmt, "{}: {}", self.severity, self.message)?; if let Some(detail) = &self.detail { - write!(fmt, "\nDETAIL: {}", detail)?; + write!(fmt, "\nDETAIL: {detail}")?; } if let Some(hint) = &self.hint { - write!(fmt, "\nHINT: {}", hint)?; + write!(fmt, "\nHINT: {hint}")?; } Ok(()) } @@ -398,9 +398,9 @@ impl fmt::Display for Error { Kind::Io => fmt.write_str("error communicating with the server")?, Kind::UnexpectedMessage => fmt.write_str("unexpected message from server")?, Kind::Tls => fmt.write_str("error performing TLS handshake")?, - Kind::ToSql(idx) => write!(fmt, "error serializing parameter {}", idx)?, - Kind::FromSql(idx) => write!(fmt, "error deserializing column {}", idx)?, - Kind::Column(column) => write!(fmt, "invalid column `{}`", column)?, + Kind::ToSql(idx) => write!(fmt, "error serializing parameter {idx}")?, + Kind::FromSql(idx) => write!(fmt, "error deserializing column {idx}")?, + Kind::Column(column) => write!(fmt, "invalid column `{column}`")?, Kind::Closed => fmt.write_str("connection closed")?, Kind::Db => fmt.write_str("db error")?, Kind::Parse => fmt.write_str("error parsing response from server")?, @@ -411,7 +411,7 @@ impl fmt::Display for Error { Kind::Timeout => fmt.write_str("timeout waiting for server")?, }; if let Some(ref cause) = self.0.cause { - write!(fmt, ": {}", cause)?; + write!(fmt, ": {cause}")?; } Ok(()) } diff --git a/libs/proxy/tokio-postgres2/src/lib.rs b/libs/proxy/tokio-postgres2/src/lib.rs index 9556070ed5..791c93b972 100644 --- a/libs/proxy/tokio-postgres2/src/lib.rs +++ b/libs/proxy/tokio-postgres2/src/lib.rs @@ -3,7 +3,7 @@ use postgres_protocol2::message::backend::ReadyForQueryBody; -pub use crate::cancel_token::CancelToken; +pub use crate::cancel_token::{CancelToken, RawCancelToken}; pub use crate::client::{Client, SocketConfig}; pub use crate::config::Config; pub use crate::connect_raw::RawConnection; diff --git a/libs/proxy/tokio-postgres2/src/row.rs b/libs/proxy/tokio-postgres2/src/row.rs index 5fc955eef4..36d578558f 100644 --- a/libs/proxy/tokio-postgres2/src/row.rs +++ b/libs/proxy/tokio-postgres2/src/row.rs @@ -156,7 +156,7 @@ impl Row { { match self.get_inner(&idx) { Ok(ok) => ok, - Err(err) => panic!("error retrieving column {}: {}", idx, err), + Err(err) => panic!("error retrieving column {idx}: {err}"), } } @@ -274,7 +274,7 @@ impl SimpleQueryRow { { match self.get_inner(&idx) { Ok(ok) => ok, - Err(err) => panic!("error retrieving column {}: {}", idx, err), + Err(err) => panic!("error retrieving column {idx}: {err}"), } } diff --git a/libs/proxy/tokio-postgres2/src/tls.rs b/libs/proxy/tokio-postgres2/src/tls.rs index 41b51368ff..f9cbcf4991 100644 --- a/libs/proxy/tokio-postgres2/src/tls.rs +++ b/libs/proxy/tokio-postgres2/src/tls.rs @@ -47,7 +47,7 @@ pub trait MakeTlsConnect { /// Creates a new `TlsConnect`or. /// /// The domain name is provided for certificate verification and SNI. - fn make_tls_connect(&mut self, domain: &str) -> Result; + fn make_tls_connect(&self, domain: &str) -> Result; } /// An asynchronous function wrapping a stream in a TLS session. @@ -85,7 +85,7 @@ impl MakeTlsConnect for NoTls { type TlsConnect = NoTls; type Error = NoTlsError; - fn make_tls_connect(&mut self, _: &str) -> Result { + fn make_tls_connect(&self, _: &str) -> Result { Ok(NoTls) } } diff --git a/libs/remote_storage/src/azure_blob.rs b/libs/remote_storage/src/azure_blob.rs index 5363e935e3..e9c24ac723 100644 --- a/libs/remote_storage/src/azure_blob.rs +++ b/libs/remote_storage/src/azure_blob.rs @@ -10,7 +10,7 @@ use std::sync::Arc; use std::time::{Duration, SystemTime}; use std::{env, io}; -use anyhow::{Context, Result}; +use anyhow::{Context, Result, anyhow}; use azure_core::request_options::{IfMatchCondition, MaxResults, Metadata, Range}; use azure_core::{Continuable, HttpClient, RetryOptions, TransportOptions}; use azure_storage::StorageCredentials; @@ -37,6 +37,7 @@ use crate::metrics::{AttemptOutcome, RequestKind, start_measuring_requests}; use crate::{ ConcurrencyLimiter, Download, DownloadError, DownloadKind, DownloadOpts, Listing, ListingMode, ListingObject, RemotePath, RemoteStorage, StorageMetadata, TimeTravelError, TimeoutOrCancel, + Version, VersionKind, }; pub struct AzureBlobStorage { @@ -405,6 +406,39 @@ impl AzureBlobStorage { pub fn container_name(&self) -> &str { &self.container_name } + + async fn list_versions_with_permit( + &self, + _permit: &tokio::sync::SemaphorePermit<'_>, + prefix: Option<&RemotePath>, + mode: ListingMode, + max_keys: Option, + cancel: &CancellationToken, + ) -> Result { + let customize_builder = |mut builder: ListBlobsBuilder| { + builder = builder.include_versions(true); + // We do not return this info back to `VersionListing` yet. + builder = builder.include_deleted(true); + builder + }; + let kind = RequestKind::ListVersions; + + let mut stream = std::pin::pin!(self.list_streaming_for_fn( + prefix, + mode, + max_keys, + cancel, + kind, + customize_builder + )); + let mut combined: crate::VersionListing = + stream.next().await.expect("At least one item required")?; + while let Some(list) = stream.next().await { + let list = list?; + combined.versions.extend(list.versions.into_iter()); + } + Ok(combined) + } } trait ListingCollector { @@ -488,27 +522,10 @@ impl RemoteStorage for AzureBlobStorage { max_keys: Option, cancel: &CancellationToken, ) -> std::result::Result { - let customize_builder = |mut builder: ListBlobsBuilder| { - builder = builder.include_versions(true); - builder - }; let kind = RequestKind::ListVersions; - - let mut stream = std::pin::pin!(self.list_streaming_for_fn( - prefix, - mode, - max_keys, - cancel, - kind, - customize_builder - )); - let mut combined: crate::VersionListing = - stream.next().await.expect("At least one item required")?; - while let Some(list) = stream.next().await { - let list = list?; - combined.versions.extend(list.versions.into_iter()); - } - Ok(combined) + let permit = self.permit(kind, cancel).await?; + self.list_versions_with_permit(&permit, prefix, mode, max_keys, cancel) + .await } async fn head_object( @@ -803,14 +820,159 @@ impl RemoteStorage for AzureBlobStorage { async fn time_travel_recover( &self, - _prefix: Option<&RemotePath>, - _timestamp: SystemTime, - _done_if_after: SystemTime, - _cancel: &CancellationToken, + prefix: Option<&RemotePath>, + timestamp: SystemTime, + done_if_after: SystemTime, + cancel: &CancellationToken, + _complexity_limit: Option, ) -> Result<(), TimeTravelError> { - // TODO use Azure point in time recovery feature for this - // https://learn.microsoft.com/en-us/azure/storage/blobs/point-in-time-restore-overview - Err(TimeTravelError::Unimplemented) + let msg = "PLEASE NOTE: Azure Blob storage time-travel recovery may not work as expected " + .to_string() + + "for some specific files. If a file gets deleted but then overwritten and we want to recover " + + "to the time during the file was not present, this functionality will recover the file. Only " + + "use the functionality for services that can tolerate this. For example, recovering a state of the " + + "pageserver tenants."; + tracing::error!("{}", msg); + + let kind = RequestKind::TimeTravel; + let permit = self.permit(kind, cancel).await?; + + let mode = ListingMode::NoDelimiter; + let version_listing = self + .list_versions_with_permit(&permit, prefix, mode, None, cancel) + .await + .map_err(|err| match err { + DownloadError::Other(e) => TimeTravelError::Other(e), + DownloadError::Cancelled => TimeTravelError::Cancelled, + other => TimeTravelError::Other(other.into()), + })?; + let versions_and_deletes = version_listing.versions; + + tracing::info!( + "Built list for time travel with {} versions and deletions", + versions_and_deletes.len() + ); + + // Work on the list of references instead of the objects directly, + // otherwise we get lifetime errors in the sort_by_key call below. + let mut versions_and_deletes = versions_and_deletes.iter().collect::>(); + + versions_and_deletes.sort_by_key(|vd| (&vd.key, &vd.last_modified)); + + let mut vds_for_key = HashMap::<_, Vec<_>>::new(); + + for vd in &versions_and_deletes { + let Version { key, .. } = &vd; + let version_id = vd.version_id().map(|v| v.0.as_str()); + if version_id == Some("null") { + return Err(TimeTravelError::Other(anyhow!( + "Received ListVersions response for key={key} with version_id='null', \ + indicating either disabled versioning, or legacy objects with null version id values" + ))); + } + tracing::trace!("Parsing version key={key} kind={:?}", vd.kind); + + vds_for_key.entry(key).or_default().push(vd); + } + + let warn_threshold = 3; + let max_retries = 10; + let is_permanent = |e: &_| matches!(e, TimeTravelError::Cancelled); + + for (key, versions) in vds_for_key { + let last_vd = versions.last().unwrap(); + let key = self.relative_path_to_name(key); + if last_vd.last_modified > done_if_after { + tracing::debug!("Key {key} has version later than done_if_after, skipping"); + continue; + } + // the version we want to restore to. + let version_to_restore_to = + match versions.binary_search_by_key(×tamp, |tpl| tpl.last_modified) { + Ok(v) => v, + Err(e) => e, + }; + if version_to_restore_to == versions.len() { + tracing::debug!("Key {key} has no changes since timestamp, skipping"); + continue; + } + let mut do_delete = false; + if version_to_restore_to == 0 { + // All versions more recent, so the key didn't exist at the specified time point. + tracing::debug!( + "All {} versions more recent for {key}, deleting", + versions.len() + ); + do_delete = true; + } else { + match &versions[version_to_restore_to - 1] { + Version { + kind: VersionKind::Version(version_id), + .. + } => { + let source_url = format!( + "{}/{}?versionid={}", + self.client + .url() + .map_err(|e| TimeTravelError::Other(anyhow!("{e}")))?, + key, + version_id.0 + ); + tracing::debug!( + "Promoting old version {} for {key} at {}...", + version_id.0, + source_url + ); + backoff::retry( + || async { + let blob_client = self.client.blob_client(key.clone()); + let op = blob_client.copy(Url::from_str(&source_url).unwrap()); + tokio::select! { + res = op => res.map_err(|e| TimeTravelError::Other(e.into())), + _ = cancel.cancelled() => Err(TimeTravelError::Cancelled), + } + }, + is_permanent, + warn_threshold, + max_retries, + "copying object version for time_travel_recover", + cancel, + ) + .await + .ok_or_else(|| TimeTravelError::Cancelled) + .and_then(|x| x)?; + tracing::info!(?version_id, %key, "Copied old version in Azure blob storage"); + } + Version { + kind: VersionKind::DeletionMarker, + .. + } => { + do_delete = true; + } + } + }; + if do_delete { + if matches!(last_vd.kind, VersionKind::DeletionMarker) { + // Key has since been deleted (but there was some history), no need to do anything + tracing::debug!("Key {key} already deleted, skipping."); + } else { + tracing::debug!("Deleting {key}..."); + + self.delete(&RemotePath::from_string(&key).unwrap(), cancel) + .await + .map_err(|e| { + // delete_oid0 will use TimeoutOrCancel + if TimeoutOrCancel::caused_by_cancel(&e) { + TimeTravelError::Cancelled + } else { + TimeTravelError::Other(e) + } + })?; + } + } + } + + Ok(()) } } diff --git a/libs/remote_storage/src/config.rs b/libs/remote_storage/src/config.rs index 52978be5b4..5bc1f678ae 100644 --- a/libs/remote_storage/src/config.rs +++ b/libs/remote_storage/src/config.rs @@ -87,6 +87,28 @@ pub enum RemoteStorageKind { AzureContainer(AzureConfig), } +#[derive(Deserialize)] +#[serde(tag = "type")] +/// Version of RemoteStorageKind which deserializes with type: LocalFs | AwsS3 | AzureContainer +/// Needed for endpoint storage service +pub enum TypedRemoteStorageKind { + LocalFs { local_path: Utf8PathBuf }, + AwsS3(S3Config), + AzureContainer(AzureConfig), +} + +impl From for RemoteStorageKind { + fn from(value: TypedRemoteStorageKind) -> Self { + match value { + TypedRemoteStorageKind::LocalFs { local_path } => { + RemoteStorageKind::LocalFs { local_path } + } + TypedRemoteStorageKind::AwsS3(v) => RemoteStorageKind::AwsS3(v), + TypedRemoteStorageKind::AzureContainer(v) => RemoteStorageKind::AzureContainer(v), + } + } +} + /// AWS S3 bucket coordinates and access credentials to manage the bucket contents (read and write). #[derive(Clone, PartialEq, Eq, Deserialize, Serialize)] pub struct S3Config { diff --git a/libs/remote_storage/src/lib.rs b/libs/remote_storage/src/lib.rs index b265d37a62..ed416b2811 100644 --- a/libs/remote_storage/src/lib.rs +++ b/libs/remote_storage/src/lib.rs @@ -31,6 +31,7 @@ use anyhow::Context; pub use azure_core::Etag; use bytes::Bytes; use camino::{Utf8Path, Utf8PathBuf}; +pub use config::TypedRemoteStorageKind; pub use error::{DownloadError, TimeTravelError, TimeoutOrCancel}; use futures::StreamExt; use futures::stream::Stream; @@ -440,6 +441,7 @@ pub trait RemoteStorage: Send + Sync + 'static { timestamp: SystemTime, done_if_after: SystemTime, cancel: &CancellationToken, + complexity_limit: Option, ) -> Result<(), TimeTravelError>; } @@ -651,22 +653,23 @@ impl GenericRemoteStorage> { timestamp: SystemTime, done_if_after: SystemTime, cancel: &CancellationToken, + complexity_limit: Option, ) -> Result<(), TimeTravelError> { match self { Self::LocalFs(s) => { - s.time_travel_recover(prefix, timestamp, done_if_after, cancel) + s.time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit) .await } Self::AwsS3(s) => { - s.time_travel_recover(prefix, timestamp, done_if_after, cancel) + s.time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit) .await } Self::AzureBlob(s) => { - s.time_travel_recover(prefix, timestamp, done_if_after, cancel) + s.time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit) .await } Self::Unreliable(s) => { - s.time_travel_recover(prefix, timestamp, done_if_after, cancel) + s.time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit) .await } } @@ -674,6 +677,15 @@ impl GenericRemoteStorage> { } impl GenericRemoteStorage { + pub async fn from_storage_kind(kind: TypedRemoteStorageKind) -> anyhow::Result { + Self::from_config(&RemoteStorageConfig { + storage: kind.into(), + timeout: RemoteStorageConfig::DEFAULT_TIMEOUT, + small_timeout: RemoteStorageConfig::DEFAULT_SMALL_TIMEOUT, + }) + .await + } + pub async fn from_config(storage_config: &RemoteStorageConfig) -> anyhow::Result { let timeout = storage_config.timeout; diff --git a/libs/remote_storage/src/local_fs.rs b/libs/remote_storage/src/local_fs.rs index 6607b55f1a..30690b1bdb 100644 --- a/libs/remote_storage/src/local_fs.rs +++ b/libs/remote_storage/src/local_fs.rs @@ -400,7 +400,7 @@ impl RemoteStorage for LocalFs { key }; - let relative_key = format!("{}", relative_key); + let relative_key = format!("{relative_key}"); if relative_key.contains(REMOTE_STORAGE_PREFIX_SEPARATOR) { let first_part = relative_key .split(REMOTE_STORAGE_PREFIX_SEPARATOR) @@ -594,13 +594,9 @@ impl RemoteStorage for LocalFs { let from_path = from.with_base(&self.storage_root); let to_path = to.with_base(&self.storage_root); create_target_directory(&to_path).await?; - fs::copy(&from_path, &to_path).await.with_context(|| { - format!( - "Failed to copy file from '{from_path}' to '{to_path}'", - from_path = from_path, - to_path = to_path - ) - })?; + fs::copy(&from_path, &to_path) + .await + .with_context(|| format!("Failed to copy file from '{from_path}' to '{to_path}'"))?; Ok(()) } @@ -610,6 +606,7 @@ impl RemoteStorage for LocalFs { _timestamp: SystemTime, _done_if_after: SystemTime, _cancel: &CancellationToken, + _complexity_limit: Option, ) -> Result<(), TimeTravelError> { Err(TimeTravelError::Unimplemented) } @@ -1182,7 +1179,7 @@ mod fs_tests { .write(true) .create_new(true) .open(path)?; - write!(file_for_writing, "{}", contents)?; + write!(file_for_writing, "{contents}")?; drop(file_for_writing); let file_size = path.metadata()?.len() as usize; Ok(( diff --git a/libs/remote_storage/src/s3_bucket.rs b/libs/remote_storage/src/s3_bucket.rs index d98ff552ee..8a2e5bd10e 100644 --- a/libs/remote_storage/src/s3_bucket.rs +++ b/libs/remote_storage/src/s3_bucket.rs @@ -981,22 +981,16 @@ impl RemoteStorage for S3Bucket { timestamp: SystemTime, done_if_after: SystemTime, cancel: &CancellationToken, + complexity_limit: Option, ) -> Result<(), TimeTravelError> { let kind = RequestKind::TimeTravel; let permit = self.permit(kind, cancel).await?; tracing::trace!("Target time: {timestamp:?}, done_if_after {done_if_after:?}"); - // Limit the number of versions deletions, mostly so that we don't - // keep requesting forever if the list is too long, as we'd put the - // list in RAM. - // Building a list of 100k entries that reaches the limit roughly takes - // 40 seconds, and roughly corresponds to tenants of 2 TiB physical size. - const COMPLEXITY_LIMIT: Option = NonZeroU32::new(100_000); - let mode = ListingMode::NoDelimiter; let version_listing = self - .list_versions_with_permit(&permit, prefix, mode, COMPLEXITY_LIMIT, cancel) + .list_versions_with_permit(&permit, prefix, mode, complexity_limit, cancel) .await .map_err(|err| match err { DownloadError::Other(e) => TimeTravelError::Other(e), @@ -1022,6 +1016,7 @@ impl RemoteStorage for S3Bucket { let Version { key, .. } = &vd; let version_id = vd.version_id().map(|v| v.0.as_str()); if version_id == Some("null") { + // TODO: check the behavior of using the SDK on a non-versioned container return Err(TimeTravelError::Other(anyhow!( "Received ListVersions response for key={key} with version_id='null', \ indicating either disabled versioning, or legacy objects with null version id values" diff --git a/libs/remote_storage/src/simulate_failures.rs b/libs/remote_storage/src/simulate_failures.rs index 894cf600be..f9856a5856 100644 --- a/libs/remote_storage/src/simulate_failures.rs +++ b/libs/remote_storage/src/simulate_failures.rs @@ -240,11 +240,12 @@ impl RemoteStorage for UnreliableWrapper { timestamp: SystemTime, done_if_after: SystemTime, cancel: &CancellationToken, + complexity_limit: Option, ) -> Result<(), TimeTravelError> { self.attempt(RemoteOp::TimeTravelRecover(prefix.map(|p| p.to_owned()))) .map_err(TimeTravelError::Other)?; self.inner - .time_travel_recover(prefix, timestamp, done_if_after, cancel) + .time_travel_recover(prefix, timestamp, done_if_after, cancel, complexity_limit) .await } } diff --git a/libs/remote_storage/tests/test_real_s3.rs b/libs/remote_storage/tests/test_real_s3.rs index d38e13fd05..6b893edf75 100644 --- a/libs/remote_storage/tests/test_real_s3.rs +++ b/libs/remote_storage/tests/test_real_s3.rs @@ -157,7 +157,7 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow: // No changes after recovery to t2 (no-op) let t_final = time_point().await; ctx.client - .time_travel_recover(None, t2, t_final, &cancel) + .time_travel_recover(None, t2, t_final, &cancel, None) .await?; let t2_files_recovered = list_files(&ctx.client, &cancel).await?; println!("after recovery to t2: {t2_files_recovered:?}"); @@ -173,7 +173,7 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow: // after recovery to t1: path1 is back, path2 has the old content let t_final = time_point().await; ctx.client - .time_travel_recover(None, t1, t_final, &cancel) + .time_travel_recover(None, t1, t_final, &cancel, None) .await?; let t1_files_recovered = list_files(&ctx.client, &cancel).await?; println!("after recovery to t1: {t1_files_recovered:?}"); @@ -189,7 +189,7 @@ async fn s3_time_travel_recovery_works(ctx: &mut MaybeEnabledStorage) -> anyhow: // after recovery to t0: everything is gone except for path1 let t_final = time_point().await; ctx.client - .time_travel_recover(None, t0, t_final, &cancel) + .time_travel_recover(None, t0, t_final, &cancel, None) .await?; let t0_files_recovered = list_files(&ctx.client, &cancel).await?; println!("after recovery to t0: {t0_files_recovered:?}"); diff --git a/libs/safekeeper_api/Cargo.toml b/libs/safekeeper_api/Cargo.toml index d9d080e8fe..928e583b0b 100644 --- a/libs/safekeeper_api/Cargo.toml +++ b/libs/safekeeper_api/Cargo.toml @@ -10,6 +10,7 @@ const_format.workspace = true serde.workspace = true serde_json.workspace = true postgres_ffi.workspace = true +postgres_versioninfo.workspace = true pq_proto.workspace = true tokio.workspace = true utils.workspace = true diff --git a/libs/safekeeper_api/src/lib.rs b/libs/safekeeper_api/src/lib.rs index fa86523ad7..ba0bfee971 100644 --- a/libs/safekeeper_api/src/lib.rs +++ b/libs/safekeeper_api/src/lib.rs @@ -8,6 +8,8 @@ pub mod membership; /// Public API types pub mod models; +pub use postgres_versioninfo::{PgMajorVersion, PgVersionId}; + /// Consensus logical timestamp. Note: it is a part of sk control file. pub type Term = u64; /// With this term timeline is created initially. It @@ -20,7 +22,7 @@ pub const INITIAL_TERM: Term = 0; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct ServerInfo { /// Postgres server version - pub pg_version: u32, + pub pg_version: PgVersionId, pub system_id: SystemId, pub wal_seg_size: u32, } diff --git a/libs/safekeeper_api/src/membership.rs b/libs/safekeeper_api/src/membership.rs index 3d4d17096e..1751c54f6a 100644 --- a/libs/safekeeper_api/src/membership.rs +++ b/libs/safekeeper_api/src/membership.rs @@ -193,10 +193,10 @@ mod tests { }) .unwrap(); - println!("members: {}", members); + println!("members: {members}"); let j = serde_json::to_string(&members).expect("failed to serialize"); - println!("members json: {}", j); + println!("members json: {j}"); assert_eq!( j, r#"[{"id":42,"host":"lala.org","pg_port":5432},{"id":43,"host":"bubu.org","pg_port":5432}]"# diff --git a/libs/safekeeper_api/src/models.rs b/libs/safekeeper_api/src/models.rs index 8658dc4011..5c1ee41f7b 100644 --- a/libs/safekeeper_api/src/models.rs +++ b/libs/safekeeper_api/src/models.rs @@ -4,6 +4,7 @@ use std::net::SocketAddr; use pageserver_api::shard::ShardIdentity; use postgres_ffi::TimestampTz; +use postgres_versioninfo::PgVersionId; use serde::{Deserialize, Serialize}; use tokio::time::Instant; use utils::id::{NodeId, TenantId, TenantTimelineId, TimelineId}; @@ -13,7 +14,7 @@ use utils::pageserver_feedback::PageserverFeedback; use crate::membership::Configuration; use crate::{ServerInfo, Term}; -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct SafekeeperStatus { pub id: NodeId, } @@ -23,8 +24,7 @@ pub struct TimelineCreateRequest { pub tenant_id: TenantId, pub timeline_id: TimelineId, pub mconf: Configuration, - /// In the PG_VERSION_NUM macro format, like 140017. - pub pg_version: u32, + pub pg_version: PgVersionId, pub system_id: Option, // By default WAL_SEGMENT_SIZE pub wal_seg_size: Option, diff --git a/libs/utils/src/error.rs b/libs/utils/src/error.rs index 7ce203e918..6fa86916c1 100644 --- a/libs/utils/src/error.rs +++ b/libs/utils/src/error.rs @@ -41,7 +41,7 @@ pub fn report_compact_sources(e: &E) -> impl std::fmt::Dis // why is E a generic parameter here? hope that rustc will see through a default // Error::source implementation and leave the following out if there cannot be any // sources: - Sources(self.0.source()).try_for_each(|src| write!(f, ": {}", src)) + Sources(self.0.source()).try_for_each(|src| write!(f, ": {src}")) } } diff --git a/libs/utils/src/generation.rs b/libs/utils/src/generation.rs index b5e4a4644a..8a3bef914a 100644 --- a/libs/utils/src/generation.rs +++ b/libs/utils/src/generation.rs @@ -135,7 +135,7 @@ impl Debug for Generation { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Valid(v) => { - write!(f, "{:08x}", v) + write!(f, "{v:08x}") } Self::None => { write!(f, "") diff --git a/libs/utils/src/id.rs b/libs/utils/src/id.rs index 68cb1f0209..e3037aec21 100644 --- a/libs/utils/src/id.rs +++ b/libs/utils/src/id.rs @@ -280,7 +280,7 @@ impl TryFrom> for TimelineId { value .unwrap_or_default() .parse::() - .with_context(|| format!("Could not parse timeline id from {:?}", value)) + .with_context(|| format!("Could not parse timeline id from {value:?}")) } } diff --git a/libs/utils/src/postgres_client.rs b/libs/utils/src/postgres_client.rs index 4167839e28..7596fefe38 100644 --- a/libs/utils/src/postgres_client.rs +++ b/libs/utils/src/postgres_client.rs @@ -89,7 +89,7 @@ pub fn wal_stream_connection_config( .set_password(args.auth_token.map(|s| s.to_owned())); if let Some(availability_zone) = args.availability_zone { - connstr = connstr.extend_options([format!("availability_zone={}", availability_zone)]); + connstr = connstr.extend_options([format!("availability_zone={availability_zone}")]); } Ok(connstr) diff --git a/libs/utils/src/shard.rs b/libs/utils/src/shard.rs index c8c410a725..f2b81373e2 100644 --- a/libs/utils/src/shard.rs +++ b/libs/utils/src/shard.rs @@ -196,7 +196,7 @@ impl std::fmt::Display for TenantShardId { impl std::fmt::Debug for TenantShardId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // Debug is the same as Display: the compact hex representation - write!(f, "{}", self) + write!(f, "{self}") } } @@ -284,7 +284,7 @@ impl std::fmt::Display for ShardIndex { impl std::fmt::Debug for ShardIndex { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { // Debug is the same as Display: the compact hex representation - write!(f, "{}", self) + write!(f, "{self}") } } diff --git a/libs/utils/src/signals.rs b/libs/utils/src/signals.rs index 426bb65916..bdaa3cb665 100644 --- a/libs/utils/src/signals.rs +++ b/libs/utils/src/signals.rs @@ -29,7 +29,7 @@ impl ShutdownSignals { SIGINT => Signal::Interrupt, SIGTERM => Signal::Terminate, SIGQUIT => Signal::Quit, - other => panic!("unknown signal: {}", other), + other => panic!("unknown signal: {other}"), }; handler(signal)?; diff --git a/libs/vm_monitor/src/dispatcher.rs b/libs/vm_monitor/src/dispatcher.rs index 7b7201ab77..7bd6adc2f8 100644 --- a/libs/vm_monitor/src/dispatcher.rs +++ b/libs/vm_monitor/src/dispatcher.rs @@ -90,8 +90,7 @@ impl Dispatcher { Err(e) => { sink.send(Message::Text(Utf8Bytes::from( serde_json::to_string(&ProtocolResponse::Error(format!( - "Received protocol version range {} which does not overlap with {}", - agent_range, monitor_range + "Received protocol version range {agent_range} which does not overlap with {monitor_range}" ))) .unwrap(), ))) diff --git a/libs/vm_monitor/src/filecache.rs b/libs/vm_monitor/src/filecache.rs index bc42347e5a..55bbdea169 100644 --- a/libs/vm_monitor/src/filecache.rs +++ b/libs/vm_monitor/src/filecache.rs @@ -285,7 +285,7 @@ impl FileCacheState { // why we're constructing the query here. self.client .query( - &format!("ALTER SYSTEM SET neon.file_cache_size_limit = {};", num_mb), + &format!("ALTER SYSTEM SET neon.file_cache_size_limit = {num_mb};"), &[], ) .await diff --git a/libs/wal_decoder/Cargo.toml b/libs/wal_decoder/Cargo.toml index cb0ef4b00d..600ef091f5 100644 --- a/libs/wal_decoder/Cargo.toml +++ b/libs/wal_decoder/Cargo.toml @@ -14,6 +14,7 @@ bytes.workspace = true pageserver_api.workspace = true prost.workspace = true postgres_ffi.workspace = true +postgres_ffi_types.workspace = true serde.workspace = true thiserror.workspace = true tokio = { workspace = true, features = ["io-util"] } diff --git a/libs/wal_decoder/benches/bench_interpret_wal.rs b/libs/wal_decoder/benches/bench_interpret_wal.rs index ed6ba4d267..e3956eca05 100644 --- a/libs/wal_decoder/benches/bench_interpret_wal.rs +++ b/libs/wal_decoder/benches/bench_interpret_wal.rs @@ -10,7 +10,7 @@ use futures::StreamExt; use futures::stream::FuturesUnordered; use pageserver_api::shard::{ShardIdentity, ShardStripeSize}; use postgres_ffi::waldecoder::WalStreamDecoder; -use postgres_ffi::{MAX_SEND_SIZE, WAL_SEGMENT_SIZE}; +use postgres_ffi::{MAX_SEND_SIZE, PgMajorVersion, WAL_SEGMENT_SIZE}; use pprof::criterion::{Output, PProfProfiler}; use remote_storage::{ DownloadOpts, GenericRemoteStorage, ListingMode, RemoteStorageConfig, RemoteStorageKind, @@ -64,7 +64,7 @@ async fn download_bench_data( let temp_dir_parent: Utf8PathBuf = env::current_dir().unwrap().try_into()?; let temp_dir = camino_tempfile::tempdir_in(temp_dir_parent)?; - eprintln!("Downloading benchmark data to {:?}", temp_dir); + eprintln!("Downloading benchmark data to {temp_dir:?}"); let listing = client .list(None, ListingMode::NoDelimiter, None, cancel) @@ -115,12 +115,12 @@ struct BenchmarkData { #[derive(Deserialize)] struct BenchmarkMetadata { - pg_version: u32, + pg_version: PgMajorVersion, start_lsn: Lsn, } async fn load_bench_data(path: &Utf8Path, input_size: usize) -> anyhow::Result { - eprintln!("Loading benchmark data from {:?}", path); + eprintln!("Loading benchmark data from {path:?}"); let mut entries = tokio::fs::read_dir(path).await?; let mut ordered_segment_paths = Vec::new(); diff --git a/libs/wal_decoder/build.rs b/libs/wal_decoder/build.rs index d5b7ad02ad..e8acb52256 100644 --- a/libs/wal_decoder/build.rs +++ b/libs/wal_decoder/build.rs @@ -6,6 +6,6 @@ fn main() -> Result<(), Box> { // the build then. Anyway, per cargo docs build script shouldn't output to // anywhere but $OUT_DIR. tonic_build::compile_protos("proto/interpreted_wal.proto") - .unwrap_or_else(|e| panic!("failed to compile protos {:?}", e)); + .unwrap_or_else(|e| panic!("failed to compile protos {e:?}")); Ok(()) } diff --git a/libs/wal_decoder/src/decoder.rs b/libs/wal_decoder/src/decoder.rs index cb0835e894..0843eb35bf 100644 --- a/libs/wal_decoder/src/decoder.rs +++ b/libs/wal_decoder/src/decoder.rs @@ -7,9 +7,9 @@ use bytes::{Buf, Bytes}; use pageserver_api::key::rel_block_to_key; use pageserver_api::reltag::{RelTag, SlruKind}; use pageserver_api::shard::ShardIdentity; -use postgres_ffi::pg_constants; -use postgres_ffi::relfile_utils::VISIBILITYMAP_FORKNUM; use postgres_ffi::walrecord::*; +use postgres_ffi::{PgMajorVersion, pg_constants}; +use postgres_ffi_types::forknum::VISIBILITYMAP_FORKNUM; use utils::lsn::Lsn; use crate::models::*; @@ -24,7 +24,7 @@ impl InterpretedWalRecord { buf: Bytes, shards: &[ShardIdentity], next_record_lsn: Lsn, - pg_version: u32, + pg_version: PgMajorVersion, ) -> anyhow::Result> { let mut decoded = DecodedWALRecord::default(); decode_wal_record(buf, &mut decoded, pg_version)?; @@ -78,7 +78,7 @@ impl MetadataRecord { decoded: &DecodedWALRecord, shard_records: &mut HashMap, next_record_lsn: Lsn, - pg_version: u32, + pg_version: PgMajorVersion, ) -> anyhow::Result<()> { // Note: this doesn't actually copy the bytes since // the [`Bytes`] type implements it via a level of indirection. @@ -193,7 +193,7 @@ impl MetadataRecord { fn decode_heapam_record( buf: &mut Bytes, decoded: &DecodedWALRecord, - pg_version: u32, + pg_version: PgMajorVersion, ) -> anyhow::Result> { // Handle VM bit updates that are implicitly part of heap records. @@ -205,7 +205,7 @@ impl MetadataRecord { let mut flags = pg_constants::VISIBILITYMAP_VALID_BITS; match pg_version { - 14 => { + PgMajorVersion::PG14 => { if decoded.xl_rmid == pg_constants::RM_HEAP_ID { let info = decoded.xl_info & pg_constants::XLOG_HEAP_OPMASK; @@ -272,7 +272,7 @@ impl MetadataRecord { anyhow::bail!("Unknown RMGR {} for Heap decoding", decoded.xl_rmid); } } - 15 => { + PgMajorVersion::PG15 => { if decoded.xl_rmid == pg_constants::RM_HEAP_ID { let info = decoded.xl_info & pg_constants::XLOG_HEAP_OPMASK; @@ -339,7 +339,7 @@ impl MetadataRecord { anyhow::bail!("Unknown RMGR {} for Heap decoding", decoded.xl_rmid); } } - 16 => { + PgMajorVersion::PG16 => { if decoded.xl_rmid == pg_constants::RM_HEAP_ID { let info = decoded.xl_info & pg_constants::XLOG_HEAP_OPMASK; @@ -406,7 +406,7 @@ impl MetadataRecord { anyhow::bail!("Unknown RMGR {} for Heap decoding", decoded.xl_rmid); } } - 17 => { + PgMajorVersion::PG17 => { if decoded.xl_rmid == pg_constants::RM_HEAP_ID { let info = decoded.xl_info & pg_constants::XLOG_HEAP_OPMASK; @@ -473,7 +473,6 @@ impl MetadataRecord { anyhow::bail!("Unknown RMGR {} for Heap decoding", decoded.xl_rmid); } } - _ => {} } if new_heap_blkno.is_some() || old_heap_blkno.is_some() { @@ -500,7 +499,7 @@ impl MetadataRecord { fn decode_neonmgr_record( buf: &mut Bytes, decoded: &DecodedWALRecord, - pg_version: u32, + pg_version: PgMajorVersion, ) -> anyhow::Result> { // Handle VM bit updates that are implicitly part of heap records. @@ -514,7 +513,7 @@ impl MetadataRecord { assert_eq!(decoded.xl_rmid, pg_constants::RM_NEON_ID); match pg_version { - 16 | 17 => { + PgMajorVersion::PG16 | PgMajorVersion::PG17 => { let info = decoded.xl_info & pg_constants::XLOG_HEAP_OPMASK; match info { @@ -574,7 +573,7 @@ impl MetadataRecord { info => anyhow::bail!("Unknown WAL record type for Neon RMGR: {}", info), } } - _ => anyhow::bail!( + PgMajorVersion::PG15 | PgMajorVersion::PG14 => anyhow::bail!( "Neon RMGR has no known compatibility with PostgreSQL version {}", pg_version ), @@ -629,116 +628,121 @@ impl MetadataRecord { fn decode_dbase_record( buf: &mut Bytes, decoded: &DecodedWALRecord, - pg_version: u32, + pg_version: PgMajorVersion, ) -> anyhow::Result> { // TODO: Refactor this to avoid the duplication between postgres versions. let info = decoded.xl_info & pg_constants::XLR_RMGR_INFO_MASK; tracing::debug!(%info, %pg_version, "handle RM_DBASE_ID"); - if pg_version == 14 { - if info == postgres_ffi::v14::bindings::XLOG_DBASE_CREATE { - let createdb = XlCreateDatabase::decode(buf); - tracing::debug!("XLOG_DBASE_CREATE v14"); + match pg_version { + PgMajorVersion::PG14 => { + if info == postgres_ffi::v14::bindings::XLOG_DBASE_CREATE { + let createdb = XlCreateDatabase::decode(buf); + tracing::debug!("XLOG_DBASE_CREATE v14"); - let record = MetadataRecord::Dbase(DbaseRecord::Create(DbaseCreate { - db_id: createdb.db_id, - tablespace_id: createdb.tablespace_id, - src_db_id: createdb.src_db_id, - src_tablespace_id: createdb.src_tablespace_id, - })); + let record = MetadataRecord::Dbase(DbaseRecord::Create(DbaseCreate { + db_id: createdb.db_id, + tablespace_id: createdb.tablespace_id, + src_db_id: createdb.src_db_id, + src_tablespace_id: createdb.src_tablespace_id, + })); - return Ok(Some(record)); - } else if info == postgres_ffi::v14::bindings::XLOG_DBASE_DROP { - let dropdb = XlDropDatabase::decode(buf); + return Ok(Some(record)); + } else if info == postgres_ffi::v14::bindings::XLOG_DBASE_DROP { + let dropdb = XlDropDatabase::decode(buf); - let record = MetadataRecord::Dbase(DbaseRecord::Drop(DbaseDrop { - db_id: dropdb.db_id, - tablespace_ids: dropdb.tablespace_ids, - })); + let record = MetadataRecord::Dbase(DbaseRecord::Drop(DbaseDrop { + db_id: dropdb.db_id, + tablespace_ids: dropdb.tablespace_ids, + })); - return Ok(Some(record)); + return Ok(Some(record)); + } } - } else if pg_version == 15 { - if info == postgres_ffi::v15::bindings::XLOG_DBASE_CREATE_WAL_LOG { - tracing::debug!("XLOG_DBASE_CREATE_WAL_LOG: noop"); - } else if info == postgres_ffi::v15::bindings::XLOG_DBASE_CREATE_FILE_COPY { - // The XLOG record was renamed between v14 and v15, - // but the record format is the same. - // So we can reuse XlCreateDatabase here. - tracing::debug!("XLOG_DBASE_CREATE_FILE_COPY"); + PgMajorVersion::PG15 => { + if info == postgres_ffi::v15::bindings::XLOG_DBASE_CREATE_WAL_LOG { + tracing::debug!("XLOG_DBASE_CREATE_WAL_LOG: noop"); + } else if info == postgres_ffi::v15::bindings::XLOG_DBASE_CREATE_FILE_COPY { + // The XLOG record was renamed between v14 and v15, + // but the record format is the same. + // So we can reuse XlCreateDatabase here. + tracing::debug!("XLOG_DBASE_CREATE_FILE_COPY"); - let createdb = XlCreateDatabase::decode(buf); - let record = MetadataRecord::Dbase(DbaseRecord::Create(DbaseCreate { - db_id: createdb.db_id, - tablespace_id: createdb.tablespace_id, - src_db_id: createdb.src_db_id, - src_tablespace_id: createdb.src_tablespace_id, - })); + let createdb = XlCreateDatabase::decode(buf); + let record = MetadataRecord::Dbase(DbaseRecord::Create(DbaseCreate { + db_id: createdb.db_id, + tablespace_id: createdb.tablespace_id, + src_db_id: createdb.src_db_id, + src_tablespace_id: createdb.src_tablespace_id, + })); - return Ok(Some(record)); - } else if info == postgres_ffi::v15::bindings::XLOG_DBASE_DROP { - let dropdb = XlDropDatabase::decode(buf); - let record = MetadataRecord::Dbase(DbaseRecord::Drop(DbaseDrop { - db_id: dropdb.db_id, - tablespace_ids: dropdb.tablespace_ids, - })); + return Ok(Some(record)); + } else if info == postgres_ffi::v15::bindings::XLOG_DBASE_DROP { + let dropdb = XlDropDatabase::decode(buf); + let record = MetadataRecord::Dbase(DbaseRecord::Drop(DbaseDrop { + db_id: dropdb.db_id, + tablespace_ids: dropdb.tablespace_ids, + })); - return Ok(Some(record)); + return Ok(Some(record)); + } } - } else if pg_version == 16 { - if info == postgres_ffi::v16::bindings::XLOG_DBASE_CREATE_WAL_LOG { - tracing::debug!("XLOG_DBASE_CREATE_WAL_LOG: noop"); - } else if info == postgres_ffi::v16::bindings::XLOG_DBASE_CREATE_FILE_COPY { - // The XLOG record was renamed between v14 and v15, - // but the record format is the same. - // So we can reuse XlCreateDatabase here. - tracing::debug!("XLOG_DBASE_CREATE_FILE_COPY"); + PgMajorVersion::PG16 => { + if info == postgres_ffi::v16::bindings::XLOG_DBASE_CREATE_WAL_LOG { + tracing::debug!("XLOG_DBASE_CREATE_WAL_LOG: noop"); + } else if info == postgres_ffi::v16::bindings::XLOG_DBASE_CREATE_FILE_COPY { + // The XLOG record was renamed between v14 and v15, + // but the record format is the same. + // So we can reuse XlCreateDatabase here. + tracing::debug!("XLOG_DBASE_CREATE_FILE_COPY"); - let createdb = XlCreateDatabase::decode(buf); - let record = MetadataRecord::Dbase(DbaseRecord::Create(DbaseCreate { - db_id: createdb.db_id, - tablespace_id: createdb.tablespace_id, - src_db_id: createdb.src_db_id, - src_tablespace_id: createdb.src_tablespace_id, - })); + let createdb = XlCreateDatabase::decode(buf); + let record = MetadataRecord::Dbase(DbaseRecord::Create(DbaseCreate { + db_id: createdb.db_id, + tablespace_id: createdb.tablespace_id, + src_db_id: createdb.src_db_id, + src_tablespace_id: createdb.src_tablespace_id, + })); - return Ok(Some(record)); - } else if info == postgres_ffi::v16::bindings::XLOG_DBASE_DROP { - let dropdb = XlDropDatabase::decode(buf); - let record = MetadataRecord::Dbase(DbaseRecord::Drop(DbaseDrop { - db_id: dropdb.db_id, - tablespace_ids: dropdb.tablespace_ids, - })); + return Ok(Some(record)); + } else if info == postgres_ffi::v16::bindings::XLOG_DBASE_DROP { + let dropdb = XlDropDatabase::decode(buf); + let record = MetadataRecord::Dbase(DbaseRecord::Drop(DbaseDrop { + db_id: dropdb.db_id, + tablespace_ids: dropdb.tablespace_ids, + })); - return Ok(Some(record)); + return Ok(Some(record)); + } } - } else if pg_version == 17 { - if info == postgres_ffi::v17::bindings::XLOG_DBASE_CREATE_WAL_LOG { - tracing::debug!("XLOG_DBASE_CREATE_WAL_LOG: noop"); - } else if info == postgres_ffi::v17::bindings::XLOG_DBASE_CREATE_FILE_COPY { - // The XLOG record was renamed between v14 and v15, - // but the record format is the same. - // So we can reuse XlCreateDatabase here. - tracing::debug!("XLOG_DBASE_CREATE_FILE_COPY"); + PgMajorVersion::PG17 => { + if info == postgres_ffi::v17::bindings::XLOG_DBASE_CREATE_WAL_LOG { + tracing::debug!("XLOG_DBASE_CREATE_WAL_LOG: noop"); + } else if info == postgres_ffi::v17::bindings::XLOG_DBASE_CREATE_FILE_COPY { + // The XLOG record was renamed between v14 and v15, + // but the record format is the same. + // So we can reuse XlCreateDatabase here. + tracing::debug!("XLOG_DBASE_CREATE_FILE_COPY"); - let createdb = XlCreateDatabase::decode(buf); - let record = MetadataRecord::Dbase(DbaseRecord::Create(DbaseCreate { - db_id: createdb.db_id, - tablespace_id: createdb.tablespace_id, - src_db_id: createdb.src_db_id, - src_tablespace_id: createdb.src_tablespace_id, - })); + let createdb = XlCreateDatabase::decode(buf); + let record = MetadataRecord::Dbase(DbaseRecord::Create(DbaseCreate { + db_id: createdb.db_id, + tablespace_id: createdb.tablespace_id, + src_db_id: createdb.src_db_id, + src_tablespace_id: createdb.src_tablespace_id, + })); - return Ok(Some(record)); - } else if info == postgres_ffi::v17::bindings::XLOG_DBASE_DROP { - let dropdb = XlDropDatabase::decode(buf); - let record = MetadataRecord::Dbase(DbaseRecord::Drop(DbaseDrop { - db_id: dropdb.db_id, - tablespace_ids: dropdb.tablespace_ids, - })); + return Ok(Some(record)); + } else if info == postgres_ffi::v17::bindings::XLOG_DBASE_DROP { + let dropdb = XlDropDatabase::decode(buf); + let record = MetadataRecord::Dbase(DbaseRecord::Drop(DbaseDrop { + db_id: dropdb.db_id, + tablespace_ids: dropdb.tablespace_ids, + })); - return Ok(Some(record)); + return Ok(Some(record)); + } } } @@ -748,12 +752,12 @@ impl MetadataRecord { fn decode_clog_record( buf: &mut Bytes, decoded: &DecodedWALRecord, - pg_version: u32, + pg_version: PgMajorVersion, ) -> anyhow::Result> { let info = decoded.xl_info & !pg_constants::XLR_INFO_MASK; if info == pg_constants::CLOG_ZEROPAGE { - let pageno = if pg_version < 17 { + let pageno = if pg_version < PgMajorVersion::PG17 { buf.get_u32_le() } else { buf.get_u64_le() as u32 @@ -765,7 +769,7 @@ impl MetadataRecord { ClogZeroPage { segno, rpageno }, )))) } else { - assert!(info == pg_constants::CLOG_TRUNCATE); + assert_eq!(info, pg_constants::CLOG_TRUNCATE); let xlrec = XlClogTruncate::decode(buf, pg_version); Ok(Some(MetadataRecord::Clog(ClogRecord::Truncate( @@ -838,14 +842,14 @@ impl MetadataRecord { fn decode_multixact_record( buf: &mut Bytes, decoded: &DecodedWALRecord, - pg_version: u32, + pg_version: PgMajorVersion, ) -> anyhow::Result> { let info = decoded.xl_info & pg_constants::XLR_RMGR_INFO_MASK; if info == pg_constants::XLOG_MULTIXACT_ZERO_OFF_PAGE || info == pg_constants::XLOG_MULTIXACT_ZERO_MEM_PAGE { - let pageno = if pg_version < 17 { + let pageno = if pg_version < PgMajorVersion::PG17 { buf.get_u32_le() } else { buf.get_u64_le() as u32 diff --git a/libs/wal_decoder/src/models.rs b/libs/wal_decoder/src/models.rs index 7e1934c6c3..94a00c0e53 100644 --- a/libs/wal_decoder/src/models.rs +++ b/libs/wal_decoder/src/models.rs @@ -25,6 +25,9 @@ //! | //! |--> write to KV store within the pageserver +pub mod record; +pub mod value; + use bytes::Bytes; use pageserver_api::reltag::{RelTag, SlruKind}; use postgres_ffi::walrecord::{ diff --git a/libs/pageserver_api/src/record.rs b/libs/wal_decoder/src/models/record.rs similarity index 99% rename from libs/pageserver_api/src/record.rs rename to libs/wal_decoder/src/models/record.rs index 73516c5220..51659ed904 100644 --- a/libs/pageserver_api/src/record.rs +++ b/libs/wal_decoder/src/models/record.rs @@ -128,6 +128,6 @@ pub fn describe_wal_record(rec: &NeonWalRecord) -> Result Ok(format!("{:?}", rec)), + _ => Ok(format!("{rec:?}")), } } diff --git a/libs/pageserver_api/src/value.rs b/libs/wal_decoder/src/models/value.rs similarity index 99% rename from libs/pageserver_api/src/value.rs rename to libs/wal_decoder/src/models/value.rs index e9000939c3..3b4f896a45 100644 --- a/libs/pageserver_api/src/value.rs +++ b/libs/wal_decoder/src/models/value.rs @@ -10,7 +10,7 @@ use bytes::Bytes; use serde::{Deserialize, Serialize}; -use crate::record::NeonWalRecord; +use crate::models::record::NeonWalRecord; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub enum Value { diff --git a/libs/wal_decoder/src/serialized_batch.rs b/libs/wal_decoder/src/serialized_batch.rs index b451d6d8e0..ab38ff3d73 100644 --- a/libs/wal_decoder/src/serialized_batch.rs +++ b/libs/wal_decoder/src/serialized_batch.rs @@ -1,4 +1,4 @@ -//! This module implements batch type for serialized [`pageserver_api::value::Value`] +//! This module implements batch type for serialized [`crate::models::value::Value`] //! instances. Each batch contains a raw buffer (serialized values) //! and a list of metadata for each (key, LSN) tuple present in the batch. //! @@ -10,17 +10,17 @@ use std::collections::{BTreeSet, HashMap}; use bytes::{Bytes, BytesMut}; use pageserver_api::key::{CompactKey, Key, rel_block_to_key}; use pageserver_api::keyspace::KeySpace; -use pageserver_api::record::NeonWalRecord; use pageserver_api::reltag::RelTag; use pageserver_api::shard::ShardIdentity; -use pageserver_api::value::Value; use postgres_ffi::walrecord::{DecodedBkpBlock, DecodedWALRecord}; -use postgres_ffi::{BLCKSZ, page_is_new, page_set_lsn, pg_constants}; +use postgres_ffi::{BLCKSZ, PgMajorVersion, page_is_new, page_set_lsn, pg_constants}; use serde::{Deserialize, Serialize}; use utils::bin_ser::BeSer; use utils::lsn::Lsn; use crate::models::InterpretedWalRecord; +use crate::models::record::NeonWalRecord; +use crate::models::value::Value; static ZERO_PAGE: Bytes = Bytes::from_static(&[0u8; BLCKSZ as usize]); @@ -139,7 +139,7 @@ impl SerializedValueBatch { decoded: DecodedWALRecord, shard_records: &mut HashMap, next_record_lsn: Lsn, - pg_version: u32, + pg_version: PgMajorVersion, ) -> anyhow::Result<()> { // First determine how big the buffers need to be and allocate it up-front. // This duplicates some of the work below, but it's empirically much faster. @@ -267,7 +267,7 @@ impl SerializedValueBatch { fn estimate_buffer_size( decoded: &DecodedWALRecord, shard: &ShardIdentity, - pg_version: u32, + pg_version: PgMajorVersion, ) -> usize { let mut estimate: usize = 0; @@ -303,7 +303,11 @@ impl SerializedValueBatch { estimate } - fn block_is_image(decoded: &DecodedWALRecord, blk: &DecodedBkpBlock, pg_version: u32) -> bool { + fn block_is_image( + decoded: &DecodedWALRecord, + blk: &DecodedBkpBlock, + pg_version: PgMajorVersion, + ) -> bool { blk.apply_image && blk.has_image && decoded.xl_rmid == pg_constants::RM_XLOG_ID diff --git a/libs/walproposer/build.rs b/libs/walproposer/build.rs index 530ceb1327..b13c8b32b4 100644 --- a/libs/walproposer/build.rs +++ b/libs/walproposer/build.rs @@ -13,22 +13,24 @@ fn main() -> anyhow::Result<()> { // Tell cargo to invalidate the built crate whenever the wrapper changes println!("cargo:rerun-if-changed=bindgen_deps.h"); + let root_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../.."); + // Finding the location of built libraries and Postgres C headers: // - if POSTGRES_INSTALL_DIR is set look into it, otherwise look into `/pg_install` // - if there's a `bin/pg_config` file use it for getting include server, otherwise use `/pg_install/{PG_MAJORVERSION}/include/postgresql/server` let pg_install_dir = if let Some(postgres_install_dir) = env::var_os("POSTGRES_INSTALL_DIR") { postgres_install_dir.into() } else { - PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../pg_install") + root_path.join("pg_install") }; let pg_install_abs = std::fs::canonicalize(pg_install_dir)?; - let walproposer_lib_dir = pg_install_abs.join("build/walproposer-lib"); + let walproposer_lib_dir = root_path.join("build/walproposer-lib"); let walproposer_lib_search_str = walproposer_lib_dir .to_str() .ok_or(anyhow!("Bad non-UTF path"))?; - let pgxn_neon = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../pgxn/neon"); + let pgxn_neon = root_path.join("pgxn/neon"); let pgxn_neon = std::fs::canonicalize(pgxn_neon)?; let pgxn_neon = pgxn_neon.to_str().ok_or(anyhow!("Bad non-UTF path"))?; diff --git a/libs/walproposer/src/api_bindings.rs b/libs/walproposer/src/api_bindings.rs index 4d6cbae9a9..7c6abf252e 100644 --- a/libs/walproposer/src/api_bindings.rs +++ b/libs/walproposer/src/api_bindings.rs @@ -311,7 +311,7 @@ extern "C" fn get_redo_start_lsn(wp: *mut WalProposer) -> XLogRecPtr { } } -extern "C-unwind" fn finish_sync_safekeepers(wp: *mut WalProposer, lsn: XLogRecPtr) { +unsafe extern "C-unwind" fn finish_sync_safekeepers(wp: *mut WalProposer, lsn: XLogRecPtr) -> ! { unsafe { let callback_data = (*(*wp).config).callback_data; let api = callback_data as *mut Box; @@ -376,7 +376,7 @@ impl Level { FATAL => Level::Fatal, PANIC => Level::Panic, WPEVENT => Level::WPEvent, - _ => panic!("unknown log level {}", elevel), + _ => panic!("unknown log level {elevel}"), } } } @@ -446,7 +446,7 @@ pub fn empty_shmem() -> crate::bindings::WalproposerShmemState { impl std::fmt::Display for Level { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{:?}", self) + write!(f, "{self:?}") } } diff --git a/libs/walproposer/src/walproposer.rs b/libs/walproposer/src/walproposer.rs index e95494297c..93bb0d5eb0 100644 --- a/libs/walproposer/src/walproposer.rs +++ b/libs/walproposer/src/walproposer.rs @@ -144,7 +144,7 @@ pub trait ApiImpl { todo!() } - fn finish_sync_safekeepers(&self, _lsn: u64) { + fn finish_sync_safekeepers(&self, _lsn: u64) -> ! { todo!() } @@ -380,7 +380,7 @@ mod tests { } fn conn_send_query(&self, _: &mut crate::bindings::Safekeeper, query: &str) -> bool { - println!("conn_send_query: {}", query); + println!("conn_send_query: {query}"); true } @@ -399,13 +399,13 @@ mod tests { ) -> crate::bindings::PGAsyncReadResult { println!("conn_async_read"); let reply = self.next_safekeeper_reply(); - println!("conn_async_read result: {:?}", reply); + println!("conn_async_read result: {reply:?}"); vec.extend_from_slice(reply); crate::bindings::PGAsyncReadResult_PG_ASYNC_READ_SUCCESS } fn conn_blocking_write(&self, _: &mut crate::bindings::Safekeeper, buf: &[u8]) -> bool { - println!("conn_blocking_write: {:?}", buf); + println!("conn_blocking_write: {buf:?}"); self.check_walproposer_msg(buf); true } @@ -456,10 +456,7 @@ mod tests { timeout_millis: i64, ) -> super::WaitResult { let data = self.wait_events.get(); - println!( - "wait_event_set, timeout_millis={}, res={:?}", - timeout_millis, data - ); + println!("wait_event_set, timeout_millis={timeout_millis}, res={data:?}"); super::WaitResult::Network(data.sk, data.event_mask) } @@ -469,13 +466,13 @@ mod tests { true } - fn finish_sync_safekeepers(&self, lsn: u64) { + fn finish_sync_safekeepers(&self, lsn: u64) -> ! { self.sync_channel.send(lsn).unwrap(); panic!("sync safekeepers finished at lsn={}", lsn); } fn log_internal(&self, _wp: &mut crate::bindings::WalProposer, level: Level, msg: &str) { - println!("wp_log[{}] {}", level, msg); + println!("wp_log[{level}] {msg}"); } fn after_election(&self, _wp: &mut crate::bindings::WalProposer) { diff --git a/pageserver/Cargo.toml b/pageserver/Cargo.toml index 9591c729e8..8a2e2ed3be 100644 --- a/pageserver/Cargo.toml +++ b/pageserver/Cargo.toml @@ -12,6 +12,9 @@ testing = ["fail/failpoints", "pageserver_api/testing", "wal_decoder/testing", " fuzz-read-path = ["testing"] +# Enables benchmarking only APIs +benchmarking = [] + [dependencies] anyhow.workspace = true arc-swap.workspace = true @@ -56,6 +59,7 @@ pin-project-lite.workspace = true postgres_backend.workspace = true postgres_connection.workspace = true postgres_ffi.workspace = true +postgres_ffi_types.workspace = true postgres_initdb.workspace = true postgres-protocol.workspace = true postgres-types.workspace = true @@ -126,6 +130,7 @@ harness = false [[bench]] name = "bench_ingest" harness = false +required-features = ["benchmarking"] [[bench]] name = "upload_queue" diff --git a/pageserver/benches/bench_ingest.rs b/pageserver/benches/bench_ingest.rs index eaadfe14ae..438c6e235e 100644 --- a/pageserver/benches/bench_ingest.rs +++ b/pageserver/benches/bench_ingest.rs @@ -1,23 +1,30 @@ use std::env; use std::num::NonZeroUsize; +use std::sync::Arc; use bytes::Bytes; use camino::Utf8PathBuf; use criterion::{Criterion, criterion_group, criterion_main}; +use futures::stream::FuturesUnordered; use pageserver::config::PageServerConf; use pageserver::context::{DownloadBehavior, RequestContext}; +use pageserver::keyspace::KeySpace; use pageserver::l0_flush::{L0FlushConfig, L0FlushGlobalState}; use pageserver::task_mgr::TaskKind; -use pageserver::tenant::storage_layer::InMemoryLayer; +use pageserver::tenant::storage_layer::IoConcurrency; +use pageserver::tenant::storage_layer::{InMemoryLayer, ValuesReconstructState}; use pageserver::{page_cache, virtual_file}; +use pageserver_api::config::GetVectoredConcurrentIo; use pageserver_api::key::Key; use pageserver_api::models::virtual_file::IoMode; use pageserver_api::shard::TenantShardId; -use pageserver_api::value::Value; -use strum::IntoEnumIterator; +use tokio_stream::StreamExt; use tokio_util::sync::CancellationToken; use utils::bin_ser::BeSer; use utils::id::{TenantId, TimelineId}; +use utils::lsn::Lsn; +use utils::sync::gate::Gate; +use wal_decoder::models::value::Value; use wal_decoder::serialized_batch::SerializedValueBatch; // A very cheap hash for generating non-sequential keys. @@ -30,7 +37,7 @@ fn murmurhash32(mut h: u32) -> u32 { h } -#[derive(serde::Serialize, Clone, Copy, Debug)] +#[derive(serde::Serialize, Clone, Copy, Debug, PartialEq)] enum KeyLayout { /// Sequential unique keys Sequential, @@ -40,19 +47,30 @@ enum KeyLayout { RandomReuse(u32), } -#[derive(serde::Serialize, Clone, Copy, Debug)] +#[derive(serde::Serialize, Clone, Copy, Debug, PartialEq)] enum WriteDelta { Yes, No, } +#[derive(serde::Serialize, Clone, Copy, Debug, PartialEq)] +enum ConcurrentReads { + Yes, + No, +} + async fn ingest( conf: &'static PageServerConf, put_size: usize, put_count: usize, key_layout: KeyLayout, write_delta: WriteDelta, + concurrent_reads: ConcurrentReads, ) -> anyhow::Result<()> { + if concurrent_reads == ConcurrentReads::Yes { + assert_eq!(key_layout, KeyLayout::Sequential); + } + let mut lsn = utils::lsn::Lsn(1000); let mut key = Key::from_i128(0x0); @@ -68,16 +86,18 @@ async fn ingest( let gate = utils::sync::gate::Gate::default(); let cancel = CancellationToken::new(); - let layer = InMemoryLayer::create( - conf, - timeline_id, - tenant_shard_id, - lsn, - &gate, - &cancel, - &ctx, - ) - .await?; + let layer = Arc::new( + InMemoryLayer::create( + conf, + timeline_id, + tenant_shard_id, + lsn, + &gate, + &cancel, + &ctx, + ) + .await?, + ); let data = Value::Image(Bytes::from(vec![0u8; put_size])); let data_ser_size = data.serialized_size().unwrap() as usize; @@ -86,6 +106,61 @@ async fn ingest( pageserver::context::DownloadBehavior::Download, ); + const READ_BATCH_SIZE: u32 = 32; + let (tx, mut rx) = tokio::sync::watch::channel::>(None); + let reader_cancel = CancellationToken::new(); + let reader_handle = if concurrent_reads == ConcurrentReads::Yes { + Some(tokio::task::spawn({ + let cancel = reader_cancel.clone(); + let layer = layer.clone(); + let ctx = ctx.attached_child(); + async move { + let gate = Gate::default(); + let gate_guard = gate.enter().unwrap(); + let io_concurrency = IoConcurrency::spawn_from_conf( + GetVectoredConcurrentIo::SidecarTask, + gate_guard, + ); + + rx.wait_for(|key| key.is_some()).await.unwrap(); + + while !cancel.is_cancelled() { + let key = match *rx.borrow() { + Some(some) => some, + None => unreachable!(), + }; + + let mut start_key = key; + start_key.field6 = key.field6.saturating_sub(READ_BATCH_SIZE); + let key_range = start_key..key.next(); + + let mut reconstruct_state = ValuesReconstructState::new(io_concurrency.clone()); + + layer + .get_values_reconstruct_data( + KeySpace::single(key_range), + Lsn(1)..Lsn(u64::MAX), + &mut reconstruct_state, + &ctx, + ) + .await + .unwrap(); + + let mut collect_futs = std::mem::take(&mut reconstruct_state.keys) + .into_values() + .map(|state| state.sink_pending_ios()) + .collect::>(); + while collect_futs.next().await.is_some() {} + } + + drop(io_concurrency); + gate.close().await; + } + })) + } else { + None + }; + const BATCH_SIZE: usize = 16; let mut batch = Vec::new(); @@ -113,19 +188,27 @@ async fn ingest( batch.push((key.to_compact(), lsn, data_ser_size, data.clone())); if batch.len() >= BATCH_SIZE { + let last_key = Key::from_compact(batch.last().unwrap().0); + let this_batch = std::mem::take(&mut batch); let serialized = SerializedValueBatch::from_values(this_batch); layer.put_batch(serialized, &ctx).await?; + + tx.send(Some(last_key)).unwrap(); } } if !batch.is_empty() { + let last_key = Key::from_compact(batch.last().unwrap().0); + let this_batch = std::mem::take(&mut batch); let serialized = SerializedValueBatch::from_values(this_batch); layer.put_batch(serialized, &ctx).await?; + + tx.send(Some(last_key)).unwrap(); } layer.freeze(lsn + 1).await; - if matches!(write_delta, WriteDelta::Yes) { + if write_delta == WriteDelta::Yes { let l0_flush_state = L0FlushGlobalState::new(L0FlushConfig::Direct { max_concurrency: NonZeroUsize::new(1).unwrap(), }); @@ -136,6 +219,11 @@ async fn ingest( tokio::fs::remove_file(path).await?; } + reader_cancel.cancel(); + if let Some(handle) = reader_handle { + handle.await.unwrap(); + } + Ok(()) } @@ -147,6 +235,7 @@ fn ingest_main( put_count: usize, key_layout: KeyLayout, write_delta: WriteDelta, + concurrent_reads: ConcurrentReads, ) { pageserver::virtual_file::set_io_mode(io_mode); @@ -156,7 +245,15 @@ fn ingest_main( .unwrap(); runtime.block_on(async move { - let r = ingest(conf, put_size, put_count, key_layout, write_delta).await; + let r = ingest( + conf, + put_size, + put_count, + key_layout, + write_delta, + concurrent_reads, + ) + .await; if let Err(e) = r { panic!("{e:?}"); } @@ -195,6 +292,7 @@ fn criterion_benchmark(c: &mut Criterion) { key_size: usize, key_layout: KeyLayout, write_delta: WriteDelta, + concurrent_reads: ConcurrentReads, } #[derive(Clone)] struct HandPickedParameters { @@ -245,7 +343,7 @@ fn criterion_benchmark(c: &mut Criterion) { ]; let exploded_parameters = { let mut out = Vec::new(); - for io_mode in IoMode::iter() { + for concurrent_reads in [ConcurrentReads::Yes, ConcurrentReads::No] { for param in expect.clone() { let HandPickedParameters { volume_mib, @@ -253,12 +351,18 @@ fn criterion_benchmark(c: &mut Criterion) { key_layout, write_delta, } = param; + + if key_layout != KeyLayout::Sequential && concurrent_reads == ConcurrentReads::Yes { + continue; + } + out.push(ExplodedParameters { - io_mode, + io_mode: IoMode::DirectRw, volume_mib, key_size, key_layout, write_delta, + concurrent_reads, }); } } @@ -272,9 +376,10 @@ fn criterion_benchmark(c: &mut Criterion) { key_size, key_layout, write_delta, + concurrent_reads, } = self; format!( - "io_mode={io_mode:?} volume_mib={volume_mib:?} key_size_bytes={key_size:?} key_layout={key_layout:?} write_delta={write_delta:?}" + "io_mode={io_mode:?} volume_mib={volume_mib:?} key_size_bytes={key_size:?} key_layout={key_layout:?} write_delta={write_delta:?} concurrent_reads={concurrent_reads:?}" ) } } @@ -287,12 +392,23 @@ fn criterion_benchmark(c: &mut Criterion) { key_size, key_layout, write_delta, + concurrent_reads, } = params; let put_count = volume_mib * 1024 * 1024 / key_size; group.throughput(criterion::Throughput::Bytes((key_size * put_count) as u64)); group.sample_size(10); group.bench_function(id, |b| { - b.iter(|| ingest_main(conf, io_mode, key_size, put_count, key_layout, write_delta)) + b.iter(|| { + ingest_main( + conf, + io_mode, + key_size, + put_count, + key_layout, + write_delta, + concurrent_reads, + ) + }) }); } } diff --git a/pageserver/benches/bench_walredo.rs b/pageserver/benches/bench_walredo.rs index 215682d90c..efb970f705 100644 --- a/pageserver/benches/bench_walredo.rs +++ b/pageserver/benches/bench_walredo.rs @@ -67,12 +67,13 @@ use once_cell::sync::Lazy; use pageserver::config::PageServerConf; use pageserver::walredo::{PostgresRedoManager, RedoAttemptType}; use pageserver_api::key::Key; -use pageserver_api::record::NeonWalRecord; use pageserver_api::shard::TenantShardId; +use postgres_ffi::{BLCKSZ, PgMajorVersion}; use tokio::sync::Barrier; use tokio::task::JoinSet; use utils::id::TenantId; use utils::lsn::Lsn; +use wal_decoder::models::record::NeonWalRecord; fn bench(c: &mut Criterion) { macro_rules! bench_group { @@ -94,7 +95,7 @@ fn bench(c: &mut Criterion) { // // benchmark the protocol implementation // - let pg_version = 14; + let pg_version = PgMajorVersion::PG14; bench_group!( "ping", Arc::new(move |mgr: Arc| async move { @@ -107,7 +108,7 @@ fn bench(c: &mut Criterion) { let make_redo_work = |req: &'static Request| { Arc::new(move |mgr: Arc| async move { let page = req.execute(&mgr).await.unwrap(); - assert_eq!(page.remaining(), 8192); + assert_eq!(page.remaining(), BLCKSZ as usize); }) }; bench_group!("short", { @@ -208,7 +209,7 @@ struct Request { lsn: Lsn, base_img: Option<(Lsn, Bytes)>, records: Vec<(Lsn, NeonWalRecord)>, - pg_version: u32, + pg_version: PgMajorVersion, } impl Request { @@ -267,7 +268,7 @@ impl Request { pg_record(false, b"\xbc\0\0\0\0\0\0\0h?m\x01\0\0\0\0p\n\0\09\x08\xa3\xea\0 \x8c\0\x7f\x06\0\0\xd22\0\0\xeb\x04\0\0\0\0\0\0\xff\x02\0@\0\0another_table\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x98\x08\0\0\x02@\0\0\0\0\0\0\n\0\0\0\x02\0\0\0\0@\0\0\0\0\0\0\x05\0\0\0\0@zD\x05\0\0\0\0\0\0\0\0\0pr\x01\0\0\0\0\0\0\0\0\x01d\0\0\0\0\0\0\x04\0\0\x01\0\0\0\x02\0"), ), ], - pg_version: 14, + pg_version: PgMajorVersion::PG14, } } @@ -516,7 +517,7 @@ impl Request { (lsn!("0/16B8000"), pg_record(false, b"C\0\0\0\0\x04\0\0p\x7fk\x01\0\0\0\0\0\n\0\0\\\xc4:?\0 \x12\0\x7f\x06\0\0\xd22\0\0\0@\0\0\0\0\0\0\xff\x03\x01\0\0\x08\x01\0\0\0\x18\0\xe1\0\0\0\0\0\0\0\xe2\0\0")), (lsn!("0/16CBD68"), pg_record(false, b"@ \0\0\0\0\0\0\xc0|l\x01\0\0\0\0@\t\0\0\xdf\xb0\x1a`\0\x12\0\0\0 \0\0\x04\x7f\x06\0\0\xd22\0\0\0@\0\0\0\0\0\0\x01\x80\0\0\0\0\0\0\xff\x05\0\0\0\0\0\0\0\0\0\0\0\0\x18\0\0 \0 \x04 \0\0\0\0\x01\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\x04\0\0\x01")), ], - pg_version: 14, + pg_version: PgMajorVersion::PG14, } } } diff --git a/pageserver/client/Cargo.toml b/pageserver/client/Cargo.toml index 970a437a42..47e2a6ddae 100644 --- a/pageserver/client/Cargo.toml +++ b/pageserver/client/Cargo.toml @@ -18,6 +18,7 @@ workspace_hack = { version = "0.1", path = "../../workspace_hack" } tokio-postgres.workspace = true tokio-stream.workspace = true tokio.workspace = true +postgres_versioninfo.workspace = true futures.workspace = true tokio-util.workspace = true anyhow.workspace = true diff --git a/pageserver/client/src/mgmt_api.rs b/pageserver/client/src/mgmt_api.rs index 219e63c9d4..af4be23b9b 100644 --- a/pageserver/client/src/mgmt_api.rs +++ b/pageserver/client/src/mgmt_api.rs @@ -7,6 +7,7 @@ use detach_ancestor::AncestorDetached; use http_utils::error::HttpErrorBody; use pageserver_api::models::*; use pageserver_api::shard::TenantShardId; +use postgres_versioninfo::PgMajorVersion; pub use reqwest::Body as ReqwestBody; use reqwest::{IntoUrl, Method, StatusCode, Url}; use utils::id::{TenantId, TimelineId}; @@ -508,11 +509,11 @@ impl Client { .expect("Cannot build URL"); path.query_pairs_mut() - .append_pair("recurse", &format!("{}", recurse)); + .append_pair("recurse", &format!("{recurse}")); if let Some(concurrency) = concurrency { path.query_pairs_mut() - .append_pair("concurrency", &format!("{}", concurrency)); + .append_pair("concurrency", &format!("{concurrency}")); } self.request(Method::POST, path, ()).await.map(|_| ()) @@ -745,9 +746,11 @@ impl Client { timeline_id: TimelineId, base_lsn: Lsn, end_lsn: Lsn, - pg_version: u32, + pg_version: PgMajorVersion, basebackup_tarball: ReqwestBody, ) -> Result<()> { + let pg_version = pg_version.major_version_num(); + let uri = format!( "{}/v1/tenant/{tenant_id}/timeline/{timeline_id}/import_basebackup?base_lsn={base_lsn}&end_lsn={end_lsn}&pg_version={pg_version}", self.mgmt_api_endpoint, @@ -841,4 +844,13 @@ impl Client { .await .map_err(Error::ReceiveBody) } + + pub async fn update_feature_flag_spec(&self, spec: String) -> Result<()> { + let uri = format!("{}/v1/feature_flag_spec", self.mgmt_api_endpoint); + self.request(Method::POST, uri, spec) + .await? + .json() + .await + .map_err(Error::ReceiveBody) + } } diff --git a/pageserver/client/src/page_service.rs b/pageserver/client/src/page_service.rs index ef35ac2f48..085c0e6543 100644 --- a/pageserver/client/src/page_service.rs +++ b/pageserver/client/src/page_service.rs @@ -2,7 +2,7 @@ use std::sync::{Arc, Mutex}; use futures::stream::{SplitSink, SplitStream}; use futures::{SinkExt, StreamExt}; -use pageserver_api::models::{ +use pageserver_api::pagestream_api::{ PagestreamBeMessage, PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetPageResponse, }; use pageserver_api::reltag::RelTag; diff --git a/pageserver/compaction/src/simulator/draw.rs b/pageserver/compaction/src/simulator/draw.rs index 3d35d1b91e..8322fe7d6d 100644 --- a/pageserver/compaction/src/simulator/draw.rs +++ b/pageserver/compaction/src/simulator/draw.rs @@ -152,7 +152,7 @@ pub fn draw_history(history: &[LayerTraceEvent], mut output: let key_diff = key_end - key_start; if key_start >= key_end { - panic!("Invalid key range {}-{}", key_start, key_end); + panic!("Invalid key range {key_start}-{key_end}"); } let lsn_start = lsn_map.map(f.lsn_range.start); @@ -212,12 +212,12 @@ pub fn draw_history(history: &[LayerTraceEvent], mut output: )?; writeln!(svg, "")?; } - Ordering::Greater => panic!("Invalid lsn range {}-{}", lsn_start, lsn_end), + Ordering::Greater => panic!("Invalid lsn range {lsn_start}-{lsn_end}"), } files_seen.insert(f); } - writeln!(svg, "{}", EndSvg)?; + writeln!(svg, "{EndSvg}")?; let mut layer_events_str = String::new(); let mut first = true; diff --git a/pageserver/ctl/src/draw_timeline_dir.rs b/pageserver/ctl/src/draw_timeline_dir.rs index 80ca414543..2135d302c1 100644 --- a/pageserver/ctl/src/draw_timeline_dir.rs +++ b/pageserver/ctl/src/draw_timeline_dir.rs @@ -20,7 +20,7 @@ //! //! # local timeline dir //! ls test_output/test_pgbench\[neon-45-684\]/repo/tenants/$TENANT/timelines/$TIMELINE | \ -//! grep "__" | cargo run --release --bin pagectl draw-timeline-dir > out.svg +//! grep "__" | cargo run --release --bin pagectl draw-timeline > out.svg //! //! # Layer map dump from `/v1/tenant/$TENANT/timeline/$TIMELINE/layer` //! (jq -r '.historic_layers[] | .layer_file_name' | cargo run -p pagectl draw-timeline) < layer-map.json > out.svg @@ -81,7 +81,11 @@ fn build_coordinate_compression_map(coords: Vec) -> BTreeMap (Range, Range) { let split: Vec<&str> = name.split("__").collect(); let keys: Vec<&str> = split[0].split('-').collect(); - let mut lsns: Vec<&str> = split[1].split('-').collect(); + + // Remove the temporary file extension, e.g., remove the `.d20a.___temp` part from the following filename: + // 000000067F000040490000404A00441B0000-000000067F000040490000404A00441B4000__000043483A34CE00.d20a.___temp + let lsns = split[1].split('.').collect::>()[0]; + let mut lsns: Vec<&str> = lsns.split('-').collect(); // The current format of the layer file name: 000000067F0000000400000B150100000000-000000067F0000000400000D350100000000__00000000014B7AC8-v1-00000001 @@ -224,7 +228,7 @@ pub fn main() -> Result<()> { let lsn_max = lsn_map.len(); if key_start >= key_end { - panic!("Invalid key range {}-{}", key_start, key_end); + panic!("Invalid key range {key_start}-{key_end}"); } let lsn_start = *lsn_map.get(&lsnr.start).unwrap(); @@ -246,7 +250,7 @@ pub fn main() -> Result<()> { ymargin = 0.05; fill = Fill::Color(rgb(0, 0, 0)); } - Ordering::Greater => panic!("Invalid lsn range {}-{}", lsn_start, lsn_end), + Ordering::Greater => panic!("Invalid lsn range {lsn_start}-{lsn_end}"), } println!( @@ -283,10 +287,10 @@ pub fn main() -> Result<()> { ); } - println!("{}", EndSvg); + println!("{EndSvg}"); - eprintln!("num_images: {}", num_images); - eprintln!("num_deltas: {}", num_deltas); + eprintln!("num_images: {num_images}"); + eprintln!("num_deltas: {num_deltas}"); Ok(()) } diff --git a/pageserver/ctl/src/key.rs b/pageserver/ctl/src/key.rs index 600f7c412e..c4daafdfd0 100644 --- a/pageserver/ctl/src/key.rs +++ b/pageserver/ctl/src/key.rs @@ -372,7 +372,7 @@ impl std::fmt::Debug for RelTagish { f.write_char('/')?; } first = false; - write!(f, "{}", x) + write!(f, "{x}") }) } } diff --git a/pageserver/ctl/src/layer_map_analyzer.rs b/pageserver/ctl/src/layer_map_analyzer.rs index c49c8b58df..ef844fbd0f 100644 --- a/pageserver/ctl/src/layer_map_analyzer.rs +++ b/pageserver/ctl/src/layer_map_analyzer.rs @@ -224,8 +224,7 @@ pub(crate) async fn main(cmd: &AnalyzeLayerMapCmd) -> Result<()> { } } println!( - "Total delta layers {} image layers {} excess layers {}", - total_delta_layers, total_image_layers, total_excess_layers + "Total delta layers {total_delta_layers} image layers {total_image_layers} excess layers {total_excess_layers}" ); Ok(()) } diff --git a/pageserver/ctl/src/layers.rs b/pageserver/ctl/src/layers.rs index 79f56a5a51..42b3e4a9ba 100644 --- a/pageserver/ctl/src/layers.rs +++ b/pageserver/ctl/src/layers.rs @@ -13,7 +13,7 @@ use pageserver::{page_cache, virtual_file}; use pageserver_api::key::Key; use utils::id::{TenantId, TimelineId}; -use crate::layer_map_analyzer::parse_filename; +use crate::layer_map_analyzer::{LayerFile, parse_filename}; #[derive(Subcommand)] pub(crate) enum LayerCmd { @@ -38,6 +38,8 @@ pub(crate) enum LayerCmd { /// The id from list-layer command id: usize, }, + /// Dump all information of a layer file locally + DumpLayerLocal { path: PathBuf }, RewriteSummary { layer_file_path: Utf8PathBuf, #[clap(long)] @@ -131,15 +133,7 @@ pub(crate) async fn main(cmd: &LayerCmd) -> Result<()> { } for (idx, layer_file) in to_print { - println!( - "[{:3}] key:{}-{}\n lsn:{}-{}\n delta:{}", - idx, - layer_file.key_range.start, - layer_file.key_range.end, - layer_file.lsn_range.start, - layer_file.lsn_range.end, - layer_file.is_delta, - ); + print_layer_file(idx, &layer_file); } Ok(()) } @@ -159,16 +153,7 @@ pub(crate) async fn main(cmd: &LayerCmd) -> Result<()> { let layer = layer?; if let Ok(layer_file) = parse_filename(&layer.file_name().into_string().unwrap()) { if *id == idx { - // TODO(chi): dedup code - println!( - "[{:3}] key:{}-{}\n lsn:{}-{}\n delta:{}", - idx, - layer_file.key_range.start, - layer_file.key_range.end, - layer_file.lsn_range.start, - layer_file.lsn_range.end, - layer_file.is_delta, - ); + print_layer_file(idx, &layer_file); if layer_file.is_delta { read_delta_file(layer.path(), &ctx).await?; @@ -183,6 +168,18 @@ pub(crate) async fn main(cmd: &LayerCmd) -> Result<()> { } Ok(()) } + LayerCmd::DumpLayerLocal { path } => { + if let Ok(layer_file) = parse_filename(path.file_name().unwrap().to_str().unwrap()) { + print_layer_file(0, &layer_file); + + if layer_file.is_delta { + read_delta_file(path, &ctx).await?; + } else { + read_image_file(path, &ctx).await?; + } + } + Ok(()) + } LayerCmd::RewriteSummary { layer_file_path, new_tenant_id, @@ -247,3 +244,15 @@ pub(crate) async fn main(cmd: &LayerCmd) -> Result<()> { } } } + +fn print_layer_file(idx: usize, layer_file: &LayerFile) { + println!( + "[{:3}] key:{}-{}\n lsn:{}-{}\n delta:{}", + idx, + layer_file.key_range.start, + layer_file.key_range.end, + layer_file.lsn_range.start, + layer_file.lsn_range.end, + layer_file.is_delta, + ); +} diff --git a/pageserver/ctl/src/main.rs b/pageserver/ctl/src/main.rs index 1d81b839a8..3cd4faaf2e 100644 --- a/pageserver/ctl/src/main.rs +++ b/pageserver/ctl/src/main.rs @@ -176,9 +176,11 @@ async fn main() -> anyhow::Result<()> { let config = RemoteStorageConfig::from_toml_str(&cmd.config_toml_str)?; let storage = remote_storage::GenericRemoteStorage::from_config(&config).await; let cancel = CancellationToken::new(); + // Complexity limit: as we are running this command locally, we should have a lot of memory available, and we do not + // need to limit the number of versions we are going to delete. storage .unwrap() - .time_travel_recover(Some(&prefix), timestamp, done_if_after, &cancel) + .time_travel_recover(Some(&prefix), timestamp, done_if_after, &cancel, None) .await?; } Commands::Key(dkc) => dkc.execute(), diff --git a/pageserver/page_api/Cargo.toml b/pageserver/page_api/Cargo.toml index e643b5749b..c5283c2b09 100644 --- a/pageserver/page_api/Cargo.toml +++ b/pageserver/page_api/Cargo.toml @@ -5,11 +5,16 @@ edition.workspace = true license.workspace = true [dependencies] +anyhow.workspace = true bytes.workspace = true +futures.workspace = true pageserver_api.workspace = true postgres_ffi.workspace = true prost.workspace = true +strum.workspace = true +strum_macros.workspace = true thiserror.workspace = true +tokio.workspace = true tonic.workspace = true utils.workspace = true workspace_hack.workspace = true diff --git a/pageserver/page_api/proto/page_service.proto b/pageserver/page_api/proto/page_service.proto index 44976084bf..d06b2cfca5 100644 --- a/pageserver/page_api/proto/page_service.proto +++ b/pageserver/page_api/proto/page_service.proto @@ -102,12 +102,27 @@ message CheckRelExistsResponse { bool exists = 1; } -// Requests a base backup at a given LSN. +// Requests a base backup. message GetBaseBackupRequest { - // The LSN to fetch a base backup at. - ReadLsn read_lsn = 1; + // The LSN to fetch the base backup at. 0 or absent means the latest LSN known to the Pageserver. + uint64 lsn = 1; // If true, logical replication slots will not be created. bool replica = 2; + // If true, include relation files in the base backup. Mainly for debugging and tests. + bool full = 3; + // Compression algorithm to use. Base backups send a compressed payload instead of using gRPC + // compression, so that we can cache compressed backups on the server. + BaseBackupCompression compression = 4; +} + +// Base backup compression algorithms. +enum BaseBackupCompression { + // Unknown algorithm. Used when clients send an unsupported algorithm. + BASE_BACKUP_COMPRESSION_UNKNOWN = 0; + // No compression. + BASE_BACKUP_COMPRESSION_NONE = 1; + // GZIP compression. + BASE_BACKUP_COMPRESSION_GZIP = 2; } // Base backup response chunk, returned as an ordered stream. diff --git a/pageserver/page_api/src/client.rs b/pageserver/page_api/src/client.rs new file mode 100644 index 0000000000..71d539ab91 --- /dev/null +++ b/pageserver/page_api/src/client.rs @@ -0,0 +1,199 @@ +use std::convert::TryInto; + +use bytes::Bytes; +use futures::TryStreamExt; +use futures::{Stream, StreamExt}; +use tonic::metadata::AsciiMetadataValue; +use tonic::metadata::errors::InvalidMetadataValue; +use tonic::transport::Channel; +use tonic::{Request, Streaming}; + +use utils::id::TenantId; +use utils::id::TimelineId; +use utils::shard::ShardIndex; + +use anyhow::Result; + +use crate::model; +use crate::proto; + +/// +/// AuthInterceptor adds tenant, timeline, and auth header to the channel. These +/// headers are required at the pageserver. +/// +#[derive(Clone)] +struct AuthInterceptor { + tenant_id: AsciiMetadataValue, + timeline_id: AsciiMetadataValue, + shard_id: AsciiMetadataValue, + auth_header: Option, // including "Bearer " prefix +} + +impl AuthInterceptor { + fn new( + tenant_id: TenantId, + timeline_id: TimelineId, + auth_token: Option, + shard_id: ShardIndex, + ) -> Result { + let tenant_ascii: AsciiMetadataValue = tenant_id.to_string().try_into()?; + let timeline_ascii: AsciiMetadataValue = timeline_id.to_string().try_into()?; + let shard_ascii: AsciiMetadataValue = shard_id.to_string().try_into()?; + + let auth_header: Option = match auth_token { + Some(token) => Some(format!("Bearer {token}").try_into()?), + None => None, + }; + + Ok(Self { + tenant_id: tenant_ascii, + shard_id: shard_ascii, + timeline_id: timeline_ascii, + auth_header, + }) + } +} + +impl tonic::service::Interceptor for AuthInterceptor { + fn call(&mut self, mut req: tonic::Request<()>) -> Result, tonic::Status> { + req.metadata_mut() + .insert("neon-tenant-id", self.tenant_id.clone()); + req.metadata_mut() + .insert("neon-shard-id", self.shard_id.clone()); + req.metadata_mut() + .insert("neon-timeline-id", self.timeline_id.clone()); + if let Some(auth_header) = &self.auth_header { + req.metadata_mut() + .insert("authorization", auth_header.clone()); + } + Ok(req) + } +} +#[derive(Clone)] +pub struct Client { + client: proto::PageServiceClient< + tonic::service::interceptor::InterceptedService, + >, +} + +impl Client { + pub async fn new + Send + Sync + 'static>( + into_endpoint: T, + tenant_id: TenantId, + timeline_id: TimelineId, + shard_id: ShardIndex, + auth_header: Option, + compression: Option, + ) -> anyhow::Result { + let endpoint: tonic::transport::Endpoint = into_endpoint + .try_into() + .map_err(|_e| anyhow::anyhow!("failed to convert endpoint"))?; + let channel = endpoint.connect().await?; + let auth = AuthInterceptor::new(tenant_id, timeline_id, auth_header, shard_id) + .map_err(|e| anyhow::anyhow!(e.to_string()))?; + let mut client = proto::PageServiceClient::with_interceptor(channel, auth); + + if let Some(compression) = compression { + // TODO: benchmark this (including network latency). + client = client + .accept_compressed(compression) + .send_compressed(compression); + } + + Ok(Self { client }) + } + + /// Returns whether a relation exists. + pub async fn check_rel_exists( + &mut self, + req: model::CheckRelExistsRequest, + ) -> Result { + let proto_req = proto::CheckRelExistsRequest::from(req); + + let response = self.client.check_rel_exists(proto_req).await?; + + let proto_resp = response.into_inner(); + Ok(proto_resp.into()) + } + + /// Fetches a base backup. + pub async fn get_base_backup( + &mut self, + req: model::GetBaseBackupRequest, + ) -> Result> + 'static, tonic::Status> { + let proto_req = proto::GetBaseBackupRequest::from(req); + + let response_stream: Streaming = + self.client.get_base_backup(proto_req).await?.into_inner(); + + // TODO: Consider dechunking internally + let domain_stream = response_stream.map(|chunk_res| { + chunk_res.and_then(|proto_chunk| { + proto_chunk.try_into().map_err(|e| { + tonic::Status::internal(format!("Failed to convert response chunk: {e}")) + }) + }) + }); + + Ok(domain_stream) + } + + /// Returns the total size of a database, as # of bytes. + pub async fn get_db_size( + &mut self, + req: model::GetDbSizeRequest, + ) -> Result { + let proto_req = proto::GetDbSizeRequest::from(req); + + let response = self.client.get_db_size(proto_req).await?; + Ok(response.into_inner().into()) + } + + /// Fetches pages. + /// + /// This is implemented as a bidirectional streaming RPC for performance. + /// Per-request errors are often returned as status_code instead of errors, + /// to avoid tearing down the entire stream via tonic::Status. + pub async fn get_pages( + &mut self, + inbound: ReqSt, + ) -> Result< + impl Stream> + Send + 'static, + tonic::Status, + > + where + ReqSt: Stream + Send + 'static, + { + let outbound_proto = inbound.map(|domain_req| domain_req.into()); + + let req_new = Request::new(outbound_proto); + + let response_stream: Streaming = + self.client.get_pages(req_new).await?.into_inner(); + + let domain_stream = response_stream.map_ok(model::GetPageResponse::from); + + Ok(domain_stream) + } + + /// Returns the size of a relation, as # of blocks. + pub async fn get_rel_size( + &mut self, + req: model::GetRelSizeRequest, + ) -> Result { + let proto_req = proto::GetRelSizeRequest::from(req); + let response = self.client.get_rel_size(proto_req).await?; + let proto_resp = response.into_inner(); + Ok(proto_resp.into()) + } + + /// Fetches an SLRU segment. + pub async fn get_slru_segment( + &mut self, + req: model::GetSlruSegmentRequest, + ) -> Result { + let proto_req = proto::GetSlruSegmentRequest::from(req); + let response = self.client.get_slru_segment(proto_req).await?; + Ok(response.into_inner().try_into()?) + } +} diff --git a/pageserver/page_api/src/lib.rs b/pageserver/page_api/src/lib.rs index f515f27f3e..e78f6ce206 100644 --- a/pageserver/page_api/src/lib.rs +++ b/pageserver/page_api/src/lib.rs @@ -18,6 +18,8 @@ pub mod proto { pub use page_service_server::{PageService, PageServiceServer}; } +mod client; +pub use client::Client; mod model; pub use model::*; diff --git a/pageserver/page_api/src/model.rs b/pageserver/page_api/src/model.rs index 1a08d04cc1..1ca89b4870 100644 --- a/pageserver/page_api/src/model.rs +++ b/pageserver/page_api/src/model.rs @@ -26,7 +26,7 @@ use utils::lsn::Lsn; use crate::proto; /// A protocol error. Typically returned via try_from() or try_into(). -#[derive(thiserror::Error, Debug)] +#[derive(thiserror::Error, Clone, Debug)] pub enum ProtocolError { #[error("field '{0}' has invalid value '{1}'")] Invalid(&'static str, String), @@ -182,13 +182,18 @@ impl From for proto::CheckRelExistsResponse { } } -/// Requests a base backup at a given LSN. +/// Requests a base backup. #[derive(Clone, Copy, Debug)] pub struct GetBaseBackupRequest { - /// The LSN to fetch a base backup at. - pub read_lsn: ReadLsn, + /// The LSN to fetch a base backup at. If None, uses the latest LSN known to the Pageserver. + pub lsn: Option, /// If true, logical replication slots will not be created. pub replica: bool, + /// If true, include relation files in the base backup. Mainly for debugging and tests. + pub full: bool, + /// Compression algorithm to use. Base backups send a compressed payload instead of using gRPC + /// compression, so that we can cache compressed backups on the server. + pub compression: BaseBackupCompression, } impl TryFrom for GetBaseBackupRequest { @@ -196,11 +201,10 @@ impl TryFrom for GetBaseBackupRequest { fn try_from(pb: proto::GetBaseBackupRequest) -> Result { Ok(Self { - read_lsn: pb - .read_lsn - .ok_or(ProtocolError::Missing("read_lsn"))? - .try_into()?, + lsn: (pb.lsn != 0).then_some(Lsn(pb.lsn)), replica: pb.replica, + full: pb.full, + compression: pb.compression.try_into()?, }) } } @@ -208,12 +212,58 @@ impl TryFrom for GetBaseBackupRequest { impl From for proto::GetBaseBackupRequest { fn from(request: GetBaseBackupRequest) -> Self { Self { - read_lsn: Some(request.read_lsn.into()), + lsn: request.lsn.unwrap_or_default().0, replica: request.replica, + full: request.full, + compression: request.compression.into(), } } } +/// Base backup compression algorithm. +#[derive(Clone, Copy, Debug)] +pub enum BaseBackupCompression { + None, + Gzip, +} + +impl TryFrom for BaseBackupCompression { + type Error = ProtocolError; + + fn try_from(pb: proto::BaseBackupCompression) -> Result { + match pb { + proto::BaseBackupCompression::Unknown => Err(ProtocolError::invalid("compression", pb)), + proto::BaseBackupCompression::None => Ok(Self::None), + proto::BaseBackupCompression::Gzip => Ok(Self::Gzip), + } + } +} + +impl TryFrom for BaseBackupCompression { + type Error = ProtocolError; + + fn try_from(compression: i32) -> Result { + proto::BaseBackupCompression::try_from(compression) + .map_err(|_| ProtocolError::invalid("compression", compression)) + .and_then(Self::try_from) + } +} + +impl From for proto::BaseBackupCompression { + fn from(compression: BaseBackupCompression) -> Self { + match compression { + BaseBackupCompression::None => Self::None, + BaseBackupCompression::Gzip => Self::Gzip, + } + } +} + +impl From for i32 { + fn from(compression: BaseBackupCompression) -> Self { + proto::BaseBackupCompression::from(compression).into() + } +} + pub type GetBaseBackupResponseChunk = Bytes; impl TryFrom for GetBaseBackupResponseChunk { @@ -422,12 +472,45 @@ impl From for proto::GetPageResponse { } } +impl GetPageResponse { + /// Attempts to represent a tonic::Status as a GetPageResponse if appropriate. Returning a + /// tonic::Status will terminate the GetPage stream, so per-request errors are emitted as a + /// GetPageResponse with a non-OK status code instead. + #[allow(clippy::result_large_err)] + pub fn try_from_status( + status: tonic::Status, + request_id: RequestID, + ) -> Result { + // We shouldn't see an OK status here, because we're emitting an error. + debug_assert_ne!(status.code(), tonic::Code::Ok); + if status.code() == tonic::Code::Ok { + return Err(tonic::Status::internal(format!( + "unexpected OK status: {status:?}", + ))); + } + + // If we can't convert the tonic::Code to a GetPageStatusCode, this is not a per-request + // error and we should return a tonic::Status to terminate the stream. + let Ok(status_code) = status.code().try_into() else { + return Err(status); + }; + + // Return a GetPageResponse for the status. + Ok(Self { + request_id, + status_code, + reason: Some(status.message().to_string()), + page_images: Vec::new(), + }) + } +} + /// A GetPage response status code. /// /// These are effectively equivalent to gRPC statuses. However, we use a bidirectional stream /// (potentially shared by many backends), and a gRPC status response would terminate the stream so /// we send GetPageResponse messages with these codes instead. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq, strum_macros::Display)] pub enum GetPageStatusCode { /// Unknown status. For forwards compatibility: used when an older client version receives a new /// status code from a newer server version. @@ -485,8 +568,42 @@ impl From for i32 { } } +impl TryFrom for GetPageStatusCode { + type Error = tonic::Code; + + fn try_from(code: tonic::Code) -> Result { + use tonic::Code; + + let status_code = match code { + Code::Ok => Self::Ok, + + // These are per-request errors, which should be returned as GetPageResponses. + Code::AlreadyExists => Self::InvalidRequest, + Code::DataLoss => Self::InternalError, + Code::FailedPrecondition => Self::InvalidRequest, + Code::InvalidArgument => Self::InvalidRequest, + Code::Internal => Self::InternalError, + Code::NotFound => Self::NotFound, + Code::OutOfRange => Self::InvalidRequest, + Code::ResourceExhausted => Self::SlowDown, + + // These should terminate the stream by returning a tonic::Status. + Code::Aborted + | Code::Cancelled + | Code::DeadlineExceeded + | Code::PermissionDenied + | Code::Unauthenticated + | Code::Unavailable + | Code::Unimplemented + | Code::Unknown => return Err(code), + }; + Ok(status_code) + } +} + // Fetches the size of a relation at a given LSN, as # of blocks. Only valid on shard 0, other // shards will error. +#[derive(Clone, Copy, Debug)] pub struct GetRelSizeRequest { pub read_lsn: ReadLsn, pub rel: RelTag, @@ -530,6 +647,7 @@ impl From for proto::GetRelSizeResponse { } /// Requests an SLRU segment. Only valid on shard 0, other shards will error. +#[derive(Clone, Copy, Debug)] pub struct GetSlruSegmentRequest { pub read_lsn: ReadLsn, pub kind: SlruKind, diff --git a/pageserver/pagebench/Cargo.toml b/pageserver/pagebench/Cargo.toml index 5e4af88e69..f5dfc0db25 100644 --- a/pageserver/pagebench/Cargo.toml +++ b/pageserver/pagebench/Cargo.toml @@ -25,6 +25,7 @@ tokio.workspace = true tokio-stream.workspace = true tokio-util.workspace = true tonic.workspace = true +url.workspace = true pageserver_client.workspace = true pageserver_api.workspace = true diff --git a/pageserver/pagebench/src/cmd/aux_files.rs b/pageserver/pagebench/src/cmd/aux_files.rs index 6441c047c2..43d7a73399 100644 --- a/pageserver/pagebench/src/cmd/aux_files.rs +++ b/pageserver/pagebench/src/cmd/aux_files.rs @@ -62,7 +62,7 @@ async fn main_impl(args: Args) -> anyhow::Result<()> { let tenant_shard_id = TenantShardId::unsharded(timeline.tenant_id); let timeline_id = timeline.timeline_id; - println!("operating on timeline {}", timeline); + println!("operating on timeline {timeline}"); mgmt_api_client .set_tenant_config(&TenantConfigRequest { @@ -75,8 +75,8 @@ async fn main_impl(args: Args) -> anyhow::Result<()> { let items = (0..100) .map(|id| { ( - format!("pg_logical/mappings/{:03}.{:03}", batch, id), - format!("{:08}", id), + format!("pg_logical/mappings/{batch:03}.{id:03}"), + format!("{id:08}"), ) }) .collect::>(); diff --git a/pageserver/pagebench/src/cmd/basebackup.rs b/pageserver/pagebench/src/cmd/basebackup.rs index 43ad92980c..4111d09f92 100644 --- a/pageserver/pagebench/src/cmd/basebackup.rs +++ b/pageserver/pagebench/src/cmd/basebackup.rs @@ -1,20 +1,29 @@ use std::collections::HashMap; use std::num::NonZeroUsize; use std::ops::Range; -use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; +use std::pin::Pin; +use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::{Arc, Mutex}; use std::time::Instant; -use anyhow::Context; +use anyhow::anyhow; +use futures::TryStreamExt as _; use pageserver_api::shard::TenantShardId; use pageserver_client::mgmt_api::ForceAwaitLogicalSize; use pageserver_client::page_service::BasebackupRequest; +use pageserver_page_api as page_api; use rand::prelude::*; +use tokio::io::AsyncRead; use tokio::sync::Barrier; use tokio::task::JoinSet; +use tokio_util::compat::{TokioAsyncReadCompatExt as _, TokioAsyncWriteCompatExt as _}; +use tokio_util::io::StreamReader; +use tonic::async_trait; use tracing::{info, instrument}; +use url::Url; use utils::id::TenantTimelineId; use utils::lsn::Lsn; +use utils::shard::ShardIndex; use crate::util::tokio_thread_local_stats::AllThreadLocalStats; use crate::util::{request_stats, tokio_thread_local_stats}; @@ -24,14 +33,15 @@ use crate::util::{request_stats, tokio_thread_local_stats}; pub(crate) struct Args { #[clap(long, default_value = "http://localhost:9898")] mgmt_api_endpoint: String, - #[clap(long, default_value = "postgres://postgres@localhost:64000")] + /// The Pageserver to connect to. Use postgresql:// for libpq, or grpc:// for gRPC. + #[clap(long, default_value = "postgresql://postgres@localhost:64000")] page_service_connstring: String, #[clap(long)] pageserver_jwt: Option, #[clap(long, default_value = "1")] num_clients: NonZeroUsize, - #[clap(long, default_value = "1.0")] - gzip_probability: f64, + #[clap(long)] + no_compression: bool, #[clap(long)] runtime: Option, #[clap(long)] @@ -146,12 +156,27 @@ async fn main_impl( let mut work_senders = HashMap::new(); let mut tasks = Vec::new(); - for tl in &timelines { + let scheme = match Url::parse(&args.page_service_connstring) { + Ok(url) => url.scheme().to_lowercase().to_string(), + Err(url::ParseError::RelativeUrlWithoutBase) => "postgresql".to_string(), + Err(err) => return Err(anyhow!("invalid connstring: {err}")), + }; + for &tl in &timelines { let (sender, receiver) = tokio::sync::mpsc::channel(1); // TODO: not sure what the implications of this are work_senders.insert(tl, sender); - tasks.push(tokio::spawn(client( - args, - *tl, + + let client: Box = match scheme.as_str() { + "postgresql" | "postgres" => Box::new( + LibpqClient::new(&args.page_service_connstring, tl, !args.no_compression).await?, + ), + "grpc" => Box::new( + GrpcClient::new(&args.page_service_connstring, tl, !args.no_compression).await?, + ), + scheme => return Err(anyhow!("invalid scheme {scheme}")), + }; + + tasks.push(tokio::spawn(run_worker( + client, Arc::clone(&start_work_barrier), receiver, Arc::clone(&all_work_done_barrier), @@ -166,13 +191,7 @@ async fn main_impl( let mut rng = rand::thread_rng(); let target = all_targets.choose(&mut rng).unwrap(); let lsn = target.lsn_range.clone().map(|r| rng.gen_range(r)); - ( - target.timeline, - Work { - lsn, - gzip: rng.gen_bool(args.gzip_probability), - }, - ) + (target.timeline, Work { lsn }) }; let sender = work_senders.get(&timeline).unwrap(); // TODO: what if this blocks? @@ -216,13 +235,11 @@ async fn main_impl( #[derive(Copy, Clone)] struct Work { lsn: Option, - gzip: bool, } #[instrument(skip_all)] -async fn client( - args: &'static Args, - timeline: TenantTimelineId, +async fn run_worker( + mut client: Box, start_work_barrier: Arc, mut work: tokio::sync::mpsc::Receiver, all_work_done_barrier: Arc, @@ -230,37 +247,14 @@ async fn client( ) { start_work_barrier.wait().await; - let client = pageserver_client::page_service::Client::new(args.page_service_connstring.clone()) - .await - .unwrap(); - - while let Some(Work { lsn, gzip }) = work.recv().await { + while let Some(Work { lsn }) = work.recv().await { let start = Instant::now(); - let copy_out_stream = client - .basebackup(&BasebackupRequest { - tenant_id: timeline.tenant_id, - timeline_id: timeline.timeline_id, - lsn, - gzip, - }) - .await - .with_context(|| format!("start basebackup for {timeline}")) - .unwrap(); + let stream = client.basebackup(lsn).await.unwrap(); - use futures::StreamExt; - let size = Arc::new(AtomicUsize::new(0)); - copy_out_stream - .for_each({ - |r| { - let size = Arc::clone(&size); - async move { - let size = Arc::clone(&size); - size.fetch_add(r.unwrap().len(), Ordering::Relaxed); - } - } - }) - .await; - info!("basebackup size is {} bytes", size.load(Ordering::Relaxed)); + let size = futures::io::copy(stream.compat(), &mut tokio::io::sink().compat_write()) + .await + .unwrap(); + info!("basebackup size is {size} bytes"); let elapsed = start.elapsed(); live_stats.inc(); STATS.with(|stats| { @@ -270,3 +264,100 @@ async fn client( all_work_done_barrier.wait().await; } + +/// A basebackup client. This allows switching out the client protocol implementation. +#[async_trait] +trait Client: Send { + async fn basebackup( + &mut self, + lsn: Option, + ) -> anyhow::Result>>; +} + +/// A libpq-based Pageserver client. +struct LibpqClient { + inner: pageserver_client::page_service::Client, + ttid: TenantTimelineId, + compression: bool, +} + +impl LibpqClient { + async fn new( + connstring: &str, + ttid: TenantTimelineId, + compression: bool, + ) -> anyhow::Result { + Ok(Self { + inner: pageserver_client::page_service::Client::new(connstring.to_string()).await?, + ttid, + compression, + }) + } +} + +#[async_trait] +impl Client for LibpqClient { + async fn basebackup( + &mut self, + lsn: Option, + ) -> anyhow::Result>> { + let req = BasebackupRequest { + tenant_id: self.ttid.tenant_id, + timeline_id: self.ttid.timeline_id, + lsn, + gzip: self.compression, + }; + let stream = self.inner.basebackup(&req).await?; + Ok(Box::pin(StreamReader::new( + stream.map_err(std::io::Error::other), + ))) + } +} + +/// A gRPC Pageserver client. +struct GrpcClient { + inner: page_api::Client, + compression: page_api::BaseBackupCompression, +} + +impl GrpcClient { + async fn new( + connstring: &str, + ttid: TenantTimelineId, + compression: bool, + ) -> anyhow::Result { + let inner = page_api::Client::new( + connstring.to_string(), + ttid.tenant_id, + ttid.timeline_id, + ShardIndex::unsharded(), + None, + None, // NB: uses payload compression + ) + .await?; + let compression = match compression { + true => page_api::BaseBackupCompression::Gzip, + false => page_api::BaseBackupCompression::None, + }; + Ok(Self { inner, compression }) + } +} + +#[async_trait] +impl Client for GrpcClient { + async fn basebackup( + &mut self, + lsn: Option, + ) -> anyhow::Result>> { + let req = page_api::GetBaseBackupRequest { + lsn, + replica: false, + full: false, + compression: self.compression, + }; + let stream = self.inner.get_base_backup(req).await?; + Ok(Box::pin(StreamReader::new( + stream.map_err(std::io::Error::other), + ))) + } +} diff --git a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs index 3f3b6e396e..a297819e9b 100644 --- a/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs +++ b/pageserver/pagebench/src/cmd/getpage_latest_lsn.rs @@ -10,33 +10,31 @@ use anyhow::Context; use async_trait::async_trait; use bytes::Bytes; use camino::Utf8PathBuf; +use futures::{Stream, StreamExt as _}; use pageserver_api::key::Key; use pageserver_api::keyspace::KeySpaceAccum; -use pageserver_api::models::{PagestreamGetPageRequest, PagestreamRequest}; +use pageserver_api::pagestream_api::{PagestreamGetPageRequest, PagestreamRequest}; use pageserver_api::reltag::RelTag; use pageserver_api::shard::TenantShardId; -use pageserver_page_api::proto; +use pageserver_page_api as page_api; use rand::prelude::*; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use tracing::info; +use url::Url; use utils::id::TenantTimelineId; use utils::lsn::Lsn; +use utils::shard::ShardIndex; use crate::util::tokio_thread_local_stats::AllThreadLocalStats; use crate::util::{request_stats, tokio_thread_local_stats}; -#[derive(clap::ValueEnum, Clone, Debug)] -enum Protocol { - Libpq, - Grpc, -} - /// GetPage@LatestLSN, uniformly distributed across the compute-accessible keyspace. #[derive(clap::Parser)] pub(crate) struct Args { #[clap(long, default_value = "http://localhost:9898")] mgmt_api_endpoint: String, + /// Pageserver connection string. Supports postgresql:// and grpc:// protocols. #[clap(long, default_value = "postgres://postgres@localhost:64000")] page_service_connstring: String, #[clap(long)] @@ -45,8 +43,9 @@ pub(crate) struct Args { num_clients: NonZeroUsize, #[clap(long)] runtime: Option, - #[clap(long, value_enum, default_value = "libpq")] - protocol: Protocol, + /// If true, enable compression (only for gRPC). + #[clap(long)] + compression: bool, /// Each client sends requests at the given rate. /// /// If a request takes too long and we should be issuing a new request already, @@ -325,18 +324,32 @@ async fn main_impl( .unwrap(); Box::pin(async move { - let client: Box = match args.protocol { - Protocol::Libpq => Box::new( - LibpqClient::new(args.page_service_connstring.clone(), worker_id.timeline) - .await - .unwrap(), + let scheme = match Url::parse(&args.page_service_connstring) { + Ok(url) => url.scheme().to_lowercase().to_string(), + Err(url::ParseError::RelativeUrlWithoutBase) => "postgresql".to_string(), + Err(err) => panic!("invalid connstring: {err}"), + }; + let client: Box = match scheme.as_str() { + "postgresql" | "postgres" => { + assert!(!args.compression, "libpq does not support compression"); + Box::new( + LibpqClient::new(&args.page_service_connstring, worker_id.timeline) + .await + .unwrap(), + ) + } + + "grpc" => Box::new( + GrpcClient::new( + &args.page_service_connstring, + worker_id.timeline, + args.compression, + ) + .await + .unwrap(), ), - Protocol::Grpc => Box::new( - GrpcClient::new(args.page_service_connstring.clone(), worker_id.timeline) - .await - .unwrap(), - ), + scheme => panic!("unsupported scheme {scheme}"), }; run_worker(args, client, ss, cancel, rps_period, ranges, weights).await }) @@ -543,8 +556,8 @@ struct LibpqClient { } impl LibpqClient { - async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result { - let inner = pageserver_client::page_service::Client::new(connstring) + async fn new(connstring: &str, ttid: TenantTimelineId) -> anyhow::Result { + let inner = pageserver_client::page_service::Client::new(connstring.to_string()) .await? .pagestream(ttid.tenant_id, ttid.timeline_id) .await?; @@ -600,34 +613,36 @@ impl Client for LibpqClient { } } -/// A gRPC client using the raw, no-frills gRPC client. +/// A gRPC Pageserver client. struct GrpcClient { - req_tx: tokio::sync::mpsc::Sender, - resp_rx: tonic::Streaming, + req_tx: tokio::sync::mpsc::Sender, + resp_rx: Pin> + Send>>, } impl GrpcClient { - async fn new(connstring: String, ttid: TenantTimelineId) -> anyhow::Result { - let mut client = pageserver_page_api::proto::PageServiceClient::connect(connstring).await?; + async fn new( + connstring: &str, + ttid: TenantTimelineId, + compression: bool, + ) -> anyhow::Result { + let mut client = page_api::Client::new( + connstring.to_string(), + ttid.tenant_id, + ttid.timeline_id, + ShardIndex::unsharded(), + None, + compression.then_some(tonic::codec::CompressionEncoding::Zstd), + ) + .await?; // The channel has a buffer size of 1, since 0 is not allowed. It does not matter, since the // benchmark will control the queue depth (i.e. in-flight requests) anyway, and requests are // buffered by Tonic and the OS too. let (req_tx, req_rx) = tokio::sync::mpsc::channel(1); let req_stream = tokio_stream::wrappers::ReceiverStream::new(req_rx); - let mut req = tonic::Request::new(req_stream); - let metadata = req.metadata_mut(); - metadata.insert("neon-tenant-id", ttid.tenant_id.to_string().try_into()?); - metadata.insert("neon-timeline-id", ttid.timeline_id.to_string().try_into()?); - metadata.insert("neon-shard-id", "0000".try_into()?); + let resp_rx = Box::pin(client.get_pages(req_stream).await?); - let resp = client.get_pages(req).await?; - let resp_stream = resp.into_inner(); - - Ok(Self { - req_tx, - resp_rx: resp_stream, - }) + Ok(Self { req_tx, resp_rx }) } } @@ -641,27 +656,27 @@ impl Client for GrpcClient { rel: RelTag, blks: Vec, ) -> anyhow::Result<()> { - let req = proto::GetPageRequest { + let req = page_api::GetPageRequest { request_id: req_id, - request_class: proto::GetPageClass::Normal as i32, - read_lsn: Some(proto::ReadLsn { - request_lsn: req_lsn.0, - not_modified_since_lsn: mod_lsn.0, - }), - rel: Some(rel.into()), - block_number: blks, + request_class: page_api::GetPageClass::Normal, + read_lsn: page_api::ReadLsn { + request_lsn: req_lsn, + not_modified_since_lsn: Some(mod_lsn), + }, + rel, + block_numbers: blks, }; self.req_tx.send(req).await?; Ok(()) } async fn recv_get_page(&mut self) -> anyhow::Result<(u64, Vec)> { - let resp = self.resp_rx.message().await?.unwrap(); + let resp = self.resp_rx.next().await.unwrap().unwrap(); anyhow::ensure!( - resp.status_code == proto::GetPageStatusCode::Ok as i32, + resp.status_code == page_api::GetPageStatusCode::Ok, "unexpected status code: {}", - resp.status_code + resp.status_code, ); - Ok((resp.request_id, resp.page_image)) + Ok((resp.request_id, resp.page_images)) } } diff --git a/pageserver/src/basebackup.rs b/pageserver/src/basebackup.rs index 2a0548b811..36dada1e89 100644 --- a/pageserver/src/basebackup.rs +++ b/pageserver/src/basebackup.rs @@ -14,19 +14,19 @@ use std::fmt::Write as FmtWrite; use std::time::{Instant, SystemTime}; use anyhow::{Context, anyhow}; +use async_compression::tokio::write::GzipEncoder; use bytes::{BufMut, Bytes, BytesMut}; use fail::fail_point; use pageserver_api::key::{Key, rel_block_to_key}; use pageserver_api::reltag::{RelTag, SlruKind}; -use postgres_ffi::pg_constants::{ - DEFAULTTABLESPACE_OID, GLOBALTABLESPACE_OID, PG_HBA, PGDATA_SPECIAL_FILES, -}; -use postgres_ffi::relfile_utils::{INIT_FORKNUM, MAIN_FORKNUM}; +use postgres_ffi::pg_constants::{PG_HBA, PGDATA_SPECIAL_FILES}; use postgres_ffi::{ - BLCKSZ, PG_TLI, RELSEG_SIZE, WAL_SEGMENT_SIZE, XLogFileName, dispatch_pgversion, pg_constants, + BLCKSZ, PG_TLI, PgMajorVersion, RELSEG_SIZE, WAL_SEGMENT_SIZE, XLogFileName, + dispatch_pgversion, pg_constants, }; -use tokio::io; -use tokio::io::AsyncWrite; +use postgres_ffi_types::constants::{DEFAULTTABLESPACE_OID, GLOBALTABLESPACE_OID}; +use postgres_ffi_types::forknum::{INIT_FORKNUM, MAIN_FORKNUM}; +use tokio::io::{self, AsyncWrite, AsyncWriteExt as _}; use tokio_tar::{Builder, EntryType, Header}; use tracing::*; use utils::lsn::Lsn; @@ -97,6 +97,7 @@ impl From for tonic::Status { /// * When working without safekeepers. In this situation it is important to match the lsn /// we are taking basebackup on with the lsn that is used in pageserver's walreceiver /// to start the replication. +#[allow(clippy::too_many_arguments)] pub async fn send_basebackup_tarball<'a, W>( write: &'a mut W, timeline: &'a Timeline, @@ -104,6 +105,7 @@ pub async fn send_basebackup_tarball<'a, W>( prev_lsn: Option, full_backup: bool, replica: bool, + gzip_level: Option, ctx: &'a RequestContext, ) -> Result<(), BasebackupError> where @@ -122,7 +124,7 @@ where // prev_lsn value; that happens if the timeline was just branched from // an old LSN and it doesn't have any WAL of its own yet. We will set // prev_lsn to Lsn(0) if we cannot provide the correct value. - let (backup_prev, backup_lsn) = if let Some(req_lsn) = req_lsn { + let (backup_prev, lsn) = if let Some(req_lsn) = req_lsn { // Backup was requested at a particular LSN. The caller should've // already checked that it's a valid LSN. @@ -143,7 +145,7 @@ where }; // Consolidate the derived and the provided prev_lsn values - let prev_lsn = if let Some(provided_prev_lsn) = prev_lsn { + let prev_record_lsn = if let Some(provided_prev_lsn) = prev_lsn { if backup_prev != Lsn(0) && backup_prev != provided_prev_lsn { return Err(BasebackupError::Server(anyhow!( "backup_prev {backup_prev} != provided_prev_lsn {provided_prev_lsn}" @@ -155,30 +157,55 @@ where }; info!( - "taking basebackup lsn={}, prev_lsn={} (full_backup={}, replica={})", - backup_lsn, prev_lsn, full_backup, replica + "taking basebackup lsn={lsn}, prev_lsn={prev_record_lsn} \ + (full_backup={full_backup}, replica={replica}, gzip={gzip_level:?})", + ); + let span = info_span!("send_tarball", backup_lsn=%lsn); + + let io_concurrency = IoConcurrency::spawn_from_conf( + timeline.conf.get_vectored_concurrent_io, + timeline + .gate + .enter() + .map_err(|_| BasebackupError::Shutdown)?, ); - let basebackup = Basebackup { - ar: Builder::new_non_terminated(write), - timeline, - lsn: backup_lsn, - prev_record_lsn: prev_lsn, - full_backup, - replica, - ctx, - io_concurrency: IoConcurrency::spawn_from_conf( - timeline.conf.get_vectored_concurrent_io, - timeline - .gate - .enter() - .map_err(|_| BasebackupError::Shutdown)?, - ), - }; - basebackup + if let Some(gzip_level) = gzip_level { + let mut encoder = GzipEncoder::with_quality(write, gzip_level); + Basebackup { + ar: Builder::new_non_terminated(&mut encoder), + timeline, + lsn, + prev_record_lsn, + full_backup, + replica, + ctx, + io_concurrency, + } .send_tarball() - .instrument(info_span!("send_tarball", backup_lsn=%backup_lsn)) - .await + .instrument(span) + .await?; + encoder + .shutdown() + .await + .map_err(|err| BasebackupError::Client(err, "gzip"))?; + } else { + Basebackup { + ar: Builder::new_non_terminated(write), + timeline, + lsn, + prev_record_lsn, + full_backup, + replica, + ctx, + io_concurrency, + } + .send_tarball() + .instrument(span) + .await?; + } + + Ok(()) } /// This is short-living object only for the time of tarball creation, @@ -372,6 +399,7 @@ where .partition( self.timeline.get_shard_identity(), self.timeline.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64, + BLCKSZ as u64, ); let mut slru_builder = SlruSegmentsBuilder::new(&mut self.ar); @@ -619,10 +647,7 @@ where }; if spcnode == GLOBALTABLESPACE_OID { - let pg_version_str = match self.timeline.pg_version { - 14 | 15 => self.timeline.pg_version.to_string(), - ver => format!("{ver}\x0A"), - }; + let pg_version_str = self.timeline.pg_version.versionfile_string(); let header = new_tar_header("PG_VERSION", pg_version_str.len() as u64)?; self.ar .append(&header, pg_version_str.as_bytes()) @@ -669,7 +694,7 @@ where } // Append dir path for each database - let path = format!("base/{}", dbnode); + let path = format!("base/{dbnode}"); let header = new_tar_header_dir(&path)?; self.ar .append(&header, io::empty()) @@ -677,19 +702,16 @@ where .map_err(|e| BasebackupError::Client(e, "add_dbdir,base"))?; if let Some(img) = relmap_img { - let dst_path = format!("base/{}/PG_VERSION", dbnode); + let dst_path = format!("base/{dbnode}/PG_VERSION"); - let pg_version_str = match self.timeline.pg_version { - 14 | 15 => self.timeline.pg_version.to_string(), - ver => format!("{ver}\x0A"), - }; + let pg_version_str = self.timeline.pg_version.versionfile_string(); let header = new_tar_header(&dst_path, pg_version_str.len() as u64)?; self.ar .append(&header, pg_version_str.as_bytes()) .await .map_err(|e| BasebackupError::Client(e, "add_dbdir,base/PG_VERSION"))?; - let relmap_path = format!("base/{}/pg_filenode.map", dbnode); + let relmap_path = format!("base/{dbnode}/pg_filenode.map"); let header = new_tar_header(&relmap_path, img.len() as u64)?; self.ar .append(&header, &img[..]) @@ -713,10 +735,10 @@ where buf.extend_from_slice(&img[..]); let crc = crc32c::crc32c(&img[..]); buf.put_u32_le(crc); - let path = if self.timeline.pg_version < 17 { - format!("pg_twophase/{:>08X}", xid) + let path = if self.timeline.pg_version < PgMajorVersion::PG17 { + format!("pg_twophase/{xid:>08X}") } else { - format!("pg_twophase/{:>016X}", xid) + format!("pg_twophase/{xid:>016X}") }; let header = new_tar_header(&path, buf.len() as u64)?; self.ar @@ -768,7 +790,7 @@ where //send wal segment let segno = self.lsn.segment_number(WAL_SEGMENT_SIZE); let wal_file_name = XLogFileName(PG_TLI, segno, WAL_SEGMENT_SIZE); - let wal_file_path = format!("pg_wal/{}", wal_file_name); + let wal_file_path = format!("pg_wal/{wal_file_name}"); let header = new_tar_header(&wal_file_path, WAL_SEGMENT_SIZE as u64)?; let wal_seg = postgres_ffi::generate_wal_segment( diff --git a/pageserver/src/basebackup_cache.rs b/pageserver/src/basebackup_cache.rs index 3a8ec555f7..4966fee2d7 100644 --- a/pageserver/src/basebackup_cache.rs +++ b/pageserver/src/basebackup_cache.rs @@ -1,12 +1,12 @@ use std::{collections::HashMap, sync::Arc}; -use async_compression::tokio::write::GzipEncoder; +use anyhow::Context; use camino::{Utf8Path, Utf8PathBuf}; use metrics::core::{AtomicU64, GenericCounter}; use pageserver_api::{config::BasebackupCacheConfig, models::TenantState}; use tokio::{ io::{AsyncWriteExt, BufWriter}, - sync::mpsc::{UnboundedReceiver, UnboundedSender}, + sync::mpsc::{Receiver, Sender, error::TrySendError}, }; use tokio_util::sync::CancellationToken; use utils::{ @@ -18,7 +18,10 @@ use utils::{ use crate::{ basebackup::send_basebackup_tarball, context::{DownloadBehavior, RequestContext}, - metrics::{BASEBACKUP_CACHE_ENTRIES, BASEBACKUP_CACHE_PREPARE, BASEBACKUP_CACHE_READ}, + metrics::{ + BASEBACKUP_CACHE_ENTRIES, BASEBACKUP_CACHE_PREPARE, BASEBACKUP_CACHE_PREPARE_QUEUE_SIZE, + BASEBACKUP_CACHE_READ, BASEBACKUP_CACHE_SIZE, + }, task_mgr::TaskKind, tenant::{ Timeline, @@ -32,11 +35,16 @@ pub struct BasebackupPrepareRequest { pub lsn: Lsn, } -pub type BasebackupPrepareSender = UnboundedSender; -pub type BasebackupPrepareReceiver = UnboundedReceiver; +pub type BasebackupPrepareSender = Sender; +pub type BasebackupPrepareReceiver = Receiver; -type BasebackupRemoveEntrySender = UnboundedSender; -type BasebackupRemoveEntryReceiver = UnboundedReceiver; +#[derive(Clone)] +struct CacheEntry { + /// LSN at which the basebackup was taken. + lsn: Lsn, + /// Size of the basebackup archive in bytes. + size_bytes: u64, +} /// BasebackupCache stores cached basebackup archives for timelines on local disk. /// @@ -52,68 +60,118 @@ type BasebackupRemoveEntryReceiver = UnboundedReceiver; /// and ~1 RPS for get requests. pub struct BasebackupCache { data_dir: Utf8PathBuf, - config: BasebackupCacheConfig, - tenant_manager: Arc, - remove_entry_sender: BasebackupRemoveEntrySender, + config: Option, - entries: std::sync::Mutex>, + entries: std::sync::Mutex>, - cancel: CancellationToken, + prepare_sender: BasebackupPrepareSender, read_hit_count: GenericCounter, read_miss_count: GenericCounter, read_err_count: GenericCounter, - prepare_ok_count: GenericCounter, prepare_skip_count: GenericCounter, - prepare_err_count: GenericCounter, } impl BasebackupCache { - /// Creates a BasebackupCache and spawns the background task. - /// The initialization of the cache is performed in the background and does not - /// block the caller. The cache will return `None` for any get requests until - /// initialization is complete. - pub fn spawn( - runtime_handle: &tokio::runtime::Handle, + /// Create a new BasebackupCache instance. + /// Also returns a BasebackupPrepareReceiver which is needed to start + /// the background task. + /// The cache is initialized from the data_dir in the background task. + /// The cache will return `None` for any get requests until the initialization is complete. + /// The background task is spawned separately using [`Self::spawn_background_task`] + /// to avoid a circular dependency between the cache and the tenant manager. + pub fn new( data_dir: Utf8PathBuf, config: Option, - prepare_receiver: BasebackupPrepareReceiver, - tenant_manager: Arc, - cancel: CancellationToken, - ) -> Arc { - let (remove_entry_sender, remove_entry_receiver) = tokio::sync::mpsc::unbounded_channel(); + ) -> (Arc, BasebackupPrepareReceiver) { + let chan_size = config.as_ref().map(|c| c.max_size_entries).unwrap_or(1); - let enabled = config.is_some(); + let (prepare_sender, prepare_receiver) = tokio::sync::mpsc::channel(chan_size); let cache = Arc::new(BasebackupCache { data_dir, - config: config.unwrap_or_default(), - tenant_manager, - remove_entry_sender, - + config, entries: std::sync::Mutex::new(HashMap::new()), - - cancel, + prepare_sender, read_hit_count: BASEBACKUP_CACHE_READ.with_label_values(&["hit"]), read_miss_count: BASEBACKUP_CACHE_READ.with_label_values(&["miss"]), read_err_count: BASEBACKUP_CACHE_READ.with_label_values(&["error"]), - prepare_ok_count: BASEBACKUP_CACHE_PREPARE.with_label_values(&["ok"]), prepare_skip_count: BASEBACKUP_CACHE_PREPARE.with_label_values(&["skip"]), - prepare_err_count: BASEBACKUP_CACHE_PREPARE.with_label_values(&["error"]), }); - if enabled { - runtime_handle.spawn( - cache - .clone() - .background(prepare_receiver, remove_entry_receiver), - ); - } + (cache, prepare_receiver) + } - cache + /// Spawns the background task. + /// The background task initializes the cache from the disk, + /// processes prepare requests, and cleans up outdated cache entries. + /// Noop if the cache is disabled (config is None). + pub fn spawn_background_task( + self: Arc, + runtime_handle: &tokio::runtime::Handle, + prepare_receiver: BasebackupPrepareReceiver, + tenant_manager: Arc, + cancel: CancellationToken, + ) { + if let Some(config) = self.config.clone() { + let background = BackgroundTask { + c: self, + + config, + tenant_manager, + cancel, + + entry_count: 0, + total_size_bytes: 0, + + prepare_ok_count: BASEBACKUP_CACHE_PREPARE.with_label_values(&["ok"]), + prepare_skip_count: BASEBACKUP_CACHE_PREPARE.with_label_values(&["skip"]), + prepare_err_count: BASEBACKUP_CACHE_PREPARE.with_label_values(&["error"]), + }; + runtime_handle.spawn(background.run(prepare_receiver)); + } + } + + /// Send a basebackup prepare request to the background task. + /// The basebackup will be prepared asynchronously, it does not block the caller. + /// The request will be skipped if any cache limits are exceeded. + pub fn send_prepare(&self, tenant_shard_id: TenantShardId, timeline_id: TimelineId, lsn: Lsn) { + let req = BasebackupPrepareRequest { + tenant_shard_id, + timeline_id, + lsn, + }; + + BASEBACKUP_CACHE_PREPARE_QUEUE_SIZE.inc(); + let res = self.prepare_sender.try_send(req); + + if let Err(e) = res { + BASEBACKUP_CACHE_PREPARE_QUEUE_SIZE.dec(); + self.prepare_skip_count.inc(); + match e { + TrySendError::Full(_) => { + // Basebackup prepares are pretty rare, normally we should not hit this. + tracing::info!( + tenant_id = %tenant_shard_id.tenant_id, + %timeline_id, + %lsn, + "Basebackup prepare channel is full, skipping the request" + ); + } + TrySendError::Closed(_) => { + // Normal during shutdown, not critical. + tracing::info!( + tenant_id = %tenant_shard_id.tenant_id, + %timeline_id, + %lsn, + "Basebackup prepare channel is closed, skipping the request" + ); + } + } + } } /// Gets a basebackup entry from the cache. @@ -126,9 +184,13 @@ impl BasebackupCache { timeline_id: TimelineId, lsn: Lsn, ) -> Option { + if !self.is_enabled() { + return None; + } + // Fast path. Check if the entry exists using the in-memory state. let tti = TenantTimelineId::new(tenant_id, timeline_id); - if self.entries.lock().unwrap().get(&tti) != Some(&lsn) { + if self.entries.lock().unwrap().get(&tti).map(|e| e.lsn) != Some(lsn) { self.read_miss_count.inc(); return None; } @@ -153,6 +215,10 @@ impl BasebackupCache { } } + pub fn is_enabled(&self) -> bool { + self.config.is_some() + } + // Private methods. fn entry_filename(tenant_id: TenantId, timeline_id: TimelineId, lsn: Lsn) -> String { @@ -166,6 +232,42 @@ impl BasebackupCache { self.data_dir .join(Self::entry_filename(tenant_id, timeline_id, lsn)) } +} + +/// The background task that does the job to prepare basebackups +/// and manage the cache entries on disk. +/// It is a separate struct from BasebackupCache to allow holding +/// a mutable reference to this state without a mutex lock, +/// while BasebackupCache is referenced by the clients. +struct BackgroundTask { + c: Arc, + + config: BasebackupCacheConfig, + tenant_manager: Arc, + cancel: CancellationToken, + + /// Number of the entries in the cache. + /// This counter is used for metrics and applying cache limits. + /// It generally should be equal to c.entries.len(), but it's calculated + /// pessimistically for abnormal situations: if we encountered some errors + /// during removing the entry from disk, we won't decrement this counter to + /// make sure that we don't exceed the limit with "trashed" files on the disk. + /// It will also count files in the data_dir that are not valid cache entries. + entry_count: usize, + /// Total size of all the entries on the disk. + /// This counter is used for metrics and applying cache limits. + /// Similar to entry_count, it is calculated pessimistically for abnormal situations. + total_size_bytes: u64, + + prepare_ok_count: GenericCounter, + prepare_skip_count: GenericCounter, + prepare_err_count: GenericCounter, +} + +impl BackgroundTask { + fn tmp_dir(&self) -> Utf8PathBuf { + self.c.data_dir.join("tmp") + } fn entry_tmp_path( &self, @@ -173,9 +275,8 @@ impl BasebackupCache { timeline_id: TimelineId, lsn: Lsn, ) -> Utf8PathBuf { - self.data_dir - .join("tmp") - .join(Self::entry_filename(tenant_id, timeline_id, lsn)) + self.tmp_dir() + .join(BasebackupCache::entry_filename(tenant_id, timeline_id, lsn)) } fn parse_entry_filename(filename: &str) -> Option<(TenantId, TimelineId, Lsn)> { @@ -194,18 +295,21 @@ impl BasebackupCache { Some((tenant_id, timeline_id, lsn)) } - async fn cleanup(&self) -> anyhow::Result<()> { - // Cleanup tmp directory. - let tmp_dir = self.data_dir.join("tmp"); - let mut tmp_dir = tokio::fs::read_dir(&tmp_dir).await?; - while let Some(dir_entry) = tmp_dir.next_entry().await? { - if let Err(e) = tokio::fs::remove_file(dir_entry.path()).await { - tracing::warn!("Failed to remove basebackup cache tmp file: {:#}", e); - } + // Recreate the tmp directory to clear all files in it. + async fn clean_tmp_dir(&self) -> anyhow::Result<()> { + let tmp_dir = self.tmp_dir(); + if tmp_dir.exists() { + tokio::fs::remove_dir_all(&tmp_dir).await?; } + tokio::fs::create_dir_all(&tmp_dir).await?; + Ok(()) + } - // Remove outdated entries. - let entries_old = self.entries.lock().unwrap().clone(); + async fn cleanup(&mut self) -> anyhow::Result<()> { + self.clean_tmp_dir().await?; + + // Leave only up-to-date entries. + let entries_old = self.c.entries.lock().unwrap().clone(); let mut entries_new = HashMap::new(); for (tenant_shard_id, tenant_slot) in self.tenant_manager.list() { if !tenant_shard_id.is_shard_zero() { @@ -218,43 +322,42 @@ impl BasebackupCache { for timeline in tenant.list_timelines() { let tti = TenantTimelineId::new(tenant_id, timeline.timeline_id); - if let Some(&entry_lsn) = entries_old.get(&tti) { - if timeline.get_last_record_lsn() <= entry_lsn { - entries_new.insert(tti, entry_lsn); + if let Some(entry) = entries_old.get(&tti) { + if timeline.get_last_record_lsn() <= entry.lsn { + entries_new.insert(tti, entry.clone()); } } } } - for (&tti, &lsn) in entries_old.iter() { + // Try to remove all entries that are not up-to-date. + for (&tti, entry) in entries_old.iter() { if !entries_new.contains_key(&tti) { - self.remove_entry_sender - .send(self.entry_path(tti.tenant_id, tti.timeline_id, lsn)) - .unwrap(); + self.try_remove_entry(tti.tenant_id, tti.timeline_id, entry) + .await; } } - BASEBACKUP_CACHE_ENTRIES.set(entries_new.len() as i64); - *self.entries.lock().unwrap() = entries_new; + // Note: BackgroundTask is the only writer for self.c.entries, + // so it couldn't have been modified concurrently. + *self.c.entries.lock().unwrap() = entries_new; Ok(()) } - async fn on_startup(&self) -> anyhow::Result<()> { - // Create data_dir and tmp directory if they do not exist. - tokio::fs::create_dir_all(&self.data_dir.join("tmp")) + async fn on_startup(&mut self) -> anyhow::Result<()> { + // Create data_dir if it does not exist. + tokio::fs::create_dir_all(&self.c.data_dir) .await - .map_err(|e| { - anyhow::anyhow!( - "Failed to create basebackup cache data_dir {:?}: {:?}", - self.data_dir, - e - ) - })?; + .context("Failed to create basebackup cache data directory")?; + + self.clean_tmp_dir() + .await + .context("Failed to clean tmp directory")?; // Read existing entries from the data_dir and add them to in-memory state. - let mut entries = HashMap::new(); - let mut dir = tokio::fs::read_dir(&self.data_dir).await?; + let mut entries = HashMap::::new(); + let mut dir = tokio::fs::read_dir(&self.c.data_dir).await?; while let Some(dir_entry) = dir.next_entry().await? { let filename = dir_entry.file_name(); @@ -263,33 +366,43 @@ impl BasebackupCache { continue; } + let size_bytes = dir_entry + .metadata() + .await + .map_err(|e| { + anyhow::anyhow!("Failed to read metadata for file {:?}: {:?}", filename, e) + })? + .len(); + + self.entry_count += 1; + BASEBACKUP_CACHE_ENTRIES.set(self.entry_count as u64); + + self.total_size_bytes += size_bytes; + BASEBACKUP_CACHE_SIZE.set(self.total_size_bytes); + let parsed = Self::parse_entry_filename(filename.to_string_lossy().as_ref()); let Some((tenant_id, timeline_id, lsn)) = parsed else { tracing::warn!("Invalid basebackup cache file name: {:?}", filename); continue; }; + let cur_entry = CacheEntry { lsn, size_bytes }; + let tti = TenantTimelineId::new(tenant_id, timeline_id); use std::collections::hash_map::Entry::*; match entries.entry(tti) { Occupied(mut entry) => { - let entry_lsn = *entry.get(); + let found_entry = entry.get(); // Leave only the latest entry, remove the old one. - if lsn < entry_lsn { - self.remove_entry_sender.send(self.entry_path( - tenant_id, - timeline_id, - lsn, - ))?; - } else if lsn > entry_lsn { - self.remove_entry_sender.send(self.entry_path( - tenant_id, - timeline_id, - entry_lsn, - ))?; - entry.insert(lsn); + if cur_entry.lsn < found_entry.lsn { + self.try_remove_entry(tenant_id, timeline_id, &cur_entry) + .await; + } else if cur_entry.lsn > found_entry.lsn { + self.try_remove_entry(tenant_id, timeline_id, found_entry) + .await; + entry.insert(cur_entry); } else { // Two different filenames parsed to the same timline_id and LSN. // Should never happen. @@ -300,22 +413,17 @@ impl BasebackupCache { } } Vacant(entry) => { - entry.insert(lsn); + entry.insert(cur_entry); } } } - BASEBACKUP_CACHE_ENTRIES.set(entries.len() as i64); - *self.entries.lock().unwrap() = entries; + *self.c.entries.lock().unwrap() = entries; Ok(()) } - async fn background( - self: Arc, - mut prepare_receiver: BasebackupPrepareReceiver, - mut remove_entry_receiver: BasebackupRemoveEntryReceiver, - ) { + async fn run(mut self, mut prepare_receiver: BasebackupPrepareReceiver) { // Panic in the background is a safe fallback. // It will drop receivers and the cache will be effectively disabled. self.on_startup() @@ -328,6 +436,7 @@ impl BasebackupCache { loop { tokio::select! { Some(req) = prepare_receiver.recv() => { + BASEBACKUP_CACHE_PREPARE_QUEUE_SIZE.dec(); if let Err(err) = self.prepare_basebackup( req.tenant_shard_id, req.timeline_id, @@ -338,11 +447,6 @@ impl BasebackupCache { continue; } } - Some(req) = remove_entry_receiver.recv() => { - if let Err(e) = tokio::fs::remove_file(req).await { - tracing::warn!("Failed to remove basebackup cache file: {:#}", e); - } - } _ = cleanup_ticker.tick() => { self.cleanup().await.unwrap_or_else(|e| { tracing::warn!("Failed to clean up basebackup cache: {:#}", e); @@ -356,6 +460,67 @@ impl BasebackupCache { } } + /// Try to remove an entry from disk. + /// The caller is responsible for removing the entry from the in-memory state. + /// Updates size counters and corresponding metrics. + /// Ignores the filesystem errors as not-so-important, but the size counters + /// are not decremented in this case, so the file will continue to be counted + /// towards the size limits. + async fn try_remove_entry( + &mut self, + tenant_id: TenantId, + timeline_id: TimelineId, + entry: &CacheEntry, + ) { + let entry_path = self.c.entry_path(tenant_id, timeline_id, entry.lsn); + + match tokio::fs::remove_file(&entry_path).await { + Ok(_) => {} + Err(e) if e.kind() == std::io::ErrorKind::NotFound => {} + Err(e) => { + tracing::warn!( + "Failed to remove basebackup cache file for tenant {} timeline {} LSN {}: {:#}", + tenant_id, + timeline_id, + entry.lsn, + e + ); + return; + } + } + + self.entry_count -= 1; + BASEBACKUP_CACHE_ENTRIES.set(self.entry_count as u64); + + self.total_size_bytes -= entry.size_bytes; + BASEBACKUP_CACHE_SIZE.set(self.total_size_bytes); + } + + /// Insert the cache entry into in-memory state and update the size counters. + /// Assumes that the file for the entry already exists on disk. + /// If the entry already exists with previous LSN, it will be removed. + async fn upsert_entry( + &mut self, + tenant_id: TenantId, + timeline_id: TimelineId, + entry: CacheEntry, + ) { + let tti = TenantTimelineId::new(tenant_id, timeline_id); + + self.entry_count += 1; + BASEBACKUP_CACHE_ENTRIES.set(self.entry_count as u64); + + self.total_size_bytes += entry.size_bytes; + BASEBACKUP_CACHE_SIZE.set(self.total_size_bytes); + + let old_entry = self.c.entries.lock().unwrap().insert(tti, entry); + + if let Some(old_entry) = old_entry { + self.try_remove_entry(tenant_id, timeline_id, &old_entry) + .await; + } + } + /// Prepare a basebackup for the given timeline. /// /// If the basebackup already exists with a higher LSN or the timeline already @@ -364,7 +529,7 @@ impl BasebackupCache { /// The basebackup is prepared in a temporary directory and then moved to the final /// location to make the operation atomic. async fn prepare_basebackup( - &self, + &mut self, tenant_shard_id: TenantShardId, timeline_id: TimelineId, req_lsn: Lsn, @@ -378,30 +543,44 @@ impl BasebackupCache { let tti = TenantTimelineId::new(tenant_shard_id.tenant_id, timeline_id); + // TODO(diko): I don't think we will hit the limit, + // but if we do, it makes sense to try to evict oldest entries. here + if self.entry_count >= self.config.max_size_entries { + tracing::info!( + %tenant_shard_id, + %timeline_id, + %req_lsn, + "Basebackup cache is full (max_size_entries), skipping basebackup", + ); + self.prepare_skip_count.inc(); + return Ok(()); + } + + if self.total_size_bytes >= self.config.max_total_size_bytes { + tracing::info!( + %tenant_shard_id, + %timeline_id, + %req_lsn, + "Basebackup cache is full (max_total_size_bytes), skipping basebackup", + ); + self.prepare_skip_count.inc(); + return Ok(()); + } + { - let entries = self.entries.lock().unwrap(); - if let Some(&entry_lsn) = entries.get(&tti) { - if entry_lsn >= req_lsn { + let entries = self.c.entries.lock().unwrap(); + if let Some(entry) = entries.get(&tti) { + if entry.lsn >= req_lsn { tracing::info!( %timeline_id, %req_lsn, - %entry_lsn, + %entry.lsn, "Basebackup entry already exists for timeline with higher LSN, skipping basebackup", ); self.prepare_skip_count.inc(); return Ok(()); } } - - if entries.len() as i64 >= self.config.max_size_entries { - tracing::info!( - %timeline_id, - %req_lsn, - "Basebackup cache is full, skipping basebackup", - ); - self.prepare_skip_count.inc(); - return Ok(()); - } } let tenant = self @@ -437,56 +616,54 @@ impl BasebackupCache { .prepare_basebackup_tmp(&entry_tmp_path, &timeline, req_lsn) .await; - if let Err(err) = res { - tracing::info!("Failed to prepare basebackup tmp file: {:#}", err); - // Try to clean up tmp file. If we fail, the background clean up task will take care of it. - match tokio::fs::remove_file(&entry_tmp_path).await { - Ok(_) => {} - Err(e) if e.kind() == std::io::ErrorKind::NotFound => {} - Err(e) => { - tracing::info!("Failed to remove basebackup tmp file: {:?}", e); + let entry = match res { + Ok(entry) => entry, + Err(err) => { + tracing::info!("Failed to prepare basebackup tmp file: {:#}", err); + // Try to clean up tmp file. If we fail, the background clean up task will take care of it. + match tokio::fs::remove_file(&entry_tmp_path).await { + Ok(_) => {} + Err(e) if e.kind() == std::io::ErrorKind::NotFound => {} + Err(e) => { + tracing::info!("Failed to remove basebackup tmp file: {:?}", e); + } } + return Err(err); } - return Err(err); - } + }; // Move the tmp file to the final location atomically. - let entry_path = self.entry_path(tenant_shard_id.tenant_id, timeline_id, req_lsn); + // The tmp file is fsynced, so it's guaranteed that we will not have a partial file + // in the main directory. + // It's not necessary to fsync the inode after renaming, because the worst case is that + // the rename operation will be rolled back on the disk failure, the entry will disappear + // from the main directory, and the entry access will cause a cache miss. + let entry_path = self + .c + .entry_path(tenant_shard_id.tenant_id, timeline_id, req_lsn); tokio::fs::rename(&entry_tmp_path, &entry_path).await?; - let mut entries = self.entries.lock().unwrap(); - if let Some(old_lsn) = entries.insert(tti, req_lsn) { - // Remove the old entry if it exists. - self.remove_entry_sender - .send(self.entry_path(tenant_shard_id.tenant_id, timeline_id, old_lsn)) - .unwrap(); - } - BASEBACKUP_CACHE_ENTRIES.set(entries.len() as i64); + self.upsert_entry(tenant_shard_id.tenant_id, timeline_id, entry) + .await; self.prepare_ok_count.inc(); Ok(()) } /// Prepares a basebackup in a temporary file. + /// Guarantees that the tmp file is fsynced before returning. async fn prepare_basebackup_tmp( &self, - emptry_tmp_path: &Utf8Path, + entry_tmp_path: &Utf8Path, timeline: &Arc, req_lsn: Lsn, - ) -> anyhow::Result<()> { + ) -> anyhow::Result { let ctx = RequestContext::new(TaskKind::BasebackupCache, DownloadBehavior::Download); let ctx = ctx.with_scope_timeline(timeline); - let file = tokio::fs::File::create(emptry_tmp_path).await?; + let file = tokio::fs::File::create(entry_tmp_path).await?; let mut writer = BufWriter::new(file); - let mut encoder = GzipEncoder::with_quality( - &mut writer, - // Level::Best because compression is not on the hot path of basebackup requests. - // The decompression is almost not affected by the compression level. - async_compression::Level::Best, - ); - // We may receive a request before the WAL record is applied to the timeline. // Wait for the requested LSN to be applied. timeline @@ -499,20 +676,28 @@ impl BasebackupCache { .await?; send_basebackup_tarball( - &mut encoder, + &mut writer, timeline, Some(req_lsn), None, false, false, + // Level::Best because compression is not on the hot path of basebackup requests. + // The decompression is almost not affected by the compression level. + Some(async_compression::Level::Best), &ctx, ) .await?; - encoder.shutdown().await?; writer.flush().await?; writer.into_inner().sync_all().await?; - Ok(()) + // TODO(diko): we can count it via Writer wrapper instead of a syscall. + let size_bytes = tokio::fs::metadata(entry_tmp_path).await?.len(); + + Ok(CacheEntry { + lsn: req_lsn, + size_bytes, + }) } } diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 45b90a5aa3..1bb99c4605 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -23,6 +23,7 @@ use pageserver::deletion_queue::DeletionQueue; use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task}; use pageserver::feature_resolver::FeatureResolver; use pageserver::metrics::{STARTUP_DURATION, STARTUP_IS_LOADING}; +use pageserver::page_service::GrpcPageServiceHandler; use pageserver::task_mgr::{ BACKGROUND_RUNTIME, COMPUTE_REQUEST_RUNTIME, MGMT_REQUEST_RUNTIME, WALRECEIVER_RUNTIME, }; @@ -581,11 +582,14 @@ fn start_pageserver( pageserver::l0_flush::L0FlushGlobalState::new(conf.l0_flush.clone()); // Scan the local 'tenants/' directory and start loading the tenants - let (basebackup_prepare_sender, basebackup_prepare_receiver) = - tokio::sync::mpsc::unbounded_channel(); + let (basebackup_cache, basebackup_prepare_receiver) = BasebackupCache::new( + conf.basebackup_cache_dir(), + conf.basebackup_cache_config.clone(), + ); let deletion_queue_client = deletion_queue.new_client(); let background_purges = mgr::BackgroundPurges::default(); - let tenant_manager = BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr( + + let tenant_manager = mgr::init( conf, background_purges.clone(), TenantSharedResources { @@ -593,18 +597,16 @@ fn start_pageserver( remote_storage: remote_storage.clone(), deletion_queue_client, l0_flush_global_state, - basebackup_prepare_sender, - feature_resolver, + basebackup_cache: Arc::clone(&basebackup_cache), + feature_resolver: feature_resolver.clone(), }, - order, shutdown_pageserver.clone(), - ))?; + ); let tenant_manager = Arc::new(tenant_manager); + BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(tenant_manager.clone(), order))?; - let basebackup_cache = BasebackupCache::spawn( + basebackup_cache.spawn_background_task( BACKGROUND_RUNTIME.handle(), - conf.basebackup_cache_dir(), - conf.basebackup_cache_config.clone(), basebackup_prepare_receiver, Arc::clone(&tenant_manager), shutdown_pageserver.child_token(), @@ -726,6 +728,7 @@ fn start_pageserver( disk_usage_eviction_state, deletion_queue.new_client(), secondary_controller, + feature_resolver, ) .context("Failed to initialize router state")?, ); @@ -816,7 +819,6 @@ fn start_pageserver( } else { None }, - basebackup_cache, ); // Spawn a Pageserver gRPC server task. It will spawn separate tasks for @@ -827,7 +829,7 @@ fn start_pageserver( // necessary? let mut page_service_grpc = None; if let Some(grpc_listener) = grpc_listener { - page_service_grpc = Some(page_service::spawn_grpc( + page_service_grpc = Some(GrpcPageServiceHandler::spawn( tenant_manager.clone(), grpc_auth, otel_guard.as_ref().map(|g| g.dispatch.clone()), diff --git a/pageserver/src/bin/test_helper_slow_client_reads.rs b/pageserver/src/bin/test_helper_slow_client_reads.rs index 0215dd06fb..be8e081945 100644 --- a/pageserver/src/bin/test_helper_slow_client_reads.rs +++ b/pageserver/src/bin/test_helper_slow_client_reads.rs @@ -2,7 +2,9 @@ use std::io::{Read, Write, stdin, stdout}; use std::time::Duration; use clap::Parser; -use pageserver_api::models::{PagestreamRequest, PagestreamTestRequest}; +use pageserver_api::pagestream_api::{ + PagestreamFeMessage, PagestreamRequest, PagestreamTestRequest, +}; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; @@ -28,17 +30,15 @@ async fn main() -> anyhow::Result<()> { let mut msg = 0; loop { msg += 1; - let fut = sender.send(pageserver_api::models::PagestreamFeMessage::Test( - PagestreamTestRequest { - hdr: PagestreamRequest { - reqid: 0, - request_lsn: Lsn(23), - not_modified_since: Lsn(23), - }, - batch_key: 42, - message: format!("message {}", msg), + let fut = sender.send(PagestreamFeMessage::Test(PagestreamTestRequest { + hdr: PagestreamRequest { + reqid: 0, + request_lsn: Lsn(23), + not_modified_since: Lsn(23), }, - )); + batch_key: 42, + message: format!("message {}", msg), + })); let Ok(res) = tokio::time::timeout(Duration::from_secs(10), fut).await else { eprintln!("pipe seems full"); break; diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index 3492a8d966..5b51a9617b 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -11,7 +11,7 @@ use std::num::NonZeroUsize; use std::sync::Arc; use std::time::Duration; -use anyhow::{Context, bail, ensure}; +use anyhow::{Context, ensure}; use camino::{Utf8Path, Utf8PathBuf}; use once_cell::sync::OnceCell; use pageserver_api::config::{ @@ -22,6 +22,7 @@ use pageserver_api::models::ImageCompressionAlgorithm; use pageserver_api::shard::TenantShardId; use pem::Pem; use postgres_backend::AuthType; +use postgres_ffi::PgMajorVersion; use remote_storage::{RemotePath, RemoteStorageConfig}; use reqwest::Url; use storage_broker::Uri; @@ -338,20 +339,16 @@ impl PageServerConf { // // Postgres distribution paths // - pub fn pg_distrib_dir(&self, pg_version: u32) -> anyhow::Result { + pub fn pg_distrib_dir(&self, pg_version: PgMajorVersion) -> anyhow::Result { let path = self.pg_distrib_dir.clone(); - #[allow(clippy::manual_range_patterns)] - match pg_version { - 14 | 15 | 16 | 17 => Ok(path.join(format!("v{pg_version}"))), - _ => bail!("Unsupported postgres version: {}", pg_version), - } + Ok(path.join(pg_version.v_str())) } - pub fn pg_bin_dir(&self, pg_version: u32) -> anyhow::Result { + pub fn pg_bin_dir(&self, pg_version: PgMajorVersion) -> anyhow::Result { Ok(self.pg_distrib_dir(pg_version)?.join("bin")) } - pub fn pg_lib_dir(&self, pg_version: u32) -> anyhow::Result { + pub fn pg_lib_dir(&self, pg_version: PgMajorVersion) -> anyhow::Result { Ok(self.pg_distrib_dir(pg_version)?.join("lib")) } @@ -765,4 +762,23 @@ mod tests { let result = PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir); assert_eq!(result.is_ok(), is_valid); } + + #[test] + fn test_config_posthog_config_is_valid() { + let input = r#" + control_plane_api = "http://localhost:6666" + + [posthog_config] + server_api_key = "phs_AAA" + client_api_key = "phc_BBB" + project_id = "000" + private_api_url = "https://us.posthog.com" + public_api_url = "https://us.i.posthog.com" + "#; + let config_toml = toml_edit::de::from_str::(input) + .expect("posthogconfig is valid"); + let workdir = Utf8PathBuf::from("/nonexistent"); + PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir) + .expect("parse_and_validate"); + } } diff --git a/pageserver/src/controller_upcall_client.rs b/pageserver/src/controller_upcall_client.rs index dc38ea616c..f1f9aaf43c 100644 --- a/pageserver/src/controller_upcall_client.rs +++ b/pageserver/src/controller_upcall_client.rs @@ -159,14 +159,7 @@ impl StorageControllerUpcallApi for StorageControllerUpcallClient { Ok(m) => { // Since we run one time at startup, be generous in our logging and // dump all metadata. - tracing::info!( - "Loaded node metadata: postgres {}:{}, http {}:{}, other fields: {:?}", - m.postgres_host, - m.postgres_port, - m.http_host, - m.http_port, - m.other - ); + tracing::info!("Loaded node metadata: {m}"); let az_id = { let az_id_from_metadata = m @@ -195,6 +188,8 @@ impl StorageControllerUpcallApi for StorageControllerUpcallClient { node_id: conf.id, listen_pg_addr: m.postgres_host, listen_pg_port: m.postgres_port, + listen_grpc_addr: m.grpc_host, + listen_grpc_port: m.grpc_port, listen_http_addr: m.http_host, listen_http_port: m.http_port, listen_https_port: m.https_port, diff --git a/pageserver/src/feature_resolver.rs b/pageserver/src/feature_resolver.rs index 50de3b691c..92a9ef2880 100644 --- a/pageserver/src/feature_resolver.rs +++ b/pageserver/src/feature_resolver.rs @@ -1,5 +1,7 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; +use arc_swap::ArcSwap; +use pageserver_api::config::NodeMetadata; use posthog_client_lite::{ CaptureEvent, FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError, PostHogFlagFilterPropertyValue, @@ -11,10 +13,13 @@ use utils::id::TenantId; use crate::{config::PageServerConf, metrics::FEATURE_FLAG_EVALUATION}; +const DEFAULT_POSTHOG_REFRESH_INTERVAL: Duration = Duration::from_secs(600); + #[derive(Clone)] pub struct FeatureResolver { inner: Option>, internal_properties: Option>>, + force_overrides_for_testing: Arc>>, } impl FeatureResolver { @@ -22,9 +27,17 @@ impl FeatureResolver { Self { inner: None, internal_properties: None, + force_overrides_for_testing: Arc::new(ArcSwap::new(Arc::new(HashMap::new()))), } } + pub fn update(&self, spec: String) -> anyhow::Result<()> { + if let Some(inner) = &self.inner { + inner.update(spec)?; + } + Ok(()) + } + pub fn spawn( conf: &PageServerConf, shutdown_pageserver: CancellationToken, @@ -86,7 +99,35 @@ impl FeatureResolver { } } } - // TODO: add pageserver URL. + // TODO: move this to a background task so that we don't block startup in case of slow disk + let metadata_path = conf.metadata_path(); + match std::fs::read_to_string(&metadata_path) { + Ok(metadata_str) => match serde_json::from_str::(&metadata_str) { + Ok(metadata) => { + properties.insert( + "hostname".to_string(), + PostHogFlagFilterPropertyValue::String(metadata.http_host), + ); + if let Some(cplane_region) = metadata.other.get("region_id") { + if let Some(cplane_region) = cplane_region.as_str() { + // This region contains the cell number + properties.insert( + "neon_region".to_string(), + PostHogFlagFilterPropertyValue::String( + cplane_region.to_string(), + ), + ); + } + } + } + Err(e) => { + tracing::warn!("Failed to parse metadata.json: {}", e); + } + }, + Err(e) => { + tracing::warn!("Failed to read metadata.json: {}", e); + } + } Arc::new(properties) }; let fake_tenants = { @@ -110,18 +151,23 @@ impl FeatureResolver { } tenants }; - // TODO: make refresh period configurable - inner - .clone() - .spawn(handle, Duration::from_secs(60), fake_tenants); + inner.clone().spawn( + handle, + posthog_config + .refresh_interval + .unwrap_or(DEFAULT_POSTHOG_REFRESH_INTERVAL), + fake_tenants, + ); Ok(FeatureResolver { inner: Some(inner), internal_properties: Some(internal_properties), + force_overrides_for_testing: Arc::new(ArcSwap::new(Arc::new(HashMap::new()))), }) } else { Ok(FeatureResolver { inner: None, internal_properties: None, + force_overrides_for_testing: Arc::new(ArcSwap::new(Arc::new(HashMap::new()))), }) } } @@ -161,6 +207,11 @@ impl FeatureResolver { flag_key: &str, tenant_id: TenantId, ) -> Result { + let force_overrides = self.force_overrides_for_testing.load(); + if let Some(value) = force_overrides.get(flag_key) { + return Ok(value.clone()); + } + if let Some(inner) = &self.inner { let res = inner.feature_store().evaluate_multivariate( flag_key, @@ -199,6 +250,15 @@ impl FeatureResolver { flag_key: &str, tenant_id: TenantId, ) -> Result<(), PostHogEvaluationError> { + let force_overrides = self.force_overrides_for_testing.load(); + if let Some(value) = force_overrides.get(flag_key) { + return if value == "true" { + Ok(()) + } else { + Err(PostHogEvaluationError::NoConditionGroupMatched) + }; + } + if let Some(inner) = &self.inner { let res = inner.feature_store().evaluate_boolean( flag_key, @@ -230,8 +290,22 @@ impl FeatureResolver { inner.feature_store().is_feature_flag_boolean(flag_key) } else { Err(PostHogEvaluationError::NotAvailable( - "PostHog integration is not enabled".to_string(), + "PostHog integration is not enabled, cannot auto-determine the flag type" + .to_string(), )) } } + + /// Force override a feature flag for testing. This is only for testing purposes. Assume the caller only call it + /// from a single thread so it won't race. + pub fn force_override_for_testing(&self, flag_key: &str, value: Option<&str>) { + let mut force_overrides = self.force_overrides_for_testing.load().as_ref().clone(); + if let Some(value) = value { + force_overrides.insert(flag_key.to_string(), value.to_string()); + } else { + force_overrides.remove(flag_key); + } + self.force_overrides_for_testing + .store(Arc::new(force_overrides)); + } } diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index c8a2a0209f..aa9bec657c 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -41,6 +41,7 @@ use pageserver_api::models::{ TopTenantShardItem, TopTenantShardsRequest, TopTenantShardsResponse, }; use pageserver_api::shard::{ShardCount, TenantShardId}; +use postgres_ffi::PgMajorVersion; use remote_storage::{DownloadError, GenericRemoteStorage, TimeTravelError}; use scopeguard::defer; use serde_json::json; @@ -59,6 +60,7 @@ use crate::config::PageServerConf; use crate::context; use crate::context::{DownloadBehavior, RequestContext, RequestContextBuilder}; use crate::deletion_queue::DeletionQueueClient; +use crate::feature_resolver::FeatureResolver; use crate::pgdatadir_mapping::LsnForTimestamp; use crate::task_mgr::TaskKind; use crate::tenant::config::LocationConf; @@ -73,6 +75,7 @@ use crate::tenant::remote_timeline_client::{ use crate::tenant::secondary::SecondaryController; use crate::tenant::size::ModelInputs; use crate::tenant::storage_layer::{IoConcurrency, LayerAccessStatsReset, LayerName}; +use crate::tenant::timeline::layer_manager::LayerManagerLockHolder; use crate::tenant::timeline::offload::{OffloadError, offload_timeline}; use crate::tenant::timeline::{ CompactFlags, CompactOptions, CompactRequest, CompactionError, MarkInvisibleRequest, Timeline, @@ -106,6 +109,7 @@ pub struct State { deletion_queue_client: DeletionQueueClient, secondary_controller: SecondaryController, latest_utilization: tokio::sync::Mutex>, + feature_resolver: FeatureResolver, } impl State { @@ -119,6 +123,7 @@ impl State { disk_usage_eviction_state: Arc, deletion_queue_client: DeletionQueueClient, secondary_controller: SecondaryController, + feature_resolver: FeatureResolver, ) -> anyhow::Result { let allowlist_routes = &[ "/v1/status", @@ -139,6 +144,7 @@ impl State { deletion_queue_client, secondary_controller, latest_utilization: Default::default(), + feature_resolver, }) } } @@ -282,11 +288,11 @@ impl From for ApiError { GetActiveTenantError::WillNotBecomeActive(TenantState::Stopping { .. }) => { ApiError::ShuttingDown } - GetActiveTenantError::WillNotBecomeActive(_) => ApiError::Conflict(format!("{}", e)), + GetActiveTenantError::WillNotBecomeActive(_) => ApiError::Conflict(format!("{e}")), GetActiveTenantError::Cancelled => ApiError::ShuttingDown, GetActiveTenantError::NotFound(gte) => gte.into(), GetActiveTenantError::WaitForActiveTimeout { .. } => { - ApiError::ResourceUnavailable(format!("{}", e).into()) + ApiError::ResourceUnavailable(format!("{e}").into()) } GetActiveTenantError::SwitchedTenant => { // in our HTTP handlers, this error doesn't happen @@ -1010,7 +1016,7 @@ async fn get_lsn_by_timestamp_handler( let timeline_id: TimelineId = parse_request_param(&request, "timeline_id")?; let timestamp_raw = must_get_query_param(&request, "timestamp")?; let timestamp = humantime::parse_rfc3339(×tamp_raw) - .with_context(|| format!("Invalid time: {:?}", timestamp_raw)) + .with_context(|| format!("Invalid time: {timestamp_raw:?}")) .map_err(ApiError::BadRequest)?; let timestamp_pg = postgres_ffi::to_pg_timestamp(timestamp); @@ -1105,7 +1111,7 @@ async fn get_timestamp_of_lsn_handler( json_response(StatusCode::OK, time) } None => Err(ApiError::PreconditionFailed( - format!("Timestamp for lsn {} not found", lsn).into(), + format!("Timestamp for lsn {lsn} not found").into(), )), } } @@ -1451,7 +1457,10 @@ async fn timeline_layer_scan_disposable_keys( let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download) .with_scope_timeline(&timeline); - let guard = timeline.layers.read().await; + let guard = timeline + .layers + .read(LayerManagerLockHolder::GetLayerMapInfo) + .await; let Some(layer) = guard.try_get_from_key(&layer_name.clone().into()) else { return Err(ApiError::NotFound( anyhow::anyhow!("Layer {tenant_shard_id}/{timeline_id}/{layer_name} not found").into(), @@ -2413,7 +2422,7 @@ async fn timeline_offload_handler( } if let (false, reason) = timeline.can_offload() { return Err(ApiError::PreconditionFailed( - format!("Timeline::can_offload() check failed: {}", reason) .into(), + format!("Timeline::can_offload() check failed: {reason}") .into(), )); } offload_timeline(&tenant, &timeline) @@ -3377,7 +3386,7 @@ async fn put_tenant_timeline_import_basebackup( let timeline_id: TimelineId = parse_request_param(&request, "timeline_id")?; let base_lsn: Lsn = must_parse_query_param(&request, "base_lsn")?; let end_lsn: Lsn = must_parse_query_param(&request, "end_lsn")?; - let pg_version: u32 = must_parse_query_param(&request, "pg_version")?; + let pg_version: PgMajorVersion = must_parse_query_param(&request, "pg_version")?; check_permission(&request, Some(tenant_id))?; @@ -3671,8 +3680,8 @@ async fn tenant_evaluate_feature_flag( let tenant_shard_id: TenantShardId = parse_request_param(&request, "tenant_shard_id")?; check_permission(&request, Some(tenant_shard_id.tenant_id))?; - let flag: String = must_parse_query_param(&request, "flag")?; - let as_type: String = must_parse_query_param(&request, "as")?; + let flag: String = parse_request_param(&request, "flag_key")?; + let as_type: Option = parse_query_param(&request, "as")?; let state = get_state(&request); @@ -3681,11 +3690,11 @@ async fn tenant_evaluate_feature_flag( .tenant_manager .get_attached_tenant_shard(tenant_shard_id)?; let properties = tenant.feature_resolver.collect_properties(tenant_shard_id.tenant_id); - if as_type == "boolean" { + if as_type.as_deref() == Some("boolean") { let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id); let result = result.map(|_| true).map_err(|e| e.to_string()); json_response(StatusCode::OK, json!({ "result": result, "properties": properties })) - } else if as_type == "multivariate" { + } else if as_type.as_deref() == Some("multivariate") { let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string()); json_response(StatusCode::OK, json!({ "result": result, "properties": properties })) } else { @@ -3705,6 +3714,49 @@ async fn tenant_evaluate_feature_flag( .await } +async fn force_override_feature_flag_for_testing_put( + request: Request, + _cancel: CancellationToken, +) -> Result, ApiError> { + check_permission(&request, None)?; + + let flag: String = parse_request_param(&request, "flag_key")?; + let value: String = must_parse_query_param(&request, "value")?; + let state = get_state(&request); + state + .feature_resolver + .force_override_for_testing(&flag, Some(&value)); + json_response(StatusCode::OK, ()) +} + +async fn force_override_feature_flag_for_testing_delete( + request: Request, + _cancel: CancellationToken, +) -> Result, ApiError> { + check_permission(&request, None)?; + + let flag: String = parse_request_param(&request, "flag_key")?; + let state = get_state(&request); + state + .feature_resolver + .force_override_for_testing(&flag, None); + json_response(StatusCode::OK, ()) +} + +async fn update_feature_flag_spec( + mut request: Request, + _cancel: CancellationToken, +) -> Result, ApiError> { + check_permission(&request, None)?; + let body = json_request(&mut request).await?; + let state = get_state(&request); + state + .feature_resolver + .update(body) + .map_err(ApiError::InternalServerError)?; + json_response(StatusCode::OK, ()) +} + /// Common functionality of all the HTTP API handlers. /// /// - Adds a tracing span to each request (by `request_span`) @@ -4081,8 +4133,17 @@ pub fn make_router( "/v1/tenant/:tenant_shard_id/timeline/:timeline_id/activate_post_import", |r| api_handler(r, activate_post_import_handler), ) - .get("/v1/tenant/:tenant_shard_id/feature_flag", |r| { + .get("/v1/tenant/:tenant_shard_id/feature_flag/:flag_key", |r| { api_handler(r, tenant_evaluate_feature_flag) }) + .put("/v1/feature_flag/:flag_key", |r| { + testing_api_handler("force override feature flag - put", r, force_override_feature_flag_for_testing_put) + }) + .delete("/v1/feature_flag/:flag_key", |r| { + testing_api_handler("force override feature flag - delete", r, force_override_feature_flag_for_testing_delete) + }) + .post("/v1/feature_flag_spec", |r| { + api_handler(r, update_feature_flag_spec) + }) .any(handler_404)) } diff --git a/pageserver/src/import_datadir.rs b/pageserver/src/import_datadir.rs index 911449c7c5..96fe0c1078 100644 --- a/pageserver/src/import_datadir.rs +++ b/pageserver/src/import_datadir.rs @@ -520,7 +520,7 @@ async fn import_file( } if file_path.starts_with("global") { - let spcnode = postgres_ffi::pg_constants::GLOBALTABLESPACE_OID; + let spcnode = postgres_ffi_types::constants::GLOBALTABLESPACE_OID; let dbnode = 0; match file_name.as_ref() { @@ -553,7 +553,7 @@ async fn import_file( } } } else if file_path.starts_with("base") { - let spcnode = pg_constants::DEFAULTTABLESPACE_OID; + let spcnode = postgres_ffi_types::constants::DEFAULTTABLESPACE_OID; let dbnode: u32 = file_path .iter() .nth(1) diff --git a/pageserver/src/lib.rs b/pageserver/src/lib.rs index ae7cbf1d6b..0dd3c465e0 100644 --- a/pageserver/src/lib.rs +++ b/pageserver/src/lib.rs @@ -38,6 +38,7 @@ pub mod walredo; use camino::Utf8Path; use deletion_queue::DeletionQueue; +use postgres_ffi::PgMajorVersion; use tenant::mgr::{BackgroundPurges, TenantManager}; use tenant::secondary; use tracing::{info, info_span}; @@ -51,7 +52,7 @@ use tracing::{info, info_span}; /// backwards-compatible changes to the metadata format. pub const STORAGE_FORMAT_VERSION: u16 = 3; -pub const DEFAULT_PG_VERSION: u32 = 17; +pub const DEFAULT_PG_VERSION: PgMajorVersion = PgMajorVersion::PG17; // Magic constants used to identify different kinds of files pub const IMAGE_FILE_MAGIC: u16 = 0x5A60; diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index 3eb70ffac2..21faceef49 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -1053,6 +1053,15 @@ pub(crate) static TENANT_STATE_METRIC: Lazy = Lazy::new(|| { .expect("Failed to register pageserver_tenant_states_count metric") }); +pub(crate) static TIMELINE_STATE_METRIC: Lazy = Lazy::new(|| { + register_uint_gauge_vec!( + "pageserver_timeline_states_count", + "Count of timelines per state", + &["state"] + ) + .expect("Failed to register pageserver_timeline_states_count metric") +}); + /// A set of broken tenants. /// /// These are expected to be so rare that a set is fine. Set as in a new timeseries per each broken @@ -1718,12 +1727,7 @@ impl Drop for SmgrOpTimer { impl SmgrOpFlushInProgress { /// The caller must guarantee that `socket_fd`` outlives this function. - pub(crate) async fn measure( - self, - started_at: Instant, - mut fut: Fut, - socket_fd: RawFd, - ) -> O + pub(crate) async fn measure(self, started_at: Instant, fut: Fut, socket_fd: RawFd) -> O where Fut: std::future::Future, { @@ -3325,6 +3329,8 @@ impl TimelineMetrics { &timeline_id, ); + TIMELINE_STATE_METRIC.with_label_values(&["active"]).inc(); + TimelineMetrics { tenant_id, shard_id, @@ -3415,7 +3421,7 @@ impl TimelineMetrics { pub fn dec_frozen_layer(&self, layer: &InMemoryLayer) { assert!(matches!(layer.info(), InMemoryLayerInfo::Frozen { .. })); let labels = self.make_frozen_layer_labels(layer); - let size = layer.try_len().expect("frozen layer should have no writer"); + let size = layer.len(); TIMELINE_LAYER_COUNT .get_metric_with_label_values(&labels) .unwrap() @@ -3430,7 +3436,7 @@ impl TimelineMetrics { pub fn inc_frozen_layer(&self, layer: &InMemoryLayer) { assert!(matches!(layer.info(), InMemoryLayerInfo::Frozen { .. })); let labels = self.make_frozen_layer_labels(layer); - let size = layer.try_len().expect("frozen layer should have no writer"); + let size = layer.len(); TIMELINE_LAYER_COUNT .get_metric_with_label_values(&labels) .unwrap() @@ -3479,6 +3485,8 @@ impl TimelineMetrics { return; } + TIMELINE_STATE_METRIC.with_label_values(&["active"]).dec(); + let tenant_id = &self.tenant_id; let timeline_id = &self.timeline_id; let shard_id = &self.shard_id; @@ -4415,24 +4423,30 @@ pub(crate) static BASEBACKUP_CACHE_PREPARE: Lazy = Lazy::new(|| { .expect("failed to define a metric") }); -pub(crate) static BASEBACKUP_CACHE_ENTRIES: Lazy = Lazy::new(|| { - register_int_gauge!( +pub(crate) static BASEBACKUP_CACHE_ENTRIES: Lazy = Lazy::new(|| { + register_uint_gauge!( "pageserver_basebackup_cache_entries_total", "Number of entries in the basebackup cache" ) .expect("failed to define a metric") }); -// FIXME: Support basebackup cache size metrics. -#[allow(dead_code)] -pub(crate) static BASEBACKUP_CACHE_SIZE: Lazy = Lazy::new(|| { - register_int_gauge!( +pub(crate) static BASEBACKUP_CACHE_SIZE: Lazy = Lazy::new(|| { + register_uint_gauge!( "pageserver_basebackup_cache_size_bytes", "Total size of all basebackup cache entries on disk in bytes" ) .expect("failed to define a metric") }); +pub(crate) static BASEBACKUP_CACHE_PREPARE_QUEUE_SIZE: Lazy = Lazy::new(|| { + register_uint_gauge!( + "pageserver_basebackup_cache_prepare_queue_size", + "Number of requests in the basebackup prepare channel" + ) + .expect("failed to define a metric") +}); + static PAGESERVER_CONFIG_IGNORED_ITEMS: Lazy = Lazy::new(|| { register_uint_gauge_vec!( "pageserver_config_ignored_items", diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 4a1ddf09b5..0287a2bdb5 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -13,8 +13,7 @@ use std::time::{Duration, Instant, SystemTime}; use std::{io, str}; use anyhow::{Context as _, anyhow, bail}; -use async_compression::tokio::write::GzipEncoder; -use bytes::{Buf, BytesMut}; +use bytes::{Buf as _, BufMut as _, BytesMut}; use futures::future::BoxFuture; use futures::{FutureExt, Stream}; use itertools::Itertools; @@ -25,12 +24,13 @@ use pageserver_api::config::{ PageServiceProtocolPipelinedBatchingStrategy, PageServiceProtocolPipelinedExecutionStrategy, }; use pageserver_api::key::rel_block_to_key; -use pageserver_api::models::{ - self, PageTraceEvent, PagestreamBeMessage, PagestreamDbSizeRequest, PagestreamDbSizeResponse, +use pageserver_api::models::{PageTraceEvent, TenantState}; +use pageserver_api::pagestream_api::{ + self, PagestreamBeMessage, PagestreamDbSizeRequest, PagestreamDbSizeResponse, PagestreamErrorResponse, PagestreamExistsRequest, PagestreamExistsResponse, PagestreamFeMessage, PagestreamGetPageRequest, PagestreamGetSlruSegmentRequest, PagestreamGetSlruSegmentResponse, PagestreamNblocksRequest, PagestreamNblocksResponse, - PagestreamProtocolVersion, PagestreamRequest, TenantState, + PagestreamProtocolVersion, PagestreamRequest, }; use pageserver_api::reltag::SlruKind; use pageserver_api::shard::TenantShardId; @@ -40,7 +40,7 @@ use postgres_backend::{ AuthType, PostgresBackend, PostgresBackendReader, QueryError, is_expected_io_error, }; use postgres_ffi::BLCKSZ; -use postgres_ffi::pg_constants::DEFAULTTABLESPACE_OID; +use postgres_ffi_types::constants::DEFAULTTABLESPACE_OID; use pq_proto::framed::ConnectionError; use pq_proto::{BeMessage, FeMessage, FeStartupPacket, RowDescriptor}; use smallvec::{SmallVec, smallvec}; @@ -62,7 +62,6 @@ use utils::{failpoint_support, span_record}; use crate::auth::check_permission; use crate::basebackup::{self, BasebackupError}; -use crate::basebackup_cache::BasebackupCache; use crate::config::PageServerConf; use crate::context::{ DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder, @@ -137,7 +136,6 @@ pub fn spawn( perf_trace_dispatch: Option, tcp_listener: tokio::net::TcpListener, tls_config: Option>, - basebackup_cache: Arc, ) -> Listener { let cancel = CancellationToken::new(); let libpq_ctx = RequestContext::todo_child( @@ -159,7 +157,6 @@ pub fn spawn( conf.pg_auth_type, tls_config, conf.page_service_pipelining.clone(), - basebackup_cache, libpq_ctx, cancel.clone(), ) @@ -169,99 +166,6 @@ pub fn spawn( Listener { cancel, task } } -/// Spawns a gRPC server for the page service. -/// -/// TODO: move this onto GrpcPageServiceHandler::spawn(). -/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we -/// need to reimplement the TCP+TLS accept loop ourselves. -pub fn spawn_grpc( - tenant_manager: Arc, - auth: Option>, - perf_trace_dispatch: Option, - get_vectored_concurrent_io: GetVectoredConcurrentIo, - listener: std::net::TcpListener, -) -> anyhow::Result { - let cancel = CancellationToken::new(); - let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler) - .download_behavior(DownloadBehavior::Download) - .perf_span_dispatch(perf_trace_dispatch) - .detached_child(); - let gate = Gate::default(); - - // Set up the TCP socket. We take a preconfigured TcpListener to bind the - // port early during startup. - let incoming = { - let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std - listener.set_nonblocking(true)?; - tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std(listener)?) - .with_nodelay(Some(GRPC_TCP_NODELAY)) - .with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME)) - }; - - // Set up the gRPC server. - // - // TODO: consider tuning window sizes. - let mut server = tonic::transport::Server::builder() - .http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL)) - .http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT)) - .max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS)); - - // Main page service stack. Uses a mix of Tonic interceptors and Tower layers: - // - // * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service. - // - // * Layers: allow async code, can run code after the service response. However, only has access - // to the raw HTTP request/response, not the gRPC types. - let page_service_handler = GrpcPageServiceHandler { - tenant_manager, - ctx, - gate_guard: gate.enter().expect("gate was just created"), - get_vectored_concurrent_io, - }; - - let observability_layer = ObservabilityLayer; - let mut tenant_interceptor = TenantMetadataInterceptor; - let mut auth_interceptor = TenantAuthInterceptor::new(auth); - - let page_service = tower::ServiceBuilder::new() - // Create tracing span and record request start time. - .layer(observability_layer) - // Intercept gRPC requests. - .layer(tonic::service::InterceptorLayer::new(move |mut req| { - // Extract tenant metadata. - req = tenant_interceptor.call(req)?; - // Authenticate tenant JWT token. - req = auth_interceptor.call(req)?; - Ok(req) - })) - .service(proto::PageServiceServer::new(page_service_handler)); - let server = server.add_service(page_service); - - // Reflection service for use with e.g. grpcurl. - let reflection_service = tonic_reflection::server::Builder::configure() - .register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET) - .build_v1()?; - let server = server.add_service(reflection_service); - - // Spawn server task. - let task_cancel = cancel.clone(); - let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error( - "grpc listener", - async move { - let result = server - .serve_with_incoming_shutdown(incoming, task_cancel.cancelled()) - .await; - if result.is_ok() { - // TODO: revisit shutdown logic once page service is implemented. - gate.close().await; - } - result - }, - )); - - Ok(CancellableTask { task, cancel }) -} - impl Listener { pub async fn stop_accepting(self) -> Connections { self.cancel.cancel(); @@ -311,7 +215,6 @@ pub async fn libpq_listener_main( auth_type: AuthType, tls_config: Option>, pipelining_config: PageServicePipeliningConfig, - basebackup_cache: Arc, listener_ctx: RequestContext, listener_cancel: CancellationToken, ) -> Connections { @@ -355,7 +258,6 @@ pub async fn libpq_listener_main( auth_type, tls_config.clone(), pipelining_config.clone(), - Arc::clone(&basebackup_cache), connection_ctx, connections_cancel.child_token(), gate_guard, @@ -398,7 +300,6 @@ async fn page_service_conn_main( auth_type: AuthType, tls_config: Option>, pipelining_config: PageServicePipeliningConfig, - basebackup_cache: Arc, connection_ctx: RequestContext, cancel: CancellationToken, gate_guard: GateGuard, @@ -464,7 +365,6 @@ async fn page_service_conn_main( pipelining_config, conf.get_vectored_concurrent_io, perf_span_fields, - basebackup_cache, connection_ctx, cancel.clone(), gate_guard, @@ -484,16 +384,14 @@ async fn page_service_conn_main( } else { let tenant_id = conn_handler.timeline_handles.as_ref().unwrap().tenant_id(); Err(io_error).context(format!( - "Postgres connection error for tenant_id={:?} client at peer_addr={}", - tenant_id, peer_addr + "Postgres connection error for tenant_id={tenant_id:?} client at peer_addr={peer_addr}" )) } } other => { let tenant_id = conn_handler.timeline_handles.as_ref().unwrap().tenant_id(); other.context(format!( - "Postgres query error for tenant_id={:?} client peer_addr={}", - tenant_id, peer_addr + "Postgres query error for tenant_id={tenant_id:?} client peer_addr={peer_addr}" )) } } @@ -520,8 +418,6 @@ struct PageServerHandler { pipelining_config: PageServicePipeliningConfig, get_vectored_concurrent_io: GetVectoredConcurrentIo, - basebackup_cache: Arc, - gate_guard: GateGuard, } @@ -716,60 +612,6 @@ enum PageStreamError { BadRequest(Cow<'static, str>), } -impl PageStreamError { - /// Converts a PageStreamError into a proto::GetPageResponse with the appropriate status - /// code, or a gRPC status if it should terminate the stream (e.g. shutdown). This is a - /// convenience method for use from a get_pages gRPC stream. - #[allow(clippy::result_large_err)] - fn into_get_page_response( - self, - request_id: page_api::RequestID, - ) -> Result { - use page_api::GetPageStatusCode; - use tonic::Code; - - // We dispatch to Into first, and then map it to a GetPageResponse. - let status: tonic::Status = self.into(); - let status_code = match status.code() { - // We shouldn't see an OK status here, because we're emitting an error. - Code::Ok => { - debug_assert_ne!(status.code(), Code::Ok); - return Err(tonic::Status::internal(format!( - "unexpected OK status: {status:?}", - ))); - } - - // These are per-request errors, returned as GetPageResponses. - Code::AlreadyExists => GetPageStatusCode::InvalidRequest, - Code::DataLoss => GetPageStatusCode::InternalError, - Code::FailedPrecondition => GetPageStatusCode::InvalidRequest, - Code::InvalidArgument => GetPageStatusCode::InvalidRequest, - Code::Internal => GetPageStatusCode::InternalError, - Code::NotFound => GetPageStatusCode::NotFound, - Code::OutOfRange => GetPageStatusCode::InvalidRequest, - Code::ResourceExhausted => GetPageStatusCode::SlowDown, - - // These should terminate the stream. - Code::Aborted => return Err(status), - Code::Cancelled => return Err(status), - Code::DeadlineExceeded => return Err(status), - Code::PermissionDenied => return Err(status), - Code::Unauthenticated => return Err(status), - Code::Unavailable => return Err(status), - Code::Unimplemented => return Err(status), - Code::Unknown => return Err(status), - }; - - Ok(page_api::GetPageResponse { - request_id, - status_code, - reason: Some(status.message().to_string()), - page_images: Vec::new(), - } - .into()) - } -} - impl From for tonic::Status { fn from(err: PageStreamError) -> Self { use tonic::Code; @@ -859,7 +701,7 @@ struct BatchedGetPageRequest { #[cfg(feature = "testing")] struct BatchedTestRequest { - req: models::PagestreamTestRequest, + req: pagestream_api::PagestreamTestRequest, timer: SmgrOpTimer, } @@ -873,13 +715,13 @@ enum BatchedFeMessage { span: Span, timer: SmgrOpTimer, shard: WeakHandle, - req: models::PagestreamExistsRequest, + req: PagestreamExistsRequest, }, Nblocks { span: Span, timer: SmgrOpTimer, shard: WeakHandle, - req: models::PagestreamNblocksRequest, + req: PagestreamNblocksRequest, }, GetPage { span: Span, @@ -891,13 +733,13 @@ enum BatchedFeMessage { span: Span, timer: SmgrOpTimer, shard: WeakHandle, - req: models::PagestreamDbSizeRequest, + req: PagestreamDbSizeRequest, }, GetSlruSegment { span: Span, timer: SmgrOpTimer, shard: WeakHandle, - req: models::PagestreamGetSlruSegmentRequest, + req: PagestreamGetSlruSegmentRequest, }, #[cfg(feature = "testing")] Test { @@ -1061,7 +903,6 @@ impl PageServerHandler { pipelining_config: PageServicePipeliningConfig, get_vectored_concurrent_io: GetVectoredConcurrentIo, perf_span_fields: ConnectionPerfSpanFields, - basebackup_cache: Arc, connection_ctx: RequestContext, cancel: CancellationToken, gate_guard: GateGuard, @@ -1075,7 +916,6 @@ impl PageServerHandler { cancel, pipelining_config, get_vectored_concurrent_io, - basebackup_cache, gate_guard, } } @@ -2286,8 +2126,7 @@ impl PageServerHandler { if request_lsn < not_modified_since { return Err(PageStreamError::BadRequest( format!( - "invalid request with request LSN {} and not_modified_since {}", - request_lsn, not_modified_since, + "invalid request with request LSN {request_lsn} and not_modified_since {not_modified_since}", ) .into(), )); @@ -2590,10 +2429,9 @@ impl PageServerHandler { .map(|(req, res)| { res.map(|page| { ( - PagestreamBeMessage::GetPage(models::PagestreamGetPageResponse { - req: req.req, - page, - }), + PagestreamBeMessage::GetPage( + pagestream_api::PagestreamGetPageResponse { req: req.req, page }, + ), req.timer, req.ctx, ) @@ -2660,7 +2498,7 @@ impl PageServerHandler { .map(|(req, res)| { res.map(|()| { ( - PagestreamBeMessage::Test(models::PagestreamTestResponse { + PagestreamBeMessage::Test(pagestream_api::PagestreamTestResponse { req: req.req.clone(), }), req.timer, @@ -2763,6 +2601,7 @@ impl PageServerHandler { prev_lsn, full_backup, replica, + None, &ctx, ) .await?; @@ -2776,9 +2615,7 @@ impl PageServerHandler { && lsn.is_some() && prev_lsn.is_none() { - self.basebackup_cache - .get(tenant_id, timeline_id, lsn.unwrap()) - .await + timeline.get_cached_basebackup(lsn.unwrap()).await } else { None } @@ -2791,31 +2628,6 @@ impl PageServerHandler { .map_err(|err| { BasebackupError::Client(err, "handle_basebackup_request,cached,copy") })?; - } else if gzip { - let mut encoder = GzipEncoder::with_quality( - &mut writer, - // NOTE using fast compression because it's on the critical path - // for compute startup. For an empty database, we get - // <100KB with this method. The Level::Best compression method - // gives us <20KB, but maybe we should add basebackup caching - // on compute shutdown first. - async_compression::Level::Fastest, - ); - basebackup::send_basebackup_tarball( - &mut encoder, - &timeline, - lsn, - prev_lsn, - full_backup, - replica, - &ctx, - ) - .await?; - // shutdown the encoder to ensure the gzip footer is written - encoder - .shutdown() - .await - .map_err(|e| QueryError::Disconnected(ConnectionError::Io(e)))?; } else { basebackup::send_basebackup_tarball( &mut writer, @@ -2824,6 +2636,11 @@ impl PageServerHandler { prev_lsn, full_backup, replica, + // NB: using fast compression because it's on the critical path for compute + // startup. For an empty database, we get <100KB with this method. The + // Level::Best compression method gives us <20KB, but maybe we should add + // basebackup caching on compute shutdown first. + gzip.then_some(async_compression::Level::Fastest), &ctx, ) .await?; @@ -3366,6 +3183,108 @@ pub struct GrpcPageServiceHandler { } impl GrpcPageServiceHandler { + /// Spawns a gRPC server for the page service. + /// + /// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we + /// need to reimplement the TCP+TLS accept loop ourselves. + pub fn spawn( + tenant_manager: Arc, + auth: Option>, + perf_trace_dispatch: Option, + get_vectored_concurrent_io: GetVectoredConcurrentIo, + listener: std::net::TcpListener, + ) -> anyhow::Result { + let cancel = CancellationToken::new(); + let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler) + .download_behavior(DownloadBehavior::Download) + .perf_span_dispatch(perf_trace_dispatch) + .detached_child(); + let gate = Gate::default(); + + // Set up the TCP socket. We take a preconfigured TcpListener to bind the + // port early during startup. + let incoming = { + let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std + listener.set_nonblocking(true)?; + tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std( + listener, + )?) + .with_nodelay(Some(GRPC_TCP_NODELAY)) + .with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME)) + }; + + // Set up the gRPC server. + // + // TODO: consider tuning window sizes. + let mut server = tonic::transport::Server::builder() + .http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL)) + .http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT)) + .max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS)); + + // Main page service stack. Uses a mix of Tonic interceptors and Tower layers: + // + // * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service. + // + // * Layers: allow async code, can run code after the service response. However, only has access + // to the raw HTTP request/response, not the gRPC types. + let page_service_handler = GrpcPageServiceHandler { + tenant_manager, + ctx, + gate_guard: gate.enter().expect("gate was just created"), + get_vectored_concurrent_io, + }; + + let observability_layer = ObservabilityLayer; + let mut tenant_interceptor = TenantMetadataInterceptor; + let mut auth_interceptor = TenantAuthInterceptor::new(auth); + + let page_service = tower::ServiceBuilder::new() + // Create tracing span and record request start time. + .layer(observability_layer) + // Intercept gRPC requests. + .layer(tonic::service::InterceptorLayer::new(move |mut req| { + // Extract tenant metadata. + req = tenant_interceptor.call(req)?; + // Authenticate tenant JWT token. + req = auth_interceptor.call(req)?; + Ok(req) + })) + // Run the page service. + .service( + proto::PageServiceServer::new(page_service_handler) + // Support both gzip and zstd compression. The client decides what to use. + .accept_compressed(tonic::codec::CompressionEncoding::Gzip) + .accept_compressed(tonic::codec::CompressionEncoding::Zstd) + .send_compressed(tonic::codec::CompressionEncoding::Gzip) + .send_compressed(tonic::codec::CompressionEncoding::Zstd), + ); + let server = server.add_service(page_service); + + // Reflection service for use with e.g. grpcurl. + let reflection_service = tonic_reflection::server::Builder::configure() + .register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET) + .build_v1()?; + let server = server.add_service(reflection_service); + + // Spawn server task. + let task_cancel = cancel.clone(); + let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error( + "grpc listener", + async move { + let result = server + .serve_with_incoming_shutdown(incoming, task_cancel.cancelled()) + .await; + if result.is_ok() { + // TODO: revisit shutdown logic once page service is implemented. + gate.close().await; + } + result + }, + )); + + Ok(CancellableTask { task, cancel }) + } + /// Errors if the request is executed on a non-zero shard. Only shard 0 has a complete view of /// relations and their sizes, as well as SLRU segments and similar data. #[allow(clippy::result_large_err)] @@ -3436,8 +3355,8 @@ impl GrpcPageServiceHandler { /// Processes a GetPage batch request, via the GetPages bidirectional streaming RPC. /// - /// NB: errors will terminate the stream. Per-request errors should return a GetPageResponse - /// with an appropriate status code instead. + /// NB: errors returned from here are intercepted in get_pages(), and may be converted to a + /// GetPageResponse with an appropriate status code to avoid terminating the stream. /// /// TODO: get_vectored() currently enforces a batch limit of 32. Postgres will typically send /// batches up to effective_io_concurrency = 100. Either we have to accept large batches, or @@ -3454,7 +3373,7 @@ impl GrpcPageServiceHandler { let ctx = ctx.with_scope_page_service_pagestream(&timeline); // Validate the request, decorate the span, and convert it to a Pagestream request. - let req: page_api::GetPageRequest = req.try_into()?; + let req = page_api::GetPageRequest::try_from(req)?; span_record!( req_id = %req.request_id, @@ -3465,7 +3384,7 @@ impl GrpcPageServiceHandler { ); let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); // hold guard - let effective_lsn = match PageServerHandler::effective_request_lsn( + let effective_lsn = PageServerHandler::effective_request_lsn( &timeline, timeline.get_last_record_lsn(), req.read_lsn.request_lsn, @@ -3473,10 +3392,7 @@ impl GrpcPageServiceHandler { .not_modified_since_lsn .unwrap_or(req.read_lsn.request_lsn), &latest_gc_cutoff_lsn, - ) { - Ok(lsn) => lsn, - Err(err) => return err.into_get_page_response(req.request_id), - }; + )?; let mut batch = SmallVec::with_capacity(req.block_numbers.len()); for blkno in req.block_numbers { @@ -3533,7 +3449,7 @@ impl GrpcPageServiceHandler { "unexpected response: {resp:?}" ))); } - Err(err) => return err.err.into_get_page_response(req.request_id), + Err(err) => return Err(err.err.into()), }; } @@ -3587,57 +3503,66 @@ impl proto::PageService for GrpcPageServiceHandler { Ok(tonic::Response::new(resp.into())) } - // TODO: ensure clients use gzip compression for the stream. #[instrument(skip_all, fields(lsn))] async fn get_base_backup( &self, req: tonic::Request, ) -> Result, tonic::Status> { - // Send 64 KB chunks to avoid large memory allocations. - const CHUNK_SIZE: usize = 64 * 1024; + // Send chunks of 256 KB to avoid large memory allocations. pagebench basebackup shows this + // to be the sweet spot where throughput is saturated. + const CHUNK_SIZE: usize = 256 * 1024; let timeline = self.get_request_timeline(&req).await?; let ctx = self.ctx.with_scope_timeline(&timeline); - // Validate the request, decorate the span, and wait for the LSN to arrive. - // - // TODO: this requires a read LSN, is that ok? + // Validate the request and decorate the span. Self::ensure_shard_zero(&timeline)?; if timeline.is_archived() == Some(true) { return Err(tonic::Status::failed_precondition("timeline is archived")); } let req: page_api::GetBaseBackupRequest = req.into_inner().try_into()?; - span_record!(lsn=%req.read_lsn); + span_record!(lsn=?req.lsn); - let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); - timeline - .wait_lsn( - req.read_lsn.request_lsn, - WaitLsnWaiter::PageService, - WaitLsnTimeout::Default, - &ctx, - ) - .await?; - timeline - .check_lsn_is_in_scope(req.read_lsn.request_lsn, &latest_gc_cutoff_lsn) - .map_err(|err| { - tonic::Status::invalid_argument(format!("invalid basebackup LSN: {err}")) - })?; + // Wait for the LSN to arrive, if given. + if let Some(lsn) = req.lsn { + let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); + timeline + .wait_lsn( + lsn, + WaitLsnWaiter::PageService, + WaitLsnTimeout::Default, + &ctx, + ) + .await?; + timeline + .check_lsn_is_in_scope(lsn, &latest_gc_cutoff_lsn) + .map_err(|err| { + tonic::Status::invalid_argument(format!("invalid basebackup LSN: {err}")) + })?; + } // Spawn a task to run the basebackup. - // - // TODO: do we need to support full base backups, for debugging? let span = Span::current(); let (mut simplex_read, mut simplex_write) = tokio::io::simplex(CHUNK_SIZE); let jh = tokio::spawn(async move { + let gzip_level = match req.compression { + page_api::BaseBackupCompression::None => None, + // NB: using fast compression because it's on the critical path for compute + // startup. For an empty database, we get <100KB with this method. The + // Level::Best compression method gives us <20KB, but maybe we should add + // basebackup caching on compute shutdown first. + page_api::BaseBackupCompression::Gzip => Some(async_compression::Level::Fastest), + }; + let result = basebackup::send_basebackup_tarball( &mut simplex_write, &timeline, - Some(req.read_lsn.request_lsn), + req.lsn, None, - false, + req.full, req.replica, + gzip_level, &ctx, ) .instrument(span) // propagate request span @@ -3650,20 +3575,21 @@ impl proto::PageService for GrpcPageServiceHandler { // Emit chunks of size CHUNK_SIZE. let chunks = async_stream::try_stream! { - let mut chunk = BytesMut::with_capacity(CHUNK_SIZE); loop { - let n = simplex_read.read_buf(&mut chunk).await.map_err(|err| { - tonic::Status::internal(format!("failed to read basebackup chunk: {err}")) - })?; - - // If we read 0 bytes, either the chunk is full or the stream is closed. - if n == 0 { - if chunk.is_empty() { - break; + let mut chunk = BytesMut::with_capacity(CHUNK_SIZE).limit(CHUNK_SIZE); + loop { + let n = simplex_read.read_buf(&mut chunk).await.map_err(|err| { + tonic::Status::internal(format!("failed to read basebackup chunk: {err}")) + })?; + if n == 0 { + break; // full chunk or closed stream } - yield proto::GetBaseBackupResponseChunk::from(chunk.clone().freeze()); - chunk.clear(); } + let chunk = chunk.into_inner().freeze(); + if chunk.is_empty() { + break; + } + yield proto::GetBaseBackupResponseChunk::from(chunk); } // Wait for the basebackup task to exit and check for errors. jh.await.map_err(|err| { @@ -3740,9 +3666,16 @@ impl proto::PageService for GrpcPageServiceHandler { .await? .downgrade(); while let Some(req) = reqs.message().await? { - yield Self::get_page(&ctx, &timeline, req, io_concurrency.clone()) + let req_id = req.request_id; + let result = Self::get_page(&ctx, &timeline, req, io_concurrency.clone()) .instrument(span.clone()) // propagate request span - .await? + .await; + yield match result { + Ok(resp) => resp, + // Convert per-request errors to GetPageResponses as appropriate, or terminate + // the stream with a tonic::Status. + Err(err) => page_api::GetPageResponse::try_from_status(err, req_id)?.into(), + } } }; diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index 633d62210d..09a7a8a651 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -23,12 +23,11 @@ use pageserver_api::key::{ }; use pageserver_api::keyspace::{KeySpaceRandomAccum, SparseKeySpace}; use pageserver_api::models::RelSizeMigration; -use pageserver_api::record::NeonWalRecord; use pageserver_api::reltag::{BlockNumber, RelTag, SlruKind}; use pageserver_api::shard::ShardIdentity; -use pageserver_api::value::Value; -use postgres_ffi::relfile_utils::{FSM_FORKNUM, VISIBILITYMAP_FORKNUM}; -use postgres_ffi::{BLCKSZ, Oid, RepOriginId, TimestampTz, TransactionId}; +use postgres_ffi::{BLCKSZ, PgMajorVersion, TimestampTz, TransactionId}; +use postgres_ffi_types::forknum::{FSM_FORKNUM, VISIBILITYMAP_FORKNUM}; +use postgres_ffi_types::{Oid, RepOriginId}; use serde::{Deserialize, Serialize}; use strum::IntoEnumIterator; use tokio_util::sync::CancellationToken; @@ -36,6 +35,8 @@ use tracing::{debug, info, info_span, trace, warn}; use utils::bin_ser::{BeSer, DeserializeError}; use utils::lsn::Lsn; use utils::pausable_failpoint; +use wal_decoder::models::record::NeonWalRecord; +use wal_decoder::models::value::Value; use wal_decoder::serialized_batch::{SerializedValueBatch, ValueMeta}; use super::tenant::{PageReconstructError, Timeline}; @@ -720,6 +721,7 @@ impl Timeline { let batches = keyspace.partition( self.get_shard_identity(), self.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64, + BLCKSZ as u64, ); let io_concurrency = IoConcurrency::spawn_from_conf( @@ -960,6 +962,7 @@ impl Timeline { let batches = keyspace.partition( self.get_shard_identity(), self.conf.max_get_vectored_keys.get() as u64 * BLCKSZ as u64, + BLCKSZ as u64, ); let io_concurrency = IoConcurrency::spawn_from_conf( @@ -1078,7 +1081,7 @@ impl Timeline { // fetch directory entry let buf = self.get(TWOPHASEDIR_KEY, lsn, ctx).await?; - if self.pg_version >= 17 { + if self.pg_version >= PgMajorVersion::PG17 { Ok(TwoPhaseDirectoryV17::des(&buf)?.xids) } else { Ok(TwoPhaseDirectory::des(&buf)? @@ -1182,7 +1185,7 @@ impl Timeline { } let origin_id = k.field6 as RepOriginId; let origin_lsn = Lsn::des(&v) - .with_context(|| format!("decode replorigin value for {}: {v:?}", origin_id))?; + .with_context(|| format!("decode replorigin value for {origin_id}: {v:?}"))?; if origin_lsn != Lsn::INVALID { result.insert(origin_id, origin_lsn); } @@ -1610,7 +1613,7 @@ impl DatadirModification<'_> { .push((DirectoryKind::Db, MetricsUpdate::Set(0))); self.put(DBDIR_KEY, Value::Image(buf.into())); - let buf = if self.tline.pg_version >= 17 { + let buf = if self.tline.pg_version >= PgMajorVersion::PG17 { TwoPhaseDirectoryV17::ser(&TwoPhaseDirectoryV17 { xids: HashSet::new(), }) @@ -1964,7 +1967,7 @@ impl DatadirModification<'_> { ) -> Result<(), WalIngestError> { // Add it to the directory entry let dirbuf = self.get(TWOPHASEDIR_KEY, ctx).await?; - let newdirbuf = if self.tline.pg_version >= 17 { + let newdirbuf = if self.tline.pg_version >= PgMajorVersion::PG17 { let mut dir = TwoPhaseDirectoryV17::des(&dirbuf)?; if !dir.xids.insert(xid) { Err(WalIngestErrorKind::FileAlreadyExists(xid))?; @@ -2380,7 +2383,7 @@ impl DatadirModification<'_> { ) -> Result<(), WalIngestError> { // Remove it from the directory entry let buf = self.get(TWOPHASEDIR_KEY, ctx).await?; - let newdirbuf = if self.tline.pg_version >= 17 { + let newdirbuf = if self.tline.pg_version >= PgMajorVersion::PG17 { let mut dir = TwoPhaseDirectoryV17::des(&buf)?; if !dir.xids.remove(&xid) { @@ -2437,8 +2440,7 @@ impl DatadirModification<'_> { if path == p { assert!( modifying_file.is_none(), - "duplicated entries found for {}", - path + "duplicated entries found for {path}" ); modifying_file = Some(content); } else { diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index f9fdc143b4..2613528143 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -38,6 +38,7 @@ use pageserver_api::models::{ WalRedoManagerStatus, }; use pageserver_api::shard::{ShardIdentity, ShardStripeSize, TenantShardId}; +use postgres_ffi::PgMajorVersion; use remote_storage::{DownloadError, GenericRemoteStorage, TimeoutOrCancel}; use remote_timeline_client::index::GcCompactionState; use remote_timeline_client::manifest::{ @@ -51,6 +52,7 @@ use secondary::heatmap::{HeatMapTenant, HeatMapTimeline}; use storage_broker::BrokerClientChannel; use timeline::compaction::{CompactionOutcome, GcCompactionQueue}; use timeline::import_pgdata::ImportingTimeline; +use timeline::layer_manager::LayerManagerLockHolder; use timeline::offload::{OffloadError, offload_timeline}; use timeline::{ CompactFlags, CompactOptions, CompactionError, PreviousHeatmap, ShutdownMode, import_pgdata, @@ -78,7 +80,7 @@ use self::timeline::uninit::{TimelineCreateGuard, TimelineExclusionError, Uninit use self::timeline::{ EvictionTaskTenantState, GcCutoffs, TimelineDeleteProgress, TimelineResources, WaitLsnError, }; -use crate::basebackup_cache::BasebackupPrepareSender; +use crate::basebackup_cache::BasebackupCache; use crate::config::PageServerConf; use crate::context; use crate::context::RequestContextBuilder; @@ -89,7 +91,8 @@ use crate::l0_flush::L0FlushGlobalState; use crate::metrics::{ BROKEN_TENANTS_SET, CIRCUIT_BREAKERS_BROKEN, CIRCUIT_BREAKERS_UNBROKEN, CONCURRENT_INITDBS, INITDB_RUN_TIME, INITDB_SEMAPHORE_ACQUISITION_TIME, TENANT, TENANT_OFFLOADED_TIMELINES, - TENANT_STATE_METRIC, TENANT_SYNTHETIC_SIZE_METRIC, remove_tenant_metrics, + TENANT_STATE_METRIC, TENANT_SYNTHETIC_SIZE_METRIC, TIMELINE_STATE_METRIC, + remove_tenant_metrics, }; use crate::task_mgr::TaskKind; use crate::tenant::config::LocationMode; @@ -159,7 +162,7 @@ pub struct TenantSharedResources { pub remote_storage: GenericRemoteStorage, pub deletion_queue_client: DeletionQueueClient, pub l0_flush_global_state: L0FlushGlobalState, - pub basebackup_prepare_sender: BasebackupPrepareSender, + pub basebackup_cache: Arc, pub feature_resolver: FeatureResolver, } @@ -328,7 +331,7 @@ pub struct TenantShard { deletion_queue_client: DeletionQueueClient, /// A channel to send async requests to prepare a basebackup for the basebackup cache. - basebackup_prepare_sender: BasebackupPrepareSender, + basebackup_cache: Arc, /// Cached logical sizes updated updated on each [`TenantShard::gather_size_inputs`]. cached_logical_sizes: tokio::sync::Mutex>, @@ -494,8 +497,8 @@ impl WalRedoManager { key: pageserver_api::key::Key, lsn: Lsn, base_img: Option<(Lsn, bytes::Bytes)>, - records: Vec<(Lsn, pageserver_api::record::NeonWalRecord)>, - pg_version: u32, + records: Vec<(Lsn, wal_decoder::models::record::NeonWalRecord)>, + pg_version: PgMajorVersion, redo_attempt_type: RedoAttemptType, ) -> Result { match self { @@ -544,6 +547,28 @@ pub struct OffloadedTimeline { /// Part of the `OffloadedTimeline` object's lifecycle: this needs to be set before we drop it pub deleted_from_ancestor: AtomicBool, + + _metrics_guard: OffloadedTimelineMetricsGuard, +} + +/// Increases the offloaded timeline count metric when created, and decreases when dropped. +struct OffloadedTimelineMetricsGuard; + +impl OffloadedTimelineMetricsGuard { + fn new() -> Self { + TIMELINE_STATE_METRIC + .with_label_values(&["offloaded"]) + .inc(); + Self + } +} + +impl Drop for OffloadedTimelineMetricsGuard { + fn drop(&mut self) { + TIMELINE_STATE_METRIC + .with_label_values(&["offloaded"]) + .dec(); + } } impl OffloadedTimeline { @@ -576,6 +601,8 @@ impl OffloadedTimeline { delete_progress: timeline.delete_progress.clone(), deleted_from_ancestor: AtomicBool::new(false), + + _metrics_guard: OffloadedTimelineMetricsGuard::new(), }) } fn from_manifest(tenant_shard_id: TenantShardId, manifest: &OffloadedTimelineManifest) -> Self { @@ -595,6 +622,7 @@ impl OffloadedTimeline { archived_at, delete_progress: TimelineDeleteProgress::default(), deleted_from_ancestor: AtomicBool::new(false), + _metrics_guard: OffloadedTimelineMetricsGuard::new(), } } fn manifest(&self) -> OffloadedTimelineManifest { @@ -906,7 +934,7 @@ pub(crate) enum CreateTimelineParams { pub(crate) struct CreateTimelineParamsBootstrap { pub(crate) new_timeline_id: TimelineId, pub(crate) existing_initdb_timeline_id: Option, - pub(crate) pg_version: u32, + pub(crate) pg_version: PgMajorVersion, } /// NB: See comment on [`CreateTimelineIdempotency::Branch`] for why there's no `pg_version` here. @@ -944,7 +972,7 @@ pub(crate) enum CreateTimelineIdempotency { /// NB: special treatment, see comment in [`Self`]. FailWithConflict, Bootstrap { - pg_version: u32, + pg_version: PgMajorVersion, }, /// NB: branches always have the same `pg_version` as their ancestor. /// While [`pageserver_api::models::TimelineCreateRequestMode::Branch::pg_version`] @@ -1289,7 +1317,7 @@ impl TenantShard { ancestor.is_some() || timeline .layers - .read() + .read(LayerManagerLockHolder::LoadLayerMap) .await .layer_map() .expect( @@ -1335,7 +1363,7 @@ impl TenantShard { remote_storage, deletion_queue_client, l0_flush_global_state, - basebackup_prepare_sender, + basebackup_cache, feature_resolver, } = resources; @@ -1352,7 +1380,7 @@ impl TenantShard { remote_storage.clone(), deletion_queue_client, l0_flush_global_state, - basebackup_prepare_sender, + basebackup_cache, feature_resolver, )); @@ -1832,6 +1860,29 @@ impl TenantShard { } } + // At this point we've initialized all timelines and are tracking them. + // Now compute the layer visibility for all (not offloaded) timelines. + let compute_visiblity_for = { + let timelines_accessor = self.timelines.lock().unwrap(); + let mut timelines_offloaded_accessor = self.timelines_offloaded.lock().unwrap(); + + timelines_offloaded_accessor.extend(offloaded_timelines_list.into_iter()); + + // Before activation, populate each Timeline's GcInfo with information about its children + self.initialize_gc_info(&timelines_accessor, &timelines_offloaded_accessor, None); + + timelines_accessor.values().cloned().collect::>() + }; + + for tl in compute_visiblity_for { + tl.update_layer_visibility().await.with_context(|| { + format!( + "failed initial timeline visibility computation {} for tenant {}", + tl.timeline_id, self.tenant_shard_id + ) + })?; + } + // Walk through deleted timelines, resume deletion for (timeline_id, index_part, remote_timeline_client) in timelines_to_resume_deletions { remote_timeline_client @@ -1851,10 +1902,6 @@ impl TenantShard { .context("resume_deletion") .map_err(LoadLocalTimelineError::ResumeDeletion)?; } - { - let mut offloaded_timelines_accessor = self.timelines_offloaded.lock().unwrap(); - offloaded_timelines_accessor.extend(offloaded_timelines_list.into_iter()); - } // Stash the preloaded tenant manifest, and upload a new manifest if changed. // @@ -2495,7 +2542,7 @@ impl TenantShard { self: &Arc, new_timeline_id: TimelineId, initdb_lsn: Lsn, - pg_version: u32, + pg_version: PgMajorVersion, ctx: &RequestContext, ) -> anyhow::Result<(UninitializedTimeline, RequestContext)> { anyhow::ensure!( @@ -2547,7 +2594,7 @@ impl TenantShard { self: &Arc, new_timeline_id: TimelineId, initdb_lsn: Lsn, - pg_version: u32, + pg_version: PgMajorVersion, ctx: &RequestContext, ) -> anyhow::Result> { let (uninit_tl, ctx) = self @@ -2586,7 +2633,7 @@ impl TenantShard { self: &Arc, new_timeline_id: TimelineId, initdb_lsn: Lsn, - pg_version: u32, + pg_version: PgMajorVersion, ctx: &RequestContext, in_memory_layer_desc: Vec, delta_layer_desc: Vec, @@ -2617,7 +2664,7 @@ impl TenantShard { } let layer_names = tline .layers - .read() + .read(LayerManagerLockHolder::Testing) .await .layer_map() .unwrap() @@ -2852,7 +2899,7 @@ impl TenantShard { Lsn(0), initdb_lsn, initdb_lsn, - 15, + PgMajorVersion::PG15, ); this.prepare_new_timeline( new_timeline_id, @@ -3132,7 +3179,12 @@ impl TenantShard { for timeline in &compact { // Collect L0 counts. Can't await while holding lock above. - if let Ok(lm) = timeline.layers.read().await.layer_map() { + if let Ok(lm) = timeline + .layers + .read(LayerManagerLockHolder::Compaction) + .await + .layer_map() + { l0_counts.insert(timeline.timeline_id, lm.level0_deltas().len()); } } @@ -3398,7 +3450,7 @@ impl TenantShard { use pageserver_api::models::ActivatingFrom; match &*current_state { TenantState::Activating(_) | TenantState::Active | TenantState::Broken { .. } | TenantState::Stopping { .. } => { - panic!("caller is responsible for calling activate() only on Loading / Attaching tenants, got {state:?}", state = current_state); + panic!("caller is responsible for calling activate() only on Loading / Attaching tenants, got {current_state:?}"); } TenantState::Attaching => { *current_state = TenantState::Activating(ActivatingFrom::Attaching); @@ -3417,9 +3469,6 @@ impl TenantShard { .values() .filter(|timeline| !(timeline.is_broken() || timeline.is_stopping())); - // Before activation, populate each Timeline's GcInfo with information about its children - self.initialize_gc_info(&timelines_accessor, &timelines_offloaded_accessor, None); - // Spawn gc and compaction loops. The loops will shut themselves // down when they notice that the tenant is inactive. tasks::start_background_loops(self, background_jobs_can_start); @@ -4331,7 +4380,7 @@ impl TenantShard { remote_storage: GenericRemoteStorage, deletion_queue_client: DeletionQueueClient, l0_flush_global_state: L0FlushGlobalState, - basebackup_prepare_sender: BasebackupPrepareSender, + basebackup_cache: Arc, feature_resolver: FeatureResolver, ) -> TenantShard { assert!(!attached_conf.location.generation.is_none()); @@ -4436,7 +4485,7 @@ impl TenantShard { ongoing_timeline_detach: std::sync::Mutex::default(), gc_block: Default::default(), l0_flush_global_state, - basebackup_prepare_sender, + basebackup_cache, feature_resolver, } } @@ -4874,7 +4923,7 @@ impl TenantShard { } let layer_names = tline .layers - .read() + .read(LayerManagerLockHolder::Testing) .await .layer_map() .unwrap() @@ -5042,7 +5091,7 @@ impl TenantShard { pub(crate) async fn bootstrap_timeline_test( self: &Arc, timeline_id: TimelineId, - pg_version: u32, + pg_version: PgMajorVersion, load_existing_initdb: Option, ctx: &RequestContext, ) -> anyhow::Result> { @@ -5184,7 +5233,7 @@ impl TenantShard { async fn bootstrap_timeline( self: &Arc, timeline_id: TimelineId, - pg_version: u32, + pg_version: PgMajorVersion, load_existing_initdb: Option, ctx: &RequestContext, ) -> Result { @@ -5365,7 +5414,7 @@ impl TenantShard { pagestream_throttle_metrics: self.pagestream_throttle_metrics.clone(), l0_compaction_trigger: self.l0_compaction_trigger.clone(), l0_flush_global_state: self.l0_flush_global_state.clone(), - basebackup_prepare_sender: self.basebackup_prepare_sender.clone(), + basebackup_cache: self.basebackup_cache.clone(), feature_resolver: self.feature_resolver.clone(), } } @@ -5722,7 +5771,7 @@ impl TenantShard { async fn run_initdb( conf: &'static PageServerConf, initdb_target_dir: &Utf8Path, - pg_version: u32, + pg_version: PgMajorVersion, cancel: &CancellationToken, ) -> Result<(), InitdbError> { let initdb_bin_path = conf @@ -5804,10 +5853,10 @@ pub(crate) mod harness { use once_cell::sync::OnceCell; use pageserver_api::key::Key; use pageserver_api::models::ShardParameters; - use pageserver_api::record::NeonWalRecord; use pageserver_api::shard::ShardIndex; use utils::id::TenantId; use utils::logging; + use wal_decoder::models::record::NeonWalRecord; use super::*; use crate::deletion_queue::mock::MockDeletionQueue; @@ -5951,7 +6000,7 @@ pub(crate) mod harness { ) -> anyhow::Result> { let walredo_mgr = Arc::new(WalRedoManager::from(TestRedoManager)); - let (basebackup_requst_sender, _) = tokio::sync::mpsc::unbounded_channel(); + let (basebackup_cache, _) = BasebackupCache::new(Utf8PathBuf::new(), None); let tenant = Arc::new(TenantShard::new( TenantState::Attaching, @@ -5969,7 +6018,7 @@ pub(crate) mod harness { self.deletion_queue.new_client(), // TODO: ideally we should run all unit tests with both configs L0FlushGlobalState::new(L0FlushConfig::default()), - basebackup_requst_sender, + basebackup_cache, FeatureResolver::new_disabled(), )); @@ -6003,7 +6052,7 @@ pub(crate) mod harness { lsn: Lsn, base_img: Option<(Lsn, Bytes)>, records: Vec<(Lsn, NeonWalRecord)>, - _pg_version: u32, + _pg_version: PgMajorVersion, _redo_attempt_type: RedoAttemptType, ) -> Result { let records_neon = records.iter().all(|r| apply_neon::can_apply_in_neon(&r.1)); @@ -6062,9 +6111,6 @@ mod tests { #[cfg(feature = "testing")] use pageserver_api::keyspace::KeySpaceRandomAccum; use pageserver_api::models::{CompactionAlgorithm, CompactionAlgorithmSettings}; - #[cfg(feature = "testing")] - use pageserver_api::record::NeonWalRecord; - use pageserver_api::value::Value; use pageserver_compaction::helpers::overlaps_with; #[cfg(feature = "testing")] use rand::SeedableRng; @@ -6085,6 +6131,9 @@ mod tests { use timeline::{CompactOptions, DeltaLayerTestDesc, VersionedKeySpaceQuery}; use utils::id::TenantId; use utils::shard::{ShardCount, ShardNumber}; + #[cfg(feature = "testing")] + use wal_decoder::models::record::NeonWalRecord; + use wal_decoder::models::value::Value; use super::*; use crate::DEFAULT_PG_VERSION; @@ -6175,7 +6224,7 @@ mod tests { async fn randomize_timeline( tenant: &Arc, new_timeline_id: TimelineId, - pg_version: u32, + pg_version: PgMajorVersion, spec: TestTimelineSpecification, random: &mut rand::rngs::StdRng, ctx: &RequestContext, @@ -6568,7 +6617,7 @@ mod tests { .put( *TEST_KEY, lsn, - &Value::Image(test_img(&format!("foo at {}", lsn))), + &Value::Image(test_img(&format!("foo at {lsn}"))), ctx, ) .await?; @@ -6578,7 +6627,7 @@ mod tests { .put( *TEST_KEY, lsn, - &Value::Image(test_img(&format!("foo at {}", lsn))), + &Value::Image(test_img(&format!("foo at {lsn}"))), ctx, ) .await?; @@ -6592,7 +6641,7 @@ mod tests { .put( *TEST_KEY, lsn, - &Value::Image(test_img(&format!("foo at {}", lsn))), + &Value::Image(test_img(&format!("foo at {lsn}"))), ctx, ) .await?; @@ -6602,7 +6651,7 @@ mod tests { .put( *TEST_KEY, lsn, - &Value::Image(test_img(&format!("foo at {}", lsn))), + &Value::Image(test_img(&format!("foo at {lsn}"))), ctx, ) .await?; @@ -6944,7 +6993,7 @@ mod tests { .await?; make_some_layers(tline.as_ref(), Lsn(0x20), &ctx).await?; - let layer_map = tline.layers.read().await; + let layer_map = tline.layers.read(LayerManagerLockHolder::Testing).await; let level0_deltas = layer_map .layer_map()? .level0_deltas() @@ -7101,7 +7150,7 @@ mod tests { .put( test_key, lsn, - &Value::Image(test_img(&format!("{} at {}", blknum, lsn))), + &Value::Image(test_img(&format!("{blknum} at {lsn}"))), ctx, ) .await?; @@ -7180,7 +7229,7 @@ mod tests { let lsn = Lsn(0x10); let inserted = bulk_insert_compact_gc(&tenant, &tline, &ctx, lsn, 50, 10000).await?; - let guard = tline.layers.read().await; + let guard = tline.layers.read(LayerManagerLockHolder::Testing).await; let lm = guard.layer_map()?; lm.dump(true, &ctx).await?; @@ -7389,7 +7438,7 @@ mod tests { .put( gap_at_key, current_lsn, - &Value::Image(test_img(&format!("{} at {}", gap_at_key, current_lsn))), + &Value::Image(test_img(&format!("{gap_at_key} at {current_lsn}"))), &ctx, ) .await?; @@ -7428,7 +7477,7 @@ mod tests { .put( current_key, current_lsn, - &Value::Image(test_img(&format!("{} at {}", current_key, current_lsn))), + &Value::Image(test_img(&format!("{current_key} at {current_lsn}"))), &ctx, ) .await?; @@ -7536,7 +7585,7 @@ mod tests { while key < end_key { current_lsn += 0x10; - let image_value = format!("{} at {}", child_gap_at_key, current_lsn); + let image_value = format!("{child_gap_at_key} at {current_lsn}"); let mut writer = parent_timeline.writer().await; writer @@ -7579,7 +7628,7 @@ mod tests { .put( key, current_lsn, - &Value::Image(test_img(&format!("{} at {}", key, current_lsn))), + &Value::Image(test_img(&format!("{key} at {current_lsn}"))), &ctx, ) .await?; @@ -7700,7 +7749,7 @@ mod tests { .put( test_key, lsn, - &Value::Image(test_img(&format!("{} at {}", blknum, lsn))), + &Value::Image(test_img(&format!("{blknum} at {lsn}"))), &ctx, ) .await?; @@ -7721,7 +7770,7 @@ mod tests { .put( test_key, lsn, - &Value::Image(test_img(&format!("{} at {}", blknum, lsn))), + &Value::Image(test_img(&format!("{blknum} at {lsn}"))), &ctx, ) .await?; @@ -7735,7 +7784,7 @@ mod tests { test_key.field6 = blknum as u32; assert_eq!( tline.get(test_key, lsn, &ctx).await?, - test_img(&format!("{} at {}", blknum, last_lsn)) + test_img(&format!("{blknum} at {last_lsn}")) ); } @@ -7781,7 +7830,7 @@ mod tests { .put( test_key, lsn, - &Value::Image(test_img(&format!("{} at {}", blknum, lsn))), + &Value::Image(test_img(&format!("{blknum} at {lsn}"))), &ctx, ) .await?; @@ -7810,11 +7859,11 @@ mod tests { .put( test_key, lsn, - &Value::Image(test_img(&format!("{} at {}", blknum, lsn))), + &Value::Image(test_img(&format!("{blknum} at {lsn}"))), &ctx, ) .await?; - println!("updating {} at {}", blknum, lsn); + println!("updating {blknum} at {lsn}"); writer.finish_write(lsn); drop(writer); updated[blknum] = lsn; @@ -7825,7 +7874,7 @@ mod tests { test_key.field6 = blknum as u32; assert_eq!( tline.get(test_key, lsn, &ctx).await?, - test_img(&format!("{} at {}", blknum, last_lsn)) + test_img(&format!("{blknum} at {last_lsn}")) ); } @@ -7878,11 +7927,11 @@ mod tests { .put( test_key, lsn, - &Value::Image(test_img(&format!("{} {} at {}", idx, blknum, lsn))), + &Value::Image(test_img(&format!("{idx} {blknum} at {lsn}"))), &ctx, ) .await?; - println!("updating [{}][{}] at {}", idx, blknum, lsn); + println!("updating [{idx}][{blknum}] at {lsn}"); writer.finish_write(lsn); drop(writer); updated[idx][blknum] = lsn; @@ -8088,7 +8137,7 @@ mod tests { .put( test_key, lsn, - &Value::Image(test_img(&format!("{} at {}", blknum, lsn))), + &Value::Image(test_img(&format!("{blknum} at {lsn}"))), &ctx, ) .await?; @@ -8105,7 +8154,7 @@ mod tests { test_key.field6 = (blknum * STEP) as u32; assert_eq!( tline.get(test_key, lsn, &ctx).await?, - test_img(&format!("{} at {}", blknum, last_lsn)) + test_img(&format!("{blknum} at {last_lsn}")) ); } @@ -8142,7 +8191,7 @@ mod tests { .put( test_key, lsn, - &Value::Image(test_img(&format!("{} at {}", blknum, lsn))), + &Value::Image(test_img(&format!("{blknum} at {lsn}"))), &ctx, ) .await?; @@ -8208,12 +8257,23 @@ mod tests { tline.freeze_and_flush().await?; // force create a delta layer } - let before_num_l0_delta_files = - tline.layers.read().await.layer_map()?.level0_deltas().len(); + let before_num_l0_delta_files = tline + .layers + .read(LayerManagerLockHolder::Testing) + .await + .layer_map()? + .level0_deltas() + .len(); tline.compact(&cancel, EnumSet::default(), &ctx).await?; - let after_num_l0_delta_files = tline.layers.read().await.layer_map()?.level0_deltas().len(); + let after_num_l0_delta_files = tline + .layers + .read(LayerManagerLockHolder::Testing) + .await + .layer_map()? + .level0_deltas() + .len(); assert!( after_num_l0_delta_files < before_num_l0_delta_files, @@ -8384,7 +8444,7 @@ mod tests { .put( test_key, lsn, - &Value::Image(test_img(&format!("{} at {}", blknum, lsn))), + &Value::Image(test_img(&format!("{blknum} at {lsn}"))), &ctx, ) .await?; @@ -8404,7 +8464,7 @@ mod tests { .put( test_key, lsn, - &Value::Image(test_img(&format!("{} at {}", blknum, lsn))), + &Value::Image(test_img(&format!("{blknum} at {lsn}"))), &ctx, ) .await?; @@ -9325,12 +9385,7 @@ mod tests { let end_lsn = Lsn(0x100); let image_layers = (0x20..=0x90) .step_by(0x10) - .map(|n| { - ( - Lsn(n), - vec![(key, test_img(&format!("data key at {:x}", n)))], - ) - }) + .map(|n| (Lsn(n), vec![(key, test_img(&format!("data key at {n:x}")))])) .collect(); let timeline = tenant diff --git a/pageserver/src/tenant/checks.rs b/pageserver/src/tenant/checks.rs index d5b979ab2a..83d54f09de 100644 --- a/pageserver/src/tenant/checks.rs +++ b/pageserver/src/tenant/checks.rs @@ -63,8 +63,7 @@ pub fn check_valid_layermap(metadata: &[LayerName]) -> Option { && overlaps_with(&layer.key_range, &other_layer.key_range) { let err = format!( - "layer violates the layer map LSN split assumption: layer {} intersects with layer {}", - layer, other_layer + "layer violates the layer map LSN split assumption: layer {layer} intersects with layer {other_layer}" ); return Some(err); } diff --git a/pageserver/src/tenant/config.rs b/pageserver/src/tenant/config.rs index bf82fc8df8..c5087f7e0f 100644 --- a/pageserver/src/tenant/config.rs +++ b/pageserver/src/tenant/config.rs @@ -61,8 +61,10 @@ pub(crate) struct LocationConf { /// The detailed shard identity. This structure is already scoped within /// a TenantShardId, but we need the full ShardIdentity to enable calculating /// key->shard mappings. - #[serde(default = "ShardIdentity::unsharded")] - #[serde(skip_serializing_if = "ShardIdentity::is_unsharded")] + /// + /// NB: we store this even for unsharded tenants, so that we agree with storcon on the intended + /// stripe size. Otherwise, a split request that does not specify a stripe size may use a + /// different default than storcon, which can lead to incorrect stripe sizes and corruption. pub(crate) shard: ShardIdentity, /// The pan-cluster tenant configuration, the same on all locations @@ -149,7 +151,12 @@ impl LocationConf { /// For use when attaching/re-attaching: update the generation stored in this /// structure. If we were in a secondary state, promote to attached (posession /// of a fresh generation implies this). - pub(crate) fn attach_in_generation(&mut self, mode: AttachmentMode, generation: Generation) { + pub(crate) fn attach_in_generation( + &mut self, + mode: AttachmentMode, + generation: Generation, + stripe_size: ShardStripeSize, + ) { match &mut self.mode { LocationMode::Attached(attach_conf) => { attach_conf.generation = generation; @@ -163,6 +170,8 @@ impl LocationConf { }) } } + + self.shard.stripe_size = stripe_size; } pub(crate) fn try_from(conf: &'_ models::LocationConfig) -> anyhow::Result { diff --git a/pageserver/src/tenant/ephemeral_file.rs b/pageserver/src/tenant/ephemeral_file.rs index 2edf22e9fd..203b5bf592 100644 --- a/pageserver/src/tenant/ephemeral_file.rs +++ b/pageserver/src/tenant/ephemeral_file.rs @@ -3,7 +3,7 @@ use std::io; use std::sync::Arc; -use std::sync::atomic::AtomicU64; +use std::sync::atomic::{AtomicU64, Ordering}; use camino::Utf8PathBuf; use num_traits::Num; @@ -18,6 +18,7 @@ use crate::assert_u64_eq_usize::{U64IsUsize, UsizeIsU64}; use crate::config::PageServerConf; use crate::context::RequestContext; use crate::page_cache; +use crate::tenant::storage_layer::inmemory_layer::GlobalResourceUnits; use crate::tenant::storage_layer::inmemory_layer::vectored_dio_read::File; use crate::virtual_file::owned_buffers_io::io_buf_aligned::IoBufAlignedMut; use crate::virtual_file::owned_buffers_io::slice::SliceMutExt; @@ -30,9 +31,13 @@ pub struct EphemeralFile { _tenant_shard_id: TenantShardId, _timeline_id: TimelineId, page_cache_file_id: page_cache::FileId, - bytes_written: u64, file: TempVirtualFileCoOwnedByEphemeralFileAndBufferedWriter, - buffered_writer: BufferedWriter, + + buffered_writer: tokio::sync::RwLock, + + bytes_written: AtomicU64, + + resource_units: std::sync::Mutex, } type BufferedWriter = owned_buffers_io::write::BufferedWriter< @@ -94,9 +99,8 @@ impl EphemeralFile { _tenant_shard_id: tenant_shard_id, _timeline_id: timeline_id, page_cache_file_id, - bytes_written: 0, file: file.clone(), - buffered_writer: BufferedWriter::new( + buffered_writer: tokio::sync::RwLock::new(BufferedWriter::new( file, 0, || IoBufferMut::with_capacity(TAIL_SZ), @@ -104,7 +108,9 @@ impl EphemeralFile { cancel.child_token(), ctx, info_span!(parent: None, "ephemeral_file_buffered_writer", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), timeline_id=%timeline_id, path = %filename), - ), + )), + bytes_written: AtomicU64::new(0), + resource_units: std::sync::Mutex::new(GlobalResourceUnits::new()), }) } } @@ -151,15 +157,17 @@ impl std::ops::Deref for TempVirtualFileCoOwnedByEphemeralFileAndBufferedWriter #[derive(Debug, thiserror::Error)] pub(crate) enum EphemeralFileWriteError { - #[error("{0}")] - TooLong(String), #[error("cancelled")] Cancelled, } impl EphemeralFile { pub(crate) fn len(&self) -> u64 { - self.bytes_written + // TODO(vlad): The value returned here is not always correct if + // we have more than one concurrent writer. Writes are always + // sequenced, but we could grab the buffered writer lock if we wanted + // to. + self.bytes_written.load(Ordering::Acquire) } pub(crate) fn page_cache_file_id(&self) -> page_cache::FileId { @@ -186,7 +194,7 @@ impl EphemeralFile { /// Panics if the write is short because there's no way we can recover from that. /// TODO: make upstack handle this as an error. pub(crate) async fn write_raw( - &mut self, + &self, srcbuf: &[u8], ctx: &RequestContext, ) -> Result { @@ -198,22 +206,13 @@ impl EphemeralFile { } async fn write_raw_controlled( - &mut self, + &self, srcbuf: &[u8], ctx: &RequestContext, ) -> Result<(u64, Option), EphemeralFileWriteError> { - let pos = self.bytes_written; + let mut writer = self.buffered_writer.write().await; - let new_bytes_written = pos.checked_add(srcbuf.len().into_u64()).ok_or_else(|| { - EphemeralFileWriteError::TooLong(format!( - "write would grow EphemeralFile beyond u64::MAX: len={pos} writen={srcbuf_len}", - srcbuf_len = srcbuf.len(), - )) - })?; - - // Write the payload - let (nwritten, control) = self - .buffered_writer + let (nwritten, control) = writer .write_buffered_borrowed_controlled(srcbuf, ctx) .await .map_err(|e| match e { @@ -225,43 +224,69 @@ impl EphemeralFile { "buffered writer has no short writes" ); - self.bytes_written = new_bytes_written; + // There's no realistic risk of overflow here. We won't have exabytes sized files on disk. + let pos = self + .bytes_written + .fetch_add(srcbuf.len().into_u64(), Ordering::AcqRel); + + let mut resource_units = self.resource_units.lock().unwrap(); + resource_units.maybe_publish_size(self.bytes_written.load(Ordering::Relaxed)); Ok((pos, control)) } + + pub(crate) fn tick(&self) -> Option { + let mut resource_units = self.resource_units.lock().unwrap(); + let len = self.bytes_written.load(Ordering::Relaxed); + resource_units.publish_size(len) + } } impl super::storage_layer::inmemory_layer::vectored_dio_read::File for EphemeralFile { async fn read_exact_at_eof_ok( &self, start: u64, - dst: tokio_epoll_uring::Slice, + mut dst: tokio_epoll_uring::Slice, ctx: &RequestContext, ) -> std::io::Result<(tokio_epoll_uring::Slice, usize)> { - let submitted_offset = self.buffered_writer.bytes_submitted(); + // We will fill the slice in back to front. Hence, we need + // the slice to be fully initialized. + // TODO(vlad): Is there a nicer way of doing this? + dst.as_mut_rust_slice_full_zeroed(); - let mutable = match self.buffered_writer.inspect_mutable() { - Some(mutable) => &mutable[0..mutable.pending()], - None => { - // Timeline::cancel and hence buffered writer flush was cancelled. - // Remain read-available while timeline is shutting down. - &[] - } - }; + let writer = self.buffered_writer.read().await; - let maybe_flushed = self.buffered_writer.inspect_maybe_flushed(); + // Read bytes written while under lock. This is a hack to deal with concurrent + // writes updating the number of bytes written. `bytes_written` is not DIO alligned + // but we may end the read there. + // + // TODO(vlad): Feels like there's a nicer path where we align the end if it + // shoots over the end of the file. + let bytes_written = self.bytes_written.load(Ordering::Acquire); let dst_cap = dst.bytes_total().into_u64(); let end = { // saturating_add is correct here because the max file size is u64::MAX, so, // if start + dst.len() > u64::MAX, then we know it will be a short read let mut end: u64 = start.saturating_add(dst_cap); - if end > self.bytes_written { - end = self.bytes_written; + if end > bytes_written { + end = bytes_written; } end }; + let submitted_offset = writer.bytes_submitted(); + let maybe_flushed = writer.inspect_maybe_flushed(); + + let mutable = match writer.inspect_mutable() { + Some(mutable) => &mutable[0..mutable.pending()], + None => { + // Timeline::cancel and hence buffered writer flush was cancelled. + // Remain read-available while timeline is shutting down. + &[] + } + }; + // inclusive, exclusive #[derive(Debug)] struct Range(N, N); @@ -306,13 +331,33 @@ impl super::storage_layer::inmemory_layer::vectored_dio_read::File for Ephemeral let mutable_range = Range(std::cmp::max(start, submitted_offset), end); - let dst = if written_range.len() > 0 { + // There are three sources from which we might have to read data: + // 1. The file itself + // 2. The buffer which contains changes currently being flushed + // 3. The buffer which contains chnages yet to be flushed + // + // For better concurrency, we do them in reverse order: perform the in-memory + // reads while holding the writer lock, drop the writer lock and read from the + // file if required. + + let dst = if mutable_range.len() > 0 { + let offset_in_buffer = mutable_range + .0 + .checked_sub(submitted_offset) + .unwrap() + .into_usize(); + let to_copy = + &mutable[offset_in_buffer..(offset_in_buffer + mutable_range.len().into_usize())]; let bounds = dst.bounds(); - let slice = self - .file - .read_exact_at(dst.slice(0..written_range.len().into_usize()), start, ctx) - .await?; - Slice::from_buf_bounds(Slice::into_inner(slice), bounds) + let mut view = dst.slice({ + let start = + written_range.len().into_usize() + maybe_flushed_range.len().into_usize(); + let end = start.checked_add(mutable_range.len().into_usize()).unwrap(); + start..end + }); + view.as_mut_rust_slice_full_zeroed() + .copy_from_slice(to_copy); + Slice::from_buf_bounds(Slice::into_inner(view), bounds) } else { dst }; @@ -342,24 +387,15 @@ impl super::storage_layer::inmemory_layer::vectored_dio_read::File for Ephemeral dst }; - let dst = if mutable_range.len() > 0 { - let offset_in_buffer = mutable_range - .0 - .checked_sub(submitted_offset) - .unwrap() - .into_usize(); - let to_copy = - &mutable[offset_in_buffer..(offset_in_buffer + mutable_range.len().into_usize())]; + drop(writer); + + let dst = if written_range.len() > 0 { let bounds = dst.bounds(); - let mut view = dst.slice({ - let start = - written_range.len().into_usize() + maybe_flushed_range.len().into_usize(); - let end = start.checked_add(mutable_range.len().into_usize()).unwrap(); - start..end - }); - view.as_mut_rust_slice_full_zeroed() - .copy_from_slice(to_copy); - Slice::from_buf_bounds(Slice::into_inner(view), bounds) + let slice = self + .file + .read_exact_at(dst.slice(0..written_range.len().into_usize()), start, ctx) + .await?; + Slice::from_buf_bounds(Slice::into_inner(slice), bounds) } else { dst }; @@ -460,13 +496,15 @@ mod tests { let gate = utils::sync::gate::Gate::default(); let cancel = CancellationToken::new(); - let mut file = EphemeralFile::create(conf, tenant_id, timeline_id, &gate, &cancel, &ctx) + let file = EphemeralFile::create(conf, tenant_id, timeline_id, &gate, &cancel, &ctx) .await .unwrap(); - let mutable = file.buffered_writer.mutable(); + let writer = file.buffered_writer.read().await; + let mutable = writer.mutable(); let cap = mutable.capacity(); let align = mutable.align(); + drop(writer); let write_nbytes = cap * 2 + cap / 2; @@ -504,10 +542,11 @@ mod tests { let file_contents = std::fs::read(file.file.path()).unwrap(); assert!(file_contents == content[0..cap * 2]); - let maybe_flushed_buffer_contents = file.buffered_writer.inspect_maybe_flushed().unwrap(); + let writer = file.buffered_writer.read().await; + let maybe_flushed_buffer_contents = writer.inspect_maybe_flushed().unwrap(); assert_eq!(&maybe_flushed_buffer_contents[..], &content[cap..cap * 2]); - let mutable_buffer_contents = file.buffered_writer.mutable(); + let mutable_buffer_contents = writer.mutable(); assert_eq!(mutable_buffer_contents, &content[cap * 2..write_nbytes]); } @@ -517,12 +556,14 @@ mod tests { let gate = utils::sync::gate::Gate::default(); let cancel = CancellationToken::new(); - let mut file = EphemeralFile::create(conf, tenant_id, timeline_id, &gate, &cancel, &ctx) + let file = EphemeralFile::create(conf, tenant_id, timeline_id, &gate, &cancel, &ctx) .await .unwrap(); // mutable buffer and maybe_flushed buffer each has `cap` bytes. - let cap = file.buffered_writer.mutable().capacity(); + let writer = file.buffered_writer.read().await; + let cap = writer.mutable().capacity(); + drop(writer); let content: Vec = rand::thread_rng() .sample_iter(rand::distributions::Standard) @@ -540,12 +581,13 @@ mod tests { 2 * cap.into_u64(), "buffered writer requires one write to be flushed if we write 2.5x buffer capacity" ); + let writer = file.buffered_writer.read().await; assert_eq!( - &file.buffered_writer.inspect_maybe_flushed().unwrap()[0..cap], + &writer.inspect_maybe_flushed().unwrap()[0..cap], &content[cap..cap * 2] ); assert_eq!( - &file.buffered_writer.mutable()[0..cap / 2], + &writer.mutable()[0..cap / 2], &content[cap * 2..cap * 2 + cap / 2] ); } @@ -563,13 +605,15 @@ mod tests { let gate = utils::sync::gate::Gate::default(); let cancel = CancellationToken::new(); - let mut file = EphemeralFile::create(conf, tenant_id, timeline_id, &gate, &cancel, &ctx) + let file = EphemeralFile::create(conf, tenant_id, timeline_id, &gate, &cancel, &ctx) .await .unwrap(); - let mutable = file.buffered_writer.mutable(); + let writer = file.buffered_writer.read().await; + let mutable = writer.mutable(); let cap = mutable.capacity(); let align = mutable.align(); + drop(writer); let content: Vec = rand::thread_rng() .sample_iter(rand::distributions::Standard) .take(cap * 2 + cap / 2) diff --git a/pageserver/src/tenant/metadata.rs b/pageserver/src/tenant/metadata.rs index bea3128265..2f407de951 100644 --- a/pageserver/src/tenant/metadata.rs +++ b/pageserver/src/tenant/metadata.rs @@ -18,6 +18,7 @@ //! [`IndexPart`]: super::remote_timeline_client::index::IndexPart use anyhow::ensure; +use postgres_ffi::PgMajorVersion; use serde::{Deserialize, Serialize}; use utils::bin_ser::{BeSer, SerializeError}; use utils::id::TimelineId; @@ -136,7 +137,7 @@ struct TimelineMetadataBodyV2 { latest_gc_cutoff_lsn: Lsn, initdb_lsn: Lsn, - pg_version: u32, + pg_version: PgMajorVersion, } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -167,7 +168,7 @@ impl TimelineMetadata { ancestor_lsn: Lsn, latest_gc_cutoff_lsn: Lsn, initdb_lsn: Lsn, - pg_version: u32, + pg_version: PgMajorVersion, ) -> Self { Self { hdr: TimelineMetadataHeader { @@ -215,7 +216,7 @@ impl TimelineMetadata { ancestor_lsn: body.ancestor_lsn, latest_gc_cutoff_lsn: body.latest_gc_cutoff_lsn, initdb_lsn: body.initdb_lsn, - pg_version: 14, // All timelines created before this version had pg_version 14 + pg_version: PgMajorVersion::PG14, // All timelines created before this version had pg_version 14 }; hdr.format_version = METADATA_FORMAT_VERSION; @@ -317,7 +318,7 @@ impl TimelineMetadata { self.body.initdb_lsn } - pub fn pg_version(&self) -> u32 { + pub fn pg_version(&self) -> PgMajorVersion { self.body.pg_version } @@ -331,7 +332,7 @@ impl TimelineMetadata { Lsn::from_hex("00000000").unwrap(), Lsn::from_hex("00000000").unwrap(), Lsn::from_hex("00000000").unwrap(), - 0, + PgMajorVersion::PG14, ); let bytes = instance.to_bytes().unwrap(); Self::from_bytes(&bytes).unwrap() @@ -545,13 +546,12 @@ mod tests { Lsn(0), Lsn(0), Lsn(0), - 14, // All timelines created before this version had pg_version 14 + PgMajorVersion::PG14, // All timelines created before this version had pg_version 14 ); assert_eq!( deserialized_metadata.body, expected_metadata.body, - "Metadata of the old version {} should be upgraded to the latest version {}", - METADATA_OLD_FORMAT_VERSION, METADATA_FORMAT_VERSION + "Metadata of the old version {METADATA_OLD_FORMAT_VERSION} should be upgraded to the latest version {METADATA_FORMAT_VERSION}" ); } @@ -566,7 +566,7 @@ mod tests { Lsn(0), // Updating this version to 17 will cause the test to fail at the // next assert_eq!(). - 16, + PgMajorVersion::PG16, ); let expected_bytes = vec![ /* TimelineMetadataHeader */ diff --git a/pageserver/src/tenant/mgr.rs b/pageserver/src/tenant/mgr.rs index 86aef9b42c..0a494e7923 100644 --- a/pageserver/src/tenant/mgr.rs +++ b/pageserver/src/tenant/mgr.rs @@ -12,7 +12,6 @@ use anyhow::Context; use camino::{Utf8DirEntry, Utf8Path, Utf8PathBuf}; use futures::StreamExt; use itertools::Itertools; -use once_cell::sync::Lazy; use pageserver_api::key::Key; use pageserver_api::models::{DetachBehavior, LocationConfigMode}; use pageserver_api::shard::{ @@ -52,6 +51,7 @@ use crate::tenant::config::{ use crate::tenant::span::debug_assert_current_span_has_tenant_id; use crate::tenant::storage_layer::inmemory_layer; use crate::tenant::timeline::ShutdownMode; +use crate::tenant::timeline::layer_manager::LayerManagerLockHolder; use crate::tenant::{ AttachedTenantConf, GcError, LoadConfigError, SpawnMode, TenantShard, TenantState, }; @@ -103,7 +103,7 @@ pub(crate) enum TenantsMap { /// [`init_tenant_mgr`] is not done yet. Initializing, /// [`init_tenant_mgr`] is done, all on-disk tenants have been loaded. - /// New tenants can be added using [`tenant_map_acquire_slot`]. + /// New tenants can be added using [`TenantManager::tenant_map_acquire_slot`]. Open(BTreeMap), /// The pageserver has entered shutdown mode via [`TenantManager::shutdown`]. /// Existing tenants are still accessible, but no new tenants can be created. @@ -129,7 +129,7 @@ pub(crate) enum ShardSelector { /// /// This represents the subset of a LocationConfig that we receive during re-attach. pub(crate) enum TenantStartupMode { - Attached((AttachmentMode, Generation)), + Attached((AttachmentMode, Generation, ShardStripeSize)), Secondary, } @@ -143,15 +143,21 @@ impl TenantStartupMode { match (rart.mode, rart.r#gen) { (LocationConfigMode::Detached, _) => None, (LocationConfigMode::Secondary, _) => Some(Self::Secondary), - (LocationConfigMode::AttachedMulti, Some(g)) => { - Some(Self::Attached((AttachmentMode::Multi, Generation::new(g)))) - } - (LocationConfigMode::AttachedSingle, Some(g)) => { - Some(Self::Attached((AttachmentMode::Single, Generation::new(g)))) - } - (LocationConfigMode::AttachedStale, Some(g)) => { - Some(Self::Attached((AttachmentMode::Stale, Generation::new(g)))) - } + (LocationConfigMode::AttachedMulti, Some(g)) => Some(Self::Attached(( + AttachmentMode::Multi, + Generation::new(g), + rart.stripe_size, + ))), + (LocationConfigMode::AttachedSingle, Some(g)) => Some(Self::Attached(( + AttachmentMode::Single, + Generation::new(g), + rart.stripe_size, + ))), + (LocationConfigMode::AttachedStale, Some(g)) => Some(Self::Attached(( + AttachmentMode::Stale, + Generation::new(g), + rart.stripe_size, + ))), _ => { tracing::warn!( "Received invalid re-attach state for tenant {}: {rart:?}", @@ -284,9 +290,6 @@ impl BackgroundPurges { } } -static TENANTS: Lazy> = - Lazy::new(|| std::sync::RwLock::new(TenantsMap::Initializing)); - /// Responsible for storing and mutating the collection of all tenants /// that this pageserver has state for. /// @@ -297,10 +300,7 @@ static TENANTS: Lazy> = /// and attached modes concurrently. pub struct TenantManager { conf: &'static PageServerConf, - // TODO: currently this is a &'static pointing to TENANTs. When we finish refactoring - // out of that static variable, the TenantManager can own this. - // See https://github.com/neondatabase/neon/issues/5796 - tenants: &'static std::sync::RwLock, + tenants: std::sync::RwLock, resources: TenantSharedResources, // Long-running operations that happen outside of a [`Tenant`] lifetime should respect this token. @@ -325,9 +325,11 @@ fn emergency_generations( Some(( *tid, match &lc.mode { - LocationMode::Attached(alc) => { - TenantStartupMode::Attached((alc.attach_mode, alc.generation)) - } + LocationMode::Attached(alc) => TenantStartupMode::Attached(( + alc.attach_mode, + alc.generation, + ShardStripeSize::default(), + )), LocationMode::Secondary(_) => TenantStartupMode::Secondary, }, )) @@ -371,7 +373,7 @@ async fn init_load_generations( .iter() .flat_map(|(id, start_mode)| { match start_mode { - TenantStartupMode::Attached((_mode, generation)) => Some(generation), + TenantStartupMode::Attached((_mode, generation, _stripe_size)) => Some(generation), TenantStartupMode::Secondary => None, } .map(|gen_| (*id, *gen_)) @@ -479,21 +481,43 @@ pub(crate) enum DeleteTenantError { Other(#[from] anyhow::Error), } -/// Initialize repositories with locally available timelines. +/// Initialize repositories at `Initializing` state. +pub fn init( + conf: &'static PageServerConf, + background_purges: BackgroundPurges, + resources: TenantSharedResources, + cancel: CancellationToken, +) -> TenantManager { + TenantManager { + conf, + tenants: std::sync::RwLock::new(TenantsMap::Initializing), + resources, + cancel, + background_purges, + } +} + +/// Transition repositories from `Initializing` state to `Open` state with locally available timelines. /// Timelines that are only partially available locally (remote storage has more data than this pageserver) /// are scheduled for download and added to the tenant once download is completed. #[instrument(skip_all)] pub async fn init_tenant_mgr( - conf: &'static PageServerConf, - background_purges: BackgroundPurges, - resources: TenantSharedResources, + tenant_manager: Arc, init_order: InitializationOrder, - cancel: CancellationToken, -) -> anyhow::Result { +) -> anyhow::Result<()> { + debug_assert!(matches!( + *tenant_manager.tenants.read().unwrap(), + TenantsMap::Initializing + )); let mut tenants = BTreeMap::new(); let ctx = RequestContext::todo_child(TaskKind::Startup, DownloadBehavior::Warn); + let conf = tenant_manager.conf; + let resources = &tenant_manager.resources; + let cancel = &tenant_manager.cancel; + let background_purges = &tenant_manager.background_purges; + // Initialize dynamic limits that depend on system resources let system_memory = sysinfo::System::new_with_specifics(sysinfo::RefreshKind::new().with_memory()) @@ -512,7 +536,7 @@ pub async fn init_tenant_mgr( let tenant_configs = init_load_tenant_configs(conf).await; // Determine which tenants are to be secondary or attached, and in which generation - let tenant_modes = init_load_generations(conf, &tenant_configs, &resources, &cancel).await?; + let tenant_modes = init_load_generations(conf, &tenant_configs, resources, cancel).await?; tracing::info!( "Attaching {} tenants at startup, warming up {} at a time", @@ -569,7 +593,7 @@ pub async fn init_tenant_mgr( location_conf.mode = LocationMode::Secondary(DEFAULT_SECONDARY_CONF); } } - Some(TenantStartupMode::Attached((attach_mode, generation))) => { + Some(TenantStartupMode::Attached((attach_mode, generation, stripe_size))) => { let old_gen_higher = match &location_conf.mode { LocationMode::Attached(AttachedLocationConfig { generation: old_generation, @@ -593,7 +617,7 @@ pub async fn init_tenant_mgr( // local disk content: demote to secondary rather than detaching. location_conf.mode = LocationMode::Secondary(DEFAULT_SECONDARY_CONF); } else { - location_conf.attach_in_generation(*attach_mode, *generation); + location_conf.attach_in_generation(*attach_mode, *generation, *stripe_size); } } } @@ -669,18 +693,10 @@ pub async fn init_tenant_mgr( info!("Processed {} local tenants at startup", tenants.len()); - let mut tenants_map = TENANTS.write().unwrap(); - assert!(matches!(&*tenants_map, &TenantsMap::Initializing)); + let mut tenant_map = tenant_manager.tenants.write().unwrap(); + *tenant_map = TenantsMap::Open(tenants); - *tenants_map = TenantsMap::Open(tenants); - - Ok(TenantManager { - conf, - tenants: &TENANTS, - resources, - cancel: CancellationToken::new(), - background_purges, - }) + Ok(()) } /// Wrapper for Tenant::spawn that checks invariants before running @@ -719,142 +735,6 @@ fn tenant_spawn( ) } -async fn shutdown_all_tenants0(tenants: &std::sync::RwLock) { - let mut join_set = JoinSet::new(); - - #[cfg(all(debug_assertions, not(test)))] - { - // Check that our metrics properly tracked the size of the tenants map. This is a convenient location to check, - // as it happens implicitly at the end of tests etc. - let m = tenants.read().unwrap(); - debug_assert_eq!(METRICS.slots_total(), m.len() as u64); - } - - // Atomically, 1. create the shutdown tasks and 2. prevent creation of new tenants. - let (total_in_progress, total_attached) = { - let mut m = tenants.write().unwrap(); - match &mut *m { - TenantsMap::Initializing => { - *m = TenantsMap::ShuttingDown(BTreeMap::default()); - info!("tenants map is empty"); - return; - } - TenantsMap::Open(tenants) => { - let mut shutdown_state = BTreeMap::new(); - let mut total_in_progress = 0; - let mut total_attached = 0; - - for (tenant_shard_id, v) in std::mem::take(tenants).into_iter() { - match v { - TenantSlot::Attached(t) => { - shutdown_state.insert(tenant_shard_id, TenantSlot::Attached(t.clone())); - join_set.spawn( - async move { - let res = { - let (_guard, shutdown_progress) = completion::channel(); - t.shutdown(shutdown_progress, ShutdownMode::FreezeAndFlush).await - }; - - if let Err(other_progress) = res { - // join the another shutdown in progress - other_progress.wait().await; - } - - // we cannot afford per tenant logging here, because if s3 is degraded, we are - // going to log too many lines - debug!("tenant successfully stopped"); - } - .instrument(info_span!("shutdown", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug())), - ); - - total_attached += 1; - } - TenantSlot::Secondary(state) => { - // We don't need to wait for this individually per-tenant: the - // downloader task will be waited on eventually, this cancel - // is just to encourage it to drop out if it is doing work - // for this tenant right now. - state.cancel.cancel(); - - shutdown_state.insert(tenant_shard_id, TenantSlot::Secondary(state)); - } - TenantSlot::InProgress(notify) => { - // InProgress tenants are not visible in TenantsMap::ShuttingDown: we will - // wait for their notifications to fire in this function. - join_set.spawn(async move { - notify.wait().await; - }); - - total_in_progress += 1; - } - } - } - *m = TenantsMap::ShuttingDown(shutdown_state); - (total_in_progress, total_attached) - } - TenantsMap::ShuttingDown(_) => { - error!( - "already shutting down, this function isn't supposed to be called more than once" - ); - return; - } - } - }; - - let started_at = std::time::Instant::now(); - - info!( - "Waiting for {} InProgress tenants and {} Attached tenants to shut down", - total_in_progress, total_attached - ); - - let total = join_set.len(); - let mut panicked = 0; - let mut buffering = true; - const BUFFER_FOR: std::time::Duration = std::time::Duration::from_millis(500); - let mut buffered = std::pin::pin!(tokio::time::sleep(BUFFER_FOR)); - - while !join_set.is_empty() { - tokio::select! { - Some(joined) = join_set.join_next() => { - match joined { - Ok(()) => {}, - Err(join_error) if join_error.is_cancelled() => { - unreachable!("we are not cancelling any of the tasks"); - } - Err(join_error) if join_error.is_panic() => { - // cannot really do anything, as this panic is likely a bug - panicked += 1; - } - Err(join_error) => { - warn!("unknown kind of JoinError: {join_error}"); - } - } - if !buffering { - // buffer so that every 500ms since the first update (or starting) we'll log - // how far away we are; this is because we will get SIGKILL'd at 10s, and we - // are not able to log *then*. - buffering = true; - buffered.as_mut().reset(tokio::time::Instant::now() + BUFFER_FOR); - } - }, - _ = &mut buffered, if buffering => { - buffering = false; - info!(remaining = join_set.len(), total, elapsed_ms = started_at.elapsed().as_millis(), "waiting for tenants to shutdown"); - } - } - } - - if panicked > 0 { - warn!( - panicked, - total, "observed panicks while shutting down tenants" - ); - } - - // caller will log how long we took -} - #[derive(thiserror::Error, Debug)] pub(crate) enum UpsertLocationError { #[error("Bad config request: {0}")] @@ -1056,7 +936,8 @@ impl TenantManager { // the tenant is inaccessible to the outside world while we are doing this, but that is sensible: // the state is ill-defined while we're in transition. Transitions are async, but fast: we do // not do significant I/O, and shutdowns should be prompt via cancellation tokens. - let mut slot_guard = tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any) + let mut slot_guard = self + .tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any) .map_err(|e| match e { TenantSlotError::NotFound(_) => { unreachable!("Called with mode Any") @@ -1223,6 +1104,75 @@ impl TenantManager { } } + fn tenant_map_acquire_slot( + &self, + tenant_shard_id: &TenantShardId, + mode: TenantSlotAcquireMode, + ) -> Result { + use TenantSlotAcquireMode::*; + METRICS.tenant_slot_writes.inc(); + + let mut locked = self.tenants.write().unwrap(); + let span = tracing::info_span!("acquire_slot", tenant_id=%tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug()); + let _guard = span.enter(); + + let m = match &mut *locked { + TenantsMap::Initializing => return Err(TenantMapError::StillInitializing.into()), + TenantsMap::ShuttingDown(_) => return Err(TenantMapError::ShuttingDown.into()), + TenantsMap::Open(m) => m, + }; + + use std::collections::btree_map::Entry; + + let entry = m.entry(*tenant_shard_id); + + match entry { + Entry::Vacant(v) => match mode { + MustExist => { + tracing::debug!("Vacant && MustExist: return NotFound"); + Err(TenantSlotError::NotFound(*tenant_shard_id)) + } + _ => { + let (completion, barrier) = utils::completion::channel(); + let inserting = TenantSlot::InProgress(barrier); + METRICS.slot_inserted(&inserting); + v.insert(inserting); + tracing::debug!("Vacant, inserted InProgress"); + Ok(SlotGuard::new( + *tenant_shard_id, + None, + completion, + &self.tenants, + )) + } + }, + Entry::Occupied(mut o) => { + // Apply mode-driven checks + match (o.get(), mode) { + (TenantSlot::InProgress(_), _) => { + tracing::debug!("Occupied, failing for InProgress"); + Err(TenantSlotError::InProgress) + } + _ => { + // Happy case: the slot was not in any state that violated our mode + let (completion, barrier) = utils::completion::channel(); + let in_progress = TenantSlot::InProgress(barrier); + METRICS.slot_inserted(&in_progress); + let old_value = o.insert(in_progress); + METRICS.slot_removed(&old_value); + tracing::debug!("Occupied, replaced with InProgress"); + Ok(SlotGuard::new( + *tenant_shard_id, + Some(old_value), + completion, + &self.tenants, + )) + } + } + } + } + } + /// Resetting a tenant is equivalent to detaching it, then attaching it again with the same /// LocationConf that was last used to attach it. Optionally, the local file cache may be /// dropped before re-attaching. @@ -1239,7 +1189,8 @@ impl TenantManager { drop_cache: bool, ctx: &RequestContext, ) -> anyhow::Result<()> { - let mut slot_guard = tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?; + let mut slot_guard = + self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?; let Some(old_slot) = slot_guard.get_old_value() else { anyhow::bail!("Tenant not found when trying to reset"); }; @@ -1388,7 +1339,8 @@ impl TenantManager { Ok(()) } - let slot_guard = tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?; + let slot_guard = + self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?; match &slot_guard.old_value { Some(TenantSlot::Attached(tenant)) => { // Legacy deletion flow: the tenant remains attached, goes to Stopping state, and @@ -1539,7 +1491,7 @@ impl TenantManager { // Phase 2: Put the parent shard to InProgress and grab a reference to the parent Tenant drop(tenant); let mut parent_slot_guard = - tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?; + self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?; let parent = match parent_slot_guard.get_old_value() { Some(TenantSlot::Attached(t)) => t, Some(TenantSlot::Secondary(_)) => anyhow::bail!("Tenant location in secondary mode"), @@ -1671,7 +1623,12 @@ impl TenantManager { } } - // Phase 5: Shut down the parent shard, and erase it from disk + // Phase 5: Shut down the parent shard. We leave it on disk in case the split fails and we + // have to roll back to the parent shard, avoiding a cold start. It will be cleaned up once + // the storage controller commits the split, or if all else fails, on the next restart. + // + // TODO: We don't flush the ephemeral layer here, because the split is likely to succeed and + // catching up the parent should be reasonably quick. Consider using FreezeAndFlush instead. let (_guard, progress) = completion::channel(); match parent.shutdown(progress, ShutdownMode::Hard).await { Ok(()) => {} @@ -1679,11 +1636,6 @@ impl TenantManager { other.wait().await; } } - let local_tenant_directory = self.conf.tenant_path(&tenant_shard_id); - let tmp_path = safe_rename_tenant_dir(&local_tenant_directory) - .await - .with_context(|| format!("local tenant directory {local_tenant_directory:?} rename"))?; - self.background_purges.spawn(tmp_path); fail::fail_point!("shard-split-pre-finish", |_| Err(anyhow::anyhow!( "failpoint" @@ -1715,7 +1667,10 @@ impl TenantManager { let parent_timelines = timelines.keys().cloned().collect::>(); for timeline in timelines.values() { tracing::info!(timeline_id=%timeline.timeline_id, "Loading list of layers to hardlink"); - let layers = timeline.layers.read().await; + let layers = timeline + .layers + .read(LayerManagerLockHolder::GetLayerMapInfo) + .await; for layer in layers.likely_resident_layers() { let relative_path = layer @@ -1843,44 +1798,210 @@ impl TenantManager { pub(crate) async fn shutdown(&self) { self.cancel.cancel(); - shutdown_all_tenants0(self.tenants).await + self.shutdown_all_tenants0().await } + async fn shutdown_all_tenants0(&self) { + let mut join_set = JoinSet::new(); + + #[cfg(all(debug_assertions, not(test)))] + { + // Check that our metrics properly tracked the size of the tenants map. This is a convenient location to check, + // as it happens implicitly at the end of tests etc. + let m = self.tenants.read().unwrap(); + debug_assert_eq!(METRICS.slots_total(), m.len() as u64); + } + + // Atomically, 1. create the shutdown tasks and 2. prevent creation of new tenants. + let (total_in_progress, total_attached) = { + let mut m = self.tenants.write().unwrap(); + match &mut *m { + TenantsMap::Initializing => { + *m = TenantsMap::ShuttingDown(BTreeMap::default()); + info!("tenants map is empty"); + return; + } + TenantsMap::Open(tenants) => { + let mut shutdown_state = BTreeMap::new(); + let mut total_in_progress = 0; + let mut total_attached = 0; + + for (tenant_shard_id, v) in std::mem::take(tenants).into_iter() { + match v { + TenantSlot::Attached(t) => { + shutdown_state + .insert(tenant_shard_id, TenantSlot::Attached(t.clone())); + join_set.spawn( + async move { + let res = { + let (_guard, shutdown_progress) = completion::channel(); + t.shutdown(shutdown_progress, ShutdownMode::FreezeAndFlush).await + }; + + if let Err(other_progress) = res { + // join the another shutdown in progress + other_progress.wait().await; + } + + // we cannot afford per tenant logging here, because if s3 is degraded, we are + // going to log too many lines + debug!("tenant successfully stopped"); + } + .instrument(info_span!("shutdown", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug())), + ); + + total_attached += 1; + } + TenantSlot::Secondary(state) => { + // We don't need to wait for this individually per-tenant: the + // downloader task will be waited on eventually, this cancel + // is just to encourage it to drop out if it is doing work + // for this tenant right now. + state.cancel.cancel(); + + shutdown_state + .insert(tenant_shard_id, TenantSlot::Secondary(state)); + } + TenantSlot::InProgress(notify) => { + // InProgress tenants are not visible in TenantsMap::ShuttingDown: we will + // wait for their notifications to fire in this function. + join_set.spawn(async move { + notify.wait().await; + }); + + total_in_progress += 1; + } + } + } + *m = TenantsMap::ShuttingDown(shutdown_state); + (total_in_progress, total_attached) + } + TenantsMap::ShuttingDown(_) => { + error!( + "already shutting down, this function isn't supposed to be called more than once" + ); + return; + } + } + }; + + let started_at = std::time::Instant::now(); + + info!( + "Waiting for {} InProgress tenants and {} Attached tenants to shut down", + total_in_progress, total_attached + ); + + let total = join_set.len(); + let mut panicked = 0; + let mut buffering = true; + const BUFFER_FOR: std::time::Duration = std::time::Duration::from_millis(500); + let mut buffered = std::pin::pin!(tokio::time::sleep(BUFFER_FOR)); + + while !join_set.is_empty() { + tokio::select! { + Some(joined) = join_set.join_next() => { + match joined { + Ok(()) => {}, + Err(join_error) if join_error.is_cancelled() => { + unreachable!("we are not cancelling any of the tasks"); + } + Err(join_error) if join_error.is_panic() => { + // cannot really do anything, as this panic is likely a bug + panicked += 1; + } + Err(join_error) => { + warn!("unknown kind of JoinError: {join_error}"); + } + } + if !buffering { + // buffer so that every 500ms since the first update (or starting) we'll log + // how far away we are; this is because we will get SIGKILL'd at 10s, and we + // are not able to log *then*. + buffering = true; + buffered.as_mut().reset(tokio::time::Instant::now() + BUFFER_FOR); + } + }, + _ = &mut buffered, if buffering => { + buffering = false; + info!(remaining = join_set.len(), total, elapsed_ms = started_at.elapsed().as_millis(), "waiting for tenants to shutdown"); + } + } + } + + if panicked > 0 { + warn!( + panicked, + total, "observed panicks while shutting down tenants" + ); + } + + // caller will log how long we took + } + + /// Detaches a tenant, and removes its local files asynchronously. + /// + /// File removal is idempotent: even if the tenant has already been removed, this will still + /// remove any local files. This is used during shard splits, where we leave the parent shard's + /// files around in case we have to roll back the split. pub(crate) async fn detach_tenant( &self, conf: &'static PageServerConf, tenant_shard_id: TenantShardId, deletion_queue_client: &DeletionQueueClient, ) -> Result<(), TenantStateError> { - let tmp_path = self + if let Some(tmp_path) = self .detach_tenant0(conf, tenant_shard_id, deletion_queue_client) - .await?; - self.background_purges.spawn(tmp_path); + .await? + { + self.background_purges.spawn(tmp_path); + } Ok(()) } + /// Detaches a tenant. This renames the tenant directory to a temporary path and returns it, + /// allowing the caller to delete it asynchronously. Returns None if the dir is already removed. async fn detach_tenant0( &self, conf: &'static PageServerConf, tenant_shard_id: TenantShardId, deletion_queue_client: &DeletionQueueClient, - ) -> Result { + ) -> Result, TenantStateError> { let tenant_dir_rename_operation = |tenant_id_to_clean: TenantShardId| async move { let local_tenant_directory = conf.tenant_path(&tenant_id_to_clean); + if !tokio::fs::try_exists(&local_tenant_directory).await? { + // If the tenant directory doesn't exist, it's already cleaned up. + return Ok(None); + } safe_rename_tenant_dir(&local_tenant_directory) .await .with_context(|| { format!("local tenant directory {local_tenant_directory:?} rename") }) + .map(Some) }; - let removal_result = remove_tenant_from_memory( - self.tenants, - tenant_shard_id, - tenant_dir_rename_operation(tenant_shard_id), - ) - .await; + let mut removal_result = self + .remove_tenant_from_memory( + tenant_shard_id, + tenant_dir_rename_operation(tenant_shard_id), + ) + .await; + + // If the tenant was not found, it was likely already removed. Attempt to remove the tenant + // directory on disk anyway. For example, during shard splits, we shut down and remove the + // parent shard, but leave its directory on disk in case we have to roll back the split. + // + // TODO: it would be better to leave the parent shard attached until the split is committed. + // This will be needed by the gRPC page service too, such that a compute can continue to + // read from the parent shard until it's notified about the new child shards. See: + // . + if let Err(TenantStateError::SlotError(TenantSlotError::NotFound(_))) = removal_result { + removal_result = tenant_dir_rename_operation(tenant_shard_id) + .await + .map_err(TenantStateError::Other); + } // Flush pending deletions, so that they have a good chance of passing validation // before this tenant is potentially re-attached elsewhere. @@ -1920,17 +2041,16 @@ impl TenantManager { ) -> Result, detach_ancestor::Error> { use detach_ancestor::Error; - let slot_guard = - tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::MustExist).map_err( - |e| { - use TenantSlotError::*; + let slot_guard = self + .tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::MustExist) + .map_err(|e| { + use TenantSlotError::*; - match e { - MapState(TenantMapError::ShuttingDown) => Error::ShuttingDown, - NotFound(_) | InProgress | MapState(_) => Error::DetachReparent(e.into()), - } - }, - )?; + match e { + MapState(TenantMapError::ShuttingDown) => Error::ShuttingDown, + NotFound(_) | InProgress | MapState(_) => Error::DetachReparent(e.into()), + } + })?; let tenant = { let old_slot = slot_guard @@ -2263,6 +2383,80 @@ impl TenantManager { other => ApiError::InternalServerError(anyhow::anyhow!(other)), }) } + + /// Stops and removes the tenant from memory, if it's not [`TenantState::Stopping`] already, bails otherwise. + /// Allows to remove other tenant resources manually, via `tenant_cleanup`. + /// If the cleanup fails, tenant will stay in memory in [`TenantState::Broken`] state, and another removal + async fn remove_tenant_from_memory( + &self, + tenant_shard_id: TenantShardId, + tenant_cleanup: F, + ) -> Result + where + F: std::future::Future>, + { + let mut slot_guard = + self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::MustExist)?; + + // allow pageserver shutdown to await for our completion + let (_guard, progress) = completion::channel(); + + // The SlotGuard allows us to manipulate the Tenant object without fear of some + // concurrent API request doing something else for the same tenant ID. + let attached_tenant = match slot_guard.get_old_value() { + Some(TenantSlot::Attached(tenant)) => { + // whenever we remove a tenant from memory, we don't want to flush and wait for upload + let shutdown_mode = ShutdownMode::Hard; + + // shutdown is sure to transition tenant to stopping, and wait for all tasks to complete, so + // that we can continue safely to cleanup. + match tenant.shutdown(progress, shutdown_mode).await { + Ok(()) => {} + Err(_other) => { + // if pageserver shutdown or other detach/ignore is already ongoing, we don't want to + // wait for it but return an error right away because these are distinct requests. + slot_guard.revert(); + return Err(TenantStateError::IsStopping(tenant_shard_id)); + } + } + Some(tenant) + } + Some(TenantSlot::Secondary(secondary_state)) => { + tracing::info!("Shutting down in secondary mode"); + secondary_state.shutdown().await; + None + } + Some(TenantSlot::InProgress(_)) => { + // Acquiring a slot guarantees its old value was not InProgress + unreachable!(); + } + None => None, + }; + + match tenant_cleanup + .await + .with_context(|| format!("Failed to run cleanup for tenant {tenant_shard_id}")) + { + Ok(hook_value) => { + // Success: drop the old TenantSlot::Attached. + slot_guard + .drop_old_value() + .expect("We just called shutdown"); + + Ok(hook_value) + } + Err(e) => { + // If we had a Tenant, set it to Broken and put it back in the TenantsMap + if let Some(attached_tenant) = attached_tenant { + attached_tenant.set_broken(e.to_string()).await; + } + // Leave the broken tenant in the map + slot_guard.revert(); + + Err(TenantStateError::Other(e)) + } + } + } } #[derive(Debug, thiserror::Error)] @@ -2427,7 +2621,7 @@ pub(crate) enum TenantMapError { /// this tenant to retry later, or wait for the InProgress state to end. /// /// This structure enforces the important invariant that we do not have overlapping -/// tasks that will try use local storage for a the same tenant ID: we enforce that +/// tasks that will try to use local storage for a the same tenant ID: we enforce that /// the previous contents of a slot have been shut down before the slot can be /// left empty or used for something else /// @@ -2440,7 +2634,7 @@ pub(crate) enum TenantMapError { /// The `old_value` may be dropped before the SlotGuard is dropped, by calling /// `drop_old_value`. It is an error to call this without shutting down /// the conents of `old_value`. -pub(crate) struct SlotGuard { +pub(crate) struct SlotGuard<'a> { tenant_shard_id: TenantShardId, old_value: Option, upserted: bool, @@ -2448,19 +2642,23 @@ pub(crate) struct SlotGuard { /// [`TenantSlot::InProgress`] carries the corresponding Barrier: it will /// release any waiters as soon as this SlotGuard is dropped. completion: utils::completion::Completion, + + tenants: &'a std::sync::RwLock, } -impl SlotGuard { +impl<'a> SlotGuard<'a> { fn new( tenant_shard_id: TenantShardId, old_value: Option, completion: utils::completion::Completion, + tenants: &'a std::sync::RwLock, ) -> Self { Self { tenant_shard_id, old_value, upserted: false, completion, + tenants, } } @@ -2484,8 +2682,8 @@ impl SlotGuard { )); } - let replaced = { - let mut locked = TENANTS.write().unwrap(); + let replaced: Option = { + let mut locked = self.tenants.write().unwrap(); if let TenantSlot::InProgress(_) = new_value { // It is never expected to try and upsert InProgress via this path: it should @@ -2593,7 +2791,7 @@ impl SlotGuard { } } -impl Drop for SlotGuard { +impl<'a> Drop for SlotGuard<'a> { fn drop(&mut self) { if self.upserted { return; @@ -2601,7 +2799,7 @@ impl Drop for SlotGuard { // Our old value is already shutdown, or it never existed: it is safe // for us to fully release the TenantSlot back into an empty state - let mut locked = TENANTS.write().unwrap(); + let mut locked = self.tenants.write().unwrap(); let m = match &mut *locked { TenantsMap::Initializing => { @@ -2683,151 +2881,6 @@ enum TenantSlotAcquireMode { MustExist, } -fn tenant_map_acquire_slot( - tenant_shard_id: &TenantShardId, - mode: TenantSlotAcquireMode, -) -> Result { - tenant_map_acquire_slot_impl(tenant_shard_id, &TENANTS, mode) -} - -fn tenant_map_acquire_slot_impl( - tenant_shard_id: &TenantShardId, - tenants: &std::sync::RwLock, - mode: TenantSlotAcquireMode, -) -> Result { - use TenantSlotAcquireMode::*; - METRICS.tenant_slot_writes.inc(); - - let mut locked = tenants.write().unwrap(); - let span = tracing::info_span!("acquire_slot", tenant_id=%tenant_shard_id.tenant_id, shard_id = %tenant_shard_id.shard_slug()); - let _guard = span.enter(); - - let m = match &mut *locked { - TenantsMap::Initializing => return Err(TenantMapError::StillInitializing.into()), - TenantsMap::ShuttingDown(_) => return Err(TenantMapError::ShuttingDown.into()), - TenantsMap::Open(m) => m, - }; - - use std::collections::btree_map::Entry; - - let entry = m.entry(*tenant_shard_id); - - match entry { - Entry::Vacant(v) => match mode { - MustExist => { - tracing::debug!("Vacant && MustExist: return NotFound"); - Err(TenantSlotError::NotFound(*tenant_shard_id)) - } - _ => { - let (completion, barrier) = utils::completion::channel(); - let inserting = TenantSlot::InProgress(barrier); - METRICS.slot_inserted(&inserting); - v.insert(inserting); - tracing::debug!("Vacant, inserted InProgress"); - Ok(SlotGuard::new(*tenant_shard_id, None, completion)) - } - }, - Entry::Occupied(mut o) => { - // Apply mode-driven checks - match (o.get(), mode) { - (TenantSlot::InProgress(_), _) => { - tracing::debug!("Occupied, failing for InProgress"); - Err(TenantSlotError::InProgress) - } - _ => { - // Happy case: the slot was not in any state that violated our mode - let (completion, barrier) = utils::completion::channel(); - let in_progress = TenantSlot::InProgress(barrier); - METRICS.slot_inserted(&in_progress); - let old_value = o.insert(in_progress); - METRICS.slot_removed(&old_value); - tracing::debug!("Occupied, replaced with InProgress"); - Ok(SlotGuard::new( - *tenant_shard_id, - Some(old_value), - completion, - )) - } - } - } - } -} - -/// Stops and removes the tenant from memory, if it's not [`TenantState::Stopping`] already, bails otherwise. -/// Allows to remove other tenant resources manually, via `tenant_cleanup`. -/// If the cleanup fails, tenant will stay in memory in [`TenantState::Broken`] state, and another removal -/// operation would be needed to remove it. -async fn remove_tenant_from_memory( - tenants: &std::sync::RwLock, - tenant_shard_id: TenantShardId, - tenant_cleanup: F, -) -> Result -where - F: std::future::Future>, -{ - let mut slot_guard = - tenant_map_acquire_slot_impl(&tenant_shard_id, tenants, TenantSlotAcquireMode::MustExist)?; - - // allow pageserver shutdown to await for our completion - let (_guard, progress) = completion::channel(); - - // The SlotGuard allows us to manipulate the Tenant object without fear of some - // concurrent API request doing something else for the same tenant ID. - let attached_tenant = match slot_guard.get_old_value() { - Some(TenantSlot::Attached(tenant)) => { - // whenever we remove a tenant from memory, we don't want to flush and wait for upload - let shutdown_mode = ShutdownMode::Hard; - - // shutdown is sure to transition tenant to stopping, and wait for all tasks to complete, so - // that we can continue safely to cleanup. - match tenant.shutdown(progress, shutdown_mode).await { - Ok(()) => {} - Err(_other) => { - // if pageserver shutdown or other detach/ignore is already ongoing, we don't want to - // wait for it but return an error right away because these are distinct requests. - slot_guard.revert(); - return Err(TenantStateError::IsStopping(tenant_shard_id)); - } - } - Some(tenant) - } - Some(TenantSlot::Secondary(secondary_state)) => { - tracing::info!("Shutting down in secondary mode"); - secondary_state.shutdown().await; - None - } - Some(TenantSlot::InProgress(_)) => { - // Acquiring a slot guarantees its old value was not InProgress - unreachable!(); - } - None => None, - }; - - match tenant_cleanup - .await - .with_context(|| format!("Failed to run cleanup for tenant {tenant_shard_id}")) - { - Ok(hook_value) => { - // Success: drop the old TenantSlot::Attached. - slot_guard - .drop_old_value() - .expect("We just called shutdown"); - - Ok(hook_value) - } - Err(e) => { - // If we had a Tenant, set it to Broken and put it back in the TenantsMap - if let Some(attached_tenant) = attached_tenant { - attached_tenant.set_broken(e.to_string()).await; - } - // Leave the broken tenant in the map - slot_guard.revert(); - - Err(TenantStateError::Other(e)) - } - } -} - use http_utils::error::ApiError; use pageserver_api::models::TimelineGcRequest; @@ -2838,11 +2891,19 @@ mod tests { use std::collections::BTreeMap; use std::sync::Arc; + use camino::Utf8PathBuf; + use storage_broker::BrokerClientChannel; use tracing::Instrument; use super::super::harness::TenantHarness; use super::TenantsMap; - use crate::tenant::mgr::TenantSlot; + use crate::{ + basebackup_cache::BasebackupCache, + tenant::{ + TenantSharedResources, + mgr::{BackgroundPurges, TenantManager, TenantSlot}, + }, + }; #[tokio::test(start_paused = true)] async fn shutdown_awaits_in_progress_tenant() { @@ -2863,23 +2924,45 @@ mod tests { let _e = span.enter(); let tenants = BTreeMap::from([(id, TenantSlot::Attached(t.clone()))]); - let tenants = Arc::new(std::sync::RwLock::new(TenantsMap::Open(tenants))); // Invoke remove_tenant_from_memory with a cleanup hook that blocks until we manually // permit it to proceed: that will stick the tenant in InProgress + let (basebackup_cache, _) = BasebackupCache::new(Utf8PathBuf::new(), None); + + let tenant_manager = TenantManager { + tenants: std::sync::RwLock::new(TenantsMap::Open(tenants)), + conf: h.conf, + resources: TenantSharedResources { + broker_client: BrokerClientChannel::connect_lazy("foobar.com") + .await + .unwrap(), + remote_storage: h.remote_storage.clone(), + deletion_queue_client: h.deletion_queue.new_client(), + l0_flush_global_state: crate::l0_flush::L0FlushGlobalState::new( + h.conf.l0_flush.clone(), + ), + basebackup_cache, + feature_resolver: crate::feature_resolver::FeatureResolver::new_disabled(), + }, + cancel: tokio_util::sync::CancellationToken::new(), + background_purges: BackgroundPurges::default(), + }; + + let tenant_manager = Arc::new(tenant_manager); + let (until_cleanup_completed, can_complete_cleanup) = utils::completion::channel(); let (until_cleanup_started, cleanup_started) = utils::completion::channel(); let mut remove_tenant_from_memory_task = { + let tenant_manager = tenant_manager.clone(); let jh = tokio::spawn({ - let tenants = tenants.clone(); async move { let cleanup = async move { drop(until_cleanup_started); can_complete_cleanup.wait().await; anyhow::Ok(()) }; - super::remove_tenant_from_memory(&tenants, id, cleanup).await + tenant_manager.remove_tenant_from_memory(id, cleanup).await } .instrument(h.span()) }); @@ -2892,9 +2975,11 @@ mod tests { let mut shutdown_task = { let (until_shutdown_started, shutdown_started) = utils::completion::channel(); + let tenant_manager = tenant_manager.clone(); + let shutdown_task = tokio::spawn(async move { drop(until_shutdown_started); - super::shutdown_all_tenants0(&tenants).await; + tenant_manager.shutdown_all_tenants0().await; }); shutdown_started.wait().await; diff --git a/pageserver/src/tenant/remote_timeline_client/index.rs b/pageserver/src/tenant/remote_timeline_client/index.rs index a5cd8989aa..6060c42cbb 100644 --- a/pageserver/src/tenant/remote_timeline_client/index.rs +++ b/pageserver/src/tenant/remote_timeline_client/index.rs @@ -427,8 +427,8 @@ impl GcBlocking { #[cfg(test)] mod tests { + use postgres_ffi::PgMajorVersion; use std::str::FromStr; - use utils::id::TimelineId; use super::*; @@ -831,7 +831,7 @@ mod tests { Lsn::INVALID, Lsn::from_str("0/1696070").unwrap(), Lsn::from_str("0/1696070").unwrap(), - 14, + PgMajorVersion::PG14, ).with_recalculated_checksum().unwrap(), deleted_at: Some(parse_naive_datetime("2023-07-31T09:00:00.123000000")), archived_at: None, @@ -893,7 +893,7 @@ mod tests { Lsn::INVALID, Lsn::from_str("0/1696070").unwrap(), Lsn::from_str("0/1696070").unwrap(), - 14, + PgMajorVersion::PG14, ).with_recalculated_checksum().unwrap(), deleted_at: Some(parse_naive_datetime("2023-07-31T09:00:00.123000000")), archived_at: Some(parse_naive_datetime("2023-04-29T09:00:00.123000000")), @@ -957,7 +957,7 @@ mod tests { Lsn::INVALID, Lsn::from_str("0/1696070").unwrap(), Lsn::from_str("0/1696070").unwrap(), - 14, + PgMajorVersion::PG14, ).with_recalculated_checksum().unwrap(), deleted_at: None, lineage: Default::default(), @@ -1033,7 +1033,7 @@ mod tests { Lsn::INVALID, Lsn::from_str("0/1696070").unwrap(), Lsn::from_str("0/1696070").unwrap(), - 14, + PgMajorVersion::PG14, ).with_recalculated_checksum().unwrap(), deleted_at: None, lineage: Default::default(), @@ -1114,7 +1114,7 @@ mod tests { Lsn::INVALID, Lsn::from_str("0/1696070").unwrap(), Lsn::from_str("0/1696070").unwrap(), - 14, + PgMajorVersion::PG14, ).with_recalculated_checksum().unwrap(), deleted_at: None, lineage: Default::default(), @@ -1199,7 +1199,7 @@ mod tests { Lsn::INVALID, Lsn::from_str("0/1696070").unwrap(), Lsn::from_str("0/1696070").unwrap(), - 14, + PgMajorVersion::PG14, ).with_recalculated_checksum().unwrap(), deleted_at: None, lineage: Default::default(), @@ -1287,7 +1287,7 @@ mod tests { Lsn::INVALID, Lsn::from_str("0/1696070").unwrap(), Lsn::from_str("0/1696070").unwrap(), - 14, + PgMajorVersion::PG14, ).with_recalculated_checksum().unwrap(), deleted_at: None, lineage: Default::default(), diff --git a/pageserver/src/tenant/remote_timeline_client/upload.rs b/pageserver/src/tenant/remote_timeline_client/upload.rs index 89f6136530..ffb4717d9f 100644 --- a/pageserver/src/tenant/remote_timeline_client/upload.rs +++ b/pageserver/src/tenant/remote_timeline_client/upload.rs @@ -1,6 +1,7 @@ //! Helper functions to upload files to remote storage with a RemoteStorage use std::io::{ErrorKind, SeekFrom}; +use std::num::NonZeroU32; use std::time::SystemTime; use anyhow::{Context, bail}; @@ -228,11 +229,25 @@ pub(crate) async fn time_travel_recover_tenant( let timelines_path = super::remote_timelines_path(tenant_shard_id); prefixes.push(timelines_path); } + + // Limit the number of versions deletions, mostly so that we don't + // keep requesting forever if the list is too long, as we'd put the + // list in RAM. + // Building a list of 100k entries that reaches the limit roughly takes + // 40 seconds, and roughly corresponds to tenants of 2 TiB physical size. + const COMPLEXITY_LIMIT: Option = NonZeroU32::new(100_000); + for prefix in &prefixes { backoff::retry( || async { storage - .time_travel_recover(Some(prefix), timestamp, done_if_after, cancel) + .time_travel_recover( + Some(prefix), + timestamp, + done_if_after, + cancel, + COMPLEXITY_LIMIT, + ) .await }, |e| !matches!(e, TimeTravelError::Other(_)), diff --git a/pageserver/src/tenant/secondary/downloader.rs b/pageserver/src/tenant/secondary/downloader.rs index dd49c843f3..6b315dc4bc 100644 --- a/pageserver/src/tenant/secondary/downloader.rs +++ b/pageserver/src/tenant/secondary/downloader.rs @@ -1427,7 +1427,7 @@ async fn init_timeline_state( let local_meta = dentry .metadata() .await - .fatal_err(&format!("Read metadata on {}", file_path)); + .fatal_err(&format!("Read metadata on {file_path}")); let file_name = file_path.file_name().expect("created it from the dentry"); if crate::is_temporary(&file_path) diff --git a/pageserver/src/tenant/storage_layer.rs b/pageserver/src/tenant/storage_layer.rs index 9d15e7c4de..9fbb9d2438 100644 --- a/pageserver/src/tenant/storage_layer.rs +++ b/pageserver/src/tenant/storage_layer.rs @@ -34,11 +34,11 @@ pub use layer_name::{DeltaLayerName, ImageLayerName, LayerName}; use pageserver_api::config::GetVectoredConcurrentIo; use pageserver_api::key::Key; use pageserver_api::keyspace::{KeySpace, KeySpaceRandomAccum}; -use pageserver_api::record::NeonWalRecord; -use pageserver_api::value::Value; use tracing::{Instrument, info_span, trace}; use utils::lsn::Lsn; use utils::sync::gate::GateGuard; +use wal_decoder::models::record::NeonWalRecord; +use wal_decoder::models::value::Value; use self::inmemory_layer::InMemoryLayerFileId; use super::PageReconstructError; @@ -109,7 +109,7 @@ pub(crate) enum OnDiskValue { /// Reconstruct data accumulated for a single key during a vectored get #[derive(Debug, Default)] -pub(crate) struct VectoredValueReconstructState { +pub struct VectoredValueReconstructState { pub(crate) on_disk_values: Vec<(Lsn, OnDiskValueIoWaiter)>, pub(crate) situation: ValueReconstructSituation, @@ -244,13 +244,60 @@ impl VectoredValueReconstructState { res } + + /// Benchmarking utility to await for the completion of all pending ios + /// + /// # Cancel-Safety + /// + /// Technically fine to stop polling this future, but, the IOs will still + /// be executed to completion by the sidecar task and hold on to / consume resources. + /// Better not do it to make reasonsing about the system easier. + #[cfg(feature = "benchmarking")] + pub async fn sink_pending_ios(self) -> Result<(), std::io::Error> { + let mut res = Ok(()); + + // We should try hard not to bail early, so that by the time we return from this + // function, all IO for this value is done. It's not required -- we could totally + // stop polling the IO futures in the sidecar task, they need to support that, + // but just stopping to poll doesn't reduce the IO load on the disk. It's easier + // to reason about the system if we just wait for all IO to complete, even if + // we're no longer interested in the result. + // + // Revisit this when IO futures are replaced with a more sophisticated IO system + // and an IO scheduler, where we know which IOs were submitted and which ones + // just queued. Cf the comment on IoConcurrency::spawn_io. + for (_lsn, waiter) in self.on_disk_values { + let value_recv_res = waiter + .wait_completion() + // we rely on the caller to poll us to completion, so this is not a bail point + .await; + + match (&mut res, value_recv_res) { + (Err(_), _) => { + // We've already failed, no need to process more. + } + (Ok(_), Err(_wait_err)) => { + // This shouldn't happen - likely the sidecar task panicked. + unreachable!(); + } + (Ok(_), Ok(Err(err))) => { + let err: std::io::Error = err; + res = Err(err); + } + (Ok(_ok), Ok(Ok(OnDiskValue::RawImage(_img)))) => {} + (Ok(_ok), Ok(Ok(OnDiskValue::WalRecordOrImage(_buf)))) => {} + } + } + + res + } } /// Bag of data accumulated during a vectored get.. -pub(crate) struct ValuesReconstructState { +pub struct ValuesReconstructState { /// The keys will be removed after `get_vectored` completes. The caller outside `Timeline` /// should not expect to get anything from this hashmap. - pub(crate) keys: HashMap, + pub keys: HashMap, /// The keys which are already retrieved keys_done: KeySpaceRandomAccum, @@ -272,7 +319,7 @@ pub(crate) struct ValuesReconstructState { /// The desired end state is that we always do parallel IO. /// This struct and the dispatching in the impl will be removed once /// we've built enough confidence. -pub(crate) enum IoConcurrency { +pub enum IoConcurrency { Sequential, SidecarTask { task_id: usize, @@ -317,10 +364,7 @@ impl IoConcurrency { Self::spawn(SelectedIoConcurrency::Sequential) } - pub(crate) fn spawn_from_conf( - conf: GetVectoredConcurrentIo, - gate_guard: GateGuard, - ) -> IoConcurrency { + pub fn spawn_from_conf(conf: GetVectoredConcurrentIo, gate_guard: GateGuard) -> IoConcurrency { let selected = match conf { GetVectoredConcurrentIo::Sequential => SelectedIoConcurrency::Sequential, GetVectoredConcurrentIo::SidecarTask => SelectedIoConcurrency::SidecarTask(gate_guard), @@ -425,16 +469,6 @@ impl IoConcurrency { } } - pub(crate) fn clone(&self) -> Self { - match self { - IoConcurrency::Sequential => IoConcurrency::Sequential, - IoConcurrency::SidecarTask { task_id, ios_tx } => IoConcurrency::SidecarTask { - task_id: *task_id, - ios_tx: ios_tx.clone(), - }, - } - } - /// Submit an IO to be executed in the background. DEADLOCK RISK, read the full doc string. /// /// The IO is represented as an opaque future. @@ -573,6 +607,18 @@ impl IoConcurrency { } } +impl Clone for IoConcurrency { + fn clone(&self) -> Self { + match self { + IoConcurrency::Sequential => IoConcurrency::Sequential, + IoConcurrency::SidecarTask { task_id, ios_tx } => IoConcurrency::SidecarTask { + task_id: *task_id, + ios_tx: ios_tx.clone(), + }, + } + } +} + /// Make noise in case the [`ValuesReconstructState`] gets dropped while /// there are still IOs in flight. /// Refer to `collect_pending_ios` for why we prefer not to do that. @@ -603,7 +649,7 @@ impl Drop for ValuesReconstructState { } impl ValuesReconstructState { - pub(crate) fn new(io_concurrency: IoConcurrency) -> Self { + pub fn new(io_concurrency: IoConcurrency) -> Self { Self { keys: HashMap::new(), keys_done: KeySpaceRandomAccum::new(), diff --git a/pageserver/src/tenant/storage_layer/batch_split_writer.rs b/pageserver/src/tenant/storage_layer/batch_split_writer.rs index 51f2e909a2..1d50a5f3a0 100644 --- a/pageserver/src/tenant/storage_layer/batch_split_writer.rs +++ b/pageserver/src/tenant/storage_layer/batch_split_writer.rs @@ -4,11 +4,11 @@ use std::sync::Arc; use bytes::Bytes; use pageserver_api::key::{KEY_SIZE, Key}; -use pageserver_api::value::Value; use tokio_util::sync::CancellationToken; use utils::id::TimelineId; use utils::lsn::Lsn; use utils::shard::TenantShardId; +use wal_decoder::models::value::Value; use super::errors::PutError; use super::layer::S3_UPLOAD_LIMIT; diff --git a/pageserver/src/tenant/storage_layer/delta_layer.rs b/pageserver/src/tenant/storage_layer/delta_layer.rs index 2c1b27c8d5..c2f76c859c 100644 --- a/pageserver/src/tenant/storage_layer/delta_layer.rs +++ b/pageserver/src/tenant/storage_layer/delta_layer.rs @@ -44,7 +44,6 @@ use pageserver_api::key::{DBDIR_KEY, KEY_SIZE, Key}; use pageserver_api::keyspace::KeySpace; use pageserver_api::models::ImageCompressionAlgorithm; use pageserver_api::shard::TenantShardId; -use pageserver_api::value::Value; use serde::{Deserialize, Serialize}; use tokio::sync::OnceCell; use tokio_epoll_uring::IoBuf; @@ -54,6 +53,7 @@ use utils::bin_ser::BeSer; use utils::bin_ser::SerializeError; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; +use wal_decoder::models::value::Value; use super::errors::PutError; use super::{ @@ -783,7 +783,7 @@ impl DeltaLayer { ctx, ) .await - .with_context(|| format!("Failed to open file '{}'", path))?; + .with_context(|| format!("Failed to open file '{path}'"))?; let file_id = page_cache::next_file_id(); let block_reader = FileBlockReader::new(&file, file_id); let summary_blk = block_reader.read_blk(0, ctx).await?; @@ -1306,7 +1306,7 @@ impl DeltaLayerInner { // is it an image or will_init walrecord? // FIXME: this could be handled by threading the BlobRef to the // VectoredReadBuilder - let will_init = pageserver_api::value::ValueBytes::will_init(&data) + let will_init = wal_decoder::models::value::ValueBytes::will_init(&data) .inspect_err(|_e| { #[cfg(feature = "testing")] tracing::error!(data=?utils::Hex(&data), err=?_e, %key, %lsn, "failed to parse will_init out of serialized value"); @@ -1369,7 +1369,7 @@ impl DeltaLayerInner { format!(" img {} bytes", img.len()) } Value::WalRecord(rec) => { - let wal_desc = pageserver_api::record::describe_wal_record(&rec)?; + let wal_desc = wal_decoder::models::record::describe_wal_record(&rec)?; format!( " rec {} bytes will_init: {} {}", buf.len(), @@ -1401,7 +1401,7 @@ impl DeltaLayerInner { match val { Value::Image(img) => { let checkpoint = CheckPoint::decode(&img)?; - println!(" CHECKPOINT: {:?}", checkpoint); + println!(" CHECKPOINT: {checkpoint:?}"); } Value::WalRecord(_rec) => { println!(" unexpected walrecord value for checkpoint key"); @@ -1622,12 +1622,6 @@ impl DeltaLayerIterator<'_> { pub(crate) mod test { use std::collections::BTreeMap; - use bytes::Bytes; - use itertools::MinMaxResult; - use pageserver_api::value::Value; - use rand::prelude::{SeedableRng, SliceRandom, StdRng}; - use rand::{Rng, RngCore}; - use super::*; use crate::DEFAULT_PG_VERSION; use crate::context::DownloadBehavior; @@ -1635,7 +1629,13 @@ pub(crate) mod test { use crate::tenant::disk_btree::tests::TestDisk; use crate::tenant::harness::{TIMELINE_ID, TenantHarness}; use crate::tenant::storage_layer::{Layer, ResidentLayer}; + use crate::tenant::timeline::layer_manager::LayerManagerLockHolder; use crate::tenant::{TenantShard, Timeline}; + use bytes::Bytes; + use itertools::MinMaxResult; + use postgres_ffi::PgMajorVersion; + use rand::prelude::{SeedableRng, SliceRandom, StdRng}; + use rand::{Rng, RngCore}; /// Construct an index for a fictional delta layer and and then /// traverse in order to plan vectored reads for a query. Finally, @@ -1987,7 +1987,7 @@ pub(crate) mod test { #[tokio::test] async fn copy_delta_prefix_smoke() { use bytes::Bytes; - use pageserver_api::record::NeonWalRecord; + use wal_decoder::models::record::NeonWalRecord; let h = crate::tenant::harness::TenantHarness::create("truncate_delta_smoke") .await @@ -1995,14 +1995,14 @@ pub(crate) mod test { let (tenant, ctx) = h.load().await; let ctx = &ctx; let timeline = tenant - .create_test_timeline(TimelineId::generate(), Lsn(0x10), 14, ctx) + .create_test_timeline(TimelineId::generate(), Lsn(0x10), PgMajorVersion::PG14, ctx) .await .unwrap(); let ctx = &ctx.with_scope_timeline(&timeline); let initdb_layer = timeline .layers - .read() + .read(crate::tenant::timeline::layer_manager::LayerManagerLockHolder::Testing) .await .likely_resident_layers() .next() @@ -2078,7 +2078,7 @@ pub(crate) mod test { let new_layer = timeline .layers - .read() + .read(LayerManagerLockHolder::Testing) .await .likely_resident_layers() .find(|&x| x != &initdb_layer) diff --git a/pageserver/src/tenant/storage_layer/filter_iterator.rs b/pageserver/src/tenant/storage_layer/filter_iterator.rs index 1a330ecfc2..d345195446 100644 --- a/pageserver/src/tenant/storage_layer/filter_iterator.rs +++ b/pageserver/src/tenant/storage_layer/filter_iterator.rs @@ -4,8 +4,8 @@ use std::sync::Arc; use anyhow::bail; use pageserver_api::key::Key; use pageserver_api::keyspace::{KeySpace, SparseKeySpace}; -use pageserver_api::value::Value; use utils::lsn::Lsn; +use wal_decoder::models::value::Value; use super::PersistentLayerKey; use super::merge_iterator::{MergeIterator, MergeIteratorItem}; @@ -126,7 +126,6 @@ mod tests { #[tokio::test] async fn filter_keyspace_iterator() { use bytes::Bytes; - use pageserver_api::value::Value; let harness = TenantHarness::create("filter_iterator_filter_keyspace_iterator") .await diff --git a/pageserver/src/tenant/storage_layer/image_layer.rs b/pageserver/src/tenant/storage_layer/image_layer.rs index 740f53f928..9f76f697d3 100644 --- a/pageserver/src/tenant/storage_layer/image_layer.rs +++ b/pageserver/src/tenant/storage_layer/image_layer.rs @@ -42,7 +42,6 @@ use pageserver_api::config::MaxVectoredReadBytes; use pageserver_api::key::{DBDIR_KEY, KEY_SIZE, Key}; use pageserver_api::keyspace::KeySpace; use pageserver_api::shard::{ShardIdentity, TenantShardId}; -use pageserver_api::value::Value; use serde::{Deserialize, Serialize}; use tokio::sync::OnceCell; use tokio_stream::StreamExt; @@ -52,6 +51,7 @@ use utils::bin_ser::BeSer; use utils::bin_ser::SerializeError; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; +use wal_decoder::models::value::Value; use super::errors::PutError; use super::layer_name::ImageLayerName; @@ -272,8 +272,7 @@ impl ImageLayer { conf.timeline_path(&tenant_shard_id, &timeline_id) .join(format!( - "{fname}.{:x}.{TEMP_FILE_SUFFIX}", - filename_disambiguator + "{fname}.{filename_disambiguator:x}.{TEMP_FILE_SUFFIX}" )) } @@ -370,7 +369,7 @@ impl ImageLayer { ctx, ) .await - .with_context(|| format!("Failed to open file '{}'", path))?; + .with_context(|| format!("Failed to open file '{path}'"))?; let file_id = page_cache::next_file_id(); let block_reader = FileBlockReader::new(&file, file_id); let summary_blk = block_reader.read_blk(0, ctx).await?; @@ -1232,10 +1231,10 @@ mod test { use itertools::Itertools; use pageserver_api::key::Key; use pageserver_api::shard::{ShardCount, ShardIdentity, ShardNumber, ShardStripeSize}; - use pageserver_api::value::Value; use utils::generation::Generation; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; + use wal_decoder::models::value::Value; use super::{ImageLayerIterator, ImageLayerWriter}; use crate::DEFAULT_PG_VERSION; @@ -1475,7 +1474,7 @@ mod test { assert_eq!(l1, expect_lsn); assert_eq!(&i1, i2); } - (o1, o2) => panic!("iterators length mismatch: {:?}, {:?}", o1, o2), + (o1, o2) => panic!("iterators length mismatch: {o1:?}, {o2:?}"), } } } diff --git a/pageserver/src/tenant/storage_layer/inmemory_layer.rs b/pageserver/src/tenant/storage_layer/inmemory_layer.rs index 200beba115..c4d53c6405 100644 --- a/pageserver/src/tenant/storage_layer/inmemory_layer.rs +++ b/pageserver/src/tenant/storage_layer/inmemory_layer.rs @@ -70,23 +70,15 @@ pub struct InMemoryLayer { /// We use a separate lock for the index to reduce the critical section /// during which reads cannot be planned. /// - /// If you need access to both the index and the underlying file at the same time, - /// respect the following locking order to avoid deadlocks: - /// 1. [`InMemoryLayer::inner`] - /// 2. [`InMemoryLayer::index`] - /// - /// Note that the file backing [`InMemoryLayer::inner`] is append-only, - /// so it is not necessary to hold simultaneous locks on index. - /// This avoids holding index locks across IO, and is crucial for avoiding read tail latency. + /// Note that the file backing [`InMemoryLayer::file`] is append-only, + /// so it is not necessary to hold a lock on the index while reading or writing from the file. /// In particular: - /// 1. It is safe to read and release [`InMemoryLayer::index`] before locking and reading from [`InMemoryLayer::inner`]. - /// 2. It is safe to write and release [`InMemoryLayer::inner`] before locking and updating [`InMemoryLayer::index`]. + /// 1. It is safe to read and release [`InMemoryLayer::index`] before reading from [`InMemoryLayer::file`]. + /// 2. It is safe to write to [`InMemoryLayer::file`] before locking and updating [`InMemoryLayer::index`]. index: RwLock>>, - /// The above fields never change, except for `end_lsn`, which is only set once, - /// and `index` (see rationale there). - /// All other changing parts are in `inner`, and protected by a mutex. - inner: RwLock, + /// Wrapper for the actual on-disk file. Uses interior mutability for concurrent reads/writes. + file: EphemeralFile, estimated_in_mem_size: AtomicU64, } @@ -96,20 +88,10 @@ impl std::fmt::Debug for InMemoryLayer { f.debug_struct("InMemoryLayer") .field("start_lsn", &self.start_lsn) .field("end_lsn", &self.end_lsn) - .field("inner", &self.inner) .finish() } } -pub struct InMemoryLayerInner { - /// The values are stored in a serialized format in this file. - /// Each serialized Value is preceded by a 'u32' length field. - /// PerSeg::page_versions map stores offsets into this file. - file: EphemeralFile, - - resource_units: GlobalResourceUnits, -} - /// Support the same max blob length as blob_io, because ultimately /// all the InMemoryLayer contents end up being written into a delta layer, /// using the [`crate::tenant::blob_io`]. @@ -258,12 +240,6 @@ struct IndexEntryUnpacked { pos: u64, } -impl std::fmt::Debug for InMemoryLayerInner { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("InMemoryLayerInner").finish() - } -} - /// State shared by all in-memory (ephemeral) layers. Updated infrequently during background ticks in Timeline, /// to minimize contention. /// @@ -280,7 +256,7 @@ pub(crate) struct GlobalResources { } // Per-timeline RAII struct for its contribution to [`GlobalResources`] -struct GlobalResourceUnits { +pub(crate) struct GlobalResourceUnits { // How many dirty bytes have I added to the global dirty_bytes: this guard object is responsible // for decrementing the global counter by this many bytes when dropped. dirty_bytes: u64, @@ -292,7 +268,7 @@ impl GlobalResourceUnits { // updated when the Timeline "ticks" in the background. const MAX_SIZE_DRIFT: u64 = 10 * 1024 * 1024; - fn new() -> Self { + pub(crate) fn new() -> Self { GLOBAL_RESOURCES .dirty_layers .fetch_add(1, AtomicOrdering::Relaxed); @@ -304,7 +280,7 @@ impl GlobalResourceUnits { /// /// Returns the effective layer size limit that should be applied, if any, to keep /// the total number of dirty bytes below the configured maximum. - fn publish_size(&mut self, size: u64) -> Option { + pub(crate) fn publish_size(&mut self, size: u64) -> Option { let new_global_dirty_bytes = match size.cmp(&self.dirty_bytes) { Ordering::Equal => GLOBAL_RESOURCES.dirty_bytes.load(AtomicOrdering::Relaxed), Ordering::Greater => { @@ -349,7 +325,7 @@ impl GlobalResourceUnits { // Call publish_size if the input size differs from last published size by more than // the drift limit - fn maybe_publish_size(&mut self, size: u64) { + pub(crate) fn maybe_publish_size(&mut self, size: u64) { let publish = match size.cmp(&self.dirty_bytes) { Ordering::Equal => false, Ordering::Greater => size - self.dirty_bytes > Self::MAX_SIZE_DRIFT, @@ -398,8 +374,8 @@ impl InMemoryLayer { } } - pub(crate) fn try_len(&self) -> Option { - self.inner.try_read().map(|i| i.file.len()).ok() + pub(crate) fn len(&self) -> u64 { + self.file.len() } pub(crate) fn assert_writable(&self) { @@ -430,7 +406,7 @@ impl InMemoryLayer { // Look up the keys in the provided keyspace and update // the reconstruct state with whatever is found. - pub(crate) async fn get_values_reconstruct_data( + pub async fn get_values_reconstruct_data( self: &Arc, keyspace: KeySpace, lsn_range: Range, @@ -479,14 +455,13 @@ impl InMemoryLayer { } } } - drop(index); // release the lock before we spawn the IO; if it's serial-mode IO we will deadlock on the read().await below + drop(index); // release the lock before we spawn the IO let read_from = Arc::clone(self); let read_ctx = ctx.attached_child(); reconstruct_state .spawn_io(async move { - let inner = read_from.inner.read().await; let f = vectored_dio_read::execute( - &inner.file, + &read_from.file, reads .iter() .flat_map(|(_, value_reads)| value_reads.iter().map(|v| &v.read)), @@ -518,7 +493,6 @@ impl InMemoryLayer { // This is kinda forced for InMemoryLayer because we need to inner.read() anyway, // but it's less obvious for DeltaLayer and ImageLayer. So, keep this explicit // drop for consistency among all three layer types. - drop(inner); drop(read_from); }) .await; @@ -537,7 +511,7 @@ fn inmem_layer_log_display( start_lsn: Lsn, end_lsn: Lsn, ) -> std::fmt::Result { - write!(f, "timeline {} in-memory ", timeline)?; + write!(f, "timeline {timeline} in-memory ")?; inmem_layer_display(f, start_lsn, end_lsn) } @@ -549,12 +523,6 @@ impl std::fmt::Display for InMemoryLayer { } impl InMemoryLayer { - /// Get layer size. - pub async fn size(&self) -> Result { - let inner = self.inner.read().await; - Ok(inner.file.len()) - } - pub fn estimated_in_mem_size(&self) -> u64 { self.estimated_in_mem_size.load(AtomicOrdering::Relaxed) } @@ -587,10 +555,7 @@ impl InMemoryLayer { end_lsn: OnceLock::new(), opened_at: Instant::now(), index: RwLock::new(BTreeMap::new()), - inner: RwLock::new(InMemoryLayerInner { - file, - resource_units: GlobalResourceUnits::new(), - }), + file, estimated_in_mem_size: AtomicU64::new(0), }) } @@ -599,41 +564,37 @@ impl InMemoryLayer { /// /// Errors are not retryable, the [`InMemoryLayer`] must be discarded, and not be read from. /// The reason why it's not retryable is that the [`EphemeralFile`] writes are not retryable. + /// + /// This method shall not be called concurrently. We enforce this property via [`crate::tenant::Timeline::write_lock`]. + /// /// TODO: it can be made retryable if we aborted the process on EphemeralFile write errors. pub async fn put_batch( &self, serialized_batch: SerializedValueBatch, ctx: &RequestContext, ) -> anyhow::Result<()> { - let (base_offset, metadata) = { - let mut inner = self.inner.write().await; - self.assert_writable(); + self.assert_writable(); - let base_offset = inner.file.len(); + let base_offset = self.file.len(); - let SerializedValueBatch { - raw, - metadata, - max_lsn: _, - len: _, - } = serialized_batch; + let SerializedValueBatch { + raw, + metadata, + max_lsn: _, + len: _, + } = serialized_batch; - // Write the batch to the file - inner.file.write_raw(&raw, ctx).await?; - let new_size = inner.file.len(); + // Write the batch to the file + self.file.write_raw(&raw, ctx).await?; + let new_size = self.file.len(); - let expected_new_len = base_offset - .checked_add(raw.len().into_u64()) - // write_raw would error if we were to overflow u64. - // also IndexEntry and higher levels in - //the code don't allow the file to grow that large - .unwrap(); - assert_eq!(new_size, expected_new_len); - - inner.resource_units.maybe_publish_size(new_size); - - (base_offset, metadata) - }; + let expected_new_len = base_offset + .checked_add(raw.len().into_u64()) + // write_raw would error if we were to overflow u64. + // also IndexEntry and higher levels in + //the code don't allow the file to grow that large + .unwrap(); + assert_eq!(new_size, expected_new_len); // Update the index with the new entries let mut index = self.index.write().await; @@ -686,10 +647,8 @@ impl InMemoryLayer { self.opened_at } - pub(crate) async fn tick(&self) -> Option { - let mut inner = self.inner.write().await; - let size = inner.file.len(); - inner.resource_units.publish_size(size) + pub(crate) fn tick(&self) -> Option { + self.file.tick() } pub(crate) async fn put_tombstones(&self, _key_ranges: &[(Range, Lsn)]) -> Result<()> { @@ -753,12 +712,6 @@ impl InMemoryLayer { gate: &utils::sync::gate::Gate, cancel: CancellationToken, ) -> Result> { - // Grab the lock in read-mode. We hold it over the I/O, but because this - // layer is not writeable anymore, no one should be trying to acquire the - // write lock on it, so we shouldn't block anyone. See the comment on - // [`InMemoryLayer::freeze`] to understand how locking between the append path - // and layer flushing works. - let inner = self.inner.read().await; let index = self.index.read().await; use l0_flush::Inner; @@ -793,7 +746,7 @@ impl InMemoryLayer { match l0_flush_global_state { l0_flush::Inner::Direct { .. } => { - let file_contents = inner.file.load_to_io_buf(ctx).await?; + let file_contents = self.file.load_to_io_buf(ctx).await?; let file_contents = file_contents.freeze(); for (key, vec_map) in index.iter() { diff --git a/pageserver/src/tenant/storage_layer/inmemory_layer/vectored_dio_read.rs b/pageserver/src/tenant/storage_layer/inmemory_layer/vectored_dio_read.rs index ea354fc716..27fbc6f5fb 100644 --- a/pageserver/src/tenant/storage_layer/inmemory_layer/vectored_dio_read.rs +++ b/pageserver/src/tenant/storage_layer/inmemory_layer/vectored_dio_read.rs @@ -380,7 +380,7 @@ impl std::fmt::Debug for LogicalReadState { write!(f, "Ongoing({:?})", BufferDebug::from(b as &dyn Buffer)) } LogicalReadState::Ok(b) => write!(f, "Ok({:?})", BufferDebug::from(b as &dyn Buffer)), - LogicalReadState::Error(e) => write!(f, "Error({:?})", e), + LogicalReadState::Error(e) => write!(f, "Error({e:?})"), LogicalReadState::Undefined => write!(f, "Undefined"), } } diff --git a/pageserver/src/tenant/storage_layer/layer.rs b/pageserver/src/tenant/storage_layer/layer.rs index 3d55972017..0be13e67a8 100644 --- a/pageserver/src/tenant/storage_layer/layer.rs +++ b/pageserver/src/tenant/storage_layer/layer.rs @@ -105,7 +105,7 @@ impl std::fmt::Display for Layer { impl std::fmt::Debug for Layer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } diff --git a/pageserver/src/tenant/storage_layer/layer/tests.rs b/pageserver/src/tenant/storage_layer/layer/tests.rs index b6fd4678d6..9bdce163c9 100644 --- a/pageserver/src/tenant/storage_layer/layer/tests.rs +++ b/pageserver/src/tenant/storage_layer/layer/tests.rs @@ -1,6 +1,7 @@ use std::time::UNIX_EPOCH; use pageserver_api::key::{CONTROLFILE_KEY, Key}; +use postgres_ffi::PgMajorVersion; use tokio::task::JoinSet; use utils::completion::{self, Completion}; use utils::id::TimelineId; @@ -10,6 +11,7 @@ use super::*; use crate::context::DownloadBehavior; use crate::tenant::harness::{TenantHarness, test_img}; use crate::tenant::storage_layer::{IoConcurrency, LayerVisibilityHint}; +use crate::tenant::timeline::layer_manager::LayerManagerLockHolder; /// Used in tests to advance a future to wanted await point, and not futher. const ADVANCE: std::time::Duration = std::time::Duration::from_secs(3600); @@ -44,7 +46,7 @@ async fn smoke_test() { .create_test_timeline_with_layers( TimelineId::generate(), Lsn(0x10), - 14, + PgMajorVersion::PG14, &ctx, Default::default(), // in-memory layers Default::default(), @@ -59,7 +61,7 @@ async fn smoke_test() { // there to avoid the timeline being illegally empty let (layer, dummy_layer) = { let mut layers = { - let layers = timeline.layers.read().await; + let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await; layers.likely_resident_layers().cloned().collect::>() }; @@ -215,7 +217,7 @@ async fn smoke_test() { // Simulate GC removing our test layer. { - let mut g = timeline.layers.write().await; + let mut g = timeline.layers.write(LayerManagerLockHolder::Testing).await; let layers = &[layer]; g.open_mut().unwrap().finish_gc_timeline(layers); @@ -255,13 +257,18 @@ async fn evict_and_wait_on_wanted_deleted() { let (tenant, ctx) = h.load().await; let timeline = tenant - .create_test_timeline(TimelineId::generate(), Lsn(0x10), 14, &ctx) + .create_test_timeline( + TimelineId::generate(), + Lsn(0x10), + PgMajorVersion::PG14, + &ctx, + ) .await .unwrap(); let layer = { let mut layers = { - let layers = timeline.layers.read().await; + let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await; layers.likely_resident_layers().cloned().collect::>() }; @@ -305,7 +312,7 @@ async fn evict_and_wait_on_wanted_deleted() { // assert that once we remove the `layer` from the layer map and drop our reference, // the deletion of the layer in remote_storage happens. { - let mut layers = timeline.layers.write().await; + let mut layers = timeline.layers.write(LayerManagerLockHolder::Testing).await; layers.open_mut().unwrap().finish_gc_timeline(&[layer]); } @@ -340,14 +347,19 @@ fn read_wins_pending_eviction() { let download_span = span.in_scope(|| tracing::info_span!("downloading", timeline_id = 1)); let timeline = tenant - .create_test_timeline(TimelineId::generate(), Lsn(0x10), 14, &ctx) + .create_test_timeline( + TimelineId::generate(), + Lsn(0x10), + PgMajorVersion::PG14, + &ctx, + ) .await .unwrap(); let ctx = ctx.with_scope_timeline(&timeline); let layer = { let mut layers = { - let layers = timeline.layers.read().await; + let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await; layers.likely_resident_layers().cloned().collect::>() }; @@ -473,14 +485,19 @@ fn multiple_pending_evictions_scenario(name: &'static str, in_order: bool) { let download_span = span.in_scope(|| tracing::info_span!("downloading", timeline_id = 1)); let timeline = tenant - .create_test_timeline(TimelineId::generate(), Lsn(0x10), 14, &ctx) + .create_test_timeline( + TimelineId::generate(), + Lsn(0x10), + PgMajorVersion::PG14, + &ctx, + ) .await .unwrap(); let ctx = ctx.with_scope_timeline(&timeline); let layer = { let mut layers = { - let layers = timeline.layers.read().await; + let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await; layers.likely_resident_layers().cloned().collect::>() }; @@ -643,7 +660,12 @@ async fn cancelled_get_or_maybe_download_does_not_cancel_eviction() { let (tenant, ctx) = h.load().await; let timeline = tenant - .create_test_timeline(TimelineId::generate(), Lsn(0x10), 14, &ctx) + .create_test_timeline( + TimelineId::generate(), + Lsn(0x10), + PgMajorVersion::PG14, + &ctx, + ) .await .unwrap(); let ctx = ctx.with_scope_timeline(&timeline); @@ -655,7 +677,7 @@ async fn cancelled_get_or_maybe_download_does_not_cancel_eviction() { let layer = { let mut layers = { - let layers = timeline.layers.read().await; + let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await; layers.likely_resident_layers().cloned().collect::>() }; @@ -729,7 +751,12 @@ async fn evict_and_wait_does_not_wait_for_download() { let download_span = span.in_scope(|| tracing::info_span!("downloading", timeline_id = 1)); let timeline = tenant - .create_test_timeline(TimelineId::generate(), Lsn(0x10), 14, &ctx) + .create_test_timeline( + TimelineId::generate(), + Lsn(0x10), + PgMajorVersion::PG14, + &ctx, + ) .await .unwrap(); let ctx = ctx.with_scope_timeline(&timeline); @@ -741,7 +768,7 @@ async fn evict_and_wait_does_not_wait_for_download() { let layer = { let mut layers = { - let layers = timeline.layers.read().await; + let layers = timeline.layers.read(LayerManagerLockHolder::Testing).await; layers.likely_resident_layers().cloned().collect::>() }; @@ -823,7 +850,7 @@ async fn evict_and_wait_does_not_wait_for_download() { #[tokio::test(start_paused = true)] async fn eviction_cancellation_on_drop() { use bytes::Bytes; - use pageserver_api::value::Value; + use wal_decoder::models::value::Value; // this is the runtime on which Layer spawns the blocking tasks on let handle = tokio::runtime::Handle::current(); @@ -835,7 +862,12 @@ async fn eviction_cancellation_on_drop() { let (tenant, ctx) = h.load().await; let timeline = tenant - .create_test_timeline(TimelineId::generate(), Lsn(0x10), 14, &ctx) + .create_test_timeline( + TimelineId::generate(), + Lsn(0x10), + PgMajorVersion::PG14, + &ctx, + ) .await .unwrap(); @@ -862,7 +894,7 @@ async fn eviction_cancellation_on_drop() { let (evicted_layer, not_evicted) = { let mut layers = { - let mut guard = timeline.layers.write().await; + let mut guard = timeline.layers.write(LayerManagerLockHolder::Testing).await; let layers = guard.likely_resident_layers().cloned().collect::>(); // remove the layers from layermap guard.open_mut().unwrap().finish_gc_timeline(&layers); diff --git a/pageserver/src/tenant/storage_layer/merge_iterator.rs b/pageserver/src/tenant/storage_layer/merge_iterator.rs index ea3dea50c3..c15abcdf3f 100644 --- a/pageserver/src/tenant/storage_layer/merge_iterator.rs +++ b/pageserver/src/tenant/storage_layer/merge_iterator.rs @@ -4,8 +4,8 @@ use std::sync::Arc; use anyhow::bail; use pageserver_api::key::Key; -use pageserver_api::value::Value; use utils::lsn::Lsn; +use wal_decoder::models::value::Value; use super::delta_layer::{DeltaLayerInner, DeltaLayerIterator}; use super::image_layer::{ImageLayerInner, ImageLayerIterator}; @@ -402,9 +402,9 @@ impl<'a> MergeIterator<'a> { mod tests { use itertools::Itertools; use pageserver_api::key::Key; - #[cfg(feature = "testing")] - use pageserver_api::record::NeonWalRecord; use utils::lsn::Lsn; + #[cfg(feature = "testing")] + use wal_decoder::models::record::NeonWalRecord; use super::*; use crate::DEFAULT_PG_VERSION; @@ -436,7 +436,6 @@ mod tests { #[tokio::test] async fn merge_in_between() { use bytes::Bytes; - use pageserver_api::value::Value; let harness = TenantHarness::create("merge_iterator_merge_in_between") .await @@ -501,7 +500,6 @@ mod tests { #[tokio::test] async fn delta_merge() { use bytes::Bytes; - use pageserver_api::value::Value; let harness = TenantHarness::create("merge_iterator_delta_merge") .await @@ -578,7 +576,6 @@ mod tests { #[tokio::test] async fn delta_image_mixed_merge() { use bytes::Bytes; - use pageserver_api::value::Value; let harness = TenantHarness::create("merge_iterator_delta_image_mixed_merge") .await diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 6798606141..bec2f0ed52 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -35,7 +35,11 @@ use fail::fail_point; use futures::stream::FuturesUnordered; use futures::{FutureExt, StreamExt}; use handle::ShardTimelineId; -use layer_manager::Shutdown; +use layer_manager::{ + LayerManagerLockHolder, LayerManagerReadGuard, LayerManagerWriteGuard, LockedLayerManager, + Shutdown, +}; + use offload::OffloadError; use once_cell::sync::Lazy; use pageserver_api::config::tenant_conf_defaults::DEFAULT_PITR_INTERVAL; @@ -52,11 +56,9 @@ use pageserver_api::models::{ }; use pageserver_api::reltag::{BlockNumber, RelTag}; use pageserver_api::shard::{ShardIdentity, ShardIndex, ShardNumber, TenantShardId}; -#[cfg(test)] -use pageserver_api::value::Value; use postgres_connection::PgConnectionConfig; use postgres_ffi::v14::xlog_utils; -use postgres_ffi::{WAL_SEGMENT_SIZE, to_pg_timestamp}; +use postgres_ffi::{PgMajorVersion, WAL_SEGMENT_SIZE, to_pg_timestamp}; use rand::Rng; use remote_storage::DownloadError; use serde_with::serde_as; @@ -77,12 +79,13 @@ use utils::seqwait::SeqWait; use utils::simple_rcu::{Rcu, RcuReadGuard}; use utils::sync::gate::{Gate, GateGuard}; use utils::{completion, critical, fs_ext, pausable_failpoint}; +#[cfg(test)] +use wal_decoder::models::value::Value; use wal_decoder::serialized_batch::{SerializedValueBatch, ValueMeta}; use self::delete::DeleteTimelineFlow; pub(super) use self::eviction_task::EvictionTaskTenantState; use self::eviction_task::EvictionTaskTimelineState; -use self::layer_manager::LayerManager; use self::logical_size::LogicalSize; use self::walreceiver::{WalReceiver, WalReceiverConf}; use super::remote_timeline_client::RemoteTimelineClient; @@ -92,12 +95,12 @@ use super::storage_layer::{LayerFringe, LayerVisibilityHint, ReadableLayer}; use super::tasks::log_compaction_error; use super::upload_queue::NotInitialized; use super::{ - AttachedTenantConf, BasebackupPrepareSender, GcError, HeatMapTimeline, MaybeOffloaded, + AttachedTenantConf, GcError, HeatMapTimeline, MaybeOffloaded, debug_assert_current_span_has_tenant_and_timeline_id, }; use crate::PERF_TRACE_TARGET; use crate::aux_file::AuxFileSizeEstimator; -use crate::basebackup_cache::BasebackupPrepareRequest; +use crate::basebackup_cache::BasebackupCache; use crate::config::PageServerConf; use crate::context::{ DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder, @@ -175,19 +178,19 @@ pub enum LastImageLayerCreationStatus { impl std::fmt::Display for ImageLayerCreationMode { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) + write!(f, "{self:?}") } } /// Temporary function for immutable storage state refactor, ensures we are dropping mutex guard instead of other things. /// Can be removed after all refactors are done. -fn drop_rlock(rlock: tokio::sync::RwLockReadGuard) { +fn drop_layer_manager_rlock(rlock: LayerManagerReadGuard<'_>) { drop(rlock) } /// Temporary function for immutable storage state refactor, ensures we are dropping mutex guard instead of other things. /// Can be removed after all refactors are done. -fn drop_wlock(rlock: tokio::sync::RwLockWriteGuard<'_, T>) { +fn drop_layer_manager_wlock(rlock: LayerManagerWriteGuard<'_>) { drop(rlock) } @@ -198,7 +201,7 @@ pub struct TimelineResources { pub pagestream_throttle_metrics: Arc, pub l0_compaction_trigger: Arc, pub l0_flush_global_state: l0_flush::L0FlushGlobalState, - pub basebackup_prepare_sender: BasebackupPrepareSender, + pub basebackup_cache: Arc, pub feature_resolver: FeatureResolver, } @@ -222,7 +225,7 @@ pub struct Timeline { /// to shards, and is constant through the lifetime of this Timeline. shard_identity: ShardIdentity, - pub pg_version: u32, + pub pg_version: PgMajorVersion, /// The tuple has two elements. /// 1. `LayerFileManager` keeps track of the various physical representations of the layer files (inmem, local, remote). @@ -241,7 +244,7 @@ pub struct Timeline { /// /// In the future, we'll be able to split up the tuple of LayerMap and `LayerFileManager`, /// so that e.g. on-demand-download/eviction, and layer spreading, can operate just on `LayerFileManager`. - pub(crate) layers: tokio::sync::RwLock, + pub(crate) layers: LockedLayerManager, last_freeze_at: AtomicLsn, // Atomic would be more appropriate here. @@ -445,7 +448,7 @@ pub struct Timeline { wait_lsn_log_slow: tokio::sync::Semaphore, /// A channel to send async requests to prepare a basebackup for the basebackup cache. - basebackup_prepare_sender: BasebackupPrepareSender, + basebackup_cache: Arc, feature_resolver: FeatureResolver, } @@ -629,7 +632,7 @@ pub enum ReadPathLayerId { impl std::fmt::Display for ReadPathLayerId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - ReadPathLayerId::PersistentLayer(key) => write!(f, "{}", key), + ReadPathLayerId::PersistentLayer(key) => write!(f, "{key}"), ReadPathLayerId::InMemoryLayer(range) => { write!(f, "in-mem {}..{}", range.start, range.end) } @@ -705,7 +708,7 @@ impl MissingKeyError { impl std::fmt::Debug for MissingKeyError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self) + write!(f, "{self}") } } @@ -718,19 +721,19 @@ impl std::fmt::Display for MissingKeyError { )?; if let Some(ref ancestor_lsn) = self.ancestor_lsn { - write!(f, ", ancestor {}", ancestor_lsn)?; + write!(f, ", ancestor {ancestor_lsn}")?; } if let Some(ref query) = self.query { - write!(f, ", query {}", query)?; + write!(f, ", query {query}")?; } if let Some(ref read_path) = self.read_path { - write!(f, "\n{}", read_path)?; + write!(f, "\n{read_path}")?; } if let Some(ref backtrace) = self.backtrace { - write!(f, "\n{}", backtrace)?; + write!(f, "\n{backtrace}")?; } Ok(()) @@ -813,7 +816,7 @@ impl From for FlushLayerError { } #[derive(thiserror::Error, Debug)] -pub(crate) enum GetVectoredError { +pub enum GetVectoredError { #[error("timeline shutting down")] Cancelled, @@ -846,7 +849,7 @@ impl From for GetVectoredError { } #[derive(thiserror::Error, Debug)] -pub(crate) enum GetReadyAncestorError { +pub enum GetReadyAncestorError { #[error("ancestor LSN wait error")] AncestorLsnTimeout(#[from] WaitLsnError), @@ -936,7 +939,7 @@ impl std::fmt::Debug for Timeline { } #[derive(thiserror::Error, Debug, Clone)] -pub(crate) enum WaitLsnError { +pub enum WaitLsnError { // Called on a timeline which is shutting down #[error("Shutdown")] Shutdown, @@ -1055,8 +1058,8 @@ pub(crate) enum WaitLsnWaiter<'a> { /// Argument to [`Timeline::shutdown`]. #[derive(Debug, Clone, Copy)] pub(crate) enum ShutdownMode { - /// Graceful shutdown, may do a lot of I/O as we flush any open layers to disk and then - /// also to remote storage. This method can easily take multiple seconds for a busy timeline. + /// Graceful shutdown, may do a lot of I/O as we flush any open layers to disk. This method can + /// take multiple seconds for a busy timeline. /// /// While we are flushing, we continue to accept read I/O for LSNs ingested before /// the call to [`Timeline::shutdown`]. @@ -1535,7 +1538,10 @@ impl Timeline { /// This method makes no distinction between local and remote layers. /// Hence, the result **does not represent local filesystem usage**. pub(crate) async fn layer_size_sum(&self) -> u64 { - let guard = self.layers.read().await; + let guard = self + .layers + .read(LayerManagerLockHolder::GetLayerMapInfo) + .await; guard.layer_size_sum() } @@ -1845,7 +1851,7 @@ impl Timeline { // time, and this was missed. // if write_guard.is_none() { return; } - let Ok(layers_guard) = self.layers.try_read() else { + let Ok(layers_guard) = self.layers.try_read(LayerManagerLockHolder::TryFreezeLayer) else { // Don't block if the layer lock is busy return; }; @@ -1896,16 +1902,11 @@ impl Timeline { return; }; - let Some(current_size) = open_layer.try_len() else { - // Unexpected: since we hold the write guard, nobody else should be writing to this layer, so - // read lock to get size should always succeed. - tracing::warn!("Lock conflict while reading size of open layer"); - return; - }; + let current_size = open_layer.len(); let current_lsn = self.get_last_record_lsn(); - let checkpoint_distance_override = open_layer.tick().await; + let checkpoint_distance_override = open_layer.tick(); if let Some(size_override) = checkpoint_distance_override { if current_size > size_override { @@ -2158,7 +2159,7 @@ impl Timeline { if let ShutdownMode::FreezeAndFlush = mode { let do_flush = if let Some((open, frozen)) = self .layers - .read() + .read(LayerManagerLockHolder::Shutdown) .await .layer_map() .map(|lm| (lm.open_layer.is_some(), lm.frozen_layers.len())) @@ -2262,7 +2263,10 @@ impl Timeline { // Allow any remaining in-memory layers to do cleanup -- until that, they hold the gate // open. let mut write_guard = self.write_lock.lock().await; - self.layers.write().await.shutdown(&mut write_guard); + self.layers + .write(LayerManagerLockHolder::Shutdown) + .await + .shutdown(&mut write_guard); } // Finally wait until any gate-holders are complete. @@ -2365,7 +2369,10 @@ impl Timeline { &self, reset: LayerAccessStatsReset, ) -> Result { - let guard = self.layers.read().await; + let guard = self + .layers + .read(LayerManagerLockHolder::GetLayerMapInfo) + .await; let layer_map = guard.layer_map()?; let mut in_memory_layers = Vec::with_capacity(layer_map.frozen_layers.len() + 1); if let Some(open_layer) = &layer_map.open_layer { @@ -2493,6 +2500,13 @@ impl Timeline { .unwrap_or(self.conf.default_tenant_conf.basebackup_cache_enabled) } + /// Try to get a basebackup from the on-disk cache. + pub(crate) async fn get_cached_basebackup(&self, lsn: Lsn) -> Option { + self.basebackup_cache + .get(self.tenant_shard_id.tenant_id, self.timeline_id, lsn) + .await + } + /// Prepare basebackup for the given LSN and store it in the basebackup cache. /// The method is asynchronous and returns immediately. /// The actual basebackup preparation is performed in the background @@ -2506,18 +2520,16 @@ impl Timeline { // Preparing basebackup doesn't make sense for shards other than shard zero. return; } - - let res = self - .basebackup_prepare_sender - .send(BasebackupPrepareRequest { - tenant_shard_id: self.tenant_shard_id, - timeline_id: self.timeline_id, - lsn, - }); - if let Err(e) = res { - // May happen during shutdown, it's not critical. - info!("Failed to send shutdown checkpoint: {e:#}"); + if !self.is_active() { + // May happen during initial timeline creation. + // Such timeline is not in the global timeline map yet, + // so basebackup cache will not be able to find it. + // TODO(diko): We can prepare such timelines in finish_creation(). + return; } + + self.basebackup_cache + .send_prepare(self.tenant_shard_id, self.timeline_id, lsn); } } @@ -2899,7 +2911,7 @@ impl Timeline { shard_identity: ShardIdentity, walredo_mgr: Option>, resources: TimelineResources, - pg_version: u32, + pg_version: PgMajorVersion, state: TimelineState, attach_wal_lag_cooldown: Arc>, create_idempotency: crate::tenant::CreateTimelineIdempotency, @@ -3074,7 +3086,7 @@ impl Timeline { wait_lsn_log_slow: tokio::sync::Semaphore::new(1), - basebackup_prepare_sender: resources.basebackup_prepare_sender, + basebackup_cache: resources.basebackup_cache, feature_resolver: resources.feature_resolver, }; @@ -3225,7 +3237,7 @@ impl Timeline { /// Initialize with an empty layer map. Used when creating a new timeline. pub(super) fn init_empty_layer_map(&self, start_lsn: Lsn) { - let mut layers = self.layers.try_write().expect( + let mut layers = self.layers.try_write(LayerManagerLockHolder::Init).expect( "in the context where we call this function, no other task has access to the object", ); layers @@ -3245,7 +3257,10 @@ impl Timeline { use init::Decision::*; use init::{Discovered, DismissedLayer}; - let mut guard = self.layers.write().await; + let mut guard = self + .layers + .write(LayerManagerLockHolder::LoadLayerMap) + .await; let timer = self.metrics.load_layer_map_histo.start_timer(); @@ -3400,10 +3415,6 @@ impl Timeline { // TenantShard::create_timeline will wait for these uploads to happen before returning, or // on retry. - // Now that we have the full layer map, we may calculate the visibility of layers within it (a global scan) - drop(guard); // drop write lock, update_layer_visibility will take a read lock. - self.update_layer_visibility().await?; - info!( "loaded layer map with {} layers at {}, total physical size: {}", num_layers, disk_consistent_lsn, total_physical_size @@ -3862,7 +3873,10 @@ impl Timeline { &self, layer_name: &LayerName, ) -> Result, layer_manager::Shutdown> { - let guard = self.layers.read().await; + let guard = self + .layers + .read(LayerManagerLockHolder::GetLayerMapInfo) + .await; let layer = guard .layer_map()? .iter_historic_layers() @@ -3895,7 +3909,10 @@ impl Timeline { return None; } - let guard = self.layers.read().await; + let guard = self + .layers + .read(LayerManagerLockHolder::GenerateHeatmap) + .await; // Firstly, if there's any heatmap left over from when this location // was a secondary, take that into account. Keep layers that are: @@ -3993,7 +4010,10 @@ impl Timeline { } pub(super) async fn generate_unarchival_heatmap(&self, end_lsn: Lsn) -> PreviousHeatmap { - let guard = self.layers.read().await; + let guard = self + .layers + .read(LayerManagerLockHolder::GenerateHeatmap) + .await; let now = SystemTime::now(); let mut heatmap_layers = Vec::default(); @@ -4335,7 +4355,7 @@ impl Timeline { query: &VersionedKeySpaceQuery, ) -> Result { let mut fringe = LayerFringe::new(); - let guard = self.layers.read().await; + let guard = self.layers.read(LayerManagerLockHolder::GetPage).await; match query { VersionedKeySpaceQuery::Uniform { keyspace, lsn } => { @@ -4438,7 +4458,7 @@ impl Timeline { // required for correctness, but avoids visiting extra layers // which turns out to be a perf bottleneck in some cases. if !unmapped_keyspace.is_empty() { - let guard = timeline.layers.read().await; + let guard = timeline.layers.read(LayerManagerLockHolder::GetPage).await; guard.update_search_fringe(&unmapped_keyspace, cont_lsn, &mut fringe)?; // It's safe to drop the layer map lock after planning the next round of reads. @@ -4548,7 +4568,10 @@ impl Timeline { _guard: &tokio::sync::MutexGuard<'_, Option>, ctx: &RequestContext, ) -> anyhow::Result> { - let mut guard = self.layers.write().await; + let mut guard = self + .layers + .write(LayerManagerLockHolder::GetLayerForWrite) + .await; let last_record_lsn = self.get_last_record_lsn(); ensure!( @@ -4590,7 +4613,10 @@ impl Timeline { write_lock: &mut tokio::sync::MutexGuard<'_, Option>, ) -> Result { let frozen = { - let mut guard = self.layers.write().await; + let mut guard = self + .layers + .write(LayerManagerLockHolder::TryFreezeLayer) + .await; guard .open_mut()? .try_freeze_in_memory_layer(at, &self.last_freeze_at, write_lock, &self.metrics) @@ -4631,7 +4657,12 @@ impl Timeline { ctx: &RequestContext, ) { // Subscribe to L0 delta layer updates, for compaction backpressure. - let mut watch_l0 = match self.layers.read().await.layer_map() { + let mut watch_l0 = match self + .layers + .read(LayerManagerLockHolder::FlushLoop) + .await + .layer_map() + { Ok(lm) => lm.watch_level0_deltas(), Err(Shutdown) => return, }; @@ -4668,7 +4699,7 @@ impl Timeline { // Fetch the next layer to flush, if any. let (layer, l0_count, frozen_count, frozen_size) = { - let layers = self.layers.read().await; + let layers = self.layers.read(LayerManagerLockHolder::FlushLoop).await; let Ok(lm) = layers.layer_map() else { info!("dropping out of flush loop for timeline shutdown"); return; @@ -4964,7 +4995,10 @@ impl Timeline { // in-memory layer from the map now. The flushed layer is stored in // the mapping in `create_delta_layer`. { - let mut guard = self.layers.write().await; + let mut guard = self + .layers + .write(LayerManagerLockHolder::FlushFrozenLayer) + .await; guard.open_mut()?.finish_flush_l0_layer( delta_layer_to_add.as_ref(), @@ -5166,7 +5200,11 @@ impl Timeline { } let (dense_ks, sparse_ks) = self.collect_keyspace(lsn, ctx).await?; - let dense_partitioning = dense_ks.partition(&self.shard_identity, partition_size); + let dense_partitioning = dense_ks.partition( + &self.shard_identity, + partition_size, + postgres_ffi::BLCKSZ as u64, + ); let sparse_partitioning = SparseKeyPartitioning { parts: vec![sparse_ks], }; // no partitioning for metadata keys for now @@ -5179,7 +5217,7 @@ impl Timeline { async fn time_for_new_image_layer(&self, partition: &KeySpace, lsn: Lsn) -> bool { let threshold = self.get_image_creation_threshold(); - let guard = self.layers.read().await; + let guard = self.layers.read(LayerManagerLockHolder::Compaction).await; let Ok(layers) = guard.layer_map() else { return false; }; @@ -5597,7 +5635,7 @@ impl Timeline { if let ImageLayerCreationMode::Force = mode { // When forced to create image layers, we might try and create them where they already // exist. This mode is only used in tests/debug. - let layers = self.layers.read().await; + let layers = self.layers.read(LayerManagerLockHolder::Compaction).await; if layers.contains_key(&PersistentLayerKey { key_range: img_range.clone(), lsn_range: PersistentLayerDesc::image_layer_lsn_range(lsn), @@ -5722,7 +5760,7 @@ impl Timeline { let image_layers = batch_image_writer.finish(self, ctx).await?; - let mut guard = self.layers.write().await; + let mut guard = self.layers.write(LayerManagerLockHolder::Compaction).await; // FIXME: we could add the images to be uploaded *before* returning from here, but right // now they are being scheduled outside of write lock; current way is inconsistent with @@ -5730,7 +5768,7 @@ impl Timeline { guard .open_mut()? .track_new_image_layers(&image_layers, &self.metrics); - drop_wlock(guard); + drop_layer_manager_wlock(guard); let duration = timer.stop_and_record(); // Creating image layers may have caused some previously visible layers to be covered @@ -5894,7 +5932,7 @@ impl Drop for Timeline { if let Ok(mut gc_info) = ancestor.gc_info.write() { if !gc_info.remove_child_not_offloaded(self.timeline_id) { tracing::error!(tenant_id = %self.tenant_shard_id.tenant_id, shard_id = %self.tenant_shard_id.shard_slug(), timeline_id = %self.timeline_id, - "Couldn't remove retain_lsn entry from offloaded timeline's parent: already removed"); + "Couldn't remove retain_lsn entry from timeline's parent on drop: already removed"); } } } @@ -6100,7 +6138,7 @@ impl Timeline { layers_to_remove: &[Layer], ) -> Result<(), CompactionError> { let mut guard = tokio::select! { - guard = self.layers.write() => guard, + guard = self.layers.write(LayerManagerLockHolder::Compaction) => guard, _ = self.cancel.cancelled() => { return Err(CompactionError::ShuttingDown); } @@ -6149,7 +6187,7 @@ impl Timeline { self.remote_client .schedule_compaction_update(&remove_layers, new_deltas)?; - drop_wlock(guard); + drop_layer_manager_wlock(guard); Ok(()) } @@ -6159,7 +6197,7 @@ impl Timeline { mut replace_layers: Vec<(Layer, ResidentLayer)>, mut drop_layers: Vec, ) -> Result<(), CompactionError> { - let mut guard = self.layers.write().await; + let mut guard = self.layers.write(LayerManagerLockHolder::Compaction).await; // Trim our lists in case our caller (compaction) raced with someone else (GC) removing layers: we want // to avoid double-removing, and avoid rewriting something that was removed. @@ -6498,7 +6536,7 @@ impl Timeline { debug!("retain_lsns: {:?}", retain_lsns); - let mut layers_to_remove = Vec::new(); + let max_retain_lsn = retain_lsns.iter().max(); // Scan all layers in the timeline (remote or on-disk). // @@ -6508,105 +6546,110 @@ impl Timeline { // 3. it doesn't need to be retained for 'retain_lsns'; // 4. it does not need to be kept for LSNs holding valid leases. // 5. newer on-disk image layers cover the layer's whole key range - // - // TODO holding a write lock is too agressive and avoidable - let mut guard = self.layers.write().await; - let layers = guard.layer_map()?; - 'outer: for l in layers.iter_historic_layers() { - result.layers_total += 1; + let layers_to_remove = { + let mut layers_to_remove = Vec::new(); - // 1. Is it newer than GC horizon cutoff point? - if l.get_lsn_range().end > space_cutoff { - info!( - "keeping {} because it's newer than space_cutoff {}", - l.layer_name(), - space_cutoff, - ); - result.layers_needed_by_cutoff += 1; - continue 'outer; - } + let guard = self + .layers + .read(LayerManagerLockHolder::GarbageCollection) + .await; + let layers = guard.layer_map()?; + 'outer: for l in layers.iter_historic_layers() { + result.layers_total += 1; - // 2. It is newer than PiTR cutoff point? - if l.get_lsn_range().end > time_cutoff { - info!( - "keeping {} because it's newer than time_cutoff {}", - l.layer_name(), - time_cutoff, - ); - result.layers_needed_by_pitr += 1; - continue 'outer; - } - - // 3. Is it needed by a child branch? - // NOTE With that we would keep data that - // might be referenced by child branches forever. - // We can track this in child timeline GC and delete parent layers when - // they are no longer needed. This might be complicated with long inheritance chains. - // - // TODO Vec is not a great choice for `retain_lsns` - for retain_lsn in &retain_lsns { - // start_lsn is inclusive - if &l.get_lsn_range().start <= retain_lsn { - info!( - "keeping {} because it's still might be referenced by child branch forked at {} is_dropped: xx is_incremental: {}", + // 1. Is it newer than GC horizon cutoff point? + if l.get_lsn_range().end > space_cutoff { + debug!( + "keeping {} because it's newer than space_cutoff {}", l.layer_name(), - retain_lsn, - l.is_incremental(), + space_cutoff, ); - result.layers_needed_by_branches += 1; + result.layers_needed_by_cutoff += 1; continue 'outer; } - } - // 4. Is there a valid lease that requires us to keep this layer? - if let Some(lsn) = &max_lsn_with_valid_lease { - // keep if layer start <= any of the lease - if &l.get_lsn_range().start <= lsn { - info!( - "keeping {} because there is a valid lease preventing GC at {}", + // 2. It is newer than PiTR cutoff point? + if l.get_lsn_range().end > time_cutoff { + debug!( + "keeping {} because it's newer than time_cutoff {}", l.layer_name(), - lsn, + time_cutoff, ); - result.layers_needed_by_leases += 1; + result.layers_needed_by_pitr += 1; continue 'outer; } + + // 3. Is it needed by a child branch? + // NOTE With that we would keep data that + // might be referenced by child branches forever. + // We can track this in child timeline GC and delete parent layers when + // they are no longer needed. This might be complicated with long inheritance chains. + if let Some(retain_lsn) = max_retain_lsn { + // start_lsn is inclusive + if &l.get_lsn_range().start <= retain_lsn { + debug!( + "keeping {} because it's still might be referenced by child branch forked at {} is_dropped: xx is_incremental: {}", + l.layer_name(), + retain_lsn, + l.is_incremental(), + ); + result.layers_needed_by_branches += 1; + continue 'outer; + } + } + + // 4. Is there a valid lease that requires us to keep this layer? + if let Some(lsn) = &max_lsn_with_valid_lease { + // keep if layer start <= any of the lease + if &l.get_lsn_range().start <= lsn { + debug!( + "keeping {} because there is a valid lease preventing GC at {}", + l.layer_name(), + lsn, + ); + result.layers_needed_by_leases += 1; + continue 'outer; + } + } + + // 5. Is there a later on-disk layer for this relation? + // + // The end-LSN is exclusive, while disk_consistent_lsn is + // inclusive. For example, if disk_consistent_lsn is 100, it is + // OK for a delta layer to have end LSN 101, but if the end LSN + // is 102, then it might not have been fully flushed to disk + // before crash. + // + // For example, imagine that the following layers exist: + // + // 1000 - image (A) + // 1000-2000 - delta (B) + // 2000 - image (C) + // 2000-3000 - delta (D) + // 3000 - image (E) + // + // If GC horizon is at 2500, we can remove layers A and B, but + // we cannot remove C, even though it's older than 2500, because + // the delta layer 2000-3000 depends on it. + if !layers + .image_layer_exists(&l.get_key_range(), &(l.get_lsn_range().end..new_gc_cutoff)) + { + debug!("keeping {} because it is the latest layer", l.layer_name()); + result.layers_not_updated += 1; + continue 'outer; + } + + // We didn't find any reason to keep this file, so remove it. + info!( + "garbage collecting {} is_dropped: xx is_incremental: {}", + l.layer_name(), + l.is_incremental(), + ); + layers_to_remove.push(l); } - // 5. Is there a later on-disk layer for this relation? - // - // The end-LSN is exclusive, while disk_consistent_lsn is - // inclusive. For example, if disk_consistent_lsn is 100, it is - // OK for a delta layer to have end LSN 101, but if the end LSN - // is 102, then it might not have been fully flushed to disk - // before crash. - // - // For example, imagine that the following layers exist: - // - // 1000 - image (A) - // 1000-2000 - delta (B) - // 2000 - image (C) - // 2000-3000 - delta (D) - // 3000 - image (E) - // - // If GC horizon is at 2500, we can remove layers A and B, but - // we cannot remove C, even though it's older than 2500, because - // the delta layer 2000-3000 depends on it. - if !layers - .image_layer_exists(&l.get_key_range(), &(l.get_lsn_range().end..new_gc_cutoff)) - { - info!("keeping {} because it is the latest layer", l.layer_name()); - result.layers_not_updated += 1; - continue 'outer; - } - - // We didn't find any reason to keep this file, so remove it. - info!( - "garbage collecting {} is_dropped: xx is_incremental: {}", - l.layer_name(), - l.is_incremental(), - ); - layers_to_remove.push(l); - } + layers_to_remove + }; if !layers_to_remove.is_empty() { // Persist the new GC cutoff value before we actually remove anything. @@ -6622,15 +6665,19 @@ impl Timeline { } })?; + let mut guard = self + .layers + .write(LayerManagerLockHolder::GarbageCollection) + .await; + let gc_layers = layers_to_remove .iter() - .map(|x| guard.get_from_desc(x)) + .flat_map(|desc| guard.try_get_from_key(&desc.key()).cloned()) .collect::>(); result.layers_removed = gc_layers.len() as u64; self.remote_client.schedule_gc_update(&gc_layers)?; - guard.open_mut()?.finish_gc_timeline(&gc_layers); #[cfg(feature = "testing")] @@ -6812,7 +6859,10 @@ impl Timeline { use pageserver_api::models::DownloadRemoteLayersTaskState; let remaining = { - let guard = self.layers.read().await; + let guard = self + .layers + .read(LayerManagerLockHolder::GetLayerMapInfo) + .await; let Ok(lm) = guard.layer_map() else { // technically here we could look into iterating accessible layers, but downloading // all layers of a shutdown timeline makes no sense regardless. @@ -6918,7 +6968,7 @@ impl Timeline { impl Timeline { /// Returns non-remote layers for eviction. pub(crate) async fn get_local_layers_for_disk_usage_eviction(&self) -> DiskUsageEvictionInfo { - let guard = self.layers.read().await; + let guard = self.layers.read(LayerManagerLockHolder::Eviction).await; let mut max_layer_size: Option = None; let resident_layers = guard @@ -7019,7 +7069,7 @@ impl Timeline { let image_layer = Layer::finish_creating(self.conf, self, desc, &path)?; info!("force created image layer {}", image_layer.local_path()); { - let mut guard = self.layers.write().await; + let mut guard = self.layers.write(LayerManagerLockHolder::Testing).await; guard .open_mut() .unwrap() @@ -7082,7 +7132,7 @@ impl Timeline { let delta_layer = Layer::finish_creating(self.conf, self, desc, &path)?; info!("force created delta layer {}", delta_layer.local_path()); { - let mut guard = self.layers.write().await; + let mut guard = self.layers.write(LayerManagerLockHolder::Testing).await; guard .open_mut() .unwrap() @@ -7127,9 +7177,7 @@ impl Timeline { if let Some(end) = layer_end_lsn { assert!( end <= last_record_lsn, - "advance last record lsn before inserting a layer, end_lsn={}, last_record_lsn={}", - end, - last_record_lsn, + "advance last record lsn before inserting a layer, end_lsn={end}, last_record_lsn={last_record_lsn}", ); } @@ -7177,7 +7225,7 @@ impl Timeline { // Link the layer to the layer map { - let mut guard = self.layers.write().await; + let mut guard = self.layers.write(LayerManagerLockHolder::Testing).await; let layer_map = guard.open_mut().unwrap(); layer_map.force_insert_in_memory_layer(Arc::new(layer)); } @@ -7194,7 +7242,7 @@ impl Timeline { io_concurrency: IoConcurrency, ) -> anyhow::Result> { let mut all_data = Vec::new(); - let guard = self.layers.read().await; + let guard = self.layers.read(LayerManagerLockHolder::Testing).await; for layer in guard.layer_map()?.iter_historic_layers() { if !layer.is_delta() && layer.image_layer_lsn() == lsn { let layer = guard.get_from_desc(&layer); @@ -7223,7 +7271,7 @@ impl Timeline { self: &Arc, ) -> anyhow::Result> { let mut layers = Vec::new(); - let guard = self.layers.read().await; + let guard = self.layers.read(LayerManagerLockHolder::Testing).await; for layer in guard.layer_map()?.iter_historic_layers() { layers.push(layer.key()); } @@ -7315,7 +7363,7 @@ impl TimelineWriter<'_> { .tl .get_layer_for_write(at, &self.write_guard, ctx) .await?; - let initial_size = layer.size().await?; + let initial_size = layer.len(); let last_freeze_at = self.last_freeze_at.load(); self.write_guard.replace(TimelineWriterState::new( @@ -7335,7 +7383,7 @@ impl TimelineWriter<'_> { let l0_count = self .tl .layers - .read() + .read(LayerManagerLockHolder::GetLayerMapInfo) .await .layer_map()? .level0_deltas() @@ -7543,17 +7591,19 @@ mod tests { use std::sync::Arc; use pageserver_api::key::Key; - use pageserver_api::value::Value; + use postgres_ffi::PgMajorVersion; use std::iter::Iterator; use tracing::Instrument; use utils::id::TimelineId; use utils::lsn::Lsn; + use wal_decoder::models::value::Value; use super::HeatMapTimeline; use crate::context::RequestContextBuilder; use crate::tenant::harness::{TenantHarness, test_img}; use crate::tenant::layer_map::LayerMap; use crate::tenant::storage_layer::{Layer, LayerName, LayerVisibilityHint}; + use crate::tenant::timeline::layer_manager::LayerManagerLockHolder; use crate::tenant::timeline::{DeltaLayerTestDesc, EvictionError}; use crate::tenant::{PreviousHeatmap, Timeline}; @@ -7616,7 +7666,7 @@ mod tests { .create_test_timeline_with_layers( TimelineId::generate(), Lsn(0x10), - 14, + PgMajorVersion::PG14, &ctx, Vec::new(), // in-memory layers delta_layers, @@ -7661,7 +7711,7 @@ mod tests { // Evict all the layers and stash the old heatmap in the timeline. // This simulates a migration to a cold secondary location. - let guard = timeline.layers.read().await; + let guard = timeline.layers.read(LayerManagerLockHolder::Testing).await; let mut all_layers = Vec::new(); let forever = std::time::Duration::from_secs(120); for layer in guard.likely_resident_layers() { @@ -7752,7 +7802,7 @@ mod tests { .create_test_timeline_with_layers( TimelineId::generate(), Lsn(0x10), - 14, + PgMajorVersion::PG14, &ctx, Vec::new(), // in-memory layers delta_layers, @@ -7783,7 +7833,7 @@ mod tests { }))); // Evict all the layers in the previous heatmap - let guard = timeline.layers.read().await; + let guard = timeline.layers.read(LayerManagerLockHolder::Testing).await; let forever = std::time::Duration::from_secs(120); for layer in guard.likely_resident_layers() { layer.evict_and_wait(forever).await.unwrap(); @@ -7812,7 +7862,12 @@ mod tests { let (tenant, ctx) = harness.load().await; let timeline = tenant - .create_test_timeline(TimelineId::generate(), Lsn(0x10), 14, &ctx) + .create_test_timeline( + TimelineId::generate(), + Lsn(0x10), + PgMajorVersion::PG14, + &ctx, + ) .await .unwrap(); @@ -7846,7 +7901,10 @@ mod tests { } async fn find_some_layer(timeline: &Timeline) -> Layer { - let layers = timeline.layers.read().await; + let layers = timeline + .layers + .read(LayerManagerLockHolder::GetLayerMapInfo) + .await; let desc = layers .layer_map() .unwrap() diff --git a/pageserver/src/tenant/timeline/analysis.rs b/pageserver/src/tenant/timeline/analysis.rs index 96864ec44b..90c70086ed 100644 --- a/pageserver/src/tenant/timeline/analysis.rs +++ b/pageserver/src/tenant/timeline/analysis.rs @@ -4,6 +4,7 @@ use std::ops::Range; use utils::lsn::Lsn; use super::Timeline; +use crate::tenant::timeline::layer_manager::LayerManagerLockHolder; #[derive(serde::Serialize)] pub(crate) struct RangeAnalysis { @@ -24,7 +25,10 @@ impl Timeline { let num_of_l0; let all_layer_files = { - let guard = self.layers.read().await; + let guard = self + .layers + .read(LayerManagerLockHolder::GetLayerMapInfo) + .await; num_of_l0 = guard.layer_map().unwrap().level0_deltas().len(); guard.all_persistent_layers() }; diff --git a/pageserver/src/tenant/timeline/compaction.rs b/pageserver/src/tenant/timeline/compaction.rs index 72ca0f9cc1..1b8e5f4b9c 100644 --- a/pageserver/src/tenant/timeline/compaction.rs +++ b/pageserver/src/tenant/timeline/compaction.rs @@ -9,7 +9,7 @@ use std::ops::{Deref, Range}; use std::sync::Arc; use std::time::{Duration, Instant}; -use super::layer_manager::LayerManager; +use super::layer_manager::{LayerManagerLockHolder, LayerManagerReadGuard}; use super::{ CompactFlags, CompactOptions, CompactionError, CreateImageLayersError, DurationRecorder, GetVectoredError, ImageLayerCreationMode, LastImageLayerCreationStatus, RecordedDuration, @@ -29,9 +29,7 @@ use pageserver_api::config::tenant_conf_defaults::DEFAULT_CHECKPOINT_DISTANCE; use pageserver_api::key::{KEY_SIZE, Key}; use pageserver_api::keyspace::{KeySpace, ShardedRange}; use pageserver_api::models::{CompactInfoResponse, CompactKeyRange}; -use pageserver_api::record::NeonWalRecord; use pageserver_api::shard::{ShardCount, ShardIdentity, TenantShardId}; -use pageserver_api::value::Value; use pageserver_compaction::helpers::{fully_contains, overlaps_with}; use pageserver_compaction::interface::*; use serde::Serialize; @@ -41,6 +39,8 @@ use tracing::{Instrument, debug, error, info, info_span, trace, warn}; use utils::critical; use utils::id::TimelineId; use utils::lsn::Lsn; +use wal_decoder::models::record::NeonWalRecord; +use wal_decoder::models::value::Value; use crate::context::{AccessStatsBehavior, RequestContext, RequestContextBuilder}; use crate::page_cache; @@ -62,7 +62,7 @@ use crate::tenant::storage_layer::{ use crate::tenant::tasks::log_compaction_error; use crate::tenant::timeline::{ DeltaLayerWriter, ImageLayerCreationOutcome, ImageLayerWriter, IoConcurrency, Layer, - ResidentLayer, drop_rlock, + ResidentLayer, drop_layer_manager_rlock, }; use crate::tenant::{DeltaLayer, MaybeOffloaded}; use crate::virtual_file::{MaybeFatalIo, VirtualFile}; @@ -314,7 +314,10 @@ impl GcCompactionQueue { .unwrap_or(Lsn::INVALID); let layers = { - let guard = timeline.layers.read().await; + let guard = timeline + .layers + .read(LayerManagerLockHolder::GetLayerMapInfo) + .await; let layer_map = guard.layer_map()?; layer_map.iter_historic_layers().collect_vec() }; @@ -408,7 +411,10 @@ impl GcCompactionQueue { timeline: &Arc, lsn: Lsn, ) -> Result { - let guard = timeline.layers.read().await; + let guard = timeline + .layers + .read(LayerManagerLockHolder::GetLayerMapInfo) + .await; let layer_map = guard.layer_map()?; let layers = layer_map.iter_historic_layers().collect_vec(); let mut size = 0; @@ -851,7 +857,7 @@ impl KeyHistoryRetention { } let layer_generation; { - let guard = tline.layers.read().await; + let guard = tline.layers.read(LayerManagerLockHolder::Compaction).await; if !guard.contains_key(key) { return false; } @@ -971,7 +977,7 @@ impl KeyHistoryRetention { tline .reconstruct_value(key, lsn, data, RedoAttemptType::GcCompaction) .await - .with_context(|| format!("verification failed for key {} at lsn {}", key, lsn))?; + .with_context(|| format!("verification failed for key {key} at lsn {lsn}"))?; Ok(()) } @@ -1282,7 +1288,10 @@ impl Timeline { // We do the repartition on the L0-L1 boundary. All data below the boundary // are compacted by L0 with low read amplification, thus making the `repartition` // function run fast. - let guard = self.layers.read().await; + let guard = self + .layers + .read(LayerManagerLockHolder::GetLayerMapInfo) + .await; guard .all_persistent_layers() .iter() @@ -1461,7 +1470,7 @@ impl Timeline { let latest_gc_cutoff = self.get_applied_gc_cutoff_lsn(); let pitr_cutoff = self.gc_info.read().unwrap().cutoffs.time; - let layers = self.layers.read().await; + let layers = self.layers.read(LayerManagerLockHolder::Compaction).await; let layers_iter = layers.layer_map()?.iter_historic_layers(); let (layers_total, mut layers_checked) = (layers_iter.len(), 0); for layer_desc in layers_iter { @@ -1722,7 +1731,10 @@ impl Timeline { // are implicitly left visible, because LayerVisibilityHint's default is Visible, and we never modify it here. // Note that L0 deltas _can_ be covered by image layers, but we consider them 'visible' because we anticipate that // they will be subject to L0->L1 compaction in the near future. - let layer_manager = self.layers.read().await; + let layer_manager = self + .layers + .read(LayerManagerLockHolder::GetLayerMapInfo) + .await; let layer_map = layer_manager.layer_map()?; let readable_points = { @@ -1775,7 +1787,7 @@ impl Timeline { }; let begin = tokio::time::Instant::now(); - let phase1_layers_locked = self.layers.read().await; + let phase1_layers_locked = self.layers.read(LayerManagerLockHolder::Compaction).await; let now = tokio::time::Instant::now(); stats.read_lock_acquisition_micros = DurationRecorder::Recorded(RecordedDuration(now - begin), now); @@ -1803,7 +1815,7 @@ impl Timeline { /// Level0 files first phase of compaction, explained in the [`Self::compact_legacy`] comment. async fn compact_level0_phase1<'a>( self: &'a Arc, - guard: tokio::sync::RwLockReadGuard<'a, LayerManager>, + guard: LayerManagerReadGuard<'a>, mut stats: CompactLevel0Phase1StatsBuilder, target_file_size: u64, force_compaction_ignore_threshold: bool, @@ -2029,7 +2041,7 @@ impl Timeline { holes }; stats.read_lock_held_compute_holes_micros = stats.read_lock_held_key_sort_micros.till_now(); - drop_rlock(guard); + drop_layer_manager_rlock(guard); if self.cancel.is_cancelled() { return Err(CompactionError::ShuttingDown); @@ -2469,7 +2481,7 @@ impl Timeline { // Find the top of the historical layers let end_lsn = { - let guard = self.layers.read().await; + let guard = self.layers.read(LayerManagerLockHolder::Compaction).await; let layers = guard.layer_map()?; let l0_deltas = layers.level0_deltas(); @@ -2635,15 +2647,15 @@ impl Timeline { use std::fmt::Write; let mut output = String::new(); if let Some((key, _, _)) = replay_history.first() { - write!(output, "key={} ", key).unwrap(); + write!(output, "key={key} ").unwrap(); let mut cnt = 0; for (_, lsn, val) in replay_history { if val.is_image() { - write!(output, "i@{} ", lsn).unwrap(); + write!(output, "i@{lsn} ").unwrap(); } else if val.will_init() { - write!(output, "di@{} ", lsn).unwrap(); + write!(output, "di@{lsn} ").unwrap(); } else { - write!(output, "d@{} ", lsn).unwrap(); + write!(output, "d@{lsn} ").unwrap(); } cnt += 1; if cnt >= 128 { @@ -3008,7 +3020,7 @@ impl Timeline { } split_key_ranges.sort(); let all_layers = { - let guard = self.layers.read().await; + let guard = self.layers.read(LayerManagerLockHolder::Compaction).await; let layer_map = guard.layer_map()?; layer_map.iter_historic_layers().collect_vec() }; @@ -3112,12 +3124,12 @@ impl Timeline { .await?; let jobs_len = jobs.len(); for (idx, job) in jobs.into_iter().enumerate() { - info!( - "running enhanced gc bottom-most compaction, sub-compaction {}/{}", - idx + 1, - jobs_len - ); + let sub_compaction_progress = format!("{}/{}", idx + 1, jobs_len); self.compact_with_gc_inner(cancel, job, ctx, yield_for_l0) + .instrument(info_span!( + "sub_compaction", + sub_compaction_progress = sub_compaction_progress + )) .await?; } if jobs_len == 0 { @@ -3185,7 +3197,10 @@ impl Timeline { // 1. If a layer is in the selection, all layers below it are in the selection. // 2. Inferred from (1), for each key in the layer selection, the value can be reconstructed only with the layers in the layer selection. let job_desc = { - let guard = self.layers.read().await; + let guard = self + .layers + .read(LayerManagerLockHolder::GarbageCollection) + .await; let layers = guard.layer_map()?; let gc_info = self.gc_info.read().unwrap(); let mut retain_lsns_below_horizon = Vec::new(); @@ -3956,7 +3971,10 @@ impl Timeline { // First, do a sanity check to ensure the newly-created layer map does not contain overlaps. let all_layers = { - let guard = self.layers.read().await; + let guard = self + .layers + .read(LayerManagerLockHolder::GarbageCollection) + .await; let layer_map = guard.layer_map()?; layer_map.iter_historic_layers().collect_vec() }; @@ -4020,7 +4038,10 @@ impl Timeline { let update_guard = self.gc_compaction_layer_update_lock.write().await; // Acquiring the update guard ensures current read operations end and new read operations are blocked. // TODO: can we use `latest_gc_cutoff` Rcu to achieve the same effect? - let mut guard = self.layers.write().await; + let mut guard = self + .layers + .write(LayerManagerLockHolder::GarbageCollection) + .await; guard .open_mut()? .finish_gc_compaction(&layer_selection, &compact_to, &self.metrics); @@ -4088,7 +4109,11 @@ impl TimelineAdaptor { pub async fn flush_updates(&mut self) -> Result<(), CompactionError> { let layers_to_delete = { - let guard = self.timeline.layers.read().await; + let guard = self + .timeline + .layers + .read(LayerManagerLockHolder::Compaction) + .await; self.layers_to_delete .iter() .map(|x| guard.get_from_desc(x)) @@ -4133,7 +4158,11 @@ impl CompactionJobExecutor for TimelineAdaptor { ) -> anyhow::Result>> { self.flush_updates().await?; - let guard = self.timeline.layers.read().await; + let guard = self + .timeline + .layers + .read(LayerManagerLockHolder::Compaction) + .await; let layer_map = guard.layer_map()?; let result = layer_map @@ -4172,7 +4201,11 @@ impl CompactionJobExecutor for TimelineAdaptor { // this is a lot more complex than a simple downcast... if layer.is_delta() { let l = { - let guard = self.timeline.layers.read().await; + let guard = self + .timeline + .layers + .read(LayerManagerLockHolder::Compaction) + .await; guard.get_from_desc(layer) }; let result = l.download_and_keep_resident(ctx).await?; diff --git a/pageserver/src/tenant/timeline/detach_ancestor.rs b/pageserver/src/tenant/timeline/detach_ancestor.rs index 40eda8c785..f47ce5408b 100644 --- a/pageserver/src/tenant/timeline/detach_ancestor.rs +++ b/pageserver/src/tenant/timeline/detach_ancestor.rs @@ -19,7 +19,7 @@ use utils::id::TimelineId; use utils::lsn::Lsn; use utils::sync::gate::GateError; -use super::layer_manager::LayerManager; +use super::layer_manager::{LayerManager, LayerManagerLockHolder}; use super::{FlushLayerError, Timeline}; use crate::context::{DownloadBehavior, RequestContext}; use crate::task_mgr::TaskKind; @@ -199,7 +199,10 @@ pub(crate) async fn generate_tombstone_image_layer( let image_lsn = ancestor_lsn; { - let layers = detached.layers.read().await; + let layers = detached + .layers + .read(LayerManagerLockHolder::DetachAncestor) + .await; for layer in layers.all_persistent_layers() { if !layer.is_delta && layer.lsn_range.start == image_lsn @@ -423,7 +426,7 @@ pub(super) async fn prepare( // we do not need to start from our layers, because they can only be layers that come // *after* ancestor_lsn let layers = tokio::select! { - guard = ancestor.layers.read() => guard, + guard = ancestor.layers.read(LayerManagerLockHolder::DetachAncestor) => guard, _ = detached.cancel.cancelled() => { return Err(ShuttingDown); } @@ -869,7 +872,12 @@ async fn remote_copy( // Double check that the file is orphan (probably from an earlier attempt), then delete it let key = file_name.clone().into(); - if adoptee.layers.read().await.contains_key(&key) { + if adoptee + .layers + .read(LayerManagerLockHolder::DetachAncestor) + .await + .contains_key(&key) + { // We are supposed to filter out such cases before coming to this function return Err(Error::Prepare(anyhow::anyhow!( "layer file {file_name} already present and inside layer map" diff --git a/pageserver/src/tenant/timeline/eviction_task.rs b/pageserver/src/tenant/timeline/eviction_task.rs index b1b0d32c9b..1328c3ac12 100644 --- a/pageserver/src/tenant/timeline/eviction_task.rs +++ b/pageserver/src/tenant/timeline/eviction_task.rs @@ -33,6 +33,7 @@ use crate::tenant::size::CalculateSyntheticSizeError; use crate::tenant::storage_layer::LayerVisibilityHint; use crate::tenant::tasks::{BackgroundLoopKind, BackgroundLoopSemaphorePermit, sleep_random}; use crate::tenant::timeline::EvictionError; +use crate::tenant::timeline::layer_manager::LayerManagerLockHolder; use crate::tenant::{LogicalSizeCalculationCause, TenantShard}; #[derive(Default)] @@ -208,7 +209,7 @@ impl Timeline { let mut js = tokio::task::JoinSet::new(); { - let guard = self.layers.read().await; + let guard = self.layers.read(LayerManagerLockHolder::Eviction).await; guard .likely_resident_layers() diff --git a/pageserver/src/tenant/timeline/import_pgdata.rs b/pageserver/src/tenant/timeline/import_pgdata.rs index 606ad09ef1..817d76ce2f 100644 --- a/pageserver/src/tenant/timeline/import_pgdata.rs +++ b/pageserver/src/tenant/timeline/import_pgdata.rs @@ -15,6 +15,7 @@ use super::{Timeline, TimelineDeleteProgress}; use crate::context::RequestContext; use crate::controller_upcall_client::{StorageControllerUpcallApi, StorageControllerUpcallClient}; use crate::tenant::metadata::TimelineMetadata; +use crate::tenant::timeline::layer_manager::LayerManagerLockHolder; mod flow; mod importbucket_client; @@ -163,7 +164,10 @@ async fn prepare_import( info!("wipe the slate clean"); { // TODO: do we need to hold GC lock for this? - let mut guard = timeline.layers.write().await; + let mut guard = timeline + .layers + .write(LayerManagerLockHolder::ImportPgData) + .await; assert!( guard.layer_map()?.open_layer.is_none(), "while importing, there should be no in-memory layer" // this just seems like a good place to assert it diff --git a/pageserver/src/tenant/timeline/import_pgdata/flow.rs b/pageserver/src/tenant/timeline/import_pgdata/flow.rs index e003bb6810..d471e9fc69 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/flow.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/flow.rs @@ -36,8 +36,8 @@ use pageserver_api::keyspace::{ShardedRange, singleton_range}; use pageserver_api::models::{ShardImportProgress, ShardImportProgressV1, ShardImportStatus}; use pageserver_api::reltag::{RelTag, SlruKind}; use pageserver_api::shard::ShardIdentity; +use postgres_ffi::BLCKSZ; use postgres_ffi::relfile_utils::parse_relfilename; -use postgres_ffi::{BLCKSZ, pg_constants}; use remote_storage::RemotePath; use tokio::sync::Semaphore; use tokio_stream::StreamExt; @@ -56,6 +56,7 @@ use crate::pgdatadir_mapping::{ }; use crate::task_mgr::TaskKind; use crate::tenant::storage_layer::{AsLayerDesc, ImageLayerWriter, Layer}; +use crate::tenant::timeline::layer_manager::LayerManagerLockHolder; pub async fn run( timeline: Arc, @@ -557,7 +558,7 @@ impl PgDataDir { PgDataDirDb::new( storage, &basedir.join(dboid.to_string()), - pg_constants::DEFAULTTABLESPACE_OID, + postgres_ffi_types::constants::DEFAULTTABLESPACE_OID, dboid, &datadir_path, ) @@ -570,7 +571,7 @@ impl PgDataDir { PgDataDirDb::new( storage, &datadir_path.join("global"), - postgres_ffi::pg_constants::GLOBALTABLESPACE_OID, + postgres_ffi_types::constants::GLOBALTABLESPACE_OID, 0, &datadir_path, ) @@ -984,7 +985,10 @@ impl ChunkProcessingJob { let (desc, path) = writer.finish(ctx).await?; { - let guard = timeline.layers.read().await; + let guard = timeline + .layers + .read(LayerManagerLockHolder::ImportPgData) + .await; let existing_layer = guard.try_get_from_key(&desc.key()); if let Some(layer) = existing_layer { if layer.metadata().generation == timeline.generation { @@ -1007,7 +1011,10 @@ impl ChunkProcessingJob { // certain that the existing layer is identical to the new one, so in that case // we replace the old layer with the one we just generated. - let mut guard = timeline.layers.write().await; + let mut guard = timeline + .layers + .write(LayerManagerLockHolder::ImportPgData) + .await; let existing_layer = guard .try_get_from_key(&resident_layer.layer_desc().key()) @@ -1036,7 +1043,7 @@ impl ChunkProcessingJob { } } - crate::tenant::timeline::drop_wlock(guard); + crate::tenant::timeline::drop_layer_manager_wlock(guard); timeline .remote_client diff --git a/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs b/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs index bf2d9875c1..98c44313f1 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/importbucket_client.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use anyhow::Context; use bytes::Bytes; -use postgres_ffi::ControlFileData; +use postgres_ffi::{ControlFileData, PgMajorVersion}; use remote_storage::{ Download, DownloadError, DownloadKind, DownloadOpts, GenericRemoteStorage, Listing, ListingObject, RemotePath, RemoteStorageConfig, @@ -264,7 +264,7 @@ impl ControlFile { pub(crate) fn base_lsn(&self) -> Lsn { Lsn(self.control_file_data.checkPoint).align() } - pub(crate) fn pg_version(&self) -> u32 { + pub(crate) fn pg_version(&self) -> PgMajorVersion { self.try_pg_version() .expect("prepare() checks that try_pg_version doesn't error") } @@ -274,13 +274,14 @@ impl ControlFile { pub(crate) fn control_file_buf(&self) -> &Bytes { &self.control_file_buf } - fn try_pg_version(&self) -> anyhow::Result { + + fn try_pg_version(&self) -> anyhow::Result { Ok(match self.control_file_data.catalog_version_no { // thesea are from catversion.h - 202107181 => 14, - 202209061 => 15, - 202307071 => 16, - 202406281 => 17, + 202107181 => PgMajorVersion::PG14, + 202209061 => PgMajorVersion::PG15, + 202307071 => PgMajorVersion::PG16, + 202406281 => PgMajorVersion::PG17, catversion => { anyhow::bail!("unrecognized catalog version {catversion}") } diff --git a/pageserver/src/tenant/timeline/layer_manager.rs b/pageserver/src/tenant/timeline/layer_manager.rs index ae898260d2..2eccf48579 100644 --- a/pageserver/src/tenant/timeline/layer_manager.rs +++ b/pageserver/src/tenant/timeline/layer_manager.rs @@ -1,5 +1,8 @@ use std::collections::HashMap; +use std::mem::ManuallyDrop; +use std::ops::{Deref, DerefMut}; use std::sync::Arc; +use std::time::Duration; use anyhow::{Context, bail, ensure}; use itertools::Itertools; @@ -20,6 +23,155 @@ use crate::tenant::storage_layer::{ PersistentLayerKey, ReadableLayerWeak, ResidentLayer, }; +/// Warn if the lock was held for longer than this threshold. +/// It's very generous and we should bring this value down over time. +const LAYER_MANAGER_LOCK_WARN_THRESHOLD: Duration = Duration::from_secs(5); +const LAYER_MANAGER_LOCK_READ_WARN_THRESHOLD: Duration = Duration::from_secs(30); + +/// Describes the operation that is holding the layer manager lock +#[derive(Debug, Clone, Copy, strum_macros::Display)] +#[strum(serialize_all = "kebab_case")] +pub(crate) enum LayerManagerLockHolder { + GetLayerMapInfo, + GenerateHeatmap, + GetPage, + Init, + LoadLayerMap, + GetLayerForWrite, + TryFreezeLayer, + FlushFrozenLayer, + FlushLoop, + Compaction, + GarbageCollection, + Shutdown, + ImportPgData, + DetachAncestor, + Eviction, + #[cfg(test)] + Testing, +} + +/// Wrapper for the layer manager that tracks the amount of time during which +/// it was held under read or write lock +#[derive(Default)] +pub(crate) struct LockedLayerManager { + locked: tokio::sync::RwLock, +} + +pub(crate) struct LayerManagerReadGuard<'a> { + guard: ManuallyDrop>, + acquired_at: std::time::Instant, + holder: LayerManagerLockHolder, +} + +pub(crate) struct LayerManagerWriteGuard<'a> { + guard: ManuallyDrop>, + acquired_at: std::time::Instant, + holder: LayerManagerLockHolder, +} + +impl Drop for LayerManagerReadGuard<'_> { + fn drop(&mut self) { + // Drop the lock first, before potentially warning if it was held for too long. + // SAFETY: ManuallyDrop in Drop implementation + unsafe { ManuallyDrop::drop(&mut self.guard) }; + + let held_for = self.acquired_at.elapsed(); + if held_for >= LAYER_MANAGER_LOCK_READ_WARN_THRESHOLD { + tracing::warn!( + holder=%self.holder, + "Layer manager read lock held for {}s", + held_for.as_secs_f64(), + ); + } + } +} + +impl Drop for LayerManagerWriteGuard<'_> { + fn drop(&mut self) { + // Drop the lock first, before potentially warning if it was held for too long. + // SAFETY: ManuallyDrop in Drop implementation + unsafe { ManuallyDrop::drop(&mut self.guard) }; + + let held_for = self.acquired_at.elapsed(); + if held_for >= LAYER_MANAGER_LOCK_WARN_THRESHOLD { + tracing::warn!( + holder=%self.holder, + "Layer manager write lock held for {}s", + held_for.as_secs_f64(), + ); + } + } +} + +impl Deref for LayerManagerReadGuard<'_> { + type Target = LayerManager; + + fn deref(&self) -> &Self::Target { + self.guard.deref() + } +} + +impl Deref for LayerManagerWriteGuard<'_> { + type Target = LayerManager; + + fn deref(&self) -> &Self::Target { + self.guard.deref() + } +} + +impl DerefMut for LayerManagerWriteGuard<'_> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.guard.deref_mut() + } +} + +impl LockedLayerManager { + pub(crate) async fn read(&self, holder: LayerManagerLockHolder) -> LayerManagerReadGuard { + let guard = ManuallyDrop::new(self.locked.read().await); + LayerManagerReadGuard { + guard, + acquired_at: std::time::Instant::now(), + holder, + } + } + + pub(crate) fn try_read( + &self, + holder: LayerManagerLockHolder, + ) -> Result { + let guard = ManuallyDrop::new(self.locked.try_read()?); + + Ok(LayerManagerReadGuard { + guard, + acquired_at: std::time::Instant::now(), + holder, + }) + } + + pub(crate) async fn write(&self, holder: LayerManagerLockHolder) -> LayerManagerWriteGuard { + let guard = ManuallyDrop::new(self.locked.write().await); + LayerManagerWriteGuard { + guard, + acquired_at: std::time::Instant::now(), + holder, + } + } + + pub(crate) fn try_write( + &self, + holder: LayerManagerLockHolder, + ) -> Result { + let guard = ManuallyDrop::new(self.locked.try_write()?); + + Ok(LayerManagerWriteGuard { + guard, + acquired_at: std::time::Instant::now(), + holder, + }) + } +} + /// Provides semantic APIs to manipulate the layer map. pub(crate) enum LayerManager { /// Open as in not shutdown layer manager; we still have in-memory layers and we can manipulate diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index 343e04f5f0..6d52da1f00 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -275,12 +275,20 @@ pub(super) async fn handle_walreceiver_connection( let copy_stream = replication_client.copy_both_simple(&query).await?; let mut physical_stream = pin!(ReplicationStream::new(copy_stream)); - let mut walingest = WalIngest::new(timeline.as_ref(), startpoint, &ctx) - .await - .map_err(|e| match e.kind { - crate::walingest::WalIngestErrorKind::Cancelled => WalReceiverError::Cancelled, - _ => WalReceiverError::Other(e.into()), - })?; + let walingest_future = WalIngest::new(timeline.as_ref(), startpoint, &ctx); + let walingest_res = select! { + walingest_res = walingest_future => walingest_res, + _ = cancellation.cancelled() => { + // We are doing reads in WalIngest::new, and those can hang as they come from the network. + // Timeline cancellation hits the walreceiver cancellation token before it hits the timeline global one. + debug!("Connection cancelled"); + return Err(WalReceiverError::Cancelled); + }, + }; + let mut walingest = walingest_res.map_err(|e| match e.kind { + crate::walingest::WalIngestErrorKind::Cancelled => WalReceiverError::Cancelled, + _ => WalReceiverError::Other(e.into()), + })?; let (format, compression) = match protocol { PostgresClientProtocol::Interpreted { @@ -360,8 +368,7 @@ pub(super) async fn handle_walreceiver_connection( match raw_wal_start_lsn.cmp(&expected_wal_start) { std::cmp::Ordering::Greater => { let msg = format!( - "Gap in streamed WAL: [{}, {})", - expected_wal_start, raw_wal_start_lsn + "Gap in streamed WAL: [{expected_wal_start}, {raw_wal_start_lsn})" ); critical!("{msg}"); return Err(WalReceiverError::Other(anyhow!(msg))); diff --git a/pageserver/src/virtual_file/owned_buffers_io/aligned_buffer/buffer.rs b/pageserver/src/virtual_file/owned_buffers_io/aligned_buffer/buffer.rs index 090d2ece85..85ea5c4d80 100644 --- a/pageserver/src/virtual_file/owned_buffers_io/aligned_buffer/buffer.rs +++ b/pageserver/src/virtual_file/owned_buffers_io/aligned_buffer/buffer.rs @@ -68,16 +68,9 @@ impl AlignedBuffer { assert!( begin <= end, - "range start must not be greater than end: {:?} <= {:?}", - begin, - end, - ); - assert!( - end <= len, - "range end out of bounds: {:?} <= {:?}", - end, - len, + "range start must not be greater than end: {begin:?} <= {end:?}", ); + assert!(end <= len, "range end out of bounds: {end:?} <= {len:?}",); let begin = self.range.start + begin; let end = self.range.start + end; diff --git a/pageserver/src/virtual_file/owned_buffers_io/aligned_buffer/buffer_mut.rs b/pageserver/src/virtual_file/owned_buffers_io/aligned_buffer/buffer_mut.rs index 07f949b89e..93116ea85e 100644 --- a/pageserver/src/virtual_file/owned_buffers_io/aligned_buffer/buffer_mut.rs +++ b/pageserver/src/virtual_file/owned_buffers_io/aligned_buffer/buffer_mut.rs @@ -242,10 +242,7 @@ unsafe impl bytes::BufMut for AlignedBufferMut { /// Panic with a nice error message. #[cold] fn panic_advance(idx: usize, len: usize) -> ! { - panic!( - "advance out of bounds: the len is {} but advancing by {}", - len, idx - ); + panic!("advance out of bounds: the len is {len} but advancing by {idx}"); } /// Safety: [`AlignedBufferMut`] has exclusive ownership of the io buffer, diff --git a/pageserver/src/walingest.rs b/pageserver/src/walingest.rs index c1a3b79915..a597aedee3 100644 --- a/pageserver/src/walingest.rs +++ b/pageserver/src/walingest.rs @@ -28,20 +28,20 @@ use std::time::{Duration, Instant, SystemTime}; use bytes::{Buf, Bytes}; use pageserver_api::key::{Key, rel_block_to_key}; -use pageserver_api::record::NeonWalRecord; use pageserver_api::reltag::{BlockNumber, RelTag, SlruKind}; use pageserver_api::shard::ShardIdentity; -use postgres_ffi::relfile_utils::{FSM_FORKNUM, INIT_FORKNUM, MAIN_FORKNUM, VISIBILITYMAP_FORKNUM}; use postgres_ffi::walrecord::*; use postgres_ffi::{ - TimestampTz, TransactionId, dispatch_pgversion, enum_pgversion, enum_pgversion_dispatch, - fsm_logical_to_physical, pg_constants, + PgMajorVersion, TimestampTz, TransactionId, dispatch_pgversion, enum_pgversion, + enum_pgversion_dispatch, fsm_logical_to_physical, pg_constants, }; +use postgres_ffi_types::forknum::{FSM_FORKNUM, INIT_FORKNUM, MAIN_FORKNUM, VISIBILITYMAP_FORKNUM}; use tracing::*; use utils::bin_ser::{DeserializeError, SerializeError}; use utils::lsn::Lsn; use utils::rate_limit::RateLimit; use utils::{critical, failpoint_support}; +use wal_decoder::models::record::NeonWalRecord; use wal_decoder::models::*; use crate::ZERO_PAGE; @@ -781,7 +781,7 @@ impl WalIngest { ) -> Result<(), WalIngestError> { let (xact_common, is_commit, is_prepared) = match record { XactRecord::Prepare(XactPrepare { xl_xid, data }) => { - let xid: u64 = if modification.tline.pg_version >= 17 { + let xid: u64 = if modification.tline.pg_version >= PgMajorVersion::PG17 { self.adjust_to_full_transaction_id(xl_xid)? } else { xl_xid as u64 @@ -886,7 +886,7 @@ impl WalIngest { xl_xid, parsed.xid, lsn, ); - let xid: u64 = if modification.tline.pg_version >= 17 { + let xid: u64 = if modification.tline.pg_version >= PgMajorVersion::PG17 { self.adjust_to_full_transaction_id(parsed.xid)? } else { parsed.xid as u64 @@ -1241,7 +1241,7 @@ impl WalIngest { if xlog_checkpoint.oldestActiveXid == pg_constants::INVALID_TRANSACTION_ID && info == pg_constants::XLOG_CHECKPOINT_SHUTDOWN { - let oldest_active_xid = if pg_version >= 17 { + let oldest_active_xid = if pg_version >= PgMajorVersion::PG17 { let mut oldest_active_full_xid = cp.nextXid.value; for xid in modification.tline.list_twophase_files(lsn, ctx).await? { if xid < oldest_active_full_xid { @@ -1475,10 +1475,11 @@ impl WalIngest { const fn rate_limiter( &self, - pg_version: u32, + pg_version: PgMajorVersion, ) -> Option<&Lazy>> { - const MIN_PG_VERSION: u32 = 14; - const MAX_PG_VERSION: u32 = 17; + const MIN_PG_VERSION: u32 = PgMajorVersion::PG14.major_version_num(); + const MAX_PG_VERSION: u32 = PgMajorVersion::PG17.major_version_num(); + let pg_version = pg_version.major_version_num(); if pg_version < MIN_PG_VERSION || pg_version > MAX_PG_VERSION { return None; @@ -1603,6 +1604,7 @@ async fn get_relsize( #[cfg(test)] mod tests { use anyhow::Result; + use postgres_ffi::PgMajorVersion; use postgres_ffi::RELSEG_SIZE; use super::*; @@ -1625,7 +1627,7 @@ mod tests { #[tokio::test] async fn test_zeroed_checkpoint_decodes_correctly() -> Result<(), anyhow::Error> { - for i in 14..=16 { + for i in PgMajorVersion::ALL { dispatch_pgversion!(i, { pgv::CheckPoint::decode(&pgv::ZERO_CHECKPOINT)?; }); @@ -2108,7 +2110,7 @@ mod tests { // Check relation content for blkno in 0..relsize { let lsn = Lsn(0x20); - let data = format!("foo blk {} at {}", blkno, lsn); + let data = format!("foo blk {blkno} at {lsn}"); assert_eq!( tline .get_rel_page_at_lsn( @@ -2142,7 +2144,7 @@ mod tests { for blkno in 0..1 { let lsn = Lsn(0x20); - let data = format!("foo blk {} at {}", blkno, lsn); + let data = format!("foo blk {blkno} at {lsn}"); assert_eq!( tline .get_rel_page_at_lsn( @@ -2167,7 +2169,7 @@ mod tests { ); for blkno in 0..relsize { let lsn = Lsn(0x20); - let data = format!("foo blk {} at {}", blkno, lsn); + let data = format!("foo blk {blkno} at {lsn}"); assert_eq!( tline .get_rel_page_at_lsn( @@ -2188,7 +2190,7 @@ mod tests { let lsn = Lsn(0x80); let mut m = tline.begin_modification(lsn); for blkno in 0..relsize { - let data = format!("foo blk {} at {}", blkno, lsn); + let data = format!("foo blk {blkno} at {lsn}"); walingest .put_rel_page_image(&mut m, TESTREL_A, blkno, test_img(&data), &ctx) .await?; @@ -2210,7 +2212,7 @@ mod tests { // Check relation content for blkno in 0..relsize { let lsn = Lsn(0x80); - let data = format!("foo blk {} at {}", blkno, lsn); + let data = format!("foo blk {blkno} at {lsn}"); assert_eq!( tline .get_rel_page_at_lsn( @@ -2335,7 +2337,7 @@ mod tests { // 5. Grep sk logs for "restart decoder" to get startpoint // 6. Run just the decoder from this test to get the endpoint. // It's the last LSN the decoder will output. - let pg_version = 15; // The test data was generated by pg15 + let pg_version = PgMajorVersion::PG15; // The test data was generated by pg15 let path = "test_data/sk_wal_segment_from_pgbench"; let wal_segment_path = format!("{path}/000000010000000000000001.zst"); let source_initdb_path = format!("{path}/{INITDB_PATH}"); @@ -2414,6 +2416,6 @@ mod tests { } let duration = started_at.elapsed(); - println!("done in {:?}", duration); + println!("done in {duration:?}"); } } diff --git a/pageserver/src/walredo.rs b/pageserver/src/walredo.rs index ed8a954369..b17b5a15f9 100644 --- a/pageserver/src/walredo.rs +++ b/pageserver/src/walredo.rs @@ -32,12 +32,13 @@ use anyhow::Context; use bytes::{Bytes, BytesMut}; use pageserver_api::key::Key; use pageserver_api::models::{WalRedoManagerProcessStatus, WalRedoManagerStatus}; -use pageserver_api::record::NeonWalRecord; use pageserver_api::shard::TenantShardId; +use postgres_ffi::PgMajorVersion; use tracing::*; use utils::lsn::Lsn; use utils::sync::gate::GateError; use utils::sync::heavier_once_cell; +use wal_decoder::models::record::NeonWalRecord; use crate::config::PageServerConf; use crate::metrics::{ @@ -165,7 +166,7 @@ impl PostgresRedoManager { lsn: Lsn, base_img: Option<(Lsn, Bytes)>, records: Vec<(Lsn, NeonWalRecord)>, - pg_version: u32, + pg_version: PgMajorVersion, redo_attempt_type: RedoAttemptType, ) -> Result { if records.is_empty() { @@ -232,7 +233,7 @@ impl PostgresRedoManager { /// # Cancel-Safety /// /// This method is cancellation-safe. - pub async fn ping(&self, pg_version: u32) -> Result<(), Error> { + pub async fn ping(&self, pg_version: PgMajorVersion) -> Result<(), Error> { self.do_with_walredo_process(pg_version, |proc| async move { proc.ping(Duration::from_secs(1)) .await @@ -342,7 +343,7 @@ impl PostgresRedoManager { O, >( &self, - pg_version: u32, + pg_version: PgMajorVersion, closure: F, ) -> Result { let proc: Arc = match self.redo_process.get_or_init_detached().await { @@ -442,7 +443,7 @@ impl PostgresRedoManager { base_img_lsn: Lsn, records: &[(Lsn, NeonWalRecord)], wal_redo_timeout: Duration, - pg_version: u32, + pg_version: PgMajorVersion, max_retry_attempts: u32, ) -> Result { *(self.last_redo_at.lock().unwrap()) = Some(Instant::now()); @@ -571,11 +572,12 @@ mod tests { use bytes::Bytes; use pageserver_api::key::Key; - use pageserver_api::record::NeonWalRecord; use pageserver_api::shard::TenantShardId; + use postgres_ffi::PgMajorVersion; use tracing::Instrument; use utils::id::TenantId; use utils::lsn::Lsn; + use wal_decoder::models::record::NeonWalRecord; use super::PostgresRedoManager; use crate::config::PageServerConf; @@ -586,7 +588,7 @@ mod tests { let h = RedoHarness::new().unwrap(); h.manager - .ping(14) + .ping(PgMajorVersion::PG14) .instrument(h.span()) .await .expect("ping should work"); @@ -612,7 +614,7 @@ mod tests { Lsn::from_str("0/16E2408").unwrap(), None, short_records(), - 14, + PgMajorVersion::PG14, RedoAttemptType::ReadPage, ) .instrument(h.span()) @@ -641,7 +643,7 @@ mod tests { Lsn::from_str("0/16E2408").unwrap(), None, short_records(), - 14, + PgMajorVersion::PG14, RedoAttemptType::ReadPage, ) .instrument(h.span()) @@ -663,7 +665,7 @@ mod tests { Lsn::INVALID, None, short_records(), - 16, /* 16 currently produces stderr output on startup, which adds a nice extra edge */ + PgMajorVersion::PG16, /* 16 currently produces stderr output on startup, which adds a nice extra edge */ RedoAttemptType::ReadPage, ) .instrument(h.span()) diff --git a/pageserver/src/walredo/apply_neon.rs b/pageserver/src/walredo/apply_neon.rs index a3840f1f6f..a525579082 100644 --- a/pageserver/src/walredo/apply_neon.rs +++ b/pageserver/src/walredo/apply_neon.rs @@ -2,16 +2,16 @@ use anyhow::Context; use byteorder::{ByteOrder, LittleEndian}; use bytes::BytesMut; use pageserver_api::key::Key; -use pageserver_api::record::NeonWalRecord; use pageserver_api::reltag::SlruKind; -use postgres_ffi::relfile_utils::VISIBILITYMAP_FORKNUM; use postgres_ffi::v14::nonrelfile_utils::{ mx_offset_to_flags_bitshift, mx_offset_to_flags_offset, mx_offset_to_member_offset, transaction_id_set_status, }; use postgres_ffi::{BLCKSZ, pg_constants}; +use postgres_ffi_types::forknum::VISIBILITYMAP_FORKNUM; use tracing::*; use utils::lsn::Lsn; +use wal_decoder::models::record::NeonWalRecord; /// Can this request be served by neon redo functions /// or we need to pass it to wal-redo postgres process? @@ -52,8 +52,7 @@ pub(crate) fn apply_in_neon( let (rel, _) = key.to_rel_block().context("invalid record")?; assert!( rel.forknum == VISIBILITYMAP_FORKNUM, - "TruncateVisibilityMap record on unexpected rel {}", - rel + "TruncateVisibilityMap record on unexpected rel {rel}" ); let map = &mut page[pg_constants::MAXALIGN_SIZE_OF_PAGE_HEADER_DATA..]; map[*trunc_byte + 1..].fill(0u8); @@ -78,8 +77,7 @@ pub(crate) fn apply_in_neon( let (rel, blknum) = key.to_rel_block().context("invalid record")?; assert!( rel.forknum == VISIBILITYMAP_FORKNUM, - "ClearVisibilityMapFlags record on unexpected rel {}", - rel + "ClearVisibilityMapFlags record on unexpected rel {rel}" ); if let Some(heap_blkno) = *new_heap_blkno { // Calculate the VM block and offset that corresponds to the heap block. @@ -124,8 +122,7 @@ pub(crate) fn apply_in_neon( assert_eq!( slru_kind, SlruKind::Clog, - "ClogSetCommitted record with unexpected key {}", - key + "ClogSetCommitted record with unexpected key {key}" ); for &xid in xids { let pageno = xid / pg_constants::CLOG_XACTS_PER_PAGE; @@ -135,15 +132,11 @@ pub(crate) fn apply_in_neon( // Check that we're modifying the correct CLOG block. assert!( segno == expected_segno, - "ClogSetCommitted record for XID {} with unexpected key {}", - xid, - key + "ClogSetCommitted record for XID {xid} with unexpected key {key}" ); assert!( blknum == expected_blknum, - "ClogSetCommitted record for XID {} with unexpected key {}", - xid, - key + "ClogSetCommitted record for XID {xid} with unexpected key {key}" ); transaction_id_set_status(xid, pg_constants::TRANSACTION_STATUS_COMMITTED, page); @@ -169,8 +162,7 @@ pub(crate) fn apply_in_neon( assert_eq!( slru_kind, SlruKind::Clog, - "ClogSetAborted record with unexpected key {}", - key + "ClogSetAborted record with unexpected key {key}" ); for &xid in xids { let pageno = xid / pg_constants::CLOG_XACTS_PER_PAGE; @@ -180,15 +172,11 @@ pub(crate) fn apply_in_neon( // Check that we're modifying the correct CLOG block. assert!( segno == expected_segno, - "ClogSetAborted record for XID {} with unexpected key {}", - xid, - key + "ClogSetAborted record for XID {xid} with unexpected key {key}" ); assert!( blknum == expected_blknum, - "ClogSetAborted record for XID {} with unexpected key {}", - xid, - key + "ClogSetAborted record for XID {xid} with unexpected key {key}" ); transaction_id_set_status(xid, pg_constants::TRANSACTION_STATUS_ABORTED, page); @@ -199,8 +187,7 @@ pub(crate) fn apply_in_neon( assert_eq!( slru_kind, SlruKind::MultiXactOffsets, - "MultixactOffsetCreate record with unexpected key {}", - key + "MultixactOffsetCreate record with unexpected key {key}" ); // Compute the block and offset to modify. // See RecordNewMultiXact in PostgreSQL sources. @@ -213,15 +200,11 @@ pub(crate) fn apply_in_neon( let expected_blknum = pageno % pg_constants::SLRU_PAGES_PER_SEGMENT; assert!( segno == expected_segno, - "MultiXactOffsetsCreate record for multi-xid {} with unexpected key {}", - mid, - key + "MultiXactOffsetsCreate record for multi-xid {mid} with unexpected key {key}" ); assert!( blknum == expected_blknum, - "MultiXactOffsetsCreate record for multi-xid {} with unexpected key {}", - mid, - key + "MultiXactOffsetsCreate record for multi-xid {mid} with unexpected key {key}" ); LittleEndian::write_u32(&mut page[offset..offset + 4], *moff); @@ -231,8 +214,7 @@ pub(crate) fn apply_in_neon( assert_eq!( slru_kind, SlruKind::MultiXactMembers, - "MultixactMembersCreate record with unexpected key {}", - key + "MultixactMembersCreate record with unexpected key {key}" ); for (i, member) in members.iter().enumerate() { let offset = moff + i as u32; @@ -249,15 +231,11 @@ pub(crate) fn apply_in_neon( let expected_blknum = pageno % pg_constants::SLRU_PAGES_PER_SEGMENT; assert!( segno == expected_segno, - "MultiXactMembersCreate record for offset {} with unexpected key {}", - moff, - key + "MultiXactMembersCreate record for offset {moff} with unexpected key {key}" ); assert!( blknum == expected_blknum, - "MultiXactMembersCreate record for offset {} with unexpected key {}", - moff, - key + "MultiXactMembersCreate record for offset {moff} with unexpected key {key}" ); let mut flagsval = LittleEndian::read_u32(&page[flagsoff..flagsoff + 4]); diff --git a/pageserver/src/walredo/process.rs b/pageserver/src/walredo/process.rs index 6d4a38d4ff..c8b0846480 100644 --- a/pageserver/src/walredo/process.rs +++ b/pageserver/src/walredo/process.rs @@ -10,14 +10,14 @@ use std::time::Duration; use anyhow::Context; use bytes::Bytes; -use pageserver_api::record::NeonWalRecord; use pageserver_api::reltag::RelTag; use pageserver_api::shard::TenantShardId; -use postgres_ffi::BLCKSZ; +use postgres_ffi::{BLCKSZ, PgMajorVersion}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tracing::{Instrument, debug, error, instrument}; use utils::lsn::Lsn; use utils::poison::Poison; +use wal_decoder::models::record::NeonWalRecord; use self::no_leak_child::NoLeakChild; use crate::config::PageServerConf; @@ -54,11 +54,11 @@ impl WalRedoProcess { // // Start postgres binary in special WAL redo mode. // - #[instrument(skip_all,fields(pg_version=pg_version))] + #[instrument(skip_all,fields(pg_version=pg_version.major_version_num()))] pub(crate) fn launch( conf: &'static PageServerConf, tenant_shard_id: TenantShardId, - pg_version: u32, + pg_version: PgMajorVersion, ) -> anyhow::Result { crate::span::debug_assert_current_span_has_tenant_id(); diff --git a/pgxn/Makefile b/pgxn/Makefile new file mode 100644 index 0000000000..8f190668ea --- /dev/null +++ b/pgxn/Makefile @@ -0,0 +1,28 @@ +# This makefile assumes that 'pg_config' is in the path, or is passed in the +# PG_CONFIG variable. +# +# This is used in two different ways: +# +# 1. The main makefile calls this, when you invoke the `make neon-pg-ext-%` +# target. It passes PG_CONFIG pointing to pg_install/%/bin/pg_config. +# This is a VPATH build; the current directory is build/pgxn-%, and +# the path to the Makefile is passed with the -f argument. +# +# 2. compute-node.Dockerfile invokes this to build the compute extensions +# for the specific Postgres version. It relies on pg_config already +# being in $(PATH). + +srcdir = $(dir $(firstword $(MAKEFILE_LIST))) + +PG_CONFIG = pg_config + +subdirs = neon neon_rmgr neon_walredo neon_utils neon_test_utils + +.PHONY: install install-compute install-storage $(subdirs) +install: $(subdirs) +install-compute: neon neon_utils neon_test_utils neon_rmgr +install-storage: neon_rmgr neon_walredo + +$(subdirs): %: + mkdir -p $* + $(MAKE) PG_CONFIG=$(PG_CONFIG) -C $* -f $(abspath $(srcdir)/$@/Makefile) install diff --git a/pgxn/neon/Makefile b/pgxn/neon/Makefile index 8bcc6bf924..9bce0e798a 100644 --- a/pgxn/neon/Makefile +++ b/pgxn/neon/Makefile @@ -21,7 +21,7 @@ OBJS = \ unstable_extensions.o \ walproposer.o \ walproposer_pg.o \ - control_plane_connector.o \ + neon_ddl_handler.o \ walsender_hooks.o PG_CPPFLAGS = -I$(libpq_srcdir) diff --git a/pgxn/neon/communicator.c b/pgxn/neon/communicator.c index 2655a45bcc..7c84be7d15 100644 --- a/pgxn/neon/communicator.c +++ b/pgxn/neon/communicator.c @@ -1092,13 +1092,15 @@ communicator_prefetch_register_bufferv(BufferTag tag, neon_request_lsns *frlsns, MyPState->ring_last <= ring_index); } -/* internal version. Returns the ring index */ +/* Internal version. Returns the ring index of the last block (result of this function is used only +* when nblocks==1) +*/ static uint64 prefetch_register_bufferv(BufferTag tag, neon_request_lsns *frlsns, BlockNumber nblocks, const bits8 *mask, bool is_prefetch) { - uint64 min_ring_index; + uint64 last_ring_index; PrefetchRequest hashkey; #ifdef USE_ASSERT_CHECKING bool any_hits = false; @@ -1122,13 +1124,12 @@ Retry: MyPState->ring_unused - MyPState->ring_receive; MyNeonCounters->getpage_prefetches_buffered = MyPState->n_responses_buffered; + last_ring_index = UINT64_MAX; - min_ring_index = UINT64_MAX; for (int i = 0; i < nblocks; i++) { PrefetchRequest *slot = NULL; PrfHashEntry *entry = NULL; - uint64 ring_index; neon_request_lsns *lsns; if (PointerIsValid(mask) && BITMAP_ISSET(mask, i)) @@ -1152,12 +1153,12 @@ Retry: if (entry != NULL) { slot = entry->slot; - ring_index = slot->my_ring_index; - Assert(slot == GetPrfSlot(ring_index)); + last_ring_index = slot->my_ring_index; + Assert(slot == GetPrfSlot(last_ring_index)); Assert(slot->status != PRFS_UNUSED); - Assert(MyPState->ring_last <= ring_index && - ring_index < MyPState->ring_unused); + Assert(MyPState->ring_last <= last_ring_index && + last_ring_index < MyPState->ring_unused); Assert(BufferTagsEqual(&slot->buftag, &hashkey.buftag)); /* @@ -1169,9 +1170,9 @@ Retry: if (!neon_prefetch_response_usable(lsns, slot)) { /* Wait for the old request to finish and discard it */ - if (!prefetch_wait_for(ring_index)) + if (!prefetch_wait_for(last_ring_index)) goto Retry; - prefetch_set_unused(ring_index); + prefetch_set_unused(last_ring_index); entry = NULL; slot = NULL; pgBufferUsage.prefetch.expired += 1; @@ -1188,13 +1189,12 @@ Retry: */ if (slot->status == PRFS_TAG_REMAINS) { - prefetch_set_unused(ring_index); + prefetch_set_unused(last_ring_index); entry = NULL; slot = NULL; } else { - min_ring_index = Min(min_ring_index, ring_index); /* The buffered request is good enough, return that index */ if (is_prefetch) pgBufferUsage.prefetch.duplicates++; @@ -1283,12 +1283,12 @@ Retry: * The next buffer pointed to by `ring_unused` is now definitely empty, so * we can insert the new request to it. */ - ring_index = MyPState->ring_unused; + last_ring_index = MyPState->ring_unused; - Assert(MyPState->ring_last <= ring_index && - ring_index <= MyPState->ring_unused); + Assert(MyPState->ring_last <= last_ring_index && + last_ring_index <= MyPState->ring_unused); - slot = GetPrfSlotNoCheck(ring_index); + slot = GetPrfSlotNoCheck(last_ring_index); Assert(slot->status == PRFS_UNUSED); @@ -1298,11 +1298,9 @@ Retry: */ slot->buftag = hashkey.buftag; slot->shard_no = get_shard_number(&tag); - slot->my_ring_index = ring_index; + slot->my_ring_index = last_ring_index; slot->flags = 0; - min_ring_index = Min(min_ring_index, ring_index); - if (is_prefetch) MyNeonCounters->getpage_prefetch_requests_total++; else @@ -1315,11 +1313,12 @@ Retry: MyPState->ring_unused - MyPState->ring_receive; Assert(any_hits); + Assert(last_ring_index != UINT64_MAX); - Assert(GetPrfSlot(min_ring_index)->status == PRFS_REQUESTED || - GetPrfSlot(min_ring_index)->status == PRFS_RECEIVED); - Assert(MyPState->ring_last <= min_ring_index && - min_ring_index < MyPState->ring_unused); + Assert(GetPrfSlot(last_ring_index)->status == PRFS_REQUESTED || + GetPrfSlot(last_ring_index)->status == PRFS_RECEIVED); + Assert(MyPState->ring_last <= last_ring_index && + last_ring_index < MyPState->ring_unused); if (flush_every_n_requests > 0 && MyPState->ring_unused - MyPState->ring_flush >= flush_every_n_requests) @@ -1335,7 +1334,7 @@ Retry: MyPState->ring_flush = MyPState->ring_unused; } - return min_ring_index; + return last_ring_index; } static bool diff --git a/pgxn/neon/control_plane_connector.h b/pgxn/neon/control_plane_connector.h deleted file mode 100644 index 7eed449200..0000000000 --- a/pgxn/neon/control_plane_connector.h +++ /dev/null @@ -1,6 +0,0 @@ -#ifndef CONTROL_PLANE_CONNECTOR_H -#define CONTROL_PLANE_CONNECTOR_H - -void InitControlPlaneConnector(void); - -#endif diff --git a/pgxn/neon/file_cache.c b/pgxn/neon/file_cache.c index 45a4695495..8cfa09bc87 100644 --- a/pgxn/neon/file_cache.c +++ b/pgxn/neon/file_cache.c @@ -1295,7 +1295,8 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno, if (iteration_hits != 0) { - /* chunk offset (# of pages) into the LFC file */ + /* chunk offset (# + of pages) into the LFC file */ off_t first_read_offset = (off_t) entry_offset * lfc_blocks_per_chunk; int nwrite = iov_last_used - first_block_in_chunk_read; /* offset of first IOV */ @@ -1313,16 +1314,6 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno, lfc_disable("read"); return -1; } - - /* - * We successfully read the pages we know were valid when we - * started reading; now mark those pages as read - */ - for (int i = first_block_in_chunk_read; i < iov_last_used; i++) - { - if (BITMAP_ISSET(chunk_mask, i)) - BITMAP_SET(mask, buf_offset + i); - } } /* Place entry to the head of LRU list */ @@ -1340,6 +1331,15 @@ lfc_readv_select(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno, { lfc_ctl->time_read += io_time_us; inc_page_cache_read_wait(io_time_us); + /* + * We successfully read the pages we know were valid when we + * started reading; now mark those pages as read + */ + for (int i = first_block_in_chunk_read; i < iov_last_used; i++) + { + if (BITMAP_ISSET(chunk_mask, i)) + BITMAP_SET(mask, buf_offset + i); + } } CriticalAssert(entry->access_count > 0); diff --git a/pgxn/neon/neon--1.6--1.5.sql b/pgxn/neon/neon--1.6--1.5.sql index 57512980f5..50c62238a3 100644 --- a/pgxn/neon/neon--1.6--1.5.sql +++ b/pgxn/neon/neon--1.6--1.5.sql @@ -2,6 +2,6 @@ DROP FUNCTION IF EXISTS get_prewarm_info(out total_pages integer, out prewarmed_ DROP FUNCTION IF EXISTS get_local_cache_state(max_chunks integer); -DROP FUNCTION IF EXISTS prewarm_local_cache(state bytea, n_workers integer default 1); +DROP FUNCTION IF EXISTS prewarm_local_cache(state bytea, n_workers integer); diff --git a/pgxn/neon/neon.c b/pgxn/neon/neon.c index 5b4ced7cf0..8a405f4129 100644 --- a/pgxn/neon/neon.c +++ b/pgxn/neon/neon.c @@ -33,9 +33,9 @@ #include "extension_server.h" #include "file_cache.h" #include "neon.h" +#include "neon_ddl_handler.h" #include "neon_lwlsncache.h" #include "neon_perf_counters.h" -#include "control_plane_connector.h" #include "logical_replication_monitor.h" #include "unstable_extensions.h" #include "walsender_hooks.h" @@ -454,7 +454,7 @@ _PG_init(void) InitUnstableExtensionsSupport(); InitLogicalReplicationMonitor(); - InitControlPlaneConnector(); + InitDDLHandler(); pg_init_extension_server(); diff --git a/pgxn/neon/neon.control b/pgxn/neon/neon.control index af69116e21..51193f63c8 100644 --- a/pgxn/neon/neon.control +++ b/pgxn/neon/neon.control @@ -1,6 +1,6 @@ # neon extension comment = 'cloud storage for PostgreSQL' -default_version = '1.5' +default_version = '1.6' module_pathname = '$libdir/neon' relocatable = true trusted = true diff --git a/pgxn/neon/control_plane_connector.c b/pgxn/neon/neon_ddl_handler.c similarity index 57% rename from pgxn/neon/control_plane_connector.c rename to pgxn/neon/neon_ddl_handler.c index 47ed37da06..dba28c0ed6 100644 --- a/pgxn/neon/control_plane_connector.c +++ b/pgxn/neon/neon_ddl_handler.c @@ -1,6 +1,6 @@ /*------------------------------------------------------------------------- * - * control_plane_connector.c + * neon_ddl_handler.c * Captures updates to roles/databases using ProcessUtility_hook and * sends them to the control ProcessUtility_hook. The changes are sent * via HTTP to the URL specified by the GUC neon.console_url when the @@ -13,18 +13,30 @@ * accumulate changes. On subtransaction commit, the top of the stack * is merged with the table below it. * + * Support event triggers for neon_superuser + * + * IDENTIFICATION + * contrib/neon/neon_dll_handler.c + * *------------------------------------------------------------------------- */ #include "postgres.h" #include +#include #include "access/xact.h" +#include "catalog/pg_authid.h" +#include "catalog/pg_proc.h" #include "commands/defrem.h" +#include "commands/event_trigger.h" +#include "commands/user.h" #include "fmgr.h" #include "libpq/crypt.h" #include "miscadmin.h" +#include "nodes/makefuncs.h" +#include "parser/parse_func.h" #include "tcop/pquery.h" #include "tcop/utility.h" #include "utils/acl.h" @@ -32,11 +44,16 @@ #include "utils/hsearch.h" #include "utils/memutils.h" #include "utils/jsonb.h" +#include +#include -#include "control_plane_connector.h" +#include "neon_ddl_handler.h" #include "neon_utils.h" static ProcessUtility_hook_type PreviousProcessUtilityHook = NULL; +static fmgr_hook_type next_fmgr_hook = NULL; +static needs_fmgr_hook_type next_needs_fmgr_hook = NULL; +static bool neon_event_triggers = true; static const char *jwt_token = NULL; @@ -773,6 +790,7 @@ HandleDropRole(DropRoleStmt *stmt) } } + static void HandleRename(RenameStmt *stmt) { @@ -782,6 +800,460 @@ HandleRename(RenameStmt *stmt) return HandleRoleRename(stmt); } + +/* + * Support for Event Triggers. + * + * In vanilla only superuser can create Event Triggers. + * + * We allow it for neon_superuser by temporary switching to superuser. But as + * far as event trigger can fire in superuser context we should protect + * superuser from execution of arbitrary user's code. + * + * The idea was taken from Supabase PR series starting at + * https://github.com/supabase/supautils/pull/98 + */ + +static bool +neon_needs_fmgr_hook(Oid functionId) { + + return (next_needs_fmgr_hook && (*next_needs_fmgr_hook) (functionId)) + || get_func_rettype(functionId) == EVENT_TRIGGEROID; +} + +static void +LookupFuncOwnerSecDef(Oid functionId, Oid *funcOwner, bool *is_secdef) +{ + Form_pg_proc procForm; + HeapTuple proc_tup = SearchSysCache1(PROCOID, ObjectIdGetDatum(functionId)); + + if (!HeapTupleIsValid(proc_tup)) + ereport(ERROR, + (errmsg("cache lookup failed for function %u", functionId))); + + procForm = (Form_pg_proc) GETSTRUCT(proc_tup); + + *funcOwner = procForm->proowner; + *is_secdef = procForm->prosecdef; + + ReleaseSysCache(proc_tup); +} + + +PG_FUNCTION_INFO_V1(noop); +Datum noop(__attribute__ ((unused)) PG_FUNCTION_ARGS) { PG_RETURN_VOID();} + +static void +force_noop(FmgrInfo *finfo) +{ + finfo->fn_addr = (PGFunction) noop; + finfo->fn_oid = InvalidOid; /* not a known function OID anymore */ + finfo->fn_nargs = 0; /* no arguments for noop */ + finfo->fn_strict = false; + finfo->fn_retset = false; + finfo->fn_stats = 0; /* no stats collection */ + finfo->fn_extra = NULL; /* clear out old context data */ + finfo->fn_mcxt = CurrentMemoryContext; + finfo->fn_expr = NULL; /* no parse tree */ +} + + +/* + * Skip executing Event Triggers execution for superusers, because Event + * Triggers are SECURITY DEFINER and user provided code could then attempt + * privilege escalation. + * + * Also skip executing Event Triggers when GUC neon.event_triggers has been + * set to false. This might be necessary to be able to connect again after a + * LOGIN Event Trigger has been installed that would prevent connections as + * neon_superuser. + */ +static void +neon_fmgr_hook(FmgrHookEventType event, FmgrInfo *flinfo, Datum *private) +{ + /* + * It can be other needs_fmgr_hook which cause our hook to be invoked for + * non-trigger function, so recheck that is is trigger function. + */ + if (flinfo->fn_oid != InvalidOid && + get_func_rettype(flinfo->fn_oid) != EVENT_TRIGGEROID) + { + if (next_fmgr_hook) + (*next_fmgr_hook) (event, flinfo, private); + + return; + } + + /* + * The neon_superuser role can use the GUC neon.event_triggers to disable + * firing Event Trigger. + * + * SET neon.event_triggers TO false; + * + * This only applies to the neon_superuser role though, and only allows + * skipping Event Triggers owned by neon_superuser, which we check by + * proxy of the Event Trigger function being owned by neon_superuser. + * + * A role that is created in role neon_superuser should be allowed to also + * benefit from the neon_event_triggers GUC, and will be considered the + * same as the neon_superuser role. + */ + if (event == FHET_START + && !neon_event_triggers + && is_neon_superuser()) + { + Oid neon_superuser_oid = get_role_oid("neon_superuser", false); + + /* Find the Function Attributes (owner Oid, security definer) */ + const char *fun_owner_name = NULL; + Oid fun_owner = InvalidOid; + bool fun_is_secdef = false; + + LookupFuncOwnerSecDef(flinfo->fn_oid, &fun_owner, &fun_is_secdef); + fun_owner_name = GetUserNameFromId(fun_owner, false); + + if (RoleIsNeonSuperuser(fun_owner_name) + || has_privs_of_role(fun_owner, neon_superuser_oid)) + { + elog(WARNING, + "Skipping Event Trigger: neon.event_triggers is false"); + + /* + * we can't skip execution directly inside the fmgr_hook so instead we + * change the event trigger function to a noop function. + */ + force_noop(flinfo); + } + } + + /* + * Fire Event Trigger if both function owner and current user are + * superuser, or none of them are. + */ + else if (event == FHET_START + /* still enable it to pass pg_regress tests */ + && !RegressTestMode) + { + /* + * Get the current user oid as of before SECURITY DEFINER change of + * CurrentUserId, and that would be SessionUserId. + */ + Oid current_role_oid = GetSessionUserId(); + bool role_is_super = superuser_arg(current_role_oid); + + /* Find the Function Attributes (owner Oid, security definer) */ + Oid function_owner = InvalidOid; + bool function_is_secdef = false; + bool function_is_owned_by_super = false; + + LookupFuncOwnerSecDef(flinfo->fn_oid, &function_owner, &function_is_secdef); + + function_is_owned_by_super = superuser_arg(function_owner); + + /* + * 1. Refuse to run SECURITY DEFINER function that belongs to a + * superuser when the current user is not a superuser itself. + */ + if (!role_is_super + && function_is_owned_by_super + && function_is_secdef) + { + char *func_name = get_func_name(flinfo->fn_oid); + + ereport(WARNING, + (errmsg("Skipping Event Trigger"), + errdetail("Event Trigger function \"%s\" is owned by \"%s\" " + "and is SECURITY DEFINER", + func_name, + GetUserNameFromId(function_owner, false)))); + + /* + * we can't skip execution directly inside the fmgr_hook so + * instead we change the event trigger function to a noop + * function. + */ + force_noop(flinfo); + } + + /* + * 2. Refuse to run functions that belongs to a non-superuser when the + * current user is a superuser. + * + * We could run a SECURITY DEFINER user-function here and be safe with + * privilege escalation risks, but superuser roles are only used for + * infrastructure maintenance operations, where we prefer to skip + * running user-defined code. + */ + else if (role_is_super && !function_is_owned_by_super) + { + char *func_name = get_func_name(flinfo->fn_oid); + + ereport(WARNING, + (errmsg("Skipping Event Trigger"), + errdetail("Event Trigger function \"%s\" " + "is owned by non-superuser role \"%s\", " + "and current_user \"%s\" is superuser", + func_name, + GetUserNameFromId(function_owner, false), + GetUserNameFromId(current_role_oid, false)))); + + /* + * we can't skip execution directly inside the fmgr_hook so + * instead we change the event trigger function to a noop + * function. + */ + force_noop(flinfo); + } + + } + + if (next_fmgr_hook) + (*next_fmgr_hook) (event, flinfo, private); +} + +static Oid prev_role_oid = 0; +static int prev_role_sec_context = 0; +static bool switched_to_superuser = false; + +/* + * Switch tp superuser if not yet superuser. + * Returns false if already switched to superuser. + */ +static bool +switch_to_superuser(void) +{ + Oid superuser_oid; + + if (switched_to_superuser) + return false; + switched_to_superuser = true; + + superuser_oid = get_role_oid("cloud_admin", true /*missing_ok*/); + if (superuser_oid == InvalidOid) + superuser_oid = BOOTSTRAP_SUPERUSERID; + + GetUserIdAndSecContext(&prev_role_oid, &prev_role_sec_context); + SetUserIdAndSecContext(superuser_oid, prev_role_sec_context | + SECURITY_LOCAL_USERID_CHANGE | + SECURITY_RESTRICTED_OPERATION); + return true; +} + +static void +switch_to_original_role(void) +{ + SetUserIdAndSecContext(prev_role_oid, prev_role_sec_context); + switched_to_superuser = false; +} + +/* + * ALTER ROLE ... SUPERUSER; + * + * Used internally to give superuser to a non-privileged role to allow + * ownership of superuser-only objects such as Event Trigger. + * + * ALTER ROLE foo SUPERUSER; + * ALTER EVENT TRIGGER ... OWNED BY foo; + * ALTER ROLE foo NOSUPERUSER; + * + * Now the EVENT TRIGGER is owned by foo, who can DROP it without having to be + * superuser again. + */ +static void +alter_role_super(const char* rolename, bool make_super) +{ + AlterRoleStmt *alter_stmt = makeNode(AlterRoleStmt); + + DefElem *defel_superuser = +#if PG_MAJORVERSION_NUM <= 14 + makeDefElem("superuser", (Node *) makeInteger(make_super), -1); +#else + makeDefElem("superuser", (Node *) makeBoolean(make_super), -1); +#endif + + RoleSpec *rolespec = makeNode(RoleSpec); + rolespec->roletype = ROLESPEC_CSTRING; + rolespec->rolename = pstrdup(rolename); + rolespec->location = -1; + + alter_stmt->role = rolespec; + alter_stmt->options = list_make1(defel_superuser); + +#if PG_MAJORVERSION_NUM < 15 + AlterRole(alter_stmt); +#else + /* ParseState *pstate, AlterRoleStmt *stmt */ + AlterRole(NULL, alter_stmt); +#endif + + CommandCounterIncrement(); +} + + +/* + * Changes the OWNER of an Event Trigger. + * + * Event Triggers can only be owned by superusers, so this ALTER ROLE with + * SUPERUSER and then removes the property. + */ +static void +alter_event_trigger_owner(const char *obj_name, Oid role_oid) +{ + char* role_name = GetUserNameFromId(role_oid, false); + + alter_role_super(role_name, true); + + AlterEventTriggerOwner(obj_name, role_oid); + CommandCounterIncrement(); + + alter_role_super(role_name, false); +} + + +/* + * Neon processing of the CREATE EVENT TRIGGER requires special attention and + * is worth having its own ProcessUtility_hook for that. + */ +static void +ProcessCreateEventTrigger( + PlannedStmt *pstmt, + const char *queryString, + bool readOnlyTree, + ProcessUtilityContext context, + ParamListInfo params, + QueryEnvironment *queryEnv, + DestReceiver *dest, + QueryCompletion *qc) +{ + Node *parseTree = pstmt->utilityStmt; + bool sudo = false; + + /* We double-check that after local variable declaration block */ + CreateEventTrigStmt *stmt = (CreateEventTrigStmt *) parseTree; + + /* + * We are going to change the current user privileges (sudo) and might + * need after execution cleanup. For that we want to capture the UserId + * before changing it for our sudo implementation. + */ + const Oid current_user_id = GetUserId(); + bool current_user_is_super = superuser_arg(current_user_id); + + if (nodeTag(parseTree) != T_CreateEventTrigStmt) + { + ereport(ERROR, + errcode(ERRCODE_INTERNAL_ERROR), + errmsg("ProcessCreateEventTrigger called for the wrong command")); + } + + /* + * Allow neon_superuser to create Event Trigger, while keeping the + * ownership of the object. + * + * For that we give superuser membership to the role for the execution of + * the command. + */ + if (IsTransactionState() && is_neon_superuser()) + { + /* Find the Event Trigger function Oid */ + Oid func_oid = LookupFuncName(stmt->funcname, 0, NULL, false); + + /* Find the Function Owner Oid */ + Oid func_owner = InvalidOid; + bool is_secdef = false; + bool function_is_owned_by_super = false; + + LookupFuncOwnerSecDef(func_oid, &func_owner, &is_secdef); + + function_is_owned_by_super = superuser_arg(func_owner); + + if(!current_user_is_super && function_is_owned_by_super) + { + ereport(ERROR, + (errcode(ERRCODE_INSUFFICIENT_PRIVILEGE), + errmsg("Permission denied to execute " + "a function owned by a superuser role"), + errdetail("current user \"%s\" is not a superuser " + "and Event Trigger function \"%s\" " + "is owned by a superuser", + GetUserNameFromId(current_user_id, false), + NameListToString(stmt->funcname)))); + } + + if(current_user_is_super && !function_is_owned_by_super) + { + ereport(ERROR, + (errcode(ERRCODE_INSUFFICIENT_PRIVILEGE), + errmsg("Permission denied to execute " + "a function owned by a non-superuser role"), + errdetail("current user \"%s\" is a superuser " + "and function \"%s\" is " + "owned by a non-superuser", + GetUserNameFromId(current_user_id, false), + NameListToString(stmt->funcname)))); + } + + sudo = switch_to_superuser(); + } + + PG_TRY(); + { + if (PreviousProcessUtilityHook) + { + PreviousProcessUtilityHook( + pstmt, + queryString, + readOnlyTree, + context, + params, + queryEnv, + dest, + qc); + } + else + { + standard_ProcessUtility( + pstmt, + queryString, + readOnlyTree, + context, + params, + queryEnv, + dest, + qc); + } + + /* + * Now that the Event Trigger has been installed via our sudo + * mechanism, if the original role was not a superuser then change + * the event trigger ownership back to the original role. + * + * That way [ ALTER | DROP ] EVENT TRIGGER commands just work. + */ + if (IsTransactionState() && is_neon_superuser()) + { + if (!current_user_is_super) + { + /* + * Change event trigger owner to the current role (making + * it a privileged role during the ALTER OWNER command). + */ + alter_event_trigger_owner(stmt->trigname, current_user_id); + } + } + } + PG_FINALLY(); + { + if (sudo) + switch_to_original_role(); + } + PG_END_TRY(); +} + + +/* + * Neon hooks for DDLs (handling privileges, limiting features, etc). + */ static void NeonProcessUtility( PlannedStmt *pstmt, @@ -795,6 +1267,27 @@ NeonProcessUtility( { Node *parseTree = pstmt->utilityStmt; + /* + * The process utility hook for CREATE EVENT TRIGGER is its own + * implementation and warrant being addressed separately from here. + */ + if (nodeTag(parseTree) == T_CreateEventTrigStmt) + { + ProcessCreateEventTrigger( + pstmt, + queryString, + readOnlyTree, + context, + params, + queryEnv, + dest, + qc); + return; + } + + /* + * Other commands that need Neon specific implementations are handled here: + */ switch (nodeTag(parseTree)) { case T_CreatedbStmt: @@ -833,37 +1326,82 @@ NeonProcessUtility( if (PreviousProcessUtilityHook) { PreviousProcessUtilityHook( - pstmt, - queryString, - readOnlyTree, - context, - params, - queryEnv, - dest, - qc); + pstmt, + queryString, + readOnlyTree, + context, + params, + queryEnv, + dest, + qc); } else { standard_ProcessUtility( - pstmt, - queryString, - readOnlyTree, - context, - params, - queryEnv, - dest, - qc); + pstmt, + queryString, + readOnlyTree, + context, + params, + queryEnv, + dest, + qc); } } +/* + * Only neon_superuser is granted privilege to edit neon.event_triggers GUC. + */ +static void +neon_event_triggers_assign_hook(bool newval, void *extra) +{ + /* MyDatabaseId == InvalidOid || !OidIsValid(GetUserId()) */ + + if (IsTransactionState() && !is_neon_superuser()) + { + ereport(ERROR, + (errcode(ERRCODE_INSUFFICIENT_PRIVILEGE), + errmsg("permission denied to set neon.event_triggers"), + errdetail("Only \"neon_superuser\" is allowed to set the GUC"))); + } +} + + void -InitControlPlaneConnector() +InitDDLHandler() { PreviousProcessUtilityHook = ProcessUtility_hook; ProcessUtility_hook = NeonProcessUtility; + + next_needs_fmgr_hook = needs_fmgr_hook; + needs_fmgr_hook = neon_needs_fmgr_hook; + + next_fmgr_hook = fmgr_hook; + fmgr_hook = neon_fmgr_hook; + RegisterXactCallback(NeonXactCallback, NULL); RegisterSubXactCallback(NeonSubXactCallback, NULL); + /* + * The GUC neon.event_triggers should provide the same effect as the + * Postgres GUC event_triggers, but the neon one is PGC_USERSET. + * + * This allows using the GUC in the connection string and work out of a + * LOGIN Event Trigger that would break database access, all without + * having to edit and reload the Postgres configuration file. + */ + DefineCustomBoolVariable( + "neon.event_triggers", + "Enable firing of event triggers", + NULL, + &neon_event_triggers, + true, + PGC_USERSET, + 0, + NULL, + neon_event_triggers_assign_hook, + NULL); + DefineCustomStringVariable( "neon.console_url", "URL of the Neon Console, which will be forwarded changes to dbs and roles", diff --git a/pgxn/neon/neon_ddl_handler.h b/pgxn/neon/neon_ddl_handler.h new file mode 100644 index 0000000000..de18ed3d82 --- /dev/null +++ b/pgxn/neon/neon_ddl_handler.h @@ -0,0 +1,6 @@ +#ifndef CONTROL_DDL_HANDLER_H +#define CONTROL_DDL_HANDLER_H + +void InitDDLHandler(void); + +#endif diff --git a/pgxn/neon/walproposer.c b/pgxn/neon/walproposer.c index 91d39345e2..ba6e4a54ff 100644 --- a/pgxn/neon/walproposer.c +++ b/pgxn/neon/walproposer.c @@ -1135,7 +1135,7 @@ VotesCollectedMset(WalProposer *wp, MemberSet *mset, Safekeeper **msk, StringInf wp->propTermStartLsn = sk->voteResponse.flushLsn; wp->donor = sk; } - wp->truncateLsn = Max(wp->safekeeper[i].voteResponse.truncateLsn, wp->truncateLsn); + wp->truncateLsn = Max(sk->voteResponse.truncateLsn, wp->truncateLsn); if (n_votes > 0) appendStringInfoString(s, ", "); diff --git a/pgxn/neon/walproposer.h b/pgxn/neon/walproposer.h index 08087e5a55..4b223b6b18 100644 --- a/pgxn/neon/walproposer.h +++ b/pgxn/neon/walproposer.h @@ -679,8 +679,7 @@ typedef struct walproposer_api * Finish sync safekeepers with the given LSN. This function should not * return and should exit the program. */ - void (*finish_sync_safekeepers) (WalProposer *wp, XLogRecPtr lsn); - + void (*finish_sync_safekeepers) (WalProposer *wp, XLogRecPtr lsn) __attribute__((noreturn)) ; /* * Called after every AppendResponse from the safekeeper. Used to * propagate backpressure feedback and to confirm WAL persistence (has diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index 3d6a92ad79..185fc83ace 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -1890,7 +1890,7 @@ walprop_pg_wait_event_set(WalProposer *wp, long timeout, Safekeeper **sk, uint32 return rc; } -static void +static void __attribute__((noreturn)) walprop_pg_finish_sync_safekeepers(WalProposer *wp, XLogRecPtr lsn) { fprintf(stdout, "%X/%X\n", LSN_FORMAT_ARGS(lsn)); diff --git a/poetry.lock b/poetry.lock index 21a2664555..1bc5077eb7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -746,23 +746,23 @@ xray = ["mypy-boto3-xray (>=1.26.0,<1.27.0)"] [[package]] name = "botocore" -version = "1.34.11" +version = "1.34.162" description = "Low-level, data-driven core of boto 3." optional = false -python-versions = ">= 3.8" +python-versions = ">=3.8" groups = ["main"] files = [ - {file = "botocore-1.34.11-py3-none-any.whl", hash = "sha256:1ff1398b6ea670e1c01ac67a33af3da854f8e700d3528289c04f319c330d8250"}, - {file = "botocore-1.34.11.tar.gz", hash = "sha256:51905c3d623c60df5dc5794387de7caf886d350180a01a3dfa762e903edb45a9"}, + {file = "botocore-1.34.162-py3-none-any.whl", hash = "sha256:2d918b02db88d27a75b48275e6fb2506e9adaaddbec1ffa6a8a0898b34e769be"}, + {file = "botocore-1.34.162.tar.gz", hash = "sha256:adc23be4fb99ad31961236342b7cbf3c0bfc62532cd02852196032e8c0d682f3"}, ] [package.dependencies] jmespath = ">=0.7.1,<2.0.0" python-dateutil = ">=2.1,<3.0.0" -urllib3 = {version = ">=1.25.4,<2.1", markers = "python_version >= \"3.10\""} +urllib3 = {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""} [package.extras] -crt = ["awscrt (==0.19.19)"] +crt = ["awscrt (==0.21.2)"] [[package]] name = "botocore-stubs" @@ -3051,19 +3051,19 @@ files = [ [[package]] name = "requests" -version = "2.32.3" +version = "2.32.4" description = "Python HTTP for Humans." optional = false python-versions = ">=3.8" groups = ["main"] files = [ - {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, - {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, + {file = "requests-2.32.4-py3-none-any.whl", hash = "sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c"}, + {file = "requests-2.32.4.tar.gz", hash = "sha256:27d0316682c8a29834d3264820024b62a36942083d52caf2f14c0591336d3422"}, ] [package.dependencies] certifi = ">=2017.4.17" -charset-normalizer = ">=2,<4" +charset_normalizer = ">=2,<4" idna = ">=2.5,<4" urllib3 = ">=1.21.1,<3" @@ -3422,20 +3422,21 @@ files = [ [[package]] name = "urllib3" -version = "1.26.19" +version = "2.5.0" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" +python-versions = ">=3.9" groups = ["main"] files = [ - {file = "urllib3-1.26.19-py2.py3-none-any.whl", hash = "sha256:37a0344459b199fce0e80b0d3569837ec6b6937435c5244e7fd73fa6006830f3"}, - {file = "urllib3-1.26.19.tar.gz", hash = "sha256:3e3d753a8618b86d7de333b4223005f68720bcd6a7d2bcb9fbd2229ec7c1e429"}, + {file = "urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc"}, + {file = "urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760"}, ] [package.extras] -brotli = ["brotli (==1.0.9) ; os_name != \"nt\" and python_version < \"3\" and platform_python_implementation == \"CPython\"", "brotli (>=1.0.9) ; python_version >= \"3\" and platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; (os_name != \"nt\" or python_version >= \"3\") and platform_python_implementation != \"CPython\"", "brotlipy (>=0.6.0) ; os_name == \"nt\" and python_version < \"3\""] -secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress ; python_version == \"2.7\"", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] -socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""] +h2 = ["h2 (>=4,<5)"] +socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] +zstd = ["zstandard (>=0.18.0)"] [[package]] name = "websockets" @@ -3846,4 +3847,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.1" python-versions = "^3.11" -content-hash = "7ab1e7b975af34b3271b7c6018fa22a261d3f73c7c0a0403b6b2bb86b5fbd36e" +content-hash = "bd93313f110110aa53b24a3ed47ba2d7f60e2c658a79cdff7320fed1bb1b57b5" diff --git a/proxy/src/auth/backend/classic.rs b/proxy/src/auth/backend/classic.rs index 8445368740..f35b3ecc05 100644 --- a/proxy/src/auth/backend/classic.rs +++ b/proxy/src/auth/backend/classic.rs @@ -18,11 +18,6 @@ pub(super) async fn authenticate( secret: AuthSecret, ) -> auth::Result { let scram_keys = match secret { - #[cfg(any(test, feature = "testing"))] - AuthSecret::Md5(_) => { - debug!("auth endpoint chooses MD5"); - return Err(auth::AuthError::MalformedPassword("MD5 not supported")); - } AuthSecret::Scram(secret) => { debug!("auth endpoint chooses SCRAM"); diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index c388848926..8440d198df 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -6,18 +6,17 @@ use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{info, info_span}; -use super::ComputeCredentialKeys; -use crate::auth::IpPattern; use crate::auth::backend::ComputeUserInfo; use crate::cache::Cached; +use crate::compute::AuthInfo; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::client::cplane_proxy_v1; use crate::control_plane::{self, CachedNodeInfo, NodeInfo}; use crate::error::{ReportableError, UserFacingError}; -use crate::pglb::connect_compute::ComputeConnectBackend; use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; +use crate::proxy::wake_compute::WakeComputeBackend; use crate::stream::PqStream; use crate::types::RoleName; use crate::{auth, compute, waiters}; @@ -98,15 +97,11 @@ impl ConsoleRedirectBackend { ctx: &RequestContext, auth_config: &'static AuthenticationConfig, client: &mut PqStream, - ) -> auth::Result<( - ConsoleRedirectNodeInfo, - ComputeUserInfo, - Option>, - )> { + ) -> auth::Result<(ConsoleRedirectNodeInfo, AuthInfo, ComputeUserInfo)> { authenticate(ctx, auth_config, &self.console_uri, client) .await - .map(|(node_info, user_info, ip_allowlist)| { - (ConsoleRedirectNodeInfo(node_info), user_info, ip_allowlist) + .map(|(node_info, auth_info, user_info)| { + (ConsoleRedirectNodeInfo(node_info), auth_info, user_info) }) } } @@ -114,17 +109,13 @@ impl ConsoleRedirectBackend { pub struct ConsoleRedirectNodeInfo(pub(super) NodeInfo); #[async_trait] -impl ComputeConnectBackend for ConsoleRedirectNodeInfo { +impl WakeComputeBackend for ConsoleRedirectNodeInfo { async fn wake_compute( &self, _ctx: &RequestContext, ) -> Result { Ok(Cached::new_uncached(self.0.clone())) } - - fn get_keys(&self) -> &ComputeCredentialKeys { - &ComputeCredentialKeys::None - } } async fn authenticate( @@ -132,7 +123,7 @@ async fn authenticate( auth_config: &'static AuthenticationConfig, link_uri: &reqwest::Url, client: &mut PqStream, -) -> auth::Result<(NodeInfo, ComputeUserInfo, Option>)> { +) -> auth::Result<(NodeInfo, AuthInfo, ComputeUserInfo)> { ctx.set_auth_method(crate::context::AuthMethod::ConsoleRedirect); // registering waiter can fail if we get unlucky with rng. @@ -192,10 +183,24 @@ async fn authenticate( client.write_message(BeMessage::NoticeResponse("Connecting to database.")); - // This config should be self-contained, because we won't - // take username or dbname from client's startup message. - let mut config = compute::ConnCfg::new(db_info.host.to_string(), db_info.port); - config.dbname(&db_info.dbname).user(&db_info.user); + // Backwards compatibility. pg_sni_proxy uses "--" in domain names + // while direct connections do not. Once we migrate to pg_sni_proxy + // everywhere, we can remove this. + let ssl_mode = if db_info.host.contains("--") { + // we need TLS connection with SNI info to properly route it + SslMode::Require + } else { + SslMode::Disable + }; + + let conn_info = compute::ConnectInfo { + host: db_info.host.into(), + port: db_info.port, + ssl_mode, + host_addr: None, + }; + let auth_info = + AuthInfo::for_console_redirect(&db_info.dbname, &db_info.user, db_info.password.as_deref()); let user: RoleName = db_info.user.into(); let user_info = ComputeUserInfo { @@ -209,26 +214,12 @@ async fn authenticate( ctx.set_project(db_info.aux.clone()); info!("woken up a compute node"); - // Backwards compatibility. pg_sni_proxy uses "--" in domain names - // while direct connections do not. Once we migrate to pg_sni_proxy - // everywhere, we can remove this. - if db_info.host.contains("--") { - // we need TLS connection with SNI info to properly route it - config.ssl_mode(SslMode::Require); - } else { - config.ssl_mode(SslMode::Disable); - } - - if let Some(password) = db_info.password { - config.password(password.as_ref()); - } - Ok(( NodeInfo { - config, + conn_info, aux: db_info.aux, }, + auth_info, user_info, - db_info.allowed_ips, )) } diff --git a/proxy/src/auth/backend/jwt.rs b/proxy/src/auth/backend/jwt.rs index a48f67199a..5edc878243 100644 --- a/proxy/src/auth/backend/jwt.rs +++ b/proxy/src/auth/backend/jwt.rs @@ -4,6 +4,8 @@ use std::sync::Arc; use std::time::{Duration, SystemTime}; use arc_swap::ArcSwapOption; +use base64::Engine as _; +use base64::prelude::BASE64_URL_SAFE_NO_PAD; use clashmap::ClashMap; use jose_jwk::crypto::KeyInfo; use reqwest::{Client, redirect}; @@ -347,17 +349,17 @@ impl JwkCacheEntryLock { .split_once('.') .ok_or(JwtEncodingError::InvalidCompactForm)?; - let header = base64::decode_config(header, base64::URL_SAFE_NO_PAD)?; + let header = BASE64_URL_SAFE_NO_PAD.decode(header)?; let header = serde_json::from_slice::>(&header)?; - let payloadb = base64::decode_config(payload, base64::URL_SAFE_NO_PAD)?; + let payloadb = BASE64_URL_SAFE_NO_PAD.decode(payload)?; let payload = serde_json::from_slice::>(&payloadb)?; if let Some(iss) = &payload.issuer { ctx.set_jwt_issuer(iss.as_ref().to_owned()); } - let sig = base64::decode_config(signature, base64::URL_SAFE_NO_PAD)?; + let sig = BASE64_URL_SAFE_NO_PAD.decode(signature)?; let kid = header.key_id.ok_or(JwtError::MissingKeyId)?; @@ -796,7 +798,6 @@ mod tests { use std::net::SocketAddr; use std::time::SystemTime; - use base64::URL_SAFE_NO_PAD; use bytes::Bytes; use http::Response; use http_body_util::Full; @@ -871,9 +872,8 @@ mod tests { key_id: Some(Cow::Owned(kid)), }; - let header = - base64::encode_config(serde_json::to_string(&header).unwrap(), URL_SAFE_NO_PAD); - let body = base64::encode_config(serde_json::to_string(&body).unwrap(), URL_SAFE_NO_PAD); + let header = BASE64_URL_SAFE_NO_PAD.encode(serde_json::to_string(&header).unwrap()); + let body = BASE64_URL_SAFE_NO_PAD.encode(serde_json::to_string(&body).unwrap()); format!("{header}.{body}") } @@ -883,7 +883,7 @@ mod tests { let payload = build_jwt_payload(kid, jose_jwa::Signing::Es256); let sig: Signature = SigningKey::from(key).sign(payload.as_bytes()); - let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD); + let sig = BASE64_URL_SAFE_NO_PAD.encode(sig.to_bytes()); format!("{payload}.{sig}") } @@ -893,7 +893,7 @@ mod tests { let payload = build_custom_jwt_payload(kid, body, jose_jwa::Signing::Es256); let sig: Signature = SigningKey::from(key).sign(payload.as_bytes()); - let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD); + let sig = BASE64_URL_SAFE_NO_PAD.encode(sig.to_bytes()); format!("{payload}.{sig}") } @@ -904,7 +904,7 @@ mod tests { let payload = build_jwt_payload(kid, jose_jwa::Signing::Rs256); let sig = SigningKey::::new(key).sign(payload.as_bytes()); - let sig = base64::encode_config(sig.to_bytes(), URL_SAFE_NO_PAD); + let sig = BASE64_URL_SAFE_NO_PAD.encode(sig.to_bytes()); format!("{payload}.{sig}") } diff --git a/proxy/src/auth/backend/local.rs b/proxy/src/auth/backend/local.rs index 7a6dceb194..2224f492b8 100644 --- a/proxy/src/auth/backend/local.rs +++ b/proxy/src/auth/backend/local.rs @@ -1,11 +1,12 @@ use std::net::SocketAddr; use arc_swap::ArcSwapOption; +use postgres_client::config::SslMode; use tokio::sync::Semaphore; use super::jwt::{AuthRule, FetchAuthRules}; use crate::auth::backend::jwt::FetchAuthRulesError; -use crate::compute::ConnCfg; +use crate::compute::ConnectInfo; use crate::compute_ctl::ComputeCtlApi; use crate::context::RequestContext; use crate::control_plane::NodeInfo; @@ -29,7 +30,12 @@ impl LocalBackend { api: http::Endpoint::new(compute_ctl, http::new_client()), }, node_info: NodeInfo { - config: ConnCfg::new(postgres_addr.ip().to_string(), postgres_addr.port()), + conn_info: ConnectInfo { + host_addr: Some(postgres_addr.ip()), + host: postgres_addr.ip().to_string().into(), + port: postgres_addr.port(), + ssl_mode: SslMode::Disable, + }, // TODO(conrad): make this better reflect compute info rather than endpoint info. aux: MetricsAuxInfo { endpoint_id: EndpointIdTag::get_interner().get_or_intern("local"), diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index f978f655c4..2e3013ead0 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -14,20 +14,21 @@ use serde::{Deserialize, Serialize}; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::{debug, info}; -use crate::auth::{self, AuthError, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange}; +use crate::auth::{self, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange}; use crate::cache::Cached; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::client::ControlPlaneClient; use crate::control_plane::errors::GetAuthInfoError; +use crate::control_plane::messages::EndpointRateLimitConfig; use crate::control_plane::{ self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl, RoleAccessControl, }; use crate::intern::EndpointIdInt; -use crate::pglb::connect_compute::ComputeConnectBackend; use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; +use crate::proxy::wake_compute::WakeComputeBackend; use crate::rate_limiter::EndpointRateLimiter; use crate::stream::Stream; use crate::types::{EndpointCacheKey, EndpointId, RoleName}; @@ -168,8 +169,6 @@ impl ComputeUserInfo { #[cfg_attr(test, derive(Debug))] pub(crate) enum ComputeCredentialKeys { - #[cfg(any(test, feature = "testing"))] - Password(Vec), AuthKeys(AuthKeys), JwtPayload(Vec), None, @@ -232,11 +231,8 @@ async fn auth_quirks( config.is_vpc_acccess_proxy, )?; - let endpoint = EndpointIdInt::from(&info.endpoint); - let rate_limit_config = None; - if !endpoint_rate_limiter.check(endpoint, rate_limit_config, 1) { - return Err(AuthError::too_many_connections()); - } + access_controls.connection_attempt_rate_limit(ctx, &info.endpoint, &endpoint_rate_limiter)?; + let role_access = api .get_role_access_control(ctx, &info.endpoint, &info.user) .await?; @@ -403,29 +399,23 @@ impl Backend<'_, ComputeUserInfo> { allowed_ips: Arc::new(vec![]), allowed_vpce: Arc::new(vec![]), flags: AccessBlockerFlags::default(), + rate_limits: EndpointRateLimitConfig::default(), }), } } } #[async_trait::async_trait] -impl ComputeConnectBackend for Backend<'_, ComputeCredentials> { +impl WakeComputeBackend for Backend<'_, ComputeUserInfo> { async fn wake_compute( &self, ctx: &RequestContext, ) -> Result { match self { - Self::ControlPlane(api, creds) => api.wake_compute(ctx, &creds.info).await, + Self::ControlPlane(api, info) => api.wake_compute(ctx, info).await, Self::Local(local) => Ok(Cached::new_uncached(local.node_info.clone())), } } - - fn get_keys(&self) -> &ComputeCredentialKeys { - match self { - Self::ControlPlane(_, creds) => &creds.keys, - Self::Local(_) => &ComputeCredentialKeys::None, - } - } } #[cfg(test)] @@ -448,6 +438,7 @@ mod tests { use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern}; use crate::config::AuthenticationConfig; use crate::context::RequestContext; + use crate::control_plane::messages::EndpointRateLimitConfig; use crate::control_plane::{ self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl, }; @@ -486,6 +477,7 @@ mod tests { allowed_ips: Arc::new(self.ips.clone()), allowed_vpce: Arc::new(self.vpc_endpoint_ids.clone()), flags: self.access_blocker_flags, + rate_limits: EndpointRateLimitConfig::default(), }) } diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index 8fbc4577e9..c825d5bf4b 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -169,13 +169,6 @@ pub(crate) async fn validate_password_and_exchange( secret: AuthSecret, ) -> super::Result> { match secret { - #[cfg(any(test, feature = "testing"))] - AuthSecret::Md5(_) => { - // test only - Ok(sasl::Outcome::Success(ComputeCredentialKeys::Password( - password.to_owned(), - ))) - } // perform scram authentication as both client and server to validate the keys AuthSecret::Scram(scram_secret) => { let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?; diff --git a/proxy/src/batch.rs b/proxy/src/batch.rs new file mode 100644 index 0000000000..33e08797f2 --- /dev/null +++ b/proxy/src/batch.rs @@ -0,0 +1,180 @@ +//! Batch processing system based on intrusive linked lists. +//! +//! Enqueuing a batch job requires no allocations, with +//! direct support for cancelling jobs early. +use std::collections::BTreeMap; +use std::pin::pin; +use std::sync::Mutex; + +use scopeguard::ScopeGuard; +use tokio::sync::oneshot::error::TryRecvError; + +use crate::ext::LockExt; + +pub trait QueueProcessing: Send + 'static { + type Req: Send + 'static; + type Res: Send; + + /// Get the desired batch size. + fn batch_size(&self, queue_size: usize) -> usize; + + /// This applies a full batch of events. + /// Must respond with a full batch of replies. + /// + /// If this apply can error, it's expected that errors be forwarded to each Self::Res. + /// + /// Batching does not need to happen atomically. + fn apply(&mut self, req: Vec) -> impl Future> + Send; +} + +pub struct BatchQueue { + processor: tokio::sync::Mutex

, + inner: Mutex>, +} + +struct BatchJob { + req: P::Req, + res: tokio::sync::oneshot::Sender, +} + +impl BatchQueue

{ + pub fn new(p: P) -> Self { + Self { + processor: tokio::sync::Mutex::new(p), + inner: Mutex::new(BatchQueueInner { + version: 0, + queue: BTreeMap::new(), + }), + } + } + + /// Perform a single request-response process, this may be batched internally. + /// + /// This function is not cancel safe. + pub async fn call( + &self, + req: P::Req, + cancelled: impl Future, + ) -> Result { + let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req); + + let mut cancelled = pin!(cancelled); + let resp = loop { + // try become the leader, or try wait for success. + let mut processor = tokio::select! { + // try become leader. + p = self.processor.lock() => p, + // wait for success. + resp = &mut rx => break resp.ok(), + // wait for cancellation. + cancel = cancelled.as_mut() => { + let mut inner = self.inner.lock_propagate_poison(); + if inner.queue.remove(&id).is_some() { + tracing::warn!("batched task cancelled before completion"); + } + return Err(cancel); + }, + }; + + tracing::debug!(id, "batch: became leader"); + let (reqs, resps) = self.inner.lock_propagate_poison().get_batch(&processor); + + // snitch incase the task gets cancelled. + let cancel_safety = scopeguard::guard((), |()| { + if !std::thread::panicking() { + tracing::error!( + id, + "batch: leader cancelled, despite not being cancellation safe" + ); + } + }); + + // apply a batch. + // if this is cancelled, jobs will not be completed and will panic. + let values = processor.apply(reqs).await; + + // good: we didn't get cancelled. + ScopeGuard::into_inner(cancel_safety); + + if values.len() != resps.len() { + tracing::error!( + "batch: invalid response size, expected={}, got={}", + resps.len(), + values.len() + ); + } + + // send response values. + for (tx, value) in std::iter::zip(resps, values) { + if tx.send(value).is_err() { + // receiver hung up but that's fine. + } + } + + match rx.try_recv() { + Ok(resp) => break Some(resp), + Err(TryRecvError::Closed) => break None, + // edge case - there was a race condition where + // we became the leader but were not in the batch. + // + // Example: + // thread 1: register job id=1 + // thread 2: register job id=2 + // thread 2: processor.lock().await + // thread 1: processor.lock().await + // thread 2: becomes leader, batch_size=1, jobs=[1]. + Err(TryRecvError::Empty) => {} + } + }; + + tracing::debug!(id, "batch: job completed"); + + Ok(resp.expect("no response found. batch processer should not panic")) + } +} + +struct BatchQueueInner { + version: u64, + queue: BTreeMap>, +} + +impl BatchQueueInner

{ + fn register_job(&mut self, req: P::Req) -> (u64, tokio::sync::oneshot::Receiver) { + let (tx, rx) = tokio::sync::oneshot::channel(); + + let id = self.version; + + // Overflow concern: + // This is a u64, and we might enqueue 2^16 tasks per second. + // This gives us 2^48 seconds (9 million years). + // Even if this does overflow, it will not break, but some + // jobs with the higher version might never get prioritised. + self.version += 1; + + self.queue.insert(id, BatchJob { req, res: tx }); + + tracing::debug!(id, "batch: registered job in the queue"); + + (id, rx) + } + + fn get_batch(&mut self, p: &P) -> (Vec, Vec>) { + let batch_size = p.batch_size(self.queue.len()); + let mut reqs = Vec::with_capacity(batch_size); + let mut resps = Vec::with_capacity(batch_size); + let mut ids = Vec::with_capacity(batch_size); + + while reqs.len() < batch_size { + let Some((id, job)) = self.queue.pop_first() else { + break; + }; + reqs.push(job.req); + resps.push(job.res); + ids.push(id); + } + + tracing::debug!(ids=?ids, "batch: acquired jobs"); + + (reqs, resps) + } +} diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index ba10fce7b4..423ecf821e 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -201,7 +201,7 @@ pub async fn run() -> anyhow::Result<()> { auth_backend, http_listener, shutdown.clone(), - Arc::new(CancellationHandler::new(&config.connect_to_compute, None)), + Arc::new(CancellationHandler::new(&config.connect_to_compute)), endpoint_rate_limiter, ); @@ -279,7 +279,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig }, proxy_protocol_v2: config::ProxyProtocolV2::Rejected, handshake_timeout: Duration::from_secs(10), - region: "local".into(), wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?, connect_compute_locks, connect_to_compute: compute_config, diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index a4f517fead..070c73cdcf 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -28,10 +28,9 @@ use crate::context::RequestContext; use crate::metrics::{Metrics, ThreadPoolMetrics}; use crate::pqproto::FeStartupPacket; use crate::protocol2::ConnectionInfo; -use crate::proxy::{ - ErrorSource, TlsRequired, copy_bidirectional_client_compute, run_until_cancelled, -}; +use crate::proxy::{ErrorSource, TlsRequired, copy_bidirectional_client_compute}; use crate::stream::{PqStream, Stream}; +use crate::util::run_until_cancelled; project_git_version!(GIT_VERSION); @@ -237,7 +236,6 @@ pub(super) async fn task_main( extra: None, }, crate::metrics::Protocol::SniRouter, - "sni", ); handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await } diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 757c1e988b..9ead05d492 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -11,17 +11,20 @@ use anyhow::Context; use anyhow::{bail, ensure}; use arc_swap::ArcSwapOption; use futures::future::Either; +use itertools::{Itertools, Position}; +use rand::{Rng, thread_rng}; use remote_storage::RemoteStorageConfig; use tokio::net::TcpListener; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; -use tracing::{Instrument, info, warn}; +use tracing::{Instrument, error, info, warn}; use utils::sentry_init::init_sentry; use utils::{project_build_tag, project_git_version}; use crate::auth::backend::jwt::JwkCache; use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned}; -use crate::cancellation::{CancellationHandler, handle_cancel_messages}; +use crate::batch::BatchQueue; +use crate::cancellation::{CancellationHandler, CancellationProcessor}; use crate::config::{ self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions, ProxyConfig, ProxyProtocolV2, remote_storage_from_toml, @@ -120,12 +123,6 @@ struct ProxyCliArgs { /// timeout for the TLS handshake #[clap(long, default_value = "15s", value_parser = humantime::parse_duration)] handshake_timeout: tokio::time::Duration, - /// http endpoint to receive periodic metric updates - #[clap(long)] - metric_collection_endpoint: Option, - /// how often metrics should be sent to a collection endpoint - #[clap(long)] - metric_collection_interval: Option, /// cache for `wake_compute` api method (use `size=0` to disable) #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)] wake_compute_cache: String, @@ -152,40 +149,31 @@ struct ProxyCliArgs { /// Wake compute rate limiter max number of requests per second. #[clap(long, default_values_t = RateBucketInfo::DEFAULT_SET)] wake_compute_limit: Vec, - /// Redis rate limiter max number of requests per second. - #[clap(long, default_values_t = RateBucketInfo::DEFAULT_REDIS_SET)] - redis_rps_limit: Vec, /// Cancellation channel size (max queue size for redis kv client) #[clap(long, default_value_t = 1024)] cancellation_ch_size: usize, /// Cancellation ops batch size for redis #[clap(long, default_value_t = 8)] cancellation_batch_size: usize, - /// cache for `allowed_ips` (use `size=0` to disable) - #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)] - allowed_ips_cache: String, - /// cache for `role_secret` (use `size=0` to disable) - #[clap(long, default_value = config::CacheOptions::CACHE_DEFAULT_OPTIONS)] - role_secret_cache: String, - /// redis url for notifications (if empty, redis_host:port will be used for both notifications and streaming connections) - #[clap(long)] - redis_notifications: Option, - /// what from the available authentications type to use for the regional redis we have. Supported are "irsa" and "plain". + /// redis url for plain authentication + #[clap(long, alias("redis-notifications"))] + redis_plain: Option, + /// what from the available authentications type to use for redis. Supported are "irsa" and "plain". #[clap(long, default_value = "irsa")] redis_auth_type: String, - /// redis host for streaming connections (might be different from the notifications host) + /// redis host for irsa authentication #[clap(long)] redis_host: Option, - /// redis port for streaming connections (might be different from the notifications host) + /// redis port for irsa authentication #[clap(long)] redis_port: Option, - /// redis cluster name, used in aws elasticache + /// redis cluster name for irsa authentication #[clap(long)] redis_cluster_name: Option, - /// redis user_id, used in aws elasticache + /// redis user_id for irsa authentication #[clap(long)] redis_user_id: Option, - /// aws region to retrieve credentials + /// aws region for irsa authentication #[clap(long, default_value_t = String::new())] aws_region: String, /// cache for `project_info` (use `size=0` to disable) @@ -197,6 +185,12 @@ struct ProxyCliArgs { #[clap(flatten)] parquet_upload: ParquetUploadArgs, + /// http endpoint to receive periodic metric updates + #[clap(long)] + metric_collection_endpoint: Option, + /// how often metrics should be sent to a collection endpoint + #[clap(long)] + metric_collection_interval: Option, /// interval for backup metric collection #[clap(long, default_value = "10m", value_parser = humantime::parse_duration)] metric_backup_collection_interval: std::time::Duration, @@ -209,6 +203,7 @@ struct ProxyCliArgs { /// Size of each event is no more than 400 bytes, so 2**22 is about 200MB before the compression. #[clap(long, default_value = "4194304")] metric_backup_collection_chunk_size: usize, + /// Whether to retry the connection to the compute node #[clap(long, default_value = config::RetryConfig::CONNECT_TO_COMPUTE_DEFAULT_VALUES)] connect_to_compute_retry: String, @@ -314,7 +309,7 @@ pub async fn run() -> anyhow::Result<()> { let jemalloc = match crate::jemalloc::MetricRecorder::new() { Ok(t) => Some(t), Err(e) => { - tracing::error!(error = ?e, "could not start jemalloc metrics loop"); + error!(error = ?e, "could not start jemalloc metrics loop"); None } }; @@ -328,7 +323,7 @@ pub async fn run() -> anyhow::Result<()> { Either::Right(auth_backend) => info!("Authentication backend: {auth_backend:?}"), } info!("Using region: {}", args.aws_region); - let (regional_redis_client, redis_notifications_client) = configure_redis(&args).await?; + let redis_client = configure_redis(&args).await?; // Check that we can bind to address before further initialization info!("Starting http on {}", args.http); @@ -383,20 +378,7 @@ pub async fn run() -> anyhow::Result<()> { let cancellation_token = CancellationToken::new(); - let redis_rps_limit = Vec::leak(args.redis_rps_limit.clone()); - RateBucketInfo::validate(redis_rps_limit)?; - - let redis_kv_client = regional_redis_client - .as_ref() - .map(|redis_publisher| RedisKVClient::new(redis_publisher.clone(), redis_rps_limit)); - - // channel size should be higher than redis client limit to avoid blocking - let cancel_ch_size = args.cancellation_ch_size; - let (tx_cancel, rx_cancel) = tokio::sync::mpsc::channel(cancel_ch_size); - let cancellation_handler = Arc::new(CancellationHandler::new( - &config.connect_to_compute, - Some(tx_cancel), - )); + let cancellation_handler = Arc::new(CancellationHandler::new(&config.connect_to_compute)); let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards( RateBucketInfo::to_leaky_bucket(&args.endpoint_rps_limit) @@ -475,6 +457,7 @@ pub async fn run() -> anyhow::Result<()> { client_tasks.spawn(crate::context::parquet::worker( cancellation_token.clone(), args.parquet_upload, + args.region, )); // maintenance tasks. these never return unless there's an error @@ -498,53 +481,47 @@ pub async fn run() -> anyhow::Result<()> { #[cfg_attr(not(any(test, feature = "testing")), expect(irrefutable_let_patterns))] if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend { if let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api { - match (redis_notifications_client, regional_redis_client.clone()) { - (None, None) => {} - (client1, client2) => { - let cache = api.caches.project_info.clone(); - if let Some(client) = client1 { - maintenance_tasks.spawn(notifications::task_main( - client, - cache.clone(), - args.region.clone(), - )); + if let Some(client) = redis_client { + // project info cache and invalidation of that cache. + let cache = api.caches.project_info.clone(); + maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone())); + maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); + + // Try to connect to Redis 3 times with 1 + (0..0.1) second interval. + // This prevents immediate exit and pod restart, + // which can cause hammering of the redis in case of connection issues. + // cancellation key management + let mut redis_kv_client = RedisKVClient::new(client.clone()); + for attempt in (0..3).with_position() { + match redis_kv_client.try_connect().await { + Ok(()) => { + info!("Connected to Redis KV client"); + cancellation_handler.init_tx(BatchQueue::new(CancellationProcessor { + client: redis_kv_client, + batch_size: args.cancellation_batch_size, + })); + + break; + } + Err(e) => { + error!("Failed to connect to Redis KV client: {e}"); + if matches!(attempt, Position::Last(_)) { + bail!( + "Failed to connect to Redis KV client after {} attempts", + attempt.into_inner() + ); + } + let jitter = thread_rng().gen_range(0..100); + tokio::time::sleep(Duration::from_millis(1000 + jitter)).await; + } } - if let Some(client) = client2 { - maintenance_tasks.spawn(notifications::task_main( - client, - cache.clone(), - args.region.clone(), - )); - } - maintenance_tasks.spawn(async move { cache.clone().gc_worker().await }); } - } - if let Some(mut redis_kv_client) = redis_kv_client { - maintenance_tasks.spawn(async move { - redis_kv_client.try_connect().await?; - handle_cancel_messages( - &mut redis_kv_client, - rx_cancel, - args.cancellation_batch_size, - ) - .await?; - - drop(redis_kv_client); - - // `handle_cancel_messages` was terminated due to the tx_cancel - // being dropped. this is not worthy of an error, and this task can only return `Err`, - // so let's wait forever instead. - std::future::pending().await - }); - } - - if let Some(regional_redis_client) = regional_redis_client { + // listen for notifications of new projects/endpoints/branches let cache = api.caches.endpoints_cache.clone(); - let con = regional_redis_client; let span = tracing::info_span!("endpoints_cache"); maintenance_tasks.spawn( - async move { cache.do_read(con, cancellation_token.clone()).await } + async move { cache.do_read(client, cancellation_token.clone()).await } .instrument(span), ); } @@ -673,7 +650,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { authentication_config, proxy_protocol_v2: args.proxy_protocol_v2, handshake_timeout: args.handshake_timeout, - region: args.region.clone(), wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?, connect_compute_locks, connect_to_compute: compute_config, @@ -835,21 +811,18 @@ fn build_auth_backend( async fn configure_redis( args: &ProxyCliArgs, -) -> anyhow::Result<( - Option, - Option, -)> { +) -> anyhow::Result> { // TODO: untangle the config args - let regional_redis_client = match (args.redis_auth_type.as_str(), &args.redis_notifications) { - ("plain", redis_url) => match redis_url { + let redis_client = match &*args.redis_auth_type { + "plain" => match &args.redis_plain { None => { - bail!("plain auth requires redis_notifications to be set"); + bail!("plain auth requires redis_plain to be set"); } Some(url) => { Some(ConnectionWithCredentialsProvider::new_with_static_credentials(url.clone())) } }, - ("irsa", _) => match (&args.redis_host, args.redis_port) { + "irsa" => match (&args.redis_host, args.redis_port) { (Some(host), Some(port)) => Some( ConnectionWithCredentialsProvider::new_with_credentials_provider( host.clone(), @@ -873,18 +846,12 @@ async fn configure_redis( bail!("redis-host and redis-port must be specified together"); } }, - _ => { - bail!("unknown auth type given"); + auth_type => { + bail!("unknown auth type {auth_type:?} given") } }; - let redis_notifications_client = if let Some(url) = &args.redis_notifications { - Some(ConnectionWithCredentialsProvider::new_with_static_credentials(&**url)) - } else { - regional_redis_client.clone() - }; - - Ok((regional_redis_client, redis_notifications_client)) + Ok(redis_client) } #[cfg(test)] diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index 81c88e3ddd..d37c107323 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -18,6 +18,7 @@ use crate::types::{EndpointId, RoleName}; #[async_trait] pub(crate) trait ProjectInfoCache { + fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt); fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt); fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt); fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt); @@ -100,6 +101,13 @@ pub struct ProjectInfoCacheImpl { #[async_trait] impl ProjectInfoCache for ProjectInfoCacheImpl { + fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) { + info!("invalidating endpoint access for `{endpoint_id}`"); + if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { + endpoint_info.invalidate_endpoint(); + } + } + fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) { info!("invalidating endpoint access for project `{project_id}`"); let endpoints = self @@ -356,6 +364,7 @@ mod tests { use std::sync::Arc; use super::*; + use crate::control_plane::messages::EndpointRateLimitConfig; use crate::control_plane::{AccessBlockerFlags, AuthSecret}; use crate::scram::ServerSecret; use crate::types::ProjectId; @@ -391,6 +400,7 @@ mod tests { allowed_ips: allowed_ips.clone(), allowed_vpce: Arc::new(vec![]), flags: AccessBlockerFlags::default(), + rate_limits: EndpointRateLimitConfig::default(), }, RoleAccessControl { secret: secret1.clone(), @@ -406,6 +416,7 @@ mod tests { allowed_ips: allowed_ips.clone(), allowed_vpce: Arc::new(vec![]), flags: AccessBlockerFlags::default(), + rate_limits: EndpointRateLimitConfig::default(), }, RoleAccessControl { secret: secret2.clone(), @@ -431,6 +442,7 @@ mod tests { allowed_ips: allowed_ips.clone(), allowed_vpce: Arc::new(vec![]), flags: AccessBlockerFlags::default(), + rate_limits: EndpointRateLimitConfig::default(), }, RoleAccessControl { secret: secret3.clone(), diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index d26641db46..ffc0cf43f1 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -1,19 +1,24 @@ +use std::convert::Infallible; use std::net::{IpAddr, SocketAddr}; -use std::sync::Arc; +use std::pin::pin; +use std::sync::{Arc, OnceLock}; +use std::time::Duration; -use anyhow::{Context, anyhow}; +use anyhow::anyhow; +use futures::FutureExt; use ipnet::{IpNet, Ipv4Net, Ipv6Net}; -use postgres_client::CancelToken; +use postgres_client::RawCancelToken; use postgres_client::tls::MakeTlsConnect; use redis::{Cmd, FromRedisValue, Value}; use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::net::TcpStream; -use tokio::sync::{mpsc, oneshot}; -use tracing::{debug, error, info, warn}; +use tokio::time::timeout; +use tracing::{debug, error, info}; use crate::auth::AuthError; use crate::auth::backend::ComputeUserInfo; +use crate::batch::{BatchQueue, QueueProcessing}; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::ControlPlaneApi; @@ -24,50 +29,39 @@ use crate::pqproto::CancelKeyData; use crate::rate_limiter::LeakyBucketRateLimiter; use crate::redis::keys::KeyPrefix; use crate::redis::kv_ops::RedisKVClient; -use crate::tls::postgres_rustls::MakeRustlsConnect; type IpSubnetKey = IpNet; -const CANCEL_KEY_TTL: i64 = 1_209_600; // 2 weeks cancellation key expire time +const CANCEL_KEY_TTL: std::time::Duration = std::time::Duration::from_secs(600); +const CANCEL_KEY_REFRESH: std::time::Duration = std::time::Duration::from_secs(570); // Message types for sending through mpsc channel pub enum CancelKeyOp { StoreCancelKey { - key: String, - field: String, - value: String, - resp_tx: Option>>, - _guard: CancelChannelSizeGuard<'static>, - expire: i64, // TTL for key + key: CancelKeyData, + value: Box, + expire: std::time::Duration, }, GetCancelData { - key: String, - resp_tx: oneshot::Sender>>, - _guard: CancelChannelSizeGuard<'static>, - }, - RemoveCancelKey { - key: String, - field: String, - resp_tx: Option>>, - _guard: CancelChannelSizeGuard<'static>, + key: CancelKeyData, }, } pub struct Pipeline { inner: redis::Pipeline, - replies: Vec, + replies: usize, } impl Pipeline { fn with_capacity(n: usize) -> Self { Self { inner: redis::Pipeline::with_capacity(n), - replies: Vec::with_capacity(n), + replies: 0, } } - async fn execute(&mut self, client: &mut RedisKVClient) { - let responses = self.replies.len(); + async fn execute(self, client: &mut RedisKVClient) -> Vec> { + let responses = self.replies; let batch_size = self.inner.len(); match client.query(&self.inner).await { @@ -77,176 +71,72 @@ impl Pipeline { batch_size, responses, "successfully completed cancellation jobs", ); - for (value, reply) in std::iter::zip(values, self.replies.drain(..)) { - reply.send_value(value); - } + values.into_iter().map(Ok).collect() } Ok(value) => { error!(batch_size, ?value, "unexpected redis return value"); - for reply in self.replies.drain(..) { - reply.send_err(anyhow!("incorrect response type from redis")); - } + std::iter::repeat_with(|| Err(anyhow!("incorrect response type from redis"))) + .take(responses) + .collect() } Err(err) => { - for reply in self.replies.drain(..) { - reply.send_err(anyhow!("could not send cmd to redis: {err}")); - } + std::iter::repeat_with(|| Err(anyhow!("could not send cmd to redis: {err}"))) + .take(responses) + .collect() } } - - self.inner.clear(); - self.replies.clear(); } - fn add_command_with_reply(&mut self, cmd: Cmd, reply: CancelReplyOp) { + fn add_command_with_reply(&mut self, cmd: Cmd) { self.inner.add_command(cmd); - self.replies.push(reply); + self.replies += 1; } fn add_command_no_reply(&mut self, cmd: Cmd) { self.inner.add_command(cmd).ignore(); } - - fn add_command(&mut self, cmd: Cmd, reply: Option) { - match reply { - Some(reply) => self.add_command_with_reply(cmd, reply), - None => self.add_command_no_reply(cmd), - } - } } impl CancelKeyOp { - fn register(self, pipe: &mut Pipeline) { - #[allow(clippy::used_underscore_binding)] + fn register(&self, pipe: &mut Pipeline) { match self { - CancelKeyOp::StoreCancelKey { - key, - field, - value, - resp_tx, - _guard, - expire, - } => { - let reply = - resp_tx.map(|resp_tx| CancelReplyOp::StoreCancelKey { resp_tx, _guard }); - pipe.add_command(Cmd::hset(&key, field, value), reply); - pipe.add_command_no_reply(Cmd::expire(key, expire)); + CancelKeyOp::StoreCancelKey { key, value, expire } => { + let key = KeyPrefix::Cancel(*key).build_redis_key(); + pipe.add_command_with_reply(Cmd::hset(&key, "data", &**value)); + pipe.add_command_no_reply(Cmd::expire(&key, expire.as_secs() as i64)); } - CancelKeyOp::GetCancelData { - key, - resp_tx, - _guard, - } => { - let reply = CancelReplyOp::GetCancelData { resp_tx, _guard }; - pipe.add_command_with_reply(Cmd::hgetall(key), reply); - } - CancelKeyOp::RemoveCancelKey { - key, - field, - resp_tx, - _guard, - } => { - let reply = - resp_tx.map(|resp_tx| CancelReplyOp::RemoveCancelKey { resp_tx, _guard }); - pipe.add_command(Cmd::hdel(key, field), reply); + CancelKeyOp::GetCancelData { key } => { + let key = KeyPrefix::Cancel(*key).build_redis_key(); + pipe.add_command_with_reply(Cmd::hget(key, "data")); } } } } -// Message types for sending through mpsc channel -pub enum CancelReplyOp { - StoreCancelKey { - resp_tx: oneshot::Sender>, - _guard: CancelChannelSizeGuard<'static>, - }, - GetCancelData { - resp_tx: oneshot::Sender>>, - _guard: CancelChannelSizeGuard<'static>, - }, - RemoveCancelKey { - resp_tx: oneshot::Sender>, - _guard: CancelChannelSizeGuard<'static>, - }, +pub struct CancellationProcessor { + pub client: RedisKVClient, + pub batch_size: usize, } -impl CancelReplyOp { - fn send_err(self, e: anyhow::Error) { - match self { - CancelReplyOp::StoreCancelKey { resp_tx, _guard } => { - resp_tx - .send(Err(e)) - .inspect_err(|_| tracing::debug!("could not send reply")) - .ok(); - } - CancelReplyOp::GetCancelData { resp_tx, _guard } => { - resp_tx - .send(Err(e)) - .inspect_err(|_| tracing::debug!("could not send reply")) - .ok(); - } - CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => { - resp_tx - .send(Err(e)) - .inspect_err(|_| tracing::debug!("could not send reply")) - .ok(); - } - } +impl QueueProcessing for CancellationProcessor { + type Req = (CancelChannelSizeGuard<'static>, CancelKeyOp); + type Res = anyhow::Result; + + fn batch_size(&self, _queue_size: usize) -> usize { + self.batch_size } - fn send_value(self, v: redis::Value) { - match self { - CancelReplyOp::StoreCancelKey { resp_tx, _guard } => { - let send = - FromRedisValue::from_owned_redis_value(v).context("could not parse value"); - resp_tx - .send(send) - .inspect_err(|_| tracing::debug!("could not send reply")) - .ok(); - } - CancelReplyOp::GetCancelData { resp_tx, _guard } => { - let send = - FromRedisValue::from_owned_redis_value(v).context("could not parse value"); - resp_tx - .send(send) - .inspect_err(|_| tracing::debug!("could not send reply")) - .ok(); - } - CancelReplyOp::RemoveCancelKey { resp_tx, _guard } => { - let send = - FromRedisValue::from_owned_redis_value(v).context("could not parse value"); - resp_tx - .send(send) - .inspect_err(|_| tracing::debug!("could not send reply")) - .ok(); - } - } - } -} - -// Running as a separate task to accept messages through the rx channel -pub async fn handle_cancel_messages( - client: &mut RedisKVClient, - mut rx: mpsc::Receiver, - batch_size: usize, -) -> anyhow::Result<()> { - let mut batch = Vec::with_capacity(batch_size); - let mut pipeline = Pipeline::with_capacity(batch_size); - - loop { - if rx.recv_many(&mut batch, batch_size).await == 0 { - warn!("shutting down cancellation queue"); - break Ok(()); - } + async fn apply(&mut self, batch: Vec) -> Vec { + let mut pipeline = Pipeline::with_capacity(batch.len()); let batch_size = batch.len(); debug!(batch_size, "running cancellation jobs"); - for msg in batch.drain(..) { - msg.register(&mut pipeline); + for (_, op) in &batch { + op.register(&mut pipeline); } - pipeline.execute(client).await; + pipeline.execute(&mut self.client).await } } @@ -257,7 +147,7 @@ pub struct CancellationHandler { compute_config: &'static ComputeConfig, // rate limiter of cancellation requests limiter: Arc>>, - tx: Option>, // send messages to the redis KV client task + tx: OnceLock>, // send messages to the redis KV client task } #[derive(Debug, Error)] @@ -297,13 +187,10 @@ impl ReportableError for CancelError { } impl CancellationHandler { - pub fn new( - compute_config: &'static ComputeConfig, - tx: Option>, - ) -> Self { + pub fn new(compute_config: &'static ComputeConfig) -> Self { Self { compute_config, - tx, + tx: OnceLock::new(), limiter: Arc::new(std::sync::Mutex::new( LeakyBucketRateLimiter::::new_with_shards( LeakyBucketRateLimiter::::DEFAULT, @@ -313,7 +200,14 @@ impl CancellationHandler { } } - pub(crate) fn get_key(self: &Arc) -> Session { + pub fn init_tx(&self, queue: BatchQueue) { + self.tx + .set(queue) + .map_err(|_| {}) + .expect("cancellation queue should be registered once"); + } + + pub(crate) fn get_key(self: Arc) -> Session { // we intentionally generate a random "backend pid" and "secret key" here. // we use the corresponding u64 as an identifier for the // actual endpoint+pid+secret for postgres/pgbouncer. @@ -323,83 +217,68 @@ impl CancellationHandler { let key: CancelKeyData = rand::random(); - let prefix_key: KeyPrefix = KeyPrefix::Cancel(key); - let redis_key = prefix_key.build_redis_key(); - debug!("registered new query cancellation key {key}"); Session { key, - redis_key, - cancellation_handler: Arc::clone(self), + cancellation_handler: self, } } + /// This is not cancel safe async fn get_cancel_key( &self, key: CancelKeyData, ) -> Result, CancelError> { - let prefix_key: KeyPrefix = KeyPrefix::Cancel(key); - let redis_key = prefix_key.build_redis_key(); + let guard = Metrics::get() + .proxy + .cancel_channel_size + .guard(RedisMsgKind::HGet); + let op = CancelKeyOp::GetCancelData { key }; - let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); - let op = CancelKeyOp::GetCancelData { - key: redis_key, - resp_tx, - _guard: Metrics::get() - .proxy - .cancel_channel_size - .guard(RedisMsgKind::HGetAll), - }; - - let Some(tx) = &self.tx else { + let Some(tx) = self.tx.get() else { tracing::warn!("cancellation handler is not available"); return Err(CancelError::InternalError); }; - tx.try_send(op) - .map_err(|e| { - tracing::warn!("failed to send GetCancelData for {key}: {e}"); - }) - .map_err(|()| CancelError::InternalError)?; - - let result = resp_rx.await.map_err(|e| { + const TIMEOUT: Duration = Duration::from_secs(5); + let result = timeout( + TIMEOUT, + tx.call((guard, op), std::future::pending::()), + ) + .await + .map_err(|_| { + tracing::warn!("timed out waiting to receive GetCancelData response"); + CancelError::RateLimit + })? + // cannot be cancelled + .unwrap_or_else(|x| match x {}) + .map_err(|e| { tracing::warn!("failed to receive GetCancelData response: {e}"); CancelError::InternalError })?; - let cancel_state_str: Option = match result { - Ok(mut state) => { - if state.len() == 1 { - Some(state.remove(0).1) - } else { - tracing::warn!("unexpected number of entries in cancel state: {state:?}"); - return Err(CancelError::InternalError); - } - } - Err(e) => { - tracing::warn!("failed to receive cancel state from redis: {e}"); - return Err(CancelError::InternalError); - } - }; + let cancel_state_str = String::from_owned_redis_value(result).map_err(|e| { + tracing::warn!("failed to receive GetCancelData response: {e}"); + CancelError::InternalError + })?; - let cancel_state: Option = match cancel_state_str { - Some(state) => { - let cancel_closure: CancelClosure = serde_json::from_str(&state).map_err(|e| { - tracing::warn!("failed to deserialize cancel state: {e}"); - CancelError::InternalError - })?; - Some(cancel_closure) - } - None => None, - }; - Ok(cancel_state) + let cancel_closure: CancelClosure = + serde_json::from_str(&cancel_state_str).map_err(|e| { + tracing::warn!("failed to deserialize cancel state: {e}"); + CancelError::InternalError + })?; + + Ok(Some(cancel_closure)) } + /// Try to cancel a running query for the corresponding connection. /// If the cancellation key is not found, it will be published to Redis. /// check_allowed - if true, check if the IP is allowed to cancel the query. /// Will fetch IP allowlist internally. /// /// return Result primarily for tests + /// + /// This is not cancel safe pub(crate) async fn cancel_session( &self, key: CancelKeyData, @@ -468,10 +347,10 @@ impl CancellationHandler { /// This should've been a [`std::future::Future`], but /// it's impossible to name a type of an unboxed future /// (we'd need something like `#![feature(type_alias_impl_trait)]`). -#[derive(Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct CancelClosure { socket_addr: SocketAddr, - cancel_token: CancelToken, + cancel_token: RawCancelToken, hostname: String, // for pg_sni router user_info: ComputeUserInfo, } @@ -479,7 +358,7 @@ pub struct CancelClosure { impl CancelClosure { pub(crate) fn new( socket_addr: SocketAddr, - cancel_token: CancelToken, + cancel_token: RawCancelToken, hostname: String, user_info: ComputeUserInfo, ) -> Self { @@ -492,15 +371,13 @@ impl CancelClosure { } /// Cancels the query running on user's compute node. pub(crate) async fn try_cancel_query( - self, + &self, compute_config: &ComputeConfig, ) -> Result<(), CancelError> { let socket = TcpStream::connect(self.socket_addr).await?; - let mut mk_tls = - crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone()); - let tls = >::make_tls_connect( - &mut mk_tls, + let tls = <_ as MakeTlsConnect>::make_tls_connect( + compute_config, &self.hostname, ) .map_err(|e| CancelError::IO(std::io::Error::other(e.to_string())))?; @@ -515,7 +392,6 @@ impl CancelClosure { pub(crate) struct Session { /// The user-facing key identifying this session. key: CancelKeyData, - redis_key: String, cancellation_handler: Arc, } @@ -524,60 +400,75 @@ impl Session { &self.key } - // Send the store key op to the cancellation handler and set TTL for the key - pub(crate) fn write_cancel_key( + /// Ensure the cancel key is continously refreshed, + /// but stop when the channel is dropped. + /// + /// This is not cancel safe + pub(crate) async fn maintain_cancel_key( &self, - cancel_closure: CancelClosure, - ) -> Result<(), CancelError> { - let Some(tx) = &self.cancellation_handler.tx else { + session_id: uuid::Uuid, + cancel: tokio::sync::oneshot::Receiver, + cancel_closure: &CancelClosure, + compute_config: &ComputeConfig, + ) { + let Some(tx) = self.cancellation_handler.tx.get() else { tracing::warn!("cancellation handler is not available"); - return Err(CancelError::InternalError); + // don't exit, as we only want to exit if cancelled externally. + std::future::pending().await }; - let closure_json = serde_json::to_string(&cancel_closure).map_err(|e| { - tracing::warn!("failed to serialize cancel closure: {e}"); - CancelError::InternalError - })?; + let closure_json = serde_json::to_string(&cancel_closure) + .expect("serialising to json string should not fail") + .into_boxed_str(); - let op = CancelKeyOp::StoreCancelKey { - key: self.redis_key.clone(), - field: "data".to_string(), - value: closure_json, - resp_tx: None, - _guard: Metrics::get() + let mut cancel = pin!(cancel); + + loop { + let guard = Metrics::get() .proxy .cancel_channel_size - .guard(RedisMsgKind::HSet), - expire: CANCEL_KEY_TTL, - }; + .guard(RedisMsgKind::HSet); + let op = CancelKeyOp::StoreCancelKey { + key: self.key, + value: closure_json.clone(), + expire: CANCEL_KEY_TTL, + }; - let _ = tx.try_send(op).map_err(|e| { - let key = self.key; - tracing::warn!("failed to send StoreCancelKey for {key}: {e}"); - }); - Ok(()) - } + tracing::debug!( + src=%self.key, + dest=?cancel_closure.cancel_token, + "registering cancellation key" + ); - pub(crate) fn remove_cancel_key(&self) -> Result<(), CancelError> { - let Some(tx) = &self.cancellation_handler.tx else { - tracing::warn!("cancellation handler is not available"); - return Err(CancelError::InternalError); - }; + match tx.call((guard, op), cancel.as_mut()).await { + Ok(Ok(_)) => { + tracing::debug!( + src=%self.key, + dest=?cancel_closure.cancel_token, + "registered cancellation key" + ); - let op = CancelKeyOp::RemoveCancelKey { - key: self.redis_key.clone(), - field: "data".to_string(), - resp_tx: None, - _guard: Metrics::get() - .proxy - .cancel_channel_size - .guard(RedisMsgKind::HDel), - }; + // wait before continuing. + tokio::time::sleep(CANCEL_KEY_REFRESH).await; + } + // retry immediately. + Ok(Err(error)) => { + tracing::warn!(?error, "error registering cancellation key"); + } + Err(Err(_cancelled)) => break, + } + } - let _ = tx.try_send(op).map_err(|e| { - let key = self.key; - tracing::warn!("failed to send RemoveCancelKey for {key}: {e}"); - }); - Ok(()) + if let Err(err) = cancel_closure + .try_cancel_query(compute_config) + .boxed() + .await + { + tracing::warn!( + ?session_id, + ?err, + "could not cancel the query in the database" + ); + } } } diff --git a/proxy/src/compute.rs b/proxy/src/compute/mod.rs similarity index 55% rename from proxy/src/compute.rs rename to proxy/src/compute/mod.rs index 2899f25129..7fb88e6a45 100644 --- a/proxy/src/compute.rs +++ b/proxy/src/compute/mod.rs @@ -1,21 +1,24 @@ +mod tls; + use std::fmt::Debug; use std::io; -use std::net::SocketAddr; -use std::time::Duration; +use std::net::{IpAddr, SocketAddr}; use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; +use postgres_client::config::{AuthKeys, ChannelBinding, SslMode}; +use postgres_client::maybe_tls_stream::MaybeTlsStream; use postgres_client::tls::MakeTlsConnect; -use postgres_client::{CancelToken, RawConnection}; +use postgres_client::{NoTls, RawCancelToken, RawConnection}; use postgres_protocol::message::backend::NoticeResponseBody; -use rustls::pki_types::InvalidDnsNameError; use thiserror::Error; use tokio::net::{TcpStream, lookup_host}; use tracing::{debug, error, info, warn}; -use crate::auth::backend::ComputeUserInfo; +use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; use crate::auth::parse_endpoint_param; use crate::cancellation::CancelClosure; +use crate::compute::tls::TlsError; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::client::ApiLockError; @@ -25,23 +28,58 @@ use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumDbConnectionsGuard}; use crate::pqproto::StartupMessageParams; use crate::proxy::neon_option; -use crate::tls::postgres_rustls::MakeRustlsConnect; use crate::types::Host; pub const COULD_NOT_CONNECT: &str = "Couldn't connect to compute node"; #[derive(Debug, Error)] -pub(crate) enum ConnectionError { +pub(crate) enum PostgresError { /// This error doesn't seem to reveal any secrets; for instance, /// `postgres_client::error::Kind` doesn't contain ip addresses and such. #[error("{COULD_NOT_CONNECT}: {0}")] Postgres(#[from] postgres_client::Error), +} - #[error("{COULD_NOT_CONNECT}: {0}")] - CouldNotConnect(#[from] io::Error), +impl UserFacingError for PostgresError { + fn to_string_client(&self) -> String { + match self { + // This helps us drop irrelevant library-specific prefixes. + // TODO: propagate severity level and other parameters. + PostgresError::Postgres(err) => match err.as_db_error() { + Some(err) => { + let msg = err.message(); + if msg.starts_with("unsupported startup parameter: ") + || msg.starts_with("unsupported startup parameter in options: ") + { + format!( + "{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter" + ) + } else { + msg.to_owned() + } + } + None => err.to_string(), + }, + } + } +} + +impl ReportableError for PostgresError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + PostgresError::Postgres(e) if e.as_db_error().is_some() => { + crate::error::ErrorKind::Postgres + } + PostgresError::Postgres(_) => crate::error::ErrorKind::Compute, + } + } +} + +#[derive(Debug, Error)] +pub(crate) enum ConnectionError { #[error("{COULD_NOT_CONNECT}: {0}")] - TlsError(#[from] InvalidDnsNameError), + TlsError(#[from] TlsError), #[error("{COULD_NOT_CONNECT}: {0}")] WakeComputeError(#[from] WakeComputeError), @@ -53,27 +91,11 @@ pub(crate) enum ConnectionError { impl UserFacingError for ConnectionError { fn to_string_client(&self) -> String { match self { - // This helps us drop irrelevant library-specific prefixes. - // TODO: propagate severity level and other parameters. - ConnectionError::Postgres(err) => match err.as_db_error() { - Some(err) => { - let msg = err.message(); - - if msg.starts_with("unsupported startup parameter: ") - || msg.starts_with("unsupported startup parameter in options: ") - { - format!("{msg}. Please use unpooled connection or remove this parameter from the startup package. More details: https://neon.tech/docs/connect/connection-errors#unsupported-startup-parameter") - } else { - msg.to_owned() - } - } - None => err.to_string(), - }, ConnectionError::WakeComputeError(err) => err.to_string_client(), ConnectionError::TooManyConnectionAttempts(_) => { "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned() } - _ => COULD_NOT_CONNECT.to_owned(), + ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(), } } } @@ -81,11 +103,6 @@ impl UserFacingError for ConnectionError { impl ReportableError for ConnectionError { fn get_error_kind(&self) -> crate::error::ErrorKind { match self { - ConnectionError::Postgres(e) if e.as_db_error().is_some() => { - crate::error::ErrorKind::Postgres - } - ConnectionError::Postgres(_) => crate::error::ErrorKind::Compute, - ConnectionError::CouldNotConnect(_) => crate::error::ErrorKind::Compute, ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute, ConnectionError::WakeComputeError(e) => e.get_error_kind(), ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(), @@ -96,34 +113,91 @@ impl ReportableError for ConnectionError { /// A pair of `ClientKey` & `ServerKey` for `SCRAM-SHA-256`. pub(crate) type ScramKeys = postgres_client::config::ScramKeys<32>; -/// A config for establishing a connection to compute node. -/// Eventually, `postgres_client` will be replaced with something better. -/// Newtype allows us to implement methods on top of it. #[derive(Clone)] -pub(crate) struct ConnCfg(Box); +pub enum Auth { + /// Only used during console-redirect. + Password(Vec), + /// Used by sql-over-http, ws, tcp. + Scram(Box), +} + +/// A config for authenticating to the compute node. +pub(crate) struct AuthInfo { + /// None for local-proxy, as we use trust-based localhost auth. + /// Some for sql-over-http, ws, tcp, and in most cases for console-redirect. + /// Might be None for console-redirect, but that's only a consequence of testing environments ATM. + auth: Option, + server_params: StartupMessageParams, + + channel_binding: ChannelBinding, + + /// Console redirect sets user and database, we shouldn't re-use those from the params. + skip_db_user: bool, +} + +/// Contains only the data needed to establish a secure connection to compute. +#[derive(Clone)] +pub struct ConnectInfo { + pub host_addr: Option, + pub host: Host, + pub port: u16, + pub ssl_mode: SslMode, +} /// Creation and initialization routines. -impl ConnCfg { - pub(crate) fn new(host: String, port: u16) -> Self { - Self(Box::new(postgres_client::Config::new(host, port))) - } - - /// Reuse password or auth keys from the other config. - pub(crate) fn reuse_password(&mut self, other: Self) { - if let Some(password) = other.get_password() { - self.password(password); - } - - if let Some(keys) = other.get_auth_keys() { - self.auth_keys(keys); +impl AuthInfo { + pub(crate) fn for_console_redirect(db: &str, user: &str, pw: Option<&str>) -> Self { + let mut server_params = StartupMessageParams::default(); + server_params.insert("database", db); + server_params.insert("user", user); + Self { + auth: pw.map(|pw| Auth::Password(pw.as_bytes().to_owned())), + server_params, + skip_db_user: true, + // pg-sni-router is a mitm so this would fail. + channel_binding: ChannelBinding::Disable, } } - pub(crate) fn get_host(&self) -> Host { - match self.0.get_host() { - postgres_client::config::Host::Tcp(s) => s.into(), + pub(crate) fn with_auth_keys(keys: ComputeCredentialKeys) -> Self { + Self { + auth: match keys { + ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => { + Some(Auth::Scram(Box::new(auth_keys))) + } + ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None, + }, + server_params: StartupMessageParams::default(), + skip_db_user: false, + channel_binding: ChannelBinding::Prefer, } } +} + +impl ConnectInfo { + pub fn to_postgres_client_config(&self) -> postgres_client::Config { + let mut config = postgres_client::Config::new(self.host.to_string(), self.port); + config.ssl_mode(self.ssl_mode); + if let Some(host_addr) = self.host_addr { + config.set_host_addr(host_addr); + } + config + } +} + +impl AuthInfo { + fn enrich(&self, mut config: postgres_client::Config) -> postgres_client::Config { + match &self.auth { + Some(Auth::Scram(keys)) => config.auth_keys(AuthKeys::ScramSha256(**keys)), + Some(Auth::Password(pw)) => config.password(pw), + None => &mut config, + }; + config.channel_binding(self.channel_binding); + for (k, v) in self.server_params.iter() { + config.set_param(k, v); + } + config + } /// Apply startup message params to the connection config. pub(crate) fn set_startup_params( @@ -132,53 +206,90 @@ impl ConnCfg { arbitrary_params: bool, ) { if !arbitrary_params { - self.set_param("client_encoding", "UTF8"); + self.server_params.insert("client_encoding", "UTF8"); } for (k, v) in params.iter() { match k { // Only set `user` if it's not present in the config. // Console redirect auth flow takes username from the console's response. - "user" if self.user_is_set() => {} - "database" if self.db_is_set() => {} + "user" | "database" if self.skip_db_user => {} "options" => { if let Some(options) = filtered_options(v) { - self.set_param(k, &options); + self.server_params.insert(k, &options); } } "user" | "database" | "application_name" | "replication" => { - self.set_param(k, v); + self.server_params.insert(k, v); } // if we allow arbitrary params, then we forward them through. // this is a flag for a period of backwards compatibility k if arbitrary_params => { - self.set_param(k, v); + self.server_params.insert(k, v); } _ => {} } } } -} -impl std::ops::Deref for ConnCfg { - type Target = postgres_client::Config; + pub async fn authenticate( + &self, + ctx: &RequestContext, + compute: &mut ComputeConnection, + user_info: ComputeUserInfo, + ) -> Result { + // client config with stubbed connect info. + // TODO(conrad): should we rewrite this to bypass tokio-postgres2 entirely, + // utilising pqproto.rs. + let mut tmp_config = postgres_client::Config::new(String::new(), 0); + // We have already established SSL if necessary. + tmp_config.ssl_mode(SslMode::Disable); + let tmp_config = self.enrich(tmp_config); - fn deref(&self) -> &Self::Target { - &self.0 + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); + let connection = tmp_config + .tls_and_authenticate(&mut compute.stream, NoTls) + .await?; + drop(pause); + + let RawConnection { + stream: _, + parameters, + delayed_notice, + process_id, + secret_key, + } = connection; + + tracing::Span::current().record("pid", tracing::field::display(process_id)); + + // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw. + // Yet another reason to rework the connection establishing code. + let cancel_closure = CancelClosure::new( + compute.socket_addr, + RawCancelToken { + ssl_mode: compute.ssl_mode, + process_id, + secret_key, + }, + compute.hostname.to_string(), + user_info, + ); + + Ok(PostgresSettings { + params: parameters, + cancel_closure, + delayed_notice, + }) } } -/// For now, let's make it easier to setup the config. -impl std::ops::DerefMut for ConnCfg { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -impl ConnCfg { - /// Establish a raw TCP connection to the compute node. - async fn connect_raw(&self, timeout: Duration) -> io::Result<(SocketAddr, TcpStream, &str)> { - use postgres_client::config::Host; +impl ConnectInfo { + /// Establish a raw TCP+TLS connection to the compute node. + async fn connect_raw( + &self, + config: &ComputeConfig, + ) -> Result<(SocketAddr, MaybeTlsStream), TlsError> { + let timeout = config.timeout; // wrap TcpStream::connect with timeout let connect_with_timeout = |addrs| { @@ -208,112 +319,84 @@ impl ConnCfg { // We can't reuse connection establishing logic from `postgres_client` here, // because it has no means for extracting the underlying socket which we // require for our business. - let port = self.0.get_port(); - let host = self.0.get_host(); + let port = self.port; + let host = &*self.host; - let host = match host { - Host::Tcp(host) => host.as_str(), - }; - - let addrs = match self.0.get_host_addr() { + let addrs = match self.host_addr { Some(addr) => vec![SocketAddr::new(addr, port)], None => lookup_host((host, port)).await?.collect(), }; match connect_once(&*addrs).await { - Ok((sockaddr, stream)) => Ok((sockaddr, stream, host)), + Ok((sockaddr, stream)) => Ok(( + sockaddr, + tls::connect_tls(stream, self.ssl_mode, config, host).await?, + )), Err(err) => { warn!("couldn't connect to compute node at {host}:{port}: {err}"); - Err(err) + Err(TlsError::Connection(err)) } } } } -type RustlsStream = >::Stream; +pub type RustlsStream = >::Stream; +pub type MaybeRustlsStream = MaybeTlsStream; -pub(crate) struct PostgresConnection { - /// Socket connected to a compute node. - pub(crate) stream: - postgres_client::maybe_tls_stream::MaybeTlsStream, +// TODO(conrad): we don't need to parse these. +// These are just immediately forwarded back to the client. +// We could instead stream them out instead of reading them into memory. +pub struct PostgresSettings { /// PostgreSQL connection parameters. - pub(crate) params: std::collections::HashMap, + pub params: std::collections::HashMap, /// Query cancellation token. - pub(crate) cancel_closure: CancelClosure, - /// Labels for proxy's metrics. - pub(crate) aux: MetricsAuxInfo, + pub cancel_closure: CancelClosure, /// Notices received from compute after authenticating - pub(crate) delayed_notice: Vec, - - _guage: NumDbConnectionsGuard<'static>, + pub delayed_notice: Vec, } -impl ConnCfg { +pub struct ComputeConnection { + /// Socket connected to a compute node. + pub stream: MaybeTlsStream, + /// Labels for proxy's metrics. + pub aux: MetricsAuxInfo, + pub hostname: Host, + pub ssl_mode: SslMode, + pub socket_addr: SocketAddr, + pub guage: NumDbConnectionsGuard<'static>, +} + +impl ConnectInfo { /// Connect to a corresponding compute node. - pub(crate) async fn connect( + pub async fn connect( &self, ctx: &RequestContext, - aux: MetricsAuxInfo, + aux: &MetricsAuxInfo, config: &ComputeConfig, - user_info: ComputeUserInfo, - ) -> Result { + ) -> Result { let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let (socket_addr, stream, host) = self.connect_raw(config.timeout).await?; + let (socket_addr, stream) = self.connect_raw(config).await?; drop(pause); - let mut mk_tls = crate::tls::postgres_rustls::MakeRustlsConnect::new(config.tls.clone()); - let tls = >::make_tls_connect( - &mut mk_tls, - host, - )?; - - // connect_raw() will not use TLS if sslmode is "disable" - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let connection = self.0.connect_raw(stream, tls).await?; - drop(pause); - - let RawConnection { - stream, - parameters, - delayed_notice, - process_id, - secret_key, - } = connection; - - tracing::Span::current().record("pid", tracing::field::display(process_id)); tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id)); - let stream = stream.into_inner(); // TODO: lots of useful info but maybe we can move it elsewhere (eg traces?) info!( cold_start_info = ctx.cold_start_info().as_str(), - "connected to compute node at {host} ({socket_addr}) sslmode={:?}, latency={}, query_id={}", - self.0.get_ssl_mode(), + "connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}", + self.host, + self.ssl_mode, ctx.get_proxy_latency(), ctx.get_testodrome_id().unwrap_or_default(), ); - // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw. - // Yet another reason to rework the connection establishing code. - let cancel_closure = CancelClosure::new( - socket_addr, - CancelToken { - socket_config: None, - ssl_mode: self.0.get_ssl_mode(), - process_id, - secret_key, - }, - host.to_string(), - user_info, - ); - - let connection = PostgresConnection { + let connection = ComputeConnection { stream, - params: parameters, - delayed_notice, - cancel_closure, - aux, - _guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()), + socket_addr, + hostname: self.host.clone(), + ssl_mode: self.ssl_mode, + aux: aux.clone(), + guage: Metrics::get().proxy.db_connections.guard(ctx.protocol()), }; Ok(connection) diff --git a/proxy/src/compute/tls.rs b/proxy/src/compute/tls.rs new file mode 100644 index 0000000000..000d75fca5 --- /dev/null +++ b/proxy/src/compute/tls.rs @@ -0,0 +1,63 @@ +use futures::FutureExt; +use postgres_client::config::SslMode; +use postgres_client::maybe_tls_stream::MaybeTlsStream; +use postgres_client::tls::{MakeTlsConnect, TlsConnect}; +use rustls::pki_types::InvalidDnsNameError; +use thiserror::Error; +use tokio::io::{AsyncRead, AsyncWrite}; + +use crate::pqproto::request_tls; +use crate::proxy::retry::CouldRetry; + +#[derive(Debug, Error)] +pub enum TlsError { + #[error(transparent)] + Dns(#[from] InvalidDnsNameError), + #[error(transparent)] + Connection(#[from] std::io::Error), + #[error("TLS required but not provided")] + Required, +} + +impl CouldRetry for TlsError { + fn could_retry(&self) -> bool { + match self { + TlsError::Dns(_) => false, + TlsError::Connection(err) => err.could_retry(), + // perhaps compute didn't realise it supports TLS? + TlsError::Required => true, + } + } +} + +pub async fn connect_tls( + mut stream: S, + mode: SslMode, + tls: &T, + host: &str, +) -> Result, TlsError> +where + S: AsyncRead + AsyncWrite + Unpin + Send, + T: MakeTlsConnect< + S, + Error = InvalidDnsNameError, + TlsConnect: TlsConnect, + >, +{ + match mode { + SslMode::Disable => return Ok(MaybeTlsStream::Raw(stream)), + SslMode::Prefer | SslMode::Require => {} + } + + if !request_tls(&mut stream).await? { + if SslMode::Require == mode { + return Err(TlsError::Required); + } + + return Ok(MaybeTlsStream::Raw(stream)); + } + + Ok(MaybeTlsStream::Tls( + tls.make_tls_connect(host)?.connect(stream).boxed().await?, + )) +} diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 248584a19a..cee15ac7fa 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -22,7 +22,6 @@ pub struct ProxyConfig { pub http_config: HttpConfig, pub authentication_config: AuthenticationConfig, pub proxy_protocol_v2: ProxyProtocolV2, - pub region: String, pub handshake_timeout: Duration, pub wake_compute_retry_config: RetryConfig, pub connect_compute_locks: ApiLocks, diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index f2484b54b8..112465a89b 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -11,13 +11,12 @@ use crate::config::{ProxyConfig, ProxyProtocolV2}; use crate::context::RequestContext; use crate::error::ReportableError; use crate::metrics::{Metrics, NumClientConnectionsGuard}; -use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute}; use crate::pglb::handshake::{HandshakeData, handshake}; use crate::pglb::passthrough::ProxyPassthrough; use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; -use crate::proxy::{ - ClientRequestError, ErrorSource, prepare_client_connection, run_until_cancelled, -}; +use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; +use crate::proxy::{ClientRequestError, ErrorSource, prepare_client_connection}; +use crate::util::run_until_cancelled; pub async fn task_main( config: &'static ProxyConfig, @@ -90,12 +89,7 @@ pub async fn task_main( } } - let ctx = RequestContext::new( - session_id, - conn_info, - crate::metrics::Protocol::Tcp, - &config.region, - ); + let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp); let res = handle_client( config, @@ -121,7 +115,7 @@ pub async fn task_main( Ok(Some(p)) => { ctx.set_success(); let _disconnect = ctx.log_connect(); - match p.proxy_pass(&config.connect_to_compute).await { + match p.proxy_pass().await { Ok(()) => {} Err(ErrorSource::Client(e)) => { error!( @@ -210,20 +204,18 @@ pub(crate) async fn handle_client( ctx.set_db_options(params.clone()); - let (node_info, user_info, _ip_allowlist) = match backend + let (node_info, mut auth_info, user_info) = match backend .authenticate(ctx, &config.authentication_config, &mut stream) .await { Ok(auth_result) => auth_result, Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, }; + auth_info.set_startup_params(¶ms, true); - let node = connect_to_compute( + let mut node = connect_to_compute( ctx, &TcpMechanism { - user_info, - params_compat: true, - params: ¶ms, locks: &config.connect_compute_locks, }, &node_info, @@ -233,22 +225,40 @@ pub(crate) async fn handle_client( .or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) }) .await?; - let cancellation_handler_clone = Arc::clone(&cancellation_handler); - let session = cancellation_handler_clone.get_key(); + let pg_settings = auth_info + .authenticate(ctx, &mut node, user_info) + .or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) }) + .await?; - session.write_cancel_key(node.cancel_closure.clone())?; + let session = cancellation_handler.get_key(); - prepare_client_connection(&node, *session.key(), &mut stream); + prepare_client_connection(&pg_settings, *session.key(), &mut stream); let stream = stream.flush_and_into_inner().await?; + let session_id = ctx.session_id(); + let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel(); + tokio::spawn(async move { + session + .maintain_cancel_key( + session_id, + cancel, + &pg_settings.cancel_closure, + &config.connect_to_compute, + ) + .await; + }); + Ok(Some(ProxyPassthrough { client: stream, - aux: node.aux.clone(), + compute: node.stream, + + aux: node.aux, private_link_id: None, - compute: node, - session_id: ctx.session_id(), - cancel: session, + + _cancel_on_shutdown: cancel_on_shutdown, + _req: request_gauge, _conn: conn_gauge, + _db_conn: node.guage, })) } diff --git a/proxy/src/context/mod.rs b/proxy/src/context/mod.rs index 24268997ba..df1c4e194a 100644 --- a/proxy/src/context/mod.rs +++ b/proxy/src/context/mod.rs @@ -46,7 +46,6 @@ struct RequestContextInner { pub(crate) session_id: Uuid, pub(crate) protocol: Protocol, first_packet: chrono::DateTime, - region: &'static str, pub(crate) span: Span, // filled in as they are discovered @@ -94,7 +93,6 @@ impl Clone for RequestContext { session_id: inner.session_id, protocol: inner.protocol, first_packet: inner.first_packet, - region: inner.region, span: info_span!("background_task"), project: inner.project, @@ -124,12 +122,7 @@ impl Clone for RequestContext { } impl RequestContext { - pub fn new( - session_id: Uuid, - conn_info: ConnectionInfo, - protocol: Protocol, - region: &'static str, - ) -> Self { + pub fn new(session_id: Uuid, conn_info: ConnectionInfo, protocol: Protocol) -> Self { // TODO: be careful with long lived spans let span = info_span!( "connect_request", @@ -145,7 +138,6 @@ impl RequestContext { session_id, protocol, first_packet: Utc::now(), - region, span, project: None, @@ -179,7 +171,7 @@ impl RequestContext { let ip = IpAddr::from([127, 0, 0, 1]); let addr = SocketAddr::new(ip, 5432); let conn_info = ConnectionInfo { addr, extra: None }; - RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp, "test") + RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp) } pub(crate) fn console_application_name(&self) -> String { diff --git a/proxy/src/context/parquet.rs b/proxy/src/context/parquet.rs index c9d3905abd..b55cc14532 100644 --- a/proxy/src/context/parquet.rs +++ b/proxy/src/context/parquet.rs @@ -74,7 +74,7 @@ pub(crate) const FAILED_UPLOAD_MAX_RETRIES: u32 = 10; #[derive(parquet_derive::ParquetRecordWriter)] pub(crate) struct RequestData { - region: &'static str, + region: String, protocol: &'static str, /// Must be UTC. The derive macro doesn't like the timezones timestamp: chrono::NaiveDateTime, @@ -147,7 +147,7 @@ impl From<&RequestContextInner> for RequestData { }), jwt_issuer: value.jwt_issuer.clone(), protocol: value.protocol.as_str(), - region: value.region, + region: String::new(), error: value.error_kind.as_ref().map(|e| e.to_metric_label()), success: value.success, cold_start_info: value.cold_start_info.as_str(), @@ -167,6 +167,7 @@ impl From<&RequestContextInner> for RequestData { pub async fn worker( cancellation_token: CancellationToken, config: ParquetUploadArgs, + region: String, ) -> anyhow::Result<()> { let Some(remote_storage_config) = config.parquet_upload_remote_storage else { tracing::warn!("parquet request upload: no s3 bucket configured"); @@ -232,12 +233,17 @@ pub async fn worker( .context("remote storage for disconnect events init")?; let parquet_config_disconnect = parquet_config.clone(); tokio::try_join!( - worker_inner(storage, rx, parquet_config), - worker_inner(storage_disconnect, rx_disconnect, parquet_config_disconnect) + worker_inner(storage, rx, parquet_config, ®ion), + worker_inner( + storage_disconnect, + rx_disconnect, + parquet_config_disconnect, + ®ion + ) ) .map(|_| ()) } else { - worker_inner(storage, rx, parquet_config).await + worker_inner(storage, rx, parquet_config, ®ion).await } } @@ -257,6 +263,7 @@ async fn worker_inner( storage: GenericRemoteStorage, rx: impl Stream, config: ParquetConfig, + region: &str, ) -> anyhow::Result<()> { #[cfg(any(test, feature = "testing"))] let storage = if config.test_remote_failures > 0 { @@ -277,7 +284,8 @@ async fn worker_inner( let mut last_upload = time::Instant::now(); let mut len = 0; - while let Some(row) = rx.next().await { + while let Some(mut row) = rx.next().await { + region.clone_into(&mut row.region); rows.push(row); let force = last_upload.elapsed() > config.max_duration; if rows.len() == config.rows_per_group || force { @@ -533,7 +541,7 @@ mod tests { auth_method: None, jwt_issuer: None, protocol: ["tcp", "ws", "http"][rng.gen_range(0..3)], - region: "us-east-1", + region: String::new(), error: None, success: rng.r#gen(), cold_start_info: "no", @@ -565,7 +573,9 @@ mod tests { .await .unwrap(); - worker_inner(storage, rx, config).await.unwrap(); + worker_inner(storage, rx, config, "us-east-1") + .await + .unwrap(); let mut files = WalkDir::new(tmpdir.as_std_path()) .into_iter() diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index da548d6b2c..8c76d034f7 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -146,6 +146,7 @@ impl NeonControlPlaneClient { public_access_blocked: block_public_connections, vpc_access_blocked: block_vpc_connections, }, + rate_limits: body.rate_limits, }) } .inspect_err(|e| tracing::debug!(error = ?e)) @@ -261,24 +262,18 @@ impl NeonControlPlaneClient { Some(_) => SslMode::Require, None => SslMode::Disable, }; - let host_name = match body.server_name { - Some(host) => host, - None => host.to_owned(), + let host = match body.server_name { + Some(host) => host.into(), + None => host.into(), }; - // Don't set anything but host and port! This config will be cached. - // We'll set username and such later using the startup message. - // TODO: add more type safety (in progress). - let mut config = compute::ConnCfg::new(host_name, port); - - if let Some(addr) = host_addr { - config.set_host_addr(addr); - } - - config.ssl_mode(ssl_mode); - let node = NodeInfo { - config, + conn_info: compute::ConnectInfo { + host_addr, + host, + port, + ssl_mode, + }, aux: body.aux, }; @@ -318,6 +313,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { allowed_ips: Arc::new(auth_info.allowed_ips), allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), flags: auth_info.access_blocker_flags, + rate_limits: auth_info.rate_limits, }; let role_control = RoleAccessControl { secret: auth_info.secret, @@ -363,6 +359,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { allowed_ips: Arc::new(auth_info.allowed_ips), allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids), flags: auth_info.access_blocker_flags, + rate_limits: auth_info.rate_limits, }; let role_control = RoleAccessControl { secret: auth_info.secret, diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index ece7153fce..b84dba6b09 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -6,6 +6,7 @@ use std::str::FromStr; use std::sync::Arc; use futures::TryFutureExt; +use postgres_client::config::SslMode; use thiserror::Error; use tokio_postgres::Client; use tracing::{Instrument, error, info, info_span, warn}; @@ -14,19 +15,20 @@ use crate::auth::IpPattern; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; use crate::cache::Cached; +use crate::compute::ConnectInfo; use crate::context::RequestContext; use crate::control_plane::errors::{ ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError, }; -use crate::control_plane::messages::MetricsAuxInfo; +use crate::control_plane::messages::{EndpointRateLimitConfig, MetricsAuxInfo}; use crate::control_plane::{ AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, RoleAccessControl, }; use crate::intern::RoleNameInt; +use crate::scram; use crate::types::{BranchId, EndpointId, ProjectId, RoleName}; use crate::url::ApiUrl; -use crate::{compute, scram}; #[derive(Debug, Error)] enum MockApiError { @@ -87,8 +89,7 @@ impl MockControlPlane { .await? { info!("got a secret: {entry}"); // safe since it's not a prod scenario - let secret = scram::ServerSecret::parse(&entry).map(AuthSecret::Scram); - secret.or_else(|| parse_md5(&entry).map(AuthSecret::Md5)) + scram::ServerSecret::parse(&entry).map(AuthSecret::Scram) } else { warn!("user '{role}' does not exist"); None @@ -129,6 +130,7 @@ impl MockControlPlane { project_id: None, account_id: None, access_blocker_flags: AccessBlockerFlags::default(), + rate_limits: EndpointRateLimitConfig::default(), }) } @@ -170,25 +172,23 @@ impl MockControlPlane { async fn do_wake_compute(&self) -> Result { let port = self.endpoint.port().unwrap_or(5432); - let mut config = match self.endpoint.host_str() { - None => { - let mut config = compute::ConnCfg::new("localhost".to_string(), port); - config.set_host_addr(IpAddr::V4(Ipv4Addr::LOCALHOST)); - config - } - Some(host) => { - let mut config = compute::ConnCfg::new(host.to_string(), port); - if let Ok(addr) = IpAddr::from_str(host) { - config.set_host_addr(addr); - } - config - } + let conn_info = match self.endpoint.host_str() { + None => ConnectInfo { + host_addr: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), + host: "localhost".into(), + port, + ssl_mode: SslMode::Disable, + }, + Some(host) => ConnectInfo { + host_addr: IpAddr::from_str(host).ok(), + host: host.into(), + port, + ssl_mode: SslMode::Disable, + }, }; - config.ssl_mode(postgres_client::config::SslMode::Disable); - let node = NodeInfo { - config, + conn_info, aux: MetricsAuxInfo { endpoint_id: (&EndpointId::from("endpoint")).into(), project_id: (&ProjectId::from("project")).into(), @@ -234,6 +234,7 @@ impl super::ControlPlaneApi for MockControlPlane { allowed_ips: Arc::new(info.allowed_ips), allowed_vpce: Arc::new(info.allowed_vpc_endpoint_ids), flags: info.access_blocker_flags, + rate_limits: info.rate_limits, }) } @@ -266,12 +267,3 @@ impl super::ControlPlaneApi for MockControlPlane { self.do_wake_compute().map_ok(Cached::new_uncached).await } } - -fn parse_md5(input: &str) -> Option<[u8; 16]> { - let text = input.strip_prefix("md5")?; - - let mut bytes = [0u8; 16]; - hex::decode_to_slice(text, &mut bytes).ok()?; - - Some(bytes) -} diff --git a/proxy/src/control_plane/client/mod.rs b/proxy/src/control_plane/client/mod.rs index 9b9d1e25ea..4e5f5c7899 100644 --- a/proxy/src/control_plane/client/mod.rs +++ b/proxy/src/control_plane/client/mod.rs @@ -10,6 +10,7 @@ use clashmap::ClashMap; use tokio::time::Instant; use tracing::{debug, info}; +use super::{EndpointAccessControl, RoleAccessControl}; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError}; use crate::cache::endpoints::EndpointsCache; @@ -22,8 +23,6 @@ use crate::metrics::ApiLockMetrics; use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token}; use crate::types::EndpointId; -use super::{EndpointAccessControl, RoleAccessControl}; - #[non_exhaustive] #[derive(Clone)] pub enum ControlPlaneClient { diff --git a/proxy/src/control_plane/messages.rs b/proxy/src/control_plane/messages.rs index ec4554eab5..f0314f91f0 100644 --- a/proxy/src/control_plane/messages.rs +++ b/proxy/src/control_plane/messages.rs @@ -227,12 +227,35 @@ pub(crate) struct UserFacingMessage { #[derive(Deserialize)] pub(crate) struct GetEndpointAccessControl { pub(crate) role_secret: Box, - pub(crate) allowed_ips: Option>, - pub(crate) allowed_vpc_endpoint_ids: Option>, + pub(crate) project_id: Option, pub(crate) account_id: Option, + + pub(crate) allowed_ips: Option>, + pub(crate) allowed_vpc_endpoint_ids: Option>, pub(crate) block_public_connections: Option, pub(crate) block_vpc_connections: Option, + + #[serde(default)] + pub(crate) rate_limits: EndpointRateLimitConfig, +} + +#[derive(Copy, Clone, Deserialize, Default)] +pub struct EndpointRateLimitConfig { + pub connection_attempts: ConnectionAttemptsLimit, +} + +#[derive(Copy, Clone, Deserialize, Default)] +pub struct ConnectionAttemptsLimit { + pub tcp: Option, + pub ws: Option, + pub http: Option, +} + +#[derive(Copy, Clone, Deserialize)] +pub struct LeakyBucketSetting { + pub rps: f64, + pub burst: f64, } /// Response which holds compute node's `host:port` pair. diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index 7ff093d9dc..a8c59dad0c 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -11,15 +11,18 @@ pub(crate) mod errors; use std::sync::Arc; +use messages::EndpointRateLimitConfig; + +use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; -use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list}; use crate::cache::{Cached, TimedLru}; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo}; -use crate::intern::{AccountIdInt, ProjectIdInt}; +use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt}; use crate::protocol2::ConnectionInfoExtra; +use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig}; use crate::types::{EndpointCacheKey, EndpointId, RoleName}; use crate::{compute, scram}; @@ -39,10 +42,6 @@ pub mod mgmt; /// Auth secret which is managed by the cloud. #[derive(Clone, Eq, PartialEq, Debug)] pub(crate) enum AuthSecret { - #[cfg(any(test, feature = "testing"))] - /// Md5 hash of user's password. - Md5([u8; 16]), - /// [SCRAM](crate::scram) authentication info. Scram(scram::ServerSecret), } @@ -60,16 +59,14 @@ pub(crate) struct AuthInfo { pub(crate) account_id: Option, /// Are public connections or VPC connections blocked? pub(crate) access_blocker_flags: AccessBlockerFlags, + /// The rate limits for this endpoint. + pub(crate) rate_limits: EndpointRateLimitConfig, } /// Info for establishing a connection to a compute node. -/// This is what we get after auth succeeded, but not before! #[derive(Clone)] pub(crate) struct NodeInfo { - /// Compute node connection params. - /// It's sad that we have to clone this, but this will improve - /// once we migrate to a bespoke connection logic. - pub(crate) config: compute::ConnCfg, + pub(crate) conn_info: compute::ConnectInfo, /// Labels for proxy's metrics. pub(crate) aux: MetricsAuxInfo, @@ -80,24 +77,8 @@ impl NodeInfo { &self, ctx: &RequestContext, config: &ComputeConfig, - user_info: ComputeUserInfo, - ) -> Result { - self.config - .connect(ctx, self.aux.clone(), config, user_info) - .await - } - - pub(crate) fn reuse_settings(&mut self, other: Self) { - self.config.reuse_password(other.config); - } - - pub(crate) fn set_keys(&mut self, keys: &ComputeCredentialKeys) { - match keys { - #[cfg(any(test, feature = "testing"))] - ComputeCredentialKeys::Password(password) => self.config.password(password), - ComputeCredentialKeys::AuthKeys(auth_keys) => self.config.auth_keys(*auth_keys), - ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => &mut self.config, - }; + ) -> Result { + self.conn_info.connect(ctx, &self.aux, config).await } } @@ -121,6 +102,8 @@ pub struct EndpointAccessControl { pub allowed_ips: Arc>, pub allowed_vpce: Arc>, pub flags: AccessBlockerFlags, + + pub rate_limits: EndpointRateLimitConfig, } impl EndpointAccessControl { @@ -159,6 +142,36 @@ impl EndpointAccessControl { Ok(()) } + + pub fn connection_attempt_rate_limit( + &self, + ctx: &RequestContext, + endpoint: &EndpointId, + rate_limiter: &EndpointRateLimiter, + ) -> Result<(), AuthError> { + let endpoint = EndpointIdInt::from(endpoint); + + let limits = &self.rate_limits.connection_attempts; + let config = match ctx.protocol() { + crate::metrics::Protocol::Http => limits.http, + crate::metrics::Protocol::Ws => limits.ws, + crate::metrics::Protocol::Tcp => limits.tcp, + crate::metrics::Protocol::SniRouter => return Ok(()), + }; + let config = config.and_then(|config| { + if config.rps <= 0.0 || config.burst <= 0.0 { + return None; + } + + Some(LeakyBucketConfig::new(config.rps, config.burst)) + }); + + if !rate_limiter.check(endpoint, config, 1) { + return Err(AuthError::too_many_connections()); + } + + Ok(()) + } } /// This will allocate per each call, but the http requests alone diff --git a/proxy/src/lib.rs b/proxy/src/lib.rs index d65d056585..263d784e78 100644 --- a/proxy/src/lib.rs +++ b/proxy/src/lib.rs @@ -61,6 +61,10 @@ clippy::too_many_lines, clippy::unused_self )] +#![allow( + clippy::unsafe_derive_deserialize, + reason = "false positive: https://github.com/rust-lang/rust-clippy/issues/15120" +)] #![cfg_attr( any(test, feature = "testing"), allow( @@ -75,6 +79,7 @@ pub mod binary; mod auth; +mod batch; mod cache; mod cancellation; mod compute; @@ -106,4 +111,5 @@ mod tls; mod types; mod url; mod usage_metrics; +mod util; mod waiters; diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 4b22c912eb..4c340edfd5 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -610,11 +610,11 @@ pub enum RedisEventsCount { BranchCreated, ProjectCreated, CancelSession, - PasswordUpdate, - AllowedIpsUpdate, - AllowedVpcEndpointIdsUpdateForProjects, - AllowedVpcEndpointIdsUpdateForAllProjectsInOrg, - BlockPublicOrVpcAccessUpdate, + InvalidateRole, + InvalidateEndpoint, + InvalidateProject, + InvalidateProjects, + InvalidateOrg, } pub struct ThreadPoolWorkers(usize); diff --git a/proxy/src/pglb/mod.rs b/proxy/src/pglb/mod.rs index 4b107142a7..cb82524cf6 100644 --- a/proxy/src/pglb/mod.rs +++ b/proxy/src/pglb/mod.rs @@ -1,4 +1,3 @@ -pub mod connect_compute; pub mod copy_bidirectional; pub mod handshake; pub mod inprocess; diff --git a/proxy/src/pglb/passthrough.rs b/proxy/src/pglb/passthrough.rs index 6f651d383d..d4c029f6d9 100644 --- a/proxy/src/pglb/passthrough.rs +++ b/proxy/src/pglb/passthrough.rs @@ -1,15 +1,17 @@ -use futures::FutureExt; +use std::convert::Infallible; + use smol_str::SmolStr; use tokio::io::{AsyncRead, AsyncWrite}; use tracing::debug; use utils::measured_stream::MeasuredStream; use super::copy_bidirectional::ErrorSource; -use crate::cancellation; -use crate::compute::PostgresConnection; -use crate::config::ComputeConfig; +use crate::compute::MaybeRustlsStream; use crate::control_plane::messages::MetricsAuxInfo; -use crate::metrics::{Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard}; +use crate::metrics::{ + Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard, + NumDbConnectionsGuard, +}; use crate::stream::Stream; use crate::usage_metrics::{Ids, MetricCounterRecorder, USAGE_METRICS}; @@ -64,40 +66,20 @@ pub(crate) async fn proxy_pass( pub(crate) struct ProxyPassthrough { pub(crate) client: Stream, - pub(crate) compute: PostgresConnection, + pub(crate) compute: MaybeRustlsStream, + pub(crate) aux: MetricsAuxInfo, - pub(crate) session_id: uuid::Uuid, pub(crate) private_link_id: Option, - pub(crate) cancel: cancellation::Session, + + pub(crate) _cancel_on_shutdown: tokio::sync::oneshot::Sender, pub(crate) _req: NumConnectionRequestsGuard<'static>, pub(crate) _conn: NumClientConnectionsGuard<'static>, + pub(crate) _db_conn: NumDbConnectionsGuard<'static>, } impl ProxyPassthrough { - pub(crate) async fn proxy_pass( - self, - compute_config: &ComputeConfig, - ) -> Result<(), ErrorSource> { - let res = proxy_pass( - self.client, - self.compute.stream, - self.aux, - self.private_link_id, - ) - .await; - if let Err(err) = self - .compute - .cancel_closure - .try_cancel_query(compute_config) - .boxed() - .await - { - tracing::warn!(session_id = ?self.session_id, ?err, "could not cancel the query in the database"); - } - - drop(self.cancel.remove_cancel_key()); // we don't need a result. If the queue is full, we just log the error - - res + pub(crate) async fn proxy_pass(self) -> Result<(), ErrorSource> { + proxy_pass(self.client, self.compute, self.aux, self.private_link_id).await } } diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs index 43074bf208..ad99eecda5 100644 --- a/proxy/src/pqproto.rs +++ b/proxy/src/pqproto.rs @@ -8,7 +8,7 @@ use std::io::{self, Cursor}; use bytes::{Buf, BufMut}; use itertools::Itertools; use rand::distributions::{Distribution, Standard}; -use tokio::io::{AsyncRead, AsyncReadExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use zerocopy::{FromBytes, Immutable, IntoBytes, big_endian}; pub type ErrorCode = [u8; 5]; @@ -53,6 +53,28 @@ impl fmt::Debug for ProtocolVersion { } } +/// +const MAX_STARTUP_PACKET_LENGTH: usize = 10000; +const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234; +/// +const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678); +/// +const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679); +/// +const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680); + +/// This first reads the startup message header, is 8 bytes. +/// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number. +/// +/// The length value is inclusive of the header. For example, +/// an empty message will always have length 8. +#[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)] +#[repr(C)] +struct StartupHeader { + len: big_endian::U32, + version: ProtocolVersion, +} + /// read the type from the stream using zerocopy. /// /// not cancel safe. @@ -66,32 +88,38 @@ macro_rules! read { }}; } +/// Returns true if TLS is supported. +/// +/// This is not cancel safe. +pub async fn request_tls(stream: &mut S) -> io::Result +where + S: AsyncRead + AsyncWrite + Unpin, +{ + let payload = StartupHeader { + len: 8.into(), + version: NEGOTIATE_SSL_CODE, + }; + stream.write_all(payload.as_bytes()).await?; + stream.flush().await?; + + // we expect back either `S` or `N` as a single byte. + let mut res = *b"0"; + stream.read_exact(&mut res).await?; + + debug_assert!( + res == *b"S" || res == *b"N", + "unexpected SSL negotiation response: {}", + char::from(res[0]), + ); + + // S for SSL. + Ok(res == *b"S") +} + pub async fn read_startup(stream: &mut S) -> io::Result where S: AsyncRead + Unpin, { - /// - const MAX_STARTUP_PACKET_LENGTH: usize = 10000; - const RESERVED_INVALID_MAJOR_VERSION: u16 = 1234; - /// - const CANCEL_REQUEST_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5678); - /// - const NEGOTIATE_SSL_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5679); - /// - const NEGOTIATE_GSS_CODE: ProtocolVersion = ProtocolVersion::new(1234, 5680); - - /// This first reads the startup message header, is 8 bytes. - /// The first 4 bytes is a big-endian message length, and the next 4 bytes is a version number. - /// - /// The length value is inclusive of the header. For example, - /// an empty message will always have length 8. - #[derive(Clone, Copy, FromBytes, IntoBytes, Immutable)] - #[repr(C)] - struct StartupHeader { - len: big_endian::U32, - version: ProtocolVersion, - } - let header = read!(stream => StartupHeader); // @@ -564,9 +592,8 @@ mod tests { use tokio::io::{AsyncWriteExt, duplex}; use zerocopy::IntoBytes; - use crate::pqproto::{FeStartupPacket, read_message, read_startup}; - use super::ProtocolVersion; + use crate::pqproto::{FeStartupPacket, read_message, read_startup}; #[tokio::test] async fn reject_large_startup() { diff --git a/proxy/src/pglb/connect_compute.rs b/proxy/src/proxy/connect_compute.rs similarity index 75% rename from proxy/src/pglb/connect_compute.rs rename to proxy/src/proxy/connect_compute.rs index 1d6ca5fbb3..aa675a439e 100644 --- a/proxy/src/pglb/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -2,26 +2,24 @@ use async_trait::async_trait; use tokio::time; use tracing::{debug, info, warn}; -use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; -use crate::compute::{self, COULD_NOT_CONNECT, PostgresConnection}; +use crate::compute::{self, COULD_NOT_CONNECT, ComputeConnection}; use crate::config::{ComputeConfig, RetryConfig}; use crate::context::RequestContext; use crate::control_plane::errors::WakeComputeError; use crate::control_plane::locks::ApiLocks; -use crate::control_plane::{self, CachedNodeInfo, NodeInfo}; +use crate::control_plane::{self, NodeInfo}; use crate::error::ReportableError; use crate::metrics::{ ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType, }; -use crate::pqproto::StartupMessageParams; use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry}; -use crate::proxy::wake_compute::wake_compute; +use crate::proxy::wake_compute::{WakeComputeBackend, wake_compute}; use crate::types::Host; /// If we couldn't connect, a cached connection info might be to blame /// (e.g. the compute node's address might've changed at the wrong time). /// Invalidate the cache entry (if any) to prevent subsequent errors. -#[tracing::instrument(name = "invalidate_cache", skip_all)] +#[tracing::instrument(skip_all)] pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> NodeInfo { let is_cached = node_info.cached(); if is_cached { @@ -48,35 +46,16 @@ pub(crate) trait ConnectMechanism { node_info: &control_plane::CachedNodeInfo, config: &ComputeConfig, ) -> Result; - - fn update_connect_config(&self, conf: &mut compute::ConnCfg); } -#[async_trait] -pub(crate) trait ComputeConnectBackend { - async fn wake_compute( - &self, - ctx: &RequestContext, - ) -> Result; - - fn get_keys(&self) -> &ComputeCredentialKeys; -} - -pub(crate) struct TcpMechanism<'a> { - pub(crate) params_compat: bool, - - /// KV-dictionary with PostgreSQL connection params. - pub(crate) params: &'a StartupMessageParams, - +pub(crate) struct TcpMechanism { /// connect_to_compute concurrency lock pub(crate) locks: &'static ApiLocks, - - pub(crate) user_info: ComputeUserInfo, } #[async_trait] -impl ConnectMechanism for TcpMechanism<'_> { - type Connection = PostgresConnection; +impl ConnectMechanism for TcpMechanism { + type Connection = ComputeConnection; type ConnectError = compute::ConnectionError; type Error = compute::ConnectionError; @@ -89,20 +68,15 @@ impl ConnectMechanism for TcpMechanism<'_> { ctx: &RequestContext, node_info: &control_plane::CachedNodeInfo, config: &ComputeConfig, - ) -> Result { - let host = node_info.config.get_host(); - let permit = self.locks.get_permit(&host).await?; - permit.release_result(node_info.connect(ctx, config, self.user_info.clone()).await) - } - - fn update_connect_config(&self, config: &mut compute::ConnCfg) { - config.set_startup_params(self.params, self.params_compat); + ) -> Result { + let permit = self.locks.get_permit(&node_info.conn_info.host).await?; + permit.release_result(node_info.connect(ctx, config).await) } } /// Try to connect to the compute node, retrying if necessary. #[tracing::instrument(skip_all)] -pub(crate) async fn connect_to_compute( +pub(crate) async fn connect_to_compute( ctx: &RequestContext, mechanism: &M, user_info: &B, @@ -114,12 +88,9 @@ where M::Error: From, { let mut num_retries = 0; - let mut node_info = + let node_info = wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?; - node_info.set_keys(user_info.get_keys()); - mechanism.update_connect_config(&mut node_info.config); - // try once let err = match mechanism.connect_once(ctx, &node_info, compute).await { Ok(res) => { @@ -155,14 +126,9 @@ where } else { // if we failed to connect, it's likely that the compute node was suspended, wake a new compute node debug!("compute node's state has likely changed; requesting a wake-up"); - let old_node_info = invalidate_cache(node_info); + invalidate_cache(node_info); // TODO: increment num_retries? - let mut node_info = - wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?; - node_info.reuse_settings(old_node_info); - - mechanism.update_connect_config(&mut node_info.config); - node_info + wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await? }; // now that we have a new node, try connect to it repeatedly. diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 0e138cc0c7..6b84e47982 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -1,8 +1,10 @@ #[cfg(test)] mod tests; +pub(crate) mod connect_compute; pub(crate) mod retry; pub(crate) mod wake_compute; + use std::sync::Arc; use futures::FutureExt; @@ -21,15 +23,16 @@ use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig}; use crate::context::RequestContext; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumClientConnectionsGuard}; -use crate::pglb::connect_compute::{TcpMechanism, connect_to_compute}; pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake}; use crate::pglb::passthrough::ProxyPassthrough; use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams}; use crate::protocol2::{ConnectHeader, ConnectionInfo, ConnectionInfoExtra, read_proxy_protocol}; +use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; use crate::rate_limiter::EndpointRateLimiter; use crate::stream::{PqStream, Stream}; use crate::types::EndpointCacheKey; +use crate::util::run_until_cancelled; use crate::{auth, compute}; const ERR_INSECURE_CONNECTION: &str = "connection is insecure (try using `sslmode=require`)"; @@ -46,21 +49,6 @@ impl ReportableError for TlsRequired { impl UserFacingError for TlsRequired {} -pub async fn run_until_cancelled( - f: F, - cancellation_token: &CancellationToken, -) -> Option { - match futures::future::select( - std::pin::pin!(f), - std::pin::pin!(cancellation_token.cancelled()), - ) - .await - { - futures::future::Either::Left((f, _)) => Some(f), - futures::future::Either::Right(((), _)) => None, - } -} - pub async fn task_main( config: &'static ProxyConfig, auth_backend: &'static auth::Backend<'static, ()>, @@ -134,12 +122,7 @@ pub async fn task_main( } } - let ctx = RequestContext::new( - session_id, - conn_info, - crate::metrics::Protocol::Tcp, - &config.region, - ); + let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp); let res = handle_client( config, @@ -167,7 +150,7 @@ pub async fn task_main( Ok(Some(p)) => { ctx.set_success(); let _disconnect = ctx.log_connect(); - match p.proxy_pass(&config.connect_to_compute).await { + match p.proxy_pass().await { Ok(()) => {} Err(ErrorSource::Client(e)) => { warn!( @@ -358,41 +341,54 @@ pub(crate) async fn handle_client( } }; - let compute_user_info = match &user_info { - auth::Backend::ControlPlane(_, info) => &info.info, + let (cplane, creds) = match user_info { + auth::Backend::ControlPlane(cplane, creds) => (cplane, creds), auth::Backend::Local(_) => unreachable!("local proxy does not run tcp proxy service"), }; - let params_compat = compute_user_info - .options - .get(NeonOptions::PARAMS_COMPAT) - .is_some(); + let params_compat = creds.info.options.get(NeonOptions::PARAMS_COMPAT).is_some(); + let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys); + auth_info.set_startup_params(¶ms, params_compat); let res = connect_to_compute( ctx, &TcpMechanism { - user_info: compute_user_info.clone(), - params_compat, - params: ¶ms, locks: &config.connect_compute_locks, }, - &user_info, + &auth::Backend::ControlPlane(cplane, creds.info.clone()), config.wake_compute_retry_config, &config.connect_to_compute, ) .await; - let node = match res { + let mut node = match res { Ok(node) => node, Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, }; - let cancellation_handler_clone = Arc::clone(&cancellation_handler); - let session = cancellation_handler_clone.get_key(); + let pg_settings = auth_info.authenticate(ctx, &mut node, creds.info).await; + let pg_settings = match pg_settings { + Ok(pg_settings) => pg_settings, + Err(e) => Err(stream.throw_error(e, Some(ctx)).await)?, + }; - session.write_cancel_key(node.cancel_closure.clone())?; - prepare_client_connection(&node, *session.key(), &mut stream); + let session = cancellation_handler.get_key(); + + prepare_client_connection(&pg_settings, *session.key(), &mut stream); let stream = stream.flush_and_into_inner().await?; + let session_id = ctx.session_id(); + let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel(); + tokio::spawn(async move { + session + .maintain_cancel_key( + session_id, + cancel, + &pg_settings.cancel_closure, + &config.connect_to_compute, + ) + .await; + }); + let private_link_id = match ctx.extra() { Some(ConnectionInfoExtra::Aws { vpce_id }) => Some(vpce_id.clone()), Some(ConnectionInfoExtra::Azure { link_id }) => Some(link_id.to_smolstr()), @@ -401,31 +397,34 @@ pub(crate) async fn handle_client( Ok(Some(ProxyPassthrough { client: stream, - aux: node.aux.clone(), + compute: node.stream, + + aux: node.aux, private_link_id, - compute: node, - session_id: ctx.session_id(), - cancel: session, + + _cancel_on_shutdown: cancel_on_shutdown, + _req: request_gauge, _conn: conn_gauge, + _db_conn: node.guage, })) } /// Finish client connection initialization: confirm auth success, send params, etc. pub(crate) fn prepare_client_connection( - node: &compute::PostgresConnection, + settings: &compute::PostgresSettings, cancel_key_data: CancelKeyData, stream: &mut PqStream, ) { // Forward all deferred notices to the client. - for notice in &node.delayed_notice { + for notice in &settings.delayed_notice { stream.write_raw(notice.as_bytes().len(), b'N', |buf| { buf.extend_from_slice(notice.as_bytes()); }); } // Forward all postgres connection params to the client. - for (name, value) in &node.params { + for (name, value) in &settings.params { stream.write_message(BeMessage::ParameterStatus { name: name.as_bytes(), value: value.as_bytes(), diff --git a/proxy/src/proxy/retry.rs b/proxy/src/proxy/retry.rs index 01e603ec14..e9eca95724 100644 --- a/proxy/src/proxy/retry.rs +++ b/proxy/src/proxy/retry.rs @@ -99,17 +99,15 @@ impl ShouldRetryWakeCompute for postgres_client::Error { impl CouldRetry for compute::ConnectionError { fn could_retry(&self) -> bool { match self { - compute::ConnectionError::Postgres(err) => err.could_retry(), - compute::ConnectionError::CouldNotConnect(err) => err.could_retry(), + compute::ConnectionError::TlsError(err) => err.could_retry(), compute::ConnectionError::WakeComputeError(err) => err.could_retry(), - _ => false, + compute::ConnectionError::TooManyConnectionAttempts(_) => false, } } } impl ShouldRetryWakeCompute for compute::ConnectionError { fn should_retry_wake_compute(&self) -> bool { match self { - compute::ConnectionError::Postgres(err) => err.should_retry_wake_compute(), // the cache entry was not checked for validity compute::ConnectionError::TooManyConnectionAttempts(_) => false, _ => true, diff --git a/proxy/src/proxy/tests/mitm.rs b/proxy/src/proxy/tests/mitm.rs index c92ee49b8d..67dd0ab522 100644 --- a/proxy/src/proxy/tests/mitm.rs +++ b/proxy/src/proxy/tests/mitm.rs @@ -169,7 +169,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> { .dbname("db") .password("password") .ssl_mode(SslMode::Require) - .connect_raw(server, client_config.make_tls_connect()?) + .tls_and_authenticate(server, client_config.make_tls_connect()?) .await?; proxy.await? @@ -252,7 +252,7 @@ async fn connect_failure( .dbname("db") .password("password") .ssl_mode(SslMode::Require) - .connect_raw(server, client_config.make_tls_connect()?) + .tls_and_authenticate(server, client_config.make_tls_connect()?) .await .err() .context("client shouldn't be able to connect")?; diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index e5db0013a7..29a269208a 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -19,17 +19,14 @@ use tracing_test::traced_test; use super::retry::CouldRetry; use super::*; -use crate::auth::backend::{ - ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo, MaybeOwned, -}; +use crate::auth::backend::{ComputeUserInfo, MaybeOwned}; use crate::config::{ComputeConfig, RetryConfig}; use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache}; use crate::error::ErrorKind; -use crate::pglb::connect_compute::ConnectMechanism; +use crate::proxy::connect_compute::ConnectMechanism; use crate::tls::client_config::compute_client_config_with_certs; -use crate::tls::postgres_rustls::MakeRustlsConnect; use crate::tls::server_config::CertResolver; use crate::types::{BranchId, EndpointId, ProjectId}; use crate::{sasl, scram}; @@ -72,13 +69,14 @@ struct ClientConfig<'a> { hostname: &'a str, } -type TlsConnect = >::TlsConnect; +type TlsConnect = >::TlsConnect; impl ClientConfig<'_> { fn make_tls_connect(self) -> anyhow::Result> { - let mut mk = MakeRustlsConnect::new(self.config); - let tls = MakeTlsConnect::::make_tls_connect(&mut mk, self.hostname)?; - Ok(tls) + Ok(crate::tls::postgres_rustls::make_tls_connect( + &self.config, + self.hostname, + )?) } } @@ -201,7 +199,7 @@ async fn handshake_tls_is_enforced_by_proxy() -> anyhow::Result<()> { .user("john_doe") .dbname("earth") .ssl_mode(SslMode::Disable) - .connect_raw(server, NoTls) + .tls_and_authenticate(server, NoTls) .await .err() // -> Option .context("client shouldn't be able to connect")?; @@ -230,7 +228,7 @@ async fn handshake_tls() -> anyhow::Result<()> { .user("john_doe") .dbname("earth") .ssl_mode(SslMode::Require) - .connect_raw(server, client_config.make_tls_connect()?) + .tls_and_authenticate(server, client_config.make_tls_connect()?) .await?; proxy.await? @@ -247,7 +245,7 @@ async fn handshake_raw() -> anyhow::Result<()> { .dbname("earth") .set_param("options", "project=generic-project-name") .ssl_mode(SslMode::Prefer) - .connect_raw(server, NoTls) + .tls_and_authenticate(server, NoTls) .await?; proxy.await? @@ -295,7 +293,7 @@ async fn scram_auth_good(#[case] password: &str) -> anyhow::Result<()> { .dbname("db") .password(password) .ssl_mode(SslMode::Require) - .connect_raw(server, client_config.make_tls_connect()?) + .tls_and_authenticate(server, client_config.make_tls_connect()?) .await?; proxy.await? @@ -319,7 +317,7 @@ async fn scram_auth_disable_channel_binding() -> anyhow::Result<()> { .dbname("db") .password("password") .ssl_mode(SslMode::Require) - .connect_raw(server, client_config.make_tls_connect()?) + .tls_and_authenticate(server, client_config.make_tls_connect()?) .await?; proxy.await? @@ -346,7 +344,7 @@ async fn scram_auth_mock() -> anyhow::Result<()> { .dbname("db") .password(&password) // no password will match the mocked secret .ssl_mode(SslMode::Require) - .connect_raw(server, client_config.make_tls_connect()?) + .tls_and_authenticate(server, client_config.make_tls_connect()?) .await .err() // -> Option .context("client shouldn't be able to connect")?; @@ -497,8 +495,6 @@ impl ConnectMechanism for TestConnectMechanism { x => panic!("expecting action {x:?}, connect is called instead"), } } - - fn update_connect_config(&self, _conf: &mut compute::ConnCfg) {} } impl TestControlPlaneClient for TestConnectMechanism { @@ -557,7 +553,12 @@ impl TestControlPlaneClient for TestConnectMechanism { fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo { let node = NodeInfo { - config: compute::ConnCfg::new("test".to_owned(), 5432), + conn_info: compute::ConnectInfo { + host: "test".into(), + port: 5432, + ssl_mode: SslMode::Disable, + host_addr: None, + }, aux: MetricsAuxInfo { endpoint_id: (&EndpointId::from("endpoint")).into(), project_id: (&ProjectId::from("project")).into(), @@ -572,16 +573,13 @@ fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeIn fn helper_create_connect_info( mechanism: &TestConnectMechanism, -) -> auth::Backend<'static, ComputeCredentials> { +) -> auth::Backend<'static, ComputeUserInfo> { auth::Backend::ControlPlane( MaybeOwned::Owned(ControlPlaneClient::Test(Box::new(mechanism.clone()))), - ComputeCredentials { - info: ComputeUserInfo { - endpoint: "endpoint".into(), - user: "user".into(), - options: NeonOptions::parse_options_raw(""), - }, - keys: ComputeCredentialKeys::Password("password".into()), + ComputeUserInfo { + endpoint: "endpoint".into(), + user: "user".into(), + options: NeonOptions::parse_options_raw(""), }, ) } diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs index 06c2da58db..b8edf9fd5c 100644 --- a/proxy/src/proxy/wake_compute.rs +++ b/proxy/src/proxy/wake_compute.rs @@ -1,3 +1,4 @@ +use async_trait::async_trait; use tracing::{error, info}; use crate::config::RetryConfig; @@ -8,7 +9,6 @@ use crate::error::ReportableError; use crate::metrics::{ ConnectOutcome, ConnectionFailuresBreakdownGroup, Metrics, RetriesMetricGroup, RetryType, }; -use crate::pglb::connect_compute::ComputeConnectBackend; use crate::proxy::retry::{retry_after, should_retry}; // Use macro to retain original callsite. @@ -23,7 +23,12 @@ macro_rules! log_wake_compute_error { }; } -pub(crate) async fn wake_compute( +#[async_trait] +pub(crate) trait WakeComputeBackend { + async fn wake_compute(&self, ctx: &RequestContext) -> Result; +} + +pub(crate) async fn wake_compute( num_retries: &mut u32, ctx: &RequestContext, api: &B, diff --git a/proxy/src/rate_limiter/leaky_bucket.rs b/proxy/src/rate_limiter/leaky_bucket.rs index 0c79b5e92f..f7e54ebfe7 100644 --- a/proxy/src/rate_limiter/leaky_bucket.rs +++ b/proxy/src/rate_limiter/leaky_bucket.rs @@ -69,9 +69,8 @@ pub struct LeakyBucketConfig { pub max: f64, } -#[cfg(test)] impl LeakyBucketConfig { - pub(crate) fn new(rps: f64, max: f64) -> Self { + pub fn new(rps: f64, max: f64) -> Self { assert!(rps > 0.0, "rps must be positive"); assert!(max > 0.0, "max must be positive"); Self { rps, max } diff --git a/proxy/src/rate_limiter/limiter.rs b/proxy/src/rate_limiter/limiter.rs index 9d700c1b52..2e40f5bf60 100644 --- a/proxy/src/rate_limiter/limiter.rs +++ b/proxy/src/rate_limiter/limiter.rs @@ -12,11 +12,10 @@ use rand::{Rng, SeedableRng}; use tokio::time::{Duration, Instant}; use tracing::info; +use super::LeakyBucketConfig; use crate::ext::LockExt; use crate::intern::EndpointIdInt; -use super::LeakyBucketConfig; - pub struct GlobalRateLimiter { data: Vec, info: Vec, @@ -140,12 +139,6 @@ impl RateBucketInfo { Self::new(200, Duration::from_secs(600)), ]; - // For all the sessions will be cancel key. So this limit is essentially global proxy limit. - pub const DEFAULT_REDIS_SET: [Self; 2] = [ - Self::new(100_000, Duration::from_secs(1)), - Self::new(50_000, Duration::from_secs(10)), - ]; - pub fn rps(&self) -> f64 { (self.max_rpi as f64) / self.interval.as_secs_f64() } diff --git a/proxy/src/redis/keys.rs b/proxy/src/redis/keys.rs index 3113bad949..ffb7bc876b 100644 --- a/proxy/src/redis/keys.rs +++ b/proxy/src/redis/keys.rs @@ -1,8 +1,4 @@ -use std::io::ErrorKind; - -use anyhow::Ok; - -use crate::pqproto::{CancelKeyData, id_to_cancel_key}; +use crate::pqproto::CancelKeyData; pub mod keyspace { pub const CANCEL_PREFIX: &str = "cancel"; @@ -23,40 +19,12 @@ impl KeyPrefix { } } } - - #[allow(dead_code)] - pub(crate) fn as_str(&self) -> &'static str { - match self { - KeyPrefix::Cancel(_) => keyspace::CANCEL_PREFIX, - } - } -} - -#[allow(dead_code)] -pub(crate) fn parse_redis_key(key: &str) -> anyhow::Result { - let (prefix, key_str) = key.split_once(':').ok_or_else(|| { - anyhow::anyhow!(std::io::Error::new( - ErrorKind::InvalidData, - "missing prefix" - )) - })?; - - match prefix { - keyspace::CANCEL_PREFIX => { - let id = u64::from_str_radix(key_str, 16)?; - - Ok(KeyPrefix::Cancel(id_to_cancel_key(id))) - } - _ => Err(anyhow::anyhow!(std::io::Error::new( - ErrorKind::InvalidData, - "unknown prefix" - ))), - } } #[cfg(test)] mod tests { use super::*; + use crate::pqproto::id_to_cancel_key; #[test] fn test_build_redis_key() { @@ -65,16 +33,4 @@ mod tests { let redis_key = cancel_key.build_redis_key(); assert_eq!(redis_key, "cancel:30390000d431"); } - - #[test] - fn test_parse_redis_key() { - let redis_key = "cancel:30390000d431"; - let key: KeyPrefix = parse_redis_key(redis_key).expect("Failed to parse key"); - - let ref_key = id_to_cancel_key(12345 << 32 | 54321); - - assert_eq!(key.as_str(), KeyPrefix::Cancel(ref_key).as_str()); - let KeyPrefix::Cancel(cancel_key) = key; - assert_eq!(ref_key, cancel_key); - } } diff --git a/proxy/src/redis/kv_ops.rs b/proxy/src/redis/kv_ops.rs index f71730c533..671fe09b0b 100644 --- a/proxy/src/redis/kv_ops.rs +++ b/proxy/src/redis/kv_ops.rs @@ -1,12 +1,13 @@ +use std::time::Duration; + +use futures::FutureExt; use redis::aio::ConnectionLike; use redis::{Cmd, FromRedisValue, Pipeline, RedisResult}; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; -use crate::rate_limiter::{GlobalRateLimiter, RateBucketInfo}; pub struct RedisKVClient { client: ConnectionWithCredentialsProvider, - limiter: GlobalRateLimiter, } #[allow(async_fn_in_trait)] @@ -27,42 +28,41 @@ impl Queryable for Cmd { } impl RedisKVClient { - pub fn new(client: ConnectionWithCredentialsProvider, info: &'static [RateBucketInfo]) -> Self { - Self { - client, - limiter: GlobalRateLimiter::new(info.into()), - } + pub fn new(client: ConnectionWithCredentialsProvider) -> Self { + Self { client } } pub async fn try_connect(&mut self) -> anyhow::Result<()> { - match self.client.connect().await { - Ok(()) => {} - Err(e) => { - tracing::error!("failed to connect to redis: {e}"); - return Err(e); - } - } - Ok(()) + self.client + .connect() + .boxed() + .await + .inspect_err(|e| tracing::error!("failed to connect to redis: {e}")) } pub(crate) async fn query( &mut self, q: &impl Queryable, ) -> anyhow::Result { - if !self.limiter.check() { - tracing::info!("Rate limit exceeded. Skipping query"); - return Err(anyhow::anyhow!("Rate limit exceeded")); - } - - match q.query(&mut self.client).await { + let e = match q.query(&mut self.client).await { Ok(t) => return Ok(t), - Err(e) => { - tracing::error!("failed to run query: {e}"); + Err(e) => e, + }; + + tracing::error!("failed to run query: {e}"); + match e.retry_method() { + redis::RetryMethod::Reconnect => { + tracing::info!("Redis client is disconnected. Reconnecting..."); + self.try_connect().await?; } + redis::RetryMethod::RetryImmediately => {} + redis::RetryMethod::WaitAndRetry => { + // somewhat arbitrary. + tokio::time::sleep(Duration::from_millis(100)).await; + } + _ => Err(e)?, } - tracing::info!("Redis client is disconnected. Reconnecting..."); - self.try_connect().await?; Ok(q.query(&mut self.client).await?) } } diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index a9d6b40603..973a4c5b02 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -3,12 +3,12 @@ use std::sync::Arc; use futures::StreamExt; use redis::aio::PubSub; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use tokio_util::sync::CancellationToken; use super::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::cache::project_info::ProjectInfoCache; -use crate::intern::{AccountIdInt, ProjectIdInt, RoleNameInt}; +use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; use crate::metrics::{Metrics, RedisErrors, RedisEventsCount}; const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates"; @@ -27,42 +27,37 @@ struct NotificationHeader<'a> { topic: &'a str, } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] #[serde(tag = "topic", content = "data")] -pub(crate) enum Notification { +enum Notification { #[serde( - rename = "/allowed_ips_updated", + rename = "/account_settings_update", + alias = "/allowed_vpc_endpoints_updated_for_org", deserialize_with = "deserialize_json_string" )] - AllowedIpsUpdate { - allowed_ips_update: AllowedIpsUpdate, - }, + AccountSettingsUpdate(InvalidateAccount), + #[serde( - rename = "/block_public_or_vpc_access_updated", + rename = "/endpoint_settings_update", deserialize_with = "deserialize_json_string" )] - BlockPublicOrVpcAccessUpdated { - block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated, - }, + EndpointSettingsUpdate(InvalidateEndpoint), + #[serde( - rename = "/allowed_vpc_endpoints_updated_for_org", + rename = "/project_settings_update", + alias = "/allowed_ips_updated", + alias = "/block_public_or_vpc_access_updated", + alias = "/allowed_vpc_endpoints_updated_for_projects", deserialize_with = "deserialize_json_string" )] - AllowedVpcEndpointsUpdatedForOrg { - allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg, - }, + ProjectSettingsUpdate(InvalidateProject), + #[serde( - rename = "/allowed_vpc_endpoints_updated_for_projects", + rename = "/role_setting_update", + alias = "/password_updated", deserialize_with = "deserialize_json_string" )] - AllowedVpcEndpointsUpdatedForProjects { - allowed_vpc_endpoints_updated_for_projects: AllowedVpcEndpointsUpdatedForProjects, - }, - #[serde( - rename = "/password_updated", - deserialize_with = "deserialize_json_string" - )] - PasswordUpdate { password_update: PasswordUpdate }, + RoleSettingUpdate(InvalidateRole), #[serde( other, @@ -72,28 +67,56 @@ pub(crate) enum Notification { UnknownTopic, } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct AllowedIpsUpdate { - project_id: ProjectIdInt, +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +#[serde(rename_all = "snake_case")] +enum InvalidateEndpoint { + EndpointId(EndpointIdInt), + EndpointIds(Vec), +} +impl std::ops::Deref for InvalidateEndpoint { + type Target = [EndpointIdInt]; + fn deref(&self) -> &Self::Target { + match self { + Self::EndpointId(id) => std::slice::from_ref(id), + Self::EndpointIds(ids) => ids, + } + } } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct BlockPublicOrVpcAccessUpdated { - project_id: ProjectIdInt, +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +#[serde(rename_all = "snake_case")] +enum InvalidateProject { + ProjectId(ProjectIdInt), + ProjectIds(Vec), +} +impl std::ops::Deref for InvalidateProject { + type Target = [ProjectIdInt]; + fn deref(&self) -> &Self::Target { + match self { + Self::ProjectId(id) => std::slice::from_ref(id), + Self::ProjectIds(ids) => ids, + } + } } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct AllowedVpcEndpointsUpdatedForOrg { - account_id: AccountIdInt, +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +#[serde(rename_all = "snake_case")] +enum InvalidateAccount { + AccountId(AccountIdInt), + AccountIds(Vec), +} +impl std::ops::Deref for InvalidateAccount { + type Target = [AccountIdInt]; + fn deref(&self) -> &Self::Target { + match self { + Self::AccountId(id) => std::slice::from_ref(id), + Self::AccountIds(ids) => ids, + } + } } -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct AllowedVpcEndpointsUpdatedForProjects { - project_ids: Vec, -} - -#[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] -pub(crate) struct PasswordUpdate { +#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +struct InvalidateRole { project_id: ProjectIdInt, role_name: RoleNameInt, } @@ -118,29 +141,19 @@ where struct MessageHandler { cache: Arc, - region_id: String, } impl Clone for MessageHandler { fn clone(&self) -> Self { Self { cache: self.cache.clone(), - region_id: self.region_id.clone(), } } } impl MessageHandler { - pub(crate) fn new(cache: Arc, region_id: String) -> Self { - Self { cache, region_id } - } - - pub(crate) async fn increment_active_listeners(&self) { - self.cache.increment_active_listeners().await; - } - - pub(crate) async fn decrement_active_listeners(&self) { - self.cache.decrement_active_listeners().await; + pub(crate) fn new(cache: Arc) -> Self { + Self { cache } } #[tracing::instrument(skip(self, msg), fields(session_id = tracing::field::Empty))] @@ -177,41 +190,29 @@ impl MessageHandler { tracing::debug!(?msg, "received a message"); match msg { - Notification::AllowedIpsUpdate { .. } - | Notification::PasswordUpdate { .. } - | Notification::BlockPublicOrVpcAccessUpdated { .. } - | Notification::AllowedVpcEndpointsUpdatedForOrg { .. } - | Notification::AllowedVpcEndpointsUpdatedForProjects { .. } => { + Notification::RoleSettingUpdate { .. } + | Notification::EndpointSettingsUpdate { .. } + | Notification::ProjectSettingsUpdate { .. } + | Notification::AccountSettingsUpdate { .. } => { invalidate_cache(self.cache.clone(), msg.clone()); - if matches!(msg, Notification::AllowedIpsUpdate { .. }) { - Metrics::get() - .proxy - .redis_events_count - .inc(RedisEventsCount::AllowedIpsUpdate); - } else if matches!(msg, Notification::PasswordUpdate { .. }) { - Metrics::get() - .proxy - .redis_events_count - .inc(RedisEventsCount::PasswordUpdate); - } else if matches!( - msg, - Notification::AllowedVpcEndpointsUpdatedForProjects { .. } - ) { - Metrics::get() - .proxy - .redis_events_count - .inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForProjects); - } else if matches!(msg, Notification::AllowedVpcEndpointsUpdatedForOrg { .. }) { - Metrics::get() - .proxy - .redis_events_count - .inc(RedisEventsCount::AllowedVpcEndpointIdsUpdateForAllProjectsInOrg); - } else if matches!(msg, Notification::BlockPublicOrVpcAccessUpdated { .. }) { - Metrics::get() - .proxy - .redis_events_count - .inc(RedisEventsCount::BlockPublicOrVpcAccessUpdate); + + let m = &Metrics::get().proxy.redis_events_count; + match msg { + Notification::RoleSettingUpdate { .. } => { + m.inc(RedisEventsCount::InvalidateRole); + } + Notification::EndpointSettingsUpdate { .. } => { + m.inc(RedisEventsCount::InvalidateEndpoint); + } + Notification::ProjectSettingsUpdate { .. } => { + m.inc(RedisEventsCount::InvalidateProject); + } + Notification::AccountSettingsUpdate { .. } => { + m.inc(RedisEventsCount::InvalidateOrg); + } + Notification::UnknownTopic => {} } + // TODO: add additional metrics for the other event types. // It might happen that the invalid entry is on the way to be cached. @@ -233,30 +234,23 @@ impl MessageHandler { fn invalidate_cache(cache: Arc, msg: Notification) { match msg { - Notification::AllowedIpsUpdate { - allowed_ips_update: AllowedIpsUpdate { project_id }, - } - | Notification::BlockPublicOrVpcAccessUpdated { - block_public_or_vpc_access_updated: BlockPublicOrVpcAccessUpdated { project_id }, - } => cache.invalidate_endpoint_access_for_project(project_id), - Notification::AllowedVpcEndpointsUpdatedForOrg { - allowed_vpc_endpoints_updated_for_org: AllowedVpcEndpointsUpdatedForOrg { account_id }, - } => cache.invalidate_endpoint_access_for_org(account_id), - Notification::AllowedVpcEndpointsUpdatedForProjects { - allowed_vpc_endpoints_updated_for_projects: - AllowedVpcEndpointsUpdatedForProjects { project_ids }, - } => { - for project in project_ids { - cache.invalidate_endpoint_access_for_project(project); - } - } - Notification::PasswordUpdate { - password_update: - PasswordUpdate { - project_id, - role_name, - }, - } => cache.invalidate_role_secret_for_project(project_id, role_name), + Notification::EndpointSettingsUpdate(ids) => ids + .iter() + .for_each(|&id| cache.invalidate_endpoint_access(id)), + + Notification::AccountSettingsUpdate(ids) => ids + .iter() + .for_each(|&id| cache.invalidate_endpoint_access_for_org(id)), + + Notification::ProjectSettingsUpdate(ids) => ids + .iter() + .for_each(|&id| cache.invalidate_endpoint_access_for_project(id)), + + Notification::RoleSettingUpdate(InvalidateRole { + project_id, + role_name, + }) => cache.invalidate_role_secret_for_project(project_id, role_name), + Notification::UnknownTopic => unreachable!(), } } @@ -272,7 +266,7 @@ async fn handle_messages( } let mut conn = match try_connect(&redis).await { Ok(conn) => { - handler.increment_active_listeners().await; + handler.cache.increment_active_listeners().await; conn } Err(e) => { @@ -293,11 +287,11 @@ async fn handle_messages( } } if cancellation_token.is_cancelled() { - handler.decrement_active_listeners().await; + handler.cache.decrement_active_listeners().await; return Ok(()); } } - handler.decrement_active_listeners().await; + handler.cache.decrement_active_listeners().await; } } @@ -306,12 +300,11 @@ async fn handle_messages( pub async fn task_main( redis: ConnectionWithCredentialsProvider, cache: Arc, - region_id: String, ) -> anyhow::Result where C: ProjectInfoCache + Send + Sync + 'static, { - let handler = MessageHandler::new(cache, region_id); + let handler = MessageHandler::new(cache); // 6h - 1m. // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost. let mut interval = tokio::time::interval(std::time::Duration::from_secs(6 * 60 * 60 - 60)); @@ -353,11 +346,32 @@ mod tests { let result: Notification = serde_json::from_str(&text)?; assert_eq!( result, - Notification::AllowedIpsUpdate { - allowed_ips_update: AllowedIpsUpdate { - project_id: (&project_id).into() - } - } + Notification::ProjectSettingsUpdate(InvalidateProject::ProjectId((&project_id).into())) + ); + + Ok(()) + } + + #[test] + fn parse_multiple_projects() -> anyhow::Result<()> { + let project_id1: ProjectId = "new_project1".into(); + let project_id2: ProjectId = "new_project2".into(); + let data = format!("{{\"project_ids\": [\"{project_id1}\",\"{project_id2}\"]}}"); + let text = json!({ + "type": "message", + "topic": "/allowed_vpc_endpoints_updated_for_projects", + "data": data, + "extre_fields": "something" + }) + .to_string(); + + let result: Notification = serde_json::from_str(&text)?; + assert_eq!( + result, + Notification::ProjectSettingsUpdate(InvalidateProject::ProjectIds(vec![ + (&project_id1).into(), + (&project_id2).into() + ])) ); Ok(()) @@ -379,12 +393,10 @@ mod tests { let result: Notification = serde_json::from_str(&text)?; assert_eq!( result, - Notification::PasswordUpdate { - password_update: PasswordUpdate { - project_id: (&project_id).into(), - role_name: (&role_name).into(), - } - } + Notification::RoleSettingUpdate(InvalidateRole { + project_id: (&project_id).into(), + role_name: (&role_name).into(), + }) ); Ok(()) diff --git a/proxy/src/sasl/channel_binding.rs b/proxy/src/sasl/channel_binding.rs index fdd011448e..e548cf3a83 100644 --- a/proxy/src/sasl/channel_binding.rs +++ b/proxy/src/sasl/channel_binding.rs @@ -1,5 +1,8 @@ //! Definition and parser for channel binding flag (a part of the `GS2` header). +use base64::Engine as _; +use base64::prelude::BASE64_STANDARD; + /// Channel binding flag (possibly with params). #[derive(Debug, PartialEq, Eq)] pub(crate) enum ChannelBinding { @@ -55,7 +58,7 @@ impl ChannelBinding { let mut cbind_input = vec![]; write!(&mut cbind_input, "p={mode},,",).unwrap(); cbind_input.extend_from_slice(get_cbind_data(mode)?); - base64::encode(&cbind_input).into() + BASE64_STANDARD.encode(&cbind_input).into() } }) } @@ -70,9 +73,9 @@ mod tests { use ChannelBinding::*; let cases = [ - (NotSupportedClient, base64::encode("n,,")), - (NotSupportedServer, base64::encode("y,,")), - (Required("foo"), base64::encode("p=foo,,bar")), + (NotSupportedClient, BASE64_STANDARD.encode("n,,")), + (NotSupportedServer, BASE64_STANDARD.encode("y,,")), + (Required("foo"), BASE64_STANDARD.encode("p=foo,,bar")), ]; for (cb, input) in cases { diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs index abd5aeae5b..3ba8a79368 100644 --- a/proxy/src/scram/exchange.rs +++ b/proxy/src/scram/exchange.rs @@ -2,6 +2,8 @@ use std::convert::Infallible; +use base64::Engine as _; +use base64::prelude::BASE64_STANDARD; use hmac::{Hmac, Mac}; use sha2::Sha256; @@ -105,7 +107,7 @@ pub(crate) async fn exchange( secret: &ServerSecret, password: &[u8], ) -> sasl::Result> { - let salt = base64::decode(&secret.salt_base64)?; + let salt = BASE64_STANDARD.decode(&secret.salt_base64)?; let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await; if secret.is_password_invalid(&client_key).into() { diff --git a/proxy/src/scram/messages.rs b/proxy/src/scram/messages.rs index e071417dab..42039f099c 100644 --- a/proxy/src/scram/messages.rs +++ b/proxy/src/scram/messages.rs @@ -3,6 +3,9 @@ use std::fmt; use std::ops::Range; +use base64::Engine as _; +use base64::prelude::BASE64_STANDARD; + use super::base64_decode_array; use super::key::{SCRAM_KEY_LEN, ScramKey}; use super::signature::SignatureBuilder; @@ -88,7 +91,7 @@ impl<'a> ClientFirstMessage<'a> { let mut message = String::new(); write!(&mut message, "r={}", self.nonce).unwrap(); - base64::encode_config_buf(nonce, base64::STANDARD, &mut message); + BASE64_STANDARD.encode_string(nonce, &mut message); let combined_nonce = 2..message.len(); write!(&mut message, ",s={salt_base64},i={iterations}").unwrap(); @@ -142,11 +145,7 @@ impl<'a> ClientFinalMessage<'a> { server_key: &ScramKey, ) -> String { let mut buf = String::from("v="); - base64::encode_config_buf( - signature_builder.build(server_key), - base64::STANDARD, - &mut buf, - ); + BASE64_STANDARD.encode_string(signature_builder.build(server_key), &mut buf); buf } @@ -251,7 +250,7 @@ mod tests { "iiYEfS3rOgn8S3rtpSdrOsHtPLWvIkdgmHxA0hf3JNOAG4dU" ); assert_eq!( - base64::encode(msg.proof), + BASE64_STANDARD.encode(msg.proof), "SRpfsIVS4Gk11w1LqQ4QvCUBZYQmqXNSDEcHqbQ3CHI=" ); } diff --git a/proxy/src/scram/mod.rs b/proxy/src/scram/mod.rs index 4f764c6087..5f627e062c 100644 --- a/proxy/src/scram/mod.rs +++ b/proxy/src/scram/mod.rs @@ -15,6 +15,8 @@ mod secret; mod signature; pub mod threadpool; +use base64::Engine as _; +use base64::prelude::BASE64_STANDARD; pub(crate) use exchange::{Exchange, exchange}; use hmac::{Hmac, Mac}; pub(crate) use key::ScramKey; @@ -32,7 +34,7 @@ pub(crate) const METHODS_WITHOUT_PLUS: &[&str] = &[SCRAM_SHA_256]; fn base64_decode_array(input: impl AsRef<[u8]>) -> Option<[u8; N]> { let mut bytes = [0u8; N]; - let size = base64::decode_config_slice(input, base64::STANDARD, &mut bytes).ok()?; + let size = BASE64_STANDARD.decode_slice(input, &mut bytes).ok()?; if size != N { return None; } diff --git a/proxy/src/scram/secret.rs b/proxy/src/scram/secret.rs index 8c6a08d432..f03617f34d 100644 --- a/proxy/src/scram/secret.rs +++ b/proxy/src/scram/secret.rs @@ -1,5 +1,7 @@ //! Tools for SCRAM server secret management. +use base64::Engine as _; +use base64::prelude::BASE64_STANDARD; use subtle::{Choice, ConstantTimeEq}; use super::base64_decode_array; @@ -56,7 +58,7 @@ impl ServerSecret { // iteration count 1 for our generated passwords going forward. // PG16 users can set iteration count=1 already today. iterations: 1, - salt_base64: base64::encode(nonce), + salt_base64: BASE64_STANDARD.encode(nonce), stored_key: ScramKey::default(), server_key: ScramKey::default(), doomed: true, @@ -88,7 +90,7 @@ mod tests { assert_eq!(parsed.iterations, iterations); assert_eq!(parsed.salt_base64, salt); - assert_eq!(base64::encode(parsed.stored_key), stored_key); - assert_eq!(base64::encode(parsed.server_key), server_key); + assert_eq!(BASE64_STANDARD.encode(parsed.stored_key), stored_key); + assert_eq!(BASE64_STANDARD.encode(parsed.server_key), server_key); } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 748e0ce6f2..26269d0a6e 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -21,9 +21,8 @@ use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool}; use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client}; use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool}; use crate::auth::backend::local::StaticAuthRules; -use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; +use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo}; use crate::auth::{self, AuthError}; -use crate::compute; use crate::compute_ctl::{ ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, }; @@ -35,7 +34,7 @@ use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; use crate::control_plane::locks::ApiLocks; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::intern::EndpointIdInt; -use crate::pglb::connect_compute::ConnectMechanism; +use crate::proxy::connect_compute::ConnectMechanism; use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute}; use crate::rate_limiter::EndpointRateLimiter; use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX}; @@ -69,17 +68,20 @@ impl PoolingBackend { self.config.authentication_config.is_vpc_acccess_proxy, )?; - let ep = EndpointIdInt::from(&user_info.endpoint); - let rate_limit_config = None; - if !self.endpoint_rate_limiter.check(ep, rate_limit_config, 1) { - return Err(AuthError::too_many_connections()); - } + access_control.connection_attempt_rate_limit( + ctx, + &user_info.endpoint, + &self.endpoint_rate_limiter, + )?; + let role_access = backend.get_role_secret(ctx).await?; let Some(secret) = role_access.secret else { // If we don't have an authentication secret, for the http flow we can just return an error. info!("authentication info not found"); return Err(AuthError::password_failed(&*user_info.user)); }; + + let ep = EndpointIdInt::from(&user_info.endpoint); let auth_outcome = crate::auth::validate_password_and_exchange( &self.config.authentication_config.thread_pool, ep, @@ -181,14 +183,15 @@ impl PoolingBackend { let conn_id = uuid::Uuid::new_v4(); tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); - let backend = self.auth_backend.as_ref().map(|()| keys); - crate::pglb::connect_compute::connect_to_compute( + let backend = self.auth_backend.as_ref().map(|()| keys.info); + crate::proxy::connect_compute::connect_to_compute( ctx, &TokioMechanism { conn_id, conn_info, pool: self.pool.clone(), locks: &self.config.connect_compute_locks, + keys: keys.keys, }, &backend, self.config.wake_compute_retry_config, @@ -215,18 +218,15 @@ impl PoolingBackend { let conn_id = uuid::Uuid::new_v4(); tracing::Span::current().record("conn_id", display(conn_id)); debug!(%conn_id, "pool: opening a new connection '{conn_info}'"); - let backend = self.auth_backend.as_ref().map(|()| ComputeCredentials { - info: ComputeUserInfo { - user: conn_info.user_info.user.clone(), - endpoint: EndpointId::from(format!( - "{}{LOCAL_PROXY_SUFFIX}", - conn_info.user_info.endpoint.normalize() - )), - options: conn_info.user_info.options.clone(), - }, - keys: crate::auth::backend::ComputeCredentialKeys::None, + let backend = self.auth_backend.as_ref().map(|()| ComputeUserInfo { + user: conn_info.user_info.user.clone(), + endpoint: EndpointId::from(format!( + "{}{LOCAL_PROXY_SUFFIX}", + conn_info.user_info.endpoint.normalize() + )), + options: conn_info.user_info.options.clone(), }); - crate::pglb::connect_compute::connect_to_compute( + crate::proxy::connect_compute::connect_to_compute( ctx, &HyperMechanism { conn_id, @@ -305,12 +305,13 @@ impl PoolingBackend { tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "local_pool: opening a new connection '{conn_info}'"); - let mut node_info = local_backend.node_info.clone(); - let (key, jwk) = create_random_jwk(); - let config = node_info - .config + let mut config = local_backend + .node_info + .conn_info + .to_postgres_client_config(); + config .user(&conn_info.user_info.user) .dbname(&conn_info.dbname) .set_param( @@ -322,7 +323,7 @@ impl PoolingBackend { ); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let (client, connection) = config.connect(postgres_client::NoTls).await?; + let (client, connection) = config.connect(&postgres_client::NoTls).await?; drop(pause); let pid = client.get_process_id(); @@ -336,7 +337,7 @@ impl PoolingBackend { connection, key, conn_id, - node_info.aux.clone(), + local_backend.node_info.aux.clone(), ); { @@ -495,6 +496,7 @@ struct TokioMechanism { pool: Arc>>, conn_info: ConnInfo, conn_id: uuid::Uuid, + keys: ComputeCredentialKeys, /// connect_to_compute concurrency lock locks: &'static ApiLocks, @@ -512,19 +514,20 @@ impl ConnectMechanism for TokioMechanism { node_info: &CachedNodeInfo, compute_config: &ComputeConfig, ) -> Result { - let host = node_info.config.get_host(); - let permit = self.locks.get_permit(&host).await?; + let permit = self.locks.get_permit(&node_info.conn_info.host).await?; - let mut config = (*node_info.config).clone(); + let mut config = node_info.conn_info.to_postgres_client_config(); let config = config .user(&self.conn_info.user_info.user) .dbname(&self.conn_info.dbname) .connect_timeout(compute_config.timeout); - let mk_tls = - crate::tls::postgres_rustls::MakeRustlsConnect::new(compute_config.tls.clone()); + if let ComputeCredentialKeys::AuthKeys(auth_keys) = self.keys { + config.auth_keys(auth_keys); + } + let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let res = config.connect(mk_tls).await; + let res = config.connect(compute_config).await; drop(pause); let (client, connection) = permit.release_result(res)?; @@ -548,8 +551,6 @@ impl ConnectMechanism for TokioMechanism { node_info.aux.clone(), )) } - - fn update_connect_config(&self, _config: &mut compute::ConnCfg) {} } struct HyperMechanism { @@ -573,20 +574,20 @@ impl ConnectMechanism for HyperMechanism { node_info: &CachedNodeInfo, config: &ComputeConfig, ) -> Result { - let host_addr = node_info.config.get_host_addr(); - let host = node_info.config.get_host(); - let permit = self.locks.get_permit(&host).await?; + let host_addr = node_info.conn_info.host_addr; + let host = &node_info.conn_info.host; + let permit = self.locks.get_permit(host).await?; let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let tls = if node_info.config.get_ssl_mode() == SslMode::Disable { + let tls = if node_info.conn_info.ssl_mode == SslMode::Disable { None } else { Some(&config.tls) }; - let port = node_info.config.get_port(); - let res = connect_http2(host_addr, &host, port, config.timeout, tls).await; + let port = node_info.conn_info.port; + let res = connect_http2(host_addr, host, port, config.timeout, tls).await; drop(pause); let (client, connection) = permit.release_result(res)?; @@ -609,8 +610,6 @@ impl ConnectMechanism for HyperMechanism { node_info.aux.clone(), )) } - - fn update_connect_config(&self, _config: &mut compute::ConnCfg) {} } async fn connect_http2( diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 87176ff7d6..dd8cf052c5 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -23,12 +23,12 @@ use super::conn_pool_lib::{ Client, ClientDataEnum, ClientInnerCommon, ClientInnerExt, ConnInfo, EndpointConnPool, GlobalConnPool, }; +use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::messages::MetricsAuxInfo; use crate::metrics::Metrics; -use crate::tls::postgres_rustls::MakeRustlsConnect; -type TlsStream = >::Stream; +type TlsStream = >::Stream; #[derive(Debug, Clone)] pub(crate) struct ConnInfoWithAuth { diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index bb5637cd5f..c367615fb8 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -16,6 +16,8 @@ use std::sync::atomic::AtomicUsize; use std::task::{Poll, ready}; use std::time::Duration; +use base64::Engine as _; +use base64::prelude::BASE64_URL_SAFE_NO_PAD; use ed25519_dalek::{Signature, Signer, SigningKey}; use futures::Future; use futures::future::poll_fn; @@ -346,7 +348,7 @@ fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String { jwt.push_str("eyJhbGciOiJFZERTQSJ9."); // encode the jwt payload in-place - base64::encode_config_buf(payload, base64::URL_SAFE_NO_PAD, &mut jwt); + BASE64_URL_SAFE_NO_PAD.encode_string(payload, &mut jwt); // create the signature from the encoded header || payload let sig: Signature = sk.sign(jwt.as_bytes()); @@ -354,7 +356,7 @@ fn sign_jwt(sk: &SigningKey, payload: &[u8]) -> String { jwt.push('.'); // encode the jwt signature in-place - base64::encode_config_buf(sig.to_bytes(), base64::URL_SAFE_NO_PAD, &mut jwt); + BASE64_URL_SAFE_NO_PAD.encode_string(sig.to_bytes(), &mut jwt); debug_assert_eq!( jwt.len(), diff --git a/proxy/src/serverless/mod.rs b/proxy/src/serverless/mod.rs index f6f681ac45..d8942bb814 100644 --- a/proxy/src/serverless/mod.rs +++ b/proxy/src/serverless/mod.rs @@ -50,10 +50,10 @@ use crate::context::RequestContext; use crate::ext::TaskExt; use crate::metrics::Metrics; use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; -use crate::proxy::run_until_cancelled; use crate::rate_limiter::EndpointRateLimiter; use crate::serverless::backend::PoolingBackend; use crate::serverless::http_util::{api_error_into_response, json_response}; +use crate::util::run_until_cancelled; pub(crate) const SERVERLESS_DRIVER_SNI: &str = "api"; pub(crate) const AUTH_BROKER_SNI: &str = "apiauth"; @@ -417,12 +417,7 @@ async fn request_handler( if config.http_config.accept_websockets && framed_websockets::upgrade::is_upgrade_request(&request) { - let ctx = RequestContext::new( - session_id, - conn_info, - crate::metrics::Protocol::Ws, - &config.region, - ); + let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Ws); ctx.set_user_agent( request @@ -462,12 +457,7 @@ async fn request_handler( // Return the response so the spawned future can continue. Ok(response.map(|b| b.map_err(|x| match x {}).boxed())) } else if request.uri().path() == "/sql" && *request.method() == Method::POST { - let ctx = RequestContext::new( - session_id, - conn_info, - crate::metrics::Protocol::Http, - &config.region, - ); + let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http); let span = ctx.span(); let testodrome_id = request diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index eb80ac9ad0..b2eb801f5c 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -41,10 +41,11 @@ use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::http::{ReadBodyError, read_body_with_limit}; use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind}; use crate::pqproto::StartupMessageParams; -use crate::proxy::{NeonOptions, run_until_cancelled}; +use crate::proxy::NeonOptions; use crate::serverless::backend::HttpConnError; use crate::types::{DbName, RoleName}; use crate::usage_metrics::{MetricCounter, MetricCounterRecorder}; +use crate::util::run_until_cancelled; #[derive(serde::Deserialize)] #[serde(rename_all = "camelCase")] diff --git a/proxy/src/serverless/websocket.rs b/proxy/src/serverless/websocket.rs index 8648a94869..0d374e6df2 100644 --- a/proxy/src/serverless/websocket.rs +++ b/proxy/src/serverless/websocket.rs @@ -167,7 +167,7 @@ pub(crate) async fn serve_websocket( Ok(Some(p)) => { ctx.set_success(); ctx.log_connect(); - match p.proxy_pass(&config.connect_to_compute).await { + match p.proxy_pass().await { Ok(()) => Ok(()), Err(ErrorSource::Client(err)) => Err(err).context("client"), Err(ErrorSource::Compute(err)) => Err(err).context("compute"), diff --git a/proxy/src/tls/mod.rs b/proxy/src/tls/mod.rs index 7fe71abf48..f576214255 100644 --- a/proxy/src/tls/mod.rs +++ b/proxy/src/tls/mod.rs @@ -3,6 +3,8 @@ pub mod postgres_rustls; pub mod server_config; use anyhow::Context; +use base64::Engine as _; +use base64::prelude::BASE64_STANDARD; use rustls::pki_types::CertificateDer; use sha2::{Digest, Sha256}; use tracing::{error, info}; @@ -58,7 +60,7 @@ impl TlsServerEndPoint { let oid = certificate.signature_algorithm.oid; if SHA256_OIDS.contains(&oid) { let tls_server_end_point: [u8; 32] = Sha256::new().chain_update(cert).finalize().into(); - info!(%subject, tls_server_end_point = %base64::encode(tls_server_end_point), "determined channel binding"); + info!(%subject, tls_server_end_point = %BASE64_STANDARD.encode(tls_server_end_point), "determined channel binding"); Ok(Self::Sha256(tls_server_end_point)) } else { error!(%subject, "unknown channel binding"); diff --git a/proxy/src/tls/postgres_rustls.rs b/proxy/src/tls/postgres_rustls.rs index 013b307f0b..9269ad8a06 100644 --- a/proxy/src/tls/postgres_rustls.rs +++ b/proxy/src/tls/postgres_rustls.rs @@ -2,10 +2,11 @@ use std::convert::TryFrom; use std::sync::Arc; use postgres_client::tls::MakeTlsConnect; -use rustls::ClientConfig; -use rustls::pki_types::ServerName; +use rustls::pki_types::{InvalidDnsNameError, ServerName}; use tokio::io::{AsyncRead, AsyncWrite}; +use crate::config::ComputeConfig; + mod private { use std::future::Future; use std::io; @@ -123,36 +124,27 @@ mod private { } } -/// A `MakeTlsConnect` implementation using `rustls`. -/// -/// That way you can connect to PostgreSQL using `rustls` as the TLS stack. -#[derive(Clone)] -pub struct MakeRustlsConnect { - pub config: Arc, -} - -impl MakeRustlsConnect { - /// Creates a new `MakeRustlsConnect` from the provided `ClientConfig`. - #[must_use] - pub fn new(config: Arc) -> Self { - Self { config } - } -} - -impl MakeTlsConnect for MakeRustlsConnect +impl MakeTlsConnect for ComputeConfig where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { type Stream = private::RustlsStream; type TlsConnect = private::RustlsConnect; - type Error = rustls::pki_types::InvalidDnsNameError; + type Error = InvalidDnsNameError; - fn make_tls_connect(&mut self, hostname: &str) -> Result { - ServerName::try_from(hostname).map(|dns_name| { - private::RustlsConnect(private::RustlsConnectData { - hostname: dns_name.to_owned(), - connector: Arc::clone(&self.config).into(), - }) - }) + fn make_tls_connect(&self, hostname: &str) -> Result { + make_tls_connect(&self.tls, hostname) } } + +pub fn make_tls_connect( + tls: &Arc, + hostname: &str, +) -> Result { + ServerName::try_from(hostname).map(|dns_name| { + private::RustlsConnect(private::RustlsConnectData { + hostname: dns_name.to_owned(), + connector: tls.clone().into(), + }) + }) +} diff --git a/proxy/src/usage_metrics.rs b/proxy/src/usage_metrics.rs index 115b958c54..c82c4865a7 100644 --- a/proxy/src/usage_metrics.rs +++ b/proxy/src/usage_metrics.rs @@ -399,7 +399,7 @@ async fn collect_metrics_iteration( fn create_remote_path_prefix(now: DateTime) -> String { format!( - "year={year:04}/month={month:02}/day={day:02}/{hour:02}:{minute:02}:{second:02}Z", + "year={year:04}/month={month:02}/day={day:02}/hour={hour:02}/{hour:02}:{minute:02}:{second:02}Z", year = now.year(), month = now.month(), day = now.day(), @@ -461,7 +461,7 @@ async fn upload_backup_events( real_now.second().into(), real_now.nanosecond(), )); - let path = format!("{path_prefix}_{id}.json.gz"); + let path = format!("{path_prefix}_{id}.ndjson.gz"); let remote_path = match RemotePath::from_string(&path) { Ok(remote_path) => remote_path, Err(e) => { @@ -471,9 +471,12 @@ async fn upload_backup_events( // TODO: This is async compression from Vec to Vec. Rewrite as byte stream. // Use sync compression in blocking threadpool. - let data = serde_json::to_vec(chunk).context("serialize metrics")?; let mut encoder = GzipEncoder::new(Vec::new()); - encoder.write_all(&data).await.context("compress metrics")?; + for event in chunk.events.iter() { + let data = serde_json::to_vec(event).context("serialize metrics")?; + encoder.write_all(&data).await.context("compress metrics")?; + encoder.write_all(b"\n").await.context("compress metrics")?; + } encoder.shutdown().await.context("compress metrics")?; let compressed_data: Bytes = encoder.get_ref().clone().into(); backoff::retry( @@ -499,7 +502,7 @@ async fn upload_backup_events( #[cfg(test)] mod tests { use std::fs; - use std::io::BufReader; + use std::io::{BufRead, BufReader}; use std::sync::{Arc, Mutex}; use anyhow::Error; @@ -673,11 +676,22 @@ mod tests { { let path = local_fs_path.join(&path_prefix).to_string(); if entry.path().to_str().unwrap().starts_with(&path) { - let chunk = serde_json::from_reader(flate2::bufread::GzDecoder::new( - BufReader::new(fs::File::open(entry.into_path()).unwrap()), - )) - .unwrap(); - stored_chunks.push(chunk); + let file = fs::File::open(entry.into_path()).unwrap(); + let decoder = flate2::bufread::GzDecoder::new(BufReader::new(file)); + let reader = BufReader::new(decoder); + + let mut events: Vec> = Vec::new(); + for line in reader.lines() { + let line = line.unwrap(); + let event: Event = serde_json::from_str(&line).unwrap(); + events.push(event); + } + + let report = Report { + events: Cow::Owned(events), + }; + + stored_chunks.push(report); } } storage_test_dir.close().ok(); diff --git a/proxy/src/util.rs b/proxy/src/util.rs new file mode 100644 index 0000000000..7fc2d9fbdb --- /dev/null +++ b/proxy/src/util.rs @@ -0,0 +1,14 @@ +use std::pin::pin; + +use futures::future::{Either, select}; +use tokio_util::sync::CancellationToken; + +pub async fn run_until_cancelled( + f: F, + cancellation_token: &CancellationToken, +) -> Option { + match select(pin!(f), pin!(cancellation_token.cancelled())).await { + Either::Left((f, _)) => Some(f), + Either::Right(((), _)) => None, + } +} diff --git a/pyproject.toml b/pyproject.toml index c6dfdc223c..e7e314d144 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ pytest = "^7.4.4" psycopg2-binary = "^2.9.10" typing-extensions = "^4.12.2" PyJWT = {version = "^2.1.0", extras = ["crypto"]} -requests = "^2.32.3" +requests = "^2.32.4" pytest-xdist = "^3.3.1" asyncpg = "^0.30.0" aiopg = "^1.4.0" diff --git a/safekeeper/Cargo.toml b/safekeeper/Cargo.toml index 0a8cc415be..6955028c73 100644 --- a/safekeeper/Cargo.toml +++ b/safekeeper/Cargo.toml @@ -58,6 +58,7 @@ metrics.workspace = true pem.workspace = true postgres_backend.workspace = true postgres_ffi.workspace = true +postgres_versioninfo.workspace = true pq_proto.workspace = true remote_storage.workspace = true safekeeper_api.workspace = true diff --git a/safekeeper/client/src/mgmt_api.rs b/safekeeper/client/src/mgmt_api.rs index b364ac8e48..2e46a7b529 100644 --- a/safekeeper/client/src/mgmt_api.rs +++ b/safekeeper/client/src/mgmt_api.rs @@ -8,8 +8,8 @@ use std::error::Error as _; use http_utils::error::HttpErrorBody; use reqwest::{IntoUrl, Method, StatusCode}; use safekeeper_api::models::{ - self, PullTimelineRequest, PullTimelineResponse, SafekeeperUtilization, TimelineCreateRequest, - TimelineStatus, + self, PullTimelineRequest, PullTimelineResponse, SafekeeperStatus, SafekeeperUtilization, + TimelineCreateRequest, TimelineStatus, }; use utils::id::{NodeId, TenantId, TimelineId}; use utils::logging::SecretString; @@ -183,6 +183,12 @@ impl Client { self.get(&uri).await } + pub async fn status(&self) -> Result { + let uri = format!("{}/v1/status", self.mgmt_api_endpoint); + let resp = self.get(&uri).await?; + resp.json().await.map_err(Error::ReceiveBody) + } + pub async fn utilization(&self) -> Result { let uri = format!("{}/v1/utilization", self.mgmt_api_endpoint); let resp = self.get(&uri).await?; diff --git a/safekeeper/src/control_file.rs b/safekeeper/src/control_file.rs index 1bf3e4cac1..4fc62fb229 100644 --- a/safekeeper/src/control_file.rs +++ b/safekeeper/src/control_file.rs @@ -206,16 +206,10 @@ impl Storage for FileStorage { let buf: Vec = s.write_to_buf()?; control_partial.write_all(&buf).await.with_context(|| { - format!( - "failed to write safekeeper state into control file at: {}", - control_partial_path - ) + format!("failed to write safekeeper state into control file at: {control_partial_path}") })?; control_partial.flush().await.with_context(|| { - format!( - "failed to flush safekeeper state into control file at: {}", - control_partial_path - ) + format!("failed to flush safekeeper state into control file at: {control_partial_path}") })?; let control_path = self.timeline_dir.join(CONTROL_FILE_NAME); diff --git a/safekeeper/src/control_file_upgrade.rs b/safekeeper/src/control_file_upgrade.rs index 1ad9e62f9b..555cbe457b 100644 --- a/safekeeper/src/control_file_upgrade.rs +++ b/safekeeper/src/control_file_upgrade.rs @@ -2,6 +2,7 @@ use std::vec; use anyhow::{Result, bail}; +use postgres_versioninfo::PgVersionId; use pq_proto::SystemId; use safekeeper_api::membership::{Configuration, INVALID_GENERATION}; use safekeeper_api::{ServerInfo, Term}; @@ -46,7 +47,7 @@ struct SafeKeeperStateV1 { #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct ServerInfoV2 { /// Postgres server version - pub pg_version: u32, + pub pg_version: PgVersionId, pub system_id: SystemId, pub tenant_id: TenantId, pub timeline_id: TimelineId, @@ -75,7 +76,7 @@ pub struct SafeKeeperStateV2 { #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub struct ServerInfoV3 { /// Postgres server version - pub pg_version: u32, + pub pg_version: PgVersionId, pub system_id: SystemId, #[serde(with = "hex")] pub tenant_id: TenantId, @@ -444,13 +445,13 @@ pub fn upgrade_control_file(buf: &[u8], version: u32) -> Result TimelinePersisten mod tests { use std::str::FromStr; + use postgres_versioninfo::PgMajorVersion; use utils::Hex; use utils::id::NodeId; @@ -563,7 +565,7 @@ mod tests { epoch: 43, }, server: ServerInfoV2 { - pg_version: 14, + pg_version: PgVersionId::from(PgMajorVersion::PG14), system_id: 0x1234567887654321, tenant_id, timeline_id, @@ -586,8 +588,8 @@ mod tests { 0x2a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // epoch 0x2b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - // pg_version - 0x0e, 0x00, 0x00, 0x00, + // pg_version = 140000 + 0xE0, 0x22, 0x02, 0x00, // system_id 0x21, 0x43, 0x65, 0x87, 0x78, 0x56, 0x34, 0x12, // tenant_id @@ -626,7 +628,7 @@ mod tests { }]), }, server: ServerInfoV2 { - pg_version: 14, + pg_version: PgVersionId::from(PgMajorVersion::PG14), system_id: 0x1234567887654321, tenant_id, timeline_id, @@ -646,7 +648,7 @@ mod tests { let expected = [ 0x2a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x21, 0x43, 0x65, 0x87, 0x78, 0x56, + 0x00, 0x00, 0x00, 0x00, 0xE0, 0x22, 0x02, 0x00, 0x21, 0x43, 0x65, 0x87, 0x78, 0x56, 0x34, 0x12, 0xcf, 0x04, 0x80, 0x92, 0x97, 0x07, 0xee, 0x75, 0x37, 0x23, 0x37, 0xef, 0xaa, 0x5e, 0xcf, 0x96, 0x11, 0x2d, 0xed, 0x66, 0x42, 0x2a, 0xa5, 0xe9, 0x53, 0xe5, 0x44, 0x0f, 0xa5, 0x42, 0x7a, 0xc4, 0x78, 0x56, 0x34, 0x12, 0xc4, 0x7a, 0x42, 0xa5, @@ -675,7 +677,7 @@ mod tests { }]), }, server: ServerInfoV3 { - pg_version: 14, + pg_version: PgVersionId::from(PgMajorVersion::PG14), system_id: 0x1234567887654321, tenant_id, timeline_id, @@ -695,7 +697,7 @@ mod tests { let expected = [ 0x2a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x0e, 0x00, 0x00, 0x00, 0x21, 0x43, 0x65, 0x87, 0x78, 0x56, + 0x00, 0x00, 0x00, 0x00, 0xE0, 0x22, 0x02, 0x00, 0x21, 0x43, 0x65, 0x87, 0x78, 0x56, 0x34, 0x12, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x63, 0x66, 0x30, 0x34, 0x38, 0x30, 0x39, 0x32, 0x39, 0x37, 0x30, 0x37, 0x65, 0x65, 0x37, 0x35, 0x33, 0x37, 0x32, 0x33, 0x33, 0x37, 0x65, 0x66, 0x61, 0x61, 0x35, 0x65, 0x63, 0x66, 0x39, 0x36, @@ -731,7 +733,7 @@ mod tests { }]), }, server: ServerInfo { - pg_version: 14, + pg_version: PgVersionId::from(PgMajorVersion::PG14), system_id: 0x1234567887654321, wal_seg_size: 0x12345678, }, @@ -765,7 +767,7 @@ mod tests { 0x30, 0x66, 0x61, 0x35, 0x34, 0x32, 0x37, 0x61, 0x63, 0x34, 0x2a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x0e, 0x00, 0x00, 0x00, 0x21, 0x43, 0x65, 0x87, 0x78, 0x56, 0x34, 0x12, 0x78, 0x56, + 0xE0, 0x22, 0x02, 0x00, 0x21, 0x43, 0x65, 0x87, 0x78, 0x56, 0x34, 0x12, 0x78, 0x56, 0x34, 0x12, 0x20, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x63, 0x34, 0x37, 0x61, 0x34, 0x32, 0x61, 0x35, 0x30, 0x66, 0x34, 0x34, 0x65, 0x35, 0x35, 0x33, 0x65, 0x39, 0x61, 0x35, 0x32, 0x61, 0x34, 0x32, 0x36, 0x36, 0x65, 0x64, 0x32, 0x64, 0x31, 0x31, diff --git a/safekeeper/src/handler.rs b/safekeeper/src/handler.rs index b54bee8bfb..5e7f1d8758 100644 --- a/safekeeper/src/handler.rs +++ b/safekeeper/src/handler.rs @@ -73,7 +73,7 @@ fn parse_cmd(cmd: &str) -> anyhow::Result { let re = Regex::new(r"START_WAL_PUSH(\s+?\((.*)\))?").unwrap(); let caps = re .captures(cmd) - .context(format!("failed to parse START_WAL_PUSH command {}", cmd))?; + .context(format!("failed to parse START_WAL_PUSH command {cmd}"))?; // capture () content let options = caps.get(2).map(|m| m.as_str()).unwrap_or(""); // default values @@ -85,24 +85,20 @@ fn parse_cmd(cmd: &str) -> anyhow::Result { } let mut kvit = kvstr.split_whitespace(); let key = kvit.next().context(format!( - "failed to parse key in kv {} in command {}", - kvstr, cmd + "failed to parse key in kv {kvstr} in command {cmd}" ))?; let value = kvit.next().context(format!( - "failed to parse value in kv {} in command {}", - kvstr, cmd + "failed to parse value in kv {kvstr} in command {cmd}" ))?; let value_trimmed = value.trim_matches('\''); if key == "proto_version" { proto_version = value_trimmed.parse::().context(format!( - "failed to parse proto_version value {} in command {}", - value, cmd + "failed to parse proto_version value {value} in command {cmd}" ))?; } if key == "allow_timeline_creation" { allow_timeline_creation = value_trimmed.parse::().context(format!( - "failed to parse allow_timeline_creation value {} in command {}", - value, cmd + "failed to parse allow_timeline_creation value {value} in command {cmd}" ))?; } } @@ -118,7 +114,7 @@ fn parse_cmd(cmd: &str) -> anyhow::Result { .unwrap(); let caps = re .captures(cmd) - .context(format!("failed to parse START_REPLICATION command {}", cmd))?; + .context(format!("failed to parse START_REPLICATION command {cmd}"))?; let start_lsn = Lsn::from_str(&caps[1]).context("parse start LSN from START_REPLICATION command")?; let term = if let Some(m) = caps.get(2) { diff --git a/safekeeper/src/safekeeper.rs b/safekeeper/src/safekeeper.rs index 886cac869d..4d15fc9de3 100644 --- a/safekeeper/src/safekeeper.rs +++ b/safekeeper/src/safekeeper.rs @@ -9,6 +9,7 @@ use anyhow::{Context, Result, bail}; use byteorder::{LittleEndian, ReadBytesExt}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use postgres_ffi::{MAX_SEND_SIZE, TimeLineID}; +use postgres_versioninfo::{PgMajorVersion, PgVersionId}; use pq_proto::SystemId; use safekeeper_api::membership::{ INVALID_GENERATION, MemberSet, SafekeeperGeneration as Generation, SafekeeperId, @@ -29,7 +30,7 @@ use crate::{control_file, wal_storage}; pub const SK_PROTO_VERSION_2: u32 = 2; pub const SK_PROTO_VERSION_3: u32 = 3; -pub const UNKNOWN_SERVER_VERSION: u32 = 0; +pub const UNKNOWN_SERVER_VERSION: PgVersionId = PgVersionId::UNKNOWN; #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)] pub struct TermLsn { @@ -64,10 +65,10 @@ impl TermHistory { for i in 0..n_entries { let term = bytes .get_u64_f() - .with_context(|| format!("TermHistory pos {} misses term", i))?; + .with_context(|| format!("TermHistory pos {i} misses term"))?; let lsn = bytes .get_u64_f() - .with_context(|| format!("TermHistory pos {} misses lsn", i))? + .with_context(|| format!("TermHistory pos {i} misses lsn"))? .into(); res.push(TermLsn { term, lsn }) } @@ -121,9 +122,7 @@ impl TermHistory { if let Some(sk_th_last) = sk_th.last() { assert!( sk_th_last.lsn <= sk_wal_end, - "safekeeper term history end {:?} LSN is higher than WAL end {:?}", - sk_th_last, - sk_wal_end + "safekeeper term history end {sk_th_last:?} LSN is higher than WAL end {sk_wal_end:?}" ); } @@ -220,7 +219,7 @@ pub struct ProposerGreeting { pub timeline_id: TimelineId, pub mconf: membership::Configuration, /// Postgres server version - pub pg_version: u32, + pub pg_version: PgVersionId, pub system_id: SystemId, pub wal_seg_size: u32, } @@ -231,7 +230,7 @@ pub struct ProposerGreetingV2 { /// proposer-acceptor protocol version pub protocol_version: u32, /// Postgres server version - pub pg_version: u32, + pub pg_version: PgVersionId, pub proposer_id: PgUuid, pub system_id: SystemId, pub timeline_id: TimelineId, @@ -438,11 +437,11 @@ impl ProposerAcceptorMessage { for i in 0..members_len { let id = buf .get_u64_f() - .with_context(|| format!("reading member {} node_id", i))?; - let host = Self::get_cstr(buf).with_context(|| format!("reading member {} host", i))?; + .with_context(|| format!("reading member {i} node_id"))?; + let host = Self::get_cstr(buf).with_context(|| format!("reading member {i} host"))?; let pg_port = buf .get_u16_f() - .with_context(|| format!("reading member {} port", i))?; + .with_context(|| format!("reading member {i} port"))?; let sk = SafekeeperId { id: NodeId(id), host, @@ -463,12 +462,12 @@ impl ProposerAcceptorMessage { for i in 0..new_members_len { let id = buf .get_u64_f() - .with_context(|| format!("reading new member {} node_id", i))?; - let host = Self::get_cstr(buf) - .with_context(|| format!("reading new member {} host", i))?; + .with_context(|| format!("reading new member {i} node_id"))?; + let host = + Self::get_cstr(buf).with_context(|| format!("reading new member {i} host"))?; let pg_port = buf .get_u16_f() - .with_context(|| format!("reading new member {} port", i))?; + .with_context(|| format!("reading new member {i} port"))?; let sk = SafekeeperId { id: NodeId(id), host, @@ -513,7 +512,7 @@ impl ProposerAcceptorMessage { tenant_id, timeline_id, mconf, - pg_version, + pg_version: PgVersionId::from_full_pg_version(pg_version), system_id, wal_seg_size, }; @@ -963,7 +962,8 @@ where * because safekeepers parse WAL headers and the format * may change between versions. */ - if msg.pg_version / 10000 != self.state.server.pg_version / 10000 + if PgMajorVersion::try_from(msg.pg_version)? + != PgMajorVersion::try_from(self.state.server.pg_version)? && self.state.server.pg_version != UNKNOWN_SERVER_VERSION { bail!( @@ -1508,7 +1508,7 @@ mod tests { let mut vote_resp = sk.process_msg(&vote_request).await; match vote_resp.unwrap() { Some(AcceptorProposerMessage::VoteResponse(resp)) => assert!(resp.vote_given), - r => panic!("unexpected response: {:?}", r), + r => panic!("unexpected response: {r:?}"), } // reboot... @@ -1523,7 +1523,7 @@ mod tests { vote_resp = sk.process_msg(&vote_request).await; match vote_resp.unwrap() { Some(AcceptorProposerMessage::VoteResponse(resp)) => assert!(!resp.vote_given), - r => panic!("unexpected response: {:?}", r), + r => panic!("unexpected response: {r:?}"), } } @@ -1750,7 +1750,7 @@ mod tests { }]), }, server: ServerInfo { - pg_version: 14, + pg_version: PgVersionId::from_full_pg_version(140000), system_id: 0x1234567887654321, wal_seg_size: 0x12345678, }, diff --git a/safekeeper/src/send_interpreted_wal.rs b/safekeeper/src/send_interpreted_wal.rs index 2b1fd7b854..2192f5eab4 100644 --- a/safekeeper/src/send_interpreted_wal.rs +++ b/safekeeper/src/send_interpreted_wal.rs @@ -8,8 +8,8 @@ use futures::StreamExt; use futures::future::Either; use pageserver_api::shard::ShardIdentity; use postgres_backend::{CopyStreamHandlerEnd, PostgresBackend}; -use postgres_ffi::get_current_timestamp; use postgres_ffi::waldecoder::{WalDecodeError, WalStreamDecoder}; +use postgres_ffi::{PgMajorVersion, get_current_timestamp}; use pq_proto::{BeMessage, InterpretedWalRecordsBody, WalSndKeepAlive}; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::sync::mpsc::error::SendError; @@ -78,7 +78,7 @@ pub(crate) struct InterpretedWalReader { shard_senders: HashMap>, shard_notification_rx: Option>, state: Arc>, - pg_version: u32, + pg_version: PgMajorVersion, } /// A handle for [`InterpretedWalReader`] which allows for interacting with it @@ -258,7 +258,7 @@ impl InterpretedWalReader { start_pos: Lsn, tx: tokio::sync::mpsc::Sender, shard: ShardIdentity, - pg_version: u32, + pg_version: PgMajorVersion, appname: &Option, ) -> InterpretedWalReaderHandle { let state = Arc::new(std::sync::RwLock::new(InterpretedWalReaderState::Running { @@ -322,7 +322,7 @@ impl InterpretedWalReader { start_pos: Lsn, tx: tokio::sync::mpsc::Sender, shard: ShardIdentity, - pg_version: u32, + pg_version: PgMajorVersion, shard_notification_rx: Option< tokio::sync::mpsc::UnboundedReceiver, >, @@ -718,7 +718,7 @@ mod tests { use std::time::Duration; use pageserver_api::shard::{ShardIdentity, ShardStripeSize}; - use postgres_ffi::MAX_SEND_SIZE; + use postgres_ffi::{MAX_SEND_SIZE, PgMajorVersion}; use tokio::sync::mpsc::error::TryRecvError; use utils::id::{NodeId, TenantTimelineId}; use utils::lsn::Lsn; @@ -734,7 +734,7 @@ mod tests { const SIZE: usize = 8 * 1024; const MSG_COUNT: usize = 200; - const PG_VERSION: u32 = 17; + const PG_VERSION: PgMajorVersion = PgMajorVersion::PG17; const SHARD_COUNT: u8 = 2; let start_lsn = Lsn::from_str("0/149FD18").unwrap(); @@ -876,7 +876,7 @@ mod tests { const SIZE: usize = 8 * 1024; const MSG_COUNT: usize = 200; - const PG_VERSION: u32 = 17; + const PG_VERSION: PgMajorVersion = PgMajorVersion::PG17; const SHARD_COUNT: u8 = 2; let start_lsn = Lsn::from_str("0/149FD18").unwrap(); @@ -1025,7 +1025,7 @@ mod tests { const SIZE: usize = 64 * 1024; const MSG_COUNT: usize = 10; - const PG_VERSION: u32 = 17; + const PG_VERSION: PgMajorVersion = PgMajorVersion::PG17; const SHARD_COUNT: u8 = 2; const WAL_READER_BATCH_SIZE: usize = 8192; @@ -1148,7 +1148,7 @@ mod tests { const SIZE: usize = 8 * 1024; const MSG_COUNT: usize = 10; - const PG_VERSION: u32 = 17; + const PG_VERSION: PgMajorVersion = PgMajorVersion::PG17; let start_lsn = Lsn::from_str("0/149FD18").unwrap(); let env = Env::new(true).unwrap(); diff --git a/safekeeper/src/send_wal.rs b/safekeeper/src/send_wal.rs index 05f827494e..177e759db5 100644 --- a/safekeeper/src/send_wal.rs +++ b/safekeeper/src/send_wal.rs @@ -12,7 +12,7 @@ use futures::FutureExt; use itertools::Itertools; use parking_lot::Mutex; use postgres_backend::{CopyStreamHandlerEnd, PostgresBackend, PostgresBackendReader, QueryError}; -use postgres_ffi::{MAX_SEND_SIZE, TimestampTz, get_current_timestamp}; +use postgres_ffi::{MAX_SEND_SIZE, PgMajorVersion, TimestampTz, get_current_timestamp}; use pq_proto::{BeMessage, WalSndKeepAlive, XLogDataBody}; use safekeeper_api::Term; use safekeeper_api::models::{ @@ -559,7 +559,9 @@ impl SafekeeperPostgresHandler { format, compression, } => { - let pg_version = tli.tli.get_state().await.1.server.pg_version / 10000; + let pg_version = + PgMajorVersion::try_from(tli.tli.get_state().await.1.server.pg_version) + .unwrap(); let end_watch_view = end_watch.view(); let wal_residence_guard = tli.wal_residence_guard().await?; let (tx, rx) = tokio::sync::mpsc::channel::(2); diff --git a/safekeeper/src/state.rs b/safekeeper/src/state.rs index 7533005c35..b6cf73be2e 100644 --- a/safekeeper/src/state.rs +++ b/safekeeper/src/state.rs @@ -7,6 +7,7 @@ use std::time::SystemTime; use anyhow::{Result, bail}; use postgres_ffi::WAL_SEGMENT_SIZE; +use postgres_versioninfo::{PgMajorVersion, PgVersionId}; use safekeeper_api::membership::Configuration; use safekeeper_api::models::{TimelineMembershipSwitchResponse, TimelineTermBumpResponse}; use safekeeper_api::{INITIAL_TERM, ServerInfo, Term}; @@ -149,8 +150,8 @@ impl TimelinePersistentState { &TenantTimelineId::empty(), Configuration::empty(), ServerInfo { - pg_version: 170000, /* Postgres server version (major * 10000) */ - system_id: 0, /* Postgres system identifier */ + pg_version: PgVersionId::from(PgMajorVersion::PG17), + system_id: 0, /* Postgres system identifier */ wal_seg_size: WAL_SEGMENT_SIZE as u32, }, Lsn::INVALID, diff --git a/safekeeper/src/timeline.rs b/safekeeper/src/timeline.rs index 588bd4f2c9..2bee41537f 100644 --- a/safekeeper/src/timeline.rs +++ b/safekeeper/src/timeline.rs @@ -395,6 +395,8 @@ pub enum TimelineError { Cancelled(TenantTimelineId), #[error("Timeline {0} was not found in global map")] NotFound(TenantTimelineId), + #[error("Timeline {0} has been deleted")] + Deleted(TenantTimelineId), #[error("Timeline {0} creation is in progress")] CreationInProgress(TenantTimelineId), #[error("Timeline {0} exists on disk, but wasn't loaded on startup")] diff --git a/safekeeper/src/timeline_eviction.rs b/safekeeper/src/timeline_eviction.rs index e817dbf6f9..47b65a579a 100644 --- a/safekeeper/src/timeline_eviction.rs +++ b/safekeeper/src/timeline_eviction.rs @@ -342,7 +342,7 @@ where let bytes_read1 = reader1 .read(&mut buffer1[..bytes_to_read]) .await - .with_context(|| format!("failed to read from reader1 at offset {}", offset))?; + .with_context(|| format!("failed to read from reader1 at offset {offset}"))?; if bytes_read1 == 0 { anyhow::bail!("unexpected EOF from reader1 at offset {}", offset); } @@ -351,10 +351,7 @@ where .read_exact(&mut buffer2[..bytes_read1]) .await .with_context(|| { - format!( - "failed to read {} bytes from reader2 at offset {}", - bytes_read1, offset - ) + format!("failed to read {bytes_read1} bytes from reader2 at offset {offset}") })?; assert!(bytes_read2 == bytes_read1); diff --git a/safekeeper/src/timeline_manager.rs b/safekeeper/src/timeline_manager.rs index 48eda92fed..a68752bfdd 100644 --- a/safekeeper/src/timeline_manager.rs +++ b/safekeeper/src/timeline_manager.rs @@ -108,7 +108,7 @@ impl std::fmt::Debug for ManagerCtlMessage { match self { ManagerCtlMessage::GuardRequest(_) => write!(f, "GuardRequest"), ManagerCtlMessage::TryGuardRequest(_) => write!(f, "TryGuardRequest"), - ManagerCtlMessage::GuardDrop(id) => write!(f, "GuardDrop({:?})", id), + ManagerCtlMessage::GuardDrop(id) => write!(f, "GuardDrop({id:?})"), ManagerCtlMessage::BackupPartialReset(_) => write!(f, "BackupPartialReset"), } } diff --git a/safekeeper/src/timelines_global_map.rs b/safekeeper/src/timelines_global_map.rs index e3f7d88f7c..a81a7298a9 100644 --- a/safekeeper/src/timelines_global_map.rs +++ b/safekeeper/src/timelines_global_map.rs @@ -78,7 +78,13 @@ impl GlobalTimelinesState { Some(GlobalMapTimeline::CreationInProgress) => { Err(TimelineError::CreationInProgress(*ttid)) } - None => Err(TimelineError::NotFound(*ttid)), + None => { + if self.has_tombstone(ttid) { + Err(TimelineError::Deleted(*ttid)) + } else { + Err(TimelineError::NotFound(*ttid)) + } + } } } @@ -141,7 +147,7 @@ impl GlobalTimelines { }; let mut tenant_count = 0; for tenants_dir_entry in std::fs::read_dir(&tenants_dir) - .with_context(|| format!("failed to list tenants dir {}", tenants_dir))? + .with_context(|| format!("failed to list tenants dir {tenants_dir}"))? { match &tenants_dir_entry { Ok(tenants_dir_entry) => { @@ -182,7 +188,7 @@ impl GlobalTimelines { let timelines_dir = get_tenant_dir(&conf, &tenant_id); for timelines_dir_entry in std::fs::read_dir(&timelines_dir) - .with_context(|| format!("failed to list timelines dir {}", timelines_dir))? + .with_context(|| format!("failed to list timelines dir {timelines_dir}"))? { match &timelines_dir_entry { Ok(timeline_dir_entry) => { diff --git a/safekeeper/src/wal_backup_partial.rs b/safekeeper/src/wal_backup_partial.rs index fe0f1b3607..cdf68262dd 100644 --- a/safekeeper/src/wal_backup_partial.rs +++ b/safekeeper/src/wal_backup_partial.rs @@ -364,8 +364,7 @@ impl PartialBackup { // there should always be zero or one uploaded segment assert!( new_segments.is_empty(), - "too many uploaded segments: {:?}", - new_segments + "too many uploaded segments: {new_segments:?}" ); } diff --git a/safekeeper/src/wal_storage.rs b/safekeeper/src/wal_storage.rs index 8ba3e7cc47..da00df2dd7 100644 --- a/safekeeper/src/wal_storage.rs +++ b/safekeeper/src/wal_storage.rs @@ -19,6 +19,7 @@ use futures::future::BoxFuture; use postgres_ffi::v14::xlog_utils::{IsPartialXLogFileName, IsXLogFileName, XLogFromFileName}; use postgres_ffi::waldecoder::WalStreamDecoder; use postgres_ffi::{PG_TLI, XLogFileName, XLogSegNo, dispatch_pgversion}; +use postgres_versioninfo::{PgMajorVersion, PgVersionId}; use pq_proto::SystemId; use remote_storage::RemotePath; use std::sync::Arc; @@ -92,7 +93,7 @@ pub struct PhysicalStorage { /// Size of WAL segment in bytes. wal_seg_size: usize, - pg_version: u32, + pg_version: PgVersionId, system_id: u64, /// Written to disk, but possibly still in the cache and not fully persisted. @@ -180,7 +181,7 @@ impl PhysicalStorage { let write_lsn = if state.commit_lsn == Lsn(0) { Lsn(0) } else { - let version = state.server.pg_version / 10000; + let version = PgMajorVersion::try_from(state.server.pg_version).unwrap(); dispatch_pgversion!( version, @@ -226,7 +227,10 @@ impl PhysicalStorage { write_record_lsn: write_lsn, flush_lsn, flush_record_lsn: flush_lsn, - decoder: WalStreamDecoder::new(write_lsn, state.server.pg_version / 10000), + decoder: WalStreamDecoder::new( + write_lsn, + PgMajorVersion::try_from(state.server.pg_version).unwrap(), + ), file: None, pending_wal_truncation: true, }) @@ -408,7 +412,7 @@ impl Storage for PhysicalStorage { let segno = init_lsn.segment_number(self.wal_seg_size); let (mut file, _) = self.open_or_create(segno).await?; - let major_pg_version = self.pg_version / 10000; + let major_pg_version = PgMajorVersion::try_from(self.pg_version).unwrap(); let wal_seg = postgres_ffi::generate_wal_segment(segno, self.system_id, major_pg_version, init_lsn)?; file.seek(SeekFrom::Start(0)).await?; @@ -654,7 +658,7 @@ pub struct WalReader { // pos is in the same segment as timeline_start_lsn. timeline_start_lsn: Lsn, // integer version number of PostgreSQL, e.g. 14; 15; 16 - pg_version: u32, + pg_version: PgMajorVersion, system_id: SystemId, timeline_start_segment: Option, } @@ -697,7 +701,7 @@ impl WalReader { wal_backup, local_start_lsn: state.local_start_lsn, timeline_start_lsn: state.timeline_start_lsn, - pg_version: state.server.pg_version / 10000, + pg_version: PgMajorVersion::try_from(state.server.pg_version).unwrap(), system_id: state.server.system_id, timeline_start_segment: None, }) @@ -841,7 +845,7 @@ pub(crate) async fn open_wal_file( // If that failed, try it without the .partial extension. let pf = tokio::fs::File::open(&wal_file_path) .await - .with_context(|| format!("failed to open WAL file {:#}", wal_file_path)) + .with_context(|| format!("failed to open WAL file {wal_file_path:#}")) .map_err(|e| { warn!("{}", e); e diff --git a/safekeeper/tests/walproposer_sim/log.rs b/safekeeper/tests/walproposer_sim/log.rs index e2ba3282ca..cecbc859e6 100644 --- a/safekeeper/tests/walproposer_sim/log.rs +++ b/safekeeper/tests/walproposer_sim/log.rs @@ -33,7 +33,7 @@ impl FormatTime for SimClock { if let Some(clock) = clock.as_ref() { let now = clock.now(); - write!(w, "[{}]", now) + write!(w, "[{now}]") } else { write!(w, "[?]") } diff --git a/safekeeper/tests/walproposer_sim/safekeeper.rs b/safekeeper/tests/walproposer_sim/safekeeper.rs index 5fb29683f2..1fdf8e4949 100644 --- a/safekeeper/tests/walproposer_sim/safekeeper.rs +++ b/safekeeper/tests/walproposer_sim/safekeeper.rs @@ -257,7 +257,7 @@ pub fn run_server(os: NodeOs, disk: Arc) -> Result<()> { let estr = e.to_string(); if !estr.contains("finished processing START_REPLICATION") { warn!("conn {:?} error: {:?}", connection_id, e); - panic!("unexpected error at safekeeper: {:#}", e); + panic!("unexpected error at safekeeper: {e:#}"); } conns.remove(&connection_id); break; diff --git a/safekeeper/tests/walproposer_sim/safekeeper_disk.rs b/safekeeper/tests/walproposer_sim/safekeeper_disk.rs index 94a849b5f0..029f8fab0a 100644 --- a/safekeeper/tests/walproposer_sim/safekeeper_disk.rs +++ b/safekeeper/tests/walproposer_sim/safekeeper_disk.rs @@ -7,8 +7,8 @@ use anyhow::Result; use bytes::{Buf, BytesMut}; use futures::future::BoxFuture; use parking_lot::Mutex; -use postgres_ffi::XLogSegNo; use postgres_ffi::waldecoder::WalStreamDecoder; +use postgres_ffi::{PgMajorVersion, XLogSegNo}; use safekeeper::metrics::WalStorageMetrics; use safekeeper::state::TimelinePersistentState; use safekeeper::{control_file, wal_storage}; @@ -142,7 +142,7 @@ impl DiskWALStorage { write_lsn, write_record_lsn: flush_lsn, flush_record_lsn: flush_lsn, - decoder: WalStreamDecoder::new(flush_lsn, 16), + decoder: WalStreamDecoder::new(flush_lsn, PgMajorVersion::PG16), unflushed_bytes: BytesMut::new(), disk, }) @@ -151,7 +151,7 @@ impl DiskWALStorage { fn find_end_of_wal(disk: Arc, start_lsn: Lsn) -> Result { let mut buf = [0; 8192]; let mut pos = start_lsn.0; - let mut decoder = WalStreamDecoder::new(start_lsn, 16); + let mut decoder = WalStreamDecoder::new(start_lsn, PgMajorVersion::PG16); let mut result = start_lsn; loop { disk.wal.lock().read(pos, &mut buf); @@ -204,7 +204,7 @@ impl wal_storage::Storage for DiskWALStorage { self.decoder.available(), startpos, ); - self.decoder = WalStreamDecoder::new(startpos, 16); + self.decoder = WalStreamDecoder::new(startpos, PgMajorVersion::PG16); } self.decoder.feed_bytes(buf); loop { @@ -242,7 +242,7 @@ impl wal_storage::Storage for DiskWALStorage { self.write_record_lsn = end_pos; self.flush_record_lsn = end_pos; self.unflushed_bytes.clear(); - self.decoder = WalStreamDecoder::new(end_pos, 16); + self.decoder = WalStreamDecoder::new(end_pos, PgMajorVersion::PG16); Ok(()) } diff --git a/safekeeper/tests/walproposer_sim/simulation.rs b/safekeeper/tests/walproposer_sim/simulation.rs index 70fecfbe22..edd3bf2d9e 100644 --- a/safekeeper/tests/walproposer_sim/simulation.rs +++ b/safekeeper/tests/walproposer_sim/simulation.rs @@ -217,7 +217,7 @@ impl TestConfig { ]; let server_ids = [servers[0].id, servers[1].id, servers[2].id]; - let safekeepers_addrs = server_ids.map(|id| format!("node:{}", id)).to_vec(); + let safekeepers_addrs = server_ids.map(|id| format!("node:{id}")).to_vec(); let ttid = TenantTimelineId::generate(); diff --git a/safekeeper/tests/walproposer_sim/walproposer_api.rs b/safekeeper/tests/walproposer_sim/walproposer_api.rs index 82e7a32881..29b361db7e 100644 --- a/safekeeper/tests/walproposer_sim/walproposer_api.rs +++ b/safekeeper/tests/walproposer_sim/walproposer_api.rs @@ -499,7 +499,7 @@ impl ApiImpl for SimulationApi { true } - fn finish_sync_safekeepers(&self, lsn: u64) { + fn finish_sync_safekeepers(&self, lsn: u64) -> ! { debug!("finish_sync_safekeepers, lsn={}", lsn); executor::exit(0, Lsn(lsn).to_string()); } @@ -523,7 +523,7 @@ impl ApiImpl for SimulationApi { // Voting bug when safekeeper disconnects after voting executor::exit(1, msg.to_owned()); } - panic!("unknown FATAL error from walproposer: {}", msg); + panic!("unknown FATAL error from walproposer: {msg}"); } } @@ -544,10 +544,7 @@ impl ApiImpl for SimulationApi { } } - let msg = format!( - "prop_elected;{};{};{};{}", - prop_lsn, prop_term, prev_lsn, prev_term - ); + let msg = format!("prop_elected;{prop_lsn};{prop_term};{prev_lsn};{prev_term}"); debug!(msg); self.os.log_event(msg); diff --git a/scripts/ingest_perf_test_result.py b/scripts/ingest_perf_test_result.py index 804f8a3cde..898e1ee954 100644 --- a/scripts/ingest_perf_test_result.py +++ b/scripts/ingest_perf_test_result.py @@ -26,7 +26,7 @@ CREATE TABLE IF NOT EXISTS perf_test_results ( metric_unit VARCHAR(10), metric_report_type TEXT, recorded_at_timestamp TIMESTAMP WITH TIME ZONE DEFAULT NOW(), - labels JSONB with default '{}' + labels JSONB DEFAULT '{}'::jsonb ) """ diff --git a/scripts/proxy_bench_results_ingest.py b/scripts/proxy_bench_results_ingest.py new file mode 100644 index 0000000000..475d053ed2 --- /dev/null +++ b/scripts/proxy_bench_results_ingest.py @@ -0,0 +1,187 @@ +#!/usr/bin/env python3 + +import argparse +import json +import time +from typing import Any, TypedDict, cast + +import requests + +PROMETHEUS_URL = "http://localhost:9090" +SAMPLE_INTERVAL = 1 # seconds + +DEFAULT_REVISION = "unknown" +DEFAULT_PLATFORM = "unknown" +DEFAULT_SUIT = "proxy_bench" + + +class MetricConfig(TypedDict, total=False): + name: str + promql: str + unit: str + report: str + labels: dict[str, str] + is_vector: bool + label_field: str + + +METRICS: list[MetricConfig] = [ + { + "name": "latency_p99", + "promql": 'histogram_quantile(0.99, sum(rate(proxy_compute_connection_latency_seconds_bucket{outcome="success", excluded="client_and_cplane"}[5m])) by (le))', + "unit": "s", + "report": "LOWER_IS_BETTER", + "labels": {}, + }, + { + "name": "error_rate", + "promql": 'sum(rate(proxy_errors_total{type!~"user|clientdisconnect|quota"}[5m])) / sum(rate(proxy_accepted_connections_total[5m]))', + "unit": "", + "report": "LOWER_IS_BETTER", + "labels": {}, + }, + { + "name": "max_memory_kb", + "promql": "max(libmetrics_maxrss_kb)", + "unit": "kB", + "report": "LOWER_IS_BETTER", + "labels": {}, + }, + { + "name": "jemalloc_active_bytes", + "promql": "sum(jemalloc_active_bytes)", + "unit": "bytes", + "report": "LOWER_IS_BETTER", + "labels": {}, + }, + { + "name": "open_connections", + "promql": "sum by (protocol) (proxy_opened_client_connections_total - proxy_closed_client_connections_total)", + "unit": "", + "report": "HIGHER_IS_BETTER", + "labels": {}, + "is_vector": True, + "label_field": "protocol", + }, +] + + +class PrometheusMetric(TypedDict): + metric: dict[str, str] + value: list[str | float] + + +class PrometheusResult(TypedDict): + result: list[PrometheusMetric] + + +class PrometheusResponse(TypedDict): + data: PrometheusResult + + +def query_prometheus(promql: str) -> PrometheusResponse: + resp = requests.get(f"{PROMETHEUS_URL}/api/v1/query", params={"query": promql}) + resp.raise_for_status() + return cast("PrometheusResponse", resp.json()) + + +def extract_scalar_metric(result_json: PrometheusResponse) -> float | None: + try: + return float(result_json["data"]["result"][0]["value"][1]) + except (IndexError, KeyError, ValueError, TypeError): + return None + + +def extract_vector_metric( + result_json: PrometheusResponse, label_field: str +) -> list[tuple[str | None, float, dict[str, str]]]: + out: list[tuple[str | None, float, dict[str, str]]] = [] + for entry in result_json["data"]["result"]: + try: + value_str = entry["value"][1] + if not isinstance(value_str, (str | float)): + continue + value = float(value_str) + except (IndexError, KeyError, ValueError, TypeError): + continue + labels = entry.get("metric", {}) + label_val = labels.get(label_field, None) + out.append((label_val, value, labels)) + return out + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Collect Prometheus metrics and output in benchmark fixture format" + ) + parser.add_argument("--revision", default=DEFAULT_REVISION) + parser.add_argument("--platform", default=DEFAULT_PLATFORM) + parser.add_argument("--suit", default=DEFAULT_SUIT) + parser.add_argument("--out", default="metrics_benchmarks.json", help="Output JSON file") + parser.add_argument( + "--interval", default=SAMPLE_INTERVAL, type=int, help="Sampling interval (s)" + ) + args = parser.parse_args() + + start_time = int(time.time()) + samples: list[dict[str, Any]] = [] + + print("Collecting metrics (Ctrl+C to stop)...") + try: + while True: + ts = int(time.time()) + for metric in METRICS: + if metric.get("is_vector", False): + # Vector (per-label, e.g. per-protocol) + for label_val, value, labels in extract_vector_metric( + query_prometheus(metric["promql"]), metric["label_field"] + ): + entry = { + "name": f"{metric['name']}.{label_val}" + if label_val + else metric["name"], + "value": value, + "unit": metric["unit"], + "report": metric["report"], + "labels": {**metric.get("labels", {}), **labels}, + "timestamp": ts, + } + samples.append(entry) + else: + result = extract_scalar_metric(query_prometheus(metric["promql"])) + if result is not None: + entry = { + "name": metric["name"], + "value": result, + "unit": metric["unit"], + "report": metric["report"], + "labels": metric.get("labels", {}), + "timestamp": ts, + } + samples.append(entry) + time.sleep(args.interval) + except KeyboardInterrupt: + print("Collection stopped.") + + total_duration = int(time.time()) - start_time + + # Compose output + out = { + "revision": args.revision, + "platform": args.platform, + "result": [ + { + "suit": args.suit, + "total_duration": total_duration, + "data": samples, + } + ], + } + + with open(args.out, "w") as f: + json.dump(out, f, indent=2) + print(f"Wrote metrics in fixture format to {args.out}") + + +if __name__ == "__main__": + main() diff --git a/storage_broker/benches/rps.rs b/storage_broker/benches/rps.rs index 9953ccfa91..5f3e594687 100644 --- a/storage_broker/benches/rps.rs +++ b/storage_broker/benches/rps.rs @@ -161,7 +161,7 @@ async fn publish(client: Option, n_keys: u64) { } }; let response = client.publish_safekeeper_info(Request::new(outbound)).await; - println!("pub response is {:?}", response); + println!("pub response is {response:?}"); } #[tokio::main] diff --git a/storage_broker/build.rs b/storage_broker/build.rs index 08dadeacd5..77c441dddd 100644 --- a/storage_broker/build.rs +++ b/storage_broker/build.rs @@ -6,6 +6,6 @@ fn main() -> Result<(), Box> { // the build then. Anyway, per cargo docs build script shouldn't output to // anywhere but $OUT_DIR. tonic_build::compile_protos("proto/broker.proto") - .unwrap_or_else(|e| panic!("failed to compile protos {:?}", e)); + .unwrap_or_else(|e| panic!("failed to compile protos {e:?}")); Ok(()) } diff --git a/storage_broker/src/lib.rs b/storage_broker/src/lib.rs index 149656a191..7d8b57380f 100644 --- a/storage_broker/src/lib.rs +++ b/storage_broker/src/lib.rs @@ -86,13 +86,9 @@ impl BrokerClientChannel { #[allow(clippy::result_large_err, reason = "TODO")] pub fn parse_proto_ttid(proto_ttid: &ProtoTenantTimelineId) -> Result { let tenant_id = TenantId::from_slice(&proto_ttid.tenant_id) - .map_err(|e| Status::new(Code::InvalidArgument, format!("malformed tenant_id: {}", e)))?; - let timeline_id = TimelineId::from_slice(&proto_ttid.timeline_id).map_err(|e| { - Status::new( - Code::InvalidArgument, - format!("malformed timeline_id: {}", e), - ) - })?; + .map_err(|e| Status::new(Code::InvalidArgument, format!("malformed tenant_id: {e}")))?; + let timeline_id = TimelineId::from_slice(&proto_ttid.timeline_id) + .map_err(|e| Status::new(Code::InvalidArgument, format!("malformed timeline_id: {e}")))?; Ok(TenantTimelineId { tenant_id, timeline_id, diff --git a/storage_controller/Cargo.toml b/storage_controller/Cargo.toml index c41e174d9d..3a0806b3b2 100644 --- a/storage_controller/Cargo.toml +++ b/storage_controller/Cargo.toml @@ -27,6 +27,7 @@ governor.workspace = true hex.workspace = true hyper0.workspace = true humantime.workspace = true +humantime-serde.workspace = true itertools.workspace = true json-structural-diff.workspace = true lasso.workspace = true @@ -34,6 +35,7 @@ once_cell.workspace = true pageserver_api.workspace = true pageserver_client.workspace = true postgres_connection.workspace = true +posthog_client_lite.workspace = true rand.workspace = true reqwest = { workspace = true, features = ["stream"] } routerify.workspace = true diff --git a/storage_controller/migrations/2025-06-01-201442_add_lifecycle_to_nodes/down.sql b/storage_controller/migrations/2025-06-01-201442_add_lifecycle_to_nodes/down.sql new file mode 100644 index 0000000000..a09acb916b --- /dev/null +++ b/storage_controller/migrations/2025-06-01-201442_add_lifecycle_to_nodes/down.sql @@ -0,0 +1 @@ +ALTER TABLE nodes DROP COLUMN lifecycle; diff --git a/storage_controller/migrations/2025-06-01-201442_add_lifecycle_to_nodes/up.sql b/storage_controller/migrations/2025-06-01-201442_add_lifecycle_to_nodes/up.sql new file mode 100644 index 0000000000..e03a0cadba --- /dev/null +++ b/storage_controller/migrations/2025-06-01-201442_add_lifecycle_to_nodes/up.sql @@ -0,0 +1 @@ +ALTER TABLE nodes ADD COLUMN lifecycle VARCHAR NOT NULL DEFAULT 'active'; diff --git a/storage_controller/migrations/2025-06-17-082247_pageserver_grpc_addr/down.sql b/storage_controller/migrations/2025-06-17-082247_pageserver_grpc_addr/down.sql new file mode 100644 index 0000000000..f9f2ebb070 --- /dev/null +++ b/storage_controller/migrations/2025-06-17-082247_pageserver_grpc_addr/down.sql @@ -0,0 +1 @@ +ALTER TABLE nodes DROP listen_grpc_addr, listen_grpc_port; diff --git a/storage_controller/migrations/2025-06-17-082247_pageserver_grpc_addr/up.sql b/storage_controller/migrations/2025-06-17-082247_pageserver_grpc_addr/up.sql new file mode 100644 index 0000000000..8291864b16 --- /dev/null +++ b/storage_controller/migrations/2025-06-17-082247_pageserver_grpc_addr/up.sql @@ -0,0 +1 @@ +ALTER TABLE nodes ADD listen_grpc_addr VARCHAR NULL, ADD listen_grpc_port INTEGER NULL; diff --git a/storage_controller/src/compute_hook.rs b/storage_controller/src/compute_hook.rs index 57709302e1..0b5569b3d6 100644 --- a/storage_controller/src/compute_hook.rs +++ b/storage_controller/src/compute_hook.rs @@ -5,10 +5,11 @@ use std::sync::Arc; use std::time::Duration; use anyhow::Context; -use control_plane::endpoint::{ComputeControlPlane, EndpointStatus}; +use control_plane::endpoint::{ComputeControlPlane, EndpointStatus, PageserverProtocol}; use control_plane::local_env::LocalEnv; use futures::StreamExt; use hyper::StatusCode; +use pageserver_api::config::DEFAULT_GRPC_LISTEN_PORT; use pageserver_api::controller_api::AvailabilityZone; use pageserver_api::shard::{ShardCount, ShardNumber, ShardStripeSize, TenantShardId}; use postgres_connection::parse_host_port; @@ -369,7 +370,7 @@ impl ComputeHook { let authorization_header = config .control_plane_jwt_token .clone() - .map(|jwt| format!("Bearer {}", jwt)); + .map(|jwt| format!("Bearer {jwt}")); let mut client = reqwest::ClientBuilder::new().timeout(NOTIFY_REQUEST_TIMEOUT); for cert in &config.ssl_ca_certs { @@ -420,23 +421,31 @@ impl ComputeHook { preferred_az: _preferred_az, } = reconfigure_request; - let compute_pageservers = shards - .iter() - .map(|shard| { - let ps_conf = env - .get_pageserver_conf(shard.node_id) - .expect("Unknown pageserver"); - let (pg_host, pg_port) = parse_host_port(&ps_conf.listen_pg_addr) - .expect("Unable to parse listen_pg_addr"); - (pg_host, pg_port.unwrap_or(5432)) - }) - .collect::>(); - for (endpoint_name, endpoint) in &cplane.endpoints { if endpoint.tenant_id == *tenant_id && endpoint.status() == EndpointStatus::Running { - tracing::info!("Reconfiguring endpoint {}", endpoint_name,); + tracing::info!("Reconfiguring endpoint {endpoint_name}"); + + let pageservers = shards + .iter() + .map(|shard| { + let ps_conf = env + .get_pageserver_conf(shard.node_id) + .expect("Unknown pageserver"); + if endpoint.grpc { + let addr = ps_conf.listen_grpc_addr.as_ref().expect("no gRPC address"); + let (host, port) = parse_host_port(addr).expect("invalid gRPC address"); + let port = port.unwrap_or(DEFAULT_GRPC_LISTEN_PORT); + (PageserverProtocol::Grpc, host, port) + } else { + let (host, port) = parse_host_port(&ps_conf.listen_pg_addr) + .expect("Unable to parse listen_pg_addr"); + (PageserverProtocol::Libpq, host, port.unwrap_or(5432)) + } + }) + .collect::>(); + endpoint - .reconfigure(compute_pageservers.clone(), *stripe_size, None) + .reconfigure(pageservers, *stripe_size, None) .await .map_err(NotifyError::NeonLocal)?; } diff --git a/storage_controller/src/drain_utils.rs b/storage_controller/src/drain_utils.rs index bd4b8ba38f..0dae7b8147 100644 --- a/storage_controller/src/drain_utils.rs +++ b/storage_controller/src/drain_utils.rs @@ -62,7 +62,7 @@ pub(crate) fn validate_node_state( nodes: Arc>, ) -> Result<(), OperationError> { let node = nodes.get(node_id).ok_or(OperationError::NodeStateChanged( - format!("node {} was removed", node_id).into(), + format!("node {node_id} was removed").into(), ))?; let current_policy = node.get_scheduling(); @@ -70,7 +70,7 @@ pub(crate) fn validate_node_state( // TODO(vlad): maybe cancel pending reconciles before erroring out. need to think // about it return Err(OperationError::NodeStateChanged( - format!("node {} changed state to {:?}", node_id, current_policy).into(), + format!("node {node_id} changed state to {current_policy:?}").into(), )); } @@ -145,7 +145,7 @@ impl TenantShardDrain { if !nodes.contains_key(&destination) { return Err(OperationError::NodeStateChanged( - format!("node {} was removed", destination).into(), + format!("node {destination} was removed").into(), )); } diff --git a/storage_controller/src/http.rs b/storage_controller/src/http.rs index 2b1c0db12f..a7e86b5224 100644 --- a/storage_controller/src/http.rs +++ b/storage_controller/src/http.rs @@ -721,9 +721,9 @@ async fn handle_tenant_timeline_passthrough( // Callers will always pass an unsharded tenant ID. Before proxying, we must // rewrite this to a shard-aware shard zero ID. - let path = format!("{}", path); + let path = format!("{path}"); let tenant_str = tenant_or_shard_id.tenant_id.to_string(); - let tenant_shard_str = format!("{}", tenant_shard_id); + let tenant_shard_str = format!("{tenant_shard_id}"); let path = path.replace(&tenant_str, &tenant_shard_str); let latency = &METRICS_REGISTRY @@ -907,6 +907,42 @@ async fn handle_node_delete(req: Request) -> Result, ApiErr json_response(StatusCode::OK, state.service.node_delete(node_id).await?) } +async fn handle_tombstone_list(req: Request) -> Result, ApiError> { + check_permissions(&req, Scope::Admin)?; + + let req = match maybe_forward(req).await { + ForwardOutcome::Forwarded(res) => { + return res; + } + ForwardOutcome::NotForwarded(req) => req, + }; + + let state = get_state(&req); + let mut nodes = state.service.tombstone_list().await?; + nodes.sort_by_key(|n| n.get_id()); + let api_nodes = nodes.into_iter().map(|n| n.describe()).collect::>(); + + json_response(StatusCode::OK, api_nodes) +} + +async fn handle_tombstone_delete(req: Request) -> Result, ApiError> { + check_permissions(&req, Scope::Admin)?; + + let req = match maybe_forward(req).await { + ForwardOutcome::Forwarded(res) => { + return res; + } + ForwardOutcome::NotForwarded(req) => req, + }; + + let state = get_state(&req); + let node_id: NodeId = parse_request_param(&req, "node_id")?; + json_response( + StatusCode::OK, + state.service.tombstone_delete(node_id).await?, + ) +} + async fn handle_node_configure(req: Request) -> Result, ApiError> { check_permissions(&req, Scope::Admin)?; @@ -1362,6 +1398,31 @@ async fn handle_timeline_import(req: Request) -> Result, Ap ) } +async fn handle_tenant_timeline_locate( + service: Arc, + req: Request, +) -> Result, ApiError> { + let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?; + let timeline_id: TimelineId = parse_request_param(&req, "timeline_id")?; + + check_permissions(&req, Scope::Admin)?; + maybe_rate_limit(&req, tenant_id).await; + + match maybe_forward(req).await { + ForwardOutcome::Forwarded(res) => { + return res; + } + ForwardOutcome::NotForwarded(_req) => {} + }; + + json_response( + StatusCode::OK, + service + .tenant_timeline_locate(tenant_id, timeline_id) + .await?, + ) +} + async fn handle_tenants_dump(req: Request) -> Result, ApiError> { check_permissions(&req, Scope::Admin)?; @@ -1478,7 +1539,7 @@ async fn handle_ready(req: Request) -> Result, ApiError> { impl From for ApiError { fn from(value: ReconcileError) -> Self { - ApiError::Conflict(format!("Reconciliation error: {}", value)) + ApiError::Conflict(format!("Reconciliation error: {value}")) } } @@ -2009,10 +2070,10 @@ pub fn make_router( router .data(Arc::new(HttpState::new(service, auth, build_info))) + // Non-prefixed generic endpoints (status, metrics, profiling) .get("/metrics", |r| { named_request_span(r, measured_metrics_handler, RequestName("metrics")) }) - // Non-prefixed generic endpoints (status, metrics, profiling) .get("/status", |r| { named_request_span(r, handle_status, RequestName("status")) }) @@ -2062,6 +2123,20 @@ pub fn make_router( .post("/debug/v1/node/:node_id/drop", |r| { named_request_span(r, handle_node_drop, RequestName("debug_v1_node_drop")) }) + .delete("/debug/v1/tombstone/:node_id", |r| { + named_request_span( + r, + handle_tombstone_delete, + RequestName("debug_v1_tombstone_delete"), + ) + }) + .get("/debug/v1/tombstone", |r| { + named_request_span( + r, + handle_tombstone_list, + RequestName("debug_v1_tombstone_list"), + ) + }) .post("/debug/v1/tenant/:tenant_id/import", |r| { named_request_span( r, @@ -2089,6 +2164,16 @@ pub fn make_router( ) }, ) + .get( + "/debug/v1/tenant/:tenant_id/timeline/:timeline_id/locate", + |r| { + tenant_service_handler( + r, + handle_tenant_timeline_locate, + RequestName("v1_tenant_timeline_locate"), + ) + }, + ) .get("/debug/v1/scheduler", |r| { named_request_span(r, handle_scheduler_dump, RequestName("debug_v1_scheduler")) }) diff --git a/storage_controller/src/main.rs b/storage_controller/src/main.rs index 2eea2f9d10..296a98e620 100644 --- a/storage_controller/src/main.rs +++ b/storage_controller/src/main.rs @@ -5,17 +5,22 @@ use std::time::Duration; use anyhow::{Context, anyhow}; use camino::Utf8PathBuf; + +#[cfg(feature = "testing")] +use clap::ArgAction; use clap::Parser; use futures::future::OptionFuture; use http_utils::tls_certs::ReloadingCertificateResolver; use hyper0::Uri; use metrics::BuildInfo; use metrics::launch_timestamp::LaunchTimestamp; +use pageserver_api::config::PostHogConfig; use reqwest::Certificate; use storage_controller::http::make_router; use storage_controller::metrics::preinitialize_metrics; use storage_controller::persistence::Persistence; use storage_controller::service::chaos_injector::ChaosInjector; +use storage_controller::service::feature_flag::FeatureFlagService; use storage_controller::service::{ Config, HEARTBEAT_INTERVAL_DEFAULT, LONG_RECONCILE_THRESHOLD_DEFAULT, MAX_OFFLINE_INTERVAL_DEFAULT, MAX_WARMING_UP_INTERVAL_DEFAULT, @@ -207,6 +212,19 @@ struct Cli { /// the compute notification directly (instead of via control plane). #[arg(long, default_value = "false")] use_local_compute_notifications: bool, + + /// Number of safekeepers to choose for a timeline when creating it. + /// Safekeepers will be choosen from different availability zones. + /// This option exists primarily for testing purposes. + #[arg(long, default_value = "3", value_parser = clap::value_parser!(i64).range(1..))] + timeline_safekeeper_count: i64, + + /// When set, actively checks and initiates heatmap downloads/uploads during reconciliation. + /// This speed up migrations by avoiding the default wait for the heatmap download interval. + /// Primarily useful for testing to reduce test execution time. + #[cfg(feature = "testing")] + #[arg(long, default_value = "true", action=ArgAction::Set)] + kick_secondary_downloads: bool, } enum StrictMode { @@ -236,6 +254,8 @@ struct Secrets { peer_jwt_token: Option, } +const POSTHOG_CONFIG_ENV: &str = "POSTHOG_CONFIG"; + impl Secrets { const DATABASE_URL_ENV: &'static str = "DATABASE_URL"; const PAGESERVER_JWT_TOKEN_ENV: &'static str = "PAGESERVER_JWT_TOKEN"; @@ -371,6 +391,11 @@ async fn async_main() -> anyhow::Result<()> { StrictMode::Strict if args.use_local_compute_notifications => { anyhow::bail!("`--use-local-compute-notifications` is only permitted in `--dev` mode"); } + StrictMode::Strict if args.timeline_safekeeper_count < 3 => { + anyhow::bail!( + "Running with less than 3 safekeepers per timeline is only permitted in `--dev` mode" + ); + } StrictMode::Strict => { tracing::info!("Starting in strict mode: configuration is OK.") } @@ -388,6 +413,18 @@ async fn async_main() -> anyhow::Result<()> { None => Vec::new(), }; + let posthog_config = if let Ok(json) = std::env::var(POSTHOG_CONFIG_ENV) { + let res: Result = serde_json::from_str(&json); + if let Ok(config) = res { + Some(config) + } else { + tracing::warn!("Invalid posthog config: {json}"); + None + } + } else { + None + }; + let config = Config { pageserver_jwt_token: secrets.pageserver_jwt_token, safekeeper_jwt_token: secrets.safekeeper_jwt_token, @@ -433,6 +470,10 @@ async fn async_main() -> anyhow::Result<()> { ssl_ca_certs, timelines_onto_safekeepers: args.timelines_onto_safekeepers, use_local_compute_notifications: args.use_local_compute_notifications, + timeline_safekeeper_count: args.timeline_safekeeper_count, + posthog_config: posthog_config.clone(), + #[cfg(feature = "testing")] + kick_secondary_downloads: args.kick_secondary_downloads, }; // Validate that we can connect to the database @@ -513,6 +554,23 @@ async fn async_main() -> anyhow::Result<()> { ) }); + let feature_flag_task = if let Some(posthog_config) = posthog_config { + let service = service.clone(); + let cancel = CancellationToken::new(); + let cancel_bg = cancel.clone(); + let task = tokio::task::spawn( + async move { + let feature_flag_service = FeatureFlagService::new(service, posthog_config); + let feature_flag_service = Arc::new(feature_flag_service); + feature_flag_service.run(cancel_bg).await + } + .instrument(tracing::info_span!("feature_flag_service")), + ); + Some((task, cancel)) + } else { + None + }; + // Wait until we receive a signal let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())?; let mut sigquit = tokio::signal::unix::signal(SignalKind::quit())?; @@ -560,6 +618,12 @@ async fn async_main() -> anyhow::Result<()> { chaos_jh.await.ok(); } + // If we were running the feature flag service, stop that so that we're not calling into Service while it shuts down + if let Some((feature_flag_task, feature_flag_cancel)) = feature_flag_task { + feature_flag_cancel.cancel(); + feature_flag_task.await.ok(); + } + service.shutdown().await; tracing::info!("Service shutdown complete"); diff --git a/storage_controller/src/metrics.rs b/storage_controller/src/metrics.rs index ccdbcad139..07713c3fbc 100644 --- a/storage_controller/src/metrics.rs +++ b/storage_controller/src/metrics.rs @@ -97,7 +97,7 @@ pub(crate) struct StorageControllerMetricGroup { /// Count of HTTP requests to the safekeeper that resulted in an error, /// broken down by the safekeeper node id, request name and method pub(crate) storage_controller_safekeeper_request_error: - measured::CounterVec, + measured::CounterVec, /// Latency of HTTP requests to the pageserver, broken down by pageserver /// node id, request name and method. This include both successful and unsuccessful @@ -111,7 +111,7 @@ pub(crate) struct StorageControllerMetricGroup { /// requests. #[metric(metadata = histogram::Thresholds::exponential_buckets(0.1, 2.0))] pub(crate) storage_controller_safekeeper_request_latency: - measured::HistogramVec, + measured::HistogramVec, /// Count of pass-through HTTP requests to the pageserver that resulted in an error, /// broken down by the pageserver node id, request name and method @@ -136,16 +136,17 @@ pub(crate) struct StorageControllerMetricGroup { pub(crate) storage_controller_leadership_status: measured::GaugeVec, - /// HTTP request status counters for handled requests + /// Indicator of stucked (long-running) reconciles, broken down by tenant, shard and sequence. + /// The metric is automatically removed once the reconciliation completes. pub(crate) storage_controller_reconcile_long_running: measured::CounterVec, /// Indicator of safekeeper reconciler queue depth, broken down by safekeeper, excluding ongoing reconciles. - pub(crate) storage_controller_safkeeper_reconciles_queued: + pub(crate) storage_controller_safekeeper_reconciles_queued: measured::GaugeVec, /// Indicator of completed safekeeper reconciles, broken down by safekeeper. - pub(crate) storage_controller_safkeeper_reconciles_complete: + pub(crate) storage_controller_safekeeper_reconciles_complete: measured::CounterVec, } @@ -218,6 +219,16 @@ pub(crate) struct PageserverRequestLabelGroup<'a> { pub(crate) method: Method, } +#[derive(measured::LabelGroup, Clone)] +#[label(set = SafekeeperRequestLabelGroupSet)] +pub(crate) struct SafekeeperRequestLabelGroup<'a> { + #[label(dynamic_with = lasso::ThreadedRodeo, default)] + pub(crate) safekeeper_id: &'a str, + #[label(dynamic_with = lasso::ThreadedRodeo, default)] + pub(crate) path: &'a str, + pub(crate) method: Method, +} + #[derive(measured::LabelGroup)] #[label(set = DatabaseQueryErrorLabelGroupSet)] pub(crate) struct DatabaseQueryErrorLabelGroup { diff --git a/storage_controller/src/node.rs b/storage_controller/src/node.rs index e180c49b43..cba007d75f 100644 --- a/storage_controller/src/node.rs +++ b/storage_controller/src/node.rs @@ -2,7 +2,7 @@ use std::str::FromStr; use std::time::Duration; use pageserver_api::controller_api::{ - AvailabilityZone, NodeAvailability, NodeDescribeResponse, NodeRegisterRequest, + AvailabilityZone, NodeAvailability, NodeDescribeResponse, NodeLifecycle, NodeRegisterRequest, NodeSchedulingPolicy, TenantLocateResponseShard, }; use pageserver_api::shard::TenantShardId; @@ -29,6 +29,7 @@ pub(crate) struct Node { availability: NodeAvailability, scheduling: NodeSchedulingPolicy, + lifecycle: NodeLifecycle, listen_http_addr: String, listen_http_port: u16, @@ -36,6 +37,8 @@ pub(crate) struct Node { listen_pg_addr: String, listen_pg_port: u16, + listen_grpc_addr: Option, + listen_grpc_port: Option, availability_zone_id: AvailabilityZone, @@ -99,8 +102,8 @@ impl Node { self.id == register_req.node_id && self.listen_http_addr == register_req.listen_http_addr && self.listen_http_port == register_req.listen_http_port - // Note: listen_https_port may change. See [`Self::need_update`] for mode details. - // && self.listen_https_port == register_req.listen_https_port + // Note: HTTPS and gRPC addresses may change, to allow for migrations. See + // [`Self::need_update`] for more details. && self.listen_pg_addr == register_req.listen_pg_addr && self.listen_pg_port == register_req.listen_pg_port && self.availability_zone_id == register_req.availability_zone_id @@ -108,9 +111,10 @@ impl Node { // Do we need to update an existing record in DB on this registration request? pub(crate) fn need_update(&self, register_req: &NodeRegisterRequest) -> bool { - // listen_https_port is checked here because it may change during migration to https. - // After migration, this check may be moved to registration_match. + // These are checked here, since they may change before we're fully migrated. self.listen_https_port != register_req.listen_https_port + || self.listen_grpc_addr != register_req.listen_grpc_addr + || self.listen_grpc_port != register_req.listen_grpc_port } /// For a shard located on this node, populate a response object @@ -124,6 +128,8 @@ impl Node { listen_https_port: self.listen_https_port, listen_pg_addr: self.listen_pg_addr.clone(), listen_pg_port: self.listen_pg_port, + listen_grpc_addr: self.listen_grpc_addr.clone(), + listen_grpc_port: self.listen_grpc_port, } } @@ -210,6 +216,8 @@ impl Node { listen_https_port: Option, listen_pg_addr: String, listen_pg_port: u16, + listen_grpc_addr: Option, + listen_grpc_port: Option, availability_zone_id: AvailabilityZone, use_https: bool, ) -> anyhow::Result { @@ -220,6 +228,10 @@ impl Node { ); } + if listen_grpc_addr.is_some() != listen_grpc_port.is_some() { + anyhow::bail!("cannot create node {id}: must specify both gRPC address and port"); + } + Ok(Self { id, listen_http_addr, @@ -227,7 +239,10 @@ impl Node { listen_https_port, listen_pg_addr, listen_pg_port, + listen_grpc_addr, + listen_grpc_port, scheduling: NodeSchedulingPolicy::Active, + lifecycle: NodeLifecycle::Active, availability: NodeAvailability::Offline, availability_zone_id, use_https, @@ -239,11 +254,14 @@ impl Node { NodePersistence { node_id: self.id.0 as i64, scheduling_policy: self.scheduling.into(), + lifecycle: self.lifecycle.into(), listen_http_addr: self.listen_http_addr.clone(), listen_http_port: self.listen_http_port as i32, listen_https_port: self.listen_https_port.map(|x| x as i32), listen_pg_addr: self.listen_pg_addr.clone(), listen_pg_port: self.listen_pg_port as i32, + listen_grpc_addr: self.listen_grpc_addr.clone(), + listen_grpc_port: self.listen_grpc_port.map(|port| port as i32), availability_zone_id: self.availability_zone_id.0.clone(), } } @@ -257,17 +275,27 @@ impl Node { ); } + if np.listen_grpc_addr.is_some() != np.listen_grpc_port.is_some() { + anyhow::bail!( + "can't load node {}: must specify both gRPC address and port", + np.node_id + ); + } + Ok(Self { id: NodeId(np.node_id as u64), // At startup we consider a node offline until proven otherwise. availability: NodeAvailability::Offline, scheduling: NodeSchedulingPolicy::from_str(&np.scheduling_policy) .expect("Bad scheduling policy in DB"), + lifecycle: NodeLifecycle::from_str(&np.lifecycle).expect("Bad lifecycle in DB"), listen_http_addr: np.listen_http_addr, listen_http_port: np.listen_http_port as u16, listen_https_port: np.listen_https_port.map(|x| x as u16), listen_pg_addr: np.listen_pg_addr, listen_pg_port: np.listen_pg_port as u16, + listen_grpc_addr: np.listen_grpc_addr, + listen_grpc_port: np.listen_grpc_port.map(|port| port as u16), availability_zone_id: AvailabilityZone(np.availability_zone_id), use_https, cancel: CancellationToken::new(), @@ -357,6 +385,8 @@ impl Node { listen_https_port: self.listen_https_port, listen_pg_addr: self.listen_pg_addr.clone(), listen_pg_port: self.listen_pg_port, + listen_grpc_addr: self.listen_grpc_addr.clone(), + listen_grpc_port: self.listen_grpc_port, } } } diff --git a/storage_controller/src/pageserver_client.rs b/storage_controller/src/pageserver_client.rs index 817409e112..d6fe173eb3 100644 --- a/storage_controller/src/pageserver_client.rs +++ b/storage_controller/src/pageserver_client.rs @@ -376,4 +376,13 @@ impl PageserverClient { .await ) } + + pub(crate) async fn update_feature_flag_spec(&self, spec: String) -> Result<()> { + measured_request!( + "update_feature_flag_spec", + crate::metrics::Method::Post, + &self.node_id_label, + self.inner.update_feature_flag_spec(spec).await + ) + } } diff --git a/storage_controller/src/persistence.rs b/storage_controller/src/persistence.rs index 052c0f02eb..2948e9019f 100644 --- a/storage_controller/src/persistence.rs +++ b/storage_controller/src/persistence.rs @@ -19,7 +19,7 @@ use futures::FutureExt; use futures::future::BoxFuture; use itertools::Itertools; use pageserver_api::controller_api::{ - AvailabilityZone, MetadataHealthRecord, NodeSchedulingPolicy, PlacementPolicy, + AvailabilityZone, MetadataHealthRecord, NodeLifecycle, NodeSchedulingPolicy, PlacementPolicy, SafekeeperDescribeResponse, ShardSchedulingPolicy, SkSchedulingPolicy, }; use pageserver_api::models::{ShardImportStatus, TenantConfig}; @@ -102,6 +102,7 @@ pub(crate) enum DatabaseOperation { UpdateNode, DeleteNode, ListNodes, + ListTombstones, BeginShardSplit, CompleteShardSplit, AbortShardSplit, @@ -357,6 +358,8 @@ impl Persistence { } /// When a node is first registered, persist it before using it for anything + /// If the provided node_id already exists, it will be error. + /// The common case is when a node marked for deletion wants to register. pub(crate) async fn insert_node(&self, node: &Node) -> DatabaseResult<()> { let np = &node.to_persistent(); self.with_measured_conn(DatabaseOperation::InsertNode, move |conn| { @@ -373,19 +376,41 @@ impl Persistence { /// At startup, populate the list of nodes which our shards may be placed on pub(crate) async fn list_nodes(&self) -> DatabaseResult> { - let nodes: Vec = self + use crate::schema::nodes::dsl::*; + + let result: Vec = self .with_measured_conn(DatabaseOperation::ListNodes, move |conn| { Box::pin(async move { Ok(crate::schema::nodes::table + .filter(lifecycle.ne(String::from(NodeLifecycle::Deleted))) .load::(conn) .await?) }) }) .await?; - tracing::info!("list_nodes: loaded {} nodes", nodes.len()); + tracing::info!("list_nodes: loaded {} nodes", result.len()); - Ok(nodes) + Ok(result) + } + + pub(crate) async fn list_tombstones(&self) -> DatabaseResult> { + use crate::schema::nodes::dsl::*; + + let result: Vec = self + .with_measured_conn(DatabaseOperation::ListTombstones, move |conn| { + Box::pin(async move { + Ok(crate::schema::nodes::table + .filter(lifecycle.eq(String::from(NodeLifecycle::Deleted))) + .load::(conn) + .await?) + }) + }) + .await?; + + tracing::info!("list_tombstones: loaded {} nodes", result.len()); + + Ok(result) } pub(crate) async fn update_node( @@ -404,6 +429,7 @@ impl Persistence { Box::pin(async move { let updated = diesel::update(nodes) .filter(node_id.eq(input_node_id.0 as i64)) + .filter(lifecycle.ne(String::from(NodeLifecycle::Deleted))) .set(values) .execute(conn) .await?; @@ -447,6 +473,55 @@ impl Persistence { .await } + /// Tombstone is a special state where the node is not deleted from the database, + /// but it is not available for usage. + /// The main reason for it is to prevent the flaky node to register. + pub(crate) async fn set_tombstone(&self, del_node_id: NodeId) -> DatabaseResult<()> { + use crate::schema::nodes::dsl::*; + self.update_node( + del_node_id, + lifecycle.eq(String::from(NodeLifecycle::Deleted)), + ) + .await + } + + pub(crate) async fn delete_node(&self, del_node_id: NodeId) -> DatabaseResult<()> { + use crate::schema::nodes::dsl::*; + self.with_measured_conn(DatabaseOperation::DeleteNode, move |conn| { + Box::pin(async move { + // You can hard delete a node only if it has a tombstone. + // So we need to check if the node has lifecycle set to deleted. + let node_to_delete = nodes + .filter(node_id.eq(del_node_id.0 as i64)) + .first::(conn) + .await + .optional()?; + + if let Some(np) = node_to_delete { + let lc = NodeLifecycle::from_str(&np.lifecycle).map_err(|e| { + DatabaseError::Logical(format!( + "Node {del_node_id} has invalid lifecycle: {e}" + )) + })?; + + if lc != NodeLifecycle::Deleted { + return Err(DatabaseError::Logical(format!( + "Node {del_node_id} was not soft deleted before, cannot hard delete it" + ))); + } + + diesel::delete(nodes) + .filter(node_id.eq(del_node_id.0 as i64)) + .execute(conn) + .await?; + } + + Ok(()) + }) + }) + .await + } + /// At startup, load the high level state for shards, such as their config + policy. This will /// be enriched at runtime with state discovered on pageservers. /// @@ -543,21 +618,6 @@ impl Persistence { .await } - pub(crate) async fn delete_node(&self, del_node_id: NodeId) -> DatabaseResult<()> { - use crate::schema::nodes::dsl::*; - self.with_measured_conn(DatabaseOperation::DeleteNode, move |conn| { - Box::pin(async move { - diesel::delete(nodes) - .filter(node_id.eq(del_node_id.0 as i64)) - .execute(conn) - .await?; - - Ok(()) - }) - }) - .await - } - /// When a tenant invokes the /re-attach API, this function is responsible for doing an efficient /// batched increment of the generations of all tenants whose generation_pageserver is equal to /// the node that called /re-attach. @@ -571,6 +631,19 @@ impl Persistence { let updated = self .with_measured_conn(DatabaseOperation::ReAttach, move |conn| { Box::pin(async move { + // Check if the node is not marked as deleted + let deleted_node: i64 = nodes + .filter(node_id.eq(input_node_id.0 as i64)) + .filter(lifecycle.eq(String::from(NodeLifecycle::Deleted))) + .count() + .get_result(conn) + .await?; + if deleted_node > 0 { + return Err(DatabaseError::Logical(format!( + "Node {input_node_id} is marked as deleted, re-attach is not allowed" + ))); + } + let rows_updated = diesel::update(tenant_shards) .filter(generation_pageserver.eq(input_node_id.0 as i64)) .set(generation.eq(generation + 1)) @@ -927,7 +1000,7 @@ impl Persistence { .execute(conn).await?; if u8::try_from(updated) .map_err(|_| DatabaseError::Logical( - format!("Overflow existing shard count {} while splitting", updated)) + format!("Overflow existing shard count {updated} while splitting")) )? != old_shard_count.count() { // Perhaps a deletion or another split raced with this attempt to split, mutating // the parent shards that we intend to split. In this case the split request should fail. @@ -1267,8 +1340,7 @@ impl Persistence { if inserted_updated != 1 { return Err(DatabaseError::Logical(format!( - "unexpected number of rows ({})", - inserted_updated + "unexpected number of rows ({inserted_updated})" ))); } @@ -1330,8 +1402,7 @@ impl Persistence { 0 => Ok(false), 1 => Ok(true), _ => Err(DatabaseError::Logical(format!( - "unexpected number of rows ({})", - inserted_updated + "unexpected number of rows ({inserted_updated})" ))), } }) @@ -1400,8 +1471,7 @@ impl Persistence { 0 => Ok(()), 1 => Ok(()), _ => Err(DatabaseError::Logical(format!( - "unexpected number of rows ({})", - updated + "unexpected number of rows ({updated})" ))), } }) @@ -1494,8 +1564,7 @@ impl Persistence { 0 => Ok(false), 1 => Ok(true), _ => Err(DatabaseError::Logical(format!( - "unexpected number of rows ({})", - inserted_updated + "unexpected number of rows ({inserted_updated})" ))), } }) @@ -2048,6 +2117,9 @@ pub(crate) struct NodePersistence { pub(crate) listen_pg_port: i32, pub(crate) availability_zone_id: String, pub(crate) listen_https_port: Option, + pub(crate) lifecycle: String, + pub(crate) listen_grpc_addr: Option, + pub(crate) listen_grpc_port: Option, } /// Tenant metadata health status that are stored durably. diff --git a/storage_controller/src/reconciler.rs b/storage_controller/src/reconciler.rs index b03a6dae04..92844c9c7b 100644 --- a/storage_controller/src/reconciler.rs +++ b/storage_controller/src/reconciler.rs @@ -856,6 +856,7 @@ impl Reconciler { &self.shard, &self.config, &self.placement_policy, + self.intent.secondary.len(), ); match self.observed.locations.get(&node.get_id()) { Some(conf) if conf.conf.as_ref() == Some(&wanted_conf) => { @@ -1235,11 +1236,11 @@ pub(crate) fn attached_location_conf( shard: &ShardIdentity, config: &TenantConfig, policy: &PlacementPolicy, + secondary_count: usize, ) -> LocationConfig { let has_secondaries = match policy { - PlacementPolicy::Attached(0) | PlacementPolicy::Detached | PlacementPolicy::Secondary => { - false - } + PlacementPolicy::Detached | PlacementPolicy::Secondary => false, + PlacementPolicy::Attached(0) => secondary_count > 0, PlacementPolicy::Attached(_) => true, }; diff --git a/storage_controller/src/safekeeper_client.rs b/storage_controller/src/safekeeper_client.rs index 1f3ea96d96..bcf223c731 100644 --- a/storage_controller/src/safekeeper_client.rs +++ b/storage_controller/src/safekeeper_client.rs @@ -5,7 +5,7 @@ use safekeeper_client::mgmt_api::{Client, Result}; use utils::id::{NodeId, TenantId, TimelineId}; use utils::logging::SecretString; -use crate::metrics::PageserverRequestLabelGroup; +use crate::metrics::SafekeeperRequestLabelGroup; /// Thin wrapper around [`safekeeper_client::mgmt_api::Client`]. It allows the storage /// controller to collect metrics in a non-intrusive manner. @@ -19,8 +19,8 @@ pub(crate) struct SafekeeperClient { macro_rules! measured_request { ($name:literal, $method:expr, $node_id: expr, $invoke:expr) => {{ - let labels = PageserverRequestLabelGroup { - pageserver_id: $node_id, + let labels = SafekeeperRequestLabelGroup { + safekeeper_id: $node_id, path: $name, method: $method, }; @@ -35,7 +35,7 @@ macro_rules! measured_request { if res.is_err() { let error_counters = &crate::metrics::METRICS_REGISTRY .metrics_group - .storage_controller_pageserver_request_error; + .storage_controller_safekeeper_request_error; error_counters.inc(labels) } diff --git a/storage_controller/src/scheduler.rs b/storage_controller/src/scheduler.rs index 773373391e..b86b4dfab1 100644 --- a/storage_controller/src/scheduler.rs +++ b/storage_controller/src/scheduler.rs @@ -23,7 +23,7 @@ pub enum ScheduleError { impl From for ApiError { fn from(value: ScheduleError) -> Self { - ApiError::Conflict(format!("Scheduling error: {}", value)) + ApiError::Conflict(format!("Scheduling error: {value}")) } } @@ -825,6 +825,7 @@ impl Scheduler { struct AzScore { home_shard_count: usize, scheduleable: bool, + node_count: usize, } let mut azs: HashMap<&AvailabilityZone, AzScore> = HashMap::new(); @@ -832,6 +833,7 @@ impl Scheduler { let az = azs.entry(&node.az).or_default(); az.home_shard_count += node.home_shard_count; az.scheduleable |= matches!(node.may_schedule, MaySchedule::Yes(_)); + az.node_count += 1; } // If any AZs are schedulable, then filter out the non-schedulable ones (i.e. AZs where @@ -840,10 +842,20 @@ impl Scheduler { azs.retain(|_, i| i.scheduleable); } + // We will multiply up shard counts by the max node count for scoring, before dividing + // by per-node max node count, to get a normalized score that doesn't collapse to zero + // when the absolute shard count is less than the node count. + let max_node_count = azs.values().map(|i| i.node_count).max().unwrap_or(0); + // Find the AZ with the lowest number of shards currently allocated Some( azs.into_iter() - .min_by_key(|i| (i.1.home_shard_count, i.0)) + .min_by_key(|i| { + ( + (i.1.home_shard_count * max_node_count) / i.1.node_count, + i.0, + ) + }) .unwrap() .0 .clone(), @@ -891,7 +903,7 @@ impl Scheduler { /// rigorously updating them on every change. pub(crate) fn update_metrics(&self) { for (node_id, node) in &self.nodes { - let node_id_str = format!("{}", node_id); + let node_id_str = format!("{node_id}"); let label_group = NodeLabelGroup { az: &node.az.0, node_id: &node_id_str, @@ -945,6 +957,8 @@ pub(crate) mod test_utils { None, format!("pghost-{i}"), 5432 + i as u16, + Some(format!("grpchost-{i}")), + Some(51051 + i as u16), az_iter .next() .cloned() @@ -1312,7 +1326,7 @@ mod tests { .map(|(node_id, node)| (node_id, node.home_shard_count)) .collect::>(); node_home_counts.sort_by_key(|i| i.0); - eprintln!("Selected {}, vs nodes {:?}", preferred_az, node_home_counts); + eprintln!("Selected {preferred_az}, vs nodes {node_home_counts:?}"); let tenant_shard_id = TenantShardId { tenant_id: TenantId::generate(), diff --git a/storage_controller/src/schema.rs b/storage_controller/src/schema.rs index 20be9bb5ca..312f7e0b0e 100644 --- a/storage_controller/src/schema.rs +++ b/storage_controller/src/schema.rs @@ -33,6 +33,9 @@ diesel::table! { listen_pg_port -> Int4, availability_zone_id -> Varchar, listen_https_port -> Nullable, + lifecycle -> Varchar, + listen_grpc_addr -> Nullable, + listen_grpc_port -> Nullable, } } diff --git a/storage_controller/src/service.rs b/storage_controller/src/service.rs index 790797bae2..b4dfd01249 100644 --- a/storage_controller/src/service.rs +++ b/storage_controller/src/service.rs @@ -1,5 +1,6 @@ pub mod chaos_injector; mod context_iterator; +pub mod feature_flag; pub(crate) mod safekeeper_reconciler; mod safekeeper_service; @@ -25,6 +26,7 @@ use futures::stream::FuturesUnordered; use http_utils::error::ApiError; use hyper::Uri; use itertools::Itertools; +use pageserver_api::config::PostHogConfig; use pageserver_api::controller_api::{ AvailabilityZone, MetadataHealthRecord, MetadataHealthUpdateRequest, NodeAvailability, NodeRegisterRequest, NodeSchedulingPolicy, NodeShard, NodeShardResponse, PlacementPolicy, @@ -166,6 +168,7 @@ enum NodeOperations { Register, Configure, Delete, + DeleteTombstone, } /// The leadership status for the storage controller process. @@ -259,7 +262,7 @@ fn passthrough_api_error(node: &Node, e: mgmt_api::Error) -> ApiError { // Presume errors receiving body are connectivity/availability issues except for decoding errors let src_str = err.source().map(|e| e.to_string()).unwrap_or_default(); ApiError::ResourceUnavailable( - format!("{node} error receiving error body: {err} {}", src_str).into(), + format!("{node} error receiving error body: {err} {src_str}").into(), ) } mgmt_api::Error::ApiError(StatusCode::NOT_FOUND, msg) => { @@ -465,6 +468,16 @@ pub struct Config { pub timelines_onto_safekeepers: bool, pub use_local_compute_notifications: bool, + + /// Number of safekeepers to choose for a timeline when creating it. + /// Safekeepers will be choosen from different availability zones. + pub timeline_safekeeper_count: i64, + + /// PostHog integration config + pub posthog_config: Option, + + #[cfg(feature = "testing")] + pub kick_secondary_downloads: bool, } impl From for ApiError { @@ -663,7 +676,7 @@ impl std::fmt::Display for StopReconciliationsReason { Self::ShuttingDown => "Shutting down", Self::SteppingDown => "Stepping down", }; - write!(writer, "{}", s) + write!(writer, "{s}") } } @@ -1107,7 +1120,8 @@ impl Service { observed } - /// Used during [`Self::startup_reconcile`]: detach a list of unknown-to-us tenants from pageservers. + /// Used during [`Self::startup_reconcile`] and shard splits: detach a list of unknown-to-us + /// tenants from pageservers. /// /// This is safe to run in the background, because if we don't have this TenantShardId in our map of /// tenants, then it is probably something incompletely deleted before: we will not fight with any @@ -1681,6 +1695,8 @@ impl Service { None, "".to_string(), 123, + None, + None, AvailabilityZone("test_az".to_string()), false, ) @@ -2056,6 +2072,7 @@ impl Service { &tenant_shard.shard, &tenant_shard.config, &PlacementPolicy::Attached(0), + tenant_shard.intent.get_secondary().len(), )), }, )]); @@ -2265,6 +2282,7 @@ impl Service { // fail, and start from scratch, so it doesn't make sense for us to try and preserve // the stale/multi states at this point. mode: LocationConfigMode::AttachedSingle, + stripe_size: shard.shard.stripe_size, }); shard.generation = std::cmp::max(shard.generation, Some(new_gen)); @@ -2298,6 +2316,7 @@ impl Service { id: *tenant_shard_id, r#gen: None, mode: LocationConfigMode::Secondary, + stripe_size: shard.shard.stripe_size, }); // We must not update observed, because we have no guarantee that our @@ -5264,7 +5283,7 @@ impl Service { shard_params, result .iter() - .map(|s| format!("{:?}", s)) + .map(|s| format!("{s:?}")) .collect::>() .join(",") ); @@ -5595,7 +5614,15 @@ impl Service { for parent_id in parent_ids { let child_ids = parent_id.split(new_shard_count); - let (pageserver, generation, policy, parent_ident, config, preferred_az) = { + let ( + pageserver, + generation, + policy, + parent_ident, + config, + preferred_az, + secondary_count, + ) = { let mut old_state = tenants .remove(&parent_id) .expect("It was present, we just split it"); @@ -5615,6 +5642,7 @@ impl Service { old_state.shard, old_state.config.clone(), old_state.preferred_az().cloned(), + old_state.intent.get_secondary().len(), ) }; @@ -5636,6 +5664,7 @@ impl Service { &child_shard, &config, &policy, + secondary_count, )), }, ); @@ -6177,7 +6206,7 @@ impl Service { }, ) .await - .map_err(|e| ApiError::Conflict(format!("Failed to split {}: {}", parent_id, e)))?; + .map_err(|e| ApiError::Conflict(format!("Failed to split {parent_id}: {e}")))?; fail::fail_point!("shard-split-post-remote", |_| Err(ApiError::Conflict( "failpoint".to_string() @@ -6194,7 +6223,7 @@ impl Service { response .new_shards .iter() - .map(|s| format!("{:?}", s)) + .map(|s| format!("{s:?}")) .collect::>() .join(",") ); @@ -6210,7 +6239,11 @@ impl Service { } } - pausable_failpoint!("shard-split-pre-complete"); + fail::fail_point!("shard-split-pre-complete", |_| Err(ApiError::Conflict( + "failpoint".to_string() + ))); + + pausable_failpoint!("shard-split-pre-complete-pause"); // TODO: if the pageserver restarted concurrently with our split API call, // the actual generation of the child shard might differ from the generation @@ -6232,6 +6265,15 @@ impl Service { let (response, child_locations, waiters) = self.tenant_shard_split_commit_inmem(tenant_id, new_shard_count, new_stripe_size); + // Notify all page servers to detach and clean up the old shards because they will no longer + // be needed. This is best-effort: if it fails, it will be cleaned up on a subsequent + // Pageserver re-attach/startup. + let shards_to_cleanup = targets + .iter() + .map(|target| (target.parent_id, target.node.get_id())) + .collect(); + self.cleanup_locations(shards_to_cleanup).await; + // Send compute notifications for all the new shards let mut failed_notifications = Vec::new(); for (child_id, child_ps, stripe_size) in child_locations { @@ -6634,6 +6676,8 @@ impl Service { /// This is for debug/support only: assuming tenant data is already present in S3, we "create" a /// tenant with a very high generation number so that it will see the existing data. + /// It does not create timelines on safekeepers, because they might already exist on some + /// safekeeper set. So, the timelines are not storcon-managed after the import. pub(crate) async fn tenant_import( &self, tenant_id: TenantId, @@ -6909,7 +6953,7 @@ impl Service { /// detaching or deleting it on pageservers. We do not try and re-schedule any /// tenants that were on this node. pub(crate) async fn node_drop(&self, node_id: NodeId) -> Result<(), ApiError> { - self.persistence.delete_node(node_id).await?; + self.persistence.set_tombstone(node_id).await?; let mut locked = self.inner.write().unwrap(); @@ -7033,9 +7077,10 @@ impl Service { // That is safe because in Service::spawn we only use generation_pageserver if it refers to a node // that exists. - // 2. Actually delete the node from the database and from in-memory state + // 2. Actually delete the node from in-memory state and set tombstone to the database + // for preventing the node to register again. tracing::info!("Deleting node from database"); - self.persistence.delete_node(node_id).await?; + self.persistence.set_tombstone(node_id).await?; Ok(()) } @@ -7054,6 +7099,34 @@ impl Service { Ok(nodes) } + pub(crate) async fn tombstone_list(&self) -> Result, ApiError> { + self.persistence + .list_tombstones() + .await? + .into_iter() + .map(|np| Node::from_persistent(np, false)) + .collect::, _>>() + .map_err(ApiError::InternalServerError) + } + + pub(crate) async fn tombstone_delete(&self, node_id: NodeId) -> Result<(), ApiError> { + let _node_lock = trace_exclusive_lock( + &self.node_op_locks, + node_id, + NodeOperations::DeleteTombstone, + ) + .await; + + if matches!(self.get_node(node_id).await, Err(ApiError::NotFound(_))) { + self.persistence.delete_node(node_id).await?; + Ok(()) + } else { + Err(ApiError::Conflict(format!( + "Node {node_id} is in use, consider using tombstone API first" + ))) + } + } + pub(crate) async fn get_node(&self, node_id: NodeId) -> Result { self.inner .read() @@ -7205,6 +7278,12 @@ impl Service { )); } + if register_req.listen_grpc_addr.is_some() != register_req.listen_grpc_port.is_some() { + return Err(ApiError::BadRequest(anyhow::anyhow!( + "must specify both gRPC address and port" + ))); + } + // Ordering: we must persist the new node _before_ adding it to in-memory state. // This ensures that before we use it for anything or expose it via any external // API, it is guaranteed to be available after a restart. @@ -7215,6 +7294,8 @@ impl Service { register_req.listen_https_port, register_req.listen_pg_addr, register_req.listen_pg_port, + register_req.listen_grpc_addr, + register_req.listen_grpc_port, register_req.availability_zone_id.clone(), self.config.use_https_pageserver_api, ); @@ -7224,7 +7305,25 @@ impl Service { }; match registration_status { - RegistrationStatus::New => self.persistence.insert_node(&new_node).await?, + RegistrationStatus::New => { + self.persistence.insert_node(&new_node).await.map_err(|e| { + if matches!( + e, + crate::persistence::DatabaseError::Query( + diesel::result::Error::DatabaseError( + diesel::result::DatabaseErrorKind::UniqueViolation, + _, + ) + ) + ) { + // The node can be deleted by tombstone API, and not show up in the list of nodes. + // If you see this error, check tombstones first. + ApiError::Conflict(format!("Node {} is already exists", new_node.get_id())) + } else { + ApiError::from(e) + } + })?; + } RegistrationStatus::NeedUpdate => { self.persistence .update_node_on_registration( @@ -7573,7 +7672,7 @@ impl Service { if let Some(ongoing) = ongoing_op { return Err(ApiError::PreconditionFailed( - format!("Background operation already ongoing for node: {}", ongoing).into(), + format!("Background operation already ongoing for node: {ongoing}").into(), )); } @@ -7704,7 +7803,7 @@ impl Service { if let Some(ongoing) = ongoing_op { return Err(ApiError::PreconditionFailed( - format!("Background operation already ongoing for node: {}", ongoing).into(), + format!("Background operation already ongoing for node: {ongoing}").into(), )); } @@ -8292,6 +8391,11 @@ impl Service { /// we have this helper to move things along faster. #[cfg(feature = "testing")] async fn kick_secondary_download(&self, tenant_shard_id: TenantShardId) { + if !self.config.kick_secondary_downloads { + // No-op if kick_secondary_downloads functionaliuty is not configured + return; + } + let (attached_node, secondaries) = { let locked = self.inner.read().unwrap(); let Some(shard) = locked.tenants.get(&tenant_shard_id) else { @@ -8701,15 +8805,22 @@ impl Service { let waiter_count = waiters.len(); match self.await_waiters(waiters, RECONCILE_TIMEOUT).await { Ok(()) => {} - Err(ReconcileWaitError::Failed(_, reconcile_error)) - if matches!(*reconcile_error, ReconcileError::Cancel) => - { - // Ignore reconciler cancel errors: this reconciler might have shut down - // because some other change superceded it. We will return a nonzero number, - // so the caller knows they might have to call again to quiesce the system. - } Err(e) => { - return Err(e); + if let ReconcileWaitError::Failed(_, reconcile_error) = &e { + match **reconcile_error { + ReconcileError::Cancel + | ReconcileError::Remote(mgmt_api::Error::Cancelled) => { + // Ignore reconciler cancel errors: this reconciler might have shut down + // because some other change superceded it. We will return a nonzero number, + // so the caller knows they might have to call again to quiesce the system. + } + _ => { + return Err(e); + } + } + } else { + return Err(e); + } } }; @@ -8763,7 +8874,7 @@ impl Service { let nodes = self.inner.read().unwrap().nodes.clone(); let node = nodes.get(secondary).ok_or(mgmt_api::Error::ApiError( StatusCode::NOT_FOUND, - format!("Node with id {} not found", secondary), + format!("Node with id {secondary} not found"), ))?; match node @@ -8842,8 +8953,7 @@ impl Service { Err(err) => { return Err(OperationError::FinalizeError( format!( - "Failed to finalise drain cancel of {} by setting scheduling policy to Active: {}", - node_id, err + "Failed to finalise drain cancel of {node_id} by setting scheduling policy to Active: {err}" ) .into(), )); @@ -8947,8 +9057,7 @@ impl Service { Err(err) => { return Err(OperationError::FinalizeError( format!( - "Failed to finalise drain cancel of {} by setting scheduling policy to Active: {}", - node_id, err + "Failed to finalise drain cancel of {node_id} by setting scheduling policy to Active: {err}" ) .into(), )); @@ -9158,8 +9267,7 @@ impl Service { Err(err) => { return Err(OperationError::FinalizeError( format!( - "Failed to finalise drain cancel of {} by setting scheduling policy to Active: {}", - node_id, err + "Failed to finalise drain cancel of {node_id} by setting scheduling policy to Active: {err}" ) .into(), )); @@ -9241,8 +9349,7 @@ impl Service { Err(err) => { return Err(OperationError::FinalizeError( format!( - "Failed to finalise drain cancel of {} by setting scheduling policy to Active: {}", - node_id, err + "Failed to finalise drain cancel of {node_id} by setting scheduling policy to Active: {err}" ) .into(), )); diff --git a/storage_controller/src/service/chaos_injector.rs b/storage_controller/src/service/chaos_injector.rs index 9c7a9e3798..4087de200a 100644 --- a/storage_controller/src/service/chaos_injector.rs +++ b/storage_controller/src/service/chaos_injector.rs @@ -107,7 +107,7 @@ impl ChaosInjector { // - Skip shards doing a graceful migration already, so that we allow these to run to // completion rather than only exercising the first part and then cancelling with // some other chaos. - !matches!(shard.get_scheduling_policy(), ShardSchedulingPolicy::Active) + matches!(shard.get_scheduling_policy(), ShardSchedulingPolicy::Active) && shard.get_preferred_node().is_none() } diff --git a/storage_controller/src/service/feature_flag.rs b/storage_controller/src/service/feature_flag.rs new file mode 100644 index 0000000000..645eb75237 --- /dev/null +++ b/storage_controller/src/service/feature_flag.rs @@ -0,0 +1,117 @@ +use std::{sync::Arc, time::Duration}; + +use futures::StreamExt; +use pageserver_api::config::PostHogConfig; +use pageserver_client::mgmt_api; +use posthog_client_lite::{PostHogClient, PostHogClientConfig}; +use reqwest::StatusCode; +use tokio::time::MissedTickBehavior; +use tokio_util::sync::CancellationToken; + +use crate::{pageserver_client::PageserverClient, service::Service}; + +pub struct FeatureFlagService { + service: Arc, + config: PostHogConfig, + client: PostHogClient, + http_client: reqwest::Client, +} + +const DEFAULT_POSTHOG_REFRESH_INTERVAL: Duration = Duration::from_secs(30); + +impl FeatureFlagService { + pub fn new(service: Arc, config: PostHogConfig) -> Self { + let client = PostHogClient::new(PostHogClientConfig { + project_id: config.project_id.clone(), + server_api_key: config.server_api_key.clone(), + client_api_key: config.client_api_key.clone(), + private_api_url: config.private_api_url.clone(), + public_api_url: config.public_api_url.clone(), + }); + Self { + service, + config, + client, + http_client: reqwest::Client::new(), + } + } + + async fn refresh(self: Arc, cancel: CancellationToken) -> Result<(), anyhow::Error> { + let nodes = { + let inner = self.service.inner.read().unwrap(); + inner.nodes.clone() + }; + + let feature_flag_spec = self.client.get_feature_flags_local_evaluation_raw().await?; + let stream = futures::stream::iter(nodes.values().cloned()).map(|node| { + let this = self.clone(); + let feature_flag_spec = feature_flag_spec.clone(); + async move { + let res = async { + let client = PageserverClient::new( + node.get_id(), + this.http_client.clone(), + node.base_url(), + // TODO: what if we rotate the token during storcon lifetime? + this.service.config.pageserver_jwt_token.as_deref(), + ); + + client.update_feature_flag_spec(feature_flag_spec).await?; + tracing::info!( + "Updated {}({}) with feature flag spec", + node.get_id(), + node.base_url() + ); + Ok::<_, mgmt_api::Error>(()) + }; + + if let Err(e) = res.await { + if let mgmt_api::Error::ApiError(status, _) = e { + if status == StatusCode::NOT_FOUND { + // This is expected during deployments where the API is not available, so we can ignore it + return; + } + } + tracing::warn!( + "Failed to update feature flag spec for {}: {e}", + node.get_id() + ); + } + } + }); + let mut stream = stream.buffer_unordered(8); + + while stream.next().await.is_some() { + if cancel.is_cancelled() { + return Ok(()); + } + } + + Ok(()) + } + + pub async fn run(self: Arc, cancel: CancellationToken) { + let refresh_interval = self + .config + .refresh_interval + .unwrap_or(DEFAULT_POSTHOG_REFRESH_INTERVAL); + let mut interval = tokio::time::interval(refresh_interval); + interval.set_missed_tick_behavior(MissedTickBehavior::Skip); + tracing::info!( + "Starting feature flag service with refresh interval: {:?}", + refresh_interval + ); + loop { + tokio::select! { + _ = interval.tick() => {} + _ = cancel.cancelled() => { + break; + } + } + let res = self.clone().refresh(cancel.clone()).await; + if let Err(e) = res { + tracing::error!("Failed to refresh feature flags: {e:#?}"); + } + } + } +} diff --git a/storage_controller/src/service/safekeeper_reconciler.rs b/storage_controller/src/service/safekeeper_reconciler.rs index fbf0b5c4e3..a3c5082be6 100644 --- a/storage_controller/src/service/safekeeper_reconciler.rs +++ b/storage_controller/src/service/safekeeper_reconciler.rs @@ -230,7 +230,7 @@ impl ReconcilerHandle { // increase it before putting into the queue. let queued_gauge = &METRICS_REGISTRY .metrics_group - .storage_controller_safkeeper_reconciles_queued; + .storage_controller_safekeeper_reconciles_queued; let label_group = SafekeeperReconcilerLabelGroup { sk_az: &sk_az, sk_node_id: &sk_node_id, @@ -306,7 +306,7 @@ impl SafekeeperReconciler { let queued_gauge = &METRICS_REGISTRY .metrics_group - .storage_controller_safkeeper_reconciles_queued; + .storage_controller_safekeeper_reconciles_queued; queued_gauge.set( SafekeeperReconcilerLabelGroup { sk_az: &req.safekeeper.skp.availability_zone_id, @@ -547,7 +547,7 @@ impl SafekeeperReconcilerInner { let complete_counter = &METRICS_REGISTRY .metrics_group - .storage_controller_safkeeper_reconciles_complete; + .storage_controller_safekeeper_reconciles_complete; complete_counter.inc(SafekeeperReconcilerLabelGroup { sk_az: &req.safekeeper.skp.availability_zone_id, sk_node_id: &req.safekeeper.get_id().to_string(), diff --git a/storage_controller/src/service/safekeeper_service.rs b/storage_controller/src/service/safekeeper_service.rs index 1f673fe445..fec81fb661 100644 --- a/storage_controller/src/service/safekeeper_service.rs +++ b/storage_controller/src/service/safekeeper_service.rs @@ -1,3 +1,4 @@ +use std::cmp::max; use std::collections::HashSet; use std::str::FromStr; use std::sync::Arc; @@ -17,7 +18,8 @@ use pageserver_api::controller_api::{ SafekeeperDescribeResponse, SkSchedulingPolicy, TimelineImportRequest, }; use pageserver_api::models::{SafekeeperInfo, SafekeepersInfo, TimelineInfo}; -use safekeeper_api::membership::{MemberSet, SafekeeperId}; +use safekeeper_api::PgVersionId; +use safekeeper_api::membership::{MemberSet, SafekeeperGeneration, SafekeeperId}; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; use utils::id::{NodeId, TenantId, TimelineId}; @@ -26,6 +28,13 @@ use utils::lsn::Lsn; use super::Service; +#[derive(serde::Serialize, serde::Deserialize, Clone)] +pub struct TimelineLocateResponse { + pub generation: SafekeeperGeneration, + pub sk_set: Vec, + pub new_sk_set: Option>, +} + impl Service { /// Timeline creation on safekeepers /// @@ -36,7 +45,7 @@ impl Service { &self, tenant_id: TenantId, timeline_id: TimelineId, - pg_version: u32, + pg_version: PgVersionId, timeline_persistence: &TimelinePersistence, ) -> Result, ApiError> { // If quorum is reached, return if we are outside of a specified timeout @@ -211,7 +220,7 @@ impl Service { read_only: bool, ) -> Result { let timeline_id = timeline_info.timeline_id; - let pg_version = timeline_info.pg_version * 10000; + let pg_version = PgVersionId::from(timeline_info.pg_version); // Initially start_lsn is determined by last_record_lsn in pageserver // response as it does initdb. However, later we persist it and in sk // creation calls replace with the value from the timeline row if it @@ -396,6 +405,38 @@ impl Service { Ok(()) } + /// Locate safekeepers for a timeline. + /// Return the generation, sk_set and new_sk_set if present. + /// If the timeline is not storcon-managed, return NotFound. + pub(crate) async fn tenant_timeline_locate( + &self, + tenant_id: TenantId, + timeline_id: TimelineId, + ) -> Result { + let timeline = self + .persistence + .get_timeline(tenant_id, timeline_id) + .await?; + + let Some(timeline) = timeline else { + return Err(ApiError::NotFound( + anyhow::anyhow!("Timeline {}/{} not found", tenant_id, timeline_id).into(), + )); + }; + + Ok(TimelineLocateResponse { + generation: SafekeeperGeneration::new(timeline.generation as u32), + sk_set: timeline + .sk_set + .iter() + .map(|id| NodeId(*id as u64)) + .collect(), + new_sk_set: timeline + .new_sk_set + .map(|sk_set| sk_set.iter().map(|id| NodeId(*id as u64)).collect()), + }) + } + /// Perform timeline deletion on safekeepers. Will return success: we persist the deletion into the reconciler. pub(super) async fn tenant_timeline_delete_safekeepers( self: &Arc, @@ -569,7 +610,8 @@ impl Service { Ok(()) } - /// Choose safekeepers for the new timeline: 3 in different azs. + /// Choose safekeepers for the new timeline in different azs. + /// 3 are choosen by default, but may be configured via config (for testing). pub(crate) async fn safekeepers_for_new_timeline( &self, ) -> Result, ApiError> { @@ -612,18 +654,14 @@ impl Service { ) }); // Number of safekeepers in different AZs we are looking for - let wanted_count = match all_safekeepers.len() { - 0 => { - return Err(ApiError::InternalServerError(anyhow::anyhow!( - "couldn't find any active safekeeper for new timeline", - ))); - } - // Have laxer requirements on testig mode as we don't want to - // spin up three safekeepers for every single test - #[cfg(feature = "testing")] - 1 | 2 => all_safekeepers.len(), - _ => 3, - }; + let mut wanted_count = self.config.timeline_safekeeper_count as usize; + // TODO(diko): remove this when `timeline_safekeeper_count` option is in the release + // branch and is specified in tests/neon_local config. + if cfg!(feature = "testing") && all_safekeepers.len() < wanted_count { + // In testing mode, we can have less safekeepers than the config says + wanted_count = max(all_safekeepers.len(), 1); + } + let mut sks = Vec::new(); let mut azs = HashSet::new(); for (_sk_util, sk_info, az_id) in all_safekeepers.iter() { diff --git a/storage_controller/src/tenant_shard.rs b/storage_controller/src/tenant_shard.rs index c7b2628ec4..359921ecbf 100644 --- a/storage_controller/src/tenant_shard.rs +++ b/storage_controller/src/tenant_shard.rs @@ -1184,11 +1184,19 @@ impl TenantShard { for secondary in self.intent.get_secondary() { // Make sure we don't try to migrate a secondary to our attached location: this case happens // easily in environments without multiple AZs. - let exclude = match self.intent.attached { + let mut exclude = match self.intent.attached { Some(attached) => vec![attached], None => vec![], }; + // Exclude all other secondaries from the scheduling process to avoid replacing + // one existing secondary with another existing secondary. + for another_secondary in self.intent.secondary.iter() { + if another_secondary != secondary { + exclude.push(*another_secondary); + } + } + let replacement = match &self.policy { PlacementPolicy::Attached(_) => { // Secondaries for an attached shard should be scheduled using `SecondaryShardTag` @@ -1348,28 +1356,19 @@ impl TenantShard { /// Reconciliation may still be needed for other aspects of state such as secondaries (see [`Self::dirty`]): this /// funciton should not be used to decide whether to reconcile. pub(crate) fn stably_attached(&self) -> Option { - if let Some(attach_intent) = self.intent.attached { - match self.observed.locations.get(&attach_intent) { - Some(loc) => match &loc.conf { - Some(conf) => match conf.mode { - LocationConfigMode::AttachedMulti - | LocationConfigMode::AttachedSingle - | LocationConfigMode::AttachedStale => { - // Our intent and observed state agree that this node is in an attached state. - Some(attach_intent) - } - // Our observed config is not an attached state - _ => None, - }, - // Our observed state is None, i.e. in flux - None => None, - }, - // We have no observed state for this node - None => None, - } - } else { - // Our intent is not to attach - None + // We have an intent to attach for this node + let attach_intent = self.intent.attached?; + // We have an observed state for this node + let location = self.observed.locations.get(&attach_intent)?; + // Our observed state is not None, i.e. not in flux + let location_config = location.conf.as_ref()?; + + // Check if our intent and observed state agree that this node is in an attached state. + match location_config.mode { + LocationConfigMode::AttachedMulti + | LocationConfigMode::AttachedSingle + | LocationConfigMode::AttachedStale => Some(attach_intent), + _ => None, } } @@ -1382,8 +1381,13 @@ impl TenantShard { .generation .expect("Attempted to enter attached state without a generation"); - let wanted_conf = - attached_location_conf(generation, &self.shard, &self.config, &self.policy); + let wanted_conf = attached_location_conf( + generation, + &self.shard, + &self.config, + &self.policy, + self.intent.get_secondary().len(), + ); match self.observed.locations.get(&node_id) { Some(conf) if conf.conf.as_ref() == Some(&wanted_conf) => {} Some(_) | None => { @@ -3004,21 +3008,18 @@ pub(crate) mod tests { if attachments_in_wrong_az > 0 { violations.push(format!( - "{} attachments scheduled to the incorrect AZ", - attachments_in_wrong_az + "{attachments_in_wrong_az} attachments scheduled to the incorrect AZ" )); } if secondaries_in_wrong_az > 0 { violations.push(format!( - "{} secondaries scheduled to the incorrect AZ", - secondaries_in_wrong_az + "{secondaries_in_wrong_az} secondaries scheduled to the incorrect AZ" )); } eprintln!( - "attachments_in_wrong_az={} secondaries_in_wrong_az={}", - attachments_in_wrong_az, secondaries_in_wrong_az + "attachments_in_wrong_az={attachments_in_wrong_az} secondaries_in_wrong_az={secondaries_in_wrong_az}" ); for (node_id, stats) in &node_stats { diff --git a/storage_controller/src/timeline_import.rs b/storage_controller/src/timeline_import.rs index eb50819d02..e88bce4c82 100644 --- a/storage_controller/src/timeline_import.rs +++ b/storage_controller/src/timeline_import.rs @@ -195,7 +195,7 @@ impl UpcallClient { let authorization_header = config .control_plane_jwt_token .clone() - .map(|jwt| format!("Bearer {}", jwt)); + .map(|jwt| format!("Bearer {jwt}")); let client = reqwest::ClientBuilder::new() .timeout(IMPORT_COMPLETE_REQUEST_TIMEOUT) diff --git a/storage_scrubber/src/checks.rs b/storage_scrubber/src/checks.rs index 865f0908f9..774418f237 100644 --- a/storage_scrubber/src/checks.rs +++ b/storage_scrubber/src/checks.rs @@ -146,7 +146,7 @@ pub(crate) async fn branch_cleanup_and_check_errors( for (layer, metadata) in index_part.layer_metadata { if metadata.file_size == 0 { result.errors.push(format!( - "index_part.json contains a layer {} that has 0 size in its layer metadata", layer, + "index_part.json contains a layer {layer} that has 0 size in its layer metadata", )) } diff --git a/storage_scrubber/src/lib.rs b/storage_scrubber/src/lib.rs index 25a157f108..d3ed5a8357 100644 --- a/storage_scrubber/src/lib.rs +++ b/storage_scrubber/src/lib.rs @@ -123,7 +123,7 @@ impl S3Target { pub fn with_sub_segment(&self, new_segment: &str) -> Self { let mut new_self = self.clone(); if new_self.prefix_in_bucket.is_empty() { - new_self.prefix_in_bucket = format!("/{}/", new_segment); + new_self.prefix_in_bucket = format!("/{new_segment}/"); } else { if new_self.prefix_in_bucket.ends_with('/') { new_self.prefix_in_bucket.pop(); diff --git a/storage_scrubber/src/scan_safekeeper_metadata.rs b/storage_scrubber/src/scan_safekeeper_metadata.rs index f10d758097..cf0a3d19e9 100644 --- a/storage_scrubber/src/scan_safekeeper_metadata.rs +++ b/storage_scrubber/src/scan_safekeeper_metadata.rs @@ -265,7 +265,7 @@ async fn load_timelines_from_db( // so spawn it off to run on its own. tokio::spawn(async move { if let Err(e) = connection.await { - eprintln!("connection error: {}", e); + eprintln!("connection error: {e}"); } }); @@ -274,7 +274,7 @@ async fn load_timelines_from_db( "and tenant_id in ({})", tenant_ids .iter() - .map(|t| format!("'{}'", t)) + .map(|t| format!("'{t}'")) .collect::>() .join(", ") ) diff --git a/test_runner/fixtures/endpoint/http.py b/test_runner/fixtures/endpoint/http.py index 4b4b98aa6c..f5be544439 100644 --- a/test_runner/fixtures/endpoint/http.py +++ b/test_runner/fixtures/endpoint/http.py @@ -24,7 +24,7 @@ The value to place in the `aud` claim. @final class ComputeClaimsScope(StrEnum): - ADMIN = "admin" + ADMIN = "compute_ctl:admin" @final @@ -69,15 +69,17 @@ class EndpointHttpClient(requests.Session): json: dict[str, str] = res.json() return json - def prewarm_lfc(self): - self.post(f"http://localhost:{self.external_port}/lfc/prewarm").raise_for_status() + def prewarm_lfc(self, from_endpoint_id: str | None = None): + url: str = f"http://localhost:{self.external_port}/lfc/prewarm" + params = {"from_endpoint": from_endpoint_id} if from_endpoint_id else dict() + self.post(url, params=params).raise_for_status() def prewarmed(): json = self.prewarm_lfc_status() status, err = json["status"], json.get("error") assert status == "completed", f"{status}, error {err}" - wait_until(prewarmed) + wait_until(prewarmed, timeout=60) def offload_lfc(self): url = f"http://localhost:{self.external_port}/lfc/offload" diff --git a/test_runner/fixtures/neon_api.py b/test_runner/fixtures/neon_api.py index 0cf5945458..9d85b9a332 100644 --- a/test_runner/fixtures/neon_api.py +++ b/test_runner/fixtures/neon_api.py @@ -129,6 +129,18 @@ class NeonAPI: return cast("dict[str, Any]", resp.json()) + def get_project_limits(self, project_id: str) -> dict[str, Any]: + resp = self.__request( + "GET", + f"/projects/{project_id}/limits", + headers={ + "Accept": "application/json", + "Content-Type": "application/json", + }, + ) + + return cast("dict[str, Any]", resp.json()) + def delete_project( self, project_id: str, diff --git a/test_runner/fixtures/neon_cli.py b/test_runner/fixtures/neon_cli.py index bb07e2b6d1..1b09e5bdd0 100644 --- a/test_runner/fixtures/neon_cli.py +++ b/test_runner/fixtures/neon_cli.py @@ -497,6 +497,7 @@ class NeonLocalCli(AbstractNeonCli): tenant_id: TenantId, pg_version: PgVersion, endpoint_id: str | None = None, + grpc: bool | None = None, hot_standby: bool = False, lsn: Lsn | None = None, pageserver_id: int | None = None, @@ -521,6 +522,8 @@ class NeonLocalCli(AbstractNeonCli): args.extend(["--external-http-port", str(external_http_port)]) if internal_http_port is not None: args.extend(["--internal-http-port", str(internal_http_port)]) + if grpc: + args.append("--grpc") if endpoint_id is not None: args.append(endpoint_id) if hot_standby: @@ -564,6 +567,7 @@ class NeonLocalCli(AbstractNeonCli): basebackup_request_tries: int | None = None, timeout: str | None = None, env: dict[str, str] | None = None, + dev: bool = False, ) -> subprocess.CompletedProcess[str]: args = [ "endpoint", @@ -589,6 +593,8 @@ class NeonLocalCli(AbstractNeonCli): args.extend(["--create-test-user"]) if timeout is not None: args.extend(["--start-timeout", str(timeout)]) + if dev: + args.extend(["--dev"]) res = self.raw_cli(args, extra_env_vars) res.check_returncode() @@ -617,7 +623,7 @@ class NeonLocalCli(AbstractNeonCli): destroy=False, check_return_code=True, mode: str | None = None, - ) -> subprocess.CompletedProcess[str]: + ) -> tuple[Lsn | None, subprocess.CompletedProcess[str]]: args = [ "endpoint", "stop", @@ -629,7 +635,11 @@ class NeonLocalCli(AbstractNeonCli): if endpoint_id is not None: args.append(endpoint_id) - return self.raw_cli(args, check_return_code=check_return_code) + proc = self.raw_cli(args, check_return_code=check_return_code) + log.debug(f"endpoint stop stdout: {proc.stdout}") + lsn_str = proc.stdout.split()[-1] + lsn: Lsn | None = None if lsn_str == "null" else Lsn(lsn_str) + return lsn, proc def mappings_map_branch( self, name: str, tenant_id: TenantId, timeline_id: TimelineId diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index db3f080261..4eb85119ca 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -453,6 +453,7 @@ class NeonEnvBuilder: pageserver_get_vectored_concurrent_io: str | None = None, pageserver_tracing_config: PageserverTracingConfig | None = None, pageserver_import_config: PageserverImportConfig | None = None, + storcon_kick_secondary_downloads: bool | None = None, ): self.repo_dir = repo_dir self.rust_log_override = rust_log_override @@ -489,7 +490,9 @@ class NeonEnvBuilder: self.config_init_force: str | None = None self.top_output_dir = top_output_dir self.control_plane_hooks_api: str | None = None - self.storage_controller_config: dict[Any, Any] | None = None + self.storage_controller_config: dict[Any, Any] | None = { + "timelines_onto_safekeepers": True, + } # Flag to enable https listener in pageserver, generate local ssl certs, # and force storage controller to use https for pageserver api. @@ -512,6 +515,8 @@ class NeonEnvBuilder: self.pageserver_tracing_config = pageserver_tracing_config self.pageserver_import_config = pageserver_import_config + self.storcon_kick_secondary_downloads = storcon_kick_secondary_downloads + self.pageserver_default_tenant_config_compaction_algorithm: dict[str, Any] | None = ( pageserver_default_tenant_config_compaction_algorithm ) @@ -1219,6 +1224,14 @@ class NeonEnv: else: cfg["storage_controller"] = {"use_local_compute_notifications": False} + if config.storcon_kick_secondary_downloads is not None: + # Configure whether storage controller should actively kick off secondary downloads + if "storage_controller" not in cfg: + cfg["storage_controller"] = {} + cfg["storage_controller"]["kick_secondary_downloads"] = ( + config.storcon_kick_secondary_downloads + ) + # Create config for pageserver http_auth_type = "NeonJWT" if config.auth_enabled else "Trust" pg_auth_type = "NeonJWT" if config.auth_enabled else "Trust" @@ -1228,6 +1241,7 @@ class NeonEnv: ): pageserver_port = PageserverPort( pg=self.port_distributor.get_port(), + grpc=self.port_distributor.get_port(), http=self.port_distributor.get_port(), https=self.port_distributor.get_port() if config.use_https_pageserver_api else None, ) @@ -1243,13 +1257,14 @@ class NeonEnv: ps_cfg: dict[str, Any] = { "id": ps_id, "listen_pg_addr": f"localhost:{pageserver_port.pg}", + "listen_grpc_addr": f"localhost:{pageserver_port.grpc}", "listen_http_addr": f"localhost:{pageserver_port.http}", "listen_https_addr": f"localhost:{pageserver_port.https}" if config.use_https_pageserver_api else None, "pg_auth_type": pg_auth_type, - "http_auth_type": http_auth_type, "grpc_auth_type": grpc_auth_type, + "http_auth_type": http_auth_type, "availability_zone": availability_zone, # Disable pageserver disk syncs in tests: when running tests concurrently, this avoids # the pageserver taking a long time to start up due to syncfs flushing other tests' data @@ -1762,6 +1777,7 @@ def neon_env_builder( @dataclass class PageserverPort: pg: int + grpc: int http: int https: int | None = None @@ -2054,6 +2070,14 @@ class NeonStorageController(MetricsGetter, LogUtils): headers=self.headers(TokenScope.ADMIN), ) + def tombstone_delete(self, node_id): + log.info(f"tombstone_delete({node_id})") + self.request( + "DELETE", + f"{self.api}/debug/v1/tombstone/{node_id}", + headers=self.headers(TokenScope.ADMIN), + ) + def node_drain(self, node_id): log.info(f"node_drain({node_id})") self.request( @@ -2110,6 +2134,14 @@ class NeonStorageController(MetricsGetter, LogUtils): ) return response.json() + def tombstone_list(self): + response = self.request( + "GET", + f"{self.api}/debug/v1/tombstone", + headers=self.headers(TokenScope.ADMIN), + ) + return response.json() + def tenant_shard_dump(self): """ Debug listing API: dumps the internal map of tenant shards @@ -2207,6 +2239,17 @@ class NeonStorageController(MetricsGetter, LogUtils): shards: list[dict[str, Any]] = body["shards"] return shards + def timeline_locate(self, tenant_id: TenantId, timeline_id: TimelineId): + """ + :return: dict {"generation": int, "sk_set": [int], "new_sk_set": [int]} + """ + response = self.request( + "GET", + f"{self.api}/debug/v1/tenant/{tenant_id}/timeline/{timeline_id}/locate", + headers=self.headers(TokenScope.ADMIN), + ) + return response.json() + def tenant_describe(self, tenant_id: TenantId): """ :return: list of {"shard_id": "", "node_id": int, "listen_pg_addr": str, "listen_pg_port": int, "listen_http_addr: str, "listen_http_port: int, preferred_az_id: str} @@ -2333,6 +2376,7 @@ class NeonStorageController(MetricsGetter, LogUtils): delay_max = max_interval while n > 0: n = self.reconcile_all() + if n == 0: break elif time.time() - start_at > timeout_secs: @@ -4030,6 +4074,16 @@ def static_proxy( "CREATE TABLE neon_control_plane.endpoints (endpoint_id VARCHAR(255) PRIMARY KEY, allowed_ips VARCHAR(255))" ) + vanilla_pg.stop() + vanilla_pg.edit_hba( + [ + "local all all trust", + "host all all 127.0.0.1/32 scram-sha-256", + "host all all ::1/128 scram-sha-256", + ] + ) + vanilla_pg.start() + proxy_port = port_distributor.get_port() mgmt_port = port_distributor.get_port() http_port = port_distributor.get_port() @@ -4155,6 +4209,8 @@ class Endpoint(PgProtocol, LogUtils): self._running = threading.Semaphore(0) self.__jwt: str | None = None + self.terminate_flush_lsn: Lsn | None = None + def http_client(self, retries: Retry | None = None) -> EndpointHttpClient: assert self.__jwt is not None return EndpointHttpClient( @@ -4167,6 +4223,7 @@ class Endpoint(PgProtocol, LogUtils): self, branch_name: str, endpoint_id: str | None = None, + grpc: bool | None = None, hot_standby: bool = False, lsn: Lsn | None = None, config_lines: list[str] | None = None, @@ -4191,6 +4248,7 @@ class Endpoint(PgProtocol, LogUtils): endpoint_id=self.endpoint_id, tenant_id=self.tenant_id, lsn=lsn, + grpc=grpc, hot_standby=hot_standby, pg_port=self.pg_port, external_http_port=self.external_http_port, @@ -4457,9 +4515,10 @@ class Endpoint(PgProtocol, LogUtils): running = self._running.acquire(blocking=False) if running: assert self.endpoint_id is not None - self.env.neon_cli.endpoint_stop( + lsn, _ = self.env.neon_cli.endpoint_stop( self.endpoint_id, check_return_code=self.check_stop_result, mode=mode ) + self.terminate_flush_lsn = lsn if sks_wait_walreceiver_gone is not None: for sk in sks_wait_walreceiver_gone[0]: @@ -4477,9 +4536,10 @@ class Endpoint(PgProtocol, LogUtils): running = self._running.acquire(blocking=False) if running: assert self.endpoint_id is not None - self.env.neon_cli.endpoint_stop( + lsn, _ = self.env.neon_cli.endpoint_stop( self.endpoint_id, True, check_return_code=self.check_stop_result, mode=mode ) + self.terminate_flush_lsn = lsn self.endpoint_id = None return self @@ -4488,6 +4548,7 @@ class Endpoint(PgProtocol, LogUtils): self, branch_name: str, endpoint_id: str | None = None, + grpc: bool | None = None, hot_standby: bool = False, lsn: Lsn | None = None, config_lines: list[str] | None = None, @@ -4505,6 +4566,7 @@ class Endpoint(PgProtocol, LogUtils): branch_name=branch_name, endpoint_id=endpoint_id, config_lines=config_lines, + grpc=grpc, hot_standby=hot_standby, lsn=lsn, pageserver_id=pageserver_id, @@ -4592,6 +4654,7 @@ class EndpointFactory: endpoint_id: str | None = None, tenant_id: TenantId | None = None, lsn: Lsn | None = None, + grpc: bool | None = None, hot_standby: bool = False, config_lines: list[str] | None = None, remote_ext_base_url: str | None = None, @@ -4611,6 +4674,7 @@ class EndpointFactory: return ep.create_start( branch_name=branch_name, endpoint_id=endpoint_id, + grpc=grpc, hot_standby=hot_standby, config_lines=config_lines, lsn=lsn, @@ -4625,6 +4689,7 @@ class EndpointFactory: endpoint_id: str | None = None, tenant_id: TenantId | None = None, lsn: Lsn | None = None, + grpc: bool | None = None, hot_standby: bool = False, config_lines: list[str] | None = None, pageserver_id: int | None = None, @@ -4647,6 +4712,7 @@ class EndpointFactory: branch_name=branch_name, endpoint_id=endpoint_id, lsn=lsn, + grpc=grpc, hot_standby=hot_standby, config_lines=config_lines, pageserver_id=pageserver_id, @@ -4671,6 +4737,7 @@ class EndpointFactory: self, origin: Endpoint, endpoint_id: str | None = None, + grpc: bool | None = None, config_lines: list[str] | None = None, ) -> Endpoint: branch_name = origin.branch_name @@ -4682,6 +4749,7 @@ class EndpointFactory: endpoint_id=endpoint_id, tenant_id=origin.tenant_id, lsn=None, + grpc=grpc, hot_standby=True, config_lines=config_lines, ) @@ -4690,6 +4758,7 @@ class EndpointFactory: self, origin: Endpoint, endpoint_id: str | None = None, + grpc: bool | None = None, config_lines: list[str] | None = None, ) -> Endpoint: branch_name = origin.branch_name @@ -4701,6 +4770,7 @@ class EndpointFactory: endpoint_id=endpoint_id, tenant_id=origin.tenant_id, lsn=None, + grpc=grpc, hot_standby=True, config_lines=config_lines, ) @@ -4852,6 +4922,9 @@ class Safekeeper(LogUtils): log.info(f"finished pulling timeline from {src_ids} to {self.id}") return res + def safekeeper_id(self) -> SafekeeperId: + return SafekeeperId(self.id, "localhost", self.port.pg_tenant_only) + @property def data_dir(self) -> Path: return self.env.repo_dir / "safekeepers" / f"sk{self.id}" diff --git a/test_runner/fixtures/pageserver/http.py b/test_runner/fixtures/pageserver/http.py index c29192c25c..d9037f2d08 100644 --- a/test_runner/fixtures/pageserver/http.py +++ b/test_runner/fixtures/pageserver/http.py @@ -1219,3 +1219,31 @@ class PageserverHttpClient(requests.Session, MetricsGetter): ) self.verbose_error(res) return res.json() + + def force_override_feature_flag(self, flag: str, value: str | None = None): + if value is None: + res = self.delete( + f"http://localhost:{self.port}/v1/feature_flag/{flag}", + ) + else: + res = self.put( + f"http://localhost:{self.port}/v1/feature_flag/{flag}", + params={"value": value}, + ) + self.verbose_error(res) + + def evaluate_feature_flag_boolean(self, tenant_id: TenantId, flag: str) -> Any: + res = self.get( + f"http://localhost:{self.port}/v1/tenant/{tenant_id}/feature_flag/{flag}", + params={"as": "boolean"}, + ) + self.verbose_error(res) + return res.json() + + def evaluate_feature_flag_multivariate(self, tenant_id: TenantId, flag: str) -> Any: + res = self.get( + f"http://localhost:{self.port}/v1/tenant/{tenant_id}/feature_flag/{flag}", + params={"as": "multivariate"}, + ) + self.verbose_error(res) + return res.json() diff --git a/test_runner/performance/large_synthetic_oltp/grow_action_blocks.sql b/test_runner/performance/large_synthetic_oltp/grow_action_blocks.sql new file mode 100644 index 0000000000..0860b76331 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/grow_action_blocks.sql @@ -0,0 +1,22 @@ +-- add 100000 rows or approximately 11 MB to the action_blocks table +-- takes about 1 second +INSERT INTO workflows.action_blocks ( + id, + uuid, + created_at, + status, + function_signature, + reference_id, + blocking, + run_synchronously +) +SELECT + id, + uuid_generate_v4(), + now() - (random() * interval '100 days'), -- Random date within the last 100 days + 'CONDITIONS_NOT_MET', + 'function_signature_' || id, -- Create a unique function signature using id + CASE WHEN random() > 0.5 THEN 'reference_' || id ELSE NULL END, -- 50% chance of being NULL + true, + CASE WHEN random() > 0.5 THEN true ELSE false END -- Random boolean value +FROM generate_series(1, 100000) AS id; \ No newline at end of file diff --git a/test_runner/performance/large_synthetic_oltp/grow_action_kwargs.sql b/test_runner/performance/large_synthetic_oltp/grow_action_kwargs.sql new file mode 100644 index 0000000000..8a2b7c398a --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/grow_action_kwargs.sql @@ -0,0 +1,11 @@ +-- add 100000 rows or approximately 10 MB to the action_kwargs table +-- takes about 5 minutes +INSERT INTO workflows.action_kwargs (created_at, key, uuid, value_id, state_value_id, action_block_id) +SELECT + now(), -- Using the default value for `created_at` + 'key_' || gs.id, -- Generating a unique key based on the id + uuid_generate_v4(), -- Generating a new UUID for each row + CASE WHEN gs.id % 2 = 0 THEN gs.id ELSE NULL END, -- Setting value_id for even ids + CASE WHEN gs.id % 2 <> 0 THEN gs.id ELSE NULL END, -- Setting state_value_id for odd ids + 1 -- Setting action_block_id as 1 for simplicity +FROM generate_series(1, 100000) AS gs(id); \ No newline at end of file diff --git a/test_runner/performance/large_synthetic_oltp/grow_device_fingerprint_event.sql b/test_runner/performance/large_synthetic_oltp/grow_device_fingerprint_event.sql new file mode 100644 index 0000000000..1ef38451b7 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/grow_device_fingerprint_event.sql @@ -0,0 +1,56 @@ +-- add 100000 rows or approx. 30 MB to the device_fingerprint_event table +-- takes about 4 minutes +INSERT INTO authentication.device_fingerprint_event ( + uuid, + created_at, + identity_uuid, + fingerprint_request_id, + fingerprint_id, + confidence_score, + ip_address, + url, + client_referrer, + last_seen_at, + raw_fingerprint_response, + session_uuid, + fingerprint_response, + browser_version, + browser_name, + device, + operating_system, + operating_system_version, + user_agent, + ip_address_location_city, + ip_address_location_region, + ip_address_location_country_code, + ip_address_location_latitude, + ip_address_location_longitude, + is_incognito +) +SELECT + gen_random_uuid(), -- Generates a random UUID for primary key + now() - (random() * interval '10 days'), -- Random timestamp within the last 10 days + gen_random_uuid(), -- Random UUID for identity + md5(gs::text), -- Simulates unique fingerprint request ID using `md5` hash of series number + md5((gs + 10000)::text), -- Simulates unique fingerprint ID + round(CAST(random() AS numeric), 2), -- Generates a random score between 0 and 1, cast `random()` to numeric + '192.168.' || (random() * 255)::int || '.' || (random() * 255)::int, -- Random IP address + 'https://example.com/' || (gs % 1000), -- Random URL with series number suffix + CASE WHEN random() < 0.5 THEN NULL ELSE 'https://referrer.com/' || (gs % 100)::text END, -- Random referrer, 50% chance of being NULL + now() - (random() * interval '5 days'), -- Last seen timestamp within the last 5 days + NULL, -- Keeping raw_fingerprint_response NULL for simplicity + CASE WHEN random() < 0.3 THEN gen_random_uuid() ELSE NULL END, -- Session UUID, 30% chance of NULL + NULL, -- Keeping fingerprint_response NULL for simplicity + CASE WHEN random() < 0.5 THEN '93.0' ELSE '92.0' END, -- Random browser version + CASE WHEN random() < 0.5 THEN 'Firefox' ELSE 'Chrome' END, -- Random browser name + CASE WHEN random() < 0.5 THEN 'Desktop' ELSE 'Mobile' END, -- Random device type + 'Windows', -- Static value for operating system + '10.0', -- Static value for operating system version + 'Mozilla/5.0', -- Static value for user agent + 'City ' || (gs % 1000)::text, -- Random city name + 'Region ' || (gs % 100)::text, -- Random region name + 'US', -- Static country code + random() * 180 - 90, -- Random latitude between -90 and 90 + random() * 360 - 180, -- Random longitude between -180 and 180 + random() < 0.1 -- 10% chance of being incognito +FROM generate_series(1, 100000) AS gs; \ No newline at end of file diff --git a/test_runner/performance/large_synthetic_oltp/grow_edges.sql b/test_runner/performance/large_synthetic_oltp/grow_edges.sql new file mode 100644 index 0000000000..17f289fe5b --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/grow_edges.sql @@ -0,0 +1,10 @@ +-- add 100000 rows or approximately 11 MB to the edges table +-- takes about 1 minute +INSERT INTO workflows.edges (created_at, workflow_id, uuid, from_vertex_id, to_vertex_id) +SELECT + now() - (random() * interval '365 days'), -- Random `created_at` timestamp in the last year + (random() * 100)::int + 1, -- Random `workflow_id` between 1 and 100 + uuid_generate_v4(), -- Generate a new UUID for each row + (random() * 100000)::bigint + 1, -- Random `from_vertex_id` between 1 and 100,000 + (random() * 100000)::bigint + 1 -- Random `to_vertex_id` between 1 and 100,000 +FROM generate_series(1, 100000) AS gs; -- Generate 100,000 sequential IDs \ No newline at end of file diff --git a/test_runner/performance/large_synthetic_oltp/grow_hotel_rate_mapping.sql b/test_runner/performance/large_synthetic_oltp/grow_hotel_rate_mapping.sql new file mode 100644 index 0000000000..1e79f94eab --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/grow_hotel_rate_mapping.sql @@ -0,0 +1,21 @@ +-- add 100000 rows or approximately 10 MB to the hotel_rate_mapping table +-- takes about 1 second +INSERT INTO booking_inventory.hotel_rate_mapping ( + uuid, + created_at, + updated_at, + hotel_rate_id, + remote_id, + source +) +SELECT + uuid_generate_v4(), -- Unique UUID for each row + now(), -- Created at timestamp + now(), -- Updated at timestamp + 'rate_' || gs AS hotel_rate_id, -- Unique hotel_rate_id + 'remote_' || gs AS remote_id, -- Unique remote_id + CASE WHEN gs % 3 = 0 THEN 'source_1' + WHEN gs % 3 = 1 THEN 'source_2' + ELSE 'source_3' + END AS source -- Distributing sources among three options +FROM generate_series(1, 100000) AS gs; \ No newline at end of file diff --git a/test_runner/performance/large_synthetic_oltp/grow_ocr_pipeline_results_version.sql b/test_runner/performance/large_synthetic_oltp/grow_ocr_pipeline_results_version.sql new file mode 100644 index 0000000000..21ebac74d2 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/grow_ocr_pipeline_results_version.sql @@ -0,0 +1,31 @@ +-- add 100000 rows or approximately 20 MB to the ocr_pipeline_results_version table +-- takes about 1 second +INSERT INTO ocr.ocr_pipeline_results_version ( + id, transaction_id, operation_type, created_at, updated_at, s3_filename, completed_at, result, + end_transaction_id, pipeline_type, is_async, callback, callback_kwargs, input, error, file_type, s3_bucket_name, pipeline_kwargs +) +SELECT + gs.aid, -- id + gs.aid, -- transaction_id (same as id for simplicity) + (gs.aid % 5)::smallint + 1, -- operation_type (cyclic values from 1 to 5) + now() - interval '1 day' * (random() * 30), -- created_at (random timestamp within the last 30 days) + now() - interval '1 day' * (random() * 30), -- updated_at (random timestamp within the last 30 days) + 's3_file_' || gs.aid || '.txt', -- s3_filename (synthetic filename) + now() - interval '1 day' * (random() * 30), -- completed_at (random timestamp within the last 30 days) + '{}'::jsonb, -- result (empty JSON object) + NULL, -- end_transaction_id (NULL) + CASE (gs.aid % 3) -- pipeline_type (cyclic text values) + WHEN 0 THEN 'OCR' + WHEN 1 THEN 'PDF' + ELSE 'Image' + END, + gs.aid % 2 = 0, -- is_async (alternating between true and false) + 'http://callback/' || gs.aid, -- callback (synthetic URL) + '{}'::jsonb, -- callback_kwargs (empty JSON object) + 'Input text ' || gs.aid, -- input (synthetic input text) + NULL, -- error (NULL) + 'pdf', -- file_type (default to 'pdf') + 'bucket_' || gs.aid % 10, -- s3_bucket_name (synthetic bucket names) + '{}'::jsonb -- pipeline_kwargs (empty JSON object) +FROM + generate_series(1, 100000) AS gs(aid); \ No newline at end of file diff --git a/test_runner/performance/large_synthetic_oltp/grow_priceline_raw_response.sql b/test_runner/performance/large_synthetic_oltp/grow_priceline_raw_response.sql new file mode 100644 index 0000000000..28c4f1a7fb --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/grow_priceline_raw_response.sql @@ -0,0 +1,18 @@ +-- add 100000 rows or approx. 20 MB to the priceline_raw_response table +-- takes about 20 seconds +INSERT INTO booking_inventory.priceline_raw_response ( + uuid, created_at, updated_at, url, base_url, path, method, params, request, response +) +SELECT + gen_random_uuid(), -- Generate random UUIDs + now() - (random() * interval '30 days'), -- Random creation time within the past 30 days + now() - (random() * interval '30 days'), -- Random update time within the past 30 days + 'https://example.com/resource/' || gs, -- Construct a unique URL for each row + 'https://example.com', -- Base URL for all rows + '/resource/' || gs, -- Path for each row + CASE WHEN gs % 2 = 0 THEN 'GET' ELSE 'POST' END, -- Alternate between GET and POST methods + 'id=' || gs, -- Simple parameter pattern for each row + '{}'::jsonb, -- Empty JSON object for request + jsonb_build_object('status', 'success', 'data', gs) -- Construct a valid JSON response +FROM + generate_series(1, 100000) AS gs; \ No newline at end of file diff --git a/test_runner/performance/large_synthetic_oltp/grow_relabled_transactions.sql b/test_runner/performance/large_synthetic_oltp/grow_relabled_transactions.sql new file mode 100644 index 0000000000..0b1aa2d2bd --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/grow_relabled_transactions.sql @@ -0,0 +1,26 @@ +-- add 100000 rows or approx. 1 MB to the relabeled_transactions table +-- takes about 1 second +INSERT INTO heron.relabeled_transactions ( + id, + created_at, + universal_transaction_id, + raw_result, + category, + category_confidence, + merchant, + batch_id +) +SELECT + gs.aid AS id, + now() - (gs.aid % 1000) * interval '1 second' AS created_at, + 'txn_' || gs.aid AS universal_transaction_id, + '{}'::jsonb AS raw_result, + CASE WHEN gs.aid % 5 = 0 THEN 'grocery' + WHEN gs.aid % 5 = 1 THEN 'electronics' + WHEN gs.aid % 5 = 2 THEN 'clothing' + WHEN gs.aid % 5 = 3 THEN 'utilities' + ELSE NULL END AS category, + ROUND(RANDOM()::numeric, 2) AS category_confidence, + CASE WHEN gs.aid % 2 = 0 THEN 'Merchant_' || gs.aid % 20 ELSE NULL END AS merchant, + gs.aid % 100 + 1 AS batch_id +FROM generate_series(1, 100000) AS gs(aid); \ No newline at end of file diff --git a/test_runner/performance/large_synthetic_oltp/grow_state_values.sql b/test_runner/performance/large_synthetic_oltp/grow_state_values.sql new file mode 100644 index 0000000000..8a8ce146be --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/grow_state_values.sql @@ -0,0 +1,9 @@ +-- add 100000 rows or approx.10 MB to the state_values table +-- takes about 14 seconds +INSERT INTO workflows.state_values (key, workflow_id, state_type, value_id) +SELECT + 'key_' || gs::text, -- Key: Generate as 'key_1', 'key_2', etc. + (gs - 1) / 1000 + 1, -- workflow_id: Distribute over a range (1000 workflows) + 'STATIC', -- state_type: Use constant 'STATIC' as defined in schema + gs::bigint -- value_id: Use the same as the series value +FROM generate_series(1, 100000) AS gs; -- Generate 100,000 rows \ No newline at end of file diff --git a/test_runner/performance/large_synthetic_oltp/grow_values.sql b/test_runner/performance/large_synthetic_oltp/grow_values.sql new file mode 100644 index 0000000000..3afdafdf86 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/grow_values.sql @@ -0,0 +1,30 @@ +-- add 100000 rows or approx. 24 MB to the values table +-- takes about 126 seconds +INSERT INTO workflows.values ( + id, + type, + int_value, + string_value, + child_type, + bool_value, + uuid, + numeric_value, + workflow_id, + jsonb_value, + parent_value_id +) +SELECT + gs AS id, + 'TYPE_A' AS type, + CASE WHEN selector = 1 THEN gs ELSE NULL END AS int_value, + CASE WHEN selector = 2 THEN 'string_value_' || gs::text ELSE NULL END AS string_value, + 'CHILD_TYPE_A' AS child_type, -- Always non-null + CASE WHEN selector = 3 THEN (gs % 2 = 0) ELSE NULL END AS bool_value, + uuid_generate_v4() AS uuid, -- Always non-null + CASE WHEN selector = 4 THEN gs * 1.0 ELSE NULL END AS numeric_value, + (array[1, 2, 3, 4, 5])[gs % 5 + 1] AS workflow_id, -- Use only existing workflow IDs + CASE WHEN selector = 5 THEN ('{"key":' || gs::text || '}')::jsonb ELSE NULL END AS jsonb_value, + (gs % 100) + 1 AS parent_value_id -- Always non-null +FROM + generate_series(1, 100000) AS gs, + (SELECT floor(random() * 5 + 1)::int AS selector) AS s; \ No newline at end of file diff --git a/test_runner/performance/large_synthetic_oltp/grow_vertices.sql b/test_runner/performance/large_synthetic_oltp/grow_vertices.sql new file mode 100644 index 0000000000..87a2410e8a --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/grow_vertices.sql @@ -0,0 +1,26 @@ +-- add 100000 rows or approx. 18 MB to the vertices table +-- takes about 90 seconds +INSERT INTO workflows.vertices( + uuid, + created_at, + condition_block_id, + operator, + has_been_visited, + reference_id, + workflow_id, + meta_data, + -- id, + action_block_id +) +SELECT + uuid_generate_v4() AS uuid, + now() AS created_at, + CASE WHEN (gs % 2 = 0) THEN gs % 10 ELSE NULL END AS condition_block_id, -- Every alternate row has a condition_block_id + 'operator_' || (gs % 10) AS operator, -- Cyclical operator values (e.g., operator_0, operator_1) + false AS has_been_visited, + 'ref_' || gs AS reference_id, -- Unique reference_id for each row + (gs % 1000) + 1 AS workflow_id, -- Random workflow_id values between 1 and 1000 + '{}'::jsonb AS meta_data, -- Empty JSON metadata + -- gs AS id, -- default from sequence to get unique ID + CASE WHEN (gs % 2 = 1) THEN gs ELSE NULL END AS action_block_id -- Complementary to condition_block_id +FROM generate_series(1, 100000) AS gs; \ No newline at end of file diff --git a/test_runner/performance/large_synthetic_oltp/update_accounting_coding_body_tracking_category_selection.sql b/test_runner/performance/large_synthetic_oltp/update_accounting_coding_body_tracking_category_selection.sql new file mode 100644 index 0000000000..78688fc8ba --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_accounting_coding_body_tracking_category_selection.sql @@ -0,0 +1,9 @@ +-- update approximately 2000 rows or 200 kb in the accounting_coding_body_tracking_category_selection table +-- takes about 1 second +UPDATE accounting.accounting_coding_body_tracking_category_selection +SET created_at = now() +WHERE ctid in ( + SELECT ctid + FROM accounting.accounting_coding_body_tracking_category_selection + TABLESAMPLE SYSTEM (0.0005) +); diff --git a/test_runner/performance/large_synthetic_oltp/update_action_blocks.sql b/test_runner/performance/large_synthetic_oltp/update_action_blocks.sql new file mode 100644 index 0000000000..ad1ee6c749 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_action_blocks.sql @@ -0,0 +1,9 @@ +-- update approximately 9000 rows or 1 MB in the action_blocks table +-- takes about 1 second +UPDATE workflows.action_blocks +SET run_synchronously = NOT run_synchronously +WHERE ctid in ( + SELECT ctid + FROM workflows.action_blocks + TABLESAMPLE SYSTEM (0.001) +); diff --git a/test_runner/performance/large_synthetic_oltp/update_action_kwargs.sql b/test_runner/performance/large_synthetic_oltp/update_action_kwargs.sql new file mode 100644 index 0000000000..b939c0ff2d --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_action_kwargs.sql @@ -0,0 +1,9 @@ +-- update approximately 5000 rows or 1 MB in the action_kwargs table +-- takes about 1 second +UPDATE workflows.action_kwargs +SET created_at = now() +WHERE ctid in ( + SELECT ctid + FROM workflows.action_kwargs + TABLESAMPLE SYSTEM (0.0002) +); diff --git a/test_runner/performance/large_synthetic_oltp/update_denormalized_approval_workflow.sql b/test_runner/performance/large_synthetic_oltp/update_denormalized_approval_workflow.sql new file mode 100644 index 0000000000..671ddbc2d4 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_denormalized_approval_workflow.sql @@ -0,0 +1,10 @@ +-- update approximately 3000 rows or 500 KB in the denormalized_approval_workflow table +-- takes about 1 second +UPDATE approvals_v2.denormalized_approval_workflow +SET created_at = now() +WHERE ctid in ( + SELECT ctid + FROM approvals_v2.denormalized_approval_workflow + TABLESAMPLE SYSTEM (0.0005) +); + diff --git a/test_runner/performance/large_synthetic_oltp/update_device_fingerprint_event.sql b/test_runner/performance/large_synthetic_oltp/update_device_fingerprint_event.sql new file mode 100644 index 0000000000..20baf12887 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_device_fingerprint_event.sql @@ -0,0 +1,9 @@ +-- update approximately 2000 rows or 1 MB in the device_fingerprint_event table +-- takes about 5 seconds +UPDATE authentication.device_fingerprint_event +SET is_incognito = NOT is_incognito +WHERE ctid in ( + SELECT ctid + FROM authentication.device_fingerprint_event + TABLESAMPLE SYSTEM (0.001) +); diff --git a/test_runner/performance/large_synthetic_oltp/update_edges.sql b/test_runner/performance/large_synthetic_oltp/update_edges.sql new file mode 100644 index 0000000000..d79da78de3 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_edges.sql @@ -0,0 +1,9 @@ +-- update approximately 4000 rows or 600 kb in the edges table +-- takes about 1 second +UPDATE workflows.edges +SET created_at = now() +WHERE ctid in ( + SELECT ctid + FROM workflows.edges + TABLESAMPLE SYSTEM (0.0005) +); diff --git a/test_runner/performance/large_synthetic_oltp/update_heron_transaction_enriched_log.sql b/test_runner/performance/large_synthetic_oltp/update_heron_transaction_enriched_log.sql new file mode 100644 index 0000000000..5bcc885736 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_heron_transaction_enriched_log.sql @@ -0,0 +1,9 @@ +-- update approximately 10000 rows or 200 KB in the heron_transaction_enriched_log table +-- takes about 1 minutes +UPDATE heron.heron_transaction_enriched_log +SET created_at = now() +WHERE ctid in ( + SELECT ctid + FROM heron.heron_transaction_enriched_log + TABLESAMPLE SYSTEM (0.0005) +); diff --git a/test_runner/performance/large_synthetic_oltp/update_heron_transaction_enrichment_requests.sql b/test_runner/performance/large_synthetic_oltp/update_heron_transaction_enrichment_requests.sql new file mode 100644 index 0000000000..02cf0ca420 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_heron_transaction_enrichment_requests.sql @@ -0,0 +1,9 @@ +-- update approximately 4000 rows or 1 MB in the heron_transaction_enrichment_requests table +-- takes about 2 minutes +UPDATE heron.heron_transaction_enrichment_requests +SET created_at = now() +WHERE ctid in ( + SELECT ctid + FROM heron.heron_transaction_enrichment_requests + TABLESAMPLE SYSTEM (0.0002) +); diff --git a/test_runner/performance/large_synthetic_oltp/update_hotel_rate_mapping.sql b/test_runner/performance/large_synthetic_oltp/update_hotel_rate_mapping.sql new file mode 100644 index 0000000000..3210b6dff8 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_hotel_rate_mapping.sql @@ -0,0 +1,9 @@ +-- update approximately 6000 rows or 600 kb in the hotel_rate_mapping table +-- takes about 1 second +UPDATE booking_inventory.hotel_rate_mapping +SET created_at = now() +WHERE ctid in ( + SELECT ctid + FROM booking_inventory.hotel_rate_mapping + TABLESAMPLE SYSTEM (0.0005) +); diff --git a/test_runner/performance/large_synthetic_oltp/update_incoming_webhooks.sql b/test_runner/performance/large_synthetic_oltp/update_incoming_webhooks.sql new file mode 100644 index 0000000000..ea284eb47c --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_incoming_webhooks.sql @@ -0,0 +1,9 @@ +-- update approximately 2000 rows or 1 MB in the incoming_webhooks table +-- takes about 5 seconds +UPDATE webhook.incoming_webhooks +SET is_body_encrypted = NOT is_body_encrypted +WHERE ctid in ( + SELECT ctid + FROM webhook.incoming_webhooks + TABLESAMPLE SYSTEM (0.0002) +); diff --git a/test_runner/performance/large_synthetic_oltp/update_manual_transaction.sql b/test_runner/performance/large_synthetic_oltp/update_manual_transaction.sql new file mode 100644 index 0000000000..190bc625e2 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_manual_transaction.sql @@ -0,0 +1,9 @@ +-- update approximately 1000 rows or 200 kb in the manual_transaction table +-- takes about 2 seconds +UPDATE banking.manual_transaction +SET created_at = now() +WHERE ctid in ( + SELECT ctid + FROM banking.manual_transaction + TABLESAMPLE SYSTEM (0.0005) +); diff --git a/test_runner/performance/large_synthetic_oltp/update_ml_receipt_matching_log.sql b/test_runner/performance/large_synthetic_oltp/update_ml_receipt_matching_log.sql new file mode 100644 index 0000000000..810021b09d --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_ml_receipt_matching_log.sql @@ -0,0 +1,9 @@ +-- update approximately 1000 rows or 100 kb in the ml_receipt_matching_log table +-- takes about 1 second +UPDATE receipt.ml_receipt_matching_log +SET is_shadow_mode = NOT is_shadow_mode +WHERE ctid in ( + SELECT ctid + FROM receipt.ml_receipt_matching_log + TABLESAMPLE SYSTEM (0.0005) +); diff --git a/test_runner/performance/large_synthetic_oltp/update_ocr_pipeine_results_version.sql b/test_runner/performance/large_synthetic_oltp/update_ocr_pipeine_results_version.sql new file mode 100644 index 0000000000..a1da8fdb07 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_ocr_pipeine_results_version.sql @@ -0,0 +1,9 @@ +-- update approximately 2000 rows or 400 kb in the ocr_pipeline_results_version table +-- takes about 1 second +UPDATE ocr.ocr_pipeline_results_version +SET is_async = NOT is_async +WHERE ctid in ( + SELECT ctid + FROM ocr.ocr_pipeline_results_version + TABLESAMPLE SYSTEM (0.0005) +); diff --git a/test_runner/performance/large_synthetic_oltp/update_orc_pipeline_step_results.sql b/test_runner/performance/large_synthetic_oltp/update_orc_pipeline_step_results.sql new file mode 100644 index 0000000000..b7bb4932bd --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_orc_pipeline_step_results.sql @@ -0,0 +1,9 @@ +-- update approximately 3000 rows or 1 MB in the ocr_pipeline_step_results table +-- takes about 11 seconds +UPDATE ocr.ocr_pipeline_step_results +SET created_at = now() +WHERE ctid in ( + SELECT ctid + FROM ocr.ocr_pipeline_step_results + TABLESAMPLE SYSTEM (0.0005) +); diff --git a/test_runner/performance/large_synthetic_oltp/update_orc_pipeline_step_results_version.sql b/test_runner/performance/large_synthetic_oltp/update_orc_pipeline_step_results_version.sql new file mode 100644 index 0000000000..83e9765d22 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_orc_pipeline_step_results_version.sql @@ -0,0 +1,9 @@ +-- update approximately 5000 rows or 1 MB in the ocr_pipeline_step_results_version table +-- takes about 40 seconds +UPDATE ocr.ocr_pipeline_step_results_version +SET created_at = now() +WHERE ctid in ( + SELECT ctid + FROM ocr.ocr_pipeline_step_results_version + TABLESAMPLE SYSTEM (0.0005) +); diff --git a/test_runner/performance/large_synthetic_oltp/update_priceline_raw_response.sql b/test_runner/performance/large_synthetic_oltp/update_priceline_raw_response.sql new file mode 100644 index 0000000000..a434c6cb63 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_priceline_raw_response.sql @@ -0,0 +1,9 @@ +-- update approximately 5000 rows or 1 MB in the priceline_raw_response table +-- takes about 1 second +UPDATE booking_inventory.priceline_raw_response +SET created_at = now() +WHERE ctid in ( + SELECT ctid + FROM booking_inventory.priceline_raw_response + TABLESAMPLE SYSTEM (0.0005) +); diff --git a/test_runner/performance/large_synthetic_oltp/update_quickbooks_transactions.sql b/test_runner/performance/large_synthetic_oltp/update_quickbooks_transactions.sql new file mode 100644 index 0000000000..a783246c4c --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_quickbooks_transactions.sql @@ -0,0 +1,9 @@ +-- update approximately 5000 rows or 1 MB in the quickbooks_transactions table +-- takes about 30 seconds +UPDATE accounting.quickbooks_transactions +SET created_at = now() +WHERE ctid in ( + SELECT ctid + FROM accounting.quickbooks_transactions + TABLESAMPLE SYSTEM (0.0005) +); diff --git a/test_runner/performance/large_synthetic_oltp/update_raw_finicity_transaction.sql b/test_runner/performance/large_synthetic_oltp/update_raw_finicity_transaction.sql new file mode 100644 index 0000000000..91fb1bc789 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_raw_finicity_transaction.sql @@ -0,0 +1,15 @@ +-- update approximately 6000 rows or 600 kb in the raw_finicity_transaction table +-- takes about 1 second +UPDATE banking.raw_finicity_transaction +SET raw_data = + jsonb_set( + raw_data, + '{updated}', + to_jsonb(now()), + true + ) +WHERE ctid IN ( + SELECT ctid + FROM banking.raw_finicity_transaction + TABLESAMPLE SYSTEM (0.0005) +); diff --git a/test_runner/performance/large_synthetic_oltp/update_relabeled_transactions.sql b/test_runner/performance/large_synthetic_oltp/update_relabeled_transactions.sql new file mode 100644 index 0000000000..87b402f9e7 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_relabeled_transactions.sql @@ -0,0 +1,9 @@ +-- update approximately 8000 rows or 1 MB in the relabeled_transactions table +-- takes about 1 second +UPDATE heron.relabeled_transactions +SET created_at = now() +WHERE ctid in ( + SELECT ctid + FROM heron.relabeled_transactions + TABLESAMPLE SYSTEM (0.0005) +); diff --git a/test_runner/performance/large_synthetic_oltp/update_state_values.sql b/test_runner/performance/large_synthetic_oltp/update_state_values.sql new file mode 100644 index 0000000000..2365ea3d6b --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_state_values.sql @@ -0,0 +1,9 @@ +-- update approximately 8000 rows or 1 MB in the state_values table +-- takes about 2 minutes +UPDATE workflows.state_values +SET state_type = now()::text +WHERE ctid in ( + SELECT ctid + FROM workflows.state_values + TABLESAMPLE SYSTEM (0.0002) +); \ No newline at end of file diff --git a/test_runner/performance/large_synthetic_oltp/update_stripe_authorization_event_log.sql b/test_runner/performance/large_synthetic_oltp/update_stripe_authorization_event_log.sql new file mode 100644 index 0000000000..5328db9fb8 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_stripe_authorization_event_log.sql @@ -0,0 +1,9 @@ +-- update approximately 4000 rows or 1 MB in the stripe_authorization_event_log table +-- takes about 5 minutes +UPDATE stripe.stripe_authorization_event_log +SET approved = NOT approved +WHERE ctid in ( + SELECT ctid + FROM stripe.stripe_authorization_event_log + TABLESAMPLE SYSTEM (0.0002) +); \ No newline at end of file diff --git a/test_runner/performance/large_synthetic_oltp/update_transaction.sql b/test_runner/performance/large_synthetic_oltp/update_transaction.sql new file mode 100644 index 0000000000..83bec52065 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_transaction.sql @@ -0,0 +1,9 @@ +-- update approximately 2000 rows or 301 MB in the transaction table +-- takes about 90 seconds +UPDATE transaction.transaction +SET is_last = NOT is_last +WHERE ctid in ( + SELECT ctid + FROM transaction.transaction + TABLESAMPLE SYSTEM (0.0002) +); diff --git a/test_runner/performance/large_synthetic_oltp/update_values.sql b/test_runner/performance/large_synthetic_oltp/update_values.sql new file mode 100644 index 0000000000..e5d576dae5 --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_values.sql @@ -0,0 +1,9 @@ +-- update approximately 2500 rows or 1 MB in the values table +-- takes about 3 minutes +UPDATE workflows.values +SET bool_value = NOT bool_value +WHERE ctid in ( + SELECT ctid + FROM workflows.values + TABLESAMPLE SYSTEM (0.0002) +) AND bool_value IS NOT NULL; diff --git a/test_runner/performance/large_synthetic_oltp/update_vertices.sql b/test_runner/performance/large_synthetic_oltp/update_vertices.sql new file mode 100644 index 0000000000..714c38965b --- /dev/null +++ b/test_runner/performance/large_synthetic_oltp/update_vertices.sql @@ -0,0 +1,9 @@ +-- update approximately 10000 rows or 2 MB in the vertices table +-- takes about 1 minute +UPDATE workflows.vertices +SET has_been_visited = NOT has_been_visited +WHERE ctid in ( + SELECT ctid + FROM workflows.vertices + TABLESAMPLE SYSTEM (0.0002) +); \ No newline at end of file diff --git a/test_runner/performance/pageserver/pagebench/test_large_slru_basebackup.py b/test_runner/performance/pageserver/pagebench/test_large_slru_basebackup.py index 8af52dcbd0..25dfd5277c 100644 --- a/test_runner/performance/pageserver/pagebench/test_large_slru_basebackup.py +++ b/test_runner/performance/pageserver/pagebench/test_large_slru_basebackup.py @@ -146,8 +146,6 @@ def run_benchmark(env: NeonEnv, pg_bin: PgBin, record, duration_secs: int): ps_http.base_url, "--page-service-connstring", env.pageserver.connstr(password=None), - "--gzip-probability", - "1", "--runtime", f"{duration_secs}s", # don't specify the targets explicitly, let pagebench auto-discover them diff --git a/test_runner/performance/test_perf_oltp_large_tenant.py b/test_runner/performance/test_perf_oltp_large_tenant.py index b45394d627..bd00f6b65f 100644 --- a/test_runner/performance/test_perf_oltp_large_tenant.py +++ b/test_runner/performance/test_perf_oltp_large_tenant.py @@ -31,7 +31,9 @@ def get_custom_scripts( return rv -def run_test_pgbench(env: PgCompare, custom_scripts: str, duration: int): +def run_test_pgbench( + env: PgCompare, custom_scripts: str, duration: int, clients: int = 500, jobs: int = 100 +): password = env.pg.default_options.get("password", None) options = env.pg.default_options.get("options", "") # drop password from the connection string by passing password=None and set password separately @@ -46,8 +48,8 @@ def run_test_pgbench(env: PgCompare, custom_scripts: str, duration: int): "-n", # no explicit vacuum before the test - we want to rely on auto-vacuum "-M", "prepared", - "--client=500", - "--jobs=100", + f"--client={clients}", + f"--jobs={jobs}", f"-T{duration}", "-P60", # progress every minute "--progress-timestamp", @@ -164,6 +166,12 @@ def test_perf_oltp_large_tenant_pgbench( run_test_pgbench(remote_compare, custom_scripts, duration) +@pytest.mark.parametrize("duration", get_durations_matrix()) +@pytest.mark.remote_cluster +def test_perf_oltp_large_tenant_growth(remote_compare: PgCompare, duration: int): + run_test_pgbench(remote_compare, " ".join(get_custom_scripts()), duration, 35, 35) + + @pytest.mark.remote_cluster def test_perf_oltp_large_tenant_maintenance(remote_compare: PgCompare): # run analyze, vacuum, re-index after the test and measure and report its duration diff --git a/test_runner/random_ops/test_random_ops.py b/test_runner/random_ops/test_random_ops.py index 645c9b7b9d..d3815c40bb 100644 --- a/test_runner/random_ops/test_random_ops.py +++ b/test_runner/random_ops/test_random_ops.py @@ -45,6 +45,8 @@ class NeonEndpoint: if self.branch.connect_env: self.connect_env = self.branch.connect_env.copy() self.connect_env["PGHOST"] = self.host + if self.type == "read_only": + self.project.read_only_endpoints_total += 1 def delete(self): self.project.delete_endpoint(self.id) @@ -228,8 +230,13 @@ class NeonProject: self.benchmarks: dict[str, subprocess.Popen[Any]] = {} self.restore_num: int = 0 self.restart_pgbench_on_console_errors: bool = False + self.limits: dict[str, Any] = self.get_limits()["limits"] + self.read_only_endpoints_total: int = 0 - def delete(self): + def get_limits(self) -> dict[str, Any]: + return self.neon_api.get_project_limits(self.id) + + def delete(self) -> None: self.neon_api.delete_project(self.id) def create_branch(self, parent_id: str | None = None) -> NeonBranch | None: @@ -282,6 +289,7 @@ class NeonProject: self.neon_api.delete_endpoint(self.id, endpoint_id) self.endpoints[endpoint_id].branch.endpoints.pop(endpoint_id) self.endpoints.pop(endpoint_id) + self.read_only_endpoints_total -= 1 self.wait() def start_benchmark(self, target: str, clients: int = 10) -> subprocess.Popen[Any]: @@ -369,49 +377,64 @@ def setup_class( print(f"::warning::Retried on 524 error {neon_api.retries524} times") if neon_api.retries4xx > 0: print(f"::warning::Retried on 4xx error {neon_api.retries4xx} times") - log.info("Removing the project") + log.info("Removing the project %s", project.id) project.delete() -def do_action(project: NeonProject, action: str) -> None: +def do_action(project: NeonProject, action: str) -> bool: """ Runs the action """ log.info("Action: %s", action) if action == "new_branch": log.info("Trying to create a new branch") + if 0 <= project.limits["max_branches"] <= len(project.branches): + log.info( + "Maximum branch limit exceeded (%s of %s)", + len(project.branches), + project.limits["max_branches"], + ) + return False parent = project.branches[ random.choice(list(set(project.branches.keys()) - project.reset_branches)) ] log.info("Parent: %s", parent) child = parent.create_child_branch() if child is None: - return + return False log.info("Created branch %s", child) child.start_benchmark() elif action == "delete_branch": if project.leaf_branches: - target = random.choice(list(project.leaf_branches.values())) + target: NeonBranch = random.choice(list(project.leaf_branches.values())) log.info("Trying to delete branch %s", target) target.delete() else: log.info("Leaf branches not found, skipping") + return False elif action == "new_ro_endpoint": + if 0 <= project.limits["max_read_only_endpoints"] <= project.read_only_endpoints_total: + log.info( + "Maximum read only endpoint limit exceeded (%s of %s)", + project.read_only_endpoints_total, + project.limits["max_read_only_endpoints"], + ) + return False ep = random.choice( [br for br in project.branches.values() if br.id not in project.reset_branches] ).create_ro_endpoint() log.info("Created the RO endpoint with id %s branch: %s", ep.id, ep.branch.id) ep.start_benchmark() elif action == "delete_ro_endpoint": + if project.read_only_endpoints_total == 0: + log.info("no read_only endpoints present, skipping") + return False ro_endpoints: list[NeonEndpoint] = [ endpoint for endpoint in project.endpoints.values() if endpoint.type == "read_only" ] - if ro_endpoints: - target_ep: NeonEndpoint = random.choice(ro_endpoints) - target_ep.delete() - log.info("endpoint %s deleted", target_ep.id) - else: - log.info("no read_only endpoints present, skipping") + target_ep: NeonEndpoint = random.choice(ro_endpoints) + target_ep.delete() + log.info("endpoint %s deleted", target_ep.id) elif action == "restore_random_time": if project.leaf_branches: br: NeonBranch = random.choice(list(project.leaf_branches.values())) @@ -419,8 +442,10 @@ def do_action(project: NeonProject, action: str) -> None: br.restore_random_time() else: log.info("No leaf branches found") + return False else: raise ValueError(f"The action {action} is unknown") + return True @pytest.mark.timeout(7200) @@ -457,8 +482,9 @@ def test_api_random( pg_bin.run(["pgbench", "-i", "-I", "dtGvp", "-s100"], env=project.main_branch.connect_env) for _ in range(num_operations): log.info("Starting action #%s", _ + 1) - do_action( + while not do_action( project, random.choices([a[0] for a in ACTIONS], weights=[w[1] for w in ACTIONS])[0] - ) + ): + log.info("Retrying...") project.check_all_benchmarks() assert True diff --git a/test_runner/regress/test_attach_tenant_config.py b/test_runner/regress/test_attach_tenant_config.py index dc44fc77db..7788faceb4 100644 --- a/test_runner/regress/test_attach_tenant_config.py +++ b/test_runner/regress/test_attach_tenant_config.py @@ -184,7 +184,7 @@ def test_fully_custom_config(positive_env: NeonEnv): "timeline_offloading": False, "rel_size_v2_enabled": True, "relsize_snapshot_cache_capacity": 10000, - "gc_compaction_enabled": True, + "gc_compaction_enabled": False, "gc_compaction_verification": False, "gc_compaction_initial_threshold_kb": 1024000, "gc_compaction_ratio_percent": 200, diff --git a/test_runner/regress/test_basebackup.py b/test_runner/regress/test_basebackup.py index b083c394c7..d1b10ec85d 100644 --- a/test_runner/regress/test_basebackup.py +++ b/test_runner/regress/test_basebackup.py @@ -26,6 +26,10 @@ def test_basebackup_cache(neon_env_builder: NeonEnvBuilder): ps = env.pageserver ps_http = ps.http_client() + storcon_managed_timelines = (env.storage_controller_config or {}).get( + "timelines_onto_safekeepers", False + ) + # 1. Check that we always hit the cache after compute restart. for i in range(3): ep.start() @@ -33,15 +37,26 @@ def test_basebackup_cache(neon_env_builder: NeonEnvBuilder): def check_metrics(i=i): metrics = ps_http.get_metrics() - # Never miss. - # The first time compute_ctl sends `get_basebackup` with lsn=None, we do not cache such requests. - # All other requests should be a hit - assert ( - metrics.query_one( - "pageserver_basebackup_cache_read_total", {"result": "miss"} - ).value - == 0 - ) + if storcon_managed_timelines: + # We do not cache the initial basebackup yet, + # so the first compute startup should be a miss. + assert ( + metrics.query_one( + "pageserver_basebackup_cache_read_total", {"result": "miss"} + ).value + == 1 + ) + else: + # If the timeline is not initialized on safekeeprs, + # the compute_ctl sends `get_basebackup` with lsn=None for the first startup. + # We do not use cache for such requests, so it's niether a hit nor a miss. + assert ( + metrics.query_one( + "pageserver_basebackup_cache_read_total", {"result": "miss"} + ).value + == 0 + ) + # All but the first requests are hits. assert ( metrics.query_one("pageserver_basebackup_cache_read_total", {"result": "hit"}).value @@ -54,6 +69,11 @@ def test_basebackup_cache(neon_env_builder: NeonEnvBuilder): ).value == i + 1 ) + # There should be only one basebackup file in the cache. + assert metrics.query_one("pageserver_basebackup_cache_entries_total").value == 1 + # The size of one basebackup for new DB is ~20KB. + size_bytes = metrics.query_one("pageserver_basebackup_cache_size_bytes").value + assert 10 * 1024 <= size_bytes <= 100 * 1024 wait_until(check_metrics) diff --git a/test_runner/regress/test_branching.py b/test_runner/regress/test_branching.py index 9ce618b2ad..920c538069 100644 --- a/test_runner/regress/test_branching.py +++ b/test_runner/regress/test_branching.py @@ -11,6 +11,7 @@ from fixtures.common_types import Lsn, TimelineId from fixtures.log_helper import log from fixtures.pageserver.http import PageserverApiException from fixtures.pageserver.utils import wait_until_tenant_active +from fixtures.safekeeper.http import MembershipConfiguration, TimelineCreateRequest from fixtures.utils import query_scalar from performance.test_perf_pgbench import get_scales_matrix from requests import RequestException @@ -164,6 +165,19 @@ def test_cannot_create_endpoint_on_non_uploaded_timeline(neon_env_builder: NeonE ps_http.configure_failpoints(("before-upload-index-pausable", "pause")) env.pageserver.tenant_create(env.initial_tenant) + sk = env.safekeepers[0] + assert sk + sk.http_client().timeline_create( + TimelineCreateRequest( + env.initial_tenant, + env.initial_timeline, + MembershipConfiguration(generation=1, members=[sk.safekeeper_id()], new_members=None), + int(env.pg_version) * 10000, + Lsn(0), + None, + ) + ) + initial_branch = "initial_branch" def start_creating_timeline(): diff --git a/test_runner/regress/test_compatibility.py b/test_runner/regress/test_compatibility.py index 784afbba82..16ab2bb359 100644 --- a/test_runner/regress/test_compatibility.py +++ b/test_runner/regress/test_compatibility.py @@ -18,6 +18,8 @@ from fixtures.neon_fixtures import ( NeonEnv, NeonEnvBuilder, PgBin, + Safekeeper, + StorageControllerApiException, flush_ep_to_pageserver, ) from fixtures.pageserver.http import PageserverApiException @@ -26,6 +28,7 @@ from fixtures.pageserver.utils import ( ) from fixtures.pg_version import PgVersion from fixtures.remote_storage import RemoteStorageKind, S3Storage, s3_storage +from fixtures.safekeeper.http import MembershipConfiguration from fixtures.workload import Workload if TYPE_CHECKING: @@ -125,6 +128,12 @@ check_ondisk_data_compatibility_if_enabled = pytest.mark.skipif( reason="CHECK_ONDISK_DATA_COMPATIBILITY env is not set", ) +skip_old_debug_versions = pytest.mark.skipif( + os.getenv("BUILD_TYPE", "debug") == "debug" + and os.getenv("DEFAULT_PG_VERSION") in [PgVersion.V14, PgVersion.V15, PgVersion.V16], + reason="compatibility snaphots not available for old versions of debug builds", +) + @pytest.mark.xdist_group("compatibility") @pytest.mark.order(before="test_forward_compatibility") @@ -195,6 +204,7 @@ ingest_lag_log_line = ".*ingesting record with timestamp lagging more than wait_ @check_ondisk_data_compatibility_if_enabled +@skip_old_debug_versions @pytest.mark.xdist_group("compatibility") @pytest.mark.order(after="test_create_snapshot") def test_backward_compatibility( @@ -222,6 +232,7 @@ def test_backward_compatibility( @check_ondisk_data_compatibility_if_enabled +@skip_old_debug_versions @pytest.mark.xdist_group("compatibility") @pytest.mark.order(after="test_create_snapshot") def test_forward_compatibility( @@ -291,7 +302,20 @@ def test_forward_compatibility( def check_neon_works(env: NeonEnv, test_output_dir: Path, sql_dump_path: Path, repo_dir: Path): ep = env.endpoints.create("main") ep_env = {"LD_LIBRARY_PATH": str(env.pg_distrib_dir / f"v{env.pg_version}/lib")} - ep.start(env=ep_env) + + # If the compatibility snapshot was created with --timelines-onto-safekeepers=false, + # we should not pass safekeeper_generation to the endpoint because the compute + # will not be able to start. + # Zero generation is INVALID_GENERATION. + generation = 0 + try: + res = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline) + generation = res["generation"] + except StorageControllerApiException as e: + if e.status_code != 404 or not re.search(r"Timeline .* not found", str(e)): + raise e + + ep.start(env=ep_env, safekeeper_generation=generation) connstr = ep.connstr() @@ -341,7 +365,7 @@ def check_neon_works(env: NeonEnv, test_output_dir: Path, sql_dump_path: Path, r ) # Timeline exists again: restart the endpoint - ep.start(env=ep_env) + ep.start(env=ep_env, safekeeper_generation=generation) pg_bin.run_capture( ["pg_dumpall", f"--dbname={connstr}", f"--file={test_output_dir / 'dump-from-wal.sql'}"] @@ -542,6 +566,24 @@ def test_historic_storage_formats( # All our artifacts should contain at least one timeline assert len(timelines) > 0 + # Import tenant does not create the timeline on safekeepers, + # because it is a debug handler and the timeline may have already been + # created on some set of safekeepers. + # Create the timeline on safekeepers manually. + # TODO(diko): when we have the script/storcon handler to migrate + # the timeline to storcon, we can replace this code with it. + mconf = MembershipConfiguration( + generation=1, + members=Safekeeper.sks_to_safekeeper_ids([env.safekeepers[0]]), + new_members=None, + ) + members_sks = Safekeeper.mconf_sks(env, mconf) + + for timeline in timelines: + Safekeeper.create_timeline( + dataset.tenant_id, timeline["timeline_id"], env.pageserver, mconf, members_sks + ) + # TODO: ensure that the snapshots we're importing contain a sensible variety of content, at the very # least they should include a mixture of deltas and image layers. Preferably they should also # contain some "exotic" stuff like aux files from logical replication. @@ -573,6 +615,7 @@ def test_historic_storage_formats( @check_ondisk_data_compatibility_if_enabled +@skip_old_debug_versions @pytest.mark.xdist_group("compatibility") @pytest.mark.parametrize( **fixtures.utils.allpairs_versions(), diff --git a/test_runner/regress/test_compute_metrics.py b/test_runner/regress/test_compute_metrics.py index c751a3e7cc..d1e61e597c 100644 --- a/test_runner/regress/test_compute_metrics.py +++ b/test_runner/regress/test_compute_metrics.py @@ -418,7 +418,7 @@ def test_sql_exporter_metrics_e2e( pg_user = conn_options["user"] pg_dbname = conn_options["dbname"] pg_application_name = f"sql_exporter{stem_suffix}" - connstr = f"postgresql://{pg_user}@{pg_host}:{pg_port}/{pg_dbname}?sslmode=disable&application_name={pg_application_name}" + connstr = f"postgresql://{pg_user}@{pg_host}:{pg_port}/{pg_dbname}?sslmode=disable&application_name={pg_application_name}&pgaudit.log=none" def escape_go_filepath_match_characters(s: str) -> str: """ diff --git a/test_runner/regress/test_compute_reconfigure.py b/test_runner/regress/test_compute_reconfigure.py index b533d45b1e..cc792333ba 100644 --- a/test_runner/regress/test_compute_reconfigure.py +++ b/test_runner/regress/test_compute_reconfigure.py @@ -9,6 +9,8 @@ from fixtures.utils import wait_until if TYPE_CHECKING: from fixtures.neon_fixtures import NeonEnv +from fixtures.log_helper import log + def test_compute_reconfigure(neon_simple_env: NeonEnv): """ @@ -85,3 +87,57 @@ def test_compute_reconfigure(neon_simple_env: NeonEnv): samples = metrics.query_all("compute_ctl_up", {"build_tag": build_tag}) assert len(samples) == 1 assert samples[0].value == 1 + + +def test_compute_safekeeper_connstrings_duplicate(neon_simple_env: NeonEnv): + """ + Test that we catch duplicate entries in neon.safekeepers. + """ + env = neon_simple_env + + endpoint = env.endpoints.create_start("main") + + # grab the current value of neon.safekeepers + sk_list = [] + with endpoint.cursor() as cursor: + cursor.execute("SHOW neon.safekeepers;") + row = cursor.fetchone() + assert row is not None + + log.info(f' initial neon.safekeepers: "{row}"') + + # build a safekeepers list with a duplicate + sk_list.append(row[0]) + sk_list.append(row[0]) + + safekeepers = ",".join(sk_list) + log.info(f'reconfigure neon.safekeepers: "{safekeepers}"') + + # introduce duplicate entry in neon.safekeepers, on purpose + endpoint.respec_deep( + **{ + "spec": { + "skip_pg_catalog_updates": True, + "cluster": { + "settings": [ + { + "name": "neon.safekeepers", + "vartype": "string", + "value": safekeepers, + } + ] + }, + }, + } + ) + + try: + endpoint.reconfigure() + + # Check that in logs we see that it was actually reconfigured, + # not restarted or something else. + endpoint.log_contains("INFO request{method=POST uri=/configure") + + except Exception as e: + # we except a failure here + log.info(f"RAISED: {e}" % e) diff --git a/test_runner/regress/test_feature_flag.py b/test_runner/regress/test_feature_flag.py new file mode 100644 index 0000000000..2712d13dcc --- /dev/null +++ b/test_runner/regress/test_feature_flag.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from fixtures.utils import run_only_on_default_postgres + +if TYPE_CHECKING: + from fixtures.neon_fixtures import NeonEnvBuilder + + +@run_only_on_default_postgres("Pageserver-only test only needs to run on one version") +def test_feature_flag(neon_env_builder: NeonEnvBuilder): + env = neon_env_builder.init_start() + env.pageserver.http_client().force_override_feature_flag("test-feature-flag", "true") + assert env.pageserver.http_client().evaluate_feature_flag_boolean( + env.initial_tenant, "test-feature-flag" + )["result"]["Ok"] + assert ( + env.pageserver.http_client().evaluate_feature_flag_multivariate( + env.initial_tenant, "test-feature-flag" + )["result"]["Ok"] + == "true" + ) + + env.pageserver.http_client().force_override_feature_flag("test-feature-flag", "false") + assert ( + env.pageserver.http_client().evaluate_feature_flag_boolean( + env.initial_tenant, "test-feature-flag" + )["result"]["Err"] + == "No condition group is matched" + ) + assert ( + env.pageserver.http_client().evaluate_feature_flag_multivariate( + env.initial_tenant, "test-feature-flag" + )["result"]["Ok"] + == "false" + ) + + env.pageserver.http_client().force_override_feature_flag("test-feature-flag", None) + assert ( + "Err" + in env.pageserver.http_client().evaluate_feature_flag_boolean( + env.initial_tenant, "test-feature-flag" + )["result"] + ) + assert ( + "Err" + in env.pageserver.http_client().evaluate_feature_flag_multivariate( + env.initial_tenant, "test-feature-flag" + )["result"] + ) diff --git a/test_runner/regress/test_import.py b/test_runner/regress/test_import.py index 55737c35f0..e1070a81e6 100644 --- a/test_runner/regress/test_import.py +++ b/test_runner/regress/test_import.py @@ -87,6 +87,9 @@ def test_import_from_vanilla(test_output_dir, pg_bin, vanilla_pg, neon_env_build # Set up pageserver for import neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS) + neon_env_builder.storage_controller_config = { + "timelines_onto_safekeepers": True, + } env = neon_env_builder.init_start() env.pageserver.tenant_create(tenant) diff --git a/test_runner/regress/test_lfc_prewarm.py b/test_runner/regress/test_lfc_prewarm.py index 82e1e9fcba..e1058cd644 100644 --- a/test_runner/regress/test_lfc_prewarm.py +++ b/test_runner/regress/test_lfc_prewarm.py @@ -59,7 +59,7 @@ def test_lfc_prewarm(neon_simple_env: NeonEnv, query: LfcQueryMethod): pg_conn = endpoint.connect() pg_cur = pg_conn.cursor() - pg_cur.execute("create extension neon version '1.6'") + pg_cur.execute("create extension neon") pg_cur.execute("create database lfc") lfc_conn = endpoint.connect(dbname="lfc") @@ -84,11 +84,8 @@ def test_lfc_prewarm(neon_simple_env: NeonEnv, query: LfcQueryMethod): endpoint.stop() endpoint.start() - # wait until compute_ctl completes downgrade of extension to default version - time.sleep(1) pg_conn = endpoint.connect() pg_cur = pg_conn.cursor() - pg_cur.execute("alter extension neon update to '1.6'") lfc_conn = endpoint.connect(dbname="lfc") lfc_cur = lfc_conn.cursor() @@ -144,7 +141,7 @@ def test_lfc_prewarm_under_workload(neon_simple_env: NeonEnv, query: LfcQueryMet pg_conn = endpoint.connect() pg_cur = pg_conn.cursor() - pg_cur.execute("create extension neon version '1.6'") + pg_cur.execute("create extension neon") pg_cur.execute("CREATE DATABASE lfc") lfc_conn = endpoint.connect(dbname="lfc") @@ -188,7 +185,8 @@ def test_lfc_prewarm_under_workload(neon_simple_env: NeonEnv, query: LfcQueryMet pg_cur.execute("select pg_reload_conf()") if query is LfcQueryMethod.COMPUTE_CTL: - http_client.prewarm_lfc() + # Same thing as prewarm_lfc(), testing other method + http_client.prewarm_lfc(endpoint.endpoint_id) else: pg_cur.execute("select prewarm_local_cache(%s)", (lfc_state,)) diff --git a/test_runner/regress/test_neon_extension.py b/test_runner/regress/test_neon_extension.py index e79ab458ca..6bcd15d463 100644 --- a/test_runner/regress/test_neon_extension.py +++ b/test_runner/regress/test_neon_extension.py @@ -29,7 +29,7 @@ def test_neon_extension(neon_env_builder: NeonEnvBuilder): # IMPORTANT: # If the version has changed, the test should be updated. # Ensure that the default version is also updated in the neon.control file - assert cur.fetchone() == ("1.5",) + assert cur.fetchone() == ("1.6",) cur.execute("SELECT * from neon.NEON_STAT_FILE_CACHE") res = cur.fetchall() log.info(res) @@ -53,10 +53,10 @@ def test_neon_extension_compatibility(neon_env_builder: NeonEnvBuilder): # IMPORTANT: # If the version has changed, the test should be updated. # Ensure that the default version is also updated in the neon.control file - assert cur.fetchone() == ("1.5",) + assert cur.fetchone() == ("1.6",) cur.execute("SELECT * from neon.NEON_STAT_FILE_CACHE") - all_versions = ["1.5", "1.4", "1.3", "1.2", "1.1", "1.0"] - current_version = "1.5" + all_versions = ["1.6", "1.5", "1.4", "1.3", "1.2", "1.1", "1.0"] + current_version = "1.6" for idx, begin_version in enumerate(all_versions): for target_version in all_versions[idx + 1 :]: if current_version != begin_version: diff --git a/test_runner/regress/test_normal_work.py b/test_runner/regress/test_normal_work.py index 44590ea4b9..3335cf686c 100644 --- a/test_runner/regress/test_normal_work.py +++ b/test_runner/regress/test_normal_work.py @@ -64,6 +64,11 @@ def test_normal_work( """ neon_env_builder.num_safekeepers = num_safekeepers + + if safekeeper_proto_version == 2: + neon_env_builder.storage_controller_config = { + "timelines_onto_safekeepers": False, + } env = neon_env_builder.init_start() pageserver_http = env.pageserver.http_client() diff --git a/test_runner/regress/test_pageserver_restarts_under_workload.py b/test_runner/regress/test_pageserver_restarts_under_workload.py index 9f19c887a4..6b33b3e046 100644 --- a/test_runner/regress/test_pageserver_restarts_under_workload.py +++ b/test_runner/regress/test_pageserver_restarts_under_workload.py @@ -16,7 +16,7 @@ if TYPE_CHECKING: # Test restarting page server, while safekeeper and compute node keep # running. -def test_pageserver_restarts_under_worload(neon_simple_env: NeonEnv, pg_bin: PgBin): +def test_pageserver_restarts_under_workload(neon_simple_env: NeonEnv, pg_bin: PgBin): env = neon_simple_env env.create_branch("test_pageserver_restarts") endpoint = env.endpoints.create_start("test_pageserver_restarts") @@ -28,7 +28,11 @@ def test_pageserver_restarts_under_worload(neon_simple_env: NeonEnv, pg_bin: PgB pg_bin.run_capture(["pgbench", "-i", "-I", "dtGvp", f"-s{scale}", connstr]) pg_bin.run_capture(["pgbench", f"-T{n_restarts}", connstr]) - thread = threading.Thread(target=run_pgbench, args=(endpoint.connstr(),), daemon=True) + thread = threading.Thread( + target=run_pgbench, + args=(endpoint.connstr(options="-cstatement_timeout=360s"),), + daemon=True, + ) thread.start() for _ in range(n_restarts): diff --git a/test_runner/regress/test_pg_regress.py b/test_runner/regress/test_pg_regress.py index 474002353b..728241b465 100644 --- a/test_runner/regress/test_pg_regress.py +++ b/test_runner/regress/test_pg_regress.py @@ -173,7 +173,11 @@ def test_pg_regress( (runpath / "testtablespace").mkdir(parents=True) # Compute all the file locations that pg_regress will need. - build_path = pg_distrib_dir / f"build/{env.pg_version.v_prefixed}/src/test/regress" + # + # XXX: We assume that the `build` directory is a sibling of the + # pg_distrib_dir. That is the default when you check out the + # repository; `build` and `pg_install` are created side by side. + build_path = pg_distrib_dir / f"../build/{env.pg_version.v_prefixed}/src/test/regress" src_path = base_dir / f"vendor/postgres-{env.pg_version.v_prefixed}/src/test/regress" bindir = pg_distrib_dir / f"v{env.pg_version}/bin" schedule = src_path / "parallel_schedule" @@ -250,7 +254,11 @@ def test_isolation( (runpath / "testtablespace").mkdir(parents=True) # Compute all the file locations that pg_isolation_regress will need. - build_path = pg_distrib_dir / f"build/{env.pg_version.v_prefixed}/src/test/isolation" + # + # XXX: We assume that the `build` directory is a sibling of the + # pg_distrib_dir. That is the default when you check out the + # repository; `build` and `pg_install` are created side by side. + build_path = pg_distrib_dir / f"../build/{env.pg_version.v_prefixed}/src/test/isolation" src_path = base_dir / f"vendor/postgres-{env.pg_version.v_prefixed}/src/test/isolation" bindir = pg_distrib_dir / f"v{env.pg_version}/bin" schedule = src_path / "isolation_schedule" @@ -306,13 +314,7 @@ def test_sql_regress( ) # Connect to postgres and create a database called "regression". - endpoint = env.endpoints.create_start( - "main", - config_lines=[ - # Enable the test mode, so that we don't need to patch the test cases. - "neon.regress_test_mode = true", - ], - ) + endpoint = env.endpoints.create_start("main") endpoint.safe_psql(f"CREATE DATABASE {DBNAME}") # Create some local directories for pg_regress to run in. @@ -320,8 +322,11 @@ def test_sql_regress( (runpath / "testtablespace").mkdir(parents=True) # Compute all the file locations that pg_regress will need. - # This test runs neon specific tests - build_path = pg_distrib_dir / f"build/v{env.pg_version}/src/test/regress" + # + # XXX: We assume that the `build` directory is a sibling of the + # pg_distrib_dir. That is the default when you check out the + # repository; `build` and `pg_install` are created side by side. + build_path = pg_distrib_dir / f"../build/{env.pg_version.v_prefixed}/src/test/regress" src_path = base_dir / "test_runner/sql_regress" bindir = pg_distrib_dir / f"v{env.pg_version}/bin" schedule = src_path / "parallel_schedule" diff --git a/test_runner/regress/test_proxy_allowed_ips.py b/test_runner/regress/test_proxy_allowed_ips.py index 7384326385..5ac74585b9 100644 --- a/test_runner/regress/test_proxy_allowed_ips.py +++ b/test_runner/regress/test_proxy_allowed_ips.py @@ -19,11 +19,15 @@ TABLE_NAME = "neon_control_plane.endpoints" async def test_proxy_psql_allowed_ips(static_proxy: NeonProxy, vanilla_pg: VanillaPostgres): # Shouldn't be able to connect to this project vanilla_pg.safe_psql( - f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('private-project', '8.8.8.8')" + f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('private-project', '8.8.8.8')", + user="proxy", + password="password", ) # Should be able to connect to this project vanilla_pg.safe_psql( - f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('generic-project', '::1,127.0.0.1')" + f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('generic-project', '::1,127.0.0.1')", + user="proxy", + password="password", ) def check_cannot_connect(**kwargs): @@ -60,7 +64,9 @@ async def test_proxy_http_allowed_ips(static_proxy: NeonProxy, vanilla_pg: Vanil # Shouldn't be able to connect to this project vanilla_pg.safe_psql( - f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('proxy', '8.8.8.8')" + f"INSERT INTO {TABLE_NAME} (endpoint_id, allowed_ips) VALUES ('proxy', '8.8.8.8')", + user="proxy", + password="password", ) def query(status: int, query: str, *args): @@ -75,6 +81,8 @@ async def test_proxy_http_allowed_ips(static_proxy: NeonProxy, vanilla_pg: Vanil query(400, "select 1;") # ip address is not allowed # Should be able to connect to this project vanilla_pg.safe_psql( - f"UPDATE {TABLE_NAME} SET allowed_ips = '8.8.8.8,127.0.0.1' WHERE endpoint_id = 'proxy'" + f"UPDATE {TABLE_NAME} SET allowed_ips = '8.8.8.8,127.0.0.1' WHERE endpoint_id = 'proxy'", + user="proxy", + password="password", ) query(200, "select 1;") # should work now diff --git a/test_runner/regress/test_replica_promotes.py b/test_runner/regress/test_replica_promotes.py index e378d37635..4486901bae 100644 --- a/test_runner/regress/test_replica_promotes.py +++ b/test_runner/regress/test_replica_promotes.py @@ -4,13 +4,25 @@ File with secondary->primary promotion testing. This far, only contains a test that we don't break and that the data is persisted. """ +from typing import cast + import psycopg2 +from fixtures.common_types import Lsn from fixtures.log_helper import log from fixtures.neon_fixtures import Endpoint, NeonEnv, wait_replica_caughtup from fixtures.pg_version import PgVersion from pytest import raises +def stop_and_check_lsn(ep: Endpoint, expected_lsn: Lsn | None): + ep.stop(mode="immediate-terminate") + lsn = ep.terminate_flush_lsn + if expected_lsn is not None: + assert lsn >= expected_lsn, f"{expected_lsn=} < {lsn=}" + else: + assert lsn == expected_lsn, f"{expected_lsn=} != {lsn=}" + + def test_replica_promotes(neon_simple_env: NeonEnv, pg_version: PgVersion): """ Test that a replica safely promotes, and can commit data updates which @@ -37,7 +49,9 @@ def test_replica_promotes(neon_simple_env: NeonEnv, pg_version: PgVersion): pg_current_wal_flush_lsn() """ ) - log.info(f"Primary: Current LSN after workload is {primary_cur.fetchone()}") + lsn_triple = cast("tuple[str, str, str]", primary_cur.fetchone()) + log.info(f"Primary: Current LSN after workload is {lsn_triple}") + expected_primary_lsn: Lsn = Lsn(lsn_triple[2]) primary_cur.execute("show neon.safekeepers") safekeepers = primary_cur.fetchall()[0][0] @@ -57,7 +71,7 @@ def test_replica_promotes(neon_simple_env: NeonEnv, pg_version: PgVersion): secondary_cur.execute("select count(*) from t") assert secondary_cur.fetchone() == (100,) - primary.stop_and_destroy(mode="immediate") + stop_and_check_lsn(primary, expected_primary_lsn) # Reconnect to the secondary to make sure we get a read-write connection promo_conn = secondary.connect() @@ -109,9 +123,10 @@ def test_replica_promotes(neon_simple_env: NeonEnv, pg_version: PgVersion): # wait_for_last_flush_lsn(env, secondary, env.initial_tenant, env.initial_timeline) - secondary.stop_and_destroy() + # secondaries don't sync safekeepers on finish so LSN will be None + stop_and_check_lsn(secondary, None) - primary = env.endpoints.create_start(branch_name="main", endpoint_id="primary") + primary = env.endpoints.create_start(branch_name="main", endpoint_id="primary2") with primary.connect() as new_primary: new_primary_cur = new_primary.cursor() @@ -122,7 +137,9 @@ def test_replica_promotes(neon_simple_env: NeonEnv, pg_version: PgVersion): pg_current_wal_flush_lsn() """ ) - log.info(f"New primary: Boot LSN is {new_primary_cur.fetchone()}") + lsn_triple = cast("tuple[str, str, str]", new_primary_cur.fetchone()) + expected_primary_lsn = Lsn(lsn_triple[2]) + log.info(f"New primary: Boot LSN is {lsn_triple}") new_primary_cur.execute("select count(*) from t") assert new_primary_cur.fetchone() == (200,) @@ -130,4 +147,4 @@ def test_replica_promotes(neon_simple_env: NeonEnv, pg_version: PgVersion): new_primary_cur.execute("select count(*) from t") assert new_primary_cur.fetchone() == (300,) - primary.stop(mode="immediate") + stop_and_check_lsn(primary, expected_primary_lsn) diff --git a/test_runner/regress/test_s3_restore.py b/test_runner/regress/test_s3_restore.py index 082808f9ff..2d7be1f9d1 100644 --- a/test_runner/regress/test_s3_restore.py +++ b/test_runner/regress/test_s3_restore.py @@ -74,7 +74,7 @@ def test_tenant_s3_restore( last_flush_lsn = Lsn(endpoint.safe_psql("SELECT pg_current_wal_flush_lsn()")[0][0]) last_flush_lsns.append(last_flush_lsn) ps_http.timeline_checkpoint(tenant_id, timeline_id) - wait_for_upload(ps_http, tenant_id, timeline_id, last_flush_lsn) + wait_for_upload(ps_http, tenant_id, timeline_id, last_flush_lsn, timeout=60) log.info(f"{timeline} timeline {timeline_id} {last_flush_lsn=}") parent = timeline diff --git a/test_runner/regress/test_safekeeper_deletion.py b/test_runner/regress/test_safekeeper_deletion.py index b681a86103..bc79969e9a 100644 --- a/test_runner/regress/test_safekeeper_deletion.py +++ b/test_runner/regress/test_safekeeper_deletion.py @@ -30,6 +30,7 @@ def test_safekeeper_delete_timeline(neon_env_builder: NeonEnvBuilder, auth_enabl env.pageserver.allowed_errors.extend( [ ".*Timeline .* was not found in global map.*", + ".*Timeline .* has been deleted.*", ".*Timeline .* was cancelled and cannot be used anymore.*", ] ) @@ -198,6 +199,7 @@ def test_safekeeper_delete_timeline_under_load(neon_env_builder: NeonEnvBuilder) env.pageserver.allowed_errors.extend( [ ".*Timeline.*was cancelled.*", + ".*Timeline.*has been deleted.*", ".*Timeline.*was not found.*", ] ) diff --git a/test_runner/regress/test_sharding.py b/test_runner/regress/test_sharding.py index 4c9887fb92..93c621f564 100644 --- a/test_runner/regress/test_sharding.py +++ b/test_runner/regress/test_sharding.py @@ -1337,7 +1337,7 @@ def test_sharding_split_failures( # Create bystander tenants with various shard counts. They should not be affected by the aborted # splits. Regression test for https://github.com/neondatabase/cloud/issues/28589. bystanders = {} # id → shard_count - for bystander_shard_count in [1, 2, 4, 8]: + for bystander_shard_count in [1, 2, 4]: id, _ = env.create_tenant(shard_count=bystander_shard_count) bystanders[id] = bystander_shard_count @@ -1358,6 +1358,8 @@ def test_sharding_split_failures( ".*Reconcile error.*Cancelled.*", # While parent shard's client is stopped during split, flush loop updating LSNs will emit this warning ".*Failed to schedule metadata upload after updating disk_consistent_lsn.*", + # We didn't identify a secondary to remove. + ".*Keeping extra secondaries.*", ] ) @@ -1388,51 +1390,36 @@ def test_sharding_split_failures( with pytest.raises(failure.expect_exception()): env.storage_controller.tenant_shard_split(tenant_id, shard_count=4) + def assert_shard_count(shard_count: int, exclude_ps_id: int | None = None) -> None: + secondary_count = 0 + attached_count = 0 + log.info(f"Iterating over {len(env.pageservers)} pageservers to check shard count") + for ps in env.pageservers: + if exclude_ps_id is not None and ps.id == exclude_ps_id: + continue + + locations = ps.http_client().tenant_list_locations()["tenant_shards"] + for loc in locations: + tenant_shard_id = TenantShardId.parse(loc[0]) + if tenant_shard_id.tenant_id != tenant_id: + continue # skip bystanders + log.info(f"Shard {tenant_shard_id} seen on node {ps.id} in mode {loc[1]['mode']}") + assert tenant_shard_id.shard_count == shard_count + if loc[1]["mode"] == "Secondary": + secondary_count += 1 + else: + attached_count += 1 + assert secondary_count == shard_count + assert attached_count == shard_count + # We expect that the overall operation will fail, but some split requests # will have succeeded: the net result should be to return to a clean state, including # detaching any child shards. def assert_rolled_back(exclude_ps_id=None) -> None: - secondary_count = 0 - attached_count = 0 - for ps in env.pageservers: - if exclude_ps_id is not None and ps.id == exclude_ps_id: - continue - - locations = ps.http_client().tenant_list_locations()["tenant_shards"] - for loc in locations: - tenant_shard_id = TenantShardId.parse(loc[0]) - if tenant_shard_id.tenant_id != tenant_id: - continue # skip bystanders - log.info(f"Shard {tenant_shard_id} seen on node {ps.id} in mode {loc[1]['mode']}") - assert tenant_shard_id.shard_count == initial_shard_count - if loc[1]["mode"] == "Secondary": - secondary_count += 1 - else: - attached_count += 1 - - assert secondary_count == initial_shard_count - assert attached_count == initial_shard_count + assert_shard_count(initial_shard_count, exclude_ps_id) def assert_split_done(exclude_ps_id: int | None = None) -> None: - secondary_count = 0 - attached_count = 0 - for ps in env.pageservers: - if exclude_ps_id is not None and ps.id == exclude_ps_id: - continue - - locations = ps.http_client().tenant_list_locations()["tenant_shards"] - for loc in locations: - tenant_shard_id = TenantShardId.parse(loc[0]) - if tenant_shard_id.tenant_id != tenant_id: - continue # skip bystanders - log.info(f"Shard {tenant_shard_id} seen on node {ps.id} in mode {loc[1]['mode']}") - assert tenant_shard_id.shard_count == split_shard_count - if loc[1]["mode"] == "Secondary": - secondary_count += 1 - else: - attached_count += 1 - assert attached_count == split_shard_count - assert secondary_count == split_shard_count + assert_shard_count(split_shard_count, exclude_ps_id) def finish_split(): # Having failed+rolled back, we should be able to split again @@ -1468,6 +1455,7 @@ def test_sharding_split_failures( # The split should appear to be rolled back from the point of view of all pageservers # apart from the one that is offline + env.storage_controller.reconcile_until_idle(timeout_secs=60, max_interval=2) wait_until(lambda: assert_rolled_back(exclude_ps_id=failure.pageserver_id)) finish_split() @@ -1482,6 +1470,7 @@ def test_sharding_split_failures( log.info("Clearing failure...") failure.clear(env) + env.storage_controller.reconcile_until_idle(timeout_secs=60, max_interval=2) wait_until(assert_rolled_back) # Having rolled back, the tenant should be working @@ -1836,3 +1825,90 @@ def test_sharding_gc( shard_gc_cutoff_lsn = Lsn(shard_index["metadata_bytes"]["latest_gc_cutoff_lsn"]) log.info(f"Shard {shard_number} cutoff LSN: {shard_gc_cutoff_lsn}") assert shard_gc_cutoff_lsn == shard_0_gc_cutoff_lsn + + +def test_split_ps_delete_old_shard_after_commit(neon_env_builder: NeonEnvBuilder): + """ + Check that PageServer only deletes old shards after the split is committed such that it doesn't + have to download a lot of files during abort. + """ + DBNAME = "regression" + + init_shard_count = 4 + neon_env_builder.num_pageservers = init_shard_count + stripe_size = 32 + + env = neon_env_builder.init_start( + initial_tenant_shard_count=init_shard_count, initial_tenant_shard_stripe_size=stripe_size + ) + + env.storage_controller.allowed_errors.extend( + [ + # All split failures log a warning when they enqueue the abort operation + ".*Enqueuing background abort.*", + # Tolerate any error logs that mention a failpoint + ".*failpoint.*", + ] + ) + + endpoint = env.endpoints.create("main") + endpoint.respec(skip_pg_catalog_updates=False) + endpoint.start() + + # Write some initial data. + endpoint.safe_psql(f"CREATE DATABASE {DBNAME}") + endpoint.safe_psql("CREATE TABLE usertable ( YCSB_KEY INT, FIELD0 TEXT);") + + for _ in range(1000): + endpoint.safe_psql( + "INSERT INTO usertable SELECT random(), repeat('a', 1000);", log_query=False + ) + + # Record how many bytes we've downloaded before the split. + def collect_downloaded_bytes() -> list[float | None]: + downloaded_bytes = [] + for page_server in env.pageservers: + metric = page_server.http_client().get_metric_value( + "pageserver_remote_ondemand_downloaded_bytes_total" + ) + downloaded_bytes.append(metric) + return downloaded_bytes + + downloaded_bytes_before = collect_downloaded_bytes() + + # Attempt to split the tenant, but fail the split before it completes. + env.storage_controller.configure_failpoints(("shard-split-pre-complete", "return(1)")) + with pytest.raises(StorageControllerApiException): + env.storage_controller.tenant_shard_split(env.initial_tenant, shard_count=16) + + # Wait until split is aborted. + def check_split_is_aborted(): + tenants = env.storage_controller.tenant_list() + assert len(tenants) == 1 + shards = tenants[0]["shards"] + assert len(shards) == 4 + for shard in shards: + assert not shard["is_splitting"] + assert not shard["is_reconciling"] + + # Make sure all new shards have been deleted. + valid_shards = 0 + for ps in env.pageservers: + for tenant_dir in os.listdir(ps.workdir / "tenants"): + try: + tenant_shard_id = TenantShardId.parse(tenant_dir) + valid_shards += 1 + assert tenant_shard_id.shard_count == 4 + except ValueError: + log.info(f"{tenant_dir} is not valid tenant shard id") + assert valid_shards >= 4 + + wait_until(check_split_is_aborted) + + endpoint.safe_psql("SELECT count(*) from usertable;", log_query=False) + + # Make sure we didn't download anything following the aborted split. + downloaded_bytes_after = collect_downloaded_bytes() + + assert downloaded_bytes_before == downloaded_bytes_after + endpoint.stop_and_destroy() diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index 346ef0951d..70772766d7 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -88,6 +88,12 @@ def test_storage_controller_smoke( neon_env_builder.control_plane_hooks_api = compute_reconfigure_listener.control_plane_hooks_api env = neon_env_builder.init_configs() + # These bubble up from safekeepers + for ps in env.pageservers: + ps.allowed_errors.extend( + [".*Timeline.* has been deleted.*", ".*Timeline.*was cancelled and cannot be used"] + ) + # Start services by hand so that we can skip a pageserver (this will start + register later) env.broker.start() env.storage_controller.start() @@ -2956,7 +2962,7 @@ def test_storage_controller_leadership_transfer_during_split( env.storage_controller.allowed_errors.extend( [".*Unexpected child shard count.*", ".*Enqueuing background abort.*"] ) - pause_failpoint = "shard-split-pre-complete" + pause_failpoint = "shard-split-pre-complete-pause" env.storage_controller.configure_failpoints((pause_failpoint, "pause")) split_fut = executor.submit( @@ -3003,7 +3009,7 @@ def test_storage_controller_leadership_transfer_during_split( env.storage_controller.request( "PUT", f"http://127.0.0.1:{storage_controller_1_port}/debug/v1/failpoints", - json=[{"name": "shard-split-pre-complete", "actions": "off"}], + json=[{"name": pause_failpoint, "actions": "off"}], headers=env.storage_controller.headers(TokenScope.ADMIN), ) @@ -3093,6 +3099,58 @@ def test_storage_controller_ps_restarted_during_drain(neon_env_builder: NeonEnvB wait_until(reconfigure_node_again) +def test_ps_unavailable_after_delete(neon_env_builder: NeonEnvBuilder): + neon_env_builder.num_pageservers = 3 + + env = neon_env_builder.init_start() + + def assert_nodes_count(n: int): + nodes = env.storage_controller.node_list() + assert len(nodes) == n + + # Nodes count must remain the same before deletion + assert_nodes_count(3) + + ps = env.pageservers[0] + env.storage_controller.node_delete(ps.id) + + # After deletion, the node count must be reduced + assert_nodes_count(2) + + # Running pageserver CLI init in a separate thread + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + log.info("Restarting tombstoned pageserver...") + ps.stop() + ps_start_fut = executor.submit(lambda: ps.start(await_active=False)) + + # After deleted pageserver restart, the node count must remain the same + assert_nodes_count(2) + + tombstones = env.storage_controller.tombstone_list() + assert len(tombstones) == 1 and tombstones[0]["id"] == ps.id + + env.storage_controller.tombstone_delete(ps.id) + + tombstones = env.storage_controller.tombstone_list() + assert len(tombstones) == 0 + + # Wait for the pageserver start operation to complete. + # If it fails with an exception, we try restarting the pageserver since the failure + # may be due to the storage controller refusing to register the node. + # However, if we get a TimeoutError that means the pageserver is completely hung, + # which is an unexpected failure mode that we'll let propagate up. + try: + ps_start_fut.result(timeout=20) + except TimeoutError: + raise + except Exception: + log.info("Restarting deleted pageserver...") + ps.restart() + + # Finally, the node can be registered again after tombstone is deleted + wait_until(lambda: assert_nodes_count(3)) + + def test_storage_controller_timeline_crud_race(neon_env_builder: NeonEnvBuilder): """ The storage controller is meant to handle the case where a timeline CRUD operation races @@ -3403,7 +3461,7 @@ def test_safekeeper_deployment_time_update(neon_env_builder: NeonEnvBuilder): assert target.get_safekeeper(fake_id) is None - assert len(target.get_safekeepers()) == 0 + start_sks = target.get_safekeepers() sk_0 = env.safekeepers[0] @@ -3425,7 +3483,7 @@ def test_safekeeper_deployment_time_update(neon_env_builder: NeonEnvBuilder): inserted = target.get_safekeeper(fake_id) assert inserted is not None - assert target.get_safekeepers() == [inserted] + assert target.get_safekeepers() == start_sks + [inserted] assert eq_safekeeper_records(body, inserted) # error out if pk is changed (unexpected) @@ -3437,7 +3495,7 @@ def test_safekeeper_deployment_time_update(neon_env_builder: NeonEnvBuilder): assert exc.value.status_code == 400 inserted_again = target.get_safekeeper(fake_id) - assert target.get_safekeepers() == [inserted_again] + assert target.get_safekeepers() == start_sks + [inserted_again] assert inserted_again is not None assert eq_safekeeper_records(inserted, inserted_again) @@ -3446,7 +3504,7 @@ def test_safekeeper_deployment_time_update(neon_env_builder: NeonEnvBuilder): body["version"] += 1 target.on_safekeeper_deploy(fake_id, body) inserted_now = target.get_safekeeper(fake_id) - assert target.get_safekeepers() == [inserted_now] + assert target.get_safekeepers() == start_sks + [inserted_now] assert inserted_now is not None assert eq_safekeeper_records(body, inserted_now) @@ -3455,7 +3513,7 @@ def test_safekeeper_deployment_time_update(neon_env_builder: NeonEnvBuilder): body["https_port"] = 123 target.on_safekeeper_deploy(fake_id, body) inserted_now = target.get_safekeeper(fake_id) - assert target.get_safekeepers() == [inserted_now] + assert target.get_safekeepers() == start_sks + [inserted_now] assert inserted_now is not None assert eq_safekeeper_records(body, inserted_now) env.storage_controller.consistency_check() @@ -3464,7 +3522,7 @@ def test_safekeeper_deployment_time_update(neon_env_builder: NeonEnvBuilder): body["https_port"] = None target.on_safekeeper_deploy(fake_id, body) inserted_now = target.get_safekeeper(fake_id) - assert target.get_safekeepers() == [inserted_now] + assert target.get_safekeepers() == start_sks + [inserted_now] assert inserted_now is not None assert eq_safekeeper_records(body, inserted_now) env.storage_controller.consistency_check() @@ -3583,6 +3641,11 @@ def test_timeline_delete_mid_live_migration(neon_env_builder: NeonEnvBuilder, mi env = neon_env_builder.init_configs() env.start() + for ps in env.pageservers: + ps.allowed_errors.extend( + [".*Timeline.* has been deleted.*", ".*Timeline.*was cancelled and cannot be used"] + ) + tenant_id = TenantId.generate() timeline_id = TimelineId.generate() env.storage_controller.tenant_create(tenant_id, placement_policy={"Attached": 1}) @@ -4373,6 +4436,53 @@ def test_storage_controller_graceful_migration(neon_env_builder: NeonEnvBuilder, assert initial_ps.http_client().tenant_list_locations()["tenant_shards"] == [] +def test_attached_0_graceful_migration(neon_env_builder: NeonEnvBuilder): + neon_env_builder.num_pageservers = 4 + neon_env_builder.num_azs = 2 + + neon_env_builder.storcon_kick_secondary_downloads = False + + env = neon_env_builder.init_start() + + # It is default, but we want to ensure that there are no secondary locations requested + env.storage_controller.tenant_policy_update(env.initial_tenant, {"placement": {"Attached": 0}}) + env.storage_controller.reconcile_until_idle() + + desc = env.storage_controller.tenant_describe(env.initial_tenant)["shards"][0] + src_ps_id = desc["node_attached"] + src_ps = env.get_pageserver(src_ps_id) + src_az = desc["preferred_az_id"] + + # There must be no secondary locations with Attached(0) placement policy + assert len(desc["node_secondary"]) == 0 + + # Migrate tenant shard to the same AZ node + dst_ps = [ps for ps in env.pageservers if ps.id != src_ps_id and ps.az_id == src_az][0] + + env.storage_controller.tenant_shard_migrate( + TenantShardId(env.initial_tenant, 0, 0), + dst_ps.id, + config=StorageControllerMigrationConfig(prewarm=True), + ) + + def tenant_shard_migrated(): + src_locations = src_ps.http_client().tenant_list_locations()["tenant_shards"] + assert len(src_locations) == 0 + log.info(f"Tenant shard migrated from {src_ps.id}") + dst_locations = dst_ps.http_client().tenant_list_locations()["tenant_shards"] + assert len(dst_locations) == 1 + assert dst_locations[0][1]["mode"] == "AttachedSingle" + log.info(f"Tenant shard migrated to {dst_ps.id}") + + # After all we expect that tenant shard exists only on dst node. + # We wait so long because [`DEFAULT_HEATMAP_PERIOD`] and [`DEFAULT_DOWNLOAD_INTERVAL`] + # are set to 60 seconds by default. + # + # TODO: we should consider making these configurable, so the test can run faster. + wait_until(tenant_shard_migrated, timeout=180, interval=5, status_interval=10) + log.info("Tenant shard migrated successfully") + + @run_only_on_default_postgres("this is like a 'unit test' against storcon db") def test_storage_controller_migrate_with_pageserver_restart( neon_env_builder: NeonEnvBuilder, make_httpserver diff --git a/test_runner/regress/test_storage_scrubber.py b/test_runner/regress/test_storage_scrubber.py index 03cd133ccb..e29cb801d5 100644 --- a/test_runner/regress/test_storage_scrubber.py +++ b/test_runner/regress/test_storage_scrubber.py @@ -341,6 +341,11 @@ def test_scrubber_physical_gc_timeline_deletion(neon_env_builder: NeonEnvBuilder env = neon_env_builder.init_configs() env.start() + for ps in env.pageservers: + ps.allowed_errors.extend( + [".*Timeline.* has been deleted.*", ".*Timeline.*was cancelled and cannot be used"] + ) + tenant_id = TenantId.generate() timeline_id = TimelineId.generate() env.create_tenant( diff --git a/test_runner/regress/test_tenant_delete.py b/test_runner/regress/test_tenant_delete.py index 8379908631..a0ff9a3ae2 100644 --- a/test_runner/regress/test_tenant_delete.py +++ b/test_runner/regress/test_tenant_delete.py @@ -430,6 +430,7 @@ def test_tenant_delete_stale_shards(neon_env_builder: NeonEnvBuilder, pg_bin: Pg workload.init() workload.write_rows(256) workload.validate() + workload.stop() assert_prefix_not_empty( neon_env_builder.pageserver_remote_storage, diff --git a/test_runner/regress/test_tenants.py b/test_runner/regress/test_tenants.py index d08692500f..c54dd8b38d 100644 --- a/test_runner/regress/test_tenants.py +++ b/test_runner/regress/test_tenants.py @@ -12,7 +12,7 @@ from typing import TYPE_CHECKING import pytest import requests -from fixtures.common_types import Lsn, TenantId, TimelineId +from fixtures.common_types import Lsn, TenantId, TimelineArchivalState, TimelineId from fixtures.log_helper import log from fixtures.metrics import ( PAGESERVER_GLOBAL_METRICS, @@ -299,6 +299,65 @@ def test_pageserver_metrics_removed_after_detach(neon_env_builder: NeonEnvBuilde assert post_detach_samples == set() +def test_pageserver_metrics_removed_after_offload(neon_env_builder: NeonEnvBuilder): + """Tests that when a timeline is offloaded, the tenant specific metrics are not left behind""" + + neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.MOCK_S3) + + neon_env_builder.num_safekeepers = 3 + + env = neon_env_builder.init_start() + tenant_1, _ = env.create_tenant() + + timeline_1 = env.create_timeline("test_metrics_removed_after_offload_1", tenant_id=tenant_1) + timeline_2 = env.create_timeline("test_metrics_removed_after_offload_2", tenant_id=tenant_1) + + endpoint_tenant1 = env.endpoints.create_start( + "test_metrics_removed_after_offload_1", tenant_id=tenant_1 + ) + endpoint_tenant2 = env.endpoints.create_start( + "test_metrics_removed_after_offload_2", tenant_id=tenant_1 + ) + + for endpoint in [endpoint_tenant1, endpoint_tenant2]: + with closing(endpoint.connect()) as conn: + with conn.cursor() as cur: + cur.execute("CREATE TABLE t(key int primary key, value text)") + cur.execute("INSERT INTO t SELECT generate_series(1,100000), 'payload'") + cur.execute("SELECT sum(key) FROM t") + assert cur.fetchone() == (5000050000,) + endpoint.stop() + + def get_ps_metric_samples_for_timeline( + tenant_id: TenantId, timeline_id: TimelineId + ) -> list[Sample]: + ps_metrics = env.pageserver.http_client().get_metrics() + samples = [] + for metric_name in ps_metrics.metrics: + for sample in ps_metrics.query_all( + name=metric_name, + filter={"tenant_id": str(tenant_id), "timeline_id": str(timeline_id)}, + ): + samples.append(sample) + return samples + + for timeline in [timeline_1, timeline_2]: + pre_offload_samples = set( + [x.name for x in get_ps_metric_samples_for_timeline(tenant_1, timeline)] + ) + assert len(pre_offload_samples) > 0, f"expected at least one sample for {timeline}" + env.pageserver.http_client().timeline_archival_config( + tenant_1, + timeline, + state=TimelineArchivalState.ARCHIVED, + ) + env.pageserver.http_client().timeline_offload(tenant_1, timeline) + post_offload_samples = set( + [x.name for x in get_ps_metric_samples_for_timeline(tenant_1, timeline)] + ) + assert post_offload_samples == set() + + def test_pageserver_with_empty_tenants(neon_env_builder: NeonEnvBuilder): env = neon_env_builder.init_start() diff --git a/test_runner/regress/test_timeline_detach_ancestor.py b/test_runner/regress/test_timeline_detach_ancestor.py index f0810270b1..b5cc431afe 100644 --- a/test_runner/regress/test_timeline_detach_ancestor.py +++ b/test_runner/regress/test_timeline_detach_ancestor.py @@ -21,7 +21,10 @@ from fixtures.neon_fixtures import ( last_flush_lsn_upload, wait_for_last_flush_lsn, ) -from fixtures.pageserver.http import HistoricLayerInfo, PageserverApiException +from fixtures.pageserver.http import ( + HistoricLayerInfo, + PageserverApiException, +) from fixtures.pageserver.utils import wait_for_last_record_lsn, wait_timeline_detail_404 from fixtures.remote_storage import LocalFsStorage, RemoteStorageKind from fixtures.utils import assert_pageserver_backups_equal, skip_in_debug_build, wait_until @@ -413,6 +416,7 @@ def test_ancestor_detach_behavior_v2(neon_env_builder: NeonEnvBuilder, snapshots "read_only": True, }, ) + sk = env.safekeepers[0] assert sk with pytest.raises(requests.exceptions.HTTPError, match="Not Found"): @@ -504,8 +508,15 @@ def test_ancestor_detach_behavior_v2(neon_env_builder: NeonEnvBuilder, snapshots assert len(lineage.get("original_ancestor", [])) == 0 assert len(lineage.get("reparenting_history", [])) == 0 - for name, _, _, rows, starts in expected_result: - with env.endpoints.create_start(name, tenant_id=env.initial_tenant) as ep: + for branch_name, queried_timeline, _, rows, starts in expected_result: + details = client.timeline_detail(env.initial_tenant, queried_timeline) + log.info(f"reading data from branch {branch_name}") + # specifying the lsn makes the endpoint read-only and not connect to safekeepers + with env.endpoints.create( + branch_name, + lsn=Lsn(details["last_record_lsn"]), + ) as ep: + ep.start(safekeeper_generation=1) assert ep.safe_psql("SELECT count(*) FROM foo;")[0][0] == rows assert ep.safe_psql(f"SELECT count(*) FROM audit WHERE starts = {starts}")[0][0] == 1 @@ -1088,6 +1099,9 @@ def test_timeline_detach_ancestor_interrupted_by_deletion( for ps in env.pageservers: ps.allowed_errors.extend(SHUTDOWN_ALLOWED_ERRORS) + ps.allowed_errors.extend( + [".*Timeline.* has been deleted.*", ".*Timeline.*was cancelled and cannot be used"] + ) pageservers = dict((int(p.id), p) for p in env.pageservers) @@ -1209,6 +1223,9 @@ def test_sharded_tad_interleaved_after_partial_success(neon_env_builder: NeonEnv for ps in env.pageservers: ps.allowed_errors.extend(SHUTDOWN_ALLOWED_ERRORS) + ps.allowed_errors.extend( + [".*Timeline.* has been deleted.*", ".*Timeline.*was cancelled and cannot be used"] + ) pageservers = dict((int(p.id), p) for p in env.pageservers) diff --git a/test_runner/regress/test_timeline_gc_blocking.py b/test_runner/regress/test_timeline_gc_blocking.py index 9a710f5b80..daba8019b6 100644 --- a/test_runner/regress/test_timeline_gc_blocking.py +++ b/test_runner/regress/test_timeline_gc_blocking.py @@ -24,6 +24,10 @@ def test_gc_blocking_by_timeline(neon_env_builder: NeonEnvBuilder, sharded: bool initial_tenant_conf={"gc_period": "1s", "lsn_lease_length": "0s"}, initial_tenant_shard_count=2 if sharded else None, ) + for ps in env.pageservers: + ps.allowed_errors.extend( + [".*Timeline.* has been deleted.*", ".*Timeline.*was cancelled and cannot be used"] + ) if sharded: http = env.storage_controller.pageserver_api() diff --git a/test_runner/regress/test_wal_acceptor.py b/test_runner/regress/test_wal_acceptor.py index 3c337e26aa..5bd7c6022b 100644 --- a/test_runner/regress/test_wal_acceptor.py +++ b/test_runner/regress/test_wal_acceptor.py @@ -229,7 +229,7 @@ def test_many_timelines(neon_env_builder: NeonEnvBuilder): # Test timeline_list endpoint. http_cli = env.safekeepers[0].http_client() - assert len(http_cli.timeline_list()) == 3 + assert len(http_cli.timeline_list()) == 4 # Check that dead minority doesn't prevent the commits: execute insert n_inserts @@ -433,6 +433,7 @@ def test_wal_backup(neon_env_builder: NeonEnvBuilder): env.pageserver.allowed_errors.extend( [ ".*Timeline .* was not found in global map.*", + ".*Timeline .* has been deleted.*", ".*Timeline .* was cancelled and cannot be used anymore.*", ] ) @@ -739,8 +740,8 @@ def test_timeline_status(neon_env_builder: NeonEnvBuilder, auth_enabled: bool): env = neon_env_builder.init_start() tenant_id = env.initial_tenant - timeline_id = env.create_branch("test_timeline_status") - endpoint = env.endpoints.create_start("test_timeline_status") + timeline_id = env.initial_timeline + endpoint = env.endpoints.create_start("main") wa = env.safekeepers[0] @@ -1291,6 +1292,12 @@ def test_lagging_sk(neon_env_builder: NeonEnvBuilder): # it works without compute at all. def test_peer_recovery(neon_env_builder: NeonEnvBuilder): neon_env_builder.num_safekeepers = 3 + + # timelines should be created the old way + neon_env_builder.storage_controller_config = { + "timelines_onto_safekeepers": False, + } + env = neon_env_builder.init_start() tenant_id = env.initial_tenant @@ -1532,6 +1539,11 @@ def test_safekeeper_without_pageserver( def test_replace_safekeeper(neon_env_builder: NeonEnvBuilder): + # timelines should be created the old way manually until we have migration support + neon_env_builder.storage_controller_config = { + "timelines_onto_safekeepers": False, + } + def execute_payload(endpoint: Endpoint): with closing(endpoint.connect()) as conn: with conn.cursor() as cur: @@ -1661,6 +1673,15 @@ def test_pull_timeline(neon_env_builder: NeonEnvBuilder, live_sk_change: bool): res = env.safekeepers[3].pull_timeline( [env.safekeepers[0], env.safekeepers[2]], tenant_id, timeline_id ) + sk_id_1 = env.safekeepers[0].safekeeper_id() + sk_id_3 = env.safekeepers[2].safekeeper_id() + sk_id_4 = env.safekeepers[3].safekeeper_id() + new_conf = MembershipConfiguration( + generation=2, members=[sk_id_1, sk_id_3, sk_id_4], new_members=None + ) + for i in [0, 2, 3]: + env.safekeepers[i].http_client().membership_switch(tenant_id, timeline_id, new_conf) + log.info("Finished pulling timeline") log.info(res) @@ -1705,13 +1726,15 @@ def test_pull_timeline_gc(neon_env_builder: NeonEnvBuilder): neon_env_builder.num_safekeepers = 3 neon_env_builder.enable_safekeeper_remote_storage(default_remote_storage()) env = neon_env_builder.init_start() - tenant_id = env.initial_tenant - timeline_id = env.initial_timeline (src_sk, dst_sk) = (env.safekeepers[0], env.safekeepers[2]) + dst_sk.stop() + + [tenant_id, timeline_id] = env.create_tenant() + log.info("use only first 2 safekeepers, 3rd will be seeded") - endpoint = env.endpoints.create("main") + endpoint = env.endpoints.create("main", tenant_id=tenant_id) endpoint.active_safekeepers = [1, 2] endpoint.start() endpoint.safe_psql("create table t(key int, value text)") @@ -1723,6 +1746,7 @@ def test_pull_timeline_gc(neon_env_builder: NeonEnvBuilder): src_http = src_sk.http_client() # run pull_timeline which will halt before downloading files src_http.configure_failpoints(("sk-snapshot-after-list-pausable", "pause")) + dst_sk.start() pt_handle = PropagatingThread( target=dst_sk.pull_timeline, args=([src_sk], tenant_id, timeline_id) ) @@ -1782,23 +1806,27 @@ def test_pull_timeline_term_change(neon_env_builder: NeonEnvBuilder): neon_env_builder.enable_safekeeper_remote_storage(default_remote_storage()) env = neon_env_builder.init_start() tenant_id = env.initial_tenant - timeline_id = env.initial_timeline (src_sk, dst_sk) = (env.safekeepers[0], env.safekeepers[2]) + dst_sk.stop() + src_http = src_sk.http_client() + src_http.configure_failpoints(("sk-snapshot-after-list-pausable", "pause")) + + timeline_id = env.create_branch("pull_timeline_term_changes") + + # run pull_timeline which will halt before downloading files log.info("use only first 2 safekeepers, 3rd will be seeded") - ep = env.endpoints.create("main") + ep = env.endpoints.create("pull_timeline_term_changes") ep.active_safekeepers = [1, 2] ep.start() ep.safe_psql("create table t(key int, value text)") ep.safe_psql("insert into t select generate_series(1, 1000), 'pear'") - src_http = src_sk.http_client() - # run pull_timeline which will halt before downloading files - src_http.configure_failpoints(("sk-snapshot-after-list-pausable", "pause")) pt_handle = PropagatingThread( target=dst_sk.pull_timeline, args=([src_sk], tenant_id, timeline_id) ) + dst_sk.start() pt_handle.start() src_sk.wait_until_paused("sk-snapshot-after-list-pausable") @@ -1807,7 +1835,7 @@ def test_pull_timeline_term_change(neon_env_builder: NeonEnvBuilder): # restart compute to bump term ep.stop() - ep = env.endpoints.create("main") + ep = env.endpoints.create("pull_timeline_term_changes") ep.active_safekeepers = [1, 2] ep.start() ep.safe_psql("insert into t select generate_series(1, 100), 'pear'") @@ -1929,12 +1957,18 @@ def test_pull_timeline_while_evicted(neon_env_builder: NeonEnvBuilder): @run_only_on_default_postgres("tests only safekeeper API") def test_membership_api(neon_env_builder: NeonEnvBuilder): neon_env_builder.num_safekeepers = 1 + # timelines should be created the old way + neon_env_builder.storage_controller_config = { + "timelines_onto_safekeepers": False, + } + env = neon_env_builder.init_start() # These are expected after timeline deletion on safekeepers. env.pageserver.allowed_errors.extend( [ ".*Timeline .* was not found in global map.*", + ".*Timeline .* has been deleted.*", ".*Timeline .* was cancelled and cannot be used anymore.*", ] ) @@ -2008,6 +2042,12 @@ def test_explicit_timeline_creation(neon_env_builder: NeonEnvBuilder): created manually, later storcon will do that. """ neon_env_builder.num_safekeepers = 3 + + # timelines should be created the old way manually + neon_env_builder.storage_controller_config = { + "timelines_onto_safekeepers": False, + } + env = neon_env_builder.init_start() tenant_id = env.initial_tenant @@ -2063,7 +2103,7 @@ def test_idle_reconnections(neon_env_builder: NeonEnvBuilder): env = neon_env_builder.init_start() tenant_id = env.initial_tenant - timeline_id = env.create_branch("test_idle_reconnections") + timeline_id = env.initial_timeline def collect_stats() -> dict[str, float]: # we need to collect safekeeper_pg_queries_received_total metric from all safekeepers @@ -2094,7 +2134,7 @@ def test_idle_reconnections(neon_env_builder: NeonEnvBuilder): collect_stats() - endpoint = env.endpoints.create_start("test_idle_reconnections") + endpoint = env.endpoints.create_start("main") # just write something to the timeline endpoint.safe_psql("create table t(i int)") collect_stats() diff --git a/test_runner/regress/test_wal_acceptor_async.py b/test_runner/regress/test_wal_acceptor_async.py index d8a7dc2a2b..1bad387a90 100644 --- a/test_runner/regress/test_wal_acceptor_async.py +++ b/test_runner/regress/test_wal_acceptor_async.py @@ -590,6 +590,13 @@ async def run_wal_truncation(env: NeonEnv, safekeeper_proto_version: int): @pytest.mark.parametrize("safekeeper_proto_version", [2, 3]) def test_wal_truncation(neon_env_builder: NeonEnvBuilder, safekeeper_proto_version: int): neon_env_builder.num_safekeepers = 3 + if safekeeper_proto_version == 2: + # On the legacy protocol, we don't support generations, which are part of + # `timelines_onto_safekeepers` + neon_env_builder.storage_controller_config = { + "timelines_onto_safekeepers": False, + } + env = neon_env_builder.init_start() asyncio.run(run_wal_truncation(env, safekeeper_proto_version)) @@ -713,6 +720,11 @@ async def run_quorum_sanity(env: NeonEnv): # we don't. def test_quorum_sanity(neon_env_builder: NeonEnvBuilder): neon_env_builder.num_safekeepers = 4 + + # The test fails basically always on the new mode. + neon_env_builder.storage_controller_config = { + "timelines_onto_safekeepers": False, + } env = neon_env_builder.init_start() asyncio.run(run_quorum_sanity(env)) diff --git a/test_runner/regress/test_wal_receiver.py b/test_runner/regress/test_wal_receiver.py index 0252b590cc..d281c055b0 100644 --- a/test_runner/regress/test_wal_receiver.py +++ b/test_runner/regress/test_wal_receiver.py @@ -16,6 +16,13 @@ if TYPE_CHECKING: # Checks that pageserver's walreceiver state is printed in the logs during WAL wait timeout. # Ensures that walreceiver does not run without any data inserted and only starts after the insertion. def test_pageserver_lsn_wait_error_start(neon_env_builder: NeonEnvBuilder): + # we assert below that the walreceiver is not active before data writes. + # with manually created timelines, it is active. + # FIXME: remove this test once we remove timelines_onto_safekeepers + neon_env_builder.storage_controller_config = { + "timelines_onto_safekeepers": False, + } + # Trigger WAL wait timeout faster neon_env_builder.pageserver_config_override = "wait_lsn_timeout = '1s'" env = neon_env_builder.init_start() diff --git a/test_runner/sql_regress/expected/neon-event-triggers.out b/test_runner/sql_regress/expected/neon-event-triggers.out new file mode 100644 index 0000000000..3a62e67316 --- /dev/null +++ b/test_runner/sql_regress/expected/neon-event-triggers.out @@ -0,0 +1,90 @@ +create or replace function admin_proc() + returns event_trigger + language plpgsql as +$$ +begin + raise notice 'admin event trigger is executed for %', current_user; +end; +$$; +create role neon_superuser; +create role neon_admin login inherit createrole createdb in role neon_superuser; +grant create on schema public to neon_admin; +create database neondb with owner neon_admin; +grant all privileges on database neondb to neon_superuser; +create role neon_user; +grant create on schema public to neon_user; +create event trigger on_ddl1 on ddl_command_end +execute procedure admin_proc(); +set role neon_user; +-- check that non-privileged user can not change neon.event_triggers +set neon.event_triggers to false; +ERROR: permission denied to set neon.event_triggers +DETAIL: Only "neon_superuser" is allowed to set the GUC +-- Non-privileged neon user should not be able to create event trigers +create event trigger on_ddl2 on ddl_command_end +execute procedure admin_proc(); +ERROR: permission denied to create event trigger "on_ddl2" +HINT: Must be superuser to create an event trigger. +set role neon_admin; +-- neon_superuser should be able to create event trigers +create or replace function neon_proc() + returns event_trigger + language plpgsql as +$$ +begin + raise notice 'neon event trigger is executed for %', current_user; +end; +$$; +NOTICE: admin event trigger is executed for neon_admin +create event trigger on_ddl2 on ddl_command_end +execute procedure neon_proc(); +\c neondb neon_admin +create or replace function neondb_proc() + returns event_trigger + language plpgsql as +$$ +begin + raise notice 'neondb event trigger is executed for %', current_user; +end; +$$; +create or replace function neondb_secdef_proc() + returns event_trigger + language plpgsql + SECURITY DEFINER +as +$$ +begin + raise notice 'neondb secdef event trigger is executed for %', current_user; +end; +$$; +-- neon_admin (neon_superuser member) should be able to create event triggers +create event trigger on_ddl3 on ddl_command_end +execute procedure neondb_proc(); +create event trigger on_ddl4 on ddl_command_end +execute procedure neondb_secdef_proc(); +-- Check that event trigger is fired for neon_admin +create table t1(x integer); +NOTICE: neondb event trigger is executed for neon_admin +NOTICE: neondb secdef event trigger is executed for neon_admin +-- Check that event trigger can be skipped +set neon.event_triggers to false; +create table t2(x integer); +WARNING: Skipping Event Trigger: neon.event_triggers is false +WARNING: Skipping Event Trigger: neon.event_triggers is false +\c regression cloud_admin +-- Check that event triggers are not fired for superuser +create table t3(x integer); +NOTICE: admin event trigger is executed for cloud_admin +WARNING: Skipping Event Trigger +DETAIL: Event Trigger function "neon_proc" is owned by non-superuser role "neon_admin", and current_user "cloud_admin" is superuser +\c neondb cloud_admin +-- Check that user-defined event triggers are not fired for superuser +create table t4(x integer); +WARNING: Skipping Event Trigger +DETAIL: Event Trigger function "neondb_proc" is owned by non-superuser role "neon_admin", and current_user "cloud_admin" is superuser +WARNING: Skipping Event Trigger +DETAIL: Event Trigger function "neondb_secdef_proc" is owned by non-superuser role "neon_admin", and current_user "cloud_admin" is superuser +\c neondb neon_admin +-- Check that neon_admin can drop event triggers +drop event trigger on_ddl3; +drop event trigger on_ddl4; diff --git a/test_runner/sql_regress/parallel_schedule b/test_runner/sql_regress/parallel_schedule index d9508d1c90..d1bd7226ed 100644 --- a/test_runner/sql_regress/parallel_schedule +++ b/test_runner/sql_regress/parallel_schedule @@ -9,3 +9,4 @@ test: neon-rel-truncate test: neon-clog test: neon-test-utils test: neon-vacuum-full +test: neon-event-triggers diff --git a/test_runner/sql_regress/sql/neon-event-triggers.sql b/test_runner/sql_regress/sql/neon-event-triggers.sql new file mode 100644 index 0000000000..75365455dc --- /dev/null +++ b/test_runner/sql_regress/sql/neon-event-triggers.sql @@ -0,0 +1,96 @@ +create or replace function admin_proc() + returns event_trigger + language plpgsql as +$$ +begin + raise notice 'admin event trigger is executed for %', current_user; +end; +$$; + +create role neon_superuser; +create role neon_admin login inherit createrole createdb in role neon_superuser; +grant create on schema public to neon_admin; +create database neondb with owner neon_admin; +grant all privileges on database neondb to neon_superuser; + +create role neon_user; +grant create on schema public to neon_user; + +create event trigger on_ddl1 on ddl_command_end +execute procedure admin_proc(); + +set role neon_user; + +-- check that non-privileged user can not change neon.event_triggers +set neon.event_triggers to false; + +-- Non-privileged neon user should not be able to create event trigers +create event trigger on_ddl2 on ddl_command_end +execute procedure admin_proc(); + +set role neon_admin; + +-- neon_superuser should be able to create event trigers +create or replace function neon_proc() + returns event_trigger + language plpgsql as +$$ +begin + raise notice 'neon event trigger is executed for %', current_user; +end; +$$; + +create event trigger on_ddl2 on ddl_command_end +execute procedure neon_proc(); + +\c neondb neon_admin + +create or replace function neondb_proc() + returns event_trigger + language plpgsql as +$$ +begin + raise notice 'neondb event trigger is executed for %', current_user; +end; +$$; + +create or replace function neondb_secdef_proc() + returns event_trigger + language plpgsql + SECURITY DEFINER +as +$$ +begin + raise notice 'neondb secdef event trigger is executed for %', current_user; +end; +$$; + +-- neon_admin (neon_superuser member) should be able to create event triggers +create event trigger on_ddl3 on ddl_command_end +execute procedure neondb_proc(); + +create event trigger on_ddl4 on ddl_command_end +execute procedure neondb_secdef_proc(); + +-- Check that event trigger is fired for neon_admin +create table t1(x integer); + +-- Check that event trigger can be skipped +set neon.event_triggers to false; +create table t2(x integer); + +\c regression cloud_admin + +-- Check that event triggers are not fired for superuser +create table t3(x integer); + +\c neondb cloud_admin + +-- Check that user-defined event triggers are not fired for superuser +create table t4(x integer); + +\c neondb neon_admin + +-- Check that neon_admin can drop event triggers +drop event trigger on_ddl3; +drop event trigger on_ddl4; diff --git a/vendor/postgres-v14 b/vendor/postgres-v14 index 6770bc2513..9085654ee8 160000 --- a/vendor/postgres-v14 +++ b/vendor/postgres-v14 @@ -1 +1 @@ -Subproject commit 6770bc251301ef40c66f7ecb731741dc435b5051 +Subproject commit 9085654ee8022d5cc4ca719380a1dc53e5e3246f diff --git a/vendor/revisions.json b/vendor/revisions.json index 12d5499ddb..b260698c86 100644 --- a/vendor/revisions.json +++ b/vendor/revisions.json @@ -13,6 +13,6 @@ ], "v14": [ "14.18", - "6770bc251301ef40c66f7ecb731741dc435b5051" + "9085654ee8022d5cc4ca719380a1dc53e5e3246f" ] } diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 2b07889871..b74df50f86 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -20,8 +20,7 @@ anstream = { version = "0.6" } anyhow = { version = "1", features = ["backtrace"] } axum = { version = "0.8", features = ["ws"] } axum-core = { version = "0.5", default-features = false, features = ["tracing"] } -base64-594e8ee84c453af0 = { package = "base64", version = "0.13", features = ["alloc"] } -base64-647d43efb71741da = { package = "base64", version = "0.21" } +base64 = { version = "0.21" } base64ct = { version = "1", default-features = false, features = ["std"] } bytes = { version = "1", features = ["serde"] } camino = { version = "1", default-features = false, features = ["serde1"] }