Compare commits

..

1 Commits

Author SHA1 Message Date
Conrad Ludgate
13d41b51a2 simplify error handling for query encoding 2025-05-21 20:07:43 +01:00
110 changed files with 1624 additions and 5386 deletions

View File

@@ -314,8 +314,7 @@ jobs:
test_selection: performance
run_in_parallel: false
save_perf_report: ${{ github.ref_name == 'main' }}
# test_pageserver_max_throughput_getpage_at_latest_lsn is run in separate workflow periodic_pagebench.yml because it needs snapshots
extra_params: --splits 5 --group ${{ matrix.pytest_split_group }} --ignore=test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py
extra_params: --splits 5 --group ${{ matrix.pytest_split_group }}
benchmark_durations: ${{ needs.get-benchmarks-durations.outputs.json }}
pg_version: v16
aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}

View File

@@ -1,4 +1,4 @@
name: Periodic pagebench performance test on unit-perf hetzner runner
name: Periodic pagebench performance test on dedicated EC2 machine in eu-central-1 region
on:
schedule:
@@ -8,7 +8,7 @@ on:
# │ │ ┌───────────── day of the month (1 - 31)
# │ │ │ ┌───────────── month (1 - 12 or JAN-DEC)
# │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT)
- cron: '0 */4 * * *' # Runs every 4 hours
- cron: '0 */3 * * *' # Runs every 3 hours
workflow_dispatch: # Allows manual triggering of the workflow
inputs:
commit_hash:
@@ -16,11 +16,6 @@ on:
description: 'The long neon repo commit hash for the system under test (pageserver) to be tested.'
required: false
default: ''
recreate_snapshots:
type: boolean
description: 'Recreate snapshots - !!!WARNING!!! We should only recreate snapshots if the previous ones are no longer compatible. Otherwise benchmarking results are not comparable across runs.'
required: false
default: false
defaults:
run:
@@ -34,13 +29,13 @@ permissions:
contents: read
jobs:
run_periodic_pagebench_test:
trigger_bench_on_ec2_machine_in_eu_central_1:
permissions:
id-token: write # aws-actions/configure-aws-credentials
statuses: write
contents: write
pull-requests: write
runs-on: [ self-hosted, unit-perf ]
runs-on: [ self-hosted, small ]
container:
image: ghcr.io/neondatabase/build-tools:pinned-bookworm
credentials:
@@ -49,13 +44,10 @@ jobs:
options: --init
timeout-minutes: 360 # Set the timeout to 6 hours
env:
API_KEY: ${{ secrets.PERIODIC_PAGEBENCH_EC2_RUNNER_API_KEY }}
RUN_ID: ${{ github.run_id }}
DEFAULT_PG_VERSION: 16
BUILD_TYPE: release
RUST_BACKTRACE: 1
# NEON_ENV_BUILDER_USE_OVERLAYFS_FOR_SNAPSHOTS: 1 - doesn't work without root in container
S3_BUCKET: neon-github-public-dev
PERF_TEST_RESULT_CONNSTR: "${{ secrets.PERF_TEST_RESULT_CONNSTR }}"
AWS_DEFAULT_REGION : "eu-central-1"
AWS_INSTANCE_ID : "i-02a59a3bf86bc7e74"
steps:
# we don't need the neon source code because we run everything remotely
# however we still need the local github actions to run the allure step below
@@ -64,194 +56,99 @@ jobs:
with:
egress-policy: audit
- name: Set up the environment which depends on $RUNNER_TEMP on nvme drive
id: set-env
shell: bash -euxo pipefail {0}
run: |
{
echo "NEON_DIR=${RUNNER_TEMP}/neon"
echo "NEON_BIN=${RUNNER_TEMP}/neon/bin"
echo "POSTGRES_DISTRIB_DIR=${RUNNER_TEMP}/neon/pg_install"
echo "LD_LIBRARY_PATH=${RUNNER_TEMP}/neon/pg_install/v${DEFAULT_PG_VERSION}/lib"
echo "BACKUP_DIR=${RUNNER_TEMP}/instance_store/saved_snapshots"
echo "TEST_OUTPUT=${RUNNER_TEMP}/neon/test_output"
echo "PERF_REPORT_DIR=${RUNNER_TEMP}/neon/test_output/perf-report-local"
echo "ALLURE_DIR=${RUNNER_TEMP}/neon/test_output/allure-results"
echo "ALLURE_RESULTS_DIR=${RUNNER_TEMP}/neon/test_output/allure-results/results"
} >> "$GITHUB_ENV"
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
echo "allure_results_dir=${RUNNER_TEMP}/neon/test_output/allure-results/results" >> "$GITHUB_OUTPUT"
- name: Show my own (github runner) external IP address - usefull for IP allowlisting
run: curl https://ifconfig.me
- uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2
- name: Assume AWS OIDC role that allows to manage (start/stop/describe... EC machine)
uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # max 5 hours (needed in case commit hash is still being built)
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_MANAGE_BENCHMARK_EC2_VMS_ARN }}
role-duration-seconds: 3600
- name: Start EC2 instance and wait for the instance to boot up
run: |
aws ec2 start-instances --instance-ids $AWS_INSTANCE_ID
aws ec2 wait instance-running --instance-ids $AWS_INSTANCE_ID
sleep 60 # sleep some time to allow cloudinit and our API server to start up
- name: Determine public IP of the EC2 instance and set env variable EC2_MACHINE_URL_US
run: |
public_ip=$(aws ec2 describe-instances --instance-ids $AWS_INSTANCE_ID --query 'Reservations[*].Instances[*].PublicIpAddress' --output text)
echo "Public IP of the EC2 instance: $public_ip"
echo "EC2_MACHINE_URL_US=https://${public_ip}:8443" >> $GITHUB_ENV
- name: Determine commit hash
id: commit_hash
shell: bash -euxo pipefail {0}
env:
INPUT_COMMIT_HASH: ${{ github.event.inputs.commit_hash }}
run: |
if [[ -z "${INPUT_COMMIT_HASH}" ]]; then
COMMIT_HASH=$(curl -s https://api.github.com/repos/neondatabase/neon/commits/main | jq -r '.sha')
echo "COMMIT_HASH=$COMMIT_HASH" >> $GITHUB_ENV
echo "commit_hash=$COMMIT_HASH" >> "$GITHUB_OUTPUT"
if [ -z "$INPUT_COMMIT_HASH" ]; then
echo "COMMIT_HASH=$(curl -s https://api.github.com/repos/neondatabase/neon/commits/main | jq -r '.sha')" >> $GITHUB_ENV
echo "COMMIT_HASH_TYPE=latest" >> $GITHUB_ENV
else
COMMIT_HASH="${INPUT_COMMIT_HASH}"
echo "COMMIT_HASH=$COMMIT_HASH" >> $GITHUB_ENV
echo "commit_hash=$COMMIT_HASH" >> "$GITHUB_OUTPUT"
echo "COMMIT_HASH=$INPUT_COMMIT_HASH" >> $GITHUB_ENV
echo "COMMIT_HASH_TYPE=manual" >> $GITHUB_ENV
fi
- name: Checkout the neon repository at given commit hash
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
ref: ${{ steps.commit_hash.outputs.commit_hash }}
# does not reuse ./.github/actions/download because we need to download the artifact for the given commit hash
# example artifact
# s3://neon-github-public-dev/artifacts/48b870bc078bd2c450eb7b468e743b9c118549bf/15036827400/1/neon-Linux-X64-release-artifact.tar.zst /instance_store/artifacts/neon-Linux-release-artifact.tar.zst
- name: Determine artifact S3_KEY for given commit hash and download and extract artifact
id: artifact_prefix
shell: bash -euxo pipefail {0}
env:
ARCHIVE: ${{ runner.temp }}/downloads/neon-${{ runner.os }}-${{ runner.arch }}-release-artifact.tar.zst
COMMIT_HASH: ${{ env.COMMIT_HASH }}
COMMIT_HASH_TYPE: ${{ env.COMMIT_HASH_TYPE }}
- name: Start Bench with run_id
run: |
attempt=0
max_attempts=24 # 5 minutes * 24 = 2 hours
curl -k -X 'POST' \
"${EC2_MACHINE_URL_US}/start_test/${GITHUB_RUN_ID}" \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-H "Authorization: Bearer $API_KEY" \
-d "{\"neonRepoCommitHash\": \"${COMMIT_HASH}\", \"neonRepoCommitHashType\": \"${COMMIT_HASH_TYPE}\"}"
while [[ $attempt -lt $max_attempts ]]; do
# the following command will fail until the artifacts are available ...
S3_KEY=$(aws s3api list-objects-v2 --bucket "$S3_BUCKET" --prefix "artifacts/$COMMIT_HASH/" \
| jq -r '.Contents[]?.Key' \
| grep "neon-${{ runner.os }}-${{ runner.arch }}-release-artifact.tar.zst" \
| sort --version-sort \
| tail -1) || true # ... thus ignore errors from the command
if [[ -n "${S3_KEY}" ]]; then
echo "Artifact found: $S3_KEY"
echo "S3_KEY=$S3_KEY" >> $GITHUB_ENV
- name: Poll Test Status
id: poll_step
run: |
status=""
while [[ "$status" != "failure" && "$status" != "success" ]]; do
response=$(curl -k -X 'GET' \
"${EC2_MACHINE_URL_US}/test_status/${GITHUB_RUN_ID}" \
-H 'accept: application/json' \
-H "Authorization: Bearer $API_KEY")
echo "Response: $response"
set +x
status=$(echo $response | jq -r '.status')
echo "Test status: $status"
if [[ "$status" == "failure" ]]; then
echo "Test failed"
exit 1 # Fail the job step if status is failure
elif [[ "$status" == "success" || "$status" == "null" ]]; then
break
elif [[ "$status" == "too_many_runs" ]]; then
echo "Too many runs already running"
echo "too_many_runs=true" >> "$GITHUB_OUTPUT"
exit 1
fi
# Increment attempt counter and sleep for 5 minutes
attempt=$((attempt + 1))
echo "Attempt $attempt of $max_attempts to find artifacts in S3 bucket s3://$S3_BUCKET/artifacts/$COMMIT_HASH failed. Retrying in 5 minutes..."
sleep 300 # Sleep for 5 minutes
sleep 60 # Poll every 60 seconds
done
if [[ -z "${S3_KEY}" ]]; then
echo "Error: artifact not found in S3 bucket s3://$S3_BUCKET/artifacts/$COMMIT_HASH" after 2 hours
else
mkdir -p $(dirname $ARCHIVE)
time aws s3 cp --only-show-errors s3://$S3_BUCKET/${S3_KEY} ${ARCHIVE}
mkdir -p ${NEON_DIR}
time tar -xf ${ARCHIVE} -C ${NEON_DIR}
rm -f ${ARCHIVE}
fi
- name: Download snapshots from S3
if: ${{ github.event_name != 'workflow_dispatch' || github.event.inputs.recreate_snapshots == 'false' || github.event.inputs.recreate_snapshots == '' }}
id: download_snapshots
shell: bash -euxo pipefail {0}
- name: Retrieve Test Logs
if: always() && steps.poll_step.outputs.too_many_runs != 'true'
run: |
# Download the snapshots from S3
mkdir -p ${TEST_OUTPUT}
mkdir -p $BACKUP_DIR
cd $BACKUP_DIR
mkdir parts
cd parts
PART=$(aws s3api list-objects-v2 --bucket $S3_BUCKET --prefix performance/pagebench/ \
| jq -r '.Contents[]?.Key' \
| grep -E 'shared-snapshots-[0-9]{4}-[0-9]{2}-[0-9]{2}' \
| sort \
| tail -1)
echo "Latest PART: $PART"
if [[ -z "$PART" ]]; then
echo "ERROR: No matching S3 key found" >&2
exit 1
fi
S3_KEY=$(dirname $PART)
time aws s3 cp --only-show-errors --recursive s3://${S3_BUCKET}/$S3_KEY/ .
cd $TEST_OUTPUT
time cat $BACKUP_DIR/parts/* | zstdcat | tar --extract --preserve-permissions
rm -rf ${BACKUP_DIR}
curl -k -X 'GET' \
"${EC2_MACHINE_URL_US}/test_log/${GITHUB_RUN_ID}" \
-H 'accept: application/gzip' \
-H "Authorization: Bearer $API_KEY" \
--output "test_log_${GITHUB_RUN_ID}.gz"
- name: Cache poetry deps
uses: actions/cache@v4
with:
path: ~/.cache/pypoetry/virtualenvs
key: v2-${{ runner.os }}-${{ runner.arch }}-python-deps-bookworm-${{ hashFiles('poetry.lock') }}
- name: Install Python deps
shell: bash -euxo pipefail {0}
run: ./scripts/pysync
# we need high number of open files for pagebench
- name: show ulimits
shell: bash -euxo pipefail {0}
- name: Unzip Test Log and Print it into this job's log
if: always() && steps.poll_step.outputs.too_many_runs != 'true'
run: |
ulimit -a
- name: Run pagebench testcase
shell: bash -euxo pipefail {0}
env:
CI: false # need to override this env variable set by github to enforce using snapshots
run: |
export PLATFORM=hetzner-unit-perf-${COMMIT_HASH_TYPE}
# report the commit hash of the neon repository in the revision of the test results
export GITHUB_SHA=${COMMIT_HASH}
rm -rf ${PERF_REPORT_DIR}
rm -rf ${ALLURE_RESULTS_DIR}
mkdir -p ${PERF_REPORT_DIR}
mkdir -p ${ALLURE_RESULTS_DIR}
PARAMS="--alluredir=${ALLURE_RESULTS_DIR} --tb=short --verbose -rA"
EXTRA_PARAMS="--out-dir ${PERF_REPORT_DIR} --durations-path $TEST_OUTPUT/benchmark_durations.json"
# run only two selected tests
# environment set by parent:
# RUST_BACKTRACE=1 DEFAULT_PG_VERSION=16 BUILD_TYPE=release
./scripts/pytest ${PARAMS} test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py::test_pageserver_characterize_throughput_with_n_tenants ${EXTRA_PARAMS}
./scripts/pytest ${PARAMS} test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py::test_pageserver_characterize_latencies_with_1_client_and_throughput_with_many_clients_one_tenant ${EXTRA_PARAMS}
- name: upload the performance metrics to the Neon performance database which is used by grafana dashboards to display the results
shell: bash -euxo pipefail {0}
run: |
export REPORT_FROM="$PERF_REPORT_DIR"
export GITHUB_SHA=${COMMIT_HASH}
time ./scripts/generate_and_push_perf_report.sh
- name: Upload test results
if: ${{ !cancelled() }}
uses: ./.github/actions/allure-report-store
with:
report-dir: ${{ steps.set-env.outputs.allure_results_dir }}
unique-key: ${{ env.BUILD_TYPE }}-${{ env.DEFAULT_PG_VERSION }}-${{ runner.arch }}
aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
gzip -d "test_log_${GITHUB_RUN_ID}.gz"
cat "test_log_${GITHUB_RUN_ID}"
- name: Create Allure report
id: create-allure-report
if: ${{ !cancelled() }}
uses: ./.github/actions/allure-report-generate
with:
aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
- name: Upload snapshots
if: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.recreate_snapshots != 'false' && github.event.inputs.recreate_snapshots != '' }}
id: upload_snapshots
shell: bash -euxo pipefail {0}
run: |
mkdir -p $BACKUP_DIR
cd $TEST_OUTPUT
tar --create --preserve-permissions --file - shared-snapshots | zstd -o $BACKUP_DIR/shared_snapshots.tar.zst
cd $BACKUP_DIR
mkdir parts
split -b 1G shared_snapshots.tar.zst ./parts/shared_snapshots.tar.zst.part.
SNAPSHOT_DATE=$(date +%F) # YYYY-MM-DD
cd parts
time aws s3 cp --recursive . s3://${S3_BUCKET}/performance/pagebench/shared-snapshots-${SNAPSHOT_DATE}/
- name: Post to a Slack channel
if: ${{ github.event.schedule && failure() }}
uses: slackapi/slack-github-action@fcfb566f8b0aab22203f066d80ca1d7e4b5d05b3 # v1.27.1
@@ -260,22 +157,26 @@ jobs:
slack-message: "Periodic pagebench testing on dedicated hardware: ${{ job.status }}\n${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}"
env:
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
- name: Cleanup Test Resources
if: always()
shell: bash -euxo pipefail {0}
env:
ARCHIVE: ${{ runner.temp }}/downloads/neon-${{ runner.os }}-${{ runner.arch }}-release-artifact.tar.zst
run: |
# Cleanup the test resources
if [[ -d "${BACKUP_DIR}" ]]; then
rm -rf ${BACKUP_DIR}
fi
if [[ -d "${TEST_OUTPUT}" ]]; then
rm -rf ${TEST_OUTPUT}
fi
if [[ -d "${NEON_DIR}" ]]; then
rm -rf ${NEON_DIR}
fi
rm -rf $(dirname $ARCHIVE)
curl -k -X 'POST' \
"${EC2_MACHINE_URL_US}/cleanup_test/${GITHUB_RUN_ID}" \
-H 'accept: application/json' \
-H "Authorization: Bearer $API_KEY" \
-d ''
- name: Assume AWS OIDC role that allows to manage (start/stop/describe... EC machine)
if: always() && steps.poll_step.outputs.too_many_runs != 'true'
uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_MANAGE_BENCHMARK_EC2_VMS_ARN }}
role-duration-seconds: 3600
- name: Stop EC2 instance and wait for the instance to be stopped
if: always() && steps.poll_step.outputs.too_many_runs != 'true'
run: |
aws ec2 stop-instances --instance-ids $AWS_INSTANCE_ID
aws ec2 wait instance-stopped --instance-ids $AWS_INSTANCE_ID

250
Cargo.lock generated
View File

@@ -1276,7 +1276,7 @@ version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
"indexmap 2.9.0",
"indexmap 2.0.1",
"jsonwebtoken",
"regex",
"remote_storage",
@@ -1308,7 +1308,7 @@ dependencies = [
"flate2",
"futures",
"http 1.1.0",
"indexmap 2.9.0",
"indexmap 2.0.1",
"itertools 0.10.5",
"jsonwebtoken",
"metrics",
@@ -2597,7 +2597,7 @@ dependencies = [
"futures-sink",
"futures-util",
"http 0.2.9",
"indexmap 2.9.0",
"indexmap 2.0.1",
"slab",
"tokio",
"tokio-util",
@@ -2616,7 +2616,7 @@ dependencies = [
"futures-sink",
"futures-util",
"http 1.1.0",
"indexmap 2.9.0",
"indexmap 2.0.1",
"slab",
"tokio",
"tokio-util",
@@ -2863,14 +2863,14 @@ dependencies = [
"pprof",
"regex",
"routerify",
"rustls 0.23.27",
"rustls 0.23.18",
"rustls-pemfile 2.1.1",
"serde",
"serde_json",
"serde_path_to_error",
"thiserror 1.0.69",
"tokio",
"tokio-rustls 0.26.2",
"tokio-rustls 0.26.0",
"tokio-stream",
"tokio-util",
"tracing",
@@ -3200,12 +3200,12 @@ dependencies = [
[[package]]
name = "indexmap"
version = "2.9.0"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e"
checksum = "ad227c3af19d4914570ad36d30409928b75967c298feb9ea1969db3a610bb14e"
dependencies = [
"equivalent",
"hashbrown 0.15.2",
"hashbrown 0.14.5",
"serde",
]
@@ -3228,7 +3228,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "232929e1d75fe899576a3d5c7416ad0d88dbfbb3c3d6aa00873a7408a50ddb88"
dependencies = [
"ahash",
"indexmap 2.9.0",
"indexmap 2.0.1",
"is-terminal",
"itoa",
"log",
@@ -3251,7 +3251,7 @@ dependencies = [
"crossbeam-utils",
"dashmap 6.1.0",
"env_logger",
"indexmap 2.9.0",
"indexmap 2.0.1",
"itoa",
"log",
"num-format",
@@ -3898,16 +3898,6 @@ dependencies = [
"winapi",
]
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84"
dependencies = [
"overload",
"winapi",
]
[[package]]
name = "num"
version = "0.4.1"
@@ -4112,7 +4102,7 @@ dependencies = [
"opentelemetry-http",
"opentelemetry-proto",
"opentelemetry_sdk",
"prost 0.13.5",
"prost 0.13.3",
"reqwest",
"thiserror 1.0.69",
]
@@ -4125,8 +4115,8 @@ checksum = "a6e05acbfada5ec79023c85368af14abd0b307c015e9064d249b2a950ef459a6"
dependencies = [
"opentelemetry",
"opentelemetry_sdk",
"prost 0.13.5",
"tonic 0.12.3",
"prost 0.13.3",
"tonic",
]
[[package]]
@@ -4192,12 +4182,6 @@ version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4030760ffd992bef45b0ae3f10ce1aba99e33464c90d14dd7c039884963ddc7a"
[[package]]
name = "overload"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39"
[[package]]
name = "p256"
version = "0.11.1"
@@ -4321,7 +4305,6 @@ dependencies = [
"pageserver_api",
"pageserver_client",
"pageserver_compaction",
"pageserver_page_api",
"pem",
"pin-project-lite",
"postgres-protocol",
@@ -4330,7 +4313,6 @@ dependencies = [
"postgres_connection",
"postgres_ffi",
"postgres_initdb",
"posthog_client_lite",
"pprof",
"pq_proto",
"procfs",
@@ -4341,7 +4323,7 @@ dependencies = [
"reqwest",
"rpds",
"rstest",
"rustls 0.23.27",
"rustls 0.23.18",
"scopeguard",
"send-future",
"serde",
@@ -4360,13 +4342,11 @@ dependencies = [
"tokio-epoll-uring",
"tokio-io-timeout",
"tokio-postgres",
"tokio-rustls 0.26.2",
"tokio-rustls 0.26.0",
"tokio-stream",
"tokio-tar",
"tokio-util",
"toml_edit",
"tonic 0.13.1",
"tonic-reflection",
"tracing",
"tracing-utils",
"twox-hash",
@@ -4459,15 +4439,9 @@ dependencies = [
name = "pageserver_page_api"
version = "0.1.0"
dependencies = [
"bytes",
"pageserver_api",
"postgres_ffi",
"prost 0.13.5",
"smallvec",
"thiserror 1.0.69",
"tonic 0.13.1",
"prost 0.13.3",
"tonic",
"tonic-build",
"utils",
"workspace_hack",
]
@@ -4847,14 +4821,14 @@ dependencies = [
"bytes",
"once_cell",
"pq_proto",
"rustls 0.23.27",
"rustls 0.23.18",
"rustls-pemfile 2.1.1",
"serde",
"thiserror 1.0.69",
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls 0.26.2",
"tokio-rustls 0.26.0",
"tokio-util",
"tracing",
]
@@ -4908,16 +4882,11 @@ name = "posthog_client_lite"
version = "0.1.0"
dependencies = [
"anyhow",
"arc-swap",
"reqwest",
"serde",
"serde_json",
"sha2",
"thiserror 1.0.69",
"tokio",
"tokio-util",
"tracing",
"tracing-utils",
"workspace_hack",
]
@@ -4966,7 +4935,7 @@ dependencies = [
"inferno 0.12.0",
"num",
"paste",
"prost 0.13.5",
"prost 0.13.3",
]
[[package]]
@@ -5071,12 +5040,12 @@ dependencies = [
[[package]]
name = "prost"
version = "0.13.5"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5"
checksum = "7b0487d90e047de87f984913713b85c601c05609aad5b0df4b4573fbf69aa13f"
dependencies = [
"bytes",
"prost-derive 0.13.5",
"prost-derive 0.13.3",
]
[[package]]
@@ -5114,7 +5083,7 @@ dependencies = [
"once_cell",
"petgraph",
"prettyplease",
"prost 0.13.5",
"prost 0.13.3",
"prost-types 0.13.3",
"regex",
"syn 2.0.100",
@@ -5136,9 +5105,9 @@ dependencies = [
[[package]]
name = "prost-derive"
version = "0.13.5"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d"
checksum = "e9552f850d5f0964a4e4d0bf306459ac29323ddfbae05e35a7c0d35cb0803cc5"
dependencies = [
"anyhow",
"itertools 0.12.1",
@@ -5162,7 +5131,7 @@ version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4759aa0d3a6232fb8dbdb97b61de2c20047c68aca932c7ed76da9d788508d670"
dependencies = [
"prost 0.13.5",
"prost 0.13.3",
]
[[package]]
@@ -5210,7 +5179,7 @@ dependencies = [
"hyper 0.14.30",
"hyper 1.4.1",
"hyper-util",
"indexmap 2.9.0",
"indexmap 2.0.1",
"ipnet",
"itertools 0.10.5",
"itoa",
@@ -5244,7 +5213,7 @@ dependencies = [
"rsa",
"rstest",
"rustc-hash 1.1.0",
"rustls 0.23.27",
"rustls 0.23.18",
"rustls-native-certs 0.8.0",
"rustls-pemfile 2.1.1",
"scopeguard",
@@ -5263,14 +5232,13 @@ dependencies = [
"tokio",
"tokio-postgres",
"tokio-postgres2",
"tokio-rustls 0.26.2",
"tokio-rustls 0.26.0",
"tokio-tungstenite 0.21.0",
"tokio-util",
"tracing",
"tracing-log",
"tracing-opentelemetry",
"tracing-subscriber",
"tracing-test",
"tracing-utils",
"try-lock",
"typed-json",
@@ -5487,13 +5455,13 @@ dependencies = [
"num-bigint",
"percent-encoding",
"pin-project-lite",
"rustls 0.23.27",
"rustls 0.23.18",
"rustls-native-certs 0.8.0",
"ryu",
"sha1_smol",
"socket2",
"tokio",
"tokio-rustls 0.26.2",
"tokio-rustls 0.26.0",
"tokio-util",
"url",
]
@@ -5941,15 +5909,15 @@ dependencies = [
[[package]]
name = "rustls"
version = "0.23.27"
version = "0.23.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321"
checksum = "9c9cc1d47e243d655ace55ed38201c19ae02c148ae56412ab8750e8f0166ab7f"
dependencies = [
"log",
"once_cell",
"ring",
"rustls-pki-types",
"rustls-webpki 0.103.3",
"rustls-webpki 0.102.8",
"subtle",
"zeroize",
]
@@ -6038,17 +6006,6 @@ dependencies = [
"untrusted",
]
[[package]]
name = "rustls-webpki"
version = "0.103.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435"
dependencies = [
"ring",
"rustls-pki-types",
"untrusted",
]
[[package]]
name = "rustversion"
version = "1.0.12"
@@ -6100,7 +6057,7 @@ dependencies = [
"regex",
"remote_storage",
"reqwest",
"rustls 0.23.27",
"rustls 0.23.18",
"safekeeper_api",
"safekeeper_client",
"scopeguard",
@@ -6117,7 +6074,7 @@ dependencies = [
"tokio",
"tokio-io-timeout",
"tokio-postgres",
"tokio-rustls 0.26.2",
"tokio-rustls 0.26.0",
"tokio-stream",
"tokio-tar",
"tokio-util",
@@ -6289,7 +6246,7 @@ checksum = "255914a8e53822abd946e2ce8baa41d4cded6b8e938913b7f7b9da5b7ab44335"
dependencies = [
"httpdate",
"reqwest",
"rustls 0.23.27",
"rustls 0.23.18",
"sentry-backtrace",
"sentry-contexts",
"sentry-core",
@@ -6718,11 +6675,11 @@ dependencies = [
"metrics",
"once_cell",
"parking_lot 0.12.1",
"prost 0.13.5",
"rustls 0.23.27",
"prost 0.13.3",
"rustls 0.23.18",
"tokio",
"tokio-rustls 0.26.2",
"tonic 0.13.1",
"tokio-rustls 0.26.0",
"tonic",
"tonic-build",
"tracing",
"utils",
@@ -6764,7 +6721,7 @@ dependencies = [
"regex",
"reqwest",
"routerify",
"rustls 0.23.27",
"rustls 0.23.18",
"rustls-native-certs 0.8.0",
"safekeeper_api",
"safekeeper_client",
@@ -6779,7 +6736,7 @@ dependencies = [
"tokio",
"tokio-postgres",
"tokio-postgres-rustls",
"tokio-rustls 0.26.2",
"tokio-rustls 0.26.0",
"tokio-util",
"tracing",
"utils",
@@ -6817,7 +6774,7 @@ dependencies = [
"postgres_ffi",
"remote_storage",
"reqwest",
"rustls 0.23.27",
"rustls 0.23.18",
"rustls-native-certs 0.8.0",
"serde",
"serde_json",
@@ -7351,10 +7308,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04fb792ccd6bbcd4bba408eb8a292f70fc4a3589e5d793626f45190e6454b6ab"
dependencies = [
"ring",
"rustls 0.23.27",
"rustls 0.23.18",
"tokio",
"tokio-postgres",
"tokio-rustls 0.26.2",
"tokio-rustls 0.26.0",
"x509-certificate",
]
@@ -7398,11 +7355,12 @@ dependencies = [
[[package]]
name = "tokio-rustls"
version = "0.26.2"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b"
checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4"
dependencies = [
"rustls 0.23.27",
"rustls 0.23.18",
"rustls-pki-types",
"tokio",
]
@@ -7500,7 +7458,7 @@ version = "0.22.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f21c7aaf97f1bd9ca9d4f9e73b0a6c74bd5afef56f2bc931943a6e1c37e04e38"
dependencies = [
"indexmap 2.9.0",
"indexmap 2.0.1",
"serde",
"serde_spanned",
"toml_datetime",
@@ -7519,41 +7477,18 @@ dependencies = [
"http 1.1.0",
"http-body 1.0.0",
"http-body-util",
"percent-encoding",
"pin-project",
"prost 0.13.5",
"tokio-stream",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tonic"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e581ba15a835f4d9ea06c55ab1bd4dce26fc53752c69a04aac00703bfb49ba9"
dependencies = [
"async-trait",
"axum",
"base64 0.22.1",
"bytes",
"h2 0.4.4",
"http 1.1.0",
"http-body 1.0.0",
"http-body-util",
"hyper 1.4.1",
"hyper-timeout",
"hyper-util",
"percent-encoding",
"pin-project",
"prost 0.13.5",
"prost 0.13.3",
"rustls-native-certs 0.8.0",
"socket2",
"rustls-pemfile 2.1.1",
"tokio",
"tokio-rustls 0.26.2",
"tokio-rustls 0.26.0",
"tokio-stream",
"tower 0.5.2",
"tower 0.4.13",
"tower-layer",
"tower-service",
"tracing",
@@ -7561,9 +7496,9 @@ dependencies = [
[[package]]
name = "tonic-build"
version = "0.13.1"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eac6f67be712d12f0b41328db3137e0d0757645d8904b4cb7d51cd9c2279e847"
checksum = "9557ce109ea773b399c9b9e5dca39294110b74f1f342cb347a80d1fce8c26a11"
dependencies = [
"prettyplease",
"proc-macro2",
@@ -7573,19 +7508,6 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "tonic-reflection"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9687bd5bfeafebdded2356950f278bba8226f0b32109537c4253406e09aafe1"
dependencies = [
"prost 0.13.5",
"prost-types 0.13.3",
"tokio",
"tokio-stream",
"tonic 0.13.1",
]
[[package]]
name = "tower"
version = "0.4.13"
@@ -7594,11 +7516,16 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
dependencies = [
"futures-core",
"futures-util",
"indexmap 1.9.3",
"pin-project",
"pin-project-lite",
"rand 0.8.5",
"slab",
"tokio",
"tokio-util",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
@@ -7609,12 +7536,9 @@ checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9"
dependencies = [
"futures-core",
"futures-util",
"indexmap 2.9.0",
"pin-project-lite",
"slab",
"sync_wrapper 1.0.1",
"tokio",
"tokio-util",
"tower-layer",
"tower-service",
"tracing",
@@ -7765,7 +7689,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008"
dependencies = [
"matchers",
"nu-ansi-term",
"once_cell",
"regex",
"serde",
@@ -7779,27 +7702,6 @@ dependencies = [
"tracing-serde",
]
[[package]]
name = "tracing-test"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "557b891436fe0d5e0e363427fc7f217abf9ccd510d5136549847bdcbcd011d68"
dependencies = [
"tracing-core",
"tracing-subscriber",
"tracing-test-macro",
]
[[package]]
name = "tracing-test-macro"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04659ddb06c87d233c566112c1c9c5b9e98256d9af50ec3bc9c8327f873a7568"
dependencies = [
"quote",
"syn 2.0.100",
]
[[package]]
name = "tracing-utils"
version = "0.1.0"
@@ -7942,7 +7844,7 @@ dependencies = [
"base64 0.22.1",
"log",
"once_cell",
"rustls 0.23.27",
"rustls 0.23.18",
"rustls-pki-types",
"url",
"webpki-roots",
@@ -8137,7 +8039,7 @@ dependencies = [
"pageserver_api",
"postgres_ffi",
"pprof",
"prost 0.13.5",
"prost 0.13.3",
"remote_storage",
"serde",
"serde_json",
@@ -8557,8 +8459,6 @@ dependencies = [
"ahash",
"anstream",
"anyhow",
"axum",
"axum-core",
"base64 0.13.1",
"base64 0.21.7",
"base64ct",
@@ -8581,8 +8481,10 @@ dependencies = [
"fail",
"form_urlencoded",
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-task",
"futures-util",
"generic-array",
"getrandom 0.2.11",
@@ -8593,7 +8495,8 @@ dependencies = [
"hyper 0.14.30",
"hyper 1.4.1",
"hyper-util",
"indexmap 2.9.0",
"indexmap 1.9.3",
"indexmap 2.0.1",
"itertools 0.12.1",
"lazy_static",
"libc",
@@ -8612,18 +8515,19 @@ dependencies = [
"once_cell",
"p256 0.13.2",
"parquet",
"percent-encoding",
"prettyplease",
"proc-macro2",
"prost 0.13.5",
"prost 0.13.3",
"quote",
"rand 0.8.5",
"regex",
"regex-automata 0.4.3",
"regex-syntax 0.8.2",
"reqwest",
"rustls 0.23.27",
"rustls 0.23.18",
"rustls-pki-types",
"rustls-webpki 0.103.3",
"rustls-webpki 0.102.8",
"scopeguard",
"sec1 0.7.3",
"serde",
@@ -8641,15 +8545,15 @@ dependencies = [
"time",
"time-macros",
"tokio",
"tokio-rustls 0.26.2",
"tokio-rustls 0.26.0",
"tokio-stream",
"tokio-util",
"toml_edit",
"tower 0.5.2",
"tonic",
"tower 0.4.13",
"tracing",
"tracing-core",
"tracing-log",
"tracing-subscriber",
"url",
"uuid",
"zeroize",

View File

@@ -149,7 +149,7 @@ pin-project-lite = "0.2"
pprof = { version = "0.14", features = ["criterion", "flamegraph", "frame-pointer", "prost-codec"] }
procfs = "0.16"
prometheus = {version = "0.13", default-features=false, features = ["process"]} # removes protobuf dependency
prost = "0.13.5"
prost = "0.13"
rand = "0.8"
redis = { version = "0.29.2", features = ["tokio-rustls-comp", "keep-alive"] }
regex = "1.10.2"
@@ -199,8 +199,7 @@ tokio-tar = "0.3"
tokio-util = { version = "0.7.10", features = ["io", "rt"] }
toml = "0.8"
toml_edit = "0.22"
tonic = { version = "0.13.1", default-features = false, features = ["channel", "codegen", "prost", "router", "server", "tls-ring", "tls-native-roots"] }
tonic-reflection = { version = "0.13.1", features = ["server"] }
tonic = {version = "0.12.3", default-features = false, features = ["channel", "tls", "tls-roots"]}
tower = { version = "0.5.2", default-features = false }
tower-http = { version = "0.6.2", features = ["auth", "request-id", "trace"] }
@@ -247,7 +246,6 @@ azure_storage_blobs = { git = "https://github.com/neondatabase/azure-sdk-for-rus
## Local libraries
compute_api = { version = "0.1", path = "./libs/compute_api/" }
consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" }
desim = { version = "0.1", path = "./libs/desim" }
endpoint_storage = { version = "0.0.1", path = "./endpoint_storage/" }
http-utils = { version = "0.1", path = "./libs/http-utils/" }
metrics = { version = "0.1", path = "./libs/metrics/" }
@@ -260,19 +258,19 @@ postgres_backend = { version = "0.1", path = "./libs/postgres_backend/" }
postgres_connection = { version = "0.1", path = "./libs/postgres_connection/" }
postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" }
postgres_initdb = { path = "./libs/postgres_initdb" }
posthog_client_lite = { version = "0.1", path = "./libs/posthog_client_lite" }
pq_proto = { version = "0.1", path = "./libs/pq_proto/" }
remote_storage = { version = "0.1", path = "./libs/remote_storage/" }
safekeeper_api = { version = "0.1", path = "./libs/safekeeper_api" }
safekeeper_client = { path = "./safekeeper/client" }
desim = { version = "0.1", path = "./libs/desim" }
storage_broker = { version = "0.1", path = "./storage_broker/" } # Note: main broker code is inside the binary crate, so linking with the library shouldn't be heavy.
storage_controller_client = { path = "./storage_controller/client" }
tenant_size_model = { version = "0.1", path = "./libs/tenant_size_model/" }
tracing-utils = { version = "0.1", path = "./libs/tracing-utils/" }
utils = { version = "0.1", path = "./libs/utils/" }
vm_monitor = { version = "0.1", path = "./libs/vm_monitor/" }
wal_decoder = { version = "0.1", path = "./libs/wal_decoder" }
walproposer = { version = "0.1", path = "./libs/walproposer/" }
wal_decoder = { version = "0.1", path = "./libs/wal_decoder" }
## Common library dependency
workspace_hack = { version = "0.1", path = "./workspace_hack/" }
@@ -282,7 +280,7 @@ criterion = "0.5.1"
rcgen = "0.13"
rstest = "0.18"
camino-tempfile = "1.0.2"
tonic-build = "0.13.1"
tonic-build = "0.12"
[patch.crates-io]

View File

@@ -155,7 +155,7 @@ RUN set -e \
# Keep the version the same as in compute/compute-node.Dockerfile and
# test_runner/regress/test_compute_metrics.py.
ENV SQL_EXPORTER_VERSION=0.17.3
ENV SQL_EXPORTER_VERSION=0.17.0
RUN curl -fsSL \
"https://github.com/burningalchemist/sql_exporter/releases/download/${SQL_EXPORTER_VERSION}/sql_exporter-${SQL_EXPORTER_VERSION}.linux-$(case "$(uname -m)" in x86_64) echo amd64;; aarch64) echo arm64;; esac).tar.gz" \
--output sql_exporter.tar.gz \

View File

@@ -1751,17 +1751,17 @@ ARG TARGETARCH
RUN if [ "$TARGETARCH" = "amd64" ]; then\
postgres_exporter_sha256='59aa4a7bb0f7d361f5e05732f5ed8c03cc08f78449cef5856eadec33a627694b';\
pgbouncer_exporter_sha256='c9f7cf8dcff44f0472057e9bf52613d93f3ffbc381ad7547a959daa63c5e84ac';\
sql_exporter_sha256='9a41127a493e8bfebfe692bf78c7ed2872a58a3f961ee534d1b0da9ae584aaab';\
sql_exporter_sha256='38e439732bbf6e28ca4a94d7bc3686d3fa1abdb0050773d5617a9efdb9e64d08';\
else\
postgres_exporter_sha256='d1dedea97f56c6d965837bfd1fbb3e35a3b4a4556f8cccee8bd513d8ee086124';\
pgbouncer_exporter_sha256='217c4afd7e6492ae904055bc14fe603552cf9bac458c063407e991d68c519da3';\
sql_exporter_sha256='530e6afc77c043497ed965532c4c9dfa873bc2a4f0b3047fad367715c0081d6a';\
sql_exporter_sha256='11918b00be6e2c3a67564adfdb2414fdcbb15a5db76ea17d1d1a944237a893c6';\
fi\
&& curl -sL https://github.com/prometheus-community/postgres_exporter/releases/download/v0.17.1/postgres_exporter-0.17.1.linux-${TARGETARCH}.tar.gz\
| tar xzf - --strip-components=1 -C.\
&& curl -sL https://github.com/prometheus-community/pgbouncer_exporter/releases/download/v0.10.2/pgbouncer_exporter-0.10.2.linux-${TARGETARCH}.tar.gz\
| tar xzf - --strip-components=1 -C.\
&& curl -sL https://github.com/burningalchemist/sql_exporter/releases/download/0.17.3/sql_exporter-0.17.3.linux-${TARGETARCH}.tar.gz\
&& curl -sL https://github.com/burningalchemist/sql_exporter/releases/download/0.17.0/sql_exporter-0.17.0.linux-${TARGETARCH}.tar.gz\
| tar xzf - --strip-components=1 -C.\
&& echo "${postgres_exporter_sha256} postgres_exporter" | sha256sum -c -\
&& echo "${pgbouncer_exporter_sha256} pgbouncer_exporter" | sha256sum -c -\
@@ -1814,7 +1814,7 @@ COPY docker-compose/ext-src/ /ext-src/
COPY --from=pg-build /postgres /postgres
#COPY --from=postgis-src /ext-src/ /ext-src/
COPY --from=plv8-src /ext-src/ /ext-src/
COPY --from=h3-pg-src /ext-src/h3-pg-src /ext-src/h3-pg-src
#COPY --from=h3-pg-src /ext-src/ /ext-src/
COPY --from=postgresql-unit-src /ext-src/ /ext-src/
COPY --from=pgvector-src /ext-src/ /ext-src/
COPY --from=pgjwt-src /ext-src/ /ext-src/

View File

@@ -136,10 +136,6 @@ struct Cli {
requires = "compute-id"
)]
pub control_plane_uri: Option<String>,
/// Interval in seconds for collecting installed extensions statistics
#[arg(long, default_value = "3600")]
pub installed_extensions_collection_interval: u64,
}
fn main() -> Result<()> {
@@ -183,7 +179,6 @@ fn main() -> Result<()> {
cgroup: cli.cgroup,
#[cfg(target_os = "linux")]
vm_monitor_addr: cli.vm_monitor_addr,
installed_extensions_collection_interval: cli.installed_extensions_collection_interval,
},
config,
)?;

View File

@@ -97,9 +97,6 @@ pub struct ComputeNodeParams {
/// the address of extension storage proxy gateway
pub remote_ext_base_url: Option<String>,
/// Interval for installed extensions collection
pub installed_extensions_collection_interval: u64,
}
/// Compute node info shared across several `compute_ctl` threads.
@@ -698,18 +695,25 @@ impl ComputeNode {
let log_directory_path = Path::new(&self.params.pgdata).join("log");
let log_directory_path = log_directory_path.to_string_lossy().to_string();
// Add project_id,endpoint_id to identify the logs.
// Add project_id,endpoint_id tag to identify the logs.
//
// These ids are passed from cplane,
let endpoint_id = pspec.spec.endpoint_id.as_deref().unwrap_or("");
let project_id = pspec.spec.project_id.as_deref().unwrap_or("");
// for backwards compatibility (old computes that don't have them),
// we set them to None.
// TODO: Clean up this code when all computes have them.
let tag: Option<String> = match (
pspec.spec.project_id.as_deref(),
pspec.spec.endpoint_id.as_deref(),
) {
(Some(project_id), Some(endpoint_id)) => {
Some(format!("{project_id}/{endpoint_id}"))
}
(Some(project_id), None) => Some(format!("{project_id}/None")),
(None, Some(endpoint_id)) => Some(format!("None,{endpoint_id}")),
(None, None) => None,
};
configure_audit_rsyslog(
log_directory_path.clone(),
endpoint_id,
project_id,
&remote_endpoint,
)?;
configure_audit_rsyslog(log_directory_path.clone(), tag, &remote_endpoint)?;
// Launch a background task to clean up the audit logs
launch_pgaudit_gc(log_directory_path);
@@ -745,7 +749,17 @@ impl ComputeNode {
let conf = self.get_tokio_conn_conf(None);
tokio::task::spawn(async {
let _ = installed_extensions(conf).await;
let res = get_installed_extensions(conf).await;
match res {
Ok(extensions) => {
info!(
"[NEON_EXT_STAT] {}",
serde_json::to_string(&extensions)
.expect("failed to serialize extensions list")
);
}
Err(err) => error!("could not get installed extensions: {err:?}"),
}
});
}
@@ -775,9 +789,6 @@ impl ComputeNode {
// Log metrics so that we can search for slow operations in logs
info!(?metrics, postmaster_pid = %postmaster_pid, "compute start finished");
// Spawn the extension stats background task
self.spawn_extension_stats_task();
if pspec.spec.prewarm_lfc_on_startup {
self.prewarm_lfc();
}
@@ -2188,41 +2199,6 @@ LIMIT 100",
info!("Pageserver config changed");
}
}
pub fn spawn_extension_stats_task(&self) {
let conf = self.tokio_conn_conf.clone();
let installed_extensions_collection_interval =
self.params.installed_extensions_collection_interval;
tokio::spawn(async move {
// An initial sleep is added to ensure that two collections don't happen at the same time.
// The first collection happens during compute startup.
tokio::time::sleep(tokio::time::Duration::from_secs(
installed_extensions_collection_interval,
))
.await;
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(
installed_extensions_collection_interval,
));
loop {
interval.tick().await;
let _ = installed_extensions(conf.clone()).await;
}
});
}
}
pub async fn installed_extensions(conf: tokio_postgres::Config) -> Result<()> {
let res = get_installed_extensions(conf).await;
match res {
Ok(extensions) => {
info!(
"[NEON_EXT_STAT] {}",
serde_json::to_string(&extensions).expect("failed to serialize extensions list")
);
}
Err(err) => error!("could not get installed extensions: {err:?}"),
}
Ok(())
}
pub fn forward_termination_signal() {

View File

@@ -2,24 +2,10 @@
module(load="imfile")
# Input configuration for log files in the specified directory
# The messages can be multiline. The start of the message is a timestamp
# in "%Y-%m-%d %H:%M:%S.%3N GMT" (so timezone hardcoded).
# Replace log_directory with the directory containing the log files
input(type="imfile" File="{log_directory}/*.log"
Tag="pgaudit_log" Severity="info" Facility="local5"
startmsg.regex="^[[:digit:]]{{4}}-[[:digit:]]{{2}}-[[:digit:]]{{2}} [[:digit:]]{{2}}:[[:digit:]]{{2}}:[[:digit:]]{{2}}.[[:digit:]]{{3}} GMT,")
# Replace {log_directory} with the directory containing the log files
input(type="imfile" File="{log_directory}/*.log" Tag="{tag}" Severity="info" Facility="local0")
# the directory to store rsyslog state files
global(workDirectory="/var/log/rsyslog")
# Construct json, endpoint_id and project_id as additional metadata
set $.json_log!endpoint_id = "{endpoint_id}";
set $.json_log!project_id = "{project_id}";
set $.json_log!msg = $msg;
# Template suitable for rfc5424 syslog format
template(name="PgAuditLog" type="string"
string="<%PRI%>1 %TIMESTAMP:::date-rfc3339% %HOSTNAME% - - - - %$.json_log%")
# Forward to remote syslog receiver (@@<hostname>:<port>;format
local5.info @@{remote_endpoint};PgAuditLog
# Forward logs to remote syslog server
*.* @@{remote_endpoint}

View File

@@ -27,40 +27,6 @@ fn get_rsyslog_pid() -> Option<String> {
}
}
fn wait_for_rsyslog_pid() -> Result<String, anyhow::Error> {
const MAX_WAIT: Duration = Duration::from_secs(5);
const INITIAL_SLEEP: Duration = Duration::from_millis(2);
let mut sleep_duration = INITIAL_SLEEP;
let start = std::time::Instant::now();
let mut attempts = 1;
for attempt in 1.. {
attempts = attempt;
match get_rsyslog_pid() {
Some(pid) => return Ok(pid),
None => {
if start.elapsed() >= MAX_WAIT {
break;
}
info!(
"rsyslogd is not running, attempt {}. Sleeping for {} ms",
attempt,
sleep_duration.as_millis()
);
std::thread::sleep(sleep_duration);
sleep_duration *= 2;
}
}
}
Err(anyhow::anyhow!(
"rsyslogd is not running after waiting for {} seconds and {} attempts",
attempts,
start.elapsed().as_secs()
))
}
// Restart rsyslogd to apply the new configuration.
// This is necessary, because there is no other way to reload the rsyslog configuration.
//
@@ -70,29 +36,27 @@ fn wait_for_rsyslog_pid() -> Result<String, anyhow::Error> {
// TODO: test it properly
//
fn restart_rsyslog() -> Result<()> {
let old_pid = get_rsyslog_pid().context("rsyslogd is not running")?;
info!("rsyslogd is running with pid: {}, restart it", old_pid);
// kill it to restart
let _ = Command::new("pkill")
.arg("rsyslogd")
.output()
.context("Failed to restart rsyslogd")?;
// ensure rsyslogd is running
wait_for_rsyslog_pid()?;
.context("Failed to stop rsyslogd")?;
Ok(())
}
pub fn configure_audit_rsyslog(
log_directory: String,
endpoint_id: &str,
project_id: &str,
tag: Option<String>,
remote_endpoint: &str,
) -> Result<()> {
let config_content: String = format!(
include_str!("config_template/compute_audit_rsyslog_template.conf"),
log_directory = log_directory,
endpoint_id = endpoint_id,
project_id = project_id,
tag = tag.unwrap_or("".to_string()),
remote_endpoint = remote_endpoint
);
@@ -167,11 +131,15 @@ pub fn configure_postgres_logs_export(conf: PostgresLogsRsyslogConfig) -> Result
return Ok(());
}
// Nothing to configure
// When new config is empty we can simply remove the configuration file.
if new_config.is_empty() {
// When the configuration is removed, PostgreSQL will stop sending data
// to the files watched by rsyslog, so restarting rsyslog is more effort
// than just ignoring this change.
info!("removing rsyslog config file: {}", POSTGRES_LOGS_CONF_PATH);
match std::fs::remove_file(POSTGRES_LOGS_CONF_PATH) {
Ok(_) => {}
Err(err) if err.kind() == ErrorKind::NotFound => {}
Err(err) => return Err(err.into()),
}
restart_rsyslog()?;
return Ok(());
}

View File

@@ -2,10 +2,8 @@
[pageserver]
listen_pg_addr = '127.0.0.1:64000'
listen_http_addr = '127.0.0.1:9898'
listen_grpc_addr = '127.0.0.1:51051'
pg_auth_type = 'Trust'
http_auth_type = 'Trust'
grpc_auth_type = 'Trust'
[[safekeepers]]
id = 1

View File

@@ -4,10 +4,8 @@
id=1
listen_pg_addr = '127.0.0.1:64000'
listen_http_addr = '127.0.0.1:9898'
listen_grpc_addr = '127.0.0.1:51051'
pg_auth_type = 'Trust'
http_auth_type = 'Trust'
grpc_auth_type = 'Trust'
[[safekeepers]]
id = 1

View File

@@ -32,7 +32,6 @@ use control_plane::storage_controller::{
};
use nix::fcntl::{Flock, FlockArg};
use pageserver_api::config::{
DEFAULT_GRPC_LISTEN_PORT as DEFAULT_PAGESERVER_GRPC_PORT,
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_PAGESERVER_HTTP_PORT,
DEFAULT_PG_LISTEN_PORT as DEFAULT_PAGESERVER_PG_PORT,
};
@@ -1008,16 +1007,13 @@ fn handle_init(args: &InitCmdArgs) -> anyhow::Result<LocalEnv> {
let pageserver_id = NodeId(DEFAULT_PAGESERVER_ID.0 + i as u64);
let pg_port = DEFAULT_PAGESERVER_PG_PORT + i;
let http_port = DEFAULT_PAGESERVER_HTTP_PORT + i;
let grpc_port = DEFAULT_PAGESERVER_GRPC_PORT + i;
NeonLocalInitPageserverConf {
id: pageserver_id,
listen_pg_addr: format!("127.0.0.1:{pg_port}"),
listen_http_addr: format!("127.0.0.1:{http_port}"),
listen_https_addr: None,
listen_grpc_addr: Some(format!("127.0.0.1:{grpc_port}")),
pg_auth_type: AuthType::Trust,
http_auth_type: AuthType::Trust,
grpc_auth_type: AuthType::Trust,
other: Default::default(),
// Typical developer machines use disks with slow fsync, and we don't care
// about data integrity: disable disk syncs.
@@ -1279,7 +1275,6 @@ async fn handle_timeline(cmd: &TimelineCmd, env: &mut local_env::LocalEnv) -> Re
mode: pageserver_api::models::TimelineCreateRequestMode::Branch {
ancestor_timeline_id,
ancestor_start_lsn: start_lsn,
read_only: false,
pg_version: None,
},
};

View File

@@ -278,10 +278,8 @@ pub struct PageServerConf {
pub listen_pg_addr: String,
pub listen_http_addr: String,
pub listen_https_addr: Option<String>,
pub listen_grpc_addr: Option<String>,
pub pg_auth_type: AuthType,
pub http_auth_type: AuthType,
pub grpc_auth_type: AuthType,
pub no_sync: bool,
}
@@ -292,10 +290,8 @@ impl Default for PageServerConf {
listen_pg_addr: String::new(),
listen_http_addr: String::new(),
listen_https_addr: None,
listen_grpc_addr: None,
pg_auth_type: AuthType::Trust,
http_auth_type: AuthType::Trust,
grpc_auth_type: AuthType::Trust,
no_sync: false,
}
}
@@ -310,10 +306,8 @@ pub struct NeonLocalInitPageserverConf {
pub listen_pg_addr: String,
pub listen_http_addr: String,
pub listen_https_addr: Option<String>,
pub listen_grpc_addr: Option<String>,
pub pg_auth_type: AuthType,
pub http_auth_type: AuthType,
pub grpc_auth_type: AuthType,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
pub no_sync: bool,
#[serde(flatten)]
@@ -327,10 +321,8 @@ impl From<&NeonLocalInitPageserverConf> for PageServerConf {
listen_pg_addr,
listen_http_addr,
listen_https_addr,
listen_grpc_addr,
pg_auth_type,
http_auth_type,
grpc_auth_type,
no_sync,
other: _,
} = conf;
@@ -339,9 +331,7 @@ impl From<&NeonLocalInitPageserverConf> for PageServerConf {
listen_pg_addr: listen_pg_addr.clone(),
listen_http_addr: listen_http_addr.clone(),
listen_https_addr: listen_https_addr.clone(),
listen_grpc_addr: listen_grpc_addr.clone(),
pg_auth_type: *pg_auth_type,
grpc_auth_type: *grpc_auth_type,
http_auth_type: *http_auth_type,
no_sync: *no_sync,
}
@@ -717,10 +707,8 @@ impl LocalEnv {
listen_pg_addr: String,
listen_http_addr: String,
listen_https_addr: Option<String>,
listen_grpc_addr: Option<String>,
pg_auth_type: AuthType,
http_auth_type: AuthType,
grpc_auth_type: AuthType,
#[serde(default)]
no_sync: bool,
}
@@ -744,10 +732,8 @@ impl LocalEnv {
listen_pg_addr,
listen_http_addr,
listen_https_addr,
listen_grpc_addr,
pg_auth_type,
http_auth_type,
grpc_auth_type,
no_sync,
} = config_toml;
let IdentityTomlSubset {
@@ -764,10 +750,8 @@ impl LocalEnv {
listen_pg_addr,
listen_http_addr,
listen_https_addr,
listen_grpc_addr,
pg_auth_type,
http_auth_type,
grpc_auth_type,
no_sync,
};
pageservers.push(conf);

View File

@@ -129,9 +129,7 @@ impl PageServerNode {
));
}
if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type]
.contains(&AuthType::NeonJWT)
{
if conf.http_auth_type != AuthType::Trust || conf.pg_auth_type != AuthType::Trust {
// Keys are generated in the toplevel repo dir, pageservers' workdirs
// are one level below that, so refer to keys with ../
overrides.push("auth_validation_public_key_path='../auth_public_key.pem'".to_owned());
@@ -553,11 +551,6 @@ impl PageServerNode {
.map(|x| x.parse::<usize>())
.transpose()
.context("Falied to parse 'relsize_snapshot_cache_capacity' as integer")?,
basebackup_cache_enabled: settings
.remove("basebackup_cache_enabled")
.map(|x| x.parse::<bool>())
.transpose()
.context("Failed to parse 'basebackup_cache_enabled' as bool")?,
};
if !settings.is_empty() {
bail!("Unrecognized tenant settings: {settings:?}")

View File

@@ -20,7 +20,7 @@ first_path="$(ldconfig --verbose 2>/dev/null \
| grep --invert-match ^$'\t' \
| cut --delimiter=: --fields=1 \
| head --lines=1)"
test "$first_path" == '/usr/local/lib'
test "$first_path" == '/usr/local/lib' || true # Remove the || true in a follow-up PR. Needed for backwards compat.
echo "Waiting pageserver become ready."
while ! nc -z pageserver 6400; do

View File

@@ -1,16 +0,0 @@
#!/usr/bin/env bash
set -ex
cd "$(dirname "${0}")"
PG_REGRESS=$(dirname "$(pg_config --pgxs)")/../test/regress/pg_regress
dropdb --if-exists contrib_regression
createdb contrib_regression
cd h3_postgis/test
psql -d contrib_regression -c "CREATE EXTENSION postgis" -c "CREATE EXTENSION postgis_raster" -c "CREATE EXTENSION h3" -c "CREATE EXTENSION h3_postgis"
TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g')
${PG_REGRESS} --use-existing --dbname contrib_regression ${TESTS}
cd ../../h3/test
TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g')
dropdb --if-exists contrib_regression
createdb contrib_regression
psql -d contrib_regression -c "CREATE EXTENSION h3"
${PG_REGRESS} --use-existing --dbname contrib_regression ${TESTS}

View File

@@ -1,7 +0,0 @@
#!/bin/sh
set -ex
cd "$(dirname ${0})"
PG_REGRESS=$(dirname "$(pg_config --pgxs)")/../test/regress/pg_regress
cd h3/test
TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g')
${PG_REGRESS} --use-existing --inputdir=./ --bindir='/usr/local/pgsql/bin' --dbname=contrib_regression ${TESTS}

View File

@@ -1,6 +0,0 @@
#!/bin/sh
set -ex
cd "$(dirname "${0}")"
if [ -f Makefile ]; then
make installcheck
fi

View File

@@ -1,9 +0,0 @@
#!/bin/sh
set -ex
cd "$(dirname ${0})"
[ -f Makefile ] || exit 0
dropdb --if-exist contrib_regression
createdb contrib_regression
PG_REGRESS=$(dirname "$(pg_config --pgxs)")/../test/regress/pg_regress
TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g')
${PG_REGRESS} --use-existing --inputdir=./ --bindir='/usr/local/pgsql/bin' --dbname=contrib_regression ${TESTS}

View File

@@ -82,8 +82,7 @@ EXTENSIONS='[
{"extname": "pg_ivm", "extdir": "pg_ivm-src"},
{"extname": "pgjwt", "extdir": "pgjwt-src"},
{"extname": "pgtap", "extdir": "pgtap-src"},
{"extname": "pg_repack", "extdir": "pg_repack-src"},
{"extname": "h3", "extdir": "h3-pg-src"}
{"extname": "pg_repack", "extdir": "pg_repack-src"}
]'
EXTNAMES=$(echo ${EXTENSIONS} | jq -r '.[].extname' | paste -sd ' ' -)
COMPUTE_TAG=${NEW_COMPUTE_TAG} docker compose --profile test-extensions up --quiet-pull --build -d

View File

@@ -8,8 +8,6 @@ pub const DEFAULT_PG_LISTEN_PORT: u16 = 64000;
pub const DEFAULT_PG_LISTEN_ADDR: &str = formatcp!("127.0.0.1:{DEFAULT_PG_LISTEN_PORT}");
pub const DEFAULT_HTTP_LISTEN_PORT: u16 = 9898;
pub const DEFAULT_HTTP_LISTEN_ADDR: &str = formatcp!("127.0.0.1:{DEFAULT_HTTP_LISTEN_PORT}");
// TODO: gRPC is disabled by default for now, but the port is used in neon_local.
pub const DEFAULT_GRPC_LISTEN_PORT: u16 = 51051; // storage-broker already uses 50051
use std::collections::HashMap;
use std::num::{NonZeroU64, NonZeroUsize};
@@ -45,21 +43,6 @@ pub struct NodeMetadata {
pub other: HashMap<String, serde_json::Value>,
}
/// PostHog integration config.
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct PostHogConfig {
/// PostHog project ID
pub project_id: String,
/// Server-side (private) API key
pub server_api_key: String,
/// Client-side (public) API key
pub client_api_key: String,
/// Private API URL
pub private_api_url: String,
/// Public API URL
pub public_api_url: String,
}
/// `pageserver.toml`
///
/// We use serde derive with `#[serde(default)]` to generate a deserializer
@@ -121,7 +104,6 @@ pub struct ConfigToml {
pub listen_pg_addr: String,
pub listen_http_addr: String,
pub listen_https_addr: Option<String>,
pub listen_grpc_addr: Option<String>,
pub ssl_key_file: Utf8PathBuf,
pub ssl_cert_file: Utf8PathBuf,
#[serde(with = "humantime_serde")]
@@ -141,7 +123,6 @@ pub struct ConfigToml {
pub http_auth_type: AuthType,
#[serde_as(as = "serde_with::DisplayFromStr")]
pub pg_auth_type: AuthType,
pub grpc_auth_type: AuthType,
pub auth_validation_public_key_path: Option<Utf8PathBuf>,
pub remote_storage: Option<RemoteStorageConfig>,
pub tenant_config: TenantConfigToml,
@@ -201,11 +182,7 @@ pub struct ConfigToml {
pub tracing: Option<Tracing>,
pub enable_tls_page_service_api: bool,
pub dev_mode: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub posthog_config: Option<PostHogConfig>,
pub timeline_import_config: TimelineImportConfig,
#[serde(skip_serializing_if = "Option::is_none")]
pub basebackup_cache_config: Option<BasebackupCacheConfig>,
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
@@ -331,26 +308,6 @@ pub struct TimelineImportConfig {
pub import_job_checkpoint_threshold: NonZeroUsize,
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[serde(default)]
pub struct BasebackupCacheConfig {
#[serde(with = "humantime_serde")]
pub cleanup_period: Duration,
// FIXME: Support max_size_bytes.
// pub max_size_bytes: usize,
pub max_size_entries: i64,
}
impl Default for BasebackupCacheConfig {
fn default() -> Self {
Self {
cleanup_period: Duration::from_secs(60),
// max_size_bytes: 1024 * 1024 * 1024, // 1 GiB
max_size_entries: 1000,
}
}
}
pub mod statvfs {
pub mod mock {
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
@@ -534,14 +491,8 @@ pub struct TenantConfigToml {
/// Tenant level performance sampling ratio override. Controls the ratio of get page requests
/// that will get perf sampling for the tenant.
pub sampling_ratio: Option<Ratio>,
/// Capacity of relsize snapshot cache (used by replicas).
pub relsize_snapshot_cache_capacity: usize,
/// Enable preparing basebackup on XLOG_CHECKPOINT_SHUTDOWN and using it in basebackup requests.
// FIXME: Remove skip_serializing_if when the feature is stable.
#[serde(skip_serializing_if = "std::ops::Not::not")]
pub basebackup_cache_enabled: bool,
}
pub mod defaults {
@@ -609,7 +560,6 @@ impl Default for ConfigToml {
listen_pg_addr: (DEFAULT_PG_LISTEN_ADDR.to_string()),
listen_http_addr: (DEFAULT_HTTP_LISTEN_ADDR.to_string()),
listen_https_addr: (None),
listen_grpc_addr: None, // TODO: default to 127.0.0.1:51051
ssl_key_file: Utf8PathBuf::from(DEFAULT_SSL_KEY_FILE),
ssl_cert_file: Utf8PathBuf::from(DEFAULT_SSL_CERT_FILE),
ssl_cert_reload_period: Duration::from_secs(60),
@@ -626,7 +576,6 @@ impl Default for ConfigToml {
pg_distrib_dir: None, // Utf8PathBuf::from("./pg_install"), // TODO: formely, this was std::env::current_dir()
http_auth_type: (AuthType::Trust),
pg_auth_type: (AuthType::Trust),
grpc_auth_type: (AuthType::Trust),
auth_validation_public_key_path: (None),
remote_storage: None,
broker_endpoint: (storage_broker::DEFAULT_ENDPOINT
@@ -717,8 +666,6 @@ impl Default for ConfigToml {
import_job_soft_size_limit: NonZeroUsize::new(1024 * 1024 * 1024).unwrap(),
import_job_checkpoint_threshold: NonZeroUsize::new(128).unwrap(),
},
basebackup_cache_config: None,
posthog_config: None,
}
}
}
@@ -844,7 +791,6 @@ impl Default for TenantConfigToml {
gc_compaction_ratio_percent: DEFAULT_GC_COMPACTION_RATIO_PERCENT,
sampling_ratio: None,
relsize_snapshot_cache_capacity: DEFAULT_RELSIZE_SNAPSHOT_CACHE_CAPACITY,
basebackup_cache_enabled: false,
}
}
}

View File

@@ -354,9 +354,6 @@ pub struct ShardImportProgressV1 {
pub completed: usize,
/// Hash of the plan
pub import_plan_hash: u64,
/// Soft limit for the job size
/// This needs to remain constant throughout the import
pub job_soft_size_limit: usize,
}
impl ShardImportStatus {
@@ -405,8 +402,6 @@ pub enum TimelineCreateRequestMode {
// using a flattened enum, so, it was an accepted field, and
// we continue to accept it by having it here.
pg_version: Option<u32>,
#[serde(default, skip_serializing_if = "std::ops::Not::not")]
read_only: bool,
},
ImportPgdata {
import_pgdata: TimelineCreateRequestModeImportPgdata,
@@ -637,8 +632,6 @@ pub struct TenantConfigPatch {
pub sampling_ratio: FieldPatch<Option<Ratio>>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub relsize_snapshot_cache_capacity: FieldPatch<usize>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub basebackup_cache_enabled: FieldPatch<bool>,
}
/// Like [`crate::config::TenantConfigToml`], but preserves the information
@@ -771,9 +764,6 @@ pub struct TenantConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub relsize_snapshot_cache_capacity: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub basebackup_cache_enabled: Option<bool>,
}
impl TenantConfig {
@@ -820,7 +810,6 @@ impl TenantConfig {
mut gc_compaction_ratio_percent,
mut sampling_ratio,
mut relsize_snapshot_cache_capacity,
mut basebackup_cache_enabled,
} = self;
patch.checkpoint_distance.apply(&mut checkpoint_distance);
@@ -925,9 +914,6 @@ impl TenantConfig {
patch
.relsize_snapshot_cache_capacity
.apply(&mut relsize_snapshot_cache_capacity);
patch
.basebackup_cache_enabled
.apply(&mut basebackup_cache_enabled);
Ok(Self {
checkpoint_distance,
@@ -968,7 +954,6 @@ impl TenantConfig {
gc_compaction_ratio_percent,
sampling_ratio,
relsize_snapshot_cache_capacity,
basebackup_cache_enabled,
})
}
@@ -1080,9 +1065,6 @@ impl TenantConfig {
relsize_snapshot_cache_capacity: self
.relsize_snapshot_cache_capacity
.unwrap_or(global_conf.relsize_snapshot_cache_capacity),
basebackup_cache_enabled: self
.basebackup_cache_enabled
.unwrap_or(global_conf.basebackup_cache_enabled),
}
}
}

View File

@@ -6,14 +6,9 @@ license.workspace = true
[dependencies]
anyhow.workspace = true
arc-swap.workspace = true
reqwest.workspace = true
serde_json.workspace = true
serde.workspace = true
serde_json.workspace = true
sha2.workspace = true
thiserror.workspace = true
tokio = { workspace = true, features = ["process", "sync", "fs", "rt", "io-util", "time"] }
tokio-util.workspace = true
tracing-utils.workspace = true
tracing.workspace = true
workspace_hack.workspace = true
thiserror.workspace = true

View File

@@ -1,59 +0,0 @@
//! A background loop that fetches feature flags from PostHog and updates the feature store.
use std::{sync::Arc, time::Duration};
use arc_swap::ArcSwap;
use tokio_util::sync::CancellationToken;
use crate::{FeatureStore, PostHogClient, PostHogClientConfig};
/// A background loop that fetches feature flags from PostHog and updates the feature store.
pub struct FeatureResolverBackgroundLoop {
posthog_client: PostHogClient,
feature_store: ArcSwap<FeatureStore>,
cancel: CancellationToken,
}
impl FeatureResolverBackgroundLoop {
pub fn new(config: PostHogClientConfig, shutdown_pageserver: CancellationToken) -> Self {
Self {
posthog_client: PostHogClient::new(config),
feature_store: ArcSwap::new(Arc::new(FeatureStore::new())),
cancel: shutdown_pageserver,
}
}
pub fn spawn(self: Arc<Self>, handle: &tokio::runtime::Handle, refresh_period: Duration) {
let this = self.clone();
let cancel = self.cancel.clone();
handle.spawn(async move {
tracing::info!("Starting PostHog feature resolver");
let mut ticker = tokio::time::interval(refresh_period);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
_ = ticker.tick() => {}
_ = cancel.cancelled() => break
}
let resp = match this
.posthog_client
.get_feature_flags_local_evaluation()
.await
{
Ok(resp) => resp,
Err(e) => {
tracing::warn!("Cannot get feature flags: {}", e);
continue;
}
};
let feature_store = FeatureStore::new_with_flags(resp.flags);
this.feature_store.store(Arc::new(feature_store));
}
tracing::info!("PostHog feature resolver stopped");
});
}
pub fn feature_store(&self) -> Arc<FeatureStore> {
self.feature_store.load_full()
}
}

View File

@@ -1,9 +1,5 @@
//! A lite version of the PostHog client that only supports local evaluation of feature flags.
mod background_loop;
pub use background_loop::FeatureResolverBackgroundLoop;
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
@@ -24,7 +20,8 @@ pub enum PostHogEvaluationError {
#[derive(Deserialize)]
pub struct LocalEvaluationResponse {
pub flags: Vec<LocalEvaluationFlag>,
#[allow(dead_code)]
flags: Vec<LocalEvaluationFlag>,
}
#[derive(Deserialize)]
@@ -37,7 +34,7 @@ pub struct LocalEvaluationFlag {
#[derive(Deserialize)]
pub struct LocalEvaluationFlagFilters {
groups: Vec<LocalEvaluationFlagFilterGroup>,
multivariate: Option<LocalEvaluationFlagMultivariate>,
multivariate: LocalEvaluationFlagMultivariate,
}
#[derive(Deserialize)]
@@ -97,12 +94,6 @@ impl FeatureStore {
}
}
pub fn new_with_flags(flags: Vec<LocalEvaluationFlag>) -> Self {
let mut store = Self::new();
store.set_flags(flags);
store
}
pub fn set_flags(&mut self, flags: Vec<LocalEvaluationFlag>) {
self.flags.clear();
for flag in flags {
@@ -254,7 +245,7 @@ impl FeatureStore {
}
}
/// Evaluate a multivariate feature flag. Returns an error if the flag is not available or if there are errors
/// Evaluate a multivariate feature flag. Returns `None` if the flag is not available or if there are errors
/// during the evaluation.
///
/// The parsing logic is as follows:
@@ -272,15 +263,10 @@ impl FeatureStore {
/// Example: we have a multivariate flag with 3 groups of the configured global rollout percentage: A (10%), B (20%), C (70%).
/// There is a single group with a condition that has a rollout percentage of 10% and it does not have a variant override.
/// Then, we will have 1% of the users evaluated to A, 2% to B, and 7% to C.
///
/// Error handling: the caller should inspect the error and decide the behavior when a feature flag
/// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be
/// propagated beyond where the feature flag gets resolved.
pub fn evaluate_multivariate(
&self,
flag_key: &str,
user_id: &str,
properties: &HashMap<String, PostHogFlagFilterPropertyValue>,
) -> Result<String, PostHogEvaluationError> {
let hash_on_global_rollout_percentage =
Self::consistent_hash(user_id, flag_key, "multivariate");
@@ -290,39 +276,10 @@ impl FeatureStore {
flag_key,
hash_on_global_rollout_percentage,
hash_on_group_rollout_percentage,
properties,
&HashMap::new(),
)
}
/// Evaluate a boolean feature flag. Returns an error if the flag is not available or if there are errors
/// during the evaluation.
///
/// The parsing logic is as follows:
///
/// * Generate a consistent hash for the tenant-feature.
/// * Match each filter group.
/// - If a group is matched, it will first determine whether the user is in the range of the rollout
/// percentage.
/// - If the hash falls within the group's rollout percentage, return true.
/// * Otherwise, continue with the next group until all groups are evaluated and no group is within the
/// rollout percentage.
/// * If there are no matching groups, return an error.
///
/// Returns `Ok(())` if the feature flag evaluates to true. In the future, it will return a payload.
///
/// Error handling: the caller should inspect the error and decide the behavior when a feature flag
/// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be
/// propagated beyond where the feature flag gets resolved.
pub fn evaluate_boolean(
&self,
flag_key: &str,
user_id: &str,
properties: &HashMap<String, PostHogFlagFilterPropertyValue>,
) -> Result<(), PostHogEvaluationError> {
let hash_on_global_rollout_percentage = Self::consistent_hash(user_id, flag_key, "boolean");
self.evaluate_boolean_inner(flag_key, hash_on_global_rollout_percentage, properties)
}
/// Evaluate a multivariate feature flag. Note that we directly take the mapped user ID
/// (a consistent hash ranging from 0 to 1) so that it is easier to use it in the tests
/// and avoid duplicate computations.
@@ -349,11 +306,6 @@ impl FeatureStore {
flag_key
)));
}
let Some(ref multivariate) = flag_config.filters.multivariate else {
return Err(PostHogEvaluationError::Internal(format!(
"No multivariate available, should use evaluate_boolean?: {flag_key}"
)));
};
// TODO: sort the groups so that variant overrides always get evaluated first and it follows the PostHog
// Python SDK behavior; for now we do not configure conditions without variant overrides in Neon so it
// does not matter.
@@ -362,7 +314,7 @@ impl FeatureStore {
GroupEvaluationResult::MatchedAndOverride(variant) => return Ok(variant),
GroupEvaluationResult::MatchedAndEvaluate => {
let mut percentage = 0;
for variant in &multivariate.variants {
for variant in &flag_config.filters.multivariate.variants {
percentage += variant.rollout_percentage;
if self
.evaluate_percentage(hash_on_global_rollout_percentage, percentage)
@@ -390,77 +342,6 @@ impl FeatureStore {
)))
}
}
/// Evaluate a multivariate feature flag. Note that we directly take the mapped user ID
/// (a consistent hash ranging from 0 to 1) so that it is easier to use it in the tests
/// and avoid duplicate computations.
///
/// Use a different consistent hash for evaluating the group rollout percentage.
/// The behavior: if the condition is set to rolling out to 10% of the users, and
/// we set the variant A to 20% in the global config, then 2% of the total users will
/// be evaluated to variant A.
///
/// Note that the hash to determine group rollout percentage is shared across all groups. So if we have two
/// exactly-the-same conditions with 10% and 20% rollout percentage respectively, a total of 20% of the users
/// will be evaluated (versus 30% if group evaluation is done independently).
pub(crate) fn evaluate_boolean_inner(
&self,
flag_key: &str,
hash_on_global_rollout_percentage: f64,
properties: &HashMap<String, PostHogFlagFilterPropertyValue>,
) -> Result<(), PostHogEvaluationError> {
if let Some(flag_config) = self.flags.get(flag_key) {
if !flag_config.active {
return Err(PostHogEvaluationError::NotAvailable(format!(
"The feature flag is not active: {}",
flag_key
)));
}
if flag_config.filters.multivariate.is_some() {
return Err(PostHogEvaluationError::Internal(format!(
"This looks like a multivariate flag, should use evaluate_multivariate?: {flag_key}"
)));
};
// TODO: sort the groups so that variant overrides always get evaluated first and it follows the PostHog
// Python SDK behavior; for now we do not configure conditions without variant overrides in Neon so it
// does not matter.
for group in &flag_config.filters.groups {
match self.evaluate_group(group, hash_on_global_rollout_percentage, properties)? {
GroupEvaluationResult::MatchedAndOverride(_) => {
return Err(PostHogEvaluationError::Internal(format!(
"Boolean flag cannot have overrides: {}",
flag_key
)));
}
GroupEvaluationResult::MatchedAndEvaluate => {
return Ok(());
}
GroupEvaluationResult::Unmatched => continue,
}
}
// If no group is matched, the feature is not available, and up to the caller to decide what to do.
Err(PostHogEvaluationError::NoConditionGroupMatched)
} else {
// The feature flag is not available yet
Err(PostHogEvaluationError::NotAvailable(format!(
"Not found in the local evaluation spec: {}",
flag_key
)))
}
}
}
pub struct PostHogClientConfig {
/// The server API key.
pub server_api_key: String,
/// The client API key.
pub client_api_key: String,
/// The project ID.
pub project_id: String,
/// The private API URL.
pub private_api_url: String,
/// The public API URL.
pub public_api_url: String,
}
/// A lite PostHog client.
@@ -479,16 +360,37 @@ pub struct PostHogClientConfig {
/// want to report the feature flag usage back to PostHog. The current plan is to use PostHog only as an UI to
/// configure feature flags so it is very likely that the client API will not be used.
pub struct PostHogClient {
/// The config.
config: PostHogClientConfig,
/// The server API key.
server_api_key: String,
/// The client API key.
client_api_key: String,
/// The project ID.
project_id: String,
/// The private API URL.
private_api_url: String,
/// The public API URL.
public_api_url: String,
/// The HTTP client.
client: reqwest::Client,
}
impl PostHogClient {
pub fn new(config: PostHogClientConfig) -> Self {
pub fn new(
server_api_key: String,
client_api_key: String,
project_id: String,
private_api_url: String,
public_api_url: String,
) -> Self {
let client = reqwest::Client::new();
Self { config, client }
Self {
server_api_key,
client_api_key,
project_id,
private_api_url,
public_api_url,
client,
}
}
pub fn new_with_us_region(
@@ -496,13 +398,13 @@ impl PostHogClient {
client_api_key: String,
project_id: String,
) -> Self {
Self::new(PostHogClientConfig {
Self::new(
server_api_key,
client_api_key,
project_id,
private_api_url: "https://us.posthog.com".to_string(),
public_api_url: "https://us.i.posthog.com".to_string(),
})
"https://us.posthog.com".to_string(),
"https://us.i.posthog.com".to_string(),
)
}
/// Fetch the feature flag specs from the server.
@@ -520,12 +422,12 @@ impl PostHogClient {
// with bearer token of self.server_api_key
let url = format!(
"{}/api/projects/{}/feature_flags/local_evaluation",
self.config.private_api_url, self.config.project_id
self.private_api_url, self.project_id
);
let response = self
.client
.get(url)
.bearer_auth(&self.config.server_api_key)
.bearer_auth(&self.server_api_key)
.send()
.await?;
let body = response.text().await?;
@@ -544,11 +446,11 @@ impl PostHogClient {
) -> anyhow::Result<()> {
// PUBLIC_URL/capture/
// with bearer token of self.client_api_key
let url = format!("{}/capture/", self.config.public_api_url);
let url = format!("{}/capture/", self.public_api_url);
self.client
.post(url)
.body(serde_json::to_string(&json!({
"api_key": self.config.client_api_key,
"api_key": self.client_api_key,
"distinct_id": distinct_id,
"event": event,
"properties": properties,
@@ -565,162 +467,95 @@ mod tests {
fn data() -> &'static str {
r#"{
"flags": [
{
"id": 141807,
"team_id": 152860,
"name": "",
"key": "image-compaction-boundary",
"filters": {
"groups": [
{
"variant": null,
"properties": [
{
"key": "plan_type",
"type": "person",
"value": [
"free"
],
"operator": "exact"
}
"flags": [
{
"id": 132794,
"team_id": 152860,
"name": "",
"key": "gc-compaction",
"filters": {
"groups": [
{
"variant": "enabled-stage-2",
"properties": [
{
"key": "plan_type",
"type": "person",
"value": [
"free"
],
"operator": "exact"
},
{
"key": "pageserver_remote_size",
"type": "person",
"value": "10000000",
"operator": "lt"
}
],
"rollout_percentage": 50
},
{
"properties": [
{
"key": "plan_type",
"type": "person",
"value": [
"free"
],
"operator": "exact"
},
{
"key": "pageserver_remote_size",
"type": "person",
"value": "10000000",
"operator": "lt"
}
],
"rollout_percentage": 80
}
],
"payloads": {},
"multivariate": {
"variants": [
{
"key": "disabled",
"name": "",
"rollout_percentage": 90
},
{
"key": "enabled-stage-1",
"name": "",
"rollout_percentage": 10
},
{
"key": "enabled-stage-2",
"name": "",
"rollout_percentage": 0
},
{
"key": "enabled-stage-3",
"name": "",
"rollout_percentage": 0
},
{
"key": "enabled",
"name": "",
"rollout_percentage": 0
}
]
}
},
"deleted": false,
"active": true,
"ensure_experience_continuity": false,
"has_encrypted_payloads": false,
"version": 6
}
],
"rollout_percentage": 40
},
{
"variant": null,
"properties": [],
"rollout_percentage": 10
}
],
"payloads": {},
"multivariate": null
},
"deleted": false,
"active": true,
"ensure_experience_continuity": false,
"has_encrypted_payloads": false,
"version": 1
},
{
"id": 135586,
"team_id": 152860,
"name": "",
"key": "boolean-flag",
"filters": {
"groups": [
{
"variant": null,
"properties": [
{
"key": "plan_type",
"type": "person",
"value": [
"free"
],
"operator": "exact"
}
],
"rollout_percentage": 47
}
],
"payloads": {},
"multivariate": null
},
"deleted": false,
"active": true,
"ensure_experience_continuity": false,
"has_encrypted_payloads": false,
"version": 1
},
{
"id": 132794,
"team_id": 152860,
"name": "",
"key": "gc-compaction",
"filters": {
"groups": [
{
"variant": "enabled-stage-2",
"properties": [
{
"key": "plan_type",
"type": "person",
"value": [
"free"
],
"operator": "exact"
},
{
"key": "pageserver_remote_size",
"type": "person",
"value": "10000000",
"operator": "lt"
}
],
"rollout_percentage": 50
},
{
"properties": [
{
"key": "plan_type",
"type": "person",
"value": [
"free"
],
"operator": "exact"
},
{
"key": "pageserver_remote_size",
"type": "person",
"value": "10000000",
"operator": "lt"
}
],
"rollout_percentage": 80
}
],
"payloads": {},
"multivariate": {
"variants": [
{
"key": "disabled",
"name": "",
"rollout_percentage": 90
},
{
"key": "enabled-stage-1",
"name": "",
"rollout_percentage": 10
},
{
"key": "enabled-stage-2",
"name": "",
"rollout_percentage": 0
},
{
"key": "enabled-stage-3",
"name": "",
"rollout_percentage": 0
},
{
"key": "enabled",
"name": "",
"rollout_percentage": 0
}
]
}
},
"deleted": false,
"active": true,
"ensure_experience_continuity": false,
"has_encrypted_payloads": false,
"version": 7
}
],
"group_type_mapping": {},
"cohorts": {}
}"#
"group_type_mapping": {},
"cohorts": {}
}"#
}
#[test]
@@ -796,125 +631,4 @@ mod tests {
Err(PostHogEvaluationError::NoConditionGroupMatched)
),);
}
#[test]
fn evaluate_boolean_1() {
// The `boolean-flag` feature flag only has one group that matches on the free user.
let mut store = FeatureStore::new();
let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap();
store.set_flags(response.flags);
// This lacks the required properties and cannot be evaluated.
let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &HashMap::new());
assert!(matches!(
variant,
Err(PostHogEvaluationError::NotAvailable(_))
),);
let properties_unmatched = HashMap::from([
(
"plan_type".to_string(),
PostHogFlagFilterPropertyValue::String("paid".to_string()),
),
(
"pageserver_remote_size".to_string(),
PostHogFlagFilterPropertyValue::Number(1000.0),
),
]);
// This does not match any group so there will be an error.
let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &properties_unmatched);
assert!(matches!(
variant,
Err(PostHogEvaluationError::NoConditionGroupMatched)
),);
let properties = HashMap::from([
(
"plan_type".to_string(),
PostHogFlagFilterPropertyValue::String("free".to_string()),
),
(
"pageserver_remote_size".to_string(),
PostHogFlagFilterPropertyValue::Number(1000.0),
),
]);
// It matches the first group as 0.10 <= 0.50 and the properties are matched. Then it gets evaluated to the variant override.
let variant = store.evaluate_boolean_inner("boolean-flag", 0.10, &properties);
assert!(variant.is_ok());
// It matches the group conditions but not the group rollout percentage.
let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &properties);
assert!(matches!(
variant,
Err(PostHogEvaluationError::NoConditionGroupMatched)
),);
}
#[test]
fn evaluate_boolean_2() {
// The `image-compaction-boundary` feature flag has one group that matches on the free user and a group that matches on all users.
let mut store = FeatureStore::new();
let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap();
store.set_flags(response.flags);
// This lacks the required properties and cannot be evaluated.
let variant =
store.evaluate_boolean_inner("image-compaction-boundary", 1.00, &HashMap::new());
assert!(matches!(
variant,
Err(PostHogEvaluationError::NotAvailable(_))
),);
let properties_unmatched = HashMap::from([
(
"plan_type".to_string(),
PostHogFlagFilterPropertyValue::String("paid".to_string()),
),
(
"pageserver_remote_size".to_string(),
PostHogFlagFilterPropertyValue::Number(1000.0),
),
]);
// This does not match the filtered group but the all user group.
let variant =
store.evaluate_boolean_inner("image-compaction-boundary", 1.00, &properties_unmatched);
assert!(matches!(
variant,
Err(PostHogEvaluationError::NoConditionGroupMatched)
),);
let variant =
store.evaluate_boolean_inner("image-compaction-boundary", 0.05, &properties_unmatched);
assert!(variant.is_ok());
let properties = HashMap::from([
(
"plan_type".to_string(),
PostHogFlagFilterPropertyValue::String("free".to_string()),
),
(
"pageserver_remote_size".to_string(),
PostHogFlagFilterPropertyValue::Number(1000.0),
),
]);
// It matches the first group as 0.30 <= 0.40 and the properties are matched. Then it gets evaluated to the variant override.
let variant = store.evaluate_boolean_inner("image-compaction-boundary", 0.30, &properties);
assert!(variant.is_ok());
// It matches the group conditions but not the group rollout percentage.
let variant = store.evaluate_boolean_inner("image-compaction-boundary", 1.00, &properties);
assert!(matches!(
variant,
Err(PostHogEvaluationError::NoConditionGroupMatched)
),);
// It matches the second "all" group conditions.
let variant = store.evaluate_boolean_inner("image-compaction-boundary", 0.09, &properties);
assert!(variant.is_ok());
}
}

View File

@@ -9,12 +9,14 @@ use sha2::digest::FixedOutput;
use sha2::{Digest, Sha256};
use tokio::task::yield_now;
use crate::CSafeStr;
const NONCE_LENGTH: usize = 24;
/// The identifier of the SCRAM-SHA-256 SASL authentication mechanism.
pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
pub const SCRAM_SHA_256: &CSafeStr = CSafeStr::from_cstr(c"SCRAM-SHA-256");
/// The identifier of the SCRAM-SHA-256-PLUS SASL authentication mechanism.
pub const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
pub const SCRAM_SHA_256_PLUS: &CSafeStr = CSafeStr::from_cstr(c"SCRAM-SHA-256-PLUS");
// since postgres passwords are not required to exclude saslprep-prohibited
// characters or even be valid UTF8, we run saslprep if possible and otherwise

View File

@@ -11,7 +11,7 @@
//! set to `UTF8`. It will most likely not behave properly if that is not the case.
#![warn(missing_docs, clippy::all)]
use std::io;
use std::{ffi::CStr, io};
use byteorder::{BigEndian, ByteOrder};
use bytes::{BufMut, BytesMut};
@@ -36,20 +36,67 @@ pub enum IsNull {
No,
}
fn write_nullable<F, E>(serializer: F, buf: &mut BytesMut) -> Result<(), E>
/// A [`std::ffi::CStr`] but without the null byte.
#[repr(transparent)]
#[derive(PartialEq, Eq, Hash, Debug)]
pub struct CSafeStr([u8]);
impl CSafeStr {
/// Create a new `CSafeStr`, erroring if the bytes contains a null.
pub fn new(bytes: &[u8]) -> Result<&Self, io::Error> {
let nul_pos = memchr::memchr(0, bytes);
match nul_pos {
Some(nul_pos) => Err(io::Error::other(format!(
"unexpected null byte at position {nul_pos}"
))),
None => {
// Safety: CSafeStr is transparent over [u8].
Ok(unsafe { std::mem::transmute(bytes) })
}
}
}
/// Create a new `CSafeStr` up until the next null.
pub fn take<'a>(bytes: &mut &'a [u8]) -> &'a Self {
let nul_pos = memchr::memchr(0, bytes).unwrap_or(bytes.len());
let bytes = bytes
.split_off(..nul_pos)
.expect("nul_pos should be in-bounds");
// Safety: CSafeStr is transparent over [u8].
unsafe { std::mem::transmute(bytes) }
}
/// Get the bytes of this CSafeStr.
pub const fn as_bytes(&self) -> &[u8] {
&self.0
}
/// Create a new `CSafeStr`
pub const fn from_cstr(s: &CStr) -> &CSafeStr {
// Safety: CSafeStr is transparent over [u8].
unsafe { std::mem::transmute(s.to_bytes()) }
}
}
impl<'a> From<&'a CStr> for &'a CSafeStr {
fn from(s: &'a CStr) -> &'a CSafeStr {
CSafeStr::from_cstr(s)
}
}
fn write_nullable<F>(serializer: F, buf: &mut BytesMut)
where
F: FnOnce(&mut BytesMut) -> Result<IsNull, E>,
E: From<io::Error>,
F: FnOnce(&mut BytesMut) -> IsNull,
{
let base = buf.len();
buf.put_i32(0);
let size = match serializer(buf)? {
IsNull::No => i32::from_usize(buf.len() - base - 4)?,
let size = match serializer(buf) {
// this is an unreasonable enough case that I think a panic is acceptable.
IsNull::No => i32::from_usize(buf.len() - base - 4)
.expect("buffer size should not be larger than i32::MAX"),
IsNull::Yes => -1,
};
BigEndian::write_i32(&mut buf[base..], size);
Ok(())
}
trait FromUsize: Sized {

View File

@@ -9,7 +9,7 @@ use bytes::{Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use memchr::memchr;
use crate::Oid;
use crate::{CSafeStr, Oid};
// top-level message tags
const PARSE_COMPLETE_TAG: u8 = b'1';
@@ -332,25 +332,24 @@ impl AuthenticationSaslBody {
pub struct SaslMechanisms<'a>(&'a [u8]);
impl<'a> FallibleIterator for SaslMechanisms<'a> {
type Item = &'a str;
type Item = &'a CSafeStr;
type Error = io::Error;
#[inline]
fn next(&mut self) -> io::Result<Option<&'a str>> {
let value_end = find_null(self.0, 0)?;
if value_end == 0 {
if self.0.len() != 1 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid message length: expected to be at end of iterator for sasl",
));
}
Ok(None)
} else {
let value = get_str(&self.0[..value_end])?;
self.0 = &self.0[value_end + 1..];
Ok(Some(value))
fn next(&mut self) -> io::Result<Option<&'a CSafeStr>> {
if self.0 == b"0" {
return Ok(None);
}
let value = CSafeStr::take(&mut self.0);
if value.as_bytes().len() == 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid message length: expected to be at end of iterator for sasl",
));
}
Ok(Some(value))
}
}

View File

@@ -1,115 +1,73 @@
//! Frontend message serialization.
#![allow(missing_docs)]
use std::error::Error;
use std::{io, marker};
use std::io;
use byteorder::{BigEndian, ByteOrder};
use bytes::{Buf, BufMut, BytesMut};
use crate::{FromUsize, IsNull, Oid, write_nullable};
use crate::{CSafeStr, FromUsize, IsNull, Oid, write_nullable};
#[inline]
fn write_body<F, E>(buf: &mut BytesMut, f: F) -> Result<(), E>
fn write_body<F>(buf: &mut BytesMut, f: F)
where
F: FnOnce(&mut BytesMut) -> Result<(), E>,
E: From<io::Error>,
F: FnOnce(&mut BytesMut),
{
let base = buf.len();
buf.extend_from_slice(&[0; 4]);
f(buf)?;
f(buf);
let size = i32::from_usize(buf.len() - base)?;
let size =
i32::from_usize(buf.len() - base).expect("buffer size should not be larger than i32::MAX");
BigEndian::write_i32(&mut buf[base..], size);
Ok(())
}
#[derive(Debug)]
pub enum BindError {
Conversion(Box<dyn Error + marker::Sync + Send>),
Serialization(io::Error),
}
impl From<Box<dyn Error + marker::Sync + Send>> for BindError {
#[inline]
fn from(e: Box<dyn Error + marker::Sync + Send>) -> BindError {
BindError::Conversion(e)
}
}
impl From<io::Error> for BindError {
#[inline]
fn from(e: io::Error) -> BindError {
BindError::Serialization(e)
}
}
#[inline]
pub fn bind<I, J, F, T, K>(
portal: &str,
statement: &str,
portal: &CSafeStr,
statement: &CSafeStr,
formats: I,
values: J,
mut serializer: F,
result_formats: K,
buf: &mut BytesMut,
) -> Result<(), BindError>
where
) where
I: IntoIterator<Item = i16>,
J: IntoIterator<Item = T>,
F: FnMut(T, &mut BytesMut) -> Result<IsNull, Box<dyn Error + marker::Sync + Send>>,
F: FnMut(T, &mut BytesMut) -> IsNull,
K: IntoIterator<Item = i16>,
{
buf.put_u8(b'B');
write_body(buf, |buf| {
write_cstr(portal.as_bytes(), buf)?;
write_cstr(statement.as_bytes(), buf)?;
write_counted(
formats,
|f, buf| {
buf.put_i16(f);
Ok::<_, io::Error>(())
},
buf,
)?;
write_cstr(portal, buf);
write_cstr(statement, buf);
write_counted(formats, |f, buf| buf.put_i16(f), buf);
write_counted(
values,
|v, buf| write_nullable(|buf| serializer(v, buf), buf),
buf,
)?;
write_counted(
result_formats,
|f, buf| {
buf.put_i16(f);
Ok::<_, io::Error>(())
},
buf,
)?;
Ok(())
);
write_counted(result_formats, |f, buf| buf.put_i16(f), buf);
})
}
#[inline]
fn write_counted<I, T, F, E>(items: I, mut serializer: F, buf: &mut BytesMut) -> Result<(), E>
fn write_counted<I, T, F>(items: I, mut serializer: F, buf: &mut BytesMut)
where
I: IntoIterator<Item = T>,
F: FnMut(T, &mut BytesMut) -> Result<(), E>,
E: From<io::Error>,
F: FnMut(T, &mut BytesMut),
{
let base = buf.len();
buf.extend_from_slice(&[0; 2]);
let mut count = 0;
for item in items {
serializer(item, buf)?;
serializer(item, buf);
count += 1;
}
let count = i16::from_usize(count)?;
let count = i16::from_usize(count).expect("list should not exceed 32767 items");
BigEndian::write_i16(&mut buf[base..], count);
Ok(())
}
#[inline]
@@ -118,17 +76,15 @@ pub fn cancel_request(process_id: i32, secret_key: i32, buf: &mut BytesMut) {
buf.put_i32(80_877_102);
buf.put_i32(process_id);
buf.put_i32(secret_key);
Ok::<_, io::Error>(())
})
.unwrap();
}
#[inline]
pub fn close(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
pub fn close(variant: u8, name: &CSafeStr, buf: &mut BytesMut) {
buf.put_u8(b'C');
write_body(buf, |buf| {
buf.put_u8(variant);
write_cstr(name.as_bytes(), buf)
write_cstr(name, buf)
})
}
@@ -163,85 +119,75 @@ where
#[inline]
pub fn copy_done(buf: &mut BytesMut) {
buf.put_u8(b'c');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
write_body(buf, |_| {});
}
#[inline]
pub fn copy_fail(message: &str, buf: &mut BytesMut) -> io::Result<()> {
pub fn copy_fail(message: &CSafeStr, buf: &mut BytesMut) {
buf.put_u8(b'f');
write_body(buf, |buf| write_cstr(message.as_bytes(), buf))
write_body(buf, |buf| write_cstr(message, buf))
}
#[inline]
pub fn describe(variant: u8, name: &str, buf: &mut BytesMut) -> io::Result<()> {
pub fn describe(variant: u8, name: &CSafeStr, buf: &mut BytesMut) {
buf.put_u8(b'D');
write_body(buf, |buf| {
buf.put_u8(variant);
write_cstr(name.as_bytes(), buf)
write_cstr(name, buf)
})
}
#[inline]
pub fn execute(portal: &str, max_rows: i32, buf: &mut BytesMut) -> io::Result<()> {
pub fn execute(portal: &CSafeStr, max_rows: i32, buf: &mut BytesMut) {
buf.put_u8(b'E');
write_body(buf, |buf| {
write_cstr(portal.as_bytes(), buf)?;
write_cstr(portal, buf);
buf.put_i32(max_rows);
Ok(())
})
}
#[inline]
pub fn parse<I>(name: &str, query: &str, param_types: I, buf: &mut BytesMut) -> io::Result<()>
pub fn parse<I>(name: &CSafeStr, query: &CSafeStr, param_types: I, buf: &mut BytesMut)
where
I: IntoIterator<Item = Oid>,
{
buf.put_u8(b'P');
write_body(buf, |buf| {
write_cstr(name.as_bytes(), buf)?;
write_cstr(query.as_bytes(), buf)?;
write_counted(
param_types,
|t, buf| {
buf.put_u32(t);
Ok::<_, io::Error>(())
},
buf,
)?;
Ok(())
write_cstr(name, buf);
write_cstr(query, buf);
write_counted(param_types, |t, buf| buf.put_u32(t), buf);
})
}
#[inline]
pub fn password_message(password: &[u8], buf: &mut BytesMut) -> io::Result<()> {
pub fn password_message(password: &CSafeStr, buf: &mut BytesMut) {
buf.put_u8(b'p');
write_body(buf, |buf| write_cstr(password, buf))
}
#[inline]
pub fn query(query: &str, buf: &mut BytesMut) -> io::Result<()> {
pub fn query(query: &CSafeStr, buf: &mut BytesMut) {
buf.put_u8(b'Q');
write_body(buf, |buf| write_cstr(query.as_bytes(), buf))
write_body(buf, |buf| write_cstr(query, buf))
}
#[inline]
pub fn sasl_initial_response(mechanism: &str, data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
pub fn sasl_initial_response(mechanism: &CSafeStr, data: &[u8], buf: &mut BytesMut) {
buf.put_u8(b'p');
write_body(buf, |buf| {
write_cstr(mechanism.as_bytes(), buf)?;
let len = i32::from_usize(data.len())?;
write_cstr(mechanism, buf);
let len =
i32::from_usize(data.len()).expect("sasl data should not be larger than i32::MAX");
buf.put_i32(len);
buf.put_slice(data);
Ok(())
})
}
#[inline]
pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
pub fn sasl_response(data: &[u8], buf: &mut BytesMut) {
buf.put_u8(b'p');
write_body(buf, |buf| {
buf.put_slice(data);
Ok(())
})
}
@@ -249,19 +195,16 @@ pub fn sasl_response(data: &[u8], buf: &mut BytesMut) -> io::Result<()> {
pub fn ssl_request(buf: &mut BytesMut) {
write_body(buf, |buf| {
buf.put_i32(80_877_103);
Ok::<_, io::Error>(())
})
.unwrap();
});
}
#[inline]
pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) -> io::Result<()> {
pub fn startup_message(parameters: &StartupMessageParams, buf: &mut BytesMut) {
write_body(buf, |buf| {
// postgres protocol version 3.0(196608) in bigger-endian
buf.put_i32(0x00_03_00_00);
buf.put_slice(&parameters.params);
buf.put_u8(0);
Ok(())
})
}
@@ -272,10 +215,7 @@ pub struct StartupMessageParams {
impl StartupMessageParams {
/// Set parameter's value by its name.
pub fn insert(&mut self, name: &str, value: &str) {
if name.contains('\0') || value.contains('\0') {
panic!("startup parameter name or value contained a null")
}
pub fn insert(&mut self, name: &CSafeStr, value: &CSafeStr) {
self.params.put_slice(name.as_bytes());
self.params.put_u8(0);
self.params.put_slice(value.as_bytes());
@@ -286,30 +226,17 @@ impl StartupMessageParams {
#[inline]
pub fn sync(buf: &mut BytesMut) {
buf.put_u8(b'S');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
}
#[inline]
pub fn flush(buf: &mut BytesMut) {
buf.put_u8(b'H');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
write_body(buf, |_| {});
}
#[inline]
pub fn terminate(buf: &mut BytesMut) {
buf.put_u8(b'X');
write_body(buf, |_| Ok::<(), io::Error>(())).unwrap();
write_body(buf, |_| {});
}
#[inline]
fn write_cstr(s: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> {
if s.contains(&0) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"string contains embedded null",
));
}
buf.put_slice(s);
fn write_cstr(s: &CSafeStr, buf: &mut BytesMut) {
buf.put_slice(s.as_bytes());
buf.put_u8(0);
Ok(())
}

View File

@@ -9,10 +9,11 @@ use std::error::Error;
use std::fmt;
use std::sync::Arc;
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
#[doc(inline)]
pub use postgres_protocol2::Oid;
use postgres_protocol2::types;
use postgres_protocol2::{IsNull, types};
use crate::type_gen::{Inner, Other};
@@ -26,6 +27,20 @@ macro_rules! accepts {
)
}
/// Generates an implementation of `ToSql::to_sql_checked`.
///
/// All `ToSql` implementations should use this macro.
macro_rules! to_sql_checked {
() => {
fn check(&self, ty: &Type) -> ::std::result::Result<(), $crate::WrongType> {
if !<Self as $crate::ToSql>::accepts(ty) {
return Err($crate::WrongType::new::<Self>(ty.clone()));
}
Ok(())
}
};
}
// mod pg_lsn;
#[doc(hidden)]
pub mod private;
@@ -106,7 +121,7 @@ pub enum Kind {
/// An array type along with the type of its elements.
Array(Type),
/// A range type along with the type of its elements.
Range(Oid),
Range(Type),
/// A multirange type along with the type of its elements.
Multirange(Type),
/// A domain type along with its underlying type.
@@ -333,12 +348,35 @@ macro_rules! simple_from {
simple_from!(i8, char_from_sql, CHAR);
simple_from!(u32, oid_from_sql, OID);
/// An enum representing the nullability of a Postgres value.
pub enum IsNull {
/// The value is NULL.
Yes,
/// The value is not NULL.
No,
/// A trait for types that can be converted into Postgres values.
pub trait ToSql: fmt::Debug {
/// Converts the value of `self` into the binary format of the specified
/// Postgres `Type`, appending it to `out`.
///
/// The caller of this method is responsible for ensuring that this type
/// is compatible with the Postgres `Type`.
///
/// The return value indicates if this value should be represented as
/// `NULL`. If this is the case, implementations **must not** write
/// anything to `out`.
fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> IsNull;
/// Determines if a value of this type can be converted to the specified
/// Postgres `Type`.
fn accepts(ty: &Type) -> bool
where
Self: Sized;
/// An adaptor method used internally by Rust-Postgres.
///
/// *All* implementations of this method should be generated by the
/// `to_sql_checked!()` macro.
fn check(&self, ty: &Type) -> Result<(), WrongType>;
/// Specify the encode format
fn encode_format(&self, _ty: &Type) -> Format {
Format::Binary
}
}
/// Supported Postgres message format types
@@ -351,3 +389,49 @@ pub enum Format {
/// Compact, typed binary format
Binary,
}
impl ToSql for &str {
fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> IsNull {
match *ty {
ref ty if ty.name() == "ltree" => types::ltree_to_sql(self, w),
ref ty if ty.name() == "lquery" => types::lquery_to_sql(self, w),
ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(self, w),
_ => types::text_to_sql(self, w),
}
IsNull::No
}
fn accepts(ty: &Type) -> bool {
match *ty {
Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true,
ref ty
if (ty.name() == "citext"
|| ty.name() == "ltree"
|| ty.name() == "lquery"
|| ty.name() == "ltxtquery") =>
{
true
}
_ => false,
}
}
to_sql_checked!();
}
macro_rules! simple_to {
($t:ty, $f:ident, $($expected:ident),+) => {
impl ToSql for $t {
fn to_sql(&self, _: &Type, w: &mut BytesMut) -> IsNull {
types::$f(*self, w);
IsNull::No
}
accepts!($($expected),+);
to_sql_checked!();
}
}
}
simple_to!(u32, oid_to_sql, OID);

View File

@@ -393,7 +393,7 @@ impl Inner {
}
}
pub const fn const_oid(&self) -> Oid {
pub fn oid(&self) -> Oid {
match *self {
Inner::Bool => 16,
Inner::Bytea => 17,
@@ -580,14 +580,7 @@ impl Inner {
Inner::TstzmultiRangeArray => 6153,
Inner::DatemultiRangeArray => 6155,
Inner::Int8multiRangeArray => 6157,
Inner::Other(_) => u32::MAX,
}
}
pub fn oid(&self) -> Oid {
match *self {
Inner::Other(ref u) => u.oid,
_ => self.const_oid(),
}
}
@@ -734,17 +727,17 @@ impl Inner {
Inner::JsonbArray => &Kind::Array(Type(Inner::Jsonb)),
Inner::AnyRange => &Kind::Pseudo,
Inner::EventTrigger => &Kind::Pseudo,
Inner::Int4Range => &const { Kind::Range(Inner::Int4.const_oid()) },
Inner::Int4Range => &Kind::Range(Type(Inner::Int4)),
Inner::Int4RangeArray => &Kind::Array(Type(Inner::Int4Range)),
Inner::NumRange => &const { Kind::Range(Inner::Numeric.const_oid()) },
Inner::NumRange => &Kind::Range(Type(Inner::Numeric)),
Inner::NumRangeArray => &Kind::Array(Type(Inner::NumRange)),
Inner::TsRange => &const { Kind::Range(Inner::Timestamp.const_oid()) },
Inner::TsRange => &Kind::Range(Type(Inner::Timestamp)),
Inner::TsRangeArray => &Kind::Array(Type(Inner::TsRange)),
Inner::TstzRange => &const { Kind::Range(Inner::Timestamptz.const_oid()) },
Inner::TstzRange => &Kind::Range(Type(Inner::Timestamptz)),
Inner::TstzRangeArray => &Kind::Array(Type(Inner::TstzRange)),
Inner::DateRange => &const { Kind::Range(Inner::Date.const_oid()) },
Inner::DateRange => &Kind::Range(Type(Inner::Date)),
Inner::DateRangeArray => &Kind::Array(Type(Inner::DateRange)),
Inner::Int8Range => &const { Kind::Range(Inner::Int8.const_oid()) },
Inner::Int8Range => &Kind::Range(Type(Inner::Int8)),
Inner::Int8RangeArray => &Kind::Array(Type(Inner::Int8Range)),
Inner::Jsonpath => &Kind::Simple,
Inner::JsonpathArray => &Kind::Array(Type(Inner::Jsonpath)),

View File

@@ -1,12 +1,14 @@
use std::collections::HashMap;
use std::fmt;
use std::net::IpAddr;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{TryStreamExt, future, ready};
use parking_lot::Mutex;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use serde::{Deserialize, Serialize};
@@ -14,52 +16,29 @@ use tokio::sync::mpsc;
use crate::codec::{BackendMessages, FrontendMessage};
use crate::config::{Host, SslMode};
use crate::connection::{Request, RequestMessages};
use crate::query::RowStream;
use crate::simple_query::SimpleQueryStream;
use crate::types::{Oid, Type};
use crate::{
CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, Transaction, TransactionBuilder,
query, simple_query,
CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, Statement, Transaction,
TransactionBuilder, query, simple_query,
};
pub struct Responses {
/// new messages from conn
receiver: mpsc::Receiver<BackendMessages>,
/// current batch of messages
cur: BackendMessages,
/// number of total queries sent.
waiting: usize,
/// number of ReadyForQuery messages received.
received: usize,
}
impl Responses {
pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll<Result<Message, Error>> {
loop {
// get the next saved message
if let Some(message) = self.cur.next().map_err(Error::parse)? {
let received = self.received;
// increase the query head if this is the last message.
if let Message::ReadyForQuery(_) = message {
self.received += 1;
}
// check if the client has skipped this query.
if received + 1 < self.waiting {
// grab the next message.
continue;
}
// convenience: turn the error messaage into a proper error.
let res = match message {
Message::ErrorResponse(body) => Err(Error::db(body)),
message => Ok(message),
};
return Poll::Ready(res);
match self.cur.next().map_err(Error::parse)? {
Some(Message::ErrorResponse(body)) => return Poll::Ready(Err(Error::db(body))),
Some(message) => return Poll::Ready(Ok(message)),
None => {}
}
// get the next batch of messages.
match ready!(self.receiver.poll_recv(cx)) {
Some(messages) => self.cur = messages,
None => return Poll::Ready(Err(Error::closed())),
@@ -76,87 +55,44 @@ impl Responses {
/// (corresponding to the queries in the [crate::prepare] module).
#[derive(Default)]
pub(crate) struct CachedTypeInfo {
/// A statement for basic information for a type from its
/// OID. Corresponds to [TYPEINFO_QUERY](crate::prepare::TYPEINFO_QUERY) (or its
/// fallback).
pub(crate) typeinfo: Option<Statement>,
/// Cache of types already looked up.
pub(crate) types: HashMap<Oid, Type>,
}
pub struct InnerClient {
sender: mpsc::UnboundedSender<FrontendMessage>,
responses: Responses,
sender: mpsc::UnboundedSender<Request>,
/// A buffer to use when writing out postgres commands.
buffer: BytesMut,
buffer: Mutex<BytesMut>,
}
impl InnerClient {
pub fn start(&mut self) -> Result<PartialQuery, Error> {
self.responses.waiting += 1;
Ok(PartialQuery(Some(self)))
pub fn send(&self, messages: RequestMessages) -> Result<Responses, Error> {
let (sender, receiver) = mpsc::channel(1);
let request = Request { messages, sender };
self.sender.send(request).map_err(|_| Error::closed())?;
Ok(Responses {
receiver,
cur: BackendMessages::empty(),
})
}
// pub fn send_with_sync<F>(&mut self, f: F) -> Result<&mut Responses, Error>
// where
// F: FnOnce(&mut BytesMut) -> Result<(), Error>,
// {
// self.start()?.send_with_sync(f)
// }
pub fn send_simple_query(&mut self, query: &str) -> Result<&mut Responses, Error> {
self.responses.waiting += 1;
self.buffer.clear();
// simple queries do not need sync.
frontend::query(query, &mut self.buffer).map_err(Error::encode)?;
let buf = self.buffer.split().freeze();
self.send_message(FrontendMessage::Raw(buf))
}
fn send_message(&mut self, messages: FrontendMessage) -> Result<&mut Responses, Error> {
self.sender.send(messages).map_err(|_| Error::closed())?;
Ok(&mut self.responses)
}
}
pub struct PartialQuery<'a>(Option<&'a mut InnerClient>);
impl Drop for PartialQuery<'_> {
fn drop(&mut self) {
if let Some(client) = self.0.take() {
client.buffer.clear();
frontend::sync(&mut client.buffer);
let buf = client.buffer.split().freeze();
let _ = client.send_message(FrontendMessage::Raw(buf));
}
}
}
impl<'a> PartialQuery<'a> {
pub fn send_with_flush<F>(&mut self, f: F) -> Result<&mut Responses, Error>
/// Call the given function with a buffer to be used when writing out
/// postgres commands.
pub fn with_buf<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut BytesMut) -> Result<(), Error>,
F: FnOnce(&mut BytesMut) -> R,
{
let client = self.0.as_deref_mut().unwrap();
client.buffer.clear();
f(&mut client.buffer)?;
frontend::flush(&mut client.buffer);
let buf = client.buffer.split().freeze();
client.send_message(FrontendMessage::Raw(buf))
}
pub fn send_with_sync<F>(mut self, f: F) -> Result<&'a mut Responses, Error>
where
F: FnOnce(&mut BytesMut) -> Result<(), Error>,
{
let client = self.0.as_deref_mut().unwrap();
client.buffer.clear();
f(&mut client.buffer)?;
frontend::sync(&mut client.buffer);
let buf = client.buffer.split().freeze();
let _ = client.send_message(FrontendMessage::Raw(buf));
Ok(&mut self.0.take().unwrap().responses)
let mut buffer = self.buffer.lock();
let r = f(&mut buffer);
buffer.clear();
r
}
}
@@ -173,7 +109,7 @@ pub struct SocketConfig {
/// The client is one half of what is returned when a connection is established. Users interact with the database
/// through this client object.
pub struct Client {
inner: InnerClient,
inner: Arc<InnerClient>,
cached_typeinfo: CachedTypeInfo,
socket_config: SocketConfig,
@@ -184,24 +120,17 @@ pub struct Client {
impl Client {
pub(crate) fn new(
sender: mpsc::UnboundedSender<FrontendMessage>,
receiver: mpsc::Receiver<BackendMessages>,
sender: mpsc::UnboundedSender<Request>,
socket_config: SocketConfig,
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
) -> Client {
Client {
inner: InnerClient {
inner: Arc::new(InnerClient {
sender,
responses: Responses {
receiver,
cur: BackendMessages::empty(),
waiting: 0,
received: 0,
},
buffer: Default::default(),
},
}),
cached_typeinfo: Default::default(),
socket_config,
@@ -216,29 +145,19 @@ impl Client {
self.process_id
}
pub(crate) fn inner_mut(&mut self) -> &mut InnerClient {
&mut self.inner
pub(crate) fn inner(&self) -> &Arc<InnerClient> {
&self.inner
}
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
/// to save a roundtrip
pub async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream, Error>
pub async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
query::query_txt(
&mut self.inner,
&mut self.cached_typeinfo,
statement,
params,
)
.await
query::query_txt(&self.inner, statement, params).await
}
/// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
@@ -254,15 +173,12 @@ impl Client {
/// Prepared statements should be use for any query which contains user-specified data, as they provided the
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
/// them to this method!
pub async fn simple_query(&mut self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
self.simple_query_raw(query).await?.try_collect().await
}
pub(crate) async fn simple_query_raw(
&mut self,
query: &str,
) -> Result<SimpleQueryStream, Error> {
simple_query::simple_query(self.inner_mut(), query).await
pub(crate) async fn simple_query_raw(&self, query: &str) -> Result<SimpleQueryStream, Error> {
simple_query::simple_query(self.inner(), query).await
}
/// Executes a sequence of SQL statements using the simple query protocol.
@@ -275,11 +191,15 @@ impl Client {
/// Prepared statements should be use for any query which contains user-specified data, as they provided the
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
/// them to this method!
pub async fn batch_execute(&mut self, query: &str) -> Result<ReadyForQueryStatus, Error> {
simple_query::batch_execute(self.inner_mut(), query).await
pub async fn batch_execute(&self, query: &str) -> Result<ReadyForQueryStatus, Error> {
simple_query::batch_execute(self.inner(), query).await
}
pub async fn discard_all(&mut self) -> Result<ReadyForQueryStatus, Error> {
// clear the prepared statements that are about to be nuked from the postgres session
self.cached_typeinfo.typeinfo = None;
self.batch_execute("discard all").await
}
@@ -288,7 +208,7 @@ impl Client {
/// The transaction will roll back by default - use the `commit` method to commit it.
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
struct RollbackIfNotDone<'me> {
client: &'me mut Client,
client: &'me Client,
done: bool,
}
@@ -298,7 +218,14 @@ impl Client {
return;
}
let _ = self.client.inner.send_simple_query("ROLLBACK");
let buf = self.client.inner().with_buf(|buf| {
frontend::query(c"ROLLBACK".into(), buf);
buf.split().freeze()
});
let _ = self
.client
.inner()
.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
}
}
@@ -312,7 +239,7 @@ impl Client {
client: self,
done: false,
};
cleaner.client.batch_execute("BEGIN").await?;
self.batch_execute("BEGIN").await?;
cleaner.done = true;
}
@@ -338,6 +265,11 @@ impl Client {
}
}
/// Query for type information
pub(crate) async fn get_type_inner(&mut self, oid: Oid) -> Result<Type, Error> {
crate::prepare::get_type(&self.inner, &mut self.cached_typeinfo, oid).await
}
/// Determines if the connection to the server has already closed.
///
/// In that case, all future queries will fail.

View File

@@ -1,16 +1,21 @@
use std::io;
use bytes::{Bytes, BytesMut};
use bytes::{Buf, Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use postgres_protocol2::message::backend;
use postgres_protocol2::message::frontend::CopyData;
use tokio_util::codec::{Decoder, Encoder};
pub enum FrontendMessage {
Raw(Bytes),
CopyData(CopyData<Box<dyn Buf + Send>>),
}
pub enum BackendMessage {
Normal { messages: BackendMessages },
Normal {
messages: BackendMessages,
request_complete: bool,
},
Async(backend::Message),
}
@@ -39,6 +44,7 @@ impl Encoder<FrontendMessage> for PostgresCodec {
fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> {
match item {
FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf),
FrontendMessage::CopyData(data) => data.write(dst),
}
Ok(())
@@ -51,6 +57,7 @@ impl Decoder for PostgresCodec {
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<BackendMessage>, io::Error> {
let mut idx = 0;
let mut request_complete = false;
while let Some(header) = backend::Header::parse(&src[idx..])? {
let len = header.len() as usize + 1;
@@ -75,6 +82,7 @@ impl Decoder for PostgresCodec {
idx += len;
if header.tag() == backend::READY_FOR_QUERY_TAG {
request_complete = true;
break;
}
}
@@ -84,6 +92,7 @@ impl Decoder for PostgresCodec {
} else {
Ok(Some(BackendMessage::Normal {
messages: BackendMessages(src.split_to(idx)),
request_complete,
}))
}
}

View File

@@ -4,6 +4,7 @@ use std::net::IpAddr;
use std::time::Duration;
use std::{fmt, str};
use postgres_protocol2::CSafeStr;
pub use postgres_protocol2::authentication::sasl::ScramKeys;
use postgres_protocol2::message::frontend::StartupMessageParams;
use serde::{Deserialize, Serialize};
@@ -162,7 +163,10 @@ impl Config {
self.username = true;
}
self.server_params.insert(name, value);
self.server_params.insert(
CSafeStr::new(name.as_bytes()).expect("param name should not contain a null"),
CSafeStr::new(value.as_bytes()).expect("param name should not contain a null"),
);
self
}

View File

@@ -59,11 +59,9 @@ where
connect_timeout: config.connect_timeout,
};
let (client_tx, conn_rx) = mpsc::unbounded_channel();
let (conn_tx, client_rx) = mpsc::channel(4);
let (sender, receiver) = mpsc::unbounded_channel();
let client = Client::new(
client_tx,
client_rx,
sender,
socket_config,
config.ssl_mode,
process_id,
@@ -76,7 +74,7 @@ where
.map(|m| BackendMessage::Async(Message::NoticeResponse(m)))
.collect();
let connection = Connection::new(stream, delayed, parameters, conn_tx, conn_rx);
let connection = Connection::new(stream, delayed, parameters, receiver);
Ok((client, connection))
}

View File

@@ -6,6 +6,7 @@ use std::task::{Context, Poll};
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{Sink, SinkExt, Stream, TryStreamExt, ready};
use postgres_protocol2::CSafeStr;
use postgres_protocol2::authentication::sasl;
use postgres_protocol2::authentication::sasl::ScramSha256;
use postgres_protocol2::message::backend::{AuthenticationSaslBody, Message, NoticeResponseBody};
@@ -122,7 +123,7 @@ where
T: AsyncRead + AsyncWrite + Unpin,
{
let mut buf = BytesMut::new();
frontend::startup_message(&config.server_params, &mut buf).map_err(Error::encode)?;
frontend::startup_message(&config.server_params, &mut buf);
stream
.send(FrontendMessage::Raw(buf.freeze()))
@@ -193,7 +194,7 @@ where
T: AsyncRead + AsyncWrite + Unpin,
{
let mut buf = BytesMut::new();
frontend::password_message(password, &mut buf).map_err(Error::encode)?;
frontend::password_message(CSafeStr::new(password).map_err(Error::encode)?, &mut buf);
stream
.send(FrontendMessage::Raw(buf.freeze()))
@@ -214,10 +215,10 @@ where
let mut has_scram_plus = false;
let mut mechanisms = body.mechanisms();
while let Some(mechanism) = mechanisms.next().map_err(Error::parse)? {
match mechanism {
sasl::SCRAM_SHA_256 => has_scram = true,
sasl::SCRAM_SHA_256_PLUS => has_scram_plus = true,
_ => {}
if mechanism == sasl::SCRAM_SHA_256 {
has_scram = true;
} else if mechanism == sasl::SCRAM_SHA_256_PLUS {
has_scram_plus = true;
}
}
@@ -256,7 +257,7 @@ where
};
let mut buf = BytesMut::new();
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf).map_err(Error::encode)?;
frontend::sasl_initial_response(mechanism, scram.message(), &mut buf);
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await
@@ -275,7 +276,7 @@ where
.map_err(|e| Error::authentication(e.into()))?;
let mut buf = BytesMut::new();
frontend::sasl_response(scram.message(), &mut buf).map_err(Error::encode)?;
frontend::sasl_response(scram.message(), &mut buf);
stream
.send(FrontendMessage::Raw(buf.freeze()))
.await

View File

@@ -4,6 +4,7 @@ use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{Sink, Stream, ready};
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
@@ -18,12 +19,30 @@ use crate::error::DbError;
use crate::maybe_tls_stream::MaybeTlsStream;
use crate::{AsyncMessage, Error, Notification};
pub enum RequestMessages {
Single(FrontendMessage),
}
pub struct Request {
pub messages: RequestMessages,
pub sender: mpsc::Sender<BackendMessages>,
}
pub struct Response {
sender: PollSender<BackendMessages>,
}
#[derive(PartialEq, Debug)]
enum State {
Active,
Closing,
}
enum WriteReady {
Terminating,
WaitingOnRead,
}
/// A connection to a PostgreSQL database.
///
/// This is one half of what is returned when a new connection is established. It performs the actual IO with the
@@ -37,11 +56,9 @@ pub struct Connection<S, T> {
pub stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
/// HACK: we need this in the Neon Proxy to forward params.
pub parameters: HashMap<String, String>,
sender: PollSender<BackendMessages>,
receiver: mpsc::UnboundedReceiver<FrontendMessage>,
receiver: mpsc::UnboundedReceiver<Request>,
pending_responses: VecDeque<BackendMessage>,
responses: VecDeque<Response>,
state: State,
}
@@ -54,15 +71,14 @@ where
stream: Framed<MaybeTlsStream<S, T>, PostgresCodec>,
pending_responses: VecDeque<BackendMessage>,
parameters: HashMap<String, String>,
sender: mpsc::Sender<BackendMessages>,
receiver: mpsc::UnboundedReceiver<FrontendMessage>,
receiver: mpsc::UnboundedReceiver<Request>,
) -> Connection<S, T> {
Connection {
stream,
parameters,
sender: PollSender::new(sender),
receiver,
pending_responses,
responses: VecDeque::new(),
state: State::Active,
}
}
@@ -94,7 +110,7 @@ where
}
};
let messages = match message {
let (mut messages, request_complete) = match message {
BackendMessage::Async(Message::NoticeResponse(body)) => {
let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?;
return Poll::Ready(Ok(AsyncMessage::Notice(error)));
@@ -115,19 +131,41 @@ where
continue;
}
BackendMessage::Async(_) => unreachable!(),
BackendMessage::Normal { messages } => messages,
BackendMessage::Normal {
messages,
request_complete,
} => (messages, request_complete),
};
match self.sender.poll_reserve(cx) {
let mut response = match self.responses.pop_front() {
Some(response) => response,
None => match messages.next().map_err(Error::parse)? {
Some(Message::ErrorResponse(error)) => {
return Poll::Ready(Err(Error::db(error)));
}
_ => return Poll::Ready(Err(Error::unexpected_message())),
},
};
match response.sender.poll_reserve(cx) {
Poll::Ready(Ok(())) => {
let _ = self.sender.send_item(messages);
let _ = response.sender.send_item(messages);
if !request_complete {
self.responses.push_front(response);
}
}
Poll::Ready(Err(_)) => {
return Poll::Ready(Err(Error::closed()));
// we need to keep paging through the rest of the messages even if the receiver's hung up
if !request_complete {
self.responses.push_front(response);
}
}
Poll::Pending => {
self.pending_responses
.push_back(BackendMessage::Normal { messages });
self.responses.push_front(response);
self.pending_responses.push_back(BackendMessage::Normal {
messages,
request_complete,
});
trace!("poll_read: waiting on sender");
return Poll::Pending;
}
@@ -136,7 +174,7 @@ where
}
/// Fetch the next client request and enqueue the response sender.
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<FrontendMessage>> {
fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll<Option<RequestMessages>> {
if self.receiver.is_closed() {
return Poll::Ready(None);
}
@@ -144,7 +182,10 @@ where
match self.receiver.poll_recv(cx) {
Poll::Ready(Some(request)) => {
trace!("polled new request");
Poll::Ready(Some(request))
self.responses.push_back(Response {
sender: PollSender::new(request.sender),
});
Poll::Ready(Some(request.messages))
}
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
@@ -153,7 +194,7 @@ where
/// Process client requests and write them to the postgres connection, flushing if necessary.
/// client -> postgres
fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<Result<WriteReady, Error>> {
loop {
if Pin::new(&mut self.stream)
.poll_ready(cx)
@@ -168,14 +209,14 @@ where
match self.poll_request(cx) {
// send the message to postgres
Poll::Ready(Some(request)) => {
Poll::Ready(Some(RequestMessages::Single(request))) => {
Pin::new(&mut self.stream)
.start_send(request)
.map_err(Error::io)?;
}
// No more messages from the client, and no more responses to wait for.
// Send a terminate message to postgres
Poll::Ready(None) => {
Poll::Ready(None) if self.responses.is_empty() => {
trace!("poll_write: at eof, terminating");
let mut request = BytesMut::new();
frontend::terminate(&mut request);
@@ -187,7 +228,16 @@ where
trace!("poll_write: sent eof, closing");
trace!("poll_write: done");
return Poll::Ready(Ok(()));
return Poll::Ready(Ok(WriteReady::Terminating));
}
// No more messages from the client, but there are still some responses to wait for.
Poll::Ready(None) => {
trace!(
"poll_write: at eof, pending responses {}",
self.responses.len()
);
ready!(self.poll_flush(cx))?;
return Poll::Ready(Ok(WriteReady::WaitingOnRead));
}
// Still waiting for a message from the client.
Poll::Pending => {
@@ -248,7 +298,7 @@ where
// if the state is still active, try read from and write to postgres.
let message = self.poll_read(cx)?;
let closing = self.poll_write(cx)?;
if let Poll::Ready(()) = closing {
if let Poll::Ready(WriteReady::Terminating) = closing {
self.state = State::Closing;
}

View File

@@ -86,27 +86,6 @@ pub struct DbError {
}
impl DbError {
pub fn new_test_error(code: SqlState, message: String) -> Self {
DbError {
severity: "ERROR".to_string(),
parsed_severity: Some(Severity::Error),
code,
message,
detail: None,
hint: None,
position: None,
where_: None,
schema: None,
table: None,
column: None,
datatype: None,
constraint: None,
file: None,
line: None,
routine: None,
}
}
pub(crate) fn parse(fields: &mut ErrorFields<'_>) -> io::Result<DbError> {
let mut severity = None;
let mut parsed_severity = None;

View File

@@ -1,6 +1,9 @@
#![allow(async_fn_in_trait)]
use postgres_protocol2::Oid;
use crate::query::RowStream;
use crate::types::Type;
use crate::{Client, Error, Transaction};
mod private {
@@ -12,17 +15,20 @@ mod private {
/// This trait is "sealed", and cannot be implemented outside of this crate.
pub trait GenericClient: private::Sealed {
/// Like `Client::query_raw_txt`.
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send;
/// Query for type information
async fn get_type(&mut self, oid: Oid) -> Result<Type, Error>;
}
impl private::Sealed for Client {}
impl GenericClient for Client {
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
@@ -30,12 +36,17 @@ impl GenericClient for Client {
{
self.query_raw_txt(statement, params).await
}
/// Query for type information
async fn get_type(&mut self, oid: Oid) -> Result<Type, Error> {
self.get_type_inner(oid).await
}
}
impl private::Sealed for Transaction<'_> {}
impl GenericClient for Transaction<'_> {
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
@@ -43,4 +54,9 @@ impl GenericClient for Transaction<'_> {
{
self.query_raw_txt(statement, params).await
}
/// Query for type information
async fn get_type(&mut self, oid: Oid) -> Result<Type, Error> {
self.client_mut().get_type(oid).await
}
}

View File

@@ -18,6 +18,7 @@ pub use crate::statement::{Column, Statement};
pub use crate::tls::NoTls;
pub use crate::transaction::Transaction;
pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder};
use crate::types::ToSql;
/// After executing a query, the connection will be in one of these states
#[derive(Clone, Copy, Debug, PartialEq)]
@@ -119,3 +120,9 @@ pub enum SimpleQueryMessage {
/// The number of rows modified or selected is returned.
CommandComplete(u64),
}
fn slice_iter<'a>(
s: &'a [&'a (dyn ToSql + Sync)],
) -> impl ExactSizeIterator<Item = &'a (dyn ToSql + Sync)> + Clone + 'a {
s.iter().map(|s| *s as _)
}

View File

@@ -1,16 +1,22 @@
use bytes::BytesMut;
use std::ffi::CStr;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use bytes::Bytes;
use fallible_iterator::FallibleIterator;
use postgres_protocol2::IsNull;
use postgres_protocol2::message::backend::{Message, RowDescriptionBody};
use futures_util::{TryStreamExt, pin_mut};
use postgres_protocol2::CSafeStr;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use postgres_protocol2::types::oid_to_sql;
use postgres_types2::Format;
use crate::client::{CachedTypeInfo, PartialQuery, Responses};
use crate::client::{CachedTypeInfo, InnerClient};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::types::{Kind, Oid, Type};
use crate::{Column, Error, Row, Statement};
use crate::{Column, Error, Statement, query, slice_iter};
pub(crate) const TYPEINFO_QUERY: &str = "\
pub(crate) const TYPEINFO_QUERY: &CStr = c"\
SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
FROM pg_catalog.pg_type t
LEFT OUTER JOIN pg_catalog.pg_range r ON r.rngtypid = t.oid
@@ -18,51 +24,22 @@ INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid
WHERE t.oid = $1
";
/// we need to make sure we close this prepared statement.
struct CloseStmt<'a, 'b> {
client: Option<&'a mut PartialQuery<'b>>,
name: &'static str,
}
impl<'a> CloseStmt<'a, '_> {
fn close(mut self) -> Result<&'a mut Responses, Error> {
let client = self.client.take().unwrap();
client.send_with_flush(|buf| {
frontend::close(b'S', self.name, buf).map_err(Error::encode)?;
Ok(())
})
}
}
impl Drop for CloseStmt<'_, '_> {
fn drop(&mut self) {
if let Some(client) = self.client.take() {
let _ = client.send_with_flush(|buf| {
frontend::close(b'S', self.name, buf).map_err(Error::encode)?;
Ok(())
});
}
}
}
async fn prepare_typecheck(
client: &mut PartialQuery<'_>,
name: &'static str,
query: &str,
client: &Arc<InnerClient>,
name: &'static CStr,
query: &CSafeStr,
types: &[Type],
) -> Result<Statement, Error> {
let responses = client.send_with_flush(|buf| {
frontend::parse(name, query, [], buf).map_err(Error::encode)?;
frontend::describe(b'S', name, buf).map_err(Error::encode)?;
Ok(())
})?;
let buf = encode(client, name.into(), query, types)?;
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? {
Message::ParseComplete => {}
_ => return Err(Error::unexpected_message()),
}
match responses.next().await? {
Message::ParameterDescription(_) => {}
let parameter_description = match responses.next().await? {
Message::ParameterDescription(body) => body,
_ => return Err(Error::unexpected_message()),
};
@@ -72,6 +49,13 @@ async fn prepare_typecheck(
_ => return Err(Error::unexpected_message()),
};
let mut parameters = vec![];
let mut it = parameter_description.parameters();
while let Some(oid) = it.next().map_err(Error::parse)? {
let type_ = Type::from_oid(oid).ok_or_else(Error::unexpected_message)?;
parameters.push(type_);
}
let mut columns = vec![];
if let Some(row_description) = row_description {
let mut it = row_description.fields();
@@ -82,168 +66,103 @@ async fn prepare_typecheck(
}
}
Ok(Statement::new(name, columns))
Ok(Statement::new(client, name, parameters, columns))
}
fn try_from_cache(typecache: &CachedTypeInfo, oid: Oid) -> Option<Type> {
fn encode(
client: &InnerClient,
name: &CSafeStr,
query: &CSafeStr,
types: &[Type],
) -> Result<Bytes, Error> {
// if types.is_empty() {
// debug!("preparing query {}: {}", name, query);
// } else {
// debug!("preparing query {} with types {:?}: {}", name, types, query);
// }
client.with_buf(|buf| {
frontend::parse(name, query, types.iter().map(Type::oid), buf);
frontend::describe(b'S', name, buf);
frontend::sync(buf);
Ok(buf.split().freeze())
})
}
pub async fn get_type(
client: &Arc<InnerClient>,
typecache: &mut CachedTypeInfo,
oid: Oid,
) -> Result<Type, Error> {
if let Some(type_) = Type::from_oid(oid) {
return Some(type_);
return Ok(type_);
}
if let Some(type_) = typecache.types.get(&oid) {
return Some(type_.clone());
return Ok(type_.clone());
};
None
}
let stmt = typeinfo_statement(client, typecache).await?;
pub async fn parse_row_description(
client: &mut PartialQuery<'_>,
typecache: &mut CachedTypeInfo,
row_description: Option<RowDescriptionBody>,
) -> Result<Vec<Column>, Error> {
let mut columns = vec![];
let rows = query::query(client, stmt, slice_iter(&[&oid])).await?;
pin_mut!(rows);
if let Some(row_description) = row_description {
let mut it = row_description.fields();
while let Some(field) = it.next().map_err(Error::parse)? {
let type_ = try_from_cache(typecache, field.type_oid()).unwrap_or(Type::UNKNOWN);
let column = Column::new(field.name().to_string(), type_, field);
columns.push(column);
}
}
let all_known = columns.iter().all(|c| c.type_ != Type::UNKNOWN);
if all_known {
// all known, return early.
return Ok(columns);
}
let typeinfo = "neon_proxy_typeinfo";
// make sure to close the typeinfo statement before exiting.
let mut guard = CloseStmt {
name: typeinfo,
client: None,
};
let client = guard.client.insert(client);
// get the typeinfo statement.
let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY).await?;
for column in &mut columns {
column.type_ = get_type(client, typecache, &stmt, column.type_oid()).await?;
}
// cancel the close guard.
let responses = guard.close()?;
match responses.next().await? {
Message::CloseComplete => {}
_ => return Err(Error::unexpected_message()),
}
Ok(columns)
}
async fn get_type(
client: &mut PartialQuery<'_>,
typecache: &mut CachedTypeInfo,
stmt: &Statement,
mut oid: Oid,
) -> Result<Type, Error> {
let mut stack = vec![];
let mut type_ = loop {
if let Some(type_) = try_from_cache(typecache, oid) {
break type_;
}
let row = exec(client, stmt, oid).await?;
if stack.len() > 8 {
return Err(Error::unexpected_message());
}
let name: String = row.try_get(0)?;
let type_: i8 = row.try_get(1)?;
let elem_oid: Oid = row.try_get(2)?;
let rngsubtype: Option<Oid> = row.try_get(3)?;
let basetype: Oid = row.try_get(4)?;
let schema: String = row.try_get(5)?;
let relid: Oid = row.try_get(6)?;
let kind = if type_ == b'e' as i8 {
Kind::Enum
} else if type_ == b'p' as i8 {
Kind::Pseudo
} else if basetype != 0 {
Kind::Domain(basetype)
} else if elem_oid != 0 {
stack.push((name, oid, schema));
oid = elem_oid;
continue;
} else if relid != 0 {
Kind::Composite(relid)
} else if let Some(rngsubtype) = rngsubtype {
Kind::Range(rngsubtype)
} else {
Kind::Simple
};
let type_ = Type::new(name, oid, kind, schema);
typecache.types.insert(oid, type_.clone());
break type_;
let row = match rows.try_next().await? {
Some(row) => row,
None => return Err(Error::unexpected_message()),
};
while let Some((name, oid, schema)) = stack.pop() {
type_ = Type::new(name, oid, Kind::Array(type_), schema);
typecache.types.insert(oid, type_.clone());
}
let name: String = row.try_get(0)?;
let type_: i8 = row.try_get(1)?;
let elem_oid: Oid = row.try_get(2)?;
let rngsubtype: Option<Oid> = row.try_get(3)?;
let basetype: Oid = row.try_get(4)?;
let schema: String = row.try_get(5)?;
let relid: Oid = row.try_get(6)?;
let kind = if type_ == b'e' as i8 {
Kind::Enum
} else if type_ == b'p' as i8 {
Kind::Pseudo
} else if basetype != 0 {
Kind::Domain(basetype)
} else if elem_oid != 0 {
let type_ = get_type_rec(client, typecache, elem_oid).await?;
Kind::Array(type_)
} else if relid != 0 {
Kind::Composite(relid)
} else if let Some(rngsubtype) = rngsubtype {
let type_ = get_type_rec(client, typecache, rngsubtype).await?;
Kind::Range(type_)
} else {
Kind::Simple
};
let type_ = Type::new(name, oid, kind, schema);
typecache.types.insert(oid, type_.clone());
Ok(type_)
}
/// exec the typeinfo statement returning one row.
async fn exec(
client: &mut PartialQuery<'_>,
statement: &Statement,
param: Oid,
) -> Result<Row, Error> {
let responses = client.send_with_flush(|buf| {
encode_bind(statement, param, "", buf);
frontend::execute("", 0, buf).map_err(Error::encode)?;
Ok(())
})?;
fn get_type_rec<'a>(
client: &'a Arc<InnerClient>,
typecache: &'a mut CachedTypeInfo,
oid: Oid,
) -> Pin<Box<dyn Future<Output = Result<Type, Error>> + Send + 'a>> {
Box::pin(get_type(client, typecache, oid))
}
match responses.next().await? {
Message::BindComplete => {}
_ => return Err(Error::unexpected_message()),
async fn typeinfo_statement(
client: &Arc<InnerClient>,
typecache: &mut CachedTypeInfo,
) -> Result<Statement, Error> {
if let Some(stmt) = &typecache.typeinfo {
return Ok(stmt.clone());
}
let row = match responses.next().await? {
Message::DataRow(body) => Row::new(statement.clone(), body, Format::Binary)?,
_ => return Err(Error::unexpected_message()),
};
let typeinfo = c"neon_proxy_typeinfo";
let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY.into(), &[]).await?;
match responses.next().await? {
Message::CommandComplete(_) => {}
_ => return Err(Error::unexpected_message()),
};
Ok(row)
}
fn encode_bind(statement: &Statement, param: Oid, portal: &str, buf: &mut BytesMut) {
frontend::bind(
portal,
statement.name(),
[Format::Binary as i16],
[param],
|param, buf| {
oid_to_sql(param, buf);
Ok(IsNull::No)
},
[Format::Binary as i16],
buf,
)
.unwrap();
typecache.typeinfo = Some(stmt.clone());
Ok(stmt)
}

View File

@@ -1,61 +1,104 @@
use std::iter;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::BufMut;
use bytes::{BufMut, Bytes, BytesMut};
use fallible_iterator::FallibleIterator;
use futures_util::{Stream, ready};
use pin_project_lite::pin_project;
use postgres_protocol2::CSafeStr;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use postgres_types2::Format;
use postgres_types2::{Format, ToSql, Type};
use crate::client::{CachedTypeInfo, InnerClient, Responses};
use crate::{Error, ReadyForQueryStatus, Row, Statement};
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::{Column, Error, ReadyForQueryStatus, Row, Statement};
pub async fn query_txt<'a, S, I>(
client: &'a mut InnerClient,
typecache: &mut CachedTypeInfo,
pub async fn query<'a, I>(
client: &InnerClient,
statement: Statement,
params: I,
) -> Result<RowStream, Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator + Clone,
{
let buf = encode(client, &statement, params)?;
let responses = start(client, buf).await?;
Ok(RowStream {
statement,
responses,
command_tag: None,
status: ReadyForQueryStatus::Unknown,
output_format: Format::Binary,
_p: PhantomPinned,
})
}
pub async fn query_txt<S, I>(
client: &Arc<InnerClient>,
query: &str,
params: I,
) -> Result<RowStream<'a>, Error>
) -> Result<RowStream, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
let query = CSafeStr::new(query.as_bytes()).map_err(Error::encode)?;
let params = params.into_iter();
let mut client = client.start()?;
// Flow:
// 1. Parse the query
// 2. Inspect the row description for OIDs
// 3. If there's any OIDs we don't already know about, perform the typeinfo routine
// 4. Execute the query
// 5. Sync.
//
// The typeinfo routine:
// 1. Parse the typeinfo query
// 2. Execute the query on each OID
// 3. If the result does not match an OID we know, repeat 2.
let portal = c"".into(); // unnamed portal
let statement = c"".into(); // unnamed prepared statement
// parse the query and get type info
let responses = client.send_with_flush(|buf| {
let buf = client.with_buf(|buf| {
frontend::parse(
"", // unnamed prepared statement
query, // query to parse
std::iter::empty(), // give no type info
statement,
query, // query to parse
iter::empty(), // give no type info
buf,
)
.map_err(Error::encode)?;
frontend::describe(b'S', "", buf).map_err(Error::encode)?;
Ok(())
})?;
);
frontend::describe(b'S', statement, buf);
// Bind, pass params as text, retrieve as test
frontend::bind(
portal,
statement,
iter::empty(), // all parameters use the default format (text)
params,
|param, buf| match param {
Some(param) => {
buf.put_slice(param.as_ref().as_bytes());
postgres_protocol2::IsNull::No
}
None => postgres_protocol2::IsNull::Yes,
},
Some(0), // all text
buf,
);
// Execute
frontend::execute(portal, 0, buf);
// Sync
frontend::sync(buf);
buf.split().freeze()
});
// now read the responses
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? {
Message::ParseComplete => {}
_ => return Err(Error::unexpected_message()),
}
match responses.next().await? {
Message::ParameterDescription(_) => {}
let parameter_description = match responses.next().await? {
Message::ParameterDescription(body) => body,
_ => return Err(Error::unexpected_message()),
};
@@ -65,82 +108,139 @@ where
_ => return Err(Error::unexpected_message()),
};
let columns =
crate::prepare::parse_row_description(&mut client, typecache, row_description).await?;
match responses.next().await? {
Message::BindComplete => {}
_ => return Err(Error::unexpected_message()),
}
let responses = client.send_with_sync(|buf| {
// Bind, pass params as text, retrieve as text
match frontend::bind(
"", // empty string selects the unnamed portal
"", // unnamed prepared statement
std::iter::empty(), // all parameters use the default format (text)
params,
|param, buf| match param {
Some(param) => {
buf.put_slice(param.as_ref().as_bytes());
Ok(postgres_protocol2::IsNull::No)
}
None => Ok(postgres_protocol2::IsNull::Yes),
},
Some(0), // all text
buf,
) {
Ok(()) => Ok(()),
Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)),
Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
}?;
let mut parameters = vec![];
let mut it = parameter_description.parameters();
while let Some(oid) = it.next().map_err(Error::parse)? {
let type_ = Type::from_oid(oid).unwrap_or(Type::UNKNOWN);
parameters.push(type_);
}
// Execute
frontend::execute("", 0, buf).map_err(Error::encode)?;
let mut columns = vec![];
if let Some(row_description) = row_description {
let mut it = row_description.fields();
while let Some(field) = it.next().map_err(Error::parse)? {
let type_ = Type::from_oid(field.type_oid()).unwrap_or(Type::UNKNOWN);
let column = Column::new(field.name().to_string(), type_, field);
columns.push(column);
}
}
Ok(())
})?;
Ok(RowStream {
statement: Statement::new_anonymous(parameters, columns),
responses,
command_tag: None,
status: ReadyForQueryStatus::Unknown,
output_format: Format::Text,
_p: PhantomPinned,
})
}
async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
match responses.next().await? {
Message::BindComplete => {}
_ => return Err(Error::unexpected_message()),
}
Ok(RowStream {
responses,
statement: Statement::new("", columns),
command_tag: None,
status: ReadyForQueryStatus::Unknown,
output_format: Format::Text,
Ok(responses)
}
pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator + Clone,
{
let portal = c"".into(); // unnamed portal
client.with_buf(|buf| {
encode_bind(statement, params, portal, buf)?;
frontend::execute(portal, 0, buf);
frontend::sync(buf);
Ok(buf.split().freeze())
})
}
/// A stream of table rows.
pub struct RowStream<'a> {
responses: &'a mut Responses,
output_format: Format,
pub statement: Statement,
pub command_tag: Option<String>,
pub status: ReadyForQueryStatus,
pub fn encode_bind<'a, I>(
statement: &Statement,
params: I,
portal: &CSafeStr,
buf: &mut BytesMut,
) -> Result<(), Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator + Clone,
{
let param_types = statement.params();
let params = iter::zip(params.into_iter(), param_types);
assert!(
param_types.len() == params.len(),
"expected {} parameters but got {}",
param_types.len(),
params.len()
);
// check encodings
for (i, (p, ty)) in params.clone().enumerate() {
p.check(ty).map_err(|e| Error::to_sql(Box::new(e), i))?
}
let param_formats = params.clone().map(|(p, ty)| p.encode_format(ty) as i16);
frontend::bind(
portal,
statement.name(),
param_formats,
params,
|(param, ty), buf| param.to_sql(ty, buf),
Some(1),
buf,
);
Ok(())
}
impl Stream for RowStream<'_> {
pin_project! {
/// A stream of table rows.
pub struct RowStream {
statement: Statement,
responses: Responses,
command_tag: Option<String>,
output_format: Format,
status: ReadyForQueryStatus,
#[pin]
_p: PhantomPinned,
}
}
impl Stream for RowStream {
type Item = Result<Row, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
let this = self.project();
loop {
match ready!(this.responses.poll_next(cx)?) {
Message::DataRow(body) => {
return Poll::Ready(Some(Ok(Row::new(
this.statement.clone(),
body,
this.output_format,
*this.output_format,
)?)));
}
Message::EmptyQueryResponse | Message::PortalSuspended => {}
Message::CommandComplete(body) => {
if let Ok(tag) = body.tag() {
this.command_tag = Some(tag.to_string());
*this.command_tag = Some(tag.to_string());
}
}
Message::ReadyForQuery(status) => {
this.status = status.into();
*this.status = status.into();
return Poll::Ready(None);
}
_ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
@@ -148,3 +248,24 @@ impl Stream for RowStream<'_> {
}
}
}
impl RowStream {
/// Returns information about the columns of data in the row.
pub fn columns(&self) -> &[Column] {
self.statement.columns()
}
/// Returns the command tag of this query.
///
/// This is only available after the stream has been exhausted.
pub fn command_tag(&self) -> Option<String> {
self.command_tag.clone()
}
/// Returns if the connection is ready for querying, with the status of the connection.
///
/// This might be available only after the stream has been exhausted.
pub fn ready_status(&self) -> ReadyForQueryStatus {
self.status
}
}

View File

@@ -1,14 +1,20 @@
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::Bytes;
use fallible_iterator::FallibleIterator;
use futures_util::{Stream, ready};
use pin_project_lite::pin_project;
use postgres_protocol2::CSafeStr;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use tracing::debug;
use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::{Error, ReadyForQueryStatus, SimpleQueryMessage, SimpleQueryRow};
/// Information about a column of a single query row.
@@ -28,28 +34,28 @@ impl SimpleColumn {
}
}
pub async fn simple_query<'a>(
client: &'a mut InnerClient,
query: &str,
) -> Result<SimpleQueryStream<'a>, Error> {
pub async fn simple_query(client: &InnerClient, query: &str) -> Result<SimpleQueryStream, Error> {
debug!("executing simple query: {}", query);
let responses = client.send_simple_query(query)?;
let buf = encode(client, query)?;
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
Ok(SimpleQueryStream {
responses,
columns: None,
status: ReadyForQueryStatus::Unknown,
_p: PhantomPinned,
})
}
pub async fn batch_execute(
client: &mut InnerClient,
client: &InnerClient,
query: &str,
) -> Result<ReadyForQueryStatus, Error> {
debug!("executing statement batch: {}", query);
let responses = client.send_simple_query(query)?;
let buf = encode(client, query)?;
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
loop {
match responses.next().await? {
@@ -63,16 +69,26 @@ pub async fn batch_execute(
}
}
pub(crate) fn encode(client: &InnerClient, query: &str) -> Result<Bytes, Error> {
let query = CSafeStr::new(query.as_bytes()).map_err(Error::encode)?;
client.with_buf(|buf| {
frontend::query(query, buf);
Ok(buf.split().freeze())
})
}
pin_project! {
/// A stream of simple query results.
pub struct SimpleQueryStream<'a> {
responses: &'a mut Responses,
pub struct SimpleQueryStream {
responses: Responses,
columns: Option<Arc<[SimpleColumn]>>,
status: ReadyForQueryStatus,
#[pin]
_p: PhantomPinned,
}
}
impl SimpleQueryStream<'_> {
impl SimpleQueryStream {
/// Returns if the connection is ready for querying, with the status of the connection.
///
/// This might be available only after the stream has been exhausted.
@@ -81,7 +97,7 @@ impl SimpleQueryStream<'_> {
}
}
impl Stream for SimpleQueryStream<'_> {
impl Stream for SimpleQueryStream {
type Item = Result<SimpleQueryMessage, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {

View File

@@ -1,15 +1,36 @@
use std::ffi::CStr;
use std::fmt;
use std::sync::Arc;
use std::sync::{Arc, Weak};
use crate::types::Type;
use postgres_protocol2::Oid;
use postgres_protocol2::message::backend::Field;
use postgres_protocol2::message::frontend;
use postgres_protocol2::{CSafeStr, Oid};
use crate::client::InnerClient;
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::types::Type;
struct StatementInner {
name: &'static str,
client: Weak<InnerClient>,
name: &'static CStr,
params: Vec<Type>,
columns: Vec<Column>,
}
impl Drop for StatementInner {
fn drop(&mut self) {
if let Some(client) = self.client.upgrade() {
let buf = client.with_buf(|buf| {
frontend::close(b'S', self.name.into(), buf);
frontend::sync(buf);
buf.split().freeze()
});
let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
}
}
}
/// A prepared statement.
///
/// Prepared statements can only be used with the connection that created them.
@@ -17,12 +38,36 @@ struct StatementInner {
pub struct Statement(Arc<StatementInner>);
impl Statement {
pub(crate) fn new(name: &'static str, columns: Vec<Column>) -> Statement {
Statement(Arc::new(StatementInner { name, columns }))
pub(crate) fn new(
inner: &Arc<InnerClient>,
name: &'static CStr,
params: Vec<Type>,
columns: Vec<Column>,
) -> Statement {
Statement(Arc::new(StatementInner {
client: Arc::downgrade(inner),
name,
params,
columns,
}))
}
pub(crate) fn name(&self) -> &str {
self.0.name
pub(crate) fn new_anonymous(params: Vec<Type>, columns: Vec<Column>) -> Statement {
Statement(Arc::new(StatementInner {
client: Weak::new(),
name: c"<anonymous>",
params,
columns,
}))
}
pub(crate) fn name(&self) -> &CSafeStr {
self.0.name.into()
}
/// Returns the expected types of the statement's parameters.
pub fn params(&self) -> &[Type] {
&self.0.params
}
/// Returns information about the columns returned when the statement is queried.
@@ -34,7 +79,7 @@ impl Statement {
/// Information about a column of a query.
pub struct Column {
name: String,
pub(crate) type_: Type,
type_: Type,
// raw fields from RowDescription
table_oid: Oid,

View File

@@ -1,3 +1,7 @@
use postgres_protocol2::message::frontend;
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::query::RowStream;
use crate::{CancelToken, Client, Error, ReadyForQueryStatus};
@@ -16,7 +20,14 @@ impl Drop for Transaction<'_> {
return;
}
let _ = self.client.inner_mut().send_simple_query("ROLLBACK");
let buf = self.client.inner().with_buf(|buf| {
frontend::query(c"ROLLBACK".into(), buf);
buf.split().freeze()
});
let _ = self
.client
.inner()
.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
}
}
@@ -43,11 +54,7 @@ impl<'a> Transaction<'a> {
}
/// Like `Client::query_raw_txt`.
pub async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream, Error>
pub async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,

View File

@@ -1,7 +1,6 @@
#![allow(clippy::todo)]
use std::ffi::CString;
use std::str::FromStr;
use postgres_ffi::WAL_SEGMENT_SIZE;
use utils::id::TenantTimelineId;
@@ -174,8 +173,6 @@ pub struct Config {
pub ttid: TenantTimelineId,
/// List of safekeepers in format `host:port`
pub safekeepers_list: Vec<String>,
/// libpq connection info options
pub safekeeper_conninfo_options: String,
/// Safekeeper reconnect timeout in milliseconds
pub safekeeper_reconnect_timeout: i32,
/// Safekeeper connection timeout in milliseconds
@@ -205,9 +202,6 @@ impl Wrapper {
.into_bytes_with_nul();
assert!(safekeepers_list_vec.len() == safekeepers_list_vec.capacity());
let safekeepers_list = safekeepers_list_vec.as_mut_ptr() as *mut std::ffi::c_char;
let safekeeper_conninfo_options = CString::from_str(&config.safekeeper_conninfo_options)
.unwrap()
.into_raw();
let callback_data = Box::into_raw(Box::new(api)) as *mut ::std::os::raw::c_void;
@@ -215,7 +209,6 @@ impl Wrapper {
neon_tenant,
neon_timeline,
safekeepers_list,
safekeeper_conninfo_options,
safekeeper_reconnect_timeout: config.safekeeper_reconnect_timeout,
safekeeper_connection_timeout: config.safekeeper_connection_timeout,
wal_segment_size: WAL_SEGMENT_SIZE as i32, // default 16MB
@@ -583,7 +576,6 @@ mod tests {
let config = crate::walproposer::Config {
ttid,
safekeepers_list: vec!["localhost:5000".to_string()],
safekeeper_conninfo_options: String::new(),
safekeeper_reconnect_timeout: 1000,
safekeeper_connection_timeout: 10000,
sync_safekeepers: true,

View File

@@ -17,69 +17,50 @@ anyhow.workspace = true
arc-swap.workspace = true
async-compression.workspace = true
async-stream.workspace = true
bincode.workspace = true
bit_field.workspace = true
bincode.workspace = true
byteorder.workspace = true
bytes.workspace = true
camino-tempfile.workspace = true
camino.workspace = true
camino-tempfile.workspace = true
chrono = { workspace = true, features = ["serde"] }
clap = { workspace = true, features = ["string"] }
consumption_metrics.workspace = true
crc32c.workspace = true
either.workspace = true
enum-map.workspace = true
enumset = { workspace = true, features = ["serde"]}
fail.workspace = true
futures.workspace = true
hashlink.workspace = true
hex.workspace = true
http-utils.workspace = true
humantime-serde.workspace = true
humantime.workspace = true
humantime-serde.workspace = true
hyper0.workspace = true
itertools.workspace = true
jsonwebtoken.workspace = true
md5.workspace = true
metrics.workspace = true
nix.workspace = true
num_cpus.workspace = true # hack to get the number of worker threads tokio uses
# hack to get the number of worker threads tokio uses
num_cpus.workspace = true
num-traits.workspace = true
once_cell.workspace = true
pageserver_api.workspace = true
pageserver_client.workspace = true # for ResponseErrorMessageExt TOOD refactor that
pageserver_compaction.workspace = true
pageserver_page_api.workspace = true
pem.workspace = true
pin-project-lite.workspace = true
postgres_backend.workspace = true
postgres_connection.workspace = true
postgres_ffi.workspace = true
postgres_initdb.workspace = true
postgres-protocol.workspace = true
postgres-types.workspace = true
posthog_client_lite.workspace = true
postgres_initdb.workspace = true
pprof.workspace = true
pq_proto.workspace = true
rand.workspace = true
range-set-blaze = { version = "0.1.16", features = ["alloc"] }
regex.workspace = true
remote_storage.workspace = true
reqwest.workspace = true
rpds.workspace = true
rustls.workspace = true
scopeguard.workspace = true
send-future.workspace = true
serde.workspace = true
serde_json = { workspace = true, features = ["raw_value"] }
serde_path_to_error.workspace = true
serde_with.workspace = true
serde.workspace = true
smallvec.workspace = true
storage_broker.workspace = true
strum_macros.workspace = true
strum.workspace = true
sysinfo.workspace = true
tenant_size_model.workspace = true
tokio-tar.workspace = true
thiserror.workspace = true
tikv-jemallocator.workspace = true
tokio = { workspace = true, features = ["process", "sync", "fs", "rt", "io-util", "time"] }
@@ -88,18 +69,34 @@ tokio-io-timeout.workspace = true
tokio-postgres.workspace = true
tokio-rustls.workspace = true
tokio-stream.workspace = true
tokio-tar.workspace = true
tokio-util.workspace = true
toml_edit = { workspace = true, features = [ "serde" ] }
tonic.workspace = true
tonic-reflection.workspace = true
tracing.workspace = true
tracing-utils.workspace = true
url.workspace = true
utils.workspace = true
wal_decoder.workspace = true
walkdir.workspace = true
metrics.workspace = true
pageserver_api.workspace = true
pageserver_client.workspace = true # for ResponseErrorMessageExt TOOD refactor that
pageserver_compaction.workspace = true
pem.workspace = true
postgres_connection.workspace = true
postgres_ffi.workspace = true
pq_proto.workspace = true
remote_storage.workspace = true
storage_broker.workspace = true
tenant_size_model.workspace = true
http-utils.workspace = true
utils.workspace = true
workspace_hack.workspace = true
reqwest.workspace = true
rpds.workspace = true
enum-map.workspace = true
enumset = { workspace = true, features = ["serde"]}
strum.workspace = true
strum_macros.workspace = true
wal_decoder.workspace = true
smallvec.workspace = true
twox-hash.workspace = true
[target.'cfg(target_os = "linux")'.dependencies]

View File

@@ -5,14 +5,8 @@ edition.workspace = true
license.workspace = true
[dependencies]
bytes.workspace = true
pageserver_api.workspace = true
postgres_ffi.workspace = true
prost.workspace = true
smallvec.workspace = true
thiserror.workspace = true
tonic.workspace = true
utils.workspace = true
workspace_hack.workspace = true
[build-dependencies]

View File

@@ -54,9 +54,9 @@ service PageService {
// RPCs use regular unary requests, since they are not as frequent and
// performance-critical, and this simplifies implementation.
//
// NB: a gRPC status response (e.g. errors) will terminate the stream. The
// stream may be shared by multiple Postgres backends, so we avoid this by
// sending them as GetPageResponse.status_code instead.
// NB: a status response (e.g. errors) will terminate the stream. The stream
// may be shared by e.g. multiple Postgres backends, so we should avoid this.
// Most errors are therefore sent as GetPageResponse.status instead.
rpc GetPages (stream GetPageRequest) returns (stream GetPageResponse);
// Returns the size of a relation, as # of blocks.
@@ -159,8 +159,8 @@ message GetPageRequest {
// A GetPageRequest class. Primarily intended for observability, but may also be
// used for prioritization in the future.
enum GetPageClass {
// Unknown class. For backwards compatibility: used when an older client version sends a class
// that a newer server version has removed.
// Unknown class. For forwards compatibility: used when the client sends a
// class that the server doesn't know about.
GET_PAGE_CLASS_UNKNOWN = 0;
// A normal request. This is the default.
GET_PAGE_CLASS_NORMAL = 1;
@@ -180,37 +180,31 @@ message GetPageResponse {
// The original request's ID.
uint64 request_id = 1;
// The response status code.
GetPageStatusCode status_code = 2;
GetPageStatus status = 2;
// A string describing the status, if any.
string reason = 3;
// The 8KB page images, in the same order as the request. Empty if status_code != OK.
// The 8KB page images, in the same order as the request. Empty if status != OK.
repeated bytes page_image = 4;
}
// A GetPageResponse status code.
//
// These are effectively equivalent to gRPC statuses. However, we use a bidirectional stream
// (potentially shared by many backends), and a gRPC status response would terminate the stream so
// we send GetPageResponse messages with these codes instead.
enum GetPageStatusCode {
// Unknown status. For forwards compatibility: used when an older client version receives a new
// status code from a newer server version.
GET_PAGE_STATUS_CODE_UNKNOWN = 0;
// A GetPageResponse status code. Since we use a bidirectional stream, we don't
// want to send errors as gRPC statuses, since this would terminate the stream.
enum GetPageStatus {
// Unknown status. For forwards compatibility: used when the server sends a
// status code that the client doesn't know about.
GET_PAGE_STATUS_UNKNOWN = 0;
// The request was successful.
GET_PAGE_STATUS_CODE_OK = 1;
GET_PAGE_STATUS_OK = 1;
// The page did not exist. The tenant/timeline/shard has already been
// validated during stream setup.
GET_PAGE_STATUS_CODE_NOT_FOUND = 2;
GET_PAGE_STATUS_NOT_FOUND = 2;
// The request was invalid.
GET_PAGE_STATUS_CODE_INVALID_REQUEST = 3;
// The request failed due to an internal server error.
GET_PAGE_STATUS_CODE_INTERNAL_ERROR = 4;
GET_PAGE_STATUS_INVALID = 3;
// The tenant is rate limited. Slow down and retry later.
GET_PAGE_STATUS_CODE_SLOW_DOWN = 5;
// NB: shutdown errors are emitted as a gRPC Unavailable status.
//
// TODO: consider adding a GET_PAGE_STATUS_CODE_LAYER_DOWNLOAD in the case of a layer download.
// This could free up the server task to process other requests while the download is in progress.
GET_PAGE_STATUS_SLOW_DOWN = 4;
// TODO: consider adding a GET_PAGE_STATUS_LAYER_DOWNLOAD in the case of a
// layer download. This could free up the server task to process other
// requests while the layer download is in progress.
}
// Fetches the size of a relation at a given LSN, as # of blocks. Only valid on

View File

@@ -17,7 +17,3 @@ pub mod proto {
pub use page_service_client::PageServiceClient;
pub use page_service_server::{PageService, PageServiceServer};
}
mod model;
pub use model::*;

View File

@@ -1,595 +0,0 @@
//! Structs representing the canonical page service API.
//!
//! These mirror the autogenerated Protobuf types. The differences are:
//!
//! - Types that are in fact required by the API are not Options. The protobuf "required"
//! attribute is deprecated and 'prost' marks a lot of members as optional because of that.
//! (See <https://github.com/tokio-rs/prost/issues/800> for a gripe on this)
//!
//! - Use more precise datatypes, e.g. Lsn and uints shorter than 32 bits.
//!
//! - Validate protocol invariants, via try_from() and try_into().
use bytes::Bytes;
use postgres_ffi::Oid;
use smallvec::SmallVec;
// TODO: split out Lsn, RelTag, SlruKind, Oid and other basic types to a separate crate, to avoid
// pulling in all of their other crate dependencies when building the client.
use utils::lsn::Lsn;
use crate::proto;
/// A protocol error. Typically returned via try_from() or try_into().
#[derive(thiserror::Error, Debug)]
pub enum ProtocolError {
#[error("field '{0}' has invalid value '{1}'")]
Invalid(&'static str, String),
#[error("required field '{0}' is missing")]
Missing(&'static str),
}
impl ProtocolError {
/// Helper to generate a new ProtocolError::Invalid for the given field and value.
pub fn invalid(field: &'static str, value: impl std::fmt::Debug) -> Self {
Self::Invalid(field, format!("{value:?}"))
}
}
impl From<ProtocolError> for tonic::Status {
fn from(err: ProtocolError) -> Self {
tonic::Status::invalid_argument(format!("{err}"))
}
}
/// The LSN a request should read at.
#[derive(Clone, Copy, Debug)]
pub struct ReadLsn {
/// The request's read LSN.
pub request_lsn: Lsn,
/// If given, the caller guarantees that the page has not been modified since this LSN. Must be
/// smaller than or equal to request_lsn. This allows the Pageserver to serve an old page
/// without waiting for the request LSN to arrive. Valid for all request types.
///
/// It is undefined behaviour to make a request such that the page was, in fact, modified
/// between request_lsn and not_modified_since_lsn. The Pageserver might detect it and return an
/// error, or it might return the old page version or the new page version. Setting
/// not_modified_since_lsn equal to request_lsn is always safe, but can lead to unnecessary
/// waiting.
pub not_modified_since_lsn: Option<Lsn>,
}
impl ReadLsn {
/// Validates the ReadLsn.
pub fn validate(&self) -> Result<(), ProtocolError> {
if self.request_lsn == Lsn::INVALID {
return Err(ProtocolError::invalid("request_lsn", self.request_lsn));
}
if self.not_modified_since_lsn > Some(self.request_lsn) {
return Err(ProtocolError::invalid(
"not_modified_since_lsn",
self.not_modified_since_lsn,
));
}
Ok(())
}
}
impl TryFrom<proto::ReadLsn> for ReadLsn {
type Error = ProtocolError;
fn try_from(pb: proto::ReadLsn) -> Result<Self, Self::Error> {
let read_lsn = Self {
request_lsn: Lsn(pb.request_lsn),
not_modified_since_lsn: match pb.not_modified_since_lsn {
0 => None,
lsn => Some(Lsn(lsn)),
},
};
read_lsn.validate()?;
Ok(read_lsn)
}
}
impl TryFrom<ReadLsn> for proto::ReadLsn {
type Error = ProtocolError;
fn try_from(read_lsn: ReadLsn) -> Result<Self, Self::Error> {
read_lsn.validate()?;
Ok(Self {
request_lsn: read_lsn.request_lsn.0,
not_modified_since_lsn: read_lsn.not_modified_since_lsn.unwrap_or_default().0,
})
}
}
// RelTag is defined in pageserver_api::reltag.
pub type RelTag = pageserver_api::reltag::RelTag;
impl TryFrom<proto::RelTag> for RelTag {
type Error = ProtocolError;
fn try_from(pb: proto::RelTag) -> Result<Self, Self::Error> {
Ok(Self {
spcnode: pb.spc_oid,
dbnode: pb.db_oid,
relnode: pb.rel_number,
forknum: pb
.fork_number
.try_into()
.map_err(|_| ProtocolError::invalid("fork_number", pb.fork_number))?,
})
}
}
impl From<RelTag> for proto::RelTag {
fn from(rel_tag: RelTag) -> Self {
Self {
spc_oid: rel_tag.spcnode,
db_oid: rel_tag.dbnode,
rel_number: rel_tag.relnode,
fork_number: rel_tag.forknum as u32,
}
}
}
/// Checks whether a relation exists, at the given LSN. Only valid on shard 0, other shards error.
#[derive(Clone, Copy, Debug)]
pub struct CheckRelExistsRequest {
pub read_lsn: ReadLsn,
pub rel: RelTag,
}
impl TryFrom<proto::CheckRelExistsRequest> for CheckRelExistsRequest {
type Error = ProtocolError;
fn try_from(pb: proto::CheckRelExistsRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: pb
.read_lsn
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?,
})
}
}
pub type CheckRelExistsResponse = bool;
impl From<proto::CheckRelExistsResponse> for CheckRelExistsResponse {
fn from(pb: proto::CheckRelExistsResponse) -> Self {
pb.exists
}
}
impl From<CheckRelExistsResponse> for proto::CheckRelExistsResponse {
fn from(exists: CheckRelExistsResponse) -> Self {
Self { exists }
}
}
/// Requests a base backup at a given LSN.
#[derive(Clone, Copy, Debug)]
pub struct GetBaseBackupRequest {
/// The LSN to fetch a base backup at.
pub read_lsn: ReadLsn,
/// If true, logical replication slots will not be created.
pub replica: bool,
}
impl TryFrom<proto::GetBaseBackupRequest> for GetBaseBackupRequest {
type Error = ProtocolError;
fn try_from(pb: proto::GetBaseBackupRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: pb
.read_lsn
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
replica: pb.replica,
})
}
}
impl TryFrom<GetBaseBackupRequest> for proto::GetBaseBackupRequest {
type Error = ProtocolError;
fn try_from(request: GetBaseBackupRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
replica: request.replica,
})
}
}
pub type GetBaseBackupResponseChunk = Bytes;
impl TryFrom<proto::GetBaseBackupResponseChunk> for GetBaseBackupResponseChunk {
type Error = ProtocolError;
fn try_from(pb: proto::GetBaseBackupResponseChunk) -> Result<Self, Self::Error> {
if pb.chunk.is_empty() {
return Err(ProtocolError::Missing("chunk"));
}
Ok(pb.chunk)
}
}
impl TryFrom<GetBaseBackupResponseChunk> for proto::GetBaseBackupResponseChunk {
type Error = ProtocolError;
fn try_from(chunk: GetBaseBackupResponseChunk) -> Result<Self, Self::Error> {
if chunk.is_empty() {
return Err(ProtocolError::Missing("chunk"));
}
Ok(Self { chunk })
}
}
/// Requests the size of a database, as # of bytes. Only valid on shard 0, other shards will error.
#[derive(Clone, Copy, Debug)]
pub struct GetDbSizeRequest {
pub read_lsn: ReadLsn,
pub db_oid: Oid,
}
impl TryFrom<proto::GetDbSizeRequest> for GetDbSizeRequest {
type Error = ProtocolError;
fn try_from(pb: proto::GetDbSizeRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: pb
.read_lsn
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
db_oid: pb.db_oid,
})
}
}
impl TryFrom<GetDbSizeRequest> for proto::GetDbSizeRequest {
type Error = ProtocolError;
fn try_from(request: GetDbSizeRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
db_oid: request.db_oid,
})
}
}
pub type GetDbSizeResponse = u64;
impl From<proto::GetDbSizeResponse> for GetDbSizeResponse {
fn from(pb: proto::GetDbSizeResponse) -> Self {
pb.num_bytes
}
}
impl From<GetDbSizeResponse> for proto::GetDbSizeResponse {
fn from(num_bytes: GetDbSizeResponse) -> Self {
Self { num_bytes }
}
}
/// Requests one or more pages.
#[derive(Clone, Debug)]
pub struct GetPageRequest {
/// A request ID. Will be included in the response. Should be unique for in-flight requests on
/// the stream.
pub request_id: RequestID,
/// The request class.
pub request_class: GetPageClass,
/// The LSN to read at.
pub read_lsn: ReadLsn,
/// The relation to read from.
pub rel: RelTag,
/// Page numbers to read. Must belong to the remote shard.
///
/// Multiple pages will be executed as a single batch by the Pageserver, amortizing layer access
/// costs and parallelizing them. This may increase the latency of any individual request, but
/// improves the overall latency and throughput of the batch as a whole.
pub block_numbers: SmallVec<[u32; 1]>,
}
impl TryFrom<proto::GetPageRequest> for GetPageRequest {
type Error = ProtocolError;
fn try_from(pb: proto::GetPageRequest) -> Result<Self, Self::Error> {
if pb.block_number.is_empty() {
return Err(ProtocolError::Missing("block_number"));
}
Ok(Self {
request_id: pb.request_id,
request_class: pb.request_class.into(),
read_lsn: pb
.read_lsn
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?,
block_numbers: pb.block_number.into(),
})
}
}
impl TryFrom<GetPageRequest> for proto::GetPageRequest {
type Error = ProtocolError;
fn try_from(request: GetPageRequest) -> Result<Self, Self::Error> {
if request.block_numbers.is_empty() {
return Err(ProtocolError::Missing("block_number"));
}
Ok(Self {
request_id: request.request_id,
request_class: request.request_class.into(),
read_lsn: Some(request.read_lsn.try_into()?),
rel: Some(request.rel.into()),
block_number: request.block_numbers.into_vec(),
})
}
}
/// A GetPage request ID.
pub type RequestID = u64;
/// A GetPage request class.
#[derive(Clone, Copy, Debug)]
pub enum GetPageClass {
/// Unknown class. For backwards compatibility: used when an older client version sends a class
/// that a newer server version has removed.
Unknown,
/// A normal request. This is the default.
Normal,
/// A prefetch request. NB: can only be classified on pg < 18.
Prefetch,
/// A background request (e.g. vacuum).
Background,
}
impl From<proto::GetPageClass> for GetPageClass {
fn from(pb: proto::GetPageClass) -> Self {
match pb {
proto::GetPageClass::Unknown => Self::Unknown,
proto::GetPageClass::Normal => Self::Normal,
proto::GetPageClass::Prefetch => Self::Prefetch,
proto::GetPageClass::Background => Self::Background,
}
}
}
impl From<i32> for GetPageClass {
fn from(class: i32) -> Self {
proto::GetPageClass::try_from(class)
.unwrap_or(proto::GetPageClass::Unknown)
.into()
}
}
impl From<GetPageClass> for proto::GetPageClass {
fn from(class: GetPageClass) -> Self {
match class {
GetPageClass::Unknown => Self::Unknown,
GetPageClass::Normal => Self::Normal,
GetPageClass::Prefetch => Self::Prefetch,
GetPageClass::Background => Self::Background,
}
}
}
impl From<GetPageClass> for i32 {
fn from(class: GetPageClass) -> Self {
proto::GetPageClass::from(class).into()
}
}
/// A GetPage response.
///
/// A batch response will contain all of the requested pages. We could eagerly emit individual pages
/// as soon as they are ready, but on a readv() Postgres holds buffer pool locks on all pages in the
/// batch and we'll only return once the entire batch is ready, so no one can make use of the
/// individual pages.
#[derive(Clone, Debug)]
pub struct GetPageResponse {
/// The original request's ID.
pub request_id: RequestID,
/// The response status code.
pub status_code: GetPageStatusCode,
/// A string describing the status, if any.
pub reason: Option<String>,
/// The 8KB page images, in the same order as the request. Empty if status != OK.
pub page_images: SmallVec<[Bytes; 1]>,
}
impl From<proto::GetPageResponse> for GetPageResponse {
fn from(pb: proto::GetPageResponse) -> Self {
Self {
request_id: pb.request_id,
status_code: pb.status_code.into(),
reason: Some(pb.reason).filter(|r| !r.is_empty()),
page_images: pb.page_image.into(),
}
}
}
impl From<GetPageResponse> for proto::GetPageResponse {
fn from(response: GetPageResponse) -> Self {
Self {
request_id: response.request_id,
status_code: response.status_code.into(),
reason: response.reason.unwrap_or_default(),
page_image: response.page_images.into_vec(),
}
}
}
/// A GetPage response status code.
///
/// These are effectively equivalent to gRPC statuses. However, we use a bidirectional stream
/// (potentially shared by many backends), and a gRPC status response would terminate the stream so
/// we send GetPageResponse messages with these codes instead.
#[derive(Clone, Copy, Debug)]
pub enum GetPageStatusCode {
/// Unknown status. For forwards compatibility: used when an older client version receives a new
/// status code from a newer server version.
Unknown,
/// The request was successful.
Ok,
/// The page did not exist. The tenant/timeline/shard has already been validated during stream
/// setup.
NotFound,
/// The request was invalid.
InvalidRequest,
/// The request failed due to an internal server error.
InternalError,
/// The tenant is rate limited. Slow down and retry later.
SlowDown,
}
impl From<proto::GetPageStatusCode> for GetPageStatusCode {
fn from(pb: proto::GetPageStatusCode) -> Self {
match pb {
proto::GetPageStatusCode::Unknown => Self::Unknown,
proto::GetPageStatusCode::Ok => Self::Ok,
proto::GetPageStatusCode::NotFound => Self::NotFound,
proto::GetPageStatusCode::InvalidRequest => Self::InvalidRequest,
proto::GetPageStatusCode::InternalError => Self::InternalError,
proto::GetPageStatusCode::SlowDown => Self::SlowDown,
}
}
}
impl From<i32> for GetPageStatusCode {
fn from(status_code: i32) -> Self {
proto::GetPageStatusCode::try_from(status_code)
.unwrap_or(proto::GetPageStatusCode::Unknown)
.into()
}
}
impl From<GetPageStatusCode> for proto::GetPageStatusCode {
fn from(status_code: GetPageStatusCode) -> Self {
match status_code {
GetPageStatusCode::Unknown => Self::Unknown,
GetPageStatusCode::Ok => Self::Ok,
GetPageStatusCode::NotFound => Self::NotFound,
GetPageStatusCode::InvalidRequest => Self::InvalidRequest,
GetPageStatusCode::InternalError => Self::InternalError,
GetPageStatusCode::SlowDown => Self::SlowDown,
}
}
}
impl From<GetPageStatusCode> for i32 {
fn from(status_code: GetPageStatusCode) -> Self {
proto::GetPageStatusCode::from(status_code).into()
}
}
// Fetches the size of a relation at a given LSN, as # of blocks. Only valid on shard 0, other
// shards will error.
pub struct GetRelSizeRequest {
pub read_lsn: ReadLsn,
pub rel: RelTag,
}
impl TryFrom<proto::GetRelSizeRequest> for GetRelSizeRequest {
type Error = ProtocolError;
fn try_from(proto: proto::GetRelSizeRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: proto
.read_lsn
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
rel: proto.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?,
})
}
}
impl TryFrom<GetRelSizeRequest> for proto::GetRelSizeRequest {
type Error = ProtocolError;
fn try_from(request: GetRelSizeRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
rel: Some(request.rel.into()),
})
}
}
pub type GetRelSizeResponse = u32;
impl From<proto::GetRelSizeResponse> for GetRelSizeResponse {
fn from(proto: proto::GetRelSizeResponse) -> Self {
proto.num_blocks
}
}
impl From<GetRelSizeResponse> for proto::GetRelSizeResponse {
fn from(num_blocks: GetRelSizeResponse) -> Self {
Self { num_blocks }
}
}
/// Requests an SLRU segment. Only valid on shard 0, other shards will error.
pub struct GetSlruSegmentRequest {
pub read_lsn: ReadLsn,
pub kind: SlruKind,
pub segno: u32,
}
impl TryFrom<proto::GetSlruSegmentRequest> for GetSlruSegmentRequest {
type Error = ProtocolError;
fn try_from(pb: proto::GetSlruSegmentRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: pb
.read_lsn
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
kind: u8::try_from(pb.kind)
.ok()
.and_then(SlruKind::from_repr)
.ok_or_else(|| ProtocolError::invalid("slru_kind", pb.kind))?,
segno: pb.segno,
})
}
}
impl TryFrom<GetSlruSegmentRequest> for proto::GetSlruSegmentRequest {
type Error = ProtocolError;
fn try_from(request: GetSlruSegmentRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: Some(request.read_lsn.try_into()?),
kind: request.kind as u32,
segno: request.segno,
})
}
}
pub type GetSlruSegmentResponse = Bytes;
impl TryFrom<proto::GetSlruSegmentResponse> for GetSlruSegmentResponse {
type Error = ProtocolError;
fn try_from(pb: proto::GetSlruSegmentResponse) -> Result<Self, Self::Error> {
if pb.segment.is_empty() {
return Err(ProtocolError::Missing("segment"));
}
Ok(pb.segment)
}
}
impl TryFrom<GetSlruSegmentResponse> for proto::GetSlruSegmentResponse {
type Error = ProtocolError;
fn try_from(segment: GetSlruSegmentResponse) -> Result<Self, Self::Error> {
if segment.is_empty() {
return Err(ProtocolError::Missing("segment"));
}
Ok(Self { segment })
}
}
// SlruKind is defined in pageserver_api::reltag.
pub type SlruKind = pageserver_api::reltag::SlruKind;

View File

@@ -1,518 +0,0 @@
use std::{collections::HashMap, sync::Arc};
use async_compression::tokio::write::GzipEncoder;
use camino::{Utf8Path, Utf8PathBuf};
use metrics::core::{AtomicU64, GenericCounter};
use pageserver_api::{config::BasebackupCacheConfig, models::TenantState};
use tokio::{
io::{AsyncWriteExt, BufWriter},
sync::mpsc::{UnboundedReceiver, UnboundedSender},
};
use tokio_util::sync::CancellationToken;
use utils::{
id::{TenantId, TenantTimelineId, TimelineId},
lsn::Lsn,
shard::TenantShardId,
};
use crate::{
basebackup::send_basebackup_tarball,
context::{DownloadBehavior, RequestContext},
metrics::{BASEBACKUP_CACHE_ENTRIES, BASEBACKUP_CACHE_PREPARE, BASEBACKUP_CACHE_READ},
task_mgr::TaskKind,
tenant::{
Timeline,
mgr::{TenantManager, TenantSlot},
},
};
pub struct BasebackupPrepareRequest {
pub tenant_shard_id: TenantShardId,
pub timeline_id: TimelineId,
pub lsn: Lsn,
}
pub type BasebackupPrepareSender = UnboundedSender<BasebackupPrepareRequest>;
pub type BasebackupPrepareReceiver = UnboundedReceiver<BasebackupPrepareRequest>;
type BasebackupRemoveEntrySender = UnboundedSender<Utf8PathBuf>;
type BasebackupRemoveEntryReceiver = UnboundedReceiver<Utf8PathBuf>;
/// BasebackupCache stores cached basebackup archives for timelines on local disk.
///
/// The main purpose of this cache is to speed up the startup process of compute nodes
/// after scaling to zero.
/// Thus, the basebackup is stored only for the latest LSN of the timeline and with
/// fixed set of parameters (gzip=true, full_backup=false, replica=false, prev_lsn=none).
///
/// The cache receives prepare requests through the `BasebackupPrepareSender` channel,
/// generates a basebackup from the timeline in the background, and stores it on disk.
///
/// Basebackup requests are pretty rare. We expect ~thousands of entries in the cache
/// and ~1 RPS for get requests.
pub struct BasebackupCache {
data_dir: Utf8PathBuf,
config: BasebackupCacheConfig,
tenant_manager: Arc<TenantManager>,
remove_entry_sender: BasebackupRemoveEntrySender,
entries: std::sync::Mutex<HashMap<TenantTimelineId, Lsn>>,
cancel: CancellationToken,
read_hit_count: GenericCounter<AtomicU64>,
read_miss_count: GenericCounter<AtomicU64>,
read_err_count: GenericCounter<AtomicU64>,
prepare_ok_count: GenericCounter<AtomicU64>,
prepare_skip_count: GenericCounter<AtomicU64>,
prepare_err_count: GenericCounter<AtomicU64>,
}
impl BasebackupCache {
/// Creates a BasebackupCache and spawns the background task.
/// The initialization of the cache is performed in the background and does not
/// block the caller. The cache will return `None` for any get requests until
/// initialization is complete.
pub fn spawn(
runtime_handle: &tokio::runtime::Handle,
data_dir: Utf8PathBuf,
config: Option<BasebackupCacheConfig>,
prepare_receiver: BasebackupPrepareReceiver,
tenant_manager: Arc<TenantManager>,
cancel: CancellationToken,
) -> Arc<Self> {
let (remove_entry_sender, remove_entry_receiver) = tokio::sync::mpsc::unbounded_channel();
let enabled = config.is_some();
let cache = Arc::new(BasebackupCache {
data_dir,
config: config.unwrap_or_default(),
tenant_manager,
remove_entry_sender,
entries: std::sync::Mutex::new(HashMap::new()),
cancel,
read_hit_count: BASEBACKUP_CACHE_READ.with_label_values(&["hit"]),
read_miss_count: BASEBACKUP_CACHE_READ.with_label_values(&["miss"]),
read_err_count: BASEBACKUP_CACHE_READ.with_label_values(&["error"]),
prepare_ok_count: BASEBACKUP_CACHE_PREPARE.with_label_values(&["ok"]),
prepare_skip_count: BASEBACKUP_CACHE_PREPARE.with_label_values(&["skip"]),
prepare_err_count: BASEBACKUP_CACHE_PREPARE.with_label_values(&["error"]),
});
if enabled {
runtime_handle.spawn(
cache
.clone()
.background(prepare_receiver, remove_entry_receiver),
);
}
cache
}
/// Gets a basebackup entry from the cache.
/// If the entry is found, opens a file with the basebackup archive and returns it.
/// The open file descriptor will prevent the file system from deleting the file
/// even if the entry is removed from the cache in the background.
pub async fn get(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
lsn: Lsn,
) -> Option<tokio::fs::File> {
// Fast path. Check if the entry exists using the in-memory state.
let tti = TenantTimelineId::new(tenant_id, timeline_id);
if self.entries.lock().unwrap().get(&tti) != Some(&lsn) {
self.read_miss_count.inc();
return None;
}
let path = self.entry_path(tenant_id, timeline_id, lsn);
match tokio::fs::File::open(path).await {
Ok(file) => {
self.read_hit_count.inc();
Some(file)
}
Err(e) => {
if e.kind() == std::io::ErrorKind::NotFound {
// We may end up here if the basebackup was concurrently removed by the cleanup task.
self.read_miss_count.inc();
} else {
self.read_err_count.inc();
tracing::warn!("Unexpected error opening basebackup cache file: {:?}", e);
}
None
}
}
}
// Private methods.
fn entry_filename(tenant_id: TenantId, timeline_id: TimelineId, lsn: Lsn) -> String {
// The default format for LSN is 0/ABCDEF.
// The backslash is not filename friendly, so serialize it as plain hex.
let lsn = lsn.0;
format!("basebackup_{tenant_id}_{timeline_id}_{lsn:016X}.tar.gz")
}
fn entry_path(&self, tenant_id: TenantId, timeline_id: TimelineId, lsn: Lsn) -> Utf8PathBuf {
self.data_dir
.join(Self::entry_filename(tenant_id, timeline_id, lsn))
}
fn entry_tmp_path(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
lsn: Lsn,
) -> Utf8PathBuf {
self.data_dir
.join("tmp")
.join(Self::entry_filename(tenant_id, timeline_id, lsn))
}
fn parse_entry_filename(filename: &str) -> Option<(TenantId, TimelineId, Lsn)> {
let parts: Vec<&str> = filename
.strip_prefix("basebackup_")?
.strip_suffix(".tar.gz")?
.split('_')
.collect();
if parts.len() != 3 {
return None;
}
let tenant_id = parts[0].parse::<TenantId>().ok()?;
let timeline_id = parts[1].parse::<TimelineId>().ok()?;
let lsn = Lsn(u64::from_str_radix(parts[2], 16).ok()?);
Some((tenant_id, timeline_id, lsn))
}
async fn cleanup(&self) -> anyhow::Result<()> {
// Cleanup tmp directory.
let tmp_dir = self.data_dir.join("tmp");
let mut tmp_dir = tokio::fs::read_dir(&tmp_dir).await?;
while let Some(dir_entry) = tmp_dir.next_entry().await? {
if let Err(e) = tokio::fs::remove_file(dir_entry.path()).await {
tracing::warn!("Failed to remove basebackup cache tmp file: {:#}", e);
}
}
// Remove outdated entries.
let entries_old = self.entries.lock().unwrap().clone();
let mut entries_new = HashMap::new();
for (tenant_shard_id, tenant_slot) in self.tenant_manager.list() {
if !tenant_shard_id.is_shard_zero() {
continue;
}
let TenantSlot::Attached(tenant) = tenant_slot else {
continue;
};
let tenant_id = tenant_shard_id.tenant_id;
for timeline in tenant.list_timelines() {
let tti = TenantTimelineId::new(tenant_id, timeline.timeline_id);
if let Some(&entry_lsn) = entries_old.get(&tti) {
if timeline.get_last_record_lsn() <= entry_lsn {
entries_new.insert(tti, entry_lsn);
}
}
}
}
for (&tti, &lsn) in entries_old.iter() {
if !entries_new.contains_key(&tti) {
self.remove_entry_sender
.send(self.entry_path(tti.tenant_id, tti.timeline_id, lsn))
.unwrap();
}
}
BASEBACKUP_CACHE_ENTRIES.set(entries_new.len() as i64);
*self.entries.lock().unwrap() = entries_new;
Ok(())
}
async fn on_startup(&self) -> anyhow::Result<()> {
// Create data_dir and tmp directory if they do not exist.
tokio::fs::create_dir_all(&self.data_dir.join("tmp"))
.await
.map_err(|e| {
anyhow::anyhow!(
"Failed to create basebackup cache data_dir {:?}: {:?}",
self.data_dir,
e
)
})?;
// Read existing entries from the data_dir and add them to in-memory state.
let mut entries = HashMap::new();
let mut dir = tokio::fs::read_dir(&self.data_dir).await?;
while let Some(dir_entry) = dir.next_entry().await? {
let filename = dir_entry.file_name();
if filename == "tmp" {
// Skip the tmp directory.
continue;
}
let parsed = Self::parse_entry_filename(filename.to_string_lossy().as_ref());
let Some((tenant_id, timeline_id, lsn)) = parsed else {
tracing::warn!("Invalid basebackup cache file name: {:?}", filename);
continue;
};
let tti = TenantTimelineId::new(tenant_id, timeline_id);
use std::collections::hash_map::Entry::*;
match entries.entry(tti) {
Occupied(mut entry) => {
let entry_lsn = *entry.get();
// Leave only the latest entry, remove the old one.
if lsn < entry_lsn {
self.remove_entry_sender.send(self.entry_path(
tenant_id,
timeline_id,
lsn,
))?;
} else if lsn > entry_lsn {
self.remove_entry_sender.send(self.entry_path(
tenant_id,
timeline_id,
entry_lsn,
))?;
entry.insert(lsn);
} else {
// Two different filenames parsed to the same timline_id and LSN.
// Should never happen.
return Err(anyhow::anyhow!(
"Duplicate basebackup cache entry with the same LSN: {:?}",
filename
));
}
}
Vacant(entry) => {
entry.insert(lsn);
}
}
}
BASEBACKUP_CACHE_ENTRIES.set(entries.len() as i64);
*self.entries.lock().unwrap() = entries;
Ok(())
}
async fn background(
self: Arc<Self>,
mut prepare_receiver: BasebackupPrepareReceiver,
mut remove_entry_receiver: BasebackupRemoveEntryReceiver,
) {
// Panic in the background is a safe fallback.
// It will drop receivers and the cache will be effectively disabled.
self.on_startup()
.await
.expect("Failed to initialize basebackup cache");
let mut cleanup_ticker = tokio::time::interval(self.config.cleanup_period);
cleanup_ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
Some(req) = prepare_receiver.recv() => {
if let Err(err) = self.prepare_basebackup(
req.tenant_shard_id,
req.timeline_id,
req.lsn,
).await {
tracing::info!("Failed to prepare basebackup: {:#}", err);
self.prepare_err_count.inc();
continue;
}
}
Some(req) = remove_entry_receiver.recv() => {
if let Err(e) = tokio::fs::remove_file(req).await {
tracing::warn!("Failed to remove basebackup cache file: {:#}", e);
}
}
_ = cleanup_ticker.tick() => {
self.cleanup().await.unwrap_or_else(|e| {
tracing::warn!("Failed to clean up basebackup cache: {:#}", e);
});
}
_ = self.cancel.cancelled() => {
tracing::info!("BasebackupCache background task cancelled");
break;
}
}
}
}
/// Prepare a basebackup for the given timeline.
///
/// If the basebackup already exists with a higher LSN or the timeline already
/// has a higher last_record_lsn, skip the preparation.
///
/// The basebackup is prepared in a temporary directory and then moved to the final
/// location to make the operation atomic.
async fn prepare_basebackup(
&self,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
req_lsn: Lsn,
) -> anyhow::Result<()> {
tracing::info!(
tenant_id = %tenant_shard_id.tenant_id,
%timeline_id,
%req_lsn,
"Preparing basebackup for timeline",
);
let tti = TenantTimelineId::new(tenant_shard_id.tenant_id, timeline_id);
{
let entries = self.entries.lock().unwrap();
if let Some(&entry_lsn) = entries.get(&tti) {
if entry_lsn >= req_lsn {
tracing::info!(
%timeline_id,
%req_lsn,
%entry_lsn,
"Basebackup entry already exists for timeline with higher LSN, skipping basebackup",
);
self.prepare_skip_count.inc();
return Ok(());
}
}
if entries.len() as i64 >= self.config.max_size_entries {
tracing::info!(
%timeline_id,
%req_lsn,
"Basebackup cache is full, skipping basebackup",
);
self.prepare_skip_count.inc();
return Ok(());
}
}
let tenant = self
.tenant_manager
.get_attached_tenant_shard(tenant_shard_id)?;
let tenant_state = tenant.current_state();
if tenant_state != TenantState::Active {
anyhow::bail!(
"Tenant {} is not active, current state: {:?}",
tenant_shard_id.tenant_id,
tenant_state
)
}
let timeline = tenant.get_timeline(timeline_id, true)?;
let last_record_lsn = timeline.get_last_record_lsn();
if last_record_lsn > req_lsn {
tracing::info!(
%timeline_id,
%req_lsn,
%last_record_lsn,
"Timeline has a higher LSN than the requested one, skipping basebackup",
);
self.prepare_skip_count.inc();
return Ok(());
}
let entry_tmp_path = self.entry_tmp_path(tenant_shard_id.tenant_id, timeline_id, req_lsn);
let res = self
.prepare_basebackup_tmp(&entry_tmp_path, &timeline, req_lsn)
.await;
if let Err(err) = res {
tracing::info!("Failed to prepare basebackup tmp file: {:#}", err);
// Try to clean up tmp file. If we fail, the background clean up task will take care of it.
match tokio::fs::remove_file(&entry_tmp_path).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {}
Err(e) => {
tracing::info!("Failed to remove basebackup tmp file: {:?}", e);
}
}
return Err(err);
}
// Move the tmp file to the final location atomically.
let entry_path = self.entry_path(tenant_shard_id.tenant_id, timeline_id, req_lsn);
tokio::fs::rename(&entry_tmp_path, &entry_path).await?;
let mut entries = self.entries.lock().unwrap();
if let Some(old_lsn) = entries.insert(tti, req_lsn) {
// Remove the old entry if it exists.
self.remove_entry_sender
.send(self.entry_path(tenant_shard_id.tenant_id, timeline_id, old_lsn))
.unwrap();
}
BASEBACKUP_CACHE_ENTRIES.set(entries.len() as i64);
self.prepare_ok_count.inc();
Ok(())
}
/// Prepares a basebackup in a temporary file.
async fn prepare_basebackup_tmp(
&self,
emptry_tmp_path: &Utf8Path,
timeline: &Arc<Timeline>,
req_lsn: Lsn,
) -> anyhow::Result<()> {
let ctx = RequestContext::new(TaskKind::BasebackupCache, DownloadBehavior::Download);
let ctx = ctx.with_scope_timeline(timeline);
let file = tokio::fs::File::create(emptry_tmp_path).await?;
let mut writer = BufWriter::new(file);
let mut encoder = GzipEncoder::with_quality(
&mut writer,
// Level::Best because compression is not on the hot path of basebackup requests.
// The decompression is almost not affected by the compression level.
async_compression::Level::Best,
);
// We may receive a request before the WAL record is applied to the timeline.
// Wait for the requested LSN to be applied.
timeline
.wait_lsn(
req_lsn,
crate::tenant::timeline::WaitLsnWaiter::BaseBackupCache,
crate::tenant::timeline::WaitLsnTimeout::Default,
&ctx,
)
.await?;
send_basebackup_tarball(
&mut encoder,
timeline,
Some(req_lsn),
None,
false,
false,
&ctx,
)
.await?;
encoder.shutdown().await?;
writer.flush().await?;
writer.into_inner().sync_all().await?;
Ok(())
}
}

View File

@@ -16,12 +16,10 @@ use http_utils::tls_certs::ReloadingCertificateResolver;
use metrics::launch_timestamp::{LaunchTimestamp, set_launch_timestamp_metric};
use metrics::set_build_info_metric;
use nix::sys::socket::{setsockopt, sockopt};
use pageserver::basebackup_cache::BasebackupCache;
use pageserver::config::{PageServerConf, PageserverIdentity, ignored_fields};
use pageserver::controller_upcall_client::StorageControllerUpcallClient;
use pageserver::deletion_queue::DeletionQueue;
use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task};
use pageserver::feature_resolver::FeatureResolver;
use pageserver::metrics::{STARTUP_DURATION, STARTUP_IS_LOADING};
use pageserver::task_mgr::{
BACKGROUND_RUNTIME, COMPUTE_REQUEST_RUNTIME, MGMT_REQUEST_RUNTIME, WALRECEIVER_RUNTIME,
@@ -389,30 +387,23 @@ fn start_pageserver(
// We need to release the lock file only when the process exits.
std::mem::forget(lock_file);
// Bind the HTTP, libpq, and gRPC ports early, to error out if they are
// already in use.
info!(
"Starting pageserver http handler on {} with auth {:#?}",
conf.listen_http_addr, conf.http_auth_type
);
let http_listener = tcp_listener::bind(&conf.listen_http_addr)?;
// Bind the HTTP and libpq ports early, so that if they are in use by some other
// process, we error out early.
let http_addr = &conf.listen_http_addr;
info!("Starting pageserver http handler on {http_addr}");
let http_listener = tcp_listener::bind(http_addr)?;
let https_listener = match conf.listen_https_addr.as_ref() {
Some(https_addr) => {
info!(
"Starting pageserver https handler on {https_addr} with auth {:#?}",
conf.http_auth_type
);
info!("Starting pageserver https handler on {https_addr}");
Some(tcp_listener::bind(https_addr)?)
}
None => None,
};
info!(
"Starting pageserver pg protocol handler on {} with auth {:#?}",
conf.listen_pg_addr, conf.pg_auth_type,
);
let pageserver_listener = tcp_listener::bind(&conf.listen_pg_addr)?;
let pg_addr = &conf.listen_pg_addr;
info!("Starting pageserver pg protocol handler on {pg_addr}");
let pageserver_listener = tcp_listener::bind(pg_addr)?;
// Enable SO_KEEPALIVE on the socket, to detect dead connections faster.
// These are configured via net.ipv4.tcp_keepalive_* sysctls.
@@ -421,15 +412,6 @@ fn start_pageserver(
// support enabling keepalives while using the default OS sysctls.
setsockopt(&pageserver_listener, sockopt::KeepAlive, &true)?;
let mut grpc_listener = None;
if let Some(grpc_addr) = &conf.listen_grpc_addr {
info!(
"Starting pageserver gRPC handler on {grpc_addr} with auth {:#?}",
conf.grpc_auth_type
);
grpc_listener = Some(tcp_listener::bind(grpc_addr).map_err(|e| anyhow!("{e}"))?);
}
// Launch broker client
// The storage_broker::connect call needs to happen inside a tokio runtime thread.
let broker_client = WALRECEIVER_RUNTIME
@@ -457,8 +439,7 @@ fn start_pageserver(
// Initialize authentication for incoming connections
let http_auth;
let pg_auth;
let grpc_auth;
if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type].contains(&AuthType::NeonJWT) {
if conf.http_auth_type == AuthType::NeonJWT || conf.pg_auth_type == AuthType::NeonJWT {
// unwrap is ok because check is performed when creating config, so path is set and exists
let key_path = conf.auth_validation_public_key_path.as_ref().unwrap();
info!("Loading public key(s) for verifying JWT tokens from {key_path:?}");
@@ -466,23 +447,20 @@ fn start_pageserver(
let jwt_auth = JwtAuth::from_key_path(key_path)?;
let auth: Arc<SwappableJwtAuth> = Arc::new(SwappableJwtAuth::new(jwt_auth));
http_auth = match conf.http_auth_type {
http_auth = match &conf.http_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => Some(auth.clone()),
};
pg_auth = match conf.pg_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => Some(auth.clone()),
};
grpc_auth = match conf.grpc_auth_type {
pg_auth = match &conf.pg_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => Some(auth),
};
} else {
http_auth = None;
pg_auth = None;
grpc_auth = None;
}
info!("Using auth for http API: {:#?}", conf.http_auth_type);
info!("Using auth for pg connections: {:#?}", conf.pg_auth_type);
let tls_server_config = if conf.listen_https_addr.is_some() || conf.enable_tls_page_service_api
{
@@ -523,12 +501,6 @@ fn start_pageserver(
// Set up remote storage client
let remote_storage = BACKGROUND_RUNTIME.block_on(create_remote_storage_client(conf))?;
let feature_resolver = create_feature_resolver(
conf,
shutdown_pageserver.clone(),
BACKGROUND_RUNTIME.handle(),
)?;
// Set up deletion queue
let (deletion_queue, deletion_workers) = DeletionQueue::new(
remote_storage.clone(),
@@ -569,8 +541,6 @@ fn start_pageserver(
pageserver::l0_flush::L0FlushGlobalState::new(conf.l0_flush.clone());
// Scan the local 'tenants/' directory and start loading the tenants
let (basebackup_prepare_sender, basebackup_prepare_receiver) =
tokio::sync::mpsc::unbounded_channel();
let deletion_queue_client = deletion_queue.new_client();
let background_purges = mgr::BackgroundPurges::default();
let tenant_manager = BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(
@@ -581,23 +551,12 @@ fn start_pageserver(
remote_storage: remote_storage.clone(),
deletion_queue_client,
l0_flush_global_state,
basebackup_prepare_sender,
feature_resolver,
},
order,
shutdown_pageserver.clone(),
))?;
let tenant_manager = Arc::new(tenant_manager);
let basebackup_cache = BasebackupCache::spawn(
BACKGROUND_RUNTIME.handle(),
conf.basebackup_cache_dir(),
conf.basebackup_cache_config.clone(),
basebackup_prepare_receiver,
Arc::clone(&tenant_manager),
shutdown_pageserver.child_token(),
);
BACKGROUND_RUNTIME.spawn({
let shutdown_pageserver = shutdown_pageserver.clone();
let drive_init = async move {
@@ -804,27 +763,8 @@ fn start_pageserver(
} else {
None
},
basebackup_cache.clone(),
);
// Spawn a Pageserver gRPC server task. It will spawn separate tasks for
// each stream/request.
//
// TODO: this uses a separate Tokio runtime for the page service. If we want
// other gRPC services, they will need their own port and runtime. Is this
// necessary?
let mut page_service_grpc = None;
if let Some(grpc_listener) = grpc_listener {
page_service_grpc = Some(page_service::spawn_grpc(
conf,
tenant_manager.clone(),
grpc_auth,
otel_guard.as_ref().map(|g| g.dispatch.clone()),
grpc_listener,
basebackup_cache,
)?);
}
// All started up! Now just sit and wait for shutdown signal.
BACKGROUND_RUNTIME.block_on(async move {
let signal_token = CancellationToken::new();
@@ -843,7 +783,6 @@ fn start_pageserver(
http_endpoint_listener,
https_endpoint_listener,
page_service,
page_service_grpc,
consumption_metrics_tasks,
disk_usage_eviction_task,
&tenant_manager,
@@ -857,14 +796,6 @@ fn start_pageserver(
})
}
fn create_feature_resolver(
conf: &'static PageServerConf,
shutdown_pageserver: CancellationToken,
handle: &tokio::runtime::Handle,
) -> anyhow::Result<FeatureResolver> {
FeatureResolver::spawn(conf, shutdown_pageserver, handle)
}
async fn create_remote_storage_client(
conf: &'static PageServerConf,
) -> anyhow::Result<GenericRemoteStorage> {

View File

@@ -14,7 +14,7 @@ use std::time::Duration;
use anyhow::{Context, bail, ensure};
use camino::{Utf8Path, Utf8PathBuf};
use once_cell::sync::OnceCell;
use pageserver_api::config::{DiskUsageEvictionTaskConfig, MaxVectoredReadBytes, PostHogConfig};
use pageserver_api::config::{DiskUsageEvictionTaskConfig, MaxVectoredReadBytes};
use pageserver_api::models::ImageCompressionAlgorithm;
use pageserver_api::shard::TenantShardId;
use pem::Pem;
@@ -58,16 +58,11 @@ pub struct PageServerConf {
pub listen_http_addr: String,
/// Example: 127.0.0.1:9899
pub listen_https_addr: Option<String>,
/// If set, expose a gRPC API on this address.
/// Example: 127.0.0.1:51051
///
/// EXPERIMENTAL: this protocol is unstable and under active development.
pub listen_grpc_addr: Option<String>,
/// Path to a file with certificate's private key for https and gRPC API.
/// Path to a file with certificate's private key for https API.
/// Default: server.key
pub ssl_key_file: Utf8PathBuf,
/// Path to a file with a X509 certificate for https and gRPC API.
/// Path to a file with a X509 certificate for https API.
/// Default: server.crt
pub ssl_cert_file: Utf8PathBuf,
/// Period to reload certificate and private key from files.
@@ -105,8 +100,6 @@ pub struct PageServerConf {
pub http_auth_type: AuthType,
/// authentication method for libpq connections from compute
pub pg_auth_type: AuthType,
/// authentication method for gRPC connections from compute
pub grpc_auth_type: AuthType,
/// Path to a file or directory containing public key(s) for verifying JWT tokens.
/// Used for both mgmt and compute auth, if enabled.
pub auth_validation_public_key_path: Option<Utf8PathBuf>,
@@ -238,12 +231,7 @@ pub struct PageServerConf {
/// This is insecure and should only be used in development environments.
pub dev_mode: bool,
/// PostHog integration config.
pub posthog_config: Option<PostHogConfig>,
pub timeline_import_config: pageserver_api::config::TimelineImportConfig,
pub basebackup_cache_config: Option<pageserver_api::config::BasebackupCacheConfig>,
}
/// Token for authentication to safekeepers
@@ -273,10 +261,6 @@ impl PageServerConf {
self.workdir.join("metadata.json")
}
pub fn basebackup_cache_dir(&self) -> Utf8PathBuf {
self.workdir.join("basebackup_cache")
}
pub fn deletion_list_path(&self, sequence: u64) -> Utf8PathBuf {
// Encode a version in the filename, so that if we ever switch away from JSON we can
// increment this.
@@ -365,7 +349,6 @@ impl PageServerConf {
listen_pg_addr,
listen_http_addr,
listen_https_addr,
listen_grpc_addr,
ssl_key_file,
ssl_cert_file,
ssl_cert_reload_period,
@@ -380,7 +363,6 @@ impl PageServerConf {
pg_distrib_dir,
http_auth_type,
pg_auth_type,
grpc_auth_type,
auth_validation_public_key_path,
remote_storage,
broker_endpoint,
@@ -424,9 +406,7 @@ impl PageServerConf {
tracing,
enable_tls_page_service_api,
dev_mode,
posthog_config,
timeline_import_config,
basebackup_cache_config,
} = config_toml;
let mut conf = PageServerConf {
@@ -436,7 +416,6 @@ impl PageServerConf {
listen_pg_addr,
listen_http_addr,
listen_https_addr,
listen_grpc_addr,
ssl_key_file,
ssl_cert_file,
ssl_cert_reload_period,
@@ -449,7 +428,6 @@ impl PageServerConf {
max_file_descriptors,
http_auth_type,
pg_auth_type,
grpc_auth_type,
auth_validation_public_key_path,
remote_storage_config: remote_storage,
broker_endpoint,
@@ -483,7 +461,6 @@ impl PageServerConf {
enable_tls_page_service_api,
dev_mode,
timeline_import_config,
basebackup_cache_config,
// ------------------------------------------------------------
// fields that require additional validation or custom handling
@@ -540,16 +517,13 @@ impl PageServerConf {
}
None => Vec::new(),
},
posthog_config,
};
// ------------------------------------------------------------
// custom validation code that covers more than one field in isolation
// ------------------------------------------------------------
if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type]
.contains(&AuthType::NeonJWT)
{
if conf.http_auth_type == AuthType::NeonJWT || conf.pg_auth_type == AuthType::NeonJWT {
let auth_validation_public_key_path = conf
.auth_validation_public_key_path
.get_or_insert_with(|| workdir.join("auth_public_key.pem"));
@@ -570,23 +544,6 @@ impl PageServerConf {
ratio.numerator, ratio.denominator
)
);
let url = Url::parse(&tracing_config.export_config.endpoint)
.map_err(anyhow::Error::msg)
.with_context(|| {
format!(
"tracing endpoint URL is invalid : {}",
tracing_config.export_config.endpoint
)
})?;
ensure!(
url.scheme() == "http" || url.scheme() == "https",
format!(
"tracing endpoint URL must start with http:// or https://: {}",
tracing_config.export_config.endpoint
)
);
}
IndexEntry::validate_checkpoint_distance(conf.default_tenant_conf.checkpoint_distance)
@@ -703,25 +660,4 @@ mod tests {
PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir)
.expect("parse_and_validate");
}
#[test]
fn test_config_tracing_endpoint_is_invalid() {
let input = r#"
control_plane_api = "http://localhost:6666"
[tracing]
sampling_ratio = { numerator = 1, denominator = 0 }
[tracing.export_config]
endpoint = "localhost:4317"
protocol = "http-binary"
timeout = "1ms"
"#;
let config_toml = toml_edit::de::from_str::<pageserver_api::config::ConfigToml>(input)
.expect("config has valid fields");
let workdir = Utf8PathBuf::from("/nonexistent");
PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir)
.expect_err("parse_and_validate should fail for endpoint without scheme");
}
}

View File

@@ -18,25 +18,12 @@ use crate::tenant::timeline::logical_size::CurrentLogicalSize;
// management.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub(super) enum Name {
/// Timeline last_record_lsn, absolute.
/// Timeline last_record_lsn, absolute
#[serde(rename = "written_size")]
WrittenSize,
/// Timeline last_record_lsn, incremental
#[serde(rename = "written_data_bytes_delta")]
WrittenSizeDelta,
/// Written bytes only on this timeline (not including ancestors):
/// written_size - ancestor_lsn
///
/// On the root branch, this is equivalent to `written_size`.
#[serde(rename = "written_size_since_parent")]
WrittenSizeSinceParent,
/// PITR history size only on this timeline (not including ancestors):
/// last_record_lsn - max(pitr_cutoff, ancestor_lsn).
///
/// On the root branch, this is its entire PITR history size. Not emitted if GC hasn't computed
/// the PITR cutoff yet. 0 if PITR is disabled.
#[serde(rename = "pitr_history_size_since_parent")]
PitrHistorySizeSinceParent,
/// Timeline logical size
#[serde(rename = "timeline_logical_size")]
LogicalSize,
@@ -170,32 +157,6 @@ impl MetricsKey {
.incremental_values()
}
/// `written_size` - `ancestor_lsn`.
const fn written_size_since_parent(
tenant_id: TenantId,
timeline_id: TimelineId,
) -> AbsoluteValueFactory {
MetricsKey {
tenant_id,
timeline_id: Some(timeline_id),
metric: Name::WrittenSizeSinceParent,
}
.absolute_values()
}
/// `written_size` - max(`pitr_cutoff`, `ancestor_lsn`).
const fn pitr_history_size_since_parent(
tenant_id: TenantId,
timeline_id: TimelineId,
) -> AbsoluteValueFactory {
MetricsKey {
tenant_id,
timeline_id: Some(timeline_id),
metric: Name::PitrHistorySizeSinceParent,
}
.absolute_values()
}
/// Exact [`Timeline::get_current_logical_size`].
///
/// [`Timeline::get_current_logical_size`]: crate::tenant::Timeline::get_current_logical_size
@@ -373,13 +334,7 @@ impl TenantSnapshot {
struct TimelineSnapshot {
loaded_at: (Lsn, SystemTime),
last_record_lsn: Lsn,
ancestor_lsn: Lsn,
current_exact_logical_size: Option<u64>,
/// Whether PITR is enabled (pitr_interval > 0).
pitr_enabled: bool,
/// The PITR cutoff LSN. None if not yet initialized. If PITR is disabled, this is approximately
/// Some(last_record_lsn), but may lag behind it since it's computed periodically.
pitr_cutoff: Option<Lsn>,
}
impl TimelineSnapshot {
@@ -399,9 +354,6 @@ impl TimelineSnapshot {
} else {
let loaded_at = t.loaded_at;
let last_record_lsn = t.get_last_record_lsn();
let ancestor_lsn = t.get_ancestor_lsn();
let pitr_enabled = !t.get_pitr_interval().is_zero();
let pitr_cutoff = t.gc_info.read().unwrap().cutoffs.time;
let current_exact_logical_size = {
let span = tracing::info_span!("collect_metrics_iteration", tenant_id = %t.tenant_shard_id.tenant_id, timeline_id = %t.timeline_id);
@@ -421,10 +373,7 @@ impl TimelineSnapshot {
Ok(Some(TimelineSnapshot {
loaded_at,
last_record_lsn,
ancestor_lsn,
current_exact_logical_size,
pitr_enabled,
pitr_cutoff,
}))
}
}
@@ -475,8 +424,6 @@ impl TimelineSnapshot {
let up_to = now;
let written_size_last = written_size_now.value.max(prev.1); // don't regress
if let Some(delta) = written_size_now.value.checked_sub(prev.1) {
let key_value = written_size_delta_key.from_until(prev.0, up_to, delta);
// written_size_delta
@@ -494,27 +441,6 @@ impl TimelineSnapshot {
});
}
// Compute the branch-local written size.
let written_size_since_parent_key =
MetricsKey::written_size_since_parent(tenant_id, timeline_id);
metrics.push(
written_size_since_parent_key
.at(now, written_size_last.saturating_sub(self.ancestor_lsn.0)),
);
// Compute the branch-local PITR history size. Not emitted if GC hasn't yet computed the
// PITR cutoff. 0 if PITR is disabled.
let pitr_history_size_since_parent_key =
MetricsKey::pitr_history_size_since_parent(tenant_id, timeline_id);
if !self.pitr_enabled {
metrics.push(pitr_history_size_since_parent_key.at(now, 0));
} else if let Some(pitr_cutoff) = self.pitr_cutoff {
metrics.push(pitr_history_size_since_parent_key.at(
now,
written_size_last.saturating_sub(pitr_cutoff.max(self.ancestor_lsn).0),
));
}
{
let factory = MetricsKey::timeline_logical_size(tenant_id, timeline_id);
let current_or_previous = self

View File

@@ -12,17 +12,12 @@ fn startup_collected_timeline_metrics_before_advancing() {
let cache = HashMap::new();
let initdb_lsn = Lsn(0x10000);
let pitr_cutoff = Lsn(0x11000);
let disk_consistent_lsn = Lsn(initdb_lsn.0 * 2);
let logical_size = 0x42000;
let snap = TimelineSnapshot {
loaded_at: (disk_consistent_lsn, SystemTime::now()),
last_record_lsn: disk_consistent_lsn,
ancestor_lsn: Lsn(0),
current_exact_logical_size: Some(logical_size),
pitr_enabled: true,
pitr_cutoff: Some(pitr_cutoff),
current_exact_logical_size: Some(0x42000),
};
let now = DateTime::<Utc>::from(SystemTime::now());
@@ -38,11 +33,7 @@ fn startup_collected_timeline_metrics_before_advancing() {
0
),
MetricsKey::written_size(tenant_id, timeline_id).at(now, disk_consistent_lsn.0),
MetricsKey::written_size_since_parent(tenant_id, timeline_id)
.at(now, disk_consistent_lsn.0),
MetricsKey::pitr_history_size_since_parent(tenant_id, timeline_id)
.at(now, disk_consistent_lsn.0 - pitr_cutoff.0),
MetricsKey::timeline_logical_size(tenant_id, timeline_id).at(now, logical_size)
MetricsKey::timeline_logical_size(tenant_id, timeline_id).at(now, 0x42000)
]
);
}
@@ -58,9 +49,7 @@ fn startup_collected_timeline_metrics_second_round() {
let before = DateTime::<Utc>::from(before);
let initdb_lsn = Lsn(0x10000);
let pitr_cutoff = Lsn(0x11000);
let disk_consistent_lsn = Lsn(initdb_lsn.0 * 2);
let logical_size = 0x42000;
let mut metrics = Vec::new();
let cache = HashMap::from([MetricsKey::written_size(tenant_id, timeline_id)
@@ -70,10 +59,7 @@ fn startup_collected_timeline_metrics_second_round() {
let snap = TimelineSnapshot {
loaded_at: (disk_consistent_lsn, init),
last_record_lsn: disk_consistent_lsn,
ancestor_lsn: Lsn(0),
current_exact_logical_size: Some(logical_size),
pitr_enabled: true,
pitr_cutoff: Some(pitr_cutoff),
current_exact_logical_size: Some(0x42000),
};
snap.to_metrics(tenant_id, timeline_id, now, &mut metrics, &cache);
@@ -83,11 +69,7 @@ fn startup_collected_timeline_metrics_second_round() {
&[
MetricsKey::written_size_delta(tenant_id, timeline_id).from_until(before, now, 0),
MetricsKey::written_size(tenant_id, timeline_id).at(now, disk_consistent_lsn.0),
MetricsKey::written_size_since_parent(tenant_id, timeline_id)
.at(now, disk_consistent_lsn.0),
MetricsKey::pitr_history_size_since_parent(tenant_id, timeline_id)
.at(now, disk_consistent_lsn.0 - pitr_cutoff.0),
MetricsKey::timeline_logical_size(tenant_id, timeline_id).at(now, logical_size)
MetricsKey::timeline_logical_size(tenant_id, timeline_id).at(now, 0x42000)
]
);
}
@@ -104,9 +86,7 @@ fn startup_collected_timeline_metrics_nth_round_at_same_lsn() {
let before = DateTime::<Utc>::from(before);
let initdb_lsn = Lsn(0x10000);
let pitr_cutoff = Lsn(0x11000);
let disk_consistent_lsn = Lsn(initdb_lsn.0 * 2);
let logical_size = 0x42000;
let mut metrics = Vec::new();
let cache = HashMap::from([
@@ -123,10 +103,7 @@ fn startup_collected_timeline_metrics_nth_round_at_same_lsn() {
let snap = TimelineSnapshot {
loaded_at: (disk_consistent_lsn, init),
last_record_lsn: disk_consistent_lsn,
ancestor_lsn: Lsn(0),
current_exact_logical_size: Some(logical_size),
pitr_enabled: true,
pitr_cutoff: Some(pitr_cutoff),
current_exact_logical_size: Some(0x42000),
};
snap.to_metrics(tenant_id, timeline_id, now, &mut metrics, &cache);
@@ -136,18 +113,16 @@ fn startup_collected_timeline_metrics_nth_round_at_same_lsn() {
&[
MetricsKey::written_size_delta(tenant_id, timeline_id).from_until(just_before, now, 0),
MetricsKey::written_size(tenant_id, timeline_id).at(now, disk_consistent_lsn.0),
MetricsKey::written_size_since_parent(tenant_id, timeline_id)
.at(now, disk_consistent_lsn.0),
MetricsKey::pitr_history_size_since_parent(tenant_id, timeline_id)
.at(now, disk_consistent_lsn.0 - pitr_cutoff.0),
MetricsKey::timeline_logical_size(tenant_id, timeline_id).at(now, logical_size)
MetricsKey::timeline_logical_size(tenant_id, timeline_id).at(now, 0x42000)
]
);
}
/// Tests that written sizes do not regress across restarts.
#[test]
fn post_restart_written_sizes_with_rolled_back_last_record_lsn() {
// it can happen that we lose the inmemorylayer but have previously sent metrics and we
// should never go backwards
let tenant_id = TenantId::generate();
let timeline_id = TimelineId::generate();
@@ -165,10 +140,7 @@ fn post_restart_written_sizes_with_rolled_back_last_record_lsn() {
let snap = TimelineSnapshot {
loaded_at: (Lsn(50), at_restart),
last_record_lsn: Lsn(50),
ancestor_lsn: Lsn(0),
current_exact_logical_size: None,
pitr_enabled: true,
pitr_cutoff: Some(Lsn(20)),
};
let mut cache = HashMap::from([
@@ -197,8 +169,6 @@ fn post_restart_written_sizes_with_rolled_back_last_record_lsn() {
0
),
MetricsKey::written_size(tenant_id, timeline_id).at(now, 100),
MetricsKey::written_size_since_parent(tenant_id, timeline_id).at(now, 100),
MetricsKey::pitr_history_size_since_parent(tenant_id, timeline_id).at(now, 80),
]
);
@@ -213,157 +183,6 @@ fn post_restart_written_sizes_with_rolled_back_last_record_lsn() {
&[
MetricsKey::written_size_delta(tenant_id, timeline_id).from_until(now, later, 0),
MetricsKey::written_size(tenant_id, timeline_id).at(later, 100),
MetricsKey::written_size_since_parent(tenant_id, timeline_id).at(later, 100),
MetricsKey::pitr_history_size_since_parent(tenant_id, timeline_id).at(later, 80),
]
);
}
/// Tests that written sizes do not regress across restarts, even on child branches.
#[test]
fn post_restart_written_sizes_with_rolled_back_last_record_lsn_and_ancestor_lsn() {
let tenant_id = TenantId::generate();
let timeline_id = TimelineId::generate();
let [later, now, at_restart] = time_backwards();
// FIXME: tests would be so much easier if we did not need to juggle back and forth
// SystemTime and DateTime::<Utc> ... Could do the conversion only at upload time?
let now = DateTime::<Utc>::from(now);
let later = DateTime::<Utc>::from(later);
let before_restart = at_restart - std::time::Duration::from_secs(5 * 60);
let way_before = before_restart - std::time::Duration::from_secs(10 * 60);
let before_restart = DateTime::<Utc>::from(before_restart);
let way_before = DateTime::<Utc>::from(way_before);
let snap = TimelineSnapshot {
loaded_at: (Lsn(50), at_restart),
last_record_lsn: Lsn(50),
ancestor_lsn: Lsn(40),
current_exact_logical_size: None,
pitr_enabled: true,
pitr_cutoff: Some(Lsn(20)),
};
let mut cache = HashMap::from([
MetricsKey::written_size(tenant_id, timeline_id)
.at(before_restart, 100)
.to_kv_pair(),
MetricsKey::written_size_delta(tenant_id, timeline_id)
.from_until(
way_before,
before_restart,
// not taken into account, but the timestamps are important
999_999_999,
)
.to_kv_pair(),
]);
let mut metrics = Vec::new();
snap.to_metrics(tenant_id, timeline_id, now, &mut metrics, &cache);
assert_eq!(
metrics,
&[
MetricsKey::written_size_delta(tenant_id, timeline_id).from_until(
before_restart,
now,
0
),
MetricsKey::written_size(tenant_id, timeline_id).at(now, 100),
MetricsKey::written_size_since_parent(tenant_id, timeline_id).at(now, 60),
MetricsKey::pitr_history_size_since_parent(tenant_id, timeline_id).at(now, 60),
]
);
// now if we cache these metrics, and re-run while "still in recovery"
cache.extend(metrics.drain(..).map(|x| x.to_kv_pair()));
// "still in recovery", because our snapshot did not change
snap.to_metrics(tenant_id, timeline_id, later, &mut metrics, &cache);
assert_eq!(
metrics,
&[
MetricsKey::written_size_delta(tenant_id, timeline_id).from_until(now, later, 0),
MetricsKey::written_size(tenant_id, timeline_id).at(later, 100),
MetricsKey::written_size_since_parent(tenant_id, timeline_id).at(later, 60),
MetricsKey::pitr_history_size_since_parent(tenant_id, timeline_id).at(later, 60),
]
);
}
/// Tests that written sizes do not regress across restarts, even on child branches and
/// with a PITR cutoff after the branch point.
#[test]
fn post_restart_written_sizes_with_rolled_back_last_record_lsn_and_ancestor_lsn_and_pitr_cutoff() {
let tenant_id = TenantId::generate();
let timeline_id = TimelineId::generate();
let [later, now, at_restart] = time_backwards();
// FIXME: tests would be so much easier if we did not need to juggle back and forth
// SystemTime and DateTime::<Utc> ... Could do the conversion only at upload time?
let now = DateTime::<Utc>::from(now);
let later = DateTime::<Utc>::from(later);
let before_restart = at_restart - std::time::Duration::from_secs(5 * 60);
let way_before = before_restart - std::time::Duration::from_secs(10 * 60);
let before_restart = DateTime::<Utc>::from(before_restart);
let way_before = DateTime::<Utc>::from(way_before);
let snap = TimelineSnapshot {
loaded_at: (Lsn(50), at_restart),
last_record_lsn: Lsn(50),
ancestor_lsn: Lsn(30),
current_exact_logical_size: None,
pitr_enabled: true,
pitr_cutoff: Some(Lsn(40)),
};
let mut cache = HashMap::from([
MetricsKey::written_size(tenant_id, timeline_id)
.at(before_restart, 100)
.to_kv_pair(),
MetricsKey::written_size_delta(tenant_id, timeline_id)
.from_until(
way_before,
before_restart,
// not taken into account, but the timestamps are important
999_999_999,
)
.to_kv_pair(),
]);
let mut metrics = Vec::new();
snap.to_metrics(tenant_id, timeline_id, now, &mut metrics, &cache);
assert_eq!(
metrics,
&[
MetricsKey::written_size_delta(tenant_id, timeline_id).from_until(
before_restart,
now,
0
),
MetricsKey::written_size(tenant_id, timeline_id).at(now, 100),
MetricsKey::written_size_since_parent(tenant_id, timeline_id).at(now, 70),
MetricsKey::pitr_history_size_since_parent(tenant_id, timeline_id).at(now, 60),
]
);
// now if we cache these metrics, and re-run while "still in recovery"
cache.extend(metrics.drain(..).map(|x| x.to_kv_pair()));
// "still in recovery", because our snapshot did not change
snap.to_metrics(tenant_id, timeline_id, later, &mut metrics, &cache);
assert_eq!(
metrics,
&[
MetricsKey::written_size_delta(tenant_id, timeline_id).from_until(now, later, 0),
MetricsKey::written_size(tenant_id, timeline_id).at(later, 100),
MetricsKey::written_size_since_parent(tenant_id, timeline_id).at(later, 70),
MetricsKey::pitr_history_size_since_parent(tenant_id, timeline_id).at(later, 60),
]
);
}
@@ -382,10 +201,7 @@ fn post_restart_current_exact_logical_size_uses_cached() {
let snap = TimelineSnapshot {
loaded_at: (Lsn(50), at_restart),
last_record_lsn: Lsn(50),
ancestor_lsn: Lsn(0),
current_exact_logical_size: None,
pitr_enabled: true,
pitr_cutoff: None,
};
let cache = HashMap::from([MetricsKey::timeline_logical_size(tenant_id, timeline_id)
@@ -470,101 +286,16 @@ fn time_backwards<const N: usize>() -> [std::time::SystemTime; N] {
times
}
/// Tests that disabled PITR history does not yield any history size, even when the PITR cutoff
/// indicates otherwise.
#[test]
fn pitr_disabled_yields_no_history_size() {
let tenant_id = TenantId::generate();
let timeline_id = TimelineId::generate();
let mut metrics = Vec::new();
let cache = HashMap::new();
let initdb_lsn = Lsn(0x10000);
let pitr_cutoff = Lsn(0x11000);
let disk_consistent_lsn = Lsn(initdb_lsn.0 * 2);
let snap = TimelineSnapshot {
loaded_at: (disk_consistent_lsn, SystemTime::now()),
last_record_lsn: disk_consistent_lsn,
ancestor_lsn: Lsn(0),
current_exact_logical_size: None,
pitr_enabled: false,
pitr_cutoff: Some(pitr_cutoff),
};
let now = DateTime::<Utc>::from(SystemTime::now());
snap.to_metrics(tenant_id, timeline_id, now, &mut metrics, &cache);
assert_eq!(
metrics,
&[
MetricsKey::written_size_delta(tenant_id, timeline_id).from_until(
snap.loaded_at.1.into(),
now,
0
),
MetricsKey::written_size(tenant_id, timeline_id).at(now, disk_consistent_lsn.0),
MetricsKey::written_size_since_parent(tenant_id, timeline_id)
.at(now, disk_consistent_lsn.0),
MetricsKey::pitr_history_size_since_parent(tenant_id, timeline_id).at(now, 0),
]
);
}
/// Tests that uninitialized PITR cutoff does not emit any history size metric at all.
#[test]
fn pitr_uninitialized_does_not_emit_history_size() {
let tenant_id = TenantId::generate();
let timeline_id = TimelineId::generate();
let mut metrics = Vec::new();
let cache = HashMap::new();
let initdb_lsn = Lsn(0x10000);
let disk_consistent_lsn = Lsn(initdb_lsn.0 * 2);
let snap = TimelineSnapshot {
loaded_at: (disk_consistent_lsn, SystemTime::now()),
last_record_lsn: disk_consistent_lsn,
ancestor_lsn: Lsn(0),
current_exact_logical_size: None,
pitr_enabled: true,
pitr_cutoff: None,
};
let now = DateTime::<Utc>::from(SystemTime::now());
snap.to_metrics(tenant_id, timeline_id, now, &mut metrics, &cache);
assert_eq!(
metrics,
&[
MetricsKey::written_size_delta(tenant_id, timeline_id).from_until(
snap.loaded_at.1.into(),
now,
0
),
MetricsKey::written_size(tenant_id, timeline_id).at(now, disk_consistent_lsn.0),
MetricsKey::written_size_since_parent(tenant_id, timeline_id)
.at(now, disk_consistent_lsn.0),
]
);
}
pub(crate) const fn metric_examples_old(
tenant_id: TenantId,
timeline_id: TimelineId,
now: DateTime<Utc>,
before: DateTime<Utc>,
) -> [RawMetric; 7] {
) -> [RawMetric; 5] {
[
MetricsKey::written_size(tenant_id, timeline_id).at_old_format(now, 0),
MetricsKey::written_size_delta(tenant_id, timeline_id)
.from_until_old_format(before, now, 0),
MetricsKey::written_size_since_parent(tenant_id, timeline_id).at_old_format(now, 0),
MetricsKey::pitr_history_size_since_parent(tenant_id, timeline_id).at_old_format(now, 0),
MetricsKey::timeline_logical_size(tenant_id, timeline_id).at_old_format(now, 0),
MetricsKey::remote_storage_size(tenant_id).at_old_format(now, 0),
MetricsKey::synthetic_size(tenant_id).at_old_format(now, 1),
@@ -576,12 +307,10 @@ pub(crate) const fn metric_examples(
timeline_id: TimelineId,
now: DateTime<Utc>,
before: DateTime<Utc>,
) -> [NewRawMetric; 7] {
) -> [NewRawMetric; 5] {
[
MetricsKey::written_size(tenant_id, timeline_id).at(now, 0),
MetricsKey::written_size_delta(tenant_id, timeline_id).from_until(before, now, 0),
MetricsKey::written_size_since_parent(tenant_id, timeline_id).at(now, 0),
MetricsKey::pitr_history_size_since_parent(tenant_id, timeline_id).at(now, 0),
MetricsKey::timeline_logical_size(tenant_id, timeline_id).at(now, 0),
MetricsKey::remote_storage_size(tenant_id).at(now, 0),
MetricsKey::synthetic_size(tenant_id).at(now, 1),

View File

@@ -513,14 +513,6 @@ mod tests {
line!(),
r#"{"type":"incremental","start_time":"2023-09-14T00:00:00.123456789Z","stop_time":"2023-09-15T00:00:00.123456789Z","metric":"written_data_bytes_delta","idempotency_key":"2023-09-15 00:00:00.123456789 UTC-1-0000","value":0,"tenant_id":"00000000000000000000000000000000","timeline_id":"ffffffffffffffffffffffffffffffff"}"#,
),
(
line!(),
r#"{"type":"absolute","time":"2023-09-15T00:00:00.123456789Z","metric":"written_size_since_parent","idempotency_key":"2023-09-15 00:00:00.123456789 UTC-1-0000","value":0,"tenant_id":"00000000000000000000000000000000","timeline_id":"ffffffffffffffffffffffffffffffff"}"#,
),
(
line!(),
r#"{"type":"absolute","time":"2023-09-15T00:00:00.123456789Z","metric":"pitr_history_size_since_parent","idempotency_key":"2023-09-15 00:00:00.123456789 UTC-1-0000","value":0,"tenant_id":"00000000000000000000000000000000","timeline_id":"ffffffffffffffffffffffffffffffff"}"#,
),
(
line!(),
r#"{"type":"absolute","time":"2023-09-15T00:00:00.123456789Z","metric":"timeline_logical_size","idempotency_key":"2023-09-15 00:00:00.123456789 UTC-1-0000","value":0,"tenant_id":"00000000000000000000000000000000","timeline_id":"ffffffffffffffffffffffffffffffff"}"#,
@@ -568,7 +560,7 @@ mod tests {
assert_eq!(upgraded_samples, new_samples);
}
fn metric_samples_old() -> [RawMetric; 7] {
fn metric_samples_old() -> [RawMetric; 5] {
let tenant_id = TenantId::from_array([0; 16]);
let timeline_id = TimelineId::from_array([0xff; 16]);
@@ -580,7 +572,7 @@ mod tests {
super::super::metrics::metric_examples_old(tenant_id, timeline_id, now, before)
}
fn metric_samples() -> [NewRawMetric; 7] {
fn metric_samples() -> [NewRawMetric; 5] {
let tenant_id = TenantId::from_array([0; 16]);
let timeline_id = TimelineId::from_array([0xff; 16]);

View File

@@ -1,94 +0,0 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use posthog_client_lite::{
FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError,
};
use tokio_util::sync::CancellationToken;
use utils::id::TenantId;
use crate::config::PageServerConf;
#[derive(Clone)]
pub struct FeatureResolver {
inner: Option<Arc<FeatureResolverBackgroundLoop>>,
}
impl FeatureResolver {
pub fn new_disabled() -> Self {
Self { inner: None }
}
pub fn spawn(
conf: &PageServerConf,
shutdown_pageserver: CancellationToken,
handle: &tokio::runtime::Handle,
) -> anyhow::Result<Self> {
// DO NOT block in this function: make it return as fast as possible to avoid startup delays.
if let Some(posthog_config) = &conf.posthog_config {
let inner = FeatureResolverBackgroundLoop::new(
PostHogClientConfig {
server_api_key: posthog_config.server_api_key.clone(),
client_api_key: posthog_config.client_api_key.clone(),
project_id: posthog_config.project_id.clone(),
private_api_url: posthog_config.private_api_url.clone(),
public_api_url: posthog_config.public_api_url.clone(),
},
shutdown_pageserver,
);
let inner = Arc::new(inner);
// TODO: make this configurable
inner.clone().spawn(handle, Duration::from_secs(60));
Ok(FeatureResolver { inner: Some(inner) })
} else {
Ok(FeatureResolver { inner: None })
}
}
/// Evaluate a multivariate feature flag. Currently, we do not support any properties.
///
/// Error handling: the caller should inspect the error and decide the behavior when a feature flag
/// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be
/// propagated beyond where the feature flag gets resolved.
pub fn evaluate_multivariate(
&self,
flag_key: &str,
tenant_id: TenantId,
) -> Result<String, PostHogEvaluationError> {
if let Some(inner) = &self.inner {
inner.feature_store().evaluate_multivariate(
flag_key,
&tenant_id.to_string(),
&HashMap::new(),
)
} else {
Err(PostHogEvaluationError::NotAvailable(
"PostHog integration is not enabled".to_string(),
))
}
}
/// Evaluate a boolean feature flag. Currently, we do not support any properties.
///
/// Returns `Ok(())` if the flag is evaluated to true, otherwise returns an error.
///
/// Error handling: the caller should inspect the error and decide the behavior when a feature flag
/// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be
/// propagated beyond where the feature flag gets resolved.
pub fn evaluate_boolean(
&self,
flag_key: &str,
tenant_id: TenantId,
) -> Result<(), PostHogEvaluationError> {
if let Some(inner) = &self.inner {
inner.feature_store().evaluate_boolean(
flag_key,
&tenant_id.to_string(),
&HashMap::new(),
)
} else {
Err(PostHogEvaluationError::NotAvailable(
"PostHog integration is not enabled".to_string(),
))
}
}
}

View File

@@ -353,33 +353,6 @@ paths:
"200":
description: OK
/v1/tenant/{tenant_shard_id}/timeline/{timeline_id}/mark_invisible:
parameters:
- name: tenant_shard_id
in: path
required: true
schema:
type: string
- name: timeline_id
in: path
required: true
schema:
type: string
format: hex
put:
requestBody:
content:
application/json:
schema:
type: object
properties:
is_visible:
type: boolean
default: false
responses:
"200":
description: OK
/v1/tenant/{tenant_shard_id}/location_config:
parameters:
- name: tenant_shard_id
@@ -653,8 +626,6 @@ paths:
format: hex
pg_version:
type: integer
read_only:
type: boolean
existing_initdb_timeline_id:
type: string
format: hex

View File

@@ -370,18 +370,6 @@ impl From<crate::tenant::secondary::SecondaryTenantError> for ApiError {
}
}
impl From<crate::tenant::FinalizeTimelineImportError> for ApiError {
fn from(err: crate::tenant::FinalizeTimelineImportError) -> ApiError {
use crate::tenant::FinalizeTimelineImportError::*;
match err {
ImportTaskStillRunning => {
ApiError::ResourceUnavailable("Import task still running".into())
}
ShuttingDown => ApiError::ShuttingDown,
}
}
}
// Helper function to construct a TimelineInfo struct for a timeline
async fn build_timeline_info(
timeline: &Arc<Timeline>,
@@ -461,7 +449,7 @@ async fn build_timeline_info_common(
// Internally we distinguish between the planned GC cutoff (PITR point) and the "applied" GC cutoff (where we
// actually trimmed data to), which can pass each other when PITR is changed.
let min_readable_lsn = std::cmp::max(
timeline.get_gc_cutoff_lsn().unwrap_or_default(),
timeline.get_gc_cutoff_lsn(),
*timeline.get_applied_gc_cutoff_lsn(),
);
@@ -584,7 +572,6 @@ async fn timeline_create_handler(
TimelineCreateRequestMode::Branch {
ancestor_timeline_id,
ancestor_start_lsn,
read_only: _,
pg_version: _,
} => tenant::CreateTimelineParams::Branch(tenant::CreateTimelineParamsBranch {
new_timeline_id,
@@ -3545,7 +3532,10 @@ async fn activate_post_import_handler(
tenant.wait_to_become_active(ACTIVE_TENANT_TIMEOUT).await?;
tenant.finalize_importing_timeline(timeline_id).await?;
tenant
.finalize_importing_timeline(timeline_id)
.await
.map_err(ApiError::InternalServerError)?;
match tenant.get_timeline(timeline_id, false) {
Ok(_timeline) => {

View File

@@ -3,14 +3,12 @@
mod auth;
pub mod basebackup;
pub mod basebackup_cache;
pub mod config;
pub mod consumption_metrics;
pub mod context;
pub mod controller_upcall_client;
pub mod deletion_queue;
pub mod disk_usage_eviction_task;
pub mod feature_resolver;
pub mod http;
pub mod import_datadir;
pub mod l0_flush;
@@ -85,7 +83,6 @@ pub async fn shutdown_pageserver(
http_listener: HttpEndpointListener,
https_listener: Option<HttpsEndpointListener>,
page_service: page_service::Listener,
grpc_task: Option<CancellableTask>,
consumption_metrics_worker: ConsumptionMetricsTasks,
disk_usage_eviction_task: Option<DiskUsageEvictionTask>,
tenant_manager: &TenantManager,
@@ -179,16 +176,6 @@ pub async fn shutdown_pageserver(
)
.await;
// Shut down the gRPC server task, including request handlers.
if let Some(grpc_task) = grpc_task {
timed(
grpc_task.shutdown(),
"shutdown gRPC PageRequestHandler",
Duration::from_secs(3),
)
.await;
}
// Shut down all the tenants. This flushes everything to disk and kills
// the checkpoint and GC tasks.
timed(

View File

@@ -1066,15 +1066,6 @@ pub(crate) static TENANT_SYNTHETIC_SIZE_METRIC: Lazy<UIntGaugeVec> = Lazy::new(|
.expect("Failed to register pageserver_tenant_synthetic_cached_size_bytes metric")
});
pub(crate) static TENANT_OFFLOADED_TIMELINES: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"pageserver_tenant_offloaded_timelines",
"Number of offloaded timelines of a tenant",
&["tenant_id", "shard_id"]
)
.expect("Failed to register pageserver_tenant_offloaded_timelines metric")
});
pub(crate) static EVICTION_ITERATION_DURATION: Lazy<HistogramVec> = Lazy::new(|| {
register_histogram_vec!(
"pageserver_eviction_iteration_duration_seconds_global",
@@ -2234,10 +2225,8 @@ impl BasebackupQueryTimeOngoingRecording<'_> {
// If you want to change categorize of a specific error, also change it in `log_query_error`.
let metric = match res {
Ok(_) => &self.parent.ok,
Err(QueryError::Shutdown) | Err(QueryError::Reconnect) => {
// Do not observe ok/err for shutdown/reconnect.
// Reconnect error might be raised when the operation is waiting for LSN and the tenant shutdown interrupts
// the operation. A reconnect error will be issued and the client will retry.
Err(QueryError::Shutdown) => {
// Do not observe ok/err for shutdown
return;
}
Err(QueryError::Disconnected(ConnectionError::Io(io_error)))
@@ -3562,14 +3551,11 @@ impl TimelineMetrics {
}
pub(crate) fn remove_tenant_metrics(tenant_shard_id: &TenantShardId) {
let tid = tenant_shard_id.tenant_id.to_string();
let shard_id = tenant_shard_id.shard_slug().to_string();
// Only shard zero deals in synthetic sizes
if tenant_shard_id.is_shard_zero() {
let tid = tenant_shard_id.tenant_id.to_string();
let _ = TENANT_SYNTHETIC_SIZE_METRIC.remove_label_values(&[&tid]);
}
let _ = TENANT_OFFLOADED_TIMELINES.remove_label_values(&[&tid, &shard_id]);
tenant_throttling::remove_tenant_metrics(tenant_shard_id);
@@ -4361,42 +4347,6 @@ pub(crate) fn set_tokio_runtime_setup(setup: &str, num_threads: NonZeroUsize) {
.set(u64::try_from(num_threads.get()).unwrap());
}
pub(crate) static BASEBACKUP_CACHE_READ: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"pageserver_basebackup_cache_read_total",
"Number of read accesses to the basebackup cache grouped by hit/miss/error",
&["result"]
)
.expect("failed to define a metric")
});
pub(crate) static BASEBACKUP_CACHE_PREPARE: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"pageserver_basebackup_cache_prepare_total",
"Number of prepare requests processed by the basebackup cache grouped by ok/skip/error",
&["result"]
)
.expect("failed to define a metric")
});
pub(crate) static BASEBACKUP_CACHE_ENTRIES: Lazy<IntGauge> = Lazy::new(|| {
register_int_gauge!(
"pageserver_basebackup_cache_entries_total",
"Number of entries in the basebackup cache"
)
.expect("failed to define a metric")
});
// FIXME: Support basebackup cache size metrics.
#[allow(dead_code)]
pub(crate) static BASEBACKUP_CACHE_SIZE: Lazy<IntGauge> = Lazy::new(|| {
register_int_gauge!(
"pageserver_basebackup_cache_size_bytes",
"Total size of all basebackup cache entries on disk in bytes"
)
.expect("failed to define a metric")
});
static PAGESERVER_CONFIG_IGNORED_ITEMS: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"pageserver_config_ignored_items",

View File

@@ -4,16 +4,16 @@
use std::borrow::Cow;
use std::num::NonZeroUsize;
use std::os::fd::AsRawFd;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime};
use std::{io, str};
use crate::PERF_TRACE_TARGET;
use anyhow::{Context, bail};
use async_compression::tokio::write::GzipEncoder;
use bytes::Buf;
use futures::{FutureExt, Stream};
use futures::FutureExt;
use itertools::Itertools;
use jsonwebtoken::TokenData;
use once_cell::sync::OnceCell;
@@ -31,7 +31,6 @@ use pageserver_api::models::{
};
use pageserver_api::reltag::SlruKind;
use pageserver_api::shard::TenantShardId;
use pageserver_page_api::proto;
use postgres_backend::{
AuthType, PostgresBackend, PostgresBackendReader, QueryError, is_expected_io_error,
};
@@ -43,21 +42,18 @@ use strum_macros::IntoStaticStr;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufWriter};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tonic::service::Interceptor as _;
use tracing::*;
use utils::auth::{Claims, Scope, SwappableJwtAuth};
use utils::failpoint_support;
use utils::id::{TenantId, TenantTimelineId, TimelineId};
use utils::id::{TenantId, TimelineId};
use utils::logging::log_slow;
use utils::lsn::Lsn;
use utils::shard::ShardIndex;
use utils::simple_rcu::RcuReadGuard;
use utils::sync::gate::{Gate, GateGuard};
use utils::sync::spsc_fold;
use crate::auth::check_permission;
use crate::basebackup::{self, BasebackupError};
use crate::basebackup_cache::BasebackupCache;
use crate::basebackup::BasebackupError;
use crate::config::PageServerConf;
use crate::context::{
DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder,
@@ -78,7 +74,7 @@ use crate::tenant::mgr::{
use crate::tenant::storage_layer::IoConcurrency;
use crate::tenant::timeline::{self, WaitLsnError};
use crate::tenant::{GetTimelineError, PageReconstructError, Timeline};
use crate::{CancellableTask, PERF_TRACE_TARGET, timed_after_cancellation};
use crate::{basebackup, timed_after_cancellation};
/// How long we may wait for a [`crate::tenant::mgr::TenantSlot::InProgress`]` and/or a [`crate::tenant::TenantShard`] which
/// is not yet in state [`TenantState::Active`].
@@ -89,26 +85,6 @@ const ACTIVE_TENANT_TIMEOUT: Duration = Duration::from_millis(30000);
/// Threshold at which to log slow GetPage requests.
const LOG_SLOW_GETPAGE_THRESHOLD: Duration = Duration::from_secs(30);
/// The idle time before sending TCP keepalive probes for gRPC connections. The
/// interval and timeout between each probe is configured via sysctl. This
/// allows detecting dead connections sooner.
const GRPC_TCP_KEEPALIVE_TIME: Duration = Duration::from_secs(60);
/// Whether to enable TCP nodelay for gRPC connections. This disables Nagle's
/// algorithm, which can cause latency spikes for small messages.
const GRPC_TCP_NODELAY: bool = true;
/// The interval between HTTP2 keepalive pings. This allows shutting down server
/// tasks when clients are unresponsive.
const GRPC_HTTP2_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30);
/// The timeout for HTTP2 keepalive pings. Should be <= GRPC_KEEPALIVE_INTERVAL.
const GRPC_HTTP2_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(20);
/// Number of concurrent gRPC streams per TCP connection. We expect something
/// like 8 GetPage streams per connections, plus any unary requests.
const GRPC_MAX_CONCURRENT_STREAMS: u32 = 256;
///////////////////////////////////////////////////////////////////////////////
pub struct Listener {
@@ -131,7 +107,6 @@ pub fn spawn(
perf_trace_dispatch: Option<Dispatch>,
tcp_listener: tokio::net::TcpListener,
tls_config: Option<Arc<rustls::ServerConfig>>,
basebackup_cache: Arc<BasebackupCache>,
) -> Listener {
let cancel = CancellationToken::new();
let libpq_ctx = RequestContext::todo_child(
@@ -153,7 +128,6 @@ pub fn spawn(
conf.pg_auth_type,
tls_config,
conf.page_service_pipelining.clone(),
basebackup_cache,
libpq_ctx,
cancel.clone(),
)
@@ -163,94 +137,6 @@ pub fn spawn(
Listener { cancel, task }
}
/// Spawns a gRPC server for the page service.
///
/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we
/// need to reimplement the TCP+TLS accept loop ourselves.
pub fn spawn_grpc(
conf: &'static PageServerConf,
tenant_manager: Arc<TenantManager>,
auth: Option<Arc<SwappableJwtAuth>>,
perf_trace_dispatch: Option<Dispatch>,
listener: std::net::TcpListener,
basebackup_cache: Arc<BasebackupCache>,
) -> anyhow::Result<CancellableTask> {
let cancel = CancellationToken::new();
let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler)
.download_behavior(DownloadBehavior::Download)
.perf_span_dispatch(perf_trace_dispatch)
.detached_child();
let gate = Gate::default();
// Set up the TCP socket. We take a preconfigured TcpListener to bind the
// port early during startup.
let incoming = {
let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std
listener.set_nonblocking(true)?;
tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std(listener)?)
.with_nodelay(Some(GRPC_TCP_NODELAY))
.with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME))
};
// Set up the gRPC server.
//
// TODO: consider tuning window sizes.
// TODO: wire up tracing.
let mut server = tonic::transport::Server::builder()
.http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL))
.http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT))
.max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS));
// Main page service.
let page_service_handler = PageServerHandler::new(
tenant_manager,
auth.clone(),
PageServicePipeliningConfig::Serial, // TODO: unused with gRPC
conf.get_vectored_concurrent_io,
ConnectionPerfSpanFields::default(),
basebackup_cache,
ctx,
cancel.clone(),
gate.enter().expect("just created"),
);
let mut tenant_interceptor = TenantMetadataInterceptor;
let mut auth_interceptor = TenantAuthInterceptor::new(auth);
let interceptors = move |mut req: tonic::Request<()>| {
req = tenant_interceptor.call(req)?;
req = auth_interceptor.call(req)?;
Ok(req)
};
let page_service =
proto::PageServiceServer::with_interceptor(page_service_handler, interceptors);
let server = server.add_service(page_service);
// Reflection service for use with e.g. grpcurl.
let reflection_service = tonic_reflection::server::Builder::configure()
.register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET)
.build_v1()?;
let server = server.add_service(reflection_service);
// Spawn server task.
let task_cancel = cancel.clone();
let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error(
"grpc listener",
async move {
let result = server
.serve_with_incoming_shutdown(incoming, task_cancel.cancelled())
.await;
if result.is_ok() {
// TODO: revisit shutdown logic once page service is implemented.
gate.close().await;
}
result
},
));
Ok(CancellableTask { task, cancel })
}
impl Listener {
pub async fn stop_accepting(self) -> Connections {
self.cancel.cancel();
@@ -300,7 +186,6 @@ pub async fn libpq_listener_main(
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
pipelining_config: PageServicePipeliningConfig,
basebackup_cache: Arc<BasebackupCache>,
listener_ctx: RequestContext,
listener_cancel: CancellationToken,
) -> Connections {
@@ -344,7 +229,6 @@ pub async fn libpq_listener_main(
auth_type,
tls_config.clone(),
pipelining_config.clone(),
Arc::clone(&basebackup_cache),
connection_ctx,
connections_cancel.child_token(),
gate_guard,
@@ -370,7 +254,7 @@ type ConnectionHandlerResult = anyhow::Result<()>;
/// Perf root spans start at the per-request level, after shard routing.
/// This struct carries connection-level information to the root perf span definition.
#[derive(Clone, Default)]
#[derive(Clone)]
struct ConnectionPerfSpanFields {
peer_addr: String,
application_name: Option<String>,
@@ -387,7 +271,6 @@ async fn page_service_conn_main(
auth_type: AuthType,
tls_config: Option<Arc<rustls::ServerConfig>>,
pipelining_config: PageServicePipeliningConfig,
basebackup_cache: Arc<BasebackupCache>,
connection_ctx: RequestContext,
cancel: CancellationToken,
gate_guard: GateGuard,
@@ -453,7 +336,6 @@ async fn page_service_conn_main(
pipelining_config,
conf.get_vectored_concurrent_io,
perf_span_fields,
basebackup_cache,
connection_ctx,
cancel.clone(),
gate_guard,
@@ -488,11 +370,6 @@ async fn page_service_conn_main(
}
}
/// Page service connection handler.
///
/// TODO: for gRPC, this will be shared by all requests from all connections.
/// Decompose it into global state and per-connection/request state, and make
/// libpq-specific options (e.g. pipelining) separate.
struct PageServerHandler {
auth: Option<Arc<SwappableJwtAuth>>,
claims: Option<Claims>,
@@ -513,8 +390,6 @@ struct PageServerHandler {
pipelining_config: PageServicePipeliningConfig,
get_vectored_concurrent_io: GetVectoredConcurrentIo,
basebackup_cache: Arc<BasebackupCache>,
gate_guard: GateGuard,
}
@@ -769,9 +644,6 @@ struct BatchedGetPageRequest {
timer: SmgrOpTimer,
lsn_range: LsnRange,
ctx: RequestContext,
// If the request is perf enabled, this contains a context
// with a perf span tracking the time spent waiting for the executor.
batch_wait_ctx: Option<RequestContext>,
}
#[cfg(feature = "testing")]
@@ -784,7 +656,6 @@ struct BatchedTestRequest {
/// so that we don't keep the [`Timeline::gate`] open while the batch
/// is being built up inside the [`spsc_fold`] (pagestream pipelining).
#[derive(IntoStaticStr)]
#[allow(clippy::large_enum_variant)]
enum BatchedFeMessage {
Exists {
span: Span,
@@ -978,7 +849,6 @@ impl PageServerHandler {
pipelining_config: PageServicePipeliningConfig,
get_vectored_concurrent_io: GetVectoredConcurrentIo,
perf_span_fields: ConnectionPerfSpanFields,
basebackup_cache: Arc<BasebackupCache>,
connection_ctx: RequestContext,
cancel: CancellationToken,
gate_guard: GateGuard,
@@ -992,7 +862,6 @@ impl PageServerHandler {
cancel,
pipelining_config,
get_vectored_concurrent_io,
basebackup_cache,
gate_guard,
}
}
@@ -1302,22 +1171,6 @@ impl PageServerHandler {
}
};
let batch_wait_ctx = if ctx.has_perf_span() {
Some(
RequestContextBuilder::from(&ctx)
.perf_span(|crnt_perf_span| {
info_span!(
target: PERF_TRACE_TARGET,
parent: crnt_perf_span,
"WAIT_EXECUTOR",
)
})
.attached_child(),
)
} else {
None
};
BatchedFeMessage::GetPage {
span,
shard: shard.downgrade(),
@@ -1329,7 +1182,6 @@ impl PageServerHandler {
request_lsn: req.hdr.request_lsn
},
ctx,
batch_wait_ctx,
}],
// The executor grabs the batch when it becomes idle.
// Hence, [`GetPageBatchBreakReason::ExecutorSteal`] is the
@@ -1485,7 +1337,7 @@ impl PageServerHandler {
let mut flush_timers = Vec::with_capacity(handler_results.len());
for handler_result in &mut handler_results {
let flush_timer = match handler_result {
Ok((_response, timer, _ctx)) => Some(
Ok((_, timer)) => Some(
timer
.observe_execution_end(flushing_start_time)
.expect("we are the first caller"),
@@ -1505,7 +1357,7 @@ impl PageServerHandler {
// Some handler errors cause exit from pagestream protocol.
// Other handler errors are sent back as an error message and we stay in pagestream protocol.
for (handler_result, flushing_timer) in handler_results.into_iter().zip(flush_timers) {
let (response_msg, ctx) = match handler_result {
let response_msg = match handler_result {
Err(e) => match &e.err {
PageStreamError::Shutdown => {
// If we fail to fulfil a request during shutdown, which may be _because_ of
@@ -1530,30 +1382,15 @@ impl PageServerHandler {
error!("error reading relation or page version: {full:#}")
});
(
PagestreamBeMessage::Error(PagestreamErrorResponse {
req: e.req,
message: e.err.to_string(),
}),
None,
)
PagestreamBeMessage::Error(PagestreamErrorResponse {
req: e.req,
message: e.err.to_string(),
})
}
},
Ok((response_msg, _op_timer_already_observed, ctx)) => (response_msg, Some(ctx)),
Ok((response_msg, _op_timer_already_observed)) => response_msg,
};
let ctx = ctx.map(|req_ctx| {
RequestContextBuilder::from(&req_ctx)
.perf_span(|crnt_perf_span| {
info_span!(
target: PERF_TRACE_TARGET,
parent: crnt_perf_span,
"FLUSH_RESPONSE",
)
})
.attached_child()
});
//
// marshal & transmit response message
//
@@ -1576,17 +1413,6 @@ impl PageServerHandler {
)),
None => futures::future::Either::Right(flush_fut),
};
let flush_fut = if let Some(req_ctx) = ctx.as_ref() {
futures::future::Either::Left(
flush_fut.maybe_perf_instrument(req_ctx, |current_perf_span| {
current_perf_span.clone()
}),
)
} else {
futures::future::Either::Right(flush_fut)
};
// do it while respecting cancellation
let _: () = async move {
tokio::select! {
@@ -1616,7 +1442,7 @@ impl PageServerHandler {
ctx: &RequestContext,
) -> Result<
(
Vec<Result<(PagestreamBeMessage, SmgrOpTimer, RequestContext), BatchedPageStreamError>>,
Vec<Result<(PagestreamBeMessage, SmgrOpTimer), BatchedPageStreamError>>,
Span,
),
QueryError,
@@ -1643,7 +1469,7 @@ impl PageServerHandler {
self.handle_get_rel_exists_request(&shard, &req, &ctx)
.instrument(span.clone())
.await
.map(|msg| (msg, timer, ctx))
.map(|msg| (msg, timer))
.map_err(|err| BatchedPageStreamError { err, req: req.hdr }),
],
span,
@@ -1662,7 +1488,7 @@ impl PageServerHandler {
self.handle_get_nblocks_request(&shard, &req, &ctx)
.instrument(span.clone())
.await
.map(|msg| (msg, timer, ctx))
.map(|msg| (msg, timer))
.map_err(|err| BatchedPageStreamError { err, req: req.hdr }),
],
span,
@@ -1709,7 +1535,7 @@ impl PageServerHandler {
self.handle_db_size_request(&shard, &req, &ctx)
.instrument(span.clone())
.await
.map(|msg| (msg, timer, ctx))
.map(|msg| (msg, timer))
.map_err(|err| BatchedPageStreamError { err, req: req.hdr }),
],
span,
@@ -1728,7 +1554,7 @@ impl PageServerHandler {
self.handle_get_slru_segment_request(&shard, &req, &ctx)
.instrument(span.clone())
.await
.map(|msg| (msg, timer, ctx))
.map(|msg| (msg, timer))
.map_err(|err| BatchedPageStreamError { err, req: req.hdr }),
],
span,
@@ -2080,25 +1906,12 @@ impl PageServerHandler {
return Ok(());
}
};
let mut batch = match batch {
let batch = match batch {
Ok(batch) => batch,
Err(e) => {
return Err(e);
}
};
if let BatchedFeMessage::GetPage {
pages,
span: _,
shard: _,
batch_break_reason: _,
} = &mut batch
{
for req in pages {
req.batch_wait_ctx.take();
}
}
self.pagestream_handle_batched_message(
pgb_writer,
batch,
@@ -2411,8 +2224,7 @@ impl PageServerHandler {
io_concurrency: IoConcurrency,
batch_break_reason: GetPageBatchBreakReason,
ctx: &RequestContext,
) -> Vec<Result<(PagestreamBeMessage, SmgrOpTimer, RequestContext), BatchedPageStreamError>>
{
) -> Vec<Result<(PagestreamBeMessage, SmgrOpTimer), BatchedPageStreamError>> {
debug_assert_current_span_has_tenant_and_timeline_id();
timeline
@@ -2519,7 +2331,6 @@ impl PageServerHandler {
page,
}),
req.timer,
req.ctx,
)
})
.map_err(|e| BatchedPageStreamError {
@@ -2564,8 +2375,7 @@ impl PageServerHandler {
timeline: &Timeline,
requests: Vec<BatchedTestRequest>,
_ctx: &RequestContext,
) -> Vec<Result<(PagestreamBeMessage, SmgrOpTimer, RequestContext), BatchedPageStreamError>>
{
) -> Vec<Result<(PagestreamBeMessage, SmgrOpTimer), BatchedPageStreamError>> {
// real requests would do something with the timeline
let mut results = Vec::with_capacity(requests.len());
for _req in requests.iter() {
@@ -2592,10 +2402,6 @@ impl PageServerHandler {
req: req.req.clone(),
}),
req.timer,
RequestContext::new(
TaskKind::PageRequestHandler,
DownloadBehavior::Warn,
),
)
})
.map_err(|e| BatchedPageStreamError {
@@ -2687,8 +2493,6 @@ impl PageServerHandler {
.map_err(QueryError::Disconnected)?;
self.flush_cancellable(pgb, &self.cancel).await?;
let mut from_cache = false;
// Send a tarball of the latest layer on the timeline. Compress if not
// fullbackup. TODO Compress in that case too (tests need to be updated)
if full_backup {
@@ -2706,33 +2510,7 @@ impl PageServerHandler {
.map_err(map_basebackup_error)?;
} else {
let mut writer = BufWriter::new(pgb.copyout_writer());
let cached = {
// Basebackup is cached only for this combination of parameters.
if timeline.is_basebackup_cache_enabled()
&& gzip
&& lsn.is_some()
&& prev_lsn.is_none()
{
self.basebackup_cache
.get(tenant_id, timeline_id, lsn.unwrap())
.await
} else {
None
}
};
if let Some(mut cached) = cached {
from_cache = true;
tokio::io::copy(&mut cached, &mut writer)
.await
.map_err(|e| {
map_basebackup_error(BasebackupError::Client(
e,
"handle_basebackup_request,cached,copy",
))
})?;
} else if gzip {
if gzip {
let mut encoder = GzipEncoder::with_quality(
&mut writer,
// NOTE using fast compression because it's on the critical path
@@ -2791,7 +2569,6 @@ impl PageServerHandler {
info!(
lsn_await_millis = lsn_awaited_after.as_millis(),
basebackup_millis = basebackup_after.as_millis(),
%from_cache,
"basebackup complete"
);
@@ -3300,60 +3077,6 @@ where
}
}
/// Implements the page service over gRPC.
///
/// TODO: not yet implemented, all methods return unimplemented.
#[tonic::async_trait]
impl proto::PageService for PageServerHandler {
type GetBaseBackupStream = Pin<
Box<dyn Stream<Item = Result<proto::GetBaseBackupResponseChunk, tonic::Status>> + Send>,
>;
type GetPagesStream =
Pin<Box<dyn Stream<Item = Result<proto::GetPageResponse, tonic::Status>> + Send>>;
async fn check_rel_exists(
&self,
_: tonic::Request<proto::CheckRelExistsRequest>,
) -> Result<tonic::Response<proto::CheckRelExistsResponse>, tonic::Status> {
Err(tonic::Status::unimplemented("not implemented"))
}
async fn get_base_backup(
&self,
_: tonic::Request<proto::GetBaseBackupRequest>,
) -> Result<tonic::Response<Self::GetBaseBackupStream>, tonic::Status> {
Err(tonic::Status::unimplemented("not implemented"))
}
async fn get_db_size(
&self,
_: tonic::Request<proto::GetDbSizeRequest>,
) -> Result<tonic::Response<proto::GetDbSizeResponse>, tonic::Status> {
Err(tonic::Status::unimplemented("not implemented"))
}
async fn get_pages(
&self,
_: tonic::Request<tonic::Streaming<proto::GetPageRequest>>,
) -> Result<tonic::Response<Self::GetPagesStream>, tonic::Status> {
Err(tonic::Status::unimplemented("not implemented"))
}
async fn get_rel_size(
&self,
_: tonic::Request<proto::GetRelSizeRequest>,
) -> Result<tonic::Response<proto::GetRelSizeResponse>, tonic::Status> {
Err(tonic::Status::unimplemented("not implemented"))
}
async fn get_slru_segment(
&self,
_: tonic::Request<proto::GetSlruSegmentRequest>,
) -> Result<tonic::Response<proto::GetSlruSegmentResponse>, tonic::Status> {
Err(tonic::Status::unimplemented("not implemented"))
}
}
impl From<GetActiveTenantError> for QueryError {
fn from(e: GetActiveTenantError) -> Self {
match e {
@@ -3370,104 +3093,6 @@ impl From<GetActiveTenantError> for QueryError {
}
}
/// gRPC interceptor that decodes tenant metadata and stores it as request extensions of type
/// TenantTimelineId and ShardIndex.
///
/// TODO: consider looking up the timeline handle here and storing it.
#[derive(Clone)]
struct TenantMetadataInterceptor;
impl tonic::service::Interceptor for TenantMetadataInterceptor {
fn call(&mut self, mut req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
// Decode the tenant ID.
let tenant_id = req
.metadata()
.get("neon-tenant-id")
.ok_or_else(|| tonic::Status::invalid_argument("missing neon-tenant-id"))?
.to_str()
.map_err(|_| tonic::Status::invalid_argument("invalid neon-tenant-id"))?;
let tenant_id = TenantId::from_str(tenant_id)
.map_err(|_| tonic::Status::invalid_argument("invalid neon-tenant-id"))?;
// Decode the timeline ID.
let timeline_id = req
.metadata()
.get("neon-timeline-id")
.ok_or_else(|| tonic::Status::invalid_argument("missing neon-timeline-id"))?
.to_str()
.map_err(|_| tonic::Status::invalid_argument("invalid neon-timeline-id"))?;
let timeline_id = TimelineId::from_str(timeline_id)
.map_err(|_| tonic::Status::invalid_argument("invalid neon-timeline-id"))?;
// Decode the shard ID.
let shard_index = req
.metadata()
.get("neon-shard-id")
.ok_or_else(|| tonic::Status::invalid_argument("missing neon-shard-id"))?
.to_str()
.map_err(|_| tonic::Status::invalid_argument("invalid neon-shard-id"))?;
let shard_index = ShardIndex::from_str(shard_index)
.map_err(|_| tonic::Status::invalid_argument("invalid neon-shard-id"))?;
// Stash them in the request.
let extensions = req.extensions_mut();
extensions.insert(TenantTimelineId::new(tenant_id, timeline_id));
extensions.insert(shard_index);
Ok(req)
}
}
/// Authenticates gRPC page service requests. Must run after TenantMetadataInterceptor.
#[derive(Clone)]
struct TenantAuthInterceptor {
auth: Option<Arc<SwappableJwtAuth>>,
}
impl TenantAuthInterceptor {
fn new(auth: Option<Arc<SwappableJwtAuth>>) -> Self {
Self { auth }
}
}
impl tonic::service::Interceptor for TenantAuthInterceptor {
fn call(&mut self, req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
// Do nothing if auth is disabled.
let Some(auth) = self.auth.as_ref() else {
return Ok(req);
};
// Fetch the tenant ID that's been set by TenantMetadataInterceptor.
let ttid = req
.extensions()
.get::<TenantTimelineId>()
.expect("TenantMetadataInterceptor must run before TenantAuthInterceptor");
// Fetch and decode the JWT token.
let jwt = req
.metadata()
.get("authorization")
.ok_or_else(|| tonic::Status::unauthenticated("no authorization header"))?
.to_str()
.map_err(|_| tonic::Status::invalid_argument("invalid authorization header"))?
.strip_prefix("Bearer ")
.ok_or_else(|| tonic::Status::invalid_argument("invalid authorization header"))?
.trim();
let jwtdata: TokenData<Claims> = auth
.decode(jwt)
.map_err(|err| tonic::Status::invalid_argument(format!("invalid JWT token: {err}")))?;
let claims = jwtdata.claims;
// Check if the token is valid for this tenant.
check_permission(&claims, Some(ttid.tenant_id))
.map_err(|err| tonic::Status::permission_denied(err.to_string()))?;
// TODO: consider stashing the claims in the request extensions, if needed.
Ok(req)
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum GetActiveTimelineError {
#[error(transparent)]

View File

@@ -276,10 +276,9 @@ pub enum TaskKind {
// HTTP endpoint listener.
HttpEndpointListener,
/// Task that handles a single page service connection. A PageRequestHandler
/// task starts detached from any particular tenant or timeline, but it can
/// be associated with one later, after receiving a command from the client.
/// Also used for the gRPC page service API, including the main server task.
// Task that handles a single connection. A PageRequestHandler task
// starts detached from any particular tenant or timeline, but it can be
// associated with one later, after receiving a command from the client.
PageRequestHandler,
/// Manages the WAL receiver connection for one timeline.
@@ -381,10 +380,6 @@ pub enum TaskKind {
DetachAncestor,
ImportPgdata,
/// Background task of [`crate::basebackup_cache::BasebackupCache`].
/// Prepares basebackups and clears outdated entries.
BasebackupCache,
}
#[derive(Default)]

View File

@@ -78,18 +78,16 @@ use self::timeline::uninit::{TimelineCreateGuard, TimelineExclusionError, Uninit
use self::timeline::{
EvictionTaskTenantState, GcCutoffs, TimelineDeleteProgress, TimelineResources, WaitLsnError,
};
use crate::basebackup_cache::BasebackupPrepareSender;
use crate::config::PageServerConf;
use crate::context;
use crate::context::RequestContextBuilder;
use crate::context::{DownloadBehavior, RequestContext};
use crate::deletion_queue::{DeletionQueueClient, DeletionQueueError};
use crate::feature_resolver::FeatureResolver;
use crate::l0_flush::L0FlushGlobalState;
use crate::metrics::{
BROKEN_TENANTS_SET, CIRCUIT_BREAKERS_BROKEN, CIRCUIT_BREAKERS_UNBROKEN, CONCURRENT_INITDBS,
INITDB_RUN_TIME, INITDB_SEMAPHORE_ACQUISITION_TIME, TENANT, TENANT_OFFLOADED_TIMELINES,
TENANT_STATE_METRIC, TENANT_SYNTHETIC_SIZE_METRIC, remove_tenant_metrics,
INITDB_RUN_TIME, INITDB_SEMAPHORE_ACQUISITION_TIME, TENANT, TENANT_STATE_METRIC,
TENANT_SYNTHETIC_SIZE_METRIC, remove_tenant_metrics,
};
use crate::task_mgr::TaskKind;
use crate::tenant::config::LocationMode;
@@ -159,8 +157,6 @@ pub struct TenantSharedResources {
pub remote_storage: GenericRemoteStorage,
pub deletion_queue_client: DeletionQueueClient,
pub l0_flush_global_state: L0FlushGlobalState,
pub basebackup_prepare_sender: BasebackupPrepareSender,
pub feature_resolver: FeatureResolver,
}
/// A [`TenantShard`] is really an _attached_ tenant. The configuration
@@ -321,15 +317,12 @@ pub struct TenantShard {
gc_cs: tokio::sync::Mutex<()>,
walredo_mgr: Option<Arc<WalRedoManager>>,
/// Provides access to timeline data sitting in the remote storage.
// provides access to timeline data sitting in the remote storage
pub(crate) remote_storage: GenericRemoteStorage,
/// Access to global deletion queue for when this tenant wants to schedule a deletion.
// Access to global deletion queue for when this tenant wants to schedule a deletion
deletion_queue_client: DeletionQueueClient,
/// A channel to send async requests to prepare a basebackup for the basebackup cache.
basebackup_prepare_sender: BasebackupPrepareSender,
/// Cached logical sizes updated updated on each [`TenantShard::gather_size_inputs`].
cached_logical_sizes: tokio::sync::Mutex<HashMap<(TimelineId, Lsn), u64>>,
cached_synthetic_tenant_size: Arc<AtomicU64>,
@@ -382,8 +375,6 @@ pub struct TenantShard {
pub(crate) gc_block: gc_block::GcBlock,
l0_flush_global_state: L0FlushGlobalState,
feature_resolver: FeatureResolver,
}
impl std::fmt::Debug for TenantShard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -864,14 +855,6 @@ impl Debug for SetStoppingError {
}
}
#[derive(thiserror::Error, Debug)]
pub(crate) enum FinalizeTimelineImportError {
#[error("Import task not done yet")]
ImportTaskStillRunning,
#[error("Shutting down")]
ShuttingDown,
}
/// Arguments to [`TenantShard::create_timeline`].
///
/// Not usable as an idempotency key for timeline creation because if [`CreateTimelineParamsBranch::ancestor_start_lsn`]
@@ -1158,20 +1141,10 @@ impl TenantShard {
ctx,
)?;
let disk_consistent_lsn = timeline.get_disk_consistent_lsn();
if !disk_consistent_lsn.is_valid() {
// As opposed to normal timelines which get initialised with a disk consitent LSN
// via initdb, imported timelines start from 0. If the import task stops before
// it advances disk consitent LSN, allow it to resume.
let in_progress_import = import_pgdata
.as_ref()
.map(|import| !import.is_done())
.unwrap_or(false);
if !in_progress_import {
anyhow::bail!("Timeline {tenant_id}/{timeline_id} has invalid disk_consistent_lsn");
}
}
anyhow::ensure!(
disk_consistent_lsn.is_valid(),
"Timeline {tenant_id}/{timeline_id} has invalid disk_consistent_lsn"
);
assert_eq!(
disk_consistent_lsn,
metadata.disk_consistent_lsn(),
@@ -1265,25 +1238,20 @@ impl TenantShard {
}
}
if disk_consistent_lsn.is_valid() {
// Sanity check: a timeline should have some content.
// Exception: importing timelines might not yet have any
anyhow::ensure!(
ancestor.is_some()
|| timeline
.layers
.read()
.await
.layer_map()
.expect(
"currently loading, layer manager cannot be shutdown already"
)
.iter_historic_layers()
.next()
.is_some(),
"Timeline has no ancestor and no layer files"
);
}
// Sanity check: a timeline should have some content.
anyhow::ensure!(
ancestor.is_some()
|| timeline
.layers
.read()
.await
.layer_map()
.expect("currently loading, layer manager cannot be shutdown already")
.iter_historic_layers()
.next()
.is_some(),
"Timeline has no ancestor and no layer files"
);
Ok(TimelineInitAndSyncResult::ReadyToActivate)
}
@@ -1318,8 +1286,6 @@ impl TenantShard {
remote_storage,
deletion_queue_client,
l0_flush_global_state,
basebackup_prepare_sender,
feature_resolver,
} = resources;
let attach_mode = attached_conf.location.attach_mode;
@@ -1335,8 +1301,6 @@ impl TenantShard {
remote_storage.clone(),
deletion_queue_client,
l0_flush_global_state,
basebackup_prepare_sender,
feature_resolver,
));
// The attach task will carry a GateGuard, so that shutdown() reliably waits for it to drop out if
@@ -2883,13 +2847,13 @@ impl TenantShard {
pub(crate) async fn finalize_importing_timeline(
&self,
timeline_id: TimelineId,
) -> Result<(), FinalizeTimelineImportError> {
) -> anyhow::Result<()> {
let timeline = {
let locked = self.timelines_importing.lock().unwrap();
match locked.get(&timeline_id) {
Some(importing_timeline) => {
if !importing_timeline.import_task_handle.is_finished() {
return Err(FinalizeTimelineImportError::ImportTaskStillRunning);
return Err(anyhow::anyhow!("Import task not done yet"));
}
importing_timeline.timeline.clone()
@@ -2902,13 +2866,8 @@ impl TenantShard {
timeline
.remote_client
.schedule_index_upload_for_import_pgdata_finalize()
.map_err(|_err| FinalizeTimelineImportError::ShuttingDown)?;
timeline
.remote_client
.wait_completion()
.await
.map_err(|_err| FinalizeTimelineImportError::ShuttingDown)?;
.schedule_index_upload_for_import_pgdata_finalize()?;
timeline.remote_client.wait_completion().await?;
self.timelines_importing
.lock()
@@ -3169,18 +3128,11 @@ impl TenantShard {
.or_insert_with(|| Arc::new(GcCompactionQueue::new()))
.clone()
};
let gc_compaction_strategy = self
.feature_resolver
.evaluate_multivariate("gc-comapction-strategy", self.tenant_shard_id.tenant_id)
.ok();
let span = if let Some(gc_compaction_strategy) = gc_compaction_strategy {
info_span!("gc_compact_timeline", timeline_id = %timeline.timeline_id, strategy = %gc_compaction_strategy)
} else {
info_span!("gc_compact_timeline", timeline_id = %timeline.timeline_id)
};
outcome = queue
.iteration(cancel, ctx, &self.gc_block, &timeline)
.instrument(span)
.instrument(
info_span!("gc_compact_timeline", timeline_id = %timeline.timeline_id),
)
.await?;
}
@@ -3396,13 +3348,6 @@ impl TenantShard {
activated_timelines += 1;
}
let tid = self.tenant_shard_id.tenant_id.to_string();
let shard_id = self.tenant_shard_id.shard_slug().to_string();
let offloaded_timeline_count = timelines_offloaded_accessor.len();
TENANT_OFFLOADED_TIMELINES
.with_label_values(&[&tid, &shard_id])
.set(offloaded_timeline_count as u64);
self.state.send_modify(move |current_state| {
assert!(
matches!(current_state, TenantState::Activating(_)),
@@ -3512,9 +3457,8 @@ impl TenantShard {
let mut timelines_importing = self.timelines_importing.lock().unwrap();
timelines_importing
.drain()
.for_each(|(timeline_id, importing_timeline)| {
let span = tracing::info_span!("importing_timeline_shutdown", %timeline_id);
js.spawn(async move { importing_timeline.shutdown().instrument(span).await });
.for_each(|(_timeline_id, importing_timeline)| {
importing_timeline.shutdown();
});
}
// test_long_timeline_create_then_tenant_delete is leaning on this message
@@ -4288,8 +4232,6 @@ impl TenantShard {
remote_storage: GenericRemoteStorage,
deletion_queue_client: DeletionQueueClient,
l0_flush_global_state: L0FlushGlobalState,
basebackup_prepare_sender: BasebackupPrepareSender,
feature_resolver: FeatureResolver,
) -> TenantShard {
assert!(!attached_conf.location.generation.is_none());
@@ -4393,8 +4335,6 @@ impl TenantShard {
ongoing_timeline_detach: std::sync::Mutex::default(),
gc_block: Default::default(),
l0_flush_global_state,
basebackup_prepare_sender,
feature_resolver,
}
}
@@ -4647,7 +4587,7 @@ impl TenantShard {
target.cutoffs = GcCutoffs {
space: space_cutoff,
time: None,
time: Lsn::INVALID,
};
}
}
@@ -4731,7 +4671,7 @@ impl TenantShard {
if let Some(ancestor_id) = timeline.get_ancestor_timeline_id() {
if let Some(ancestor_gc_cutoffs) = gc_cutoffs.get(&ancestor_id) {
target.within_ancestor_pitr =
Some(timeline.get_ancestor_lsn()) >= ancestor_gc_cutoffs.time;
timeline.get_ancestor_lsn() >= ancestor_gc_cutoffs.time;
}
}
@@ -4744,15 +4684,13 @@ impl TenantShard {
} else {
0
});
if let Some(time_cutoff) = target.cutoffs.time {
timeline.metrics.pitr_history_size.set(
timeline
.get_last_record_lsn()
.checked_sub(time_cutoff)
.unwrap_or_default()
.0,
);
}
timeline.metrics.pitr_history_size.set(
timeline
.get_last_record_lsn()
.checked_sub(target.cutoffs.time)
.unwrap_or(Lsn(0))
.0,
);
// Apply the cutoffs we found to the Timeline's GcInfo. Why might we _not_ have cutoffs for a timeline?
// - this timeline was created while we were finding cutoffs
@@ -4761,8 +4699,8 @@ impl TenantShard {
let original_cutoffs = target.cutoffs.clone();
// GC cutoffs should never go back
target.cutoffs = GcCutoffs {
space: cutoffs.space.max(original_cutoffs.space),
time: cutoffs.time.max(original_cutoffs.time),
space: Lsn(cutoffs.space.0.max(original_cutoffs.space.0)),
time: Lsn(cutoffs.time.0.max(original_cutoffs.time.0)),
}
}
}
@@ -5314,8 +5252,6 @@ impl TenantShard {
pagestream_throttle_metrics: self.pagestream_throttle_metrics.clone(),
l0_compaction_trigger: self.l0_compaction_trigger.clone(),
l0_flush_global_state: self.l0_flush_global_state.clone(),
basebackup_prepare_sender: self.basebackup_prepare_sender.clone(),
feature_resolver: self.feature_resolver.clone(),
}
}
@@ -5624,14 +5560,6 @@ impl TenantShard {
}
}
// Update metrics
let tid = self.tenant_shard_id.to_string();
let shard_id = self.tenant_shard_id.shard_slug().to_string();
let set_key = &[tid.as_str(), shard_id.as_str()][..];
TENANT_OFFLOADED_TIMELINES
.with_label_values(set_key)
.set(manifest.offloaded_timelines.len() as u64);
// Upload the manifest. Remote storage does no retries internally, so retry here.
match backoff::retry(
|| async {
@@ -5898,8 +5826,6 @@ pub(crate) mod harness {
) -> anyhow::Result<Arc<TenantShard>> {
let walredo_mgr = Arc::new(WalRedoManager::from(TestRedoManager));
let (basebackup_requst_sender, _) = tokio::sync::mpsc::unbounded_channel();
let tenant = Arc::new(TenantShard::new(
TenantState::Attaching,
self.conf,
@@ -5917,8 +5843,6 @@ pub(crate) mod harness {
self.deletion_queue.new_client(),
// TODO: ideally we should run all unit tests with both configs
L0FlushGlobalState::new(L0FlushConfig::default()),
basebackup_requst_sender,
FeatureResolver::new_disabled(),
));
let preload = tenant
@@ -8360,24 +8284,10 @@ mod tests {
}
tline.freeze_and_flush().await?;
// Force layers to L1
tline
.compact(
&cancel,
{
let mut flags = EnumSet::new();
flags.insert(CompactFlags::ForceL0Compaction);
flags
},
&ctx,
)
.await?;
if iter % 5 == 0 {
let scan_lsn = Lsn(lsn.0 + 1);
info!("scanning at {}", scan_lsn);
let (_, before_delta_file_accessed) =
scan_with_statistics(&tline, &keyspace, scan_lsn, &ctx, io_concurrency.clone())
scan_with_statistics(&tline, &keyspace, lsn, &ctx, io_concurrency.clone())
.await?;
tline
.compact(
@@ -8386,14 +8296,13 @@ mod tests {
let mut flags = EnumSet::new();
flags.insert(CompactFlags::ForceImageLayerCreation);
flags.insert(CompactFlags::ForceRepartition);
flags.insert(CompactFlags::ForceL0Compaction);
flags
},
&ctx,
)
.await?;
let (_, after_delta_file_accessed) =
scan_with_statistics(&tline, &keyspace, scan_lsn, &ctx, io_concurrency.clone())
scan_with_statistics(&tline, &keyspace, lsn, &ctx, io_concurrency.clone())
.await?;
assert!(
after_delta_file_accessed < before_delta_file_accessed,
@@ -8834,8 +8743,6 @@ mod tests {
let cancel = CancellationToken::new();
// Image layer creation happens on the disk_consistent_lsn so we need to force set it now.
tline.force_set_disk_consistent_lsn(Lsn(0x40));
tline
.compact(
&cancel,
@@ -8849,7 +8756,8 @@ mod tests {
)
.await
.unwrap();
// Image layers are created at repartition LSN
// Image layers are created at last_record_lsn
let images = tline
.inspect_image_layers(Lsn(0x40), &ctx, io_concurrency.clone())
.await
@@ -9029,7 +8937,7 @@ mod tests {
.await;
// Update GC info
let mut guard = tline.gc_info.write().unwrap();
guard.cutoffs.time = Some(Lsn(0x30));
guard.cutoffs.time = Lsn(0x30);
guard.cutoffs.space = Lsn(0x30);
}
@@ -9137,7 +9045,7 @@ mod tests {
.await;
// Update GC info
let mut guard = tline.gc_info.write().unwrap();
guard.cutoffs.time = Some(Lsn(0x40));
guard.cutoffs.time = Lsn(0x40);
guard.cutoffs.space = Lsn(0x40);
}
tline
@@ -9555,7 +9463,7 @@ mod tests {
*guard = GcInfo {
retain_lsns: vec![],
cutoffs: GcCutoffs {
time: Some(Lsn(0x30)),
time: Lsn(0x30),
space: Lsn(0x30),
},
leases: Default::default(),
@@ -9639,7 +9547,7 @@ mod tests {
.await;
// Update GC info
let mut guard = tline.gc_info.write().unwrap();
guard.cutoffs.time = Some(Lsn(0x40));
guard.cutoffs.time = Lsn(0x40);
guard.cutoffs.space = Lsn(0x40);
}
tline
@@ -10110,7 +10018,7 @@ mod tests {
(Lsn(0x20), tline.timeline_id, MaybeOffloaded::No),
],
cutoffs: GcCutoffs {
time: Some(Lsn(0x30)),
time: Lsn(0x30),
space: Lsn(0x30),
},
leases: Default::default(),
@@ -10173,7 +10081,7 @@ mod tests {
let verify_result = || async {
let gc_horizon = {
let gc_info = tline.gc_info.read().unwrap();
gc_info.cutoffs.time.unwrap_or_default()
gc_info.cutoffs.time
};
for idx in 0..10 {
assert_eq!(
@@ -10251,7 +10159,7 @@ mod tests {
.await;
// Update GC info
let mut guard = tline.gc_info.write().unwrap();
guard.cutoffs.time = Some(Lsn(0x38));
guard.cutoffs.time = Lsn(0x38);
guard.cutoffs.space = Lsn(0x38);
}
tline
@@ -10359,7 +10267,7 @@ mod tests {
(Lsn(0x20), tline.timeline_id, MaybeOffloaded::No),
],
cutoffs: GcCutoffs {
time: Some(Lsn(0x30)),
time: Lsn(0x30),
space: Lsn(0x30),
},
leases: Default::default(),
@@ -10422,7 +10330,7 @@ mod tests {
let verify_result = || async {
let gc_horizon = {
let gc_info = tline.gc_info.read().unwrap();
gc_info.cutoffs.time.unwrap_or_default()
gc_info.cutoffs.time
};
for idx in 0..10 {
assert_eq!(
@@ -10608,7 +10516,7 @@ mod tests {
*guard = GcInfo {
retain_lsns: vec![(Lsn(0x18), branch_tline.timeline_id, MaybeOffloaded::No)],
cutoffs: GcCutoffs {
time: Some(Lsn(0x10)),
time: Lsn(0x10),
space: Lsn(0x10),
},
leases: Default::default(),
@@ -10628,7 +10536,7 @@ mod tests {
*guard = GcInfo {
retain_lsns: vec![(Lsn(0x40), branch_tline.timeline_id, MaybeOffloaded::No)],
cutoffs: GcCutoffs {
time: Some(Lsn(0x50)),
time: Lsn(0x50),
space: Lsn(0x50),
},
leases: Default::default(),
@@ -11349,7 +11257,7 @@ mod tests {
*guard = GcInfo {
retain_lsns: vec![(Lsn(0x20), tline.timeline_id, MaybeOffloaded::No)],
cutoffs: GcCutoffs {
time: Some(Lsn(0x30)),
time: Lsn(0x30),
space: Lsn(0x30),
},
leases: Default::default(),
@@ -11738,7 +11646,7 @@ mod tests {
(Lsn(0x20), tline.timeline_id, MaybeOffloaded::No),
],
cutoffs: GcCutoffs {
time: Some(Lsn(0x30)),
time: Lsn(0x30),
space: Lsn(0x30),
},
leases: Default::default(),
@@ -11801,7 +11709,7 @@ mod tests {
let verify_result = || async {
let gc_horizon = {
let gc_info = tline.gc_info.read().unwrap();
gc_info.cutoffs.time.unwrap_or_default()
gc_info.cutoffs.time
};
for idx in 0..10 {
assert_eq!(
@@ -11990,7 +11898,7 @@ mod tests {
(Lsn(0x20), tline.timeline_id, MaybeOffloaded::No),
],
cutoffs: GcCutoffs {
time: Some(Lsn(0x30)),
time: Lsn(0x30),
space: Lsn(0x30),
},
leases: Default::default(),
@@ -12053,7 +11961,7 @@ mod tests {
let verify_result = || async {
let gc_horizon = {
let gc_info = tline.gc_info.read().unwrap();
gc_info.cutoffs.time.unwrap_or_default()
gc_info.cutoffs.time
};
for idx in 0..10 {
assert_eq!(
@@ -12316,7 +12224,7 @@ mod tests {
*guard = GcInfo {
retain_lsns: vec![],
cutoffs: GcCutoffs {
time: Some(Lsn(0x30)),
time: Lsn(0x30),
space: Lsn(0x30),
},
leases: Default::default(),

View File

@@ -235,7 +235,7 @@ pub(super) async fn gather_inputs(
// than our internal space cutoff. This means that if someone drops a database and waits for their
// PITR interval, they will see synthetic size decrease, even if we are still storing data inside
// the space cutoff.
let mut next_pitr_cutoff = gc_info.cutoffs.time.unwrap_or_default(); // TODO: handle None
let mut next_pitr_cutoff = gc_info.cutoffs.time;
// If the caller provided a shorter retention period, use that instead of the GC cutoff.
let retention_param_cutoff = if let Some(max_retention_period) = max_retention_period {

View File

@@ -63,28 +63,7 @@ pub struct InMemoryLayer {
opened_at: Instant,
/// All versions of all pages in the layer are kept here. Indexed
/// by block number and LSN. The [`IndexEntry`] is an offset into the
/// ephemeral file where the page version is stored.
///
/// We use a separate lock for the index to reduce the critical section
/// during which reads cannot be planned.
///
/// If you need access to both the index and the underlying file at the same time,
/// respect the following locking order to avoid deadlocks:
/// 1. [`InMemoryLayer::inner`]
/// 2. [`InMemoryLayer::index`]
///
/// Note that the file backing [`InMemoryLayer::inner`] is append-only,
/// so it is not necessary to hold simultaneous locks on index.
/// This avoids holding index locks across IO, and is crucial for avoiding read tail latency.
/// In particular:
/// 1. It is safe to read and release [`InMemoryLayer::index`] before locking and reading from [`InMemoryLayer::inner`].
/// 2. It is safe to write and release [`InMemoryLayer::inner`] before locking and updating [`InMemoryLayer::index`].
index: RwLock<BTreeMap<CompactKey, VecMap<Lsn, IndexEntry>>>,
/// The above fields never change, except for `end_lsn`, which is only set once,
/// and `index` (see rationale there).
/// The above fields never change, except for `end_lsn`, which is only set once.
/// All other changing parts are in `inner`, and protected by a mutex.
inner: RwLock<InMemoryLayerInner>,
@@ -102,6 +81,11 @@ impl std::fmt::Debug for InMemoryLayer {
}
pub struct InMemoryLayerInner {
/// All versions of all pages in the layer are kept here. Indexed
/// by block number and LSN. The [`IndexEntry`] is an offset into the
/// ephemeral file where the page version is stored.
index: BTreeMap<CompactKey, VecMap<Lsn, IndexEntry>>,
/// The values are stored in a serialized format in this file.
/// Each serialized Value is preceded by a 'u32' length field.
/// PerSeg::page_versions map stores offsets into this file.
@@ -121,7 +105,7 @@ const MAX_SUPPORTED_BLOB_LEN_BITS: usize = {
trailing_ones
};
/// See [`InMemoryLayer::index`].
/// See [`InMemoryLayerInner::index`].
///
/// For memory efficiency, the data is packed into a u64.
///
@@ -441,7 +425,7 @@ impl InMemoryLayer {
.page_content_kind(PageContentKind::InMemoryLayer)
.attached_child();
let index = self.index.read().await;
let inner = self.inner.read().await;
struct ValueRead {
entry_lsn: Lsn,
@@ -451,7 +435,10 @@ impl InMemoryLayer {
let mut ios: HashMap<(Key, Lsn), OnDiskValueIo> = Default::default();
for range in keyspace.ranges.iter() {
for (key, vec_map) in index.range(range.start.to_compact()..range.end.to_compact()) {
for (key, vec_map) in inner
.index
.range(range.start.to_compact()..range.end.to_compact())
{
let key = Key::from_compact(*key);
let slice = vec_map.slice_range(lsn_range.clone());
@@ -479,7 +466,7 @@ impl InMemoryLayer {
}
}
}
drop(index); // release the lock before we spawn the IO; if it's serial-mode IO we will deadlock on the read().await below
drop(inner); // release the lock before we spawn the IO; if it's serial-mode IO we will deadlock on the read().await below
let read_from = Arc::clone(self);
let read_ctx = ctx.attached_child();
reconstruct_state
@@ -586,8 +573,8 @@ impl InMemoryLayer {
start_lsn,
end_lsn: OnceLock::new(),
opened_at: Instant::now(),
index: RwLock::new(BTreeMap::new()),
inner: RwLock::new(InMemoryLayerInner {
index: BTreeMap::new(),
file,
resource_units: GlobalResourceUnits::new(),
}),
@@ -605,39 +592,31 @@ impl InMemoryLayer {
serialized_batch: SerializedValueBatch,
ctx: &RequestContext,
) -> anyhow::Result<()> {
let (base_offset, metadata) = {
let mut inner = self.inner.write().await;
self.assert_writable();
let mut inner = self.inner.write().await;
self.assert_writable();
let base_offset = inner.file.len();
let base_offset = inner.file.len();
let SerializedValueBatch {
raw,
metadata,
max_lsn: _,
len: _,
} = serialized_batch;
let SerializedValueBatch {
raw,
metadata,
max_lsn: _,
len: _,
} = serialized_batch;
// Write the batch to the file
inner.file.write_raw(&raw, ctx).await?;
let new_size = inner.file.len();
// Write the batch to the file
inner.file.write_raw(&raw, ctx).await?;
let new_size = inner.file.len();
let expected_new_len = base_offset
.checked_add(raw.len().into_u64())
// write_raw would error if we were to overflow u64.
// also IndexEntry and higher levels in
//the code don't allow the file to grow that large
.unwrap();
assert_eq!(new_size, expected_new_len);
inner.resource_units.maybe_publish_size(new_size);
(base_offset, metadata)
};
let expected_new_len = base_offset
.checked_add(raw.len().into_u64())
// write_raw would error if we were to overflow u64.
// also IndexEntry and higher levels in
//the code don't allow the file to grow that large
.unwrap();
assert_eq!(new_size, expected_new_len);
// Update the index with the new entries
let mut index = self.index.write().await;
for meta in metadata {
let SerializedValueMeta {
key,
@@ -660,7 +639,7 @@ impl InMemoryLayer {
will_init,
})?;
let vec_map = index.entry(key).or_default();
let vec_map = inner.index.entry(key).or_default();
let old = vec_map.append_or_update_last(lsn, index_entry).unwrap().0;
if old.is_some() {
// This should not break anything, but is unexpected: ingestion code aims to filter out
@@ -679,6 +658,8 @@ impl InMemoryLayer {
);
}
inner.resource_units.maybe_publish_size(new_size);
Ok(())
}
@@ -699,18 +680,6 @@ impl InMemoryLayer {
/// Records the end_lsn for non-dropped layers.
/// `end_lsn` is exclusive
///
/// A note on locking:
/// The current API of [`InMemoryLayer`] does not ensure that there's no ongoing
/// writes while freezing the layer. This is enforced at a higher level via
/// [`crate::tenant::Timeline::write_lock`]. Freeze might be called via two code paths:
/// 1. Via the active [`crate::tenant::timeline::TimelineWriter`]. This holds the
/// Timeline::write_lock for its lifetime. The rolling is handled in
/// [`crate::tenant::timeline::TimelineWriter::put_batch`]. It's a &mut self function
/// so can't be called from different threads.
/// 2. In the background via [`crate::tenant::Timeline::maybe_freeze_ephemeral_layer`].
/// This only proceeds if try_lock on Timeline::write_lock succeeds (i.e. there's no active writer),
/// hence there can be no concurrent writes
pub async fn freeze(&self, end_lsn: Lsn) {
assert!(
self.start_lsn < end_lsn,
@@ -731,8 +700,8 @@ impl InMemoryLayer {
#[cfg(debug_assertions)]
{
let index = self.index.read().await;
for vec_map in index.values() {
let inner = self.inner.write().await;
for vec_map in inner.index.values() {
for (lsn, _) in vec_map.as_slice() {
assert!(*lsn < end_lsn);
}
@@ -755,11 +724,14 @@ impl InMemoryLayer {
) -> Result<Option<(PersistentLayerDesc, Utf8PathBuf)>> {
// Grab the lock in read-mode. We hold it over the I/O, but because this
// layer is not writeable anymore, no one should be trying to acquire the
// write lock on it, so we shouldn't block anyone. See the comment on
// [`InMemoryLayer::freeze`] to understand how locking between the append path
// and layer flushing works.
// write lock on it, so we shouldn't block anyone. There's one exception
// though: another thread might have grabbed a reference to this layer
// in `get_layer_for_write' just before the checkpointer called
// `freeze`, and then `write_to_disk` on it. When the thread gets the
// lock, it will see that it's not writeable anymore and retry, but it
// would have to wait until we release it. That race condition is very
// rare though, so we just accept the potential latency hit for now.
let inner = self.inner.read().await;
let index = self.index.read().await;
use l0_flush::Inner;
let _concurrency_permit = match l0_flush_global_state {
@@ -771,9 +743,13 @@ impl InMemoryLayer {
let key_count = if let Some(key_range) = key_range {
let key_range = key_range.start.to_compact()..key_range.end.to_compact();
index.iter().filter(|(k, _)| key_range.contains(k)).count()
inner
.index
.iter()
.filter(|(k, _)| key_range.contains(k))
.count()
} else {
index.len()
inner.index.len()
};
if key_count == 0 {
return Ok(None);
@@ -796,7 +772,7 @@ impl InMemoryLayer {
let file_contents = inner.file.load_to_io_buf(ctx).await?;
let file_contents = file_contents.freeze();
for (key, vec_map) in index.iter() {
for (key, vec_map) in inner.index.iter() {
// Write all page versions
for (lsn, entry) in vec_map
.as_slice()

View File

@@ -24,6 +24,8 @@ use std::sync::atomic::{AtomicBool, AtomicU64, Ordering as AtomicOrdering};
use std::sync::{Arc, Mutex, OnceLock, RwLock, Weak};
use std::time::{Duration, Instant, SystemTime};
use crate::PERF_TRACE_TARGET;
use crate::walredo::RedoAttemptType;
use anyhow::{Context, Result, anyhow, bail, ensure};
use arc_swap::{ArcSwap, ArcSwapOption};
use bytes::Bytes;
@@ -92,18 +94,15 @@ use super::storage_layer::{LayerFringe, LayerVisibilityHint, ReadableLayer};
use super::tasks::log_compaction_error;
use super::upload_queue::NotInitialized;
use super::{
AttachedTenantConf, BasebackupPrepareSender, GcError, HeatMapTimeline, MaybeOffloaded,
AttachedTenantConf, GcError, HeatMapTimeline, MaybeOffloaded,
debug_assert_current_span_has_tenant_and_timeline_id,
};
use crate::PERF_TRACE_TARGET;
use crate::aux_file::AuxFileSizeEstimator;
use crate::basebackup_cache::BasebackupPrepareRequest;
use crate::config::PageServerConf;
use crate::context::{
DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder,
};
use crate::disk_usage_eviction_task::{DiskUsageEvictionInfo, EvictionCandidate, finite_f32};
use crate::feature_resolver::FeatureResolver;
use crate::keyspace::{KeyPartitioning, KeySpace};
use crate::l0_flush::{self, L0FlushGlobalState};
use crate::metrics::{
@@ -132,7 +131,6 @@ use crate::tenant::tasks::BackgroundLoopKind;
use crate::tenant::timeline::logical_size::CurrentLogicalSize;
use crate::virtual_file::{MaybeFatalIo, VirtualFile};
use crate::walingest::WalLagCooldown;
use crate::walredo::RedoAttemptType;
use crate::{ZERO_PAGE, task_mgr, walredo};
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
@@ -198,8 +196,6 @@ pub struct TimelineResources {
pub pagestream_throttle_metrics: Arc<crate::metrics::tenant_throttling::Pagestream>,
pub l0_compaction_trigger: Arc<Notify>,
pub l0_flush_global_state: l0_flush::L0FlushGlobalState,
pub basebackup_prepare_sender: BasebackupPrepareSender,
pub feature_resolver: FeatureResolver,
}
pub struct Timeline {
@@ -443,11 +439,6 @@ pub struct Timeline {
pub(crate) rel_size_v2_status: ArcSwapOption<RelSizeMigration>,
wait_lsn_log_slow: tokio::sync::Semaphore,
/// A channel to send async requests to prepare a basebackup for the basebackup cache.
basebackup_prepare_sender: BasebackupPrepareSender,
feature_resolver: FeatureResolver,
}
pub(crate) enum PreviousHeatmap {
@@ -538,24 +529,29 @@ impl GcInfo {
/// The `GcInfo` component describing which Lsns need to be retained. Functionally, this
/// is a single number (the oldest LSN which we must retain), but it internally distinguishes
/// between time-based and space-based retention for observability and consumption metrics purposes.
#[derive(Clone, Debug, Default)]
#[derive(Debug, Clone)]
pub(crate) struct GcCutoffs {
/// Calculated from the [`pageserver_api::models::TenantConfig::gc_horizon`], this LSN indicates how much
/// history we must keep to retain a specified number of bytes of WAL.
pub(crate) space: Lsn,
/// Calculated from [`pageserver_api::models::TenantConfig::pitr_interval`], this LSN indicates
/// how much history we must keep to enable reading back at least the PITR interval duration.
///
/// None indicates that the PITR cutoff has not been computed. A PITR interval of 0 will yield
/// Some(last_record_lsn).
pub(crate) time: Option<Lsn>,
/// Calculated from [`pageserver_api::models::TenantConfig::pitr_interval`], this LSN indicates how much
/// history we must keep to enable reading back at least the PITR interval duration.
pub(crate) time: Lsn,
}
impl Default for GcCutoffs {
fn default() -> Self {
Self {
space: Lsn::INVALID,
time: Lsn::INVALID,
}
}
}
impl GcCutoffs {
fn select_min(&self) -> Lsn {
// NB: if we haven't computed the PITR cutoff yet, we can't GC anything.
self.space.min(self.time.unwrap_or_default())
std::cmp::min(self.space, self.time)
}
}
@@ -1037,7 +1033,6 @@ pub(crate) enum WaitLsnWaiter<'a> {
Tenant,
PageService,
HttpEndpoint,
BaseBackupCache,
}
/// Argument to [`Timeline::shutdown`].
@@ -1093,14 +1088,11 @@ impl Timeline {
/// Get the bytes written since the PITR cutoff on this branch, and
/// whether this branch's ancestor_lsn is within its parent's PITR.
pub(crate) fn get_pitr_history_stats(&self) -> (u64, bool) {
// TODO: for backwards compatibility, we return the full history back to 0 when the PITR
// cutoff has not yet been initialized. This should return None instead, but this is exposed
// in external HTTP APIs and callers may not handle a null value.
let gc_info = self.gc_info.read().unwrap();
let history = self
.get_last_record_lsn()
.checked_sub(gc_info.cutoffs.time.unwrap_or_default())
.unwrap_or_default()
.checked_sub(gc_info.cutoffs.time)
.unwrap_or(Lsn(0))
.0;
(history, gc_info.within_ancestor_pitr)
}
@@ -1110,10 +1102,9 @@ impl Timeline {
self.applied_gc_cutoff_lsn.read()
}
/// Read timeline's planned GC cutoff: this is the logical end of history that users are allowed
/// to read (based on configured PITR), even if physically we have more history. Returns None
/// if the PITR cutoff has not yet been initialized.
pub(crate) fn get_gc_cutoff_lsn(&self) -> Option<Lsn> {
/// Read timeline's planned GC cutoff: this is the logical end of history that users
/// are allowed to read (based on configured PITR), even if physically we have more history.
pub(crate) fn get_gc_cutoff_lsn(&self) -> Lsn {
self.gc_info.read().unwrap().cutoffs.time
}
@@ -1564,8 +1555,7 @@ impl Timeline {
}
WaitLsnWaiter::Tenant
| WaitLsnWaiter::PageService
| WaitLsnWaiter::HttpEndpoint
| WaitLsnWaiter::BaseBackupCache => unreachable!(
| WaitLsnWaiter::HttpEndpoint => unreachable!(
"tenant or page_service context are not expected to have task kind {:?}",
ctx.task_kind()
),
@@ -2470,41 +2460,6 @@ impl Timeline {
false
}
}
pub(crate) fn is_basebackup_cache_enabled(&self) -> bool {
let tenant_conf = self.tenant_conf.load();
tenant_conf
.tenant_conf
.basebackup_cache_enabled
.unwrap_or(self.conf.default_tenant_conf.basebackup_cache_enabled)
}
/// Prepare basebackup for the given LSN and store it in the basebackup cache.
/// The method is asynchronous and returns immediately.
/// The actual basebackup preparation is performed in the background
/// by the basebackup cache on a best-effort basis.
pub(crate) fn prepare_basebackup(&self, lsn: Lsn) {
if !self.is_basebackup_cache_enabled() {
return;
}
if !self.tenant_shard_id.is_shard_zero() {
// In theory we should never get here, but just in case check it.
// Preparing basebackup doesn't make sense for shards other than shard zero.
return;
}
let res = self
.basebackup_prepare_sender
.send(BasebackupPrepareRequest {
tenant_shard_id: self.tenant_shard_id,
timeline_id: self.timeline_id,
lsn,
});
if let Err(e) = res {
// May happen during shutdown, it's not critical.
info!("Failed to send shutdown checkpoint: {e:#}");
}
}
}
/// Number of times we will compute partition within a checkpoint distance.
@@ -2582,13 +2537,6 @@ impl Timeline {
.unwrap_or(self.conf.default_tenant_conf.checkpoint_timeout)
}
pub(crate) fn get_pitr_interval(&self) -> Duration {
let tenant_conf = &self.tenant_conf.load().tenant_conf;
tenant_conf
.pitr_interval
.unwrap_or(self.conf.default_tenant_conf.pitr_interval)
}
fn get_compaction_period(&self) -> Duration {
let tenant_conf = self.tenant_conf.load().tenant_conf.clone();
tenant_conf
@@ -3074,10 +3022,6 @@ impl Timeline {
rel_size_v2_status: ArcSwapOption::from_pointee(rel_size_v2_status),
wait_lsn_log_slow: tokio::sync::Semaphore::new(1),
basebackup_prepare_sender: resources.basebackup_prepare_sender,
feature_resolver: resources.feature_resolver,
};
result.repartition_threshold =
@@ -4912,7 +4856,6 @@ impl Timeline {
LastImageLayerCreationStatus::Initial,
false, // don't yield for L0, we're flushing L0
)
.instrument(info_span!("create_image_layers", mode = %ImageLayerCreationMode::Initial, partition_mode = "initial", lsn = %self.initdb_lsn))
.await?;
debug_assert!(
matches!(is_complete, LastImageLayerCreationStatus::Complete),
@@ -5469,8 +5412,7 @@ impl Timeline {
/// Returns the image layers generated and an enum indicating whether the process is fully completed.
/// true = we have generate all image layers, false = we preempt the process for L0 compaction.
///
/// `partition_mode` is only for logging purpose and is not used anywhere in this function.
#[tracing::instrument(skip_all, fields(%lsn, %mode))]
async fn create_image_layers(
self: &Arc<Timeline>,
partitioning: &KeyPartitioning,
@@ -6293,12 +6235,14 @@ impl Timeline {
pausable_failpoint!("Timeline::find_gc_cutoffs-pausable");
if cfg!(test) && pitr == Duration::ZERO {
if cfg!(test) {
// Unit tests which specify zero PITR interval expect to avoid doing any I/O for timestamp lookup
return Ok(GcCutoffs {
time: Some(self.get_last_record_lsn()),
space: space_cutoff,
});
if pitr == Duration::ZERO {
return Ok(GcCutoffs {
time: self.get_last_record_lsn(),
space: space_cutoff,
});
}
}
// Calculate a time-based limit on how much to retain:
@@ -6312,14 +6256,14 @@ impl Timeline {
// PITR is not set. Retain the size-based limit, or the default time retention,
// whichever requires less data.
GcCutoffs {
time: Some(self.get_last_record_lsn()),
time: self.get_last_record_lsn(),
space: std::cmp::max(time_cutoff, space_cutoff),
}
}
(Duration::ZERO, None) => {
// PITR is not set, and time lookup failed
GcCutoffs {
time: Some(self.get_last_record_lsn()),
time: self.get_last_record_lsn(),
space: space_cutoff,
}
}
@@ -6327,7 +6271,7 @@ impl Timeline {
// PITR interval is set & we didn't look up a timestamp successfully. Conservatively assume PITR
// cannot advance beyond what was already GC'd, and respect space-based retention
GcCutoffs {
time: Some(*self.get_applied_gc_cutoff_lsn()),
time: *self.get_applied_gc_cutoff_lsn(),
space: space_cutoff,
}
}
@@ -6335,7 +6279,7 @@ impl Timeline {
// PITR interval is set and we looked up timestamp successfully. Ignore
// size based retention and make time cutoff authoritative
GcCutoffs {
time: Some(time_cutoff),
time: time_cutoff,
space: time_cutoff,
}
}
@@ -6388,7 +6332,7 @@ impl Timeline {
)
};
let mut new_gc_cutoff = space_cutoff.min(time_cutoff.unwrap_or_default());
let mut new_gc_cutoff = Lsn::min(space_cutoff, time_cutoff);
let standby_horizon = self.standby_horizon.load();
// Hold GC for the standby, but as a safety guard do it only within some
// reasonable lag.
@@ -6437,7 +6381,7 @@ impl Timeline {
async fn gc_timeline(
&self,
space_cutoff: Lsn,
time_cutoff: Option<Lsn>, // None if uninitialized
time_cutoff: Lsn,
retain_lsns: Vec<Lsn>,
max_lsn_with_valid_lease: Option<Lsn>,
new_gc_cutoff: Lsn,
@@ -6456,12 +6400,6 @@ impl Timeline {
return Ok(result);
}
let Some(time_cutoff) = time_cutoff else {
// The GC cutoff should have been computed by now, but let's be defensive.
info!("Nothing to GC: time_cutoff not yet computed");
return Ok(result);
};
// We need to ensure that no one tries to read page versions or create
// branches at a point before latest_gc_cutoff_lsn. See branch_timeline()
// for details. This will block until the old value is no longer in use.

View File

@@ -1278,55 +1278,11 @@ impl Timeline {
}
let gc_cutoff = *self.applied_gc_cutoff_lsn.read();
let l0_l1_boundary_lsn = {
// We do the repartition on the L0-L1 boundary. All data below the boundary
// are compacted by L0 with low read amplification, thus making the `repartition`
// function run fast.
let guard = self.layers.read().await;
guard
.all_persistent_layers()
.iter()
.map(|x| {
// Use the end LSN of delta layers OR the start LSN of image layers.
if x.is_delta {
x.lsn_range.end
} else {
x.lsn_range.start
}
})
.max()
};
let (partition_mode, partition_lsn) = if cfg!(test)
|| cfg!(feature = "testing")
|| self
.feature_resolver
.evaluate_boolean("image-compaction-boundary", self.tenant_shard_id.tenant_id)
.is_ok()
{
let last_repartition_lsn = self.partitioning.read().1;
let lsn = match l0_l1_boundary_lsn {
Some(boundary) => gc_cutoff
.max(boundary)
.max(last_repartition_lsn)
.max(self.initdb_lsn)
.max(self.ancestor_lsn),
None => self.get_last_record_lsn(),
};
if lsn <= self.initdb_lsn || lsn <= self.ancestor_lsn {
// Do not attempt to create image layers below the initdb or ancestor LSN -- no data below it
("l0_l1_boundary", self.get_last_record_lsn())
} else {
("l0_l1_boundary", lsn)
}
} else {
("latest_record", self.get_last_record_lsn())
};
// 2. Repartition and create image layers if necessary
match self
.repartition(
partition_lsn,
self.get_last_record_lsn(),
self.get_compaction_target_size(),
options.flags,
ctx,
@@ -1345,19 +1301,18 @@ impl Timeline {
.extend(sparse_partitioning.into_dense().parts);
// 3. Create new image layers for partitions that have been modified "enough".
let mode = if options
.flags
.contains(CompactFlags::ForceImageLayerCreation)
{
ImageLayerCreationMode::Force
} else {
ImageLayerCreationMode::Try
};
let (image_layers, outcome) = self
.create_image_layers(
&partitioning,
lsn,
mode,
if options
.flags
.contains(CompactFlags::ForceImageLayerCreation)
{
ImageLayerCreationMode::Force
} else {
ImageLayerCreationMode::Try
},
&image_ctx,
self.last_image_layer_creation_status
.load()
@@ -1365,7 +1320,6 @@ impl Timeline {
.clone(),
options.flags.contains(CompactFlags::YieldForL0),
)
.instrument(info_span!("create_image_layers", mode = %mode, partition_mode = %partition_mode, lsn = %lsn))
.await
.inspect_err(|err| {
if let CreateImageLayersError::GetVectoredError(
@@ -1390,8 +1344,7 @@ impl Timeline {
}
Ok(_) => {
// This happens very frequently so we don't want to log it.
debug!("skipping repartitioning due to image compaction LSN being below GC cutoff");
info!("skipping repartitioning due to image compaction LSN being below GC cutoff");
}
// Suppress errors when cancelled.
@@ -1573,7 +1526,7 @@ impl Timeline {
info!(
"starting shard ancestor compaction, rewriting {} layers and dropping {} layers, \
checked {layers_checked}/{layers_total} layers \
(latest_gc_cutoff={} pitr_cutoff={:?})",
(latest_gc_cutoff={} pitr_cutoff={})",
layers_to_rewrite.len(),
drop_layers.len(),
*latest_gc_cutoff,

View File

@@ -25,11 +25,8 @@ pub(crate) struct ImportingTimeline {
}
impl ImportingTimeline {
pub(crate) async fn shutdown(self) {
pub(crate) fn shutdown(self) {
self.import_task_handle.abort();
let _ = self.import_task_handle.await;
self.timeline.remote_client.shutdown().await;
}
}
@@ -96,11 +93,6 @@ pub async fn doit(
);
}
timeline
.remote_client
.schedule_index_upload_for_file_changes()?;
timeline.remote_client.wait_completion().await?;
// Communicate that shard is done.
// Ensure at-least-once delivery of the upcall to storage controller
// before we mark the task as done and never come here again.

View File

@@ -30,7 +30,6 @@
use std::collections::HashSet;
use std::hash::{Hash, Hasher};
use std::num::NonZeroUsize;
use std::ops::Range;
use std::sync::Arc;
@@ -101,24 +100,8 @@ async fn run_v1(
tasks: Vec::default(),
};
// Use the job size limit encoded in the progress if we are resuming an import.
// This ensures that imports have stable plans even if the pageserver config changes.
let import_config = {
match &import_progress {
Some(progress) => {
let base = &timeline.conf.timeline_import_config;
TimelineImportConfig {
import_job_soft_size_limit: NonZeroUsize::new(progress.job_soft_size_limit)
.unwrap(),
import_job_concurrency: base.import_job_concurrency,
import_job_checkpoint_threshold: base.import_job_checkpoint_threshold,
}
}
None => timeline.conf.timeline_import_config.clone(),
}
};
let plan = planner.plan(&import_config).await?;
let import_config = &timeline.conf.timeline_import_config;
let plan = planner.plan(import_config).await?;
// Hash the plan and compare with the hash of the plan we got back from the storage controller.
// If the two match, it means that the planning stage had the same output.
@@ -130,20 +113,20 @@ async fn run_v1(
let plan_hash = hasher.finish();
if let Some(progress) = &import_progress {
if plan_hash != progress.import_plan_hash {
anyhow::bail!("Import plan does not match storcon metadata");
}
// Handle collisions on jobs of unequal length
if progress.jobs != plan.jobs.len() {
anyhow::bail!("Import plan job length does not match storcon metadata")
}
if plan_hash != progress.import_plan_hash {
anyhow::bail!("Import plan does not match storcon metadata");
}
}
pausable_failpoint!("import-timeline-pre-execute-pausable");
let start_from_job_idx = import_progress.map(|progress| progress.completed);
plan.execute(timeline, start_from_job_idx, plan_hash, &import_config, ctx)
plan.execute(timeline, start_from_job_idx, plan_hash, import_config, ctx)
.await
}
@@ -235,19 +218,6 @@ impl Planner {
checkpoint_buf,
)));
// Sort the tasks by the key ranges they handle.
// The plan being generated here needs to be stable across invocations
// of this method.
self.tasks.sort_by_key(|task| match task {
AnyImportTask::SingleKey(key) => (key.key, key.key.next()),
AnyImportTask::RelBlocks(rel_blocks) => {
(rel_blocks.key_range.start, rel_blocks.key_range.end)
}
AnyImportTask::SlruBlocks(slru_blocks) => {
(slru_blocks.key_range.start, slru_blocks.key_range.end)
}
});
// Assigns parts of key space to later parallel jobs
let mut last_end_key = Key::MIN;
let mut current_chunk = Vec::new();
@@ -456,8 +426,6 @@ impl Plan {
}));
},
maybe_complete_job_idx = work.next() => {
pausable_failpoint!("import-task-complete-pausable");
match maybe_complete_job_idx {
Some(Ok((job_idx, res))) => {
assert!(last_completed_job_idx.checked_add(1).unwrap() == job_idx);
@@ -470,12 +438,8 @@ impl Plan {
jobs: jobs_in_plan,
completed: last_completed_job_idx,
import_plan_hash,
job_soft_size_limit: import_config.import_job_soft_size_limit.into(),
};
timeline.remote_client.schedule_index_upload_for_file_changes()?;
timeline.remote_client.wait_completion().await?;
storcon_client.put_timeline_import_status(
timeline.tenant_shard_id,
timeline.timeline_id,
@@ -676,11 +640,7 @@ impl Hash for ImportSingleKeyTask {
let ImportSingleKeyTask { key, buf } = self;
key.hash(state);
// The key value might not have a stable binary representation.
// For instance, the db directory uses an unstable hash-map.
// To work around this we are a bit lax here and only hash the
// size of the buffer which must be consistent.
buf.len().hash(state);
buf.hash(state);
}
}
@@ -955,7 +915,7 @@ impl ChunkProcessingJob {
let guard = timeline.layers.read().await;
let existing_layer = guard.try_get_from_key(&desc.key());
if let Some(layer) = existing_layer {
if layer.metadata().generation == timeline.generation {
if layer.metadata().generation != timeline.generation {
return Err(anyhow::anyhow!(
"Import attempted to rewrite layer file in the same generation: {}",
layer.local_path()

View File

@@ -1316,10 +1316,6 @@ impl WalIngest {
}
});
if info == pg_constants::XLOG_CHECKPOINT_SHUTDOWN {
modification.tline.prepare_basebackup(lsn);
}
Ok(())
}

View File

@@ -717,7 +717,7 @@ prefetch_read(PrefetchRequest *slot)
Assert(slot->status == PRFS_REQUESTED);
Assert(slot->response == NULL);
Assert(slot->my_ring_index == MyPState->ring_receive);
Assert(readpage_reentrant_guard || AmPrewarmWorker);
Assert(readpage_reentrant_guard);
if (slot->status != PRFS_REQUESTED ||
slot->response != NULL ||
@@ -800,7 +800,7 @@ communicator_prefetch_receive(BufferTag tag)
PrfHashEntry *entry;
PrefetchRequest hashkey;
Assert(readpage_reentrant_guard || AmPrewarmWorker); /* do not pump prefetch state in prewarm worker */
Assert(readpage_reentrant_guard);
hashkey.buftag = tag;
entry = prfh_lookup(MyPState->prf_hash, &hashkey);
if (entry != NULL && prefetch_wait_for(entry->slot->my_ring_index))
@@ -2450,7 +2450,6 @@ void
communicator_reconfigure_timeout_if_needed(void)
{
bool needs_set = MyPState->ring_receive != MyPState->ring_unused &&
!AmPrewarmWorker && /* do not pump prefetch state in prewarm worker */
readahead_getpage_pull_timeout_ms > 0;
if (needs_set != timeout_set)

View File

@@ -201,8 +201,6 @@ static shmem_request_hook_type prev_shmem_request_hook;
bool lfc_store_prefetch_result;
bool lfc_prewarm_update_ws_estimation;
bool AmPrewarmWorker;
#define LFC_ENABLED() (lfc_ctl->limit != 0)
/*
@@ -847,8 +845,6 @@ lfc_prewarm_main(Datum main_arg)
PrewarmWorkerState* ws;
uint32 worker_id = DatumGetInt32(main_arg);
AmPrewarmWorker = true;
pqsignal(SIGTERM, die);
BackgroundWorkerUnblockSignals();

View File

@@ -23,8 +23,6 @@ extern int wal_acceptor_connection_timeout;
extern int readahead_getpage_pull_timeout_ms;
extern bool disable_wal_prev_lsn_checks;
extern bool AmPrewarmWorker;
#if PG_MAJORVERSION_NUM >= 17
extern uint32 WAIT_EVENT_NEON_LFC_MAINTENANCE;
extern uint32 WAIT_EVENT_NEON_LFC_READ;

View File

@@ -155,9 +155,8 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api)
int written = 0;
written = snprintf((char *) &sk->conninfo, MAXCONNINFO,
"%s host=%s port=%s dbname=replication options='-c timeline_id=%s tenant_id=%s'",
wp->config->safekeeper_conninfo_options, sk->host, sk->port,
wp->config->neon_timeline, wp->config->neon_tenant);
"host=%s port=%s dbname=replication options='-c timeline_id=%s tenant_id=%s'",
sk->host, sk->port, wp->config->neon_timeline, wp->config->neon_tenant);
if (written > MAXCONNINFO || written < 0)
wp_log(FATAL, "could not create connection string for safekeeper %s:%s", sk->host, sk->port);
}

View File

@@ -714,9 +714,6 @@ typedef struct WalProposerConfig
*/
char *safekeepers_list;
/* libpq connection info options. */
char *safekeeper_conninfo_options;
/*
* WalProposer reconnects to offline safekeepers once in this interval.
* Time is in milliseconds.

View File

@@ -64,7 +64,6 @@ char *wal_acceptors_list = "";
int wal_acceptor_reconnect_timeout = 1000;
int wal_acceptor_connection_timeout = 10000;
int safekeeper_proto_version = 3;
char *safekeeper_conninfo_options = "";
/* Set to true in the walproposer bgw. */
static bool am_walproposer;
@@ -120,7 +119,6 @@ init_walprop_config(bool syncSafekeepers)
walprop_config.neon_timeline = neon_timeline;
/* WalProposerCreate scribbles directly on it, so pstrdup */
walprop_config.safekeepers_list = pstrdup(wal_acceptors_list);
walprop_config.safekeeper_conninfo_options = pstrdup(safekeeper_conninfo_options);
walprop_config.safekeeper_reconnect_timeout = wal_acceptor_reconnect_timeout;
walprop_config.safekeeper_connection_timeout = wal_acceptor_connection_timeout;
walprop_config.wal_segment_size = wal_segment_size;
@@ -205,16 +203,6 @@ nwp_register_gucs(void)
* GUC_LIST_QUOTE */
NULL, assign_neon_safekeepers, NULL);
DefineCustomStringVariable(
"neon.safekeeper_conninfo_options",
"libpq keyword parameters and values to apply to safekeeper connections",
NULL,
&safekeeper_conninfo_options,
"",
PGC_POSTMASTER,
0,
NULL, NULL, NULL);
DefineCustomIntVariable(
"neon.safekeeper_reconnect_timeout",
"Walproposer reconnects to offline safekeepers once in this interval.",

17
poetry.lock generated
View File

@@ -3170,24 +3170,19 @@ pbr = "*"
[[package]]
name = "setuptools"
version = "78.1.1"
version = "70.0.0"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
optional = false
python-versions = ">=3.9"
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "setuptools-78.1.1-py3-none-any.whl", hash = "sha256:c3a9c4211ff4c309edb8b8c4f1cbfa7ae324c4ba9f91ff254e3d305b9fd54561"},
{file = "setuptools-78.1.1.tar.gz", hash = "sha256:fcc17fd9cd898242f6b4adfaca46137a9edef687f43e6f78469692a5e70d851d"},
{file = "setuptools-70.0.0-py3-none-any.whl", hash = "sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4"},
{file = "setuptools-70.0.0.tar.gz", hash = "sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0"},
]
[package.extras]
check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""]
core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"]
cover = ["pytest-cov"]
doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]
enabler = ["pytest-enabler (>=2.2)"]
test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"]
type = ["importlib_metadata (>=7.0.2) ; python_version < \"3.10\"", "jaraco.develop (>=7.21) ; sys_platform != \"cygwin\"", "mypy (==1.14.*)", "pytest-mypy"]
docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier"]
testing = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "importlib-metadata", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "mypy (==1.9)", "packaging (>=23.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.1)", "pytest-checkdocs (>=2.4)", "pytest-cov ; platform_python_implementation != \"PyPy\"", "pytest-enabler (>=2.2)", "pytest-home (>=0.5)", "pytest-mypy", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel"]
[[package]]
name = "six"

View File

@@ -127,4 +127,3 @@ rstest.workspace = true
walkdir.workspace = true
rand_distr = "0.4"
tokio-postgres.workspace = true
tracing-test = "0.2"

View File

@@ -80,22 +80,10 @@ impl std::fmt::Display for Backend<'_, ()> {
.field(&endpoint.url())
.finish(),
#[cfg(any(test, feature = "testing"))]
ControlPlaneClient::PostgresMock(endpoint) => {
let url = endpoint.url();
match url::Url::parse(url) {
Ok(mut url) => {
let _ = url.set_password(Some("_redacted_"));
let url = url.as_str();
fmt.debug_tuple("ControlPlane::PostgresMock")
.field(&url)
.finish()
}
Err(_) => fmt
.debug_tuple("ControlPlane::PostgresMock")
.field(&url)
.finish(),
}
}
ControlPlaneClient::PostgresMock(endpoint) => fmt
.debug_tuple("ControlPlane::PostgresMock")
.field(&endpoint.url())
.finish(),
#[cfg(test)]
ControlPlaneClient::Test(_) => fmt.debug_tuple("ControlPlane::Test").finish(),
},
@@ -548,7 +536,8 @@ mod tests {
use control_plane::AuthSecret;
use fallible_iterator::FallibleIterator;
use once_cell::sync::Lazy;
use postgres_protocol::authentication::sasl::{ChannelBinding, ScramSha256};
use postgres_protocol::CSafeStr;
use postgres_protocol::authentication::sasl::{ChannelBinding, SCRAM_SHA_256, ScramSha256};
use postgres_protocol::message::backend::Message as PgMessage;
use postgres_protocol::message::frontend;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
@@ -726,15 +715,15 @@ mod tests {
// server should offer scram
match read_message(&mut client, &mut read).await {
PgMessage::AuthenticationSasl(a) => {
let options: Vec<&str> = a.mechanisms().collect().unwrap();
assert_eq!(options, ["SCRAM-SHA-256"]);
let options: Vec<&CSafeStr> = a.mechanisms().collect().unwrap();
assert_eq!(options, [SCRAM_SHA_256]);
}
_ => panic!("wrong message"),
}
// client sends client-first-message
let mut write = BytesMut::new();
frontend::sasl_initial_response("SCRAM-SHA-256", scram.message(), &mut write).unwrap();
frontend::sasl_initial_response(SCRAM_SHA_256, scram.message(), &mut write);
client.write_all(&write).await.unwrap();
// server response with server-first-message
@@ -747,7 +736,7 @@ mod tests {
// client response with client-final-message
write.clear();
frontend::sasl_response(scram.message(), &mut write).unwrap();
frontend::sasl_response(scram.message(), &mut write);
client.write_all(&write).await.unwrap();
// server response with server-final-message
@@ -812,7 +801,7 @@ mod tests {
// client responds with password
write.clear();
frontend::password_message(b"my-secret-password", &mut write).unwrap();
frontend::password_message(c"my-secret-password".into(), &mut write);
client.write_all(&write).await.unwrap();
});
let endpoint_rate_limiter = Arc::new(EndpointRateLimiter::new_with_shards(
@@ -865,8 +854,10 @@ mod tests {
// client responds with password
let mut write = BytesMut::new();
frontend::password_message(b"endpoint=my-endpoint;my-secret-password", &mut write)
.unwrap();
frontend::password_message(
c"endpoint=my-endpoint;my-secret-password".into(),
&mut write,
);
client.write_all(&write).await.unwrap();
});

View File

@@ -3,7 +3,6 @@
use std::io;
use std::sync::Arc;
use postgres_protocol::authentication::sasl::{SCRAM_SHA_256, SCRAM_SHA_256_PLUS};
use pq_proto::{BeAuthenticationSaslMessage, BeMessage, BeMessage as Be};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::info;
@@ -174,8 +173,10 @@ impl<S: AsyncRead + AsyncWrite + Unpin> AuthFlow<'_, S, Scram<'_>> {
}
match sasl.method {
SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
SCRAM_SHA_256_PLUS => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus),
scram::SCRAM_SHA_256 => ctx.set_auth_method(crate::context::AuthMethod::ScramSha256),
scram::SCRAM_SHA_256_PLUS => {
ctx.set_auth_method(crate::context::AuthMethod::ScramSha256Plus)
}
_ => {}
}

View File

@@ -1,13 +1,9 @@
#[cfg(any(test, feature = "testing"))]
use std::env;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::pin::pin;
use std::sync::Arc;
use std::time::Duration;
#[cfg(any(test, feature = "testing"))]
use anyhow::Context;
use anyhow::{bail, ensure};
use arc_swap::ArcSwapOption;
use futures::future::Either;
@@ -39,8 +35,6 @@ use crate::scram::threadpool::ThreadPool;
use crate::serverless::GlobalConnPoolOptions;
use crate::serverless::cancel_set::CancelSet;
use crate::tls::client_config::compute_client_config_with_root_certs;
#[cfg(any(test, feature = "testing"))]
use crate::url::ApiUrl;
use crate::{auth, control_plane, http, serverless, usage_metrics};
project_git_version!(GIT_VERSION);
@@ -783,13 +777,7 @@ fn build_auth_backend(
#[cfg(any(test, feature = "testing"))]
AuthBackendType::Postgres => {
let mut url: ApiUrl = args.auth_endpoint.parse()?;
if url.password().is_none() {
let password = env::var("PGPASSWORD")
.with_context(|| "auth-endpoint does not contain a password and environment variable `PGPASSWORD` is not set")?;
url.set_password(Some(&password))
.expect("Failed to set password");
}
let url = args.auth_endpoint.parse()?;
let api = control_plane::client::mock::MockControlPlane::new(
url,
!args.is_private_access_proxy,

View File

@@ -48,7 +48,7 @@ impl ShouldRetryWakeCompute for postgres_client::error::DbError {
use postgres_client::error::SqlState;
// Here are errors that happens after the user successfully authenticated to the database.
// TODO: there are pgbouncer errors that should be retried, but they are not listed here.
let non_retriable_pg_errors = matches!(
!matches!(
self.code(),
&SqlState::TOO_MANY_CONNECTIONS
| &SqlState::OUT_OF_MEMORY
@@ -56,20 +56,8 @@ impl ShouldRetryWakeCompute for postgres_client::error::DbError {
| &SqlState::T_R_SERIALIZATION_FAILURE
| &SqlState::INVALID_CATALOG_NAME
| &SqlState::INVALID_SCHEMA_NAME
| &SqlState::INVALID_PARAMETER_VALUE,
);
if non_retriable_pg_errors {
return false;
}
// PGBouncer errors that should not trigger a wake_compute retry.
if self.code() == &SqlState::PROTOCOL_VIOLATION {
// Source for the error message:
// https://github.com/pgbouncer/pgbouncer/blob/f15997fe3effe3a94ba8bcc1ea562e6117d1a131/src/client.c#L1070
return !self
.message()
.contains("no more connections allowed (max_client_conn)");
}
true
| &SqlState::INVALID_PARAMETER_VALUE
)
}
}
@@ -122,55 +110,3 @@ pub(crate) fn retry_after(num_retries: u32, config: RetryConfig) -> time::Durati
.base_delay
.mul_f64(config.backoff_factor.powi((num_retries as i32) - 1))
}
#[cfg(test)]
mod tests {
use super::ShouldRetryWakeCompute;
use postgres_client::error::{DbError, SqlState};
#[test]
fn should_retry_wake_compute_for_db_error() {
// These SQLStates should NOT trigger a wake_compute retry.
let non_retry_states = [
SqlState::TOO_MANY_CONNECTIONS,
SqlState::OUT_OF_MEMORY,
SqlState::SYNTAX_ERROR,
SqlState::T_R_SERIALIZATION_FAILURE,
SqlState::INVALID_CATALOG_NAME,
SqlState::INVALID_SCHEMA_NAME,
SqlState::INVALID_PARAMETER_VALUE,
];
for state in non_retry_states {
let err = DbError::new_test_error(state.clone(), "oops".to_string());
assert!(
!err.should_retry_wake_compute(),
"State {state:?} unexpectedly retried"
);
}
// Errors coming from pgbouncer should not trigger a wake_compute retry
let non_retry_pgbouncer_errors = ["no more connections allowed (max_client_conn)"];
for error in non_retry_pgbouncer_errors {
let err = DbError::new_test_error(SqlState::PROTOCOL_VIOLATION, error.to_string());
assert!(
!err.should_retry_wake_compute(),
"PGBouncer error {error:?} unexpectedly retried"
);
}
// These SQLStates should trigger a wake_compute retry.
let retry_states = [
SqlState::CONNECTION_FAILURE,
SqlState::CONNECTION_EXCEPTION,
SqlState::CONNECTION_DOES_NOT_EXIST,
SqlState::SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION,
];
for state in retry_states {
let err = DbError::new_test_error(state.clone(), "oops".to_string());
assert!(
err.should_retry_wake_compute(),
"State {state:?} unexpectedly skipped retry"
);
}
}
}

View File

@@ -9,7 +9,7 @@ use std::fmt::Debug;
use bytes::{Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use postgres_client::tls::TlsConnect;
use postgres_protocol::message::frontend;
use postgres_protocol::{authentication::sasl::SCRAM_SHA_256, message::frontend};
use tokio::io::{AsyncReadExt, DuplexStream};
use tokio_util::codec::{Decoder, Encoder};
@@ -60,8 +60,7 @@ async fn proxy_mitm(
params: startup.params.into(),
},
&mut buf,
)
.unwrap();
);
end_server.send(buf.freeze()).await.unwrap();
// proxy messages between end_client and end_server
@@ -90,7 +89,7 @@ async fn proxy_mitm(
new_message.extend_from_slice(sasl_message.strip_prefix(b"p=tls-server-end-point,,").unwrap());
let mut buf = BytesMut::new();
frontend::sasl_initial_response("SCRAM-SHA-256", &new_message, &mut buf).unwrap();
frontend::sasl_initial_response(SCRAM_SHA_256, &new_message, &mut buf);
end_server.send(buf.freeze()).await.unwrap();
continue;

View File

@@ -15,7 +15,6 @@ use rstest::rstest;
use rustls::crypto::ring;
use rustls::pki_types;
use tokio::io::DuplexStream;
use tracing_test::traced_test;
use super::connect_compute::ConnectMechanism;
use super::retry::CouldRetry;
@@ -382,14 +381,8 @@ enum ConnectAction {
WakeFail,
WakeRetry,
Connect,
// connect_once -> Err, could_retry = true, should_retry_wake_compute = true
Retry,
// connect_once -> Err, could_retry = true, should_retry_wake_compute = false
RetryNoWake,
// connect_once -> Err, could_retry = false, should_retry_wake_compute = true
Fail,
// connect_once -> Err, could_retry = false, should_retry_wake_compute = false
FailNoWake,
}
#[derive(Clone)]
@@ -431,7 +424,6 @@ struct TestConnection;
#[derive(Debug)]
struct TestConnectError {
retryable: bool,
wakeable: bool,
kind: crate::error::ErrorKind,
}
@@ -456,7 +448,7 @@ impl CouldRetry for TestConnectError {
}
impl ShouldRetryWakeCompute for TestConnectError {
fn should_retry_wake_compute(&self) -> bool {
self.wakeable
true
}
}
@@ -479,22 +471,10 @@ impl ConnectMechanism for TestConnectMechanism {
ConnectAction::Connect => Ok(TestConnection),
ConnectAction::Retry => Err(TestConnectError {
retryable: true,
wakeable: true,
kind: ErrorKind::Compute,
}),
ConnectAction::RetryNoWake => Err(TestConnectError {
retryable: true,
wakeable: false,
kind: ErrorKind::Compute,
}),
ConnectAction::Fail => Err(TestConnectError {
retryable: false,
wakeable: true,
kind: ErrorKind::Compute,
}),
ConnectAction::FailNoWake => Err(TestConnectError {
retryable: false,
wakeable: false,
kind: ErrorKind::Compute,
}),
x => panic!("expecting action {x:?}, connect is called instead"),
@@ -729,92 +709,3 @@ async fn wake_non_retry() {
.unwrap_err();
mechanism.verify();
}
#[tokio::test]
#[traced_test]
async fn fail_but_wake_invalidates_cache() {
let ctx = RequestContext::test();
let mech = TestConnectMechanism::new(vec![
ConnectAction::Wake,
ConnectAction::Fail,
ConnectAction::Wake,
ConnectAction::Connect,
]);
let user = helper_create_connect_info(&mech);
let cfg = config();
connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg)
.await
.unwrap();
assert!(logs_contain(
"invalidating stalled compute node info cache entry"
));
}
#[tokio::test]
#[traced_test]
async fn fail_no_wake_skips_cache_invalidation() {
let ctx = RequestContext::test();
let mech = TestConnectMechanism::new(vec![
ConnectAction::Wake,
ConnectAction::FailNoWake,
ConnectAction::Connect,
]);
let user = helper_create_connect_info(&mech);
let cfg = config();
connect_to_compute(&ctx, &mech, &user, cfg.retry, &cfg)
.await
.unwrap();
assert!(!logs_contain(
"invalidating stalled compute node info cache entry"
));
}
#[tokio::test]
#[traced_test]
async fn retry_but_wake_invalidates_cache() {
let _ = env_logger::try_init();
use ConnectAction::*;
let ctx = RequestContext::test();
// Wake → Retry (retryable + wakeable) → Wake → Connect
let mechanism = TestConnectMechanism::new(vec![Wake, Retry, Wake, Connect]);
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
.await
.unwrap();
mechanism.verify();
// Because Retry has wakeable=true, we should see invalidate_cache
assert!(logs_contain(
"invalidating stalled compute node info cache entry"
));
}
#[tokio::test]
#[traced_test]
async fn retry_no_wake_skips_invalidation() {
let _ = env_logger::try_init();
use ConnectAction::*;
let ctx = RequestContext::test();
// Wake → RetryNoWake (retryable + NOT wakeable)
let mechanism = TestConnectMechanism::new(vec![Wake, RetryNoWake]);
let user_info = helper_create_connect_info(&mechanism);
let cfg = config();
connect_to_compute(&ctx, &mechanism, &user_info, cfg.retry, &cfg)
.await
.unwrap_err();
mechanism.verify();
// Because RetryNoWake has wakeable=false, we must NOT see invalidate_cache
assert!(!logs_contain(
"invalidating stalled compute node info cache entry"
));
}

View File

@@ -21,8 +21,8 @@ pub(crate) use key::ScramKey;
pub(crate) use secret::ServerSecret;
use sha2::{Digest, Sha256};
const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
pub(crate) const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
pub(crate) const SCRAM_SHA_256_PLUS: &str = "SCRAM-SHA-256-PLUS";
/// A list of supported SCRAM methods.
pub(crate) const METHODS: &[&str] = &[SCRAM_SHA_256_PLUS, SCRAM_SHA_256];

View File

@@ -14,9 +14,7 @@ use hyper::http::{HeaderName, HeaderValue};
use hyper::{HeaderMap, Request, Response, StatusCode, header};
use indexmap::IndexMap;
use postgres_client::error::{DbError, ErrorPosition, SqlState};
use postgres_client::{
GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction,
};
use postgres_client::{GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, Transaction};
use pq_proto::StartupMessageParamsBuilder;
use serde::Serialize;
use serde_json::Value;
@@ -1094,41 +1092,22 @@ async fn query_to_json<T: GenericClient>(
let query_start = Instant::now();
let query_params = data.params;
let mut row_stream = client
.query_raw_txt(&data.query, query_params)
.await
.map_err(SqlOverHttpError::Postgres)?;
let mut row_stream = std::pin::pin!(
client
.query_raw_txt(&data.query, query_params)
.await
.map_err(SqlOverHttpError::Postgres)?
);
let query_acknowledged = Instant::now();
let columns_len = row_stream.statement.columns().len();
let mut fields = Vec::with_capacity(columns_len);
let mut types = Vec::with_capacity(columns_len);
for c in row_stream.statement.columns() {
fields.push(json!({
"name": c.name().to_owned(),
"dataTypeID": c.type_().oid(),
"tableID": c.table_oid(),
"columnID": c.column_id(),
"dataTypeSize": c.type_size(),
"dataTypeModifier": c.type_modifier(),
"format": "text",
}));
types.push(c.type_().clone());
}
let raw_output = parsed_headers.raw_output;
let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode);
// Manually drain the stream into a vector to leave row_stream hanging
// around to get a command tag. Also check that the response is not too
// big.
let mut rows = Vec::new();
let mut rows: Vec<postgres_client::Row> = Vec::new();
while let Some(row) = row_stream.next().await {
let row = row.map_err(SqlOverHttpError::Postgres)?;
*current_size += row.body_len();
rows.push(row);
// we don't have a streaming response support yet so this is to prevent OOM
// from a malicious query (eg a cross join)
if *current_size > config.max_response_size_bytes {
@@ -1136,26 +1115,13 @@ async fn query_to_json<T: GenericClient>(
config.max_response_size_bytes,
));
}
let row = pg_text_row_to_json(&row, &types, raw_output, array_mode)?;
rows.push(row);
// assumption: parsing pg text and converting to json takes CPU time.
// let's assume it is slightly expensive, so we should consume some cooperative budget.
// Especially considering that `RowStream::next` might be pulling from a batch
// of rows and never hit the tokio mpsc for a long time (although unlikely).
tokio::task::consume_budget().await;
}
let query_resp_end = Instant::now();
let RowStream {
command_tag,
status: ready,
..
} = row_stream;
let ready = row_stream.ready_status();
// grab the command tag and number of rows affected
let command_tag = command_tag.unwrap_or_default();
let command_tag = row_stream.command_tag().unwrap_or_default();
let mut command_tag_split = command_tag.split(' ');
let command_tag_name = command_tag_split.next().unwrap_or_default();
let command_tag_count = if command_tag_name == "INSERT" {
@@ -1176,6 +1142,38 @@ async fn query_to_json<T: GenericClient>(
"finished executing query"
);
let columns_len = row_stream.columns().len();
let mut fields = Vec::with_capacity(columns_len);
let mut columns = Vec::with_capacity(columns_len);
for c in row_stream.columns() {
fields.push(json!({
"name": c.name().to_owned(),
"dataTypeID": c.type_().oid(),
"tableID": c.table_oid(),
"columnID": c.column_id(),
"dataTypeSize": c.type_size(),
"dataTypeModifier": c.type_modifier(),
"format": "text",
}));
match client.get_type(c.type_oid()).await {
Ok(t) => columns.push(t),
Err(err) => {
tracing::warn!(?err, "unable to query type information");
return Err(SqlOverHttpError::InternalPostgres(err));
}
}
}
let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode);
// convert rows to JSON
let rows = rows
.iter()
.map(|row| pg_text_row_to_json(row, &columns, parsed_headers.raw_output, array_mode))
.collect::<Result<Vec<_>, _>>()?;
// Resulting JSON format is based on the format of node-postgres result.
let results = json!({
"command": command_tag_name.to_string(),

View File

@@ -43,12 +43,6 @@ impl std::ops::Deref for ApiUrl {
}
}
impl std::ops::DerefMut for ApiUrl {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl std::fmt::Display for ApiUrl {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)

View File

@@ -44,7 +44,6 @@ struct GlobalTimelinesState {
// on-demand timeline creation from recreating deleted timelines. This is only soft-enforced, as
// this map is dropped on restart.
tombstones: HashMap<TenantTimelineId, Instant>,
tenant_tombstones: HashMap<TenantId, Instant>,
conf: Arc<SafeKeeperConf>,
broker_active_set: Arc<TimelinesSet>,
@@ -82,25 +81,10 @@ impl GlobalTimelinesState {
}
}
fn has_tombstone(&self, ttid: &TenantTimelineId) -> bool {
self.tombstones.contains_key(ttid) || self.tenant_tombstones.contains_key(&ttid.tenant_id)
}
/// Removes all blocking tombstones for the given timeline ID.
/// Returns `true` if there have been actual changes.
fn remove_tombstone(&mut self, ttid: &TenantTimelineId) -> bool {
self.tombstones.remove(ttid).is_some()
|| self.tenant_tombstones.remove(&ttid.tenant_id).is_some()
}
fn delete(&mut self, ttid: TenantTimelineId) {
self.timelines.remove(&ttid);
self.tombstones.insert(ttid, Instant::now());
}
fn add_tenant_tombstone(&mut self, tenant_id: TenantId) {
self.tenant_tombstones.insert(tenant_id, Instant::now());
}
}
/// A struct used to manage access to the global timelines map.
@@ -115,7 +99,6 @@ impl GlobalTimelines {
state: Mutex::new(GlobalTimelinesState {
timelines: HashMap::new(),
tombstones: HashMap::new(),
tenant_tombstones: HashMap::new(),
conf,
broker_active_set: Arc::new(TimelinesSet::default()),
global_rate_limiter: RateLimiter::new(1, 1),
@@ -262,7 +245,7 @@ impl GlobalTimelines {
return Ok(timeline);
}
if state.has_tombstone(&ttid) {
if state.tombstones.contains_key(&ttid) {
anyhow::bail!("Timeline {ttid} is deleted, refusing to recreate");
}
@@ -312,14 +295,13 @@ impl GlobalTimelines {
_ => {}
}
if check_tombstone {
if state.has_tombstone(&ttid) {
if state.tombstones.contains_key(&ttid) {
anyhow::bail!("timeline {ttid} is deleted, refusing to recreate");
}
} else {
// We may be have been asked to load a timeline that was previously deleted (e.g. from `pull_timeline.rs`). We trust
// that the human doing this manual intervention knows what they are doing, and remove its tombstone.
// It's also possible that we enter this when the tenant has been deleted, even if the timeline itself has never existed.
if state.remove_tombstone(&ttid) {
if state.tombstones.remove(&ttid).is_some() {
warn!("un-deleted timeline {ttid}");
}
}
@@ -500,7 +482,6 @@ impl GlobalTimelines {
let tli_res = {
let state = self.state.lock().unwrap();
// Do NOT check tenant tombstones here: those were set earlier
if state.tombstones.contains_key(ttid) {
// Presence of a tombstone guarantees that a previous deletion has completed and there is no work to do.
info!("Timeline {ttid} was already deleted");
@@ -576,10 +557,6 @@ impl GlobalTimelines {
action: DeleteOrExclude,
) -> Result<HashMap<TenantTimelineId, TimelineDeleteResult>> {
info!("deleting all timelines for tenant {}", tenant_id);
// Adding a tombstone before getting the timelines to prevent new timeline additions
self.state.lock().unwrap().add_tenant_tombstone(*tenant_id);
let to_delete = self.get_all_for_tenant(*tenant_id);
let mut err = None;
@@ -623,9 +600,6 @@ impl GlobalTimelines {
state
.tombstones
.retain(|_, v| now.duration_since(*v) < *tombstone_ttl);
state
.tenant_tombstones
.retain(|_, v| now.duration_since(*v) < *tombstone_ttl);
}
}

View File

@@ -87,7 +87,6 @@ impl WalProposer {
let config = Config {
ttid,
safekeepers_list: addrs,
safekeeper_conninfo_options: String::new(),
safekeeper_reconnect_timeout: 1000,
safekeeper_connection_timeout: 5000,
sync_safekeepers,

View File

@@ -32,6 +32,12 @@ BENCHMARKS_DURATION_QUERY = """
# the total duration varies from 8 to 40 minutes.
# We use some pre-collected durations as a fallback to have a better distribution.
FALLBACK_DURATION = {
"test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py::test_pageserver_max_throughput_getpage_at_latest_lsn[1-13-30]": 400.15,
"test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py::test_pageserver_max_throughput_getpage_at_latest_lsn[1-6-30]": 372.521,
"test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py::test_pageserver_max_throughput_getpage_at_latest_lsn[10-13-30]": 420.017,
"test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py::test_pageserver_max_throughput_getpage_at_latest_lsn[10-6-30]": 373.769,
"test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py::test_pageserver_max_throughput_getpage_at_latest_lsn[100-13-30]": 678.742,
"test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py::test_pageserver_max_throughput_getpage_at_latest_lsn[100-6-30]": 512.135,
"test_runner/performance/test_branch_creation.py::test_branch_creation_heavy_write[20]": 58.036,
"test_runner/performance/test_branch_creation.py::test_branch_creation_many_relations": 22.104,
"test_runner/performance/test_branch_creation.py::test_branch_creation_many[1024]": 126.073,

View File

@@ -17,14 +17,12 @@ use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use camino::Utf8PathBuf;
use clap::{Parser, command};
use futures::future::OptionFuture;
use futures_core::Stream;
use futures_util::StreamExt;
use http_body_util::combinators::BoxBody;
use http_body_util::{Empty, Full};
use http_body_util::Full;
use http_utils::tls_certs::ReloadingCertificateResolver;
use hyper::body::Incoming;
use hyper::header::CONTENT_TYPE;
@@ -48,6 +46,7 @@ use tokio::net::TcpListener;
use tokio::sync::broadcast;
use tokio::sync::broadcast::error::RecvError;
use tokio::time;
use tonic::body::{self, BoxBody, empty_body};
use tonic::codegen::Service;
use tonic::{Code, Request, Response, Status};
use tracing::*;
@@ -635,7 +634,7 @@ impl BrokerService for Broker {
// We serve only metrics and healthcheck through http1.
async fn http1_handler(
req: hyper::Request<Incoming>,
) -> Result<hyper::Response<BoxBody<Bytes, Infallible>>, Infallible> {
) -> Result<hyper::Response<BoxBody>, Infallible> {
let resp = match (req.method(), req.uri().path()) {
(&Method::GET, "/metrics") => {
let mut buffer = vec![];
@@ -646,16 +645,16 @@ async fn http1_handler(
hyper::Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, encoder.format_type())
.body(BoxBody::new(Full::new(Bytes::from(buffer))))
.body(body::boxed(Full::new(bytes::Bytes::from(buffer))))
.unwrap()
}
(&Method::GET, "/status") => hyper::Response::builder()
.status(StatusCode::OK)
.body(BoxBody::new(Empty::new()))
.body(empty_body())
.unwrap(),
_ => hyper::Response::builder()
.status(StatusCode::NOT_FOUND)
.body(BoxBody::new(Empty::new()))
.body(empty_body())
.unwrap(),
};
Ok(resp)

View File

@@ -3823,13 +3823,6 @@ impl Service {
.await;
failpoint_support::sleep_millis_async!("tenant-create-timeline-shared-lock");
let is_import = create_req.is_import();
let read_only = matches!(
create_req.mode,
models::TimelineCreateRequestMode::Branch {
read_only: true,
..
}
);
if is_import {
// Ensure that there is no split on-going.
@@ -3902,13 +3895,13 @@ impl Service {
}
None
} else if safekeepers || read_only {
} else if safekeepers {
// Note that for imported timelines, we do not create the timeline on the safekeepers
// straight away. Instead, we do it once the import finalized such that we know what
// start LSN to provide for the safekeepers. This is done in
// [`Self::finalize_timeline_import`].
let res = self
.tenant_timeline_create_safekeepers(tenant_id, &timeline_info, read_only)
.tenant_timeline_create_safekeepers(tenant_id, &timeline_info)
.instrument(tracing::info_span!("timeline_create_safekeepers", %tenant_id, timeline_id=%timeline_info.timeline_id))
.await?;
Some(res)
@@ -3922,11 +3915,6 @@ impl Service {
})
}
#[instrument(skip_all, fields(
tenant_id=%req.tenant_shard_id.tenant_id,
shard_id=%req.tenant_shard_id.shard_slug(),
timeline_id=%req.timeline_id,
))]
pub(crate) async fn handle_timeline_shard_import_progress(
self: &Arc<Self>,
req: TimelineImportStatusRequest,
@@ -3976,11 +3964,6 @@ impl Service {
})
}
#[instrument(skip_all, fields(
tenant_id=%req.tenant_shard_id.tenant_id,
shard_id=%req.tenant_shard_id.shard_slug(),
timeline_id=%req.timeline_id,
))]
pub(crate) async fn handle_timeline_shard_import_progress_upcall(
self: &Arc<Self>,
req: PutTimelineImportStatusRequest,

View File

@@ -208,7 +208,6 @@ impl Service {
self: &Arc<Self>,
tenant_id: TenantId,
timeline_info: &TimelineInfo,
read_only: bool,
) -> Result<SafekeepersInfo, ApiError> {
let timeline_id = timeline_info.timeline_id;
let pg_version = timeline_info.pg_version * 10000;
@@ -221,11 +220,7 @@ impl Service {
let start_lsn = timeline_info.last_record_lsn;
// Choose initial set of safekeepers respecting affinity
let sks = if !read_only {
self.safekeepers_for_new_timeline().await?
} else {
Vec::new()
};
let sks = self.safekeepers_for_new_timeline().await?;
let sks_persistence = sks.iter().map(|sk| sk.id.0 as i64).collect::<Vec<_>>();
// Add timeline to db
let mut timeline_persist = TimelinePersistence {
@@ -258,16 +253,6 @@ impl Service {
)));
}
}
let ret = SafekeepersInfo {
generation: timeline_persist.generation as u32,
safekeepers: sks.clone(),
tenant_id,
timeline_id,
};
if read_only {
return Ok(ret);
}
// Create the timeline on a quorum of safekeepers
let remaining = self
.tenant_timeline_create_safekeepers_quorum(
@@ -331,7 +316,12 @@ impl Service {
}
}
Ok(ret)
Ok(SafekeepersInfo {
generation: timeline_persist.generation as u32,
safekeepers: sks,
tenant_id,
timeline_id,
})
}
pub(crate) async fn tenant_timeline_create_safekeepers_until_success(
@@ -346,10 +336,8 @@ impl Service {
return Err(TimelineImportFinalizeError::ShuttingDown);
}
// This function is only used in non-read-only scenarios
let read_only = false;
let res = self
.tenant_timeline_create_safekeepers(tenant_id, &timeline_info, read_only)
.tenant_timeline_create_safekeepers(tenant_id, &timeline_info)
.await;
match res {
@@ -422,18 +410,6 @@ impl Service {
.chain(tl.sk_set.iter())
.collect::<HashSet<_>>();
// The timeline has no safekeepers: we need to delete it from the db manually,
// as no safekeeper reconciler will get to it
if all_sks.is_empty() {
if let Err(err) = self
.persistence
.delete_timeline(tenant_id, timeline_id)
.await
{
tracing::warn!(%tenant_id, %timeline_id, "couldn't delete timeline from db: {err}");
}
}
// Schedule reconciliations
for &sk_id in all_sks.iter() {
let pending_op = TimelinePendingOpPersistence {

View File

@@ -184,7 +184,6 @@ PAGESERVER_PER_TENANT_METRICS: tuple[str, ...] = (
"pageserver_evictions_with_low_residence_duration_total",
"pageserver_aux_file_estimated_size",
"pageserver_valid_lsn_lease_count",
"pageserver_tenant_offloaded_timelines",
counter("pageserver_tenant_throttling_count_accounted_start"),
counter("pageserver_tenant_throttling_count_accounted_finish"),
counter("pageserver_tenant_throttling_wait_usecs_sum"),

View File

@@ -404,29 +404,6 @@ class PageserverTracingConfig:
return ("tracing", value)
@dataclass
class PageserverImportConfig:
import_job_concurrency: int
import_job_soft_size_limit: int
import_job_checkpoint_threshold: int
@staticmethod
def default() -> PageserverImportConfig:
return PageserverImportConfig(
import_job_concurrency=4,
import_job_soft_size_limit=512 * 1024,
import_job_checkpoint_threshold=4,
)
def to_config_key_value(self) -> tuple[str, dict[str, Any]]:
value = {
"import_job_concurrency": self.import_job_concurrency,
"import_job_soft_size_limit": self.import_job_soft_size_limit,
"import_job_checkpoint_threshold": self.import_job_checkpoint_threshold,
}
return ("timeline_import_config", value)
class NeonEnvBuilder:
"""
Builder object to create a Neon runtime environment
@@ -477,7 +454,6 @@ class NeonEnvBuilder:
pageserver_wal_receiver_protocol: PageserverWalReceiverProtocol | None = None,
pageserver_get_vectored_concurrent_io: str | None = None,
pageserver_tracing_config: PageserverTracingConfig | None = None,
pageserver_import_config: PageserverImportConfig | None = None,
):
self.repo_dir = repo_dir
self.rust_log_override = rust_log_override
@@ -535,7 +511,6 @@ class NeonEnvBuilder:
)
self.pageserver_tracing_config = pageserver_tracing_config
self.pageserver_import_config = pageserver_import_config
self.pageserver_default_tenant_config_compaction_algorithm: dict[str, Any] | None = (
pageserver_default_tenant_config_compaction_algorithm
@@ -707,7 +682,7 @@ class NeonEnvBuilder:
log.info(
f"Copying pageserver tenants directory {tenants_from_dir} to {tenants_to_dir}"
)
subprocess.run(["cp", "-a", tenants_from_dir, tenants_to_dir], check=True)
shutil.copytree(tenants_from_dir, tenants_to_dir)
else:
log.info(
f"Creating overlayfs mount of pageserver tenants directory {tenants_from_dir} to {tenants_to_dir}"
@@ -723,9 +698,8 @@ class NeonEnvBuilder:
shutil.rmtree(self.repo_dir / "local_fs_remote_storage", ignore_errors=True)
if self.test_overlay_dir is None:
log.info("Copying local_fs_remote_storage directory from snapshot")
subprocess.run(
["cp", "-a", f"{repo_dir / 'local_fs_remote_storage'}", f"{self.repo_dir}"],
check=True,
shutil.copytree(
repo_dir / "local_fs_remote_storage", self.repo_dir / "local_fs_remote_storage"
)
else:
log.info("Creating overlayfs mount of local_fs_remote_storage directory from snapshot")
@@ -1204,10 +1178,6 @@ class NeonEnv:
self.pageserver_wal_receiver_protocol = config.pageserver_wal_receiver_protocol
self.pageserver_get_vectored_concurrent_io = config.pageserver_get_vectored_concurrent_io
self.pageserver_tracing_config = config.pageserver_tracing_config
if config.pageserver_import_config is None:
self.pageserver_import_config = PageserverImportConfig.default()
else:
self.pageserver_import_config = config.pageserver_import_config
# Create the neon_local's `NeonLocalInitConf`
cfg: dict[str, Any] = {
@@ -1253,7 +1223,6 @@ class NeonEnv:
# Create config for pageserver
http_auth_type = "NeonJWT" if config.auth_enabled else "Trust"
pg_auth_type = "NeonJWT" if config.auth_enabled else "Trust"
grpc_auth_type = "NeonJWT" if config.auth_enabled else "Trust"
for ps_id in range(
self.BASE_PAGESERVER_ID, self.BASE_PAGESERVER_ID + config.num_pageservers
):
@@ -1280,13 +1249,18 @@ class NeonEnv:
else None,
"pg_auth_type": pg_auth_type,
"http_auth_type": http_auth_type,
"grpc_auth_type": grpc_auth_type,
"availability_zone": availability_zone,
# Disable pageserver disk syncs in tests: when running tests concurrently, this avoids
# the pageserver taking a long time to start up due to syncfs flushing other tests' data
"no_sync": True,
# Look for gaps in WAL received from safekeepeers
"validate_wal_contiguity": True,
# TODO(vlad): make these configurable through the builder
"timeline_import_config": {
"import_job_concurrency": 4,
"import_job_soft_size_limit": 512 * 1024,
"import_job_checkpoint_threshold": 4,
},
}
# Batching (https://github.com/neondatabase/neon/issues/9377):
@@ -1348,12 +1322,6 @@ class NeonEnv:
ps_cfg[key] = value
if self.pageserver_import_config is not None:
key, value = self.pageserver_import_config.to_config_key_value()
if key not in ps_cfg:
ps_cfg[key] = value
# Create a corresponding NeonPageserver object
ps = NeonPageserver(
self, ps_id, port=pageserver_port, az_id=ps_cfg["availability_zone"]

View File

@@ -14,7 +14,7 @@ from fixtures.neon_fixtures import (
PgBin,
wait_for_last_flush_lsn,
)
from fixtures.utils import get_scale_for_db, humantime_to_ms
from fixtures.utils import get_scale_for_db, humantime_to_ms, skip_on_ci
from performance.pageserver.util import setup_pageserver_with_tenants
@@ -36,6 +36,9 @@ if TYPE_CHECKING:
@pytest.mark.parametrize("pgbench_scale", [get_scale_for_db(200)])
@pytest.mark.parametrize("n_tenants", [500])
@pytest.mark.timeout(10000)
@skip_on_ci(
"This test needs lot of resources and should run on dedicated HW, not in github action runners as part of CI"
)
def test_pageserver_characterize_throughput_with_n_tenants(
neon_env_builder: NeonEnvBuilder,
zenbenchmark: NeonBenchmarker,
@@ -60,6 +63,9 @@ def test_pageserver_characterize_throughput_with_n_tenants(
@pytest.mark.parametrize("n_clients", [1, 64])
@pytest.mark.parametrize("n_tenants", [1])
@pytest.mark.timeout(2400)
@skip_on_ci(
"This test needs lot of resources and should run on dedicated HW, not in github action runners as part of CI"
)
def test_pageserver_characterize_latencies_with_1_client_and_throughput_with_many_clients_one_tenant(
neon_env_builder: NeonEnvBuilder,
zenbenchmark: NeonBenchmarker,

View File

@@ -1,77 +0,0 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from fixtures.utils import wait_until
if TYPE_CHECKING:
from fixtures.neon_fixtures import NeonEnvBuilder
def test_basebackup_cache(neon_env_builder: NeonEnvBuilder):
"""
Simple test for basebackup cache.
1. Check that we always hit the cache after compute restart.
2. Check that we eventually delete old basebackup files, but not the latest one.
3. Check that we delete basebackup file for timeline with active compute.
"""
neon_env_builder.pageserver_config_override = """
tenant_config = { basebackup_cache_enabled = true }
basebackup_cache_config = { cleanup_period = '1s' }
"""
env = neon_env_builder.init_start()
ep = env.endpoints.create("main")
ps = env.pageserver
ps_http = ps.http_client()
# 1. Check that we always hit the cache after compute restart.
for i in range(3):
ep.start()
ep.stop()
def check_metrics(i=i):
metrics = ps_http.get_metrics()
# Never miss.
# The first time compute_ctl sends `get_basebackup` with lsn=None, we do not cache such requests.
# All other requests should be a hit
assert (
metrics.query_one(
"pageserver_basebackup_cache_read_total", {"result": "miss"}
).value
== 0
)
# All but the first requests are hits.
assert (
metrics.query_one("pageserver_basebackup_cache_read_total", {"result": "hit"}).value
== i
)
# Every compute shut down should trigger a prepare reuest.
assert (
metrics.query_one(
"pageserver_basebackup_cache_prepare_total", {"result": "ok"}
).value
== i + 1
)
wait_until(check_metrics)
# 2. Check that we eventually delete old basebackup files, but not the latest one.
def check_bb_file_count():
bb_files = list(ps.workdir.joinpath("basebackup_cache").iterdir())
# tmp dir + 1 basebackup file.
assert len(bb_files) == 2
wait_until(check_bb_file_count)
# 3. Check that we delete basebackup file for timeline with active compute.
ep.start()
ep.safe_psql("create table t1 as select generate_series(1, 10) as n")
def check_bb_dir_empty():
bb_files = list(ps.workdir.joinpath("basebackup_cache").iterdir())
# only tmp dir.
assert len(bb_files) == 1
wait_until(check_bb_dir_empty)

Some files were not shown because too many files have changed in this diff Show More