Merge branch 'main' into thesuhas/brc-3051

This commit is contained in:
Suhas Thalanki
2025-07-29 17:34:21 -04:00
136 changed files with 5414 additions and 1832 deletions

View File

@@ -31,7 +31,7 @@ config-variables:
- NEON_PROD_AWS_ACCOUNT_ID
- PGREGRESS_PG16_PROJECT_ID
- PGREGRESS_PG17_PROJECT_ID
- PREWARM_PGBENCH_SIZE
- PREWARM_PROJECT_ID
- REMOTE_STORAGE_AZURE_CONTAINER
- REMOTE_STORAGE_AZURE_REGION
- SLACK_CICD_CHANNEL_ID

384
.github/workflows/benchbase_tpcc.yml vendored Normal file
View File

@@ -0,0 +1,384 @@
name: TPC-C like benchmark using benchbase
on:
schedule:
# * is a special character in YAML so you have to quote this string
# ┌───────────── minute (0 - 59)
# │ ┌───────────── hour (0 - 23)
# │ │ ┌───────────── day of the month (1 - 31)
# │ │ │ ┌───────────── month (1 - 12 or JAN-DEC)
# │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT)
- cron: '0 6 * * *' # run once a day at 6 AM UTC
workflow_dispatch: # adds ability to run this manually
defaults:
run:
shell: bash -euxo pipefail {0}
concurrency:
# Allow only one workflow globally because we do not want to be too noisy in production environment
group: benchbase-tpcc-workflow
cancel-in-progress: false
permissions:
contents: read
jobs:
benchbase-tpcc:
strategy:
fail-fast: false # allow other variants to continue even if one fails
matrix:
include:
- warehouses: 50 # defines number of warehouses and is used to compute number of terminals
max_rate: 800 # measured max TPS at scale factor based on experiments. Adjust if performance is better/worse
min_cu: 0.25 # simulate free tier plan (0.25 -2 CU)
max_cu: 2
- warehouses: 500 # serverless plan (2-8 CU)
max_rate: 2000
min_cu: 2
max_cu: 8
- warehouses: 1000 # business plan (2-16 CU)
max_rate: 2900
min_cu: 2
max_cu: 16
max-parallel: 1 # we want to run each workload size sequentially to avoid noisy neighbors
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
env:
PG_CONFIG: /tmp/neon/pg_install/v17/bin/pg_config
PSQL: /tmp/neon/pg_install/v17/bin/psql
PG_17_LIB_PATH: /tmp/neon/pg_install/v17/lib
POSTGRES_VERSION: 17
runs-on: [ self-hosted, us-east-2, x64 ]
timeout-minutes: 1440
steps:
- name: Harden the runner (Audit all outbound calls)
uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0
with:
egress-policy: audit
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Configure AWS credentials # necessary to download artefacts
uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # 5 hours is currently max associated with IAM role
- name: Download Neon artifact
uses: ./.github/actions/download
with:
name: neon-${{ runner.os }}-${{ runner.arch }}-release-artifact
path: /tmp/neon/
prefix: latest
aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
- name: Create Neon Project
id: create-neon-project-tpcc
uses: ./.github/actions/neon-project-create
with:
region_id: aws-us-east-2
postgres_version: ${{ env.POSTGRES_VERSION }}
compute_units: '[${{ matrix.min_cu }}, ${{ matrix.max_cu }}]'
api_key: ${{ secrets.NEON_PRODUCTION_API_KEY_4_BENCHMARKS }}
api_host: console.neon.tech # production (!)
- name: Initialize Neon project
env:
BENCHMARK_TPCC_CONNSTR: ${{ steps.create-neon-project-tpcc.outputs.dsn }}
PROJECT_ID: ${{ steps.create-neon-project-tpcc.outputs.project_id }}
run: |
echo "Initializing Neon project with project_id: ${PROJECT_ID}"
export LD_LIBRARY_PATH=${PG_17_LIB_PATH}
# Retry logic for psql connection with 1 minute sleep between attempts
for attempt in {1..3}; do
echo "Attempt ${attempt}/3: Creating extensions in Neon project"
if ${PSQL} "${BENCHMARK_TPCC_CONNSTR}" -c "CREATE EXTENSION IF NOT EXISTS neon; CREATE EXTENSION IF NOT EXISTS neon_utils;"; then
echo "Successfully created extensions"
break
else
echo "Failed to create extensions on attempt ${attempt}"
if [ ${attempt} -lt 3 ]; then
echo "Waiting 60 seconds before retry..."
sleep 60
else
echo "All attempts failed, exiting"
exit 1
fi
fi
done
echo "BENCHMARK_TPCC_CONNSTR=${BENCHMARK_TPCC_CONNSTR}" >> $GITHUB_ENV
- name: Generate BenchBase workload configuration
env:
WAREHOUSES: ${{ matrix.warehouses }}
MAX_RATE: ${{ matrix.max_rate }}
run: |
echo "Generating BenchBase configs for warehouses: ${WAREHOUSES}, max_rate: ${MAX_RATE}"
# Extract hostname and password from connection string
# Format: postgresql://username:password@hostname/database?params (no port for Neon)
HOSTNAME=$(echo "${BENCHMARK_TPCC_CONNSTR}" | sed -n 's|.*://[^:]*:[^@]*@\([^/]*\)/.*|\1|p')
PASSWORD=$(echo "${BENCHMARK_TPCC_CONNSTR}" | sed -n 's|.*://[^:]*:\([^@]*\)@.*|\1|p')
echo "Extracted hostname: ${HOSTNAME}"
# Use runner temp (NVMe) as working directory
cd "${RUNNER_TEMP}"
# Copy the generator script
cp "${GITHUB_WORKSPACE}/test_runner/performance/benchbase_tpc_c_helpers/generate_workload_size.py" .
# Generate configs and scripts
python3 generate_workload_size.py \
--warehouses ${WAREHOUSES} \
--max-rate ${MAX_RATE} \
--hostname ${HOSTNAME} \
--password ${PASSWORD} \
--runner-arch ${{ runner.arch }}
# Fix path mismatch: move generated configs and scripts to expected locations
mv ../configs ./configs
mv ../scripts ./scripts
- name: Prepare database (load data)
env:
WAREHOUSES: ${{ matrix.warehouses }}
run: |
cd "${RUNNER_TEMP}"
echo "Loading ${WAREHOUSES} warehouses into database..."
# Run the loader script and capture output to log file while preserving stdout/stderr
./scripts/load_${WAREHOUSES}_warehouses.sh 2>&1 | tee "load_${WAREHOUSES}_warehouses.log"
echo "Database loading completed"
- name: Run TPC-C benchmark (warmup phase, then benchmark at 70% of configuredmax TPS)
env:
WAREHOUSES: ${{ matrix.warehouses }}
run: |
cd "${RUNNER_TEMP}"
echo "Running TPC-C benchmark with ${WAREHOUSES} warehouses..."
# Run the optimal rate benchmark
./scripts/execute_${WAREHOUSES}_warehouses_opt_rate.sh
echo "Benchmark execution completed"
- name: Run TPC-C benchmark (warmup phase, then ramp down TPS and up again in 5 minute intervals)
env:
WAREHOUSES: ${{ matrix.warehouses }}
run: |
cd "${RUNNER_TEMP}"
echo "Running TPC-C ramp-down-up with ${WAREHOUSES} warehouses..."
# Run the optimal rate benchmark
./scripts/execute_${WAREHOUSES}_warehouses_ramp_up.sh
echo "Benchmark execution completed"
- name: Process results (upload to test results database and generate diagrams)
env:
WAREHOUSES: ${{ matrix.warehouses }}
MIN_CU: ${{ matrix.min_cu }}
MAX_CU: ${{ matrix.max_cu }}
PROJECT_ID: ${{ steps.create-neon-project-tpcc.outputs.project_id }}
REVISION: ${{ github.sha }}
PERF_DB_CONNSTR: ${{ secrets.PERF_TEST_RESULT_CONNSTR }}
run: |
cd "${RUNNER_TEMP}"
echo "Creating temporary Python environment for results processing..."
# Create temporary virtual environment
python3 -m venv temp_results_env
source temp_results_env/bin/activate
# Install required packages in virtual environment
pip install matplotlib pandas psycopg2-binary
echo "Copying results processing scripts..."
# Copy both processing scripts
cp "${GITHUB_WORKSPACE}/test_runner/performance/benchbase_tpc_c_helpers/generate_diagrams.py" .
cp "${GITHUB_WORKSPACE}/test_runner/performance/benchbase_tpc_c_helpers/upload_results_to_perf_test_results.py" .
echo "Processing load phase metrics..."
# Find and process load log
LOAD_LOG=$(find . -name "load_${WAREHOUSES}_warehouses.log" -type f | head -1)
if [ -n "$LOAD_LOG" ]; then
echo "Processing load metrics from: $LOAD_LOG"
python upload_results_to_perf_test_results.py \
--load-log "$LOAD_LOG" \
--run-type "load" \
--warehouses "${WAREHOUSES}" \
--min-cu "${MIN_CU}" \
--max-cu "${MAX_CU}" \
--project-id "${PROJECT_ID}" \
--revision "${REVISION}" \
--connection-string "${PERF_DB_CONNSTR}"
else
echo "Warning: Load log file not found: load_${WAREHOUSES}_warehouses.log"
fi
echo "Processing warmup results for optimal rate..."
# Find and process warmup results
WARMUP_CSV=$(find results_warmup -name "*.results.csv" -type f | head -1)
WARMUP_JSON=$(find results_warmup -name "*.summary.json" -type f | head -1)
if [ -n "$WARMUP_CSV" ] && [ -n "$WARMUP_JSON" ]; then
echo "Generating warmup diagram from: $WARMUP_CSV"
python generate_diagrams.py \
--input-csv "$WARMUP_CSV" \
--output-svg "warmup_${WAREHOUSES}_warehouses_performance.svg" \
--title-suffix "Warmup at max TPS"
echo "Uploading warmup metrics from: $WARMUP_JSON"
python upload_results_to_perf_test_results.py \
--summary-json "$WARMUP_JSON" \
--results-csv "$WARMUP_CSV" \
--run-type "warmup" \
--min-cu "${MIN_CU}" \
--max-cu "${MAX_CU}" \
--project-id "${PROJECT_ID}" \
--revision "${REVISION}" \
--connection-string "${PERF_DB_CONNSTR}"
else
echo "Warning: Missing warmup results files (CSV: $WARMUP_CSV, JSON: $WARMUP_JSON)"
fi
echo "Processing optimal rate results..."
# Find and process optimal rate results
OPTRATE_CSV=$(find results_opt_rate -name "*.results.csv" -type f | head -1)
OPTRATE_JSON=$(find results_opt_rate -name "*.summary.json" -type f | head -1)
if [ -n "$OPTRATE_CSV" ] && [ -n "$OPTRATE_JSON" ]; then
echo "Generating optimal rate diagram from: $OPTRATE_CSV"
python generate_diagrams.py \
--input-csv "$OPTRATE_CSV" \
--output-svg "benchmark_${WAREHOUSES}_warehouses_performance.svg" \
--title-suffix "70% of max TPS"
echo "Uploading optimal rate metrics from: $OPTRATE_JSON"
python upload_results_to_perf_test_results.py \
--summary-json "$OPTRATE_JSON" \
--results-csv "$OPTRATE_CSV" \
--run-type "opt-rate" \
--min-cu "${MIN_CU}" \
--max-cu "${MAX_CU}" \
--project-id "${PROJECT_ID}" \
--revision "${REVISION}" \
--connection-string "${PERF_DB_CONNSTR}"
else
echo "Warning: Missing optimal rate results files (CSV: $OPTRATE_CSV, JSON: $OPTRATE_JSON)"
fi
echo "Processing warmup 2 results for ramp down/up phase..."
# Find and process warmup results
WARMUP_CSV=$(find results_warmup -name "*.results.csv" -type f | tail -1)
WARMUP_JSON=$(find results_warmup -name "*.summary.json" -type f | tail -1)
if [ -n "$WARMUP_CSV" ] && [ -n "$WARMUP_JSON" ]; then
echo "Generating warmup diagram from: $WARMUP_CSV"
python generate_diagrams.py \
--input-csv "$WARMUP_CSV" \
--output-svg "warmup_2_${WAREHOUSES}_warehouses_performance.svg" \
--title-suffix "Warmup at max TPS"
echo "Uploading warmup metrics from: $WARMUP_JSON"
python upload_results_to_perf_test_results.py \
--summary-json "$WARMUP_JSON" \
--results-csv "$WARMUP_CSV" \
--run-type "warmup" \
--min-cu "${MIN_CU}" \
--max-cu "${MAX_CU}" \
--project-id "${PROJECT_ID}" \
--revision "${REVISION}" \
--connection-string "${PERF_DB_CONNSTR}"
else
echo "Warning: Missing warmup results files (CSV: $WARMUP_CSV, JSON: $WARMUP_JSON)"
fi
echo "Processing ramp results..."
# Find and process ramp results
RAMPUP_CSV=$(find results_ramp_up -name "*.results.csv" -type f | head -1)
RAMPUP_JSON=$(find results_ramp_up -name "*.summary.json" -type f | head -1)
if [ -n "$RAMPUP_CSV" ] && [ -n "$RAMPUP_JSON" ]; then
echo "Generating ramp diagram from: $RAMPUP_CSV"
python generate_diagrams.py \
--input-csv "$RAMPUP_CSV" \
--output-svg "ramp_${WAREHOUSES}_warehouses_performance.svg" \
--title-suffix "ramp TPS down and up in 5 minute intervals"
echo "Uploading ramp metrics from: $RAMPUP_JSON"
python upload_results_to_perf_test_results.py \
--summary-json "$RAMPUP_JSON" \
--results-csv "$RAMPUP_CSV" \
--run-type "ramp-up" \
--min-cu "${MIN_CU}" \
--max-cu "${MAX_CU}" \
--project-id "${PROJECT_ID}" \
--revision "${REVISION}" \
--connection-string "${PERF_DB_CONNSTR}"
else
echo "Warning: Missing ramp results files (CSV: $RAMPUP_CSV, JSON: $RAMPUP_JSON)"
fi
# Deactivate and clean up virtual environment
deactivate
rm -rf temp_results_env
rm upload_results_to_perf_test_results.py
echo "Results processing completed and environment cleaned up"
- name: Set date for upload
id: set-date
run: echo "date=$(date +%Y-%m-%d)" >> $GITHUB_OUTPUT
- name: Configure AWS credentials # necessary to upload results
uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2
with:
aws-region: us-east-2
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 900 # 900 is minimum value
- name: Upload benchmark results to S3
env:
S3_BUCKET: neon-public-benchmark-results
S3_PREFIX: benchbase-tpc-c/${{ steps.set-date.outputs.date }}/${{ github.run_id }}/${{ matrix.warehouses }}-warehouses
run: |
echo "Redacting passwords from configuration files before upload..."
# Mask all passwords in XML config files
find "${RUNNER_TEMP}/configs" -name "*.xml" -type f -exec sed -i 's|<password>[^<]*</password>|<password>redacted</password>|g' {} \;
echo "Uploading benchmark results to s3://${S3_BUCKET}/${S3_PREFIX}/"
# Upload the entire benchmark directory recursively
aws s3 cp --only-show-errors --recursive "${RUNNER_TEMP}" s3://${S3_BUCKET}/${S3_PREFIX}/
echo "Upload completed"
- name: Delete Neon Project
if: ${{ always() }}
uses: ./.github/actions/neon-project-delete
with:
project_id: ${{ steps.create-neon-project-tpcc.outputs.project_id }}
api_key: ${{ secrets.NEON_PRODUCTION_API_KEY_4_BENCHMARKS }}
api_host: console.neon.tech # production (!)

View File

@@ -418,7 +418,7 @@ jobs:
statuses: write
id-token: write # aws-actions/configure-aws-credentials
env:
PGBENCH_SIZE: ${{ vars.PREWARM_PGBENCH_SIZE }}
PROJECT_ID: ${{ vars.PREWARM_PROJECT_ID }}
POSTGRES_DISTRIB_DIR: /tmp/neon/pg_install
DEFAULT_PG_VERSION: 17
TEST_OUTPUT: /tmp/test_output

View File

@@ -146,7 +146,9 @@ jobs:
with:
file: build-tools/Dockerfile
context: .
provenance: false
attests: |
type=provenance,mode=max
type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1
push: true
pull: true
build-args: |

View File

@@ -634,7 +634,9 @@ jobs:
DEBIAN_VERSION=bookworm
secrets: |
SUBZERO_ACCESS_TOKEN=${{ secrets.CI_ACCESS_TOKEN }}
provenance: false
attests: |
type=provenance,mode=max
type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1
push: true
pull: true
file: Dockerfile
@@ -747,7 +749,9 @@ jobs:
PG_VERSION=${{ matrix.version.pg }}
BUILD_TAG=${{ needs.meta.outputs.release-tag || needs.meta.outputs.build-tag }}
DEBIAN_VERSION=${{ matrix.version.debian }}
provenance: false
attests: |
type=provenance,mode=max
type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1
push: true
pull: true
file: compute/compute-node.Dockerfile
@@ -766,7 +770,9 @@ jobs:
PG_VERSION=${{ matrix.version.pg }}
BUILD_TAG=${{ needs.meta.outputs.release-tag || needs.meta.outputs.build-tag }}
DEBIAN_VERSION=${{ matrix.version.debian }}
provenance: false
attests: |
type=provenance,mode=max
type=sbom,generator=docker.io/docker/buildkit-syft-scanner:1
push: true
pull: true
file: compute/compute-node.Dockerfile

View File

@@ -48,8 +48,20 @@ jobs:
uses: ./.github/workflows/build-build-tools-image.yml
secrets: inherit
generate-ch-tmppw:
runs-on: ubuntu-22.04
outputs:
tmp_val: ${{ steps.pwgen.outputs.tmp_val }}
steps:
- name: Generate a random password
id: pwgen
run: |
set +x
p=$(dd if=/dev/random bs=14 count=1 2>/dev/null | base64)
echo tmp_val="${p//\//}" >> "${GITHUB_OUTPUT}"
test-logical-replication:
needs: [ build-build-tools-image ]
needs: [ build-build-tools-image, generate-ch-tmppw ]
runs-on: ubuntu-22.04
container:
@@ -60,16 +72,21 @@ jobs:
options: --init --user root
services:
clickhouse:
image: clickhouse/clickhouse-server:24.6.3.64
image: clickhouse/clickhouse-server:25.6
env:
CLICKHOUSE_PASSWORD: ${{ needs.generate-ch-tmppw.outputs.tmp_val }}
PGSSLCERT: /tmp/postgresql.crt
ports:
- 9000:9000
- 8123:8123
zookeeper:
image: quay.io/debezium/zookeeper:2.7
image: quay.io/debezium/zookeeper:3.1.3.Final
ports:
- 2181:2181
- 2888:2888
- 3888:3888
kafka:
image: quay.io/debezium/kafka:2.7
image: quay.io/debezium/kafka:3.1.3.Final
env:
ZOOKEEPER_CONNECT: "zookeeper:2181"
KAFKA_ADVERTISED_LISTENERS: PLAINTEXT://kafka:9092
@@ -79,7 +96,7 @@ jobs:
ports:
- 9092:9092
debezium:
image: quay.io/debezium/connect:2.7
image: quay.io/debezium/connect:3.1.3.Final
env:
BOOTSTRAP_SERVERS: kafka:9092
GROUP_ID: 1
@@ -125,6 +142,7 @@ jobs:
aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
env:
BENCHMARK_CONNSTR: ${{ steps.create-neon-project.outputs.dsn }}
CLICKHOUSE_PASSWORD: ${{ needs.generate-ch-tmppw.outputs.tmp_val }}
- name: Delete Neon Project
if: always()

203
Cargo.lock generated
View File

@@ -211,11 +211,11 @@ dependencies = [
[[package]]
name = "async-lock"
version = "3.2.0"
version = "3.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7125e42787d53db9dd54261812ef17e937c95a51e4d291373b670342fa44310c"
checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18"
dependencies = [
"event-listener 4.0.0",
"event-listener 5.4.0",
"event-listener-strategy",
"pin-project-lite",
]
@@ -1404,9 +1404,9 @@ dependencies = [
[[package]]
name = "concurrent-queue"
version = "2.3.0"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f057a694a54f12365049b0958a1685bb52d567f5593b355fbf685838e873d400"
checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973"
dependencies = [
"crossbeam-utils",
]
@@ -2232,9 +2232,9 @@ checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0"
[[package]]
name = "event-listener"
version = "4.0.0"
version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "770d968249b5d99410d61f5bf89057f3199a077a04d087092f58e7d10692baae"
checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae"
dependencies = [
"concurrent-queue",
"parking",
@@ -2243,11 +2243,11 @@ dependencies = [
[[package]]
name = "event-listener-strategy"
version = "0.4.0"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3"
checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93"
dependencies = [
"event-listener 4.0.0",
"event-listener 5.4.0",
"pin-project-lite",
]
@@ -2516,6 +2516,20 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "304de19db7028420975a296ab0fcbbc8e69438c4ed254a1e41e2a7f37d5f0e0a"
[[package]]
name = "generator"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d18470a76cb7f8ff746cf1f7470914f900252ec36bbc40b569d74b1258446827"
dependencies = [
"cc",
"cfg-if",
"libc",
"log",
"rustversion",
"windows 0.61.3",
]
[[package]]
name = "generic-array"
version = "0.14.7"
@@ -2834,7 +2848,7 @@ checksum = "f9c7c7c8ac16c798734b8a24560c1362120597c40d5e1459f09498f8f6c8f2ba"
dependencies = [
"cfg-if",
"libc",
"windows",
"windows 0.52.0",
]
[[package]]
@@ -3105,7 +3119,7 @@ dependencies = [
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"windows-core",
"windows-core 0.52.0",
]
[[package]]
@@ -3656,6 +3670,19 @@ version = "0.4.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e"
[[package]]
name = "loom"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca"
dependencies = [
"cfg-if",
"generator",
"scoped-tls",
"tracing",
"tracing-subscriber",
]
[[package]]
name = "lru"
version = "0.12.3"
@@ -3872,6 +3899,25 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "moka"
version = "0.12.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9321642ca94a4282428e6ea4af8cc2ca4eac48ac7a6a4ea8f33f76d0ce70926"
dependencies = [
"crossbeam-channel",
"crossbeam-epoch",
"crossbeam-utils",
"loom",
"parking_lot 0.12.1",
"portable-atomic",
"rustc_version",
"smallvec",
"tagptr",
"thiserror 1.0.69",
"uuid",
]
[[package]]
name = "multimap"
version = "0.8.3"
@@ -5031,8 +5077,6 @@ dependencies = [
"crc32c",
"criterion",
"env_logger",
"log",
"memoffset 0.9.0",
"once_cell",
"postgres",
"postgres_ffi_types",
@@ -5385,7 +5429,6 @@ dependencies = [
"futures",
"gettid",
"hashbrown 0.14.5",
"hashlink",
"hex",
"hmac",
"hostname",
@@ -5407,6 +5450,7 @@ dependencies = [
"lasso",
"measured",
"metrics",
"moka",
"once_cell",
"opentelemetry",
"ouroboros",
@@ -5473,6 +5517,7 @@ dependencies = [
"workspace_hack",
"x509-cert",
"zerocopy 0.8.24",
"zeroize",
]
[[package]]
@@ -6420,6 +6465,12 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "scoped-tls"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294"
[[package]]
name = "scopeguard"
version = "1.1.0"
@@ -7269,6 +7320,12 @@ dependencies = [
"winapi",
]
[[package]]
name = "tagptr"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417"
[[package]]
name = "tar"
version = "0.4.40"
@@ -8638,10 +8695,32 @@ version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e48a53791691ab099e5e2ad123536d0fff50652600abaf43bbf952894110d0be"
dependencies = [
"windows-core",
"windows-core 0.52.0",
"windows-targets 0.52.6",
]
[[package]]
name = "windows"
version = "0.61.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893"
dependencies = [
"windows-collections",
"windows-core 0.61.2",
"windows-future",
"windows-link",
"windows-numerics",
]
[[package]]
name = "windows-collections"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8"
dependencies = [
"windows-core 0.61.2",
]
[[package]]
name = "windows-core"
version = "0.52.0"
@@ -8651,6 +8730,86 @@ dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-core"
version = "0.61.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3"
dependencies = [
"windows-implement",
"windows-interface",
"windows-link",
"windows-result",
"windows-strings",
]
[[package]]
name = "windows-future"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e"
dependencies = [
"windows-core 0.61.2",
"windows-link",
"windows-threading",
]
[[package]]
name = "windows-implement"
version = "0.60.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "windows-interface"
version = "0.59.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "windows-link"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a"
[[package]]
name = "windows-numerics"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1"
dependencies = [
"windows-core 0.61.2",
"windows-link",
]
[[package]]
name = "windows-result"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6"
dependencies = [
"windows-link",
]
[[package]]
name = "windows-strings"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57"
dependencies = [
"windows-link",
]
[[package]]
name = "windows-sys"
version = "0.48.0"
@@ -8709,6 +8868,15 @@ dependencies = [
"windows_x86_64_msvc 0.52.6",
]
[[package]]
name = "windows-threading"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6"
dependencies = [
"windows-link",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.48.0"
@@ -8845,6 +9013,8 @@ dependencies = [
"clap",
"clap_builder",
"const-oid",
"crossbeam-epoch",
"crossbeam-utils",
"crypto-bigint 0.5.5",
"der 0.7.8",
"deranged",
@@ -8890,6 +9060,7 @@ dependencies = [
"once_cell",
"p256 0.13.2",
"parquet",
"portable-atomic",
"prettyplease",
"proc-macro2",
"prost 0.13.5",

View File

@@ -46,10 +46,10 @@ members = [
"libs/proxy/json",
"libs/proxy/postgres-protocol2",
"libs/proxy/postgres-types2",
"libs/proxy/subzero_core",
"libs/proxy/tokio-postgres2",
"endpoint_storage",
"pgxn/neon/communicator",
"proxy/subzero_core",
]
[workspace.package]
@@ -135,7 +135,7 @@ lock_api = "0.4.13"
md5 = "0.7.0"
measured = { version = "0.0.22", features=["lasso"] }
measured-process = { version = "0.0.22" }
memoffset = "0.9"
moka = { version = "0.12", features = ["sync"] }
nix = { version = "0.30.1", features = ["dir", "fs", "mman", "process", "socket", "signal", "poll"] }
# Do not update to >= 7.0.0, at least. The update will have a significant impact
# on compute startup metrics (start_postgres_ms), >= 25% degradation.
@@ -233,9 +233,10 @@ uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }
walkdir = "2.3.2"
rustls-native-certs = "0.8"
whoami = "1.5.1"
zerocopy = { version = "0.8", features = ["derive", "simd"] }
json-structural-diff = { version = "0.2.0" }
x509-cert = { version = "0.2.5" }
zerocopy = { version = "0.8", features = ["derive", "simd"] }
zeroize = "1.8"
## TODO replace this with tracing
env_logger = "0.11"

View File

@@ -103,7 +103,7 @@ RUN --mount=type=secret,uid=1000,id=SUBZERO_ACCESS_TOKEN \
&& if [ -s /run/secrets/SUBZERO_ACCESS_TOKEN ]; then \
export CARGO_FEATURES="rest_broker"; \
fi \
&& RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo build \
&& RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo auditable build \
--features $CARGO_FEATURES \
--bin pg_sni_router \
--bin pageserver \

View File

@@ -39,13 +39,13 @@ COPY build-tools/patches/pgcopydbv017.patch /pgcopydbv017.patch
RUN if [ "${DEBIAN_VERSION}" = "bookworm" ]; then \
set -e && \
apt update && \
apt install -y --no-install-recommends \
apt-get update && \
apt-get install -y --no-install-recommends \
ca-certificates wget gpg && \
wget -qO - https://www.postgresql.org/media/keys/ACCC4CF8.asc | gpg --dearmor -o /usr/share/keyrings/postgresql-keyring.gpg && \
echo "deb [signed-by=/usr/share/keyrings/postgresql-keyring.gpg] http://apt.postgresql.org/pub/repos/apt bookworm-pgdg main" > /etc/apt/sources.list.d/pgdg.list && \
apt-get update && \
apt install -y --no-install-recommends \
apt-get install -y --no-install-recommends \
build-essential \
autotools-dev \
libedit-dev \
@@ -89,8 +89,7 @@ RUN useradd -ms /bin/bash nonroot -b /home
# Use strict mode for bash to catch errors early
SHELL ["/bin/bash", "-euo", "pipefail", "-c"]
RUN mkdir -p /pgcopydb/bin && \
mkdir -p /pgcopydb/lib && \
RUN mkdir -p /pgcopydb/{bin,lib} && \
chmod -R 755 /pgcopydb && \
chown -R nonroot:nonroot /pgcopydb
@@ -106,8 +105,8 @@ RUN echo 'Acquire::Retries "5";' > /etc/apt/apt.conf.d/80-retries && \
# 'gdb' is included so that we get backtraces of core dumps produced in
# regression tests
RUN set -e \
&& apt update \
&& apt install -y \
&& apt-get update \
&& apt-get install -y --no-install-recommends \
autoconf \
automake \
bison \
@@ -183,22 +182,22 @@ RUN curl -sL "https://github.com/peak/s5cmd/releases/download/v${S5CMD_VERSION}/
ENV LLVM_VERSION=20
RUN curl -fsSL 'https://apt.llvm.org/llvm-snapshot.gpg.key' | apt-key add - \
&& echo "deb http://apt.llvm.org/${DEBIAN_VERSION}/ llvm-toolchain-${DEBIAN_VERSION}-${LLVM_VERSION} main" > /etc/apt/sources.list.d/llvm.stable.list \
&& apt update \
&& apt install -y clang-${LLVM_VERSION} llvm-${LLVM_VERSION} \
&& apt-get update \
&& apt-get install -y --no-install-recommends clang-${LLVM_VERSION} llvm-${LLVM_VERSION} \
&& bash -c 'for f in /usr/bin/clang*-${LLVM_VERSION} /usr/bin/llvm*-${LLVM_VERSION}; do ln -s "${f}" "${f%-${LLVM_VERSION}}"; done' \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
# Install node
ENV NODE_VERSION=24
RUN curl -fsSL https://deb.nodesource.com/setup_${NODE_VERSION}.x | bash - \
&& apt install -y nodejs \
&& apt-get install -y --no-install-recommends nodejs \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
# Install docker
RUN curl -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg \
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/debian ${DEBIAN_VERSION} stable" > /etc/apt/sources.list.d/docker.list \
&& apt update \
&& apt install -y docker-ce docker-ce-cli \
&& apt-get update \
&& apt-get install -y --no-install-recommends docker-ce docker-ce-cli \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
# Configure sudo & docker
@@ -215,12 +214,11 @@ RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-$(uname -m).zip" -o "aws
# Mold: A Modern Linker
ENV MOLD_VERSION=v2.37.1
RUN set -e \
&& git clone https://github.com/rui314/mold.git \
&& git clone -b "${MOLD_VERSION}" --depth 1 https://github.com/rui314/mold.git \
&& mkdir mold/build \
&& cd mold/build \
&& git checkout ${MOLD_VERSION} \
&& cd mold/build \
&& cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_COMPILER=clang++ .. \
&& cmake --build . -j $(nproc) \
&& cmake --build . -j "$(nproc)" \
&& cmake --install . \
&& cd .. \
&& rm -rf mold
@@ -254,7 +252,7 @@ ENV ICU_VERSION=67.1
ENV ICU_PREFIX=/usr/local/icu
# Download and build static ICU
RUN wget -O /tmp/libicu-${ICU_VERSION}.tgz https://github.com/unicode-org/icu/releases/download/release-${ICU_VERSION//./-}/icu4c-${ICU_VERSION//./_}-src.tgz && \
RUN wget -O "/tmp/libicu-${ICU_VERSION}.tgz" https://github.com/unicode-org/icu/releases/download/release-${ICU_VERSION//./-}/icu4c-${ICU_VERSION//./_}-src.tgz && \
echo "94a80cd6f251a53bd2a997f6f1b5ac6653fe791dfab66e1eb0227740fb86d5dc /tmp/libicu-${ICU_VERSION}.tgz" | sha256sum --check && \
mkdir /tmp/icu && \
pushd /tmp/icu && \
@@ -265,8 +263,7 @@ RUN wget -O /tmp/libicu-${ICU_VERSION}.tgz https://github.com/unicode-org/icu/re
make install && \
popd && \
rm -rf icu && \
rm -f /tmp/libicu-${ICU_VERSION}.tgz && \
popd
rm -f /tmp/libicu-${ICU_VERSION}.tgz
# Switch to nonroot user
USER nonroot:nonroot
@@ -279,19 +276,19 @@ ENV PYTHON_VERSION=3.11.12 \
PYENV_ROOT=/home/nonroot/.pyenv \
PATH=/home/nonroot/.pyenv/shims:/home/nonroot/.pyenv/bin:/home/nonroot/.poetry/bin:$PATH
RUN set -e \
&& cd $HOME \
&& cd "$HOME" \
&& curl -sSO https://raw.githubusercontent.com/pyenv/pyenv-installer/master/bin/pyenv-installer \
&& chmod +x pyenv-installer \
&& ./pyenv-installer \
&& export PYENV_ROOT=/home/nonroot/.pyenv \
&& export PATH="$PYENV_ROOT/bin:$PATH" \
&& export PATH="$PYENV_ROOT/shims:$PATH" \
&& pyenv install ${PYTHON_VERSION} \
&& pyenv global ${PYTHON_VERSION} \
&& pyenv install "${PYTHON_VERSION}" \
&& pyenv global "${PYTHON_VERSION}" \
&& python --version \
&& pip install --upgrade pip \
&& pip install --no-cache-dir --upgrade pip \
&& pip --version \
&& pip install pipenv wheel poetry
&& pip install --no-cache-dir pipenv wheel poetry
# Switch to nonroot user (again)
USER nonroot:nonroot
@@ -302,6 +299,7 @@ WORKDIR /home/nonroot
ENV RUSTC_VERSION=1.88.0
ENV RUSTUP_HOME="/home/nonroot/.rustup"
ENV PATH="/home/nonroot/.cargo/bin:${PATH}"
ARG CARGO_AUDITABLE_VERSION=0.7.0
ARG RUSTFILT_VERSION=0.2.1
ARG CARGO_HAKARI_VERSION=0.9.36
ARG CARGO_DENY_VERSION=0.18.2
@@ -317,14 +315,16 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux
. "$HOME/.cargo/env" && \
cargo --version && rustup --version && \
rustup component add llvm-tools rustfmt clippy && \
cargo install rustfilt --locked --version ${RUSTFILT_VERSION} && \
cargo install cargo-hakari --locked --version ${CARGO_HAKARI_VERSION} && \
cargo install cargo-deny --locked --version ${CARGO_DENY_VERSION} && \
cargo install cargo-hack --locked --version ${CARGO_HACK_VERSION} && \
cargo install cargo-nextest --locked --version ${CARGO_NEXTEST_VERSION} && \
cargo install cargo-chef --locked --version ${CARGO_CHEF_VERSION} && \
cargo install diesel_cli --locked --version ${CARGO_DIESEL_CLI_VERSION} \
--features postgres-bundled --no-default-features && \
cargo install cargo-auditable --locked --version "${CARGO_AUDITABLE_VERSION}" && \
cargo auditable install cargo-auditable --locked --version "${CARGO_AUDITABLE_VERSION}" --force && \
cargo auditable install rustfilt --version "${RUSTFILT_VERSION}" && \
cargo auditable install cargo-hakari --locked --version "${CARGO_HAKARI_VERSION}" && \
cargo auditable install cargo-deny --locked --version "${CARGO_DENY_VERSION}" && \
cargo auditable install cargo-hack --locked --version "${CARGO_HACK_VERSION}" && \
cargo auditable install cargo-nextest --locked --version "${CARGO_NEXTEST_VERSION}" && \
cargo auditable install cargo-chef --locked --version "${CARGO_CHEF_VERSION}" && \
cargo auditable install diesel_cli --locked --version "${CARGO_DIESEL_CLI_VERSION}" \
--features postgres-bundled --no-default-features && \
rm -rf /home/nonroot/.cargo/registry && \
rm -rf /home/nonroot/.cargo/git

View File

@@ -1,5 +1,11 @@
commit 5eb393810cf7c7bafa4e394dad2e349e2a8cb2cb
Author: Alexey Masterov <alexey.masterov@databricks.com>
Date: Mon Jul 28 18:11:02 2025 +0200
Patch for pg_repack
diff --git a/regress/Makefile b/regress/Makefile
index bf6edcb..89b4c7f 100644
index bf6edcb..110e734 100644
--- a/regress/Makefile
+++ b/regress/Makefile
@@ -17,7 +17,7 @@ INTVERSION := $(shell echo $$(($$(echo $(VERSION).0 | sed 's/\([[:digit:]]\{1,\}
@@ -7,18 +13,36 @@ index bf6edcb..89b4c7f 100644
#
-REGRESS := init-extension repack-setup repack-run error-on-invalid-idx no-error-on-invalid-idx after-schema repack-check nosuper tablespace get_order_by trigger
+REGRESS := init-extension repack-setup repack-run error-on-invalid-idx no-error-on-invalid-idx after-schema repack-check nosuper get_order_by trigger
+REGRESS := init-extension noautovacuum repack-setup repack-run error-on-invalid-idx no-error-on-invalid-idx after-schema repack-check nosuper get_order_by trigger autovacuum
USE_PGXS = 1 # use pgxs if not in contrib directory
PGXS := $(shell $(PG_CONFIG) --pgxs)
diff --git a/regress/expected/init-extension.out b/regress/expected/init-extension.out
index 9f2e171..f6e4f8d 100644
--- a/regress/expected/init-extension.out
+++ b/regress/expected/init-extension.out
@@ -1,3 +1,2 @@
SET client_min_messages = warning;
CREATE EXTENSION pg_repack;
-RESET client_min_messages;
diff --git a/regress/expected/autovacuum.out b/regress/expected/autovacuum.out
new file mode 100644
index 0000000..e7f2363
--- /dev/null
+++ b/regress/expected/autovacuum.out
@@ -0,0 +1,7 @@
+ALTER SYSTEM SET autovacuum='on';
+SELECT pg_reload_conf();
+ pg_reload_conf
+----------------
+ t
+(1 row)
+
diff --git a/regress/expected/noautovacuum.out b/regress/expected/noautovacuum.out
new file mode 100644
index 0000000..fc7978e
--- /dev/null
+++ b/regress/expected/noautovacuum.out
@@ -0,0 +1,7 @@
+ALTER SYSTEM SET autovacuum='off';
+SELECT pg_reload_conf();
+ pg_reload_conf
+----------------
+ t
+(1 row)
+
diff --git a/regress/expected/nosuper.out b/regress/expected/nosuper.out
index 8d0a94e..63b68bf 100644
--- a/regress/expected/nosuper.out
@@ -50,14 +74,22 @@ index 8d0a94e..63b68bf 100644
INFO: repacking table "public.tbl_cluster"
ERROR: query failed: ERROR: current transaction is aborted, commands ignored until end of transaction block
DETAIL: query was: RESET lock_timeout
diff --git a/regress/sql/init-extension.sql b/regress/sql/init-extension.sql
index 9f2e171..f6e4f8d 100644
--- a/regress/sql/init-extension.sql
+++ b/regress/sql/init-extension.sql
@@ -1,3 +1,2 @@
SET client_min_messages = warning;
CREATE EXTENSION pg_repack;
-RESET client_min_messages;
diff --git a/regress/sql/autovacuum.sql b/regress/sql/autovacuum.sql
new file mode 100644
index 0000000..a8eda63
--- /dev/null
+++ b/regress/sql/autovacuum.sql
@@ -0,0 +1,2 @@
+ALTER SYSTEM SET autovacuum='on';
+SELECT pg_reload_conf();
diff --git a/regress/sql/noautovacuum.sql b/regress/sql/noautovacuum.sql
new file mode 100644
index 0000000..13d4836
--- /dev/null
+++ b/regress/sql/noautovacuum.sql
@@ -0,0 +1,2 @@
+ALTER SYSTEM SET autovacuum='off';
+SELECT pg_reload_conf();
diff --git a/regress/sql/nosuper.sql b/regress/sql/nosuper.sql
index 072f0fa..dbe60f8 100644
--- a/regress/sql/nosuper.sql

View File

@@ -54,11 +54,11 @@ stateDiagram-v2
Running --> TerminationPendingImmediate : Requested termination
Running --> ConfigurationPending : Received a /configure request with spec
Running --> RefreshConfigurationPending : Received a /refresh_configuration request, compute node will pull a new spec and reconfigure
RefreshConfigurationPending --> Running : Compute has been re-configured
RefreshConfigurationPending --> RefreshConfiguration: Received compute spec and started configuration
RefreshConfiguration --> Running : Compute has been re-configured
RefreshConfiguration --> RefreshConfigurationPending : Configuration failed and to be retried
TerminationPendingFast --> Terminated compute with 30s delay for cplane to inspect status
TerminationPendingImmediate --> Terminated : Terminated compute immediately
Running --> TerminationPending : Requested termination
TerminationPending --> Terminated : Terminated compute
Failed --> RefreshConfigurationPending : Received a /refresh_configuration request
Failed --> [*] : Compute exited
Terminated --> [*] : Compute exited

View File

@@ -49,10 +49,10 @@ use compute_tools::compute::{
BUILD_TAG, ComputeNode, ComputeNodeParams, forward_termination_signal,
};
use compute_tools::extension_server::get_pg_version_string;
use compute_tools::logger::*;
use compute_tools::params::*;
use compute_tools::pg_isready::get_pg_isready_bin;
use compute_tools::spec::*;
use compute_tools::{hadron_metrics, installed_extensions, logger::*};
use rlimit::{Resource, setrlimit};
use signal_hook::consts::{SIGINT, SIGQUIT, SIGTERM};
use signal_hook::iterator::Signals;
@@ -82,6 +82,15 @@ struct Cli {
#[arg(long, default_value_t = 3081)]
pub internal_http_port: u16,
/// Backwards-compatible --http-port for Hadron deployments. Functionally the
/// same as --external-http-port.
#[arg(
long,
conflicts_with = "external_http_port",
conflicts_with = "internal_http_port"
)]
pub http_port: Option<u16>,
#[arg(short = 'D', long, value_name = "DATADIR")]
pub pgdata: String,
@@ -181,6 +190,26 @@ impl Cli {
}
}
// Hadron helpers to get compatible compute_ctl http ports from Cli. The old `--http-port`
// arg is used and acts the same as `--external-http-port`. The internal http port is defined
// to be http_port + 1. Hadron runs in the dblet environment which uses the host network, so
// we need to be careful with the ports to choose.
fn get_external_http_port(cli: &Cli) -> u16 {
if cli.lakebase_mode {
return cli.http_port.unwrap_or(cli.external_http_port);
}
cli.external_http_port
}
fn get_internal_http_port(cli: &Cli) -> u16 {
if cli.lakebase_mode {
return cli
.http_port
.map(|p| p + 1)
.unwrap_or(cli.internal_http_port);
}
cli.internal_http_port
}
fn main() -> Result<()> {
let cli = Cli::parse();
@@ -205,10 +234,18 @@ fn main() -> Result<()> {
// enable core dumping for all child processes
setrlimit(Resource::CORE, rlimit::INFINITY, rlimit::INFINITY)?;
if cli.lakebase_mode {
installed_extensions::initialize_metrics();
hadron_metrics::initialize_metrics();
}
let connstr = Url::parse(&cli.connstr).context("cannot parse connstr as a URL")?;
let config = get_config(&cli)?;
let external_http_port = get_external_http_port(&cli);
let internal_http_port = get_internal_http_port(&cli);
let compute_node = ComputeNode::new(
ComputeNodeParams {
compute_id: cli.compute_id,
@@ -217,8 +254,8 @@ fn main() -> Result<()> {
pgdata: cli.pgdata.clone(),
pgbin: cli.pgbin.clone(),
pgversion: get_pg_version_string(&cli.pgbin),
external_http_port: cli.external_http_port,
internal_http_port: cli.internal_http_port,
external_http_port,
internal_http_port,
remote_ext_base_url: cli.remote_ext_base_url.clone(),
resize_swap_on_bind: cli.resize_swap_on_bind,
set_disk_quota_for_fs: cli.set_disk_quota_for_fs,

View File

@@ -6,7 +6,8 @@ use compute_api::responses::{
LfcPrewarmState, PromoteState, TlsConfig,
};
use compute_api::spec::{
ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PageserverProtocol, PgIdent,
ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, GenericOption,
PageserverProtocol, PgIdent, Role,
};
use futures::StreamExt;
use futures::future::join_all;
@@ -41,8 +42,9 @@ use utils::shard::{ShardCount, ShardIndex, ShardNumber};
use crate::configurator::launch_configurator;
use crate::disk_quota::set_disk_quota;
use crate::hadron_metrics::COMPUTE_ATTACHED;
use crate::installed_extensions::get_installed_extensions;
use crate::logger::startup_context_from_env;
use crate::logger::{self, startup_context_from_env};
use crate::lsn_lease::launch_lsn_lease_bg_task_for_static;
use crate::metrics::COMPUTE_CTL_UP;
use crate::monitor::launch_monitor;
@@ -412,6 +414,130 @@ struct StartVmMonitorResult {
vm_monitor: Option<JoinHandle<Result<()>>>,
}
// BEGIN_HADRON
/// This function creates roles that are used by Databricks.
/// These roles are not needs to be botostrapped at PG Compute provisioning time.
/// The auth method for these roles are configured in databricks_pg_hba.conf in universe repository.
pub(crate) fn create_databricks_roles() -> Vec<String> {
let roles = vec![
// Role for prometheus_stats_exporter
Role {
name: "databricks_monitor".to_string(),
// This uses "local" connection and auth method for that is "trust", so no password is needed.
encrypted_password: None,
options: Some(vec![GenericOption {
name: "IN ROLE pg_monitor".to_string(),
value: None,
vartype: "string".to_string(),
}]),
},
// Role for brickstore control plane
Role {
name: "databricks_control_plane".to_string(),
// Certificate user does not need password.
encrypted_password: None,
options: Some(vec![GenericOption {
name: "SUPERUSER".to_string(),
value: None,
vartype: "string".to_string(),
}]),
},
// Role for brickstore httpgateway.
Role {
name: "databricks_gateway".to_string(),
// Certificate user does not need password.
encrypted_password: None,
options: None,
},
];
roles
.into_iter()
.map(|role| {
let query = format!(
r#"
DO $$
BEGIN
IF NOT EXISTS (
SELECT FROM pg_catalog.pg_roles WHERE rolname = '{}')
THEN
CREATE ROLE {} {};
END IF;
END
$$;"#,
role.name,
role.name.pg_quote(),
role.to_pg_options(),
);
query
})
.collect()
}
/// Databricks-specific environment variables to be passed to the `postgres` sub-process.
pub struct DatabricksEnvVars {
/// The Databricks "endpoint ID" of the compute instance. Used by `postgres` to check
/// the token scopes of internal auth tokens.
pub endpoint_id: String,
/// Hostname of the Databricks workspace URL this compute instance belongs to.
/// Used by postgres to verify Databricks PAT tokens.
pub workspace_host: String,
pub lakebase_mode: bool,
}
impl DatabricksEnvVars {
pub fn new(
compute_spec: &ComputeSpec,
compute_id: Option<&String>,
instance_id: Option<String>,
lakebase_mode: bool,
) -> Self {
let endpoint_id = if let Some(instance_id) = instance_id {
// Use instance_id as endpoint_id if it is set. This code path is for PuPr model.
instance_id
} else {
// Use compute_id as endpoint_id if instance_id is not set. The code path is for PrPr model.
// compute_id is a string format of "{endpoint_id}/{compute_idx}"
// endpoint_id is a uuid. We only need to pass down endpoint_id to postgres.
// Panics if compute_id is not set or not in the expected format.
compute_id.unwrap().split('/').next().unwrap().to_string()
};
let workspace_host = compute_spec
.databricks_settings
.as_ref()
.map(|s| s.databricks_workspace_host.clone())
.unwrap_or("".to_string());
Self {
endpoint_id,
workspace_host,
lakebase_mode,
}
}
/// Constants for the names of Databricks-specific postgres environment variables.
const DATABRICKS_ENDPOINT_ID_ENVVAR: &'static str = "DATABRICKS_ENDPOINT_ID";
const DATABRICKS_WORKSPACE_HOST_ENVVAR: &'static str = "DATABRICKS_WORKSPACE_HOST";
/// Convert DatabricksEnvVars to a list of string pairs that can be passed as env vars. Consumes `self`.
pub fn to_env_var_list(self) -> Vec<(String, String)> {
if !self.lakebase_mode {
// In neon env, we don't need to pass down the env vars to postgres.
return vec![];
}
vec![
(
Self::DATABRICKS_ENDPOINT_ID_ENVVAR.to_string(),
self.endpoint_id.clone(),
),
(
Self::DATABRICKS_WORKSPACE_HOST_ENVVAR.to_string(),
self.workspace_host.clone(),
),
]
}
}
impl ComputeNode {
pub fn new(params: ComputeNodeParams, config: ComputeConfig) -> Result<Self> {
let connstr = params.connstr.as_str();
@@ -448,7 +574,11 @@ impl ComputeNode {
let mut new_state = ComputeState::new();
if let Some(spec) = config.spec {
let pspec = ParsedSpec::try_from(spec).map_err(|msg| anyhow::anyhow!(msg))?;
new_state.pspec = Some(pspec);
if params.lakebase_mode {
ComputeNode::set_spec(&params, &mut new_state, pspec);
} else {
new_state.pspec = Some(pspec);
}
}
Ok(ComputeNode {
@@ -1046,7 +1176,14 @@ impl ComputeNode {
// If it is something different then create_dir() will error out anyway.
let pgdata = &self.params.pgdata;
let _ok = fs::remove_dir_all(pgdata);
fs::create_dir(pgdata)?;
if self.params.lakebase_mode {
// Ignore creation errors if the directory already exists (e.g. mounting it ahead of time).
// If it is something different then PG startup will error out anyway.
let _ok = fs::create_dir(pgdata);
} else {
fs::create_dir(pgdata)?;
}
fs::set_permissions(pgdata, fs::Permissions::from_mode(0o700))?;
Ok(())
@@ -1410,6 +1547,8 @@ impl ComputeNode {
let pgdata_path = Path::new(&self.params.pgdata);
let tls_config = self.tls_config(&pspec.spec);
let databricks_settings = spec.databricks_settings.as_ref();
let postgres_port = self.params.connstr.port();
// Remove/create an empty pgdata directory and put configuration there.
self.create_pgdata()?;
@@ -1417,8 +1556,11 @@ impl ComputeNode {
pgdata_path,
&self.params,
&pspec.spec,
postgres_port,
self.params.internal_http_port,
tls_config,
databricks_settings,
self.params.lakebase_mode,
)?;
// Syncing safekeepers is only safe with primary nodes: if a primary
@@ -1458,8 +1600,28 @@ impl ComputeNode {
)
})?;
// Update pg_hba.conf received with basebackup.
update_pg_hba(pgdata_path, None)?;
if let Some(settings) = databricks_settings {
copy_tls_certificates(
&settings.pg_compute_tls_settings.key_file,
&settings.pg_compute_tls_settings.cert_file,
pgdata_path,
)?;
// Update pg_hba.conf received with basebackup including additional databricks settings.
update_pg_hba(pgdata_path, Some(&settings.databricks_pg_hba))?;
update_pg_ident(pgdata_path, Some(&settings.databricks_pg_ident))?;
} else {
// Update pg_hba.conf received with basebackup.
update_pg_hba(pgdata_path, None)?;
}
if let Some(databricks_settings) = spec.databricks_settings.as_ref() {
copy_tls_certificates(
&databricks_settings.pg_compute_tls_settings.key_file,
&databricks_settings.pg_compute_tls_settings.cert_file,
pgdata_path,
)?;
}
// Place pg_dynshmem under /dev/shm. This allows us to use
// 'dynamic_shared_memory_type = mmap' so that the files are placed in
@@ -1500,7 +1662,7 @@ impl ComputeNode {
// symlink doesn't affect anything.
//
// See https://github.com/neondatabase/autoscaling/issues/800
std::fs::remove_dir(pgdata_path.join("pg_dynshmem"))?;
std::fs::remove_dir_all(pgdata_path.join("pg_dynshmem"))?;
symlink("/dev/shm/", pgdata_path.join("pg_dynshmem"))?;
match spec.mode {
@@ -1515,6 +1677,12 @@ impl ComputeNode {
/// Start and stop a postgres process to warm up the VM for startup.
pub fn prewarm_postgres_vm_memory(&self) -> Result<()> {
if self.params.lakebase_mode {
// We are running in Hadron mode. Disabling this prewarming step for now as it could run
// into dblet port conflicts and also doesn't add much value with our current infra.
info!("Skipping postgres prewarming in Hadron mode");
return Ok(());
}
info!("prewarming VM memory");
// Create pgdata
@@ -1572,14 +1740,36 @@ impl ComputeNode {
pub fn start_postgres(&self, storage_auth_token: Option<String>) -> Result<PostgresHandle> {
let pgdata_path = Path::new(&self.params.pgdata);
let env_vars: Vec<(String, String)> = if self.params.lakebase_mode {
let databricks_env_vars = {
let state = self.state.lock().unwrap();
let spec = &state.pspec.as_ref().unwrap().spec;
DatabricksEnvVars::new(
spec,
Some(&self.params.compute_id),
self.params.instance_id.clone(),
self.params.lakebase_mode,
)
};
info!(
"Starting Postgres for databricks endpoint id: {}",
&databricks_env_vars.endpoint_id
);
let mut env_vars = databricks_env_vars.to_env_var_list();
env_vars.extend(storage_auth_token.map(|t| ("NEON_AUTH_TOKEN".to_string(), t)));
env_vars
} else if let Some(storage_auth_token) = &storage_auth_token {
vec![("NEON_AUTH_TOKEN".to_owned(), storage_auth_token.to_owned())]
} else {
vec![]
};
// Run postgres as a child process.
let mut pg = maybe_cgexec(&self.params.pgbin)
.args(["-D", &self.params.pgdata])
.envs(if let Some(storage_auth_token) = &storage_auth_token {
vec![("NEON_AUTH_TOKEN", storage_auth_token)]
} else {
vec![]
})
.envs(env_vars)
.stderr(Stdio::piped())
.spawn()
.expect("cannot start postgres process");
@@ -1731,7 +1921,15 @@ impl ComputeNode {
/// Do initial configuration of the already started Postgres.
#[instrument(skip_all)]
pub fn apply_config(&self, compute_state: &ComputeState) -> Result<()> {
let conf = self.get_tokio_conn_conf(Some("compute_ctl:apply_config"));
let mut conf = self.get_tokio_conn_conf(Some("compute_ctl:apply_config"));
if self.params.lakebase_mode {
// Set a 2-minute statement_timeout for the session applying config. The individual SQL statements
// used in apply_spec_sql() should not take long (they are just creating users and installing
// extensions). If any of them are stuck for an extended period of time it usually indicates a
// pageserver connectivity problem and we should bail out.
conf.options("-c statement_timeout=2min");
}
let conf = Arc::new(conf);
let spec = Arc::new(
@@ -1882,12 +2080,16 @@ impl ComputeNode {
// Write new config
let pgdata_path = Path::new(&self.params.pgdata);
let postgres_port = self.params.connstr.port();
config::write_postgres_conf(
pgdata_path,
&self.params,
&spec,
postgres_port,
self.params.internal_http_port,
tls_config,
spec.databricks_settings.as_ref(),
self.params.lakebase_mode,
)?;
self.pg_reload_conf()?;
@@ -1993,6 +2195,7 @@ impl ComputeNode {
// wait
ComputeStatus::Init
| ComputeStatus::Configuration
| ComputeStatus::RefreshConfiguration
| ComputeStatus::RefreshConfigurationPending
| ComputeStatus::Empty => {
state = self.state_changed.wait(state).unwrap();
@@ -2044,7 +2247,17 @@ impl ComputeNode {
pub fn check_for_core_dumps(&self) -> Result<()> {
let core_dump_dir = match std::env::consts::OS {
"macos" => Path::new("/cores/"),
_ => Path::new(&self.params.pgdata),
// BEGIN HADRON
// NB: Read core dump files from a fixed location outside of
// the data directory since `compute_ctl` wipes the data directory
// across container restarts.
_ => {
if self.params.lakebase_mode {
Path::new("/databricks/logs/brickstore")
} else {
Path::new(&self.params.pgdata)
}
} // END HADRON
};
// Collect core dump paths if any
@@ -2357,7 +2570,7 @@ LIMIT 100",
if let Some(libs) = spec.cluster.settings.find("shared_preload_libraries") {
libs_vec = libs
.split(&[',', '\'', ' '])
.filter(|s| *s != "neon" && !s.is_empty())
.filter(|s| *s != "neon" && *s != "databricks_auth" && !s.is_empty())
.map(str::to_string)
.collect();
}
@@ -2376,7 +2589,7 @@ LIMIT 100",
if let Some(libs) = shared_preload_libraries_line.split("='").nth(1) {
preload_libs_vec = libs
.split(&[',', '\'', ' '])
.filter(|s| *s != "neon" && !s.is_empty())
.filter(|s| *s != "neon" && *s != "databricks_auth" && !s.is_empty())
.map(str::to_string)
.collect();
}
@@ -2550,6 +2763,34 @@ LIMIT 100",
);
}
}
/// Set the compute spec and update related metrics.
/// This is the central place where pspec is updated.
pub fn set_spec(params: &ComputeNodeParams, state: &mut ComputeState, pspec: ParsedSpec) {
state.pspec = Some(pspec);
ComputeNode::update_attached_metric(params, state);
let _ = logger::update_ids(&params.instance_id, &Some(params.compute_id.clone()));
}
pub fn update_attached_metric(params: &ComputeNodeParams, state: &mut ComputeState) {
// Update the pg_cctl_attached gauge when all identifiers are available.
if let Some(instance_id) = &params.instance_id {
if let Some(pspec) = &state.pspec {
// Clear all values in the metric
COMPUTE_ATTACHED.reset();
// Set new metric value
COMPUTE_ATTACHED
.with_label_values(&[
&params.compute_id,
instance_id,
&pspec.tenant_id.to_string(),
&pspec.timeline_id.to_string(),
])
.set(1);
}
}
}
}
pub async fn installed_extensions(conf: tokio_postgres::Config) -> Result<()> {

View File

@@ -7,11 +7,14 @@ use std::io::prelude::*;
use std::path::Path;
use compute_api::responses::TlsConfig;
use compute_api::spec::{ComputeAudit, ComputeMode, ComputeSpec, GenericOption};
use compute_api::spec::{
ComputeAudit, ComputeMode, ComputeSpec, DatabricksSettings, GenericOption,
};
use crate::compute::ComputeNodeParams;
use crate::pg_helpers::{
GenericOptionExt, GenericOptionsSearch, PgOptionsSerialize, escape_conf_value,
DatabricksSettingsExt as _, GenericOptionExt, GenericOptionsSearch, PgOptionsSerialize,
escape_conf_value,
};
use crate::tls::{self, SERVER_CRT, SERVER_KEY};
@@ -40,12 +43,16 @@ pub fn line_in_file(path: &Path, line: &str) -> Result<bool> {
}
/// Create or completely rewrite configuration file specified by `path`
#[allow(clippy::too_many_arguments)]
pub fn write_postgres_conf(
pgdata_path: &Path,
params: &ComputeNodeParams,
spec: &ComputeSpec,
postgres_port: Option<u16>,
extension_server_port: u16,
tls_config: &Option<TlsConfig>,
databricks_settings: Option<&DatabricksSettings>,
lakebase_mode: bool,
) -> Result<()> {
let path = pgdata_path.join("postgresql.conf");
// File::create() destroys the file content if it exists.
@@ -285,6 +292,24 @@ pub fn write_postgres_conf(
writeln!(file, "log_destination='stderr,syslog'")?;
}
if lakebase_mode {
// Explicitly set the port based on the connstr, overriding any previous port setting.
// Note: It is important that we don't specify a different port again after this.
let port = postgres_port.expect("port must be present in connstr");
writeln!(file, "port = {port}")?;
// This is databricks specific settings.
// This should be at the end of the file but before `compute_ctl_temp_override.conf` below
// so that it can override any settings above.
// `compute_ctl_temp_override.conf` is intended to override any settings above during specific operations.
// To prevent potential breakage in the future, we keep it above `compute_ctl_temp_override.conf`.
writeln!(file, "# Databricks settings start")?;
if let Some(settings) = databricks_settings {
writeln!(file, "{}", settings.as_pg_settings())?;
}
writeln!(file, "# Databricks settings end")?;
}
// This is essential to keep this line at the end of the file,
// because it is intended to override any settings above.
writeln!(file, "include_if_exists = 'compute_ctl_temp_override.conf'")?;

View File

@@ -2,6 +2,7 @@ use std::fs::File;
use std::thread;
use std::{path::Path, sync::Arc};
use anyhow::Result;
use compute_api::responses::{ComputeConfig, ComputeStatus};
use tracing::{error, info, instrument};
@@ -13,6 +14,10 @@ fn configurator_main_loop(compute: &Arc<ComputeNode>) {
info!("waiting for reconfiguration requests");
loop {
let mut state = compute.state.lock().unwrap();
/* BEGIN_HADRON */
// RefreshConfiguration should only be used inside the loop
assert_ne!(state.status, ComputeStatus::RefreshConfiguration);
/* END_HADRON */
if compute.params.lakebase_mode {
while state.status != ComputeStatus::ConfigurationPending
@@ -54,53 +59,81 @@ fn configurator_main_loop(compute: &Arc<ComputeNode>) {
info!(
"compute node suspects its configuration is out of date, now refreshing configuration"
);
// Drop the lock guard here to avoid holding the lock while downloading spec from the control plane / HCC.
// This is the only thread that can move compute_ctl out of the `RefreshConfigurationPending` state, so it
state.set_status(ComputeStatus::RefreshConfiguration, &compute.state_changed);
// Drop the lock guard here to avoid holding the lock while downloading config from the control plane / HCC.
// This is the only thread that can move compute_ctl out of the `RefreshConfiguration` state, so it
// is safe to drop the lock like this.
drop(state);
let spec = if let Some(config_path) = &compute.params.config_path_test_only {
// This path is only to make testing easier. In production we always get the spec from the HCC.
info!(
"reloading config.json from path: {}",
config_path.to_string_lossy()
);
let path = Path::new(config_path);
if let Ok(file) = File::open(path) {
match serde_json::from_reader::<File, ComputeConfig>(file) {
Ok(config) => config.spec,
Err(e) => {
error!("could not parse spec file: {}", e);
None
}
}
} else {
error!(
"could not open config file at path: {}",
let get_config_result: anyhow::Result<ComputeConfig> =
if let Some(config_path) = &compute.params.config_path_test_only {
// This path is only to make testing easier. In production we always get the config from the HCC.
info!(
"reloading config.json from path: {}",
config_path.to_string_lossy()
);
None
}
} else if let Some(control_plane_uri) = &compute.params.control_plane_uri {
match get_config_from_control_plane(control_plane_uri, &compute.params.compute_id) {
Ok(config) => config.spec,
Err(e) => {
error!("could not get config from control plane: {}", e);
None
let path = Path::new(config_path);
if let Ok(file) = File::open(path) {
match serde_json::from_reader::<File, ComputeConfig>(file) {
Ok(config) => Ok(config),
Err(e) => {
error!("could not parse config file: {}", e);
Err(anyhow::anyhow!("could not parse config file: {}", e))
}
}
} else {
error!(
"could not open config file at path: {:?}",
config_path.to_string_lossy()
);
Err(anyhow::anyhow!(
"could not open config file at path: {}",
config_path.to_string_lossy()
))
}
}
} else {
None
};
} else if let Some(control_plane_uri) = &compute.params.control_plane_uri {
get_config_from_control_plane(control_plane_uri, &compute.params.compute_id)
} else {
Err(anyhow::anyhow!("config_path_test_only is not set"))
};
if let Some(spec) = spec {
if let Ok(pspec) = ParsedSpec::try_from(spec) {
// Parse any received ComputeSpec and transpose the result into a Result<Option<ParsedSpec>>.
let parsed_spec_result: Result<Option<ParsedSpec>> =
get_config_result.and_then(|config| {
if let Some(spec) = config.spec {
if let Ok(pspec) = ParsedSpec::try_from(spec) {
Ok(Some(pspec))
} else {
Err(anyhow::anyhow!("could not parse spec"))
}
} else {
Ok(None)
}
});
let new_status: ComputeStatus;
match parsed_spec_result {
// Control plane (HCM) returned a spec and we were able to parse it.
Ok(Some(pspec)) => {
{
let mut state = compute.state.lock().unwrap();
// Defensive programming to make sure this thread is indeed the only one that can move the compute
// node out of the `RefreshConfigurationPending` state. Would be nice if we can encode this invariant
// node out of the `RefreshConfiguration` state. Would be nice if we can encode this invariant
// into the type system.
assert_eq!(state.status, ComputeStatus::RefreshConfigurationPending);
assert_eq!(state.status, ComputeStatus::RefreshConfiguration);
if state.pspec.as_ref().map(|ps| ps.pageserver_connstr.clone())
== Some(pspec.pageserver_connstr.clone())
{
info!(
"Refresh configuration: Retrieved spec is the same as the current spec. Waiting for control plane to update the spec before attempting reconfiguration."
);
state.status = ComputeStatus::Running;
compute.state_changed.notify_all();
drop(state);
std::thread::sleep(std::time::Duration::from_secs(5));
continue;
}
// state.pspec is consumed by compute.reconfigure() below. Note that compute.reconfigure() will acquire
// the compute.state lock again so we need to have the lock guard go out of scope here. We could add a
// "locked" variant of compute.reconfigure() that takes the lock guard as an argument to make this cleaner,
@@ -110,20 +143,45 @@ fn configurator_main_loop(compute: &Arc<ComputeNode>) {
match compute.reconfigure() {
Ok(_) => {
info!("Refresh configuration: compute node configured");
compute.set_status(ComputeStatus::Running);
new_status = ComputeStatus::Running;
}
Err(e) => {
error!(
"Refresh configuration: could not configure compute node: {}",
e
);
// Leave the compute node in the `RefreshConfigurationPending` state if the configuration
// Set the compute node back to the `RefreshConfigurationPending` state if the configuration
// was not successful. It should be okay to treat this situation the same as if the loop
// hasn't executed yet as long as the detection side keeps notifying.
new_status = ComputeStatus::RefreshConfigurationPending;
}
}
}
// Control plane (HCM)'s response does not contain a spec. This is the "Empty" attachment case.
Ok(None) => {
info!(
"Compute Manager signaled that this compute is no longer attached to any storage. Exiting."
);
// We just immediately terminate the whole compute_ctl in this case. It's not necessary to attempt a
// clean shutdown as Postgres is probably not responding anyway (which is why we are in this refresh
// configuration state).
std::process::exit(1);
}
// Various error cases:
// - The request to the control plane (HCM) either failed or returned a malformed spec.
// - compute_ctl itself is configured incorrectly (e.g., compute_id is not set).
Err(e) => {
error!(
"Refresh configuration: error getting a parsed spec: {:?}",
e
);
new_status = ComputeStatus::RefreshConfigurationPending;
// We may be dealing with an overloaded HCM if we end up in this path. Backoff 5 seconds before
// retrying to avoid hammering the HCM.
std::thread::sleep(std::time::Duration::from_secs(5));
}
}
compute.set_status(new_status);
} else if state.status == ComputeStatus::Failed {
info!("compute node is now in Failed state, exiting");
break;

View File

@@ -43,7 +43,12 @@ pub(in crate::http) async fn configure(
// configure request for tracing purposes.
state.startup_span = Some(tracing::Span::current());
state.pspec = Some(pspec);
if compute.params.lakebase_mode {
ComputeNode::set_spec(&compute.params, &mut state, pspec);
} else {
state.pspec = Some(pspec);
}
state.set_status(ComputeStatus::ConfigurationPending, &compute.state_changed);
drop(state);
}

View File

@@ -13,6 +13,7 @@ use metrics::{Encoder, TextEncoder};
use crate::communicator_socket_client::connect_communicator_socket;
use crate::compute::ComputeNode;
use crate::hadron_metrics;
use crate::http::JsonResponse;
use crate::metrics::collect;
@@ -21,11 +22,18 @@ pub(in crate::http) async fn get_metrics() -> Response {
// When we call TextEncoder::encode() below, it will immediately return an
// error if a metric family has no metrics, so we need to preemptively
// filter out metric families with no metrics.
let metrics = collect()
let mut metrics = collect()
.into_iter()
.filter(|m| !m.get_metric().is_empty())
.collect::<Vec<MetricFamily>>();
// Add Hadron metrics.
let hadron_metrics: Vec<MetricFamily> = hadron_metrics::collect()
.into_iter()
.filter(|m| !m.get_metric().is_empty())
.collect();
metrics.extend(hadron_metrics);
let encoder = TextEncoder::new();
let mut buffer = vec![];

View File

@@ -9,7 +9,7 @@ use axum::{
use http::StatusCode;
use crate::compute::ComputeNode;
// use crate::hadron_metrics::POSTGRES_PAGESTREAM_REQUEST_ERRORS;
use crate::hadron_metrics::POSTGRES_PAGESTREAM_REQUEST_ERRORS;
use crate::http::JsonResponse;
/// The /refresh_configuration POST method is used to nudge compute_ctl to pull a new spec
@@ -21,6 +21,7 @@ use crate::http::JsonResponse;
pub(in crate::http) async fn refresh_configuration(
State(compute): State<Arc<ComputeNode>>,
) -> Response {
POSTGRES_PAGESTREAM_REQUEST_ERRORS.inc();
match compute.signal_refresh_configuration().await {
Ok(_) => StatusCode::OK.into_response(),
Err(e) => JsonResponse::error(StatusCode::INTERNAL_SERVER_ERROR, e),

View File

@@ -1,7 +1,7 @@
use crate::compute::{ComputeNode, forward_termination_signal};
use crate::http::JsonResponse;
use axum::extract::State;
use axum::response::Response;
use axum::response::{IntoResponse, Response};
use axum_extra::extract::OptionalQuery;
use compute_api::responses::{ComputeStatus, TerminateMode, TerminateResponse};
use http::StatusCode;
@@ -33,7 +33,29 @@ pub(in crate::http) async fn terminate(
if !matches!(state.status, ComputeStatus::Empty | ComputeStatus::Running) {
return JsonResponse::invalid_status(state.status);
}
// If compute is Empty, there's no Postgres to terminate. The regular compute_ctl termination path
// assumes Postgres to be configured and running, so we just special-handle this case by exiting
// the process directly.
if compute.params.lakebase_mode && state.status == ComputeStatus::Empty {
drop(state);
info!("terminating empty compute - will exit process");
// Queue a task to exit the process after 5 seconds. The 5-second delay aims to
// give enough time for the HTTP response to be sent so that HCM doesn't get an abrupt
// connection termination.
tokio::spawn(async {
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
info!("exiting process after terminating empty compute");
std::process::exit(0);
});
return StatusCode::OK.into_response();
}
// For Running status, proceed with normal termination
state.set_status(mode.into(), &compute.state_changed);
drop(state);
}
forward_termination_signal(false);

View File

@@ -142,7 +142,7 @@ pub fn update_pg_hba(pgdata_path: &Path, databricks_pg_hba: Option<&String>) ->
// Update pg_hba to contains databricks specfic settings before adding neon settings
// PG uses the first record that matches to perform authentication, so we need to have
// our rules before the default ones from neon.
// See https://www.postgresql.org/docs/16/auth-pg-hba-conf.html
// See https://www.postgresql.org/docs/current/auth-pg-hba-conf.html
if let Some(databricks_pg_hba) = databricks_pg_hba {
if config::line_in_file(
&pghba_path,

View File

@@ -13,17 +13,19 @@ use tokio_postgres::Client;
use tokio_postgres::error::SqlState;
use tracing::{Instrument, debug, error, info, info_span, instrument, warn};
use crate::compute::{ComputeNode, ComputeNodeParams, ComputeState};
use crate::compute::{ComputeNode, ComputeNodeParams, ComputeState, create_databricks_roles};
use crate::hadron_metrics::COMPUTE_CONFIGURE_STATEMENT_TIMEOUT_ERRORS;
use crate::pg_helpers::{
DatabaseExt, Escaping, GenericOptionsSearch, RoleExt, get_existing_dbs_async,
get_existing_roles_async,
};
use crate::spec_apply::ApplySpecPhase::{
CreateAndAlterDatabases, CreateAndAlterRoles, CreateAvailabilityCheck, CreatePgauditExtension,
AddDatabricksGrants, AlterDatabricksRoles, CreateAndAlterDatabases, CreateAndAlterRoles,
CreateAvailabilityCheck, CreateDatabricksMisc, CreateDatabricksRoles, CreatePgauditExtension,
CreatePgauditlogtofileExtension, CreatePrivilegedRole, CreateSchemaNeon,
DisablePostgresDBPgAudit, DropInvalidDatabases, DropRoles, FinalizeDropLogicalSubscriptions,
HandleNeonExtension, HandleOtherExtensions, RenameAndDeleteDatabases, RenameRoles,
RunInEachDatabase,
HandleDatabricksAuthExtension, HandleNeonExtension, HandleOtherExtensions,
RenameAndDeleteDatabases, RenameRoles, RunInEachDatabase,
};
use crate::spec_apply::PerDatabasePhase::{
ChangeSchemaPerms, DeleteDBRoleReferences, DropLogicalSubscriptions,
@@ -166,6 +168,7 @@ impl ComputeNode {
concurrency_token.clone(),
db,
[DropLogicalSubscriptions].to_vec(),
self.params.lakebase_mode,
);
Ok(tokio::spawn(fut))
@@ -186,15 +189,33 @@ impl ComputeNode {
};
}
for phase in [
CreatePrivilegedRole,
let phases = if self.params.lakebase_mode {
vec![
CreatePrivilegedRole,
// BEGIN_HADRON
CreateDatabricksRoles,
AlterDatabricksRoles,
// END_HADRON
DropInvalidDatabases,
RenameRoles,
CreateAndAlterRoles,
RenameAndDeleteDatabases,
CreateAndAlterDatabases,
CreateSchemaNeon,
] {
]
} else {
vec![
CreatePrivilegedRole,
DropInvalidDatabases,
RenameRoles,
CreateAndAlterRoles,
RenameAndDeleteDatabases,
CreateAndAlterDatabases,
CreateSchemaNeon,
]
};
for phase in phases {
info!("Applying phase {:?}", &phase);
apply_operations(
params.clone(),
@@ -203,6 +224,7 @@ impl ComputeNode {
jwks_roles.clone(),
phase,
|| async { Ok(&client) },
self.params.lakebase_mode,
)
.await?;
}
@@ -254,6 +276,7 @@ impl ComputeNode {
concurrency_token.clone(),
db,
phases,
self.params.lakebase_mode,
);
Ok(tokio::spawn(fut))
@@ -265,12 +288,28 @@ impl ComputeNode {
handle.await??;
}
let mut phases = vec![
let mut phases = if self.params.lakebase_mode {
vec![
HandleOtherExtensions,
HandleNeonExtension, // This step depends on CreateSchemaNeon
// BEGIN_HADRON
HandleDatabricksAuthExtension,
// END_HADRON
CreateAvailabilityCheck,
DropRoles,
// BEGIN_HADRON
AddDatabricksGrants,
CreateDatabricksMisc,
// END_HADRON
]
} else {
vec![
HandleOtherExtensions,
HandleNeonExtension, // This step depends on CreateSchemaNeon
CreateAvailabilityCheck,
DropRoles,
];
]
};
// This step depends on CreateSchemaNeon
if spec.drop_subscriptions_before_start && !drop_subscriptions_done {
@@ -303,6 +342,7 @@ impl ComputeNode {
jwks_roles.clone(),
phase,
|| async { Ok(&client) },
self.params.lakebase_mode,
)
.await?;
}
@@ -328,6 +368,7 @@ impl ComputeNode {
concurrency_token: Arc<tokio::sync::Semaphore>,
db: DB,
subphases: Vec<PerDatabasePhase>,
lakebase_mode: bool,
) -> Result<()> {
let _permit = concurrency_token.acquire().await?;
@@ -355,6 +396,7 @@ impl ComputeNode {
let client = client_conn.as_ref().unwrap();
Ok(client)
},
lakebase_mode,
)
.await?;
}
@@ -477,6 +519,10 @@ pub enum PerDatabasePhase {
#[derive(Clone, Debug)]
pub enum ApplySpecPhase {
CreatePrivilegedRole,
// BEGIN_HADRON
CreateDatabricksRoles,
AlterDatabricksRoles,
// END_HADRON
DropInvalidDatabases,
RenameRoles,
CreateAndAlterRoles,
@@ -489,7 +535,14 @@ pub enum ApplySpecPhase {
DisablePostgresDBPgAudit,
HandleOtherExtensions,
HandleNeonExtension,
// BEGIN_HADRON
HandleDatabricksAuthExtension,
// END_HADRON
CreateAvailabilityCheck,
// BEGIN_HADRON
AddDatabricksGrants,
CreateDatabricksMisc,
// END_HADRON
DropRoles,
FinalizeDropLogicalSubscriptions,
}
@@ -525,6 +578,7 @@ pub async fn apply_operations<'a, Fut, F>(
jwks_roles: Arc<HashSet<String>>,
apply_spec_phase: ApplySpecPhase,
client: F,
lakebase_mode: bool,
) -> Result<()>
where
F: FnOnce() -> Fut,
@@ -571,6 +625,23 @@ where
},
query
);
if !lakebase_mode {
return res;
}
// BEGIN HADRON
if let Err(e) = res.as_ref() {
if let Some(sql_state) = e.code() {
if sql_state.code() == "57014" {
// SQL State 57014 (ERRCODE_QUERY_CANCELED) is used for statement timeouts.
// Increment the counter whenever a statement timeout occurs. Timeouts on
// this configuration path can only occur due to PS connectivity problems that
// Postgres failed to recover from.
COMPUTE_CONFIGURE_STATEMENT_TIMEOUT_ERRORS.inc();
}
}
}
// END HADRON
res
}
.instrument(inspan)
@@ -612,6 +683,35 @@ async fn get_operations<'a>(
),
comment: None,
}))),
// BEGIN_HADRON
// New Hadron phase
ApplySpecPhase::CreateDatabricksRoles => {
let queries = create_databricks_roles();
let operations = queries.into_iter().map(|query| Operation {
query,
comment: None,
});
Ok(Box::new(operations))
}
// Backfill existing databricks_reader_* roles with statement timeout from GUC
ApplySpecPhase::AlterDatabricksRoles => {
let query = String::from(include_str!(
"sql/alter_databricks_reader_roles_timeout.sql"
));
let operations = once(Operation {
query,
comment: Some(
"Backfill existing databricks_reader_* roles with statement timeout"
.to_string(),
),
});
Ok(Box::new(operations))
}
// End of new Hadron Phase
// END_HADRON
ApplySpecPhase::DropInvalidDatabases => {
let mut ctx = ctx.write().await;
let databases = &mut ctx.dbs;
@@ -981,7 +1081,10 @@ async fn get_operations<'a>(
// N.B. this has to be properly dollar-escaped with `pg_quote_dollar()`
role_name = escaped_role,
outer_tag = outer_tag,
),
)
// HADRON change:
.replace("neon_superuser", &params.privileged_role_name),
// HADRON change end ,
comment: None,
},
// This now will only drop privileges of the role
@@ -1017,7 +1120,8 @@ async fn get_operations<'a>(
comment: None,
},
Operation {
query: String::from(include_str!("sql/default_grants.sql")),
query: String::from(include_str!("sql/default_grants.sql"))
.replace("neon_superuser", &params.privileged_role_name),
comment: None,
},
]
@@ -1086,6 +1190,28 @@ async fn get_operations<'a>(
Ok(Box::new(operations))
}
// BEGIN_HADRON
// Note: we may want to version the extension someday, but for now we just drop it and recreate it.
ApplySpecPhase::HandleDatabricksAuthExtension => {
let operations = vec![
Operation {
query: String::from("DROP EXTENSION IF EXISTS databricks_auth"),
comment: Some(String::from("dropping existing databricks_auth extension")),
},
Operation {
query: String::from("CREATE EXTENSION databricks_auth"),
comment: Some(String::from("creating databricks_auth extension")),
},
Operation {
query: String::from("GRANT SELECT ON databricks_auth_metrics TO pg_monitor"),
comment: Some(String::from("grant select on databricks auth counters")),
},
]
.into_iter();
Ok(Box::new(operations))
}
// END_HADRON
ApplySpecPhase::CreateAvailabilityCheck => Ok(Box::new(once(Operation {
query: String::from(include_str!("sql/add_availabilitycheck_tables.sql")),
comment: None,
@@ -1103,6 +1229,63 @@ async fn get_operations<'a>(
Ok(Box::new(operations))
}
// BEGIN_HADRON
// New Hadron phases
//
// Grants permissions to roles that are used by Databricks.
ApplySpecPhase::AddDatabricksGrants => {
let operations = vec![
Operation {
query: String::from("GRANT USAGE ON SCHEMA neon TO databricks_monitor"),
comment: Some(String::from(
"Permissions needed to execute neon.* functions (in the postgres database)",
)),
},
Operation {
query: String::from(
"GRANT SELECT, INSERT, UPDATE ON health_check TO databricks_monitor",
),
comment: Some(String::from("Permissions needed for read and write probes")),
},
Operation {
query: String::from(
"GRANT EXECUTE ON FUNCTION pg_ls_dir(text) TO databricks_monitor",
),
comment: Some(String::from(
"Permissions needed to monitor .snap file counts",
)),
},
Operation {
query: String::from(
"GRANT SELECT ON neon.neon_perf_counters TO databricks_monitor",
),
comment: Some(String::from(
"Permissions needed to access neon performance counters view",
)),
},
Operation {
query: String::from(
"GRANT EXECUTE ON FUNCTION neon.get_perf_counters() TO databricks_monitor",
),
comment: Some(String::from(
"Permissions needed to execute the underlying performance counters function",
)),
},
]
.into_iter();
Ok(Box::new(operations))
}
// Creates minor objects that are used by Databricks.
ApplySpecPhase::CreateDatabricksMisc => Ok(Box::new(once(Operation {
query: String::from(include_str!("sql/create_databricks_misc.sql")),
comment: Some(String::from(
"The function databricks_monitor uses to convert exception to 0 or 1",
)),
}))),
// End of new Hadron phases
// END_HADRON
ApplySpecPhase::FinalizeDropLogicalSubscriptions => Ok(Box::new(once(Operation {
query: String::from(include_str!("sql/finalize_drop_subscriptions.sql")),
comment: None,

View File

@@ -0,0 +1,25 @@
DO $$
DECLARE
reader_role RECORD;
timeout_value TEXT;
BEGIN
-- Get the current GUC setting for reader statement timeout
SELECT current_setting('databricks.reader_statement_timeout', true) INTO timeout_value;
-- Only proceed if timeout_value is not null/empty and not '0' (disabled)
IF timeout_value IS NOT NULL AND timeout_value != '' AND timeout_value != '0' THEN
-- Find all databricks_reader_* roles and update their statement_timeout
FOR reader_role IN
SELECT r.rolname
FROM pg_roles r
WHERE r.rolname ~ '^databricks_reader_\d+$'
LOOP
-- Apply the timeout setting to the role (will overwrite existing setting)
EXECUTE format('ALTER ROLE %I SET statement_timeout = %L',
reader_role.rolname, timeout_value);
RAISE LOG 'Updated statement_timeout = % for role %', timeout_value, reader_role.rolname;
END LOOP;
END IF;
END
$$;

View File

@@ -0,0 +1,15 @@
ALTER ROLE databricks_monitor SET statement_timeout = '60s';
CREATE OR REPLACE FUNCTION health_check_write_succeeds()
RETURNS INTEGER AS $$
BEGIN
INSERT INTO health_check VALUES (1, now())
ON CONFLICT (id) DO UPDATE
SET updated_at = now();
RETURN 1;
EXCEPTION WHEN OTHERS THEN
RAISE EXCEPTION '[DATABRICKS_SMGR] health_check failed: [%] %', SQLSTATE, SQLERRM;
RETURN 0;
END;
$$ LANGUAGE plpgsql;

View File

@@ -793,6 +793,7 @@ impl Endpoint {
autoprewarm: args.autoprewarm,
offload_lfc_interval_seconds: args.offload_lfc_interval_seconds,
suspend_timeout_seconds: -1, // Only used in neon_local.
databricks_settings: None,
};
// this strange code is needed to support respec() in tests
@@ -938,7 +939,8 @@ impl Endpoint {
| ComputeStatus::TerminationPendingFast
| ComputeStatus::TerminationPendingImmediate
| ComputeStatus::Terminated
| ComputeStatus::RefreshConfigurationPending => {
| ComputeStatus::RefreshConfigurationPending
| ComputeStatus::RefreshConfiguration => {
bail!("unexpected compute status: {:?}", state.status)
}
}

View File

@@ -174,6 +174,9 @@ pub enum ComputeStatus {
Terminated,
// A spec refresh is being requested
RefreshConfigurationPending,
// A spec refresh is being applied. We cannot refresh configuration again until the current
// refresh is done, i.e., signal_refresh_configuration() will return 500 error.
RefreshConfiguration,
}
#[derive(Deserialize, Serialize)]
@@ -186,6 +189,10 @@ impl Display for ComputeStatus {
match self {
ComputeStatus::Empty => f.write_str("empty"),
ComputeStatus::ConfigurationPending => f.write_str("configuration-pending"),
ComputeStatus::RefreshConfiguration => f.write_str("refresh-configuration"),
ComputeStatus::RefreshConfigurationPending => {
f.write_str("refresh-configuration-pending")
}
ComputeStatus::Init => f.write_str("init"),
ComputeStatus::Running => f.write_str("running"),
ComputeStatus::Configuration => f.write_str("configuration"),
@@ -195,9 +202,6 @@ impl Display for ComputeStatus {
f.write_str("termination-pending-immediate")
}
ComputeStatus::Terminated => f.write_str("terminated"),
ComputeStatus::RefreshConfigurationPending => {
f.write_str("refresh-configuration-pending")
}
}
}
}

View File

@@ -193,6 +193,9 @@ pub struct ComputeSpec {
///
/// We use this value to derive other values, such as the installed extensions metric.
pub suspend_timeout_seconds: i64,
// Databricks specific options for compute instance.
pub databricks_settings: Option<DatabricksSettings>,
}
/// Feature flag to signal `compute_ctl` to enable certain experimental functionality.

View File

@@ -558,11 +558,11 @@ async fn add_request_id_header_to_response(
mut res: Response<Body>,
req_info: RequestInfo,
) -> Result<Response<Body>, ApiError> {
if let Some(request_id) = req_info.context::<RequestId>() {
if let Ok(request_header_value) = HeaderValue::from_str(&request_id.0) {
res.headers_mut()
.insert(&X_REQUEST_ID_HEADER, request_header_value);
};
if let Some(request_id) = req_info.context::<RequestId>()
&& let Ok(request_header_value) = HeaderValue::from_str(&request_id.0)
{
res.headers_mut()
.insert(&X_REQUEST_ID_HEADER, request_header_value);
};
Ok(res)

View File

@@ -72,10 +72,10 @@ impl Server {
if err.is_incomplete_message() || err.is_closed() || err.is_timeout() {
return true;
}
if let Some(inner) = err.source() {
if let Some(io) = inner.downcast_ref::<std::io::Error>() {
return suppress_io_error(io);
}
if let Some(inner) = err.source()
&& let Some(io) = inner.downcast_ref::<std::io::Error>()
{
return suppress_io_error(io);
}
false
}

View File

@@ -129,6 +129,12 @@ impl<L: LabelGroup> InfoMetric<L> {
}
}
impl<L: LabelGroup + Default> Default for InfoMetric<L, GaugeState> {
fn default() -> Self {
InfoMetric::new(L::default())
}
}
impl<L: LabelGroup, M: MetricType<Metadata = ()>> InfoMetric<L, M> {
pub fn with_metric(label: L, metric: M) -> Self {
Self {

View File

@@ -363,7 +363,7 @@ where
// TODO: An Iterator might be nicer. The communicator's clock algorithm needs to
// _slowly_ iterate through all buckets with its clock hand, without holding a lock.
// If we switch to an Iterator, it must not hold the lock.
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<(K, V)>> {
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<'_, (K, V)>> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
if pos >= map.buckets.len() {
return None;

View File

@@ -9,10 +9,7 @@ regex.workspace = true
bytes.workspace = true
anyhow.workspace = true
crc32c.workspace = true
criterion.workspace = true
once_cell.workspace = true
log.workspace = true
memoffset.workspace = true
pprof.workspace = true
thiserror.workspace = true
serde.workspace = true
@@ -22,6 +19,7 @@ tracing.workspace = true
postgres_versioninfo.workspace = true
[dev-dependencies]
criterion.workspace = true
env_logger.workspace = true
postgres.workspace = true

View File

@@ -34,9 +34,8 @@ const SIZEOF_CONTROLDATA: usize = size_of::<ControlFileData>();
impl ControlFileData {
/// Compute the offset of the `crc` field within the `ControlFileData` struct.
/// Equivalent to offsetof(ControlFileData, crc) in C.
// Someday this can be const when the right compiler features land.
fn pg_control_crc_offset() -> usize {
memoffset::offset_of!(ControlFileData, crc)
const fn pg_control_crc_offset() -> usize {
std::mem::offset_of!(ControlFileData, crc)
}
///

View File

@@ -4,12 +4,11 @@
use crate::pg_constants;
use crate::transaction_id_precedes;
use bytes::BytesMut;
use log::*;
use super::bindings::MultiXactId;
pub fn transaction_id_set_status(xid: u32, status: u8, page: &mut BytesMut) {
trace!(
tracing::trace!(
"handle_apply_request for RM_XACT_ID-{} (1-commit, 2-abort, 3-sub_commit)",
status
);

View File

@@ -14,7 +14,6 @@ use super::xlog_utils::*;
use crate::WAL_SEGMENT_SIZE;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use crc32c::*;
use log::*;
use std::cmp::min;
use std::num::NonZeroU32;
use utils::lsn::Lsn;
@@ -236,7 +235,7 @@ impl WalStreamDecoderHandler for WalStreamDecoder {
// XLOG_SWITCH records are special. If we see one, we need to skip
// to the next WAL segment.
let next_lsn = if xlogrec.is_xlog_switch_record() {
trace!("saw xlog switch record at {}", self.lsn);
tracing::trace!("saw xlog switch record at {}", self.lsn);
self.lsn + self.lsn.calc_padding(WAL_SEGMENT_SIZE as u64)
} else {
// Pad to an 8-byte boundary

View File

@@ -23,8 +23,6 @@ use crate::{WAL_SEGMENT_SIZE, XLOG_BLCKSZ};
use bytes::BytesMut;
use bytes::{Buf, Bytes};
use log::*;
use serde::Serialize;
use std::ffi::{CString, OsStr};
use std::fs::File;
@@ -235,7 +233,7 @@ pub fn find_end_of_wal(
let mut curr_lsn = start_lsn;
let mut buf = [0u8; XLOG_BLCKSZ];
let pg_version = MY_PGVERSION;
debug!("find_end_of_wal PG_VERSION: {}", pg_version);
tracing::debug!("find_end_of_wal PG_VERSION: {}", pg_version);
let mut decoder = WalStreamDecoder::new(start_lsn, pg_version);
@@ -247,7 +245,7 @@ pub fn find_end_of_wal(
match open_wal_segment(&seg_file_path)? {
None => {
// no more segments
debug!(
tracing::debug!(
"find_end_of_wal reached end at {:?}, segment {:?} doesn't exist",
result, seg_file_path
);
@@ -260,7 +258,7 @@ pub fn find_end_of_wal(
while curr_lsn.segment_number(wal_seg_size) == segno {
let bytes_read = segment.read(&mut buf)?;
if bytes_read == 0 {
debug!(
tracing::debug!(
"find_end_of_wal reached end at {:?}, EOF in segment {:?} at offset {}",
result,
seg_file_path,
@@ -276,7 +274,7 @@ pub fn find_end_of_wal(
match decoder.poll_decode() {
Ok(Some(record)) => result = record.0,
Err(e) => {
debug!(
tracing::debug!(
"find_end_of_wal reached end at {:?}, decode error: {:?}",
result, e
);

View File

@@ -185,6 +185,7 @@ impl Client {
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
write_buf: BytesMut,
) -> Client {
Client {
inner: InnerClient {
@@ -195,7 +196,7 @@ impl Client {
waiting: 0,
received: 0,
},
buffer: Default::default(),
buffer: write_buf,
},
cached_typeinfo: Default::default(),

View File

@@ -47,14 +47,7 @@ impl Encoder<BytesMut> for PostgresCodec {
type Error = io::Error;
fn encode(&mut self, item: BytesMut, dst: &mut BytesMut) -> io::Result<()> {
// When it comes to request/response workflows, we usually flush the entire write
// buffer in order to wait for the response before we send a new request.
// Therefore we can avoid the copy and just replace the buffer.
if dst.is_empty() {
*dst = item;
} else {
dst.extend_from_slice(&item);
}
dst.unsplit(item);
Ok(())
}
}

View File

@@ -77,6 +77,9 @@ where
connect_timeout,
};
let mut stream = stream.into_framed();
let write_buf = std::mem::take(stream.write_buffer_mut());
let (client_tx, conn_rx) = mpsc::unbounded_channel();
let (conn_tx, client_rx) = mpsc::channel(4);
let client = Client::new(
@@ -86,9 +89,9 @@ where
ssl_mode,
process_id,
secret_key,
write_buf,
);
let stream = stream.into_framed();
let connection = Connection::new(stream, conn_tx, conn_rx);
Ok((client, connection))

View File

@@ -229,8 +229,11 @@ where
Poll::Ready(()) => {
trace!("poll_flush: flushed");
// GC the write buffer if we managed to flush
gc_bytesmut(self.stream.write_buffer_mut());
// Since our codec prefers to share the buffer with the `Client`,
// if we don't release our share, then the `Client` would have to re-alloc
// the buffer when they next use it.
debug_assert!(self.stream.write_buffer().is_empty());
*self.stream.write_buffer_mut() = BytesMut::new();
Poll::Ready(Ok(()))
}

View File

@@ -9,7 +9,7 @@ use postgres_protocol2::message::backend::{ErrorFields, ErrorResponseBody};
pub use self::sqlstate::*;
#[allow(clippy::unreadable_literal)]
mod sqlstate;
pub mod sqlstate;
/// The severity of a Postgres error or notice.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]

View File

@@ -49,7 +49,7 @@ impl PerfSpan {
}
}
pub fn enter(&self) -> PerfSpanEntered {
pub fn enter(&self) -> PerfSpanEntered<'_> {
if let Some(ref id) = self.inner.id() {
self.dispatch.enter(id);
}

View File

@@ -14,9 +14,9 @@ use utils::logging::warn_slow;
use crate::pool::{ChannelPool, ClientGuard, ClientPool, StreamGuard, StreamPool};
use crate::retry::Retry;
use crate::split::GetPageSplitter;
use compute_api::spec::PageserverProtocol;
use pageserver_page_api as page_api;
use pageserver_page_api::GetPageSplitter;
use utils::id::{TenantId, TimelineId};
use utils::shard::{ShardCount, ShardIndex, ShardNumber, ShardStripeSize};
@@ -230,16 +230,14 @@ impl PageserverClient {
) -> tonic::Result<page_api::GetPageResponse> {
// Fast path: request is for a single shard.
if let Some(shard_id) =
GetPageSplitter::for_single_shard(&req, shards.count, shards.stripe_size)
.map_err(|err| tonic::Status::internal(err.to_string()))?
GetPageSplitter::for_single_shard(&req, shards.count, shards.stripe_size)?
{
return Self::get_page_with_shard(req, shards.get(shard_id)?).await;
}
// Request spans multiple shards. Split it, dispatch concurrent per-shard requests, and
// reassemble the responses.
let mut splitter = GetPageSplitter::split(req, shards.count, shards.stripe_size)
.map_err(|err| tonic::Status::internal(err.to_string()))?;
let mut splitter = GetPageSplitter::split(req, shards.count, shards.stripe_size)?;
let mut shard_requests = FuturesUnordered::new();
for (shard_id, shard_req) in splitter.drain_requests() {
@@ -249,14 +247,10 @@ impl PageserverClient {
}
while let Some((shard_id, shard_response)) = shard_requests.next().await.transpose()? {
splitter
.add_response(shard_id, shard_response)
.map_err(|err| tonic::Status::internal(err.to_string()))?;
splitter.add_response(shard_id, shard_response)?;
}
splitter
.get_response()
.map_err(|err| tonic::Status::internal(err.to_string()))
Ok(splitter.collect_response()?)
}
/// Fetches pages on the given shard. Does not retry internally.

View File

@@ -1,6 +1,5 @@
mod client;
mod pool;
mod retry;
mod split;
pub use client::{PageserverClient, ShardSpec};

View File

@@ -19,7 +19,9 @@ pub mod proto {
}
mod client;
pub use client::Client;
mod model;
mod split;
pub use client::Client;
pub use model::*;
pub use split::{GetPageSplitter, SplitError};

View File

@@ -1,20 +1,19 @@
use std::collections::HashMap;
use anyhow::anyhow;
use bytes::Bytes;
use crate::model::*;
use pageserver_api::key::rel_block_to_key;
use pageserver_api::shard::key_to_shard_number;
use pageserver_page_api as page_api;
use utils::shard::{ShardCount, ShardIndex, ShardStripeSize};
/// Splits GetPageRequests that straddle shard boundaries and assembles the responses.
/// TODO: add tests for this.
pub struct GetPageSplitter {
/// Split requests by shard index.
requests: HashMap<ShardIndex, page_api::GetPageRequest>,
requests: HashMap<ShardIndex, GetPageRequest>,
/// The response being assembled. Preallocated with empty pages, to be filled in.
response: page_api::GetPageResponse,
response: GetPageResponse,
/// Maps the offset in `request.block_numbers` and `response.pages` to the owning shard. Used
/// to assemble the response pages in the same order as the original request.
block_shards: Vec<ShardIndex>,
@@ -24,22 +23,22 @@ impl GetPageSplitter {
/// Checks if the given request only touches a single shard, and returns the shard ID. This is
/// the common case, so we check first in order to avoid unnecessary allocations and overhead.
pub fn for_single_shard(
req: &page_api::GetPageRequest,
req: &GetPageRequest,
count: ShardCount,
stripe_size: Option<ShardStripeSize>,
) -> anyhow::Result<Option<ShardIndex>> {
) -> Result<Option<ShardIndex>, SplitError> {
// Fast path: unsharded tenant.
if count.is_unsharded() {
return Ok(Some(ShardIndex::unsharded()));
}
let Some(stripe_size) = stripe_size else {
return Err(anyhow!("stripe size must be given for sharded tenants"));
return Err("stripe size must be given for sharded tenants".into());
};
// Find the first page's shard, for comparison.
let Some(&first_page) = req.block_numbers.first() else {
return Err(anyhow!("no block numbers in request"));
return Err("no block numbers in request".into());
};
let key = rel_block_to_key(req.rel, first_page);
let shard_number = key_to_shard_number(count, stripe_size, &key);
@@ -57,10 +56,10 @@ impl GetPageSplitter {
/// Splits the given request.
pub fn split(
req: page_api::GetPageRequest,
req: GetPageRequest,
count: ShardCount,
stripe_size: Option<ShardStripeSize>,
) -> anyhow::Result<Self> {
) -> Result<Self, SplitError> {
// The caller should make sure we don't split requests unnecessarily.
debug_assert!(
Self::for_single_shard(&req, count, stripe_size)?.is_none(),
@@ -68,10 +67,10 @@ impl GetPageSplitter {
);
if count.is_unsharded() {
return Err(anyhow!("unsharded tenant, no point in splitting request"));
return Err("unsharded tenant, no point in splitting request".into());
}
let Some(stripe_size) = stripe_size else {
return Err(anyhow!("stripe size must be given for sharded tenants"));
return Err("stripe size must be given for sharded tenants".into());
};
// Split the requests by shard index.
@@ -84,7 +83,7 @@ impl GetPageSplitter {
requests
.entry(shard_id)
.or_insert_with(|| page_api::GetPageRequest {
.or_insert_with(|| GetPageRequest {
request_id: req.request_id,
request_class: req.request_class,
rel: req.rel,
@@ -98,16 +97,16 @@ impl GetPageSplitter {
// Construct a response to be populated by shard responses. Preallocate empty page slots
// with the expected block numbers.
let response = page_api::GetPageResponse {
let response = GetPageResponse {
request_id: req.request_id,
status_code: page_api::GetPageStatusCode::Ok,
status_code: GetPageStatusCode::Ok,
reason: None,
rel: req.rel,
pages: req
.block_numbers
.into_iter()
.map(|block_number| {
page_api::Page {
Page {
block_number,
image: Bytes::new(), // empty page slot to be filled in
}
@@ -123,43 +122,38 @@ impl GetPageSplitter {
}
/// Drains the per-shard requests, moving them out of the splitter to avoid extra allocations.
pub fn drain_requests(
&mut self,
) -> impl Iterator<Item = (ShardIndex, page_api::GetPageRequest)> {
pub fn drain_requests(&mut self) -> impl Iterator<Item = (ShardIndex, GetPageRequest)> {
self.requests.drain()
}
/// Adds a response from the given shard. The response must match the request ID and have an OK
/// status code. A response must not already exist for the given shard ID.
#[allow(clippy::result_large_err)]
pub fn add_response(
&mut self,
shard_id: ShardIndex,
response: page_api::GetPageResponse,
) -> anyhow::Result<()> {
response: GetPageResponse,
) -> Result<(), SplitError> {
// The caller should already have converted status codes into tonic::Status.
if response.status_code != page_api::GetPageStatusCode::Ok {
return Err(anyhow!(
if response.status_code != GetPageStatusCode::Ok {
return Err(SplitError(format!(
"unexpected non-OK response for shard {shard_id}: {} {}",
response.status_code,
response.reason.unwrap_or_default()
));
)));
}
if response.request_id != self.response.request_id {
return Err(anyhow!(
return Err(SplitError(format!(
"response ID mismatch for shard {shard_id}: expected {}, got {}",
self.response.request_id,
response.request_id
));
self.response.request_id, response.request_id
)));
}
if response.request_id != self.response.request_id {
return Err(anyhow!(
return Err(SplitError(format!(
"response ID mismatch for shard {shard_id}: expected {}, got {}",
self.response.request_id,
response.request_id
));
self.response.request_id, response.request_id
)));
}
// Place the shard response pages into the assembled response, in request order.
@@ -171,26 +165,27 @@ impl GetPageSplitter {
}
let Some(slot) = self.response.pages.get_mut(i) else {
return Err(anyhow!("no block_shards slot {i} for shard {shard_id}"));
return Err(SplitError(format!(
"no block_shards slot {i} for shard {shard_id}"
)));
};
let Some(page) = pages.next() else {
return Err(anyhow!(
return Err(SplitError(format!(
"missing page {} in shard {shard_id} response",
slot.block_number
));
)));
};
if page.block_number != slot.block_number {
return Err(anyhow!(
return Err(SplitError(format!(
"shard {shard_id} returned wrong page at index {i}, expected {} got {}",
slot.block_number,
page.block_number
));
slot.block_number, page.block_number
)));
}
if !slot.image.is_empty() {
return Err(anyhow!(
return Err(SplitError(format!(
"shard {shard_id} returned duplicate page {} at index {i}",
slot.block_number
));
)));
}
*slot = page;
@@ -198,32 +193,54 @@ impl GetPageSplitter {
// Make sure we've consumed all pages from the shard response.
if let Some(extra_page) = pages.next() {
return Err(anyhow!(
return Err(SplitError(format!(
"shard {shard_id} returned extra page: {}",
extra_page.block_number
));
)));
}
Ok(())
}
/// Fetches the final, assembled response.
#[allow(clippy::result_large_err)]
pub fn get_response(self) -> anyhow::Result<page_api::GetPageResponse> {
/// Collects the final, assembled response.
pub fn collect_response(self) -> Result<GetPageResponse, SplitError> {
// Check that the response is complete.
for (i, page) in self.response.pages.iter().enumerate() {
if page.image.is_empty() {
return Err(anyhow!(
return Err(SplitError(format!(
"missing page {} for shard {}",
page.block_number,
self.block_shards
.get(i)
.map(|s| s.to_string())
.unwrap_or_else(|| "?".to_string())
));
)));
}
}
Ok(self.response)
}
}
/// A GetPageSplitter error.
#[derive(Debug, thiserror::Error)]
#[error("{0}")]
pub struct SplitError(String);
impl From<&str> for SplitError {
fn from(err: &str) -> Self {
SplitError(err.to_string())
}
}
impl From<String> for SplitError {
fn from(err: String) -> Self {
SplitError(err)
}
}
impl From<SplitError> for tonic::Status {
fn from(err: SplitError) -> Self {
tonic::Status::internal(err.0)
}
}

View File

@@ -715,7 +715,7 @@ fn start_pageserver(
disk_usage_eviction_state,
deletion_queue.new_client(),
secondary_controller,
feature_resolver,
feature_resolver.clone(),
)
.context("Failed to initialize router state")?,
);
@@ -841,14 +841,14 @@ fn start_pageserver(
} else {
None
},
feature_resolver.clone(),
);
// Spawn a Pageserver gRPC server task. It will spawn separate tasks for
// each stream/request.
// Spawn a Pageserver gRPC server task. It will spawn separate tasks for each request/stream.
// It uses a separate compute request Tokio runtime (COMPUTE_REQUEST_RUNTIME).
//
// TODO: this uses a separate Tokio runtime for the page service. If we want
// other gRPC services, they will need their own port and runtime. Is this
// necessary?
// NB: this port is exposed to computes. It should only provide services that we're okay with
// computes accessing. Internal services should use a separate port.
let mut page_service_grpc = None;
if let Some(grpc_listener) = grpc_listener {
page_service_grpc = Some(GrpcPageServiceHandler::spawn(

View File

@@ -2005,6 +2005,10 @@ async fn put_tenant_location_config_handler(
let state = get_state(&request);
let conf = state.conf;
fail::fail_point!("put-location-conf-handler", |_| {
Err(ApiError::ResourceUnavailable("failpoint".into()))
});
// The `Detached` state is special, it doesn't upsert a tenant, it removes
// its local disk content and drops it from memory.
if let LocationConfigMode::Detached = request_data.config.mode {

View File

@@ -16,7 +16,8 @@ use anyhow::{Context as _, bail};
use bytes::{Buf as _, BufMut as _, BytesMut};
use chrono::Utc;
use futures::future::BoxFuture;
use futures::{FutureExt, Stream};
use futures::stream::FuturesUnordered;
use futures::{FutureExt, Stream, StreamExt as _};
use itertools::Itertools;
use jsonwebtoken::TokenData;
use once_cell::sync::OnceCell;
@@ -35,8 +36,8 @@ use pageserver_api::pagestream_api::{
};
use pageserver_api::reltag::SlruKind;
use pageserver_api::shard::TenantShardId;
use pageserver_page_api as page_api;
use pageserver_page_api::proto;
use pageserver_page_api::{self as page_api, GetPageSplitter};
use postgres_backend::{
AuthType, PostgresBackend, PostgresBackendReader, QueryError, is_expected_io_error,
};
@@ -68,6 +69,7 @@ use crate::config::PageServerConf;
use crate::context::{
DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder,
};
use crate::feature_resolver::FeatureResolver;
use crate::metrics::{
self, COMPUTE_COMMANDS_COUNTERS, ComputeCommandKind, GetPageBatchBreakReason, LIVE_CONNECTIONS,
MISROUTED_PAGESTREAM_REQUESTS, PAGESTREAM_HANDLER_RESULTS_TOTAL, SmgrOpTimer, TimelineMetrics,
@@ -139,6 +141,7 @@ pub fn spawn(
perf_trace_dispatch: Option<Dispatch>,
tcp_listener: tokio::net::TcpListener,
tls_config: Option<Arc<rustls::ServerConfig>>,
feature_resolver: FeatureResolver,
) -> Listener {
let cancel = CancellationToken::new();
let libpq_ctx = RequestContext::todo_child(
@@ -160,6 +163,7 @@ pub fn spawn(
conf.pg_auth_type,
tls_config,
conf.page_service_pipelining.clone(),
feature_resolver,
libpq_ctx,
cancel.clone(),
)
@@ -218,6 +222,7 @@ pub async fn libpq_listener_main(
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
pipelining_config: PageServicePipeliningConfig,
feature_resolver: FeatureResolver,
listener_ctx: RequestContext,
listener_cancel: CancellationToken,
) -> Connections {
@@ -261,6 +266,7 @@ pub async fn libpq_listener_main(
auth_type,
tls_config.clone(),
pipelining_config.clone(),
feature_resolver.clone(),
connection_ctx,
connections_cancel.child_token(),
gate_guard,
@@ -303,6 +309,7 @@ async fn page_service_conn_main(
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
pipelining_config: PageServicePipeliningConfig,
feature_resolver: FeatureResolver,
connection_ctx: RequestContext,
cancel: CancellationToken,
gate_guard: GateGuard,
@@ -370,6 +377,7 @@ async fn page_service_conn_main(
perf_span_fields,
connection_ctx,
cancel.clone(),
feature_resolver.clone(),
gate_guard,
);
let pgbackend =
@@ -421,6 +429,8 @@ struct PageServerHandler {
pipelining_config: PageServicePipeliningConfig,
get_vectored_concurrent_io: GetVectoredConcurrentIo,
feature_resolver: FeatureResolver,
gate_guard: GateGuard,
}
@@ -457,13 +467,6 @@ impl TimelineHandles {
self.handles
.get(timeline_id, shard_selector, &self.wrapper)
.await
.map_err(|e| match e {
timeline::handle::GetError::TenantManager(e) => e,
timeline::handle::GetError::PerTimelineStateShutDown => {
trace!("per-timeline state shut down");
GetActiveTimelineError::Timeline(GetTimelineError::ShuttingDown)
}
})
}
fn tenant_id(&self) -> Option<TenantId> {
@@ -479,11 +482,9 @@ pub(crate) struct TenantManagerWrapper {
tenant_id: once_cell::sync::OnceCell<TenantId>,
}
#[derive(Debug)]
pub(crate) struct TenantManagerTypes;
impl timeline::handle::Types for TenantManagerTypes {
type TenantManagerError = GetActiveTimelineError;
type TenantManager = TenantManagerWrapper;
type Timeline = TenantManagerCacheItem;
}
@@ -535,6 +536,7 @@ impl timeline::handle::TenantManager<TenantManagerTypes> for TenantManagerWrappe
match resolved {
ShardResolveResult::Found(tenant_shard) => break tenant_shard,
ShardResolveResult::NotFound => {
MISROUTED_PAGESTREAM_REQUESTS.inc();
return Err(GetActiveTimelineError::Tenant(
GetActiveTenantError::NotFound(GetTenantError::NotFound(*tenant_id)),
));
@@ -586,6 +588,15 @@ impl timeline::handle::TenantManager<TenantManagerTypes> for TenantManagerWrappe
}
}
/// Whether to hold the applied GC cutoff guard when processing GetPage requests.
/// This is determined once at the start of pagestream subprotocol handling based on
/// feature flags, configuration, and test conditions.
#[derive(Debug, Clone, Copy)]
enum HoldAppliedGcCutoffGuard {
Yes,
No,
}
#[derive(thiserror::Error, Debug)]
enum PageStreamError {
/// We encountered an error that should prompt the client to reconnect:
@@ -729,6 +740,7 @@ enum BatchedFeMessage {
GetPage {
span: Span,
shard: WeakHandle<TenantManagerTypes>,
applied_gc_cutoff_guard: Option<RcuReadGuard<Lsn>>,
pages: SmallVec<[BatchedGetPageRequest; 1]>,
batch_break_reason: GetPageBatchBreakReason,
},
@@ -908,6 +920,7 @@ impl PageServerHandler {
perf_span_fields: ConnectionPerfSpanFields,
connection_ctx: RequestContext,
cancel: CancellationToken,
feature_resolver: FeatureResolver,
gate_guard: GateGuard,
) -> Self {
PageServerHandler {
@@ -919,6 +932,7 @@ impl PageServerHandler {
cancel,
pipelining_config,
get_vectored_concurrent_io,
feature_resolver,
gate_guard,
}
}
@@ -958,6 +972,7 @@ impl PageServerHandler {
ctx: &RequestContext,
protocol_version: PagestreamProtocolVersion,
parent_span: Span,
hold_gc_cutoff_guard: HoldAppliedGcCutoffGuard,
) -> Result<Option<BatchedFeMessage>, QueryError>
where
IO: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static,
@@ -1195,19 +1210,27 @@ impl PageServerHandler {
})
.await?;
let applied_gc_cutoff_guard = shard.get_applied_gc_cutoff_lsn(); // hold guard
// We're holding the Handle
let effective_lsn = match Self::effective_request_lsn(
&shard,
shard.get_last_record_lsn(),
req.hdr.request_lsn,
req.hdr.not_modified_since,
&shard.get_applied_gc_cutoff_lsn(),
&applied_gc_cutoff_guard,
) {
Ok(lsn) => lsn,
Err(e) => {
return respond_error!(span, e);
}
};
let applied_gc_cutoff_guard = match hold_gc_cutoff_guard {
HoldAppliedGcCutoffGuard::Yes => Some(applied_gc_cutoff_guard),
HoldAppliedGcCutoffGuard::No => {
drop(applied_gc_cutoff_guard);
None
}
};
let batch_wait_ctx = if ctx.has_perf_span() {
Some(
@@ -1228,6 +1251,7 @@ impl PageServerHandler {
BatchedFeMessage::GetPage {
span,
shard: shard.downgrade(),
applied_gc_cutoff_guard,
pages: smallvec![BatchedGetPageRequest {
req,
timer,
@@ -1328,13 +1352,28 @@ impl PageServerHandler {
match (eligible_batch, this_msg) {
(
BatchedFeMessage::GetPage {
pages: accum_pages, ..
pages: accum_pages,
applied_gc_cutoff_guard: accum_applied_gc_cutoff_guard,
..
},
BatchedFeMessage::GetPage {
pages: this_pages, ..
pages: this_pages,
applied_gc_cutoff_guard: this_applied_gc_cutoff_guard,
..
},
) => {
accum_pages.extend(this_pages);
// the minimum of the two guards will keep data for both alive
match (&accum_applied_gc_cutoff_guard, this_applied_gc_cutoff_guard) {
(None, None) => (),
(None, Some(this)) => *accum_applied_gc_cutoff_guard = Some(this),
(Some(_), None) => (),
(Some(accum), Some(this)) => {
if **accum > *this {
*accum_applied_gc_cutoff_guard = Some(this);
}
}
};
Ok(())
}
#[cfg(feature = "testing")]
@@ -1649,6 +1688,7 @@ impl PageServerHandler {
BatchedFeMessage::GetPage {
span,
shard,
applied_gc_cutoff_guard,
pages,
batch_break_reason,
} => {
@@ -1668,6 +1708,7 @@ impl PageServerHandler {
.instrument(span.clone())
.await;
assert_eq!(res.len(), npages);
drop(applied_gc_cutoff_guard);
res
},
span,
@@ -1749,7 +1790,7 @@ impl PageServerHandler {
/// Coding discipline within this function: all interaction with the `pgb` connection
/// needs to be sensitive to connection shutdown, currently signalled via [`Self::cancel`].
/// This is so that we can shutdown page_service quickly.
#[instrument(skip_all)]
#[instrument(skip_all, fields(hold_gc_cutoff_guard))]
async fn handle_pagerequests<IO>(
&mut self,
pgb: &mut PostgresBackend<IO>,
@@ -1795,6 +1836,30 @@ impl PageServerHandler {
.take()
.expect("implementation error: timeline_handles should not be locked");
// Evaluate the expensive feature resolver check once per pagestream subprotocol handling
// instead of once per GetPage request. This is shared between pipelined and serial paths.
let hold_gc_cutoff_guard = if cfg!(test) || cfg!(feature = "testing") {
HoldAppliedGcCutoffGuard::Yes
} else {
// Use the global feature resolver with the tenant ID directly, avoiding the need
// to get a timeline/shard which might not be available on this pageserver node.
let empty_properties = std::collections::HashMap::new();
match self.feature_resolver.evaluate_boolean(
"page-service-getpage-hold-applied-gc-cutoff-guard",
tenant_id,
&empty_properties,
) {
Ok(()) => HoldAppliedGcCutoffGuard::Yes,
Err(_) => HoldAppliedGcCutoffGuard::No,
}
};
// record it in the span of handle_pagerequests so that both the request_span
// and the pipeline implementation spans contains the field.
Span::current().record(
"hold_gc_cutoff_guard",
tracing::field::debug(&hold_gc_cutoff_guard),
);
let request_span = info_span!("request");
let ((pgb_reader, timeline_handles), result) = match self.pipelining_config.clone() {
PageServicePipeliningConfig::Pipelined(pipelining_config) => {
@@ -1808,6 +1873,7 @@ impl PageServerHandler {
pipelining_config,
protocol_version,
io_concurrency,
hold_gc_cutoff_guard,
&ctx,
)
.await
@@ -1822,6 +1888,7 @@ impl PageServerHandler {
request_span,
protocol_version,
io_concurrency,
hold_gc_cutoff_guard,
&ctx,
)
.await
@@ -1850,6 +1917,7 @@ impl PageServerHandler {
request_span: Span,
protocol_version: PagestreamProtocolVersion,
io_concurrency: IoConcurrency,
hold_gc_cutoff_guard: HoldAppliedGcCutoffGuard,
ctx: &RequestContext,
) -> (
(PostgresBackendReader<IO>, TimelineHandles),
@@ -1871,6 +1939,7 @@ impl PageServerHandler {
ctx,
protocol_version,
request_span.clone(),
hold_gc_cutoff_guard,
)
.await;
let msg = match msg {
@@ -1918,6 +1987,7 @@ impl PageServerHandler {
pipelining_config: PageServicePipeliningConfigPipelined,
protocol_version: PagestreamProtocolVersion,
io_concurrency: IoConcurrency,
hold_gc_cutoff_guard: HoldAppliedGcCutoffGuard,
ctx: &RequestContext,
) -> (
(PostgresBackendReader<IO>, TimelineHandles),
@@ -2021,6 +2091,7 @@ impl PageServerHandler {
&ctx,
protocol_version,
request_span.clone(),
hold_gc_cutoff_guard,
)
.await;
let Some(read_res) = read_res.transpose() else {
@@ -2067,6 +2138,7 @@ impl PageServerHandler {
pages,
span: _,
shard: _,
applied_gc_cutoff_guard: _,
batch_break_reason: _,
} = &mut batch
{
@@ -3352,18 +3424,6 @@ impl GrpcPageServiceHandler {
Ok(CancellableTask { task, cancel })
}
/// Errors if the request is executed on a non-zero shard. Only shard 0 has a complete view of
/// relations and their sizes, as well as SLRU segments and similar data.
#[allow(clippy::result_large_err)]
fn ensure_shard_zero(timeline: &Handle<TenantManagerTypes>) -> Result<(), tonic::Status> {
match timeline.get_shard_index().shard_number.0 {
0 => Ok(()),
shard => Err(tonic::Status::invalid_argument(format!(
"request must execute on shard zero (is shard {shard})",
))),
}
}
/// Generates a PagestreamRequest header from a ReadLsn and request ID.
fn make_hdr(
read_lsn: page_api::ReadLsn,
@@ -3378,30 +3438,72 @@ impl GrpcPageServiceHandler {
}
}
/// Acquires a timeline handle for the given request.
/// Acquires a timeline handle for the given request. The shard index must match a local shard.
///
/// TODO: during shard splits, the compute may still be sending requests to the parent shard
/// until the entire split is committed and the compute is notified. Consider installing a
/// temporary shard router from the parent to the children while the split is in progress.
///
/// TODO: consider moving this to a middleware layer; all requests need it. Needs to manage
/// the TimelineHandles lifecycle.
///
/// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to avoid
/// the unnecessary overhead.
/// NB: this will fail during shard splits, see comment on [`Self::maybe_split_get_page`].
async fn get_request_timeline(
&self,
req: &tonic::Request<impl Any>,
) -> Result<Handle<TenantManagerTypes>, GetActiveTimelineError> {
let ttid = *extract::<TenantTimelineId>(req);
let TenantTimelineId {
tenant_id,
timeline_id,
} = *extract::<TenantTimelineId>(req);
let shard_index = *extract::<ShardIndex>(req);
let shard_selector = ShardSelector::Known(shard_index);
// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to
// avoid the unnecessary overhead.
TimelineHandles::new(self.tenant_manager.clone())
.get(ttid.tenant_id, ttid.timeline_id, shard_selector)
.get(tenant_id, timeline_id, ShardSelector::Known(shard_index))
.await
}
/// Acquires a timeline handle for the given request, which must be for shard zero. Most
/// metadata requests are only valid on shard zero.
///
/// NB: during an ongoing shard split, the compute will keep talking to the parent shard until
/// the split is committed, but the parent shard may have been removed in the meanwhile. In that
/// case, we reroute the request to the new child shard. See [`Self::maybe_split_get_page`].
///
/// TODO: revamp the split protocol to avoid this child routing.
async fn get_request_timeline_shard_zero(
&self,
req: &tonic::Request<impl Any>,
) -> Result<Handle<TenantManagerTypes>, tonic::Status> {
let TenantTimelineId {
tenant_id,
timeline_id,
} = *extract::<TenantTimelineId>(req);
let shard_index = *extract::<ShardIndex>(req);
if shard_index.shard_number.0 != 0 {
return Err(tonic::Status::invalid_argument(format!(
"request only valid on shard zero (requested shard {shard_index})",
)));
}
// TODO: untangle acquisition from TenantManagerWrapper::resolve() and Cache::get(), to
// avoid the unnecessary overhead.
let mut handles = TimelineHandles::new(self.tenant_manager.clone());
match handles
.get(tenant_id, timeline_id, ShardSelector::Known(shard_index))
.await
{
Ok(timeline) => Ok(timeline),
Err(err) => {
// We may be in the middle of a shard split. Try to find a child shard 0.
if let Ok(timeline) = handles
.get(tenant_id, timeline_id, ShardSelector::Zero)
.await
&& timeline.get_shard_index().shard_count > shard_index.shard_count
{
return Ok(timeline);
}
Err(err.into())
}
}
}
/// Starts a SmgrOpTimer at received_at, throttles the request, and records execution start.
/// Only errors if the timeline is shutting down.
///
@@ -3428,32 +3530,37 @@ impl GrpcPageServiceHandler {
/// NB: errors returned from here are intercepted in get_pages(), and may be converted to a
/// GetPageResponse with an appropriate status code to avoid terminating the stream.
///
/// TODO: verify that the requested pages belong to this shard.
///
/// TODO: get_vectored() currently enforces a batch limit of 32. Postgres will typically send
/// batches up to effective_io_concurrency = 100. Either we have to accept large batches, or
/// split them up in the client or server.
#[instrument(skip_all, fields(req_id, rel, blkno, blks, req_lsn, mod_lsn))]
#[instrument(skip_all, fields(
req_id = %req.request_id,
rel = %req.rel,
blkno = %req.block_numbers[0],
blks = %req.block_numbers.len(),
lsn = %req.read_lsn,
))]
async fn get_page(
ctx: &RequestContext,
timeline: &WeakHandle<TenantManagerTypes>,
req: proto::GetPageRequest,
timeline: Handle<TenantManagerTypes>,
req: page_api::GetPageRequest,
io_concurrency: IoConcurrency,
) -> Result<proto::GetPageResponse, tonic::Status> {
let received_at = Instant::now();
let timeline = timeline.upgrade()?;
received_at: Instant,
) -> Result<page_api::GetPageResponse, tonic::Status> {
let ctx = ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
let req = page_api::GetPageRequest::try_from(req)?;
span_record!(
req_id = %req.request_id,
rel = %req.rel,
blkno = %req.block_numbers[0],
blks = %req.block_numbers.len(),
lsn = %req.read_lsn,
);
for &blkno in &req.block_numbers {
let shard = timeline.get_shard_identity();
let key = rel_block_to_key(req.rel, blkno);
if !shard.is_key_local(&key) {
return Err(tonic::Status::invalid_argument(format!(
"block {blkno} of relation {} requested on wrong shard {} (is on {})",
req.rel,
timeline.get_shard_index(),
ShardIndex::new(shard.get_shard_number(&key), shard.count),
)));
}
}
let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn(); // hold guard
let effective_lsn = PageServerHandler::effective_request_lsn(
@@ -3529,7 +3636,89 @@ impl GrpcPageServiceHandler {
};
}
Ok(resp.into())
Ok(resp)
}
/// Processes a GetPage request when there is a potential shard split in progress. We have to
/// reroute the request to any local child shards, and split batch requests that straddle
/// multiple child shards.
///
/// Parent shards are split and removed incrementally (there may be many parent shards when
/// splitting an already-sharded tenant), but the compute is only notified once the overall
/// split commits, which can take several minutes. In the meanwhile, the compute will be sending
/// requests to the parent shards.
///
/// TODO: add test infrastructure to provoke this situation frequently and for long periods of
/// time, to properly exercise it.
///
/// TODO: revamp the split protocol to avoid this, e.g.:
/// * Keep the parent shard until the split commits and the compute is notified.
/// * Notify the compute about each subsplit.
/// * Return an error that updates the compute's shard map.
#[instrument(skip_all)]
#[allow(clippy::too_many_arguments)]
async fn maybe_split_get_page(
ctx: &RequestContext,
handles: &mut TimelineHandles,
tenant_id: TenantId,
timeline_id: TimelineId,
parent: ShardIndex,
req: page_api::GetPageRequest,
io_concurrency: IoConcurrency,
received_at: Instant,
) -> Result<page_api::GetPageResponse, tonic::Status> {
// Check the first page to see if we have any child shards at all. Otherwise, the compute is
// just talking to the wrong Pageserver. If the parent has been split, the shard now owning
// the page must have a higher shard count.
let timeline = handles
.get(
tenant_id,
timeline_id,
ShardSelector::Page(rel_block_to_key(req.rel, req.block_numbers[0])),
)
.await?;
let shard_id = timeline.get_shard_identity();
if shard_id.count <= parent.shard_count {
return Err(HandleUpgradeError::ShutDown.into()); // emulate original error
}
// Fast path: the request fits in a single shard.
if let Some(shard_index) =
GetPageSplitter::for_single_shard(&req, shard_id.count, Some(shard_id.stripe_size))?
{
// We got the shard ID from the first page, so these must be equal.
assert_eq!(shard_index.shard_number, shard_id.number);
assert_eq!(shard_index.shard_count, shard_id.count);
return Self::get_page(ctx, timeline, req, io_concurrency, received_at).await;
}
// The request spans multiple shards; split it and dispatch parallel requests. All pages
// were originally in the parent shard, and during a split all children are local, so we
// expect to find local shards for all pages.
let mut splitter = GetPageSplitter::split(req, shard_id.count, Some(shard_id.stripe_size))?;
let mut shard_requests = FuturesUnordered::new();
for (shard_index, shard_req) in splitter.drain_requests() {
let timeline = handles
.get(tenant_id, timeline_id, ShardSelector::Known(shard_index))
.await?;
let future = Self::get_page(
ctx,
timeline,
shard_req,
io_concurrency.clone(),
received_at,
)
.map(move |result| result.map(|resp| (shard_index, resp)));
shard_requests.push(future);
}
while let Some((shard_index, shard_response)) = shard_requests.next().await.transpose()? {
splitter.add_response(shard_index, shard_response)?;
}
Ok(splitter.collect_response()?)
}
}
@@ -3558,11 +3747,10 @@ impl proto::PageService for GrpcPageServiceHandler {
// to be the sweet spot where throughput is saturated.
const CHUNK_SIZE: usize = 256 * 1024;
let timeline = self.get_request_timeline(&req).await?;
let timeline = self.get_request_timeline_shard_zero(&req).await?;
let ctx = self.ctx.with_scope_timeline(&timeline);
// Validate the request and decorate the span.
Self::ensure_shard_zero(&timeline)?;
if timeline.is_archived() == Some(true) {
return Err(tonic::Status::failed_precondition("timeline is archived"));
}
@@ -3678,11 +3866,10 @@ impl proto::PageService for GrpcPageServiceHandler {
req: tonic::Request<proto::GetDbSizeRequest>,
) -> Result<tonic::Response<proto::GetDbSizeResponse>, tonic::Status> {
let received_at = extract::<ReceivedAt>(&req).0;
let timeline = self.get_request_timeline(&req).await?;
let timeline = self.get_request_timeline_shard_zero(&req).await?;
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
Self::ensure_shard_zero(&timeline)?;
let req: page_api::GetDbSizeRequest = req.into_inner().try_into()?;
span_record!(db_oid=%req.db_oid, lsn=%req.read_lsn);
@@ -3711,14 +3898,29 @@ impl proto::PageService for GrpcPageServiceHandler {
req: tonic::Request<tonic::Streaming<proto::GetPageRequest>>,
) -> Result<tonic::Response<Self::GetPagesStream>, tonic::Status> {
// Extract the timeline from the request and check that it exists.
let ttid = *extract::<TenantTimelineId>(&req);
//
// NB: during shard splits, the compute may still send requests to the parent shard. We'll
// reroute requests to the child shards below, but we also detect the common cases here
// where either the shard exists or no shards exist at all. If we have a child shard, we
// can't acquire a weak handle because we don't know which child shard to use yet.
let TenantTimelineId {
tenant_id,
timeline_id,
} = *extract::<TenantTimelineId>(&req);
let shard_index = *extract::<ShardIndex>(&req);
let shard_selector = ShardSelector::Known(shard_index);
let mut handles = TimelineHandles::new(self.tenant_manager.clone());
handles
.get(ttid.tenant_id, ttid.timeline_id, shard_selector)
.await?;
let timeline = match handles
.get(tenant_id, timeline_id, ShardSelector::Known(shard_index))
.await
{
// The timeline shard exists. Keep a weak handle to reuse for each request.
Ok(timeline) => Some(timeline.downgrade()),
// The shard doesn't exist, but a child shard does. We'll reroute requests later.
Err(_) if self.tenant_manager.has_child_shard(tenant_id, shard_index) => None,
// Failed to fetch the timeline, and no child shard exists. Error out.
Err(err) => return Err(err.into()),
};
// Spawn an IoConcurrency sidecar, if enabled.
let gate_guard = self
@@ -3735,11 +3937,9 @@ impl proto::PageService for GrpcPageServiceHandler {
let mut reqs = req.into_inner();
let resps = async_stream::try_stream! {
let timeline = handles
.get(ttid.tenant_id, ttid.timeline_id, shard_selector)
.await?
.downgrade();
loop {
// Wait for the next client request.
//
// NB: Tonic considers the entire stream to be an in-flight request and will wait
// for it to complete before shutting down. React to cancellation between requests.
let req = tokio::select! {
@@ -3752,16 +3952,44 @@ impl proto::PageService for GrpcPageServiceHandler {
Err(err) => Err(err),
},
}?;
let received_at = Instant::now();
let req_id = req.request_id.map(page_api::RequestID::from).unwrap_or_default();
let result = Self::get_page(&ctx, &timeline, req, io_concurrency.clone())
// Process the request, using a closure to capture errors.
let process_request = async || {
let req = page_api::GetPageRequest::try_from(req)?;
// Fast path: use the pre-acquired timeline handle.
if let Some(Ok(timeline)) = timeline.as_ref().map(|t| t.upgrade()) {
return Self::get_page(&ctx, timeline, req, io_concurrency.clone(), received_at)
.instrument(span.clone()) // propagate request span
.await
}
// The timeline handle is stale. During shard splits, the compute may still be
// sending requests to the parent shard. Try to re-route requests to the child
// shards, and split any batch requests that straddle multiple child shards.
Self::maybe_split_get_page(
&ctx,
&mut handles,
tenant_id,
timeline_id,
shard_index,
req,
io_concurrency.clone(),
received_at,
)
.instrument(span.clone()) // propagate request span
.await;
yield match result {
Ok(resp) => resp,
// Convert per-request errors to GetPageResponses as appropriate, or terminate
// the stream with a tonic::Status. Log the error regardless, since
// ObservabilityLayer can't automatically log stream errors.
.await
};
// Return the response. Convert per-request errors to GetPageResponses if
// appropriate, or terminate the stream with a tonic::Status.
yield match process_request().await {
Ok(resp) => resp.into(),
Err(status) => {
// Log the error, since ObservabilityLayer won't see stream errors.
// TODO: it would be nice if we could propagate the get_page() fields here.
span.in_scope(|| {
warn!("request failed with {:?}: {}", status.code(), status.message());
@@ -3781,11 +4009,10 @@ impl proto::PageService for GrpcPageServiceHandler {
req: tonic::Request<proto::GetRelSizeRequest>,
) -> Result<tonic::Response<proto::GetRelSizeResponse>, tonic::Status> {
let received_at = extract::<ReceivedAt>(&req).0;
let timeline = self.get_request_timeline(&req).await?;
let timeline = self.get_request_timeline_shard_zero(&req).await?;
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
Self::ensure_shard_zero(&timeline)?;
let req: page_api::GetRelSizeRequest = req.into_inner().try_into()?;
let allow_missing = req.allow_missing;
@@ -3818,11 +4045,10 @@ impl proto::PageService for GrpcPageServiceHandler {
req: tonic::Request<proto::GetSlruSegmentRequest>,
) -> Result<tonic::Response<proto::GetSlruSegmentResponse>, tonic::Status> {
let received_at = extract::<ReceivedAt>(&req).0;
let timeline = self.get_request_timeline(&req).await?;
let timeline = self.get_request_timeline_shard_zero(&req).await?;
let ctx = self.ctx.with_scope_page_service_pagestream(&timeline);
// Validate the request, decorate the span, and convert it to a Pagestream request.
Self::ensure_shard_zero(&timeline)?;
let req: page_api::GetSlruSegmentRequest = req.into_inner().try_into()?;
span_record!(kind=%req.kind, segno=%req.segno, lsn=%req.read_lsn);
@@ -3852,6 +4078,10 @@ impl proto::PageService for GrpcPageServiceHandler {
&self,
req: tonic::Request<proto::LeaseLsnRequest>,
) -> Result<tonic::Response<proto::LeaseLsnResponse>, tonic::Status> {
// TODO: this won't work during shard splits, as the request is directed at a specific shard
// but the parent shard is removed before the split commits and the compute is notified
// (which can take several minutes for large tenants). That's also the case for the libpq
// implementation, so we keep the behavior for now.
let timeline = self.get_request_timeline(&req).await?;
let ctx = self.ctx.with_scope_timeline(&timeline);

View File

@@ -826,6 +826,18 @@ impl TenantManager {
peek_slot.is_some()
}
/// Returns whether a local shard exists that's a child of the given tenant shard. Note that
/// this just checks for any shard with a larger shard count, and it may not be a direct child
/// of the given shard (their keyspace may not overlap).
pub(crate) fn has_child_shard(&self, tenant_id: TenantId, shard_index: ShardIndex) -> bool {
match &*self.tenants.read().unwrap() {
TenantsMap::Initializing => false,
TenantsMap::Open(slots) | TenantsMap::ShuttingDown(slots) => slots
.range(TenantShardId::tenant_range(tenant_id))
.any(|(tsid, _)| tsid.shard_count > shard_index.shard_count),
}
}
#[instrument(skip_all, fields(tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug()))]
pub(crate) async fn upsert_location(
&self,
@@ -1522,6 +1534,13 @@ impl TenantManager {
self.resources.deletion_queue_client.flush_advisory();
// Phase 2: Put the parent shard to InProgress and grab a reference to the parent Tenant
//
// TODO: keeping the parent as InProgress while spawning the children causes read
// unavailability, as we can't acquire a new timeline handle for it (existing handles appear
// to still work though, even downgraded ones). The parent should be available for reads
// until the children are ready -- potentially until *all* subsplits across all parent
// shards are complete and the compute has been notified. See:
// <https://databricks.atlassian.net/browse/LKB-672>.
drop(tenant);
let mut parent_slot_guard =
self.tenant_map_acquire_slot(&tenant_shard_id, TenantSlotAcquireMode::Any)?;

View File

@@ -70,7 +70,7 @@ use tracing::*;
use utils::generation::Generation;
use utils::guard_arc_swap::GuardArcSwap;
use utils::id::TimelineId;
use utils::logging::{MonitorSlowFutureCallback, monitor_slow_future};
use utils::logging::{MonitorSlowFutureCallback, log_slow, monitor_slow_future};
use utils::lsn::{AtomicLsn, Lsn, RecordLsn};
use utils::postgres_client::PostgresClientProtocol;
use utils::rate_limit::RateLimit;
@@ -6915,7 +6915,13 @@ impl Timeline {
write_guard.store_and_unlock(new_gc_cutoff)
};
waitlist.wait().await;
let waitlist_wait_fut = std::pin::pin!(waitlist.wait());
log_slow(
"applied_gc_cutoff waitlist wait",
Duration::from_secs(30),
waitlist_wait_fut,
)
.await;
info!("GC starting");

View File

@@ -224,11 +224,11 @@ use tracing::{instrument, trace};
use utils::id::TimelineId;
use utils::shard::{ShardIndex, ShardNumber};
use crate::tenant::mgr::ShardSelector;
use crate::page_service::GetActiveTimelineError;
use crate::tenant::GetTimelineError;
use crate::tenant::mgr::{GetActiveTenantError, ShardSelector};
/// The requirement for Debug is so that #[derive(Debug)] works in some places.
pub(crate) trait Types: Sized + std::fmt::Debug {
type TenantManagerError: Sized + std::fmt::Debug;
pub(crate) trait Types: Sized {
type TenantManager: TenantManager<Self> + Sized;
type Timeline: Timeline<Self> + Sized;
}
@@ -307,12 +307,11 @@ impl<T: Types> Default for PerTimelineState<T> {
/// Abstract view of [`crate::tenant::mgr`], for testability.
pub(crate) trait TenantManager<T: Types> {
/// Invoked by [`Cache::get`] to resolve a [`ShardTimelineId`] to a [`Types::Timeline`].
/// Errors are returned as [`GetError::TenantManager`].
async fn resolve(
&self,
timeline_id: TimelineId,
shard_selector: ShardSelector,
) -> Result<T::Timeline, T::TenantManagerError>;
) -> Result<T::Timeline, GetActiveTimelineError>;
}
/// Abstract view of an [`Arc<Timeline>`], for testability.
@@ -322,13 +321,6 @@ pub(crate) trait Timeline<T: Types> {
fn per_timeline_state(&self) -> &PerTimelineState<T>;
}
/// Errors returned by [`Cache::get`].
#[derive(Debug)]
pub(crate) enum GetError<T: Types> {
TenantManager(T::TenantManagerError),
PerTimelineStateShutDown,
}
/// Internal type used in [`Cache::get`].
enum RoutingResult<T: Types> {
FastPath(Handle<T>),
@@ -345,7 +337,7 @@ impl<T: Types> Cache<T> {
timeline_id: TimelineId,
shard_selector: ShardSelector,
tenant_manager: &T::TenantManager,
) -> Result<Handle<T>, GetError<T>> {
) -> Result<Handle<T>, GetActiveTimelineError> {
const GET_MAX_RETRIES: usize = 10;
const RETRY_BACKOFF: Duration = Duration::from_millis(100);
let mut attempt = 0;
@@ -356,7 +348,11 @@ impl<T: Types> Cache<T> {
.await
{
Ok(handle) => return Ok(handle),
Err(e) => {
Err(
e @ GetActiveTimelineError::Tenant(GetActiveTenantError::WaitForActiveTimeout {
..
}),
) => {
// Retry on tenant manager error to handle tenant split more gracefully
if attempt < GET_MAX_RETRIES {
tokio::time::sleep(RETRY_BACKOFF).await;
@@ -370,6 +366,7 @@ impl<T: Types> Cache<T> {
return Err(e);
}
}
Err(err) => return Err(err),
}
}
}
@@ -388,7 +385,7 @@ impl<T: Types> Cache<T> {
timeline_id: TimelineId,
shard_selector: ShardSelector,
tenant_manager: &T::TenantManager,
) -> Result<Handle<T>, GetError<T>> {
) -> Result<Handle<T>, GetActiveTimelineError> {
// terminates because when every iteration we remove an element from the map
let miss: ShardSelector = loop {
let routing_state = self.shard_routing(timeline_id, shard_selector);
@@ -468,60 +465,50 @@ impl<T: Types> Cache<T> {
timeline_id: TimelineId,
shard_selector: ShardSelector,
tenant_manager: &T::TenantManager,
) -> Result<Handle<T>, GetError<T>> {
match tenant_manager.resolve(timeline_id, shard_selector).await {
Ok(timeline) => {
let key = timeline.shard_timeline_id();
match &shard_selector {
ShardSelector::Zero => assert_eq!(key.shard_index.shard_number, ShardNumber(0)),
ShardSelector::Page(_) => (), // gotta trust tenant_manager
ShardSelector::Known(idx) => assert_eq!(idx, &key.shard_index),
}
trace!("creating new HandleInner");
let timeline = Arc::new(timeline);
let handle_inner_arc =
Arc::new(Mutex::new(HandleInner::Open(Arc::clone(&timeline))));
let handle_weak = WeakHandle {
inner: Arc::downgrade(&handle_inner_arc),
};
let handle = handle_weak
.upgrade()
.ok()
.expect("we just created it and it's not linked anywhere yet");
{
let mut lock_guard = timeline
.per_timeline_state()
.handles
.lock()
.expect("mutex poisoned");
match &mut *lock_guard {
Some(per_timeline_state) => {
let replaced =
per_timeline_state.insert(self.id, Arc::clone(&handle_inner_arc));
assert!(replaced.is_none(), "some earlier code left a stale handle");
match self.map.entry(key) {
hash_map::Entry::Occupied(_o) => {
// This cannot not happen because
// 1. we're the _miss_ handle, i.e., `self.map` didn't contain an entry and
// 2. we were holding &mut self during .resolve().await above, so, no other thread can have inserted a handle
// while we were waiting for the tenant manager.
unreachable!()
}
hash_map::Entry::Vacant(v) => {
v.insert(handle_weak);
}
}
}
None => {
return Err(GetError::PerTimelineStateShutDown);
}
}
}
Ok(handle)
}
Err(e) => Err(GetError::TenantManager(e)),
) -> Result<Handle<T>, GetActiveTimelineError> {
let timeline = tenant_manager.resolve(timeline_id, shard_selector).await?;
let key = timeline.shard_timeline_id();
match &shard_selector {
ShardSelector::Zero => assert_eq!(key.shard_index.shard_number, ShardNumber(0)),
ShardSelector::Page(_) => (), // gotta trust tenant_manager
ShardSelector::Known(idx) => assert_eq!(idx, &key.shard_index),
}
trace!("creating new HandleInner");
let timeline = Arc::new(timeline);
let handle_inner_arc = Arc::new(Mutex::new(HandleInner::Open(Arc::clone(&timeline))));
let handle_weak = WeakHandle {
inner: Arc::downgrade(&handle_inner_arc),
};
let handle = handle_weak
.upgrade()
.ok()
.expect("we just created it and it's not linked anywhere yet");
let mut lock_guard = timeline
.per_timeline_state()
.handles
.lock()
.expect("mutex poisoned");
let Some(per_timeline_state) = &mut *lock_guard else {
return Err(GetActiveTimelineError::Timeline(
GetTimelineError::ShuttingDown,
));
};
let replaced = per_timeline_state.insert(self.id, Arc::clone(&handle_inner_arc));
assert!(replaced.is_none(), "some earlier code left a stale handle");
match self.map.entry(key) {
hash_map::Entry::Occupied(_o) => {
// This cannot not happen because
// 1. we're the _miss_ handle, i.e., `self.map` didn't contain an entry and
// 2. we were holding &mut self during .resolve().await above, so, no other thread can have inserted a handle
// while we were waiting for the tenant manager.
unreachable!()
}
hash_map::Entry::Vacant(v) => {
v.insert(handle_weak);
}
}
Ok(handle)
}
}
@@ -655,7 +642,8 @@ mod tests {
use pageserver_api::models::ShardParameters;
use pageserver_api::reltag::RelTag;
use pageserver_api::shard::DEFAULT_STRIPE_SIZE;
use utils::shard::ShardCount;
use utils::id::TenantId;
use utils::shard::{ShardCount, TenantShardId};
use utils::sync::gate::GateGuard;
use super::*;
@@ -665,7 +653,6 @@ mod tests {
#[derive(Debug)]
struct TestTypes;
impl Types for TestTypes {
type TenantManagerError = anyhow::Error;
type TenantManager = StubManager;
type Timeline = Entered;
}
@@ -716,40 +703,48 @@ mod tests {
&self,
timeline_id: TimelineId,
shard_selector: ShardSelector,
) -> anyhow::Result<Entered> {
) -> Result<Entered, GetActiveTimelineError> {
fn enter_gate(
timeline: &StubTimeline,
) -> Result<Arc<GateGuard>, GetActiveTimelineError> {
Ok(Arc::new(timeline.gate.enter().map_err(|_| {
GetActiveTimelineError::Timeline(GetTimelineError::ShuttingDown)
})?))
}
for timeline in &self.shards {
if timeline.id == timeline_id {
let enter_gate = || {
let gate_guard = timeline.gate.enter()?;
let gate_guard = Arc::new(gate_guard);
anyhow::Ok(gate_guard)
};
match &shard_selector {
ShardSelector::Zero if timeline.shard.is_shard_zero() => {
return Ok(Entered {
timeline: Arc::clone(timeline),
gate_guard: enter_gate()?,
gate_guard: enter_gate(timeline)?,
});
}
ShardSelector::Zero => continue,
ShardSelector::Page(key) if timeline.shard.is_key_local(key) => {
return Ok(Entered {
timeline: Arc::clone(timeline),
gate_guard: enter_gate()?,
gate_guard: enter_gate(timeline)?,
});
}
ShardSelector::Page(_) => continue,
ShardSelector::Known(idx) if idx == &timeline.shard.shard_index() => {
return Ok(Entered {
timeline: Arc::clone(timeline),
gate_guard: enter_gate()?,
gate_guard: enter_gate(timeline)?,
});
}
ShardSelector::Known(_) => continue,
}
}
}
anyhow::bail!("not found")
Err(GetActiveTimelineError::Timeline(
GetTimelineError::NotFound {
tenant_id: TenantShardId::unsharded(TenantId::from([0; 16])),
timeline_id,
},
))
}
}

View File

@@ -52,7 +52,7 @@ pub(crate) fn regenerate(
};
// Express a static value for how many shards we may schedule on one node
const MAX_SHARDS: u32 = 5000;
const MAX_SHARDS: u32 = 2500;
let mut doc = PageserverUtilization {
disk_usage_bytes: used,

View File

@@ -79,10 +79,6 @@
#include "access/xlogrecovery.h"
#endif
#if PG_VERSION_NUM < 160000
typedef PGAlignedBlock PGIOAlignedBlock;
#endif
#define NEON_PANIC_CONNECTION_STATE(shard_no, elvl, message, ...) \
neon_shard_log(shard_no, elvl, "Broken connection state: " message, \
##__VA_ARGS__)

View File

@@ -14,7 +14,7 @@
#include "extension_server.h"
#include "neon_utils.h"
static int extension_server_port = 0;
int hadron_extension_server_port = 0;
static int extension_server_request_timeout = 60;
static int extension_server_connect_timeout = 60;
@@ -47,7 +47,7 @@ neon_download_extension_file_http(const char *filename, bool is_library)
curl_easy_setopt(handle, CURLOPT_CONNECTTIMEOUT, (long)extension_server_connect_timeout /* seconds */ );
compute_ctl_url = psprintf("http://localhost:%d/extension_server/%s%s",
extension_server_port, filename, is_library ? "?is_library=true" : "");
hadron_extension_server_port, filename, is_library ? "?is_library=true" : "");
elog(LOG, "Sending request to compute_ctl: %s", compute_ctl_url);
@@ -82,7 +82,7 @@ pg_init_extension_server()
DefineCustomIntVariable("neon.extension_server_port",
"connection string to the compute_ctl",
NULL,
&extension_server_port,
&hadron_extension_server_port,
0, 0, INT_MAX,
PGC_POSTMASTER,
0, /* no flags required */

View File

@@ -635,6 +635,11 @@ lfc_init(void)
NULL);
}
/*
* Dump a list of pages that are currently in the LFC
*
* This is used to get a snapshot that can be used to prewarm the LFC later.
*/
FileCacheState*
lfc_get_state(size_t max_entries)
{
@@ -1827,125 +1832,46 @@ lfc_writev(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
LWLockRelease(lfc_lock);
}
typedef struct
/*
* Return metrics about the LFC.
*
* The return format is a palloc'd array of LfcStatsEntrys. The size
* of the returned array is returned in *num_entries.
*/
LfcStatsEntry *
lfc_get_stats(size_t *num_entries)
{
TupleDesc tupdesc;
} NeonGetStatsCtx;
LfcStatsEntry *entries;
size_t n = 0;
#define NUM_NEON_GET_STATS_COLS 2
#define MAX_ENTRIES 10
entries = palloc(sizeof(LfcStatsEntry) * MAX_ENTRIES);
PG_FUNCTION_INFO_V1(neon_get_lfc_stats);
Datum
neon_get_lfc_stats(PG_FUNCTION_ARGS)
{
FuncCallContext *funcctx;
NeonGetStatsCtx *fctx;
MemoryContext oldcontext;
TupleDesc tupledesc;
Datum result;
HeapTuple tuple;
char const *key;
uint64 value = 0;
Datum values[NUM_NEON_GET_STATS_COLS];
bool nulls[NUM_NEON_GET_STATS_COLS];
entries[n++] = (LfcStatsEntry) {"file_cache_chunk_size_pages", lfc_ctl == NULL,
lfc_ctl ? lfc_blocks_per_chunk : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_misses", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->misses : 0};
entries[n++] = (LfcStatsEntry) {"file_cache_hits", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->hits : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_used", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->used : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_writes", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->writes : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_size", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->size : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_used_pages", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->used_pages : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_evicted_pages", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->evicted_pages : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_limit", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->limit : 0 };
entries[n++] = (LfcStatsEntry) {"file_cache_chunks_pinned", lfc_ctl == NULL,
lfc_ctl ? lfc_ctl->pinned : 0 };
Assert(n <= MAX_ENTRIES);
#undef MAX_ENTRIES
if (SRF_IS_FIRSTCALL())
{
funcctx = SRF_FIRSTCALL_INIT();
/* Switch context when allocating stuff to be used in later calls */
oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
/* Create a user function context for cross-call persistence */
fctx = (NeonGetStatsCtx *) palloc(sizeof(NeonGetStatsCtx));
/* Construct a tuple descriptor for the result rows. */
tupledesc = CreateTemplateTupleDesc(NUM_NEON_GET_STATS_COLS);
TupleDescInitEntry(tupledesc, (AttrNumber) 1, "lfc_key",
TEXTOID, -1, 0);
TupleDescInitEntry(tupledesc, (AttrNumber) 2, "lfc_value",
INT8OID, -1, 0);
fctx->tupdesc = BlessTupleDesc(tupledesc);
funcctx->user_fctx = fctx;
/* Return to original context when allocating transient memory */
MemoryContextSwitchTo(oldcontext);
}
funcctx = SRF_PERCALL_SETUP();
/* Get the saved state */
fctx = (NeonGetStatsCtx *) funcctx->user_fctx;
switch (funcctx->call_cntr)
{
case 0:
key = "file_cache_misses";
if (lfc_ctl)
value = lfc_ctl->misses;
break;
case 1:
key = "file_cache_hits";
if (lfc_ctl)
value = lfc_ctl->hits;
break;
case 2:
key = "file_cache_used";
if (lfc_ctl)
value = lfc_ctl->used;
break;
case 3:
key = "file_cache_writes";
if (lfc_ctl)
value = lfc_ctl->writes;
break;
case 4:
key = "file_cache_size";
if (lfc_ctl)
value = lfc_ctl->size;
break;
case 5:
key = "file_cache_used_pages";
if (lfc_ctl)
value = lfc_ctl->used_pages;
break;
case 6:
key = "file_cache_evicted_pages";
if (lfc_ctl)
value = lfc_ctl->evicted_pages;
break;
case 7:
key = "file_cache_limit";
if (lfc_ctl)
value = lfc_ctl->limit;
break;
case 8:
key = "file_cache_chunk_size_pages";
value = lfc_blocks_per_chunk;
break;
case 9:
key = "file_cache_chunks_pinned";
if (lfc_ctl)
value = lfc_ctl->pinned;
break;
default:
SRF_RETURN_DONE(funcctx);
}
values[0] = PointerGetDatum(cstring_to_text(key));
nulls[0] = false;
if (lfc_ctl)
{
nulls[1] = false;
values[1] = Int64GetDatum(value);
}
else
nulls[1] = true;
tuple = heap_form_tuple(fctx->tupdesc, values, nulls);
result = HeapTupleGetDatum(tuple);
SRF_RETURN_NEXT(funcctx, result);
*num_entries = n;
return entries;
}
@@ -1953,193 +1879,86 @@ neon_get_lfc_stats(PG_FUNCTION_ARGS)
* Function returning data from the local file cache
* relation node/tablespace/database/blocknum and access_counter
*/
PG_FUNCTION_INFO_V1(local_cache_pages);
/*
* Record structure holding the to be exposed cache data.
*/
typedef struct
LocalCachePagesRec *
lfc_local_cache_pages(size_t *num_entries)
{
uint32 pageoffs;
Oid relfilenode;
Oid reltablespace;
Oid reldatabase;
ForkNumber forknum;
BlockNumber blocknum;
uint16 accesscount;
} LocalCachePagesRec;
HASH_SEQ_STATUS status;
FileCacheEntry *entry;
size_t n_pages;
size_t n;
LocalCachePagesRec *result;
/*
* Function context for data persisting over repeated calls.
*/
typedef struct
{
TupleDesc tupdesc;
LocalCachePagesRec *record;
} LocalCachePagesContext;
#define NUM_LOCALCACHE_PAGES_ELEM 7
Datum
local_cache_pages(PG_FUNCTION_ARGS)
{
FuncCallContext *funcctx;
Datum result;
MemoryContext oldcontext;
LocalCachePagesContext *fctx; /* User function context. */
TupleDesc tupledesc;
TupleDesc expected_tupledesc;
HeapTuple tuple;
if (SRF_IS_FIRSTCALL())
if (!lfc_ctl)
{
HASH_SEQ_STATUS status;
FileCacheEntry *entry;
uint32 n_pages = 0;
*num_entries = 0;
return NULL;
}
funcctx = SRF_FIRSTCALL_INIT();
LWLockAcquire(lfc_lock, LW_SHARED);
if (!LFC_ENABLED())
{
LWLockRelease(lfc_lock);
*num_entries = 0;
return NULL;
}
/* Switch context when allocating stuff to be used in later calls */
oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
/* Create a user function context for cross-call persistence */
fctx = (LocalCachePagesContext *) palloc(sizeof(LocalCachePagesContext));
/*
* To smoothly support upgrades from version 1.0 of this extension
* transparently handle the (non-)existence of the pinning_backends
* column. We unfortunately have to get the result type for that... -
* we can't use the result type determined by the function definition
* without potentially crashing when somebody uses the old (or even
* wrong) function definition though.
*/
if (get_call_result_type(fcinfo, NULL, &expected_tupledesc) != TYPEFUNC_COMPOSITE)
neon_log(ERROR, "return type must be a row type");
if (expected_tupledesc->natts != NUM_LOCALCACHE_PAGES_ELEM)
neon_log(ERROR, "incorrect number of output arguments");
/* Construct a tuple descriptor for the result rows. */
tupledesc = CreateTemplateTupleDesc(expected_tupledesc->natts);
TupleDescInitEntry(tupledesc, (AttrNumber) 1, "pageoffs",
INT8OID, -1, 0);
#if PG_MAJORVERSION_NUM < 16
TupleDescInitEntry(tupledesc, (AttrNumber) 2, "relfilenode",
OIDOID, -1, 0);
#else
TupleDescInitEntry(tupledesc, (AttrNumber) 2, "relfilenumber",
OIDOID, -1, 0);
#endif
TupleDescInitEntry(tupledesc, (AttrNumber) 3, "reltablespace",
OIDOID, -1, 0);
TupleDescInitEntry(tupledesc, (AttrNumber) 4, "reldatabase",
OIDOID, -1, 0);
TupleDescInitEntry(tupledesc, (AttrNumber) 5, "relforknumber",
INT2OID, -1, 0);
TupleDescInitEntry(tupledesc, (AttrNumber) 6, "relblocknumber",
INT8OID, -1, 0);
TupleDescInitEntry(tupledesc, (AttrNumber) 7, "accesscount",
INT4OID, -1, 0);
fctx->tupdesc = BlessTupleDesc(tupledesc);
if (lfc_ctl)
/* Count the pages first */
n_pages = 0;
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
/* Skip hole tags */
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
{
LWLockAcquire(lfc_lock, LW_SHARED);
for (int i = 0; i < lfc_blocks_per_chunk; i++)
n_pages += GET_STATE(entry, i) == AVAILABLE;
}
}
if (LFC_ENABLED())
if (n_pages == 0)
{
LWLockRelease(lfc_lock);
*num_entries = 0;
return NULL;
}
result = (LocalCachePagesRec *)
MemoryContextAllocHuge(CurrentMemoryContext,
sizeof(LocalCachePagesRec) * n_pages);
/*
* Scan through all the cache entries, saving the relevant fields
* in the result structure.
*/
n = 0;
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
{
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
{
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
if (GET_STATE(entry, i) == AVAILABLE)
{
/* Skip hole tags */
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
n_pages += GET_STATE(entry, i) == AVAILABLE;
}
result[n].pageoffs = entry->offset * lfc_blocks_per_chunk + i;
result[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key));
result[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(entry->key));
result[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(entry->key));
result[n].forknum = entry->key.forkNum;
result[n].blocknum = entry->key.blockNum + i;
result[n].accesscount = entry->access_count;
n += 1;
}
}
}
fctx->record = (LocalCachePagesRec *)
MemoryContextAllocHuge(CurrentMemoryContext,
sizeof(LocalCachePagesRec) * n_pages);
/* Set max calls and remember the user function context. */
funcctx->max_calls = n_pages;
funcctx->user_fctx = fctx;
/* Return to original context when allocating transient memory */
MemoryContextSwitchTo(oldcontext);
if (n_pages != 0)
{
/*
* Scan through all the cache entries, saving the relevant fields
* in the fctx->record structure.
*/
uint32 n = 0;
hash_seq_init(&status, lfc_hash);
while ((entry = hash_seq_search(&status)) != NULL)
{
for (int i = 0; i < lfc_blocks_per_chunk; i++)
{
if (NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key)) != 0)
{
if (GET_STATE(entry, i) == AVAILABLE)
{
fctx->record[n].pageoffs = entry->offset * lfc_blocks_per_chunk + i;
fctx->record[n].relfilenode = NInfoGetRelNumber(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reltablespace = NInfoGetSpcOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].reldatabase = NInfoGetDbOid(BufTagGetNRelFileInfo(entry->key));
fctx->record[n].forknum = entry->key.forkNum;
fctx->record[n].blocknum = entry->key.blockNum + i;
fctx->record[n].accesscount = entry->access_count;
n += 1;
}
}
}
}
Assert(n_pages == n);
}
if (lfc_ctl)
LWLockRelease(lfc_lock);
}
Assert(n_pages == n);
LWLockRelease(lfc_lock);
funcctx = SRF_PERCALL_SETUP();
/* Get the saved state */
fctx = funcctx->user_fctx;
if (funcctx->call_cntr < funcctx->max_calls)
{
uint32 i = funcctx->call_cntr;
Datum values[NUM_LOCALCACHE_PAGES_ELEM];
bool nulls[NUM_LOCALCACHE_PAGES_ELEM] = {
false, false, false, false, false, false, false
};
values[0] = Int64GetDatum((int64) fctx->record[i].pageoffs);
values[1] = ObjectIdGetDatum(fctx->record[i].relfilenode);
values[2] = ObjectIdGetDatum(fctx->record[i].reltablespace);
values[3] = ObjectIdGetDatum(fctx->record[i].reldatabase);
values[4] = ObjectIdGetDatum(fctx->record[i].forknum);
values[5] = Int64GetDatum((int64) fctx->record[i].blocknum);
values[6] = Int32GetDatum(fctx->record[i].accesscount);
/* Build and return the tuple. */
tuple = heap_form_tuple(fctx->tupdesc, values, nulls);
result = HeapTupleGetDatum(tuple);
SRF_RETURN_NEXT(funcctx, result);
}
else
SRF_RETURN_DONE(funcctx);
*num_entries = n_pages;
return result;
}
/*
* Internal implementation of the approximate_working_set_size_seconds()
* function.
@@ -2267,4 +2086,3 @@ get_prewarm_info(PG_FUNCTION_ARGS)
PG_RETURN_DATUM(HeapTupleGetDatum(heap_form_tuple(tupdesc, values, nulls)));
}

View File

@@ -47,6 +47,26 @@ extern bool lfc_prefetch(NRelFileInfo rinfo, ForkNumber forknum, BlockNumber blk
extern FileCacheState* lfc_get_state(size_t max_entries);
extern void lfc_prewarm(FileCacheState* fcs, uint32 n_workers);
typedef struct LfcStatsEntry
{
const char *metric_name;
bool isnull;
uint64 value;
} LfcStatsEntry;
extern LfcStatsEntry *lfc_get_stats(size_t *num_entries);
typedef struct
{
uint32 pageoffs;
Oid relfilenode;
Oid reltablespace;
Oid reldatabase;
ForkNumber forknum;
BlockNumber blocknum;
uint16 accesscount;
} LocalCachePagesRec;
extern LocalCachePagesRec *lfc_local_cache_pages(size_t *num_entries);
extern int32 lfc_approximate_working_set_size_seconds(time_t duration, bool reset);

View File

@@ -13,6 +13,8 @@
#include <math.h>
#include <sys/socket.h>
#include <curl/curl.h>
#include "libpq-int.h"
#include "access/xlog.h"
@@ -86,6 +88,10 @@ static int pageserver_response_log_timeout = 10000;
/* 2.5 minutes. A bit higher than highest default TCP retransmission timeout */
static int pageserver_response_disconnect_timeout = 150000;
static int conf_refresh_reconnect_attempt_threshold = 16;
// Hadron: timeout for refresh errors (1 minute)
static uint64 kRefreshErrorTimeoutUSec = 1 * USECS_PER_MINUTE;
typedef struct
{
char connstring[MAX_SHARDS][MAX_PAGESERVER_CONNSTRING_SIZE];
@@ -130,7 +136,7 @@ static uint64 pagestore_local_counter = 0;
typedef enum PSConnectionState {
PS_Disconnected, /* no connection yet */
PS_Connecting_Startup, /* connection starting up */
PS_Connecting_PageStream, /* negotiating pagestream */
PS_Connecting_PageStream, /* negotiating pagestream */
PS_Connected, /* connected, pagestream established */
} PSConnectionState;
@@ -401,7 +407,7 @@ get_shard_number(BufferTag *tag)
}
static inline void
CLEANUP_AND_DISCONNECT(PageServer *shard)
CLEANUP_AND_DISCONNECT(PageServer *shard)
{
if (shard->wes_read)
{
@@ -423,7 +429,7 @@ CLEANUP_AND_DISCONNECT(PageServer *shard)
* complete the connection (e.g. due to receiving an earlier cancellation
* during connection start).
* Returns true if successfully connected; false if the connection failed.
*
*
* Throws errors in unrecoverable situations, or when this backend's query
* is canceled.
*/
@@ -1030,6 +1036,101 @@ pageserver_disconnect_shard(shardno_t shard_no)
shard->state = PS_Disconnected;
}
// BEGIN HADRON
/*
* Nudge compute_ctl to refresh our configuration. Called when we suspect we may be
* connecting to the wrong pageservers due to a stale configuration.
*
* This is a best-effort operation. If we couldn't send the local loopback HTTP request
* to compute_ctl or if the request fails for any reason, we just log the error and move
* on.
*/
extern int hadron_extension_server_port;
// The timestamp (usec) of the first error that occurred while trying to refresh the configuration.
// Will be reset to 0 after a successful refresh.
static uint64 first_recorded_refresh_error_usec = 0;
// Request compute_ctl to refresh the configuration. This operation may fail, e.g., if the compute_ctl
// is already in the configuration state. The function returns true if the caller needs to cancel the
// current query to avoid dead/live lock.
static bool
hadron_request_configuration_refresh() {
static CURL *handle = NULL;
CURLcode res;
char *compute_ctl_url;
bool cancel_query = false;
if (!lakebase_mode)
return false;
if (handle == NULL)
{
handle = alloc_curl_handle();
curl_easy_setopt(handle, CURLOPT_CUSTOMREQUEST, "POST");
curl_easy_setopt(handle, CURLOPT_TIMEOUT, 3L /* seconds */ );
curl_easy_setopt(handle, CURLOPT_POSTFIELDS, "");
}
// Set the URL
compute_ctl_url = psprintf("http://localhost:%d/refresh_configuration", hadron_extension_server_port);
elog(LOG, "Sending refresh configuration request to compute_ctl: %s", compute_ctl_url);
curl_easy_setopt(handle, CURLOPT_URL, compute_ctl_url);
res = curl_easy_perform(handle);
if (res != CURLE_OK )
{
elog(WARNING, "refresh_configuration request failed: %s\n", curl_easy_strerror(res));
}
else
{
long http_code = 0;
curl_easy_getinfo(handle, CURLINFO_RESPONSE_CODE, &http_code);
if ( res != CURLE_OK )
{
elog(WARNING, "compute_ctl refresh_configuration request getinfo failed: %s\n", curl_easy_strerror(res));
}
else
{
elog(LOG, "compute_ctl refresh_configuration got HTTP response: %ld\n", http_code);
if( http_code == 200 )
{
first_recorded_refresh_error_usec = 0;
}
else
{
if (first_recorded_refresh_error_usec == 0)
{
first_recorded_refresh_error_usec = GetCurrentTimestamp();
}
else if(GetCurrentTimestamp() - first_recorded_refresh_error_usec > kRefreshErrorTimeoutUSec)
{
{
first_recorded_refresh_error_usec = 0;
cancel_query = true;
}
}
}
}
}
// In regular Postgres usage, it is not necessary to manually free memory allocated by palloc (psprintf) because
// it will be cleaned up after the "memory context" is reset (e.g. after the query or the transaction is finished).
// However, the number of times this function gets called during a single query/transaction can be unbounded due to
// the various retry loops around calls to pageservers. Therefore, we need to manually free this memory here.
if (compute_ctl_url != NULL)
{
pfree(compute_ctl_url);
}
return cancel_query;
}
// END HADRON
static bool
pageserver_send(shardno_t shard_no, NeonRequest *request)
{
@@ -1064,6 +1165,11 @@ pageserver_send(shardno_t shard_no, NeonRequest *request)
while (!pageserver_connect(shard_no, shard->n_reconnect_attempts < max_reconnect_attempts ? LOG : ERROR))
{
shard->n_reconnect_attempts += 1;
if (shard->n_reconnect_attempts > conf_refresh_reconnect_attempt_threshold
&& hadron_request_configuration_refresh() )
{
neon_shard_log(shard_no, ERROR, "request failed too many times, cancelling query");
}
}
shard->n_reconnect_attempts = 0;
} else {
@@ -1171,17 +1277,26 @@ pageserver_receive(shardno_t shard_no)
pfree(msg);
pageserver_disconnect(shard_no);
resp = NULL;
/*
* Always poke compute_ctl to request a configuration refresh if we have issues receiving data from pageservers after
* successfully connecting to it. It could be an indication that we are connecting to the wrong pageservers (e.g. PS
* is in secondary mode or otherwise refuses to respond our request).
*/
hadron_request_configuration_refresh();
}
else if (rc == -2)
{
char *msg = pchomp(PQerrorMessage(pageserver_conn));
pageserver_disconnect(shard_no);
hadron_request_configuration_refresh();
neon_shard_log(shard_no, ERROR, "pageserver_receive disconnect: could not read COPY data: %s", msg);
}
else
{
pageserver_disconnect(shard_no);
hadron_request_configuration_refresh();
neon_shard_log(shard_no, ERROR, "pageserver_receive disconnect: unexpected PQgetCopyData return value: %d", rc);
}
@@ -1249,21 +1364,34 @@ pageserver_try_receive(shardno_t shard_no)
neon_shard_log(shard_no, LOG, "pageserver_receive disconnect: psql end of copy data: %s", pchomp(PQerrorMessage(pageserver_conn)));
pageserver_disconnect(shard_no);
resp = NULL;
hadron_request_configuration_refresh();
}
else if (rc == -2)
{
char *msg = pchomp(PQerrorMessage(pageserver_conn));
pageserver_disconnect(shard_no);
hadron_request_configuration_refresh();
neon_shard_log(shard_no, LOG, "pageserver_receive disconnect: could not read COPY data: %s", msg);
resp = NULL;
}
else
{
pageserver_disconnect(shard_no);
hadron_request_configuration_refresh();
neon_shard_log(shard_no, ERROR, "pageserver_receive disconnect: unexpected PQgetCopyData return value: %d", rc);
}
/*
* Always poke compute_ctl to request a configuration refresh if we have issues receiving data from pageservers after
* successfully connecting to it. It could be an indication that we are connecting to the wrong pageservers (e.g. PS
* is in secondary mode or otherwise refuses to respond our request).
*/
if ( rc < 0 && hadron_request_configuration_refresh() )
{
neon_shard_log(shard_no, ERROR, "refresh_configuration request failed, cancelling query");
}
shard->nresponses_received++;
return (NeonResponse *) resp;
}
@@ -1459,6 +1587,16 @@ pg_init_libpagestore(void)
PGC_SU_BACKEND,
0, /* no flags required */
NULL, NULL, NULL);
DefineCustomIntVariable("hadron.conf_refresh_reconnect_attempt_threshold",
"Threshold of the number of consecutive failed pageserver "
"connection attempts (per shard) before signaling "
"compute_ctl for a configuration refresh.",
NULL,
&conf_refresh_reconnect_attempt_threshold,
16, 0, INT_MAX,
PGC_USERSET,
0,
NULL, NULL, NULL);
DefineCustomIntVariable("neon.pageserver_response_log_timeout",
"pageserver response log timeout",

View File

@@ -1,7 +1,7 @@
/*-------------------------------------------------------------------------
*
* neon.c
* Main entry point into the neon exension
* Main entry point into the neon extension
*
*-------------------------------------------------------------------------
*/
@@ -51,6 +51,7 @@ void _PG_init(void);
bool lakebase_mode = false;
static int running_xacts_overflow_policy;
static emit_log_hook_type prev_emit_log_hook;
static bool monitor_query_exec_time = false;
static ExecutorStart_hook_type prev_ExecutorStart = NULL;
@@ -447,17 +448,19 @@ ReportSearchPath(void)
static int neon_pgstat_file_size_limit;
#endif
#if PG_VERSION_NUM >= 160000 && PG_VERSION_NUM < 170000
static void DatabricksSqlErrorHookImpl(int sqlerrcode) {
if (sqlerrcode == ERRCODE_DATA_CORRUPTED) {
static void DatabricksSqlErrorHookImpl(ErrorData *edata) {
if (prev_emit_log_hook != NULL) {
prev_emit_log_hook(edata);
}
if (edata->sqlerrcode == ERRCODE_DATA_CORRUPTED) {
pg_atomic_fetch_add_u32(&databricks_metrics_shared->data_corruption_count, 1);
} else if (sqlerrcode == ERRCODE_INDEX_CORRUPTED) {
} else if (edata->sqlerrcode == ERRCODE_INDEX_CORRUPTED) {
pg_atomic_fetch_add_u32(&databricks_metrics_shared->index_corruption_count, 1);
} else if (sqlerrcode == ERRCODE_INTERNAL_ERROR) {
} else if (edata->sqlerrcode == ERRCODE_INTERNAL_ERROR) {
pg_atomic_fetch_add_u32(&databricks_metrics_shared->internal_error_count, 1);
}
}
#endif
void
_PG_init(void)
@@ -470,11 +473,10 @@ _PG_init(void)
load_file("$libdir/neon_rmgr", false);
#endif
#if PG_VERSION_NUM >= 160000 && PG_VERSION_NUM < 170000
if (lakebase_mode) {
SqlErrorCode_hook = DatabricksSqlErrorHookImpl;
prev_emit_log_hook = emit_log_hook;
emit_log_hook = DatabricksSqlErrorHookImpl;
}
#endif
/*
* Initializing a pre-loaded Postgres extension happens in three stages:
@@ -528,7 +530,7 @@ _PG_init(void)
DefineCustomBoolVariable(
"neon.disable_logical_replication_subscribers",
"Disables incomming logical replication",
"Disable incoming logical replication",
NULL,
&disable_logical_replication_subscribers,
false,
@@ -587,7 +589,7 @@ _PG_init(void)
DefineCustomEnumVariable(
"neon.debug_compare_local",
"Debug mode for compaing content of pages in prefetch ring/LFC/PS and local disk",
"Debug mode for comparing content of pages in prefetch ring/LFC/PS and local disk",
NULL,
&debug_compare_local,
DEBUG_COMPARE_LOCAL_NONE,
@@ -658,11 +660,15 @@ _PG_init(void)
ExecutorEnd_hook = neon_ExecutorEnd;
}
/* Various functions exposed at SQL level */
PG_FUNCTION_INFO_V1(pg_cluster_size);
PG_FUNCTION_INFO_V1(backpressure_lsns);
PG_FUNCTION_INFO_V1(backpressure_throttling_time);
PG_FUNCTION_INFO_V1(approximate_working_set_size_seconds);
PG_FUNCTION_INFO_V1(approximate_working_set_size);
PG_FUNCTION_INFO_V1(neon_get_lfc_stats);
PG_FUNCTION_INFO_V1(local_cache_pages);
Datum
pg_cluster_size(PG_FUNCTION_ARGS)
@@ -737,6 +743,76 @@ approximate_working_set_size(PG_FUNCTION_ARGS)
PG_RETURN_INT32(dc);
}
Datum
neon_get_lfc_stats(PG_FUNCTION_ARGS)
{
#define NUM_NEON_GET_STATS_COLS 2
ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo;
LfcStatsEntry *entries;
size_t num_entries;
InitMaterializedSRF(fcinfo, 0);
/* lfc_get_stats() does all the heavy lifting */
entries = lfc_get_stats(&num_entries);
/* Convert the LfcStatsEntrys to a result set */
for (size_t i = 0; i < num_entries; i++)
{
LfcStatsEntry *entry = &entries[i];
Datum values[NUM_NEON_GET_STATS_COLS];
bool nulls[NUM_NEON_GET_STATS_COLS];
values[0] = CStringGetTextDatum(entry->metric_name);
nulls[0] = false;
values[1] = Int64GetDatum(entry->isnull ? 0 : entry->value);
nulls[1] = entry->isnull;
tuplestore_putvalues(rsinfo->setResult, rsinfo->setDesc, values, nulls);
}
PG_RETURN_VOID();
#undef NUM_NEON_GET_STATS_COLS
}
Datum
local_cache_pages(PG_FUNCTION_ARGS)
{
#define NUM_LOCALCACHE_PAGES_COLS 7
ReturnSetInfo *rsinfo = (ReturnSetInfo *) fcinfo->resultinfo;
LocalCachePagesRec *entries;
size_t num_entries;
InitMaterializedSRF(fcinfo, 0);
/* lfc_local_cache_pages() does all the heavy lifting */
entries = lfc_local_cache_pages(&num_entries);
/* Convert the LocalCachePagesRec structs to a result set */
for (size_t i = 0; i < num_entries; i++)
{
LocalCachePagesRec *entry = &entries[i];
Datum values[NUM_LOCALCACHE_PAGES_COLS];
bool nulls[NUM_LOCALCACHE_PAGES_COLS] = {
false, false, false, false, false, false, false
};
values[0] = Int64GetDatum((int64) entry->pageoffs);
values[1] = ObjectIdGetDatum(entry->relfilenode);
values[2] = ObjectIdGetDatum(entry->reltablespace);
values[3] = ObjectIdGetDatum(entry->reldatabase);
values[4] = ObjectIdGetDatum(entry->forknum);
values[5] = Int64GetDatum((int64) entry->blocknum);
values[6] = Int32GetDatum(entry->accesscount);
/* Build and return the tuple. */
tuplestore_putvalues(rsinfo->setResult, rsinfo->setDesc, values, nulls);
}
PG_RETURN_VOID();
#undef NUM_LOCALCACHE_PAGES_COLS
}
/*
* Initialization stage 2: make requests for the amount of shared memory we
* will need.
@@ -768,7 +844,6 @@ neon_shmem_request_hook(void)
static void
neon_shmem_startup_hook(void)
{
/* Initialize */
if (prev_shmem_startup_hook)
prev_shmem_startup_hook();

View File

@@ -20,7 +20,6 @@
#include "neon.h"
#include "neon_perf_counters.h"
#include "walproposer.h"
#include "walproposer.h"
/* BEGIN_HADRON */
databricks_metrics *databricks_metrics_shared;

View File

@@ -167,11 +167,7 @@ extern neon_per_backend_counters *neon_per_backend_counters_shared;
*/
#define NUM_NEON_PERF_COUNTER_SLOTS (MaxBackends + NUM_AUXILIARY_PROCS)
#if PG_VERSION_NUM >= 170000
#define MyNeonCounters (&neon_per_backend_counters_shared[MyProcNumber])
#else
#define MyNeonCounters (&neon_per_backend_counters_shared[MyProc->pgprocno])
#endif
extern void inc_getpage_wait(uint64 latency);
extern void inc_page_cache_read_wait(uint64 latency);

View File

@@ -9,6 +9,10 @@
#include "fmgr.h"
#include "storage/buf_internals.h"
#if PG_MAJORVERSION_NUM < 16
typedef PGAlignedBlock PGIOAlignedBlock;
#endif
#if PG_MAJORVERSION_NUM < 17
#define NRelFileInfoBackendIsTemp(rinfo) (rinfo.backend != InvalidBackendId)
#else
@@ -158,6 +162,10 @@ InitBufferTag(BufferTag *tag, const RelFileNode *rnode,
#define AmAutoVacuumWorkerProcess() (IsAutoVacuumWorkerProcess())
#endif
#if PG_MAJORVERSION_NUM < 17
#define MyProcNumber (MyProc - &ProcGlobal->allProcs[0])
#endif
#if PG_MAJORVERSION_NUM < 15
extern void InitMaterializedSRF(FunctionCallInfo fcinfo, bits32 flags);
extern TimeLineID GetWALInsertionTimeLine(void);

View File

@@ -72,10 +72,6 @@
#include "access/xlogrecovery.h"
#endif
#if PG_VERSION_NUM < 160000
typedef PGAlignedBlock PGIOAlignedBlock;
#endif
#include "access/nbtree.h"
#include "storage/bufpage.h"
#include "access/xlog_internal.h"

View File

@@ -13,6 +13,7 @@
#include "neon.h"
#include "neon_pgversioncompat.h"
#include "miscadmin.h"
#include "pagestore_client.h"
#include RELFILEINFO_HDR
#include "storage/smgr.h"
@@ -23,10 +24,6 @@
#include "utils/dynahash.h"
#include "utils/guc.h"
#if PG_VERSION_NUM >= 150000
#include "miscadmin.h"
#endif
typedef struct
{
NRelFileInfo rinfo;

View File

@@ -508,19 +508,45 @@ backpressure_lag_impl(void)
LSN_FORMAT_ARGS(flushPtr),
LSN_FORMAT_ARGS(applyPtr));
if ((writePtr != InvalidXLogRecPtr && max_replication_write_lag > 0 && myFlushLsn > writePtr + max_replication_write_lag * MB))
if (lakebase_mode)
{
return (myFlushLsn - writePtr - max_replication_write_lag * MB);
}
// in case PG does not have shard map initialized, we assume PG always has 1 shard at minimum.
shardno_t num_shards = Max(1, get_num_shards());
int tenant_max_replication_apply_lag = num_shards * max_replication_apply_lag;
int tenant_max_replication_flush_lag = num_shards * max_replication_flush_lag;
int tenant_max_replication_write_lag = num_shards * max_replication_write_lag;
if ((flushPtr != InvalidXLogRecPtr && max_replication_flush_lag > 0 && myFlushLsn > flushPtr + max_replication_flush_lag * MB))
{
return (myFlushLsn - flushPtr - max_replication_flush_lag * MB);
}
if ((writePtr != InvalidXLogRecPtr && tenant_max_replication_write_lag > 0 && myFlushLsn > writePtr + tenant_max_replication_write_lag * MB))
{
return (myFlushLsn - writePtr - tenant_max_replication_write_lag * MB);
}
if ((applyPtr != InvalidXLogRecPtr && max_replication_apply_lag > 0 && myFlushLsn > applyPtr + max_replication_apply_lag * MB))
if ((flushPtr != InvalidXLogRecPtr && tenant_max_replication_flush_lag > 0 && myFlushLsn > flushPtr + tenant_max_replication_flush_lag * MB))
{
return (myFlushLsn - flushPtr - tenant_max_replication_flush_lag * MB);
}
if ((applyPtr != InvalidXLogRecPtr && tenant_max_replication_apply_lag > 0 && myFlushLsn > applyPtr + tenant_max_replication_apply_lag * MB))
{
return (myFlushLsn - applyPtr - tenant_max_replication_apply_lag * MB);
}
}
else
{
return (myFlushLsn - applyPtr - max_replication_apply_lag * MB);
if ((writePtr != InvalidXLogRecPtr && max_replication_write_lag > 0 && myFlushLsn > writePtr + max_replication_write_lag * MB))
{
return (myFlushLsn - writePtr - max_replication_write_lag * MB);
}
if ((flushPtr != InvalidXLogRecPtr && max_replication_flush_lag > 0 && myFlushLsn > flushPtr + max_replication_flush_lag * MB))
{
return (myFlushLsn - flushPtr - max_replication_flush_lag * MB);
}
if ((applyPtr != InvalidXLogRecPtr && max_replication_apply_lag > 0 && myFlushLsn > applyPtr + max_replication_apply_lag * MB))
{
return (myFlushLsn - applyPtr - max_replication_apply_lag * MB);
}
}
}
return 0;

View File

@@ -33,7 +33,6 @@ env_logger.workspace = true
framed-websockets.workspace = true
futures.workspace = true
hashbrown.workspace = true
hashlink.workspace = true
hex.workspace = true
hmac.workspace = true
hostname.workspace = true
@@ -54,6 +53,7 @@ json = { path = "../libs/proxy/json" }
lasso = { workspace = true, features = ["multi-threaded"] }
measured = { workspace = true, features = ["lasso"] }
metrics.workspace = true
moka.workspace = true
once_cell.workspace = true
opentelemetry = { workspace = true, features = ["trace"] }
papaya = "0.2.0"
@@ -107,10 +107,11 @@ uuid.workspace = true
x509-cert.workspace = true
redis.workspace = true
zerocopy.workspace = true
zeroize.workspace = true
# uncomment this to use the real subzero-core crate
# subzero-core = { git = "https://github.com/neondatabase/subzero", rev = "396264617e78e8be428682f87469bb25429af88a", features = ["postgresql"], optional = true }
# this is a stub for the subzero-core crate
subzero-core = { path = "./subzero_core", features = ["postgresql"], optional = true}
subzero-core = { path = "../libs/proxy/subzero_core", features = ["postgresql"], optional = true}
ouroboros = { version = "0.18", optional = true }
# jwt stuff

View File

@@ -8,11 +8,12 @@ use tracing::{info, info_span};
use crate::auth::backend::ComputeUserInfo;
use crate::cache::Cached;
use crate::cache::node_info::CachedNodeInfo;
use crate::compute::AuthInfo;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::client::cplane_proxy_v1;
use crate::control_plane::{self, CachedNodeInfo, NodeInfo};
use crate::control_plane::{self, NodeInfo};
use crate::error::{ReportableError, UserFacingError};
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;

View File

@@ -6,7 +6,7 @@ use crate::auth::{self, AuthFlow};
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::sasl;
use crate::stream::{self, Stream};
@@ -25,13 +25,15 @@ pub(crate) async fn authenticate_cleartext(
ctx.set_auth_method(crate::context::AuthMethod::Cleartext);
let ep = EndpointIdInt::from(&info.endpoint);
let role = RoleNameInt::from(&info.user);
let auth_flow = AuthFlow::new(
client,
auth::CleartextPassword {
secret,
endpoint: ep,
pool: config.thread_pool.clone(),
role,
pool: config.scram_thread_pool.clone(),
},
);
let auth_outcome = {

View File

@@ -16,16 +16,16 @@ use tracing::{debug, info};
use crate::auth::{self, ComputeUserInfoMaybeEndpoint, validate_password_and_exchange};
use crate::cache::Cached;
use crate::cache::node_info::CachedNodeInfo;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ControlPlaneClient;
use crate::control_plane::errors::GetAuthInfoError;
use crate::control_plane::messages::EndpointRateLimitConfig;
use crate::control_plane::{
self, AccessBlockerFlags, AuthSecret, CachedNodeInfo, ControlPlaneApi, EndpointAccessControl,
RoleAccessControl,
self, AccessBlockerFlags, AuthSecret, ControlPlaneApi, EndpointAccessControl, RoleAccessControl,
};
use crate::intern::EndpointIdInt;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::pqproto::BeMessage;
use crate::proxy::NeonOptions;
use crate::proxy::wake_compute::WakeComputeBackend;
@@ -273,9 +273,11 @@ async fn authenticate_with_secret(
) -> auth::Result<ComputeCredentials> {
if let Some(password) = unauthenticated_password {
let ep = EndpointIdInt::from(&info.endpoint);
let role = RoleNameInt::from(&info.user);
let auth_outcome =
validate_password_and_exchange(&config.thread_pool, ep, &password, secret).await?;
validate_password_and_exchange(&config.scram_thread_pool, ep, role, &password, secret)
.await?;
let keys = match auth_outcome {
crate::sasl::Outcome::Success(key) => key,
crate::sasl::Outcome::Failure(reason) => {
@@ -433,11 +435,12 @@ mod tests {
use super::auth_quirks;
use super::jwt::JwkCache;
use crate::auth::{ComputeUserInfoMaybeEndpoint, IpPattern};
use crate::cache::node_info::CachedNodeInfo;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::EndpointRateLimitConfig;
use crate::control_plane::{
self, AccessBlockerFlags, CachedNodeInfo, EndpointAccessControl, RoleAccessControl,
self, AccessBlockerFlags, EndpointAccessControl, RoleAccessControl,
};
use crate::proxy::NeonOptions;
use crate::rate_limiter::EndpointRateLimiter;
@@ -498,7 +501,7 @@ mod tests {
static CONFIG: Lazy<AuthenticationConfig> = Lazy::new(|| AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(1),
scram_thread_pool: ThreadPool::new(1),
scram_protocol_timeout: std::time::Duration::from_secs(5),
ip_allowlist_check_enabled: true,
is_vpc_acccess_proxy: false,

View File

@@ -10,7 +10,7 @@ use super::backend::ComputeCredentialKeys;
use super::{AuthError, PasswordHackPayload};
use crate::context::RequestContext;
use crate::control_plane::AuthSecret;
use crate::intern::EndpointIdInt;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::pqproto::{BeAuthenticationSaslMessage, BeMessage};
use crate::sasl;
use crate::scram::threadpool::ThreadPool;
@@ -46,6 +46,7 @@ pub(crate) struct PasswordHack;
pub(crate) struct CleartextPassword {
pub(crate) pool: Arc<ThreadPool>,
pub(crate) endpoint: EndpointIdInt,
pub(crate) role: RoleNameInt,
pub(crate) secret: AuthSecret,
}
@@ -111,6 +112,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, CleartextPassword> {
let outcome = validate_password_and_exchange(
&self.state.pool,
self.state.endpoint,
self.state.role,
password,
self.state.secret,
)
@@ -165,13 +167,15 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
pub(crate) async fn validate_password_and_exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
role: RoleNameInt,
password: &[u8],
secret: AuthSecret,
) -> super::Result<sasl::Outcome<ComputeCredentialKeys>> {
match secret {
// perform scram authentication as both client and server to validate the keys
AuthSecret::Scram(scram_secret) => {
let outcome = crate::scram::exchange(pool, endpoint, &scram_secret, password).await?;
let outcome =
crate::scram::exchange(pool, endpoint, role, &scram_secret, password).await?;
let client_key = match outcome {
sasl::Outcome::Success(client_key) => client_key,

View File

@@ -29,7 +29,7 @@ use crate::config::{
};
use crate::control_plane::locks::ApiLocks;
use crate::http::health_server::AppMetrics;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::metrics::{Metrics, ServiceInfo};
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::cancel_set::CancelSet;
@@ -114,8 +114,6 @@ pub async fn run() -> anyhow::Result<()> {
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
// TODO: refactor these to use labels
debug!("Version: {GIT_VERSION}");
debug!("Build_tag: {BUILD_TAG}");
@@ -207,6 +205,11 @@ pub async fn run() -> anyhow::Result<()> {
endpoint_rate_limiter,
);
Metrics::get()
.service
.info
.set_label(ServiceInfo::running());
match futures::future::select(pin!(maintenance_tasks.join_next()), pin!(task)).await {
// exit immediately on maintenance task completion
Either::Left((Some(res), _)) => match crate::error::flatten_err(res)? {},
@@ -279,7 +282,7 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
http_config,
authentication_config: AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool: ThreadPool::new(0),
scram_thread_pool: ThreadPool::new(0),
scram_protocol_timeout: Duration::from_secs(10),
ip_allowlist_check_enabled: true,
is_vpc_acccess_proxy: false,

View File

@@ -26,7 +26,7 @@ use utils::project_git_version;
use utils::sentry_init::init_sentry;
use crate::context::RequestContext;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::metrics::{Metrics, ServiceInfo};
use crate::pglb::TlsRequired;
use crate::pqproto::FeStartupPacket;
use crate::protocol2::ConnectionInfo;
@@ -80,8 +80,6 @@ pub async fn run() -> anyhow::Result<()> {
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
let args = cli().get_matches();
let destination: String = args
.get_one::<String>("dest")
@@ -135,6 +133,12 @@ pub async fn run() -> anyhow::Result<()> {
cancellation_token.clone(),
))
.map(crate::error::flatten_err);
Metrics::get()
.service
.info
.set_label(ServiceInfo::running());
let signals_task = tokio::spawn(crate::signals::handle(cancellation_token, || {}));
// the signal task cant ever succeed.

View File

@@ -40,7 +40,7 @@ use crate::config::{
};
use crate::context::parquet::ParquetUploadArgs;
use crate::http::health_server::AppMetrics;
use crate::metrics::Metrics;
use crate::metrics::{Metrics, ServiceInfo};
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter};
use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::redis::kv_ops::RedisKVClient;
@@ -535,12 +535,7 @@ pub async fn run() -> anyhow::Result<()> {
// add a task to flush the db_schema cache every 10 minutes
#[cfg(feature = "rest_broker")]
if let Some(db_schema_cache) = &config.rest_config.db_schema_cache {
maintenance_tasks.spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(600)).await;
db_schema_cache.flush();
}
});
maintenance_tasks.spawn(db_schema_cache.maintain());
}
if let Some(metrics_config) = &config.metric_collection {
@@ -590,6 +585,11 @@ pub async fn run() -> anyhow::Result<()> {
}
}
Metrics::get()
.service
.info
.set_label(ServiceInfo::running());
let maintenance = loop {
// get one complete task
match futures::future::select(
@@ -617,7 +617,12 @@ pub async fn run() -> anyhow::Result<()> {
/// ProxyConfig is created at proxy startup, and lives forever.
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
Metrics::install(thread_pool.metrics.clone());
Metrics::get()
.proxy
.scram_pool
.0
.set(thread_pool.metrics.clone())
.ok();
let tls_config = match (&args.tls_key, &args.tls_cert) {
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(
@@ -690,7 +695,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
};
let authentication_config = AuthenticationConfig {
jwks_cache: JwkCache::default(),
thread_pool,
scram_thread_pool: thread_pool,
scram_protocol_timeout: args.scram_protocol_timeout,
ip_allowlist_check_enabled: !args.is_private_access_proxy,
is_vpc_acccess_proxy: args.is_private_access_proxy,
@@ -711,12 +716,7 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
info!("Using DbSchemaCache with options={db_schema_cache_config:?}");
let db_schema_cache = if args.is_rest_broker {
Some(DbSchemaCache::new(
"db_schema_cache",
db_schema_cache_config.size,
db_schema_cache_config.ttl,
true,
))
Some(DbSchemaCache::new(db_schema_cache_config))
} else {
None
};

View File

@@ -1,4 +1,16 @@
use std::ops::{Deref, DerefMut};
use std::time::{Duration, Instant};
use moka::Expiry;
use moka::notification::RemovalCause;
use crate::control_plane::messages::ControlPlaneErrorMessage;
use crate::metrics::{
CacheEviction, CacheKind, CacheOutcome, CacheOutcomeGroup, CacheRemovalCause, Metrics,
};
/// Default TTL used when caching errors from control plane.
pub const DEFAULT_ERROR_TTL: Duration = Duration::from_secs(30);
/// A generic trait which exposes types of cache's key and value,
/// as well as the notion of cache entry invalidation.
@@ -10,20 +22,16 @@ pub(crate) trait Cache {
/// Entry's value.
type Value;
/// Used for entry invalidation.
type LookupInfo<Key>;
/// Invalidate an entry using a lookup info.
/// We don't have an empty default impl because it's error-prone.
fn invalidate(&self, _: &Self::LookupInfo<Self::Key>);
fn invalidate(&self, _: &Self::Key);
}
impl<C: Cache> Cache for &C {
type Key = C::Key;
type Value = C::Value;
type LookupInfo<Key> = C::LookupInfo<Key>;
fn invalidate(&self, info: &Self::LookupInfo<Self::Key>) {
fn invalidate(&self, info: &Self::Key) {
C::invalidate(self, info);
}
}
@@ -31,7 +39,7 @@ impl<C: Cache> Cache for &C {
/// Wrapper for convenient entry invalidation.
pub(crate) struct Cached<C: Cache, V = <C as Cache>::Value> {
/// Cache + lookup info.
pub(crate) token: Option<(C, C::LookupInfo<C::Key>)>,
pub(crate) token: Option<(C, C::Key)>,
/// The value itself.
pub(crate) value: V,
@@ -43,23 +51,6 @@ impl<C: Cache, V> Cached<C, V> {
Self { token: None, value }
}
pub(crate) fn take_value(self) -> (Cached<C, ()>, V) {
(
Cached {
token: self.token,
value: (),
},
self.value,
)
}
pub(crate) fn map<U>(self, f: impl FnOnce(V) -> U) -> Cached<C, U> {
Cached {
token: self.token,
value: f(self.value),
}
}
/// Drop this entry from a cache if it's still there.
pub(crate) fn invalidate(self) -> V {
if let Some((cache, info)) = &self.token {
@@ -87,3 +78,91 @@ impl<C: Cache, V> DerefMut for Cached<C, V> {
&mut self.value
}
}
pub type ControlPlaneResult<T> = Result<T, Box<ControlPlaneErrorMessage>>;
#[derive(Clone, Copy)]
pub struct CplaneExpiry {
pub error: Duration,
}
impl Default for CplaneExpiry {
fn default() -> Self {
Self {
error: DEFAULT_ERROR_TTL,
}
}
}
impl CplaneExpiry {
pub fn expire_early<V>(
&self,
value: &ControlPlaneResult<V>,
updated: Instant,
) -> Option<Duration> {
match value {
Ok(_) => None,
Err(err) => Some(self.expire_err_early(err, updated)),
}
}
pub fn expire_err_early(&self, err: &ControlPlaneErrorMessage, updated: Instant) -> Duration {
err.status
.as_ref()
.and_then(|s| s.details.retry_info.as_ref())
.map_or(self.error, |r| r.retry_at.into_std() - updated)
}
}
impl<K, V> Expiry<K, ControlPlaneResult<V>> for CplaneExpiry {
fn expire_after_create(
&self,
_key: &K,
value: &ControlPlaneResult<V>,
created_at: Instant,
) -> Option<Duration> {
self.expire_early(value, created_at)
}
fn expire_after_update(
&self,
_key: &K,
value: &ControlPlaneResult<V>,
updated_at: Instant,
_duration_until_expiry: Option<Duration>,
) -> Option<Duration> {
self.expire_early(value, updated_at)
}
}
pub fn eviction_listener(kind: CacheKind, cause: RemovalCause) {
let cause = match cause {
RemovalCause::Expired => CacheRemovalCause::Expired,
RemovalCause::Explicit => CacheRemovalCause::Explicit,
RemovalCause::Replaced => CacheRemovalCause::Replaced,
RemovalCause::Size => CacheRemovalCause::Size,
};
Metrics::get()
.cache
.evicted_total
.inc(CacheEviction { cache: kind, cause });
}
#[inline]
pub fn count_cache_outcome<T>(kind: CacheKind, cache_result: Option<T>) -> Option<T> {
let outcome = if cache_result.is_some() {
CacheOutcome::Hit
} else {
CacheOutcome::Miss
};
Metrics::get().cache.request_total.inc(CacheOutcomeGroup {
cache: kind,
outcome,
});
cache_result
}
#[inline]
pub fn count_cache_insert(kind: CacheKind) {
Metrics::get().cache.inserted_total.inc(kind);
}

View File

@@ -1,6 +1,5 @@
pub(crate) mod common;
pub(crate) mod node_info;
pub(crate) mod project_info;
mod timed_lru;
pub(crate) use common::{Cache, Cached};
pub(crate) use timed_lru::TimedLru;
pub(crate) use common::{Cached, ControlPlaneResult, CplaneExpiry};

60
proxy/src/cache/node_info.rs vendored Normal file
View File

@@ -0,0 +1,60 @@
use crate::cache::common::{Cache, count_cache_insert, count_cache_outcome, eviction_listener};
use crate::cache::{Cached, ControlPlaneResult, CplaneExpiry};
use crate::config::CacheOptions;
use crate::control_plane::NodeInfo;
use crate::metrics::{CacheKind, Metrics};
use crate::types::EndpointCacheKey;
pub(crate) struct NodeInfoCache(moka::sync::Cache<EndpointCacheKey, ControlPlaneResult<NodeInfo>>);
pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
impl Cache for NodeInfoCache {
type Key = EndpointCacheKey;
type Value = ControlPlaneResult<NodeInfo>;
fn invalidate(&self, info: &EndpointCacheKey) {
self.0.invalidate(info);
}
}
impl NodeInfoCache {
pub fn new(config: CacheOptions) -> Self {
let builder = moka::sync::Cache::builder()
.name("node_info")
.expire_after(CplaneExpiry::default());
let builder = config.moka(builder);
if let Some(size) = config.size {
Metrics::get()
.cache
.capacity
.set(CacheKind::NodeInfo, size as i64);
}
let builder = builder
.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::NodeInfo, cause));
Self(builder.build())
}
pub fn insert(&self, key: EndpointCacheKey, value: ControlPlaneResult<NodeInfo>) {
count_cache_insert(CacheKind::NodeInfo);
self.0.insert(key, value);
}
pub fn get(&self, key: &EndpointCacheKey) -> Option<ControlPlaneResult<NodeInfo>> {
count_cache_outcome(CacheKind::NodeInfo, self.0.get(key))
}
pub fn get_entry(
&'static self,
key: &EndpointCacheKey,
) -> Option<ControlPlaneResult<CachedNodeInfo>> {
self.get(key).map(|res| {
res.map(|value| Cached {
token: Some((self, key.clone())),
value,
})
})
}
}

View File

@@ -1,84 +1,20 @@
use std::collections::{HashMap, HashSet, hash_map};
use std::collections::HashSet;
use std::convert::Infallible;
use std::time::Duration;
use async_trait::async_trait;
use clashmap::ClashMap;
use clashmap::mapref::one::Ref;
use rand::Rng;
use tokio::time::Instant;
use moka::sync::Cache;
use tracing::{debug, info};
use crate::cache::common::{
ControlPlaneResult, CplaneExpiry, count_cache_insert, count_cache_outcome, eviction_listener,
};
use crate::config::ProjectInfoCacheOptions;
use crate::control_plane::messages::{ControlPlaneErrorMessage, Reason};
use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::metrics::{CacheKind, Metrics};
use crate::types::{EndpointId, RoleName};
#[async_trait]
pub(crate) trait ProjectInfoCache {
fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt);
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt);
fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt);
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
}
struct Entry<T> {
expires_at: Instant,
value: T,
}
impl<T> Entry<T> {
pub(crate) fn new(value: T, ttl: Duration) -> Self {
Self {
expires_at: Instant::now() + ttl,
value,
}
}
pub(crate) fn get(&self) -> Option<&T> {
(!self.is_expired()).then_some(&self.value)
}
fn is_expired(&self) -> bool {
self.expires_at <= Instant::now()
}
}
struct EndpointInfo {
role_controls: HashMap<RoleNameInt, Entry<ControlPlaneResult<RoleAccessControl>>>,
controls: Option<Entry<ControlPlaneResult<EndpointAccessControl>>>,
}
type ControlPlaneResult<T> = Result<T, Box<ControlPlaneErrorMessage>>;
impl EndpointInfo {
pub(crate) fn get_role_secret_with_ttl(
&self,
role_name: RoleNameInt,
) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
let entry = self.role_controls.get(&role_name)?;
let ttl = entry.expires_at - Instant::now();
Some((entry.get()?.clone(), ttl))
}
pub(crate) fn get_controls_with_ttl(
&self,
) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
let entry = self.controls.as_ref()?;
let ttl = entry.expires_at - Instant::now();
Some((entry.get()?.clone(), ttl))
}
pub(crate) fn invalidate_endpoint(&mut self) {
self.controls = None;
}
pub(crate) fn invalidate_role_secret(&mut self, role_name: RoleNameInt) {
self.role_controls.remove(&role_name);
}
}
/// Cache for project info.
/// This is used to cache auth data for endpoints.
/// Invalidation is done by console notifications or by TTL (if console notifications are disabled).
@@ -86,8 +22,9 @@ impl EndpointInfo {
/// We also store endpoint-to-project mapping in the cache, to be able to access per-endpoint data.
/// One may ask, why the data is stored per project, when on the user request there is only data about the endpoint available?
/// On the cplane side updates are done per project (or per branch), so it's easier to invalidate the whole project cache.
pub struct ProjectInfoCacheImpl {
cache: ClashMap<EndpointIdInt, EndpointInfo>,
pub struct ProjectInfoCache {
role_controls: Cache<(EndpointIdInt, RoleNameInt), ControlPlaneResult<RoleAccessControl>>,
ep_controls: Cache<EndpointIdInt, ControlPlaneResult<EndpointAccessControl>>,
project2ep: ClashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
// FIXME(stefan): we need a way to GC the account2ep map.
@@ -96,16 +33,13 @@ pub struct ProjectInfoCacheImpl {
config: ProjectInfoCacheOptions,
}
#[async_trait]
impl ProjectInfoCache for ProjectInfoCacheImpl {
fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
impl ProjectInfoCache {
pub fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
info!("invalidating endpoint access for `{endpoint_id}`");
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_endpoint();
}
self.ep_controls.invalidate(&endpoint_id);
}
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
pub fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt) {
info!("invalidating endpoint access for project `{project_id}`");
let endpoints = self
.project2ep
@@ -113,13 +47,11 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_endpoint();
}
self.ep_controls.invalidate(&endpoint_id);
}
}
fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) {
pub fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt) {
info!("invalidating endpoint access for org `{account_id}`");
let endpoints = self
.account2ep
@@ -127,13 +59,15 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_endpoint();
}
self.ep_controls.invalidate(&endpoint_id);
}
}
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt) {
pub fn invalidate_role_secret_for_project(
&self,
project_id: ProjectIdInt,
role_name: RoleNameInt,
) {
info!(
"invalidating role secret for project_id `{}` and role_name `{}`",
project_id, role_name,
@@ -144,47 +78,73 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
.map(|kv| kv.value().clone())
.unwrap_or_default();
for endpoint_id in endpoints {
if let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) {
endpoint_info.invalidate_role_secret(role_name);
}
self.role_controls.invalidate(&(endpoint_id, role_name));
}
}
}
impl ProjectInfoCacheImpl {
impl ProjectInfoCache {
pub(crate) fn new(config: ProjectInfoCacheOptions) -> Self {
Metrics::get().cache.capacity.set(
CacheKind::ProjectInfoRoles,
(config.size * config.max_roles) as i64,
);
Metrics::get()
.cache
.capacity
.set(CacheKind::ProjectInfoEndpoints, config.size as i64);
// we cache errors for 30 seconds, unless retry_at is set.
let expiry = CplaneExpiry::default();
Self {
cache: ClashMap::new(),
role_controls: Cache::builder()
.name("project_info_roles")
.eviction_listener(|_k, _v, cause| {
eviction_listener(CacheKind::ProjectInfoRoles, cause);
})
.max_capacity(config.size * config.max_roles)
.time_to_live(config.ttl)
.expire_after(expiry)
.build(),
ep_controls: Cache::builder()
.name("project_info_endpoints")
.eviction_listener(|_k, _v, cause| {
eviction_listener(CacheKind::ProjectInfoEndpoints, cause);
})
.max_capacity(config.size)
.time_to_live(config.ttl)
.expire_after(expiry)
.build(),
project2ep: ClashMap::new(),
account2ep: ClashMap::new(),
config,
}
}
fn get_endpoint_cache(
&self,
endpoint_id: &EndpointId,
) -> Option<Ref<'_, EndpointIdInt, EndpointInfo>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
self.cache.get(&endpoint_id)
}
pub(crate) fn get_role_secret_with_ttl(
pub(crate) fn get_role_secret(
&self,
endpoint_id: &EndpointId,
role_name: &RoleName,
) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
) -> Option<ControlPlaneResult<RoleAccessControl>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
let role_name = RoleNameInt::get(role_name)?;
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_role_secret_with_ttl(role_name)
count_cache_outcome(
CacheKind::ProjectInfoRoles,
self.role_controls.get(&(endpoint_id, role_name)),
)
}
pub(crate) fn get_endpoint_access_with_ttl(
pub(crate) fn get_endpoint_access(
&self,
endpoint_id: &EndpointId,
) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_controls_with_ttl()
) -> Option<ControlPlaneResult<EndpointAccessControl>> {
let endpoint_id = EndpointIdInt::get(endpoint_id)?;
count_cache_outcome(
CacheKind::ProjectInfoEndpoints,
self.ep_controls.get(&endpoint_id),
)
}
pub(crate) fn insert_endpoint_access(
@@ -203,34 +163,17 @@ impl ProjectInfoCacheImpl {
self.insert_project2endpoint(project_id, endpoint_id);
}
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
debug!(
key = &*endpoint_id,
"created a cache entry for endpoint access"
);
let controls = Some(Entry::new(Ok(controls), self.config.ttl));
let role_controls = Entry::new(Ok(role_controls), self.config.ttl);
count_cache_insert(CacheKind::ProjectInfoEndpoints);
count_cache_insert(CacheKind::ProjectInfoRoles);
match self.cache.entry(endpoint_id) {
clashmap::Entry::Vacant(e) => {
e.insert(EndpointInfo {
role_controls: HashMap::from_iter([(role_name, role_controls)]),
controls,
});
}
clashmap::Entry::Occupied(mut e) => {
let ep = e.get_mut();
ep.controls = controls;
if ep.role_controls.len() < self.config.max_roles {
ep.role_controls.insert(role_name, role_controls);
}
}
}
self.ep_controls.insert(endpoint_id, Ok(controls));
self.role_controls
.insert((endpoint_id, role_name), Ok(role_controls));
}
pub(crate) fn insert_endpoint_access_err(
@@ -238,55 +181,34 @@ impl ProjectInfoCacheImpl {
endpoint_id: EndpointIdInt,
role_name: RoleNameInt,
msg: Box<ControlPlaneErrorMessage>,
ttl: Option<Duration>,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
debug!(
key = &*endpoint_id,
"created a cache entry for an endpoint access error"
);
let ttl = ttl.unwrap_or(self.config.ttl);
let controls = if msg.get_reason() == Reason::RoleProtected {
// RoleProtected is the only role-specific error that control plane can give us.
// If a given role name does not exist, it still returns a successful response,
// just with an empty secret.
None
} else {
// We can cache all the other errors in EndpointInfo.controls,
// because they don't depend on what role name we pass to control plane.
Some(Entry::new(Err(msg.clone()), ttl))
};
let role_controls = Entry::new(Err(msg), ttl);
match self.cache.entry(endpoint_id) {
clashmap::Entry::Vacant(e) => {
e.insert(EndpointInfo {
role_controls: HashMap::from_iter([(role_name, role_controls)]),
controls,
// RoleProtected is the only role-specific error that control plane can give us.
// If a given role name does not exist, it still returns a successful response,
// just with an empty secret.
if msg.get_reason() != Reason::RoleProtected {
// We can cache all the other errors in ep_controls because they don't
// depend on what role name we pass to control plane.
self.ep_controls
.entry(endpoint_id)
.and_compute_with(|entry| match entry {
// leave the entry alone if it's already Ok
Some(entry) if entry.value().is_ok() => moka::ops::compute::Op::Nop,
// replace the entry
_ => {
count_cache_insert(CacheKind::ProjectInfoEndpoints);
moka::ops::compute::Op::Put(Err(msg.clone()))
}
});
}
clashmap::Entry::Occupied(mut e) => {
let ep = e.get_mut();
if let Some(entry) = &ep.controls
&& !entry.is_expired()
&& entry.value.is_ok()
{
// If we have cached non-expired, non-error controls, keep them.
} else {
ep.controls = controls;
}
if ep.role_controls.len() < self.config.max_roles {
ep.role_controls.insert(role_name, role_controls);
}
}
}
count_cache_insert(CacheKind::ProjectInfoRoles);
self.role_controls
.insert((endpoint_id, role_name), Err(msg));
}
fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
@@ -307,73 +229,35 @@ impl ProjectInfoCacheImpl {
}
}
pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) {
let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else {
return;
};
let Some(role_name) = RoleNameInt::get(role_name) else {
return;
};
let Some(mut endpoint_info) = self.cache.get_mut(&endpoint_id) else {
return;
};
let entry = endpoint_info.role_controls.entry(role_name);
let hash_map::Entry::Occupied(role_controls) = entry else {
return;
};
if role_controls.get().is_expired() {
role_controls.remove();
}
pub fn maybe_invalidate_role_secret(&self, _endpoint_id: &EndpointId, _role_name: &RoleName) {
// TODO: Expire the value early if the key is idle.
// Currently not an issue as we would just use the TTL to decide, which is what already happens.
}
pub async fn gc_worker(&self) -> anyhow::Result<Infallible> {
let mut interval =
tokio::time::interval(self.config.gc_interval / (self.cache.shards().len()) as u32);
let mut interval = tokio::time::interval(self.config.gc_interval);
loop {
interval.tick().await;
if self.cache.len() < self.config.size {
// If there are not too many entries, wait until the next gc cycle.
continue;
}
self.gc();
self.ep_controls.run_pending_tasks();
self.role_controls.run_pending_tasks();
}
}
fn gc(&self) {
let shard = rand::rng().random_range(0..self.project2ep.shards().len());
debug!(shard, "project_info_cache: performing epoch reclamation");
// acquire a random shard lock
let mut removed = 0;
let shard = self.project2ep.shards()[shard].write();
for (_, endpoints) in shard.iter() {
for endpoint in endpoints {
self.cache.remove(endpoint);
removed += 1;
}
}
// We can drop this shard only after making sure that all endpoints are removed.
drop(shard);
info!("project_info_cache: removed {removed} endpoints");
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::time::Duration;
use super::*;
use crate::control_plane::messages::{Details, EndpointRateLimitConfig, ErrorInfo, Status};
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
use crate::scram::ServerSecret;
use std::sync::Arc;
#[tokio::test]
async fn test_project_info_cache_settings() {
tokio::time::pause();
let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
size: 2,
let cache = ProjectInfoCache::new(ProjectInfoCacheOptions {
size: 1,
max_roles: 2,
ttl: Duration::from_secs(1),
gc_interval: Duration::from_secs(600),
@@ -423,22 +307,17 @@ mod tests {
},
);
let (cached, ttl) = cache
.get_role_secret_with_ttl(&endpoint_id, &user1)
.unwrap();
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
assert_eq!(cached.unwrap().secret, secret1);
assert_eq!(ttl, cache.config.ttl);
let (cached, ttl) = cache
.get_role_secret_with_ttl(&endpoint_id, &user2)
.unwrap();
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
assert_eq!(cached.unwrap().secret, secret2);
assert_eq!(ttl, cache.config.ttl);
// Shouldn't add more than 2 roles.
let user3: RoleName = "user3".into();
let secret3 = Some(AuthSecret::Scram(ServerSecret::mock([3; 32])));
cache.role_controls.run_pending_tasks();
cache.insert_endpoint_access(
account_id,
project_id,
@@ -455,31 +334,18 @@ mod tests {
},
);
assert!(
cache
.get_role_secret_with_ttl(&endpoint_id, &user3)
.is_none()
);
cache.role_controls.run_pending_tasks();
assert_eq!(cache.role_controls.entry_count(), 2);
let cached = cache
.get_endpoint_access_with_ttl(&endpoint_id)
.unwrap()
.0
.unwrap();
assert_eq!(cached.allowed_ips, allowed_ips);
tokio::time::sleep(Duration::from_secs(2)).await;
tokio::time::advance(Duration::from_secs(2)).await;
let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user1);
assert!(cached.is_none());
let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user2);
assert!(cached.is_none());
let cached = cache.get_endpoint_access_with_ttl(&endpoint_id);
assert!(cached.is_none());
cache.role_controls.run_pending_tasks();
assert_eq!(cache.role_controls.entry_count(), 0);
}
#[tokio::test]
async fn test_caching_project_info_errors() {
let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
let cache = ProjectInfoCache::new(ProjectInfoCacheOptions {
size: 10,
max_roles: 10,
ttl: Duration::from_secs(1),
@@ -519,34 +385,23 @@ mod tests {
status: None,
});
let get_role_secret = |endpoint_id, role_name| {
cache
.get_role_secret_with_ttl(endpoint_id, role_name)
.unwrap()
.0
};
let get_endpoint_access =
|endpoint_id| cache.get_endpoint_access_with_ttl(endpoint_id).unwrap().0;
let get_role_secret =
|endpoint_id, role_name| cache.get_role_secret(endpoint_id, role_name).unwrap();
let get_endpoint_access = |endpoint_id| cache.get_endpoint_access(endpoint_id).unwrap();
// stores role-specific errors only for get_role_secret
cache.insert_endpoint_access_err(
(&endpoint_id).into(),
(&user1).into(),
role_msg.clone(),
None,
);
cache.insert_endpoint_access_err((&endpoint_id).into(), (&user1).into(), role_msg.clone());
assert_eq!(
get_role_secret(&endpoint_id, &user1).unwrap_err().error,
role_msg.error
);
assert!(cache.get_endpoint_access_with_ttl(&endpoint_id).is_none());
assert!(cache.get_endpoint_access(&endpoint_id).is_none());
// stores non-role specific errors for both get_role_secret and get_endpoint_access
cache.insert_endpoint_access_err(
(&endpoint_id).into(),
(&user1).into(),
generic_msg.clone(),
None,
);
assert_eq!(
get_role_secret(&endpoint_id, &user1).unwrap_err().error,
@@ -558,11 +413,7 @@ mod tests {
);
// error isn't returned for other roles in the same endpoint
assert!(
cache
.get_role_secret_with_ttl(&endpoint_id, &user2)
.is_none()
);
assert!(cache.get_role_secret(&endpoint_id, &user2).is_none());
// success for a role does not overwrite errors for other roles
cache.insert_endpoint_access(
@@ -590,7 +441,6 @@ mod tests {
(&endpoint_id).into(),
(&user2).into(),
generic_msg.clone(),
None,
);
assert!(get_role_secret(&endpoint_id, &user2).is_err());
assert!(get_endpoint_access(&endpoint_id).is_ok());

View File

@@ -1,262 +0,0 @@
use std::borrow::Borrow;
use std::hash::Hash;
use std::time::{Duration, Instant};
// This seems to make more sense than `lru` or `cached`:
//
// * `near/nearcore` ditched `cached` in favor of `lru`
// (https://github.com/near/nearcore/issues?q=is%3Aissue+lru+is%3Aclosed).
//
// * `lru` methods use an obscure `KeyRef` type in their contraints (which is deliberately excluded from docs).
// This severely hinders its usage both in terms of creating wrappers and supported key types.
//
// On the other hand, `hashlink` has good download stats and appears to be maintained.
use hashlink::{LruCache, linked_hash_map::RawEntryMut};
use tracing::debug;
use super::Cache;
use super::common::Cached;
/// An implementation of timed LRU cache with fixed capacity.
/// Key properties:
///
/// * Whenever a new entry is inserted, the least recently accessed one is evicted.
/// The cache also keeps track of entry's insertion time (`created_at`) and TTL (`expires_at`).
///
/// * If `update_ttl_on_retrieval` is `true`. When the entry is about to be retrieved, we check its expiration timestamp.
/// If the entry has expired, we remove it from the cache; Otherwise we bump the
/// expiration timestamp (e.g. +5mins) and change its place in LRU list to prolong
/// its existence.
///
/// * There's an API for immediate invalidation (removal) of a cache entry;
/// It's useful in case we know for sure that the entry is no longer correct.
/// See [`Cached`] for more information.
///
/// * Expired entries are kept in the cache, until they are evicted by the LRU policy,
/// or by a successful lookup (i.e. the entry hasn't expired yet).
/// There is no background job to reap the expired records.
///
/// * It's possible for an entry that has not yet expired entry to be evicted
/// before expired items. That's a bit wasteful, but probably fine in practice.
pub(crate) struct TimedLru<K, V> {
/// Cache's name for tracing.
name: &'static str,
/// The underlying cache implementation.
cache: parking_lot::Mutex<LruCache<K, Entry<V>>>,
/// Default time-to-live of a single entry.
ttl: Duration,
update_ttl_on_retrieval: bool,
}
impl<K: Hash + Eq, V> Cache for TimedLru<K, V> {
type Key = K;
type Value = V;
type LookupInfo<Key> = Key;
fn invalidate(&self, info: &Self::LookupInfo<K>) {
self.invalidate_raw(info);
}
}
struct Entry<T> {
created_at: Instant,
expires_at: Instant,
ttl: Duration,
update_ttl_on_retrieval: bool,
value: T,
}
impl<K: Hash + Eq, V> TimedLru<K, V> {
/// Construct a new LRU cache with timed entries.
pub(crate) fn new(
name: &'static str,
capacity: usize,
ttl: Duration,
update_ttl_on_retrieval: bool,
) -> Self {
Self {
name,
cache: LruCache::new(capacity).into(),
ttl,
update_ttl_on_retrieval,
}
}
/// Drop an entry from the cache if it's outdated.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn invalidate_raw(&self, key: &K) {
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let entry = match cache.raw_entry_mut().from_key(key) {
RawEntryMut::Vacant(_) => return,
RawEntryMut::Occupied(x) => x.remove(),
};
drop(cache); // drop lock before logging
let Entry {
created_at,
expires_at,
..
} = entry;
debug!(
?created_at,
?expires_at,
"processed a cache entry invalidation event"
);
}
/// Try retrieving an entry by its key, then execute `extract` if it exists.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn get_raw<Q, R>(&self, key: &Q, extract: impl FnOnce(&K, &Entry<V>) -> R) -> Option<R>
where
K: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
let now = Instant::now();
// Do costly things before taking the lock.
let mut cache = self.cache.lock();
let mut raw_entry = match cache.raw_entry_mut().from_key(key) {
RawEntryMut::Vacant(_) => return None,
RawEntryMut::Occupied(x) => x,
};
// Immeditely drop the entry if it has expired.
let entry = raw_entry.get();
if entry.expires_at <= now {
raw_entry.remove();
return None;
}
let value = extract(raw_entry.key(), entry);
let (created_at, expires_at) = (entry.created_at, entry.expires_at);
// Update the deadline and the entry's position in the LRU list.
let deadline = now.checked_add(raw_entry.get().ttl).expect("time overflow");
if raw_entry.get().update_ttl_on_retrieval {
raw_entry.get_mut().expires_at = deadline;
}
raw_entry.to_back();
drop(cache); // drop lock before logging
debug!(
created_at = format_args!("{created_at:?}"),
old_expires_at = format_args!("{expires_at:?}"),
new_expires_at = format_args!("{deadline:?}"),
"accessed a cache entry"
);
Some(value)
}
/// Insert an entry to the cache. If an entry with the same key already
/// existed, return the previous value and its creation timestamp.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn insert_raw(&self, key: K, value: V) -> (Instant, Option<V>) {
self.insert_raw_ttl(key, value, self.ttl, self.update_ttl_on_retrieval)
}
/// Insert an entry to the cache. If an entry with the same key already
/// existed, return the previous value and its creation timestamp.
#[tracing::instrument(level = "debug", fields(cache = self.name), skip_all)]
fn insert_raw_ttl(
&self,
key: K,
value: V,
ttl: Duration,
update: bool,
) -> (Instant, Option<V>) {
let created_at = Instant::now();
let expires_at = created_at.checked_add(ttl).expect("time overflow");
let entry = Entry {
created_at,
expires_at,
ttl,
update_ttl_on_retrieval: update,
value,
};
// Do costly things before taking the lock.
let old = self
.cache
.lock()
.insert(key, entry)
.map(|entry| entry.value);
debug!(
created_at = format_args!("{created_at:?}"),
expires_at = format_args!("{expires_at:?}"),
replaced = old.is_some(),
"created a cache entry"
);
(created_at, old)
}
}
impl<K: Hash + Eq + Clone, V: Clone> TimedLru<K, V> {
pub(crate) fn insert_ttl(&self, key: K, value: V, ttl: Duration) {
self.insert_raw_ttl(key, value, ttl, false);
}
#[cfg(feature = "rest_broker")]
pub(crate) fn insert(&self, key: K, value: V) {
self.insert_raw_ttl(key, value, self.ttl, self.update_ttl_on_retrieval);
}
pub(crate) fn insert_unit(&self, key: K, value: V) -> (Option<V>, Cached<&Self, ()>) {
let (_, old) = self.insert_raw(key.clone(), value);
let cached = Cached {
token: Some((self, key)),
value: (),
};
(old, cached)
}
#[cfg(feature = "rest_broker")]
pub(crate) fn flush(&self) {
let now = Instant::now();
let mut cache = self.cache.lock();
// Collect keys of expired entries first
let expired_keys: Vec<_> = cache
.iter()
.filter_map(|(key, entry)| {
if entry.expires_at <= now {
Some(key.clone())
} else {
None
}
})
.collect();
// Remove expired entries
for key in expired_keys {
cache.remove(&key);
}
}
}
impl<K: Hash + Eq, V: Clone> TimedLru<K, V> {
/// Retrieve a cached entry in convenient wrapper, alongside timing information.
pub(crate) fn get_with_created_at<Q>(
&self,
key: &Q,
) -> Option<Cached<&Self, (<Self as Cache>::Value, Instant)>>
where
K: Borrow<Q> + Clone,
Q: Hash + Eq + ?Sized,
{
self.get_raw(key, |key, entry| Cached {
token: Some((self, key.clone())),
value: (entry.value.clone(), entry.created_at),
})
}
}

View File

@@ -8,6 +8,7 @@ use futures::{FutureExt, TryFutureExt};
use itertools::Itertools;
use postgres_client::config::{AuthKeys, ChannelBinding, SslMode};
use postgres_client::connect_raw::StartupStream;
use postgres_client::error::SqlState;
use postgres_client::maybe_tls_stream::MaybeTlsStream;
use postgres_client::tls::MakeTlsConnect;
use thiserror::Error;
@@ -22,7 +23,7 @@ use crate::context::RequestContext;
use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::error::{ReportableError, UserFacingError};
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumDbConnectionsGuard};
use crate::pqproto::StartupMessageParams;
use crate::proxy::connect_compute::TlsNegotiation;
@@ -65,12 +66,13 @@ impl UserFacingError for PostgresError {
}
impl ReportableError for PostgresError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
fn get_error_kind(&self) -> ErrorKind {
match self {
PostgresError::Postgres(e) if e.as_db_error().is_some() => {
crate::error::ErrorKind::Postgres
}
PostgresError::Postgres(_) => crate::error::ErrorKind::Compute,
PostgresError::Postgres(err) => match err.as_db_error() {
Some(err) if err.code() == &SqlState::INVALID_CATALOG_NAME => ErrorKind::User,
Some(_) => ErrorKind::Postgres,
None => ErrorKind::Compute,
},
}
}
}
@@ -110,9 +112,9 @@ impl UserFacingError for ConnectionError {
}
impl ReportableError for ConnectionError {
fn get_error_kind(&self) -> crate::error::ErrorKind {
fn get_error_kind(&self) -> ErrorKind {
match self {
ConnectionError::TlsError(_) => crate::error::ErrorKind::Compute,
ConnectionError::TlsError(_) => ErrorKind::Compute,
ConnectionError::WakeComputeError(e) => e.get_error_kind(),
ConnectionError::TooManyConnectionAttempts(e) => e.get_error_kind(),
#[cfg(test)]

View File

@@ -19,7 +19,7 @@ use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::ext::TaskExt;
use crate::intern::RoleNameInt;
use crate::rate_limiter::{RateLimitAlgorithm, RateLimiterConfig};
use crate::scram::threadpool::ThreadPool;
use crate::scram;
use crate::serverless::GlobalConnPoolOptions;
use crate::serverless::cancel_set::CancelSet;
#[cfg(feature = "rest_broker")]
@@ -75,7 +75,7 @@ pub struct HttpConfig {
}
pub struct AuthenticationConfig {
pub thread_pool: Arc<ThreadPool>,
pub scram_thread_pool: Arc<scram::threadpool::ThreadPool>,
pub scram_protocol_timeout: tokio::time::Duration,
pub ip_allowlist_check_enabled: bool,
pub is_vpc_acccess_proxy: bool,
@@ -107,20 +107,23 @@ pub fn remote_storage_from_toml(s: &str) -> anyhow::Result<RemoteStorageConfig>
#[derive(Debug)]
pub struct CacheOptions {
/// Max number of entries.
pub size: usize,
pub size: Option<u64>,
/// Entry's time-to-live.
pub ttl: Duration,
pub absolute_ttl: Option<Duration>,
/// Entry's time-to-idle.
pub idle_ttl: Option<Duration>,
}
impl CacheOptions {
/// Default options for [`crate::control_plane::NodeInfoCache`].
pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,ttl=4m";
/// Default options for [`crate::cache::node_info::NodeInfoCache`].
pub const CACHE_DEFAULT_OPTIONS: &'static str = "size=4000,idle_ttl=4m";
/// Parse cache options passed via cmdline.
/// Example: [`Self::CACHE_DEFAULT_OPTIONS`].
fn parse(options: &str) -> anyhow::Result<Self> {
let mut size = None;
let mut ttl = None;
let mut absolute_ttl = None;
let mut idle_ttl = None;
for option in options.split(',') {
let (key, value) = option
@@ -129,21 +132,34 @@ impl CacheOptions {
match key {
"size" => size = Some(value.parse()?),
"ttl" => ttl = Some(humantime::parse_duration(value)?),
"absolute_ttl" | "ttl" => absolute_ttl = Some(humantime::parse_duration(value)?),
"idle_ttl" | "tti" => idle_ttl = Some(humantime::parse_duration(value)?),
unknown => bail!("unknown key: {unknown}"),
}
}
// TTL doesn't matter if cache is always empty.
if let Some(0) = size {
ttl.get_or_insert(Duration::default());
}
Ok(Self {
size: size.context("missing `size`")?,
ttl: ttl.context("missing `ttl`")?,
size,
absolute_ttl,
idle_ttl,
})
}
pub fn moka<K, V, C>(
&self,
mut builder: moka::sync::CacheBuilder<K, V, C>,
) -> moka::sync::CacheBuilder<K, V, C> {
if let Some(size) = self.size {
builder = builder.max_capacity(size);
}
if let Some(ttl) = self.absolute_ttl {
builder = builder.time_to_live(ttl);
}
if let Some(tti) = self.idle_ttl {
builder = builder.time_to_idle(tti);
}
builder
}
}
impl FromStr for CacheOptions {
@@ -159,17 +175,17 @@ impl FromStr for CacheOptions {
#[derive(Debug)]
pub struct ProjectInfoCacheOptions {
/// Max number of entries.
pub size: usize,
pub size: u64,
/// Entry's time-to-live.
pub ttl: Duration,
/// Max number of roles per endpoint.
pub max_roles: usize,
pub max_roles: u64,
/// Gc interval.
pub gc_interval: Duration,
}
impl ProjectInfoCacheOptions {
/// Default options for [`crate::control_plane::NodeInfoCache`].
/// Default options for [`crate::cache::project_info::ProjectInfoCache`].
pub const CACHE_DEFAULT_OPTIONS: &'static str =
"size=10000,ttl=4m,max_roles=10,gc_interval=60m";
@@ -496,21 +512,37 @@ mod tests {
#[test]
fn test_parse_cache_options() -> anyhow::Result<()> {
let CacheOptions { size, ttl } = "size=4096,ttl=5min".parse()?;
assert_eq!(size, 4096);
assert_eq!(ttl, Duration::from_secs(5 * 60));
let CacheOptions {
size,
absolute_ttl,
idle_ttl: _,
} = "size=4096,ttl=5min".parse()?;
assert_eq!(size, Some(4096));
assert_eq!(absolute_ttl, Some(Duration::from_secs(5 * 60)));
let CacheOptions { size, ttl } = "ttl=4m,size=2".parse()?;
assert_eq!(size, 2);
assert_eq!(ttl, Duration::from_secs(4 * 60));
let CacheOptions {
size,
absolute_ttl,
idle_ttl: _,
} = "ttl=4m,size=2".parse()?;
assert_eq!(size, Some(2));
assert_eq!(absolute_ttl, Some(Duration::from_secs(4 * 60)));
let CacheOptions { size, ttl } = "size=0,ttl=1s".parse()?;
assert_eq!(size, 0);
assert_eq!(ttl, Duration::from_secs(1));
let CacheOptions {
size,
absolute_ttl,
idle_ttl: _,
} = "size=0,ttl=1s".parse()?;
assert_eq!(size, Some(0));
assert_eq!(absolute_ttl, Some(Duration::from_secs(1)));
let CacheOptions { size, ttl } = "size=0".parse()?;
assert_eq!(size, 0);
assert_eq!(ttl, Duration::default());
let CacheOptions {
size,
absolute_ttl,
idle_ttl: _,
} = "size=0".parse()?;
assert_eq!(size, Some(0));
assert_eq!(absolute_ttl, None);
Ok(())
}

View File

@@ -3,7 +3,6 @@
use std::net::IpAddr;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use ::http::HeaderName;
use ::http::header::AUTHORIZATION;
@@ -17,6 +16,8 @@ use tracing::{Instrument, debug, info, info_span, warn};
use super::super::messages::{ControlPlaneErrorMessage, GetEndpointAccessControl, WakeCompute};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::cache::Cached;
use crate::cache::node_info::CachedNodeInfo;
use crate::context::RequestContext;
use crate::control_plane::caches::ApiCaches;
use crate::control_plane::errors::{
@@ -25,8 +26,7 @@ use crate::control_plane::errors::{
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse};
use crate::control_plane::{
AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo,
RoleAccessControl,
AccessBlockerFlags, AuthInfo, AuthSecret, EndpointAccessControl, NodeInfo, RoleAccessControl,
};
use crate::metrics::Metrics;
use crate::proxy::retry::CouldRetry;
@@ -118,7 +118,6 @@ impl NeonControlPlaneClient {
cache_key.into(),
role.into(),
msg.clone(),
retry_info.map(|r| Duration::from_millis(r.retry_delay_ms)),
);
Err(err)
@@ -347,18 +346,11 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
) -> Result<RoleAccessControl, GetAuthInfoError> {
let key = endpoint.normalize();
if let Some((role_control, ttl)) = self
.caches
.project_info
.get_role_secret_with_ttl(&key, role)
{
if let Some(role_control) = self.caches.project_info.get_role_secret(&key, role) {
return match role_control {
Err(mut msg) => {
Err(msg) => {
info!(key = &*key, "found cached get_role_access_control error");
// if retry_delay_ms is set change it to the remaining TTL
replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64);
Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg)))
}
Ok(role_control) => {
@@ -383,17 +375,14 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
) -> Result<EndpointAccessControl, GetAuthInfoError> {
let key = endpoint.normalize();
if let Some((control, ttl)) = self.caches.project_info.get_endpoint_access_with_ttl(&key) {
if let Some(control) = self.caches.project_info.get_endpoint_access(&key) {
return match control {
Err(mut msg) => {
Err(msg) => {
info!(
key = &*key,
"found cached get_endpoint_access_control error"
);
// if retry_delay_ms is set change it to the remaining TTL
replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64);
Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg)))
}
Ok(control) => {
@@ -426,17 +415,11 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
macro_rules! check_cache {
() => {
if let Some(cached) = self.caches.node_info.get_with_created_at(&key) {
let (cached, (info, created_at)) = cached.take_value();
if let Some(info) = self.caches.node_info.get_entry(&key) {
return match info {
Err(mut msg) => {
Err(msg) => {
info!(key = &*key, "found cached wake_compute error");
// if retry_delay_ms is set, reduce it by the amount of time it spent in cache
replace_retry_delay_ms(&mut msg, |delay| {
delay.saturating_sub(created_at.elapsed().as_millis() as u64)
});
Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
msg,
)))
@@ -444,7 +427,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
Ok(info) => {
debug!(key = &*key, "found cached compute node info");
ctx.set_project(info.aux.clone());
Ok(cached.map(|()| info))
Ok(info)
}
};
}
@@ -483,10 +466,12 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
let mut stored_node = node.clone();
// store the cached node as 'warm_cached'
stored_node.aux.cold_start_info = ColdStartInfo::WarmCached;
self.caches.node_info.insert(key.clone(), Ok(stored_node));
let (_, cached) = self.caches.node_info.insert_unit(key, Ok(stored_node));
Ok(cached.map(|()| node))
Ok(Cached {
token: Some((&self.caches.node_info, key)),
value: node,
})
}
Err(err) => match err {
WakeComputeError::ControlPlane(ControlPlaneError::Message(ref msg)) => {
@@ -503,11 +488,7 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
"created a cache entry for the wake compute error"
);
let ttl = retry_info.map_or(Duration::from_secs(30), |r| {
Duration::from_millis(r.retry_delay_ms)
});
self.caches.node_info.insert_ttl(key, Err(msg.clone()), ttl);
self.caches.node_info.insert(key, Err(msg.clone()));
Err(err)
}
@@ -517,14 +498,6 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
}
}
fn replace_retry_delay_ms(msg: &mut ControlPlaneErrorMessage, f: impl FnOnce(u64) -> u64) {
if let Some(status) = &mut msg.status
&& let Some(retry_info) = &mut status.details.retry_info
{
retry_info.retry_delay_ms = f(retry_info.retry_delay_ms);
}
}
/// Parse http response body, taking status code into account.
fn parse_body<T: for<'a> serde::Deserialize<'a>>(
status: StatusCode,

View File

@@ -15,6 +15,7 @@ use crate::auth::IpPattern;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::cache::Cached;
use crate::cache::node_info::CachedNodeInfo;
use crate::compute::ConnectInfo;
use crate::context::RequestContext;
use crate::control_plane::errors::{
@@ -22,8 +23,7 @@ use crate::control_plane::errors::{
};
use crate::control_plane::messages::{EndpointRateLimitConfig, MetricsAuxInfo};
use crate::control_plane::{
AccessBlockerFlags, AuthInfo, AuthSecret, CachedNodeInfo, EndpointAccessControl, NodeInfo,
RoleAccessControl,
AccessBlockerFlags, AuthInfo, AuthSecret, EndpointAccessControl, NodeInfo, RoleAccessControl,
};
use crate::intern::RoleNameInt;
use crate::scram;

View File

@@ -13,10 +13,11 @@ use tracing::{debug, info};
use super::{EndpointAccessControl, RoleAccessControl};
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::{AuthRule, FetchAuthRules, FetchAuthRulesError};
use crate::cache::project_info::ProjectInfoCacheImpl;
use crate::cache::node_info::{CachedNodeInfo, NodeInfoCache};
use crate::cache::project_info::ProjectInfoCache;
use crate::config::{CacheOptions, ProjectInfoCacheOptions};
use crate::context::RequestContext;
use crate::control_plane::{CachedNodeInfo, ControlPlaneApi, NodeInfoCache, errors};
use crate::control_plane::{ControlPlaneApi, errors};
use crate::error::ReportableError;
use crate::metrics::ApiLockMetrics;
use crate::rate_limiter::{DynamicLimiter, Outcome, RateLimiterConfig, Token};
@@ -119,7 +120,7 @@ pub struct ApiCaches {
/// Cache for the `wake_compute` API method.
pub(crate) node_info: NodeInfoCache,
/// Cache which stores project_id -> endpoint_ids mapping.
pub project_info: Arc<ProjectInfoCacheImpl>,
pub project_info: Arc<ProjectInfoCache>,
}
impl ApiCaches {
@@ -128,13 +129,8 @@ impl ApiCaches {
project_info_cache_config: ProjectInfoCacheOptions,
) -> Self {
Self {
node_info: NodeInfoCache::new(
"node_info_cache",
wake_compute_cache_config.size,
wake_compute_cache_config.ttl,
true,
),
project_info: Arc::new(ProjectInfoCacheImpl::new(project_info_cache_config)),
node_info: NodeInfoCache::new(wake_compute_cache_config),
project_info: Arc::new(ProjectInfoCache::new(project_info_cache_config)),
}
}
}

View File

@@ -1,8 +1,10 @@
use std::fmt::{self, Display};
use std::time::Duration;
use measured::FixedCardinalityLabel;
use serde::{Deserialize, Serialize};
use smol_str::SmolStr;
use tokio::time::Instant;
use crate::auth::IpPattern;
use crate::intern::{AccountIdInt, BranchIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
@@ -231,7 +233,13 @@ impl Reason {
#[derive(Copy, Clone, Debug, Deserialize)]
#[allow(dead_code)]
pub(crate) struct RetryInfo {
pub(crate) retry_delay_ms: u64,
#[serde(rename = "retry_delay_ms", deserialize_with = "milliseconds_from_now")]
pub(crate) retry_at: Instant,
}
fn milliseconds_from_now<'de, D: serde::Deserializer<'de>>(d: D) -> Result<Instant, D::Error> {
let millis = u64::deserialize(d)?;
Ok(Instant::now() + Duration::from_millis(millis))
}
#[derive(Debug, Deserialize, Clone)]

View File

@@ -16,13 +16,13 @@ use messages::EndpointRateLimitConfig;
use crate::auth::backend::ComputeUserInfo;
use crate::auth::backend::jwt::AuthRule;
use crate::auth::{AuthError, IpPattern, check_peer_addr_is_in_list};
use crate::cache::{Cached, TimedLru};
use crate::cache::node_info::CachedNodeInfo;
use crate::context::RequestContext;
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
use crate::control_plane::messages::MetricsAuxInfo;
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt};
use crate::protocol2::ConnectionInfoExtra;
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig};
use crate::types::{EndpointCacheKey, EndpointId, RoleName};
use crate::types::{EndpointId, RoleName};
use crate::{compute, scram};
/// Various cache-related types.
@@ -77,10 +77,6 @@ pub(crate) struct AccessBlockerFlags {
pub vpc_access_blocked: bool,
}
pub(crate) type NodeInfoCache =
TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ControlPlaneErrorMessage>>>;
pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
#[derive(Clone, Debug)]
pub struct RoleAccessControl {
pub secret: Option<AuthSecret>,

View File

@@ -2,59 +2,60 @@ use std::sync::{Arc, OnceLock};
use lasso::ThreadedRodeo;
use measured::label::{
FixedCardinalitySet, LabelGroupSet, LabelName, LabelSet, LabelValue, StaticLabelSet,
FixedCardinalitySet, LabelGroupSet, LabelGroupVisitor, LabelName, LabelSet, LabelValue,
StaticLabelSet,
};
use measured::metric::group::Encoding;
use measured::metric::histogram::Thresholds;
use measured::metric::name::MetricName;
use measured::{
Counter, CounterVec, FixedCardinalityLabel, Gauge, Histogram, HistogramVec, LabelGroup,
MetricGroup,
Counter, CounterVec, FixedCardinalityLabel, Gauge, GaugeVec, Histogram, HistogramVec,
LabelGroup, MetricGroup,
};
use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLogVec};
use metrics::{CounterPairAssoc, CounterPairVec, HyperLogLogVec, InfoMetric};
use tokio::time::{self, Instant};
use crate::control_plane::messages::ColdStartInfo;
use crate::error::ErrorKind;
#[derive(MetricGroup)]
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
#[metric(new())]
pub struct Metrics {
#[metric(namespace = "proxy")]
#[metric(init = ProxyMetrics::new(thread_pool))]
#[metric(init = ProxyMetrics::new())]
pub proxy: ProxyMetrics,
#[metric(namespace = "wake_compute_lock")]
pub wake_compute_lock: ApiLockMetrics,
#[metric(namespace = "service")]
pub service: ServiceMetrics,
#[metric(namespace = "cache")]
pub cache: CacheMetrics,
}
static SELF: OnceLock<Metrics> = OnceLock::new();
impl Metrics {
pub fn install(thread_pool: Arc<ThreadPoolMetrics>) {
let mut metrics = Metrics::new(thread_pool);
metrics.proxy.errors_total.init_all_dense();
metrics.proxy.redis_errors_total.init_all_dense();
metrics.proxy.redis_events_count.init_all_dense();
metrics.proxy.retries_metric.init_all_dense();
metrics.proxy.connection_failures_total.init_all_dense();
SELF.set(metrics)
.ok()
.expect("proxy metrics must not be installed more than once");
}
#[track_caller]
pub fn get() -> &'static Self {
#[cfg(test)]
return SELF.get_or_init(|| Metrics::new(Arc::new(ThreadPoolMetrics::new(0))));
static SELF: OnceLock<Metrics> = OnceLock::new();
#[cfg(not(test))]
SELF.get()
.expect("proxy metrics must be installed by the main() function")
SELF.get_or_init(|| {
let mut metrics = Metrics::new();
metrics.proxy.errors_total.init_all_dense();
metrics.proxy.redis_errors_total.init_all_dense();
metrics.proxy.redis_events_count.init_all_dense();
metrics.proxy.retries_metric.init_all_dense();
metrics.proxy.connection_failures_total.init_all_dense();
metrics
})
}
}
#[derive(MetricGroup)]
#[metric(new(thread_pool: Arc<ThreadPoolMetrics>))]
#[metric(new())]
pub struct ProxyMetrics {
#[metric(flatten)]
pub db_connections: CounterPairVec<NumDbConnectionsGauge>,
@@ -127,6 +128,9 @@ pub struct ProxyMetrics {
/// Number of TLS handshake failures
pub tls_handshake_failures: Counter,
/// Number of SHA 256 rounds executed.
pub sha_rounds: Counter,
/// HLL approximate cardinality of endpoints that are connecting
pub connecting_endpoints: HyperLogLogVec<StaticLabelSet<Protocol>, 32>,
@@ -144,8 +148,25 @@ pub struct ProxyMetrics {
pub connect_compute_lock: ApiLockMetrics,
#[metric(namespace = "scram_pool")]
#[metric(init = thread_pool)]
pub scram_pool: Arc<ThreadPoolMetrics>,
pub scram_pool: OnceLockWrapper<Arc<ThreadPoolMetrics>>,
}
/// A Wrapper over [`OnceLock`] to implement [`MetricGroup`].
pub struct OnceLockWrapper<T>(pub OnceLock<T>);
impl<T> Default for OnceLockWrapper<T> {
fn default() -> Self {
Self(OnceLock::new())
}
}
impl<Enc: Encoding, T: MetricGroup<Enc>> MetricGroup<Enc> for OnceLockWrapper<T> {
fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> {
if let Some(inner) = self.0.get() {
inner.collect_group_into(enc)?;
}
Ok(())
}
}
#[derive(MetricGroup)]
@@ -215,13 +236,6 @@ pub enum Bool {
False,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "outcome")]
pub enum CacheOutcome {
Hit,
Miss,
}
#[derive(LabelGroup)]
#[label(set = ConsoleRequestSet)]
pub struct ConsoleRequest<'a> {
@@ -553,14 +567,6 @@ impl From<bool> for Bool {
}
}
#[derive(LabelGroup)]
#[label(set = InvalidEndpointsSet)]
pub struct InvalidEndpointsGroup {
pub protocol: Protocol,
pub rejected: Bool,
pub outcome: ConnectOutcome,
}
#[derive(LabelGroup)]
#[label(set = RetriesMetricSet)]
pub struct RetriesMetricGroup {
@@ -660,3 +666,100 @@ pub struct ThreadPoolMetrics {
#[metric(init = CounterVec::with_label_set(ThreadPoolWorkers(workers)))]
pub worker_task_skips_total: CounterVec<ThreadPoolWorkers>,
}
#[derive(MetricGroup, Default)]
pub struct ServiceMetrics {
pub info: InfoMetric<ServiceInfo>,
}
#[derive(Default)]
pub struct ServiceInfo {
pub state: ServiceState,
}
impl ServiceInfo {
pub const fn running() -> Self {
ServiceInfo {
state: ServiceState::Running,
}
}
pub const fn terminating() -> Self {
ServiceInfo {
state: ServiceState::Terminating,
}
}
}
impl LabelGroup for ServiceInfo {
fn visit_values(&self, v: &mut impl LabelGroupVisitor) {
const STATE: &LabelName = LabelName::from_str("state");
v.write_value(STATE, &self.state);
}
}
#[derive(FixedCardinalityLabel, Clone, Copy, Debug, Default)]
#[label(singleton = "state")]
pub enum ServiceState {
#[default]
Init,
Running,
Terminating,
}
#[derive(MetricGroup)]
#[metric(new())]
pub struct CacheMetrics {
/// The capacity of the cache
pub capacity: GaugeVec<StaticLabelSet<CacheKind>>,
/// The total number of entries inserted into the cache
pub inserted_total: CounterVec<StaticLabelSet<CacheKind>>,
/// The total number of entries removed from the cache
pub evicted_total: CounterVec<CacheEvictionSet>,
/// The total number of cache requests
pub request_total: CounterVec<CacheOutcomeSet>,
}
impl Default for CacheMetrics {
fn default() -> Self {
Self::new()
}
}
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
#[label(singleton = "cache")]
pub enum CacheKind {
NodeInfo,
ProjectInfoEndpoints,
ProjectInfoRoles,
Schema,
Pbkdf2,
}
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
pub enum CacheRemovalCause {
Expired,
Explicit,
Replaced,
Size,
}
#[derive(LabelGroup)]
#[label(set = CacheEvictionSet)]
pub struct CacheEviction {
pub cache: CacheKind,
pub cause: CacheRemovalCause,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
pub enum CacheOutcome {
Hit,
Miss,
}
#[derive(LabelGroup)]
#[label(set = CacheOutcomeSet)]
pub struct CacheOutcomeGroup {
pub cache: CacheKind,
pub outcome: CacheOutcome,
}

View File

@@ -2,7 +2,7 @@ use thiserror::Error;
use crate::auth::Backend;
use crate::auth::backend::ComputeUserInfo;
use crate::cache::Cache;
use crate::cache::common::Cache;
use crate::compute::{AuthInfo, ComputeConnection, ConnectionError, PostgresError};
use crate::config::ProxyConfig;
use crate::context::RequestContext;

View File

@@ -1,11 +1,12 @@
use tokio::time;
use tracing::{debug, info, warn};
use crate::cache::node_info::CachedNodeInfo;
use crate::compute::{self, COULD_NOT_CONNECT, ComputeConnection};
use crate::config::{ComputeConfig, ProxyConfig, RetryConfig};
use crate::context::RequestContext;
use crate::control_plane::NodeInfo;
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::{self, NodeInfo};
use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
};
@@ -17,7 +18,7 @@ use crate::types::Host;
/// (e.g. the compute node's address might've changed at the wrong time).
/// Invalidate the cache entry (if any) to prevent subsequent errors.
#[tracing::instrument(skip_all)]
pub(crate) fn invalidate_cache(node_info: control_plane::CachedNodeInfo) -> NodeInfo {
pub(crate) fn invalidate_cache(node_info: CachedNodeInfo) -> NodeInfo {
let is_cached = node_info.cached();
if is_cached {
warn!("invalidating stalled compute node info cache entry");
@@ -37,7 +38,7 @@ pub(crate) trait ConnectMechanism {
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
node_info: &CachedNodeInfo,
config: &ComputeConfig,
) -> Result<Self::Connection, compute::ConnectionError>;
}
@@ -66,7 +67,7 @@ impl ConnectMechanism for TcpMechanism<'_> {
async fn connect_once(
&self,
ctx: &RequestContext,
node_info: &control_plane::CachedNodeInfo,
node_info: &CachedNodeInfo,
config: &ComputeConfig,
) -> Result<ComputeConnection, compute::ConnectionError> {
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;

View File

@@ -15,15 +15,17 @@ use rstest::rstest;
use rustls::crypto::ring;
use rustls::pki_types;
use tokio::io::{AsyncRead, AsyncWrite, DuplexStream};
use tokio::time::Instant;
use tracing_test::traced_test;
use super::retry::CouldRetry;
use crate::auth::backend::{ComputeUserInfo, MaybeOwned};
use crate::config::{ComputeConfig, RetryConfig, TlsConfig};
use crate::cache::node_info::{CachedNodeInfo, NodeInfoCache};
use crate::config::{CacheOptions, ComputeConfig, RetryConfig, TlsConfig};
use crate::context::RequestContext;
use crate::control_plane::client::{ControlPlaneClient, TestControlPlaneClient};
use crate::control_plane::messages::{ControlPlaneErrorMessage, Details, MetricsAuxInfo, Status};
use crate::control_plane::{self, CachedNodeInfo, NodeInfo, NodeInfoCache};
use crate::control_plane::{self, NodeInfo};
use crate::error::ErrorKind;
use crate::pglb::ERR_INSECURE_CONNECTION;
use crate::pglb::handshake::{HandshakeData, handshake};
@@ -417,12 +419,11 @@ impl TestConnectMechanism {
Self {
counter: Arc::new(std::sync::Mutex::new(0)),
sequence,
cache: Box::leak(Box::new(NodeInfoCache::new(
"test",
1,
Duration::from_secs(100),
false,
))),
cache: Box::leak(Box::new(NodeInfoCache::new(CacheOptions {
size: Some(1),
absolute_ttl: Some(Duration::from_secs(100)),
idle_ttl: None,
}))),
}
}
}
@@ -436,7 +437,7 @@ impl ConnectMechanism for TestConnectMechanism {
async fn connect_once(
&self,
_ctx: &RequestContext,
_node_info: &control_plane::CachedNodeInfo,
_node_info: &CachedNodeInfo,
_config: &ComputeConfig,
) -> Result<Self::Connection, compute::ConnectionError> {
let mut counter = self.counter.lock().unwrap();
@@ -501,7 +502,7 @@ impl TestControlPlaneClient for TestConnectMechanism {
details: Details {
error_info: None,
retry_info: Some(control_plane::messages::RetryInfo {
retry_delay_ms: 1,
retry_at: Instant::now() + Duration::from_millis(1),
}),
user_facing_message: None,
},
@@ -546,8 +547,11 @@ fn helper_create_uncached_node_info() -> NodeInfo {
fn helper_create_cached_node_info(cache: &'static NodeInfoCache) -> CachedNodeInfo {
let node = helper_create_uncached_node_info();
let (_, node2) = cache.insert_unit("key".into(), Ok(node.clone()));
node2.map(|()| node)
cache.insert("key".into(), Ok(node.clone()));
CachedNodeInfo {
token: Some((cache, "key".into())),
value: node,
}
}
fn helper_create_connect_info(

View File

@@ -1,9 +1,9 @@
use async_trait::async_trait;
use tracing::{error, info};
use crate::cache::node_info::CachedNodeInfo;
use crate::config::RetryConfig;
use crate::context::RequestContext;
use crate::control_plane::CachedNodeInfo;
use crate::control_plane::errors::{ControlPlaneError, WakeComputeError};
use crate::error::ReportableError;
use crate::metrics::{

View File

@@ -131,11 +131,11 @@ where
Ok(())
}
struct MessageHandler<C: ProjectInfoCache + Send + Sync + 'static> {
struct MessageHandler<C: Send + Sync + 'static> {
cache: Arc<C>,
}
impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
impl<C: Send + Sync + 'static> Clone for MessageHandler<C> {
fn clone(&self) -> Self {
Self {
cache: self.cache.clone(),
@@ -143,8 +143,8 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> Clone for MessageHandler<C> {
}
}
impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
pub(crate) fn new(cache: Arc<C>) -> Self {
impl MessageHandler<ProjectInfoCache> {
pub(crate) fn new(cache: Arc<ProjectInfoCache>) -> Self {
Self { cache }
}
@@ -224,7 +224,7 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
}
}
fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
fn invalidate_cache(cache: Arc<ProjectInfoCache>, msg: Notification) {
match msg {
Notification::EndpointSettingsUpdate(ids) => ids
.iter()
@@ -247,8 +247,8 @@ fn invalidate_cache<C: ProjectInfoCache>(cache: Arc<C>, msg: Notification) {
}
}
async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
handler: MessageHandler<C>,
async fn handle_messages(
handler: MessageHandler<ProjectInfoCache>,
redis: ConnectionWithCredentialsProvider,
cancellation_token: CancellationToken,
) -> anyhow::Result<()> {
@@ -284,13 +284,10 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
/// Handle console's invalidation messages.
#[tracing::instrument(name = "redis_notifications", skip_all)]
pub async fn task_main<C>(
pub async fn task_main(
redis: ConnectionWithCredentialsProvider,
cache: Arc<C>,
) -> anyhow::Result<Infallible>
where
C: ProjectInfoCache + Send + Sync + 'static,
{
cache: Arc<ProjectInfoCache>,
) -> anyhow::Result<Infallible> {
let handler = MessageHandler::new(cache);
// 6h - 1m.
// There will be 1 minute overlap between two tasks. But at least we can be sure that no message is lost.

84
proxy/src/scram/cache.rs Normal file
View File

@@ -0,0 +1,84 @@
use tokio::time::Instant;
use zeroize::Zeroize as _;
use super::pbkdf2;
use crate::cache::Cached;
use crate::cache::common::{Cache, count_cache_insert, count_cache_outcome, eviction_listener};
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::metrics::{CacheKind, Metrics};
pub(crate) struct Pbkdf2Cache(moka::sync::Cache<(EndpointIdInt, RoleNameInt), Pbkdf2CacheEntry>);
pub(crate) type CachedPbkdf2<'a> = Cached<&'a Pbkdf2Cache>;
impl Cache for Pbkdf2Cache {
type Key = (EndpointIdInt, RoleNameInt);
type Value = Pbkdf2CacheEntry;
fn invalidate(&self, info: &(EndpointIdInt, RoleNameInt)) {
self.0.invalidate(info);
}
}
/// To speed up password hashing for more active customers, we store the tail results of the
/// PBKDF2 algorithm. If the output of PBKDF2 is U1 ^ U2 ^ ⋯ ^ Uc, then we store
/// suffix = U17 ^ U18 ^ ⋯ ^ Uc. We only need to calculate U1 ^ U2 ^ ⋯ ^ U15 ^ U16
/// to determine the final result.
///
/// The suffix alone isn't enough to crack the password. The stored_key is still required.
/// While both are cached in memory, given they're in different locations is makes it much
/// harder to exploit, even if any such memory exploit exists in proxy.
#[derive(Clone)]
pub struct Pbkdf2CacheEntry {
/// corresponds to [`super::ServerSecret::cached_at`]
pub(super) cached_from: Instant,
pub(super) suffix: pbkdf2::Block,
}
impl Drop for Pbkdf2CacheEntry {
fn drop(&mut self) {
self.suffix.zeroize();
}
}
impl Pbkdf2Cache {
pub fn new() -> Self {
const SIZE: u64 = 100;
const TTL: std::time::Duration = std::time::Duration::from_secs(60);
let builder = moka::sync::Cache::builder()
.name("pbkdf2")
.max_capacity(SIZE)
// We use time_to_live so we don't refresh the lifetime for an invalid password attempt.
.time_to_live(TTL);
Metrics::get()
.cache
.capacity
.set(CacheKind::Pbkdf2, SIZE as i64);
let builder =
builder.eviction_listener(|_k, _v, cause| eviction_listener(CacheKind::Pbkdf2, cause));
Self(builder.build())
}
pub fn insert(&self, endpoint: EndpointIdInt, role: RoleNameInt, value: Pbkdf2CacheEntry) {
count_cache_insert(CacheKind::Pbkdf2);
self.0.insert((endpoint, role), value);
}
fn get(&self, endpoint: EndpointIdInt, role: RoleNameInt) -> Option<Pbkdf2CacheEntry> {
count_cache_outcome(CacheKind::Pbkdf2, self.0.get(&(endpoint, role)))
}
pub fn get_entry(
&self,
endpoint: EndpointIdInt,
role: RoleNameInt,
) -> Option<CachedPbkdf2<'_>> {
self.get(endpoint, role).map(|value| Cached {
token: Some((self, (endpoint, role))),
value,
})
}
}

View File

@@ -4,10 +4,8 @@ use std::convert::Infallible;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use tracing::{debug, trace};
use super::ScramKey;
use super::messages::{
ClientFinalMessage, ClientFirstMessage, OwnedServerFirstMessage, SCRAM_RAW_NONCE_LEN,
};
@@ -15,8 +13,10 @@ use super::pbkdf2::Pbkdf2;
use super::secret::ServerSecret;
use super::signature::SignatureBuilder;
use super::threadpool::ThreadPool;
use crate::intern::EndpointIdInt;
use super::{ScramKey, pbkdf2};
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::sasl::{self, ChannelBinding, Error as SaslError};
use crate::scram::cache::Pbkdf2CacheEntry;
/// The only channel binding mode we currently support.
#[derive(Debug)]
@@ -77,46 +77,113 @@ impl<'a> Exchange<'a> {
}
}
// copied from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L236-L248>
async fn derive_client_key(
pool: &ThreadPool,
endpoint: EndpointIdInt,
password: &[u8],
salt: &[u8],
iterations: u32,
) -> ScramKey {
let salted_password = pool
.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
.await;
let make_key = |name| {
let key = Hmac::<Sha256>::new_from_slice(&salted_password)
.expect("HMAC is able to accept all key sizes")
.chain_update(name)
.finalize();
<[u8; 32]>::from(key.into_bytes())
};
make_key(b"Client Key").into()
) -> pbkdf2::Block {
pool.spawn_job(endpoint, Pbkdf2::start(password, salt, iterations))
.await
}
/// For cleartext flow, we need to derive the client key to
/// 1. authenticate the client.
/// 2. authenticate with compute.
pub(crate) async fn exchange(
pool: &ThreadPool,
endpoint: EndpointIdInt,
role: RoleNameInt,
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
if secret.iterations > CACHED_ROUNDS {
exchange_with_cache(pool, endpoint, role, secret, password).await
} else {
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
Ok(validate_pbkdf2(secret, &hash))
}
}
/// Compute the client key using a cache. We cache the suffix of the pbkdf2 result only,
/// which is not enough by itself to perform an offline brute force.
async fn exchange_with_cache(
pool: &ThreadPool,
endpoint: EndpointIdInt,
role: RoleNameInt,
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
debug_assert!(
secret.iterations > CACHED_ROUNDS,
"we should not cache password data if there isn't enough rounds needed"
);
// compute the prefix of the pbkdf2 output.
let prefix = derive_client_key(pool, endpoint, password, &salt, CACHED_ROUNDS).await;
if let Some(entry) = pool.cache.get_entry(endpoint, role) {
// hot path: let's check the threadpool cache
if secret.cached_at == entry.cached_from {
// cache is valid. compute the full hash by adding the prefix to the suffix.
let mut hash = prefix;
pbkdf2::xor_assign(&mut hash, &entry.suffix);
let outcome = validate_pbkdf2(secret, &hash);
if matches!(outcome, sasl::Outcome::Success(_)) {
trace!("password validated from cache");
}
return Ok(outcome);
}
// cached key is no longer valid.
debug!("invalidating cached password");
entry.invalidate();
}
// slow path: full password hash.
let hash = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
let outcome = validate_pbkdf2(secret, &hash);
let client_key = match outcome {
sasl::Outcome::Success(client_key) => client_key,
sasl::Outcome::Failure(_) => return Ok(outcome),
};
trace!("storing cached password");
// time to cache, compute the suffix by subtracting the prefix from the hash.
let mut suffix = hash;
pbkdf2::xor_assign(&mut suffix, &prefix);
pool.cache.insert(
endpoint,
role,
Pbkdf2CacheEntry {
cached_from: secret.cached_at,
suffix,
},
);
Ok(sasl::Outcome::Success(client_key))
}
fn validate_pbkdf2(secret: &ServerSecret, hash: &pbkdf2::Block) -> sasl::Outcome<ScramKey> {
let client_key = super::ScramKey::client_key(&(*hash).into());
if secret.is_password_invalid(&client_key).into() {
Ok(sasl::Outcome::Failure("password doesn't match"))
sasl::Outcome::Failure("password doesn't match")
} else {
Ok(sasl::Outcome::Success(client_key))
sasl::Outcome::Success(client_key)
}
}
const CACHED_ROUNDS: u32 = 16;
impl SaslInitial {
fn transition(
&self,

View File

@@ -1,6 +1,12 @@
//! Tools for client/server/stored key management.
use hmac::Mac as _;
use sha2::Digest as _;
use subtle::ConstantTimeEq;
use zeroize::Zeroize as _;
use crate::metrics::Metrics;
use crate::scram::pbkdf2::Prf;
/// Faithfully taken from PostgreSQL.
pub(crate) const SCRAM_KEY_LEN: usize = 32;
@@ -14,6 +20,12 @@ pub(crate) struct ScramKey {
bytes: [u8; SCRAM_KEY_LEN],
}
impl Drop for ScramKey {
fn drop(&mut self) {
self.bytes.zeroize();
}
}
impl PartialEq for ScramKey {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
@@ -28,12 +40,26 @@ impl ConstantTimeEq for ScramKey {
impl ScramKey {
pub(crate) fn sha256(&self) -> Self {
super::sha256([self.as_ref()]).into()
Metrics::get().proxy.sha_rounds.inc_by(1);
Self {
bytes: sha2::Sha256::digest(self.as_bytes()).into(),
}
}
pub(crate) fn as_bytes(&self) -> [u8; SCRAM_KEY_LEN] {
self.bytes
}
pub(crate) fn client_key(b: &[u8; 32]) -> Self {
// Prf::new_from_slice will run 2 sha256 rounds.
// Update + Finalize run 2 sha256 rounds.
Metrics::get().proxy.sha_rounds.inc_by(4);
let mut prf = Prf::new_from_slice(b).expect("HMAC is able to accept all key sizes");
prf.update(b"Client Key");
let client_key: [u8; 32] = prf.finalize().into_bytes().into();
client_key.into()
}
}
impl From<[u8; SCRAM_KEY_LEN]> for ScramKey {

View File

@@ -6,6 +6,7 @@
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/backend/libpq/auth-scram.c>
//! * <https://github.com/postgres/postgres/blob/94226d4506e66d6e7cbf4b391f1e7393c1962841/src/interfaces/libpq/fe-auth-scram.c>
mod cache;
mod countmin;
mod exchange;
mod key;
@@ -18,10 +19,8 @@ pub mod threadpool;
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
pub(crate) use exchange::{Exchange, exchange};
use hmac::{Hmac, Mac};
pub(crate) use key::ScramKey;
pub(crate) use secret::ServerSecret;
use sha2::{Digest, Sha256};
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
@@ -42,29 +41,13 @@ fn base64_decode_array<const N: usize>(input: impl AsRef<[u8]>) -> Option<[u8; N
Some(bytes)
}
/// This function essentially is `Hmac(sha256, key, input)`.
/// Further reading: <https://datatracker.ietf.org/doc/html/rfc2104>.
fn hmac_sha256<'a>(key: &[u8], parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("bad key size");
parts.into_iter().for_each(|s| mac.update(s));
mac.finalize().into_bytes().into()
}
fn sha256<'a>(parts: impl IntoIterator<Item = &'a [u8]>) -> [u8; 32] {
let mut hasher = Sha256::new();
parts.into_iter().for_each(|s| hasher.update(s));
hasher.finalize().into()
}
#[cfg(test)]
mod tests {
use super::threadpool::ThreadPool;
use super::{Exchange, ServerSecret};
use crate::intern::EndpointIdInt;
use crate::intern::{EndpointIdInt, RoleNameInt};
use crate::sasl::{Mechanism, Step};
use crate::types::EndpointId;
use crate::types::{EndpointId, RoleName};
#[test]
fn snapshot() {
@@ -114,23 +97,34 @@ mod tests {
);
}
async fn run_round_trip_test(server_password: &str, client_password: &str) {
let pool = ThreadPool::new(1);
async fn check(
pool: &ThreadPool,
scram_secret: &ServerSecret,
password: &[u8],
) -> Result<(), &'static str> {
let ep = EndpointId::from("foo");
let ep = EndpointIdInt::from(ep);
let role = RoleName::from("user");
let role = RoleNameInt::from(&role);
let scram_secret = ServerSecret::build(server_password).await.unwrap();
let outcome = super::exchange(&pool, ep, &scram_secret, client_password.as_bytes())
let outcome = super::exchange(pool, ep, role, scram_secret, password)
.await
.unwrap();
match outcome {
crate::sasl::Outcome::Success(_) => {}
crate::sasl::Outcome::Failure(r) => panic!("{r}"),
crate::sasl::Outcome::Success(_) => Ok(()),
crate::sasl::Outcome::Failure(r) => Err(r),
}
}
async fn run_round_trip_test(server_password: &str, client_password: &str) {
let pool = ThreadPool::new(1);
let scram_secret = ServerSecret::build(server_password).await.unwrap();
check(&pool, &scram_secret, client_password.as_bytes())
.await
.unwrap();
}
#[tokio::test]
async fn round_trip() {
run_round_trip_test("pencil", "pencil").await;
@@ -141,4 +135,27 @@ mod tests {
async fn failure() {
run_round_trip_test("pencil", "eraser").await;
}
#[tokio::test]
#[tracing_test::traced_test]
async fn password_cache() {
let pool = ThreadPool::new(1);
let scram_secret = ServerSecret::build("password").await.unwrap();
// wrong passwords are not added to cache
check(&pool, &scram_secret, b"wrong").await.unwrap_err();
assert!(!logs_contain("storing cached password"));
// correct passwords get cached
check(&pool, &scram_secret, b"password").await.unwrap();
assert!(logs_contain("storing cached password"));
// wrong passwords do not match the cache
check(&pool, &scram_secret, b"wrong").await.unwrap_err();
assert!(!logs_contain("password validated from cache"));
// correct passwords match the cache
check(&pool, &scram_secret, b"password").await.unwrap();
assert!(logs_contain("password validated from cache"));
}
}

View File

@@ -1,25 +1,50 @@
//! For postgres password authentication, we need to perform a PBKDF2 using
//! PRF=HMAC-SHA2-256, producing only 1 block (32 bytes) of output key.
use hmac::Mac as _;
use hmac::digest::consts::U32;
use hmac::digest::generic_array::GenericArray;
use hmac::{Hmac, Mac};
use sha2::Sha256;
use zeroize::Zeroize as _;
use crate::metrics::Metrics;
/// The Psuedo-random function used during PBKDF2 and the SCRAM-SHA-256 handshake.
pub type Prf = hmac::Hmac<sha2::Sha256>;
pub(crate) type Block = GenericArray<u8, U32>;
pub(crate) struct Pbkdf2 {
hmac: Hmac<Sha256>,
prev: GenericArray<u8, U32>,
hi: GenericArray<u8, U32>,
hmac: Prf,
/// U{r-1} for whatever iteration r we are currently on.
prev: Block,
/// the output of `fold(xor, U{1}..U{r})` for whatever iteration r we are currently on.
hi: Block,
/// number of iterations left
iterations: u32,
}
impl Drop for Pbkdf2 {
fn drop(&mut self) {
self.prev.zeroize();
self.hi.zeroize();
}
}
// inspired from <https://github.com/neondatabase/rust-postgres/blob/20031d7a9ee1addeae6e0968e3899ae6bf01cee2/postgres-protocol/src/authentication/sasl.rs#L36-L61>
impl Pbkdf2 {
pub(crate) fn start(str: &[u8], salt: &[u8], iterations: u32) -> Self {
pub(crate) fn start(pw: &[u8], salt: &[u8], iterations: u32) -> Self {
// key the HMAC and derive the first block in-place
let mut hmac =
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
let mut hmac = Prf::new_from_slice(pw).expect("HMAC is able to accept all key sizes");
// U1 = PRF(Password, Salt + INT_32_BE(i))
// i = 1 since we only need 1 block of output.
hmac.update(salt);
hmac.update(&1u32.to_be_bytes());
let init_block = hmac.finalize_reset().into_bytes();
// Prf::new_from_slice will run 2 sha256 rounds.
// Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
Metrics::get().proxy.sha_rounds.inc_by(4);
Self {
hmac,
// one iteration spent above
@@ -33,7 +58,11 @@ impl Pbkdf2 {
(self.iterations).clamp(0, 4096)
}
pub(crate) fn turn(&mut self) -> std::task::Poll<[u8; 32]> {
/// For "fairness", we implement PBKDF2 with cooperative yielding, which is why we use this `turn`
/// function that only executes a fixed number of iterations before continuing.
///
/// Task must be rescheuled if this returns [`std::task::Poll::Pending`].
pub(crate) fn turn(&mut self) -> std::task::Poll<Block> {
let Self {
hmac,
prev,
@@ -44,25 +73,37 @@ impl Pbkdf2 {
// only do up to 4096 iterations per turn for fairness
let n = (*iterations).clamp(0, 4096);
for _ in 0..n {
hmac.update(prev);
let block = hmac.finalize_reset().into_bytes();
for (hi_byte, &b) in hi.iter_mut().zip(block.iter()) {
*hi_byte ^= b;
}
*prev = block;
let next = single_round(hmac, prev);
xor_assign(hi, &next);
*prev = next;
}
// Our update + finalize run 2 sha256 rounds for each pbkdf2 round.
Metrics::get().proxy.sha_rounds.inc_by(2 * n as u64);
*iterations -= n;
if *iterations == 0 {
std::task::Poll::Ready((*hi).into())
std::task::Poll::Ready(*hi)
} else {
std::task::Poll::Pending
}
}
}
#[inline(always)]
pub fn xor_assign(x: &mut Block, y: &Block) {
for (x, &y) in std::iter::zip(x, y) {
*x ^= y;
}
}
#[inline(always)]
fn single_round(prf: &mut Prf, ui: &Block) -> Block {
// Ui = PRF(Password, Ui-1)
prf.update(ui);
prf.finalize_reset().into_bytes()
}
#[cfg(test)]
mod tests {
use pbkdf2::pbkdf2_hmac_array;
@@ -76,11 +117,11 @@ mod tests {
let pass = b"Ne0n_!5_50_C007";
let mut job = Pbkdf2::start(pass, salt, 60000);
let hash = loop {
let hash: [u8; 32] = loop {
let std::task::Poll::Ready(hash) = job.turn() else {
continue;
};
break hash;
break hash.into();
};
let expected = pbkdf2_hmac_array::<Sha256, 32>(pass, salt, 60000);

View File

@@ -3,6 +3,7 @@
use base64::Engine as _;
use base64::prelude::BASE64_STANDARD;
use subtle::{Choice, ConstantTimeEq};
use tokio::time::Instant;
use super::base64_decode_array;
use super::key::ScramKey;
@@ -11,6 +12,9 @@ use super::key::ScramKey;
/// and is used throughout the authentication process.
#[derive(Clone, Eq, PartialEq, Debug)]
pub(crate) struct ServerSecret {
/// When this secret was cached.
pub(crate) cached_at: Instant,
/// Number of iterations for `PBKDF2` function.
pub(crate) iterations: u32,
/// Salt used to hash user's password.
@@ -34,6 +38,7 @@ impl ServerSecret {
params.split_once(':').zip(keys.split_once(':'))?;
let secret = ServerSecret {
cached_at: Instant::now(),
iterations: iterations.parse().ok()?,
salt_base64: salt.into(),
stored_key: base64_decode_array(stored_key)?.into(),
@@ -54,6 +59,7 @@ impl ServerSecret {
/// See `auth-scram.c : mock_scram_secret` for details.
pub(crate) fn mock(nonce: [u8; 32]) -> Self {
Self {
cached_at: Instant::now(),
// this doesn't reveal much information as we're going to use
// iteration count 1 for our generated passwords going forward.
// PG16 users can set iteration count=1 already today.

Some files were not shown because too many files have changed in this diff Show More