diff --git a/.github/actionlint.yml b/.github/actionlint.yml index 25b2fc702a..8a4bcaf811 100644 --- a/.github/actionlint.yml +++ b/.github/actionlint.yml @@ -31,7 +31,7 @@ config-variables: - NEON_PROD_AWS_ACCOUNT_ID - PGREGRESS_PG16_PROJECT_ID - PGREGRESS_PG17_PROJECT_ID - - PREWARM_PGBENCH_SIZE + - PREWARM_PROJECT_ID - REMOTE_STORAGE_AZURE_CONTAINER - REMOTE_STORAGE_AZURE_REGION - SLACK_CICD_CHANNEL_ID diff --git a/.github/workflows/benchbase_tpcc.yml b/.github/workflows/benchbase_tpcc.yml new file mode 100644 index 0000000000..3a36a97bb1 --- /dev/null +++ b/.github/workflows/benchbase_tpcc.yml @@ -0,0 +1,384 @@ +name: TPC-C like benchmark using benchbase + +on: + schedule: + # * is a special character in YAML so you have to quote this string + # ┌───────────── minute (0 - 59) + # │ ┌───────────── hour (0 - 23) + # │ │ ┌───────────── day of the month (1 - 31) + # │ │ │ ┌───────────── month (1 - 12 or JAN-DEC) + # │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT) + - cron: '0 6 * * *' # run once a day at 6 AM UTC + workflow_dispatch: # adds ability to run this manually + +defaults: + run: + shell: bash -euxo pipefail {0} + +concurrency: + # Allow only one workflow globally because we do not want to be too noisy in production environment + group: benchbase-tpcc-workflow + cancel-in-progress: false + +permissions: + contents: read + +jobs: + benchbase-tpcc: + strategy: + fail-fast: false # allow other variants to continue even if one fails + matrix: + include: + - warehouses: 50 # defines number of warehouses and is used to compute number of terminals + max_rate: 800 # measured max TPS at scale factor based on experiments. Adjust if performance is better/worse + min_cu: 0.25 # simulate free tier plan (0.25 -2 CU) + max_cu: 2 + - warehouses: 500 # serverless plan (2-8 CU) + max_rate: 2000 + min_cu: 2 + max_cu: 8 + - warehouses: 1000 # business plan (2-16 CU) + max_rate: 2900 + min_cu: 2 + max_cu: 16 + max-parallel: 1 # we want to run each workload size sequentially to avoid noisy neighbors + permissions: + contents: write + statuses: write + id-token: write # aws-actions/configure-aws-credentials + env: + PG_CONFIG: /tmp/neon/pg_install/v17/bin/pg_config + PSQL: /tmp/neon/pg_install/v17/bin/psql + PG_17_LIB_PATH: /tmp/neon/pg_install/v17/lib + POSTGRES_VERSION: 17 + runs-on: [ self-hosted, us-east-2, x64 ] + timeout-minutes: 1440 + + steps: + - name: Harden the runner (Audit all outbound calls) + uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0 + with: + egress-policy: audit + + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Configure AWS credentials # necessary to download artefacts + uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2 + with: + aws-region: eu-central-1 + role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} + role-duration-seconds: 18000 # 5 hours is currently max associated with IAM role + + - name: Download Neon artifact + uses: ./.github/actions/download + with: + name: neon-${{ runner.os }}-${{ runner.arch }}-release-artifact + path: /tmp/neon/ + prefix: latest + aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} + + - name: Create Neon Project + id: create-neon-project-tpcc + uses: ./.github/actions/neon-project-create + with: + region_id: aws-us-east-2 + postgres_version: ${{ env.POSTGRES_VERSION }} + compute_units: '[${{ matrix.min_cu }}, ${{ matrix.max_cu }}]' + api_key: ${{ secrets.NEON_PRODUCTION_API_KEY_4_BENCHMARKS }} + api_host: console.neon.tech # production (!) + + - name: Initialize Neon project + env: + BENCHMARK_TPCC_CONNSTR: ${{ steps.create-neon-project-tpcc.outputs.dsn }} + PROJECT_ID: ${{ steps.create-neon-project-tpcc.outputs.project_id }} + run: | + echo "Initializing Neon project with project_id: ${PROJECT_ID}" + export LD_LIBRARY_PATH=${PG_17_LIB_PATH} + + # Retry logic for psql connection with 1 minute sleep between attempts + for attempt in {1..3}; do + echo "Attempt ${attempt}/3: Creating extensions in Neon project" + if ${PSQL} "${BENCHMARK_TPCC_CONNSTR}" -c "CREATE EXTENSION IF NOT EXISTS neon; CREATE EXTENSION IF NOT EXISTS neon_utils;"; then + echo "Successfully created extensions" + break + else + echo "Failed to create extensions on attempt ${attempt}" + if [ ${attempt} -lt 3 ]; then + echo "Waiting 60 seconds before retry..." + sleep 60 + else + echo "All attempts failed, exiting" + exit 1 + fi + fi + done + + echo "BENCHMARK_TPCC_CONNSTR=${BENCHMARK_TPCC_CONNSTR}" >> $GITHUB_ENV + + - name: Generate BenchBase workload configuration + env: + WAREHOUSES: ${{ matrix.warehouses }} + MAX_RATE: ${{ matrix.max_rate }} + run: | + echo "Generating BenchBase configs for warehouses: ${WAREHOUSES}, max_rate: ${MAX_RATE}" + + # Extract hostname and password from connection string + # Format: postgresql://username:password@hostname/database?params (no port for Neon) + HOSTNAME=$(echo "${BENCHMARK_TPCC_CONNSTR}" | sed -n 's|.*://[^:]*:[^@]*@\([^/]*\)/.*|\1|p') + PASSWORD=$(echo "${BENCHMARK_TPCC_CONNSTR}" | sed -n 's|.*://[^:]*:\([^@]*\)@.*|\1|p') + + echo "Extracted hostname: ${HOSTNAME}" + + # Use runner temp (NVMe) as working directory + cd "${RUNNER_TEMP}" + + # Copy the generator script + cp "${GITHUB_WORKSPACE}/test_runner/performance/benchbase_tpc_c_helpers/generate_workload_size.py" . + + # Generate configs and scripts + python3 generate_workload_size.py \ + --warehouses ${WAREHOUSES} \ + --max-rate ${MAX_RATE} \ + --hostname ${HOSTNAME} \ + --password ${PASSWORD} \ + --runner-arch ${{ runner.arch }} + + # Fix path mismatch: move generated configs and scripts to expected locations + mv ../configs ./configs + mv ../scripts ./scripts + + - name: Prepare database (load data) + env: + WAREHOUSES: ${{ matrix.warehouses }} + run: | + cd "${RUNNER_TEMP}" + + echo "Loading ${WAREHOUSES} warehouses into database..." + + # Run the loader script and capture output to log file while preserving stdout/stderr + ./scripts/load_${WAREHOUSES}_warehouses.sh 2>&1 | tee "load_${WAREHOUSES}_warehouses.log" + + echo "Database loading completed" + + - name: Run TPC-C benchmark (warmup phase, then benchmark at 70% of configuredmax TPS) + env: + WAREHOUSES: ${{ matrix.warehouses }} + run: | + cd "${RUNNER_TEMP}" + + echo "Running TPC-C benchmark with ${WAREHOUSES} warehouses..." + + # Run the optimal rate benchmark + ./scripts/execute_${WAREHOUSES}_warehouses_opt_rate.sh + + echo "Benchmark execution completed" + + - name: Run TPC-C benchmark (warmup phase, then ramp down TPS and up again in 5 minute intervals) + + env: + WAREHOUSES: ${{ matrix.warehouses }} + run: | + cd "${RUNNER_TEMP}" + + echo "Running TPC-C ramp-down-up with ${WAREHOUSES} warehouses..." + + # Run the optimal rate benchmark + ./scripts/execute_${WAREHOUSES}_warehouses_ramp_up.sh + + echo "Benchmark execution completed" + + - name: Process results (upload to test results database and generate diagrams) + env: + WAREHOUSES: ${{ matrix.warehouses }} + MIN_CU: ${{ matrix.min_cu }} + MAX_CU: ${{ matrix.max_cu }} + PROJECT_ID: ${{ steps.create-neon-project-tpcc.outputs.project_id }} + REVISION: ${{ github.sha }} + PERF_DB_CONNSTR: ${{ secrets.PERF_TEST_RESULT_CONNSTR }} + run: | + cd "${RUNNER_TEMP}" + + echo "Creating temporary Python environment for results processing..." + + # Create temporary virtual environment + python3 -m venv temp_results_env + source temp_results_env/bin/activate + + # Install required packages in virtual environment + pip install matplotlib pandas psycopg2-binary + + echo "Copying results processing scripts..." + + # Copy both processing scripts + cp "${GITHUB_WORKSPACE}/test_runner/performance/benchbase_tpc_c_helpers/generate_diagrams.py" . + cp "${GITHUB_WORKSPACE}/test_runner/performance/benchbase_tpc_c_helpers/upload_results_to_perf_test_results.py" . + + echo "Processing load phase metrics..." + + # Find and process load log + LOAD_LOG=$(find . -name "load_${WAREHOUSES}_warehouses.log" -type f | head -1) + if [ -n "$LOAD_LOG" ]; then + echo "Processing load metrics from: $LOAD_LOG" + python upload_results_to_perf_test_results.py \ + --load-log "$LOAD_LOG" \ + --run-type "load" \ + --warehouses "${WAREHOUSES}" \ + --min-cu "${MIN_CU}" \ + --max-cu "${MAX_CU}" \ + --project-id "${PROJECT_ID}" \ + --revision "${REVISION}" \ + --connection-string "${PERF_DB_CONNSTR}" + else + echo "Warning: Load log file not found: load_${WAREHOUSES}_warehouses.log" + fi + + echo "Processing warmup results for optimal rate..." + + # Find and process warmup results + WARMUP_CSV=$(find results_warmup -name "*.results.csv" -type f | head -1) + WARMUP_JSON=$(find results_warmup -name "*.summary.json" -type f | head -1) + + if [ -n "$WARMUP_CSV" ] && [ -n "$WARMUP_JSON" ]; then + echo "Generating warmup diagram from: $WARMUP_CSV" + python generate_diagrams.py \ + --input-csv "$WARMUP_CSV" \ + --output-svg "warmup_${WAREHOUSES}_warehouses_performance.svg" \ + --title-suffix "Warmup at max TPS" + + echo "Uploading warmup metrics from: $WARMUP_JSON" + python upload_results_to_perf_test_results.py \ + --summary-json "$WARMUP_JSON" \ + --results-csv "$WARMUP_CSV" \ + --run-type "warmup" \ + --min-cu "${MIN_CU}" \ + --max-cu "${MAX_CU}" \ + --project-id "${PROJECT_ID}" \ + --revision "${REVISION}" \ + --connection-string "${PERF_DB_CONNSTR}" + else + echo "Warning: Missing warmup results files (CSV: $WARMUP_CSV, JSON: $WARMUP_JSON)" + fi + + echo "Processing optimal rate results..." + + # Find and process optimal rate results + OPTRATE_CSV=$(find results_opt_rate -name "*.results.csv" -type f | head -1) + OPTRATE_JSON=$(find results_opt_rate -name "*.summary.json" -type f | head -1) + + if [ -n "$OPTRATE_CSV" ] && [ -n "$OPTRATE_JSON" ]; then + echo "Generating optimal rate diagram from: $OPTRATE_CSV" + python generate_diagrams.py \ + --input-csv "$OPTRATE_CSV" \ + --output-svg "benchmark_${WAREHOUSES}_warehouses_performance.svg" \ + --title-suffix "70% of max TPS" + + echo "Uploading optimal rate metrics from: $OPTRATE_JSON" + python upload_results_to_perf_test_results.py \ + --summary-json "$OPTRATE_JSON" \ + --results-csv "$OPTRATE_CSV" \ + --run-type "opt-rate" \ + --min-cu "${MIN_CU}" \ + --max-cu "${MAX_CU}" \ + --project-id "${PROJECT_ID}" \ + --revision "${REVISION}" \ + --connection-string "${PERF_DB_CONNSTR}" + else + echo "Warning: Missing optimal rate results files (CSV: $OPTRATE_CSV, JSON: $OPTRATE_JSON)" + fi + + echo "Processing warmup 2 results for ramp down/up phase..." + + # Find and process warmup results + WARMUP_CSV=$(find results_warmup -name "*.results.csv" -type f | tail -1) + WARMUP_JSON=$(find results_warmup -name "*.summary.json" -type f | tail -1) + + if [ -n "$WARMUP_CSV" ] && [ -n "$WARMUP_JSON" ]; then + echo "Generating warmup diagram from: $WARMUP_CSV" + python generate_diagrams.py \ + --input-csv "$WARMUP_CSV" \ + --output-svg "warmup_2_${WAREHOUSES}_warehouses_performance.svg" \ + --title-suffix "Warmup at max TPS" + + echo "Uploading warmup metrics from: $WARMUP_JSON" + python upload_results_to_perf_test_results.py \ + --summary-json "$WARMUP_JSON" \ + --results-csv "$WARMUP_CSV" \ + --run-type "warmup" \ + --min-cu "${MIN_CU}" \ + --max-cu "${MAX_CU}" \ + --project-id "${PROJECT_ID}" \ + --revision "${REVISION}" \ + --connection-string "${PERF_DB_CONNSTR}" + else + echo "Warning: Missing warmup results files (CSV: $WARMUP_CSV, JSON: $WARMUP_JSON)" + fi + + echo "Processing ramp results..." + + # Find and process ramp results + RAMPUP_CSV=$(find results_ramp_up -name "*.results.csv" -type f | head -1) + RAMPUP_JSON=$(find results_ramp_up -name "*.summary.json" -type f | head -1) + + if [ -n "$RAMPUP_CSV" ] && [ -n "$RAMPUP_JSON" ]; then + echo "Generating ramp diagram from: $RAMPUP_CSV" + python generate_diagrams.py \ + --input-csv "$RAMPUP_CSV" \ + --output-svg "ramp_${WAREHOUSES}_warehouses_performance.svg" \ + --title-suffix "ramp TPS down and up in 5 minute intervals" + + echo "Uploading ramp metrics from: $RAMPUP_JSON" + python upload_results_to_perf_test_results.py \ + --summary-json "$RAMPUP_JSON" \ + --results-csv "$RAMPUP_CSV" \ + --run-type "ramp-up" \ + --min-cu "${MIN_CU}" \ + --max-cu "${MAX_CU}" \ + --project-id "${PROJECT_ID}" \ + --revision "${REVISION}" \ + --connection-string "${PERF_DB_CONNSTR}" + else + echo "Warning: Missing ramp results files (CSV: $RAMPUP_CSV, JSON: $RAMPUP_JSON)" + fi + + # Deactivate and clean up virtual environment + deactivate + rm -rf temp_results_env + rm upload_results_to_perf_test_results.py + + echo "Results processing completed and environment cleaned up" + + - name: Set date for upload + id: set-date + run: echo "date=$(date +%Y-%m-%d)" >> $GITHUB_OUTPUT + + - name: Configure AWS credentials # necessary to upload results + uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2 + with: + aws-region: us-east-2 + role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} + role-duration-seconds: 900 # 900 is minimum value + + - name: Upload benchmark results to S3 + env: + S3_BUCKET: neon-public-benchmark-results + S3_PREFIX: benchbase-tpc-c/${{ steps.set-date.outputs.date }}/${{ github.run_id }}/${{ matrix.warehouses }}-warehouses + run: | + echo "Redacting passwords from configuration files before upload..." + + # Mask all passwords in XML config files + find "${RUNNER_TEMP}/configs" -name "*.xml" -type f -exec sed -i 's|[^<]*|redacted|g' {} \; + + echo "Uploading benchmark results to s3://${S3_BUCKET}/${S3_PREFIX}/" + + # Upload the entire benchmark directory recursively + aws s3 cp --only-show-errors --recursive "${RUNNER_TEMP}" s3://${S3_BUCKET}/${S3_PREFIX}/ + + echo "Upload completed" + + - name: Delete Neon Project + if: ${{ always() }} + uses: ./.github/actions/neon-project-delete + with: + project_id: ${{ steps.create-neon-project-tpcc.outputs.project_id }} + api_key: ${{ secrets.NEON_PRODUCTION_API_KEY_4_BENCHMARKS }} + api_host: console.neon.tech # production (!) \ No newline at end of file diff --git a/.github/workflows/benchmarking.yml b/.github/workflows/benchmarking.yml index df80bad579..c9a998bd4e 100644 --- a/.github/workflows/benchmarking.yml +++ b/.github/workflows/benchmarking.yml @@ -418,7 +418,7 @@ jobs: statuses: write id-token: write # aws-actions/configure-aws-credentials env: - PGBENCH_SIZE: ${{ vars.PREWARM_PGBENCH_SIZE }} + PROJECT_ID: ${{ vars.PREWARM_PROJECT_ID }} POSTGRES_DISTRIB_DIR: /tmp/neon/pg_install DEFAULT_PG_VERSION: 17 TEST_OUTPUT: /tmp/test_output diff --git a/.github/workflows/build-build-tools-image.yml b/.github/workflows/build-build-tools-image.yml index 24e4c8fa3d..5e53d8231f 100644 --- a/.github/workflows/build-build-tools-image.yml +++ b/.github/workflows/build-build-tools-image.yml @@ -146,7 +146,9 @@ jobs: with: file: build-tools/Dockerfile context: . - provenance: false + attests: | + type=provenance,mode=max + type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1 push: true pull: true build-args: | diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index f237a991cc..0dcbd1c6dd 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -634,7 +634,9 @@ jobs: DEBIAN_VERSION=bookworm secrets: | SUBZERO_ACCESS_TOKEN=${{ secrets.CI_ACCESS_TOKEN }} - provenance: false + attests: | + type=provenance,mode=max + type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1 push: true pull: true file: Dockerfile @@ -747,7 +749,9 @@ jobs: PG_VERSION=${{ matrix.version.pg }} BUILD_TAG=${{ needs.meta.outputs.release-tag || needs.meta.outputs.build-tag }} DEBIAN_VERSION=${{ matrix.version.debian }} - provenance: false + attests: | + type=provenance,mode=max + type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1 push: true pull: true file: compute/compute-node.Dockerfile @@ -766,7 +770,9 @@ jobs: PG_VERSION=${{ matrix.version.pg }} BUILD_TAG=${{ needs.meta.outputs.release-tag || needs.meta.outputs.build-tag }} DEBIAN_VERSION=${{ matrix.version.debian }} - provenance: false + attests: | + type=provenance,mode=max + type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1 push: true pull: true file: compute/compute-node.Dockerfile diff --git a/.github/workflows/pg-clients.yml b/.github/workflows/pg-clients.yml index 6efe0b4c8c..40b2c51624 100644 --- a/.github/workflows/pg-clients.yml +++ b/.github/workflows/pg-clients.yml @@ -48,8 +48,20 @@ jobs: uses: ./.github/workflows/build-build-tools-image.yml secrets: inherit + generate-ch-tmppw: + runs-on: ubuntu-22.04 + outputs: + tmp_val: ${{ steps.pwgen.outputs.tmp_val }} + steps: + - name: Generate a random password + id: pwgen + run: | + set +x + p=$(dd if=/dev/random bs=14 count=1 2>/dev/null | base64) + echo tmp_val="${p//\//}" >> "${GITHUB_OUTPUT}" + test-logical-replication: - needs: [ build-build-tools-image ] + needs: [ build-build-tools-image, generate-ch-tmppw ] runs-on: ubuntu-22.04 container: @@ -60,16 +72,21 @@ jobs: options: --init --user root services: clickhouse: - image: clickhouse/clickhouse-server:24.6.3.64 + image: clickhouse/clickhouse-server:25.6 + env: + CLICKHOUSE_PASSWORD: ${{ needs.generate-ch-tmppw.outputs.tmp_val }} + PGSSLCERT: /tmp/postgresql.crt ports: - 9000:9000 - 8123:8123 zookeeper: - image: quay.io/debezium/zookeeper:2.7 + image: quay.io/debezium/zookeeper:3.1.3.Final ports: - 2181:2181 + - 2888:2888 + - 3888:3888 kafka: - image: quay.io/debezium/kafka:2.7 + image: quay.io/debezium/kafka:3.1.3.Final env: ZOOKEEPER_CONNECT: "zookeeper:2181" KAFKA_ADVERTISED_LISTENERS: PLAINTEXT://kafka:9092 @@ -79,7 +96,7 @@ jobs: ports: - 9092:9092 debezium: - image: quay.io/debezium/connect:2.7 + image: quay.io/debezium/connect:3.1.3.Final env: BOOTSTRAP_SERVERS: kafka:9092 GROUP_ID: 1 @@ -125,6 +142,7 @@ jobs: aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} env: BENCHMARK_CONNSTR: ${{ steps.create-neon-project.outputs.dsn }} + CLICKHOUSE_PASSWORD: ${{ needs.generate-ch-tmppw.outputs.tmp_val }} - name: Delete Neon Project if: always() diff --git a/Cargo.lock b/Cargo.lock index 133ca5def9..9a0cc9076a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -211,11 +211,11 @@ dependencies = [ [[package]] name = "async-lock" -version = "3.2.0" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7125e42787d53db9dd54261812ef17e937c95a51e4d291373b670342fa44310c" +checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18" dependencies = [ - "event-listener 4.0.0", + "event-listener 5.4.0", "event-listener-strategy", "pin-project-lite", ] @@ -1404,9 +1404,9 @@ dependencies = [ [[package]] name = "concurrent-queue" -version = "2.3.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f057a694a54f12365049b0958a1685bb52d567f5593b355fbf685838e873d400" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" dependencies = [ "crossbeam-utils", ] @@ -2232,9 +2232,9 @@ checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" [[package]] name = "event-listener" -version = "4.0.0" +version = "5.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "770d968249b5d99410d61f5bf89057f3199a077a04d087092f58e7d10692baae" +checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae" dependencies = [ "concurrent-queue", "parking", @@ -2243,11 +2243,11 @@ dependencies = [ [[package]] name = "event-listener-strategy" -version = "0.4.0" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" dependencies = [ - "event-listener 4.0.0", + "event-listener 5.4.0", "pin-project-lite", ] @@ -2516,6 +2516,20 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "304de19db7028420975a296ab0fcbbc8e69438c4ed254a1e41e2a7f37d5f0e0a" +[[package]] +name = "generator" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d18470a76cb7f8ff746cf1f7470914f900252ec36bbc40b569d74b1258446827" +dependencies = [ + "cc", + "cfg-if", + "libc", + "log", + "rustversion", + "windows 0.61.3", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -2834,7 +2848,7 @@ checksum = "f9c7c7c8ac16c798734b8a24560c1362120597c40d5e1459f09498f8f6c8f2ba" dependencies = [ "cfg-if", "libc", - "windows", + "windows 0.52.0", ] [[package]] @@ -3105,7 +3119,7 @@ dependencies = [ "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows-core", + "windows-core 0.52.0", ] [[package]] @@ -3656,6 +3670,19 @@ version = "0.4.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" +[[package]] +name = "loom" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator", + "scoped-tls", + "tracing", + "tracing-subscriber", +] + [[package]] name = "lru" version = "0.12.3" @@ -3872,6 +3899,25 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "moka" +version = "0.12.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9321642ca94a4282428e6ea4af8cc2ca4eac48ac7a6a4ea8f33f76d0ce70926" +dependencies = [ + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "loom", + "parking_lot 0.12.1", + "portable-atomic", + "rustc_version", + "smallvec", + "tagptr", + "thiserror 1.0.69", + "uuid", +] + [[package]] name = "multimap" version = "0.8.3" @@ -5031,8 +5077,6 @@ dependencies = [ "crc32c", "criterion", "env_logger", - "log", - "memoffset 0.9.0", "once_cell", "postgres", "postgres_ffi_types", @@ -5385,7 +5429,6 @@ dependencies = [ "futures", "gettid", "hashbrown 0.14.5", - "hashlink", "hex", "hmac", "hostname", @@ -5407,6 +5450,7 @@ dependencies = [ "lasso", "measured", "metrics", + "moka", "once_cell", "opentelemetry", "ouroboros", @@ -5473,6 +5517,7 @@ dependencies = [ "workspace_hack", "x509-cert", "zerocopy 0.8.24", + "zeroize", ] [[package]] @@ -6420,6 +6465,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.1.0" @@ -7269,6 +7320,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "tar" version = "0.4.40" @@ -8638,10 +8695,32 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" dependencies = [ - "windows-core", + "windows-core 0.52.0", "windows-targets 0.52.6", ] +[[package]] +name = "windows" +version = "0.61.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893" +dependencies = [ + "windows-collections", + "windows-core 0.61.2", + "windows-future", + "windows-link", + "windows-numerics", +] + +[[package]] +name = "windows-collections" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" +dependencies = [ + "windows-core 0.61.2", +] + [[package]] name = "windows-core" version = "0.52.0" @@ -8651,6 +8730,86 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-core" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-future" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" +dependencies = [ + "windows-core 0.61.2", + "windows-link", + "windows-threading", +] + +[[package]] +name = "windows-implement" +version = "0.60.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "windows-interface" +version = "0.59.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "windows-link" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" + +[[package]] +name = "windows-numerics" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" +dependencies = [ + "windows-core 0.61.2", + "windows-link", +] + +[[package]] +name = "windows-result" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -8709,6 +8868,15 @@ dependencies = [ "windows_x86_64_msvc 0.52.6", ] +[[package]] +name = "windows-threading" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" +dependencies = [ + "windows-link", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.0" @@ -8845,6 +9013,8 @@ dependencies = [ "clap", "clap_builder", "const-oid", + "crossbeam-epoch", + "crossbeam-utils", "crypto-bigint 0.5.5", "der 0.7.8", "deranged", @@ -8890,6 +9060,7 @@ dependencies = [ "once_cell", "p256 0.13.2", "parquet", + "portable-atomic", "prettyplease", "proc-macro2", "prost 0.13.5", diff --git a/Cargo.toml b/Cargo.toml index 18236a81f5..3f23086797 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,10 +46,10 @@ members = [ "libs/proxy/json", "libs/proxy/postgres-protocol2", "libs/proxy/postgres-types2", + "libs/proxy/subzero_core", "libs/proxy/tokio-postgres2", "endpoint_storage", "pgxn/neon/communicator", - "proxy/subzero_core", ] [workspace.package] @@ -135,7 +135,7 @@ lock_api = "0.4.13" md5 = "0.7.0" measured = { version = "0.0.22", features=["lasso"] } measured-process = { version = "0.0.22" } -memoffset = "0.9" +moka = { version = "0.12", features = ["sync"] } nix = { version = "0.30.1", features = ["dir", "fs", "mman", "process", "socket", "signal", "poll"] } # Do not update to >= 7.0.0, at least. The update will have a significant impact # on compute startup metrics (start_postgres_ms), >= 25% degradation. @@ -233,9 +233,10 @@ uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] } walkdir = "2.3.2" rustls-native-certs = "0.8" whoami = "1.5.1" -zerocopy = { version = "0.8", features = ["derive", "simd"] } json-structural-diff = { version = "0.2.0" } x509-cert = { version = "0.2.5" } +zerocopy = { version = "0.8", features = ["derive", "simd"] } +zeroize = "1.8" ## TODO replace this with tracing env_logger = "0.11" diff --git a/Dockerfile b/Dockerfile index 654ae72e56..63cc954873 100644 --- a/Dockerfile +++ b/Dockerfile @@ -103,7 +103,7 @@ RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \ && if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \ export CARGO_FEATURES="rest_broker"; \ fi \ - && RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo build \ + && RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo auditable build \ --features $CARGO_FEATURES \ --bin pg_sni_router \ --bin pageserver \ diff --git a/build-tools/Dockerfile b/build-tools/Dockerfile index b5fe642e6f..c9760f610b 100644 --- a/build-tools/Dockerfile +++ b/build-tools/Dockerfile @@ -39,13 +39,13 @@ COPY build-tools/patches/pgcopydbv017.patch /pgcopydbv017.patch RUN if [ "${DEBIAN_VERSION}" = "bookworm" ]; then \ set -e && \ - apt update && \ - apt install -y --no-install-recommends \ + apt-get update && \ + apt-get install -y --no-install-recommends \ ca-certificates wget gpg && \ wget -qO - https://www.postgresql.org/media/keys/ACCC4CF8.asc | gpg --dearmor -o /usr/share/keyrings/postgresql-keyring.gpg && \ echo "deb [signed-by=/usr/share/keyrings/postgresql-keyring.gpg] http://apt.postgresql.org/pub/repos/apt bookworm-pgdg main" > /etc/apt/sources.list.d/pgdg.list && \ apt-get update && \ - apt install -y --no-install-recommends \ + apt-get install -y --no-install-recommends \ build-essential \ autotools-dev \ libedit-dev \ @@ -89,8 +89,7 @@ RUN useradd -ms /bin/bash nonroot -b /home # Use strict mode for bash to catch errors early SHELL ["/bin/bash", "-euo", "pipefail", "-c"] -RUN mkdir -p /pgcopydb/bin && \ - mkdir -p /pgcopydb/lib && \ +RUN mkdir -p /pgcopydb/{bin,lib} && \ chmod -R 755 /pgcopydb && \ chown -R nonroot:nonroot /pgcopydb @@ -106,8 +105,8 @@ RUN echo 'Acquire::Retries "5";' > /etc/apt/apt.conf.d/80-retries && \ # 'gdb' is included so that we get backtraces of core dumps produced in # regression tests RUN set -e \ - && apt update \ - && apt install -y \ + && apt-get update \ + && apt-get install -y --no-install-recommends \ autoconf \ automake \ bison \ @@ -183,22 +182,22 @@ RUN curl -sL "https://github.com/peak/s5cmd/releases/download/v${S5CMD_VERSION}/ ENV LLVM_VERSION=20 RUN curl -fsSL 'https://apt.llvm.org/llvm-snapshot.gpg.key' | apt-key add - \ && echo "deb http://apt.llvm.org/${DEBIAN_VERSION}/ llvm-toolchain-${DEBIAN_VERSION}-${LLVM_VERSION} main" > /etc/apt/sources.list.d/llvm.stable.list \ - && apt update \ - && apt install -y clang-${LLVM_VERSION} llvm-${LLVM_VERSION} \ + && apt-get update \ + && apt-get install -y --no-install-recommends clang-${LLVM_VERSION} llvm-${LLVM_VERSION} \ && bash -c 'for f in /usr/bin/clang*-${LLVM_VERSION} /usr/bin/llvm*-${LLVM_VERSION}; do ln -s "${f}" "${f%-${LLVM_VERSION}}"; done' \ && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* # Install node ENV NODE_VERSION=24 RUN curl -fsSL https://deb.nodesource.com/setup_${NODE_VERSION}.x | bash - \ - && apt install -y nodejs \ + && apt-get install -y --no-install-recommends nodejs \ && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* # Install docker RUN curl -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg \ && echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/debian ${DEBIAN_VERSION} stable" > /etc/apt/sources.list.d/docker.list \ - && apt update \ - && apt install -y docker-ce docker-ce-cli \ + && apt-get update \ + && apt-get install -y --no-install-recommends docker-ce docker-ce-cli \ && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* # Configure sudo & docker @@ -215,12 +214,11 @@ RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-$(uname -m).zip" -o "aws # Mold: A Modern Linker ENV MOLD_VERSION=v2.37.1 RUN set -e \ - && git clone https://github.com/rui314/mold.git \ + && git clone -b "${MOLD_VERSION}" --depth 1 https://github.com/rui314/mold.git \ && mkdir mold/build \ - && cd mold/build \ - && git checkout ${MOLD_VERSION} \ + && cd mold/build \ && cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_COMPILER=clang++ .. \ - && cmake --build . -j $(nproc) \ + && cmake --build . -j "$(nproc)" \ && cmake --install . \ && cd .. \ && rm -rf mold @@ -254,7 +252,7 @@ ENV ICU_VERSION=67.1 ENV ICU_PREFIX=/usr/local/icu # Download and build static ICU -RUN wget -O /tmp/libicu-${ICU_VERSION}.tgz https://github.com/unicode-org/icu/releases/download/release-${ICU_VERSION//./-}/icu4c-${ICU_VERSION//./_}-src.tgz && \ +RUN wget -O "/tmp/libicu-${ICU_VERSION}.tgz" https://github.com/unicode-org/icu/releases/download/release-${ICU_VERSION//./-}/icu4c-${ICU_VERSION//./_}-src.tgz && \ echo "94a80cd6f251a53bd2a997f6f1b5ac6653fe791dfab66e1eb0227740fb86d5dc /tmp/libicu-${ICU_VERSION}.tgz" | sha256sum --check && \ mkdir /tmp/icu && \ pushd /tmp/icu && \ @@ -265,8 +263,7 @@ RUN wget -O /tmp/libicu-${ICU_VERSION}.tgz https://github.com/unicode-org/icu/re make install && \ popd && \ rm -rf icu && \ - rm -f /tmp/libicu-${ICU_VERSION}.tgz && \ - popd + rm -f /tmp/libicu-${ICU_VERSION}.tgz # Switch to nonroot user USER nonroot:nonroot @@ -279,19 +276,19 @@ ENV PYTHON_VERSION=3.11.12 \ PYENV_ROOT=/home/nonroot/.pyenv \ PATH=/home/nonroot/.pyenv/shims:/home/nonroot/.pyenv/bin:/home/nonroot/.poetry/bin:$PATH RUN set -e \ - && cd $HOME \ + && cd "$HOME" \ && curl -sSO https://raw.githubusercontent.com/pyenv/pyenv-installer/master/bin/pyenv-installer \ && chmod +x pyenv-installer \ && ./pyenv-installer \ && export PYENV_ROOT=/home/nonroot/.pyenv \ && export PATH="$PYENV_ROOT/bin:$PATH" \ && export PATH="$PYENV_ROOT/shims:$PATH" \ - && pyenv install ${PYTHON_VERSION} \ - && pyenv global ${PYTHON_VERSION} \ + && pyenv install "${PYTHON_VERSION}" \ + && pyenv global "${PYTHON_VERSION}" \ && python --version \ - && pip install --upgrade pip \ + && pip install --no-cache-dir --upgrade pip \ && pip --version \ - && pip install pipenv wheel poetry + && pip install --no-cache-dir pipenv wheel poetry # Switch to nonroot user (again) USER nonroot:nonroot @@ -302,6 +299,7 @@ WORKDIR /home/nonroot ENV RUSTC_VERSION=1.88.0 ENV RUSTUP_HOME="/home/nonroot/.rustup" ENV PATH="/home/nonroot/.cargo/bin:${PATH}" +ARG CARGO_AUDITABLE_VERSION=0.7.0 ARG RUSTFILT_VERSION=0.2.1 ARG CARGO_HAKARI_VERSION=0.9.36 ARG CARGO_DENY_VERSION=0.18.2 @@ -317,14 +315,16 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux . "$HOME/.cargo/env" && \ cargo --version && rustup --version && \ rustup component add llvm-tools rustfmt clippy && \ - cargo install rustfilt --locked --version ${RUSTFILT_VERSION} && \ - cargo install cargo-hakari --locked --version ${CARGO_HAKARI_VERSION} && \ - cargo install cargo-deny --locked --version ${CARGO_DENY_VERSION} && \ - cargo install cargo-hack --locked --version ${CARGO_HACK_VERSION} && \ - cargo install cargo-nextest --locked --version ${CARGO_NEXTEST_VERSION} && \ - cargo install cargo-chef --locked --version ${CARGO_CHEF_VERSION} && \ - cargo install diesel_cli --locked --version ${CARGO_DIESEL_CLI_VERSION} \ - --features postgres-bundled --no-default-features && \ + cargo install cargo-auditable --locked --version "${CARGO_AUDITABLE_VERSION}" && \ + cargo auditable install cargo-auditable --locked --version "${CARGO_AUDITABLE_VERSION}" --force && \ + cargo auditable install rustfilt --version "${RUSTFILT_VERSION}" && \ + cargo auditable install cargo-hakari --locked --version "${CARGO_HAKARI_VERSION}" && \ + cargo auditable install cargo-deny --locked --version "${CARGO_DENY_VERSION}" && \ + cargo auditable install cargo-hack --locked --version "${CARGO_HACK_VERSION}" && \ + cargo auditable install cargo-nextest --locked --version "${CARGO_NEXTEST_VERSION}" && \ + cargo auditable install cargo-chef --locked --version "${CARGO_CHEF_VERSION}" && \ + cargo auditable install diesel_cli --locked --version "${CARGO_DIESEL_CLI_VERSION}" \ + --features postgres-bundled --no-default-features && \ rm -rf /home/nonroot/.cargo/registry && \ rm -rf /home/nonroot/.cargo/git diff --git a/compute/patches/pg_repack.patch b/compute/patches/pg_repack.patch index 10ed1054ff..b8a057e222 100644 --- a/compute/patches/pg_repack.patch +++ b/compute/patches/pg_repack.patch @@ -1,5 +1,11 @@ +commit 5eb393810cf7c7bafa4e394dad2e349e2a8cb2cb +Author: Alexey Masterov +Date: Mon Jul 28 18:11:02 2025 +0200 + + Patch for pg_repack + diff --git a/regress/Makefile b/regress/Makefile -index bf6edcb..89b4c7f 100644 +index bf6edcb..110e734 100644 --- a/regress/Makefile +++ b/regress/Makefile @@ -17,7 +17,7 @@ INTVERSION := $(shell echo $$(($$(echo $(VERSION).0 | sed 's/\([[:digit:]]\{1,\} @@ -7,18 +13,36 @@ index bf6edcb..89b4c7f 100644 # -REGRESS := init-extension repack-setup repack-run error-on-invalid-idx no-error-on-invalid-idx after-schema repack-check nosuper tablespace get_order_by trigger -+REGRESS := init-extension repack-setup repack-run error-on-invalid-idx no-error-on-invalid-idx after-schema repack-check nosuper get_order_by trigger ++REGRESS := init-extension noautovacuum repack-setup repack-run error-on-invalid-idx no-error-on-invalid-idx after-schema repack-check nosuper get_order_by trigger autovacuum USE_PGXS = 1 # use pgxs if not in contrib directory PGXS := $(shell $(PG_CONFIG) --pgxs) -diff --git a/regress/expected/init-extension.out b/regress/expected/init-extension.out -index 9f2e171..f6e4f8d 100644 ---- a/regress/expected/init-extension.out -+++ b/regress/expected/init-extension.out -@@ -1,3 +1,2 @@ - SET client_min_messages = warning; - CREATE EXTENSION pg_repack; --RESET client_min_messages; +diff --git a/regress/expected/autovacuum.out b/regress/expected/autovacuum.out +new file mode 100644 +index 0000000..e7f2363 +--- /dev/null ++++ b/regress/expected/autovacuum.out +@@ -0,0 +1,7 @@ ++ALTER SYSTEM SET autovacuum='on'; ++SELECT pg_reload_conf(); ++ pg_reload_conf ++---------------- ++ t ++(1 row) ++ +diff --git a/regress/expected/noautovacuum.out b/regress/expected/noautovacuum.out +new file mode 100644 +index 0000000..fc7978e +--- /dev/null ++++ b/regress/expected/noautovacuum.out +@@ -0,0 +1,7 @@ ++ALTER SYSTEM SET autovacuum='off'; ++SELECT pg_reload_conf(); ++ pg_reload_conf ++---------------- ++ t ++(1 row) ++ diff --git a/regress/expected/nosuper.out b/regress/expected/nosuper.out index 8d0a94e..63b68bf 100644 --- a/regress/expected/nosuper.out @@ -50,14 +74,22 @@ index 8d0a94e..63b68bf 100644 INFO: repacking table "public.tbl_cluster" ERROR: query failed: ERROR: current transaction is aborted, commands ignored until end of transaction block DETAIL: query was: RESET lock_timeout -diff --git a/regress/sql/init-extension.sql b/regress/sql/init-extension.sql -index 9f2e171..f6e4f8d 100644 ---- a/regress/sql/init-extension.sql -+++ b/regress/sql/init-extension.sql -@@ -1,3 +1,2 @@ - SET client_min_messages = warning; - CREATE EXTENSION pg_repack; --RESET client_min_messages; +diff --git a/regress/sql/autovacuum.sql b/regress/sql/autovacuum.sql +new file mode 100644 +index 0000000..a8eda63 +--- /dev/null ++++ b/regress/sql/autovacuum.sql +@@ -0,0 +1,2 @@ ++ALTER SYSTEM SET autovacuum='on'; ++SELECT pg_reload_conf(); +diff --git a/regress/sql/noautovacuum.sql b/regress/sql/noautovacuum.sql +new file mode 100644 +index 0000000..13d4836 +--- /dev/null ++++ b/regress/sql/noautovacuum.sql +@@ -0,0 +1,2 @@ ++ALTER SYSTEM SET autovacuum='off'; ++SELECT pg_reload_conf(); diff --git a/regress/sql/nosuper.sql b/regress/sql/nosuper.sql index 072f0fa..dbe60f8 100644 --- a/regress/sql/nosuper.sql diff --git a/compute_tools/README.md b/compute_tools/README.md index 446b441c18..e92e5920b9 100644 --- a/compute_tools/README.md +++ b/compute_tools/README.md @@ -54,11 +54,11 @@ stateDiagram-v2 Running --> TerminationPendingImmediate : Requested termination Running --> ConfigurationPending : Received a /configure request with spec Running --> RefreshConfigurationPending : Received a /refresh_configuration request, compute node will pull a new spec and reconfigure - RefreshConfigurationPending --> Running : Compute has been re-configured + RefreshConfigurationPending --> RefreshConfiguration: Received compute spec and started configuration + RefreshConfiguration --> Running : Compute has been re-configured + RefreshConfiguration --> RefreshConfigurationPending : Configuration failed and to be retried TerminationPendingFast --> Terminated compute with 30s delay for cplane to inspect status TerminationPendingImmediate --> Terminated : Terminated compute immediately - Running --> TerminationPending : Requested termination - TerminationPending --> Terminated : Terminated compute Failed --> RefreshConfigurationPending : Received a /refresh_configuration request Failed --> [*] : Compute exited Terminated --> [*] : Compute exited diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index 83a2e6dc68..2b4802f309 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -49,10 +49,10 @@ use compute_tools::compute::{ BUILD_TAG, ComputeNode, ComputeNodeParams, forward_termination_signal, }; use compute_tools::extension_server::get_pg_version_string; -use compute_tools::logger::*; use compute_tools::params::*; use compute_tools::pg_isready::get_pg_isready_bin; use compute_tools::spec::*; +use compute_tools::{hadron_metrics, installed_extensions, logger::*}; use rlimit::{Resource, setrlimit}; use signal_hook::consts::{SIGINT, SIGQUIT, SIGTERM}; use signal_hook::iterator::Signals; @@ -82,6 +82,15 @@ struct Cli { #[arg(long, default_value_t = 3081)] pub internal_http_port: u16, + /// Backwards-compatible --http-port for Hadron deployments. Functionally the + /// same as --external-http-port. + #[arg( + long, + conflicts_with = "external_http_port", + conflicts_with = "internal_http_port" + )] + pub http_port: Option, + #[arg(short = 'D', long, value_name = "DATADIR")] pub pgdata: String, @@ -181,6 +190,26 @@ impl Cli { } } +// Hadron helpers to get compatible compute_ctl http ports from Cli. The old `--http-port` +// arg is used and acts the same as `--external-http-port`. The internal http port is defined +// to be http_port + 1. Hadron runs in the dblet environment which uses the host network, so +// we need to be careful with the ports to choose. +fn get_external_http_port(cli: &Cli) -> u16 { + if cli.lakebase_mode { + return cli.http_port.unwrap_or(cli.external_http_port); + } + cli.external_http_port +} +fn get_internal_http_port(cli: &Cli) -> u16 { + if cli.lakebase_mode { + return cli + .http_port + .map(|p| p + 1) + .unwrap_or(cli.internal_http_port); + } + cli.internal_http_port +} + fn main() -> Result<()> { let cli = Cli::parse(); @@ -205,10 +234,18 @@ fn main() -> Result<()> { // enable core dumping for all child processes setrlimit(Resource::CORE, rlimit::INFINITY, rlimit::INFINITY)?; + if cli.lakebase_mode { + installed_extensions::initialize_metrics(); + hadron_metrics::initialize_metrics(); + } + let connstr = Url::parse(&cli.connstr).context("cannot parse connstr as a URL")?; let config = get_config(&cli)?; + let external_http_port = get_external_http_port(&cli); + let internal_http_port = get_internal_http_port(&cli); + let compute_node = ComputeNode::new( ComputeNodeParams { compute_id: cli.compute_id, @@ -217,8 +254,8 @@ fn main() -> Result<()> { pgdata: cli.pgdata.clone(), pgbin: cli.pgbin.clone(), pgversion: get_pg_version_string(&cli.pgbin), - external_http_port: cli.external_http_port, - internal_http_port: cli.internal_http_port, + external_http_port, + internal_http_port, remote_ext_base_url: cli.remote_ext_base_url.clone(), resize_swap_on_bind: cli.resize_swap_on_bind, set_disk_quota_for_fs: cli.set_disk_quota_for_fs, diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index e3ac887e9c..27d33d8cd8 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -6,7 +6,8 @@ use compute_api::responses::{ LfcPrewarmState, PromoteState, TlsConfig, }; use compute_api::spec::{ - ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PageserverProtocol, PgIdent, + ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, GenericOption, + PageserverProtocol, PgIdent, Role, }; use futures::StreamExt; use futures::future::join_all; @@ -41,8 +42,9 @@ use utils::shard::{ShardCount, ShardIndex, ShardNumber}; use crate::configurator::launch_configurator; use crate::disk_quota::set_disk_quota; +use crate::hadron_metrics::COMPUTE_ATTACHED; use crate::installed_extensions::get_installed_extensions; -use crate::logger::startup_context_from_env; +use crate::logger::{self, startup_context_from_env}; use crate::lsn_lease::launch_lsn_lease_bg_task_for_static; use crate::metrics::COMPUTE_CTL_UP; use crate::monitor::launch_monitor; @@ -412,6 +414,130 @@ struct StartVmMonitorResult { vm_monitor: Option>>, } +// BEGIN_HADRON +/// This function creates roles that are used by Databricks. +/// These roles are not needs to be botostrapped at PG Compute provisioning time. +/// The auth method for these roles are configured in databricks_pg_hba.conf in universe repository. +pub(crate) fn create_databricks_roles() -> Vec { + let roles = vec![ + // Role for prometheus_stats_exporter + Role { + name: "databricks_monitor".to_string(), + // This uses "local" connection and auth method for that is "trust", so no password is needed. + encrypted_password: None, + options: Some(vec![GenericOption { + name: "IN ROLE pg_monitor".to_string(), + value: None, + vartype: "string".to_string(), + }]), + }, + // Role for brickstore control plane + Role { + name: "databricks_control_plane".to_string(), + // Certificate user does not need password. + encrypted_password: None, + options: Some(vec![GenericOption { + name: "SUPERUSER".to_string(), + value: None, + vartype: "string".to_string(), + }]), + }, + // Role for brickstore httpgateway. + Role { + name: "databricks_gateway".to_string(), + // Certificate user does not need password. + encrypted_password: None, + options: None, + }, + ]; + + roles + .into_iter() + .map(|role| { + let query = format!( + r#" + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT FROM pg_catalog.pg_roles WHERE rolname = '{}') + THEN + CREATE ROLE {} {}; + END IF; + END + $$;"#, + role.name, + role.name.pg_quote(), + role.to_pg_options(), + ); + query + }) + .collect() +} + +/// Databricks-specific environment variables to be passed to the `postgres` sub-process. +pub struct DatabricksEnvVars { + /// The Databricks "endpoint ID" of the compute instance. Used by `postgres` to check + /// the token scopes of internal auth tokens. + pub endpoint_id: String, + /// Hostname of the Databricks workspace URL this compute instance belongs to. + /// Used by postgres to verify Databricks PAT tokens. + pub workspace_host: String, + + pub lakebase_mode: bool, +} + +impl DatabricksEnvVars { + pub fn new( + compute_spec: &ComputeSpec, + compute_id: Option<&String>, + instance_id: Option, + lakebase_mode: bool, + ) -> Self { + let endpoint_id = if let Some(instance_id) = instance_id { + // Use instance_id as endpoint_id if it is set. This code path is for PuPr model. + instance_id + } else { + // Use compute_id as endpoint_id if instance_id is not set. The code path is for PrPr model. + // compute_id is a string format of "{endpoint_id}/{compute_idx}" + // endpoint_id is a uuid. We only need to pass down endpoint_id to postgres. + // Panics if compute_id is not set or not in the expected format. + compute_id.unwrap().split('/').next().unwrap().to_string() + }; + let workspace_host = compute_spec + .databricks_settings + .as_ref() + .map(|s| s.databricks_workspace_host.clone()) + .unwrap_or("".to_string()); + Self { + endpoint_id, + workspace_host, + lakebase_mode, + } + } + + /// Constants for the names of Databricks-specific postgres environment variables. + const DATABRICKS_ENDPOINT_ID_ENVVAR: &'static str = "DATABRICKS_ENDPOINT_ID"; + const DATABRICKS_WORKSPACE_HOST_ENVVAR: &'static str = "DATABRICKS_WORKSPACE_HOST"; + + /// Convert DatabricksEnvVars to a list of string pairs that can be passed as env vars. Consumes `self`. + pub fn to_env_var_list(self) -> Vec<(String, String)> { + if !self.lakebase_mode { + // In neon env, we don't need to pass down the env vars to postgres. + return vec![]; + } + vec![ + ( + Self::DATABRICKS_ENDPOINT_ID_ENVVAR.to_string(), + self.endpoint_id.clone(), + ), + ( + Self::DATABRICKS_WORKSPACE_HOST_ENVVAR.to_string(), + self.workspace_host.clone(), + ), + ] + } +} + impl ComputeNode { pub fn new(params: ComputeNodeParams, config: ComputeConfig) -> Result { let connstr = params.connstr.as_str(); @@ -448,7 +574,11 @@ impl ComputeNode { let mut new_state = ComputeState::new(); if let Some(spec) = config.spec { let pspec = ParsedSpec::try_from(spec).map_err(|msg| anyhow::anyhow!(msg))?; - new_state.pspec = Some(pspec); + if params.lakebase_mode { + ComputeNode::set_spec(¶ms, &mut new_state, pspec); + } else { + new_state.pspec = Some(pspec); + } } Ok(ComputeNode { @@ -1046,7 +1176,14 @@ impl ComputeNode { // If it is something different then create_dir() will error out anyway. let pgdata = &self.params.pgdata; let _ok = fs::remove_dir_all(pgdata); - fs::create_dir(pgdata)?; + if self.params.lakebase_mode { + // Ignore creation errors if the directory already exists (e.g. mounting it ahead of time). + // If it is something different then PG startup will error out anyway. + let _ok = fs::create_dir(pgdata); + } else { + fs::create_dir(pgdata)?; + } + fs::set_permissions(pgdata, fs::Permissions::from_mode(0o700))?; Ok(()) @@ -1410,6 +1547,8 @@ impl ComputeNode { let pgdata_path = Path::new(&self.params.pgdata); let tls_config = self.tls_config(&pspec.spec); + let databricks_settings = spec.databricks_settings.as_ref(); + let postgres_port = self.params.connstr.port(); // Remove/create an empty pgdata directory and put configuration there. self.create_pgdata()?; @@ -1417,8 +1556,11 @@ impl ComputeNode { pgdata_path, &self.params, &pspec.spec, + postgres_port, self.params.internal_http_port, tls_config, + databricks_settings, + self.params.lakebase_mode, )?; // Syncing safekeepers is only safe with primary nodes: if a primary @@ -1458,8 +1600,28 @@ impl ComputeNode { ) })?; - // Update pg_hba.conf received with basebackup. - update_pg_hba(pgdata_path, None)?; + if let Some(settings) = databricks_settings { + copy_tls_certificates( + &settings.pg_compute_tls_settings.key_file, + &settings.pg_compute_tls_settings.cert_file, + pgdata_path, + )?; + + // Update pg_hba.conf received with basebackup including additional databricks settings. + update_pg_hba(pgdata_path, Some(&settings.databricks_pg_hba))?; + update_pg_ident(pgdata_path, Some(&settings.databricks_pg_ident))?; + } else { + // Update pg_hba.conf received with basebackup. + update_pg_hba(pgdata_path, None)?; + } + + if let Some(databricks_settings) = spec.databricks_settings.as_ref() { + copy_tls_certificates( + &databricks_settings.pg_compute_tls_settings.key_file, + &databricks_settings.pg_compute_tls_settings.cert_file, + pgdata_path, + )?; + } // Place pg_dynshmem under /dev/shm. This allows us to use // 'dynamic_shared_memory_type = mmap' so that the files are placed in @@ -1500,7 +1662,7 @@ impl ComputeNode { // symlink doesn't affect anything. // // See https://github.com/neondatabase/autoscaling/issues/800 - std::fs::remove_dir(pgdata_path.join("pg_dynshmem"))?; + std::fs::remove_dir_all(pgdata_path.join("pg_dynshmem"))?; symlink("/dev/shm/", pgdata_path.join("pg_dynshmem"))?; match spec.mode { @@ -1515,6 +1677,12 @@ impl ComputeNode { /// Start and stop a postgres process to warm up the VM for startup. pub fn prewarm_postgres_vm_memory(&self) -> Result<()> { + if self.params.lakebase_mode { + // We are running in Hadron mode. Disabling this prewarming step for now as it could run + // into dblet port conflicts and also doesn't add much value with our current infra. + info!("Skipping postgres prewarming in Hadron mode"); + return Ok(()); + } info!("prewarming VM memory"); // Create pgdata @@ -1572,14 +1740,36 @@ impl ComputeNode { pub fn start_postgres(&self, storage_auth_token: Option) -> Result { let pgdata_path = Path::new(&self.params.pgdata); + let env_vars: Vec<(String, String)> = if self.params.lakebase_mode { + let databricks_env_vars = { + let state = self.state.lock().unwrap(); + let spec = &state.pspec.as_ref().unwrap().spec; + DatabricksEnvVars::new( + spec, + Some(&self.params.compute_id), + self.params.instance_id.clone(), + self.params.lakebase_mode, + ) + }; + + info!( + "Starting Postgres for databricks endpoint id: {}", + &databricks_env_vars.endpoint_id + ); + + let mut env_vars = databricks_env_vars.to_env_var_list(); + env_vars.extend(storage_auth_token.map(|t| ("NEON_AUTH_TOKEN".to_string(), t))); + env_vars + } else if let Some(storage_auth_token) = &storage_auth_token { + vec![("NEON_AUTH_TOKEN".to_owned(), storage_auth_token.to_owned())] + } else { + vec![] + }; + // Run postgres as a child process. let mut pg = maybe_cgexec(&self.params.pgbin) .args(["-D", &self.params.pgdata]) - .envs(if let Some(storage_auth_token) = &storage_auth_token { - vec![("NEON_AUTH_TOKEN", storage_auth_token)] - } else { - vec![] - }) + .envs(env_vars) .stderr(Stdio::piped()) .spawn() .expect("cannot start postgres process"); @@ -1731,7 +1921,15 @@ impl ComputeNode { /// Do initial configuration of the already started Postgres. #[instrument(skip_all)] pub fn apply_config(&self, compute_state: &ComputeState) -> Result<()> { - let conf = self.get_tokio_conn_conf(Some("compute_ctl:apply_config")); + let mut conf = self.get_tokio_conn_conf(Some("compute_ctl:apply_config")); + + if self.params.lakebase_mode { + // Set a 2-minute statement_timeout for the session applying config. The individual SQL statements + // used in apply_spec_sql() should not take long (they are just creating users and installing + // extensions). If any of them are stuck for an extended period of time it usually indicates a + // pageserver connectivity problem and we should bail out. + conf.options("-c statement_timeout=2min"); + } let conf = Arc::new(conf); let spec = Arc::new( @@ -1882,12 +2080,16 @@ impl ComputeNode { // Write new config let pgdata_path = Path::new(&self.params.pgdata); + let postgres_port = self.params.connstr.port(); config::write_postgres_conf( pgdata_path, &self.params, &spec, + postgres_port, self.params.internal_http_port, tls_config, + spec.databricks_settings.as_ref(), + self.params.lakebase_mode, )?; self.pg_reload_conf()?; @@ -1993,6 +2195,7 @@ impl ComputeNode { // wait ComputeStatus::Init | ComputeStatus::Configuration + | ComputeStatus::RefreshConfiguration | ComputeStatus::RefreshConfigurationPending | ComputeStatus::Empty => { state = self.state_changed.wait(state).unwrap(); @@ -2044,7 +2247,17 @@ impl ComputeNode { pub fn check_for_core_dumps(&self) -> Result<()> { let core_dump_dir = match std::env::consts::OS { "macos" => Path::new("/cores/"), - _ => Path::new(&self.params.pgdata), + // BEGIN HADRON + // NB: Read core dump files from a fixed location outside of + // the data directory since `compute_ctl` wipes the data directory + // across container restarts. + _ => { + if self.params.lakebase_mode { + Path::new("/databricks/logs/brickstore") + } else { + Path::new(&self.params.pgdata) + } + } // END HADRON }; // Collect core dump paths if any @@ -2357,7 +2570,7 @@ LIMIT 100", if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") { libs_vec = libs .split(&[',', '\'', ' ']) - .filter(|s| *s != "neon" && !s.is_empty()) + .filter(|s| *s != "neon" && *s != "databricks_auth" && !s.is_empty()) .map(str::to_string) .collect(); } @@ -2376,7 +2589,7 @@ LIMIT 100", if let Some(libs) = shared_preload_libraries_line.split("='").nth(1) { preload_libs_vec = libs .split(&[',', '\'', ' ']) - .filter(|s| *s != "neon" && !s.is_empty()) + .filter(|s| *s != "neon" && *s != "databricks_auth" && !s.is_empty()) .map(str::to_string) .collect(); } @@ -2550,6 +2763,34 @@ LIMIT 100", ); } } + + /// Set the compute spec and update related metrics. + /// This is the central place where pspec is updated. + pub fn set_spec(params: &ComputeNodeParams, state: &mut ComputeState, pspec: ParsedSpec) { + state.pspec = Some(pspec); + ComputeNode::update_attached_metric(params, state); + let _ = logger::update_ids(¶ms.instance_id, &Some(params.compute_id.clone())); + } + + pub fn update_attached_metric(params: &ComputeNodeParams, state: &mut ComputeState) { + // Update the pg_cctl_attached gauge when all identifiers are available. + if let Some(instance_id) = ¶ms.instance_id { + if let Some(pspec) = &state.pspec { + // Clear all values in the metric + COMPUTE_ATTACHED.reset(); + + // Set new metric value + COMPUTE_ATTACHED + .with_label_values(&[ + ¶ms.compute_id, + instance_id, + &pspec.tenant_id.to_string(), + &pspec.timeline_id.to_string(), + ]) + .set(1); + } + } + } } pub async fn installed_extensions(conf: tokio_postgres::Config) -> Result<()> { diff --git a/compute_tools/src/config.rs b/compute_tools/src/config.rs index dd46353343..55a1eda0b7 100644 --- a/compute_tools/src/config.rs +++ b/compute_tools/src/config.rs @@ -7,11 +7,14 @@ use std::io::prelude::*; use std::path::Path; use compute_api::responses::TlsConfig; -use compute_api::spec::{ComputeAudit, ComputeMode, ComputeSpec, GenericOption}; +use compute_api::spec::{ + ComputeAudit, ComputeMode, ComputeSpec, DatabricksSettings, GenericOption, +}; use crate::compute::ComputeNodeParams; use crate::pg_helpers::{ - GenericOptionExt, GenericOptionsSearch, PgOptionsSerialize, escape_conf_value, + DatabricksSettingsExt as _, GenericOptionExt, GenericOptionsSearch, PgOptionsSerialize, + escape_conf_value, }; use crate::tls::{self, SERVER_CRT, SERVER_KEY}; @@ -40,12 +43,16 @@ pub fn line_in_file(path: &Path, line: &str) -> Result { } /// Create or completely rewrite configuration file specified by `path` +#[allow(clippy::too_many_arguments)] pub fn write_postgres_conf( pgdata_path: &Path, params: &ComputeNodeParams, spec: &ComputeSpec, + postgres_port: Option, extension_server_port: u16, tls_config: &Option, + databricks_settings: Option<&DatabricksSettings>, + lakebase_mode: bool, ) -> Result<()> { let path = pgdata_path.join("postgresql.conf"); // File::create() destroys the file content if it exists. @@ -285,6 +292,24 @@ pub fn write_postgres_conf( writeln!(file, "log_destination='stderr,syslog'")?; } + if lakebase_mode { + // Explicitly set the port based on the connstr, overriding any previous port setting. + // Note: It is important that we don't specify a different port again after this. + let port = postgres_port.expect("port must be present in connstr"); + writeln!(file, "port = {port}")?; + + // This is databricks specific settings. + // This should be at the end of the file but before `compute_ctl_temp_override.conf` below + // so that it can override any settings above. + // `compute_ctl_temp_override.conf` is intended to override any settings above during specific operations. + // To prevent potential breakage in the future, we keep it above `compute_ctl_temp_override.conf`. + writeln!(file, "# Databricks settings start")?; + if let Some(settings) = databricks_settings { + writeln!(file, "{}", settings.as_pg_settings())?; + } + writeln!(file, "# Databricks settings end")?; + } + // This is essential to keep this line at the end of the file, // because it is intended to override any settings above. writeln!(file, "include_if_exists = 'compute_ctl_temp_override.conf'")?; diff --git a/compute_tools/src/configurator.rs b/compute_tools/src/configurator.rs index 864335fd2c..feca8337b2 100644 --- a/compute_tools/src/configurator.rs +++ b/compute_tools/src/configurator.rs @@ -2,6 +2,7 @@ use std::fs::File; use std::thread; use std::{path::Path, sync::Arc}; +use anyhow::Result; use compute_api::responses::{ComputeConfig, ComputeStatus}; use tracing::{error, info, instrument}; @@ -13,6 +14,10 @@ fn configurator_main_loop(compute: &Arc) { info!("waiting for reconfiguration requests"); loop { let mut state = compute.state.lock().unwrap(); + /* BEGIN_HADRON */ + // RefreshConfiguration should only be used inside the loop + assert_ne!(state.status, ComputeStatus::RefreshConfiguration); + /* END_HADRON */ if compute.params.lakebase_mode { while state.status != ComputeStatus::ConfigurationPending @@ -54,53 +59,81 @@ fn configurator_main_loop(compute: &Arc) { info!( "compute node suspects its configuration is out of date, now refreshing configuration" ); - // Drop the lock guard here to avoid holding the lock while downloading spec from the control plane / HCC. - // This is the only thread that can move compute_ctl out of the `RefreshConfigurationPending` state, so it + state.set_status(ComputeStatus::RefreshConfiguration, &compute.state_changed); + // Drop the lock guard here to avoid holding the lock while downloading config from the control plane / HCC. + // This is the only thread that can move compute_ctl out of the `RefreshConfiguration` state, so it // is safe to drop the lock like this. drop(state); - let spec = if let Some(config_path) = &compute.params.config_path_test_only { - // This path is only to make testing easier. In production we always get the spec from the HCC. - info!( - "reloading config.json from path: {}", - config_path.to_string_lossy() - ); - let path = Path::new(config_path); - if let Ok(file) = File::open(path) { - match serde_json::from_reader::(file) { - Ok(config) => config.spec, - Err(e) => { - error!("could not parse spec file: {}", e); - None - } - } - } else { - error!( - "could not open config file at path: {}", + let get_config_result: anyhow::Result = + if let Some(config_path) = &compute.params.config_path_test_only { + // This path is only to make testing easier. In production we always get the config from the HCC. + info!( + "reloading config.json from path: {}", config_path.to_string_lossy() ); - None - } - } else if let Some(control_plane_uri) = &compute.params.control_plane_uri { - match get_config_from_control_plane(control_plane_uri, &compute.params.compute_id) { - Ok(config) => config.spec, - Err(e) => { - error!("could not get config from control plane: {}", e); - None + let path = Path::new(config_path); + if let Ok(file) = File::open(path) { + match serde_json::from_reader::(file) { + Ok(config) => Ok(config), + Err(e) => { + error!("could not parse config file: {}", e); + Err(anyhow::anyhow!("could not parse config file: {}", e)) + } + } + } else { + error!( + "could not open config file at path: {:?}", + config_path.to_string_lossy() + ); + Err(anyhow::anyhow!( + "could not open config file at path: {}", + config_path.to_string_lossy() + )) } - } - } else { - None - }; + } else if let Some(control_plane_uri) = &compute.params.control_plane_uri { + get_config_from_control_plane(control_plane_uri, &compute.params.compute_id) + } else { + Err(anyhow::anyhow!("config_path_test_only is not set")) + }; - if let Some(spec) = spec { - if let Ok(pspec) = ParsedSpec::try_from(spec) { + // Parse any received ComputeSpec and transpose the result into a Result>. + let parsed_spec_result: Result> = + get_config_result.and_then(|config| { + if let Some(spec) = config.spec { + if let Ok(pspec) = ParsedSpec::try_from(spec) { + Ok(Some(pspec)) + } else { + Err(anyhow::anyhow!("could not parse spec")) + } + } else { + Ok(None) + } + }); + + let new_status: ComputeStatus; + match parsed_spec_result { + // Control plane (HCM) returned a spec and we were able to parse it. + Ok(Some(pspec)) => { { let mut state = compute.state.lock().unwrap(); // Defensive programming to make sure this thread is indeed the only one that can move the compute - // node out of the `RefreshConfigurationPending` state. Would be nice if we can encode this invariant + // node out of the `RefreshConfiguration` state. Would be nice if we can encode this invariant // into the type system. - assert_eq!(state.status, ComputeStatus::RefreshConfigurationPending); + assert_eq!(state.status, ComputeStatus::RefreshConfiguration); + + if state.pspec.as_ref().map(|ps| ps.pageserver_connstr.clone()) + == Some(pspec.pageserver_connstr.clone()) + { + info!( + "Refresh configuration: Retrieved spec is the same as the current spec. Waiting for control plane to update the spec before attempting reconfiguration." + ); + state.status = ComputeStatus::Running; + compute.state_changed.notify_all(); + drop(state); + std::thread::sleep(std::time::Duration::from_secs(5)); + continue; + } // state.pspec is consumed by compute.reconfigure() below. Note that compute.reconfigure() will acquire // the compute.state lock again so we need to have the lock guard go out of scope here. We could add a // "locked" variant of compute.reconfigure() that takes the lock guard as an argument to make this cleaner, @@ -110,20 +143,45 @@ fn configurator_main_loop(compute: &Arc) { match compute.reconfigure() { Ok(_) => { info!("Refresh configuration: compute node configured"); - compute.set_status(ComputeStatus::Running); + new_status = ComputeStatus::Running; } Err(e) => { error!( "Refresh configuration: could not configure compute node: {}", e ); - // Leave the compute node in the `RefreshConfigurationPending` state if the configuration + // Set the compute node back to the `RefreshConfigurationPending` state if the configuration // was not successful. It should be okay to treat this situation the same as if the loop // hasn't executed yet as long as the detection side keeps notifying. + new_status = ComputeStatus::RefreshConfigurationPending; } } } + // Control plane (HCM)'s response does not contain a spec. This is the "Empty" attachment case. + Ok(None) => { + info!( + "Compute Manager signaled that this compute is no longer attached to any storage. Exiting." + ); + // We just immediately terminate the whole compute_ctl in this case. It's not necessary to attempt a + // clean shutdown as Postgres is probably not responding anyway (which is why we are in this refresh + // configuration state). + std::process::exit(1); + } + // Various error cases: + // - The request to the control plane (HCM) either failed or returned a malformed spec. + // - compute_ctl itself is configured incorrectly (e.g., compute_id is not set). + Err(e) => { + error!( + "Refresh configuration: error getting a parsed spec: {:?}", + e + ); + new_status = ComputeStatus::RefreshConfigurationPending; + // We may be dealing with an overloaded HCM if we end up in this path. Backoff 5 seconds before + // retrying to avoid hammering the HCM. + std::thread::sleep(std::time::Duration::from_secs(5)); + } } + compute.set_status(new_status); } else if state.status == ComputeStatus::Failed { info!("compute node is now in Failed state, exiting"); break; diff --git a/compute_tools/src/http/routes/configure.rs b/compute_tools/src/http/routes/configure.rs index b7325d283f..943ff45357 100644 --- a/compute_tools/src/http/routes/configure.rs +++ b/compute_tools/src/http/routes/configure.rs @@ -43,7 +43,12 @@ pub(in crate::http) async fn configure( // configure request for tracing purposes. state.startup_span = Some(tracing::Span::current()); - state.pspec = Some(pspec); + if compute.params.lakebase_mode { + ComputeNode::set_spec(&compute.params, &mut state, pspec); + } else { + state.pspec = Some(pspec); + } + state.set_status(ComputeStatus::ConfigurationPending, &compute.state_changed); drop(state); } diff --git a/compute_tools/src/http/routes/metrics.rs b/compute_tools/src/http/routes/metrics.rs index 96b464fd12..8406746327 100644 --- a/compute_tools/src/http/routes/metrics.rs +++ b/compute_tools/src/http/routes/metrics.rs @@ -13,6 +13,7 @@ use metrics::{Encoder, TextEncoder}; use crate::communicator_socket_client::connect_communicator_socket; use crate::compute::ComputeNode; +use crate::hadron_metrics; use crate::http::JsonResponse; use crate::metrics::collect; @@ -21,11 +22,18 @@ pub(in crate::http) async fn get_metrics() -> Response { // When we call TextEncoder::encode() below, it will immediately return an // error if a metric family has no metrics, so we need to preemptively // filter out metric families with no metrics. - let metrics = collect() + let mut metrics = collect() .into_iter() .filter(|m| !m.get_metric().is_empty()) .collect::>(); + // Add Hadron metrics. + let hadron_metrics: Vec = hadron_metrics::collect() + .into_iter() + .filter(|m| !m.get_metric().is_empty()) + .collect(); + metrics.extend(hadron_metrics); + let encoder = TextEncoder::new(); let mut buffer = vec![]; diff --git a/compute_tools/src/http/routes/refresh_configuration.rs b/compute_tools/src/http/routes/refresh_configuration.rs index 512abaa0a6..9b2f95ca5a 100644 --- a/compute_tools/src/http/routes/refresh_configuration.rs +++ b/compute_tools/src/http/routes/refresh_configuration.rs @@ -9,7 +9,7 @@ use axum::{ use http::StatusCode; use crate::compute::ComputeNode; -// use crate::hadron_metrics::POSTGRES_PAGESTREAM_REQUEST_ERRORS; +use crate::hadron_metrics::POSTGRES_PAGESTREAM_REQUEST_ERRORS; use crate::http::JsonResponse; /// The /refresh_configuration POST method is used to nudge compute_ctl to pull a new spec @@ -21,6 +21,7 @@ use crate::http::JsonResponse; pub(in crate::http) async fn refresh_configuration( State(compute): State>, ) -> Response { + POSTGRES_PAGESTREAM_REQUEST_ERRORS.inc(); match compute.signal_refresh_configuration().await { Ok(_) => StatusCode::OK.into_response(), Err(e) => JsonResponse::error(StatusCode::INTERNAL_SERVER_ERROR, e), diff --git a/compute_tools/src/http/routes/terminate.rs b/compute_tools/src/http/routes/terminate.rs index 5b30b020c8..deac760f43 100644 --- a/compute_tools/src/http/routes/terminate.rs +++ b/compute_tools/src/http/routes/terminate.rs @@ -1,7 +1,7 @@ use crate::compute::{ComputeNode, forward_termination_signal}; use crate::http::JsonResponse; use axum::extract::State; -use axum::response::Response; +use axum::response::{IntoResponse, Response}; use axum_extra::extract::OptionalQuery; use compute_api::responses::{ComputeStatus, TerminateMode, TerminateResponse}; use http::StatusCode; @@ -33,7 +33,29 @@ pub(in crate::http) async fn terminate( if !matches!(state.status, ComputeStatus::Empty | ComputeStatus::Running) { return JsonResponse::invalid_status(state.status); } + + // If compute is Empty, there's no Postgres to terminate. The regular compute_ctl termination path + // assumes Postgres to be configured and running, so we just special-handle this case by exiting + // the process directly. + if compute.params.lakebase_mode && state.status == ComputeStatus::Empty { + drop(state); + info!("terminating empty compute - will exit process"); + + // Queue a task to exit the process after 5 seconds. The 5-second delay aims to + // give enough time for the HTTP response to be sent so that HCM doesn't get an abrupt + // connection termination. + tokio::spawn(async { + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; + info!("exiting process after terminating empty compute"); + std::process::exit(0); + }); + + return StatusCode::OK.into_response(); + } + + // For Running status, proceed with normal termination state.set_status(mode.into(), &compute.state_changed); + drop(state); } forward_termination_signal(false); diff --git a/compute_tools/src/spec.rs b/compute_tools/src/spec.rs index d00f86a2c0..2dda9c3134 100644 --- a/compute_tools/src/spec.rs +++ b/compute_tools/src/spec.rs @@ -142,7 +142,7 @@ pub fn update_pg_hba(pgdata_path: &Path, databricks_pg_hba: Option<&String>) -> // Update pg_hba to contains databricks specfic settings before adding neon settings // PG uses the first record that matches to perform authentication, so we need to have // our rules before the default ones from neon. - // See https://www.postgresql.org/docs/16/auth-pg-hba-conf.html + // See https://www.postgresql.org/docs/current/auth-pg-hba-conf.html if let Some(databricks_pg_hba) = databricks_pg_hba { if config::line_in_file( &pghba_path, diff --git a/compute_tools/src/spec_apply.rs b/compute_tools/src/spec_apply.rs index 47bf61ae1b..2356078703 100644 --- a/compute_tools/src/spec_apply.rs +++ b/compute_tools/src/spec_apply.rs @@ -13,17 +13,19 @@ use tokio_postgres::Client; use tokio_postgres::error::SqlState; use tracing::{Instrument, debug, error, info, info_span, instrument, warn}; -use crate::compute::{ComputeNode, ComputeNodeParams, ComputeState}; +use crate::compute::{ComputeNode, ComputeNodeParams, ComputeState, create_databricks_roles}; +use crate::hadron_metrics::COMPUTE_CONFIGURE_STATEMENT_TIMEOUT_ERRORS; use crate::pg_helpers::{ DatabaseExt, Escaping, GenericOptionsSearch, RoleExt, get_existing_dbs_async, get_existing_roles_async, }; use crate::spec_apply::ApplySpecPhase::{ - CreateAndAlterDatabases, CreateAndAlterRoles, CreateAvailabilityCheck, CreatePgauditExtension, + AddDatabricksGrants, AlterDatabricksRoles, CreateAndAlterDatabases, CreateAndAlterRoles, + CreateAvailabilityCheck, CreateDatabricksMisc, CreateDatabricksRoles, CreatePgauditExtension, CreatePgauditlogtofileExtension, CreatePrivilegedRole, CreateSchemaNeon, DisablePostgresDBPgAudit, DropInvalidDatabases, DropRoles, FinalizeDropLogicalSubscriptions, - HandleNeonExtension, HandleOtherExtensions, RenameAndDeleteDatabases, RenameRoles, - RunInEachDatabase, + HandleDatabricksAuthExtension, HandleNeonExtension, HandleOtherExtensions, + RenameAndDeleteDatabases, RenameRoles, RunInEachDatabase, }; use crate::spec_apply::PerDatabasePhase::{ ChangeSchemaPerms, DeleteDBRoleReferences, DropLogicalSubscriptions, @@ -166,6 +168,7 @@ impl ComputeNode { concurrency_token.clone(), db, [DropLogicalSubscriptions].to_vec(), + self.params.lakebase_mode, ); Ok(tokio::spawn(fut)) @@ -186,15 +189,33 @@ impl ComputeNode { }; } - for phase in [ - CreatePrivilegedRole, + let phases = if self.params.lakebase_mode { + vec![ + CreatePrivilegedRole, + // BEGIN_HADRON + CreateDatabricksRoles, + AlterDatabricksRoles, + // END_HADRON DropInvalidDatabases, RenameRoles, CreateAndAlterRoles, RenameAndDeleteDatabases, CreateAndAlterDatabases, CreateSchemaNeon, - ] { + ] + } else { + vec![ + CreatePrivilegedRole, + DropInvalidDatabases, + RenameRoles, + CreateAndAlterRoles, + RenameAndDeleteDatabases, + CreateAndAlterDatabases, + CreateSchemaNeon, + ] + }; + + for phase in phases { info!("Applying phase {:?}", &phase); apply_operations( params.clone(), @@ -203,6 +224,7 @@ impl ComputeNode { jwks_roles.clone(), phase, || async { Ok(&client) }, + self.params.lakebase_mode, ) .await?; } @@ -254,6 +276,7 @@ impl ComputeNode { concurrency_token.clone(), db, phases, + self.params.lakebase_mode, ); Ok(tokio::spawn(fut)) @@ -265,12 +288,28 @@ impl ComputeNode { handle.await??; } - let mut phases = vec![ + let mut phases = if self.params.lakebase_mode { + vec![ + HandleOtherExtensions, + HandleNeonExtension, // This step depends on CreateSchemaNeon + // BEGIN_HADRON + HandleDatabricksAuthExtension, + // END_HADRON + CreateAvailabilityCheck, + DropRoles, + // BEGIN_HADRON + AddDatabricksGrants, + CreateDatabricksMisc, + // END_HADRON + ] + } else { + vec![ HandleOtherExtensions, HandleNeonExtension, // This step depends on CreateSchemaNeon CreateAvailabilityCheck, DropRoles, - ]; + ] + }; // This step depends on CreateSchemaNeon if spec.drop_subscriptions_before_start && !drop_subscriptions_done { @@ -303,6 +342,7 @@ impl ComputeNode { jwks_roles.clone(), phase, || async { Ok(&client) }, + self.params.lakebase_mode, ) .await?; } @@ -328,6 +368,7 @@ impl ComputeNode { concurrency_token: Arc, db: DB, subphases: Vec, + lakebase_mode: bool, ) -> Result<()> { let _permit = concurrency_token.acquire().await?; @@ -355,6 +396,7 @@ impl ComputeNode { let client = client_conn.as_ref().unwrap(); Ok(client) }, + lakebase_mode, ) .await?; } @@ -477,6 +519,10 @@ pub enum PerDatabasePhase { #[derive(Clone, Debug)] pub enum ApplySpecPhase { CreatePrivilegedRole, + // BEGIN_HADRON + CreateDatabricksRoles, + AlterDatabricksRoles, + // END_HADRON DropInvalidDatabases, RenameRoles, CreateAndAlterRoles, @@ -489,7 +535,14 @@ pub enum ApplySpecPhase { DisablePostgresDBPgAudit, HandleOtherExtensions, HandleNeonExtension, + // BEGIN_HADRON + HandleDatabricksAuthExtension, + // END_HADRON CreateAvailabilityCheck, + // BEGIN_HADRON + AddDatabricksGrants, + CreateDatabricksMisc, + // END_HADRON DropRoles, FinalizeDropLogicalSubscriptions, } @@ -525,6 +578,7 @@ pub async fn apply_operations<'a, Fut, F>( jwks_roles: Arc>, apply_spec_phase: ApplySpecPhase, client: F, + lakebase_mode: bool, ) -> Result<()> where F: FnOnce() -> Fut, @@ -571,6 +625,23 @@ where }, query ); + if !lakebase_mode { + return res; + } + // BEGIN HADRON + if let Err(e) = res.as_ref() { + if let Some(sql_state) = e.code() { + if sql_state.code() == "57014" { + // SQL State 57014 (ERRCODE_QUERY_CANCELED) is used for statement timeouts. + // Increment the counter whenever a statement timeout occurs. Timeouts on + // this configuration path can only occur due to PS connectivity problems that + // Postgres failed to recover from. + COMPUTE_CONFIGURE_STATEMENT_TIMEOUT_ERRORS.inc(); + } + } + } + // END HADRON + res } .instrument(inspan) @@ -612,6 +683,35 @@ async fn get_operations<'a>( ), comment: None, }))), + // BEGIN_HADRON + // New Hadron phase + ApplySpecPhase::CreateDatabricksRoles => { + let queries = create_databricks_roles(); + let operations = queries.into_iter().map(|query| Operation { + query, + comment: None, + }); + Ok(Box::new(operations)) + } + + // Backfill existing databricks_reader_* roles with statement timeout from GUC + ApplySpecPhase::AlterDatabricksRoles => { + let query = String::from(include_str!( + "sql/alter_databricks_reader_roles_timeout.sql" + )); + + let operations = once(Operation { + query, + comment: Some( + "Backfill existing databricks_reader_* roles with statement timeout" + .to_string(), + ), + }); + + Ok(Box::new(operations)) + } + // End of new Hadron Phase + // END_HADRON ApplySpecPhase::DropInvalidDatabases => { let mut ctx = ctx.write().await; let databases = &mut ctx.dbs; @@ -981,7 +1081,10 @@ async fn get_operations<'a>( // N.B. this has to be properly dollar-escaped with `pg_quote_dollar()` role_name = escaped_role, outer_tag = outer_tag, - ), + ) + // HADRON change: + .replace("neon_superuser", ¶ms.privileged_role_name), + // HADRON change end , comment: None, }, // This now will only drop privileges of the role @@ -1017,7 +1120,8 @@ async fn get_operations<'a>( comment: None, }, Operation { - query: String::from(include_str!("sql/default_grants.sql")), + query: String::from(include_str!("sql/default_grants.sql")) + .replace("neon_superuser", ¶ms.privileged_role_name), comment: None, }, ] @@ -1086,6 +1190,28 @@ async fn get_operations<'a>( Ok(Box::new(operations)) } + // BEGIN_HADRON + // Note: we may want to version the extension someday, but for now we just drop it and recreate it. + ApplySpecPhase::HandleDatabricksAuthExtension => { + let operations = vec![ + Operation { + query: String::from("DROP EXTENSION IF EXISTS databricks_auth"), + comment: Some(String::from("dropping existing databricks_auth extension")), + }, + Operation { + query: String::from("CREATE EXTENSION databricks_auth"), + comment: Some(String::from("creating databricks_auth extension")), + }, + Operation { + query: String::from("GRANT SELECT ON databricks_auth_metrics TO pg_monitor"), + comment: Some(String::from("grant select on databricks auth counters")), + }, + ] + .into_iter(); + + Ok(Box::new(operations)) + } + // END_HADRON ApplySpecPhase::CreateAvailabilityCheck => Ok(Box::new(once(Operation { query: String::from(include_str!("sql/add_availabilitycheck_tables.sql")), comment: None, @@ -1103,6 +1229,63 @@ async fn get_operations<'a>( Ok(Box::new(operations)) } + + // BEGIN_HADRON + // New Hadron phases + // + // Grants permissions to roles that are used by Databricks. + ApplySpecPhase::AddDatabricksGrants => { + let operations = vec![ + Operation { + query: String::from("GRANT USAGE ON SCHEMA neon TO databricks_monitor"), + comment: Some(String::from( + "Permissions needed to execute neon.* functions (in the postgres database)", + )), + }, + Operation { + query: String::from( + "GRANT SELECT, INSERT, UPDATE ON health_check TO databricks_monitor", + ), + comment: Some(String::from("Permissions needed for read and write probes")), + }, + Operation { + query: String::from( + "GRANT EXECUTE ON FUNCTION pg_ls_dir(text) TO databricks_monitor", + ), + comment: Some(String::from( + "Permissions needed to monitor .snap file counts", + )), + }, + Operation { + query: String::from( + "GRANT SELECT ON neon.neon_perf_counters TO databricks_monitor", + ), + comment: Some(String::from( + "Permissions needed to access neon performance counters view", + )), + }, + Operation { + query: String::from( + "GRANT EXECUTE ON FUNCTION neon.get_perf_counters() TO databricks_monitor", + ), + comment: Some(String::from( + "Permissions needed to execute the underlying performance counters function", + )), + }, + ] + .into_iter(); + + Ok(Box::new(operations)) + } + // Creates minor objects that are used by Databricks. + ApplySpecPhase::CreateDatabricksMisc => Ok(Box::new(once(Operation { + query: String::from(include_str!("sql/create_databricks_misc.sql")), + comment: Some(String::from( + "The function databricks_monitor uses to convert exception to 0 or 1", + )), + }))), + // End of new Hadron phases + // END_HADRON ApplySpecPhase::FinalizeDropLogicalSubscriptions => Ok(Box::new(once(Operation { query: String::from(include_str!("sql/finalize_drop_subscriptions.sql")), comment: None, diff --git a/compute_tools/src/sql/alter_databricks_reader_roles_timeout.sql b/compute_tools/src/sql/alter_databricks_reader_roles_timeout.sql new file mode 100644 index 0000000000..db16df3817 --- /dev/null +++ b/compute_tools/src/sql/alter_databricks_reader_roles_timeout.sql @@ -0,0 +1,25 @@ +DO $$ +DECLARE + reader_role RECORD; + timeout_value TEXT; +BEGIN + -- Get the current GUC setting for reader statement timeout + SELECT current_setting('databricks.reader_statement_timeout', true) INTO timeout_value; + + -- Only proceed if timeout_value is not null/empty and not '0' (disabled) + IF timeout_value IS NOT NULL AND timeout_value != '' AND timeout_value != '0' THEN + -- Find all databricks_reader_* roles and update their statement_timeout + FOR reader_role IN + SELECT r.rolname + FROM pg_roles r + WHERE r.rolname ~ '^databricks_reader_\d+$' + LOOP + -- Apply the timeout setting to the role (will overwrite existing setting) + EXECUTE format('ALTER ROLE %I SET statement_timeout = %L', + reader_role.rolname, timeout_value); + + RAISE LOG 'Updated statement_timeout = % for role %', timeout_value, reader_role.rolname; + END LOOP; + END IF; +END +$$; diff --git a/compute_tools/src/sql/create_databricks_misc.sql b/compute_tools/src/sql/create_databricks_misc.sql new file mode 100644 index 0000000000..a6dc379078 --- /dev/null +++ b/compute_tools/src/sql/create_databricks_misc.sql @@ -0,0 +1,15 @@ +ALTER ROLE databricks_monitor SET statement_timeout = '60s'; + +CREATE OR REPLACE FUNCTION health_check_write_succeeds() +RETURNS INTEGER AS $$ +BEGIN +INSERT INTO health_check VALUES (1, now()) +ON CONFLICT (id) DO UPDATE + SET updated_at = now(); + +RETURN 1; +EXCEPTION WHEN OTHERS THEN +RAISE EXCEPTION '[DATABRICKS_SMGR] health_check failed: [%] %', SQLSTATE, SQLERRM; +RETURN 0; +END; +$$ LANGUAGE plpgsql; diff --git a/control_plane/src/endpoint.rs b/control_plane/src/endpoint.rs index 20dcf85562..1c7f489d68 100644 --- a/control_plane/src/endpoint.rs +++ b/control_plane/src/endpoint.rs @@ -793,6 +793,7 @@ impl Endpoint { autoprewarm: args.autoprewarm, offload_lfc_interval_seconds: args.offload_lfc_interval_seconds, suspend_timeout_seconds: -1, // Only used in neon_local. + databricks_settings: None, }; // this strange code is needed to support respec() in tests @@ -938,7 +939,8 @@ impl Endpoint { | ComputeStatus::TerminationPendingFast | ComputeStatus::TerminationPendingImmediate | ComputeStatus::Terminated - | ComputeStatus::RefreshConfigurationPending => { + | ComputeStatus::RefreshConfigurationPending + | ComputeStatus::RefreshConfiguration => { bail!("unexpected compute status: {:?}", state.status) } } diff --git a/libs/compute_api/src/responses.rs b/libs/compute_api/src/responses.rs index 7efd94c76a..a27301e45e 100644 --- a/libs/compute_api/src/responses.rs +++ b/libs/compute_api/src/responses.rs @@ -174,6 +174,9 @@ pub enum ComputeStatus { Terminated, // A spec refresh is being requested RefreshConfigurationPending, + // A spec refresh is being applied. We cannot refresh configuration again until the current + // refresh is done, i.e., signal_refresh_configuration() will return 500 error. + RefreshConfiguration, } #[derive(Deserialize, Serialize)] @@ -186,6 +189,10 @@ impl Display for ComputeStatus { match self { ComputeStatus::Empty => f.write_str("empty"), ComputeStatus::ConfigurationPending => f.write_str("configuration-pending"), + ComputeStatus::RefreshConfiguration => f.write_str("refresh-configuration"), + ComputeStatus::RefreshConfigurationPending => { + f.write_str("refresh-configuration-pending") + } ComputeStatus::Init => f.write_str("init"), ComputeStatus::Running => f.write_str("running"), ComputeStatus::Configuration => f.write_str("configuration"), @@ -195,9 +202,6 @@ impl Display for ComputeStatus { f.write_str("termination-pending-immediate") } ComputeStatus::Terminated => f.write_str("terminated"), - ComputeStatus::RefreshConfigurationPending => { - f.write_str("refresh-configuration-pending") - } } } } diff --git a/libs/compute_api/src/spec.rs b/libs/compute_api/src/spec.rs index 061ac3e66d..6709c06fc6 100644 --- a/libs/compute_api/src/spec.rs +++ b/libs/compute_api/src/spec.rs @@ -193,6 +193,9 @@ pub struct ComputeSpec { /// /// We use this value to derive other values, such as the installed extensions metric. pub suspend_timeout_seconds: i64, + + // Databricks specific options for compute instance. + pub databricks_settings: Option, } /// Feature flag to signal `compute_ctl` to enable certain experimental functionality. diff --git a/libs/http-utils/src/endpoint.rs b/libs/http-utils/src/endpoint.rs index a61bf8e08a..c23d95f3e6 100644 --- a/libs/http-utils/src/endpoint.rs +++ b/libs/http-utils/src/endpoint.rs @@ -558,11 +558,11 @@ async fn add_request_id_header_to_response( mut res: Response, req_info: RequestInfo, ) -> Result, ApiError> { - if let Some(request_id) = req_info.context::() { - if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) { - res.headers_mut() - .insert(&X_REQUEST_ID_HEADER, request_header_value); - }; + if let Some(request_id) = req_info.context::() + && let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) + { + res.headers_mut() + .insert(&X_REQUEST_ID_HEADER, request_header_value); }; Ok(res) diff --git a/libs/http-utils/src/server.rs b/libs/http-utils/src/server.rs index f93f71c962..ce90b8d710 100644 --- a/libs/http-utils/src/server.rs +++ b/libs/http-utils/src/server.rs @@ -72,10 +72,10 @@ impl Server { if err.is_incomplete_message() || err.is_closed() || err.is_timeout() { return true; } - if let Some(inner) = err.source() { - if let Some(io) = inner.downcast_ref::() { - return suppress_io_error(io); - } + if let Some(inner) = err.source() + && let Some(io) = inner.downcast_ref::() + { + return suppress_io_error(io); } false } diff --git a/libs/metrics/src/lib.rs b/libs/metrics/src/lib.rs index 41873cdcd6..6cf27abcaf 100644 --- a/libs/metrics/src/lib.rs +++ b/libs/metrics/src/lib.rs @@ -129,6 +129,12 @@ impl InfoMetric { } } +impl Default for InfoMetric { + fn default() -> Self { + InfoMetric::new(L::default()) + } +} + impl> InfoMetric { pub fn with_metric(label: L, metric: M) -> Self { Self { diff --git a/libs/neon-shmem/src/hash.rs b/libs/neon-shmem/src/hash.rs index 58726b9ba3..a58797d8fa 100644 --- a/libs/neon-shmem/src/hash.rs +++ b/libs/neon-shmem/src/hash.rs @@ -363,7 +363,7 @@ where // TODO: An Iterator might be nicer. The communicator's clock algorithm needs to // _slowly_ iterate through all buckets with its clock hand, without holding a lock. // If we switch to an Iterator, it must not hold the lock. - pub fn get_at_bucket(&self, pos: usize) -> Option> { + pub fn get_at_bucket(&self, pos: usize) -> Option> { let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read(); if pos >= map.buckets.len() { return None; diff --git a/libs/postgres_ffi/Cargo.toml b/libs/postgres_ffi/Cargo.toml index d4fec6cbe9..23fabeccd2 100644 --- a/libs/postgres_ffi/Cargo.toml +++ b/libs/postgres_ffi/Cargo.toml @@ -9,10 +9,7 @@ regex.workspace = true bytes.workspace = true anyhow.workspace = true crc32c.workspace = true -criterion.workspace = true once_cell.workspace = true -log.workspace = true -memoffset.workspace = true pprof.workspace = true thiserror.workspace = true serde.workspace = true @@ -22,6 +19,7 @@ tracing.workspace = true postgres_versioninfo.workspace = true [dev-dependencies] +criterion.workspace = true env_logger.workspace = true postgres.workspace = true diff --git a/libs/postgres_ffi/src/controlfile_utils.rs b/libs/postgres_ffi/src/controlfile_utils.rs index eaa9450294..d6d17ce3fb 100644 --- a/libs/postgres_ffi/src/controlfile_utils.rs +++ b/libs/postgres_ffi/src/controlfile_utils.rs @@ -34,9 +34,8 @@ const SIZEOF_CONTROLDATA: usize = size_of::(); impl ControlFileData { /// Compute the offset of the `crc` field within the `ControlFileData` struct. /// Equivalent to offsetof(ControlFileData, crc) in C. - // Someday this can be const when the right compiler features land. - fn pg_control_crc_offset() -> usize { - memoffset::offset_of!(ControlFileData, crc) + const fn pg_control_crc_offset() -> usize { + std::mem::offset_of!(ControlFileData, crc) } /// diff --git a/libs/postgres_ffi/src/nonrelfile_utils.rs b/libs/postgres_ffi/src/nonrelfile_utils.rs index e3e7133b94..f6693d4ec1 100644 --- a/libs/postgres_ffi/src/nonrelfile_utils.rs +++ b/libs/postgres_ffi/src/nonrelfile_utils.rs @@ -4,12 +4,11 @@ use crate::pg_constants; use crate::transaction_id_precedes; use bytes::BytesMut; -use log::*; use super::bindings::MultiXactId; pub fn transaction_id_set_status(xid: u32, status: u8, page: &mut BytesMut) { - trace!( + tracing::trace!( "handle_apply_request for RM_XACT_ID-{} (1-commit, 2-abort, 3-sub_commit)", status ); diff --git a/libs/postgres_ffi/src/waldecoder_handler.rs b/libs/postgres_ffi/src/waldecoder_handler.rs index 9cd40645ec..563a3426a0 100644 --- a/libs/postgres_ffi/src/waldecoder_handler.rs +++ b/libs/postgres_ffi/src/waldecoder_handler.rs @@ -14,7 +14,6 @@ use super::xlog_utils::*; use crate::WAL_SEGMENT_SIZE; use bytes::{Buf, BufMut, Bytes, BytesMut}; use crc32c::*; -use log::*; use std::cmp::min; use std::num::NonZeroU32; use utils::lsn::Lsn; @@ -236,7 +235,7 @@ impl WalStreamDecoderHandler for WalStreamDecoder { // XLOG_SWITCH records are special. If we see one, we need to skip // to the next WAL segment. let next_lsn = if xlogrec.is_xlog_switch_record() { - trace!("saw xlog switch record at {}", self.lsn); + tracing::trace!("saw xlog switch record at {}", self.lsn); self.lsn + self.lsn.calc_padding(WAL_SEGMENT_SIZE as u64) } else { // Pad to an 8-byte boundary diff --git a/libs/postgres_ffi/src/xlog_utils.rs b/libs/postgres_ffi/src/xlog_utils.rs index 134baf5ff7..913e6b453f 100644 --- a/libs/postgres_ffi/src/xlog_utils.rs +++ b/libs/postgres_ffi/src/xlog_utils.rs @@ -23,8 +23,6 @@ use crate::{WAL_SEGMENT_SIZE, XLOG_BLCKSZ}; use bytes::BytesMut; use bytes::{Buf, Bytes}; -use log::*; - use serde::Serialize; use std::ffi::{CString, OsStr}; use std::fs::File; @@ -235,7 +233,7 @@ pub fn find_end_of_wal( let mut curr_lsn = start_lsn; let mut buf = [0u8; XLOG_BLCKSZ]; let pg_version = MY_PGVERSION; - debug!("find_end_of_wal PG_VERSION: {}", pg_version); + tracing::debug!("find_end_of_wal PG_VERSION: {}", pg_version); let mut decoder = WalStreamDecoder::new(start_lsn, pg_version); @@ -247,7 +245,7 @@ pub fn find_end_of_wal( match open_wal_segment(&seg_file_path)? { None => { // no more segments - debug!( + tracing::debug!( "find_end_of_wal reached end at {:?}, segment {:?} doesn't exist", result, seg_file_path ); @@ -260,7 +258,7 @@ pub fn find_end_of_wal( while curr_lsn.segment_number(wal_seg_size) == segno { let bytes_read = segment.read(&mut buf)?; if bytes_read == 0 { - debug!( + tracing::debug!( "find_end_of_wal reached end at {:?}, EOF in segment {:?} at offset {}", result, seg_file_path, @@ -276,7 +274,7 @@ pub fn find_end_of_wal( match decoder.poll_decode() { Ok(Some(record)) => result = record.0, Err(e) => { - debug!( + tracing::debug!( "find_end_of_wal reached end at {:?}, decode error: {:?}", result, e ); diff --git a/proxy/subzero_core/.gitignore b/libs/proxy/subzero_core/.gitignore similarity index 100% rename from proxy/subzero_core/.gitignore rename to libs/proxy/subzero_core/.gitignore diff --git a/proxy/subzero_core/Cargo.toml b/libs/proxy/subzero_core/Cargo.toml similarity index 100% rename from proxy/subzero_core/Cargo.toml rename to libs/proxy/subzero_core/Cargo.toml diff --git a/proxy/subzero_core/src/lib.rs b/libs/proxy/subzero_core/src/lib.rs similarity index 100% rename from proxy/subzero_core/src/lib.rs rename to libs/proxy/subzero_core/src/lib.rs diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index f5aed010ef..90ff39aff1 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -185,6 +185,7 @@ impl Client { ssl_mode: SslMode, process_id: i32, secret_key: i32, + write_buf: BytesMut, ) -> Client { Client { inner: InnerClient { @@ -195,7 +196,7 @@ impl Client { waiting: 0, received: 0, }, - buffer: Default::default(), + buffer: write_buf, }, cached_typeinfo: Default::default(), diff --git a/libs/proxy/tokio-postgres2/src/codec.rs b/libs/proxy/tokio-postgres2/src/codec.rs index 71fe062fca..35f616d229 100644 --- a/libs/proxy/tokio-postgres2/src/codec.rs +++ b/libs/proxy/tokio-postgres2/src/codec.rs @@ -47,14 +47,7 @@ impl Encoder for PostgresCodec { type Error = io::Error; fn encode(&mut self, item: BytesMut, dst: &mut BytesMut) -> io::Result<()> { - // When it comes to request/response workflows, we usually flush the entire write - // buffer in order to wait for the response before we send a new request. - // Therefore we can avoid the copy and just replace the buffer. - if dst.is_empty() { - *dst = item; - } else { - dst.extend_from_slice(&item); - } + dst.unsplit(item); Ok(()) } } diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs index ca6f69f049..b1df87811e 100644 --- a/libs/proxy/tokio-postgres2/src/connect.rs +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -77,6 +77,9 @@ where connect_timeout, }; + let mut stream = stream.into_framed(); + let write_buf = std::mem::take(stream.write_buffer_mut()); + let (client_tx, conn_rx) = mpsc::unbounded_channel(); let (conn_tx, client_rx) = mpsc::channel(4); let client = Client::new( @@ -86,9 +89,9 @@ where ssl_mode, process_id, secret_key, + write_buf, ); - let stream = stream.into_framed(); let connection = Connection::new(stream, conn_tx, conn_rx); Ok((client, connection)) diff --git a/libs/proxy/tokio-postgres2/src/connection.rs b/libs/proxy/tokio-postgres2/src/connection.rs index bee4b3372d..303de71cfa 100644 --- a/libs/proxy/tokio-postgres2/src/connection.rs +++ b/libs/proxy/tokio-postgres2/src/connection.rs @@ -229,8 +229,11 @@ where Poll::Ready(()) => { trace!("poll_flush: flushed"); - // GC the write buffer if we managed to flush - gc_bytesmut(self.stream.write_buffer_mut()); + // Since our codec prefers to share the buffer with the `Client`, + // if we don't release our share, then the `Client` would have to re-alloc + // the buffer when they next use it. + debug_assert!(self.stream.write_buffer().is_empty()); + *self.stream.write_buffer_mut() = BytesMut::new(); Poll::Ready(Ok(())) } diff --git a/libs/proxy/tokio-postgres2/src/error/mod.rs b/libs/proxy/tokio-postgres2/src/error/mod.rs index 6e68b1e595..3fbb97f9bb 100644 --- a/libs/proxy/tokio-postgres2/src/error/mod.rs +++ b/libs/proxy/tokio-postgres2/src/error/mod.rs @@ -9,7 +9,7 @@ use postgres_protocol2::message::backend::{ErrorFields, ErrorResponseBody}; pub use self::sqlstate::*; #[allow(clippy::unreadable_literal)] -mod sqlstate; +pub mod sqlstate; /// The severity of a Postgres error or notice. #[derive(Debug, Copy, Clone, PartialEq, Eq)] diff --git a/libs/tracing-utils/src/perf_span.rs b/libs/tracing-utils/src/perf_span.rs index 16f713c67e..4eec0829f7 100644 --- a/libs/tracing-utils/src/perf_span.rs +++ b/libs/tracing-utils/src/perf_span.rs @@ -49,7 +49,7 @@ impl PerfSpan { } } - pub fn enter(&self) -> PerfSpanEntered { + pub fn enter(&self) -> PerfSpanEntered<'_> { if let Some(ref id) = self.inner.id() { self.dispatch.enter(id); } diff --git a/libs/walproposer/src/api_bindings.rs b/libs/walproposer/src/api_bindings.rs index 74662f8b12..a9a39948a8 100644 --- a/libs/walproposer/src/api_bindings.rs +++ b/libs/walproposer/src/api_bindings.rs @@ -346,6 +346,9 @@ extern "C" fn reset_safekeeper_statuses_for_metrics(wp: *mut WalProposer, num_sa unsafe { let callback_data = (*(*wp).config).callback_data; let api = callback_data as *mut Box; + if api.is_null() { + return; + } (*api).reset_safekeeper_statuses_for_metrics(&mut (*wp), num_safekeepers); } } @@ -358,6 +361,9 @@ extern "C" fn update_safekeeper_status_for_metrics( unsafe { let callback_data = (*(*wp).config).callback_data; let api = callback_data as *mut Box; + if api.is_null() { + return; + } (*api).update_safekeeper_status_for_metrics(&mut (*wp), sk_index, status); } } diff --git a/pageserver/client_grpc/src/client.rs b/pageserver/client_grpc/src/client.rs index e6a90fb582..dad37ebe74 100644 --- a/pageserver/client_grpc/src/client.rs +++ b/pageserver/client_grpc/src/client.rs @@ -14,9 +14,9 @@ use utils::logging::warn_slow; use crate::pool::{ChannelPool, ClientGuard, ClientPool, StreamGuard, StreamPool}; use crate::retry::Retry; -use crate::split::GetPageSplitter; use compute_api::spec::PageserverProtocol; use pageserver_page_api as page_api; +use pageserver_page_api::GetPageSplitter; use utils::id::{TenantId, TimelineId}; use utils::shard::{ShardCount, ShardIndex, ShardNumber, ShardStripeSize}; @@ -230,16 +230,14 @@ impl PageserverClient { ) -> tonic::Result { // Fast path: request is for a single shard. if let Some(shard_id) = - GetPageSplitter::for_single_shard(&req, shards.count, shards.stripe_size) - .map_err(|err| tonic::Status::internal(err.to_string()))? + GetPageSplitter::for_single_shard(&req, shards.count, shards.stripe_size)? { return Self::get_page_with_shard(req, shards.get(shard_id)?).await; } // Request spans multiple shards. Split it, dispatch concurrent per-shard requests, and // reassemble the responses. - let mut splitter = GetPageSplitter::split(req, shards.count, shards.stripe_size) - .map_err(|err| tonic::Status::internal(err.to_string()))?; + let mut splitter = GetPageSplitter::split(req, shards.count, shards.stripe_size)?; let mut shard_requests = FuturesUnordered::new(); for (shard_id, shard_req) in splitter.drain_requests() { @@ -249,14 +247,10 @@ impl PageserverClient { } while let Some((shard_id, shard_response)) = shard_requests.next().await.transpose()? { - splitter - .add_response(shard_id, shard_response) - .map_err(|err| tonic::Status::internal(err.to_string()))?; + splitter.add_response(shard_id, shard_response)?; } - splitter - .get_response() - .map_err(|err| tonic::Status::internal(err.to_string())) + Ok(splitter.collect_response()?) } /// Fetches pages on the given shard. Does not retry internally. diff --git a/pageserver/client_grpc/src/lib.rs b/pageserver/client_grpc/src/lib.rs index 14fb3fbd5a..4999fd3d0a 100644 --- a/pageserver/client_grpc/src/lib.rs +++ b/pageserver/client_grpc/src/lib.rs @@ -1,6 +1,5 @@ mod client; mod pool; mod retry; -mod split; pub use client::{PageserverClient, ShardSpec}; diff --git a/pageserver/page_api/src/lib.rs b/pageserver/page_api/src/lib.rs index e78f6ce206..b9be6b8b91 100644 --- a/pageserver/page_api/src/lib.rs +++ b/pageserver/page_api/src/lib.rs @@ -19,7 +19,9 @@ pub mod proto { } mod client; -pub use client::Client; mod model; +mod split; +pub use client::Client; pub use model::*; +pub use split::{GetPageSplitter, SplitError}; diff --git a/pageserver/client_grpc/src/split.rs b/pageserver/page_api/src/split.rs similarity index 73% rename from pageserver/client_grpc/src/split.rs rename to pageserver/page_api/src/split.rs index 8631638686..27c1c995e0 100644 --- a/pageserver/client_grpc/src/split.rs +++ b/pageserver/page_api/src/split.rs @@ -1,20 +1,19 @@ use std::collections::HashMap; -use anyhow::anyhow; use bytes::Bytes; +use crate::model::*; use pageserver_api::key::rel_block_to_key; use pageserver_api::shard::key_to_shard_number; -use pageserver_page_api as page_api; use utils::shard::{ShardCount, ShardIndex, ShardStripeSize}; /// Splits GetPageRequests that straddle shard boundaries and assembles the responses. /// TODO: add tests for this. pub struct GetPageSplitter { /// Split requests by shard index. - requests: HashMap, + requests: HashMap, /// The response being assembled. Preallocated with empty pages, to be filled in. - response: page_api::GetPageResponse, + response: GetPageResponse, /// Maps the offset in `request.block_numbers` and `response.pages` to the owning shard. Used /// to assemble the response pages in the same order as the original request. block_shards: Vec, @@ -24,22 +23,22 @@ impl GetPageSplitter { /// Checks if the given request only touches a single shard, and returns the shard ID. This is /// the common case, so we check first in order to avoid unnecessary allocations and overhead. pub fn for_single_shard( - req: &page_api::GetPageRequest, + req: &GetPageRequest, count: ShardCount, stripe_size: Option, - ) -> anyhow::Result> { + ) -> Result, SplitError> { // Fast path: unsharded tenant. if count.is_unsharded() { return Ok(Some(ShardIndex::unsharded())); } let Some(stripe_size) = stripe_size else { - return Err(anyhow!("stripe size must be given for sharded tenants")); + return Err("stripe size must be given for sharded tenants".into()); }; // Find the first page's shard, for comparison. let Some(&first_page) = req.block_numbers.first() else { - return Err(anyhow!("no block numbers in request")); + return Err("no block numbers in request".into()); }; let key = rel_block_to_key(req.rel, first_page); let shard_number = key_to_shard_number(count, stripe_size, &key); @@ -57,10 +56,10 @@ impl GetPageSplitter { /// Splits the given request. pub fn split( - req: page_api::GetPageRequest, + req: GetPageRequest, count: ShardCount, stripe_size: Option, - ) -> anyhow::Result { + ) -> Result { // The caller should make sure we don't split requests unnecessarily. debug_assert!( Self::for_single_shard(&req, count, stripe_size)?.is_none(), @@ -68,10 +67,10 @@ impl GetPageSplitter { ); if count.is_unsharded() { - return Err(anyhow!("unsharded tenant, no point in splitting request")); + return Err("unsharded tenant, no point in splitting request".into()); } let Some(stripe_size) = stripe_size else { - return Err(anyhow!("stripe size must be given for sharded tenants")); + return Err("stripe size must be given for sharded tenants".into()); }; // Split the requests by shard index. @@ -84,7 +83,7 @@ impl GetPageSplitter { requests .entry(shard_id) - .or_insert_with(|| page_api::GetPageRequest { + .or_insert_with(|| GetPageRequest { request_id: req.request_id, request_class: req.request_class, rel: req.rel, @@ -98,16 +97,16 @@ impl GetPageSplitter { // Construct a response to be populated by shard responses. Preallocate empty page slots // with the expected block numbers. - let response = page_api::GetPageResponse { + let response = GetPageResponse { request_id: req.request_id, - status_code: page_api::GetPageStatusCode::Ok, + status_code: GetPageStatusCode::Ok, reason: None, rel: req.rel, pages: req .block_numbers .into_iter() .map(|block_number| { - page_api::Page { + Page { block_number, image: Bytes::new(), // empty page slot to be filled in } @@ -123,43 +122,38 @@ impl GetPageSplitter { } /// Drains the per-shard requests, moving them out of the splitter to avoid extra allocations. - pub fn drain_requests( - &mut self, - ) -> impl Iterator { + pub fn drain_requests(&mut self) -> impl Iterator { self.requests.drain() } /// Adds a response from the given shard. The response must match the request ID and have an OK /// status code. A response must not already exist for the given shard ID. - #[allow(clippy::result_large_err)] pub fn add_response( &mut self, shard_id: ShardIndex, - response: page_api::GetPageResponse, - ) -> anyhow::Result<()> { + response: GetPageResponse, + ) -> Result<(), SplitError> { // The caller should already have converted status codes into tonic::Status. - if response.status_code != page_api::GetPageStatusCode::Ok { - return Err(anyhow!( + if response.status_code != GetPageStatusCode::Ok { + return Err(SplitError(format!( "unexpected non-OK response for shard {shard_id}: {} {}", response.status_code, response.reason.unwrap_or_default() - )); + ))); } if response.request_id != self.response.request_id { - return Err(anyhow!( + return Err(SplitError(format!( "response ID mismatch for shard {shard_id}: expected {}, got {}", - self.response.request_id, - response.request_id - )); + self.response.request_id, response.request_id + ))); } if response.request_id != self.response.request_id { - return Err(anyhow!( + return Err(SplitError(format!( "response ID mismatch for shard {shard_id}: expected {}, got {}", - self.response.request_id, - response.request_id - )); + self.response.request_id, response.request_id + ))); } // Place the shard response pages into the assembled response, in request order. @@ -171,26 +165,27 @@ impl GetPageSplitter { } let Some(slot) = self.response.pages.get_mut(i) else { - return Err(anyhow!("no block_shards slot {i} for shard {shard_id}")); + return Err(SplitError(format!( + "no block_shards slot {i} for shard {shard_id}" + ))); }; let Some(page) = pages.next() else { - return Err(anyhow!( + return Err(SplitError(format!( "missing page {} in shard {shard_id} response", slot.block_number - )); + ))); }; if page.block_number != slot.block_number { - return Err(anyhow!( + return Err(SplitError(format!( "shard {shard_id} returned wrong page at index {i}, expected {} got {}", - slot.block_number, - page.block_number - )); + slot.block_number, page.block_number + ))); } if !slot.image.is_empty() { - return Err(anyhow!( + return Err(SplitError(format!( "shard {shard_id} returned duplicate page {} at index {i}", slot.block_number - )); + ))); } *slot = page; @@ -198,32 +193,54 @@ impl GetPageSplitter { // Make sure we've consumed all pages from the shard response. if let Some(extra_page) = pages.next() { - return Err(anyhow!( + return Err(SplitError(format!( "shard {shard_id} returned extra page: {}", extra_page.block_number - )); + ))); } Ok(()) } - /// Fetches the final, assembled response. - #[allow(clippy::result_large_err)] - pub fn get_response(self) -> anyhow::Result { + /// Collects the final, assembled response. + pub fn collect_response(self) -> Result { // Check that the response is complete. for (i, page) in self.response.pages.iter().enumerate() { if page.image.is_empty() { - return Err(anyhow!( + return Err(SplitError(format!( "missing page {} for shard {}", page.block_number, self.block_shards .get(i) .map(|s| s.to_string()) .unwrap_or_else(|| "?".to_string()) - )); + ))); } } Ok(self.response) } } + +/// A GetPageSplitter error. +#[derive(Debug, thiserror::Error)] +#[error("{0}")] +pub struct SplitError(String); + +impl From<&str> for SplitError { + fn from(err: &str) -> Self { + SplitError(err.to_string()) + } +} + +impl From for SplitError { + fn from(err: String) -> Self { + SplitError(err) + } +} + +impl From for tonic::Status { + fn from(err: SplitError) -> Self { + tonic::Status::internal(err.0) + } +} diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 855af7009c..b1566c2d6e 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -715,7 +715,7 @@ fn start_pageserver( disk_usage_eviction_state, deletion_queue.new_client(), secondary_controller, - feature_resolver, + feature_resolver.clone(), ) .context("Failed to initialize router state")?, ); @@ -841,14 +841,14 @@ fn start_pageserver( } else { None }, + feature_resolver.clone(), ); - // Spawn a Pageserver gRPC server task. It will spawn separate tasks for - // each stream/request. + // Spawn a Pageserver gRPC server task. It will spawn separate tasks for each request/stream. + // It uses a separate compute request Tokio runtime (COMPUTE_REQUEST_RUNTIME). // - // TODO: this uses a separate Tokio runtime for the page service. If we want - // other gRPC services, they will need their own port and runtime. Is this - // necessary? + // NB: this port is exposed to computes. It should only provide services that we're okay with + // computes accessing. Internal services should use a separate port. let mut page_service_grpc = None; if let Some(grpc_listener) = grpc_listener { page_service_grpc = Some(GrpcPageServiceHandler::spawn( diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index a0ea9b90dc..669eeffa32 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -2005,6 +2005,10 @@ async fn put_tenant_location_config_handler( let state = get_state(&request); let conf = state.conf; + fail::fail_point!("put-location-conf-handler", |_| { + Err(ApiError::ResourceUnavailable("failpoint".into())) + }); + // The `Detached` state is special, it doesn't upsert a tenant, it removes // its local disk content and drops it from memory. if let LocationConfigMode::Detached = request_data.config.mode { diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index f60eab68e6..1b783326a0 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -2914,7 +2914,6 @@ pub(crate) struct WalIngestMetrics { pub(crate) records_received: IntCounter, pub(crate) records_observed: IntCounter, pub(crate) records_committed: IntCounter, - pub(crate) records_filtered: IntCounter, pub(crate) values_committed_metadata_images: IntCounter, pub(crate) values_committed_metadata_deltas: IntCounter, pub(crate) values_committed_data_images: IntCounter, @@ -2970,11 +2969,6 @@ pub(crate) static WAL_INGEST: Lazy = Lazy::new(|| { "Number of WAL records which resulted in writes to pageserver storage" ) .expect("failed to define a metric"), - records_filtered: register_int_counter!( - "pageserver_wal_ingest_records_filtered", - "Number of WAL records filtered out due to sharding" - ) - .expect("failed to define a metric"), values_committed_metadata_images: values_committed.with_label_values(&["metadata", "image"]), values_committed_metadata_deltas: values_committed.with_label_values(&["metadata", "delta"]), values_committed_data_images: values_committed.with_label_values(&["data", "image"]), diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 26a23da66f..116e289e99 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -16,7 +16,8 @@ use anyhow::{Context as _, bail}; use bytes::{Buf as _, BufMut as _, BytesMut}; use chrono::Utc; use futures::future::BoxFuture; -use futures::{FutureExt, Stream}; +use futures::stream::FuturesUnordered; +use futures::{FutureExt, Stream, StreamExt as _}; use itertools::Itertools; use jsonwebtoken::TokenData; use once_cell::sync::OnceCell; @@ -35,8 +36,8 @@ use pageserver_api::pagestream_api::{ }; use pageserver_api::reltag::SlruKind; use pageserver_api::shard::TenantShardId; -use pageserver_page_api as page_api; use pageserver_page_api::proto; +use pageserver_page_api::{self as page_api, GetPageSplitter}; use postgres_backend::{ AuthType, PostgresBackend, PostgresBackendReader, QueryError, is_expected_io_error, }; @@ -68,6 +69,7 @@ use crate::config::PageServerConf; use crate::context::{ DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder, }; +use crate::feature_resolver::FeatureResolver; use crate::metrics::{ self, COMPUTE_COMMANDS_COUNTERS, ComputeCommandKind, GetPageBatchBreakReason, LIVE_CONNECTIONS, MISROUTED_PAGESTREAM_REQUESTS, PAGESTREAM_HANDLER_RESULTS_TOTAL, SmgrOpTimer, TimelineMetrics, @@ -139,6 +141,7 @@ pub fn spawn( perf_trace_dispatch: Option, tcp_listener: tokio::net::TcpListener, tls_config: Option>, + feature_resolver: FeatureResolver, ) -> Listener { let cancel = CancellationToken::new(); let libpq_ctx = RequestContext::todo_child( @@ -160,6 +163,7 @@ pub fn spawn( conf.pg_auth_type, tls_config, conf.page_service_pipelining.clone(), + feature_resolver, libpq_ctx, cancel.clone(), ) @@ -218,6 +222,7 @@ pub async fn libpq_listener_main( auth_type: AuthType, tls_config: Option>, pipelining_config: PageServicePipeliningConfig, + feature_resolver: FeatureResolver, listener_ctx: RequestContext, listener_cancel: CancellationToken, ) -> Connections { @@ -261,6 +266,7 @@ pub async fn libpq_listener_main( auth_type, tls_config.clone(), pipelining_config.clone(), + feature_resolver.clone(), connection_ctx, connections_cancel.child_token(), gate_guard, @@ -303,6 +309,7 @@ async fn page_service_conn_main( auth_type: AuthType, tls_config: Option>, pipelining_config: PageServicePipeliningConfig, + feature_resolver: FeatureResolver, connection_ctx: RequestContext, cancel: CancellationToken, gate_guard: GateGuard, @@ -370,6 +377,7 @@ async fn page_service_conn_main( perf_span_fields, connection_ctx, cancel.clone(), + feature_resolver.clone(), gate_guard, ); let pgbackend = @@ -421,6 +429,8 @@ struct PageServerHandler { pipelining_config: PageServicePipeliningConfig, get_vectored_concurrent_io: GetVectoredConcurrentIo, + feature_resolver: FeatureResolver, + gate_guard: GateGuard, } @@ -457,13 +467,6 @@ impl TimelineHandles { self.handles .get(timeline_id, shard_selector, &self.wrapper) .await - .map_err(|e| match e { - timeline::handle::GetError::TenantManager(e) => e, - timeline::handle::GetError::PerTimelineStateShutDown => { - trace!("per-timeline state shut down"); - GetActiveTimelineError::Timeline(GetTimelineError::ShuttingDown) - } - }) } fn tenant_id(&self) -> Option { @@ -479,11 +482,9 @@ pub(crate) struct TenantManagerWrapper { tenant_id: once_cell::sync::OnceCell, } -#[derive(Debug)] pub(crate) struct TenantManagerTypes; impl timeline::handle::Types for TenantManagerTypes { - type TenantManagerError = GetActiveTimelineError; type TenantManager = TenantManagerWrapper; type Timeline = TenantManagerCacheItem; } @@ -535,6 +536,7 @@ impl timeline::handle::TenantManager for TenantManagerWrappe match resolved { ShardResolveResult::Found(tenant_shard) => break tenant_shard, ShardResolveResult::NotFound => { + MISROUTED_PAGESTREAM_REQUESTS.inc(); return Err(GetActiveTimelineError::Tenant( GetActiveTenantError::NotFound(GetTenantError::NotFound(*tenant_id)), )); @@ -586,6 +588,15 @@ impl timeline::handle::TenantManager for TenantManagerWrappe } } +/// Whether to hold the applied GC cutoff guard when processing GetPage requests. +/// This is determined once at the start of pagestream subprotocol handling based on +/// feature flags, configuration, and test conditions. +#[derive(Debug, Clone, Copy)] +enum HoldAppliedGcCutoffGuard { + Yes, + No, +} + #[derive(thiserror::Error, Debug)] enum PageStreamError { /// We encountered an error that should prompt the client to reconnect: @@ -729,6 +740,7 @@ enum BatchedFeMessage { GetPage { span: Span, shard: WeakHandle, + applied_gc_cutoff_guard: Option>, pages: SmallVec<[BatchedGetPageRequest; 1]>, batch_break_reason: GetPageBatchBreakReason, }, @@ -908,6 +920,7 @@ impl PageServerHandler { perf_span_fields: ConnectionPerfSpanFields, connection_ctx: RequestContext, cancel: CancellationToken, + feature_resolver: FeatureResolver, gate_guard: GateGuard, ) -> Self { PageServerHandler { @@ -919,6 +932,7 @@ impl PageServerHandler { cancel, pipelining_config, get_vectored_concurrent_io, + feature_resolver, gate_guard, } } @@ -958,6 +972,7 @@ impl PageServerHandler { ctx: &RequestContext, protocol_version: PagestreamProtocolVersion, parent_span: Span, + hold_gc_cutoff_guard: HoldAppliedGcCutoffGuard, ) -> Result, QueryError> where IO: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, @@ -1195,19 +1210,27 @@ impl PageServerHandler { }) .await?; + let applied_gc_cutoff_guard = shard.get_applied_gc_cutoff_lsn(); // hold guard // We're holding the Handle let effective_lsn = match Self::effective_request_lsn( &shard, shard.get_last_record_lsn(), req.hdr.request_lsn, req.hdr.not_modified_since, - &shard.get_applied_gc_cutoff_lsn(), + &applied_gc_cutoff_guard, ) { Ok(lsn) => lsn, Err(e) => { return respond_error!(span, e); } }; + let applied_gc_cutoff_guard = match hold_gc_cutoff_guard { + HoldAppliedGcCutoffGuard::Yes => Some(applied_gc_cutoff_guard), + HoldAppliedGcCutoffGuard::No => { + drop(applied_gc_cutoff_guard); + None + } + }; let batch_wait_ctx = if ctx.has_perf_span() { Some( @@ -1228,6 +1251,7 @@ impl PageServerHandler { BatchedFeMessage::GetPage { span, shard: shard.downgrade(), + applied_gc_cutoff_guard, pages: smallvec![BatchedGetPageRequest { req, timer, @@ -1328,13 +1352,28 @@ impl PageServerHandler { match (eligible_batch, this_msg) { ( BatchedFeMessage::GetPage { - pages: accum_pages, .. + pages: accum_pages, + applied_gc_cutoff_guard: accum_applied_gc_cutoff_guard, + .. }, BatchedFeMessage::GetPage { - pages: this_pages, .. + pages: this_pages, + applied_gc_cutoff_guard: this_applied_gc_cutoff_guard, + .. }, ) => { accum_pages.extend(this_pages); + // the minimum of the two guards will keep data for both alive + match (&accum_applied_gc_cutoff_guard, this_applied_gc_cutoff_guard) { + (None, None) => (), + (None, Some(this)) => *accum_applied_gc_cutoff_guard = Some(this), + (Some(_), None) => (), + (Some(accum), Some(this)) => { + if **accum > *this { + *accum_applied_gc_cutoff_guard = Some(this); + } + } + }; Ok(()) } #[cfg(feature = "testing")] @@ -1649,6 +1688,7 @@ impl PageServerHandler { BatchedFeMessage::GetPage { span, shard, + applied_gc_cutoff_guard, pages, batch_break_reason, } => { @@ -1668,6 +1708,7 @@ impl PageServerHandler { .instrument(span.clone()) .await; assert_eq!(res.len(), npages); + drop(applied_gc_cutoff_guard); res }, span, @@ -1749,7 +1790,7 @@ impl PageServerHandler { /// Coding discipline within this function: all interaction with the `pgb` connection /// needs to be sensitive to connection shutdown, currently signalled via [`Self::cancel`]. /// This is so that we can shutdown page_service quickly. - #[instrument(skip_all)] + #[instrument(skip_all, fields(hold_gc_cutoff_guard))] async fn handle_pagerequests( &mut self, pgb: &mut PostgresBackend, @@ -1795,6 +1836,30 @@ impl PageServerHandler { .take() .expect("implementation error: timeline_handles should not be locked"); + // Evaluate the expensive feature resolver check once per pagestream subprotocol handling + // instead of once per GetPage request. This is shared between pipelined and serial paths. + let hold_gc_cutoff_guard = if cfg!(test) || cfg!(feature = "testing") { + HoldAppliedGcCutoffGuard::Yes + } else { + // Use the global feature resolver with the tenant ID directly, avoiding the need + // to get a timeline/shard which might not be available on this pageserver node. + let empty_properties = std::collections::HashMap::new(); + match self.feature_resolver.evaluate_boolean( + "page-service-getpage-hold-applied-gc-cutoff-guard", + tenant_id, + &empty_properties, + ) { + Ok(()) => HoldAppliedGcCutoffGuard::Yes, + Err(_) => HoldAppliedGcCutoffGuard::No, + } + }; + // record it in the span of handle_pagerequests so that both the request_span + // and the pipeline implementation spans contains the field. + Span::current().record( + "hold_gc_cutoff_guard", + tracing::field::debug(&hold_gc_cutoff_guard), + ); + let request_span = info_span!("request"); let ((pgb_reader, timeline_handles), result) = match self.pipelining_config.clone() { PageServicePipeliningConfig::Pipelined(pipelining_config) => { @@ -1808,6 +1873,7 @@ impl PageServerHandler { pipelining_config, protocol_version, io_concurrency, + hold_gc_cutoff_guard, &ctx, ) .await @@ -1822,6 +1888,7 @@ impl PageServerHandler { request_span, protocol_version, io_concurrency, + hold_gc_cutoff_guard, &ctx, ) .await @@ -1850,6 +1917,7 @@ impl PageServerHandler { request_span: Span, protocol_version: PagestreamProtocolVersion, io_concurrency: IoConcurrency, + hold_gc_cutoff_guard: HoldAppliedGcCutoffGuard, ctx: &RequestContext, ) -> ( (PostgresBackendReader, TimelineHandles), @@ -1871,6 +1939,7 @@ impl PageServerHandler { ctx, protocol_version, request_span.clone(), + hold_gc_cutoff_guard, ) .await; let msg = match msg { @@ -1918,6 +1987,7 @@ impl PageServerHandler { pipelining_config: PageServicePipeliningConfigPipelined, protocol_version: PagestreamProtocolVersion, io_concurrency: IoConcurrency, + hold_gc_cutoff_guard: HoldAppliedGcCutoffGuard, ctx: &RequestContext, ) -> ( (PostgresBackendReader, TimelineHandles), @@ -2021,6 +2091,7 @@ impl PageServerHandler { &ctx, protocol_version, request_span.clone(), + hold_gc_cutoff_guard, ) .await; let Some(read_res) = read_res.transpose() else { @@ -2067,6 +2138,7 @@ impl PageServerHandler { pages, span: _, shard: _, + applied_gc_cutoff_guard: _, batch_break_reason: _, } = &mut batch { @@ -3352,18 +3424,6 @@ impl GrpcPageServiceHandler { Ok(CancellableTask { task, cancel }) } - /// Errors if the request is executed on a non-zero shard. Only shard 0 has a complete view of - /// relations and their sizes, as well as SLRU segments and similar data. - #[allow(clippy::result_large_err)] - fn ensure_shard_zero(timeline: &Handle) -> Result<(), tonic::Status> { - match timeline.get_shard_index().shard_number.0 { - 0 => Ok(()), - shard => Err(tonic::Status::invalid_argument(format!( - "request must execute on shard zero (is shard {shard})", - ))), - } - } - /// Generates a PagestreamRequest header from a ReadLsn and request ID. fn make_hdr( read_lsn: page_api::ReadLsn, @@ -3378,30 +3438,72 @@ impl GrpcPageServiceHandler { } } - /// Acquires a timeline handle for the given request. + /// Acquires a timeline handle for the given request. The shard index must match a local shard. /// - /// TODO: during shard splits, the compute may still be sending requests to the parent shard - /// until the entire split is committed and the compute is notified. Consider installing a - /// temporary shard router from the parent to the children while the split is in progress. - /// - /// TODO: consider moving this to a middleware layer; all requests need it. Needs to manage - /// the TimelineHandles lifecycle. - /// - /// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to avoid - /// the unnecessary overhead. + /// NB: this will fail during shard splits, see comment on [`Self::maybe_split_get_page`]. async fn get_request_timeline( &self, req: &tonic::Request, ) -> Result, GetActiveTimelineError> { - let ttid = *extract::(req); + let TenantTimelineId { + tenant_id, + timeline_id, + } = *extract::(req); let shard_index = *extract::(req); - let shard_selector = ShardSelector::Known(shard_index); + // TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to + // avoid the unnecessary overhead. TimelineHandles::new(self.tenant_manager.clone()) - .get(ttid.tenant_id, ttid.timeline_id, shard_selector) + .get(tenant_id, timeline_id, ShardSelector::Known(shard_index)) .await } + /// Acquires a timeline handle for the given request, which must be for shard zero. Most + /// metadata requests are only valid on shard zero. + /// + /// NB: during an ongoing shard split, the compute will keep talking to the parent shard until + /// the split is committed, but the parent shard may have been removed in the meanwhile. In that + /// case, we reroute the request to the new child shard. See [`Self::maybe_split_get_page`]. + /// + /// TODO: revamp the split protocol to avoid this child routing. + async fn get_request_timeline_shard_zero( + &self, + req: &tonic::Request, + ) -> Result, tonic::Status> { + let TenantTimelineId { + tenant_id, + timeline_id, + } = *extract::(req); + let shard_index = *extract::(req); + + if shard_index.shard_number.0 != 0 { + return Err(tonic::Status::invalid_argument(format!( + "request only valid on shard zero (requested shard {shard_index})", + ))); + } + + // TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to + // avoid the unnecessary overhead. + let mut handles = TimelineHandles::new(self.tenant_manager.clone()); + match handles + .get(tenant_id, timeline_id, ShardSelector::Known(shard_index)) + .await + { + Ok(timeline) => Ok(timeline), + Err(err) => { + // We may be in the middle of a shard split. Try to find a child shard 0. + if let Ok(timeline) = handles + .get(tenant_id, timeline_id, ShardSelector::Zero) + .await + && timeline.get_shard_index().shard_count > shard_index.shard_count + { + return Ok(timeline); + } + Err(err.into()) + } + } + } + /// Starts a SmgrOpTimer at received_at, throttles the request, and records execution start. /// Only errors if the timeline is shutting down. /// @@ -3428,32 +3530,37 @@ impl GrpcPageServiceHandler { /// NB: errors returned from here are intercepted in get_pages(), and may be converted to a /// GetPageResponse with an appropriate status code to avoid terminating the stream. /// - /// TODO: verify that the requested pages belong to this shard. - /// /// TODO: get_vectored() currently enforces a batch limit of 32. Postgres will typically send /// batches up to effective_io_concurrency = 100. Either we have to accept large batches, or /// split them up in the client or server. - #[instrument(skip_all, fields(req_id, rel, blkno, blks, req_lsn, mod_lsn))] + #[instrument(skip_all, fields( + req_id = %req.request_id, + rel = %req.rel, + blkno = %req.block_numbers[0], + blks = %req.block_numbers.len(), + lsn = %req.read_lsn, + ))] async fn get_page( ctx: &RequestContext, - timeline: &WeakHandle, - req: proto::GetPageRequest, + timeline: Handle, + req: page_api::GetPageRequest, io_concurrency: IoConcurrency, - ) -> Result { - let received_at = Instant::now(); - let timeline = timeline.upgrade()?; + received_at: Instant, + ) -> Result { let ctx = ctx.with_scope_page_service_pagestream(&timeline); - // Validate the request, decorate the span, and convert it to a Pagestream request. - let req = page_api::GetPageRequest::try_from(req)?; - - span_record!( - req_id = %req.request_id, - rel = %req.rel, - blkno = %req.block_numbers[0], - blks = %req.block_numbers.len(), - lsn = %req.read_lsn, - ); + for &blkno in &req.block_numbers { + let shard = timeline.get_shard_identity(); + let key = rel_block_to_key(req.rel, blkno); + if !shard.is_key_local(&key) { + return Err(tonic::Status::invalid_argument(format!( + "block {blkno} of relation {} requested on wrong shard {} (is on {})", + req.rel, + timeline.get_shard_index(), + ShardIndex::new(shard.get_shard_number(&key), shard.count), + ))); + } + } let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); // hold guard let effective_lsn = PageServerHandler::effective_request_lsn( @@ -3529,7 +3636,89 @@ impl GrpcPageServiceHandler { }; } - Ok(resp.into()) + Ok(resp) + } + + /// Processes a GetPage request when there is a potential shard split in progress. We have to + /// reroute the request to any local child shards, and split batch requests that straddle + /// multiple child shards. + /// + /// Parent shards are split and removed incrementally (there may be many parent shards when + /// splitting an already-sharded tenant), but the compute is only notified once the overall + /// split commits, which can take several minutes. In the meanwhile, the compute will be sending + /// requests to the parent shards. + /// + /// TODO: add test infrastructure to provoke this situation frequently and for long periods of + /// time, to properly exercise it. + /// + /// TODO: revamp the split protocol to avoid this, e.g.: + /// * Keep the parent shard until the split commits and the compute is notified. + /// * Notify the compute about each subsplit. + /// * Return an error that updates the compute's shard map. + #[instrument(skip_all)] + #[allow(clippy::too_many_arguments)] + async fn maybe_split_get_page( + ctx: &RequestContext, + handles: &mut TimelineHandles, + tenant_id: TenantId, + timeline_id: TimelineId, + parent: ShardIndex, + req: page_api::GetPageRequest, + io_concurrency: IoConcurrency, + received_at: Instant, + ) -> Result { + // Check the first page to see if we have any child shards at all. Otherwise, the compute is + // just talking to the wrong Pageserver. If the parent has been split, the shard now owning + // the page must have a higher shard count. + let timeline = handles + .get( + tenant_id, + timeline_id, + ShardSelector::Page(rel_block_to_key(req.rel, req.block_numbers[0])), + ) + .await?; + + let shard_id = timeline.get_shard_identity(); + if shard_id.count <= parent.shard_count { + return Err(HandleUpgradeError::ShutDown.into()); // emulate original error + } + + // Fast path: the request fits in a single shard. + if let Some(shard_index) = + GetPageSplitter::for_single_shard(&req, shard_id.count, Some(shard_id.stripe_size))? + { + // We got the shard ID from the first page, so these must be equal. + assert_eq!(shard_index.shard_number, shard_id.number); + assert_eq!(shard_index.shard_count, shard_id.count); + return Self::get_page(ctx, timeline, req, io_concurrency, received_at).await; + } + + // The request spans multiple shards; split it and dispatch parallel requests. All pages + // were originally in the parent shard, and during a split all children are local, so we + // expect to find local shards for all pages. + let mut splitter = GetPageSplitter::split(req, shard_id.count, Some(shard_id.stripe_size))?; + + let mut shard_requests = FuturesUnordered::new(); + for (shard_index, shard_req) in splitter.drain_requests() { + let timeline = handles + .get(tenant_id, timeline_id, ShardSelector::Known(shard_index)) + .await?; + let future = Self::get_page( + ctx, + timeline, + shard_req, + io_concurrency.clone(), + received_at, + ) + .map(move |result| result.map(|resp| (shard_index, resp))); + shard_requests.push(future); + } + + while let Some((shard_index, shard_response)) = shard_requests.next().await.transpose()? { + splitter.add_response(shard_index, shard_response)?; + } + + Ok(splitter.collect_response()?) } } @@ -3558,11 +3747,10 @@ impl proto::PageService for GrpcPageServiceHandler { // to be the sweet spot where throughput is saturated. const CHUNK_SIZE: usize = 256 * 1024; - let timeline = self.get_request_timeline(&req).await?; + let timeline = self.get_request_timeline_shard_zero(&req).await?; let ctx = self.ctx.with_scope_timeline(&timeline); // Validate the request and decorate the span. - Self::ensure_shard_zero(&timeline)?; if timeline.is_archived() == Some(true) { return Err(tonic::Status::failed_precondition("timeline is archived")); } @@ -3678,11 +3866,10 @@ impl proto::PageService for GrpcPageServiceHandler { req: tonic::Request, ) -> Result, tonic::Status> { let received_at = extract::(&req).0; - let timeline = self.get_request_timeline(&req).await?; + let timeline = self.get_request_timeline_shard_zero(&req).await?; let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); // Validate the request, decorate the span, and convert it to a Pagestream request. - Self::ensure_shard_zero(&timeline)?; let req: page_api::GetDbSizeRequest = req.into_inner().try_into()?; span_record!(db_oid=%req.db_oid, lsn=%req.read_lsn); @@ -3711,14 +3898,29 @@ impl proto::PageService for GrpcPageServiceHandler { req: tonic::Request>, ) -> Result, tonic::Status> { // Extract the timeline from the request and check that it exists. - let ttid = *extract::(&req); + // + // NB: during shard splits, the compute may still send requests to the parent shard. We'll + // reroute requests to the child shards below, but we also detect the common cases here + // where either the shard exists or no shards exist at all. If we have a child shard, we + // can't acquire a weak handle because we don't know which child shard to use yet. + let TenantTimelineId { + tenant_id, + timeline_id, + } = *extract::(&req); let shard_index = *extract::(&req); - let shard_selector = ShardSelector::Known(shard_index); let mut handles = TimelineHandles::new(self.tenant_manager.clone()); - handles - .get(ttid.tenant_id, ttid.timeline_id, shard_selector) - .await?; + let timeline = match handles + .get(tenant_id, timeline_id, ShardSelector::Known(shard_index)) + .await + { + // The timeline shard exists. Keep a weak handle to reuse for each request. + Ok(timeline) => Some(timeline.downgrade()), + // The shard doesn't exist, but a child shard does. We'll reroute requests later. + Err(_) if self.tenant_manager.has_child_shard(tenant_id, shard_index) => None, + // Failed to fetch the timeline, and no child shard exists. Error out. + Err(err) => return Err(err.into()), + }; // Spawn an IoConcurrency sidecar, if enabled. let gate_guard = self @@ -3735,11 +3937,9 @@ impl proto::PageService for GrpcPageServiceHandler { let mut reqs = req.into_inner(); let resps = async_stream::try_stream! { - let timeline = handles - .get(ttid.tenant_id, ttid.timeline_id, shard_selector) - .await? - .downgrade(); loop { + // Wait for the next client request. + // // NB: Tonic considers the entire stream to be an in-flight request and will wait // for it to complete before shutting down. React to cancellation between requests. let req = tokio::select! { @@ -3752,16 +3952,44 @@ impl proto::PageService for GrpcPageServiceHandler { Err(err) => Err(err), }, }?; + + let received_at = Instant::now(); let req_id = req.request_id.map(page_api::RequestID::from).unwrap_or_default(); - let result = Self::get_page(&ctx, &timeline, req, io_concurrency.clone()) + + // Process the request, using a closure to capture errors. + let process_request = async || { + let req = page_api::GetPageRequest::try_from(req)?; + + // Fast path: use the pre-acquired timeline handle. + if let Some(Ok(timeline)) = timeline.as_ref().map(|t| t.upgrade()) { + return Self::get_page(&ctx, timeline, req, io_concurrency.clone(), received_at) + .instrument(span.clone()) // propagate request span + .await + } + + // The timeline handle is stale. During shard splits, the compute may still be + // sending requests to the parent shard. Try to re-route requests to the child + // shards, and split any batch requests that straddle multiple child shards. + Self::maybe_split_get_page( + &ctx, + &mut handles, + tenant_id, + timeline_id, + shard_index, + req, + io_concurrency.clone(), + received_at, + ) .instrument(span.clone()) // propagate request span - .await; - yield match result { - Ok(resp) => resp, - // Convert per-request errors to GetPageResponses as appropriate, or terminate - // the stream with a tonic::Status. Log the error regardless, since - // ObservabilityLayer can't automatically log stream errors. + .await + }; + + // Return the response. Convert per-request errors to GetPageResponses if + // appropriate, or terminate the stream with a tonic::Status. + yield match process_request().await { + Ok(resp) => resp.into(), Err(status) => { + // Log the error, since ObservabilityLayer won't see stream errors. // TODO: it would be nice if we could propagate the get_page() fields here. span.in_scope(|| { warn!("request failed with {:?}: {}", status.code(), status.message()); @@ -3781,11 +4009,10 @@ impl proto::PageService for GrpcPageServiceHandler { req: tonic::Request, ) -> Result, tonic::Status> { let received_at = extract::(&req).0; - let timeline = self.get_request_timeline(&req).await?; + let timeline = self.get_request_timeline_shard_zero(&req).await?; let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); // Validate the request, decorate the span, and convert it to a Pagestream request. - Self::ensure_shard_zero(&timeline)?; let req: page_api::GetRelSizeRequest = req.into_inner().try_into()?; let allow_missing = req.allow_missing; @@ -3818,11 +4045,10 @@ impl proto::PageService for GrpcPageServiceHandler { req: tonic::Request, ) -> Result, tonic::Status> { let received_at = extract::(&req).0; - let timeline = self.get_request_timeline(&req).await?; + let timeline = self.get_request_timeline_shard_zero(&req).await?; let ctx = self.ctx.with_scope_page_service_pagestream(&timeline); // Validate the request, decorate the span, and convert it to a Pagestream request. - Self::ensure_shard_zero(&timeline)?; let req: page_api::GetSlruSegmentRequest = req.into_inner().try_into()?; span_record!(kind=%req.kind, segno=%req.segno, lsn=%req.read_lsn); @@ -3852,6 +4078,10 @@ impl proto::PageService for GrpcPageServiceHandler { &self, req: tonic::Request, ) -> Result, tonic::Status> { + // TODO: this won't work during shard splits, as the request is directed at a specific shard + // but the parent shard is removed before the split commits and the compute is notified + // (which can take several minutes for large tenants). That's also the case for the libpq + // implementation, so we keep the behavior for now. let timeline = self.get_request_timeline(&req).await?; let ctx = self.ctx.with_scope_timeline(&timeline); diff --git a/pageserver/src/tenant/mgr.rs b/pageserver/src/tenant/mgr.rs index b47bab16d8..0feba5e9c8 100644 --- a/pageserver/src/tenant/mgr.rs +++ b/pageserver/src/tenant/mgr.rs @@ -826,6 +826,18 @@ impl TenantManager { peek_slot.is_some() } + /// Returns whether a local shard exists that's a child of the given tenant shard. Note that + /// this just checks for any shard with a larger shard count, and it may not be a direct child + /// of the given shard (their keyspace may not overlap). + pub(crate) fn has_child_shard(&self, tenant_id: TenantId, shard_index: ShardIndex) -> bool { + match &*self.tenants.read().unwrap() { + TenantsMap::Initializing => false, + TenantsMap::Open(slots) | TenantsMap::ShuttingDown(slots) => slots + .range(TenantShardId::tenant_range(tenant_id)) + .any(|(tsid, _)| tsid.shard_count > shard_index.shard_count), + } + } + #[instrument(skip_all, fields(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug()))] pub(crate) async fn upsert_location( &self, @@ -1522,6 +1534,13 @@ impl TenantManager { self.resources.deletion_queue_client.flush_advisory(); // Phase 2: Put the parent shard to InProgress and grab a reference to the parent Tenant + // + // TODO: keeping the parent as InProgress while spawning the children causes read + // unavailability, as we can't acquire a new timeline handle for it (existing handles appear + // to still work though, even downgraded ones). The parent should be available for reads + // until the children are ready -- potentially until *all* subsplits across all parent + // shards are complete and the compute has been notified. See: + // . drop(tenant); let mut parent_slot_guard = self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?; diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index f82c47ec3a..ff66b0ecc8 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -70,7 +70,7 @@ use tracing::*; use utils::generation::Generation; use utils::guard_arc_swap::GuardArcSwap; use utils::id::TimelineId; -use utils::logging::{MonitorSlowFutureCallback, monitor_slow_future}; +use utils::logging::{MonitorSlowFutureCallback, log_slow, monitor_slow_future}; use utils::lsn::{AtomicLsn, Lsn, RecordLsn}; use utils::postgres_client::PostgresClientProtocol; use utils::rate_limit::RateLimit; @@ -6915,7 +6915,13 @@ impl Timeline { write_guard.store_and_unlock(new_gc_cutoff) }; - waitlist.wait().await; + let waitlist_wait_fut = std::pin::pin!(waitlist.wait()); + log_slow( + "applied_gc_cutoff waitlist wait", + Duration::from_secs(30), + waitlist_wait_fut, + ) + .await; info!("GC starting"); diff --git a/pageserver/src/tenant/timeline/handle.rs b/pageserver/src/tenant/timeline/handle.rs index 3570cab301..537b9ff373 100644 --- a/pageserver/src/tenant/timeline/handle.rs +++ b/pageserver/src/tenant/timeline/handle.rs @@ -224,11 +224,11 @@ use tracing::{instrument, trace}; use utils::id::TimelineId; use utils::shard::{ShardIndex, ShardNumber}; -use crate::tenant::mgr::ShardSelector; +use crate::page_service::GetActiveTimelineError; +use crate::tenant::GetTimelineError; +use crate::tenant::mgr::{GetActiveTenantError, ShardSelector}; -/// The requirement for Debug is so that #[derive(Debug)] works in some places. -pub(crate) trait Types: Sized + std::fmt::Debug { - type TenantManagerError: Sized + std::fmt::Debug; +pub(crate) trait Types: Sized { type TenantManager: TenantManager + Sized; type Timeline: Timeline + Sized; } @@ -307,12 +307,11 @@ impl Default for PerTimelineState { /// Abstract view of [`crate::tenant::mgr`], for testability. pub(crate) trait TenantManager { /// Invoked by [`Cache::get`] to resolve a [`ShardTimelineId`] to a [`Types::Timeline`]. - /// Errors are returned as [`GetError::TenantManager`]. async fn resolve( &self, timeline_id: TimelineId, shard_selector: ShardSelector, - ) -> Result; + ) -> Result; } /// Abstract view of an [`Arc`], for testability. @@ -322,13 +321,6 @@ pub(crate) trait Timeline { fn per_timeline_state(&self) -> &PerTimelineState; } -/// Errors returned by [`Cache::get`]. -#[derive(Debug)] -pub(crate) enum GetError { - TenantManager(T::TenantManagerError), - PerTimelineStateShutDown, -} - /// Internal type used in [`Cache::get`]. enum RoutingResult { FastPath(Handle), @@ -345,7 +337,7 @@ impl Cache { timeline_id: TimelineId, shard_selector: ShardSelector, tenant_manager: &T::TenantManager, - ) -> Result, GetError> { + ) -> Result, GetActiveTimelineError> { const GET_MAX_RETRIES: usize = 10; const RETRY_BACKOFF: Duration = Duration::from_millis(100); let mut attempt = 0; @@ -356,7 +348,11 @@ impl Cache { .await { Ok(handle) => return Ok(handle), - Err(e) => { + Err( + e @ GetActiveTimelineError::Tenant(GetActiveTenantError::WaitForActiveTimeout { + .. + }), + ) => { // Retry on tenant manager error to handle tenant split more gracefully if attempt < GET_MAX_RETRIES { tokio::time::sleep(RETRY_BACKOFF).await; @@ -370,6 +366,7 @@ impl Cache { return Err(e); } } + Err(err) => return Err(err), } } } @@ -388,7 +385,7 @@ impl Cache { timeline_id: TimelineId, shard_selector: ShardSelector, tenant_manager: &T::TenantManager, - ) -> Result, GetError> { + ) -> Result, GetActiveTimelineError> { // terminates because when every iteration we remove an element from the map let miss: ShardSelector = loop { let routing_state = self.shard_routing(timeline_id, shard_selector); @@ -468,60 +465,50 @@ impl Cache { timeline_id: TimelineId, shard_selector: ShardSelector, tenant_manager: &T::TenantManager, - ) -> Result, GetError> { - match tenant_manager.resolve(timeline_id, shard_selector).await { - Ok(timeline) => { - let key = timeline.shard_timeline_id(); - match &shard_selector { - ShardSelector::Zero => assert_eq!(key.shard_index.shard_number, ShardNumber(0)), - ShardSelector::Page(_) => (), // gotta trust tenant_manager - ShardSelector::Known(idx) => assert_eq!(idx, &key.shard_index), - } - - trace!("creating new HandleInner"); - let timeline = Arc::new(timeline); - let handle_inner_arc = - Arc::new(Mutex::new(HandleInner::Open(Arc::clone(&timeline)))); - let handle_weak = WeakHandle { - inner: Arc::downgrade(&handle_inner_arc), - }; - let handle = handle_weak - .upgrade() - .ok() - .expect("we just created it and it's not linked anywhere yet"); - { - let mut lock_guard = timeline - .per_timeline_state() - .handles - .lock() - .expect("mutex poisoned"); - match &mut *lock_guard { - Some(per_timeline_state) => { - let replaced = - per_timeline_state.insert(self.id, Arc::clone(&handle_inner_arc)); - assert!(replaced.is_none(), "some earlier code left a stale handle"); - match self.map.entry(key) { - hash_map::Entry::Occupied(_o) => { - // This cannot not happen because - // 1. we're the _miss_ handle, i.e., `self.map` didn't contain an entry and - // 2. we were holding &mut self during .resolve().await above, so, no other thread can have inserted a handle - // while we were waiting for the tenant manager. - unreachable!() - } - hash_map::Entry::Vacant(v) => { - v.insert(handle_weak); - } - } - } - None => { - return Err(GetError::PerTimelineStateShutDown); - } - } - } - Ok(handle) - } - Err(e) => Err(GetError::TenantManager(e)), + ) -> Result, GetActiveTimelineError> { + let timeline = tenant_manager.resolve(timeline_id, shard_selector).await?; + let key = timeline.shard_timeline_id(); + match &shard_selector { + ShardSelector::Zero => assert_eq!(key.shard_index.shard_number, ShardNumber(0)), + ShardSelector::Page(_) => (), // gotta trust tenant_manager + ShardSelector::Known(idx) => assert_eq!(idx, &key.shard_index), } + + trace!("creating new HandleInner"); + let timeline = Arc::new(timeline); + let handle_inner_arc = Arc::new(Mutex::new(HandleInner::Open(Arc::clone(&timeline)))); + let handle_weak = WeakHandle { + inner: Arc::downgrade(&handle_inner_arc), + }; + let handle = handle_weak + .upgrade() + .ok() + .expect("we just created it and it's not linked anywhere yet"); + let mut lock_guard = timeline + .per_timeline_state() + .handles + .lock() + .expect("mutex poisoned"); + let Some(per_timeline_state) = &mut *lock_guard else { + return Err(GetActiveTimelineError::Timeline( + GetTimelineError::ShuttingDown, + )); + }; + let replaced = per_timeline_state.insert(self.id, Arc::clone(&handle_inner_arc)); + assert!(replaced.is_none(), "some earlier code left a stale handle"); + match self.map.entry(key) { + hash_map::Entry::Occupied(_o) => { + // This cannot not happen because + // 1. we're the _miss_ handle, i.e., `self.map` didn't contain an entry and + // 2. we were holding &mut self during .resolve().await above, so, no other thread can have inserted a handle + // while we were waiting for the tenant manager. + unreachable!() + } + hash_map::Entry::Vacant(v) => { + v.insert(handle_weak); + } + } + Ok(handle) } } @@ -655,7 +642,8 @@ mod tests { use pageserver_api::models::ShardParameters; use pageserver_api::reltag::RelTag; use pageserver_api::shard::DEFAULT_STRIPE_SIZE; - use utils::shard::ShardCount; + use utils::id::TenantId; + use utils::shard::{ShardCount, TenantShardId}; use utils::sync::gate::GateGuard; use super::*; @@ -665,7 +653,6 @@ mod tests { #[derive(Debug)] struct TestTypes; impl Types for TestTypes { - type TenantManagerError = anyhow::Error; type TenantManager = StubManager; type Timeline = Entered; } @@ -716,40 +703,48 @@ mod tests { &self, timeline_id: TimelineId, shard_selector: ShardSelector, - ) -> anyhow::Result { + ) -> Result { + fn enter_gate( + timeline: &StubTimeline, + ) -> Result, GetActiveTimelineError> { + Ok(Arc::new(timeline.gate.enter().map_err(|_| { + GetActiveTimelineError::Timeline(GetTimelineError::ShuttingDown) + })?)) + } + for timeline in &self.shards { if timeline.id == timeline_id { - let enter_gate = || { - let gate_guard = timeline.gate.enter()?; - let gate_guard = Arc::new(gate_guard); - anyhow::Ok(gate_guard) - }; match &shard_selector { ShardSelector::Zero if timeline.shard.is_shard_zero() => { return Ok(Entered { timeline: Arc::clone(timeline), - gate_guard: enter_gate()?, + gate_guard: enter_gate(timeline)?, }); } ShardSelector::Zero => continue, ShardSelector::Page(key) if timeline.shard.is_key_local(key) => { return Ok(Entered { timeline: Arc::clone(timeline), - gate_guard: enter_gate()?, + gate_guard: enter_gate(timeline)?, }); } ShardSelector::Page(_) => continue, ShardSelector::Known(idx) if idx == &timeline.shard.shard_index() => { return Ok(Entered { timeline: Arc::clone(timeline), - gate_guard: enter_gate()?, + gate_guard: enter_gate(timeline)?, }); } ShardSelector::Known(_) => continue, } } } - anyhow::bail!("not found") + Err(GetActiveTimelineError::Timeline( + GetTimelineError::NotFound { + tenant_id: TenantShardId::unsharded(TenantId::from([0; 16])), + timeline_id, + }, + )) } } diff --git a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs index bba9e51cdc..7ec5aa3b77 100644 --- a/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs +++ b/pageserver/src/tenant/timeline/walreceiver/walreceiver_connection.rs @@ -16,7 +16,6 @@ use postgres_connection::PgConnectionConfig; use postgres_ffi::WAL_SEGMENT_SIZE; use postgres_ffi::v14::xlog_utils::normalize_lsn; use postgres_ffi::waldecoder::WalDecodeError; -use postgres_ffi::waldecoder::WalStreamDecoder; use postgres_protocol::message::backend::ReplicationMessage; use postgres_types::PgLsn; use tokio::sync::watch; @@ -32,7 +31,7 @@ use utils::lsn::Lsn; use utils::pageserver_feedback::PageserverFeedback; use utils::postgres_client::PostgresClientProtocol; use utils::sync::gate::GateError; -use wal_decoder::models::{FlushUncommittedRecords, InterpretedWalRecord, InterpretedWalRecords}; +use wal_decoder::models::{FlushUncommittedRecords, InterpretedWalRecords}; use wal_decoder::wire_format::FromWireFormat; use super::TaskStateUpdate; @@ -276,8 +275,6 @@ pub(super) async fn handle_walreceiver_connection( let copy_stream = replication_client.copy_both_simple(&query).await?; let mut physical_stream = pin!(ReplicationStream::new(copy_stream)); - let mut waldecoder = WalStreamDecoder::new(startpoint, timeline.pg_version); - let mut walingest = WalIngest::new(timeline.as_ref(), startpoint, &ctx) .await .map_err(|e| match e.kind { @@ -285,8 +282,6 @@ pub(super) async fn handle_walreceiver_connection( _ => WalReceiverError::Other(e.into()), })?; - let shard = vec![*timeline.get_shard_identity()]; - let (format, compression) = match protocol { PostgresClientProtocol::Interpreted { format, @@ -517,143 +512,6 @@ pub(super) async fn handle_walreceiver_connection( Some(streaming_lsn) } - ReplicationMessage::XLogData(xlog_data) => { - async fn commit( - modification: &mut DatadirModification<'_>, - uncommitted: &mut u64, - filtered: &mut u64, - ctx: &RequestContext, - ) -> anyhow::Result<()> { - let stats = modification.stats(); - modification.commit(ctx).await?; - WAL_INGEST - .records_committed - .inc_by(*uncommitted - *filtered); - WAL_INGEST.inc_values_committed(&stats); - *uncommitted = 0; - *filtered = 0; - Ok(()) - } - - // Pass the WAL data to the decoder, and see if we can decode - // more records as a result. - let data = xlog_data.data(); - let startlsn = Lsn::from(xlog_data.wal_start()); - let endlsn = startlsn + data.len() as u64; - - trace!("received XLogData between {startlsn} and {endlsn}"); - - WAL_INGEST.bytes_received.inc_by(data.len() as u64); - waldecoder.feed_bytes(data); - - { - let mut modification = timeline.begin_modification(startlsn); - let mut uncommitted_records = 0; - let mut filtered_records = 0; - - while let Some((next_record_lsn, recdata)) = waldecoder.poll_decode()? { - // It is important to deal with the aligned records as lsn in getPage@LSN is - // aligned and can be several bytes bigger. Without this alignment we are - // at risk of hitting a deadlock. - if !next_record_lsn.is_aligned() { - return Err(WalReceiverError::Other(anyhow!("LSN not aligned"))); - } - - // Deserialize and interpret WAL record - let interpreted = InterpretedWalRecord::from_bytes_filtered( - recdata, - &shard, - next_record_lsn, - modification.tline.pg_version, - )? - .remove(timeline.get_shard_identity()) - .unwrap(); - - if matches!(interpreted.flush_uncommitted, FlushUncommittedRecords::Yes) - && uncommitted_records > 0 - { - // Special case: legacy PG database creations operate by reading pages from a 'template' database: - // these are the only kinds of WAL record that require reading data blocks while ingesting. Ensure - // all earlier writes of data blocks are visible by committing any modification in flight. - commit( - &mut modification, - &mut uncommitted_records, - &mut filtered_records, - &ctx, - ) - .await?; - } - - // Ingest the records without immediately committing them. - timeline.metrics.wal_records_received.inc(); - let ingested = walingest - .ingest_record(interpreted, &mut modification, &ctx) - .await - .with_context(|| { - format!("could not ingest record at {next_record_lsn}") - }) - .inspect_err(|err| { - // TODO: we can't differentiate cancellation errors with - // anyhow::Error, so just ignore it if we're cancelled. - if !cancellation.is_cancelled() { - critical_timeline!( - timeline.tenant_shard_id, - timeline.timeline_id, - Some(&timeline.corruption_detected), - "{err:?}" - ) - } - })?; - if !ingested { - tracing::debug!("ingest: filtered out record @ LSN {next_record_lsn}"); - WAL_INGEST.records_filtered.inc(); - filtered_records += 1; - } - - // FIXME: this cannot be made pausable_failpoint without fixing the - // failpoint library; in tests, the added amount of debugging will cause us - // to timeout the tests. - fail_point!("walreceiver-after-ingest"); - - last_rec_lsn = next_record_lsn; - - // Commit every ingest_batch_size records. Even if we filtered out - // all records, we still need to call commit to advance the LSN. - uncommitted_records += 1; - if uncommitted_records >= ingest_batch_size - || modification.approx_pending_bytes() - > DatadirModification::MAX_PENDING_BYTES - { - commit( - &mut modification, - &mut uncommitted_records, - &mut filtered_records, - &ctx, - ) - .await?; - } - } - - // Commit the remaining records. - if uncommitted_records > 0 { - commit( - &mut modification, - &mut uncommitted_records, - &mut filtered_records, - &ctx, - ) - .await?; - } - } - - if !caught_up && endlsn >= end_of_wal { - info!("caught up at LSN {endlsn}"); - caught_up = true; - } - - Some(endlsn) - } - ReplicationMessage::PrimaryKeepAlive(keepalive) => { let wal_end = keepalive.wal_end(); let timestamp = keepalive.timestamp(); diff --git a/pageserver/src/utilization.rs b/pageserver/src/utilization.rs index 0dafa5c4bb..cec28f8059 100644 --- a/pageserver/src/utilization.rs +++ b/pageserver/src/utilization.rs @@ -52,7 +52,7 @@ pub(crate) fn regenerate( }; // Express a static value for how many shards we may schedule on one node - const MAX_SHARDS: u32 = 5000; + const MAX_SHARDS: u32 = 2500; let mut doc = PageserverUtilization { disk_usage_bytes: used, diff --git a/pgxn/neon/communicator.c b/pgxn/neon/communicator.c index 5a08b3e331..4c03193d7e 100644 --- a/pgxn/neon/communicator.c +++ b/pgxn/neon/communicator.c @@ -79,10 +79,6 @@ #include "access/xlogrecovery.h" #endif -#if PG_VERSION_NUM < 160000 -typedef PGAlignedBlock PGIOAlignedBlock; -#endif - #define NEON_PANIC_CONNECTION_STATE(shard_no, elvl, message, ...) \ neon_shard_log(shard_no, elvl, "Broken connection state: " message, \ ##__VA_ARGS__) diff --git a/pgxn/neon/extension_server.c b/pgxn/neon/extension_server.c index 00dcb6920e..d64cd3e4af 100644 --- a/pgxn/neon/extension_server.c +++ b/pgxn/neon/extension_server.c @@ -14,7 +14,7 @@ #include "extension_server.h" #include "neon_utils.h" -static int extension_server_port = 0; +int hadron_extension_server_port = 0; static int extension_server_request_timeout = 60; static int extension_server_connect_timeout = 60; @@ -47,7 +47,7 @@ neon_download_extension_file_http(const char *filename, bool is_library) curl_easy_setopt(handle, CURLOPT_CONNECTTIMEOUT, (long)extension_server_connect_timeout /* seconds */ ); compute_ctl_url = psprintf("http://localhost:%d/extension_server/%s%s", - extension_server_port, filename, is_library ? "?is_library=true" : ""); + hadron_extension_server_port, filename, is_library ? "?is_library=true" : ""); elog(LOG, "Sending request to compute_ctl: %s", compute_ctl_url); @@ -82,7 +82,7 @@ pg_init_extension_server() DefineCustomIntVariable("neon.extension_server_port", "connection string to the compute_ctl", NULL, - &extension_server_port, + &hadron_extension_server_port, 0, 0, INT_MAX, PGC_POSTMASTER, 0, /* no flags required */ diff --git a/pgxn/neon/file_cache.c b/pgxn/neon/file_cache.c index 4da6c176cd..3c680eab86 100644 --- a/pgxn/neon/file_cache.c +++ b/pgxn/neon/file_cache.c @@ -635,6 +635,11 @@ lfc_init(void) NULL); } +/* + * Dump a list of pages that are currently in the LFC + * + * This is used to get a snapshot that can be used to prewarm the LFC later. + */ FileCacheState* lfc_get_state(size_t max_entries) { @@ -1827,125 +1832,46 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno, LWLockRelease(lfc_lock); } -typedef struct +/* + * Return metrics about the LFC. + * + * The return format is a palloc'd array of LfcStatsEntrys. The size + * of the returned array is returned in *num_entries. + */ +LfcStatsEntry * +lfc_get_stats(size_t *num_entries) { - TupleDesc tupdesc; -} NeonGetStatsCtx; + LfcStatsEntry *entries; + size_t n = 0; -#define NUM_NEON_GET_STATS_COLS 2 +#define MAX_ENTRIES 10 + entries = palloc(sizeof(LfcStatsEntry) * MAX_ENTRIES); -PG_FUNCTION_INFO_V1(neon_get_lfc_stats); -Datum -neon_get_lfc_stats(PG_FUNCTION_ARGS) -{ - FuncCallContext *funcctx; - NeonGetStatsCtx *fctx; - MemoryContext oldcontext; - TupleDesc tupledesc; - Datum result; - HeapTuple tuple; - char const *key; - uint64 value = 0; - Datum values[NUM_NEON_GET_STATS_COLS]; - bool nulls[NUM_NEON_GET_STATS_COLS]; + entries[n++] = (LfcStatsEntry) {"file_cache_chunk_size_pages", lfc_ctl == NULL, + lfc_ctl ? lfc_blocks_per_chunk : 0 }; + entries[n++] = (LfcStatsEntry) {"file_cache_misses", lfc_ctl == NULL, + lfc_ctl ? lfc_ctl->misses : 0}; + entries[n++] = (LfcStatsEntry) {"file_cache_hits", lfc_ctl == NULL, + lfc_ctl ? lfc_ctl->hits : 0 }; + entries[n++] = (LfcStatsEntry) {"file_cache_used", lfc_ctl == NULL, + lfc_ctl ? lfc_ctl->used : 0 }; + entries[n++] = (LfcStatsEntry) {"file_cache_writes", lfc_ctl == NULL, + lfc_ctl ? lfc_ctl->writes : 0 }; + entries[n++] = (LfcStatsEntry) {"file_cache_size", lfc_ctl == NULL, + lfc_ctl ? lfc_ctl->size : 0 }; + entries[n++] = (LfcStatsEntry) {"file_cache_used_pages", lfc_ctl == NULL, + lfc_ctl ? lfc_ctl->used_pages : 0 }; + entries[n++] = (LfcStatsEntry) {"file_cache_evicted_pages", lfc_ctl == NULL, + lfc_ctl ? lfc_ctl->evicted_pages : 0 }; + entries[n++] = (LfcStatsEntry) {"file_cache_limit", lfc_ctl == NULL, + lfc_ctl ? lfc_ctl->limit : 0 }; + entries[n++] = (LfcStatsEntry) {"file_cache_chunks_pinned", lfc_ctl == NULL, + lfc_ctl ? lfc_ctl->pinned : 0 }; + Assert(n <= MAX_ENTRIES); +#undef MAX_ENTRIES - if (SRF_IS_FIRSTCALL()) - { - funcctx = SRF_FIRSTCALL_INIT(); - - /* Switch context when allocating stuff to be used in later calls */ - oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx); - - /* Create a user function context for cross-call persistence */ - fctx = (NeonGetStatsCtx *) palloc(sizeof(NeonGetStatsCtx)); - - /* Construct a tuple descriptor for the result rows. */ - tupledesc = CreateTemplateTupleDesc(NUM_NEON_GET_STATS_COLS); - - TupleDescInitEntry(tupledesc, (AttrNumber) 1, "lfc_key", - TEXTOID, -1, 0); - TupleDescInitEntry(tupledesc, (AttrNumber) 2, "lfc_value", - INT8OID, -1, 0); - - fctx->tupdesc = BlessTupleDesc(tupledesc); - funcctx->user_fctx = fctx; - - /* Return to original context when allocating transient memory */ - MemoryContextSwitchTo(oldcontext); - } - - funcctx = SRF_PERCALL_SETUP(); - - /* Get the saved state */ - fctx = (NeonGetStatsCtx *) funcctx->user_fctx; - - switch (funcctx->call_cntr) - { - case 0: - key = "file_cache_misses"; - if (lfc_ctl) - value = lfc_ctl->misses; - break; - case 1: - key = "file_cache_hits"; - if (lfc_ctl) - value = lfc_ctl->hits; - break; - case 2: - key = "file_cache_used"; - if (lfc_ctl) - value = lfc_ctl->used; - break; - case 3: - key = "file_cache_writes"; - if (lfc_ctl) - value = lfc_ctl->writes; - break; - case 4: - key = "file_cache_size"; - if (lfc_ctl) - value = lfc_ctl->size; - break; - case 5: - key = "file_cache_used_pages"; - if (lfc_ctl) - value = lfc_ctl->used_pages; - break; - case 6: - key = "file_cache_evicted_pages"; - if (lfc_ctl) - value = lfc_ctl->evicted_pages; - break; - case 7: - key = "file_cache_limit"; - if (lfc_ctl) - value = lfc_ctl->limit; - break; - case 8: - key = "file_cache_chunk_size_pages"; - value = lfc_blocks_per_chunk; - break; - case 9: - key = "file_cache_chunks_pinned"; - if (lfc_ctl) - value = lfc_ctl->pinned; - break; - default: - SRF_RETURN_DONE(funcctx); - } - values[0] = PointerGetDatum(cstring_to_text(key)); - nulls[0] = false; - if (lfc_ctl) - { - nulls[1] = false; - values[1] = Int64GetDatum(value); - } - else - nulls[1] = true; - - tuple = heap_form_tuple(fctx->tupdesc, values, nulls); - result = HeapTupleGetDatum(tuple); - SRF_RETURN_NEXT(funcctx, result); + *num_entries = n; + return entries; } @@ -1953,193 +1879,86 @@ neon_get_lfc_stats(PG_FUNCTION_ARGS) * Function returning data from the local file cache * relation node/tablespace/database/blocknum and access_counter */ -PG_FUNCTION_INFO_V1(local_cache_pages); - -/* - * Record structure holding the to be exposed cache data. - */ -typedef struct +LocalCachePagesRec * +lfc_local_cache_pages(size_t *num_entries) { - uint32 pageoffs; - Oid relfilenode; - Oid reltablespace; - Oid reldatabase; - ForkNumber forknum; - BlockNumber blocknum; - uint16 accesscount; -} LocalCachePagesRec; + HASH_SEQ_STATUS status; + FileCacheEntry *entry; + size_t n_pages; + size_t n; + LocalCachePagesRec *result; -/* - * Function context for data persisting over repeated calls. - */ -typedef struct -{ - TupleDesc tupdesc; - LocalCachePagesRec *record; -} LocalCachePagesContext; - - -#define NUM_LOCALCACHE_PAGES_ELEM 7 - -Datum -local_cache_pages(PG_FUNCTION_ARGS) -{ - FuncCallContext *funcctx; - Datum result; - MemoryContext oldcontext; - LocalCachePagesContext *fctx; /* User function context. */ - TupleDesc tupledesc; - TupleDesc expected_tupledesc; - HeapTuple tuple; - - if (SRF_IS_FIRSTCALL()) + if (!lfc_ctl) { - HASH_SEQ_STATUS status; - FileCacheEntry *entry; - uint32 n_pages = 0; + *num_entries = 0; + return NULL; + } - funcctx = SRF_FIRSTCALL_INIT(); + LWLockAcquire(lfc_lock, LW_SHARED); + if (!LFC_ENABLED()) + { + LWLockRelease(lfc_lock); + *num_entries = 0; + return NULL; + } - /* Switch context when allocating stuff to be used in later calls */ - oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx); - - /* Create a user function context for cross-call persistence */ - fctx = (LocalCachePagesContext *) palloc(sizeof(LocalCachePagesContext)); - - /* - * To smoothly support upgrades from version 1.0 of this extension - * transparently handle the (non-)existence of the pinning_backends - * column. We unfortunately have to get the result type for that... - - * we can't use the result type determined by the function definition - * without potentially crashing when somebody uses the old (or even - * wrong) function definition though. - */ - if (get_call_result_type(fcinfo, NULL, &expected_tupledesc) != TYPEFUNC_COMPOSITE) - neon_log(ERROR, "return type must be a row type"); - - if (expected_tupledesc->natts != NUM_LOCALCACHE_PAGES_ELEM) - neon_log(ERROR, "incorrect number of output arguments"); - - /* Construct a tuple descriptor for the result rows. */ - tupledesc = CreateTemplateTupleDesc(expected_tupledesc->natts); - TupleDescInitEntry(tupledesc, (AttrNumber) 1, "pageoffs", - INT8OID, -1, 0); -#if PG_MAJORVERSION_NUM < 16 - TupleDescInitEntry(tupledesc, (AttrNumber) 2, "relfilenode", - OIDOID, -1, 0); -#else - TupleDescInitEntry(tupledesc, (AttrNumber) 2, "relfilenumber", - OIDOID, -1, 0); -#endif - TupleDescInitEntry(tupledesc, (AttrNumber) 3, "reltablespace", - OIDOID, -1, 0); - TupleDescInitEntry(tupledesc, (AttrNumber) 4, "reldatabase", - OIDOID, -1, 0); - TupleDescInitEntry(tupledesc, (AttrNumber) 5, "relforknumber", - INT2OID, -1, 0); - TupleDescInitEntry(tupledesc, (AttrNumber) 6, "relblocknumber", - INT8OID, -1, 0); - TupleDescInitEntry(tupledesc, (AttrNumber) 7, "accesscount", - INT4OID, -1, 0); - - fctx->tupdesc = BlessTupleDesc(tupledesc); - - if (lfc_ctl) + /* Count the pages first */ + n_pages = 0; + hash_seq_init(&status, lfc_hash); + while ((entry = hash_seq_search(&status)) != NULL) + { + /* Skip hole tags */ + if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0) { - LWLockAcquire(lfc_lock, LW_SHARED); + for (int i = 0; i < lfc_blocks_per_chunk; i++) + n_pages += GET_STATE(entry, i) == AVAILABLE; + } + } - if (LFC_ENABLED()) + if (n_pages == 0) + { + LWLockRelease(lfc_lock); + *num_entries = 0; + return NULL; + } + + result = (LocalCachePagesRec *) + MemoryContextAllocHuge(CurrentMemoryContext, + sizeof(LocalCachePagesRec) * n_pages); + + /* + * Scan through all the cache entries, saving the relevant fields + * in the result structure. + */ + n = 0; + hash_seq_init(&status, lfc_hash); + while ((entry = hash_seq_search(&status)) != NULL) + { + for (int i = 0; i < lfc_blocks_per_chunk; i++) + { + if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0) { - hash_seq_init(&status, lfc_hash); - while ((entry = hash_seq_search(&status)) != NULL) + if (GET_STATE(entry, i) == AVAILABLE) { - /* Skip hole tags */ - if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0) - { - for (int i = 0; i < lfc_blocks_per_chunk; i++) - n_pages += GET_STATE(entry, i) == AVAILABLE; - } + result[n].pageoffs = entry->offset * lfc_blocks_per_chunk + i; + result[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)); + result[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(entry->key)); + result[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(entry->key)); + result[n].forknum = entry->key.forkNum; + result[n].blocknum = entry->key.blockNum + i; + result[n].accesscount = entry->access_count; + n += 1; } } } - fctx->record = (LocalCachePagesRec *) - MemoryContextAllocHuge(CurrentMemoryContext, - sizeof(LocalCachePagesRec) * n_pages); - - /* Set max calls and remember the user function context. */ - funcctx->max_calls = n_pages; - funcctx->user_fctx = fctx; - - /* Return to original context when allocating transient memory */ - MemoryContextSwitchTo(oldcontext); - - if (n_pages != 0) - { - /* - * Scan through all the cache entries, saving the relevant fields - * in the fctx->record structure. - */ - uint32 n = 0; - - hash_seq_init(&status, lfc_hash); - while ((entry = hash_seq_search(&status)) != NULL) - { - for (int i = 0; i < lfc_blocks_per_chunk; i++) - { - if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0) - { - if (GET_STATE(entry, i) == AVAILABLE) - { - fctx->record[n].pageoffs = entry->offset * lfc_blocks_per_chunk + i; - fctx->record[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)); - fctx->record[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(entry->key)); - fctx->record[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(entry->key)); - fctx->record[n].forknum = entry->key.forkNum; - fctx->record[n].blocknum = entry->key.blockNum + i; - fctx->record[n].accesscount = entry->access_count; - n += 1; - } - } - } - } - Assert(n_pages == n); - } - if (lfc_ctl) - LWLockRelease(lfc_lock); } + Assert(n_pages == n); + LWLockRelease(lfc_lock); - funcctx = SRF_PERCALL_SETUP(); - - /* Get the saved state */ - fctx = funcctx->user_fctx; - - if (funcctx->call_cntr < funcctx->max_calls) - { - uint32 i = funcctx->call_cntr; - Datum values[NUM_LOCALCACHE_PAGES_ELEM]; - bool nulls[NUM_LOCALCACHE_PAGES_ELEM] = { - false, false, false, false, false, false, false - }; - - values[0] = Int64GetDatum((int64) fctx->record[i].pageoffs); - values[1] = ObjectIdGetDatum(fctx->record[i].relfilenode); - values[2] = ObjectIdGetDatum(fctx->record[i].reltablespace); - values[3] = ObjectIdGetDatum(fctx->record[i].reldatabase); - values[4] = ObjectIdGetDatum(fctx->record[i].forknum); - values[5] = Int64GetDatum((int64) fctx->record[i].blocknum); - values[6] = Int32GetDatum(fctx->record[i].accesscount); - - /* Build and return the tuple. */ - tuple = heap_form_tuple(fctx->tupdesc, values, nulls); - result = HeapTupleGetDatum(tuple); - - SRF_RETURN_NEXT(funcctx, result); - } - else - SRF_RETURN_DONE(funcctx); + *num_entries = n_pages; + return result; } - /* * Internal implementation of the approximate_working_set_size_seconds() * function. @@ -2267,4 +2086,3 @@ get_prewarm_info(PG_FUNCTION_ARGS) PG_RETURN_DATUM(HeapTupleGetDatum(heap_form_tuple(tupdesc, values, nulls))); } - diff --git a/pgxn/neon/file_cache.h b/pgxn/neon/file_cache.h index 14e5d4f753..4145327942 100644 --- a/pgxn/neon/file_cache.h +++ b/pgxn/neon/file_cache.h @@ -47,6 +47,26 @@ extern bool lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blk extern FileCacheState* lfc_get_state(size_t max_entries); extern void lfc_prewarm(FileCacheState* fcs, uint32 n_workers); +typedef struct LfcStatsEntry +{ + const char *metric_name; + bool isnull; + uint64 value; +} LfcStatsEntry; +extern LfcStatsEntry *lfc_get_stats(size_t *num_entries); + +typedef struct +{ + uint32 pageoffs; + Oid relfilenode; + Oid reltablespace; + Oid reldatabase; + ForkNumber forknum; + BlockNumber blocknum; + uint16 accesscount; +} LocalCachePagesRec; +extern LocalCachePagesRec *lfc_local_cache_pages(size_t *num_entries); + extern int32 lfc_approximate_working_set_size_seconds(time_t duration, bool reset); diff --git a/pgxn/neon/libpagestore.c b/pgxn/neon/libpagestore.c index ff7ec05ba4..158118f860 100644 --- a/pgxn/neon/libpagestore.c +++ b/pgxn/neon/libpagestore.c @@ -13,6 +13,8 @@ #include #include +#include + #include "libpq-int.h" #include "access/xlog.h" @@ -86,6 +88,10 @@ static int pageserver_response_log_timeout = 10000; /* 2.5 minutes. A bit higher than highest default TCP retransmission timeout */ static int pageserver_response_disconnect_timeout = 150000; +static int conf_refresh_reconnect_attempt_threshold = 16; +// Hadron: timeout for refresh errors (1 minute) +static uint64 kRefreshErrorTimeoutUSec = 1 * USECS_PER_MINUTE; + typedef struct { char connstring[MAX_SHARDS][MAX_PAGESERVER_CONNSTRING_SIZE]; @@ -130,7 +136,7 @@ static uint64 pagestore_local_counter = 0; typedef enum PSConnectionState { PS_Disconnected, /* no connection yet */ PS_Connecting_Startup, /* connection starting up */ - PS_Connecting_PageStream, /* negotiating pagestream */ + PS_Connecting_PageStream, /* negotiating pagestream */ PS_Connected, /* connected, pagestream established */ } PSConnectionState; @@ -401,7 +407,7 @@ get_shard_number(BufferTag *tag) } static inline void -CLEANUP_AND_DISCONNECT(PageServer *shard) +CLEANUP_AND_DISCONNECT(PageServer *shard) { if (shard->wes_read) { @@ -423,7 +429,7 @@ CLEANUP_AND_DISCONNECT(PageServer *shard) * complete the connection (e.g. due to receiving an earlier cancellation * during connection start). * Returns true if successfully connected; false if the connection failed. - * + * * Throws errors in unrecoverable situations, or when this backend's query * is canceled. */ @@ -1030,6 +1036,101 @@ pageserver_disconnect_shard(shardno_t shard_no) shard->state = PS_Disconnected; } +// BEGIN HADRON +/* + * Nudge compute_ctl to refresh our configuration. Called when we suspect we may be + * connecting to the wrong pageservers due to a stale configuration. + * + * This is a best-effort operation. If we couldn't send the local loopback HTTP request + * to compute_ctl or if the request fails for any reason, we just log the error and move + * on. + */ + +extern int hadron_extension_server_port; + +// The timestamp (usec) of the first error that occurred while trying to refresh the configuration. +// Will be reset to 0 after a successful refresh. +static uint64 first_recorded_refresh_error_usec = 0; + +// Request compute_ctl to refresh the configuration. This operation may fail, e.g., if the compute_ctl +// is already in the configuration state. The function returns true if the caller needs to cancel the +// current query to avoid dead/live lock. +static bool +hadron_request_configuration_refresh() { + static CURL *handle = NULL; + CURLcode res; + char *compute_ctl_url; + bool cancel_query = false; + + if (!lakebase_mode) + return false; + + if (handle == NULL) + { + handle = alloc_curl_handle(); + + curl_easy_setopt(handle, CURLOPT_CUSTOMREQUEST, "POST"); + curl_easy_setopt(handle, CURLOPT_TIMEOUT, 3L /* seconds */ ); + curl_easy_setopt(handle, CURLOPT_POSTFIELDS, ""); + } + + // Set the URL + compute_ctl_url = psprintf("http://localhost:%d/refresh_configuration", hadron_extension_server_port); + + + elog(LOG, "Sending refresh configuration request to compute_ctl: %s", compute_ctl_url); + + curl_easy_setopt(handle, CURLOPT_URL, compute_ctl_url); + + res = curl_easy_perform(handle); + if (res != CURLE_OK ) + { + elog(WARNING, "refresh_configuration request failed: %s\n", curl_easy_strerror(res)); + } + else + { + long http_code = 0; + curl_easy_getinfo(handle, CURLINFO_RESPONSE_CODE, &http_code); + if ( res != CURLE_OK ) + { + elog(WARNING, "compute_ctl refresh_configuration request getinfo failed: %s\n", curl_easy_strerror(res)); + } + else + { + elog(LOG, "compute_ctl refresh_configuration got HTTP response: %ld\n", http_code); + if( http_code == 200 ) + { + first_recorded_refresh_error_usec = 0; + } + else + { + if (first_recorded_refresh_error_usec == 0) + { + first_recorded_refresh_error_usec = GetCurrentTimestamp(); + } + else if(GetCurrentTimestamp() - first_recorded_refresh_error_usec > kRefreshErrorTimeoutUSec) + { + { + first_recorded_refresh_error_usec = 0; + cancel_query = true; + } + } + } + } + } + + // In regular Postgres usage, it is not necessary to manually free memory allocated by palloc (psprintf) because + // it will be cleaned up after the "memory context" is reset (e.g. after the query or the transaction is finished). + // However, the number of times this function gets called during a single query/transaction can be unbounded due to + // the various retry loops around calls to pageservers. Therefore, we need to manually free this memory here. + if (compute_ctl_url != NULL) + { + pfree(compute_ctl_url); + } + return cancel_query; +} +// END HADRON + static bool pageserver_send(shardno_t shard_no, NeonRequest *request) { @@ -1064,6 +1165,11 @@ pageserver_send(shardno_t shard_no, NeonRequest *request) while (!pageserver_connect(shard_no, shard->n_reconnect_attempts < max_reconnect_attempts ? LOG : ERROR)) { shard->n_reconnect_attempts += 1; + if (shard->n_reconnect_attempts > conf_refresh_reconnect_attempt_threshold + && hadron_request_configuration_refresh() ) + { + neon_shard_log(shard_no, ERROR, "request failed too many times, cancelling query"); + } } shard->n_reconnect_attempts = 0; } else { @@ -1171,17 +1277,26 @@ pageserver_receive(shardno_t shard_no) pfree(msg); pageserver_disconnect(shard_no); resp = NULL; + + /* + * Always poke compute_ctl to request a configuration refresh if we have issues receiving data from pageservers after + * successfully connecting to it. It could be an indication that we are connecting to the wrong pageservers (e.g. PS + * is in secondary mode or otherwise refuses to respond our request). + */ + hadron_request_configuration_refresh(); } else if (rc == -2) { char *msg = pchomp(PQerrorMessage(pageserver_conn)); pageserver_disconnect(shard_no); + hadron_request_configuration_refresh(); neon_shard_log(shard_no, ERROR, "pageserver_receive disconnect: could not read COPY data: %s", msg); } else { pageserver_disconnect(shard_no); + hadron_request_configuration_refresh(); neon_shard_log(shard_no, ERROR, "pageserver_receive disconnect: unexpected PQgetCopyData return value: %d", rc); } @@ -1249,21 +1364,34 @@ pageserver_try_receive(shardno_t shard_no) neon_shard_log(shard_no, LOG, "pageserver_receive disconnect: psql end of copy data: %s", pchomp(PQerrorMessage(pageserver_conn))); pageserver_disconnect(shard_no); resp = NULL; + hadron_request_configuration_refresh(); } else if (rc == -2) { char *msg = pchomp(PQerrorMessage(pageserver_conn)); pageserver_disconnect(shard_no); + hadron_request_configuration_refresh(); neon_shard_log(shard_no, LOG, "pageserver_receive disconnect: could not read COPY data: %s", msg); resp = NULL; } else { pageserver_disconnect(shard_no); + hadron_request_configuration_refresh(); neon_shard_log(shard_no, ERROR, "pageserver_receive disconnect: unexpected PQgetCopyData return value: %d", rc); } + /* + * Always poke compute_ctl to request a configuration refresh if we have issues receiving data from pageservers after + * successfully connecting to it. It could be an indication that we are connecting to the wrong pageservers (e.g. PS + * is in secondary mode or otherwise refuses to respond our request). + */ + if ( rc < 0 && hadron_request_configuration_refresh() ) + { + neon_shard_log(shard_no, ERROR, "refresh_configuration request failed, cancelling query"); + } + shard->nresponses_received++; return (NeonResponse *) resp; } @@ -1459,6 +1587,16 @@ pg_init_libpagestore(void) PGC_SU_BACKEND, 0, /* no flags required */ NULL, NULL, NULL); + DefineCustomIntVariable("hadron.conf_refresh_reconnect_attempt_threshold", + "Threshold of the number of consecutive failed pageserver " + "connection attempts (per shard) before signaling " + "compute_ctl for a configuration refresh.", + NULL, + &conf_refresh_reconnect_attempt_threshold, + 16, 0, INT_MAX, + PGC_USERSET, + 0, + NULL, NULL, NULL); DefineCustomIntVariable("neon.pageserver_response_log_timeout", "pageserver response log timeout", diff --git a/pgxn/neon/neon.c b/pgxn/neon/neon.c index 5a6d7dcd2f..e831fca7f8 100644 --- a/pgxn/neon/neon.c +++ b/pgxn/neon/neon.c @@ -1,7 +1,7 @@ /*------------------------------------------------------------------------- * * neon.c - * Main entry point into the neon exension + * Main entry point into the neon extension * *------------------------------------------------------------------------- */ @@ -51,6 +51,7 @@ void _PG_init(void); bool lakebase_mode = false; static int running_xacts_overflow_policy; +static emit_log_hook_type prev_emit_log_hook; static bool monitor_query_exec_time = false; static ExecutorStart_hook_type prev_ExecutorStart = NULL; @@ -447,17 +448,19 @@ ReportSearchPath(void) static int neon_pgstat_file_size_limit; #endif -#if PG_VERSION_NUM >= 160000 && PG_VERSION_NUM < 170000 -static void DatabricksSqlErrorHookImpl(int sqlerrcode) { - if (sqlerrcode == ERRCODE_DATA_CORRUPTED) { +static void DatabricksSqlErrorHookImpl(ErrorData *edata) { + if (prev_emit_log_hook != NULL) { + prev_emit_log_hook(edata); + } + + if (edata->sqlerrcode == ERRCODE_DATA_CORRUPTED) { pg_atomic_fetch_add_u32(&databricks_metrics_shared->data_corruption_count, 1); - } else if (sqlerrcode == ERRCODE_INDEX_CORRUPTED) { + } else if (edata->sqlerrcode == ERRCODE_INDEX_CORRUPTED) { pg_atomic_fetch_add_u32(&databricks_metrics_shared->index_corruption_count, 1); - } else if (sqlerrcode == ERRCODE_INTERNAL_ERROR) { + } else if (edata->sqlerrcode == ERRCODE_INTERNAL_ERROR) { pg_atomic_fetch_add_u32(&databricks_metrics_shared->internal_error_count, 1); } } -#endif void _PG_init(void) @@ -470,11 +473,10 @@ _PG_init(void) load_file("$libdir/neon_rmgr", false); #endif -#if PG_VERSION_NUM >= 160000 && PG_VERSION_NUM < 170000 if (lakebase_mode) { - SqlErrorCode_hook = DatabricksSqlErrorHookImpl; + prev_emit_log_hook = emit_log_hook; + emit_log_hook = DatabricksSqlErrorHookImpl; } -#endif /* * Initializing a pre-loaded Postgres extension happens in three stages: @@ -528,7 +530,7 @@ _PG_init(void) DefineCustomBoolVariable( "neon.disable_logical_replication_subscribers", - "Disables incomming logical replication", + "Disable incoming logical replication", NULL, &disable_logical_replication_subscribers, false, @@ -587,7 +589,7 @@ _PG_init(void) DefineCustomEnumVariable( "neon.debug_compare_local", - "Debug mode for compaing content of pages in prefetch ring/LFC/PS and local disk", + "Debug mode for comparing content of pages in prefetch ring/LFC/PS and local disk", NULL, &debug_compare_local, DEBUG_COMPARE_LOCAL_NONE, @@ -658,11 +660,15 @@ _PG_init(void) ExecutorEnd_hook = neon_ExecutorEnd; } +/* Various functions exposed at SQL level */ + PG_FUNCTION_INFO_V1(pg_cluster_size); PG_FUNCTION_INFO_V1(backpressure_lsns); PG_FUNCTION_INFO_V1(backpressure_throttling_time); PG_FUNCTION_INFO_V1(approximate_working_set_size_seconds); PG_FUNCTION_INFO_V1(approximate_working_set_size); +PG_FUNCTION_INFO_V1(neon_get_lfc_stats); +PG_FUNCTION_INFO_V1(local_cache_pages); Datum pg_cluster_size(PG_FUNCTION_ARGS) @@ -737,6 +743,76 @@ approximate_working_set_size(PG_FUNCTION_ARGS) PG_RETURN_INT32(dc); } +Datum +neon_get_lfc_stats(PG_FUNCTION_ARGS) +{ +#define NUM_NEON_GET_STATS_COLS 2 + ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo; + LfcStatsEntry *entries; + size_t num_entries; + + InitMaterializedSRF(fcinfo, 0); + + /* lfc_get_stats() does all the heavy lifting */ + entries = lfc_get_stats(&num_entries); + + /* Convert the LfcStatsEntrys to a result set */ + for (size_t i = 0; i < num_entries; i++) + { + LfcStatsEntry *entry = &entries[i]; + Datum values[NUM_NEON_GET_STATS_COLS]; + bool nulls[NUM_NEON_GET_STATS_COLS]; + + values[0] = CStringGetTextDatum(entry->metric_name); + nulls[0] = false; + values[1] = Int64GetDatum(entry->isnull ? 0 : entry->value); + nulls[1] = entry->isnull; + tuplestore_putvalues(rsinfo->setResult, rsinfo->setDesc, values, nulls); + } + PG_RETURN_VOID(); + +#undef NUM_NEON_GET_STATS_COLS +} + +Datum +local_cache_pages(PG_FUNCTION_ARGS) +{ +#define NUM_LOCALCACHE_PAGES_COLS 7 + ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo; + LocalCachePagesRec *entries; + size_t num_entries; + + InitMaterializedSRF(fcinfo, 0); + + /* lfc_local_cache_pages() does all the heavy lifting */ + entries = lfc_local_cache_pages(&num_entries); + + /* Convert the LocalCachePagesRec structs to a result set */ + for (size_t i = 0; i < num_entries; i++) + { + LocalCachePagesRec *entry = &entries[i]; + Datum values[NUM_LOCALCACHE_PAGES_COLS]; + bool nulls[NUM_LOCALCACHE_PAGES_COLS] = { + false, false, false, false, false, false, false + }; + + values[0] = Int64GetDatum((int64) entry->pageoffs); + values[1] = ObjectIdGetDatum(entry->relfilenode); + values[2] = ObjectIdGetDatum(entry->reltablespace); + values[3] = ObjectIdGetDatum(entry->reldatabase); + values[4] = ObjectIdGetDatum(entry->forknum); + values[5] = Int64GetDatum((int64) entry->blocknum); + values[6] = Int32GetDatum(entry->accesscount); + + /* Build and return the tuple. */ + tuplestore_putvalues(rsinfo->setResult, rsinfo->setDesc, values, nulls); + } + + PG_RETURN_VOID(); + +#undef NUM_LOCALCACHE_PAGES_COLS +} + /* * Initialization stage 2: make requests for the amount of shared memory we * will need. @@ -768,7 +844,6 @@ neon_shmem_request_hook(void) static void neon_shmem_startup_hook(void) { - /* Initialize */ if (prev_shmem_startup_hook) prev_shmem_startup_hook(); diff --git a/pgxn/neon/neon_perf_counters.c b/pgxn/neon/neon_perf_counters.c index 69d58e7d99..1b060910fe 100644 --- a/pgxn/neon/neon_perf_counters.c +++ b/pgxn/neon/neon_perf_counters.c @@ -20,7 +20,6 @@ #include "neon.h" #include "neon_perf_counters.h" #include "walproposer.h" -#include "walproposer.h" /* BEGIN_HADRON */ databricks_metrics *databricks_metrics_shared; diff --git a/pgxn/neon/neon_perf_counters.h b/pgxn/neon/neon_perf_counters.h index 6a6e16cd26..5c0b7ded7a 100644 --- a/pgxn/neon/neon_perf_counters.h +++ b/pgxn/neon/neon_perf_counters.h @@ -167,11 +167,7 @@ extern neon_per_backend_counters *neon_per_backend_counters_shared; */ #define NUM_NEON_PERF_COUNTER_SLOTS (MaxBackends + NUM_AUXILIARY_PROCS) -#if PG_VERSION_NUM >= 170000 #define MyNeonCounters (&neon_per_backend_counters_shared[MyProcNumber]) -#else -#define MyNeonCounters (&neon_per_backend_counters_shared[MyProc->pgprocno]) -#endif extern void inc_getpage_wait(uint64 latency); extern void inc_page_cache_read_wait(uint64 latency); diff --git a/pgxn/neon/neon_pgversioncompat.h b/pgxn/neon/neon_pgversioncompat.h index 3ab8d3e5f5..dbe0e5aa3d 100644 --- a/pgxn/neon/neon_pgversioncompat.h +++ b/pgxn/neon/neon_pgversioncompat.h @@ -9,6 +9,10 @@ #include "fmgr.h" #include "storage/buf_internals.h" +#if PG_MAJORVERSION_NUM < 16 +typedef PGAlignedBlock PGIOAlignedBlock; +#endif + #if PG_MAJORVERSION_NUM < 17 #define NRelFileInfoBackendIsTemp(rinfo) (rinfo.backend != InvalidBackendId) #else @@ -158,6 +162,10 @@ InitBufferTag(BufferTag *tag, const RelFileNode *rnode, #define AmAutoVacuumWorkerProcess() (IsAutoVacuumWorkerProcess()) #endif +#if PG_MAJORVERSION_NUM < 17 +#define MyProcNumber (MyProc - &ProcGlobal->allProcs[0]) +#endif + #if PG_MAJORVERSION_NUM < 15 extern void InitMaterializedSRF(FunctionCallInfo fcinfo, bits32 flags); extern TimeLineID GetWALInsertionTimeLine(void); diff --git a/pgxn/neon/pagestore_smgr.c b/pgxn/neon/pagestore_smgr.c index 9d25266e10..d3e51ba682 100644 --- a/pgxn/neon/pagestore_smgr.c +++ b/pgxn/neon/pagestore_smgr.c @@ -72,10 +72,6 @@ #include "access/xlogrecovery.h" #endif -#if PG_VERSION_NUM < 160000 -typedef PGAlignedBlock PGIOAlignedBlock; -#endif - #include "access/nbtree.h" #include "storage/bufpage.h" #include "access/xlog_internal.h" diff --git a/pgxn/neon/relsize_cache.c b/pgxn/neon/relsize_cache.c index bf7961574a..c6b4aeb394 100644 --- a/pgxn/neon/relsize_cache.c +++ b/pgxn/neon/relsize_cache.c @@ -13,6 +13,7 @@ #include "neon.h" #include "neon_pgversioncompat.h" +#include "miscadmin.h" #include "pagestore_client.h" #include RELFILEINFO_HDR #include "storage/smgr.h" @@ -23,10 +24,6 @@ #include "utils/dynahash.h" #include "utils/guc.h" -#if PG_VERSION_NUM >= 150000 -#include "miscadmin.h" -#endif - typedef struct { NRelFileInfo rinfo; diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index 6c1f56d919..8ffac13daf 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -508,19 +508,45 @@ backpressure_lag_impl(void) LSN_FORMAT_ARGS(flushPtr), LSN_FORMAT_ARGS(applyPtr)); - if ((writePtr != InvalidXLogRecPtr && max_replication_write_lag > 0 && myFlushLsn > writePtr + max_replication_write_lag * MB)) + if (lakebase_mode) { - return (myFlushLsn - writePtr - max_replication_write_lag * MB); - } + // in case PG does not have shard map initialized, we assume PG always has 1 shard at minimum. + shardno_t num_shards = Max(1, get_num_shards()); + int tenant_max_replication_apply_lag = num_shards * max_replication_apply_lag; + int tenant_max_replication_flush_lag = num_shards * max_replication_flush_lag; + int tenant_max_replication_write_lag = num_shards * max_replication_write_lag; - if ((flushPtr != InvalidXLogRecPtr && max_replication_flush_lag > 0 && myFlushLsn > flushPtr + max_replication_flush_lag * MB)) - { - return (myFlushLsn - flushPtr - max_replication_flush_lag * MB); - } + if ((writePtr != InvalidXLogRecPtr && tenant_max_replication_write_lag > 0 && myFlushLsn > writePtr + tenant_max_replication_write_lag * MB)) + { + return (myFlushLsn - writePtr - tenant_max_replication_write_lag * MB); + } - if ((applyPtr != InvalidXLogRecPtr && max_replication_apply_lag > 0 && myFlushLsn > applyPtr + max_replication_apply_lag * MB)) + if ((flushPtr != InvalidXLogRecPtr && tenant_max_replication_flush_lag > 0 && myFlushLsn > flushPtr + tenant_max_replication_flush_lag * MB)) + { + return (myFlushLsn - flushPtr - tenant_max_replication_flush_lag * MB); + } + + if ((applyPtr != InvalidXLogRecPtr && tenant_max_replication_apply_lag > 0 && myFlushLsn > applyPtr + tenant_max_replication_apply_lag * MB)) + { + return (myFlushLsn - applyPtr - tenant_max_replication_apply_lag * MB); + } + } + else { - return (myFlushLsn - applyPtr - max_replication_apply_lag * MB); + if ((writePtr != InvalidXLogRecPtr && max_replication_write_lag > 0 && myFlushLsn > writePtr + max_replication_write_lag * MB)) + { + return (myFlushLsn - writePtr - max_replication_write_lag * MB); + } + + if ((flushPtr != InvalidXLogRecPtr && max_replication_flush_lag > 0 && myFlushLsn > flushPtr + max_replication_flush_lag * MB)) + { + return (myFlushLsn - flushPtr - max_replication_flush_lag * MB); + } + + if ((applyPtr != InvalidXLogRecPtr && max_replication_apply_lag > 0 && myFlushLsn > applyPtr + max_replication_apply_lag * MB)) + { + return (myFlushLsn - applyPtr - max_replication_apply_lag * MB); + } } } return 0; diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 3c3f93c8e3..0ece79c329 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -33,7 +33,6 @@ env_logger.workspace = true framed-websockets.workspace = true futures.workspace = true hashbrown.workspace = true -hashlink.workspace = true hex.workspace = true hmac.workspace = true hostname.workspace = true @@ -54,6 +53,7 @@ json = { path = "../libs/proxy/json" } lasso = { workspace = true, features = ["multi-threaded"] } measured = { workspace = true, features = ["lasso"] } metrics.workspace = true +moka.workspace = true once_cell.workspace = true opentelemetry = { workspace = true, features = ["trace"] } papaya = "0.2.0" @@ -107,10 +107,11 @@ uuid.workspace = true x509-cert.workspace = true redis.workspace = true zerocopy.workspace = true +zeroize.workspace = true # uncomment this to use the real subzero-core crate # subzero-core = { git = "https://github.com/neondatabase/subzero", rev = "396264617e78e8be428682f87469bb25429af88a", features = ["postgresql"], optional = true } # this is a stub for the subzero-core crate -subzero-core = { path = "./subzero_core", features = ["postgresql"], optional = true} +subzero-core = { path = "../libs/proxy/subzero_core", features = ["postgresql"], optional = true} ouroboros = { version = "0.18", optional = true } # jwt stuff diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index b06ed3a0ae..2a02748a10 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -8,11 +8,12 @@ use tracing::{info, info_span}; use crate::auth::backend::ComputeUserInfo; use crate::cache::Cached; +use crate::cache::node_info::CachedNodeInfo; use crate::compute::AuthInfo; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::client::cplane_proxy_v1; -use crate::control_plane::{self, CachedNodeInfo, NodeInfo}; +use crate::control_plane::{self, NodeInfo}; use crate::error::{ReportableError, UserFacingError}; use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 1e5c076fb9..491f14b1b6 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -6,7 +6,7 @@ use crate::auth::{self, AuthFlow}; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::AuthSecret; -use crate::intern::EndpointIdInt; +use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::sasl; use crate::stream::{self, Stream}; @@ -25,13 +25,15 @@ pub(crate) async fn authenticate_cleartext( ctx.set_auth_method(crate::context::AuthMethod::Cleartext); let ep = EndpointIdInt::from(&info.endpoint); + let role = RoleNameInt::from(&info.user); let auth_flow = AuthFlow::new( client, auth::CleartextPassword { secret, endpoint: ep, - pool: config.thread_pool.clone(), + role, + pool: config.scram_thread_pool.clone(), }, ); let auth_outcome = { diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index e7805d8bfe..a6df2a7011 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -16,16 +16,16 @@ use tracing::{debug, info}; use crate::auth::{self, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange}; use crate::cache::Cached; +use crate::cache::node_info::CachedNodeInfo; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::client::ControlPlaneClient; use crate::control_plane::errors::GetAuthInfoError; use crate::control_plane::messages::EndpointRateLimitConfig; use crate::control_plane::{ - self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl, - RoleAccessControl, + self, AccessBlockerFlags, AuthSecret, ControlPlaneApi, EndpointAccessControl, RoleAccessControl, }; -use crate::intern::EndpointIdInt; +use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; use crate::proxy::wake_compute::WakeComputeBackend; @@ -273,9 +273,11 @@ async fn authenticate_with_secret( ) -> auth::Result { if let Some(password) = unauthenticated_password { let ep = EndpointIdInt::from(&info.endpoint); + let role = RoleNameInt::from(&info.user); let auth_outcome = - validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?; + validate_password_and_exchange(&config.scram_thread_pool, ep, role, &password, secret) + .await?; let keys = match auth_outcome { crate::sasl::Outcome::Success(key) => key, crate::sasl::Outcome::Failure(reason) => { @@ -433,11 +435,12 @@ mod tests { use super::auth_quirks; use super::jwt::JwkCache; use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern}; + use crate::cache::node_info::CachedNodeInfo; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::messages::EndpointRateLimitConfig; use crate::control_plane::{ - self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl, + self, AccessBlockerFlags, EndpointAccessControl, RoleAccessControl, }; use crate::proxy::NeonOptions; use crate::rate_limiter::EndpointRateLimiter; @@ -498,7 +501,7 @@ mod tests { static CONFIG: Lazy = Lazy::new(|| AuthenticationConfig { jwks_cache: JwkCache::default(), - thread_pool: ThreadPool::new(1), + scram_thread_pool: ThreadPool::new(1), scram_protocol_timeout: std::time::Duration::from_secs(5), ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index c825d5bf4b..00cd274e99 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -10,7 +10,7 @@ use super::backend::ComputeCredentialKeys; use super::{AuthError, PasswordHackPayload}; use crate::context::RequestContext; use crate::control_plane::AuthSecret; -use crate::intern::EndpointIdInt; +use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage}; use crate::sasl; use crate::scram::threadpool::ThreadPool; @@ -46,6 +46,7 @@ pub(crate) struct PasswordHack; pub(crate) struct CleartextPassword { pub(crate) pool: Arc, pub(crate) endpoint: EndpointIdInt, + pub(crate) role: RoleNameInt, pub(crate) secret: AuthSecret, } @@ -111,6 +112,7 @@ impl AuthFlow<'_, S, CleartextPassword> { let outcome = validate_password_and_exchange( &self.state.pool, self.state.endpoint, + self.state.role, password, self.state.secret, ) @@ -165,13 +167,15 @@ impl AuthFlow<'_, S, Scram<'_>> { pub(crate) async fn validate_password_and_exchange( pool: &ThreadPool, endpoint: EndpointIdInt, + role: RoleNameInt, password: &[u8], secret: AuthSecret, ) -> super::Result> { match secret { // perform scram authentication as both client and server to validate the keys AuthSecret::Scram(scram_secret) => { - let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?; + let outcome = + crate::scram::exchange(pool, endpoint, role, &scram_secret, password).await?; let client_key = match outcome { sasl::Outcome::Success(client_key) => client_key, diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index 7b9012dc69..86b64c62c9 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -29,7 +29,7 @@ use crate::config::{ }; use crate::control_plane::locks::ApiLocks; use crate::http::health_server::AppMetrics; -use crate::metrics::{Metrics, ThreadPoolMetrics}; +use crate::metrics::{Metrics, ServiceInfo}; use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo}; use crate::scram::threadpool::ThreadPool; use crate::serverless::cancel_set::CancelSet; @@ -114,8 +114,6 @@ pub async fn run() -> anyhow::Result<()> { let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook(); let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]); - Metrics::install(Arc::new(ThreadPoolMetrics::new(0))); - // TODO: refactor these to use labels debug!("Version: {GIT_VERSION}"); debug!("Build_tag: {BUILD_TAG}"); @@ -207,6 +205,11 @@ pub async fn run() -> anyhow::Result<()> { endpoint_rate_limiter, ); + Metrics::get() + .service + .info + .set_label(ServiceInfo::running()); + match futures::future::select(pin!(maintenance_tasks.join_next()), pin!(task)).await { // exit immediately on maintenance task completion Either::Left((Some(res), _)) => match crate::error::flatten_err(res)? {}, @@ -279,7 +282,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig http_config, authentication_config: AuthenticationConfig { jwks_cache: JwkCache::default(), - thread_pool: ThreadPool::new(0), + scram_thread_pool: ThreadPool::new(0), scram_protocol_timeout: Duration::from_secs(10), ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index f3782312dc..cdbf0f09ac 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -26,7 +26,7 @@ use utils::project_git_version; use utils::sentry_init::init_sentry; use crate::context::RequestContext; -use crate::metrics::{Metrics, ThreadPoolMetrics}; +use crate::metrics::{Metrics, ServiceInfo}; use crate::pglb::TlsRequired; use crate::pqproto::FeStartupPacket; use crate::protocol2::ConnectionInfo; @@ -80,8 +80,6 @@ pub async fn run() -> anyhow::Result<()> { let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook(); let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]); - Metrics::install(Arc::new(ThreadPoolMetrics::new(0))); - let args = cli().get_matches(); let destination: String = args .get_one::("dest") @@ -135,6 +133,12 @@ pub async fn run() -> anyhow::Result<()> { cancellation_token.clone(), )) .map(crate::error::flatten_err); + + Metrics::get() + .service + .info + .set_label(ServiceInfo::running()); + let signals_task = tokio::spawn(crate::signals::handle(cancellation_token, || {})); // the signal task cant ever succeed. diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 4148f4bc62..29b0ad53f2 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -40,7 +40,7 @@ use crate::config::{ }; use crate::context::parquet::ParquetUploadArgs; use crate::http::health_server::AppMetrics; -use crate::metrics::Metrics; +use crate::metrics::{Metrics, ServiceInfo}; use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::redis::kv_ops::RedisKVClient; @@ -535,12 +535,7 @@ pub async fn run() -> anyhow::Result<()> { // add a task to flush the db_schema cache every 10 minutes #[cfg(feature = "rest_broker")] if let Some(db_schema_cache) = &config.rest_config.db_schema_cache { - maintenance_tasks.spawn(async move { - loop { - tokio::time::sleep(Duration::from_secs(600)).await; - db_schema_cache.flush(); - } - }); + maintenance_tasks.spawn(db_schema_cache.maintain()); } if let Some(metrics_config) = &config.metric_collection { @@ -590,6 +585,11 @@ pub async fn run() -> anyhow::Result<()> { } } + Metrics::get() + .service + .info + .set_label(ServiceInfo::running()); + let maintenance = loop { // get one complete task match futures::future::select( @@ -617,7 +617,12 @@ pub async fn run() -> anyhow::Result<()> { /// ProxyConfig is created at proxy startup, and lives forever. fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let thread_pool = ThreadPool::new(args.scram_thread_pool_size); - Metrics::install(thread_pool.metrics.clone()); + Metrics::get() + .proxy + .scram_pool + .0 + .set(thread_pool.metrics.clone()) + .ok(); let tls_config = match (&args.tls_key, &args.tls_cert) { (Some(key_path), Some(cert_path)) => Some(config::configure_tls( @@ -690,7 +695,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { }; let authentication_config = AuthenticationConfig { jwks_cache: JwkCache::default(), - thread_pool, + scram_thread_pool: thread_pool, scram_protocol_timeout: args.scram_protocol_timeout, ip_allowlist_check_enabled: !args.is_private_access_proxy, is_vpc_acccess_proxy: args.is_private_access_proxy, @@ -711,12 +716,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { info!("Using DbSchemaCache with options={db_schema_cache_config:?}"); let db_schema_cache = if args.is_rest_broker { - Some(DbSchemaCache::new( - "db_schema_cache", - db_schema_cache_config.size, - db_schema_cache_config.ttl, - true, - )) + Some(DbSchemaCache::new(db_schema_cache_config)) } else { None }; diff --git a/proxy/src/cache/common.rs b/proxy/src/cache/common.rs index b5caf94788..9a7d0d99cf 100644 --- a/proxy/src/cache/common.rs +++ b/proxy/src/cache/common.rs @@ -1,4 +1,16 @@ use std::ops::{Deref, DerefMut}; +use std::time::{Duration, Instant}; + +use moka::Expiry; +use moka::notification::RemovalCause; + +use crate::control_plane::messages::ControlPlaneErrorMessage; +use crate::metrics::{ + CacheEviction, CacheKind, CacheOutcome, CacheOutcomeGroup, CacheRemovalCause, Metrics, +}; + +/// Default TTL used when caching errors from control plane. +pub const DEFAULT_ERROR_TTL: Duration = Duration::from_secs(30); /// A generic trait which exposes types of cache's key and value, /// as well as the notion of cache entry invalidation. @@ -10,20 +22,16 @@ pub(crate) trait Cache { /// Entry's value. type Value; - /// Used for entry invalidation. - type LookupInfo; - /// Invalidate an entry using a lookup info. /// We don't have an empty default impl because it's error-prone. - fn invalidate(&self, _: &Self::LookupInfo); + fn invalidate(&self, _: &Self::Key); } impl Cache for &C { type Key = C::Key; type Value = C::Value; - type LookupInfo = C::LookupInfo; - fn invalidate(&self, info: &Self::LookupInfo) { + fn invalidate(&self, info: &Self::Key) { C::invalidate(self, info); } } @@ -31,7 +39,7 @@ impl Cache for &C { /// Wrapper for convenient entry invalidation. pub(crate) struct Cached::Value> { /// Cache + lookup info. - pub(crate) token: Option<(C, C::LookupInfo)>, + pub(crate) token: Option<(C, C::Key)>, /// The value itself. pub(crate) value: V, @@ -43,23 +51,6 @@ impl Cached { Self { token: None, value } } - pub(crate) fn take_value(self) -> (Cached, V) { - ( - Cached { - token: self.token, - value: (), - }, - self.value, - ) - } - - pub(crate) fn map(self, f: impl FnOnce(V) -> U) -> Cached { - Cached { - token: self.token, - value: f(self.value), - } - } - /// Drop this entry from a cache if it's still there. pub(crate) fn invalidate(self) -> V { if let Some((cache, info)) = &self.token { @@ -87,3 +78,91 @@ impl DerefMut for Cached { &mut self.value } } + +pub type ControlPlaneResult = Result>; + +#[derive(Clone, Copy)] +pub struct CplaneExpiry { + pub error: Duration, +} + +impl Default for CplaneExpiry { + fn default() -> Self { + Self { + error: DEFAULT_ERROR_TTL, + } + } +} + +impl CplaneExpiry { + pub fn expire_early( + &self, + value: &ControlPlaneResult, + updated: Instant, + ) -> Option { + match value { + Ok(_) => None, + Err(err) => Some(self.expire_err_early(err, updated)), + } + } + + pub fn expire_err_early(&self, err: &ControlPlaneErrorMessage, updated: Instant) -> Duration { + err.status + .as_ref() + .and_then(|s| s.details.retry_info.as_ref()) + .map_or(self.error, |r| r.retry_at.into_std() - updated) + } +} + +impl Expiry> for CplaneExpiry { + fn expire_after_create( + &self, + _key: &K, + value: &ControlPlaneResult, + created_at: Instant, + ) -> Option { + self.expire_early(value, created_at) + } + + fn expire_after_update( + &self, + _key: &K, + value: &ControlPlaneResult, + updated_at: Instant, + _duration_until_expiry: Option, + ) -> Option { + self.expire_early(value, updated_at) + } +} + +pub fn eviction_listener(kind: CacheKind, cause: RemovalCause) { + let cause = match cause { + RemovalCause::Expired => CacheRemovalCause::Expired, + RemovalCause::Explicit => CacheRemovalCause::Explicit, + RemovalCause::Replaced => CacheRemovalCause::Replaced, + RemovalCause::Size => CacheRemovalCause::Size, + }; + Metrics::get() + .cache + .evicted_total + .inc(CacheEviction { cache: kind, cause }); +} + +#[inline] +pub fn count_cache_outcome(kind: CacheKind, cache_result: Option) -> Option { + let outcome = if cache_result.is_some() { + CacheOutcome::Hit + } else { + CacheOutcome::Miss + }; + Metrics::get().cache.request_total.inc(CacheOutcomeGroup { + cache: kind, + outcome, + }); + cache_result +} + +#[inline] +pub fn count_cache_insert(kind: CacheKind) { + Metrics::get().cache.inserted_total.inc(kind); +} diff --git a/proxy/src/cache/mod.rs b/proxy/src/cache/mod.rs index ce7f781213..0a607a1409 100644 --- a/proxy/src/cache/mod.rs +++ b/proxy/src/cache/mod.rs @@ -1,6 +1,5 @@ pub(crate) mod common; +pub(crate) mod node_info; pub(crate) mod project_info; -mod timed_lru; -pub(crate) use common::{Cache, Cached}; -pub(crate) use timed_lru::TimedLru; +pub(crate) use common::{Cached, ControlPlaneResult, CplaneExpiry}; diff --git a/proxy/src/cache/node_info.rs b/proxy/src/cache/node_info.rs new file mode 100644 index 0000000000..47fc7a5b08 --- /dev/null +++ b/proxy/src/cache/node_info.rs @@ -0,0 +1,60 @@ +use crate::cache::common::{Cache, count_cache_insert, count_cache_outcome, eviction_listener}; +use crate::cache::{Cached, ControlPlaneResult, CplaneExpiry}; +use crate::config::CacheOptions; +use crate::control_plane::NodeInfo; +use crate::metrics::{CacheKind, Metrics}; +use crate::types::EndpointCacheKey; + +pub(crate) struct NodeInfoCache(moka::sync::Cache>); +pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; + +impl Cache for NodeInfoCache { + type Key = EndpointCacheKey; + type Value = ControlPlaneResult; + + fn invalidate(&self, info: &EndpointCacheKey) { + self.0.invalidate(info); + } +} + +impl NodeInfoCache { + pub fn new(config: CacheOptions) -> Self { + let builder = moka::sync::Cache::builder() + .name("node_info") + .expire_after(CplaneExpiry::default()); + let builder = config.moka(builder); + + if let Some(size) = config.size { + Metrics::get() + .cache + .capacity + .set(CacheKind::NodeInfo, size as i64); + } + + let builder = builder + .eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::NodeInfo, cause)); + + Self(builder.build()) + } + + pub fn insert(&self, key: EndpointCacheKey, value: ControlPlaneResult) { + count_cache_insert(CacheKind::NodeInfo); + self.0.insert(key, value); + } + + pub fn get(&self, key: &EndpointCacheKey) -> Option> { + count_cache_outcome(CacheKind::NodeInfo, self.0.get(key)) + } + + pub fn get_entry( + &'static self, + key: &EndpointCacheKey, + ) -> Option> { + self.get(key).map(|res| { + res.map(|value| Cached { + token: Some((self, key.clone())), + value, + }) + }) + } +} diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index a589dd175b..f8a38be287 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -1,84 +1,20 @@ -use std::collections::{HashMap, HashSet, hash_map}; +use std::collections::HashSet; use std::convert::Infallible; -use std::time::Duration; -use async_trait::async_trait; use clashmap::ClashMap; -use clashmap::mapref::one::Ref; -use rand::Rng; -use tokio::time::Instant; +use moka::sync::Cache; use tracing::{debug, info}; +use crate::cache::common::{ + ControlPlaneResult, CplaneExpiry, count_cache_insert, count_cache_outcome, eviction_listener, +}; use crate::config::ProjectInfoCacheOptions; use crate::control_plane::messages::{ControlPlaneErrorMessage, Reason}; use crate::control_plane::{EndpointAccessControl, RoleAccessControl}; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; +use crate::metrics::{CacheKind, Metrics}; use crate::types::{EndpointId, RoleName}; -#[async_trait] -pub(crate) trait ProjectInfoCache { - fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt); - fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt); - fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt); - fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt); -} - -struct Entry { - expires_at: Instant, - value: T, -} - -impl Entry { - pub(crate) fn new(value: T, ttl: Duration) -> Self { - Self { - expires_at: Instant::now() + ttl, - value, - } - } - - pub(crate) fn get(&self) -> Option<&T> { - (!self.is_expired()).then_some(&self.value) - } - - fn is_expired(&self) -> bool { - self.expires_at <= Instant::now() - } -} - -struct EndpointInfo { - role_controls: HashMap>>, - controls: Option>>, -} - -type ControlPlaneResult = Result>; - -impl EndpointInfo { - pub(crate) fn get_role_secret_with_ttl( - &self, - role_name: RoleNameInt, - ) -> Option<(ControlPlaneResult, Duration)> { - let entry = self.role_controls.get(&role_name)?; - let ttl = entry.expires_at - Instant::now(); - Some((entry.get()?.clone(), ttl)) - } - - pub(crate) fn get_controls_with_ttl( - &self, - ) -> Option<(ControlPlaneResult, Duration)> { - let entry = self.controls.as_ref()?; - let ttl = entry.expires_at - Instant::now(); - Some((entry.get()?.clone(), ttl)) - } - - pub(crate) fn invalidate_endpoint(&mut self) { - self.controls = None; - } - - pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) { - self.role_controls.remove(&role_name); - } -} - /// Cache for project info. /// This is used to cache auth data for endpoints. /// Invalidation is done by console notifications or by TTL (if console notifications are disabled). @@ -86,8 +22,9 @@ impl EndpointInfo { /// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data. /// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available? /// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache. -pub struct ProjectInfoCacheImpl { - cache: ClashMap, +pub struct ProjectInfoCache { + role_controls: Cache<(EndpointIdInt, RoleNameInt), ControlPlaneResult>, + ep_controls: Cache>, project2ep: ClashMap>, // FIXME(stefan): we need a way to GC the account2ep map. @@ -96,16 +33,13 @@ pub struct ProjectInfoCacheImpl { config: ProjectInfoCacheOptions, } -#[async_trait] -impl ProjectInfoCache for ProjectInfoCacheImpl { - fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) { +impl ProjectInfoCache { + pub fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) { info!("invalidating endpoint access for `{endpoint_id}`"); - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_endpoint(); - } + self.ep_controls.invalidate(&endpoint_id); } - fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) { + pub fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) { info!("invalidating endpoint access for project `{project_id}`"); let endpoints = self .project2ep @@ -113,13 +47,11 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { .map(|kv| kv.value().clone()) .unwrap_or_default(); for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_endpoint(); - } + self.ep_controls.invalidate(&endpoint_id); } } - fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) { + pub fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) { info!("invalidating endpoint access for org `{account_id}`"); let endpoints = self .account2ep @@ -127,13 +59,15 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { .map(|kv| kv.value().clone()) .unwrap_or_default(); for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_endpoint(); - } + self.ep_controls.invalidate(&endpoint_id); } } - fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) { + pub fn invalidate_role_secret_for_project( + &self, + project_id: ProjectIdInt, + role_name: RoleNameInt, + ) { info!( "invalidating role secret for project_id `{}` and role_name `{}`", project_id, role_name, @@ -144,47 +78,73 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { .map(|kv| kv.value().clone()) .unwrap_or_default(); for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_role_secret(role_name); - } + self.role_controls.invalidate(&(endpoint_id, role_name)); } } } -impl ProjectInfoCacheImpl { +impl ProjectInfoCache { pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self { + Metrics::get().cache.capacity.set( + CacheKind::ProjectInfoRoles, + (config.size * config.max_roles) as i64, + ); + Metrics::get() + .cache + .capacity + .set(CacheKind::ProjectInfoEndpoints, config.size as i64); + + // we cache errors for 30 seconds, unless retry_at is set. + let expiry = CplaneExpiry::default(); Self { - cache: ClashMap::new(), + role_controls: Cache::builder() + .name("project_info_roles") + .eviction_listener(|_k, _v, cause| { + eviction_listener(CacheKind::ProjectInfoRoles, cause); + }) + .max_capacity(config.size * config.max_roles) + .time_to_live(config.ttl) + .expire_after(expiry) + .build(), + ep_controls: Cache::builder() + .name("project_info_endpoints") + .eviction_listener(|_k, _v, cause| { + eviction_listener(CacheKind::ProjectInfoEndpoints, cause); + }) + .max_capacity(config.size) + .time_to_live(config.ttl) + .expire_after(expiry) + .build(), project2ep: ClashMap::new(), account2ep: ClashMap::new(), config, } } - fn get_endpoint_cache( - &self, - endpoint_id: &EndpointId, - ) -> Option> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - self.cache.get(&endpoint_id) - } - - pub(crate) fn get_role_secret_with_ttl( + pub(crate) fn get_role_secret( &self, endpoint_id: &EndpointId, role_name: &RoleName, - ) -> Option<(ControlPlaneResult, Duration)> { + ) -> Option> { + let endpoint_id = EndpointIdInt::get(endpoint_id)?; let role_name = RoleNameInt::get(role_name)?; - let endpoint_info = self.get_endpoint_cache(endpoint_id)?; - endpoint_info.get_role_secret_with_ttl(role_name) + + count_cache_outcome( + CacheKind::ProjectInfoRoles, + self.role_controls.get(&(endpoint_id, role_name)), + ) } - pub(crate) fn get_endpoint_access_with_ttl( + pub(crate) fn get_endpoint_access( &self, endpoint_id: &EndpointId, - ) -> Option<(ControlPlaneResult, Duration)> { - let endpoint_info = self.get_endpoint_cache(endpoint_id)?; - endpoint_info.get_controls_with_ttl() + ) -> Option> { + let endpoint_id = EndpointIdInt::get(endpoint_id)?; + + count_cache_outcome( + CacheKind::ProjectInfoEndpoints, + self.ep_controls.get(&endpoint_id), + ) } pub(crate) fn insert_endpoint_access( @@ -203,34 +163,17 @@ impl ProjectInfoCacheImpl { self.insert_project2endpoint(project_id, endpoint_id); } - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } - debug!( key = &*endpoint_id, "created a cache entry for endpoint access" ); - let controls = Some(Entry::new(Ok(controls), self.config.ttl)); - let role_controls = Entry::new(Ok(role_controls), self.config.ttl); + count_cache_insert(CacheKind::ProjectInfoEndpoints); + count_cache_insert(CacheKind::ProjectInfoRoles); - match self.cache.entry(endpoint_id) { - clashmap::Entry::Vacant(e) => { - e.insert(EndpointInfo { - role_controls: HashMap::from_iter([(role_name, role_controls)]), - controls, - }); - } - clashmap::Entry::Occupied(mut e) => { - let ep = e.get_mut(); - ep.controls = controls; - if ep.role_controls.len() < self.config.max_roles { - ep.role_controls.insert(role_name, role_controls); - } - } - } + self.ep_controls.insert(endpoint_id, Ok(controls)); + self.role_controls + .insert((endpoint_id, role_name), Ok(role_controls)); } pub(crate) fn insert_endpoint_access_err( @@ -238,55 +181,34 @@ impl ProjectInfoCacheImpl { endpoint_id: EndpointIdInt, role_name: RoleNameInt, msg: Box, - ttl: Option, ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } - debug!( key = &*endpoint_id, "created a cache entry for an endpoint access error" ); - let ttl = ttl.unwrap_or(self.config.ttl); - - let controls = if msg.get_reason() == Reason::RoleProtected { - // RoleProtected is the only role-specific error that control plane can give us. - // If a given role name does not exist, it still returns a successful response, - // just with an empty secret. - None - } else { - // We can cache all the other errors in EndpointInfo.controls, - // because they don't depend on what role name we pass to control plane. - Some(Entry::new(Err(msg.clone()), ttl)) - }; - - let role_controls = Entry::new(Err(msg), ttl); - - match self.cache.entry(endpoint_id) { - clashmap::Entry::Vacant(e) => { - e.insert(EndpointInfo { - role_controls: HashMap::from_iter([(role_name, role_controls)]), - controls, + // RoleProtected is the only role-specific error that control plane can give us. + // If a given role name does not exist, it still returns a successful response, + // just with an empty secret. + if msg.get_reason() != Reason::RoleProtected { + // We can cache all the other errors in ep_controls because they don't + // depend on what role name we pass to control plane. + self.ep_controls + .entry(endpoint_id) + .and_compute_with(|entry| match entry { + // leave the entry alone if it's already Ok + Some(entry) if entry.value().is_ok() => moka::ops::compute::Op::Nop, + // replace the entry + _ => { + count_cache_insert(CacheKind::ProjectInfoEndpoints); + moka::ops::compute::Op::Put(Err(msg.clone())) + } }); - } - clashmap::Entry::Occupied(mut e) => { - let ep = e.get_mut(); - if let Some(entry) = &ep.controls - && !entry.is_expired() - && entry.value.is_ok() - { - // If we have cached non-expired, non-error controls, keep them. - } else { - ep.controls = controls; - } - if ep.role_controls.len() < self.config.max_roles { - ep.role_controls.insert(role_name, role_controls); - } - } } + + count_cache_insert(CacheKind::ProjectInfoRoles); + self.role_controls + .insert((endpoint_id, role_name), Err(msg)); } fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) { @@ -307,73 +229,35 @@ impl ProjectInfoCacheImpl { } } - pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) { - let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else { - return; - }; - let Some(role_name) = RoleNameInt::get(role_name) else { - return; - }; - - let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else { - return; - }; - - let entry = endpoint_info.role_controls.entry(role_name); - let hash_map::Entry::Occupied(role_controls) = entry else { - return; - }; - - if role_controls.get().is_expired() { - role_controls.remove(); - } + pub fn maybe_invalidate_role_secret(&self, _endpoint_id: &EndpointId, _role_name: &RoleName) { + // TODO: Expire the value early if the key is idle. + // Currently not an issue as we would just use the TTL to decide, which is what already happens. } pub async fn gc_worker(&self) -> anyhow::Result { - let mut interval = - tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32); + let mut interval = tokio::time::interval(self.config.gc_interval); loop { interval.tick().await; - if self.cache.len() < self.config.size { - // If there are not too many entries, wait until the next gc cycle. - continue; - } - self.gc(); + self.ep_controls.run_pending_tasks(); + self.role_controls.run_pending_tasks(); } } - - fn gc(&self) { - let shard = rand::rng().random_range(0..self.project2ep.shards().len()); - debug!(shard, "project_info_cache: performing epoch reclamation"); - - // acquire a random shard lock - let mut removed = 0; - let shard = self.project2ep.shards()[shard].write(); - for (_, endpoints) in shard.iter() { - for endpoint in endpoints { - self.cache.remove(endpoint); - removed += 1; - } - } - // We can drop this shard only after making sure that all endpoints are removed. - drop(shard); - info!("project_info_cache: removed {removed} endpoints"); - } } #[cfg(test)] mod tests { + use std::sync::Arc; + use std::time::Duration; + use super::*; use crate::control_plane::messages::{Details, EndpointRateLimitConfig, ErrorInfo, Status}; use crate::control_plane::{AccessBlockerFlags, AuthSecret}; use crate::scram::ServerSecret; - use std::sync::Arc; #[tokio::test] async fn test_project_info_cache_settings() { - tokio::time::pause(); - let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { - size: 2, + let cache = ProjectInfoCache::new(ProjectInfoCacheOptions { + size: 1, max_roles: 2, ttl: Duration::from_secs(1), gc_interval: Duration::from_secs(600), @@ -423,22 +307,17 @@ mod tests { }, ); - let (cached, ttl) = cache - .get_role_secret_with_ttl(&endpoint_id, &user1) - .unwrap(); + let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); assert_eq!(cached.unwrap().secret, secret1); - assert_eq!(ttl, cache.config.ttl); - let (cached, ttl) = cache - .get_role_secret_with_ttl(&endpoint_id, &user2) - .unwrap(); + let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); assert_eq!(cached.unwrap().secret, secret2); - assert_eq!(ttl, cache.config.ttl); // Shouldn't add more than 2 roles. let user3: RoleName = "user3".into(); let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32]))); + cache.role_controls.run_pending_tasks(); cache.insert_endpoint_access( account_id, project_id, @@ -455,31 +334,18 @@ mod tests { }, ); - assert!( - cache - .get_role_secret_with_ttl(&endpoint_id, &user3) - .is_none() - ); + cache.role_controls.run_pending_tasks(); + assert_eq!(cache.role_controls.entry_count(), 2); - let cached = cache - .get_endpoint_access_with_ttl(&endpoint_id) - .unwrap() - .0 - .unwrap(); - assert_eq!(cached.allowed_ips, allowed_ips); + tokio::time::sleep(Duration::from_secs(2)).await; - tokio::time::advance(Duration::from_secs(2)).await; - let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user1); - assert!(cached.is_none()); - let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user2); - assert!(cached.is_none()); - let cached = cache.get_endpoint_access_with_ttl(&endpoint_id); - assert!(cached.is_none()); + cache.role_controls.run_pending_tasks(); + assert_eq!(cache.role_controls.entry_count(), 0); } #[tokio::test] async fn test_caching_project_info_errors() { - let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { + let cache = ProjectInfoCache::new(ProjectInfoCacheOptions { size: 10, max_roles: 10, ttl: Duration::from_secs(1), @@ -519,34 +385,23 @@ mod tests { status: None, }); - let get_role_secret = |endpoint_id, role_name| { - cache - .get_role_secret_with_ttl(endpoint_id, role_name) - .unwrap() - .0 - }; - let get_endpoint_access = - |endpoint_id| cache.get_endpoint_access_with_ttl(endpoint_id).unwrap().0; + let get_role_secret = + |endpoint_id, role_name| cache.get_role_secret(endpoint_id, role_name).unwrap(); + let get_endpoint_access = |endpoint_id| cache.get_endpoint_access(endpoint_id).unwrap(); // stores role-specific errors only for get_role_secret - cache.insert_endpoint_access_err( - (&endpoint_id).into(), - (&user1).into(), - role_msg.clone(), - None, - ); + cache.insert_endpoint_access_err((&endpoint_id).into(), (&user1).into(), role_msg.clone()); assert_eq!( get_role_secret(&endpoint_id, &user1).unwrap_err().error, role_msg.error ); - assert!(cache.get_endpoint_access_with_ttl(&endpoint_id).is_none()); + assert!(cache.get_endpoint_access(&endpoint_id).is_none()); // stores non-role specific errors for both get_role_secret and get_endpoint_access cache.insert_endpoint_access_err( (&endpoint_id).into(), (&user1).into(), generic_msg.clone(), - None, ); assert_eq!( get_role_secret(&endpoint_id, &user1).unwrap_err().error, @@ -558,11 +413,7 @@ mod tests { ); // error isn't returned for other roles in the same endpoint - assert!( - cache - .get_role_secret_with_ttl(&endpoint_id, &user2) - .is_none() - ); + assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); // success for a role does not overwrite errors for other roles cache.insert_endpoint_access( @@ -590,7 +441,6 @@ mod tests { (&endpoint_id).into(), (&user2).into(), generic_msg.clone(), - None, ); assert!(get_role_secret(&endpoint_id, &user2).is_err()); assert!(get_endpoint_access(&endpoint_id).is_ok()); diff --git a/proxy/src/cache/timed_lru.rs b/proxy/src/cache/timed_lru.rs deleted file mode 100644 index 0a7fb40b0c..0000000000 --- a/proxy/src/cache/timed_lru.rs +++ /dev/null @@ -1,262 +0,0 @@ -use std::borrow::Borrow; -use std::hash::Hash; -use std::time::{Duration, Instant}; - -// This seems to make more sense than `lru` or `cached`: -// -// * `near/nearcore` ditched `cached` in favor of `lru` -// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed). -// -// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs). -// This severely hinders its usage both in terms of creating wrappers and supported key types. -// -// On the other hand, `hashlink` has good download stats and appears to be maintained. -use hashlink::{LruCache, linked_hash_map::RawEntryMut}; -use tracing::debug; - -use super::Cache; -use super::common::Cached; - -/// An implementation of timed LRU cache with fixed capacity. -/// Key properties: -/// -/// * Whenever a new entry is inserted, the least recently accessed one is evicted. -/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`). -/// -/// * If `update_ttl_on_retrieval` is `true`. When the entry is about to be retrieved, we check its expiration timestamp. -/// If the entry has expired, we remove it from the cache; Otherwise we bump the -/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong -/// its existence. -/// -/// * There's an API for immediate invalidation (removal) of a cache entry; -/// It's useful in case we know for sure that the entry is no longer correct. -/// See [`Cached`] for more information. -/// -/// * Expired entries are kept in the cache, until they are evicted by the LRU policy, -/// or by a successful lookup (i.e. the entry hasn't expired yet). -/// There is no background job to reap the expired records. -/// -/// * It's possible for an entry that has not yet expired entry to be evicted -/// before expired items. That's a bit wasteful, but probably fine in practice. -pub(crate) struct TimedLru { - /// Cache's name for tracing. - name: &'static str, - - /// The underlying cache implementation. - cache: parking_lot::Mutex>>, - - /// Default time-to-live of a single entry. - ttl: Duration, - - update_ttl_on_retrieval: bool, -} - -impl Cache for TimedLru { - type Key = K; - type Value = V; - type LookupInfo = Key; - - fn invalidate(&self, info: &Self::LookupInfo) { - self.invalidate_raw(info); - } -} - -struct Entry { - created_at: Instant, - expires_at: Instant, - ttl: Duration, - update_ttl_on_retrieval: bool, - value: T, -} - -impl TimedLru { - /// Construct a new LRU cache with timed entries. - pub(crate) fn new( - name: &'static str, - capacity: usize, - ttl: Duration, - update_ttl_on_retrieval: bool, - ) -> Self { - Self { - name, - cache: LruCache::new(capacity).into(), - ttl, - update_ttl_on_retrieval, - } - } - - /// Drop an entry from the cache if it's outdated. - #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] - fn invalidate_raw(&self, key: &K) { - // Do costly things before taking the lock. - let mut cache = self.cache.lock(); - let entry = match cache.raw_entry_mut().from_key(key) { - RawEntryMut::Vacant(_) => return, - RawEntryMut::Occupied(x) => x.remove(), - }; - drop(cache); // drop lock before logging - - let Entry { - created_at, - expires_at, - .. - } = entry; - - debug!( - ?created_at, - ?expires_at, - "processed a cache entry invalidation event" - ); - } - - /// Try retrieving an entry by its key, then execute `extract` if it exists. - #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] - fn get_raw(&self, key: &Q, extract: impl FnOnce(&K, &Entry) -> R) -> Option - where - K: Borrow, - Q: Hash + Eq + ?Sized, - { - let now = Instant::now(); - - // Do costly things before taking the lock. - let mut cache = self.cache.lock(); - let mut raw_entry = match cache.raw_entry_mut().from_key(key) { - RawEntryMut::Vacant(_) => return None, - RawEntryMut::Occupied(x) => x, - }; - - // Immeditely drop the entry if it has expired. - let entry = raw_entry.get(); - if entry.expires_at <= now { - raw_entry.remove(); - return None; - } - - let value = extract(raw_entry.key(), entry); - let (created_at, expires_at) = (entry.created_at, entry.expires_at); - - // Update the deadline and the entry's position in the LRU list. - let deadline = now.checked_add(raw_entry.get().ttl).expect("time overflow"); - if raw_entry.get().update_ttl_on_retrieval { - raw_entry.get_mut().expires_at = deadline; - } - raw_entry.to_back(); - - drop(cache); // drop lock before logging - debug!( - created_at = format_args!("{created_at:?}"), - old_expires_at = format_args!("{expires_at:?}"), - new_expires_at = format_args!("{deadline:?}"), - "accessed a cache entry" - ); - - Some(value) - } - - /// Insert an entry to the cache. If an entry with the same key already - /// existed, return the previous value and its creation timestamp. - #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] - fn insert_raw(&self, key: K, value: V) -> (Instant, Option) { - self.insert_raw_ttl(key, value, self.ttl, self.update_ttl_on_retrieval) - } - - /// Insert an entry to the cache. If an entry with the same key already - /// existed, return the previous value and its creation timestamp. - #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] - fn insert_raw_ttl( - &self, - key: K, - value: V, - ttl: Duration, - update: bool, - ) -> (Instant, Option) { - let created_at = Instant::now(); - let expires_at = created_at.checked_add(ttl).expect("time overflow"); - - let entry = Entry { - created_at, - expires_at, - ttl, - update_ttl_on_retrieval: update, - value, - }; - - // Do costly things before taking the lock. - let old = self - .cache - .lock() - .insert(key, entry) - .map(|entry| entry.value); - - debug!( - created_at = format_args!("{created_at:?}"), - expires_at = format_args!("{expires_at:?}"), - replaced = old.is_some(), - "created a cache entry" - ); - - (created_at, old) - } -} - -impl TimedLru { - pub(crate) fn insert_ttl(&self, key: K, value: V, ttl: Duration) { - self.insert_raw_ttl(key, value, ttl, false); - } - - #[cfg(feature = "rest_broker")] - pub(crate) fn insert(&self, key: K, value: V) { - self.insert_raw_ttl(key, value, self.ttl, self.update_ttl_on_retrieval); - } - - pub(crate) fn insert_unit(&self, key: K, value: V) -> (Option, Cached<&Self, ()>) { - let (_, old) = self.insert_raw(key.clone(), value); - - let cached = Cached { - token: Some((self, key)), - value: (), - }; - - (old, cached) - } - - #[cfg(feature = "rest_broker")] - pub(crate) fn flush(&self) { - let now = Instant::now(); - let mut cache = self.cache.lock(); - - // Collect keys of expired entries first - let expired_keys: Vec<_> = cache - .iter() - .filter_map(|(key, entry)| { - if entry.expires_at <= now { - Some(key.clone()) - } else { - None - } - }) - .collect(); - - // Remove expired entries - for key in expired_keys { - cache.remove(&key); - } - } -} - -impl TimedLru { - /// Retrieve a cached entry in convenient wrapper, alongside timing information. - pub(crate) fn get_with_created_at( - &self, - key: &Q, - ) -> Option::Value, Instant)>> - where - K: Borrow + Clone, - Q: Hash + Eq + ?Sized, - { - self.get_raw(key, |key, entry| Cached { - token: Some((self, key.clone())), - value: (entry.value.clone(), entry.created_at), - }) - } -} diff --git a/proxy/src/compute/mod.rs b/proxy/src/compute/mod.rs index ca784423ee..43cfe70206 100644 --- a/proxy/src/compute/mod.rs +++ b/proxy/src/compute/mod.rs @@ -8,6 +8,7 @@ use futures::{FutureExt, TryFutureExt}; use itertools::Itertools; use postgres_client::config::{AuthKeys, ChannelBinding, SslMode}; use postgres_client::connect_raw::StartupStream; +use postgres_client::error::SqlState; use postgres_client::maybe_tls_stream::MaybeTlsStream; use postgres_client::tls::MakeTlsConnect; use thiserror::Error; @@ -22,7 +23,7 @@ use crate::context::RequestContext; use crate::control_plane::client::ApiLockError; use crate::control_plane::errors::WakeComputeError; use crate::control_plane::messages::MetricsAuxInfo; -use crate::error::{ReportableError, UserFacingError}; +use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumDbConnectionsGuard}; use crate::pqproto::StartupMessageParams; use crate::proxy::connect_compute::TlsNegotiation; @@ -65,12 +66,13 @@ impl UserFacingError for PostgresError { } impl ReportableError for PostgresError { - fn get_error_kind(&self) -> crate::error::ErrorKind { + fn get_error_kind(&self) -> ErrorKind { match self { - PostgresError::Postgres(e) if e.as_db_error().is_some() => { - crate::error::ErrorKind::Postgres - } - PostgresError::Postgres(_) => crate::error::ErrorKind::Compute, + PostgresError::Postgres(err) => match err.as_db_error() { + Some(err) if err.code() == &SqlState::INVALID_CATALOG_NAME => ErrorKind::User, + Some(_) => ErrorKind::Postgres, + None => ErrorKind::Compute, + }, } } } @@ -110,9 +112,9 @@ impl UserFacingError for ConnectionError { } impl ReportableError for ConnectionError { - fn get_error_kind(&self) -> crate::error::ErrorKind { + fn get_error_kind(&self) -> ErrorKind { match self { - ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute, + ConnectionError::TlsError(_) => ErrorKind::Compute, ConnectionError::WakeComputeError(e) => e.get_error_kind(), ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(), #[cfg(test)] diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 16b1dff5f4..22902dbcab 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -19,7 +19,7 @@ use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings}; use crate::ext::TaskExt; use crate::intern::RoleNameInt; use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig}; -use crate::scram::threadpool::ThreadPool; +use crate::scram; use crate::serverless::GlobalConnPoolOptions; use crate::serverless::cancel_set::CancelSet; #[cfg(feature = "rest_broker")] @@ -75,7 +75,7 @@ pub struct HttpConfig { } pub struct AuthenticationConfig { - pub thread_pool: Arc, + pub scram_thread_pool: Arc, pub scram_protocol_timeout: tokio::time::Duration, pub ip_allowlist_check_enabled: bool, pub is_vpc_acccess_proxy: bool, @@ -107,20 +107,23 @@ pub fn remote_storage_from_toml(s: &str) -> anyhow::Result #[derive(Debug)] pub struct CacheOptions { /// Max number of entries. - pub size: usize, + pub size: Option, /// Entry's time-to-live. - pub ttl: Duration, + pub absolute_ttl: Option, + /// Entry's time-to-idle. + pub idle_ttl: Option, } impl CacheOptions { - /// Default options for [`crate::control_plane::NodeInfoCache`]. - pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m"; + /// Default options for [`crate::cache::node_info::NodeInfoCache`]. + pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,idle_ttl=4m"; /// Parse cache options passed via cmdline. /// Example: [`Self::CACHE_DEFAULT_OPTIONS`]. fn parse(options: &str) -> anyhow::Result { let mut size = None; - let mut ttl = None; + let mut absolute_ttl = None; + let mut idle_ttl = None; for option in options.split(',') { let (key, value) = option @@ -129,21 +132,34 @@ impl CacheOptions { match key { "size" => size = Some(value.parse()?), - "ttl" => ttl = Some(humantime::parse_duration(value)?), + "absolute_ttl" | "ttl" => absolute_ttl = Some(humantime::parse_duration(value)?), + "idle_ttl" | "tti" => idle_ttl = Some(humantime::parse_duration(value)?), unknown => bail!("unknown key: {unknown}"), } } - // TTL doesn't matter if cache is always empty. - if let Some(0) = size { - ttl.get_or_insert(Duration::default()); - } - Ok(Self { - size: size.context("missing `size`")?, - ttl: ttl.context("missing `ttl`")?, + size, + absolute_ttl, + idle_ttl, }) } + + pub fn moka( + &self, + mut builder: moka::sync::CacheBuilder, + ) -> moka::sync::CacheBuilder { + if let Some(size) = self.size { + builder = builder.max_capacity(size); + } + if let Some(ttl) = self.absolute_ttl { + builder = builder.time_to_live(ttl); + } + if let Some(tti) = self.idle_ttl { + builder = builder.time_to_idle(tti); + } + builder + } } impl FromStr for CacheOptions { @@ -159,17 +175,17 @@ impl FromStr for CacheOptions { #[derive(Debug)] pub struct ProjectInfoCacheOptions { /// Max number of entries. - pub size: usize, + pub size: u64, /// Entry's time-to-live. pub ttl: Duration, /// Max number of roles per endpoint. - pub max_roles: usize, + pub max_roles: u64, /// Gc interval. pub gc_interval: Duration, } impl ProjectInfoCacheOptions { - /// Default options for [`crate::control_plane::NodeInfoCache`]. + /// Default options for [`crate::cache::project_info::ProjectInfoCache`]. pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=10000,ttl=4m,max_roles=10,gc_interval=60m"; @@ -496,21 +512,37 @@ mod tests { #[test] fn test_parse_cache_options() -> anyhow::Result<()> { - let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?; - assert_eq!(size, 4096); - assert_eq!(ttl, Duration::from_secs(5 * 60)); + let CacheOptions { + size, + absolute_ttl, + idle_ttl: _, + } = "size=4096,ttl=5min".parse()?; + assert_eq!(size, Some(4096)); + assert_eq!(absolute_ttl, Some(Duration::from_secs(5 * 60))); - let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?; - assert_eq!(size, 2); - assert_eq!(ttl, Duration::from_secs(4 * 60)); + let CacheOptions { + size, + absolute_ttl, + idle_ttl: _, + } = "ttl=4m,size=2".parse()?; + assert_eq!(size, Some(2)); + assert_eq!(absolute_ttl, Some(Duration::from_secs(4 * 60))); - let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?; - assert_eq!(size, 0); - assert_eq!(ttl, Duration::from_secs(1)); + let CacheOptions { + size, + absolute_ttl, + idle_ttl: _, + } = "size=0,ttl=1s".parse()?; + assert_eq!(size, Some(0)); + assert_eq!(absolute_ttl, Some(Duration::from_secs(1))); - let CacheOptions { size, ttl } = "size=0".parse()?; - assert_eq!(size, 0); - assert_eq!(ttl, Duration::default()); + let CacheOptions { + size, + absolute_ttl, + idle_ttl: _, + } = "size=0".parse()?; + assert_eq!(size, Some(0)); + assert_eq!(absolute_ttl, None); Ok(()) } diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 8a0403c0b0..b76b13e2c2 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -3,7 +3,6 @@ use std::net::IpAddr; use std::str::FromStr; use std::sync::Arc; -use std::time::Duration; use ::http::HeaderName; use ::http::header::AUTHORIZATION; @@ -17,6 +16,8 @@ use tracing::{Instrument, debug, info, info_span, warn}; use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute}; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; +use crate::cache::Cached; +use crate::cache::node_info::CachedNodeInfo; use crate::context::RequestContext; use crate::control_plane::caches::ApiCaches; use crate::control_plane::errors::{ @@ -25,8 +26,7 @@ use crate::control_plane::errors::{ use crate::control_plane::locks::ApiLocks; use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse}; use crate::control_plane::{ - AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, - RoleAccessControl, + AccessBlockerFlags, AuthInfo, AuthSecret, EndpointAccessControl, NodeInfo, RoleAccessControl, }; use crate::metrics::Metrics; use crate::proxy::retry::CouldRetry; @@ -118,7 +118,6 @@ impl NeonControlPlaneClient { cache_key.into(), role.into(), msg.clone(), - retry_info.map(|r| Duration::from_millis(r.retry_delay_ms)), ); Err(err) @@ -347,18 +346,11 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { ) -> Result { let key = endpoint.normalize(); - if let Some((role_control, ttl)) = self - .caches - .project_info - .get_role_secret_with_ttl(&key, role) - { + if let Some(role_control) = self.caches.project_info.get_role_secret(&key, role) { return match role_control { - Err(mut msg) => { + Err(msg) => { info!(key = &*key, "found cached get_role_access_control error"); - // if retry_delay_ms is set change it to the remaining TTL - replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64); - Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg))) } Ok(role_control) => { @@ -383,17 +375,14 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { ) -> Result { let key = endpoint.normalize(); - if let Some((control, ttl)) = self.caches.project_info.get_endpoint_access_with_ttl(&key) { + if let Some(control) = self.caches.project_info.get_endpoint_access(&key) { return match control { - Err(mut msg) => { + Err(msg) => { info!( key = &*key, "found cached get_endpoint_access_control error" ); - // if retry_delay_ms is set change it to the remaining TTL - replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64); - Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg))) } Ok(control) => { @@ -426,17 +415,11 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { macro_rules! check_cache { () => { - if let Some(cached) = self.caches.node_info.get_with_created_at(&key) { - let (cached, (info, created_at)) = cached.take_value(); + if let Some(info) = self.caches.node_info.get_entry(&key) { return match info { - Err(mut msg) => { + Err(msg) => { info!(key = &*key, "found cached wake_compute error"); - // if retry_delay_ms is set, reduce it by the amount of time it spent in cache - replace_retry_delay_ms(&mut msg, |delay| { - delay.saturating_sub(created_at.elapsed().as_millis() as u64) - }); - Err(WakeComputeError::ControlPlane(ControlPlaneError::Message( msg, ))) @@ -444,7 +427,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { Ok(info) => { debug!(key = &*key, "found cached compute node info"); ctx.set_project(info.aux.clone()); - Ok(cached.map(|()| info)) + Ok(info) } }; } @@ -483,10 +466,12 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { let mut stored_node = node.clone(); // store the cached node as 'warm_cached' stored_node.aux.cold_start_info = ColdStartInfo::WarmCached; + self.caches.node_info.insert(key.clone(), Ok(stored_node)); - let (_, cached) = self.caches.node_info.insert_unit(key, Ok(stored_node)); - - Ok(cached.map(|()| node)) + Ok(Cached { + token: Some((&self.caches.node_info, key)), + value: node, + }) } Err(err) => match err { WakeComputeError::ControlPlane(ControlPlaneError::Message(ref msg)) => { @@ -503,11 +488,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { "created a cache entry for the wake compute error" ); - let ttl = retry_info.map_or(Duration::from_secs(30), |r| { - Duration::from_millis(r.retry_delay_ms) - }); - - self.caches.node_info.insert_ttl(key, Err(msg.clone()), ttl); + self.caches.node_info.insert(key, Err(msg.clone())); Err(err) } @@ -517,14 +498,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { } } -fn replace_retry_delay_ms(msg: &mut ControlPlaneErrorMessage, f: impl FnOnce(u64) -> u64) { - if let Some(status) = &mut msg.status - && let Some(retry_info) = &mut status.details.retry_info - { - retry_info.retry_delay_ms = f(retry_info.retry_delay_ms); - } -} - /// Parse http response body, taking status code into account. fn parse_body serde::Deserialize<'a>>( status: StatusCode, diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index b84dba6b09..9e48d91340 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -15,6 +15,7 @@ use crate::auth::IpPattern; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; use crate::cache::Cached; +use crate::cache::node_info::CachedNodeInfo; use crate::compute::ConnectInfo; use crate::context::RequestContext; use crate::control_plane::errors::{ @@ -22,8 +23,7 @@ use crate::control_plane::errors::{ }; use crate::control_plane::messages::{EndpointRateLimitConfig, MetricsAuxInfo}; use crate::control_plane::{ - AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, - RoleAccessControl, + AccessBlockerFlags, AuthInfo, AuthSecret, EndpointAccessControl, NodeInfo, RoleAccessControl, }; use crate::intern::RoleNameInt; use crate::scram; diff --git a/proxy/src/control_plane/client/mod.rs b/proxy/src/control_plane/client/mod.rs index ecd4db29b2..ec26746873 100644 --- a/proxy/src/control_plane/client/mod.rs +++ b/proxy/src/control_plane/client/mod.rs @@ -13,10 +13,11 @@ use tracing::{debug, info}; use super::{EndpointAccessControl, RoleAccessControl}; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError}; -use crate::cache::project_info::ProjectInfoCacheImpl; +use crate::cache::node_info::{CachedNodeInfo, NodeInfoCache}; +use crate::cache::project_info::ProjectInfoCache; use crate::config::{CacheOptions, ProjectInfoCacheOptions}; use crate::context::RequestContext; -use crate::control_plane::{CachedNodeInfo, ControlPlaneApi, NodeInfoCache, errors}; +use crate::control_plane::{ControlPlaneApi, errors}; use crate::error::ReportableError; use crate::metrics::ApiLockMetrics; use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token}; @@ -119,7 +120,7 @@ pub struct ApiCaches { /// Cache for the `wake_compute` API method. pub(crate) node_info: NodeInfoCache, /// Cache which stores project_id -> endpoint_ids mapping. - pub project_info: Arc, + pub project_info: Arc, } impl ApiCaches { @@ -128,13 +129,8 @@ impl ApiCaches { project_info_cache_config: ProjectInfoCacheOptions, ) -> Self { Self { - node_info: NodeInfoCache::new( - "node_info_cache", - wake_compute_cache_config.size, - wake_compute_cache_config.ttl, - true, - ), - project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)), + node_info: NodeInfoCache::new(wake_compute_cache_config), + project_info: Arc::new(ProjectInfoCache::new(project_info_cache_config)), } } } diff --git a/proxy/src/control_plane/messages.rs b/proxy/src/control_plane/messages.rs index d44d7efcc3..a23ddeb5c8 100644 --- a/proxy/src/control_plane/messages.rs +++ b/proxy/src/control_plane/messages.rs @@ -1,8 +1,10 @@ use std::fmt::{self, Display}; +use std::time::Duration; use measured::FixedCardinalityLabel; use serde::{Deserialize, Serialize}; use smol_str::SmolStr; +use tokio::time::Instant; use crate::auth::IpPattern; use crate::intern::{AccountIdInt, BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; @@ -231,7 +233,13 @@ impl Reason { #[derive(Copy, Clone, Debug, Deserialize)] #[allow(dead_code)] pub(crate) struct RetryInfo { - pub(crate) retry_delay_ms: u64, + #[serde(rename = "retry_delay_ms", deserialize_with = "milliseconds_from_now")] + pub(crate) retry_at: Instant, +} + +fn milliseconds_from_now<'de, D: serde::Deserializer<'de>>(d: D) -> Result { + let millis = u64::deserialize(d)?; + Ok(Instant::now() + Duration::from_millis(millis)) } #[derive(Debug, Deserialize, Clone)] diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index 5bfa24c92d..6f326d789a 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -16,13 +16,13 @@ use messages::EndpointRateLimitConfig; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list}; -use crate::cache::{Cached, TimedLru}; +use crate::cache::node_info::CachedNodeInfo; use crate::context::RequestContext; -use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo}; +use crate::control_plane::messages::MetricsAuxInfo; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt}; use crate::protocol2::ConnectionInfoExtra; use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig}; -use crate::types::{EndpointCacheKey, EndpointId, RoleName}; +use crate::types::{EndpointId, RoleName}; use crate::{compute, scram}; /// Various cache-related types. @@ -77,10 +77,6 @@ pub(crate) struct AccessBlockerFlags { pub vpc_access_blocked: bool, } -pub(crate) type NodeInfoCache = - TimedLru>>; -pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; - #[derive(Clone, Debug)] pub struct RoleAccessControl { pub secret: Option, diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 7524133093..905c9b5279 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -2,59 +2,60 @@ use std::sync::{Arc, OnceLock}; use lasso::ThreadedRodeo; use measured::label::{ - FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue, StaticLabelSet, + FixedCardinalitySet, LabelGroupSet, LabelGroupVisitor, LabelName, LabelSet, LabelValue, + StaticLabelSet, }; +use measured::metric::group::Encoding; use measured::metric::histogram::Thresholds; use measured::metric::name::MetricName; use measured::{ - Counter, CounterVec, FixedCardinalityLabel, Gauge, Histogram, HistogramVec, LabelGroup, - MetricGroup, + Counter, CounterVec, FixedCardinalityLabel, Gauge, GaugeVec, Histogram, HistogramVec, + LabelGroup, MetricGroup, }; -use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLogVec}; +use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLogVec, InfoMetric}; use tokio::time::{self, Instant}; use crate::control_plane::messages::ColdStartInfo; use crate::error::ErrorKind; #[derive(MetricGroup)] -#[metric(new(thread_pool: Arc))] +#[metric(new())] pub struct Metrics { #[metric(namespace = "proxy")] - #[metric(init = ProxyMetrics::new(thread_pool))] + #[metric(init = ProxyMetrics::new())] pub proxy: ProxyMetrics, #[metric(namespace = "wake_compute_lock")] pub wake_compute_lock: ApiLockMetrics, + + #[metric(namespace = "service")] + pub service: ServiceMetrics, + + #[metric(namespace = "cache")] + pub cache: CacheMetrics, } -static SELF: OnceLock = OnceLock::new(); impl Metrics { - pub fn install(thread_pool: Arc) { - let mut metrics = Metrics::new(thread_pool); - - metrics.proxy.errors_total.init_all_dense(); - metrics.proxy.redis_errors_total.init_all_dense(); - metrics.proxy.redis_events_count.init_all_dense(); - metrics.proxy.retries_metric.init_all_dense(); - metrics.proxy.connection_failures_total.init_all_dense(); - - SELF.set(metrics) - .ok() - .expect("proxy metrics must not be installed more than once"); - } - + #[track_caller] pub fn get() -> &'static Self { - #[cfg(test)] - return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0)))); + static SELF: OnceLock = OnceLock::new(); - #[cfg(not(test))] - SELF.get() - .expect("proxy metrics must be installed by the main() function") + SELF.get_or_init(|| { + let mut metrics = Metrics::new(); + + metrics.proxy.errors_total.init_all_dense(); + metrics.proxy.redis_errors_total.init_all_dense(); + metrics.proxy.redis_events_count.init_all_dense(); + metrics.proxy.retries_metric.init_all_dense(); + metrics.proxy.connection_failures_total.init_all_dense(); + + metrics + }) } } #[derive(MetricGroup)] -#[metric(new(thread_pool: Arc))] +#[metric(new())] pub struct ProxyMetrics { #[metric(flatten)] pub db_connections: CounterPairVec, @@ -127,6 +128,9 @@ pub struct ProxyMetrics { /// Number of TLS handshake failures pub tls_handshake_failures: Counter, + /// Number of SHA 256 rounds executed. + pub sha_rounds: Counter, + /// HLL approximate cardinality of endpoints that are connecting pub connecting_endpoints: HyperLogLogVec, 32>, @@ -144,8 +148,25 @@ pub struct ProxyMetrics { pub connect_compute_lock: ApiLockMetrics, #[metric(namespace = "scram_pool")] - #[metric(init = thread_pool)] - pub scram_pool: Arc, + pub scram_pool: OnceLockWrapper>, +} + +/// A Wrapper over [`OnceLock`] to implement [`MetricGroup`]. +pub struct OnceLockWrapper(pub OnceLock); + +impl Default for OnceLockWrapper { + fn default() -> Self { + Self(OnceLock::new()) + } +} + +impl> MetricGroup for OnceLockWrapper { + fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> { + if let Some(inner) = self.0.get() { + inner.collect_group_into(enc)?; + } + Ok(()) + } } #[derive(MetricGroup)] @@ -215,13 +236,6 @@ pub enum Bool { False, } -#[derive(FixedCardinalityLabel, Copy, Clone)] -#[label(singleton = "outcome")] -pub enum CacheOutcome { - Hit, - Miss, -} - #[derive(LabelGroup)] #[label(set = ConsoleRequestSet)] pub struct ConsoleRequest<'a> { @@ -553,14 +567,6 @@ impl From for Bool { } } -#[derive(LabelGroup)] -#[label(set = InvalidEndpointsSet)] -pub struct InvalidEndpointsGroup { - pub protocol: Protocol, - pub rejected: Bool, - pub outcome: ConnectOutcome, -} - #[derive(LabelGroup)] #[label(set = RetriesMetricSet)] pub struct RetriesMetricGroup { @@ -660,3 +666,100 @@ pub struct ThreadPoolMetrics { #[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))] pub worker_task_skips_total: CounterVec, } + +#[derive(MetricGroup, Default)] +pub struct ServiceMetrics { + pub info: InfoMetric, +} + +#[derive(Default)] +pub struct ServiceInfo { + pub state: ServiceState, +} + +impl ServiceInfo { + pub const fn running() -> Self { + ServiceInfo { + state: ServiceState::Running, + } + } + + pub const fn terminating() -> Self { + ServiceInfo { + state: ServiceState::Terminating, + } + } +} + +impl LabelGroup for ServiceInfo { + fn visit_values(&self, v: &mut impl LabelGroupVisitor) { + const STATE: &LabelName = LabelName::from_str("state"); + v.write_value(STATE, &self.state); + } +} + +#[derive(FixedCardinalityLabel, Clone, Copy, Debug, Default)] +#[label(singleton = "state")] +pub enum ServiceState { + #[default] + Init, + Running, + Terminating, +} + +#[derive(MetricGroup)] +#[metric(new())] +pub struct CacheMetrics { + /// The capacity of the cache + pub capacity: GaugeVec>, + /// The total number of entries inserted into the cache + pub inserted_total: CounterVec>, + /// The total number of entries removed from the cache + pub evicted_total: CounterVec, + /// The total number of cache requests + pub request_total: CounterVec, +} + +impl Default for CacheMetrics { + fn default() -> Self { + Self::new() + } +} + +#[derive(FixedCardinalityLabel, Clone, Copy, Debug)] +#[label(singleton = "cache")] +pub enum CacheKind { + NodeInfo, + ProjectInfoEndpoints, + ProjectInfoRoles, + Schema, + Pbkdf2, +} + +#[derive(FixedCardinalityLabel, Clone, Copy, Debug)] +pub enum CacheRemovalCause { + Expired, + Explicit, + Replaced, + Size, +} + +#[derive(LabelGroup)] +#[label(set = CacheEvictionSet)] +pub struct CacheEviction { + pub cache: CacheKind, + pub cause: CacheRemovalCause, +} + +#[derive(FixedCardinalityLabel, Copy, Clone)] +pub enum CacheOutcome { + Hit, + Miss, +} + +#[derive(LabelGroup)] +#[label(set = CacheOutcomeSet)] +pub struct CacheOutcomeGroup { + pub cache: CacheKind, + pub outcome: CacheOutcome, +} diff --git a/proxy/src/proxy/connect_auth.rs b/proxy/src/proxy/connect_auth.rs index 5a1d1ae314..77578c71b1 100644 --- a/proxy/src/proxy/connect_auth.rs +++ b/proxy/src/proxy/connect_auth.rs @@ -2,7 +2,7 @@ use thiserror::Error; use crate::auth::Backend; use crate::auth::backend::ComputeUserInfo; -use crate::cache::Cache; +use crate::cache::common::Cache; use crate::compute::{AuthInfo, ComputeConnection, ConnectionError, PostgresError}; use crate::config::ProxyConfig; use crate::context::RequestContext; diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index 1a4e5f77d2..515f925236 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -1,11 +1,12 @@ use tokio::time; use tracing::{debug, info, warn}; +use crate::cache::node_info::CachedNodeInfo; use crate::compute::{self, COULD_NOT_CONNECT, ComputeConnection}; use crate::config::{ComputeConfig, ProxyConfig, RetryConfig}; use crate::context::RequestContext; +use crate::control_plane::NodeInfo; use crate::control_plane::locks::ApiLocks; -use crate::control_plane::{self, NodeInfo}; use crate::metrics::{ ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType, }; @@ -17,7 +18,7 @@ use crate::types::Host; /// (e.g. the compute node's address might've changed at the wrong time). /// Invalidate the cache entry (if any) to prevent subsequent errors. #[tracing::instrument(skip_all)] -pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> NodeInfo { +pub(crate) fn invalidate_cache(node_info: CachedNodeInfo) -> NodeInfo { let is_cached = node_info.cached(); if is_cached { warn!("invalidating stalled compute node info cache entry"); @@ -37,7 +38,7 @@ pub(crate) trait ConnectMechanism { async fn connect_once( &self, ctx: &RequestContext, - node_info: &control_plane::CachedNodeInfo, + node_info: &CachedNodeInfo, config: &ComputeConfig, ) -> Result; } @@ -66,7 +67,7 @@ impl ConnectMechanism for TcpMechanism<'_> { async fn connect_once( &self, ctx: &RequestContext, - node_info: &control_plane::CachedNodeInfo, + node_info: &CachedNodeInfo, config: &ComputeConfig, ) -> Result { let permit = self.locks.get_permit(&node_info.conn_info.host).await?; diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index d1084628b1..7e0710749e 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -15,15 +15,17 @@ use rstest::rstest; use rustls::crypto::ring; use rustls::pki_types; use tokio::io::{AsyncRead, AsyncWrite, DuplexStream}; +use tokio::time::Instant; use tracing_test::traced_test; use super::retry::CouldRetry; use crate::auth::backend::{ComputeUserInfo, MaybeOwned}; -use crate::config::{ComputeConfig, RetryConfig, TlsConfig}; +use crate::cache::node_info::{CachedNodeInfo, NodeInfoCache}; +use crate::config::{CacheOptions, ComputeConfig, RetryConfig, TlsConfig}; use crate::context::RequestContext; use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; -use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache}; +use crate::control_plane::{self, NodeInfo}; use crate::error::ErrorKind; use crate::pglb::ERR_INSECURE_CONNECTION; use crate::pglb::handshake::{HandshakeData, handshake}; @@ -417,12 +419,11 @@ impl TestConnectMechanism { Self { counter: Arc::new(std::sync::Mutex::new(0)), sequence, - cache: Box::leak(Box::new(NodeInfoCache::new( - "test", - 1, - Duration::from_secs(100), - false, - ))), + cache: Box::leak(Box::new(NodeInfoCache::new(CacheOptions { + size: Some(1), + absolute_ttl: Some(Duration::from_secs(100)), + idle_ttl: None, + }))), } } } @@ -436,7 +437,7 @@ impl ConnectMechanism for TestConnectMechanism { async fn connect_once( &self, _ctx: &RequestContext, - _node_info: &control_plane::CachedNodeInfo, + _node_info: &CachedNodeInfo, _config: &ComputeConfig, ) -> Result { let mut counter = self.counter.lock().unwrap(); @@ -501,7 +502,7 @@ impl TestControlPlaneClient for TestConnectMechanism { details: Details { error_info: None, retry_info: Some(control_plane::messages::RetryInfo { - retry_delay_ms: 1, + retry_at: Instant::now() + Duration::from_millis(1), }), user_facing_message: None, }, @@ -546,8 +547,11 @@ fn helper_create_uncached_node_info() -> NodeInfo { fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo { let node = helper_create_uncached_node_info(); - let (_, node2) = cache.insert_unit("key".into(), Ok(node.clone())); - node2.map(|()| node) + cache.insert("key".into(), Ok(node.clone())); + CachedNodeInfo { + token: Some((cache, "key".into())), + value: node, + } } fn helper_create_connect_info( diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs index b8edf9fd5c..66f2df2af4 100644 --- a/proxy/src/proxy/wake_compute.rs +++ b/proxy/src/proxy/wake_compute.rs @@ -1,9 +1,9 @@ use async_trait::async_trait; use tracing::{error, info}; +use crate::cache::node_info::CachedNodeInfo; use crate::config::RetryConfig; use crate::context::RequestContext; -use crate::control_plane::CachedNodeInfo; use crate::control_plane::errors::{ControlPlaneError, WakeComputeError}; use crate::error::ReportableError; use crate::metrics::{ diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 88d5550fff..a3a378c7e2 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -131,11 +131,11 @@ where Ok(()) } -struct MessageHandler { +struct MessageHandler { cache: Arc, } -impl Clone for MessageHandler { +impl Clone for MessageHandler { fn clone(&self) -> Self { Self { cache: self.cache.clone(), @@ -143,8 +143,8 @@ impl Clone for MessageHandler { } } -impl MessageHandler { - pub(crate) fn new(cache: Arc) -> Self { +impl MessageHandler { + pub(crate) fn new(cache: Arc) -> Self { Self { cache } } @@ -224,7 +224,7 @@ impl MessageHandler { } } -fn invalidate_cache(cache: Arc, msg: Notification) { +fn invalidate_cache(cache: Arc, msg: Notification) { match msg { Notification::EndpointSettingsUpdate(ids) => ids .iter() @@ -247,8 +247,8 @@ fn invalidate_cache(cache: Arc, msg: Notification) { } } -async fn handle_messages( - handler: MessageHandler, +async fn handle_messages( + handler: MessageHandler, redis: ConnectionWithCredentialsProvider, cancellation_token: CancellationToken, ) -> anyhow::Result<()> { @@ -284,13 +284,10 @@ async fn handle_messages( /// Handle console's invalidation messages. #[tracing::instrument(name = "redis_notifications", skip_all)] -pub async fn task_main( +pub async fn task_main( redis: ConnectionWithCredentialsProvider, - cache: Arc, -) -> anyhow::Result -where - C: ProjectInfoCache + Send + Sync + 'static, -{ + cache: Arc, +) -> anyhow::Result { let handler = MessageHandler::new(cache); // 6h - 1m. // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost. diff --git a/proxy/src/scram/cache.rs b/proxy/src/scram/cache.rs new file mode 100644 index 0000000000..9ade7af458 --- /dev/null +++ b/proxy/src/scram/cache.rs @@ -0,0 +1,84 @@ +use tokio::time::Instant; +use zeroize::Zeroize as _; + +use super::pbkdf2; +use crate::cache::Cached; +use crate::cache::common::{Cache, count_cache_insert, count_cache_outcome, eviction_listener}; +use crate::intern::{EndpointIdInt, RoleNameInt}; +use crate::metrics::{CacheKind, Metrics}; + +pub(crate) struct Pbkdf2Cache(moka::sync::Cache<(EndpointIdInt, RoleNameInt), Pbkdf2CacheEntry>); +pub(crate) type CachedPbkdf2<'a> = Cached<&'a Pbkdf2Cache>; + +impl Cache for Pbkdf2Cache { + type Key = (EndpointIdInt, RoleNameInt); + type Value = Pbkdf2CacheEntry; + + fn invalidate(&self, info: &(EndpointIdInt, RoleNameInt)) { + self.0.invalidate(info); + } +} + +/// To speed up password hashing for more active customers, we store the tail results of the +/// PBKDF2 algorithm. If the output of PBKDF2 is U1 ^ U2 ^ ⋯ ^ Uc, then we store +/// suffix = U17 ^ U18 ^ ⋯ ^ Uc. We only need to calculate U1 ^ U2 ^ ⋯ ^ U15 ^ U16 +/// to determine the final result. +/// +/// The suffix alone isn't enough to crack the password. The stored_key is still required. +/// While both are cached in memory, given they're in different locations is makes it much +/// harder to exploit, even if any such memory exploit exists in proxy. +#[derive(Clone)] +pub struct Pbkdf2CacheEntry { + /// corresponds to [`super::ServerSecret::cached_at`] + pub(super) cached_from: Instant, + pub(super) suffix: pbkdf2::Block, +} + +impl Drop for Pbkdf2CacheEntry { + fn drop(&mut self) { + self.suffix.zeroize(); + } +} + +impl Pbkdf2Cache { + pub fn new() -> Self { + const SIZE: u64 = 100; + const TTL: std::time::Duration = std::time::Duration::from_secs(60); + + let builder = moka::sync::Cache::builder() + .name("pbkdf2") + .max_capacity(SIZE) + // We use time_to_live so we don't refresh the lifetime for an invalid password attempt. + .time_to_live(TTL); + + Metrics::get() + .cache + .capacity + .set(CacheKind::Pbkdf2, SIZE as i64); + + let builder = + builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Pbkdf2, cause)); + + Self(builder.build()) + } + + pub fn insert(&self, endpoint: EndpointIdInt, role: RoleNameInt, value: Pbkdf2CacheEntry) { + count_cache_insert(CacheKind::Pbkdf2); + self.0.insert((endpoint, role), value); + } + + fn get(&self, endpoint: EndpointIdInt, role: RoleNameInt) -> Option { + count_cache_outcome(CacheKind::Pbkdf2, self.0.get(&(endpoint, role))) + } + + pub fn get_entry( + &self, + endpoint: EndpointIdInt, + role: RoleNameInt, + ) -> Option> { + self.get(endpoint, role).map(|value| Cached { + token: Some((self, (endpoint, role))), + value, + }) + } +} diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs index a0918fca9f..3f4b0d534b 100644 --- a/proxy/src/scram/exchange.rs +++ b/proxy/src/scram/exchange.rs @@ -4,10 +4,8 @@ use std::convert::Infallible; use base64::Engine as _; use base64::prelude::BASE64_STANDARD; -use hmac::{Hmac, Mac}; -use sha2::Sha256; +use tracing::{debug, trace}; -use super::ScramKey; use super::messages::{ ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN, }; @@ -15,8 +13,10 @@ use super::pbkdf2::Pbkdf2; use super::secret::ServerSecret; use super::signature::SignatureBuilder; use super::threadpool::ThreadPool; -use crate::intern::EndpointIdInt; +use super::{ScramKey, pbkdf2}; +use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::sasl::{self, ChannelBinding, Error as SaslError}; +use crate::scram::cache::Pbkdf2CacheEntry; /// The only channel binding mode we currently support. #[derive(Debug)] @@ -77,46 +77,113 @@ impl<'a> Exchange<'a> { } } -// copied from async fn derive_client_key( pool: &ThreadPool, endpoint: EndpointIdInt, password: &[u8], salt: &[u8], iterations: u32, -) -> ScramKey { - let salted_password = pool - .spawn_job(endpoint, Pbkdf2::start(password, salt, iterations)) - .await; - - let make_key = |name| { - let key = Hmac::::new_from_slice(&salted_password) - .expect("HMAC is able to accept all key sizes") - .chain_update(name) - .finalize(); - - <[u8; 32]>::from(key.into_bytes()) - }; - - make_key(b"Client Key").into() +) -> pbkdf2::Block { + pool.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations)) + .await } +/// For cleartext flow, we need to derive the client key to +/// 1. authenticate the client. +/// 2. authenticate with compute. pub(crate) async fn exchange( pool: &ThreadPool, endpoint: EndpointIdInt, + role: RoleNameInt, + secret: &ServerSecret, + password: &[u8], +) -> sasl::Result> { + if secret.iterations > CACHED_ROUNDS { + exchange_with_cache(pool, endpoint, role, secret, password).await + } else { + let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?; + let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await; + Ok(validate_pbkdf2(secret, &hash)) + } +} + +/// Compute the client key using a cache. We cache the suffix of the pbkdf2 result only, +/// which is not enough by itself to perform an offline brute force. +async fn exchange_with_cache( + pool: &ThreadPool, + endpoint: EndpointIdInt, + role: RoleNameInt, secret: &ServerSecret, password: &[u8], ) -> sasl::Result> { let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?; - let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await; + debug_assert!( + secret.iterations > CACHED_ROUNDS, + "we should not cache password data if there isn't enough rounds needed" + ); + + // compute the prefix of the pbkdf2 output. + let prefix = derive_client_key(pool, endpoint, password, &salt, CACHED_ROUNDS).await; + + if let Some(entry) = pool.cache.get_entry(endpoint, role) { + // hot path: let's check the threadpool cache + if secret.cached_at == entry.cached_from { + // cache is valid. compute the full hash by adding the prefix to the suffix. + let mut hash = prefix; + pbkdf2::xor_assign(&mut hash, &entry.suffix); + let outcome = validate_pbkdf2(secret, &hash); + + if matches!(outcome, sasl::Outcome::Success(_)) { + trace!("password validated from cache"); + } + + return Ok(outcome); + } + + // cached key is no longer valid. + debug!("invalidating cached password"); + entry.invalidate(); + } + + // slow path: full password hash. + let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await; + let outcome = validate_pbkdf2(secret, &hash); + + let client_key = match outcome { + sasl::Outcome::Success(client_key) => client_key, + sasl::Outcome::Failure(_) => return Ok(outcome), + }; + + trace!("storing cached password"); + + // time to cache, compute the suffix by subtracting the prefix from the hash. + let mut suffix = hash; + pbkdf2::xor_assign(&mut suffix, &prefix); + + pool.cache.insert( + endpoint, + role, + Pbkdf2CacheEntry { + cached_from: secret.cached_at, + suffix, + }, + ); + + Ok(sasl::Outcome::Success(client_key)) +} + +fn validate_pbkdf2(secret: &ServerSecret, hash: &pbkdf2::Block) -> sasl::Outcome { + let client_key = super::ScramKey::client_key(&(*hash).into()); if secret.is_password_invalid(&client_key).into() { - Ok(sasl::Outcome::Failure("password doesn't match")) + sasl::Outcome::Failure("password doesn't match") } else { - Ok(sasl::Outcome::Success(client_key)) + sasl::Outcome::Success(client_key) } } +const CACHED_ROUNDS: u32 = 16; + impl SaslInitial { fn transition( &self, diff --git a/proxy/src/scram/key.rs b/proxy/src/scram/key.rs index fe55ff493b..7dc52fd409 100644 --- a/proxy/src/scram/key.rs +++ b/proxy/src/scram/key.rs @@ -1,6 +1,12 @@ //! Tools for client/server/stored key management. +use hmac::Mac as _; +use sha2::Digest as _; use subtle::ConstantTimeEq; +use zeroize::Zeroize as _; + +use crate::metrics::Metrics; +use crate::scram::pbkdf2::Prf; /// Faithfully taken from PostgreSQL. pub(crate) const SCRAM_KEY_LEN: usize = 32; @@ -14,6 +20,12 @@ pub(crate) struct ScramKey { bytes: [u8; SCRAM_KEY_LEN], } +impl Drop for ScramKey { + fn drop(&mut self) { + self.bytes.zeroize(); + } +} + impl PartialEq for ScramKey { fn eq(&self, other: &Self) -> bool { self.ct_eq(other).into() @@ -28,12 +40,26 @@ impl ConstantTimeEq for ScramKey { impl ScramKey { pub(crate) fn sha256(&self) -> Self { - super::sha256([self.as_ref()]).into() + Metrics::get().proxy.sha_rounds.inc_by(1); + Self { + bytes: sha2::Sha256::digest(self.as_bytes()).into(), + } } pub(crate) fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] { self.bytes } + + pub(crate) fn client_key(b: &[u8; 32]) -> Self { + // Prf::new_from_slice will run 2 sha256 rounds. + // Update + Finalize run 2 sha256 rounds. + Metrics::get().proxy.sha_rounds.inc_by(4); + + let mut prf = Prf::new_from_slice(b).expect("HMAC is able to accept all key sizes"); + prf.update(b"Client Key"); + let client_key: [u8; 32] = prf.finalize().into_bytes().into(); + client_key.into() + } } impl From<[u8; SCRAM_KEY_LEN]> for ScramKey { diff --git a/proxy/src/scram/mod.rs b/proxy/src/scram/mod.rs index 5f627e062c..04722d920b 100644 --- a/proxy/src/scram/mod.rs +++ b/proxy/src/scram/mod.rs @@ -6,6 +6,7 @@ //! * //! * +mod cache; mod countmin; mod exchange; mod key; @@ -18,10 +19,8 @@ pub mod threadpool; use base64::Engine as _; use base64::prelude::BASE64_STANDARD; pub(crate) use exchange::{Exchange, exchange}; -use hmac::{Hmac, Mac}; pub(crate) use key::ScramKey; pub(crate) use secret::ServerSecret; -use sha2::{Digest, Sha256}; const SCRAM_SHA_256: &str = "SCRAM-SHA-256"; const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS"; @@ -42,29 +41,13 @@ fn base64_decode_array(input: impl AsRef<[u8]>) -> Option<[u8; N Some(bytes) } -/// This function essentially is `Hmac(sha256, key, input)`. -/// Further reading: . -fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator) -> [u8; 32] { - let mut mac = Hmac::::new_from_slice(key).expect("bad key size"); - parts.into_iter().for_each(|s| mac.update(s)); - - mac.finalize().into_bytes().into() -} - -fn sha256<'a>(parts: impl IntoIterator) -> [u8; 32] { - let mut hasher = Sha256::new(); - parts.into_iter().for_each(|s| hasher.update(s)); - - hasher.finalize().into() -} - #[cfg(test)] mod tests { use super::threadpool::ThreadPool; use super::{Exchange, ServerSecret}; - use crate::intern::EndpointIdInt; + use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::sasl::{Mechanism, Step}; - use crate::types::EndpointId; + use crate::types::{EndpointId, RoleName}; #[test] fn snapshot() { @@ -114,23 +97,34 @@ mod tests { ); } - async fn run_round_trip_test(server_password: &str, client_password: &str) { - let pool = ThreadPool::new(1); - + async fn check( + pool: &ThreadPool, + scram_secret: &ServerSecret, + password: &[u8], + ) -> Result<(), &'static str> { let ep = EndpointId::from("foo"); let ep = EndpointIdInt::from(ep); + let role = RoleName::from("user"); + let role = RoleNameInt::from(&role); - let scram_secret = ServerSecret::build(server_password).await.unwrap(); - let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes()) + let outcome = super::exchange(pool, ep, role, scram_secret, password) .await .unwrap(); match outcome { - crate::sasl::Outcome::Success(_) => {} - crate::sasl::Outcome::Failure(r) => panic!("{r}"), + crate::sasl::Outcome::Success(_) => Ok(()), + crate::sasl::Outcome::Failure(r) => Err(r), } } + async fn run_round_trip_test(server_password: &str, client_password: &str) { + let pool = ThreadPool::new(1); + let scram_secret = ServerSecret::build(server_password).await.unwrap(); + check(&pool, &scram_secret, client_password.as_bytes()) + .await + .unwrap(); + } + #[tokio::test] async fn round_trip() { run_round_trip_test("pencil", "pencil").await; @@ -141,4 +135,27 @@ mod tests { async fn failure() { run_round_trip_test("pencil", "eraser").await; } + + #[tokio::test] + #[tracing_test::traced_test] + async fn password_cache() { + let pool = ThreadPool::new(1); + let scram_secret = ServerSecret::build("password").await.unwrap(); + + // wrong passwords are not added to cache + check(&pool, &scram_secret, b"wrong").await.unwrap_err(); + assert!(!logs_contain("storing cached password")); + + // correct passwords get cached + check(&pool, &scram_secret, b"password").await.unwrap(); + assert!(logs_contain("storing cached password")); + + // wrong passwords do not match the cache + check(&pool, &scram_secret, b"wrong").await.unwrap_err(); + assert!(!logs_contain("password validated from cache")); + + // correct passwords match the cache + check(&pool, &scram_secret, b"password").await.unwrap(); + assert!(logs_contain("password validated from cache")); + } } diff --git a/proxy/src/scram/pbkdf2.rs b/proxy/src/scram/pbkdf2.rs index 7f48e00c41..1300310de2 100644 --- a/proxy/src/scram/pbkdf2.rs +++ b/proxy/src/scram/pbkdf2.rs @@ -1,25 +1,50 @@ +//! For postgres password authentication, we need to perform a PBKDF2 using +//! PRF=HMAC-SHA2-256, producing only 1 block (32 bytes) of output key. + +use hmac::Mac as _; use hmac::digest::consts::U32; use hmac::digest::generic_array::GenericArray; -use hmac::{Hmac, Mac}; -use sha2::Sha256; +use zeroize::Zeroize as _; + +use crate::metrics::Metrics; + +/// The Psuedo-random function used during PBKDF2 and the SCRAM-SHA-256 handshake. +pub type Prf = hmac::Hmac; +pub(crate) type Block = GenericArray; pub(crate) struct Pbkdf2 { - hmac: Hmac, - prev: GenericArray, - hi: GenericArray, + hmac: Prf, + /// U{r-1} for whatever iteration r we are currently on. + prev: Block, + /// the output of `fold(xor, U{1}..U{r})` for whatever iteration r we are currently on. + hi: Block, + /// number of iterations left iterations: u32, } +impl Drop for Pbkdf2 { + fn drop(&mut self) { + self.prev.zeroize(); + self.hi.zeroize(); + } +} + // inspired from impl Pbkdf2 { - pub(crate) fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self { + pub(crate) fn start(pw: &[u8], salt: &[u8], iterations: u32) -> Self { // key the HMAC and derive the first block in-place - let mut hmac = - Hmac::::new_from_slice(str).expect("HMAC is able to accept all key sizes"); + let mut hmac = Prf::new_from_slice(pw).expect("HMAC is able to accept all key sizes"); + + // U1 = PRF(Password, Salt + INT_32_BE(i)) + // i = 1 since we only need 1 block of output. hmac.update(salt); hmac.update(&1u32.to_be_bytes()); let init_block = hmac.finalize_reset().into_bytes(); + // Prf::new_from_slice will run 2 sha256 rounds. + // Our update + finalize run 2 sha256 rounds for each pbkdf2 round. + Metrics::get().proxy.sha_rounds.inc_by(4); + Self { hmac, // one iteration spent above @@ -33,7 +58,11 @@ impl Pbkdf2 { (self.iterations).clamp(0, 4096) } - pub(crate) fn turn(&mut self) -> std::task::Poll<[u8; 32]> { + /// For "fairness", we implement PBKDF2 with cooperative yielding, which is why we use this `turn` + /// function that only executes a fixed number of iterations before continuing. + /// + /// Task must be rescheuled if this returns [`std::task::Poll::Pending`]. + pub(crate) fn turn(&mut self) -> std::task::Poll { let Self { hmac, prev, @@ -44,25 +73,37 @@ impl Pbkdf2 { // only do up to 4096 iterations per turn for fairness let n = (*iterations).clamp(0, 4096); for _ in 0..n { - hmac.update(prev); - let block = hmac.finalize_reset().into_bytes(); - - for (hi_byte, &b) in hi.iter_mut().zip(block.iter()) { - *hi_byte ^= b; - } - - *prev = block; + let next = single_round(hmac, prev); + xor_assign(hi, &next); + *prev = next; } + // Our update + finalize run 2 sha256 rounds for each pbkdf2 round. + Metrics::get().proxy.sha_rounds.inc_by(2 * n as u64); + *iterations -= n; if *iterations == 0 { - std::task::Poll::Ready((*hi).into()) + std::task::Poll::Ready(*hi) } else { std::task::Poll::Pending } } } +#[inline(always)] +pub fn xor_assign(x: &mut Block, y: &Block) { + for (x, &y) in std::iter::zip(x, y) { + *x ^= y; + } +} + +#[inline(always)] +fn single_round(prf: &mut Prf, ui: &Block) -> Block { + // Ui = PRF(Password, Ui-1) + prf.update(ui); + prf.finalize_reset().into_bytes() +} + #[cfg(test)] mod tests { use pbkdf2::pbkdf2_hmac_array; @@ -76,11 +117,11 @@ mod tests { let pass = b"Ne0n_!5_50_C007"; let mut job = Pbkdf2::start(pass, salt, 60000); - let hash = loop { + let hash: [u8; 32] = loop { let std::task::Poll::Ready(hash) = job.turn() else { continue; }; - break hash; + break hash.into(); }; let expected = pbkdf2_hmac_array::(pass, salt, 60000); diff --git a/proxy/src/scram/secret.rs b/proxy/src/scram/secret.rs index 0e070c2f27..a3a64f271c 100644 --- a/proxy/src/scram/secret.rs +++ b/proxy/src/scram/secret.rs @@ -3,6 +3,7 @@ use base64::Engine as _; use base64::prelude::BASE64_STANDARD; use subtle::{Choice, ConstantTimeEq}; +use tokio::time::Instant; use super::base64_decode_array; use super::key::ScramKey; @@ -11,6 +12,9 @@ use super::key::ScramKey; /// and is used throughout the authentication process. #[derive(Clone, Eq, PartialEq, Debug)] pub(crate) struct ServerSecret { + /// When this secret was cached. + pub(crate) cached_at: Instant, + /// Number of iterations for `PBKDF2` function. pub(crate) iterations: u32, /// Salt used to hash user's password. @@ -34,6 +38,7 @@ impl ServerSecret { params.split_once(':').zip(keys.split_once(':'))?; let secret = ServerSecret { + cached_at: Instant::now(), iterations: iterations.parse().ok()?, salt_base64: salt.into(), stored_key: base64_decode_array(stored_key)?.into(), @@ -54,6 +59,7 @@ impl ServerSecret { /// See `auth-scram.c : mock_scram_secret` for details. pub(crate) fn mock(nonce: [u8; 32]) -> Self { Self { + cached_at: Instant::now(), // this doesn't reveal much information as we're going to use // iteration count 1 for our generated passwords going forward. // PG16 users can set iteration count=1 already today. diff --git a/proxy/src/scram/signature.rs b/proxy/src/scram/signature.rs index a5b1c3e9f4..8e074272b6 100644 --- a/proxy/src/scram/signature.rs +++ b/proxy/src/scram/signature.rs @@ -1,6 +1,10 @@ //! Tools for client/server signature management. +use hmac::Mac as _; + use super::key::{SCRAM_KEY_LEN, ScramKey}; +use crate::metrics::Metrics; +use crate::scram::pbkdf2::Prf; /// A collection of message parts needed to derive the client's signature. #[derive(Debug)] @@ -12,15 +16,18 @@ pub(crate) struct SignatureBuilder<'a> { impl SignatureBuilder<'_> { pub(crate) fn build(&self, key: &ScramKey) -> Signature { - let parts = [ - self.client_first_message_bare.as_bytes(), - b",", - self.server_first_message.as_bytes(), - b",", - self.client_final_message_without_proof.as_bytes(), - ]; + // don't know exactly. this is a rough approx + Metrics::get().proxy.sha_rounds.inc_by(8); - super::hmac_sha256(key.as_ref(), parts).into() + let mut mac = Prf::new_from_slice(key.as_ref()).expect("HMAC accepts all key sizes"); + mac.update(self.client_first_message_bare.as_bytes()); + mac.update(b","); + mac.update(self.server_first_message.as_bytes()); + mac.update(b","); + mac.update(self.client_final_message_without_proof.as_bytes()); + Signature { + bytes: mac.finalize().into_bytes().into(), + } } } diff --git a/proxy/src/scram/threadpool.rs b/proxy/src/scram/threadpool.rs index ea2e29ede9..20a1df2b53 100644 --- a/proxy/src/scram/threadpool.rs +++ b/proxy/src/scram/threadpool.rs @@ -15,6 +15,8 @@ use futures::FutureExt; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; +use super::cache::Pbkdf2Cache; +use super::pbkdf2; use super::pbkdf2::Pbkdf2; use crate::intern::EndpointIdInt; use crate::metrics::{ThreadPoolMetrics, ThreadPoolWorkerId}; @@ -23,6 +25,10 @@ use crate::scram::countmin::CountMinSketch; pub struct ThreadPool { runtime: Option, pub metrics: Arc, + + // we hash a lot of passwords. + // we keep a cache of partial hashes for faster validation. + pub(super) cache: Pbkdf2Cache, } /// How often to reset the sketch values @@ -68,6 +74,7 @@ impl ThreadPool { Self { runtime: Some(runtime), metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)), + cache: Pbkdf2Cache::new(), } }) } @@ -130,7 +137,7 @@ struct JobSpec { } impl Future for JobSpec { - type Output = [u8; 32]; + type Output = pbkdf2::Block; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { STATE.with_borrow_mut(|state| { @@ -166,10 +173,10 @@ impl Future for JobSpec { } } -pub(crate) struct JobHandle(tokio::task::JoinHandle<[u8; 32]>); +pub(crate) struct JobHandle(tokio::task::JoinHandle); impl Future for JobHandle { - type Output = [u8; 32]; + type Output = pbkdf2::Block; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.0.poll_unpin(cx) { @@ -203,10 +210,10 @@ mod tests { .spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096)) .await; - let expected = [ + let expected = &[ 10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242, 178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140, ]; - assert_eq!(actual, expected); + assert_eq!(actual.as_slice(), expected); } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 0987b6927f..511bdc4e42 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -4,6 +4,7 @@ use std::time::Duration; use ed25519_dalek::SigningKey; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use jose_jwk::jose_b64; +use postgres_client::error::SqlState; use postgres_client::maybe_tls_stream::MaybeTlsStream; use rand_core::OsRng; use tracing::field::display; @@ -26,7 +27,7 @@ use crate::context::RequestContext; use crate::control_plane::client::ApiLockError; use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; use crate::error::{ErrorKind, ReportableError, UserFacingError}; -use crate::intern::EndpointIdInt; +use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::pqproto::StartupMessageParams; use crate::proxy::{connect_auth, connect_compute}; use crate::rate_limiter::EndpointRateLimiter; @@ -76,9 +77,11 @@ impl PoolingBackend { }; let ep = EndpointIdInt::from(&user_info.endpoint); + let role = RoleNameInt::from(&user_info.user); let auth_outcome = crate::auth::validate_password_and_exchange( - &self.config.authentication_config.thread_pool, + &self.config.authentication_config.scram_thread_pool, ep, + role, password, secret, ) @@ -457,15 +460,14 @@ impl ReportableError for HttpConnError { match self { HttpConnError::ConnectError(_) => ErrorKind::Compute, HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute, - HttpConnError::PostgresConnectionError(p) => { - if p.as_db_error().is_some() { - // postgres rejected the connection - ErrorKind::Postgres - } else { - // couldn't even reach postgres - ErrorKind::Compute - } - } + HttpConnError::PostgresConnectionError(p) => match p.as_db_error() { + // user provided a wrong database name + Some(err) if err.code() == &SqlState::INVALID_CATALOG_NAME => ErrorKind::User, + // postgres rejected the connection + Some(_) => ErrorKind::Postgres, + // couldn't even reach postgres + None => ErrorKind::Compute, + }, HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute, HttpConnError::ComputeCtl(_) => ErrorKind::Service, HttpConnError::JwtPayloadError(_) => ErrorKind::User, diff --git a/proxy/src/serverless/rest.rs b/proxy/src/serverless/rest.rs index c9b5e99747..0c3d2c958d 100644 --- a/proxy/src/serverless/rest.rs +++ b/proxy/src/serverless/rest.rs @@ -1,5 +1,6 @@ use std::borrow::Cow; use std::collections::HashMap; +use std::convert::Infallible; use std::sync::Arc; use bytes::Bytes; @@ -12,6 +13,7 @@ use hyper::body::Incoming; use hyper::http::{HeaderName, HeaderValue}; use hyper::{Request, Response, StatusCode}; use indexmap::IndexMap; +use moka::sync::Cache; use ouroboros::self_referencing; use serde::de::DeserializeOwned; use serde::{Deserialize, Deserializer}; @@ -53,12 +55,12 @@ use super::http_util::{ }; use super::json::JsonConversionError; use crate::auth::backend::ComputeCredentialKeys; -use crate::cache::{Cached, TimedLru}; +use crate::cache::common::{count_cache_insert, count_cache_outcome, eviction_listener}; use crate::config::ProxyConfig; use crate::context::RequestContext; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::http::read_body_with_limit; -use crate::metrics::Metrics; +use crate::metrics::{CacheKind, Metrics}; use crate::serverless::sql_over_http::HEADER_VALUE_TRUE; use crate::types::EndpointCacheKey; use crate::util::deserialize_json_string; @@ -138,8 +140,31 @@ pub struct ApiConfig { } // The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint -pub(crate) type DbSchemaCache = TimedLru>; +pub(crate) struct DbSchemaCache(Cache>); impl DbSchemaCache { + pub fn new(config: crate::config::CacheOptions) -> Self { + let builder = Cache::builder().name("schema"); + let builder = config.moka(builder); + + let metrics = &Metrics::get().cache; + if let Some(size) = config.size { + metrics.capacity.set(CacheKind::Schema, size as i64); + } + + let builder = + builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Schema, cause)); + + Self(builder.build()) + } + + pub async fn maintain(&self) -> Result { + let mut ticker = tokio::time::interval(std::time::Duration::from_secs(60)); + loop { + ticker.tick().await; + self.0.run_pending_tasks(); + } + } + pub async fn get_cached_or_remote( &self, endpoint_id: &EndpointCacheKey, @@ -149,8 +174,9 @@ impl DbSchemaCache { ctx: &RequestContext, config: &'static ProxyConfig, ) -> Result, RestError> { - match self.get_with_created_at(endpoint_id) { - Some(Cached { value: (v, _), .. }) => Ok(v), + let cache_result = count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id)); + match cache_result { + Some(v) => Ok(v), None => { info!("db_schema cache miss for endpoint: {:?}", endpoint_id); let remote_value = self @@ -173,7 +199,8 @@ impl DbSchemaCache { db_extra_search_path: None, }; let value = Arc::new((api_config, schema_owned)); - self.insert(endpoint_id.clone(), value); + count_cache_insert(CacheKind::Schema); + self.0.insert(endpoint_id.clone(), value); return Err(e); } Err(e) => { @@ -181,7 +208,8 @@ impl DbSchemaCache { } }; let value = Arc::new((api_config, schema_owned)); - self.insert(endpoint_id.clone(), value.clone()); + count_cache_insert(CacheKind::Schema); + self.0.insert(endpoint_id.clone(), value.clone()); Ok(value) } } diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index 26f65379e7..c334e820d7 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -192,34 +192,29 @@ pub(crate) async fn handle( let line = get(db_error, |db| db.line().map(|l| l.to_string())); let routine = get(db_error, |db| db.routine()); - match &e { - SqlOverHttpError::Postgres(e) - if e.as_db_error().is_some() && error_kind == ErrorKind::User => - { - // this error contains too much info, and it's not an error we care about. - if tracing::enabled!(Level::DEBUG) { - tracing::debug!( - kind=error_kind.to_metric_label(), - error=%e, - msg=message, - "forwarding error to user" - ); - } else { - tracing::info!( - kind = error_kind.to_metric_label(), - error = "bad query", - "forwarding error to user" - ); - } - } - _ => { - tracing::info!( + if db_error.is_some() && error_kind == ErrorKind::User { + // this error contains too much info, and it's not an error we care about. + if tracing::enabled!(Level::DEBUG) { + debug!( kind=error_kind.to_metric_label(), error=%e, msg=message, "forwarding error to user" ); + } else { + info!( + kind = error_kind.to_metric_label(), + error = "bad query", + "forwarding error to user" + ); } + } else { + info!( + kind=error_kind.to_metric_label(), + error=%e, + msg=message, + "forwarding error to user" + ); } json_response( diff --git a/proxy/src/signals.rs b/proxy/src/signals.rs index 32b2344a1c..63ef7d1061 100644 --- a/proxy/src/signals.rs +++ b/proxy/src/signals.rs @@ -4,6 +4,8 @@ use anyhow::bail; use tokio_util::sync::CancellationToken; use tracing::{info, warn}; +use crate::metrics::{Metrics, ServiceInfo}; + /// Handle unix signals appropriately. pub async fn handle( token: CancellationToken, @@ -28,10 +30,12 @@ where // Shut down the whole application. _ = interrupt.recv() => { warn!("received SIGINT, exiting immediately"); + Metrics::get().service.info.set_label(ServiceInfo::terminating()); bail!("interrupted"); } _ = terminate.recv() => { warn!("received SIGTERM, shutting down once all existing connections have closed"); + Metrics::get().service.info.set_label(ServiceInfo::terminating()); token.cancel(); } } diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index d6a43df188..9447b9623b 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -102,7 +102,7 @@ pub struct ReportedError { } impl ReportedError { - pub fn new(e: (impl UserFacingError + Into)) -> Self { + pub fn new(e: impl UserFacingError + Into) -> Self { let error_kind = e.get_error_kind(); Self { source: e.into(), diff --git a/safekeeper/src/wal_backup.rs b/safekeeper/src/wal_backup.rs index 03c8f7e84a..191f8aacf1 100644 --- a/safekeeper/src/wal_backup.rs +++ b/safekeeper/src/wal_backup.rs @@ -12,7 +12,7 @@ use futures::stream::{self, FuturesOrdered}; use postgres_ffi::v14::xlog_utils::XLogSegNoOffsetToRecPtr; use postgres_ffi::{PG_TLI, XLogFileName, XLogSegNo}; use remote_storage::{ - DownloadOpts, GenericRemoteStorage, ListingMode, RemotePath, StorageMetadata, + DownloadError, DownloadOpts, GenericRemoteStorage, ListingMode, RemotePath, StorageMetadata, }; use safekeeper_api::models::PeerInfo; use tokio::fs::File; @@ -607,6 +607,9 @@ pub(crate) async fn copy_partial_segment( storage.copy_object(source, destination, &cancel).await } +const WAL_READ_WARN_THRESHOLD: u32 = 2; +const WAL_READ_MAX_RETRIES: u32 = 3; + pub async fn read_object( storage: &GenericRemoteStorage, file_path: &RemotePath, @@ -620,12 +623,23 @@ pub async fn read_object( byte_start: std::ops::Bound::Included(offset), ..Default::default() }; - let download = storage - .download(file_path, &opts, &cancel) - .await - .with_context(|| { - format!("Failed to open WAL segment download stream for remote path {file_path:?}") - })?; + + // This retry only solves the connect errors: subsequent reads can still fail as this function returns + // a stream. + let download = backoff::retry( + || async { storage.download(file_path, &opts, &cancel).await }, + DownloadError::is_permanent, + WAL_READ_WARN_THRESHOLD, + WAL_READ_MAX_RETRIES, + "download WAL segment", + &cancel, + ) + .await + .ok_or_else(|| DownloadError::Cancelled) + .and_then(|x| x) + .with_context(|| { + format!("Failed to open WAL segment download stream for remote path {file_path:?}") + })?; let reader = tokio_util::io::StreamReader::new(download.download_stream); diff --git a/storage_controller/src/operation_utils.rs b/storage_controller/src/operation_utils.rs index af86010ab7..1060c92832 100644 --- a/storage_controller/src/operation_utils.rs +++ b/storage_controller/src/operation_utils.rs @@ -46,11 +46,31 @@ impl TenantShardDrain { &self, tenants: &BTreeMap, scheduler: &Scheduler, - ) -> Option { - let tenant_shard = tenants.get(&self.tenant_shard_id)?; + ) -> TenantShardDrainAction { + let Some(tenant_shard) = tenants.get(&self.tenant_shard_id) else { + return TenantShardDrainAction::Skip; + }; if *tenant_shard.intent.get_attached() != Some(self.drained_node) { - return None; + // If the intent attached node is not the drained node, check the observed state + // of the shard on the drained node. If it is Attached*, it means the shard is + // beeing migrated from the drained node. The drain loop needs to wait for the + // reconciliation to complete for a smooth draining. + + use pageserver_api::models::LocationConfigMode::*; + + let attach_mode = tenant_shard + .observed + .locations + .get(&self.drained_node) + .and_then(|observed| observed.conf.as_ref().map(|conf| conf.mode)); + + return match (attach_mode, tenant_shard.intent.get_attached()) { + (Some(AttachedSingle | AttachedMulti | AttachedStale), Some(intent_node_id)) => { + TenantShardDrainAction::Reconcile(*intent_node_id) + } + _ => TenantShardDrainAction::Skip, + }; } // Only tenants with a normal (Active) scheduling policy are proactively moved @@ -63,19 +83,19 @@ impl TenantShardDrain { } ShardSchedulingPolicy::Pause | ShardSchedulingPolicy::Stop => { // If we have been asked to avoid rescheduling this shard, then do not migrate it during a drain - return None; + return TenantShardDrainAction::Skip; } } match tenant_shard.preferred_secondary(scheduler) { - Some(node) => Some(node), + Some(node) => TenantShardDrainAction::RescheduleToSecondary(node), None => { tracing::warn!( tenant_id=%self.tenant_shard_id.tenant_id, shard_id=%self.tenant_shard_id.shard_slug(), "No eligible secondary while draining {}", self.drained_node ); - None + TenantShardDrainAction::Skip } } } @@ -138,3 +158,17 @@ impl TenantShardDrain { } } } + +/// Action to take when draining a tenant shard. +pub(crate) enum TenantShardDrainAction { + /// The tenant shard is on the draining node. + /// Reschedule the tenant shard to a secondary location. + /// Holds a destination node id to reschedule to. + RescheduleToSecondary(NodeId), + /// The tenant shard is beeing migrated from the draining node. + /// Wait for the reconciliation to complete. + /// Holds the intent attached node id. + Reconcile(NodeId), + /// The tenant shard is not eligible for drainining, skip it. + Skip, +} diff --git a/storage_controller/src/reconciler.rs b/storage_controller/src/reconciler.rs index d1590ec75e..ff5a3831cd 100644 --- a/storage_controller/src/reconciler.rs +++ b/storage_controller/src/reconciler.rs @@ -981,6 +981,7 @@ impl Reconciler { )); } + let mut first_err = None; for (node, conf) in changes { if self.cancel.is_cancelled() { return Err(ReconcileError::Cancel); @@ -990,7 +991,12 @@ impl Reconciler { // shard _available_ (the attached location), and configuring secondary locations // can be done lazily when the node becomes available (via background reconciliation). if node.is_available() { - self.location_config(&node, conf, None, false).await?; + let res = self.location_config(&node, conf, None, false).await; + if let Err(err) = res { + if first_err.is_none() { + first_err = Some(err); + } + } } else { // If the node is unavailable, we skip and consider the reconciliation successful: this // is a common case where a pageserver is marked unavailable: we demote a location on @@ -1002,6 +1008,10 @@ impl Reconciler { } } + if let Some(err) = first_err { + return Err(err); + } + // The condition below identifies a detach. We must have no attached intent and // must have been attached to something previously. Pass this information to // the [`ComputeHook`] such that it can update its tenant-wide state. diff --git a/storage_controller/src/service.rs b/storage_controller/src/service.rs index 8f5efe8ac4..37380b8fbe 100644 --- a/storage_controller/src/service.rs +++ b/storage_controller/src/service.rs @@ -79,7 +79,7 @@ use crate::id_lock_map::{ use crate::leadership::Leadership; use crate::metrics; use crate::node::{AvailabilityTransition, Node}; -use crate::operation_utils::{self, TenantShardDrain}; +use crate::operation_utils::{self, TenantShardDrain, TenantShardDrainAction}; use crate::pageserver_client::PageserverClient; use crate::peer_client::GlobalObservedState; use crate::persistence::split_state::SplitState; @@ -1274,7 +1274,7 @@ impl Service { // Always attempt autosplits. Sharding is crucial for bulk ingest performance, so we // must be responsive when new projects begin ingesting and reach the threshold. self.autosplit_tenants().await; - } + }, _ = self.reconcilers_cancel.cancelled() => return } } @@ -1530,10 +1530,19 @@ impl Service { // so that waiters will see the correct error after waiting. tenant.set_last_error(result.sequence, e); - // Skip deletions on reconcile failures - let upsert_deltas = - deltas.filter(|delta| matches!(delta, ObservedStateDelta::Upsert(_))); - tenant.apply_observed_deltas(upsert_deltas); + // If the reconciliation failed, don't clear the observed state for places where we + // detached. Instead, mark the observed state as uncertain. + let failed_reconcile_deltas = deltas.map(|delta| { + if let ObservedStateDelta::Delete(node_id) = delta { + ObservedStateDelta::Upsert(Box::new(( + node_id, + ObservedStateLocation { conf: None }, + ))) + } else { + delta + } + }); + tenant.apply_observed_deltas(failed_reconcile_deltas); } } @@ -8867,6 +8876,9 @@ impl Service { for (_tenant_id, schedule_context, shards) in TenantShardExclusiveIterator::new(tenants, ScheduleMode::Speculative) { + if work.len() >= MAX_OPTIMIZATIONS_PLAN_PER_PASS { + break; + } for shard in shards { if work.len() >= MAX_OPTIMIZATIONS_PLAN_PER_PASS { break; @@ -9631,16 +9643,16 @@ impl Service { tenant_shard_id: tid, }; - let dest_node_id = { + let drain_action = { let locked = self.inner.read().unwrap(); + tid_drain.tenant_shard_eligible_for_drain(&locked.tenants, &locked.scheduler) + }; - match tid_drain - .tenant_shard_eligible_for_drain(&locked.tenants, &locked.scheduler) - { - Some(node_id) => node_id, - None => { - continue; - } + let dest_node_id = match drain_action { + TenantShardDrainAction::RescheduleToSecondary(dest_node_id) => dest_node_id, + TenantShardDrainAction::Reconcile(intent_node_id) => intent_node_id, + TenantShardDrainAction::Skip => { + continue; } }; @@ -9675,14 +9687,16 @@ impl Service { { let mut locked = self.inner.write().unwrap(); let (nodes, tenants, scheduler) = locked.parts_mut(); - let rescheduled = tid_drain.reschedule_to_secondary( - dest_node_id, - tenants, - scheduler, - nodes, - )?; - if let Some(tenant_shard) = rescheduled { + let tenant_shard = match drain_action { + TenantShardDrainAction::RescheduleToSecondary(dest_node_id) => tid_drain + .reschedule_to_secondary(dest_node_id, tenants, scheduler, nodes)?, + TenantShardDrainAction::Reconcile(_) => tenants.get_mut(&tid), + // Note: Unreachable, handled above. + TenantShardDrainAction::Skip => None, + }; + + if let Some(tenant_shard) = tenant_shard { let waiter = self.maybe_configured_reconcile_shard( tenant_shard, nodes, diff --git a/storage_controller/src/tenant_shard.rs b/storage_controller/src/tenant_shard.rs index f60378470e..bf16c642af 100644 --- a/storage_controller/src/tenant_shard.rs +++ b/storage_controller/src/tenant_shard.rs @@ -249,6 +249,10 @@ impl IntentState { } pub(crate) fn push_secondary(&mut self, scheduler: &mut Scheduler, new_secondary: NodeId) { + // Every assertion here should probably have a corresponding check in + // `validate_optimization` unless it is an invariant that should never be violated. Note + // that the lock is not held between planning optimizations and applying them so you have to + // assume any valid state transition of the intent state may have occurred assert!(!self.secondary.contains(&new_secondary)); assert!(self.attached != Some(new_secondary)); scheduler.update_node_ref_counts( @@ -808,8 +812,6 @@ impl TenantShard { /// if the swap is not possible and leaves the intent state in its original state. /// /// Arguments: - /// `attached_to`: the currently attached location matching the intent state (may be None if the - /// shard is not attached) /// `promote_to`: an optional secondary location of this tenant shard. If set to None, we ask /// the scheduler to recommend a node pub(crate) fn reschedule_to_secondary( @@ -1335,8 +1337,9 @@ impl TenantShard { true } - /// Check that the desired modifications to the intent state are compatible with - /// the current intent state + /// Check that the desired modifications to the intent state are compatible with the current + /// intent state. Note that the lock is not held between planning optimizations and applying + /// them so any valid state transition of the intent state may have occurred. fn validate_optimization(&self, optimization: &ScheduleOptimization) -> bool { match optimization.action { ScheduleOptimizationAction::MigrateAttachment(MigrateAttachment { @@ -1352,6 +1355,9 @@ impl TenantShard { }) => { // It's legal to remove a secondary that is not present in the intent state !self.intent.secondary.contains(&new_node_id) + // Ensure the secondary hasn't already been promoted to attached by a concurrent + // optimization/migration. + && self.intent.attached != Some(new_node_id) } ScheduleOptimizationAction::CreateSecondary(new_node_id) => { !self.intent.secondary.contains(&new_node_id) diff --git a/test_runner/fixtures/compare_fixtures.py b/test_runner/fixtures/compare_fixtures.py index 150046b99a..3d248efc04 100644 --- a/test_runner/fixtures/compare_fixtures.py +++ b/test_runner/fixtures/compare_fixtures.py @@ -16,6 +16,7 @@ from typing_extensions import override from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker from fixtures.log_helper import log from fixtures.neon_fixtures import ( + Endpoint, NeonEnv, PgBin, PgProtocol, @@ -129,6 +130,10 @@ class NeonCompare(PgCompare): # Start pg self._pg = self.env.endpoints.create_start("main", "main", self.tenant) + @property + def endpoint(self) -> Endpoint: + return self._pg + @property @override def pg(self) -> PgProtocol: diff --git a/test_runner/fixtures/endpoint/http.py b/test_runner/fixtures/endpoint/http.py index 64db2b1f17..d235ac2143 100644 --- a/test_runner/fixtures/endpoint/http.py +++ b/test_runner/fixtures/endpoint/http.py @@ -79,18 +79,28 @@ class EndpointHttpClient(requests.Session): return json def prewarm_lfc(self, from_endpoint_id: str | None = None): + """ + Prewarm LFC cache from given endpoint and wait till it finishes or errors + """ params = {"from_endpoint": from_endpoint_id} if from_endpoint_id else dict() self.post(self.prewarm_url, params=params).raise_for_status() self.prewarm_lfc_wait() def prewarm_lfc_wait(self): + """ + Wait till LFC prewarm returns with error or success. + If prewarm was not requested before calling this function, it will error + """ + statuses = "failed", "completed", "skipped" + def prewarmed(): json = self.prewarm_lfc_status() status, err = json["status"], json.get("error") - assert status in ["failed", "completed", "skipped"], f"{status}, {err=}" + assert status in statuses, f"{status}, {err=}" wait_until(prewarmed, timeout=60) - assert self.prewarm_lfc_status()["status"] != "failed" + res = self.prewarm_lfc_status() + assert res["status"] != "failed", res def offload_lfc_status(self) -> dict[str, str]: res = self.get(self.offload_url) @@ -99,17 +109,26 @@ class EndpointHttpClient(requests.Session): return json def offload_lfc(self): + """ + Offload LFC cache to endpoint storage and wait till offload finishes or errors + """ self.post(self.offload_url).raise_for_status() self.offload_lfc_wait() def offload_lfc_wait(self): + """ + Wait till LFC offload returns with error or success. + If offload was not requested before calling this function, it will error + """ + def offloaded(): json = self.offload_lfc_status() status, err = json["status"], json.get("error") assert status in ["failed", "completed"], f"{status}, {err=}" - wait_until(offloaded) - assert self.offload_lfc_status()["status"] != "failed" + wait_until(offloaded, timeout=60) + res = self.offload_lfc_status() + assert res["status"] != "failed", res def promote(self, promote_spec: dict[str, Any], disconnect: bool = False): url = f"http://localhost:{self.external_port}/promote" diff --git a/test_runner/fixtures/neon_api.py b/test_runner/fixtures/neon_api.py index 8d447c837f..69160dab20 100644 --- a/test_runner/fixtures/neon_api.py +++ b/test_runner/fixtures/neon_api.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re import time from typing import TYPE_CHECKING, cast, final @@ -13,6 +14,17 @@ if TYPE_CHECKING: from fixtures.pg_version import PgVersion +def connstr_to_env(connstr: str) -> dict[str, str]: + # postgresql://neondb_owner:npg_kuv6Rqi1cB@ep-old-silence-w26pxsvz-pooler.us-east-2.aws.neon.build/neondb?sslmode=require&channel_binding=...' + parts = re.split(r":|@|\/|\?", connstr.removeprefix("postgresql://")) + return { + "PGUSER": parts[0], + "PGPASSWORD": parts[1], + "PGHOST": parts[2], + "PGDATABASE": parts[3], + } + + def connection_parameters_to_env(params: dict[str, str]) -> dict[str, str]: return { "PGHOST": params["host"], diff --git a/test_runner/fixtures/neon_cli.py b/test_runner/fixtures/neon_cli.py index d7634f24a4..390efe0309 100644 --- a/test_runner/fixtures/neon_cli.py +++ b/test_runner/fixtures/neon_cli.py @@ -587,7 +587,9 @@ class NeonLocalCli(AbstractNeonCli): ] extra_env_vars = env or {} if basebackup_request_tries is not None: - extra_env_vars["NEON_COMPUTE_TESTING_BASEBACKUP_TRIES"] = str(basebackup_request_tries) + extra_env_vars["NEON_COMPUTE_TESTING_BASEBACKUP_RETRIES"] = str( + basebackup_request_tries + ) if remote_ext_base_url is not None: args.extend(["--remote-ext-base-url", remote_ext_base_url]) @@ -623,6 +625,7 @@ class NeonLocalCli(AbstractNeonCli): pageserver_id: int | None = None, safekeepers: list[int] | None = None, check_return_code=True, + timeout_sec: float | None = None, ) -> subprocess.CompletedProcess[str]: args = ["endpoint", "reconfigure", endpoint_id] if tenant_id is not None: @@ -631,7 +634,7 @@ class NeonLocalCli(AbstractNeonCli): args.extend(["--pageserver-id", str(pageserver_id)]) if safekeepers is not None: args.extend(["--safekeepers", (",".join(map(str, safekeepers)))]) - return self.raw_cli(args, check_return_code=check_return_code) + return self.raw_cli(args, check_return_code=check_return_code, timeout=timeout_sec) def endpoint_refresh_configuration( self, diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index e02b3b12f8..7f59547c73 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -4930,15 +4930,34 @@ class Endpoint(PgProtocol, LogUtils): def is_running(self): return self._running._value > 0 - def reconfigure(self, pageserver_id: int | None = None, safekeepers: list[int] | None = None): + def reconfigure( + self, + pageserver_id: int | None = None, + safekeepers: list[int] | None = None, + timeout_sec: float = 120, + ): assert self.endpoint_id is not None # If `safekeepers` is not None, they are remember them as active and use # in the following commands. if safekeepers is not None: self.active_safekeepers = safekeepers - self.env.neon_cli.endpoint_reconfigure( - self.endpoint_id, self.tenant_id, pageserver_id, self.active_safekeepers - ) + + start_time = time.time() + while True: + try: + self.env.neon_cli.endpoint_reconfigure( + self.endpoint_id, + self.tenant_id, + pageserver_id, + self.active_safekeepers, + timeout_sec=timeout_sec, + ) + return + except RuntimeError as e: + if time.time() - start_time > timeout_sec: + raise e + log.warning(f"Reconfigure failed with error: {e}. Retrying...") + time.sleep(5) def refresh_configuration(self): assert self.endpoint_id is not None diff --git a/test_runner/fixtures/workload.py b/test_runner/fixtures/workload.py index e17a8e989b..3ac61b5d8c 100644 --- a/test_runner/fixtures/workload.py +++ b/test_runner/fixtures/workload.py @@ -78,6 +78,9 @@ class Workload: """ if self._endpoint is not None: with ENDPOINT_LOCK: + # It's important that we update config.json before issuing the reconfigure request to make sure + # that PG-initiated spec refresh doesn't mess things up by reverting to the old spec. + self._endpoint.update_pageservers_in_config() self._endpoint.reconfigure() def endpoint(self, pageserver_id: int | None = None) -> Endpoint: @@ -97,10 +100,10 @@ class Workload: self._endpoint.start(pageserver_id=pageserver_id) self._configured_pageserver = pageserver_id else: - if self._configured_pageserver != pageserver_id: - self._configured_pageserver = pageserver_id - self._endpoint.reconfigure(pageserver_id=pageserver_id) - self._endpoint_config = pageserver_id + # It's important that we update config.json before issuing the reconfigure request to make sure + # that PG-initiated spec refresh doesn't mess things up by reverting to the old spec. + self._endpoint.update_pageservers_in_config(pageserver_id=pageserver_id) + self._endpoint.reconfigure(pageserver_id=pageserver_id) connstring = self._endpoint.safe_psql( "SELECT setting FROM pg_settings WHERE name='neon.pageserver_connstring'" diff --git a/test_runner/logical_repl/README.md b/test_runner/logical_repl/README.md index 449e56e21d..74af203c03 100644 --- a/test_runner/logical_repl/README.md +++ b/test_runner/logical_repl/README.md @@ -9,9 +9,10 @@ ```bash export BENCHMARK_CONNSTR=postgres://user:pass@ep-abc-xyz-123.us-east-2.aws.neon.build/neondb +export CLICKHOUSE_PASSWORD=ch_password123 docker compose -f test_runner/logical_repl/clickhouse/docker-compose.yml up -d -./scripts/pytest -m remote_cluster -k test_clickhouse +./scripts/pytest -m remote_cluster -k 'test_clickhouse[release-pg17]' docker compose -f test_runner/logical_repl/clickhouse/docker-compose.yml down ``` @@ -21,6 +22,6 @@ docker compose -f test_runner/logical_repl/clickhouse/docker-compose.yml down export BENCHMARK_CONNSTR=postgres://user:pass@ep-abc-xyz-123.us-east-2.aws.neon.build/neondb docker compose -f test_runner/logical_repl/debezium/docker-compose.yml up -d -./scripts/pytest -m remote_cluster -k test_debezium +./scripts/pytest -m remote_cluster -k 'test_debezium[release-pg17]' docker compose -f test_runner/logical_repl/debezium/docker-compose.yml down ``` diff --git a/test_runner/logical_repl/clickhouse/docker-compose.yml b/test_runner/logical_repl/clickhouse/docker-compose.yml index e00038b811..4131fbf0a5 100644 --- a/test_runner/logical_repl/clickhouse/docker-compose.yml +++ b/test_runner/logical_repl/clickhouse/docker-compose.yml @@ -1,9 +1,11 @@ services: clickhouse: - image: clickhouse/clickhouse-server + image: clickhouse/clickhouse-server:25.6 user: "101:101" container_name: clickhouse hostname: clickhouse + environment: + - CLICKHOUSE_PASSWORD=${CLICKHOUSE_PASSWORD:-ch_password123} ports: - 127.0.0.1:8123:8123 - 127.0.0.1:9000:9000 diff --git a/test_runner/logical_repl/debezium/docker-compose.yml b/test_runner/logical_repl/debezium/docker-compose.yml index fee127a2fd..fd446f173f 100644 --- a/test_runner/logical_repl/debezium/docker-compose.yml +++ b/test_runner/logical_repl/debezium/docker-compose.yml @@ -1,18 +1,28 @@ services: zookeeper: - image: quay.io/debezium/zookeeper:2.7 + image: quay.io/debezium/zookeeper:3.1.3.Final + ports: + - 127.0.0.1:2181:2181 + - 127.0.0.1:2888:2888 + - 127.0.0.1:3888:3888 kafka: - image: quay.io/debezium/kafka:2.7 + image: quay.io/debezium/kafka:3.1.3.Final + depends_on: [zookeeper] environment: ZOOKEEPER_CONNECT: "zookeeper:2181" - KAFKA_ADVERTISED_LISTENERS: PLAINTEXT://kafka:9092 + KAFKA_LISTENERS: INTERNAL://:9092,EXTERNAL://:29092 + KAFKA_ADVERTISED_LISTENERS: INTERNAL://kafka:9092,EXTERNAL://localhost:29092 + KAFKA_LISTENER_SECURITY_PROTOCOL_MAP: INTERNAL:PLAINTEXT,EXTERNAL:PLAINTEXT + KAFKA_INTER_BROKER_LISTENER_NAME: INTERNAL KAFKA_BROKER_ID: 1 KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR: 1 KAFKA_JMX_PORT: 9991 ports: - - 127.0.0.1:9092:9092 + - 9092:9092 + - 29092:29092 debezium: - image: quay.io/debezium/connect:2.7 + image: quay.io/debezium/connect:3.1.3.Final + depends_on: [kafka] environment: BOOTSTRAP_SERVERS: kafka:9092 GROUP_ID: 1 diff --git a/test_runner/logical_repl/test_clickhouse.py b/test_runner/logical_repl/test_clickhouse.py index c05684baf9..ef41ee6187 100644 --- a/test_runner/logical_repl/test_clickhouse.py +++ b/test_runner/logical_repl/test_clickhouse.py @@ -53,8 +53,13 @@ def test_clickhouse(remote_pg: RemotePostgres): cur.execute("CREATE TABLE table1 (id integer primary key, column1 varchar(10));") cur.execute("INSERT INTO table1 (id, column1) VALUES (1, 'abc'), (2, 'def');") conn.commit() - client = clickhouse_connect.get_client(host=clickhouse_host) + if "CLICKHOUSE_PASSWORD" not in os.environ: + raise RuntimeError("CLICKHOUSE_PASSWORD is not set") + client = clickhouse_connect.get_client( + host=clickhouse_host, password=os.environ["CLICKHOUSE_PASSWORD"] + ) client.command("SET allow_experimental_database_materialized_postgresql=1") + client.command("DROP DATABASE IF EXISTS db1_postgres") client.command( "CREATE DATABASE db1_postgres ENGINE = " f"MaterializedPostgreSQL('{conn_options['host']}', " diff --git a/test_runner/logical_repl/test_debezium.py b/test_runner/logical_repl/test_debezium.py index a53e6cef92..becdaffcf8 100644 --- a/test_runner/logical_repl/test_debezium.py +++ b/test_runner/logical_repl/test_debezium.py @@ -17,6 +17,7 @@ from fixtures.utils import wait_until if TYPE_CHECKING: from fixtures.neon_fixtures import RemotePostgres + from kafka import KafkaConsumer class DebeziumAPI: @@ -101,9 +102,13 @@ def debezium(remote_pg: RemotePostgres): assert len(dbz.list_connectors()) == 1 from kafka import KafkaConsumer + kafka_host = "kafka" if (os.getenv("CI", "false") == "true") else "127.0.0.1" + kafka_port = 9092 if (os.getenv("CI", "false") == "true") else 29092 + log.info("Connecting to Kafka: %s:%s", kafka_host, kafka_port) + consumer = KafkaConsumer( "dbserver1.inventory.customers", - bootstrap_servers=["kafka:9092"], + bootstrap_servers=[f"{kafka_host}:{kafka_port}"], auto_offset_reset="earliest", enable_auto_commit=False, ) @@ -112,7 +117,7 @@ def debezium(remote_pg: RemotePostgres): assert resp.status_code == 204 -def get_kafka_msg(consumer, ts_ms, before=None, after=None) -> None: +def get_kafka_msg(consumer: KafkaConsumer, ts_ms, before=None, after=None) -> None: """ Gets the message from Kafka and checks its validity Arguments: @@ -124,6 +129,7 @@ def get_kafka_msg(consumer, ts_ms, before=None, after=None) -> None: after: a dictionary, if not None, the after field from the kafka message must have the same values for the same keys """ + log.info("Bootstrap servers: %s", consumer.config["bootstrap_servers"]) msg = consumer.poll() assert msg, "Empty message" for val in msg.values(): diff --git a/test_runner/performance/benchbase_tpc_c_helpers/generate_diagrams.py b/test_runner/performance/benchbase_tpc_c_helpers/generate_diagrams.py new file mode 100644 index 0000000000..cf41a4ff59 --- /dev/null +++ b/test_runner/performance/benchbase_tpc_c_helpers/generate_diagrams.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +""" +Generate TPS and latency charts from BenchBase TPC-C results CSV files. + +This script reads a CSV file containing BenchBase results and generates two charts: +1. TPS (requests per second) over time +2. P95 and P99 latencies over time + +Both charts are combined in a single SVG file. +""" + +import argparse +import sys +from pathlib import Path + +import matplotlib.pyplot as plt # type: ignore[import-not-found] +import pandas as pd # type: ignore[import-untyped] + + +def load_results_csv(csv_file_path): + """Load BenchBase results CSV file into a pandas DataFrame.""" + try: + df = pd.read_csv(csv_file_path) + + # Validate required columns exist + required_columns = [ + "Time (seconds)", + "Throughput (requests/second)", + "95th Percentile Latency (millisecond)", + "99th Percentile Latency (millisecond)", + ] + + missing_columns = [col for col in required_columns if col not in df.columns] + if missing_columns: + print(f"Error: Missing required columns: {missing_columns}") + sys.exit(1) + + return df + + except FileNotFoundError: + print(f"Error: CSV file not found: {csv_file_path}") + sys.exit(1) + except pd.errors.EmptyDataError: + print(f"Error: CSV file is empty: {csv_file_path}") + sys.exit(1) + except Exception as e: + print(f"Error reading CSV file: {e}") + sys.exit(1) + + +def generate_charts(df, input_filename, output_svg_path, title_suffix=None): + """Generate combined TPS and latency charts and save as SVG.""" + + # Get the filename without extension for chart titles + file_label = Path(input_filename).stem + + # Build title ending with optional suffix + if title_suffix: + title_ending = f"{title_suffix} - {file_label}" + else: + title_ending = file_label + + # Create figure with two subplots + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10)) + + # Chart 1: Time vs TPS + ax1.plot( + df["Time (seconds)"], + df["Throughput (requests/second)"], + linewidth=1, + color="blue", + alpha=0.7, + ) + ax1.set_xlabel("Time (seconds)") + ax1.set_ylabel("TPS (Requests Per Second)") + ax1.set_title(f"Benchbase TPC-C Like Throughput (TPS) - {title_ending}") + ax1.grid(True, alpha=0.3) + ax1.set_xlim(0, df["Time (seconds)"].max()) + + # Chart 2: Time vs P95 and P99 Latencies + ax2.plot( + df["Time (seconds)"], + df["95th Percentile Latency (millisecond)"], + linewidth=1, + color="orange", + alpha=0.7, + label="Latency P95", + ) + ax2.plot( + df["Time (seconds)"], + df["99th Percentile Latency (millisecond)"], + linewidth=1, + color="red", + alpha=0.7, + label="Latency P99", + ) + ax2.set_xlabel("Time (seconds)") + ax2.set_ylabel("Latency (ms)") + ax2.set_title(f"Benchbase TPC-C Like Latency - {title_ending}") + ax2.grid(True, alpha=0.3) + ax2.set_xlim(0, df["Time (seconds)"].max()) + ax2.legend() + + plt.tight_layout() + + # Save as SVG + try: + plt.savefig(output_svg_path, format="svg", dpi=300, bbox_inches="tight") + print(f"Charts saved to: {output_svg_path}") + except Exception as e: + print(f"Error saving SVG file: {e}") + sys.exit(1) + + +def main(): + """Main function to parse arguments and generate charts.""" + parser = argparse.ArgumentParser( + description="Generate TPS and latency charts from BenchBase TPC-C results CSV" + ) + parser.add_argument( + "--input-csv", type=str, required=True, help="Path to the input CSV results file" + ) + parser.add_argument( + "--output-svg", type=str, required=True, help="Path for the output SVG chart file" + ) + parser.add_argument( + "--title-suffix", + type=str, + required=False, + help="Optional suffix to add to chart titles (e.g., 'Warmup', 'Benchmark Phase')", + ) + + args = parser.parse_args() + + # Validate input file exists + if not Path(args.input_csv).exists(): + print(f"Error: Input CSV file does not exist: {args.input_csv}") + sys.exit(1) + + # Create output directory if it doesn't exist + output_path = Path(args.output_svg) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Load data and generate charts + df = load_results_csv(args.input_csv) + generate_charts(df, args.input_csv, args.output_svg, args.title_suffix) + + print(f"Successfully generated charts from {len(df)} data points") + + +if __name__ == "__main__": + main() diff --git a/test_runner/performance/benchbase_tpc_c_helpers/generate_workload_size.py b/test_runner/performance/benchbase_tpc_c_helpers/generate_workload_size.py new file mode 100644 index 0000000000..1549c74b87 --- /dev/null +++ b/test_runner/performance/benchbase_tpc_c_helpers/generate_workload_size.py @@ -0,0 +1,339 @@ +import argparse +import html +import math +import os +import sys +from pathlib import Path + +CONFIGS_DIR = Path("../configs") +SCRIPTS_DIR = Path("../scripts") + +# Constants +## TODO increase times after testing +WARMUP_TIME_SECONDS = 1200 # 20 minutes +BENCHMARK_TIME_SECONDS = 3600 # 1 hour +RAMP_STEP_TIME_SECONDS = 300 # 5 minutes +BASE_TERMINALS = 130 +TERMINALS_PER_WAREHOUSE = 0.2 +OPTIMAL_RATE_FACTOR = 0.7 # 70% of max rate +BATCH_SIZE = 1000 +LOADER_THREADS = 4 +TRANSACTION_WEIGHTS = "45,43,4,4,4" # NewOrder, Payment, OrderStatus, Delivery, StockLevel +# Ramp-up rate multipliers +RAMP_RATE_FACTORS = [1.5, 1.1, 0.9, 0.7, 0.6, 0.4, 0.6, 0.7, 0.9, 1.1] + +# Templates for XML configs +WARMUP_XML = """ + + POSTGRES + org.postgresql.Driver + jdbc:postgresql://{hostname}/neondb?sslmode=require&ApplicationName=tpcc&reWriteBatchedInserts=true + neondb_owner + {password} + true + TRANSACTION_READ_COMMITTED + {batch_size} + {warehouses} + 0 + {terminals} + + + + {transaction_weights} + unlimited + POISSON + ZIPFIAN + + + + NewOrder + Payment + OrderStatus + Delivery + StockLevel + + +""" + +MAX_RATE_XML = """ + + POSTGRES + org.postgresql.Driver + jdbc:postgresql://{hostname}/neondb?sslmode=require&ApplicationName=tpcc&reWriteBatchedInserts=true + neondb_owner + {password} + true + TRANSACTION_READ_COMMITTED + {batch_size} + {warehouses} + 0 + {terminals} + + + + {transaction_weights} + unlimited + POISSON + ZIPFIAN + + + + NewOrder + Payment + OrderStatus + Delivery + StockLevel + + +""" + +OPT_RATE_XML = """ + + POSTGRES + org.postgresql.Driver + jdbc:postgresql://{hostname}/neondb?sslmode=require&ApplicationName=tpcc&reWriteBatchedInserts=true + neondb_owner + {password} + true + TRANSACTION_READ_COMMITTED + {batch_size} + {warehouses} + 0 + {terminals} + + + + {opt_rate} + {transaction_weights} + POISSON + ZIPFIAN + + + + NewOrder + Payment + OrderStatus + Delivery + StockLevel + + +""" + +RAMP_UP_XML = """ + + POSTGRES + org.postgresql.Driver + jdbc:postgresql://{hostname}/neondb?sslmode=require&ApplicationName=tpcc&reWriteBatchedInserts=true + neondb_owner + {password} + true + TRANSACTION_READ_COMMITTED + {batch_size} + {warehouses} + 0 + {terminals} + +{works} + + + NewOrder + Payment + OrderStatus + Delivery + StockLevel + + +""" + +WORK_TEMPLATE = f""" \n \n {{rate}}\n {TRANSACTION_WEIGHTS}\n POISSON\n ZIPFIAN\n \n""" + +# Templates for shell scripts +EXECUTE_SCRIPT = """# Create results directories +mkdir -p results_warmup +mkdir -p results_{suffix} +chmod 777 results_warmup results_{suffix} + +# Run warmup phase +docker run --network=host --rm \ + -v $(pwd)/configs:/configs \ + -v $(pwd)/results_warmup:/results \ + {docker_image}\ + -b tpcc \ + -c /configs/execute_{warehouses}_warehouses_warmup.xml \ + -d /results \ + --create=false --load=false --execute=true + +# Run benchmark phase +docker run --network=host --rm \ + -v $(pwd)/configs:/configs \ + -v $(pwd)/results_{suffix}:/results \ + {docker_image}\ + -b tpcc \ + -c /configs/execute_{warehouses}_warehouses_{suffix}.xml \ + -d /results \ + --create=false --load=false --execute=true\n""" + +LOAD_XML = """ + + POSTGRES + org.postgresql.Driver + jdbc:postgresql://{hostname}/neondb?sslmode=require&ApplicationName=tpcc&reWriteBatchedInserts=true + neondb_owner + {password} + true + TRANSACTION_READ_COMMITTED + {batch_size} + {warehouses} + {loader_threads} + +""" + +LOAD_SCRIPT = """# Create results directory for loading +mkdir -p results_load +chmod 777 results_load + +docker run --network=host --rm \ + -v $(pwd)/configs:/configs \ + -v $(pwd)/results_load:/results \ + {docker_image}\ + -b tpcc \ + -c /configs/load_{warehouses}_warehouses.xml \ + -d /results \ + --create=true --load=true --execute=false\n""" + + +def write_file(path, content): + path.parent.mkdir(parents=True, exist_ok=True) + try: + with open(path, "w") as f: + f.write(content) + except OSError as e: + print(f"Error writing {path}: {e}") + sys.exit(1) + # If it's a shell script, set executable permission + if str(path).endswith(".sh"): + os.chmod(path, 0o755) + + +def escape_xml_password(password): + """Escape XML special characters in password.""" + return html.escape(password, quote=True) + + +def get_docker_arch_tag(runner_arch): + """Map GitHub Actions runner.arch to Docker image architecture tag.""" + arch_mapping = {"X64": "amd64", "ARM64": "arm64"} + return arch_mapping.get(runner_arch, "amd64") # Default to amd64 + + +def main(): + parser = argparse.ArgumentParser(description="Generate BenchBase workload configs and scripts.") + parser.add_argument("--warehouses", type=int, required=True, help="Number of warehouses") + parser.add_argument("--max-rate", type=int, required=True, help="Max rate (TPS)") + parser.add_argument("--hostname", type=str, required=True, help="Database hostname") + parser.add_argument("--password", type=str, required=True, help="Database password") + parser.add_argument( + "--runner-arch", type=str, required=True, help="GitHub Actions runner architecture" + ) + args = parser.parse_args() + + warehouses = args.warehouses + max_rate = args.max_rate + hostname = args.hostname + password = args.password + runner_arch = args.runner_arch + + # Escape password for safe XML insertion + escaped_password = escape_xml_password(password) + + # Get the appropriate Docker architecture tag + docker_arch = get_docker_arch_tag(runner_arch) + docker_image = f"ghcr.io/neondatabase-labs/benchbase-postgres:latest-{docker_arch}" + + opt_rate = math.ceil(max_rate * OPTIMAL_RATE_FACTOR) + # Calculate terminals as next rounded integer of 40% of warehouses + terminals = math.ceil(BASE_TERMINALS + warehouses * TERMINALS_PER_WAREHOUSE) + ramp_rates = [math.ceil(max_rate * factor) for factor in RAMP_RATE_FACTORS] + + # Write configs + write_file( + CONFIGS_DIR / f"execute_{warehouses}_warehouses_warmup.xml", + WARMUP_XML.format( + warehouses=warehouses, + hostname=hostname, + password=escaped_password, + terminals=terminals, + batch_size=BATCH_SIZE, + warmup_time=WARMUP_TIME_SECONDS, + transaction_weights=TRANSACTION_WEIGHTS, + ), + ) + write_file( + CONFIGS_DIR / f"execute_{warehouses}_warehouses_max_rate.xml", + MAX_RATE_XML.format( + warehouses=warehouses, + hostname=hostname, + password=escaped_password, + terminals=terminals, + batch_size=BATCH_SIZE, + benchmark_time=BENCHMARK_TIME_SECONDS, + transaction_weights=TRANSACTION_WEIGHTS, + ), + ) + write_file( + CONFIGS_DIR / f"execute_{warehouses}_warehouses_opt_rate.xml", + OPT_RATE_XML.format( + warehouses=warehouses, + opt_rate=opt_rate, + hostname=hostname, + password=escaped_password, + terminals=terminals, + batch_size=BATCH_SIZE, + benchmark_time=BENCHMARK_TIME_SECONDS, + transaction_weights=TRANSACTION_WEIGHTS, + ), + ) + + ramp_works = "".join([WORK_TEMPLATE.format(rate=rate) for rate in ramp_rates]) + write_file( + CONFIGS_DIR / f"execute_{warehouses}_warehouses_ramp_up.xml", + RAMP_UP_XML.format( + warehouses=warehouses, + works=ramp_works, + hostname=hostname, + password=escaped_password, + terminals=terminals, + batch_size=BATCH_SIZE, + ), + ) + + # Loader config + write_file( + CONFIGS_DIR / f"load_{warehouses}_warehouses.xml", + LOAD_XML.format( + warehouses=warehouses, + hostname=hostname, + password=escaped_password, + batch_size=BATCH_SIZE, + loader_threads=LOADER_THREADS, + ), + ) + + # Write scripts + for suffix in ["max_rate", "opt_rate", "ramp_up"]: + script = EXECUTE_SCRIPT.format( + warehouses=warehouses, suffix=suffix, docker_image=docker_image + ) + write_file(SCRIPTS_DIR / f"execute_{warehouses}_warehouses_{suffix}.sh", script) + + # Loader script + write_file( + SCRIPTS_DIR / f"load_{warehouses}_warehouses.sh", + LOAD_SCRIPT.format(warehouses=warehouses, docker_image=docker_image), + ) + + print(f"Generated configs and scripts for {warehouses} warehouses and max rate {max_rate}.") + + +if __name__ == "__main__": + main() diff --git a/test_runner/performance/benchbase_tpc_c_helpers/upload_results_to_perf_test_results.py b/test_runner/performance/benchbase_tpc_c_helpers/upload_results_to_perf_test_results.py new file mode 100644 index 0000000000..3706d14fd4 --- /dev/null +++ b/test_runner/performance/benchbase_tpc_c_helpers/upload_results_to_perf_test_results.py @@ -0,0 +1,591 @@ +#!/usr/bin/env python3 +# ruff: noqa +# we exclude the file from ruff because on the github runner we have python 3.9 and ruff +# is running with newer python 3.12 which suggests changes incompatible with python 3.9 +""" +Upload BenchBase TPC-C results from summary.json and results.csv files to perf_test_results database. + +This script extracts metrics from BenchBase *.summary.json and *.results.csv files and uploads them +to a PostgreSQL database table for performance tracking and analysis. +""" + +import argparse +import json +import re +import sys +from datetime import datetime, timezone +from pathlib import Path + +import pandas as pd # type: ignore[import-untyped] +import psycopg2 + + +def load_summary_json(json_file_path): + """Load summary.json file and return parsed data.""" + try: + with open(json_file_path) as f: + return json.load(f) + except FileNotFoundError: + print(f"Error: Summary JSON file not found: {json_file_path}") + sys.exit(1) + except json.JSONDecodeError as e: + print(f"Error: Invalid JSON in file {json_file_path}: {e}") + sys.exit(1) + except Exception as e: + print(f"Error loading JSON file {json_file_path}: {e}") + sys.exit(1) + + +def get_metric_info(metric_name): + """Get metric unit and report type for a given metric name.""" + metrics_config = { + "Throughput": {"unit": "req/s", "report_type": "higher_is_better"}, + "Goodput": {"unit": "req/s", "report_type": "higher_is_better"}, + "Measured Requests": {"unit": "requests", "report_type": "higher_is_better"}, + "95th Percentile Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "Maximum Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "Median Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "Minimum Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "25th Percentile Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "90th Percentile Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "99th Percentile Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "75th Percentile Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "Average Latency": {"unit": "µs", "report_type": "lower_is_better"}, + } + + return metrics_config.get(metric_name, {"unit": "", "report_type": "higher_is_better"}) + + +def extract_metrics(summary_data): + """Extract relevant metrics from summary JSON data.""" + metrics = [] + + # Direct top-level metrics + direct_metrics = { + "Throughput (requests/second)": "Throughput", + "Goodput (requests/second)": "Goodput", + "Measured Requests": "Measured Requests", + } + + for json_key, clean_name in direct_metrics.items(): + if json_key in summary_data: + metrics.append((clean_name, summary_data[json_key])) + + # Latency metrics from nested "Latency Distribution" object + if "Latency Distribution" in summary_data: + latency_data = summary_data["Latency Distribution"] + latency_metrics = { + "95th Percentile Latency (microseconds)": "95th Percentile Latency", + "Maximum Latency (microseconds)": "Maximum Latency", + "Median Latency (microseconds)": "Median Latency", + "Minimum Latency (microseconds)": "Minimum Latency", + "25th Percentile Latency (microseconds)": "25th Percentile Latency", + "90th Percentile Latency (microseconds)": "90th Percentile Latency", + "99th Percentile Latency (microseconds)": "99th Percentile Latency", + "75th Percentile Latency (microseconds)": "75th Percentile Latency", + "Average Latency (microseconds)": "Average Latency", + } + + for json_key, clean_name in latency_metrics.items(): + if json_key in latency_data: + metrics.append((clean_name, latency_data[json_key])) + + return metrics + + +def build_labels(summary_data, project_id): + """Build labels JSON object from summary data and project info.""" + labels = {} + + # Extract required label keys from summary data + label_keys = [ + "DBMS Type", + "DBMS Version", + "Benchmark Type", + "Final State", + "isolation", + "scalefactor", + "terminals", + ] + + for key in label_keys: + if key in summary_data: + labels[key] = summary_data[key] + + # Add project_id from workflow + labels["project_id"] = project_id + + return labels + + +def build_suit_name(scalefactor, terminals, run_type, min_cu, max_cu): + """Build the suit name according to specification.""" + return f"benchbase-tpc-c-{scalefactor}-{terminals}-{run_type}-{min_cu}-{max_cu}" + + +def convert_timestamp_to_utc(timestamp_ms): + """Convert millisecond timestamp to PostgreSQL-compatible UTC timestamp.""" + try: + dt = datetime.fromtimestamp(timestamp_ms / 1000.0, tz=timezone.utc) + return dt.isoformat() + except (ValueError, TypeError) as e: + print(f"Warning: Could not convert timestamp {timestamp_ms}: {e}") + return datetime.now(timezone.utc).isoformat() + + +def insert_metrics(conn, metrics_data): + """Insert metrics data into the perf_test_results table.""" + insert_query = """ + INSERT INTO perf_test_results + (suit, revision, platform, metric_name, metric_value, metric_unit, + metric_report_type, recorded_at_timestamp, labels) + VALUES (%(suit)s, %(revision)s, %(platform)s, %(metric_name)s, %(metric_value)s, + %(metric_unit)s, %(metric_report_type)s, %(recorded_at_timestamp)s, %(labels)s) + """ + + try: + with conn.cursor() as cursor: + cursor.executemany(insert_query, metrics_data) + conn.commit() + print(f"Successfully inserted {len(metrics_data)} metrics into perf_test_results") + + # Log some sample data for verification + if metrics_data: + print( + f"Sample metric: {metrics_data[0]['metric_name']} = {metrics_data[0]['metric_value']} {metrics_data[0]['metric_unit']}" + ) + + except Exception as e: + print(f"Error inserting metrics into database: {e}") + sys.exit(1) + + +def create_benchbase_results_details_table(conn): + """Create benchbase_results_details table if it doesn't exist.""" + create_table_query = """ + CREATE TABLE IF NOT EXISTS benchbase_results_details ( + id BIGSERIAL PRIMARY KEY, + suit TEXT, + revision CHAR(40), + platform TEXT, + recorded_at_timestamp TIMESTAMP WITH TIME ZONE, + requests_per_second NUMERIC, + average_latency_ms NUMERIC, + minimum_latency_ms NUMERIC, + p25_latency_ms NUMERIC, + median_latency_ms NUMERIC, + p75_latency_ms NUMERIC, + p90_latency_ms NUMERIC, + p95_latency_ms NUMERIC, + p99_latency_ms NUMERIC, + maximum_latency_ms NUMERIC + ); + + CREATE INDEX IF NOT EXISTS benchbase_results_details_recorded_at_timestamp_idx + ON benchbase_results_details USING BRIN (recorded_at_timestamp); + CREATE INDEX IF NOT EXISTS benchbase_results_details_suit_idx + ON benchbase_results_details USING BTREE (suit text_pattern_ops); + """ + + try: + with conn.cursor() as cursor: + cursor.execute(create_table_query) + conn.commit() + print("Successfully created/verified benchbase_results_details table") + except Exception as e: + print(f"Error creating benchbase_results_details table: {e}") + sys.exit(1) + + +def process_csv_results(csv_file_path, start_timestamp_ms, suit, revision, platform): + """Process CSV results and return data for database insertion.""" + try: + # Read CSV file + df = pd.read_csv(csv_file_path) + + # Validate required columns exist + required_columns = [ + "Time (seconds)", + "Throughput (requests/second)", + "Average Latency (millisecond)", + "Minimum Latency (millisecond)", + "25th Percentile Latency (millisecond)", + "Median Latency (millisecond)", + "75th Percentile Latency (millisecond)", + "90th Percentile Latency (millisecond)", + "95th Percentile Latency (millisecond)", + "99th Percentile Latency (millisecond)", + "Maximum Latency (millisecond)", + ] + + missing_columns = [col for col in required_columns if col not in df.columns] + if missing_columns: + print(f"Error: Missing required columns in CSV: {missing_columns}") + return [] + + csv_data = [] + + for _, row in df.iterrows(): + # Calculate timestamp: start_timestamp_ms + (time_seconds * 1000) + time_seconds = row["Time (seconds)"] + row_timestamp_ms = start_timestamp_ms + (time_seconds * 1000) + + # Convert to UTC timestamp + row_timestamp = datetime.fromtimestamp( + row_timestamp_ms / 1000.0, tz=timezone.utc + ).isoformat() + + csv_row = { + "suit": suit, + "revision": revision, + "platform": platform, + "recorded_at_timestamp": row_timestamp, + "requests_per_second": float(row["Throughput (requests/second)"]), + "average_latency_ms": float(row["Average Latency (millisecond)"]), + "minimum_latency_ms": float(row["Minimum Latency (millisecond)"]), + "p25_latency_ms": float(row["25th Percentile Latency (millisecond)"]), + "median_latency_ms": float(row["Median Latency (millisecond)"]), + "p75_latency_ms": float(row["75th Percentile Latency (millisecond)"]), + "p90_latency_ms": float(row["90th Percentile Latency (millisecond)"]), + "p95_latency_ms": float(row["95th Percentile Latency (millisecond)"]), + "p99_latency_ms": float(row["99th Percentile Latency (millisecond)"]), + "maximum_latency_ms": float(row["Maximum Latency (millisecond)"]), + } + csv_data.append(csv_row) + + print(f"Processed {len(csv_data)} rows from CSV file") + return csv_data + + except FileNotFoundError: + print(f"Error: CSV file not found: {csv_file_path}") + return [] + except Exception as e: + print(f"Error processing CSV file {csv_file_path}: {e}") + return [] + + +def insert_csv_results(conn, csv_data): + """Insert CSV results into benchbase_results_details table.""" + if not csv_data: + print("No CSV data to insert") + return + + insert_query = """ + INSERT INTO benchbase_results_details + (suit, revision, platform, recorded_at_timestamp, requests_per_second, + average_latency_ms, minimum_latency_ms, p25_latency_ms, median_latency_ms, + p75_latency_ms, p90_latency_ms, p95_latency_ms, p99_latency_ms, maximum_latency_ms) + VALUES (%(suit)s, %(revision)s, %(platform)s, %(recorded_at_timestamp)s, %(requests_per_second)s, + %(average_latency_ms)s, %(minimum_latency_ms)s, %(p25_latency_ms)s, %(median_latency_ms)s, + %(p75_latency_ms)s, %(p90_latency_ms)s, %(p95_latency_ms)s, %(p99_latency_ms)s, %(maximum_latency_ms)s) + """ + + try: + with conn.cursor() as cursor: + cursor.executemany(insert_query, csv_data) + conn.commit() + print( + f"Successfully inserted {len(csv_data)} detailed results into benchbase_results_details" + ) + + # Log some sample data for verification + sample = csv_data[0] + print( + f"Sample detail: {sample['requests_per_second']} req/s at {sample['recorded_at_timestamp']}" + ) + + except Exception as e: + print(f"Error inserting CSV results into database: {e}") + sys.exit(1) + + +def parse_load_log(log_file_path, scalefactor): + """Parse load log file and extract load metrics.""" + try: + with open(log_file_path) as f: + log_content = f.read() + + # Regex patterns to match the timestamp lines + loading_pattern = r"\[INFO \] (\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}),\d{3}.*Loading data into TPCC database" + finished_pattern = r"\[INFO \] (\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}),\d{3}.*Finished loading data into TPCC database" + + loading_match = re.search(loading_pattern, log_content) + finished_match = re.search(finished_pattern, log_content) + + if not loading_match or not finished_match: + print(f"Warning: Could not find loading timestamps in log file {log_file_path}") + return None + + # Parse timestamps + loading_time = datetime.strptime(loading_match.group(1), "%Y-%m-%d %H:%M:%S") + finished_time = datetime.strptime(finished_match.group(1), "%Y-%m-%d %H:%M:%S") + + # Calculate duration in seconds + duration_seconds = (finished_time - loading_time).total_seconds() + + # Calculate throughput: scalefactor/warehouses: 10 warehouses is approx. 1 GB of data + load_throughput = (scalefactor * 1024 / 10.0) / duration_seconds + + # Convert end time to UTC timestamp for database + finished_time_utc = finished_time.replace(tzinfo=timezone.utc).isoformat() + + print(f"Load metrics: Duration={duration_seconds}s, Throughput={load_throughput:.2f} MB/s") + + return { + "duration_seconds": duration_seconds, + "throughput_mb_per_sec": load_throughput, + "end_timestamp": finished_time_utc, + } + + except FileNotFoundError: + print(f"Warning: Load log file not found: {log_file_path}") + return None + except Exception as e: + print(f"Error parsing load log file {log_file_path}: {e}") + return None + + +def insert_load_metrics(conn, load_metrics, suit, revision, platform, labels_json): + """Insert load metrics into perf_test_results table.""" + if not load_metrics: + print("No load metrics to insert") + return + + load_metrics_data = [ + { + "suit": suit, + "revision": revision, + "platform": platform, + "metric_name": "load_duration_seconds", + "metric_value": load_metrics["duration_seconds"], + "metric_unit": "seconds", + "metric_report_type": "lower_is_better", + "recorded_at_timestamp": load_metrics["end_timestamp"], + "labels": labels_json, + }, + { + "suit": suit, + "revision": revision, + "platform": platform, + "metric_name": "load_throughput", + "metric_value": load_metrics["throughput_mb_per_sec"], + "metric_unit": "MB/second", + "metric_report_type": "higher_is_better", + "recorded_at_timestamp": load_metrics["end_timestamp"], + "labels": labels_json, + }, + ] + + insert_query = """ + INSERT INTO perf_test_results + (suit, revision, platform, metric_name, metric_value, metric_unit, + metric_report_type, recorded_at_timestamp, labels) + VALUES (%(suit)s, %(revision)s, %(platform)s, %(metric_name)s, %(metric_value)s, + %(metric_unit)s, %(metric_report_type)s, %(recorded_at_timestamp)s, %(labels)s) + """ + + try: + with conn.cursor() as cursor: + cursor.executemany(insert_query, load_metrics_data) + conn.commit() + print(f"Successfully inserted {len(load_metrics_data)} load metrics into perf_test_results") + + except Exception as e: + print(f"Error inserting load metrics into database: {e}") + sys.exit(1) + + +def main(): + """Main function to parse arguments and upload results.""" + parser = argparse.ArgumentParser( + description="Upload BenchBase TPC-C results to perf_test_results database" + ) + parser.add_argument( + "--summary-json", type=str, required=False, help="Path to the summary.json file" + ) + parser.add_argument( + "--run-type", + type=str, + required=True, + choices=["warmup", "opt-rate", "ramp-up", "load"], + help="Type of benchmark run", + ) + parser.add_argument("--min-cu", type=float, required=True, help="Minimum compute units") + parser.add_argument("--max-cu", type=float, required=True, help="Maximum compute units") + parser.add_argument("--project-id", type=str, required=True, help="Neon project ID") + parser.add_argument( + "--revision", type=str, required=True, help="Git commit hash (40 characters)" + ) + parser.add_argument( + "--connection-string", type=str, required=True, help="PostgreSQL connection string" + ) + parser.add_argument( + "--results-csv", + type=str, + required=False, + help="Path to the results.csv file for detailed metrics upload", + ) + parser.add_argument( + "--load-log", + type=str, + required=False, + help="Path to the load log file for load phase metrics", + ) + parser.add_argument( + "--warehouses", + type=int, + required=False, + help="Number of warehouses (scalefactor) for load metrics calculation", + ) + + args = parser.parse_args() + + # Validate inputs + if args.summary_json and not Path(args.summary_json).exists(): + print(f"Error: Summary JSON file does not exist: {args.summary_json}") + sys.exit(1) + + if not args.summary_json and not args.load_log: + print("Error: Either summary JSON or load log file must be provided") + sys.exit(1) + + if len(args.revision) != 40: + print(f"Warning: Revision should be 40 characters, got {len(args.revision)}") + + # Load and process summary data if provided + summary_data = None + metrics = [] + + if args.summary_json: + summary_data = load_summary_json(args.summary_json) + metrics = extract_metrics(summary_data) + if not metrics: + print("Warning: No metrics found in summary JSON") + + # Build common data for all metrics + if summary_data: + scalefactor = summary_data.get("scalefactor", "unknown") + terminals = summary_data.get("terminals", "unknown") + labels = build_labels(summary_data, args.project_id) + else: + # For load-only processing, use warehouses argument as scalefactor + scalefactor = args.warehouses if args.warehouses else "unknown" + terminals = "unknown" + labels = {"project_id": args.project_id} + + suit = build_suit_name(scalefactor, terminals, args.run_type, args.min_cu, args.max_cu) + platform = f"prod-us-east-2-{args.project_id}" + + # Convert timestamp - only needed for summary metrics and CSV processing + current_timestamp_ms = None + start_timestamp_ms = None + recorded_at = None + + if summary_data: + current_timestamp_ms = summary_data.get("Current Timestamp (milliseconds)") + start_timestamp_ms = summary_data.get("Start timestamp (milliseconds)") + + if current_timestamp_ms: + recorded_at = convert_timestamp_to_utc(current_timestamp_ms) + else: + print("Warning: No timestamp found in JSON, using current time") + recorded_at = datetime.now(timezone.utc).isoformat() + + if not start_timestamp_ms: + print("Warning: No start timestamp found in JSON, CSV upload may be incorrect") + start_timestamp_ms = ( + current_timestamp_ms or datetime.now(timezone.utc).timestamp() * 1000 + ) + + # Print Grafana dashboard link for cross-service endpoint debugging + if start_timestamp_ms and current_timestamp_ms: + grafana_url = ( + f"https://neonprod.grafana.net/d/cdya0okb81zwga/cross-service-endpoint-debugging" + f"?orgId=1&from={int(start_timestamp_ms)}&to={int(current_timestamp_ms)}" + f"&timezone=utc&var-env=prod&var-input_project_id={args.project_id}" + ) + print(f'Cross service endpoint dashboard for "{args.run_type}" phase: {grafana_url}') + + # Prepare metrics data for database insertion (only if we have summary metrics) + metrics_data = [] + if metrics and recorded_at: + for metric_name, metric_value in metrics: + metric_info = get_metric_info(metric_name) + + row = { + "suit": suit, + "revision": args.revision, + "platform": platform, + "metric_name": metric_name, + "metric_value": float(metric_value), # Ensure numeric type + "metric_unit": metric_info["unit"], + "metric_report_type": metric_info["report_type"], + "recorded_at_timestamp": recorded_at, + "labels": json.dumps(labels), # Convert to JSON string for JSONB column + } + metrics_data.append(row) + + print(f"Prepared {len(metrics_data)} summary metrics for upload to database") + print(f"Suit: {suit}") + print(f"Platform: {platform}") + + # Connect to database and insert metrics + try: + conn = psycopg2.connect(args.connection_string) + + # Insert summary metrics into perf_test_results (if any) + if metrics_data: + insert_metrics(conn, metrics_data) + else: + print("No summary metrics to upload") + + # Process and insert detailed CSV results if provided + if args.results_csv: + print(f"Processing detailed CSV results from: {args.results_csv}") + + # Create table if it doesn't exist + create_benchbase_results_details_table(conn) + + # Process CSV data + csv_data = process_csv_results( + args.results_csv, start_timestamp_ms, suit, args.revision, platform + ) + + # Insert CSV data + if csv_data: + insert_csv_results(conn, csv_data) + else: + print("No CSV data to upload") + else: + print("No CSV file provided, skipping detailed results upload") + + # Process and insert load metrics if provided + if args.load_log: + print(f"Processing load metrics from: {args.load_log}") + + # Parse load log and extract metrics + load_metrics = parse_load_log(args.load_log, scalefactor) + + # Insert load metrics + if load_metrics: + insert_load_metrics( + conn, load_metrics, suit, args.revision, platform, json.dumps(labels) + ) + else: + print("No load metrics to upload") + else: + print("No load log file provided, skipping load metrics upload") + + conn.close() + print("Database upload completed successfully") + + except psycopg2.Error as e: + print(f"Database connection/query error: {e}") + sys.exit(1) + except Exception as e: + print(f"Unexpected error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/test_runner/performance/test_lfc_prewarm.py b/test_runner/performance/test_lfc_prewarm.py index 6c0083de95..d459f9f3bf 100644 --- a/test_runner/performance/test_lfc_prewarm.py +++ b/test_runner/performance/test_lfc_prewarm.py @@ -2,45 +2,48 @@ from __future__ import annotations import os import timeit -import traceback -from concurrent.futures import ThreadPoolExecutor as Exec from pathlib import Path +from threading import Thread from time import sleep -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, cast import pytest from fixtures.benchmark_fixture import NeonBenchmarker, PgBenchRunResult from fixtures.log_helper import log -from fixtures.neon_api import NeonAPI, connection_parameters_to_env +from fixtures.neon_api import NeonAPI, connstr_to_env + +from performance.test_perf_pgbench import utc_now_timestamp if TYPE_CHECKING: from fixtures.compare_fixtures import NeonCompare from fixtures.neon_fixtures import Endpoint, PgBin from fixtures.pg_version import PgVersion -from performance.test_perf_pgbench import utc_now_timestamp # These tests compare performance for a write-heavy and read-heavy workloads of an ordinary endpoint -# compared to the endpoint which saves its LFC and prewarms using it on startup. +# compared to the endpoint which saves its LFC and prewarms using it on startup def test_compare_prewarmed_pgbench_perf(neon_compare: NeonCompare): env = neon_compare.env - env.create_branch("normal") env.create_branch("prewarmed") pg_bin = neon_compare.pg_bin - ep_normal: Endpoint = env.endpoints.create_start("normal") - ep_prewarmed: Endpoint = env.endpoints.create_start("prewarmed", autoprewarm=True) + ep_ordinary: Endpoint = neon_compare.endpoint + ep_prewarmed: Endpoint = env.endpoints.create_start("prewarmed") - for ep in [ep_normal, ep_prewarmed]: + for ep in [ep_ordinary, ep_prewarmed]: connstr: str = ep.connstr() pg_bin.run(["pgbench", "-i", "-I", "dtGvp", connstr, "-s100"]) - ep.safe_psql("CREATE EXTENSION neon") - client = ep.http_client() - client.offload_lfc() - ep.stop() - ep.start() - client.prewarm_lfc_wait() + ep.safe_psql("CREATE SCHEMA neon; CREATE EXTENSION neon WITH SCHEMA neon") + if ep == ep_prewarmed: + client = ep.http_client() + client.offload_lfc() + ep.stop() + ep.start(autoprewarm=True) + client.prewarm_lfc_wait() + else: + ep.stop() + ep.start() run_start_timestamp = utc_now_timestamp() t0 = timeit.default_timer() @@ -59,6 +62,36 @@ def test_compare_prewarmed_pgbench_perf(neon_compare: NeonCompare): neon_compare.zenbenchmark.record_pg_bench_result(name, res) +def test_compare_prewarmed_read_perf(neon_compare: NeonCompare): + env = neon_compare.env + env.create_branch("prewarmed") + ep_ordinary: Endpoint = neon_compare.endpoint + ep_prewarmed: Endpoint = env.endpoints.create_start("prewarmed") + + sql = [ + "CREATE SCHEMA neon", + "CREATE EXTENSION neon WITH SCHEMA neon", + "CREATE TABLE foo(key serial primary key, t text default 'foooooooooooooooooooooooooooooooooooooooooooooooooooo')", + "INSERT INTO foo SELECT FROM generate_series(1,1000000)", + ] + sql_check = "SELECT count(*) from foo" + + ep_ordinary.safe_psql_many(sql) + ep_ordinary.stop() + ep_ordinary.start() + with neon_compare.record_duration("ordinary_run_duration"): + ep_ordinary.safe_psql(sql_check) + + ep_prewarmed.safe_psql_many(sql) + client = ep_prewarmed.http_client() + client.offload_lfc() + ep_prewarmed.stop() + ep_prewarmed.start(autoprewarm=True) + client.prewarm_lfc_wait() + with neon_compare.record_duration("prewarmed_run_duration"): + ep_prewarmed.safe_psql(sql_check) + + @pytest.mark.remote_cluster @pytest.mark.timeout(2 * 60 * 60) def test_compare_prewarmed_pgbench_perf_benchmark( @@ -67,67 +100,66 @@ def test_compare_prewarmed_pgbench_perf_benchmark( pg_version: PgVersion, zenbenchmark: NeonBenchmarker, ): - name = f"Test prewarmed pgbench performance, GITHUB_RUN_ID={os.getenv('GITHUB_RUN_ID')}" - project = neon_api.create_project(pg_version, name) - project_id = project["project"]["id"] - neon_api.wait_for_operation_to_finish(project_id) - err = False - try: - benchmark_impl(pg_bin, neon_api, project, zenbenchmark) - except Exception as e: - err = True - log.error(f"Caught exception: {e}") - log.error(traceback.format_exc()) - finally: - assert not err - neon_api.delete_project(project_id) + """ + Prewarm API is not public, so this test relies on a pre-created project + with pgbench size of 3424, pgbench -i -IdtGvp -s3424. Sleeping and + offloading constants are hardcoded to this size as well + """ + project_id = os.getenv("PROJECT_ID") + assert project_id + ordinary_branch_id = "" + prewarmed_branch_id = "" + for branch in neon_api.get_branches(project_id)["branches"]: + if branch["name"] == "ordinary": + ordinary_branch_id = branch["id"] + if branch["name"] == "prewarmed": + prewarmed_branch_id = branch["id"] + assert len(ordinary_branch_id) > 0 + assert len(prewarmed_branch_id) > 0 + + ep_ordinary = None + ep_prewarmed = None + for ep in neon_api.get_endpoints(project_id)["endpoints"]: + if ep["branch_id"] == ordinary_branch_id: + ep_ordinary = ep + if ep["branch_id"] == prewarmed_branch_id: + ep_prewarmed = ep + assert ep_ordinary + assert ep_prewarmed + ordinary_id = ep_ordinary["id"] + prewarmed_id = ep_prewarmed["id"] -def benchmark_impl( - pg_bin: PgBin, neon_api: NeonAPI, project: dict[str, Any], zenbenchmark: NeonBenchmarker -): - pgbench_size = int(os.getenv("PGBENCH_SIZE") or "3424") # 50GB offload_secs = 20 - test_duration_min = 5 + test_duration_min = 3 pgbench_duration = f"-T{test_duration_min * 60}" - # prewarm API is not publicly exposed. In order to test performance of a - # fully prewarmed endpoint, wait after it restarts. - # The number here is empirical, based on manual runs on staging + pgbench_init_cmd = ["pgbench", "-P10", "-n", "-c10", pgbench_duration, "-Mprepared"] + pgbench_perf_cmd = pgbench_init_cmd + ["-S"] prewarmed_sleep_secs = 180 - branch_id = project["branch"]["id"] - project_id = project["project"]["id"] - normal_env = connection_parameters_to_env( - project["connection_uris"][0]["connection_parameters"] - ) - normal_id = project["endpoints"][0]["id"] - - prewarmed_branch_id = neon_api.create_branch( - project_id, "prewarmed", parent_id=branch_id, add_endpoint=False - )["branch"]["id"] - neon_api.wait_for_operation_to_finish(project_id) - - ep_prewarmed = neon_api.create_endpoint( - project_id, - prewarmed_branch_id, - endpoint_type="read_write", - settings={"autoprewarm": True, "offload_lfc_interval_seconds": offload_secs}, - ) - neon_api.wait_for_operation_to_finish(project_id) - - prewarmed_env = normal_env.copy() - prewarmed_env["PGHOST"] = ep_prewarmed["endpoint"]["host"] - prewarmed_id = ep_prewarmed["endpoint"]["id"] + ordinary_uri = neon_api.get_connection_uri(project_id, ordinary_branch_id, ordinary_id)["uri"] + prewarmed_uri = neon_api.get_connection_uri(project_id, prewarmed_branch_id, prewarmed_id)[ + "uri" + ] def bench(endpoint_name, endpoint_id, env): - pg_bin.run(["pgbench", "-i", "-I", "dtGvp", f"-s{pgbench_size}"], env) - sleep(offload_secs * 2) # ensure LFC is offloaded after pgbench finishes - neon_api.restart_endpoint(project_id, endpoint_id) - sleep(prewarmed_sleep_secs) + log.info(f"Running pgbench for {pgbench_duration}s to warm up the cache") + pg_bin.run_capture(pgbench_init_cmd, env) # capture useful for debugging + log.info(f"Initialized {endpoint_name}") + if endpoint_name == "prewarmed": + log.info(f"sleeping {offload_secs * 2} to ensure LFC is offloaded") + sleep(offload_secs * 2) + neon_api.restart_endpoint(project_id, endpoint_id) + log.info(f"sleeping {prewarmed_sleep_secs} to ensure LFC is prewarmed") + sleep(prewarmed_sleep_secs) + else: + neon_api.restart_endpoint(project_id, endpoint_id) + + log.info(f"Starting benchmark for {endpoint_name}") run_start_timestamp = utc_now_timestamp() t0 = timeit.default_timer() - out = pg_bin.run_capture(["pgbench", "-c10", pgbench_duration, "-Mprepared"], env) + out = pg_bin.run_capture(pgbench_perf_cmd, env) run_duration = timeit.default_timer() - t0 run_end_timestamp = utc_now_timestamp() @@ -140,29 +172,9 @@ def benchmark_impl( ) zenbenchmark.record_pg_bench_result(endpoint_name, res) - with Exec(max_workers=2) as exe: - exe.submit(bench, "normal", normal_id, normal_env) - exe.submit(bench, "prewarmed", prewarmed_id, prewarmed_env) + prewarmed_args = ("prewarmed", prewarmed_id, connstr_to_env(prewarmed_uri)) + prewarmed_thread = Thread(target=bench, args=prewarmed_args) + prewarmed_thread.start() - -def test_compare_prewarmed_read_perf(neon_compare: NeonCompare): - env = neon_compare.env - env.create_branch("normal") - env.create_branch("prewarmed") - ep_normal: Endpoint = env.endpoints.create_start("normal") - ep_prewarmed: Endpoint = env.endpoints.create_start("prewarmed", autoprewarm=True) - - sql = [ - "CREATE EXTENSION neon", - "CREATE TABLE foo(key serial primary key, t text default 'foooooooooooooooooooooooooooooooooooooooooooooooooooo')", - "INSERT INTO foo SELECT FROM generate_series(1,1000000)", - ] - for ep in [ep_normal, ep_prewarmed]: - ep.safe_psql_many(sql) - client = ep.http_client() - client.offload_lfc() - ep.stop() - ep.start() - client.prewarm_lfc_wait() - with neon_compare.record_duration(f"{ep.branch_name}_run_duration"): - ep.safe_psql("SELECT count(*) from foo") + bench("ordinary", ordinary_id, connstr_to_env(ordinary_uri)) + prewarmed_thread.join() diff --git a/test_runner/regress/test_change_pageserver.py b/test_runner/regress/test_change_pageserver.py index bcdccac14e..af736af825 100644 --- a/test_runner/regress/test_change_pageserver.py +++ b/test_runner/regress/test_change_pageserver.py @@ -17,7 +17,7 @@ def reconfigure_endpoint(endpoint: Endpoint, pageserver_id: int, use_explicit_re # to make sure that PG-initiated config refresh doesn't mess things up by reverting to the old config. endpoint.update_pageservers_in_config(pageserver_id=pageserver_id) - # PG will eventually automatically refresh its configuration if it detects connectivity issues with pageservers. + # PG will automatically refresh its configuration if it detects connectivity issues with pageservers. # We also allow the test to explicitly request a reconfigure so that the test can be sure that the # endpoint is running with the latest configuration. # diff --git a/test_runner/regress/test_compute_termination.py b/test_runner/regress/test_compute_termination.py new file mode 100644 index 0000000000..2d62ccf20f --- /dev/null +++ b/test_runner/regress/test_compute_termination.py @@ -0,0 +1,369 @@ +from __future__ import annotations + +import json +import os +import shutil +import subprocess +import threading +import time +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import TYPE_CHECKING + +import requests +from fixtures.log_helper import log +from typing_extensions import override + +if TYPE_CHECKING: + from typing import Any + + from fixtures.common_types import TenantId, TimelineId + from fixtures.neon_fixtures import NeonEnv + from fixtures.port_distributor import PortDistributor + + +def launch_compute_ctl( + env: NeonEnv, + endpoint_name: str, + external_http_port: int, + internal_http_port: int, + pg_port: int, + control_plane_port: int, +) -> subprocess.Popen[str]: + """ + Helper function to launch compute_ctl process with common configuration. + Returns the Popen process object. + """ + # Create endpoint directory structure following the standard pattern + endpoint_path = env.repo_dir / "endpoints" / endpoint_name + + # Clean up any existing endpoint directory to avoid conflicts + if endpoint_path.exists(): + shutil.rmtree(endpoint_path) + + endpoint_path.mkdir(mode=0o755, parents=True, exist_ok=True) + + # pgdata path - compute_ctl will create this directory during basebackup + pgdata_path = endpoint_path / "pgdata" + + # Create log file in endpoint directory + log_file = endpoint_path / "compute.log" + log_handle = open(log_file, "w") + + # Start compute_ctl pointing to our control plane + compute_ctl_path = env.neon_binpath / "compute_ctl" + connstr = f"postgresql://cloud_admin@localhost:{pg_port}/postgres" + + # Find postgres binary path + pg_bin_path = env.pg_distrib_dir / env.pg_version.v_prefixed / "bin" / "postgres" + pg_lib_path = env.pg_distrib_dir / env.pg_version.v_prefixed / "lib" + + env_vars = { + "INSTANCE_ID": "lakebase-instance-id", + "LD_LIBRARY_PATH": str(pg_lib_path), # Linux, etc. + "DYLD_LIBRARY_PATH": str(pg_lib_path), # macOS + } + + cmd = [ + str(compute_ctl_path), + "--external-http-port", + str(external_http_port), + "--internal-http-port", + str(internal_http_port), + "--pgdata", + str(pgdata_path), + "--connstr", + connstr, + "--pgbin", + str(pg_bin_path), + "--compute-id", + endpoint_name, # Use endpoint_name as compute-id + "--control-plane-uri", + f"http://127.0.0.1:{control_plane_port}", + "--lakebase-mode", + "true", + ] + + print(f"Launching compute_ctl with command: {cmd}") + + # Start compute_ctl + process = subprocess.Popen( + cmd, + env=env_vars, + stdout=log_handle, + stderr=subprocess.STDOUT, # Combine stderr with stdout + text=True, + ) + + return process + + +def wait_for_compute_status( + compute_process: subprocess.Popen[str], + http_port: int, + expected_status: str, + timeout_seconds: int = 10, +) -> None: + """ + Wait for compute_ctl to reach the expected status. + Raises an exception if timeout is reached or process exits unexpectedly. + """ + start_time = time.time() + while time.time() - start_time < timeout_seconds: + try: + # Try to connect to the HTTP endpoint + response = requests.get(f"http://localhost:{http_port}/status", timeout=0.5) + if response.status_code == 200: + status_json = response.json() + # Check if it's in expected status + if status_json.get("status") == expected_status: + return + except (requests.ConnectionError, requests.Timeout): + pass + + # Check if process has exited + if compute_process.poll() is not None: + raise Exception( + f"compute_ctl exited unexpectedly with code {compute_process.returncode}." + ) + + time.sleep(0.5) + + # Timeout reached + compute_process.terminate() + raise Exception( + f"compute_ctl failed to reach {expected_status} status within {timeout_seconds} seconds." + ) + + +class EmptySpecHandler(BaseHTTPRequestHandler): + """HTTP handler that returns an Empty compute spec response""" + + def do_GET(self): + if self.path.startswith("/compute/api/v2/computes/") and self.path.endswith("/spec"): + # Return empty status which will put compute in Empty state + response: dict[str, Any] = { + "status": "empty", + "spec": None, + "compute_ctl_config": {"jwks": {"keys": []}}, + } + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps(response).encode()) + else: + self.send_error(404) + + @override + def log_message(self, format: str, *args: Any): + # Suppress request logging + pass + + +def test_compute_terminate_empty(neon_simple_env: NeonEnv, port_distributor: PortDistributor): + """ + Test that terminating a compute in Empty status works correctly. + + This tests the bug fix where terminating an Empty compute would hang + waiting for a non-existent postgres process to terminate. + """ + env = neon_simple_env + + # Get ports for our test + control_plane_port = port_distributor.get_port() + external_http_port = port_distributor.get_port() + internal_http_port = port_distributor.get_port() + pg_port = port_distributor.get_port() + + # Start a simple HTTP server that will serve the Empty spec + server = HTTPServer(("127.0.0.1", control_plane_port), EmptySpecHandler) + server_thread = threading.Thread(target=server.serve_forever) + server_thread.daemon = True + server_thread.start() + + compute_process = None + try: + # Start compute_ctl with ephemeral tenant ID + compute_process = launch_compute_ctl( + env, + "test-empty-compute", + external_http_port, + internal_http_port, + pg_port, + control_plane_port, + ) + + # Wait for compute_ctl to start and report "empty" status + wait_for_compute_status(compute_process, external_http_port, "empty") + + # Now send terminate request + response = requests.post(f"http://localhost:{external_http_port}/terminate") + + # Verify that the termination request sends back a 200 OK response and is not abruptly terminated. + assert response.status_code == 200, ( + f"Expected 200 OK, got {response.status_code}: {response.text}" + ) + + # Wait for compute_ctl to exit + exit_code = compute_process.wait(timeout=10) + assert exit_code == 0, f"compute_ctl exited with non-zero code: {exit_code}" + + finally: + # Clean up + server.shutdown() + if compute_process and compute_process.poll() is None: + compute_process.terminate() + compute_process.wait() + + +class SwitchableConfigHandler(BaseHTTPRequestHandler): + """HTTP handler that can switch between normal compute configs and compute configs without specs""" + + return_empty_spec: bool = False + tenant_id: TenantId | None = None + timeline_id: TimelineId | None = None + pageserver_port: int | None = None + safekeeper_connstrs: list[str] | None = None + + def do_GET(self): + if self.path.startswith("/compute/api/v2/computes/") and self.path.endswith("/spec"): + if self.return_empty_spec: + # Return empty status + response: dict[str, object | None] = { + "status": "empty", + "spec": None, + "compute_ctl_config": { + "jwks": {"keys": []}, + }, + } + else: + # Return normal attached spec + response = { + "status": "attached", + "spec": { + "format_version": 1.0, + "cluster": { + "roles": [], + "databases": [], + "postgresql_conf": "shared_preload_libraries='neon'", + }, + "tenant_id": str(self.tenant_id) if self.tenant_id else "", + "timeline_id": str(self.timeline_id) if self.timeline_id else "", + "pageserver_connstring": f"postgres://no_user@localhost:{self.pageserver_port}" + if self.pageserver_port + else "", + "safekeeper_connstrings": self.safekeeper_connstrs or [], + "mode": "Primary", + "skip_pg_catalog_updates": True, + "reconfigure_concurrency": 1, + "suspend_timeout_seconds": -1, + }, + "compute_ctl_config": { + "jwks": {"keys": []}, + }, + } + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps(response).encode()) + else: + self.send_error(404) + + @override + def log_message(self, format: str, *args: Any): + # Suppress request logging + pass + + +def test_compute_empty_spec_during_refresh_configuration( + neon_simple_env: NeonEnv, port_distributor: PortDistributor +): + """ + Test that compute exits when it receives an empty spec during refresh configuration state. + + This test: + 1. Start compute with a normal spec + 2. Change the spec handler to return empty spec + 3. Trigger some condition to force compute to refresh configuration + 4. Verify that compute_ctl exits + """ + env = neon_simple_env + + # Get ports for our test + control_plane_port = port_distributor.get_port() + external_http_port = port_distributor.get_port() + internal_http_port = port_distributor.get_port() + pg_port = port_distributor.get_port() + + # Set up handler class variables + SwitchableConfigHandler.tenant_id = env.initial_tenant + SwitchableConfigHandler.timeline_id = env.initial_timeline + SwitchableConfigHandler.pageserver_port = env.pageserver.service_port.pg + # Convert comma-separated string to list + safekeeper_connstrs = env.get_safekeeper_connstrs() + if safekeeper_connstrs: + SwitchableConfigHandler.safekeeper_connstrs = safekeeper_connstrs.split(",") + else: + SwitchableConfigHandler.safekeeper_connstrs = [] + SwitchableConfigHandler.return_empty_spec = False # Start with normal spec + + # Start HTTP server with switchable spec handler + server = HTTPServer(("127.0.0.1", control_plane_port), SwitchableConfigHandler) + server_thread = threading.Thread(target=server.serve_forever) + server_thread.daemon = True + server_thread.start() + + compute_process = None + try: + # Start compute_ctl with tenant and timeline IDs + # Use a unique endpoint name to avoid conflicts + endpoint_name = f"test-refresh-compute-{os.getpid()}" + compute_process = launch_compute_ctl( + env, + endpoint_name, + external_http_port, + internal_http_port, + pg_port, + control_plane_port, + ) + + # Wait for compute_ctl to start and report "running" status + wait_for_compute_status(compute_process, external_http_port, "running", timeout_seconds=30) + + log.info("Compute is running. Now returning empty spec and trigger configuration refresh.") + + # Switch spec fetch handler to return empty spec + SwitchableConfigHandler.return_empty_spec = True + + # Trigger a configuration refresh + try: + requests.post(f"http://localhost:{internal_http_port}/refresh_configuration") + except requests.RequestException as e: + log.info(f"Call to /refresh_configuration failed: {e}") + log.info( + "Ignoring the error, assuming that compute_ctl is already refreshing or has exited" + ) + + # Wait for compute_ctl to exit (it should exit when it gets an empty spec during refresh) + exit_start_time = time.time() + while time.time() - exit_start_time < 30: + if compute_process.poll() is not None: + # Process exited + break + time.sleep(0.5) + + # Verify that compute_ctl exited + exit_code = compute_process.poll() + if exit_code is None: + compute_process.terminate() + raise Exception("compute_ctl did not exit after receiving empty spec.") + + # The exit code might not be 0 in this case since it's an unexpected termination + # but we mainly care that it did exit + assert exit_code is not None, "compute_ctl should have exited" + + finally: + # Clean up + server.shutdown() + if compute_process and compute_process.poll() is None: + compute_process.terminate() + compute_process.wait() diff --git a/test_runner/regress/test_hadron_ps_connectivity_metrics.py b/test_runner/regress/test_hadron_ps_connectivity_metrics.py new file mode 100644 index 0000000000..ff1f37b634 --- /dev/null +++ b/test_runner/regress/test_hadron_ps_connectivity_metrics.py @@ -0,0 +1,137 @@ +import json +import shutil + +from fixtures.common_types import TenantShardId +from fixtures.log_helper import log +from fixtures.metrics import parse_metrics +from fixtures.neon_fixtures import Endpoint, NeonEnvBuilder, NeonPageserver +from requests.exceptions import ConnectionError + + +# Helper function to attempt reconfiguration of the compute to point to a new pageserver. Note that in these tests, +# we don't expect the reconfiguration attempts to go through, as we will be pointing the compute at a "wrong" pageserver. +def _attempt_reconfiguration(endpoint: Endpoint, new_pageserver_id: int, timeout_sec: float): + try: + endpoint.reconfigure(pageserver_id=new_pageserver_id, timeout_sec=timeout_sec) + except Exception as e: + log.info(f"reconfiguration failed with exception {e}") + pass + + +def read_misrouted_metric_value(pageserver: NeonPageserver) -> float: + return ( + pageserver.http_client() + .get_metrics() + .query_one("pageserver_misrouted_pagestream_requests_total") + .value + ) + + +def read_request_error_metric_value(endpoint: Endpoint) -> float: + return ( + parse_metrics(endpoint.http_client().metrics()) + .query_one("pg_cctl_pagestream_request_errors_total") + .value + ) + + +def test_misrouted_to_secondary( + neon_env_builder: NeonEnvBuilder, +): + """ + Tests that the following metrics are incremented when compute tries to talk to a secondary pageserver: + - On pageserver receiving the request: pageserver_misrouted_pagestream_requests_total + - On compute: pg_cctl_pagestream_request_errors_total + """ + neon_env_builder.num_pageservers = 2 + env = neon_env_builder.init_configs() + env.broker.start() + env.storage_controller.start() + for ps in env.pageservers: + ps.start() + for sk in env.safekeepers: + sk.start() + + # Create a tenant that has one primary and one secondary. Due to primary/secondary placement constraints, + # the primary and secondary pageservers will be different. + tenant_id, _ = env.create_tenant(shard_count=1, placement_policy=json.dumps({"Attached": 1})) + endpoint = env.endpoints.create( + "main", tenant_id=tenant_id, config_lines=["neon.lakebase_mode = true"] + ) + endpoint.respec(skip_pg_catalog_updates=False) + endpoint.start() + + # Get the primary pageserver serving the zero shard of the tenant, and detach it from the primary pageserver. + # This test operation configures tenant directly on the pageserver/does not go through the storage controller, + # so the compute does not get any notifications and will keep pointing at the detached pageserver. + tenant_zero_shard = TenantShardId(tenant_id, shard_number=0, shard_count=1) + + primary_ps = env.get_tenant_pageserver(tenant_zero_shard) + secondary_ps = ( + env.pageservers[1] if primary_ps.id == env.pageservers[0].id else env.pageservers[0] + ) + + # Now try to point the compute at the pageserver that is acting as secondary for the tenant. Test that the metrics + # on both compute_ctl and the pageserver register the misrouted requests following the reconfiguration attempt. + assert read_misrouted_metric_value(secondary_ps) == 0 + assert read_request_error_metric_value(endpoint) == 0 + _attempt_reconfiguration(endpoint, new_pageserver_id=secondary_ps.id, timeout_sec=2.0) + assert read_misrouted_metric_value(secondary_ps) > 0 + try: + assert read_request_error_metric_value(endpoint) > 0 + except ConnectionError: + # When configuring PG to use misconfigured pageserver, PG will cancel the query after certain number of failed + # reconfigure attempts. This will cause compute_ctl to exit. + log.info("Cannot connect to PG, ignoring") + pass + + +def test_misrouted_to_ps_not_hosting_tenant( + neon_env_builder: NeonEnvBuilder, +): + """ + Tests that the following metrics are incremented when compute tries to talk to a pageserver that does not host the tenant: + - On pageserver receiving the request: pageserver_misrouted_pagestream_requests_total + - On compute: pg_cctl_pagestream_request_errors_total + """ + neon_env_builder.num_pageservers = 2 + env = neon_env_builder.init_configs() + env.broker.start() + env.storage_controller.start(handle_ps_local_disk_loss=False) + for ps in env.pageservers: + ps.start() + for sk in env.safekeepers: + sk.start() + + tenant_id, _ = env.create_tenant(shard_count=1) + endpoint = env.endpoints.create( + "main", tenant_id=tenant_id, config_lines=["neon.lakebase_mode = true"] + ) + endpoint.respec(skip_pg_catalog_updates=False) + endpoint.start() + + tenant_ps_id = env.get_tenant_pageserver( + TenantShardId(tenant_id, shard_number=0, shard_count=1) + ).id + non_hosting_ps = ( + env.pageservers[1] if tenant_ps_id == env.pageservers[0].id else env.pageservers[0] + ) + + # Clear the disk of the non-hosting PS to make sure that it indeed doesn't have any information about the tenant. + non_hosting_ps.stop(immediate=True) + shutil.rmtree(non_hosting_ps.tenant_dir()) + non_hosting_ps.start() + + # Now try to point the compute to the non-hosting pageserver. Test that the metrics + # on both compute_ctl and the pageserver register the misrouted requests following the reconfiguration attempt. + assert read_misrouted_metric_value(non_hosting_ps) == 0 + assert read_request_error_metric_value(endpoint) == 0 + _attempt_reconfiguration(endpoint, new_pageserver_id=non_hosting_ps.id, timeout_sec=2.0) + assert read_misrouted_metric_value(non_hosting_ps) > 0 + try: + assert read_request_error_metric_value(endpoint) > 0 + except ConnectionError: + # When configuring PG to use misconfigured pageserver, PG will cancel the query after certain number of failed + # reconfigure attempts. This will cause compute_ctl to exit. + log.info("Cannot connect to PG, ignoring") + pass diff --git a/test_runner/regress/test_readonly_node.py b/test_runner/regress/test_readonly_node.py index 5612236250..e151b0ba13 100644 --- a/test_runner/regress/test_readonly_node.py +++ b/test_runner/regress/test_readonly_node.py @@ -129,7 +129,10 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder): Test static endpoint is protected from GC by acquiring and renewing lsn leases. """ - LSN_LEASE_LENGTH = 8 + LSN_LEASE_LENGTH = ( + 14 # This value needs to be large enough for compute_ctl to send two lease requests. + ) + neon_env_builder.num_pageservers = 2 # GC is manual triggered. env = neon_env_builder.init_start( @@ -230,6 +233,15 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder): log.info(f"`SELECT` query succeed after GC, {ctx=}") return offset + # It's not reliable to let the compute renew the lease in this test case as we have a very tight + # lease timeout. Therefore, the test case itself will renew the lease. + # + # This is a workaround to make the test case more deterministic. + def renew_lease(env: NeonEnv, lease_lsn: Lsn): + env.storage_controller.pageserver_api().timeline_lsn_lease( + env.initial_tenant, env.initial_timeline, lease_lsn + ) + # Insert some records on main branch with env.endpoints.create_start("main", config_lines=["shared_buffers=1MB"]) as ep_main: with ep_main.cursor() as cur: @@ -242,6 +254,9 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder): XLOG_BLCKSZ = 8192 lsn = Lsn((int(lsn) // XLOG_BLCKSZ) * XLOG_BLCKSZ) + # We need to mock the way cplane works: it gets a lease for a branch before starting the compute. + renew_lease(env, lsn) + with env.endpoints.create_start( branch_name="main", endpoint_id="static", @@ -251,9 +266,6 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder): cur.execute("SELECT count(*) FROM t0") assert cur.fetchone() == (ROW_COUNT,) - # Wait for static compute to renew lease at least once. - time.sleep(LSN_LEASE_LENGTH / 2) - generate_updates_on_main(env, ep_main, 3, end=100) offset = trigger_gc_and_select( @@ -263,10 +275,10 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder): # Trigger Pageserver restarts for ps in env.pageservers: ps.stop() - # Static compute should have at least one lease request failure due to connection. - time.sleep(LSN_LEASE_LENGTH / 2) ps.start() + renew_lease(env, lsn) + trigger_gc_and_select( env, ep_static, @@ -282,6 +294,9 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder): ) env.storage_controller.reconcile_until_idle() + # Wait for static compute to renew lease on the new pageserver. + time.sleep(LSN_LEASE_LENGTH + 3) + trigger_gc_and_select( env, ep_static, @@ -292,7 +307,6 @@ def test_readonly_node_gc(neon_env_builder: NeonEnvBuilder): # Do some update so we can increment gc_cutoff generate_updates_on_main(env, ep_main, i, end=100) - # Wait for the existing lease to expire. time.sleep(LSN_LEASE_LENGTH + 1) # Now trigger GC again, layers should be removed. diff --git a/test_runner/regress/test_sharding.py b/test_runner/regress/test_sharding.py index c2907d8a4f..4e46b67988 100644 --- a/test_runner/regress/test_sharding.py +++ b/test_runner/regress/test_sharding.py @@ -1751,14 +1751,15 @@ def test_back_pressure_per_shard(neon_env_builder: NeonEnvBuilder): "max_replication_apply_lag = 0", "max_replication_flush_lag = 15MB", "neon.max_cluster_size = 10GB", + "neon.lakebase_mode = true", ], ) endpoint.respec(skip_pg_catalog_updates=False) endpoint.start() - # generate 10MB of data + # generate 20MB of data endpoint.safe_psql( - "CREATE TABLE usertable AS SELECT s AS KEY, repeat('a', 1000) as VALUE from generate_series(1, 10000) s;" + "CREATE TABLE usertable AS SELECT s AS KEY, repeat('a', 1000) as VALUE from generate_series(1, 20000) s;" ) res = endpoint.safe_psql("SELECT neon.backpressure_throttling_time() as throttling_time")[0] assert res[0] == 0, f"throttling_time should be 0, but got {res[0]}" diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index 9986c1f24a..e11be1df8c 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -3309,6 +3309,7 @@ def test_ps_unavailable_after_delete( ps.allowed_errors.append(".*request was dropped before completing.*") env.storage_controller.node_delete(ps.id, force=True) wait_until(lambda: assert_nodes_count(2)) + env.storage_controller.reconcile_until_idle() elif deletion_api == DeletionAPIKind.OLD: env.storage_controller.node_delete_old(ps.id) assert_nodes_count(2) @@ -4959,3 +4960,49 @@ def test_storage_controller_forward_404(neon_env_builder: NeonEnvBuilder): env.storage_controller.configure_failpoints( ("reconciler-live-migrate-post-generation-inc", "off") ) + + +def test_re_attach_with_stuck_secondary(neon_env_builder: NeonEnvBuilder): + """ + This test assumes that the secondary location cannot be configured for whatever reason. + It then attempts to detach and and attach the tenant back again and, finally, checks + for observed state consistency by attempting to create a timeline. + + See LKB-204 for more details. + """ + + neon_env_builder.num_pageservers = 2 + + env = neon_env_builder.init_configs() + env.start() + + env.storage_controller.allowed_errors.append(".*failpoint.*") + + tenant_id, _ = env.create_tenant(shard_count=1, placement_policy='{"Attached":1}') + env.storage_controller.reconcile_until_idle() + + locations = env.storage_controller.locate(tenant_id) + assert len(locations) == 1 + primary: int = locations[0]["node_id"] + + not_primary = [ps.id for ps in env.pageservers if ps.id != primary] + assert len(not_primary) == 1 + secondary = not_primary[0] + + env.get_pageserver(secondary).http_client().configure_failpoints( + ("put-location-conf-handler", "return(1)") + ) + + env.storage_controller.tenant_policy_update(tenant_id, {"placement": "Detached"}) + + with pytest.raises(Exception, match="failpoint"): + env.storage_controller.reconcile_all() + + env.storage_controller.tenant_policy_update(tenant_id, {"placement": {"Attached": 1}}) + + with pytest.raises(Exception, match="failpoint"): + env.storage_controller.reconcile_all() + + env.storage_controller.pageserver_api().timeline_create( + pg_version=PgVersion.NOT_SET, tenant_id=tenant_id, new_timeline_id=TimelineId.generate() + ) diff --git a/test_runner/regress/test_tenant_size.py b/test_runner/regress/test_tenant_size.py index 8b291b7cbe..5564d9342c 100644 --- a/test_runner/regress/test_tenant_size.py +++ b/test_runner/regress/test_tenant_size.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING @@ -768,6 +769,14 @@ def test_lsn_lease_storcon(neon_env_builder: NeonEnvBuilder): "compaction_period": "0s", } env = neon_env_builder.init_start(initial_tenant_conf=conf) + # ShardSplit is slow in debug builds, so ignore the warning + if os.getenv("BUILD_TYPE", "debug") == "debug": + env.storage_controller.allowed_errors.extend( + [ + ".*Exclusive lock by ShardSplit was held.*", + ] + ) + with env.endpoints.create_start( "main", ) as ep: diff --git a/test_runner/regress/test_tenants.py b/test_runner/regress/test_tenants.py index 7f32f34d36..49bc02a3e7 100644 --- a/test_runner/regress/test_tenants.py +++ b/test_runner/regress/test_tenants.py @@ -298,15 +298,26 @@ def test_pageserver_metrics_removed_after_detach(neon_env_builder: NeonEnvBuilde assert post_detach_samples == set() -def test_pageserver_metrics_removed_after_offload(neon_env_builder: NeonEnvBuilder): +@pytest.mark.parametrize("compaction", ["compaction_enabled", "compaction_disabled"]) +def test_pageserver_metrics_removed_after_offload( + neon_env_builder: NeonEnvBuilder, compaction: str +): """Tests that when a timeline is offloaded, the tenant specific metrics are not left behind""" neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.MOCK_S3) - neon_env_builder.num_safekeepers = 3 env = neon_env_builder.init_start() - tenant_1, _ = env.create_tenant() + tenant_1, _ = env.create_tenant( + conf={ + # disable background compaction and GC so that we don't have leftover tasks + # after offloading. + "gc_period": "0s", + "compaction_period": "0s", + } + if compaction == "compaction_disabled" + else None + ) timeline_1 = env.create_timeline("test_metrics_removed_after_offload_1", tenant_id=tenant_1) timeline_2 = env.create_timeline("test_metrics_removed_after_offload_2", tenant_id=tenant_1) @@ -351,6 +362,23 @@ def test_pageserver_metrics_removed_after_offload(neon_env_builder: NeonEnvBuild state=TimelineArchivalState.ARCHIVED, ) env.pageserver.http_client().timeline_offload(tenant_1, timeline) + # We need to wait until all background jobs are finished before we can check the metrics. + # There're many of them: compaction, GC, etc. + wait_until( + lambda: all( + sample.value == 0 + for sample in env.pageserver.http_client() + .get_metrics() + .query_all("pageserver_background_loop_semaphore_waiting_tasks") + ) + and all( + sample.value == 0 + for sample in env.pageserver.http_client() + .get_metrics() + .query_all("pageserver_background_loop_semaphore_running_tasks") + ) + ) + post_offload_samples = set( [x.name for x in get_ps_metric_samples_for_timeline(tenant_1, timeline)] ) diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index f5984d3ac3..e1eba9149d 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -28,6 +28,8 @@ chrono = { version = "0.4", default-features = false, features = ["clock", "serd clap = { version = "4", features = ["derive", "env", "string"] } clap_builder = { version = "4", default-features = false, features = ["color", "env", "help", "std", "string", "suggestions", "usage"] } const-oid = { version = "0.9", default-features = false, features = ["db", "std"] } +crossbeam-epoch = { version = "0.9" } +crossbeam-utils = { version = "0.8" } crypto-bigint = { version = "0.5", features = ["generic-array", "zeroize"] } der = { version = "0.7", default-features = false, features = ["derive", "flagset", "oid", "pem", "std"] } deranged = { version = "0.3", default-features = false, features = ["powerfmt", "serde", "std"] } @@ -73,6 +75,7 @@ num-traits = { version = "0.2", features = ["i128", "libm"] } once_cell = { version = "1" } p256 = { version = "0.13", features = ["jwk"] } parquet = { version = "53", default-features = false, features = ["zstd"] } +portable-atomic = { version = "1", features = ["require-cas"] } prost = { version = "0.13", features = ["no-recursion-limit", "prost-derive"] } rand = { version = "0.9" } regex = { version = "1" }