diff --git a/.github/actionlint.yml b/.github/actionlint.yml index 25b2fc702a..8a4bcaf811 100644 --- a/.github/actionlint.yml +++ b/.github/actionlint.yml @@ -31,7 +31,7 @@ config-variables: - NEON_PROD_AWS_ACCOUNT_ID - PGREGRESS_PG16_PROJECT_ID - PGREGRESS_PG17_PROJECT_ID - - PREWARM_PGBENCH_SIZE + - PREWARM_PROJECT_ID - REMOTE_STORAGE_AZURE_CONTAINER - REMOTE_STORAGE_AZURE_REGION - SLACK_CICD_CHANNEL_ID diff --git a/.github/workflows/benchbase_tpcc.yml b/.github/workflows/benchbase_tpcc.yml new file mode 100644 index 0000000000..3a36a97bb1 --- /dev/null +++ b/.github/workflows/benchbase_tpcc.yml @@ -0,0 +1,384 @@ +name: TPC-C like benchmark using benchbase + +on: + schedule: + # * is a special character in YAML so you have to quote this string + # ┌───────────── minute (0 - 59) + # │ ┌───────────── hour (0 - 23) + # │ │ ┌───────────── day of the month (1 - 31) + # │ │ │ ┌───────────── month (1 - 12 or JAN-DEC) + # │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT) + - cron: '0 6 * * *' # run once a day at 6 AM UTC + workflow_dispatch: # adds ability to run this manually + +defaults: + run: + shell: bash -euxo pipefail {0} + +concurrency: + # Allow only one workflow globally because we do not want to be too noisy in production environment + group: benchbase-tpcc-workflow + cancel-in-progress: false + +permissions: + contents: read + +jobs: + benchbase-tpcc: + strategy: + fail-fast: false # allow other variants to continue even if one fails + matrix: + include: + - warehouses: 50 # defines number of warehouses and is used to compute number of terminals + max_rate: 800 # measured max TPS at scale factor based on experiments. Adjust if performance is better/worse + min_cu: 0.25 # simulate free tier plan (0.25 -2 CU) + max_cu: 2 + - warehouses: 500 # serverless plan (2-8 CU) + max_rate: 2000 + min_cu: 2 + max_cu: 8 + - warehouses: 1000 # business plan (2-16 CU) + max_rate: 2900 + min_cu: 2 + max_cu: 16 + max-parallel: 1 # we want to run each workload size sequentially to avoid noisy neighbors + permissions: + contents: write + statuses: write + id-token: write # aws-actions/configure-aws-credentials + env: + PG_CONFIG: /tmp/neon/pg_install/v17/bin/pg_config + PSQL: /tmp/neon/pg_install/v17/bin/psql + PG_17_LIB_PATH: /tmp/neon/pg_install/v17/lib + POSTGRES_VERSION: 17 + runs-on: [ self-hosted, us-east-2, x64 ] + timeout-minutes: 1440 + + steps: + - name: Harden the runner (Audit all outbound calls) + uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0 + with: + egress-policy: audit + + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Configure AWS credentials # necessary to download artefacts + uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2 + with: + aws-region: eu-central-1 + role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} + role-duration-seconds: 18000 # 5 hours is currently max associated with IAM role + + - name: Download Neon artifact + uses: ./.github/actions/download + with: + name: neon-${{ runner.os }}-${{ runner.arch }}-release-artifact + path: /tmp/neon/ + prefix: latest + aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} + + - name: Create Neon Project + id: create-neon-project-tpcc + uses: ./.github/actions/neon-project-create + with: + region_id: aws-us-east-2 + postgres_version: ${{ env.POSTGRES_VERSION }} + compute_units: '[${{ matrix.min_cu }}, ${{ matrix.max_cu }}]' + api_key: ${{ secrets.NEON_PRODUCTION_API_KEY_4_BENCHMARKS }} + api_host: console.neon.tech # production (!) + + - name: Initialize Neon project + env: + BENCHMARK_TPCC_CONNSTR: ${{ steps.create-neon-project-tpcc.outputs.dsn }} + PROJECT_ID: ${{ steps.create-neon-project-tpcc.outputs.project_id }} + run: | + echo "Initializing Neon project with project_id: ${PROJECT_ID}" + export LD_LIBRARY_PATH=${PG_17_LIB_PATH} + + # Retry logic for psql connection with 1 minute sleep between attempts + for attempt in {1..3}; do + echo "Attempt ${attempt}/3: Creating extensions in Neon project" + if ${PSQL} "${BENCHMARK_TPCC_CONNSTR}" -c "CREATE EXTENSION IF NOT EXISTS neon; CREATE EXTENSION IF NOT EXISTS neon_utils;"; then + echo "Successfully created extensions" + break + else + echo "Failed to create extensions on attempt ${attempt}" + if [ ${attempt} -lt 3 ]; then + echo "Waiting 60 seconds before retry..." + sleep 60 + else + echo "All attempts failed, exiting" + exit 1 + fi + fi + done + + echo "BENCHMARK_TPCC_CONNSTR=${BENCHMARK_TPCC_CONNSTR}" >> $GITHUB_ENV + + - name: Generate BenchBase workload configuration + env: + WAREHOUSES: ${{ matrix.warehouses }} + MAX_RATE: ${{ matrix.max_rate }} + run: | + echo "Generating BenchBase configs for warehouses: ${WAREHOUSES}, max_rate: ${MAX_RATE}" + + # Extract hostname and password from connection string + # Format: postgresql://username:password@hostname/database?params (no port for Neon) + HOSTNAME=$(echo "${BENCHMARK_TPCC_CONNSTR}" | sed -n 's|.*://[^:]*:[^@]*@\([^/]*\)/.*|\1|p') + PASSWORD=$(echo "${BENCHMARK_TPCC_CONNSTR}" | sed -n 's|.*://[^:]*:\([^@]*\)@.*|\1|p') + + echo "Extracted hostname: ${HOSTNAME}" + + # Use runner temp (NVMe) as working directory + cd "${RUNNER_TEMP}" + + # Copy the generator script + cp "${GITHUB_WORKSPACE}/test_runner/performance/benchbase_tpc_c_helpers/generate_workload_size.py" . + + # Generate configs and scripts + python3 generate_workload_size.py \ + --warehouses ${WAREHOUSES} \ + --max-rate ${MAX_RATE} \ + --hostname ${HOSTNAME} \ + --password ${PASSWORD} \ + --runner-arch ${{ runner.arch }} + + # Fix path mismatch: move generated configs and scripts to expected locations + mv ../configs ./configs + mv ../scripts ./scripts + + - name: Prepare database (load data) + env: + WAREHOUSES: ${{ matrix.warehouses }} + run: | + cd "${RUNNER_TEMP}" + + echo "Loading ${WAREHOUSES} warehouses into database..." + + # Run the loader script and capture output to log file while preserving stdout/stderr + ./scripts/load_${WAREHOUSES}_warehouses.sh 2>&1 | tee "load_${WAREHOUSES}_warehouses.log" + + echo "Database loading completed" + + - name: Run TPC-C benchmark (warmup phase, then benchmark at 70% of configuredmax TPS) + env: + WAREHOUSES: ${{ matrix.warehouses }} + run: | + cd "${RUNNER_TEMP}" + + echo "Running TPC-C benchmark with ${WAREHOUSES} warehouses..." + + # Run the optimal rate benchmark + ./scripts/execute_${WAREHOUSES}_warehouses_opt_rate.sh + + echo "Benchmark execution completed" + + - name: Run TPC-C benchmark (warmup phase, then ramp down TPS and up again in 5 minute intervals) + + env: + WAREHOUSES: ${{ matrix.warehouses }} + run: | + cd "${RUNNER_TEMP}" + + echo "Running TPC-C ramp-down-up with ${WAREHOUSES} warehouses..." + + # Run the optimal rate benchmark + ./scripts/execute_${WAREHOUSES}_warehouses_ramp_up.sh + + echo "Benchmark execution completed" + + - name: Process results (upload to test results database and generate diagrams) + env: + WAREHOUSES: ${{ matrix.warehouses }} + MIN_CU: ${{ matrix.min_cu }} + MAX_CU: ${{ matrix.max_cu }} + PROJECT_ID: ${{ steps.create-neon-project-tpcc.outputs.project_id }} + REVISION: ${{ github.sha }} + PERF_DB_CONNSTR: ${{ secrets.PERF_TEST_RESULT_CONNSTR }} + run: | + cd "${RUNNER_TEMP}" + + echo "Creating temporary Python environment for results processing..." + + # Create temporary virtual environment + python3 -m venv temp_results_env + source temp_results_env/bin/activate + + # Install required packages in virtual environment + pip install matplotlib pandas psycopg2-binary + + echo "Copying results processing scripts..." + + # Copy both processing scripts + cp "${GITHUB_WORKSPACE}/test_runner/performance/benchbase_tpc_c_helpers/generate_diagrams.py" . + cp "${GITHUB_WORKSPACE}/test_runner/performance/benchbase_tpc_c_helpers/upload_results_to_perf_test_results.py" . + + echo "Processing load phase metrics..." + + # Find and process load log + LOAD_LOG=$(find . -name "load_${WAREHOUSES}_warehouses.log" -type f | head -1) + if [ -n "$LOAD_LOG" ]; then + echo "Processing load metrics from: $LOAD_LOG" + python upload_results_to_perf_test_results.py \ + --load-log "$LOAD_LOG" \ + --run-type "load" \ + --warehouses "${WAREHOUSES}" \ + --min-cu "${MIN_CU}" \ + --max-cu "${MAX_CU}" \ + --project-id "${PROJECT_ID}" \ + --revision "${REVISION}" \ + --connection-string "${PERF_DB_CONNSTR}" + else + echo "Warning: Load log file not found: load_${WAREHOUSES}_warehouses.log" + fi + + echo "Processing warmup results for optimal rate..." + + # Find and process warmup results + WARMUP_CSV=$(find results_warmup -name "*.results.csv" -type f | head -1) + WARMUP_JSON=$(find results_warmup -name "*.summary.json" -type f | head -1) + + if [ -n "$WARMUP_CSV" ] && [ -n "$WARMUP_JSON" ]; then + echo "Generating warmup diagram from: $WARMUP_CSV" + python generate_diagrams.py \ + --input-csv "$WARMUP_CSV" \ + --output-svg "warmup_${WAREHOUSES}_warehouses_performance.svg" \ + --title-suffix "Warmup at max TPS" + + echo "Uploading warmup metrics from: $WARMUP_JSON" + python upload_results_to_perf_test_results.py \ + --summary-json "$WARMUP_JSON" \ + --results-csv "$WARMUP_CSV" \ + --run-type "warmup" \ + --min-cu "${MIN_CU}" \ + --max-cu "${MAX_CU}" \ + --project-id "${PROJECT_ID}" \ + --revision "${REVISION}" \ + --connection-string "${PERF_DB_CONNSTR}" + else + echo "Warning: Missing warmup results files (CSV: $WARMUP_CSV, JSON: $WARMUP_JSON)" + fi + + echo "Processing optimal rate results..." + + # Find and process optimal rate results + OPTRATE_CSV=$(find results_opt_rate -name "*.results.csv" -type f | head -1) + OPTRATE_JSON=$(find results_opt_rate -name "*.summary.json" -type f | head -1) + + if [ -n "$OPTRATE_CSV" ] && [ -n "$OPTRATE_JSON" ]; then + echo "Generating optimal rate diagram from: $OPTRATE_CSV" + python generate_diagrams.py \ + --input-csv "$OPTRATE_CSV" \ + --output-svg "benchmark_${WAREHOUSES}_warehouses_performance.svg" \ + --title-suffix "70% of max TPS" + + echo "Uploading optimal rate metrics from: $OPTRATE_JSON" + python upload_results_to_perf_test_results.py \ + --summary-json "$OPTRATE_JSON" \ + --results-csv "$OPTRATE_CSV" \ + --run-type "opt-rate" \ + --min-cu "${MIN_CU}" \ + --max-cu "${MAX_CU}" \ + --project-id "${PROJECT_ID}" \ + --revision "${REVISION}" \ + --connection-string "${PERF_DB_CONNSTR}" + else + echo "Warning: Missing optimal rate results files (CSV: $OPTRATE_CSV, JSON: $OPTRATE_JSON)" + fi + + echo "Processing warmup 2 results for ramp down/up phase..." + + # Find and process warmup results + WARMUP_CSV=$(find results_warmup -name "*.results.csv" -type f | tail -1) + WARMUP_JSON=$(find results_warmup -name "*.summary.json" -type f | tail -1) + + if [ -n "$WARMUP_CSV" ] && [ -n "$WARMUP_JSON" ]; then + echo "Generating warmup diagram from: $WARMUP_CSV" + python generate_diagrams.py \ + --input-csv "$WARMUP_CSV" \ + --output-svg "warmup_2_${WAREHOUSES}_warehouses_performance.svg" \ + --title-suffix "Warmup at max TPS" + + echo "Uploading warmup metrics from: $WARMUP_JSON" + python upload_results_to_perf_test_results.py \ + --summary-json "$WARMUP_JSON" \ + --results-csv "$WARMUP_CSV" \ + --run-type "warmup" \ + --min-cu "${MIN_CU}" \ + --max-cu "${MAX_CU}" \ + --project-id "${PROJECT_ID}" \ + --revision "${REVISION}" \ + --connection-string "${PERF_DB_CONNSTR}" + else + echo "Warning: Missing warmup results files (CSV: $WARMUP_CSV, JSON: $WARMUP_JSON)" + fi + + echo "Processing ramp results..." + + # Find and process ramp results + RAMPUP_CSV=$(find results_ramp_up -name "*.results.csv" -type f | head -1) + RAMPUP_JSON=$(find results_ramp_up -name "*.summary.json" -type f | head -1) + + if [ -n "$RAMPUP_CSV" ] && [ -n "$RAMPUP_JSON" ]; then + echo "Generating ramp diagram from: $RAMPUP_CSV" + python generate_diagrams.py \ + --input-csv "$RAMPUP_CSV" \ + --output-svg "ramp_${WAREHOUSES}_warehouses_performance.svg" \ + --title-suffix "ramp TPS down and up in 5 minute intervals" + + echo "Uploading ramp metrics from: $RAMPUP_JSON" + python upload_results_to_perf_test_results.py \ + --summary-json "$RAMPUP_JSON" \ + --results-csv "$RAMPUP_CSV" \ + --run-type "ramp-up" \ + --min-cu "${MIN_CU}" \ + --max-cu "${MAX_CU}" \ + --project-id "${PROJECT_ID}" \ + --revision "${REVISION}" \ + --connection-string "${PERF_DB_CONNSTR}" + else + echo "Warning: Missing ramp results files (CSV: $RAMPUP_CSV, JSON: $RAMPUP_JSON)" + fi + + # Deactivate and clean up virtual environment + deactivate + rm -rf temp_results_env + rm upload_results_to_perf_test_results.py + + echo "Results processing completed and environment cleaned up" + + - name: Set date for upload + id: set-date + run: echo "date=$(date +%Y-%m-%d)" >> $GITHUB_OUTPUT + + - name: Configure AWS credentials # necessary to upload results + uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2 + with: + aws-region: us-east-2 + role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} + role-duration-seconds: 900 # 900 is minimum value + + - name: Upload benchmark results to S3 + env: + S3_BUCKET: neon-public-benchmark-results + S3_PREFIX: benchbase-tpc-c/${{ steps.set-date.outputs.date }}/${{ github.run_id }}/${{ matrix.warehouses }}-warehouses + run: | + echo "Redacting passwords from configuration files before upload..." + + # Mask all passwords in XML config files + find "${RUNNER_TEMP}/configs" -name "*.xml" -type f -exec sed -i 's|[^<]*|redacted|g' {} \; + + echo "Uploading benchmark results to s3://${S3_BUCKET}/${S3_PREFIX}/" + + # Upload the entire benchmark directory recursively + aws s3 cp --only-show-errors --recursive "${RUNNER_TEMP}" s3://${S3_BUCKET}/${S3_PREFIX}/ + + echo "Upload completed" + + - name: Delete Neon Project + if: ${{ always() }} + uses: ./.github/actions/neon-project-delete + with: + project_id: ${{ steps.create-neon-project-tpcc.outputs.project_id }} + api_key: ${{ secrets.NEON_PRODUCTION_API_KEY_4_BENCHMARKS }} + api_host: console.neon.tech # production (!) \ No newline at end of file diff --git a/.github/workflows/benchmarking.yml b/.github/workflows/benchmarking.yml index df80bad579..c9a998bd4e 100644 --- a/.github/workflows/benchmarking.yml +++ b/.github/workflows/benchmarking.yml @@ -418,7 +418,7 @@ jobs: statuses: write id-token: write # aws-actions/configure-aws-credentials env: - PGBENCH_SIZE: ${{ vars.PREWARM_PGBENCH_SIZE }} + PROJECT_ID: ${{ vars.PREWARM_PROJECT_ID }} POSTGRES_DISTRIB_DIR: /tmp/neon/pg_install DEFAULT_PG_VERSION: 17 TEST_OUTPUT: /tmp/test_output diff --git a/.github/workflows/build-build-tools-image.yml b/.github/workflows/build-build-tools-image.yml index 24e4c8fa3d..5e53d8231f 100644 --- a/.github/workflows/build-build-tools-image.yml +++ b/.github/workflows/build-build-tools-image.yml @@ -146,7 +146,9 @@ jobs: with: file: build-tools/Dockerfile context: . - provenance: false + attests: | + type=provenance,mode=max + type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1 push: true pull: true build-args: | diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index f237a991cc..0dcbd1c6dd 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -634,7 +634,9 @@ jobs: DEBIAN_VERSION=bookworm secrets: | SUBZERO_ACCESS_TOKEN=${{ secrets.CI_ACCESS_TOKEN }} - provenance: false + attests: | + type=provenance,mode=max + type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1 push: true pull: true file: Dockerfile @@ -747,7 +749,9 @@ jobs: PG_VERSION=${{ matrix.version.pg }} BUILD_TAG=${{ needs.meta.outputs.release-tag || needs.meta.outputs.build-tag }} DEBIAN_VERSION=${{ matrix.version.debian }} - provenance: false + attests: | + type=provenance,mode=max + type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1 push: true pull: true file: compute/compute-node.Dockerfile @@ -766,7 +770,9 @@ jobs: PG_VERSION=${{ matrix.version.pg }} BUILD_TAG=${{ needs.meta.outputs.release-tag || needs.meta.outputs.build-tag }} DEBIAN_VERSION=${{ matrix.version.debian }} - provenance: false + attests: | + type=provenance,mode=max + type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1 push: true pull: true file: compute/compute-node.Dockerfile diff --git a/.github/workflows/pg-clients.yml b/.github/workflows/pg-clients.yml index 6efe0b4c8c..40b2c51624 100644 --- a/.github/workflows/pg-clients.yml +++ b/.github/workflows/pg-clients.yml @@ -48,8 +48,20 @@ jobs: uses: ./.github/workflows/build-build-tools-image.yml secrets: inherit + generate-ch-tmppw: + runs-on: ubuntu-22.04 + outputs: + tmp_val: ${{ steps.pwgen.outputs.tmp_val }} + steps: + - name: Generate a random password + id: pwgen + run: | + set +x + p=$(dd if=/dev/random bs=14 count=1 2>/dev/null | base64) + echo tmp_val="${p//\//}" >> "${GITHUB_OUTPUT}" + test-logical-replication: - needs: [ build-build-tools-image ] + needs: [ build-build-tools-image, generate-ch-tmppw ] runs-on: ubuntu-22.04 container: @@ -60,16 +72,21 @@ jobs: options: --init --user root services: clickhouse: - image: clickhouse/clickhouse-server:24.6.3.64 + image: clickhouse/clickhouse-server:25.6 + env: + CLICKHOUSE_PASSWORD: ${{ needs.generate-ch-tmppw.outputs.tmp_val }} + PGSSLCERT: /tmp/postgresql.crt ports: - 9000:9000 - 8123:8123 zookeeper: - image: quay.io/debezium/zookeeper:2.7 + image: quay.io/debezium/zookeeper:3.1.3.Final ports: - 2181:2181 + - 2888:2888 + - 3888:3888 kafka: - image: quay.io/debezium/kafka:2.7 + image: quay.io/debezium/kafka:3.1.3.Final env: ZOOKEEPER_CONNECT: "zookeeper:2181" KAFKA_ADVERTISED_LISTENERS: PLAINTEXT://kafka:9092 @@ -79,7 +96,7 @@ jobs: ports: - 9092:9092 debezium: - image: quay.io/debezium/connect:2.7 + image: quay.io/debezium/connect:3.1.3.Final env: BOOTSTRAP_SERVERS: kafka:9092 GROUP_ID: 1 @@ -125,6 +142,7 @@ jobs: aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} env: BENCHMARK_CONNSTR: ${{ steps.create-neon-project.outputs.dsn }} + CLICKHOUSE_PASSWORD: ${{ needs.generate-ch-tmppw.outputs.tmp_val }} - name: Delete Neon Project if: always() diff --git a/.github/workflows/proxy-benchmark.yml b/.github/workflows/proxy-benchmark.yml index 0ae93ce295..e48fe41b45 100644 --- a/.github/workflows/proxy-benchmark.yml +++ b/.github/workflows/proxy-benchmark.yml @@ -3,7 +3,7 @@ name: Periodic proxy performance test on unit-perf-aws-arm runners on: push: # TODO: remove after testing branches: - - test-proxy-bench # Runs on pushes to branches starting with test-proxy-bench + - test-proxy-bench # Runs on pushes to test-proxy-bench branch # schedule: # * is a special character in YAML so you have to quote this string # ┌───────────── minute (0 - 59) @@ -32,7 +32,7 @@ jobs: statuses: write contents: write pull-requests: write - runs-on: [self-hosted, unit-perf-aws-arm] + runs-on: [ self-hosted, unit-perf-aws-arm ] timeout-minutes: 60 # 1h timeout container: image: ghcr.io/neondatabase/build-tools:pinned-bookworm @@ -55,30 +55,58 @@ jobs: { echo "PROXY_BENCH_PATH=$PROXY_BENCH_PATH" echo "NEON_DIR=${RUNNER_TEMP}/neon" + echo "NEON_PROXY_PATH=${RUNNER_TEMP}/neon/bin/proxy" echo "TEST_OUTPUT=${PROXY_BENCH_PATH}/test_output" echo "" } >> "$GITHUB_ENV" - - name: Run proxy-bench - run: ${PROXY_BENCH_PATH}/run.sh + - name: Cache poetry deps + uses: actions/cache@v4 + with: + path: ~/.cache/pypoetry/virtualenvs + key: v2-${{ runner.os }}-${{ runner.arch }}-python-deps-bookworm-${{ hashFiles('poetry.lock') }} - - name: Ingest Bench Results # neon repo script + - name: Install Python deps + shell: bash -euxo pipefail {0} + run: ./scripts/pysync + + - name: show ulimits + shell: bash -euxo pipefail {0} + run: | + ulimit -a + + - name: Run proxy-bench + working-directory: ${{ env.PROXY_BENCH_PATH }} + run: ./run.sh --with-grafana --bare-metal + + - name: Ingest Bench Results if: always() + working-directory: ${{ env.NEON_DIR }} run: | mkdir -p $TEST_OUTPUT python $NEON_DIR/scripts/proxy_bench_results_ingest.py --out $TEST_OUTPUT - name: Push Metrics to Proxy perf database + shell: bash -euxo pipefail {0} if: always() env: PERF_TEST_RESULT_CONNSTR: "${{ secrets.PROXY_TEST_RESULT_CONNSTR }}" REPORT_FROM: $TEST_OUTPUT + working-directory: ${{ env.NEON_DIR }} run: $NEON_DIR/scripts/generate_and_push_perf_report.sh - - name: Docker cleanup - if: always() - run: docker compose down - - name: Notify Failure if: failure() - run: echo "Proxy bench job failed" && exit 1 \ No newline at end of file + run: echo "Proxy bench job failed" && exit 1 + + - name: Cleanup Test Resources + if: always() + shell: bash -euxo pipefail {0} + run: | + # Cleanup the test resources + if [[ -d "${TEST_OUTPUT}" ]]; then + rm -rf ${TEST_OUTPUT} + fi + if [[ -d "${PROXY_BENCH_PATH}/test_output" ]]; then + rm -rf ${PROXY_BENCH_PATH}/test_output + fi \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 669fb0eac7..f51be760d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -211,11 +211,11 @@ dependencies = [ [[package]] name = "async-lock" -version = "3.2.0" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7125e42787d53db9dd54261812ef17e937c95a51e4d291373b670342fa44310c" +checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18" dependencies = [ - "event-listener 4.0.0", + "event-listener 5.4.0", "event-listener-strategy", "pin-project-lite", ] @@ -1317,7 +1317,6 @@ dependencies = [ "itertools 0.10.5", "libc", "measured", - "metrics", "neon-shmem", "nix 0.30.1", "pageserver_api", @@ -1418,6 +1417,7 @@ dependencies = [ "tower-http", "tower-otel", "tracing", + "tracing-appender", "tracing-opentelemetry", "tracing-subscriber", "tracing-utils", @@ -1433,9 +1433,9 @@ dependencies = [ [[package]] name = "concurrent-queue" -version = "2.3.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f057a694a54f12365049b0958a1685bb52d567f5593b355fbf685838e873d400" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" dependencies = [ "crossbeam-utils", ] @@ -2261,9 +2261,9 @@ checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" [[package]] name = "event-listener" -version = "4.0.0" +version = "5.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "770d968249b5d99410d61f5bf89057f3199a077a04d087092f58e7d10692baae" +checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae" dependencies = [ "concurrent-queue", "parking", @@ -2272,11 +2272,11 @@ dependencies = [ [[package]] name = "event-listener-strategy" -version = "0.4.0" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" dependencies = [ - "event-listener 4.0.0", + "event-listener 5.4.0", "pin-project-lite", ] @@ -2554,6 +2554,20 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "304de19db7028420975a296ab0fcbbc8e69438c4ed254a1e41e2a7f37d5f0e0a" +[[package]] +name = "generator" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d18470a76cb7f8ff746cf1f7470914f900252ec36bbc40b569d74b1258446827" +dependencies = [ + "cc", + "cfg-if", + "libc", + "log", + "rustversion", + "windows 0.61.3", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -2872,7 +2886,7 @@ checksum = "f9c7c7c8ac16c798734b8a24560c1362120597c40d5e1459f09498f8f6c8f2ba" dependencies = [ "cfg-if", "libc", - "windows", + "windows 0.52.0", ] [[package]] @@ -3143,7 +3157,7 @@ dependencies = [ "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows-core", + "windows-core 0.52.0", ] [[package]] @@ -3694,6 +3708,19 @@ version = "0.4.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" +[[package]] +name = "loom" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator", + "scoped-tls", + "tracing", + "tracing-subscriber", +] + [[package]] name = "lru" version = "0.12.3" @@ -3910,6 +3937,25 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "moka" +version = "0.12.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9321642ca94a4282428e6ea4af8cc2ca4eac48ac7a6a4ea8f33f76d0ce70926" +dependencies = [ + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "loom", + "parking_lot 0.12.1", + "portable-atomic", + "rustc_version", + "smallvec", + "tagptr", + "thiserror 1.0.69", + "uuid", +] + [[package]] name = "multimap" version = "0.8.3" @@ -5083,8 +5129,6 @@ dependencies = [ "crc32c", "criterion", "env_logger", - "log", - "memoffset 0.9.0", "once_cell", "postgres", "postgres_ffi_types", @@ -5437,7 +5481,6 @@ dependencies = [ "futures", "gettid", "hashbrown 0.14.5", - "hashlink", "hex", "hmac", "hostname", @@ -5459,6 +5502,7 @@ dependencies = [ "lasso", "measured", "metrics", + "moka", "once_cell", "opentelemetry", "ouroboros", @@ -5525,6 +5569,7 @@ dependencies = [ "workspace_hack", "x509-cert", "zerocopy 0.8.24", + "zeroize", ] [[package]] @@ -6472,6 +6517,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.1.0" @@ -7321,6 +7372,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "tar" version = "0.4.40" @@ -7997,11 +8054,12 @@ dependencies = [ [[package]] name = "tracing-appender" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09d48f71a791638519505cefafe162606f706c25592e4bde4d97600c0195312e" +checksum = "3566e8ce28cc0a3fe42519fc80e6b4c943cc4c8cef275620eb8dac2d3d4e06cf" dependencies = [ "crossbeam-channel", + "thiserror 1.0.69", "time", "tracing-subscriber", ] @@ -8699,10 +8757,32 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" dependencies = [ - "windows-core", + "windows-core 0.52.0", "windows-targets 0.52.6", ] +[[package]] +name = "windows" +version = "0.61.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893" +dependencies = [ + "windows-collections", + "windows-core 0.61.2", + "windows-future", + "windows-link", + "windows-numerics", +] + +[[package]] +name = "windows-collections" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" +dependencies = [ + "windows-core 0.61.2", +] + [[package]] name = "windows-core" version = "0.52.0" @@ -8712,6 +8792,86 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-core" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-future" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" +dependencies = [ + "windows-core 0.61.2", + "windows-link", + "windows-threading", +] + +[[package]] +name = "windows-implement" +version = "0.60.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "windows-interface" +version = "0.59.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "windows-link" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" + +[[package]] +name = "windows-numerics" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" +dependencies = [ + "windows-core 0.61.2", + "windows-link", +] + +[[package]] +name = "windows-result" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -8770,6 +8930,15 @@ dependencies = [ "windows_x86_64_msvc 0.52.6", ] +[[package]] +name = "windows-threading" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" +dependencies = [ + "windows-link", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.0" @@ -8907,6 +9076,7 @@ dependencies = [ "clap_builder", "const-oid", "criterion", + "crossbeam-epoch", "crypto-bigint 0.5.5", "der 0.7.8", "deranged", @@ -8951,6 +9121,7 @@ dependencies = [ "num-traits", "p256 0.13.2", "parquet", + "portable-atomic", "prettyplease", "proc-macro2", "prost 0.13.5", diff --git a/Cargo.toml b/Cargo.toml index 1402c0b38d..4f7481dbd5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,10 +46,10 @@ members = [ "libs/proxy/json", "libs/proxy/postgres-protocol2", "libs/proxy/postgres-types2", + "libs/proxy/subzero_core", "libs/proxy/tokio-postgres2", "endpoint_storage", "pgxn/neon/communicator", - "proxy/subzero_core", ] [workspace.package] @@ -136,7 +136,7 @@ lock_api = "0.4.13" md5 = "0.7.0" measured = { version = "0.0.22", features=["lasso"] } measured-process = { version = "0.0.22" } -memoffset = "0.9" +moka = { version = "0.12", features = ["sync"] } nix = { version = "0.30.1", features = ["dir", "fs", "mman", "process", "socket", "signal", "poll"] } # Do not update to >= 7.0.0, at least. The update will have a significant impact # on compute startup metrics (start_postgres_ms), >= 25% degradation. @@ -224,6 +224,7 @@ tracing-log = "0.2" tracing-opentelemetry = "0.31" tracing-serde = "0.2.0" tracing-subscriber = { version = "0.3", default-features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] } +tracing-appender = "0.2.3" try-lock = "0.2.5" test-log = { version = "0.2.17", default-features = false, features = ["log"] } twox-hash = { version = "1.6.3", default-features = false } @@ -234,9 +235,10 @@ uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] } walkdir = "2.3.2" rustls-native-certs = "0.8" whoami = "1.5.1" -zerocopy = { version = "0.8", features = ["derive", "simd"] } json-structural-diff = { version = "0.2.0" } x509-cert = { version = "0.2.5" } +zerocopy = { version = "0.8", features = ["derive", "simd"] } +zeroize = "1.8" ## TODO replace this with tracing env_logger = "0.11" diff --git a/Dockerfile b/Dockerfile index 654ae72e56..63cc954873 100644 --- a/Dockerfile +++ b/Dockerfile @@ -103,7 +103,7 @@ RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \ && if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \ export CARGO_FEATURES="rest_broker"; \ fi \ - && RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo build \ + && RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo auditable build \ --features $CARGO_FEATURES \ --bin pg_sni_router \ --bin pageserver \ diff --git a/build-tools/Dockerfile b/build-tools/Dockerfile index b5fe642e6f..c9760f610b 100644 --- a/build-tools/Dockerfile +++ b/build-tools/Dockerfile @@ -39,13 +39,13 @@ COPY build-tools/patches/pgcopydbv017.patch /pgcopydbv017.patch RUN if [ "${DEBIAN_VERSION}" = "bookworm" ]; then \ set -e && \ - apt update && \ - apt install -y --no-install-recommends \ + apt-get update && \ + apt-get install -y --no-install-recommends \ ca-certificates wget gpg && \ wget -qO - https://www.postgresql.org/media/keys/ACCC4CF8.asc | gpg --dearmor -o /usr/share/keyrings/postgresql-keyring.gpg && \ echo "deb [signed-by=/usr/share/keyrings/postgresql-keyring.gpg] http://apt.postgresql.org/pub/repos/apt bookworm-pgdg main" > /etc/apt/sources.list.d/pgdg.list && \ apt-get update && \ - apt install -y --no-install-recommends \ + apt-get install -y --no-install-recommends \ build-essential \ autotools-dev \ libedit-dev \ @@ -89,8 +89,7 @@ RUN useradd -ms /bin/bash nonroot -b /home # Use strict mode for bash to catch errors early SHELL ["/bin/bash", "-euo", "pipefail", "-c"] -RUN mkdir -p /pgcopydb/bin && \ - mkdir -p /pgcopydb/lib && \ +RUN mkdir -p /pgcopydb/{bin,lib} && \ chmod -R 755 /pgcopydb && \ chown -R nonroot:nonroot /pgcopydb @@ -106,8 +105,8 @@ RUN echo 'Acquire::Retries "5";' > /etc/apt/apt.conf.d/80-retries && \ # 'gdb' is included so that we get backtraces of core dumps produced in # regression tests RUN set -e \ - && apt update \ - && apt install -y \ + && apt-get update \ + && apt-get install -y --no-install-recommends \ autoconf \ automake \ bison \ @@ -183,22 +182,22 @@ RUN curl -sL "https://github.com/peak/s5cmd/releases/download/v${S5CMD_VERSION}/ ENV LLVM_VERSION=20 RUN curl -fsSL 'https://apt.llvm.org/llvm-snapshot.gpg.key' | apt-key add - \ && echo "deb http://apt.llvm.org/${DEBIAN_VERSION}/ llvm-toolchain-${DEBIAN_VERSION}-${LLVM_VERSION} main" > /etc/apt/sources.list.d/llvm.stable.list \ - && apt update \ - && apt install -y clang-${LLVM_VERSION} llvm-${LLVM_VERSION} \ + && apt-get update \ + && apt-get install -y --no-install-recommends clang-${LLVM_VERSION} llvm-${LLVM_VERSION} \ && bash -c 'for f in /usr/bin/clang*-${LLVM_VERSION} /usr/bin/llvm*-${LLVM_VERSION}; do ln -s "${f}" "${f%-${LLVM_VERSION}}"; done' \ && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* # Install node ENV NODE_VERSION=24 RUN curl -fsSL https://deb.nodesource.com/setup_${NODE_VERSION}.x | bash - \ - && apt install -y nodejs \ + && apt-get install -y --no-install-recommends nodejs \ && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* # Install docker RUN curl -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg \ && echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/debian ${DEBIAN_VERSION} stable" > /etc/apt/sources.list.d/docker.list \ - && apt update \ - && apt install -y docker-ce docker-ce-cli \ + && apt-get update \ + && apt-get install -y --no-install-recommends docker-ce docker-ce-cli \ && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* # Configure sudo & docker @@ -215,12 +214,11 @@ RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-$(uname -m).zip" -o "aws # Mold: A Modern Linker ENV MOLD_VERSION=v2.37.1 RUN set -e \ - && git clone https://github.com/rui314/mold.git \ + && git clone -b "${MOLD_VERSION}" --depth 1 https://github.com/rui314/mold.git \ && mkdir mold/build \ - && cd mold/build \ - && git checkout ${MOLD_VERSION} \ + && cd mold/build \ && cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_COMPILER=clang++ .. \ - && cmake --build . -j $(nproc) \ + && cmake --build . -j "$(nproc)" \ && cmake --install . \ && cd .. \ && rm -rf mold @@ -254,7 +252,7 @@ ENV ICU_VERSION=67.1 ENV ICU_PREFIX=/usr/local/icu # Download and build static ICU -RUN wget -O /tmp/libicu-${ICU_VERSION}.tgz https://github.com/unicode-org/icu/releases/download/release-${ICU_VERSION//./-}/icu4c-${ICU_VERSION//./_}-src.tgz && \ +RUN wget -O "/tmp/libicu-${ICU_VERSION}.tgz" https://github.com/unicode-org/icu/releases/download/release-${ICU_VERSION//./-}/icu4c-${ICU_VERSION//./_}-src.tgz && \ echo "94a80cd6f251a53bd2a997f6f1b5ac6653fe791dfab66e1eb0227740fb86d5dc /tmp/libicu-${ICU_VERSION}.tgz" | sha256sum --check && \ mkdir /tmp/icu && \ pushd /tmp/icu && \ @@ -265,8 +263,7 @@ RUN wget -O /tmp/libicu-${ICU_VERSION}.tgz https://github.com/unicode-org/icu/re make install && \ popd && \ rm -rf icu && \ - rm -f /tmp/libicu-${ICU_VERSION}.tgz && \ - popd + rm -f /tmp/libicu-${ICU_VERSION}.tgz # Switch to nonroot user USER nonroot:nonroot @@ -279,19 +276,19 @@ ENV PYTHON_VERSION=3.11.12 \ PYENV_ROOT=/home/nonroot/.pyenv \ PATH=/home/nonroot/.pyenv/shims:/home/nonroot/.pyenv/bin:/home/nonroot/.poetry/bin:$PATH RUN set -e \ - && cd $HOME \ + && cd "$HOME" \ && curl -sSO https://raw.githubusercontent.com/pyenv/pyenv-installer/master/bin/pyenv-installer \ && chmod +x pyenv-installer \ && ./pyenv-installer \ && export PYENV_ROOT=/home/nonroot/.pyenv \ && export PATH="$PYENV_ROOT/bin:$PATH" \ && export PATH="$PYENV_ROOT/shims:$PATH" \ - && pyenv install ${PYTHON_VERSION} \ - && pyenv global ${PYTHON_VERSION} \ + && pyenv install "${PYTHON_VERSION}" \ + && pyenv global "${PYTHON_VERSION}" \ && python --version \ - && pip install --upgrade pip \ + && pip install --no-cache-dir --upgrade pip \ && pip --version \ - && pip install pipenv wheel poetry + && pip install --no-cache-dir pipenv wheel poetry # Switch to nonroot user (again) USER nonroot:nonroot @@ -302,6 +299,7 @@ WORKDIR /home/nonroot ENV RUSTC_VERSION=1.88.0 ENV RUSTUP_HOME="/home/nonroot/.rustup" ENV PATH="/home/nonroot/.cargo/bin:${PATH}" +ARG CARGO_AUDITABLE_VERSION=0.7.0 ARG RUSTFILT_VERSION=0.2.1 ARG CARGO_HAKARI_VERSION=0.9.36 ARG CARGO_DENY_VERSION=0.18.2 @@ -317,14 +315,16 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux . "$HOME/.cargo/env" && \ cargo --version && rustup --version && \ rustup component add llvm-tools rustfmt clippy && \ - cargo install rustfilt --locked --version ${RUSTFILT_VERSION} && \ - cargo install cargo-hakari --locked --version ${CARGO_HAKARI_VERSION} && \ - cargo install cargo-deny --locked --version ${CARGO_DENY_VERSION} && \ - cargo install cargo-hack --locked --version ${CARGO_HACK_VERSION} && \ - cargo install cargo-nextest --locked --version ${CARGO_NEXTEST_VERSION} && \ - cargo install cargo-chef --locked --version ${CARGO_CHEF_VERSION} && \ - cargo install diesel_cli --locked --version ${CARGO_DIESEL_CLI_VERSION} \ - --features postgres-bundled --no-default-features && \ + cargo install cargo-auditable --locked --version "${CARGO_AUDITABLE_VERSION}" && \ + cargo auditable install cargo-auditable --locked --version "${CARGO_AUDITABLE_VERSION}" --force && \ + cargo auditable install rustfilt --version "${RUSTFILT_VERSION}" && \ + cargo auditable install cargo-hakari --locked --version "${CARGO_HAKARI_VERSION}" && \ + cargo auditable install cargo-deny --locked --version "${CARGO_DENY_VERSION}" && \ + cargo auditable install cargo-hack --locked --version "${CARGO_HACK_VERSION}" && \ + cargo auditable install cargo-nextest --locked --version "${CARGO_NEXTEST_VERSION}" && \ + cargo auditable install cargo-chef --locked --version "${CARGO_CHEF_VERSION}" && \ + cargo auditable install diesel_cli --locked --version "${CARGO_DIESEL_CLI_VERSION}" \ + --features postgres-bundled --no-default-features && \ rm -rf /home/nonroot/.cargo/registry && \ rm -rf /home/nonroot/.cargo/git diff --git a/build-tools/package-lock.json b/build-tools/package-lock.json index b2c44ed9b4..0d48345fd5 100644 --- a/build-tools/package-lock.json +++ b/build-tools/package-lock.json @@ -6,7 +6,7 @@ "": { "name": "build-tools", "devDependencies": { - "@redocly/cli": "1.34.4", + "@redocly/cli": "1.34.5", "@sourcemeta/jsonschema": "10.0.0" } }, @@ -472,9 +472,9 @@ } }, "node_modules/@redocly/cli": { - "version": "1.34.4", - "resolved": "https://registry.npmjs.org/@redocly/cli/-/cli-1.34.4.tgz", - "integrity": "sha512-seH/GgrjSB1EeOsgJ/4Ct6Jk2N7sh12POn/7G8UQFARMyUMJpe1oHtBwT2ndfp4EFCpgBAbZ/82Iw6dwczNxEA==", + "version": "1.34.5", + "resolved": "https://registry.npmjs.org/@redocly/cli/-/cli-1.34.5.tgz", + "integrity": "sha512-5IEwxs7SGP5KEXjBKLU8Ffdz9by/KqNSeBk6YUVQaGxMXK//uYlTJIPntgUXbo1KAGG2d2q2XF8y4iFz6qNeiw==", "dev": true, "license": "MIT", "dependencies": { @@ -484,14 +484,14 @@ "@opentelemetry/sdk-trace-node": "1.26.0", "@opentelemetry/semantic-conventions": "1.27.0", "@redocly/config": "^0.22.0", - "@redocly/openapi-core": "1.34.4", - "@redocly/respect-core": "1.34.4", + "@redocly/openapi-core": "1.34.5", + "@redocly/respect-core": "1.34.5", "abort-controller": "^3.0.0", "chokidar": "^3.5.1", "colorette": "^1.2.0", "core-js": "^3.32.1", "dotenv": "16.4.7", - "form-data": "^4.0.0", + "form-data": "^4.0.4", "get-port-please": "^3.0.1", "glob": "^7.1.6", "handlebars": "^4.7.6", @@ -522,9 +522,9 @@ "license": "MIT" }, "node_modules/@redocly/openapi-core": { - "version": "1.34.4", - "resolved": "https://registry.npmjs.org/@redocly/openapi-core/-/openapi-core-1.34.4.tgz", - "integrity": "sha512-hf53xEgpXIgWl3b275PgZU3OTpYh1RoD2LHdIfQ1JzBNTWsiNKczTEsI/4Tmh2N1oq9YcphhSMyk3lDh85oDjg==", + "version": "1.34.5", + "resolved": "https://registry.npmjs.org/@redocly/openapi-core/-/openapi-core-1.34.5.tgz", + "integrity": "sha512-0EbE8LRbkogtcCXU7liAyC00n9uNG9hJ+eMyHFdUsy9lB/WGqnEBgwjA9q2cyzAVcdTkQqTBBU1XePNnN3OijA==", "dev": true, "license": "MIT", "dependencies": { @@ -544,21 +544,21 @@ } }, "node_modules/@redocly/respect-core": { - "version": "1.34.4", - "resolved": "https://registry.npmjs.org/@redocly/respect-core/-/respect-core-1.34.4.tgz", - "integrity": "sha512-MitKyKyQpsizA4qCVv+MjXL4WltfhFQAoiKiAzrVR1Kusro3VhYb6yJuzoXjiJhR0ukLP5QOP19Vcs7qmj9dZg==", + "version": "1.34.5", + "resolved": "https://registry.npmjs.org/@redocly/respect-core/-/respect-core-1.34.5.tgz", + "integrity": "sha512-GheC/g/QFztPe9UA9LamooSplQuy9pe0Yr8XGTqkz0ahivLDl7svoy/LSQNn1QH3XGtLKwFYMfTwFR2TAYyh5Q==", "dev": true, "license": "MIT", "dependencies": { "@faker-js/faker": "^7.6.0", "@redocly/ajv": "8.11.2", - "@redocly/openapi-core": "1.34.4", + "@redocly/openapi-core": "1.34.5", "better-ajv-errors": "^1.2.0", "colorette": "^2.0.20", "concat-stream": "^2.0.0", "cookie": "^0.7.2", "dotenv": "16.4.7", - "form-data": "4.0.0", + "form-data": "^4.0.4", "jest-diff": "^29.3.1", "jest-matcher-utils": "^29.3.1", "js-yaml": "4.1.0", @@ -582,21 +582,6 @@ "dev": true, "license": "MIT" }, - "node_modules/@redocly/respect-core/node_modules/form-data": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.0.tgz", - "integrity": "sha512-ETEklSGi5t0QMZuiXoA/Q6vcnxcLQP5vdugSpuAyi6SVGi2clPPp+xgEhuMaHC+zGgn31Kd235W35f7Hykkaww==", - "dev": true, - "license": "MIT", - "dependencies": { - "asynckit": "^0.4.0", - "combined-stream": "^1.0.8", - "mime-types": "^2.1.12" - }, - "engines": { - "node": ">= 6" - } - }, "node_modules/@sinclair/typebox": { "version": "0.27.8", "resolved": "https://registry.npmjs.org/@sinclair/typebox/-/typebox-0.27.8.tgz", @@ -1345,9 +1330,9 @@ "license": "MIT" }, "node_modules/form-data": { - "version": "4.0.3", - "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.3.tgz", - "integrity": "sha512-qsITQPfmvMOSAdeyZ+12I1c+CKSstAFAwu+97zrnWAbIr5u8wfsExUzCesVLC8NgHuRUqNN4Zy6UPWUTRGslcA==", + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.4.tgz", + "integrity": "sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==", "dev": true, "license": "MIT", "dependencies": { diff --git a/build-tools/package.json b/build-tools/package.json index 000969c672..2dc1359075 100644 --- a/build-tools/package.json +++ b/build-tools/package.json @@ -2,7 +2,7 @@ "name": "build-tools", "private": true, "devDependencies": { - "@redocly/cli": "1.34.4", + "@redocly/cli": "1.34.5", "@sourcemeta/jsonschema": "10.0.0" } } diff --git a/compute/compute-node.Dockerfile b/compute/compute-node.Dockerfile index a658738d76..6eecb89291 100644 --- a/compute/compute-node.Dockerfile +++ b/compute/compute-node.Dockerfile @@ -133,7 +133,7 @@ RUN case $DEBIAN_VERSION in \ # Install newer version (3.25) from backports. # libstdc++-10-dev is required for plv8 bullseye) \ - echo "deb http://deb.debian.org/debian bullseye-backports main" > /etc/apt/sources.list.d/bullseye-backports.list; \ + echo "deb http://archive.debian.org/debian bullseye-backports main" > /etc/apt/sources.list.d/bullseye-backports.list; \ VERSION_INSTALLS="cmake/bullseye-backports cmake-data/bullseye-backports libstdc++-10-dev"; \ ;; \ # Version-specific installs for Bookworm (PG17): diff --git a/compute/etc/sql_exporter/checkpoints_req.17.sql b/compute/etc/sql_exporter/checkpoints_req.17.sql index a4b946e8e2..28c868ff72 100644 --- a/compute/etc/sql_exporter/checkpoints_req.17.sql +++ b/compute/etc/sql_exporter/checkpoints_req.17.sql @@ -1 +1 @@ -SELECT num_requested AS checkpoints_req FROM pg_stat_checkpointer; +SELECT num_requested AS checkpoints_req FROM pg_catalog.pg_stat_checkpointer; diff --git a/compute/etc/sql_exporter/checkpoints_req.sql b/compute/etc/sql_exporter/checkpoints_req.sql index eb8427c883..421448c0de 100644 --- a/compute/etc/sql_exporter/checkpoints_req.sql +++ b/compute/etc/sql_exporter/checkpoints_req.sql @@ -1 +1 @@ -SELECT checkpoints_req FROM pg_stat_bgwriter; +SELECT checkpoints_req FROM pg_catalog.pg_stat_bgwriter; diff --git a/compute/etc/sql_exporter/checkpoints_timed.sql b/compute/etc/sql_exporter/checkpoints_timed.sql index c50853134c..bfa9b1b3d6 100644 --- a/compute/etc/sql_exporter/checkpoints_timed.sql +++ b/compute/etc/sql_exporter/checkpoints_timed.sql @@ -1 +1 @@ -SELECT checkpoints_timed FROM pg_stat_bgwriter; +SELECT checkpoints_timed FROM pg_catalog.pg_stat_bgwriter; diff --git a/compute/etc/sql_exporter/compute_backpressure_throttling_seconds_total.sql b/compute/etc/sql_exporter/compute_backpressure_throttling_seconds_total.sql index d97d625d4c..3fe638e489 100644 --- a/compute/etc/sql_exporter/compute_backpressure_throttling_seconds_total.sql +++ b/compute/etc/sql_exporter/compute_backpressure_throttling_seconds_total.sql @@ -1 +1 @@ -SELECT (neon.backpressure_throttling_time()::float8 / 1000000) AS throttled; +SELECT (neon.backpressure_throttling_time()::pg_catalog.float8 / 1000000) AS throttled; diff --git a/compute/etc/sql_exporter/compute_current_lsn.sql b/compute/etc/sql_exporter/compute_current_lsn.sql index be02b8a094..9a042547f0 100644 --- a/compute/etc/sql_exporter/compute_current_lsn.sql +++ b/compute/etc/sql_exporter/compute_current_lsn.sql @@ -1,4 +1,4 @@ SELECT CASE - WHEN pg_catalog.pg_is_in_recovery() THEN (pg_last_wal_replay_lsn() - '0/0')::FLOAT8 - ELSE (pg_current_wal_lsn() - '0/0')::FLOAT8 + WHEN pg_catalog.pg_is_in_recovery() THEN (pg_catalog.pg_last_wal_replay_lsn() - '0/0')::pg_catalog.FLOAT8 + ELSE (pg_catalog.pg_current_wal_lsn() - '0/0')::pg_catalog.FLOAT8 END AS lsn; diff --git a/compute/etc/sql_exporter/compute_logical_snapshot_files.sql b/compute/etc/sql_exporter/compute_logical_snapshot_files.sql index f2454235b7..2224c02d8d 100644 --- a/compute/etc/sql_exporter/compute_logical_snapshot_files.sql +++ b/compute/etc/sql_exporter/compute_logical_snapshot_files.sql @@ -1,7 +1,7 @@ SELECT - (SELECT setting FROM pg_settings WHERE name = 'neon.timeline_id') AS timeline_id, + (SELECT setting FROM pg_catalog.pg_settings WHERE name = 'neon.timeline_id') AS timeline_id, -- Postgres creates temporary snapshot files of the form %X-%X.snap.%d.tmp. -- These temporary snapshot files are renamed to the actual snapshot files -- after they are completely built. We only WAL-log the completely built -- snapshot files - (SELECT COUNT(*) FROM pg_ls_dir('pg_logical/snapshots') AS name WHERE name LIKE '%.snap') AS num_logical_snapshot_files; + (SELECT COUNT(*) FROM pg_catalog.pg_ls_dir('pg_logical/snapshots') AS name WHERE name LIKE '%.snap') AS num_logical_snapshot_files; diff --git a/compute/etc/sql_exporter/compute_logical_snapshots_bytes.15.sql b/compute/etc/sql_exporter/compute_logical_snapshots_bytes.15.sql index 73a9c11405..17cf1228d3 100644 --- a/compute/etc/sql_exporter/compute_logical_snapshots_bytes.15.sql +++ b/compute/etc/sql_exporter/compute_logical_snapshots_bytes.15.sql @@ -1,7 +1,7 @@ SELECT - (SELECT current_setting('neon.timeline_id')) AS timeline_id, + (SELECT pg_catalog.current_setting('neon.timeline_id')) AS timeline_id, -- Postgres creates temporary snapshot files of the form %X-%X.snap.%d.tmp. -- These temporary snapshot files are renamed to the actual snapshot files -- after they are completely built. We only WAL-log the completely built -- snapshot files - (SELECT COALESCE(sum(size), 0) FROM pg_ls_logicalsnapdir() WHERE name LIKE '%.snap') AS logical_snapshots_bytes; + (SELECT COALESCE(pg_catalog.sum(size), 0) FROM pg_catalog.pg_ls_logicalsnapdir() WHERE name LIKE '%.snap') AS logical_snapshots_bytes; diff --git a/compute/etc/sql_exporter/compute_logical_snapshots_bytes.sql b/compute/etc/sql_exporter/compute_logical_snapshots_bytes.sql index 16da899de2..33ca1137dc 100644 --- a/compute/etc/sql_exporter/compute_logical_snapshots_bytes.sql +++ b/compute/etc/sql_exporter/compute_logical_snapshots_bytes.sql @@ -1,9 +1,9 @@ SELECT - (SELECT setting FROM pg_settings WHERE name = 'neon.timeline_id') AS timeline_id, + (SELECT setting FROM pg_catalog.pg_settings WHERE name = 'neon.timeline_id') AS timeline_id, -- Postgres creates temporary snapshot files of the form %X-%X.snap.%d.tmp. -- These temporary snapshot files are renamed to the actual snapshot files -- after they are completely built. We only WAL-log the completely built -- snapshot files - (SELECT COALESCE(sum((pg_stat_file('pg_logical/snapshots/' || name, missing_ok => true)).size), 0) - FROM (SELECT * FROM pg_ls_dir('pg_logical/snapshots') WHERE pg_ls_dir LIKE '%.snap') AS name + (SELECT COALESCE(pg_catalog.sum((pg_catalog.pg_stat_file('pg_logical/snapshots/' || name, missing_ok => true)).size), 0) + FROM (SELECT * FROM pg_catalog.pg_ls_dir('pg_logical/snapshots') WHERE pg_ls_dir LIKE '%.snap') AS name ) AS logical_snapshots_bytes; diff --git a/compute/etc/sql_exporter/compute_max_connections.sql b/compute/etc/sql_exporter/compute_max_connections.sql index 99a49483a6..1613c962a2 100644 --- a/compute/etc/sql_exporter/compute_max_connections.sql +++ b/compute/etc/sql_exporter/compute_max_connections.sql @@ -1 +1 @@ -SELECT current_setting('max_connections') as max_connections; +SELECT pg_catalog.current_setting('max_connections') AS max_connections; diff --git a/compute/etc/sql_exporter/compute_pg_oldest_frozen_xid_age.sql b/compute/etc/sql_exporter/compute_pg_oldest_frozen_xid_age.sql index d2281fdd42..e613939e71 100644 --- a/compute/etc/sql_exporter/compute_pg_oldest_frozen_xid_age.sql +++ b/compute/etc/sql_exporter/compute_pg_oldest_frozen_xid_age.sql @@ -1,4 +1,4 @@ SELECT datname database_name, - age(datfrozenxid) frozen_xid_age -FROM pg_database + pg_catalog.age(datfrozenxid) frozen_xid_age +FROM pg_catalog.pg_database ORDER BY frozen_xid_age DESC LIMIT 10; diff --git a/compute/etc/sql_exporter/compute_pg_oldest_mxid_age.sql b/compute/etc/sql_exporter/compute_pg_oldest_mxid_age.sql index ed57894b3a..7949bacfff 100644 --- a/compute/etc/sql_exporter/compute_pg_oldest_mxid_age.sql +++ b/compute/etc/sql_exporter/compute_pg_oldest_mxid_age.sql @@ -1,4 +1,4 @@ SELECT datname database_name, - mxid_age(datminmxid) min_mxid_age -FROM pg_database + pg_catalog.mxid_age(datminmxid) min_mxid_age +FROM pg_catalog.pg_database ORDER BY min_mxid_age DESC LIMIT 10; diff --git a/compute/etc/sql_exporter/compute_receive_lsn.sql b/compute/etc/sql_exporter/compute_receive_lsn.sql index 318b31ab41..fb96056881 100644 --- a/compute/etc/sql_exporter/compute_receive_lsn.sql +++ b/compute/etc/sql_exporter/compute_receive_lsn.sql @@ -1,4 +1,4 @@ SELECT CASE - WHEN pg_catalog.pg_is_in_recovery() THEN (pg_last_wal_receive_lsn() - '0/0')::FLOAT8 + WHEN pg_catalog.pg_is_in_recovery() THEN (pg_catalog.pg_last_wal_receive_lsn() - '0/0')::pg_catalog.FLOAT8 ELSE 0 END AS lsn; diff --git a/compute/etc/sql_exporter/compute_subscriptions_count.sql b/compute/etc/sql_exporter/compute_subscriptions_count.sql index 50740cb5df..e380a7acc7 100644 --- a/compute/etc/sql_exporter/compute_subscriptions_count.sql +++ b/compute/etc/sql_exporter/compute_subscriptions_count.sql @@ -1 +1 @@ -SELECT subenabled::text AS enabled, count(*) AS subscriptions_count FROM pg_subscription GROUP BY subenabled; +SELECT subenabled::pg_catalog.text AS enabled, pg_catalog.count(*) AS subscriptions_count FROM pg_catalog.pg_subscription GROUP BY subenabled; diff --git a/compute/etc/sql_exporter/connection_counts.sql b/compute/etc/sql_exporter/connection_counts.sql index 6824480fdb..480c4fb439 100644 --- a/compute/etc/sql_exporter/connection_counts.sql +++ b/compute/etc/sql_exporter/connection_counts.sql @@ -1 +1 @@ -SELECT datname, state, count(*) AS count FROM pg_stat_activity WHERE state <> '' GROUP BY datname, state; +SELECT datname, state, pg_catalog.count(*) AS count FROM pg_catalog.pg_stat_activity WHERE state <> '' GROUP BY datname, state; diff --git a/compute/etc/sql_exporter/db_total_size.sql b/compute/etc/sql_exporter/db_total_size.sql index fe0360ab5c..59205e6ed3 100644 --- a/compute/etc/sql_exporter/db_total_size.sql +++ b/compute/etc/sql_exporter/db_total_size.sql @@ -1,5 +1,5 @@ -SELECT sum(pg_database_size(datname)) AS total -FROM pg_database +SELECT pg_catalog.sum(pg_catalog.pg_database_size(datname)) AS total +FROM pg_catalog.pg_database -- Ignore invalid databases, as we will likely have problems with -- getting their size from the Pageserver. WHERE datconnlimit != -2; diff --git a/compute/etc/sql_exporter/lfc_approximate_working_set_size_windows.autoscaling.sql b/compute/etc/sql_exporter/lfc_approximate_working_set_size_windows.autoscaling.sql index 35fa42c34c..02cb2b4649 100644 --- a/compute/etc/sql_exporter/lfc_approximate_working_set_size_windows.autoscaling.sql +++ b/compute/etc/sql_exporter/lfc_approximate_working_set_size_windows.autoscaling.sql @@ -3,6 +3,6 @@ -- minutes. SELECT - x::text as duration_seconds, + x::pg_catalog.text AS duration_seconds, neon.approximate_working_set_size_seconds(x) AS size FROM (SELECT generate_series * 60 AS x FROM generate_series(1, 60)) AS t (x); diff --git a/compute/etc/sql_exporter/lfc_approximate_working_set_size_windows.sql b/compute/etc/sql_exporter/lfc_approximate_working_set_size_windows.sql index 46c7d1610c..aab93d433a 100644 --- a/compute/etc/sql_exporter/lfc_approximate_working_set_size_windows.sql +++ b/compute/etc/sql_exporter/lfc_approximate_working_set_size_windows.sql @@ -3,6 +3,6 @@ SELECT x AS duration, - neon.approximate_working_set_size_seconds(extract('epoch' FROM x::interval)::int) AS size FROM ( + neon.approximate_working_set_size_seconds(extract('epoch' FROM x::pg_catalog.interval)::pg_catalog.int4) AS size FROM ( VALUES ('5m'), ('15m'), ('1h') ) AS t (x); diff --git a/compute/etc/sql_exporter/lfc_cache_size_limit.sql b/compute/etc/sql_exporter/lfc_cache_size_limit.sql index 378904c1fe..41c11e0adc 100644 --- a/compute/etc/sql_exporter/lfc_cache_size_limit.sql +++ b/compute/etc/sql_exporter/lfc_cache_size_limit.sql @@ -1 +1 @@ -SELECT pg_size_bytes(current_setting('neon.file_cache_size_limit')) AS lfc_cache_size_limit; +SELECT pg_catalog.pg_size_bytes(pg_catalog.current_setting('neon.file_cache_size_limit')) AS lfc_cache_size_limit; diff --git a/compute/etc/sql_exporter/logical_slot_restart_lsn.sql b/compute/etc/sql_exporter/logical_slot_restart_lsn.sql index 1b1c038501..8964d0d8ff 100644 --- a/compute/etc/sql_exporter/logical_slot_restart_lsn.sql +++ b/compute/etc/sql_exporter/logical_slot_restart_lsn.sql @@ -1,3 +1,3 @@ -SELECT slot_name, (restart_lsn - '0/0')::FLOAT8 as restart_lsn -FROM pg_replication_slots +SELECT slot_name, (restart_lsn - '0/0')::pg_catalog.FLOAT8 AS restart_lsn +FROM pg_catalog.pg_replication_slots WHERE slot_type = 'logical'; diff --git a/compute/etc/sql_exporter/max_cluster_size.sql b/compute/etc/sql_exporter/max_cluster_size.sql index 2d2355a9a7..d44fdebe38 100644 --- a/compute/etc/sql_exporter/max_cluster_size.sql +++ b/compute/etc/sql_exporter/max_cluster_size.sql @@ -1 +1 @@ -SELECT setting::int AS max_cluster_size FROM pg_settings WHERE name = 'neon.max_cluster_size'; +SELECT setting::pg_catalog.int4 AS max_cluster_size FROM pg_catalog.pg_settings WHERE name = 'neon.max_cluster_size'; diff --git a/compute/etc/sql_exporter/pg_stats_userdb.sql b/compute/etc/sql_exporter/pg_stats_userdb.sql index 12e6c4ae59..1a1e54a7c6 100644 --- a/compute/etc/sql_exporter/pg_stats_userdb.sql +++ b/compute/etc/sql_exporter/pg_stats_userdb.sql @@ -1,13 +1,13 @@ -- We export stats for 10 non-system databases. Without this limit it is too -- easy to abuse the system by creating lots of databases. -SELECT pg_database_size(datname) AS db_size, +SELECT pg_catalog.pg_database_size(datname) AS db_size, deadlocks, tup_inserted AS inserted, tup_updated AS updated, tup_deleted AS deleted, datname -FROM pg_stat_database +FROM pg_catalog.pg_stat_database WHERE datname IN ( SELECT datname FROM pg_database -- Ignore invalid databases, as we will likely have problems with diff --git a/compute/etc/sql_exporter/replication_delay_bytes.sql b/compute/etc/sql_exporter/replication_delay_bytes.sql index 60a6981acd..d3b7aa724b 100644 --- a/compute/etc/sql_exporter/replication_delay_bytes.sql +++ b/compute/etc/sql_exporter/replication_delay_bytes.sql @@ -3,4 +3,4 @@ -- replay LSN may have advanced past the receive LSN we are using for the -- calculation. -SELECT GREATEST(0, pg_wal_lsn_diff(pg_last_wal_receive_lsn(), pg_last_wal_replay_lsn())) AS replication_delay_bytes; +SELECT GREATEST(0, pg_catalog.pg_wal_lsn_diff(pg_catalog.pg_last_wal_receive_lsn(), pg_catalog.pg_last_wal_replay_lsn())) AS replication_delay_bytes; diff --git a/compute/etc/sql_exporter/replication_delay_seconds.sql b/compute/etc/sql_exporter/replication_delay_seconds.sql index a76809ad74..af4dd3fd90 100644 --- a/compute/etc/sql_exporter/replication_delay_seconds.sql +++ b/compute/etc/sql_exporter/replication_delay_seconds.sql @@ -1,5 +1,5 @@ SELECT CASE - WHEN pg_last_wal_receive_lsn() = pg_last_wal_replay_lsn() THEN 0 - ELSE GREATEST(0, EXTRACT (EPOCH FROM now() - pg_last_xact_replay_timestamp())) + WHEN pg_catalog.pg_last_wal_receive_lsn() = pg_catalog.pg_last_wal_replay_lsn() THEN 0 + ELSE GREATEST(0, EXTRACT (EPOCH FROM pg_catalog.now() - pg_catalog.pg_last_xact_replay_timestamp())) END AS replication_delay_seconds; diff --git a/compute/etc/sql_exporter/retained_wal.sql b/compute/etc/sql_exporter/retained_wal.sql index 3e2aadfc28..ccb3504d58 100644 --- a/compute/etc/sql_exporter/retained_wal.sql +++ b/compute/etc/sql_exporter/retained_wal.sql @@ -1,10 +1,10 @@ SELECT slot_name, - pg_wal_lsn_diff( + pg_catalog.pg_wal_lsn_diff( CASE - WHEN pg_is_in_recovery() THEN pg_last_wal_replay_lsn() - ELSE pg_current_wal_lsn() + WHEN pg_catalog.pg_is_in_recovery() THEN pg_catalog.pg_last_wal_replay_lsn() + ELSE pg_catalog.pg_current_wal_lsn() END, - restart_lsn)::FLOAT8 AS retained_wal -FROM pg_replication_slots + restart_lsn)::pg_catalog.FLOAT8 AS retained_wal +FROM pg_catalog.pg_replication_slots WHERE active = false; diff --git a/compute/etc/sql_exporter/wal_is_lost.sql b/compute/etc/sql_exporter/wal_is_lost.sql index 5521270851..5a94cc3373 100644 --- a/compute/etc/sql_exporter/wal_is_lost.sql +++ b/compute/etc/sql_exporter/wal_is_lost.sql @@ -4,4 +4,4 @@ SELECT WHEN wal_status = 'lost' THEN 1 ELSE 0 END AS wal_is_lost -FROM pg_replication_slots; +FROM pg_catalog.pg_replication_slots; diff --git a/compute/patches/pg_repack.patch b/compute/patches/pg_repack.patch index 10ed1054ff..b8a057e222 100644 --- a/compute/patches/pg_repack.patch +++ b/compute/patches/pg_repack.patch @@ -1,5 +1,11 @@ +commit 5eb393810cf7c7bafa4e394dad2e349e2a8cb2cb +Author: Alexey Masterov +Date: Mon Jul 28 18:11:02 2025 +0200 + + Patch for pg_repack + diff --git a/regress/Makefile b/regress/Makefile -index bf6edcb..89b4c7f 100644 +index bf6edcb..110e734 100644 --- a/regress/Makefile +++ b/regress/Makefile @@ -17,7 +17,7 @@ INTVERSION := $(shell echo $$(($$(echo $(VERSION).0 | sed 's/\([[:digit:]]\{1,\} @@ -7,18 +13,36 @@ index bf6edcb..89b4c7f 100644 # -REGRESS := init-extension repack-setup repack-run error-on-invalid-idx no-error-on-invalid-idx after-schema repack-check nosuper tablespace get_order_by trigger -+REGRESS := init-extension repack-setup repack-run error-on-invalid-idx no-error-on-invalid-idx after-schema repack-check nosuper get_order_by trigger ++REGRESS := init-extension noautovacuum repack-setup repack-run error-on-invalid-idx no-error-on-invalid-idx after-schema repack-check nosuper get_order_by trigger autovacuum USE_PGXS = 1 # use pgxs if not in contrib directory PGXS := $(shell $(PG_CONFIG) --pgxs) -diff --git a/regress/expected/init-extension.out b/regress/expected/init-extension.out -index 9f2e171..f6e4f8d 100644 ---- a/regress/expected/init-extension.out -+++ b/regress/expected/init-extension.out -@@ -1,3 +1,2 @@ - SET client_min_messages = warning; - CREATE EXTENSION pg_repack; --RESET client_min_messages; +diff --git a/regress/expected/autovacuum.out b/regress/expected/autovacuum.out +new file mode 100644 +index 0000000..e7f2363 +--- /dev/null ++++ b/regress/expected/autovacuum.out +@@ -0,0 +1,7 @@ ++ALTER SYSTEM SET autovacuum='on'; ++SELECT pg_reload_conf(); ++ pg_reload_conf ++---------------- ++ t ++(1 row) ++ +diff --git a/regress/expected/noautovacuum.out b/regress/expected/noautovacuum.out +new file mode 100644 +index 0000000..fc7978e +--- /dev/null ++++ b/regress/expected/noautovacuum.out +@@ -0,0 +1,7 @@ ++ALTER SYSTEM SET autovacuum='off'; ++SELECT pg_reload_conf(); ++ pg_reload_conf ++---------------- ++ t ++(1 row) ++ diff --git a/regress/expected/nosuper.out b/regress/expected/nosuper.out index 8d0a94e..63b68bf 100644 --- a/regress/expected/nosuper.out @@ -50,14 +74,22 @@ index 8d0a94e..63b68bf 100644 INFO: repacking table "public.tbl_cluster" ERROR: query failed: ERROR: current transaction is aborted, commands ignored until end of transaction block DETAIL: query was: RESET lock_timeout -diff --git a/regress/sql/init-extension.sql b/regress/sql/init-extension.sql -index 9f2e171..f6e4f8d 100644 ---- a/regress/sql/init-extension.sql -+++ b/regress/sql/init-extension.sql -@@ -1,3 +1,2 @@ - SET client_min_messages = warning; - CREATE EXTENSION pg_repack; --RESET client_min_messages; +diff --git a/regress/sql/autovacuum.sql b/regress/sql/autovacuum.sql +new file mode 100644 +index 0000000..a8eda63 +--- /dev/null ++++ b/regress/sql/autovacuum.sql +@@ -0,0 +1,2 @@ ++ALTER SYSTEM SET autovacuum='on'; ++SELECT pg_reload_conf(); +diff --git a/regress/sql/noautovacuum.sql b/regress/sql/noautovacuum.sql +new file mode 100644 +index 0000000..13d4836 +--- /dev/null ++++ b/regress/sql/noautovacuum.sql +@@ -0,0 +1,2 @@ ++ALTER SYSTEM SET autovacuum='off'; ++SELECT pg_reload_conf(); diff --git a/regress/sql/nosuper.sql b/regress/sql/nosuper.sql index 072f0fa..dbe60f8 100644 --- a/regress/sql/nosuper.sql diff --git a/compute/vm-image-spec-bookworm.yaml b/compute/vm-image-spec-bookworm.yaml index 267e4c83b5..5f27b6bf9d 100644 --- a/compute/vm-image-spec-bookworm.yaml +++ b/compute/vm-image-spec-bookworm.yaml @@ -26,7 +26,13 @@ commands: - name: postgres-exporter user: nobody sysvInitAction: respawn - shell: 'DATA_SOURCE_NAME="user=cloud_admin sslmode=disable dbname=postgres application_name=postgres-exporter pgaudit.log=none" /bin/postgres_exporter --config.file=/etc/postgres_exporter.yml' + # Turn off database collector (`--no-collector.database`), we don't use `pg_database_size_bytes` metric anyway, see + # https://github.com/neondatabase/flux-fleet/blob/5e19b3fd897667b70d9a7ad4aa06df0ca22b49ff/apps/base/compute-metrics/scrape-compute-pg-exporter-neon.yaml#L29 + # but it's enabled by default and it doesn't filter out invalid databases, see + # https://github.com/prometheus-community/postgres_exporter/blob/06a553c8166512c9d9c5ccf257b0f9bba8751dbc/collector/pg_database.go#L67 + # so if it hits one, it starts spamming logs + # ERROR: [NEON_SMGR] [reqid d9700000018] could not read db size of db 705302 from page server at lsn 5/A2457EB0 + shell: 'DATA_SOURCE_NAME="user=cloud_admin sslmode=disable dbname=postgres application_name=postgres-exporter pgaudit.log=none" /bin/postgres_exporter --no-collector.database --config.file=/etc/postgres_exporter.yml' - name: pgbouncer-exporter user: postgres sysvInitAction: respawn diff --git a/compute/vm-image-spec-bullseye.yaml b/compute/vm-image-spec-bullseye.yaml index 2b6e77b656..cf26ace72a 100644 --- a/compute/vm-image-spec-bullseye.yaml +++ b/compute/vm-image-spec-bullseye.yaml @@ -26,7 +26,13 @@ commands: - name: postgres-exporter user: nobody sysvInitAction: respawn - shell: 'DATA_SOURCE_NAME="user=cloud_admin sslmode=disable dbname=postgres application_name=postgres-exporter pgaudit.log=none" /bin/postgres_exporter --config.file=/etc/postgres_exporter.yml' + # Turn off database collector (`--no-collector.database`), we don't use `pg_database_size_bytes` metric anyway, see + # https://github.com/neondatabase/flux-fleet/blob/5e19b3fd897667b70d9a7ad4aa06df0ca22b49ff/apps/base/compute-metrics/scrape-compute-pg-exporter-neon.yaml#L29 + # but it's enabled by default and it doesn't filter out invalid databases, see + # https://github.com/prometheus-community/postgres_exporter/blob/06a553c8166512c9d9c5ccf257b0f9bba8751dbc/collector/pg_database.go#L67 + # so if it hits one, it starts spamming logs + # ERROR: [NEON_SMGR] [reqid d9700000018] could not read db size of db 705302 from page server at lsn 5/A2457EB0 + shell: 'DATA_SOURCE_NAME="user=cloud_admin sslmode=disable dbname=postgres application_name=postgres-exporter pgaudit.log=none" /bin/postgres_exporter --no-collector.database --config.file=/etc/postgres_exporter.yml' - name: pgbouncer-exporter user: postgres sysvInitAction: respawn diff --git a/compute_tools/Cargo.toml b/compute_tools/Cargo.toml index 496471acc7..558760b0ad 100644 --- a/compute_tools/Cargo.toml +++ b/compute_tools/Cargo.toml @@ -62,6 +62,7 @@ tokio-stream.workspace = true tonic.workspace = true tower-otel.workspace = true tracing.workspace = true +tracing-appender.workspace = true tracing-opentelemetry.workspace = true tracing-subscriber.workspace = true tracing-utils.workspace = true diff --git a/compute_tools/README.md b/compute_tools/README.md index 49f1368f0e..e92e5920b9 100644 --- a/compute_tools/README.md +++ b/compute_tools/README.md @@ -52,8 +52,14 @@ stateDiagram-v2 Init --> Running : Started Postgres Running --> TerminationPendingFast : Requested termination Running --> TerminationPendingImmediate : Requested termination + Running --> ConfigurationPending : Received a /configure request with spec + Running --> RefreshConfigurationPending : Received a /refresh_configuration request, compute node will pull a new spec and reconfigure + RefreshConfigurationPending --> RefreshConfiguration: Received compute spec and started configuration + RefreshConfiguration --> Running : Compute has been re-configured + RefreshConfiguration --> RefreshConfigurationPending : Configuration failed and to be retried TerminationPendingFast --> Terminated compute with 30s delay for cplane to inspect status TerminationPendingImmediate --> Terminated : Terminated compute immediately + Failed --> RefreshConfigurationPending : Received a /refresh_configuration request Failed --> [*] : Compute exited Terminated --> [*] : Compute exited ``` diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index 04723d6f3d..f383683ef8 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -49,9 +49,10 @@ use compute_tools::compute::{ BUILD_TAG, ComputeNode, ComputeNodeParams, forward_termination_signal, }; use compute_tools::extension_server::get_pg_version_string; -use compute_tools::logger::*; use compute_tools::params::*; +use compute_tools::pg_isready::get_pg_isready_bin; use compute_tools::spec::*; +use compute_tools::{hadron_metrics, installed_extensions, logger::*}; use rlimit::{Resource, setrlimit}; use signal_hook::consts::{SIGINT, SIGQUIT, SIGTERM}; use signal_hook::iterator::Signals; @@ -81,6 +82,15 @@ struct Cli { #[arg(long, default_value_t = 3081)] pub internal_http_port: u16, + /// Backwards-compatible --http-port for Hadron deployments. Functionally the + /// same as --external-http-port. + #[arg( + long, + conflicts_with = "external_http_port", + conflicts_with = "internal_http_port" + )] + pub http_port: Option, + #[arg(short = 'D', long, value_name = "DATADIR")] pub pgdata: String, @@ -180,6 +190,26 @@ impl Cli { } } +// Hadron helpers to get compatible compute_ctl http ports from Cli. The old `--http-port` +// arg is used and acts the same as `--external-http-port`. The internal http port is defined +// to be http_port + 1. Hadron runs in the dblet environment which uses the host network, so +// we need to be careful with the ports to choose. +fn get_external_http_port(cli: &Cli) -> u16 { + if cli.lakebase_mode { + return cli.http_port.unwrap_or(cli.external_http_port); + } + cli.external_http_port +} +fn get_internal_http_port(cli: &Cli) -> u16 { + if cli.lakebase_mode { + return cli + .http_port + .map(|p| p + 1) + .unwrap_or(cli.internal_http_port); + } + cli.internal_http_port +} + fn main() -> Result<()> { let cli = Cli::parse(); @@ -194,15 +224,28 @@ fn main() -> Result<()> { .build()?; let _rt_guard = runtime.enter(); - let tracing_provider = init(cli.dev)?; + let mut log_dir = None; + if cli.lakebase_mode { + log_dir = std::env::var("COMPUTE_CTL_LOG_DIRECTORY").ok(); + } + + let (tracing_provider, _file_logs_guard) = init(cli.dev, log_dir)?; // enable core dumping for all child processes setrlimit(Resource::CORE, rlimit::INFINITY, rlimit::INFINITY)?; + if cli.lakebase_mode { + installed_extensions::initialize_metrics(); + hadron_metrics::initialize_metrics(); + } + let connstr = Url::parse(&cli.connstr).context("cannot parse connstr as a URL")?; let config = get_config(&cli)?; + let external_http_port = get_external_http_port(&cli); + let internal_http_port = get_internal_http_port(&cli); + let compute_node = ComputeNode::new( ComputeNodeParams { compute_id: cli.compute_id, @@ -211,8 +254,8 @@ fn main() -> Result<()> { pgdata: cli.pgdata.clone(), pgbin: cli.pgbin.clone(), pgversion: get_pg_version_string(&cli.pgbin), - external_http_port: cli.external_http_port, - internal_http_port: cli.internal_http_port, + external_http_port, + internal_http_port, remote_ext_base_url: cli.remote_ext_base_url.clone(), resize_swap_on_bind: cli.resize_swap_on_bind, set_disk_quota_for_fs: cli.set_disk_quota_for_fs, @@ -226,20 +269,31 @@ fn main() -> Result<()> { cli.installed_extensions_collection_interval, )), pg_init_timeout: cli.pg_init_timeout.map(Duration::from_secs), + pg_isready_bin: get_pg_isready_bin(&cli.pgbin), + instance_id: std::env::var("INSTANCE_ID").ok(), lakebase_mode: cli.lakebase_mode, + build_tag: BUILD_TAG.to_string(), + control_plane_uri: cli.control_plane_uri, + config_path_test_only: cli.config, }, config, )?; - let exit_code = compute_node.run()?; + let exit_code = compute_node.run().context("running compute node")?; scenario.teardown(); deinit_and_exit(tracing_provider, exit_code); } -fn init(dev_mode: bool) -> Result> { - let provider = init_tracing_and_logging(DEFAULT_LOG_LEVEL)?; +fn init( + dev_mode: bool, + log_dir: Option, +) -> Result<( + Option, + Option, +)> { + let (provider, file_logs_guard) = init_tracing_and_logging(DEFAULT_LOG_LEVEL, &log_dir)?; let mut signals = Signals::new([SIGINT, SIGTERM, SIGQUIT])?; thread::spawn(move || { @@ -250,7 +304,7 @@ fn init(dev_mode: bool) -> Result> { info!("compute build_tag: {}", &BUILD_TAG.to_string()); - Ok(provider) + Ok((provider, file_logs_guard)) } fn get_config(cli: &Cli) -> Result { diff --git a/compute_tools/src/checker.rs b/compute_tools/src/checker.rs index e4207876ac..2458fe3c11 100644 --- a/compute_tools/src/checker.rs +++ b/compute_tools/src/checker.rs @@ -24,9 +24,9 @@ pub async fn check_writability(compute: &ComputeNode) -> Result<()> { }); let query = " - INSERT INTO health_check VALUES (1, now()) + INSERT INTO public.health_check VALUES (1, pg_catalog.now()) ON CONFLICT (id) DO UPDATE - SET updated_at = now();"; + SET updated_at = pg_catalog.now();"; match client.simple_query(query).await { Result::Ok(result) => { diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index dac17cf6c9..1df837e1e6 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -6,8 +6,8 @@ use compute_api::responses::{ LfcPrewarmState, PromoteState, TlsConfig, }; use compute_api::spec::{ - ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PageserverConnectionInfo, - PageserverProtocol, PageserverShardConnectionInfo, PageserverShardInfo, PgIdent, + ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, GenericOption, + PageserverConnectionInfo, PageserverProtocol, PgIdent, Role, }; use futures::StreamExt; use futures::future::join_all; @@ -22,6 +22,7 @@ use postgres::NoTls; use postgres::error::SqlState; use remote_storage::{DownloadError, RemotePath}; use std::collections::{HashMap, HashSet}; +use std::ffi::OsString; use std::os::unix::fs::{PermissionsExt, symlink}; use std::path::Path; use std::process::{Command, Stdio}; @@ -31,18 +32,23 @@ use std::sync::{Arc, Condvar, Mutex, RwLock}; use std::time::{Duration, Instant}; use std::{env, fs}; use tokio::{spawn, sync::watch, task::JoinHandle, time}; +use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info, instrument, warn}; use url::Url; +use utils::backoff::{ + DEFAULT_BASE_BACKOFF_SECONDS, DEFAULT_MAX_BACKOFF_SECONDS, exponential_backoff_duration, +}; use utils::id::{TenantId, TimelineId}; use utils::lsn::Lsn; use utils::measured_stream::MeasuredReader; use utils::pid_file; -use utils::shard::{ShardCount, ShardIndex, ShardNumber}; +use utils::shard::{ShardIndex, ShardNumber, ShardStripeSize}; use crate::configurator::launch_configurator; use crate::disk_quota::set_disk_quota; +use crate::hadron_metrics::COMPUTE_ATTACHED; use crate::installed_extensions::get_installed_extensions; -use crate::logger::startup_context_from_env; +use crate::logger::{self, startup_context_from_env}; use crate::lsn_lease::launch_lsn_lease_bg_task_for_static; use crate::metrics::COMPUTE_CTL_UP; use crate::monitor::launch_monitor; @@ -114,11 +120,17 @@ pub struct ComputeNodeParams { /// Interval for installed extensions collection pub installed_extensions_collection_interval: Arc, - + /// Hadron instance ID of the compute node. + pub instance_id: Option, /// Timeout of PG compute startup in the Init state. pub pg_init_timeout: Option, - + // Path to the `pg_isready` binary. + pub pg_isready_bin: String, pub lakebase_mode: bool, + + pub build_tag: String, + pub control_plane_uri: Option, + pub config_path_test_only: Option, } type TaskHandle = Mutex>>; @@ -184,6 +196,7 @@ pub struct ComputeState { pub startup_span: Option, pub lfc_prewarm_state: LfcPrewarmState, + pub lfc_prewarm_token: CancellationToken, pub lfc_offload_state: LfcOffloadState, /// WAL flush LSN that is set after terminating Postgres and syncing safekeepers if @@ -209,6 +222,7 @@ impl ComputeState { lfc_offload_state: LfcOffloadState::default(), terminate_flush_lsn: None, promote_state: None, + lfc_prewarm_token: CancellationToken::new(), } } @@ -288,72 +302,6 @@ impl ParsedSpec { } } -/// Extract PageserverConnectionInfo from a comma-separated list of libpq connection strings. -/// -/// This is used for backwards-compatilibity, to parse the legacye `pageserver_connstr` -/// field in the compute spec, or the 'neon.pageserver_connstring' GUC. Nowadays, the -/// 'pageserver_connection_info' field should be used instead. -fn extract_pageserver_conninfo_from_connstr( - connstr: &str, - stripe_size: Option, -) -> Result { - let shard_infos: Vec<_> = connstr - .split(',') - .map(|connstr| PageserverShardInfo { - pageservers: vec![PageserverShardConnectionInfo { - id: None, - libpq_url: Some(connstr.to_string()), - grpc_url: None, - }], - }) - .collect(); - - match shard_infos.len() { - 0 => anyhow::bail!("empty connection string"), - 1 => { - // We assume that if there's only connection string, it means "unsharded", - // rather than a sharded system with just a single shard. The latter is - // possible in principle, but we never do it. - let shard_count = ShardCount::unsharded(); - let only_shard = shard_infos.first().unwrap().clone(); - let shards = vec![(ShardIndex::unsharded(), only_shard)]; - Ok(PageserverConnectionInfo { - shard_count, - stripe_size: None, - shards: shards.into_iter().collect(), - prefer_protocol: PageserverProtocol::Libpq, - }) - } - n => { - if stripe_size.is_none() { - anyhow::bail!("{n} shards but no stripe_size"); - } - let shard_count = ShardCount(n.try_into()?); - let shards = shard_infos - .into_iter() - .enumerate() - .map(|(idx, shard_info)| { - ( - ShardIndex { - shard_count, - shard_number: ShardNumber( - idx.try_into().expect("shard number fits in u8"), - ), - }, - shard_info, - ) - }) - .collect(); - Ok(PageserverConnectionInfo { - shard_count, - stripe_size, - shards, - prefer_protocol: PageserverProtocol::Libpq, - }) - } - } -} - impl TryFrom for ParsedSpec { type Error = anyhow::Error; fn try_from(spec: ComputeSpec) -> Result { @@ -367,7 +315,7 @@ impl TryFrom for ParsedSpec { let mut pageserver_conninfo = spec.pageserver_connection_info.clone(); if pageserver_conninfo.is_none() { if let Some(pageserver_connstr_field) = &spec.pageserver_connstring { - pageserver_conninfo = Some(extract_pageserver_conninfo_from_connstr( + pageserver_conninfo = Some(PageserverConnectionInfo::from_connstr( pageserver_connstr_field, spec.shard_stripe_size, )?); @@ -377,12 +325,12 @@ impl TryFrom for ParsedSpec { if let Some(guc) = spec.cluster.settings.find("neon.pageserver_connstring") { let stripe_size = if let Some(guc) = spec.cluster.settings.find("neon.stripe_size") { - Some(u32::from_str(&guc)?) + Some(ShardStripeSize(u32::from_str(&guc)?)) } else { None }; pageserver_conninfo = - Some(extract_pageserver_conninfo_from_connstr(&guc, stripe_size)?); + Some(PageserverConnectionInfo::from_connstr(&guc, stripe_size)?); } } let pageserver_conninfo = pageserver_conninfo.ok_or(anyhow::anyhow!( @@ -494,6 +442,130 @@ struct StartVmMonitorResult { vm_monitor: Option>>, } +// BEGIN_HADRON +/// This function creates roles that are used by Databricks. +/// These roles are not needs to be botostrapped at PG Compute provisioning time. +/// The auth method for these roles are configured in databricks_pg_hba.conf in universe repository. +pub(crate) fn create_databricks_roles() -> Vec { + let roles = vec![ + // Role for prometheus_stats_exporter + Role { + name: "databricks_monitor".to_string(), + // This uses "local" connection and auth method for that is "trust", so no password is needed. + encrypted_password: None, + options: Some(vec![GenericOption { + name: "IN ROLE pg_monitor".to_string(), + value: None, + vartype: "string".to_string(), + }]), + }, + // Role for brickstore control plane + Role { + name: "databricks_control_plane".to_string(), + // Certificate user does not need password. + encrypted_password: None, + options: Some(vec![GenericOption { + name: "SUPERUSER".to_string(), + value: None, + vartype: "string".to_string(), + }]), + }, + // Role for brickstore httpgateway. + Role { + name: "databricks_gateway".to_string(), + // Certificate user does not need password. + encrypted_password: None, + options: None, + }, + ]; + + roles + .into_iter() + .map(|role| { + let query = format!( + r#" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT FROM pg_catalog.pg_roles WHERE rolname = '{}') + THEN + CREATE ROLE {} {}; + END IF; + END + $$;"#, + role.name, + role.name.pg_quote(), + role.to_pg_options(), + ); + query + }) + .collect() +} + +/// Databricks-specific environment variables to be passed to the `postgres` sub-process. +pub struct DatabricksEnvVars { + /// The Databricks "endpoint ID" of the compute instance. Used by `postgres` to check + /// the token scopes of internal auth tokens. + pub endpoint_id: String, + /// Hostname of the Databricks workspace URL this compute instance belongs to. + /// Used by postgres to verify Databricks PAT tokens. + pub workspace_host: String, + + pub lakebase_mode: bool, +} + +impl DatabricksEnvVars { + pub fn new( + compute_spec: &ComputeSpec, + compute_id: Option<&String>, + instance_id: Option, + lakebase_mode: bool, + ) -> Self { + let endpoint_id = if let Some(instance_id) = instance_id { + // Use instance_id as endpoint_id if it is set. This code path is for PuPr model. + instance_id + } else { + // Use compute_id as endpoint_id if instance_id is not set. The code path is for PrPr model. + // compute_id is a string format of "{endpoint_id}/{compute_idx}" + // endpoint_id is a uuid. We only need to pass down endpoint_id to postgres. + // Panics if compute_id is not set or not in the expected format. + compute_id.unwrap().split('/').next().unwrap().to_string() + }; + let workspace_host = compute_spec + .databricks_settings + .as_ref() + .map(|s| s.databricks_workspace_host.clone()) + .unwrap_or("".to_string()); + Self { + endpoint_id, + workspace_host, + lakebase_mode, + } + } + + /// Constants for the names of Databricks-specific postgres environment variables. + const DATABRICKS_ENDPOINT_ID_ENVVAR: &'static str = "DATABRICKS_ENDPOINT_ID"; + const DATABRICKS_WORKSPACE_HOST_ENVVAR: &'static str = "DATABRICKS_WORKSPACE_HOST"; + + /// Convert DatabricksEnvVars to a list of string pairs that can be passed as env vars. Consumes `self`. + pub fn to_env_var_list(self) -> Vec<(String, String)> { + if !self.lakebase_mode { + // In neon env, we don't need to pass down the env vars to postgres. + return vec![]; + } + vec![ + ( + Self::DATABRICKS_ENDPOINT_ID_ENVVAR.to_string(), + self.endpoint_id.clone(), + ), + ( + Self::DATABRICKS_WORKSPACE_HOST_ENVVAR.to_string(), + self.workspace_host.clone(), + ), + ] + } +} + impl ComputeNode { pub fn new(params: ComputeNodeParams, config: ComputeConfig) -> Result { let connstr = params.connstr.as_str(); @@ -517,7 +589,7 @@ impl ComputeNode { // that can affect `compute_ctl` and prevent it from properly configuring the database schema. // Unset them via connection string options before connecting to the database. // N.B. keep it in sync with `ZENITH_OPTIONS` in `get_maintenance_client()`. - const EXTRA_OPTIONS: &str = "-c role=cloud_admin -c default_transaction_read_only=off -c search_path=public -c statement_timeout=0 -c pgaudit.log=none"; + const EXTRA_OPTIONS: &str = "-c role=cloud_admin -c default_transaction_read_only=off -c search_path='' -c statement_timeout=0 -c pgaudit.log=none"; let options = match conn_conf.get_options() { // Allow the control plane to override any options set by the // compute @@ -530,7 +602,11 @@ impl ComputeNode { let mut new_state = ComputeState::new(); if let Some(spec) = config.spec { let pspec = ParsedSpec::try_from(spec).map_err(|msg| anyhow::anyhow!(msg))?; - new_state.pspec = Some(pspec); + if params.lakebase_mode { + ComputeNode::set_spec(¶ms, &mut new_state, pspec); + } else { + new_state.pspec = Some(pspec); + } } Ok(ComputeNode { @@ -575,6 +651,7 @@ impl ComputeNode { port: this.params.external_http_port, config: this.compute_ctl_config.clone(), compute_id: this.params.compute_id.clone(), + instance_id: this.params.instance_id.clone(), } .launch(&this); @@ -1127,7 +1204,14 @@ impl ComputeNode { // If it is something different then create_dir() will error out anyway. let pgdata = &self.params.pgdata; let _ok = fs::remove_dir_all(pgdata); - fs::create_dir(pgdata)?; + if self.params.lakebase_mode { + // Ignore creation errors if the directory already exists (e.g. mounting it ahead of time). + // If it is something different then PG startup will error out anyway. + let _ok = fs::create_dir(pgdata); + } else { + fs::create_dir(pgdata)?; + } + fs::set_permissions(pgdata, fs::Permissions::from_mode(0o700))?; Ok(()) @@ -1184,22 +1268,10 @@ impl ComputeNode { shard_number: ShardNumber(0), shard_count: spec.pageserver_conninfo.shard_count, }; - let shard0 = spec + let shard0_url = spec .pageserver_conninfo - .shards - .get(&shard0_index) - .ok_or_else(|| { - anyhow::anyhow!("shard connection info missing for shard {}", shard0_index) - })?; - let pageserver = shard0 - .pageservers - .first() - .expect("must have at least one pageserver"); - let shard0_url = pageserver - .grpc_url - .clone() - .expect("no grpc_url for shard 0"); - + .shard_url(ShardNumber(0), PageserverProtocol::Grpc)? + .to_owned(); let (reader, connected) = tokio::runtime::Handle::current().block_on(async move { let mut client = page_api::Client::connect( shard0_url, @@ -1237,26 +1309,10 @@ impl ComputeNode { /// Fetches a basebackup via libpq. The connstring must use postgresql://. Returns the timestamp /// when the connection was established, and the (compressed) size of the basebackup. fn try_get_basebackup_libpq(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> { - let shard0_index = ShardIndex { - shard_number: ShardNumber(0), - shard_count: spec.pageserver_conninfo.shard_count, - }; - let shard0 = spec + let shard0_connstr = spec .pageserver_conninfo - .shards - .get(&shard0_index) - .ok_or_else(|| { - anyhow::anyhow!("shard connection info missing for shard {}", shard0_index) - })?; - let pageserver = shard0 - .pageservers - .first() - .expect("must have at least one pageserver"); - let shard0_connstr = pageserver - .libpq_url - .clone() - .expect("no libpq_url for shard 0"); - let mut config = postgres::Config::from_str(&shard0_connstr)?; + .shard_url(ShardNumber(0), PageserverProtocol::Libpq)?; + let mut config = postgres::Config::from_str(shard0_connstr)?; // Use the storage auth token from the config file, if given. // Note: this overrides any password set in the connection string. @@ -1504,6 +1560,41 @@ impl ComputeNode { Ok(lsn) } + fn sync_safekeepers_with_retries(&self, storage_auth_token: Option) -> Result { + let max_retries = 5; + let mut attempts = 0; + loop { + let result = self.sync_safekeepers(storage_auth_token.clone()); + match &result { + Ok(_) => { + if attempts > 0 { + tracing::info!("sync_safekeepers succeeded after {attempts} retries"); + } + return result; + } + Err(e) if attempts < max_retries => { + tracing::info!( + "sync_safekeepers failed, will retry (attempt {attempts}): {e:#}" + ); + } + Err(err) => { + tracing::warn!( + "sync_safekeepers still failed after {attempts} retries, giving up: {err:?}" + ); + return result; + } + } + // sleep and retry + let backoff = exponential_backoff_duration( + attempts, + DEFAULT_BASE_BACKOFF_SECONDS, + DEFAULT_MAX_BACKOFF_SECONDS, + ); + std::thread::sleep(backoff); + attempts += 1; + } + } + /// Do all the preparations like PGDATA directory creation, configuration, /// safekeepers sync, basebackup, etc. #[instrument(skip_all)] @@ -1513,6 +1604,8 @@ impl ComputeNode { let pgdata_path = Path::new(&self.params.pgdata); let tls_config = self.tls_config(&pspec.spec); + let databricks_settings = spec.databricks_settings.as_ref(); + let postgres_port = self.params.connstr.port(); // Remove/create an empty pgdata directory and put configuration there. self.create_pgdata()?; @@ -1520,8 +1613,11 @@ impl ComputeNode { pgdata_path, &self.params, &pspec.spec, + postgres_port, self.params.internal_http_port, tls_config, + databricks_settings, + self.params.lakebase_mode, )?; // Syncing safekeepers is only safe with primary nodes: if a primary @@ -1534,7 +1630,7 @@ impl ComputeNode { lsn } else { info!("starting safekeepers syncing"); - self.sync_safekeepers(pspec.storage_auth_token.clone()) + self.sync_safekeepers_with_retries(pspec.storage_auth_token.clone()) .with_context(|| "failed to sync safekeepers")? }; info!("safekeepers synced at LSN {}", lsn); @@ -1553,8 +1649,28 @@ impl ComputeNode { self.get_basebackup(compute_state, lsn) .with_context(|| format!("failed to get basebackup@{lsn}"))?; - // Update pg_hba.conf received with basebackup. - update_pg_hba(pgdata_path, None)?; + if let Some(settings) = databricks_settings { + copy_tls_certificates( + &settings.pg_compute_tls_settings.key_file, + &settings.pg_compute_tls_settings.cert_file, + pgdata_path, + )?; + + // Update pg_hba.conf received with basebackup including additional databricks settings. + update_pg_hba(pgdata_path, Some(&settings.databricks_pg_hba))?; + update_pg_ident(pgdata_path, Some(&settings.databricks_pg_ident))?; + } else { + // Update pg_hba.conf received with basebackup. + update_pg_hba(pgdata_path, None)?; + } + + if let Some(databricks_settings) = spec.databricks_settings.as_ref() { + copy_tls_certificates( + &databricks_settings.pg_compute_tls_settings.key_file, + &databricks_settings.pg_compute_tls_settings.cert_file, + pgdata_path, + )?; + } // Place pg_dynshmem under /dev/shm. This allows us to use // 'dynamic_shared_memory_type = mmap' so that the files are placed in @@ -1595,7 +1711,7 @@ impl ComputeNode { // symlink doesn't affect anything. // // See https://github.com/neondatabase/autoscaling/issues/800 - std::fs::remove_dir(pgdata_path.join("pg_dynshmem"))?; + std::fs::remove_dir_all(pgdata_path.join("pg_dynshmem"))?; symlink("/dev/shm/", pgdata_path.join("pg_dynshmem"))?; match spec.mode { @@ -1610,6 +1726,12 @@ impl ComputeNode { /// Start and stop a postgres process to warm up the VM for startup. pub fn prewarm_postgres_vm_memory(&self) -> Result<()> { + if self.params.lakebase_mode { + // We are running in Hadron mode. Disabling this prewarming step for now as it could run + // into dblet port conflicts and also doesn't add much value with our current infra. + info!("Skipping postgres prewarming in Hadron mode"); + return Ok(()); + } info!("prewarming VM memory"); // Create pgdata @@ -1667,14 +1789,36 @@ impl ComputeNode { pub fn start_postgres(&self, storage_auth_token: Option) -> Result { let pgdata_path = Path::new(&self.params.pgdata); + let env_vars: Vec<(String, String)> = if self.params.lakebase_mode { + let databricks_env_vars = { + let state = self.state.lock().unwrap(); + let spec = &state.pspec.as_ref().unwrap().spec; + DatabricksEnvVars::new( + spec, + Some(&self.params.compute_id), + self.params.instance_id.clone(), + self.params.lakebase_mode, + ) + }; + + info!( + "Starting Postgres for databricks endpoint id: {}", + &databricks_env_vars.endpoint_id + ); + + let mut env_vars = databricks_env_vars.to_env_var_list(); + env_vars.extend(storage_auth_token.map(|t| ("NEON_AUTH_TOKEN".to_string(), t))); + env_vars + } else if let Some(storage_auth_token) = &storage_auth_token { + vec![("NEON_AUTH_TOKEN".to_owned(), storage_auth_token.to_owned())] + } else { + vec![] + }; + // Run postgres as a child process. let mut pg = maybe_cgexec(&self.params.pgbin) .args(["-D", &self.params.pgdata]) - .envs(if let Some(storage_auth_token) = &storage_auth_token { - vec![("NEON_AUTH_TOKEN", storage_auth_token)] - } else { - vec![] - }) + .envs(env_vars) .stderr(Stdio::piped()) .spawn() .expect("cannot start postgres process"); @@ -1781,7 +1925,7 @@ impl ComputeNode { // It doesn't matter what were the options before, here we just want // to connect and create a new superuser role. - const ZENITH_OPTIONS: &str = "-c role=zenith_admin -c default_transaction_read_only=off -c search_path=public -c statement_timeout=0"; + const ZENITH_OPTIONS: &str = "-c role=zenith_admin -c default_transaction_read_only=off -c search_path='' -c statement_timeout=0"; zenith_admin_conf.options(ZENITH_OPTIONS); let mut client = @@ -1826,7 +1970,15 @@ impl ComputeNode { /// Do initial configuration of the already started Postgres. #[instrument(skip_all)] pub fn apply_config(&self, compute_state: &ComputeState) -> Result<()> { - let conf = self.get_tokio_conn_conf(Some("compute_ctl:apply_config")); + let mut conf = self.get_tokio_conn_conf(Some("compute_ctl:apply_config")); + + if self.params.lakebase_mode { + // Set a 2-minute statement_timeout for the session applying config. The individual SQL statements + // used in apply_spec_sql() should not take long (they are just creating users and installing + // extensions). If any of them are stuck for an extended period of time it usually indicates a + // pageserver connectivity problem and we should bail out. + conf.options("-c statement_timeout=2min"); + } let conf = Arc::new(conf); let spec = Arc::new( @@ -1888,6 +2040,34 @@ impl ComputeNode { Ok::<(), anyhow::Error>(()) } + // Signal to the configurator to refresh the configuration by pulling a new spec from the HCC. + // Note that this merely triggers a notification on a condition variable the configurator thread + // waits on. The configurator thread (in configurator.rs) pulls the new spec from the HCC and + // applies it. + pub async fn signal_refresh_configuration(&self) -> Result<()> { + let states_allowing_configuration_refresh = [ + ComputeStatus::Running, + ComputeStatus::Failed, + ComputeStatus::RefreshConfigurationPending, + ]; + + let mut state = self.state.lock().expect("state lock poisoned"); + if states_allowing_configuration_refresh.contains(&state.status) { + state.status = ComputeStatus::RefreshConfigurationPending; + self.state_changed.notify_all(); + Ok(()) + } else if state.status == ComputeStatus::Init { + // If the compute is in Init state, we can't refresh the configuration immediately, + // but we should be able to do that soon. + Ok(()) + } else { + Err(anyhow::anyhow!( + "Cannot refresh compute configuration in state {:?}", + state.status + )) + } + } + // Wrapped this around `pg_ctl reload`, but right now we don't use // `pg_ctl` for start / stop. #[instrument(skip_all)] @@ -1949,12 +2129,16 @@ impl ComputeNode { // Write new config let pgdata_path = Path::new(&self.params.pgdata); + let postgres_port = self.params.connstr.port(); config::write_postgres_conf( pgdata_path, &self.params, &spec, + postgres_port, self.params.internal_http_port, tls_config, + spec.databricks_settings.as_ref(), + self.params.lakebase_mode, )?; self.pg_reload_conf()?; @@ -2060,6 +2244,8 @@ impl ComputeNode { // wait ComputeStatus::Init | ComputeStatus::Configuration + | ComputeStatus::RefreshConfiguration + | ComputeStatus::RefreshConfigurationPending | ComputeStatus::Empty => { state = self.state_changed.wait(state).unwrap(); } @@ -2110,7 +2296,17 @@ impl ComputeNode { pub fn check_for_core_dumps(&self) -> Result<()> { let core_dump_dir = match std::env::consts::OS { "macos" => Path::new("/cores/"), - _ => Path::new(&self.params.pgdata), + // BEGIN HADRON + // NB: Read core dump files from a fixed location outside of + // the data directory since `compute_ctl` wipes the data directory + // across container restarts. + _ => { + if self.params.lakebase_mode { + Path::new("/databricks/logs/brickstore") + } else { + Path::new(&self.params.pgdata) + } + } // END HADRON }; // Collect core dump paths if any @@ -2184,13 +2380,13 @@ impl ComputeNode { let result = client .simple_query( "SELECT - row_to_json(pg_stat_statements) + pg_catalog.row_to_json(pss) FROM - pg_stat_statements + public.pg_stat_statements pss WHERE - userid != 'cloud_admin'::regrole::oid + pss.userid != 'cloud_admin'::pg_catalog.regrole::pg_catalog.oid ORDER BY - (mean_exec_time + mean_plan_time) DESC + (pss.mean_exec_time + pss.mean_plan_time) DESC LIMIT 100", ) .await; @@ -2318,11 +2514,11 @@ LIMIT 100", // check the role grants first - to gracefully handle read-replicas. let select = "SELECT privilege_type - FROM pg_namespace - JOIN LATERAL (SELECT * FROM aclexplode(nspacl) AS x) acl ON true - JOIN pg_user users ON acl.grantee = users.usesysid - WHERE users.usename = $1 - AND nspname = $2"; + FROM pg_catalog.pg_namespace + JOIN LATERAL (SELECT * FROM aclexplode(nspacl) AS x) AS acl ON true + JOIN pg_catalog.pg_user users ON acl.grantee = users.usesysid + WHERE users.usename OPERATOR(pg_catalog.=) $1::pg_catalog.name + AND nspname OPERATOR(pg_catalog.=) $2::pg_catalog.name"; let rows = db_client .query(select, &[role_name, schema_name]) .await @@ -2391,8 +2587,9 @@ LIMIT 100", .await .with_context(|| format!("Failed to execute query: {query}"))?; } else { - let query = - format!("CREATE EXTENSION IF NOT EXISTS {ext_name} WITH VERSION {quoted_version}"); + let query = format!( + "CREATE EXTENSION IF NOT EXISTS {ext_name} WITH SCHEMA public VERSION {quoted_version}" + ); db_client .simple_query(&query) .await @@ -2423,7 +2620,7 @@ LIMIT 100", if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") { libs_vec = libs .split(&[',', '\'', ' ']) - .filter(|s| *s != "neon" && !s.is_empty()) + .filter(|s| *s != "neon" && *s != "databricks_auth" && !s.is_empty()) .map(str::to_string) .collect(); } @@ -2442,7 +2639,7 @@ LIMIT 100", if let Some(libs) = shared_preload_libraries_line.split("='").nth(1) { preload_libs_vec = libs .split(&[',', '\'', ' ']) - .filter(|s| *s != "neon" && !s.is_empty()) + .filter(|s| *s != "neon" && *s != "databricks_auth" && !s.is_empty()) .map(str::to_string) .collect(); } @@ -2616,6 +2813,34 @@ LIMIT 100", ); } } + + /// Set the compute spec and update related metrics. + /// This is the central place where pspec is updated. + pub fn set_spec(params: &ComputeNodeParams, state: &mut ComputeState, pspec: ParsedSpec) { + state.pspec = Some(pspec); + ComputeNode::update_attached_metric(params, state); + let _ = logger::update_ids(¶ms.instance_id, &Some(params.compute_id.clone())); + } + + pub fn update_attached_metric(params: &ComputeNodeParams, state: &mut ComputeState) { + // Update the pg_cctl_attached gauge when all identifiers are available. + if let Some(instance_id) = ¶ms.instance_id { + if let Some(pspec) = &state.pspec { + // Clear all values in the metric + COMPUTE_ATTACHED.reset(); + + // Set new metric value + COMPUTE_ATTACHED + .with_label_values(&[ + ¶ms.compute_id, + instance_id, + &pspec.tenant_id.to_string(), + &pspec.timeline_id.to_string(), + ]) + .set(1); + } + } + } } pub async fn installed_extensions(conf: tokio_postgres::Config) -> Result<()> { diff --git a/compute_tools/src/compute_prewarm.rs b/compute_tools/src/compute_prewarm.rs index 07b4a596cc..82cb28f1ac 100644 --- a/compute_tools/src/compute_prewarm.rs +++ b/compute_tools/src/compute_prewarm.rs @@ -7,7 +7,8 @@ use http::StatusCode; use reqwest::Client; use std::mem::replace; use std::sync::Arc; -use tokio::{io::AsyncReadExt, spawn}; +use tokio::{io::AsyncReadExt, select, spawn}; +use tokio_util::sync::CancellationToken; use tracing::{error, info}; #[derive(serde::Serialize, Default)] @@ -90,36 +91,37 @@ impl ComputeNode { } /// If there is a prewarm request ongoing, return `false`, `true` otherwise. + /// Has a failpoint "compute-prewarm" pub fn prewarm_lfc(self: &Arc, from_endpoint: Option) -> bool { + let token: CancellationToken; { - let state = &mut self.state.lock().unwrap().lfc_prewarm_state; - if let LfcPrewarmState::Prewarming = replace(state, LfcPrewarmState::Prewarming) { + let state = &mut self.state.lock().unwrap(); + token = state.lfc_prewarm_token.clone(); + if let LfcPrewarmState::Prewarming = + replace(&mut state.lfc_prewarm_state, LfcPrewarmState::Prewarming) + { return false; } } crate::metrics::LFC_PREWARMS.inc(); - let cloned = self.clone(); + let this = self.clone(); spawn(async move { - let state = match cloned.prewarm_impl(from_endpoint).await { - Ok(true) => LfcPrewarmState::Completed, - Ok(false) => { - info!( - "skipping LFC prewarm because LFC state is not found in endpoint storage" - ); - LfcPrewarmState::Skipped - } + let prewarm_state = match this.prewarm_impl(from_endpoint, token).await { + Ok(state) => state, Err(err) => { crate::metrics::LFC_PREWARM_ERRORS.inc(); error!(%err, "could not prewarm LFC"); - - LfcPrewarmState::Failed { - error: err.to_string(), - } + let error = format!("{err:#}"); + LfcPrewarmState::Failed { error } } }; - cloned.state.lock().unwrap().lfc_prewarm_state = state; + let state = &mut this.state.lock().unwrap(); + if let LfcPrewarmState::Cancelled = prewarm_state { + state.lfc_prewarm_token = CancellationToken::new(); + } + state.lfc_prewarm_state = prewarm_state; }); true } @@ -132,43 +134,70 @@ impl ComputeNode { /// Request LFC state from endpoint storage and load corresponding pages into Postgres. /// Returns a result with `false` if the LFC state is not found in endpoint storage. - async fn prewarm_impl(&self, from_endpoint: Option) -> Result { - let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?; + async fn prewarm_impl( + &self, + from_endpoint: Option, + token: CancellationToken, + ) -> Result { + let EndpointStoragePair { + url, + token: storage_token, + } = self.endpoint_storage_pair(from_endpoint)?; + + #[cfg(feature = "testing")] + fail::fail_point!("compute-prewarm", |_| bail!("compute-prewarm failpoint")); info!(%url, "requesting LFC state from endpoint storage"); - let request = Client::new().get(&url).bearer_auth(token); - let res = request.send().await.context("querying endpoint storage")?; - let status = res.status(); - match status { + let request = Client::new().get(&url).bearer_auth(storage_token); + let response = select! { + _ = token.cancelled() => return Ok(LfcPrewarmState::Cancelled), + response = request.send() => response + } + .context("querying endpoint storage")?; + + match response.status() { StatusCode::OK => (), - StatusCode::NOT_FOUND => { - return Ok(false); - } - _ => bail!("{status} querying endpoint storage"), + StatusCode::NOT_FOUND => return Ok(LfcPrewarmState::Skipped), + status => bail!("{status} querying endpoint storage"), } let mut uncompressed = Vec::new(); - let lfc_state = res - .bytes() - .await - .context("getting request body from endpoint storage")?; - ZstdDecoder::new(lfc_state.iter().as_slice()) - .read_to_end(&mut uncompressed) - .await - .context("decoding LFC state")?; + let lfc_state = select! { + _ = token.cancelled() => return Ok(LfcPrewarmState::Cancelled), + lfc_state = response.bytes() => lfc_state + } + .context("getting request body from endpoint storage")?; + + let mut decoder = ZstdDecoder::new(lfc_state.iter().as_slice()); + select! { + _ = token.cancelled() => return Ok(LfcPrewarmState::Cancelled), + read = decoder.read_to_end(&mut uncompressed) => read + } + .context("decoding LFC state")?; + let uncompressed_len = uncompressed.len(); + info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}"); - info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}, loading into Postgres"); - - ComputeNode::get_maintenance_client(&self.tokio_conn_conf) + // Client connection and prewarm info querying are fast and therefore don't need + // cancellation + let client = ComputeNode::get_maintenance_client(&self.tokio_conn_conf) .await - .context("connecting to postgres")? - .query_one("select neon.prewarm_local_cache($1)", &[&uncompressed]) - .await - .context("loading LFC state into postgres") - .map(|_| ())?; + .context("connecting to postgres")?; + let pg_token = client.cancel_token(); - Ok(true) + let params: Vec<&(dyn postgres_types::ToSql + Sync)> = vec![&uncompressed]; + select! { + res = client.query_one("select neon.prewarm_local_cache($1)", ¶ms) => res, + _ = token.cancelled() => { + pg_token.cancel_query(postgres::NoTls).await + .context("cancelling neon.prewarm_local_cache()")?; + return Ok(LfcPrewarmState::Cancelled) + } + } + .context("loading LFC state into postgres") + .map(|_| ())?; + + Ok(LfcPrewarmState::Completed) } /// If offload request is ongoing, return false, true otherwise @@ -196,33 +225,39 @@ impl ComputeNode { async fn offload_lfc_with_state_update(&self) { crate::metrics::LFC_OFFLOADS.inc(); - - let Err(err) = self.offload_lfc_impl().await else { - self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Completed; - return; + let state = match self.offload_lfc_impl().await { + Ok(state) => state, + Err(err) => { + crate::metrics::LFC_OFFLOAD_ERRORS.inc(); + error!(%err, "could not offload LFC"); + let error = format!("{err:#}"); + LfcOffloadState::Failed { error } + } }; - crate::metrics::LFC_OFFLOAD_ERRORS.inc(); - error!(%err, "could not offload LFC state to endpoint storage"); - self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Failed { - error: err.to_string(), - }; + self.state.lock().unwrap().lfc_offload_state = state; } - async fn offload_lfc_impl(&self) -> Result<()> { + async fn offload_lfc_impl(&self) -> Result { let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?; info!(%url, "requesting LFC state from Postgres"); - let mut compressed = Vec::new(); - ComputeNode::get_maintenance_client(&self.tokio_conn_conf) + let row = ComputeNode::get_maintenance_client(&self.tokio_conn_conf) .await .context("connecting to postgres")? .query_one("select neon.get_local_cache_state()", &[]) .await - .context("querying LFC state")? - .try_get::(0) - .context("deserializing LFC state") - .map(ZstdEncoder::new)? + .context("querying LFC state")?; + let state = row + .try_get::>(0) + .context("deserializing LFC state")?; + let Some(state) = state else { + info!(%url, "empty LFC state, not exporting"); + return Ok(LfcOffloadState::Skipped); + }; + + let mut compressed = Vec::new(); + ZstdEncoder::new(state) .read_to_end(&mut compressed) .await .context("compressing LFC state")?; @@ -232,7 +267,7 @@ impl ComputeNode { let request = Client::new().put(url).bearer_auth(token).body(compressed); match request.send().await { - Ok(res) if res.status() == StatusCode::OK => Ok(()), + Ok(res) if res.status() == StatusCode::OK => Ok(LfcOffloadState::Completed), Ok(res) => bail!( "Request to endpoint storage failed with status: {}", res.status() @@ -240,4 +275,8 @@ impl ComputeNode { Err(err) => Err(err).context("writing to endpoint storage"), } } + + pub fn cancel_prewarm(self: &Arc) { + self.state.lock().unwrap().lfc_prewarm_token.cancel(); + } } diff --git a/compute_tools/src/compute_promote.rs b/compute_tools/src/compute_promote.rs index 42256faa22..29195b60e9 100644 --- a/compute_tools/src/compute_promote.rs +++ b/compute_tools/src/compute_promote.rs @@ -1,11 +1,12 @@ use crate::compute::ComputeNode; use anyhow::{Context, Result, bail}; -use compute_api::{ - responses::{LfcPrewarmState, PromoteState, SafekeepersLsn}, - spec::ComputeMode, -}; +use compute_api::responses::{LfcPrewarmState, PromoteConfig, PromoteState}; +use compute_api::spec::ComputeMode; +use itertools::Itertools; +use std::collections::HashMap; use std::{sync::Arc, time::Duration}; use tokio::time::sleep; +use tracing::info; use utils::lsn::Lsn; impl ComputeNode { @@ -13,21 +14,22 @@ impl ComputeNode { /// and http client disconnects, this does not stop promotion, and subsequent /// calls block until promote finishes. /// Called by control plane on secondary after primary endpoint is terminated - pub async fn promote(self: &Arc, safekeepers_lsn: SafekeepersLsn) -> PromoteState { + /// Has a failpoint "compute-promotion" + pub async fn promote(self: &Arc, cfg: PromoteConfig) -> PromoteState { let cloned = self.clone(); + let promote_fn = async move || { + let Err(err) = cloned.promote_impl(cfg).await else { + return PromoteState::Completed; + }; + tracing::error!(%err, "promoting"); + PromoteState::Failed { + error: format!("{err:#}"), + } + }; + let start_promotion = || { let (tx, rx) = tokio::sync::watch::channel(PromoteState::NotPromoted); - tokio::spawn(async move { - tx.send(match cloned.promote_impl(safekeepers_lsn).await { - Ok(_) => PromoteState::Completed, - Err(err) => { - tracing::error!(%err, "promoting"); - PromoteState::Failed { - error: err.to_string(), - } - } - }) - }); + tokio::spawn(async move { tx.send(promote_fn().await) }); rx }; @@ -47,9 +49,7 @@ impl ComputeNode { task.borrow().clone() } - // Why do we have to supply safekeepers? - // For secondary we use primary_connection_conninfo so safekeepers field is empty - async fn promote_impl(&self, safekeepers_lsn: SafekeepersLsn) -> Result<()> { + async fn promote_impl(&self, mut cfg: PromoteConfig) -> Result<()> { { let state = self.state.lock().unwrap(); let mode = &state.pspec.as_ref().unwrap().spec.mode; @@ -73,12 +73,12 @@ impl ComputeNode { .await .context("connecting to postgres")?; - let primary_lsn = safekeepers_lsn.wal_flush_lsn; + let primary_lsn = cfg.wal_flush_lsn; let mut last_wal_replay_lsn: Lsn = Lsn::INVALID; const RETRIES: i32 = 20; for i in 0..=RETRIES { let row = client - .query_one("SELECT pg_last_wal_replay_lsn()", &[]) + .query_one("SELECT pg_catalog.pg_last_wal_replay_lsn()", &[]) .await .context("getting last replay lsn")?; let lsn: u64 = row.get::(0).into(); @@ -86,7 +86,7 @@ impl ComputeNode { if last_wal_replay_lsn >= primary_lsn { break; } - tracing::info!("Try {i}, replica lsn {last_wal_replay_lsn}, primary lsn {primary_lsn}"); + info!("Try {i}, replica lsn {last_wal_replay_lsn}, primary lsn {primary_lsn}"); sleep(Duration::from_secs(1)).await; } if last_wal_replay_lsn < primary_lsn { @@ -96,18 +96,24 @@ impl ComputeNode { // using $1 doesn't work with ALTER SYSTEM SET let safekeepers_sql = format!( "ALTER SYSTEM SET neon.safekeepers='{}'", - safekeepers_lsn.safekeepers + cfg.spec.safekeeper_connstrings.join(",") ); client .query(&safekeepers_sql, &[]) .await .context("setting safekeepers")?; client - .query("SELECT pg_reload_conf()", &[]) + .query("SELECT pg_catalog.pg_reload_conf()", &[]) .await .context("reloading postgres config")?; + + #[cfg(feature = "testing")] + fail::fail_point!("compute-promotion", |_| { + bail!("promotion configured to fail because of a failpoint") + }); + let row = client - .query_one("SELECT * FROM pg_promote()", &[]) + .query_one("SELECT * FROM pg_catalog.pg_promote()", &[]) .await .context("pg_promote")?; if !row.get::(0) { @@ -125,8 +131,36 @@ impl ComputeNode { bail!("replica in read only mode after promotion"); } - let mut state = self.state.lock().unwrap(); - state.pspec.as_mut().unwrap().spec.mode = ComputeMode::Primary; - Ok(()) + { + let mut state = self.state.lock().unwrap(); + let spec = &mut state.pspec.as_mut().unwrap().spec; + spec.mode = ComputeMode::Primary; + let new_conf = cfg.spec.cluster.postgresql_conf.as_mut().unwrap(); + let existing_conf = spec.cluster.postgresql_conf.as_ref().unwrap(); + Self::merge_spec(new_conf, existing_conf); + } + info!("applied new spec, reconfiguring as primary"); + self.reconfigure() + } + + /// Merge old and new Postgres conf specs to apply on secondary. + /// Change new spec's port and safekeepers since they are supplied + /// differenly + fn merge_spec(new_conf: &mut String, existing_conf: &str) { + let mut new_conf_set: HashMap<&str, &str> = new_conf + .split_terminator('\n') + .map(|e| e.split_once("=").expect("invalid item")) + .collect(); + new_conf_set.remove("neon.safekeepers"); + + let existing_conf_set: HashMap<&str, &str> = existing_conf + .split_terminator('\n') + .map(|e| e.split_once("=").expect("invalid item")) + .collect(); + new_conf_set.insert("port", existing_conf_set["port"]); + *new_conf = new_conf_set + .iter() + .map(|(k, v)| format!("{k}={v}")) + .join("\n"); } } diff --git a/compute_tools/src/config.rs b/compute_tools/src/config.rs index 8821611f0c..73768938c0 100644 --- a/compute_tools/src/config.rs +++ b/compute_tools/src/config.rs @@ -7,11 +7,14 @@ use std::io::prelude::*; use std::path::Path; use compute_api::responses::TlsConfig; -use compute_api::spec::{ComputeAudit, ComputeMode, ComputeSpec, GenericOption}; +use compute_api::spec::{ + ComputeAudit, ComputeMode, ComputeSpec, DatabricksSettings, GenericOption, PageserverProtocol, +}; use crate::compute::ComputeNodeParams; use crate::pg_helpers::{ - GenericOptionExt, GenericOptionsSearch, PgOptionsSerialize, escape_conf_value, + DatabricksSettingsExt as _, GenericOptionExt, GenericOptionsSearch, PgOptionsSerialize, + escape_conf_value, }; use crate::tls::{self, SERVER_CRT, SERVER_KEY}; @@ -42,12 +45,16 @@ pub fn line_in_file(path: &Path, line: &str) -> Result { } /// Create or completely rewrite configuration file specified by `path` +#[allow(clippy::too_many_arguments)] pub fn write_postgres_conf( pgdata_path: &Path, params: &ComputeNodeParams, spec: &ComputeSpec, + postgres_port: Option, extension_server_port: u16, tls_config: &Option, + databricks_settings: Option<&DatabricksSettings>, + lakebase_mode: bool, ) -> Result<()> { let path = pgdata_path.join("postgresql.conf"); // File::create() destroys the file content if it exists. @@ -62,11 +69,20 @@ pub fn write_postgres_conf( writeln!(file, "# Neon storage settings")?; writeln!(file)?; if let Some(conninfo) = &spec.pageserver_connection_info { + match conninfo.prefer_protocol { + PageserverProtocol::Libpq => { + writeln!(file, "neon.use_communicator_worker=false")?; + } + PageserverProtocol::Grpc => { + writeln!(file, "neon.use_communicator_worker=true")?; + } + } + // Stripe size GUC should be defined prior to connection string if let Some(stripe_size) = conninfo.stripe_size { writeln!( file, - "# from compute spec's pageserver_conninfo.stripe_size field" + "# from compute spec's pageserver_connection_info.stripe_size field" )?; writeln!(file, "neon.stripe_size={stripe_size}")?; } @@ -117,7 +133,7 @@ pub fn write_postgres_conf( if let Some(libpq_urls) = libpq_urls { writeln!( file, - "# derived from compute spec's pageserver_conninfo field" + "# derived from compute spec's pageserver_connection_info field" )?; writeln!( file, @@ -141,12 +157,13 @@ pub fn write_postgres_conf( writeln!(file, "# no neon.pageserver_grpc_urls")?; } } else { + writeln!(file, "neon.use_communicator_worker=false")?; + // Stripe size GUC should be defined prior to connection string if let Some(stripe_size) = spec.shard_stripe_size { writeln!(file, "# from compute spec's shard_stripe_size field")?; writeln!(file, "neon.stripe_size={stripe_size}")?; } - if let Some(s) = &spec.pageserver_connstring { writeln!(file, "# from compute spec's pageserver_connstring field")?; writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?; @@ -373,6 +390,24 @@ pub fn write_postgres_conf( writeln!(file, "log_destination='stderr,syslog'")?; } + if lakebase_mode { + // Explicitly set the port based on the connstr, overriding any previous port setting. + // Note: It is important that we don't specify a different port again after this. + let port = postgres_port.expect("port must be present in connstr"); + writeln!(file, "port = {port}")?; + + // This is databricks specific settings. + // This should be at the end of the file but before `compute_ctl_temp_override.conf` below + // so that it can override any settings above. + // `compute_ctl_temp_override.conf` is intended to override any settings above during specific operations. + // To prevent potential breakage in the future, we keep it above `compute_ctl_temp_override.conf`. + writeln!(file, "# Databricks settings start")?; + if let Some(settings) = databricks_settings { + writeln!(file, "{}", settings.as_pg_settings())?; + } + writeln!(file, "# Databricks settings end")?; + } + // This is essential to keep this line at the end of the file, // because it is intended to override any settings above. writeln!(file, "include_if_exists = 'compute_ctl_temp_override.conf'")?; diff --git a/compute_tools/src/configurator.rs b/compute_tools/src/configurator.rs index d97bd37285..79eb80c4a0 100644 --- a/compute_tools/src/configurator.rs +++ b/compute_tools/src/configurator.rs @@ -1,23 +1,40 @@ -use std::sync::Arc; +use std::fs::File; use std::thread; +use std::{path::Path, sync::Arc}; -use compute_api::responses::ComputeStatus; +use anyhow::Result; +use compute_api::responses::{ComputeConfig, ComputeStatus}; use tracing::{error, info, instrument}; -use crate::compute::ComputeNode; +use crate::compute::{ComputeNode, ParsedSpec}; +use crate::spec::get_config_from_control_plane; #[instrument(skip_all)] fn configurator_main_loop(compute: &Arc) { info!("waiting for reconfiguration requests"); loop { let mut state = compute.state.lock().unwrap(); + /* BEGIN_HADRON */ + // RefreshConfiguration should only be used inside the loop + assert_ne!(state.status, ComputeStatus::RefreshConfiguration); + /* END_HADRON */ - // We have to re-check the status after re-acquiring the lock because it could be that - // the status has changed while we were waiting for the lock, and we might not need to - // wait on the condition variable. Otherwise, we might end up in some soft-/deadlock, i.e. - // we are waiting for a condition variable that will never be signaled. - if state.status != ComputeStatus::ConfigurationPending { - state = compute.state_changed.wait(state).unwrap(); + if compute.params.lakebase_mode { + while state.status != ComputeStatus::ConfigurationPending + && state.status != ComputeStatus::RefreshConfigurationPending + && state.status != ComputeStatus::Failed + { + info!("configurator: compute status: {:?}, sleeping", state.status); + state = compute.state_changed.wait(state).unwrap(); + } + } else { + // We have to re-check the status after re-acquiring the lock because it could be that + // the status has changed while we were waiting for the lock, and we might not need to + // wait on the condition variable. Otherwise, we might end up in some soft-/deadlock, i.e. + // we are waiting for a condition variable that will never be signaled. + if state.status != ComputeStatus::ConfigurationPending { + state = compute.state_changed.wait(state).unwrap(); + } } // Re-check the status after waking up @@ -37,6 +54,136 @@ fn configurator_main_loop(compute: &Arc) { // XXX: used to test that API is blocking // std::thread::sleep(std::time::Duration::from_millis(10000)); + compute.set_status(new_status); + } else if state.status == ComputeStatus::RefreshConfigurationPending { + info!( + "compute node suspects its configuration is out of date, now refreshing configuration" + ); + state.set_status(ComputeStatus::RefreshConfiguration, &compute.state_changed); + // Drop the lock guard here to avoid holding the lock while downloading config from the control plane / HCC. + // This is the only thread that can move compute_ctl out of the `RefreshConfiguration` state, so it + // is safe to drop the lock like this. + drop(state); + + let get_config_result: anyhow::Result = + if let Some(config_path) = &compute.params.config_path_test_only { + // This path is only to make testing easier. In production we always get the config from the HCC. + info!( + "reloading config.json from path: {}", + config_path.to_string_lossy() + ); + let path = Path::new(config_path); + if let Ok(file) = File::open(path) { + match serde_json::from_reader::(file) { + Ok(config) => Ok(config), + Err(e) => { + error!("could not parse config file: {}", e); + Err(anyhow::anyhow!("could not parse config file: {}", e)) + } + } + } else { + error!( + "could not open config file at path: {:?}", + config_path.to_string_lossy() + ); + Err(anyhow::anyhow!( + "could not open config file at path: {}", + config_path.to_string_lossy() + )) + } + } else if let Some(control_plane_uri) = &compute.params.control_plane_uri { + get_config_from_control_plane(control_plane_uri, &compute.params.compute_id) + } else { + Err(anyhow::anyhow!("config_path_test_only is not set")) + }; + + // Parse any received ComputeSpec and transpose the result into a Result>. + let parsed_spec_result: Result> = + get_config_result.and_then(|config| { + if let Some(spec) = config.spec { + if let Ok(pspec) = ParsedSpec::try_from(spec) { + Ok(Some(pspec)) + } else { + Err(anyhow::anyhow!("could not parse spec")) + } + } else { + Ok(None) + } + }); + + let new_status: ComputeStatus; + match parsed_spec_result { + // Control plane (HCM) returned a spec and we were able to parse it. + Ok(Some(pspec)) => { + { + let mut state = compute.state.lock().unwrap(); + // Defensive programming to make sure this thread is indeed the only one that can move the compute + // node out of the `RefreshConfiguration` state. Would be nice if we can encode this invariant + // into the type system. + assert_eq!(state.status, ComputeStatus::RefreshConfiguration); + + if state + .pspec + .as_ref() + .map(|ps| ps.pageserver_conninfo.clone()) + == Some(pspec.pageserver_conninfo.clone()) + { + info!( + "Refresh configuration: Retrieved spec is the same as the current spec. Waiting for control plane to update the spec before attempting reconfiguration." + ); + state.status = ComputeStatus::Running; + compute.state_changed.notify_all(); + drop(state); + std::thread::sleep(std::time::Duration::from_secs(5)); + continue; + } + // state.pspec is consumed by compute.reconfigure() below. Note that compute.reconfigure() will acquire + // the compute.state lock again so we need to have the lock guard go out of scope here. We could add a + // "locked" variant of compute.reconfigure() that takes the lock guard as an argument to make this cleaner, + // but it's not worth forking the codebase too much for this minor point alone right now. + state.pspec = Some(pspec); + } + match compute.reconfigure() { + Ok(_) => { + info!("Refresh configuration: compute node configured"); + new_status = ComputeStatus::Running; + } + Err(e) => { + error!( + "Refresh configuration: could not configure compute node: {}", + e + ); + // Set the compute node back to the `RefreshConfigurationPending` state if the configuration + // was not successful. It should be okay to treat this situation the same as if the loop + // hasn't executed yet as long as the detection side keeps notifying. + new_status = ComputeStatus::RefreshConfigurationPending; + } + } + } + // Control plane (HCM)'s response does not contain a spec. This is the "Empty" attachment case. + Ok(None) => { + info!( + "Compute Manager signaled that this compute is no longer attached to any storage. Exiting." + ); + // We just immediately terminate the whole compute_ctl in this case. It's not necessary to attempt a + // clean shutdown as Postgres is probably not responding anyway (which is why we are in this refresh + // configuration state). + std::process::exit(1); + } + // Various error cases: + // - The request to the control plane (HCM) either failed or returned a malformed spec. + // - compute_ctl itself is configured incorrectly (e.g., compute_id is not set). + Err(e) => { + error!( + "Refresh configuration: error getting a parsed spec: {:?}", + e + ); + new_status = ComputeStatus::RefreshConfigurationPending; + // We may be dealing with an overloaded HCM if we end up in this path. Backoff 5 seconds before + // retrying to avoid hammering the HCM. + std::thread::sleep(std::time::Duration::from_secs(5)); + } + } compute.set_status(new_status); } else if state.status == ComputeStatus::Failed { info!("compute node is now in Failed state, exiting"); diff --git a/compute_tools/src/hadron_metrics.rs b/compute_tools/src/hadron_metrics.rs new file mode 100644 index 0000000000..17c4e82622 --- /dev/null +++ b/compute_tools/src/hadron_metrics.rs @@ -0,0 +1,60 @@ +use metrics::{ + IntCounter, IntGaugeVec, core::Collector, proto::MetricFamily, register_int_counter, + register_int_gauge_vec, +}; +use once_cell::sync::Lazy; + +// Counter keeping track of the number of PageStream request errors reported by Postgres. +// An error is registered every time Postgres calls compute_ctl's /refresh_configuration API. +// Postgres will invoke this API if it detected trouble with PageStream requests (get_page@lsn, +// get_base_backup, etc.) it sends to any pageserver. An increase in this counter value typically +// indicates Postgres downtime, as PageStream requests are critical for Postgres to function. +pub static POSTGRES_PAGESTREAM_REQUEST_ERRORS: Lazy = Lazy::new(|| { + register_int_counter!( + "pg_cctl_pagestream_request_errors_total", + "Number of PageStream request errors reported by the postgres process" + ) + .expect("failed to define a metric") +}); + +// Counter keeping track of the number of compute configuration errors due to Postgres statement +// timeouts. An error is registered every time `ComputeNode::reconfigure()` fails due to Postgres +// error code 57014 (query cancelled). This statement timeout typically occurs when postgres is +// stuck in a problematic retry loop when the PS is reject its connection requests (usually due +// to PG pointing at the wrong PS). We should investigate the root cause when this counter value +// increases by checking PG and PS logs. +pub static COMPUTE_CONFIGURE_STATEMENT_TIMEOUT_ERRORS: Lazy = Lazy::new(|| { + register_int_counter!( + "pg_cctl_configure_statement_timeout_errors_total", + "Number of compute configuration errors due to Postgres statement timeouts." + ) + .expect("failed to define a metric") +}); + +pub static COMPUTE_ATTACHED: Lazy = Lazy::new(|| { + register_int_gauge_vec!( + "pg_cctl_attached", + "Compute node attached status (1 if attached)", + &[ + "pg_compute_id", + "pg_instance_id", + "tenant_id", + "timeline_id" + ] + ) + .expect("failed to define a metric") +}); + +pub fn collect() -> Vec { + let mut metrics = Vec::new(); + metrics.extend(POSTGRES_PAGESTREAM_REQUEST_ERRORS.collect()); + metrics.extend(COMPUTE_CONFIGURE_STATEMENT_TIMEOUT_ERRORS.collect()); + metrics.extend(COMPUTE_ATTACHED.collect()); + metrics +} + +pub fn initialize_metrics() { + Lazy::force(&POSTGRES_PAGESTREAM_REQUEST_ERRORS); + Lazy::force(&COMPUTE_CONFIGURE_STATEMENT_TIMEOUT_ERRORS); + Lazy::force(&COMPUTE_ATTACHED); +} diff --git a/compute_tools/src/http/middleware/authorize.rs b/compute_tools/src/http/middleware/authorize.rs index a82f46e062..407833bb0e 100644 --- a/compute_tools/src/http/middleware/authorize.rs +++ b/compute_tools/src/http/middleware/authorize.rs @@ -16,13 +16,29 @@ use crate::http::JsonResponse; #[derive(Clone, Debug)] pub(in crate::http) struct Authorize { compute_id: String, + // BEGIN HADRON + // Hadron instance ID. Only set if it's a Lakebase V1 a.k.a. Hadron instance. + instance_id: Option, + // END HADRON jwks: JwkSet, validation: Validation, } impl Authorize { - pub fn new(compute_id: String, jwks: JwkSet) -> Self { + pub fn new(compute_id: String, instance_id: Option, jwks: JwkSet) -> Self { let mut validation = Validation::new(Algorithm::EdDSA); + + // BEGIN HADRON + let use_rsa = jwks.keys.iter().any(|jwk| { + jwk.common + .key_algorithm + .is_some_and(|alg| alg == jsonwebtoken::jwk::KeyAlgorithm::RS256) + }); + if use_rsa { + validation = Validation::new(Algorithm::RS256); + } + // END HADRON + validation.validate_exp = true; // Unused by the control plane validation.validate_nbf = false; @@ -34,6 +50,7 @@ impl Authorize { Self { compute_id, + instance_id, jwks, validation, } @@ -47,10 +64,20 @@ impl AsyncAuthorizeRequest for Authorize { fn authorize(&mut self, mut request: Request) -> Self::Future { let compute_id = self.compute_id.clone(); + let is_hadron_instance = self.instance_id.is_some(); let jwks = self.jwks.clone(); let validation = self.validation.clone(); Box::pin(async move { + // BEGIN HADRON + // In Hadron deployments the "external" HTTP endpoint on compute_ctl can only be + // accessed by trusted components (enforced by dblet network policy), so we can bypass + // all auth here. + if is_hadron_instance { + return Ok(request); + } + // END HADRON + let TypedHeader(Authorization(bearer)) = request .extract_parts::>>() .await diff --git a/compute_tools/src/http/openapi_spec.yaml b/compute_tools/src/http/openapi_spec.yaml index 3cf5ea7c51..27e610a87d 100644 --- a/compute_tools/src/http/openapi_spec.yaml +++ b/compute_tools/src/http/openapi_spec.yaml @@ -96,7 +96,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/SafekeepersLsn" + $ref: "#/components/schemas/ComputeSchemaWithLsn" responses: 200: description: Promote succeeded or wasn't started @@ -139,6 +139,15 @@ paths: application/json: schema: $ref: "#/components/schemas/LfcPrewarmState" + delete: + tags: + - Prewarm + summary: Cancel ongoing LFC prewarm + description: "" + operationId: cancelLfcPrewarm + responses: + 202: + description: Prewarm cancelled /lfc/offload: post: @@ -297,14 +306,7 @@ paths: content: application/json: schema: - type: object - required: - - spec - properties: - spec: - # XXX: I don't want to explain current spec in the OpenAPI format, - # as it could be changed really soon. Consider doing it later. - type: object + $ref: "#/components/schemas/ComputeSchema" responses: 200: description: Compute configuration finished. @@ -591,18 +593,25 @@ components: type: string example: "1.0.0" - SafekeepersLsn: + ComputeSchema: type: object required: - - safekeepers + - spec + properties: + spec: + type: object + ComputeSchemaWithLsn: + type: object + required: + - spec - wal_flush_lsn properties: - safekeepers: - description: Primary replica safekeepers - type: string + spec: + $ref: "#/components/schemas/ComputeState" wal_flush_lsn: - description: Primary last WAL flush LSN type: string + description: "last WAL flush LSN" + example: "0/028F10D8" LfcPrewarmState: type: object @@ -636,7 +645,7 @@ components: properties: status: description: LFC offload status - enum: [not_offloaded, offloading, completed, failed] + enum: [not_offloaded, offloading, completed, skipped, failed] type: string error: description: LFC offload error, if any diff --git a/compute_tools/src/http/routes/configure.rs b/compute_tools/src/http/routes/configure.rs index b7325d283f..943ff45357 100644 --- a/compute_tools/src/http/routes/configure.rs +++ b/compute_tools/src/http/routes/configure.rs @@ -43,7 +43,12 @@ pub(in crate::http) async fn configure( // configure request for tracing purposes. state.startup_span = Some(tracing::Span::current()); - state.pspec = Some(pspec); + if compute.params.lakebase_mode { + ComputeNode::set_spec(&compute.params, &mut state, pspec); + } else { + state.pspec = Some(pspec); + } + state.set_status(ComputeStatus::ConfigurationPending, &compute.state_changed); drop(state); } diff --git a/compute_tools/src/http/routes/hadron_liveness_probe.rs b/compute_tools/src/http/routes/hadron_liveness_probe.rs new file mode 100644 index 0000000000..4f66b6b139 --- /dev/null +++ b/compute_tools/src/http/routes/hadron_liveness_probe.rs @@ -0,0 +1,34 @@ +use crate::pg_isready::pg_isready; +use crate::{compute::ComputeNode, http::JsonResponse}; +use axum::{extract::State, http::StatusCode, response::Response}; +use std::sync::Arc; + +/// NOTE: NOT ENABLED YET +/// Detect if the compute is alive. +/// Called by the liveness probe of the compute container. +pub(in crate::http) async fn hadron_liveness_probe( + State(compute): State>, +) -> Response { + let port = match compute.params.connstr.port() { + Some(port) => port, + None => { + return JsonResponse::error( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to get the port from the connection string", + ); + } + }; + match pg_isready(&compute.params.pg_isready_bin, port) { + Ok(_) => { + // The connection is successful, so the compute is alive. + // Return a 200 OK response. + JsonResponse::success(StatusCode::OK, "ok") + } + Err(e) => { + tracing::error!("Hadron liveness probe failed: {}", e); + // The connection failed, so the compute is not alive. + // Return a 500 Internal Server Error response. + JsonResponse::error(StatusCode::INTERNAL_SERVER_ERROR, e) + } + } +} diff --git a/compute_tools/src/http/routes/lfc.rs b/compute_tools/src/http/routes/lfc.rs index e98bd781a2..7483198723 100644 --- a/compute_tools/src/http/routes/lfc.rs +++ b/compute_tools/src/http/routes/lfc.rs @@ -46,3 +46,8 @@ pub(in crate::http) async fn offload(compute: Compute) -> Response { ) } } + +pub(in crate::http) async fn cancel_prewarm(compute: Compute) -> StatusCode { + compute.cancel_prewarm(); + StatusCode::ACCEPTED +} diff --git a/compute_tools/src/http/routes/metrics.rs b/compute_tools/src/http/routes/metrics.rs index 96b464fd12..8406746327 100644 --- a/compute_tools/src/http/routes/metrics.rs +++ b/compute_tools/src/http/routes/metrics.rs @@ -13,6 +13,7 @@ use metrics::{Encoder, TextEncoder}; use crate::communicator_socket_client::connect_communicator_socket; use crate::compute::ComputeNode; +use crate::hadron_metrics; use crate::http::JsonResponse; use crate::metrics::collect; @@ -21,11 +22,18 @@ pub(in crate::http) async fn get_metrics() -> Response { // When we call TextEncoder::encode() below, it will immediately return an // error if a metric family has no metrics, so we need to preemptively // filter out metric families with no metrics. - let metrics = collect() + let mut metrics = collect() .into_iter() .filter(|m| !m.get_metric().is_empty()) .collect::>(); + // Add Hadron metrics. + let hadron_metrics: Vec = hadron_metrics::collect() + .into_iter() + .filter(|m| !m.get_metric().is_empty()) + .collect(); + metrics.extend(hadron_metrics); + let encoder = TextEncoder::new(); let mut buffer = vec![]; diff --git a/compute_tools/src/http/routes/mod.rs b/compute_tools/src/http/routes/mod.rs index dd71f663eb..c0f68701c6 100644 --- a/compute_tools/src/http/routes/mod.rs +++ b/compute_tools/src/http/routes/mod.rs @@ -10,11 +10,13 @@ pub(in crate::http) mod extension_server; pub(in crate::http) mod extensions; pub(in crate::http) mod failpoints; pub(in crate::http) mod grants; +pub(in crate::http) mod hadron_liveness_probe; pub(in crate::http) mod insights; pub(in crate::http) mod lfc; pub(in crate::http) mod metrics; pub(in crate::http) mod metrics_json; pub(in crate::http) mod promote; +pub(in crate::http) mod refresh_configuration; pub(in crate::http) mod status; pub(in crate::http) mod terminate; diff --git a/compute_tools/src/http/routes/promote.rs b/compute_tools/src/http/routes/promote.rs index bc5f93b4da..7ca3464b63 100644 --- a/compute_tools/src/http/routes/promote.rs +++ b/compute_tools/src/http/routes/promote.rs @@ -1,14 +1,14 @@ use crate::http::JsonResponse; -use axum::Form; +use axum::extract::Json; use http::StatusCode; pub(in crate::http) async fn promote( compute: axum::extract::State>, - Form(safekeepers_lsn): Form, + Json(cfg): Json, ) -> axum::response::Response { - let state = compute.promote(safekeepers_lsn).await; - if let compute_api::responses::PromoteState::Failed { error } = state { - return JsonResponse::error(StatusCode::INTERNAL_SERVER_ERROR, error); + let state = compute.promote(cfg).await; + if let compute_api::responses::PromoteState::Failed { error: _ } = state { + return JsonResponse::create_response(StatusCode::INTERNAL_SERVER_ERROR, state); } JsonResponse::success(StatusCode::OK, state) } diff --git a/compute_tools/src/http/routes/refresh_configuration.rs b/compute_tools/src/http/routes/refresh_configuration.rs new file mode 100644 index 0000000000..9b2f95ca5a --- /dev/null +++ b/compute_tools/src/http/routes/refresh_configuration.rs @@ -0,0 +1,29 @@ +// This file is added by Hadron + +use std::sync::Arc; + +use axum::{ + extract::State, + response::{IntoResponse, Response}, +}; +use http::StatusCode; + +use crate::compute::ComputeNode; +use crate::hadron_metrics::POSTGRES_PAGESTREAM_REQUEST_ERRORS; +use crate::http::JsonResponse; + +/// The /refresh_configuration POST method is used to nudge compute_ctl to pull a new spec +/// from the HCC and attempt to reconfigure Postgres with the new spec. The method does not wait +/// for the reconfiguration to complete. Rather, it simply delivers a signal that will cause +/// configuration to be reloaded in a best effort manner. Invocation of this method does not +/// guarantee that a reconfiguration will occur. The caller should consider keep sending this +/// request while it believes that the compute configuration is out of date. +pub(in crate::http) async fn refresh_configuration( + State(compute): State>, +) -> Response { + POSTGRES_PAGESTREAM_REQUEST_ERRORS.inc(); + match compute.signal_refresh_configuration().await { + Ok(_) => StatusCode::OK.into_response(), + Err(e) => JsonResponse::error(StatusCode::INTERNAL_SERVER_ERROR, e), + } +} diff --git a/compute_tools/src/http/routes/terminate.rs b/compute_tools/src/http/routes/terminate.rs index 5b30b020c8..deac760f43 100644 --- a/compute_tools/src/http/routes/terminate.rs +++ b/compute_tools/src/http/routes/terminate.rs @@ -1,7 +1,7 @@ use crate::compute::{ComputeNode, forward_termination_signal}; use crate::http::JsonResponse; use axum::extract::State; -use axum::response::Response; +use axum::response::{IntoResponse, Response}; use axum_extra::extract::OptionalQuery; use compute_api::responses::{ComputeStatus, TerminateMode, TerminateResponse}; use http::StatusCode; @@ -33,7 +33,29 @@ pub(in crate::http) async fn terminate( if !matches!(state.status, ComputeStatus::Empty | ComputeStatus::Running) { return JsonResponse::invalid_status(state.status); } + + // If compute is Empty, there's no Postgres to terminate. The regular compute_ctl termination path + // assumes Postgres to be configured and running, so we just special-handle this case by exiting + // the process directly. + if compute.params.lakebase_mode && state.status == ComputeStatus::Empty { + drop(state); + info!("terminating empty compute - will exit process"); + + // Queue a task to exit the process after 5 seconds. The 5-second delay aims to + // give enough time for the HTTP response to be sent so that HCM doesn't get an abrupt + // connection termination. + tokio::spawn(async { + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; + info!("exiting process after terminating empty compute"); + std::process::exit(0); + }); + + return StatusCode::OK.into_response(); + } + + // For Running status, proceed with normal termination state.set_status(mode.into(), &compute.state_changed); + drop(state); } forward_termination_signal(false); diff --git a/compute_tools/src/http/server.rs b/compute_tools/src/http/server.rs index f0fbca8263..869fdef11d 100644 --- a/compute_tools/src/http/server.rs +++ b/compute_tools/src/http/server.rs @@ -23,7 +23,8 @@ use super::{ middleware::authorize::Authorize, routes::{ check_writability, configure, database_schema, dbs_and_roles, extension_server, extensions, - grants, insights, lfc, metrics, metrics_json, promote, status, terminate, + grants, hadron_liveness_probe, insights, lfc, metrics, metrics_json, promote, + refresh_configuration, status, terminate, }, }; use crate::compute::ComputeNode; @@ -43,6 +44,7 @@ pub enum Server { port: u16, config: ComputeCtlConfig, compute_id: String, + instance_id: Option, }, } @@ -67,7 +69,12 @@ impl From<&Server> for Router> { post(extension_server::download_extension), ) .route("/extensions", post(extensions::install_extension)) - .route("/grants", post(grants::add_grant)); + .route("/grants", post(grants::add_grant)) + // Hadron: Compute-initiated configuration refresh + .route( + "/refresh_configuration", + post(refresh_configuration::refresh_configuration), + ); // Add in any testing support if cfg!(feature = "testing") { @@ -79,7 +86,10 @@ impl From<&Server> for Router> { router } Server::External { - config, compute_id, .. + config, + compute_id, + instance_id, + .. } => { let unauthenticated_router = Router::>::new() .route("/metrics", get(metrics::get_metrics)) @@ -89,7 +99,12 @@ impl From<&Server> for Router> { ); let authenticated_router = Router::>::new() - .route("/lfc/prewarm", get(lfc::prewarm_state).post(lfc::prewarm)) + .route( + "/lfc/prewarm", + get(lfc::prewarm_state) + .post(lfc::prewarm) + .delete(lfc::cancel_prewarm), + ) .route("/lfc/offload", get(lfc::offload_state).post(lfc::offload)) .route("/promote", post(promote::promote)) .route("/check_writability", post(check_writability::is_writable)) @@ -100,8 +115,13 @@ impl From<&Server> for Router> { .route("/metrics.json", get(metrics_json::get_metrics)) .route("/status", get(status::get_status)) .route("/terminate", post(terminate::terminate)) + .route( + "/hadron_liveness_probe", + get(hadron_liveness_probe::hadron_liveness_probe), + ) .layer(AsyncRequireAuthorizationLayer::new(Authorize::new( compute_id.clone(), + instance_id.clone(), config.jwks.clone(), ))); diff --git a/compute_tools/src/installed_extensions.rs b/compute_tools/src/installed_extensions.rs index 90e1a17be4..a9ddef58e5 100644 --- a/compute_tools/src/installed_extensions.rs +++ b/compute_tools/src/installed_extensions.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use anyhow::Result; use compute_api::responses::{InstalledExtension, InstalledExtensions}; +use once_cell::sync::Lazy; use tokio_postgres::error::Error as PostgresError; use tokio_postgres::{Client, Config, NoTls}; @@ -18,7 +19,7 @@ async fn list_dbs(client: &mut Client) -> Result, PostgresError> { .query( "SELECT datname FROM pg_catalog.pg_database WHERE datallowconn - AND datconnlimit <> - 2 + AND datconnlimit OPERATOR(pg_catalog.<>) (OPERATOR(pg_catalog.-) 2::pg_catalog.int4) LIMIT 500", &[], ) @@ -66,7 +67,7 @@ pub async fn get_installed_extensions( let extensions: Vec<(String, String, i32)> = client .query( - "SELECT extname, extversion, extowner::integer FROM pg_catalog.pg_extension", + "SELECT extname, extversion, extowner::pg_catalog.int4 FROM pg_catalog.pg_extension", &[], ) .await? @@ -119,3 +120,7 @@ pub async fn get_installed_extensions( extensions: extensions_map.into_values().collect(), }) } + +pub fn initialize_metrics() { + Lazy::force(&INSTALLED_EXTENSIONS); +} diff --git a/compute_tools/src/lib.rs b/compute_tools/src/lib.rs index 4d0a7dca05..85a6f955d9 100644 --- a/compute_tools/src/lib.rs +++ b/compute_tools/src/lib.rs @@ -16,6 +16,7 @@ pub mod compute_prewarm; pub mod compute_promote; pub mod disk_quota; pub mod extension_server; +pub mod hadron_metrics; pub mod installed_extensions; pub mod local_proxy; pub mod lsn_lease; @@ -24,6 +25,7 @@ mod migration; pub mod monitor; pub mod params; pub mod pg_helpers; +pub mod pg_isready; pub mod pgbouncer; pub mod rsyslog; pub mod spec; diff --git a/compute_tools/src/logger.rs b/compute_tools/src/logger.rs index cd076472a6..83e666223c 100644 --- a/compute_tools/src/logger.rs +++ b/compute_tools/src/logger.rs @@ -1,7 +1,10 @@ use std::collections::HashMap; +use std::sync::{LazyLock, RwLock}; +use tracing::Subscriber; use tracing::info; -use tracing_subscriber::layer::SubscriberExt; +use tracing_appender; use tracing_subscriber::prelude::*; +use tracing_subscriber::{fmt, layer::SubscriberExt, registry::LookupSpan}; /// Initialize logging to stderr, and OpenTelemetry tracing and exporter. /// @@ -15,16 +18,44 @@ use tracing_subscriber::prelude::*; /// pub fn init_tracing_and_logging( default_log_level: &str, -) -> anyhow::Result> { + log_dir_opt: &Option, +) -> anyhow::Result<( + Option, + Option, +)> { // Initialize Logging let env_filter = tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(default_log_level)); + // Standard output streams let fmt_layer = tracing_subscriber::fmt::layer() .with_ansi(false) .with_target(false) .with_writer(std::io::stderr); + // Logs with file rotation. Files in `$log_dir/pgcctl.yyyy-MM-dd` + let (json_to_file_layer, _file_logs_guard) = if let Some(log_dir) = log_dir_opt { + std::fs::create_dir_all(log_dir)?; + let file_logs_appender = tracing_appender::rolling::RollingFileAppender::builder() + .rotation(tracing_appender::rolling::Rotation::DAILY) + .filename_prefix("pgcctl") + // Lib appends to existing files, so we will keep files for up to 2 days even on restart loops. + // At minimum, log-daemon will have 1 day to detect and upload a file (if created right before midnight). + .max_log_files(2) + .build(log_dir) + .expect("Initializing rolling file appender should succeed"); + let (file_logs_writer, _file_logs_guard) = + tracing_appender::non_blocking(file_logs_appender); + let json_to_file_layer = tracing_subscriber::fmt::layer() + .with_ansi(false) + .with_target(false) + .event_format(PgJsonLogShapeFormatter) + .with_writer(file_logs_writer); + (Some(json_to_file_layer), Some(_file_logs_guard)) + } else { + (None, None) + }; + // Initialize OpenTelemetry let provider = tracing_utils::init_tracing("compute_ctl", tracing_utils::ExportConfig::default()); @@ -35,12 +66,13 @@ pub fn init_tracing_and_logging( .with(env_filter) .with(otlp_layer) .with(fmt_layer) + .with(json_to_file_layer) .init(); tracing::info!("logging and tracing started"); utils::logging::replace_panic_hook_with_tracing_panic_hook().forget(); - Ok(provider) + Ok((provider, _file_logs_guard)) } /// Replace all newline characters with a special character to make it @@ -95,3 +127,157 @@ pub fn startup_context_from_env() -> Option { None } } + +/// Track relevant id's +const UNKNOWN_IDS: &str = r#""pg_instance_id": "", "pg_compute_id": """#; +static IDS: LazyLock> = LazyLock::new(|| RwLock::new(UNKNOWN_IDS.to_string())); + +pub fn update_ids(instance_id: &Option, compute_id: &Option) -> anyhow::Result<()> { + let ids = format!( + r#""pg_instance_id": "{}", "pg_compute_id": "{}""#, + instance_id.as_ref().map(|s| s.as_str()).unwrap_or_default(), + compute_id.as_ref().map(|s| s.as_str()).unwrap_or_default() + ); + let mut guard = IDS + .write() + .map_err(|e| anyhow::anyhow!("Log set id's rwlock poisoned: {}", e))?; + *guard = ids; + Ok(()) +} + +/// Massage compute_ctl logs into PG json log shape so we can use the same Lumberjack setup. +struct PgJsonLogShapeFormatter; +impl fmt::format::FormatEvent for PgJsonLogShapeFormatter +where + S: Subscriber + for<'a> LookupSpan<'a>, + N: for<'a> fmt::format::FormatFields<'a> + 'static, +{ + fn format_event( + &self, + ctx: &fmt::FmtContext<'_, S, N>, + mut writer: fmt::format::Writer<'_>, + event: &tracing::Event<'_>, + ) -> std::fmt::Result { + // Format values from the event's metadata, and open message string + let metadata = event.metadata(); + { + let ids_guard = IDS.read(); + let ids = ids_guard + .as_ref() + .map(|guard| guard.as_str()) + // Surpress so that we don't lose all uploaded/ file logs if something goes super wrong. We would notice the missing id's. + .unwrap_or(UNKNOWN_IDS); + write!( + &mut writer, + r#"{{"timestamp": "{}", "error_severity": "{}", "file_name": "{}", "backend_type": "compute_ctl_self", {}, "message": "#, + chrono::Utc::now().format("%Y-%m-%d %H:%M:%S%.3f GMT"), + metadata.level(), + metadata.target(), + ids + )?; + } + + let mut message = String::new(); + let message_writer = fmt::format::Writer::new(&mut message); + + // Gather the message + ctx.field_format().format_fields(message_writer, event)?; + + // TODO: any better options than to copy-paste this OSS span formatter? + // impl FormatEvent for Format + // https://docs.rs/tracing-subscriber/latest/tracing_subscriber/fmt/trait.FormatEvent.html#impl-FormatEvent%3CS,+N%3E-for-Format%3CFull,+T%3E + + // write message, close bracket, and new line + writeln!(writer, "{}}}", serde_json::to_string(&message).unwrap()) + } +} + +#[cfg(feature = "testing")] +#[cfg(test)] +mod test { + use super::*; + use std::{cell::RefCell, io}; + + // Use thread_local! instead of Mutex for test isolation + thread_local! { + static WRITER_OUTPUT: RefCell = const { RefCell::new(String::new()) }; + } + + #[derive(Clone, Default)] + struct StaticStringWriter; + + impl io::Write for StaticStringWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + let output = String::from_utf8(buf.to_vec()).expect("Invalid UTF-8 in test output"); + WRITER_OUTPUT.with(|s| s.borrow_mut().push_str(&output)); + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } + } + + impl fmt::MakeWriter<'_> for StaticStringWriter { + type Writer = Self; + + fn make_writer(&self) -> Self::Writer { + Self + } + } + + #[test] + fn test_log_pg_json_shape_formatter() { + // Use a scoped subscriber to prevent global state pollution + let subscriber = tracing_subscriber::registry().with( + tracing_subscriber::fmt::layer() + .with_ansi(false) + .with_target(false) + .event_format(PgJsonLogShapeFormatter) + .with_writer(StaticStringWriter), + ); + + let _ = update_ids(&Some("000".to_string()), &Some("111".to_string())); + + // Clear any previous test state + WRITER_OUTPUT.with(|s| s.borrow_mut().clear()); + + let messages = [ + "test message", + r#"json escape check: name="BatchSpanProcessor.Flush.ExportError" reason="Other(reqwest::Error { kind: Request, url: \"http://localhost:4318/v1/traces\", source: hyper_ + util::client::legacy::Error(Connect, ConnectError(\"tcp connect error\", Os { code: 111, kind: ConnectionRefused, message: \"Connection refused\" })) })" Failed during the export process"#, + ]; + + tracing::subscriber::with_default(subscriber, || { + for message in messages { + tracing::info!(message); + } + }); + tracing::info!("not test message"); + + // Get captured output + let output = WRITER_OUTPUT.with(|s| s.borrow().clone()); + + let json_strings: Vec<&str> = output.lines().collect(); + assert_eq!( + json_strings.len(), + messages.len(), + "Log didn't have the expected number of json strings." + ); + + let json_string_shape_regex = regex::Regex::new( + r#"\{"timestamp": "\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3} GMT", "error_severity": "INFO", "file_name": ".+", "backend_type": "compute_ctl_self", "pg_instance_id": "000", "pg_compute_id": "111", "message": ".+"\}"# + ).unwrap(); + + for (i, expected_message) in messages.iter().enumerate() { + let json_string = json_strings[i]; + assert!( + json_string_shape_regex.is_match(json_string), + "Json log didn't match expected pattern:\n{json_string}", + ); + let parsed_json: serde_json::Value = serde_json::from_str(json_string).unwrap(); + let actual_message = parsed_json["message"].as_str().unwrap(); + assert_eq!(*expected_message, actual_message); + } + } +} diff --git a/compute_tools/src/migration.rs b/compute_tools/src/migration.rs index 88d870df97..c8f911b5a6 100644 --- a/compute_tools/src/migration.rs +++ b/compute_tools/src/migration.rs @@ -76,7 +76,7 @@ impl<'m> MigrationRunner<'m> { self.client .simple_query("CREATE SCHEMA IF NOT EXISTS neon_migration") .await?; - self.client.simple_query("CREATE TABLE IF NOT EXISTS neon_migration.migration_id (key INT NOT NULL PRIMARY KEY, id bigint NOT NULL DEFAULT 0)").await?; + self.client.simple_query("CREATE TABLE IF NOT EXISTS neon_migration.migration_id (key pg_catalog.int4 NOT NULL PRIMARY KEY, id pg_catalog.int8 NOT NULL DEFAULT 0)").await?; self.client .simple_query( "INSERT INTO neon_migration.migration_id VALUES (0, 0) ON CONFLICT DO NOTHING", diff --git a/compute_tools/src/migrations/0002-alter_roles.sql b/compute_tools/src/migrations/0002-alter_roles.sql index 367356e6eb..6e28e8d32c 100644 --- a/compute_tools/src/migrations/0002-alter_roles.sql +++ b/compute_tools/src/migrations/0002-alter_roles.sql @@ -15,17 +15,17 @@ DO $$ DECLARE role_name text; BEGIN - FOR role_name IN SELECT rolname FROM pg_roles WHERE pg_has_role(rolname, '{privileged_role_name}', 'member') + FOR role_name IN SELECT rolname FROM pg_catalog.pg_roles WHERE pg_catalog.pg_has_role(rolname, '{privileged_role_name}', 'member') LOOP - RAISE NOTICE 'EXECUTING ALTER ROLE % INHERIT', quote_ident(role_name); - EXECUTE 'ALTER ROLE ' || quote_ident(role_name) || ' INHERIT'; + RAISE NOTICE 'EXECUTING ALTER ROLE % INHERIT', pg_catalog.quote_ident(role_name); + EXECUTE pg_catalog.format('ALTER ROLE %I INHERIT;', role_name); END LOOP; - FOR role_name IN SELECT rolname FROM pg_roles + FOR role_name IN SELECT rolname FROM pg_catalog.pg_roles WHERE - NOT pg_has_role(rolname, '{privileged_role_name}', 'member') AND NOT starts_with(rolname, 'pg_') + NOT pg_catalog.pg_has_role(rolname, '{privileged_role_name}', 'member') AND NOT pg_catalog.starts_with(rolname, 'pg_') LOOP - RAISE NOTICE 'EXECUTING ALTER ROLE % NOBYPASSRLS', quote_ident(role_name); - EXECUTE 'ALTER ROLE ' || quote_ident(role_name) || ' NOBYPASSRLS'; + RAISE NOTICE 'EXECUTING ALTER ROLE % NOBYPASSRLS', pg_catalog.quote_ident(role_name); + EXECUTE pg_catalog.format('ALTER ROLE %I NOBYPASSRLS;', role_name); END LOOP; END $$; diff --git a/compute_tools/src/migrations/0003-grant_pg_create_subscription_to_privileged_role.sql b/compute_tools/src/migrations/0003-grant_pg_create_subscription_to_privileged_role.sql index adf159dc06..d67d6457c6 100644 --- a/compute_tools/src/migrations/0003-grant_pg_create_subscription_to_privileged_role.sql +++ b/compute_tools/src/migrations/0003-grant_pg_create_subscription_to_privileged_role.sql @@ -1,6 +1,6 @@ DO $$ BEGIN - IF (SELECT setting::numeric >= 160000 FROM pg_settings WHERE name = 'server_version_num') THEN + IF (SELECT setting::pg_catalog.numeric >= 160000 FROM pg_catalog.pg_settings WHERE name = 'server_version_num') THEN EXECUTE 'GRANT pg_create_subscription TO {privileged_role_name}'; END IF; END $$; diff --git a/compute_tools/src/migrations/0009-revoke_replication_for_previously_allowed_roles.sql b/compute_tools/src/migrations/0009-revoke_replication_for_previously_allowed_roles.sql index 47129d65b8..7f74d4ee28 100644 --- a/compute_tools/src/migrations/0009-revoke_replication_for_previously_allowed_roles.sql +++ b/compute_tools/src/migrations/0009-revoke_replication_for_previously_allowed_roles.sql @@ -5,9 +5,9 @@ DO $$ DECLARE role_name TEXT; BEGIN - FOR role_name IN SELECT rolname FROM pg_roles WHERE rolreplication IS TRUE + FOR role_name IN SELECT rolname FROM pg_catalog.pg_roles WHERE rolreplication IS TRUE LOOP - RAISE NOTICE 'EXECUTING ALTER ROLE % NOREPLICATION', quote_ident(role_name); - EXECUTE 'ALTER ROLE ' || quote_ident(role_name) || ' NOREPLICATION'; + RAISE NOTICE 'EXECUTING ALTER ROLE % NOREPLICATION', pg_catalog.quote_ident(role_name); + EXECUTE pg_catalog.format('ALTER ROLE %I NOREPLICATION;', role_name); END LOOP; END $$; diff --git a/compute_tools/src/migrations/0010-grant_snapshot_synchronization_funcs_to_privileged_role.sql b/compute_tools/src/migrations/0010-grant_snapshot_synchronization_funcs_to_privileged_role.sql index 84fcb36391..714bdc735a 100644 --- a/compute_tools/src/migrations/0010-grant_snapshot_synchronization_funcs_to_privileged_role.sql +++ b/compute_tools/src/migrations/0010-grant_snapshot_synchronization_funcs_to_privileged_role.sql @@ -1,6 +1,6 @@ DO $$ BEGIN - IF (SELECT setting::numeric >= 160000 FROM pg_settings WHERE name = 'server_version_num') THEN + IF (SELECT setting::pg_catalog.numeric >= 160000 FROM pg_catalog.pg_settings WHERE name OPERATOR(pg_catalog.=) 'server_version_num'::pg_catalog.text) THEN EXECUTE 'GRANT EXECUTE ON FUNCTION pg_export_snapshot TO {privileged_role_name}'; EXECUTE 'GRANT EXECUTE ON FUNCTION pg_log_standby_snapshot TO {privileged_role_name}'; END IF; diff --git a/compute_tools/src/migrations/tests/0001-add_bypass_rls_to_privileged_role.sql b/compute_tools/src/migrations/tests/0001-add_bypass_rls_to_privileged_role.sql index 0c81cef1c4..b5b209ef5e 100644 --- a/compute_tools/src/migrations/tests/0001-add_bypass_rls_to_privileged_role.sql +++ b/compute_tools/src/migrations/tests/0001-add_bypass_rls_to_privileged_role.sql @@ -2,7 +2,7 @@ DO $$ DECLARE bypassrls boolean; BEGIN - SELECT rolbypassrls INTO bypassrls FROM pg_roles WHERE rolname = 'neon_superuser'; + SELECT rolbypassrls INTO bypassrls FROM pg_catalog.pg_roles WHERE rolname = 'neon_superuser'; IF NOT bypassrls THEN RAISE EXCEPTION 'neon_superuser cannot bypass RLS'; END IF; diff --git a/compute_tools/src/migrations/tests/0002-alter_roles.sql b/compute_tools/src/migrations/tests/0002-alter_roles.sql index 433f7b34f7..1755c9088c 100644 --- a/compute_tools/src/migrations/tests/0002-alter_roles.sql +++ b/compute_tools/src/migrations/tests/0002-alter_roles.sql @@ -4,8 +4,8 @@ DECLARE BEGIN FOR role IN SELECT rolname AS name, rolinherit AS inherit - FROM pg_roles - WHERE pg_has_role(rolname, 'neon_superuser', 'member') + FROM pg_catalog.pg_roles + WHERE pg_catalog.pg_has_role(rolname, 'neon_superuser', 'member') LOOP IF NOT role.inherit THEN RAISE EXCEPTION '% cannot inherit', quote_ident(role.name); @@ -14,12 +14,12 @@ BEGIN FOR role IN SELECT rolname AS name, rolbypassrls AS bypassrls - FROM pg_roles - WHERE NOT pg_has_role(rolname, 'neon_superuser', 'member') - AND NOT starts_with(rolname, 'pg_') + FROM pg_catalog.pg_roles + WHERE NOT pg_catalog.pg_has_role(rolname, 'neon_superuser', 'member') + AND NOT pg_catalog.starts_with(rolname, 'pg_') LOOP IF role.bypassrls THEN - RAISE EXCEPTION '% can bypass RLS', quote_ident(role.name); + RAISE EXCEPTION '% can bypass RLS', pg_catalog.quote_ident(role.name); END IF; END LOOP; END $$; diff --git a/compute_tools/src/migrations/tests/0003-grant_pg_create_subscription_to_privileged_role.sql b/compute_tools/src/migrations/tests/0003-grant_pg_create_subscription_to_privileged_role.sql index b164d61295..498770f4fa 100644 --- a/compute_tools/src/migrations/tests/0003-grant_pg_create_subscription_to_privileged_role.sql +++ b/compute_tools/src/migrations/tests/0003-grant_pg_create_subscription_to_privileged_role.sql @@ -1,10 +1,10 @@ DO $$ BEGIN - IF (SELECT current_setting('server_version_num')::numeric < 160000) THEN + IF (SELECT pg_catalog.current_setting('server_version_num')::pg_catalog.numeric < 160000) THEN RETURN; END IF; - IF NOT (SELECT pg_has_role('neon_superuser', 'pg_create_subscription', 'member')) THEN + IF NOT (SELECT pg_catalog.pg_has_role('neon_superuser', 'pg_create_subscription', 'member')) THEN RAISE EXCEPTION 'neon_superuser cannot execute pg_create_subscription'; END IF; END $$; diff --git a/compute_tools/src/migrations/tests/0004-grant_pg_monitor_to_privileged_role.sql b/compute_tools/src/migrations/tests/0004-grant_pg_monitor_to_privileged_role.sql index 3464a2b1cf..ec04cfe199 100644 --- a/compute_tools/src/migrations/tests/0004-grant_pg_monitor_to_privileged_role.sql +++ b/compute_tools/src/migrations/tests/0004-grant_pg_monitor_to_privileged_role.sql @@ -2,12 +2,12 @@ DO $$ DECLARE monitor record; BEGIN - SELECT pg_has_role('neon_superuser', 'pg_monitor', 'member') AS member, + SELECT pg_catalog.pg_has_role('neon_superuser', 'pg_monitor', 'member') AS member, admin_option AS admin INTO monitor - FROM pg_auth_members - WHERE roleid = 'pg_monitor'::regrole - AND member = 'neon_superuser'::regrole; + FROM pg_catalog.pg_auth_members + WHERE roleid = 'pg_monitor'::pg_catalog.regrole + AND member = 'neon_superuser'::pg_catalog.regrole; IF monitor IS NULL THEN RAISE EXCEPTION 'no entry in pg_auth_members for neon_superuser and pg_monitor'; diff --git a/compute_tools/src/migrations/tests/0010-grant_snapshot_synchronization_funcs_to_privileged_role.sql b/compute_tools/src/migrations/tests/0010-grant_snapshot_synchronization_funcs_to_privileged_role.sql index af7f50e95d..f3b28d76c9 100644 --- a/compute_tools/src/migrations/tests/0010-grant_snapshot_synchronization_funcs_to_privileged_role.sql +++ b/compute_tools/src/migrations/tests/0010-grant_snapshot_synchronization_funcs_to_privileged_role.sql @@ -2,11 +2,11 @@ DO $$ DECLARE can_execute boolean; BEGIN - SELECT bool_and(has_function_privilege('neon_superuser', oid, 'execute')) + SELECT pg_catalog.bool_and(pg_catalog.has_function_privilege('neon_superuser', oid, 'execute')) INTO can_execute - FROM pg_proc + FROM pg_catalog.pg_proc WHERE proname IN ('pg_export_snapshot', 'pg_log_standby_snapshot') - AND pronamespace = 'pg_catalog'::regnamespace; + AND pronamespace = 'pg_catalog'::pg_catalog.regnamespace; IF NOT can_execute THEN RAISE EXCEPTION 'neon_superuser cannot execute both pg_export_snapshot and pg_log_standby_snapshot'; END IF; diff --git a/compute_tools/src/migrations/tests/0011-grant_pg_show_replication_origin_status_to_privileged_role.sql b/compute_tools/src/migrations/tests/0011-grant_pg_show_replication_origin_status_to_privileged_role.sql index e55dcdc3b6..197211300b 100644 --- a/compute_tools/src/migrations/tests/0011-grant_pg_show_replication_origin_status_to_privileged_role.sql +++ b/compute_tools/src/migrations/tests/0011-grant_pg_show_replication_origin_status_to_privileged_role.sql @@ -2,9 +2,9 @@ DO $$ DECLARE can_execute boolean; BEGIN - SELECT has_function_privilege('neon_superuser', oid, 'execute') + SELECT pg_catalog.has_function_privilege('neon_superuser', oid, 'execute') INTO can_execute - FROM pg_proc + FROM pg_catalog.pg_proc WHERE proname = 'pg_show_replication_origin_status' AND pronamespace = 'pg_catalog'::regnamespace; IF NOT can_execute THEN diff --git a/compute_tools/src/migrations/tests/0012-grant_pg_signal_backend_to_privileged_role.sql b/compute_tools/src/migrations/tests/0012-grant_pg_signal_backend_to_privileged_role.sql index e62b742d30..0f772d67bd 100644 --- a/compute_tools/src/migrations/tests/0012-grant_pg_signal_backend_to_privileged_role.sql +++ b/compute_tools/src/migrations/tests/0012-grant_pg_signal_backend_to_privileged_role.sql @@ -2,10 +2,10 @@ DO $$ DECLARE signal_backend record; BEGIN - SELECT pg_has_role('neon_superuser', 'pg_signal_backend', 'member') AS member, + SELECT pg_catalog.pg_has_role('neon_superuser', 'pg_signal_backend', 'member') AS member, admin_option AS admin INTO signal_backend - FROM pg_auth_members + FROM pg_catalog.pg_auth_members WHERE roleid = 'pg_signal_backend'::regrole AND member = 'neon_superuser'::regrole; diff --git a/compute_tools/src/monitor.rs b/compute_tools/src/monitor.rs index e164f15dba..78ac423a9b 100644 --- a/compute_tools/src/monitor.rs +++ b/compute_tools/src/monitor.rs @@ -407,9 +407,9 @@ fn get_database_stats(cli: &mut Client) -> anyhow::Result<(f64, i64)> { // like `postgres_exporter` use it to query Postgres statistics. // Use explicit 8 bytes type casts to match Rust types. let stats = cli.query_one( - "SELECT coalesce(sum(active_time), 0.0)::float8 AS total_active_time, - coalesce(sum(sessions), 0)::bigint AS total_sessions - FROM pg_stat_database + "SELECT pg_catalog.coalesce(pg_catalog.sum(active_time), 0.0)::pg_catalog.float8 AS total_active_time, + pg_catalog.coalesce(pg_catalog.sum(sessions), 0)::pg_catalog.bigint AS total_sessions + FROM pg_catalog.pg_stat_database WHERE datname NOT IN ( 'postgres', 'template0', @@ -445,11 +445,11 @@ fn get_backends_state_change(cli: &mut Client) -> anyhow::Result> = None; // Get all running client backends except ourself, use RFC3339 DateTime format. let backends = cli.query( - "SELECT state, to_char(state_change, 'YYYY-MM-DD\"T\"HH24:MI:SS.US\"Z\"') AS state_change + "SELECT state, pg_catalog.to_char(state_change, 'YYYY-MM-DD\"T\"HH24:MI:SS.US\"Z\"'::pg_catalog.text) AS state_change FROM pg_stat_activity - WHERE backend_type = 'client backend' - AND pid != pg_backend_pid() - AND usename != 'cloud_admin';", // XXX: find a better way to filter other monitors? + WHERE backend_type OPERATOR(pg_catalog.=) 'client backend'::pg_catalog.text + AND pid OPERATOR(pg_catalog.!=) pg_catalog.pg_backend_pid() + AND usename OPERATOR(pg_catalog.!=) 'cloud_admin'::pg_catalog.name;", // XXX: find a better way to filter other monitors? &[], ); diff --git a/compute_tools/src/pg_helpers.rs b/compute_tools/src/pg_helpers.rs index 09bbe89b41..4e16a75181 100644 --- a/compute_tools/src/pg_helpers.rs +++ b/compute_tools/src/pg_helpers.rs @@ -299,9 +299,9 @@ pub async fn get_existing_dbs_async( .query_raw::( "SELECT datname AS name, - (SELECT rolname FROM pg_roles WHERE oid = datdba) AS owner, + (SELECT rolname FROM pg_catalog.pg_roles WHERE oid OPERATOR(pg_catalog.=) datdba) AS owner, NOT datallowconn AS restrict_conn, - datconnlimit = - 2 AS invalid + datconnlimit OPERATOR(pg_catalog.=) (OPERATOR(pg_catalog.-) 2) AS invalid FROM pg_catalog.pg_database;", &[], diff --git a/compute_tools/src/pg_isready.rs b/compute_tools/src/pg_isready.rs new file mode 100644 index 0000000000..76c45d6b0a --- /dev/null +++ b/compute_tools/src/pg_isready.rs @@ -0,0 +1,30 @@ +use anyhow::{Context, anyhow}; + +// Run `/usr/local/bin/pg_isready -p {port}` +// Check the connectivity of PG +// Success means PG is listening on the port and accepting connections +// Note that PG does not need to authenticate the connection, nor reserve a connection quota for it. +// See https://www.postgresql.org/docs/current/app-pg-isready.html +pub fn pg_isready(bin: &str, port: u16) -> anyhow::Result<()> { + let child_result = std::process::Command::new(bin) + .arg("-p") + .arg(port.to_string()) + .spawn(); + + child_result + .context("spawn() failed") + .and_then(|mut child| child.wait().context("wait() failed")) + .and_then(|status| match status.success() { + true => Ok(()), + false => Err(anyhow!("process exited with {status}")), + }) + // wrap any prior error with the overall context that we couldn't run the command + .with_context(|| format!("could not run `{bin} --port {port}`")) +} + +// It's safe to assume pg_isready is under the same directory with postgres, +// because it is a PG util bin installed along with postgres +pub fn get_pg_isready_bin(pgbin: &str) -> String { + let split = pgbin.split("/").collect::>(); + split[0..split.len() - 1].join("/") + "/pg_isready" +} diff --git a/compute_tools/src/spec.rs b/compute_tools/src/spec.rs index d00f86a2c0..2dda9c3134 100644 --- a/compute_tools/src/spec.rs +++ b/compute_tools/src/spec.rs @@ -142,7 +142,7 @@ pub fn update_pg_hba(pgdata_path: &Path, databricks_pg_hba: Option<&String>) -> // Update pg_hba to contains databricks specfic settings before adding neon settings // PG uses the first record that matches to perform authentication, so we need to have // our rules before the default ones from neon. - // See https://www.postgresql.org/docs/16/auth-pg-hba-conf.html + // See https://www.postgresql.org/docs/current/auth-pg-hba-conf.html if let Some(databricks_pg_hba) = databricks_pg_hba { if config::line_in_file( &pghba_path, diff --git a/compute_tools/src/spec_apply.rs b/compute_tools/src/spec_apply.rs index 47bf61ae1b..90c0e234d5 100644 --- a/compute_tools/src/spec_apply.rs +++ b/compute_tools/src/spec_apply.rs @@ -13,17 +13,19 @@ use tokio_postgres::Client; use tokio_postgres::error::SqlState; use tracing::{Instrument, debug, error, info, info_span, instrument, warn}; -use crate::compute::{ComputeNode, ComputeNodeParams, ComputeState}; +use crate::compute::{ComputeNode, ComputeNodeParams, ComputeState, create_databricks_roles}; +use crate::hadron_metrics::COMPUTE_CONFIGURE_STATEMENT_TIMEOUT_ERRORS; use crate::pg_helpers::{ DatabaseExt, Escaping, GenericOptionsSearch, RoleExt, get_existing_dbs_async, get_existing_roles_async, }; use crate::spec_apply::ApplySpecPhase::{ - CreateAndAlterDatabases, CreateAndAlterRoles, CreateAvailabilityCheck, CreatePgauditExtension, + AddDatabricksGrants, AlterDatabricksRoles, CreateAndAlterDatabases, CreateAndAlterRoles, + CreateAvailabilityCheck, CreateDatabricksMisc, CreateDatabricksRoles, CreatePgauditExtension, CreatePgauditlogtofileExtension, CreatePrivilegedRole, CreateSchemaNeon, DisablePostgresDBPgAudit, DropInvalidDatabases, DropRoles, FinalizeDropLogicalSubscriptions, - HandleNeonExtension, HandleOtherExtensions, RenameAndDeleteDatabases, RenameRoles, - RunInEachDatabase, + HandleDatabricksAuthExtension, HandleNeonExtension, HandleOtherExtensions, + RenameAndDeleteDatabases, RenameRoles, RunInEachDatabase, }; use crate::spec_apply::PerDatabasePhase::{ ChangeSchemaPerms, DeleteDBRoleReferences, DropLogicalSubscriptions, @@ -80,7 +82,7 @@ impl ComputeNode { info!("Checking if drop subscription operation was already performed for timeline_id: {}", timeline_id); drop_subscriptions_done = match - client.query("select 1 from neon.drop_subscriptions_done where timeline_id = $1", &[&timeline_id.to_string()]).await { + client.query("select 1 from neon.drop_subscriptions_done where timeline_id OPERATOR(pg_catalog.=) $1", &[&timeline_id.to_string()]).await { Ok(result) => !result.is_empty(), Err(e) => { @@ -166,6 +168,7 @@ impl ComputeNode { concurrency_token.clone(), db, [DropLogicalSubscriptions].to_vec(), + self.params.lakebase_mode, ); Ok(tokio::spawn(fut)) @@ -186,15 +189,33 @@ impl ComputeNode { }; } - for phase in [ - CreatePrivilegedRole, + let phases = if self.params.lakebase_mode { + vec![ + CreatePrivilegedRole, + // BEGIN_HADRON + CreateDatabricksRoles, + AlterDatabricksRoles, + // END_HADRON DropInvalidDatabases, RenameRoles, CreateAndAlterRoles, RenameAndDeleteDatabases, CreateAndAlterDatabases, CreateSchemaNeon, - ] { + ] + } else { + vec![ + CreatePrivilegedRole, + DropInvalidDatabases, + RenameRoles, + CreateAndAlterRoles, + RenameAndDeleteDatabases, + CreateAndAlterDatabases, + CreateSchemaNeon, + ] + }; + + for phase in phases { info!("Applying phase {:?}", &phase); apply_operations( params.clone(), @@ -203,6 +224,7 @@ impl ComputeNode { jwks_roles.clone(), phase, || async { Ok(&client) }, + self.params.lakebase_mode, ) .await?; } @@ -254,6 +276,7 @@ impl ComputeNode { concurrency_token.clone(), db, phases, + self.params.lakebase_mode, ); Ok(tokio::spawn(fut)) @@ -265,12 +288,28 @@ impl ComputeNode { handle.await??; } - let mut phases = vec![ + let mut phases = if self.params.lakebase_mode { + vec![ + HandleOtherExtensions, + HandleNeonExtension, // This step depends on CreateSchemaNeon + // BEGIN_HADRON + HandleDatabricksAuthExtension, + // END_HADRON + CreateAvailabilityCheck, + DropRoles, + // BEGIN_HADRON + AddDatabricksGrants, + CreateDatabricksMisc, + // END_HADRON + ] + } else { + vec![ HandleOtherExtensions, HandleNeonExtension, // This step depends on CreateSchemaNeon CreateAvailabilityCheck, DropRoles, - ]; + ] + }; // This step depends on CreateSchemaNeon if spec.drop_subscriptions_before_start && !drop_subscriptions_done { @@ -303,6 +342,7 @@ impl ComputeNode { jwks_roles.clone(), phase, || async { Ok(&client) }, + self.params.lakebase_mode, ) .await?; } @@ -328,6 +368,7 @@ impl ComputeNode { concurrency_token: Arc, db: DB, subphases: Vec, + lakebase_mode: bool, ) -> Result<()> { let _permit = concurrency_token.acquire().await?; @@ -355,6 +396,7 @@ impl ComputeNode { let client = client_conn.as_ref().unwrap(); Ok(client) }, + lakebase_mode, ) .await?; } @@ -477,6 +519,10 @@ pub enum PerDatabasePhase { #[derive(Clone, Debug)] pub enum ApplySpecPhase { CreatePrivilegedRole, + // BEGIN_HADRON + CreateDatabricksRoles, + AlterDatabricksRoles, + // END_HADRON DropInvalidDatabases, RenameRoles, CreateAndAlterRoles, @@ -489,7 +535,14 @@ pub enum ApplySpecPhase { DisablePostgresDBPgAudit, HandleOtherExtensions, HandleNeonExtension, + // BEGIN_HADRON + HandleDatabricksAuthExtension, + // END_HADRON CreateAvailabilityCheck, + // BEGIN_HADRON + AddDatabricksGrants, + CreateDatabricksMisc, + // END_HADRON DropRoles, FinalizeDropLogicalSubscriptions, } @@ -525,6 +578,7 @@ pub async fn apply_operations<'a, Fut, F>( jwks_roles: Arc>, apply_spec_phase: ApplySpecPhase, client: F, + lakebase_mode: bool, ) -> Result<()> where F: FnOnce() -> Fut, @@ -571,6 +625,23 @@ where }, query ); + if !lakebase_mode { + return res; + } + // BEGIN HADRON + if let Err(e) = res.as_ref() { + if let Some(sql_state) = e.code() { + if sql_state.code() == "57014" { + // SQL State 57014 (ERRCODE_QUERY_CANCELED) is used for statement timeouts. + // Increment the counter whenever a statement timeout occurs. Timeouts on + // this configuration path can only occur due to PS connectivity problems that + // Postgres failed to recover from. + COMPUTE_CONFIGURE_STATEMENT_TIMEOUT_ERRORS.inc(); + } + } + } + // END HADRON + res } .instrument(inspan) @@ -608,10 +679,44 @@ async fn get_operations<'a>( ApplySpecPhase::CreatePrivilegedRole => Ok(Box::new(once(Operation { query: format!( include_str!("sql/create_privileged_role.sql"), - privileged_role_name = params.privileged_role_name + privileged_role_name = params.privileged_role_name, + privileges = if params.lakebase_mode { + "CREATEDB CREATEROLE NOLOGIN BYPASSRLS" + } else { + "CREATEDB CREATEROLE NOLOGIN REPLICATION BYPASSRLS" + } ), comment: None, }))), + // BEGIN_HADRON + // New Hadron phase + ApplySpecPhase::CreateDatabricksRoles => { + let queries = create_databricks_roles(); + let operations = queries.into_iter().map(|query| Operation { + query, + comment: None, + }); + Ok(Box::new(operations)) + } + + // Backfill existing databricks_reader_* roles with statement timeout from GUC + ApplySpecPhase::AlterDatabricksRoles => { + let query = String::from(include_str!( + "sql/alter_databricks_reader_roles_timeout.sql" + )); + + let operations = once(Operation { + query, + comment: Some( + "Backfill existing databricks_reader_* roles with statement timeout" + .to_string(), + ), + }); + + Ok(Box::new(operations)) + } + // End of new Hadron Phase + // END_HADRON ApplySpecPhase::DropInvalidDatabases => { let mut ctx = ctx.write().await; let databases = &mut ctx.dbs; @@ -981,7 +1086,10 @@ async fn get_operations<'a>( // N.B. this has to be properly dollar-escaped with `pg_quote_dollar()` role_name = escaped_role, outer_tag = outer_tag, - ), + ) + // HADRON change: + .replace("neon_superuser", ¶ms.privileged_role_name), + // HADRON change end , comment: None, }, // This now will only drop privileges of the role @@ -1017,7 +1125,8 @@ async fn get_operations<'a>( comment: None, }, Operation { - query: String::from(include_str!("sql/default_grants.sql")), + query: String::from(include_str!("sql/default_grants.sql")) + .replace("neon_superuser", ¶ms.privileged_role_name), comment: None, }, ] @@ -1033,7 +1142,9 @@ async fn get_operations<'a>( if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") { if libs.contains("pg_stat_statements") { return Ok(Box::new(once(Operation { - query: String::from("CREATE EXTENSION IF NOT EXISTS pg_stat_statements"), + query: String::from( + "CREATE EXTENSION IF NOT EXISTS pg_stat_statements WITH SCHEMA public", + ), comment: Some(String::from("create system extensions")), }))); } @@ -1041,11 +1152,13 @@ async fn get_operations<'a>( Ok(Box::new(empty())) } ApplySpecPhase::CreatePgauditExtension => Ok(Box::new(once(Operation { - query: String::from("CREATE EXTENSION IF NOT EXISTS pgaudit"), + query: String::from("CREATE EXTENSION IF NOT EXISTS pgaudit WITH SCHEMA public"), comment: Some(String::from("create pgaudit extensions")), }))), ApplySpecPhase::CreatePgauditlogtofileExtension => Ok(Box::new(once(Operation { - query: String::from("CREATE EXTENSION IF NOT EXISTS pgauditlogtofile"), + query: String::from( + "CREATE EXTENSION IF NOT EXISTS pgauditlogtofile WITH SCHEMA public", + ), comment: Some(String::from("create pgauditlogtofile extensions")), }))), // Disable pgaudit logging for postgres database. @@ -1069,7 +1182,7 @@ async fn get_operations<'a>( }, Operation { query: String::from( - "UPDATE pg_extension SET extrelocatable = true WHERE extname = 'neon'", + "UPDATE pg_catalog.pg_extension SET extrelocatable = true WHERE extname OPERATOR(pg_catalog.=) 'neon'::pg_catalog.name AND extrelocatable OPERATOR(pg_catalog.=) false", ), comment: Some(String::from("compat/fix: make neon relocatable")), }, @@ -1086,6 +1199,28 @@ async fn get_operations<'a>( Ok(Box::new(operations)) } + // BEGIN_HADRON + // Note: we may want to version the extension someday, but for now we just drop it and recreate it. + ApplySpecPhase::HandleDatabricksAuthExtension => { + let operations = vec![ + Operation { + query: String::from("DROP EXTENSION IF EXISTS databricks_auth"), + comment: Some(String::from("dropping existing databricks_auth extension")), + }, + Operation { + query: String::from("CREATE EXTENSION databricks_auth"), + comment: Some(String::from("creating databricks_auth extension")), + }, + Operation { + query: String::from("GRANT SELECT ON databricks_auth_metrics TO pg_monitor"), + comment: Some(String::from("grant select on databricks auth counters")), + }, + ] + .into_iter(); + + Ok(Box::new(operations)) + } + // END_HADRON ApplySpecPhase::CreateAvailabilityCheck => Ok(Box::new(once(Operation { query: String::from(include_str!("sql/add_availabilitycheck_tables.sql")), comment: None, @@ -1103,6 +1238,63 @@ async fn get_operations<'a>( Ok(Box::new(operations)) } + + // BEGIN_HADRON + // New Hadron phases + // + // Grants permissions to roles that are used by Databricks. + ApplySpecPhase::AddDatabricksGrants => { + let operations = vec![ + Operation { + query: String::from("GRANT USAGE ON SCHEMA neon TO databricks_monitor"), + comment: Some(String::from( + "Permissions needed to execute neon.* functions (in the postgres database)", + )), + }, + Operation { + query: String::from( + "GRANT SELECT, INSERT, UPDATE ON health_check TO databricks_monitor", + ), + comment: Some(String::from("Permissions needed for read and write probes")), + }, + Operation { + query: String::from( + "GRANT EXECUTE ON FUNCTION pg_ls_dir(text) TO databricks_monitor", + ), + comment: Some(String::from( + "Permissions needed to monitor .snap file counts", + )), + }, + Operation { + query: String::from( + "GRANT SELECT ON neon.neon_perf_counters TO databricks_monitor", + ), + comment: Some(String::from( + "Permissions needed to access neon performance counters view", + )), + }, + Operation { + query: String::from( + "GRANT EXECUTE ON FUNCTION neon.get_perf_counters() TO databricks_monitor", + ), + comment: Some(String::from( + "Permissions needed to execute the underlying performance counters function", + )), + }, + ] + .into_iter(); + + Ok(Box::new(operations)) + } + // Creates minor objects that are used by Databricks. + ApplySpecPhase::CreateDatabricksMisc => Ok(Box::new(once(Operation { + query: String::from(include_str!("sql/create_databricks_misc.sql")), + comment: Some(String::from( + "The function databricks_monitor uses to convert exception to 0 or 1", + )), + }))), + // End of new Hadron phases + // END_HADRON ApplySpecPhase::FinalizeDropLogicalSubscriptions => Ok(Box::new(once(Operation { query: String::from(include_str!("sql/finalize_drop_subscriptions.sql")), comment: None, diff --git a/compute_tools/src/sql/add_availabilitycheck_tables.sql b/compute_tools/src/sql/add_availabilitycheck_tables.sql index 7c60690c78..dd27105e16 100644 --- a/compute_tools/src/sql/add_availabilitycheck_tables.sql +++ b/compute_tools/src/sql/add_availabilitycheck_tables.sql @@ -3,16 +3,17 @@ BEGIN IF NOT EXISTS( SELECT 1 FROM pg_catalog.pg_tables - WHERE tablename = 'health_check' + WHERE tablename::pg_catalog.name OPERATOR(pg_catalog.=) 'health_check'::pg_catalog.name + AND schemaname::pg_catalog.name OPERATOR(pg_catalog.=) 'public'::pg_catalog.name ) THEN - CREATE TABLE health_check ( - id serial primary key, - updated_at timestamptz default now() + CREATE TABLE public.health_check ( + id pg_catalog.int4 primary key generated by default as identity, + updated_at pg_catalog.timestamptz default pg_catalog.now() ); - INSERT INTO health_check VALUES (1, now()) + INSERT INTO public.health_check VALUES (1, pg_catalog.now()) ON CONFLICT (id) DO UPDATE - SET updated_at = now(); + SET updated_at = pg_catalog.now(); END IF; END $$ \ No newline at end of file diff --git a/compute_tools/src/sql/alter_databricks_reader_roles_timeout.sql b/compute_tools/src/sql/alter_databricks_reader_roles_timeout.sql new file mode 100644 index 0000000000..db16df3817 --- /dev/null +++ b/compute_tools/src/sql/alter_databricks_reader_roles_timeout.sql @@ -0,0 +1,25 @@ +DO $$ +DECLARE + reader_role RECORD; + timeout_value TEXT; +BEGIN + -- Get the current GUC setting for reader statement timeout + SELECT current_setting('databricks.reader_statement_timeout', true) INTO timeout_value; + + -- Only proceed if timeout_value is not null/empty and not '0' (disabled) + IF timeout_value IS NOT NULL AND timeout_value != '' AND timeout_value != '0' THEN + -- Find all databricks_reader_* roles and update their statement_timeout + FOR reader_role IN + SELECT r.rolname + FROM pg_roles r + WHERE r.rolname ~ '^databricks_reader_\d+$' + LOOP + -- Apply the timeout setting to the role (will overwrite existing setting) + EXECUTE format('ALTER ROLE %I SET statement_timeout = %L', + reader_role.rolname, timeout_value); + + RAISE LOG 'Updated statement_timeout = % for role %', timeout_value, reader_role.rolname; + END LOOP; + END IF; +END +$$; diff --git a/compute_tools/src/sql/anon_ext_fn_reassign.sql b/compute_tools/src/sql/anon_ext_fn_reassign.sql deleted file mode 100644 index 3d7b15c590..0000000000 --- a/compute_tools/src/sql/anon_ext_fn_reassign.sql +++ /dev/null @@ -1,12 +0,0 @@ -DO $$ -DECLARE - query varchar; -BEGIN - FOR query IN SELECT 'ALTER FUNCTION '||nsp.nspname||'.'||p.proname||'('||pg_get_function_identity_arguments(p.oid)||') OWNER TO {db_owner};' - FROM pg_proc p - JOIN pg_namespace nsp ON p.pronamespace = nsp.oid - WHERE nsp.nspname = 'anon' LOOP - EXECUTE query; - END LOOP; -END -$$; diff --git a/compute_tools/src/sql/create_databricks_misc.sql b/compute_tools/src/sql/create_databricks_misc.sql new file mode 100644 index 0000000000..a6dc379078 --- /dev/null +++ b/compute_tools/src/sql/create_databricks_misc.sql @@ -0,0 +1,15 @@ +ALTER ROLE databricks_monitor SET statement_timeout = '60s'; + +CREATE OR REPLACE FUNCTION health_check_write_succeeds() +RETURNS INTEGER AS $$ +BEGIN +INSERT INTO health_check VALUES (1, now()) +ON CONFLICT (id) DO UPDATE + SET updated_at = now(); + +RETURN 1; +EXCEPTION WHEN OTHERS THEN +RAISE EXCEPTION '[DATABRICKS_SMGR] health_check failed: [%] %', SQLSTATE, SQLERRM; +RETURN 0; +END; +$$ LANGUAGE plpgsql; diff --git a/compute_tools/src/sql/create_privileged_role.sql b/compute_tools/src/sql/create_privileged_role.sql index df27ac32fc..a682089cce 100644 --- a/compute_tools/src/sql/create_privileged_role.sql +++ b/compute_tools/src/sql/create_privileged_role.sql @@ -1,8 +1,8 @@ DO $$ BEGIN - IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{privileged_role_name}') + IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname OPERATOR(pg_catalog.=) '{privileged_role_name}'::pg_catalog.name) THEN - CREATE ROLE {privileged_role_name} CREATEDB CREATEROLE NOLOGIN REPLICATION BYPASSRLS IN ROLE pg_read_all_data, pg_write_all_data; + CREATE ROLE {privileged_role_name} {privileges} IN ROLE pg_read_all_data, pg_write_all_data; END IF; END $$; diff --git a/compute_tools/src/sql/default_grants.sql b/compute_tools/src/sql/default_grants.sql index 58ebb0690b..d572332270 100644 --- a/compute_tools/src/sql/default_grants.sql +++ b/compute_tools/src/sql/default_grants.sql @@ -4,14 +4,14 @@ $$ IF EXISTS( SELECT nspname FROM pg_catalog.pg_namespace - WHERE nspname = 'public' + WHERE nspname OPERATOR(pg_catalog.=) 'public' ) AND - current_setting('server_version_num')::int / 10000 >= 15 + pg_catalog.current_setting('server_version_num')::int OPERATOR(pg_catalog./) 10000 OPERATOR(pg_catalog.>=) 15 THEN IF EXISTS( SELECT rolname FROM pg_catalog.pg_roles - WHERE rolname = 'web_access' + WHERE rolname OPERATOR(pg_catalog.=) 'web_access' ) THEN GRANT CREATE ON SCHEMA public TO web_access; @@ -20,7 +20,7 @@ $$ IF EXISTS( SELECT nspname FROM pg_catalog.pg_namespace - WHERE nspname = 'public' + WHERE nspname OPERATOR(pg_catalog.=) 'public' ) THEN ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO neon_superuser WITH GRANT OPTION; diff --git a/compute_tools/src/sql/drop_subscriptions.sql b/compute_tools/src/sql/drop_subscriptions.sql index f5d9420130..68b3f8b729 100644 --- a/compute_tools/src/sql/drop_subscriptions.sql +++ b/compute_tools/src/sql/drop_subscriptions.sql @@ -2,11 +2,17 @@ DO ${outer_tag}$ DECLARE subname TEXT; BEGIN - LOCK TABLE pg_subscription IN ACCESS EXCLUSIVE MODE; - FOR subname IN SELECT pg_subscription.subname FROM pg_subscription WHERE subdbid = (SELECT oid FROM pg_database WHERE datname = {datname_str}) LOOP - EXECUTE format('ALTER SUBSCRIPTION %I DISABLE;', subname); - EXECUTE format('ALTER SUBSCRIPTION %I SET (slot_name = NONE);', subname); - EXECUTE format('DROP SUBSCRIPTION %I;', subname); + LOCK TABLE pg_catalog.pg_subscription IN ACCESS EXCLUSIVE MODE; + FOR subname IN + SELECT pg_subscription.subname + FROM pg_catalog.pg_subscription + WHERE subdbid OPERATOR(pg_catalog.=) ( + SELECT oid FROM pg_database WHERE datname OPERATOR(pg_catalog.=) {datname_str}::pg_catalog.name + ) + LOOP + EXECUTE pg_catalog.format('ALTER SUBSCRIPTION %I DISABLE;', subname); + EXECUTE pg_catalog.format('ALTER SUBSCRIPTION %I SET (slot_name = NONE);', subname); + EXECUTE pg_catalog.format('DROP SUBSCRIPTION %I;', subname); END LOOP; END; ${outer_tag}$; diff --git a/compute_tools/src/sql/finalize_drop_subscriptions.sql b/compute_tools/src/sql/finalize_drop_subscriptions.sql index 4bb291876f..1a8876ad61 100644 --- a/compute_tools/src/sql/finalize_drop_subscriptions.sql +++ b/compute_tools/src/sql/finalize_drop_subscriptions.sql @@ -3,19 +3,19 @@ BEGIN IF NOT EXISTS( SELECT 1 FROM pg_catalog.pg_tables - WHERE tablename = 'drop_subscriptions_done' - AND schemaname = 'neon' + WHERE tablename OPERATOR(pg_catalog.=) 'drop_subscriptions_done'::pg_catalog.name + AND schemaname OPERATOR(pg_catalog.=) 'neon'::pg_catalog.name ) THEN CREATE TABLE neon.drop_subscriptions_done - (id serial primary key, timeline_id text); + (id pg_catalog.int4 primary key generated by default as identity, timeline_id pg_catalog.text); END IF; -- preserve the timeline_id of the last drop_subscriptions run -- to ensure that the cleanup of a timeline is executed only once. -- use upsert to avoid the table bloat in case of cascade branching (branch of a branch) - INSERT INTO neon.drop_subscriptions_done VALUES (1, current_setting('neon.timeline_id')) + INSERT INTO neon.drop_subscriptions_done VALUES (1, pg_catalog.current_setting('neon.timeline_id')) ON CONFLICT (id) DO UPDATE - SET timeline_id = current_setting('neon.timeline_id'); + SET timeline_id = pg_catalog.current_setting('neon.timeline_id')::pg_catalog.text; END $$ diff --git a/compute_tools/src/sql/pre_drop_role_revoke_privileges.sql b/compute_tools/src/sql/pre_drop_role_revoke_privileges.sql index 734607be02..2ed0f94bad 100644 --- a/compute_tools/src/sql/pre_drop_role_revoke_privileges.sql +++ b/compute_tools/src/sql/pre_drop_role_revoke_privileges.sql @@ -15,15 +15,15 @@ BEGIN WHERE schema_name IN ('public') LOOP FOR grantor IN EXECUTE - format( - 'SELECT DISTINCT rtg.grantor FROM information_schema.role_table_grants AS rtg WHERE grantee = %s', + pg_catalog.format( + 'SELECT DISTINCT rtg.grantor FROM information_schema.role_table_grants AS rtg WHERE grantee OPERATOR(pg_catalog.=) %s', -- N.B. this has to be properly dollar-escaped with `pg_quote_dollar()` quote_literal({role_name}) ) LOOP - EXECUTE format('SET LOCAL ROLE %I', grantor); + EXECUTE pg_catalog.format('SET LOCAL ROLE %I', grantor); - revoke_query := format( + revoke_query := pg_catalog.format( 'REVOKE ALL PRIVILEGES ON ALL TABLES IN SCHEMA %I FROM %I GRANTED BY %I', schema, -- N.B. this has to be properly dollar-escaped with `pg_quote_dollar()` diff --git a/compute_tools/src/sql/set_public_schema_owner.sql b/compute_tools/src/sql/set_public_schema_owner.sql index dc502c6d2d..41bd0d4689 100644 --- a/compute_tools/src/sql/set_public_schema_owner.sql +++ b/compute_tools/src/sql/set_public_schema_owner.sql @@ -5,17 +5,17 @@ DO ${outer_tag}$ IF EXISTS( SELECT nspname FROM pg_catalog.pg_namespace - WHERE nspname = 'public' + WHERE nspname OPERATOR(pg_catalog.=) 'public'::pg_catalog.name ) THEN SELECT nspowner::regrole::text FROM pg_catalog.pg_namespace - WHERE nspname = 'public' + WHERE nspname OPERATOR(pg_catalog.=) 'public'::pg_catalog.text INTO schema_owner; - IF schema_owner = 'cloud_admin' OR schema_owner = 'zenith_admin' + IF schema_owner OPERATOR(pg_catalog.=) 'cloud_admin'::pg_catalog.text OR schema_owner OPERATOR(pg_catalog.=) 'zenith_admin'::pg_catalog.text THEN - EXECUTE format('ALTER SCHEMA public OWNER TO %I', {db_owner}); + EXECUTE pg_catalog.format('ALTER SCHEMA public OWNER TO %I', {db_owner}); END IF; END IF; END diff --git a/compute_tools/src/sql/unset_template_for_drop_dbs.sql b/compute_tools/src/sql/unset_template_for_drop_dbs.sql index 36dc648beb..03225d5e64 100644 --- a/compute_tools/src/sql/unset_template_for_drop_dbs.sql +++ b/compute_tools/src/sql/unset_template_for_drop_dbs.sql @@ -3,10 +3,10 @@ DO ${outer_tag}$ IF EXISTS( SELECT 1 FROM pg_catalog.pg_database - WHERE datname = {datname} + WHERE datname OPERATOR(pg_catalog.=) {datname}::pg_catalog.name ) THEN - EXECUTE format('ALTER DATABASE %I is_template false', {datname}); + EXECUTE pg_catalog.format('ALTER DATABASE %I is_template false', {datname}); END IF; END ${outer_tag}$; diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index 8cd923fc72..2b81c3957c 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -16,13 +16,11 @@ use std::time::Duration; use anyhow::{Context, Result, anyhow, bail}; use clap::Parser; use compute_api::requests::ComputeClaimsScope; -use compute_api::spec::{ - ComputeMode, PageserverConnectionInfo, PageserverProtocol, PageserverShardInfo, -}; +use compute_api::spec::{ComputeMode, PageserverProtocol}; use control_plane::broker::StorageBroker; use control_plane::endpoint::{ComputeControlPlane, EndpointTerminateMode}; use control_plane::endpoint::{ - pageserver_conf_to_shard_conn_info, tenant_locate_response_to_conn_info, + local_pageserver_conf_to_conn_info, tenant_locate_response_to_conn_info, }; use control_plane::endpoint_storage::{ENDPOINT_STORAGE_DEFAULT_ADDR, EndpointStorage}; use control_plane::local_env; @@ -60,7 +58,6 @@ use utils::auth::{Claims, Scope}; use utils::id::{NodeId, TenantId, TenantTimelineId, TimelineId}; use utils::lsn::Lsn; use utils::project_git_version; -use utils::shard::ShardIndex; // Default id of a safekeeper node, if not specified on the command line. const DEFAULT_SAFEKEEPER_ID: NodeId = NodeId(1); @@ -74,8 +71,9 @@ const DEFAULT_PG_VERSION_NUM: &str = "17"; const DEFAULT_PAGESERVER_CONTROL_PLANE_API: &str = "http://127.0.0.1:1234/upcall/v1/"; +/// Neon CLI. #[derive(clap::Parser)] -#[command(version = GIT_VERSION, about, name = "Neon CLI")] +#[command(version = GIT_VERSION, name = "Neon CLI")] struct Cli { #[command(subcommand)] command: NeonLocalCmd, @@ -110,30 +108,31 @@ enum NeonLocalCmd { Stop(StopCmdArgs), } +/// Initialize a new Neon repository, preparing configs for services to start with. #[derive(clap::Args)] -#[clap(about = "Initialize a new Neon repository, preparing configs for services to start with")] struct InitCmdArgs { - #[clap(long, help("How many pageservers to create (default 1)"))] + /// How many pageservers to create (default 1). + #[clap(long)] num_pageservers: Option, #[clap(long)] config: Option, - #[clap(long, help("Force initialization even if the repository is not empty"))] + /// Force initialization even if the repository is not empty. + #[clap(long, default_value = "must-not-exist")] #[arg(value_parser)] - #[clap(default_value = "must-not-exist")] force: InitForceMode, } +/// Start pageserver and safekeepers. #[derive(clap::Args)] -#[clap(about = "Start pageserver and safekeepers")] struct StartCmdArgs { #[clap(long = "start-timeout", default_value = "10s")] timeout: humantime::Duration, } +/// Stop pageserver and safekeepers. #[derive(clap::Args)] -#[clap(about = "Stop pageserver and safekeepers")] struct StopCmdArgs { #[arg(value_enum)] #[clap(long, default_value_t = StopMode::Fast)] @@ -146,8 +145,8 @@ enum StopMode { Immediate, } +/// Manage tenants. #[derive(clap::Subcommand)] -#[clap(about = "Manage tenants")] enum TenantCmd { List, Create(TenantCreateCmdArgs), @@ -158,38 +157,36 @@ enum TenantCmd { #[derive(clap::Args)] struct TenantCreateCmdArgs { - #[clap( - long = "tenant-id", - help = "Tenant id. Represented as a hexadecimal string 32 symbols length" - )] + /// Tenant ID, as a 32-byte hexadecimal string. + #[clap(long = "tenant-id")] tenant_id: Option, - #[clap( - long, - help = "Use a specific timeline id when creating a tenant and its initial timeline" - )] + /// Use a specific timeline id when creating a tenant and its initial timeline. + #[clap(long)] timeline_id: Option, #[clap(short = 'c')] config: Vec, + /// Postgres version to use for the initial timeline. #[arg(default_value = DEFAULT_PG_VERSION_NUM)] - #[clap(long, help = "Postgres version to use for the initial timeline")] + #[clap(long)] pg_version: PgMajorVersion, - #[clap( - long, - help = "Use this tenant in future CLI commands where tenant_id is needed, but not specified" - )] + /// Use this tenant in future CLI commands where tenant_id is needed, but not specified. + #[clap(long)] set_default: bool, - #[clap(long, help = "Number of shards in the new tenant")] + /// Number of shards in the new tenant. + #[clap(long)] #[arg(default_value_t = 0)] shard_count: u8, - #[clap(long, help = "Sharding stripe size in pages")] + /// Sharding stripe size in pages. + #[clap(long)] shard_stripe_size: Option, - #[clap(long, help = "Placement policy shards in this tenant")] + /// Placement policy shards in this tenant. + #[clap(long)] #[arg(value_parser = parse_placement_policy)] placement_policy: Option, } @@ -198,44 +195,35 @@ fn parse_placement_policy(s: &str) -> anyhow::Result { Ok(serde_json::from_str::(s)?) } +/// Set a particular tenant as default in future CLI commands where tenant_id is needed, but not +/// specified. #[derive(clap::Args)] -#[clap( - about = "Set a particular tenant as default in future CLI commands where tenant_id is needed, but not specified" -)] struct TenantSetDefaultCmdArgs { - #[clap( - long = "tenant-id", - help = "Tenant id. Represented as a hexadecimal string 32 symbols length" - )] + /// Tenant ID, as a 32-byte hexadecimal string. + #[clap(long = "tenant-id")] tenant_id: TenantId, } #[derive(clap::Args)] struct TenantConfigCmdArgs { - #[clap( - long = "tenant-id", - help = "Tenant id. Represented as a hexadecimal string 32 symbols length" - )] + /// Tenant ID, as a 32-byte hexadecimal string. + #[clap(long = "tenant-id")] tenant_id: Option, #[clap(short = 'c')] config: Vec, } +/// Import a tenant that is present in remote storage, and create branches for its timelines. #[derive(clap::Args)] -#[clap( - about = "Import a tenant that is present in remote storage, and create branches for its timelines" -)] struct TenantImportCmdArgs { - #[clap( - long = "tenant-id", - help = "Tenant id. Represented as a hexadecimal string 32 symbols length" - )] + /// Tenant ID, as a 32-byte hexadecimal string. + #[clap(long = "tenant-id")] tenant_id: TenantId, } +/// Manage timelines. #[derive(clap::Subcommand)] -#[clap(about = "Manage timelines")] enum TimelineCmd { List(TimelineListCmdArgs), Branch(TimelineBranchCmdArgs), @@ -243,98 +231,87 @@ enum TimelineCmd { Import(TimelineImportCmdArgs), } +/// List all timelines available to this pageserver. #[derive(clap::Args)] -#[clap(about = "List all timelines available to this pageserver")] struct TimelineListCmdArgs { - #[clap( - long = "tenant-id", - help = "Tenant id. Represented as a hexadecimal string 32 symbols length" - )] + /// Tenant ID, as a 32-byte hexadecimal string. + #[clap(long = "tenant-id")] tenant_shard_id: Option, } +/// Create a new timeline, branching off from another timeline. #[derive(clap::Args)] -#[clap(about = "Create a new timeline, branching off from another timeline")] struct TimelineBranchCmdArgs { - #[clap( - long = "tenant-id", - help = "Tenant id. Represented as a hexadecimal string 32 symbols length" - )] + /// Tenant ID, as a 32-byte hexadecimal string. + #[clap(long = "tenant-id")] tenant_id: Option, - - #[clap(long, help = "New timeline's ID")] + /// New timeline's ID, as a 32-byte hexadecimal string. + #[clap(long)] timeline_id: Option, - - #[clap(long, help = "Human-readable alias for the new timeline")] + /// Human-readable alias for the new timeline. + #[clap(long)] branch_name: String, - - #[clap( - long, - help = "Use last Lsn of another timeline (and its data) as base when creating the new timeline. The timeline gets resolved by its branch name." - )] + /// Use last Lsn of another timeline (and its data) as base when creating the new timeline. The + /// timeline gets resolved by its branch name. + #[clap(long)] ancestor_branch_name: Option, - - #[clap( - long, - help = "When using another timeline as base, use a specific Lsn in it instead of the latest one" - )] + /// When using another timeline as base, use a specific Lsn in it instead of the latest one. + #[clap(long)] ancestor_start_lsn: Option, } +/// Create a new blank timeline. #[derive(clap::Args)] -#[clap(about = "Create a new blank timeline")] struct TimelineCreateCmdArgs { - #[clap( - long = "tenant-id", - help = "Tenant id. Represented as a hexadecimal string 32 symbols length" - )] + /// Tenant ID, as a 32-byte hexadecimal string. + #[clap(long = "tenant-id")] tenant_id: Option, - - #[clap(long, help = "New timeline's ID")] + /// New timeline's ID, as a 32-byte hexadecimal string. + #[clap(long)] timeline_id: Option, - - #[clap(long, help = "Human-readable alias for the new timeline")] + /// Human-readable alias for the new timeline. + #[clap(long)] branch_name: String, + /// Postgres version. #[arg(default_value = DEFAULT_PG_VERSION_NUM)] - #[clap(long, help = "Postgres version")] + #[clap(long)] pg_version: PgMajorVersion, } +/// Import a timeline from a basebackup directory. #[derive(clap::Args)] -#[clap(about = "Import timeline from a basebackup directory")] struct TimelineImportCmdArgs { - #[clap( - long = "tenant-id", - help = "Tenant id. Represented as a hexadecimal string 32 symbols length" - )] + /// Tenant ID, as a 32-byte hexadecimal string. + #[clap(long = "tenant-id")] tenant_id: Option, - - #[clap(long, help = "New timeline's ID")] + /// New timeline's ID, as a 32-byte hexadecimal string. + #[clap(long)] timeline_id: TimelineId, - - #[clap(long, help = "Human-readable alias for the new timeline")] + /// Human-readable alias for the new timeline. + #[clap(long)] branch_name: String, - - #[clap(long, help = "Basebackup tarfile to import")] + /// Basebackup tarfile to import. + #[clap(long)] base_tarfile: PathBuf, - - #[clap(long, help = "Lsn the basebackup starts at")] + /// LSN the basebackup starts at. + #[clap(long)] base_lsn: Lsn, - - #[clap(long, help = "Wal to add after base")] + /// WAL to add after base. + #[clap(long)] wal_tarfile: Option, - - #[clap(long, help = "Lsn the basebackup ends at")] + /// LSN the basebackup ends at. + #[clap(long)] end_lsn: Option, + /// Postgres version of the basebackup being imported. #[arg(default_value = DEFAULT_PG_VERSION_NUM)] - #[clap(long, help = "Postgres version of the backup being imported")] + #[clap(long)] pg_version: PgMajorVersion, } +/// Manage pageservers. #[derive(clap::Subcommand)] -#[clap(about = "Manage pageservers")] enum PageserverCmd { Status(PageserverStatusCmdArgs), Start(PageserverStartCmdArgs), @@ -342,259 +319,234 @@ enum PageserverCmd { Restart(PageserverRestartCmdArgs), } +/// Show status of a local pageserver. #[derive(clap::Args)] -#[clap(about = "Show status of a local pageserver")] struct PageserverStatusCmdArgs { - #[clap(long = "id", help = "pageserver id")] + /// Pageserver ID. + #[clap(long = "id")] pageserver_id: Option, } +/// Start local pageserver. #[derive(clap::Args)] -#[clap(about = "Start local pageserver")] struct PageserverStartCmdArgs { - #[clap(long = "id", help = "pageserver id")] + /// Pageserver ID. + #[clap(long = "id")] pageserver_id: Option, - - #[clap(short = 't', long, help = "timeout until we fail the command")] + /// Timeout until we fail the command. + #[clap(short = 't', long)] #[arg(default_value = "10s")] start_timeout: humantime::Duration, } +/// Stop local pageserver. #[derive(clap::Args)] -#[clap(about = "Stop local pageserver")] struct PageserverStopCmdArgs { - #[clap(long = "id", help = "pageserver id")] + /// Pageserver ID. + #[clap(long = "id")] pageserver_id: Option, - - #[clap( - short = 'm', - help = "If 'immediate', don't flush repository data at shutdown" - )] + /// If 'immediate', don't flush repository data at shutdown + #[clap(short = 'm')] #[arg(value_enum, default_value = "fast")] stop_mode: StopMode, } +/// Restart local pageserver. #[derive(clap::Args)] -#[clap(about = "Restart local pageserver")] struct PageserverRestartCmdArgs { - #[clap(long = "id", help = "pageserver id")] + /// Pageserver ID. + #[clap(long = "id")] pageserver_id: Option, - - #[clap(short = 't', long, help = "timeout until we fail the command")] + /// Timeout until we fail the command. + #[clap(short = 't', long)] #[arg(default_value = "10s")] start_timeout: humantime::Duration, } +/// Manage storage controller. #[derive(clap::Subcommand)] -#[clap(about = "Manage storage controller")] enum StorageControllerCmd { Start(StorageControllerStartCmdArgs), Stop(StorageControllerStopCmdArgs), } +/// Start storage controller. #[derive(clap::Args)] -#[clap(about = "Start storage controller")] struct StorageControllerStartCmdArgs { - #[clap(short = 't', long, help = "timeout until we fail the command")] + /// Timeout until we fail the command. + #[clap(short = 't', long)] #[arg(default_value = "10s")] start_timeout: humantime::Duration, - - #[clap( - long, - help = "Identifier used to distinguish storage controller instances" - )] + /// Identifier used to distinguish storage controller instances. + #[clap(long)] #[arg(default_value_t = 1)] instance_id: u8, - - #[clap( - long, - help = "Base port for the storage controller instance idenfified by instance-id (defaults to pageserver cplane api)" - )] + /// Base port for the storage controller instance identified by instance-id (defaults to + /// pageserver cplane api). + #[clap(long)] base_port: Option, - #[clap( - long, - help = "Whether the storage controller should handle pageserver-reported local disk loss events." - )] + /// Whether the storage controller should handle pageserver-reported local disk loss events. + #[clap(long)] handle_ps_local_disk_loss: Option, } +/// Stop storage controller. #[derive(clap::Args)] -#[clap(about = "Stop storage controller")] struct StorageControllerStopCmdArgs { - #[clap( - short = 'm', - help = "If 'immediate', don't flush repository data at shutdown" - )] + /// If 'immediate', don't flush repository data at shutdown + #[clap(short = 'm')] #[arg(value_enum, default_value = "fast")] stop_mode: StopMode, - - #[clap( - long, - help = "Identifier used to distinguish storage controller instances" - )] + /// Identifier used to distinguish storage controller instances. + #[clap(long)] #[arg(default_value_t = 1)] instance_id: u8, } +/// Manage storage broker. #[derive(clap::Subcommand)] -#[clap(about = "Manage storage broker")] enum StorageBrokerCmd { Start(StorageBrokerStartCmdArgs), Stop(StorageBrokerStopCmdArgs), } +/// Start broker. #[derive(clap::Args)] -#[clap(about = "Start broker")] struct StorageBrokerStartCmdArgs { - #[clap(short = 't', long, help = "timeout until we fail the command")] - #[arg(default_value = "10s")] + /// Timeout until we fail the command. + #[clap(short = 't', long, default_value = "10s")] start_timeout: humantime::Duration, } +/// Stop broker. #[derive(clap::Args)] -#[clap(about = "stop broker")] struct StorageBrokerStopCmdArgs { - #[clap( - short = 'm', - help = "If 'immediate', don't flush repository data at shutdown" - )] + /// If 'immediate', don't flush repository data on shutdown. + #[clap(short = 'm')] #[arg(value_enum, default_value = "fast")] stop_mode: StopMode, } +/// Manage safekeepers. #[derive(clap::Subcommand)] -#[clap(about = "Manage safekeepers")] enum SafekeeperCmd { Start(SafekeeperStartCmdArgs), Stop(SafekeeperStopCmdArgs), Restart(SafekeeperRestartCmdArgs), } +/// Manage object storage. #[derive(clap::Subcommand)] -#[clap(about = "Manage object storage")] enum EndpointStorageCmd { Start(EndpointStorageStartCmd), Stop(EndpointStorageStopCmd), } +/// Start object storage. #[derive(clap::Args)] -#[clap(about = "Start object storage")] struct EndpointStorageStartCmd { - #[clap(short = 't', long, help = "timeout until we fail the command")] + /// Timeout until we fail the command. + #[clap(short = 't', long)] #[arg(default_value = "10s")] start_timeout: humantime::Duration, } +/// Stop object storage. #[derive(clap::Args)] -#[clap(about = "Stop object storage")] struct EndpointStorageStopCmd { + /// If 'immediate', don't flush repository data on shutdown. + #[clap(short = 'm')] #[arg(value_enum, default_value = "fast")] - #[clap( - short = 'm', - help = "If 'immediate', don't flush repository data at shutdown" - )] stop_mode: StopMode, } +/// Start local safekeeper. #[derive(clap::Args)] -#[clap(about = "Start local safekeeper")] struct SafekeeperStartCmdArgs { - #[clap(help = "safekeeper id")] + /// Safekeeper ID. #[arg(default_value_t = NodeId(1))] id: NodeId, - #[clap( - short = 'e', - long = "safekeeper-extra-opt", - help = "Additional safekeeper invocation options, e.g. -e=--http-auth-public-key-path=foo" - )] + /// Additional safekeeper invocation options, e.g. -e=--http-auth-public-key-path=foo. + #[clap(short = 'e', long = "safekeeper-extra-opt")] extra_opt: Vec, - #[clap(short = 't', long, help = "timeout until we fail the command")] + /// Timeout until we fail the command. + #[clap(short = 't', long)] #[arg(default_value = "10s")] start_timeout: humantime::Duration, } +/// Stop local safekeeper. #[derive(clap::Args)] -#[clap(about = "Stop local safekeeper")] struct SafekeeperStopCmdArgs { - #[clap(help = "safekeeper id")] + /// Safekeeper ID. #[arg(default_value_t = NodeId(1))] id: NodeId, + /// If 'immediate', don't flush repository data on shutdown. #[arg(value_enum, default_value = "fast")] - #[clap( - short = 'm', - help = "If 'immediate', don't flush repository data at shutdown" - )] + #[clap(short = 'm')] stop_mode: StopMode, } +/// Restart local safekeeper. #[derive(clap::Args)] -#[clap(about = "Restart local safekeeper")] struct SafekeeperRestartCmdArgs { - #[clap(help = "safekeeper id")] + /// Safekeeper ID. #[arg(default_value_t = NodeId(1))] id: NodeId, + /// If 'immediate', don't flush repository data on shutdown. #[arg(value_enum, default_value = "fast")] - #[clap( - short = 'm', - help = "If 'immediate', don't flush repository data at shutdown" - )] + #[clap(short = 'm')] stop_mode: StopMode, - #[clap( - short = 'e', - long = "safekeeper-extra-opt", - help = "Additional safekeeper invocation options, e.g. -e=--http-auth-public-key-path=foo" - )] + /// Additional safekeeper invocation options, e.g. -e=--http-auth-public-key-path=foo. + #[clap(short = 'e', long = "safekeeper-extra-opt")] extra_opt: Vec, - #[clap(short = 't', long, help = "timeout until we fail the command")] + /// Timeout until we fail the command. + #[clap(short = 't', long)] #[arg(default_value = "10s")] start_timeout: humantime::Duration, } +/// Manage Postgres instances. #[derive(clap::Subcommand)] -#[clap(about = "Manage Postgres instances")] enum EndpointCmd { List(EndpointListCmdArgs), Create(EndpointCreateCmdArgs), Start(EndpointStartCmdArgs), Reconfigure(EndpointReconfigureCmdArgs), + RefreshConfiguration(EndpointRefreshConfigurationArgs), Stop(EndpointStopCmdArgs), + UpdatePageservers(EndpointUpdatePageserversCmdArgs), GenerateJwt(EndpointGenerateJwtCmdArgs), } +/// List endpoints. #[derive(clap::Args)] -#[clap(about = "List endpoints")] struct EndpointListCmdArgs { - #[clap( - long = "tenant-id", - help = "Tenant id. Represented as a hexadecimal string 32 symbols length" - )] + /// Tenant ID, as a 32-byte hexadecimal string. + #[clap(long = "tenant-id")] tenant_shard_id: Option, } +/// Create a compute endpoint. #[derive(clap::Args)] -#[clap(about = "Create a compute endpoint")] struct EndpointCreateCmdArgs { - #[clap( - long = "tenant-id", - help = "Tenant id. Represented as a hexadecimal string 32 symbols length" - )] + /// Tenant ID, as a 32-byte hexadecimal string. + #[clap(long = "tenant-id")] tenant_id: Option, - - #[clap(help = "Postgres endpoint id")] + /// Postgres endpoint ID. endpoint_id: Option, - #[clap(long, help = "Name of the branch the endpoint will run on")] + /// Name of the branch the endpoint will run on. + #[clap(long)] branch_name: Option, - #[clap( - long, - help = "Specify Lsn on the timeline to start from. By default, end of the timeline would be used" - )] + /// Specify LSN on the timeline to start from. By default, end of the timeline would be used. + #[clap(long)] lsn: Option, #[clap(long)] pg_port: Option, @@ -605,16 +557,13 @@ struct EndpointCreateCmdArgs { #[clap(long = "pageserver-id")] endpoint_pageserver_id: Option, - #[clap( - long, - help = "Don't do basebackup, create endpoint directory with only config files", - action = clap::ArgAction::Set, - default_value_t = false - )] + /// Don't do basebackup, create endpoint directory with only config files. + #[clap(long, action = clap::ArgAction::Set, default_value_t = false)] config_only: bool, + /// Postgres version. #[arg(default_value = DEFAULT_PG_VERSION_NUM)] - #[clap(long, help = "Postgres version")] + #[clap(long)] pg_version: PgMajorVersion, /// Use gRPC to communicate with Pageservers, by generating grpc:// connstrings. @@ -625,153 +574,140 @@ struct EndpointCreateCmdArgs { #[clap(long)] grpc: bool, - #[clap( - long, - help = "If set, the node will be a hot replica on the specified timeline", - action = clap::ArgAction::Set, - default_value_t = false - )] + /// If set, the node will be a hot replica on the specified timeline. + #[clap(long, action = clap::ArgAction::Set, default_value_t = false)] hot_standby: bool, - - #[clap(long, help = "If set, will set up the catalog for neon_superuser")] + /// If set, will set up the catalog for neon_superuser. + #[clap(long)] update_catalog: bool, - - #[clap( - long, - help = "Allow multiple primary endpoints running on the same branch. Shouldn't be used normally, but useful for tests." - )] + /// Allow multiple primary endpoints running on the same branch. Shouldn't be used normally, but + /// useful for tests. + #[clap(long)] allow_multiple: bool, - /// Only allow changing it on creation - #[clap(long, help = "Name of the privileged role for the endpoint")] + /// Name of the privileged role for the endpoint. + // Only allow changing it on creation. + #[clap(long)] privileged_role_name: Option, } +/// Start Postgres. If the endpoint doesn't exist yet, it is created. #[derive(clap::Args)] -#[clap(about = "Start postgres. If the endpoint doesn't exist yet, it is created.")] struct EndpointStartCmdArgs { - #[clap(help = "Postgres endpoint id")] + /// Postgres endpoint ID. endpoint_id: String, + /// Pageserver ID. #[clap(long = "pageserver-id")] endpoint_pageserver_id: Option, - - #[clap( - long, - help = "Safekeepers membership generation to prefix neon.safekeepers with. Normally neon_local sets it on its own, but this option allows to override. Non zero value forces endpoint to use membership configurations." - )] + /// Safekeepers membership generation to prefix neon.safekeepers with. + #[clap(long)] safekeepers_generation: Option, - #[clap( - long, - help = "List of safekeepers endpoint will talk to. Normally neon_local chooses them on its own, but this option allows to override." - )] + /// List of safekeepers endpoint will talk to. + #[clap(long)] safekeepers: Option, - - #[clap( - long, - help = "Configure the remote extensions storage proxy gateway URL to request for extensions.", - alias = "remote-ext-config" - )] + /// Configure the remote extensions storage proxy gateway URL to request for extensions. + #[clap(long, alias = "remote-ext-config")] remote_ext_base_url: Option, - - #[clap( - long, - help = "If set, will create test user `user` and `neondb` database. Requires `update-catalog = true`" - )] + /// If set, will create test user `user` and `neondb` database. Requires `update-catalog = true` + #[clap(long)] create_test_user: bool, - - #[clap( - long, - help = "Allow multiple primary endpoints running on the same branch. Shouldn't be used normally, but useful for tests." - )] + /// Allow multiple primary endpoints running on the same branch. Shouldn't be used normally, but + /// useful for tests. + #[clap(long)] allow_multiple: bool, - - #[clap(short = 't', long, value_parser= humantime::parse_duration, help = "timeout until we fail the command")] + /// Timeout until we fail the command. + #[clap(short = 't', long, value_parser= humantime::parse_duration)] #[arg(default_value = "90s")] start_timeout: Duration, - #[clap( - long, - help = "Download LFC cache from endpoint storage on endpoint startup", - default_value = "false" - )] + /// Download LFC cache from endpoint storage on endpoint startup + #[clap(long, default_value = "false")] autoprewarm: bool, - #[clap(long, help = "Upload LFC cache to endpoint storage periodically")] + /// Upload LFC cache to endpoint storage periodically + #[clap(long)] offload_lfc_interval_seconds: Option, - #[clap( - long, - help = "Run in development mode, skipping VM-specific operations like process termination", - action = clap::ArgAction::SetTrue - )] + /// Run in development mode, skipping VM-specific operations like process termination + #[clap(long, action = clap::ArgAction::SetTrue)] dev: bool, } +/// Reconfigure an endpoint. #[derive(clap::Args)] -#[clap(about = "Reconfigure an endpoint")] struct EndpointReconfigureCmdArgs { - #[clap( - long = "tenant-id", - help = "Tenant id. Represented as a hexadecimal string 32 symbols length" - )] + /// Tenant id. Represented as a hexadecimal string 32 symbols length + #[clap(long = "tenant-id")] tenant_id: Option, - - #[clap(help = "Postgres endpoint id")] + /// Postgres endpoint ID. endpoint_id: String, + /// Pageserver ID. #[clap(long = "pageserver-id")] endpoint_pageserver_id: Option, - #[clap(long)] safekeepers: Option, } +/// Refresh the endpoint's configuration by forcing it reload it's spec #[derive(clap::Args)] -#[clap(about = "Stop an endpoint")] -struct EndpointStopCmdArgs { - #[clap(help = "Postgres endpoint id")] +struct EndpointRefreshConfigurationArgs { + /// Postgres endpoint id endpoint_id: String, +} - #[clap( - long, - help = "Also delete data directory (now optional, should be default in future)" - )] +/// Stop an endpoint. +#[derive(clap::Args)] +struct EndpointStopCmdArgs { + /// Postgres endpoint ID. + endpoint_id: String, + /// Also delete data directory (now optional, should be default in future). + #[clap(long)] destroy: bool, - #[clap(long, help = "Postgres shutdown mode")] + /// Postgres shutdown mode, passed to `pg_ctl -m `. + #[clap(long)] #[clap(default_value = "fast")] mode: EndpointTerminateMode, } +/// Update the pageservers in the spec file of the compute endpoint #[derive(clap::Args)] -#[clap(about = "Generate a JWT for an endpoint")] -struct EndpointGenerateJwtCmdArgs { - #[clap(help = "Postgres endpoint id")] +struct EndpointUpdatePageserversCmdArgs { + /// Postgres endpoint id endpoint_id: String, - #[clap(short = 's', long, help = "Scope to generate the JWT with", value_parser = ComputeClaimsScope::from_str)] + /// Specified pageserver id + #[clap(short = 'p', long)] + pageserver_id: Option, +} + +/// Generate a JWT for an endpoint. +#[derive(clap::Args)] +struct EndpointGenerateJwtCmdArgs { + /// Postgres endpoint ID. + endpoint_id: String, + /// Scope to generate the JWT with. + #[clap(short = 's', long, value_parser = ComputeClaimsScope::from_str)] scope: Option, } +/// Manage neon_local branch name mappings. #[derive(clap::Subcommand)] -#[clap(about = "Manage neon_local branch name mappings")] enum MappingsCmd { Map(MappingsMapCmdArgs), } +/// Create new mapping which cannot exist already. #[derive(clap::Args)] -#[clap(about = "Create new mapping which cannot exist already")] struct MappingsMapCmdArgs { - #[clap( - long, - help = "Tenant id. Represented as a hexadecimal string 32 symbols length" - )] + /// Tenant ID, as a 32-byte hexadecimal string. + #[clap(long)] tenant_id: TenantId, - #[clap( - long, - help = "Timeline id. Represented as a hexadecimal string 32 symbols length" - )] + /// Timeline ID, as a 32-byte hexadecimal string. + #[clap(long)] timeline_id: TimelineId, - #[clap(long, help = "Branch name to give to the timeline")] + /// Branch name to give to the timeline. + #[clap(long)] branch_name: String, } @@ -1521,7 +1457,7 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res let endpoint = cplane .endpoints .get(endpoint_id.as_str()) - .ok_or_else(|| anyhow::anyhow!("endpoint {endpoint_id} not found"))?; + .ok_or_else(|| anyhow!("endpoint {endpoint_id} not found"))?; if !args.allow_multiple { cplane.check_conflicting_endpoints( @@ -1539,22 +1475,7 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res let mut pageserver_conninfo = if let Some(ps_id) = pageserver_id { let conf = env.get_pageserver_conf(ps_id).unwrap(); - let ps_conninfo = pageserver_conf_to_shard_conn_info(conf)?; - - let shard_info = PageserverShardInfo { - pageservers: vec![ps_conninfo], - }; - // If caller is telling us what pageserver to use, this is not a tenant which is - // fully managed by storage controller, therefore not sharded. - let shards: HashMap<_, _> = vec![(ShardIndex::unsharded(), shard_info)] - .into_iter() - .collect(); - PageserverConnectionInfo { - shard_count: ShardCount(0), - stripe_size: None, - shards, - prefer_protocol, - } + local_pageserver_conf_to_conn_info(conf)? } else { // Look up the currently attached location of the tenant, and its striping metadata, // to pass these on to postgres. @@ -1622,6 +1543,36 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res println!("Starting existing endpoint {endpoint_id}..."); endpoint.start(args).await?; } + EndpointCmd::UpdatePageservers(args) => { + let endpoint_id = &args.endpoint_id; + let endpoint = cplane + .endpoints + .get(endpoint_id.as_str()) + .with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?; + let prefer_protocol = if endpoint.grpc { + PageserverProtocol::Grpc + } else { + PageserverProtocol::Libpq + }; + let mut pageserver_conninfo = match args.pageserver_id { + Some(pageserver_id) => { + let conf = env.get_pageserver_conf(pageserver_id)?; + local_pageserver_conf_to_conn_info(conf)? + } + None => { + let storage_controller = StorageController::from_env(env); + let locate_result = + storage_controller.tenant_locate(endpoint.tenant_id).await?; + + tenant_locate_response_to_conn_info(&locate_result)? + } + }; + pageserver_conninfo.prefer_protocol = prefer_protocol; + + endpoint + .update_pageservers_in_config(&pageserver_conninfo) + .await?; + } EndpointCmd::Reconfigure(args) => { let endpoint_id = &args.endpoint_id; let endpoint = cplane @@ -1636,22 +1587,7 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res }; let mut pageserver_conninfo = if let Some(ps_id) = args.endpoint_pageserver_id { let conf = env.get_pageserver_conf(ps_id)?; - let ps_conninfo = pageserver_conf_to_shard_conn_info(conf)?; - let shard_info = PageserverShardInfo { - pageservers: vec![ps_conninfo], - }; - - // If caller is telling us what pageserver to use, this is not a tenant which is - // fully managed by storage controller, therefore not sharded. - let shards: HashMap<_, _> = vec![(ShardIndex::unsharded(), shard_info)] - .into_iter() - .collect(); - PageserverConnectionInfo { - shard_count: ShardCount::unsharded(), - stripe_size: None, - shards, - prefer_protocol, - } + local_pageserver_conf_to_conn_info(conf)? } else { // Look up the currently attached location of the tenant, and its striping metadata, // to pass these on to postgres. @@ -1669,6 +1605,14 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res .reconfigure(Some(&pageserver_conninfo), safekeepers, None) .await?; } + EndpointCmd::RefreshConfiguration(args) => { + let endpoint_id = &args.endpoint_id; + let endpoint = cplane + .endpoints + .get(endpoint_id.as_str()) + .with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?; + endpoint.refresh_configuration().await?; + } EndpointCmd::Stop(args) => { let endpoint_id = &args.endpoint_id; let endpoint = cplane diff --git a/control_plane/src/endpoint.rs b/control_plane/src/endpoint.rs index 149ea07a6b..814ee2a52f 100644 --- a/control_plane/src/endpoint.rs +++ b/control_plane/src/endpoint.rs @@ -79,7 +79,7 @@ use spki::der::Decode; use spki::{SubjectPublicKeyInfo, SubjectPublicKeyInfoRef}; use tracing::debug; use utils::id::{NodeId, TenantId, TimelineId}; -use utils::shard::{ShardIndex, ShardNumber}; +use utils::shard::{ShardCount, ShardIndex, ShardNumber}; use pageserver_api::config::DEFAULT_GRPC_LISTEN_PORT as DEFAULT_PAGESERVER_GRPC_PORT; use postgres_connection::parse_host_port; @@ -728,14 +728,13 @@ impl Endpoint { // For the sake of backwards-compatibility, also fill in 'pageserver_connstring' // + // XXX: I believe this is not really needed, except to make + // test_forward_compatibility happy. + // // Use a closure so that we can conviniently return None in the middle of the // loop. let pageserver_connstring: Option = (|| { - let num_shards = if args.pageserver_conninfo.shard_count.is_unsharded() { - 1 - } else { - args.pageserver_conninfo.shard_count.0 - }; + let num_shards = args.pageserver_conninfo.shard_count.count(); let mut connstrings = Vec::new(); for shard_no in 0..num_shards { let shard_index = ShardIndex { @@ -827,6 +826,7 @@ impl Endpoint { autoprewarm: args.autoprewarm, offload_lfc_interval_seconds: args.offload_lfc_interval_seconds, suspend_timeout_seconds: -1, // Only used in neon_local. + databricks_settings: None, }; // this strange code is needed to support respec() in tests @@ -971,7 +971,9 @@ impl Endpoint { | ComputeStatus::Configuration | ComputeStatus::TerminationPendingFast | ComputeStatus::TerminationPendingImmediate - | ComputeStatus::Terminated => { + | ComputeStatus::Terminated + | ComputeStatus::RefreshConfigurationPending + | ComputeStatus::RefreshConfiguration => { bail!("unexpected compute status: {:?}", state.status) } } @@ -994,6 +996,27 @@ impl Endpoint { Ok(()) } + // Update the pageservers in the spec file of the endpoint. This is useful to test the spec refresh scenario. + pub async fn update_pageservers_in_config( + &self, + pageserver_conninfo: &PageserverConnectionInfo, + ) -> Result<()> { + let config_path = self.endpoint_path().join("config.json"); + let mut config: ComputeConfig = { + let file = std::fs::File::open(&config_path)?; + serde_json::from_reader(file)? + }; + + let mut spec = config.spec.unwrap(); + spec.pageserver_connection_info = Some(pageserver_conninfo.clone()); + config.spec = Some(spec); + + let file = std::fs::File::create(&config_path)?; + serde_json::to_writer_pretty(file, &config)?; + + Ok(()) + } + // Call the /status HTTP API pub async fn get_status(&self) -> Result { let client = reqwest::Client::new(); @@ -1156,6 +1179,33 @@ impl Endpoint { Ok(response) } + pub async fn refresh_configuration(&self) -> Result<()> { + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(30)) + .build() + .unwrap(); + let response = client + .post(format!( + "http://{}:{}/refresh_configuration", + self.internal_http_address.ip(), + self.internal_http_address.port() + )) + .send() + .await?; + + let status = response.status(); + if !(status.is_client_error() || status.is_server_error()) { + Ok(()) + } else { + let url = response.url().to_owned(); + let msg = match response.text().await { + Ok(err_body) => format!("Error: {err_body}"), + Err(_) => format!("Http error ({}) at {}.", status.as_u16(), url), + }; + Err(anyhow::anyhow!(msg)) + } + } + pub fn connstr(&self, user: &str, db_name: &str) -> String { format!( "postgresql://{}@{}:{}/{}", @@ -1167,9 +1217,11 @@ impl Endpoint { } } -pub fn pageserver_conf_to_shard_conn_info( +/// If caller is telling us what pageserver to use, this is not a tenant which is +/// fully managed by storage controller, therefore not sharded. +pub fn local_pageserver_conf_to_conn_info( conf: &crate::local_env::PageServerConf, -) -> Result { +) -> Result { let libpq_url = { let (host, port) = parse_host_port(&conf.listen_pg_addr)?; let port = port.unwrap_or(5432); @@ -1182,10 +1234,24 @@ pub fn pageserver_conf_to_shard_conn_info( } else { None }; - Ok(PageserverShardConnectionInfo { - id: Some(conf.id.to_string()), + let ps_conninfo = PageserverShardConnectionInfo { + id: Some(conf.id), libpq_url, grpc_url, + }; + + let shard_info = PageserverShardInfo { + pageservers: vec![ps_conninfo], + }; + + let shards: HashMap<_, _> = vec![(ShardIndex::unsharded(), shard_info)] + .into_iter() + .collect(); + Ok(PageserverConnectionInfo { + shard_count: ShardCount::unsharded(), + stripe_size: None, + shards, + prefer_protocol: PageserverProtocol::default(), }) } @@ -1210,7 +1276,7 @@ pub fn tenant_locate_response_to_conn_info( let shard_info = PageserverShardInfo { pageservers: vec![PageserverShardConnectionInfo { - id: Some(shard.node_id.to_string()), + id: Some(shard.node_id), libpq_url, grpc_url, }], @@ -1222,7 +1288,7 @@ pub fn tenant_locate_response_to_conn_info( let stripe_size = if response.shard_params.count.is_unsharded() { None } else { - Some(response.shard_params.stripe_size.0) + Some(response.shard_params.stripe_size) }; Ok(PageserverConnectionInfo { shard_count: response.shard_params.count, diff --git a/docker-compose/compute_wrapper/var/db/postgres/configs/config.json b/docker-compose/compute_wrapper/var/db/postgres/configs/config.json index 60e232425b..06a47f437d 100644 --- a/docker-compose/compute_wrapper/var/db/postgres/configs/config.json +++ b/docker-compose/compute_wrapper/var/db/postgres/configs/config.json @@ -120,6 +120,11 @@ "value": "host=pageserver port=6400", "vartype": "string" }, + { + "name": "neon.pageserver_grpc_urls", + "value": "grpc://pageserver:6401/", + "vartype": "string" + }, { "name": "max_replication_write_lag", "value": "500MB", diff --git a/docker-compose/pageserver_config/pageserver.toml b/docker-compose/pageserver_config/pageserver.toml index 81445ed412..fe7a5744be 100644 --- a/docker-compose/pageserver_config/pageserver.toml +++ b/docker-compose/pageserver_config/pageserver.toml @@ -1,6 +1,7 @@ broker_endpoint='http://storage_broker:50051' pg_distrib_dir='/usr/local/' listen_pg_addr='0.0.0.0:6400' +listen_grpc_addr='0.0.0.0:6401' listen_http_addr='0.0.0.0:9898' remote_storage={ endpoint='http://minio:9000', bucket_name='neon', bucket_region='eu-north-1', prefix_in_bucket='/pageserver' } control_plane_api='http://0.0.0.0:6666' # No storage controller in docker compose, specify a junk address diff --git a/docs/rfcs/2025-07-07-node-deletion-api-improvement.md b/docs/rfcs/2025-07-07-node-deletion-api-improvement.md new file mode 100644 index 0000000000..47dadaee35 --- /dev/null +++ b/docs/rfcs/2025-07-07-node-deletion-api-improvement.md @@ -0,0 +1,246 @@ +# Node deletion API improvement + +Created on 2025-07-07 +Implemented on _TBD_ + +## Summary + +This RFC describes improvements to the storage controller API for gracefully deleting pageserver +nodes. + +## Motivation + +The basic node deletion API introduced in [#8226](https://github.com/neondatabase/neon/issues/8333) +has several limitations: + +- Deleted nodes can re-add themselves if they restart (e.g., a flaky node that keeps restarting and +we cannot reach via SSH to stop the pageserver). This issue has been resolved by tombstone +mechanism in [#12036](https://github.com/neondatabase/neon/issues/12036) +- Process of node deletion is not graceful, i.e. it just imitates a node failure + +In this context, "graceful" node deletion means that users do not experience any disruption or +negative effects, provided the system remains in a healthy state (i.e., the remaining pageservers +can handle the workload and all requirements are met). To achieve this, the system must perform +live migration of all tenant shards from the node being deleted while the node is still running +and continue processing all incoming requests. The node is removed only after all tenant shards +have been safely migrated. + +Although live migrations can be achieved with the drain functionality, it leads to incorrect shard +placement, such as not matching availability zones. This results in unnecessary work to optimize +the placement that was just recently performed. + +If we delete a node before its tenant shards are fully moved, the new node won't have all the +needed data (e.g. heatmaps) ready. This means user requests to the new node will be much slower at +first. If there are many tenant shards, this slowdown affects a huge amount of users. + +Graceful node deletion is more complicated and can introduce new issues. It takes longer because +live migration of each tenant shard can last several minutes. Using non-blocking accessors may +also cause deletion to wait if other processes are holding inner state lock. It also gets trickier +because we need to handle other requests, like drain and fill, at the same time. + +## Impacted components (e.g. pageserver, safekeeper, console, etc) + +- storage controller +- pageserver (indirectly) + +## Proposed implementation + +### Tombstones + +To resolve the problem of deleted nodes re-adding themselves, a tombstone mechanism was introduced +as part of the node stored information. Each node has a separate `NodeLifecycle` field with two +possible states: `Active` and `Deleted`. When node deletion completes, the database row is not +deleted but instead has its `NodeLifecycle` column switched to `Deleted`. Nodes with `Deleted` +lifecycle are treated as if the row is absent for most handlers, with several exceptions: reattach +and register functionality must be aware of tombstones. Additionally, new debug handlers are +available for listing and deleting tombstones via the `/debug/v1/tombstone` path. + +### Gracefulness + +The problem of making node deletion graceful is complex and involves several challenges: + +- **Cancellable**: The operation must be cancellable to allow administrators to abort the process +if needed, e.g. if run by mistake. +- **Non-blocking**: We don't want to block deployment operations like draining/filling on the node +deletion process. We need clear policies for handling concurrent operations: what happens when a +drain/fill request arrives while deletion is in progress, and what happens when a delete request +arrives while drain/fill is in progress. +- **Persistent**: If the storage controller restarts during this long-running operation, we must +preserve progress and automatically resume the deletion process after the storage controller +restarts. +- **Migrated correctly**: We cannot simply use the existing drain mechanism for nodes scheduled +for deletion, as this would move shards to irrelevant locations. The drain process expects the +node to return, so it only moves shards to backup locations, not to their preferred AZs. It also +leaves secondary locations unmoved. This could result in unnecessary load on the storage +controller and inefficient resource utilization. +- **Force option**: Administrators need the ability to force immediate, non-graceful deletion when +time constraints or emergency situations require it, bypassing the normal graceful migration +process. + +See below for a detailed breakdown of the proposed changes and mechanisms. + +#### Node lifecycle + +New `NodeLifecycle` enum and a matching database field with these values: +- `Active`: The normal state. All operations are allowed. +- `ScheduledForDeletion`: The node is marked to be deleted soon. Deletion may be in progress or +will happen later, but the node will eventually be removed. All operations are allowed. +- `Deleted`: The node is fully deleted. No operations are allowed, and the node cannot be brought +back. The only action left is to remove its record from the database. Any attempt to register a +node in this state will fail. + +This state persists across storage controller restarts. + +**State transition** +``` + +--------------------+ + +---| Active |<---------------------+ + | +--------------------+ | + | ^ | + | start_node_delete | cancel_node_delete | + v | | + +----------------------------------+ | + | ScheduledForDeletion | | + +----------------------------------+ | + | | + | node_register | + | | + | delete_node (at the finish) | + | | + v | + +---------+ tombstone_delete +----------+ + | Deleted |-------------------------------->| no row | + +---------+ +----------+ +``` + +#### NodeSchedulingPolicy::Deleting + +A `Deleting` variant to the `NodeSchedulingPolicy` enum. This means the deletion function is +running for the node right now. Only one node can have the `Deleting` policy at a time. + +The `NodeSchedulingPolicy::Deleting` state is persisted in the database. However, after a storage +controller restart, any node previously marked as `Deleting` will have its scheduling policy reset +to `Pause`. The policy will only transition back to `Deleting` when the deletion operation is +actively started again, as triggered by the node's `NodeLifecycle::ScheduledForDeletion` state. + +`NodeSchedulingPolicy` transition details: +1. When `node_delete` begins, set the policy to `NodeSchedulingPolicy::Deleting`. +2. If `node_delete` is cancelled (for example, due to a concurrent drain operation), revert the +policy to its previous value. The policy is persisted in storcon DB. +3. After `node_delete` completes, the final value of the scheduling policy is irrelevant, since +`NodeLifecycle::Deleted` prevents any further access to this field. + +The deletion process cannot be initiated for nodes currently undergoing deployment-related +operations (`Draining`, `Filling`, or `PauseForRestart` policies). Deletion will only be triggered +once the node transitions to either the `Active` or `Pause` state. + +#### OperationTracker + +A replacement for `Option ongoing_operation`, the `OperationTracker` is a +dedicated service state object responsible for managing all long-running node operations (drain, +fill, delete) with robust concurrency control. + +Key responsibilities: +- Orchestrates the execution of operations +- Supports cancellation of currently running operations +- Enforces operation constraints, e.g. allowing only single drain/fill operation at a time +- Persists deletion state, enabling recovery of pending deletions across restarts +- Ensures thread safety across concurrent requests + +#### Attached tenant shard processing + +When deleting a node, handle each attached tenant shard as follows: + +1. Pick the best node to become the new attached (the candidate). +2. If the candidate already has this shard as a secondary: + - Create a new secondary for the shard on another suitable node. + Otherwise: + - Create a secondary for the shard on the candidate node. +3. Wait until all secondaries are ready and pre-warmed. +4. Promote the candidate's secondary to attached. +5. Remove the secondary from the node being deleted. + +This process safely moves all attached shards before deleting the node. + +#### Secondary tenant shard processing + +When deleting a node, handle each secondary tenant shard as follows: + +1. Choose the best node to become the new secondary. +2. Create a secondary for the shard on that node. +3. Wait until the new secondary is ready. +4. Remove the secondary from the node being deleted. + +This ensures all secondary shards are safely moved before deleting the node. + +### Reliability, failure modes and corner cases + +In case of a storage controller failure and following restart, the system behavior depends on the +`NodeLifecycle` state: + +- If `NodeLifecycle` is `Active`: No action is taken for this node. +- If `NodeLifecycle` is `Deleted`: The node will not be re-added. +- If `NodeLifecycle` is `ScheduledForDeletion`: A deletion background task will be launched for +this node. + +In case of a pageserver node failure during deletion, the behavior depends on the `force` flag: +- If `force` is set: The node deletion will proceed regardless of the node's availability. +- If `force` is not set: The deletion will be retried a limited number of times. If the node +remains unavailable, the deletion process will pause and automatically resume when the node +becomes healthy again. + +### Operations concurrency + +The following sections describe the behavior when different types of requests arrive at the storage +controller and how they interact with ongoing operations. + +#### Delete request + +Handler: `PUT /control/v1/node/:node_id/delete` + +1. If node lifecycle is `NodeLifecycle::ScheduledForDeletion`: + - Return `200 OK`: there is already an ongoing deletion request for this node +2. Update & persist lifecycle to `NodeLifecycle::ScheduledForDeletion` +3. Persist current scheduling policy +4. If there is no active operation (drain/fill/delete): + - Run deletion process for this node + +#### Cancel delete request + +Handler: `DELETE /control/v1/node/:node_id/delete` + +1. If node lifecycle is not `NodeLifecycle::ScheduledForDeletion`: + - Return `404 Not Found`: there is no current deletion request for this node +2. If the active operation is deleting this node, cancel it +3. Update & persist lifecycle to `NodeLifecycle::Active` +4. Restore the last scheduling policy from persistence + +#### Drain/fill request + +1. If there are already ongoing drain/fill processes: + - Return `409 Conflict`: queueing of drain/fill processes is not supported +2. If there is an ongoing delete process: + - Cancel it and wait until it is cancelled +3. Run the drain/fill process +4. After the drain/fill process is cancelled or finished: + - Try to find another candidate to delete and run the deletion process for that node + +#### Drain/fill cancel request + +1. If the active operation is not the related process: + - Return `400 Bad Request`: cancellation request is incorrect, operations are not the same +2. Cancel the active operation +3. Try to find another candidate to delete and run the deletion process for that node + +## Definition of Done + +- [x] Fix flaky node scenario and introduce related debug handlers +- [ ] Node deletion intent is persistent - a node will be eventually deleted after a deletion +request regardless of draining/filling requests and restarts +- [ ] Node deletion can be graceful - deletion completes only after moving all tenant shards to +recommended locations +- [ ] Deploying does not break due to long deletions - drain/fill operations override deletion +process and deletion resumes after drain/fill completes +- [ ] `force` flag is implemented and provides fast, failure-tolerant node removal (e.g., when a +pageserver node does not respond) +- [ ] Legacy delete handler code is removed from storage_controller, test_runner, and storcon_cli diff --git a/libs/compute_api/src/responses.rs b/libs/compute_api/src/responses.rs index 5b8fc49750..a918644e4c 100644 --- a/libs/compute_api/src/responses.rs +++ b/libs/compute_api/src/responses.rs @@ -68,11 +68,15 @@ pub enum LfcPrewarmState { /// We tried to fetch the corresponding LFC state from the endpoint storage, /// but received `Not Found 404`. This should normally happen only during the /// first endpoint start after creation with `autoprewarm: true`. + /// This may also happen if LFC is turned off or not initialized /// /// During the orchestrated prewarm via API, when a caller explicitly /// provides the LFC state key to prewarm from, it's the caller responsibility /// to handle this status as an error state in this case. Skipped, + /// LFC prewarm was cancelled. Some pages in LFC cache may be prewarmed if query + /// has started working before cancellation + Cancelled, } impl Display for LfcPrewarmState { @@ -83,6 +87,7 @@ impl Display for LfcPrewarmState { LfcPrewarmState::Completed => f.write_str("Completed"), LfcPrewarmState::Skipped => f.write_str("Skipped"), LfcPrewarmState::Failed { error } => write!(f, "Error({error})"), + LfcPrewarmState::Cancelled => f.write_str("Cancelled"), } } } @@ -97,6 +102,7 @@ pub enum LfcOffloadState { Failed { error: String, }, + Skipped, } #[derive(Serialize, Debug, Clone, PartialEq)] @@ -108,11 +114,10 @@ pub enum PromoteState { Failed { error: String }, } -#[derive(Deserialize, Serialize, Default, Debug, Clone)] +#[derive(Deserialize, Default, Debug)] #[serde(rename_all = "snake_case")] -/// Result of /safekeepers_lsn -pub struct SafekeepersLsn { - pub safekeepers: String, +pub struct PromoteConfig { + pub spec: ComputeSpec, pub wal_flush_lsn: utils::lsn::Lsn, } @@ -173,6 +178,11 @@ pub enum ComputeStatus { TerminationPendingImmediate, // Terminated Postgres Terminated, + // A spec refresh is being requested + RefreshConfigurationPending, + // A spec refresh is being applied. We cannot refresh configuration again until the current + // refresh is done, i.e., signal_refresh_configuration() will return 500 error. + RefreshConfiguration, } #[derive(Deserialize, Serialize)] @@ -185,6 +195,10 @@ impl Display for ComputeStatus { match self { ComputeStatus::Empty => f.write_str("empty"), ComputeStatus::ConfigurationPending => f.write_str("configuration-pending"), + ComputeStatus::RefreshConfiguration => f.write_str("refresh-configuration"), + ComputeStatus::RefreshConfigurationPending => { + f.write_str("refresh-configuration-pending") + } ComputeStatus::Init => f.write_str("init"), ComputeStatus::Running => f.write_str("running"), ComputeStatus::Configuration => f.write_str("configuration"), diff --git a/libs/compute_api/src/spec.rs b/libs/compute_api/src/spec.rs index 8cfd6b974a..12d825e1bf 100644 --- a/libs/compute_api/src/spec.rs +++ b/libs/compute_api/src/spec.rs @@ -12,9 +12,9 @@ use regex::Regex; use remote_storage::RemotePath; use serde::{Deserialize, Serialize}; use url::Url; -use utils::id::{TenantId, TimelineId}; +use utils::id::{NodeId, TenantId, TimelineId}; use utils::lsn::Lsn; -use utils::shard::{ShardCount, ShardIndex}; +use utils::shard::{ShardCount, ShardIndex, ShardNumber, ShardStripeSize}; use crate::responses::TlsConfig; @@ -115,10 +115,18 @@ pub struct ComputeSpec { /// The goal is to use method 1. everywhere. But for backwards-compatibility with old /// versions of the control plane, `compute_ctl` will check 2. and 3. if the /// `pageserver_connection_info` field is missing. + /// + /// If both `pageserver_connection_info` and `pageserver_connstring`+`shard_stripe_size` are + /// given, they must contain the same information. pub pageserver_connection_info: Option, pub pageserver_connstring: Option, + /// Stripe size for pageserver sharding, in pages. This is set together with the legacy + /// `pageserver_connstring` field. When the modern `pageserver_connection_info` field is used, + /// the stripe size is stored in `pageserver_connection_info.stripe_size` instead. + pub shard_stripe_size: Option, + // More neon ids that we expose to the compute_ctl // and to postgres as neon extension GUCs. pub project_id: Option, @@ -151,10 +159,6 @@ pub struct ComputeSpec { pub pgbouncer_settings: Option>, - // Stripe size for pageserver sharding, in pages - #[serde(default)] - pub shard_stripe_size: Option, - /// Local Proxy configuration used for JWT authentication #[serde(default)] pub local_proxy_config: Option, @@ -205,6 +209,9 @@ pub struct ComputeSpec { /// /// We use this value to derive other values, such as the installed extensions metric. pub suspend_timeout_seconds: i64, + + // Databricks specific options for compute instance. + pub databricks_settings: Option, } /// Feature flag to signal `compute_ctl` to enable certain experimental functionality. @@ -232,14 +239,122 @@ pub struct PageserverConnectionInfo { pub shard_count: ShardCount, /// INVARIANT: null if shard_count is 0, otherwise non-null and immutable - pub stripe_size: Option, + pub stripe_size: Option, pub shards: HashMap, + /// If the compute supports both protocols, this indicates which one it should use. The compute + /// may use other available protocols too, if it doesn't support the preferred one. The URL's + /// for the protocol specified here must be present for all shards, i.e. do not mark a protocol + /// as preferred if it cannot actually be used with all the pageservers. #[serde(default)] pub prefer_protocol: PageserverProtocol, } +/// Extract PageserverConnectionInfo from a comma-separated list of libpq connection strings. +/// +/// This is used for backwards-compatibility, to parse the legacy +/// [ComputeSpec::pageserver_connstring] field, or the 'neon.pageserver_connstring' GUC. Nowadays, +/// the 'pageserver_connection_info' field should be used instead. +impl PageserverConnectionInfo { + pub fn from_connstr( + connstr: &str, + stripe_size: Option, + ) -> Result { + let shard_infos: Vec<_> = connstr + .split(',') + .map(|connstr| PageserverShardInfo { + pageservers: vec![PageserverShardConnectionInfo { + id: None, + libpq_url: Some(connstr.to_string()), + grpc_url: None, + }], + }) + .collect(); + + match shard_infos.len() { + 0 => anyhow::bail!("empty connection string"), + 1 => { + // We assume that if there's only connection string, it means "unsharded", + // rather than a sharded system with just a single shard. The latter is + // possible in principle, but we never do it. + let shard_count = ShardCount::unsharded(); + let only_shard = shard_infos.first().unwrap().clone(); + let shards = vec![(ShardIndex::unsharded(), only_shard)]; + Ok(PageserverConnectionInfo { + shard_count, + stripe_size: None, + shards: shards.into_iter().collect(), + prefer_protocol: PageserverProtocol::Libpq, + }) + } + n => { + if stripe_size.is_none() { + anyhow::bail!("{n} shards but no stripe_size"); + } + let shard_count = ShardCount(n.try_into()?); + let shards = shard_infos + .into_iter() + .enumerate() + .map(|(idx, shard_info)| { + ( + ShardIndex { + shard_count, + shard_number: ShardNumber( + idx.try_into().expect("shard number fits in u8"), + ), + }, + shard_info, + ) + }) + .collect(); + Ok(PageserverConnectionInfo { + shard_count, + stripe_size, + shards, + prefer_protocol: PageserverProtocol::Libpq, + }) + } + } + } + + /// Convenience routine to get the connection string for a shard. + pub fn shard_url( + &self, + shard_number: ShardNumber, + protocol: PageserverProtocol, + ) -> anyhow::Result<&str> { + let shard_index = ShardIndex { + shard_number, + shard_count: self.shard_count, + }; + let shard = self.shards.get(&shard_index).ok_or_else(|| { + anyhow::anyhow!("shard connection info missing for shard {}", shard_index) + })?; + + // Just use the first pageserver in the list. That's good enough for this + // convenience routine; if you need more control, like round robin policy or + // failover support, roll your own. (As of this writing, we never have more than + // one pageserver per shard anyway, but that will change in the future.) + let pageserver = shard + .pageservers + .first() + .ok_or(anyhow::anyhow!("must have at least one pageserver"))?; + + let result = match protocol { + PageserverProtocol::Grpc => pageserver + .grpc_url + .as_ref() + .ok_or(anyhow::anyhow!("no grpc_url for shard {shard_index}"))?, + PageserverProtocol::Libpq => pageserver + .libpq_url + .as_ref() + .ok_or(anyhow::anyhow!("no libpq_url for shard {shard_index}"))?, + }; + Ok(result) + } +} + #[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] pub struct PageserverShardInfo { pub pageservers: Vec, @@ -247,7 +362,7 @@ pub struct PageserverShardInfo { #[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)] pub struct PageserverShardConnectionInfo { - pub id: Option, + pub id: Option, pub libpq_url: Option, pub grpc_url: Option, } diff --git a/libs/http-utils/src/endpoint.rs b/libs/http-utils/src/endpoint.rs index a61bf8e08a..c23d95f3e6 100644 --- a/libs/http-utils/src/endpoint.rs +++ b/libs/http-utils/src/endpoint.rs @@ -558,11 +558,11 @@ async fn add_request_id_header_to_response( mut res: Response, req_info: RequestInfo, ) -> Result, ApiError> { - if let Some(request_id) = req_info.context::() { - if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) { - res.headers_mut() - .insert(&X_REQUEST_ID_HEADER, request_header_value); - }; + if let Some(request_id) = req_info.context::() + && let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) + { + res.headers_mut() + .insert(&X_REQUEST_ID_HEADER, request_header_value); }; Ok(res) diff --git a/libs/http-utils/src/server.rs b/libs/http-utils/src/server.rs index f93f71c962..ce90b8d710 100644 --- a/libs/http-utils/src/server.rs +++ b/libs/http-utils/src/server.rs @@ -72,10 +72,10 @@ impl Server { if err.is_incomplete_message() || err.is_closed() || err.is_timeout() { return true; } - if let Some(inner) = err.source() { - if let Some(io) = inner.downcast_ref::() { - return suppress_io_error(io); - } + if let Some(inner) = err.source() + && let Some(io) = inner.downcast_ref::() + { + return suppress_io_error(io); } false } diff --git a/libs/metrics/src/lib.rs b/libs/metrics/src/lib.rs index 41873cdcd6..6cf27abcaf 100644 --- a/libs/metrics/src/lib.rs +++ b/libs/metrics/src/lib.rs @@ -129,6 +129,12 @@ impl InfoMetric { } } +impl Default for InfoMetric { + fn default() -> Self { + InfoMetric::new(L::default()) + } +} + impl> InfoMetric { pub fn with_metric(label: L, metric: M) -> Self { Self { diff --git a/libs/neon-shmem/src/hash.rs b/libs/neon-shmem/src/hash.rs index 28309fa5de..eb9b01d727 100644 --- a/libs/neon-shmem/src/hash.rs +++ b/libs/neon-shmem/src/hash.rs @@ -405,7 +405,7 @@ where // TODO: An Iterator might be nicer. The communicator's clock algorithm needs to // _slowly_ iterate through all buckets with its clock hand, without holding a lock. // If we switch to an Iterator, it must not hold the lock. - pub fn get_at_bucket(&self, pos: usize) -> Option> { + pub fn get_at_bucket(&self, pos: usize) -> Option> { let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read(); if pos >= map.buckets.len() { return None; diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 7c7c65fb70..230c1f46ea 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -1500,6 +1500,7 @@ pub struct TimelineArchivalConfigRequest { #[derive(Serialize, Deserialize, PartialEq, Eq, Clone)] pub struct TimelinePatchIndexPartRequest { pub rel_size_migration: Option, + pub rel_size_migrated_at: Option, pub gc_compaction_last_completed_lsn: Option, pub applied_gc_cutoff_lsn: Option, #[serde(default)] @@ -1533,10 +1534,10 @@ pub enum RelSizeMigration { /// `None` is the same as `Some(RelSizeMigration::Legacy)`. Legacy, /// The tenant is migrating to the new rel_size format. Both old and new rel_size format are - /// persisted in the index part. The read path will read both formats and merge them. + /// persisted in the storage. The read path will read both formats and validate them. Migrating, /// The tenant has migrated to the new rel_size format. Only the new rel_size format is persisted - /// in the index part, and the read path will not read the old format. + /// in the storage, and the read path will not read the old format. Migrated, } @@ -1619,6 +1620,7 @@ pub struct TimelineInfo { /// The status of the rel_size migration. pub rel_size_migration: Option, + pub rel_size_migrated_at: Option, /// Whether the timeline is invisible in synthetic size calculations. pub is_invisible: Option, diff --git a/libs/postgres_ffi/Cargo.toml b/libs/postgres_ffi/Cargo.toml index d4fec6cbe9..23fabeccd2 100644 --- a/libs/postgres_ffi/Cargo.toml +++ b/libs/postgres_ffi/Cargo.toml @@ -9,10 +9,7 @@ regex.workspace = true bytes.workspace = true anyhow.workspace = true crc32c.workspace = true -criterion.workspace = true once_cell.workspace = true -log.workspace = true -memoffset.workspace = true pprof.workspace = true thiserror.workspace = true serde.workspace = true @@ -22,6 +19,7 @@ tracing.workspace = true postgres_versioninfo.workspace = true [dev-dependencies] +criterion.workspace = true env_logger.workspace = true postgres.workspace = true diff --git a/libs/postgres_ffi/src/controlfile_utils.rs b/libs/postgres_ffi/src/controlfile_utils.rs index eaa9450294..d6d17ce3fb 100644 --- a/libs/postgres_ffi/src/controlfile_utils.rs +++ b/libs/postgres_ffi/src/controlfile_utils.rs @@ -34,9 +34,8 @@ const SIZEOF_CONTROLDATA: usize = size_of::(); impl ControlFileData { /// Compute the offset of the `crc` field within the `ControlFileData` struct. /// Equivalent to offsetof(ControlFileData, crc) in C. - // Someday this can be const when the right compiler features land. - fn pg_control_crc_offset() -> usize { - memoffset::offset_of!(ControlFileData, crc) + const fn pg_control_crc_offset() -> usize { + std::mem::offset_of!(ControlFileData, crc) } /// diff --git a/libs/postgres_ffi/src/nonrelfile_utils.rs b/libs/postgres_ffi/src/nonrelfile_utils.rs index e3e7133b94..f6693d4ec1 100644 --- a/libs/postgres_ffi/src/nonrelfile_utils.rs +++ b/libs/postgres_ffi/src/nonrelfile_utils.rs @@ -4,12 +4,11 @@ use crate::pg_constants; use crate::transaction_id_precedes; use bytes::BytesMut; -use log::*; use super::bindings::MultiXactId; pub fn transaction_id_set_status(xid: u32, status: u8, page: &mut BytesMut) { - trace!( + tracing::trace!( "handle_apply_request for RM_XACT_ID-{} (1-commit, 2-abort, 3-sub_commit)", status ); diff --git a/libs/postgres_ffi/src/waldecoder_handler.rs b/libs/postgres_ffi/src/waldecoder_handler.rs index 9cd40645ec..563a3426a0 100644 --- a/libs/postgres_ffi/src/waldecoder_handler.rs +++ b/libs/postgres_ffi/src/waldecoder_handler.rs @@ -14,7 +14,6 @@ use super::xlog_utils::*; use crate::WAL_SEGMENT_SIZE; use bytes::{Buf, BufMut, Bytes, BytesMut}; use crc32c::*; -use log::*; use std::cmp::min; use std::num::NonZeroU32; use utils::lsn::Lsn; @@ -236,7 +235,7 @@ impl WalStreamDecoderHandler for WalStreamDecoder { // XLOG_SWITCH records are special. If we see one, we need to skip // to the next WAL segment. let next_lsn = if xlogrec.is_xlog_switch_record() { - trace!("saw xlog switch record at {}", self.lsn); + tracing::trace!("saw xlog switch record at {}", self.lsn); self.lsn + self.lsn.calc_padding(WAL_SEGMENT_SIZE as u64) } else { // Pad to an 8-byte boundary diff --git a/libs/postgres_ffi/src/xlog_utils.rs b/libs/postgres_ffi/src/xlog_utils.rs index 134baf5ff7..913e6b453f 100644 --- a/libs/postgres_ffi/src/xlog_utils.rs +++ b/libs/postgres_ffi/src/xlog_utils.rs @@ -23,8 +23,6 @@ use crate::{WAL_SEGMENT_SIZE, XLOG_BLCKSZ}; use bytes::BytesMut; use bytes::{Buf, Bytes}; -use log::*; - use serde::Serialize; use std::ffi::{CString, OsStr}; use std::fs::File; @@ -235,7 +233,7 @@ pub fn find_end_of_wal( let mut curr_lsn = start_lsn; let mut buf = [0u8; XLOG_BLCKSZ]; let pg_version = MY_PGVERSION; - debug!("find_end_of_wal PG_VERSION: {}", pg_version); + tracing::debug!("find_end_of_wal PG_VERSION: {}", pg_version); let mut decoder = WalStreamDecoder::new(start_lsn, pg_version); @@ -247,7 +245,7 @@ pub fn find_end_of_wal( match open_wal_segment(&seg_file_path)? { None => { // no more segments - debug!( + tracing::debug!( "find_end_of_wal reached end at {:?}, segment {:?} doesn't exist", result, seg_file_path ); @@ -260,7 +258,7 @@ pub fn find_end_of_wal( while curr_lsn.segment_number(wal_seg_size) == segno { let bytes_read = segment.read(&mut buf)?; if bytes_read == 0 { - debug!( + tracing::debug!( "find_end_of_wal reached end at {:?}, EOF in segment {:?} at offset {}", result, seg_file_path, @@ -276,7 +274,7 @@ pub fn find_end_of_wal( match decoder.poll_decode() { Ok(Some(record)) => result = record.0, Err(e) => { - debug!( + tracing::debug!( "find_end_of_wal reached end at {:?}, decode error: {:?}", result, e ); diff --git a/proxy/subzero_core/.gitignore b/libs/proxy/subzero_core/.gitignore similarity index 100% rename from proxy/subzero_core/.gitignore rename to libs/proxy/subzero_core/.gitignore diff --git a/proxy/subzero_core/Cargo.toml b/libs/proxy/subzero_core/Cargo.toml similarity index 100% rename from proxy/subzero_core/Cargo.toml rename to libs/proxy/subzero_core/Cargo.toml diff --git a/proxy/subzero_core/src/lib.rs b/libs/proxy/subzero_core/src/lib.rs similarity index 100% rename from proxy/subzero_core/src/lib.rs rename to libs/proxy/subzero_core/src/lib.rs diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index 068566e955..90ff39aff1 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -15,6 +15,7 @@ use tokio::sync::mpsc; use crate::cancel_token::RawCancelToken; use crate::codec::{BackendMessages, FrontendMessage, RecordNotices}; use crate::config::{Host, SslMode}; +use crate::connection::gc_bytesmut; use crate::query::RowStream; use crate::simple_query::SimpleQueryStream; use crate::types::{Oid, Type}; @@ -95,20 +96,13 @@ impl InnerClient { Ok(PartialQuery(Some(self))) } - // pub fn send_with_sync(&mut self, f: F) -> Result<&mut Responses, Error> - // where - // F: FnOnce(&mut BytesMut) -> Result<(), Error>, - // { - // self.start()?.send_with_sync(f) - // } - pub fn send_simple_query(&mut self, query: &str) -> Result<&mut Responses, Error> { self.responses.waiting += 1; self.buffer.clear(); // simple queries do not need sync. frontend::query(query, &mut self.buffer).map_err(Error::encode)?; - let buf = self.buffer.split().freeze(); + let buf = self.buffer.split(); self.send_message(FrontendMessage::Raw(buf)) } @@ -125,7 +119,7 @@ impl Drop for PartialQuery<'_> { if let Some(client) = self.0.take() { client.buffer.clear(); frontend::sync(&mut client.buffer); - let buf = client.buffer.split().freeze(); + let buf = client.buffer.split(); let _ = client.send_message(FrontendMessage::Raw(buf)); } } @@ -141,7 +135,7 @@ impl<'a> PartialQuery<'a> { client.buffer.clear(); f(&mut client.buffer)?; frontend::flush(&mut client.buffer); - let buf = client.buffer.split().freeze(); + let buf = client.buffer.split(); client.send_message(FrontendMessage::Raw(buf)) } @@ -154,7 +148,7 @@ impl<'a> PartialQuery<'a> { client.buffer.clear(); f(&mut client.buffer)?; frontend::sync(&mut client.buffer); - let buf = client.buffer.split().freeze(); + let buf = client.buffer.split(); let _ = client.send_message(FrontendMessage::Raw(buf)); Ok(&mut self.0.take().unwrap().responses) @@ -191,6 +185,7 @@ impl Client { ssl_mode: SslMode, process_id: i32, secret_key: i32, + write_buf: BytesMut, ) -> Client { Client { inner: InnerClient { @@ -201,7 +196,7 @@ impl Client { waiting: 0, received: 0, }, - buffer: Default::default(), + buffer: write_buf, }, cached_typeinfo: Default::default(), @@ -292,8 +287,35 @@ impl Client { simple_query::batch_execute(self.inner_mut(), query).await } - pub async fn discard_all(&mut self) -> Result { - self.batch_execute("discard all").await + /// Similar to `discard_all`, but it does not clear any query plans + /// + /// This runs in the background, so it can be executed without `await`ing. + pub fn reset_session_background(&mut self) -> Result<(), Error> { + // "CLOSE ALL": closes any cursors + // "SET SESSION AUTHORIZATION DEFAULT": resets the current_user back to the session_user + // "RESET ALL": resets any GUCs back to their session defaults. + // "DEALLOCATE ALL": deallocates any prepared statements + // "UNLISTEN *": stops listening on all channels + // "SELECT pg_advisory_unlock_all();": unlocks all advisory locks + // "DISCARD TEMP;": drops all temporary tables + // "DISCARD SEQUENCES;": deallocates all cached sequence state + + let _responses = self.inner_mut().send_simple_query( + "ROLLBACK; + CLOSE ALL; + SET SESSION AUTHORIZATION DEFAULT; + RESET ALL; + DEALLOCATE ALL; + UNLISTEN *; + SELECT pg_advisory_unlock_all(); + DISCARD TEMP; + DISCARD SEQUENCES;", + )?; + + // Clean up memory usage. + gc_bytesmut(&mut self.inner_mut().buffer); + + Ok(()) } /// Begins a new database transaction. diff --git a/libs/proxy/tokio-postgres2/src/codec.rs b/libs/proxy/tokio-postgres2/src/codec.rs index 813faa0e35..35f616d229 100644 --- a/libs/proxy/tokio-postgres2/src/codec.rs +++ b/libs/proxy/tokio-postgres2/src/codec.rs @@ -1,13 +1,13 @@ use std::io; -use bytes::{Bytes, BytesMut}; +use bytes::BytesMut; use fallible_iterator::FallibleIterator; use postgres_protocol2::message::backend; use tokio::sync::mpsc::UnboundedSender; use tokio_util::codec::{Decoder, Encoder}; pub enum FrontendMessage { - Raw(Bytes), + Raw(BytesMut), RecordNotices(RecordNotices), } @@ -17,7 +17,10 @@ pub struct RecordNotices { } pub enum BackendMessage { - Normal { messages: BackendMessages }, + Normal { + messages: BackendMessages, + ready: bool, + }, Async(backend::Message), } @@ -40,11 +43,11 @@ impl FallibleIterator for BackendMessages { pub struct PostgresCodec; -impl Encoder for PostgresCodec { +impl Encoder for PostgresCodec { type Error = io::Error; - fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> io::Result<()> { - dst.extend_from_slice(&item); + fn encode(&mut self, item: BytesMut, dst: &mut BytesMut) -> io::Result<()> { + dst.unsplit(item); Ok(()) } } @@ -56,6 +59,7 @@ impl Decoder for PostgresCodec { fn decode(&mut self, src: &mut BytesMut) -> Result, io::Error> { let mut idx = 0; + let mut ready = false; while let Some(header) = backend::Header::parse(&src[idx..])? { let len = header.len() as usize + 1; if src[idx..].len() < len { @@ -79,6 +83,7 @@ impl Decoder for PostgresCodec { idx += len; if header.tag() == backend::READY_FOR_QUERY_TAG { + ready = true; break; } } @@ -88,6 +93,7 @@ impl Decoder for PostgresCodec { } else { Ok(Some(BackendMessage::Normal { messages: BackendMessages(src.split_to(idx)), + ready, })) } } diff --git a/libs/proxy/tokio-postgres2/src/config.rs b/libs/proxy/tokio-postgres2/src/config.rs index 961cbc923e..3579dd94a2 100644 --- a/libs/proxy/tokio-postgres2/src/config.rs +++ b/libs/proxy/tokio-postgres2/src/config.rs @@ -11,9 +11,8 @@ use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; use crate::connect::connect; -use crate::connect_raw::{RawConnection, connect_raw}; +use crate::connect_raw::{self, StartupStream}; use crate::connect_tls::connect_tls; -use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::{MakeTlsConnect, TlsConnect, TlsStream}; use crate::{Client, Connection, Error}; @@ -244,24 +243,27 @@ impl Config { &self, stream: S, tls: T, - ) -> Result, Error> + ) -> Result, Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsConnect, { let stream = connect_tls(stream, self.ssl_mode, tls).await?; - connect_raw(stream, self).await + let mut stream = StartupStream::new(stream); + connect_raw::authenticate(&mut stream, self).await?; + + Ok(stream) } - pub async fn authenticate( + pub fn authenticate( &self, - stream: MaybeTlsStream, - ) -> Result, Error> + stream: &mut StartupStream, + ) -> impl Future> where S: AsyncRead + AsyncWrite + Unpin, T: TlsStream + Unpin, { - connect_raw(stream, self).await + connect_raw::authenticate(stream, self) } } diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs index 2f718e1e7d..b1df87811e 100644 --- a/libs/proxy/tokio-postgres2/src/connect.rs +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -1,15 +1,17 @@ use std::net::IpAddr; +use futures_util::TryStreamExt; +use postgres_protocol2::message::backend::Message; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; use tokio::sync::mpsc; use crate::client::SocketConfig; -use crate::config::Host; -use crate::connect_raw::connect_raw; +use crate::config::{Host, SslMode}; +use crate::connect_raw::StartupStream; use crate::connect_socket::connect_socket; -use crate::connect_tls::connect_tls; use crate::tls::{MakeTlsConnect, TlsConnect}; -use crate::{Client, Config, Connection, Error, RawConnection}; +use crate::{Client, Config, Connection, Error}; pub async fn connect( tls: &T, @@ -43,34 +45,78 @@ where T: TlsConnect, { let socket = connect_socket(host_addr, host, port, config.connect_timeout).await?; - let stream = connect_tls(socket, config.ssl_mode, tls).await?; - let RawConnection { + let stream = config.tls_and_authenticate(socket, tls).await?; + managed( stream, - parameters: _, - delayed_notice: _, - process_id, - secret_key, - } = connect_raw(stream, config).await?; + host_addr, + host.clone(), + port, + config.ssl_mode, + config.connect_timeout, + ) + .await +} + +pub async fn managed( + mut stream: StartupStream, + host_addr: Option, + host: Host, + port: u16, + ssl_mode: SslMode, + connect_timeout: Option, +) -> Result<(Client, Connection), Error> +where + TlsStream: AsyncRead + AsyncWrite + Unpin, +{ + let (process_id, secret_key) = wait_until_ready(&mut stream).await?; let socket_config = SocketConfig { host_addr, - host: host.clone(), + host, port, - connect_timeout: config.connect_timeout, + connect_timeout, }; + let mut stream = stream.into_framed(); + let write_buf = std::mem::take(stream.write_buffer_mut()); + let (client_tx, conn_rx) = mpsc::unbounded_channel(); let (conn_tx, client_rx) = mpsc::channel(4); let client = Client::new( client_tx, client_rx, socket_config, - config.ssl_mode, + ssl_mode, process_id, secret_key, + write_buf, ); let connection = Connection::new(stream, conn_tx, conn_rx); Ok((client, connection)) } + +async fn wait_until_ready(stream: &mut StartupStream) -> Result<(i32, i32), Error> +where + S: AsyncRead + AsyncWrite + Unpin, + T: AsyncRead + AsyncWrite + Unpin, +{ + let mut process_id = 0; + let mut secret_key = 0; + + loop { + match stream.try_next().await.map_err(Error::io)? { + Some(Message::BackendKeyData(body)) => { + process_id = body.process_id(); + secret_key = body.secret_key(); + } + // These values are currently not used by `Client`/`Connection`. Ignore them. + Some(Message::ParameterStatus(_)) | Some(Message::NoticeResponse(_)) => {} + Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key)), + Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), + Some(_) => return Err(Error::unexpected_message()), + None => return Err(Error::closed()), + } + } +} diff --git a/libs/proxy/tokio-postgres2/src/connect_raw.rs b/libs/proxy/tokio-postgres2/src/connect_raw.rs index 462e1be1aa..17237eeef5 100644 --- a/libs/proxy/tokio-postgres2/src/connect_raw.rs +++ b/libs/proxy/tokio-postgres2/src/connect_raw.rs @@ -1,52 +1,27 @@ -use std::collections::HashMap; use std::io; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::task::{Context, Poll, ready}; -use bytes::{Bytes, BytesMut}; +use bytes::BytesMut; use fallible_iterator::FallibleIterator; -use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready}; +use futures_util::{SinkExt, Stream, TryStreamExt}; use postgres_protocol2::authentication::sasl; use postgres_protocol2::authentication::sasl::ScramSha256; -use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody}; +use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message}; use postgres_protocol2::message::frontend; -use tokio::io::{AsyncRead, AsyncWrite}; -use tokio_util::codec::Framed; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_util::codec::{Framed, FramedParts}; use crate::Error; -use crate::codec::{BackendMessage, BackendMessages, PostgresCodec}; +use crate::codec::PostgresCodec; use crate::config::{self, AuthKeys, Config}; +use crate::connection::{GC_THRESHOLD, INITIAL_CAPACITY}; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::TlsStream; pub struct StartupStream { inner: Framed, PostgresCodec>, - buf: BackendMessages, - delayed_notice: Vec, -} - -impl Sink for StartupStream -where - S: AsyncRead + AsyncWrite + Unpin, - T: AsyncRead + AsyncWrite + Unpin, -{ - type Error = io::Error; - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_ready(cx) - } - - fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> io::Result<()> { - Pin::new(&mut self.inner).start_send(item) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_flush(cx) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_close(cx) - } + read_buf: BytesMut, } impl Stream for StartupStream @@ -56,78 +31,109 @@ where { type Item = io::Result; - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - loop { - match self.buf.next() { - Ok(Some(message)) => return Poll::Ready(Some(Ok(message))), - Ok(None) => {} - Err(e) => return Poll::Ready(Some(Err(e))), - } + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // We don't use `self.inner.poll_next()` as that might over-read into the read buffer. - match ready!(Pin::new(&mut self.inner).poll_next(cx)) { - Some(Ok(BackendMessage::Normal { messages, .. })) => self.buf = messages, - Some(Ok(BackendMessage::Async(message))) => return Poll::Ready(Some(Ok(message))), - Some(Err(e)) => return Poll::Ready(Some(Err(e))), - None => return Poll::Ready(None), - } + // read 1 byte tag, 4 bytes length. + let header = ready!(self.as_mut().poll_fill_buf_exact(cx, 5)?); + + let len = u32::from_be_bytes(header[1..5].try_into().unwrap()); + if len < 4 { + return Poll::Ready(Some(Err(std::io::Error::other( + "postgres message too small", + )))); } + if len >= 65536 { + return Poll::Ready(Some(Err(std::io::Error::other( + "postgres message too large", + )))); + } + + // the tag is an additional byte. + let _message = ready!(self.as_mut().poll_fill_buf_exact(cx, len as usize + 1)?); + + // Message::parse will remove the all the bytes from the buffer. + Poll::Ready(Message::parse(&mut self.read_buf).transpose()) } } -pub struct RawConnection { - pub stream: Framed, PostgresCodec>, - pub parameters: HashMap, - pub delayed_notice: Vec, - pub process_id: i32, - pub secret_key: i32, -} - -pub async fn connect_raw( - stream: MaybeTlsStream, - config: &Config, -) -> Result, Error> -where - S: AsyncRead + AsyncWrite + Unpin, - T: TlsStream + Unpin, -{ - let mut stream = StartupStream { - inner: Framed::new(stream, PostgresCodec), - buf: BackendMessages::empty(), - delayed_notice: Vec::new(), - }; - - startup(&mut stream, config).await?; - authenticate(&mut stream, config).await?; - let (process_id, secret_key, parameters) = read_info(&mut stream).await?; - - Ok(RawConnection { - stream: stream.inner, - parameters, - delayed_notice: stream.delayed_notice, - process_id, - secret_key, - }) -} - -async fn startup(stream: &mut StartupStream, config: &Config) -> Result<(), Error> +impl StartupStream where S: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin, { - let mut buf = BytesMut::new(); - frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?; + /// Fill the buffer until it's the exact length provided. No additional data will be read from the socket. + /// + /// If the current buffer length is greater, nothing happens. + fn poll_fill_buf_exact( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + len: usize, + ) -> Poll> { + let this = self.get_mut(); + let mut stream = Pin::new(this.inner.get_mut()); - stream.send(buf.freeze()).await.map_err(Error::io) + let mut n = this.read_buf.len(); + while n < len { + this.read_buf.resize(len, 0); + + let mut buf = ReadBuf::new(&mut this.read_buf[..]); + buf.set_filled(n); + + if stream.as_mut().poll_read(cx, &mut buf)?.is_pending() { + this.read_buf.truncate(n); + return Poll::Pending; + } + + if buf.filled().len() == n { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "early eof", + ))); + } + n = buf.filled().len(); + + this.read_buf.truncate(n); + } + + Poll::Ready(Ok(&this.read_buf[..len])) + } + + pub fn into_framed(mut self) -> Framed, PostgresCodec> { + *self.inner.read_buffer_mut() = self.read_buf; + self.inner + } + + pub fn new(io: MaybeTlsStream) -> Self { + let mut parts = FramedParts::new(io, PostgresCodec); + parts.write_buf = BytesMut::with_capacity(INITIAL_CAPACITY); + + let mut inner = Framed::from_parts(parts); + + // This is the default already, but nice to be explicit. + // We divide by two because writes will overshoot the boundary. + // We don't want constant overshoots to cause us to constantly re-shrink the buffer. + inner.set_backpressure_boundary(GC_THRESHOLD / 2); + + Self { + inner, + read_buf: BytesMut::with_capacity(INITIAL_CAPACITY), + } + } } -async fn authenticate(stream: &mut StartupStream, config: &Config) -> Result<(), Error> +pub(crate) async fn authenticate( + stream: &mut StartupStream, + config: &Config, +) -> Result<(), Error> where S: AsyncRead + AsyncWrite + Unpin, T: TlsStream + Unpin, { + frontend::startup_message(&config.server_params, stream.inner.write_buffer_mut()) + .map_err(Error::encode)?; + + stream.inner.flush().await.map_err(Error::io)?; match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationOk) => { can_skip_channel_binding(config)?; @@ -141,7 +147,8 @@ where .as_ref() .ok_or_else(|| Error::config("password missing".into()))?; - authenticate_password(stream, pass).await?; + frontend::password_message(pass, stream.inner.write_buffer_mut()) + .map_err(Error::encode)?; } Some(Message::AuthenticationSasl(body)) => { authenticate_sasl(stream, body, config).await?; @@ -160,6 +167,7 @@ where None => return Err(Error::closed()), } + stream.inner.flush().await.map_err(Error::io)?; match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationOk) => Ok(()), Some(Message::ErrorResponse(body)) => Err(Error::db(body)), @@ -177,20 +185,6 @@ fn can_skip_channel_binding(config: &Config) -> Result<(), Error> { } } -async fn authenticate_password( - stream: &mut StartupStream, - password: &[u8], -) -> Result<(), Error> -where - S: AsyncRead + AsyncWrite + Unpin, - T: AsyncRead + AsyncWrite + Unpin, -{ - let mut buf = BytesMut::new(); - frontend::password_message(password, &mut buf).map_err(Error::encode)?; - - stream.send(buf.freeze()).await.map_err(Error::io) -} - async fn authenticate_sasl( stream: &mut StartupStream, body: AuthenticationSaslBody, @@ -245,10 +239,10 @@ where return Err(Error::config("password or auth keys missing".into())); }; - let mut buf = BytesMut::new(); - frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?; - stream.send(buf.freeze()).await.map_err(Error::io)?; + frontend::sasl_initial_response(mechanism, scram.message(), stream.inner.write_buffer_mut()) + .map_err(Error::encode)?; + stream.inner.flush().await.map_err(Error::io)?; let body = match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationSaslContinue(body)) => body, Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), @@ -261,10 +255,10 @@ where .await .map_err(|e| Error::authentication(e.into()))?; - let mut buf = BytesMut::new(); - frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?; - stream.send(buf.freeze()).await.map_err(Error::io)?; + frontend::sasl_response(scram.message(), stream.inner.write_buffer_mut()) + .map_err(Error::encode)?; + stream.inner.flush().await.map_err(Error::io)?; let body = match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationSaslFinal(body)) => body, Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), @@ -278,35 +272,3 @@ where Ok(()) } - -async fn read_info( - stream: &mut StartupStream, -) -> Result<(i32, i32, HashMap), Error> -where - S: AsyncRead + AsyncWrite + Unpin, - T: AsyncRead + AsyncWrite + Unpin, -{ - let mut process_id = 0; - let mut secret_key = 0; - let mut parameters = HashMap::new(); - - loop { - match stream.try_next().await.map_err(Error::io)? { - Some(Message::BackendKeyData(body)) => { - process_id = body.process_id(); - secret_key = body.secret_key(); - } - Some(Message::ParameterStatus(body)) => { - parameters.insert( - body.name().map_err(Error::parse)?.to_string(), - body.value().map_err(Error::parse)?.to_string(), - ); - } - Some(Message::NoticeResponse(body)) => stream.delayed_notice.push(body), - Some(Message::ReadyForQuery(_)) => return Ok((process_id, secret_key, parameters)), - Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), - Some(_) => return Err(Error::unexpected_message()), - None => return Err(Error::closed()), - } - } -} diff --git a/libs/proxy/tokio-postgres2/src/connection.rs b/libs/proxy/tokio-postgres2/src/connection.rs index c43a22ffe7..303de71cfa 100644 --- a/libs/proxy/tokio-postgres2/src/connection.rs +++ b/libs/proxy/tokio-postgres2/src/connection.rs @@ -44,6 +44,27 @@ pub struct Connection { state: State, } +pub const INITIAL_CAPACITY: usize = 2 * 1024; +pub const GC_THRESHOLD: usize = 16 * 1024; + +/// Gargabe collect the [`BytesMut`] if it has too much spare capacity. +pub fn gc_bytesmut(buf: &mut BytesMut) { + // We use a different mode to shrink the buf when above the threshold. + // When above the threshold, we only re-allocate when the buf has 2x spare capacity. + let reclaim = GC_THRESHOLD.checked_sub(buf.len()).unwrap_or(buf.len()); + + // `try_reclaim` tries to get the capacity from any shared `BytesMut`s, + // before then comparing the length against the capacity. + if buf.try_reclaim(reclaim) { + let capacity = usize::max(buf.len(), INITIAL_CAPACITY); + + // Allocate a new `BytesMut` so that we deallocate the old version. + let mut new = BytesMut::with_capacity(capacity); + new.extend_from_slice(buf); + *buf = new; + } +} + pub enum Never {} impl Connection @@ -86,7 +107,14 @@ where continue; } BackendMessage::Async(_) => continue, - BackendMessage::Normal { messages } => messages, + BackendMessage::Normal { messages, ready } => { + // if we read a ReadyForQuery from postgres, let's try GC the read buffer. + if ready { + gc_bytesmut(self.stream.read_buffer_mut()); + } + + messages + } } } }; @@ -177,12 +205,7 @@ where // Send a terminate message to postgres Poll::Ready(None) => { trace!("poll_write: at eof, terminating"); - let mut request = BytesMut::new(); - frontend::terminate(&mut request); - - Pin::new(&mut self.stream) - .start_send(request.freeze()) - .map_err(Error::io)?; + frontend::terminate(self.stream.write_buffer_mut()); trace!("poll_write: sent eof, closing"); trace!("poll_write: done"); @@ -205,6 +228,13 @@ where { Poll::Ready(()) => { trace!("poll_flush: flushed"); + + // Since our codec prefers to share the buffer with the `Client`, + // if we don't release our share, then the `Client` would have to re-alloc + // the buffer when they next use it. + debug_assert!(self.stream.write_buffer().is_empty()); + *self.stream.write_buffer_mut() = BytesMut::new(); + Poll::Ready(Ok(())) } Poll::Pending => { diff --git a/libs/proxy/tokio-postgres2/src/error/mod.rs b/libs/proxy/tokio-postgres2/src/error/mod.rs index 5309bce17e..3fbb97f9bb 100644 --- a/libs/proxy/tokio-postgres2/src/error/mod.rs +++ b/libs/proxy/tokio-postgres2/src/error/mod.rs @@ -9,7 +9,7 @@ use postgres_protocol2::message::backend::{ErrorFields, ErrorResponseBody}; pub use self::sqlstate::*; #[allow(clippy::unreadable_literal)] -mod sqlstate; +pub mod sqlstate; /// The severity of a Postgres error or notice. #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -452,16 +452,16 @@ impl Error { Error(Box::new(ErrorInner { kind, cause })) } - pub(crate) fn closed() -> Error { + pub fn closed() -> Error { Error::new(Kind::Closed, None) } - pub(crate) fn unexpected_message() -> Error { + pub fn unexpected_message() -> Error { Error::new(Kind::UnexpectedMessage, None) } #[allow(clippy::needless_pass_by_value)] - pub(crate) fn db(error: ErrorResponseBody) -> Error { + pub fn db(error: ErrorResponseBody) -> Error { match DbError::parse(&mut error.fields()) { Ok(e) => Error::new(Kind::Db, Some(Box::new(e))), Err(e) => Error::new(Kind::Parse, Some(Box::new(e))), @@ -493,7 +493,7 @@ impl Error { Error::new(Kind::Tls, Some(e)) } - pub(crate) fn io(e: io::Error) -> Error { + pub fn io(e: io::Error) -> Error { Error::new(Kind::Io, Some(Box::new(e))) } diff --git a/libs/proxy/tokio-postgres2/src/lib.rs b/libs/proxy/tokio-postgres2/src/lib.rs index e3dd6d9261..da2665095c 100644 --- a/libs/proxy/tokio-postgres2/src/lib.rs +++ b/libs/proxy/tokio-postgres2/src/lib.rs @@ -6,7 +6,6 @@ use postgres_protocol2::message::backend::ReadyForQueryBody; pub use crate::cancel_token::{CancelToken, RawCancelToken}; pub use crate::client::{Client, SocketConfig}; pub use crate::config::Config; -pub use crate::connect_raw::RawConnection; pub use crate::connection::Connection; pub use crate::error::Error; pub use crate::generic_client::GenericClient; @@ -49,8 +48,8 @@ mod cancel_token; mod client; mod codec; pub mod config; -mod connect; -mod connect_raw; +pub mod connect; +pub mod connect_raw; mod connect_socket; mod connect_tls; mod connection; diff --git a/libs/safekeeper_api/src/models.rs b/libs/safekeeper_api/src/models.rs index a300c8464f..b34ed947c0 100644 --- a/libs/safekeeper_api/src/models.rs +++ b/libs/safekeeper_api/src/models.rs @@ -301,7 +301,12 @@ pub struct PullTimelineRequest { pub tenant_id: TenantId, pub timeline_id: TimelineId, pub http_hosts: Vec, - pub ignore_tombstone: Option, + /// Membership configuration to switch to after pull. + /// It guarantees that if pull_timeline returns successfully, the timeline will + /// not be deleted by request with an older generation. + /// Storage controller always sets this field. + /// None is only allowed for manual pull_timeline requests. + pub mconf: Option, } #[derive(Debug, Serialize, Deserialize)] diff --git a/libs/tracing-utils/src/perf_span.rs b/libs/tracing-utils/src/perf_span.rs index 16f713c67e..4eec0829f7 100644 --- a/libs/tracing-utils/src/perf_span.rs +++ b/libs/tracing-utils/src/perf_span.rs @@ -49,7 +49,7 @@ impl PerfSpan { } } - pub fn enter(&self) -> PerfSpanEntered { + pub fn enter(&self) -> PerfSpanEntered<'_> { if let Some(ref id) = self.inner.id() { self.dispatch.enter(id); } diff --git a/libs/utils/src/logging.rs b/libs/utils/src/logging.rs index d67c0f123b..9f118048f3 100644 --- a/libs/utils/src/logging.rs +++ b/libs/utils/src/logging.rs @@ -34,13 +34,16 @@ macro_rules! critical { #[macro_export] macro_rules! critical_timeline { - ($tenant_shard_id:expr, $timeline_id:expr, $($arg:tt)*) => {{ + ($tenant_shard_id:expr, $timeline_id:expr, $corruption_detected:expr, $($arg:tt)*) => {{ if cfg!(debug_assertions) { panic!($($arg)*); } // Increment both metrics $crate::logging::TRACING_EVENT_COUNT_METRIC.inc_critical(); $crate::logging::HADRON_CRITICAL_STORAGE_EVENT_COUNT_METRIC.inc(&$tenant_shard_id.to_string(), &$timeline_id.to_string()); + if let Some(c) = $corruption_detected.as_ref() { + c.store(true, std::sync::atomic::Ordering::Relaxed); + } let backtrace = std::backtrace::Backtrace::capture(); tracing::error!("CRITICAL: [tenant_shard_id: {}, timeline_id: {}] {}\n{backtrace}", $tenant_shard_id, $timeline_id, format!($($arg)*)); diff --git a/libs/utils/src/lsn.rs b/libs/utils/src/lsn.rs index 1abb63817b..47b7e6a888 100644 --- a/libs/utils/src/lsn.rs +++ b/libs/utils/src/lsn.rs @@ -310,6 +310,11 @@ impl AtomicLsn { } } + /// Consumes the atomic and returns the contained value. + pub const fn into_inner(self) -> Lsn { + Lsn(self.inner.into_inner()) + } + /// Atomically retrieve the `Lsn` value from memory. pub fn load(&self) -> Lsn { Lsn(self.inner.load(Ordering::Acquire)) diff --git a/libs/utils/src/pageserver_feedback.rs b/libs/utils/src/pageserver_feedback.rs index cffbc0b4d6..da5b53306a 100644 --- a/libs/utils/src/pageserver_feedback.rs +++ b/libs/utils/src/pageserver_feedback.rs @@ -32,6 +32,9 @@ pub struct PageserverFeedback { pub replytime: SystemTime, /// Used to track feedbacks from different shards. Always zero for unsharded tenants. pub shard_number: u32, + /// If true, the pageserver has detected corruption and the safekeeper and postgres + /// should stop sending WAL. + pub corruption_detected: bool, } impl PageserverFeedback { @@ -43,6 +46,7 @@ impl PageserverFeedback { disk_consistent_lsn: Lsn::INVALID, replytime: *PG_EPOCH, shard_number: 0, + corruption_detected: false, } } @@ -101,6 +105,13 @@ impl PageserverFeedback { buf.put_u32(self.shard_number); } + if self.corruption_detected { + nkeys += 1; + buf.put_slice(b"corruption_detected\0"); + buf.put_i32(1); + buf.put_u8(1); + } + buf[buf_ptr] = nkeys; } @@ -147,6 +158,11 @@ impl PageserverFeedback { assert_eq!(len, 4); rf.shard_number = buf.get_u32(); } + b"corruption_detected" => { + let len = buf.get_i32(); + assert_eq!(len, 1); + rf.corruption_detected = buf.get_u8() != 0; + } _ => { let len = buf.get_i32(); warn!( @@ -206,6 +222,26 @@ mod tests { assert_eq!(rf, rf_parsed); } + // Test that databricks-specific fields added to the PageserverFeedback message are serialized + // and deserialized correctly, in addition to the existing fields from upstream. + #[test] + fn test_replication_feedback_databricks_fields() { + let mut rf = PageserverFeedback::empty(); + rf.current_timeline_size = 12345678; + rf.last_received_lsn = Lsn(23456789); + rf.disk_consistent_lsn = Lsn(34567890); + rf.remote_consistent_lsn = Lsn(45678901); + rf.replytime = *PG_EPOCH + Duration::from_secs(100_000_000); + rf.shard_number = 1; + rf.corruption_detected = true; + + let mut data = BytesMut::new(); + rf.serialize(&mut data); + + let rf_parsed = PageserverFeedback::parse(data.freeze()); + assert_eq!(rf, rf_parsed); + } + #[test] fn test_replication_feedback_unknown_key() { let mut rf = PageserverFeedback::empty(); diff --git a/libs/walproposer/src/api_bindings.rs b/libs/walproposer/src/api_bindings.rs index 825a137d0f..9c90beb379 100644 --- a/libs/walproposer/src/api_bindings.rs +++ b/libs/walproposer/src/api_bindings.rs @@ -341,6 +341,34 @@ extern "C-unwind" fn log_internal( } } +/* BEGIN_HADRON */ +extern "C" fn reset_safekeeper_statuses_for_metrics(wp: *mut WalProposer, num_safekeepers: u32) { + unsafe { + let callback_data = (*(*wp).config).callback_data; + let api = callback_data as *mut Box; + if api.is_null() { + return; + } + (*api).reset_safekeeper_statuses_for_metrics(&mut (*wp), num_safekeepers); + } +} + +extern "C" fn update_safekeeper_status_for_metrics( + wp: *mut WalProposer, + sk_index: u32, + status: u8, +) { + unsafe { + let callback_data = (*(*wp).config).callback_data; + let api = callback_data as *mut Box; + if api.is_null() { + return; + } + (*api).update_safekeeper_status_for_metrics(&mut (*wp), sk_index, status); + } +} +/* END_HADRON */ + #[derive(Debug, PartialEq)] pub enum Level { Debug5, @@ -414,6 +442,10 @@ pub(crate) fn create_api() -> walproposer_api { finish_sync_safekeepers: Some(finish_sync_safekeepers), process_safekeeper_feedback: Some(process_safekeeper_feedback), log_internal: Some(log_internal), + /* BEGIN_HADRON */ + reset_safekeeper_statuses_for_metrics: Some(reset_safekeeper_statuses_for_metrics), + update_safekeeper_status_for_metrics: Some(update_safekeeper_status_for_metrics), + /* END_HADRON */ } } @@ -426,12 +458,15 @@ pub fn empty_shmem() -> crate::bindings::WalproposerShmemState { remote_consistent_lsn: 0, replytime: 0, shard_number: 0, + corruption_detected: false, }; let empty_wal_rate_limiter = crate::bindings::WalRateLimiter { + effective_max_wal_bytes_per_second: crate::bindings::pg_atomic_uint32 { value: 0 }, should_limit: crate::bindings::pg_atomic_uint32 { value: 0 }, sent_bytes: 0, - last_recorded_time_us: crate::bindings::pg_atomic_uint64 { value: 0 }, + batch_start_time_us: crate::bindings::pg_atomic_uint64 { value: 0 }, + batch_end_time_us: crate::bindings::pg_atomic_uint64 { value: 0 }, }; crate::bindings::WalproposerShmemState { @@ -448,6 +483,8 @@ pub fn empty_shmem() -> crate::bindings::WalproposerShmemState { replica_promote: false, min_ps_feedback: empty_feedback, wal_rate_limiter: empty_wal_rate_limiter, + num_safekeepers: 0, + safekeeper_status: [0; 32], } } diff --git a/libs/walproposer/src/walproposer.rs b/libs/walproposer/src/walproposer.rs index 93bb0d5eb0..8453279c5c 100644 --- a/libs/walproposer/src/walproposer.rs +++ b/libs/walproposer/src/walproposer.rs @@ -159,6 +159,21 @@ pub trait ApiImpl { fn after_election(&self, _wp: &mut WalProposer) { todo!() } + + /* BEGIN_HADRON */ + fn reset_safekeeper_statuses_for_metrics(&self, _wp: &mut WalProposer, _num_safekeepers: u32) { + // Do nothing for testing purposes. + } + + fn update_safekeeper_status_for_metrics( + &self, + _wp: &mut WalProposer, + _sk_index: u32, + _status: u8, + ) { + // Do nothing for testing purposes. + } + /* END_HADRON */ } #[derive(Debug)] diff --git a/pageserver/client_grpc/src/client.rs b/pageserver/client_grpc/src/client.rs index b8ee57bf9f..dad37ebe74 100644 --- a/pageserver/client_grpc/src/client.rs +++ b/pageserver/client_grpc/src/client.rs @@ -230,16 +230,14 @@ impl PageserverClient { ) -> tonic::Result { // Fast path: request is for a single shard. if let Some(shard_id) = - GetPageSplitter::for_single_shard(&req, shards.count, shards.stripe_size) - .map_err(|err| tonic::Status::internal(err.to_string()))? + GetPageSplitter::for_single_shard(&req, shards.count, shards.stripe_size)? { return Self::get_page_with_shard(req, shards.get(shard_id)?).await; } // Request spans multiple shards. Split it, dispatch concurrent per-shard requests, and // reassemble the responses. - let mut splitter = GetPageSplitter::split(req, shards.count, shards.stripe_size) - .map_err(|err| tonic::Status::internal(err.to_string()))?; + let mut splitter = GetPageSplitter::split(req, shards.count, shards.stripe_size)?; let mut shard_requests = FuturesUnordered::new(); for (shard_id, shard_req) in splitter.drain_requests() { @@ -249,14 +247,10 @@ impl PageserverClient { } while let Some((shard_id, shard_response)) = shard_requests.next().await.transpose()? { - splitter - .add_response(shard_id, shard_response) - .map_err(|err| tonic::Status::internal(err.to_string()))?; + splitter.add_response(shard_id, shard_response)?; } - splitter - .get_response() - .map_err(|err| tonic::Status::internal(err.to_string())) + Ok(splitter.collect_response()?) } /// Fetches pages on the given shard. Does not retry internally. diff --git a/pageserver/page_api/src/lib.rs b/pageserver/page_api/src/lib.rs index b44df6337f..b9be6b8b91 100644 --- a/pageserver/page_api/src/lib.rs +++ b/pageserver/page_api/src/lib.rs @@ -24,4 +24,4 @@ mod split; pub use client::Client; pub use model::*; -pub use split::GetPageSplitter; +pub use split::{GetPageSplitter, SplitError}; diff --git a/pageserver/page_api/src/split.rs b/pageserver/page_api/src/split.rs index 5ecc90a166..27c1c995e0 100644 --- a/pageserver/page_api/src/split.rs +++ b/pageserver/page_api/src/split.rs @@ -1,6 +1,5 @@ use std::collections::HashMap; -use anyhow::anyhow; use bytes::Bytes; use crate::model::*; @@ -27,19 +26,19 @@ impl GetPageSplitter { req: &GetPageRequest, count: ShardCount, stripe_size: Option, - ) -> anyhow::Result> { + ) -> Result, SplitError> { // Fast path: unsharded tenant. if count.is_unsharded() { return Ok(Some(ShardIndex::unsharded())); } let Some(stripe_size) = stripe_size else { - return Err(anyhow!("stripe size must be given for sharded tenants")); + return Err("stripe size must be given for sharded tenants".into()); }; // Find the first page's shard, for comparison. let Some(&first_page) = req.block_numbers.first() else { - return Err(anyhow!("no block numbers in request")); + return Err("no block numbers in request".into()); }; let key = rel_block_to_key(req.rel, first_page); let shard_number = key_to_shard_number(count, stripe_size, &key); @@ -60,7 +59,7 @@ impl GetPageSplitter { req: GetPageRequest, count: ShardCount, stripe_size: Option, - ) -> anyhow::Result { + ) -> Result { // The caller should make sure we don't split requests unnecessarily. debug_assert!( Self::for_single_shard(&req, count, stripe_size)?.is_none(), @@ -68,10 +67,10 @@ impl GetPageSplitter { ); if count.is_unsharded() { - return Err(anyhow!("unsharded tenant, no point in splitting request")); + return Err("unsharded tenant, no point in splitting request".into()); } let Some(stripe_size) = stripe_size else { - return Err(anyhow!("stripe size must be given for sharded tenants")); + return Err("stripe size must be given for sharded tenants".into()); }; // Split the requests by shard index. @@ -129,35 +128,32 @@ impl GetPageSplitter { /// Adds a response from the given shard. The response must match the request ID and have an OK /// status code. A response must not already exist for the given shard ID. - #[allow(clippy::result_large_err)] pub fn add_response( &mut self, shard_id: ShardIndex, response: GetPageResponse, - ) -> anyhow::Result<()> { + ) -> Result<(), SplitError> { // The caller should already have converted status codes into tonic::Status. if response.status_code != GetPageStatusCode::Ok { - return Err(anyhow!( + return Err(SplitError(format!( "unexpected non-OK response for shard {shard_id}: {} {}", response.status_code, response.reason.unwrap_or_default() - )); + ))); } if response.request_id != self.response.request_id { - return Err(anyhow!( + return Err(SplitError(format!( "response ID mismatch for shard {shard_id}: expected {}, got {}", - self.response.request_id, - response.request_id - )); + self.response.request_id, response.request_id + ))); } if response.request_id != self.response.request_id { - return Err(anyhow!( + return Err(SplitError(format!( "response ID mismatch for shard {shard_id}: expected {}, got {}", - self.response.request_id, - response.request_id - )); + self.response.request_id, response.request_id + ))); } // Place the shard response pages into the assembled response, in request order. @@ -169,26 +165,27 @@ impl GetPageSplitter { } let Some(slot) = self.response.pages.get_mut(i) else { - return Err(anyhow!("no block_shards slot {i} for shard {shard_id}")); + return Err(SplitError(format!( + "no block_shards slot {i} for shard {shard_id}" + ))); }; let Some(page) = pages.next() else { - return Err(anyhow!( + return Err(SplitError(format!( "missing page {} in shard {shard_id} response", slot.block_number - )); + ))); }; if page.block_number != slot.block_number { - return Err(anyhow!( + return Err(SplitError(format!( "shard {shard_id} returned wrong page at index {i}, expected {} got {}", - slot.block_number, - page.block_number - )); + slot.block_number, page.block_number + ))); } if !slot.image.is_empty() { - return Err(anyhow!( + return Err(SplitError(format!( "shard {shard_id} returned duplicate page {} at index {i}", slot.block_number - )); + ))); } *slot = page; @@ -196,32 +193,54 @@ impl GetPageSplitter { // Make sure we've consumed all pages from the shard response. if let Some(extra_page) = pages.next() { - return Err(anyhow!( + return Err(SplitError(format!( "shard {shard_id} returned extra page: {}", extra_page.block_number - )); + ))); } Ok(()) } - /// Fetches the final, assembled response. - #[allow(clippy::result_large_err)] - pub fn get_response(self) -> anyhow::Result { + /// Collects the final, assembled response. + pub fn collect_response(self) -> Result { // Check that the response is complete. for (i, page) in self.response.pages.iter().enumerate() { if page.image.is_empty() { - return Err(anyhow!( + return Err(SplitError(format!( "missing page {} for shard {}", page.block_number, self.block_shards .get(i) .map(|s| s.to_string()) .unwrap_or_else(|| "?".to_string()) - )); + ))); } } Ok(self.response) } } + +/// A GetPageSplitter error. +#[derive(Debug, thiserror::Error)] +#[error("{0}")] +pub struct SplitError(String); + +impl From<&str> for SplitError { + fn from(err: &str) -> Self { + SplitError(err.to_string()) + } +} + +impl From for SplitError { + fn from(err: String) -> Self { + SplitError(err) + } +} + +impl From for tonic::Status { + fn from(err: SplitError) -> Self { + tonic::Status::internal(err.0) + } +} diff --git a/pageserver/src/basebackup.rs b/pageserver/src/basebackup.rs index 1a44c80e2d..1f1a3f8157 100644 --- a/pageserver/src/basebackup.rs +++ b/pageserver/src/basebackup.rs @@ -11,6 +11,7 @@ //! from data stored in object storage. //! use std::fmt::Write as FmtWrite; +use std::sync::Arc; use std::time::{Instant, SystemTime}; use anyhow::{Context, anyhow}; @@ -420,12 +421,16 @@ where } let mut min_restart_lsn: Lsn = Lsn::MAX; + + let mut dbdir_cnt = 0; + let mut rel_cnt = 0; + // Create tablespace directories for ((spcnode, dbnode), has_relmap_file) in self.timeline.list_dbdirs(self.lsn, self.ctx).await? { self.add_dbdir(spcnode, dbnode, has_relmap_file).await?; - + dbdir_cnt += 1; // If full backup is requested, include all relation files. // Otherwise only include init forks of unlogged relations. let rels = self @@ -433,6 +438,7 @@ where .list_rels(spcnode, dbnode, Version::at(self.lsn), self.ctx) .await?; for &rel in rels.iter() { + rel_cnt += 1; // Send init fork as main fork to provide well formed empty // contents of UNLOGGED relations. Postgres copies it in // `reinit.c` during recovery. @@ -455,6 +461,10 @@ where } } + self.timeline + .db_rel_count + .store(Some(Arc::new((dbdir_cnt, rel_cnt)))); + let start_time = Instant::now(); let aux_files = self .timeline diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 855af7009c..b1566c2d6e 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -715,7 +715,7 @@ fn start_pageserver( disk_usage_eviction_state, deletion_queue.new_client(), secondary_controller, - feature_resolver, + feature_resolver.clone(), ) .context("Failed to initialize router state")?, ); @@ -841,14 +841,14 @@ fn start_pageserver( } else { None }, + feature_resolver.clone(), ); - // Spawn a Pageserver gRPC server task. It will spawn separate tasks for - // each stream/request. + // Spawn a Pageserver gRPC server task. It will spawn separate tasks for each request/stream. + // It uses a separate compute request Tokio runtime (COMPUTE_REQUEST_RUNTIME). // - // TODO: this uses a separate Tokio runtime for the page service. If we want - // other gRPC services, they will need their own port and runtime. Is this - // necessary? + // NB: this port is exposed to computes. It should only provide services that we're okay with + // computes accessing. Internal services should use a separate port. let mut page_service_grpc = None; if let Some(grpc_listener) = grpc_listener { page_service_grpc = Some(GrpcPageServiceHandler::spawn( diff --git a/pageserver/src/feature_resolver.rs b/pageserver/src/feature_resolver.rs index 11b0e972b4..678d7e052b 100644 --- a/pageserver/src/feature_resolver.rs +++ b/pageserver/src/feature_resolver.rs @@ -156,6 +156,8 @@ impl FeatureResolver { let tenant_properties = PerTenantProperties { remote_size_mb: Some(rand::rng().random_range(100.0..1000000.00)), + db_count_max: Some(rand::rng().random_range(1..1000)), + rel_count_max: Some(rand::rng().random_range(1..1000)), } .into_posthog_properties(); @@ -344,6 +346,8 @@ impl FeatureResolver { struct PerTenantProperties { pub remote_size_mb: Option, + pub db_count_max: Option, + pub rel_count_max: Option, } impl PerTenantProperties { @@ -355,6 +359,18 @@ impl PerTenantProperties { PostHogFlagFilterPropertyValue::Number(remote_size_mb), ); } + if let Some(db_count) = self.db_count_max { + properties.insert( + "tenant_db_count_max".to_string(), + PostHogFlagFilterPropertyValue::Number(db_count as f64), + ); + } + if let Some(rel_count) = self.rel_count_max { + properties.insert( + "tenant_rel_count_max".to_string(), + PostHogFlagFilterPropertyValue::Number(rel_count as f64), + ); + } properties } } @@ -409,7 +425,11 @@ impl TenantFeatureResolver { /// Refresh the cached properties and flags on the critical path. pub fn refresh_properties_and_flags(&self, tenant_shard: &TenantShard) { + // Any of the remote size is none => this property is none. let mut remote_size_mb = Some(0.0); + // Any of the db or rel count is available => this property is available. + let mut db_count_max = None; + let mut rel_count_max = None; for timeline in tenant_shard.list_timelines() { let size = timeline.metrics.resident_physical_size_get(); if size == 0 { @@ -419,9 +439,25 @@ impl TenantFeatureResolver { if let Some(ref mut remote_size_mb) = remote_size_mb { *remote_size_mb += size as f64 / 1024.0 / 1024.0; } + if let Some(data) = timeline.db_rel_count.load_full() { + let (db_count, rel_count) = *data.as_ref(); + if db_count_max.is_none() { + db_count_max = Some(db_count); + } + if rel_count_max.is_none() { + rel_count_max = Some(rel_count); + } + db_count_max = db_count_max.map(|max| max.max(db_count)); + rel_count_max = rel_count_max.map(|max| max.max(rel_count)); + } } self.cached_tenant_properties.store(Arc::new( - PerTenantProperties { remote_size_mb }.into_posthog_properties(), + PerTenantProperties { + remote_size_mb, + db_count_max, + rel_count_max, + } + .into_posthog_properties(), )); // BEGIN: Update the feature flag on the critical path. diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index 3a08244d71..669eeffa32 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -484,6 +484,8 @@ async fn build_timeline_info_common( *timeline.get_applied_gc_cutoff_lsn(), ); + let (rel_size_migration, rel_size_migrated_at) = timeline.get_rel_size_v2_status(); + let info = TimelineInfo { tenant_id: timeline.tenant_shard_id, timeline_id: timeline.timeline_id, @@ -515,7 +517,8 @@ async fn build_timeline_info_common( state, is_archived: Some(is_archived), - rel_size_migration: Some(timeline.get_rel_size_v2_status()), + rel_size_migration: Some(rel_size_migration), + rel_size_migrated_at, is_invisible: Some(is_invisible), walreceiver_status, @@ -930,9 +933,16 @@ async fn timeline_patch_index_part_handler( active_timeline_of_active_tenant(&state.tenant_manager, tenant_shard_id, timeline_id) .await?; + if request_data.rel_size_migration.is_none() && request_data.rel_size_migrated_at.is_some() + { + return Err(ApiError::BadRequest(anyhow!( + "updating rel_size_migrated_at without rel_size_migration is not allowed" + ))); + } + if let Some(rel_size_migration) = request_data.rel_size_migration { timeline - .update_rel_size_v2_status(rel_size_migration) + .update_rel_size_v2_status(rel_size_migration, request_data.rel_size_migrated_at) .map_err(ApiError::InternalServerError)?; } @@ -1995,6 +2005,10 @@ async fn put_tenant_location_config_handler( let state = get_state(&request); let conf = state.conf; + fail::fail_point!("put-location-conf-handler", |_| { + Err(ApiError::ResourceUnavailable("failpoint".into())) + }); + // The `Detached` state is special, it doesn't upsert a tenant, it removes // its local disk content and drops it from memory. if let LocationConfigMode::Detached = request_data.config.mode { diff --git a/pageserver/src/import_datadir.rs b/pageserver/src/import_datadir.rs index 409cc2e3c5..5b674adbb3 100644 --- a/pageserver/src/import_datadir.rs +++ b/pageserver/src/import_datadir.rs @@ -57,7 +57,7 @@ pub async fn import_timeline_from_postgres_datadir( // TODO this shoud be start_lsn, which is not necessarily equal to end_lsn (aka lsn) // Then fishing out pg_control would be unnecessary - let mut modification = tline.begin_modification(pgdata_lsn); + let mut modification = tline.begin_modification_for_import(pgdata_lsn); modification.init_empty()?; // Import all but pg_wal @@ -309,7 +309,7 @@ async fn import_wal( waldecoder.feed_bytes(&buf); let mut nrecords = 0; - let mut modification = tline.begin_modification(last_lsn); + let mut modification = tline.begin_modification_for_import(last_lsn); while last_lsn <= endpoint { if let Some((lsn, recdata)) = waldecoder.poll_decode()? { let interpreted = InterpretedWalRecord::from_bytes_filtered( @@ -357,7 +357,7 @@ pub async fn import_basebackup_from_tar( ctx: &RequestContext, ) -> Result<()> { info!("importing base at {base_lsn}"); - let mut modification = tline.begin_modification(base_lsn); + let mut modification = tline.begin_modification_for_import(base_lsn); modification.init_empty()?; let mut pg_control: Option = None; @@ -457,7 +457,7 @@ pub async fn import_wal_from_tar( waldecoder.feed_bytes(&bytes[offset..]); - let mut modification = tline.begin_modification(last_lsn); + let mut modification = tline.begin_modification_for_import(last_lsn); while last_lsn <= end_lsn { if let Some((lsn, recdata)) = waldecoder.poll_decode()? { let interpreted = InterpretedWalRecord::from_bytes_filtered( diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index ab1c77076c..61cf2954c1 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -69,6 +69,7 @@ use crate::config::PageServerConf; use crate::context::{ DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder, }; +use crate::feature_resolver::FeatureResolver; use crate::metrics::{ self, COMPUTE_COMMANDS_COUNTERS, ComputeCommandKind, GetPageBatchBreakReason, LIVE_CONNECTIONS, MISROUTED_PAGESTREAM_REQUESTS, PAGESTREAM_HANDLER_RESULTS_TOTAL, SmgrOpTimer, TimelineMetrics, @@ -140,6 +141,7 @@ pub fn spawn( perf_trace_dispatch: Option, tcp_listener: tokio::net::TcpListener, tls_config: Option>, + feature_resolver: FeatureResolver, ) -> Listener { let cancel = CancellationToken::new(); let libpq_ctx = RequestContext::todo_child( @@ -161,6 +163,7 @@ pub fn spawn( conf.pg_auth_type, tls_config, conf.page_service_pipelining.clone(), + feature_resolver, libpq_ctx, cancel.clone(), ) @@ -219,6 +222,7 @@ pub async fn libpq_listener_main( auth_type: AuthType, tls_config: Option>, pipelining_config: PageServicePipeliningConfig, + feature_resolver: FeatureResolver, listener_ctx: RequestContext, listener_cancel: CancellationToken, ) -> Connections { @@ -262,6 +266,7 @@ pub async fn libpq_listener_main( auth_type, tls_config.clone(), pipelining_config.clone(), + feature_resolver.clone(), connection_ctx, connections_cancel.child_token(), gate_guard, @@ -304,6 +309,7 @@ async fn page_service_conn_main( auth_type: AuthType, tls_config: Option>, pipelining_config: PageServicePipeliningConfig, + feature_resolver: FeatureResolver, connection_ctx: RequestContext, cancel: CancellationToken, gate_guard: GateGuard, @@ -371,6 +377,7 @@ async fn page_service_conn_main( perf_span_fields, connection_ctx, cancel.clone(), + feature_resolver.clone(), gate_guard, ); let pgbackend = @@ -422,6 +429,8 @@ struct PageServerHandler { pipelining_config: PageServicePipeliningConfig, get_vectored_concurrent_io: GetVectoredConcurrentIo, + feature_resolver: FeatureResolver, + gate_guard: GateGuard, } @@ -459,25 +468,11 @@ impl TimelineHandles { self.handles .get(timeline_id, shard_selector, &self.wrapper) .await - .map_err(|e| match e { - timeline::handle::GetError::TenantManager(e) => e, - timeline::handle::GetError::PerTimelineStateShutDown => { - trace!("per-timeline state shut down"); - GetActiveTimelineError::Timeline(GetTimelineError::ShuttingDown) - } - }) } fn tenant_id(&self) -> Option { self.wrapper.tenant_id.get().copied() } - - /// Returns whether a child shard exists locally for the given shard. - fn has_child_shard(&self, tenant_id: TenantId, shard_index: ShardIndex) -> bool { - self.wrapper - .tenant_manager - .has_child_shard(tenant_id, shard_index) - } } pub(crate) struct TenantManagerWrapper { @@ -488,11 +483,9 @@ pub(crate) struct TenantManagerWrapper { tenant_id: once_cell::sync::OnceCell, } -#[derive(Debug)] pub(crate) struct TenantManagerTypes; impl timeline::handle::Types for TenantManagerTypes { - type TenantManagerError = GetActiveTimelineError; type TenantManager = TenantManagerWrapper; type Timeline = TenantManagerCacheItem; } @@ -544,6 +537,7 @@ impl timeline::handle::TenantManager for TenantManagerWrappe match resolved { ShardResolveResult::Found(tenant_shard) => break tenant_shard, ShardResolveResult::NotFound => { + MISROUTED_PAGESTREAM_REQUESTS.inc(); return Err(GetActiveTimelineError::Tenant( GetActiveTenantError::NotFound(GetTenantError::NotFound(*tenant_id)), )); @@ -595,6 +589,15 @@ impl timeline::handle::TenantManager for TenantManagerWrappe } } +/// Whether to hold the applied GC cutoff guard when processing GetPage requests. +/// This is determined once at the start of pagestream subprotocol handling based on +/// feature flags, configuration, and test conditions. +#[derive(Debug, Clone, Copy)] +enum HoldAppliedGcCutoffGuard { + Yes, + No, +} + #[derive(thiserror::Error, Debug)] enum PageStreamError { /// We encountered an error that should prompt the client to reconnect: @@ -738,6 +741,7 @@ enum BatchedFeMessage { GetPage { span: Span, shard: WeakHandle, + applied_gc_cutoff_guard: Option>, pages: SmallVec<[BatchedGetPageRequest; 1]>, batch_break_reason: GetPageBatchBreakReason, }, @@ -917,6 +921,7 @@ impl PageServerHandler { perf_span_fields: ConnectionPerfSpanFields, connection_ctx: RequestContext, cancel: CancellationToken, + feature_resolver: FeatureResolver, gate_guard: GateGuard, ) -> Self { PageServerHandler { @@ -928,6 +933,7 @@ impl PageServerHandler { cancel, pipelining_config, get_vectored_concurrent_io, + feature_resolver, gate_guard, } } @@ -967,6 +973,7 @@ impl PageServerHandler { ctx: &RequestContext, protocol_version: PagestreamProtocolVersion, parent_span: Span, + hold_gc_cutoff_guard: HoldAppliedGcCutoffGuard, ) -> Result, QueryError> where IO: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, @@ -1204,19 +1211,27 @@ impl PageServerHandler { }) .await?; + let applied_gc_cutoff_guard = shard.get_applied_gc_cutoff_lsn(); // hold guard // We're holding the Handle let effective_lsn = match Self::effective_request_lsn( &shard, shard.get_last_record_lsn(), req.hdr.request_lsn, req.hdr.not_modified_since, - &shard.get_applied_gc_cutoff_lsn(), + &applied_gc_cutoff_guard, ) { Ok(lsn) => lsn, Err(e) => { return respond_error!(span, e); } }; + let applied_gc_cutoff_guard = match hold_gc_cutoff_guard { + HoldAppliedGcCutoffGuard::Yes => Some(applied_gc_cutoff_guard), + HoldAppliedGcCutoffGuard::No => { + drop(applied_gc_cutoff_guard); + None + } + }; let batch_wait_ctx = if ctx.has_perf_span() { Some( @@ -1237,6 +1252,7 @@ impl PageServerHandler { BatchedFeMessage::GetPage { span, shard: shard.downgrade(), + applied_gc_cutoff_guard, pages: smallvec![BatchedGetPageRequest { req, timer, @@ -1337,13 +1353,28 @@ impl PageServerHandler { match (eligible_batch, this_msg) { ( BatchedFeMessage::GetPage { - pages: accum_pages, .. + pages: accum_pages, + applied_gc_cutoff_guard: accum_applied_gc_cutoff_guard, + .. }, BatchedFeMessage::GetPage { - pages: this_pages, .. + pages: this_pages, + applied_gc_cutoff_guard: this_applied_gc_cutoff_guard, + .. }, ) => { accum_pages.extend(this_pages); + // the minimum of the two guards will keep data for both alive + match (&accum_applied_gc_cutoff_guard, this_applied_gc_cutoff_guard) { + (None, None) => (), + (None, Some(this)) => *accum_applied_gc_cutoff_guard = Some(this), + (Some(_), None) => (), + (Some(accum), Some(this)) => { + if **accum > *this { + *accum_applied_gc_cutoff_guard = Some(this); + } + } + }; Ok(()) } #[cfg(feature = "testing")] @@ -1658,6 +1689,7 @@ impl PageServerHandler { BatchedFeMessage::GetPage { span, shard, + applied_gc_cutoff_guard, pages, batch_break_reason, } => { @@ -1677,6 +1709,7 @@ impl PageServerHandler { .instrument(span.clone()) .await; assert_eq!(res.len(), npages); + drop(applied_gc_cutoff_guard); res }, span, @@ -1758,7 +1791,7 @@ impl PageServerHandler { /// Coding discipline within this function: all interaction with the `pgb` connection /// needs to be sensitive to connection shutdown, currently signalled via [`Self::cancel`]. /// This is so that we can shutdown page_service quickly. - #[instrument(skip_all)] + #[instrument(skip_all, fields(hold_gc_cutoff_guard))] async fn handle_pagerequests( &mut self, pgb: &mut PostgresBackend, @@ -1804,6 +1837,30 @@ impl PageServerHandler { .take() .expect("implementation error: timeline_handles should not be locked"); + // Evaluate the expensive feature resolver check once per pagestream subprotocol handling + // instead of once per GetPage request. This is shared between pipelined and serial paths. + let hold_gc_cutoff_guard = if cfg!(test) || cfg!(feature = "testing") { + HoldAppliedGcCutoffGuard::Yes + } else { + // Use the global feature resolver with the tenant ID directly, avoiding the need + // to get a timeline/shard which might not be available on this pageserver node. + let empty_properties = std::collections::HashMap::new(); + match self.feature_resolver.evaluate_boolean( + "page-service-getpage-hold-applied-gc-cutoff-guard", + tenant_id, + &empty_properties, + ) { + Ok(()) => HoldAppliedGcCutoffGuard::Yes, + Err(_) => HoldAppliedGcCutoffGuard::No, + } + }; + // record it in the span of handle_pagerequests so that both the request_span + // and the pipeline implementation spans contains the field. + Span::current().record( + "hold_gc_cutoff_guard", + tracing::field::debug(&hold_gc_cutoff_guard), + ); + let request_span = info_span!("request"); let ((pgb_reader, timeline_handles), result) = match self.pipelining_config.clone() { PageServicePipeliningConfig::Pipelined(pipelining_config) => { @@ -1817,6 +1874,7 @@ impl PageServerHandler { pipelining_config, protocol_version, io_concurrency, + hold_gc_cutoff_guard, &ctx, ) .await @@ -1831,6 +1889,7 @@ impl PageServerHandler { request_span, protocol_version, io_concurrency, + hold_gc_cutoff_guard, &ctx, ) .await @@ -1859,6 +1918,7 @@ impl PageServerHandler { request_span: Span, protocol_version: PagestreamProtocolVersion, io_concurrency: IoConcurrency, + hold_gc_cutoff_guard: HoldAppliedGcCutoffGuard, ctx: &RequestContext, ) -> ( (PostgresBackendReader, TimelineHandles), @@ -1880,6 +1940,7 @@ impl PageServerHandler { ctx, protocol_version, request_span.clone(), + hold_gc_cutoff_guard, ) .await; let msg = match msg { @@ -1927,6 +1988,7 @@ impl PageServerHandler { pipelining_config: PageServicePipeliningConfigPipelined, protocol_version: PagestreamProtocolVersion, io_concurrency: IoConcurrency, + hold_gc_cutoff_guard: HoldAppliedGcCutoffGuard, ctx: &RequestContext, ) -> ( (PostgresBackendReader, TimelineHandles), @@ -2030,6 +2092,7 @@ impl PageServerHandler { &ctx, protocol_version, request_span.clone(), + hold_gc_cutoff_guard, ) .await; let Some(read_res) = read_res.transpose() else { @@ -2076,6 +2139,7 @@ impl PageServerHandler { pages, span: _, shard: _, + applied_gc_cutoff_guard: _, batch_break_reason: _, } = &mut batch { @@ -3361,18 +3425,6 @@ impl GrpcPageServiceHandler { Ok(CancellableTask { task, cancel }) } - /// Errors if the request is executed on a non-zero shard. Only shard 0 has a complete view of - /// relations and their sizes, as well as SLRU segments and similar data. - #[allow(clippy::result_large_err)] - fn ensure_shard_zero(timeline: &Handle) -> Result<(), tonic::Status> { - match timeline.get_shard_index().shard_number.0 { - 0 => Ok(()), - shard => Err(tonic::Status::invalid_argument(format!( - "request must execute on shard zero (is shard {shard})", - ))), - } - } - /// Generates a PagestreamRequest header from a ReadLsn and request ID. fn make_hdr( read_lsn: page_api::ReadLsn, @@ -3394,56 +3446,55 @@ impl GrpcPageServiceHandler { &self, req: &tonic::Request, ) -> Result, GetActiveTimelineError> { - let ttid = *extract::(req); + let TenantTimelineId { + tenant_id, + timeline_id, + } = *extract::(req); let shard_index = *extract::(req); - let shard_selector = ShardSelector::Known(shard_index); // TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to // avoid the unnecessary overhead. TimelineHandles::new(self.tenant_manager.clone()) - .get(ttid.tenant_id, ttid.timeline_id, shard_selector) + .get(tenant_id, timeline_id, ShardSelector::Known(shard_index)) .await } - /// Acquires a timeline handle for the given request, which must be for shard zero. + /// Acquires a timeline handle for the given request, which must be for shard zero. Most + /// metadata requests are only valid on shard zero. /// /// NB: during an ongoing shard split, the compute will keep talking to the parent shard until /// the split is committed, but the parent shard may have been removed in the meanwhile. In that /// case, we reroute the request to the new child shard. See [`Self::maybe_split_get_page`]. /// /// TODO: revamp the split protocol to avoid this child routing. - async fn get_shard_zero_request_timeline( + async fn get_request_timeline_shard_zero( &self, req: &tonic::Request, ) -> Result, tonic::Status> { - let ttid = *extract::(req); + let TenantTimelineId { + tenant_id, + timeline_id, + } = *extract::(req); let shard_index = *extract::(req); if shard_index.shard_number.0 != 0 { return Err(tonic::Status::invalid_argument(format!( - "request must use shard zero (requested shard {shard_index})", + "request only valid on shard zero (requested shard {shard_index})", ))); } // TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to // avoid the unnecessary overhead. - // - // TODO: this does internal retries, which will delay requests during shard splits (we won't - // look for the child until the parent's retries are exhausted). Don't do that. let mut handles = TimelineHandles::new(self.tenant_manager.clone()); match handles - .get( - ttid.tenant_id, - ttid.timeline_id, - ShardSelector::Known(shard_index), - ) + .get(tenant_id, timeline_id, ShardSelector::Known(shard_index)) .await { Ok(timeline) => Ok(timeline), Err(err) => { // We may be in the middle of a shard split. Try to find a child shard 0. if let Ok(timeline) = handles - .get(ttid.tenant_id, ttid.timeline_id, ShardSelector::Zero) + .get(tenant_id, timeline_id, ShardSelector::Zero) .await && timeline.get_shard_index().shard_count > shard_index.shard_count { @@ -3480,8 +3531,6 @@ impl GrpcPageServiceHandler { /// NB: errors returned from here are intercepted in get_pages(), and may be converted to a /// GetPageResponse with an appropriate status code to avoid terminating the stream. /// - /// TODO: verify that the requested pages belong to this shard. - /// /// TODO: get_vectored() currently enforces a batch limit of 32. Postgres will typically send /// batches up to effective_io_concurrency = 100. Either we have to accept large batches, or /// split them up in the client or server. @@ -3501,6 +3550,19 @@ impl GrpcPageServiceHandler { ) -> Result { let ctx = ctx.with_scope_page_service_pagestream(&timeline); + for &blkno in &req.block_numbers { + let shard = timeline.get_shard_identity(); + let key = rel_block_to_key(req.rel, blkno); + if !shard.is_key_local(&key) { + return Err(tonic::Status::invalid_argument(format!( + "block {blkno} of relation {} requested on wrong shard {} (is on {})", + req.rel, + timeline.get_shard_index(), + ShardIndex::new(shard.get_shard_number(&key), shard.count), + ))); + } + } + let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); // hold guard let effective_lsn = PageServerHandler::effective_request_lsn( &timeline, @@ -3579,12 +3641,13 @@ impl GrpcPageServiceHandler { } /// Processes a GetPage request when there is a potential shard split in progress. We have to - /// reroute the request any local child shards, and split batch requests that straddle multiple - /// child shards. + /// reroute the request to any local child shards, and split batch requests that straddle + /// multiple child shards. /// - /// Parent shards are split and removed incrementally, but the compute is only notified once the - /// entire split commits, which can take several minutes. In the meanwhile, the compute will be - /// sending requests to the parent shard. + /// Parent shards are split and removed incrementally (there may be many parent shards when + /// splitting an already-sharded tenant), but the compute is only notified once the overall + /// split commits, which can take several minutes. In the meanwhile, the compute will be sending + /// requests to the parent shards. /// /// TODO: add test infrastructure to provoke this situation frequently and for long periods of /// time, to properly exercise it. @@ -3594,10 +3657,12 @@ impl GrpcPageServiceHandler { /// * Notify the compute about each subsplit. /// * Return an error that updates the compute's shard map. #[instrument(skip_all)] + #[allow(clippy::too_many_arguments)] async fn maybe_split_get_page( ctx: &RequestContext, handles: &mut TimelineHandles, - ttid: TenantTimelineId, + tenant_id: TenantId, + timeline_id: TimelineId, parent: ShardIndex, req: page_api::GetPageRequest, io_concurrency: IoConcurrency, @@ -3608,8 +3673,8 @@ impl GrpcPageServiceHandler { // the page must have a higher shard count. let timeline = handles .get( - ttid.tenant_id, - ttid.timeline_id, + tenant_id, + timeline_id, ShardSelector::Page(rel_block_to_key(req.rel, req.block_numbers[0])), ) .await?; @@ -3621,8 +3686,7 @@ impl GrpcPageServiceHandler { // Fast path: the request fits in a single shard. if let Some(shard_index) = - GetPageSplitter::for_single_shard(&req, shard_id.count, Some(shard_id.stripe_size)) - .map_err(|err| tonic::Status::internal(err.to_string()))? + GetPageSplitter::for_single_shard(&req, shard_id.count, Some(shard_id.stripe_size))? { // We got the shard ID from the first page, so these must be equal. assert_eq!(shard_index.shard_number, shard_id.number); @@ -3633,17 +3697,12 @@ impl GrpcPageServiceHandler { // The request spans multiple shards; split it and dispatch parallel requests. All pages // were originally in the parent shard, and during a split all children are local, so we // expect to find local shards for all pages. - let mut splitter = GetPageSplitter::split(req, shard_id.count, Some(shard_id.stripe_size)) - .map_err(|err| tonic::Status::internal(err.to_string()))?; + let mut splitter = GetPageSplitter::split(req, shard_id.count, Some(shard_id.stripe_size))?; let mut shard_requests = FuturesUnordered::new(); for (shard_index, shard_req) in splitter.drain_requests() { let timeline = handles - .get( - ttid.tenant_id, - ttid.timeline_id, - ShardSelector::Known(shard_index), - ) + .get(tenant_id, timeline_id, ShardSelector::Known(shard_index)) .await?; let future = Self::get_page( ctx, @@ -3657,14 +3716,10 @@ impl GrpcPageServiceHandler { } while let Some((shard_index, shard_response)) = shard_requests.next().await.transpose()? { - splitter - .add_response(shard_index, shard_response) - .map_err(|err| tonic::Status::internal(err.to_string()))?; + splitter.add_response(shard_index, shard_response)?; } - splitter - .get_response() - .map_err(|err| tonic::Status::internal(err.to_string())) + Ok(splitter.collect_response()?) } } @@ -3693,7 +3748,7 @@ impl proto::PageService for GrpcPageServiceHandler { // to be the sweet spot where throughput is saturated. const CHUNK_SIZE: usize = 256 * 1024; - let timeline = self.get_shard_zero_request_timeline(&req).await?; + let timeline = self.get_request_timeline_shard_zero(&req).await?; let ctx = self.ctx.with_scope_timeline(&timeline); // Validate the request and decorate the span. @@ -3812,7 +3867,7 @@ impl proto::PageService for GrpcPageServiceHandler { req: tonic::Request, ) -> Result, tonic::Status> { let received_at = extract::(&req).0; - let timeline = self.get_shard_zero_request_timeline(&req).await?; + let timeline = self.get_request_timeline_shard_zero(&req).await?; let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); // Validate the request, decorate the span, and convert it to a Pagestream request. @@ -3849,25 +3904,21 @@ impl proto::PageService for GrpcPageServiceHandler { // reroute requests to the child shards below, but we also detect the common cases here // where either the shard exists or no shards exist at all. If we have a child shard, we // can't acquire a weak handle because we don't know which child shard to use yet. - // - // TODO: TimelineHandles.get() does internal retries, which will delay requests during shard - // splits. It shouldn't. - let ttid = *extract::(&req); + let TenantTimelineId { + tenant_id, + timeline_id, + } = *extract::(&req); let shard_index = *extract::(&req); let mut handles = TimelineHandles::new(self.tenant_manager.clone()); let timeline = match handles - .get( - ttid.tenant_id, - ttid.timeline_id, - ShardSelector::Known(shard_index), - ) + .get(tenant_id, timeline_id, ShardSelector::Known(shard_index)) .await { // The timeline shard exists. Keep a weak handle to reuse for each request. Ok(timeline) => Some(timeline.downgrade()), // The shard doesn't exist, but a child shard does. We'll reroute requests later. - Err(_) if handles.has_child_shard(ttid.tenant_id, shard_index) => None, + Err(_) if self.tenant_manager.has_child_shard(tenant_id, shard_index) => None, // Failed to fetch the timeline, and no child shard exists. Error out. Err(err) => return Err(err.into()), }; @@ -3923,7 +3974,8 @@ impl proto::PageService for GrpcPageServiceHandler { Self::maybe_split_get_page( &ctx, &mut handles, - ttid, + tenant_id, + timeline_id, shard_index, req, io_concurrency.clone(), @@ -3958,7 +4010,7 @@ impl proto::PageService for GrpcPageServiceHandler { req: tonic::Request, ) -> Result, tonic::Status> { let received_at = extract::(&req).0; - let timeline = self.get_shard_zero_request_timeline(&req).await?; + let timeline = self.get_request_timeline_shard_zero(&req).await?; let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); // Validate the request, decorate the span, and convert it to a Pagestream request. @@ -3994,11 +4046,10 @@ impl proto::PageService for GrpcPageServiceHandler { req: tonic::Request, ) -> Result, tonic::Status> { let received_at = extract::(&req).0; - let timeline = self.get_shard_zero_request_timeline(&req).await?; + let timeline = self.get_request_timeline_shard_zero(&req).await?; let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); // Validate the request, decorate the span, and convert it to a Pagestream request. - Self::ensure_shard_zero(&timeline)?; let req: page_api::GetSlruSegmentRequest = req.into_inner().try_into()?; span_record!(kind=%req.kind, segno=%req.segno, lsn=%req.read_lsn); diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index ab9cc88e5f..cedf77fb37 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -6,8 +6,9 @@ //! walingest.rs handles a few things like implicit relation creation and extension. //! Clarify that) //! -use std::collections::{HashMap, HashSet, hash_map}; +use std::collections::{BTreeSet, HashMap, HashSet, hash_map}; use std::ops::{ControlFlow, Range}; +use std::sync::Arc; use crate::walingest::{WalIngestError, WalIngestErrorKind}; use crate::{PERF_TRACE_TARGET, ensure_walingest}; @@ -226,6 +227,25 @@ impl Timeline { pending_nblocks: 0, pending_directory_entries: Vec::new(), pending_metadata_bytes: 0, + is_importing_pgdata: false, + lsn, + } + } + + pub fn begin_modification_for_import(&self, lsn: Lsn) -> DatadirModification + where + Self: Sized, + { + DatadirModification { + tline: self, + pending_lsns: Vec::new(), + pending_metadata_pages: HashMap::new(), + pending_data_batch: None, + pending_deletions: Vec::new(), + pending_nblocks: 0, + pending_directory_entries: Vec::new(), + pending_metadata_bytes: 0, + is_importing_pgdata: true, lsn, } } @@ -595,6 +615,50 @@ impl Timeline { self.get_rel_exists_in_reldir(tag, version, None, ctx).await } + async fn get_rel_exists_in_reldir_v1( + &self, + tag: RelTag, + version: Version<'_>, + deserialized_reldir_v1: Option<(Key, &RelDirectory)>, + ctx: &RequestContext, + ) -> Result { + let key = rel_dir_to_key(tag.spcnode, tag.dbnode); + if let Some((cached_key, dir)) = deserialized_reldir_v1 { + if cached_key == key { + return Ok(dir.rels.contains(&(tag.relnode, tag.forknum))); + } else if cfg!(test) || cfg!(feature = "testing") { + panic!("cached reldir key mismatch: {cached_key} != {key}"); + } else { + warn!("cached reldir key mismatch: {cached_key} != {key}"); + } + // Fallback to reading the directory from the datadir. + } + + let buf = version.get(self, key, ctx).await?; + + let dir = RelDirectory::des(&buf)?; + Ok(dir.rels.contains(&(tag.relnode, tag.forknum))) + } + + async fn get_rel_exists_in_reldir_v2( + &self, + tag: RelTag, + version: Version<'_>, + ctx: &RequestContext, + ) -> Result { + let key = rel_tag_sparse_key(tag.spcnode, tag.dbnode, tag.relnode, tag.forknum); + let buf = RelDirExists::decode_option(version.sparse_get(self, key, ctx).await?).map_err( + |_| { + PageReconstructError::Other(anyhow::anyhow!( + "invalid reldir key: decode failed, {}", + key + )) + }, + )?; + let exists_v2 = buf == RelDirExists::Exists; + Ok(exists_v2) + } + /// Does the relation exist? With a cached deserialized `RelDirectory`. /// /// There are some cases where the caller loops across all relations. In that specific case, @@ -626,45 +690,134 @@ impl Timeline { return Ok(false); } - // Read path: first read the new reldir keyspace. Early return if the relation exists. - // Otherwise, read the old reldir keyspace. - // TODO: if IndexPart::rel_size_migration is `Migrated`, we only need to read from v2. + let (v2_status, migrated_lsn) = self.get_rel_size_v2_status(); - if let RelSizeMigration::Migrated | RelSizeMigration::Migrating = - self.get_rel_size_v2_status() - { - // fetch directory listing (new) - let key = rel_tag_sparse_key(tag.spcnode, tag.dbnode, tag.relnode, tag.forknum); - let buf = RelDirExists::decode_option(version.sparse_get(self, key, ctx).await?) - .map_err(|_| PageReconstructError::Other(anyhow::anyhow!("invalid reldir key")))?; - let exists_v2 = buf == RelDirExists::Exists; - // Fast path: if the relation exists in the new format, return true. - // TODO: we should have a verification mode that checks both keyspaces - // to ensure the relation only exists in one of them. - if exists_v2 { - return Ok(true); + match v2_status { + RelSizeMigration::Legacy => { + let v1_exists = self + .get_rel_exists_in_reldir_v1(tag, version, deserialized_reldir_v1, ctx) + .await?; + Ok(v1_exists) + } + RelSizeMigration::Migrating | RelSizeMigration::Migrated + if version.get_lsn() < migrated_lsn.unwrap_or(Lsn(0)) => + { + // For requests below the migrated LSN, we still use the v1 read path. + let v1_exists = self + .get_rel_exists_in_reldir_v1(tag, version, deserialized_reldir_v1, ctx) + .await?; + Ok(v1_exists) + } + RelSizeMigration::Migrating => { + let v1_exists = self + .get_rel_exists_in_reldir_v1(tag, version, deserialized_reldir_v1, ctx) + .await?; + let v2_exists_res = self.get_rel_exists_in_reldir_v2(tag, version, ctx).await; + match v2_exists_res { + Ok(v2_exists) if v1_exists == v2_exists => {} + Ok(v2_exists) => { + tracing::warn!( + "inconsistent v1/v2 reldir keyspace for rel {}: v1_exists={}, v2_exists={}", + tag, + v1_exists, + v2_exists + ); + } + Err(e) => { + tracing::warn!("failed to get rel exists in v2: {e}"); + } + } + Ok(v1_exists) + } + RelSizeMigration::Migrated => { + let v2_exists = self.get_rel_exists_in_reldir_v2(tag, version, ctx).await?; + Ok(v2_exists) } } + } - // fetch directory listing (old) - - let key = rel_dir_to_key(tag.spcnode, tag.dbnode); - - if let Some((cached_key, dir)) = deserialized_reldir_v1 { - if cached_key == key { - return Ok(dir.rels.contains(&(tag.relnode, tag.forknum))); - } else if cfg!(test) || cfg!(feature = "testing") { - panic!("cached reldir key mismatch: {cached_key} != {key}"); - } else { - warn!("cached reldir key mismatch: {cached_key} != {key}"); - } - // Fallback to reading the directory from the datadir. - } + async fn list_rels_v1( + &self, + spcnode: Oid, + dbnode: Oid, + version: Version<'_>, + ctx: &RequestContext, + ) -> Result, PageReconstructError> { + let key = rel_dir_to_key(spcnode, dbnode); let buf = version.get(self, key, ctx).await?; - let dir = RelDirectory::des(&buf)?; - let exists_v1 = dir.rels.contains(&(tag.relnode, tag.forknum)); - Ok(exists_v1) + let rels_v1: HashSet = + HashSet::from_iter(dir.rels.iter().map(|(relnode, forknum)| RelTag { + spcnode, + dbnode, + relnode: *relnode, + forknum: *forknum, + })); + Ok(rels_v1) + } + + async fn list_rels_v2( + &self, + spcnode: Oid, + dbnode: Oid, + version: Version<'_>, + ctx: &RequestContext, + ) -> Result, PageReconstructError> { + let key_range = rel_tag_sparse_key_range(spcnode, dbnode); + let io_concurrency = IoConcurrency::spawn_from_conf( + self.conf.get_vectored_concurrent_io, + self.gate + .enter() + .map_err(|_| PageReconstructError::Cancelled)?, + ); + let results = self + .scan( + KeySpace::single(key_range), + version.get_lsn(), + ctx, + io_concurrency, + ) + .await?; + let mut rels = HashSet::new(); + for (key, val) in results { + let val = RelDirExists::decode(&val?).map_err(|_| { + PageReconstructError::Other(anyhow::anyhow!( + "invalid reldir key: decode failed, {}", + key + )) + })?; + if key.field6 != 1 { + return Err(PageReconstructError::Other(anyhow::anyhow!( + "invalid reldir key: field6 != 1, {}", + key + ))); + } + if key.field2 != spcnode { + return Err(PageReconstructError::Other(anyhow::anyhow!( + "invalid reldir key: field2 != spcnode, {}", + key + ))); + } + if key.field3 != dbnode { + return Err(PageReconstructError::Other(anyhow::anyhow!( + "invalid reldir key: field3 != dbnode, {}", + key + ))); + } + let tag = RelTag { + spcnode, + dbnode, + relnode: key.field4, + forknum: key.field5, + }; + if val == RelDirExists::Removed { + debug_assert!(!rels.contains(&tag), "removed reltag in v2"); + continue; + } + let did_not_contain = rels.insert(tag); + debug_assert!(did_not_contain, "duplicate reltag in v2"); + } + Ok(rels) } /// Get a list of all existing relations in given tablespace and database. @@ -682,60 +835,45 @@ impl Timeline { version: Version<'_>, ctx: &RequestContext, ) -> Result, PageReconstructError> { - // fetch directory listing (old) - let key = rel_dir_to_key(spcnode, dbnode); - let buf = version.get(self, key, ctx).await?; + let (v2_status, migrated_lsn) = self.get_rel_size_v2_status(); - let dir = RelDirectory::des(&buf)?; - let rels_v1: HashSet = - HashSet::from_iter(dir.rels.iter().map(|(relnode, forknum)| RelTag { - spcnode, - dbnode, - relnode: *relnode, - forknum: *forknum, - })); - - if let RelSizeMigration::Legacy = self.get_rel_size_v2_status() { - return Ok(rels_v1); - } - - // scan directory listing (new), merge with the old results - let key_range = rel_tag_sparse_key_range(spcnode, dbnode); - let io_concurrency = IoConcurrency::spawn_from_conf( - self.conf.get_vectored_concurrent_io, - self.gate - .enter() - .map_err(|_| PageReconstructError::Cancelled)?, - ); - let results = self - .scan( - KeySpace::single(key_range), - version.get_lsn(), - ctx, - io_concurrency, - ) - .await?; - let mut rels = rels_v1; - for (key, val) in results { - let val = RelDirExists::decode(&val?) - .map_err(|_| PageReconstructError::Other(anyhow::anyhow!("invalid reldir key")))?; - assert_eq!(key.field6, 1); - assert_eq!(key.field2, spcnode); - assert_eq!(key.field3, dbnode); - let tag = RelTag { - spcnode, - dbnode, - relnode: key.field4, - forknum: key.field5, - }; - if val == RelDirExists::Removed { - debug_assert!(!rels.contains(&tag), "removed reltag in v2"); - continue; + match v2_status { + RelSizeMigration::Legacy => { + let rels_v1 = self.list_rels_v1(spcnode, dbnode, version, ctx).await?; + Ok(rels_v1) + } + RelSizeMigration::Migrating | RelSizeMigration::Migrated + if version.get_lsn() < migrated_lsn.unwrap_or(Lsn(0)) => + { + // For requests below the migrated LSN, we still use the v1 read path. + let rels_v1 = self.list_rels_v1(spcnode, dbnode, version, ctx).await?; + Ok(rels_v1) + } + RelSizeMigration::Migrating => { + let rels_v1 = self.list_rels_v1(spcnode, dbnode, version, ctx).await?; + let rels_v2_res = self.list_rels_v2(spcnode, dbnode, version, ctx).await; + match rels_v2_res { + Ok(rels_v2) if rels_v1 == rels_v2 => {} + Ok(rels_v2) => { + tracing::warn!( + "inconsistent v1/v2 reldir keyspace for db {} {}: v1_rels.len()={}, v2_rels.len()={}", + spcnode, + dbnode, + rels_v1.len(), + rels_v2.len() + ); + } + Err(e) => { + tracing::warn!("failed to list rels in v2: {e}"); + } + } + Ok(rels_v1) + } + RelSizeMigration::Migrated => { + let rels_v2 = self.list_rels_v2(spcnode, dbnode, version, ctx).await?; + Ok(rels_v2) } - let did_not_contain = rels.insert(tag); - debug_assert!(did_not_contain, "duplicate reltag in v2"); } - Ok(rels) } /// Get the whole SLRU segment @@ -1254,11 +1392,16 @@ impl Timeline { let dbdir = DbDirectory::des(&buf)?; let mut total_size: u64 = 0; - for (spcnode, dbnode) in dbdir.dbdirs.keys() { + let mut dbdir_cnt = 0; + let mut rel_cnt = 0; + + for &(spcnode, dbnode) in dbdir.dbdirs.keys() { + dbdir_cnt += 1; for rel in self - .list_rels(*spcnode, *dbnode, Version::at(lsn), ctx) + .list_rels(spcnode, dbnode, Version::at(lsn), ctx) .await? { + rel_cnt += 1; if self.cancel.is_cancelled() { return Err(CalculateLogicalSizeError::Cancelled); } @@ -1269,6 +1412,10 @@ impl Timeline { total_size += relsize as u64; } } + + self.db_rel_count + .store(Some(Arc::new((dbdir_cnt, rel_cnt)))); + Ok(total_size * BLCKSZ as u64) } @@ -1556,6 +1703,9 @@ pub struct DatadirModification<'a> { /// An **approximation** of how many metadata bytes will be written to the EphemeralFile. pending_metadata_bytes: usize, + + /// Whether we are importing a pgdata directory. + is_importing_pgdata: bool, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -1568,6 +1718,14 @@ pub enum MetricsUpdate { Sub(u64), } +/// Controls the behavior of the reldir keyspace. +pub struct RelDirMode { + // Whether we can read the v2 keyspace or not. + current_status: RelSizeMigration, + // Whether we should initialize the v2 keyspace or not. + initialize: bool, +} + impl DatadirModification<'_> { // When a DatadirModification is committed, we do a monolithic serialization of all its contents. WAL records can // contain multiple pages, so the pageserver's record-based batch size isn't sufficient to bound this allocation: we @@ -1923,30 +2081,49 @@ impl DatadirModification<'_> { } /// Returns `true` if the rel_size_v2 write path is enabled. If it is the first time that - /// we enable it, we also need to persist it in `index_part.json`. - pub fn maybe_enable_rel_size_v2(&mut self) -> anyhow::Result { - let status = self.tline.get_rel_size_v2_status(); + /// we enable it, we also need to persist it in `index_part.json` (initialize is true). + /// + /// As this function is only used on the write path, we do not need to read the migrated_at + /// field. + pub fn maybe_enable_rel_size_v2(&mut self, is_create: bool) -> anyhow::Result { + // TODO: define the behavior of the tenant-level config flag and use feature flag to enable this feature + + let (status, _) = self.tline.get_rel_size_v2_status(); let config = self.tline.get_rel_size_v2_enabled(); match (config, status) { (false, RelSizeMigration::Legacy) => { // tenant config didn't enable it and we didn't write any reldir_v2 key yet - Ok(false) + Ok(RelDirMode { + current_status: RelSizeMigration::Legacy, + initialize: false, + }) } - (false, RelSizeMigration::Migrating | RelSizeMigration::Migrated) => { + (false, status @ RelSizeMigration::Migrating | status @ RelSizeMigration::Migrated) => { // index_part already persisted that the timeline has enabled rel_size_v2 - Ok(true) + Ok(RelDirMode { + current_status: status, + initialize: false, + }) } (true, RelSizeMigration::Legacy) => { // The first time we enable it, we need to persist it in `index_part.json` - self.tline - .update_rel_size_v2_status(RelSizeMigration::Migrating)?; - tracing::info!("enabled rel_size_v2"); - Ok(true) + // The caller should update the reldir status once the initialization is done. + // + // Only initialize the v2 keyspace on new relation creation. No initialization + // during `timeline_create` (TODO: fix this, we should allow, but currently it + // hits consistency issues). + Ok(RelDirMode { + current_status: RelSizeMigration::Legacy, + initialize: is_create && !self.is_importing_pgdata, + }) } - (true, RelSizeMigration::Migrating | RelSizeMigration::Migrated) => { + (true, status @ RelSizeMigration::Migrating | status @ RelSizeMigration::Migrated) => { // index_part already persisted that the timeline has enabled rel_size_v2 // and we don't need to do anything - Ok(true) + Ok(RelDirMode { + current_status: status, + initialize: false, + }) } } } @@ -1959,8 +2136,8 @@ impl DatadirModification<'_> { img: Bytes, ctx: &RequestContext, ) -> Result<(), WalIngestError> { - let v2_enabled = self - .maybe_enable_rel_size_v2() + let v2_mode = self + .maybe_enable_rel_size_v2(false) .map_err(WalIngestErrorKind::MaybeRelSizeV2Error)?; // Add it to the directory (if it doesn't exist already) @@ -1976,17 +2153,19 @@ impl DatadirModification<'_> { self.put(DBDIR_KEY, Value::Image(buf.into())); } if r.is_none() { - // Create RelDirectory - // TODO: if we have fully migrated to v2, no need to create this directory + if v2_mode.current_status != RelSizeMigration::Legacy { + self.pending_directory_entries + .push((DirectoryKind::RelV2, MetricsUpdate::Set(0))); + } + + // Create RelDirectory in v1 keyspace. TODO: if we have fully migrated to v2, no need to create this directory. + // Some code path relies on this directory to be present. We should remove it once we starts to set tenants to + // `RelSizeMigration::Migrated` state (currently we don't, all tenants will have `RelSizeMigration::Migrating`). let buf = RelDirectory::ser(&RelDirectory { rels: HashSet::new(), })?; self.pending_directory_entries .push((DirectoryKind::Rel, MetricsUpdate::Set(0))); - if v2_enabled { - self.pending_directory_entries - .push((DirectoryKind::RelV2, MetricsUpdate::Set(0))); - } self.put( rel_dir_to_key(spcnode, dbnode), Value::Image(Bytes::from(buf)), @@ -2093,6 +2272,109 @@ impl DatadirModification<'_> { Ok(()) } + async fn initialize_rel_size_v2_keyspace( + &mut self, + ctx: &RequestContext, + dbdir: &DbDirectory, + ) -> Result<(), WalIngestError> { + // Copy everything from relv1 to relv2; TODO: check if there's any key in the v2 keyspace, if so, abort. + tracing::info!("initializing rel_size_v2 keyspace"); + let mut rel_cnt = 0; + // relmap_exists (the value of dbdirs hashmap) does not affect the migration: we need to copy things over anyways + for &(spcnode, dbnode) in dbdir.dbdirs.keys() { + let rel_dir_key = rel_dir_to_key(spcnode, dbnode); + let rel_dir = RelDirectory::des(&self.get(rel_dir_key, ctx).await?)?; + for (relnode, forknum) in rel_dir.rels { + let sparse_rel_dir_key = rel_tag_sparse_key(spcnode, dbnode, relnode, forknum); + self.put( + sparse_rel_dir_key, + Value::Image(RelDirExists::Exists.encode()), + ); + tracing::info!( + "migrated rel_size_v2: {}", + RelTag { + spcnode, + dbnode, + relnode, + forknum + } + ); + rel_cnt += 1; + } + } + tracing::info!( + "initialized rel_size_v2 keyspace at lsn {}: migrated {} relations", + self.lsn, + rel_cnt + ); + self.tline + .update_rel_size_v2_status(RelSizeMigration::Migrating, Some(self.lsn)) + .map_err(WalIngestErrorKind::MaybeRelSizeV2Error)?; + Ok::<_, WalIngestError>(()) + } + + async fn put_rel_creation_v1( + &mut self, + rel: RelTag, + dbdir_exists: bool, + ctx: &RequestContext, + ) -> Result<(), WalIngestError> { + // Reldir v1 write path + let rel_dir_key = rel_dir_to_key(rel.spcnode, rel.dbnode); + let mut rel_dir = if !dbdir_exists { + // Create the RelDirectory + RelDirectory::default() + } else { + // reldir already exists, fetch it + RelDirectory::des(&self.get(rel_dir_key, ctx).await?)? + }; + + // Add the new relation to the rel directory entry, and write it back + if !rel_dir.rels.insert((rel.relnode, rel.forknum)) { + Err(WalIngestErrorKind::RelationAlreadyExists(rel))?; + } + if !dbdir_exists { + self.pending_directory_entries + .push((DirectoryKind::Rel, MetricsUpdate::Set(0))) + } + self.pending_directory_entries + .push((DirectoryKind::Rel, MetricsUpdate::Add(1))); + self.put( + rel_dir_key, + Value::Image(Bytes::from(RelDirectory::ser(&rel_dir)?)), + ); + Ok(()) + } + + async fn put_rel_creation_v2( + &mut self, + rel: RelTag, + dbdir_exists: bool, + ctx: &RequestContext, + ) -> Result<(), WalIngestError> { + // Reldir v2 write path + let sparse_rel_dir_key = + rel_tag_sparse_key(rel.spcnode, rel.dbnode, rel.relnode, rel.forknum); + // check if the rel_dir_key exists in v2 + let val = self.sparse_get(sparse_rel_dir_key, ctx).await?; + let val = RelDirExists::decode_option(val) + .map_err(|_| WalIngestErrorKind::InvalidRelDirKey(sparse_rel_dir_key))?; + if val == RelDirExists::Exists { + Err(WalIngestErrorKind::RelationAlreadyExists(rel))?; + } + self.put( + sparse_rel_dir_key, + Value::Image(RelDirExists::Exists.encode()), + ); + if !dbdir_exists { + self.pending_directory_entries + .push((DirectoryKind::RelV2, MetricsUpdate::Set(0))); + } + self.pending_directory_entries + .push((DirectoryKind::RelV2, MetricsUpdate::Add(1))); + Ok(()) + } + /// Create a relation fork. /// /// 'nblocks' is the initial size. @@ -2126,66 +2408,31 @@ impl DatadirModification<'_> { true }; - let rel_dir_key = rel_dir_to_key(rel.spcnode, rel.dbnode); - let mut rel_dir = if !dbdir_exists { - // Create the RelDirectory - RelDirectory::default() - } else { - // reldir already exists, fetch it - RelDirectory::des(&self.get(rel_dir_key, ctx).await?)? - }; - - let v2_enabled = self - .maybe_enable_rel_size_v2() + let mut v2_mode = self + .maybe_enable_rel_size_v2(true) .map_err(WalIngestErrorKind::MaybeRelSizeV2Error)?; - if v2_enabled { - if rel_dir.rels.contains(&(rel.relnode, rel.forknum)) { - Err(WalIngestErrorKind::RelationAlreadyExists(rel))?; + if v2_mode.initialize { + if let Err(e) = self.initialize_rel_size_v2_keyspace(ctx, &dbdir).await { + tracing::warn!("error initializing rel_size_v2 keyspace: {}", e); + // TODO: circuit breaker so that it won't retry forever + } else { + v2_mode.current_status = RelSizeMigration::Migrating; } - let sparse_rel_dir_key = - rel_tag_sparse_key(rel.spcnode, rel.dbnode, rel.relnode, rel.forknum); - // check if the rel_dir_key exists in v2 - let val = self.sparse_get(sparse_rel_dir_key, ctx).await?; - let val = RelDirExists::decode_option(val) - .map_err(|_| WalIngestErrorKind::InvalidRelDirKey(sparse_rel_dir_key))?; - if val == RelDirExists::Exists { - Err(WalIngestErrorKind::RelationAlreadyExists(rel))?; + } + + if v2_mode.current_status != RelSizeMigration::Migrated { + self.put_rel_creation_v1(rel, dbdir_exists, ctx).await?; + } + + if v2_mode.current_status != RelSizeMigration::Legacy { + let write_v2_res = self.put_rel_creation_v2(rel, dbdir_exists, ctx).await; + if let Err(e) = write_v2_res { + if v2_mode.current_status == RelSizeMigration::Migrated { + return Err(e); + } + tracing::warn!("error writing rel_size_v2 keyspace: {}", e); } - self.put( - sparse_rel_dir_key, - Value::Image(RelDirExists::Exists.encode()), - ); - if !dbdir_exists { - self.pending_directory_entries - .push((DirectoryKind::Rel, MetricsUpdate::Set(0))); - self.pending_directory_entries - .push((DirectoryKind::RelV2, MetricsUpdate::Set(0))); - // We don't write `rel_dir_key -> rel_dir.rels` back to the storage in the v2 path unless it's the initial creation. - // TODO: if we have fully migrated to v2, no need to create this directory. Otherwise, there - // will be key not found errors if we don't create an empty one for rel_size_v2. - self.put( - rel_dir_key, - Value::Image(Bytes::from(RelDirectory::ser(&RelDirectory::default())?)), - ); - } - self.pending_directory_entries - .push((DirectoryKind::RelV2, MetricsUpdate::Add(1))); - } else { - // Add the new relation to the rel directory entry, and write it back - if !rel_dir.rels.insert((rel.relnode, rel.forknum)) { - Err(WalIngestErrorKind::RelationAlreadyExists(rel))?; - } - if !dbdir_exists { - self.pending_directory_entries - .push((DirectoryKind::Rel, MetricsUpdate::Set(0))) - } - self.pending_directory_entries - .push((DirectoryKind::Rel, MetricsUpdate::Add(1))); - self.put( - rel_dir_key, - Value::Image(Bytes::from(RelDirectory::ser(&rel_dir)?)), - ); } // Put size @@ -2260,15 +2507,12 @@ impl DatadirModification<'_> { Ok(()) } - /// Drop some relations - pub(crate) async fn put_rel_drops( + async fn put_rel_drop_v1( &mut self, drop_relations: HashMap<(u32, u32), Vec>, ctx: &RequestContext, - ) -> Result<(), WalIngestError> { - let v2_enabled = self - .maybe_enable_rel_size_v2() - .map_err(WalIngestErrorKind::MaybeRelSizeV2Error)?; + ) -> Result, WalIngestError> { + let mut dropped_rels = BTreeSet::new(); for ((spc_node, db_node), rel_tags) in drop_relations { let dir_key = rel_dir_to_key(spc_node, db_node); let buf = self.get(dir_key, ctx).await?; @@ -2280,25 +2524,8 @@ impl DatadirModification<'_> { self.pending_directory_entries .push((DirectoryKind::Rel, MetricsUpdate::Sub(1))); dirty = true; + dropped_rels.insert(rel_tag); true - } else if v2_enabled { - // The rel is not found in the old reldir key, so we need to check the new sparse keyspace. - // Note that a relation can only exist in one of the two keyspaces (guaranteed by the ingestion - // logic). - let key = - rel_tag_sparse_key(spc_node, db_node, rel_tag.relnode, rel_tag.forknum); - let val = RelDirExists::decode_option(self.sparse_get(key, ctx).await?) - .map_err(|_| WalIngestErrorKind::InvalidKey(key, self.lsn))?; - if val == RelDirExists::Exists { - self.pending_directory_entries - .push((DirectoryKind::RelV2, MetricsUpdate::Sub(1))); - // put tombstone - self.put(key, Value::Image(RelDirExists::Removed.encode())); - // no need to set dirty to true - true - } else { - false - } } else { false }; @@ -2321,7 +2548,67 @@ impl DatadirModification<'_> { self.put(dir_key, Value::Image(Bytes::from(RelDirectory::ser(&dir)?))); } } + Ok(dropped_rels) + } + async fn put_rel_drop_v2( + &mut self, + drop_relations: HashMap<(u32, u32), Vec>, + ctx: &RequestContext, + ) -> Result, WalIngestError> { + let mut dropped_rels = BTreeSet::new(); + for ((spc_node, db_node), rel_tags) in drop_relations { + for rel_tag in rel_tags { + let key = rel_tag_sparse_key(spc_node, db_node, rel_tag.relnode, rel_tag.forknum); + let val = RelDirExists::decode_option(self.sparse_get(key, ctx).await?) + .map_err(|_| WalIngestErrorKind::InvalidKey(key, self.lsn))?; + if val == RelDirExists::Exists { + dropped_rels.insert(rel_tag); + self.pending_directory_entries + .push((DirectoryKind::RelV2, MetricsUpdate::Sub(1))); + // put tombstone + self.put(key, Value::Image(RelDirExists::Removed.encode())); + } + } + } + Ok(dropped_rels) + } + + /// Drop some relations + pub(crate) async fn put_rel_drops( + &mut self, + drop_relations: HashMap<(u32, u32), Vec>, + ctx: &RequestContext, + ) -> Result<(), WalIngestError> { + let v2_mode = self + .maybe_enable_rel_size_v2(false) + .map_err(WalIngestErrorKind::MaybeRelSizeV2Error)?; + match v2_mode.current_status { + RelSizeMigration::Legacy => { + self.put_rel_drop_v1(drop_relations, ctx).await?; + } + RelSizeMigration::Migrating => { + let dropped_rels_v1 = self.put_rel_drop_v1(drop_relations.clone(), ctx).await?; + let dropped_rels_v2_res = self.put_rel_drop_v2(drop_relations, ctx).await; + match dropped_rels_v2_res { + Ok(dropped_rels_v2) => { + if dropped_rels_v1 != dropped_rels_v2 { + tracing::warn!( + "inconsistent v1/v2 rel drop: dropped_rels_v1.len()={}, dropped_rels_v2.len()={}", + dropped_rels_v1.len(), + dropped_rels_v2.len() + ); + } + } + Err(e) => { + tracing::warn!("error dropping rels: {}", e); + } + } + } + RelSizeMigration::Migrated => { + self.put_rel_drop_v2(drop_relations, ctx).await?; + } + } Ok(()) } diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 4c8856c386..91b717a2e9 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -1205,6 +1205,7 @@ impl TenantShard { idempotency.clone(), index_part.gc_compaction.clone(), index_part.rel_size_migration.clone(), + index_part.rel_size_migrated_at, ctx, )?; let disk_consistent_lsn = timeline.get_disk_consistent_lsn(); @@ -2584,6 +2585,7 @@ impl TenantShard { initdb_lsn, None, None, + None, ctx, ) .await @@ -2913,6 +2915,7 @@ impl TenantShard { initdb_lsn, None, None, + None, ctx, ) .await @@ -4342,6 +4345,7 @@ impl TenantShard { create_idempotency: CreateTimelineIdempotency, gc_compaction_state: Option, rel_size_v2_status: Option, + rel_size_migrated_at: Option, ctx: &RequestContext, ) -> anyhow::Result<(Arc, RequestContext)> { let state = match cause { @@ -4376,6 +4380,7 @@ impl TenantShard { create_idempotency, gc_compaction_state, rel_size_v2_status, + rel_size_migrated_at, self.cancel.child_token(), ); @@ -5085,6 +5090,7 @@ impl TenantShard { src_timeline.pg_version, ); + let (rel_size_v2_status, rel_size_migrated_at) = src_timeline.get_rel_size_v2_status(); let (uninitialized_timeline, _timeline_ctx) = self .prepare_new_timeline( dst_id, @@ -5092,7 +5098,8 @@ impl TenantShard { timeline_create_guard, start_lsn + 1, Some(Arc::clone(src_timeline)), - Some(src_timeline.get_rel_size_v2_status()), + Some(rel_size_v2_status), + rel_size_migrated_at, ctx, ) .await?; @@ -5379,6 +5386,7 @@ impl TenantShard { pgdata_lsn, None, None, + None, ctx, ) .await?; @@ -5462,14 +5470,17 @@ impl TenantShard { start_lsn: Lsn, ancestor: Option>, rel_size_v2_status: Option, + rel_size_migrated_at: Option, ctx: &RequestContext, ) -> anyhow::Result<(UninitializedTimeline<'a>, RequestContext)> { let tenant_shard_id = self.tenant_shard_id; let resources = self.build_timeline_resources(new_timeline_id); - resources - .remote_client - .init_upload_queue_for_empty_remote(new_metadata, rel_size_v2_status.clone())?; + resources.remote_client.init_upload_queue_for_empty_remote( + new_metadata, + rel_size_v2_status.clone(), + rel_size_migrated_at, + )?; let (timeline_struct, timeline_ctx) = self .create_timeline_struct( @@ -5482,6 +5493,7 @@ impl TenantShard { create_guard.idempotency.clone(), None, rel_size_v2_status, + rel_size_migrated_at, ctx, ) .context("Failed to create timeline data structure")?; diff --git a/pageserver/src/tenant/mgr.rs b/pageserver/src/tenant/mgr.rs index 9bd870d90e..0feba5e9c8 100644 --- a/pageserver/src/tenant/mgr.rs +++ b/pageserver/src/tenant/mgr.rs @@ -826,9 +826,9 @@ impl TenantManager { peek_slot.is_some() } - /// Returns whether a local slot exists for a child shard of the given tenant and shard count. - /// Note that this just checks for a shard with a larger shard count, and it may not be a - /// direct child of the given shard. + /// Returns whether a local shard exists that's a child of the given tenant shard. Note that + /// this just checks for any shard with a larger shard count, and it may not be a direct child + /// of the given shard (their keyspace may not overlap). pub(crate) fn has_child_shard(&self, tenant_id: TenantId, shard_index: ShardIndex) -> bool { match &*self.tenants.read().unwrap() { TenantsMap::Initializing => false, @@ -1534,6 +1534,13 @@ impl TenantManager { self.resources.deletion_queue_client.flush_advisory(); // Phase 2: Put the parent shard to InProgress and grab a reference to the parent Tenant + // + // TODO: keeping the parent as InProgress while spawning the children causes read + // unavailability, as we can't acquire a new timeline handle for it (existing handles appear + // to still work though, even downgraded ones). The parent should be available for reads + // until the children are ready -- potentially until *all* subsplits across all parent + // shards are complete and the compute has been notified. See: + // . drop(tenant); let mut parent_slot_guard = self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?; diff --git a/pageserver/src/tenant/remote_timeline_client.rs b/pageserver/src/tenant/remote_timeline_client.rs index fd65000379..6b650beb3f 100644 --- a/pageserver/src/tenant/remote_timeline_client.rs +++ b/pageserver/src/tenant/remote_timeline_client.rs @@ -443,7 +443,8 @@ impl RemoteTimelineClient { pub fn init_upload_queue_for_empty_remote( &self, local_metadata: &TimelineMetadata, - rel_size_v2_status: Option, + rel_size_v2_migration: Option, + rel_size_migrated_at: Option, ) -> anyhow::Result<()> { // Set the maximum number of inprogress tasks to the remote storage concurrency. There's // certainly no point in starting more upload tasks than this. @@ -455,7 +456,8 @@ impl RemoteTimelineClient { let mut upload_queue = self.upload_queue.lock().unwrap(); let initialized_queue = upload_queue.initialize_empty_remote(local_metadata, inprogress_limit)?; - initialized_queue.dirty.rel_size_migration = rel_size_v2_status; + initialized_queue.dirty.rel_size_migration = rel_size_v2_migration; + initialized_queue.dirty.rel_size_migrated_at = rel_size_migrated_at; self.update_remote_physical_size_gauge(None); info!("initialized upload queue as empty"); Ok(()) @@ -994,10 +996,12 @@ impl RemoteTimelineClient { pub(crate) fn schedule_index_upload_for_rel_size_v2_status_update( self: &Arc, rel_size_v2_status: RelSizeMigration, + rel_size_migrated_at: Option, ) -> anyhow::Result<()> { let mut guard = self.upload_queue.lock().unwrap(); let upload_queue = guard.initialized_mut()?; upload_queue.dirty.rel_size_migration = Some(rel_size_v2_status); + upload_queue.dirty.rel_size_migrated_at = rel_size_migrated_at; // TODO: allow this operation to bypass the validation check because we might upload the index part // with no layers but the flag updated. For now, we just modify the index part in memory and the next // upload will include the flag. diff --git a/pageserver/src/tenant/remote_timeline_client/index.rs b/pageserver/src/tenant/remote_timeline_client/index.rs index 6060c42cbb..9531e7f650 100644 --- a/pageserver/src/tenant/remote_timeline_client/index.rs +++ b/pageserver/src/tenant/remote_timeline_client/index.rs @@ -114,6 +114,11 @@ pub struct IndexPart { /// The timestamp when the timeline was marked invisible in synthetic size calculations. #[serde(skip_serializing_if = "Option::is_none", default)] pub(crate) marked_invisible_at: Option, + + /// The LSN at which we started the rel size migration. Accesses below this LSN should be + /// processed with the v1 read path. Usually this LSN should be set together with `rel_size_migration`. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub(crate) rel_size_migrated_at: Option, } #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] @@ -142,10 +147,12 @@ impl IndexPart { /// - 12: +l2_lsn /// - 13: +gc_compaction /// - 14: +marked_invisible_at - const LATEST_VERSION: usize = 14; + /// - 15: +rel_size_migrated_at + const LATEST_VERSION: usize = 15; // Versions we may see when reading from a bucket. - pub const KNOWN_VERSIONS: &'static [usize] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; + pub const KNOWN_VERSIONS: &'static [usize] = + &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]; pub const FILE_NAME: &'static str = "index_part.json"; @@ -165,6 +172,7 @@ impl IndexPart { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, } } @@ -475,6 +483,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -524,6 +533,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -574,6 +584,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -627,6 +638,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let empty_layers_parsed = IndexPart::from_json_bytes(empty_layers_json.as_bytes()).unwrap(); @@ -675,6 +687,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -726,6 +739,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -782,6 +796,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -843,6 +858,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -905,6 +921,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -972,6 +989,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -1052,6 +1070,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -1133,6 +1152,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -1220,6 +1240,7 @@ mod tests { last_completed_lsn: "0/16960E8".parse::().unwrap(), }), marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -1308,6 +1329,97 @@ mod tests { last_completed_lsn: "0/16960E8".parse::().unwrap(), }), marked_invisible_at: Some(parse_naive_datetime("2023-07-31T09:00:00.123000000")), + rel_size_migrated_at: None, + }; + + let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); + assert_eq!(part, expected); + } + + #[test] + fn v15_rel_size_migrated_at_is_parsed() { + let example = r#"{ + "version": 15, + "layer_metadata":{ + "000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9": { "file_size": 25600000 }, + "000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51": { "file_size": 9007199254741001 } + }, + "disk_consistent_lsn":"0/16960E8", + "metadata": { + "disk_consistent_lsn": "0/16960E8", + "prev_record_lsn": "0/1696070", + "ancestor_timeline": "e45a7f37d3ee2ff17dc14bf4f4e3f52e", + "ancestor_lsn": "0/0", + "latest_gc_cutoff_lsn": "0/1696070", + "initdb_lsn": "0/1696070", + "pg_version": 14 + }, + "gc_blocking": { + "started_at": "2024-07-19T09:00:00.123", + "reasons": ["DetachAncestor"] + }, + "import_pgdata": { + "V1": { + "Done": { + "idempotency_key": "specified-by-client-218a5213-5044-4562-a28d-d024c5f057f5", + "started_at": "2024-11-13T09:23:42.123", + "finished_at": "2024-11-13T09:42:23.123" + } + } + }, + "rel_size_migration": "legacy", + "l2_lsn": "0/16960E8", + "gc_compaction": { + "last_completed_lsn": "0/16960E8" + }, + "marked_invisible_at": "2023-07-31T09:00:00.123", + "rel_size_migrated_at": "0/16960E8" + }"#; + + let expected = IndexPart { + version: 15, + layer_metadata: HashMap::from([ + ("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap(), LayerFileMetadata { + file_size: 25600000, + generation: Generation::none(), + shard: ShardIndex::unsharded() + }), + ("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51".parse().unwrap(), LayerFileMetadata { + file_size: 9007199254741001, + generation: Generation::none(), + shard: ShardIndex::unsharded() + }) + ]), + disk_consistent_lsn: "0/16960E8".parse::().unwrap(), + metadata: TimelineMetadata::new( + Lsn::from_str("0/16960E8").unwrap(), + Some(Lsn::from_str("0/1696070").unwrap()), + Some(TimelineId::from_str("e45a7f37d3ee2ff17dc14bf4f4e3f52e").unwrap()), + Lsn::INVALID, + Lsn::from_str("0/1696070").unwrap(), + Lsn::from_str("0/1696070").unwrap(), + PgMajorVersion::PG14, + ).with_recalculated_checksum().unwrap(), + deleted_at: None, + lineage: Default::default(), + gc_blocking: Some(GcBlocking { + started_at: parse_naive_datetime("2024-07-19T09:00:00.123000000"), + reasons: enumset::EnumSet::from_iter([GcBlockingReason::DetachAncestor]), + }), + last_aux_file_policy: Default::default(), + archived_at: None, + import_pgdata: Some(import_pgdata::index_part_format::Root::V1(import_pgdata::index_part_format::V1::Done(import_pgdata::index_part_format::Done{ + started_at: parse_naive_datetime("2024-11-13T09:23:42.123000000"), + finished_at: parse_naive_datetime("2024-11-13T09:42:23.123000000"), + idempotency_key: import_pgdata::index_part_format::IdempotencyKey::new("specified-by-client-218a5213-5044-4562-a28d-d024c5f057f5".to_string()), + }))), + rel_size_migration: Some(RelSizeMigration::Legacy), + l2_lsn: Some("0/16960E8".parse::().unwrap()), + gc_compaction: Some(GcCompactionState { + last_completed_lsn: "0/16960E8".parse::().unwrap(), + }), + marked_invisible_at: Some(parse_naive_datetime("2023-07-31T09:00:00.123000000")), + rel_size_migrated_at: Some("0/16960E8".parse::().unwrap()), }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 7f6173db3f..ff66b0ecc8 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -70,7 +70,7 @@ use tracing::*; use utils::generation::Generation; use utils::guard_arc_swap::GuardArcSwap; use utils::id::TimelineId; -use utils::logging::{MonitorSlowFutureCallback, monitor_slow_future}; +use utils::logging::{MonitorSlowFutureCallback, log_slow, monitor_slow_future}; use utils::lsn::{AtomicLsn, Lsn, RecordLsn}; use utils::postgres_client::PostgresClientProtocol; use utils::rate_limit::RateLimit; @@ -287,7 +287,7 @@ pub struct Timeline { ancestor_lsn: Lsn, // The LSN of gc-compaction that was last applied to this timeline. - gc_compaction_state: ArcSwap>, + gc_compaction_state: ArcSwapOption, pub(crate) metrics: Arc, @@ -397,6 +397,11 @@ pub struct Timeline { /// If true, the last compaction failed. compaction_failed: AtomicBool, + /// Begin Hadron: If true, the pageserver has likely detected data corruption in the timeline. + /// We need to feed this information back to the Safekeeper and postgres for them to take the + /// appropriate action. + corruption_detected: AtomicBool, + /// Notifies the tenant compaction loop that there is pending L0 compaction work. l0_compaction_trigger: Arc, @@ -441,7 +446,7 @@ pub struct Timeline { /// heatmap on demand. heatmap_layers_downloader: Mutex>, - pub(crate) rel_size_v2_status: ArcSwapOption, + pub(crate) rel_size_v2_status: ArcSwap<(Option, Option)>, wait_lsn_log_slow: tokio::sync::Semaphore, @@ -450,6 +455,9 @@ pub struct Timeline { #[expect(dead_code)] feature_resolver: Arc, + + /// Basebackup will collect the count and store it here. Used for reldirv2 rollout. + pub(crate) db_rel_count: ArcSwapOption<(usize, usize)>, } pub(crate) enum PreviousHeatmap { @@ -2891,12 +2899,9 @@ impl Timeline { .unwrap_or(self.conf.default_tenant_conf.rel_size_v2_enabled) } - pub(crate) fn get_rel_size_v2_status(&self) -> RelSizeMigration { - self.rel_size_v2_status - .load() - .as_ref() - .map(|s| s.as_ref().clone()) - .unwrap_or(RelSizeMigration::Legacy) + pub(crate) fn get_rel_size_v2_status(&self) -> (RelSizeMigration, Option) { + let (status, migrated_at) = self.rel_size_v2_status.load().as_ref().clone(); + (status.unwrap_or(RelSizeMigration::Legacy), migrated_at) } fn get_compaction_upper_limit(&self) -> usize { @@ -3171,6 +3176,7 @@ impl Timeline { create_idempotency: crate::tenant::CreateTimelineIdempotency, gc_compaction_state: Option, rel_size_v2_status: Option, + rel_size_migrated_at: Option, cancel: CancellationToken, ) -> Arc { let disk_consistent_lsn = metadata.disk_consistent_lsn(); @@ -3237,7 +3243,7 @@ impl Timeline { }), disk_consistent_lsn: AtomicLsn::new(disk_consistent_lsn.0), - gc_compaction_state: ArcSwap::new(Arc::new(gc_compaction_state)), + gc_compaction_state: ArcSwapOption::from_pointee(gc_compaction_state), last_freeze_at: AtomicLsn::new(disk_consistent_lsn.0), last_freeze_ts: RwLock::new(Instant::now()), @@ -3309,6 +3315,7 @@ impl Timeline { compaction_lock: tokio::sync::Mutex::default(), compaction_failed: AtomicBool::default(), + corruption_detected: AtomicBool::default(), l0_compaction_trigger: resources.l0_compaction_trigger, gc_lock: tokio::sync::Mutex::default(), @@ -3335,13 +3342,18 @@ impl Timeline { heatmap_layers_downloader: Mutex::new(None), - rel_size_v2_status: ArcSwapOption::from_pointee(rel_size_v2_status), + rel_size_v2_status: ArcSwap::from_pointee(( + rel_size_v2_status, + rel_size_migrated_at, + )), wait_lsn_log_slow: tokio::sync::Semaphore::new(1), basebackup_cache: resources.basebackup_cache, feature_resolver: resources.feature_resolver.clone(), + + db_rel_count: ArcSwapOption::from_pointee(None), }; result.repartition_threshold = @@ -3413,7 +3425,7 @@ impl Timeline { gc_compaction_state: GcCompactionState, ) -> anyhow::Result<()> { self.gc_compaction_state - .store(Arc::new(Some(gc_compaction_state.clone()))); + .store(Some(Arc::new(gc_compaction_state.clone()))); self.remote_client .schedule_index_upload_for_gc_compaction_state_update(gc_compaction_state) } @@ -3421,15 +3433,24 @@ impl Timeline { pub(crate) fn update_rel_size_v2_status( &self, rel_size_v2_status: RelSizeMigration, + rel_size_migrated_at: Option, ) -> anyhow::Result<()> { - self.rel_size_v2_status - .store(Some(Arc::new(rel_size_v2_status.clone()))); + self.rel_size_v2_status.store(Arc::new(( + Some(rel_size_v2_status.clone()), + rel_size_migrated_at, + ))); self.remote_client - .schedule_index_upload_for_rel_size_v2_status_update(rel_size_v2_status) + .schedule_index_upload_for_rel_size_v2_status_update( + rel_size_v2_status, + rel_size_migrated_at, + ) } pub(crate) fn get_gc_compaction_state(&self) -> Option { - self.gc_compaction_state.load_full().as_ref().clone() + self.gc_compaction_state + .load() + .as_ref() + .map(|x| x.as_ref().clone()) } /// Creates and starts the wal receiver. @@ -5989,6 +6010,17 @@ impl Timeline { ))) }); + // Begin Hadron + // + fail_point!("create-image-layer-fail-simulated-corruption", |_| { + self.corruption_detected + .store(true, std::sync::atomic::Ordering::Relaxed); + Err(CreateImageLayersError::Other(anyhow::anyhow!( + "failpoint create-image-layer-fail-simulated-corruption" + ))) + }); + // End Hadron + let io_concurrency = IoConcurrency::spawn_from_conf( self.conf.get_vectored_concurrent_io, self.gate @@ -6883,7 +6915,13 @@ impl Timeline { write_guard.store_and_unlock(new_gc_cutoff) }; - waitlist.wait().await; + let waitlist_wait_fut = std::pin::pin!(waitlist.wait()); + log_slow( + "applied_gc_cutoff waitlist wait", + Duration::from_secs(30), + waitlist_wait_fut, + ) + .await; info!("GC starting"); @@ -7128,6 +7166,7 @@ impl Timeline { critical_timeline!( self.tenant_shard_id, self.timeline_id, + Some(&self.corruption_detected), "walredo failure during page reconstruction: {err:?}" ); } diff --git a/pageserver/src/tenant/timeline/compaction.rs b/pageserver/src/tenant/timeline/compaction.rs index 9bca952a46..c5363d84b7 100644 --- a/pageserver/src/tenant/timeline/compaction.rs +++ b/pageserver/src/tenant/timeline/compaction.rs @@ -1397,6 +1397,7 @@ impl Timeline { critical_timeline!( self.tenant_shard_id, self.timeline_id, + Some(&self.corruption_detected), "missing key during compaction: {err:?}" ); } @@ -1441,6 +1442,7 @@ impl Timeline { critical_timeline!( self.tenant_shard_id, self.timeline_id, + Some(&self.corruption_detected), "could not compact, repartitioning keyspace failed: {e:?}" ); } diff --git a/pageserver/src/tenant/timeline/delete.rs b/pageserver/src/tenant/timeline/delete.rs index f7dc44be90..2f6eccdbf9 100644 --- a/pageserver/src/tenant/timeline/delete.rs +++ b/pageserver/src/tenant/timeline/delete.rs @@ -332,6 +332,7 @@ impl DeleteTimelineFlow { crate::tenant::CreateTimelineIdempotency::FailWithConflict, // doesn't matter what we put here None, // doesn't matter what we put here None, // doesn't matter what we put here + None, // doesn't matter what we put here ctx, ) .context("create_timeline_struct")?; diff --git a/pageserver/src/tenant/timeline/handle.rs b/pageserver/src/tenant/timeline/handle.rs index 0b118dd65d..537b9ff373 100644 --- a/pageserver/src/tenant/timeline/handle.rs +++ b/pageserver/src/tenant/timeline/handle.rs @@ -224,11 +224,11 @@ use tracing::{instrument, trace}; use utils::id::TimelineId; use utils::shard::{ShardIndex, ShardNumber}; -use crate::tenant::mgr::ShardSelector; +use crate::page_service::GetActiveTimelineError; +use crate::tenant::GetTimelineError; +use crate::tenant::mgr::{GetActiveTenantError, ShardSelector}; -/// The requirement for Debug is so that #[derive(Debug)] works in some places. -pub(crate) trait Types: Sized + std::fmt::Debug { - type TenantManagerError: Sized + std::fmt::Debug; +pub(crate) trait Types: Sized { type TenantManager: TenantManager + Sized; type Timeline: Timeline + Sized; } @@ -307,12 +307,11 @@ impl Default for PerTimelineState { /// Abstract view of [`crate::tenant::mgr`], for testability. pub(crate) trait TenantManager { /// Invoked by [`Cache::get`] to resolve a [`ShardTimelineId`] to a [`Types::Timeline`]. - /// Errors are returned as [`GetError::TenantManager`]. async fn resolve( &self, timeline_id: TimelineId, shard_selector: ShardSelector, - ) -> Result; + ) -> Result; } /// Abstract view of an [`Arc`], for testability. @@ -322,13 +321,6 @@ pub(crate) trait Timeline { fn per_timeline_state(&self) -> &PerTimelineState; } -/// Errors returned by [`Cache::get`]. -#[derive(Debug)] -pub(crate) enum GetError { - TenantManager(T::TenantManagerError), - PerTimelineStateShutDown, -} - /// Internal type used in [`Cache::get`]. enum RoutingResult { FastPath(Handle), @@ -345,7 +337,7 @@ impl Cache { timeline_id: TimelineId, shard_selector: ShardSelector, tenant_manager: &T::TenantManager, - ) -> Result, GetError> { + ) -> Result, GetActiveTimelineError> { const GET_MAX_RETRIES: usize = 10; const RETRY_BACKOFF: Duration = Duration::from_millis(100); let mut attempt = 0; @@ -356,13 +348,17 @@ impl Cache { .await { Ok(handle) => return Ok(handle), - Err(e) => { + Err( + e @ GetActiveTimelineError::Tenant(GetActiveTenantError::WaitForActiveTimeout { + .. + }), + ) => { // Retry on tenant manager error to handle tenant split more gracefully if attempt < GET_MAX_RETRIES { tokio::time::sleep(RETRY_BACKOFF).await; continue; } else { - tracing::warn!( + tracing::info!( "Failed to resolve tenant shard after {} attempts: {:?}", GET_MAX_RETRIES, e @@ -370,6 +366,7 @@ impl Cache { return Err(e); } } + Err(err) => return Err(err), } } } @@ -388,7 +385,7 @@ impl Cache { timeline_id: TimelineId, shard_selector: ShardSelector, tenant_manager: &T::TenantManager, - ) -> Result, GetError> { + ) -> Result, GetActiveTimelineError> { // terminates because when every iteration we remove an element from the map let miss: ShardSelector = loop { let routing_state = self.shard_routing(timeline_id, shard_selector); @@ -468,60 +465,50 @@ impl Cache { timeline_id: TimelineId, shard_selector: ShardSelector, tenant_manager: &T::TenantManager, - ) -> Result, GetError> { - match tenant_manager.resolve(timeline_id, shard_selector).await { - Ok(timeline) => { - let key = timeline.shard_timeline_id(); - match &shard_selector { - ShardSelector::Zero => assert_eq!(key.shard_index.shard_number, ShardNumber(0)), - ShardSelector::Page(_) => (), // gotta trust tenant_manager - ShardSelector::Known(idx) => assert_eq!(idx, &key.shard_index), - } - - trace!("creating new HandleInner"); - let timeline = Arc::new(timeline); - let handle_inner_arc = - Arc::new(Mutex::new(HandleInner::Open(Arc::clone(&timeline)))); - let handle_weak = WeakHandle { - inner: Arc::downgrade(&handle_inner_arc), - }; - let handle = handle_weak - .upgrade() - .ok() - .expect("we just created it and it's not linked anywhere yet"); - { - let mut lock_guard = timeline - .per_timeline_state() - .handles - .lock() - .expect("mutex poisoned"); - match &mut *lock_guard { - Some(per_timeline_state) => { - let replaced = - per_timeline_state.insert(self.id, Arc::clone(&handle_inner_arc)); - assert!(replaced.is_none(), "some earlier code left a stale handle"); - match self.map.entry(key) { - hash_map::Entry::Occupied(_o) => { - // This cannot not happen because - // 1. we're the _miss_ handle, i.e., `self.map` didn't contain an entry and - // 2. we were holding &mut self during .resolve().await above, so, no other thread can have inserted a handle - // while we were waiting for the tenant manager. - unreachable!() - } - hash_map::Entry::Vacant(v) => { - v.insert(handle_weak); - } - } - } - None => { - return Err(GetError::PerTimelineStateShutDown); - } - } - } - Ok(handle) - } - Err(e) => Err(GetError::TenantManager(e)), + ) -> Result, GetActiveTimelineError> { + let timeline = tenant_manager.resolve(timeline_id, shard_selector).await?; + let key = timeline.shard_timeline_id(); + match &shard_selector { + ShardSelector::Zero => assert_eq!(key.shard_index.shard_number, ShardNumber(0)), + ShardSelector::Page(_) => (), // gotta trust tenant_manager + ShardSelector::Known(idx) => assert_eq!(idx, &key.shard_index), } + + trace!("creating new HandleInner"); + let timeline = Arc::new(timeline); + let handle_inner_arc = Arc::new(Mutex::new(HandleInner::Open(Arc::clone(&timeline)))); + let handle_weak = WeakHandle { + inner: Arc::downgrade(&handle_inner_arc), + }; + let handle = handle_weak + .upgrade() + .ok() + .expect("we just created it and it's not linked anywhere yet"); + let mut lock_guard = timeline + .per_timeline_state() + .handles + .lock() + .expect("mutex poisoned"); + let Some(per_timeline_state) = &mut *lock_guard else { + return Err(GetActiveTimelineError::Timeline( + GetTimelineError::ShuttingDown, + )); + }; + let replaced = per_timeline_state.insert(self.id, Arc::clone(&handle_inner_arc)); + assert!(replaced.is_none(), "some earlier code left a stale handle"); + match self.map.entry(key) { + hash_map::Entry::Occupied(_o) => { + // This cannot not happen because + // 1. we're the _miss_ handle, i.e., `self.map` didn't contain an entry and + // 2. we were holding &mut self during .resolve().await above, so, no other thread can have inserted a handle + // while we were waiting for the tenant manager. + unreachable!() + } + hash_map::Entry::Vacant(v) => { + v.insert(handle_weak); + } + } + Ok(handle) } } @@ -655,7 +642,8 @@ mod tests { use pageserver_api::models::ShardParameters; use pageserver_api::reltag::RelTag; use pageserver_api::shard::DEFAULT_STRIPE_SIZE; - use utils::shard::ShardCount; + use utils::id::TenantId; + use utils::shard::{ShardCount, TenantShardId}; use utils::sync::gate::GateGuard; use super::*; @@ -665,7 +653,6 @@ mod tests { #[derive(Debug)] struct TestTypes; impl Types for TestTypes { - type TenantManagerError = anyhow::Error; type TenantManager = StubManager; type Timeline = Entered; } @@ -716,40 +703,48 @@ mod tests { &self, timeline_id: TimelineId, shard_selector: ShardSelector, - ) -> anyhow::Result { + ) -> Result { + fn enter_gate( + timeline: &StubTimeline, + ) -> Result, GetActiveTimelineError> { + Ok(Arc::new(timeline.gate.enter().map_err(|_| { + GetActiveTimelineError::Timeline(GetTimelineError::ShuttingDown) + })?)) + } + for timeline in &self.shards { if timeline.id == timeline_id { - let enter_gate = || { - let gate_guard = timeline.gate.enter()?; - let gate_guard = Arc::new(gate_guard); - anyhow::Ok(gate_guard) - }; match &shard_selector { ShardSelector::Zero if timeline.shard.is_shard_zero() => { return Ok(Entered { timeline: Arc::clone(timeline), - gate_guard: enter_gate()?, + gate_guard: enter_gate(timeline)?, }); } ShardSelector::Zero => continue, ShardSelector::Page(key) if timeline.shard.is_key_local(key) => { return Ok(Entered { timeline: Arc::clone(timeline), - gate_guard: enter_gate()?, + gate_guard: enter_gate(timeline)?, }); } ShardSelector::Page(_) => continue, ShardSelector::Known(idx) if idx == &timeline.shard.shard_index() => { return Ok(Entered { timeline: Arc::clone(timeline), - gate_guard: enter_gate()?, + gate_guard: enter_gate(timeline)?, }); } ShardSelector::Known(_) => continue, } } } - anyhow::bail!("not found") + Err(GetActiveTimelineError::Timeline( + GetTimelineError::NotFound { + tenant_id: TenantShardId::unsharded(TenantId::from([0; 16])), + timeline_id, + }, + )) } } diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index f619c69599..7ec5aa3b77 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -365,6 +365,7 @@ pub(super) async fn handle_walreceiver_connection( critical_timeline!( timeline.tenant_shard_id, timeline.timeline_id, + Some(&timeline.corruption_detected), "{msg}" ); return Err(WalReceiverError::Other(anyhow!(msg))); @@ -382,6 +383,7 @@ pub(super) async fn handle_walreceiver_connection( critical_timeline!( timeline.tenant_shard_id, timeline.timeline_id, + Some(&timeline.corruption_detected), "{msg}" ); return Err(WalReceiverError::Other(anyhow!(msg))); @@ -455,6 +457,7 @@ pub(super) async fn handle_walreceiver_connection( critical_timeline!( timeline.tenant_shard_id, timeline.timeline_id, + Some(&timeline.corruption_detected), "{err:?}" ); } @@ -586,6 +589,9 @@ pub(super) async fn handle_walreceiver_connection( remote_consistent_lsn, replytime: ts, shard_number: timeline.tenant_shard_id.shard_number.0 as u32, + corruption_detected: timeline + .corruption_detected + .load(std::sync::atomic::Ordering::Relaxed), }; debug!("neon_status_update {status_update:?}"); diff --git a/pageserver/src/utilization.rs b/pageserver/src/utilization.rs index 0dafa5c4bb..cec28f8059 100644 --- a/pageserver/src/utilization.rs +++ b/pageserver/src/utilization.rs @@ -52,7 +52,7 @@ pub(crate) fn regenerate( }; // Express a static value for how many shards we may schedule on one node - const MAX_SHARDS: u32 = 5000; + const MAX_SHARDS: u32 = 2500; let mut doc = PageserverUtilization { disk_usage_bytes: used, diff --git a/pageserver/src/walingest.rs b/pageserver/src/walingest.rs index 3acf98b020..c364334dab 100644 --- a/pageserver/src/walingest.rs +++ b/pageserver/src/walingest.rs @@ -23,6 +23,7 @@ use std::backtrace::Backtrace; use std::collections::HashMap; +use std::sync::atomic::AtomicBool; use std::sync::{Arc, OnceLock}; use std::time::{Duration, Instant, SystemTime}; @@ -422,6 +423,8 @@ impl WalIngest { critical_timeline!( modification.tline.tenant_shard_id, modification.tline.timeline_id, + // Hadron: No need to raise the corruption flag here; the caller of `ingest_record()` will do it. + None::<&AtomicBool>, "clear_vm_bits for unknown VM relation {vm_rel}" ); return Ok(()); @@ -431,6 +434,8 @@ impl WalIngest { critical_timeline!( modification.tline.tenant_shard_id, modification.tline.timeline_id, + // Hadron: No need to raise the corruption flag here; the caller of `ingest_record()` will do it. + None::<&AtomicBool>, "new_vm_blk {blknum} not in {vm_rel} of size {vm_size}" ); new_vm_blk = None; @@ -441,6 +446,8 @@ impl WalIngest { critical_timeline!( modification.tline.tenant_shard_id, modification.tline.timeline_id, + // Hadron: No need to raise the corruption flag here; the caller of `ingest_record()` will do it. + None::<&AtomicBool>, "old_vm_blk {blknum} not in {vm_rel} of size {vm_size}" ); old_vm_blk = None; diff --git a/pgxn/neon/Makefile b/pgxn/neon/Makefile index 04a06fcb63..95a6dd382b 100644 --- a/pgxn/neon/Makefile +++ b/pgxn/neon/Makefile @@ -35,6 +35,10 @@ SHLIB_LINK = -lcurl UNAME_S := $(shell uname -s) ifeq ($(UNAME_S), Darwin) SHLIB_LINK += -framework Security -framework CoreFoundation -framework SystemConfiguration + + # Link against object files for the current macOS version, to avoid spurious linker warnings. + MACOSX_DEPLOYMENT_TARGET := $(shell xcrun --sdk macosx --show-sdk-version) + export MACOSX_DEPLOYMENT_TARGET endif EXTENSION = neon diff --git a/pgxn/neon/communicator.c b/pgxn/neon/communicator.c index 5a08b3e331..4c03193d7e 100644 --- a/pgxn/neon/communicator.c +++ b/pgxn/neon/communicator.c @@ -79,10 +79,6 @@ #include "access/xlogrecovery.h" #endif -#if PG_VERSION_NUM < 160000 -typedef PGAlignedBlock PGIOAlignedBlock; -#endif - #define NEON_PANIC_CONNECTION_STATE(shard_no, elvl, message, ...) \ neon_shard_log(shard_no, elvl, "Broken connection state: " message, \ ##__VA_ARGS__) diff --git a/pgxn/neon/communicator/Cargo.toml b/pgxn/neon/communicator/Cargo.toml index 0ea13dce04..ec29bd7e57 100644 --- a/pgxn/neon/communicator/Cargo.toml +++ b/pgxn/neon/communicator/Cargo.toml @@ -34,7 +34,6 @@ tokio-pipe = { version = "0.2.12" } tracing.workspace = true tracing-subscriber.workspace = true -metrics.workspace = true uring-common = { workspace = true, features = ["bytes"] } pageserver_client_grpc.workspace = true diff --git a/pgxn/neon/communicator/src/backend_interface.rs b/pgxn/neon/communicator/src/backend_interface.rs index 0583286ee5..c3ae4a2436 100644 --- a/pgxn/neon/communicator/src/backend_interface.rs +++ b/pgxn/neon/communicator/src/backend_interface.rs @@ -6,9 +6,11 @@ use std::os::fd::OwnedFd; use crate::backend_comms::NeonIORequestSlot; use crate::init::CommunicatorInitStruct; use crate::integrated_cache::{BackendCacheReadOp, IntegratedCacheReadAccess}; -use crate::neon_request::{CCachedGetPageVResult, COid}; +use crate::neon_request::{CCachedGetPageVResult, CLsn, COid}; use crate::neon_request::{NeonIORequest, NeonIOResult}; +use utils::lsn::Lsn; + pub struct CommunicatorBackendStruct<'t> { my_proc_number: i32, @@ -18,7 +20,7 @@ pub struct CommunicatorBackendStruct<'t> { pending_cache_read_op: Option>, - integrated_cache: &'t IntegratedCacheReadAccess, + integrated_cache: &'t IntegratedCacheReadAccess<'t>, } #[unsafe(no_mangle)] @@ -174,17 +176,21 @@ pub extern "C" fn bcomm_finish_cache_read(bs: &mut CommunicatorBackendStruct) -> } } -/// Check if the local file cache contians the given block +/// Check if LFC contains the given buffer, and update its last-written LSN if not. +/// +/// This is used in WAL replay in read replica, to skip updating pages that are +/// not in cache. #[unsafe(no_mangle)] -pub extern "C" fn bcomm_cache_contains( +pub extern "C" fn bcomm_update_lw_lsn_for_block_if_not_cached( bs: &mut CommunicatorBackendStruct, spc_oid: COid, db_oid: COid, rel_number: u32, fork_number: u8, block_number: u32, + lsn: CLsn, ) -> bool { - bs.integrated_cache.cache_contains_page( + bs.integrated_cache.update_lw_lsn_for_block_if_not_cached( &pageserver_page_api::RelTag { spcnode: spc_oid, dbnode: db_oid, @@ -192,6 +198,7 @@ pub extern "C" fn bcomm_cache_contains( forknum: fork_number, }, block_number, + Lsn(lsn), ) } diff --git a/pgxn/neon/communicator/src/file_cache.rs b/pgxn/neon/communicator/src/file_cache.rs index 60cb1f3cd9..085da01a13 100644 --- a/pgxn/neon/communicator/src/file_cache.rs +++ b/pgxn/neon/communicator/src/file_cache.rs @@ -14,6 +14,11 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use std::sync::Mutex; +use measured::metric; +use measured::metric::MetricEncoding; +use measured::metric::gauge::GaugeState; +use measured::{Gauge, MetricGroup}; + use crate::BLCKSZ; use tokio::task::spawn_blocking; @@ -22,7 +27,6 @@ pub type CacheBlock = u64; pub const INVALID_CACHE_BLOCK: CacheBlock = u64::MAX; -#[derive(Debug)] pub struct FileCache { file: Arc, free_list: Mutex, @@ -31,9 +35,17 @@ pub struct FileCache { // on an existing file descroptor, so we have to save the path. path: PathBuf, - // metrics - max_blocks_gauge: metrics::IntGauge, - num_free_blocks_gauge: metrics::IntGauge, + metrics: FileCacheMetricGroup, +} + +#[derive(MetricGroup)] +#[metric(new())] +struct FileCacheMetricGroup { + /// Local File Cache size in 8KiB blocks + max_blocks: Gauge, + + /// Number of free 8KiB blocks in Local File Cache + num_free_blocks: Gauge, } // TODO: We keep track of all free blocks in this vec. That doesn't really scale. @@ -64,17 +76,6 @@ impl FileCache { .create(true) .open(file_cache_path)?; - let max_blocks_gauge = metrics::IntGauge::new( - "file_cache_max_blocks", - "Local File Cache size in 8KiB blocks", - ) - .unwrap(); - let num_free_blocks_gauge = metrics::IntGauge::new( - "file_cache_num_free_blocks", - "Number of free 8KiB blocks in Local File Cache", - ) - .unwrap(); - tracing::info!("initialized file cache with {} blocks", initial_size); Ok(FileCache { @@ -85,8 +86,7 @@ impl FileCache { free_blocks: Vec::new(), }), path: file_cache_path.to_path_buf(), - max_blocks_gauge, - num_free_blocks_gauge, + metrics: FileCacheMetricGroup::new(), }) } @@ -222,27 +222,21 @@ impl FileCache { } } -impl metrics::core::Collector for FileCache { - fn desc(&self) -> Vec<&metrics::core::Desc> { - let mut descs = Vec::new(); - descs.append(&mut self.max_blocks_gauge.desc()); - descs.append(&mut self.num_free_blocks_gauge.desc()); - descs - } - fn collect(&self) -> Vec { +impl MetricGroup for FileCache +where + GaugeState: MetricEncoding, +{ + fn collect_group_into(&self, enc: &mut T) -> Result<(), ::Err> { // Update the gauges with fresh values first { let free_list = self.free_list.lock().unwrap(); - self.max_blocks_gauge.set(free_list.max_blocks as i64); + self.metrics.max_blocks.set(free_list.max_blocks as i64); let total_free_blocks: i64 = free_list.free_blocks.len() as i64 + (free_list.max_blocks as i64 - free_list.next_free_block as i64); - self.num_free_blocks_gauge.set(total_free_blocks); + self.metrics.num_free_blocks.set(total_free_blocks); } - let mut values = Vec::new(); - values.append(&mut self.max_blocks_gauge.collect()); - values.append(&mut self.num_free_blocks_gauge.collect()); - values + self.metrics.collect_group_into(enc) } } diff --git a/pgxn/neon/communicator/src/global_allocator.rs b/pgxn/neon/communicator/src/global_allocator.rs index 0c8e88071f..250cad1eb0 100644 --- a/pgxn/neon/communicator/src/global_allocator.rs +++ b/pgxn/neon/communicator/src/global_allocator.rs @@ -18,9 +18,12 @@ use std::alloc::{GlobalAlloc, Layout, System}; use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; -use metrics::IntGauge; +use measured::metric; +use measured::metric::MetricEncoding; +use measured::metric::gauge::GaugeState; +use measured::{Gauge, MetricGroup}; -struct MyAllocator { +pub(crate) struct MyAllocator { allocations: AtomicU64, deallocations: AtomicU64, @@ -28,6 +31,22 @@ struct MyAllocator { high: AtomicUsize, } +#[derive(MetricGroup)] +#[metric(new())] +struct MyAllocatorMetricGroup { + /// Number of allocations in Rust code + communicator_mem_allocations: Gauge, + + /// Number of deallocations in Rust code + communicator_mem_deallocations: Gauge, + + /// Bytes currently allocated + communicator_mem_allocated: Gauge, + + /// High watermark of allocated bytes + communicator_mem_high: Gauge, +} + unsafe impl GlobalAlloc for MyAllocator { unsafe fn alloc(&self, layout: Layout) -> *mut u8 { self.allocations.fetch_add(1, Ordering::Relaxed); @@ -52,58 +71,37 @@ static GLOBAL: MyAllocator = MyAllocator { high: AtomicUsize::new(0), }; -pub struct MyAllocatorCollector { - allocations: IntGauge, - deallocations: IntGauge, - allocated: IntGauge, - high: IntGauge, +pub(crate) struct MyAllocatorCollector { + metrics: MyAllocatorMetricGroup, } impl MyAllocatorCollector { - pub fn new() -> MyAllocatorCollector { - MyAllocatorCollector { - allocations: IntGauge::new("allocations_total", "Number of allocations in Rust code") - .unwrap(), - deallocations: IntGauge::new( - "deallocations_total", - "Number of deallocations in Rust code", - ) - .unwrap(), - allocated: IntGauge::new("allocated_total", "Bytes currently allocated").unwrap(), - high: IntGauge::new("allocated_high", "High watermark of allocated bytes").unwrap(), + pub(crate) fn new() -> Self { + Self { + metrics: MyAllocatorMetricGroup::new(), } } } -impl metrics::core::Collector for MyAllocatorCollector { - fn desc(&self) -> Vec<&metrics::core::Desc> { - let mut descs = Vec::new(); - - descs.append(&mut self.allocations.desc()); - descs.append(&mut self.deallocations.desc()); - descs.append(&mut self.allocated.desc()); - descs.append(&mut self.high.desc()); - - descs - } - - fn collect(&self) -> Vec { - let mut values = Vec::new(); - - // update the gauges - self.allocations +impl MetricGroup for MyAllocatorCollector +where + GaugeState: MetricEncoding, +{ + fn collect_group_into(&self, enc: &mut T) -> Result<(), ::Err> { + // Update the gauges with fresh values first + self.metrics + .communicator_mem_allocations .set(GLOBAL.allocations.load(Ordering::Relaxed) as i64); - self.deallocations + self.metrics + .communicator_mem_deallocations .set(GLOBAL.allocations.load(Ordering::Relaxed) as i64); - self.allocated + self.metrics + .communicator_mem_allocated .set(GLOBAL.allocated.load(Ordering::Relaxed) as i64); - self.high.set(GLOBAL.high.load(Ordering::Relaxed) as i64); + self.metrics + .communicator_mem_high + .set(GLOBAL.high.load(Ordering::Relaxed) as i64); - values.append(&mut self.allocations.collect()); - values.append(&mut self.deallocations.collect()); - values.append(&mut self.allocated.collect()); - values.append(&mut self.high.collect()); - - values + self.metrics.collect_group_into(enc) } } diff --git a/pgxn/neon/communicator/src/init.rs b/pgxn/neon/communicator/src/init.rs index 811c77935d..7aebe4afab 100644 --- a/pgxn/neon/communicator/src/init.rs +++ b/pgxn/neon/communicator/src/init.rs @@ -38,7 +38,7 @@ pub struct CommunicatorInitStruct { pub neon_request_slots: &'static [NeonIORequestSlot], - pub integrated_cache_init_struct: IntegratedCacheInitStruct, + pub integrated_cache_init_struct: IntegratedCacheInitStruct<'static>, } impl std::fmt::Debug for CommunicatorInitStruct { @@ -122,8 +122,6 @@ pub extern "C" fn rcommunicator_shmem_init( cis } -// fixme: currently unused -#[allow(dead_code)] pub fn alloc_from_slice( area: &mut [MaybeUninit], ) -> (&mut MaybeUninit, &mut [MaybeUninit]) { diff --git a/pgxn/neon/communicator/src/integrated_cache.rs b/pgxn/neon/communicator/src/integrated_cache.rs index 4c3da29662..d775a4143e 100644 --- a/pgxn/neon/communicator/src/integrated_cache.rs +++ b/pgxn/neon/communicator/src/integrated_cache.rs @@ -23,16 +23,21 @@ // use std::mem::MaybeUninit; -use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; use utils::lsn::{AtomicLsn, Lsn}; use crate::file_cache::INVALID_CACHE_BLOCK; use crate::file_cache::{CacheBlock, FileCache}; +use crate::init::alloc_from_slice; use pageserver_page_api::RelTag; -use metrics::{IntCounter, IntGauge}; +use measured::metric; +use measured::metric::MetricEncoding; +use measured::metric::counter::CounterState; +use measured::metric::gauge::GaugeState; +use measured::{Counter, Gauge, MetricGroup}; use neon_shmem::hash::{HashMapInit, entry::Entry}; use neon_shmem::shmem::ShmemHandle; @@ -41,49 +46,78 @@ use neon_shmem::shmem::ShmemHandle; const RELSIZE_CACHE_SIZE: u32 = 64 * 1024; /// This struct is initialized at postmaster startup, and passed to all the processes via fork(). -pub struct IntegratedCacheInitStruct { + +pub struct IntegratedCacheInitStruct<'t> { + shared: &'t IntegratedCacheShared, relsize_cache_handle: HashMapInit, block_map_handle: HashMapInit, } -/// Represents write-access to the integrated cache. This is used by the communicator process. +/// This struct is allocated in the (fixed-size) shared memory area at postmaster startup. +/// It is accessible by all the backends and the communicator process. #[derive(Debug)] -pub struct IntegratedCacheWriteAccess { +pub struct IntegratedCacheShared { + global_lw_lsn: AtomicU64, +} + +/// Represents write-access to the integrated cache. This is used by the communicator process. +pub struct IntegratedCacheWriteAccess<'t> { + shared: &'t IntegratedCacheShared, relsize_cache: neon_shmem::hash::HashMapAccess, block_map: Arc>, - global_lw_lsn: AtomicU64, - pub(crate) file_cache: Option, // Fields for eviction - clock_hand: std::sync::Mutex, + clock_hand: AtomicUsize, - // Metrics - page_evictions_counter: IntCounter, - clock_iterations_counter: IntCounter, + metrics: IntegratedCacheMetricGroup, +} + +#[derive(MetricGroup)] +#[metric(new())] +struct IntegratedCacheMetricGroup { + /// Page evictions from the Local File Cache + cache_page_evictions_counter: Counter, + + /// Block entry evictions from the integrated cache + block_entry_evictions_counter: Counter, + + /// Number of times the clock hand has moved + clock_iterations_counter: Counter, // metrics from the hash map - block_map_num_buckets: IntGauge, - block_map_num_buckets_in_use: IntGauge, + /// Allocated size of the block cache hash map + block_map_num_buckets: Gauge, - relsize_cache_num_buckets: IntGauge, - relsize_cache_num_buckets_in_use: IntGauge, + /// Number of buckets in use in the block cache hash map + block_map_num_buckets_in_use: Gauge, + + /// Allocated size of the relsize cache hash map + relsize_cache_num_buckets: Gauge, + + /// Number of buckets in use in the relsize cache hash map + relsize_cache_num_buckets_in_use: Gauge, } /// Represents read-only access to the integrated cache. Backend processes have this. -pub struct IntegratedCacheReadAccess { +pub struct IntegratedCacheReadAccess<'t> { + shared: &'t IntegratedCacheShared, relsize_cache: neon_shmem::hash::HashMapAccess, block_map: neon_shmem::hash::HashMapAccess, } -impl IntegratedCacheInitStruct { +impl<'t> IntegratedCacheInitStruct<'t> { /// Return the desired size in bytes of the fixed-size shared memory area to reserve for the /// integrated cache. pub fn shmem_size() -> usize { // The relsize cache is fixed-size. The block map is allocated in a separate resizable // area. - HashMapInit::::estimate_size(RELSIZE_CACHE_SIZE) + let mut sz = 0; + sz += std::mem::size_of::(); + sz += HashMapInit::::estimate_size(RELSIZE_CACHE_SIZE); + + sz } /// Initialize the shared memory segment. This runs once in postmaster. Returns a struct which @@ -92,10 +126,16 @@ impl IntegratedCacheInitStruct { shmem_area: &'static mut [MaybeUninit], initial_file_cache_size: u64, max_file_cache_size: u64, - ) -> IntegratedCacheInitStruct { - // Initialize the relsize cache in the fixed-size area + ) -> IntegratedCacheInitStruct<'t> { + // Initialize the shared struct + let (shared, remain_shmem_area) = alloc_from_slice::(shmem_area); + let shared = shared.write(IntegratedCacheShared { + global_lw_lsn: AtomicU64::new(0), + }); + + // Use the remaining part of the fixed-size area for the relsize cache let relsize_cache_handle = - neon_shmem::hash::HashMapInit::with_fixed(RELSIZE_CACHE_SIZE, shmem_area); + neon_shmem::hash::HashMapInit::with_fixed(RELSIZE_CACHE_SIZE, remain_shmem_area); let max_bytes = HashMapInit::::estimate_size(max_file_cache_size as u32); @@ -106,6 +146,7 @@ impl IntegratedCacheInitStruct { let block_map_handle = neon_shmem::hash::HashMapInit::with_shmem(initial_file_cache_size as u32, shmem_handle); IntegratedCacheInitStruct { + shared, relsize_cache_handle, block_map_handle, } @@ -116,62 +157,35 @@ impl IntegratedCacheInitStruct { self, lsn: Lsn, file_cache: Option, - ) -> IntegratedCacheWriteAccess { + ) -> IntegratedCacheWriteAccess<'t> { let IntegratedCacheInitStruct { + shared, relsize_cache_handle, block_map_handle, } = self; + + shared.global_lw_lsn.store(lsn.0, Ordering::Relaxed); + IntegratedCacheWriteAccess { + shared, relsize_cache: relsize_cache_handle.attach_writer(), block_map: block_map_handle.attach_writer().into(), - global_lw_lsn: AtomicU64::new(lsn.0), file_cache, - clock_hand: std::sync::Mutex::new(0), - - page_evictions_counter: metrics::IntCounter::new( - "integrated_cache_evictions", - "Page evictions from the Local File Cache", - ) - .unwrap(), - - clock_iterations_counter: metrics::IntCounter::new( - "clock_iterations", - "Number of times the clock hand has moved", - ) - .unwrap(), - - block_map_num_buckets: metrics::IntGauge::new( - "block_map_num_buckets", - "Allocated size of the block cache hash map", - ) - .unwrap(), - block_map_num_buckets_in_use: metrics::IntGauge::new( - "block_map_num_buckets_in_use", - "Number of buckets in use in the block cache hash map", - ) - .unwrap(), - - relsize_cache_num_buckets: metrics::IntGauge::new( - "relsize_cache_num_buckets", - "Allocated size of the relsize cache hash map", - ) - .unwrap(), - relsize_cache_num_buckets_in_use: metrics::IntGauge::new( - "relsize_cache_num_buckets_in_use", - "Number of buckets in use in the relsize cache hash map", - ) - .unwrap(), + clock_hand: AtomicUsize::new(0), + metrics: IntegratedCacheMetricGroup::new(), } } /// Initialize access to the integrated cache for a backend process - pub fn backend_init(self) -> IntegratedCacheReadAccess { + pub fn backend_init(self) -> IntegratedCacheReadAccess<'t> { let IntegratedCacheInitStruct { + shared, relsize_cache_handle, block_map_handle, } = self; IntegratedCacheReadAccess { + shared, relsize_cache: relsize_cache_handle.attach_reader(), block_map: block_map_handle.attach_reader(), } @@ -254,12 +268,25 @@ pub enum CacheResult { NotFound(Lsn), } -impl IntegratedCacheWriteAccess { - pub fn get_rel_size(&self, rel: &RelTag) -> CacheResult { +/// Return type of [try_evict_entry] +enum EvictResult { + /// Could not evict page because it was pinned + Pinned, + + /// The victim bucket was already vacant + Vacant, + + /// Evicted an entry. If it had a cache block associated with it, it's returned + /// here, otherwise None + Evicted(Option), +} + +impl<'t> IntegratedCacheWriteAccess<'t> { + pub fn get_rel_size(&'t self, rel: &RelTag) -> CacheResult { if let Some(nblocks) = get_rel_size(&self.relsize_cache, rel) { CacheResult::Found(nblocks) } else { - let lsn = Lsn(self.global_lw_lsn.load(Ordering::Relaxed)); + let lsn = Lsn(self.shared.global_lw_lsn.load(Ordering::Relaxed)); CacheResult::NotFound(lsn) } } @@ -284,7 +311,7 @@ impl IntegratedCacheWriteAccess { return Ok(CacheResult::NotFound(block_entry.lw_lsn.load())); } } else { - let lsn = Lsn(self.global_lw_lsn.load(Ordering::Relaxed)); + let lsn = Lsn(self.shared.global_lw_lsn.load(Ordering::Relaxed)); return Ok(CacheResult::NotFound(lsn)); }; @@ -317,7 +344,7 @@ impl IntegratedCacheWriteAccess { Ok(CacheResult::NotFound(block_entry.lw_lsn.load())) } } else { - let lsn = Lsn(self.global_lw_lsn.load(Ordering::Relaxed)); + let lsn = Lsn(self.shared.global_lw_lsn.load(Ordering::Relaxed)); Ok(CacheResult::NotFound(lsn)) } } @@ -329,7 +356,7 @@ impl IntegratedCacheWriteAccess { if let Some(_rel_entry) = self.relsize_cache.get(&RelKey::from(rel)) { CacheResult::Found(true) } else { - let lsn = Lsn(self.global_lw_lsn.load(Ordering::Relaxed)); + let lsn = Lsn(self.shared.global_lw_lsn.load(Ordering::Relaxed)); CacheResult::NotFound(lsn) } } @@ -340,14 +367,14 @@ impl IntegratedCacheWriteAccess { // e.g. psql \l+ command, so the user will feel the latency. // fixme: is this right lsn? - let lsn = Lsn(self.global_lw_lsn.load(Ordering::Relaxed)); + let lsn = Lsn(self.shared.global_lw_lsn.load(Ordering::Relaxed)); CacheResult::NotFound(lsn) } pub fn remember_rel_size(& self, rel: &RelTag, nblocks: u32, lsn: Lsn) { match self.relsize_cache.entry(RelKey::from(rel)) { Entry::Vacant(e) => { - tracing::info!("inserting rel entry for {rel:?}, {nblocks} blocks"); + tracing::trace!("inserting rel entry for {rel:?}, {nblocks} blocks"); // FIXME: what to do if we run out of memory? Evict other relation entries? _ = e .insert(RelEntry { @@ -357,7 +384,7 @@ impl IntegratedCacheWriteAccess { .expect("out of memory"); } Entry::Occupied(e) => { - tracing::info!("updating rel entry for {rel:?}, {nblocks} blocks"); + tracing::trace!("updating rel entry for {rel:?}, {nblocks} blocks"); e.get().nblocks.store(nblocks, Ordering::Relaxed); e.get().lw_lsn.store(lsn); } @@ -417,7 +444,7 @@ impl IntegratedCacheWriteAccess { if let Some(x) = file_cache.alloc_block() { break x; } - if let Some(x) = self.try_evict_one_cache_block() { + if let Some(x) = self.try_evict_cache_block() { break x; } } @@ -432,39 +459,45 @@ impl IntegratedCacheWriteAccess { // FIXME: unpin the block entry on error // Update the block entry - let entry = self.block_map.entry(key); - assert_eq!(found_existing, matches!(entry, Entry::Occupied(_))); - match entry { - Entry::Occupied(e) => { - let block_entry = e.get(); - // Update the cache block - let old_blk = block_entry.cache_block.compare_exchange( - INVALID_CACHE_BLOCK, - cache_block, - Ordering::Relaxed, - Ordering::Relaxed, - ); - assert!(old_blk == Ok(INVALID_CACHE_BLOCK) || old_blk == Err(cache_block)); + loop { + let entry = self.block_map.entry(key.clone()); + assert_eq!(found_existing, matches!(entry, Entry::Occupied(_))); + match entry { + Entry::Occupied(e) => { + let block_entry = e.get(); + // Update the cache block + let old_blk = block_entry.cache_block.compare_exchange( + INVALID_CACHE_BLOCK, + cache_block, + Ordering::Relaxed, + Ordering::Relaxed, + ); + assert!(old_blk == Ok(INVALID_CACHE_BLOCK) || old_blk == Err(cache_block)); - block_entry.lw_lsn.store(lw_lsn); + block_entry.lw_lsn.store(lw_lsn); - block_entry.referenced.store(true, Ordering::Relaxed); + block_entry.referenced.store(true, Ordering::Relaxed); - let pin_count = block_entry.pinned.fetch_sub(1, Ordering::Relaxed); - assert!(pin_count > 0); - } - Entry::Vacant(e) => { - // FIXME: what to do if we run out of memory? Evict other relation entries? Remove - // block entries first? - _ = e - .insert(BlockEntry { + let pin_count = block_entry.pinned.fetch_sub(1, Ordering::Relaxed); + assert!(pin_count > 0); + break; + } + Entry::Vacant(e) => { + if e.insert(BlockEntry { lw_lsn: AtomicLsn::new(lw_lsn.0), cache_block: AtomicU64::new(cache_block), pinned: AtomicU64::new(0), referenced: AtomicBool::new(true), }) - .expect("out of memory"); + .is_ok() + { + break; + } else { + // The hash map was full. Evict an entry and retry. + } + } } + self.try_evict_block_entry(); } } else { // !is_write @@ -479,7 +512,7 @@ impl IntegratedCacheWriteAccess { if let Some(x) = file_cache.alloc_block() { break x; } - if let Some(x) = self.try_evict_one_cache_block() { + if let Some(x) = self.try_evict_cache_block() { break x; } } @@ -492,43 +525,53 @@ impl IntegratedCacheWriteAccess { .expect("error writing to cache"); // FIXME: handle errors gracefully. - match self.block_map.entry(key) { - Entry::Occupied(e) => { - let block_entry = e.get(); - // FIXME: could there be concurrent readers? - assert!(block_entry.pinned.load(Ordering::Relaxed) == 0); + loop { + match self.block_map.entry(key.clone()) { + Entry::Occupied(e) => { + let block_entry = e.get(); + // FIXME: could there be concurrent readers? + assert!(block_entry.pinned.load(Ordering::Relaxed) == 0); - let old_cache_block = - block_entry.cache_block.swap(cache_block, Ordering::Relaxed); - if old_cache_block != INVALID_CACHE_BLOCK { - panic!( - "remember_page called in !is_write mode, but page is already cached at blk {old_cache_block}" - ); + let old_cache_block = + block_entry.cache_block.swap(cache_block, Ordering::Relaxed); + if old_cache_block != INVALID_CACHE_BLOCK { + panic!( + "remember_page called in !is_write mode, but page is already cached at blk {old_cache_block}" + ); + } + break; } - } - Entry::Vacant(e) => { - // FIXME: what to do if we run out of memory? Evict other relation entries? Remove - // block entries first? - _ = e - .insert(BlockEntry { + Entry::Vacant(e) => { + if e.insert(BlockEntry { lw_lsn: AtomicLsn::new(lw_lsn.0), cache_block: AtomicU64::new(cache_block), pinned: AtomicU64::new(0), referenced: AtomicBool::new(true), }) - .expect("out of memory"); - } + .is_ok() + { + break; + } else { + // The hash map was full. Evict an entry and retry. + } + } + }; + + self.try_evict_block_entry(); } } } /// Forget information about given relation in the cache. (For DROP TABLE and such) - pub fn forget_rel(& self, rel: &RelTag, _nblocks: Option, flush_lsn: Lsn) { - tracing::info!("forgetting rel entry for {rel:?}"); + pub fn forget_rel(&'t self, rel: &RelTag, _nblocks: Option, flush_lsn: Lsn) { + tracing::trace!("forgetting rel entry for {rel:?}"); self.relsize_cache.remove(&RelKey::from(rel)); // update with flush LSN - let _ = self.global_lw_lsn.fetch_max(flush_lsn.0, Ordering::Relaxed); + let _ = self + .shared + .global_lw_lsn + .fetch_max(flush_lsn.0, Ordering::Relaxed); // also forget all cached blocks for the relation // FIXME @@ -576,66 +619,144 @@ impl IntegratedCacheWriteAccess { // Maintenance routines - /// Evict one block from the file cache. This is used when the file cache fills up - /// Returns the evicted block. It's not put to the free list, so it's available for the - /// caller to use immediately. - pub fn try_evict_one_cache_block(&self) -> Option { - let mut clock_hand = self.clock_hand.lock().unwrap(); - for _ in 0..100 { - self.clock_iterations_counter.inc(); + /// Evict one block entry from the cache. + /// + /// This is called when the hash map is full, to make an entry available for a new + /// insertion. There's no guarantee that the entry is free by the time this function + /// returns anymore; it can taken by a concurrent thread at any time. So you need to + /// call this and retry repeatedly until you succeed. + fn try_evict_block_entry(&self) { + let num_buckets = self.block_map.get_num_buckets(); + loop { + self.metrics.clock_iterations_counter.inc(); + let victim_bucket = self.clock_hand.fetch_add(1, Ordering::Relaxed) % num_buckets; - (*clock_hand) += 1; - - let mut evict_this = false; - let num_buckets = self.block_map.get_num_logical_buckets(); - match self - .block_map - .get_at_bucket((*clock_hand) % num_buckets) - .as_deref() - { + let evict_this = match self.block_map.get_at_bucket(victim_bucket).as_deref() { None => { - // This bucket was unused + // The caller wants to have a free bucket. If there's one already, we're good. + return; } Some((_, blk_entry)) => { - if !blk_entry.referenced.swap(false, Ordering::Relaxed) { - // Evict this. Maybe. - evict_this = true; + // Clear the 'referenced' flag. If it was already clear, + // release the lock (by exiting this scope), and try to + // evict it. + !blk_entry.referenced.swap(false, Ordering::Relaxed) + } + }; + if evict_this { + match self.try_evict_entry(victim_bucket) { + EvictResult::Pinned => { + // keep looping } + EvictResult::Vacant => { + // This was released by someone else. Return so that + // the caller will try to use it. (Chances are that it + // will be reused by someone else, but let's try.) + return; + } + EvictResult::Evicted(None) => { + // This is now free. + return; + } + EvictResult::Evicted(Some(cache_block)) => { + // This is now free. We must not leak the cache block, so put it to the freelist + self.file_cache.as_ref().unwrap().dealloc_block(cache_block); + return; + } + } + } + // TODO: add some kind of a backstop to error out if we loop + // too many times without finding any unpinned entries + } + } + + /// Evict one block from the file cache. This is called when the file cache fills up, + /// to release a cache block. + /// + /// Returns the evicted block. It's not put to the free list, so it's available for + /// the caller to use immediately. + fn try_evict_cache_block(&self) -> Option { + let num_buckets = self.block_map.get_num_buckets(); + let mut iterations = 0; + while iterations < 100 { + self.metrics.clock_iterations_counter.inc(); + let victim_bucket = self.clock_hand.fetch_add(1, Ordering::Relaxed) % num_buckets; + + let evict_this = match self.block_map.get_at_bucket(victim_bucket).as_deref() { + None => { + // This bucket was unused. It's no use for finding a free cache block + continue; + } + Some((_, blk_entry)) => { + // Clear the 'referenced' flag. If it was already clear, + // release the lock (by exiting this scope), and try to + // evict it. + !blk_entry.referenced.swap(false, Ordering::Relaxed) } }; if evict_this { - // grab the write lock - let mut evicted_cache_block = None; - if let Some(e) = self.block_map.entry_at_bucket(*clock_hand % num_buckets) { - let old = e.get(); - // note: all the accesses to 'pinned' currently happen - // within update_with_fn(), or while holding ValueReadGuard, which protects from concurrent - // updates. Otherwise, another thread could set the 'pinned' - // flag just after we have checked it here. - if old.pinned.load(Ordering::Relaxed) == 0 { - let _ = self - .global_lw_lsn - .fetch_max(old.lw_lsn.load().0, Ordering::Relaxed); - let cache_block = - old.cache_block.swap(INVALID_CACHE_BLOCK, Ordering::Relaxed); - if cache_block != INVALID_CACHE_BLOCK { - evicted_cache_block = Some(cache_block); - } - e.remove(); + match self.try_evict_entry(victim_bucket) { + EvictResult::Pinned => { + // keep looping + } + EvictResult::Vacant => { + // This was released by someone else. Keep looping. + } + EvictResult::Evicted(None) => { + // This is now free, but it didn't have a cache block + // associated with it. Keep looping. + } + EvictResult::Evicted(Some(cache_block)) => { + // Reuse this + return Some(cache_block); } } - - if evicted_cache_block.is_some() { - self.page_evictions_counter.inc(); - return evicted_cache_block; - } } + + iterations += 1; } - // Give up if we didn find anything + + // Reached the max iteration count without finding an entry. Return + // to give the caller a chance to do other things None } + /// Returns Err, if the page could not be evicted because it was pinned + fn try_evict_entry(&self, victim: usize) -> EvictResult { + // grab the write lock + if let Some(e) = self.block_map.entry_at_bucket(victim) { + let old = e.get(); + // note: all the accesses to 'pinned' currently happen + // within update_with_fn(), or while holding ValueReadGuard, which protects from concurrent + // updates. Otherwise, another thread could set the 'pinned' + // flag just after we have checked it here. + // + // FIXME: ^^ outdated comment, update_with_fn() is no more + + if old.pinned.load(Ordering::Relaxed) == 0 { + let old_val = e.remove(); + let _ = self + .shared + .global_lw_lsn + .fetch_max(old_val.lw_lsn.into_inner().0, Ordering::Relaxed); + let evicted_cache_block = match old_val.cache_block.into_inner() { + INVALID_CACHE_BLOCK => None, + n => Some(n), + }; + if evicted_cache_block.is_some() { + self.metrics.cache_page_evictions_counter.inc(); + } + self.metrics.block_entry_evictions_counter.inc(); + EvictResult::Evicted(evicted_cache_block) + } else { + EvictResult::Pinned + } + } else { + EvictResult::Vacant + } + } + /// Resize the local file cache. pub fn resize_file_cache(&'static self, num_blocks: u32) { // TODO(quantumish): unclear what the semantics of this entire operation is @@ -657,21 +778,23 @@ impl IntegratedCacheWriteAccess { file_cache.grow(remaining); debug_assert!(file_cache.free_space() > remaining); } else { - let page_evictions = &self.page_evictions_counter; - let global_lw_lsn = &self.global_lw_lsn; + let page_evictions = &self.metrics.cache_page_evictions_counter; + let global_lw_lsn = &self.shared.global_lw_lsn; let block_map = self.block_map.clone(); tokio::task::spawn_blocking(move || { - // Don't hold clock hand lock any longer than necessary, should be ok to evict in parallel - // but we don't want to compete with the eviction logic in the to-be-shrunk region. - { - let mut clock_hand = self.clock_hand.lock().unwrap(); - - block_map.begin_shrink(num_blocks); - // Avoid skipping over beginning entries due to modulo shift. - if *clock_hand > num_blocks as usize { - *clock_hand = num_blocks as usize - 1; + block_map.begin_shrink(num_blocks); + let mut old_hand = self.clock_hand.load(Ordering::Relaxed); + if old_hand > num_blocks as usize { + loop { + match self.clock_hand.compare_exchange_weak( + old_hand, 0, Ordering::Relaxed, Ordering::Relaxed + ) { + Ok(_) => break, + Err(x) => old_hand = x, + } } } + // Try and evict everything in to-be-shrinked space // TODO(quantumish): consider moving things ahead of clock hand? let mut file_evictions = 0; @@ -707,7 +830,7 @@ impl IntegratedCacheWriteAccess { // enough space. Waiting for stragglers at the end of the map could *in theory* // take indefinite amounts of time depending on how long they stay pinned. while file_evictions < difference { - if let Some(i) = self.try_evict_one_cache_block() { + if let Some(i) = self.try_evict_cache_block() { if i != INVALID_CACHE_BLOCK { file_cache.delete_block(i); file_evictions += 1; @@ -764,42 +887,31 @@ impl IntegratedCacheWriteAccess { } } -impl metrics::core::Collector for IntegratedCacheWriteAccess { - fn desc(&self) -> Vec<&metrics::core::Desc> { - let mut descs = Vec::new(); - descs.append(&mut self.page_evictions_counter.desc()); - descs.append(&mut self.clock_iterations_counter.desc()); - - descs.append(&mut self.block_map_num_buckets.desc()); - descs.append(&mut self.block_map_num_buckets_in_use.desc()); - - descs.append(&mut self.relsize_cache_num_buckets.desc()); - descs.append(&mut self.relsize_cache_num_buckets_in_use.desc()); - - descs - } - fn collect(&self) -> Vec { +impl MetricGroup for IntegratedCacheWriteAccess<'_> +where + CounterState: MetricEncoding, + GaugeState: MetricEncoding, +{ + fn collect_group_into(&self, enc: &mut T) -> Result<(), ::Err> { // Update gauges - self.block_map_num_buckets + self.metrics + .block_map_num_buckets .set(self.block_map.get_num_buckets() as i64); - self.block_map_num_buckets_in_use + self.metrics + .block_map_num_buckets_in_use .set(self.block_map.get_num_buckets_in_use() as i64); - self.relsize_cache_num_buckets + self.metrics + .relsize_cache_num_buckets .set(self.relsize_cache.get_num_buckets() as i64); - self.relsize_cache_num_buckets_in_use + self.metrics + .relsize_cache_num_buckets_in_use .set(self.relsize_cache.get_num_buckets_in_use() as i64); - let mut values = Vec::new(); - values.append(&mut self.page_evictions_counter.collect()); - values.append(&mut self.clock_iterations_counter.collect()); + if let Some(file_cache) = &self.file_cache { + file_cache.collect_group_into(enc)?; + } - values.append(&mut self.block_map_num_buckets.collect()); - values.append(&mut self.block_map_num_buckets_in_use.collect()); - - values.append(&mut self.relsize_cache_num_buckets.collect()); - values.append(&mut self.relsize_cache_num_buckets_in_use.collect()); - - values + self.metrics.collect_group_into(enc) } } @@ -833,7 +945,7 @@ pub enum GetBucketResult { /// /// This allows backends to read pages from the cache directly, on their own, without making a /// request to the communicator process. -impl IntegratedCacheReadAccess { +impl<'t> IntegratedCacheReadAccess<'t> { pub fn get_rel_size(& self, rel: &RelTag) -> Option { get_rel_size(&self.relsize_cache, rel) } @@ -845,11 +957,64 @@ impl IntegratedCacheReadAccess { } } - /// Check if the given page is present in the cache - pub fn cache_contains_page(& self, rel: &RelTag, block_number: u32) -> bool { - self.block_map - .get(&BlockKey::from((rel, block_number))) - .is_some() + /// Check if LFC contains the given buffer, and update its last-written LSN if not. + /// + /// Returns: + /// true if the block is in the LFC + /// false if it's not. + /// + /// If the block was not in the LFC (i.e. when this returns false), the last-written LSN + /// value on the block is updated to the given 'lsn', so that the next read of the block + /// will read the new version. Otherwise the caller is assumed to modify the page and + /// to update the last-written LSN later by writing the new page. + pub fn update_lw_lsn_for_block_if_not_cached( + &'t self, + rel: &RelTag, + block_number: u32, + lsn: Lsn, + ) -> bool { + let key = BlockKey::from((rel, block_number)); + let entry = self.block_map.entry(key); + match entry { + Entry::Occupied(e) => { + let block_entry = e.get(); + if block_entry.cache_block.load(Ordering::Relaxed) != INVALID_CACHE_BLOCK { + block_entry.referenced.store(true, Ordering::Relaxed); + true + } else { + let old_lwlsn = block_entry.lw_lsn.fetch_max(lsn); + if old_lwlsn >= lsn { + // shouldn't happen + tracing::warn!( + "attempted to move last-written LSN backwards from {old_lwlsn} to {lsn} for rel {rel} blk {block_number}" + ); + } + false + } + } + Entry::Vacant(e) => { + if e.insert(BlockEntry { + lw_lsn: AtomicLsn::new(lsn.0), + cache_block: AtomicU64::new(INVALID_CACHE_BLOCK), + pinned: AtomicU64::new(0), + referenced: AtomicBool::new(true), + }) + .is_ok() + { + false + } else { + // The hash table is full. + // + // TODO: Evict something. But for now, just set the global lw LSN instead. + // That's correct, but not very efficient for future reads + let _ = self + .shared + .global_lw_lsn + .fetch_max(lsn.0, Ordering::Relaxed); + false + } + } + } } pub fn get_bucket(&self, bucket_no: usize) -> GetBucketResult { @@ -873,7 +1038,7 @@ impl IntegratedCacheReadAccess { pub struct BackendCacheReadOp<'t> { read_guards: Vec, - map_access: &'t IntegratedCacheReadAccess, + map_access: &'t IntegratedCacheReadAccess<'t>, } impl<'e> BackendCacheReadOp<'e> { diff --git a/pgxn/neon/communicator/src/neon_request.rs b/pgxn/neon/communicator/src/neon_request.rs index e528d451f3..809e645f8c 100644 --- a/pgxn/neon/communicator/src/neon_request.rs +++ b/pgxn/neon/communicator/src/neon_request.rs @@ -36,6 +36,10 @@ pub enum NeonIORequest { PrefetchV(CPrefetchVRequest), DbSize(CDbSizeRequest), + /// This is like GetPageV, but bypasses the LFC and allows specifiying the + /// request LSNs directly. For debugging purposes only. + GetPageVUncached(CGetPageVUncachedRequest), + // Write requests. These are needed to keep the relation size cache and LFC up-to-date. // They are not sent to the pageserver. WritePage(CWritePageRequest), @@ -89,6 +93,7 @@ impl NeonIORequest { Empty => 0, RelSize(req) => req.request_id, GetPageV(req) => req.request_id, + GetPageVUncached(req) => req.request_id, ReadSlruSegment(req) => req.request_id, PrefetchV(req) => req.request_id, DbSize(req) => req.request_id, @@ -191,6 +196,24 @@ pub struct CGetPageVRequest { pub dest: [ShmemBuf; MAX_GETPAGEV_PAGES], } +#[repr(C)] +#[derive(Copy, Clone, Debug)] +pub struct CGetPageVUncachedRequest { + pub request_id: u64, + pub spc_oid: COid, + pub db_oid: COid, + pub rel_number: u32, + pub fork_number: u8, + pub block_number: u32, + pub nblocks: u8, + + pub request_lsn: CLsn, + pub not_modified_since: CLsn, + + // These fields define where the result is written. Must point into a buffer in shared memory! + pub dest: [ShmemBuf; MAX_GETPAGEV_PAGES], +} + #[repr(C)] #[derive(Copy, Clone, Debug)] pub struct CReadSlruSegmentRequest { @@ -331,6 +354,17 @@ impl CGetPageVRequest { } } +impl CGetPageVUncachedRequest { + pub fn reltag(&self) -> page_api::RelTag { + page_api::RelTag { + spcnode: self.spc_oid, + dbnode: self.db_oid, + relnode: self.rel_number, + forknum: self.fork_number, + } + } +} + impl CPrefetchVRequest { pub fn reltag(&self) -> page_api::RelTag { page_api::RelTag { diff --git a/pgxn/neon/communicator/src/worker_process/control_socket.rs b/pgxn/neon/communicator/src/worker_process/control_socket.rs index 00e9f9bf11..5e2b35de1e 100644 --- a/pgxn/neon/communicator/src/worker_process/control_socket.rs +++ b/pgxn/neon/communicator/src/worker_process/control_socket.rs @@ -19,6 +19,9 @@ use http::StatusCode; use http::header::CONTENT_TYPE; use measured::MetricGroup; +use measured::metric::MetricEncoding; +use measured::metric::gauge::GaugeState; +use measured::metric::group::Encoding; use measured::text::BufferedTextEncoder; use std::io::ErrorKind; @@ -27,6 +30,7 @@ use std::sync::Arc; use tokio::net::UnixListener; use crate::NEON_COMMUNICATOR_SOCKET_NAME; +use crate::worker_process::lfc_metrics::LfcMetricsCollector; use crate::worker_process::main_loop::CommunicatorWorkerProcessStruct; enum ControlSocketState<'a> { @@ -34,7 +38,20 @@ enum ControlSocketState<'a> { Legacy(LegacyControlSocketState), } -struct LegacyControlSocketState; +struct LegacyControlSocketState { + pub(crate) lfc_metrics: LfcMetricsCollector, +} + +impl MetricGroup for LegacyControlSocketState +where + T: Encoding, + GaugeState: MetricEncoding, +{ + fn collect_group_into(&self, enc: &mut T) -> Result<(), T::Err> { + self.lfc_metrics.collect_group_into(enc)?; + Ok(()) + } +} /// Launch the listener pub(crate) async fn launch_listener( @@ -44,7 +61,9 @@ pub(crate) async fn launch_listener( let state = match worker { Some(worker) => ControlSocketState::Full(worker), - None => ControlSocketState::Legacy(LegacyControlSocketState), + None => ControlSocketState::Legacy(LegacyControlSocketState { + lfc_metrics: LfcMetricsCollector, + }), }; let app = Router::new() @@ -84,9 +103,7 @@ pub(crate) async fn launch_listener( async fn get_metrics(State(state): State>>) -> Response { match state.as_ref() { ControlSocketState::Full(worker) => metrics_to_response(&worker).await, - ControlSocketState::Legacy(_) => { - todo!() - } + ControlSocketState::Legacy(legacy) => metrics_to_response(&legacy).await, } } @@ -96,9 +113,7 @@ async fn get_metrics(State(state): State>>) -> Respon async fn get_autoscaling_metrics(State(state): State>>) -> Response { match state.as_ref() { ControlSocketState::Full(worker) => metrics_to_response(&worker.lfc_metrics).await, - ControlSocketState::Legacy(_) => { - todo!() - } + ControlSocketState::Legacy(legacy) => metrics_to_response(&legacy.lfc_metrics).await, } } @@ -132,8 +147,10 @@ async fn dump_cache_map(State(state): State>>) -> Res .body(Body::from(buf)) .unwrap() } - ControlSocketState::Legacy(_) => { - todo!() - } + ControlSocketState::Legacy(_) => Response::builder() + .status(StatusCode::NOT_FOUND) + .header(CONTENT_TYPE, "application/text") + .body(Body::from(Vec::new())) + .unwrap(), } } diff --git a/pgxn/neon/communicator/src/worker_process/main_loop.rs b/pgxn/neon/communicator/src/worker_process/main_loop.rs index c9036e8ff5..f7bbed5e00 100644 --- a/pgxn/neon/communicator/src/worker_process/main_loop.rs +++ b/pgxn/neon/communicator/src/worker_process/main_loop.rs @@ -9,7 +9,7 @@ use crate::file_cache::FileCache; use crate::global_allocator::MyAllocatorCollector; use crate::init::CommunicatorInitStruct; use crate::integrated_cache::{CacheResult, IntegratedCacheWriteAccess}; -use crate::neon_request::{CGetPageVRequest, CPrefetchVRequest}; +use crate::neon_request::{CGetPageVRequest, CGetPageVUncachedRequest, CPrefetchVRequest}; use crate::neon_request::{INVALID_BLOCK_NUMBER, NeonIORequest, NeonIOResult}; use crate::worker_process::control_socket; use crate::worker_process::in_progress_ios::{RequestInProgressKey, RequestInProgressTable}; @@ -23,6 +23,7 @@ use uring_common::buf::IoBuf; use measured::MetricGroup; use measured::metric::MetricEncoding; +use measured::metric::counter::CounterState; use measured::metric::gauge::GaugeState; use measured::metric::group::Encoding; use measured::{Gauge, GaugeVec}; @@ -30,7 +31,7 @@ use utils::id::{TenantId, TimelineId}; use super::callbacks::{get_request_lsn, notify_proc}; -use tracing::{debug, error, info, info_span, trace}; +use tracing::{error, info, info_span, trace}; use utils::lsn::Lsn; @@ -52,7 +53,7 @@ pub struct CommunicatorWorkerProcessStruct<'a> { in_progress_table: RequestInProgressTable, /// Local File Cache, relation size tracking, last-written LSN tracking - pub(crate) cache: IntegratedCacheWriteAccess, + pub(crate) cache: IntegratedCacheWriteAccess<'a>, /*** Metrics ***/ pub(crate) lfc_metrics: LfcMetricsCollector, @@ -65,7 +66,6 @@ pub struct CommunicatorWorkerProcessStruct<'a> { // For the requests that affect multiple blocks, have separate counters for the # of blocks affected request_nblocks_counters: GaugeVec, - #[allow(dead_code)] allocator_metrics: MyAllocatorCollector, } @@ -145,8 +145,6 @@ pub(super) fn init( .integrated_cache_init_struct .worker_process_init(last_lsn, file_cache); - debug!("Initialised integrated cache: {cache:?}"); - let client = { let _guard = runtime.enter(); PageserverClient::new( @@ -266,7 +264,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { // This needs to be removed once more regression tests are passing. // See also similar hack in the backend code, in wait_request_completion() let result = tokio::time::timeout( - tokio::time::Duration::from_secs(30), + tokio::time::Duration::from_secs(60), self.handle_request(slot.get_request()), ) .await @@ -373,7 +371,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { { Ok(Some(nblocks)) => { // update the cache - tracing::info!( + tracing::trace!( "updated relsize for {:?} in cache: {}, lsn {}", rel, nblocks, @@ -389,8 +387,9 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { NeonIOResult::RelSize(INVALID_BLOCK_NUMBER) } Err(err) => { + // FIXME: Could we map the tonic StatusCode to a libc errno in a more fine-grained way? Or pass the error message to the backend info!("tonic error: {err:?}"); - NeonIOResult::Error(0) + NeonIOResult::Error(libc::EIO) } } } @@ -398,6 +397,12 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { Ok(()) => NeonIOResult::GetPageV, Err(errno) => NeonIOResult::Error(errno), }, + NeonIORequest::GetPageVUncached(req) => { + match self.handle_get_pagev_uncached_request(req).await { + Ok(()) => NeonIOResult::GetPageV, + Err(errno) => NeonIOResult::Error(errno), + } + } NeonIORequest::ReadSlruSegment(req) => { let lsn = Lsn(req.request_lsn); let file_path = req.destination_file_path(); @@ -413,7 +418,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { { Ok(slru_bytes) => { if let Err(e) = tokio::fs::write(&file_path, &slru_bytes).await { - info!("could not write slru segment to file {file_path}: {e}"); + error!("could not write slru segment to file {file_path}: {e}"); return NeonIOResult::Error(e.raw_os_error().unwrap_or(libc::EIO)); } @@ -422,8 +427,9 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { NeonIOResult::ReadSlruSegment(blocks_count as _) } Err(err) => { + // FIXME: Could we map the tonic StatusCode to a libc errno in a more fine-grained way? Or pass the error message to the backend info!("tonic error: {err:?}"); - NeonIOResult::Error(0) + NeonIOResult::Error(libc::EIO) } } } @@ -431,6 +437,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { self.request_nblocks_counters .inc_by(RequestTypeLabelGroup::from_req(request), req.nblocks as i64); let req = *req; + // FIXME: handle_request() runs in a separate task already, do we really need to spawn a new one here? tokio::spawn(async move { self.handle_prefetchv_request(&req).await }); NeonIOResult::PrefetchVLaunched } @@ -459,8 +466,9 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { { Ok(db_size) => NeonIOResult::DbSize(db_size), Err(err) => { + // FIXME: Could we map the tonic StatusCode to a libc errno in a more fine-grained way? Or pass the error message to the backend info!("tonic error: {err:?}"); - NeonIOResult::Error(0) + NeonIOResult::Error(libc::EIO) } } } @@ -575,7 +583,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { continue; } Ok(CacheResult::NotFound(lsn)) => lsn, - Err(_io_error) => return Err(-1), // FIXME errno? + Err(_io_error) => return Err(libc::EIO), // FIXME print the error? }; cache_misses.push((blkno, not_modified_since, dest, in_progress_guard)); } @@ -599,7 +607,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { .map(|(blkno, _lsn, _dest, _guard)| *blkno) .collect(); let read_lsn = self.request_lsns(not_modified_since); - info!( + trace!( "sending getpage request for blocks {:?} in rel {:?} lsns {}", block_numbers, rel, read_lsn ); @@ -623,10 +631,10 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { resp.pages.len(), block_numbers.len(), ); - return Err(-1); + return Err(libc::EIO); } - info!( + trace!( "received getpage response for blocks {:?} in rel {:?} lsns {}", block_numbers, rel, read_lsn ); @@ -652,8 +660,75 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { } } Err(err) => { + // FIXME: Could we map the tonic StatusCode to a libc errno in a more fine-grained way? Or pass the error message to the backend info!("tonic error: {err:?}"); - return Err(-1); + return Err(libc::EIO); + } + } + Ok(()) + } + + /// Subroutine to handle an GetPageVUncached request. + /// + /// Note: this bypasses the cache, in-progress IO locking, and all other side-effects. + /// This request type is only used in tests. + async fn handle_get_pagev_uncached_request( + &'t self, + req: &CGetPageVUncachedRequest, + ) -> Result<(), i32> { + let rel = req.reltag(); + + // Construct a pageserver request + let block_numbers: Vec = + (req.block_number..(req.block_number + (req.nblocks as u32))).collect(); + let read_lsn = page_api::ReadLsn { + request_lsn: Lsn(req.request_lsn), + not_modified_since_lsn: Some(Lsn(req.not_modified_since)), + }; + trace!( + "sending (uncached) getpage request for blocks {:?} in rel {:?} lsns {}", + block_numbers, rel, read_lsn + ); + match self + .client + .get_page(page_api::GetPageRequest { + request_id: req.request_id.into(), + request_class: page_api::GetPageClass::Normal, + read_lsn, + rel, + block_numbers: block_numbers.clone(), + }) + .await + { + Ok(resp) => { + // Write the received page images directly to the shared memory location + // that the backend requested. + if resp.pages.len() != block_numbers.len() { + error!( + "received unexpected response with {} page images from pageserver for a request for {} pages", + resp.pages.len(), + block_numbers.len(), + ); + return Err(libc::EIO); + } + + trace!( + "received getpage response for blocks {:?} in rel {:?} lsns {}", + block_numbers, rel, read_lsn + ); + + for (page, dest) in resp.pages.into_iter().zip(req.dest) { + let src: &[u8] = page.image.as_ref(); + let len = std::cmp::min(src.len(), dest.bytes_total()); + unsafe { + std::ptr::copy_nonoverlapping(src.as_ptr(), dest.as_mut_ptr(), len); + }; + } + } + Err(err) => { + // FIXME: Could we map the tonic StatusCode to a libc errno in a more fine-grained way? Or pass the error message to the backend + info!("tonic error: {err:?}"); + return Err(libc::EIO); } } Ok(()) @@ -684,7 +759,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { continue; } Ok(CacheResult::NotFound(lsn)) => lsn, - Err(_io_error) => return Err(-1), // FIXME errno? + Err(_io_error) => return Err(libc::EIO), // FIXME print the error? }; cache_misses.push((blkno, not_modified_since, in_progress_guard)); } @@ -726,7 +801,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { resp.pages.len(), block_numbers.len(), ); - return Err(-1); + return Err(libc::EIO); } for (page, (blkno, _lsn, _guard)) in resp.pages.into_iter().zip(cache_misses) { @@ -736,8 +811,9 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { } } Err(err) => { + // FIXME: Could we map the tonic StatusCode to a libc errno in a more fine-grained way? Or pass the error message to the backend info!("tonic error: {err:?}"); - return Err(-1); + return Err(libc::EIO); } } Ok(()) @@ -747,6 +823,7 @@ impl<'t> CommunicatorWorkerProcessStruct<'t> { impl MetricGroup for CommunicatorWorkerProcessStruct<'_> where T: Encoding, + CounterState: MetricEncoding, GaugeState: MetricEncoding, { fn collect_group_into(&self, enc: &mut T) -> Result<(), T::Err> { @@ -754,12 +831,12 @@ where use measured::metric::name::MetricName; self.lfc_metrics.collect_group_into(enc)?; + self.cache.collect_group_into(enc)?; self.request_counters .collect_family_into(MetricName::from_str("request_counters"), enc)?; self.request_nblocks_counters .collect_family_into(MetricName::from_str("request_nblocks_counters"), enc)?; - - // FIXME: allocator metrics + self.allocator_metrics.collect_group_into(enc)?; Ok(()) } diff --git a/pgxn/neon/communicator_new.c b/pgxn/neon/communicator_new.c index c4b7263e19..f4088ab264 100644 --- a/pgxn/neon/communicator_new.c +++ b/pgxn/neon/communicator_new.c @@ -110,12 +110,12 @@ typedef struct CommunicatorShmemData * * Note that this is not protected by any locks. That's sloppy, but works * fine in practice. To "add" a value to the HLL state, we just overwrite - * one of the timestamps. Calculating the estimate reads all the values, but - * it also doesn't depend on seeing a consistent snapshot of the values. We - * could get bogus results if accessing the TimestampTz was not atomic, but - * it on any 64-bit platforms we care about it is, and even if we observed a - * torn read every now and then, it wouldn't affect the overall estimate - * much. + * one of the timestamps. Calculating the estimate reads all the values, + * but it also doesn't depend on seeing a consistent snapshot of the + * values. We could get bogus results if accessing the TimestampTz was not + * atomic, but it on any 64-bit platforms we care about it is, and even if + * we observed a torn read every now and then, it wouldn't affect the + * overall estimate much. */ HyperLogLogState wss_estimation; @@ -397,21 +397,23 @@ communicator_new_prefetch_register_bufferv(NRelFileInfo rinfo, ForkNumber forkNu } /* - * Does the LFC contains the given buffer? + * Check if LFC contains the given buffer, and update its last-written LSN if + * not. * * This is used in WAL replay in read replica, to skip updating pages that are * not in cache. */ bool -communicator_new_cache_contains(NRelFileInfo rinfo, ForkNumber forkNum, - BlockNumber blockno) +communicator_new_update_lwlsn_for_block_if_not_cached(NRelFileInfo rinfo, ForkNumber forkNum, + BlockNumber blockno, XLogRecPtr lsn) { - return bcomm_cache_contains(my_bs, - NInfoGetSpcOid(rinfo), - NInfoGetDbOid(rinfo), - NInfoGetRelNumber(rinfo), - forkNum, - blockno); + return bcomm_update_lw_lsn_for_block_if_not_cached(my_bs, + NInfoGetSpcOid(rinfo), + NInfoGetDbOid(rinfo), + NInfoGetRelNumber(rinfo), + forkNum, + blockno, + lsn); } /* Dump a list of blocks in the LFC, for use in prewarming later */ @@ -419,8 +421,9 @@ FileCacheState * communicator_new_get_lfc_state(size_t max_entries) { struct FileCacheIterator iter; - FileCacheState* fcs; + FileCacheState *fcs; uint8 *bitmap; + /* TODO: Max(max_entries, ) */ size_t n_entries = max_entries; size_t state_size = FILE_CACHE_STATE_SIZE_FOR_CHUNKS(n_entries, 1); @@ -436,14 +439,17 @@ communicator_new_get_lfc_state(size_t max_entries) bcomm_cache_iterate_begin(my_bs, &iter); while (n_pages < max_entries && bcomm_cache_iterate_next(my_bs, &iter)) { - BufferTag tag; + BufferTag tag; BufTagInit(tag, iter.rel_number, iter.fork_number, iter.block_number, iter.spc_oid, iter.db_oid); fcs->chunks[n_pages] = tag; n_pages++; } - /* fill bitmap. TODO: memset would be more efficient, but this is a silly format anyway */ + /* + * fill bitmap. TODO: memset would be more efficient, but this is a silly + * format anyway + */ for (size_t i = 0; i < n_pages; i++) { BITMAP_SET(bitmap, i); @@ -526,7 +532,7 @@ start_request(NeonIORequest *request, struct NeonIOResult *immediate_result_p) inflight_requests[num_inflight_requests] = request_idx; num_inflight_requests++; - elog(LOG, "started communicator request %s at slot %d", print_neon_io_request(request), request_idx); + elog(DEBUG5, "started communicator request %s at slot %d", print_neon_io_request(request), request_idx); return request_idx; } @@ -550,8 +556,8 @@ wait_request_completion(int request_idx, struct NeonIOResult *result_p) if (poll_res == -1) { /* - * Wake up periodically for CHECK_FOR_INTERRUPTS(). Because - * we wait on MyIOCompletionLatch rather than MyLatch, we won't be + * Wake up periodically for CHECK_FOR_INTERRUPTS(). Because we + * wait on MyIOCompletionLatch rather than MyLatch, we won't be * woken up for the standard interrupts. */ long timeout_ms = 1000; @@ -565,13 +571,14 @@ wait_request_completion(int request_idx, struct NeonIOResult *result_p) CHECK_FOR_INTERRUPTS(); /* - * FIXME: as a temporary hack, panic if we don't get a response promptly. - * Lots of regression tests are getting stuck and failing at the moment, - * this makes them fail a little faster, which it faster to iterate. - * This needs to be removed once more regression tests are passing. + * FIXME: as a temporary hack, panic if we don't get a response + * promptly. Lots of regression tests are getting stuck and + * failing at the moment, this makes them fail a little faster, + * which it faster to iterate. This needs to be removed once more + * regression tests are passing. */ now = GetCurrentTimestamp(); - if (now - start_time > 60 * 1000 * 1000) + if (now - start_time > 120 * 1000 * 1000) { elog(PANIC, "timed out waiting for response from communicator process at slot %d", request_idx); } @@ -618,10 +625,11 @@ communicator_new_rel_exists(NRelFileInfo rinfo, ForkNumber forkNum) case NeonIOResult_RelSize: return result.rel_size != InvalidBlockNumber; case NeonIOResult_Error: + errno = result.error; ereport(ERROR, (errcode_for_file_access(), - errmsg("could not check existence of rel %u/%u/%u.%u: %s", - RelFileInfoFmt(rinfo), forkNum, pg_strerror(result.error)))); + errmsg("could not check existence of rel %u/%u/%u.%u: %m", + RelFileInfoFmt(rinfo), forkNum))); break; default: elog(ERROR, "unexpected result for RelSize operation: %d", result.tag); @@ -633,8 +641,8 @@ communicator_new_rel_exists(NRelFileInfo rinfo, ForkNumber forkNum) * Read N consecutive pages from a relation */ void -communicator_new_read_at_lsnv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blockno, - void **buffers, BlockNumber nblocks) +communicator_new_readv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blockno, + void **buffers, BlockNumber nblocks) { NeonIOResult result; CCachedGetPageVResult cached_result; @@ -654,7 +662,7 @@ communicator_new_read_at_lsnv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumbe }; { - BufferTag tag; + BufferTag tag; CopyNRelFileInfoToBufTag(tag, rinfo); tag.forkNum = forkNum; @@ -662,7 +670,7 @@ communicator_new_read_at_lsnv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumbe { tag.blockNum = blockno; addSHLL(&communicator_shmem_ptr->wss_estimation, - hash_bytes((uint8_t *) &tag, sizeof(tag))); + hash_bytes((uint8_t *) & tag, sizeof(tag))); } } @@ -696,8 +704,8 @@ communicator_new_read_at_lsnv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumbe /* Split the vector-request into single page requests */ for (int j = 0; j < nblocks; j++) { - communicator_new_read_at_lsnv(rinfo, forkNum, blockno + j, - &buffers[j], 1); + communicator_new_readv(rinfo, forkNum, blockno + j, + &buffers[j], 1); } return; } @@ -789,13 +797,75 @@ retry: memcpy(buffers[0], bounce_buf_used, BLCKSZ); return; case NeonIOResult_Error: - ereport(ERROR, - (errcode_for_file_access(), - errmsg("could not read block %u in rel %u/%u/%u.%u: %s", - blockno, RelFileInfoFmt(rinfo), forkNum, pg_strerror(result.error)))); + errno = result.error; + if (nblocks > 0) + ereport(ERROR, + (errcode_for_file_access(), + errmsg("could not read block %u in rel %u/%u/%u.%u: %m", + blockno, RelFileInfoFmt(rinfo), forkNum))); + else + ereport(ERROR, + (errcode_for_file_access(), + errmsg("could not read %u blocks at %u in rel %u/%u/%u.%u: %m", + nblocks, blockno, RelFileInfoFmt(rinfo), forkNum))); break; default: - elog(ERROR, "unexpected result for GetPage operation: %d", result.tag); + elog(ERROR, "unexpected result for GetPageV operation: %d", result.tag); + break; + } +} + +/* + * Read a page at given LSN, bypassing the LFC. + * + * For tests and debugging purposes only. + */ +void +communicator_new_read_at_lsn_uncached(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blockno, + void *buffer, XLogRecPtr request_lsn, XLogRecPtr not_modified_since) +{ + NeonIOResult result; + void *bounce_buf_used; + NeonIORequest request = { + .tag = NeonIORequest_GetPageVUncached, + .get_page_v_uncached = { + .request_id = assign_request_id(), + .spc_oid = NInfoGetSpcOid(rinfo), + .db_oid = NInfoGetDbOid(rinfo), + .rel_number = NInfoGetRelNumber(rinfo), + .fork_number = forkNum, + .block_number = blockno, + .nblocks = 1, + .request_lsn = request_lsn, + .not_modified_since = not_modified_since, + } + }; + + /* + * This is for tests only and doesn't need to be particularly fast. Always + * use the bounce buffer for simplicity + */ + request.get_page_v_uncached.dest[0].ptr = bounce_buf_used = bounce_buf(); + + /* + * don't use the specialized bcomm_start_get_page_v_request() function + * here, because we want to bypass the LFC + */ + perform_request(&request, &result); + switch (result.tag) + { + case NeonIOResult_GetPageV: + memcpy(buffer, bounce_buf_used, BLCKSZ); + return; + case NeonIOResult_Error: + errno = result.error; + ereport(ERROR, + (errcode_for_file_access(), + errmsg("could not read (uncached) block %u in rel %u/%u/%u.%u: %m", + blockno, RelFileInfoFmt(rinfo), forkNum))); + break; + default: + elog(ERROR, "unexpected result for GetPageV operation: %d", result.tag); break; } } @@ -825,10 +895,11 @@ communicator_new_rel_nblocks(NRelFileInfo rinfo, ForkNumber forkNum) case NeonIOResult_RelSize: return result.rel_size; case NeonIOResult_Error: + errno = result.error; ereport(ERROR, (errcode_for_file_access(), - errmsg("could not read size of rel %u/%u/%u.%u: %s", - RelFileInfoFmt(rinfo), forkNum, pg_strerror(result.error)))); + errmsg("could not read size of rel %u/%u/%u.%u: %m", + RelFileInfoFmt(rinfo), forkNum))); break; default: elog(ERROR, "unexpected result for RelSize operation: %d", result.tag); @@ -857,10 +928,11 @@ communicator_new_dbsize(Oid dbNode) case NeonIOResult_DbSize: return (int64) result.db_size; case NeonIOResult_Error: + errno = result.error; ereport(ERROR, (errcode_for_file_access(), - errmsg("could not read database size of database %u: %s", - dbNode, pg_strerror(result.error)))); + errmsg("could not read database size of database %u: %m", + dbNode))); break; default: elog(ERROR, "unexpected result for DbSize operation: %d", result.tag); @@ -870,10 +942,10 @@ communicator_new_dbsize(Oid dbNode) int communicator_new_read_slru_segment( - SlruKind kind, - uint32_t segno, - neon_request_lsns *request_lsns, - const char* path) + SlruKind kind, + uint32_t segno, + neon_request_lsns * request_lsns, + const char *path) { NeonIOResult result = {}; NeonIORequest request = { @@ -885,10 +957,11 @@ communicator_new_read_slru_segment( .request_lsn = request_lsns->request_lsn, } }; - int nblocks = -1; - char *temp_path = bounce_buf(); + int nblocks = -1; + char *temp_path = bounce_buf(); - if (path == NULL) { + if (path == NULL) + { elog(ERROR, "read_slru_segment called with NULL path"); return -1; } @@ -897,7 +970,7 @@ communicator_new_read_slru_segment( request.read_slru_segment.destination_file_path.ptr = (uint8_t *) temp_path; elog(DEBUG5, "readslrusegment called for kind=%u, segno=%u, file_path=\"%s\"", - kind, segno, request.read_slru_segment.destination_file_path.ptr); + kind, segno, request.read_slru_segment.destination_file_path.ptr); /* FIXME: see `request_lsns` in main_loop.rs for why this is needed */ XLogSetAsyncXactLSN(request_lsns->request_lsn); @@ -910,10 +983,11 @@ communicator_new_read_slru_segment( nblocks = result.read_slru_segment; break; case NeonIOResult_Error: + errno = result.error; ereport(ERROR, (errcode_for_file_access(), - errmsg("could not read slru segment, kind=%u, segno=%u: %s", - kind, segno, pg_strerror(result.error)))); + errmsg("could not read slru segment, kind=%u, segno=%u: %m", + kind, segno))); break; default: elog(ERROR, "unexpected result for read SLRU operation: %d", result.tag); @@ -953,10 +1027,11 @@ communicator_new_write_page(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber case NeonIOResult_WriteOK: return; case NeonIOResult_Error: + errno = result.error; ereport(ERROR, (errcode_for_file_access(), - errmsg("could not write block %u in rel %u/%u/%u.%u: %s", - blockno, RelFileInfoFmt(rinfo), forkNum, pg_strerror(result.error)))); + errmsg("could not write block %u in rel %u/%u/%u.%u: %m", + blockno, RelFileInfoFmt(rinfo), forkNum))); break; default: elog(ERROR, "unexpected result for WritePage operation: %d", result.tag); @@ -993,10 +1068,11 @@ communicator_new_rel_extend(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber case NeonIOResult_WriteOK: return; case NeonIOResult_Error: + errno = result.error; ereport(ERROR, (errcode_for_file_access(), - errmsg("could not extend to block %u in rel %u/%u/%u.%u: %s", - blockno, RelFileInfoFmt(rinfo), forkNum, pg_strerror(result.error)))); + errmsg("could not extend to block %u in rel %u/%u/%u.%u: %m", + blockno, RelFileInfoFmt(rinfo), forkNum))); break; default: elog(ERROR, "unexpected result for Extend operation: %d", result.tag); @@ -1032,10 +1108,11 @@ communicator_new_rel_zeroextend(NRelFileInfo rinfo, ForkNumber forkNum, BlockNum case NeonIOResult_WriteOK: return; case NeonIOResult_Error: + errno = result.error; ereport(ERROR, (errcode_for_file_access(), - errmsg("could not zeroextend to block %u in rel %u/%u/%u.%u: %s", - blockno, RelFileInfoFmt(rinfo), forkNum, pg_strerror(result.error)))); + errmsg("could not zeroextend to block %u in rel %u/%u/%u.%u: %m", + blockno, RelFileInfoFmt(rinfo), forkNum))); break; default: elog(ERROR, "unexpected result for ZeroExtend operation: %d", result.tag); @@ -1068,10 +1145,11 @@ communicator_new_rel_create(NRelFileInfo rinfo, ForkNumber forkNum, XLogRecPtr l case NeonIOResult_WriteOK: return; case NeonIOResult_Error: + errno = result.error; ereport(ERROR, (errcode_for_file_access(), - errmsg("could not create rel %u/%u/%u.%u: %s", - RelFileInfoFmt(rinfo), forkNum, pg_strerror(result.error)))); + errmsg("could not create rel %u/%u/%u.%u: %m", + RelFileInfoFmt(rinfo), forkNum))); break; default: elog(ERROR, "unexpected result for Create operation: %d", result.tag); @@ -1105,10 +1183,11 @@ communicator_new_rel_truncate(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumbe case NeonIOResult_WriteOK: return; case NeonIOResult_Error: + errno = result.error; ereport(ERROR, (errcode_for_file_access(), - errmsg("could not truncate rel %u/%u/%u.%u to %u blocks: %s", - RelFileInfoFmt(rinfo), forkNum, nblocks, pg_strerror(result.error)))); + errmsg("could not truncate rel %u/%u/%u.%u to %u blocks: %m", + RelFileInfoFmt(rinfo), forkNum, nblocks))); break; default: elog(ERROR, "unexpected result for Truncate operation: %d", result.tag); @@ -1141,10 +1220,11 @@ communicator_new_rel_unlink(NRelFileInfo rinfo, ForkNumber forkNum, XLogRecPtr l case NeonIOResult_WriteOK: return; case NeonIOResult_Error: + errno = result.error; ereport(ERROR, (errcode_for_file_access(), - errmsg("could not unlink rel %u/%u/%u.%u: %s", - RelFileInfoFmt(rinfo), forkNum, pg_strerror(result.error)))); + errmsg("could not unlink rel %u/%u/%u.%u: %m", + RelFileInfoFmt(rinfo), forkNum))); break; default: elog(ERROR, "unexpected result for Unlink operation: %d", result.tag); @@ -1175,10 +1255,11 @@ communicator_new_update_cached_rel_size(NRelFileInfo rinfo, ForkNumber forkNum, case NeonIOResult_WriteOK: return; case NeonIOResult_Error: + errno = result.error; ereport(ERROR, (errcode_for_file_access(), - errmsg("could not update cached size for rel %u/%u/%u.%u: %s", - RelFileInfoFmt(rinfo), forkNum, pg_strerror(result.error)))); + errmsg("could not update cached size for rel %u/%u/%u.%u: %m", + RelFileInfoFmt(rinfo), forkNum))); break; default: elog(ERROR, "unexpected result for UpdateCachedRelSize operation: %d", result.tag); @@ -1213,8 +1294,18 @@ print_neon_io_request(NeonIORequest *request) CGetPageVRequest *r = &request->get_page_v; snprintf(buf, sizeof(buf), "GetPageV: req " UINT64_FORMAT " rel %u/%u/%u.%u blks %d-%d", - r->request_id, - r->spc_oid, r->db_oid, r->rel_number, r->fork_number, r->block_number, r->block_number + r->nblocks); + r->request_id, + r->spc_oid, r->db_oid, r->rel_number, r->fork_number, r->block_number, r->block_number + r->nblocks); + return buf; + } + case NeonIORequest_GetPageVUncached: + { + CGetPageVUncachedRequest *r = &request->get_page_v_uncached; + + snprintf(buf, sizeof(buf), "GetPageVUncached: req " UINT64_FORMAT " rel %u/%u/%u.%u blk %d request_lsn %X/%X not_modified_since %X/%X", + r->request_id, + r->spc_oid, r->db_oid, r->rel_number, r->fork_number, r->block_number, + LSN_FORMAT_ARGS(r->request_lsn), LSN_FORMAT_ARGS(r->not_modified_since)); return buf; } case NeonIORequest_ReadSlruSegment: @@ -1222,11 +1313,11 @@ print_neon_io_request(NeonIORequest *request) CReadSlruSegmentRequest *r = &request->read_slru_segment; snprintf(buf, sizeof(buf), "ReadSlruSegment: req " UINT64_FORMAT " slrukind=%u, segno=%u, lsn=%X/%X, file_path=\"%s\"", - r->request_id, - r->slru_kind, - r->segment_number, - LSN_FORMAT_ARGS(r->request_lsn), - r->destination_file_path.ptr); + r->request_id, + r->slru_kind, + r->segment_number, + LSN_FORMAT_ARGS(r->request_lsn), + r->destination_file_path.ptr); return buf; } case NeonIORequest_PrefetchV: @@ -1234,8 +1325,8 @@ print_neon_io_request(NeonIORequest *request) CPrefetchVRequest *r = &request->prefetch_v; snprintf(buf, sizeof(buf), "PrefetchV: req " UINT64_FORMAT " rel %u/%u/%u.%u blks %d-%d", - r->request_id, - r->spc_oid, r->db_oid, r->rel_number, r->fork_number, r->block_number, r->block_number + r->nblocks); + r->request_id, + r->spc_oid, r->db_oid, r->rel_number, r->fork_number, r->block_number, r->block_number + r->nblocks); return buf; } case NeonIORequest_DbSize: @@ -1243,7 +1334,7 @@ print_neon_io_request(NeonIORequest *request) CDbSizeRequest *r = &request->db_size; snprintf(buf, sizeof(buf), "PrefetchV: req " UINT64_FORMAT " db %u", - r->request_id, r->db_oid); + r->request_id, r->db_oid); return buf; } case NeonIORequest_WritePage: @@ -1251,9 +1342,9 @@ print_neon_io_request(NeonIORequest *request) CWritePageRequest *r = &request->write_page; snprintf(buf, sizeof(buf), "WritePage: req " UINT64_FORMAT " rel %u/%u/%u.%u blk %u lsn %X/%X", - r->request_id, - r->spc_oid, r->db_oid, r->rel_number, r->fork_number, r->block_number, - LSN_FORMAT_ARGS(r->lsn)); + r->request_id, + r->spc_oid, r->db_oid, r->rel_number, r->fork_number, r->block_number, + LSN_FORMAT_ARGS(r->lsn)); return buf; } case NeonIORequest_RelExtend: @@ -1261,9 +1352,9 @@ print_neon_io_request(NeonIORequest *request) CRelExtendRequest *r = &request->rel_extend; snprintf(buf, sizeof(buf), "RelExtend: req " UINT64_FORMAT " rel %u/%u/%u.%u blk %u lsn %X/%X", - r->request_id, - r->spc_oid, r->db_oid, r->rel_number, r->fork_number, r->block_number, - LSN_FORMAT_ARGS(r->lsn)); + r->request_id, + r->spc_oid, r->db_oid, r->rel_number, r->fork_number, r->block_number, + LSN_FORMAT_ARGS(r->lsn)); return buf; } case NeonIORequest_RelZeroExtend: @@ -1271,9 +1362,9 @@ print_neon_io_request(NeonIORequest *request) CRelZeroExtendRequest *r = &request->rel_zero_extend; snprintf(buf, sizeof(buf), "RelZeroExtend: req " UINT64_FORMAT " rel %u/%u/%u.%u blks %u-%u lsn %X/%X", - r->request_id, - r->spc_oid, r->db_oid, r->rel_number, r->fork_number, r->block_number, r->block_number + r->nblocks, - LSN_FORMAT_ARGS(r->lsn)); + r->request_id, + r->spc_oid, r->db_oid, r->rel_number, r->fork_number, r->block_number, r->block_number + r->nblocks, + LSN_FORMAT_ARGS(r->lsn)); return buf; } case NeonIORequest_RelCreate: @@ -1281,8 +1372,8 @@ print_neon_io_request(NeonIORequest *request) CRelCreateRequest *r = &request->rel_create; snprintf(buf, sizeof(buf), "RelCreate: req " UINT64_FORMAT " rel %u/%u/%u.%u", - r->request_id, - r->spc_oid, r->db_oid, r->rel_number, r->fork_number); + r->request_id, + r->spc_oid, r->db_oid, r->rel_number, r->fork_number); return buf; } case NeonIORequest_RelTruncate: @@ -1290,8 +1381,8 @@ print_neon_io_request(NeonIORequest *request) CRelTruncateRequest *r = &request->rel_truncate; snprintf(buf, sizeof(buf), "RelTruncate: req " UINT64_FORMAT " rel %u/%u/%u.%u blks %u", - r->request_id, - r->spc_oid, r->db_oid, r->rel_number, r->fork_number, r->nblocks); + r->request_id, + r->spc_oid, r->db_oid, r->rel_number, r->fork_number, r->nblocks); return buf; } case NeonIORequest_RelUnlink: @@ -1299,8 +1390,8 @@ print_neon_io_request(NeonIORequest *request) CRelUnlinkRequest *r = &request->rel_unlink; snprintf(buf, sizeof(buf), "RelUnlink: req " UINT64_FORMAT " rel %u/%u/%u.%u", - r->request_id, - r->spc_oid, r->db_oid, r->rel_number, r->fork_number); + r->request_id, + r->spc_oid, r->db_oid, r->rel_number, r->fork_number); return buf; } case NeonIORequest_UpdateCachedRelSize: @@ -1308,9 +1399,9 @@ print_neon_io_request(NeonIORequest *request) CUpdateCachedRelSizeRequest *r = &request->update_cached_rel_size; snprintf(buf, sizeof(buf), "UpdateCachedRelSize: req " UINT64_FORMAT " rel %u/%u/%u.%u blocks: %u", - r->request_id, - r->spc_oid, r->db_oid, r->rel_number, r->fork_number, - r->nblocks); + r->request_id, + r->spc_oid, r->db_oid, r->rel_number, r->fork_number, + r->nblocks); return buf; } } @@ -1365,45 +1456,68 @@ communicator_new_approximate_working_set_size_seconds(time_t duration, bool rese return dc; } - /* * Return an array of LfcStatsEntrys */ LfcStatsEntry * -communicator_new_get_lfc_stats(uint32 *num_entries) +communicator_new_lfc_get_stats(size_t *num_entries) { LfcStatsEntry *entries; - int n = 0; - uint64 cache_misses = 0; + size_t n = 0; uint64 cache_hits = 0; + uint64 cache_misses = 0; for (int i = 0; i < MaxProcs; i++) { - cache_misses += communicator_shmem_ptr->backends[i].cache_misses; cache_hits += communicator_shmem_ptr->backends[i].cache_hits; + cache_misses += communicator_shmem_ptr->backends[i].cache_misses; } #define NUM_ENTRIES 10 entries = palloc(sizeof(LfcStatsEntry) * NUM_ENTRIES); - entries[n++] = (LfcStatsEntry) {"file_cache_misses", false, cache_misses}; - entries[n++] = (LfcStatsEntry) {"file_cache_hits", false, cache_hits }; + entries[n++] = (LfcStatsEntry) + { + "file_cache_hits", false, cache_hits + }; + entries[n++] = (LfcStatsEntry) + { + "file_cache_misses", false, cache_misses + }; - entries[n++] = (LfcStatsEntry) {"file_cache_used_pages", false, - bcomm_cache_get_num_pages_used(my_bs) }; + entries[n++] = (LfcStatsEntry) + { + "file_cache_used_pages", false, + bcomm_cache_get_num_pages_used(my_bs) + }; /* TODO: these stats are exposed by the legacy LFC implementation */ #if 0 - entries[n++] = (LfcStatsEntry) {"file_cache_used", lfc_ctl == NULL, - lfc_ctl ? lfc_ctl->used : 0 }; - entries[n++] = (LfcStatsEntry) {"file_cache_writes", lfc_ctl == NULL, - lfc_ctl ? lfc_ctl->writes : 0 }; - entries[n++] = (LfcStatsEntry) {"file_cache_size", lfc_ctl == NULL, - lfc_ctl ? lfc_ctl->size : 0 }; - entries[n++] = (LfcStatsEntry) {"file_cache_evicted_pages", lfc_ctl == NULL, - lfc_ctl ? lfc_ctl->evicted_pages : 0 }; - entries[n++] = (LfcStatsEntry) {"file_cache_limit", lfc_ctl == NULL, - lfc_ctl ? lfc_ctl->limit : 0 }; + entries[n++] = (LfcStatsEntry) + { + "file_cache_used", lfc_ctl == NULL, + lfc_ctl ? lfc_ctl->used : 0 + }; + entries[n++] = (LfcStatsEntry) + { + "file_cache_writes", lfc_ctl == NULL, + lfc_ctl ? lfc_ctl->writes : 0 + }; + entries[n++] = (LfcStatsEntry) + { + "file_cache_size", lfc_ctl == NULL, + lfc_ctl ? lfc_ctl->size : 0 + }; + entries[n++] = (LfcStatsEntry) + { + "file_cache_evicted_pages", lfc_ctl == NULL, + lfc_ctl ? lfc_ctl->evicted_pages : 0 + }; + entries[n++] = (LfcStatsEntry) + { + "file_cache_limit", lfc_ctl == NULL, + lfc_ctl ? lfc_ctl->limit : 0 + }; #endif Assert(n <= NUM_ENTRIES); @@ -1411,3 +1525,40 @@ communicator_new_get_lfc_stats(uint32 *num_entries) *num_entries = n; return entries; } + +/* + * Get metrics, for the built-in metrics exporter that's part of the + * communicator process. + * + * NB: This is called from a Rust tokio task inside the communicator process. + * Acquiring lwlocks, elog(), allocating memory or anything else non-trivial + * is strictly prohibited here! + */ +struct LfcMetrics +communicator_new_get_lfc_metrics_unsafe(void) +{ + uint64 cache_hits = 0; + uint64 cache_misses = 0; + + struct LfcMetrics result = { + .lfc_cache_size_limit = (int64) lfc_size_limit * 1024 * 1024, + .lfc_used = 0, /* TODO */ + .lfc_writes = 0, /* TODO */ + }; + + for (int i = 0; i < MaxProcs; i++) + { + cache_hits += communicator_shmem_ptr->backends[i].cache_hits; + cache_misses += communicator_shmem_ptr->backends[i].cache_misses; + } + result.lfc_hits = cache_hits; + result.lfc_misses = cache_misses; + + for (int minutes = 1; minutes <= 60; minutes++) + { + result.lfc_approximate_working_set_size_windows[minutes - 1] = + communicator_new_approximate_working_set_size_seconds(minutes * 60, false); + } + + return result; +} diff --git a/pgxn/neon/communicator_new.h b/pgxn/neon/communicator_new.h index d3d4da20d5..d68c02db2e 100644 --- a/pgxn/neon/communicator_new.h +++ b/pgxn/neon/communicator_new.h @@ -30,19 +30,21 @@ extern void communicator_new_init(void); extern bool communicator_new_rel_exists(NRelFileInfo rinfo, ForkNumber forkNum); extern BlockNumber communicator_new_rel_nblocks(NRelFileInfo rinfo, ForkNumber forknum); extern int64 communicator_new_dbsize(Oid dbNode); -extern void communicator_new_read_at_lsnv(NRelFileInfo rinfo, ForkNumber forkNum, - BlockNumber base_blockno, - void **buffers, BlockNumber nblocks); +extern void communicator_new_readv(NRelFileInfo rinfo, ForkNumber forkNum, + BlockNumber base_blockno, + void **buffers, BlockNumber nblocks); +extern void communicator_new_read_at_lsn_uncached(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blockno, + void *buffer, XLogRecPtr request_lsn, XLogRecPtr not_modified_since); extern void communicator_new_prefetch_register_bufferv(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blockno, BlockNumber nblocks); -extern bool communicator_new_cache_contains(NRelFileInfo rinfo, ForkNumber forkNum, - BlockNumber blockno); -extern int communicator_new_read_slru_segment( - SlruKind kind, - uint32_t segno, - neon_request_lsns *request_lsns, - const char *path +extern bool communicator_new_update_lwlsn_for_block_if_not_cached(NRelFileInfo rinfo, ForkNumber forkNum, + BlockNumber blockno, XLogRecPtr lsn); +extern int communicator_new_read_slru_segment( + SlruKind kind, + uint32_t segno, + neon_request_lsns * request_lsns, + const char *path ); /* Write requests, to keep the caches up-to-date */ @@ -60,7 +62,8 @@ extern void communicator_new_update_cached_rel_size(NRelFileInfo rinfo, ForkNumb /* other functions */ extern int32 communicator_new_approximate_working_set_size_seconds(time_t duration, bool reset); +extern struct LfcMetrics communicator_new_get_lfc_metrics_unsafe(void); extern FileCacheState *communicator_new_get_lfc_state(size_t max_entries); -extern LfcStatsEntry *communicator_new_get_lfc_stats(uint32 *num_entries); +extern struct LfcStatsEntry *communicator_new_lfc_get_stats(size_t *num_entries); #endif /* COMMUNICATOR_NEW_H */ diff --git a/pgxn/neon/communicator_process.c b/pgxn/neon/communicator_process.c index af1b4e1497..dd0acaef13 100644 --- a/pgxn/neon/communicator_process.c +++ b/pgxn/neon/communicator_process.c @@ -32,6 +32,7 @@ #include "tcop/tcopprot.h" #include "utils/timestamp.h" +#include "communicator_new.h" #include "communicator_process.h" #include "file_cache.h" #include "neon.h" @@ -133,6 +134,7 @@ communicator_new_bgworker_main(Datum main_arg) connstrings = palloc(shard_map.num_shards * sizeof(char *)); for (int i = 0; i < shard_map.num_shards; i++) connstrings[i] = shard_map.connstring[i]; + AssignNumShards(shard_map.num_shards); proc_handle = communicator_worker_process_launch( cis, neon_tenant, @@ -231,6 +233,7 @@ communicator_new_bgworker_main(Datum main_arg) for (int i = 0; i < shard_map.num_shards; i++) connstrings[i] = shard_map.connstring[i]; + AssignNumShards(shard_map.num_shards); communicator_worker_config_reload(proc_handle, file_cache_size, connstrings, @@ -369,3 +372,16 @@ callback_get_request_lsn_unsafe(void) return flushlsn; } } + +/* + * Get metrics, for the built-in metrics exporter that's part of the + * communicator process. + */ +struct LfcMetrics +callback_get_lfc_metrics_unsafe(void) +{ + if (neon_use_communicator_worker) + return communicator_new_get_lfc_metrics_unsafe(); + else + return lfc_get_metrics_unsafe(); +} diff --git a/pgxn/neon/extension_server.c b/pgxn/neon/extension_server.c index 00dcb6920e..d64cd3e4af 100644 --- a/pgxn/neon/extension_server.c +++ b/pgxn/neon/extension_server.c @@ -14,7 +14,7 @@ #include "extension_server.h" #include "neon_utils.h" -static int extension_server_port = 0; +int hadron_extension_server_port = 0; static int extension_server_request_timeout = 60; static int extension_server_connect_timeout = 60; @@ -47,7 +47,7 @@ neon_download_extension_file_http(const char *filename, bool is_library) curl_easy_setopt(handle, CURLOPT_CONNECTTIMEOUT, (long)extension_server_connect_timeout /* seconds */ ); compute_ctl_url = psprintf("http://localhost:%d/extension_server/%s%s", - extension_server_port, filename, is_library ? "?is_library=true" : ""); + hadron_extension_server_port, filename, is_library ? "?is_library=true" : ""); elog(LOG, "Sending request to compute_ctl: %s", compute_ctl_url); @@ -82,7 +82,7 @@ pg_init_extension_server() DefineCustomIntVariable("neon.extension_server_port", "connection string to the compute_ctl", NULL, - &extension_server_port, + &hadron_extension_server_port, 0, 0, INT_MAX, PGC_POSTMASTER, 0, /* no flags required */ diff --git a/pgxn/neon/file_cache.c b/pgxn/neon/file_cache.c index 2096a39d5e..e447ba6d32 100644 --- a/pgxn/neon/file_cache.c +++ b/pgxn/neon/file_cache.c @@ -49,6 +49,7 @@ #include "neon.h" #include "neon_lwlsncache.h" #include "neon_perf_counters.h" +#include "neon_utils.h" #include "pagestore_client.h" #include "communicator.h" @@ -624,8 +625,19 @@ lfc_get_state(size_t max_entries) { if (GET_STATE(entry, j) != UNAVAILABLE) { - BITMAP_SET(bitmap, i*lfc_blocks_per_chunk + j); - n_pages += 1; + /* Validate the buffer tag before including it */ + BufferTag test_tag = entry->key; + test_tag.blockNum += j; + + if (BufferTagIsValid(&test_tag)) + { + BITMAP_SET(bitmap, i*lfc_blocks_per_chunk + j); + n_pages += 1; + } + else + { + elog(ERROR, "LFC: Skipping invalid buffer tag during cache state capture: blockNum=%u", test_tag.blockNum); + } } } if (++i == n_entries) @@ -634,7 +646,7 @@ lfc_get_state(size_t max_entries) Assert(i == n_entries); fcs->n_pages = n_pages; Assert(pg_popcount((char*)bitmap, ((n_entries << lfc_chunk_size_log) + 7)/8) == n_pages); - elog(LOG, "LFC: save state of %d chunks %d pages", (int)n_entries, (int)n_pages); + elog(LOG, "LFC: save state of %d chunks %d pages (validated)", (int)n_entries, (int)n_pages); } LWLockRelease(lfc_lock); @@ -1535,16 +1547,19 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno, } /* - * Return an array of LfcStatsEntrys + * Return metrics about the LFC. + * + * The return format is a palloc'd array of LfcStatsEntrys. The size + * of the returned array is returned in *num_entries. */ LfcStatsEntry * -get_lfc_stats(uint32 *num_entries) +lfc_get_stats(size_t *num_entries) { LfcStatsEntry *entries; - int n = 0; + size_t n = 0; -#define NUM_ENTRIES 10 - entries = palloc(sizeof(LfcStatsEntry) * NUM_ENTRIES); +#define MAX_ENTRIES 10 + entries = palloc(sizeof(LfcStatsEntry) * MAX_ENTRIES); entries[n++] = (LfcStatsEntry) {"file_cache_chunk_size_pages", lfc_ctl == NULL, lfc_ctl ? lfc_blocks_per_chunk : 0 }; @@ -1566,7 +1581,8 @@ get_lfc_stats(uint32 *num_entries) lfc_ctl ? lfc_ctl->limit : 0 }; entries[n++] = (LfcStatsEntry) {"file_cache_chunks_pinned", lfc_ctl == NULL, lfc_ctl ? lfc_ctl->pinned : 0 }; - Assert(n <= NUM_ENTRIES); + Assert(n <= MAX_ENTRIES); +#undef MAX_ENTRIES *num_entries = n; return entries; @@ -1576,193 +1592,86 @@ get_lfc_stats(uint32 *num_entries) * Function returning data from the local file cache * relation node/tablespace/database/blocknum and access_counter */ -PG_FUNCTION_INFO_V1(local_cache_pages); - -/* - * Record structure holding the to be exposed cache data. - */ -typedef struct +LocalCachePagesRec * +lfc_local_cache_pages(size_t *num_entries) { - uint32 pageoffs; - Oid relfilenode; - Oid reltablespace; - Oid reldatabase; - ForkNumber forknum; - BlockNumber blocknum; - uint16 accesscount; -} LocalCachePagesRec; + HASH_SEQ_STATUS status; + FileCacheEntry *entry; + size_t n_pages; + size_t n; + LocalCachePagesRec *result; -/* - * Function context for data persisting over repeated calls. - */ -typedef struct -{ - TupleDesc tupdesc; - LocalCachePagesRec *record; -} LocalCachePagesContext; - - -#define NUM_LOCALCACHE_PAGES_ELEM 7 - -Datum -local_cache_pages(PG_FUNCTION_ARGS) -{ - FuncCallContext *funcctx; - Datum result; - MemoryContext oldcontext; - LocalCachePagesContext *fctx; /* User function context. */ - TupleDesc tupledesc; - TupleDesc expected_tupledesc; - HeapTuple tuple; - - if (SRF_IS_FIRSTCALL()) + if (!lfc_ctl) { - HASH_SEQ_STATUS status; - FileCacheEntry *entry; - uint32 n_pages = 0; + *num_entries = 0; + return NULL; + } - funcctx = SRF_FIRSTCALL_INIT(); + LWLockAcquire(lfc_lock, LW_SHARED); + if (!LFC_ENABLED()) + { + LWLockRelease(lfc_lock); + *num_entries = 0; + return NULL; + } - /* Switch context when allocating stuff to be used in later calls */ - oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx); - - /* Create a user function context for cross-call persistence */ - fctx = (LocalCachePagesContext *) palloc(sizeof(LocalCachePagesContext)); - - /* - * To smoothly support upgrades from version 1.0 of this extension - * transparently handle the (non-)existence of the pinning_backends - * column. We unfortunately have to get the result type for that... - - * we can't use the result type determined by the function definition - * without potentially crashing when somebody uses the old (or even - * wrong) function definition though. - */ - if (get_call_result_type(fcinfo, NULL, &expected_tupledesc) != TYPEFUNC_COMPOSITE) - neon_log(ERROR, "return type must be a row type"); - - if (expected_tupledesc->natts != NUM_LOCALCACHE_PAGES_ELEM) - neon_log(ERROR, "incorrect number of output arguments"); - - /* Construct a tuple descriptor for the result rows. */ - tupledesc = CreateTemplateTupleDesc(expected_tupledesc->natts); - TupleDescInitEntry(tupledesc, (AttrNumber) 1, "pageoffs", - INT8OID, -1, 0); -#if PG_MAJORVERSION_NUM < 16 - TupleDescInitEntry(tupledesc, (AttrNumber) 2, "relfilenode", - OIDOID, -1, 0); -#else - TupleDescInitEntry(tupledesc, (AttrNumber) 2, "relfilenumber", - OIDOID, -1, 0); -#endif - TupleDescInitEntry(tupledesc, (AttrNumber) 3, "reltablespace", - OIDOID, -1, 0); - TupleDescInitEntry(tupledesc, (AttrNumber) 4, "reldatabase", - OIDOID, -1, 0); - TupleDescInitEntry(tupledesc, (AttrNumber) 5, "relforknumber", - INT2OID, -1, 0); - TupleDescInitEntry(tupledesc, (AttrNumber) 6, "relblocknumber", - INT8OID, -1, 0); - TupleDescInitEntry(tupledesc, (AttrNumber) 7, "accesscount", - INT4OID, -1, 0); - - fctx->tupdesc = BlessTupleDesc(tupledesc); - - if (lfc_ctl) + /* Count the pages first */ + n_pages = 0; + hash_seq_init(&status, lfc_hash); + while ((entry = hash_seq_search(&status)) != NULL) + { + /* Skip hole tags */ + if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0) { - LWLockAcquire(lfc_lock, LW_SHARED); + for (int i = 0; i < lfc_blocks_per_chunk; i++) + n_pages += GET_STATE(entry, i) == AVAILABLE; + } + } - if (LFC_ENABLED()) + if (n_pages == 0) + { + LWLockRelease(lfc_lock); + *num_entries = 0; + return NULL; + } + + result = (LocalCachePagesRec *) + MemoryContextAllocHuge(CurrentMemoryContext, + sizeof(LocalCachePagesRec) * n_pages); + + /* + * Scan through all the cache entries, saving the relevant fields + * in the result structure. + */ + n = 0; + hash_seq_init(&status, lfc_hash); + while ((entry = hash_seq_search(&status)) != NULL) + { + for (int i = 0; i < lfc_blocks_per_chunk; i++) + { + if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0) { - hash_seq_init(&status, lfc_hash); - while ((entry = hash_seq_search(&status)) != NULL) + if (GET_STATE(entry, i) == AVAILABLE) { - /* Skip hole tags */ - if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0) - { - for (int i = 0; i < lfc_blocks_per_chunk; i++) - n_pages += GET_STATE(entry, i) == AVAILABLE; - } + result[n].pageoffs = entry->offset * lfc_blocks_per_chunk + i; + result[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)); + result[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(entry->key)); + result[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(entry->key)); + result[n].forknum = entry->key.forkNum; + result[n].blocknum = entry->key.blockNum + i; + result[n].accesscount = entry->access_count; + n += 1; } } } - fctx->record = (LocalCachePagesRec *) - MemoryContextAllocHuge(CurrentMemoryContext, - sizeof(LocalCachePagesRec) * n_pages); - - /* Set max calls and remember the user function context. */ - funcctx->max_calls = n_pages; - funcctx->user_fctx = fctx; - - /* Return to original context when allocating transient memory */ - MemoryContextSwitchTo(oldcontext); - - if (n_pages != 0) - { - /* - * Scan through all the cache entries, saving the relevant fields - * in the fctx->record structure. - */ - uint32 n = 0; - - hash_seq_init(&status, lfc_hash); - while ((entry = hash_seq_search(&status)) != NULL) - { - for (int i = 0; i < lfc_blocks_per_chunk; i++) - { - if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0) - { - if (GET_STATE(entry, i) == AVAILABLE) - { - fctx->record[n].pageoffs = entry->offset * lfc_blocks_per_chunk + i; - fctx->record[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)); - fctx->record[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(entry->key)); - fctx->record[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(entry->key)); - fctx->record[n].forknum = entry->key.forkNum; - fctx->record[n].blocknum = entry->key.blockNum + i; - fctx->record[n].accesscount = entry->access_count; - n += 1; - } - } - } - } - Assert(n_pages == n); - } - if (lfc_ctl) - LWLockRelease(lfc_lock); } + Assert(n_pages == n); + LWLockRelease(lfc_lock); - funcctx = SRF_PERCALL_SETUP(); - - /* Get the saved state */ - fctx = funcctx->user_fctx; - - if (funcctx->call_cntr < funcctx->max_calls) - { - uint32 i = funcctx->call_cntr; - Datum values[NUM_LOCALCACHE_PAGES_ELEM]; - bool nulls[NUM_LOCALCACHE_PAGES_ELEM] = { - false, false, false, false, false, false, false - }; - - values[0] = Int64GetDatum((int64) fctx->record[i].pageoffs); - values[1] = ObjectIdGetDatum(fctx->record[i].relfilenode); - values[2] = ObjectIdGetDatum(fctx->record[i].reltablespace); - values[3] = ObjectIdGetDatum(fctx->record[i].reldatabase); - values[4] = ObjectIdGetDatum(fctx->record[i].forknum); - values[5] = Int64GetDatum((int64) fctx->record[i].blocknum); - values[6] = Int32GetDatum(fctx->record[i].accesscount); - - /* Build and return the tuple. */ - tuple = heap_form_tuple(fctx->tupdesc, values, nulls); - result = HeapTupleGetDatum(tuple); - - SRF_RETURN_NEXT(funcctx, result); - } - else - SRF_RETURN_DONE(funcctx); + *num_entries = n_pages; + return result; } - /* * Internal implementation of the approximate_working_set_size_seconds() * function. @@ -1782,15 +1691,15 @@ lfc_approximate_working_set_size_seconds(time_t duration, bool reset) } /* - * Get metrics, for the built-in metrics exporter that's part of the communicator - * process. + * Get metrics, for the built-in metrics exporter that's part of the + * communicator process. * * NB: This is called from a Rust tokio task inside the communicator process. * Acquiring lwlocks, elog(), allocating memory or anything else non-trivial * is strictly prohibited here! */ struct LfcMetrics -callback_get_lfc_metrics_unsafe(void) +lfc_get_metrics_unsafe(void) { struct LfcMetrics result = { .lfc_cache_size_limit = (int64) lfc_size_limit * 1024 * 1024, diff --git a/pgxn/neon/file_cache.h b/pgxn/neon/file_cache.h index d6ada4c6fb..d46cc92e4b 100644 --- a/pgxn/neon/file_cache.h +++ b/pgxn/neon/file_cache.h @@ -44,8 +44,24 @@ extern bool lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blk const void* buffer, XLogRecPtr lsn); extern FileCacheState* lfc_get_state(size_t max_entries); +extern LfcStatsEntry *lfc_get_stats(size_t *num_entries); + +struct LfcMetrics; /* defined in communicator_bindings.h */ +extern struct LfcMetrics lfc_get_metrics_unsafe(void); + +typedef struct +{ + uint32 pageoffs; + Oid relfilenode; + Oid reltablespace; + Oid reldatabase; + ForkNumber forknum; + BlockNumber blocknum; + uint16 accesscount; +} LocalCachePagesRec; +extern LocalCachePagesRec *lfc_local_cache_pages(size_t *num_entries); + extern int32 lfc_approximate_working_set_size_seconds(time_t duration, bool reset); -extern LfcStatsEntry *get_lfc_stats(uint32 *num_entries); static inline bool lfc_read(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno, diff --git a/pgxn/neon/lfc_prewarm.c b/pgxn/neon/lfc_prewarm.c index 1608fde628..5c23d52f3c 100644 --- a/pgxn/neon/lfc_prewarm.c +++ b/pgxn/neon/lfc_prewarm.c @@ -17,6 +17,7 @@ #include "file_cache.h" #include "lfc_prewarm.h" #include "neon.h" +#include "neon_utils.h" #include "pagestore_client.h" #include "funcapi.h" @@ -350,6 +351,10 @@ lfc_prewarm_main(Datum main_arg) { tag = fcs->chunks[snd_idx >> fcs_chunk_size_log]; tag.blockNum += snd_idx & ((1 << fcs_chunk_size_log) - 1); + + if (!BufferTagIsValid(&tag)) + elog(ERROR, "LFC: Invalid buffer tag: %u", tag.blockNum); + if (!lfc_cache_contains(BufTagGetNRelFileInfo(tag), tag.forkNum, tag.blockNum)) { (void) communicator_prefetch_register_bufferv(tag, NULL, 1, NULL); @@ -478,6 +483,9 @@ lfc_prewarm_with_async_requests(FileCacheState *fcs) BlockNumber request_startblkno = InvalidBlockNumber; BlockNumber request_endblkno; + if (!BufferTagIsValid(chunk_tag)) + elog(ERROR, "LFC: Invalid buffer tag: %u", chunk_tag->blockNum); + if (lfc_prewarm_cancel) { prewarm_ctl->prewarm_canceled = true; diff --git a/pgxn/neon/libpagestore.c b/pgxn/neon/libpagestore.c index 690dfd8635..d48824df8d 100644 --- a/pgxn/neon/libpagestore.c +++ b/pgxn/neon/libpagestore.c @@ -13,6 +13,8 @@ #include #include +#include + #include "libpq-int.h" #include "access/xlog.h" @@ -87,6 +89,10 @@ static int pageserver_response_log_timeout = 10000; /* 2.5 minutes. A bit higher than highest default TCP retransmission timeout */ static int pageserver_response_disconnect_timeout = 150000; +static int conf_refresh_reconnect_attempt_threshold = 16; +// Hadron: timeout for refresh errors (1 minute) +static uint64 kRefreshErrorTimeoutUSec = 1 * USECS_PER_MINUTE; + /* * PagestoreShmemState is kept in shared memory. It contains the connection * strings for each shard. @@ -172,6 +178,8 @@ static PageServer page_servers[MAX_SHARDS]; static bool pageserver_flush(shardno_t shard_no); static void pageserver_disconnect(shardno_t shard_no); static void pageserver_disconnect_shard(shardno_t shard_no); +// HADRON +shardno_t get_num_shards(void); static void AssignShardMap(const char *newval); @@ -307,6 +315,43 @@ AssignShardMap(const char *newval) } } +/* + * Set the 'num_shards' variable in shared memory. + * + * This is only used with the new communicator. The new communicator doesn't + * use the shard_map in shared memory, except for the shard count, which is + * needed by get_num_shards() calls in the walproposer. This is called to set + * that. This is only called from the communicator process, at process startup + * or if the configuration is reloaded. + */ +void +AssignNumShards(shardno_t num_shards) +{ + Assert(neon_use_communicator_worker); + + pg_atomic_add_fetch_u64(&pagestore_shared->begin_update_counter, 1); + pg_write_barrier(); + pagestore_shared->shard_map.num_shards = num_shards; + pg_write_barrier(); + pg_atomic_add_fetch_u64(&pagestore_shared->end_update_counter, 1); +} + +/* BEGIN_HADRON */ +/** + * Return the total number of shards seen in the shard map. + */ +shardno_t get_num_shards(void) +{ + const ShardMap *shard_map; + + Assert(pagestore_shared); + shard_map = &pagestore_shared->shard_map; + + Assert(shard_map != NULL); + return shard_map->num_shards; +} +/* END_HADRON */ + /* * Get the current number of shards, and/or the connection string for a * particular shard from the shard map in shared memory. @@ -1033,6 +1078,101 @@ pageserver_disconnect_shard(shardno_t shard_no) shard->state = PS_Disconnected; } +// BEGIN HADRON +/* + * Nudge compute_ctl to refresh our configuration. Called when we suspect we may be + * connecting to the wrong pageservers due to a stale configuration. + * + * This is a best-effort operation. If we couldn't send the local loopback HTTP request + * to compute_ctl or if the request fails for any reason, we just log the error and move + * on. + */ + +extern int hadron_extension_server_port; + +// The timestamp (usec) of the first error that occurred while trying to refresh the configuration. +// Will be reset to 0 after a successful refresh. +static uint64 first_recorded_refresh_error_usec = 0; + +// Request compute_ctl to refresh the configuration. This operation may fail, e.g., if the compute_ctl +// is already in the configuration state. The function returns true if the caller needs to cancel the +// current query to avoid dead/live lock. +static bool +hadron_request_configuration_refresh() { + static CURL *handle = NULL; + CURLcode res; + char *compute_ctl_url; + bool cancel_query = false; + + if (!lakebase_mode) + return false; + + if (handle == NULL) + { + handle = alloc_curl_handle(); + + curl_easy_setopt(handle, CURLOPT_CUSTOMREQUEST, "POST"); + curl_easy_setopt(handle, CURLOPT_TIMEOUT, 3L /* seconds */ ); + curl_easy_setopt(handle, CURLOPT_POSTFIELDS, ""); + } + + // Set the URL + compute_ctl_url = psprintf("http://localhost:%d/refresh_configuration", hadron_extension_server_port); + + + elog(LOG, "Sending refresh configuration request to compute_ctl: %s", compute_ctl_url); + + curl_easy_setopt(handle, CURLOPT_URL, compute_ctl_url); + + res = curl_easy_perform(handle); + if (res != CURLE_OK ) + { + elog(WARNING, "refresh_configuration request failed: %s\n", curl_easy_strerror(res)); + } + else + { + long http_code = 0; + curl_easy_getinfo(handle, CURLINFO_RESPONSE_CODE, &http_code); + if ( res != CURLE_OK ) + { + elog(WARNING, "compute_ctl refresh_configuration request getinfo failed: %s\n", curl_easy_strerror(res)); + } + else + { + elog(LOG, "compute_ctl refresh_configuration got HTTP response: %ld\n", http_code); + if( http_code == 200 ) + { + first_recorded_refresh_error_usec = 0; + } + else + { + if (first_recorded_refresh_error_usec == 0) + { + first_recorded_refresh_error_usec = GetCurrentTimestamp(); + } + else if(GetCurrentTimestamp() - first_recorded_refresh_error_usec > kRefreshErrorTimeoutUSec) + { + { + first_recorded_refresh_error_usec = 0; + cancel_query = true; + } + } + } + } + } + + // In regular Postgres usage, it is not necessary to manually free memory allocated by palloc (psprintf) because + // it will be cleaned up after the "memory context" is reset (e.g. after the query or the transaction is finished). + // However, the number of times this function gets called during a single query/transaction can be unbounded due to + // the various retry loops around calls to pageservers. Therefore, we need to manually free this memory here. + if (compute_ctl_url != NULL) + { + pfree(compute_ctl_url); + } + return cancel_query; +} +// END HADRON + static bool pageserver_send(shardno_t shard_no, NeonRequest *request) { @@ -1067,6 +1207,11 @@ pageserver_send(shardno_t shard_no, NeonRequest *request) while (!pageserver_connect(shard_no, shard->n_reconnect_attempts < max_reconnect_attempts ? LOG : ERROR)) { shard->n_reconnect_attempts += 1; + if (shard->n_reconnect_attempts > conf_refresh_reconnect_attempt_threshold + && hadron_request_configuration_refresh() ) + { + neon_shard_log(shard_no, ERROR, "request failed too many times, cancelling query"); + } } shard->n_reconnect_attempts = 0; } else { @@ -1174,17 +1319,26 @@ pageserver_receive(shardno_t shard_no) pfree(msg); pageserver_disconnect(shard_no); resp = NULL; + + /* + * Always poke compute_ctl to request a configuration refresh if we have issues receiving data from pageservers after + * successfully connecting to it. It could be an indication that we are connecting to the wrong pageservers (e.g. PS + * is in secondary mode or otherwise refuses to respond our request). + */ + hadron_request_configuration_refresh(); } else if (rc == -2) { char *msg = pchomp(PQerrorMessage(pageserver_conn)); pageserver_disconnect(shard_no); + hadron_request_configuration_refresh(); neon_shard_log(shard_no, ERROR, "pageserver_receive disconnect: could not read COPY data: %s", msg); } else { pageserver_disconnect(shard_no); + hadron_request_configuration_refresh(); neon_shard_log(shard_no, ERROR, "pageserver_receive disconnect: unexpected PQgetCopyData return value: %d", rc); } @@ -1252,21 +1406,34 @@ pageserver_try_receive(shardno_t shard_no) neon_shard_log(shard_no, LOG, "pageserver_receive disconnect: psql end of copy data: %s", pchomp(PQerrorMessage(pageserver_conn))); pageserver_disconnect(shard_no); resp = NULL; + hadron_request_configuration_refresh(); } else if (rc == -2) { char *msg = pchomp(PQerrorMessage(pageserver_conn)); pageserver_disconnect(shard_no); + hadron_request_configuration_refresh(); neon_shard_log(shard_no, LOG, "pageserver_receive disconnect: could not read COPY data: %s", msg); resp = NULL; } else { pageserver_disconnect(shard_no); + hadron_request_configuration_refresh(); neon_shard_log(shard_no, ERROR, "pageserver_receive disconnect: unexpected PQgetCopyData return value: %d", rc); } + /* + * Always poke compute_ctl to request a configuration refresh if we have issues receiving data from pageservers after + * successfully connecting to it. It could be an indication that we are connecting to the wrong pageservers (e.g. PS + * is in secondary mode or otherwise refuses to respond our request). + */ + if ( rc < 0 && hadron_request_configuration_refresh() ) + { + neon_shard_log(shard_no, ERROR, "refresh_configuration request failed, cancelling query"); + } + shard->nresponses_received++; return (NeonResponse *) resp; } @@ -1315,7 +1482,6 @@ check_neon_id(char **newval, void **extra, GucSource source) return **newval == '\0' || HexDecodeString(id, *newval, 16); } - void PagestoreShmemInit(void) { @@ -1472,6 +1638,16 @@ pg_init_libpagestore(void) PGC_SU_BACKEND, 0, /* no flags required */ NULL, NULL, NULL); + DefineCustomIntVariable("hadron.conf_refresh_reconnect_attempt_threshold", + "Threshold of the number of consecutive failed pageserver " + "connection attempts (per shard) before signaling " + "compute_ctl for a configuration refresh.", + NULL, + &conf_refresh_reconnect_attempt_threshold, + 16, 0, INT_MAX, + PGC_USERSET, + 0, + NULL, NULL, NULL); DefineCustomIntVariable("neon.pageserver_response_log_timeout", "pageserver response log timeout", diff --git a/pgxn/neon/neon.c b/pgxn/neon/neon.c index faf4b1d13b..7a6936a740 100644 --- a/pgxn/neon/neon.c +++ b/pgxn/neon/neon.c @@ -1,7 +1,7 @@ /*------------------------------------------------------------------------- * * neon.c - * Main entry point into the neon exension + * Main entry point into the neon extension * *------------------------------------------------------------------------- */ @@ -50,8 +50,10 @@ PG_MODULE_MAGIC; void _PG_init(void); +bool lakebase_mode = false; static int running_xacts_overflow_policy; +static emit_log_hook_type prev_emit_log_hook; static bool monitor_query_exec_time = false; static ExecutorStart_hook_type prev_ExecutorStart = NULL; @@ -82,6 +84,8 @@ uint32 WAIT_EVENT_NEON_PS_READ; uint32 WAIT_EVENT_NEON_WAL_DL; #endif +int databricks_test_hook = 0; + enum RunningXactsOverflowPolicies { OP_IGNORE, OP_SKIP, @@ -446,6 +450,20 @@ ReportSearchPath(void) static int neon_pgstat_file_size_limit; #endif +static void DatabricksSqlErrorHookImpl(ErrorData *edata) { + if (prev_emit_log_hook != NULL) { + prev_emit_log_hook(edata); + } + + if (edata->sqlerrcode == ERRCODE_DATA_CORRUPTED) { + pg_atomic_fetch_add_u32(&databricks_metrics_shared->data_corruption_count, 1); + } else if (edata->sqlerrcode == ERRCODE_INDEX_CORRUPTED) { + pg_atomic_fetch_add_u32(&databricks_metrics_shared->index_corruption_count, 1); + } else if (edata->sqlerrcode == ERRCODE_INTERNAL_ERROR) { + pg_atomic_fetch_add_u32(&databricks_metrics_shared->internal_error_count, 1); + } +} + void _PG_init(void) { @@ -467,6 +485,11 @@ _PG_init(void) 0, NULL, NULL, NULL); + if (lakebase_mode) { + prev_emit_log_hook = emit_log_hook; + emit_log_hook = DatabricksSqlErrorHookImpl; + } + /* * Initializing a pre-loaded Postgres extension happens in three stages: * @@ -503,7 +526,7 @@ _PG_init(void) lfc_init(); pg_init_prewarm(); pg_init_walproposer(); - init_lwlsncache(); + pg_init_lwlsncache(); pg_init_communicator_process(); @@ -521,7 +544,7 @@ _PG_init(void) DefineCustomBoolVariable( "neon.disable_logical_replication_subscribers", - "Disables incomming logical replication", + "Disable incoming logical replication", NULL, &disable_logical_replication_subscribers, false, @@ -580,7 +603,7 @@ _PG_init(void) DefineCustomEnumVariable( "neon.debug_compare_local", - "Debug mode for compaing content of pages in prefetch ring/LFC/PS and local disk", + "Debug mode for comparing content of pages in prefetch ring/LFC/PS and local disk", NULL, &debug_compare_local, DEBUG_COMPARE_LOCAL_NONE, @@ -597,6 +620,29 @@ _PG_init(void) "neon_superuser", PGC_POSTMASTER, 0, NULL, NULL, NULL); + DefineCustomBoolVariable( + "neon.lakebase_mode", + "Is neon running in Lakebase?", + NULL, + &lakebase_mode, + false, + PGC_POSTMASTER, + 0, + NULL, NULL, NULL); + + // A test hook used in sql regress to trigger specific behaviors + // to test features easily. + DefineCustomIntVariable( + "databricks.test_hook", + "The test hook used in sql regress tests only", + NULL, + &databricks_test_hook, + 0, + 0, INT32_MAX, + PGC_SUSET, + 0, + NULL, NULL, NULL); + /* * Important: This must happen after other parts of the extension are * loaded, otherwise any settings to GUCs that were set before the @@ -628,11 +674,15 @@ _PG_init(void) ExecutorEnd_hook = neon_ExecutorEnd; } +/* Various functions exposed at SQL level */ + PG_FUNCTION_INFO_V1(pg_cluster_size); PG_FUNCTION_INFO_V1(backpressure_lsns); PG_FUNCTION_INFO_V1(backpressure_throttling_time); PG_FUNCTION_INFO_V1(approximate_working_set_size_seconds); PG_FUNCTION_INFO_V1(approximate_working_set_size); +PG_FUNCTION_INFO_V1(neon_get_lfc_stats); +PG_FUNCTION_INFO_V1(local_cache_pages); Datum pg_cluster_size(PG_FUNCTION_ARGS) @@ -713,38 +763,78 @@ approximate_working_set_size(PG_FUNCTION_ARGS) PG_RETURN_INT32(dc); } -PG_FUNCTION_INFO_V1(neon_get_lfc_stats); Datum neon_get_lfc_stats(PG_FUNCTION_ARGS) { -#define NUM_NEON_GET_STATS_COLS 2 +#define NUM_NEON_GET_STATS_COLS 2 ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo; LfcStatsEntry *entries; - uint32 num_entries; + size_t num_entries; InitMaterializedSRF(fcinfo, 0); + /* lfc_get_stats() does all the heavy lifting */ if (neon_use_communicator_worker) - entries = communicator_new_get_lfc_stats(&num_entries); + entries = communicator_new_lfc_get_stats(&num_entries); else - entries = get_lfc_stats(&num_entries); + entries = lfc_get_stats(&num_entries); - for (uint32 i = 0; i < num_entries; i++) + /* Convert the LfcStatsEntrys to a result set */ + for (size_t i = 0; i < num_entries; i++) { LfcStatsEntry *entry = &entries[i]; Datum values[NUM_NEON_GET_STATS_COLS]; bool nulls[NUM_NEON_GET_STATS_COLS]; - nulls[0] = false; values[0] = CStringGetTextDatum(entry->metric_name); - nulls[1] = entry->isnull; + nulls[0] = false; values[1] = Int64GetDatum(entry->isnull ? 0 : entry->value); + nulls[1] = entry->isnull; + tuplestore_putvalues(rsinfo->setResult, rsinfo->setDesc, values, nulls); + } + PG_RETURN_VOID(); + +#undef NUM_NEON_GET_STATS_COLS +} + +Datum +local_cache_pages(PG_FUNCTION_ARGS) +{ +#define NUM_LOCALCACHE_PAGES_COLS 7 + ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo; + LocalCachePagesRec *entries; + size_t num_entries; + + InitMaterializedSRF(fcinfo, 0); + + /* lfc_local_cache_pages() does all the heavy lifting */ + entries = lfc_local_cache_pages(&num_entries); + + /* Convert the LocalCachePagesRec structs to a result set */ + for (size_t i = 0; i < num_entries; i++) + { + LocalCachePagesRec *entry = &entries[i]; + Datum values[NUM_LOCALCACHE_PAGES_COLS]; + bool nulls[NUM_LOCALCACHE_PAGES_COLS] = { + false, false, false, false, false, false, false + }; + + values[0] = Int64GetDatum((int64) entry->pageoffs); + values[1] = ObjectIdGetDatum(entry->relfilenode); + values[2] = ObjectIdGetDatum(entry->reltablespace); + values[3] = ObjectIdGetDatum(entry->reldatabase); + values[4] = ObjectIdGetDatum(entry->forknum); + values[5] = Int64GetDatum((int64) entry->blocknum); + values[6] = Int32GetDatum(entry->accesscount); + + /* Build and return the tuple. */ tuplestore_putvalues(rsinfo->setResult, rsinfo->setDesc, values, nulls); } PG_RETURN_VOID(); -} +#undef NUM_LOCALCACHE_PAGES_COLS +} /* * Initialization stage 2: make requests for the amount of shared memory we @@ -779,7 +869,6 @@ neon_shmem_request_hook(void) static void neon_shmem_startup_hook(void) { - /* Initialize */ if (prev_shmem_startup_hook) prev_shmem_startup_hook(); @@ -788,6 +877,9 @@ neon_shmem_startup_hook(void) LfcShmemInit(); PrewarmShmemInit(); NeonPerfCountersShmemInit(); + if (lakebase_mode) { + DatabricksMetricsShmemInit(); + } PagestoreShmemInit(); RelsizeCacheShmemInit(); WalproposerShmemInit(); diff --git a/pgxn/neon/neon.h b/pgxn/neon/neon.h index ad843553a5..d350e9468d 100644 --- a/pgxn/neon/neon.h +++ b/pgxn/neon/neon.h @@ -21,6 +21,7 @@ extern int wal_acceptor_reconnect_timeout; extern int wal_acceptor_connection_timeout; extern int readahead_getpage_pull_timeout_ms; extern bool disable_wal_prev_lsn_checks; +extern bool lakebase_mode; extern bool AmPrewarmWorker; @@ -84,7 +85,8 @@ extern void WalproposerShmemInit(void); extern void LwLsnCacheShmemInit(void); extern void NeonPerfCountersShmemInit(void); -typedef struct LfcStatsEntry { +typedef struct LfcStatsEntry +{ const char *metric_name; bool isnull; uint64 value; diff --git a/pgxn/neon/neon_lwlsncache.c b/pgxn/neon/neon_lwlsncache.c index 5887c02c36..16935edf10 100644 --- a/pgxn/neon/neon_lwlsncache.c +++ b/pgxn/neon/neon_lwlsncache.c @@ -85,12 +85,54 @@ static set_lwlsn_db_hook_type prev_set_lwlsn_db_hook = NULL; static void neon_set_max_lwlsn(XLogRecPtr lsn); void -init_lwlsncache(void) +pg_init_lwlsncache(void) { if (!process_shared_preload_libraries_in_progress) ereport(ERROR, errcode(ERRCODE_INTERNAL_ERROR), errmsg("Loading of shared preload libraries is not in progress. Exiting")); lwlc_register_gucs(); +} + + +void +LwLsnCacheShmemRequest(void) +{ + Size requested_size; + + if (neon_use_communicator_worker) + return; + + requested_size = sizeof(LwLsnCacheCtl); + requested_size += hash_estimate_size(lwlsn_cache_size, sizeof(LastWrittenLsnCacheEntry)); + + RequestAddinShmemSpace(requested_size); +} + +void +LwLsnCacheShmemInit(void) +{ + static HASHCTL info; + bool found; + + if (neon_use_communicator_worker) + return; + + Assert(lwlsn_cache_size > 0); + + info.keysize = sizeof(BufferTag); + info.entrysize = sizeof(LastWrittenLsnCacheEntry); + lastWrittenLsnCache = ShmemInitHash("last_written_lsn_cache", + lwlsn_cache_size, lwlsn_cache_size, + &info, + HASH_ELEM | HASH_BLOBS); + LwLsnCache = ShmemInitStruct("neon/LwLsnCacheCtl", sizeof(LwLsnCacheCtl), &found); + // Now set the size in the struct + LwLsnCache->lastWrittenLsnCacheSize = lwlsn_cache_size; + if (found) { + return; + } + dlist_init(&LwLsnCache->lastWrittenLsnLRU); + LwLsnCache->maxLastWrittenLsn = GetRedoRecPtr(); prev_set_lwlsn_block_range_hook = set_lwlsn_block_range_hook; set_lwlsn_block_range_hook = neon_set_lwlsn_block_range; @@ -106,41 +148,6 @@ init_lwlsncache(void) set_lwlsn_db_hook = neon_set_lwlsn_db; } - -void -LwLsnCacheShmemRequest(void) -{ - Size requested_size = sizeof(LwLsnCacheCtl); - - requested_size += hash_estimate_size(lwlsn_cache_size, sizeof(LastWrittenLsnCacheEntry)); - - RequestAddinShmemSpace(requested_size); -} - -void -LwLsnCacheShmemInit(void) -{ - static HASHCTL info; - bool found; - if (lwlsn_cache_size > 0) - { - info.keysize = sizeof(BufferTag); - info.entrysize = sizeof(LastWrittenLsnCacheEntry); - lastWrittenLsnCache = ShmemInitHash("last_written_lsn_cache", - lwlsn_cache_size, lwlsn_cache_size, - &info, - HASH_ELEM | HASH_BLOBS); - LwLsnCache = ShmemInitStruct("neon/LwLsnCacheCtl", sizeof(LwLsnCacheCtl), &found); - // Now set the size in the struct - LwLsnCache->lastWrittenLsnCacheSize = lwlsn_cache_size; - if (found) { - return; - } - } - dlist_init(&LwLsnCache->lastWrittenLsnLRU); - LwLsnCache->maxLastWrittenLsn = GetRedoRecPtr(); -} - /* * neon_get_lwlsn -- Returns maximal LSN of written page. * It returns an upper bound for the last written LSN of a given page, @@ -155,6 +162,7 @@ neon_get_lwlsn(NRelFileInfo rlocator, ForkNumber forknum, BlockNumber blkno) XLogRecPtr lsn; LastWrittenLsnCacheEntry* entry; + Assert(!neon_use_communicator_worker); Assert(LwLsnCache->lastWrittenLsnCacheSize != 0); LWLockAcquire(LastWrittenLsnLock, LW_SHARED); @@ -207,7 +215,10 @@ neon_get_lwlsn(NRelFileInfo rlocator, ForkNumber forknum, BlockNumber blkno) return lsn; } -static void neon_set_max_lwlsn(XLogRecPtr lsn) { +static void +neon_set_max_lwlsn(XLogRecPtr lsn) +{ + Assert(!neon_use_communicator_worker); LWLockAcquire(LastWrittenLsnLock, LW_EXCLUSIVE); LwLsnCache->maxLastWrittenLsn = lsn; LWLockRelease(LastWrittenLsnLock); @@ -228,6 +239,7 @@ neon_get_lwlsn_v(NRelFileInfo relfilenode, ForkNumber forknum, LastWrittenLsnCacheEntry* entry; XLogRecPtr lsn; + Assert(!neon_use_communicator_worker); Assert(LwLsnCache->lastWrittenLsnCacheSize != 0); Assert(nblocks > 0); Assert(PointerIsValid(lsns)); @@ -376,6 +388,8 @@ SetLastWrittenLSNForBlockRangeInternal(XLogRecPtr lsn, XLogRecPtr neon_set_lwlsn_block_range(XLogRecPtr lsn, NRelFileInfo rlocator, ForkNumber forknum, BlockNumber from, BlockNumber n_blocks) { + Assert(!neon_use_communicator_worker); + if (lsn == InvalidXLogRecPtr || n_blocks == 0 || LwLsnCache->lastWrittenLsnCacheSize == 0) return lsn; @@ -412,6 +426,8 @@ neon_set_lwlsn_block_v(const XLogRecPtr *lsns, NRelFileInfo relfilenode, Oid dbOid = NInfoGetDbOid(relfilenode); Oid relNumber = NInfoGetRelNumber(relfilenode); + Assert(!neon_use_communicator_worker); + if (lsns == NULL || nblocks == 0 || LwLsnCache->lastWrittenLsnCacheSize == 0 || NInfoGetRelNumber(relfilenode) == InvalidOid) return InvalidXLogRecPtr; @@ -469,6 +485,7 @@ neon_set_lwlsn_block_v(const XLogRecPtr *lsns, NRelFileInfo relfilenode, XLogRecPtr neon_set_lwlsn_block(XLogRecPtr lsn, NRelFileInfo rlocator, ForkNumber forknum, BlockNumber blkno) { + Assert(!neon_use_communicator_worker); return neon_set_lwlsn_block_range(lsn, rlocator, forknum, blkno, 1); } @@ -478,6 +495,7 @@ neon_set_lwlsn_block(XLogRecPtr lsn, NRelFileInfo rlocator, ForkNumber forknum, XLogRecPtr neon_set_lwlsn_relation(XLogRecPtr lsn, NRelFileInfo rlocator, ForkNumber forknum) { + Assert(!neon_use_communicator_worker); return neon_set_lwlsn_block(lsn, rlocator, forknum, REL_METADATA_PSEUDO_BLOCKNO); } @@ -488,6 +506,8 @@ XLogRecPtr neon_set_lwlsn_db(XLogRecPtr lsn) { NRelFileInfo dummyNode = {InvalidOid, InvalidOid, InvalidOid}; + + Assert(!neon_use_communicator_worker); return neon_set_lwlsn_block(lsn, dummyNode, MAIN_FORKNUM, 0); } diff --git a/pgxn/neon/neon_lwlsncache.h b/pgxn/neon/neon_lwlsncache.h index acb5561c0c..e022e7a998 100644 --- a/pgxn/neon/neon_lwlsncache.h +++ b/pgxn/neon/neon_lwlsncache.h @@ -3,7 +3,7 @@ #include "neon_pgversioncompat.h" -void init_lwlsncache(void); +extern void pg_init_lwlsncache(void); /* Hooks */ XLogRecPtr neon_get_lwlsn(NRelFileInfo rlocator, ForkNumber forknum, BlockNumber blkno); @@ -14,4 +14,4 @@ XLogRecPtr neon_set_lwlsn_block(XLogRecPtr lsn, NRelFileInfo rlocator, ForkNumbe XLogRecPtr neon_set_lwlsn_relation(XLogRecPtr lsn, NRelFileInfo rlocator, ForkNumber forknum); XLogRecPtr neon_set_lwlsn_db(XLogRecPtr lsn); -#endif /* NEON_LWLSNCACHE_H */ \ No newline at end of file +#endif /* NEON_LWLSNCACHE_H */ diff --git a/pgxn/neon/neon_perf_counters.c b/pgxn/neon/neon_perf_counters.c index dd576e4e73..4527084514 100644 --- a/pgxn/neon/neon_perf_counters.c +++ b/pgxn/neon/neon_perf_counters.c @@ -19,7 +19,36 @@ #include "neon.h" #include "neon_perf_counters.h" -#include "neon_pgversioncompat.h" +#include "walproposer.h" + +/* BEGIN_HADRON */ +databricks_metrics *databricks_metrics_shared; + +Size +DatabricksMetricsShmemSize(void) +{ + return sizeof(databricks_metrics); +} + +void +DatabricksMetricsShmemInit(void) +{ + bool found; + + databricks_metrics_shared = + ShmemInitStruct("Databricks counters", + DatabricksMetricsShmemSize(), + &found); + Assert(found == IsUnderPostmaster); + if (!found) + { + pg_atomic_init_u32(&databricks_metrics_shared->index_corruption_count, 0); + pg_atomic_init_u32(&databricks_metrics_shared->data_corruption_count, 0); + pg_atomic_init_u32(&databricks_metrics_shared->internal_error_count, 0); + pg_atomic_init_u32(&databricks_metrics_shared->ps_corruption_detected, 0); + } +} +/* END_HADRON */ neon_per_backend_counters *neon_per_backend_counters_shared; @@ -38,11 +67,12 @@ NeonPerfCountersShmemRequest(void) #else size = mul_size(NUM_NEON_PERF_COUNTER_SLOTS, sizeof(neon_per_backend_counters)); #endif + if (lakebase_mode) { + size = add_size(size, DatabricksMetricsShmemSize()); + } RequestAddinShmemSpace(size); } - - void NeonPerfCountersShmemInit(void) { @@ -361,6 +391,12 @@ neon_get_perf_counters(PG_FUNCTION_ARGS) neon_per_backend_counters totals = {0}; metric_t *metrics; + /* BEGIN_HADRON */ + WalproposerShmemState *wp_shmem; + uint32 num_safekeepers; + uint32 num_active_safekeepers; + /* END_HADRON */ + /* We put all the tuples into a tuplestore in one go. */ InitMaterializedSRF(fcinfo, 0); @@ -395,6 +431,55 @@ neon_get_perf_counters(PG_FUNCTION_ARGS) metric_to_datums(&metrics[i], &values[0], &nulls[0]); tuplestore_putvalues(rsinfo->setResult, rsinfo->setDesc, values, nulls); } + + if (lakebase_mode) { + + if (databricks_test_hook == TestHookCorruption) { + ereport(ERROR, + (errcode(ERRCODE_DATA_CORRUPTED), + errmsg("test corruption"))); + } + + // Not ideal but piggyback our databricks counters into the neon perf counters view + // so that we don't need to introduce neon--1.x+1.sql to add a new view. + { + // Keeping this code in its own block to work around the C90 "don't mix declarations and code" rule when we define + // the `databricks_metrics` array in the next block. Yes, we are seriously dealing with C90 rules in 2025. + + // Read safekeeper status from wal proposer shared memory first. + // Note that we are taking a mutex when reading from walproposer shared memory so that the total safekeeper count is + // consistent with the active wal acceptors count. Assuming that we don't query this view too often the mutex should + // not be a huge deal. + wp_shmem = GetWalpropShmemState(); + SpinLockAcquire(&wp_shmem->mutex); + num_safekeepers = wp_shmem->num_safekeepers; + num_active_safekeepers = 0; + for (int i = 0; i < num_safekeepers; i++) { + if (wp_shmem->safekeeper_status[i] == 1) { + num_active_safekeepers++; + } + } + SpinLockRelease(&wp_shmem->mutex); + } + { + metric_t databricks_metrics[] = { + {"sql_index_corruption_count", false, 0, (double) pg_atomic_read_u32(&databricks_metrics_shared->index_corruption_count)}, + {"sql_data_corruption_count", false, 0, (double) pg_atomic_read_u32(&databricks_metrics_shared->data_corruption_count)}, + {"sql_internal_error_count", false, 0, (double) pg_atomic_read_u32(&databricks_metrics_shared->internal_error_count)}, + {"ps_corruption_detected", false, 0, (double) pg_atomic_read_u32(&databricks_metrics_shared->ps_corruption_detected)}, + {"num_active_safekeepers", false, 0.0, (double) num_active_safekeepers}, + {"num_configured_safekeepers", false, 0.0, (double) num_safekeepers}, + {NULL, false, 0, 0}, + }; + for (int i = 0; databricks_metrics[i].name != NULL; i++) + { + metric_to_datums(&databricks_metrics[i], &values[0], &nulls[0]); + tuplestore_putvalues(rsinfo->setResult, rsinfo->setDesc, values, nulls); + } + } + /* END_HADRON */ + } + pfree(metrics); return (Datum) 0; diff --git a/pgxn/neon/neon_perf_counters.h b/pgxn/neon/neon_perf_counters.h index 4b611b0636..5c0b7ded7a 100644 --- a/pgxn/neon/neon_perf_counters.h +++ b/pgxn/neon/neon_perf_counters.h @@ -167,11 +167,7 @@ extern neon_per_backend_counters *neon_per_backend_counters_shared; */ #define NUM_NEON_PERF_COUNTER_SLOTS (MaxBackends + NUM_AUXILIARY_PROCS) -#if PG_VERSION_NUM >= 170000 #define MyNeonCounters (&neon_per_backend_counters_shared[MyProcNumber]) -#else -#define MyNeonCounters (&neon_per_backend_counters_shared[MyProc->pgprocno]) -#endif extern void inc_getpage_wait(uint64 latency); extern void inc_page_cache_read_wait(uint64 latency); @@ -181,5 +177,24 @@ extern void inc_query_time(uint64 elapsed); extern Size NeonPerfCountersShmemSize(void); extern void NeonPerfCountersShmemInit(void); +/* BEGIN_HADRON */ +typedef struct +{ + pg_atomic_uint32 index_corruption_count; + pg_atomic_uint32 data_corruption_count; + pg_atomic_uint32 internal_error_count; + pg_atomic_uint32 ps_corruption_detected; +} databricks_metrics; + +extern databricks_metrics *databricks_metrics_shared; + +extern Size DatabricksMetricsShmemSize(void); +extern void DatabricksMetricsShmemInit(void); + +extern int databricks_test_hook; + +static const int TestHookCorruption = 1; +/* END_HADRON */ + #endif /* NEON_PERF_COUNTERS_H */ diff --git a/pgxn/neon/neon_utils.c b/pgxn/neon/neon_utils.c index 1fad44bd58..847d380eb3 100644 --- a/pgxn/neon/neon_utils.c +++ b/pgxn/neon/neon_utils.c @@ -183,3 +183,22 @@ alloc_curl_handle(void) } #endif + +/* + * Check if a BufferTag is valid by verifying all its fields are not invalid. + */ +bool +BufferTagIsValid(const BufferTag *tag) +{ + #if PG_MAJORVERSION_NUM >= 16 + return (tag->spcOid != InvalidOid) && + (tag->relNumber != InvalidRelFileNumber) && + (tag->forkNum != InvalidForkNumber) && + (tag->blockNum != InvalidBlockNumber); + #else + return (tag->rnode.spcNode != InvalidOid) && + (tag->rnode.relNode != InvalidOid) && + (tag->forkNum != InvalidForkNumber) && + (tag->blockNum != InvalidBlockNumber); + #endif +} diff --git a/pgxn/neon/neon_utils.h b/pgxn/neon/neon_utils.h index 7480ac28cc..65d280788d 100644 --- a/pgxn/neon/neon_utils.h +++ b/pgxn/neon/neon_utils.h @@ -2,6 +2,7 @@ #define __NEON_UTILS_H__ #include "lib/stringinfo.h" +#include "storage/buf_internals.h" #ifndef WALPROPOSER_LIB #include @@ -16,6 +17,9 @@ void pq_sendint32_le(StringInfo buf, uint32 i); void pq_sendint64_le(StringInfo buf, uint64 i); void disable_core_dump(void); +/* Buffer tag validation function */ +bool BufferTagIsValid(const BufferTag *tag); + #ifndef WALPROPOSER_LIB CURL * alloc_curl_handle(void); diff --git a/pgxn/neon/pagestore_client.h b/pgxn/neon/pagestore_client.h index 47417a7bd5..50bf6b4e4b 100644 --- a/pgxn/neon/pagestore_client.h +++ b/pgxn/neon/pagestore_client.h @@ -256,6 +256,8 @@ typedef struct extern bool parse_shard_map(const char *connstr, ShardMap *result); extern shardno_t get_shard_number(BufferTag* tag); +extern void AssignNumShards(shardno_t num_shards); + extern const f_smgr *smgr_neon(ProcNumber backend, NRelFileInfo rinfo); extern void smgr_init_neon(void); extern void readahead_buffer_resize(int newsize, void *extra); diff --git a/pgxn/neon/pagestore_smgr.c b/pgxn/neon/pagestore_smgr.c index 06ce61d2e5..8187cc2359 100644 --- a/pgxn/neon/pagestore_smgr.c +++ b/pgxn/neon/pagestore_smgr.c @@ -73,10 +73,6 @@ #include "access/xlogrecovery.h" #endif -#if PG_VERSION_NUM < 160000 -typedef PGAlignedBlock PGIOAlignedBlock; -#endif - #include "access/nbtree.h" #include "storage/bufpage.h" #include "access/xlog_internal.h" @@ -88,7 +84,7 @@ static char *hexdump_page(char *page); NInfoGetRelNumber(InfoFromSMgrRel(reln)) >= FirstNormalObjectId \ ) -const int SmgrTrace = DEBUG1; +const int SmgrTrace = DEBUG5; /* unlogged relation build states */ typedef enum @@ -306,7 +302,7 @@ neon_wallog_pagev(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, */ lsns[batch_size++] = lsn; - if (batch_size >= BLOCK_BATCH_SIZE) + if (batch_size >= BLOCK_BATCH_SIZE && !neon_use_communicator_worker) { neon_set_lwlsn_block_v(lsns, InfoFromSMgrRel(reln), forknum, batch_blockno, @@ -316,7 +312,7 @@ neon_wallog_pagev(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, } } - if (batch_size != 0) + if (batch_size != 0 && !neon_use_communicator_worker) { neon_set_lwlsn_block_v(lsns, InfoFromSMgrRel(reln), forknum, batch_blockno, @@ -441,11 +437,17 @@ neon_wallog_page(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, co forknum, LSN_FORMAT_ARGS(lsn)))); } - /* - * Remember the LSN on this page. When we read the page again, we must - * read the same or newer version of it. - */ - neon_set_lwlsn_block(lsn, InfoFromSMgrRel(reln), forknum, blocknum); + if (!neon_use_communicator_worker) + { + /* + * Remember the LSN on this page. When we read the page again, we must + * read the same or newer version of it. + * + * (With the new communicator, the caller will make a write-request + * for this page, which updates the last-written LSN too) + */ + neon_set_lwlsn_block(lsn, InfoFromSMgrRel(reln), forknum, blocknum); + } } /* @@ -568,6 +570,7 @@ neon_get_request_lsns(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno, { XLogRecPtr last_written_lsns[PG_IOV_MAX]; + Assert(!neon_use_communicator_worker); Assert(nblocks <= PG_IOV_MAX); neon_get_lwlsn_v(rinfo, forknum, blkno, (int) nblocks, last_written_lsns); @@ -906,8 +909,25 @@ neon_create(SMgrRelation reln, ForkNumber forkNum, bool isRedo) if (isRedo) { + /* + * TODO: the protocol can check for existence and get the relsize + * in one roundtrip. Add a similar call to the + * backend<->communicator API. (The size is cached on the + * rel_exists call, so this does only one roundtrip to the + * pageserver, but two function calls and two cache lookups.) + */ if (!communicator_new_rel_exists(InfoFromSMgrRel(reln), forkNum)) + { communicator_new_rel_create(InfoFromSMgrRel(reln), forkNum, lsn); + reln->smgr_cached_nblocks[forkNum] = 0; + } + else + { + BlockNumber nblocks; + + nblocks = communicator_new_rel_nblocks(InfoFromSMgrRel(reln), forkNum); + reln->smgr_cached_nblocks[forkNum] = nblocks; + } } else communicator_new_rel_create(InfoFromSMgrRel(reln), forkNum, lsn); @@ -991,6 +1011,7 @@ neon_extend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, #endif { XLogRecPtr lsn; + bool lsn_was_zero; BlockNumber n_blocks = 0; switch (reln->smgr_relpersistence) @@ -1055,9 +1076,19 @@ neon_extend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, forkNum, blkno, (uint32) (lsn >> 32), (uint32) lsn); + /* + * smgr_extend is often called with an all-zeroes page, so + * lsn==InvalidXLogRecPtr. An smgr_write() call will come for the buffer + * later, after it has been initialized with the real page contents, and + * it is eventually evicted from the buffer cache. But we need a valid LSN + * to the relation metadata update now. + */ + lsn_was_zero = (lsn == InvalidXLogRecPtr); + if (lsn_was_zero) + lsn = GetXLogInsertRecPtr(); + if (neon_use_communicator_worker) { - // FIXME: this can pass lsn == invalid. Is that ok? communicator_new_rel_extend(InfoFromSMgrRel(reln), forkNum, blkno, (const void *) buffer, lsn); if (debug_compare_local) @@ -1084,11 +1115,8 @@ neon_extend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, * it is eventually evicted from the buffer cache. But we need a valid LSN * to the relation metadata update now. */ - if (lsn == InvalidXLogRecPtr) - { - lsn = GetXLogInsertRecPtr(); + if (lsn_was_zero) neon_set_lwlsn_block(lsn, InfoFromSMgrRel(reln), forkNum, blkno); - } neon_set_lwlsn_relation(lsn, InfoFromSMgrRel(reln), forkNum); } } @@ -1410,7 +1438,7 @@ neon_read_at_lsn(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno, { // FIXME: request_lsns is ignored. That affects the neon_test_utils callers. // Add the capability to specify the LSNs explicitly, for the sake of neon_test_utils ? - communicator_new_read_at_lsnv(rinfo, forkNum, blkno, &buffer, 1); + communicator_new_read_at_lsn_uncached(rinfo, forkNum, blkno, buffer, request_lsns.request_lsn, request_lsns.not_modified_since); } else communicator_read_at_lsnv(rinfo, forkNum, blkno, &request_lsns, &buffer, 1, NULL); @@ -1541,8 +1569,8 @@ neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void *buffer if (neon_use_communicator_worker) { - communicator_new_read_at_lsnv(InfoFromSMgrRel(reln), forkNum, blkno, - (void *) &buffer, 1); + communicator_new_readv(InfoFromSMgrRel(reln), forkNum, blkno, + (void *) &buffer, 1); } else { @@ -1657,8 +1685,8 @@ neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, if (neon_use_communicator_worker) { - communicator_new_read_at_lsnv(InfoFromSMgrRel(reln), forknum, blocknum, - buffers, nblocks); + communicator_new_readv(InfoFromSMgrRel(reln), forknum, blocknum, + buffers, nblocks); } else { @@ -2505,10 +2533,6 @@ neon_extend_rel_size(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blkno, if (blkno >= relsize) communicator_new_rel_zeroextend(rinfo, forknum, relsize, (blkno - relsize) + 1, end_recptr); - /* - * FIXME: does this need to update the last-written LSN too, like the - * old implementation? - */ return; } @@ -2666,21 +2690,27 @@ neon_redo_read_buffer_filter(XLogReaderState *record, uint8 block_id) } /* - * we don't have the buffer in memory, update lwLsn past this record, also - * evict page from file cache + * We don't have the buffer in shared buffers. Check if it's in the LFC. + * If it's not there either, update the lwLsn past this record. */ if (no_redo_needed) { - neon_set_lwlsn_block(end_recptr, rinfo, forknum, blkno); + bool in_cache; + /* - * Redo changes if page exists in LFC. - * We should perform this check after assigning LwLSN to prevent - * prefetching of some older version of the page by some other backend. + * Redo changes if the page is present in the LFC. */ if (neon_use_communicator_worker) - no_redo_needed = communicator_new_cache_contains(rinfo, forknum, blkno); + { + in_cache = communicator_new_update_lwlsn_for_block_if_not_cached(rinfo, forknum, blkno, end_recptr); + } else - no_redo_needed = !lfc_cache_contains(rinfo, forknum, blkno); + { + in_cache = lfc_cache_contains(rinfo, forknum, blkno); + neon_set_lwlsn_block(end_recptr, rinfo, forknum, blkno); + } + + no_redo_needed = !in_cache; } LWLockRelease(partitionLock); diff --git a/pgxn/neon/relsize_cache.c b/pgxn/neon/relsize_cache.c index 613e98f0d4..89ecccebc1 100644 --- a/pgxn/neon/relsize_cache.c +++ b/pgxn/neon/relsize_cache.c @@ -13,6 +13,7 @@ #include "neon.h" #include "neon_pgversioncompat.h" +#include "miscadmin.h" #include "pagestore_client.h" #include RELFILEINFO_HDR #include "storage/smgr.h" @@ -23,8 +24,6 @@ #include "utils/dynahash.h" #include "utils/guc.h" -#include "miscadmin.h" - typedef struct { NRelFileInfo rinfo; diff --git a/pgxn/neon/walproposer.c b/pgxn/neon/walproposer.c index ba6e4a54ff..dd42eaf18e 100644 --- a/pgxn/neon/walproposer.c +++ b/pgxn/neon/walproposer.c @@ -154,7 +154,9 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api) wp->safekeeper[wp->n_safekeepers].state = SS_OFFLINE; wp->safekeeper[wp->n_safekeepers].active_state = SS_ACTIVE_SEND; wp->safekeeper[wp->n_safekeepers].wp = wp; - + /* BEGIN_HADRON */ + wp->safekeeper[wp->n_safekeepers].index = wp->n_safekeepers; + /* END_HADRON */ { Safekeeper *sk = &wp->safekeeper[wp->n_safekeepers]; int written = 0; @@ -183,6 +185,10 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api) if (wp->safekeepers_generation > INVALID_GENERATION && wp->config->proto_version < 3) wp_log(FATAL, "enabling generations requires protocol version 3"); wp_log(LOG, "using safekeeper protocol version %d", wp->config->proto_version); + + /* BEGIN_HADRON */ + wp->api.reset_safekeeper_statuses_for_metrics(wp, wp->n_safekeepers); + /* END_HADRON */ /* Fill the greeting package */ wp->greetRequest.pam.tag = 'g'; @@ -355,6 +361,10 @@ ShutdownConnection(Safekeeper *sk) sk->state = SS_OFFLINE; sk->streamingAt = InvalidXLogRecPtr; + /* BEGIN_HADRON */ + sk->wp->api.update_safekeeper_status_for_metrics(sk->wp, sk->index, 0); + /* END_HADRON */ + MembershipConfigurationFree(&sk->greetResponse.mconf); if (sk->voteResponse.termHistory.entries) pfree(sk->voteResponse.termHistory.entries); @@ -1530,6 +1540,10 @@ StartStreaming(Safekeeper *sk) sk->active_state = SS_ACTIVE_SEND; sk->streamingAt = sk->startStreamingAt; + /* BEGIN_HADRON */ + sk->wp->api.update_safekeeper_status_for_metrics(sk->wp, sk->index, 1); + /* END_HADRON */ + /* * Donors can only be in SS_ACTIVE state, so we potentially update the * donor when we switch one to SS_ACTIVE. @@ -1887,6 +1901,12 @@ ParsePageserverFeedbackMessage(WalProposer *wp, StringInfo reply_message, Pagese ps_feedback->shard_number = pq_getmsgint(reply_message, sizeof(uint32)); psfeedback_log("%u", key, ps_feedback->shard_number); } + else if (strcmp(key, "corruption_detected") == 0) + { + Assert(value_len == 1); + ps_feedback->corruption_detected = pq_getmsgbyte(reply_message) != 0; + psfeedback_log("%s", key, ps_feedback->corruption_detected ? "true" : "false"); + } else { /* diff --git a/pgxn/neon/walproposer.h b/pgxn/neon/walproposer.h index 19d23925a5..ac42c2925d 100644 --- a/pgxn/neon/walproposer.h +++ b/pgxn/neon/walproposer.h @@ -374,6 +374,8 @@ typedef struct PageserverFeedback XLogRecPtr remote_consistent_lsn; TimestampTz replytime; uint32 shard_number; + /* true if the pageserver has detected data corruption in the timeline */ + bool corruption_detected; } PageserverFeedback; /* BEGIN_HADRON */ @@ -389,12 +391,21 @@ typedef struct PageserverFeedback */ typedef struct WalRateLimiter { - /* If the value is 1, PG backends will hit backpressure. */ + /* The effective wal write rate. Could be changed dynamically + based on whether PG has backpressure or not.*/ + pg_atomic_uint32 effective_max_wal_bytes_per_second; + /* If the value is 1, PG backends will hit backpressure until the time has past batch_end_time_us. */ pg_atomic_uint32 should_limit; /* The number of bytes sent in the current second. */ uint64 sent_bytes; - /* The last recorded time in microsecond. */ - pg_atomic_uint64 last_recorded_time_us; + /* The timestamp when the write starts in the current batch. A batch is a time interval (e.g., )that we + track and throttle writes. Most times a batch is 1s, but it could become larger if the PG overwrites the WALs + and we will adjust the batch accordingly to compensate (e.g., if PG writes 10MB at once and max WAL write rate + is 1MB/s, then the current batch will become 10s). */ + pg_atomic_uint64 batch_start_time_us; + /* The timestamp (in the future) that the current batch should end and accept more writes + (after should_limit is set to 1). */ + pg_atomic_uint64 batch_end_time_us; } WalRateLimiter; /* END_HADRON */ @@ -421,6 +432,10 @@ typedef struct WalproposerShmemState /* BEGIN_HADRON */ /* The WAL rate limiter */ WalRateLimiter wal_rate_limiter; + /* Number of safekeepers in the config */ + uint32 num_safekeepers; + /* Per-safekeeper status flags: 0=inactive, 1=active */ + uint8 safekeeper_status[MAX_SAFEKEEPERS]; /* END_HADRON */ } WalproposerShmemState; @@ -472,6 +487,11 @@ typedef struct Safekeeper char const *host; char const *port; + /* BEGIN_HADRON */ + /* index of this safekeeper in the WalProposer array */ + uint32 index; + /* END_HADRON */ + /* * connection string for connecting/reconnecting. * @@ -720,6 +740,23 @@ typedef struct walproposer_api * handled by elog(). */ void (*log_internal) (WalProposer *wp, int level, const char *line); + + /* + * BEGIN_HADRON + * APIs manipulating shared memory state used for Safekeeper quorum health metrics. + */ + + /* + * Reset the safekeeper statuses in shared memory for metric purposes. + */ + void (*reset_safekeeper_statuses_for_metrics) (WalProposer *wp, uint32 num_safekeepers); + + /* + * Update the safekeeper status in shared memory for metric purposes. + */ + void (*update_safekeeper_status_for_metrics) (WalProposer *wp, uint32 sk_index, uint8 status); + + /* END_HADRON */ } walproposer_api; /* diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index 93807be8c2..47b5ec523f 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -49,6 +49,7 @@ #include "libpqwalproposer.h" #include "neon.h" +#include "neon_perf_counters.h" #include "neon_walreader.h" #include "walproposer.h" @@ -68,6 +69,14 @@ int safekeeper_proto_version = 3; char *safekeeper_conninfo_options = ""; /* BEGIN_HADRON */ int databricks_max_wal_mb_per_second = -1; +// during throttling, we will limit the effective WAL write rate to 10KB. +// PG can still push some WAL to SK, but at a very low rate. +int databricks_throttled_max_wal_bytes_per_second = 10 * 1024; +// The max sleep time of a batch. This is to make sure the rate limiter does not +// overshoot too much and block PG for a very long time. +// This is set as 5 minuetes for now. PG can send as much as 10MB of WALs to SK in one batch, +// so this effectively caps the write rate to ~30KB/s in the worst case. +static uint64 kRateLimitMaxBatchUSecs = 300 * USECS_PER_SEC; /* END_HADRON */ /* Set to true in the walproposer bgw. */ @@ -86,6 +95,7 @@ static HotStandbyFeedback agg_hs_feedback; static void nwp_register_gucs(void); static void assign_neon_safekeepers(const char *newval, void *extra); static uint64 backpressure_lag_impl(void); +static uint64 hadron_backpressure_lag_impl(void); static uint64 startup_backpressure_wrap(void); static bool backpressure_throttling_impl(void); static void walprop_register_bgworker(void); @@ -110,6 +120,22 @@ static void rm_safekeeper_event_set(Safekeeper *to_remove, bool is_sk); static void CheckGracefulShutdown(WalProposer *wp); +/* BEGIN_HADRON */ +shardno_t get_num_shards(void); + +static int positive_mb_to_bytes(int mb) +{ + if (mb <= 0) + { + return mb; + } + else + { + return mb * 1024 * 1024; + } +} +/* END_HADRON */ + static void init_walprop_config(bool syncSafekeepers) { @@ -257,6 +283,16 @@ nwp_register_gucs(void) PGC_SUSET, GUC_UNIT_MB, NULL, NULL, NULL); + + DefineCustomIntVariable( + "databricks.throttled_max_wal_bytes_per_second", + "The maximum WAL bytes per second when PG is being throttled.", + NULL, + &databricks_throttled_max_wal_bytes_per_second, + 10 * 1024, 0, INT_MAX, + PGC_SUSET, + GUC_UNIT_BYTE, + NULL, NULL, NULL); /* END_HADRON */ } @@ -395,19 +431,65 @@ assign_neon_safekeepers(const char *newval, void *extra) pfree(oldval); } -/* Check if we need to suspend inserts because of lagging replication. */ -static uint64 -backpressure_lag_impl(void) +/* BEGIN_HADRON */ +static uint64 hadron_backpressure_lag_impl(void) { struct WalproposerShmemState* state = NULL; + uint64 lag = 0; - /* BEGIN_HADRON */ if(max_cluster_size < 0){ // if max cluster size is not set, then we don't apply backpressure because we're reconfiguring PG return 0; } - /* END_HADRON */ + lag = backpressure_lag_impl(); + state = GetWalpropShmemState(); + if ( state != NULL && databricks_max_wal_mb_per_second != -1 ) + { + int old_limit = pg_atomic_read_u32(&state->wal_rate_limiter.effective_max_wal_bytes_per_second); + int new_limit = (lag == 0)? positive_mb_to_bytes(databricks_max_wal_mb_per_second) : databricks_throttled_max_wal_bytes_per_second; + if( old_limit != new_limit ) + { + uint64 batch_start_time = pg_atomic_read_u64(&state->wal_rate_limiter.batch_start_time_us); + uint64 batch_end_time = pg_atomic_read_u64(&state->wal_rate_limiter.batch_end_time_us); + // the rate limit has changed, we need to reset the rate limiter's batch end time + pg_atomic_write_u32(&state->wal_rate_limiter.effective_max_wal_bytes_per_second, new_limit); + pg_atomic_write_u64(&state->wal_rate_limiter.batch_end_time_us, Min(batch_start_time + USECS_PER_SEC, batch_end_time)); + } + if( new_limit == -1 ) + { + return 0; + } + + if (pg_atomic_read_u32(&state->wal_rate_limiter.should_limit) == true) + { + TimestampTz now = GetCurrentTimestamp(); + struct WalRateLimiter *limiter = &state->wal_rate_limiter; + uint64 batch_end_time = pg_atomic_read_u64(&limiter->batch_end_time_us); + if ( now >= batch_end_time ) + { + /* + * The backend has past the batch end time and it's time to push more WALs. + * If the backends are pushing WALs too fast, the wal proposer will rate limit them again. + */ + uint32 expected = true; + pg_atomic_compare_exchange_u32(&state->wal_rate_limiter.should_limit, &expected, false); + return 0; + } + return Max(lag, 1); + } + // rate limiter decides to not throttle, then return 0. + return 0; + } + + return lag; +} +/* END_HADRON */ + +/* Check if we need to suspend inserts because of lagging replication. */ +static uint64 +backpressure_lag_impl(void) +{ if (max_replication_apply_lag > 0 || max_replication_flush_lag > 0 || max_replication_write_lag > 0) { XLogRecPtr writePtr; @@ -426,45 +508,47 @@ backpressure_lag_impl(void) LSN_FORMAT_ARGS(flushPtr), LSN_FORMAT_ARGS(applyPtr)); - if ((writePtr != InvalidXLogRecPtr && max_replication_write_lag > 0 && myFlushLsn > writePtr + max_replication_write_lag * MB)) + if (lakebase_mode) { - return (myFlushLsn - writePtr - max_replication_write_lag * MB); - } + // in case PG does not have shard map initialized, we assume PG always has 1 shard at minimum. + shardno_t num_shards = Max(1, get_num_shards()); + int tenant_max_replication_apply_lag = num_shards * max_replication_apply_lag; + int tenant_max_replication_flush_lag = num_shards * max_replication_flush_lag; + int tenant_max_replication_write_lag = num_shards * max_replication_write_lag; - if ((flushPtr != InvalidXLogRecPtr && max_replication_flush_lag > 0 && myFlushLsn > flushPtr + max_replication_flush_lag * MB)) - { - return (myFlushLsn - flushPtr - max_replication_flush_lag * MB); - } + if ((writePtr != InvalidXLogRecPtr && tenant_max_replication_write_lag > 0 && myFlushLsn > writePtr + tenant_max_replication_write_lag * MB)) + { + return (myFlushLsn - writePtr - tenant_max_replication_write_lag * MB); + } - if ((applyPtr != InvalidXLogRecPtr && max_replication_apply_lag > 0 && myFlushLsn > applyPtr + max_replication_apply_lag * MB)) + if ((flushPtr != InvalidXLogRecPtr && tenant_max_replication_flush_lag > 0 && myFlushLsn > flushPtr + tenant_max_replication_flush_lag * MB)) + { + return (myFlushLsn - flushPtr - tenant_max_replication_flush_lag * MB); + } + + if ((applyPtr != InvalidXLogRecPtr && tenant_max_replication_apply_lag > 0 && myFlushLsn > applyPtr + tenant_max_replication_apply_lag * MB)) + { + return (myFlushLsn - applyPtr - tenant_max_replication_apply_lag * MB); + } + } + else { - return (myFlushLsn - applyPtr - max_replication_apply_lag * MB); + if ((writePtr != InvalidXLogRecPtr && max_replication_write_lag > 0 && myFlushLsn > writePtr + max_replication_write_lag * MB)) + { + return (myFlushLsn - writePtr - max_replication_write_lag * MB); + } + + if ((flushPtr != InvalidXLogRecPtr && max_replication_flush_lag > 0 && myFlushLsn > flushPtr + max_replication_flush_lag * MB)) + { + return (myFlushLsn - flushPtr - max_replication_flush_lag * MB); + } + + if ((applyPtr != InvalidXLogRecPtr && max_replication_apply_lag > 0 && myFlushLsn > applyPtr + max_replication_apply_lag * MB)) + { + return (myFlushLsn - applyPtr - max_replication_apply_lag * MB); + } } } - - /* BEGIN_HADRON */ - if (databricks_max_wal_mb_per_second == -1) { - return 0; - } - - state = GetWalpropShmemState(); - if (state != NULL && !!pg_atomic_read_u32(&state->wal_rate_limiter.should_limit)) - { - TimestampTz now = GetCurrentTimestamp(); - struct WalRateLimiter *limiter = &state->wal_rate_limiter; - uint64 last_recorded_time = pg_atomic_read_u64(&limiter->last_recorded_time_us); - if (now - last_recorded_time > USECS_PER_SEC) - { - /* - * The backend has past 1 second since the last recorded time and it's time to push more WALs. - * If the backends are pushing WALs too fast, the wal proposer will rate limit them again. - */ - uint32 expected = true; - pg_atomic_compare_exchange_u32(&state->wal_rate_limiter.should_limit, &expected, false); - } - return 1; - } - /* END_HADRON */ return 0; } @@ -479,9 +563,9 @@ startup_backpressure_wrap(void) if (AmStartupProcess() || !IsUnderPostmaster) return 0; - delay_backend_us = &backpressure_lag_impl; + delay_backend_us = &hadron_backpressure_lag_impl; - return backpressure_lag_impl(); + return hadron_backpressure_lag_impl(); } /* @@ -511,8 +595,10 @@ WalproposerShmemInit(void) pg_atomic_init_u64(&walprop_shared->backpressureThrottlingTime, 0); pg_atomic_init_u64(&walprop_shared->currentClusterSize, 0); /* BEGIN_HADRON */ + pg_atomic_init_u32(&walprop_shared->wal_rate_limiter.effective_max_wal_bytes_per_second, -1); pg_atomic_init_u32(&walprop_shared->wal_rate_limiter.should_limit, 0); - pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.last_recorded_time_us, 0); + pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.batch_start_time_us, 0); + pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.batch_end_time_us, 0); /* END_HADRON */ } } @@ -527,8 +613,10 @@ WalproposerShmemInit_SyncSafekeeper(void) pg_atomic_init_u64(&walprop_shared->mineLastElectedTerm, 0); pg_atomic_init_u64(&walprop_shared->backpressureThrottlingTime, 0); /* BEGIN_HADRON */ + pg_atomic_init_u32(&walprop_shared->wal_rate_limiter.effective_max_wal_bytes_per_second, -1); pg_atomic_init_u32(&walprop_shared->wal_rate_limiter.should_limit, 0); - pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.last_recorded_time_us, 0); + pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.batch_start_time_us, 0); + pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.batch_end_time_us, 0); /* END_HADRON */ } @@ -560,7 +648,7 @@ backpressure_throttling_impl(void) return retry; /* Calculate replicas lag */ - lag = backpressure_lag_impl(); + lag = hadron_backpressure_lag_impl(); if (lag == 0) return retry; @@ -646,18 +734,24 @@ walprop_pg_get_shmem_state(WalProposer *wp) * Record new ps_feedback in the array with shards and update min_feedback. */ static PageserverFeedback -record_pageserver_feedback(PageserverFeedback *ps_feedback) +record_pageserver_feedback(PageserverFeedback *ps_feedback, shardno_t num_shards) { PageserverFeedback min_feedback; Assert(ps_feedback->present); Assert(ps_feedback->shard_number < MAX_SHARDS); + Assert(ps_feedback->shard_number < num_shards); + + // Begin Hadron: Record any corruption signal from the pageserver first. + if (ps_feedback->corruption_detected) { + pg_atomic_write_u32(&databricks_metrics_shared->ps_corruption_detected, 1); + } SpinLockAcquire(&walprop_shared->mutex); - /* Update the number of shards */ - if (ps_feedback->shard_number + 1 > walprop_shared->num_shards) - walprop_shared->num_shards = ps_feedback->shard_number + 1; + // Hadron: Update the num_shards from the source-of-truth (shard map) lazily when we receive + // a new pageserver feedback. + walprop_shared->num_shards = Max(walprop_shared->num_shards, num_shards); /* Update the feedback */ memcpy(&walprop_shared->shard_ps_feedback[ps_feedback->shard_number], ps_feedback, sizeof(PageserverFeedback)); @@ -1475,6 +1569,7 @@ XLogBroadcastWalProposer(WalProposer *wp) XLogRecPtr endptr; struct WalproposerShmemState *state = NULL; TimestampTz now = 0; + int effective_max_wal_bytes_per_second = 0; /* Start from the last sent position */ startptr = sentPtr; @@ -1529,22 +1624,36 @@ XLogBroadcastWalProposer(WalProposer *wp) /* BEGIN_HADRON */ state = GetWalpropShmemState(); - if (databricks_max_wal_mb_per_second != -1 && state != NULL) + effective_max_wal_bytes_per_second = pg_atomic_read_u32(&state->wal_rate_limiter.effective_max_wal_bytes_per_second); + if (effective_max_wal_bytes_per_second != -1 && state != NULL) { - uint64 max_wal_bytes = (uint64) databricks_max_wal_mb_per_second * 1024 * 1024; struct WalRateLimiter *limiter = &state->wal_rate_limiter; - uint64 last_recorded_time = pg_atomic_read_u64(&limiter->last_recorded_time_us); - if (now - last_recorded_time > USECS_PER_SEC) + uint64 batch_end_time = pg_atomic_read_u64(&limiter->batch_end_time_us); + if ( now >= batch_end_time ) { - /* Reset the rate limiter */ + // Reset the rate limiter to start a new batch limiter->sent_bytes = 0; - pg_atomic_write_u64(&limiter->last_recorded_time_us, now); pg_atomic_write_u32(&limiter->should_limit, false); + pg_atomic_write_u64(&limiter->batch_start_time_us, now); + /* tentatively assign the batch end time as 1s from now. This could result in one of the following cases: + 1. If sent_bytes does not reach effective_max_wal_bytes_per_second in 1s, + then we will reset the current batch and clear sent_bytes. No throttling happens. + 2. Otherwise, we will recompute the end time (below) based on how many bytes are actually written, + and throttle PG until the batch end time. */ + pg_atomic_write_u64(&limiter->batch_end_time_us, now + USECS_PER_SEC); } limiter->sent_bytes += (endptr - startptr); - if (limiter->sent_bytes > max_wal_bytes) + if (limiter->sent_bytes > effective_max_wal_bytes_per_second) { + uint64_t batch_start_time = pg_atomic_read_u64(&limiter->batch_start_time_us); + uint64 throttle_usecs = USECS_PER_SEC * limiter->sent_bytes / Max(effective_max_wal_bytes_per_second, 1); + if (throttle_usecs > kRateLimitMaxBatchUSecs){ + elog(LOG, "throttle_usecs %lu is too large, limiting to %lu", throttle_usecs, kRateLimitMaxBatchUSecs); + throttle_usecs = kRateLimitMaxBatchUSecs; + } + pg_atomic_write_u32(&limiter->should_limit, true); + pg_atomic_write_u64(&limiter->batch_end_time_us, batch_start_time + throttle_usecs); } } /* END_HADRON */ @@ -2023,19 +2132,43 @@ walprop_pg_process_safekeeper_feedback(WalProposer *wp, Safekeeper *sk) if (wp->config->syncSafekeepers) return; + /* handle fresh ps_feedback */ if (sk->appendResponse.ps_feedback.present) { - PageserverFeedback min_feedback = record_pageserver_feedback(&sk->appendResponse.ps_feedback); + shardno_t num_shards = get_num_shards(); - /* Only one main shard sends non-zero currentClusterSize */ - if (sk->appendResponse.ps_feedback.currentClusterSize > 0) - SetNeonCurrentClusterSize(sk->appendResponse.ps_feedback.currentClusterSize); - - if (min_feedback.disk_consistent_lsn != standby_apply_lsn) + // During shard split, we receive ps_feedback from child shards before + // the split commits and our shard map GUC has been updated. We must + // filter out such feedback here because record_pageserver_feedback() + // doesn't do it. + // + // NB: what we would actually want to happen is that we only receive + // ps_feedback from the parent shards when the split is committed, then + // apply the split to our set of tracked feedback and from here on only + // receive ps_feedback from child shards. This filter condition doesn't + // do that: if we split from N parent to 2N child shards, the first N + // child shards' feedback messages will pass this condition, even before + // the split is committed. That's a bit sloppy, but OK for now. + if (sk->appendResponse.ps_feedback.shard_number < num_shards) { - standby_apply_lsn = min_feedback.disk_consistent_lsn; - needToAdvanceSlot = true; + PageserverFeedback min_feedback = record_pageserver_feedback(&sk->appendResponse.ps_feedback, num_shards); + + /* Only one main shard sends non-zero currentClusterSize */ + if (sk->appendResponse.ps_feedback.currentClusterSize > 0) + SetNeonCurrentClusterSize(sk->appendResponse.ps_feedback.currentClusterSize); + + if (min_feedback.disk_consistent_lsn != standby_apply_lsn) + { + standby_apply_lsn = min_feedback.disk_consistent_lsn; + needToAdvanceSlot = true; + } + } + else + { + // HADRON + elog(DEBUG2, "Ignoring pageserver feedback for unknown shard %d (current shard number %d)", + sk->appendResponse.ps_feedback.shard_number, num_shards); } } @@ -2128,6 +2261,27 @@ GetNeonCurrentClusterSize(void) } uint64 GetNeonCurrentClusterSize(void); +/* BEGIN_HADRON */ +static void +walprop_pg_reset_safekeeper_statuses_for_metrics(WalProposer *wp, uint32 num_safekeepers) +{ + WalproposerShmemState* shmem = wp->api.get_shmem_state(wp); + SpinLockAcquire(&shmem->mutex); + shmem->num_safekeepers = num_safekeepers; + memset(shmem->safekeeper_status, 0, sizeof(shmem->safekeeper_status)); + SpinLockRelease(&shmem->mutex); +} + +static void +walprop_pg_update_safekeeper_status_for_metrics(WalProposer *wp, uint32 sk_index, uint8 status) +{ + WalproposerShmemState* shmem = wp->api.get_shmem_state(wp); + Assert(sk_index < MAX_SAFEKEEPERS); + SpinLockAcquire(&shmem->mutex); + shmem->safekeeper_status[sk_index] = status; + SpinLockRelease(&shmem->mutex); +} +/* END_HADRON */ static const walproposer_api walprop_pg = { .get_shmem_state = walprop_pg_get_shmem_state, @@ -2161,4 +2315,6 @@ static const walproposer_api walprop_pg = { .finish_sync_safekeepers = walprop_pg_finish_sync_safekeepers, .process_safekeeper_feedback = walprop_pg_process_safekeeper_feedback, .log_internal = walprop_pg_log_internal, + .reset_safekeeper_statuses_for_metrics = walprop_pg_reset_safekeeper_statuses_for_metrics, + .update_safekeeper_status_for_metrics = walprop_pg_update_safekeeper_status_for_metrics, }; diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 3c3f93c8e3..0ece79c329 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -33,7 +33,6 @@ env_logger.workspace = true framed-websockets.workspace = true futures.workspace = true hashbrown.workspace = true -hashlink.workspace = true hex.workspace = true hmac.workspace = true hostname.workspace = true @@ -54,6 +53,7 @@ json = { path = "../libs/proxy/json" } lasso = { workspace = true, features = ["multi-threaded"] } measured = { workspace = true, features = ["lasso"] } metrics.workspace = true +moka.workspace = true once_cell.workspace = true opentelemetry = { workspace = true, features = ["trace"] } papaya = "0.2.0" @@ -107,10 +107,11 @@ uuid.workspace = true x509-cert.workspace = true redis.workspace = true zerocopy.workspace = true +zeroize.workspace = true # uncomment this to use the real subzero-core crate # subzero-core = { git = "https://github.com/neondatabase/subzero", rev = "396264617e78e8be428682f87469bb25429af88a", features = ["postgresql"], optional = true } # this is a stub for the subzero-core crate -subzero-core = { path = "./subzero_core", features = ["postgresql"], optional = true} +subzero-core = { path = "../libs/proxy/subzero_core", features = ["postgresql"], optional = true} ouroboros = { version = "0.18", optional = true } # jwt stuff diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index b06ed3a0ae..2a02748a10 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -8,11 +8,12 @@ use tracing::{info, info_span}; use crate::auth::backend::ComputeUserInfo; use crate::cache::Cached; +use crate::cache::node_info::CachedNodeInfo; use crate::compute::AuthInfo; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::client::cplane_proxy_v1; -use crate::control_plane::{self, CachedNodeInfo, NodeInfo}; +use crate::control_plane::{self, NodeInfo}; use crate::error::{ReportableError, UserFacingError}; use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 1e5c076fb9..491f14b1b6 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -6,7 +6,7 @@ use crate::auth::{self, AuthFlow}; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::AuthSecret; -use crate::intern::EndpointIdInt; +use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::sasl; use crate::stream::{self, Stream}; @@ -25,13 +25,15 @@ pub(crate) async fn authenticate_cleartext( ctx.set_auth_method(crate::context::AuthMethod::Cleartext); let ep = EndpointIdInt::from(&info.endpoint); + let role = RoleNameInt::from(&info.user); let auth_flow = AuthFlow::new( client, auth::CleartextPassword { secret, endpoint: ep, - pool: config.thread_pool.clone(), + role, + pool: config.scram_thread_pool.clone(), }, ); let auth_outcome = { diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index e7805d8bfe..a6df2a7011 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -16,16 +16,16 @@ use tracing::{debug, info}; use crate::auth::{self, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange}; use crate::cache::Cached; +use crate::cache::node_info::CachedNodeInfo; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::client::ControlPlaneClient; use crate::control_plane::errors::GetAuthInfoError; use crate::control_plane::messages::EndpointRateLimitConfig; use crate::control_plane::{ - self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl, - RoleAccessControl, + self, AccessBlockerFlags, AuthSecret, ControlPlaneApi, EndpointAccessControl, RoleAccessControl, }; -use crate::intern::EndpointIdInt; +use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; use crate::proxy::wake_compute::WakeComputeBackend; @@ -273,9 +273,11 @@ async fn authenticate_with_secret( ) -> auth::Result { if let Some(password) = unauthenticated_password { let ep = EndpointIdInt::from(&info.endpoint); + let role = RoleNameInt::from(&info.user); let auth_outcome = - validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?; + validate_password_and_exchange(&config.scram_thread_pool, ep, role, &password, secret) + .await?; let keys = match auth_outcome { crate::sasl::Outcome::Success(key) => key, crate::sasl::Outcome::Failure(reason) => { @@ -433,11 +435,12 @@ mod tests { use super::auth_quirks; use super::jwt::JwkCache; use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern}; + use crate::cache::node_info::CachedNodeInfo; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::messages::EndpointRateLimitConfig; use crate::control_plane::{ - self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl, + self, AccessBlockerFlags, EndpointAccessControl, RoleAccessControl, }; use crate::proxy::NeonOptions; use crate::rate_limiter::EndpointRateLimiter; @@ -498,7 +501,7 @@ mod tests { static CONFIG: Lazy = Lazy::new(|| AuthenticationConfig { jwks_cache: JwkCache::default(), - thread_pool: ThreadPool::new(1), + scram_thread_pool: ThreadPool::new(1), scram_protocol_timeout: std::time::Duration::from_secs(5), ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index c825d5bf4b..00cd274e99 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -10,7 +10,7 @@ use super::backend::ComputeCredentialKeys; use super::{AuthError, PasswordHackPayload}; use crate::context::RequestContext; use crate::control_plane::AuthSecret; -use crate::intern::EndpointIdInt; +use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage}; use crate::sasl; use crate::scram::threadpool::ThreadPool; @@ -46,6 +46,7 @@ pub(crate) struct PasswordHack; pub(crate) struct CleartextPassword { pub(crate) pool: Arc, pub(crate) endpoint: EndpointIdInt, + pub(crate) role: RoleNameInt, pub(crate) secret: AuthSecret, } @@ -111,6 +112,7 @@ impl AuthFlow<'_, S, CleartextPassword> { let outcome = validate_password_and_exchange( &self.state.pool, self.state.endpoint, + self.state.role, password, self.state.secret, ) @@ -165,13 +167,15 @@ impl AuthFlow<'_, S, Scram<'_>> { pub(crate) async fn validate_password_and_exchange( pool: &ThreadPool, endpoint: EndpointIdInt, + role: RoleNameInt, password: &[u8], secret: AuthSecret, ) -> super::Result> { match secret { // perform scram authentication as both client and server to validate the keys AuthSecret::Scram(scram_secret) => { - let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?; + let outcome = + crate::scram::exchange(pool, endpoint, role, &scram_secret, password).await?; let client_key = match outcome { sasl::Outcome::Success(client_key) => client_key, diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index 7b9012dc69..86b64c62c9 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -29,7 +29,7 @@ use crate::config::{ }; use crate::control_plane::locks::ApiLocks; use crate::http::health_server::AppMetrics; -use crate::metrics::{Metrics, ThreadPoolMetrics}; +use crate::metrics::{Metrics, ServiceInfo}; use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo}; use crate::scram::threadpool::ThreadPool; use crate::serverless::cancel_set::CancelSet; @@ -114,8 +114,6 @@ pub async fn run() -> anyhow::Result<()> { let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook(); let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]); - Metrics::install(Arc::new(ThreadPoolMetrics::new(0))); - // TODO: refactor these to use labels debug!("Version: {GIT_VERSION}"); debug!("Build_tag: {BUILD_TAG}"); @@ -207,6 +205,11 @@ pub async fn run() -> anyhow::Result<()> { endpoint_rate_limiter, ); + Metrics::get() + .service + .info + .set_label(ServiceInfo::running()); + match futures::future::select(pin!(maintenance_tasks.join_next()), pin!(task)).await { // exit immediately on maintenance task completion Either::Left((Some(res), _)) => match crate::error::flatten_err(res)? {}, @@ -279,7 +282,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig http_config, authentication_config: AuthenticationConfig { jwks_cache: JwkCache::default(), - thread_pool: ThreadPool::new(0), + scram_thread_pool: ThreadPool::new(0), scram_protocol_timeout: Duration::from_secs(10), ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index f3782312dc..cdbf0f09ac 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -26,7 +26,7 @@ use utils::project_git_version; use utils::sentry_init::init_sentry; use crate::context::RequestContext; -use crate::metrics::{Metrics, ThreadPoolMetrics}; +use crate::metrics::{Metrics, ServiceInfo}; use crate::pglb::TlsRequired; use crate::pqproto::FeStartupPacket; use crate::protocol2::ConnectionInfo; @@ -80,8 +80,6 @@ pub async fn run() -> anyhow::Result<()> { let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook(); let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]); - Metrics::install(Arc::new(ThreadPoolMetrics::new(0))); - let args = cli().get_matches(); let destination: String = args .get_one::("dest") @@ -135,6 +133,12 @@ pub async fn run() -> anyhow::Result<()> { cancellation_token.clone(), )) .map(crate::error::flatten_err); + + Metrics::get() + .service + .info + .set_label(ServiceInfo::running()); + let signals_task = tokio::spawn(crate::signals::handle(cancellation_token, || {})); // the signal task cant ever succeed. diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 4148f4bc62..583cdc95bf 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -40,7 +40,7 @@ use crate::config::{ }; use crate::context::parquet::ParquetUploadArgs; use crate::http::health_server::AppMetrics; -use crate::metrics::Metrics; +use crate::metrics::{Metrics, ServiceInfo}; use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::redis::kv_ops::RedisKVClient; @@ -535,12 +535,7 @@ pub async fn run() -> anyhow::Result<()> { // add a task to flush the db_schema cache every 10 minutes #[cfg(feature = "rest_broker")] if let Some(db_schema_cache) = &config.rest_config.db_schema_cache { - maintenance_tasks.spawn(async move { - loop { - tokio::time::sleep(Duration::from_secs(600)).await; - db_schema_cache.flush(); - } - }); + maintenance_tasks.spawn(db_schema_cache.maintain()); } if let Some(metrics_config) = &config.metric_collection { @@ -590,6 +585,11 @@ pub async fn run() -> anyhow::Result<()> { } } + Metrics::get() + .service + .info + .set_label(ServiceInfo::running()); + let maintenance = loop { // get one complete task match futures::future::select( @@ -617,7 +617,12 @@ pub async fn run() -> anyhow::Result<()> { /// ProxyConfig is created at proxy startup, and lives forever. fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let thread_pool = ThreadPool::new(args.scram_thread_pool_size); - Metrics::install(thread_pool.metrics.clone()); + Metrics::get() + .proxy + .scram_pool + .0 + .set(thread_pool.metrics.clone()) + .ok(); let tls_config = match (&args.tls_key, &args.tls_cert) { (Some(key_path), Some(cert_path)) => Some(config::configure_tls( @@ -690,12 +695,15 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { }; let authentication_config = AuthenticationConfig { jwks_cache: JwkCache::default(), - thread_pool, + scram_thread_pool: thread_pool, scram_protocol_timeout: args.scram_protocol_timeout, ip_allowlist_check_enabled: !args.is_private_access_proxy, is_vpc_acccess_proxy: args.is_private_access_proxy, is_auth_broker: args.is_auth_broker, + #[cfg(not(feature = "rest_broker"))] accept_jwts: args.is_auth_broker, + #[cfg(feature = "rest_broker")] + accept_jwts: args.is_auth_broker || args.is_rest_broker, console_redirect_confirmation_timeout: args.webauth_confirmation_timeout, }; @@ -711,12 +719,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { info!("Using DbSchemaCache with options={db_schema_cache_config:?}"); let db_schema_cache = if args.is_rest_broker { - Some(DbSchemaCache::new( - "db_schema_cache", - db_schema_cache_config.size, - db_schema_cache_config.ttl, - true, - )) + Some(DbSchemaCache::new(db_schema_cache_config)) } else { None }; diff --git a/proxy/src/cache/common.rs b/proxy/src/cache/common.rs index b5caf94788..9a7d0d99cf 100644 --- a/proxy/src/cache/common.rs +++ b/proxy/src/cache/common.rs @@ -1,4 +1,16 @@ use std::ops::{Deref, DerefMut}; +use std::time::{Duration, Instant}; + +use moka::Expiry; +use moka::notification::RemovalCause; + +use crate::control_plane::messages::ControlPlaneErrorMessage; +use crate::metrics::{ + CacheEviction, CacheKind, CacheOutcome, CacheOutcomeGroup, CacheRemovalCause, Metrics, +}; + +/// Default TTL used when caching errors from control plane. +pub const DEFAULT_ERROR_TTL: Duration = Duration::from_secs(30); /// A generic trait which exposes types of cache's key and value, /// as well as the notion of cache entry invalidation. @@ -10,20 +22,16 @@ pub(crate) trait Cache { /// Entry's value. type Value; - /// Used for entry invalidation. - type LookupInfo; - /// Invalidate an entry using a lookup info. /// We don't have an empty default impl because it's error-prone. - fn invalidate(&self, _: &Self::LookupInfo); + fn invalidate(&self, _: &Self::Key); } impl Cache for &C { type Key = C::Key; type Value = C::Value; - type LookupInfo = C::LookupInfo; - fn invalidate(&self, info: &Self::LookupInfo) { + fn invalidate(&self, info: &Self::Key) { C::invalidate(self, info); } } @@ -31,7 +39,7 @@ impl Cache for &C { /// Wrapper for convenient entry invalidation. pub(crate) struct Cached::Value> { /// Cache + lookup info. - pub(crate) token: Option<(C, C::LookupInfo)>, + pub(crate) token: Option<(C, C::Key)>, /// The value itself. pub(crate) value: V, @@ -43,23 +51,6 @@ impl Cached { Self { token: None, value } } - pub(crate) fn take_value(self) -> (Cached, V) { - ( - Cached { - token: self.token, - value: (), - }, - self.value, - ) - } - - pub(crate) fn map(self, f: impl FnOnce(V) -> U) -> Cached { - Cached { - token: self.token, - value: f(self.value), - } - } - /// Drop this entry from a cache if it's still there. pub(crate) fn invalidate(self) -> V { if let Some((cache, info)) = &self.token { @@ -87,3 +78,91 @@ impl DerefMut for Cached { &mut self.value } } + +pub type ControlPlaneResult = Result>; + +#[derive(Clone, Copy)] +pub struct CplaneExpiry { + pub error: Duration, +} + +impl Default for CplaneExpiry { + fn default() -> Self { + Self { + error: DEFAULT_ERROR_TTL, + } + } +} + +impl CplaneExpiry { + pub fn expire_early( + &self, + value: &ControlPlaneResult, + updated: Instant, + ) -> Option { + match value { + Ok(_) => None, + Err(err) => Some(self.expire_err_early(err, updated)), + } + } + + pub fn expire_err_early(&self, err: &ControlPlaneErrorMessage, updated: Instant) -> Duration { + err.status + .as_ref() + .and_then(|s| s.details.retry_info.as_ref()) + .map_or(self.error, |r| r.retry_at.into_std() - updated) + } +} + +impl Expiry> for CplaneExpiry { + fn expire_after_create( + &self, + _key: &K, + value: &ControlPlaneResult, + created_at: Instant, + ) -> Option { + self.expire_early(value, created_at) + } + + fn expire_after_update( + &self, + _key: &K, + value: &ControlPlaneResult, + updated_at: Instant, + _duration_until_expiry: Option, + ) -> Option { + self.expire_early(value, updated_at) + } +} + +pub fn eviction_listener(kind: CacheKind, cause: RemovalCause) { + let cause = match cause { + RemovalCause::Expired => CacheRemovalCause::Expired, + RemovalCause::Explicit => CacheRemovalCause::Explicit, + RemovalCause::Replaced => CacheRemovalCause::Replaced, + RemovalCause::Size => CacheRemovalCause::Size, + }; + Metrics::get() + .cache + .evicted_total + .inc(CacheEviction { cache: kind, cause }); +} + +#[inline] +pub fn count_cache_outcome(kind: CacheKind, cache_result: Option) -> Option { + let outcome = if cache_result.is_some() { + CacheOutcome::Hit + } else { + CacheOutcome::Miss + }; + Metrics::get().cache.request_total.inc(CacheOutcomeGroup { + cache: kind, + outcome, + }); + cache_result +} + +#[inline] +pub fn count_cache_insert(kind: CacheKind) { + Metrics::get().cache.inserted_total.inc(kind); +} diff --git a/proxy/src/cache/mod.rs b/proxy/src/cache/mod.rs index ce7f781213..0a607a1409 100644 --- a/proxy/src/cache/mod.rs +++ b/proxy/src/cache/mod.rs @@ -1,6 +1,5 @@ pub(crate) mod common; +pub(crate) mod node_info; pub(crate) mod project_info; -mod timed_lru; -pub(crate) use common::{Cache, Cached}; -pub(crate) use timed_lru::TimedLru; +pub(crate) use common::{Cached, ControlPlaneResult, CplaneExpiry}; diff --git a/proxy/src/cache/node_info.rs b/proxy/src/cache/node_info.rs new file mode 100644 index 0000000000..47fc7a5b08 --- /dev/null +++ b/proxy/src/cache/node_info.rs @@ -0,0 +1,60 @@ +use crate::cache::common::{Cache, count_cache_insert, count_cache_outcome, eviction_listener}; +use crate::cache::{Cached, ControlPlaneResult, CplaneExpiry}; +use crate::config::CacheOptions; +use crate::control_plane::NodeInfo; +use crate::metrics::{CacheKind, Metrics}; +use crate::types::EndpointCacheKey; + +pub(crate) struct NodeInfoCache(moka::sync::Cache>); +pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; + +impl Cache for NodeInfoCache { + type Key = EndpointCacheKey; + type Value = ControlPlaneResult; + + fn invalidate(&self, info: &EndpointCacheKey) { + self.0.invalidate(info); + } +} + +impl NodeInfoCache { + pub fn new(config: CacheOptions) -> Self { + let builder = moka::sync::Cache::builder() + .name("node_info") + .expire_after(CplaneExpiry::default()); + let builder = config.moka(builder); + + if let Some(size) = config.size { + Metrics::get() + .cache + .capacity + .set(CacheKind::NodeInfo, size as i64); + } + + let builder = builder + .eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::NodeInfo, cause)); + + Self(builder.build()) + } + + pub fn insert(&self, key: EndpointCacheKey, value: ControlPlaneResult) { + count_cache_insert(CacheKind::NodeInfo); + self.0.insert(key, value); + } + + pub fn get(&self, key: &EndpointCacheKey) -> Option> { + count_cache_outcome(CacheKind::NodeInfo, self.0.get(key)) + } + + pub fn get_entry( + &'static self, + key: &EndpointCacheKey, + ) -> Option> { + self.get(key).map(|res| { + res.map(|value| Cached { + token: Some((self, key.clone())), + value, + }) + }) + } +} diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index a589dd175b..f8a38be287 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -1,84 +1,20 @@ -use std::collections::{HashMap, HashSet, hash_map}; +use std::collections::HashSet; use std::convert::Infallible; -use std::time::Duration; -use async_trait::async_trait; use clashmap::ClashMap; -use clashmap::mapref::one::Ref; -use rand::Rng; -use tokio::time::Instant; +use moka::sync::Cache; use tracing::{debug, info}; +use crate::cache::common::{ + ControlPlaneResult, CplaneExpiry, count_cache_insert, count_cache_outcome, eviction_listener, +}; use crate::config::ProjectInfoCacheOptions; use crate::control_plane::messages::{ControlPlaneErrorMessage, Reason}; use crate::control_plane::{EndpointAccessControl, RoleAccessControl}; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; +use crate::metrics::{CacheKind, Metrics}; use crate::types::{EndpointId, RoleName}; -#[async_trait] -pub(crate) trait ProjectInfoCache { - fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt); - fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt); - fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt); - fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt); -} - -struct Entry { - expires_at: Instant, - value: T, -} - -impl Entry { - pub(crate) fn new(value: T, ttl: Duration) -> Self { - Self { - expires_at: Instant::now() + ttl, - value, - } - } - - pub(crate) fn get(&self) -> Option<&T> { - (!self.is_expired()).then_some(&self.value) - } - - fn is_expired(&self) -> bool { - self.expires_at <= Instant::now() - } -} - -struct EndpointInfo { - role_controls: HashMap>>, - controls: Option>>, -} - -type ControlPlaneResult = Result>; - -impl EndpointInfo { - pub(crate) fn get_role_secret_with_ttl( - &self, - role_name: RoleNameInt, - ) -> Option<(ControlPlaneResult, Duration)> { - let entry = self.role_controls.get(&role_name)?; - let ttl = entry.expires_at - Instant::now(); - Some((entry.get()?.clone(), ttl)) - } - - pub(crate) fn get_controls_with_ttl( - &self, - ) -> Option<(ControlPlaneResult, Duration)> { - let entry = self.controls.as_ref()?; - let ttl = entry.expires_at - Instant::now(); - Some((entry.get()?.clone(), ttl)) - } - - pub(crate) fn invalidate_endpoint(&mut self) { - self.controls = None; - } - - pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) { - self.role_controls.remove(&role_name); - } -} - /// Cache for project info. /// This is used to cache auth data for endpoints. /// Invalidation is done by console notifications or by TTL (if console notifications are disabled). @@ -86,8 +22,9 @@ impl EndpointInfo { /// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data. /// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available? /// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache. -pub struct ProjectInfoCacheImpl { - cache: ClashMap, +pub struct ProjectInfoCache { + role_controls: Cache<(EndpointIdInt, RoleNameInt), ControlPlaneResult>, + ep_controls: Cache>, project2ep: ClashMap>, // FIXME(stefan): we need a way to GC the account2ep map. @@ -96,16 +33,13 @@ pub struct ProjectInfoCacheImpl { config: ProjectInfoCacheOptions, } -#[async_trait] -impl ProjectInfoCache for ProjectInfoCacheImpl { - fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) { +impl ProjectInfoCache { + pub fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) { info!("invalidating endpoint access for `{endpoint_id}`"); - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_endpoint(); - } + self.ep_controls.invalidate(&endpoint_id); } - fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) { + pub fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) { info!("invalidating endpoint access for project `{project_id}`"); let endpoints = self .project2ep @@ -113,13 +47,11 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { .map(|kv| kv.value().clone()) .unwrap_or_default(); for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_endpoint(); - } + self.ep_controls.invalidate(&endpoint_id); } } - fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) { + pub fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) { info!("invalidating endpoint access for org `{account_id}`"); let endpoints = self .account2ep @@ -127,13 +59,15 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { .map(|kv| kv.value().clone()) .unwrap_or_default(); for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_endpoint(); - } + self.ep_controls.invalidate(&endpoint_id); } } - fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) { + pub fn invalidate_role_secret_for_project( + &self, + project_id: ProjectIdInt, + role_name: RoleNameInt, + ) { info!( "invalidating role secret for project_id `{}` and role_name `{}`", project_id, role_name, @@ -144,47 +78,73 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { .map(|kv| kv.value().clone()) .unwrap_or_default(); for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_role_secret(role_name); - } + self.role_controls.invalidate(&(endpoint_id, role_name)); } } } -impl ProjectInfoCacheImpl { +impl ProjectInfoCache { pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self { + Metrics::get().cache.capacity.set( + CacheKind::ProjectInfoRoles, + (config.size * config.max_roles) as i64, + ); + Metrics::get() + .cache + .capacity + .set(CacheKind::ProjectInfoEndpoints, config.size as i64); + + // we cache errors for 30 seconds, unless retry_at is set. + let expiry = CplaneExpiry::default(); Self { - cache: ClashMap::new(), + role_controls: Cache::builder() + .name("project_info_roles") + .eviction_listener(|_k, _v, cause| { + eviction_listener(CacheKind::ProjectInfoRoles, cause); + }) + .max_capacity(config.size * config.max_roles) + .time_to_live(config.ttl) + .expire_after(expiry) + .build(), + ep_controls: Cache::builder() + .name("project_info_endpoints") + .eviction_listener(|_k, _v, cause| { + eviction_listener(CacheKind::ProjectInfoEndpoints, cause); + }) + .max_capacity(config.size) + .time_to_live(config.ttl) + .expire_after(expiry) + .build(), project2ep: ClashMap::new(), account2ep: ClashMap::new(), config, } } - fn get_endpoint_cache( - &self, - endpoint_id: &EndpointId, - ) -> Option> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - self.cache.get(&endpoint_id) - } - - pub(crate) fn get_role_secret_with_ttl( + pub(crate) fn get_role_secret( &self, endpoint_id: &EndpointId, role_name: &RoleName, - ) -> Option<(ControlPlaneResult, Duration)> { + ) -> Option> { + let endpoint_id = EndpointIdInt::get(endpoint_id)?; let role_name = RoleNameInt::get(role_name)?; - let endpoint_info = self.get_endpoint_cache(endpoint_id)?; - endpoint_info.get_role_secret_with_ttl(role_name) + + count_cache_outcome( + CacheKind::ProjectInfoRoles, + self.role_controls.get(&(endpoint_id, role_name)), + ) } - pub(crate) fn get_endpoint_access_with_ttl( + pub(crate) fn get_endpoint_access( &self, endpoint_id: &EndpointId, - ) -> Option<(ControlPlaneResult, Duration)> { - let endpoint_info = self.get_endpoint_cache(endpoint_id)?; - endpoint_info.get_controls_with_ttl() + ) -> Option> { + let endpoint_id = EndpointIdInt::get(endpoint_id)?; + + count_cache_outcome( + CacheKind::ProjectInfoEndpoints, + self.ep_controls.get(&endpoint_id), + ) } pub(crate) fn insert_endpoint_access( @@ -203,34 +163,17 @@ impl ProjectInfoCacheImpl { self.insert_project2endpoint(project_id, endpoint_id); } - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } - debug!( key = &*endpoint_id, "created a cache entry for endpoint access" ); - let controls = Some(Entry::new(Ok(controls), self.config.ttl)); - let role_controls = Entry::new(Ok(role_controls), self.config.ttl); + count_cache_insert(CacheKind::ProjectInfoEndpoints); + count_cache_insert(CacheKind::ProjectInfoRoles); - match self.cache.entry(endpoint_id) { - clashmap::Entry::Vacant(e) => { - e.insert(EndpointInfo { - role_controls: HashMap::from_iter([(role_name, role_controls)]), - controls, - }); - } - clashmap::Entry::Occupied(mut e) => { - let ep = e.get_mut(); - ep.controls = controls; - if ep.role_controls.len() < self.config.max_roles { - ep.role_controls.insert(role_name, role_controls); - } - } - } + self.ep_controls.insert(endpoint_id, Ok(controls)); + self.role_controls + .insert((endpoint_id, role_name), Ok(role_controls)); } pub(crate) fn insert_endpoint_access_err( @@ -238,55 +181,34 @@ impl ProjectInfoCacheImpl { endpoint_id: EndpointIdInt, role_name: RoleNameInt, msg: Box, - ttl: Option, ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } - debug!( key = &*endpoint_id, "created a cache entry for an endpoint access error" ); - let ttl = ttl.unwrap_or(self.config.ttl); - - let controls = if msg.get_reason() == Reason::RoleProtected { - // RoleProtected is the only role-specific error that control plane can give us. - // If a given role name does not exist, it still returns a successful response, - // just with an empty secret. - None - } else { - // We can cache all the other errors in EndpointInfo.controls, - // because they don't depend on what role name we pass to control plane. - Some(Entry::new(Err(msg.clone()), ttl)) - }; - - let role_controls = Entry::new(Err(msg), ttl); - - match self.cache.entry(endpoint_id) { - clashmap::Entry::Vacant(e) => { - e.insert(EndpointInfo { - role_controls: HashMap::from_iter([(role_name, role_controls)]), - controls, + // RoleProtected is the only role-specific error that control plane can give us. + // If a given role name does not exist, it still returns a successful response, + // just with an empty secret. + if msg.get_reason() != Reason::RoleProtected { + // We can cache all the other errors in ep_controls because they don't + // depend on what role name we pass to control plane. + self.ep_controls + .entry(endpoint_id) + .and_compute_with(|entry| match entry { + // leave the entry alone if it's already Ok + Some(entry) if entry.value().is_ok() => moka::ops::compute::Op::Nop, + // replace the entry + _ => { + count_cache_insert(CacheKind::ProjectInfoEndpoints); + moka::ops::compute::Op::Put(Err(msg.clone())) + } }); - } - clashmap::Entry::Occupied(mut e) => { - let ep = e.get_mut(); - if let Some(entry) = &ep.controls - && !entry.is_expired() - && entry.value.is_ok() - { - // If we have cached non-expired, non-error controls, keep them. - } else { - ep.controls = controls; - } - if ep.role_controls.len() < self.config.max_roles { - ep.role_controls.insert(role_name, role_controls); - } - } } + + count_cache_insert(CacheKind::ProjectInfoRoles); + self.role_controls + .insert((endpoint_id, role_name), Err(msg)); } fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) { @@ -307,73 +229,35 @@ impl ProjectInfoCacheImpl { } } - pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) { - let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else { - return; - }; - let Some(role_name) = RoleNameInt::get(role_name) else { - return; - }; - - let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else { - return; - }; - - let entry = endpoint_info.role_controls.entry(role_name); - let hash_map::Entry::Occupied(role_controls) = entry else { - return; - }; - - if role_controls.get().is_expired() { - role_controls.remove(); - } + pub fn maybe_invalidate_role_secret(&self, _endpoint_id: &EndpointId, _role_name: &RoleName) { + // TODO: Expire the value early if the key is idle. + // Currently not an issue as we would just use the TTL to decide, which is what already happens. } pub async fn gc_worker(&self) -> anyhow::Result { - let mut interval = - tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32); + let mut interval = tokio::time::interval(self.config.gc_interval); loop { interval.tick().await; - if self.cache.len() < self.config.size { - // If there are not too many entries, wait until the next gc cycle. - continue; - } - self.gc(); + self.ep_controls.run_pending_tasks(); + self.role_controls.run_pending_tasks(); } } - - fn gc(&self) { - let shard = rand::rng().random_range(0..self.project2ep.shards().len()); - debug!(shard, "project_info_cache: performing epoch reclamation"); - - // acquire a random shard lock - let mut removed = 0; - let shard = self.project2ep.shards()[shard].write(); - for (_, endpoints) in shard.iter() { - for endpoint in endpoints { - self.cache.remove(endpoint); - removed += 1; - } - } - // We can drop this shard only after making sure that all endpoints are removed. - drop(shard); - info!("project_info_cache: removed {removed} endpoints"); - } } #[cfg(test)] mod tests { + use std::sync::Arc; + use std::time::Duration; + use super::*; use crate::control_plane::messages::{Details, EndpointRateLimitConfig, ErrorInfo, Status}; use crate::control_plane::{AccessBlockerFlags, AuthSecret}; use crate::scram::ServerSecret; - use std::sync::Arc; #[tokio::test] async fn test_project_info_cache_settings() { - tokio::time::pause(); - let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { - size: 2, + let cache = ProjectInfoCache::new(ProjectInfoCacheOptions { + size: 1, max_roles: 2, ttl: Duration::from_secs(1), gc_interval: Duration::from_secs(600), @@ -423,22 +307,17 @@ mod tests { }, ); - let (cached, ttl) = cache - .get_role_secret_with_ttl(&endpoint_id, &user1) - .unwrap(); + let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); assert_eq!(cached.unwrap().secret, secret1); - assert_eq!(ttl, cache.config.ttl); - let (cached, ttl) = cache - .get_role_secret_with_ttl(&endpoint_id, &user2) - .unwrap(); + let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); assert_eq!(cached.unwrap().secret, secret2); - assert_eq!(ttl, cache.config.ttl); // Shouldn't add more than 2 roles. let user3: RoleName = "user3".into(); let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32]))); + cache.role_controls.run_pending_tasks(); cache.insert_endpoint_access( account_id, project_id, @@ -455,31 +334,18 @@ mod tests { }, ); - assert!( - cache - .get_role_secret_with_ttl(&endpoint_id, &user3) - .is_none() - ); + cache.role_controls.run_pending_tasks(); + assert_eq!(cache.role_controls.entry_count(), 2); - let cached = cache - .get_endpoint_access_with_ttl(&endpoint_id) - .unwrap() - .0 - .unwrap(); - assert_eq!(cached.allowed_ips, allowed_ips); + tokio::time::sleep(Duration::from_secs(2)).await; - tokio::time::advance(Duration::from_secs(2)).await; - let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user1); - assert!(cached.is_none()); - let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user2); - assert!(cached.is_none()); - let cached = cache.get_endpoint_access_with_ttl(&endpoint_id); - assert!(cached.is_none()); + cache.role_controls.run_pending_tasks(); + assert_eq!(cache.role_controls.entry_count(), 0); } #[tokio::test] async fn test_caching_project_info_errors() { - let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { + let cache = ProjectInfoCache::new(ProjectInfoCacheOptions { size: 10, max_roles: 10, ttl: Duration::from_secs(1), @@ -519,34 +385,23 @@ mod tests { status: None, }); - let get_role_secret = |endpoint_id, role_name| { - cache - .get_role_secret_with_ttl(endpoint_id, role_name) - .unwrap() - .0 - }; - let get_endpoint_access = - |endpoint_id| cache.get_endpoint_access_with_ttl(endpoint_id).unwrap().0; + let get_role_secret = + |endpoint_id, role_name| cache.get_role_secret(endpoint_id, role_name).unwrap(); + let get_endpoint_access = |endpoint_id| cache.get_endpoint_access(endpoint_id).unwrap(); // stores role-specific errors only for get_role_secret - cache.insert_endpoint_access_err( - (&endpoint_id).into(), - (&user1).into(), - role_msg.clone(), - None, - ); + cache.insert_endpoint_access_err((&endpoint_id).into(), (&user1).into(), role_msg.clone()); assert_eq!( get_role_secret(&endpoint_id, &user1).unwrap_err().error, role_msg.error ); - assert!(cache.get_endpoint_access_with_ttl(&endpoint_id).is_none()); + assert!(cache.get_endpoint_access(&endpoint_id).is_none()); // stores non-role specific errors for both get_role_secret and get_endpoint_access cache.insert_endpoint_access_err( (&endpoint_id).into(), (&user1).into(), generic_msg.clone(), - None, ); assert_eq!( get_role_secret(&endpoint_id, &user1).unwrap_err().error, @@ -558,11 +413,7 @@ mod tests { ); // error isn't returned for other roles in the same endpoint - assert!( - cache - .get_role_secret_with_ttl(&endpoint_id, &user2) - .is_none() - ); + assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); // success for a role does not overwrite errors for other roles cache.insert_endpoint_access( @@ -590,7 +441,6 @@ mod tests { (&endpoint_id).into(), (&user2).into(), generic_msg.clone(), - None, ); assert!(get_role_secret(&endpoint_id, &user2).is_err()); assert!(get_endpoint_access(&endpoint_id).is_ok()); diff --git a/proxy/src/cache/timed_lru.rs b/proxy/src/cache/timed_lru.rs deleted file mode 100644 index 0a7fb40b0c..0000000000 --- a/proxy/src/cache/timed_lru.rs +++ /dev/null @@ -1,262 +0,0 @@ -use std::borrow::Borrow; -use std::hash::Hash; -use std::time::{Duration, Instant}; - -// This seems to make more sense than `lru` or `cached`: -// -// * `near/nearcore` ditched `cached` in favor of `lru` -// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed). -// -// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs). -// This severely hinders its usage both in terms of creating wrappers and supported key types. -// -// On the other hand, `hashlink` has good download stats and appears to be maintained. -use hashlink::{LruCache, linked_hash_map::RawEntryMut}; -use tracing::debug; - -use super::Cache; -use super::common::Cached; - -/// An implementation of timed LRU cache with fixed capacity. -/// Key properties: -/// -/// * Whenever a new entry is inserted, the least recently accessed one is evicted. -/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`). -/// -/// * If `update_ttl_on_retrieval` is `true`. When the entry is about to be retrieved, we check its expiration timestamp. -/// If the entry has expired, we remove it from the cache; Otherwise we bump the -/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong -/// its existence. -/// -/// * There's an API for immediate invalidation (removal) of a cache entry; -/// It's useful in case we know for sure that the entry is no longer correct. -/// See [`Cached`] for more information. -/// -/// * Expired entries are kept in the cache, until they are evicted by the LRU policy, -/// or by a successful lookup (i.e. the entry hasn't expired yet). -/// There is no background job to reap the expired records. -/// -/// * It's possible for an entry that has not yet expired entry to be evicted -/// before expired items. That's a bit wasteful, but probably fine in practice. -pub(crate) struct TimedLru { - /// Cache's name for tracing. - name: &'static str, - - /// The underlying cache implementation. - cache: parking_lot::Mutex>>, - - /// Default time-to-live of a single entry. - ttl: Duration, - - update_ttl_on_retrieval: bool, -} - -impl Cache for TimedLru { - type Key = K; - type Value = V; - type LookupInfo = Key; - - fn invalidate(&self, info: &Self::LookupInfo) { - self.invalidate_raw(info); - } -} - -struct Entry { - created_at: Instant, - expires_at: Instant, - ttl: Duration, - update_ttl_on_retrieval: bool, - value: T, -} - -impl TimedLru { - /// Construct a new LRU cache with timed entries. - pub(crate) fn new( - name: &'static str, - capacity: usize, - ttl: Duration, - update_ttl_on_retrieval: bool, - ) -> Self { - Self { - name, - cache: LruCache::new(capacity).into(), - ttl, - update_ttl_on_retrieval, - } - } - - /// Drop an entry from the cache if it's outdated. - #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] - fn invalidate_raw(&self, key: &K) { - // Do costly things before taking the lock. - let mut cache = self.cache.lock(); - let entry = match cache.raw_entry_mut().from_key(key) { - RawEntryMut::Vacant(_) => return, - RawEntryMut::Occupied(x) => x.remove(), - }; - drop(cache); // drop lock before logging - - let Entry { - created_at, - expires_at, - .. - } = entry; - - debug!( - ?created_at, - ?expires_at, - "processed a cache entry invalidation event" - ); - } - - /// Try retrieving an entry by its key, then execute `extract` if it exists. - #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] - fn get_raw(&self, key: &Q, extract: impl FnOnce(&K, &Entry) -> R) -> Option - where - K: Borrow, - Q: Hash + Eq + ?Sized, - { - let now = Instant::now(); - - // Do costly things before taking the lock. - let mut cache = self.cache.lock(); - let mut raw_entry = match cache.raw_entry_mut().from_key(key) { - RawEntryMut::Vacant(_) => return None, - RawEntryMut::Occupied(x) => x, - }; - - // Immeditely drop the entry if it has expired. - let entry = raw_entry.get(); - if entry.expires_at <= now { - raw_entry.remove(); - return None; - } - - let value = extract(raw_entry.key(), entry); - let (created_at, expires_at) = (entry.created_at, entry.expires_at); - - // Update the deadline and the entry's position in the LRU list. - let deadline = now.checked_add(raw_entry.get().ttl).expect("time overflow"); - if raw_entry.get().update_ttl_on_retrieval { - raw_entry.get_mut().expires_at = deadline; - } - raw_entry.to_back(); - - drop(cache); // drop lock before logging - debug!( - created_at = format_args!("{created_at:?}"), - old_expires_at = format_args!("{expires_at:?}"), - new_expires_at = format_args!("{deadline:?}"), - "accessed a cache entry" - ); - - Some(value) - } - - /// Insert an entry to the cache. If an entry with the same key already - /// existed, return the previous value and its creation timestamp. - #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] - fn insert_raw(&self, key: K, value: V) -> (Instant, Option) { - self.insert_raw_ttl(key, value, self.ttl, self.update_ttl_on_retrieval) - } - - /// Insert an entry to the cache. If an entry with the same key already - /// existed, return the previous value and its creation timestamp. - #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] - fn insert_raw_ttl( - &self, - key: K, - value: V, - ttl: Duration, - update: bool, - ) -> (Instant, Option) { - let created_at = Instant::now(); - let expires_at = created_at.checked_add(ttl).expect("time overflow"); - - let entry = Entry { - created_at, - expires_at, - ttl, - update_ttl_on_retrieval: update, - value, - }; - - // Do costly things before taking the lock. - let old = self - .cache - .lock() - .insert(key, entry) - .map(|entry| entry.value); - - debug!( - created_at = format_args!("{created_at:?}"), - expires_at = format_args!("{expires_at:?}"), - replaced = old.is_some(), - "created a cache entry" - ); - - (created_at, old) - } -} - -impl TimedLru { - pub(crate) fn insert_ttl(&self, key: K, value: V, ttl: Duration) { - self.insert_raw_ttl(key, value, ttl, false); - } - - #[cfg(feature = "rest_broker")] - pub(crate) fn insert(&self, key: K, value: V) { - self.insert_raw_ttl(key, value, self.ttl, self.update_ttl_on_retrieval); - } - - pub(crate) fn insert_unit(&self, key: K, value: V) -> (Option, Cached<&Self, ()>) { - let (_, old) = self.insert_raw(key.clone(), value); - - let cached = Cached { - token: Some((self, key)), - value: (), - }; - - (old, cached) - } - - #[cfg(feature = "rest_broker")] - pub(crate) fn flush(&self) { - let now = Instant::now(); - let mut cache = self.cache.lock(); - - // Collect keys of expired entries first - let expired_keys: Vec<_> = cache - .iter() - .filter_map(|(key, entry)| { - if entry.expires_at <= now { - Some(key.clone()) - } else { - None - } - }) - .collect(); - - // Remove expired entries - for key in expired_keys { - cache.remove(&key); - } - } -} - -impl TimedLru { - /// Retrieve a cached entry in convenient wrapper, alongside timing information. - pub(crate) fn get_with_created_at( - &self, - key: &Q, - ) -> Option::Value, Instant)>> - where - K: Borrow + Clone, - Q: Hash + Eq + ?Sized, - { - self.get_raw(key, |key, entry| Cached { - token: Some((self, key.clone())), - value: (entry.value.clone(), entry.created_at), - }) - } -} diff --git a/proxy/src/cancellation.rs b/proxy/src/cancellation.rs index f25121331f..13c6f0f6d7 100644 --- a/proxy/src/cancellation.rs +++ b/proxy/src/cancellation.rs @@ -429,26 +429,13 @@ impl CancellationHandler { /// (we'd need something like `#![feature(type_alias_impl_trait)]`). #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CancelClosure { - socket_addr: SocketAddr, - cancel_token: RawCancelToken, - hostname: String, // for pg_sni router - user_info: ComputeUserInfo, + pub socket_addr: SocketAddr, + pub cancel_token: RawCancelToken, + pub hostname: String, // for pg_sni router + pub user_info: ComputeUserInfo, } impl CancelClosure { - pub(crate) fn new( - socket_addr: SocketAddr, - cancel_token: RawCancelToken, - hostname: String, - user_info: ComputeUserInfo, - ) -> Self { - Self { - socket_addr, - cancel_token, - hostname, - user_info, - } - } /// Cancels the query running on user's compute node. pub(crate) async fn try_cancel_query( &self, diff --git a/proxy/src/compute/mod.rs b/proxy/src/compute/mod.rs index 7b9183b05e..43cfe70206 100644 --- a/proxy/src/compute/mod.rs +++ b/proxy/src/compute/mod.rs @@ -7,26 +7,26 @@ use std::net::{IpAddr, SocketAddr}; use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; use postgres_client::config::{AuthKeys, ChannelBinding, SslMode}; +use postgres_client::connect_raw::StartupStream; +use postgres_client::error::SqlState; use postgres_client::maybe_tls_stream::MaybeTlsStream; use postgres_client::tls::MakeTlsConnect; -use postgres_client::{NoTls, RawCancelToken, RawConnection}; -use postgres_protocol::message::backend::NoticeResponseBody; use thiserror::Error; use tokio::net::{TcpStream, lookup_host}; use tracing::{debug, error, info, warn}; -use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo}; +use crate::auth::backend::ComputeCredentialKeys; use crate::auth::parse_endpoint_param; -use crate::cancellation::CancelClosure; use crate::compute::tls::TlsError; use crate::config::ComputeConfig; use crate::context::RequestContext; use crate::control_plane::client::ApiLockError; use crate::control_plane::errors::WakeComputeError; use crate::control_plane::messages::MetricsAuxInfo; -use crate::error::{ReportableError, UserFacingError}; +use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumDbConnectionsGuard}; use crate::pqproto::StartupMessageParams; +use crate::proxy::connect_compute::TlsNegotiation; use crate::proxy::neon_option; use crate::types::Host; @@ -66,12 +66,13 @@ impl UserFacingError for PostgresError { } impl ReportableError for PostgresError { - fn get_error_kind(&self) -> crate::error::ErrorKind { + fn get_error_kind(&self) -> ErrorKind { match self { - PostgresError::Postgres(e) if e.as_db_error().is_some() => { - crate::error::ErrorKind::Postgres - } - PostgresError::Postgres(_) => crate::error::ErrorKind::Compute, + PostgresError::Postgres(err) => match err.as_db_error() { + Some(err) if err.code() == &SqlState::INVALID_CATALOG_NAME => ErrorKind::User, + Some(_) => ErrorKind::Postgres, + None => ErrorKind::Compute, + }, } } } @@ -86,6 +87,14 @@ pub(crate) enum ConnectionError { #[error("error acquiring resource permit: {0}")] TooManyConnectionAttempts(#[from] ApiLockError), + + #[cfg(test)] + #[error("retryable: {retryable}, wakeable: {wakeable}, kind: {kind:?}")] + TestError { + retryable: bool, + wakeable: bool, + kind: crate::error::ErrorKind, + }, } impl UserFacingError for ConnectionError { @@ -96,16 +105,20 @@ impl UserFacingError for ConnectionError { "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned() } ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(), + #[cfg(test)] + ConnectionError::TestError { .. } => self.to_string(), } } } impl ReportableError for ConnectionError { - fn get_error_kind(&self) -> crate::error::ErrorKind { + fn get_error_kind(&self) -> ErrorKind { match self { - ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute, + ConnectionError::TlsError(_) => ErrorKind::Compute, ConnectionError::WakeComputeError(e) => e.get_error_kind(), ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(), + #[cfg(test)] + ConnectionError::TestError { kind, .. } => *kind, } } } @@ -236,8 +249,7 @@ impl AuthInfo { &self, ctx: &RequestContext, compute: &mut ComputeConnection, - user_info: &ComputeUserInfo, - ) -> Result { + ) -> Result<(), PostgresError> { // client config with stubbed connect info. // TODO(conrad): should we rewrite this to bypass tokio-postgres2 entirely, // utilising pqproto.rs. @@ -247,39 +259,10 @@ impl AuthInfo { let tmp_config = self.enrich(tmp_config); let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let connection = tmp_config - .tls_and_authenticate(&mut compute.stream, NoTls) - .await?; + tmp_config.authenticate(&mut compute.stream).await?; drop(pause); - let RawConnection { - stream: _, - parameters, - delayed_notice, - process_id, - secret_key, - } = connection; - - tracing::Span::current().record("pid", tracing::field::display(process_id)); - - // NB: CancelToken is supposed to hold socket_addr, but we use connect_raw. - // Yet another reason to rework the connection establishing code. - let cancel_closure = CancelClosure::new( - compute.socket_addr, - RawCancelToken { - ssl_mode: compute.ssl_mode, - process_id, - secret_key, - }, - compute.hostname.to_string(), - user_info.clone(), - ); - - Ok(PostgresSettings { - params: parameters, - cancel_closure, - delayed_notice, - }) + Ok(()) } } @@ -288,6 +271,7 @@ impl ConnectInfo { async fn connect_raw( &self, config: &ComputeConfig, + tls: TlsNegotiation, ) -> Result<(SocketAddr, MaybeTlsStream), TlsError> { let timeout = config.timeout; @@ -330,7 +314,7 @@ impl ConnectInfo { match connect_once(&*addrs).await { Ok((sockaddr, stream)) => Ok(( sockaddr, - tls::connect_tls(stream, self.ssl_mode, config, host).await?, + tls::connect_tls(stream, self.ssl_mode, config, host, tls).await?, )), Err(err) => { warn!("couldn't connect to compute node at {host}:{port}: {err}"); @@ -343,21 +327,9 @@ impl ConnectInfo { pub type RustlsStream = >::Stream; pub type MaybeRustlsStream = MaybeTlsStream; -// TODO(conrad): we don't need to parse these. -// These are just immediately forwarded back to the client. -// We could instead stream them out instead of reading them into memory. -pub struct PostgresSettings { - /// PostgreSQL connection parameters. - pub params: std::collections::HashMap, - /// Query cancellation token. - pub cancel_closure: CancelClosure, - /// Notices received from compute after authenticating - pub delayed_notice: Vec, -} - pub struct ComputeConnection { /// Socket connected to a compute node. - pub stream: MaybeTlsStream, + pub stream: StartupStream, /// Labels for proxy's metrics. pub aux: MetricsAuxInfo, pub hostname: Host, @@ -373,9 +345,10 @@ impl ConnectInfo { ctx: &RequestContext, aux: &MetricsAuxInfo, config: &ComputeConfig, + tls: TlsNegotiation, ) -> Result { let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let (socket_addr, stream) = self.connect_raw(config).await?; + let (socket_addr, stream) = self.connect_raw(config, tls).await?; drop(pause); tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id)); @@ -390,6 +363,7 @@ impl ConnectInfo { ctx.get_testodrome_id().unwrap_or_default(), ); + let stream = StartupStream::new(stream); let connection = ComputeConnection { stream, socket_addr, diff --git a/proxy/src/compute/tls.rs b/proxy/src/compute/tls.rs index 000d75fca5..cc1c0d1658 100644 --- a/proxy/src/compute/tls.rs +++ b/proxy/src/compute/tls.rs @@ -7,6 +7,7 @@ use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use crate::pqproto::request_tls; +use crate::proxy::connect_compute::TlsNegotiation; use crate::proxy::retry::CouldRetry; #[derive(Debug, Error)] @@ -35,6 +36,7 @@ pub async fn connect_tls( mode: SslMode, tls: &T, host: &str, + negotiation: TlsNegotiation, ) -> Result, TlsError> where S: AsyncRead + AsyncWrite + Unpin + Send, @@ -49,12 +51,15 @@ where SslMode::Prefer | SslMode::Require => {} } - if !request_tls(&mut stream).await? { - if SslMode::Require == mode { - return Err(TlsError::Required); - } - - return Ok(MaybeTlsStream::Raw(stream)); + match negotiation { + // No TLS request needed + TlsNegotiation::Direct => {} + // TLS request successful + TlsNegotiation::Postgres if request_tls(&mut stream).await? => {} + // TLS request failed but is required + TlsNegotiation::Postgres if SslMode::Require == mode => return Err(TlsError::Required), + // TLS request failed but is not required + TlsNegotiation::Postgres => return Ok(MaybeTlsStream::Raw(stream)), } Ok(MaybeTlsStream::Tls( diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 16b1dff5f4..22902dbcab 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -19,7 +19,7 @@ use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings}; use crate::ext::TaskExt; use crate::intern::RoleNameInt; use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig}; -use crate::scram::threadpool::ThreadPool; +use crate::scram; use crate::serverless::GlobalConnPoolOptions; use crate::serverless::cancel_set::CancelSet; #[cfg(feature = "rest_broker")] @@ -75,7 +75,7 @@ pub struct HttpConfig { } pub struct AuthenticationConfig { - pub thread_pool: Arc, + pub scram_thread_pool: Arc, pub scram_protocol_timeout: tokio::time::Duration, pub ip_allowlist_check_enabled: bool, pub is_vpc_acccess_proxy: bool, @@ -107,20 +107,23 @@ pub fn remote_storage_from_toml(s: &str) -> anyhow::Result #[derive(Debug)] pub struct CacheOptions { /// Max number of entries. - pub size: usize, + pub size: Option, /// Entry's time-to-live. - pub ttl: Duration, + pub absolute_ttl: Option, + /// Entry's time-to-idle. + pub idle_ttl: Option, } impl CacheOptions { - /// Default options for [`crate::control_plane::NodeInfoCache`]. - pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m"; + /// Default options for [`crate::cache::node_info::NodeInfoCache`]. + pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,idle_ttl=4m"; /// Parse cache options passed via cmdline. /// Example: [`Self::CACHE_DEFAULT_OPTIONS`]. fn parse(options: &str) -> anyhow::Result { let mut size = None; - let mut ttl = None; + let mut absolute_ttl = None; + let mut idle_ttl = None; for option in options.split(',') { let (key, value) = option @@ -129,21 +132,34 @@ impl CacheOptions { match key { "size" => size = Some(value.parse()?), - "ttl" => ttl = Some(humantime::parse_duration(value)?), + "absolute_ttl" | "ttl" => absolute_ttl = Some(humantime::parse_duration(value)?), + "idle_ttl" | "tti" => idle_ttl = Some(humantime::parse_duration(value)?), unknown => bail!("unknown key: {unknown}"), } } - // TTL doesn't matter if cache is always empty. - if let Some(0) = size { - ttl.get_or_insert(Duration::default()); - } - Ok(Self { - size: size.context("missing `size`")?, - ttl: ttl.context("missing `ttl`")?, + size, + absolute_ttl, + idle_ttl, }) } + + pub fn moka( + &self, + mut builder: moka::sync::CacheBuilder, + ) -> moka::sync::CacheBuilder { + if let Some(size) = self.size { + builder = builder.max_capacity(size); + } + if let Some(ttl) = self.absolute_ttl { + builder = builder.time_to_live(ttl); + } + if let Some(tti) = self.idle_ttl { + builder = builder.time_to_idle(tti); + } + builder + } } impl FromStr for CacheOptions { @@ -159,17 +175,17 @@ impl FromStr for CacheOptions { #[derive(Debug)] pub struct ProjectInfoCacheOptions { /// Max number of entries. - pub size: usize, + pub size: u64, /// Entry's time-to-live. pub ttl: Duration, /// Max number of roles per endpoint. - pub max_roles: usize, + pub max_roles: u64, /// Gc interval. pub gc_interval: Duration, } impl ProjectInfoCacheOptions { - /// Default options for [`crate::control_plane::NodeInfoCache`]. + /// Default options for [`crate::cache::project_info::ProjectInfoCache`]. pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=10000,ttl=4m,max_roles=10,gc_interval=60m"; @@ -496,21 +512,37 @@ mod tests { #[test] fn test_parse_cache_options() -> anyhow::Result<()> { - let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?; - assert_eq!(size, 4096); - assert_eq!(ttl, Duration::from_secs(5 * 60)); + let CacheOptions { + size, + absolute_ttl, + idle_ttl: _, + } = "size=4096,ttl=5min".parse()?; + assert_eq!(size, Some(4096)); + assert_eq!(absolute_ttl, Some(Duration::from_secs(5 * 60))); - let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?; - assert_eq!(size, 2); - assert_eq!(ttl, Duration::from_secs(4 * 60)); + let CacheOptions { + size, + absolute_ttl, + idle_ttl: _, + } = "ttl=4m,size=2".parse()?; + assert_eq!(size, Some(2)); + assert_eq!(absolute_ttl, Some(Duration::from_secs(4 * 60))); - let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?; - assert_eq!(size, 0); - assert_eq!(ttl, Duration::from_secs(1)); + let CacheOptions { + size, + absolute_ttl, + idle_ttl: _, + } = "size=0,ttl=1s".parse()?; + assert_eq!(size, Some(0)); + assert_eq!(absolute_ttl, Some(Duration::from_secs(1))); - let CacheOptions { size, ttl } = "size=0".parse()?; - assert_eq!(size, 0); - assert_eq!(ttl, Duration::default()); + let CacheOptions { + size, + absolute_ttl, + idle_ttl: _, + } = "size=0".parse()?; + assert_eq!(size, Some(0)); + assert_eq!(absolute_ttl, None); Ok(()) } diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 014317d823..f947abebc0 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -1,12 +1,13 @@ use std::sync::Arc; use futures::{FutureExt, TryFutureExt}; +use postgres_client::RawCancelToken; use tokio::io::{AsyncRead, AsyncWrite}; use tokio_util::sync::CancellationToken; use tracing::{Instrument, debug, error, info}; use crate::auth::backend::ConsoleRedirectBackend; -use crate::cancellation::CancellationHandler; +use crate::cancellation::{CancelClosure, CancellationHandler}; use crate::config::{ProxyConfig, ProxyProtocolV2}; use crate::context::RequestContext; use crate::error::ReportableError; @@ -15,8 +16,9 @@ use crate::pglb::ClientRequestError; use crate::pglb::handshake::{HandshakeData, handshake}; use crate::pglb::passthrough::ProxyPassthrough; use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; -use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; -use crate::proxy::{ErrorSource, finish_client_init}; +use crate::proxy::{ + ErrorSource, connect_compute, forward_compute_params_to_client, send_client_greeting, +}; use crate::util::run_until_cancelled; pub async fn task_main( @@ -214,33 +216,28 @@ pub(crate) async fn handle_client( }; auth_info.set_startup_params(¶ms, true); - let mut node = connect_to_compute( + let mut node = connect_compute::connect_to_compute( ctx, - &TcpMechanism { - locks: &config.connect_compute_locks, - }, + config, &node_info, - config.wake_compute_retry_config, - &config.connect_to_compute, + connect_compute::TlsNegotiation::Postgres, ) .or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) }) .await?; - let pg_settings = auth_info - .authenticate(ctx, &mut node, &user_info) + auth_info + .authenticate(ctx, &mut node) .or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) }) .await?; + send_client_greeting(ctx, &config.greetings, &mut stream); let session = cancellation_handler.get_key(); - finish_client_init( - ctx, - &pg_settings, - *session.key(), - &mut stream, - &config.greetings, - ); + let (process_id, secret_key) = + forward_compute_params_to_client(ctx, *session.key(), &mut stream, &mut node.stream) + .await?; let stream = stream.flush_and_into_inner().await?; + let hostname = node.hostname.to_string(); let session_id = ctx.session_id(); let (cancel_on_shutdown, cancel) = tokio::sync::oneshot::channel(); @@ -249,7 +246,16 @@ pub(crate) async fn handle_client( .maintain_cancel_key( session_id, cancel, - &pg_settings.cancel_closure, + &CancelClosure { + socket_addr: node.socket_addr, + cancel_token: RawCancelToken { + ssl_mode: node.ssl_mode, + process_id, + secret_key, + }, + hostname, + user_info, + }, &config.connect_to_compute, ) .await; @@ -257,7 +263,7 @@ pub(crate) async fn handle_client( Ok(Some(ProxyPassthrough { client: stream, - compute: node.stream, + compute: node.stream.into_framed().into_inner(), aux: node.aux, private_link_id: None, diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 8a0403c0b0..b76b13e2c2 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -3,7 +3,6 @@ use std::net::IpAddr; use std::str::FromStr; use std::sync::Arc; -use std::time::Duration; use ::http::HeaderName; use ::http::header::AUTHORIZATION; @@ -17,6 +16,8 @@ use tracing::{Instrument, debug, info, info_span, warn}; use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute}; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; +use crate::cache::Cached; +use crate::cache::node_info::CachedNodeInfo; use crate::context::RequestContext; use crate::control_plane::caches::ApiCaches; use crate::control_plane::errors::{ @@ -25,8 +26,7 @@ use crate::control_plane::errors::{ use crate::control_plane::locks::ApiLocks; use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse}; use crate::control_plane::{ - AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, - RoleAccessControl, + AccessBlockerFlags, AuthInfo, AuthSecret, EndpointAccessControl, NodeInfo, RoleAccessControl, }; use crate::metrics::Metrics; use crate::proxy::retry::CouldRetry; @@ -118,7 +118,6 @@ impl NeonControlPlaneClient { cache_key.into(), role.into(), msg.clone(), - retry_info.map(|r| Duration::from_millis(r.retry_delay_ms)), ); Err(err) @@ -347,18 +346,11 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { ) -> Result { let key = endpoint.normalize(); - if let Some((role_control, ttl)) = self - .caches - .project_info - .get_role_secret_with_ttl(&key, role) - { + if let Some(role_control) = self.caches.project_info.get_role_secret(&key, role) { return match role_control { - Err(mut msg) => { + Err(msg) => { info!(key = &*key, "found cached get_role_access_control error"); - // if retry_delay_ms is set change it to the remaining TTL - replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64); - Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg))) } Ok(role_control) => { @@ -383,17 +375,14 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { ) -> Result { let key = endpoint.normalize(); - if let Some((control, ttl)) = self.caches.project_info.get_endpoint_access_with_ttl(&key) { + if let Some(control) = self.caches.project_info.get_endpoint_access(&key) { return match control { - Err(mut msg) => { + Err(msg) => { info!( key = &*key, "found cached get_endpoint_access_control error" ); - // if retry_delay_ms is set change it to the remaining TTL - replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64); - Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg))) } Ok(control) => { @@ -426,17 +415,11 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { macro_rules! check_cache { () => { - if let Some(cached) = self.caches.node_info.get_with_created_at(&key) { - let (cached, (info, created_at)) = cached.take_value(); + if let Some(info) = self.caches.node_info.get_entry(&key) { return match info { - Err(mut msg) => { + Err(msg) => { info!(key = &*key, "found cached wake_compute error"); - // if retry_delay_ms is set, reduce it by the amount of time it spent in cache - replace_retry_delay_ms(&mut msg, |delay| { - delay.saturating_sub(created_at.elapsed().as_millis() as u64) - }); - Err(WakeComputeError::ControlPlane(ControlPlaneError::Message( msg, ))) @@ -444,7 +427,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { Ok(info) => { debug!(key = &*key, "found cached compute node info"); ctx.set_project(info.aux.clone()); - Ok(cached.map(|()| info)) + Ok(info) } }; } @@ -483,10 +466,12 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { let mut stored_node = node.clone(); // store the cached node as 'warm_cached' stored_node.aux.cold_start_info = ColdStartInfo::WarmCached; + self.caches.node_info.insert(key.clone(), Ok(stored_node)); - let (_, cached) = self.caches.node_info.insert_unit(key, Ok(stored_node)); - - Ok(cached.map(|()| node)) + Ok(Cached { + token: Some((&self.caches.node_info, key)), + value: node, + }) } Err(err) => match err { WakeComputeError::ControlPlane(ControlPlaneError::Message(ref msg)) => { @@ -503,11 +488,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { "created a cache entry for the wake compute error" ); - let ttl = retry_info.map_or(Duration::from_secs(30), |r| { - Duration::from_millis(r.retry_delay_ms) - }); - - self.caches.node_info.insert_ttl(key, Err(msg.clone()), ttl); + self.caches.node_info.insert(key, Err(msg.clone())); Err(err) } @@ -517,14 +498,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { } } -fn replace_retry_delay_ms(msg: &mut ControlPlaneErrorMessage, f: impl FnOnce(u64) -> u64) { - if let Some(status) = &mut msg.status - && let Some(retry_info) = &mut status.details.retry_info - { - retry_info.retry_delay_ms = f(retry_info.retry_delay_ms); - } -} - /// Parse http response body, taking status code into account. fn parse_body serde::Deserialize<'a>>( status: StatusCode, diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index b84dba6b09..9e48d91340 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -15,6 +15,7 @@ use crate::auth::IpPattern; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; use crate::cache::Cached; +use crate::cache::node_info::CachedNodeInfo; use crate::compute::ConnectInfo; use crate::context::RequestContext; use crate::control_plane::errors::{ @@ -22,8 +23,7 @@ use crate::control_plane::errors::{ }; use crate::control_plane::messages::{EndpointRateLimitConfig, MetricsAuxInfo}; use crate::control_plane::{ - AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, - RoleAccessControl, + AccessBlockerFlags, AuthInfo, AuthSecret, EndpointAccessControl, NodeInfo, RoleAccessControl, }; use crate::intern::RoleNameInt; use crate::scram; diff --git a/proxy/src/control_plane/client/mod.rs b/proxy/src/control_plane/client/mod.rs index ecd4db29b2..ec26746873 100644 --- a/proxy/src/control_plane/client/mod.rs +++ b/proxy/src/control_plane/client/mod.rs @@ -13,10 +13,11 @@ use tracing::{debug, info}; use super::{EndpointAccessControl, RoleAccessControl}; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError}; -use crate::cache::project_info::ProjectInfoCacheImpl; +use crate::cache::node_info::{CachedNodeInfo, NodeInfoCache}; +use crate::cache::project_info::ProjectInfoCache; use crate::config::{CacheOptions, ProjectInfoCacheOptions}; use crate::context::RequestContext; -use crate::control_plane::{CachedNodeInfo, ControlPlaneApi, NodeInfoCache, errors}; +use crate::control_plane::{ControlPlaneApi, errors}; use crate::error::ReportableError; use crate::metrics::ApiLockMetrics; use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token}; @@ -119,7 +120,7 @@ pub struct ApiCaches { /// Cache for the `wake_compute` API method. pub(crate) node_info: NodeInfoCache, /// Cache which stores project_id -> endpoint_ids mapping. - pub project_info: Arc, + pub project_info: Arc, } impl ApiCaches { @@ -128,13 +129,8 @@ impl ApiCaches { project_info_cache_config: ProjectInfoCacheOptions, ) -> Self { Self { - node_info: NodeInfoCache::new( - "node_info_cache", - wake_compute_cache_config.size, - wake_compute_cache_config.ttl, - true, - ), - project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)), + node_info: NodeInfoCache::new(wake_compute_cache_config), + project_info: Arc::new(ProjectInfoCache::new(project_info_cache_config)), } } } diff --git a/proxy/src/control_plane/messages.rs b/proxy/src/control_plane/messages.rs index d44d7efcc3..a23ddeb5c8 100644 --- a/proxy/src/control_plane/messages.rs +++ b/proxy/src/control_plane/messages.rs @@ -1,8 +1,10 @@ use std::fmt::{self, Display}; +use std::time::Duration; use measured::FixedCardinalityLabel; use serde::{Deserialize, Serialize}; use smol_str::SmolStr; +use tokio::time::Instant; use crate::auth::IpPattern; use crate::intern::{AccountIdInt, BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; @@ -231,7 +233,13 @@ impl Reason { #[derive(Copy, Clone, Debug, Deserialize)] #[allow(dead_code)] pub(crate) struct RetryInfo { - pub(crate) retry_delay_ms: u64, + #[serde(rename = "retry_delay_ms", deserialize_with = "milliseconds_from_now")] + pub(crate) retry_at: Instant, +} + +fn milliseconds_from_now<'de, D: serde::Deserializer<'de>>(d: D) -> Result { + let millis = u64::deserialize(d)?; + Ok(Instant::now() + Duration::from_millis(millis)) } #[derive(Debug, Deserialize, Clone)] diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index 9bbd3f4fb7..6f326d789a 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -16,14 +16,13 @@ use messages::EndpointRateLimitConfig; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list}; -use crate::cache::{Cached, TimedLru}; -use crate::config::ComputeConfig; +use crate::cache::node_info::CachedNodeInfo; use crate::context::RequestContext; -use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo}; +use crate::control_plane::messages::MetricsAuxInfo; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt}; use crate::protocol2::ConnectionInfoExtra; use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig}; -use crate::types::{EndpointCacheKey, EndpointId, RoleName}; +use crate::types::{EndpointId, RoleName}; use crate::{compute, scram}; /// Various cache-related types. @@ -72,26 +71,12 @@ pub(crate) struct NodeInfo { pub(crate) aux: MetricsAuxInfo, } -impl NodeInfo { - pub(crate) async fn connect( - &self, - ctx: &RequestContext, - config: &ComputeConfig, - ) -> Result { - self.conn_info.connect(ctx, &self.aux, config).await - } -} - #[derive(Copy, Clone, Default, Debug)] pub(crate) struct AccessBlockerFlags { pub public_access_blocked: bool, pub vpc_access_blocked: bool, } -pub(crate) type NodeInfoCache = - TimedLru>>; -pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; - #[derive(Clone, Debug)] pub struct RoleAccessControl { pub secret: Option, diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 7524133093..905c9b5279 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -2,59 +2,60 @@ use std::sync::{Arc, OnceLock}; use lasso::ThreadedRodeo; use measured::label::{ - FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue, StaticLabelSet, + FixedCardinalitySet, LabelGroupSet, LabelGroupVisitor, LabelName, LabelSet, LabelValue, + StaticLabelSet, }; +use measured::metric::group::Encoding; use measured::metric::histogram::Thresholds; use measured::metric::name::MetricName; use measured::{ - Counter, CounterVec, FixedCardinalityLabel, Gauge, Histogram, HistogramVec, LabelGroup, - MetricGroup, + Counter, CounterVec, FixedCardinalityLabel, Gauge, GaugeVec, Histogram, HistogramVec, + LabelGroup, MetricGroup, }; -use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLogVec}; +use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLogVec, InfoMetric}; use tokio::time::{self, Instant}; use crate::control_plane::messages::ColdStartInfo; use crate::error::ErrorKind; #[derive(MetricGroup)] -#[metric(new(thread_pool: Arc))] +#[metric(new())] pub struct Metrics { #[metric(namespace = "proxy")] - #[metric(init = ProxyMetrics::new(thread_pool))] + #[metric(init = ProxyMetrics::new())] pub proxy: ProxyMetrics, #[metric(namespace = "wake_compute_lock")] pub wake_compute_lock: ApiLockMetrics, + + #[metric(namespace = "service")] + pub service: ServiceMetrics, + + #[metric(namespace = "cache")] + pub cache: CacheMetrics, } -static SELF: OnceLock = OnceLock::new(); impl Metrics { - pub fn install(thread_pool: Arc) { - let mut metrics = Metrics::new(thread_pool); - - metrics.proxy.errors_total.init_all_dense(); - metrics.proxy.redis_errors_total.init_all_dense(); - metrics.proxy.redis_events_count.init_all_dense(); - metrics.proxy.retries_metric.init_all_dense(); - metrics.proxy.connection_failures_total.init_all_dense(); - - SELF.set(metrics) - .ok() - .expect("proxy metrics must not be installed more than once"); - } - + #[track_caller] pub fn get() -> &'static Self { - #[cfg(test)] - return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0)))); + static SELF: OnceLock = OnceLock::new(); - #[cfg(not(test))] - SELF.get() - .expect("proxy metrics must be installed by the main() function") + SELF.get_or_init(|| { + let mut metrics = Metrics::new(); + + metrics.proxy.errors_total.init_all_dense(); + metrics.proxy.redis_errors_total.init_all_dense(); + metrics.proxy.redis_events_count.init_all_dense(); + metrics.proxy.retries_metric.init_all_dense(); + metrics.proxy.connection_failures_total.init_all_dense(); + + metrics + }) } } #[derive(MetricGroup)] -#[metric(new(thread_pool: Arc))] +#[metric(new())] pub struct ProxyMetrics { #[metric(flatten)] pub db_connections: CounterPairVec, @@ -127,6 +128,9 @@ pub struct ProxyMetrics { /// Number of TLS handshake failures pub tls_handshake_failures: Counter, + /// Number of SHA 256 rounds executed. + pub sha_rounds: Counter, + /// HLL approximate cardinality of endpoints that are connecting pub connecting_endpoints: HyperLogLogVec, 32>, @@ -144,8 +148,25 @@ pub struct ProxyMetrics { pub connect_compute_lock: ApiLockMetrics, #[metric(namespace = "scram_pool")] - #[metric(init = thread_pool)] - pub scram_pool: Arc, + pub scram_pool: OnceLockWrapper>, +} + +/// A Wrapper over [`OnceLock`] to implement [`MetricGroup`]. +pub struct OnceLockWrapper(pub OnceLock); + +impl Default for OnceLockWrapper { + fn default() -> Self { + Self(OnceLock::new()) + } +} + +impl> MetricGroup for OnceLockWrapper { + fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> { + if let Some(inner) = self.0.get() { + inner.collect_group_into(enc)?; + } + Ok(()) + } } #[derive(MetricGroup)] @@ -215,13 +236,6 @@ pub enum Bool { False, } -#[derive(FixedCardinalityLabel, Copy, Clone)] -#[label(singleton = "outcome")] -pub enum CacheOutcome { - Hit, - Miss, -} - #[derive(LabelGroup)] #[label(set = ConsoleRequestSet)] pub struct ConsoleRequest<'a> { @@ -553,14 +567,6 @@ impl From for Bool { } } -#[derive(LabelGroup)] -#[label(set = InvalidEndpointsSet)] -pub struct InvalidEndpointsGroup { - pub protocol: Protocol, - pub rejected: Bool, - pub outcome: ConnectOutcome, -} - #[derive(LabelGroup)] #[label(set = RetriesMetricSet)] pub struct RetriesMetricGroup { @@ -660,3 +666,100 @@ pub struct ThreadPoolMetrics { #[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))] pub worker_task_skips_total: CounterVec, } + +#[derive(MetricGroup, Default)] +pub struct ServiceMetrics { + pub info: InfoMetric, +} + +#[derive(Default)] +pub struct ServiceInfo { + pub state: ServiceState, +} + +impl ServiceInfo { + pub const fn running() -> Self { + ServiceInfo { + state: ServiceState::Running, + } + } + + pub const fn terminating() -> Self { + ServiceInfo { + state: ServiceState::Terminating, + } + } +} + +impl LabelGroup for ServiceInfo { + fn visit_values(&self, v: &mut impl LabelGroupVisitor) { + const STATE: &LabelName = LabelName::from_str("state"); + v.write_value(STATE, &self.state); + } +} + +#[derive(FixedCardinalityLabel, Clone, Copy, Debug, Default)] +#[label(singleton = "state")] +pub enum ServiceState { + #[default] + Init, + Running, + Terminating, +} + +#[derive(MetricGroup)] +#[metric(new())] +pub struct CacheMetrics { + /// The capacity of the cache + pub capacity: GaugeVec>, + /// The total number of entries inserted into the cache + pub inserted_total: CounterVec>, + /// The total number of entries removed from the cache + pub evicted_total: CounterVec, + /// The total number of cache requests + pub request_total: CounterVec, +} + +impl Default for CacheMetrics { + fn default() -> Self { + Self::new() + } +} + +#[derive(FixedCardinalityLabel, Clone, Copy, Debug)] +#[label(singleton = "cache")] +pub enum CacheKind { + NodeInfo, + ProjectInfoEndpoints, + ProjectInfoRoles, + Schema, + Pbkdf2, +} + +#[derive(FixedCardinalityLabel, Clone, Copy, Debug)] +pub enum CacheRemovalCause { + Expired, + Explicit, + Replaced, + Size, +} + +#[derive(LabelGroup)] +#[label(set = CacheEvictionSet)] +pub struct CacheEviction { + pub cache: CacheKind, + pub cause: CacheRemovalCause, +} + +#[derive(FixedCardinalityLabel, Copy, Clone)] +pub enum CacheOutcome { + Hit, + Miss, +} + +#[derive(LabelGroup)] +#[label(set = CacheOutcomeSet)] +pub struct CacheOutcomeGroup { + pub cache: CacheKind, + pub outcome: CacheOutcome, +} diff --git a/proxy/src/pglb/mod.rs b/proxy/src/pglb/mod.rs index c4cab155c5..999fa6eb32 100644 --- a/proxy/src/pglb/mod.rs +++ b/proxy/src/pglb/mod.rs @@ -319,7 +319,7 @@ pub(crate) async fn handle_connection( Ok(Some(ProxyPassthrough { client, - compute: node.stream, + compute: node.stream.into_framed().into_inner(), aux: node.aux, private_link_id, diff --git a/proxy/src/pqproto.rs b/proxy/src/pqproto.rs index 680a23c435..7a68d430db 100644 --- a/proxy/src/pqproto.rs +++ b/proxy/src/pqproto.rs @@ -313,6 +313,14 @@ impl WriteBuf { self.0.set_position(0); } + /// Shrinks the buffer if efficient to do so, and returns the remaining size. + pub fn occupied_len(&mut self) -> usize { + if self.should_shrink() { + self.shrink(); + } + self.0.get_mut().len() + } + /// Write a raw message to the internal buffer. /// /// The size_hint value is only a hint for reserving space. It's ok if it's incorrect, since diff --git a/proxy/src/proxy/connect_auth.rs b/proxy/src/proxy/connect_auth.rs new file mode 100644 index 0000000000..77578c71b1 --- /dev/null +++ b/proxy/src/proxy/connect_auth.rs @@ -0,0 +1,82 @@ +use thiserror::Error; + +use crate::auth::Backend; +use crate::auth::backend::ComputeUserInfo; +use crate::cache::common::Cache; +use crate::compute::{AuthInfo, ComputeConnection, ConnectionError, PostgresError}; +use crate::config::ProxyConfig; +use crate::context::RequestContext; +use crate::control_plane::client::ControlPlaneClient; +use crate::error::{ReportableError, UserFacingError}; +use crate::proxy::connect_compute::{TlsNegotiation, connect_to_compute}; +use crate::proxy::retry::ShouldRetryWakeCompute; + +#[derive(Debug, Error)] +pub enum AuthError { + #[error(transparent)] + Auth(#[from] PostgresError), + #[error(transparent)] + Connect(#[from] ConnectionError), +} + +impl UserFacingError for AuthError { + fn to_string_client(&self) -> String { + match self { + AuthError::Auth(postgres_error) => postgres_error.to_string_client(), + AuthError::Connect(connection_error) => connection_error.to_string_client(), + } + } +} + +impl ReportableError for AuthError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + AuthError::Auth(postgres_error) => postgres_error.get_error_kind(), + AuthError::Connect(connection_error) => connection_error.get_error_kind(), + } + } +} + +/// Try to connect to the compute node, retrying if necessary. +#[tracing::instrument(skip_all)] +pub(crate) async fn connect_to_compute_and_auth( + ctx: &RequestContext, + config: &ProxyConfig, + user_info: &Backend<'_, ComputeUserInfo>, + auth_info: AuthInfo, + tls: TlsNegotiation, +) -> Result { + let mut attempt = 0; + + // NOTE: This is messy, but should hopefully be detangled with PGLB. + // We wanted to separate the concerns of **connect** to compute (a PGLB operation), + // from **authenticate** to compute (a NeonKeeper operation). + // + // This unfortunately removed retry handling for one error case where + // the compute was cached, and we connected, but the compute cache was actually stale + // and is associated with the wrong endpoint. We detect this when the **authentication** fails. + // As such, we retry once here if the `authenticate` function fails and the error is valid to retry. + loop { + attempt += 1; + let mut node = connect_to_compute(ctx, config, user_info, tls).await?; + + let res = auth_info.authenticate(ctx, &mut node).await; + match res { + Ok(()) => return Ok(node), + Err(e) => { + if attempt < 2 + && let Backend::ControlPlane(cplane, user_info) = user_info + && let ControlPlaneClient::ProxyV1(cplane_proxy_v1) = &**cplane + && e.should_retry_wake_compute() + { + tracing::warn!(error = ?e, "retrying wake compute"); + let key = user_info.endpoint_cache_key(); + cplane_proxy_v1.caches.node_info.invalidate(&key); + continue; + } + + return Err(e)?; + } + } + } +} diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index ce9774e3eb..515f925236 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -1,18 +1,16 @@ -use async_trait::async_trait; use tokio::time; use tracing::{debug, info, warn}; +use crate::cache::node_info::CachedNodeInfo; use crate::compute::{self, COULD_NOT_CONNECT, ComputeConnection}; -use crate::config::{ComputeConfig, RetryConfig}; +use crate::config::{ComputeConfig, ProxyConfig, RetryConfig}; use crate::context::RequestContext; -use crate::control_plane::errors::WakeComputeError; +use crate::control_plane::NodeInfo; use crate::control_plane::locks::ApiLocks; -use crate::control_plane::{self, NodeInfo}; -use crate::error::ReportableError; use crate::metrics::{ ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType, }; -use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry}; +use crate::proxy::retry::{ShouldRetryWakeCompute, retry_after, should_retry}; use crate::proxy::wake_compute::{WakeComputeBackend, wake_compute}; use crate::types::Host; @@ -20,7 +18,7 @@ use crate::types::Host; /// (e.g. the compute node's address might've changed at the wrong time). /// Invalidate the cache entry (if any) to prevent subsequent errors. #[tracing::instrument(skip_all)] -pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> NodeInfo { +pub(crate) fn invalidate_cache(node_info: CachedNodeInfo) -> NodeInfo { let is_cached = node_info.cached(); if is_cached { warn!("invalidating stalled compute node info cache entry"); @@ -35,29 +33,32 @@ pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> Node node_info.invalidate() } -#[async_trait] pub(crate) trait ConnectMechanism { type Connection; - type ConnectError: ReportableError; - type Error: From; async fn connect_once( &self, ctx: &RequestContext, - node_info: &control_plane::CachedNodeInfo, + node_info: &CachedNodeInfo, config: &ComputeConfig, - ) -> Result; + ) -> Result; } -pub(crate) struct TcpMechanism { +struct TcpMechanism<'a> { /// connect_to_compute concurrency lock - pub(crate) locks: &'static ApiLocks, + locks: &'a ApiLocks, + tls: TlsNegotiation, } -#[async_trait] -impl ConnectMechanism for TcpMechanism { +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum TlsNegotiation { + /// TLS is assumed + Direct, + /// We must ask for TLS using the postgres SSLRequest message + Postgres, +} + +impl ConnectMechanism for TcpMechanism<'_> { type Connection = ComputeConnection; - type ConnectError = compute::ConnectionError; - type Error = compute::ConnectionError; #[tracing::instrument(skip_all, fields( pid = tracing::field::Empty, @@ -66,27 +67,49 @@ impl ConnectMechanism for TcpMechanism { async fn connect_once( &self, ctx: &RequestContext, - node_info: &control_plane::CachedNodeInfo, + node_info: &CachedNodeInfo, config: &ComputeConfig, - ) -> Result { + ) -> Result { let permit = self.locks.get_permit(&node_info.conn_info.host).await?; - permit.release_result(node_info.connect(ctx, config).await) + + permit.release_result( + node_info + .conn_info + .connect(ctx, &node_info.aux, config, self.tls) + .await, + ) } } /// Try to connect to the compute node, retrying if necessary. #[tracing::instrument(skip_all)] -pub(crate) async fn connect_to_compute( +pub(crate) async fn connect_to_compute( + ctx: &RequestContext, + config: &ProxyConfig, + user_info: &B, + tls: TlsNegotiation, +) -> Result { + connect_to_compute_inner( + ctx, + &TcpMechanism { + locks: &config.connect_compute_locks, + tls, + }, + user_info, + config.wake_compute_retry_config, + &config.connect_to_compute, + ) + .await +} + +/// Try to connect to the compute node, retrying if necessary. +pub(crate) async fn connect_to_compute_inner( ctx: &RequestContext, mechanism: &M, user_info: &B, wake_compute_retry_config: RetryConfig, compute: &ComputeConfig, -) -> Result -where - M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug, - M::Error: From, -{ +) -> Result { let mut num_retries = 0; let node_info = wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?; @@ -120,7 +143,7 @@ where }, num_retries.into(), ); - return Err(err.into()); + return Err(err); } node_info } else { @@ -161,7 +184,7 @@ where }, num_retries.into(), ); - return Err(e.into()); + return Err(e); } warn!(error = ?e, num_retries, retriable = true, COULD_NOT_CONNECT); diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 8b7c4ff55d..b42457cd95 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -1,6 +1,7 @@ #[cfg(test)] mod tests; +pub(crate) mod connect_auth; pub(crate) mod connect_compute; pub(crate) mod retry; pub(crate) mod wake_compute; @@ -9,26 +10,27 @@ use std::collections::HashSet; use std::convert::Infallible; use std::sync::Arc; +use futures::TryStreamExt; use itertools::Itertools; use once_cell::sync::OnceCell; +use postgres_client::RawCancelToken; +use postgres_client::connect_raw::StartupStream; +use postgres_protocol::message::backend::Message; use regex::Regex; use serde::{Deserialize, Serialize}; use smol_str::{SmolStr, format_smolstr}; use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::TcpStream; use tokio::sync::oneshot; use tracing::Instrument; -use crate::cache::Cache; -use crate::cancellation::CancellationHandler; -use crate::compute::ComputeConnection; +use crate::cancellation::{CancelClosure, CancellationHandler}; +use crate::compute::{ComputeConnection, PostgresError, RustlsStream}; use crate::config::ProxyConfig; use crate::context::RequestContext; -use crate::control_plane::client::ControlPlaneClient; pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; use crate::pglb::{ClientMode, ClientRequestError}; use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams}; -use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; -use crate::proxy::retry::ShouldRetryWakeCompute; use crate::rate_limiter::EndpointRateLimiter; use crate::stream::{PqStream, Stream}; use crate::types::EndpointCacheKey; @@ -90,62 +92,34 @@ pub(crate) async fn handle_client( let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys); auth_info.set_startup_params(params, params_compat); - let mut node; - let mut attempt = 0; - let connect = TcpMechanism { - locks: &config.connect_compute_locks, - }; let backend = auth::Backend::ControlPlane(cplane, creds.info); - // NOTE: This is messy, but should hopefully be detangled with PGLB. - // We wanted to separate the concerns of **connect** to compute (a PGLB operation), - // from **authenticate** to compute (a NeonKeeper operation). - // - // This unfortunately removed retry handling for one error case where - // the compute was cached, and we connected, but the compute cache was actually stale - // and is associated with the wrong endpoint. We detect this when the **authentication** fails. - // As such, we retry once here if the `authenticate` function fails and the error is valid to retry. - let pg_settings = loop { - attempt += 1; + // TODO: callback to pglb + let res = connect_auth::connect_to_compute_and_auth( + ctx, + config, + &backend, + auth_info, + connect_compute::TlsNegotiation::Postgres, + ) + .await; - // TODO: callback to pglb - let res = connect_to_compute( - ctx, - &connect, - &backend, - config.wake_compute_retry_config, - &config.connect_to_compute, - ) - .await; + let mut node = match res { + Ok(node) => node, + Err(e) => Err(client.throw_error(e, Some(ctx)).await)?, + }; - match res { - Ok(n) => node = n, - Err(e) => return Err(client.throw_error(e, Some(ctx)).await)?, - } + send_client_greeting(ctx, &config.greetings, client); - let auth::Backend::ControlPlane(cplane, user_info) = &backend else { - unreachable!("ensured above"); - }; - - let res = auth_info.authenticate(ctx, &mut node, user_info).await; - match res { - Ok(pg_settings) => break pg_settings, - Err(e) if attempt < 2 && e.should_retry_wake_compute() => { - tracing::warn!(error = ?e, "retrying wake compute"); - - #[allow(irrefutable_let_patterns)] - if let ControlPlaneClient::ProxyV1(cplane_proxy_v1) = &**cplane { - let key = user_info.endpoint_cache_key(); - cplane_proxy_v1.caches.node_info.invalidate(&key); - } - } - Err(e) => Err(client.throw_error(e, Some(ctx)).await)?, - } + let auth::Backend::ControlPlane(_, user_info) = backend else { + unreachable!("ensured above"); }; let session = cancellation_handler.get_key(); - finish_client_init(ctx, &pg_settings, *session.key(), client, &config.greetings); + let (process_id, secret_key) = + forward_compute_params_to_client(ctx, *session.key(), client, &mut node.stream).await?; + let hostname = node.hostname.to_string(); let session_id = ctx.session_id(); let (cancel_on_shutdown, cancel) = oneshot::channel(); @@ -154,7 +128,16 @@ pub(crate) async fn handle_client( .maintain_cancel_key( session_id, cancel, - &pg_settings.cancel_closure, + &CancelClosure { + socket_addr: node.socket_addr, + cancel_token: RawCancelToken { + ssl_mode: node.ssl_mode, + process_id, + secret_key, + }, + hostname, + user_info, + }, &config.connect_to_compute, ) .await; @@ -163,35 +146,18 @@ pub(crate) async fn handle_client( Ok((node, cancel_on_shutdown)) } -/// Finish client connection initialization: confirm auth success, send params, etc. -pub(crate) fn finish_client_init( +/// Greet the client with any useful information. +pub(crate) fn send_client_greeting( ctx: &RequestContext, - settings: &compute::PostgresSettings, - cancel_key_data: CancelKeyData, - client: &mut PqStream, greetings: &String, + client: &mut PqStream, ) { - // Forward all deferred notices to the client. - for notice in &settings.delayed_notice { - client.write_raw(notice.as_bytes().len(), b'N', |buf| { - buf.extend_from_slice(notice.as_bytes()); - }); - } - // Expose session_id to clients if we have a greeting message. if !greetings.is_empty() { let session_msg = format!("{}, session_id: {}", greetings, ctx.session_id()); client.write_message(BeMessage::NoticeResponse(session_msg.as_str())); } - // Forward all postgres connection params to the client. - for (name, value) in &settings.params { - client.write_message(BeMessage::ParameterStatus { - name: name.as_bytes(), - value: value.as_bytes(), - }); - } - // Forward recorded latencies for probing requests if let Some(testodrome_id) = ctx.get_testodrome_id() { client.write_message(BeMessage::ParameterStatus { @@ -221,9 +187,63 @@ pub(crate) fn finish_client_init( value: latency_measured.retry.as_micros().to_string().as_bytes(), }); } +} - client.write_message(BeMessage::BackendKeyData(cancel_key_data)); - client.write_message(BeMessage::ReadyForQuery); +pub(crate) async fn forward_compute_params_to_client( + ctx: &RequestContext, + cancel_key_data: CancelKeyData, + client: &mut PqStream, + compute: &mut StartupStream, +) -> Result<(i32, i32), ClientRequestError> { + let mut process_id = 0; + let mut secret_key = 0; + + let err = loop { + // if the client buffer is too large, let's write out some bytes now to save some space + client.write_if_full().await?; + + let msg = match compute.try_next().await { + Ok(msg) => msg, + Err(e) => break postgres_client::Error::io(e), + }; + + match msg { + // Send our cancellation key data instead. + Some(Message::BackendKeyData(body)) => { + client.write_message(BeMessage::BackendKeyData(cancel_key_data)); + process_id = body.process_id(); + secret_key = body.secret_key(); + } + // Forward all postgres connection params to the client. + Some(Message::ParameterStatus(body)) => { + if let Ok(name) = body.name() + && let Ok(value) = body.value() + { + client.write_message(BeMessage::ParameterStatus { + name: name.as_bytes(), + value: value.as_bytes(), + }); + } + } + // Forward all notices to the client. + Some(Message::NoticeResponse(notice)) => { + client.write_raw(notice.as_bytes().len(), b'N', |buf| { + buf.extend_from_slice(notice.as_bytes()); + }); + } + Some(Message::ReadyForQuery(_)) => { + client.write_message(BeMessage::ReadyForQuery); + return Ok((process_id, secret_key)); + } + Some(Message::ErrorResponse(body)) => break postgres_client::Error::db(body), + Some(_) => break postgres_client::Error::unexpected_message(), + None => break postgres_client::Error::closed(), + } + }; + + Err(client + .throw_error(PostgresError::Postgres(err), Some(ctx)) + .await)? } #[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)] diff --git a/proxy/src/proxy/retry.rs b/proxy/src/proxy/retry.rs index b06c3be72c..876d252517 100644 --- a/proxy/src/proxy/retry.rs +++ b/proxy/src/proxy/retry.rs @@ -31,18 +31,6 @@ impl CouldRetry for io::Error { } } -impl CouldRetry for postgres_client::error::DbError { - fn could_retry(&self) -> bool { - use postgres_client::error::SqlState; - matches!( - self.code(), - &SqlState::CONNECTION_FAILURE - | &SqlState::CONNECTION_EXCEPTION - | &SqlState::CONNECTION_DOES_NOT_EXIST - | &SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION, - ) - } -} impl ShouldRetryWakeCompute for postgres_client::error::DbError { fn should_retry_wake_compute(&self) -> bool { use postgres_client::error::SqlState; @@ -73,17 +61,6 @@ impl ShouldRetryWakeCompute for postgres_client::error::DbError { } } -impl CouldRetry for postgres_client::Error { - fn could_retry(&self) -> bool { - if let Some(io_err) = self.source().and_then(|x| x.downcast_ref()) { - io::Error::could_retry(io_err) - } else if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) { - postgres_client::error::DbError::could_retry(db_err) - } else { - false - } - } -} impl ShouldRetryWakeCompute for postgres_client::Error { fn should_retry_wake_compute(&self) -> bool { if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) { @@ -102,6 +79,8 @@ impl CouldRetry for compute::ConnectionError { compute::ConnectionError::TlsError(err) => err.could_retry(), compute::ConnectionError::WakeComputeError(err) => err.could_retry(), compute::ConnectionError::TooManyConnectionAttempts(_) => false, + #[cfg(test)] + compute::ConnectionError::TestError { retryable, .. } => *retryable, } } } @@ -110,6 +89,8 @@ impl ShouldRetryWakeCompute for compute::ConnectionError { match self { // the cache entry was not checked for validity compute::ConnectionError::TooManyConnectionAttempts(_) => false, + #[cfg(test)] + compute::ConnectionError::TestError { wakeable, .. } => *wakeable, _ => true, } } diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index f8bff450e1..7e0710749e 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -15,22 +15,24 @@ use rstest::rstest; use rustls::crypto::ring; use rustls::pki_types; use tokio::io::{AsyncRead, AsyncWrite, DuplexStream}; +use tokio::time::Instant; use tracing_test::traced_test; use super::retry::CouldRetry; use crate::auth::backend::{ComputeUserInfo, MaybeOwned}; -use crate::config::{ComputeConfig, RetryConfig, TlsConfig}; +use crate::cache::node_info::{CachedNodeInfo, NodeInfoCache}; +use crate::config::{CacheOptions, ComputeConfig, RetryConfig, TlsConfig}; use crate::context::RequestContext; use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; -use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache}; -use crate::error::{ErrorKind, ReportableError}; +use crate::control_plane::{self, NodeInfo}; +use crate::error::ErrorKind; use crate::pglb::ERR_INSECURE_CONNECTION; use crate::pglb::handshake::{HandshakeData, handshake}; use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; -use crate::proxy::connect_compute::{ConnectMechanism, connect_to_compute}; -use crate::proxy::retry::{ShouldRetryWakeCompute, retry_after}; +use crate::proxy::connect_compute::{ConnectMechanism, connect_to_compute_inner}; +use crate::proxy::retry::retry_after; use crate::stream::{PqStream, Stream}; use crate::tls::client_config::compute_client_config_with_certs; use crate::tls::server_config::CertResolver; @@ -417,12 +419,11 @@ impl TestConnectMechanism { Self { counter: Arc::new(std::sync::Mutex::new(0)), sequence, - cache: Box::leak(Box::new(NodeInfoCache::new( - "test", - 1, - Duration::from_secs(100), - false, - ))), + cache: Box::leak(Box::new(NodeInfoCache::new(CacheOptions { + size: Some(1), + absolute_ttl: Some(Duration::from_secs(100)), + idle_ttl: None, + }))), } } } @@ -430,71 +431,36 @@ impl TestConnectMechanism { #[derive(Debug)] struct TestConnection; -#[derive(Debug)] -struct TestConnectError { - retryable: bool, - wakeable: bool, - kind: crate::error::ErrorKind, -} - -impl ReportableError for TestConnectError { - fn get_error_kind(&self) -> crate::error::ErrorKind { - self.kind - } -} - -impl std::fmt::Display for TestConnectError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{self:?}") - } -} - -impl std::error::Error for TestConnectError {} - -impl CouldRetry for TestConnectError { - fn could_retry(&self) -> bool { - self.retryable - } -} -impl ShouldRetryWakeCompute for TestConnectError { - fn should_retry_wake_compute(&self) -> bool { - self.wakeable - } -} - -#[async_trait] impl ConnectMechanism for TestConnectMechanism { type Connection = TestConnection; - type ConnectError = TestConnectError; - type Error = anyhow::Error; async fn connect_once( &self, _ctx: &RequestContext, - _node_info: &control_plane::CachedNodeInfo, + _node_info: &CachedNodeInfo, _config: &ComputeConfig, - ) -> Result { + ) -> Result { let mut counter = self.counter.lock().unwrap(); let action = self.sequence[*counter]; *counter += 1; match action { ConnectAction::Connect => Ok(TestConnection), - ConnectAction::Retry => Err(TestConnectError { + ConnectAction::Retry => Err(compute::ConnectionError::TestError { retryable: true, wakeable: true, kind: ErrorKind::Compute, }), - ConnectAction::RetryNoWake => Err(TestConnectError { + ConnectAction::RetryNoWake => Err(compute::ConnectionError::TestError { retryable: true, wakeable: false, kind: ErrorKind::Compute, }), - ConnectAction::Fail => Err(TestConnectError { + ConnectAction::Fail => Err(compute::ConnectionError::TestError { retryable: false, wakeable: true, kind: ErrorKind::Compute, }), - ConnectAction::FailNoWake => Err(TestConnectError { + ConnectAction::FailNoWake => Err(compute::ConnectionError::TestError { retryable: false, wakeable: false, kind: ErrorKind::Compute, @@ -536,7 +502,7 @@ impl TestControlPlaneClient for TestConnectMechanism { details: Details { error_info: None, retry_info: Some(control_plane::messages::RetryInfo { - retry_delay_ms: 1, + retry_at: Instant::now() + Duration::from_millis(1), }), user_facing_message: None, }, @@ -581,8 +547,11 @@ fn helper_create_uncached_node_info() -> NodeInfo { fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo { let node = helper_create_uncached_node_info(); - let (_, node2) = cache.insert_unit("key".into(), Ok(node.clone())); - node2.map(|()| node) + cache.insert("key".into(), Ok(node.clone())); + CachedNodeInfo { + token: Some((cache, "key".into())), + value: node, + } } fn helper_create_connect_info( @@ -620,7 +589,7 @@ async fn connect_to_compute_success() { let mechanism = TestConnectMechanism::new(vec![Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -634,7 +603,7 @@ async fn connect_to_compute_retry() { let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -649,7 +618,7 @@ async fn connect_to_compute_non_retry_1() { let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap_err(); mechanism.verify(); @@ -664,7 +633,7 @@ async fn connect_to_compute_non_retry_2() { let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -686,7 +655,7 @@ async fn connect_to_compute_non_retry_3() { backoff_factor: 2.0, }; let config = config(); - connect_to_compute( + connect_to_compute_inner( &ctx, &mechanism, &user_info, @@ -707,7 +676,7 @@ async fn wake_retry() { let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -722,7 +691,7 @@ async fn wake_non_retry() { let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap_err(); mechanism.verify(); @@ -741,7 +710,7 @@ async fn fail_but_wake_invalidates_cache() { let user = helper_create_connect_info(&mech); let cfg = config(); - connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg) + connect_to_compute_inner(&ctx, &mech, &user, cfg.retry, &cfg) .await .unwrap(); @@ -762,7 +731,7 @@ async fn fail_no_wake_skips_cache_invalidation() { let user = helper_create_connect_info(&mech); let cfg = config(); - connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg) + connect_to_compute_inner(&ctx, &mech, &user, cfg.retry, &cfg) .await .unwrap(); @@ -783,7 +752,7 @@ async fn retry_but_wake_invalidates_cache() { let user_info = helper_create_connect_info(&mechanism); let cfg = config(); - connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg) + connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg) .await .unwrap(); mechanism.verify(); @@ -806,7 +775,7 @@ async fn retry_no_wake_skips_invalidation() { let user_info = helper_create_connect_info(&mechanism); let cfg = config(); - connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg) + connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg) .await .unwrap_err(); mechanism.verify(); @@ -829,7 +798,7 @@ async fn retry_no_wake_error_fast() { let user_info = helper_create_connect_info(&mechanism); let cfg = config(); - connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg) + connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg) .await .unwrap_err(); mechanism.verify(); @@ -852,7 +821,7 @@ async fn retry_cold_wake_skips_invalidation() { let user_info = helper_create_connect_info(&mechanism); let cfg = config(); - connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg) + connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg) .await .unwrap(); mechanism.verify(); diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs index b8edf9fd5c..66f2df2af4 100644 --- a/proxy/src/proxy/wake_compute.rs +++ b/proxy/src/proxy/wake_compute.rs @@ -1,9 +1,9 @@ use async_trait::async_trait; use tracing::{error, info}; +use crate::cache::node_info::CachedNodeInfo; use crate::config::RetryConfig; use crate::context::RequestContext; -use crate::control_plane::CachedNodeInfo; use crate::control_plane::errors::{ControlPlaneError, WakeComputeError}; use crate::error::ReportableError; use crate::metrics::{ diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 88d5550fff..a3a378c7e2 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -131,11 +131,11 @@ where Ok(()) } -struct MessageHandler { +struct MessageHandler { cache: Arc, } -impl Clone for MessageHandler { +impl Clone for MessageHandler { fn clone(&self) -> Self { Self { cache: self.cache.clone(), @@ -143,8 +143,8 @@ impl Clone for MessageHandler { } } -impl MessageHandler { - pub(crate) fn new(cache: Arc) -> Self { +impl MessageHandler { + pub(crate) fn new(cache: Arc) -> Self { Self { cache } } @@ -224,7 +224,7 @@ impl MessageHandler { } } -fn invalidate_cache(cache: Arc, msg: Notification) { +fn invalidate_cache(cache: Arc, msg: Notification) { match msg { Notification::EndpointSettingsUpdate(ids) => ids .iter() @@ -247,8 +247,8 @@ fn invalidate_cache(cache: Arc, msg: Notification) { } } -async fn handle_messages( - handler: MessageHandler, +async fn handle_messages( + handler: MessageHandler, redis: ConnectionWithCredentialsProvider, cancellation_token: CancellationToken, ) -> anyhow::Result<()> { @@ -284,13 +284,10 @@ async fn handle_messages( /// Handle console's invalidation messages. #[tracing::instrument(name = "redis_notifications", skip_all)] -pub async fn task_main( +pub async fn task_main( redis: ConnectionWithCredentialsProvider, - cache: Arc, -) -> anyhow::Result -where - C: ProjectInfoCache + Send + Sync + 'static, -{ + cache: Arc, +) -> anyhow::Result { let handler = MessageHandler::new(cache); // 6h - 1m. // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost. diff --git a/proxy/src/scram/cache.rs b/proxy/src/scram/cache.rs new file mode 100644 index 0000000000..9ade7af458 --- /dev/null +++ b/proxy/src/scram/cache.rs @@ -0,0 +1,84 @@ +use tokio::time::Instant; +use zeroize::Zeroize as _; + +use super::pbkdf2; +use crate::cache::Cached; +use crate::cache::common::{Cache, count_cache_insert, count_cache_outcome, eviction_listener}; +use crate::intern::{EndpointIdInt, RoleNameInt}; +use crate::metrics::{CacheKind, Metrics}; + +pub(crate) struct Pbkdf2Cache(moka::sync::Cache<(EndpointIdInt, RoleNameInt), Pbkdf2CacheEntry>); +pub(crate) type CachedPbkdf2<'a> = Cached<&'a Pbkdf2Cache>; + +impl Cache for Pbkdf2Cache { + type Key = (EndpointIdInt, RoleNameInt); + type Value = Pbkdf2CacheEntry; + + fn invalidate(&self, info: &(EndpointIdInt, RoleNameInt)) { + self.0.invalidate(info); + } +} + +/// To speed up password hashing for more active customers, we store the tail results of the +/// PBKDF2 algorithm. If the output of PBKDF2 is U1 ^ U2 ^ ⋯ ^ Uc, then we store +/// suffix = U17 ^ U18 ^ ⋯ ^ Uc. We only need to calculate U1 ^ U2 ^ ⋯ ^ U15 ^ U16 +/// to determine the final result. +/// +/// The suffix alone isn't enough to crack the password. The stored_key is still required. +/// While both are cached in memory, given they're in different locations is makes it much +/// harder to exploit, even if any such memory exploit exists in proxy. +#[derive(Clone)] +pub struct Pbkdf2CacheEntry { + /// corresponds to [`super::ServerSecret::cached_at`] + pub(super) cached_from: Instant, + pub(super) suffix: pbkdf2::Block, +} + +impl Drop for Pbkdf2CacheEntry { + fn drop(&mut self) { + self.suffix.zeroize(); + } +} + +impl Pbkdf2Cache { + pub fn new() -> Self { + const SIZE: u64 = 100; + const TTL: std::time::Duration = std::time::Duration::from_secs(60); + + let builder = moka::sync::Cache::builder() + .name("pbkdf2") + .max_capacity(SIZE) + // We use time_to_live so we don't refresh the lifetime for an invalid password attempt. + .time_to_live(TTL); + + Metrics::get() + .cache + .capacity + .set(CacheKind::Pbkdf2, SIZE as i64); + + let builder = + builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Pbkdf2, cause)); + + Self(builder.build()) + } + + pub fn insert(&self, endpoint: EndpointIdInt, role: RoleNameInt, value: Pbkdf2CacheEntry) { + count_cache_insert(CacheKind::Pbkdf2); + self.0.insert((endpoint, role), value); + } + + fn get(&self, endpoint: EndpointIdInt, role: RoleNameInt) -> Option { + count_cache_outcome(CacheKind::Pbkdf2, self.0.get(&(endpoint, role))) + } + + pub fn get_entry( + &self, + endpoint: EndpointIdInt, + role: RoleNameInt, + ) -> Option> { + self.get(endpoint, role).map(|value| Cached { + token: Some((self, (endpoint, role))), + value, + }) + } +} diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs index a0918fca9f..3f4b0d534b 100644 --- a/proxy/src/scram/exchange.rs +++ b/proxy/src/scram/exchange.rs @@ -4,10 +4,8 @@ use std::convert::Infallible; use base64::Engine as _; use base64::prelude::BASE64_STANDARD; -use hmac::{Hmac, Mac}; -use sha2::Sha256; +use tracing::{debug, trace}; -use super::ScramKey; use super::messages::{ ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN, }; @@ -15,8 +13,10 @@ use super::pbkdf2::Pbkdf2; use super::secret::ServerSecret; use super::signature::SignatureBuilder; use super::threadpool::ThreadPool; -use crate::intern::EndpointIdInt; +use super::{ScramKey, pbkdf2}; +use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::sasl::{self, ChannelBinding, Error as SaslError}; +use crate::scram::cache::Pbkdf2CacheEntry; /// The only channel binding mode we currently support. #[derive(Debug)] @@ -77,46 +77,113 @@ impl<'a> Exchange<'a> { } } -// copied from async fn derive_client_key( pool: &ThreadPool, endpoint: EndpointIdInt, password: &[u8], salt: &[u8], iterations: u32, -) -> ScramKey { - let salted_password = pool - .spawn_job(endpoint, Pbkdf2::start(password, salt, iterations)) - .await; - - let make_key = |name| { - let key = Hmac::::new_from_slice(&salted_password) - .expect("HMAC is able to accept all key sizes") - .chain_update(name) - .finalize(); - - <[u8; 32]>::from(key.into_bytes()) - }; - - make_key(b"Client Key").into() +) -> pbkdf2::Block { + pool.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations)) + .await } +/// For cleartext flow, we need to derive the client key to +/// 1. authenticate the client. +/// 2. authenticate with compute. pub(crate) async fn exchange( pool: &ThreadPool, endpoint: EndpointIdInt, + role: RoleNameInt, + secret: &ServerSecret, + password: &[u8], +) -> sasl::Result> { + if secret.iterations > CACHED_ROUNDS { + exchange_with_cache(pool, endpoint, role, secret, password).await + } else { + let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?; + let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await; + Ok(validate_pbkdf2(secret, &hash)) + } +} + +/// Compute the client key using a cache. We cache the suffix of the pbkdf2 result only, +/// which is not enough by itself to perform an offline brute force. +async fn exchange_with_cache( + pool: &ThreadPool, + endpoint: EndpointIdInt, + role: RoleNameInt, secret: &ServerSecret, password: &[u8], ) -> sasl::Result> { let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?; - let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await; + debug_assert!( + secret.iterations > CACHED_ROUNDS, + "we should not cache password data if there isn't enough rounds needed" + ); + + // compute the prefix of the pbkdf2 output. + let prefix = derive_client_key(pool, endpoint, password, &salt, CACHED_ROUNDS).await; + + if let Some(entry) = pool.cache.get_entry(endpoint, role) { + // hot path: let's check the threadpool cache + if secret.cached_at == entry.cached_from { + // cache is valid. compute the full hash by adding the prefix to the suffix. + let mut hash = prefix; + pbkdf2::xor_assign(&mut hash, &entry.suffix); + let outcome = validate_pbkdf2(secret, &hash); + + if matches!(outcome, sasl::Outcome::Success(_)) { + trace!("password validated from cache"); + } + + return Ok(outcome); + } + + // cached key is no longer valid. + debug!("invalidating cached password"); + entry.invalidate(); + } + + // slow path: full password hash. + let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await; + let outcome = validate_pbkdf2(secret, &hash); + + let client_key = match outcome { + sasl::Outcome::Success(client_key) => client_key, + sasl::Outcome::Failure(_) => return Ok(outcome), + }; + + trace!("storing cached password"); + + // time to cache, compute the suffix by subtracting the prefix from the hash. + let mut suffix = hash; + pbkdf2::xor_assign(&mut suffix, &prefix); + + pool.cache.insert( + endpoint, + role, + Pbkdf2CacheEntry { + cached_from: secret.cached_at, + suffix, + }, + ); + + Ok(sasl::Outcome::Success(client_key)) +} + +fn validate_pbkdf2(secret: &ServerSecret, hash: &pbkdf2::Block) -> sasl::Outcome { + let client_key = super::ScramKey::client_key(&(*hash).into()); if secret.is_password_invalid(&client_key).into() { - Ok(sasl::Outcome::Failure("password doesn't match")) + sasl::Outcome::Failure("password doesn't match") } else { - Ok(sasl::Outcome::Success(client_key)) + sasl::Outcome::Success(client_key) } } +const CACHED_ROUNDS: u32 = 16; + impl SaslInitial { fn transition( &self, diff --git a/proxy/src/scram/key.rs b/proxy/src/scram/key.rs index fe55ff493b..7dc52fd409 100644 --- a/proxy/src/scram/key.rs +++ b/proxy/src/scram/key.rs @@ -1,6 +1,12 @@ //! Tools for client/server/stored key management. +use hmac::Mac as _; +use sha2::Digest as _; use subtle::ConstantTimeEq; +use zeroize::Zeroize as _; + +use crate::metrics::Metrics; +use crate::scram::pbkdf2::Prf; /// Faithfully taken from PostgreSQL. pub(crate) const SCRAM_KEY_LEN: usize = 32; @@ -14,6 +20,12 @@ pub(crate) struct ScramKey { bytes: [u8; SCRAM_KEY_LEN], } +impl Drop for ScramKey { + fn drop(&mut self) { + self.bytes.zeroize(); + } +} + impl PartialEq for ScramKey { fn eq(&self, other: &Self) -> bool { self.ct_eq(other).into() @@ -28,12 +40,26 @@ impl ConstantTimeEq for ScramKey { impl ScramKey { pub(crate) fn sha256(&self) -> Self { - super::sha256([self.as_ref()]).into() + Metrics::get().proxy.sha_rounds.inc_by(1); + Self { + bytes: sha2::Sha256::digest(self.as_bytes()).into(), + } } pub(crate) fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] { self.bytes } + + pub(crate) fn client_key(b: &[u8; 32]) -> Self { + // Prf::new_from_slice will run 2 sha256 rounds. + // Update + Finalize run 2 sha256 rounds. + Metrics::get().proxy.sha_rounds.inc_by(4); + + let mut prf = Prf::new_from_slice(b).expect("HMAC is able to accept all key sizes"); + prf.update(b"Client Key"); + let client_key: [u8; 32] = prf.finalize().into_bytes().into(); + client_key.into() + } } impl From<[u8; SCRAM_KEY_LEN]> for ScramKey { diff --git a/proxy/src/scram/mod.rs b/proxy/src/scram/mod.rs index 5f627e062c..04722d920b 100644 --- a/proxy/src/scram/mod.rs +++ b/proxy/src/scram/mod.rs @@ -6,6 +6,7 @@ //! * //! * +mod cache; mod countmin; mod exchange; mod key; @@ -18,10 +19,8 @@ pub mod threadpool; use base64::Engine as _; use base64::prelude::BASE64_STANDARD; pub(crate) use exchange::{Exchange, exchange}; -use hmac::{Hmac, Mac}; pub(crate) use key::ScramKey; pub(crate) use secret::ServerSecret; -use sha2::{Digest, Sha256}; const SCRAM_SHA_256: &str = "SCRAM-SHA-256"; const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS"; @@ -42,29 +41,13 @@ fn base64_decode_array(input: impl AsRef<[u8]>) -> Option<[u8; N Some(bytes) } -/// This function essentially is `Hmac(sha256, key, input)`. -/// Further reading: . -fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator) -> [u8; 32] { - let mut mac = Hmac::::new_from_slice(key).expect("bad key size"); - parts.into_iter().for_each(|s| mac.update(s)); - - mac.finalize().into_bytes().into() -} - -fn sha256<'a>(parts: impl IntoIterator) -> [u8; 32] { - let mut hasher = Sha256::new(); - parts.into_iter().for_each(|s| hasher.update(s)); - - hasher.finalize().into() -} - #[cfg(test)] mod tests { use super::threadpool::ThreadPool; use super::{Exchange, ServerSecret}; - use crate::intern::EndpointIdInt; + use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::sasl::{Mechanism, Step}; - use crate::types::EndpointId; + use crate::types::{EndpointId, RoleName}; #[test] fn snapshot() { @@ -114,23 +97,34 @@ mod tests { ); } - async fn run_round_trip_test(server_password: &str, client_password: &str) { - let pool = ThreadPool::new(1); - + async fn check( + pool: &ThreadPool, + scram_secret: &ServerSecret, + password: &[u8], + ) -> Result<(), &'static str> { let ep = EndpointId::from("foo"); let ep = EndpointIdInt::from(ep); + let role = RoleName::from("user"); + let role = RoleNameInt::from(&role); - let scram_secret = ServerSecret::build(server_password).await.unwrap(); - let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes()) + let outcome = super::exchange(pool, ep, role, scram_secret, password) .await .unwrap(); match outcome { - crate::sasl::Outcome::Success(_) => {} - crate::sasl::Outcome::Failure(r) => panic!("{r}"), + crate::sasl::Outcome::Success(_) => Ok(()), + crate::sasl::Outcome::Failure(r) => Err(r), } } + async fn run_round_trip_test(server_password: &str, client_password: &str) { + let pool = ThreadPool::new(1); + let scram_secret = ServerSecret::build(server_password).await.unwrap(); + check(&pool, &scram_secret, client_password.as_bytes()) + .await + .unwrap(); + } + #[tokio::test] async fn round_trip() { run_round_trip_test("pencil", "pencil").await; @@ -141,4 +135,27 @@ mod tests { async fn failure() { run_round_trip_test("pencil", "eraser").await; } + + #[tokio::test] + #[tracing_test::traced_test] + async fn password_cache() { + let pool = ThreadPool::new(1); + let scram_secret = ServerSecret::build("password").await.unwrap(); + + // wrong passwords are not added to cache + check(&pool, &scram_secret, b"wrong").await.unwrap_err(); + assert!(!logs_contain("storing cached password")); + + // correct passwords get cached + check(&pool, &scram_secret, b"password").await.unwrap(); + assert!(logs_contain("storing cached password")); + + // wrong passwords do not match the cache + check(&pool, &scram_secret, b"wrong").await.unwrap_err(); + assert!(!logs_contain("password validated from cache")); + + // correct passwords match the cache + check(&pool, &scram_secret, b"password").await.unwrap(); + assert!(logs_contain("password validated from cache")); + } } diff --git a/proxy/src/scram/pbkdf2.rs b/proxy/src/scram/pbkdf2.rs index 7f48e00c41..1300310de2 100644 --- a/proxy/src/scram/pbkdf2.rs +++ b/proxy/src/scram/pbkdf2.rs @@ -1,25 +1,50 @@ +//! For postgres password authentication, we need to perform a PBKDF2 using +//! PRF=HMAC-SHA2-256, producing only 1 block (32 bytes) of output key. + +use hmac::Mac as _; use hmac::digest::consts::U32; use hmac::digest::generic_array::GenericArray; -use hmac::{Hmac, Mac}; -use sha2::Sha256; +use zeroize::Zeroize as _; + +use crate::metrics::Metrics; + +/// The Psuedo-random function used during PBKDF2 and the SCRAM-SHA-256 handshake. +pub type Prf = hmac::Hmac; +pub(crate) type Block = GenericArray; pub(crate) struct Pbkdf2 { - hmac: Hmac, - prev: GenericArray, - hi: GenericArray, + hmac: Prf, + /// U{r-1} for whatever iteration r we are currently on. + prev: Block, + /// the output of `fold(xor, U{1}..U{r})` for whatever iteration r we are currently on. + hi: Block, + /// number of iterations left iterations: u32, } +impl Drop for Pbkdf2 { + fn drop(&mut self) { + self.prev.zeroize(); + self.hi.zeroize(); + } +} + // inspired from impl Pbkdf2 { - pub(crate) fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self { + pub(crate) fn start(pw: &[u8], salt: &[u8], iterations: u32) -> Self { // key the HMAC and derive the first block in-place - let mut hmac = - Hmac::::new_from_slice(str).expect("HMAC is able to accept all key sizes"); + let mut hmac = Prf::new_from_slice(pw).expect("HMAC is able to accept all key sizes"); + + // U1 = PRF(Password, Salt + INT_32_BE(i)) + // i = 1 since we only need 1 block of output. hmac.update(salt); hmac.update(&1u32.to_be_bytes()); let init_block = hmac.finalize_reset().into_bytes(); + // Prf::new_from_slice will run 2 sha256 rounds. + // Our update + finalize run 2 sha256 rounds for each pbkdf2 round. + Metrics::get().proxy.sha_rounds.inc_by(4); + Self { hmac, // one iteration spent above @@ -33,7 +58,11 @@ impl Pbkdf2 { (self.iterations).clamp(0, 4096) } - pub(crate) fn turn(&mut self) -> std::task::Poll<[u8; 32]> { + /// For "fairness", we implement PBKDF2 with cooperative yielding, which is why we use this `turn` + /// function that only executes a fixed number of iterations before continuing. + /// + /// Task must be rescheuled if this returns [`std::task::Poll::Pending`]. + pub(crate) fn turn(&mut self) -> std::task::Poll { let Self { hmac, prev, @@ -44,25 +73,37 @@ impl Pbkdf2 { // only do up to 4096 iterations per turn for fairness let n = (*iterations).clamp(0, 4096); for _ in 0..n { - hmac.update(prev); - let block = hmac.finalize_reset().into_bytes(); - - for (hi_byte, &b) in hi.iter_mut().zip(block.iter()) { - *hi_byte ^= b; - } - - *prev = block; + let next = single_round(hmac, prev); + xor_assign(hi, &next); + *prev = next; } + // Our update + finalize run 2 sha256 rounds for each pbkdf2 round. + Metrics::get().proxy.sha_rounds.inc_by(2 * n as u64); + *iterations -= n; if *iterations == 0 { - std::task::Poll::Ready((*hi).into()) + std::task::Poll::Ready(*hi) } else { std::task::Poll::Pending } } } +#[inline(always)] +pub fn xor_assign(x: &mut Block, y: &Block) { + for (x, &y) in std::iter::zip(x, y) { + *x ^= y; + } +} + +#[inline(always)] +fn single_round(prf: &mut Prf, ui: &Block) -> Block { + // Ui = PRF(Password, Ui-1) + prf.update(ui); + prf.finalize_reset().into_bytes() +} + #[cfg(test)] mod tests { use pbkdf2::pbkdf2_hmac_array; @@ -76,11 +117,11 @@ mod tests { let pass = b"Ne0n_!5_50_C007"; let mut job = Pbkdf2::start(pass, salt, 60000); - let hash = loop { + let hash: [u8; 32] = loop { let std::task::Poll::Ready(hash) = job.turn() else { continue; }; - break hash; + break hash.into(); }; let expected = pbkdf2_hmac_array::(pass, salt, 60000); diff --git a/proxy/src/scram/secret.rs b/proxy/src/scram/secret.rs index 0e070c2f27..a3a64f271c 100644 --- a/proxy/src/scram/secret.rs +++ b/proxy/src/scram/secret.rs @@ -3,6 +3,7 @@ use base64::Engine as _; use base64::prelude::BASE64_STANDARD; use subtle::{Choice, ConstantTimeEq}; +use tokio::time::Instant; use super::base64_decode_array; use super::key::ScramKey; @@ -11,6 +12,9 @@ use super::key::ScramKey; /// and is used throughout the authentication process. #[derive(Clone, Eq, PartialEq, Debug)] pub(crate) struct ServerSecret { + /// When this secret was cached. + pub(crate) cached_at: Instant, + /// Number of iterations for `PBKDF2` function. pub(crate) iterations: u32, /// Salt used to hash user's password. @@ -34,6 +38,7 @@ impl ServerSecret { params.split_once(':').zip(keys.split_once(':'))?; let secret = ServerSecret { + cached_at: Instant::now(), iterations: iterations.parse().ok()?, salt_base64: salt.into(), stored_key: base64_decode_array(stored_key)?.into(), @@ -54,6 +59,7 @@ impl ServerSecret { /// See `auth-scram.c : mock_scram_secret` for details. pub(crate) fn mock(nonce: [u8; 32]) -> Self { Self { + cached_at: Instant::now(), // this doesn't reveal much information as we're going to use // iteration count 1 for our generated passwords going forward. // PG16 users can set iteration count=1 already today. diff --git a/proxy/src/scram/signature.rs b/proxy/src/scram/signature.rs index a5b1c3e9f4..8e074272b6 100644 --- a/proxy/src/scram/signature.rs +++ b/proxy/src/scram/signature.rs @@ -1,6 +1,10 @@ //! Tools for client/server signature management. +use hmac::Mac as _; + use super::key::{SCRAM_KEY_LEN, ScramKey}; +use crate::metrics::Metrics; +use crate::scram::pbkdf2::Prf; /// A collection of message parts needed to derive the client's signature. #[derive(Debug)] @@ -12,15 +16,18 @@ pub(crate) struct SignatureBuilder<'a> { impl SignatureBuilder<'_> { pub(crate) fn build(&self, key: &ScramKey) -> Signature { - let parts = [ - self.client_first_message_bare.as_bytes(), - b",", - self.server_first_message.as_bytes(), - b",", - self.client_final_message_without_proof.as_bytes(), - ]; + // don't know exactly. this is a rough approx + Metrics::get().proxy.sha_rounds.inc_by(8); - super::hmac_sha256(key.as_ref(), parts).into() + let mut mac = Prf::new_from_slice(key.as_ref()).expect("HMAC accepts all key sizes"); + mac.update(self.client_first_message_bare.as_bytes()); + mac.update(b","); + mac.update(self.server_first_message.as_bytes()); + mac.update(b","); + mac.update(self.client_final_message_without_proof.as_bytes()); + Signature { + bytes: mac.finalize().into_bytes().into(), + } } } diff --git a/proxy/src/scram/threadpool.rs b/proxy/src/scram/threadpool.rs index ea2e29ede9..20a1df2b53 100644 --- a/proxy/src/scram/threadpool.rs +++ b/proxy/src/scram/threadpool.rs @@ -15,6 +15,8 @@ use futures::FutureExt; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; +use super::cache::Pbkdf2Cache; +use super::pbkdf2; use super::pbkdf2::Pbkdf2; use crate::intern::EndpointIdInt; use crate::metrics::{ThreadPoolMetrics, ThreadPoolWorkerId}; @@ -23,6 +25,10 @@ use crate::scram::countmin::CountMinSketch; pub struct ThreadPool { runtime: Option, pub metrics: Arc, + + // we hash a lot of passwords. + // we keep a cache of partial hashes for faster validation. + pub(super) cache: Pbkdf2Cache, } /// How often to reset the sketch values @@ -68,6 +74,7 @@ impl ThreadPool { Self { runtime: Some(runtime), metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)), + cache: Pbkdf2Cache::new(), } }) } @@ -130,7 +137,7 @@ struct JobSpec { } impl Future for JobSpec { - type Output = [u8; 32]; + type Output = pbkdf2::Block; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { STATE.with_borrow_mut(|state| { @@ -166,10 +173,10 @@ impl Future for JobSpec { } } -pub(crate) struct JobHandle(tokio::task::JoinHandle<[u8; 32]>); +pub(crate) struct JobHandle(tokio::task::JoinHandle); impl Future for JobHandle { - type Output = [u8; 32]; + type Output = pbkdf2::Block; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.0.poll_unpin(cx) { @@ -203,10 +210,10 @@ mod tests { .spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096)) .await; - let expected = [ + let expected = &[ 10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242, 178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140, ]; - assert_eq!(actual, expected); + assert_eq!(actual.as_slice(), expected); } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 59e4b09bc9..5b356c8460 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -1,46 +1,41 @@ -use std::io; -use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::time::Duration; -use async_trait::async_trait; use ed25519_dalek::SigningKey; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use jose_jwk::jose_b64; -use postgres_client::config::SslMode; +use postgres_client::error::SqlState; +use postgres_client::maybe_tls_stream::MaybeTlsStream; use rand_core::OsRng; -use rustls::pki_types::{DnsName, ServerName}; -use tokio::net::{TcpStream, lookup_host}; -use tokio_rustls::TlsConnector; use tracing::field::display; use tracing::{debug, info}; use super::AsyncRW; use super::conn_pool::poll_client; use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool}; -use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client}; +use super::http_conn_pool::{self, HttpConnPool, LocalProxyClient, poll_http2_client}; use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool}; use crate::auth::backend::local::StaticAuthRules; -use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo}; +use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; use crate::auth::{self, AuthError}; +use crate::compute; use crate::compute_ctl::{ ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, }; -use crate::config::{ComputeConfig, ProxyConfig}; +use crate::config::ProxyConfig; use crate::context::RequestContext; -use crate::control_plane::CachedNodeInfo; use crate::control_plane::client::ApiLockError; use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; -use crate::control_plane::locks::ApiLocks; use crate::error::{ErrorKind, ReportableError, UserFacingError}; -use crate::intern::EndpointIdInt; -use crate::proxy::connect_compute::ConnectMechanism; -use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute}; +use crate::intern::{EndpointIdInt, RoleNameInt}; +use crate::pqproto::StartupMessageParams; +use crate::proxy::{connect_auth, connect_compute}; use crate::rate_limiter::EndpointRateLimiter; -use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX}; +use crate::types::{EndpointId, LOCAL_PROXY_SUFFIX}; pub(crate) struct PoolingBackend { - pub(crate) http_conn_pool: Arc>>, + pub(crate) http_conn_pool: + Arc>>, pub(crate) local_pool: Arc>, pub(crate) pool: Arc>>, @@ -82,9 +77,11 @@ impl PoolingBackend { }; let ep = EndpointIdInt::from(&user_info.endpoint); + let role = RoleNameInt::from(&user_info.user); let auth_outcome = crate::auth::validate_password_and_exchange( - &self.config.authentication_config.thread_pool, + &self.config.authentication_config.scram_thread_pool, ep, + role, password, secret, ) @@ -185,20 +182,42 @@ impl PoolingBackend { tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); let backend = self.auth_backend.as_ref().map(|()| keys.info); - crate::proxy::connect_compute::connect_to_compute( + + let mut params = StartupMessageParams::default(); + params.insert("database", &conn_info.dbname); + params.insert("user", &conn_info.user_info.user); + + let mut auth_info = compute::AuthInfo::with_auth_keys(keys.keys); + auth_info.set_startup_params(¶ms, true); + + let node = connect_auth::connect_to_compute_and_auth( ctx, - &TokioMechanism { - conn_id, - conn_info, - pool: self.pool.clone(), - locks: &self.config.connect_compute_locks, - keys: keys.keys, - }, + self.config, &backend, - self.config.wake_compute_retry_config, - &self.config.connect_to_compute, + auth_info, + connect_compute::TlsNegotiation::Postgres, ) - .await + .await?; + + let (client, connection) = postgres_client::connect::managed( + node.stream, + Some(node.socket_addr.ip()), + postgres_client::config::Host::Tcp(node.hostname.to_string()), + node.socket_addr.port(), + node.ssl_mode, + Some(self.config.connect_to_compute.timeout), + ) + .await?; + + Ok(poll_client( + self.pool.clone(), + ctx, + conn_info, + client, + connection, + conn_id, + node.aux, + )) } // Wake up the destination if needed @@ -210,7 +229,7 @@ impl PoolingBackend { &self, ctx: &RequestContext, conn_info: ConnInfo, - ) -> Result, HttpConnError> { + ) -> Result, HttpConnError> { debug!("pool: looking for an existing connection"); if let Ok(Some(client)) = self.http_conn_pool.get(ctx, &conn_info) { return Ok(client); @@ -227,19 +246,38 @@ impl PoolingBackend { )), options: conn_info.user_info.options.clone(), }); - crate::proxy::connect_compute::connect_to_compute( + + let node = connect_compute::connect_to_compute( ctx, - &HyperMechanism { - conn_id, - conn_info, - pool: self.http_conn_pool.clone(), - locks: &self.config.connect_compute_locks, - }, + self.config, &backend, - self.config.wake_compute_retry_config, - &self.config.connect_to_compute, + connect_compute::TlsNegotiation::Direct, ) - .await + .await?; + + let stream = match node.stream.into_framed().into_inner() { + MaybeTlsStream::Raw(s) => Box::pin(s) as AsyncRW, + MaybeTlsStream::Tls(s) => Box::pin(s) as AsyncRW, + }; + + let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()) + .timer(TokioTimer::new()) + .keep_alive_interval(Duration::from_secs(20)) + .keep_alive_while_idle(true) + .keep_alive_timeout(Duration::from_secs(5)) + .handshake(TokioIo::new(stream)) + .await + .map_err(LocalProxyConnError::H2)?; + + Ok(poll_http2_client( + self.http_conn_pool.clone(), + ctx, + &conn_info, + client, + connection, + conn_id, + node.aux.clone(), + )) } /// Connect to postgres over localhost. @@ -379,6 +417,8 @@ fn create_random_jwk() -> (SigningKey, jose_jwk::Key) { pub(crate) enum HttpConnError { #[error("pooled connection closed at inconsistent state")] ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError), + #[error("could not connect to compute")] + ConnectError(#[from] compute::ConnectionError), #[error("could not connect to postgres in compute")] PostgresConnectionError(#[from] postgres_client::Error), #[error("could not connect to local-proxy in compute")] @@ -398,10 +438,19 @@ pub(crate) enum HttpConnError { TooManyConnectionAttempts(#[from] ApiLockError), } +impl From for HttpConnError { + fn from(value: connect_auth::AuthError) -> Self { + match value { + connect_auth::AuthError::Auth(compute::PostgresError::Postgres(error)) => { + Self::PostgresConnectionError(error) + } + connect_auth::AuthError::Connect(error) => Self::ConnectError(error), + } + } +} + #[derive(Debug, thiserror::Error)] pub(crate) enum LocalProxyConnError { - #[error("error with connection to local-proxy")] - Io(#[source] std::io::Error), #[error("could not establish h2 connection")] H2(#[from] hyper::Error), } @@ -409,16 +458,16 @@ pub(crate) enum LocalProxyConnError { impl ReportableError for HttpConnError { fn get_error_kind(&self) -> ErrorKind { match self { + HttpConnError::ConnectError(e) => e.get_error_kind(), HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute, - HttpConnError::PostgresConnectionError(p) => { - if p.as_db_error().is_some() { - // postgres rejected the connection - ErrorKind::Postgres - } else { - // couldn't even reach postgres - ErrorKind::Compute - } - } + HttpConnError::PostgresConnectionError(p) => match p.as_db_error() { + // user provided a wrong database name + Some(err) if err.code() == &SqlState::INVALID_CATALOG_NAME => ErrorKind::User, + // postgres rejected the connection + Some(_) => ErrorKind::Postgres, + // couldn't even reach postgres + None => ErrorKind::Compute, + }, HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute, HttpConnError::ComputeCtl(_) => ErrorKind::Service, HttpConnError::JwtPayloadError(_) => ErrorKind::User, @@ -433,6 +482,7 @@ impl ReportableError for HttpConnError { impl UserFacingError for HttpConnError { fn to_string_client(&self) -> String { match self { + HttpConnError::ConnectError(p) => p.to_string_client(), HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(), HttpConnError::PostgresConnectionError(p) => p.to_string(), HttpConnError::LocalProxyConnectionError(p) => p.to_string(), @@ -448,36 +498,9 @@ impl UserFacingError for HttpConnError { } } -impl CouldRetry for HttpConnError { - fn could_retry(&self) -> bool { - match self { - HttpConnError::PostgresConnectionError(e) => e.could_retry(), - HttpConnError::LocalProxyConnectionError(e) => e.could_retry(), - HttpConnError::ComputeCtl(_) => false, - HttpConnError::ConnectionClosedAbruptly(_) => false, - HttpConnError::JwtPayloadError(_) => false, - HttpConnError::GetAuthInfo(_) => false, - HttpConnError::AuthError(_) => false, - HttpConnError::WakeCompute(_) => false, - HttpConnError::TooManyConnectionAttempts(_) => false, - } - } -} -impl ShouldRetryWakeCompute for HttpConnError { - fn should_retry_wake_compute(&self) -> bool { - match self { - HttpConnError::PostgresConnectionError(e) => e.should_retry_wake_compute(), - // we never checked cache validity - HttpConnError::TooManyConnectionAttempts(_) => false, - _ => true, - } - } -} - impl ReportableError for LocalProxyConnError { fn get_error_kind(&self) -> ErrorKind { match self { - LocalProxyConnError::Io(_) => ErrorKind::Compute, LocalProxyConnError::H2(_) => ErrorKind::Compute, } } @@ -488,209 +511,3 @@ impl UserFacingError for LocalProxyConnError { "Could not establish HTTP connection to the database".to_string() } } - -impl CouldRetry for LocalProxyConnError { - fn could_retry(&self) -> bool { - match self { - LocalProxyConnError::Io(_) => false, - LocalProxyConnError::H2(_) => false, - } - } -} -impl ShouldRetryWakeCompute for LocalProxyConnError { - fn should_retry_wake_compute(&self) -> bool { - match self { - LocalProxyConnError::Io(_) => false, - LocalProxyConnError::H2(_) => false, - } - } -} - -struct TokioMechanism { - pool: Arc>>, - conn_info: ConnInfo, - conn_id: uuid::Uuid, - keys: ComputeCredentialKeys, - - /// connect_to_compute concurrency lock - locks: &'static ApiLocks, -} - -#[async_trait] -impl ConnectMechanism for TokioMechanism { - type Connection = Client; - type ConnectError = HttpConnError; - type Error = HttpConnError; - - async fn connect_once( - &self, - ctx: &RequestContext, - node_info: &CachedNodeInfo, - compute_config: &ComputeConfig, - ) -> Result { - let permit = self.locks.get_permit(&node_info.conn_info.host).await?; - - let mut config = node_info.conn_info.to_postgres_client_config(); - let config = config - .user(&self.conn_info.user_info.user) - .dbname(&self.conn_info.dbname) - .connect_timeout(compute_config.timeout); - - if let ComputeCredentialKeys::AuthKeys(auth_keys) = self.keys { - config.auth_keys(auth_keys); - } - - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let res = config.connect(compute_config).await; - drop(pause); - let (client, connection) = permit.release_result(res)?; - - tracing::Span::current().record("pid", tracing::field::display(client.get_process_id())); - tracing::Span::current().record( - "compute_id", - tracing::field::display(&node_info.aux.compute_id), - ); - - if let Some(query_id) = ctx.get_testodrome_id() { - info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id); - } - - Ok(poll_client( - self.pool.clone(), - ctx, - self.conn_info.clone(), - client, - connection, - self.conn_id, - node_info.aux.clone(), - )) - } -} - -struct HyperMechanism { - pool: Arc>>, - conn_info: ConnInfo, - conn_id: uuid::Uuid, - - /// connect_to_compute concurrency lock - locks: &'static ApiLocks, -} - -#[async_trait] -impl ConnectMechanism for HyperMechanism { - type Connection = http_conn_pool::Client; - type ConnectError = HttpConnError; - type Error = HttpConnError; - - async fn connect_once( - &self, - ctx: &RequestContext, - node_info: &CachedNodeInfo, - config: &ComputeConfig, - ) -> Result { - let host_addr = node_info.conn_info.host_addr; - let host = &node_info.conn_info.host; - let permit = self.locks.get_permit(host).await?; - - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - - let tls = if node_info.conn_info.ssl_mode == SslMode::Disable { - None - } else { - Some(&config.tls) - }; - - let port = node_info.conn_info.port; - let res = connect_http2(host_addr, host, port, config.timeout, tls).await; - drop(pause); - let (client, connection) = permit.release_result(res)?; - - tracing::Span::current().record( - "compute_id", - tracing::field::display(&node_info.aux.compute_id), - ); - - if let Some(query_id) = ctx.get_testodrome_id() { - info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id); - } - - Ok(poll_http2_client( - self.pool.clone(), - ctx, - &self.conn_info, - client, - connection, - self.conn_id, - node_info.aux.clone(), - )) - } -} - -async fn connect_http2( - host_addr: Option, - host: &str, - port: u16, - timeout: Duration, - tls: Option<&Arc>, -) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> { - let addrs = match host_addr { - Some(addr) => vec![SocketAddr::new(addr, port)], - None => lookup_host((host, port)) - .await - .map_err(LocalProxyConnError::Io)? - .collect(), - }; - let mut last_err = None; - - let mut addrs = addrs.into_iter(); - let stream = loop { - let Some(addr) = addrs.next() else { - return Err(last_err.unwrap_or_else(|| { - LocalProxyConnError::Io(io::Error::new( - io::ErrorKind::InvalidInput, - "could not resolve any addresses", - )) - })); - }; - - match tokio::time::timeout(timeout, TcpStream::connect(addr)).await { - Ok(Ok(stream)) => { - stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?; - break stream; - } - Ok(Err(e)) => { - last_err = Some(LocalProxyConnError::Io(e)); - } - Err(e) => { - last_err = Some(LocalProxyConnError::Io(io::Error::new( - io::ErrorKind::TimedOut, - e, - ))); - } - } - }; - - let stream = if let Some(tls) = tls { - let host = DnsName::try_from(host) - .map_err(io::Error::other) - .map_err(LocalProxyConnError::Io)? - .to_owned(); - let stream = TlsConnector::from(tls.clone()) - .connect(ServerName::DnsName(host), stream) - .await - .map_err(LocalProxyConnError::Io)?; - Box::pin(stream) as AsyncRW - } else { - Box::pin(stream) as AsyncRW - }; - - let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()) - .timer(TokioTimer::new()) - .keep_alive_interval(Duration::from_secs(20)) - .keep_alive_while_idle(true) - .keep_alive_timeout(Duration::from_secs(5)) - .handshake(TokioIo::new(stream)) - .await?; - - Ok((client, connection)) -} diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 015c46f787..17305e30f1 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -190,6 +190,9 @@ mod tests { fn get_process_id(&self) -> i32 { 0 } + fn reset(&mut self) -> Result<(), postgres_client::Error> { + Ok(()) + } } fn create_inner() -> ClientInnerCommon { diff --git a/proxy/src/serverless/conn_pool_lib.rs b/proxy/src/serverless/conn_pool_lib.rs index ed5cc0ea03..6adca49723 100644 --- a/proxy/src/serverless/conn_pool_lib.rs +++ b/proxy/src/serverless/conn_pool_lib.rs @@ -7,10 +7,9 @@ use std::time::Duration; use clashmap::ClashMap; use parking_lot::RwLock; -use postgres_client::ReadyForQueryStatus; use rand::Rng; use smol_str::ToSmolStr; -use tracing::{Span, debug, info}; +use tracing::{Span, debug, info, warn}; use super::backend::HttpConnError; use super::conn_pool::ClientDataRemote; @@ -188,7 +187,7 @@ impl EndpointConnPool { self.pools.get_mut(&db_user) } - pub(crate) fn put(pool: &RwLock, conn_info: &ConnInfo, client: ClientInnerCommon) { + pub(crate) fn put(pool: &RwLock, conn_info: &ConnInfo, mut client: ClientInnerCommon) { let conn_id = client.get_conn_id(); let (max_conn, conn_count, pool_name) = { let pool = pool.read(); @@ -201,12 +200,17 @@ impl EndpointConnPool { }; if client.inner.is_closed() { - info!(%conn_id, "{}: throwing away connection '{conn_info}' because connection is closed", pool_name); + info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because connection is closed"); + return; + } + + if let Err(error) = client.inner.reset() { + warn!(?error, %conn_id, "{pool_name}: throwing away connection '{conn_info}' because connection could not be reset"); return; } if conn_count >= max_conn { - info!(%conn_id, "{}: throwing away connection '{conn_info}' because pool is full", pool_name); + info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because pool is full"); return; } @@ -691,6 +695,7 @@ impl Deref for Client { pub(crate) trait ClientInnerExt: Sync + Send + 'static { fn is_closed(&self) -> bool; fn get_process_id(&self) -> i32; + fn reset(&mut self) -> Result<(), postgres_client::Error>; } impl ClientInnerExt for postgres_client::Client { @@ -701,15 +706,13 @@ impl ClientInnerExt for postgres_client::Client { fn get_process_id(&self) -> i32 { self.get_process_id() } + + fn reset(&mut self) -> Result<(), postgres_client::Error> { + self.reset_session_background() + } } impl Discard<'_, C> { - pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) { - let conn_info = &self.conn_info; - if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { - info!("pool: throwing away connection '{conn_info}' because connection is not idle"); - } - } pub(crate) fn discard(&mut self) { let conn_info = &self.conn_info; if std::mem::take(self.pool).strong_count() > 0 { diff --git a/proxy/src/serverless/http_conn_pool.rs b/proxy/src/serverless/http_conn_pool.rs index 7acd816026..bf6b934d20 100644 --- a/proxy/src/serverless/http_conn_pool.rs +++ b/proxy/src/serverless/http_conn_pool.rs @@ -23,8 +23,8 @@ use crate::protocol2::ConnectionInfoExtra; use crate::types::EndpointCacheKey; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; -pub(crate) type Send = http2::SendRequest>; -pub(crate) type Connect = +pub(crate) type LocalProxyClient = http2::SendRequest>; +pub(crate) type LocalProxyConnection = http2::Connection, BoxBody, TokioExecutor>; #[derive(Clone)] @@ -189,14 +189,14 @@ impl GlobalConnPool> { } pub(crate) fn poll_http2_client( - global_pool: Arc>>, + global_pool: Arc>>, ctx: &RequestContext, conn_info: &ConnInfo, - client: Send, - connection: Connect, + client: LocalProxyClient, + connection: LocalProxyConnection, conn_id: uuid::Uuid, aux: MetricsAuxInfo, -) -> Client { +) -> Client { let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol()); let session_id = ctx.session_id(); @@ -285,7 +285,7 @@ impl Client { } } -impl ClientInnerExt for Send { +impl ClientInnerExt for LocalProxyClient { fn is_closed(&self) -> bool { self.is_closed() } @@ -294,4 +294,10 @@ impl ClientInnerExt for Send { // ideally throw something meaningful -1 } + + fn reset(&mut self) -> Result<(), postgres_client::Error> { + // We use HTTP/2.0 to talk to local proxy. HTTP is stateless, + // so there's nothing to reset. + Ok(()) + } } diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index f63d84d66b..b8a502c37e 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -269,11 +269,6 @@ impl ClientInnerCommon { local_data.jti += 1; let token = resign_jwt(&local_data.key, payload, local_data.jti)?; - self.inner - .discard_all() - .await - .map_err(SqlOverHttpError::InternalPostgres)?; - // initiates the auth session // this is safe from query injections as the jwt format free of any escape characters. let query = format!("select auth.jwt_session_init('{token}')"); diff --git a/proxy/src/serverless/rest.rs b/proxy/src/serverless/rest.rs index 173c2629f7..0c3d2c958d 100644 --- a/proxy/src/serverless/rest.rs +++ b/proxy/src/serverless/rest.rs @@ -1,5 +1,6 @@ use std::borrow::Cow; use std::collections::HashMap; +use std::convert::Infallible; use std::sync::Arc; use bytes::Bytes; @@ -12,6 +13,7 @@ use hyper::body::Incoming; use hyper::http::{HeaderName, HeaderValue}; use hyper::{Request, Response, StatusCode}; use indexmap::IndexMap; +use moka::sync::Cache; use ouroboros::self_referencing; use serde::de::DeserializeOwned; use serde::{Deserialize, Deserializer}; @@ -46,19 +48,19 @@ use super::backend::{HttpConnError, LocalProxyConnError, PoolingBackend}; use super::conn_pool::AuthData; use super::conn_pool_lib::ConnInfo; use super::error::{ConnInfoError, Credentials, HttpCodeError, ReadPayloadError}; -use super::http_conn_pool::{self, Send}; +use super::http_conn_pool::{self, LocalProxyClient}; use super::http_util::{ ALLOW_POOL, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_ISOLATION_LEVEL, TXN_READ_ONLY, get_conn_info, json_response, uuid_to_header_value, }; use super::json::JsonConversionError; use crate::auth::backend::ComputeCredentialKeys; -use crate::cache::{Cached, TimedLru}; +use crate::cache::common::{count_cache_insert, count_cache_outcome, eviction_listener}; use crate::config::ProxyConfig; use crate::context::RequestContext; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::http::read_body_with_limit; -use crate::metrics::Metrics; +use crate::metrics::{CacheKind, Metrics}; use crate::serverless::sql_over_http::HEADER_VALUE_TRUE; use crate::types::EndpointCacheKey; use crate::util::deserialize_json_string; @@ -138,19 +140,43 @@ pub struct ApiConfig { } // The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint -pub(crate) type DbSchemaCache = TimedLru>; +pub(crate) struct DbSchemaCache(Cache>); impl DbSchemaCache { + pub fn new(config: crate::config::CacheOptions) -> Self { + let builder = Cache::builder().name("schema"); + let builder = config.moka(builder); + + let metrics = &Metrics::get().cache; + if let Some(size) = config.size { + metrics.capacity.set(CacheKind::Schema, size as i64); + } + + let builder = + builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Schema, cause)); + + Self(builder.build()) + } + + pub async fn maintain(&self) -> Result { + let mut ticker = tokio::time::interval(std::time::Duration::from_secs(60)); + loop { + ticker.tick().await; + self.0.run_pending_tasks(); + } + } + pub async fn get_cached_or_remote( &self, endpoint_id: &EndpointCacheKey, auth_header: &HeaderValue, connection_string: &str, - client: &mut http_conn_pool::Client, + client: &mut http_conn_pool::Client, ctx: &RequestContext, config: &'static ProxyConfig, ) -> Result, RestError> { - match self.get_with_created_at(endpoint_id) { - Some(Cached { value: (v, _), .. }) => Ok(v), + let cache_result = count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id)); + match cache_result { + Some(v) => Ok(v), None => { info!("db_schema cache miss for endpoint: {:?}", endpoint_id); let remote_value = self @@ -173,7 +199,8 @@ impl DbSchemaCache { db_extra_search_path: None, }; let value = Arc::new((api_config, schema_owned)); - self.insert(endpoint_id.clone(), value); + count_cache_insert(CacheKind::Schema); + self.0.insert(endpoint_id.clone(), value); return Err(e); } Err(e) => { @@ -181,7 +208,8 @@ impl DbSchemaCache { } }; let value = Arc::new((api_config, schema_owned)); - self.insert(endpoint_id.clone(), value.clone()); + count_cache_insert(CacheKind::Schema); + self.0.insert(endpoint_id.clone(), value.clone()); Ok(value) } } @@ -190,7 +218,7 @@ impl DbSchemaCache { &self, auth_header: &HeaderValue, connection_string: &str, - client: &mut http_conn_pool::Client, + client: &mut http_conn_pool::Client, ctx: &RequestContext, config: &'static ProxyConfig, ) -> Result<(ApiConfig, DbSchemaOwned), RestError> { @@ -430,7 +458,7 @@ struct BatchQueryData<'a> { } async fn make_local_proxy_request( - client: &mut http_conn_pool::Client, + client: &mut http_conn_pool::Client, headers: impl IntoIterator, body: QueryData<'_>, max_len: usize, @@ -461,7 +489,7 @@ async fn make_local_proxy_request( } async fn make_raw_local_proxy_request( - client: &mut http_conn_pool::Client, + client: &mut http_conn_pool::Client, headers: impl IntoIterator, body: String, ) -> Result, RestError> { diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index f254b41b5b..c334e820d7 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -192,34 +192,29 @@ pub(crate) async fn handle( let line = get(db_error, |db| db.line().map(|l| l.to_string())); let routine = get(db_error, |db| db.routine()); - match &e { - SqlOverHttpError::Postgres(e) - if e.as_db_error().is_some() && error_kind == ErrorKind::User => - { - // this error contains too much info, and it's not an error we care about. - if tracing::enabled!(Level::DEBUG) { - tracing::debug!( - kind=error_kind.to_metric_label(), - error=%e, - msg=message, - "forwarding error to user" - ); - } else { - tracing::info!( - kind = error_kind.to_metric_label(), - error = "bad query", - "forwarding error to user" - ); - } - } - _ => { - tracing::info!( + if db_error.is_some() && error_kind == ErrorKind::User { + // this error contains too much info, and it's not an error we care about. + if tracing::enabled!(Level::DEBUG) { + debug!( kind=error_kind.to_metric_label(), error=%e, msg=message, "forwarding error to user" ); + } else { + info!( + kind = error_kind.to_metric_label(), + error = "bad query", + "forwarding error to user" + ); } + } else { + info!( + kind=error_kind.to_metric_label(), + error=%e, + msg=message, + "forwarding error to user" + ); } json_response( @@ -735,9 +730,7 @@ impl QueryData { match batch_result { // The query successfully completed. - Ok(status) => { - discard.check_idle(status); - + Ok(_) => { let json_output = String::from_utf8(json_buf).expect("json should be valid utf8"); Ok(json_output) } @@ -793,7 +786,7 @@ impl BatchQueryData { { Ok(json_output) => { info!("commit"); - let status = transaction + transaction .commit() .await .inspect_err(|_| { @@ -802,7 +795,6 @@ impl BatchQueryData { discard.discard(); }) .map_err(SqlOverHttpError::Postgres)?; - discard.check_idle(status); json_output } Err(SqlOverHttpError::Cancelled(_)) => { @@ -815,17 +807,6 @@ impl BatchQueryData { return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres)); } Err(err) => { - info!("rollback"); - let status = transaction - .rollback() - .await - .inspect_err(|_| { - // if we cannot rollback - for now don't return connection to pool - // TODO: get a query status from the error - discard.discard(); - }) - .map_err(SqlOverHttpError::Postgres)?; - discard.check_idle(status); return Err(err); } }; @@ -1012,12 +993,6 @@ impl Client { } impl Discard<'_> { - fn check_idle(&mut self, status: ReadyForQueryStatus) { - match self { - Discard::Remote(discard) => discard.check_idle(status), - Discard::Local(discard) => discard.check_idle(status), - } - } fn discard(&mut self) { match self { Discard::Remote(discard) => discard.discard(), diff --git a/proxy/src/signals.rs b/proxy/src/signals.rs index 32b2344a1c..63ef7d1061 100644 --- a/proxy/src/signals.rs +++ b/proxy/src/signals.rs @@ -4,6 +4,8 @@ use anyhow::bail; use tokio_util::sync::CancellationToken; use tracing::{info, warn}; +use crate::metrics::{Metrics, ServiceInfo}; + /// Handle unix signals appropriately. pub async fn handle( token: CancellationToken, @@ -28,10 +30,12 @@ where // Shut down the whole application. _ = interrupt.recv() => { warn!("received SIGINT, exiting immediately"); + Metrics::get().service.info.set_label(ServiceInfo::terminating()); bail!("interrupted"); } _ = terminate.recv() => { warn!("received SIGTERM, shutting down once all existing connections have closed"); + Metrics::get().service.info.set_label(ServiceInfo::terminating()); token.cancel(); } } diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index 4e55654515..9447b9623b 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -102,7 +102,7 @@ pub struct ReportedError { } impl ReportedError { - pub fn new(e: (impl UserFacingError + Into)) -> Self { + pub fn new(e: impl UserFacingError + Into) -> Self { let error_kind = e.get_error_kind(); Self { source: e.into(), @@ -154,6 +154,15 @@ impl PqStream { message.write_message(&mut self.write); } + /// Write the buffer to the socket until we have some more space again. + pub async fn write_if_full(&mut self) -> io::Result<()> { + while self.write.occupied_len() > 2048 { + self.stream.write_buf(&mut self.write).await?; + } + + Ok(()) + } + /// Flush the output buffer into the underlying stream. /// /// This is cancel safe. diff --git a/safekeeper/src/copy_timeline.rs b/safekeeper/src/copy_timeline.rs index 7984c2e2b9..1ab6246206 100644 --- a/safekeeper/src/copy_timeline.rs +++ b/safekeeper/src/copy_timeline.rs @@ -161,9 +161,9 @@ pub async fn handle_request( FileStorage::create_new(&tli_dir_path, new_state.clone(), conf.no_sync).await?; // now we have a ready timeline in a temp directory - validate_temp_timeline(conf, request.destination_ttid, &tli_dir_path).await?; + validate_temp_timeline(conf, request.destination_ttid, &tli_dir_path, None).await?; global_timelines - .load_temp_timeline(request.destination_ttid, &tli_dir_path, true) + .load_temp_timeline(request.destination_ttid, &tli_dir_path, None) .await?; Ok(()) diff --git a/safekeeper/src/hadron.rs b/safekeeper/src/hadron.rs index 8c6a912166..72b377fcc4 100644 --- a/safekeeper/src/hadron.rs +++ b/safekeeper/src/hadron.rs @@ -193,7 +193,7 @@ pub async fn hcc_pull_timeline( tenant_id: timeline.tenant_id, timeline_id: timeline.timeline_id, http_hosts: Vec::new(), - ignore_tombstone: None, + mconf: None, }; for host in timeline.peers { if host.0 == conf.my_id.0 { @@ -387,6 +387,7 @@ pub fn get_filesystem_usage(path: &std::path::Path) -> u64 { critical_timeline!( placeholder_ttid.tenant_id, placeholder_ttid.timeline_id, + None::<&AtomicBool>, "Global disk usage watcher failed to read filesystem usage: {:?}", e ); diff --git a/safekeeper/src/http/routes.rs b/safekeeper/src/http/routes.rs index c9d8e7d3b0..9f4c7141ec 100644 --- a/safekeeper/src/http/routes.rs +++ b/safekeeper/src/http/routes.rs @@ -352,7 +352,7 @@ async fn timeline_exclude_handler(mut request: Request) -> Result>( pub struct FullTimelineInfo { pub ttid: TenantTimelineId, pub ps_feedback_count: u64, + pub ps_corruption_detected: bool, pub last_ps_feedback: PageserverFeedback, pub wal_backup_active: bool, pub timeline_is_active: bool, @@ -547,6 +548,7 @@ pub struct TimelineCollector { ps_last_received_lsn: GenericGaugeVec, feedback_last_time_seconds: GenericGaugeVec, ps_feedback_count: GenericGaugeVec, + ps_corruption_detected: IntGaugeVec, timeline_active: GenericGaugeVec, wal_backup_active: GenericGaugeVec, connected_computes: IntGaugeVec, @@ -654,6 +656,15 @@ impl TimelineCollector { ) .unwrap(); + let ps_corruption_detected = IntGaugeVec::new( + Opts::new( + "safekeeper_ps_corruption_detected", + "1 if corruption was detected in the timeline according to feedback from the pageserver, 0 otherwise", + ), + &["tenant_id", "timeline_id"], + ) + .unwrap(); + let timeline_active = GenericGaugeVec::new( Opts::new( "safekeeper_timeline_active", @@ -774,6 +785,7 @@ impl TimelineCollector { ps_last_received_lsn, feedback_last_time_seconds, ps_feedback_count, + ps_corruption_detected, timeline_active, wal_backup_active, connected_computes, @@ -892,6 +904,9 @@ impl Collector for TimelineCollector { self.ps_feedback_count .with_label_values(labels) .set(tli.ps_feedback_count); + self.ps_corruption_detected + .with_label_values(labels) + .set(tli.ps_corruption_detected as i64); if let Ok(unix_time) = tli .last_ps_feedback .replytime @@ -925,6 +940,7 @@ impl Collector for TimelineCollector { mfs.extend(self.ps_last_received_lsn.collect()); mfs.extend(self.feedback_last_time_seconds.collect()); mfs.extend(self.ps_feedback_count.collect()); + mfs.extend(self.ps_corruption_detected.collect()); mfs.extend(self.timeline_active.collect()); mfs.extend(self.wal_backup_active.collect()); mfs.extend(self.connected_computes.collect()); diff --git a/safekeeper/src/pull_timeline.rs b/safekeeper/src/pull_timeline.rs index b4c4877b2c..4febc7656e 100644 --- a/safekeeper/src/pull_timeline.rs +++ b/safekeeper/src/pull_timeline.rs @@ -13,8 +13,8 @@ use http_utils::error::ApiError; use postgres_ffi::{PG_TLI, XLogFileName, XLogSegNo}; use remote_storage::GenericRemoteStorage; use reqwest::Certificate; -use safekeeper_api::Term; use safekeeper_api::models::{PullTimelineRequest, PullTimelineResponse, TimelineStatus}; +use safekeeper_api::{Term, membership}; use safekeeper_client::mgmt_api; use safekeeper_client::mgmt_api::Client; use serde::Deserialize; @@ -453,12 +453,40 @@ pub async fn handle_request( global_timelines: Arc, wait_for_peer_timeline_status: bool, ) -> Result { + if let Some(mconf) = &request.mconf { + let sk_id = global_timelines.get_sk_id(); + if !mconf.contains(sk_id) { + return Err(ApiError::BadRequest(anyhow!( + "refused to pull timeline with {mconf}, node {sk_id} is not member of it", + ))); + } + } + let existing_tli = global_timelines.get(TenantTimelineId::new( request.tenant_id, request.timeline_id, )); - if existing_tli.is_ok() { - info!("Timeline {} already exists", request.timeline_id); + if let Ok(timeline) = existing_tli { + let cur_generation = timeline + .read_shared_state() + .await + .sk + .state() + .mconf + .generation; + + info!( + "Timeline {} already exists with generation {cur_generation}", + request.timeline_id, + ); + + if let Some(mconf) = request.mconf { + timeline + .membership_switch(mconf) + .await + .map_err(|e| ApiError::InternalServerError(anyhow::anyhow!(e)))?; + } + return Ok(PullTimelineResponse { safekeeper_host: None, }); @@ -495,6 +523,19 @@ pub async fn handle_request( for (i, response) in responses.into_iter().enumerate() { match response { Ok(status) => { + if let Some(mconf) = &request.mconf { + if status.mconf.generation > mconf.generation { + // We probably raced with another timeline membership change with higher generation. + // Ignore this request. + return Err(ApiError::Conflict(format!( + "cannot pull timeline with generation {}: timeline {} already exists with generation {} on {}", + mconf.generation, + request.timeline_id, + status.mconf.generation, + http_hosts[i], + ))); + } + } statuses.push((status, i)); } Err(e) => { @@ -571,19 +612,25 @@ pub async fn handle_request( } } + let max_term = statuses + .iter() + .map(|(status, _)| status.acceptor_state.term) + .max() + .unwrap(); + // Find the most advanced safekeeper let (status, i) = statuses .into_iter() .max_by_key(|(status, _)| { ( status.acceptor_state.epoch, + status.flush_lsn, /* BEGIN_HADRON */ // We need to pull from the SK with the highest term. // This is because another compute may come online and vote the same highest term again on the other two SKs. // Then, there will be 2 computes running on the same term. status.acceptor_state.term, /* END_HADRON */ - status.flush_lsn, status.commit_lsn, ) }) @@ -593,7 +640,21 @@ pub async fn handle_request( assert!(status.tenant_id == request.tenant_id); assert!(status.timeline_id == request.timeline_id); - let check_tombstone = !request.ignore_tombstone.unwrap_or_default(); + // TODO(diko): This is hadron only check to make sure that we pull the timeline + // from the safekeeper with the highest term during timeline restore. + // We could avoid returning the error by calling bump_term after pull_timeline. + // However, this is not a big deal because we retry the pull_timeline requests. + // The check should be removed together with removing custom hadron logic for + // safekeeper restore. + if wait_for_peer_timeline_status && status.acceptor_state.term != max_term { + return Err(ApiError::PreconditionFailed( + format!( + "choosen safekeeper {} has term {}, but the most advanced term is {}", + safekeeper_host, status.acceptor_state.term, max_term + ) + .into(), + )); + } match pull_timeline( status, @@ -601,7 +662,7 @@ pub async fn handle_request( sk_auth_token, http_client, global_timelines, - check_tombstone, + request.mconf, ) .await { @@ -611,6 +672,10 @@ pub async fn handle_request( Some(TimelineError::AlreadyExists(_)) => Ok(PullTimelineResponse { safekeeper_host: None, }), + Some(TimelineError::Deleted(_)) => Err(ApiError::Conflict(format!( + "Timeline {}/{} deleted", + request.tenant_id, request.timeline_id + ))), Some(TimelineError::CreationInProgress(_)) => { // We don't return success here because creation might still fail. Err(ApiError::Conflict("Creation in progress".to_owned())) @@ -627,7 +692,7 @@ async fn pull_timeline( sk_auth_token: Option, http_client: reqwest::Client, global_timelines: Arc, - check_tombstone: bool, + mconf: Option, ) -> Result { let ttid = TenantTimelineId::new(status.tenant_id, status.timeline_id); info!( @@ -689,8 +754,11 @@ async fn pull_timeline( // fsync temp timeline directory to remember its contents. fsync_async_opt(&tli_dir_path, !conf.no_sync).await?; + let generation = mconf.as_ref().map(|c| c.generation); + // Let's create timeline from temp directory and verify that it's correct - let (commit_lsn, flush_lsn) = validate_temp_timeline(conf, ttid, &tli_dir_path).await?; + let (commit_lsn, flush_lsn) = + validate_temp_timeline(conf, ttid, &tli_dir_path, generation).await?; info!( "finished downloading timeline {}, commit_lsn={}, flush_lsn={}", ttid, commit_lsn, flush_lsn @@ -698,10 +766,20 @@ async fn pull_timeline( assert!(status.commit_lsn <= status.flush_lsn); // Finally, load the timeline. - let _tli = global_timelines - .load_temp_timeline(ttid, &tli_dir_path, check_tombstone) + let timeline = global_timelines + .load_temp_timeline(ttid, &tli_dir_path, generation) .await?; + if let Some(mconf) = mconf { + // Switch to provided mconf to guarantee that the timeline will not + // be deleted by request with older generation. + // The generation might already be higer than the one in mconf, e.g. + // if another membership_switch request was executed between `load_temp_timeline` + // and `membership_switch`, but that's totaly fine. `membership_switch` will + // ignore switch to older generation. + timeline.membership_switch(mconf).await?; + } + Ok(PullTimelineResponse { safekeeper_host: Some(host), }) diff --git a/safekeeper/src/safekeeper.rs b/safekeeper/src/safekeeper.rs index 09ca041e22..6c658d30fb 100644 --- a/safekeeper/src/safekeeper.rs +++ b/safekeeper/src/safekeeper.rs @@ -1026,6 +1026,13 @@ where self.state.finish_change(&state).await?; } + if msg.mconf.generation > self.state.mconf.generation && !msg.mconf.contains(self.node_id) { + bail!( + "refused to switch into {}, node {} is not a member of it", + msg.mconf, + self.node_id, + ); + } // Switch into conf given by proposer conf if it is higher. self.state.membership_switch(msg.mconf.clone()).await?; diff --git a/safekeeper/src/send_interpreted_wal.rs b/safekeeper/src/send_interpreted_wal.rs index 671798298b..bfc4008c52 100644 --- a/safekeeper/src/send_interpreted_wal.rs +++ b/safekeeper/src/send_interpreted_wal.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use std::fmt::Display; use std::sync::Arc; +use std::sync::atomic::AtomicBool; use std::time::Duration; use anyhow::{Context, anyhow}; @@ -305,6 +306,9 @@ impl InterpretedWalReader { critical_timeline!( ttid.tenant_id, ttid.timeline_id, + // Hadron: The corruption flag is only used in PS so that it can feed this information back to SKs. + // We do not use these flags in SKs. + None::<&AtomicBool>, "failed to read WAL record: {err:?}" ); } @@ -375,6 +379,9 @@ impl InterpretedWalReader { critical_timeline!( ttid.tenant_id, ttid.timeline_id, + // Hadron: The corruption flag is only used in PS so that it can feed this information back to SKs. + // We do not use these flags in SKs. + None::<&AtomicBool>, "failed to decode WAL record: {err:?}" ); } diff --git a/safekeeper/src/send_wal.rs b/safekeeper/src/send_wal.rs index 5891fa88a4..2d6f7486a9 100644 --- a/safekeeper/src/send_wal.rs +++ b/safekeeper/src/send_wal.rs @@ -55,6 +55,7 @@ pub struct WalSenders { pub struct WalSendersTimelineMetricValues { pub ps_feedback_counter: u64, + pub ps_corruption_detected: bool, pub last_ps_feedback: PageserverFeedback, pub interpreted_wal_reader_tasks: usize, } @@ -193,6 +194,7 @@ impl WalSenders { WalSendersTimelineMetricValues { ps_feedback_counter: shared.ps_feedback_counter, + ps_corruption_detected: shared.ps_corruption_detected, last_ps_feedback: shared.last_ps_feedback, interpreted_wal_reader_tasks, } @@ -209,6 +211,9 @@ impl WalSenders { *shared.get_slot_mut(id).get_mut_feedback() = ReplicationFeedback::Pageserver(*feedback); shared.last_ps_feedback = *feedback; shared.ps_feedback_counter += 1; + if feedback.corruption_detected { + shared.ps_corruption_detected = true; + } drop(shared); RECEIVED_PS_FEEDBACKS.inc(); @@ -278,6 +283,9 @@ struct WalSendersShared { last_ps_feedback: PageserverFeedback, // total counter of pageserver feedbacks received ps_feedback_counter: u64, + // Hadron: true iff we received a pageserver feedback that incidated + // data corruption in the timeline + ps_corruption_detected: bool, slots: Vec>, } @@ -328,6 +336,7 @@ impl WalSendersShared { agg_standby_feedback: StandbyFeedback::empty(), last_ps_feedback: PageserverFeedback::empty(), ps_feedback_counter: 0, + ps_corruption_detected: false, slots: Vec::new(), } } diff --git a/safekeeper/src/timeline.rs b/safekeeper/src/timeline.rs index b8774b30ea..25ac8e5bd3 100644 --- a/safekeeper/src/timeline.rs +++ b/safekeeper/src/timeline.rs @@ -195,12 +195,14 @@ impl StateSK { to: Configuration, ) -> Result { let result = self.state_mut().membership_switch(to).await?; + let flush_lsn = self.flush_lsn(); + let last_log_term = self.state().acceptor_state.get_last_log_term(flush_lsn); Ok(TimelineMembershipSwitchResponse { previous_conf: result.previous_conf, current_conf: result.current_conf, - last_log_term: self.state().acceptor_state.term, - flush_lsn: self.flush_lsn(), + last_log_term, + flush_lsn, }) } @@ -594,7 +596,7 @@ impl Timeline { /// Cancel the timeline, requesting background activity to stop. Closing /// the `self.gate` waits for that. - pub async fn cancel(&self) { + pub fn cancel(&self) { info!("timeline {} shutting down", self.ttid); self.cancel.cancel(); } @@ -839,6 +841,7 @@ impl Timeline { let WalSendersTimelineMetricValues { ps_feedback_counter, + ps_corruption_detected, last_ps_feedback, interpreted_wal_reader_tasks, } = self.walsenders.info_for_metrics(); @@ -847,6 +850,7 @@ impl Timeline { Some(FullTimelineInfo { ttid: self.ttid, ps_feedback_count: ps_feedback_counter, + ps_corruption_detected, last_ps_feedback, wal_backup_active: self.wal_backup_active.load(Ordering::Relaxed), timeline_is_active: self.broker_active.load(Ordering::Relaxed), @@ -914,6 +918,13 @@ impl Timeline { to: Configuration, ) -> Result { let mut state = self.write_shared_state().await; + // Ensure we don't race with exclude/delete requests by checking the cancellation + // token under the write_shared_state lock. + // Exclude/delete cancel the timeline under the shared state lock, + // so the timeline cannot be deleted in the middle of the membership switch. + if self.is_cancelled() { + bail!(TimelineError::Cancelled(self.ttid)); + } state.sk.membership_switch(to).await } diff --git a/safekeeper/src/timelines_global_map.rs b/safekeeper/src/timelines_global_map.rs index a81a7298a9..f63d1abdcf 100644 --- a/safekeeper/src/timelines_global_map.rs +++ b/safekeeper/src/timelines_global_map.rs @@ -10,13 +10,13 @@ use std::time::{Duration, Instant}; use anyhow::{Context, Result, bail}; use camino::Utf8PathBuf; use camino_tempfile::Utf8TempDir; -use safekeeper_api::membership::Configuration; +use safekeeper_api::membership::{Configuration, SafekeeperGeneration}; use safekeeper_api::models::{SafekeeperUtilization, TimelineDeleteResult}; use safekeeper_api::{ServerInfo, membership}; use tokio::fs; use tracing::*; use utils::crashsafe::{durable_rename, fsync_async_opt}; -use utils::id::{TenantId, TenantTimelineId, TimelineId}; +use utils::id::{NodeId, TenantId, TenantTimelineId, TimelineId}; use utils::lsn::Lsn; use crate::defaults::DEFAULT_EVICTION_CONCURRENCY; @@ -40,10 +40,17 @@ enum GlobalMapTimeline { struct GlobalTimelinesState { timelines: HashMap, - // A tombstone indicates this timeline used to exist has been deleted. These are used to prevent - // on-demand timeline creation from recreating deleted timelines. This is only soft-enforced, as - // this map is dropped on restart. - tombstones: HashMap, + /// A tombstone indicates this timeline used to exist has been deleted. These are used to prevent + /// on-demand timeline creation from recreating deleted timelines. This is only soft-enforced, as + /// this map is dropped on restart. + /// The timeline might also be locally deleted (excluded) via safekeeper migration algorithm. In that case, + /// the tombsone contains the corresponding safekeeper generation. The pull_timeline requests with + /// higher generation ignore such tombstones and can recreate the timeline. + timeline_tombstones: HashMap, + /// A tombstone indicates that the tenant used to exist has been deleted. + /// These are created only by tenant_delete requests. They are always valid regardless of the + /// request generation. + /// This is only soft-enforced, as this map is dropped on restart. tenant_tombstones: HashMap, conf: Arc, @@ -79,7 +86,7 @@ impl GlobalTimelinesState { Err(TimelineError::CreationInProgress(*ttid)) } None => { - if self.has_tombstone(ttid) { + if self.has_tombstone(ttid, None) { Err(TimelineError::Deleted(*ttid)) } else { Err(TimelineError::NotFound(*ttid)) @@ -88,20 +95,46 @@ impl GlobalTimelinesState { } } - fn has_tombstone(&self, ttid: &TenantTimelineId) -> bool { - self.tombstones.contains_key(ttid) || self.tenant_tombstones.contains_key(&ttid.tenant_id) + fn has_timeline_tombstone( + &self, + ttid: &TenantTimelineId, + generation: Option, + ) -> bool { + if let Some(generation) = generation { + self.timeline_tombstones + .get(ttid) + .is_some_and(|t| t.is_valid(generation)) + } else { + self.timeline_tombstones.contains_key(ttid) + } } - /// Removes all blocking tombstones for the given timeline ID. + fn has_tenant_tombstone(&self, tenant_id: &TenantId) -> bool { + self.tenant_tombstones.contains_key(tenant_id) + } + + /// Check if the state has a tenant or a timeline tombstone. + /// If `generation` is provided, check only for timeline tombsotnes with same or higher generation. + /// If `generation` is `None`, check for any timeline tombstone. + /// Tenant tombstones are checked regardless of the generation. + fn has_tombstone( + &self, + ttid: &TenantTimelineId, + generation: Option, + ) -> bool { + self.has_timeline_tombstone(ttid, generation) || self.has_tenant_tombstone(&ttid.tenant_id) + } + + /// Removes timeline tombstone for the given timeline ID. /// Returns `true` if there have been actual changes. - fn remove_tombstone(&mut self, ttid: &TenantTimelineId) -> bool { - self.tombstones.remove(ttid).is_some() - || self.tenant_tombstones.remove(&ttid.tenant_id).is_some() + fn remove_timeline_tombstone(&mut self, ttid: &TenantTimelineId) -> bool { + self.timeline_tombstones.remove(ttid).is_some() } - fn delete(&mut self, ttid: TenantTimelineId) { + fn delete(&mut self, ttid: TenantTimelineId, generation: Option) { self.timelines.remove(&ttid); - self.tombstones.insert(ttid, Instant::now()); + self.timeline_tombstones + .insert(ttid, TimelineTombstone::new(generation)); } fn add_tenant_tombstone(&mut self, tenant_id: TenantId) { @@ -120,7 +153,7 @@ impl GlobalTimelines { Self { state: Mutex::new(GlobalTimelinesState { timelines: HashMap::new(), - tombstones: HashMap::new(), + timeline_tombstones: HashMap::new(), tenant_tombstones: HashMap::new(), conf, broker_active_set: Arc::new(TimelinesSet::default()), @@ -261,6 +294,8 @@ impl GlobalTimelines { start_lsn: Lsn, commit_lsn: Lsn, ) -> Result> { + let generation = Some(mconf.generation); + let (conf, _, _, _) = { let state = self.state.lock().unwrap(); if let Ok(timeline) = state.get(&ttid) { @@ -268,8 +303,8 @@ impl GlobalTimelines { return Ok(timeline); } - if state.has_tombstone(&ttid) { - anyhow::bail!("Timeline {ttid} is deleted, refusing to recreate"); + if state.has_tombstone(&ttid, generation) { + anyhow::bail!(TimelineError::Deleted(ttid)); } state.get_dependencies() @@ -284,7 +319,9 @@ impl GlobalTimelines { // immediately initialize first WAL segment as well. let state = TimelinePersistentState::new(&ttid, mconf, server_info, start_lsn, commit_lsn)?; control_file::FileStorage::create_new(&tmp_dir_path, state, conf.no_sync).await?; - let timeline = self.load_temp_timeline(ttid, &tmp_dir_path, true).await?; + let timeline = self + .load_temp_timeline(ttid, &tmp_dir_path, generation) + .await?; Ok(timeline) } @@ -303,7 +340,7 @@ impl GlobalTimelines { &self, ttid: TenantTimelineId, tmp_path: &Utf8PathBuf, - check_tombstone: bool, + generation: Option, ) -> Result> { // Check for existence and mark that we're creating it. let (conf, broker_active_set, partial_backup_rate_limiter, wal_backup) = { @@ -317,18 +354,18 @@ impl GlobalTimelines { } _ => {} } - if check_tombstone { - if state.has_tombstone(&ttid) { - anyhow::bail!("timeline {ttid} is deleted, refusing to recreate"); - } - } else { - // We may be have been asked to load a timeline that was previously deleted (e.g. from `pull_timeline.rs`). We trust - // that the human doing this manual intervention knows what they are doing, and remove its tombstone. - // It's also possible that we enter this when the tenant has been deleted, even if the timeline itself has never existed. - if state.remove_tombstone(&ttid) { - warn!("un-deleted timeline {ttid}"); - } + + if state.has_tombstone(&ttid, generation) { + // If the timeline is deleted, we refuse to recreate it. + // This is a safeguard against accidentally overwriting a timeline that was deleted + // by concurrent request. + anyhow::bail!(TimelineError::Deleted(ttid)); } + + // We might have an outdated tombstone with the older generation. + // Remove it unconditionally. + state.remove_timeline_tombstone(&ttid); + state .timelines .insert(ttid, GlobalMapTimeline::CreationInProgress); @@ -503,11 +540,16 @@ impl GlobalTimelines { ttid: &TenantTimelineId, action: DeleteOrExclude, ) -> Result { + let generation = match &action { + DeleteOrExclude::Delete | DeleteOrExclude::DeleteLocal => None, + DeleteOrExclude::Exclude(mconf) => Some(mconf.generation), + }; + let tli_res = { let state = self.state.lock().unwrap(); // Do NOT check tenant tombstones here: those were set earlier - if state.tombstones.contains_key(ttid) { + if state.has_timeline_tombstone(ttid, generation) { // Presence of a tombstone guarantees that a previous deletion has completed and there is no work to do. info!("Timeline {ttid} was already deleted"); return Ok(TimelineDeleteResult { dir_existed: false }); @@ -528,6 +570,11 @@ impl GlobalTimelines { // We would like to avoid holding the lock while waiting for the // gate to finish as this is deadlock prone, so for actual // deletion will take it second time. + // + // Canceling the timeline will block membership switch requests, + // ensuring that the timeline generation will not increase + // after this point, and we will not remove a timeline with a generation + // higher than the requested one. if let DeleteOrExclude::Exclude(ref mconf) = action { let shared_state = timeline.read_shared_state().await; if shared_state.sk.state().mconf.generation > mconf.generation { @@ -536,9 +583,9 @@ impl GlobalTimelines { current: shared_state.sk.state().mconf.clone(), }); } - timeline.cancel().await; + timeline.cancel(); } else { - timeline.cancel().await; + timeline.cancel(); } timeline.close().await; @@ -565,7 +612,7 @@ impl GlobalTimelines { // Finalize deletion, by dropping Timeline objects and storing smaller tombstones. The tombstones // are used to prevent still-running computes from re-creating the same timeline when they send data, // and to speed up repeated deletion calls by avoiding re-listing objects. - self.state.lock().unwrap().delete(*ttid); + self.state.lock().unwrap().delete(*ttid, generation); result } @@ -627,12 +674,16 @@ impl GlobalTimelines { // may recreate a deleted timeline. let now = Instant::now(); state - .tombstones - .retain(|_, v| now.duration_since(*v) < *tombstone_ttl); + .timeline_tombstones + .retain(|_, v| now.duration_since(v.timestamp) < *tombstone_ttl); state .tenant_tombstones .retain(|_, v| now.duration_since(*v) < *tombstone_ttl); } + + pub fn get_sk_id(&self) -> NodeId { + self.state.lock().unwrap().conf.my_id + } } /// Action for delete_or_exclude. @@ -673,6 +724,7 @@ pub async fn validate_temp_timeline( conf: &SafeKeeperConf, ttid: TenantTimelineId, path: &Utf8PathBuf, + generation: Option, ) -> Result<(Lsn, Lsn)> { let control_path = path.join("safekeeper.control"); @@ -681,6 +733,15 @@ pub async fn validate_temp_timeline( bail!("wal_seg_size is not set"); } + if let Some(generation) = generation { + if control_store.mconf.generation > generation { + bail!( + "tmp timeline generation {} is higher than expected {generation}", + control_store.mconf.generation + ); + } + } + let wal_store = wal_storage::PhysicalStorage::new(&ttid, path, &control_store, conf.no_sync)?; let commit_lsn = control_store.commit_lsn; @@ -688,3 +749,28 @@ pub async fn validate_temp_timeline( Ok((commit_lsn, flush_lsn)) } + +/// A tombstone for a deleted timeline. +/// The generation is passed with "exclude" request and stored in the tombstone. +/// We ignore the tombstone if the request generation is higher than +/// the tombstone generation. +/// If the tombstone doesn't have a generation, it's considered permanent, +/// e.g. after "delete" request. +struct TimelineTombstone { + timestamp: Instant, + generation: Option, +} + +impl TimelineTombstone { + fn new(generation: Option) -> Self { + TimelineTombstone { + timestamp: Instant::now(), + generation, + } + } + + /// Check if the timeline is still valid for the given generation. + fn is_valid(&self, generation: SafekeeperGeneration) -> bool { + self.generation.is_none_or(|g| g >= generation) + } +} diff --git a/safekeeper/src/wal_backup.rs b/safekeeper/src/wal_backup.rs index 03c8f7e84a..191f8aacf1 100644 --- a/safekeeper/src/wal_backup.rs +++ b/safekeeper/src/wal_backup.rs @@ -12,7 +12,7 @@ use futures::stream::{self, FuturesOrdered}; use postgres_ffi::v14::xlog_utils::XLogSegNoOffsetToRecPtr; use postgres_ffi::{PG_TLI, XLogFileName, XLogSegNo}; use remote_storage::{ - DownloadOpts, GenericRemoteStorage, ListingMode, RemotePath, StorageMetadata, + DownloadError, DownloadOpts, GenericRemoteStorage, ListingMode, RemotePath, StorageMetadata, }; use safekeeper_api::models::PeerInfo; use tokio::fs::File; @@ -607,6 +607,9 @@ pub(crate) async fn copy_partial_segment( storage.copy_object(source, destination, &cancel).await } +const WAL_READ_WARN_THRESHOLD: u32 = 2; +const WAL_READ_MAX_RETRIES: u32 = 3; + pub async fn read_object( storage: &GenericRemoteStorage, file_path: &RemotePath, @@ -620,12 +623,23 @@ pub async fn read_object( byte_start: std::ops::Bound::Included(offset), ..Default::default() }; - let download = storage - .download(file_path, &opts, &cancel) - .await - .with_context(|| { - format!("Failed to open WAL segment download stream for remote path {file_path:?}") - })?; + + // This retry only solves the connect errors: subsequent reads can still fail as this function returns + // a stream. + let download = backoff::retry( + || async { storage.download(file_path, &opts, &cancel).await }, + DownloadError::is_permanent, + WAL_READ_WARN_THRESHOLD, + WAL_READ_MAX_RETRIES, + "download WAL segment", + &cancel, + ) + .await + .ok_or_else(|| DownloadError::Cancelled) + .and_then(|x| x) + .with_context(|| { + format!("Failed to open WAL segment download stream for remote path {file_path:?}") + })?; let reader = tokio_util::io::StreamReader::new(download.download_stream); diff --git a/storage_controller/src/compute_hook.rs b/storage_controller/src/compute_hook.rs index 232f8abbab..1ceb242650 100644 --- a/storage_controller/src/compute_hook.rs +++ b/storage_controller/src/compute_hook.rs @@ -543,7 +543,7 @@ impl ApiMethod for ComputeHookTenant { None }; let pageserver = PageserverShardConnectionInfo { - id: Some(shard.node_id.to_string()), + id: Some(shard.node_id), libpq_url, grpc_url, }; @@ -561,7 +561,7 @@ impl ApiMethod for ComputeHookTenant { let pageserver_conninfo = PageserverConnectionInfo { shard_count, - stripe_size: stripe_size.map(|val| val.0), + stripe_size: stripe_size.map(|val| ShardStripeSize(val.0)), shards: shard_infos, prefer_protocol, }; diff --git a/storage_controller/src/operation_utils.rs b/storage_controller/src/operation_utils.rs index af86010ab7..1060c92832 100644 --- a/storage_controller/src/operation_utils.rs +++ b/storage_controller/src/operation_utils.rs @@ -46,11 +46,31 @@ impl TenantShardDrain { &self, tenants: &BTreeMap, scheduler: &Scheduler, - ) -> Option { - let tenant_shard = tenants.get(&self.tenant_shard_id)?; + ) -> TenantShardDrainAction { + let Some(tenant_shard) = tenants.get(&self.tenant_shard_id) else { + return TenantShardDrainAction::Skip; + }; if *tenant_shard.intent.get_attached() != Some(self.drained_node) { - return None; + // If the intent attached node is not the drained node, check the observed state + // of the shard on the drained node. If it is Attached*, it means the shard is + // beeing migrated from the drained node. The drain loop needs to wait for the + // reconciliation to complete for a smooth draining. + + use pageserver_api::models::LocationConfigMode::*; + + let attach_mode = tenant_shard + .observed + .locations + .get(&self.drained_node) + .and_then(|observed| observed.conf.as_ref().map(|conf| conf.mode)); + + return match (attach_mode, tenant_shard.intent.get_attached()) { + (Some(AttachedSingle | AttachedMulti | AttachedStale), Some(intent_node_id)) => { + TenantShardDrainAction::Reconcile(*intent_node_id) + } + _ => TenantShardDrainAction::Skip, + }; } // Only tenants with a normal (Active) scheduling policy are proactively moved @@ -63,19 +83,19 @@ impl TenantShardDrain { } ShardSchedulingPolicy::Pause | ShardSchedulingPolicy::Stop => { // If we have been asked to avoid rescheduling this shard, then do not migrate it during a drain - return None; + return TenantShardDrainAction::Skip; } } match tenant_shard.preferred_secondary(scheduler) { - Some(node) => Some(node), + Some(node) => TenantShardDrainAction::RescheduleToSecondary(node), None => { tracing::warn!( tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), "No eligible secondary while draining {}", self.drained_node ); - None + TenantShardDrainAction::Skip } } } @@ -138,3 +158,17 @@ impl TenantShardDrain { } } } + +/// Action to take when draining a tenant shard. +pub(crate) enum TenantShardDrainAction { + /// The tenant shard is on the draining node. + /// Reschedule the tenant shard to a secondary location. + /// Holds a destination node id to reschedule to. + RescheduleToSecondary(NodeId), + /// The tenant shard is beeing migrated from the draining node. + /// Wait for the reconciliation to complete. + /// Holds the intent attached node id. + Reconcile(NodeId), + /// The tenant shard is not eligible for drainining, skip it. + Skip, +} diff --git a/storage_controller/src/persistence.rs b/storage_controller/src/persistence.rs index 619b5f69b8..c61ef9ff5d 100644 --- a/storage_controller/src/persistence.rs +++ b/storage_controller/src/persistence.rs @@ -471,11 +471,17 @@ impl Persistence { &self, input_node_id: NodeId, input_https_port: Option, + input_grpc_addr: Option, + input_grpc_port: Option, ) -> DatabaseResult<()> { use crate::schema::nodes::dsl::*; self.update_node( input_node_id, - listen_https_port.eq(input_https_port.map(|x| x as i32)), + ( + listen_https_port.eq(input_https_port.map(|x| x as i32)), + listen_grpc_addr.eq(input_grpc_addr), + listen_grpc_port.eq(input_grpc_port.map(|x| x as i32)), + ), ) .await } diff --git a/storage_controller/src/reconciler.rs b/storage_controller/src/reconciler.rs index d1590ec75e..ff5a3831cd 100644 --- a/storage_controller/src/reconciler.rs +++ b/storage_controller/src/reconciler.rs @@ -981,6 +981,7 @@ impl Reconciler { )); } + let mut first_err = None; for (node, conf) in changes { if self.cancel.is_cancelled() { return Err(ReconcileError::Cancel); @@ -990,7 +991,12 @@ impl Reconciler { // shard _available_ (the attached location), and configuring secondary locations // can be done lazily when the node becomes available (via background reconciliation). if node.is_available() { - self.location_config(&node, conf, None, false).await?; + let res = self.location_config(&node, conf, None, false).await; + if let Err(err) = res { + if first_err.is_none() { + first_err = Some(err); + } + } } else { // If the node is unavailable, we skip and consider the reconciliation successful: this // is a common case where a pageserver is marked unavailable: we demote a location on @@ -1002,6 +1008,10 @@ impl Reconciler { } } + if let Some(err) = first_err { + return Err(err); + } + // The condition below identifies a detach. We must have no attached intent and // must have been attached to something previously. Pass this information to // the [`ComputeHook`] such that it can update its tenant-wide state. diff --git a/storage_controller/src/service.rs b/storage_controller/src/service.rs index 8f5efe8ac4..2c7e4ee2de 100644 --- a/storage_controller/src/service.rs +++ b/storage_controller/src/service.rs @@ -79,7 +79,7 @@ use crate::id_lock_map::{ use crate::leadership::Leadership; use crate::metrics; use crate::node::{AvailabilityTransition, Node}; -use crate::operation_utils::{self, TenantShardDrain}; +use crate::operation_utils::{self, TenantShardDrain, TenantShardDrainAction}; use crate::pageserver_client::PageserverClient; use crate::peer_client::GlobalObservedState; use crate::persistence::split_state::SplitState; @@ -1274,7 +1274,7 @@ impl Service { // Always attempt autosplits. Sharding is crucial for bulk ingest performance, so we // must be responsive when new projects begin ingesting and reach the threshold. self.autosplit_tenants().await; - } + }, _ = self.reconcilers_cancel.cancelled() => return } } @@ -1530,10 +1530,19 @@ impl Service { // so that waiters will see the correct error after waiting. tenant.set_last_error(result.sequence, e); - // Skip deletions on reconcile failures - let upsert_deltas = - deltas.filter(|delta| matches!(delta, ObservedStateDelta::Upsert(_))); - tenant.apply_observed_deltas(upsert_deltas); + // If the reconciliation failed, don't clear the observed state for places where we + // detached. Instead, mark the observed state as uncertain. + let failed_reconcile_deltas = deltas.map(|delta| { + if let ObservedStateDelta::Delete(node_id) = delta { + ObservedStateDelta::Upsert(Box::new(( + node_id, + ObservedStateLocation { conf: None }, + ))) + } else { + delta + } + }); + tenant.apply_observed_deltas(failed_reconcile_deltas); } } @@ -7804,7 +7813,7 @@ impl Service { register_req.listen_https_port, register_req.listen_pg_addr, register_req.listen_pg_port, - register_req.listen_grpc_addr, + register_req.listen_grpc_addr.clone(), register_req.listen_grpc_port, register_req.availability_zone_id.clone(), self.config.use_https_pageserver_api, @@ -7839,6 +7848,8 @@ impl Service { .update_node_on_registration( register_req.node_id, register_req.listen_https_port, + register_req.listen_grpc_addr, + register_req.listen_grpc_port, ) .await? } @@ -8867,6 +8878,9 @@ impl Service { for (_tenant_id, schedule_context, shards) in TenantShardExclusiveIterator::new(tenants, ScheduleMode::Speculative) { + if work.len() >= MAX_OPTIMIZATIONS_PLAN_PER_PASS { + break; + } for shard in shards { if work.len() >= MAX_OPTIMIZATIONS_PLAN_PER_PASS { break; @@ -9631,16 +9645,16 @@ impl Service { tenant_shard_id: tid, }; - let dest_node_id = { + let drain_action = { let locked = self.inner.read().unwrap(); + tid_drain.tenant_shard_eligible_for_drain(&locked.tenants, &locked.scheduler) + }; - match tid_drain - .tenant_shard_eligible_for_drain(&locked.tenants, &locked.scheduler) - { - Some(node_id) => node_id, - None => { - continue; - } + let dest_node_id = match drain_action { + TenantShardDrainAction::RescheduleToSecondary(dest_node_id) => dest_node_id, + TenantShardDrainAction::Reconcile(intent_node_id) => intent_node_id, + TenantShardDrainAction::Skip => { + continue; } }; @@ -9675,14 +9689,16 @@ impl Service { { let mut locked = self.inner.write().unwrap(); let (nodes, tenants, scheduler) = locked.parts_mut(); - let rescheduled = tid_drain.reschedule_to_secondary( - dest_node_id, - tenants, - scheduler, - nodes, - )?; - if let Some(tenant_shard) = rescheduled { + let tenant_shard = match drain_action { + TenantShardDrainAction::RescheduleToSecondary(dest_node_id) => tid_drain + .reschedule_to_secondary(dest_node_id, tenants, scheduler, nodes)?, + TenantShardDrainAction::Reconcile(_) => tenants.get_mut(&tid), + // Note: Unreachable, handled above. + TenantShardDrainAction::Skip => None, + }; + + if let Some(tenant_shard) = tenant_shard { let waiter = self.maybe_configured_reconcile_shard( tenant_shard, nodes, diff --git a/storage_controller/src/service/safekeeper_reconciler.rs b/storage_controller/src/service/safekeeper_reconciler.rs index b67a679fad..7dbbd3afe4 100644 --- a/storage_controller/src/service/safekeeper_reconciler.rs +++ b/storage_controller/src/service/safekeeper_reconciler.rs @@ -364,7 +364,12 @@ impl SafekeeperReconcilerInner { http_hosts, tenant_id: req.tenant_id, timeline_id, - ignore_tombstone: Some(false), + // TODO(diko): get mconf from "timelines" table and pass it here. + // Now we use pull_timeline reconciliation only for the timeline creation, + // so it's not critical right now. + // It could be fixed together with other reconciliation issues: + // https://github.com/neondatabase/neon/issues/12189 + mconf: None, }; success = self .reconcile_inner( diff --git a/storage_controller/src/service/safekeeper_service.rs b/storage_controller/src/service/safekeeper_service.rs index bc77a1a6b8..fab1342d5d 100644 --- a/storage_controller/src/service/safekeeper_service.rs +++ b/storage_controller/src/service/safekeeper_service.rs @@ -24,12 +24,12 @@ use pageserver_api::controller_api::{ }; use pageserver_api::models::{SafekeeperInfo, SafekeepersInfo, TimelineInfo}; use safekeeper_api::PgVersionId; +use safekeeper_api::Term; use safekeeper_api::membership::{self, MemberSet, SafekeeperGeneration}; use safekeeper_api::models::{ PullTimelineRequest, TimelineLocateResponse, TimelineMembershipSwitchRequest, TimelineMembershipSwitchResponse, }; -use safekeeper_api::{INITIAL_TERM, Term}; use safekeeper_client::mgmt_api; use tokio::task::JoinSet; use tokio_util::sync::CancellationToken; @@ -991,6 +991,7 @@ impl Service { timeline_id: TimelineId, to_safekeepers: &[Safekeeper], from_safekeepers: &[Safekeeper], + mconf: membership::Configuration, ) -> Result<(), ApiError> { let http_hosts = from_safekeepers .iter() @@ -1009,14 +1010,11 @@ impl Service { .collect::>() ); - // TODO(diko): need to pass mconf/generation with the request - // to properly handle tombstones. Ignore tombstones for now. - // Worst case: we leave a timeline on a safekeeper which is not in the current set. let req = PullTimelineRequest { tenant_id, timeline_id, http_hosts, - ignore_tombstone: Some(true), + mconf: Some(mconf), }; const SK_PULL_TIMELINE_RECONCILE_TIMEOUT: Duration = Duration::from_secs(30); @@ -1300,13 +1298,7 @@ impl Service { ) .await?; - let mut sync_position = (INITIAL_TERM, Lsn::INVALID); - for res in results.into_iter().flatten() { - let sk_position = (res.last_log_term, res.flush_lsn); - if sync_position < sk_position { - sync_position = sk_position; - } - } + let sync_position = Self::get_sync_position(&results)?; tracing::info!( %generation, @@ -1336,6 +1328,7 @@ impl Service { timeline_id, &pull_to_safekeepers, &cur_safekeepers, + joint_config.clone(), ) .await?; @@ -1599,4 +1592,36 @@ impl Service { Ok(()) } + + /// Get membership switch responses from all safekeepers and return the sync position. + /// + /// Sync position is a position equal or greater than the commit position. + /// It is guaranteed that all WAL entries with (last_log_term, flush_lsn) + /// greater than the sync position are not committed (= not on a quorum). + /// + /// Returns error if there is no quorum of successful responses. + fn get_sync_position( + responses: &[mgmt_api::Result], + ) -> Result<(Term, Lsn), ApiError> { + let quorum_size = responses.len() / 2 + 1; + + let mut wal_positions = responses + .iter() + .flatten() + .map(|res| (res.last_log_term, res.flush_lsn)) + .collect::>(); + + // Should be already checked if the responses are from tenant_timeline_set_membership_quorum. + if wal_positions.len() < quorum_size { + return Err(ApiError::InternalServerError(anyhow::anyhow!( + "not enough successful responses to get sync position: {}/{}", + wal_positions.len(), + quorum_size, + ))); + } + + wal_positions.sort(); + + Ok(wal_positions[quorum_size - 1]) + } } diff --git a/storage_controller/src/tenant_shard.rs b/storage_controller/src/tenant_shard.rs index f60378470e..bf16c642af 100644 --- a/storage_controller/src/tenant_shard.rs +++ b/storage_controller/src/tenant_shard.rs @@ -249,6 +249,10 @@ impl IntentState { } pub(crate) fn push_secondary(&mut self, scheduler: &mut Scheduler, new_secondary: NodeId) { + // Every assertion here should probably have a corresponding check in + // `validate_optimization` unless it is an invariant that should never be violated. Note + // that the lock is not held between planning optimizations and applying them so you have to + // assume any valid state transition of the intent state may have occurred assert!(!self.secondary.contains(&new_secondary)); assert!(self.attached != Some(new_secondary)); scheduler.update_node_ref_counts( @@ -808,8 +812,6 @@ impl TenantShard { /// if the swap is not possible and leaves the intent state in its original state. /// /// Arguments: - /// `attached_to`: the currently attached location matching the intent state (may be None if the - /// shard is not attached) /// `promote_to`: an optional secondary location of this tenant shard. If set to None, we ask /// the scheduler to recommend a node pub(crate) fn reschedule_to_secondary( @@ -1335,8 +1337,9 @@ impl TenantShard { true } - /// Check that the desired modifications to the intent state are compatible with - /// the current intent state + /// Check that the desired modifications to the intent state are compatible with the current + /// intent state. Note that the lock is not held between planning optimizations and applying + /// them so any valid state transition of the intent state may have occurred. fn validate_optimization(&self, optimization: &ScheduleOptimization) -> bool { match optimization.action { ScheduleOptimizationAction::MigrateAttachment(MigrateAttachment { @@ -1352,6 +1355,9 @@ impl TenantShard { }) => { // It's legal to remove a secondary that is not present in the intent state !self.intent.secondary.contains(&new_node_id) + // Ensure the secondary hasn't already been promoted to attached by a concurrent + // optimization/migration. + && self.intent.attached != Some(new_node_id) } ScheduleOptimizationAction::CreateSecondary(new_node_id) => { !self.intent.secondary.contains(&new_node_id) diff --git a/test_runner/fixtures/compare_fixtures.py b/test_runner/fixtures/compare_fixtures.py index 150046b99a..3d248efc04 100644 --- a/test_runner/fixtures/compare_fixtures.py +++ b/test_runner/fixtures/compare_fixtures.py @@ -16,6 +16,7 @@ from typing_extensions import override from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker from fixtures.log_helper import log from fixtures.neon_fixtures import ( + Endpoint, NeonEnv, PgBin, PgProtocol, @@ -129,6 +130,10 @@ class NeonCompare(PgCompare): # Start pg self._pg = self.env.endpoints.create_start("main", "main", self.tenant) + @property + def endpoint(self) -> Endpoint: + return self._pg + @property @override def pg(self) -> PgProtocol: diff --git a/test_runner/fixtures/endpoint/http.py b/test_runner/fixtures/endpoint/http.py index c43445e89d..c77a372017 100644 --- a/test_runner/fixtures/endpoint/http.py +++ b/test_runner/fixtures/endpoint/http.py @@ -78,18 +78,36 @@ class EndpointHttpClient(requests.Session): json: dict[str, str] = res.json() return json - def prewarm_lfc(self, from_endpoint_id: str | None = None): + def prewarm_lfc(self, from_endpoint_id: str | None = None) -> dict[str, str]: + """ + Prewarm LFC cache from given endpoint and wait till it finishes or errors + """ params = {"from_endpoint": from_endpoint_id} if from_endpoint_id else dict() self.post(self.prewarm_url, params=params).raise_for_status() - self.prewarm_lfc_wait() + return self.prewarm_lfc_wait() + + def cancel_prewarm_lfc(self): + """ + Cancel LFC prewarm if any is ongoing + """ + self.delete(self.prewarm_url).raise_for_status() + + def prewarm_lfc_wait(self) -> dict[str, str]: + """ + Wait till LFC prewarm returns with error or success. + If prewarm was not requested before calling this function, it will error + """ + statuses = "failed", "completed", "skipped", "cancelled" - def prewarm_lfc_wait(self): def prewarmed(): json = self.prewarm_lfc_status() status, err = json["status"], json.get("error") - assert status == "completed", f"{status}, {err=}" + assert status in statuses, f"{status}, {err=}" wait_until(prewarmed, timeout=60) + res = self.prewarm_lfc_status() + assert res["status"] != "failed", res + return res def offload_lfc_status(self) -> dict[str, str]: res = self.get(self.offload_url) @@ -97,27 +115,38 @@ class EndpointHttpClient(requests.Session): json: dict[str, str] = res.json() return json - def offload_lfc(self): + def offload_lfc(self) -> dict[str, str]: + """ + Offload LFC cache to endpoint storage and wait till offload finishes or errors + """ self.post(self.offload_url).raise_for_status() - self.offload_lfc_wait() + return self.offload_lfc_wait() + + def offload_lfc_wait(self) -> dict[str, str]: + """ + Wait till LFC offload returns with error or success. + If offload was not requested before calling this function, it will error + """ + statuses = "failed", "completed", "skipped" - def offload_lfc_wait(self): def offloaded(): json = self.offload_lfc_status() status, err = json["status"], json.get("error") - assert status == "completed", f"{status}, {err=}" + assert status in statuses, f"{status}, {err=}" - wait_until(offloaded) + wait_until(offloaded, timeout=60) + res = self.offload_lfc_status() + assert res["status"] != "failed", res + return res - def promote(self, safekeepers_lsn: dict[str, Any], disconnect: bool = False): + def promote(self, promote_spec: dict[str, Any], disconnect: bool = False) -> dict[str, str]: url = f"http://localhost:{self.external_port}/promote" if disconnect: try: # send first request to start promote and disconnect - self.post(url, data=safekeepers_lsn, timeout=0.001) + self.post(url, json=promote_spec, timeout=0.001) except ReadTimeout: pass # wait on second request which returns on promotion finish - res = self.post(url, data=safekeepers_lsn) - res.raise_for_status() + res = self.post(url, json=promote_spec) json: dict[str, str] = res.json() return json diff --git a/test_runner/fixtures/neon_api.py b/test_runner/fixtures/neon_api.py index b26bcb286c..69160dab20 100644 --- a/test_runner/fixtures/neon_api.py +++ b/test_runner/fixtures/neon_api.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re import time from typing import TYPE_CHECKING, cast, final @@ -13,6 +14,17 @@ if TYPE_CHECKING: from fixtures.pg_version import PgVersion +def connstr_to_env(connstr: str) -> dict[str, str]: + # postgresql://neondb_owner:npg_kuv6Rqi1cB@ep-old-silence-w26pxsvz-pooler.us-east-2.aws.neon.build/neondb?sslmode=require&channel_binding=...' + parts = re.split(r":|@|\/|\?", connstr.removeprefix("postgresql://")) + return { + "PGUSER": parts[0], + "PGPASSWORD": parts[1], + "PGHOST": parts[2], + "PGDATABASE": parts[3], + } + + def connection_parameters_to_env(params: dict[str, str]) -> dict[str, str]: return { "PGHOST": params["host"], @@ -227,6 +239,16 @@ class NeonAPI: ) return cast("dict[str, Any]", resp.json()) + def reset_to_parent(self, project_id: str, branch_id: str) -> dict[str, Any]: + resp = self.__request( + "POST", + f"/projects/{project_id}/branches/{branch_id}/reset_to_parent", + headers={ + "Accept": "application/json", + }, + ) + return cast("dict[str, Any]", resp.json()) + def restore_branch( self, project_id: str, diff --git a/test_runner/fixtures/neon_cli.py b/test_runner/fixtures/neon_cli.py index 5ad00d155e..c5ca05d2fd 100644 --- a/test_runner/fixtures/neon_cli.py +++ b/test_runner/fixtures/neon_cli.py @@ -212,11 +212,13 @@ class NeonLocalCli(AbstractNeonCli): pg_version, ] if conf is not None: - args.extend( - chain.from_iterable( - product(["-c"], (f"{key}:{value}" for key, value in conf.items())) - ) - ) + for key, value in conf.items(): + if isinstance(value, bool): + args.extend( + ["-c", f"{key}:{str(value).lower()}"] + ) # only accepts true/false not True/False + else: + args.extend(["-c", f"{key}:{value}"]) if set_default: args.append("--set-default") @@ -528,7 +530,10 @@ class NeonLocalCli(AbstractNeonCli): args.extend(["--external-http-port", str(external_http_port)]) if internal_http_port is not None: args.extend(["--internal-http-port", str(internal_http_port)]) - if grpc: + + # XXX: By checking for None, we enable the new communicator for all tests + # by default + if grpc or grpc is None: args.append("--grpc") if endpoint_id is not None: args.append(endpoint_id) @@ -585,7 +590,9 @@ class NeonLocalCli(AbstractNeonCli): ] extra_env_vars = env or {} if basebackup_request_tries is not None: - extra_env_vars["NEON_COMPUTE_TESTING_BASEBACKUP_TRIES"] = str(basebackup_request_tries) + extra_env_vars["NEON_COMPUTE_TESTING_BASEBACKUP_RETRIES"] = str( + basebackup_request_tries + ) if remote_ext_base_url is not None: args.extend(["--remote-ext-base-url", remote_ext_base_url]) @@ -621,6 +628,7 @@ class NeonLocalCli(AbstractNeonCli): pageserver_id: int | None = None, safekeepers: list[int] | None = None, check_return_code=True, + timeout_sec: float | None = None, ) -> subprocess.CompletedProcess[str]: args = ["endpoint", "reconfigure", endpoint_id] if tenant_id is not None: @@ -629,7 +637,16 @@ class NeonLocalCli(AbstractNeonCli): args.extend(["--pageserver-id", str(pageserver_id)]) if safekeepers is not None: args.extend(["--safekeepers", (",".join(map(str, safekeepers)))]) - return self.raw_cli(args, check_return_code=check_return_code) + return self.raw_cli(args, check_return_code=check_return_code, timeout=timeout_sec) + + def endpoint_refresh_configuration( + self, + endpoint_id: str, + ) -> subprocess.CompletedProcess[str]: + args = ["endpoint", "refresh-configuration", endpoint_id] + res = self.raw_cli(args) + res.check_returncode() + return res def endpoint_stop( self, @@ -655,6 +672,22 @@ class NeonLocalCli(AbstractNeonCli): lsn: Lsn | None = None if lsn_str == "null" else Lsn(lsn_str) return lsn, proc + def endpoint_update_pageservers( + self, + endpoint_id: str, + pageserver_id: int | None = None, + ) -> subprocess.CompletedProcess[str]: + args = [ + "endpoint", + "update-pageservers", + endpoint_id, + ] + if pageserver_id is not None: + args.extend(["--pageserver-id", str(pageserver_id)]) + res = self.raw_cli(args) + res.check_returncode() + return res + def mappings_map_branch( self, name: str, tenant_id: TenantId, timeline_id: TimelineId ) -> subprocess.CompletedProcess[str]: diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index ee0a2f4fe9..d074706996 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -262,7 +262,6 @@ class PgProtocol: # pooler does not support statement_timeout # Check if the hostname contains the string 'pooler' hostname = result.get("host", "") - log.info(f"Hostname: {hostname}") options = result.get("options", "") if "statement_timeout" not in options and "pooler" not in hostname: options = f"-cstatement_timeout=120s {options}" @@ -1540,6 +1539,17 @@ class NeonEnv: raise RuntimeError(f"Pageserver with ID {id} not found") + def get_safekeeper(self, id: int) -> Safekeeper: + """ + Look up a safekeeper by its ID. + """ + + for sk in self.safekeepers: + if sk.id == id: + return sk + + raise RuntimeError(f"Safekeeper with ID {id} not found") + def get_tenant_pageserver(self, tenant_id: TenantId | TenantShardId): """ Get the NeonPageserver where this tenant shard is currently attached, according @@ -2303,6 +2313,7 @@ class NeonStorageController(MetricsGetter, LogUtils): timeline_id: TimelineId, new_sk_set: list[int], ): + log.info(f"migrate_safekeepers({tenant_id}, {timeline_id}, {new_sk_set})") response = self.request( "POST", f"{self.api}/v1/tenant/{tenant_id}/timeline/{timeline_id}/safekeeper_migrate", @@ -3899,6 +3910,41 @@ class NeonProxy(PgProtocol): assert response.status_code == expected_code, f"response: {response.json()}" return response.json() + def http_multiquery(self, *queries, **kwargs): + # TODO maybe use default values if not provided + user = quote(kwargs["user"]) + password = quote(kwargs["password"]) + expected_code = kwargs.get("expected_code") + timeout = kwargs.get("timeout") + + json_queries = [] + for query in queries: + if type(query) is str: + json_queries.append({"query": query}) + else: + [query, params] = query + json_queries.append({"query": query, "params": params}) + + queries_str = [j["query"] for j in json_queries] + log.info(f"Executing http queries: {queries_str}") + + connstr = f"postgresql://{user}:{password}@{self.domain}:{self.proxy_port}/postgres" + response = requests.post( + f"https://{self.domain}:{self.external_http_port}/sql", + data=json.dumps({"queries": json_queries}), + headers={ + "Content-Type": "application/sql", + "Neon-Connection-String": connstr, + "Neon-Pool-Opt-In": "true", + }, + verify=str(self.test_output_dir / "proxy.crt"), + timeout=timeout, + ) + + if expected_code is not None: + assert response.status_code == expected_code, f"response: {response.json()}" + return response.json() + async def http2_query(self, query, args, **kwargs): # TODO maybe use default values if not provided user = kwargs["user"] @@ -4734,17 +4780,6 @@ class Endpoint(PgProtocol, LogUtils): # and make tests more stable. config_lines += ["max_replication_write_lag=15MB"] - # If gRPC is enabled, use the new communicator too. - # - # NB: the communicator is enabled by default, so force it to false otherwise. - # - # XXX: By checking for None, we enable the new communicator for all tests - # by default - if grpc or grpc is None: - config_lines += ["neon.use_communicator_worker=on"] - else: - config_lines += ["neon.use_communicator_worker=off"] - # Delete file cache if it exists (and we're recreating the endpoint) if USE_LFC: if (lfc_path := Path(self.lfc_path())).exists(): @@ -4759,9 +4794,10 @@ class Endpoint(PgProtocol, LogUtils): m = re.search(r"=\s*(\S+)", line) assert m is not None, f"malformed config line {line}" size = m.group(1) - assert size_to_bytes(size) >= size_to_bytes("1MB"), ( - "LFC size cannot be set less than 1MB" - ) + if size_to_bytes(size) > 0: + assert size_to_bytes(size) >= size_to_bytes("1MB"), ( + "LFC size cannot be set less than 1MB" + ) lfc_path_escaped = str(lfc_path).replace("'", "''") config_lines = [ f"neon.file_cache_path = '{lfc_path_escaped}'", @@ -4894,15 +4930,38 @@ class Endpoint(PgProtocol, LogUtils): def is_running(self): return self._running._value > 0 - def reconfigure(self, pageserver_id: int | None = None, safekeepers: list[int] | None = None): + def reconfigure( + self, + pageserver_id: int | None = None, + safekeepers: list[int] | None = None, + timeout_sec: float = 120, + ): assert self.endpoint_id is not None # If `safekeepers` is not None, they are remember them as active and use # in the following commands. if safekeepers is not None: self.active_safekeepers = safekeepers - self.env.neon_cli.endpoint_reconfigure( - self.endpoint_id, self.tenant_id, pageserver_id, self.active_safekeepers - ) + + start_time = time.time() + while True: + try: + self.env.neon_cli.endpoint_reconfigure( + self.endpoint_id, + self.tenant_id, + pageserver_id, + self.active_safekeepers, + timeout_sec=timeout_sec, + ) + return + except RuntimeError as e: + if time.time() - start_time > timeout_sec: + raise e + log.warning(f"Reconfigure failed with error: {e}. Retrying...") + time.sleep(5) + + def refresh_configuration(self): + assert self.endpoint_id is not None + self.env.neon_cli.endpoint_refresh_configuration(self.endpoint_id) def respec(self, **kwargs: Any) -> None: """Update the endpoint.json file used by control_plane.""" @@ -4916,6 +4975,10 @@ class Endpoint(PgProtocol, LogUtils): log.debug(json.dumps(dict(data_dict, **kwargs))) json.dump(dict(data_dict, **kwargs), file, indent=4) + def get_compute_spec(self) -> dict[str, Any]: + out = json.loads((Path(self.endpoint_path()) / "config.json").read_text())["spec"] + return cast("dict[str, Any]", out) + def respec_deep(self, **kwargs: Any) -> None: """ Update the endpoint.json file taking into account nested keys. @@ -4946,6 +5009,10 @@ class Endpoint(PgProtocol, LogUtils): log.debug("Updating compute config to: %s", json.dumps(config, indent=4)) json.dump(config, file, indent=4) + def update_pageservers_in_config(self, pageserver_id: int | None = None): + assert self.endpoint_id is not None + self.env.neon_cli.endpoint_update_pageservers(self.endpoint_id, pageserver_id) + def wait_for_migrations(self, wait_for: int = NUM_COMPUTE_MIGRATIONS) -> None: """ Wait for all compute migrations to be ran. Remember that migrations only @@ -5213,16 +5280,32 @@ class EndpointFactory: ) def stop_all(self, fail_on_error=True) -> Self: - exception = None - for ep in self.endpoints: + """ + Stop all the endpoints in parallel. + """ + + # Note: raising an exception from a task in a task group cancels + # all the other tasks. We don't want that, hence the 'stop_one' + # function catches exceptions and puts them on the 'exceptions' + # list for later processing. + exceptions = [] + + async def stop_one(ep): try: - ep.stop() + await asyncio.to_thread(ep.stop) except Exception as e: log.error(f"Failed to stop endpoint {ep.endpoint_id}: {e}") - exception = e + exceptions.append(e) - if fail_on_error and exception is not None: - raise exception + async def async_stop_all(): + async with asyncio.TaskGroup() as tg: + for ep in self.endpoints: + tg.create_task(stop_one(ep)) + + asyncio.run(async_stop_all()) + + if fail_on_error and exceptions: + raise ExceptionGroup("stopping an endpoint failed", exceptions) return self @@ -5402,15 +5485,24 @@ class Safekeeper(LogUtils): return timeline_status.commit_lsn def pull_timeline( - self, srcs: list[Safekeeper], tenant_id: TenantId, timeline_id: TimelineId + self, + srcs: list[Safekeeper], + tenant_id: TenantId, + timeline_id: TimelineId, + mconf: MembershipConfiguration | None = None, ) -> dict[str, Any]: """ pull_timeline from srcs to self. """ src_https = [f"http://localhost:{sk.port.http}" for sk in srcs] - res = self.http_client().pull_timeline( - {"tenant_id": str(tenant_id), "timeline_id": str(timeline_id), "http_hosts": src_https} - ) + body: dict[str, Any] = { + "tenant_id": str(tenant_id), + "timeline_id": str(timeline_id), + "http_hosts": src_https, + } + if mconf is not None: + body["mconf"] = mconf.__dict__ + res = self.http_client().pull_timeline(body) src_ids = [sk.id for sk in srcs] log.info(f"finished pulling timeline from {src_ids} to {self.id}") return res diff --git a/test_runner/fixtures/workload.py b/test_runner/fixtures/workload.py index e17a8e989b..3ac61b5d8c 100644 --- a/test_runner/fixtures/workload.py +++ b/test_runner/fixtures/workload.py @@ -78,6 +78,9 @@ class Workload: """ if self._endpoint is not None: with ENDPOINT_LOCK: + # It's important that we update config.json before issuing the reconfigure request to make sure + # that PG-initiated spec refresh doesn't mess things up by reverting to the old spec. + self._endpoint.update_pageservers_in_config() self._endpoint.reconfigure() def endpoint(self, pageserver_id: int | None = None) -> Endpoint: @@ -97,10 +100,10 @@ class Workload: self._endpoint.start(pageserver_id=pageserver_id) self._configured_pageserver = pageserver_id else: - if self._configured_pageserver != pageserver_id: - self._configured_pageserver = pageserver_id - self._endpoint.reconfigure(pageserver_id=pageserver_id) - self._endpoint_config = pageserver_id + # It's important that we update config.json before issuing the reconfigure request to make sure + # that PG-initiated spec refresh doesn't mess things up by reverting to the old spec. + self._endpoint.update_pageservers_in_config(pageserver_id=pageserver_id) + self._endpoint.reconfigure(pageserver_id=pageserver_id) connstring = self._endpoint.safe_psql( "SELECT setting FROM pg_settings WHERE name='neon.pageserver_connstring'" diff --git a/test_runner/logical_repl/README.md b/test_runner/logical_repl/README.md index 449e56e21d..74af203c03 100644 --- a/test_runner/logical_repl/README.md +++ b/test_runner/logical_repl/README.md @@ -9,9 +9,10 @@ ```bash export BENCHMARK_CONNSTR=postgres://user:pass@ep-abc-xyz-123.us-east-2.aws.neon.build/neondb +export CLICKHOUSE_PASSWORD=ch_password123 docker compose -f test_runner/logical_repl/clickhouse/docker-compose.yml up -d -./scripts/pytest -m remote_cluster -k test_clickhouse +./scripts/pytest -m remote_cluster -k 'test_clickhouse[release-pg17]' docker compose -f test_runner/logical_repl/clickhouse/docker-compose.yml down ``` @@ -21,6 +22,6 @@ docker compose -f test_runner/logical_repl/clickhouse/docker-compose.yml down export BENCHMARK_CONNSTR=postgres://user:pass@ep-abc-xyz-123.us-east-2.aws.neon.build/neondb docker compose -f test_runner/logical_repl/debezium/docker-compose.yml up -d -./scripts/pytest -m remote_cluster -k test_debezium +./scripts/pytest -m remote_cluster -k 'test_debezium[release-pg17]' docker compose -f test_runner/logical_repl/debezium/docker-compose.yml down ``` diff --git a/test_runner/logical_repl/clickhouse/docker-compose.yml b/test_runner/logical_repl/clickhouse/docker-compose.yml index e00038b811..4131fbf0a5 100644 --- a/test_runner/logical_repl/clickhouse/docker-compose.yml +++ b/test_runner/logical_repl/clickhouse/docker-compose.yml @@ -1,9 +1,11 @@ services: clickhouse: - image: clickhouse/clickhouse-server + image: clickhouse/clickhouse-server:25.6 user: "101:101" container_name: clickhouse hostname: clickhouse + environment: + - CLICKHOUSE_PASSWORD=${CLICKHOUSE_PASSWORD:-ch_password123} ports: - 127.0.0.1:8123:8123 - 127.0.0.1:9000:9000 diff --git a/test_runner/logical_repl/debezium/docker-compose.yml b/test_runner/logical_repl/debezium/docker-compose.yml index fee127a2fd..fd446f173f 100644 --- a/test_runner/logical_repl/debezium/docker-compose.yml +++ b/test_runner/logical_repl/debezium/docker-compose.yml @@ -1,18 +1,28 @@ services: zookeeper: - image: quay.io/debezium/zookeeper:2.7 + image: quay.io/debezium/zookeeper:3.1.3.Final + ports: + - 127.0.0.1:2181:2181 + - 127.0.0.1:2888:2888 + - 127.0.0.1:3888:3888 kafka: - image: quay.io/debezium/kafka:2.7 + image: quay.io/debezium/kafka:3.1.3.Final + depends_on: [zookeeper] environment: ZOOKEEPER_CONNECT: "zookeeper:2181" - KAFKA_ADVERTISED_LISTENERS: PLAINTEXT://kafka:9092 + KAFKA_LISTENERS: INTERNAL://:9092,EXTERNAL://:29092 + KAFKA_ADVERTISED_LISTENERS: INTERNAL://kafka:9092,EXTERNAL://localhost:29092 + KAFKA_LISTENER_SECURITY_PROTOCOL_MAP: INTERNAL:PLAINTEXT,EXTERNAL:PLAINTEXT + KAFKA_INTER_BROKER_LISTENER_NAME: INTERNAL KAFKA_BROKER_ID: 1 KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR: 1 KAFKA_JMX_PORT: 9991 ports: - - 127.0.0.1:9092:9092 + - 9092:9092 + - 29092:29092 debezium: - image: quay.io/debezium/connect:2.7 + image: quay.io/debezium/connect:3.1.3.Final + depends_on: [kafka] environment: BOOTSTRAP_SERVERS: kafka:9092 GROUP_ID: 1 diff --git a/test_runner/logical_repl/test_clickhouse.py b/test_runner/logical_repl/test_clickhouse.py index c05684baf9..ef41ee6187 100644 --- a/test_runner/logical_repl/test_clickhouse.py +++ b/test_runner/logical_repl/test_clickhouse.py @@ -53,8 +53,13 @@ def test_clickhouse(remote_pg: RemotePostgres): cur.execute("CREATE TABLE table1 (id integer primary key, column1 varchar(10));") cur.execute("INSERT INTO table1 (id, column1) VALUES (1, 'abc'), (2, 'def');") conn.commit() - client = clickhouse_connect.get_client(host=clickhouse_host) + if "CLICKHOUSE_PASSWORD" not in os.environ: + raise RuntimeError("CLICKHOUSE_PASSWORD is not set") + client = clickhouse_connect.get_client( + host=clickhouse_host, password=os.environ["CLICKHOUSE_PASSWORD"] + ) client.command("SET allow_experimental_database_materialized_postgresql=1") + client.command("DROP DATABASE IF EXISTS db1_postgres") client.command( "CREATE DATABASE db1_postgres ENGINE = " f"MaterializedPostgreSQL('{conn_options['host']}', " diff --git a/test_runner/logical_repl/test_debezium.py b/test_runner/logical_repl/test_debezium.py index a53e6cef92..becdaffcf8 100644 --- a/test_runner/logical_repl/test_debezium.py +++ b/test_runner/logical_repl/test_debezium.py @@ -17,6 +17,7 @@ from fixtures.utils import wait_until if TYPE_CHECKING: from fixtures.neon_fixtures import RemotePostgres + from kafka import KafkaConsumer class DebeziumAPI: @@ -101,9 +102,13 @@ def debezium(remote_pg: RemotePostgres): assert len(dbz.list_connectors()) == 1 from kafka import KafkaConsumer + kafka_host = "kafka" if (os.getenv("CI", "false") == "true") else "127.0.0.1" + kafka_port = 9092 if (os.getenv("CI", "false") == "true") else 29092 + log.info("Connecting to Kafka: %s:%s", kafka_host, kafka_port) + consumer = KafkaConsumer( "dbserver1.inventory.customers", - bootstrap_servers=["kafka:9092"], + bootstrap_servers=[f"{kafka_host}:{kafka_port}"], auto_offset_reset="earliest", enable_auto_commit=False, ) @@ -112,7 +117,7 @@ def debezium(remote_pg: RemotePostgres): assert resp.status_code == 204 -def get_kafka_msg(consumer, ts_ms, before=None, after=None) -> None: +def get_kafka_msg(consumer: KafkaConsumer, ts_ms, before=None, after=None) -> None: """ Gets the message from Kafka and checks its validity Arguments: @@ -124,6 +129,7 @@ def get_kafka_msg(consumer, ts_ms, before=None, after=None) -> None: after: a dictionary, if not None, the after field from the kafka message must have the same values for the same keys """ + log.info("Bootstrap servers: %s", consumer.config["bootstrap_servers"]) msg = consumer.poll() assert msg, "Empty message" for val in msg.values(): diff --git a/test_runner/performance/benchbase_tpc_c_helpers/generate_diagrams.py b/test_runner/performance/benchbase_tpc_c_helpers/generate_diagrams.py new file mode 100644 index 0000000000..cf41a4ff59 --- /dev/null +++ b/test_runner/performance/benchbase_tpc_c_helpers/generate_diagrams.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +""" +Generate TPS and latency charts from BenchBase TPC-C results CSV files. + +This script reads a CSV file containing BenchBase results and generates two charts: +1. TPS (requests per second) over time +2. P95 and P99 latencies over time + +Both charts are combined in a single SVG file. +""" + +import argparse +import sys +from pathlib import Path + +import matplotlib.pyplot as plt # type: ignore[import-not-found] +import pandas as pd # type: ignore[import-untyped] + + +def load_results_csv(csv_file_path): + """Load BenchBase results CSV file into a pandas DataFrame.""" + try: + df = pd.read_csv(csv_file_path) + + # Validate required columns exist + required_columns = [ + "Time (seconds)", + "Throughput (requests/second)", + "95th Percentile Latency (millisecond)", + "99th Percentile Latency (millisecond)", + ] + + missing_columns = [col for col in required_columns if col not in df.columns] + if missing_columns: + print(f"Error: Missing required columns: {missing_columns}") + sys.exit(1) + + return df + + except FileNotFoundError: + print(f"Error: CSV file not found: {csv_file_path}") + sys.exit(1) + except pd.errors.EmptyDataError: + print(f"Error: CSV file is empty: {csv_file_path}") + sys.exit(1) + except Exception as e: + print(f"Error reading CSV file: {e}") + sys.exit(1) + + +def generate_charts(df, input_filename, output_svg_path, title_suffix=None): + """Generate combined TPS and latency charts and save as SVG.""" + + # Get the filename without extension for chart titles + file_label = Path(input_filename).stem + + # Build title ending with optional suffix + if title_suffix: + title_ending = f"{title_suffix} - {file_label}" + else: + title_ending = file_label + + # Create figure with two subplots + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10)) + + # Chart 1: Time vs TPS + ax1.plot( + df["Time (seconds)"], + df["Throughput (requests/second)"], + linewidth=1, + color="blue", + alpha=0.7, + ) + ax1.set_xlabel("Time (seconds)") + ax1.set_ylabel("TPS (Requests Per Second)") + ax1.set_title(f"Benchbase TPC-C Like Throughput (TPS) - {title_ending}") + ax1.grid(True, alpha=0.3) + ax1.set_xlim(0, df["Time (seconds)"].max()) + + # Chart 2: Time vs P95 and P99 Latencies + ax2.plot( + df["Time (seconds)"], + df["95th Percentile Latency (millisecond)"], + linewidth=1, + color="orange", + alpha=0.7, + label="Latency P95", + ) + ax2.plot( + df["Time (seconds)"], + df["99th Percentile Latency (millisecond)"], + linewidth=1, + color="red", + alpha=0.7, + label="Latency P99", + ) + ax2.set_xlabel("Time (seconds)") + ax2.set_ylabel("Latency (ms)") + ax2.set_title(f"Benchbase TPC-C Like Latency - {title_ending}") + ax2.grid(True, alpha=0.3) + ax2.set_xlim(0, df["Time (seconds)"].max()) + ax2.legend() + + plt.tight_layout() + + # Save as SVG + try: + plt.savefig(output_svg_path, format="svg", dpi=300, bbox_inches="tight") + print(f"Charts saved to: {output_svg_path}") + except Exception as e: + print(f"Error saving SVG file: {e}") + sys.exit(1) + + +def main(): + """Main function to parse arguments and generate charts.""" + parser = argparse.ArgumentParser( + description="Generate TPS and latency charts from BenchBase TPC-C results CSV" + ) + parser.add_argument( + "--input-csv", type=str, required=True, help="Path to the input CSV results file" + ) + parser.add_argument( + "--output-svg", type=str, required=True, help="Path for the output SVG chart file" + ) + parser.add_argument( + "--title-suffix", + type=str, + required=False, + help="Optional suffix to add to chart titles (e.g., 'Warmup', 'Benchmark Phase')", + ) + + args = parser.parse_args() + + # Validate input file exists + if not Path(args.input_csv).exists(): + print(f"Error: Input CSV file does not exist: {args.input_csv}") + sys.exit(1) + + # Create output directory if it doesn't exist + output_path = Path(args.output_svg) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Load data and generate charts + df = load_results_csv(args.input_csv) + generate_charts(df, args.input_csv, args.output_svg, args.title_suffix) + + print(f"Successfully generated charts from {len(df)} data points") + + +if __name__ == "__main__": + main() diff --git a/test_runner/performance/benchbase_tpc_c_helpers/generate_workload_size.py b/test_runner/performance/benchbase_tpc_c_helpers/generate_workload_size.py new file mode 100644 index 0000000000..1549c74b87 --- /dev/null +++ b/test_runner/performance/benchbase_tpc_c_helpers/generate_workload_size.py @@ -0,0 +1,339 @@ +import argparse +import html +import math +import os +import sys +from pathlib import Path + +CONFIGS_DIR = Path("../configs") +SCRIPTS_DIR = Path("../scripts") + +# Constants +## TODO increase times after testing +WARMUP_TIME_SECONDS = 1200 # 20 minutes +BENCHMARK_TIME_SECONDS = 3600 # 1 hour +RAMP_STEP_TIME_SECONDS = 300 # 5 minutes +BASE_TERMINALS = 130 +TERMINALS_PER_WAREHOUSE = 0.2 +OPTIMAL_RATE_FACTOR = 0.7 # 70% of max rate +BATCH_SIZE = 1000 +LOADER_THREADS = 4 +TRANSACTION_WEIGHTS = "45,43,4,4,4" # NewOrder, Payment, OrderStatus, Delivery, StockLevel +# Ramp-up rate multipliers +RAMP_RATE_FACTORS = [1.5, 1.1, 0.9, 0.7, 0.6, 0.4, 0.6, 0.7, 0.9, 1.1] + +# Templates for XML configs +WARMUP_XML = """ + + POSTGRES + org.postgresql.Driver + jdbc:postgresql://{hostname}/neondb?sslmode=require&ApplicationName=tpcc&reWriteBatchedInserts=true + neondb_owner + {password} + true + TRANSACTION_READ_COMMITTED + {batch_size} + {warehouses} + 0 + {terminals} + + + + {transaction_weights} + unlimited + POISSON + ZIPFIAN + + + + NewOrder + Payment + OrderStatus + Delivery + StockLevel + + +""" + +MAX_RATE_XML = """ + + POSTGRES + org.postgresql.Driver + jdbc:postgresql://{hostname}/neondb?sslmode=require&ApplicationName=tpcc&reWriteBatchedInserts=true + neondb_owner + {password} + true + TRANSACTION_READ_COMMITTED + {batch_size} + {warehouses} + 0 + {terminals} + + + + {transaction_weights} + unlimited + POISSON + ZIPFIAN + + + + NewOrder + Payment + OrderStatus + Delivery + StockLevel + + +""" + +OPT_RATE_XML = """ + + POSTGRES + org.postgresql.Driver + jdbc:postgresql://{hostname}/neondb?sslmode=require&ApplicationName=tpcc&reWriteBatchedInserts=true + neondb_owner + {password} + true + TRANSACTION_READ_COMMITTED + {batch_size} + {warehouses} + 0 + {terminals} + + + + {opt_rate} + {transaction_weights} + POISSON + ZIPFIAN + + + + NewOrder + Payment + OrderStatus + Delivery + StockLevel + + +""" + +RAMP_UP_XML = """ + + POSTGRES + org.postgresql.Driver + jdbc:postgresql://{hostname}/neondb?sslmode=require&ApplicationName=tpcc&reWriteBatchedInserts=true + neondb_owner + {password} + true + TRANSACTION_READ_COMMITTED + {batch_size} + {warehouses} + 0 + {terminals} + +{works} + + + NewOrder + Payment + OrderStatus + Delivery + StockLevel + + +""" + +WORK_TEMPLATE = f""" \n \n {{rate}}\n {TRANSACTION_WEIGHTS}\n POISSON\n ZIPFIAN\n \n""" + +# Templates for shell scripts +EXECUTE_SCRIPT = """# Create results directories +mkdir -p results_warmup +mkdir -p results_{suffix} +chmod 777 results_warmup results_{suffix} + +# Run warmup phase +docker run --network=host --rm \ + -v $(pwd)/configs:/configs \ + -v $(pwd)/results_warmup:/results \ + {docker_image}\ + -b tpcc \ + -c /configs/execute_{warehouses}_warehouses_warmup.xml \ + -d /results \ + --create=false --load=false --execute=true + +# Run benchmark phase +docker run --network=host --rm \ + -v $(pwd)/configs:/configs \ + -v $(pwd)/results_{suffix}:/results \ + {docker_image}\ + -b tpcc \ + -c /configs/execute_{warehouses}_warehouses_{suffix}.xml \ + -d /results \ + --create=false --load=false --execute=true\n""" + +LOAD_XML = """ + + POSTGRES + org.postgresql.Driver + jdbc:postgresql://{hostname}/neondb?sslmode=require&ApplicationName=tpcc&reWriteBatchedInserts=true + neondb_owner + {password} + true + TRANSACTION_READ_COMMITTED + {batch_size} + {warehouses} + {loader_threads} + +""" + +LOAD_SCRIPT = """# Create results directory for loading +mkdir -p results_load +chmod 777 results_load + +docker run --network=host --rm \ + -v $(pwd)/configs:/configs \ + -v $(pwd)/results_load:/results \ + {docker_image}\ + -b tpcc \ + -c /configs/load_{warehouses}_warehouses.xml \ + -d /results \ + --create=true --load=true --execute=false\n""" + + +def write_file(path, content): + path.parent.mkdir(parents=True, exist_ok=True) + try: + with open(path, "w") as f: + f.write(content) + except OSError as e: + print(f"Error writing {path}: {e}") + sys.exit(1) + # If it's a shell script, set executable permission + if str(path).endswith(".sh"): + os.chmod(path, 0o755) + + +def escape_xml_password(password): + """Escape XML special characters in password.""" + return html.escape(password, quote=True) + + +def get_docker_arch_tag(runner_arch): + """Map GitHub Actions runner.arch to Docker image architecture tag.""" + arch_mapping = {"X64": "amd64", "ARM64": "arm64"} + return arch_mapping.get(runner_arch, "amd64") # Default to amd64 + + +def main(): + parser = argparse.ArgumentParser(description="Generate BenchBase workload configs and scripts.") + parser.add_argument("--warehouses", type=int, required=True, help="Number of warehouses") + parser.add_argument("--max-rate", type=int, required=True, help="Max rate (TPS)") + parser.add_argument("--hostname", type=str, required=True, help="Database hostname") + parser.add_argument("--password", type=str, required=True, help="Database password") + parser.add_argument( + "--runner-arch", type=str, required=True, help="GitHub Actions runner architecture" + ) + args = parser.parse_args() + + warehouses = args.warehouses + max_rate = args.max_rate + hostname = args.hostname + password = args.password + runner_arch = args.runner_arch + + # Escape password for safe XML insertion + escaped_password = escape_xml_password(password) + + # Get the appropriate Docker architecture tag + docker_arch = get_docker_arch_tag(runner_arch) + docker_image = f"ghcr.io/neondatabase-labs/benchbase-postgres:latest-{docker_arch}" + + opt_rate = math.ceil(max_rate * OPTIMAL_RATE_FACTOR) + # Calculate terminals as next rounded integer of 40% of warehouses + terminals = math.ceil(BASE_TERMINALS + warehouses * TERMINALS_PER_WAREHOUSE) + ramp_rates = [math.ceil(max_rate * factor) for factor in RAMP_RATE_FACTORS] + + # Write configs + write_file( + CONFIGS_DIR / f"execute_{warehouses}_warehouses_warmup.xml", + WARMUP_XML.format( + warehouses=warehouses, + hostname=hostname, + password=escaped_password, + terminals=terminals, + batch_size=BATCH_SIZE, + warmup_time=WARMUP_TIME_SECONDS, + transaction_weights=TRANSACTION_WEIGHTS, + ), + ) + write_file( + CONFIGS_DIR / f"execute_{warehouses}_warehouses_max_rate.xml", + MAX_RATE_XML.format( + warehouses=warehouses, + hostname=hostname, + password=escaped_password, + terminals=terminals, + batch_size=BATCH_SIZE, + benchmark_time=BENCHMARK_TIME_SECONDS, + transaction_weights=TRANSACTION_WEIGHTS, + ), + ) + write_file( + CONFIGS_DIR / f"execute_{warehouses}_warehouses_opt_rate.xml", + OPT_RATE_XML.format( + warehouses=warehouses, + opt_rate=opt_rate, + hostname=hostname, + password=escaped_password, + terminals=terminals, + batch_size=BATCH_SIZE, + benchmark_time=BENCHMARK_TIME_SECONDS, + transaction_weights=TRANSACTION_WEIGHTS, + ), + ) + + ramp_works = "".join([WORK_TEMPLATE.format(rate=rate) for rate in ramp_rates]) + write_file( + CONFIGS_DIR / f"execute_{warehouses}_warehouses_ramp_up.xml", + RAMP_UP_XML.format( + warehouses=warehouses, + works=ramp_works, + hostname=hostname, + password=escaped_password, + terminals=terminals, + batch_size=BATCH_SIZE, + ), + ) + + # Loader config + write_file( + CONFIGS_DIR / f"load_{warehouses}_warehouses.xml", + LOAD_XML.format( + warehouses=warehouses, + hostname=hostname, + password=escaped_password, + batch_size=BATCH_SIZE, + loader_threads=LOADER_THREADS, + ), + ) + + # Write scripts + for suffix in ["max_rate", "opt_rate", "ramp_up"]: + script = EXECUTE_SCRIPT.format( + warehouses=warehouses, suffix=suffix, docker_image=docker_image + ) + write_file(SCRIPTS_DIR / f"execute_{warehouses}_warehouses_{suffix}.sh", script) + + # Loader script + write_file( + SCRIPTS_DIR / f"load_{warehouses}_warehouses.sh", + LOAD_SCRIPT.format(warehouses=warehouses, docker_image=docker_image), + ) + + print(f"Generated configs and scripts for {warehouses} warehouses and max rate {max_rate}.") + + +if __name__ == "__main__": + main() diff --git a/test_runner/performance/benchbase_tpc_c_helpers/upload_results_to_perf_test_results.py b/test_runner/performance/benchbase_tpc_c_helpers/upload_results_to_perf_test_results.py new file mode 100644 index 0000000000..3706d14fd4 --- /dev/null +++ b/test_runner/performance/benchbase_tpc_c_helpers/upload_results_to_perf_test_results.py @@ -0,0 +1,591 @@ +#!/usr/bin/env python3 +# ruff: noqa +# we exclude the file from ruff because on the github runner we have python 3.9 and ruff +# is running with newer python 3.12 which suggests changes incompatible with python 3.9 +""" +Upload BenchBase TPC-C results from summary.json and results.csv files to perf_test_results database. + +This script extracts metrics from BenchBase *.summary.json and *.results.csv files and uploads them +to a PostgreSQL database table for performance tracking and analysis. +""" + +import argparse +import json +import re +import sys +from datetime import datetime, timezone +from pathlib import Path + +import pandas as pd # type: ignore[import-untyped] +import psycopg2 + + +def load_summary_json(json_file_path): + """Load summary.json file and return parsed data.""" + try: + with open(json_file_path) as f: + return json.load(f) + except FileNotFoundError: + print(f"Error: Summary JSON file not found: {json_file_path}") + sys.exit(1) + except json.JSONDecodeError as e: + print(f"Error: Invalid JSON in file {json_file_path}: {e}") + sys.exit(1) + except Exception as e: + print(f"Error loading JSON file {json_file_path}: {e}") + sys.exit(1) + + +def get_metric_info(metric_name): + """Get metric unit and report type for a given metric name.""" + metrics_config = { + "Throughput": {"unit": "req/s", "report_type": "higher_is_better"}, + "Goodput": {"unit": "req/s", "report_type": "higher_is_better"}, + "Measured Requests": {"unit": "requests", "report_type": "higher_is_better"}, + "95th Percentile Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "Maximum Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "Median Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "Minimum Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "25th Percentile Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "90th Percentile Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "99th Percentile Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "75th Percentile Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "Average Latency": {"unit": "µs", "report_type": "lower_is_better"}, + } + + return metrics_config.get(metric_name, {"unit": "", "report_type": "higher_is_better"}) + + +def extract_metrics(summary_data): + """Extract relevant metrics from summary JSON data.""" + metrics = [] + + # Direct top-level metrics + direct_metrics = { + "Throughput (requests/second)": "Throughput", + "Goodput (requests/second)": "Goodput", + "Measured Requests": "Measured Requests", + } + + for json_key, clean_name in direct_metrics.items(): + if json_key in summary_data: + metrics.append((clean_name, summary_data[json_key])) + + # Latency metrics from nested "Latency Distribution" object + if "Latency Distribution" in summary_data: + latency_data = summary_data["Latency Distribution"] + latency_metrics = { + "95th Percentile Latency (microseconds)": "95th Percentile Latency", + "Maximum Latency (microseconds)": "Maximum Latency", + "Median Latency (microseconds)": "Median Latency", + "Minimum Latency (microseconds)": "Minimum Latency", + "25th Percentile Latency (microseconds)": "25th Percentile Latency", + "90th Percentile Latency (microseconds)": "90th Percentile Latency", + "99th Percentile Latency (microseconds)": "99th Percentile Latency", + "75th Percentile Latency (microseconds)": "75th Percentile Latency", + "Average Latency (microseconds)": "Average Latency", + } + + for json_key, clean_name in latency_metrics.items(): + if json_key in latency_data: + metrics.append((clean_name, latency_data[json_key])) + + return metrics + + +def build_labels(summary_data, project_id): + """Build labels JSON object from summary data and project info.""" + labels = {} + + # Extract required label keys from summary data + label_keys = [ + "DBMS Type", + "DBMS Version", + "Benchmark Type", + "Final State", + "isolation", + "scalefactor", + "terminals", + ] + + for key in label_keys: + if key in summary_data: + labels[key] = summary_data[key] + + # Add project_id from workflow + labels["project_id"] = project_id + + return labels + + +def build_suit_name(scalefactor, terminals, run_type, min_cu, max_cu): + """Build the suit name according to specification.""" + return f"benchbase-tpc-c-{scalefactor}-{terminals}-{run_type}-{min_cu}-{max_cu}" + + +def convert_timestamp_to_utc(timestamp_ms): + """Convert millisecond timestamp to PostgreSQL-compatible UTC timestamp.""" + try: + dt = datetime.fromtimestamp(timestamp_ms / 1000.0, tz=timezone.utc) + return dt.isoformat() + except (ValueError, TypeError) as e: + print(f"Warning: Could not convert timestamp {timestamp_ms}: {e}") + return datetime.now(timezone.utc).isoformat() + + +def insert_metrics(conn, metrics_data): + """Insert metrics data into the perf_test_results table.""" + insert_query = """ + INSERT INTO perf_test_results + (suit, revision, platform, metric_name, metric_value, metric_unit, + metric_report_type, recorded_at_timestamp, labels) + VALUES (%(suit)s, %(revision)s, %(platform)s, %(metric_name)s, %(metric_value)s, + %(metric_unit)s, %(metric_report_type)s, %(recorded_at_timestamp)s, %(labels)s) + """ + + try: + with conn.cursor() as cursor: + cursor.executemany(insert_query, metrics_data) + conn.commit() + print(f"Successfully inserted {len(metrics_data)} metrics into perf_test_results") + + # Log some sample data for verification + if metrics_data: + print( + f"Sample metric: {metrics_data[0]['metric_name']} = {metrics_data[0]['metric_value']} {metrics_data[0]['metric_unit']}" + ) + + except Exception as e: + print(f"Error inserting metrics into database: {e}") + sys.exit(1) + + +def create_benchbase_results_details_table(conn): + """Create benchbase_results_details table if it doesn't exist.""" + create_table_query = """ + CREATE TABLE IF NOT EXISTS benchbase_results_details ( + id BIGSERIAL PRIMARY KEY, + suit TEXT, + revision CHAR(40), + platform TEXT, + recorded_at_timestamp TIMESTAMP WITH TIME ZONE, + requests_per_second NUMERIC, + average_latency_ms NUMERIC, + minimum_latency_ms NUMERIC, + p25_latency_ms NUMERIC, + median_latency_ms NUMERIC, + p75_latency_ms NUMERIC, + p90_latency_ms NUMERIC, + p95_latency_ms NUMERIC, + p99_latency_ms NUMERIC, + maximum_latency_ms NUMERIC + ); + + CREATE INDEX IF NOT EXISTS benchbase_results_details_recorded_at_timestamp_idx + ON benchbase_results_details USING BRIN (recorded_at_timestamp); + CREATE INDEX IF NOT EXISTS benchbase_results_details_suit_idx + ON benchbase_results_details USING BTREE (suit text_pattern_ops); + """ + + try: + with conn.cursor() as cursor: + cursor.execute(create_table_query) + conn.commit() + print("Successfully created/verified benchbase_results_details table") + except Exception as e: + print(f"Error creating benchbase_results_details table: {e}") + sys.exit(1) + + +def process_csv_results(csv_file_path, start_timestamp_ms, suit, revision, platform): + """Process CSV results and return data for database insertion.""" + try: + # Read CSV file + df = pd.read_csv(csv_file_path) + + # Validate required columns exist + required_columns = [ + "Time (seconds)", + "Throughput (requests/second)", + "Average Latency (millisecond)", + "Minimum Latency (millisecond)", + "25th Percentile Latency (millisecond)", + "Median Latency (millisecond)", + "75th Percentile Latency (millisecond)", + "90th Percentile Latency (millisecond)", + "95th Percentile Latency (millisecond)", + "99th Percentile Latency (millisecond)", + "Maximum Latency (millisecond)", + ] + + missing_columns = [col for col in required_columns if col not in df.columns] + if missing_columns: + print(f"Error: Missing required columns in CSV: {missing_columns}") + return [] + + csv_data = [] + + for _, row in df.iterrows(): + # Calculate timestamp: start_timestamp_ms + (time_seconds * 1000) + time_seconds = row["Time (seconds)"] + row_timestamp_ms = start_timestamp_ms + (time_seconds * 1000) + + # Convert to UTC timestamp + row_timestamp = datetime.fromtimestamp( + row_timestamp_ms / 1000.0, tz=timezone.utc + ).isoformat() + + csv_row = { + "suit": suit, + "revision": revision, + "platform": platform, + "recorded_at_timestamp": row_timestamp, + "requests_per_second": float(row["Throughput (requests/second)"]), + "average_latency_ms": float(row["Average Latency (millisecond)"]), + "minimum_latency_ms": float(row["Minimum Latency (millisecond)"]), + "p25_latency_ms": float(row["25th Percentile Latency (millisecond)"]), + "median_latency_ms": float(row["Median Latency (millisecond)"]), + "p75_latency_ms": float(row["75th Percentile Latency (millisecond)"]), + "p90_latency_ms": float(row["90th Percentile Latency (millisecond)"]), + "p95_latency_ms": float(row["95th Percentile Latency (millisecond)"]), + "p99_latency_ms": float(row["99th Percentile Latency (millisecond)"]), + "maximum_latency_ms": float(row["Maximum Latency (millisecond)"]), + } + csv_data.append(csv_row) + + print(f"Processed {len(csv_data)} rows from CSV file") + return csv_data + + except FileNotFoundError: + print(f"Error: CSV file not found: {csv_file_path}") + return [] + except Exception as e: + print(f"Error processing CSV file {csv_file_path}: {e}") + return [] + + +def insert_csv_results(conn, csv_data): + """Insert CSV results into benchbase_results_details table.""" + if not csv_data: + print("No CSV data to insert") + return + + insert_query = """ + INSERT INTO benchbase_results_details + (suit, revision, platform, recorded_at_timestamp, requests_per_second, + average_latency_ms, minimum_latency_ms, p25_latency_ms, median_latency_ms, + p75_latency_ms, p90_latency_ms, p95_latency_ms, p99_latency_ms, maximum_latency_ms) + VALUES (%(suit)s, %(revision)s, %(platform)s, %(recorded_at_timestamp)s, %(requests_per_second)s, + %(average_latency_ms)s, %(minimum_latency_ms)s, %(p25_latency_ms)s, %(median_latency_ms)s, + %(p75_latency_ms)s, %(p90_latency_ms)s, %(p95_latency_ms)s, %(p99_latency_ms)s, %(maximum_latency_ms)s) + """ + + try: + with conn.cursor() as cursor: + cursor.executemany(insert_query, csv_data) + conn.commit() + print( + f"Successfully inserted {len(csv_data)} detailed results into benchbase_results_details" + ) + + # Log some sample data for verification + sample = csv_data[0] + print( + f"Sample detail: {sample['requests_per_second']} req/s at {sample['recorded_at_timestamp']}" + ) + + except Exception as e: + print(f"Error inserting CSV results into database: {e}") + sys.exit(1) + + +def parse_load_log(log_file_path, scalefactor): + """Parse load log file and extract load metrics.""" + try: + with open(log_file_path) as f: + log_content = f.read() + + # Regex patterns to match the timestamp lines + loading_pattern = r"\[INFO \] (\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}),\d{3}.*Loading data into TPCC database" + finished_pattern = r"\[INFO \] (\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}),\d{3}.*Finished loading data into TPCC database" + + loading_match = re.search(loading_pattern, log_content) + finished_match = re.search(finished_pattern, log_content) + + if not loading_match or not finished_match: + print(f"Warning: Could not find loading timestamps in log file {log_file_path}") + return None + + # Parse timestamps + loading_time = datetime.strptime(loading_match.group(1), "%Y-%m-%d %H:%M:%S") + finished_time = datetime.strptime(finished_match.group(1), "%Y-%m-%d %H:%M:%S") + + # Calculate duration in seconds + duration_seconds = (finished_time - loading_time).total_seconds() + + # Calculate throughput: scalefactor/warehouses: 10 warehouses is approx. 1 GB of data + load_throughput = (scalefactor * 1024 / 10.0) / duration_seconds + + # Convert end time to UTC timestamp for database + finished_time_utc = finished_time.replace(tzinfo=timezone.utc).isoformat() + + print(f"Load metrics: Duration={duration_seconds}s, Throughput={load_throughput:.2f} MB/s") + + return { + "duration_seconds": duration_seconds, + "throughput_mb_per_sec": load_throughput, + "end_timestamp": finished_time_utc, + } + + except FileNotFoundError: + print(f"Warning: Load log file not found: {log_file_path}") + return None + except Exception as e: + print(f"Error parsing load log file {log_file_path}: {e}") + return None + + +def insert_load_metrics(conn, load_metrics, suit, revision, platform, labels_json): + """Insert load metrics into perf_test_results table.""" + if not load_metrics: + print("No load metrics to insert") + return + + load_metrics_data = [ + { + "suit": suit, + "revision": revision, + "platform": platform, + "metric_name": "load_duration_seconds", + "metric_value": load_metrics["duration_seconds"], + "metric_unit": "seconds", + "metric_report_type": "lower_is_better", + "recorded_at_timestamp": load_metrics["end_timestamp"], + "labels": labels_json, + }, + { + "suit": suit, + "revision": revision, + "platform": platform, + "metric_name": "load_throughput", + "metric_value": load_metrics["throughput_mb_per_sec"], + "metric_unit": "MB/second", + "metric_report_type": "higher_is_better", + "recorded_at_timestamp": load_metrics["end_timestamp"], + "labels": labels_json, + }, + ] + + insert_query = """ + INSERT INTO perf_test_results + (suit, revision, platform, metric_name, metric_value, metric_unit, + metric_report_type, recorded_at_timestamp, labels) + VALUES (%(suit)s, %(revision)s, %(platform)s, %(metric_name)s, %(metric_value)s, + %(metric_unit)s, %(metric_report_type)s, %(recorded_at_timestamp)s, %(labels)s) + """ + + try: + with conn.cursor() as cursor: + cursor.executemany(insert_query, load_metrics_data) + conn.commit() + print(f"Successfully inserted {len(load_metrics_data)} load metrics into perf_test_results") + + except Exception as e: + print(f"Error inserting load metrics into database: {e}") + sys.exit(1) + + +def main(): + """Main function to parse arguments and upload results.""" + parser = argparse.ArgumentParser( + description="Upload BenchBase TPC-C results to perf_test_results database" + ) + parser.add_argument( + "--summary-json", type=str, required=False, help="Path to the summary.json file" + ) + parser.add_argument( + "--run-type", + type=str, + required=True, + choices=["warmup", "opt-rate", "ramp-up", "load"], + help="Type of benchmark run", + ) + parser.add_argument("--min-cu", type=float, required=True, help="Minimum compute units") + parser.add_argument("--max-cu", type=float, required=True, help="Maximum compute units") + parser.add_argument("--project-id", type=str, required=True, help="Neon project ID") + parser.add_argument( + "--revision", type=str, required=True, help="Git commit hash (40 characters)" + ) + parser.add_argument( + "--connection-string", type=str, required=True, help="PostgreSQL connection string" + ) + parser.add_argument( + "--results-csv", + type=str, + required=False, + help="Path to the results.csv file for detailed metrics upload", + ) + parser.add_argument( + "--load-log", + type=str, + required=False, + help="Path to the load log file for load phase metrics", + ) + parser.add_argument( + "--warehouses", + type=int, + required=False, + help="Number of warehouses (scalefactor) for load metrics calculation", + ) + + args = parser.parse_args() + + # Validate inputs + if args.summary_json and not Path(args.summary_json).exists(): + print(f"Error: Summary JSON file does not exist: {args.summary_json}") + sys.exit(1) + + if not args.summary_json and not args.load_log: + print("Error: Either summary JSON or load log file must be provided") + sys.exit(1) + + if len(args.revision) != 40: + print(f"Warning: Revision should be 40 characters, got {len(args.revision)}") + + # Load and process summary data if provided + summary_data = None + metrics = [] + + if args.summary_json: + summary_data = load_summary_json(args.summary_json) + metrics = extract_metrics(summary_data) + if not metrics: + print("Warning: No metrics found in summary JSON") + + # Build common data for all metrics + if summary_data: + scalefactor = summary_data.get("scalefactor", "unknown") + terminals = summary_data.get("terminals", "unknown") + labels = build_labels(summary_data, args.project_id) + else: + # For load-only processing, use warehouses argument as scalefactor + scalefactor = args.warehouses if args.warehouses else "unknown" + terminals = "unknown" + labels = {"project_id": args.project_id} + + suit = build_suit_name(scalefactor, terminals, args.run_type, args.min_cu, args.max_cu) + platform = f"prod-us-east-2-{args.project_id}" + + # Convert timestamp - only needed for summary metrics and CSV processing + current_timestamp_ms = None + start_timestamp_ms = None + recorded_at = None + + if summary_data: + current_timestamp_ms = summary_data.get("Current Timestamp (milliseconds)") + start_timestamp_ms = summary_data.get("Start timestamp (milliseconds)") + + if current_timestamp_ms: + recorded_at = convert_timestamp_to_utc(current_timestamp_ms) + else: + print("Warning: No timestamp found in JSON, using current time") + recorded_at = datetime.now(timezone.utc).isoformat() + + if not start_timestamp_ms: + print("Warning: No start timestamp found in JSON, CSV upload may be incorrect") + start_timestamp_ms = ( + current_timestamp_ms or datetime.now(timezone.utc).timestamp() * 1000 + ) + + # Print Grafana dashboard link for cross-service endpoint debugging + if start_timestamp_ms and current_timestamp_ms: + grafana_url = ( + f"https://neonprod.grafana.net/d/cdya0okb81zwga/cross-service-endpoint-debugging" + f"?orgId=1&from={int(start_timestamp_ms)}&to={int(current_timestamp_ms)}" + f"&timezone=utc&var-env=prod&var-input_project_id={args.project_id}" + ) + print(f'Cross service endpoint dashboard for "{args.run_type}" phase: {grafana_url}') + + # Prepare metrics data for database insertion (only if we have summary metrics) + metrics_data = [] + if metrics and recorded_at: + for metric_name, metric_value in metrics: + metric_info = get_metric_info(metric_name) + + row = { + "suit": suit, + "revision": args.revision, + "platform": platform, + "metric_name": metric_name, + "metric_value": float(metric_value), # Ensure numeric type + "metric_unit": metric_info["unit"], + "metric_report_type": metric_info["report_type"], + "recorded_at_timestamp": recorded_at, + "labels": json.dumps(labels), # Convert to JSON string for JSONB column + } + metrics_data.append(row) + + print(f"Prepared {len(metrics_data)} summary metrics for upload to database") + print(f"Suit: {suit}") + print(f"Platform: {platform}") + + # Connect to database and insert metrics + try: + conn = psycopg2.connect(args.connection_string) + + # Insert summary metrics into perf_test_results (if any) + if metrics_data: + insert_metrics(conn, metrics_data) + else: + print("No summary metrics to upload") + + # Process and insert detailed CSV results if provided + if args.results_csv: + print(f"Processing detailed CSV results from: {args.results_csv}") + + # Create table if it doesn't exist + create_benchbase_results_details_table(conn) + + # Process CSV data + csv_data = process_csv_results( + args.results_csv, start_timestamp_ms, suit, args.revision, platform + ) + + # Insert CSV data + if csv_data: + insert_csv_results(conn, csv_data) + else: + print("No CSV data to upload") + else: + print("No CSV file provided, skipping detailed results upload") + + # Process and insert load metrics if provided + if args.load_log: + print(f"Processing load metrics from: {args.load_log}") + + # Parse load log and extract metrics + load_metrics = parse_load_log(args.load_log, scalefactor) + + # Insert load metrics + if load_metrics: + insert_load_metrics( + conn, load_metrics, suit, args.revision, platform, json.dumps(labels) + ) + else: + print("No load metrics to upload") + else: + print("No load log file provided, skipping load metrics upload") + + conn.close() + print("Database upload completed successfully") + + except psycopg2.Error as e: + print(f"Database connection/query error: {e}") + sys.exit(1) + except Exception as e: + print(f"Unexpected error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/test_runner/performance/test_lfc_prewarm.py b/test_runner/performance/test_lfc_prewarm.py index 6c0083de95..d459f9f3bf 100644 --- a/test_runner/performance/test_lfc_prewarm.py +++ b/test_runner/performance/test_lfc_prewarm.py @@ -2,45 +2,48 @@ from __future__ import annotations import os import timeit -import traceback -from concurrent.futures import ThreadPoolExecutor as Exec from pathlib import Path +from threading import Thread from time import sleep -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, cast import pytest from fixtures.benchmark_fixture import NeonBenchmarker, PgBenchRunResult from fixtures.log_helper import log -from fixtures.neon_api import NeonAPI, connection_parameters_to_env +from fixtures.neon_api import NeonAPI, connstr_to_env + +from performance.test_perf_pgbench import utc_now_timestamp if TYPE_CHECKING: from fixtures.compare_fixtures import NeonCompare from fixtures.neon_fixtures import Endpoint, PgBin from fixtures.pg_version import PgVersion -from performance.test_perf_pgbench import utc_now_timestamp # These tests compare performance for a write-heavy and read-heavy workloads of an ordinary endpoint -# compared to the endpoint which saves its LFC and prewarms using it on startup. +# compared to the endpoint which saves its LFC and prewarms using it on startup def test_compare_prewarmed_pgbench_perf(neon_compare: NeonCompare): env = neon_compare.env - env.create_branch("normal") env.create_branch("prewarmed") pg_bin = neon_compare.pg_bin - ep_normal: Endpoint = env.endpoints.create_start("normal") - ep_prewarmed: Endpoint = env.endpoints.create_start("prewarmed", autoprewarm=True) + ep_ordinary: Endpoint = neon_compare.endpoint + ep_prewarmed: Endpoint = env.endpoints.create_start("prewarmed") - for ep in [ep_normal, ep_prewarmed]: + for ep in [ep_ordinary, ep_prewarmed]: connstr: str = ep.connstr() pg_bin.run(["pgbench", "-i", "-I", "dtGvp", connstr, "-s100"]) - ep.safe_psql("CREATE EXTENSION neon") - client = ep.http_client() - client.offload_lfc() - ep.stop() - ep.start() - client.prewarm_lfc_wait() + ep.safe_psql("CREATE SCHEMA neon; CREATE EXTENSION neon WITH SCHEMA neon") + if ep == ep_prewarmed: + client = ep.http_client() + client.offload_lfc() + ep.stop() + ep.start(autoprewarm=True) + client.prewarm_lfc_wait() + else: + ep.stop() + ep.start() run_start_timestamp = utc_now_timestamp() t0 = timeit.default_timer() @@ -59,6 +62,36 @@ def test_compare_prewarmed_pgbench_perf(neon_compare: NeonCompare): neon_compare.zenbenchmark.record_pg_bench_result(name, res) +def test_compare_prewarmed_read_perf(neon_compare: NeonCompare): + env = neon_compare.env + env.create_branch("prewarmed") + ep_ordinary: Endpoint = neon_compare.endpoint + ep_prewarmed: Endpoint = env.endpoints.create_start("prewarmed") + + sql = [ + "CREATE SCHEMA neon", + "CREATE EXTENSION neon WITH SCHEMA neon", + "CREATE TABLE foo(key serial primary key, t text default 'foooooooooooooooooooooooooooooooooooooooooooooooooooo')", + "INSERT INTO foo SELECT FROM generate_series(1,1000000)", + ] + sql_check = "SELECT count(*) from foo" + + ep_ordinary.safe_psql_many(sql) + ep_ordinary.stop() + ep_ordinary.start() + with neon_compare.record_duration("ordinary_run_duration"): + ep_ordinary.safe_psql(sql_check) + + ep_prewarmed.safe_psql_many(sql) + client = ep_prewarmed.http_client() + client.offload_lfc() + ep_prewarmed.stop() + ep_prewarmed.start(autoprewarm=True) + client.prewarm_lfc_wait() + with neon_compare.record_duration("prewarmed_run_duration"): + ep_prewarmed.safe_psql(sql_check) + + @pytest.mark.remote_cluster @pytest.mark.timeout(2 * 60 * 60) def test_compare_prewarmed_pgbench_perf_benchmark( @@ -67,67 +100,66 @@ def test_compare_prewarmed_pgbench_perf_benchmark( pg_version: PgVersion, zenbenchmark: NeonBenchmarker, ): - name = f"Test prewarmed pgbench performance, GITHUB_RUN_ID={os.getenv('GITHUB_RUN_ID')}" - project = neon_api.create_project(pg_version, name) - project_id = project["project"]["id"] - neon_api.wait_for_operation_to_finish(project_id) - err = False - try: - benchmark_impl(pg_bin, neon_api, project, zenbenchmark) - except Exception as e: - err = True - log.error(f"Caught exception: {e}") - log.error(traceback.format_exc()) - finally: - assert not err - neon_api.delete_project(project_id) + """ + Prewarm API is not public, so this test relies on a pre-created project + with pgbench size of 3424, pgbench -i -IdtGvp -s3424. Sleeping and + offloading constants are hardcoded to this size as well + """ + project_id = os.getenv("PROJECT_ID") + assert project_id + ordinary_branch_id = "" + prewarmed_branch_id = "" + for branch in neon_api.get_branches(project_id)["branches"]: + if branch["name"] == "ordinary": + ordinary_branch_id = branch["id"] + if branch["name"] == "prewarmed": + prewarmed_branch_id = branch["id"] + assert len(ordinary_branch_id) > 0 + assert len(prewarmed_branch_id) > 0 + + ep_ordinary = None + ep_prewarmed = None + for ep in neon_api.get_endpoints(project_id)["endpoints"]: + if ep["branch_id"] == ordinary_branch_id: + ep_ordinary = ep + if ep["branch_id"] == prewarmed_branch_id: + ep_prewarmed = ep + assert ep_ordinary + assert ep_prewarmed + ordinary_id = ep_ordinary["id"] + prewarmed_id = ep_prewarmed["id"] -def benchmark_impl( - pg_bin: PgBin, neon_api: NeonAPI, project: dict[str, Any], zenbenchmark: NeonBenchmarker -): - pgbench_size = int(os.getenv("PGBENCH_SIZE") or "3424") # 50GB offload_secs = 20 - test_duration_min = 5 + test_duration_min = 3 pgbench_duration = f"-T{test_duration_min * 60}" - # prewarm API is not publicly exposed. In order to test performance of a - # fully prewarmed endpoint, wait after it restarts. - # The number here is empirical, based on manual runs on staging + pgbench_init_cmd = ["pgbench", "-P10", "-n", "-c10", pgbench_duration, "-Mprepared"] + pgbench_perf_cmd = pgbench_init_cmd + ["-S"] prewarmed_sleep_secs = 180 - branch_id = project["branch"]["id"] - project_id = project["project"]["id"] - normal_env = connection_parameters_to_env( - project["connection_uris"][0]["connection_parameters"] - ) - normal_id = project["endpoints"][0]["id"] - - prewarmed_branch_id = neon_api.create_branch( - project_id, "prewarmed", parent_id=branch_id, add_endpoint=False - )["branch"]["id"] - neon_api.wait_for_operation_to_finish(project_id) - - ep_prewarmed = neon_api.create_endpoint( - project_id, - prewarmed_branch_id, - endpoint_type="read_write", - settings={"autoprewarm": True, "offload_lfc_interval_seconds": offload_secs}, - ) - neon_api.wait_for_operation_to_finish(project_id) - - prewarmed_env = normal_env.copy() - prewarmed_env["PGHOST"] = ep_prewarmed["endpoint"]["host"] - prewarmed_id = ep_prewarmed["endpoint"]["id"] + ordinary_uri = neon_api.get_connection_uri(project_id, ordinary_branch_id, ordinary_id)["uri"] + prewarmed_uri = neon_api.get_connection_uri(project_id, prewarmed_branch_id, prewarmed_id)[ + "uri" + ] def bench(endpoint_name, endpoint_id, env): - pg_bin.run(["pgbench", "-i", "-I", "dtGvp", f"-s{pgbench_size}"], env) - sleep(offload_secs * 2) # ensure LFC is offloaded after pgbench finishes - neon_api.restart_endpoint(project_id, endpoint_id) - sleep(prewarmed_sleep_secs) + log.info(f"Running pgbench for {pgbench_duration}s to warm up the cache") + pg_bin.run_capture(pgbench_init_cmd, env) # capture useful for debugging + log.info(f"Initialized {endpoint_name}") + if endpoint_name == "prewarmed": + log.info(f"sleeping {offload_secs * 2} to ensure LFC is offloaded") + sleep(offload_secs * 2) + neon_api.restart_endpoint(project_id, endpoint_id) + log.info(f"sleeping {prewarmed_sleep_secs} to ensure LFC is prewarmed") + sleep(prewarmed_sleep_secs) + else: + neon_api.restart_endpoint(project_id, endpoint_id) + + log.info(f"Starting benchmark for {endpoint_name}") run_start_timestamp = utc_now_timestamp() t0 = timeit.default_timer() - out = pg_bin.run_capture(["pgbench", "-c10", pgbench_duration, "-Mprepared"], env) + out = pg_bin.run_capture(pgbench_perf_cmd, env) run_duration = timeit.default_timer() - t0 run_end_timestamp = utc_now_timestamp() @@ -140,29 +172,9 @@ def benchmark_impl( ) zenbenchmark.record_pg_bench_result(endpoint_name, res) - with Exec(max_workers=2) as exe: - exe.submit(bench, "normal", normal_id, normal_env) - exe.submit(bench, "prewarmed", prewarmed_id, prewarmed_env) + prewarmed_args = ("prewarmed", prewarmed_id, connstr_to_env(prewarmed_uri)) + prewarmed_thread = Thread(target=bench, args=prewarmed_args) + prewarmed_thread.start() - -def test_compare_prewarmed_read_perf(neon_compare: NeonCompare): - env = neon_compare.env - env.create_branch("normal") - env.create_branch("prewarmed") - ep_normal: Endpoint = env.endpoints.create_start("normal") - ep_prewarmed: Endpoint = env.endpoints.create_start("prewarmed", autoprewarm=True) - - sql = [ - "CREATE EXTENSION neon", - "CREATE TABLE foo(key serial primary key, t text default 'foooooooooooooooooooooooooooooooooooooooooooooooooooo')", - "INSERT INTO foo SELECT FROM generate_series(1,1000000)", - ] - for ep in [ep_normal, ep_prewarmed]: - ep.safe_psql_many(sql) - client = ep.http_client() - client.offload_lfc() - ep.stop() - ep.start() - client.prewarm_lfc_wait() - with neon_compare.record_duration(f"{ep.branch_name}_run_duration"): - ep.safe_psql("SELECT count(*) from foo") + bench("ordinary", ordinary_id, connstr_to_env(ordinary_uri)) + prewarmed_thread.join() diff --git a/test_runner/performance/test_perf_many_relations.py b/test_runner/performance/test_perf_many_relations.py index 81dae53759..9204a6a740 100644 --- a/test_runner/performance/test_perf_many_relations.py +++ b/test_runner/performance/test_perf_many_relations.py @@ -5,7 +5,7 @@ import pytest from fixtures.benchmark_fixture import NeonBenchmarker from fixtures.compare_fixtures import RemoteCompare from fixtures.log_helper import log -from fixtures.neon_fixtures import NeonEnvBuilder +from fixtures.neon_fixtures import NeonEnvBuilder, wait_for_last_flush_lsn from fixtures.utils import shared_buffers_for_max_cu @@ -69,13 +69,18 @@ def test_perf_many_relations(remote_compare: RemoteCompare, num_relations: int): ) -def test_perf_simple_many_relations_reldir_v2( - neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchmarker +@pytest.mark.parametrize( + "reldir,num_relations", + [("v1", 10000), ("v1v2", 10000), ("v2", 10000), ("v2", 100000)], + ids=["v1-small", "v1v2-small", "v2-small", "v2-large"], +) +def test_perf_simple_many_relations_reldir( + neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchmarker, reldir: str, num_relations: int ): """ Test creating many relations in a single database. """ - env = neon_env_builder.init_start(initial_tenant_conf={"rel_size_v2_enabled": "true"}) + env = neon_env_builder.init_start(initial_tenant_conf={"rel_size_v2_enabled": reldir != "v1"}) ep = env.endpoints.create_start( "main", config_lines=[ @@ -85,14 +90,38 @@ def test_perf_simple_many_relations_reldir_v2( ], ) - assert ( - env.pageserver.http_client().timeline_detail(env.initial_tenant, env.initial_timeline)[ - "rel_size_migration" - ] - != "legacy" - ) + ep.safe_psql("CREATE TABLE IF NOT EXISTS initial_table (v1 int)") + wait_for_last_flush_lsn(env, ep, env.initial_tenant, env.initial_timeline) - n = 100000 + if reldir == "v1": + assert ( + env.pageserver.http_client().timeline_detail(env.initial_tenant, env.initial_timeline)[ + "rel_size_migration" + ] + == "legacy" + ) + elif reldir == "v1v2": + assert ( + env.pageserver.http_client().timeline_detail(env.initial_tenant, env.initial_timeline)[ + "rel_size_migration" + ] + == "migrating" + ) + elif reldir == "v2": + # only read/write to the v2 keyspace + env.pageserver.http_client().timeline_patch_index_part( + env.initial_tenant, env.initial_timeline, {"rel_size_migration": "migrated"} + ) + assert ( + env.pageserver.http_client().timeline_detail(env.initial_tenant, env.initial_timeline)[ + "rel_size_migration" + ] + == "migrated" + ) + else: + raise AssertionError(f"Invalid reldir config: {reldir}") + + n = num_relations step = 5000 # Create many relations log.info(f"Creating {n} relations...") diff --git a/test_runner/random_ops/test_random_ops.py b/test_runner/random_ops/test_random_ops.py index b106e9b729..aae17e2fc4 100644 --- a/test_runner/random_ops/test_random_ops.py +++ b/test_runner/random_ops/test_random_ops.py @@ -96,6 +96,11 @@ class NeonBranch: ) self.benchmark: subprocess.Popen[Any] | None = None self.updated_at: datetime = datetime.fromisoformat(branch["branch"]["updated_at"]) + self.parent_timestamp: datetime = ( + datetime.fromisoformat(branch["branch"]["parent_timestamp"]) + if "parent_timestamp" in branch["branch"] + else datetime.fromtimestamp(0, tz=UTC) + ) self.connect_env: dict[str, str] | None = None if self.connection_parameters: self.connect_env = { @@ -113,8 +118,18 @@ class NeonBranch: """ return f"{self.id}{'(r)' if self.id in self.project.reset_branches else ''}, parent: {self.parent}" - def create_child_branch(self) -> NeonBranch | None: - return self.project.create_branch(self.id) + def random_time(self) -> datetime: + min_time = max( + self.updated_at + timedelta(seconds=1), + self.project.min_time, + self.parent_timestamp + timedelta(seconds=1), + ) + max_time = datetime.now(UTC) - timedelta(seconds=1) + log.info("min_time: %s, max_time: %s", min_time, max_time) + return (min_time + (max_time - min_time) * random.random()).replace(microsecond=0) + + def create_child_branch(self, parent_timestamp: datetime | None = None) -> NeonBranch | None: + return self.project.create_branch(self.id, parent_timestamp) def create_ro_endpoint(self) -> NeonEndpoint | None: if not self.project.check_limit_endpoints(): @@ -136,21 +151,33 @@ class NeonBranch: def terminate_benchmark(self) -> None: self.project.terminate_benchmark(self.id) + def reset_to_parent(self) -> None: + for ep in self.project.endpoints.values(): + if ep.type == "read_only": + ep.terminate_benchmark() + self.terminate_benchmark() + res = self.neon_api.reset_to_parent(self.project_id, self.id) + self.updated_at = datetime.fromisoformat(res["branch"]["updated_at"]) + self.parent_timestamp = datetime.fromisoformat(res["branch"]["parent_timestamp"]) + self.project.wait() + self.start_benchmark() + for ep in self.project.endpoints.values(): + if ep.type == "read_only": + ep.start_benchmark() + def restore_random_time(self) -> None: """ Does PITR, i.e. calls the reset API call on the same branch to the random time in the past """ - min_time = self.updated_at + timedelta(seconds=1) - max_time = datetime.now(UTC) - timedelta(seconds=1) - target_time = (min_time + (max_time - min_time) * random.random()).replace(microsecond=0) res = self.restore( self.id, - source_timestamp=target_time.isoformat().replace("+00:00", "Z"), + source_timestamp=self.random_time().isoformat().replace("+00:00", "Z"), preserve_under_name=self.project.gen_restore_name(), ) if res is None: return self.updated_at = datetime.fromisoformat(res["branch"]["updated_at"]) + self.parent_timestamp = datetime.fromisoformat(res["branch"]["parent_timestamp"]) parent_id: str = res["branch"]["parent_id"] # Creates an object for the parent branch # After the reset operation a new parent branch is created @@ -225,6 +252,7 @@ class NeonProject: self.restart_pgbench_on_console_errors: bool = False self.limits: dict[str, Any] = self.get_limits()["limits"] self.read_only_endpoints_total: int = 0 + self.min_time: datetime = datetime.now(UTC) def get_limits(self) -> dict[str, Any]: return self.neon_api.get_project_limits(self.id) @@ -251,11 +279,20 @@ class NeonProject: ) return False - def create_branch(self, parent_id: str | None = None) -> NeonBranch | None: + def create_branch( + self, parent_id: str | None = None, parent_timestamp: datetime | None = None + ) -> NeonBranch | None: self.wait() if not self.check_limit_branches(): return None - branch_def = self.neon_api.create_branch(self.id, parent_id=parent_id) + if parent_timestamp: + log.info("Timestamp: %s", parent_timestamp) + parent_timestamp_str: str | None = None + if parent_timestamp: + parent_timestamp_str = parent_timestamp.isoformat().replace("+00:00", "Z") + branch_def = self.neon_api.create_branch( + self.id, parent_id=parent_id, parent_timestamp=parent_timestamp_str + ) new_branch = NeonBranch(self, branch_def) self.wait() return new_branch @@ -288,6 +325,14 @@ class NeonProject: if parent.id in self.reset_branches: parent.delete() + def get_random_leaf_branch(self) -> NeonBranch | None: + target: NeonBranch | None = None + if self.leaf_branches: + target = random.choice(list(self.leaf_branches.values())) + else: + log.info("No leaf branches found") + return target + def delete_endpoint(self, endpoint_id: str) -> None: self.terminate_benchmark(endpoint_id) self.neon_api.delete_endpoint(self.id, endpoint_id) @@ -390,24 +435,22 @@ def do_action(project: NeonProject, action: str) -> bool: Runs the action """ log.info("Action: %s", action) - if action == "new_branch": - log.info("Trying to create a new branch") + if action == "new_branch" or action == "new_branch_random_time": + use_random_time: bool = action == "new_branch_random_time" + log.info("Trying to create a new branch %s", "random time" if use_random_time else "") parent = project.branches[ random.choice(list(set(project.branches.keys()) - project.reset_branches)) ] - child = parent.create_child_branch() + child = parent.create_child_branch(parent.random_time() if use_random_time else None) if child is None: return False log.info("Created branch %s", child) child.start_benchmark() elif action == "delete_branch": - if project.leaf_branches: - target: NeonBranch = random.choice(list(project.leaf_branches.values())) - log.info("Trying to delete branch %s", target) - target.delete() - else: - log.info("Leaf branches not found, skipping") + if (target := project.get_random_leaf_branch()) is None: return False + log.info("Trying to delete branch %s", target) + target.delete() elif action == "new_ro_endpoint": ep = random.choice( [br for br in project.branches.values() if br.id not in project.reset_branches] @@ -427,13 +470,15 @@ def do_action(project: NeonProject, action: str) -> bool: target_ep.delete() log.info("endpoint %s deleted", target_ep.id) elif action == "restore_random_time": - if project.leaf_branches: - br: NeonBranch = random.choice(list(project.leaf_branches.values())) - log.info("Restore %s", br) - br.restore_random_time() - else: - log.info("No leaf branches found") + if (target := project.get_random_leaf_branch()) is None: return False + log.info("Restore %s", target) + target.restore_random_time() + elif action == "reset_to_parent": + if (target := project.get_random_leaf_branch()) is None: + return False + log.info("Reset to parent %s", target) + target.reset_to_parent() else: raise ValueError(f"The action {action} is unknown") return True @@ -460,17 +505,22 @@ def test_api_random( pg_bin, project = setup_class # Here we can assign weights ACTIONS = ( - ("new_branch", 1.5), + ("new_branch", 1.2), + ("new_branch_random_time", 0.5), ("new_ro_endpoint", 1.4), ("delete_ro_endpoint", 0.8), - ("delete_branch", 1.0), - ("restore_random_time", 1.2), + ("delete_branch", 1.2), + ("restore_random_time", 0.9), + ("reset_to_parent", 0.3), ) if num_ops_env := os.getenv("NUM_OPERATIONS"): num_operations = int(num_ops_env) else: num_operations = 250 pg_bin.run(["pgbench", "-i", "-I", "dtGvp", "-s100"], env=project.main_branch.connect_env) + # To not go to the past where pgbench tables do not exist + time.sleep(1) + project.min_time = datetime.now(UTC) for _ in range(num_operations): log.info("Starting action #%s", _ + 1) while not do_action( diff --git a/test_runner/regress/test_bad_connection.py b/test_runner/regress/test_bad_connection.py index d31c0c95d3..3c30296e6e 100644 --- a/test_runner/regress/test_bad_connection.py +++ b/test_runner/regress/test_bad_connection.py @@ -26,7 +26,7 @@ def test_compute_pageserver_connection_stress(neon_env_builder: NeonEnvBuilder): # Enable failpoint before starting everything else up so that we exercise the retry # on fetching basebackup pageserver_http = env.pageserver.http_client() - pageserver_http.configure_failpoints(("simulated-bad-compute-connection", "50%return(15)")) + pageserver_http.configure_failpoints(("simulated-bad-compute-connection", "20%return(15)")) env.create_branch("test_compute_pageserver_connection_stress") endpoint = env.endpoints.create_start("test_compute_pageserver_connection_stress") diff --git a/test_runner/regress/test_basebackup.py b/test_runner/regress/test_basebackup.py index d1b10ec85d..23b9105617 100644 --- a/test_runner/regress/test_basebackup.py +++ b/test_runner/regress/test_basebackup.py @@ -2,13 +2,15 @@ from __future__ import annotations from typing import TYPE_CHECKING +import pytest from fixtures.utils import wait_until if TYPE_CHECKING: from fixtures.neon_fixtures import NeonEnvBuilder -def test_basebackup_cache(neon_env_builder: NeonEnvBuilder): +@pytest.mark.parametrize("grpc", [True, False]) +def test_basebackup_cache(neon_env_builder: NeonEnvBuilder, grpc: bool): """ Simple test for basebackup cache. 1. Check that we always hit the cache after compute restart. @@ -22,7 +24,7 @@ def test_basebackup_cache(neon_env_builder: NeonEnvBuilder): """ env = neon_env_builder.init_start() - ep = env.endpoints.create("main") + ep = env.endpoints.create("main", grpc=grpc) ps = env.pageserver ps_http = ps.http_client() diff --git a/test_runner/regress/test_change_pageserver.py b/test_runner/regress/test_change_pageserver.py index b004db310c..af736af825 100644 --- a/test_runner/regress/test_change_pageserver.py +++ b/test_runner/regress/test_change_pageserver.py @@ -3,14 +3,35 @@ from __future__ import annotations import asyncio from typing import TYPE_CHECKING +import pytest from fixtures.log_helper import log +from fixtures.neon_fixtures import NeonEnvBuilder from fixtures.remote_storage import RemoteStorageKind if TYPE_CHECKING: - from fixtures.neon_fixtures import NeonEnvBuilder + from fixtures.neon_fixtures import Endpoint, NeonEnvBuilder -def test_change_pageserver(neon_env_builder: NeonEnvBuilder): +def reconfigure_endpoint(endpoint: Endpoint, pageserver_id: int, use_explicit_reconfigure: bool): + # It's important that we always update config.json before issuing any reconfigure requests + # to make sure that PG-initiated config refresh doesn't mess things up by reverting to the old config. + endpoint.update_pageservers_in_config(pageserver_id=pageserver_id) + + # PG will automatically refresh its configuration if it detects connectivity issues with pageservers. + # We also allow the test to explicitly request a reconfigure so that the test can be sure that the + # endpoint is running with the latest configuration. + # + # Note that explicit reconfiguration is not required for the system to function or for this test to pass. + # It is kept for reference as this is how this test used to work before the capability of initiating + # configuration refreshes was added to compute nodes. + if use_explicit_reconfigure: + endpoint.reconfigure(pageserver_id=pageserver_id) + + +@pytest.mark.parametrize("use_explicit_reconfigure_for_failover", [False, True]) +def test_change_pageserver( + neon_env_builder: NeonEnvBuilder, use_explicit_reconfigure_for_failover: bool +): """ A relatively low level test of reconfiguring a compute's pageserver at runtime. Usually this is all done via the storage controller, but this test will disable the storage controller's compute @@ -72,7 +93,10 @@ def test_change_pageserver(neon_env_builder: NeonEnvBuilder): execute("SELECT count(*) FROM foo") assert fetchone() == (100000,) - endpoint.reconfigure(pageserver_id=alt_pageserver_id) + # Reconfigure the endpoint to use the alt pageserver. We issue an explicit reconfigure request here + # regardless of test mode as this is testing the externally driven reconfiguration scenario, not the + # compute-initiated reconfiguration scenario upon detecting failures. + reconfigure_endpoint(endpoint, pageserver_id=alt_pageserver_id, use_explicit_reconfigure=True) # Verify that the neon.pageserver_connstring GUC is set to the correct thing execute("SELECT setting FROM pg_settings WHERE name='neon.pageserver_connstring'") @@ -100,6 +124,12 @@ def test_change_pageserver(neon_env_builder: NeonEnvBuilder): env.storage_controller.node_configure(env.pageservers[1].id, {"availability": "Offline"}) env.storage_controller.reconcile_until_idle() + reconfigure_endpoint( + endpoint, + pageserver_id=env.pageservers[0].id, + use_explicit_reconfigure=use_explicit_reconfigure_for_failover, + ) + endpoint.reconfigure(pageserver_id=env.pageservers[0].id) execute("SELECT count(*) FROM foo") @@ -116,7 +146,11 @@ def test_change_pageserver(neon_env_builder: NeonEnvBuilder): await asyncio.sleep( 1 ) # Sleep for 1 second just to make sure we actually started our count(*) query - endpoint.reconfigure(pageserver_id=env.pageservers[1].id) + reconfigure_endpoint( + endpoint, + pageserver_id=env.pageservers[1].id, + use_explicit_reconfigure=use_explicit_reconfigure_for_failover, + ) def execute_count(): execute("SELECT count(*) FROM FOO") diff --git a/test_runner/regress/test_compaction.py b/test_runner/regress/test_compaction.py index 76485c8321..94c18ac548 100644 --- a/test_runner/regress/test_compaction.py +++ b/test_runner/regress/test_compaction.py @@ -58,7 +58,7 @@ PREEMPT_GC_COMPACTION_TENANT_CONF = { "compaction_upper_limit": 6, "lsn_lease_length": "0s", # Enable gc-compaction - "gc_compaction_enabled": "true", + "gc_compaction_enabled": True, "gc_compaction_initial_threshold_kb": 1024, # At a small threshold "gc_compaction_ratio_percent": 1, # No PiTR interval and small GC horizon @@ -540,7 +540,7 @@ def test_pageserver_gc_compaction_trigger(neon_env_builder: NeonEnvBuilder): "pitr_interval": "0s", "gc_horizon": f"{1024 * 16}", "lsn_lease_length": "0s", - "gc_compaction_enabled": "true", + "gc_compaction_enabled": True, "gc_compaction_initial_threshold_kb": "16", "gc_compaction_ratio_percent": "50", # Do not generate image layers with create_image_layers @@ -863,6 +863,89 @@ def test_pageserver_compaction_circuit_breaker(neon_env_builder: NeonEnvBuilder) assert not env.pageserver.log_contains(".*Circuit breaker failure ended.*") +def test_ps_corruption_detection_feedback(neon_env_builder: NeonEnvBuilder): + """ + Test that when the pageserver detects corruption during image layer creation, + it sends corruption feedback to the safekeeper which gets recorded in its + safekeeper_ps_corruption_detected metric. + """ + # Configure tenant with aggressive compaction settings to easily trigger compaction + TENANT_CONF = { + # Small checkpoint distance to create many layers + "checkpoint_distance": 1024 * 128, + # Compact small layers + "compaction_target_size": 1024 * 128, + # Create image layers eagerly + "image_creation_threshold": 1, + "image_layer_creation_check_threshold": 0, + # Force frequent compaction + "compaction_period": "1s", + } + + env = neon_env_builder.init_start(initial_tenant_conf=TENANT_CONF) + # We are simulating compaction failures so we should allow these error messages. + env.pageserver.allowed_errors.append(".*Compaction failed.*") + tenant_id = env.initial_tenant + timeline_id = env.initial_timeline + + pageserver_http = env.pageserver.http_client() + workload = Workload( + env, tenant_id, timeline_id, endpoint_opts={"config_lines": ["neon.lakebase_mode=true"]} + ) + workload.init() + + # Enable the failpoint that will cause image layer creation to fail due to a (simulated) detected + # corruption. + pageserver_http.configure_failpoints(("create-image-layer-fail-simulated-corruption", "return")) + + # Write some data to trigger compaction and image layer creation + log.info("Writing data to trigger compaction...") + workload.write_rows(1024 * 64, upload=False) + workload.write_rows(1024 * 64, upload=False) + + # Returns True if the corruption signal from PS is propagated to the SK according to the "safekeeper_ps_corruption_detected" metric. + # Raises an exception otherwise. + def check_corruption_signal_propagated_to_sk(): + # Get metrics from all safekeepers + for sk in env.safekeepers: + sk_metrics = sk.http_client().get_metrics() + # Look for our corruption detected metric with the right tenant and timeline + corruption_metrics = sk_metrics.query_all("safekeeper_ps_corruption_detected") + + for metric in corruption_metrics: + # Check if there's a metric for our tenant and timeline that has value 1 + if ( + metric.labels.get("tenant_id") == str(tenant_id) + and metric.labels.get("timeline_id") == str(timeline_id) + and metric.value == 1 + ): + log.info(f"Corruption detected by safekeeper {sk.id}: {metric}") + return True + raise Exception("Corruption detection feedback not found in any safekeeper metrics") + + # Returns True if the corruption signal from PS is propagated to the PG according to the "ps_corruption_detected" metric + # in "neon_perf_counters". + # Raises an exception otherwise. + def check_corruption_signal_propagated_to_pg(): + endpoint = workload.endpoint() + results = endpoint.safe_psql("CREATE EXTENSION IF NOT EXISTS neon") + results = endpoint.safe_psql( + "SELECT value FROM neon_perf_counters WHERE metric = 'ps_corruption_detected'" + ) + log.info("Query corruption detection metric, results: %s", results) + if results[0][0] == 1: + log.info("Corruption detection signal is raised on Postgres") + return True + raise Exception("Corruption detection signal is not raise on Postgres") + + # Confirm that the corruption signal propagates to both the safekeeper and Postgres + wait_until(check_corruption_signal_propagated_to_sk, timeout=10, interval=0.1) + wait_until(check_corruption_signal_propagated_to_pg, timeout=10, interval=0.1) + + # Cleanup the failpoint + pageserver_http.configure_failpoints(("create-image-layer-fail-simulated-corruption", "off")) + + @pytest.mark.parametrize("enabled", [True, False]) def test_image_layer_compression(neon_env_builder: NeonEnvBuilder, enabled: bool): tenant_conf = { diff --git a/test_runner/regress/test_compatibility.py b/test_runner/regress/test_compatibility.py index 734887c5b3..635a040800 100644 --- a/test_runner/regress/test_compatibility.py +++ b/test_runner/regress/test_compatibility.py @@ -538,6 +538,7 @@ def test_historic_storage_formats( neon_env_builder.enable_pageserver_remote_storage(s3_storage()) neon_env_builder.pg_version = dataset.pg_version env = neon_env_builder.init_configs() + env.start() assert isinstance(env.pageserver_remote_storage, S3Storage) @@ -576,6 +577,17 @@ def test_historic_storage_formats( # All our artifacts should contain at least one timeline assert len(timelines) > 0 + if dataset.name == "2025-04-08-tenant-manifest-v1": + # This dataset was created at a time where we decided to migrate to v2 reldir by simply disabling writes to v1 + # and starting writing to v2. This was too risky and we have reworked the migration plan. Therefore, we should + # opt in full relv2 mode for this dataset. + for timeline in timelines: + env.pageserver.http_client().timeline_patch_index_part( + dataset.tenant_id, + timeline["timeline_id"], + {"force_index_update": True, "rel_size_migration": "migrated"}, + ) + # Import tenant does not create the timeline on safekeepers, # because it is a debug handler and the timeline may have already been # created on some set of safekeepers. diff --git a/test_runner/regress/test_compute_termination.py b/test_runner/regress/test_compute_termination.py new file mode 100644 index 0000000000..2d62ccf20f --- /dev/null +++ b/test_runner/regress/test_compute_termination.py @@ -0,0 +1,369 @@ +from __future__ import annotations + +import json +import os +import shutil +import subprocess +import threading +import time +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import TYPE_CHECKING + +import requests +from fixtures.log_helper import log +from typing_extensions import override + +if TYPE_CHECKING: + from typing import Any + + from fixtures.common_types import TenantId, TimelineId + from fixtures.neon_fixtures import NeonEnv + from fixtures.port_distributor import PortDistributor + + +def launch_compute_ctl( + env: NeonEnv, + endpoint_name: str, + external_http_port: int, + internal_http_port: int, + pg_port: int, + control_plane_port: int, +) -> subprocess.Popen[str]: + """ + Helper function to launch compute_ctl process with common configuration. + Returns the Popen process object. + """ + # Create endpoint directory structure following the standard pattern + endpoint_path = env.repo_dir / "endpoints" / endpoint_name + + # Clean up any existing endpoint directory to avoid conflicts + if endpoint_path.exists(): + shutil.rmtree(endpoint_path) + + endpoint_path.mkdir(mode=0o755, parents=True, exist_ok=True) + + # pgdata path - compute_ctl will create this directory during basebackup + pgdata_path = endpoint_path / "pgdata" + + # Create log file in endpoint directory + log_file = endpoint_path / "compute.log" + log_handle = open(log_file, "w") + + # Start compute_ctl pointing to our control plane + compute_ctl_path = env.neon_binpath / "compute_ctl" + connstr = f"postgresql://cloud_admin@localhost:{pg_port}/postgres" + + # Find postgres binary path + pg_bin_path = env.pg_distrib_dir / env.pg_version.v_prefixed / "bin" / "postgres" + pg_lib_path = env.pg_distrib_dir / env.pg_version.v_prefixed / "lib" + + env_vars = { + "INSTANCE_ID": "lakebase-instance-id", + "LD_LIBRARY_PATH": str(pg_lib_path), # Linux, etc. + "DYLD_LIBRARY_PATH": str(pg_lib_path), # macOS + } + + cmd = [ + str(compute_ctl_path), + "--external-http-port", + str(external_http_port), + "--internal-http-port", + str(internal_http_port), + "--pgdata", + str(pgdata_path), + "--connstr", + connstr, + "--pgbin", + str(pg_bin_path), + "--compute-id", + endpoint_name, # Use endpoint_name as compute-id + "--control-plane-uri", + f"http://127.0.0.1:{control_plane_port}", + "--lakebase-mode", + "true", + ] + + print(f"Launching compute_ctl with command: {cmd}") + + # Start compute_ctl + process = subprocess.Popen( + cmd, + env=env_vars, + stdout=log_handle, + stderr=subprocess.STDOUT, # Combine stderr with stdout + text=True, + ) + + return process + + +def wait_for_compute_status( + compute_process: subprocess.Popen[str], + http_port: int, + expected_status: str, + timeout_seconds: int = 10, +) -> None: + """ + Wait for compute_ctl to reach the expected status. + Raises an exception if timeout is reached or process exits unexpectedly. + """ + start_time = time.time() + while time.time() - start_time < timeout_seconds: + try: + # Try to connect to the HTTP endpoint + response = requests.get(f"http://localhost:{http_port}/status", timeout=0.5) + if response.status_code == 200: + status_json = response.json() + # Check if it's in expected status + if status_json.get("status") == expected_status: + return + except (requests.ConnectionError, requests.Timeout): + pass + + # Check if process has exited + if compute_process.poll() is not None: + raise Exception( + f"compute_ctl exited unexpectedly with code {compute_process.returncode}." + ) + + time.sleep(0.5) + + # Timeout reached + compute_process.terminate() + raise Exception( + f"compute_ctl failed to reach {expected_status} status within {timeout_seconds} seconds." + ) + + +class EmptySpecHandler(BaseHTTPRequestHandler): + """HTTP handler that returns an Empty compute spec response""" + + def do_GET(self): + if self.path.startswith("/compute/api/v2/computes/") and self.path.endswith("/spec"): + # Return empty status which will put compute in Empty state + response: dict[str, Any] = { + "status": "empty", + "spec": None, + "compute_ctl_config": {"jwks": {"keys": []}}, + } + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps(response).encode()) + else: + self.send_error(404) + + @override + def log_message(self, format: str, *args: Any): + # Suppress request logging + pass + + +def test_compute_terminate_empty(neon_simple_env: NeonEnv, port_distributor: PortDistributor): + """ + Test that terminating a compute in Empty status works correctly. + + This tests the bug fix where terminating an Empty compute would hang + waiting for a non-existent postgres process to terminate. + """ + env = neon_simple_env + + # Get ports for our test + control_plane_port = port_distributor.get_port() + external_http_port = port_distributor.get_port() + internal_http_port = port_distributor.get_port() + pg_port = port_distributor.get_port() + + # Start a simple HTTP server that will serve the Empty spec + server = HTTPServer(("127.0.0.1", control_plane_port), EmptySpecHandler) + server_thread = threading.Thread(target=server.serve_forever) + server_thread.daemon = True + server_thread.start() + + compute_process = None + try: + # Start compute_ctl with ephemeral tenant ID + compute_process = launch_compute_ctl( + env, + "test-empty-compute", + external_http_port, + internal_http_port, + pg_port, + control_plane_port, + ) + + # Wait for compute_ctl to start and report "empty" status + wait_for_compute_status(compute_process, external_http_port, "empty") + + # Now send terminate request + response = requests.post(f"http://localhost:{external_http_port}/terminate") + + # Verify that the termination request sends back a 200 OK response and is not abruptly terminated. + assert response.status_code == 200, ( + f"Expected 200 OK, got {response.status_code}: {response.text}" + ) + + # Wait for compute_ctl to exit + exit_code = compute_process.wait(timeout=10) + assert exit_code == 0, f"compute_ctl exited with non-zero code: {exit_code}" + + finally: + # Clean up + server.shutdown() + if compute_process and compute_process.poll() is None: + compute_process.terminate() + compute_process.wait() + + +class SwitchableConfigHandler(BaseHTTPRequestHandler): + """HTTP handler that can switch between normal compute configs and compute configs without specs""" + + return_empty_spec: bool = False + tenant_id: TenantId | None = None + timeline_id: TimelineId | None = None + pageserver_port: int | None = None + safekeeper_connstrs: list[str] | None = None + + def do_GET(self): + if self.path.startswith("/compute/api/v2/computes/") and self.path.endswith("/spec"): + if self.return_empty_spec: + # Return empty status + response: dict[str, object | None] = { + "status": "empty", + "spec": None, + "compute_ctl_config": { + "jwks": {"keys": []}, + }, + } + else: + # Return normal attached spec + response = { + "status": "attached", + "spec": { + "format_version": 1.0, + "cluster": { + "roles": [], + "databases": [], + "postgresql_conf": "shared_preload_libraries='neon'", + }, + "tenant_id": str(self.tenant_id) if self.tenant_id else "", + "timeline_id": str(self.timeline_id) if self.timeline_id else "", + "pageserver_connstring": f"postgres://no_user@localhost:{self.pageserver_port}" + if self.pageserver_port + else "", + "safekeeper_connstrings": self.safekeeper_connstrs or [], + "mode": "Primary", + "skip_pg_catalog_updates": True, + "reconfigure_concurrency": 1, + "suspend_timeout_seconds": -1, + }, + "compute_ctl_config": { + "jwks": {"keys": []}, + }, + } + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps(response).encode()) + else: + self.send_error(404) + + @override + def log_message(self, format: str, *args: Any): + # Suppress request logging + pass + + +def test_compute_empty_spec_during_refresh_configuration( + neon_simple_env: NeonEnv, port_distributor: PortDistributor +): + """ + Test that compute exits when it receives an empty spec during refresh configuration state. + + This test: + 1. Start compute with a normal spec + 2. Change the spec handler to return empty spec + 3. Trigger some condition to force compute to refresh configuration + 4. Verify that compute_ctl exits + """ + env = neon_simple_env + + # Get ports for our test + control_plane_port = port_distributor.get_port() + external_http_port = port_distributor.get_port() + internal_http_port = port_distributor.get_port() + pg_port = port_distributor.get_port() + + # Set up handler class variables + SwitchableConfigHandler.tenant_id = env.initial_tenant + SwitchableConfigHandler.timeline_id = env.initial_timeline + SwitchableConfigHandler.pageserver_port = env.pageserver.service_port.pg + # Convert comma-separated string to list + safekeeper_connstrs = env.get_safekeeper_connstrs() + if safekeeper_connstrs: + SwitchableConfigHandler.safekeeper_connstrs = safekeeper_connstrs.split(",") + else: + SwitchableConfigHandler.safekeeper_connstrs = [] + SwitchableConfigHandler.return_empty_spec = False # Start with normal spec + + # Start HTTP server with switchable spec handler + server = HTTPServer(("127.0.0.1", control_plane_port), SwitchableConfigHandler) + server_thread = threading.Thread(target=server.serve_forever) + server_thread.daemon = True + server_thread.start() + + compute_process = None + try: + # Start compute_ctl with tenant and timeline IDs + # Use a unique endpoint name to avoid conflicts + endpoint_name = f"test-refresh-compute-{os.getpid()}" + compute_process = launch_compute_ctl( + env, + endpoint_name, + external_http_port, + internal_http_port, + pg_port, + control_plane_port, + ) + + # Wait for compute_ctl to start and report "running" status + wait_for_compute_status(compute_process, external_http_port, "running", timeout_seconds=30) + + log.info("Compute is running. Now returning empty spec and trigger configuration refresh.") + + # Switch spec fetch handler to return empty spec + SwitchableConfigHandler.return_empty_spec = True + + # Trigger a configuration refresh + try: + requests.post(f"http://localhost:{internal_http_port}/refresh_configuration") + except requests.RequestException as e: + log.info(f"Call to /refresh_configuration failed: {e}") + log.info( + "Ignoring the error, assuming that compute_ctl is already refreshing or has exited" + ) + + # Wait for compute_ctl to exit (it should exit when it gets an empty spec during refresh) + exit_start_time = time.time() + while time.time() - exit_start_time < 30: + if compute_process.poll() is not None: + # Process exited + break + time.sleep(0.5) + + # Verify that compute_ctl exited + exit_code = compute_process.poll() + if exit_code is None: + compute_process.terminate() + raise Exception("compute_ctl did not exit after receiving empty spec.") + + # The exit code might not be 0 in this case since it's an unexpected termination + # but we mainly care that it did exit + assert exit_code is not None, "compute_ctl should have exited" + + finally: + # Clean up + server.shutdown() + if compute_process and compute_process.poll() is None: + compute_process.terminate() + compute_process.wait() diff --git a/test_runner/regress/test_feature_flag.py b/test_runner/regress/test_feature_flag.py index c6c192b6f1..6c1e3484fa 100644 --- a/test_runner/regress/test_feature_flag.py +++ b/test_runner/regress/test_feature_flag.py @@ -50,11 +50,15 @@ def test_feature_flag(neon_env_builder: NeonEnvBuilder): )["result"] ) + env.endpoints.create_start("main") # trigger basebackup env.pageserver.http_client().force_refresh_feature_flag(env.initial_tenant) # Check if the properties exist result = env.pageserver.http_client().evaluate_feature_flag_multivariate( env.initial_tenant, "test-feature-flag" ) + assert "tenant_remote_size_mb" in result["properties"] + assert "tenant_db_count_max" in result["properties"] + assert "tenant_rel_count_max" in result["properties"] assert "tenant_id" in result["properties"] diff --git a/test_runner/regress/test_gin_redo.py b/test_runner/regress/test_gin_redo.py index 3ec2163203..71382990dc 100644 --- a/test_runner/regress/test_gin_redo.py +++ b/test_runner/regress/test_gin_redo.py @@ -16,7 +16,6 @@ def test_gin_redo(neon_simple_env: NeonEnv): secondary = env.endpoints.new_replica_start(origin=primary, endpoint_id="secondary") con = primary.connect() cur = con.cursor() - cur.execute("select pg_switch_wal()") cur.execute("create table gin_test_tbl(id integer, i int4[])") cur.execute("create index gin_test_idx on gin_test_tbl using gin (i)") cur.execute("insert into gin_test_tbl select g,array[3, 1, g] from generate_series(1, 10000) g") diff --git a/test_runner/regress/test_hadron_ps_connectivity_metrics.py b/test_runner/regress/test_hadron_ps_connectivity_metrics.py new file mode 100644 index 0000000000..ff1f37b634 --- /dev/null +++ b/test_runner/regress/test_hadron_ps_connectivity_metrics.py @@ -0,0 +1,137 @@ +import json +import shutil + +from fixtures.common_types import TenantShardId +from fixtures.log_helper import log +from fixtures.metrics import parse_metrics +from fixtures.neon_fixtures import Endpoint, NeonEnvBuilder, NeonPageserver +from requests.exceptions import ConnectionError + + +# Helper function to attempt reconfiguration of the compute to point to a new pageserver. Note that in these tests, +# we don't expect the reconfiguration attempts to go through, as we will be pointing the compute at a "wrong" pageserver. +def _attempt_reconfiguration(endpoint: Endpoint, new_pageserver_id: int, timeout_sec: float): + try: + endpoint.reconfigure(pageserver_id=new_pageserver_id, timeout_sec=timeout_sec) + except Exception as e: + log.info(f"reconfiguration failed with exception {e}") + pass + + +def read_misrouted_metric_value(pageserver: NeonPageserver) -> float: + return ( + pageserver.http_client() + .get_metrics() + .query_one("pageserver_misrouted_pagestream_requests_total") + .value + ) + + +def read_request_error_metric_value(endpoint: Endpoint) -> float: + return ( + parse_metrics(endpoint.http_client().metrics()) + .query_one("pg_cctl_pagestream_request_errors_total") + .value + ) + + +def test_misrouted_to_secondary( + neon_env_builder: NeonEnvBuilder, +): + """ + Tests that the following metrics are incremented when compute tries to talk to a secondary pageserver: + - On pageserver receiving the request: pageserver_misrouted_pagestream_requests_total + - On compute: pg_cctl_pagestream_request_errors_total + """ + neon_env_builder.num_pageservers = 2 + env = neon_env_builder.init_configs() + env.broker.start() + env.storage_controller.start() + for ps in env.pageservers: + ps.start() + for sk in env.safekeepers: + sk.start() + + # Create a tenant that has one primary and one secondary. Due to primary/secondary placement constraints, + # the primary and secondary pageservers will be different. + tenant_id, _ = env.create_tenant(shard_count=1, placement_policy=json.dumps({"Attached": 1})) + endpoint = env.endpoints.create( + "main", tenant_id=tenant_id, config_lines=["neon.lakebase_mode = true"] + ) + endpoint.respec(skip_pg_catalog_updates=False) + endpoint.start() + + # Get the primary pageserver serving the zero shard of the tenant, and detach it from the primary pageserver. + # This test operation configures tenant directly on the pageserver/does not go through the storage controller, + # so the compute does not get any notifications and will keep pointing at the detached pageserver. + tenant_zero_shard = TenantShardId(tenant_id, shard_number=0, shard_count=1) + + primary_ps = env.get_tenant_pageserver(tenant_zero_shard) + secondary_ps = ( + env.pageservers[1] if primary_ps.id == env.pageservers[0].id else env.pageservers[0] + ) + + # Now try to point the compute at the pageserver that is acting as secondary for the tenant. Test that the metrics + # on both compute_ctl and the pageserver register the misrouted requests following the reconfiguration attempt. + assert read_misrouted_metric_value(secondary_ps) == 0 + assert read_request_error_metric_value(endpoint) == 0 + _attempt_reconfiguration(endpoint, new_pageserver_id=secondary_ps.id, timeout_sec=2.0) + assert read_misrouted_metric_value(secondary_ps) > 0 + try: + assert read_request_error_metric_value(endpoint) > 0 + except ConnectionError: + # When configuring PG to use misconfigured pageserver, PG will cancel the query after certain number of failed + # reconfigure attempts. This will cause compute_ctl to exit. + log.info("Cannot connect to PG, ignoring") + pass + + +def test_misrouted_to_ps_not_hosting_tenant( + neon_env_builder: NeonEnvBuilder, +): + """ + Tests that the following metrics are incremented when compute tries to talk to a pageserver that does not host the tenant: + - On pageserver receiving the request: pageserver_misrouted_pagestream_requests_total + - On compute: pg_cctl_pagestream_request_errors_total + """ + neon_env_builder.num_pageservers = 2 + env = neon_env_builder.init_configs() + env.broker.start() + env.storage_controller.start(handle_ps_local_disk_loss=False) + for ps in env.pageservers: + ps.start() + for sk in env.safekeepers: + sk.start() + + tenant_id, _ = env.create_tenant(shard_count=1) + endpoint = env.endpoints.create( + "main", tenant_id=tenant_id, config_lines=["neon.lakebase_mode = true"] + ) + endpoint.respec(skip_pg_catalog_updates=False) + endpoint.start() + + tenant_ps_id = env.get_tenant_pageserver( + TenantShardId(tenant_id, shard_number=0, shard_count=1) + ).id + non_hosting_ps = ( + env.pageservers[1] if tenant_ps_id == env.pageservers[0].id else env.pageservers[0] + ) + + # Clear the disk of the non-hosting PS to make sure that it indeed doesn't have any information about the tenant. + non_hosting_ps.stop(immediate=True) + shutil.rmtree(non_hosting_ps.tenant_dir()) + non_hosting_ps.start() + + # Now try to point the compute to the non-hosting pageserver. Test that the metrics + # on both compute_ctl and the pageserver register the misrouted requests following the reconfiguration attempt. + assert read_misrouted_metric_value(non_hosting_ps) == 0 + assert read_request_error_metric_value(endpoint) == 0 + _attempt_reconfiguration(endpoint, new_pageserver_id=non_hosting_ps.id, timeout_sec=2.0) + assert read_misrouted_metric_value(non_hosting_ps) > 0 + try: + assert read_request_error_metric_value(endpoint) > 0 + except ConnectionError: + # When configuring PG to use misconfigured pageserver, PG will cancel the query after certain number of failed + # reconfigure attempts. This will cause compute_ctl to exit. + log.info("Cannot connect to PG, ignoring") + pass diff --git a/test_runner/regress/test_hot_standby.py b/test_runner/regress/test_hot_standby.py index 1ff61ce8dc..a329a5e842 100644 --- a/test_runner/regress/test_hot_standby.py +++ b/test_runner/regress/test_hot_standby.py @@ -133,6 +133,9 @@ def test_hot_standby_gc(neon_env_builder: NeonEnvBuilder, pause_apply: bool): tenant_conf = { # set PITR interval to be small, so we can do GC "pitr_interval": "0 s", + # we want to control gc and checkpoint frequency precisely + "gc_period": "0s", + "compaction_period": "0s", } env = neon_env_builder.init_start(initial_tenant_conf=tenant_conf) timeline_id = env.initial_timeline @@ -186,6 +189,23 @@ def test_hot_standby_gc(neon_env_builder: NeonEnvBuilder, pause_apply: bool): client = pageserver.http_client() client.timeline_checkpoint(tenant_shard_id, timeline_id) client.timeline_compact(tenant_shard_id, timeline_id) + # Wait for standby horizon to get propagated. + # This shouldn't be necessary, but the current mechanism for + # standby_horizon propagation is imperfect. Detailed + # description in https://databricks.atlassian.net/browse/LKB-2499 + while True: + val = client.get_metric_value( + "pageserver_standby_horizon", + { + "tenant_id": str(tenant_shard_id.tenant_id), + "shard_id": str(tenant_shard_id.shard_index), + "timeline_id": str(timeline_id), + }, + ) + log.info("waiting for next standby_horizon push from safekeeper, {val=}") + if val != 0: + break + time.sleep(0.1) client.timeline_gc(tenant_shard_id, timeline_id, 0) # Re-execute the query. The GetPage requests that this diff --git a/test_runner/regress/test_lfc_prewarm.py b/test_runner/regress/test_lfc_prewarm.py index 0f0cf4cc6d..a96f18177c 100644 --- a/test_runner/regress/test_lfc_prewarm.py +++ b/test_runner/regress/test_lfc_prewarm.py @@ -1,6 +1,6 @@ import random -import threading from enum import StrEnum +from threading import Thread from time import sleep from typing import Any @@ -47,19 +47,23 @@ def offload_lfc(method: PrewarmMethod, client: EndpointHttpClient, cur: Cursor) # With autoprewarm, we need to be sure LFC was offloaded after all writes # finish, so we sleep. Otherwise we'll have less prewarmed pages than we want sleep(AUTOOFFLOAD_INTERVAL_SECS) - client.offload_lfc_wait() - return + offload_res = client.offload_lfc_wait() + log.info(offload_res) + return offload_res if method == PrewarmMethod.COMPUTE_CTL: status = client.prewarm_lfc_status() assert status["status"] == "not_prewarmed" assert "error" not in status - client.offload_lfc() + offload_res = client.offload_lfc() + log.info(offload_res) assert client.prewarm_lfc_status()["status"] == "not_prewarmed" + parsed = prom_parse(client) desired = {OFFLOAD_LABEL: 1, PREWARM_LABEL: 0, OFFLOAD_ERR_LABEL: 0, PREWARM_ERR_LABEL: 0} assert parsed == desired, f"{parsed=} != {desired=}" - return + + return offload_res raise AssertionError(f"{method} not in PrewarmMethod") @@ -68,21 +72,30 @@ def prewarm_endpoint( method: PrewarmMethod, client: EndpointHttpClient, cur: Cursor, lfc_state: str | None ): if method == PrewarmMethod.AUTOPREWARM: - client.prewarm_lfc_wait() + prewarm_res = client.prewarm_lfc_wait() + log.info(prewarm_res) elif method == PrewarmMethod.COMPUTE_CTL: - client.prewarm_lfc() + prewarm_res = client.prewarm_lfc() + log.info(prewarm_res) + return prewarm_res elif method == PrewarmMethod.POSTGRES: cur.execute("select neon.prewarm_local_cache(%s)", (lfc_state,)) -def check_prewarmed( +def check_prewarmed_contains( method: PrewarmMethod, client: EndpointHttpClient, desired_status: dict[str, str | int] ): if method == PrewarmMethod.AUTOPREWARM: - assert client.prewarm_lfc_status() == desired_status + prewarm_status = client.prewarm_lfc_status() + for k in desired_status: + assert desired_status[k] == prewarm_status[k] + assert prom_parse(client)[PREWARM_LABEL] == 1 elif method == PrewarmMethod.COMPUTE_CTL: - assert client.prewarm_lfc_status() == desired_status + prewarm_status = client.prewarm_lfc_status() + for k in desired_status: + assert desired_status[k] == prewarm_status[k] + desired = {OFFLOAD_LABEL: 0, PREWARM_LABEL: 1, PREWARM_ERR_LABEL: 0, OFFLOAD_ERR_LABEL: 0} assert prom_parse(client) == desired @@ -149,9 +162,6 @@ def test_lfc_prewarm(neon_simple_env: NeonEnv, method: PrewarmMethod): log.info(f"Used LFC size: {lfc_used_pages}") pg_cur.execute("select * from neon.get_prewarm_info()") total, prewarmed, skipped, _ = pg_cur.fetchall()[0] - log.info(f"Prewarm info: {total=} {prewarmed=} {skipped=}") - progress = (prewarmed + skipped) * 100 // total - log.info(f"Prewarm progress: {progress}%") assert lfc_used_pages > 10000 assert total > 0 assert prewarmed > 0 @@ -161,7 +171,72 @@ def test_lfc_prewarm(neon_simple_env: NeonEnv, method: PrewarmMethod): assert lfc_cur.fetchall()[0][0] == n_records * (n_records + 1) / 2 desired = {"status": "completed", "total": total, "prewarmed": prewarmed, "skipped": skipped} - check_prewarmed(method, client, desired) + check_prewarmed_contains(method, client, desired) + + +@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") +def test_lfc_prewarm_cancel(neon_simple_env: NeonEnv): + """ + Test we can cancel LFC prewarm and prewarm successfully after + """ + env = neon_simple_env + n_records = 1000000 + cfg = [ + "autovacuum = off", + "shared_buffers=1MB", + "neon.max_file_cache_size=1GB", + "neon.file_cache_size_limit=1GB", + "neon.file_cache_prewarm_limit=1000", + ] + endpoint = env.endpoints.create_start(branch_name="main", config_lines=cfg) + + pg_conn = endpoint.connect() + pg_cur = pg_conn.cursor() + pg_cur.execute("create schema neon; create extension neon with schema neon") + pg_cur.execute("create database lfc") + + lfc_conn = endpoint.connect(dbname="lfc") + lfc_cur = lfc_conn.cursor() + log.info(f"Inserting {n_records} rows") + lfc_cur.execute("create table t(pk integer primary key, payload text default repeat('?', 128))") + lfc_cur.execute(f"insert into t (pk) values (generate_series(1,{n_records}))") + log.info(f"Inserted {n_records} rows") + + client = endpoint.http_client() + method = PrewarmMethod.COMPUTE_CTL + offload_lfc(method, client, pg_cur) + + endpoint.stop() + endpoint.start() + + thread = Thread(target=lambda: prewarm_endpoint(method, client, pg_cur, None)) + thread.start() + # wait 2 seconds to ensure we cancel prewarm SQL query + sleep(2) + client.cancel_prewarm_lfc() + thread.join() + assert client.prewarm_lfc_status()["status"] == "cancelled" + + prewarm_endpoint(method, client, pg_cur, None) + assert client.prewarm_lfc_status()["status"] == "completed" + + +@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") +def test_lfc_prewarm_empty(neon_simple_env: NeonEnv): + """ + Test there are no errors when trying to offload or prewarm endpoint without cache using compute_ctl. + Endpoint without cache is simulated by turning off LFC manually, but in cloud/ setup this is + also reproduced on fresh endpoints + """ + env = neon_simple_env + ep = env.endpoints.create_start("main", config_lines=["neon.file_cache_size_limit=0"]) + client = ep.http_client() + conn = ep.connect() + cur = conn.cursor() + cur.execute("create schema neon; create extension neon with schema neon") + method = PrewarmMethod.COMPUTE_CTL + assert offload_lfc(method, client, cur)["status"] == "skipped" + assert prewarm_endpoint(method, client, cur, None)["status"] == "skipped" # autoprewarm isn't needed as we prewarm manually @@ -232,11 +307,11 @@ def test_lfc_prewarm_under_workload(neon_simple_env: NeonEnv, method: PrewarmMet workload_threads = [] for _ in range(n_threads): - t = threading.Thread(target=workload) + t = Thread(target=workload) workload_threads.append(t) t.start() - prewarm_thread = threading.Thread(target=prewarm) + prewarm_thread = Thread(target=prewarm) prewarm_thread.start() def prewarmed(): diff --git a/test_runner/regress/test_ondemand_slru_download.py b/test_runner/regress/test_ondemand_slru_download.py index f0f12290cc..607a2921a9 100644 --- a/test_runner/regress/test_ondemand_slru_download.py +++ b/test_runner/regress/test_ondemand_slru_download.py @@ -16,7 +16,7 @@ def test_ondemand_download_pg_xact(neon_env_builder: NeonEnvBuilder, shard_count neon_env_builder.num_pageservers = shard_count tenant_conf = { - "lazy_slru_download": "true", + "lazy_slru_download": True, # set PITR interval to be small, so we can do GC "pitr_interval": "0 s", } @@ -82,7 +82,7 @@ def test_ondemand_download_replica(neon_env_builder: NeonEnvBuilder, shard_count neon_env_builder.num_pageservers = shard_count tenant_conf = { - "lazy_slru_download": "true", + "lazy_slru_download": True, } env = neon_env_builder.init_start( initial_tenant_conf=tenant_conf, initial_tenant_shard_count=shard_count @@ -141,7 +141,7 @@ def test_ondemand_download_after_wal_switch(neon_env_builder: NeonEnvBuilder): """ tenant_conf = { - "lazy_slru_download": "true", + "lazy_slru_download": True, } env = neon_env_builder.init_start(initial_tenant_conf=tenant_conf) diff --git a/test_runner/regress/test_pg_regress.py b/test_runner/regress/test_pg_regress.py index dd9c5437ad..cc7f736239 100644 --- a/test_runner/regress/test_pg_regress.py +++ b/test_runner/regress/test_pg_regress.py @@ -395,23 +395,6 @@ def test_max_wal_rate(neon_simple_env: NeonEnv): tuples = endpoint.safe_psql("SELECT backpressure_throttling_time();") assert tuples[0][0] == 0, "Backpressure throttling detected" - # 0 MB/s max_wal_rate. WAL proposer can still push some WALs but will be super slow. - endpoint.safe_psql_many( - [ - "ALTER SYSTEM SET databricks.max_wal_mb_per_second = 0;", - "SELECT pg_reload_conf();", - ] - ) - - # Write ~10 KB data should hit backpressure. - with endpoint.cursor(dbname=DBNAME) as cur: - cur.execute("SET databricks.max_wal_mb_per_second = 0;") - for _ in range(0, 10): - cur.execute("INSERT INTO usertable SELECT random(), repeat('a', 1000);") - - tuples = endpoint.safe_psql("SELECT backpressure_throttling_time();") - assert tuples[0][0] > 0, "No backpressure throttling detected" - # 1 MB/s max_wal_rate. endpoint.safe_psql_many( [ @@ -457,21 +440,6 @@ def test_tx_abort_with_many_relations( ], ) - if reldir_type == "v1": - assert ( - env.pageserver.http_client().timeline_detail(env.initial_tenant, env.initial_timeline)[ - "rel_size_migration" - ] - == "legacy" - ) - else: - assert ( - env.pageserver.http_client().timeline_detail(env.initial_tenant, env.initial_timeline)[ - "rel_size_migration" - ] - != "legacy" - ) - # How many relations: this number is tuned to be long enough to take tens of seconds # if the rollback code path is buggy, tripping the test's timeout. n = 5000 @@ -556,3 +524,19 @@ def test_tx_abort_with_many_relations( except: exec.shutdown(wait=False, cancel_futures=True) raise + + # Do the check after everything is done, because the reldirv2 transition won't happen until create table. + if reldir_type == "v1": + assert ( + env.pageserver.http_client().timeline_detail(env.initial_tenant, env.initial_timeline)[ + "rel_size_migration" + ] + == "legacy" + ) + else: + assert ( + env.pageserver.http_client().timeline_detail(env.initial_tenant, env.initial_timeline)[ + "rel_size_migration" + ] + != "legacy" + ) diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index 9860658ba5..dadaf8a1cf 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -17,9 +17,6 @@ if TYPE_CHECKING: from typing import Any -GET_CONNECTION_PID_QUERY = "SELECT pid FROM pg_stat_activity WHERE state = 'active'" - - @pytest.mark.asyncio async def test_http_pool_begin_1(static_proxy: NeonProxy): static_proxy.safe_psql("create user http_auth with password 'http' superuser") @@ -479,7 +476,7 @@ def test_sql_over_http_pool(static_proxy: NeonProxy): def get_pid(status: int, pw: str, user="http_auth") -> Any: return static_proxy.http_query( - GET_CONNECTION_PID_QUERY, + "SELECT pg_backend_pid() as pid", [], user=user, password=pw, @@ -513,6 +510,35 @@ def test_sql_over_http_pool(static_proxy: NeonProxy): assert "password authentication failed for user" in res["message"] +def test_sql_over_http_pool_settings(static_proxy: NeonProxy): + static_proxy.safe_psql("create user http_auth with password 'http' superuser") + + def multiquery(*queries) -> Any: + results = static_proxy.http_multiquery( + *queries, + user="http_auth", + password="http", + expected_code=200, + ) + + return [result["rows"] for result in results["results"]] + + [[intervalstyle]] = static_proxy.safe_psql("SHOW IntervalStyle") + assert intervalstyle == "postgres", "'postgres' is the default IntervalStyle in postgres" + + result = multiquery("select '0 seconds'::interval as interval") + assert result[0][0]["interval"] == "00:00:00", "interval is expected in postgres format" + + result = multiquery( + "SET IntervalStyle = 'iso_8601'", + "select '0 seconds'::interval as interval", + ) + assert result[1][0]["interval"] == "PT0S", "interval is expected in ISO-8601 format" + + result = multiquery("select '0 seconds'::interval as interval") + assert result[0][0]["interval"] == "00:00:00", "interval is expected in postgres format" + + def test_sql_over_http_urlencoding(static_proxy: NeonProxy): static_proxy.safe_psql("create user \"http+auth$$\" with password '%+$^&*@!' superuser") @@ -544,23 +570,37 @@ def test_http_pool_begin(static_proxy: NeonProxy): query(200, "SELECT 1;") # Query that should succeed regardless of the transaction -def test_sql_over_http_pool_idle(static_proxy: NeonProxy): +def test_sql_over_http_pool_tx_reuse(static_proxy: NeonProxy): static_proxy.safe_psql("create user http_auth2 with password 'http' superuser") - def query(status: int, query: str) -> Any: + def query(status: int, query: str, *args) -> Any: return static_proxy.http_query( query, - [], + args, user="http_auth2", password="http", expected_code=status, ) - pid1 = query(200, GET_CONNECTION_PID_QUERY)["rows"][0]["pid"] + def query_pid_txid() -> Any: + result = query( + 200, + "SELECT pg_backend_pid() as pid, pg_current_xact_id() as txid", + ) + + return result["rows"][0] + + res0 = query_pid_txid() + time.sleep(0.02) query(200, "BEGIN") - pid2 = query(200, GET_CONNECTION_PID_QUERY)["rows"][0]["pid"] - assert pid1 != pid2 + + res1 = query_pid_txid() + res2 = query_pid_txid() + + assert res0["pid"] == res1["pid"], "connection should be reused" + assert res0["pid"] == res2["pid"], "connection should be reused" + assert res1["txid"] != res2["txid"], "txid should be different" @pytest.mark.timeout(60) diff --git a/test_runner/regress/test_readonly_node.py b/test_runner/regress/test_readonly_node.py index 5612236250..e151b0ba13 100644 --- a/test_runner/regress/test_readonly_node.py +++ b/test_runner/regress/test_readonly_node.py @@ -129,7 +129,10 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder): Test static endpoint is protected from GC by acquiring and renewing lsn leases. """ - LSN_LEASE_LENGTH = 8 + LSN_LEASE_LENGTH = ( + 14 # This value needs to be large enough for compute_ctl to send two lease requests. + ) + neon_env_builder.num_pageservers = 2 # GC is manual triggered. env = neon_env_builder.init_start( @@ -230,6 +233,15 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder): log.info(f"`SELECT` query succeed after GC, {ctx=}") return offset + # It's not reliable to let the compute renew the lease in this test case as we have a very tight + # lease timeout. Therefore, the test case itself will renew the lease. + # + # This is a workaround to make the test case more deterministic. + def renew_lease(env: NeonEnv, lease_lsn: Lsn): + env.storage_controller.pageserver_api().timeline_lsn_lease( + env.initial_tenant, env.initial_timeline, lease_lsn + ) + # Insert some records on main branch with env.endpoints.create_start("main", config_lines=["shared_buffers=1MB"]) as ep_main: with ep_main.cursor() as cur: @@ -242,6 +254,9 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder): XLOG_BLCKSZ = 8192 lsn = Lsn((int(lsn) // XLOG_BLCKSZ) * XLOG_BLCKSZ) + # We need to mock the way cplane works: it gets a lease for a branch before starting the compute. + renew_lease(env, lsn) + with env.endpoints.create_start( branch_name="main", endpoint_id="static", @@ -251,9 +266,6 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder): cur.execute("SELECT count(*) FROM t0") assert cur.fetchone() == (ROW_COUNT,) - # Wait for static compute to renew lease at least once. - time.sleep(LSN_LEASE_LENGTH / 2) - generate_updates_on_main(env, ep_main, 3, end=100) offset = trigger_gc_and_select( @@ -263,10 +275,10 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder): # Trigger Pageserver restarts for ps in env.pageservers: ps.stop() - # Static compute should have at least one lease request failure due to connection. - time.sleep(LSN_LEASE_LENGTH / 2) ps.start() + renew_lease(env, lsn) + trigger_gc_and_select( env, ep_static, @@ -282,6 +294,9 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder): ) env.storage_controller.reconcile_until_idle() + # Wait for static compute to renew lease on the new pageserver. + time.sleep(LSN_LEASE_LENGTH + 3) + trigger_gc_and_select( env, ep_static, @@ -292,7 +307,6 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder): # Do some update so we can increment gc_cutoff generate_updates_on_main(env, ep_main, i, end=100) - # Wait for the existing lease to expire. time.sleep(LSN_LEASE_LENGTH + 1) # Now trigger GC again, layers should be removed. diff --git a/test_runner/regress/test_relations.py b/test_runner/regress/test_relations.py index b2ddcb1c2e..6263ced9df 100644 --- a/test_runner/regress/test_relations.py +++ b/test_runner/regress/test_relations.py @@ -7,6 +7,8 @@ if TYPE_CHECKING: NeonEnvBuilder, ) +from fixtures.neon_fixtures import wait_for_last_flush_lsn + def test_pageserver_reldir_v2( neon_env_builder: NeonEnvBuilder, @@ -65,6 +67,8 @@ def test_pageserver_reldir_v2( endpoint.safe_psql("CREATE TABLE foo4 (id INTEGER PRIMARY KEY, val text)") # Delete a relation in v1 endpoint.safe_psql("DROP TABLE foo1") + # wait pageserver to apply the LSN + wait_for_last_flush_lsn(env, endpoint, env.initial_tenant, env.initial_timeline) # Check if both relations are still accessible endpoint.safe_psql("SELECT * FROM foo2") @@ -76,12 +80,16 @@ def test_pageserver_reldir_v2( # This will acquire a basebackup, which lists all relations. endpoint.start() - # Check if both relations are still accessible + # Check if both relations are still accessible after restart endpoint.safe_psql("DROP TABLE IF EXISTS foo1") endpoint.safe_psql("SELECT * FROM foo2") endpoint.safe_psql("SELECT * FROM foo3") endpoint.safe_psql("SELECT * FROM foo4") endpoint.safe_psql("DROP TABLE foo3") + # wait pageserver to apply the LSN + wait_for_last_flush_lsn(env, endpoint, env.initial_tenant, env.initial_timeline) + + # Restart the endpoint again endpoint.stop() endpoint.start() @@ -99,6 +107,9 @@ def test_pageserver_reldir_v2( }, ) + endpoint.stop() + endpoint.start() + # Check if the relation is still accessible endpoint.safe_psql("SELECT * FROM foo2") endpoint.safe_psql("SELECT * FROM foo4") @@ -111,3 +122,10 @@ def test_pageserver_reldir_v2( ] == "migrating" ) + + assert ( + env.pageserver.http_client().timeline_detail(env.initial_tenant, env.initial_timeline)[ + "rel_size_migrated_at" + ] + is not None + ) diff --git a/test_runner/regress/test_replica_promotes.py b/test_runner/regress/test_replica_promotes.py index 8d39ac123a..9415d6886c 100644 --- a/test_runner/regress/test_replica_promotes.py +++ b/test_runner/regress/test_replica_promotes.py @@ -90,6 +90,7 @@ def test_replica_promote(neon_simple_env: NeonEnv, method: PromoteMethod): secondary_cur.execute("select count(*) from t") assert secondary_cur.fetchone() == (100,) + primary_spec = primary.get_compute_spec() primary_endpoint_id = primary.endpoint_id stop_and_check_lsn(primary, expected_primary_lsn) @@ -99,10 +100,9 @@ def test_replica_promote(neon_simple_env: NeonEnv, method: PromoteMethod): if method == PromoteMethod.COMPUTE_CTL: client = secondary.http_client() client.prewarm_lfc(primary_endpoint_id) - # control plane knows safekeepers, simulate it by querying primary assert (lsn := primary.terminate_flush_lsn) - safekeepers_lsn = {"safekeepers": safekeepers, "wal_flush_lsn": lsn} - assert client.promote(safekeepers_lsn)["status"] == "completed" + promote_spec = {"spec": primary_spec, "wal_flush_lsn": str(lsn)} + assert client.promote(promote_spec)["status"] == "completed" else: promo_cur.execute(f"alter system set neon.safekeepers='{safekeepers}'") promo_cur.execute("select pg_reload_conf()") @@ -131,21 +131,35 @@ def test_replica_promote(neon_simple_env: NeonEnv, method: PromoteMethod): lsn_triple = get_lsn_triple(new_primary_cur) log.info(f"Secondary: LSN after workload is {lsn_triple}") - expected_promoted_lsn = Lsn(lsn_triple[2]) + expected_lsn = Lsn(lsn_triple[2]) with secondary.connect() as conn, conn.cursor() as new_primary_cur: new_primary_cur.execute("select payload from t") assert new_primary_cur.fetchall() == [(it,) for it in range(1, 201)] if method == PromoteMethod.COMPUTE_CTL: - # compute_ctl's /promote switches replica type to Primary so it syncs - # safekeepers on finish - stop_and_check_lsn(secondary, expected_promoted_lsn) + # compute_ctl's /promote switches replica type to Primary so it syncs safekeepers on finish + stop_and_check_lsn(secondary, expected_lsn) else: - # on testing postgres, we don't update replica type, secondaries don't - # sync so lsn should be None + # on testing postgres, we don't update replica type, secondaries don't sync so lsn should be None stop_and_check_lsn(secondary, None) + if method == PromoteMethod.COMPUTE_CTL: + secondary.stop() + # In production, compute ultimately receives new compute spec from cplane. + secondary.respec(mode="Primary") + secondary.start() + + with secondary.connect() as conn, conn.cursor() as new_primary_cur: + new_primary_cur.execute( + "INSERT INTO t (payload) SELECT generate_series(101, 200) RETURNING payload" + ) + assert new_primary_cur.fetchall() == [(it,) for it in range(101, 201)] + lsn_triple = get_lsn_triple(new_primary_cur) + log.info(f"Secondary: LSN after restart and workload is {lsn_triple}") + expected_lsn = Lsn(lsn_triple[2]) + stop_and_check_lsn(secondary, expected_lsn) + primary = env.endpoints.create_start(branch_name="main", endpoint_id="primary2") with primary.connect() as new_primary, new_primary.cursor() as new_primary_cur: @@ -154,10 +168,11 @@ def test_replica_promote(neon_simple_env: NeonEnv, method: PromoteMethod): log.info(f"New primary: Boot LSN is {lsn_triple}") new_primary_cur.execute("select count(*) from t") - assert new_primary_cur.fetchone() == (200,) + compute_ctl_count = 100 * (method == PromoteMethod.COMPUTE_CTL) + assert new_primary_cur.fetchone() == (200 + compute_ctl_count,) new_primary_cur.execute("INSERT INTO t (payload) SELECT generate_series(201, 300)") new_primary_cur.execute("select count(*) from t") - assert new_primary_cur.fetchone() == (300,) + assert new_primary_cur.fetchone() == (300 + compute_ctl_count,) stop_and_check_lsn(primary, expected_primary_lsn) @@ -175,18 +190,91 @@ def test_replica_promote_handler_disconnects(neon_simple_env: NeonEnv): cur.execute("create schema neon;create extension neon with schema neon") cur.execute("create table t(pk bigint GENERATED ALWAYS AS IDENTITY, payload integer)") cur.execute("INSERT INTO t(payload) SELECT generate_series(1, 100)") - cur.execute("show neon.safekeepers") - safekeepers = cur.fetchall()[0][0] primary.http_client().offload_lfc() + primary_spec = primary.get_compute_spec() primary_endpoint_id = primary.endpoint_id primary.stop(mode="immediate-terminate") assert (lsn := primary.terminate_flush_lsn) client = secondary.http_client() client.prewarm_lfc(primary_endpoint_id) - safekeepers_lsn = {"safekeepers": safekeepers, "wal_flush_lsn": lsn} - assert client.promote(safekeepers_lsn, disconnect=True)["status"] == "completed" + promote_spec = {"spec": primary_spec, "wal_flush_lsn": str(lsn)} + assert client.promote(promote_spec, disconnect=True)["status"] == "completed" + + with secondary.connect() as conn, conn.cursor() as cur: + cur.execute("select count(*) from t") + assert cur.fetchone() == (100,) + cur.execute("INSERT INTO t (payload) SELECT generate_series(101, 200) RETURNING payload") + cur.execute("select count(*) from t") + assert cur.fetchone() == (200,) + + +@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") +def test_replica_promote_fails(neon_simple_env: NeonEnv): + """ + Test that if a /promote route fails, we can safely start primary back + """ + env: NeonEnv = neon_simple_env + primary: Endpoint = env.endpoints.create_start(branch_name="main", endpoint_id="primary") + secondary: Endpoint = env.endpoints.new_replica_start(origin=primary, endpoint_id="secondary") + secondary.stop() + secondary.start(env={"FAILPOINTS": "compute-promotion=return(0)"}) + + with primary.connect() as conn, conn.cursor() as cur: + cur.execute("create schema neon;create extension neon with schema neon") + cur.execute("create table t(pk bigint GENERATED ALWAYS AS IDENTITY, payload integer)") + cur.execute("INSERT INTO t(payload) SELECT generate_series(1, 100)") + + primary.http_client().offload_lfc() + primary_spec = primary.get_compute_spec() + primary_endpoint_id = primary.endpoint_id + primary.stop(mode="immediate-terminate") + assert (lsn := primary.terminate_flush_lsn) + + client = secondary.http_client() + client.prewarm_lfc(primary_endpoint_id) + promote_spec = {"spec": primary_spec, "wal_flush_lsn": str(lsn)} + assert client.promote(promote_spec)["status"] == "failed" + secondary.stop() + + primary.start() + with primary.connect() as conn, conn.cursor() as cur: + cur.execute("select count(*) from t") + assert cur.fetchone() == (100,) + cur.execute("INSERT INTO t (payload) SELECT generate_series(101, 200) RETURNING payload") + cur.execute("select count(*) from t") + assert cur.fetchone() == (200,) + + +@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") +def test_replica_promote_prewarm_fails(neon_simple_env: NeonEnv): + """ + Test that if /lfc/prewarm route fails, we are able to promote + """ + env: NeonEnv = neon_simple_env + primary: Endpoint = env.endpoints.create_start(branch_name="main", endpoint_id="primary") + secondary: Endpoint = env.endpoints.new_replica_start(origin=primary, endpoint_id="secondary") + secondary.stop() + secondary.start(env={"FAILPOINTS": "compute-prewarm=return(0)"}) + + with primary.connect() as conn, conn.cursor() as cur: + cur.execute("create schema neon;create extension neon with schema neon") + cur.execute("create table t(pk bigint GENERATED ALWAYS AS IDENTITY, payload integer)") + cur.execute("INSERT INTO t(payload) SELECT generate_series(1, 100)") + + primary.http_client().offload_lfc() + primary_spec = primary.get_compute_spec() + primary_endpoint_id = primary.endpoint_id + primary.stop(mode="immediate-terminate") + assert (lsn := primary.terminate_flush_lsn) + + client = secondary.http_client() + with pytest.raises(AssertionError): + client.prewarm_lfc(primary_endpoint_id) + assert client.prewarm_lfc_status()["status"] == "failed" + promote_spec = {"spec": primary_spec, "wal_flush_lsn": str(lsn)} + assert client.promote(promote_spec)["status"] == "completed" with secondary.connect() as conn, conn.cursor() as cur: cur.execute("select count(*) from t") diff --git a/test_runner/regress/test_safekeeper_migration.py b/test_runner/regress/test_safekeeper_migration.py index 371bec0c62..97a6ece446 100644 --- a/test_runner/regress/test_safekeeper_migration.py +++ b/test_runner/regress/test_safekeeper_migration.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING import pytest @@ -12,7 +13,7 @@ if TYPE_CHECKING: # TODO(diko): pageserver spams with various errors during safekeeper migration. # Fix the code so it handles the migration better. -ALLOWED_PAGESERVER_ERRORS = [ +PAGESERVER_ALLOWED_ERRORS = [ ".*Timeline .* was cancelled and cannot be used anymore.*", ".*Timeline .* has been deleted.*", ".*Timeline .* was not found in global map.*", @@ -35,7 +36,7 @@ def test_safekeeper_migration_simple(neon_env_builder: NeonEnvBuilder): "timeline_safekeeper_count": 1, } env = neon_env_builder.init_start() - env.pageserver.allowed_errors.extend(ALLOWED_PAGESERVER_ERRORS) + env.pageserver.allowed_errors.extend(PAGESERVER_ALLOWED_ERRORS) ep = env.endpoints.create("main", tenant_id=env.initial_tenant) @@ -136,7 +137,7 @@ def test_safekeeper_migration_common_set_failpoints(neon_env_builder: NeonEnvBui "timeline_safekeeper_count": 3, } env = neon_env_builder.init_start() - env.pageserver.allowed_errors.extend(ALLOWED_PAGESERVER_ERRORS) + env.pageserver.allowed_errors.extend(PAGESERVER_ALLOWED_ERRORS) mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline) assert len(mconf["sk_set"]) == 3 @@ -196,3 +197,266 @@ def test_safekeeper_migration_common_set_failpoints(neon_env_builder: NeonEnvBui assert ( f"timeline {env.initial_tenant}/{env.initial_timeline} deleted" in exc.value.response.text ) + + +def test_sk_generation_aware_tombstones(neon_env_builder: NeonEnvBuilder): + """ + Test that safekeeper respects generations: + 1. Check that migration back and forth between two safekeepers works. + 2. Check that sk refuses to execute requests with stale generation. + """ + neon_env_builder.num_safekeepers = 3 + neon_env_builder.storage_controller_config = { + "timelines_onto_safekeepers": True, + "timeline_safekeeper_count": 1, + } + env = neon_env_builder.init_start() + env.pageserver.allowed_errors.extend(PAGESERVER_ALLOWED_ERRORS) + + mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline) + assert mconf["new_sk_set"] is None + assert len(mconf["sk_set"]) == 1 + cur_sk = mconf["sk_set"][0] + + second_sk, third_sk = [sk.id for sk in env.safekeepers if sk.id != cur_sk] + cur_gen = 1 + + # Pull the timeline manually to third_sk, so the timeline exists there with stale generation. + # This is needed for the test later. + env.get_safekeeper(third_sk).pull_timeline( + [env.get_safekeeper(cur_sk)], env.initial_tenant, env.initial_timeline + ) + + def expect_deleted(sk_id: int): + with pytest.raises(requests.exceptions.HTTPError, match="Not Found") as exc: + env.get_safekeeper(sk_id).http_client().timeline_status( + env.initial_tenant, env.initial_timeline + ) + assert exc.value.response.status_code == 404 + assert re.match(r".*timeline .* deleted.*", exc.value.response.text) + + def get_mconf(sk_id: int): + status = ( + env.get_safekeeper(sk_id) + .http_client() + .timeline_status(env.initial_tenant, env.initial_timeline) + ) + assert status.mconf is not None + return status.mconf + + def migrate(): + nonlocal cur_sk, second_sk, cur_gen + env.storage_controller.migrate_safekeepers( + env.initial_tenant, env.initial_timeline, [second_sk] + ) + cur_sk, second_sk = second_sk, cur_sk + cur_gen += 2 + + # Migrate the timeline back and forth between cur_sk and second_sk. + for _i in range(3): + migrate() + # Timeline should exist on cur_sk. + assert get_mconf(cur_sk).generation == cur_gen + # Timeline should be deleted on second_sk. + expect_deleted(second_sk) + + # Remember current mconf. + mconf = get_mconf(cur_sk) + + # Migrate the timeline one more time. + # It increases the generation by 2. + migrate() + + # Check that sk refuses to execute the exclude request with the old mconf. + with pytest.raises(requests.exceptions.HTTPError, match="Conflict") as exc: + env.get_safekeeper(cur_sk).http_client().timeline_exclude( + env.initial_tenant, env.initial_timeline, mconf + ) + assert re.match(r".*refused to switch into excluding mconf.*", exc.value.response.text) + # We shouldn't have deleted the timeline. + assert get_mconf(cur_sk).generation == cur_gen + + # Check that sk refuses to execute the pull_timeline request with the old mconf. + # Note: we try to pull from third_sk, which has a timeline with stale generation. + # Thus, we bypass some preliminary generation checks and actually test tombstones. + with pytest.raises(requests.exceptions.HTTPError, match="Conflict") as exc: + env.get_safekeeper(second_sk).pull_timeline( + [env.get_safekeeper(third_sk)], env.initial_tenant, env.initial_timeline, mconf + ) + assert re.match(r".*Timeline .* deleted.*", exc.value.response.text) + # The timeline should remain deleted. + expect_deleted(second_sk) + + +def test_safekeeper_migration_stale_timeline(neon_env_builder: NeonEnvBuilder): + """ + Test that safekeeper migration handles stale timeline correctly by migrating to + a safekeeper with a stale timeline. + 1. Check that we are waiting for the stale timeline to catch up with the commit lsn. + The migration might fail if there is no compute to advance the WAL. + 2. Check that we rely on last_log_term (and not the current term) when waiting for the + sync_position on step 7. + 3. Check that migration succeeds if the compute is running. + """ + neon_env_builder.num_safekeepers = 2 + neon_env_builder.storage_controller_config = { + "timelines_onto_safekeepers": True, + "timeline_safekeeper_count": 1, + } + env = neon_env_builder.init_start() + env.pageserver.allowed_errors.extend(PAGESERVER_ALLOWED_ERRORS) + env.storage_controller.allowed_errors.append(".*not enough successful .* to reach quorum.*") + + mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline) + + active_sk = env.get_safekeeper(mconf["sk_set"][0]) + other_sk = [sk for sk in env.safekeepers if sk.id != active_sk.id][0] + + ep = env.endpoints.create("main", tenant_id=env.initial_tenant) + ep.start(safekeeper_generation=1, safekeepers=[active_sk.id]) + ep.safe_psql("CREATE TABLE t(a int)") + ep.safe_psql("INSERT INTO t VALUES (0)") + + # Pull the timeline to other_sk, so other_sk now has a "stale" timeline on it. + other_sk.pull_timeline([active_sk], env.initial_tenant, env.initial_timeline) + + # Advance the WAL on active_sk. + ep.safe_psql("INSERT INTO t VALUES (1)") + + # The test is more tricky if we have the same last_log_term but different term/flush_lsn. + # Stop the active_sk during the endpoint shutdown because otherwise compute_ctl runs + # sync_safekeepers and advances last_log_term on active_sk. + active_sk.stop() + ep.stop(mode="immediate") + active_sk.start() + + active_sk_status = active_sk.http_client().timeline_status( + env.initial_tenant, env.initial_timeline + ) + other_sk_status = other_sk.http_client().timeline_status( + env.initial_tenant, env.initial_timeline + ) + + # other_sk should have the same last_log_term, but a stale flush_lsn. + assert active_sk_status.last_log_term == other_sk_status.last_log_term + assert active_sk_status.flush_lsn > other_sk_status.flush_lsn + + commit_lsn = active_sk_status.flush_lsn + + # Bump the term on other_sk to make it higher than active_sk. + # This is to make sure we don't use current term instead of last_log_term in the algorithm. + other_sk.http_client().term_bump( + env.initial_tenant, env.initial_timeline, active_sk_status.term + 100 + ) + + # TODO(diko): now it fails because the timeline on other_sk is stale and there is no compute + # to catch up it with active_sk. It might be fixed in https://databricks.atlassian.net/browse/LKB-946 + # if we delete stale timelines before starting the migration. + # But the rest of the test is still valid: we should not lose committed WAL after the migration. + with pytest.raises( + StorageControllerApiException, match="not enough successful .* to reach quorum" + ): + env.storage_controller.migrate_safekeepers( + env.initial_tenant, env.initial_timeline, [other_sk.id] + ) + + mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline) + assert mconf["new_sk_set"] == [other_sk.id] + assert mconf["sk_set"] == [active_sk.id] + assert mconf["generation"] == 2 + + # Start the endpoint, so it advances the WAL on other_sk. + ep.start(safekeeper_generation=2, safekeepers=[active_sk.id, other_sk.id]) + # Now the migration should succeed. + env.storage_controller.migrate_safekeepers( + env.initial_tenant, env.initial_timeline, [other_sk.id] + ) + + # Check that we didn't lose committed WAL. + assert ( + other_sk.http_client().timeline_status(env.initial_tenant, env.initial_timeline).flush_lsn + >= commit_lsn + ) + assert ep.safe_psql("SELECT * FROM t") == [(0,), (1,)] + + +def test_pull_from_most_advanced_sk(neon_env_builder: NeonEnvBuilder): + """ + Test that we pull the timeline from the most advanced safekeeper during the + migration and do not lose committed WAL. + """ + neon_env_builder.num_safekeepers = 4 + neon_env_builder.storage_controller_config = { + "timelines_onto_safekeepers": True, + "timeline_safekeeper_count": 3, + } + env = neon_env_builder.init_start() + env.pageserver.allowed_errors.extend(PAGESERVER_ALLOWED_ERRORS) + + mconf = env.storage_controller.timeline_locate(env.initial_tenant, env.initial_timeline) + + sk_set = mconf["sk_set"] + assert len(sk_set) == 3 + + other_sk = [sk.id for sk in env.safekeepers if sk.id not in sk_set][0] + + ep = env.endpoints.create("main", tenant_id=env.initial_tenant) + ep.start(safekeeper_generation=1, safekeepers=sk_set) + ep.safe_psql("CREATE TABLE t(a int)") + ep.safe_psql("INSERT INTO t VALUES (0)") + + # Stop one sk, so we have a lagging WAL on it. + env.get_safekeeper(sk_set[0]).stop() + # Advance the WAL on the other sks. + ep.safe_psql("INSERT INTO t VALUES (1)") + + # Stop other sks to make sure compute_ctl doesn't advance the last_log_term on them during shutdown. + for sk_id in sk_set[1:]: + env.get_safekeeper(sk_id).stop() + ep.stop(mode="immediate") + for sk_id in sk_set: + env.get_safekeeper(sk_id).start() + + # Bump the term on the lagging sk to make sure we don't use it to choose the most advanced sk. + env.get_safekeeper(sk_set[0]).http_client().term_bump( + env.initial_tenant, env.initial_timeline, 100 + ) + + def get_commit_lsn(sk_set: list[int]): + flush_lsns = [] + last_log_terms = [] + for sk_id in sk_set: + sk = env.get_safekeeper(sk_id) + status = sk.http_client().timeline_status(env.initial_tenant, env.initial_timeline) + flush_lsns.append(status.flush_lsn) + last_log_terms.append(status.last_log_term) + + # In this test we assume that all sks have the same last_log_term. + assert len(set(last_log_terms)) == 1 + + flush_lsns.sort(reverse=True) + commit_lsn = flush_lsns[len(sk_set) // 2] + + log.info(f"sk_set: {sk_set}, flush_lsns: {flush_lsns}, commit_lsn: {commit_lsn}") + return commit_lsn + + commit_lsn_before_migration = get_commit_lsn(sk_set) + + # Make two migrations, so the lagging sk stays in the sk_set, but other sks are replaced. + new_sk_set1 = [sk_set[0], sk_set[1], other_sk] # remove sk_set[2], add other_sk + new_sk_set2 = [sk_set[0], other_sk, sk_set[2]] # remove sk_set[1], add sk_set[2] back + env.storage_controller.migrate_safekeepers( + env.initial_tenant, env.initial_timeline, new_sk_set1 + ) + env.storage_controller.migrate_safekeepers( + env.initial_tenant, env.initial_timeline, new_sk_set2 + ) + + commit_lsn_after_migration = get_commit_lsn(new_sk_set2) + + # We should not lose committed WAL. + # If we have choosen the lagging sk to pull the timeline from, this might fail. + assert commit_lsn_before_migration <= commit_lsn_after_migration + + ep.start(safekeeper_generation=5, safekeepers=new_sk_set2) + assert ep.safe_psql("SELECT * FROM t") == [(0,), (1,)] diff --git a/test_runner/regress/test_sharding.py b/test_runner/regress/test_sharding.py index 2252c098c7..4e46b67988 100644 --- a/test_runner/regress/test_sharding.py +++ b/test_runner/regress/test_sharding.py @@ -1508,20 +1508,55 @@ def test_sharding_split_failures( env.storage_controller.consistency_check() -@pytest.mark.skip(reason="The backpressure change has not been merged yet.") +# HADRON +def test_create_tenant_after_split(neon_env_builder: NeonEnvBuilder): + """ + Tests creating a tenant and a timeline should fail after a tenant split. + """ + env = neon_env_builder.init_start(initial_tenant_shard_count=4) + + env.storage_controller.allowed_errors.extend( + [ + ".*already exists with a different shard count.*", + ] + ) + + ep = env.endpoints.create_start("main", tenant_id=env.initial_tenant) + ep.safe_psql("CREATE TABLE usertable ( YCSB_KEY INT, FIELD0 TEXT);") + ep.safe_psql("INSERT INTO usertable VALUES (1, 'test1');") + ep.safe_psql("INSERT INTO usertable VALUES (2, 'test2');") + ep.safe_psql("INSERT INTO usertable VALUES (3, 'test3');") + + # Split the tenant + + env.storage_controller.tenant_shard_split(env.initial_tenant, shard_count=8) + + with pytest.raises(RuntimeError): + env.create_tenant(env.initial_tenant, env.initial_timeline, shard_count=4) + + # run more queries + ep.safe_psql("SELECT * FROM usertable;") + ep.safe_psql("UPDATE usertable set FIELD0 = 'test4';") + + ep.stop_and_destroy() + + +# HADRON def test_back_pressure_during_split(neon_env_builder: NeonEnvBuilder): """ - Test backpressure can ignore new shards during tenant split so that if we abort the split, - PG can continue without being blocked. + Test backpressure works correctly during a shard split, especially after a split is aborted, + PG will not be stuck forever. """ - DBNAME = "regression" - - init_shard_count = 4 + init_shard_count = 1 neon_env_builder.num_pageservers = init_shard_count stripe_size = 32 env = neon_env_builder.init_start( - initial_tenant_shard_count=init_shard_count, initial_tenant_shard_stripe_size=stripe_size + initial_tenant_shard_count=init_shard_count, + initial_tenant_shard_stripe_size=stripe_size, + initial_tenant_conf={ + "checkpoint_distance": 1024 * 1024 * 1024, + }, ) env.storage_controller.allowed_errors.extend( @@ -1537,19 +1572,31 @@ def test_back_pressure_during_split(neon_env_builder: NeonEnvBuilder): "main", config_lines=[ "max_replication_write_lag = 1MB", - "databricks.max_wal_mb_per_second = 1", "neon.max_cluster_size = 10GB", + "databricks.max_wal_mb_per_second=100", ], ) - endpoint.respec(skip_pg_catalog_updates=False) # Needed for databricks_system to get created. + endpoint.respec(skip_pg_catalog_updates=False) endpoint.start() - endpoint.safe_psql(f"CREATE DATABASE {DBNAME}") - - endpoint.safe_psql("CREATE TABLE usertable ( YCSB_KEY INT, FIELD0 TEXT);") + # generate 10MB of data + endpoint.safe_psql( + "CREATE TABLE usertable AS SELECT s AS KEY, repeat('a', 1000) as VALUE from generate_series(1, 10000) s;" + ) write_done = Event() - def write_data(write_done): + def get_write_lag(): + res = endpoint.safe_psql( + """ + SELECT + pg_wal_lsn_diff(pg_current_wal_flush_lsn(), received_lsn) as received_lsn_lag + FROM neon.backpressure_lsns(); + """, + log_query=False, + ) + return res[0][0] + + def write_data(write_done: Event): while not write_done.is_set(): endpoint.safe_psql( "INSERT INTO usertable SELECT random(), repeat('a', 1000);", log_query=False @@ -1560,35 +1607,39 @@ def test_back_pressure_during_split(neon_env_builder: NeonEnvBuilder): writer_thread.start() env.storage_controller.configure_failpoints(("shard-split-pre-complete", "return(1)")) + # sleep 10 seconds before re-activating the old shard when aborting the split. + # this is to add some backpressures to PG + env.pageservers[0].http_client().configure_failpoints( + ("attach-before-activate-sleep", "return(10000)"), + ) # split the tenant with pytest.raises(StorageControllerApiException): - env.storage_controller.tenant_shard_split(env.initial_tenant, shard_count=16) + env.storage_controller.tenant_shard_split(env.initial_tenant, shard_count=4) + + def check_tenant_status(): + status = ( + env.pageservers[0].http_client().tenant_status(TenantShardId(env.initial_tenant, 0, 1)) + ) + assert status["state"]["slug"] == "Active" + + wait_until(check_tenant_status) write_done.set() writer_thread.join() + log.info(f"current write lag: {get_write_lag()}") + # writing more data to page servers after split is aborted - for _i in range(5000): - endpoint.safe_psql( - "INSERT INTO usertable SELECT random(), repeat('a', 1000);", log_query=False - ) + with endpoint.cursor() as cur: + for _i in range(1000): + cur.execute("INSERT INTO usertable SELECT random(), repeat('a', 1000);") # wait until write lag becomes 0 def check_write_lag_is_zero(): - res = endpoint.safe_psql( - """ - SELECT - pg_wal_lsn_diff(pg_current_wal_flush_lsn(), received_lsn) as received_lsn_lag - FROM neon.backpressure_lsns(); - """, - dbname="databricks_system", - log_query=False, - ) - log.info(f"received_lsn_lag = {res[0][0]}") - assert res[0][0] == 0 + res = get_write_lag() + assert res == 0 wait_until(check_write_lag_is_zero) - endpoint.stop_and_destroy() # BEGIN_HADRON @@ -1674,7 +1725,6 @@ def test_shard_resolve_during_split_abort(neon_env_builder: NeonEnvBuilder): # HADRON -@pytest.mark.skip(reason="The backpressure change has not been merged yet.") def test_back_pressure_per_shard(neon_env_builder: NeonEnvBuilder): """ Tests back pressure knobs are enforced on the per shard basis instead of at the tenant level. @@ -1701,22 +1751,19 @@ def test_back_pressure_per_shard(neon_env_builder: NeonEnvBuilder): "max_replication_apply_lag = 0", "max_replication_flush_lag = 15MB", "neon.max_cluster_size = 10GB", + "neon.lakebase_mode = true", ], ) - endpoint.respec(skip_pg_catalog_updates=False) # Needed for databricks_system to get created. + endpoint.respec(skip_pg_catalog_updates=False) endpoint.start() # generate 20MB of data endpoint.safe_psql( "CREATE TABLE usertable AS SELECT s AS KEY, repeat('a', 1000) as VALUE from generate_series(1, 20000) s;" ) - res = endpoint.safe_psql( - "SELECT neon.backpressure_throttling_time() as throttling_time", dbname="databricks_system" - )[0] + res = endpoint.safe_psql("SELECT neon.backpressure_throttling_time() as throttling_time")[0] assert res[0] == 0, f"throttling_time should be 0, but got {res[0]}" - endpoint.stop() - # HADRON def test_shard_split_page_server_timeout(neon_env_builder: NeonEnvBuilder): @@ -1880,14 +1927,14 @@ def test_sharding_backpressure(neon_env_builder: NeonEnvBuilder): shards_info() for _write_iter in range(30): - # approximately 1MB of data - workload.write_rows(8000, upload=False) + # approximately 10MB of data + workload.write_rows(80000, upload=False) update_write_lsn() infos = shards_info() min_lsn = min(Lsn(info["last_record_lsn"]) for info in infos) max_lsn = max(Lsn(info["last_record_lsn"]) for info in infos) diff = max_lsn - min_lsn - assert diff < 2 * 1024 * 1024, f"LSN diff={diff}, expected diff < 2MB due to backpressure" + assert diff < 8 * 1024 * 1024, f"LSN diff={diff}, expected diff < 8MB due to backpressure" def test_sharding_unlogged_relation(neon_env_builder: NeonEnvBuilder): diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index 9986c1f24a..e11be1df8c 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -3309,6 +3309,7 @@ def test_ps_unavailable_after_delete( ps.allowed_errors.append(".*request was dropped before completing.*") env.storage_controller.node_delete(ps.id, force=True) wait_until(lambda: assert_nodes_count(2)) + env.storage_controller.reconcile_until_idle() elif deletion_api == DeletionAPIKind.OLD: env.storage_controller.node_delete_old(ps.id) assert_nodes_count(2) @@ -4959,3 +4960,49 @@ def test_storage_controller_forward_404(neon_env_builder: NeonEnvBuilder): env.storage_controller.configure_failpoints( ("reconciler-live-migrate-post-generation-inc", "off") ) + + +def test_re_attach_with_stuck_secondary(neon_env_builder: NeonEnvBuilder): + """ + This test assumes that the secondary location cannot be configured for whatever reason. + It then attempts to detach and and attach the tenant back again and, finally, checks + for observed state consistency by attempting to create a timeline. + + See LKB-204 for more details. + """ + + neon_env_builder.num_pageservers = 2 + + env = neon_env_builder.init_configs() + env.start() + + env.storage_controller.allowed_errors.append(".*failpoint.*") + + tenant_id, _ = env.create_tenant(shard_count=1, placement_policy='{"Attached":1}') + env.storage_controller.reconcile_until_idle() + + locations = env.storage_controller.locate(tenant_id) + assert len(locations) == 1 + primary: int = locations[0]["node_id"] + + not_primary = [ps.id for ps in env.pageservers if ps.id != primary] + assert len(not_primary) == 1 + secondary = not_primary[0] + + env.get_pageserver(secondary).http_client().configure_failpoints( + ("put-location-conf-handler", "return(1)") + ) + + env.storage_controller.tenant_policy_update(tenant_id, {"placement": "Detached"}) + + with pytest.raises(Exception, match="failpoint"): + env.storage_controller.reconcile_all() + + env.storage_controller.tenant_policy_update(tenant_id, {"placement": {"Attached": 1}}) + + with pytest.raises(Exception, match="failpoint"): + env.storage_controller.reconcile_all() + + env.storage_controller.pageserver_api().timeline_create( + pg_version=PgVersion.NOT_SET, tenant_id=tenant_id, new_timeline_id=TimelineId.generate() + ) diff --git a/test_runner/regress/test_tenant_size.py b/test_runner/regress/test_tenant_size.py index 8b291b7cbe..5564d9342c 100644 --- a/test_runner/regress/test_tenant_size.py +++ b/test_runner/regress/test_tenant_size.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING @@ -768,6 +769,14 @@ def test_lsn_lease_storcon(neon_env_builder: NeonEnvBuilder): "compaction_period": "0s", } env = neon_env_builder.init_start(initial_tenant_conf=conf) + # ShardSplit is slow in debug builds, so ignore the warning + if os.getenv("BUILD_TYPE", "debug") == "debug": + env.storage_controller.allowed_errors.extend( + [ + ".*Exclusive lock by ShardSplit was held.*", + ] + ) + with env.endpoints.create_start( "main", ) as ep: diff --git a/test_runner/regress/test_tenants.py b/test_runner/regress/test_tenants.py index 7f32f34d36..49bc02a3e7 100644 --- a/test_runner/regress/test_tenants.py +++ b/test_runner/regress/test_tenants.py @@ -298,15 +298,26 @@ def test_pageserver_metrics_removed_after_detach(neon_env_builder: NeonEnvBuilde assert post_detach_samples == set() -def test_pageserver_metrics_removed_after_offload(neon_env_builder: NeonEnvBuilder): +@pytest.mark.parametrize("compaction", ["compaction_enabled", "compaction_disabled"]) +def test_pageserver_metrics_removed_after_offload( + neon_env_builder: NeonEnvBuilder, compaction: str +): """Tests that when a timeline is offloaded, the tenant specific metrics are not left behind""" neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.MOCK_S3) - neon_env_builder.num_safekeepers = 3 env = neon_env_builder.init_start() - tenant_1, _ = env.create_tenant() + tenant_1, _ = env.create_tenant( + conf={ + # disable background compaction and GC so that we don't have leftover tasks + # after offloading. + "gc_period": "0s", + "compaction_period": "0s", + } + if compaction == "compaction_disabled" + else None + ) timeline_1 = env.create_timeline("test_metrics_removed_after_offload_1", tenant_id=tenant_1) timeline_2 = env.create_timeline("test_metrics_removed_after_offload_2", tenant_id=tenant_1) @@ -351,6 +362,23 @@ def test_pageserver_metrics_removed_after_offload(neon_env_builder: NeonEnvBuild state=TimelineArchivalState.ARCHIVED, ) env.pageserver.http_client().timeline_offload(tenant_1, timeline) + # We need to wait until all background jobs are finished before we can check the metrics. + # There're many of them: compaction, GC, etc. + wait_until( + lambda: all( + sample.value == 0 + for sample in env.pageserver.http_client() + .get_metrics() + .query_all("pageserver_background_loop_semaphore_waiting_tasks") + ) + and all( + sample.value == 0 + for sample in env.pageserver.http_client() + .get_metrics() + .query_all("pageserver_background_loop_semaphore_running_tasks") + ) + ) + post_offload_samples = set( [x.name for x in get_ps_metric_samples_for_timeline(tenant_1, timeline)] ) diff --git a/test_runner/regress/test_wal_acceptor.py b/test_runner/regress/test_wal_acceptor.py index c691087259..1e8ca216d0 100644 --- a/test_runner/regress/test_wal_acceptor.py +++ b/test_runner/regress/test_wal_acceptor.py @@ -2757,18 +2757,37 @@ def test_timeline_disk_usage_limit(neon_env_builder: NeonEnvBuilder): remote_storage_kind = s3_storage() neon_env_builder.enable_safekeeper_remote_storage(remote_storage_kind) - # Set a very small disk usage limit (1KB) - neon_env_builder.safekeeper_extra_opts = ["--max-timeline-disk-usage-bytes=1024"] - env = neon_env_builder.init_start() # Create a timeline and endpoint env.create_branch("test_timeline_disk_usage_limit") - endpoint = env.endpoints.create_start("test_timeline_disk_usage_limit") + endpoint = env.endpoints.create_start( + "test_timeline_disk_usage_limit", + config_lines=[ + "neon.lakebase_mode=true", + ], + ) + + # Install the neon extension in the test database. We need it to query perf counter metrics. + with closing(endpoint.connect()) as conn: + with conn.cursor() as cur: + cur.execute("CREATE EXTENSION IF NOT EXISTS neon") + # Sanity-check safekeeper connection status in neon_perf_counters in the happy case. + cur.execute( + "SELECT value FROM neon_perf_counters WHERE metric = 'num_active_safekeepers'" + ) + assert cur.fetchone() == (1,), "Expected 1 active safekeeper" + cur.execute( + "SELECT value FROM neon_perf_counters WHERE metric = 'num_configured_safekeepers'" + ) + assert cur.fetchone() == (1,), "Expected 1 configured safekeeper" # Get the safekeeper sk = env.safekeepers[0] + # Restart the safekeeper with a very small disk usage limit (1KB) + sk.stop().start(["--max-timeline-disk-usage-bytes=1024"]) + # Inject a failpoint to stop WAL backup with sk.http_client() as http_cli: http_cli.configure_failpoints([("backup-lsn-range-pausable", "pause")]) @@ -2794,6 +2813,18 @@ def test_timeline_disk_usage_limit(neon_env_builder: NeonEnvBuilder): wait_until(error_logged) log.info("Found expected error message in compute log, resuming.") + with closing(endpoint.connect()) as conn: + with conn.cursor() as cur: + # Confirm that neon_perf_counters also indicates that there are no active safekeepers + cur.execute( + "SELECT value FROM neon_perf_counters WHERE metric = 'num_active_safekeepers'" + ) + assert cur.fetchone() == (0,), "Expected 0 active safekeepers" + cur.execute( + "SELECT value FROM neon_perf_counters WHERE metric = 'num_configured_safekeepers'" + ) + assert cur.fetchone() == (1,), "Expected 1 configured safekeeper" + # Sanity check that the hanging insert is indeed still hanging. Otherwise means the circuit breaker we # implemented didn't work as expected. time.sleep(2) diff --git a/test_runner/sql_regress/expected/neon-spgist.out b/test_runner/sql_regress/expected/neon-spgist.out new file mode 100644 index 0000000000..5982084109 --- /dev/null +++ b/test_runner/sql_regress/expected/neon-spgist.out @@ -0,0 +1,9 @@ +-- Test unlogged build of SPGIST index (no "Page evicted with zero LSN" error) +create table spgist_point_tbl(id int4, p point); +create index spgist_point_idx on spgist_point_tbl using spgist(p) with (fillfactor = 25); +insert into spgist_point_tbl (id, p) select g, point(g*10, g*10) from generate_series(1, 10000) g; +insert into spgist_point_tbl (id, p) select g, point(g*10, g*10) from generate_series(1, 10000) g; +insert into spgist_point_tbl (id, p) select g+100000, point(g*10+1, g*10+1) from generate_series(1, 10000) g; +vacuum spgist_point_tbl; +insert into spgist_point_tbl (id, p) select g+100000, point(g*10+1, g*10+1) from generate_series(1, 10000) g; +checkpoint; diff --git a/test_runner/sql_regress/parallel_schedule b/test_runner/sql_regress/parallel_schedule index 0ce9f0e28f..d724c750ff 100644 --- a/test_runner/sql_regress/parallel_schedule +++ b/test_runner/sql_regress/parallel_schedule @@ -9,5 +9,6 @@ test: neon-rel-truncate test: neon-clog test: neon-test-utils test: neon-vacuum-full -test: neon-event-triggers test: neon-subxacts +test: neon-spgist +test: neon-event-triggers diff --git a/test_runner/sql_regress/sql/neon-spgist.sql b/test_runner/sql_regress/sql/neon-spgist.sql new file mode 100644 index 0000000000..b26b692ff7 --- /dev/null +++ b/test_runner/sql_regress/sql/neon-spgist.sql @@ -0,0 +1,10 @@ +-- Test unlogged build of SPGIST index (no "Page evicted with zero LSN" error) +create table spgist_point_tbl(id int4, p point); +create index spgist_point_idx on spgist_point_tbl using spgist(p) with (fillfactor = 25); +insert into spgist_point_tbl (id, p) select g, point(g*10, g*10) from generate_series(1, 10000) g; +insert into spgist_point_tbl (id, p) select g, point(g*10, g*10) from generate_series(1, 10000) g; +insert into spgist_point_tbl (id, p) select g+100000, point(g*10+1, g*10+1) from generate_series(1, 10000) g; + +vacuum spgist_point_tbl; +insert into spgist_point_tbl (id, p) select g+100000, point(g*10+1, g*10+1) from generate_series(1, 10000) g; +checkpoint; diff --git a/vendor/revisions.json b/vendor/revisions.json index d62f8e5736..c02c748a72 100644 --- a/vendor/revisions.json +++ b/vendor/revisions.json @@ -1,18 +1,18 @@ { "v17": [ "17.5", - "ba750903a90dded8098f2f56d0b2a9012e6166af" + "1e01fcea2a6b38180021aa83e0051d95286d9096" ], "v16": [ "16.9", - "ad2b69b58230290fc44c08fbe0c97981c64f6c7d" + "a42351fcd41ea01edede1daed65f651e838988fc" ], "v15": [ "15.13", - "e5ee23d99874ea9f5b62f8acc7d076162ae95d6c" + "2aaab3bb4a13557aae05bb2ae0ef0a132d0c4f85" ], "v14": [ "14.18", - "4cacada8bde7f6424751a0727a657783c6a1d20b" + "2155cb165d05f617eb2c8ad7e43367189b627703" ] } diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 105e4afb87..4900f9dd35 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -29,6 +29,7 @@ clap = { version = "4", features = ["derive", "env", "string"] } clap_builder = { version = "4", default-features = false, features = ["color", "env", "help", "std", "string", "suggestions", "usage"] } const-oid = { version = "0.9", default-features = false, features = ["db", "std"] } criterion = { version = "0.5", features = ["html_reports"] } +crossbeam-epoch = { version = "0.9" } crypto-bigint = { version = "0.5", features = ["generic-array", "zeroize"] } der = { version = "0.7", default-features = false, features = ["derive", "flagset", "oid", "pem", "std"] } deranged = { version = "0.3", default-features = false, features = ["powerfmt", "serde", "std"] } @@ -73,6 +74,7 @@ num-rational = { version = "0.4", default-features = false, features = ["num-big num-traits = { version = "0.2", features = ["i128", "libm"] } p256 = { version = "0.13", features = ["jwk"] } parquet = { version = "53", default-features = false, features = ["zstd"] } +portable-atomic = { version = "1", features = ["require-cas"] } prost = { version = "0.13", features = ["no-recursion-limit", "prost-derive"] } rand = { version = "0.9" } regex = { version = "1" }