diff --git a/.github/actionlint.yml b/.github/actionlint.yml index 25b2fc702a..8a4bcaf811 100644 --- a/.github/actionlint.yml +++ b/.github/actionlint.yml @@ -31,7 +31,7 @@ config-variables: - NEON_PROD_AWS_ACCOUNT_ID - PGREGRESS_PG16_PROJECT_ID - PGREGRESS_PG17_PROJECT_ID - - PREWARM_PGBENCH_SIZE + - PREWARM_PROJECT_ID - REMOTE_STORAGE_AZURE_CONTAINER - REMOTE_STORAGE_AZURE_REGION - SLACK_CICD_CHANNEL_ID diff --git a/.github/workflows/benchbase_tpcc.yml b/.github/workflows/benchbase_tpcc.yml new file mode 100644 index 0000000000..3a36a97bb1 --- /dev/null +++ b/.github/workflows/benchbase_tpcc.yml @@ -0,0 +1,384 @@ +name: TPC-C like benchmark using benchbase + +on: + schedule: + # * is a special character in YAML so you have to quote this string + # ┌───────────── minute (0 - 59) + # │ ┌───────────── hour (0 - 23) + # │ │ ┌───────────── day of the month (1 - 31) + # │ │ │ ┌───────────── month (1 - 12 or JAN-DEC) + # │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT) + - cron: '0 6 * * *' # run once a day at 6 AM UTC + workflow_dispatch: # adds ability to run this manually + +defaults: + run: + shell: bash -euxo pipefail {0} + +concurrency: + # Allow only one workflow globally because we do not want to be too noisy in production environment + group: benchbase-tpcc-workflow + cancel-in-progress: false + +permissions: + contents: read + +jobs: + benchbase-tpcc: + strategy: + fail-fast: false # allow other variants to continue even if one fails + matrix: + include: + - warehouses: 50 # defines number of warehouses and is used to compute number of terminals + max_rate: 800 # measured max TPS at scale factor based on experiments. Adjust if performance is better/worse + min_cu: 0.25 # simulate free tier plan (0.25 -2 CU) + max_cu: 2 + - warehouses: 500 # serverless plan (2-8 CU) + max_rate: 2000 + min_cu: 2 + max_cu: 8 + - warehouses: 1000 # business plan (2-16 CU) + max_rate: 2900 + min_cu: 2 + max_cu: 16 + max-parallel: 1 # we want to run each workload size sequentially to avoid noisy neighbors + permissions: + contents: write + statuses: write + id-token: write # aws-actions/configure-aws-credentials + env: + PG_CONFIG: /tmp/neon/pg_install/v17/bin/pg_config + PSQL: /tmp/neon/pg_install/v17/bin/psql + PG_17_LIB_PATH: /tmp/neon/pg_install/v17/lib + POSTGRES_VERSION: 17 + runs-on: [ self-hosted, us-east-2, x64 ] + timeout-minutes: 1440 + + steps: + - name: Harden the runner (Audit all outbound calls) + uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0 + with: + egress-policy: audit + + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Configure AWS credentials # necessary to download artefacts + uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2 + with: + aws-region: eu-central-1 + role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} + role-duration-seconds: 18000 # 5 hours is currently max associated with IAM role + + - name: Download Neon artifact + uses: ./.github/actions/download + with: + name: neon-${{ runner.os }}-${{ runner.arch }}-release-artifact + path: /tmp/neon/ + prefix: latest + aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} + + - name: Create Neon Project + id: create-neon-project-tpcc + uses: ./.github/actions/neon-project-create + with: + region_id: aws-us-east-2 + postgres_version: ${{ env.POSTGRES_VERSION }} + compute_units: '[${{ matrix.min_cu }}, ${{ matrix.max_cu }}]' + api_key: ${{ secrets.NEON_PRODUCTION_API_KEY_4_BENCHMARKS }} + api_host: console.neon.tech # production (!) + + - name: Initialize Neon project + env: + BENCHMARK_TPCC_CONNSTR: ${{ steps.create-neon-project-tpcc.outputs.dsn }} + PROJECT_ID: ${{ steps.create-neon-project-tpcc.outputs.project_id }} + run: | + echo "Initializing Neon project with project_id: ${PROJECT_ID}" + export LD_LIBRARY_PATH=${PG_17_LIB_PATH} + + # Retry logic for psql connection with 1 minute sleep between attempts + for attempt in {1..3}; do + echo "Attempt ${attempt}/3: Creating extensions in Neon project" + if ${PSQL} "${BENCHMARK_TPCC_CONNSTR}" -c "CREATE EXTENSION IF NOT EXISTS neon; CREATE EXTENSION IF NOT EXISTS neon_utils;"; then + echo "Successfully created extensions" + break + else + echo "Failed to create extensions on attempt ${attempt}" + if [ ${attempt} -lt 3 ]; then + echo "Waiting 60 seconds before retry..." + sleep 60 + else + echo "All attempts failed, exiting" + exit 1 + fi + fi + done + + echo "BENCHMARK_TPCC_CONNSTR=${BENCHMARK_TPCC_CONNSTR}" >> $GITHUB_ENV + + - name: Generate BenchBase workload configuration + env: + WAREHOUSES: ${{ matrix.warehouses }} + MAX_RATE: ${{ matrix.max_rate }} + run: | + echo "Generating BenchBase configs for warehouses: ${WAREHOUSES}, max_rate: ${MAX_RATE}" + + # Extract hostname and password from connection string + # Format: postgresql://username:password@hostname/database?params (no port for Neon) + HOSTNAME=$(echo "${BENCHMARK_TPCC_CONNSTR}" | sed -n 's|.*://[^:]*:[^@]*@\([^/]*\)/.*|\1|p') + PASSWORD=$(echo "${BENCHMARK_TPCC_CONNSTR}" | sed -n 's|.*://[^:]*:\([^@]*\)@.*|\1|p') + + echo "Extracted hostname: ${HOSTNAME}" + + # Use runner temp (NVMe) as working directory + cd "${RUNNER_TEMP}" + + # Copy the generator script + cp "${GITHUB_WORKSPACE}/test_runner/performance/benchbase_tpc_c_helpers/generate_workload_size.py" . + + # Generate configs and scripts + python3 generate_workload_size.py \ + --warehouses ${WAREHOUSES} \ + --max-rate ${MAX_RATE} \ + --hostname ${HOSTNAME} \ + --password ${PASSWORD} \ + --runner-arch ${{ runner.arch }} + + # Fix path mismatch: move generated configs and scripts to expected locations + mv ../configs ./configs + mv ../scripts ./scripts + + - name: Prepare database (load data) + env: + WAREHOUSES: ${{ matrix.warehouses }} + run: | + cd "${RUNNER_TEMP}" + + echo "Loading ${WAREHOUSES} warehouses into database..." + + # Run the loader script and capture output to log file while preserving stdout/stderr + ./scripts/load_${WAREHOUSES}_warehouses.sh 2>&1 | tee "load_${WAREHOUSES}_warehouses.log" + + echo "Database loading completed" + + - name: Run TPC-C benchmark (warmup phase, then benchmark at 70% of configuredmax TPS) + env: + WAREHOUSES: ${{ matrix.warehouses }} + run: | + cd "${RUNNER_TEMP}" + + echo "Running TPC-C benchmark with ${WAREHOUSES} warehouses..." + + # Run the optimal rate benchmark + ./scripts/execute_${WAREHOUSES}_warehouses_opt_rate.sh + + echo "Benchmark execution completed" + + - name: Run TPC-C benchmark (warmup phase, then ramp down TPS and up again in 5 minute intervals) + + env: + WAREHOUSES: ${{ matrix.warehouses }} + run: | + cd "${RUNNER_TEMP}" + + echo "Running TPC-C ramp-down-up with ${WAREHOUSES} warehouses..." + + # Run the optimal rate benchmark + ./scripts/execute_${WAREHOUSES}_warehouses_ramp_up.sh + + echo "Benchmark execution completed" + + - name: Process results (upload to test results database and generate diagrams) + env: + WAREHOUSES: ${{ matrix.warehouses }} + MIN_CU: ${{ matrix.min_cu }} + MAX_CU: ${{ matrix.max_cu }} + PROJECT_ID: ${{ steps.create-neon-project-tpcc.outputs.project_id }} + REVISION: ${{ github.sha }} + PERF_DB_CONNSTR: ${{ secrets.PERF_TEST_RESULT_CONNSTR }} + run: | + cd "${RUNNER_TEMP}" + + echo "Creating temporary Python environment for results processing..." + + # Create temporary virtual environment + python3 -m venv temp_results_env + source temp_results_env/bin/activate + + # Install required packages in virtual environment + pip install matplotlib pandas psycopg2-binary + + echo "Copying results processing scripts..." + + # Copy both processing scripts + cp "${GITHUB_WORKSPACE}/test_runner/performance/benchbase_tpc_c_helpers/generate_diagrams.py" . + cp "${GITHUB_WORKSPACE}/test_runner/performance/benchbase_tpc_c_helpers/upload_results_to_perf_test_results.py" . + + echo "Processing load phase metrics..." + + # Find and process load log + LOAD_LOG=$(find . -name "load_${WAREHOUSES}_warehouses.log" -type f | head -1) + if [ -n "$LOAD_LOG" ]; then + echo "Processing load metrics from: $LOAD_LOG" + python upload_results_to_perf_test_results.py \ + --load-log "$LOAD_LOG" \ + --run-type "load" \ + --warehouses "${WAREHOUSES}" \ + --min-cu "${MIN_CU}" \ + --max-cu "${MAX_CU}" \ + --project-id "${PROJECT_ID}" \ + --revision "${REVISION}" \ + --connection-string "${PERF_DB_CONNSTR}" + else + echo "Warning: Load log file not found: load_${WAREHOUSES}_warehouses.log" + fi + + echo "Processing warmup results for optimal rate..." + + # Find and process warmup results + WARMUP_CSV=$(find results_warmup -name "*.results.csv" -type f | head -1) + WARMUP_JSON=$(find results_warmup -name "*.summary.json" -type f | head -1) + + if [ -n "$WARMUP_CSV" ] && [ -n "$WARMUP_JSON" ]; then + echo "Generating warmup diagram from: $WARMUP_CSV" + python generate_diagrams.py \ + --input-csv "$WARMUP_CSV" \ + --output-svg "warmup_${WAREHOUSES}_warehouses_performance.svg" \ + --title-suffix "Warmup at max TPS" + + echo "Uploading warmup metrics from: $WARMUP_JSON" + python upload_results_to_perf_test_results.py \ + --summary-json "$WARMUP_JSON" \ + --results-csv "$WARMUP_CSV" \ + --run-type "warmup" \ + --min-cu "${MIN_CU}" \ + --max-cu "${MAX_CU}" \ + --project-id "${PROJECT_ID}" \ + --revision "${REVISION}" \ + --connection-string "${PERF_DB_CONNSTR}" + else + echo "Warning: Missing warmup results files (CSV: $WARMUP_CSV, JSON: $WARMUP_JSON)" + fi + + echo "Processing optimal rate results..." + + # Find and process optimal rate results + OPTRATE_CSV=$(find results_opt_rate -name "*.results.csv" -type f | head -1) + OPTRATE_JSON=$(find results_opt_rate -name "*.summary.json" -type f | head -1) + + if [ -n "$OPTRATE_CSV" ] && [ -n "$OPTRATE_JSON" ]; then + echo "Generating optimal rate diagram from: $OPTRATE_CSV" + python generate_diagrams.py \ + --input-csv "$OPTRATE_CSV" \ + --output-svg "benchmark_${WAREHOUSES}_warehouses_performance.svg" \ + --title-suffix "70% of max TPS" + + echo "Uploading optimal rate metrics from: $OPTRATE_JSON" + python upload_results_to_perf_test_results.py \ + --summary-json "$OPTRATE_JSON" \ + --results-csv "$OPTRATE_CSV" \ + --run-type "opt-rate" \ + --min-cu "${MIN_CU}" \ + --max-cu "${MAX_CU}" \ + --project-id "${PROJECT_ID}" \ + --revision "${REVISION}" \ + --connection-string "${PERF_DB_CONNSTR}" + else + echo "Warning: Missing optimal rate results files (CSV: $OPTRATE_CSV, JSON: $OPTRATE_JSON)" + fi + + echo "Processing warmup 2 results for ramp down/up phase..." + + # Find and process warmup results + WARMUP_CSV=$(find results_warmup -name "*.results.csv" -type f | tail -1) + WARMUP_JSON=$(find results_warmup -name "*.summary.json" -type f | tail -1) + + if [ -n "$WARMUP_CSV" ] && [ -n "$WARMUP_JSON" ]; then + echo "Generating warmup diagram from: $WARMUP_CSV" + python generate_diagrams.py \ + --input-csv "$WARMUP_CSV" \ + --output-svg "warmup_2_${WAREHOUSES}_warehouses_performance.svg" \ + --title-suffix "Warmup at max TPS" + + echo "Uploading warmup metrics from: $WARMUP_JSON" + python upload_results_to_perf_test_results.py \ + --summary-json "$WARMUP_JSON" \ + --results-csv "$WARMUP_CSV" \ + --run-type "warmup" \ + --min-cu "${MIN_CU}" \ + --max-cu "${MAX_CU}" \ + --project-id "${PROJECT_ID}" \ + --revision "${REVISION}" \ + --connection-string "${PERF_DB_CONNSTR}" + else + echo "Warning: Missing warmup results files (CSV: $WARMUP_CSV, JSON: $WARMUP_JSON)" + fi + + echo "Processing ramp results..." + + # Find and process ramp results + RAMPUP_CSV=$(find results_ramp_up -name "*.results.csv" -type f | head -1) + RAMPUP_JSON=$(find results_ramp_up -name "*.summary.json" -type f | head -1) + + if [ -n "$RAMPUP_CSV" ] && [ -n "$RAMPUP_JSON" ]; then + echo "Generating ramp diagram from: $RAMPUP_CSV" + python generate_diagrams.py \ + --input-csv "$RAMPUP_CSV" \ + --output-svg "ramp_${WAREHOUSES}_warehouses_performance.svg" \ + --title-suffix "ramp TPS down and up in 5 minute intervals" + + echo "Uploading ramp metrics from: $RAMPUP_JSON" + python upload_results_to_perf_test_results.py \ + --summary-json "$RAMPUP_JSON" \ + --results-csv "$RAMPUP_CSV" \ + --run-type "ramp-up" \ + --min-cu "${MIN_CU}" \ + --max-cu "${MAX_CU}" \ + --project-id "${PROJECT_ID}" \ + --revision "${REVISION}" \ + --connection-string "${PERF_DB_CONNSTR}" + else + echo "Warning: Missing ramp results files (CSV: $RAMPUP_CSV, JSON: $RAMPUP_JSON)" + fi + + # Deactivate and clean up virtual environment + deactivate + rm -rf temp_results_env + rm upload_results_to_perf_test_results.py + + echo "Results processing completed and environment cleaned up" + + - name: Set date for upload + id: set-date + run: echo "date=$(date +%Y-%m-%d)" >> $GITHUB_OUTPUT + + - name: Configure AWS credentials # necessary to upload results + uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2 + with: + aws-region: us-east-2 + role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} + role-duration-seconds: 900 # 900 is minimum value + + - name: Upload benchmark results to S3 + env: + S3_BUCKET: neon-public-benchmark-results + S3_PREFIX: benchbase-tpc-c/${{ steps.set-date.outputs.date }}/${{ github.run_id }}/${{ matrix.warehouses }}-warehouses + run: | + echo "Redacting passwords from configuration files before upload..." + + # Mask all passwords in XML config files + find "${RUNNER_TEMP}/configs" -name "*.xml" -type f -exec sed -i 's|[^<]*|redacted|g' {} \; + + echo "Uploading benchmark results to s3://${S3_BUCKET}/${S3_PREFIX}/" + + # Upload the entire benchmark directory recursively + aws s3 cp --only-show-errors --recursive "${RUNNER_TEMP}" s3://${S3_BUCKET}/${S3_PREFIX}/ + + echo "Upload completed" + + - name: Delete Neon Project + if: ${{ always() }} + uses: ./.github/actions/neon-project-delete + with: + project_id: ${{ steps.create-neon-project-tpcc.outputs.project_id }} + api_key: ${{ secrets.NEON_PRODUCTION_API_KEY_4_BENCHMARKS }} + api_host: console.neon.tech # production (!) \ No newline at end of file diff --git a/.github/workflows/benchmarking.yml b/.github/workflows/benchmarking.yml index df80bad579..c9a998bd4e 100644 --- a/.github/workflows/benchmarking.yml +++ b/.github/workflows/benchmarking.yml @@ -418,7 +418,7 @@ jobs: statuses: write id-token: write # aws-actions/configure-aws-credentials env: - PGBENCH_SIZE: ${{ vars.PREWARM_PGBENCH_SIZE }} + PROJECT_ID: ${{ vars.PREWARM_PROJECT_ID }} POSTGRES_DISTRIB_DIR: /tmp/neon/pg_install DEFAULT_PG_VERSION: 17 TEST_OUTPUT: /tmp/test_output diff --git a/.github/workflows/pg-clients.yml b/.github/workflows/pg-clients.yml index 6efe0b4c8c..b6b4eca2b8 100644 --- a/.github/workflows/pg-clients.yml +++ b/.github/workflows/pg-clients.yml @@ -48,8 +48,20 @@ jobs: uses: ./.github/workflows/build-build-tools-image.yml secrets: inherit + generate-ch-tmppw: + runs-on: ubuntu-22.04 + outputs: + tmp_val: ${{ steps.pwgen.outputs.tmp_val }} + steps: + - name: Generate a random password + id: pwgen + run: | + set +x + p=$(dd if=/dev/random bs=14 count=1 2>/dev/null | base64) + echo tmp_val="${p//\//}" >> "${GITHUB_OUTPUT}" + test-logical-replication: - needs: [ build-build-tools-image ] + needs: [ build-build-tools-image, generate-ch-tmppw ] runs-on: ubuntu-22.04 container: @@ -60,16 +72,20 @@ jobs: options: --init --user root services: clickhouse: - image: clickhouse/clickhouse-server:24.6.3.64 + image: clickhouse/clickhouse-server:24.8 + env: + CLICKHOUSE_PASSWORD: ${{ needs.generate-ch-tmppw.outputs.tmp_val }} ports: - 9000:9000 - 8123:8123 zookeeper: - image: quay.io/debezium/zookeeper:2.7 + image: quay.io/debezium/zookeeper:3.1.3.Final ports: - 2181:2181 + - 2888:2888 + - 3888:3888 kafka: - image: quay.io/debezium/kafka:2.7 + image: quay.io/debezium/kafka:3.1.3.Final env: ZOOKEEPER_CONNECT: "zookeeper:2181" KAFKA_ADVERTISED_LISTENERS: PLAINTEXT://kafka:9092 @@ -79,7 +95,7 @@ jobs: ports: - 9092:9092 debezium: - image: quay.io/debezium/connect:2.7 + image: quay.io/debezium/connect:3.1.3.Final env: BOOTSTRAP_SERVERS: kafka:9092 GROUP_ID: 1 @@ -125,6 +141,7 @@ jobs: aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} env: BENCHMARK_CONNSTR: ${{ steps.create-neon-project.outputs.dsn }} + CLICKHOUSE_PASSWORD: ${{ needs.generate-ch-tmppw.outputs.tmp_val }} - name: Delete Neon Project if: always() diff --git a/.github/workflows/proxy-benchmark.yml b/.github/workflows/proxy-benchmark.yml index 0ae93ce295..e48fe41b45 100644 --- a/.github/workflows/proxy-benchmark.yml +++ b/.github/workflows/proxy-benchmark.yml @@ -3,7 +3,7 @@ name: Periodic proxy performance test on unit-perf-aws-arm runners on: push: # TODO: remove after testing branches: - - test-proxy-bench # Runs on pushes to branches starting with test-proxy-bench + - test-proxy-bench # Runs on pushes to test-proxy-bench branch # schedule: # * is a special character in YAML so you have to quote this string # ┌───────────── minute (0 - 59) @@ -32,7 +32,7 @@ jobs: statuses: write contents: write pull-requests: write - runs-on: [self-hosted, unit-perf-aws-arm] + runs-on: [ self-hosted, unit-perf-aws-arm ] timeout-minutes: 60 # 1h timeout container: image: ghcr.io/neondatabase/build-tools:pinned-bookworm @@ -55,30 +55,58 @@ jobs: { echo "PROXY_BENCH_PATH=$PROXY_BENCH_PATH" echo "NEON_DIR=${RUNNER_TEMP}/neon" + echo "NEON_PROXY_PATH=${RUNNER_TEMP}/neon/bin/proxy" echo "TEST_OUTPUT=${PROXY_BENCH_PATH}/test_output" echo "" } >> "$GITHUB_ENV" - - name: Run proxy-bench - run: ${PROXY_BENCH_PATH}/run.sh + - name: Cache poetry deps + uses: actions/cache@v4 + with: + path: ~/.cache/pypoetry/virtualenvs + key: v2-${{ runner.os }}-${{ runner.arch }}-python-deps-bookworm-${{ hashFiles('poetry.lock') }} - - name: Ingest Bench Results # neon repo script + - name: Install Python deps + shell: bash -euxo pipefail {0} + run: ./scripts/pysync + + - name: show ulimits + shell: bash -euxo pipefail {0} + run: | + ulimit -a + + - name: Run proxy-bench + working-directory: ${{ env.PROXY_BENCH_PATH }} + run: ./run.sh --with-grafana --bare-metal + + - name: Ingest Bench Results if: always() + working-directory: ${{ env.NEON_DIR }} run: | mkdir -p $TEST_OUTPUT python $NEON_DIR/scripts/proxy_bench_results_ingest.py --out $TEST_OUTPUT - name: Push Metrics to Proxy perf database + shell: bash -euxo pipefail {0} if: always() env: PERF_TEST_RESULT_CONNSTR: "${{ secrets.PROXY_TEST_RESULT_CONNSTR }}" REPORT_FROM: $TEST_OUTPUT + working-directory: ${{ env.NEON_DIR }} run: $NEON_DIR/scripts/generate_and_push_perf_report.sh - - name: Docker cleanup - if: always() - run: docker compose down - - name: Notify Failure if: failure() - run: echo "Proxy bench job failed" && exit 1 \ No newline at end of file + run: echo "Proxy bench job failed" && exit 1 + + - name: Cleanup Test Resources + if: always() + shell: bash -euxo pipefail {0} + run: | + # Cleanup the test resources + if [[ -d "${TEST_OUTPUT}" ]]; then + rm -rf ${TEST_OUTPUT} + fi + if [[ -d "${PROXY_BENCH_PATH}/test_output" ]]; then + rm -rf ${PROXY_BENCH_PATH}/test_output + fi \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index b35fd7d074..065e7c5bd8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -250,11 +250,11 @@ dependencies = [ [[package]] name = "async-lock" -version = "3.2.0" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7125e42787d53db9dd54261812ef17e937c95a51e4d291373b670342fa44310c" +checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18" dependencies = [ - "event-listener 4.0.0", + "event-listener 5.4.0", "event-listener-strategy", "pin-project-lite", ] @@ -1483,6 +1483,7 @@ dependencies = [ "tower-http", "tower-otel", "tracing", + "tracing-appender", "tracing-opentelemetry", "tracing-subscriber", "tracing-utils", @@ -1498,9 +1499,9 @@ dependencies = [ [[package]] name = "concurrent-queue" -version = "2.3.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f057a694a54f12365049b0958a1685bb52d567f5593b355fbf685838e873d400" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" dependencies = [ "crossbeam-utils", ] @@ -2349,9 +2350,9 @@ checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" [[package]] name = "event-listener" -version = "4.0.0" +version = "5.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "770d968249b5d99410d61f5bf89057f3199a077a04d087092f58e7d10692baae" +checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae" dependencies = [ "concurrent-queue", "parking", @@ -2360,11 +2361,11 @@ dependencies = [ [[package]] name = "event-listener-strategy" -version = "0.4.0" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" dependencies = [ - "event-listener 4.0.0", + "event-listener 5.4.0", "pin-project-lite", ] @@ -2639,6 +2640,20 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "304de19db7028420975a296ab0fcbbc8e69438c4ed254a1e41e2a7f37d5f0e0a" +[[package]] +name = "generator" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d18470a76cb7f8ff746cf1f7470914f900252ec36bbc40b569d74b1258446827" +dependencies = [ + "cc", + "cfg-if", + "libc", + "log", + "rustversion", + "windows 0.61.3", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -2966,7 +2981,7 @@ checksum = "f9c7c7c8ac16c798734b8a24560c1362120597c40d5e1459f09498f8f6c8f2ba" dependencies = [ "cfg-if", "libc", - "windows", + "windows 0.52.0", ] [[package]] @@ -3237,7 +3252,7 @@ dependencies = [ "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows-core", + "windows-core 0.52.0", ] [[package]] @@ -3794,6 +3809,19 @@ version = "0.4.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" +[[package]] +name = "loom" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator", + "scoped-tls", + "tracing", + "tracing-subscriber", +] + [[package]] name = "lru" version = "0.12.3" @@ -4010,6 +4038,25 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "moka" +version = "0.12.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9321642ca94a4282428e6ea4af8cc2ca4eac48ac7a6a4ea8f33f76d0ce70926" +dependencies = [ + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "loom", + "parking_lot 0.12.1", + "portable-atomic", + "rustc_version", + "smallvec", + "tagptr", + "thiserror 1.0.69", + "uuid", +] + [[package]] name = "multimap" version = "0.8.3" @@ -5179,7 +5226,6 @@ dependencies = [ "criterion", "env_logger", "log", - "memoffset 0.9.0", "once_cell", "postgres", "postgres_ffi_types", @@ -5532,7 +5578,6 @@ dependencies = [ "futures", "gettid", "hashbrown 0.14.5", - "hashlink", "hex", "hmac", "hostname", @@ -5554,6 +5599,7 @@ dependencies = [ "lasso", "measured", "metrics", + "moka", "once_cell", "opentelemetry", "ouroboros", @@ -5620,6 +5666,7 @@ dependencies = [ "workspace_hack", "x509-cert", "zerocopy 0.8.24", + "zeroize", ] [[package]] @@ -6577,6 +6624,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.1.0" @@ -7427,6 +7480,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "tar" version = "0.4.40" @@ -8093,11 +8152,12 @@ dependencies = [ [[package]] name = "tracing-appender" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09d48f71a791638519505cefafe162606f706c25592e4bde4d97600c0195312e" +checksum = "3566e8ce28cc0a3fe42519fc80e6b4c943cc4c8cef275620eb8dac2d3d4e06cf" dependencies = [ "crossbeam-channel", + "thiserror 1.0.69", "time", "tracing-subscriber", ] @@ -8818,10 +8878,32 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be" dependencies = [ - "windows-core", + "windows-core 0.52.0", "windows-targets 0.52.6", ] +[[package]] +name = "windows" +version = "0.61.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893" +dependencies = [ + "windows-collections", + "windows-core 0.61.2", + "windows-future", + "windows-link", + "windows-numerics", +] + +[[package]] +name = "windows-collections" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" +dependencies = [ + "windows-core 0.61.2", +] + [[package]] name = "windows-core" version = "0.52.0" @@ -8831,6 +8913,86 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-core" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-future" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" +dependencies = [ + "windows-core 0.61.2", + "windows-link", + "windows-threading", +] + +[[package]] +name = "windows-implement" +version = "0.60.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "windows-interface" +version = "0.59.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.100", +] + +[[package]] +name = "windows-link" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" + +[[package]] +name = "windows-numerics" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" +dependencies = [ + "windows-core 0.61.2", + "windows-link", +] + +[[package]] +name = "windows-result" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-sys" version = "0.48.0" @@ -8889,6 +9051,15 @@ dependencies = [ "windows_x86_64_msvc 0.52.6", ] +[[package]] +name = "windows-threading" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" +dependencies = [ + "windows-link", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.0" @@ -9025,6 +9196,8 @@ dependencies = [ "clap", "clap_builder", "const-oid", + "crossbeam-epoch", + "crossbeam-utils", "crypto-bigint 0.5.5", "der 0.7.8", "deranged", @@ -9071,6 +9244,7 @@ dependencies = [ "once_cell", "p256 0.13.2", "parquet", + "portable-atomic", "prettyplease", "proc-macro2", "prost 0.13.5", diff --git a/Cargo.toml b/Cargo.toml index 1de261ed06..3744115ebf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,10 +46,10 @@ members = [ "libs/proxy/json", "libs/proxy/postgres-protocol2", "libs/proxy/postgres-types2", + "libs/proxy/subzero_core", "libs/proxy/tokio-postgres2", "endpoint_storage", "pgxn/neon/communicator", - "proxy/subzero_core", ] [workspace.package] @@ -135,7 +135,7 @@ lock_api = "0.4.13" md5 = "0.7.0" measured = { version = "0.0.22", features=["lasso"] } measured-process = { version = "0.0.22" } -memoffset = "0.9" +moka = { version = "0.12", features = ["sync"] } nix = { version = "0.30.1", features = ["dir", "fs", "mman", "process", "socket", "signal", "poll"] } # Do not update to >= 7.0.0, at least. The update will have a significant impact # on compute startup metrics (start_postgres_ms), >= 25% degradation. @@ -146,7 +146,7 @@ oid-registry = "0.7.1" once_cell = "1.13" opentelemetry = "0.30" opentelemetry_sdk = "0.30" -opentelemetry-otlp = { version = "0.30", default-features = false, features = ["http-proto", "trace", "http", "reqwest-client"] } +opentelemetry-otlp = { version = "0.30", default-features = false, features = ["http-proto", "trace", "http", "reqwest-blocking-client"] } opentelemetry-semantic-conventions = "0.30" parking_lot = "0.12" parquet = { version = "53", default-features = false, features = ["zstd"] } @@ -224,6 +224,7 @@ tracing-log = "0.2" tracing-opentelemetry = "0.31" tracing-serde = "0.2.0" tracing-subscriber = { version = "0.3", default-features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] } +tracing-appender = "0.2.3" try-lock = "0.2.5" test-log = { version = "0.2.17", default-features = false, features = ["log"] } twox-hash = { version = "1.6.3", default-features = false } @@ -234,10 +235,11 @@ uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] } walkdir = "2.3.2" rustls-native-certs = "0.8" whoami = "1.5.1" -zerocopy = { version = "0.8", features = ["derive", "simd"] } json-structural-diff = { version = "0.2.0" } x509-cert = { version = "0.2.5" } x509-parser = "0.16" +zerocopy = { version = "0.8", features = ["derive", "simd"] } +zeroize = "1.8" ## TODO replace this with tracing env_logger = "0.11" diff --git a/build-tools/Dockerfile b/build-tools/Dockerfile index b5fe642e6f..87966591c1 100644 --- a/build-tools/Dockerfile +++ b/build-tools/Dockerfile @@ -39,13 +39,13 @@ COPY build-tools/patches/pgcopydbv017.patch /pgcopydbv017.patch RUN if [ "${DEBIAN_VERSION}" = "bookworm" ]; then \ set -e && \ - apt update && \ - apt install -y --no-install-recommends \ + apt-get update && \ + apt-get install -y --no-install-recommends \ ca-certificates wget gpg && \ wget -qO - https://www.postgresql.org/media/keys/ACCC4CF8.asc | gpg --dearmor -o /usr/share/keyrings/postgresql-keyring.gpg && \ echo "deb [signed-by=/usr/share/keyrings/postgresql-keyring.gpg] http://apt.postgresql.org/pub/repos/apt bookworm-pgdg main" > /etc/apt/sources.list.d/pgdg.list && \ apt-get update && \ - apt install -y --no-install-recommends \ + apt-get install -y --no-install-recommends \ build-essential \ autotools-dev \ libedit-dev \ @@ -89,8 +89,7 @@ RUN useradd -ms /bin/bash nonroot -b /home # Use strict mode for bash to catch errors early SHELL ["/bin/bash", "-euo", "pipefail", "-c"] -RUN mkdir -p /pgcopydb/bin && \ - mkdir -p /pgcopydb/lib && \ +RUN mkdir -p /pgcopydb/{bin,lib} && \ chmod -R 755 /pgcopydb && \ chown -R nonroot:nonroot /pgcopydb @@ -106,8 +105,8 @@ RUN echo 'Acquire::Retries "5";' > /etc/apt/apt.conf.d/80-retries && \ # 'gdb' is included so that we get backtraces of core dumps produced in # regression tests RUN set -e \ - && apt update \ - && apt install -y \ + && apt-get update \ + && apt-get install -y --no-install-recommends \ autoconf \ automake \ bison \ @@ -183,22 +182,22 @@ RUN curl -sL "https://github.com/peak/s5cmd/releases/download/v${S5CMD_VERSION}/ ENV LLVM_VERSION=20 RUN curl -fsSL 'https://apt.llvm.org/llvm-snapshot.gpg.key' | apt-key add - \ && echo "deb http://apt.llvm.org/${DEBIAN_VERSION}/ llvm-toolchain-${DEBIAN_VERSION}-${LLVM_VERSION} main" > /etc/apt/sources.list.d/llvm.stable.list \ - && apt update \ - && apt install -y clang-${LLVM_VERSION} llvm-${LLVM_VERSION} \ + && apt-get update \ + && apt-get install -y --no-install-recommends clang-${LLVM_VERSION} llvm-${LLVM_VERSION} \ && bash -c 'for f in /usr/bin/clang*-${LLVM_VERSION} /usr/bin/llvm*-${LLVM_VERSION}; do ln -s "${f}" "${f%-${LLVM_VERSION}}"; done' \ && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* # Install node ENV NODE_VERSION=24 RUN curl -fsSL https://deb.nodesource.com/setup_${NODE_VERSION}.x | bash - \ - && apt install -y nodejs \ + && apt-get install -y --no-install-recommends nodejs \ && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* # Install docker RUN curl -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg \ && echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/debian ${DEBIAN_VERSION} stable" > /etc/apt/sources.list.d/docker.list \ - && apt update \ - && apt install -y docker-ce docker-ce-cli \ + && apt-get update \ + && apt-get install -y --no-install-recommends docker-ce docker-ce-cli \ && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* # Configure sudo & docker @@ -215,12 +214,11 @@ RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-$(uname -m).zip" -o "aws # Mold: A Modern Linker ENV MOLD_VERSION=v2.37.1 RUN set -e \ - && git clone https://github.com/rui314/mold.git \ + && git clone -b "${MOLD_VERSION}" --depth 1 https://github.com/rui314/mold.git \ && mkdir mold/build \ - && cd mold/build \ - && git checkout ${MOLD_VERSION} \ + && cd mold/build \ && cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_COMPILER=clang++ .. \ - && cmake --build . -j $(nproc) \ + && cmake --build . -j "$(nproc)" \ && cmake --install . \ && cd .. \ && rm -rf mold @@ -254,7 +252,7 @@ ENV ICU_VERSION=67.1 ENV ICU_PREFIX=/usr/local/icu # Download and build static ICU -RUN wget -O /tmp/libicu-${ICU_VERSION}.tgz https://github.com/unicode-org/icu/releases/download/release-${ICU_VERSION//./-}/icu4c-${ICU_VERSION//./_}-src.tgz && \ +RUN wget -O "/tmp/libicu-${ICU_VERSION}.tgz" https://github.com/unicode-org/icu/releases/download/release-${ICU_VERSION//./-}/icu4c-${ICU_VERSION//./_}-src.tgz && \ echo "94a80cd6f251a53bd2a997f6f1b5ac6653fe791dfab66e1eb0227740fb86d5dc /tmp/libicu-${ICU_VERSION}.tgz" | sha256sum --check && \ mkdir /tmp/icu && \ pushd /tmp/icu && \ @@ -265,8 +263,7 @@ RUN wget -O /tmp/libicu-${ICU_VERSION}.tgz https://github.com/unicode-org/icu/re make install && \ popd && \ rm -rf icu && \ - rm -f /tmp/libicu-${ICU_VERSION}.tgz && \ - popd + rm -f /tmp/libicu-${ICU_VERSION}.tgz # Switch to nonroot user USER nonroot:nonroot @@ -279,19 +276,19 @@ ENV PYTHON_VERSION=3.11.12 \ PYENV_ROOT=/home/nonroot/.pyenv \ PATH=/home/nonroot/.pyenv/shims:/home/nonroot/.pyenv/bin:/home/nonroot/.poetry/bin:$PATH RUN set -e \ - && cd $HOME \ + && cd "$HOME" \ && curl -sSO https://raw.githubusercontent.com/pyenv/pyenv-installer/master/bin/pyenv-installer \ && chmod +x pyenv-installer \ && ./pyenv-installer \ && export PYENV_ROOT=/home/nonroot/.pyenv \ && export PATH="$PYENV_ROOT/bin:$PATH" \ && export PATH="$PYENV_ROOT/shims:$PATH" \ - && pyenv install ${PYTHON_VERSION} \ - && pyenv global ${PYTHON_VERSION} \ + && pyenv install "${PYTHON_VERSION}" \ + && pyenv global "${PYTHON_VERSION}" \ && python --version \ - && pip install --upgrade pip \ + && pip install --no-cache-dir --upgrade pip \ && pip --version \ - && pip install pipenv wheel poetry + && pip install --no-cache-dir pipenv wheel poetry # Switch to nonroot user (again) USER nonroot:nonroot @@ -317,13 +314,13 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux . "$HOME/.cargo/env" && \ cargo --version && rustup --version && \ rustup component add llvm-tools rustfmt clippy && \ - cargo install rustfilt --locked --version ${RUSTFILT_VERSION} && \ - cargo install cargo-hakari --locked --version ${CARGO_HAKARI_VERSION} && \ - cargo install cargo-deny --locked --version ${CARGO_DENY_VERSION} && \ - cargo install cargo-hack --locked --version ${CARGO_HACK_VERSION} && \ - cargo install cargo-nextest --locked --version ${CARGO_NEXTEST_VERSION} && \ - cargo install cargo-chef --locked --version ${CARGO_CHEF_VERSION} && \ - cargo install diesel_cli --locked --version ${CARGO_DIESEL_CLI_VERSION} \ + cargo install rustfilt --locked --version "${RUSTFILT_VERSION}" && \ + cargo install cargo-hakari --locked --version "${CARGO_HAKARI_VERSION}" && \ + cargo install cargo-deny --locked --version "${CARGO_DENY_VERSION}" && \ + cargo install cargo-hack --locked --version "${CARGO_HACK_VERSION}" && \ + cargo install cargo-nextest --locked --version "${CARGO_NEXTEST_VERSION}" && \ + cargo install cargo-chef --locked --version "${CARGO_CHEF_VERSION}" && \ + cargo install diesel_cli --locked --version "${CARGO_DIESEL_CLI_VERSION}" \ --features postgres-bundled --no-default-features && \ rm -rf /home/nonroot/.cargo/registry && \ rm -rf /home/nonroot/.cargo/git diff --git a/build-tools/package-lock.json b/build-tools/package-lock.json index b2c44ed9b4..0d48345fd5 100644 --- a/build-tools/package-lock.json +++ b/build-tools/package-lock.json @@ -6,7 +6,7 @@ "": { "name": "build-tools", "devDependencies": { - "@redocly/cli": "1.34.4", + "@redocly/cli": "1.34.5", "@sourcemeta/jsonschema": "10.0.0" } }, @@ -472,9 +472,9 @@ } }, "node_modules/@redocly/cli": { - "version": "1.34.4", - "resolved": "https://registry.npmjs.org/@redocly/cli/-/cli-1.34.4.tgz", - "integrity": "sha512-seH/GgrjSB1EeOsgJ/4Ct6Jk2N7sh12POn/7G8UQFARMyUMJpe1oHtBwT2ndfp4EFCpgBAbZ/82Iw6dwczNxEA==", + "version": "1.34.5", + "resolved": "https://registry.npmjs.org/@redocly/cli/-/cli-1.34.5.tgz", + "integrity": "sha512-5IEwxs7SGP5KEXjBKLU8Ffdz9by/KqNSeBk6YUVQaGxMXK//uYlTJIPntgUXbo1KAGG2d2q2XF8y4iFz6qNeiw==", "dev": true, "license": "MIT", "dependencies": { @@ -484,14 +484,14 @@ "@opentelemetry/sdk-trace-node": "1.26.0", "@opentelemetry/semantic-conventions": "1.27.0", "@redocly/config": "^0.22.0", - "@redocly/openapi-core": "1.34.4", - "@redocly/respect-core": "1.34.4", + "@redocly/openapi-core": "1.34.5", + "@redocly/respect-core": "1.34.5", "abort-controller": "^3.0.0", "chokidar": "^3.5.1", "colorette": "^1.2.0", "core-js": "^3.32.1", "dotenv": "16.4.7", - "form-data": "^4.0.0", + "form-data": "^4.0.4", "get-port-please": "^3.0.1", "glob": "^7.1.6", "handlebars": "^4.7.6", @@ -522,9 +522,9 @@ "license": "MIT" }, "node_modules/@redocly/openapi-core": { - "version": "1.34.4", - "resolved": "https://registry.npmjs.org/@redocly/openapi-core/-/openapi-core-1.34.4.tgz", - "integrity": "sha512-hf53xEgpXIgWl3b275PgZU3OTpYh1RoD2LHdIfQ1JzBNTWsiNKczTEsI/4Tmh2N1oq9YcphhSMyk3lDh85oDjg==", + "version": "1.34.5", + "resolved": "https://registry.npmjs.org/@redocly/openapi-core/-/openapi-core-1.34.5.tgz", + "integrity": "sha512-0EbE8LRbkogtcCXU7liAyC00n9uNG9hJ+eMyHFdUsy9lB/WGqnEBgwjA9q2cyzAVcdTkQqTBBU1XePNnN3OijA==", "dev": true, "license": "MIT", "dependencies": { @@ -544,21 +544,21 @@ } }, "node_modules/@redocly/respect-core": { - "version": "1.34.4", - "resolved": "https://registry.npmjs.org/@redocly/respect-core/-/respect-core-1.34.4.tgz", - "integrity": "sha512-MitKyKyQpsizA4qCVv+MjXL4WltfhFQAoiKiAzrVR1Kusro3VhYb6yJuzoXjiJhR0ukLP5QOP19Vcs7qmj9dZg==", + "version": "1.34.5", + "resolved": "https://registry.npmjs.org/@redocly/respect-core/-/respect-core-1.34.5.tgz", + "integrity": "sha512-GheC/g/QFztPe9UA9LamooSplQuy9pe0Yr8XGTqkz0ahivLDl7svoy/LSQNn1QH3XGtLKwFYMfTwFR2TAYyh5Q==", "dev": true, "license": "MIT", "dependencies": { "@faker-js/faker": "^7.6.0", "@redocly/ajv": "8.11.2", - "@redocly/openapi-core": "1.34.4", + "@redocly/openapi-core": "1.34.5", "better-ajv-errors": "^1.2.0", "colorette": "^2.0.20", "concat-stream": "^2.0.0", "cookie": "^0.7.2", "dotenv": "16.4.7", - "form-data": "4.0.0", + "form-data": "^4.0.4", "jest-diff": "^29.3.1", "jest-matcher-utils": "^29.3.1", "js-yaml": "4.1.0", @@ -582,21 +582,6 @@ "dev": true, "license": "MIT" }, - "node_modules/@redocly/respect-core/node_modules/form-data": { - "version": "4.0.0", - "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.0.tgz", - "integrity": "sha512-ETEklSGi5t0QMZuiXoA/Q6vcnxcLQP5vdugSpuAyi6SVGi2clPPp+xgEhuMaHC+zGgn31Kd235W35f7Hykkaww==", - "dev": true, - "license": "MIT", - "dependencies": { - "asynckit": "^0.4.0", - "combined-stream": "^1.0.8", - "mime-types": "^2.1.12" - }, - "engines": { - "node": ">= 6" - } - }, "node_modules/@sinclair/typebox": { "version": "0.27.8", "resolved": "https://registry.npmjs.org/@sinclair/typebox/-/typebox-0.27.8.tgz", @@ -1345,9 +1330,9 @@ "license": "MIT" }, "node_modules/form-data": { - "version": "4.0.3", - "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.3.tgz", - "integrity": "sha512-qsITQPfmvMOSAdeyZ+12I1c+CKSstAFAwu+97zrnWAbIr5u8wfsExUzCesVLC8NgHuRUqNN4Zy6UPWUTRGslcA==", + "version": "4.0.4", + "resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.4.tgz", + "integrity": "sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==", "dev": true, "license": "MIT", "dependencies": { diff --git a/build-tools/package.json b/build-tools/package.json index 000969c672..2dc1359075 100644 --- a/build-tools/package.json +++ b/build-tools/package.json @@ -2,7 +2,7 @@ "name": "build-tools", "private": true, "devDependencies": { - "@redocly/cli": "1.34.4", + "@redocly/cli": "1.34.5", "@sourcemeta/jsonschema": "10.0.0" } } diff --git a/compute/vm-image-spec-bookworm.yaml b/compute/vm-image-spec-bookworm.yaml index 267e4c83b5..5f27b6bf9d 100644 --- a/compute/vm-image-spec-bookworm.yaml +++ b/compute/vm-image-spec-bookworm.yaml @@ -26,7 +26,13 @@ commands: - name: postgres-exporter user: nobody sysvInitAction: respawn - shell: 'DATA_SOURCE_NAME="user=cloud_admin sslmode=disable dbname=postgres application_name=postgres-exporter pgaudit.log=none" /bin/postgres_exporter --config.file=/etc/postgres_exporter.yml' + # Turn off database collector (`--no-collector.database`), we don't use `pg_database_size_bytes` metric anyway, see + # https://github.com/neondatabase/flux-fleet/blob/5e19b3fd897667b70d9a7ad4aa06df0ca22b49ff/apps/base/compute-metrics/scrape-compute-pg-exporter-neon.yaml#L29 + # but it's enabled by default and it doesn't filter out invalid databases, see + # https://github.com/prometheus-community/postgres_exporter/blob/06a553c8166512c9d9c5ccf257b0f9bba8751dbc/collector/pg_database.go#L67 + # so if it hits one, it starts spamming logs + # ERROR: [NEON_SMGR] [reqid d9700000018] could not read db size of db 705302 from page server at lsn 5/A2457EB0 + shell: 'DATA_SOURCE_NAME="user=cloud_admin sslmode=disable dbname=postgres application_name=postgres-exporter pgaudit.log=none" /bin/postgres_exporter --no-collector.database --config.file=/etc/postgres_exporter.yml' - name: pgbouncer-exporter user: postgres sysvInitAction: respawn diff --git a/compute/vm-image-spec-bullseye.yaml b/compute/vm-image-spec-bullseye.yaml index 2b6e77b656..cf26ace72a 100644 --- a/compute/vm-image-spec-bullseye.yaml +++ b/compute/vm-image-spec-bullseye.yaml @@ -26,7 +26,13 @@ commands: - name: postgres-exporter user: nobody sysvInitAction: respawn - shell: 'DATA_SOURCE_NAME="user=cloud_admin sslmode=disable dbname=postgres application_name=postgres-exporter pgaudit.log=none" /bin/postgres_exporter --config.file=/etc/postgres_exporter.yml' + # Turn off database collector (`--no-collector.database`), we don't use `pg_database_size_bytes` metric anyway, see + # https://github.com/neondatabase/flux-fleet/blob/5e19b3fd897667b70d9a7ad4aa06df0ca22b49ff/apps/base/compute-metrics/scrape-compute-pg-exporter-neon.yaml#L29 + # but it's enabled by default and it doesn't filter out invalid databases, see + # https://github.com/prometheus-community/postgres_exporter/blob/06a553c8166512c9d9c5ccf257b0f9bba8751dbc/collector/pg_database.go#L67 + # so if it hits one, it starts spamming logs + # ERROR: [NEON_SMGR] [reqid d9700000018] could not read db size of db 705302 from page server at lsn 5/A2457EB0 + shell: 'DATA_SOURCE_NAME="user=cloud_admin sslmode=disable dbname=postgres application_name=postgres-exporter pgaudit.log=none" /bin/postgres_exporter --no-collector.database --config.file=/etc/postgres_exporter.yml' - name: pgbouncer-exporter user: postgres sysvInitAction: respawn diff --git a/compute_tools/Cargo.toml b/compute_tools/Cargo.toml index 496471acc7..558760b0ad 100644 --- a/compute_tools/Cargo.toml +++ b/compute_tools/Cargo.toml @@ -62,6 +62,7 @@ tokio-stream.workspace = true tonic.workspace = true tower-otel.workspace = true tracing.workspace = true +tracing-appender.workspace = true tracing-opentelemetry.workspace = true tracing-subscriber.workspace = true tracing-utils.workspace = true diff --git a/compute_tools/README.md b/compute_tools/README.md index 49f1368f0e..e92e5920b9 100644 --- a/compute_tools/README.md +++ b/compute_tools/README.md @@ -52,8 +52,14 @@ stateDiagram-v2 Init --> Running : Started Postgres Running --> TerminationPendingFast : Requested termination Running --> TerminationPendingImmediate : Requested termination + Running --> ConfigurationPending : Received a /configure request with spec + Running --> RefreshConfigurationPending : Received a /refresh_configuration request, compute node will pull a new spec and reconfigure + RefreshConfigurationPending --> RefreshConfiguration: Received compute spec and started configuration + RefreshConfiguration --> Running : Compute has been re-configured + RefreshConfiguration --> RefreshConfigurationPending : Configuration failed and to be retried TerminationPendingFast --> Terminated compute with 30s delay for cplane to inspect status TerminationPendingImmediate --> Terminated : Terminated compute immediately + Failed --> RefreshConfigurationPending : Received a /refresh_configuration request Failed --> [*] : Compute exited Terminated --> [*] : Compute exited ``` diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index 04723d6f3d..9c86aba531 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -49,9 +49,10 @@ use compute_tools::compute::{ BUILD_TAG, ComputeNode, ComputeNodeParams, forward_termination_signal, }; use compute_tools::extension_server::get_pg_version_string; -use compute_tools::logger::*; use compute_tools::params::*; +use compute_tools::pg_isready::get_pg_isready_bin; use compute_tools::spec::*; +use compute_tools::{hadron_metrics, installed_extensions, logger::*}; use rlimit::{Resource, setrlimit}; use signal_hook::consts::{SIGINT, SIGQUIT, SIGTERM}; use signal_hook::iterator::Signals; @@ -194,11 +195,19 @@ fn main() -> Result<()> { .build()?; let _rt_guard = runtime.enter(); - let tracing_provider = init(cli.dev)?; + let mut log_dir = None; + if cli.lakebase_mode { + log_dir = std::env::var("COMPUTE_CTL_LOG_DIRECTORY").ok(); + } + + let (tracing_provider, _file_logs_guard) = init(cli.dev, log_dir)?; // enable core dumping for all child processes setrlimit(Resource::CORE, rlimit::INFINITY, rlimit::INFINITY)?; + installed_extensions::initialize_metrics(); + hadron_metrics::initialize_metrics(); + let connstr = Url::parse(&cli.connstr).context("cannot parse connstr as a URL")?; let config = get_config(&cli)?; @@ -226,7 +235,12 @@ fn main() -> Result<()> { cli.installed_extensions_collection_interval, )), pg_init_timeout: cli.pg_init_timeout.map(Duration::from_secs), + pg_isready_bin: get_pg_isready_bin(&cli.pgbin), + instance_id: std::env::var("INSTANCE_ID").ok(), lakebase_mode: cli.lakebase_mode, + build_tag: BUILD_TAG.to_string(), + control_plane_uri: cli.control_plane_uri, + config_path_test_only: cli.config, }, config, )?; @@ -238,8 +252,14 @@ fn main() -> Result<()> { deinit_and_exit(tracing_provider, exit_code); } -fn init(dev_mode: bool) -> Result> { - let provider = init_tracing_and_logging(DEFAULT_LOG_LEVEL)?; +fn init( + dev_mode: bool, + log_dir: Option, +) -> Result<( + Option, + Option, +)> { + let (provider, file_logs_guard) = init_tracing_and_logging(DEFAULT_LOG_LEVEL, &log_dir)?; let mut signals = Signals::new([SIGINT, SIGTERM, SIGQUIT])?; thread::spawn(move || { @@ -250,7 +270,7 @@ fn init(dev_mode: bool) -> Result> { info!("compute build_tag: {}", &BUILD_TAG.to_string()); - Ok(provider) + Ok((provider, file_logs_guard)) } fn get_config(cli: &Cli) -> Result { diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index b4d7a6fca9..f53adbb1df 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -21,6 +21,7 @@ use postgres::NoTls; use postgres::error::SqlState; use remote_storage::{DownloadError, RemotePath}; use std::collections::{HashMap, HashSet}; +use std::ffi::OsString; use std::os::unix::fs::{PermissionsExt, symlink}; use std::path::Path; use std::process::{Command, Stdio}; @@ -40,8 +41,9 @@ use utils::shard::{ShardCount, ShardIndex, ShardNumber}; use crate::configurator::launch_configurator; use crate::disk_quota::set_disk_quota; +use crate::hadron_metrics::COMPUTE_ATTACHED; use crate::installed_extensions::get_installed_extensions; -use crate::logger::startup_context_from_env; +use crate::logger::{self, startup_context_from_env}; use crate::lsn_lease::launch_lsn_lease_bg_task_for_static; use crate::metrics::COMPUTE_CTL_UP; use crate::monitor::launch_monitor; @@ -113,11 +115,17 @@ pub struct ComputeNodeParams { /// Interval for installed extensions collection pub installed_extensions_collection_interval: Arc, - + /// Hadron instance ID of the compute node. + pub instance_id: Option, /// Timeout of PG compute startup in the Init state. pub pg_init_timeout: Option, - + // Path to the `pg_isready` binary. + pub pg_isready_bin: String, pub lakebase_mode: bool, + + pub build_tag: String, + pub control_plane_uri: Option, + pub config_path_test_only: Option, } type TaskHandle = Mutex>>; @@ -405,6 +413,52 @@ struct StartVmMonitorResult { vm_monitor: Option>>, } +/// Databricks-specific environment variables to be passed to the `postgres` sub-process. +pub struct DatabricksEnvVars { + /// The Databricks "endpoint ID" of the compute instance. Used by `postgres` to check + /// the token scopes of internal auth tokens. + pub endpoint_id: String, + /// Hostname of the Databricks workspace URL this compute instance belongs to. + /// Used by postgres to verify Databricks PAT tokens. + pub workspace_host: String, +} + +impl DatabricksEnvVars { + pub fn new(compute_spec: &ComputeSpec, compute_id: Option<&String>) -> Self { + // compute_id is a string format of "{endpoint_id}/{compute_idx}" + // endpoint_id is a uuid. We only need to pass down endpoint_id to postgres. + // Panics if compute_id is not set or not in the expected format. + let endpoint_id = compute_id.unwrap().split('/').next().unwrap().to_string(); + let workspace_host = compute_spec + .databricks_settings + .as_ref() + .map(|s| s.databricks_workspace_host.clone()) + .unwrap_or("".to_string()); + Self { + endpoint_id, + workspace_host, + } + } + + /// Constants for the names of Databricks-specific postgres environment variables. + const DATABRICKS_ENDPOINT_ID_ENVVAR: &'static str = "DATABRICKS_ENDPOINT_ID"; + const DATABRICKS_WORKSPACE_HOST_ENVVAR: &'static str = "DATABRICKS_WORKSPACE_HOST"; + + /// Convert DatabricksEnvVars to a list of string pairs that can be passed as env vars. Consumes `self`. + pub fn to_env_var_list(self) -> Vec<(String, String)> { + vec![ + ( + Self::DATABRICKS_ENDPOINT_ID_ENVVAR.to_string(), + self.endpoint_id.clone(), + ), + ( + Self::DATABRICKS_WORKSPACE_HOST_ENVVAR.to_string(), + self.workspace_host.clone(), + ), + ] + } +} + impl ComputeNode { pub fn new(params: ComputeNodeParams, config: ComputeConfig) -> Result { let connstr = params.connstr.as_str(); @@ -486,6 +540,7 @@ impl ComputeNode { port: this.params.external_http_port, config: this.compute_ctl_config.clone(), compute_id: this.params.compute_id.clone(), + instance_id: this.params.instance_id.clone(), } .launch(&this); @@ -1402,6 +1457,8 @@ impl ComputeNode { let pgdata_path = Path::new(&self.params.pgdata); let tls_config = self.tls_config(&pspec.spec); + let databricks_settings = spec.databricks_settings.as_ref(); + let postgres_port = self.params.connstr.port(); // Remove/create an empty pgdata directory and put configuration there. self.create_pgdata()?; @@ -1409,8 +1466,11 @@ impl ComputeNode { pgdata_path, &self.params, &pspec.spec, + postgres_port, self.params.internal_http_port, tls_config, + databricks_settings, + self.params.lakebase_mode, )?; // Syncing safekeepers is only safe with primary nodes: if a primary @@ -1450,8 +1510,28 @@ impl ComputeNode { ) })?; - // Update pg_hba.conf received with basebackup. - update_pg_hba(pgdata_path, None)?; + if let Some(settings) = databricks_settings { + copy_tls_certificates( + &settings.pg_compute_tls_settings.key_file, + &settings.pg_compute_tls_settings.cert_file, + pgdata_path, + )?; + + // Update pg_hba.conf received with basebackup including additional databricks settings. + update_pg_hba(pgdata_path, Some(&settings.databricks_pg_hba))?; + update_pg_ident(pgdata_path, Some(&settings.databricks_pg_ident))?; + } else { + // Update pg_hba.conf received with basebackup. + update_pg_hba(pgdata_path, None)?; + } + + if let Some(databricks_settings) = spec.databricks_settings.as_ref() { + copy_tls_certificates( + &databricks_settings.pg_compute_tls_settings.key_file, + &databricks_settings.pg_compute_tls_settings.cert_file, + pgdata_path, + )?; + } // Place pg_dynshmem under /dev/shm. This allows us to use // 'dynamic_shared_memory_type = mmap' so that the files are placed in @@ -1564,14 +1644,31 @@ impl ComputeNode { pub fn start_postgres(&self, storage_auth_token: Option) -> Result { let pgdata_path = Path::new(&self.params.pgdata); + let env_vars: Vec<(String, String)> = if self.params.lakebase_mode { + let databricks_env_vars = { + let state = self.state.lock().unwrap(); + let spec = &state.pspec.as_ref().unwrap().spec; + DatabricksEnvVars::new(spec, Some(&self.params.compute_id)) + }; + + info!( + "Starting Postgres for databricks endpoint id: {}", + &databricks_env_vars.endpoint_id + ); + + let mut env_vars = databricks_env_vars.to_env_var_list(); + env_vars.extend(storage_auth_token.map(|t| ("NEON_AUTH_TOKEN".to_string(), t))); + env_vars + } else if let Some(storage_auth_token) = &storage_auth_token { + vec![("NEON_AUTH_TOKEN".to_owned(), storage_auth_token.to_owned())] + } else { + vec![] + }; + // Run postgres as a child process. let mut pg = maybe_cgexec(&self.params.pgbin) .args(["-D", &self.params.pgdata]) - .envs(if let Some(storage_auth_token) = &storage_auth_token { - vec![("NEON_AUTH_TOKEN", storage_auth_token)] - } else { - vec![] - }) + .envs(env_vars) .stderr(Stdio::piped()) .spawn() .expect("cannot start postgres process"); @@ -1785,6 +1882,34 @@ impl ComputeNode { Ok::<(), anyhow::Error>(()) } + // Signal to the configurator to refresh the configuration by pulling a new spec from the HCC. + // Note that this merely triggers a notification on a condition variable the configurator thread + // waits on. The configurator thread (in configurator.rs) pulls the new spec from the HCC and + // applies it. + pub async fn signal_refresh_configuration(&self) -> Result<()> { + let states_allowing_configuration_refresh = [ + ComputeStatus::Running, + ComputeStatus::Failed, + ComputeStatus::RefreshConfigurationPending, + ]; + + let mut state = self.state.lock().expect("state lock poisoned"); + if states_allowing_configuration_refresh.contains(&state.status) { + state.status = ComputeStatus::RefreshConfigurationPending; + self.state_changed.notify_all(); + Ok(()) + } else if state.status == ComputeStatus::Init { + // If the compute is in Init state, we can't refresh the configuration immediately, + // but we should be able to do that soon. + Ok(()) + } else { + Err(anyhow::anyhow!( + "Cannot refresh compute configuration in state {:?}", + state.status + )) + } + } + // Wrapped this around `pg_ctl reload`, but right now we don't use // `pg_ctl` for start / stop. #[instrument(skip_all)] @@ -1846,12 +1971,16 @@ impl ComputeNode { // Write new config let pgdata_path = Path::new(&self.params.pgdata); + let postgres_port = self.params.connstr.port(); config::write_postgres_conf( pgdata_path, &self.params, &spec, + postgres_port, self.params.internal_http_port, tls_config, + spec.databricks_settings.as_ref(), + self.params.lakebase_mode, )?; self.pg_reload_conf()?; @@ -1957,6 +2086,8 @@ impl ComputeNode { // wait ComputeStatus::Init | ComputeStatus::Configuration + | ComputeStatus::RefreshConfiguration + | ComputeStatus::RefreshConfigurationPending | ComputeStatus::Empty => { state = self.state_changed.wait(state).unwrap(); } @@ -2513,6 +2644,34 @@ LIMIT 100", ); } } + + /// Set the compute spec and update related metrics. + /// This is the central place where pspec is updated. + pub fn set_spec(params: &ComputeNodeParams, state: &mut ComputeState, pspec: ParsedSpec) { + state.pspec = Some(pspec); + ComputeNode::update_attached_metric(params, state); + let _ = logger::update_ids(¶ms.instance_id, &Some(params.compute_id.clone())); + } + + pub fn update_attached_metric(params: &ComputeNodeParams, state: &mut ComputeState) { + // Update the pg_cctl_attached gauge when all identifiers are available. + if let Some(instance_id) = ¶ms.instance_id { + if let Some(pspec) = &state.pspec { + // Clear all values in the metric + COMPUTE_ATTACHED.reset(); + + // Set new metric value + COMPUTE_ATTACHED + .with_label_values(&[ + ¶ms.compute_id, + instance_id, + &pspec.tenant_id.to_string(), + &pspec.timeline_id.to_string(), + ]) + .set(1); + } + } + } } pub async fn installed_extensions(conf: tokio_postgres::Config) -> Result<()> { diff --git a/compute_tools/src/compute_prewarm.rs b/compute_tools/src/compute_prewarm.rs index 07b4a596cc..97e62c1c80 100644 --- a/compute_tools/src/compute_prewarm.rs +++ b/compute_tools/src/compute_prewarm.rs @@ -90,6 +90,7 @@ impl ComputeNode { } /// If there is a prewarm request ongoing, return `false`, `true` otherwise. + /// Has a failpoint "compute-prewarm" pub fn prewarm_lfc(self: &Arc, from_endpoint: Option) -> bool { { let state = &mut self.state.lock().unwrap().lfc_prewarm_state; @@ -112,9 +113,8 @@ impl ComputeNode { Err(err) => { crate::metrics::LFC_PREWARM_ERRORS.inc(); error!(%err, "could not prewarm LFC"); - LfcPrewarmState::Failed { - error: err.to_string(), + error: format!("{err:#}"), } } }; @@ -135,16 +135,20 @@ impl ComputeNode { async fn prewarm_impl(&self, from_endpoint: Option) -> Result { let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?; + #[cfg(feature = "testing")] + fail::fail_point!("compute-prewarm", |_| { + bail!("prewarm configured to fail because of a failpoint") + }); + info!(%url, "requesting LFC state from endpoint storage"); let request = Client::new().get(&url).bearer_auth(token); let res = request.send().await.context("querying endpoint storage")?; - let status = res.status(); - match status { + match res.status() { StatusCode::OK => (), StatusCode::NOT_FOUND => { return Ok(false); } - _ => bail!("{status} querying endpoint storage"), + status => bail!("{status} querying endpoint storage"), } let mut uncompressed = Vec::new(); @@ -205,7 +209,7 @@ impl ComputeNode { crate::metrics::LFC_OFFLOAD_ERRORS.inc(); error!(%err, "could not offload LFC state to endpoint storage"); self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Failed { - error: err.to_string(), + error: format!("{err:#}"), }; } @@ -213,16 +217,22 @@ impl ComputeNode { let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?; info!(%url, "requesting LFC state from Postgres"); - let mut compressed = Vec::new(); - ComputeNode::get_maintenance_client(&self.tokio_conn_conf) + let row = ComputeNode::get_maintenance_client(&self.tokio_conn_conf) .await .context("connecting to postgres")? .query_one("select neon.get_local_cache_state()", &[]) .await - .context("querying LFC state")? - .try_get::(0) - .context("deserializing LFC state") - .map(ZstdEncoder::new)? + .context("querying LFC state")?; + let state = row + .try_get::>(0) + .context("deserializing LFC state")?; + let Some(state) = state else { + info!(%url, "empty LFC state, not exporting"); + return Ok(()); + }; + + let mut compressed = Vec::new(); + ZstdEncoder::new(state) .read_to_end(&mut compressed) .await .context("compressing LFC state")?; diff --git a/compute_tools/src/compute_promote.rs b/compute_tools/src/compute_promote.rs index 42256faa22..a34368c531 100644 --- a/compute_tools/src/compute_promote.rs +++ b/compute_tools/src/compute_promote.rs @@ -1,11 +1,12 @@ use crate::compute::ComputeNode; use anyhow::{Context, Result, bail}; -use compute_api::{ - responses::{LfcPrewarmState, PromoteState, SafekeepersLsn}, - spec::ComputeMode, -}; +use compute_api::responses::{LfcPrewarmState, PromoteConfig, PromoteState}; +use compute_api::spec::ComputeMode; +use itertools::Itertools; +use std::collections::HashMap; use std::{sync::Arc, time::Duration}; use tokio::time::sleep; +use tracing::info; use utils::lsn::Lsn; impl ComputeNode { @@ -13,21 +14,22 @@ impl ComputeNode { /// and http client disconnects, this does not stop promotion, and subsequent /// calls block until promote finishes. /// Called by control plane on secondary after primary endpoint is terminated - pub async fn promote(self: &Arc, safekeepers_lsn: SafekeepersLsn) -> PromoteState { + /// Has a failpoint "compute-promotion" + pub async fn promote(self: &Arc, cfg: PromoteConfig) -> PromoteState { let cloned = self.clone(); + let promote_fn = async move || { + let Err(err) = cloned.promote_impl(cfg).await else { + return PromoteState::Completed; + }; + tracing::error!(%err, "promoting"); + PromoteState::Failed { + error: format!("{err:#}"), + } + }; + let start_promotion = || { let (tx, rx) = tokio::sync::watch::channel(PromoteState::NotPromoted); - tokio::spawn(async move { - tx.send(match cloned.promote_impl(safekeepers_lsn).await { - Ok(_) => PromoteState::Completed, - Err(err) => { - tracing::error!(%err, "promoting"); - PromoteState::Failed { - error: err.to_string(), - } - } - }) - }); + tokio::spawn(async move { tx.send(promote_fn().await) }); rx }; @@ -47,9 +49,7 @@ impl ComputeNode { task.borrow().clone() } - // Why do we have to supply safekeepers? - // For secondary we use primary_connection_conninfo so safekeepers field is empty - async fn promote_impl(&self, safekeepers_lsn: SafekeepersLsn) -> Result<()> { + async fn promote_impl(&self, mut cfg: PromoteConfig) -> Result<()> { { let state = self.state.lock().unwrap(); let mode = &state.pspec.as_ref().unwrap().spec.mode; @@ -73,7 +73,7 @@ impl ComputeNode { .await .context("connecting to postgres")?; - let primary_lsn = safekeepers_lsn.wal_flush_lsn; + let primary_lsn = cfg.wal_flush_lsn; let mut last_wal_replay_lsn: Lsn = Lsn::INVALID; const RETRIES: i32 = 20; for i in 0..=RETRIES { @@ -86,7 +86,7 @@ impl ComputeNode { if last_wal_replay_lsn >= primary_lsn { break; } - tracing::info!("Try {i}, replica lsn {last_wal_replay_lsn}, primary lsn {primary_lsn}"); + info!("Try {i}, replica lsn {last_wal_replay_lsn}, primary lsn {primary_lsn}"); sleep(Duration::from_secs(1)).await; } if last_wal_replay_lsn < primary_lsn { @@ -96,7 +96,7 @@ impl ComputeNode { // using $1 doesn't work with ALTER SYSTEM SET let safekeepers_sql = format!( "ALTER SYSTEM SET neon.safekeepers='{}'", - safekeepers_lsn.safekeepers + cfg.spec.safekeeper_connstrings.join(",") ); client .query(&safekeepers_sql, &[]) @@ -106,6 +106,12 @@ impl ComputeNode { .query("SELECT pg_reload_conf()", &[]) .await .context("reloading postgres config")?; + + #[cfg(feature = "testing")] + fail::fail_point!("compute-promotion", |_| { + bail!("promotion configured to fail because of a failpoint") + }); + let row = client .query_one("SELECT * FROM pg_promote()", &[]) .await @@ -125,8 +131,36 @@ impl ComputeNode { bail!("replica in read only mode after promotion"); } - let mut state = self.state.lock().unwrap(); - state.pspec.as_mut().unwrap().spec.mode = ComputeMode::Primary; - Ok(()) + { + let mut state = self.state.lock().unwrap(); + let spec = &mut state.pspec.as_mut().unwrap().spec; + spec.mode = ComputeMode::Primary; + let new_conf = cfg.spec.cluster.postgresql_conf.as_mut().unwrap(); + let existing_conf = spec.cluster.postgresql_conf.as_ref().unwrap(); + Self::merge_spec(new_conf, existing_conf); + } + info!("applied new spec, reconfiguring as primary"); + self.reconfigure() + } + + /// Merge old and new Postgres conf specs to apply on secondary. + /// Change new spec's port and safekeepers since they are supplied + /// differenly + fn merge_spec(new_conf: &mut String, existing_conf: &str) { + let mut new_conf_set: HashMap<&str, &str> = new_conf + .split_terminator('\n') + .map(|e| e.split_once("=").expect("invalid item")) + .collect(); + new_conf_set.remove("neon.safekeepers"); + + let existing_conf_set: HashMap<&str, &str> = existing_conf + .split_terminator('\n') + .map(|e| e.split_once("=").expect("invalid item")) + .collect(); + new_conf_set.insert("port", existing_conf_set["port"]); + *new_conf = new_conf_set + .iter() + .map(|(k, v)| format!("{k}={v}")) + .join("\n"); } } diff --git a/compute_tools/src/config.rs b/compute_tools/src/config.rs index dd46353343..55a1eda0b7 100644 --- a/compute_tools/src/config.rs +++ b/compute_tools/src/config.rs @@ -7,11 +7,14 @@ use std::io::prelude::*; use std::path::Path; use compute_api::responses::TlsConfig; -use compute_api::spec::{ComputeAudit, ComputeMode, ComputeSpec, GenericOption}; +use compute_api::spec::{ + ComputeAudit, ComputeMode, ComputeSpec, DatabricksSettings, GenericOption, +}; use crate::compute::ComputeNodeParams; use crate::pg_helpers::{ - GenericOptionExt, GenericOptionsSearch, PgOptionsSerialize, escape_conf_value, + DatabricksSettingsExt as _, GenericOptionExt, GenericOptionsSearch, PgOptionsSerialize, + escape_conf_value, }; use crate::tls::{self, SERVER_CRT, SERVER_KEY}; @@ -40,12 +43,16 @@ pub fn line_in_file(path: &Path, line: &str) -> Result { } /// Create or completely rewrite configuration file specified by `path` +#[allow(clippy::too_many_arguments)] pub fn write_postgres_conf( pgdata_path: &Path, params: &ComputeNodeParams, spec: &ComputeSpec, + postgres_port: Option, extension_server_port: u16, tls_config: &Option, + databricks_settings: Option<&DatabricksSettings>, + lakebase_mode: bool, ) -> Result<()> { let path = pgdata_path.join("postgresql.conf"); // File::create() destroys the file content if it exists. @@ -285,6 +292,24 @@ pub fn write_postgres_conf( writeln!(file, "log_destination='stderr,syslog'")?; } + if lakebase_mode { + // Explicitly set the port based on the connstr, overriding any previous port setting. + // Note: It is important that we don't specify a different port again after this. + let port = postgres_port.expect("port must be present in connstr"); + writeln!(file, "port = {port}")?; + + // This is databricks specific settings. + // This should be at the end of the file but before `compute_ctl_temp_override.conf` below + // so that it can override any settings above. + // `compute_ctl_temp_override.conf` is intended to override any settings above during specific operations. + // To prevent potential breakage in the future, we keep it above `compute_ctl_temp_override.conf`. + writeln!(file, "# Databricks settings start")?; + if let Some(settings) = databricks_settings { + writeln!(file, "{}", settings.as_pg_settings())?; + } + writeln!(file, "# Databricks settings end")?; + } + // This is essential to keep this line at the end of the file, // because it is intended to override any settings above. writeln!(file, "include_if_exists = 'compute_ctl_temp_override.conf'")?; diff --git a/compute_tools/src/configurator.rs b/compute_tools/src/configurator.rs index d97bd37285..feca8337b2 100644 --- a/compute_tools/src/configurator.rs +++ b/compute_tools/src/configurator.rs @@ -1,23 +1,40 @@ -use std::sync::Arc; +use std::fs::File; use std::thread; +use std::{path::Path, sync::Arc}; -use compute_api::responses::ComputeStatus; +use anyhow::Result; +use compute_api::responses::{ComputeConfig, ComputeStatus}; use tracing::{error, info, instrument}; -use crate::compute::ComputeNode; +use crate::compute::{ComputeNode, ParsedSpec}; +use crate::spec::get_config_from_control_plane; #[instrument(skip_all)] fn configurator_main_loop(compute: &Arc) { info!("waiting for reconfiguration requests"); loop { let mut state = compute.state.lock().unwrap(); + /* BEGIN_HADRON */ + // RefreshConfiguration should only be used inside the loop + assert_ne!(state.status, ComputeStatus::RefreshConfiguration); + /* END_HADRON */ - // We have to re-check the status after re-acquiring the lock because it could be that - // the status has changed while we were waiting for the lock, and we might not need to - // wait on the condition variable. Otherwise, we might end up in some soft-/deadlock, i.e. - // we are waiting for a condition variable that will never be signaled. - if state.status != ComputeStatus::ConfigurationPending { - state = compute.state_changed.wait(state).unwrap(); + if compute.params.lakebase_mode { + while state.status != ComputeStatus::ConfigurationPending + && state.status != ComputeStatus::RefreshConfigurationPending + && state.status != ComputeStatus::Failed + { + info!("configurator: compute status: {:?}, sleeping", state.status); + state = compute.state_changed.wait(state).unwrap(); + } + } else { + // We have to re-check the status after re-acquiring the lock because it could be that + // the status has changed while we were waiting for the lock, and we might not need to + // wait on the condition variable. Otherwise, we might end up in some soft-/deadlock, i.e. + // we are waiting for a condition variable that will never be signaled. + if state.status != ComputeStatus::ConfigurationPending { + state = compute.state_changed.wait(state).unwrap(); + } } // Re-check the status after waking up @@ -37,6 +54,133 @@ fn configurator_main_loop(compute: &Arc) { // XXX: used to test that API is blocking // std::thread::sleep(std::time::Duration::from_millis(10000)); + compute.set_status(new_status); + } else if state.status == ComputeStatus::RefreshConfigurationPending { + info!( + "compute node suspects its configuration is out of date, now refreshing configuration" + ); + state.set_status(ComputeStatus::RefreshConfiguration, &compute.state_changed); + // Drop the lock guard here to avoid holding the lock while downloading config from the control plane / HCC. + // This is the only thread that can move compute_ctl out of the `RefreshConfiguration` state, so it + // is safe to drop the lock like this. + drop(state); + + let get_config_result: anyhow::Result = + if let Some(config_path) = &compute.params.config_path_test_only { + // This path is only to make testing easier. In production we always get the config from the HCC. + info!( + "reloading config.json from path: {}", + config_path.to_string_lossy() + ); + let path = Path::new(config_path); + if let Ok(file) = File::open(path) { + match serde_json::from_reader::(file) { + Ok(config) => Ok(config), + Err(e) => { + error!("could not parse config file: {}", e); + Err(anyhow::anyhow!("could not parse config file: {}", e)) + } + } + } else { + error!( + "could not open config file at path: {:?}", + config_path.to_string_lossy() + ); + Err(anyhow::anyhow!( + "could not open config file at path: {}", + config_path.to_string_lossy() + )) + } + } else if let Some(control_plane_uri) = &compute.params.control_plane_uri { + get_config_from_control_plane(control_plane_uri, &compute.params.compute_id) + } else { + Err(anyhow::anyhow!("config_path_test_only is not set")) + }; + + // Parse any received ComputeSpec and transpose the result into a Result>. + let parsed_spec_result: Result> = + get_config_result.and_then(|config| { + if let Some(spec) = config.spec { + if let Ok(pspec) = ParsedSpec::try_from(spec) { + Ok(Some(pspec)) + } else { + Err(anyhow::anyhow!("could not parse spec")) + } + } else { + Ok(None) + } + }); + + let new_status: ComputeStatus; + match parsed_spec_result { + // Control plane (HCM) returned a spec and we were able to parse it. + Ok(Some(pspec)) => { + { + let mut state = compute.state.lock().unwrap(); + // Defensive programming to make sure this thread is indeed the only one that can move the compute + // node out of the `RefreshConfiguration` state. Would be nice if we can encode this invariant + // into the type system. + assert_eq!(state.status, ComputeStatus::RefreshConfiguration); + + if state.pspec.as_ref().map(|ps| ps.pageserver_connstr.clone()) + == Some(pspec.pageserver_connstr.clone()) + { + info!( + "Refresh configuration: Retrieved spec is the same as the current spec. Waiting for control plane to update the spec before attempting reconfiguration." + ); + state.status = ComputeStatus::Running; + compute.state_changed.notify_all(); + drop(state); + std::thread::sleep(std::time::Duration::from_secs(5)); + continue; + } + // state.pspec is consumed by compute.reconfigure() below. Note that compute.reconfigure() will acquire + // the compute.state lock again so we need to have the lock guard go out of scope here. We could add a + // "locked" variant of compute.reconfigure() that takes the lock guard as an argument to make this cleaner, + // but it's not worth forking the codebase too much for this minor point alone right now. + state.pspec = Some(pspec); + } + match compute.reconfigure() { + Ok(_) => { + info!("Refresh configuration: compute node configured"); + new_status = ComputeStatus::Running; + } + Err(e) => { + error!( + "Refresh configuration: could not configure compute node: {}", + e + ); + // Set the compute node back to the `RefreshConfigurationPending` state if the configuration + // was not successful. It should be okay to treat this situation the same as if the loop + // hasn't executed yet as long as the detection side keeps notifying. + new_status = ComputeStatus::RefreshConfigurationPending; + } + } + } + // Control plane (HCM)'s response does not contain a spec. This is the "Empty" attachment case. + Ok(None) => { + info!( + "Compute Manager signaled that this compute is no longer attached to any storage. Exiting." + ); + // We just immediately terminate the whole compute_ctl in this case. It's not necessary to attempt a + // clean shutdown as Postgres is probably not responding anyway (which is why we are in this refresh + // configuration state). + std::process::exit(1); + } + // Various error cases: + // - The request to the control plane (HCM) either failed or returned a malformed spec. + // - compute_ctl itself is configured incorrectly (e.g., compute_id is not set). + Err(e) => { + error!( + "Refresh configuration: error getting a parsed spec: {:?}", + e + ); + new_status = ComputeStatus::RefreshConfigurationPending; + // We may be dealing with an overloaded HCM if we end up in this path. Backoff 5 seconds before + // retrying to avoid hammering the HCM. + std::thread::sleep(std::time::Duration::from_secs(5)); + } + } compute.set_status(new_status); } else if state.status == ComputeStatus::Failed { info!("compute node is now in Failed state, exiting"); diff --git a/compute_tools/src/http/middleware/authorize.rs b/compute_tools/src/http/middleware/authorize.rs index 1b0bf4d9c5..407833bb0e 100644 --- a/compute_tools/src/http/middleware/authorize.rs +++ b/compute_tools/src/http/middleware/authorize.rs @@ -16,12 +16,16 @@ use crate::http::JsonResponse; #[derive(Clone, Debug)] pub(in crate::http) struct Authorize { compute_id: String, + // BEGIN HADRON + // Hadron instance ID. Only set if it's a Lakebase V1 a.k.a. Hadron instance. + instance_id: Option, + // END HADRON jwks: JwkSet, validation: Validation, } impl Authorize { - pub fn new(compute_id: String, jwks: JwkSet) -> Self { + pub fn new(compute_id: String, instance_id: Option, jwks: JwkSet) -> Self { let mut validation = Validation::new(Algorithm::EdDSA); // BEGIN HADRON @@ -46,6 +50,7 @@ impl Authorize { Self { compute_id, + instance_id, jwks, validation, } @@ -59,10 +64,20 @@ impl AsyncAuthorizeRequest for Authorize { fn authorize(&mut self, mut request: Request) -> Self::Future { let compute_id = self.compute_id.clone(); + let is_hadron_instance = self.instance_id.is_some(); let jwks = self.jwks.clone(); let validation = self.validation.clone(); Box::pin(async move { + // BEGIN HADRON + // In Hadron deployments the "external" HTTP endpoint on compute_ctl can only be + // accessed by trusted components (enforced by dblet network policy), so we can bypass + // all auth here. + if is_hadron_instance { + return Ok(request); + } + // END HADRON + let TypedHeader(Authorization(bearer)) = request .extract_parts::>>() .await diff --git a/compute_tools/src/http/openapi_spec.yaml b/compute_tools/src/http/openapi_spec.yaml index 3cf5ea7c51..ab729d62b5 100644 --- a/compute_tools/src/http/openapi_spec.yaml +++ b/compute_tools/src/http/openapi_spec.yaml @@ -96,7 +96,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/SafekeepersLsn" + $ref: "#/components/schemas/ComputeSchemaWithLsn" responses: 200: description: Promote succeeded or wasn't started @@ -297,14 +297,7 @@ paths: content: application/json: schema: - type: object - required: - - spec - properties: - spec: - # XXX: I don't want to explain current spec in the OpenAPI format, - # as it could be changed really soon. Consider doing it later. - type: object + $ref: "#/components/schemas/ComputeSchema" responses: 200: description: Compute configuration finished. @@ -591,18 +584,25 @@ components: type: string example: "1.0.0" - SafekeepersLsn: + ComputeSchema: type: object required: - - safekeepers + - spec + properties: + spec: + type: object + ComputeSchemaWithLsn: + type: object + required: + - spec - wal_flush_lsn properties: - safekeepers: - description: Primary replica safekeepers - type: string + spec: + $ref: "#/components/schemas/ComputeState" wal_flush_lsn: - description: Primary last WAL flush LSN type: string + description: "last WAL flush LSN" + example: "0/028F10D8" LfcPrewarmState: type: object diff --git a/compute_tools/src/http/routes/configure.rs b/compute_tools/src/http/routes/configure.rs index b7325d283f..943ff45357 100644 --- a/compute_tools/src/http/routes/configure.rs +++ b/compute_tools/src/http/routes/configure.rs @@ -43,7 +43,12 @@ pub(in crate::http) async fn configure( // configure request for tracing purposes. state.startup_span = Some(tracing::Span::current()); - state.pspec = Some(pspec); + if compute.params.lakebase_mode { + ComputeNode::set_spec(&compute.params, &mut state, pspec); + } else { + state.pspec = Some(pspec); + } + state.set_status(ComputeStatus::ConfigurationPending, &compute.state_changed); drop(state); } diff --git a/compute_tools/src/http/routes/hadron_liveness_probe.rs b/compute_tools/src/http/routes/hadron_liveness_probe.rs new file mode 100644 index 0000000000..4f66b6b139 --- /dev/null +++ b/compute_tools/src/http/routes/hadron_liveness_probe.rs @@ -0,0 +1,34 @@ +use crate::pg_isready::pg_isready; +use crate::{compute::ComputeNode, http::JsonResponse}; +use axum::{extract::State, http::StatusCode, response::Response}; +use std::sync::Arc; + +/// NOTE: NOT ENABLED YET +/// Detect if the compute is alive. +/// Called by the liveness probe of the compute container. +pub(in crate::http) async fn hadron_liveness_probe( + State(compute): State>, +) -> Response { + let port = match compute.params.connstr.port() { + Some(port) => port, + None => { + return JsonResponse::error( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to get the port from the connection string", + ); + } + }; + match pg_isready(&compute.params.pg_isready_bin, port) { + Ok(_) => { + // The connection is successful, so the compute is alive. + // Return a 200 OK response. + JsonResponse::success(StatusCode::OK, "ok") + } + Err(e) => { + tracing::error!("Hadron liveness probe failed: {}", e); + // The connection failed, so the compute is not alive. + // Return a 500 Internal Server Error response. + JsonResponse::error(StatusCode::INTERNAL_SERVER_ERROR, e) + } + } +} diff --git a/compute_tools/src/http/routes/metrics.rs b/compute_tools/src/http/routes/metrics.rs index 96b464fd12..8406746327 100644 --- a/compute_tools/src/http/routes/metrics.rs +++ b/compute_tools/src/http/routes/metrics.rs @@ -13,6 +13,7 @@ use metrics::{Encoder, TextEncoder}; use crate::communicator_socket_client::connect_communicator_socket; use crate::compute::ComputeNode; +use crate::hadron_metrics; use crate::http::JsonResponse; use crate::metrics::collect; @@ -21,11 +22,18 @@ pub(in crate::http) async fn get_metrics() -> Response { // When we call TextEncoder::encode() below, it will immediately return an // error if a metric family has no metrics, so we need to preemptively // filter out metric families with no metrics. - let metrics = collect() + let mut metrics = collect() .into_iter() .filter(|m| !m.get_metric().is_empty()) .collect::>(); + // Add Hadron metrics. + let hadron_metrics: Vec = hadron_metrics::collect() + .into_iter() + .filter(|m| !m.get_metric().is_empty()) + .collect(); + metrics.extend(hadron_metrics); + let encoder = TextEncoder::new(); let mut buffer = vec![]; diff --git a/compute_tools/src/http/routes/mod.rs b/compute_tools/src/http/routes/mod.rs index dd71f663eb..c0f68701c6 100644 --- a/compute_tools/src/http/routes/mod.rs +++ b/compute_tools/src/http/routes/mod.rs @@ -10,11 +10,13 @@ pub(in crate::http) mod extension_server; pub(in crate::http) mod extensions; pub(in crate::http) mod failpoints; pub(in crate::http) mod grants; +pub(in crate::http) mod hadron_liveness_probe; pub(in crate::http) mod insights; pub(in crate::http) mod lfc; pub(in crate::http) mod metrics; pub(in crate::http) mod metrics_json; pub(in crate::http) mod promote; +pub(in crate::http) mod refresh_configuration; pub(in crate::http) mod status; pub(in crate::http) mod terminate; diff --git a/compute_tools/src/http/routes/promote.rs b/compute_tools/src/http/routes/promote.rs index bc5f93b4da..7ca3464b63 100644 --- a/compute_tools/src/http/routes/promote.rs +++ b/compute_tools/src/http/routes/promote.rs @@ -1,14 +1,14 @@ use crate::http::JsonResponse; -use axum::Form; +use axum::extract::Json; use http::StatusCode; pub(in crate::http) async fn promote( compute: axum::extract::State>, - Form(safekeepers_lsn): Form, + Json(cfg): Json, ) -> axum::response::Response { - let state = compute.promote(safekeepers_lsn).await; - if let compute_api::responses::PromoteState::Failed { error } = state { - return JsonResponse::error(StatusCode::INTERNAL_SERVER_ERROR, error); + let state = compute.promote(cfg).await; + if let compute_api::responses::PromoteState::Failed { error: _ } = state { + return JsonResponse::create_response(StatusCode::INTERNAL_SERVER_ERROR, state); } JsonResponse::success(StatusCode::OK, state) } diff --git a/compute_tools/src/http/routes/refresh_configuration.rs b/compute_tools/src/http/routes/refresh_configuration.rs new file mode 100644 index 0000000000..9b2f95ca5a --- /dev/null +++ b/compute_tools/src/http/routes/refresh_configuration.rs @@ -0,0 +1,29 @@ +// This file is added by Hadron + +use std::sync::Arc; + +use axum::{ + extract::State, + response::{IntoResponse, Response}, +}; +use http::StatusCode; + +use crate::compute::ComputeNode; +use crate::hadron_metrics::POSTGRES_PAGESTREAM_REQUEST_ERRORS; +use crate::http::JsonResponse; + +/// The /refresh_configuration POST method is used to nudge compute_ctl to pull a new spec +/// from the HCC and attempt to reconfigure Postgres with the new spec. The method does not wait +/// for the reconfiguration to complete. Rather, it simply delivers a signal that will cause +/// configuration to be reloaded in a best effort manner. Invocation of this method does not +/// guarantee that a reconfiguration will occur. The caller should consider keep sending this +/// request while it believes that the compute configuration is out of date. +pub(in crate::http) async fn refresh_configuration( + State(compute): State>, +) -> Response { + POSTGRES_PAGESTREAM_REQUEST_ERRORS.inc(); + match compute.signal_refresh_configuration().await { + Ok(_) => StatusCode::OK.into_response(), + Err(e) => JsonResponse::error(StatusCode::INTERNAL_SERVER_ERROR, e), + } +} diff --git a/compute_tools/src/http/routes/terminate.rs b/compute_tools/src/http/routes/terminate.rs index 5b30b020c8..deac760f43 100644 --- a/compute_tools/src/http/routes/terminate.rs +++ b/compute_tools/src/http/routes/terminate.rs @@ -1,7 +1,7 @@ use crate::compute::{ComputeNode, forward_termination_signal}; use crate::http::JsonResponse; use axum::extract::State; -use axum::response::Response; +use axum::response::{IntoResponse, Response}; use axum_extra::extract::OptionalQuery; use compute_api::responses::{ComputeStatus, TerminateMode, TerminateResponse}; use http::StatusCode; @@ -33,7 +33,29 @@ pub(in crate::http) async fn terminate( if !matches!(state.status, ComputeStatus::Empty | ComputeStatus::Running) { return JsonResponse::invalid_status(state.status); } + + // If compute is Empty, there's no Postgres to terminate. The regular compute_ctl termination path + // assumes Postgres to be configured and running, so we just special-handle this case by exiting + // the process directly. + if compute.params.lakebase_mode && state.status == ComputeStatus::Empty { + drop(state); + info!("terminating empty compute - will exit process"); + + // Queue a task to exit the process after 5 seconds. The 5-second delay aims to + // give enough time for the HTTP response to be sent so that HCM doesn't get an abrupt + // connection termination. + tokio::spawn(async { + tokio::time::sleep(tokio::time::Duration::from_secs(5)).await; + info!("exiting process after terminating empty compute"); + std::process::exit(0); + }); + + return StatusCode::OK.into_response(); + } + + // For Running status, proceed with normal termination state.set_status(mode.into(), &compute.state_changed); + drop(state); } forward_termination_signal(false); diff --git a/compute_tools/src/http/server.rs b/compute_tools/src/http/server.rs index f0fbca8263..2fd3121f4f 100644 --- a/compute_tools/src/http/server.rs +++ b/compute_tools/src/http/server.rs @@ -23,7 +23,8 @@ use super::{ middleware::authorize::Authorize, routes::{ check_writability, configure, database_schema, dbs_and_roles, extension_server, extensions, - grants, insights, lfc, metrics, metrics_json, promote, status, terminate, + grants, hadron_liveness_probe, insights, lfc, metrics, metrics_json, promote, + refresh_configuration, status, terminate, }, }; use crate::compute::ComputeNode; @@ -43,6 +44,7 @@ pub enum Server { port: u16, config: ComputeCtlConfig, compute_id: String, + instance_id: Option, }, } @@ -67,7 +69,12 @@ impl From<&Server> for Router> { post(extension_server::download_extension), ) .route("/extensions", post(extensions::install_extension)) - .route("/grants", post(grants::add_grant)); + .route("/grants", post(grants::add_grant)) + // Hadron: Compute-initiated configuration refresh + .route( + "/refresh_configuration", + post(refresh_configuration::refresh_configuration), + ); // Add in any testing support if cfg!(feature = "testing") { @@ -79,7 +86,10 @@ impl From<&Server> for Router> { router } Server::External { - config, compute_id, .. + config, + compute_id, + instance_id, + .. } => { let unauthenticated_router = Router::>::new() .route("/metrics", get(metrics::get_metrics)) @@ -100,8 +110,13 @@ impl From<&Server> for Router> { .route("/metrics.json", get(metrics_json::get_metrics)) .route("/status", get(status::get_status)) .route("/terminate", post(terminate::terminate)) + .route( + "/hadron_liveness_probe", + get(hadron_liveness_probe::hadron_liveness_probe), + ) .layer(AsyncRequireAuthorizationLayer::new(Authorize::new( compute_id.clone(), + instance_id.clone(), config.jwks.clone(), ))); diff --git a/compute_tools/src/installed_extensions.rs b/compute_tools/src/installed_extensions.rs index 90e1a17be4..5f60b711c8 100644 --- a/compute_tools/src/installed_extensions.rs +++ b/compute_tools/src/installed_extensions.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use anyhow::Result; use compute_api::responses::{InstalledExtension, InstalledExtensions}; +use once_cell::sync::Lazy; use tokio_postgres::error::Error as PostgresError; use tokio_postgres::{Client, Config, NoTls}; @@ -119,3 +120,7 @@ pub async fn get_installed_extensions( extensions: extensions_map.into_values().collect(), }) } + +pub fn initialize_metrics() { + Lazy::force(&INSTALLED_EXTENSIONS); +} diff --git a/compute_tools/src/lib.rs b/compute_tools/src/lib.rs index 5ffa2f004a..85a6f955d9 100644 --- a/compute_tools/src/lib.rs +++ b/compute_tools/src/lib.rs @@ -25,6 +25,7 @@ mod migration; pub mod monitor; pub mod params; pub mod pg_helpers; +pub mod pg_isready; pub mod pgbouncer; pub mod rsyslog; pub mod spec; diff --git a/compute_tools/src/logger.rs b/compute_tools/src/logger.rs index cd076472a6..83e666223c 100644 --- a/compute_tools/src/logger.rs +++ b/compute_tools/src/logger.rs @@ -1,7 +1,10 @@ use std::collections::HashMap; +use std::sync::{LazyLock, RwLock}; +use tracing::Subscriber; use tracing::info; -use tracing_subscriber::layer::SubscriberExt; +use tracing_appender; use tracing_subscriber::prelude::*; +use tracing_subscriber::{fmt, layer::SubscriberExt, registry::LookupSpan}; /// Initialize logging to stderr, and OpenTelemetry tracing and exporter. /// @@ -15,16 +18,44 @@ use tracing_subscriber::prelude::*; /// pub fn init_tracing_and_logging( default_log_level: &str, -) -> anyhow::Result> { + log_dir_opt: &Option, +) -> anyhow::Result<( + Option, + Option, +)> { // Initialize Logging let env_filter = tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new(default_log_level)); + // Standard output streams let fmt_layer = tracing_subscriber::fmt::layer() .with_ansi(false) .with_target(false) .with_writer(std::io::stderr); + // Logs with file rotation. Files in `$log_dir/pgcctl.yyyy-MM-dd` + let (json_to_file_layer, _file_logs_guard) = if let Some(log_dir) = log_dir_opt { + std::fs::create_dir_all(log_dir)?; + let file_logs_appender = tracing_appender::rolling::RollingFileAppender::builder() + .rotation(tracing_appender::rolling::Rotation::DAILY) + .filename_prefix("pgcctl") + // Lib appends to existing files, so we will keep files for up to 2 days even on restart loops. + // At minimum, log-daemon will have 1 day to detect and upload a file (if created right before midnight). + .max_log_files(2) + .build(log_dir) + .expect("Initializing rolling file appender should succeed"); + let (file_logs_writer, _file_logs_guard) = + tracing_appender::non_blocking(file_logs_appender); + let json_to_file_layer = tracing_subscriber::fmt::layer() + .with_ansi(false) + .with_target(false) + .event_format(PgJsonLogShapeFormatter) + .with_writer(file_logs_writer); + (Some(json_to_file_layer), Some(_file_logs_guard)) + } else { + (None, None) + }; + // Initialize OpenTelemetry let provider = tracing_utils::init_tracing("compute_ctl", tracing_utils::ExportConfig::default()); @@ -35,12 +66,13 @@ pub fn init_tracing_and_logging( .with(env_filter) .with(otlp_layer) .with(fmt_layer) + .with(json_to_file_layer) .init(); tracing::info!("logging and tracing started"); utils::logging::replace_panic_hook_with_tracing_panic_hook().forget(); - Ok(provider) + Ok((provider, _file_logs_guard)) } /// Replace all newline characters with a special character to make it @@ -95,3 +127,157 @@ pub fn startup_context_from_env() -> Option { None } } + +/// Track relevant id's +const UNKNOWN_IDS: &str = r#""pg_instance_id": "", "pg_compute_id": """#; +static IDS: LazyLock> = LazyLock::new(|| RwLock::new(UNKNOWN_IDS.to_string())); + +pub fn update_ids(instance_id: &Option, compute_id: &Option) -> anyhow::Result<()> { + let ids = format!( + r#""pg_instance_id": "{}", "pg_compute_id": "{}""#, + instance_id.as_ref().map(|s| s.as_str()).unwrap_or_default(), + compute_id.as_ref().map(|s| s.as_str()).unwrap_or_default() + ); + let mut guard = IDS + .write() + .map_err(|e| anyhow::anyhow!("Log set id's rwlock poisoned: {}", e))?; + *guard = ids; + Ok(()) +} + +/// Massage compute_ctl logs into PG json log shape so we can use the same Lumberjack setup. +struct PgJsonLogShapeFormatter; +impl fmt::format::FormatEvent for PgJsonLogShapeFormatter +where + S: Subscriber + for<'a> LookupSpan<'a>, + N: for<'a> fmt::format::FormatFields<'a> + 'static, +{ + fn format_event( + &self, + ctx: &fmt::FmtContext<'_, S, N>, + mut writer: fmt::format::Writer<'_>, + event: &tracing::Event<'_>, + ) -> std::fmt::Result { + // Format values from the event's metadata, and open message string + let metadata = event.metadata(); + { + let ids_guard = IDS.read(); + let ids = ids_guard + .as_ref() + .map(|guard| guard.as_str()) + // Surpress so that we don't lose all uploaded/ file logs if something goes super wrong. We would notice the missing id's. + .unwrap_or(UNKNOWN_IDS); + write!( + &mut writer, + r#"{{"timestamp": "{}", "error_severity": "{}", "file_name": "{}", "backend_type": "compute_ctl_self", {}, "message": "#, + chrono::Utc::now().format("%Y-%m-%d %H:%M:%S%.3f GMT"), + metadata.level(), + metadata.target(), + ids + )?; + } + + let mut message = String::new(); + let message_writer = fmt::format::Writer::new(&mut message); + + // Gather the message + ctx.field_format().format_fields(message_writer, event)?; + + // TODO: any better options than to copy-paste this OSS span formatter? + // impl FormatEvent for Format + // https://docs.rs/tracing-subscriber/latest/tracing_subscriber/fmt/trait.FormatEvent.html#impl-FormatEvent%3CS,+N%3E-for-Format%3CFull,+T%3E + + // write message, close bracket, and new line + writeln!(writer, "{}}}", serde_json::to_string(&message).unwrap()) + } +} + +#[cfg(feature = "testing")] +#[cfg(test)] +mod test { + use super::*; + use std::{cell::RefCell, io}; + + // Use thread_local! instead of Mutex for test isolation + thread_local! { + static WRITER_OUTPUT: RefCell = const { RefCell::new(String::new()) }; + } + + #[derive(Clone, Default)] + struct StaticStringWriter; + + impl io::Write for StaticStringWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + let output = String::from_utf8(buf.to_vec()).expect("Invalid UTF-8 in test output"); + WRITER_OUTPUT.with(|s| s.borrow_mut().push_str(&output)); + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } + } + + impl fmt::MakeWriter<'_> for StaticStringWriter { + type Writer = Self; + + fn make_writer(&self) -> Self::Writer { + Self + } + } + + #[test] + fn test_log_pg_json_shape_formatter() { + // Use a scoped subscriber to prevent global state pollution + let subscriber = tracing_subscriber::registry().with( + tracing_subscriber::fmt::layer() + .with_ansi(false) + .with_target(false) + .event_format(PgJsonLogShapeFormatter) + .with_writer(StaticStringWriter), + ); + + let _ = update_ids(&Some("000".to_string()), &Some("111".to_string())); + + // Clear any previous test state + WRITER_OUTPUT.with(|s| s.borrow_mut().clear()); + + let messages = [ + "test message", + r#"json escape check: name="BatchSpanProcessor.Flush.ExportError" reason="Other(reqwest::Error { kind: Request, url: \"http://localhost:4318/v1/traces\", source: hyper_ + util::client::legacy::Error(Connect, ConnectError(\"tcp connect error\", Os { code: 111, kind: ConnectionRefused, message: \"Connection refused\" })) })" Failed during the export process"#, + ]; + + tracing::subscriber::with_default(subscriber, || { + for message in messages { + tracing::info!(message); + } + }); + tracing::info!("not test message"); + + // Get captured output + let output = WRITER_OUTPUT.with(|s| s.borrow().clone()); + + let json_strings: Vec<&str> = output.lines().collect(); + assert_eq!( + json_strings.len(), + messages.len(), + "Log didn't have the expected number of json strings." + ); + + let json_string_shape_regex = regex::Regex::new( + r#"\{"timestamp": "\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{3} GMT", "error_severity": "INFO", "file_name": ".+", "backend_type": "compute_ctl_self", "pg_instance_id": "000", "pg_compute_id": "111", "message": ".+"\}"# + ).unwrap(); + + for (i, expected_message) in messages.iter().enumerate() { + let json_string = json_strings[i]; + assert!( + json_string_shape_regex.is_match(json_string), + "Json log didn't match expected pattern:\n{json_string}", + ); + let parsed_json: serde_json::Value = serde_json::from_str(json_string).unwrap(); + let actual_message = parsed_json["message"].as_str().unwrap(); + assert_eq!(*expected_message, actual_message); + } + } +} diff --git a/compute_tools/src/pg_isready.rs b/compute_tools/src/pg_isready.rs new file mode 100644 index 0000000000..76c45d6b0a --- /dev/null +++ b/compute_tools/src/pg_isready.rs @@ -0,0 +1,30 @@ +use anyhow::{Context, anyhow}; + +// Run `/usr/local/bin/pg_isready -p {port}` +// Check the connectivity of PG +// Success means PG is listening on the port and accepting connections +// Note that PG does not need to authenticate the connection, nor reserve a connection quota for it. +// See https://www.postgresql.org/docs/current/app-pg-isready.html +pub fn pg_isready(bin: &str, port: u16) -> anyhow::Result<()> { + let child_result = std::process::Command::new(bin) + .arg("-p") + .arg(port.to_string()) + .spawn(); + + child_result + .context("spawn() failed") + .and_then(|mut child| child.wait().context("wait() failed")) + .and_then(|status| match status.success() { + true => Ok(()), + false => Err(anyhow!("process exited with {status}")), + }) + // wrap any prior error with the overall context that we couldn't run the command + .with_context(|| format!("could not run `{bin} --port {port}`")) +} + +// It's safe to assume pg_isready is under the same directory with postgres, +// because it is a PG util bin installed along with postgres +pub fn get_pg_isready_bin(pgbin: &str) -> String { + let split = pgbin.split("/").collect::>(); + split[0..split.len() - 1].join("/") + "/pg_isready" +} diff --git a/compute_tools/src/spec.rs b/compute_tools/src/spec.rs index d00f86a2c0..2dda9c3134 100644 --- a/compute_tools/src/spec.rs +++ b/compute_tools/src/spec.rs @@ -142,7 +142,7 @@ pub fn update_pg_hba(pgdata_path: &Path, databricks_pg_hba: Option<&String>) -> // Update pg_hba to contains databricks specfic settings before adding neon settings // PG uses the first record that matches to perform authentication, so we need to have // our rules before the default ones from neon. - // See https://www.postgresql.org/docs/16/auth-pg-hba-conf.html + // See https://www.postgresql.org/docs/current/auth-pg-hba-conf.html if let Some(databricks_pg_hba) = databricks_pg_hba { if config::line_in_file( &pghba_path, diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index 579a696398..372118c6aa 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -560,7 +560,9 @@ enum EndpointCmd { Create(EndpointCreateCmdArgs), Start(EndpointStartCmdArgs), Reconfigure(EndpointReconfigureCmdArgs), + RefreshConfiguration(EndpointRefreshConfigurationArgs), Stop(EndpointStopCmdArgs), + UpdatePageservers(EndpointUpdatePageserversCmdArgs), GenerateJwt(EndpointGenerateJwtCmdArgs), } @@ -721,6 +723,13 @@ struct EndpointReconfigureCmdArgs { safekeepers: Option, } +#[derive(clap::Args)] +#[clap(about = "Refresh the endpoint's configuration by forcing it reload it's spec")] +struct EndpointRefreshConfigurationArgs { + #[clap(help = "Postgres endpoint id")] + endpoint_id: String, +} + #[derive(clap::Args)] #[clap(about = "Stop an endpoint")] struct EndpointStopCmdArgs { @@ -738,6 +747,16 @@ struct EndpointStopCmdArgs { mode: EndpointTerminateMode, } +#[derive(clap::Args)] +#[clap(about = "Update the pageservers in the spec file of the compute endpoint")] +struct EndpointUpdatePageserversCmdArgs { + #[clap(help = "Postgres endpoint id")] + endpoint_id: String, + + #[clap(short = 'p', long, help = "Specified pageserver id")] + pageserver_id: Option, +} + #[derive(clap::Args)] #[clap(about = "Generate a JWT for an endpoint")] struct EndpointGenerateJwtCmdArgs { @@ -1518,7 +1537,7 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res let endpoint = cplane .endpoints .get(endpoint_id.as_str()) - .ok_or_else(|| anyhow::anyhow!("endpoint {endpoint_id} not found"))?; + .ok_or_else(|| anyhow!("endpoint {endpoint_id} not found"))?; if !args.allow_multiple { cplane.check_conflicting_endpoints( @@ -1629,6 +1648,44 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res println!("Starting existing endpoint {endpoint_id}..."); endpoint.start(args).await?; } + EndpointCmd::UpdatePageservers(args) => { + let endpoint_id = &args.endpoint_id; + let endpoint = cplane + .endpoints + .get(endpoint_id.as_str()) + .with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?; + let pageservers = match args.pageserver_id { + Some(pageserver_id) => { + let pageserver = + PageServerNode::from_env(env, env.get_pageserver_conf(pageserver_id)?); + + vec![( + PageserverProtocol::Libpq, + pageserver.pg_connection_config.host().clone(), + pageserver.pg_connection_config.port(), + )] + } + None => { + let storage_controller = StorageController::from_env(env); + storage_controller + .tenant_locate(endpoint.tenant_id) + .await? + .shards + .into_iter() + .map(|shard| { + ( + PageserverProtocol::Libpq, + Host::parse(&shard.listen_pg_addr) + .expect("Storage controller reported malformed host"), + shard.listen_pg_port, + ) + }) + .collect::>() + } + }; + + endpoint.update_pageservers_in_config(pageservers).await?; + } EndpointCmd::Reconfigure(args) => { let endpoint_id = &args.endpoint_id; let endpoint = cplane @@ -1682,6 +1739,14 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res .reconfigure(Some(pageservers), None, safekeepers, None) .await?; } + EndpointCmd::RefreshConfiguration(args) => { + let endpoint_id = &args.endpoint_id; + let endpoint = cplane + .endpoints + .get(endpoint_id.as_str()) + .with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?; + endpoint.refresh_configuration().await?; + } EndpointCmd::Stop(args) => { let endpoint_id = &args.endpoint_id; let endpoint = cplane diff --git a/control_plane/src/endpoint.rs b/control_plane/src/endpoint.rs index 892180a4dc..4317c4d0f1 100644 --- a/control_plane/src/endpoint.rs +++ b/control_plane/src/endpoint.rs @@ -846,6 +846,7 @@ impl Endpoint { autoprewarm: args.autoprewarm, offload_lfc_interval_seconds: args.offload_lfc_interval_seconds, suspend_timeout_seconds: -1, // Only used in neon_local. + databricks_settings: None, }; // this strange code is needed to support respec() in tests @@ -990,7 +991,9 @@ impl Endpoint { | ComputeStatus::Configuration | ComputeStatus::TerminationPendingFast | ComputeStatus::TerminationPendingImmediate - | ComputeStatus::Terminated => { + | ComputeStatus::Terminated + | ComputeStatus::RefreshConfigurationPending + | ComputeStatus::RefreshConfiguration => { bail!("unexpected compute status: {:?}", state.status) } } @@ -1013,6 +1016,29 @@ impl Endpoint { Ok(()) } + // Update the pageservers in the spec file of the endpoint. This is useful to test the spec refresh scenario. + pub async fn update_pageservers_in_config( + &self, + pageservers: Vec<(PageserverProtocol, Host, u16)>, + ) -> Result<()> { + let config_path = self.endpoint_path().join("config.json"); + let mut config: ComputeConfig = { + let file = std::fs::File::open(&config_path)?; + serde_json::from_reader(file)? + }; + + let pageserver_connstring = Self::build_pageserver_connstr(&pageservers); + assert!(!pageserver_connstring.is_empty()); + let mut spec = config.spec.unwrap(); + spec.pageserver_connstring = Some(pageserver_connstring); + config.spec = Some(spec); + + let file = std::fs::File::create(&config_path)?; + serde_json::to_writer_pretty(file, &config)?; + + Ok(()) + } + // Call the /status HTTP API pub async fn get_status(&self) -> Result { let client = reqwest::Client::new(); @@ -1178,6 +1204,33 @@ impl Endpoint { Ok(response) } + pub async fn refresh_configuration(&self) -> Result<()> { + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(30)) + .build() + .unwrap(); + let response = client + .post(format!( + "http://{}:{}/refresh_configuration", + self.internal_http_address.ip(), + self.internal_http_address.port() + )) + .send() + .await?; + + let status = response.status(); + if !(status.is_client_error() || status.is_server_error()) { + Ok(()) + } else { + let url = response.url().to_owned(); + let msg = match response.text().await { + Ok(err_body) => format!("Error: {err_body}"), + Err(_) => format!("Http error ({}) at {}.", status.as_u16(), url), + }; + Err(anyhow::anyhow!(msg)) + } + } + pub fn connstr(&self, user: &str, db_name: &str) -> String { format!( "postgresql://{}@{}:{}/{}", diff --git a/libs/compute_api/src/responses.rs b/libs/compute_api/src/responses.rs index 5b8fc49750..a27301e45e 100644 --- a/libs/compute_api/src/responses.rs +++ b/libs/compute_api/src/responses.rs @@ -108,11 +108,10 @@ pub enum PromoteState { Failed { error: String }, } -#[derive(Deserialize, Serialize, Default, Debug, Clone)] +#[derive(Deserialize, Default, Debug)] #[serde(rename_all = "snake_case")] -/// Result of /safekeepers_lsn -pub struct SafekeepersLsn { - pub safekeepers: String, +pub struct PromoteConfig { + pub spec: ComputeSpec, pub wal_flush_lsn: utils::lsn::Lsn, } @@ -173,6 +172,11 @@ pub enum ComputeStatus { TerminationPendingImmediate, // Terminated Postgres Terminated, + // A spec refresh is being requested + RefreshConfigurationPending, + // A spec refresh is being applied. We cannot refresh configuration again until the current + // refresh is done, i.e., signal_refresh_configuration() will return 500 error. + RefreshConfiguration, } #[derive(Deserialize, Serialize)] @@ -185,6 +189,10 @@ impl Display for ComputeStatus { match self { ComputeStatus::Empty => f.write_str("empty"), ComputeStatus::ConfigurationPending => f.write_str("configuration-pending"), + ComputeStatus::RefreshConfiguration => f.write_str("refresh-configuration"), + ComputeStatus::RefreshConfigurationPending => { + f.write_str("refresh-configuration-pending") + } ComputeStatus::Init => f.write_str("init"), ComputeStatus::Running => f.write_str("running"), ComputeStatus::Configuration => f.write_str("configuration"), diff --git a/libs/compute_api/src/spec.rs b/libs/compute_api/src/spec.rs index 061ac3e66d..6709c06fc6 100644 --- a/libs/compute_api/src/spec.rs +++ b/libs/compute_api/src/spec.rs @@ -193,6 +193,9 @@ pub struct ComputeSpec { /// /// We use this value to derive other values, such as the installed extensions metric. pub suspend_timeout_seconds: i64, + + // Databricks specific options for compute instance. + pub databricks_settings: Option, } /// Feature flag to signal `compute_ctl` to enable certain experimental functionality. diff --git a/libs/http-utils/src/endpoint.rs b/libs/http-utils/src/endpoint.rs index f5a7735ad8..2b54ffbf12 100644 --- a/libs/http-utils/src/endpoint.rs +++ b/libs/http-utils/src/endpoint.rs @@ -558,11 +558,11 @@ async fn add_request_id_header_to_response( mut res: Response, req_info: RequestInfo, ) -> Result, ApiError> { - if let Some(request_id) = req_info.context::() { - if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) { - res.headers_mut() - .insert(&X_REQUEST_ID_HEADER, request_header_value); - }; + if let Some(request_id) = req_info.context::() + && let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) + { + res.headers_mut() + .insert(&X_REQUEST_ID_HEADER, request_header_value); }; Ok(res) diff --git a/libs/http-utils/src/server.rs b/libs/http-utils/src/server.rs index f93f71c962..ce90b8d710 100644 --- a/libs/http-utils/src/server.rs +++ b/libs/http-utils/src/server.rs @@ -72,10 +72,10 @@ impl Server { if err.is_incomplete_message() || err.is_closed() || err.is_timeout() { return true; } - if let Some(inner) = err.source() { - if let Some(io) = inner.downcast_ref::() { - return suppress_io_error(io); - } + if let Some(inner) = err.source() + && let Some(io) = inner.downcast_ref::() + { + return suppress_io_error(io); } false } diff --git a/libs/metrics/src/lib.rs b/libs/metrics/src/lib.rs index 41873cdcd6..6cf27abcaf 100644 --- a/libs/metrics/src/lib.rs +++ b/libs/metrics/src/lib.rs @@ -129,6 +129,12 @@ impl InfoMetric { } } +impl Default for InfoMetric { + fn default() -> Self { + InfoMetric::new(L::default()) + } +} + impl> InfoMetric { pub fn with_metric(label: L, metric: M) -> Self { Self { diff --git a/libs/neon-shmem/src/hash.rs b/libs/neon-shmem/src/hash.rs index 58726b9ba3..a58797d8fa 100644 --- a/libs/neon-shmem/src/hash.rs +++ b/libs/neon-shmem/src/hash.rs @@ -363,7 +363,7 @@ where // TODO: An Iterator might be nicer. The communicator's clock algorithm needs to // _slowly_ iterate through all buckets with its clock hand, without holding a lock. // If we switch to an Iterator, it must not hold the lock. - pub fn get_at_bucket(&self, pos: usize) -> Option> { + pub fn get_at_bucket(&self, pos: usize) -> Option> { let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read(); if pos >= map.buckets.len() { return None; diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 7c7c65fb70..230c1f46ea 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -1500,6 +1500,7 @@ pub struct TimelineArchivalConfigRequest { #[derive(Serialize, Deserialize, PartialEq, Eq, Clone)] pub struct TimelinePatchIndexPartRequest { pub rel_size_migration: Option, + pub rel_size_migrated_at: Option, pub gc_compaction_last_completed_lsn: Option, pub applied_gc_cutoff_lsn: Option, #[serde(default)] @@ -1533,10 +1534,10 @@ pub enum RelSizeMigration { /// `None` is the same as `Some(RelSizeMigration::Legacy)`. Legacy, /// The tenant is migrating to the new rel_size format. Both old and new rel_size format are - /// persisted in the index part. The read path will read both formats and merge them. + /// persisted in the storage. The read path will read both formats and validate them. Migrating, /// The tenant has migrated to the new rel_size format. Only the new rel_size format is persisted - /// in the index part, and the read path will not read the old format. + /// in the storage, and the read path will not read the old format. Migrated, } @@ -1619,6 +1620,7 @@ pub struct TimelineInfo { /// The status of the rel_size migration. pub rel_size_migration: Option, + pub rel_size_migrated_at: Option, /// Whether the timeline is invisible in synthetic size calculations. pub is_invisible: Option, diff --git a/libs/postgres_ffi/Cargo.toml b/libs/postgres_ffi/Cargo.toml index d4fec6cbe9..fca75b7bc1 100644 --- a/libs/postgres_ffi/Cargo.toml +++ b/libs/postgres_ffi/Cargo.toml @@ -12,7 +12,6 @@ crc32c.workspace = true criterion.workspace = true once_cell.workspace = true log.workspace = true -memoffset.workspace = true pprof.workspace = true thiserror.workspace = true serde.workspace = true diff --git a/libs/postgres_ffi/src/controlfile_utils.rs b/libs/postgres_ffi/src/controlfile_utils.rs index eaa9450294..d6d17ce3fb 100644 --- a/libs/postgres_ffi/src/controlfile_utils.rs +++ b/libs/postgres_ffi/src/controlfile_utils.rs @@ -34,9 +34,8 @@ const SIZEOF_CONTROLDATA: usize = size_of::(); impl ControlFileData { /// Compute the offset of the `crc` field within the `ControlFileData` struct. /// Equivalent to offsetof(ControlFileData, crc) in C. - // Someday this can be const when the right compiler features land. - fn pg_control_crc_offset() -> usize { - memoffset::offset_of!(ControlFileData, crc) + const fn pg_control_crc_offset() -> usize { + std::mem::offset_of!(ControlFileData, crc) } /// diff --git a/proxy/subzero_core/.gitignore b/libs/proxy/subzero_core/.gitignore similarity index 100% rename from proxy/subzero_core/.gitignore rename to libs/proxy/subzero_core/.gitignore diff --git a/proxy/subzero_core/Cargo.toml b/libs/proxy/subzero_core/Cargo.toml similarity index 100% rename from proxy/subzero_core/Cargo.toml rename to libs/proxy/subzero_core/Cargo.toml diff --git a/proxy/subzero_core/src/lib.rs b/libs/proxy/subzero_core/src/lib.rs similarity index 100% rename from proxy/subzero_core/src/lib.rs rename to libs/proxy/subzero_core/src/lib.rs diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index 068566e955..90ff39aff1 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -15,6 +15,7 @@ use tokio::sync::mpsc; use crate::cancel_token::RawCancelToken; use crate::codec::{BackendMessages, FrontendMessage, RecordNotices}; use crate::config::{Host, SslMode}; +use crate::connection::gc_bytesmut; use crate::query::RowStream; use crate::simple_query::SimpleQueryStream; use crate::types::{Oid, Type}; @@ -95,20 +96,13 @@ impl InnerClient { Ok(PartialQuery(Some(self))) } - // pub fn send_with_sync(&mut self, f: F) -> Result<&mut Responses, Error> - // where - // F: FnOnce(&mut BytesMut) -> Result<(), Error>, - // { - // self.start()?.send_with_sync(f) - // } - pub fn send_simple_query(&mut self, query: &str) -> Result<&mut Responses, Error> { self.responses.waiting += 1; self.buffer.clear(); // simple queries do not need sync. frontend::query(query, &mut self.buffer).map_err(Error::encode)?; - let buf = self.buffer.split().freeze(); + let buf = self.buffer.split(); self.send_message(FrontendMessage::Raw(buf)) } @@ -125,7 +119,7 @@ impl Drop for PartialQuery<'_> { if let Some(client) = self.0.take() { client.buffer.clear(); frontend::sync(&mut client.buffer); - let buf = client.buffer.split().freeze(); + let buf = client.buffer.split(); let _ = client.send_message(FrontendMessage::Raw(buf)); } } @@ -141,7 +135,7 @@ impl<'a> PartialQuery<'a> { client.buffer.clear(); f(&mut client.buffer)?; frontend::flush(&mut client.buffer); - let buf = client.buffer.split().freeze(); + let buf = client.buffer.split(); client.send_message(FrontendMessage::Raw(buf)) } @@ -154,7 +148,7 @@ impl<'a> PartialQuery<'a> { client.buffer.clear(); f(&mut client.buffer)?; frontend::sync(&mut client.buffer); - let buf = client.buffer.split().freeze(); + let buf = client.buffer.split(); let _ = client.send_message(FrontendMessage::Raw(buf)); Ok(&mut self.0.take().unwrap().responses) @@ -191,6 +185,7 @@ impl Client { ssl_mode: SslMode, process_id: i32, secret_key: i32, + write_buf: BytesMut, ) -> Client { Client { inner: InnerClient { @@ -201,7 +196,7 @@ impl Client { waiting: 0, received: 0, }, - buffer: Default::default(), + buffer: write_buf, }, cached_typeinfo: Default::default(), @@ -292,8 +287,35 @@ impl Client { simple_query::batch_execute(self.inner_mut(), query).await } - pub async fn discard_all(&mut self) -> Result { - self.batch_execute("discard all").await + /// Similar to `discard_all`, but it does not clear any query plans + /// + /// This runs in the background, so it can be executed without `await`ing. + pub fn reset_session_background(&mut self) -> Result<(), Error> { + // "CLOSE ALL": closes any cursors + // "SET SESSION AUTHORIZATION DEFAULT": resets the current_user back to the session_user + // "RESET ALL": resets any GUCs back to their session defaults. + // "DEALLOCATE ALL": deallocates any prepared statements + // "UNLISTEN *": stops listening on all channels + // "SELECT pg_advisory_unlock_all();": unlocks all advisory locks + // "DISCARD TEMP;": drops all temporary tables + // "DISCARD SEQUENCES;": deallocates all cached sequence state + + let _responses = self.inner_mut().send_simple_query( + "ROLLBACK; + CLOSE ALL; + SET SESSION AUTHORIZATION DEFAULT; + RESET ALL; + DEALLOCATE ALL; + UNLISTEN *; + SELECT pg_advisory_unlock_all(); + DISCARD TEMP; + DISCARD SEQUENCES;", + )?; + + // Clean up memory usage. + gc_bytesmut(&mut self.inner_mut().buffer); + + Ok(()) } /// Begins a new database transaction. diff --git a/libs/proxy/tokio-postgres2/src/codec.rs b/libs/proxy/tokio-postgres2/src/codec.rs index 813faa0e35..35f616d229 100644 --- a/libs/proxy/tokio-postgres2/src/codec.rs +++ b/libs/proxy/tokio-postgres2/src/codec.rs @@ -1,13 +1,13 @@ use std::io; -use bytes::{Bytes, BytesMut}; +use bytes::BytesMut; use fallible_iterator::FallibleIterator; use postgres_protocol2::message::backend; use tokio::sync::mpsc::UnboundedSender; use tokio_util::codec::{Decoder, Encoder}; pub enum FrontendMessage { - Raw(Bytes), + Raw(BytesMut), RecordNotices(RecordNotices), } @@ -17,7 +17,10 @@ pub struct RecordNotices { } pub enum BackendMessage { - Normal { messages: BackendMessages }, + Normal { + messages: BackendMessages, + ready: bool, + }, Async(backend::Message), } @@ -40,11 +43,11 @@ impl FallibleIterator for BackendMessages { pub struct PostgresCodec; -impl Encoder for PostgresCodec { +impl Encoder for PostgresCodec { type Error = io::Error; - fn encode(&mut self, item: Bytes, dst: &mut BytesMut) -> io::Result<()> { - dst.extend_from_slice(&item); + fn encode(&mut self, item: BytesMut, dst: &mut BytesMut) -> io::Result<()> { + dst.unsplit(item); Ok(()) } } @@ -56,6 +59,7 @@ impl Decoder for PostgresCodec { fn decode(&mut self, src: &mut BytesMut) -> Result, io::Error> { let mut idx = 0; + let mut ready = false; while let Some(header) = backend::Header::parse(&src[idx..])? { let len = header.len() as usize + 1; if src[idx..].len() < len { @@ -79,6 +83,7 @@ impl Decoder for PostgresCodec { idx += len; if header.tag() == backend::READY_FOR_QUERY_TAG { + ready = true; break; } } @@ -88,6 +93,7 @@ impl Decoder for PostgresCodec { } else { Ok(Some(BackendMessage::Normal { messages: BackendMessages(src.split_to(idx)), + ready, })) } } diff --git a/libs/proxy/tokio-postgres2/src/config.rs b/libs/proxy/tokio-postgres2/src/config.rs index c619f92d13..3579dd94a2 100644 --- a/libs/proxy/tokio-postgres2/src/config.rs +++ b/libs/proxy/tokio-postgres2/src/config.rs @@ -250,19 +250,20 @@ impl Config { { let stream = connect_tls(stream, self.ssl_mode, tls).await?; let mut stream = StartupStream::new(stream); - connect_raw::startup(&mut stream, self).await?; connect_raw::authenticate(&mut stream, self).await?; Ok(stream) } - pub async fn authenticate(&self, stream: &mut StartupStream) -> Result<(), Error> + pub fn authenticate( + &self, + stream: &mut StartupStream, + ) -> impl Future> where S: AsyncRead + AsyncWrite + Unpin, T: TlsStream + Unpin, { - connect_raw::startup(stream, self).await?; - connect_raw::authenticate(stream, self).await + connect_raw::authenticate(stream, self) } } diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs index 41d95c5f84..b1df87811e 100644 --- a/libs/proxy/tokio-postgres2/src/connect.rs +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -7,7 +7,7 @@ use tokio::net::TcpStream; use tokio::sync::mpsc; use crate::client::SocketConfig; -use crate::config::Host; +use crate::config::{Host, SslMode}; use crate::connect_raw::StartupStream; use crate::connect_socket::connect_socket; use crate::tls::{MakeTlsConnect, TlsConnect}; @@ -45,28 +45,53 @@ where T: TlsConnect, { let socket = connect_socket(host_addr, host, port, config.connect_timeout).await?; - let mut stream = config.tls_and_authenticate(socket, tls).await?; + let stream = config.tls_and_authenticate(socket, tls).await?; + managed( + stream, + host_addr, + host.clone(), + port, + config.ssl_mode, + config.connect_timeout, + ) + .await +} + +pub async fn managed( + mut stream: StartupStream, + host_addr: Option, + host: Host, + port: u16, + ssl_mode: SslMode, + connect_timeout: Option, +) -> Result<(Client, Connection), Error> +where + TlsStream: AsyncRead + AsyncWrite + Unpin, +{ let (process_id, secret_key) = wait_until_ready(&mut stream).await?; let socket_config = SocketConfig { host_addr, - host: host.clone(), + host, port, - connect_timeout: config.connect_timeout, + connect_timeout, }; + let mut stream = stream.into_framed(); + let write_buf = std::mem::take(stream.write_buffer_mut()); + let (client_tx, conn_rx) = mpsc::unbounded_channel(); let (conn_tx, client_rx) = mpsc::channel(4); let client = Client::new( client_tx, client_rx, socket_config, - config.ssl_mode, + ssl_mode, process_id, secret_key, + write_buf, ); - let stream = stream.into_framed(); let connection = Connection::new(stream, conn_tx, conn_rx); Ok((client, connection)) diff --git a/libs/proxy/tokio-postgres2/src/connect_raw.rs b/libs/proxy/tokio-postgres2/src/connect_raw.rs index bc35cef339..17237eeef5 100644 --- a/libs/proxy/tokio-postgres2/src/connect_raw.rs +++ b/libs/proxy/tokio-postgres2/src/connect_raw.rs @@ -2,51 +2,28 @@ use std::io; use std::pin::Pin; use std::task::{Context, Poll, ready}; -use bytes::{Bytes, BytesMut}; +use bytes::BytesMut; use fallible_iterator::FallibleIterator; -use futures_util::{Sink, SinkExt, Stream, TryStreamExt}; +use futures_util::{SinkExt, Stream, TryStreamExt}; use postgres_protocol2::authentication::sasl; use postgres_protocol2::authentication::sasl::ScramSha256; use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message}; use postgres_protocol2::message::frontend; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio_util::codec::{Framed, FramedParts, FramedWrite}; +use tokio_util::codec::{Framed, FramedParts}; use crate::Error; use crate::codec::PostgresCodec; use crate::config::{self, AuthKeys, Config}; +use crate::connection::{GC_THRESHOLD, INITIAL_CAPACITY}; use crate::maybe_tls_stream::MaybeTlsStream; use crate::tls::TlsStream; pub struct StartupStream { - inner: FramedWrite, PostgresCodec>, + inner: Framed, PostgresCodec>, read_buf: BytesMut, } -impl Sink for StartupStream -where - S: AsyncRead + AsyncWrite + Unpin, - T: AsyncRead + AsyncWrite + Unpin, -{ - type Error = io::Error; - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_ready(cx) - } - - fn start_send(mut self: Pin<&mut Self>, item: Bytes) -> io::Result<()> { - Pin::new(&mut self.inner).start_send(item) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_flush(cx) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_close(cx) - } -} - impl Stream for StartupStream where S: AsyncRead + AsyncWrite + Unpin, @@ -55,6 +32,8 @@ where type Item = io::Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // We don't use `self.inner.poll_next()` as that might over-read into the read buffer. + // read 1 byte tag, 4 bytes length. let header = ready!(self.as_mut().poll_fill_buf_exact(cx, 5)?); @@ -121,36 +100,28 @@ where } pub fn into_framed(mut self) -> Framed, PostgresCodec> { - let write_buf = std::mem::take(self.inner.write_buffer_mut()); - let io = self.inner.into_inner(); - let mut parts = FramedParts::new(io, PostgresCodec); - parts.read_buf = self.read_buf; - parts.write_buf = write_buf; - Framed::from_parts(parts) + *self.inner.read_buffer_mut() = self.read_buf; + self.inner } pub fn new(io: MaybeTlsStream) -> Self { + let mut parts = FramedParts::new(io, PostgresCodec); + parts.write_buf = BytesMut::with_capacity(INITIAL_CAPACITY); + + let mut inner = Framed::from_parts(parts); + + // This is the default already, but nice to be explicit. + // We divide by two because writes will overshoot the boundary. + // We don't want constant overshoots to cause us to constantly re-shrink the buffer. + inner.set_backpressure_boundary(GC_THRESHOLD / 2); + Self { - inner: FramedWrite::new(io, PostgresCodec), - read_buf: BytesMut::new(), + inner, + read_buf: BytesMut::with_capacity(INITIAL_CAPACITY), } } } -pub(crate) async fn startup( - stream: &mut StartupStream, - config: &Config, -) -> Result<(), Error> -where - S: AsyncRead + AsyncWrite + Unpin, - T: AsyncRead + AsyncWrite + Unpin, -{ - let mut buf = BytesMut::new(); - frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?; - - stream.send(buf.freeze()).await.map_err(Error::io) -} - pub(crate) async fn authenticate( stream: &mut StartupStream, config: &Config, @@ -159,6 +130,10 @@ where S: AsyncRead + AsyncWrite + Unpin, T: TlsStream + Unpin, { + frontend::startup_message(&config.server_params, stream.inner.write_buffer_mut()) + .map_err(Error::encode)?; + + stream.inner.flush().await.map_err(Error::io)?; match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationOk) => { can_skip_channel_binding(config)?; @@ -172,7 +147,8 @@ where .as_ref() .ok_or_else(|| Error::config("password missing".into()))?; - authenticate_password(stream, pass).await?; + frontend::password_message(pass, stream.inner.write_buffer_mut()) + .map_err(Error::encode)?; } Some(Message::AuthenticationSasl(body)) => { authenticate_sasl(stream, body, config).await?; @@ -191,6 +167,7 @@ where None => return Err(Error::closed()), } + stream.inner.flush().await.map_err(Error::io)?; match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationOk) => Ok(()), Some(Message::ErrorResponse(body)) => Err(Error::db(body)), @@ -208,20 +185,6 @@ fn can_skip_channel_binding(config: &Config) -> Result<(), Error> { } } -async fn authenticate_password( - stream: &mut StartupStream, - password: &[u8], -) -> Result<(), Error> -where - S: AsyncRead + AsyncWrite + Unpin, - T: AsyncRead + AsyncWrite + Unpin, -{ - let mut buf = BytesMut::new(); - frontend::password_message(password, &mut buf).map_err(Error::encode)?; - - stream.send(buf.freeze()).await.map_err(Error::io) -} - async fn authenticate_sasl( stream: &mut StartupStream, body: AuthenticationSaslBody, @@ -276,10 +239,10 @@ where return Err(Error::config("password or auth keys missing".into())); }; - let mut buf = BytesMut::new(); - frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?; - stream.send(buf.freeze()).await.map_err(Error::io)?; + frontend::sasl_initial_response(mechanism, scram.message(), stream.inner.write_buffer_mut()) + .map_err(Error::encode)?; + stream.inner.flush().await.map_err(Error::io)?; let body = match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationSaslContinue(body)) => body, Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), @@ -292,10 +255,10 @@ where .await .map_err(|e| Error::authentication(e.into()))?; - let mut buf = BytesMut::new(); - frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?; - stream.send(buf.freeze()).await.map_err(Error::io)?; + frontend::sasl_response(scram.message(), stream.inner.write_buffer_mut()) + .map_err(Error::encode)?; + stream.inner.flush().await.map_err(Error::io)?; let body = match stream.try_next().await.map_err(Error::io)? { Some(Message::AuthenticationSaslFinal(body)) => body, Some(Message::ErrorResponse(body)) => return Err(Error::db(body)), diff --git a/libs/proxy/tokio-postgres2/src/connection.rs b/libs/proxy/tokio-postgres2/src/connection.rs index c43a22ffe7..303de71cfa 100644 --- a/libs/proxy/tokio-postgres2/src/connection.rs +++ b/libs/proxy/tokio-postgres2/src/connection.rs @@ -44,6 +44,27 @@ pub struct Connection { state: State, } +pub const INITIAL_CAPACITY: usize = 2 * 1024; +pub const GC_THRESHOLD: usize = 16 * 1024; + +/// Gargabe collect the [`BytesMut`] if it has too much spare capacity. +pub fn gc_bytesmut(buf: &mut BytesMut) { + // We use a different mode to shrink the buf when above the threshold. + // When above the threshold, we only re-allocate when the buf has 2x spare capacity. + let reclaim = GC_THRESHOLD.checked_sub(buf.len()).unwrap_or(buf.len()); + + // `try_reclaim` tries to get the capacity from any shared `BytesMut`s, + // before then comparing the length against the capacity. + if buf.try_reclaim(reclaim) { + let capacity = usize::max(buf.len(), INITIAL_CAPACITY); + + // Allocate a new `BytesMut` so that we deallocate the old version. + let mut new = BytesMut::with_capacity(capacity); + new.extend_from_slice(buf); + *buf = new; + } +} + pub enum Never {} impl Connection @@ -86,7 +107,14 @@ where continue; } BackendMessage::Async(_) => continue, - BackendMessage::Normal { messages } => messages, + BackendMessage::Normal { messages, ready } => { + // if we read a ReadyForQuery from postgres, let's try GC the read buffer. + if ready { + gc_bytesmut(self.stream.read_buffer_mut()); + } + + messages + } } } }; @@ -177,12 +205,7 @@ where // Send a terminate message to postgres Poll::Ready(None) => { trace!("poll_write: at eof, terminating"); - let mut request = BytesMut::new(); - frontend::terminate(&mut request); - - Pin::new(&mut self.stream) - .start_send(request.freeze()) - .map_err(Error::io)?; + frontend::terminate(self.stream.write_buffer_mut()); trace!("poll_write: sent eof, closing"); trace!("poll_write: done"); @@ -205,6 +228,13 @@ where { Poll::Ready(()) => { trace!("poll_flush: flushed"); + + // Since our codec prefers to share the buffer with the `Client`, + // if we don't release our share, then the `Client` would have to re-alloc + // the buffer when they next use it. + debug_assert!(self.stream.write_buffer().is_empty()); + *self.stream.write_buffer_mut() = BytesMut::new(); + Poll::Ready(Ok(())) } Poll::Pending => { diff --git a/libs/proxy/tokio-postgres2/src/lib.rs b/libs/proxy/tokio-postgres2/src/lib.rs index a858ddca39..da2665095c 100644 --- a/libs/proxy/tokio-postgres2/src/lib.rs +++ b/libs/proxy/tokio-postgres2/src/lib.rs @@ -48,7 +48,7 @@ mod cancel_token; mod client; mod codec; pub mod config; -mod connect; +pub mod connect; pub mod connect_raw; mod connect_socket; mod connect_tls; diff --git a/libs/tracing-utils/Cargo.toml b/libs/tracing-utils/Cargo.toml index 49a6055b1e..1f8d05ae80 100644 --- a/libs/tracing-utils/Cargo.toml +++ b/libs/tracing-utils/Cargo.toml @@ -8,7 +8,7 @@ license.workspace = true hyper0.workspace = true opentelemetry = { workspace = true, features = ["trace"] } opentelemetry_sdk = { workspace = true, features = ["rt-tokio"] } -opentelemetry-otlp = { workspace = true, default-features = false, features = ["http-proto", "trace", "http", "reqwest-client"] } +opentelemetry-otlp = { workspace = true, default-features = false, features = ["http-proto", "trace", "http", "reqwest-blocking-client"] } opentelemetry-semantic-conventions.workspace = true tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } tracing.workspace = true diff --git a/libs/tracing-utils/src/perf_span.rs b/libs/tracing-utils/src/perf_span.rs index 16f713c67e..4eec0829f7 100644 --- a/libs/tracing-utils/src/perf_span.rs +++ b/libs/tracing-utils/src/perf_span.rs @@ -49,7 +49,7 @@ impl PerfSpan { } } - pub fn enter(&self) -> PerfSpanEntered { + pub fn enter(&self) -> PerfSpanEntered<'_> { if let Some(ref id) = self.inner.id() { self.dispatch.enter(id); } diff --git a/libs/walproposer/src/api_bindings.rs b/libs/walproposer/src/api_bindings.rs index 825a137d0f..c3be1e1dae 100644 --- a/libs/walproposer/src/api_bindings.rs +++ b/libs/walproposer/src/api_bindings.rs @@ -429,9 +429,11 @@ pub fn empty_shmem() -> crate::bindings::WalproposerShmemState { }; let empty_wal_rate_limiter = crate::bindings::WalRateLimiter { + effective_max_wal_bytes_per_second: crate::bindings::pg_atomic_uint32 { value: 0 }, should_limit: crate::bindings::pg_atomic_uint32 { value: 0 }, sent_bytes: 0, - last_recorded_time_us: crate::bindings::pg_atomic_uint64 { value: 0 }, + batch_start_time_us: crate::bindings::pg_atomic_uint64 { value: 0 }, + batch_end_time_us: crate::bindings::pg_atomic_uint64 { value: 0 }, }; crate::bindings::WalproposerShmemState { diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 77ed56bd77..86a918b2e0 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -727,7 +727,7 @@ fn start_pageserver( disk_usage_eviction_state, deletion_queue.new_client(), secondary_controller, - feature_resolver, + feature_resolver.clone(), ) .context("Failed to initialize router state")?, ); @@ -853,14 +853,14 @@ fn start_pageserver( } else { None }, + feature_resolver.clone(), ); - // Spawn a Pageserver gRPC server task. It will spawn separate tasks for - // each stream/request. + // Spawn a Pageserver gRPC server task. It will spawn separate tasks for each request/stream. + // It uses a separate compute request Tokio runtime (COMPUTE_REQUEST_RUNTIME). // - // TODO: this uses a separate Tokio runtime for the page service. If we want - // other gRPC services, they will need their own port and runtime. Is this - // necessary? + // NB: this port is exposed to computes. It should only provide services that we're okay with + // computes accessing. Internal services should use a separate port. let mut page_service_grpc = None; if let Some(grpc_listener) = grpc_listener { page_service_grpc = Some(GrpcPageServiceHandler::spawn( diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index db01043413..5993a1e319 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -486,6 +486,8 @@ async fn build_timeline_info_common( *timeline.get_applied_gc_cutoff_lsn(), ); + let (rel_size_migration, rel_size_migrated_at) = timeline.get_rel_size_v2_status(); + let info = TimelineInfo { tenant_id: timeline.tenant_shard_id, timeline_id: timeline.timeline_id, @@ -517,7 +519,8 @@ async fn build_timeline_info_common( state, is_archived: Some(is_archived), - rel_size_migration: Some(timeline.get_rel_size_v2_status()), + rel_size_migration: Some(rel_size_migration), + rel_size_migrated_at, is_invisible: Some(is_invisible), walreceiver_status, @@ -941,9 +944,16 @@ async fn timeline_patch_index_part_handler( active_timeline_of_active_tenant(&state.tenant_manager, tenant_shard_id, timeline_id) .await?; + if request_data.rel_size_migration.is_none() && request_data.rel_size_migrated_at.is_some() + { + return Err(ApiError::BadRequest(anyhow!( + "updating rel_size_migrated_at without rel_size_migration is not allowed" + ))); + } + if let Some(rel_size_migration) = request_data.rel_size_migration { timeline - .update_rel_size_v2_status(rel_size_migration) + .update_rel_size_v2_status(rel_size_migration, request_data.rel_size_migrated_at) .map_err(ApiError::InternalServerError)?; } @@ -2006,6 +2016,10 @@ async fn put_tenant_location_config_handler( let state = get_state(&request); let conf = state.conf; + fail::fail_point!("put-location-conf-handler", |_| { + Err(ApiError::ResourceUnavailable("failpoint".into())) + }); + // The `Detached` state is special, it doesn't upsert a tenant, it removes // its local disk content and drops it from memory. if let LocationConfigMode::Detached = request_data.config.mode { diff --git a/pageserver/src/import_datadir.rs b/pageserver/src/import_datadir.rs index 409cc2e3c5..5b674adbb3 100644 --- a/pageserver/src/import_datadir.rs +++ b/pageserver/src/import_datadir.rs @@ -57,7 +57,7 @@ pub async fn import_timeline_from_postgres_datadir( // TODO this shoud be start_lsn, which is not necessarily equal to end_lsn (aka lsn) // Then fishing out pg_control would be unnecessary - let mut modification = tline.begin_modification(pgdata_lsn); + let mut modification = tline.begin_modification_for_import(pgdata_lsn); modification.init_empty()?; // Import all but pg_wal @@ -309,7 +309,7 @@ async fn import_wal( waldecoder.feed_bytes(&buf); let mut nrecords = 0; - let mut modification = tline.begin_modification(last_lsn); + let mut modification = tline.begin_modification_for_import(last_lsn); while last_lsn <= endpoint { if let Some((lsn, recdata)) = waldecoder.poll_decode()? { let interpreted = InterpretedWalRecord::from_bytes_filtered( @@ -357,7 +357,7 @@ pub async fn import_basebackup_from_tar( ctx: &RequestContext, ) -> Result<()> { info!("importing base at {base_lsn}"); - let mut modification = tline.begin_modification(base_lsn); + let mut modification = tline.begin_modification_for_import(base_lsn); modification.init_empty()?; let mut pg_control: Option = None; @@ -457,7 +457,7 @@ pub async fn import_wal_from_tar( waldecoder.feed_bytes(&bytes[offset..]); - let mut modification = tline.begin_modification(last_lsn); + let mut modification = tline.begin_modification_for_import(last_lsn); while last_lsn <= end_lsn { if let Some((lsn, recdata)) = waldecoder.poll_decode()? { let interpreted = InterpretedWalRecord::from_bytes_filtered( diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 26a23da66f..a0998a7598 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -68,6 +68,7 @@ use crate::config::PageServerConf; use crate::context::{ DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder, }; +use crate::feature_resolver::FeatureResolver; use crate::metrics::{ self, COMPUTE_COMMANDS_COUNTERS, ComputeCommandKind, GetPageBatchBreakReason, LIVE_CONNECTIONS, MISROUTED_PAGESTREAM_REQUESTS, PAGESTREAM_HANDLER_RESULTS_TOTAL, SmgrOpTimer, TimelineMetrics, @@ -139,6 +140,7 @@ pub fn spawn( perf_trace_dispatch: Option, tcp_listener: tokio::net::TcpListener, tls_config: Option>, + feature_resolver: FeatureResolver, ) -> Listener { let cancel = CancellationToken::new(); let libpq_ctx = RequestContext::todo_child( @@ -160,6 +162,7 @@ pub fn spawn( conf.pg_auth_type, tls_config, conf.page_service_pipelining.clone(), + feature_resolver, libpq_ctx, cancel.clone(), ) @@ -218,6 +221,7 @@ pub async fn libpq_listener_main( auth_type: AuthType, tls_config: Option>, pipelining_config: PageServicePipeliningConfig, + feature_resolver: FeatureResolver, listener_ctx: RequestContext, listener_cancel: CancellationToken, ) -> Connections { @@ -261,6 +265,7 @@ pub async fn libpq_listener_main( auth_type, tls_config.clone(), pipelining_config.clone(), + feature_resolver.clone(), connection_ctx, connections_cancel.child_token(), gate_guard, @@ -303,6 +308,7 @@ async fn page_service_conn_main( auth_type: AuthType, tls_config: Option>, pipelining_config: PageServicePipeliningConfig, + feature_resolver: FeatureResolver, connection_ctx: RequestContext, cancel: CancellationToken, gate_guard: GateGuard, @@ -370,6 +376,7 @@ async fn page_service_conn_main( perf_span_fields, connection_ctx, cancel.clone(), + feature_resolver.clone(), gate_guard, ); let pgbackend = @@ -421,6 +428,8 @@ struct PageServerHandler { pipelining_config: PageServicePipeliningConfig, get_vectored_concurrent_io: GetVectoredConcurrentIo, + feature_resolver: FeatureResolver, + gate_guard: GateGuard, } @@ -535,6 +544,7 @@ impl timeline::handle::TenantManager for TenantManagerWrappe match resolved { ShardResolveResult::Found(tenant_shard) => break tenant_shard, ShardResolveResult::NotFound => { + MISROUTED_PAGESTREAM_REQUESTS.inc(); return Err(GetActiveTimelineError::Tenant( GetActiveTenantError::NotFound(GetTenantError::NotFound(*tenant_id)), )); @@ -586,6 +596,15 @@ impl timeline::handle::TenantManager for TenantManagerWrappe } } +/// Whether to hold the applied GC cutoff guard when processing GetPage requests. +/// This is determined once at the start of pagestream subprotocol handling based on +/// feature flags, configuration, and test conditions. +#[derive(Debug, Clone, Copy)] +enum HoldAppliedGcCutoffGuard { + Yes, + No, +} + #[derive(thiserror::Error, Debug)] enum PageStreamError { /// We encountered an error that should prompt the client to reconnect: @@ -729,6 +748,7 @@ enum BatchedFeMessage { GetPage { span: Span, shard: WeakHandle, + applied_gc_cutoff_guard: Option>, pages: SmallVec<[BatchedGetPageRequest; 1]>, batch_break_reason: GetPageBatchBreakReason, }, @@ -908,6 +928,7 @@ impl PageServerHandler { perf_span_fields: ConnectionPerfSpanFields, connection_ctx: RequestContext, cancel: CancellationToken, + feature_resolver: FeatureResolver, gate_guard: GateGuard, ) -> Self { PageServerHandler { @@ -919,6 +940,7 @@ impl PageServerHandler { cancel, pipelining_config, get_vectored_concurrent_io, + feature_resolver, gate_guard, } } @@ -958,6 +980,7 @@ impl PageServerHandler { ctx: &RequestContext, protocol_version: PagestreamProtocolVersion, parent_span: Span, + hold_gc_cutoff_guard: HoldAppliedGcCutoffGuard, ) -> Result, QueryError> where IO: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, @@ -1195,19 +1218,27 @@ impl PageServerHandler { }) .await?; + let applied_gc_cutoff_guard = shard.get_applied_gc_cutoff_lsn(); // hold guard // We're holding the Handle let effective_lsn = match Self::effective_request_lsn( &shard, shard.get_last_record_lsn(), req.hdr.request_lsn, req.hdr.not_modified_since, - &shard.get_applied_gc_cutoff_lsn(), + &applied_gc_cutoff_guard, ) { Ok(lsn) => lsn, Err(e) => { return respond_error!(span, e); } }; + let applied_gc_cutoff_guard = match hold_gc_cutoff_guard { + HoldAppliedGcCutoffGuard::Yes => Some(applied_gc_cutoff_guard), + HoldAppliedGcCutoffGuard::No => { + drop(applied_gc_cutoff_guard); + None + } + }; let batch_wait_ctx = if ctx.has_perf_span() { Some( @@ -1228,6 +1259,7 @@ impl PageServerHandler { BatchedFeMessage::GetPage { span, shard: shard.downgrade(), + applied_gc_cutoff_guard, pages: smallvec![BatchedGetPageRequest { req, timer, @@ -1328,13 +1360,28 @@ impl PageServerHandler { match (eligible_batch, this_msg) { ( BatchedFeMessage::GetPage { - pages: accum_pages, .. + pages: accum_pages, + applied_gc_cutoff_guard: accum_applied_gc_cutoff_guard, + .. }, BatchedFeMessage::GetPage { - pages: this_pages, .. + pages: this_pages, + applied_gc_cutoff_guard: this_applied_gc_cutoff_guard, + .. }, ) => { accum_pages.extend(this_pages); + // the minimum of the two guards will keep data for both alive + match (&accum_applied_gc_cutoff_guard, this_applied_gc_cutoff_guard) { + (None, None) => (), + (None, Some(this)) => *accum_applied_gc_cutoff_guard = Some(this), + (Some(_), None) => (), + (Some(accum), Some(this)) => { + if **accum > *this { + *accum_applied_gc_cutoff_guard = Some(this); + } + } + }; Ok(()) } #[cfg(feature = "testing")] @@ -1649,6 +1696,7 @@ impl PageServerHandler { BatchedFeMessage::GetPage { span, shard, + applied_gc_cutoff_guard, pages, batch_break_reason, } => { @@ -1668,6 +1716,7 @@ impl PageServerHandler { .instrument(span.clone()) .await; assert_eq!(res.len(), npages); + drop(applied_gc_cutoff_guard); res }, span, @@ -1749,7 +1798,7 @@ impl PageServerHandler { /// Coding discipline within this function: all interaction with the `pgb` connection /// needs to be sensitive to connection shutdown, currently signalled via [`Self::cancel`]. /// This is so that we can shutdown page_service quickly. - #[instrument(skip_all)] + #[instrument(skip_all, fields(hold_gc_cutoff_guard))] async fn handle_pagerequests( &mut self, pgb: &mut PostgresBackend, @@ -1795,6 +1844,30 @@ impl PageServerHandler { .take() .expect("implementation error: timeline_handles should not be locked"); + // Evaluate the expensive feature resolver check once per pagestream subprotocol handling + // instead of once per GetPage request. This is shared between pipelined and serial paths. + let hold_gc_cutoff_guard = if cfg!(test) || cfg!(feature = "testing") { + HoldAppliedGcCutoffGuard::Yes + } else { + // Use the global feature resolver with the tenant ID directly, avoiding the need + // to get a timeline/shard which might not be available on this pageserver node. + let empty_properties = std::collections::HashMap::new(); + match self.feature_resolver.evaluate_boolean( + "page-service-getpage-hold-applied-gc-cutoff-guard", + tenant_id, + &empty_properties, + ) { + Ok(()) => HoldAppliedGcCutoffGuard::Yes, + Err(_) => HoldAppliedGcCutoffGuard::No, + } + }; + // record it in the span of handle_pagerequests so that both the request_span + // and the pipeline implementation spans contains the field. + Span::current().record( + "hold_gc_cutoff_guard", + tracing::field::debug(&hold_gc_cutoff_guard), + ); + let request_span = info_span!("request"); let ((pgb_reader, timeline_handles), result) = match self.pipelining_config.clone() { PageServicePipeliningConfig::Pipelined(pipelining_config) => { @@ -1808,6 +1881,7 @@ impl PageServerHandler { pipelining_config, protocol_version, io_concurrency, + hold_gc_cutoff_guard, &ctx, ) .await @@ -1822,6 +1896,7 @@ impl PageServerHandler { request_span, protocol_version, io_concurrency, + hold_gc_cutoff_guard, &ctx, ) .await @@ -1850,6 +1925,7 @@ impl PageServerHandler { request_span: Span, protocol_version: PagestreamProtocolVersion, io_concurrency: IoConcurrency, + hold_gc_cutoff_guard: HoldAppliedGcCutoffGuard, ctx: &RequestContext, ) -> ( (PostgresBackendReader, TimelineHandles), @@ -1871,6 +1947,7 @@ impl PageServerHandler { ctx, protocol_version, request_span.clone(), + hold_gc_cutoff_guard, ) .await; let msg = match msg { @@ -1918,6 +1995,7 @@ impl PageServerHandler { pipelining_config: PageServicePipeliningConfigPipelined, protocol_version: PagestreamProtocolVersion, io_concurrency: IoConcurrency, + hold_gc_cutoff_guard: HoldAppliedGcCutoffGuard, ctx: &RequestContext, ) -> ( (PostgresBackendReader, TimelineHandles), @@ -2021,6 +2099,7 @@ impl PageServerHandler { &ctx, protocol_version, request_span.clone(), + hold_gc_cutoff_guard, ) .await; let Some(read_res) = read_res.transpose() else { @@ -2067,6 +2146,7 @@ impl PageServerHandler { pages, span: _, shard: _, + applied_gc_cutoff_guard: _, batch_break_reason: _, } = &mut batch { @@ -3428,8 +3508,6 @@ impl GrpcPageServiceHandler { /// NB: errors returned from here are intercepted in get_pages(), and may be converted to a /// GetPageResponse with an appropriate status code to avoid terminating the stream. /// - /// TODO: verify that the requested pages belong to this shard. - /// /// TODO: get_vectored() currently enforces a batch limit of 32. Postgres will typically send /// batches up to effective_io_concurrency = 100. Either we have to accept large batches, or /// split them up in the client or server. @@ -3455,6 +3533,19 @@ impl GrpcPageServiceHandler { lsn = %req.read_lsn, ); + for &blkno in &req.block_numbers { + let shard = timeline.get_shard_identity(); + let key = rel_block_to_key(req.rel, blkno); + if !shard.is_key_local(&key) { + return Err(tonic::Status::invalid_argument(format!( + "block {blkno} of relation {} requested on wrong shard {} (is on {})", + req.rel, + timeline.get_shard_index(), + ShardIndex::new(shard.get_shard_number(&key), shard.count), + ))); + } + } + let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); // hold guard let effective_lsn = PageServerHandler::effective_request_lsn( &timeline, diff --git a/pageserver/src/pgdatadir_mapping.rs b/pageserver/src/pgdatadir_mapping.rs index c9f3184188..cedf77fb37 100644 --- a/pageserver/src/pgdatadir_mapping.rs +++ b/pageserver/src/pgdatadir_mapping.rs @@ -6,7 +6,7 @@ //! walingest.rs handles a few things like implicit relation creation and extension. //! Clarify that) //! -use std::collections::{HashMap, HashSet, hash_map}; +use std::collections::{BTreeSet, HashMap, HashSet, hash_map}; use std::ops::{ControlFlow, Range}; use std::sync::Arc; @@ -227,6 +227,25 @@ impl Timeline { pending_nblocks: 0, pending_directory_entries: Vec::new(), pending_metadata_bytes: 0, + is_importing_pgdata: false, + lsn, + } + } + + pub fn begin_modification_for_import(&self, lsn: Lsn) -> DatadirModification + where + Self: Sized, + { + DatadirModification { + tline: self, + pending_lsns: Vec::new(), + pending_metadata_pages: HashMap::new(), + pending_data_batch: None, + pending_deletions: Vec::new(), + pending_nblocks: 0, + pending_directory_entries: Vec::new(), + pending_metadata_bytes: 0, + is_importing_pgdata: true, lsn, } } @@ -596,6 +615,50 @@ impl Timeline { self.get_rel_exists_in_reldir(tag, version, None, ctx).await } + async fn get_rel_exists_in_reldir_v1( + &self, + tag: RelTag, + version: Version<'_>, + deserialized_reldir_v1: Option<(Key, &RelDirectory)>, + ctx: &RequestContext, + ) -> Result { + let key = rel_dir_to_key(tag.spcnode, tag.dbnode); + if let Some((cached_key, dir)) = deserialized_reldir_v1 { + if cached_key == key { + return Ok(dir.rels.contains(&(tag.relnode, tag.forknum))); + } else if cfg!(test) || cfg!(feature = "testing") { + panic!("cached reldir key mismatch: {cached_key} != {key}"); + } else { + warn!("cached reldir key mismatch: {cached_key} != {key}"); + } + // Fallback to reading the directory from the datadir. + } + + let buf = version.get(self, key, ctx).await?; + + let dir = RelDirectory::des(&buf)?; + Ok(dir.rels.contains(&(tag.relnode, tag.forknum))) + } + + async fn get_rel_exists_in_reldir_v2( + &self, + tag: RelTag, + version: Version<'_>, + ctx: &RequestContext, + ) -> Result { + let key = rel_tag_sparse_key(tag.spcnode, tag.dbnode, tag.relnode, tag.forknum); + let buf = RelDirExists::decode_option(version.sparse_get(self, key, ctx).await?).map_err( + |_| { + PageReconstructError::Other(anyhow::anyhow!( + "invalid reldir key: decode failed, {}", + key + )) + }, + )?; + let exists_v2 = buf == RelDirExists::Exists; + Ok(exists_v2) + } + /// Does the relation exist? With a cached deserialized `RelDirectory`. /// /// There are some cases where the caller loops across all relations. In that specific case, @@ -627,45 +690,134 @@ impl Timeline { return Ok(false); } - // Read path: first read the new reldir keyspace. Early return if the relation exists. - // Otherwise, read the old reldir keyspace. - // TODO: if IndexPart::rel_size_migration is `Migrated`, we only need to read from v2. + let (v2_status, migrated_lsn) = self.get_rel_size_v2_status(); - if let RelSizeMigration::Migrated | RelSizeMigration::Migrating = - self.get_rel_size_v2_status() - { - // fetch directory listing (new) - let key = rel_tag_sparse_key(tag.spcnode, tag.dbnode, tag.relnode, tag.forknum); - let buf = RelDirExists::decode_option(version.sparse_get(self, key, ctx).await?) - .map_err(|_| PageReconstructError::Other(anyhow::anyhow!("invalid reldir key")))?; - let exists_v2 = buf == RelDirExists::Exists; - // Fast path: if the relation exists in the new format, return true. - // TODO: we should have a verification mode that checks both keyspaces - // to ensure the relation only exists in one of them. - if exists_v2 { - return Ok(true); + match v2_status { + RelSizeMigration::Legacy => { + let v1_exists = self + .get_rel_exists_in_reldir_v1(tag, version, deserialized_reldir_v1, ctx) + .await?; + Ok(v1_exists) + } + RelSizeMigration::Migrating | RelSizeMigration::Migrated + if version.get_lsn() < migrated_lsn.unwrap_or(Lsn(0)) => + { + // For requests below the migrated LSN, we still use the v1 read path. + let v1_exists = self + .get_rel_exists_in_reldir_v1(tag, version, deserialized_reldir_v1, ctx) + .await?; + Ok(v1_exists) + } + RelSizeMigration::Migrating => { + let v1_exists = self + .get_rel_exists_in_reldir_v1(tag, version, deserialized_reldir_v1, ctx) + .await?; + let v2_exists_res = self.get_rel_exists_in_reldir_v2(tag, version, ctx).await; + match v2_exists_res { + Ok(v2_exists) if v1_exists == v2_exists => {} + Ok(v2_exists) => { + tracing::warn!( + "inconsistent v1/v2 reldir keyspace for rel {}: v1_exists={}, v2_exists={}", + tag, + v1_exists, + v2_exists + ); + } + Err(e) => { + tracing::warn!("failed to get rel exists in v2: {e}"); + } + } + Ok(v1_exists) + } + RelSizeMigration::Migrated => { + let v2_exists = self.get_rel_exists_in_reldir_v2(tag, version, ctx).await?; + Ok(v2_exists) } } + } - // fetch directory listing (old) - - let key = rel_dir_to_key(tag.spcnode, tag.dbnode); - - if let Some((cached_key, dir)) = deserialized_reldir_v1 { - if cached_key == key { - return Ok(dir.rels.contains(&(tag.relnode, tag.forknum))); - } else if cfg!(test) || cfg!(feature = "testing") { - panic!("cached reldir key mismatch: {cached_key} != {key}"); - } else { - warn!("cached reldir key mismatch: {cached_key} != {key}"); - } - // Fallback to reading the directory from the datadir. - } + async fn list_rels_v1( + &self, + spcnode: Oid, + dbnode: Oid, + version: Version<'_>, + ctx: &RequestContext, + ) -> Result, PageReconstructError> { + let key = rel_dir_to_key(spcnode, dbnode); let buf = version.get(self, key, ctx).await?; - let dir = RelDirectory::des(&buf)?; - let exists_v1 = dir.rels.contains(&(tag.relnode, tag.forknum)); - Ok(exists_v1) + let rels_v1: HashSet = + HashSet::from_iter(dir.rels.iter().map(|(relnode, forknum)| RelTag { + spcnode, + dbnode, + relnode: *relnode, + forknum: *forknum, + })); + Ok(rels_v1) + } + + async fn list_rels_v2( + &self, + spcnode: Oid, + dbnode: Oid, + version: Version<'_>, + ctx: &RequestContext, + ) -> Result, PageReconstructError> { + let key_range = rel_tag_sparse_key_range(spcnode, dbnode); + let io_concurrency = IoConcurrency::spawn_from_conf( + self.conf.get_vectored_concurrent_io, + self.gate + .enter() + .map_err(|_| PageReconstructError::Cancelled)?, + ); + let results = self + .scan( + KeySpace::single(key_range), + version.get_lsn(), + ctx, + io_concurrency, + ) + .await?; + let mut rels = HashSet::new(); + for (key, val) in results { + let val = RelDirExists::decode(&val?).map_err(|_| { + PageReconstructError::Other(anyhow::anyhow!( + "invalid reldir key: decode failed, {}", + key + )) + })?; + if key.field6 != 1 { + return Err(PageReconstructError::Other(anyhow::anyhow!( + "invalid reldir key: field6 != 1, {}", + key + ))); + } + if key.field2 != spcnode { + return Err(PageReconstructError::Other(anyhow::anyhow!( + "invalid reldir key: field2 != spcnode, {}", + key + ))); + } + if key.field3 != dbnode { + return Err(PageReconstructError::Other(anyhow::anyhow!( + "invalid reldir key: field3 != dbnode, {}", + key + ))); + } + let tag = RelTag { + spcnode, + dbnode, + relnode: key.field4, + forknum: key.field5, + }; + if val == RelDirExists::Removed { + debug_assert!(!rels.contains(&tag), "removed reltag in v2"); + continue; + } + let did_not_contain = rels.insert(tag); + debug_assert!(did_not_contain, "duplicate reltag in v2"); + } + Ok(rels) } /// Get a list of all existing relations in given tablespace and database. @@ -683,60 +835,45 @@ impl Timeline { version: Version<'_>, ctx: &RequestContext, ) -> Result, PageReconstructError> { - // fetch directory listing (old) - let key = rel_dir_to_key(spcnode, dbnode); - let buf = version.get(self, key, ctx).await?; + let (v2_status, migrated_lsn) = self.get_rel_size_v2_status(); - let dir = RelDirectory::des(&buf)?; - let rels_v1: HashSet = - HashSet::from_iter(dir.rels.iter().map(|(relnode, forknum)| RelTag { - spcnode, - dbnode, - relnode: *relnode, - forknum: *forknum, - })); - - if let RelSizeMigration::Legacy = self.get_rel_size_v2_status() { - return Ok(rels_v1); - } - - // scan directory listing (new), merge with the old results - let key_range = rel_tag_sparse_key_range(spcnode, dbnode); - let io_concurrency = IoConcurrency::spawn_from_conf( - self.conf.get_vectored_concurrent_io, - self.gate - .enter() - .map_err(|_| PageReconstructError::Cancelled)?, - ); - let results = self - .scan( - KeySpace::single(key_range), - version.get_lsn(), - ctx, - io_concurrency, - ) - .await?; - let mut rels = rels_v1; - for (key, val) in results { - let val = RelDirExists::decode(&val?) - .map_err(|_| PageReconstructError::Other(anyhow::anyhow!("invalid reldir key")))?; - assert_eq!(key.field6, 1); - assert_eq!(key.field2, spcnode); - assert_eq!(key.field3, dbnode); - let tag = RelTag { - spcnode, - dbnode, - relnode: key.field4, - forknum: key.field5, - }; - if val == RelDirExists::Removed { - debug_assert!(!rels.contains(&tag), "removed reltag in v2"); - continue; + match v2_status { + RelSizeMigration::Legacy => { + let rels_v1 = self.list_rels_v1(spcnode, dbnode, version, ctx).await?; + Ok(rels_v1) + } + RelSizeMigration::Migrating | RelSizeMigration::Migrated + if version.get_lsn() < migrated_lsn.unwrap_or(Lsn(0)) => + { + // For requests below the migrated LSN, we still use the v1 read path. + let rels_v1 = self.list_rels_v1(spcnode, dbnode, version, ctx).await?; + Ok(rels_v1) + } + RelSizeMigration::Migrating => { + let rels_v1 = self.list_rels_v1(spcnode, dbnode, version, ctx).await?; + let rels_v2_res = self.list_rels_v2(spcnode, dbnode, version, ctx).await; + match rels_v2_res { + Ok(rels_v2) if rels_v1 == rels_v2 => {} + Ok(rels_v2) => { + tracing::warn!( + "inconsistent v1/v2 reldir keyspace for db {} {}: v1_rels.len()={}, v2_rels.len()={}", + spcnode, + dbnode, + rels_v1.len(), + rels_v2.len() + ); + } + Err(e) => { + tracing::warn!("failed to list rels in v2: {e}"); + } + } + Ok(rels_v1) + } + RelSizeMigration::Migrated => { + let rels_v2 = self.list_rels_v2(spcnode, dbnode, version, ctx).await?; + Ok(rels_v2) } - let did_not_contain = rels.insert(tag); - debug_assert!(did_not_contain, "duplicate reltag in v2"); } - Ok(rels) } /// Get the whole SLRU segment @@ -1258,10 +1395,10 @@ impl Timeline { let mut dbdir_cnt = 0; let mut rel_cnt = 0; - for (spcnode, dbnode) in dbdir.dbdirs.keys() { + for &(spcnode, dbnode) in dbdir.dbdirs.keys() { dbdir_cnt += 1; for rel in self - .list_rels(*spcnode, *dbnode, Version::at(lsn), ctx) + .list_rels(spcnode, dbnode, Version::at(lsn), ctx) .await? { rel_cnt += 1; @@ -1566,6 +1703,9 @@ pub struct DatadirModification<'a> { /// An **approximation** of how many metadata bytes will be written to the EphemeralFile. pending_metadata_bytes: usize, + + /// Whether we are importing a pgdata directory. + is_importing_pgdata: bool, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -1578,6 +1718,14 @@ pub enum MetricsUpdate { Sub(u64), } +/// Controls the behavior of the reldir keyspace. +pub struct RelDirMode { + // Whether we can read the v2 keyspace or not. + current_status: RelSizeMigration, + // Whether we should initialize the v2 keyspace or not. + initialize: bool, +} + impl DatadirModification<'_> { // When a DatadirModification is committed, we do a monolithic serialization of all its contents. WAL records can // contain multiple pages, so the pageserver's record-based batch size isn't sufficient to bound this allocation: we @@ -1933,30 +2081,49 @@ impl DatadirModification<'_> { } /// Returns `true` if the rel_size_v2 write path is enabled. If it is the first time that - /// we enable it, we also need to persist it in `index_part.json`. - pub fn maybe_enable_rel_size_v2(&mut self) -> anyhow::Result { - let status = self.tline.get_rel_size_v2_status(); + /// we enable it, we also need to persist it in `index_part.json` (initialize is true). + /// + /// As this function is only used on the write path, we do not need to read the migrated_at + /// field. + pub fn maybe_enable_rel_size_v2(&mut self, is_create: bool) -> anyhow::Result { + // TODO: define the behavior of the tenant-level config flag and use feature flag to enable this feature + + let (status, _) = self.tline.get_rel_size_v2_status(); let config = self.tline.get_rel_size_v2_enabled(); match (config, status) { (false, RelSizeMigration::Legacy) => { // tenant config didn't enable it and we didn't write any reldir_v2 key yet - Ok(false) + Ok(RelDirMode { + current_status: RelSizeMigration::Legacy, + initialize: false, + }) } - (false, RelSizeMigration::Migrating | RelSizeMigration::Migrated) => { + (false, status @ RelSizeMigration::Migrating | status @ RelSizeMigration::Migrated) => { // index_part already persisted that the timeline has enabled rel_size_v2 - Ok(true) + Ok(RelDirMode { + current_status: status, + initialize: false, + }) } (true, RelSizeMigration::Legacy) => { // The first time we enable it, we need to persist it in `index_part.json` - self.tline - .update_rel_size_v2_status(RelSizeMigration::Migrating)?; - tracing::info!("enabled rel_size_v2"); - Ok(true) + // The caller should update the reldir status once the initialization is done. + // + // Only initialize the v2 keyspace on new relation creation. No initialization + // during `timeline_create` (TODO: fix this, we should allow, but currently it + // hits consistency issues). + Ok(RelDirMode { + current_status: RelSizeMigration::Legacy, + initialize: is_create && !self.is_importing_pgdata, + }) } - (true, RelSizeMigration::Migrating | RelSizeMigration::Migrated) => { + (true, status @ RelSizeMigration::Migrating | status @ RelSizeMigration::Migrated) => { // index_part already persisted that the timeline has enabled rel_size_v2 // and we don't need to do anything - Ok(true) + Ok(RelDirMode { + current_status: status, + initialize: false, + }) } } } @@ -1969,8 +2136,8 @@ impl DatadirModification<'_> { img: Bytes, ctx: &RequestContext, ) -> Result<(), WalIngestError> { - let v2_enabled = self - .maybe_enable_rel_size_v2() + let v2_mode = self + .maybe_enable_rel_size_v2(false) .map_err(WalIngestErrorKind::MaybeRelSizeV2Error)?; // Add it to the directory (if it doesn't exist already) @@ -1986,17 +2153,19 @@ impl DatadirModification<'_> { self.put(DBDIR_KEY, Value::Image(buf.into())); } if r.is_none() { - // Create RelDirectory - // TODO: if we have fully migrated to v2, no need to create this directory + if v2_mode.current_status != RelSizeMigration::Legacy { + self.pending_directory_entries + .push((DirectoryKind::RelV2, MetricsUpdate::Set(0))); + } + + // Create RelDirectory in v1 keyspace. TODO: if we have fully migrated to v2, no need to create this directory. + // Some code path relies on this directory to be present. We should remove it once we starts to set tenants to + // `RelSizeMigration::Migrated` state (currently we don't, all tenants will have `RelSizeMigration::Migrating`). let buf = RelDirectory::ser(&RelDirectory { rels: HashSet::new(), })?; self.pending_directory_entries .push((DirectoryKind::Rel, MetricsUpdate::Set(0))); - if v2_enabled { - self.pending_directory_entries - .push((DirectoryKind::RelV2, MetricsUpdate::Set(0))); - } self.put( rel_dir_to_key(spcnode, dbnode), Value::Image(Bytes::from(buf)), @@ -2103,6 +2272,109 @@ impl DatadirModification<'_> { Ok(()) } + async fn initialize_rel_size_v2_keyspace( + &mut self, + ctx: &RequestContext, + dbdir: &DbDirectory, + ) -> Result<(), WalIngestError> { + // Copy everything from relv1 to relv2; TODO: check if there's any key in the v2 keyspace, if so, abort. + tracing::info!("initializing rel_size_v2 keyspace"); + let mut rel_cnt = 0; + // relmap_exists (the value of dbdirs hashmap) does not affect the migration: we need to copy things over anyways + for &(spcnode, dbnode) in dbdir.dbdirs.keys() { + let rel_dir_key = rel_dir_to_key(spcnode, dbnode); + let rel_dir = RelDirectory::des(&self.get(rel_dir_key, ctx).await?)?; + for (relnode, forknum) in rel_dir.rels { + let sparse_rel_dir_key = rel_tag_sparse_key(spcnode, dbnode, relnode, forknum); + self.put( + sparse_rel_dir_key, + Value::Image(RelDirExists::Exists.encode()), + ); + tracing::info!( + "migrated rel_size_v2: {}", + RelTag { + spcnode, + dbnode, + relnode, + forknum + } + ); + rel_cnt += 1; + } + } + tracing::info!( + "initialized rel_size_v2 keyspace at lsn {}: migrated {} relations", + self.lsn, + rel_cnt + ); + self.tline + .update_rel_size_v2_status(RelSizeMigration::Migrating, Some(self.lsn)) + .map_err(WalIngestErrorKind::MaybeRelSizeV2Error)?; + Ok::<_, WalIngestError>(()) + } + + async fn put_rel_creation_v1( + &mut self, + rel: RelTag, + dbdir_exists: bool, + ctx: &RequestContext, + ) -> Result<(), WalIngestError> { + // Reldir v1 write path + let rel_dir_key = rel_dir_to_key(rel.spcnode, rel.dbnode); + let mut rel_dir = if !dbdir_exists { + // Create the RelDirectory + RelDirectory::default() + } else { + // reldir already exists, fetch it + RelDirectory::des(&self.get(rel_dir_key, ctx).await?)? + }; + + // Add the new relation to the rel directory entry, and write it back + if !rel_dir.rels.insert((rel.relnode, rel.forknum)) { + Err(WalIngestErrorKind::RelationAlreadyExists(rel))?; + } + if !dbdir_exists { + self.pending_directory_entries + .push((DirectoryKind::Rel, MetricsUpdate::Set(0))) + } + self.pending_directory_entries + .push((DirectoryKind::Rel, MetricsUpdate::Add(1))); + self.put( + rel_dir_key, + Value::Image(Bytes::from(RelDirectory::ser(&rel_dir)?)), + ); + Ok(()) + } + + async fn put_rel_creation_v2( + &mut self, + rel: RelTag, + dbdir_exists: bool, + ctx: &RequestContext, + ) -> Result<(), WalIngestError> { + // Reldir v2 write path + let sparse_rel_dir_key = + rel_tag_sparse_key(rel.spcnode, rel.dbnode, rel.relnode, rel.forknum); + // check if the rel_dir_key exists in v2 + let val = self.sparse_get(sparse_rel_dir_key, ctx).await?; + let val = RelDirExists::decode_option(val) + .map_err(|_| WalIngestErrorKind::InvalidRelDirKey(sparse_rel_dir_key))?; + if val == RelDirExists::Exists { + Err(WalIngestErrorKind::RelationAlreadyExists(rel))?; + } + self.put( + sparse_rel_dir_key, + Value::Image(RelDirExists::Exists.encode()), + ); + if !dbdir_exists { + self.pending_directory_entries + .push((DirectoryKind::RelV2, MetricsUpdate::Set(0))); + } + self.pending_directory_entries + .push((DirectoryKind::RelV2, MetricsUpdate::Add(1))); + Ok(()) + } + /// Create a relation fork. /// /// 'nblocks' is the initial size. @@ -2136,66 +2408,31 @@ impl DatadirModification<'_> { true }; - let rel_dir_key = rel_dir_to_key(rel.spcnode, rel.dbnode); - let mut rel_dir = if !dbdir_exists { - // Create the RelDirectory - RelDirectory::default() - } else { - // reldir already exists, fetch it - RelDirectory::des(&self.get(rel_dir_key, ctx).await?)? - }; - - let v2_enabled = self - .maybe_enable_rel_size_v2() + let mut v2_mode = self + .maybe_enable_rel_size_v2(true) .map_err(WalIngestErrorKind::MaybeRelSizeV2Error)?; - if v2_enabled { - if rel_dir.rels.contains(&(rel.relnode, rel.forknum)) { - Err(WalIngestErrorKind::RelationAlreadyExists(rel))?; + if v2_mode.initialize { + if let Err(e) = self.initialize_rel_size_v2_keyspace(ctx, &dbdir).await { + tracing::warn!("error initializing rel_size_v2 keyspace: {}", e); + // TODO: circuit breaker so that it won't retry forever + } else { + v2_mode.current_status = RelSizeMigration::Migrating; } - let sparse_rel_dir_key = - rel_tag_sparse_key(rel.spcnode, rel.dbnode, rel.relnode, rel.forknum); - // check if the rel_dir_key exists in v2 - let val = self.sparse_get(sparse_rel_dir_key, ctx).await?; - let val = RelDirExists::decode_option(val) - .map_err(|_| WalIngestErrorKind::InvalidRelDirKey(sparse_rel_dir_key))?; - if val == RelDirExists::Exists { - Err(WalIngestErrorKind::RelationAlreadyExists(rel))?; + } + + if v2_mode.current_status != RelSizeMigration::Migrated { + self.put_rel_creation_v1(rel, dbdir_exists, ctx).await?; + } + + if v2_mode.current_status != RelSizeMigration::Legacy { + let write_v2_res = self.put_rel_creation_v2(rel, dbdir_exists, ctx).await; + if let Err(e) = write_v2_res { + if v2_mode.current_status == RelSizeMigration::Migrated { + return Err(e); + } + tracing::warn!("error writing rel_size_v2 keyspace: {}", e); } - self.put( - sparse_rel_dir_key, - Value::Image(RelDirExists::Exists.encode()), - ); - if !dbdir_exists { - self.pending_directory_entries - .push((DirectoryKind::Rel, MetricsUpdate::Set(0))); - self.pending_directory_entries - .push((DirectoryKind::RelV2, MetricsUpdate::Set(0))); - // We don't write `rel_dir_key -> rel_dir.rels` back to the storage in the v2 path unless it's the initial creation. - // TODO: if we have fully migrated to v2, no need to create this directory. Otherwise, there - // will be key not found errors if we don't create an empty one for rel_size_v2. - self.put( - rel_dir_key, - Value::Image(Bytes::from(RelDirectory::ser(&RelDirectory::default())?)), - ); - } - self.pending_directory_entries - .push((DirectoryKind::RelV2, MetricsUpdate::Add(1))); - } else { - // Add the new relation to the rel directory entry, and write it back - if !rel_dir.rels.insert((rel.relnode, rel.forknum)) { - Err(WalIngestErrorKind::RelationAlreadyExists(rel))?; - } - if !dbdir_exists { - self.pending_directory_entries - .push((DirectoryKind::Rel, MetricsUpdate::Set(0))) - } - self.pending_directory_entries - .push((DirectoryKind::Rel, MetricsUpdate::Add(1))); - self.put( - rel_dir_key, - Value::Image(Bytes::from(RelDirectory::ser(&rel_dir)?)), - ); } // Put size @@ -2270,15 +2507,12 @@ impl DatadirModification<'_> { Ok(()) } - /// Drop some relations - pub(crate) async fn put_rel_drops( + async fn put_rel_drop_v1( &mut self, drop_relations: HashMap<(u32, u32), Vec>, ctx: &RequestContext, - ) -> Result<(), WalIngestError> { - let v2_enabled = self - .maybe_enable_rel_size_v2() - .map_err(WalIngestErrorKind::MaybeRelSizeV2Error)?; + ) -> Result, WalIngestError> { + let mut dropped_rels = BTreeSet::new(); for ((spc_node, db_node), rel_tags) in drop_relations { let dir_key = rel_dir_to_key(spc_node, db_node); let buf = self.get(dir_key, ctx).await?; @@ -2290,25 +2524,8 @@ impl DatadirModification<'_> { self.pending_directory_entries .push((DirectoryKind::Rel, MetricsUpdate::Sub(1))); dirty = true; + dropped_rels.insert(rel_tag); true - } else if v2_enabled { - // The rel is not found in the old reldir key, so we need to check the new sparse keyspace. - // Note that a relation can only exist in one of the two keyspaces (guaranteed by the ingestion - // logic). - let key = - rel_tag_sparse_key(spc_node, db_node, rel_tag.relnode, rel_tag.forknum); - let val = RelDirExists::decode_option(self.sparse_get(key, ctx).await?) - .map_err(|_| WalIngestErrorKind::InvalidKey(key, self.lsn))?; - if val == RelDirExists::Exists { - self.pending_directory_entries - .push((DirectoryKind::RelV2, MetricsUpdate::Sub(1))); - // put tombstone - self.put(key, Value::Image(RelDirExists::Removed.encode())); - // no need to set dirty to true - true - } else { - false - } } else { false }; @@ -2331,7 +2548,67 @@ impl DatadirModification<'_> { self.put(dir_key, Value::Image(Bytes::from(RelDirectory::ser(&dir)?))); } } + Ok(dropped_rels) + } + async fn put_rel_drop_v2( + &mut self, + drop_relations: HashMap<(u32, u32), Vec>, + ctx: &RequestContext, + ) -> Result, WalIngestError> { + let mut dropped_rels = BTreeSet::new(); + for ((spc_node, db_node), rel_tags) in drop_relations { + for rel_tag in rel_tags { + let key = rel_tag_sparse_key(spc_node, db_node, rel_tag.relnode, rel_tag.forknum); + let val = RelDirExists::decode_option(self.sparse_get(key, ctx).await?) + .map_err(|_| WalIngestErrorKind::InvalidKey(key, self.lsn))?; + if val == RelDirExists::Exists { + dropped_rels.insert(rel_tag); + self.pending_directory_entries + .push((DirectoryKind::RelV2, MetricsUpdate::Sub(1))); + // put tombstone + self.put(key, Value::Image(RelDirExists::Removed.encode())); + } + } + } + Ok(dropped_rels) + } + + /// Drop some relations + pub(crate) async fn put_rel_drops( + &mut self, + drop_relations: HashMap<(u32, u32), Vec>, + ctx: &RequestContext, + ) -> Result<(), WalIngestError> { + let v2_mode = self + .maybe_enable_rel_size_v2(false) + .map_err(WalIngestErrorKind::MaybeRelSizeV2Error)?; + match v2_mode.current_status { + RelSizeMigration::Legacy => { + self.put_rel_drop_v1(drop_relations, ctx).await?; + } + RelSizeMigration::Migrating => { + let dropped_rels_v1 = self.put_rel_drop_v1(drop_relations.clone(), ctx).await?; + let dropped_rels_v2_res = self.put_rel_drop_v2(drop_relations, ctx).await; + match dropped_rels_v2_res { + Ok(dropped_rels_v2) => { + if dropped_rels_v1 != dropped_rels_v2 { + tracing::warn!( + "inconsistent v1/v2 rel drop: dropped_rels_v1.len()={}, dropped_rels_v2.len()={}", + dropped_rels_v1.len(), + dropped_rels_v2.len() + ); + } + } + Err(e) => { + tracing::warn!("error dropping rels: {}", e); + } + } + } + RelSizeMigration::Migrated => { + self.put_rel_drop_v2(drop_relations, ctx).await?; + } + } Ok(()) } diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index 4c8856c386..91b717a2e9 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -1205,6 +1205,7 @@ impl TenantShard { idempotency.clone(), index_part.gc_compaction.clone(), index_part.rel_size_migration.clone(), + index_part.rel_size_migrated_at, ctx, )?; let disk_consistent_lsn = timeline.get_disk_consistent_lsn(); @@ -2584,6 +2585,7 @@ impl TenantShard { initdb_lsn, None, None, + None, ctx, ) .await @@ -2913,6 +2915,7 @@ impl TenantShard { initdb_lsn, None, None, + None, ctx, ) .await @@ -4342,6 +4345,7 @@ impl TenantShard { create_idempotency: CreateTimelineIdempotency, gc_compaction_state: Option, rel_size_v2_status: Option, + rel_size_migrated_at: Option, ctx: &RequestContext, ) -> anyhow::Result<(Arc, RequestContext)> { let state = match cause { @@ -4376,6 +4380,7 @@ impl TenantShard { create_idempotency, gc_compaction_state, rel_size_v2_status, + rel_size_migrated_at, self.cancel.child_token(), ); @@ -5085,6 +5090,7 @@ impl TenantShard { src_timeline.pg_version, ); + let (rel_size_v2_status, rel_size_migrated_at) = src_timeline.get_rel_size_v2_status(); let (uninitialized_timeline, _timeline_ctx) = self .prepare_new_timeline( dst_id, @@ -5092,7 +5098,8 @@ impl TenantShard { timeline_create_guard, start_lsn + 1, Some(Arc::clone(src_timeline)), - Some(src_timeline.get_rel_size_v2_status()), + Some(rel_size_v2_status), + rel_size_migrated_at, ctx, ) .await?; @@ -5379,6 +5386,7 @@ impl TenantShard { pgdata_lsn, None, None, + None, ctx, ) .await?; @@ -5462,14 +5470,17 @@ impl TenantShard { start_lsn: Lsn, ancestor: Option>, rel_size_v2_status: Option, + rel_size_migrated_at: Option, ctx: &RequestContext, ) -> anyhow::Result<(UninitializedTimeline<'a>, RequestContext)> { let tenant_shard_id = self.tenant_shard_id; let resources = self.build_timeline_resources(new_timeline_id); - resources - .remote_client - .init_upload_queue_for_empty_remote(new_metadata, rel_size_v2_status.clone())?; + resources.remote_client.init_upload_queue_for_empty_remote( + new_metadata, + rel_size_v2_status.clone(), + rel_size_migrated_at, + )?; let (timeline_struct, timeline_ctx) = self .create_timeline_struct( @@ -5482,6 +5493,7 @@ impl TenantShard { create_guard.idempotency.clone(), None, rel_size_v2_status, + rel_size_migrated_at, ctx, ) .context("Failed to create timeline data structure")?; diff --git a/pageserver/src/tenant/remote_timeline_client.rs b/pageserver/src/tenant/remote_timeline_client.rs index fd65000379..6b650beb3f 100644 --- a/pageserver/src/tenant/remote_timeline_client.rs +++ b/pageserver/src/tenant/remote_timeline_client.rs @@ -443,7 +443,8 @@ impl RemoteTimelineClient { pub fn init_upload_queue_for_empty_remote( &self, local_metadata: &TimelineMetadata, - rel_size_v2_status: Option, + rel_size_v2_migration: Option, + rel_size_migrated_at: Option, ) -> anyhow::Result<()> { // Set the maximum number of inprogress tasks to the remote storage concurrency. There's // certainly no point in starting more upload tasks than this. @@ -455,7 +456,8 @@ impl RemoteTimelineClient { let mut upload_queue = self.upload_queue.lock().unwrap(); let initialized_queue = upload_queue.initialize_empty_remote(local_metadata, inprogress_limit)?; - initialized_queue.dirty.rel_size_migration = rel_size_v2_status; + initialized_queue.dirty.rel_size_migration = rel_size_v2_migration; + initialized_queue.dirty.rel_size_migrated_at = rel_size_migrated_at; self.update_remote_physical_size_gauge(None); info!("initialized upload queue as empty"); Ok(()) @@ -994,10 +996,12 @@ impl RemoteTimelineClient { pub(crate) fn schedule_index_upload_for_rel_size_v2_status_update( self: &Arc, rel_size_v2_status: RelSizeMigration, + rel_size_migrated_at: Option, ) -> anyhow::Result<()> { let mut guard = self.upload_queue.lock().unwrap(); let upload_queue = guard.initialized_mut()?; upload_queue.dirty.rel_size_migration = Some(rel_size_v2_status); + upload_queue.dirty.rel_size_migrated_at = rel_size_migrated_at; // TODO: allow this operation to bypass the validation check because we might upload the index part // with no layers but the flag updated. For now, we just modify the index part in memory and the next // upload will include the flag. diff --git a/pageserver/src/tenant/remote_timeline_client/index.rs b/pageserver/src/tenant/remote_timeline_client/index.rs index 6060c42cbb..9531e7f650 100644 --- a/pageserver/src/tenant/remote_timeline_client/index.rs +++ b/pageserver/src/tenant/remote_timeline_client/index.rs @@ -114,6 +114,11 @@ pub struct IndexPart { /// The timestamp when the timeline was marked invisible in synthetic size calculations. #[serde(skip_serializing_if = "Option::is_none", default)] pub(crate) marked_invisible_at: Option, + + /// The LSN at which we started the rel size migration. Accesses below this LSN should be + /// processed with the v1 read path. Usually this LSN should be set together with `rel_size_migration`. + #[serde(skip_serializing_if = "Option::is_none", default)] + pub(crate) rel_size_migrated_at: Option, } #[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] @@ -142,10 +147,12 @@ impl IndexPart { /// - 12: +l2_lsn /// - 13: +gc_compaction /// - 14: +marked_invisible_at - const LATEST_VERSION: usize = 14; + /// - 15: +rel_size_migrated_at + const LATEST_VERSION: usize = 15; // Versions we may see when reading from a bucket. - pub const KNOWN_VERSIONS: &'static [usize] = &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]; + pub const KNOWN_VERSIONS: &'static [usize] = + &[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]; pub const FILE_NAME: &'static str = "index_part.json"; @@ -165,6 +172,7 @@ impl IndexPart { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, } } @@ -475,6 +483,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -524,6 +533,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -574,6 +584,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -627,6 +638,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let empty_layers_parsed = IndexPart::from_json_bytes(empty_layers_json.as_bytes()).unwrap(); @@ -675,6 +687,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -726,6 +739,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -782,6 +796,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -843,6 +858,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -905,6 +921,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -972,6 +989,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -1052,6 +1070,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -1133,6 +1152,7 @@ mod tests { l2_lsn: None, gc_compaction: None, marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -1220,6 +1240,7 @@ mod tests { last_completed_lsn: "0/16960E8".parse::().unwrap(), }), marked_invisible_at: None, + rel_size_migrated_at: None, }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); @@ -1308,6 +1329,97 @@ mod tests { last_completed_lsn: "0/16960E8".parse::().unwrap(), }), marked_invisible_at: Some(parse_naive_datetime("2023-07-31T09:00:00.123000000")), + rel_size_migrated_at: None, + }; + + let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); + assert_eq!(part, expected); + } + + #[test] + fn v15_rel_size_migrated_at_is_parsed() { + let example = r#"{ + "version": 15, + "layer_metadata":{ + "000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9": { "file_size": 25600000 }, + "000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51": { "file_size": 9007199254741001 } + }, + "disk_consistent_lsn":"0/16960E8", + "metadata": { + "disk_consistent_lsn": "0/16960E8", + "prev_record_lsn": "0/1696070", + "ancestor_timeline": "e45a7f37d3ee2ff17dc14bf4f4e3f52e", + "ancestor_lsn": "0/0", + "latest_gc_cutoff_lsn": "0/1696070", + "initdb_lsn": "0/1696070", + "pg_version": 14 + }, + "gc_blocking": { + "started_at": "2024-07-19T09:00:00.123", + "reasons": ["DetachAncestor"] + }, + "import_pgdata": { + "V1": { + "Done": { + "idempotency_key": "specified-by-client-218a5213-5044-4562-a28d-d024c5f057f5", + "started_at": "2024-11-13T09:23:42.123", + "finished_at": "2024-11-13T09:42:23.123" + } + } + }, + "rel_size_migration": "legacy", + "l2_lsn": "0/16960E8", + "gc_compaction": { + "last_completed_lsn": "0/16960E8" + }, + "marked_invisible_at": "2023-07-31T09:00:00.123", + "rel_size_migrated_at": "0/16960E8" + }"#; + + let expected = IndexPart { + version: 15, + layer_metadata: HashMap::from([ + ("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__0000000001696070-00000000016960E9".parse().unwrap(), LayerFileMetadata { + file_size: 25600000, + generation: Generation::none(), + shard: ShardIndex::unsharded() + }), + ("000000000000000000000000000000000000-FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF__00000000016B59D8-00000000016B5A51".parse().unwrap(), LayerFileMetadata { + file_size: 9007199254741001, + generation: Generation::none(), + shard: ShardIndex::unsharded() + }) + ]), + disk_consistent_lsn: "0/16960E8".parse::().unwrap(), + metadata: TimelineMetadata::new( + Lsn::from_str("0/16960E8").unwrap(), + Some(Lsn::from_str("0/1696070").unwrap()), + Some(TimelineId::from_str("e45a7f37d3ee2ff17dc14bf4f4e3f52e").unwrap()), + Lsn::INVALID, + Lsn::from_str("0/1696070").unwrap(), + Lsn::from_str("0/1696070").unwrap(), + PgMajorVersion::PG14, + ).with_recalculated_checksum().unwrap(), + deleted_at: None, + lineage: Default::default(), + gc_blocking: Some(GcBlocking { + started_at: parse_naive_datetime("2024-07-19T09:00:00.123000000"), + reasons: enumset::EnumSet::from_iter([GcBlockingReason::DetachAncestor]), + }), + last_aux_file_policy: Default::default(), + archived_at: None, + import_pgdata: Some(import_pgdata::index_part_format::Root::V1(import_pgdata::index_part_format::V1::Done(import_pgdata::index_part_format::Done{ + started_at: parse_naive_datetime("2024-11-13T09:23:42.123000000"), + finished_at: parse_naive_datetime("2024-11-13T09:42:23.123000000"), + idempotency_key: import_pgdata::index_part_format::IdempotencyKey::new("specified-by-client-218a5213-5044-4562-a28d-d024c5f057f5".to_string()), + }))), + rel_size_migration: Some(RelSizeMigration::Legacy), + l2_lsn: Some("0/16960E8".parse::().unwrap()), + gc_compaction: Some(GcCompactionState { + last_completed_lsn: "0/16960E8".parse::().unwrap(), + }), + marked_invisible_at: Some(parse_naive_datetime("2023-07-31T09:00:00.123000000")), + rel_size_migrated_at: Some("0/16960E8".parse::().unwrap()), }; let part = IndexPart::from_json_bytes(example.as_bytes()).unwrap(); diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 3ef07aa414..2c70c5cfa5 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -70,7 +70,7 @@ use tracing::*; use utils::generation::Generation; use utils::guard_arc_swap::GuardArcSwap; use utils::id::TimelineId; -use utils::logging::{MonitorSlowFutureCallback, monitor_slow_future}; +use utils::logging::{MonitorSlowFutureCallback, log_slow, monitor_slow_future}; use utils::lsn::{AtomicLsn, Lsn, RecordLsn}; use utils::postgres_client::PostgresClientProtocol; use utils::rate_limit::RateLimit; @@ -441,7 +441,7 @@ pub struct Timeline { /// heatmap on demand. heatmap_layers_downloader: Mutex>, - pub(crate) rel_size_v2_status: ArcSwapOption, + pub(crate) rel_size_v2_status: ArcSwap<(Option, Option)>, wait_lsn_log_slow: tokio::sync::Semaphore, @@ -2894,12 +2894,9 @@ impl Timeline { .unwrap_or(self.conf.default_tenant_conf.rel_size_v2_enabled) } - pub(crate) fn get_rel_size_v2_status(&self) -> RelSizeMigration { - self.rel_size_v2_status - .load() - .as_ref() - .map(|s| s.as_ref().clone()) - .unwrap_or(RelSizeMigration::Legacy) + pub(crate) fn get_rel_size_v2_status(&self) -> (RelSizeMigration, Option) { + let (status, migrated_at) = self.rel_size_v2_status.load().as_ref().clone(); + (status.unwrap_or(RelSizeMigration::Legacy), migrated_at) } fn get_compaction_upper_limit(&self) -> usize { @@ -3174,6 +3171,7 @@ impl Timeline { create_idempotency: crate::tenant::CreateTimelineIdempotency, gc_compaction_state: Option, rel_size_v2_status: Option, + rel_size_migrated_at: Option, cancel: CancellationToken, ) -> Arc { let disk_consistent_lsn = metadata.disk_consistent_lsn(); @@ -3338,7 +3336,10 @@ impl Timeline { heatmap_layers_downloader: Mutex::new(None), - rel_size_v2_status: ArcSwapOption::from_pointee(rel_size_v2_status), + rel_size_v2_status: ArcSwap::from_pointee(( + rel_size_v2_status, + rel_size_migrated_at, + )), wait_lsn_log_slow: tokio::sync::Semaphore::new(1), @@ -3426,11 +3427,17 @@ impl Timeline { pub(crate) fn update_rel_size_v2_status( &self, rel_size_v2_status: RelSizeMigration, + rel_size_migrated_at: Option, ) -> anyhow::Result<()> { - self.rel_size_v2_status - .store(Some(Arc::new(rel_size_v2_status.clone()))); + self.rel_size_v2_status.store(Arc::new(( + Some(rel_size_v2_status.clone()), + rel_size_migrated_at, + ))); self.remote_client - .schedule_index_upload_for_rel_size_v2_status_update(rel_size_v2_status) + .schedule_index_upload_for_rel_size_v2_status_update( + rel_size_v2_status, + rel_size_migrated_at, + ) } pub(crate) fn get_gc_compaction_state(&self) -> Option { @@ -6891,7 +6898,13 @@ impl Timeline { write_guard.store_and_unlock(new_gc_cutoff) }; - waitlist.wait().await; + let waitlist_wait_fut = std::pin::pin!(waitlist.wait()); + log_slow( + "applied_gc_cutoff waitlist wait", + Duration::from_secs(30), + waitlist_wait_fut, + ) + .await; info!("GC starting"); diff --git a/pageserver/src/tenant/timeline/delete.rs b/pageserver/src/tenant/timeline/delete.rs index f7dc44be90..2f6eccdbf9 100644 --- a/pageserver/src/tenant/timeline/delete.rs +++ b/pageserver/src/tenant/timeline/delete.rs @@ -332,6 +332,7 @@ impl DeleteTimelineFlow { crate::tenant::CreateTimelineIdempotency::FailWithConflict, // doesn't matter what we put here None, // doesn't matter what we put here None, // doesn't matter what we put here + None, // doesn't matter what we put here ctx, ) .context("create_timeline_struct")?; diff --git a/pageserver/src/utilization.rs b/pageserver/src/utilization.rs index 0dafa5c4bb..cec28f8059 100644 --- a/pageserver/src/utilization.rs +++ b/pageserver/src/utilization.rs @@ -52,7 +52,7 @@ pub(crate) fn regenerate( }; // Express a static value for how many shards we may schedule on one node - const MAX_SHARDS: u32 = 5000; + const MAX_SHARDS: u32 = 2500; let mut doc = PageserverUtilization { disk_usage_bytes: used, diff --git a/pgxn/neon/Makefile b/pgxn/neon/Makefile index 34cabaca62..958ca5c378 100644 --- a/pgxn/neon/Makefile +++ b/pgxn/neon/Makefile @@ -33,6 +33,10 @@ SHLIB_LINK = -lcurl UNAME_S := $(shell uname -s) ifeq ($(UNAME_S), Darwin) SHLIB_LINK += -framework Security -framework CoreFoundation -framework SystemConfiguration + + # Link against object files for the current macOS version, to avoid spurious linker warnings. + MACOSX_DEPLOYMENT_TARGET := $(shell xcrun --sdk macosx --show-sdk-version) + export MACOSX_DEPLOYMENT_TARGET endif EXTENSION = neon diff --git a/pgxn/neon/communicator.c b/pgxn/neon/communicator.c index 5a08b3e331..4c03193d7e 100644 --- a/pgxn/neon/communicator.c +++ b/pgxn/neon/communicator.c @@ -79,10 +79,6 @@ #include "access/xlogrecovery.h" #endif -#if PG_VERSION_NUM < 160000 -typedef PGAlignedBlock PGIOAlignedBlock; -#endif - #define NEON_PANIC_CONNECTION_STATE(shard_no, elvl, message, ...) \ neon_shard_log(shard_no, elvl, "Broken connection state: " message, \ ##__VA_ARGS__) diff --git a/pgxn/neon/extension_server.c b/pgxn/neon/extension_server.c index 00dcb6920e..d64cd3e4af 100644 --- a/pgxn/neon/extension_server.c +++ b/pgxn/neon/extension_server.c @@ -14,7 +14,7 @@ #include "extension_server.h" #include "neon_utils.h" -static int extension_server_port = 0; +int hadron_extension_server_port = 0; static int extension_server_request_timeout = 60; static int extension_server_connect_timeout = 60; @@ -47,7 +47,7 @@ neon_download_extension_file_http(const char *filename, bool is_library) curl_easy_setopt(handle, CURLOPT_CONNECTTIMEOUT, (long)extension_server_connect_timeout /* seconds */ ); compute_ctl_url = psprintf("http://localhost:%d/extension_server/%s%s", - extension_server_port, filename, is_library ? "?is_library=true" : ""); + hadron_extension_server_port, filename, is_library ? "?is_library=true" : ""); elog(LOG, "Sending request to compute_ctl: %s", compute_ctl_url); @@ -82,7 +82,7 @@ pg_init_extension_server() DefineCustomIntVariable("neon.extension_server_port", "connection string to the compute_ctl", NULL, - &extension_server_port, + &hadron_extension_server_port, 0, 0, INT_MAX, PGC_POSTMASTER, 0, /* no flags required */ diff --git a/pgxn/neon/file_cache.c b/pgxn/neon/file_cache.c index 4da6c176cd..88086689c8 100644 --- a/pgxn/neon/file_cache.c +++ b/pgxn/neon/file_cache.c @@ -635,6 +635,11 @@ lfc_init(void) NULL); } +/* + * Dump a list of pages that are currently in the LFC + * + * This is used to get a snapshot that can be used to prewarm the LFC later. + */ FileCacheState* lfc_get_state(size_t max_entries) { @@ -2267,4 +2272,3 @@ get_prewarm_info(PG_FUNCTION_ARGS) PG_RETURN_DATUM(HeapTupleGetDatum(heap_form_tuple(tupdesc, values, nulls))); } - diff --git a/pgxn/neon/libpagestore.c b/pgxn/neon/libpagestore.c index caffdc9612..1031f185a6 100644 --- a/pgxn/neon/libpagestore.c +++ b/pgxn/neon/libpagestore.c @@ -13,6 +13,8 @@ #include #include +#include + #include "libpq-int.h" #include "access/xlog.h" @@ -86,6 +88,10 @@ static int pageserver_response_log_timeout = 10000; /* 2.5 minutes. A bit higher than highest default TCP retransmission timeout */ static int pageserver_response_disconnect_timeout = 150000; +static int conf_refresh_reconnect_attempt_threshold = 16; +// Hadron: timeout for refresh errors (1 minute) +static uint64 kRefreshErrorTimeoutUSec = 1 * USECS_PER_MINUTE; + typedef struct { char connstring[MAX_SHARDS][MAX_PAGESERVER_CONNSTRING_SIZE]; @@ -130,7 +136,7 @@ static uint64 pagestore_local_counter = 0; typedef enum PSConnectionState { PS_Disconnected, /* no connection yet */ PS_Connecting_Startup, /* connection starting up */ - PS_Connecting_PageStream, /* negotiating pagestream */ + PS_Connecting_PageStream, /* negotiating pagestream */ PS_Connected, /* connected, pagestream established */ } PSConnectionState; @@ -401,7 +407,7 @@ get_shard_number(BufferTag *tag) } static inline void -CLEANUP_AND_DISCONNECT(PageServer *shard) +CLEANUP_AND_DISCONNECT(PageServer *shard) { if (shard->wes_read) { @@ -423,7 +429,7 @@ CLEANUP_AND_DISCONNECT(PageServer *shard) * complete the connection (e.g. due to receiving an earlier cancellation * during connection start). * Returns true if successfully connected; false if the connection failed. - * + * * Throws errors in unrecoverable situations, or when this backend's query * is canceled. */ @@ -1030,6 +1036,101 @@ pageserver_disconnect_shard(shardno_t shard_no) shard->state = PS_Disconnected; } +// BEGIN HADRON +/* + * Nudge compute_ctl to refresh our configuration. Called when we suspect we may be + * connecting to the wrong pageservers due to a stale configuration. + * + * This is a best-effort operation. If we couldn't send the local loopback HTTP request + * to compute_ctl or if the request fails for any reason, we just log the error and move + * on. + */ + +extern int hadron_extension_server_port; + +// The timestamp (usec) of the first error that occurred while trying to refresh the configuration. +// Will be reset to 0 after a successful refresh. +static uint64 first_recorded_refresh_error_usec = 0; + +// Request compute_ctl to refresh the configuration. This operation may fail, e.g., if the compute_ctl +// is already in the configuration state. The function returns true if the caller needs to cancel the +// current query to avoid dead/live lock. +static bool +hadron_request_configuration_refresh() { + static CURL *handle = NULL; + CURLcode res; + char *compute_ctl_url; + bool cancel_query = false; + + if (!lakebase_mode) + return false; + + if (handle == NULL) + { + handle = alloc_curl_handle(); + + curl_easy_setopt(handle, CURLOPT_CUSTOMREQUEST, "POST"); + curl_easy_setopt(handle, CURLOPT_TIMEOUT, 3L /* seconds */ ); + curl_easy_setopt(handle, CURLOPT_POSTFIELDS, ""); + } + + // Set the URL + compute_ctl_url = psprintf("http://localhost:%d/refresh_configuration", hadron_extension_server_port); + + + elog(LOG, "Sending refresh configuration request to compute_ctl: %s", compute_ctl_url); + + curl_easy_setopt(handle, CURLOPT_URL, compute_ctl_url); + + res = curl_easy_perform(handle); + if (res != CURLE_OK ) + { + elog(WARNING, "refresh_configuration request failed: %s\n", curl_easy_strerror(res)); + } + else + { + long http_code = 0; + curl_easy_getinfo(handle, CURLINFO_RESPONSE_CODE, &http_code); + if ( res != CURLE_OK ) + { + elog(WARNING, "compute_ctl refresh_configuration request getinfo failed: %s\n", curl_easy_strerror(res)); + } + else + { + elog(LOG, "compute_ctl refresh_configuration got HTTP response: %ld\n", http_code); + if( http_code == 200 ) + { + first_recorded_refresh_error_usec = 0; + } + else + { + if (first_recorded_refresh_error_usec == 0) + { + first_recorded_refresh_error_usec = GetCurrentTimestamp(); + } + else if(GetCurrentTimestamp() - first_recorded_refresh_error_usec > kRefreshErrorTimeoutUSec) + { + { + first_recorded_refresh_error_usec = 0; + cancel_query = true; + } + } + } + } + } + + // In regular Postgres usage, it is not necessary to manually free memory allocated by palloc (psprintf) because + // it will be cleaned up after the "memory context" is reset (e.g. after the query or the transaction is finished). + // However, the number of times this function gets called during a single query/transaction can be unbounded due to + // the various retry loops around calls to pageservers. Therefore, we need to manually free this memory here. + if (compute_ctl_url != NULL) + { + pfree(compute_ctl_url); + } + return cancel_query; +} +// END HADRON + static bool pageserver_send(shardno_t shard_no, NeonRequest *request) { @@ -1064,6 +1165,11 @@ pageserver_send(shardno_t shard_no, NeonRequest *request) while (!pageserver_connect(shard_no, shard->n_reconnect_attempts < max_reconnect_attempts ? LOG : ERROR)) { shard->n_reconnect_attempts += 1; + if (shard->n_reconnect_attempts > conf_refresh_reconnect_attempt_threshold + && hadron_request_configuration_refresh() ) + { + neon_shard_log(shard_no, ERROR, "request failed too many times, cancelling query"); + } } shard->n_reconnect_attempts = 0; } else { @@ -1171,17 +1277,26 @@ pageserver_receive(shardno_t shard_no) pfree(msg); pageserver_disconnect(shard_no); resp = NULL; + + /* + * Always poke compute_ctl to request a configuration refresh if we have issues receiving data from pageservers after + * successfully connecting to it. It could be an indication that we are connecting to the wrong pageservers (e.g. PS + * is in secondary mode or otherwise refuses to respond our request). + */ + hadron_request_configuration_refresh(); } else if (rc == -2) { char *msg = pchomp(PQerrorMessage(pageserver_conn)); pageserver_disconnect(shard_no); + hadron_request_configuration_refresh(); neon_shard_log(shard_no, ERROR, "pageserver_receive disconnect: could not read COPY data: %s", msg); } else { pageserver_disconnect(shard_no); + hadron_request_configuration_refresh(); neon_shard_log(shard_no, ERROR, "pageserver_receive disconnect: unexpected PQgetCopyData return value: %d", rc); } @@ -1249,21 +1364,34 @@ pageserver_try_receive(shardno_t shard_no) neon_shard_log(shard_no, LOG, "pageserver_receive disconnect: psql end of copy data: %s", pchomp(PQerrorMessage(pageserver_conn))); pageserver_disconnect(shard_no); resp = NULL; + hadron_request_configuration_refresh(); } else if (rc == -2) { char *msg = pchomp(PQerrorMessage(pageserver_conn)); pageserver_disconnect(shard_no); + hadron_request_configuration_refresh(); neon_shard_log(shard_no, LOG, "pageserver_receive disconnect: could not read COPY data: %s", msg); resp = NULL; } else { pageserver_disconnect(shard_no); + hadron_request_configuration_refresh(); neon_shard_log(shard_no, ERROR, "pageserver_receive disconnect: unexpected PQgetCopyData return value: %d", rc); } + /* + * Always poke compute_ctl to request a configuration refresh if we have issues receiving data from pageservers after + * successfully connecting to it. It could be an indication that we are connecting to the wrong pageservers (e.g. PS + * is in secondary mode or otherwise refuses to respond our request). + */ + if ( rc < 0 && hadron_request_configuration_refresh() ) + { + neon_shard_log(shard_no, ERROR, "refresh_configuration request failed, cancelling query"); + } + shard->nresponses_received++; return (NeonResponse *) resp; } @@ -1460,6 +1588,16 @@ pg_init_libpagestore(void) PGC_SU_BACKEND, 0, /* no flags required */ NULL, NULL, NULL); + DefineCustomIntVariable("hadron.conf_refresh_reconnect_attempt_threshold", + "Threshold of the number of consecutive failed pageserver " + "connection attempts (per shard) before signaling " + "compute_ctl for a configuration refresh.", + NULL, + &conf_refresh_reconnect_attempt_threshold, + 16, 0, INT_MAX, + PGC_USERSET, + 0, + NULL, NULL, NULL); DefineCustomIntVariable("neon.pageserver_response_log_timeout", "pageserver response log timeout", diff --git a/pgxn/neon/neon.c b/pgxn/neon/neon.c index 5b9c7d600c..6cd21cce39 100644 --- a/pgxn/neon/neon.c +++ b/pgxn/neon/neon.c @@ -1,7 +1,7 @@ /*------------------------------------------------------------------------- * * neon.c - * Main entry point into the neon exension + * Main entry point into the neon extension * *------------------------------------------------------------------------- */ @@ -48,6 +48,7 @@ PG_MODULE_MAGIC; void _PG_init(void); +bool lakebase_mode = false; static int running_xacts_overflow_policy; static bool monitor_query_exec_time = false; @@ -507,7 +508,7 @@ _PG_init(void) DefineCustomBoolVariable( "neon.disable_logical_replication_subscribers", - "Disables incomming logical replication", + "Disable incoming logical replication", NULL, &disable_logical_replication_subscribers, false, @@ -566,7 +567,7 @@ _PG_init(void) DefineCustomEnumVariable( "neon.debug_compare_local", - "Debug mode for compaing content of pages in prefetch ring/LFC/PS and local disk", + "Debug mode for comparing content of pages in prefetch ring/LFC/PS and local disk", NULL, &debug_compare_local, DEBUG_COMPARE_LOCAL_NONE, @@ -583,6 +584,16 @@ _PG_init(void) "neon_superuser", PGC_POSTMASTER, 0, NULL, NULL, NULL); + DefineCustomBoolVariable( + "neon.lakebase_mode", + "Is neon running in Lakebase?", + NULL, + &lakebase_mode, + false, + PGC_POSTMASTER, + 0, + NULL, NULL, NULL); + /* * Important: This must happen after other parts of the extension are * loaded, otherwise any settings to GUCs that were set before the @@ -724,7 +735,6 @@ neon_shmem_request_hook(void) static void neon_shmem_startup_hook(void) { - /* Initialize */ if (prev_shmem_startup_hook) prev_shmem_startup_hook(); diff --git a/pgxn/neon/neon.h b/pgxn/neon/neon.h index 20c850864a..e589d0cfba 100644 --- a/pgxn/neon/neon.h +++ b/pgxn/neon/neon.h @@ -21,6 +21,7 @@ extern int wal_acceptor_reconnect_timeout; extern int wal_acceptor_connection_timeout; extern int readahead_getpage_pull_timeout_ms; extern bool disable_wal_prev_lsn_checks; +extern bool lakebase_mode; extern bool AmPrewarmWorker; diff --git a/pgxn/neon/neon_perf_counters.h b/pgxn/neon/neon_perf_counters.h index 4b611b0636..bc4efddee5 100644 --- a/pgxn/neon/neon_perf_counters.h +++ b/pgxn/neon/neon_perf_counters.h @@ -167,11 +167,7 @@ extern neon_per_backend_counters *neon_per_backend_counters_shared; */ #define NUM_NEON_PERF_COUNTER_SLOTS (MaxBackends + NUM_AUXILIARY_PROCS) -#if PG_VERSION_NUM >= 170000 #define MyNeonCounters (&neon_per_backend_counters_shared[MyProcNumber]) -#else -#define MyNeonCounters (&neon_per_backend_counters_shared[MyProc->pgprocno]) -#endif extern void inc_getpage_wait(uint64 latency); extern void inc_page_cache_read_wait(uint64 latency); diff --git a/pgxn/neon/neon_pgversioncompat.h b/pgxn/neon/neon_pgversioncompat.h index 3ab8d3e5f5..dbe0e5aa3d 100644 --- a/pgxn/neon/neon_pgversioncompat.h +++ b/pgxn/neon/neon_pgversioncompat.h @@ -9,6 +9,10 @@ #include "fmgr.h" #include "storage/buf_internals.h" +#if PG_MAJORVERSION_NUM < 16 +typedef PGAlignedBlock PGIOAlignedBlock; +#endif + #if PG_MAJORVERSION_NUM < 17 #define NRelFileInfoBackendIsTemp(rinfo) (rinfo.backend != InvalidBackendId) #else @@ -158,6 +162,10 @@ InitBufferTag(BufferTag *tag, const RelFileNode *rnode, #define AmAutoVacuumWorkerProcess() (IsAutoVacuumWorkerProcess()) #endif +#if PG_MAJORVERSION_NUM < 17 +#define MyProcNumber (MyProc - &ProcGlobal->allProcs[0]) +#endif + #if PG_MAJORVERSION_NUM < 15 extern void InitMaterializedSRF(FunctionCallInfo fcinfo, bits32 flags); extern TimeLineID GetWALInsertionTimeLine(void); diff --git a/pgxn/neon/pagestore_smgr.c b/pgxn/neon/pagestore_smgr.c index 9d25266e10..d3e51ba682 100644 --- a/pgxn/neon/pagestore_smgr.c +++ b/pgxn/neon/pagestore_smgr.c @@ -72,10 +72,6 @@ #include "access/xlogrecovery.h" #endif -#if PG_VERSION_NUM < 160000 -typedef PGAlignedBlock PGIOAlignedBlock; -#endif - #include "access/nbtree.h" #include "storage/bufpage.h" #include "access/xlog_internal.h" diff --git a/pgxn/neon/relsize_cache.c b/pgxn/neon/relsize_cache.c index bf7961574a..c6b4aeb394 100644 --- a/pgxn/neon/relsize_cache.c +++ b/pgxn/neon/relsize_cache.c @@ -13,6 +13,7 @@ #include "neon.h" #include "neon_pgversioncompat.h" +#include "miscadmin.h" #include "pagestore_client.h" #include RELFILEINFO_HDR #include "storage/smgr.h" @@ -23,10 +24,6 @@ #include "utils/dynahash.h" #include "utils/guc.h" -#if PG_VERSION_NUM >= 150000 -#include "miscadmin.h" -#endif - typedef struct { NRelFileInfo rinfo; diff --git a/pgxn/neon/walproposer.h b/pgxn/neon/walproposer.h index 19d23925a5..5507294c3b 100644 --- a/pgxn/neon/walproposer.h +++ b/pgxn/neon/walproposer.h @@ -389,12 +389,21 @@ typedef struct PageserverFeedback */ typedef struct WalRateLimiter { - /* If the value is 1, PG backends will hit backpressure. */ + /* The effective wal write rate. Could be changed dynamically + based on whether PG has backpressure or not.*/ + pg_atomic_uint32 effective_max_wal_bytes_per_second; + /* If the value is 1, PG backends will hit backpressure until the time has past batch_end_time_us. */ pg_atomic_uint32 should_limit; /* The number of bytes sent in the current second. */ uint64 sent_bytes; - /* The last recorded time in microsecond. */ - pg_atomic_uint64 last_recorded_time_us; + /* The timestamp when the write starts in the current batch. A batch is a time interval (e.g., )that we + track and throttle writes. Most times a batch is 1s, but it could become larger if the PG overwrites the WALs + and we will adjust the batch accordingly to compensate (e.g., if PG writes 10MB at once and max WAL write rate + is 1MB/s, then the current batch will become 10s). */ + pg_atomic_uint64 batch_start_time_us; + /* The timestamp (in the future) that the current batch should end and accept more writes + (after should_limit is set to 1). */ + pg_atomic_uint64 batch_end_time_us; } WalRateLimiter; /* END_HADRON */ diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index d43d372c2e..b0f5828d39 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -68,6 +68,14 @@ int safekeeper_proto_version = 3; char *safekeeper_conninfo_options = ""; /* BEGIN_HADRON */ int databricks_max_wal_mb_per_second = -1; +// during throttling, we will limit the effective WAL write rate to 10KB. +// PG can still push some WAL to SK, but at a very low rate. +int databricks_throttled_max_wal_bytes_per_second = 10 * 1024; +// The max sleep time of a batch. This is to make sure the rate limiter does not +// overshoot too much and block PG for a very long time. +// This is set as 5 minuetes for now. PG can send as much as 10MB of WALs to SK in one batch, +// so this effectively caps the write rate to ~30KB/s in the worst case. +static uint64 kRateLimitMaxBatchUSecs = 300 * USECS_PER_SEC; /* END_HADRON */ /* Set to true in the walproposer bgw. */ @@ -86,6 +94,7 @@ static HotStandbyFeedback agg_hs_feedback; static void nwp_register_gucs(void); static void assign_neon_safekeepers(const char *newval, void *extra); static uint64 backpressure_lag_impl(void); +static uint64 hadron_backpressure_lag_impl(void); static uint64 startup_backpressure_wrap(void); static bool backpressure_throttling_impl(void); static void walprop_register_bgworker(void); @@ -110,9 +119,22 @@ static void rm_safekeeper_event_set(Safekeeper *to_remove, bool is_sk); static void CheckGracefulShutdown(WalProposer *wp); -// HADRON +/* BEGIN_HADRON */ shardno_t get_num_shards(void); +static int positive_mb_to_bytes(int mb) +{ + if (mb <= 0) + { + return mb; + } + else + { + return mb * 1024 * 1024; + } +} +/* END_HADRON */ + static void init_walprop_config(bool syncSafekeepers) { @@ -260,6 +282,16 @@ nwp_register_gucs(void) PGC_SUSET, GUC_UNIT_MB, NULL, NULL, NULL); + + DefineCustomIntVariable( + "databricks.throttled_max_wal_bytes_per_second", + "The maximum WAL bytes per second when PG is being throttled.", + NULL, + &databricks_throttled_max_wal_bytes_per_second, + 10 * 1024, 0, INT_MAX, + PGC_SUSET, + GUC_UNIT_BYTE, + NULL, NULL, NULL); /* END_HADRON */ } @@ -398,19 +430,65 @@ assign_neon_safekeepers(const char *newval, void *extra) pfree(oldval); } -/* Check if we need to suspend inserts because of lagging replication. */ -static uint64 -backpressure_lag_impl(void) +/* BEGIN_HADRON */ +static uint64 hadron_backpressure_lag_impl(void) { struct WalproposerShmemState* state = NULL; + uint64 lag = 0; - /* BEGIN_HADRON */ if(max_cluster_size < 0){ // if max cluster size is not set, then we don't apply backpressure because we're reconfiguring PG return 0; } - /* END_HADRON */ + lag = backpressure_lag_impl(); + state = GetWalpropShmemState(); + if ( state != NULL && databricks_max_wal_mb_per_second != -1 ) + { + int old_limit = pg_atomic_read_u32(&state->wal_rate_limiter.effective_max_wal_bytes_per_second); + int new_limit = (lag == 0)? positive_mb_to_bytes(databricks_max_wal_mb_per_second) : databricks_throttled_max_wal_bytes_per_second; + if( old_limit != new_limit ) + { + uint64 batch_start_time = pg_atomic_read_u64(&state->wal_rate_limiter.batch_start_time_us); + uint64 batch_end_time = pg_atomic_read_u64(&state->wal_rate_limiter.batch_end_time_us); + // the rate limit has changed, we need to reset the rate limiter's batch end time + pg_atomic_write_u32(&state->wal_rate_limiter.effective_max_wal_bytes_per_second, new_limit); + pg_atomic_write_u64(&state->wal_rate_limiter.batch_end_time_us, Min(batch_start_time + USECS_PER_SEC, batch_end_time)); + } + if( new_limit == -1 ) + { + return 0; + } + + if (pg_atomic_read_u32(&state->wal_rate_limiter.should_limit) == true) + { + TimestampTz now = GetCurrentTimestamp(); + struct WalRateLimiter *limiter = &state->wal_rate_limiter; + uint64 batch_end_time = pg_atomic_read_u64(&limiter->batch_end_time_us); + if ( now >= batch_end_time ) + { + /* + * The backend has past the batch end time and it's time to push more WALs. + * If the backends are pushing WALs too fast, the wal proposer will rate limit them again. + */ + uint32 expected = true; + pg_atomic_compare_exchange_u32(&state->wal_rate_limiter.should_limit, &expected, false); + return 0; + } + return Max(lag, 1); + } + // rate limiter decides to not throttle, then return 0. + return 0; + } + + return lag; +} +/* END_HADRON */ + +/* Check if we need to suspend inserts because of lagging replication. */ +static uint64 +backpressure_lag_impl(void) +{ if (max_replication_apply_lag > 0 || max_replication_flush_lag > 0 || max_replication_write_lag > 0) { XLogRecPtr writePtr; @@ -429,45 +507,47 @@ backpressure_lag_impl(void) LSN_FORMAT_ARGS(flushPtr), LSN_FORMAT_ARGS(applyPtr)); - if ((writePtr != InvalidXLogRecPtr && max_replication_write_lag > 0 && myFlushLsn > writePtr + max_replication_write_lag * MB)) + if (lakebase_mode) { - return (myFlushLsn - writePtr - max_replication_write_lag * MB); - } + // in case PG does not have shard map initialized, we assume PG always has 1 shard at minimum. + shardno_t num_shards = Max(1, get_num_shards()); + int tenant_max_replication_apply_lag = num_shards * max_replication_apply_lag; + int tenant_max_replication_flush_lag = num_shards * max_replication_flush_lag; + int tenant_max_replication_write_lag = num_shards * max_replication_write_lag; - if ((flushPtr != InvalidXLogRecPtr && max_replication_flush_lag > 0 && myFlushLsn > flushPtr + max_replication_flush_lag * MB)) - { - return (myFlushLsn - flushPtr - max_replication_flush_lag * MB); - } + if ((writePtr != InvalidXLogRecPtr && tenant_max_replication_write_lag > 0 && myFlushLsn > writePtr + tenant_max_replication_write_lag * MB)) + { + return (myFlushLsn - writePtr - tenant_max_replication_write_lag * MB); + } - if ((applyPtr != InvalidXLogRecPtr && max_replication_apply_lag > 0 && myFlushLsn > applyPtr + max_replication_apply_lag * MB)) + if ((flushPtr != InvalidXLogRecPtr && tenant_max_replication_flush_lag > 0 && myFlushLsn > flushPtr + tenant_max_replication_flush_lag * MB)) + { + return (myFlushLsn - flushPtr - tenant_max_replication_flush_lag * MB); + } + + if ((applyPtr != InvalidXLogRecPtr && tenant_max_replication_apply_lag > 0 && myFlushLsn > applyPtr + tenant_max_replication_apply_lag * MB)) + { + return (myFlushLsn - applyPtr - tenant_max_replication_apply_lag * MB); + } + } + else { - return (myFlushLsn - applyPtr - max_replication_apply_lag * MB); + if ((writePtr != InvalidXLogRecPtr && max_replication_write_lag > 0 && myFlushLsn > writePtr + max_replication_write_lag * MB)) + { + return (myFlushLsn - writePtr - max_replication_write_lag * MB); + } + + if ((flushPtr != InvalidXLogRecPtr && max_replication_flush_lag > 0 && myFlushLsn > flushPtr + max_replication_flush_lag * MB)) + { + return (myFlushLsn - flushPtr - max_replication_flush_lag * MB); + } + + if ((applyPtr != InvalidXLogRecPtr && max_replication_apply_lag > 0 && myFlushLsn > applyPtr + max_replication_apply_lag * MB)) + { + return (myFlushLsn - applyPtr - max_replication_apply_lag * MB); + } } } - - /* BEGIN_HADRON */ - if (databricks_max_wal_mb_per_second == -1) { - return 0; - } - - state = GetWalpropShmemState(); - if (state != NULL && !!pg_atomic_read_u32(&state->wal_rate_limiter.should_limit)) - { - TimestampTz now = GetCurrentTimestamp(); - struct WalRateLimiter *limiter = &state->wal_rate_limiter; - uint64 last_recorded_time = pg_atomic_read_u64(&limiter->last_recorded_time_us); - if (now - last_recorded_time > USECS_PER_SEC) - { - /* - * The backend has past 1 second since the last recorded time and it's time to push more WALs. - * If the backends are pushing WALs too fast, the wal proposer will rate limit them again. - */ - uint32 expected = true; - pg_atomic_compare_exchange_u32(&state->wal_rate_limiter.should_limit, &expected, false); - } - return 1; - } - /* END_HADRON */ return 0; } @@ -482,9 +562,9 @@ startup_backpressure_wrap(void) if (AmStartupProcess() || !IsUnderPostmaster) return 0; - delay_backend_us = &backpressure_lag_impl; + delay_backend_us = &hadron_backpressure_lag_impl; - return backpressure_lag_impl(); + return hadron_backpressure_lag_impl(); } /* @@ -514,8 +594,10 @@ WalproposerShmemInit(void) pg_atomic_init_u64(&walprop_shared->backpressureThrottlingTime, 0); pg_atomic_init_u64(&walprop_shared->currentClusterSize, 0); /* BEGIN_HADRON */ + pg_atomic_init_u32(&walprop_shared->wal_rate_limiter.effective_max_wal_bytes_per_second, -1); pg_atomic_init_u32(&walprop_shared->wal_rate_limiter.should_limit, 0); - pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.last_recorded_time_us, 0); + pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.batch_start_time_us, 0); + pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.batch_end_time_us, 0); /* END_HADRON */ } } @@ -530,8 +612,10 @@ WalproposerShmemInit_SyncSafekeeper(void) pg_atomic_init_u64(&walprop_shared->mineLastElectedTerm, 0); pg_atomic_init_u64(&walprop_shared->backpressureThrottlingTime, 0); /* BEGIN_HADRON */ + pg_atomic_init_u32(&walprop_shared->wal_rate_limiter.effective_max_wal_bytes_per_second, -1); pg_atomic_init_u32(&walprop_shared->wal_rate_limiter.should_limit, 0); - pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.last_recorded_time_us, 0); + pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.batch_start_time_us, 0); + pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.batch_end_time_us, 0); /* END_HADRON */ } @@ -563,7 +647,7 @@ backpressure_throttling_impl(void) return retry; /* Calculate replicas lag */ - lag = backpressure_lag_impl(); + lag = hadron_backpressure_lag_impl(); if (lag == 0) return retry; @@ -659,7 +743,7 @@ record_pageserver_feedback(PageserverFeedback *ps_feedback, shardno_t num_shards SpinLockAcquire(&walprop_shared->mutex); - // Hadron: Update the num_shards from the source-of-truth (shard map) lazily when we receive + // Hadron: Update the num_shards from the source-of-truth (shard map) lazily when we receive // a new pageserver feedback. walprop_shared->num_shards = Max(walprop_shared->num_shards, num_shards); @@ -1479,6 +1563,7 @@ XLogBroadcastWalProposer(WalProposer *wp) XLogRecPtr endptr; struct WalproposerShmemState *state = NULL; TimestampTz now = 0; + int effective_max_wal_bytes_per_second = 0; /* Start from the last sent position */ startptr = sentPtr; @@ -1533,22 +1618,36 @@ XLogBroadcastWalProposer(WalProposer *wp) /* BEGIN_HADRON */ state = GetWalpropShmemState(); - if (databricks_max_wal_mb_per_second != -1 && state != NULL) + effective_max_wal_bytes_per_second = pg_atomic_read_u32(&state->wal_rate_limiter.effective_max_wal_bytes_per_second); + if (effective_max_wal_bytes_per_second != -1 && state != NULL) { - uint64 max_wal_bytes = (uint64) databricks_max_wal_mb_per_second * 1024 * 1024; struct WalRateLimiter *limiter = &state->wal_rate_limiter; - uint64 last_recorded_time = pg_atomic_read_u64(&limiter->last_recorded_time_us); - if (now - last_recorded_time > USECS_PER_SEC) + uint64 batch_end_time = pg_atomic_read_u64(&limiter->batch_end_time_us); + if ( now >= batch_end_time ) { - /* Reset the rate limiter */ + // Reset the rate limiter to start a new batch limiter->sent_bytes = 0; - pg_atomic_write_u64(&limiter->last_recorded_time_us, now); pg_atomic_write_u32(&limiter->should_limit, false); + pg_atomic_write_u64(&limiter->batch_start_time_us, now); + /* tentatively assign the batch end time as 1s from now. This could result in one of the following cases: + 1. If sent_bytes does not reach effective_max_wal_bytes_per_second in 1s, + then we will reset the current batch and clear sent_bytes. No throttling happens. + 2. Otherwise, we will recompute the end time (below) based on how many bytes are actually written, + and throttle PG until the batch end time. */ + pg_atomic_write_u64(&limiter->batch_end_time_us, now + USECS_PER_SEC); } limiter->sent_bytes += (endptr - startptr); - if (limiter->sent_bytes > max_wal_bytes) + if (limiter->sent_bytes > effective_max_wal_bytes_per_second) { + uint64_t batch_start_time = pg_atomic_read_u64(&limiter->batch_start_time_us); + uint64 throttle_usecs = USECS_PER_SEC * limiter->sent_bytes / Max(effective_max_wal_bytes_per_second, 1); + if (throttle_usecs > kRateLimitMaxBatchUSecs){ + elog(LOG, "throttle_usecs %lu is too large, limiting to %lu", throttle_usecs, kRateLimitMaxBatchUSecs); + throttle_usecs = kRateLimitMaxBatchUSecs; + } + pg_atomic_write_u32(&limiter->should_limit, true); + pg_atomic_write_u64(&limiter->batch_end_time_us, batch_start_time + throttle_usecs); } } /* END_HADRON */ @@ -2052,7 +2151,7 @@ walprop_pg_process_safekeeper_feedback(WalProposer *wp, Safekeeper *sk) /* Only one main shard sends non-zero currentClusterSize */ if (sk->appendResponse.ps_feedback.currentClusterSize > 0) SetNeonCurrentClusterSize(sk->appendResponse.ps_feedback.currentClusterSize); - + if (min_feedback.disk_consistent_lsn != standby_apply_lsn) { standby_apply_lsn = min_feedback.disk_consistent_lsn; diff --git a/proxy/Cargo.toml b/proxy/Cargo.toml index 3c3f93c8e3..0ece79c329 100644 --- a/proxy/Cargo.toml +++ b/proxy/Cargo.toml @@ -33,7 +33,6 @@ env_logger.workspace = true framed-websockets.workspace = true futures.workspace = true hashbrown.workspace = true -hashlink.workspace = true hex.workspace = true hmac.workspace = true hostname.workspace = true @@ -54,6 +53,7 @@ json = { path = "../libs/proxy/json" } lasso = { workspace = true, features = ["multi-threaded"] } measured = { workspace = true, features = ["lasso"] } metrics.workspace = true +moka.workspace = true once_cell.workspace = true opentelemetry = { workspace = true, features = ["trace"] } papaya = "0.2.0" @@ -107,10 +107,11 @@ uuid.workspace = true x509-cert.workspace = true redis.workspace = true zerocopy.workspace = true +zeroize.workspace = true # uncomment this to use the real subzero-core crate # subzero-core = { git = "https://github.com/neondatabase/subzero", rev = "396264617e78e8be428682f87469bb25429af88a", features = ["postgresql"], optional = true } # this is a stub for the subzero-core crate -subzero-core = { path = "./subzero_core", features = ["postgresql"], optional = true} +subzero-core = { path = "../libs/proxy/subzero_core", features = ["postgresql"], optional = true} ouroboros = { version = "0.18", optional = true } # jwt stuff diff --git a/proxy/src/auth/backend/console_redirect.rs b/proxy/src/auth/backend/console_redirect.rs index b06ed3a0ae..2a02748a10 100644 --- a/proxy/src/auth/backend/console_redirect.rs +++ b/proxy/src/auth/backend/console_redirect.rs @@ -8,11 +8,12 @@ use tracing::{info, info_span}; use crate::auth::backend::ComputeUserInfo; use crate::cache::Cached; +use crate::cache::node_info::CachedNodeInfo; use crate::compute::AuthInfo; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::client::cplane_proxy_v1; -use crate::control_plane::{self, CachedNodeInfo, NodeInfo}; +use crate::control_plane::{self, NodeInfo}; use crate::error::{ReportableError, UserFacingError}; use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; diff --git a/proxy/src/auth/backend/hacks.rs b/proxy/src/auth/backend/hacks.rs index 1e5c076fb9..491f14b1b6 100644 --- a/proxy/src/auth/backend/hacks.rs +++ b/proxy/src/auth/backend/hacks.rs @@ -6,7 +6,7 @@ use crate::auth::{self, AuthFlow}; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::AuthSecret; -use crate::intern::EndpointIdInt; +use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::sasl; use crate::stream::{self, Stream}; @@ -25,13 +25,15 @@ pub(crate) async fn authenticate_cleartext( ctx.set_auth_method(crate::context::AuthMethod::Cleartext); let ep = EndpointIdInt::from(&info.endpoint); + let role = RoleNameInt::from(&info.user); let auth_flow = AuthFlow::new( client, auth::CleartextPassword { secret, endpoint: ep, - pool: config.thread_pool.clone(), + role, + pool: config.scram_thread_pool.clone(), }, ); let auth_outcome = { diff --git a/proxy/src/auth/backend/mod.rs b/proxy/src/auth/backend/mod.rs index e7805d8bfe..a6df2a7011 100644 --- a/proxy/src/auth/backend/mod.rs +++ b/proxy/src/auth/backend/mod.rs @@ -16,16 +16,16 @@ use tracing::{debug, info}; use crate::auth::{self, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange}; use crate::cache::Cached; +use crate::cache::node_info::CachedNodeInfo; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::client::ControlPlaneClient; use crate::control_plane::errors::GetAuthInfoError; use crate::control_plane::messages::EndpointRateLimitConfig; use crate::control_plane::{ - self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl, - RoleAccessControl, + self, AccessBlockerFlags, AuthSecret, ControlPlaneApi, EndpointAccessControl, RoleAccessControl, }; -use crate::intern::EndpointIdInt; +use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; use crate::proxy::wake_compute::WakeComputeBackend; @@ -273,9 +273,11 @@ async fn authenticate_with_secret( ) -> auth::Result { if let Some(password) = unauthenticated_password { let ep = EndpointIdInt::from(&info.endpoint); + let role = RoleNameInt::from(&info.user); let auth_outcome = - validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?; + validate_password_and_exchange(&config.scram_thread_pool, ep, role, &password, secret) + .await?; let keys = match auth_outcome { crate::sasl::Outcome::Success(key) => key, crate::sasl::Outcome::Failure(reason) => { @@ -433,11 +435,12 @@ mod tests { use super::auth_quirks; use super::jwt::JwkCache; use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern}; + use crate::cache::node_info::CachedNodeInfo; use crate::config::AuthenticationConfig; use crate::context::RequestContext; use crate::control_plane::messages::EndpointRateLimitConfig; use crate::control_plane::{ - self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl, + self, AccessBlockerFlags, EndpointAccessControl, RoleAccessControl, }; use crate::proxy::NeonOptions; use crate::rate_limiter::EndpointRateLimiter; @@ -498,7 +501,7 @@ mod tests { static CONFIG: Lazy = Lazy::new(|| AuthenticationConfig { jwks_cache: JwkCache::default(), - thread_pool: ThreadPool::new(1), + scram_thread_pool: ThreadPool::new(1), scram_protocol_timeout: std::time::Duration::from_secs(5), ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, diff --git a/proxy/src/auth/flow.rs b/proxy/src/auth/flow.rs index c825d5bf4b..00cd274e99 100644 --- a/proxy/src/auth/flow.rs +++ b/proxy/src/auth/flow.rs @@ -10,7 +10,7 @@ use super::backend::ComputeCredentialKeys; use super::{AuthError, PasswordHackPayload}; use crate::context::RequestContext; use crate::control_plane::AuthSecret; -use crate::intern::EndpointIdInt; +use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage}; use crate::sasl; use crate::scram::threadpool::ThreadPool; @@ -46,6 +46,7 @@ pub(crate) struct PasswordHack; pub(crate) struct CleartextPassword { pub(crate) pool: Arc, pub(crate) endpoint: EndpointIdInt, + pub(crate) role: RoleNameInt, pub(crate) secret: AuthSecret, } @@ -111,6 +112,7 @@ impl AuthFlow<'_, S, CleartextPassword> { let outcome = validate_password_and_exchange( &self.state.pool, self.state.endpoint, + self.state.role, password, self.state.secret, ) @@ -165,13 +167,15 @@ impl AuthFlow<'_, S, Scram<'_>> { pub(crate) async fn validate_password_and_exchange( pool: &ThreadPool, endpoint: EndpointIdInt, + role: RoleNameInt, password: &[u8], secret: AuthSecret, ) -> super::Result> { match secret { // perform scram authentication as both client and server to validate the keys AuthSecret::Scram(scram_secret) => { - let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?; + let outcome = + crate::scram::exchange(pool, endpoint, role, &scram_secret, password).await?; let client_key = match outcome { sasl::Outcome::Success(client_key) => client_key, diff --git a/proxy/src/binary/local_proxy.rs b/proxy/src/binary/local_proxy.rs index 7b9012dc69..86b64c62c9 100644 --- a/proxy/src/binary/local_proxy.rs +++ b/proxy/src/binary/local_proxy.rs @@ -29,7 +29,7 @@ use crate::config::{ }; use crate::control_plane::locks::ApiLocks; use crate::http::health_server::AppMetrics; -use crate::metrics::{Metrics, ThreadPoolMetrics}; +use crate::metrics::{Metrics, ServiceInfo}; use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo}; use crate::scram::threadpool::ThreadPool; use crate::serverless::cancel_set::CancelSet; @@ -114,8 +114,6 @@ pub async fn run() -> anyhow::Result<()> { let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook(); let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]); - Metrics::install(Arc::new(ThreadPoolMetrics::new(0))); - // TODO: refactor these to use labels debug!("Version: {GIT_VERSION}"); debug!("Build_tag: {BUILD_TAG}"); @@ -207,6 +205,11 @@ pub async fn run() -> anyhow::Result<()> { endpoint_rate_limiter, ); + Metrics::get() + .service + .info + .set_label(ServiceInfo::running()); + match futures::future::select(pin!(maintenance_tasks.join_next()), pin!(task)).await { // exit immediately on maintenance task completion Either::Left((Some(res), _)) => match crate::error::flatten_err(res)? {}, @@ -279,7 +282,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig http_config, authentication_config: AuthenticationConfig { jwks_cache: JwkCache::default(), - thread_pool: ThreadPool::new(0), + scram_thread_pool: ThreadPool::new(0), scram_protocol_timeout: Duration::from_secs(10), ip_allowlist_check_enabled: true, is_vpc_acccess_proxy: false, diff --git a/proxy/src/binary/pg_sni_router.rs b/proxy/src/binary/pg_sni_router.rs index f3782312dc..cdbf0f09ac 100644 --- a/proxy/src/binary/pg_sni_router.rs +++ b/proxy/src/binary/pg_sni_router.rs @@ -26,7 +26,7 @@ use utils::project_git_version; use utils::sentry_init::init_sentry; use crate::context::RequestContext; -use crate::metrics::{Metrics, ThreadPoolMetrics}; +use crate::metrics::{Metrics, ServiceInfo}; use crate::pglb::TlsRequired; use crate::pqproto::FeStartupPacket; use crate::protocol2::ConnectionInfo; @@ -80,8 +80,6 @@ pub async fn run() -> anyhow::Result<()> { let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook(); let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]); - Metrics::install(Arc::new(ThreadPoolMetrics::new(0))); - let args = cli().get_matches(); let destination: String = args .get_one::("dest") @@ -135,6 +133,12 @@ pub async fn run() -> anyhow::Result<()> { cancellation_token.clone(), )) .map(crate::error::flatten_err); + + Metrics::get() + .service + .info + .set_label(ServiceInfo::running()); + let signals_task = tokio::spawn(crate::signals::handle(cancellation_token, || {})); // the signal task cant ever succeed. diff --git a/proxy/src/binary/proxy.rs b/proxy/src/binary/proxy.rs index 4148f4bc62..29b0ad53f2 100644 --- a/proxy/src/binary/proxy.rs +++ b/proxy/src/binary/proxy.rs @@ -40,7 +40,7 @@ use crate::config::{ }; use crate::context::parquet::ParquetUploadArgs; use crate::http::health_server::AppMetrics; -use crate::metrics::Metrics; +use crate::metrics::{Metrics, ServiceInfo}; use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter}; use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider; use crate::redis::kv_ops::RedisKVClient; @@ -535,12 +535,7 @@ pub async fn run() -> anyhow::Result<()> { // add a task to flush the db_schema cache every 10 minutes #[cfg(feature = "rest_broker")] if let Some(db_schema_cache) = &config.rest_config.db_schema_cache { - maintenance_tasks.spawn(async move { - loop { - tokio::time::sleep(Duration::from_secs(600)).await; - db_schema_cache.flush(); - } - }); + maintenance_tasks.spawn(db_schema_cache.maintain()); } if let Some(metrics_config) = &config.metric_collection { @@ -590,6 +585,11 @@ pub async fn run() -> anyhow::Result<()> { } } + Metrics::get() + .service + .info + .set_label(ServiceInfo::running()); + let maintenance = loop { // get one complete task match futures::future::select( @@ -617,7 +617,12 @@ pub async fn run() -> anyhow::Result<()> { /// ProxyConfig is created at proxy startup, and lives forever. fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { let thread_pool = ThreadPool::new(args.scram_thread_pool_size); - Metrics::install(thread_pool.metrics.clone()); + Metrics::get() + .proxy + .scram_pool + .0 + .set(thread_pool.metrics.clone()) + .ok(); let tls_config = match (&args.tls_key, &args.tls_cert) { (Some(key_path), Some(cert_path)) => Some(config::configure_tls( @@ -690,7 +695,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { }; let authentication_config = AuthenticationConfig { jwks_cache: JwkCache::default(), - thread_pool, + scram_thread_pool: thread_pool, scram_protocol_timeout: args.scram_protocol_timeout, ip_allowlist_check_enabled: !args.is_private_access_proxy, is_vpc_acccess_proxy: args.is_private_access_proxy, @@ -711,12 +716,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> { info!("Using DbSchemaCache with options={db_schema_cache_config:?}"); let db_schema_cache = if args.is_rest_broker { - Some(DbSchemaCache::new( - "db_schema_cache", - db_schema_cache_config.size, - db_schema_cache_config.ttl, - true, - )) + Some(DbSchemaCache::new(db_schema_cache_config)) } else { None }; diff --git a/proxy/src/cache/common.rs b/proxy/src/cache/common.rs index b5caf94788..9a7d0d99cf 100644 --- a/proxy/src/cache/common.rs +++ b/proxy/src/cache/common.rs @@ -1,4 +1,16 @@ use std::ops::{Deref, DerefMut}; +use std::time::{Duration, Instant}; + +use moka::Expiry; +use moka::notification::RemovalCause; + +use crate::control_plane::messages::ControlPlaneErrorMessage; +use crate::metrics::{ + CacheEviction, CacheKind, CacheOutcome, CacheOutcomeGroup, CacheRemovalCause, Metrics, +}; + +/// Default TTL used when caching errors from control plane. +pub const DEFAULT_ERROR_TTL: Duration = Duration::from_secs(30); /// A generic trait which exposes types of cache's key and value, /// as well as the notion of cache entry invalidation. @@ -10,20 +22,16 @@ pub(crate) trait Cache { /// Entry's value. type Value; - /// Used for entry invalidation. - type LookupInfo; - /// Invalidate an entry using a lookup info. /// We don't have an empty default impl because it's error-prone. - fn invalidate(&self, _: &Self::LookupInfo); + fn invalidate(&self, _: &Self::Key); } impl Cache for &C { type Key = C::Key; type Value = C::Value; - type LookupInfo = C::LookupInfo; - fn invalidate(&self, info: &Self::LookupInfo) { + fn invalidate(&self, info: &Self::Key) { C::invalidate(self, info); } } @@ -31,7 +39,7 @@ impl Cache for &C { /// Wrapper for convenient entry invalidation. pub(crate) struct Cached::Value> { /// Cache + lookup info. - pub(crate) token: Option<(C, C::LookupInfo)>, + pub(crate) token: Option<(C, C::Key)>, /// The value itself. pub(crate) value: V, @@ -43,23 +51,6 @@ impl Cached { Self { token: None, value } } - pub(crate) fn take_value(self) -> (Cached, V) { - ( - Cached { - token: self.token, - value: (), - }, - self.value, - ) - } - - pub(crate) fn map(self, f: impl FnOnce(V) -> U) -> Cached { - Cached { - token: self.token, - value: f(self.value), - } - } - /// Drop this entry from a cache if it's still there. pub(crate) fn invalidate(self) -> V { if let Some((cache, info)) = &self.token { @@ -87,3 +78,91 @@ impl DerefMut for Cached { &mut self.value } } + +pub type ControlPlaneResult = Result>; + +#[derive(Clone, Copy)] +pub struct CplaneExpiry { + pub error: Duration, +} + +impl Default for CplaneExpiry { + fn default() -> Self { + Self { + error: DEFAULT_ERROR_TTL, + } + } +} + +impl CplaneExpiry { + pub fn expire_early( + &self, + value: &ControlPlaneResult, + updated: Instant, + ) -> Option { + match value { + Ok(_) => None, + Err(err) => Some(self.expire_err_early(err, updated)), + } + } + + pub fn expire_err_early(&self, err: &ControlPlaneErrorMessage, updated: Instant) -> Duration { + err.status + .as_ref() + .and_then(|s| s.details.retry_info.as_ref()) + .map_or(self.error, |r| r.retry_at.into_std() - updated) + } +} + +impl Expiry> for CplaneExpiry { + fn expire_after_create( + &self, + _key: &K, + value: &ControlPlaneResult, + created_at: Instant, + ) -> Option { + self.expire_early(value, created_at) + } + + fn expire_after_update( + &self, + _key: &K, + value: &ControlPlaneResult, + updated_at: Instant, + _duration_until_expiry: Option, + ) -> Option { + self.expire_early(value, updated_at) + } +} + +pub fn eviction_listener(kind: CacheKind, cause: RemovalCause) { + let cause = match cause { + RemovalCause::Expired => CacheRemovalCause::Expired, + RemovalCause::Explicit => CacheRemovalCause::Explicit, + RemovalCause::Replaced => CacheRemovalCause::Replaced, + RemovalCause::Size => CacheRemovalCause::Size, + }; + Metrics::get() + .cache + .evicted_total + .inc(CacheEviction { cache: kind, cause }); +} + +#[inline] +pub fn count_cache_outcome(kind: CacheKind, cache_result: Option) -> Option { + let outcome = if cache_result.is_some() { + CacheOutcome::Hit + } else { + CacheOutcome::Miss + }; + Metrics::get().cache.request_total.inc(CacheOutcomeGroup { + cache: kind, + outcome, + }); + cache_result +} + +#[inline] +pub fn count_cache_insert(kind: CacheKind) { + Metrics::get().cache.inserted_total.inc(kind); +} diff --git a/proxy/src/cache/mod.rs b/proxy/src/cache/mod.rs index ce7f781213..0a607a1409 100644 --- a/proxy/src/cache/mod.rs +++ b/proxy/src/cache/mod.rs @@ -1,6 +1,5 @@ pub(crate) mod common; +pub(crate) mod node_info; pub(crate) mod project_info; -mod timed_lru; -pub(crate) use common::{Cache, Cached}; -pub(crate) use timed_lru::TimedLru; +pub(crate) use common::{Cached, ControlPlaneResult, CplaneExpiry}; diff --git a/proxy/src/cache/node_info.rs b/proxy/src/cache/node_info.rs new file mode 100644 index 0000000000..47fc7a5b08 --- /dev/null +++ b/proxy/src/cache/node_info.rs @@ -0,0 +1,60 @@ +use crate::cache::common::{Cache, count_cache_insert, count_cache_outcome, eviction_listener}; +use crate::cache::{Cached, ControlPlaneResult, CplaneExpiry}; +use crate::config::CacheOptions; +use crate::control_plane::NodeInfo; +use crate::metrics::{CacheKind, Metrics}; +use crate::types::EndpointCacheKey; + +pub(crate) struct NodeInfoCache(moka::sync::Cache>); +pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; + +impl Cache for NodeInfoCache { + type Key = EndpointCacheKey; + type Value = ControlPlaneResult; + + fn invalidate(&self, info: &EndpointCacheKey) { + self.0.invalidate(info); + } +} + +impl NodeInfoCache { + pub fn new(config: CacheOptions) -> Self { + let builder = moka::sync::Cache::builder() + .name("node_info") + .expire_after(CplaneExpiry::default()); + let builder = config.moka(builder); + + if let Some(size) = config.size { + Metrics::get() + .cache + .capacity + .set(CacheKind::NodeInfo, size as i64); + } + + let builder = builder + .eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::NodeInfo, cause)); + + Self(builder.build()) + } + + pub fn insert(&self, key: EndpointCacheKey, value: ControlPlaneResult) { + count_cache_insert(CacheKind::NodeInfo); + self.0.insert(key, value); + } + + pub fn get(&self, key: &EndpointCacheKey) -> Option> { + count_cache_outcome(CacheKind::NodeInfo, self.0.get(key)) + } + + pub fn get_entry( + &'static self, + key: &EndpointCacheKey, + ) -> Option> { + self.get(key).map(|res| { + res.map(|value| Cached { + token: Some((self, key.clone())), + value, + }) + }) + } +} diff --git a/proxy/src/cache/project_info.rs b/proxy/src/cache/project_info.rs index a589dd175b..f8a38be287 100644 --- a/proxy/src/cache/project_info.rs +++ b/proxy/src/cache/project_info.rs @@ -1,84 +1,20 @@ -use std::collections::{HashMap, HashSet, hash_map}; +use std::collections::HashSet; use std::convert::Infallible; -use std::time::Duration; -use async_trait::async_trait; use clashmap::ClashMap; -use clashmap::mapref::one::Ref; -use rand::Rng; -use tokio::time::Instant; +use moka::sync::Cache; use tracing::{debug, info}; +use crate::cache::common::{ + ControlPlaneResult, CplaneExpiry, count_cache_insert, count_cache_outcome, eviction_listener, +}; use crate::config::ProjectInfoCacheOptions; use crate::control_plane::messages::{ControlPlaneErrorMessage, Reason}; use crate::control_plane::{EndpointAccessControl, RoleAccessControl}; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; +use crate::metrics::{CacheKind, Metrics}; use crate::types::{EndpointId, RoleName}; -#[async_trait] -pub(crate) trait ProjectInfoCache { - fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt); - fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt); - fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt); - fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt); -} - -struct Entry { - expires_at: Instant, - value: T, -} - -impl Entry { - pub(crate) fn new(value: T, ttl: Duration) -> Self { - Self { - expires_at: Instant::now() + ttl, - value, - } - } - - pub(crate) fn get(&self) -> Option<&T> { - (!self.is_expired()).then_some(&self.value) - } - - fn is_expired(&self) -> bool { - self.expires_at <= Instant::now() - } -} - -struct EndpointInfo { - role_controls: HashMap>>, - controls: Option>>, -} - -type ControlPlaneResult = Result>; - -impl EndpointInfo { - pub(crate) fn get_role_secret_with_ttl( - &self, - role_name: RoleNameInt, - ) -> Option<(ControlPlaneResult, Duration)> { - let entry = self.role_controls.get(&role_name)?; - let ttl = entry.expires_at - Instant::now(); - Some((entry.get()?.clone(), ttl)) - } - - pub(crate) fn get_controls_with_ttl( - &self, - ) -> Option<(ControlPlaneResult, Duration)> { - let entry = self.controls.as_ref()?; - let ttl = entry.expires_at - Instant::now(); - Some((entry.get()?.clone(), ttl)) - } - - pub(crate) fn invalidate_endpoint(&mut self) { - self.controls = None; - } - - pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) { - self.role_controls.remove(&role_name); - } -} - /// Cache for project info. /// This is used to cache auth data for endpoints. /// Invalidation is done by console notifications or by TTL (if console notifications are disabled). @@ -86,8 +22,9 @@ impl EndpointInfo { /// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data. /// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available? /// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache. -pub struct ProjectInfoCacheImpl { - cache: ClashMap, +pub struct ProjectInfoCache { + role_controls: Cache<(EndpointIdInt, RoleNameInt), ControlPlaneResult>, + ep_controls: Cache>, project2ep: ClashMap>, // FIXME(stefan): we need a way to GC the account2ep map. @@ -96,16 +33,13 @@ pub struct ProjectInfoCacheImpl { config: ProjectInfoCacheOptions, } -#[async_trait] -impl ProjectInfoCache for ProjectInfoCacheImpl { - fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) { +impl ProjectInfoCache { + pub fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) { info!("invalidating endpoint access for `{endpoint_id}`"); - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_endpoint(); - } + self.ep_controls.invalidate(&endpoint_id); } - fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) { + pub fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) { info!("invalidating endpoint access for project `{project_id}`"); let endpoints = self .project2ep @@ -113,13 +47,11 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { .map(|kv| kv.value().clone()) .unwrap_or_default(); for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_endpoint(); - } + self.ep_controls.invalidate(&endpoint_id); } } - fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) { + pub fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) { info!("invalidating endpoint access for org `{account_id}`"); let endpoints = self .account2ep @@ -127,13 +59,15 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { .map(|kv| kv.value().clone()) .unwrap_or_default(); for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_endpoint(); - } + self.ep_controls.invalidate(&endpoint_id); } } - fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) { + pub fn invalidate_role_secret_for_project( + &self, + project_id: ProjectIdInt, + role_name: RoleNameInt, + ) { info!( "invalidating role secret for project_id `{}` and role_name `{}`", project_id, role_name, @@ -144,47 +78,73 @@ impl ProjectInfoCache for ProjectInfoCacheImpl { .map(|kv| kv.value().clone()) .unwrap_or_default(); for endpoint_id in endpoints { - if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) { - endpoint_info.invalidate_role_secret(role_name); - } + self.role_controls.invalidate(&(endpoint_id, role_name)); } } } -impl ProjectInfoCacheImpl { +impl ProjectInfoCache { pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self { + Metrics::get().cache.capacity.set( + CacheKind::ProjectInfoRoles, + (config.size * config.max_roles) as i64, + ); + Metrics::get() + .cache + .capacity + .set(CacheKind::ProjectInfoEndpoints, config.size as i64); + + // we cache errors for 30 seconds, unless retry_at is set. + let expiry = CplaneExpiry::default(); Self { - cache: ClashMap::new(), + role_controls: Cache::builder() + .name("project_info_roles") + .eviction_listener(|_k, _v, cause| { + eviction_listener(CacheKind::ProjectInfoRoles, cause); + }) + .max_capacity(config.size * config.max_roles) + .time_to_live(config.ttl) + .expire_after(expiry) + .build(), + ep_controls: Cache::builder() + .name("project_info_endpoints") + .eviction_listener(|_k, _v, cause| { + eviction_listener(CacheKind::ProjectInfoEndpoints, cause); + }) + .max_capacity(config.size) + .time_to_live(config.ttl) + .expire_after(expiry) + .build(), project2ep: ClashMap::new(), account2ep: ClashMap::new(), config, } } - fn get_endpoint_cache( - &self, - endpoint_id: &EndpointId, - ) -> Option> { - let endpoint_id = EndpointIdInt::get(endpoint_id)?; - self.cache.get(&endpoint_id) - } - - pub(crate) fn get_role_secret_with_ttl( + pub(crate) fn get_role_secret( &self, endpoint_id: &EndpointId, role_name: &RoleName, - ) -> Option<(ControlPlaneResult, Duration)> { + ) -> Option> { + let endpoint_id = EndpointIdInt::get(endpoint_id)?; let role_name = RoleNameInt::get(role_name)?; - let endpoint_info = self.get_endpoint_cache(endpoint_id)?; - endpoint_info.get_role_secret_with_ttl(role_name) + + count_cache_outcome( + CacheKind::ProjectInfoRoles, + self.role_controls.get(&(endpoint_id, role_name)), + ) } - pub(crate) fn get_endpoint_access_with_ttl( + pub(crate) fn get_endpoint_access( &self, endpoint_id: &EndpointId, - ) -> Option<(ControlPlaneResult, Duration)> { - let endpoint_info = self.get_endpoint_cache(endpoint_id)?; - endpoint_info.get_controls_with_ttl() + ) -> Option> { + let endpoint_id = EndpointIdInt::get(endpoint_id)?; + + count_cache_outcome( + CacheKind::ProjectInfoEndpoints, + self.ep_controls.get(&endpoint_id), + ) } pub(crate) fn insert_endpoint_access( @@ -203,34 +163,17 @@ impl ProjectInfoCacheImpl { self.insert_project2endpoint(project_id, endpoint_id); } - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } - debug!( key = &*endpoint_id, "created a cache entry for endpoint access" ); - let controls = Some(Entry::new(Ok(controls), self.config.ttl)); - let role_controls = Entry::new(Ok(role_controls), self.config.ttl); + count_cache_insert(CacheKind::ProjectInfoEndpoints); + count_cache_insert(CacheKind::ProjectInfoRoles); - match self.cache.entry(endpoint_id) { - clashmap::Entry::Vacant(e) => { - e.insert(EndpointInfo { - role_controls: HashMap::from_iter([(role_name, role_controls)]), - controls, - }); - } - clashmap::Entry::Occupied(mut e) => { - let ep = e.get_mut(); - ep.controls = controls; - if ep.role_controls.len() < self.config.max_roles { - ep.role_controls.insert(role_name, role_controls); - } - } - } + self.ep_controls.insert(endpoint_id, Ok(controls)); + self.role_controls + .insert((endpoint_id, role_name), Ok(role_controls)); } pub(crate) fn insert_endpoint_access_err( @@ -238,55 +181,34 @@ impl ProjectInfoCacheImpl { endpoint_id: EndpointIdInt, role_name: RoleNameInt, msg: Box, - ttl: Option, ) { - if self.cache.len() >= self.config.size { - // If there are too many entries, wait until the next gc cycle. - return; - } - debug!( key = &*endpoint_id, "created a cache entry for an endpoint access error" ); - let ttl = ttl.unwrap_or(self.config.ttl); - - let controls = if msg.get_reason() == Reason::RoleProtected { - // RoleProtected is the only role-specific error that control plane can give us. - // If a given role name does not exist, it still returns a successful response, - // just with an empty secret. - None - } else { - // We can cache all the other errors in EndpointInfo.controls, - // because they don't depend on what role name we pass to control plane. - Some(Entry::new(Err(msg.clone()), ttl)) - }; - - let role_controls = Entry::new(Err(msg), ttl); - - match self.cache.entry(endpoint_id) { - clashmap::Entry::Vacant(e) => { - e.insert(EndpointInfo { - role_controls: HashMap::from_iter([(role_name, role_controls)]), - controls, + // RoleProtected is the only role-specific error that control plane can give us. + // If a given role name does not exist, it still returns a successful response, + // just with an empty secret. + if msg.get_reason() != Reason::RoleProtected { + // We can cache all the other errors in ep_controls because they don't + // depend on what role name we pass to control plane. + self.ep_controls + .entry(endpoint_id) + .and_compute_with(|entry| match entry { + // leave the entry alone if it's already Ok + Some(entry) if entry.value().is_ok() => moka::ops::compute::Op::Nop, + // replace the entry + _ => { + count_cache_insert(CacheKind::ProjectInfoEndpoints); + moka::ops::compute::Op::Put(Err(msg.clone())) + } }); - } - clashmap::Entry::Occupied(mut e) => { - let ep = e.get_mut(); - if let Some(entry) = &ep.controls - && !entry.is_expired() - && entry.value.is_ok() - { - // If we have cached non-expired, non-error controls, keep them. - } else { - ep.controls = controls; - } - if ep.role_controls.len() < self.config.max_roles { - ep.role_controls.insert(role_name, role_controls); - } - } } + + count_cache_insert(CacheKind::ProjectInfoRoles); + self.role_controls + .insert((endpoint_id, role_name), Err(msg)); } fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) { @@ -307,73 +229,35 @@ impl ProjectInfoCacheImpl { } } - pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) { - let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else { - return; - }; - let Some(role_name) = RoleNameInt::get(role_name) else { - return; - }; - - let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else { - return; - }; - - let entry = endpoint_info.role_controls.entry(role_name); - let hash_map::Entry::Occupied(role_controls) = entry else { - return; - }; - - if role_controls.get().is_expired() { - role_controls.remove(); - } + pub fn maybe_invalidate_role_secret(&self, _endpoint_id: &EndpointId, _role_name: &RoleName) { + // TODO: Expire the value early if the key is idle. + // Currently not an issue as we would just use the TTL to decide, which is what already happens. } pub async fn gc_worker(&self) -> anyhow::Result { - let mut interval = - tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32); + let mut interval = tokio::time::interval(self.config.gc_interval); loop { interval.tick().await; - if self.cache.len() < self.config.size { - // If there are not too many entries, wait until the next gc cycle. - continue; - } - self.gc(); + self.ep_controls.run_pending_tasks(); + self.role_controls.run_pending_tasks(); } } - - fn gc(&self) { - let shard = rand::rng().random_range(0..self.project2ep.shards().len()); - debug!(shard, "project_info_cache: performing epoch reclamation"); - - // acquire a random shard lock - let mut removed = 0; - let shard = self.project2ep.shards()[shard].write(); - for (_, endpoints) in shard.iter() { - for endpoint in endpoints { - self.cache.remove(endpoint); - removed += 1; - } - } - // We can drop this shard only after making sure that all endpoints are removed. - drop(shard); - info!("project_info_cache: removed {removed} endpoints"); - } } #[cfg(test)] mod tests { + use std::sync::Arc; + use std::time::Duration; + use super::*; use crate::control_plane::messages::{Details, EndpointRateLimitConfig, ErrorInfo, Status}; use crate::control_plane::{AccessBlockerFlags, AuthSecret}; use crate::scram::ServerSecret; - use std::sync::Arc; #[tokio::test] async fn test_project_info_cache_settings() { - tokio::time::pause(); - let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { - size: 2, + let cache = ProjectInfoCache::new(ProjectInfoCacheOptions { + size: 1, max_roles: 2, ttl: Duration::from_secs(1), gc_interval: Duration::from_secs(600), @@ -423,22 +307,17 @@ mod tests { }, ); - let (cached, ttl) = cache - .get_role_secret_with_ttl(&endpoint_id, &user1) - .unwrap(); + let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap(); assert_eq!(cached.unwrap().secret, secret1); - assert_eq!(ttl, cache.config.ttl); - let (cached, ttl) = cache - .get_role_secret_with_ttl(&endpoint_id, &user2) - .unwrap(); + let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap(); assert_eq!(cached.unwrap().secret, secret2); - assert_eq!(ttl, cache.config.ttl); // Shouldn't add more than 2 roles. let user3: RoleName = "user3".into(); let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32]))); + cache.role_controls.run_pending_tasks(); cache.insert_endpoint_access( account_id, project_id, @@ -455,31 +334,18 @@ mod tests { }, ); - assert!( - cache - .get_role_secret_with_ttl(&endpoint_id, &user3) - .is_none() - ); + cache.role_controls.run_pending_tasks(); + assert_eq!(cache.role_controls.entry_count(), 2); - let cached = cache - .get_endpoint_access_with_ttl(&endpoint_id) - .unwrap() - .0 - .unwrap(); - assert_eq!(cached.allowed_ips, allowed_ips); + tokio::time::sleep(Duration::from_secs(2)).await; - tokio::time::advance(Duration::from_secs(2)).await; - let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user1); - assert!(cached.is_none()); - let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user2); - assert!(cached.is_none()); - let cached = cache.get_endpoint_access_with_ttl(&endpoint_id); - assert!(cached.is_none()); + cache.role_controls.run_pending_tasks(); + assert_eq!(cache.role_controls.entry_count(), 0); } #[tokio::test] async fn test_caching_project_info_errors() { - let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions { + let cache = ProjectInfoCache::new(ProjectInfoCacheOptions { size: 10, max_roles: 10, ttl: Duration::from_secs(1), @@ -519,34 +385,23 @@ mod tests { status: None, }); - let get_role_secret = |endpoint_id, role_name| { - cache - .get_role_secret_with_ttl(endpoint_id, role_name) - .unwrap() - .0 - }; - let get_endpoint_access = - |endpoint_id| cache.get_endpoint_access_with_ttl(endpoint_id).unwrap().0; + let get_role_secret = + |endpoint_id, role_name| cache.get_role_secret(endpoint_id, role_name).unwrap(); + let get_endpoint_access = |endpoint_id| cache.get_endpoint_access(endpoint_id).unwrap(); // stores role-specific errors only for get_role_secret - cache.insert_endpoint_access_err( - (&endpoint_id).into(), - (&user1).into(), - role_msg.clone(), - None, - ); + cache.insert_endpoint_access_err((&endpoint_id).into(), (&user1).into(), role_msg.clone()); assert_eq!( get_role_secret(&endpoint_id, &user1).unwrap_err().error, role_msg.error ); - assert!(cache.get_endpoint_access_with_ttl(&endpoint_id).is_none()); + assert!(cache.get_endpoint_access(&endpoint_id).is_none()); // stores non-role specific errors for both get_role_secret and get_endpoint_access cache.insert_endpoint_access_err( (&endpoint_id).into(), (&user1).into(), generic_msg.clone(), - None, ); assert_eq!( get_role_secret(&endpoint_id, &user1).unwrap_err().error, @@ -558,11 +413,7 @@ mod tests { ); // error isn't returned for other roles in the same endpoint - assert!( - cache - .get_role_secret_with_ttl(&endpoint_id, &user2) - .is_none() - ); + assert!(cache.get_role_secret(&endpoint_id, &user2).is_none()); // success for a role does not overwrite errors for other roles cache.insert_endpoint_access( @@ -590,7 +441,6 @@ mod tests { (&endpoint_id).into(), (&user2).into(), generic_msg.clone(), - None, ); assert!(get_role_secret(&endpoint_id, &user2).is_err()); assert!(get_endpoint_access(&endpoint_id).is_ok()); diff --git a/proxy/src/cache/timed_lru.rs b/proxy/src/cache/timed_lru.rs deleted file mode 100644 index 0a7fb40b0c..0000000000 --- a/proxy/src/cache/timed_lru.rs +++ /dev/null @@ -1,262 +0,0 @@ -use std::borrow::Borrow; -use std::hash::Hash; -use std::time::{Duration, Instant}; - -// This seems to make more sense than `lru` or `cached`: -// -// * `near/nearcore` ditched `cached` in favor of `lru` -// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed). -// -// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs). -// This severely hinders its usage both in terms of creating wrappers and supported key types. -// -// On the other hand, `hashlink` has good download stats and appears to be maintained. -use hashlink::{LruCache, linked_hash_map::RawEntryMut}; -use tracing::debug; - -use super::Cache; -use super::common::Cached; - -/// An implementation of timed LRU cache with fixed capacity. -/// Key properties: -/// -/// * Whenever a new entry is inserted, the least recently accessed one is evicted. -/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`). -/// -/// * If `update_ttl_on_retrieval` is `true`. When the entry is about to be retrieved, we check its expiration timestamp. -/// If the entry has expired, we remove it from the cache; Otherwise we bump the -/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong -/// its existence. -/// -/// * There's an API for immediate invalidation (removal) of a cache entry; -/// It's useful in case we know for sure that the entry is no longer correct. -/// See [`Cached`] for more information. -/// -/// * Expired entries are kept in the cache, until they are evicted by the LRU policy, -/// or by a successful lookup (i.e. the entry hasn't expired yet). -/// There is no background job to reap the expired records. -/// -/// * It's possible for an entry that has not yet expired entry to be evicted -/// before expired items. That's a bit wasteful, but probably fine in practice. -pub(crate) struct TimedLru { - /// Cache's name for tracing. - name: &'static str, - - /// The underlying cache implementation. - cache: parking_lot::Mutex>>, - - /// Default time-to-live of a single entry. - ttl: Duration, - - update_ttl_on_retrieval: bool, -} - -impl Cache for TimedLru { - type Key = K; - type Value = V; - type LookupInfo = Key; - - fn invalidate(&self, info: &Self::LookupInfo) { - self.invalidate_raw(info); - } -} - -struct Entry { - created_at: Instant, - expires_at: Instant, - ttl: Duration, - update_ttl_on_retrieval: bool, - value: T, -} - -impl TimedLru { - /// Construct a new LRU cache with timed entries. - pub(crate) fn new( - name: &'static str, - capacity: usize, - ttl: Duration, - update_ttl_on_retrieval: bool, - ) -> Self { - Self { - name, - cache: LruCache::new(capacity).into(), - ttl, - update_ttl_on_retrieval, - } - } - - /// Drop an entry from the cache if it's outdated. - #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] - fn invalidate_raw(&self, key: &K) { - // Do costly things before taking the lock. - let mut cache = self.cache.lock(); - let entry = match cache.raw_entry_mut().from_key(key) { - RawEntryMut::Vacant(_) => return, - RawEntryMut::Occupied(x) => x.remove(), - }; - drop(cache); // drop lock before logging - - let Entry { - created_at, - expires_at, - .. - } = entry; - - debug!( - ?created_at, - ?expires_at, - "processed a cache entry invalidation event" - ); - } - - /// Try retrieving an entry by its key, then execute `extract` if it exists. - #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] - fn get_raw(&self, key: &Q, extract: impl FnOnce(&K, &Entry) -> R) -> Option - where - K: Borrow, - Q: Hash + Eq + ?Sized, - { - let now = Instant::now(); - - // Do costly things before taking the lock. - let mut cache = self.cache.lock(); - let mut raw_entry = match cache.raw_entry_mut().from_key(key) { - RawEntryMut::Vacant(_) => return None, - RawEntryMut::Occupied(x) => x, - }; - - // Immeditely drop the entry if it has expired. - let entry = raw_entry.get(); - if entry.expires_at <= now { - raw_entry.remove(); - return None; - } - - let value = extract(raw_entry.key(), entry); - let (created_at, expires_at) = (entry.created_at, entry.expires_at); - - // Update the deadline and the entry's position in the LRU list. - let deadline = now.checked_add(raw_entry.get().ttl).expect("time overflow"); - if raw_entry.get().update_ttl_on_retrieval { - raw_entry.get_mut().expires_at = deadline; - } - raw_entry.to_back(); - - drop(cache); // drop lock before logging - debug!( - created_at = format_args!("{created_at:?}"), - old_expires_at = format_args!("{expires_at:?}"), - new_expires_at = format_args!("{deadline:?}"), - "accessed a cache entry" - ); - - Some(value) - } - - /// Insert an entry to the cache. If an entry with the same key already - /// existed, return the previous value and its creation timestamp. - #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] - fn insert_raw(&self, key: K, value: V) -> (Instant, Option) { - self.insert_raw_ttl(key, value, self.ttl, self.update_ttl_on_retrieval) - } - - /// Insert an entry to the cache. If an entry with the same key already - /// existed, return the previous value and its creation timestamp. - #[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)] - fn insert_raw_ttl( - &self, - key: K, - value: V, - ttl: Duration, - update: bool, - ) -> (Instant, Option) { - let created_at = Instant::now(); - let expires_at = created_at.checked_add(ttl).expect("time overflow"); - - let entry = Entry { - created_at, - expires_at, - ttl, - update_ttl_on_retrieval: update, - value, - }; - - // Do costly things before taking the lock. - let old = self - .cache - .lock() - .insert(key, entry) - .map(|entry| entry.value); - - debug!( - created_at = format_args!("{created_at:?}"), - expires_at = format_args!("{expires_at:?}"), - replaced = old.is_some(), - "created a cache entry" - ); - - (created_at, old) - } -} - -impl TimedLru { - pub(crate) fn insert_ttl(&self, key: K, value: V, ttl: Duration) { - self.insert_raw_ttl(key, value, ttl, false); - } - - #[cfg(feature = "rest_broker")] - pub(crate) fn insert(&self, key: K, value: V) { - self.insert_raw_ttl(key, value, self.ttl, self.update_ttl_on_retrieval); - } - - pub(crate) fn insert_unit(&self, key: K, value: V) -> (Option, Cached<&Self, ()>) { - let (_, old) = self.insert_raw(key.clone(), value); - - let cached = Cached { - token: Some((self, key)), - value: (), - }; - - (old, cached) - } - - #[cfg(feature = "rest_broker")] - pub(crate) fn flush(&self) { - let now = Instant::now(); - let mut cache = self.cache.lock(); - - // Collect keys of expired entries first - let expired_keys: Vec<_> = cache - .iter() - .filter_map(|(key, entry)| { - if entry.expires_at <= now { - Some(key.clone()) - } else { - None - } - }) - .collect(); - - // Remove expired entries - for key in expired_keys { - cache.remove(&key); - } - } -} - -impl TimedLru { - /// Retrieve a cached entry in convenient wrapper, alongside timing information. - pub(crate) fn get_with_created_at( - &self, - key: &Q, - ) -> Option::Value, Instant)>> - where - K: Borrow + Clone, - Q: Hash + Eq + ?Sized, - { - self.get_raw(key, |key, entry| Cached { - token: Some((self, key.clone())), - value: (entry.value.clone(), entry.created_at), - }) - } -} diff --git a/proxy/src/compute/mod.rs b/proxy/src/compute/mod.rs index 1e3631363e..ca784423ee 100644 --- a/proxy/src/compute/mod.rs +++ b/proxy/src/compute/mod.rs @@ -25,6 +25,7 @@ use crate::control_plane::messages::MetricsAuxInfo; use crate::error::{ReportableError, UserFacingError}; use crate::metrics::{Metrics, NumDbConnectionsGuard}; use crate::pqproto::StartupMessageParams; +use crate::proxy::connect_compute::TlsNegotiation; use crate::proxy::neon_option; use crate::types::Host; @@ -84,6 +85,14 @@ pub(crate) enum ConnectionError { #[error("error acquiring resource permit: {0}")] TooManyConnectionAttempts(#[from] ApiLockError), + + #[cfg(test)] + #[error("retryable: {retryable}, wakeable: {wakeable}, kind: {kind:?}")] + TestError { + retryable: bool, + wakeable: bool, + kind: crate::error::ErrorKind, + }, } impl UserFacingError for ConnectionError { @@ -94,6 +103,8 @@ impl UserFacingError for ConnectionError { "Failed to acquire permit to connect to the database. Too many database connection attempts are currently ongoing.".to_owned() } ConnectionError::TlsError(_) => COULD_NOT_CONNECT.to_owned(), + #[cfg(test)] + ConnectionError::TestError { .. } => self.to_string(), } } } @@ -104,6 +115,8 @@ impl ReportableError for ConnectionError { ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute, ConnectionError::WakeComputeError(e) => e.get_error_kind(), ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(), + #[cfg(test)] + ConnectionError::TestError { kind, .. } => *kind, } } } @@ -256,6 +269,7 @@ impl ConnectInfo { async fn connect_raw( &self, config: &ComputeConfig, + tls: TlsNegotiation, ) -> Result<(SocketAddr, MaybeTlsStream), TlsError> { let timeout = config.timeout; @@ -298,7 +312,7 @@ impl ConnectInfo { match connect_once(&*addrs).await { Ok((sockaddr, stream)) => Ok(( sockaddr, - tls::connect_tls(stream, self.ssl_mode, config, host).await?, + tls::connect_tls(stream, self.ssl_mode, config, host, tls).await?, )), Err(err) => { warn!("couldn't connect to compute node at {host}:{port}: {err}"); @@ -329,9 +343,10 @@ impl ConnectInfo { ctx: &RequestContext, aux: &MetricsAuxInfo, config: &ComputeConfig, + tls: TlsNegotiation, ) -> Result { let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let (socket_addr, stream) = self.connect_raw(config).await?; + let (socket_addr, stream) = self.connect_raw(config, tls).await?; drop(pause); tracing::Span::current().record("compute_id", tracing::field::display(&aux.compute_id)); diff --git a/proxy/src/compute/tls.rs b/proxy/src/compute/tls.rs index 000d75fca5..cc1c0d1658 100644 --- a/proxy/src/compute/tls.rs +++ b/proxy/src/compute/tls.rs @@ -7,6 +7,7 @@ use thiserror::Error; use tokio::io::{AsyncRead, AsyncWrite}; use crate::pqproto::request_tls; +use crate::proxy::connect_compute::TlsNegotiation; use crate::proxy::retry::CouldRetry; #[derive(Debug, Error)] @@ -35,6 +36,7 @@ pub async fn connect_tls( mode: SslMode, tls: &T, host: &str, + negotiation: TlsNegotiation, ) -> Result, TlsError> where S: AsyncRead + AsyncWrite + Unpin + Send, @@ -49,12 +51,15 @@ where SslMode::Prefer | SslMode::Require => {} } - if !request_tls(&mut stream).await? { - if SslMode::Require == mode { - return Err(TlsError::Required); - } - - return Ok(MaybeTlsStream::Raw(stream)); + match negotiation { + // No TLS request needed + TlsNegotiation::Direct => {} + // TLS request successful + TlsNegotiation::Postgres if request_tls(&mut stream).await? => {} + // TLS request failed but is required + TlsNegotiation::Postgres if SslMode::Require == mode => return Err(TlsError::Required), + // TLS request failed but is not required + TlsNegotiation::Postgres => return Ok(MaybeTlsStream::Raw(stream)), } Ok(MaybeTlsStream::Tls( diff --git a/proxy/src/config.rs b/proxy/src/config.rs index 16b1dff5f4..22902dbcab 100644 --- a/proxy/src/config.rs +++ b/proxy/src/config.rs @@ -19,7 +19,7 @@ use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings}; use crate::ext::TaskExt; use crate::intern::RoleNameInt; use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig}; -use crate::scram::threadpool::ThreadPool; +use crate::scram; use crate::serverless::GlobalConnPoolOptions; use crate::serverless::cancel_set::CancelSet; #[cfg(feature = "rest_broker")] @@ -75,7 +75,7 @@ pub struct HttpConfig { } pub struct AuthenticationConfig { - pub thread_pool: Arc, + pub scram_thread_pool: Arc, pub scram_protocol_timeout: tokio::time::Duration, pub ip_allowlist_check_enabled: bool, pub is_vpc_acccess_proxy: bool, @@ -107,20 +107,23 @@ pub fn remote_storage_from_toml(s: &str) -> anyhow::Result #[derive(Debug)] pub struct CacheOptions { /// Max number of entries. - pub size: usize, + pub size: Option, /// Entry's time-to-live. - pub ttl: Duration, + pub absolute_ttl: Option, + /// Entry's time-to-idle. + pub idle_ttl: Option, } impl CacheOptions { - /// Default options for [`crate::control_plane::NodeInfoCache`]. - pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m"; + /// Default options for [`crate::cache::node_info::NodeInfoCache`]. + pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,idle_ttl=4m"; /// Parse cache options passed via cmdline. /// Example: [`Self::CACHE_DEFAULT_OPTIONS`]. fn parse(options: &str) -> anyhow::Result { let mut size = None; - let mut ttl = None; + let mut absolute_ttl = None; + let mut idle_ttl = None; for option in options.split(',') { let (key, value) = option @@ -129,21 +132,34 @@ impl CacheOptions { match key { "size" => size = Some(value.parse()?), - "ttl" => ttl = Some(humantime::parse_duration(value)?), + "absolute_ttl" | "ttl" => absolute_ttl = Some(humantime::parse_duration(value)?), + "idle_ttl" | "tti" => idle_ttl = Some(humantime::parse_duration(value)?), unknown => bail!("unknown key: {unknown}"), } } - // TTL doesn't matter if cache is always empty. - if let Some(0) = size { - ttl.get_or_insert(Duration::default()); - } - Ok(Self { - size: size.context("missing `size`")?, - ttl: ttl.context("missing `ttl`")?, + size, + absolute_ttl, + idle_ttl, }) } + + pub fn moka( + &self, + mut builder: moka::sync::CacheBuilder, + ) -> moka::sync::CacheBuilder { + if let Some(size) = self.size { + builder = builder.max_capacity(size); + } + if let Some(ttl) = self.absolute_ttl { + builder = builder.time_to_live(ttl); + } + if let Some(tti) = self.idle_ttl { + builder = builder.time_to_idle(tti); + } + builder + } } impl FromStr for CacheOptions { @@ -159,17 +175,17 @@ impl FromStr for CacheOptions { #[derive(Debug)] pub struct ProjectInfoCacheOptions { /// Max number of entries. - pub size: usize, + pub size: u64, /// Entry's time-to-live. pub ttl: Duration, /// Max number of roles per endpoint. - pub max_roles: usize, + pub max_roles: u64, /// Gc interval. pub gc_interval: Duration, } impl ProjectInfoCacheOptions { - /// Default options for [`crate::control_plane::NodeInfoCache`]. + /// Default options for [`crate::cache::project_info::ProjectInfoCache`]. pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=10000,ttl=4m,max_roles=10,gc_interval=60m"; @@ -496,21 +512,37 @@ mod tests { #[test] fn test_parse_cache_options() -> anyhow::Result<()> { - let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?; - assert_eq!(size, 4096); - assert_eq!(ttl, Duration::from_secs(5 * 60)); + let CacheOptions { + size, + absolute_ttl, + idle_ttl: _, + } = "size=4096,ttl=5min".parse()?; + assert_eq!(size, Some(4096)); + assert_eq!(absolute_ttl, Some(Duration::from_secs(5 * 60))); - let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?; - assert_eq!(size, 2); - assert_eq!(ttl, Duration::from_secs(4 * 60)); + let CacheOptions { + size, + absolute_ttl, + idle_ttl: _, + } = "ttl=4m,size=2".parse()?; + assert_eq!(size, Some(2)); + assert_eq!(absolute_ttl, Some(Duration::from_secs(4 * 60))); - let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?; - assert_eq!(size, 0); - assert_eq!(ttl, Duration::from_secs(1)); + let CacheOptions { + size, + absolute_ttl, + idle_ttl: _, + } = "size=0,ttl=1s".parse()?; + assert_eq!(size, Some(0)); + assert_eq!(absolute_ttl, Some(Duration::from_secs(1))); - let CacheOptions { size, ttl } = "size=0".parse()?; - assert_eq!(size, 0); - assert_eq!(ttl, Duration::default()); + let CacheOptions { + size, + absolute_ttl, + idle_ttl: _, + } = "size=0".parse()?; + assert_eq!(size, Some(0)); + assert_eq!(absolute_ttl, None); Ok(()) } diff --git a/proxy/src/console_redirect_proxy.rs b/proxy/src/console_redirect_proxy.rs index 639cd123e1..f947abebc0 100644 --- a/proxy/src/console_redirect_proxy.rs +++ b/proxy/src/console_redirect_proxy.rs @@ -16,8 +16,9 @@ use crate::pglb::ClientRequestError; use crate::pglb::handshake::{HandshakeData, handshake}; use crate::pglb::passthrough::ProxyPassthrough; use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol}; -use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; -use crate::proxy::{ErrorSource, forward_compute_params_to_client, send_client_greeting}; +use crate::proxy::{ + ErrorSource, connect_compute, forward_compute_params_to_client, send_client_greeting, +}; use crate::util::run_until_cancelled; pub async fn task_main( @@ -215,14 +216,11 @@ pub(crate) async fn handle_client( }; auth_info.set_startup_params(¶ms, true); - let mut node = connect_to_compute( + let mut node = connect_compute::connect_to_compute( ctx, - &TcpMechanism { - locks: &config.connect_compute_locks, - }, + config, &node_info, - config.wake_compute_retry_config, - &config.connect_to_compute, + connect_compute::TlsNegotiation::Postgres, ) .or_else(|e| async { Err(stream.throw_error(e, Some(ctx)).await) }) .await?; diff --git a/proxy/src/control_plane/client/cplane_proxy_v1.rs b/proxy/src/control_plane/client/cplane_proxy_v1.rs index 8a0403c0b0..b76b13e2c2 100644 --- a/proxy/src/control_plane/client/cplane_proxy_v1.rs +++ b/proxy/src/control_plane/client/cplane_proxy_v1.rs @@ -3,7 +3,6 @@ use std::net::IpAddr; use std::str::FromStr; use std::sync::Arc; -use std::time::Duration; use ::http::HeaderName; use ::http::header::AUTHORIZATION; @@ -17,6 +16,8 @@ use tracing::{Instrument, debug, info, info_span, warn}; use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute}; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; +use crate::cache::Cached; +use crate::cache::node_info::CachedNodeInfo; use crate::context::RequestContext; use crate::control_plane::caches::ApiCaches; use crate::control_plane::errors::{ @@ -25,8 +26,7 @@ use crate::control_plane::errors::{ use crate::control_plane::locks::ApiLocks; use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse}; use crate::control_plane::{ - AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, - RoleAccessControl, + AccessBlockerFlags, AuthInfo, AuthSecret, EndpointAccessControl, NodeInfo, RoleAccessControl, }; use crate::metrics::Metrics; use crate::proxy::retry::CouldRetry; @@ -118,7 +118,6 @@ impl NeonControlPlaneClient { cache_key.into(), role.into(), msg.clone(), - retry_info.map(|r| Duration::from_millis(r.retry_delay_ms)), ); Err(err) @@ -347,18 +346,11 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { ) -> Result { let key = endpoint.normalize(); - if let Some((role_control, ttl)) = self - .caches - .project_info - .get_role_secret_with_ttl(&key, role) - { + if let Some(role_control) = self.caches.project_info.get_role_secret(&key, role) { return match role_control { - Err(mut msg) => { + Err(msg) => { info!(key = &*key, "found cached get_role_access_control error"); - // if retry_delay_ms is set change it to the remaining TTL - replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64); - Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg))) } Ok(role_control) => { @@ -383,17 +375,14 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { ) -> Result { let key = endpoint.normalize(); - if let Some((control, ttl)) = self.caches.project_info.get_endpoint_access_with_ttl(&key) { + if let Some(control) = self.caches.project_info.get_endpoint_access(&key) { return match control { - Err(mut msg) => { + Err(msg) => { info!( key = &*key, "found cached get_endpoint_access_control error" ); - // if retry_delay_ms is set change it to the remaining TTL - replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64); - Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg))) } Ok(control) => { @@ -426,17 +415,11 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { macro_rules! check_cache { () => { - if let Some(cached) = self.caches.node_info.get_with_created_at(&key) { - let (cached, (info, created_at)) = cached.take_value(); + if let Some(info) = self.caches.node_info.get_entry(&key) { return match info { - Err(mut msg) => { + Err(msg) => { info!(key = &*key, "found cached wake_compute error"); - // if retry_delay_ms is set, reduce it by the amount of time it spent in cache - replace_retry_delay_ms(&mut msg, |delay| { - delay.saturating_sub(created_at.elapsed().as_millis() as u64) - }); - Err(WakeComputeError::ControlPlane(ControlPlaneError::Message( msg, ))) @@ -444,7 +427,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { Ok(info) => { debug!(key = &*key, "found cached compute node info"); ctx.set_project(info.aux.clone()); - Ok(cached.map(|()| info)) + Ok(info) } }; } @@ -483,10 +466,12 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { let mut stored_node = node.clone(); // store the cached node as 'warm_cached' stored_node.aux.cold_start_info = ColdStartInfo::WarmCached; + self.caches.node_info.insert(key.clone(), Ok(stored_node)); - let (_, cached) = self.caches.node_info.insert_unit(key, Ok(stored_node)); - - Ok(cached.map(|()| node)) + Ok(Cached { + token: Some((&self.caches.node_info, key)), + value: node, + }) } Err(err) => match err { WakeComputeError::ControlPlane(ControlPlaneError::Message(ref msg)) => { @@ -503,11 +488,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { "created a cache entry for the wake compute error" ); - let ttl = retry_info.map_or(Duration::from_secs(30), |r| { - Duration::from_millis(r.retry_delay_ms) - }); - - self.caches.node_info.insert_ttl(key, Err(msg.clone()), ttl); + self.caches.node_info.insert(key, Err(msg.clone())); Err(err) } @@ -517,14 +498,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient { } } -fn replace_retry_delay_ms(msg: &mut ControlPlaneErrorMessage, f: impl FnOnce(u64) -> u64) { - if let Some(status) = &mut msg.status - && let Some(retry_info) = &mut status.details.retry_info - { - retry_info.retry_delay_ms = f(retry_info.retry_delay_ms); - } -} - /// Parse http response body, taking status code into account. fn parse_body serde::Deserialize<'a>>( status: StatusCode, diff --git a/proxy/src/control_plane/client/mock.rs b/proxy/src/control_plane/client/mock.rs index b84dba6b09..9e48d91340 100644 --- a/proxy/src/control_plane/client/mock.rs +++ b/proxy/src/control_plane/client/mock.rs @@ -15,6 +15,7 @@ use crate::auth::IpPattern; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; use crate::cache::Cached; +use crate::cache::node_info::CachedNodeInfo; use crate::compute::ConnectInfo; use crate::context::RequestContext; use crate::control_plane::errors::{ @@ -22,8 +23,7 @@ use crate::control_plane::errors::{ }; use crate::control_plane::messages::{EndpointRateLimitConfig, MetricsAuxInfo}; use crate::control_plane::{ - AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo, - RoleAccessControl, + AccessBlockerFlags, AuthInfo, AuthSecret, EndpointAccessControl, NodeInfo, RoleAccessControl, }; use crate::intern::RoleNameInt; use crate::scram; diff --git a/proxy/src/control_plane/client/mod.rs b/proxy/src/control_plane/client/mod.rs index ecd4db29b2..ec26746873 100644 --- a/proxy/src/control_plane/client/mod.rs +++ b/proxy/src/control_plane/client/mod.rs @@ -13,10 +13,11 @@ use tracing::{debug, info}; use super::{EndpointAccessControl, RoleAccessControl}; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError}; -use crate::cache::project_info::ProjectInfoCacheImpl; +use crate::cache::node_info::{CachedNodeInfo, NodeInfoCache}; +use crate::cache::project_info::ProjectInfoCache; use crate::config::{CacheOptions, ProjectInfoCacheOptions}; use crate::context::RequestContext; -use crate::control_plane::{CachedNodeInfo, ControlPlaneApi, NodeInfoCache, errors}; +use crate::control_plane::{ControlPlaneApi, errors}; use crate::error::ReportableError; use crate::metrics::ApiLockMetrics; use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token}; @@ -119,7 +120,7 @@ pub struct ApiCaches { /// Cache for the `wake_compute` API method. pub(crate) node_info: NodeInfoCache, /// Cache which stores project_id -> endpoint_ids mapping. - pub project_info: Arc, + pub project_info: Arc, } impl ApiCaches { @@ -128,13 +129,8 @@ impl ApiCaches { project_info_cache_config: ProjectInfoCacheOptions, ) -> Self { Self { - node_info: NodeInfoCache::new( - "node_info_cache", - wake_compute_cache_config.size, - wake_compute_cache_config.ttl, - true, - ), - project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)), + node_info: NodeInfoCache::new(wake_compute_cache_config), + project_info: Arc::new(ProjectInfoCache::new(project_info_cache_config)), } } } diff --git a/proxy/src/control_plane/messages.rs b/proxy/src/control_plane/messages.rs index d44d7efcc3..a23ddeb5c8 100644 --- a/proxy/src/control_plane/messages.rs +++ b/proxy/src/control_plane/messages.rs @@ -1,8 +1,10 @@ use std::fmt::{self, Display}; +use std::time::Duration; use measured::FixedCardinalityLabel; use serde::{Deserialize, Serialize}; use smol_str::SmolStr; +use tokio::time::Instant; use crate::auth::IpPattern; use crate::intern::{AccountIdInt, BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt}; @@ -231,7 +233,13 @@ impl Reason { #[derive(Copy, Clone, Debug, Deserialize)] #[allow(dead_code)] pub(crate) struct RetryInfo { - pub(crate) retry_delay_ms: u64, + #[serde(rename = "retry_delay_ms", deserialize_with = "milliseconds_from_now")] + pub(crate) retry_at: Instant, +} + +fn milliseconds_from_now<'de, D: serde::Deserializer<'de>>(d: D) -> Result { + let millis = u64::deserialize(d)?; + Ok(Instant::now() + Duration::from_millis(millis)) } #[derive(Debug, Deserialize, Clone)] diff --git a/proxy/src/control_plane/mod.rs b/proxy/src/control_plane/mod.rs index 9bbd3f4fb7..6f326d789a 100644 --- a/proxy/src/control_plane/mod.rs +++ b/proxy/src/control_plane/mod.rs @@ -16,14 +16,13 @@ use messages::EndpointRateLimitConfig; use crate::auth::backend::ComputeUserInfo; use crate::auth::backend::jwt::AuthRule; use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list}; -use crate::cache::{Cached, TimedLru}; -use crate::config::ComputeConfig; +use crate::cache::node_info::CachedNodeInfo; use crate::context::RequestContext; -use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo}; +use crate::control_plane::messages::MetricsAuxInfo; use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt}; use crate::protocol2::ConnectionInfoExtra; use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig}; -use crate::types::{EndpointCacheKey, EndpointId, RoleName}; +use crate::types::{EndpointId, RoleName}; use crate::{compute, scram}; /// Various cache-related types. @@ -72,26 +71,12 @@ pub(crate) struct NodeInfo { pub(crate) aux: MetricsAuxInfo, } -impl NodeInfo { - pub(crate) async fn connect( - &self, - ctx: &RequestContext, - config: &ComputeConfig, - ) -> Result { - self.conn_info.connect(ctx, &self.aux, config).await - } -} - #[derive(Copy, Clone, Default, Debug)] pub(crate) struct AccessBlockerFlags { pub public_access_blocked: bool, pub vpc_access_blocked: bool, } -pub(crate) type NodeInfoCache = - TimedLru>>; -pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>; - #[derive(Clone, Debug)] pub struct RoleAccessControl { pub secret: Option, diff --git a/proxy/src/metrics.rs b/proxy/src/metrics.rs index 7524133093..905c9b5279 100644 --- a/proxy/src/metrics.rs +++ b/proxy/src/metrics.rs @@ -2,59 +2,60 @@ use std::sync::{Arc, OnceLock}; use lasso::ThreadedRodeo; use measured::label::{ - FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue, StaticLabelSet, + FixedCardinalitySet, LabelGroupSet, LabelGroupVisitor, LabelName, LabelSet, LabelValue, + StaticLabelSet, }; +use measured::metric::group::Encoding; use measured::metric::histogram::Thresholds; use measured::metric::name::MetricName; use measured::{ - Counter, CounterVec, FixedCardinalityLabel, Gauge, Histogram, HistogramVec, LabelGroup, - MetricGroup, + Counter, CounterVec, FixedCardinalityLabel, Gauge, GaugeVec, Histogram, HistogramVec, + LabelGroup, MetricGroup, }; -use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLogVec}; +use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLogVec, InfoMetric}; use tokio::time::{self, Instant}; use crate::control_plane::messages::ColdStartInfo; use crate::error::ErrorKind; #[derive(MetricGroup)] -#[metric(new(thread_pool: Arc))] +#[metric(new())] pub struct Metrics { #[metric(namespace = "proxy")] - #[metric(init = ProxyMetrics::new(thread_pool))] + #[metric(init = ProxyMetrics::new())] pub proxy: ProxyMetrics, #[metric(namespace = "wake_compute_lock")] pub wake_compute_lock: ApiLockMetrics, + + #[metric(namespace = "service")] + pub service: ServiceMetrics, + + #[metric(namespace = "cache")] + pub cache: CacheMetrics, } -static SELF: OnceLock = OnceLock::new(); impl Metrics { - pub fn install(thread_pool: Arc) { - let mut metrics = Metrics::new(thread_pool); - - metrics.proxy.errors_total.init_all_dense(); - metrics.proxy.redis_errors_total.init_all_dense(); - metrics.proxy.redis_events_count.init_all_dense(); - metrics.proxy.retries_metric.init_all_dense(); - metrics.proxy.connection_failures_total.init_all_dense(); - - SELF.set(metrics) - .ok() - .expect("proxy metrics must not be installed more than once"); - } - + #[track_caller] pub fn get() -> &'static Self { - #[cfg(test)] - return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0)))); + static SELF: OnceLock = OnceLock::new(); - #[cfg(not(test))] - SELF.get() - .expect("proxy metrics must be installed by the main() function") + SELF.get_or_init(|| { + let mut metrics = Metrics::new(); + + metrics.proxy.errors_total.init_all_dense(); + metrics.proxy.redis_errors_total.init_all_dense(); + metrics.proxy.redis_events_count.init_all_dense(); + metrics.proxy.retries_metric.init_all_dense(); + metrics.proxy.connection_failures_total.init_all_dense(); + + metrics + }) } } #[derive(MetricGroup)] -#[metric(new(thread_pool: Arc))] +#[metric(new())] pub struct ProxyMetrics { #[metric(flatten)] pub db_connections: CounterPairVec, @@ -127,6 +128,9 @@ pub struct ProxyMetrics { /// Number of TLS handshake failures pub tls_handshake_failures: Counter, + /// Number of SHA 256 rounds executed. + pub sha_rounds: Counter, + /// HLL approximate cardinality of endpoints that are connecting pub connecting_endpoints: HyperLogLogVec, 32>, @@ -144,8 +148,25 @@ pub struct ProxyMetrics { pub connect_compute_lock: ApiLockMetrics, #[metric(namespace = "scram_pool")] - #[metric(init = thread_pool)] - pub scram_pool: Arc, + pub scram_pool: OnceLockWrapper>, +} + +/// A Wrapper over [`OnceLock`] to implement [`MetricGroup`]. +pub struct OnceLockWrapper(pub OnceLock); + +impl Default for OnceLockWrapper { + fn default() -> Self { + Self(OnceLock::new()) + } +} + +impl> MetricGroup for OnceLockWrapper { + fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> { + if let Some(inner) = self.0.get() { + inner.collect_group_into(enc)?; + } + Ok(()) + } } #[derive(MetricGroup)] @@ -215,13 +236,6 @@ pub enum Bool { False, } -#[derive(FixedCardinalityLabel, Copy, Clone)] -#[label(singleton = "outcome")] -pub enum CacheOutcome { - Hit, - Miss, -} - #[derive(LabelGroup)] #[label(set = ConsoleRequestSet)] pub struct ConsoleRequest<'a> { @@ -553,14 +567,6 @@ impl From for Bool { } } -#[derive(LabelGroup)] -#[label(set = InvalidEndpointsSet)] -pub struct InvalidEndpointsGroup { - pub protocol: Protocol, - pub rejected: Bool, - pub outcome: ConnectOutcome, -} - #[derive(LabelGroup)] #[label(set = RetriesMetricSet)] pub struct RetriesMetricGroup { @@ -660,3 +666,100 @@ pub struct ThreadPoolMetrics { #[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))] pub worker_task_skips_total: CounterVec, } + +#[derive(MetricGroup, Default)] +pub struct ServiceMetrics { + pub info: InfoMetric, +} + +#[derive(Default)] +pub struct ServiceInfo { + pub state: ServiceState, +} + +impl ServiceInfo { + pub const fn running() -> Self { + ServiceInfo { + state: ServiceState::Running, + } + } + + pub const fn terminating() -> Self { + ServiceInfo { + state: ServiceState::Terminating, + } + } +} + +impl LabelGroup for ServiceInfo { + fn visit_values(&self, v: &mut impl LabelGroupVisitor) { + const STATE: &LabelName = LabelName::from_str("state"); + v.write_value(STATE, &self.state); + } +} + +#[derive(FixedCardinalityLabel, Clone, Copy, Debug, Default)] +#[label(singleton = "state")] +pub enum ServiceState { + #[default] + Init, + Running, + Terminating, +} + +#[derive(MetricGroup)] +#[metric(new())] +pub struct CacheMetrics { + /// The capacity of the cache + pub capacity: GaugeVec>, + /// The total number of entries inserted into the cache + pub inserted_total: CounterVec>, + /// The total number of entries removed from the cache + pub evicted_total: CounterVec, + /// The total number of cache requests + pub request_total: CounterVec, +} + +impl Default for CacheMetrics { + fn default() -> Self { + Self::new() + } +} + +#[derive(FixedCardinalityLabel, Clone, Copy, Debug)] +#[label(singleton = "cache")] +pub enum CacheKind { + NodeInfo, + ProjectInfoEndpoints, + ProjectInfoRoles, + Schema, + Pbkdf2, +} + +#[derive(FixedCardinalityLabel, Clone, Copy, Debug)] +pub enum CacheRemovalCause { + Expired, + Explicit, + Replaced, + Size, +} + +#[derive(LabelGroup)] +#[label(set = CacheEvictionSet)] +pub struct CacheEviction { + pub cache: CacheKind, + pub cause: CacheRemovalCause, +} + +#[derive(FixedCardinalityLabel, Copy, Clone)] +pub enum CacheOutcome { + Hit, + Miss, +} + +#[derive(LabelGroup)] +#[label(set = CacheOutcomeSet)] +pub struct CacheOutcomeGroup { + pub cache: CacheKind, + pub outcome: CacheOutcome, +} diff --git a/proxy/src/proxy/connect_auth.rs b/proxy/src/proxy/connect_auth.rs new file mode 100644 index 0000000000..77578c71b1 --- /dev/null +++ b/proxy/src/proxy/connect_auth.rs @@ -0,0 +1,82 @@ +use thiserror::Error; + +use crate::auth::Backend; +use crate::auth::backend::ComputeUserInfo; +use crate::cache::common::Cache; +use crate::compute::{AuthInfo, ComputeConnection, ConnectionError, PostgresError}; +use crate::config::ProxyConfig; +use crate::context::RequestContext; +use crate::control_plane::client::ControlPlaneClient; +use crate::error::{ReportableError, UserFacingError}; +use crate::proxy::connect_compute::{TlsNegotiation, connect_to_compute}; +use crate::proxy::retry::ShouldRetryWakeCompute; + +#[derive(Debug, Error)] +pub enum AuthError { + #[error(transparent)] + Auth(#[from] PostgresError), + #[error(transparent)] + Connect(#[from] ConnectionError), +} + +impl UserFacingError for AuthError { + fn to_string_client(&self) -> String { + match self { + AuthError::Auth(postgres_error) => postgres_error.to_string_client(), + AuthError::Connect(connection_error) => connection_error.to_string_client(), + } + } +} + +impl ReportableError for AuthError { + fn get_error_kind(&self) -> crate::error::ErrorKind { + match self { + AuthError::Auth(postgres_error) => postgres_error.get_error_kind(), + AuthError::Connect(connection_error) => connection_error.get_error_kind(), + } + } +} + +/// Try to connect to the compute node, retrying if necessary. +#[tracing::instrument(skip_all)] +pub(crate) async fn connect_to_compute_and_auth( + ctx: &RequestContext, + config: &ProxyConfig, + user_info: &Backend<'_, ComputeUserInfo>, + auth_info: AuthInfo, + tls: TlsNegotiation, +) -> Result { + let mut attempt = 0; + + // NOTE: This is messy, but should hopefully be detangled with PGLB. + // We wanted to separate the concerns of **connect** to compute (a PGLB operation), + // from **authenticate** to compute (a NeonKeeper operation). + // + // This unfortunately removed retry handling for one error case where + // the compute was cached, and we connected, but the compute cache was actually stale + // and is associated with the wrong endpoint. We detect this when the **authentication** fails. + // As such, we retry once here if the `authenticate` function fails and the error is valid to retry. + loop { + attempt += 1; + let mut node = connect_to_compute(ctx, config, user_info, tls).await?; + + let res = auth_info.authenticate(ctx, &mut node).await; + match res { + Ok(()) => return Ok(node), + Err(e) => { + if attempt < 2 + && let Backend::ControlPlane(cplane, user_info) = user_info + && let ControlPlaneClient::ProxyV1(cplane_proxy_v1) = &**cplane + && e.should_retry_wake_compute() + { + tracing::warn!(error = ?e, "retrying wake compute"); + let key = user_info.endpoint_cache_key(); + cplane_proxy_v1.caches.node_info.invalidate(&key); + continue; + } + + return Err(e)?; + } + } + } +} diff --git a/proxy/src/proxy/connect_compute.rs b/proxy/src/proxy/connect_compute.rs index ce9774e3eb..515f925236 100644 --- a/proxy/src/proxy/connect_compute.rs +++ b/proxy/src/proxy/connect_compute.rs @@ -1,18 +1,16 @@ -use async_trait::async_trait; use tokio::time; use tracing::{debug, info, warn}; +use crate::cache::node_info::CachedNodeInfo; use crate::compute::{self, COULD_NOT_CONNECT, ComputeConnection}; -use crate::config::{ComputeConfig, RetryConfig}; +use crate::config::{ComputeConfig, ProxyConfig, RetryConfig}; use crate::context::RequestContext; -use crate::control_plane::errors::WakeComputeError; +use crate::control_plane::NodeInfo; use crate::control_plane::locks::ApiLocks; -use crate::control_plane::{self, NodeInfo}; -use crate::error::ReportableError; use crate::metrics::{ ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType, }; -use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute, retry_after, should_retry}; +use crate::proxy::retry::{ShouldRetryWakeCompute, retry_after, should_retry}; use crate::proxy::wake_compute::{WakeComputeBackend, wake_compute}; use crate::types::Host; @@ -20,7 +18,7 @@ use crate::types::Host; /// (e.g. the compute node's address might've changed at the wrong time). /// Invalidate the cache entry (if any) to prevent subsequent errors. #[tracing::instrument(skip_all)] -pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> NodeInfo { +pub(crate) fn invalidate_cache(node_info: CachedNodeInfo) -> NodeInfo { let is_cached = node_info.cached(); if is_cached { warn!("invalidating stalled compute node info cache entry"); @@ -35,29 +33,32 @@ pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> Node node_info.invalidate() } -#[async_trait] pub(crate) trait ConnectMechanism { type Connection; - type ConnectError: ReportableError; - type Error: From; async fn connect_once( &self, ctx: &RequestContext, - node_info: &control_plane::CachedNodeInfo, + node_info: &CachedNodeInfo, config: &ComputeConfig, - ) -> Result; + ) -> Result; } -pub(crate) struct TcpMechanism { +struct TcpMechanism<'a> { /// connect_to_compute concurrency lock - pub(crate) locks: &'static ApiLocks, + locks: &'a ApiLocks, + tls: TlsNegotiation, } -#[async_trait] -impl ConnectMechanism for TcpMechanism { +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum TlsNegotiation { + /// TLS is assumed + Direct, + /// We must ask for TLS using the postgres SSLRequest message + Postgres, +} + +impl ConnectMechanism for TcpMechanism<'_> { type Connection = ComputeConnection; - type ConnectError = compute::ConnectionError; - type Error = compute::ConnectionError; #[tracing::instrument(skip_all, fields( pid = tracing::field::Empty, @@ -66,27 +67,49 @@ impl ConnectMechanism for TcpMechanism { async fn connect_once( &self, ctx: &RequestContext, - node_info: &control_plane::CachedNodeInfo, + node_info: &CachedNodeInfo, config: &ComputeConfig, - ) -> Result { + ) -> Result { let permit = self.locks.get_permit(&node_info.conn_info.host).await?; - permit.release_result(node_info.connect(ctx, config).await) + + permit.release_result( + node_info + .conn_info + .connect(ctx, &node_info.aux, config, self.tls) + .await, + ) } } /// Try to connect to the compute node, retrying if necessary. #[tracing::instrument(skip_all)] -pub(crate) async fn connect_to_compute( +pub(crate) async fn connect_to_compute( + ctx: &RequestContext, + config: &ProxyConfig, + user_info: &B, + tls: TlsNegotiation, +) -> Result { + connect_to_compute_inner( + ctx, + &TcpMechanism { + locks: &config.connect_compute_locks, + tls, + }, + user_info, + config.wake_compute_retry_config, + &config.connect_to_compute, + ) + .await +} + +/// Try to connect to the compute node, retrying if necessary. +pub(crate) async fn connect_to_compute_inner( ctx: &RequestContext, mechanism: &M, user_info: &B, wake_compute_retry_config: RetryConfig, compute: &ComputeConfig, -) -> Result -where - M::ConnectError: CouldRetry + ShouldRetryWakeCompute + std::fmt::Debug, - M::Error: From, -{ +) -> Result { let mut num_retries = 0; let node_info = wake_compute(&mut num_retries, ctx, user_info, wake_compute_retry_config).await?; @@ -120,7 +143,7 @@ where }, num_retries.into(), ); - return Err(err.into()); + return Err(err); } node_info } else { @@ -161,7 +184,7 @@ where }, num_retries.into(), ); - return Err(e.into()); + return Err(e); } warn!(error = ?e, num_retries, retriable = true, COULD_NOT_CONNECT); diff --git a/proxy/src/proxy/mod.rs b/proxy/src/proxy/mod.rs index 053726505d..b42457cd95 100644 --- a/proxy/src/proxy/mod.rs +++ b/proxy/src/proxy/mod.rs @@ -1,6 +1,7 @@ #[cfg(test)] mod tests; +pub(crate) mod connect_auth; pub(crate) mod connect_compute; pub(crate) mod retry; pub(crate) mod wake_compute; @@ -23,17 +24,13 @@ use tokio::net::TcpStream; use tokio::sync::oneshot; use tracing::Instrument; -use crate::cache::Cache; use crate::cancellation::{CancelClosure, CancellationHandler}; use crate::compute::{ComputeConnection, PostgresError, RustlsStream}; use crate::config::ProxyConfig; use crate::context::RequestContext; -use crate::control_plane::client::ControlPlaneClient; pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute}; use crate::pglb::{ClientMode, ClientRequestError}; use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams}; -use crate::proxy::connect_compute::{TcpMechanism, connect_to_compute}; -use crate::proxy::retry::ShouldRetryWakeCompute; use crate::rate_limiter::EndpointRateLimiter; use crate::stream::{PqStream, Stream}; use crate::types::EndpointCacheKey; @@ -95,61 +92,24 @@ pub(crate) async fn handle_client( let mut auth_info = compute::AuthInfo::with_auth_keys(creds.keys); auth_info.set_startup_params(params, params_compat); - let mut node; - let mut attempt = 0; - let connect = TcpMechanism { - locks: &config.connect_compute_locks, - }; let backend = auth::Backend::ControlPlane(cplane, creds.info); - // NOTE: This is messy, but should hopefully be detangled with PGLB. - // We wanted to separate the concerns of **connect** to compute (a PGLB operation), - // from **authenticate** to compute (a NeonKeeper operation). - // - // This unfortunately removed retry handling for one error case where - // the compute was cached, and we connected, but the compute cache was actually stale - // and is associated with the wrong endpoint. We detect this when the **authentication** fails. - // As such, we retry once here if the `authenticate` function fails and the error is valid to retry. - loop { - attempt += 1; + // TODO: callback to pglb + let res = connect_auth::connect_to_compute_and_auth( + ctx, + config, + &backend, + auth_info, + connect_compute::TlsNegotiation::Postgres, + ) + .await; - // TODO: callback to pglb - let res = connect_to_compute( - ctx, - &connect, - &backend, - config.wake_compute_retry_config, - &config.connect_to_compute, - ) - .await; + let mut node = match res { + Ok(node) => node, + Err(e) => Err(client.throw_error(e, Some(ctx)).await)?, + }; - match res { - Ok(n) => node = n, - Err(e) => return Err(client.throw_error(e, Some(ctx)).await)?, - } - - let auth::Backend::ControlPlane(cplane, user_info) = &backend else { - unreachable!("ensured above"); - }; - - let res = auth_info.authenticate(ctx, &mut node).await; - match res { - Ok(()) => { - send_client_greeting(ctx, &config.greetings, client); - break; - } - Err(e) if attempt < 2 && e.should_retry_wake_compute() => { - tracing::warn!(error = ?e, "retrying wake compute"); - - #[allow(irrefutable_let_patterns)] - if let ControlPlaneClient::ProxyV1(cplane_proxy_v1) = &**cplane { - let key = user_info.endpoint_cache_key(); - cplane_proxy_v1.caches.node_info.invalidate(&key); - } - } - Err(e) => Err(client.throw_error(e, Some(ctx)).await)?, - } - } + send_client_greeting(ctx, &config.greetings, client); let auth::Backend::ControlPlane(_, user_info) = backend else { unreachable!("ensured above"); diff --git a/proxy/src/proxy/retry.rs b/proxy/src/proxy/retry.rs index b06c3be72c..876d252517 100644 --- a/proxy/src/proxy/retry.rs +++ b/proxy/src/proxy/retry.rs @@ -31,18 +31,6 @@ impl CouldRetry for io::Error { } } -impl CouldRetry for postgres_client::error::DbError { - fn could_retry(&self) -> bool { - use postgres_client::error::SqlState; - matches!( - self.code(), - &SqlState::CONNECTION_FAILURE - | &SqlState::CONNECTION_EXCEPTION - | &SqlState::CONNECTION_DOES_NOT_EXIST - | &SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION, - ) - } -} impl ShouldRetryWakeCompute for postgres_client::error::DbError { fn should_retry_wake_compute(&self) -> bool { use postgres_client::error::SqlState; @@ -73,17 +61,6 @@ impl ShouldRetryWakeCompute for postgres_client::error::DbError { } } -impl CouldRetry for postgres_client::Error { - fn could_retry(&self) -> bool { - if let Some(io_err) = self.source().and_then(|x| x.downcast_ref()) { - io::Error::could_retry(io_err) - } else if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) { - postgres_client::error::DbError::could_retry(db_err) - } else { - false - } - } -} impl ShouldRetryWakeCompute for postgres_client::Error { fn should_retry_wake_compute(&self) -> bool { if let Some(db_err) = self.source().and_then(|x| x.downcast_ref()) { @@ -102,6 +79,8 @@ impl CouldRetry for compute::ConnectionError { compute::ConnectionError::TlsError(err) => err.could_retry(), compute::ConnectionError::WakeComputeError(err) => err.could_retry(), compute::ConnectionError::TooManyConnectionAttempts(_) => false, + #[cfg(test)] + compute::ConnectionError::TestError { retryable, .. } => *retryable, } } } @@ -110,6 +89,8 @@ impl ShouldRetryWakeCompute for compute::ConnectionError { match self { // the cache entry was not checked for validity compute::ConnectionError::TooManyConnectionAttempts(_) => false, + #[cfg(test)] + compute::ConnectionError::TestError { wakeable, .. } => *wakeable, _ => true, } } diff --git a/proxy/src/proxy/tests/mod.rs b/proxy/src/proxy/tests/mod.rs index f8bff450e1..7e0710749e 100644 --- a/proxy/src/proxy/tests/mod.rs +++ b/proxy/src/proxy/tests/mod.rs @@ -15,22 +15,24 @@ use rstest::rstest; use rustls::crypto::ring; use rustls::pki_types; use tokio::io::{AsyncRead, AsyncWrite, DuplexStream}; +use tokio::time::Instant; use tracing_test::traced_test; use super::retry::CouldRetry; use crate::auth::backend::{ComputeUserInfo, MaybeOwned}; -use crate::config::{ComputeConfig, RetryConfig, TlsConfig}; +use crate::cache::node_info::{CachedNodeInfo, NodeInfoCache}; +use crate::config::{CacheOptions, ComputeConfig, RetryConfig, TlsConfig}; use crate::context::RequestContext; use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient}; use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status}; -use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache}; -use crate::error::{ErrorKind, ReportableError}; +use crate::control_plane::{self, NodeInfo}; +use crate::error::ErrorKind; use crate::pglb::ERR_INSECURE_CONNECTION; use crate::pglb::handshake::{HandshakeData, handshake}; use crate::pqproto::BeMessage; use crate::proxy::NeonOptions; -use crate::proxy::connect_compute::{ConnectMechanism, connect_to_compute}; -use crate::proxy::retry::{ShouldRetryWakeCompute, retry_after}; +use crate::proxy::connect_compute::{ConnectMechanism, connect_to_compute_inner}; +use crate::proxy::retry::retry_after; use crate::stream::{PqStream, Stream}; use crate::tls::client_config::compute_client_config_with_certs; use crate::tls::server_config::CertResolver; @@ -417,12 +419,11 @@ impl TestConnectMechanism { Self { counter: Arc::new(std::sync::Mutex::new(0)), sequence, - cache: Box::leak(Box::new(NodeInfoCache::new( - "test", - 1, - Duration::from_secs(100), - false, - ))), + cache: Box::leak(Box::new(NodeInfoCache::new(CacheOptions { + size: Some(1), + absolute_ttl: Some(Duration::from_secs(100)), + idle_ttl: None, + }))), } } } @@ -430,71 +431,36 @@ impl TestConnectMechanism { #[derive(Debug)] struct TestConnection; -#[derive(Debug)] -struct TestConnectError { - retryable: bool, - wakeable: bool, - kind: crate::error::ErrorKind, -} - -impl ReportableError for TestConnectError { - fn get_error_kind(&self) -> crate::error::ErrorKind { - self.kind - } -} - -impl std::fmt::Display for TestConnectError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{self:?}") - } -} - -impl std::error::Error for TestConnectError {} - -impl CouldRetry for TestConnectError { - fn could_retry(&self) -> bool { - self.retryable - } -} -impl ShouldRetryWakeCompute for TestConnectError { - fn should_retry_wake_compute(&self) -> bool { - self.wakeable - } -} - -#[async_trait] impl ConnectMechanism for TestConnectMechanism { type Connection = TestConnection; - type ConnectError = TestConnectError; - type Error = anyhow::Error; async fn connect_once( &self, _ctx: &RequestContext, - _node_info: &control_plane::CachedNodeInfo, + _node_info: &CachedNodeInfo, _config: &ComputeConfig, - ) -> Result { + ) -> Result { let mut counter = self.counter.lock().unwrap(); let action = self.sequence[*counter]; *counter += 1; match action { ConnectAction::Connect => Ok(TestConnection), - ConnectAction::Retry => Err(TestConnectError { + ConnectAction::Retry => Err(compute::ConnectionError::TestError { retryable: true, wakeable: true, kind: ErrorKind::Compute, }), - ConnectAction::RetryNoWake => Err(TestConnectError { + ConnectAction::RetryNoWake => Err(compute::ConnectionError::TestError { retryable: true, wakeable: false, kind: ErrorKind::Compute, }), - ConnectAction::Fail => Err(TestConnectError { + ConnectAction::Fail => Err(compute::ConnectionError::TestError { retryable: false, wakeable: true, kind: ErrorKind::Compute, }), - ConnectAction::FailNoWake => Err(TestConnectError { + ConnectAction::FailNoWake => Err(compute::ConnectionError::TestError { retryable: false, wakeable: false, kind: ErrorKind::Compute, @@ -536,7 +502,7 @@ impl TestControlPlaneClient for TestConnectMechanism { details: Details { error_info: None, retry_info: Some(control_plane::messages::RetryInfo { - retry_delay_ms: 1, + retry_at: Instant::now() + Duration::from_millis(1), }), user_facing_message: None, }, @@ -581,8 +547,11 @@ fn helper_create_uncached_node_info() -> NodeInfo { fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo { let node = helper_create_uncached_node_info(); - let (_, node2) = cache.insert_unit("key".into(), Ok(node.clone())); - node2.map(|()| node) + cache.insert("key".into(), Ok(node.clone())); + CachedNodeInfo { + token: Some((cache, "key".into())), + value: node, + } } fn helper_create_connect_info( @@ -620,7 +589,7 @@ async fn connect_to_compute_success() { let mechanism = TestConnectMechanism::new(vec![Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -634,7 +603,7 @@ async fn connect_to_compute_retry() { let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -649,7 +618,7 @@ async fn connect_to_compute_non_retry_1() { let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Fail]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap_err(); mechanism.verify(); @@ -664,7 +633,7 @@ async fn connect_to_compute_non_retry_2() { let mechanism = TestConnectMechanism::new(vec![Wake, Fail, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -686,7 +655,7 @@ async fn connect_to_compute_non_retry_3() { backoff_factor: 2.0, }; let config = config(); - connect_to_compute( + connect_to_compute_inner( &ctx, &mechanism, &user_info, @@ -707,7 +676,7 @@ async fn wake_retry() { let mechanism = TestConnectMechanism::new(vec![WakeRetry, Wake, Connect]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap(); mechanism.verify(); @@ -722,7 +691,7 @@ async fn wake_non_retry() { let mechanism = TestConnectMechanism::new(vec![WakeRetry, WakeFail]); let user_info = helper_create_connect_info(&mechanism); let config = config(); - connect_to_compute(&ctx, &mechanism, &user_info, config.retry, &config) + connect_to_compute_inner(&ctx, &mechanism, &user_info, config.retry, &config) .await .unwrap_err(); mechanism.verify(); @@ -741,7 +710,7 @@ async fn fail_but_wake_invalidates_cache() { let user = helper_create_connect_info(&mech); let cfg = config(); - connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg) + connect_to_compute_inner(&ctx, &mech, &user, cfg.retry, &cfg) .await .unwrap(); @@ -762,7 +731,7 @@ async fn fail_no_wake_skips_cache_invalidation() { let user = helper_create_connect_info(&mech); let cfg = config(); - connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg) + connect_to_compute_inner(&ctx, &mech, &user, cfg.retry, &cfg) .await .unwrap(); @@ -783,7 +752,7 @@ async fn retry_but_wake_invalidates_cache() { let user_info = helper_create_connect_info(&mechanism); let cfg = config(); - connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg) + connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg) .await .unwrap(); mechanism.verify(); @@ -806,7 +775,7 @@ async fn retry_no_wake_skips_invalidation() { let user_info = helper_create_connect_info(&mechanism); let cfg = config(); - connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg) + connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg) .await .unwrap_err(); mechanism.verify(); @@ -829,7 +798,7 @@ async fn retry_no_wake_error_fast() { let user_info = helper_create_connect_info(&mechanism); let cfg = config(); - connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg) + connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg) .await .unwrap_err(); mechanism.verify(); @@ -852,7 +821,7 @@ async fn retry_cold_wake_skips_invalidation() { let user_info = helper_create_connect_info(&mechanism); let cfg = config(); - connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg) + connect_to_compute_inner(&ctx, &mechanism, &user_info, cfg.retry, &cfg) .await .unwrap(); mechanism.verify(); diff --git a/proxy/src/proxy/wake_compute.rs b/proxy/src/proxy/wake_compute.rs index b8edf9fd5c..66f2df2af4 100644 --- a/proxy/src/proxy/wake_compute.rs +++ b/proxy/src/proxy/wake_compute.rs @@ -1,9 +1,9 @@ use async_trait::async_trait; use tracing::{error, info}; +use crate::cache::node_info::CachedNodeInfo; use crate::config::RetryConfig; use crate::context::RequestContext; -use crate::control_plane::CachedNodeInfo; use crate::control_plane::errors::{ControlPlaneError, WakeComputeError}; use crate::error::ReportableError; use crate::metrics::{ diff --git a/proxy/src/redis/notifications.rs b/proxy/src/redis/notifications.rs index 88d5550fff..a3a378c7e2 100644 --- a/proxy/src/redis/notifications.rs +++ b/proxy/src/redis/notifications.rs @@ -131,11 +131,11 @@ where Ok(()) } -struct MessageHandler { +struct MessageHandler { cache: Arc, } -impl Clone for MessageHandler { +impl Clone for MessageHandler { fn clone(&self) -> Self { Self { cache: self.cache.clone(), @@ -143,8 +143,8 @@ impl Clone for MessageHandler { } } -impl MessageHandler { - pub(crate) fn new(cache: Arc) -> Self { +impl MessageHandler { + pub(crate) fn new(cache: Arc) -> Self { Self { cache } } @@ -224,7 +224,7 @@ impl MessageHandler { } } -fn invalidate_cache(cache: Arc, msg: Notification) { +fn invalidate_cache(cache: Arc, msg: Notification) { match msg { Notification::EndpointSettingsUpdate(ids) => ids .iter() @@ -247,8 +247,8 @@ fn invalidate_cache(cache: Arc, msg: Notification) { } } -async fn handle_messages( - handler: MessageHandler, +async fn handle_messages( + handler: MessageHandler, redis: ConnectionWithCredentialsProvider, cancellation_token: CancellationToken, ) -> anyhow::Result<()> { @@ -284,13 +284,10 @@ async fn handle_messages( /// Handle console's invalidation messages. #[tracing::instrument(name = "redis_notifications", skip_all)] -pub async fn task_main( +pub async fn task_main( redis: ConnectionWithCredentialsProvider, - cache: Arc, -) -> anyhow::Result -where - C: ProjectInfoCache + Send + Sync + 'static, -{ + cache: Arc, +) -> anyhow::Result { let handler = MessageHandler::new(cache); // 6h - 1m. // There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost. diff --git a/proxy/src/scram/cache.rs b/proxy/src/scram/cache.rs new file mode 100644 index 0000000000..9ade7af458 --- /dev/null +++ b/proxy/src/scram/cache.rs @@ -0,0 +1,84 @@ +use tokio::time::Instant; +use zeroize::Zeroize as _; + +use super::pbkdf2; +use crate::cache::Cached; +use crate::cache::common::{Cache, count_cache_insert, count_cache_outcome, eviction_listener}; +use crate::intern::{EndpointIdInt, RoleNameInt}; +use crate::metrics::{CacheKind, Metrics}; + +pub(crate) struct Pbkdf2Cache(moka::sync::Cache<(EndpointIdInt, RoleNameInt), Pbkdf2CacheEntry>); +pub(crate) type CachedPbkdf2<'a> = Cached<&'a Pbkdf2Cache>; + +impl Cache for Pbkdf2Cache { + type Key = (EndpointIdInt, RoleNameInt); + type Value = Pbkdf2CacheEntry; + + fn invalidate(&self, info: &(EndpointIdInt, RoleNameInt)) { + self.0.invalidate(info); + } +} + +/// To speed up password hashing for more active customers, we store the tail results of the +/// PBKDF2 algorithm. If the output of PBKDF2 is U1 ^ U2 ^ ⋯ ^ Uc, then we store +/// suffix = U17 ^ U18 ^ ⋯ ^ Uc. We only need to calculate U1 ^ U2 ^ ⋯ ^ U15 ^ U16 +/// to determine the final result. +/// +/// The suffix alone isn't enough to crack the password. The stored_key is still required. +/// While both are cached in memory, given they're in different locations is makes it much +/// harder to exploit, even if any such memory exploit exists in proxy. +#[derive(Clone)] +pub struct Pbkdf2CacheEntry { + /// corresponds to [`super::ServerSecret::cached_at`] + pub(super) cached_from: Instant, + pub(super) suffix: pbkdf2::Block, +} + +impl Drop for Pbkdf2CacheEntry { + fn drop(&mut self) { + self.suffix.zeroize(); + } +} + +impl Pbkdf2Cache { + pub fn new() -> Self { + const SIZE: u64 = 100; + const TTL: std::time::Duration = std::time::Duration::from_secs(60); + + let builder = moka::sync::Cache::builder() + .name("pbkdf2") + .max_capacity(SIZE) + // We use time_to_live so we don't refresh the lifetime for an invalid password attempt. + .time_to_live(TTL); + + Metrics::get() + .cache + .capacity + .set(CacheKind::Pbkdf2, SIZE as i64); + + let builder = + builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Pbkdf2, cause)); + + Self(builder.build()) + } + + pub fn insert(&self, endpoint: EndpointIdInt, role: RoleNameInt, value: Pbkdf2CacheEntry) { + count_cache_insert(CacheKind::Pbkdf2); + self.0.insert((endpoint, role), value); + } + + fn get(&self, endpoint: EndpointIdInt, role: RoleNameInt) -> Option { + count_cache_outcome(CacheKind::Pbkdf2, self.0.get(&(endpoint, role))) + } + + pub fn get_entry( + &self, + endpoint: EndpointIdInt, + role: RoleNameInt, + ) -> Option> { + self.get(endpoint, role).map(|value| Cached { + token: Some((self, (endpoint, role))), + value, + }) + } +} diff --git a/proxy/src/scram/exchange.rs b/proxy/src/scram/exchange.rs index a0918fca9f..3f4b0d534b 100644 --- a/proxy/src/scram/exchange.rs +++ b/proxy/src/scram/exchange.rs @@ -4,10 +4,8 @@ use std::convert::Infallible; use base64::Engine as _; use base64::prelude::BASE64_STANDARD; -use hmac::{Hmac, Mac}; -use sha2::Sha256; +use tracing::{debug, trace}; -use super::ScramKey; use super::messages::{ ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN, }; @@ -15,8 +13,10 @@ use super::pbkdf2::Pbkdf2; use super::secret::ServerSecret; use super::signature::SignatureBuilder; use super::threadpool::ThreadPool; -use crate::intern::EndpointIdInt; +use super::{ScramKey, pbkdf2}; +use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::sasl::{self, ChannelBinding, Error as SaslError}; +use crate::scram::cache::Pbkdf2CacheEntry; /// The only channel binding mode we currently support. #[derive(Debug)] @@ -77,46 +77,113 @@ impl<'a> Exchange<'a> { } } -// copied from async fn derive_client_key( pool: &ThreadPool, endpoint: EndpointIdInt, password: &[u8], salt: &[u8], iterations: u32, -) -> ScramKey { - let salted_password = pool - .spawn_job(endpoint, Pbkdf2::start(password, salt, iterations)) - .await; - - let make_key = |name| { - let key = Hmac::::new_from_slice(&salted_password) - .expect("HMAC is able to accept all key sizes") - .chain_update(name) - .finalize(); - - <[u8; 32]>::from(key.into_bytes()) - }; - - make_key(b"Client Key").into() +) -> pbkdf2::Block { + pool.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations)) + .await } +/// For cleartext flow, we need to derive the client key to +/// 1. authenticate the client. +/// 2. authenticate with compute. pub(crate) async fn exchange( pool: &ThreadPool, endpoint: EndpointIdInt, + role: RoleNameInt, + secret: &ServerSecret, + password: &[u8], +) -> sasl::Result> { + if secret.iterations > CACHED_ROUNDS { + exchange_with_cache(pool, endpoint, role, secret, password).await + } else { + let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?; + let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await; + Ok(validate_pbkdf2(secret, &hash)) + } +} + +/// Compute the client key using a cache. We cache the suffix of the pbkdf2 result only, +/// which is not enough by itself to perform an offline brute force. +async fn exchange_with_cache( + pool: &ThreadPool, + endpoint: EndpointIdInt, + role: RoleNameInt, secret: &ServerSecret, password: &[u8], ) -> sasl::Result> { let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?; - let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await; + debug_assert!( + secret.iterations > CACHED_ROUNDS, + "we should not cache password data if there isn't enough rounds needed" + ); + + // compute the prefix of the pbkdf2 output. + let prefix = derive_client_key(pool, endpoint, password, &salt, CACHED_ROUNDS).await; + + if let Some(entry) = pool.cache.get_entry(endpoint, role) { + // hot path: let's check the threadpool cache + if secret.cached_at == entry.cached_from { + // cache is valid. compute the full hash by adding the prefix to the suffix. + let mut hash = prefix; + pbkdf2::xor_assign(&mut hash, &entry.suffix); + let outcome = validate_pbkdf2(secret, &hash); + + if matches!(outcome, sasl::Outcome::Success(_)) { + trace!("password validated from cache"); + } + + return Ok(outcome); + } + + // cached key is no longer valid. + debug!("invalidating cached password"); + entry.invalidate(); + } + + // slow path: full password hash. + let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await; + let outcome = validate_pbkdf2(secret, &hash); + + let client_key = match outcome { + sasl::Outcome::Success(client_key) => client_key, + sasl::Outcome::Failure(_) => return Ok(outcome), + }; + + trace!("storing cached password"); + + // time to cache, compute the suffix by subtracting the prefix from the hash. + let mut suffix = hash; + pbkdf2::xor_assign(&mut suffix, &prefix); + + pool.cache.insert( + endpoint, + role, + Pbkdf2CacheEntry { + cached_from: secret.cached_at, + suffix, + }, + ); + + Ok(sasl::Outcome::Success(client_key)) +} + +fn validate_pbkdf2(secret: &ServerSecret, hash: &pbkdf2::Block) -> sasl::Outcome { + let client_key = super::ScramKey::client_key(&(*hash).into()); if secret.is_password_invalid(&client_key).into() { - Ok(sasl::Outcome::Failure("password doesn't match")) + sasl::Outcome::Failure("password doesn't match") } else { - Ok(sasl::Outcome::Success(client_key)) + sasl::Outcome::Success(client_key) } } +const CACHED_ROUNDS: u32 = 16; + impl SaslInitial { fn transition( &self, diff --git a/proxy/src/scram/key.rs b/proxy/src/scram/key.rs index fe55ff493b..7dc52fd409 100644 --- a/proxy/src/scram/key.rs +++ b/proxy/src/scram/key.rs @@ -1,6 +1,12 @@ //! Tools for client/server/stored key management. +use hmac::Mac as _; +use sha2::Digest as _; use subtle::ConstantTimeEq; +use zeroize::Zeroize as _; + +use crate::metrics::Metrics; +use crate::scram::pbkdf2::Prf; /// Faithfully taken from PostgreSQL. pub(crate) const SCRAM_KEY_LEN: usize = 32; @@ -14,6 +20,12 @@ pub(crate) struct ScramKey { bytes: [u8; SCRAM_KEY_LEN], } +impl Drop for ScramKey { + fn drop(&mut self) { + self.bytes.zeroize(); + } +} + impl PartialEq for ScramKey { fn eq(&self, other: &Self) -> bool { self.ct_eq(other).into() @@ -28,12 +40,26 @@ impl ConstantTimeEq for ScramKey { impl ScramKey { pub(crate) fn sha256(&self) -> Self { - super::sha256([self.as_ref()]).into() + Metrics::get().proxy.sha_rounds.inc_by(1); + Self { + bytes: sha2::Sha256::digest(self.as_bytes()).into(), + } } pub(crate) fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] { self.bytes } + + pub(crate) fn client_key(b: &[u8; 32]) -> Self { + // Prf::new_from_slice will run 2 sha256 rounds. + // Update + Finalize run 2 sha256 rounds. + Metrics::get().proxy.sha_rounds.inc_by(4); + + let mut prf = Prf::new_from_slice(b).expect("HMAC is able to accept all key sizes"); + prf.update(b"Client Key"); + let client_key: [u8; 32] = prf.finalize().into_bytes().into(); + client_key.into() + } } impl From<[u8; SCRAM_KEY_LEN]> for ScramKey { diff --git a/proxy/src/scram/mod.rs b/proxy/src/scram/mod.rs index 5f627e062c..04722d920b 100644 --- a/proxy/src/scram/mod.rs +++ b/proxy/src/scram/mod.rs @@ -6,6 +6,7 @@ //! * //! * +mod cache; mod countmin; mod exchange; mod key; @@ -18,10 +19,8 @@ pub mod threadpool; use base64::Engine as _; use base64::prelude::BASE64_STANDARD; pub(crate) use exchange::{Exchange, exchange}; -use hmac::{Hmac, Mac}; pub(crate) use key::ScramKey; pub(crate) use secret::ServerSecret; -use sha2::{Digest, Sha256}; const SCRAM_SHA_256: &str = "SCRAM-SHA-256"; const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS"; @@ -42,29 +41,13 @@ fn base64_decode_array(input: impl AsRef<[u8]>) -> Option<[u8; N Some(bytes) } -/// This function essentially is `Hmac(sha256, key, input)`. -/// Further reading: . -fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator) -> [u8; 32] { - let mut mac = Hmac::::new_from_slice(key).expect("bad key size"); - parts.into_iter().for_each(|s| mac.update(s)); - - mac.finalize().into_bytes().into() -} - -fn sha256<'a>(parts: impl IntoIterator) -> [u8; 32] { - let mut hasher = Sha256::new(); - parts.into_iter().for_each(|s| hasher.update(s)); - - hasher.finalize().into() -} - #[cfg(test)] mod tests { use super::threadpool::ThreadPool; use super::{Exchange, ServerSecret}; - use crate::intern::EndpointIdInt; + use crate::intern::{EndpointIdInt, RoleNameInt}; use crate::sasl::{Mechanism, Step}; - use crate::types::EndpointId; + use crate::types::{EndpointId, RoleName}; #[test] fn snapshot() { @@ -114,23 +97,34 @@ mod tests { ); } - async fn run_round_trip_test(server_password: &str, client_password: &str) { - let pool = ThreadPool::new(1); - + async fn check( + pool: &ThreadPool, + scram_secret: &ServerSecret, + password: &[u8], + ) -> Result<(), &'static str> { let ep = EndpointId::from("foo"); let ep = EndpointIdInt::from(ep); + let role = RoleName::from("user"); + let role = RoleNameInt::from(&role); - let scram_secret = ServerSecret::build(server_password).await.unwrap(); - let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes()) + let outcome = super::exchange(pool, ep, role, scram_secret, password) .await .unwrap(); match outcome { - crate::sasl::Outcome::Success(_) => {} - crate::sasl::Outcome::Failure(r) => panic!("{r}"), + crate::sasl::Outcome::Success(_) => Ok(()), + crate::sasl::Outcome::Failure(r) => Err(r), } } + async fn run_round_trip_test(server_password: &str, client_password: &str) { + let pool = ThreadPool::new(1); + let scram_secret = ServerSecret::build(server_password).await.unwrap(); + check(&pool, &scram_secret, client_password.as_bytes()) + .await + .unwrap(); + } + #[tokio::test] async fn round_trip() { run_round_trip_test("pencil", "pencil").await; @@ -141,4 +135,27 @@ mod tests { async fn failure() { run_round_trip_test("pencil", "eraser").await; } + + #[tokio::test] + #[tracing_test::traced_test] + async fn password_cache() { + let pool = ThreadPool::new(1); + let scram_secret = ServerSecret::build("password").await.unwrap(); + + // wrong passwords are not added to cache + check(&pool, &scram_secret, b"wrong").await.unwrap_err(); + assert!(!logs_contain("storing cached password")); + + // correct passwords get cached + check(&pool, &scram_secret, b"password").await.unwrap(); + assert!(logs_contain("storing cached password")); + + // wrong passwords do not match the cache + check(&pool, &scram_secret, b"wrong").await.unwrap_err(); + assert!(!logs_contain("password validated from cache")); + + // correct passwords match the cache + check(&pool, &scram_secret, b"password").await.unwrap(); + assert!(logs_contain("password validated from cache")); + } } diff --git a/proxy/src/scram/pbkdf2.rs b/proxy/src/scram/pbkdf2.rs index 7f48e00c41..1300310de2 100644 --- a/proxy/src/scram/pbkdf2.rs +++ b/proxy/src/scram/pbkdf2.rs @@ -1,25 +1,50 @@ +//! For postgres password authentication, we need to perform a PBKDF2 using +//! PRF=HMAC-SHA2-256, producing only 1 block (32 bytes) of output key. + +use hmac::Mac as _; use hmac::digest::consts::U32; use hmac::digest::generic_array::GenericArray; -use hmac::{Hmac, Mac}; -use sha2::Sha256; +use zeroize::Zeroize as _; + +use crate::metrics::Metrics; + +/// The Psuedo-random function used during PBKDF2 and the SCRAM-SHA-256 handshake. +pub type Prf = hmac::Hmac; +pub(crate) type Block = GenericArray; pub(crate) struct Pbkdf2 { - hmac: Hmac, - prev: GenericArray, - hi: GenericArray, + hmac: Prf, + /// U{r-1} for whatever iteration r we are currently on. + prev: Block, + /// the output of `fold(xor, U{1}..U{r})` for whatever iteration r we are currently on. + hi: Block, + /// number of iterations left iterations: u32, } +impl Drop for Pbkdf2 { + fn drop(&mut self) { + self.prev.zeroize(); + self.hi.zeroize(); + } +} + // inspired from impl Pbkdf2 { - pub(crate) fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self { + pub(crate) fn start(pw: &[u8], salt: &[u8], iterations: u32) -> Self { // key the HMAC and derive the first block in-place - let mut hmac = - Hmac::::new_from_slice(str).expect("HMAC is able to accept all key sizes"); + let mut hmac = Prf::new_from_slice(pw).expect("HMAC is able to accept all key sizes"); + + // U1 = PRF(Password, Salt + INT_32_BE(i)) + // i = 1 since we only need 1 block of output. hmac.update(salt); hmac.update(&1u32.to_be_bytes()); let init_block = hmac.finalize_reset().into_bytes(); + // Prf::new_from_slice will run 2 sha256 rounds. + // Our update + finalize run 2 sha256 rounds for each pbkdf2 round. + Metrics::get().proxy.sha_rounds.inc_by(4); + Self { hmac, // one iteration spent above @@ -33,7 +58,11 @@ impl Pbkdf2 { (self.iterations).clamp(0, 4096) } - pub(crate) fn turn(&mut self) -> std::task::Poll<[u8; 32]> { + /// For "fairness", we implement PBKDF2 with cooperative yielding, which is why we use this `turn` + /// function that only executes a fixed number of iterations before continuing. + /// + /// Task must be rescheuled if this returns [`std::task::Poll::Pending`]. + pub(crate) fn turn(&mut self) -> std::task::Poll { let Self { hmac, prev, @@ -44,25 +73,37 @@ impl Pbkdf2 { // only do up to 4096 iterations per turn for fairness let n = (*iterations).clamp(0, 4096); for _ in 0..n { - hmac.update(prev); - let block = hmac.finalize_reset().into_bytes(); - - for (hi_byte, &b) in hi.iter_mut().zip(block.iter()) { - *hi_byte ^= b; - } - - *prev = block; + let next = single_round(hmac, prev); + xor_assign(hi, &next); + *prev = next; } + // Our update + finalize run 2 sha256 rounds for each pbkdf2 round. + Metrics::get().proxy.sha_rounds.inc_by(2 * n as u64); + *iterations -= n; if *iterations == 0 { - std::task::Poll::Ready((*hi).into()) + std::task::Poll::Ready(*hi) } else { std::task::Poll::Pending } } } +#[inline(always)] +pub fn xor_assign(x: &mut Block, y: &Block) { + for (x, &y) in std::iter::zip(x, y) { + *x ^= y; + } +} + +#[inline(always)] +fn single_round(prf: &mut Prf, ui: &Block) -> Block { + // Ui = PRF(Password, Ui-1) + prf.update(ui); + prf.finalize_reset().into_bytes() +} + #[cfg(test)] mod tests { use pbkdf2::pbkdf2_hmac_array; @@ -76,11 +117,11 @@ mod tests { let pass = b"Ne0n_!5_50_C007"; let mut job = Pbkdf2::start(pass, salt, 60000); - let hash = loop { + let hash: [u8; 32] = loop { let std::task::Poll::Ready(hash) = job.turn() else { continue; }; - break hash; + break hash.into(); }; let expected = pbkdf2_hmac_array::(pass, salt, 60000); diff --git a/proxy/src/scram/secret.rs b/proxy/src/scram/secret.rs index 0e070c2f27..a3a64f271c 100644 --- a/proxy/src/scram/secret.rs +++ b/proxy/src/scram/secret.rs @@ -3,6 +3,7 @@ use base64::Engine as _; use base64::prelude::BASE64_STANDARD; use subtle::{Choice, ConstantTimeEq}; +use tokio::time::Instant; use super::base64_decode_array; use super::key::ScramKey; @@ -11,6 +12,9 @@ use super::key::ScramKey; /// and is used throughout the authentication process. #[derive(Clone, Eq, PartialEq, Debug)] pub(crate) struct ServerSecret { + /// When this secret was cached. + pub(crate) cached_at: Instant, + /// Number of iterations for `PBKDF2` function. pub(crate) iterations: u32, /// Salt used to hash user's password. @@ -34,6 +38,7 @@ impl ServerSecret { params.split_once(':').zip(keys.split_once(':'))?; let secret = ServerSecret { + cached_at: Instant::now(), iterations: iterations.parse().ok()?, salt_base64: salt.into(), stored_key: base64_decode_array(stored_key)?.into(), @@ -54,6 +59,7 @@ impl ServerSecret { /// See `auth-scram.c : mock_scram_secret` for details. pub(crate) fn mock(nonce: [u8; 32]) -> Self { Self { + cached_at: Instant::now(), // this doesn't reveal much information as we're going to use // iteration count 1 for our generated passwords going forward. // PG16 users can set iteration count=1 already today. diff --git a/proxy/src/scram/signature.rs b/proxy/src/scram/signature.rs index a5b1c3e9f4..8e074272b6 100644 --- a/proxy/src/scram/signature.rs +++ b/proxy/src/scram/signature.rs @@ -1,6 +1,10 @@ //! Tools for client/server signature management. +use hmac::Mac as _; + use super::key::{SCRAM_KEY_LEN, ScramKey}; +use crate::metrics::Metrics; +use crate::scram::pbkdf2::Prf; /// A collection of message parts needed to derive the client's signature. #[derive(Debug)] @@ -12,15 +16,18 @@ pub(crate) struct SignatureBuilder<'a> { impl SignatureBuilder<'_> { pub(crate) fn build(&self, key: &ScramKey) -> Signature { - let parts = [ - self.client_first_message_bare.as_bytes(), - b",", - self.server_first_message.as_bytes(), - b",", - self.client_final_message_without_proof.as_bytes(), - ]; + // don't know exactly. this is a rough approx + Metrics::get().proxy.sha_rounds.inc_by(8); - super::hmac_sha256(key.as_ref(), parts).into() + let mut mac = Prf::new_from_slice(key.as_ref()).expect("HMAC accepts all key sizes"); + mac.update(self.client_first_message_bare.as_bytes()); + mac.update(b","); + mac.update(self.server_first_message.as_bytes()); + mac.update(b","); + mac.update(self.client_final_message_without_proof.as_bytes()); + Signature { + bytes: mac.finalize().into_bytes().into(), + } } } diff --git a/proxy/src/scram/threadpool.rs b/proxy/src/scram/threadpool.rs index ea2e29ede9..20a1df2b53 100644 --- a/proxy/src/scram/threadpool.rs +++ b/proxy/src/scram/threadpool.rs @@ -15,6 +15,8 @@ use futures::FutureExt; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; +use super::cache::Pbkdf2Cache; +use super::pbkdf2; use super::pbkdf2::Pbkdf2; use crate::intern::EndpointIdInt; use crate::metrics::{ThreadPoolMetrics, ThreadPoolWorkerId}; @@ -23,6 +25,10 @@ use crate::scram::countmin::CountMinSketch; pub struct ThreadPool { runtime: Option, pub metrics: Arc, + + // we hash a lot of passwords. + // we keep a cache of partial hashes for faster validation. + pub(super) cache: Pbkdf2Cache, } /// How often to reset the sketch values @@ -68,6 +74,7 @@ impl ThreadPool { Self { runtime: Some(runtime), metrics: Arc::new(ThreadPoolMetrics::new(n_workers as usize)), + cache: Pbkdf2Cache::new(), } }) } @@ -130,7 +137,7 @@ struct JobSpec { } impl Future for JobSpec { - type Output = [u8; 32]; + type Output = pbkdf2::Block; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { STATE.with_borrow_mut(|state| { @@ -166,10 +173,10 @@ impl Future for JobSpec { } } -pub(crate) struct JobHandle(tokio::task::JoinHandle<[u8; 32]>); +pub(crate) struct JobHandle(tokio::task::JoinHandle); impl Future for JobHandle { - type Output = [u8; 32]; + type Output = pbkdf2::Block; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { match self.0.poll_unpin(cx) { @@ -203,10 +210,10 @@ mod tests { .spawn_job(ep, Pbkdf2::start(b"password", &salt, 4096)) .await; - let expected = [ + let expected = &[ 10, 114, 73, 188, 140, 222, 196, 156, 214, 184, 79, 157, 119, 242, 16, 31, 53, 242, 178, 43, 95, 8, 225, 182, 122, 40, 219, 21, 89, 147, 64, 140, ]; - assert_eq!(actual, expected); + assert_eq!(actual.as_slice(), expected); } } diff --git a/proxy/src/serverless/backend.rs b/proxy/src/serverless/backend.rs index 59e4b09bc9..eb879f98e7 100644 --- a/proxy/src/serverless/backend.rs +++ b/proxy/src/serverless/backend.rs @@ -1,46 +1,40 @@ -use std::io; -use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use std::time::Duration; -use async_trait::async_trait; use ed25519_dalek::SigningKey; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use jose_jwk::jose_b64; -use postgres_client::config::SslMode; +use postgres_client::maybe_tls_stream::MaybeTlsStream; use rand_core::OsRng; -use rustls::pki_types::{DnsName, ServerName}; -use tokio::net::{TcpStream, lookup_host}; -use tokio_rustls::TlsConnector; use tracing::field::display; use tracing::{debug, info}; use super::AsyncRW; use super::conn_pool::poll_client; use super::conn_pool_lib::{Client, ConnInfo, EndpointConnPool, GlobalConnPool}; -use super::http_conn_pool::{self, HttpConnPool, Send, poll_http2_client}; +use super::http_conn_pool::{self, HttpConnPool, LocalProxyClient, poll_http2_client}; use super::local_conn_pool::{self, EXT_NAME, EXT_SCHEMA, EXT_VERSION, LocalConnPool}; use crate::auth::backend::local::StaticAuthRules; -use crate::auth::backend::{ComputeCredentialKeys, ComputeCredentials, ComputeUserInfo}; +use crate::auth::backend::{ComputeCredentials, ComputeUserInfo}; use crate::auth::{self, AuthError}; +use crate::compute; use crate::compute_ctl::{ ComputeCtlError, ExtensionInstallRequest, Privilege, SetRoleGrantsRequest, }; -use crate::config::{ComputeConfig, ProxyConfig}; +use crate::config::ProxyConfig; use crate::context::RequestContext; -use crate::control_plane::CachedNodeInfo; use crate::control_plane::client::ApiLockError; use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError}; -use crate::control_plane::locks::ApiLocks; use crate::error::{ErrorKind, ReportableError, UserFacingError}; -use crate::intern::EndpointIdInt; -use crate::proxy::connect_compute::ConnectMechanism; -use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute}; +use crate::intern::{EndpointIdInt, RoleNameInt}; +use crate::pqproto::StartupMessageParams; +use crate::proxy::{connect_auth, connect_compute}; use crate::rate_limiter::EndpointRateLimiter; -use crate::types::{EndpointId, Host, LOCAL_PROXY_SUFFIX}; +use crate::types::{EndpointId, LOCAL_PROXY_SUFFIX}; pub(crate) struct PoolingBackend { - pub(crate) http_conn_pool: Arc>>, + pub(crate) http_conn_pool: + Arc>>, pub(crate) local_pool: Arc>, pub(crate) pool: Arc>>, @@ -82,9 +76,11 @@ impl PoolingBackend { }; let ep = EndpointIdInt::from(&user_info.endpoint); + let role = RoleNameInt::from(&user_info.user); let auth_outcome = crate::auth::validate_password_and_exchange( - &self.config.authentication_config.thread_pool, + &self.config.authentication_config.scram_thread_pool, ep, + role, password, secret, ) @@ -185,20 +181,42 @@ impl PoolingBackend { tracing::Span::current().record("conn_id", display(conn_id)); info!(%conn_id, "pool: opening a new connection '{conn_info}'"); let backend = self.auth_backend.as_ref().map(|()| keys.info); - crate::proxy::connect_compute::connect_to_compute( + + let mut params = StartupMessageParams::default(); + params.insert("database", &conn_info.dbname); + params.insert("user", &conn_info.user_info.user); + + let mut auth_info = compute::AuthInfo::with_auth_keys(keys.keys); + auth_info.set_startup_params(¶ms, true); + + let node = connect_auth::connect_to_compute_and_auth( ctx, - &TokioMechanism { - conn_id, - conn_info, - pool: self.pool.clone(), - locks: &self.config.connect_compute_locks, - keys: keys.keys, - }, + self.config, &backend, - self.config.wake_compute_retry_config, - &self.config.connect_to_compute, + auth_info, + connect_compute::TlsNegotiation::Postgres, ) - .await + .await?; + + let (client, connection) = postgres_client::connect::managed( + node.stream, + Some(node.socket_addr.ip()), + postgres_client::config::Host::Tcp(node.hostname.to_string()), + node.socket_addr.port(), + node.ssl_mode, + Some(self.config.connect_to_compute.timeout), + ) + .await?; + + Ok(poll_client( + self.pool.clone(), + ctx, + conn_info, + client, + connection, + conn_id, + node.aux, + )) } // Wake up the destination if needed @@ -210,7 +228,7 @@ impl PoolingBackend { &self, ctx: &RequestContext, conn_info: ConnInfo, - ) -> Result, HttpConnError> { + ) -> Result, HttpConnError> { debug!("pool: looking for an existing connection"); if let Ok(Some(client)) = self.http_conn_pool.get(ctx, &conn_info) { return Ok(client); @@ -227,19 +245,38 @@ impl PoolingBackend { )), options: conn_info.user_info.options.clone(), }); - crate::proxy::connect_compute::connect_to_compute( + + let node = connect_compute::connect_to_compute( ctx, - &HyperMechanism { - conn_id, - conn_info, - pool: self.http_conn_pool.clone(), - locks: &self.config.connect_compute_locks, - }, + self.config, &backend, - self.config.wake_compute_retry_config, - &self.config.connect_to_compute, + connect_compute::TlsNegotiation::Direct, ) - .await + .await?; + + let stream = match node.stream.into_framed().into_inner() { + MaybeTlsStream::Raw(s) => Box::pin(s) as AsyncRW, + MaybeTlsStream::Tls(s) => Box::pin(s) as AsyncRW, + }; + + let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()) + .timer(TokioTimer::new()) + .keep_alive_interval(Duration::from_secs(20)) + .keep_alive_while_idle(true) + .keep_alive_timeout(Duration::from_secs(5)) + .handshake(TokioIo::new(stream)) + .await + .map_err(LocalProxyConnError::H2)?; + + Ok(poll_http2_client( + self.http_conn_pool.clone(), + ctx, + &conn_info, + client, + connection, + conn_id, + node.aux.clone(), + )) } /// Connect to postgres over localhost. @@ -379,6 +416,8 @@ fn create_random_jwk() -> (SigningKey, jose_jwk::Key) { pub(crate) enum HttpConnError { #[error("pooled connection closed at inconsistent state")] ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError), + #[error("could not connect to compute")] + ConnectError(#[from] compute::ConnectionError), #[error("could not connect to postgres in compute")] PostgresConnectionError(#[from] postgres_client::Error), #[error("could not connect to local-proxy in compute")] @@ -398,10 +437,19 @@ pub(crate) enum HttpConnError { TooManyConnectionAttempts(#[from] ApiLockError), } +impl From for HttpConnError { + fn from(value: connect_auth::AuthError) -> Self { + match value { + connect_auth::AuthError::Auth(compute::PostgresError::Postgres(error)) => { + Self::PostgresConnectionError(error) + } + connect_auth::AuthError::Connect(error) => Self::ConnectError(error), + } + } +} + #[derive(Debug, thiserror::Error)] pub(crate) enum LocalProxyConnError { - #[error("error with connection to local-proxy")] - Io(#[source] std::io::Error), #[error("could not establish h2 connection")] H2(#[from] hyper::Error), } @@ -409,6 +457,7 @@ pub(crate) enum LocalProxyConnError { impl ReportableError for HttpConnError { fn get_error_kind(&self) -> ErrorKind { match self { + HttpConnError::ConnectError(_) => ErrorKind::Compute, HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute, HttpConnError::PostgresConnectionError(p) => { if p.as_db_error().is_some() { @@ -433,6 +482,7 @@ impl ReportableError for HttpConnError { impl UserFacingError for HttpConnError { fn to_string_client(&self) -> String { match self { + HttpConnError::ConnectError(p) => p.to_string_client(), HttpConnError::ConnectionClosedAbruptly(_) => self.to_string(), HttpConnError::PostgresConnectionError(p) => p.to_string(), HttpConnError::LocalProxyConnectionError(p) => p.to_string(), @@ -448,36 +498,9 @@ impl UserFacingError for HttpConnError { } } -impl CouldRetry for HttpConnError { - fn could_retry(&self) -> bool { - match self { - HttpConnError::PostgresConnectionError(e) => e.could_retry(), - HttpConnError::LocalProxyConnectionError(e) => e.could_retry(), - HttpConnError::ComputeCtl(_) => false, - HttpConnError::ConnectionClosedAbruptly(_) => false, - HttpConnError::JwtPayloadError(_) => false, - HttpConnError::GetAuthInfo(_) => false, - HttpConnError::AuthError(_) => false, - HttpConnError::WakeCompute(_) => false, - HttpConnError::TooManyConnectionAttempts(_) => false, - } - } -} -impl ShouldRetryWakeCompute for HttpConnError { - fn should_retry_wake_compute(&self) -> bool { - match self { - HttpConnError::PostgresConnectionError(e) => e.should_retry_wake_compute(), - // we never checked cache validity - HttpConnError::TooManyConnectionAttempts(_) => false, - _ => true, - } - } -} - impl ReportableError for LocalProxyConnError { fn get_error_kind(&self) -> ErrorKind { match self { - LocalProxyConnError::Io(_) => ErrorKind::Compute, LocalProxyConnError::H2(_) => ErrorKind::Compute, } } @@ -488,209 +511,3 @@ impl UserFacingError for LocalProxyConnError { "Could not establish HTTP connection to the database".to_string() } } - -impl CouldRetry for LocalProxyConnError { - fn could_retry(&self) -> bool { - match self { - LocalProxyConnError::Io(_) => false, - LocalProxyConnError::H2(_) => false, - } - } -} -impl ShouldRetryWakeCompute for LocalProxyConnError { - fn should_retry_wake_compute(&self) -> bool { - match self { - LocalProxyConnError::Io(_) => false, - LocalProxyConnError::H2(_) => false, - } - } -} - -struct TokioMechanism { - pool: Arc>>, - conn_info: ConnInfo, - conn_id: uuid::Uuid, - keys: ComputeCredentialKeys, - - /// connect_to_compute concurrency lock - locks: &'static ApiLocks, -} - -#[async_trait] -impl ConnectMechanism for TokioMechanism { - type Connection = Client; - type ConnectError = HttpConnError; - type Error = HttpConnError; - - async fn connect_once( - &self, - ctx: &RequestContext, - node_info: &CachedNodeInfo, - compute_config: &ComputeConfig, - ) -> Result { - let permit = self.locks.get_permit(&node_info.conn_info.host).await?; - - let mut config = node_info.conn_info.to_postgres_client_config(); - let config = config - .user(&self.conn_info.user_info.user) - .dbname(&self.conn_info.dbname) - .connect_timeout(compute_config.timeout); - - if let ComputeCredentialKeys::AuthKeys(auth_keys) = self.keys { - config.auth_keys(auth_keys); - } - - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - let res = config.connect(compute_config).await; - drop(pause); - let (client, connection) = permit.release_result(res)?; - - tracing::Span::current().record("pid", tracing::field::display(client.get_process_id())); - tracing::Span::current().record( - "compute_id", - tracing::field::display(&node_info.aux.compute_id), - ); - - if let Some(query_id) = ctx.get_testodrome_id() { - info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id); - } - - Ok(poll_client( - self.pool.clone(), - ctx, - self.conn_info.clone(), - client, - connection, - self.conn_id, - node_info.aux.clone(), - )) - } -} - -struct HyperMechanism { - pool: Arc>>, - conn_info: ConnInfo, - conn_id: uuid::Uuid, - - /// connect_to_compute concurrency lock - locks: &'static ApiLocks, -} - -#[async_trait] -impl ConnectMechanism for HyperMechanism { - type Connection = http_conn_pool::Client; - type ConnectError = HttpConnError; - type Error = HttpConnError; - - async fn connect_once( - &self, - ctx: &RequestContext, - node_info: &CachedNodeInfo, - config: &ComputeConfig, - ) -> Result { - let host_addr = node_info.conn_info.host_addr; - let host = &node_info.conn_info.host; - let permit = self.locks.get_permit(host).await?; - - let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute); - - let tls = if node_info.conn_info.ssl_mode == SslMode::Disable { - None - } else { - Some(&config.tls) - }; - - let port = node_info.conn_info.port; - let res = connect_http2(host_addr, host, port, config.timeout, tls).await; - drop(pause); - let (client, connection) = permit.release_result(res)?; - - tracing::Span::current().record( - "compute_id", - tracing::field::display(&node_info.aux.compute_id), - ); - - if let Some(query_id) = ctx.get_testodrome_id() { - info!("latency={}, query_id={}", ctx.get_proxy_latency(), query_id); - } - - Ok(poll_http2_client( - self.pool.clone(), - ctx, - &self.conn_info, - client, - connection, - self.conn_id, - node_info.aux.clone(), - )) - } -} - -async fn connect_http2( - host_addr: Option, - host: &str, - port: u16, - timeout: Duration, - tls: Option<&Arc>, -) -> Result<(http_conn_pool::Send, http_conn_pool::Connect), LocalProxyConnError> { - let addrs = match host_addr { - Some(addr) => vec![SocketAddr::new(addr, port)], - None => lookup_host((host, port)) - .await - .map_err(LocalProxyConnError::Io)? - .collect(), - }; - let mut last_err = None; - - let mut addrs = addrs.into_iter(); - let stream = loop { - let Some(addr) = addrs.next() else { - return Err(last_err.unwrap_or_else(|| { - LocalProxyConnError::Io(io::Error::new( - io::ErrorKind::InvalidInput, - "could not resolve any addresses", - )) - })); - }; - - match tokio::time::timeout(timeout, TcpStream::connect(addr)).await { - Ok(Ok(stream)) => { - stream.set_nodelay(true).map_err(LocalProxyConnError::Io)?; - break stream; - } - Ok(Err(e)) => { - last_err = Some(LocalProxyConnError::Io(e)); - } - Err(e) => { - last_err = Some(LocalProxyConnError::Io(io::Error::new( - io::ErrorKind::TimedOut, - e, - ))); - } - } - }; - - let stream = if let Some(tls) = tls { - let host = DnsName::try_from(host) - .map_err(io::Error::other) - .map_err(LocalProxyConnError::Io)? - .to_owned(); - let stream = TlsConnector::from(tls.clone()) - .connect(ServerName::DnsName(host), stream) - .await - .map_err(LocalProxyConnError::Io)?; - Box::pin(stream) as AsyncRW - } else { - Box::pin(stream) as AsyncRW - }; - - let (client, connection) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()) - .timer(TokioTimer::new()) - .keep_alive_interval(Duration::from_secs(20)) - .keep_alive_while_idle(true) - .keep_alive_timeout(Duration::from_secs(5)) - .handshake(TokioIo::new(stream)) - .await?; - - Ok((client, connection)) -} diff --git a/proxy/src/serverless/conn_pool.rs b/proxy/src/serverless/conn_pool.rs index 015c46f787..17305e30f1 100644 --- a/proxy/src/serverless/conn_pool.rs +++ b/proxy/src/serverless/conn_pool.rs @@ -190,6 +190,9 @@ mod tests { fn get_process_id(&self) -> i32 { 0 } + fn reset(&mut self) -> Result<(), postgres_client::Error> { + Ok(()) + } } fn create_inner() -> ClientInnerCommon { diff --git a/proxy/src/serverless/conn_pool_lib.rs b/proxy/src/serverless/conn_pool_lib.rs index ed5cc0ea03..6adca49723 100644 --- a/proxy/src/serverless/conn_pool_lib.rs +++ b/proxy/src/serverless/conn_pool_lib.rs @@ -7,10 +7,9 @@ use std::time::Duration; use clashmap::ClashMap; use parking_lot::RwLock; -use postgres_client::ReadyForQueryStatus; use rand::Rng; use smol_str::ToSmolStr; -use tracing::{Span, debug, info}; +use tracing::{Span, debug, info, warn}; use super::backend::HttpConnError; use super::conn_pool::ClientDataRemote; @@ -188,7 +187,7 @@ impl EndpointConnPool { self.pools.get_mut(&db_user) } - pub(crate) fn put(pool: &RwLock, conn_info: &ConnInfo, client: ClientInnerCommon) { + pub(crate) fn put(pool: &RwLock, conn_info: &ConnInfo, mut client: ClientInnerCommon) { let conn_id = client.get_conn_id(); let (max_conn, conn_count, pool_name) = { let pool = pool.read(); @@ -201,12 +200,17 @@ impl EndpointConnPool { }; if client.inner.is_closed() { - info!(%conn_id, "{}: throwing away connection '{conn_info}' because connection is closed", pool_name); + info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because connection is closed"); + return; + } + + if let Err(error) = client.inner.reset() { + warn!(?error, %conn_id, "{pool_name}: throwing away connection '{conn_info}' because connection could not be reset"); return; } if conn_count >= max_conn { - info!(%conn_id, "{}: throwing away connection '{conn_info}' because pool is full", pool_name); + info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because pool is full"); return; } @@ -691,6 +695,7 @@ impl Deref for Client { pub(crate) trait ClientInnerExt: Sync + Send + 'static { fn is_closed(&self) -> bool; fn get_process_id(&self) -> i32; + fn reset(&mut self) -> Result<(), postgres_client::Error>; } impl ClientInnerExt for postgres_client::Client { @@ -701,15 +706,13 @@ impl ClientInnerExt for postgres_client::Client { fn get_process_id(&self) -> i32 { self.get_process_id() } + + fn reset(&mut self) -> Result<(), postgres_client::Error> { + self.reset_session_background() + } } impl Discard<'_, C> { - pub(crate) fn check_idle(&mut self, status: ReadyForQueryStatus) { - let conn_info = &self.conn_info; - if status != ReadyForQueryStatus::Idle && std::mem::take(self.pool).strong_count() > 0 { - info!("pool: throwing away connection '{conn_info}' because connection is not idle"); - } - } pub(crate) fn discard(&mut self) { let conn_info = &self.conn_info; if std::mem::take(self.pool).strong_count() > 0 { diff --git a/proxy/src/serverless/http_conn_pool.rs b/proxy/src/serverless/http_conn_pool.rs index 7acd816026..bf6b934d20 100644 --- a/proxy/src/serverless/http_conn_pool.rs +++ b/proxy/src/serverless/http_conn_pool.rs @@ -23,8 +23,8 @@ use crate::protocol2::ConnectionInfoExtra; use crate::types::EndpointCacheKey; use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS}; -pub(crate) type Send = http2::SendRequest>; -pub(crate) type Connect = +pub(crate) type LocalProxyClient = http2::SendRequest>; +pub(crate) type LocalProxyConnection = http2::Connection, BoxBody, TokioExecutor>; #[derive(Clone)] @@ -189,14 +189,14 @@ impl GlobalConnPool> { } pub(crate) fn poll_http2_client( - global_pool: Arc>>, + global_pool: Arc>>, ctx: &RequestContext, conn_info: &ConnInfo, - client: Send, - connection: Connect, + client: LocalProxyClient, + connection: LocalProxyConnection, conn_id: uuid::Uuid, aux: MetricsAuxInfo, -) -> Client { +) -> Client { let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol()); let session_id = ctx.session_id(); @@ -285,7 +285,7 @@ impl Client { } } -impl ClientInnerExt for Send { +impl ClientInnerExt for LocalProxyClient { fn is_closed(&self) -> bool { self.is_closed() } @@ -294,4 +294,10 @@ impl ClientInnerExt for Send { // ideally throw something meaningful -1 } + + fn reset(&mut self) -> Result<(), postgres_client::Error> { + // We use HTTP/2.0 to talk to local proxy. HTTP is stateless, + // so there's nothing to reset. + Ok(()) + } } diff --git a/proxy/src/serverless/local_conn_pool.rs b/proxy/src/serverless/local_conn_pool.rs index f63d84d66b..b8a502c37e 100644 --- a/proxy/src/serverless/local_conn_pool.rs +++ b/proxy/src/serverless/local_conn_pool.rs @@ -269,11 +269,6 @@ impl ClientInnerCommon { local_data.jti += 1; let token = resign_jwt(&local_data.key, payload, local_data.jti)?; - self.inner - .discard_all() - .await - .map_err(SqlOverHttpError::InternalPostgres)?; - // initiates the auth session // this is safe from query injections as the jwt format free of any escape characters. let query = format!("select auth.jwt_session_init('{token}')"); diff --git a/proxy/src/serverless/rest.rs b/proxy/src/serverless/rest.rs index 173c2629f7..0c3d2c958d 100644 --- a/proxy/src/serverless/rest.rs +++ b/proxy/src/serverless/rest.rs @@ -1,5 +1,6 @@ use std::borrow::Cow; use std::collections::HashMap; +use std::convert::Infallible; use std::sync::Arc; use bytes::Bytes; @@ -12,6 +13,7 @@ use hyper::body::Incoming; use hyper::http::{HeaderName, HeaderValue}; use hyper::{Request, Response, StatusCode}; use indexmap::IndexMap; +use moka::sync::Cache; use ouroboros::self_referencing; use serde::de::DeserializeOwned; use serde::{Deserialize, Deserializer}; @@ -46,19 +48,19 @@ use super::backend::{HttpConnError, LocalProxyConnError, PoolingBackend}; use super::conn_pool::AuthData; use super::conn_pool_lib::ConnInfo; use super::error::{ConnInfoError, Credentials, HttpCodeError, ReadPayloadError}; -use super::http_conn_pool::{self, Send}; +use super::http_conn_pool::{self, LocalProxyClient}; use super::http_util::{ ALLOW_POOL, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_ISOLATION_LEVEL, TXN_READ_ONLY, get_conn_info, json_response, uuid_to_header_value, }; use super::json::JsonConversionError; use crate::auth::backend::ComputeCredentialKeys; -use crate::cache::{Cached, TimedLru}; +use crate::cache::common::{count_cache_insert, count_cache_outcome, eviction_listener}; use crate::config::ProxyConfig; use crate::context::RequestContext; use crate::error::{ErrorKind, ReportableError, UserFacingError}; use crate::http::read_body_with_limit; -use crate::metrics::Metrics; +use crate::metrics::{CacheKind, Metrics}; use crate::serverless::sql_over_http::HEADER_VALUE_TRUE; use crate::types::EndpointCacheKey; use crate::util::deserialize_json_string; @@ -138,19 +140,43 @@ pub struct ApiConfig { } // The DbSchemaCache is a cache of the ApiConfig and DbSchemaOwned for each endpoint -pub(crate) type DbSchemaCache = TimedLru>; +pub(crate) struct DbSchemaCache(Cache>); impl DbSchemaCache { + pub fn new(config: crate::config::CacheOptions) -> Self { + let builder = Cache::builder().name("schema"); + let builder = config.moka(builder); + + let metrics = &Metrics::get().cache; + if let Some(size) = config.size { + metrics.capacity.set(CacheKind::Schema, size as i64); + } + + let builder = + builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Schema, cause)); + + Self(builder.build()) + } + + pub async fn maintain(&self) -> Result { + let mut ticker = tokio::time::interval(std::time::Duration::from_secs(60)); + loop { + ticker.tick().await; + self.0.run_pending_tasks(); + } + } + pub async fn get_cached_or_remote( &self, endpoint_id: &EndpointCacheKey, auth_header: &HeaderValue, connection_string: &str, - client: &mut http_conn_pool::Client, + client: &mut http_conn_pool::Client, ctx: &RequestContext, config: &'static ProxyConfig, ) -> Result, RestError> { - match self.get_with_created_at(endpoint_id) { - Some(Cached { value: (v, _), .. }) => Ok(v), + let cache_result = count_cache_outcome(CacheKind::Schema, self.0.get(endpoint_id)); + match cache_result { + Some(v) => Ok(v), None => { info!("db_schema cache miss for endpoint: {:?}", endpoint_id); let remote_value = self @@ -173,7 +199,8 @@ impl DbSchemaCache { db_extra_search_path: None, }; let value = Arc::new((api_config, schema_owned)); - self.insert(endpoint_id.clone(), value); + count_cache_insert(CacheKind::Schema); + self.0.insert(endpoint_id.clone(), value); return Err(e); } Err(e) => { @@ -181,7 +208,8 @@ impl DbSchemaCache { } }; let value = Arc::new((api_config, schema_owned)); - self.insert(endpoint_id.clone(), value.clone()); + count_cache_insert(CacheKind::Schema); + self.0.insert(endpoint_id.clone(), value.clone()); Ok(value) } } @@ -190,7 +218,7 @@ impl DbSchemaCache { &self, auth_header: &HeaderValue, connection_string: &str, - client: &mut http_conn_pool::Client, + client: &mut http_conn_pool::Client, ctx: &RequestContext, config: &'static ProxyConfig, ) -> Result<(ApiConfig, DbSchemaOwned), RestError> { @@ -430,7 +458,7 @@ struct BatchQueryData<'a> { } async fn make_local_proxy_request( - client: &mut http_conn_pool::Client, + client: &mut http_conn_pool::Client, headers: impl IntoIterator, body: QueryData<'_>, max_len: usize, @@ -461,7 +489,7 @@ async fn make_local_proxy_request( } async fn make_raw_local_proxy_request( - client: &mut http_conn_pool::Client, + client: &mut http_conn_pool::Client, headers: impl IntoIterator, body: String, ) -> Result, RestError> { diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index f254b41b5b..26f65379e7 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -735,9 +735,7 @@ impl QueryData { match batch_result { // The query successfully completed. - Ok(status) => { - discard.check_idle(status); - + Ok(_) => { let json_output = String::from_utf8(json_buf).expect("json should be valid utf8"); Ok(json_output) } @@ -793,7 +791,7 @@ impl BatchQueryData { { Ok(json_output) => { info!("commit"); - let status = transaction + transaction .commit() .await .inspect_err(|_| { @@ -802,7 +800,6 @@ impl BatchQueryData { discard.discard(); }) .map_err(SqlOverHttpError::Postgres)?; - discard.check_idle(status); json_output } Err(SqlOverHttpError::Cancelled(_)) => { @@ -815,17 +812,6 @@ impl BatchQueryData { return Err(SqlOverHttpError::Cancelled(SqlOverHttpCancel::Postgres)); } Err(err) => { - info!("rollback"); - let status = transaction - .rollback() - .await - .inspect_err(|_| { - // if we cannot rollback - for now don't return connection to pool - // TODO: get a query status from the error - discard.discard(); - }) - .map_err(SqlOverHttpError::Postgres)?; - discard.check_idle(status); return Err(err); } }; @@ -1012,12 +998,6 @@ impl Client { } impl Discard<'_> { - fn check_idle(&mut self, status: ReadyForQueryStatus) { - match self { - Discard::Remote(discard) => discard.check_idle(status), - Discard::Local(discard) => discard.check_idle(status), - } - } fn discard(&mut self) { match self { Discard::Remote(discard) => discard.discard(), diff --git a/proxy/src/signals.rs b/proxy/src/signals.rs index 32b2344a1c..63ef7d1061 100644 --- a/proxy/src/signals.rs +++ b/proxy/src/signals.rs @@ -4,6 +4,8 @@ use anyhow::bail; use tokio_util::sync::CancellationToken; use tracing::{info, warn}; +use crate::metrics::{Metrics, ServiceInfo}; + /// Handle unix signals appropriately. pub async fn handle( token: CancellationToken, @@ -28,10 +30,12 @@ where // Shut down the whole application. _ = interrupt.recv() => { warn!("received SIGINT, exiting immediately"); + Metrics::get().service.info.set_label(ServiceInfo::terminating()); bail!("interrupted"); } _ = terminate.recv() => { warn!("received SIGTERM, shutting down once all existing connections have closed"); + Metrics::get().service.info.set_label(ServiceInfo::terminating()); token.cancel(); } } diff --git a/proxy/src/stream.rs b/proxy/src/stream.rs index d6a43df188..9447b9623b 100644 --- a/proxy/src/stream.rs +++ b/proxy/src/stream.rs @@ -102,7 +102,7 @@ pub struct ReportedError { } impl ReportedError { - pub fn new(e: (impl UserFacingError + Into)) -> Self { + pub fn new(e: impl UserFacingError + Into) -> Self { let error_kind = e.get_error_kind(); Self { source: e.into(), diff --git a/storage_controller/src/reconciler.rs b/storage_controller/src/reconciler.rs index d1590ec75e..ff5a3831cd 100644 --- a/storage_controller/src/reconciler.rs +++ b/storage_controller/src/reconciler.rs @@ -981,6 +981,7 @@ impl Reconciler { )); } + let mut first_err = None; for (node, conf) in changes { if self.cancel.is_cancelled() { return Err(ReconcileError::Cancel); @@ -990,7 +991,12 @@ impl Reconciler { // shard _available_ (the attached location), and configuring secondary locations // can be done lazily when the node becomes available (via background reconciliation). if node.is_available() { - self.location_config(&node, conf, None, false).await?; + let res = self.location_config(&node, conf, None, false).await; + if let Err(err) = res { + if first_err.is_none() { + first_err = Some(err); + } + } } else { // If the node is unavailable, we skip and consider the reconciliation successful: this // is a common case where a pageserver is marked unavailable: we demote a location on @@ -1002,6 +1008,10 @@ impl Reconciler { } } + if let Some(err) = first_err { + return Err(err); + } + // The condition below identifies a detach. We must have no attached intent and // must have been attached to something previously. Pass this information to // the [`ComputeHook`] such that it can update its tenant-wide state. diff --git a/storage_controller/src/service.rs b/storage_controller/src/service.rs index bb2294831e..33c2705316 100644 --- a/storage_controller/src/service.rs +++ b/storage_controller/src/service.rs @@ -1536,10 +1536,19 @@ impl Service { // so that waiters will see the correct error after waiting. tenant.set_last_error(result.sequence, e); - // Skip deletions on reconcile failures - let upsert_deltas = - deltas.filter(|delta| matches!(delta, ObservedStateDelta::Upsert(_))); - tenant.apply_observed_deltas(upsert_deltas); + // If the reconciliation failed, don't clear the observed state for places where we + // detached. Instead, mark the observed state as uncertain. + let failed_reconcile_deltas = deltas.map(|delta| { + if let ObservedStateDelta::Delete(node_id) = delta { + ObservedStateDelta::Upsert(Box::new(( + node_id, + ObservedStateLocation { conf: None }, + ))) + } else { + delta + } + }); + tenant.apply_observed_deltas(failed_reconcile_deltas); } } diff --git a/storage_controller/src/tenant_shard.rs b/storage_controller/src/tenant_shard.rs index f60378470e..3eb54d714d 100644 --- a/storage_controller/src/tenant_shard.rs +++ b/storage_controller/src/tenant_shard.rs @@ -249,6 +249,10 @@ impl IntentState { } pub(crate) fn push_secondary(&mut self, scheduler: &mut Scheduler, new_secondary: NodeId) { + // Every assertion here should probably have a corresponding check in + // `validate_optimization` unless it is an invariant that should never be violated. Note + // that the lock is not held between planning optimizations and applying them so you have to + // assume any valid state transition of the intent state may have occurred assert!(!self.secondary.contains(&new_secondary)); assert!(self.attached != Some(new_secondary)); scheduler.update_node_ref_counts( @@ -1335,8 +1339,9 @@ impl TenantShard { true } - /// Check that the desired modifications to the intent state are compatible with - /// the current intent state + /// Check that the desired modifications to the intent state are compatible with the current + /// intent state. Note that the lock is not held between planning optimizations and applying + /// them so any valid state transition of the intent state may have occurred. fn validate_optimization(&self, optimization: &ScheduleOptimization) -> bool { match optimization.action { ScheduleOptimizationAction::MigrateAttachment(MigrateAttachment { @@ -1352,6 +1357,9 @@ impl TenantShard { }) => { // It's legal to remove a secondary that is not present in the intent state !self.intent.secondary.contains(&new_node_id) + // Ensure the secondary hasn't already been promoted to attached by a concurrent + // optimization/migration. + && self.intent.attached != Some(new_node_id) } ScheduleOptimizationAction::CreateSecondary(new_node_id) => { !self.intent.secondary.contains(&new_node_id) diff --git a/test_runner/fixtures/compare_fixtures.py b/test_runner/fixtures/compare_fixtures.py index 150046b99a..3d248efc04 100644 --- a/test_runner/fixtures/compare_fixtures.py +++ b/test_runner/fixtures/compare_fixtures.py @@ -16,6 +16,7 @@ from typing_extensions import override from fixtures.benchmark_fixture import MetricReport, NeonBenchmarker from fixtures.log_helper import log from fixtures.neon_fixtures import ( + Endpoint, NeonEnv, PgBin, PgProtocol, @@ -129,6 +130,10 @@ class NeonCompare(PgCompare): # Start pg self._pg = self.env.endpoints.create_start("main", "main", self.tenant) + @property + def endpoint(self) -> Endpoint: + return self._pg + @property @override def pg(self) -> PgProtocol: diff --git a/test_runner/fixtures/endpoint/http.py b/test_runner/fixtures/endpoint/http.py index c43445e89d..d235ac2143 100644 --- a/test_runner/fixtures/endpoint/http.py +++ b/test_runner/fixtures/endpoint/http.py @@ -79,17 +79,28 @@ class EndpointHttpClient(requests.Session): return json def prewarm_lfc(self, from_endpoint_id: str | None = None): + """ + Prewarm LFC cache from given endpoint and wait till it finishes or errors + """ params = {"from_endpoint": from_endpoint_id} if from_endpoint_id else dict() self.post(self.prewarm_url, params=params).raise_for_status() self.prewarm_lfc_wait() def prewarm_lfc_wait(self): + """ + Wait till LFC prewarm returns with error or success. + If prewarm was not requested before calling this function, it will error + """ + statuses = "failed", "completed", "skipped" + def prewarmed(): json = self.prewarm_lfc_status() status, err = json["status"], json.get("error") - assert status == "completed", f"{status}, {err=}" + assert status in statuses, f"{status}, {err=}" wait_until(prewarmed, timeout=60) + res = self.prewarm_lfc_status() + assert res["status"] != "failed", res def offload_lfc_status(self) -> dict[str, str]: res = self.get(self.offload_url) @@ -98,26 +109,35 @@ class EndpointHttpClient(requests.Session): return json def offload_lfc(self): + """ + Offload LFC cache to endpoint storage and wait till offload finishes or errors + """ self.post(self.offload_url).raise_for_status() self.offload_lfc_wait() def offload_lfc_wait(self): + """ + Wait till LFC offload returns with error or success. + If offload was not requested before calling this function, it will error + """ + def offloaded(): json = self.offload_lfc_status() status, err = json["status"], json.get("error") - assert status == "completed", f"{status}, {err=}" + assert status in ["failed", "completed"], f"{status}, {err=}" - wait_until(offloaded) + wait_until(offloaded, timeout=60) + res = self.offload_lfc_status() + assert res["status"] != "failed", res - def promote(self, safekeepers_lsn: dict[str, Any], disconnect: bool = False): + def promote(self, promote_spec: dict[str, Any], disconnect: bool = False): url = f"http://localhost:{self.external_port}/promote" if disconnect: try: # send first request to start promote and disconnect - self.post(url, data=safekeepers_lsn, timeout=0.001) + self.post(url, json=promote_spec, timeout=0.001) except ReadTimeout: pass # wait on second request which returns on promotion finish - res = self.post(url, data=safekeepers_lsn) - res.raise_for_status() + res = self.post(url, json=promote_spec) json: dict[str, str] = res.json() return json diff --git a/test_runner/fixtures/neon_api.py b/test_runner/fixtures/neon_api.py index 8d447c837f..69160dab20 100644 --- a/test_runner/fixtures/neon_api.py +++ b/test_runner/fixtures/neon_api.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re import time from typing import TYPE_CHECKING, cast, final @@ -13,6 +14,17 @@ if TYPE_CHECKING: from fixtures.pg_version import PgVersion +def connstr_to_env(connstr: str) -> dict[str, str]: + # postgresql://neondb_owner:npg_kuv6Rqi1cB@ep-old-silence-w26pxsvz-pooler.us-east-2.aws.neon.build/neondb?sslmode=require&channel_binding=...' + parts = re.split(r":|@|\/|\?", connstr.removeprefix("postgresql://")) + return { + "PGUSER": parts[0], + "PGPASSWORD": parts[1], + "PGHOST": parts[2], + "PGDATABASE": parts[3], + } + + def connection_parameters_to_env(params: dict[str, str]) -> dict[str, str]: return { "PGHOST": params["host"], diff --git a/test_runner/fixtures/neon_cli.py b/test_runner/fixtures/neon_cli.py index 5ad00d155e..390efe0309 100644 --- a/test_runner/fixtures/neon_cli.py +++ b/test_runner/fixtures/neon_cli.py @@ -212,11 +212,13 @@ class NeonLocalCli(AbstractNeonCli): pg_version, ] if conf is not None: - args.extend( - chain.from_iterable( - product(["-c"], (f"{key}:{value}" for key, value in conf.items())) - ) - ) + for key, value in conf.items(): + if isinstance(value, bool): + args.extend( + ["-c", f"{key}:{str(value).lower()}"] + ) # only accepts true/false not True/False + else: + args.extend(["-c", f"{key}:{value}"]) if set_default: args.append("--set-default") @@ -585,7 +587,9 @@ class NeonLocalCli(AbstractNeonCli): ] extra_env_vars = env or {} if basebackup_request_tries is not None: - extra_env_vars["NEON_COMPUTE_TESTING_BASEBACKUP_TRIES"] = str(basebackup_request_tries) + extra_env_vars["NEON_COMPUTE_TESTING_BASEBACKUP_RETRIES"] = str( + basebackup_request_tries + ) if remote_ext_base_url is not None: args.extend(["--remote-ext-base-url", remote_ext_base_url]) @@ -621,6 +625,7 @@ class NeonLocalCli(AbstractNeonCli): pageserver_id: int | None = None, safekeepers: list[int] | None = None, check_return_code=True, + timeout_sec: float | None = None, ) -> subprocess.CompletedProcess[str]: args = ["endpoint", "reconfigure", endpoint_id] if tenant_id is not None: @@ -629,7 +634,16 @@ class NeonLocalCli(AbstractNeonCli): args.extend(["--pageserver-id", str(pageserver_id)]) if safekeepers is not None: args.extend(["--safekeepers", (",".join(map(str, safekeepers)))]) - return self.raw_cli(args, check_return_code=check_return_code) + return self.raw_cli(args, check_return_code=check_return_code, timeout=timeout_sec) + + def endpoint_refresh_configuration( + self, + endpoint_id: str, + ) -> subprocess.CompletedProcess[str]: + args = ["endpoint", "refresh-configuration", endpoint_id] + res = self.raw_cli(args) + res.check_returncode() + return res def endpoint_stop( self, @@ -655,6 +669,22 @@ class NeonLocalCli(AbstractNeonCli): lsn: Lsn | None = None if lsn_str == "null" else Lsn(lsn_str) return lsn, proc + def endpoint_update_pageservers( + self, + endpoint_id: str, + pageserver_id: int | None = None, + ) -> subprocess.CompletedProcess[str]: + args = [ + "endpoint", + "update-pageservers", + endpoint_id, + ] + if pageserver_id is not None: + args.extend(["--pageserver-id", str(pageserver_id)]) + res = self.raw_cli(args) + res.check_returncode() + return res + def mappings_map_branch( self, name: str, tenant_id: TenantId, timeline_id: TimelineId ) -> subprocess.CompletedProcess[str]: diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index fbe7752a77..493a92cac7 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -3966,6 +3966,41 @@ class NeonProxy(PgProtocol): assert response.status_code == expected_code, f"response: {response.json()}" return response.json() + def http_multiquery(self, *queries, **kwargs): + # TODO maybe use default values if not provided + user = quote(kwargs["user"]) + password = quote(kwargs["password"]) + expected_code = kwargs.get("expected_code") + timeout = kwargs.get("timeout") + + json_queries = [] + for query in queries: + if type(query) is str: + json_queries.append({"query": query}) + else: + [query, params] = query + json_queries.append({"query": query, "params": params}) + + queries_str = [j["query"] for j in json_queries] + log.info(f"Executing http queries: {queries_str}") + + connstr = f"postgresql://{user}:{password}@{self.domain}:{self.proxy_port}/postgres" + response = requests.post( + f"https://{self.domain}:{self.external_http_port}/sql", + data=json.dumps({"queries": json_queries}), + headers={ + "Content-Type": "application/sql", + "Neon-Connection-String": connstr, + "Neon-Pool-Opt-In": "true", + }, + verify=str(self.test_output_dir / "proxy.crt"), + timeout=timeout, + ) + + if expected_code is not None: + assert response.status_code == expected_code, f"response: {response.json()}" + return response.json() + async def http2_query(self, query, args, **kwargs): # TODO maybe use default values if not provided user = kwargs["user"] @@ -4815,9 +4850,10 @@ class Endpoint(PgProtocol, LogUtils): m = re.search(r"=\s*(\S+)", line) assert m is not None, f"malformed config line {line}" size = m.group(1) - assert size_to_bytes(size) >= size_to_bytes("1MB"), ( - "LFC size cannot be set less than 1MB" - ) + if size_to_bytes(size) > 0: + assert size_to_bytes(size) >= size_to_bytes("1MB"), ( + "LFC size cannot be set less than 1MB" + ) lfc_path_escaped = str(lfc_path).replace("'", "''") config_lines = [ f"neon.file_cache_path = '{lfc_path_escaped}'", @@ -4950,15 +4986,38 @@ class Endpoint(PgProtocol, LogUtils): def is_running(self): return self._running._value > 0 - def reconfigure(self, pageserver_id: int | None = None, safekeepers: list[int] | None = None): + def reconfigure( + self, + pageserver_id: int | None = None, + safekeepers: list[int] | None = None, + timeout_sec: float = 120, + ): assert self.endpoint_id is not None # If `safekeepers` is not None, they are remember them as active and use # in the following commands. if safekeepers is not None: self.active_safekeepers = safekeepers - self.env.neon_cli.endpoint_reconfigure( - self.endpoint_id, self.tenant_id, pageserver_id, self.active_safekeepers - ) + + start_time = time.time() + while True: + try: + self.env.neon_cli.endpoint_reconfigure( + self.endpoint_id, + self.tenant_id, + pageserver_id, + self.active_safekeepers, + timeout_sec=timeout_sec, + ) + return + except RuntimeError as e: + if time.time() - start_time > timeout_sec: + raise e + log.warning(f"Reconfigure failed with error: {e}. Retrying...") + time.sleep(5) + + def refresh_configuration(self): + assert self.endpoint_id is not None + self.env.neon_cli.endpoint_refresh_configuration(self.endpoint_id) def respec(self, **kwargs: Any) -> None: """Update the endpoint.json file used by control_plane.""" @@ -4972,6 +5031,10 @@ class Endpoint(PgProtocol, LogUtils): log.debug(json.dumps(dict(data_dict, **kwargs))) json.dump(dict(data_dict, **kwargs), file, indent=4) + def get_compute_spec(self) -> dict[str, Any]: + out = json.loads((Path(self.endpoint_path()) / "config.json").read_text())["spec"] + return cast("dict[str, Any]", out) + def respec_deep(self, **kwargs: Any) -> None: """ Update the endpoint.json file taking into account nested keys. @@ -5002,6 +5065,10 @@ class Endpoint(PgProtocol, LogUtils): log.debug("Updating compute config to: %s", json.dumps(config, indent=4)) json.dump(config, file, indent=4) + def update_pageservers_in_config(self, pageserver_id: int | None = None): + assert self.endpoint_id is not None + self.env.neon_cli.endpoint_update_pageservers(self.endpoint_id, pageserver_id) + def wait_for_migrations(self, wait_for: int = NUM_COMPUTE_MIGRATIONS) -> None: """ Wait for all compute migrations to be ran. Remember that migrations only diff --git a/test_runner/fixtures/workload.py b/test_runner/fixtures/workload.py index e17a8e989b..3ac61b5d8c 100644 --- a/test_runner/fixtures/workload.py +++ b/test_runner/fixtures/workload.py @@ -78,6 +78,9 @@ class Workload: """ if self._endpoint is not None: with ENDPOINT_LOCK: + # It's important that we update config.json before issuing the reconfigure request to make sure + # that PG-initiated spec refresh doesn't mess things up by reverting to the old spec. + self._endpoint.update_pageservers_in_config() self._endpoint.reconfigure() def endpoint(self, pageserver_id: int | None = None) -> Endpoint: @@ -97,10 +100,10 @@ class Workload: self._endpoint.start(pageserver_id=pageserver_id) self._configured_pageserver = pageserver_id else: - if self._configured_pageserver != pageserver_id: - self._configured_pageserver = pageserver_id - self._endpoint.reconfigure(pageserver_id=pageserver_id) - self._endpoint_config = pageserver_id + # It's important that we update config.json before issuing the reconfigure request to make sure + # that PG-initiated spec refresh doesn't mess things up by reverting to the old spec. + self._endpoint.update_pageservers_in_config(pageserver_id=pageserver_id) + self._endpoint.reconfigure(pageserver_id=pageserver_id) connstring = self._endpoint.safe_psql( "SELECT setting FROM pg_settings WHERE name='neon.pageserver_connstring'" diff --git a/test_runner/logical_repl/README.md b/test_runner/logical_repl/README.md index 449e56e21d..74af203c03 100644 --- a/test_runner/logical_repl/README.md +++ b/test_runner/logical_repl/README.md @@ -9,9 +9,10 @@ ```bash export BENCHMARK_CONNSTR=postgres://user:pass@ep-abc-xyz-123.us-east-2.aws.neon.build/neondb +export CLICKHOUSE_PASSWORD=ch_password123 docker compose -f test_runner/logical_repl/clickhouse/docker-compose.yml up -d -./scripts/pytest -m remote_cluster -k test_clickhouse +./scripts/pytest -m remote_cluster -k 'test_clickhouse[release-pg17]' docker compose -f test_runner/logical_repl/clickhouse/docker-compose.yml down ``` @@ -21,6 +22,6 @@ docker compose -f test_runner/logical_repl/clickhouse/docker-compose.yml down export BENCHMARK_CONNSTR=postgres://user:pass@ep-abc-xyz-123.us-east-2.aws.neon.build/neondb docker compose -f test_runner/logical_repl/debezium/docker-compose.yml up -d -./scripts/pytest -m remote_cluster -k test_debezium +./scripts/pytest -m remote_cluster -k 'test_debezium[release-pg17]' docker compose -f test_runner/logical_repl/debezium/docker-compose.yml down ``` diff --git a/test_runner/logical_repl/clickhouse/docker-compose.yml b/test_runner/logical_repl/clickhouse/docker-compose.yml index e00038b811..4131fbf0a5 100644 --- a/test_runner/logical_repl/clickhouse/docker-compose.yml +++ b/test_runner/logical_repl/clickhouse/docker-compose.yml @@ -1,9 +1,11 @@ services: clickhouse: - image: clickhouse/clickhouse-server + image: clickhouse/clickhouse-server:25.6 user: "101:101" container_name: clickhouse hostname: clickhouse + environment: + - CLICKHOUSE_PASSWORD=${CLICKHOUSE_PASSWORD:-ch_password123} ports: - 127.0.0.1:8123:8123 - 127.0.0.1:9000:9000 diff --git a/test_runner/logical_repl/debezium/docker-compose.yml b/test_runner/logical_repl/debezium/docker-compose.yml index fee127a2fd..fd446f173f 100644 --- a/test_runner/logical_repl/debezium/docker-compose.yml +++ b/test_runner/logical_repl/debezium/docker-compose.yml @@ -1,18 +1,28 @@ services: zookeeper: - image: quay.io/debezium/zookeeper:2.7 + image: quay.io/debezium/zookeeper:3.1.3.Final + ports: + - 127.0.0.1:2181:2181 + - 127.0.0.1:2888:2888 + - 127.0.0.1:3888:3888 kafka: - image: quay.io/debezium/kafka:2.7 + image: quay.io/debezium/kafka:3.1.3.Final + depends_on: [zookeeper] environment: ZOOKEEPER_CONNECT: "zookeeper:2181" - KAFKA_ADVERTISED_LISTENERS: PLAINTEXT://kafka:9092 + KAFKA_LISTENERS: INTERNAL://:9092,EXTERNAL://:29092 + KAFKA_ADVERTISED_LISTENERS: INTERNAL://kafka:9092,EXTERNAL://localhost:29092 + KAFKA_LISTENER_SECURITY_PROTOCOL_MAP: INTERNAL:PLAINTEXT,EXTERNAL:PLAINTEXT + KAFKA_INTER_BROKER_LISTENER_NAME: INTERNAL KAFKA_BROKER_ID: 1 KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR: 1 KAFKA_JMX_PORT: 9991 ports: - - 127.0.0.1:9092:9092 + - 9092:9092 + - 29092:29092 debezium: - image: quay.io/debezium/connect:2.7 + image: quay.io/debezium/connect:3.1.3.Final + depends_on: [kafka] environment: BOOTSTRAP_SERVERS: kafka:9092 GROUP_ID: 1 diff --git a/test_runner/logical_repl/test_clickhouse.py b/test_runner/logical_repl/test_clickhouse.py index c05684baf9..ef41ee6187 100644 --- a/test_runner/logical_repl/test_clickhouse.py +++ b/test_runner/logical_repl/test_clickhouse.py @@ -53,8 +53,13 @@ def test_clickhouse(remote_pg: RemotePostgres): cur.execute("CREATE TABLE table1 (id integer primary key, column1 varchar(10));") cur.execute("INSERT INTO table1 (id, column1) VALUES (1, 'abc'), (2, 'def');") conn.commit() - client = clickhouse_connect.get_client(host=clickhouse_host) + if "CLICKHOUSE_PASSWORD" not in os.environ: + raise RuntimeError("CLICKHOUSE_PASSWORD is not set") + client = clickhouse_connect.get_client( + host=clickhouse_host, password=os.environ["CLICKHOUSE_PASSWORD"] + ) client.command("SET allow_experimental_database_materialized_postgresql=1") + client.command("DROP DATABASE IF EXISTS db1_postgres") client.command( "CREATE DATABASE db1_postgres ENGINE = " f"MaterializedPostgreSQL('{conn_options['host']}', " diff --git a/test_runner/logical_repl/test_debezium.py b/test_runner/logical_repl/test_debezium.py index a53e6cef92..becdaffcf8 100644 --- a/test_runner/logical_repl/test_debezium.py +++ b/test_runner/logical_repl/test_debezium.py @@ -17,6 +17,7 @@ from fixtures.utils import wait_until if TYPE_CHECKING: from fixtures.neon_fixtures import RemotePostgres + from kafka import KafkaConsumer class DebeziumAPI: @@ -101,9 +102,13 @@ def debezium(remote_pg: RemotePostgres): assert len(dbz.list_connectors()) == 1 from kafka import KafkaConsumer + kafka_host = "kafka" if (os.getenv("CI", "false") == "true") else "127.0.0.1" + kafka_port = 9092 if (os.getenv("CI", "false") == "true") else 29092 + log.info("Connecting to Kafka: %s:%s", kafka_host, kafka_port) + consumer = KafkaConsumer( "dbserver1.inventory.customers", - bootstrap_servers=["kafka:9092"], + bootstrap_servers=[f"{kafka_host}:{kafka_port}"], auto_offset_reset="earliest", enable_auto_commit=False, ) @@ -112,7 +117,7 @@ def debezium(remote_pg: RemotePostgres): assert resp.status_code == 204 -def get_kafka_msg(consumer, ts_ms, before=None, after=None) -> None: +def get_kafka_msg(consumer: KafkaConsumer, ts_ms, before=None, after=None) -> None: """ Gets the message from Kafka and checks its validity Arguments: @@ -124,6 +129,7 @@ def get_kafka_msg(consumer, ts_ms, before=None, after=None) -> None: after: a dictionary, if not None, the after field from the kafka message must have the same values for the same keys """ + log.info("Bootstrap servers: %s", consumer.config["bootstrap_servers"]) msg = consumer.poll() assert msg, "Empty message" for val in msg.values(): diff --git a/test_runner/performance/benchbase_tpc_c_helpers/generate_diagrams.py b/test_runner/performance/benchbase_tpc_c_helpers/generate_diagrams.py new file mode 100644 index 0000000000..cf41a4ff59 --- /dev/null +++ b/test_runner/performance/benchbase_tpc_c_helpers/generate_diagrams.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +""" +Generate TPS and latency charts from BenchBase TPC-C results CSV files. + +This script reads a CSV file containing BenchBase results and generates two charts: +1. TPS (requests per second) over time +2. P95 and P99 latencies over time + +Both charts are combined in a single SVG file. +""" + +import argparse +import sys +from pathlib import Path + +import matplotlib.pyplot as plt # type: ignore[import-not-found] +import pandas as pd # type: ignore[import-untyped] + + +def load_results_csv(csv_file_path): + """Load BenchBase results CSV file into a pandas DataFrame.""" + try: + df = pd.read_csv(csv_file_path) + + # Validate required columns exist + required_columns = [ + "Time (seconds)", + "Throughput (requests/second)", + "95th Percentile Latency (millisecond)", + "99th Percentile Latency (millisecond)", + ] + + missing_columns = [col for col in required_columns if col not in df.columns] + if missing_columns: + print(f"Error: Missing required columns: {missing_columns}") + sys.exit(1) + + return df + + except FileNotFoundError: + print(f"Error: CSV file not found: {csv_file_path}") + sys.exit(1) + except pd.errors.EmptyDataError: + print(f"Error: CSV file is empty: {csv_file_path}") + sys.exit(1) + except Exception as e: + print(f"Error reading CSV file: {e}") + sys.exit(1) + + +def generate_charts(df, input_filename, output_svg_path, title_suffix=None): + """Generate combined TPS and latency charts and save as SVG.""" + + # Get the filename without extension for chart titles + file_label = Path(input_filename).stem + + # Build title ending with optional suffix + if title_suffix: + title_ending = f"{title_suffix} - {file_label}" + else: + title_ending = file_label + + # Create figure with two subplots + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10)) + + # Chart 1: Time vs TPS + ax1.plot( + df["Time (seconds)"], + df["Throughput (requests/second)"], + linewidth=1, + color="blue", + alpha=0.7, + ) + ax1.set_xlabel("Time (seconds)") + ax1.set_ylabel("TPS (Requests Per Second)") + ax1.set_title(f"Benchbase TPC-C Like Throughput (TPS) - {title_ending}") + ax1.grid(True, alpha=0.3) + ax1.set_xlim(0, df["Time (seconds)"].max()) + + # Chart 2: Time vs P95 and P99 Latencies + ax2.plot( + df["Time (seconds)"], + df["95th Percentile Latency (millisecond)"], + linewidth=1, + color="orange", + alpha=0.7, + label="Latency P95", + ) + ax2.plot( + df["Time (seconds)"], + df["99th Percentile Latency (millisecond)"], + linewidth=1, + color="red", + alpha=0.7, + label="Latency P99", + ) + ax2.set_xlabel("Time (seconds)") + ax2.set_ylabel("Latency (ms)") + ax2.set_title(f"Benchbase TPC-C Like Latency - {title_ending}") + ax2.grid(True, alpha=0.3) + ax2.set_xlim(0, df["Time (seconds)"].max()) + ax2.legend() + + plt.tight_layout() + + # Save as SVG + try: + plt.savefig(output_svg_path, format="svg", dpi=300, bbox_inches="tight") + print(f"Charts saved to: {output_svg_path}") + except Exception as e: + print(f"Error saving SVG file: {e}") + sys.exit(1) + + +def main(): + """Main function to parse arguments and generate charts.""" + parser = argparse.ArgumentParser( + description="Generate TPS and latency charts from BenchBase TPC-C results CSV" + ) + parser.add_argument( + "--input-csv", type=str, required=True, help="Path to the input CSV results file" + ) + parser.add_argument( + "--output-svg", type=str, required=True, help="Path for the output SVG chart file" + ) + parser.add_argument( + "--title-suffix", + type=str, + required=False, + help="Optional suffix to add to chart titles (e.g., 'Warmup', 'Benchmark Phase')", + ) + + args = parser.parse_args() + + # Validate input file exists + if not Path(args.input_csv).exists(): + print(f"Error: Input CSV file does not exist: {args.input_csv}") + sys.exit(1) + + # Create output directory if it doesn't exist + output_path = Path(args.output_svg) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Load data and generate charts + df = load_results_csv(args.input_csv) + generate_charts(df, args.input_csv, args.output_svg, args.title_suffix) + + print(f"Successfully generated charts from {len(df)} data points") + + +if __name__ == "__main__": + main() diff --git a/test_runner/performance/benchbase_tpc_c_helpers/generate_workload_size.py b/test_runner/performance/benchbase_tpc_c_helpers/generate_workload_size.py new file mode 100644 index 0000000000..1549c74b87 --- /dev/null +++ b/test_runner/performance/benchbase_tpc_c_helpers/generate_workload_size.py @@ -0,0 +1,339 @@ +import argparse +import html +import math +import os +import sys +from pathlib import Path + +CONFIGS_DIR = Path("../configs") +SCRIPTS_DIR = Path("../scripts") + +# Constants +## TODO increase times after testing +WARMUP_TIME_SECONDS = 1200 # 20 minutes +BENCHMARK_TIME_SECONDS = 3600 # 1 hour +RAMP_STEP_TIME_SECONDS = 300 # 5 minutes +BASE_TERMINALS = 130 +TERMINALS_PER_WAREHOUSE = 0.2 +OPTIMAL_RATE_FACTOR = 0.7 # 70% of max rate +BATCH_SIZE = 1000 +LOADER_THREADS = 4 +TRANSACTION_WEIGHTS = "45,43,4,4,4" # NewOrder, Payment, OrderStatus, Delivery, StockLevel +# Ramp-up rate multipliers +RAMP_RATE_FACTORS = [1.5, 1.1, 0.9, 0.7, 0.6, 0.4, 0.6, 0.7, 0.9, 1.1] + +# Templates for XML configs +WARMUP_XML = """ + + POSTGRES + org.postgresql.Driver + jdbc:postgresql://{hostname}/neondb?sslmode=require&ApplicationName=tpcc&reWriteBatchedInserts=true + neondb_owner + {password} + true + TRANSACTION_READ_COMMITTED + {batch_size} + {warehouses} + 0 + {terminals} + + + + {transaction_weights} + unlimited + POISSON + ZIPFIAN + + + + NewOrder + Payment + OrderStatus + Delivery + StockLevel + + +""" + +MAX_RATE_XML = """ + + POSTGRES + org.postgresql.Driver + jdbc:postgresql://{hostname}/neondb?sslmode=require&ApplicationName=tpcc&reWriteBatchedInserts=true + neondb_owner + {password} + true + TRANSACTION_READ_COMMITTED + {batch_size} + {warehouses} + 0 + {terminals} + + + + {transaction_weights} + unlimited + POISSON + ZIPFIAN + + + + NewOrder + Payment + OrderStatus + Delivery + StockLevel + + +""" + +OPT_RATE_XML = """ + + POSTGRES + org.postgresql.Driver + jdbc:postgresql://{hostname}/neondb?sslmode=require&ApplicationName=tpcc&reWriteBatchedInserts=true + neondb_owner + {password} + true + TRANSACTION_READ_COMMITTED + {batch_size} + {warehouses} + 0 + {terminals} + + + + {opt_rate} + {transaction_weights} + POISSON + ZIPFIAN + + + + NewOrder + Payment + OrderStatus + Delivery + StockLevel + + +""" + +RAMP_UP_XML = """ + + POSTGRES + org.postgresql.Driver + jdbc:postgresql://{hostname}/neondb?sslmode=require&ApplicationName=tpcc&reWriteBatchedInserts=true + neondb_owner + {password} + true + TRANSACTION_READ_COMMITTED + {batch_size} + {warehouses} + 0 + {terminals} + +{works} + + + NewOrder + Payment + OrderStatus + Delivery + StockLevel + + +""" + +WORK_TEMPLATE = f""" \n \n {{rate}}\n {TRANSACTION_WEIGHTS}\n POISSON\n ZIPFIAN\n \n""" + +# Templates for shell scripts +EXECUTE_SCRIPT = """# Create results directories +mkdir -p results_warmup +mkdir -p results_{suffix} +chmod 777 results_warmup results_{suffix} + +# Run warmup phase +docker run --network=host --rm \ + -v $(pwd)/configs:/configs \ + -v $(pwd)/results_warmup:/results \ + {docker_image}\ + -b tpcc \ + -c /configs/execute_{warehouses}_warehouses_warmup.xml \ + -d /results \ + --create=false --load=false --execute=true + +# Run benchmark phase +docker run --network=host --rm \ + -v $(pwd)/configs:/configs \ + -v $(pwd)/results_{suffix}:/results \ + {docker_image}\ + -b tpcc \ + -c /configs/execute_{warehouses}_warehouses_{suffix}.xml \ + -d /results \ + --create=false --load=false --execute=true\n""" + +LOAD_XML = """ + + POSTGRES + org.postgresql.Driver + jdbc:postgresql://{hostname}/neondb?sslmode=require&ApplicationName=tpcc&reWriteBatchedInserts=true + neondb_owner + {password} + true + TRANSACTION_READ_COMMITTED + {batch_size} + {warehouses} + {loader_threads} + +""" + +LOAD_SCRIPT = """# Create results directory for loading +mkdir -p results_load +chmod 777 results_load + +docker run --network=host --rm \ + -v $(pwd)/configs:/configs \ + -v $(pwd)/results_load:/results \ + {docker_image}\ + -b tpcc \ + -c /configs/load_{warehouses}_warehouses.xml \ + -d /results \ + --create=true --load=true --execute=false\n""" + + +def write_file(path, content): + path.parent.mkdir(parents=True, exist_ok=True) + try: + with open(path, "w") as f: + f.write(content) + except OSError as e: + print(f"Error writing {path}: {e}") + sys.exit(1) + # If it's a shell script, set executable permission + if str(path).endswith(".sh"): + os.chmod(path, 0o755) + + +def escape_xml_password(password): + """Escape XML special characters in password.""" + return html.escape(password, quote=True) + + +def get_docker_arch_tag(runner_arch): + """Map GitHub Actions runner.arch to Docker image architecture tag.""" + arch_mapping = {"X64": "amd64", "ARM64": "arm64"} + return arch_mapping.get(runner_arch, "amd64") # Default to amd64 + + +def main(): + parser = argparse.ArgumentParser(description="Generate BenchBase workload configs and scripts.") + parser.add_argument("--warehouses", type=int, required=True, help="Number of warehouses") + parser.add_argument("--max-rate", type=int, required=True, help="Max rate (TPS)") + parser.add_argument("--hostname", type=str, required=True, help="Database hostname") + parser.add_argument("--password", type=str, required=True, help="Database password") + parser.add_argument( + "--runner-arch", type=str, required=True, help="GitHub Actions runner architecture" + ) + args = parser.parse_args() + + warehouses = args.warehouses + max_rate = args.max_rate + hostname = args.hostname + password = args.password + runner_arch = args.runner_arch + + # Escape password for safe XML insertion + escaped_password = escape_xml_password(password) + + # Get the appropriate Docker architecture tag + docker_arch = get_docker_arch_tag(runner_arch) + docker_image = f"ghcr.io/neondatabase-labs/benchbase-postgres:latest-{docker_arch}" + + opt_rate = math.ceil(max_rate * OPTIMAL_RATE_FACTOR) + # Calculate terminals as next rounded integer of 40% of warehouses + terminals = math.ceil(BASE_TERMINALS + warehouses * TERMINALS_PER_WAREHOUSE) + ramp_rates = [math.ceil(max_rate * factor) for factor in RAMP_RATE_FACTORS] + + # Write configs + write_file( + CONFIGS_DIR / f"execute_{warehouses}_warehouses_warmup.xml", + WARMUP_XML.format( + warehouses=warehouses, + hostname=hostname, + password=escaped_password, + terminals=terminals, + batch_size=BATCH_SIZE, + warmup_time=WARMUP_TIME_SECONDS, + transaction_weights=TRANSACTION_WEIGHTS, + ), + ) + write_file( + CONFIGS_DIR / f"execute_{warehouses}_warehouses_max_rate.xml", + MAX_RATE_XML.format( + warehouses=warehouses, + hostname=hostname, + password=escaped_password, + terminals=terminals, + batch_size=BATCH_SIZE, + benchmark_time=BENCHMARK_TIME_SECONDS, + transaction_weights=TRANSACTION_WEIGHTS, + ), + ) + write_file( + CONFIGS_DIR / f"execute_{warehouses}_warehouses_opt_rate.xml", + OPT_RATE_XML.format( + warehouses=warehouses, + opt_rate=opt_rate, + hostname=hostname, + password=escaped_password, + terminals=terminals, + batch_size=BATCH_SIZE, + benchmark_time=BENCHMARK_TIME_SECONDS, + transaction_weights=TRANSACTION_WEIGHTS, + ), + ) + + ramp_works = "".join([WORK_TEMPLATE.format(rate=rate) for rate in ramp_rates]) + write_file( + CONFIGS_DIR / f"execute_{warehouses}_warehouses_ramp_up.xml", + RAMP_UP_XML.format( + warehouses=warehouses, + works=ramp_works, + hostname=hostname, + password=escaped_password, + terminals=terminals, + batch_size=BATCH_SIZE, + ), + ) + + # Loader config + write_file( + CONFIGS_DIR / f"load_{warehouses}_warehouses.xml", + LOAD_XML.format( + warehouses=warehouses, + hostname=hostname, + password=escaped_password, + batch_size=BATCH_SIZE, + loader_threads=LOADER_THREADS, + ), + ) + + # Write scripts + for suffix in ["max_rate", "opt_rate", "ramp_up"]: + script = EXECUTE_SCRIPT.format( + warehouses=warehouses, suffix=suffix, docker_image=docker_image + ) + write_file(SCRIPTS_DIR / f"execute_{warehouses}_warehouses_{suffix}.sh", script) + + # Loader script + write_file( + SCRIPTS_DIR / f"load_{warehouses}_warehouses.sh", + LOAD_SCRIPT.format(warehouses=warehouses, docker_image=docker_image), + ) + + print(f"Generated configs and scripts for {warehouses} warehouses and max rate {max_rate}.") + + +if __name__ == "__main__": + main() diff --git a/test_runner/performance/benchbase_tpc_c_helpers/upload_results_to_perf_test_results.py b/test_runner/performance/benchbase_tpc_c_helpers/upload_results_to_perf_test_results.py new file mode 100644 index 0000000000..3706d14fd4 --- /dev/null +++ b/test_runner/performance/benchbase_tpc_c_helpers/upload_results_to_perf_test_results.py @@ -0,0 +1,591 @@ +#!/usr/bin/env python3 +# ruff: noqa +# we exclude the file from ruff because on the github runner we have python 3.9 and ruff +# is running with newer python 3.12 which suggests changes incompatible with python 3.9 +""" +Upload BenchBase TPC-C results from summary.json and results.csv files to perf_test_results database. + +This script extracts metrics from BenchBase *.summary.json and *.results.csv files and uploads them +to a PostgreSQL database table for performance tracking and analysis. +""" + +import argparse +import json +import re +import sys +from datetime import datetime, timezone +from pathlib import Path + +import pandas as pd # type: ignore[import-untyped] +import psycopg2 + + +def load_summary_json(json_file_path): + """Load summary.json file and return parsed data.""" + try: + with open(json_file_path) as f: + return json.load(f) + except FileNotFoundError: + print(f"Error: Summary JSON file not found: {json_file_path}") + sys.exit(1) + except json.JSONDecodeError as e: + print(f"Error: Invalid JSON in file {json_file_path}: {e}") + sys.exit(1) + except Exception as e: + print(f"Error loading JSON file {json_file_path}: {e}") + sys.exit(1) + + +def get_metric_info(metric_name): + """Get metric unit and report type for a given metric name.""" + metrics_config = { + "Throughput": {"unit": "req/s", "report_type": "higher_is_better"}, + "Goodput": {"unit": "req/s", "report_type": "higher_is_better"}, + "Measured Requests": {"unit": "requests", "report_type": "higher_is_better"}, + "95th Percentile Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "Maximum Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "Median Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "Minimum Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "25th Percentile Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "90th Percentile Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "99th Percentile Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "75th Percentile Latency": {"unit": "µs", "report_type": "lower_is_better"}, + "Average Latency": {"unit": "µs", "report_type": "lower_is_better"}, + } + + return metrics_config.get(metric_name, {"unit": "", "report_type": "higher_is_better"}) + + +def extract_metrics(summary_data): + """Extract relevant metrics from summary JSON data.""" + metrics = [] + + # Direct top-level metrics + direct_metrics = { + "Throughput (requests/second)": "Throughput", + "Goodput (requests/second)": "Goodput", + "Measured Requests": "Measured Requests", + } + + for json_key, clean_name in direct_metrics.items(): + if json_key in summary_data: + metrics.append((clean_name, summary_data[json_key])) + + # Latency metrics from nested "Latency Distribution" object + if "Latency Distribution" in summary_data: + latency_data = summary_data["Latency Distribution"] + latency_metrics = { + "95th Percentile Latency (microseconds)": "95th Percentile Latency", + "Maximum Latency (microseconds)": "Maximum Latency", + "Median Latency (microseconds)": "Median Latency", + "Minimum Latency (microseconds)": "Minimum Latency", + "25th Percentile Latency (microseconds)": "25th Percentile Latency", + "90th Percentile Latency (microseconds)": "90th Percentile Latency", + "99th Percentile Latency (microseconds)": "99th Percentile Latency", + "75th Percentile Latency (microseconds)": "75th Percentile Latency", + "Average Latency (microseconds)": "Average Latency", + } + + for json_key, clean_name in latency_metrics.items(): + if json_key in latency_data: + metrics.append((clean_name, latency_data[json_key])) + + return metrics + + +def build_labels(summary_data, project_id): + """Build labels JSON object from summary data and project info.""" + labels = {} + + # Extract required label keys from summary data + label_keys = [ + "DBMS Type", + "DBMS Version", + "Benchmark Type", + "Final State", + "isolation", + "scalefactor", + "terminals", + ] + + for key in label_keys: + if key in summary_data: + labels[key] = summary_data[key] + + # Add project_id from workflow + labels["project_id"] = project_id + + return labels + + +def build_suit_name(scalefactor, terminals, run_type, min_cu, max_cu): + """Build the suit name according to specification.""" + return f"benchbase-tpc-c-{scalefactor}-{terminals}-{run_type}-{min_cu}-{max_cu}" + + +def convert_timestamp_to_utc(timestamp_ms): + """Convert millisecond timestamp to PostgreSQL-compatible UTC timestamp.""" + try: + dt = datetime.fromtimestamp(timestamp_ms / 1000.0, tz=timezone.utc) + return dt.isoformat() + except (ValueError, TypeError) as e: + print(f"Warning: Could not convert timestamp {timestamp_ms}: {e}") + return datetime.now(timezone.utc).isoformat() + + +def insert_metrics(conn, metrics_data): + """Insert metrics data into the perf_test_results table.""" + insert_query = """ + INSERT INTO perf_test_results + (suit, revision, platform, metric_name, metric_value, metric_unit, + metric_report_type, recorded_at_timestamp, labels) + VALUES (%(suit)s, %(revision)s, %(platform)s, %(metric_name)s, %(metric_value)s, + %(metric_unit)s, %(metric_report_type)s, %(recorded_at_timestamp)s, %(labels)s) + """ + + try: + with conn.cursor() as cursor: + cursor.executemany(insert_query, metrics_data) + conn.commit() + print(f"Successfully inserted {len(metrics_data)} metrics into perf_test_results") + + # Log some sample data for verification + if metrics_data: + print( + f"Sample metric: {metrics_data[0]['metric_name']} = {metrics_data[0]['metric_value']} {metrics_data[0]['metric_unit']}" + ) + + except Exception as e: + print(f"Error inserting metrics into database: {e}") + sys.exit(1) + + +def create_benchbase_results_details_table(conn): + """Create benchbase_results_details table if it doesn't exist.""" + create_table_query = """ + CREATE TABLE IF NOT EXISTS benchbase_results_details ( + id BIGSERIAL PRIMARY KEY, + suit TEXT, + revision CHAR(40), + platform TEXT, + recorded_at_timestamp TIMESTAMP WITH TIME ZONE, + requests_per_second NUMERIC, + average_latency_ms NUMERIC, + minimum_latency_ms NUMERIC, + p25_latency_ms NUMERIC, + median_latency_ms NUMERIC, + p75_latency_ms NUMERIC, + p90_latency_ms NUMERIC, + p95_latency_ms NUMERIC, + p99_latency_ms NUMERIC, + maximum_latency_ms NUMERIC + ); + + CREATE INDEX IF NOT EXISTS benchbase_results_details_recorded_at_timestamp_idx + ON benchbase_results_details USING BRIN (recorded_at_timestamp); + CREATE INDEX IF NOT EXISTS benchbase_results_details_suit_idx + ON benchbase_results_details USING BTREE (suit text_pattern_ops); + """ + + try: + with conn.cursor() as cursor: + cursor.execute(create_table_query) + conn.commit() + print("Successfully created/verified benchbase_results_details table") + except Exception as e: + print(f"Error creating benchbase_results_details table: {e}") + sys.exit(1) + + +def process_csv_results(csv_file_path, start_timestamp_ms, suit, revision, platform): + """Process CSV results and return data for database insertion.""" + try: + # Read CSV file + df = pd.read_csv(csv_file_path) + + # Validate required columns exist + required_columns = [ + "Time (seconds)", + "Throughput (requests/second)", + "Average Latency (millisecond)", + "Minimum Latency (millisecond)", + "25th Percentile Latency (millisecond)", + "Median Latency (millisecond)", + "75th Percentile Latency (millisecond)", + "90th Percentile Latency (millisecond)", + "95th Percentile Latency (millisecond)", + "99th Percentile Latency (millisecond)", + "Maximum Latency (millisecond)", + ] + + missing_columns = [col for col in required_columns if col not in df.columns] + if missing_columns: + print(f"Error: Missing required columns in CSV: {missing_columns}") + return [] + + csv_data = [] + + for _, row in df.iterrows(): + # Calculate timestamp: start_timestamp_ms + (time_seconds * 1000) + time_seconds = row["Time (seconds)"] + row_timestamp_ms = start_timestamp_ms + (time_seconds * 1000) + + # Convert to UTC timestamp + row_timestamp = datetime.fromtimestamp( + row_timestamp_ms / 1000.0, tz=timezone.utc + ).isoformat() + + csv_row = { + "suit": suit, + "revision": revision, + "platform": platform, + "recorded_at_timestamp": row_timestamp, + "requests_per_second": float(row["Throughput (requests/second)"]), + "average_latency_ms": float(row["Average Latency (millisecond)"]), + "minimum_latency_ms": float(row["Minimum Latency (millisecond)"]), + "p25_latency_ms": float(row["25th Percentile Latency (millisecond)"]), + "median_latency_ms": float(row["Median Latency (millisecond)"]), + "p75_latency_ms": float(row["75th Percentile Latency (millisecond)"]), + "p90_latency_ms": float(row["90th Percentile Latency (millisecond)"]), + "p95_latency_ms": float(row["95th Percentile Latency (millisecond)"]), + "p99_latency_ms": float(row["99th Percentile Latency (millisecond)"]), + "maximum_latency_ms": float(row["Maximum Latency (millisecond)"]), + } + csv_data.append(csv_row) + + print(f"Processed {len(csv_data)} rows from CSV file") + return csv_data + + except FileNotFoundError: + print(f"Error: CSV file not found: {csv_file_path}") + return [] + except Exception as e: + print(f"Error processing CSV file {csv_file_path}: {e}") + return [] + + +def insert_csv_results(conn, csv_data): + """Insert CSV results into benchbase_results_details table.""" + if not csv_data: + print("No CSV data to insert") + return + + insert_query = """ + INSERT INTO benchbase_results_details + (suit, revision, platform, recorded_at_timestamp, requests_per_second, + average_latency_ms, minimum_latency_ms, p25_latency_ms, median_latency_ms, + p75_latency_ms, p90_latency_ms, p95_latency_ms, p99_latency_ms, maximum_latency_ms) + VALUES (%(suit)s, %(revision)s, %(platform)s, %(recorded_at_timestamp)s, %(requests_per_second)s, + %(average_latency_ms)s, %(minimum_latency_ms)s, %(p25_latency_ms)s, %(median_latency_ms)s, + %(p75_latency_ms)s, %(p90_latency_ms)s, %(p95_latency_ms)s, %(p99_latency_ms)s, %(maximum_latency_ms)s) + """ + + try: + with conn.cursor() as cursor: + cursor.executemany(insert_query, csv_data) + conn.commit() + print( + f"Successfully inserted {len(csv_data)} detailed results into benchbase_results_details" + ) + + # Log some sample data for verification + sample = csv_data[0] + print( + f"Sample detail: {sample['requests_per_second']} req/s at {sample['recorded_at_timestamp']}" + ) + + except Exception as e: + print(f"Error inserting CSV results into database: {e}") + sys.exit(1) + + +def parse_load_log(log_file_path, scalefactor): + """Parse load log file and extract load metrics.""" + try: + with open(log_file_path) as f: + log_content = f.read() + + # Regex patterns to match the timestamp lines + loading_pattern = r"\[INFO \] (\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}),\d{3}.*Loading data into TPCC database" + finished_pattern = r"\[INFO \] (\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}),\d{3}.*Finished loading data into TPCC database" + + loading_match = re.search(loading_pattern, log_content) + finished_match = re.search(finished_pattern, log_content) + + if not loading_match or not finished_match: + print(f"Warning: Could not find loading timestamps in log file {log_file_path}") + return None + + # Parse timestamps + loading_time = datetime.strptime(loading_match.group(1), "%Y-%m-%d %H:%M:%S") + finished_time = datetime.strptime(finished_match.group(1), "%Y-%m-%d %H:%M:%S") + + # Calculate duration in seconds + duration_seconds = (finished_time - loading_time).total_seconds() + + # Calculate throughput: scalefactor/warehouses: 10 warehouses is approx. 1 GB of data + load_throughput = (scalefactor * 1024 / 10.0) / duration_seconds + + # Convert end time to UTC timestamp for database + finished_time_utc = finished_time.replace(tzinfo=timezone.utc).isoformat() + + print(f"Load metrics: Duration={duration_seconds}s, Throughput={load_throughput:.2f} MB/s") + + return { + "duration_seconds": duration_seconds, + "throughput_mb_per_sec": load_throughput, + "end_timestamp": finished_time_utc, + } + + except FileNotFoundError: + print(f"Warning: Load log file not found: {log_file_path}") + return None + except Exception as e: + print(f"Error parsing load log file {log_file_path}: {e}") + return None + + +def insert_load_metrics(conn, load_metrics, suit, revision, platform, labels_json): + """Insert load metrics into perf_test_results table.""" + if not load_metrics: + print("No load metrics to insert") + return + + load_metrics_data = [ + { + "suit": suit, + "revision": revision, + "platform": platform, + "metric_name": "load_duration_seconds", + "metric_value": load_metrics["duration_seconds"], + "metric_unit": "seconds", + "metric_report_type": "lower_is_better", + "recorded_at_timestamp": load_metrics["end_timestamp"], + "labels": labels_json, + }, + { + "suit": suit, + "revision": revision, + "platform": platform, + "metric_name": "load_throughput", + "metric_value": load_metrics["throughput_mb_per_sec"], + "metric_unit": "MB/second", + "metric_report_type": "higher_is_better", + "recorded_at_timestamp": load_metrics["end_timestamp"], + "labels": labels_json, + }, + ] + + insert_query = """ + INSERT INTO perf_test_results + (suit, revision, platform, metric_name, metric_value, metric_unit, + metric_report_type, recorded_at_timestamp, labels) + VALUES (%(suit)s, %(revision)s, %(platform)s, %(metric_name)s, %(metric_value)s, + %(metric_unit)s, %(metric_report_type)s, %(recorded_at_timestamp)s, %(labels)s) + """ + + try: + with conn.cursor() as cursor: + cursor.executemany(insert_query, load_metrics_data) + conn.commit() + print(f"Successfully inserted {len(load_metrics_data)} load metrics into perf_test_results") + + except Exception as e: + print(f"Error inserting load metrics into database: {e}") + sys.exit(1) + + +def main(): + """Main function to parse arguments and upload results.""" + parser = argparse.ArgumentParser( + description="Upload BenchBase TPC-C results to perf_test_results database" + ) + parser.add_argument( + "--summary-json", type=str, required=False, help="Path to the summary.json file" + ) + parser.add_argument( + "--run-type", + type=str, + required=True, + choices=["warmup", "opt-rate", "ramp-up", "load"], + help="Type of benchmark run", + ) + parser.add_argument("--min-cu", type=float, required=True, help="Minimum compute units") + parser.add_argument("--max-cu", type=float, required=True, help="Maximum compute units") + parser.add_argument("--project-id", type=str, required=True, help="Neon project ID") + parser.add_argument( + "--revision", type=str, required=True, help="Git commit hash (40 characters)" + ) + parser.add_argument( + "--connection-string", type=str, required=True, help="PostgreSQL connection string" + ) + parser.add_argument( + "--results-csv", + type=str, + required=False, + help="Path to the results.csv file for detailed metrics upload", + ) + parser.add_argument( + "--load-log", + type=str, + required=False, + help="Path to the load log file for load phase metrics", + ) + parser.add_argument( + "--warehouses", + type=int, + required=False, + help="Number of warehouses (scalefactor) for load metrics calculation", + ) + + args = parser.parse_args() + + # Validate inputs + if args.summary_json and not Path(args.summary_json).exists(): + print(f"Error: Summary JSON file does not exist: {args.summary_json}") + sys.exit(1) + + if not args.summary_json and not args.load_log: + print("Error: Either summary JSON or load log file must be provided") + sys.exit(1) + + if len(args.revision) != 40: + print(f"Warning: Revision should be 40 characters, got {len(args.revision)}") + + # Load and process summary data if provided + summary_data = None + metrics = [] + + if args.summary_json: + summary_data = load_summary_json(args.summary_json) + metrics = extract_metrics(summary_data) + if not metrics: + print("Warning: No metrics found in summary JSON") + + # Build common data for all metrics + if summary_data: + scalefactor = summary_data.get("scalefactor", "unknown") + terminals = summary_data.get("terminals", "unknown") + labels = build_labels(summary_data, args.project_id) + else: + # For load-only processing, use warehouses argument as scalefactor + scalefactor = args.warehouses if args.warehouses else "unknown" + terminals = "unknown" + labels = {"project_id": args.project_id} + + suit = build_suit_name(scalefactor, terminals, args.run_type, args.min_cu, args.max_cu) + platform = f"prod-us-east-2-{args.project_id}" + + # Convert timestamp - only needed for summary metrics and CSV processing + current_timestamp_ms = None + start_timestamp_ms = None + recorded_at = None + + if summary_data: + current_timestamp_ms = summary_data.get("Current Timestamp (milliseconds)") + start_timestamp_ms = summary_data.get("Start timestamp (milliseconds)") + + if current_timestamp_ms: + recorded_at = convert_timestamp_to_utc(current_timestamp_ms) + else: + print("Warning: No timestamp found in JSON, using current time") + recorded_at = datetime.now(timezone.utc).isoformat() + + if not start_timestamp_ms: + print("Warning: No start timestamp found in JSON, CSV upload may be incorrect") + start_timestamp_ms = ( + current_timestamp_ms or datetime.now(timezone.utc).timestamp() * 1000 + ) + + # Print Grafana dashboard link for cross-service endpoint debugging + if start_timestamp_ms and current_timestamp_ms: + grafana_url = ( + f"https://neonprod.grafana.net/d/cdya0okb81zwga/cross-service-endpoint-debugging" + f"?orgId=1&from={int(start_timestamp_ms)}&to={int(current_timestamp_ms)}" + f"&timezone=utc&var-env=prod&var-input_project_id={args.project_id}" + ) + print(f'Cross service endpoint dashboard for "{args.run_type}" phase: {grafana_url}') + + # Prepare metrics data for database insertion (only if we have summary metrics) + metrics_data = [] + if metrics and recorded_at: + for metric_name, metric_value in metrics: + metric_info = get_metric_info(metric_name) + + row = { + "suit": suit, + "revision": args.revision, + "platform": platform, + "metric_name": metric_name, + "metric_value": float(metric_value), # Ensure numeric type + "metric_unit": metric_info["unit"], + "metric_report_type": metric_info["report_type"], + "recorded_at_timestamp": recorded_at, + "labels": json.dumps(labels), # Convert to JSON string for JSONB column + } + metrics_data.append(row) + + print(f"Prepared {len(metrics_data)} summary metrics for upload to database") + print(f"Suit: {suit}") + print(f"Platform: {platform}") + + # Connect to database and insert metrics + try: + conn = psycopg2.connect(args.connection_string) + + # Insert summary metrics into perf_test_results (if any) + if metrics_data: + insert_metrics(conn, metrics_data) + else: + print("No summary metrics to upload") + + # Process and insert detailed CSV results if provided + if args.results_csv: + print(f"Processing detailed CSV results from: {args.results_csv}") + + # Create table if it doesn't exist + create_benchbase_results_details_table(conn) + + # Process CSV data + csv_data = process_csv_results( + args.results_csv, start_timestamp_ms, suit, args.revision, platform + ) + + # Insert CSV data + if csv_data: + insert_csv_results(conn, csv_data) + else: + print("No CSV data to upload") + else: + print("No CSV file provided, skipping detailed results upload") + + # Process and insert load metrics if provided + if args.load_log: + print(f"Processing load metrics from: {args.load_log}") + + # Parse load log and extract metrics + load_metrics = parse_load_log(args.load_log, scalefactor) + + # Insert load metrics + if load_metrics: + insert_load_metrics( + conn, load_metrics, suit, args.revision, platform, json.dumps(labels) + ) + else: + print("No load metrics to upload") + else: + print("No load log file provided, skipping load metrics upload") + + conn.close() + print("Database upload completed successfully") + + except psycopg2.Error as e: + print(f"Database connection/query error: {e}") + sys.exit(1) + except Exception as e: + print(f"Unexpected error: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/test_runner/performance/test_lfc_prewarm.py b/test_runner/performance/test_lfc_prewarm.py index 6c0083de95..d459f9f3bf 100644 --- a/test_runner/performance/test_lfc_prewarm.py +++ b/test_runner/performance/test_lfc_prewarm.py @@ -2,45 +2,48 @@ from __future__ import annotations import os import timeit -import traceback -from concurrent.futures import ThreadPoolExecutor as Exec from pathlib import Path +from threading import Thread from time import sleep -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, cast import pytest from fixtures.benchmark_fixture import NeonBenchmarker, PgBenchRunResult from fixtures.log_helper import log -from fixtures.neon_api import NeonAPI, connection_parameters_to_env +from fixtures.neon_api import NeonAPI, connstr_to_env + +from performance.test_perf_pgbench import utc_now_timestamp if TYPE_CHECKING: from fixtures.compare_fixtures import NeonCompare from fixtures.neon_fixtures import Endpoint, PgBin from fixtures.pg_version import PgVersion -from performance.test_perf_pgbench import utc_now_timestamp # These tests compare performance for a write-heavy and read-heavy workloads of an ordinary endpoint -# compared to the endpoint which saves its LFC and prewarms using it on startup. +# compared to the endpoint which saves its LFC and prewarms using it on startup def test_compare_prewarmed_pgbench_perf(neon_compare: NeonCompare): env = neon_compare.env - env.create_branch("normal") env.create_branch("prewarmed") pg_bin = neon_compare.pg_bin - ep_normal: Endpoint = env.endpoints.create_start("normal") - ep_prewarmed: Endpoint = env.endpoints.create_start("prewarmed", autoprewarm=True) + ep_ordinary: Endpoint = neon_compare.endpoint + ep_prewarmed: Endpoint = env.endpoints.create_start("prewarmed") - for ep in [ep_normal, ep_prewarmed]: + for ep in [ep_ordinary, ep_prewarmed]: connstr: str = ep.connstr() pg_bin.run(["pgbench", "-i", "-I", "dtGvp", connstr, "-s100"]) - ep.safe_psql("CREATE EXTENSION neon") - client = ep.http_client() - client.offload_lfc() - ep.stop() - ep.start() - client.prewarm_lfc_wait() + ep.safe_psql("CREATE SCHEMA neon; CREATE EXTENSION neon WITH SCHEMA neon") + if ep == ep_prewarmed: + client = ep.http_client() + client.offload_lfc() + ep.stop() + ep.start(autoprewarm=True) + client.prewarm_lfc_wait() + else: + ep.stop() + ep.start() run_start_timestamp = utc_now_timestamp() t0 = timeit.default_timer() @@ -59,6 +62,36 @@ def test_compare_prewarmed_pgbench_perf(neon_compare: NeonCompare): neon_compare.zenbenchmark.record_pg_bench_result(name, res) +def test_compare_prewarmed_read_perf(neon_compare: NeonCompare): + env = neon_compare.env + env.create_branch("prewarmed") + ep_ordinary: Endpoint = neon_compare.endpoint + ep_prewarmed: Endpoint = env.endpoints.create_start("prewarmed") + + sql = [ + "CREATE SCHEMA neon", + "CREATE EXTENSION neon WITH SCHEMA neon", + "CREATE TABLE foo(key serial primary key, t text default 'foooooooooooooooooooooooooooooooooooooooooooooooooooo')", + "INSERT INTO foo SELECT FROM generate_series(1,1000000)", + ] + sql_check = "SELECT count(*) from foo" + + ep_ordinary.safe_psql_many(sql) + ep_ordinary.stop() + ep_ordinary.start() + with neon_compare.record_duration("ordinary_run_duration"): + ep_ordinary.safe_psql(sql_check) + + ep_prewarmed.safe_psql_many(sql) + client = ep_prewarmed.http_client() + client.offload_lfc() + ep_prewarmed.stop() + ep_prewarmed.start(autoprewarm=True) + client.prewarm_lfc_wait() + with neon_compare.record_duration("prewarmed_run_duration"): + ep_prewarmed.safe_psql(sql_check) + + @pytest.mark.remote_cluster @pytest.mark.timeout(2 * 60 * 60) def test_compare_prewarmed_pgbench_perf_benchmark( @@ -67,67 +100,66 @@ def test_compare_prewarmed_pgbench_perf_benchmark( pg_version: PgVersion, zenbenchmark: NeonBenchmarker, ): - name = f"Test prewarmed pgbench performance, GITHUB_RUN_ID={os.getenv('GITHUB_RUN_ID')}" - project = neon_api.create_project(pg_version, name) - project_id = project["project"]["id"] - neon_api.wait_for_operation_to_finish(project_id) - err = False - try: - benchmark_impl(pg_bin, neon_api, project, zenbenchmark) - except Exception as e: - err = True - log.error(f"Caught exception: {e}") - log.error(traceback.format_exc()) - finally: - assert not err - neon_api.delete_project(project_id) + """ + Prewarm API is not public, so this test relies on a pre-created project + with pgbench size of 3424, pgbench -i -IdtGvp -s3424. Sleeping and + offloading constants are hardcoded to this size as well + """ + project_id = os.getenv("PROJECT_ID") + assert project_id + ordinary_branch_id = "" + prewarmed_branch_id = "" + for branch in neon_api.get_branches(project_id)["branches"]: + if branch["name"] == "ordinary": + ordinary_branch_id = branch["id"] + if branch["name"] == "prewarmed": + prewarmed_branch_id = branch["id"] + assert len(ordinary_branch_id) > 0 + assert len(prewarmed_branch_id) > 0 + + ep_ordinary = None + ep_prewarmed = None + for ep in neon_api.get_endpoints(project_id)["endpoints"]: + if ep["branch_id"] == ordinary_branch_id: + ep_ordinary = ep + if ep["branch_id"] == prewarmed_branch_id: + ep_prewarmed = ep + assert ep_ordinary + assert ep_prewarmed + ordinary_id = ep_ordinary["id"] + prewarmed_id = ep_prewarmed["id"] -def benchmark_impl( - pg_bin: PgBin, neon_api: NeonAPI, project: dict[str, Any], zenbenchmark: NeonBenchmarker -): - pgbench_size = int(os.getenv("PGBENCH_SIZE") or "3424") # 50GB offload_secs = 20 - test_duration_min = 5 + test_duration_min = 3 pgbench_duration = f"-T{test_duration_min * 60}" - # prewarm API is not publicly exposed. In order to test performance of a - # fully prewarmed endpoint, wait after it restarts. - # The number here is empirical, based on manual runs on staging + pgbench_init_cmd = ["pgbench", "-P10", "-n", "-c10", pgbench_duration, "-Mprepared"] + pgbench_perf_cmd = pgbench_init_cmd + ["-S"] prewarmed_sleep_secs = 180 - branch_id = project["branch"]["id"] - project_id = project["project"]["id"] - normal_env = connection_parameters_to_env( - project["connection_uris"][0]["connection_parameters"] - ) - normal_id = project["endpoints"][0]["id"] - - prewarmed_branch_id = neon_api.create_branch( - project_id, "prewarmed", parent_id=branch_id, add_endpoint=False - )["branch"]["id"] - neon_api.wait_for_operation_to_finish(project_id) - - ep_prewarmed = neon_api.create_endpoint( - project_id, - prewarmed_branch_id, - endpoint_type="read_write", - settings={"autoprewarm": True, "offload_lfc_interval_seconds": offload_secs}, - ) - neon_api.wait_for_operation_to_finish(project_id) - - prewarmed_env = normal_env.copy() - prewarmed_env["PGHOST"] = ep_prewarmed["endpoint"]["host"] - prewarmed_id = ep_prewarmed["endpoint"]["id"] + ordinary_uri = neon_api.get_connection_uri(project_id, ordinary_branch_id, ordinary_id)["uri"] + prewarmed_uri = neon_api.get_connection_uri(project_id, prewarmed_branch_id, prewarmed_id)[ + "uri" + ] def bench(endpoint_name, endpoint_id, env): - pg_bin.run(["pgbench", "-i", "-I", "dtGvp", f"-s{pgbench_size}"], env) - sleep(offload_secs * 2) # ensure LFC is offloaded after pgbench finishes - neon_api.restart_endpoint(project_id, endpoint_id) - sleep(prewarmed_sleep_secs) + log.info(f"Running pgbench for {pgbench_duration}s to warm up the cache") + pg_bin.run_capture(pgbench_init_cmd, env) # capture useful for debugging + log.info(f"Initialized {endpoint_name}") + if endpoint_name == "prewarmed": + log.info(f"sleeping {offload_secs * 2} to ensure LFC is offloaded") + sleep(offload_secs * 2) + neon_api.restart_endpoint(project_id, endpoint_id) + log.info(f"sleeping {prewarmed_sleep_secs} to ensure LFC is prewarmed") + sleep(prewarmed_sleep_secs) + else: + neon_api.restart_endpoint(project_id, endpoint_id) + + log.info(f"Starting benchmark for {endpoint_name}") run_start_timestamp = utc_now_timestamp() t0 = timeit.default_timer() - out = pg_bin.run_capture(["pgbench", "-c10", pgbench_duration, "-Mprepared"], env) + out = pg_bin.run_capture(pgbench_perf_cmd, env) run_duration = timeit.default_timer() - t0 run_end_timestamp = utc_now_timestamp() @@ -140,29 +172,9 @@ def benchmark_impl( ) zenbenchmark.record_pg_bench_result(endpoint_name, res) - with Exec(max_workers=2) as exe: - exe.submit(bench, "normal", normal_id, normal_env) - exe.submit(bench, "prewarmed", prewarmed_id, prewarmed_env) + prewarmed_args = ("prewarmed", prewarmed_id, connstr_to_env(prewarmed_uri)) + prewarmed_thread = Thread(target=bench, args=prewarmed_args) + prewarmed_thread.start() - -def test_compare_prewarmed_read_perf(neon_compare: NeonCompare): - env = neon_compare.env - env.create_branch("normal") - env.create_branch("prewarmed") - ep_normal: Endpoint = env.endpoints.create_start("normal") - ep_prewarmed: Endpoint = env.endpoints.create_start("prewarmed", autoprewarm=True) - - sql = [ - "CREATE EXTENSION neon", - "CREATE TABLE foo(key serial primary key, t text default 'foooooooooooooooooooooooooooooooooooooooooooooooooooo')", - "INSERT INTO foo SELECT FROM generate_series(1,1000000)", - ] - for ep in [ep_normal, ep_prewarmed]: - ep.safe_psql_many(sql) - client = ep.http_client() - client.offload_lfc() - ep.stop() - ep.start() - client.prewarm_lfc_wait() - with neon_compare.record_duration(f"{ep.branch_name}_run_duration"): - ep.safe_psql("SELECT count(*) from foo") + bench("ordinary", ordinary_id, connstr_to_env(ordinary_uri)) + prewarmed_thread.join() diff --git a/test_runner/performance/test_perf_many_relations.py b/test_runner/performance/test_perf_many_relations.py index 81dae53759..9204a6a740 100644 --- a/test_runner/performance/test_perf_many_relations.py +++ b/test_runner/performance/test_perf_many_relations.py @@ -5,7 +5,7 @@ import pytest from fixtures.benchmark_fixture import NeonBenchmarker from fixtures.compare_fixtures import RemoteCompare from fixtures.log_helper import log -from fixtures.neon_fixtures import NeonEnvBuilder +from fixtures.neon_fixtures import NeonEnvBuilder, wait_for_last_flush_lsn from fixtures.utils import shared_buffers_for_max_cu @@ -69,13 +69,18 @@ def test_perf_many_relations(remote_compare: RemoteCompare, num_relations: int): ) -def test_perf_simple_many_relations_reldir_v2( - neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchmarker +@pytest.mark.parametrize( + "reldir,num_relations", + [("v1", 10000), ("v1v2", 10000), ("v2", 10000), ("v2", 100000)], + ids=["v1-small", "v1v2-small", "v2-small", "v2-large"], +) +def test_perf_simple_many_relations_reldir( + neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchmarker, reldir: str, num_relations: int ): """ Test creating many relations in a single database. """ - env = neon_env_builder.init_start(initial_tenant_conf={"rel_size_v2_enabled": "true"}) + env = neon_env_builder.init_start(initial_tenant_conf={"rel_size_v2_enabled": reldir != "v1"}) ep = env.endpoints.create_start( "main", config_lines=[ @@ -85,14 +90,38 @@ def test_perf_simple_many_relations_reldir_v2( ], ) - assert ( - env.pageserver.http_client().timeline_detail(env.initial_tenant, env.initial_timeline)[ - "rel_size_migration" - ] - != "legacy" - ) + ep.safe_psql("CREATE TABLE IF NOT EXISTS initial_table (v1 int)") + wait_for_last_flush_lsn(env, ep, env.initial_tenant, env.initial_timeline) - n = 100000 + if reldir == "v1": + assert ( + env.pageserver.http_client().timeline_detail(env.initial_tenant, env.initial_timeline)[ + "rel_size_migration" + ] + == "legacy" + ) + elif reldir == "v1v2": + assert ( + env.pageserver.http_client().timeline_detail(env.initial_tenant, env.initial_timeline)[ + "rel_size_migration" + ] + == "migrating" + ) + elif reldir == "v2": + # only read/write to the v2 keyspace + env.pageserver.http_client().timeline_patch_index_part( + env.initial_tenant, env.initial_timeline, {"rel_size_migration": "migrated"} + ) + assert ( + env.pageserver.http_client().timeline_detail(env.initial_tenant, env.initial_timeline)[ + "rel_size_migration" + ] + == "migrated" + ) + else: + raise AssertionError(f"Invalid reldir config: {reldir}") + + n = num_relations step = 5000 # Create many relations log.info(f"Creating {n} relations...") diff --git a/test_runner/regress/test_bad_connection.py b/test_runner/regress/test_bad_connection.py index d31c0c95d3..3c30296e6e 100644 --- a/test_runner/regress/test_bad_connection.py +++ b/test_runner/regress/test_bad_connection.py @@ -26,7 +26,7 @@ def test_compute_pageserver_connection_stress(neon_env_builder: NeonEnvBuilder): # Enable failpoint before starting everything else up so that we exercise the retry # on fetching basebackup pageserver_http = env.pageserver.http_client() - pageserver_http.configure_failpoints(("simulated-bad-compute-connection", "50%return(15)")) + pageserver_http.configure_failpoints(("simulated-bad-compute-connection", "20%return(15)")) env.create_branch("test_compute_pageserver_connection_stress") endpoint = env.endpoints.create_start("test_compute_pageserver_connection_stress") diff --git a/test_runner/regress/test_change_pageserver.py b/test_runner/regress/test_change_pageserver.py index b004db310c..af736af825 100644 --- a/test_runner/regress/test_change_pageserver.py +++ b/test_runner/regress/test_change_pageserver.py @@ -3,14 +3,35 @@ from __future__ import annotations import asyncio from typing import TYPE_CHECKING +import pytest from fixtures.log_helper import log +from fixtures.neon_fixtures import NeonEnvBuilder from fixtures.remote_storage import RemoteStorageKind if TYPE_CHECKING: - from fixtures.neon_fixtures import NeonEnvBuilder + from fixtures.neon_fixtures import Endpoint, NeonEnvBuilder -def test_change_pageserver(neon_env_builder: NeonEnvBuilder): +def reconfigure_endpoint(endpoint: Endpoint, pageserver_id: int, use_explicit_reconfigure: bool): + # It's important that we always update config.json before issuing any reconfigure requests + # to make sure that PG-initiated config refresh doesn't mess things up by reverting to the old config. + endpoint.update_pageservers_in_config(pageserver_id=pageserver_id) + + # PG will automatically refresh its configuration if it detects connectivity issues with pageservers. + # We also allow the test to explicitly request a reconfigure so that the test can be sure that the + # endpoint is running with the latest configuration. + # + # Note that explicit reconfiguration is not required for the system to function or for this test to pass. + # It is kept for reference as this is how this test used to work before the capability of initiating + # configuration refreshes was added to compute nodes. + if use_explicit_reconfigure: + endpoint.reconfigure(pageserver_id=pageserver_id) + + +@pytest.mark.parametrize("use_explicit_reconfigure_for_failover", [False, True]) +def test_change_pageserver( + neon_env_builder: NeonEnvBuilder, use_explicit_reconfigure_for_failover: bool +): """ A relatively low level test of reconfiguring a compute's pageserver at runtime. Usually this is all done via the storage controller, but this test will disable the storage controller's compute @@ -72,7 +93,10 @@ def test_change_pageserver(neon_env_builder: NeonEnvBuilder): execute("SELECT count(*) FROM foo") assert fetchone() == (100000,) - endpoint.reconfigure(pageserver_id=alt_pageserver_id) + # Reconfigure the endpoint to use the alt pageserver. We issue an explicit reconfigure request here + # regardless of test mode as this is testing the externally driven reconfiguration scenario, not the + # compute-initiated reconfiguration scenario upon detecting failures. + reconfigure_endpoint(endpoint, pageserver_id=alt_pageserver_id, use_explicit_reconfigure=True) # Verify that the neon.pageserver_connstring GUC is set to the correct thing execute("SELECT setting FROM pg_settings WHERE name='neon.pageserver_connstring'") @@ -100,6 +124,12 @@ def test_change_pageserver(neon_env_builder: NeonEnvBuilder): env.storage_controller.node_configure(env.pageservers[1].id, {"availability": "Offline"}) env.storage_controller.reconcile_until_idle() + reconfigure_endpoint( + endpoint, + pageserver_id=env.pageservers[0].id, + use_explicit_reconfigure=use_explicit_reconfigure_for_failover, + ) + endpoint.reconfigure(pageserver_id=env.pageservers[0].id) execute("SELECT count(*) FROM foo") @@ -116,7 +146,11 @@ def test_change_pageserver(neon_env_builder: NeonEnvBuilder): await asyncio.sleep( 1 ) # Sleep for 1 second just to make sure we actually started our count(*) query - endpoint.reconfigure(pageserver_id=env.pageservers[1].id) + reconfigure_endpoint( + endpoint, + pageserver_id=env.pageservers[1].id, + use_explicit_reconfigure=use_explicit_reconfigure_for_failover, + ) def execute_count(): execute("SELECT count(*) FROM FOO") diff --git a/test_runner/regress/test_compaction.py b/test_runner/regress/test_compaction.py index 76485c8321..be82ee806f 100644 --- a/test_runner/regress/test_compaction.py +++ b/test_runner/regress/test_compaction.py @@ -58,7 +58,7 @@ PREEMPT_GC_COMPACTION_TENANT_CONF = { "compaction_upper_limit": 6, "lsn_lease_length": "0s", # Enable gc-compaction - "gc_compaction_enabled": "true", + "gc_compaction_enabled": True, "gc_compaction_initial_threshold_kb": 1024, # At a small threshold "gc_compaction_ratio_percent": 1, # No PiTR interval and small GC horizon @@ -540,7 +540,7 @@ def test_pageserver_gc_compaction_trigger(neon_env_builder: NeonEnvBuilder): "pitr_interval": "0s", "gc_horizon": f"{1024 * 16}", "lsn_lease_length": "0s", - "gc_compaction_enabled": "true", + "gc_compaction_enabled": True, "gc_compaction_initial_threshold_kb": "16", "gc_compaction_ratio_percent": "50", # Do not generate image layers with create_image_layers diff --git a/test_runner/regress/test_compatibility.py b/test_runner/regress/test_compatibility.py index 734887c5b3..635a040800 100644 --- a/test_runner/regress/test_compatibility.py +++ b/test_runner/regress/test_compatibility.py @@ -538,6 +538,7 @@ def test_historic_storage_formats( neon_env_builder.enable_pageserver_remote_storage(s3_storage()) neon_env_builder.pg_version = dataset.pg_version env = neon_env_builder.init_configs() + env.start() assert isinstance(env.pageserver_remote_storage, S3Storage) @@ -576,6 +577,17 @@ def test_historic_storage_formats( # All our artifacts should contain at least one timeline assert len(timelines) > 0 + if dataset.name == "2025-04-08-tenant-manifest-v1": + # This dataset was created at a time where we decided to migrate to v2 reldir by simply disabling writes to v1 + # and starting writing to v2. This was too risky and we have reworked the migration plan. Therefore, we should + # opt in full relv2 mode for this dataset. + for timeline in timelines: + env.pageserver.http_client().timeline_patch_index_part( + dataset.tenant_id, + timeline["timeline_id"], + {"force_index_update": True, "rel_size_migration": "migrated"}, + ) + # Import tenant does not create the timeline on safekeepers, # because it is a debug handler and the timeline may have already been # created on some set of safekeepers. diff --git a/test_runner/regress/test_compute_termination.py b/test_runner/regress/test_compute_termination.py new file mode 100644 index 0000000000..2d62ccf20f --- /dev/null +++ b/test_runner/regress/test_compute_termination.py @@ -0,0 +1,369 @@ +from __future__ import annotations + +import json +import os +import shutil +import subprocess +import threading +import time +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import TYPE_CHECKING + +import requests +from fixtures.log_helper import log +from typing_extensions import override + +if TYPE_CHECKING: + from typing import Any + + from fixtures.common_types import TenantId, TimelineId + from fixtures.neon_fixtures import NeonEnv + from fixtures.port_distributor import PortDistributor + + +def launch_compute_ctl( + env: NeonEnv, + endpoint_name: str, + external_http_port: int, + internal_http_port: int, + pg_port: int, + control_plane_port: int, +) -> subprocess.Popen[str]: + """ + Helper function to launch compute_ctl process with common configuration. + Returns the Popen process object. + """ + # Create endpoint directory structure following the standard pattern + endpoint_path = env.repo_dir / "endpoints" / endpoint_name + + # Clean up any existing endpoint directory to avoid conflicts + if endpoint_path.exists(): + shutil.rmtree(endpoint_path) + + endpoint_path.mkdir(mode=0o755, parents=True, exist_ok=True) + + # pgdata path - compute_ctl will create this directory during basebackup + pgdata_path = endpoint_path / "pgdata" + + # Create log file in endpoint directory + log_file = endpoint_path / "compute.log" + log_handle = open(log_file, "w") + + # Start compute_ctl pointing to our control plane + compute_ctl_path = env.neon_binpath / "compute_ctl" + connstr = f"postgresql://cloud_admin@localhost:{pg_port}/postgres" + + # Find postgres binary path + pg_bin_path = env.pg_distrib_dir / env.pg_version.v_prefixed / "bin" / "postgres" + pg_lib_path = env.pg_distrib_dir / env.pg_version.v_prefixed / "lib" + + env_vars = { + "INSTANCE_ID": "lakebase-instance-id", + "LD_LIBRARY_PATH": str(pg_lib_path), # Linux, etc. + "DYLD_LIBRARY_PATH": str(pg_lib_path), # macOS + } + + cmd = [ + str(compute_ctl_path), + "--external-http-port", + str(external_http_port), + "--internal-http-port", + str(internal_http_port), + "--pgdata", + str(pgdata_path), + "--connstr", + connstr, + "--pgbin", + str(pg_bin_path), + "--compute-id", + endpoint_name, # Use endpoint_name as compute-id + "--control-plane-uri", + f"http://127.0.0.1:{control_plane_port}", + "--lakebase-mode", + "true", + ] + + print(f"Launching compute_ctl with command: {cmd}") + + # Start compute_ctl + process = subprocess.Popen( + cmd, + env=env_vars, + stdout=log_handle, + stderr=subprocess.STDOUT, # Combine stderr with stdout + text=True, + ) + + return process + + +def wait_for_compute_status( + compute_process: subprocess.Popen[str], + http_port: int, + expected_status: str, + timeout_seconds: int = 10, +) -> None: + """ + Wait for compute_ctl to reach the expected status. + Raises an exception if timeout is reached or process exits unexpectedly. + """ + start_time = time.time() + while time.time() - start_time < timeout_seconds: + try: + # Try to connect to the HTTP endpoint + response = requests.get(f"http://localhost:{http_port}/status", timeout=0.5) + if response.status_code == 200: + status_json = response.json() + # Check if it's in expected status + if status_json.get("status") == expected_status: + return + except (requests.ConnectionError, requests.Timeout): + pass + + # Check if process has exited + if compute_process.poll() is not None: + raise Exception( + f"compute_ctl exited unexpectedly with code {compute_process.returncode}." + ) + + time.sleep(0.5) + + # Timeout reached + compute_process.terminate() + raise Exception( + f"compute_ctl failed to reach {expected_status} status within {timeout_seconds} seconds." + ) + + +class EmptySpecHandler(BaseHTTPRequestHandler): + """HTTP handler that returns an Empty compute spec response""" + + def do_GET(self): + if self.path.startswith("/compute/api/v2/computes/") and self.path.endswith("/spec"): + # Return empty status which will put compute in Empty state + response: dict[str, Any] = { + "status": "empty", + "spec": None, + "compute_ctl_config": {"jwks": {"keys": []}}, + } + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps(response).encode()) + else: + self.send_error(404) + + @override + def log_message(self, format: str, *args: Any): + # Suppress request logging + pass + + +def test_compute_terminate_empty(neon_simple_env: NeonEnv, port_distributor: PortDistributor): + """ + Test that terminating a compute in Empty status works correctly. + + This tests the bug fix where terminating an Empty compute would hang + waiting for a non-existent postgres process to terminate. + """ + env = neon_simple_env + + # Get ports for our test + control_plane_port = port_distributor.get_port() + external_http_port = port_distributor.get_port() + internal_http_port = port_distributor.get_port() + pg_port = port_distributor.get_port() + + # Start a simple HTTP server that will serve the Empty spec + server = HTTPServer(("127.0.0.1", control_plane_port), EmptySpecHandler) + server_thread = threading.Thread(target=server.serve_forever) + server_thread.daemon = True + server_thread.start() + + compute_process = None + try: + # Start compute_ctl with ephemeral tenant ID + compute_process = launch_compute_ctl( + env, + "test-empty-compute", + external_http_port, + internal_http_port, + pg_port, + control_plane_port, + ) + + # Wait for compute_ctl to start and report "empty" status + wait_for_compute_status(compute_process, external_http_port, "empty") + + # Now send terminate request + response = requests.post(f"http://localhost:{external_http_port}/terminate") + + # Verify that the termination request sends back a 200 OK response and is not abruptly terminated. + assert response.status_code == 200, ( + f"Expected 200 OK, got {response.status_code}: {response.text}" + ) + + # Wait for compute_ctl to exit + exit_code = compute_process.wait(timeout=10) + assert exit_code == 0, f"compute_ctl exited with non-zero code: {exit_code}" + + finally: + # Clean up + server.shutdown() + if compute_process and compute_process.poll() is None: + compute_process.terminate() + compute_process.wait() + + +class SwitchableConfigHandler(BaseHTTPRequestHandler): + """HTTP handler that can switch between normal compute configs and compute configs without specs""" + + return_empty_spec: bool = False + tenant_id: TenantId | None = None + timeline_id: TimelineId | None = None + pageserver_port: int | None = None + safekeeper_connstrs: list[str] | None = None + + def do_GET(self): + if self.path.startswith("/compute/api/v2/computes/") and self.path.endswith("/spec"): + if self.return_empty_spec: + # Return empty status + response: dict[str, object | None] = { + "status": "empty", + "spec": None, + "compute_ctl_config": { + "jwks": {"keys": []}, + }, + } + else: + # Return normal attached spec + response = { + "status": "attached", + "spec": { + "format_version": 1.0, + "cluster": { + "roles": [], + "databases": [], + "postgresql_conf": "shared_preload_libraries='neon'", + }, + "tenant_id": str(self.tenant_id) if self.tenant_id else "", + "timeline_id": str(self.timeline_id) if self.timeline_id else "", + "pageserver_connstring": f"postgres://no_user@localhost:{self.pageserver_port}" + if self.pageserver_port + else "", + "safekeeper_connstrings": self.safekeeper_connstrs or [], + "mode": "Primary", + "skip_pg_catalog_updates": True, + "reconfigure_concurrency": 1, + "suspend_timeout_seconds": -1, + }, + "compute_ctl_config": { + "jwks": {"keys": []}, + }, + } + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps(response).encode()) + else: + self.send_error(404) + + @override + def log_message(self, format: str, *args: Any): + # Suppress request logging + pass + + +def test_compute_empty_spec_during_refresh_configuration( + neon_simple_env: NeonEnv, port_distributor: PortDistributor +): + """ + Test that compute exits when it receives an empty spec during refresh configuration state. + + This test: + 1. Start compute with a normal spec + 2. Change the spec handler to return empty spec + 3. Trigger some condition to force compute to refresh configuration + 4. Verify that compute_ctl exits + """ + env = neon_simple_env + + # Get ports for our test + control_plane_port = port_distributor.get_port() + external_http_port = port_distributor.get_port() + internal_http_port = port_distributor.get_port() + pg_port = port_distributor.get_port() + + # Set up handler class variables + SwitchableConfigHandler.tenant_id = env.initial_tenant + SwitchableConfigHandler.timeline_id = env.initial_timeline + SwitchableConfigHandler.pageserver_port = env.pageserver.service_port.pg + # Convert comma-separated string to list + safekeeper_connstrs = env.get_safekeeper_connstrs() + if safekeeper_connstrs: + SwitchableConfigHandler.safekeeper_connstrs = safekeeper_connstrs.split(",") + else: + SwitchableConfigHandler.safekeeper_connstrs = [] + SwitchableConfigHandler.return_empty_spec = False # Start with normal spec + + # Start HTTP server with switchable spec handler + server = HTTPServer(("127.0.0.1", control_plane_port), SwitchableConfigHandler) + server_thread = threading.Thread(target=server.serve_forever) + server_thread.daemon = True + server_thread.start() + + compute_process = None + try: + # Start compute_ctl with tenant and timeline IDs + # Use a unique endpoint name to avoid conflicts + endpoint_name = f"test-refresh-compute-{os.getpid()}" + compute_process = launch_compute_ctl( + env, + endpoint_name, + external_http_port, + internal_http_port, + pg_port, + control_plane_port, + ) + + # Wait for compute_ctl to start and report "running" status + wait_for_compute_status(compute_process, external_http_port, "running", timeout_seconds=30) + + log.info("Compute is running. Now returning empty spec and trigger configuration refresh.") + + # Switch spec fetch handler to return empty spec + SwitchableConfigHandler.return_empty_spec = True + + # Trigger a configuration refresh + try: + requests.post(f"http://localhost:{internal_http_port}/refresh_configuration") + except requests.RequestException as e: + log.info(f"Call to /refresh_configuration failed: {e}") + log.info( + "Ignoring the error, assuming that compute_ctl is already refreshing or has exited" + ) + + # Wait for compute_ctl to exit (it should exit when it gets an empty spec during refresh) + exit_start_time = time.time() + while time.time() - exit_start_time < 30: + if compute_process.poll() is not None: + # Process exited + break + time.sleep(0.5) + + # Verify that compute_ctl exited + exit_code = compute_process.poll() + if exit_code is None: + compute_process.terminate() + raise Exception("compute_ctl did not exit after receiving empty spec.") + + # The exit code might not be 0 in this case since it's an unexpected termination + # but we mainly care that it did exit + assert exit_code is not None, "compute_ctl should have exited" + + finally: + # Clean up + server.shutdown() + if compute_process and compute_process.poll() is None: + compute_process.terminate() + compute_process.wait() diff --git a/test_runner/regress/test_hadron_ps_connectivity_metrics.py b/test_runner/regress/test_hadron_ps_connectivity_metrics.py new file mode 100644 index 0000000000..ff1f37b634 --- /dev/null +++ b/test_runner/regress/test_hadron_ps_connectivity_metrics.py @@ -0,0 +1,137 @@ +import json +import shutil + +from fixtures.common_types import TenantShardId +from fixtures.log_helper import log +from fixtures.metrics import parse_metrics +from fixtures.neon_fixtures import Endpoint, NeonEnvBuilder, NeonPageserver +from requests.exceptions import ConnectionError + + +# Helper function to attempt reconfiguration of the compute to point to a new pageserver. Note that in these tests, +# we don't expect the reconfiguration attempts to go through, as we will be pointing the compute at a "wrong" pageserver. +def _attempt_reconfiguration(endpoint: Endpoint, new_pageserver_id: int, timeout_sec: float): + try: + endpoint.reconfigure(pageserver_id=new_pageserver_id, timeout_sec=timeout_sec) + except Exception as e: + log.info(f"reconfiguration failed with exception {e}") + pass + + +def read_misrouted_metric_value(pageserver: NeonPageserver) -> float: + return ( + pageserver.http_client() + .get_metrics() + .query_one("pageserver_misrouted_pagestream_requests_total") + .value + ) + + +def read_request_error_metric_value(endpoint: Endpoint) -> float: + return ( + parse_metrics(endpoint.http_client().metrics()) + .query_one("pg_cctl_pagestream_request_errors_total") + .value + ) + + +def test_misrouted_to_secondary( + neon_env_builder: NeonEnvBuilder, +): + """ + Tests that the following metrics are incremented when compute tries to talk to a secondary pageserver: + - On pageserver receiving the request: pageserver_misrouted_pagestream_requests_total + - On compute: pg_cctl_pagestream_request_errors_total + """ + neon_env_builder.num_pageservers = 2 + env = neon_env_builder.init_configs() + env.broker.start() + env.storage_controller.start() + for ps in env.pageservers: + ps.start() + for sk in env.safekeepers: + sk.start() + + # Create a tenant that has one primary and one secondary. Due to primary/secondary placement constraints, + # the primary and secondary pageservers will be different. + tenant_id, _ = env.create_tenant(shard_count=1, placement_policy=json.dumps({"Attached": 1})) + endpoint = env.endpoints.create( + "main", tenant_id=tenant_id, config_lines=["neon.lakebase_mode = true"] + ) + endpoint.respec(skip_pg_catalog_updates=False) + endpoint.start() + + # Get the primary pageserver serving the zero shard of the tenant, and detach it from the primary pageserver. + # This test operation configures tenant directly on the pageserver/does not go through the storage controller, + # so the compute does not get any notifications and will keep pointing at the detached pageserver. + tenant_zero_shard = TenantShardId(tenant_id, shard_number=0, shard_count=1) + + primary_ps = env.get_tenant_pageserver(tenant_zero_shard) + secondary_ps = ( + env.pageservers[1] if primary_ps.id == env.pageservers[0].id else env.pageservers[0] + ) + + # Now try to point the compute at the pageserver that is acting as secondary for the tenant. Test that the metrics + # on both compute_ctl and the pageserver register the misrouted requests following the reconfiguration attempt. + assert read_misrouted_metric_value(secondary_ps) == 0 + assert read_request_error_metric_value(endpoint) == 0 + _attempt_reconfiguration(endpoint, new_pageserver_id=secondary_ps.id, timeout_sec=2.0) + assert read_misrouted_metric_value(secondary_ps) > 0 + try: + assert read_request_error_metric_value(endpoint) > 0 + except ConnectionError: + # When configuring PG to use misconfigured pageserver, PG will cancel the query after certain number of failed + # reconfigure attempts. This will cause compute_ctl to exit. + log.info("Cannot connect to PG, ignoring") + pass + + +def test_misrouted_to_ps_not_hosting_tenant( + neon_env_builder: NeonEnvBuilder, +): + """ + Tests that the following metrics are incremented when compute tries to talk to a pageserver that does not host the tenant: + - On pageserver receiving the request: pageserver_misrouted_pagestream_requests_total + - On compute: pg_cctl_pagestream_request_errors_total + """ + neon_env_builder.num_pageservers = 2 + env = neon_env_builder.init_configs() + env.broker.start() + env.storage_controller.start(handle_ps_local_disk_loss=False) + for ps in env.pageservers: + ps.start() + for sk in env.safekeepers: + sk.start() + + tenant_id, _ = env.create_tenant(shard_count=1) + endpoint = env.endpoints.create( + "main", tenant_id=tenant_id, config_lines=["neon.lakebase_mode = true"] + ) + endpoint.respec(skip_pg_catalog_updates=False) + endpoint.start() + + tenant_ps_id = env.get_tenant_pageserver( + TenantShardId(tenant_id, shard_number=0, shard_count=1) + ).id + non_hosting_ps = ( + env.pageservers[1] if tenant_ps_id == env.pageservers[0].id else env.pageservers[0] + ) + + # Clear the disk of the non-hosting PS to make sure that it indeed doesn't have any information about the tenant. + non_hosting_ps.stop(immediate=True) + shutil.rmtree(non_hosting_ps.tenant_dir()) + non_hosting_ps.start() + + # Now try to point the compute to the non-hosting pageserver. Test that the metrics + # on both compute_ctl and the pageserver register the misrouted requests following the reconfiguration attempt. + assert read_misrouted_metric_value(non_hosting_ps) == 0 + assert read_request_error_metric_value(endpoint) == 0 + _attempt_reconfiguration(endpoint, new_pageserver_id=non_hosting_ps.id, timeout_sec=2.0) + assert read_misrouted_metric_value(non_hosting_ps) > 0 + try: + assert read_request_error_metric_value(endpoint) > 0 + except ConnectionError: + # When configuring PG to use misconfigured pageserver, PG will cancel the query after certain number of failed + # reconfigure attempts. This will cause compute_ctl to exit. + log.info("Cannot connect to PG, ignoring") + pass diff --git a/test_runner/regress/test_hot_standby.py b/test_runner/regress/test_hot_standby.py index 1ff61ce8dc..a329a5e842 100644 --- a/test_runner/regress/test_hot_standby.py +++ b/test_runner/regress/test_hot_standby.py @@ -133,6 +133,9 @@ def test_hot_standby_gc(neon_env_builder: NeonEnvBuilder, pause_apply: bool): tenant_conf = { # set PITR interval to be small, so we can do GC "pitr_interval": "0 s", + # we want to control gc and checkpoint frequency precisely + "gc_period": "0s", + "compaction_period": "0s", } env = neon_env_builder.init_start(initial_tenant_conf=tenant_conf) timeline_id = env.initial_timeline @@ -186,6 +189,23 @@ def test_hot_standby_gc(neon_env_builder: NeonEnvBuilder, pause_apply: bool): client = pageserver.http_client() client.timeline_checkpoint(tenant_shard_id, timeline_id) client.timeline_compact(tenant_shard_id, timeline_id) + # Wait for standby horizon to get propagated. + # This shouldn't be necessary, but the current mechanism for + # standby_horizon propagation is imperfect. Detailed + # description in https://databricks.atlassian.net/browse/LKB-2499 + while True: + val = client.get_metric_value( + "pageserver_standby_horizon", + { + "tenant_id": str(tenant_shard_id.tenant_id), + "shard_id": str(tenant_shard_id.shard_index), + "timeline_id": str(timeline_id), + }, + ) + log.info("waiting for next standby_horizon push from safekeeper, {val=}") + if val != 0: + break + time.sleep(0.1) client.timeline_gc(tenant_shard_id, timeline_id, 0) # Re-execute the query. The GetPage requests that this diff --git a/test_runner/regress/test_lfc_prewarm.py b/test_runner/regress/test_lfc_prewarm.py index 0f0cf4cc6d..2bbe8c3e97 100644 --- a/test_runner/regress/test_lfc_prewarm.py +++ b/test_runner/regress/test_lfc_prewarm.py @@ -164,6 +164,25 @@ def test_lfc_prewarm(neon_simple_env: NeonEnv, method: PrewarmMethod): check_prewarmed(method, client, desired) +@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") +def test_lfc_prewarm_empty(neon_simple_env: NeonEnv): + """ + Test there are no errors when trying to offload or prewarm endpoint without cache using compute_ctl. + Endpoint without cache is simulated by turning off LFC manually, but in cloud/ setup this is + also reproduced on fresh endpoints + """ + env = neon_simple_env + ep = env.endpoints.create_start("main", config_lines=["neon.file_cache_size_limit=0"]) + client = ep.http_client() + conn = ep.connect() + cur = conn.cursor() + cur.execute("create schema neon; create extension neon with schema neon") + method = PrewarmMethod.COMPUTE_CTL + offload_lfc(method, client, cur) + prewarm_endpoint(method, client, cur, None) + assert client.prewarm_lfc_status()["status"] == "skipped" + + # autoprewarm isn't needed as we prewarm manually WORKLOAD_VALUES = METHOD_VALUES[:-1] WORKLOAD_IDS = METHOD_IDS[:-1] diff --git a/test_runner/regress/test_ondemand_slru_download.py b/test_runner/regress/test_ondemand_slru_download.py index f0f12290cc..607a2921a9 100644 --- a/test_runner/regress/test_ondemand_slru_download.py +++ b/test_runner/regress/test_ondemand_slru_download.py @@ -16,7 +16,7 @@ def test_ondemand_download_pg_xact(neon_env_builder: NeonEnvBuilder, shard_count neon_env_builder.num_pageservers = shard_count tenant_conf = { - "lazy_slru_download": "true", + "lazy_slru_download": True, # set PITR interval to be small, so we can do GC "pitr_interval": "0 s", } @@ -82,7 +82,7 @@ def test_ondemand_download_replica(neon_env_builder: NeonEnvBuilder, shard_count neon_env_builder.num_pageservers = shard_count tenant_conf = { - "lazy_slru_download": "true", + "lazy_slru_download": True, } env = neon_env_builder.init_start( initial_tenant_conf=tenant_conf, initial_tenant_shard_count=shard_count @@ -141,7 +141,7 @@ def test_ondemand_download_after_wal_switch(neon_env_builder: NeonEnvBuilder): """ tenant_conf = { - "lazy_slru_download": "true", + "lazy_slru_download": True, } env = neon_env_builder.init_start(initial_tenant_conf=tenant_conf) diff --git a/test_runner/regress/test_pg_regress.py b/test_runner/regress/test_pg_regress.py index dd9c5437ad..cc7f736239 100644 --- a/test_runner/regress/test_pg_regress.py +++ b/test_runner/regress/test_pg_regress.py @@ -395,23 +395,6 @@ def test_max_wal_rate(neon_simple_env: NeonEnv): tuples = endpoint.safe_psql("SELECT backpressure_throttling_time();") assert tuples[0][0] == 0, "Backpressure throttling detected" - # 0 MB/s max_wal_rate. WAL proposer can still push some WALs but will be super slow. - endpoint.safe_psql_many( - [ - "ALTER SYSTEM SET databricks.max_wal_mb_per_second = 0;", - "SELECT pg_reload_conf();", - ] - ) - - # Write ~10 KB data should hit backpressure. - with endpoint.cursor(dbname=DBNAME) as cur: - cur.execute("SET databricks.max_wal_mb_per_second = 0;") - for _ in range(0, 10): - cur.execute("INSERT INTO usertable SELECT random(), repeat('a', 1000);") - - tuples = endpoint.safe_psql("SELECT backpressure_throttling_time();") - assert tuples[0][0] > 0, "No backpressure throttling detected" - # 1 MB/s max_wal_rate. endpoint.safe_psql_many( [ @@ -457,21 +440,6 @@ def test_tx_abort_with_many_relations( ], ) - if reldir_type == "v1": - assert ( - env.pageserver.http_client().timeline_detail(env.initial_tenant, env.initial_timeline)[ - "rel_size_migration" - ] - == "legacy" - ) - else: - assert ( - env.pageserver.http_client().timeline_detail(env.initial_tenant, env.initial_timeline)[ - "rel_size_migration" - ] - != "legacy" - ) - # How many relations: this number is tuned to be long enough to take tens of seconds # if the rollback code path is buggy, tripping the test's timeout. n = 5000 @@ -556,3 +524,19 @@ def test_tx_abort_with_many_relations( except: exec.shutdown(wait=False, cancel_futures=True) raise + + # Do the check after everything is done, because the reldirv2 transition won't happen until create table. + if reldir_type == "v1": + assert ( + env.pageserver.http_client().timeline_detail(env.initial_tenant, env.initial_timeline)[ + "rel_size_migration" + ] + == "legacy" + ) + else: + assert ( + env.pageserver.http_client().timeline_detail(env.initial_tenant, env.initial_timeline)[ + "rel_size_migration" + ] + != "legacy" + ) diff --git a/test_runner/regress/test_proxy.py b/test_runner/regress/test_proxy.py index 9860658ba5..dadaf8a1cf 100644 --- a/test_runner/regress/test_proxy.py +++ b/test_runner/regress/test_proxy.py @@ -17,9 +17,6 @@ if TYPE_CHECKING: from typing import Any -GET_CONNECTION_PID_QUERY = "SELECT pid FROM pg_stat_activity WHERE state = 'active'" - - @pytest.mark.asyncio async def test_http_pool_begin_1(static_proxy: NeonProxy): static_proxy.safe_psql("create user http_auth with password 'http' superuser") @@ -479,7 +476,7 @@ def test_sql_over_http_pool(static_proxy: NeonProxy): def get_pid(status: int, pw: str, user="http_auth") -> Any: return static_proxy.http_query( - GET_CONNECTION_PID_QUERY, + "SELECT pg_backend_pid() as pid", [], user=user, password=pw, @@ -513,6 +510,35 @@ def test_sql_over_http_pool(static_proxy: NeonProxy): assert "password authentication failed for user" in res["message"] +def test_sql_over_http_pool_settings(static_proxy: NeonProxy): + static_proxy.safe_psql("create user http_auth with password 'http' superuser") + + def multiquery(*queries) -> Any: + results = static_proxy.http_multiquery( + *queries, + user="http_auth", + password="http", + expected_code=200, + ) + + return [result["rows"] for result in results["results"]] + + [[intervalstyle]] = static_proxy.safe_psql("SHOW IntervalStyle") + assert intervalstyle == "postgres", "'postgres' is the default IntervalStyle in postgres" + + result = multiquery("select '0 seconds'::interval as interval") + assert result[0][0]["interval"] == "00:00:00", "interval is expected in postgres format" + + result = multiquery( + "SET IntervalStyle = 'iso_8601'", + "select '0 seconds'::interval as interval", + ) + assert result[1][0]["interval"] == "PT0S", "interval is expected in ISO-8601 format" + + result = multiquery("select '0 seconds'::interval as interval") + assert result[0][0]["interval"] == "00:00:00", "interval is expected in postgres format" + + def test_sql_over_http_urlencoding(static_proxy: NeonProxy): static_proxy.safe_psql("create user \"http+auth$$\" with password '%+$^&*@!' superuser") @@ -544,23 +570,37 @@ def test_http_pool_begin(static_proxy: NeonProxy): query(200, "SELECT 1;") # Query that should succeed regardless of the transaction -def test_sql_over_http_pool_idle(static_proxy: NeonProxy): +def test_sql_over_http_pool_tx_reuse(static_proxy: NeonProxy): static_proxy.safe_psql("create user http_auth2 with password 'http' superuser") - def query(status: int, query: str) -> Any: + def query(status: int, query: str, *args) -> Any: return static_proxy.http_query( query, - [], + args, user="http_auth2", password="http", expected_code=status, ) - pid1 = query(200, GET_CONNECTION_PID_QUERY)["rows"][0]["pid"] + def query_pid_txid() -> Any: + result = query( + 200, + "SELECT pg_backend_pid() as pid, pg_current_xact_id() as txid", + ) + + return result["rows"][0] + + res0 = query_pid_txid() + time.sleep(0.02) query(200, "BEGIN") - pid2 = query(200, GET_CONNECTION_PID_QUERY)["rows"][0]["pid"] - assert pid1 != pid2 + + res1 = query_pid_txid() + res2 = query_pid_txid() + + assert res0["pid"] == res1["pid"], "connection should be reused" + assert res0["pid"] == res2["pid"], "connection should be reused" + assert res1["txid"] != res2["txid"], "txid should be different" @pytest.mark.timeout(60) diff --git a/test_runner/regress/test_relations.py b/test_runner/regress/test_relations.py index b2ddcb1c2e..6263ced9df 100644 --- a/test_runner/regress/test_relations.py +++ b/test_runner/regress/test_relations.py @@ -7,6 +7,8 @@ if TYPE_CHECKING: NeonEnvBuilder, ) +from fixtures.neon_fixtures import wait_for_last_flush_lsn + def test_pageserver_reldir_v2( neon_env_builder: NeonEnvBuilder, @@ -65,6 +67,8 @@ def test_pageserver_reldir_v2( endpoint.safe_psql("CREATE TABLE foo4 (id INTEGER PRIMARY KEY, val text)") # Delete a relation in v1 endpoint.safe_psql("DROP TABLE foo1") + # wait pageserver to apply the LSN + wait_for_last_flush_lsn(env, endpoint, env.initial_tenant, env.initial_timeline) # Check if both relations are still accessible endpoint.safe_psql("SELECT * FROM foo2") @@ -76,12 +80,16 @@ def test_pageserver_reldir_v2( # This will acquire a basebackup, which lists all relations. endpoint.start() - # Check if both relations are still accessible + # Check if both relations are still accessible after restart endpoint.safe_psql("DROP TABLE IF EXISTS foo1") endpoint.safe_psql("SELECT * FROM foo2") endpoint.safe_psql("SELECT * FROM foo3") endpoint.safe_psql("SELECT * FROM foo4") endpoint.safe_psql("DROP TABLE foo3") + # wait pageserver to apply the LSN + wait_for_last_flush_lsn(env, endpoint, env.initial_tenant, env.initial_timeline) + + # Restart the endpoint again endpoint.stop() endpoint.start() @@ -99,6 +107,9 @@ def test_pageserver_reldir_v2( }, ) + endpoint.stop() + endpoint.start() + # Check if the relation is still accessible endpoint.safe_psql("SELECT * FROM foo2") endpoint.safe_psql("SELECT * FROM foo4") @@ -111,3 +122,10 @@ def test_pageserver_reldir_v2( ] == "migrating" ) + + assert ( + env.pageserver.http_client().timeline_detail(env.initial_tenant, env.initial_timeline)[ + "rel_size_migrated_at" + ] + is not None + ) diff --git a/test_runner/regress/test_replica_promotes.py b/test_runner/regress/test_replica_promotes.py index 8d39ac123a..9415d6886c 100644 --- a/test_runner/regress/test_replica_promotes.py +++ b/test_runner/regress/test_replica_promotes.py @@ -90,6 +90,7 @@ def test_replica_promote(neon_simple_env: NeonEnv, method: PromoteMethod): secondary_cur.execute("select count(*) from t") assert secondary_cur.fetchone() == (100,) + primary_spec = primary.get_compute_spec() primary_endpoint_id = primary.endpoint_id stop_and_check_lsn(primary, expected_primary_lsn) @@ -99,10 +100,9 @@ def test_replica_promote(neon_simple_env: NeonEnv, method: PromoteMethod): if method == PromoteMethod.COMPUTE_CTL: client = secondary.http_client() client.prewarm_lfc(primary_endpoint_id) - # control plane knows safekeepers, simulate it by querying primary assert (lsn := primary.terminate_flush_lsn) - safekeepers_lsn = {"safekeepers": safekeepers, "wal_flush_lsn": lsn} - assert client.promote(safekeepers_lsn)["status"] == "completed" + promote_spec = {"spec": primary_spec, "wal_flush_lsn": str(lsn)} + assert client.promote(promote_spec)["status"] == "completed" else: promo_cur.execute(f"alter system set neon.safekeepers='{safekeepers}'") promo_cur.execute("select pg_reload_conf()") @@ -131,21 +131,35 @@ def test_replica_promote(neon_simple_env: NeonEnv, method: PromoteMethod): lsn_triple = get_lsn_triple(new_primary_cur) log.info(f"Secondary: LSN after workload is {lsn_triple}") - expected_promoted_lsn = Lsn(lsn_triple[2]) + expected_lsn = Lsn(lsn_triple[2]) with secondary.connect() as conn, conn.cursor() as new_primary_cur: new_primary_cur.execute("select payload from t") assert new_primary_cur.fetchall() == [(it,) for it in range(1, 201)] if method == PromoteMethod.COMPUTE_CTL: - # compute_ctl's /promote switches replica type to Primary so it syncs - # safekeepers on finish - stop_and_check_lsn(secondary, expected_promoted_lsn) + # compute_ctl's /promote switches replica type to Primary so it syncs safekeepers on finish + stop_and_check_lsn(secondary, expected_lsn) else: - # on testing postgres, we don't update replica type, secondaries don't - # sync so lsn should be None + # on testing postgres, we don't update replica type, secondaries don't sync so lsn should be None stop_and_check_lsn(secondary, None) + if method == PromoteMethod.COMPUTE_CTL: + secondary.stop() + # In production, compute ultimately receives new compute spec from cplane. + secondary.respec(mode="Primary") + secondary.start() + + with secondary.connect() as conn, conn.cursor() as new_primary_cur: + new_primary_cur.execute( + "INSERT INTO t (payload) SELECT generate_series(101, 200) RETURNING payload" + ) + assert new_primary_cur.fetchall() == [(it,) for it in range(101, 201)] + lsn_triple = get_lsn_triple(new_primary_cur) + log.info(f"Secondary: LSN after restart and workload is {lsn_triple}") + expected_lsn = Lsn(lsn_triple[2]) + stop_and_check_lsn(secondary, expected_lsn) + primary = env.endpoints.create_start(branch_name="main", endpoint_id="primary2") with primary.connect() as new_primary, new_primary.cursor() as new_primary_cur: @@ -154,10 +168,11 @@ def test_replica_promote(neon_simple_env: NeonEnv, method: PromoteMethod): log.info(f"New primary: Boot LSN is {lsn_triple}") new_primary_cur.execute("select count(*) from t") - assert new_primary_cur.fetchone() == (200,) + compute_ctl_count = 100 * (method == PromoteMethod.COMPUTE_CTL) + assert new_primary_cur.fetchone() == (200 + compute_ctl_count,) new_primary_cur.execute("INSERT INTO t (payload) SELECT generate_series(201, 300)") new_primary_cur.execute("select count(*) from t") - assert new_primary_cur.fetchone() == (300,) + assert new_primary_cur.fetchone() == (300 + compute_ctl_count,) stop_and_check_lsn(primary, expected_primary_lsn) @@ -175,18 +190,91 @@ def test_replica_promote_handler_disconnects(neon_simple_env: NeonEnv): cur.execute("create schema neon;create extension neon with schema neon") cur.execute("create table t(pk bigint GENERATED ALWAYS AS IDENTITY, payload integer)") cur.execute("INSERT INTO t(payload) SELECT generate_series(1, 100)") - cur.execute("show neon.safekeepers") - safekeepers = cur.fetchall()[0][0] primary.http_client().offload_lfc() + primary_spec = primary.get_compute_spec() primary_endpoint_id = primary.endpoint_id primary.stop(mode="immediate-terminate") assert (lsn := primary.terminate_flush_lsn) client = secondary.http_client() client.prewarm_lfc(primary_endpoint_id) - safekeepers_lsn = {"safekeepers": safekeepers, "wal_flush_lsn": lsn} - assert client.promote(safekeepers_lsn, disconnect=True)["status"] == "completed" + promote_spec = {"spec": primary_spec, "wal_flush_lsn": str(lsn)} + assert client.promote(promote_spec, disconnect=True)["status"] == "completed" + + with secondary.connect() as conn, conn.cursor() as cur: + cur.execute("select count(*) from t") + assert cur.fetchone() == (100,) + cur.execute("INSERT INTO t (payload) SELECT generate_series(101, 200) RETURNING payload") + cur.execute("select count(*) from t") + assert cur.fetchone() == (200,) + + +@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") +def test_replica_promote_fails(neon_simple_env: NeonEnv): + """ + Test that if a /promote route fails, we can safely start primary back + """ + env: NeonEnv = neon_simple_env + primary: Endpoint = env.endpoints.create_start(branch_name="main", endpoint_id="primary") + secondary: Endpoint = env.endpoints.new_replica_start(origin=primary, endpoint_id="secondary") + secondary.stop() + secondary.start(env={"FAILPOINTS": "compute-promotion=return(0)"}) + + with primary.connect() as conn, conn.cursor() as cur: + cur.execute("create schema neon;create extension neon with schema neon") + cur.execute("create table t(pk bigint GENERATED ALWAYS AS IDENTITY, payload integer)") + cur.execute("INSERT INTO t(payload) SELECT generate_series(1, 100)") + + primary.http_client().offload_lfc() + primary_spec = primary.get_compute_spec() + primary_endpoint_id = primary.endpoint_id + primary.stop(mode="immediate-terminate") + assert (lsn := primary.terminate_flush_lsn) + + client = secondary.http_client() + client.prewarm_lfc(primary_endpoint_id) + promote_spec = {"spec": primary_spec, "wal_flush_lsn": str(lsn)} + assert client.promote(promote_spec)["status"] == "failed" + secondary.stop() + + primary.start() + with primary.connect() as conn, conn.cursor() as cur: + cur.execute("select count(*) from t") + assert cur.fetchone() == (100,) + cur.execute("INSERT INTO t (payload) SELECT generate_series(101, 200) RETURNING payload") + cur.execute("select count(*) from t") + assert cur.fetchone() == (200,) + + +@pytest.mark.skipif(not USE_LFC, reason="LFC is disabled, skipping") +def test_replica_promote_prewarm_fails(neon_simple_env: NeonEnv): + """ + Test that if /lfc/prewarm route fails, we are able to promote + """ + env: NeonEnv = neon_simple_env + primary: Endpoint = env.endpoints.create_start(branch_name="main", endpoint_id="primary") + secondary: Endpoint = env.endpoints.new_replica_start(origin=primary, endpoint_id="secondary") + secondary.stop() + secondary.start(env={"FAILPOINTS": "compute-prewarm=return(0)"}) + + with primary.connect() as conn, conn.cursor() as cur: + cur.execute("create schema neon;create extension neon with schema neon") + cur.execute("create table t(pk bigint GENERATED ALWAYS AS IDENTITY, payload integer)") + cur.execute("INSERT INTO t(payload) SELECT generate_series(1, 100)") + + primary.http_client().offload_lfc() + primary_spec = primary.get_compute_spec() + primary_endpoint_id = primary.endpoint_id + primary.stop(mode="immediate-terminate") + assert (lsn := primary.terminate_flush_lsn) + + client = secondary.http_client() + with pytest.raises(AssertionError): + client.prewarm_lfc(primary_endpoint_id) + assert client.prewarm_lfc_status()["status"] == "failed" + promote_spec = {"spec": primary_spec, "wal_flush_lsn": str(lsn)} + assert client.promote(promote_spec)["status"] == "completed" with secondary.connect() as conn, conn.cursor() as cur: cur.execute("select count(*) from t") diff --git a/test_runner/regress/test_sharding.py b/test_runner/regress/test_sharding.py index 2252c098c7..4e46b67988 100644 --- a/test_runner/regress/test_sharding.py +++ b/test_runner/regress/test_sharding.py @@ -1508,20 +1508,55 @@ def test_sharding_split_failures( env.storage_controller.consistency_check() -@pytest.mark.skip(reason="The backpressure change has not been merged yet.") +# HADRON +def test_create_tenant_after_split(neon_env_builder: NeonEnvBuilder): + """ + Tests creating a tenant and a timeline should fail after a tenant split. + """ + env = neon_env_builder.init_start(initial_tenant_shard_count=4) + + env.storage_controller.allowed_errors.extend( + [ + ".*already exists with a different shard count.*", + ] + ) + + ep = env.endpoints.create_start("main", tenant_id=env.initial_tenant) + ep.safe_psql("CREATE TABLE usertable ( YCSB_KEY INT, FIELD0 TEXT);") + ep.safe_psql("INSERT INTO usertable VALUES (1, 'test1');") + ep.safe_psql("INSERT INTO usertable VALUES (2, 'test2');") + ep.safe_psql("INSERT INTO usertable VALUES (3, 'test3');") + + # Split the tenant + + env.storage_controller.tenant_shard_split(env.initial_tenant, shard_count=8) + + with pytest.raises(RuntimeError): + env.create_tenant(env.initial_tenant, env.initial_timeline, shard_count=4) + + # run more queries + ep.safe_psql("SELECT * FROM usertable;") + ep.safe_psql("UPDATE usertable set FIELD0 = 'test4';") + + ep.stop_and_destroy() + + +# HADRON def test_back_pressure_during_split(neon_env_builder: NeonEnvBuilder): """ - Test backpressure can ignore new shards during tenant split so that if we abort the split, - PG can continue without being blocked. + Test backpressure works correctly during a shard split, especially after a split is aborted, + PG will not be stuck forever. """ - DBNAME = "regression" - - init_shard_count = 4 + init_shard_count = 1 neon_env_builder.num_pageservers = init_shard_count stripe_size = 32 env = neon_env_builder.init_start( - initial_tenant_shard_count=init_shard_count, initial_tenant_shard_stripe_size=stripe_size + initial_tenant_shard_count=init_shard_count, + initial_tenant_shard_stripe_size=stripe_size, + initial_tenant_conf={ + "checkpoint_distance": 1024 * 1024 * 1024, + }, ) env.storage_controller.allowed_errors.extend( @@ -1537,19 +1572,31 @@ def test_back_pressure_during_split(neon_env_builder: NeonEnvBuilder): "main", config_lines=[ "max_replication_write_lag = 1MB", - "databricks.max_wal_mb_per_second = 1", "neon.max_cluster_size = 10GB", + "databricks.max_wal_mb_per_second=100", ], ) - endpoint.respec(skip_pg_catalog_updates=False) # Needed for databricks_system to get created. + endpoint.respec(skip_pg_catalog_updates=False) endpoint.start() - endpoint.safe_psql(f"CREATE DATABASE {DBNAME}") - - endpoint.safe_psql("CREATE TABLE usertable ( YCSB_KEY INT, FIELD0 TEXT);") + # generate 10MB of data + endpoint.safe_psql( + "CREATE TABLE usertable AS SELECT s AS KEY, repeat('a', 1000) as VALUE from generate_series(1, 10000) s;" + ) write_done = Event() - def write_data(write_done): + def get_write_lag(): + res = endpoint.safe_psql( + """ + SELECT + pg_wal_lsn_diff(pg_current_wal_flush_lsn(), received_lsn) as received_lsn_lag + FROM neon.backpressure_lsns(); + """, + log_query=False, + ) + return res[0][0] + + def write_data(write_done: Event): while not write_done.is_set(): endpoint.safe_psql( "INSERT INTO usertable SELECT random(), repeat('a', 1000);", log_query=False @@ -1560,35 +1607,39 @@ def test_back_pressure_during_split(neon_env_builder: NeonEnvBuilder): writer_thread.start() env.storage_controller.configure_failpoints(("shard-split-pre-complete", "return(1)")) + # sleep 10 seconds before re-activating the old shard when aborting the split. + # this is to add some backpressures to PG + env.pageservers[0].http_client().configure_failpoints( + ("attach-before-activate-sleep", "return(10000)"), + ) # split the tenant with pytest.raises(StorageControllerApiException): - env.storage_controller.tenant_shard_split(env.initial_tenant, shard_count=16) + env.storage_controller.tenant_shard_split(env.initial_tenant, shard_count=4) + + def check_tenant_status(): + status = ( + env.pageservers[0].http_client().tenant_status(TenantShardId(env.initial_tenant, 0, 1)) + ) + assert status["state"]["slug"] == "Active" + + wait_until(check_tenant_status) write_done.set() writer_thread.join() + log.info(f"current write lag: {get_write_lag()}") + # writing more data to page servers after split is aborted - for _i in range(5000): - endpoint.safe_psql( - "INSERT INTO usertable SELECT random(), repeat('a', 1000);", log_query=False - ) + with endpoint.cursor() as cur: + for _i in range(1000): + cur.execute("INSERT INTO usertable SELECT random(), repeat('a', 1000);") # wait until write lag becomes 0 def check_write_lag_is_zero(): - res = endpoint.safe_psql( - """ - SELECT - pg_wal_lsn_diff(pg_current_wal_flush_lsn(), received_lsn) as received_lsn_lag - FROM neon.backpressure_lsns(); - """, - dbname="databricks_system", - log_query=False, - ) - log.info(f"received_lsn_lag = {res[0][0]}") - assert res[0][0] == 0 + res = get_write_lag() + assert res == 0 wait_until(check_write_lag_is_zero) - endpoint.stop_and_destroy() # BEGIN_HADRON @@ -1674,7 +1725,6 @@ def test_shard_resolve_during_split_abort(neon_env_builder: NeonEnvBuilder): # HADRON -@pytest.mark.skip(reason="The backpressure change has not been merged yet.") def test_back_pressure_per_shard(neon_env_builder: NeonEnvBuilder): """ Tests back pressure knobs are enforced on the per shard basis instead of at the tenant level. @@ -1701,22 +1751,19 @@ def test_back_pressure_per_shard(neon_env_builder: NeonEnvBuilder): "max_replication_apply_lag = 0", "max_replication_flush_lag = 15MB", "neon.max_cluster_size = 10GB", + "neon.lakebase_mode = true", ], ) - endpoint.respec(skip_pg_catalog_updates=False) # Needed for databricks_system to get created. + endpoint.respec(skip_pg_catalog_updates=False) endpoint.start() # generate 20MB of data endpoint.safe_psql( "CREATE TABLE usertable AS SELECT s AS KEY, repeat('a', 1000) as VALUE from generate_series(1, 20000) s;" ) - res = endpoint.safe_psql( - "SELECT neon.backpressure_throttling_time() as throttling_time", dbname="databricks_system" - )[0] + res = endpoint.safe_psql("SELECT neon.backpressure_throttling_time() as throttling_time")[0] assert res[0] == 0, f"throttling_time should be 0, but got {res[0]}" - endpoint.stop() - # HADRON def test_shard_split_page_server_timeout(neon_env_builder: NeonEnvBuilder): @@ -1880,14 +1927,14 @@ def test_sharding_backpressure(neon_env_builder: NeonEnvBuilder): shards_info() for _write_iter in range(30): - # approximately 1MB of data - workload.write_rows(8000, upload=False) + # approximately 10MB of data + workload.write_rows(80000, upload=False) update_write_lsn() infos = shards_info() min_lsn = min(Lsn(info["last_record_lsn"]) for info in infos) max_lsn = max(Lsn(info["last_record_lsn"]) for info in infos) diff = max_lsn - min_lsn - assert diff < 2 * 1024 * 1024, f"LSN diff={diff}, expected diff < 2MB due to backpressure" + assert diff < 8 * 1024 * 1024, f"LSN diff={diff}, expected diff < 8MB due to backpressure" def test_sharding_unlogged_relation(neon_env_builder: NeonEnvBuilder): diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index dbd0388034..ea2bc3a62c 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -3312,6 +3312,7 @@ def test_ps_unavailable_after_delete( ps.allowed_errors.append(".*request was dropped before completing.*") env.storage_controller.node_delete(ps.id, force=True) wait_until(lambda: assert_nodes_count(2)) + env.storage_controller.reconcile_until_idle() elif deletion_api == DeletionAPIKind.OLD: env.storage_controller.node_delete_old(ps.id) assert_nodes_count(2) @@ -4962,3 +4963,49 @@ def test_storage_controller_forward_404(neon_env_builder: NeonEnvBuilder): env.storage_controller.configure_failpoints( ("reconciler-live-migrate-post-generation-inc", "off") ) + + +def test_re_attach_with_stuck_secondary(neon_env_builder: NeonEnvBuilder): + """ + This test assumes that the secondary location cannot be configured for whatever reason. + It then attempts to detach and and attach the tenant back again and, finally, checks + for observed state consistency by attempting to create a timeline. + + See LKB-204 for more details. + """ + + neon_env_builder.num_pageservers = 2 + + env = neon_env_builder.init_configs() + env.start() + + env.storage_controller.allowed_errors.append(".*failpoint.*") + + tenant_id, _ = env.create_tenant(shard_count=1, placement_policy='{"Attached":1}') + env.storage_controller.reconcile_until_idle() + + locations = env.storage_controller.locate(tenant_id) + assert len(locations) == 1 + primary: int = locations[0]["node_id"] + + not_primary = [ps.id for ps in env.pageservers if ps.id != primary] + assert len(not_primary) == 1 + secondary = not_primary[0] + + env.get_pageserver(secondary).http_client().configure_failpoints( + ("put-location-conf-handler", "return(1)") + ) + + env.storage_controller.tenant_policy_update(tenant_id, {"placement": "Detached"}) + + with pytest.raises(Exception, match="failpoint"): + env.storage_controller.reconcile_all() + + env.storage_controller.tenant_policy_update(tenant_id, {"placement": {"Attached": 1}}) + + with pytest.raises(Exception, match="failpoint"): + env.storage_controller.reconcile_all() + + env.storage_controller.pageserver_api().timeline_create( + pg_version=PgVersion.NOT_SET, tenant_id=tenant_id, new_timeline_id=TimelineId.generate() + ) diff --git a/test_runner/regress/test_tenants.py b/test_runner/regress/test_tenants.py index 7f32f34d36..49bc02a3e7 100644 --- a/test_runner/regress/test_tenants.py +++ b/test_runner/regress/test_tenants.py @@ -298,15 +298,26 @@ def test_pageserver_metrics_removed_after_detach(neon_env_builder: NeonEnvBuilde assert post_detach_samples == set() -def test_pageserver_metrics_removed_after_offload(neon_env_builder: NeonEnvBuilder): +@pytest.mark.parametrize("compaction", ["compaction_enabled", "compaction_disabled"]) +def test_pageserver_metrics_removed_after_offload( + neon_env_builder: NeonEnvBuilder, compaction: str +): """Tests that when a timeline is offloaded, the tenant specific metrics are not left behind""" neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.MOCK_S3) - neon_env_builder.num_safekeepers = 3 env = neon_env_builder.init_start() - tenant_1, _ = env.create_tenant() + tenant_1, _ = env.create_tenant( + conf={ + # disable background compaction and GC so that we don't have leftover tasks + # after offloading. + "gc_period": "0s", + "compaction_period": "0s", + } + if compaction == "compaction_disabled" + else None + ) timeline_1 = env.create_timeline("test_metrics_removed_after_offload_1", tenant_id=tenant_1) timeline_2 = env.create_timeline("test_metrics_removed_after_offload_2", tenant_id=tenant_1) @@ -351,6 +362,23 @@ def test_pageserver_metrics_removed_after_offload(neon_env_builder: NeonEnvBuild state=TimelineArchivalState.ARCHIVED, ) env.pageserver.http_client().timeline_offload(tenant_1, timeline) + # We need to wait until all background jobs are finished before we can check the metrics. + # There're many of them: compaction, GC, etc. + wait_until( + lambda: all( + sample.value == 0 + for sample in env.pageserver.http_client() + .get_metrics() + .query_all("pageserver_background_loop_semaphore_waiting_tasks") + ) + and all( + sample.value == 0 + for sample in env.pageserver.http_client() + .get_metrics() + .query_all("pageserver_background_loop_semaphore_running_tasks") + ) + ) + post_offload_samples = set( [x.name for x in get_ps_metric_samples_for_timeline(tenant_1, timeline)] )