diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index a887db2ab1..9f2fa3d52c 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -314,7 +314,8 @@ jobs: test_selection: performance run_in_parallel: false save_perf_report: ${{ github.ref_name == 'main' }} - extra_params: --splits 5 --group ${{ matrix.pytest_split_group }} + # test_pageserver_max_throughput_getpage_at_latest_lsn is run in separate workflow periodic_pagebench.yml because it needs snapshots + extra_params: --splits 5 --group ${{ matrix.pytest_split_group }} --ignore=test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py benchmark_durations: ${{ needs.get-benchmarks-durations.outputs.json }} pg_version: v16 aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} diff --git a/.github/workflows/periodic_pagebench.yml b/.github/workflows/periodic_pagebench.yml index 532da435c2..317db94052 100644 --- a/.github/workflows/periodic_pagebench.yml +++ b/.github/workflows/periodic_pagebench.yml @@ -1,4 +1,4 @@ -name: Periodic pagebench performance test on dedicated EC2 machine in eu-central-1 region +name: Periodic pagebench performance test on unit-perf hetzner runner on: schedule: @@ -8,7 +8,7 @@ on: # │ │ ┌───────────── day of the month (1 - 31) # │ │ │ ┌───────────── month (1 - 12 or JAN-DEC) # │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT) - - cron: '0 */3 * * *' # Runs every 3 hours + - cron: '0 */4 * * *' # Runs every 4 hours workflow_dispatch: # Allows manual triggering of the workflow inputs: commit_hash: @@ -16,6 +16,11 @@ on: description: 'The long neon repo commit hash for the system under test (pageserver) to be tested.' required: false default: '' + recreate_snapshots: + type: boolean + description: 'Recreate snapshots - !!!WARNING!!! We should only recreate snapshots if the previous ones are no longer compatible. Otherwise benchmarking results are not comparable across runs.' + required: false + default: false defaults: run: @@ -29,13 +34,13 @@ permissions: contents: read jobs: - trigger_bench_on_ec2_machine_in_eu_central_1: + run_periodic_pagebench_test: permissions: id-token: write # aws-actions/configure-aws-credentials statuses: write contents: write pull-requests: write - runs-on: [ self-hosted, small ] + runs-on: [ self-hosted, unit-perf ] container: image: ghcr.io/neondatabase/build-tools:pinned-bookworm credentials: @@ -44,10 +49,13 @@ jobs: options: --init timeout-minutes: 360 # Set the timeout to 6 hours env: - API_KEY: ${{ secrets.PERIODIC_PAGEBENCH_EC2_RUNNER_API_KEY }} RUN_ID: ${{ github.run_id }} - AWS_DEFAULT_REGION : "eu-central-1" - AWS_INSTANCE_ID : "i-02a59a3bf86bc7e74" + DEFAULT_PG_VERSION: 16 + BUILD_TYPE: release + RUST_BACKTRACE: 1 + # NEON_ENV_BUILDER_USE_OVERLAYFS_FOR_SNAPSHOTS: 1 - doesn't work without root in container + S3_BUCKET: neon-github-public-dev + PERF_TEST_RESULT_CONNSTR: "${{ secrets.PERF_TEST_RESULT_CONNSTR }}" steps: # we don't need the neon source code because we run everything remotely # however we still need the local github actions to run the allure step below @@ -56,99 +64,194 @@ jobs: with: egress-policy: audit - - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set up the environment which depends on $RUNNER_TEMP on nvme drive + id: set-env + shell: bash -euxo pipefail {0} + run: | + { + echo "NEON_DIR=${RUNNER_TEMP}/neon" + echo "NEON_BIN=${RUNNER_TEMP}/neon/bin" + echo "POSTGRES_DISTRIB_DIR=${RUNNER_TEMP}/neon/pg_install" + echo "LD_LIBRARY_PATH=${RUNNER_TEMP}/neon/pg_install/v${DEFAULT_PG_VERSION}/lib" + echo "BACKUP_DIR=${RUNNER_TEMP}/instance_store/saved_snapshots" + echo "TEST_OUTPUT=${RUNNER_TEMP}/neon/test_output" + echo "PERF_REPORT_DIR=${RUNNER_TEMP}/neon/test_output/perf-report-local" + echo "ALLURE_DIR=${RUNNER_TEMP}/neon/test_output/allure-results" + echo "ALLURE_RESULTS_DIR=${RUNNER_TEMP}/neon/test_output/allure-results/results" + } >> "$GITHUB_ENV" - - name: Show my own (github runner) external IP address - usefull for IP allowlisting - run: curl https://ifconfig.me + echo "allure_results_dir=${RUNNER_TEMP}/neon/test_output/allure-results/results" >> "$GITHUB_OUTPUT" - - name: Assume AWS OIDC role that allows to manage (start/stop/describe... EC machine) - uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2 + - uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2 with: aws-region: eu-central-1 - role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_MANAGE_BENCHMARK_EC2_VMS_ARN }} - role-duration-seconds: 3600 - - - name: Start EC2 instance and wait for the instance to boot up - run: | - aws ec2 start-instances --instance-ids $AWS_INSTANCE_ID - aws ec2 wait instance-running --instance-ids $AWS_INSTANCE_ID - sleep 60 # sleep some time to allow cloudinit and our API server to start up - - - name: Determine public IP of the EC2 instance and set env variable EC2_MACHINE_URL_US - run: | - public_ip=$(aws ec2 describe-instances --instance-ids $AWS_INSTANCE_ID --query 'Reservations[*].Instances[*].PublicIpAddress' --output text) - echo "Public IP of the EC2 instance: $public_ip" - echo "EC2_MACHINE_URL_US=https://${public_ip}:8443" >> $GITHUB_ENV - + role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} + role-duration-seconds: 18000 # max 5 hours (needed in case commit hash is still being built) - name: Determine commit hash + id: commit_hash + shell: bash -euxo pipefail {0} env: INPUT_COMMIT_HASH: ${{ github.event.inputs.commit_hash }} run: | - if [ -z "$INPUT_COMMIT_HASH" ]; then - echo "COMMIT_HASH=$(curl -s https://api.github.com/repos/neondatabase/neon/commits/main | jq -r '.sha')" >> $GITHUB_ENV + if [[ -z "${INPUT_COMMIT_HASH}" ]]; then + COMMIT_HASH=$(curl -s https://api.github.com/repos/neondatabase/neon/commits/main | jq -r '.sha') + echo "COMMIT_HASH=$COMMIT_HASH" >> $GITHUB_ENV + echo "commit_hash=$COMMIT_HASH" >> "$GITHUB_OUTPUT" echo "COMMIT_HASH_TYPE=latest" >> $GITHUB_ENV else - echo "COMMIT_HASH=$INPUT_COMMIT_HASH" >> $GITHUB_ENV + COMMIT_HASH="${INPUT_COMMIT_HASH}" + echo "COMMIT_HASH=$COMMIT_HASH" >> $GITHUB_ENV + echo "commit_hash=$COMMIT_HASH" >> "$GITHUB_OUTPUT" echo "COMMIT_HASH_TYPE=manual" >> $GITHUB_ENV fi + - name: Checkout the neon repository at given commit hash + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + ref: ${{ steps.commit_hash.outputs.commit_hash }} - - name: Start Bench with run_id + # does not reuse ./.github/actions/download because we need to download the artifact for the given commit hash + # example artifact + # s3://neon-github-public-dev/artifacts/48b870bc078bd2c450eb7b468e743b9c118549bf/15036827400/1/neon-Linux-X64-release-artifact.tar.zst /instance_store/artifacts/neon-Linux-release-artifact.tar.zst + - name: Determine artifact S3_KEY for given commit hash and download and extract artifact + id: artifact_prefix + shell: bash -euxo pipefail {0} + env: + ARCHIVE: ${{ runner.temp }}/downloads/neon-${{ runner.os }}-${{ runner.arch }}-release-artifact.tar.zst + COMMIT_HASH: ${{ env.COMMIT_HASH }} + COMMIT_HASH_TYPE: ${{ env.COMMIT_HASH_TYPE }} run: | - curl -k -X 'POST' \ - "${EC2_MACHINE_URL_US}/start_test/${GITHUB_RUN_ID}" \ - -H 'accept: application/json' \ - -H 'Content-Type: application/json' \ - -H "Authorization: Bearer $API_KEY" \ - -d "{\"neonRepoCommitHash\": \"${COMMIT_HASH}\", \"neonRepoCommitHashType\": \"${COMMIT_HASH_TYPE}\"}" + attempt=0 + max_attempts=24 # 5 minutes * 24 = 2 hours - - name: Poll Test Status - id: poll_step - run: | - status="" - while [[ "$status" != "failure" && "$status" != "success" ]]; do - response=$(curl -k -X 'GET' \ - "${EC2_MACHINE_URL_US}/test_status/${GITHUB_RUN_ID}" \ - -H 'accept: application/json' \ - -H "Authorization: Bearer $API_KEY") - echo "Response: $response" - set +x - status=$(echo $response | jq -r '.status') - echo "Test status: $status" - if [[ "$status" == "failure" ]]; then - echo "Test failed" - exit 1 # Fail the job step if status is failure - elif [[ "$status" == "success" || "$status" == "null" ]]; then + while [[ $attempt -lt $max_attempts ]]; do + # the following command will fail until the artifacts are available ... + S3_KEY=$(aws s3api list-objects-v2 --bucket "$S3_BUCKET" --prefix "artifacts/$COMMIT_HASH/" \ + | jq -r '.Contents[]?.Key' \ + | grep "neon-${{ runner.os }}-${{ runner.arch }}-release-artifact.tar.zst" \ + | sort --version-sort \ + | tail -1) || true # ... thus ignore errors from the command + if [[ -n "${S3_KEY}" ]]; then + echo "Artifact found: $S3_KEY" + echo "S3_KEY=$S3_KEY" >> $GITHUB_ENV break - elif [[ "$status" == "too_many_runs" ]]; then - echo "Too many runs already running" - echo "too_many_runs=true" >> "$GITHUB_OUTPUT" - exit 1 fi - - sleep 60 # Poll every 60 seconds + + # Increment attempt counter and sleep for 5 minutes + attempt=$((attempt + 1)) + echo "Attempt $attempt of $max_attempts to find artifacts in S3 bucket s3://$S3_BUCKET/artifacts/$COMMIT_HASH failed. Retrying in 5 minutes..." + sleep 300 # Sleep for 5 minutes done - - name: Retrieve Test Logs - if: always() && steps.poll_step.outputs.too_many_runs != 'true' - run: | - curl -k -X 'GET' \ - "${EC2_MACHINE_URL_US}/test_log/${GITHUB_RUN_ID}" \ - -H 'accept: application/gzip' \ - -H "Authorization: Bearer $API_KEY" \ - --output "test_log_${GITHUB_RUN_ID}.gz" + if [[ -z "${S3_KEY}" ]]; then + echo "Error: artifact not found in S3 bucket s3://$S3_BUCKET/artifacts/$COMMIT_HASH" after 2 hours + else + mkdir -p $(dirname $ARCHIVE) + time aws s3 cp --only-show-errors s3://$S3_BUCKET/${S3_KEY} ${ARCHIVE} + mkdir -p ${NEON_DIR} + time tar -xf ${ARCHIVE} -C ${NEON_DIR} + rm -f ${ARCHIVE} + fi - - name: Unzip Test Log and Print it into this job's log - if: always() && steps.poll_step.outputs.too_many_runs != 'true' + - name: Download snapshots from S3 + if: ${{ github.event_name != 'workflow_dispatch' || github.event.inputs.recreate_snapshots == 'false' || github.event.inputs.recreate_snapshots == '' }} + id: download_snapshots + shell: bash -euxo pipefail {0} run: | - gzip -d "test_log_${GITHUB_RUN_ID}.gz" - cat "test_log_${GITHUB_RUN_ID}" + # Download the snapshots from S3 + mkdir -p ${TEST_OUTPUT} + mkdir -p $BACKUP_DIR + cd $BACKUP_DIR + mkdir parts + cd parts + PART=$(aws s3api list-objects-v2 --bucket $S3_BUCKET --prefix performance/pagebench/ \ + | jq -r '.Contents[]?.Key' \ + | grep -E 'shared-snapshots-[0-9]{4}-[0-9]{2}-[0-9]{2}' \ + | sort \ + | tail -1) + echo "Latest PART: $PART" + if [[ -z "$PART" ]]; then + echo "ERROR: No matching S3 key found" >&2 + exit 1 + fi + S3_KEY=$(dirname $PART) + time aws s3 cp --only-show-errors --recursive s3://${S3_BUCKET}/$S3_KEY/ . + cd $TEST_OUTPUT + time cat $BACKUP_DIR/parts/* | zstdcat | tar --extract --preserve-permissions + rm -rf ${BACKUP_DIR} + + - name: Cache poetry deps + uses: actions/cache@v4 + with: + path: ~/.cache/pypoetry/virtualenvs + key: v2-${{ runner.os }}-${{ runner.arch }}-python-deps-bookworm-${{ hashFiles('poetry.lock') }} + + - name: Install Python deps + shell: bash -euxo pipefail {0} + run: ./scripts/pysync + + # we need high number of open files for pagebench + - name: show ulimits + shell: bash -euxo pipefail {0} + run: | + ulimit -a + + - name: Run pagebench testcase + shell: bash -euxo pipefail {0} + env: + CI: false # need to override this env variable set by github to enforce using snapshots + run: | + export PLATFORM=hetzner-unit-perf-${COMMIT_HASH_TYPE} + # report the commit hash of the neon repository in the revision of the test results + export GITHUB_SHA=${COMMIT_HASH} + rm -rf ${PERF_REPORT_DIR} + rm -rf ${ALLURE_RESULTS_DIR} + mkdir -p ${PERF_REPORT_DIR} + mkdir -p ${ALLURE_RESULTS_DIR} + PARAMS="--alluredir=${ALLURE_RESULTS_DIR} --tb=short --verbose -rA" + EXTRA_PARAMS="--out-dir ${PERF_REPORT_DIR} --durations-path $TEST_OUTPUT/benchmark_durations.json" + # run only two selected tests + # environment set by parent: + # RUST_BACKTRACE=1 DEFAULT_PG_VERSION=16 BUILD_TYPE=release + ./scripts/pytest ${PARAMS} test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py::test_pageserver_characterize_throughput_with_n_tenants ${EXTRA_PARAMS} + ./scripts/pytest ${PARAMS} test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py::test_pageserver_characterize_latencies_with_1_client_and_throughput_with_many_clients_one_tenant ${EXTRA_PARAMS} + + - name: upload the performance metrics to the Neon performance database which is used by grafana dashboards to display the results + shell: bash -euxo pipefail {0} + run: | + export REPORT_FROM="$PERF_REPORT_DIR" + export GITHUB_SHA=${COMMIT_HASH} + time ./scripts/generate_and_push_perf_report.sh + + - name: Upload test results + if: ${{ !cancelled() }} + uses: ./.github/actions/allure-report-store + with: + report-dir: ${{ steps.set-env.outputs.allure_results_dir }} + unique-key: ${{ env.BUILD_TYPE }}-${{ env.DEFAULT_PG_VERSION }}-${{ runner.arch }} + aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} - name: Create Allure report + id: create-allure-report if: ${{ !cancelled() }} uses: ./.github/actions/allure-report-generate with: aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }} + - name: Upload snapshots + if: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.recreate_snapshots != 'false' && github.event.inputs.recreate_snapshots != '' }} + id: upload_snapshots + shell: bash -euxo pipefail {0} + run: | + mkdir -p $BACKUP_DIR + cd $TEST_OUTPUT + tar --create --preserve-permissions --file - shared-snapshots | zstd -o $BACKUP_DIR/shared_snapshots.tar.zst + cd $BACKUP_DIR + mkdir parts + split -b 1G shared_snapshots.tar.zst ./parts/shared_snapshots.tar.zst.part. + SNAPSHOT_DATE=$(date +%F) # YYYY-MM-DD + cd parts + time aws s3 cp --recursive . s3://${S3_BUCKET}/performance/pagebench/shared-snapshots-${SNAPSHOT_DATE}/ + - name: Post to a Slack channel if: ${{ github.event.schedule && failure() }} uses: slackapi/slack-github-action@fcfb566f8b0aab22203f066d80ca1d7e4b5d05b3 # v1.27.1 @@ -157,26 +260,22 @@ jobs: slack-message: "Periodic pagebench testing on dedicated hardware: ${{ job.status }}\n${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}" env: SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} - + - name: Cleanup Test Resources if: always() + shell: bash -euxo pipefail {0} + env: + ARCHIVE: ${{ runner.temp }}/downloads/neon-${{ runner.os }}-${{ runner.arch }}-release-artifact.tar.zst run: | - curl -k -X 'POST' \ - "${EC2_MACHINE_URL_US}/cleanup_test/${GITHUB_RUN_ID}" \ - -H 'accept: application/json' \ - -H "Authorization: Bearer $API_KEY" \ - -d '' + # Cleanup the test resources + if [[ -d "${BACKUP_DIR}" ]]; then + rm -rf ${BACKUP_DIR} + fi + if [[ -d "${TEST_OUTPUT}" ]]; then + rm -rf ${TEST_OUTPUT} + fi + if [[ -d "${NEON_DIR}" ]]; then + rm -rf ${NEON_DIR} + fi + rm -rf $(dirname $ARCHIVE) - - name: Assume AWS OIDC role that allows to manage (start/stop/describe... EC machine) - if: always() && steps.poll_step.outputs.too_many_runs != 'true' - uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2 - with: - aws-region: eu-central-1 - role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_MANAGE_BENCHMARK_EC2_VMS_ARN }} - role-duration-seconds: 3600 - - - name: Stop EC2 instance and wait for the instance to be stopped - if: always() && steps.poll_step.outputs.too_many_runs != 'true' - run: | - aws ec2 stop-instances --instance-ids $AWS_INSTANCE_ID - aws ec2 wait instance-stopped --instance-ids $AWS_INSTANCE_ID diff --git a/Cargo.lock b/Cargo.lock index b52ecec128..89351432c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1276,7 +1276,7 @@ version = "0.1.0" dependencies = [ "anyhow", "chrono", - "indexmap 2.0.1", + "indexmap 2.9.0", "jsonwebtoken", "regex", "remote_storage", @@ -1308,7 +1308,7 @@ dependencies = [ "flate2", "futures", "http 1.1.0", - "indexmap 2.0.1", + "indexmap 2.9.0", "itertools 0.10.5", "jsonwebtoken", "metrics", @@ -2597,7 +2597,7 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.9", - "indexmap 2.0.1", + "indexmap 2.9.0", "slab", "tokio", "tokio-util", @@ -2616,7 +2616,7 @@ dependencies = [ "futures-sink", "futures-util", "http 1.1.0", - "indexmap 2.0.1", + "indexmap 2.9.0", "slab", "tokio", "tokio-util", @@ -2863,14 +2863,14 @@ dependencies = [ "pprof", "regex", "routerify", - "rustls 0.23.18", + "rustls 0.23.27", "rustls-pemfile 2.1.1", "serde", "serde_json", "serde_path_to_error", "thiserror 1.0.69", "tokio", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.2", "tokio-stream", "tokio-util", "tracing", @@ -3200,12 +3200,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.0.1" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad227c3af19d4914570ad36d30409928b75967c298feb9ea1969db3a610bb14e" +checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" dependencies = [ "equivalent", - "hashbrown 0.14.5", + "hashbrown 0.15.2", "serde", ] @@ -3228,7 +3228,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "232929e1d75fe899576a3d5c7416ad0d88dbfbb3c3d6aa00873a7408a50ddb88" dependencies = [ "ahash", - "indexmap 2.0.1", + "indexmap 2.9.0", "is-terminal", "itoa", "log", @@ -3251,7 +3251,7 @@ dependencies = [ "crossbeam-utils", "dashmap 6.1.0", "env_logger", - "indexmap 2.0.1", + "indexmap 2.9.0", "itoa", "log", "num-format", @@ -4112,7 +4112,7 @@ dependencies = [ "opentelemetry-http", "opentelemetry-proto", "opentelemetry_sdk", - "prost 0.13.3", + "prost 0.13.5", "reqwest", "thiserror 1.0.69", ] @@ -4125,8 +4125,8 @@ checksum = "a6e05acbfada5ec79023c85368af14abd0b307c015e9064d249b2a950ef459a6" dependencies = [ "opentelemetry", "opentelemetry_sdk", - "prost 0.13.3", - "tonic", + "prost 0.13.5", + "tonic 0.12.3", ] [[package]] @@ -4321,6 +4321,7 @@ dependencies = [ "pageserver_api", "pageserver_client", "pageserver_compaction", + "pageserver_page_api", "pem", "pin-project-lite", "postgres-protocol", @@ -4329,6 +4330,7 @@ dependencies = [ "postgres_connection", "postgres_ffi", "postgres_initdb", + "posthog_client_lite", "pprof", "pq_proto", "procfs", @@ -4339,7 +4341,7 @@ dependencies = [ "reqwest", "rpds", "rstest", - "rustls 0.23.18", + "rustls 0.23.27", "scopeguard", "send-future", "serde", @@ -4358,11 +4360,13 @@ dependencies = [ "tokio-epoll-uring", "tokio-io-timeout", "tokio-postgres", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.2", "tokio-stream", "tokio-tar", "tokio-util", "toml_edit", + "tonic 0.13.1", + "tonic-reflection", "tracing", "tracing-utils", "twox-hash", @@ -4455,9 +4459,15 @@ dependencies = [ name = "pageserver_page_api" version = "0.1.0" dependencies = [ - "prost 0.13.3", - "tonic", + "bytes", + "pageserver_api", + "postgres_ffi", + "prost 0.13.5", + "smallvec", + "thiserror 1.0.69", + "tonic 0.13.1", "tonic-build", + "utils", "workspace_hack", ] @@ -4837,14 +4847,14 @@ dependencies = [ "bytes", "once_cell", "pq_proto", - "rustls 0.23.18", + "rustls 0.23.27", "rustls-pemfile 2.1.1", "serde", "thiserror 1.0.69", "tokio", "tokio-postgres", "tokio-postgres-rustls", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.2", "tokio-util", "tracing", ] @@ -4898,11 +4908,16 @@ name = "posthog_client_lite" version = "0.1.0" dependencies = [ "anyhow", + "arc-swap", "reqwest", "serde", "serde_json", "sha2", "thiserror 1.0.69", + "tokio", + "tokio-util", + "tracing", + "tracing-utils", "workspace_hack", ] @@ -4951,7 +4966,7 @@ dependencies = [ "inferno 0.12.0", "num", "paste", - "prost 0.13.3", + "prost 0.13.5", ] [[package]] @@ -5056,12 +5071,12 @@ dependencies = [ [[package]] name = "prost" -version = "0.13.3" +version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b0487d90e047de87f984913713b85c601c05609aad5b0df4b4573fbf69aa13f" +checksum = "2796faa41db3ec313a31f7624d9286acf277b52de526150b7e69f3debf891ee5" dependencies = [ "bytes", - "prost-derive 0.13.3", + "prost-derive 0.13.5", ] [[package]] @@ -5099,7 +5114,7 @@ dependencies = [ "once_cell", "petgraph", "prettyplease", - "prost 0.13.3", + "prost 0.13.5", "prost-types 0.13.3", "regex", "syn 2.0.100", @@ -5121,9 +5136,9 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.13.3" +version = "0.13.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9552f850d5f0964a4e4d0bf306459ac29323ddfbae05e35a7c0d35cb0803cc5" +checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" dependencies = [ "anyhow", "itertools 0.12.1", @@ -5147,7 +5162,7 @@ version = "0.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4759aa0d3a6232fb8dbdb97b61de2c20047c68aca932c7ed76da9d788508d670" dependencies = [ - "prost 0.13.3", + "prost 0.13.5", ] [[package]] @@ -5195,7 +5210,7 @@ dependencies = [ "hyper 0.14.30", "hyper 1.4.1", "hyper-util", - "indexmap 2.0.1", + "indexmap 2.9.0", "ipnet", "itertools 0.10.5", "itoa", @@ -5229,7 +5244,7 @@ dependencies = [ "rsa", "rstest", "rustc-hash 1.1.0", - "rustls 0.23.18", + "rustls 0.23.27", "rustls-native-certs 0.8.0", "rustls-pemfile 2.1.1", "scopeguard", @@ -5248,7 +5263,7 @@ dependencies = [ "tokio", "tokio-postgres", "tokio-postgres2", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.2", "tokio-tungstenite 0.21.0", "tokio-util", "tracing", @@ -5472,13 +5487,13 @@ dependencies = [ "num-bigint", "percent-encoding", "pin-project-lite", - "rustls 0.23.18", + "rustls 0.23.27", "rustls-native-certs 0.8.0", "ryu", "sha1_smol", "socket2", "tokio", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.2", "tokio-util", "url", ] @@ -5926,15 +5941,15 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.18" +version = "0.23.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c9cc1d47e243d655ace55ed38201c19ae02c148ae56412ab8750e8f0166ab7f" +checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321" dependencies = [ "log", "once_cell", "ring", "rustls-pki-types", - "rustls-webpki 0.102.8", + "rustls-webpki 0.103.3", "subtle", "zeroize", ] @@ -6023,6 +6038,17 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustls-webpki" +version = "0.103.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.12" @@ -6074,7 +6100,7 @@ dependencies = [ "regex", "remote_storage", "reqwest", - "rustls 0.23.18", + "rustls 0.23.27", "safekeeper_api", "safekeeper_client", "scopeguard", @@ -6091,7 +6117,7 @@ dependencies = [ "tokio", "tokio-io-timeout", "tokio-postgres", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.2", "tokio-stream", "tokio-tar", "tokio-util", @@ -6263,7 +6289,7 @@ checksum = "255914a8e53822abd946e2ce8baa41d4cded6b8e938913b7f7b9da5b7ab44335" dependencies = [ "httpdate", "reqwest", - "rustls 0.23.18", + "rustls 0.23.27", "sentry-backtrace", "sentry-contexts", "sentry-core", @@ -6692,11 +6718,11 @@ dependencies = [ "metrics", "once_cell", "parking_lot 0.12.1", - "prost 0.13.3", - "rustls 0.23.18", + "prost 0.13.5", + "rustls 0.23.27", "tokio", - "tokio-rustls 0.26.0", - "tonic", + "tokio-rustls 0.26.2", + "tonic 0.13.1", "tonic-build", "tracing", "utils", @@ -6738,7 +6764,7 @@ dependencies = [ "regex", "reqwest", "routerify", - "rustls 0.23.18", + "rustls 0.23.27", "rustls-native-certs 0.8.0", "safekeeper_api", "safekeeper_client", @@ -6753,7 +6779,7 @@ dependencies = [ "tokio", "tokio-postgres", "tokio-postgres-rustls", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.2", "tokio-util", "tracing", "utils", @@ -6791,7 +6817,7 @@ dependencies = [ "postgres_ffi", "remote_storage", "reqwest", - "rustls 0.23.18", + "rustls 0.23.27", "rustls-native-certs 0.8.0", "serde", "serde_json", @@ -7325,10 +7351,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04fb792ccd6bbcd4bba408eb8a292f70fc4a3589e5d793626f45190e6454b6ab" dependencies = [ "ring", - "rustls 0.23.18", + "rustls 0.23.27", "tokio", "tokio-postgres", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.2", "x509-certificate", ] @@ -7372,12 +7398,11 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.0" +version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" dependencies = [ - "rustls 0.23.18", - "rustls-pki-types", + "rustls 0.23.27", "tokio", ] @@ -7475,7 +7500,7 @@ version = "0.22.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f21c7aaf97f1bd9ca9d4f9e73b0a6c74bd5afef56f2bc931943a6e1c37e04e38" dependencies = [ - "indexmap 2.0.1", + "indexmap 2.9.0", "serde", "serde_spanned", "toml_datetime", @@ -7494,18 +7519,41 @@ dependencies = [ "http 1.1.0", "http-body 1.0.0", "http-body-util", + "percent-encoding", + "pin-project", + "prost 0.13.5", + "tokio-stream", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tonic" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e581ba15a835f4d9ea06c55ab1bd4dce26fc53752c69a04aac00703bfb49ba9" +dependencies = [ + "async-trait", + "axum", + "base64 0.22.1", + "bytes", + "h2 0.4.4", + "http 1.1.0", + "http-body 1.0.0", + "http-body-util", "hyper 1.4.1", "hyper-timeout", "hyper-util", "percent-encoding", "pin-project", - "prost 0.13.3", + "prost 0.13.5", "rustls-native-certs 0.8.0", - "rustls-pemfile 2.1.1", + "socket2", "tokio", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.2", "tokio-stream", - "tower 0.4.13", + "tower 0.5.2", "tower-layer", "tower-service", "tracing", @@ -7513,9 +7561,9 @@ dependencies = [ [[package]] name = "tonic-build" -version = "0.12.3" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9557ce109ea773b399c9b9e5dca39294110b74f1f342cb347a80d1fce8c26a11" +checksum = "eac6f67be712d12f0b41328db3137e0d0757645d8904b4cb7d51cd9c2279e847" dependencies = [ "prettyplease", "proc-macro2", @@ -7525,6 +7573,19 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "tonic-reflection" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9687bd5bfeafebdded2356950f278bba8226f0b32109537c4253406e09aafe1" +dependencies = [ + "prost 0.13.5", + "prost-types 0.13.3", + "tokio", + "tokio-stream", + "tonic 0.13.1", +] + [[package]] name = "tower" version = "0.4.13" @@ -7533,16 +7594,11 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c" dependencies = [ "futures-core", "futures-util", - "indexmap 1.9.3", "pin-project", "pin-project-lite", - "rand 0.8.5", - "slab", "tokio", - "tokio-util", "tower-layer", "tower-service", - "tracing", ] [[package]] @@ -7553,9 +7609,12 @@ checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" dependencies = [ "futures-core", "futures-util", + "indexmap 2.9.0", "pin-project-lite", + "slab", "sync_wrapper 1.0.1", "tokio", + "tokio-util", "tower-layer", "tower-service", "tracing", @@ -7883,7 +7942,7 @@ dependencies = [ "base64 0.22.1", "log", "once_cell", - "rustls 0.23.18", + "rustls 0.23.27", "rustls-pki-types", "url", "webpki-roots", @@ -8078,7 +8137,7 @@ dependencies = [ "pageserver_api", "postgres_ffi", "pprof", - "prost 0.13.3", + "prost 0.13.5", "remote_storage", "serde", "serde_json", @@ -8498,6 +8557,8 @@ dependencies = [ "ahash", "anstream", "anyhow", + "axum", + "axum-core", "base64 0.13.1", "base64 0.21.7", "base64ct", @@ -8520,10 +8581,8 @@ dependencies = [ "fail", "form_urlencoded", "futures-channel", - "futures-core", "futures-executor", "futures-io", - "futures-task", "futures-util", "generic-array", "getrandom 0.2.11", @@ -8534,8 +8593,7 @@ dependencies = [ "hyper 0.14.30", "hyper 1.4.1", "hyper-util", - "indexmap 1.9.3", - "indexmap 2.0.1", + "indexmap 2.9.0", "itertools 0.12.1", "lazy_static", "libc", @@ -8554,19 +8612,18 @@ dependencies = [ "once_cell", "p256 0.13.2", "parquet", - "percent-encoding", "prettyplease", "proc-macro2", - "prost 0.13.3", + "prost 0.13.5", "quote", "rand 0.8.5", "regex", "regex-automata 0.4.3", "regex-syntax 0.8.2", "reqwest", - "rustls 0.23.18", + "rustls 0.23.27", "rustls-pki-types", - "rustls-webpki 0.102.8", + "rustls-webpki 0.103.3", "scopeguard", "sec1 0.7.3", "serde", @@ -8584,12 +8641,11 @@ dependencies = [ "time", "time-macros", "tokio", - "tokio-rustls 0.26.0", + "tokio-rustls 0.26.2", "tokio-stream", "tokio-util", "toml_edit", - "tonic", - "tower 0.4.13", + "tower 0.5.2", "tracing", "tracing-core", "tracing-log", diff --git a/Cargo.toml b/Cargo.toml index a280c446b9..a040010fb7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -149,7 +149,7 @@ pin-project-lite = "0.2" pprof = { version = "0.14", features = ["criterion", "flamegraph", "frame-pointer", "prost-codec"] } procfs = "0.16" prometheus = {version = "0.13", default-features=false, features = ["process"]} # removes protobuf dependency -prost = "0.13" +prost = "0.13.5" rand = "0.8" redis = { version = "0.29.2", features = ["tokio-rustls-comp", "keep-alive"] } regex = "1.10.2" @@ -199,7 +199,8 @@ tokio-tar = "0.3" tokio-util = { version = "0.7.10", features = ["io", "rt"] } toml = "0.8" toml_edit = "0.22" -tonic = {version = "0.12.3", default-features = false, features = ["channel", "tls", "tls-roots"]} +tonic = { version = "0.13.1", default-features = false, features = ["channel", "codegen", "prost", "router", "server", "tls-ring", "tls-native-roots"] } +tonic-reflection = { version = "0.13.1", features = ["server"] } tower = { version = "0.5.2", default-features = false } tower-http = { version = "0.6.2", features = ["auth", "request-id", "trace"] } @@ -246,6 +247,7 @@ azure_storage_blobs = { git = "https://github.com/neondatabase/azure-sdk-for-rus ## Local libraries compute_api = { version = "0.1", path = "./libs/compute_api/" } consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" } +desim = { version = "0.1", path = "./libs/desim" } endpoint_storage = { version = "0.0.1", path = "./endpoint_storage/" } http-utils = { version = "0.1", path = "./libs/http-utils/" } metrics = { version = "0.1", path = "./libs/metrics/" } @@ -258,19 +260,19 @@ postgres_backend = { version = "0.1", path = "./libs/postgres_backend/" } postgres_connection = { version = "0.1", path = "./libs/postgres_connection/" } postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" } postgres_initdb = { path = "./libs/postgres_initdb" } +posthog_client_lite = { version = "0.1", path = "./libs/posthog_client_lite" } pq_proto = { version = "0.1", path = "./libs/pq_proto/" } remote_storage = { version = "0.1", path = "./libs/remote_storage/" } safekeeper_api = { version = "0.1", path = "./libs/safekeeper_api" } safekeeper_client = { path = "./safekeeper/client" } -desim = { version = "0.1", path = "./libs/desim" } storage_broker = { version = "0.1", path = "./storage_broker/" } # Note: main broker code is inside the binary crate, so linking with the library shouldn't be heavy. storage_controller_client = { path = "./storage_controller/client" } tenant_size_model = { version = "0.1", path = "./libs/tenant_size_model/" } tracing-utils = { version = "0.1", path = "./libs/tracing-utils/" } utils = { version = "0.1", path = "./libs/utils/" } vm_monitor = { version = "0.1", path = "./libs/vm_monitor/" } -walproposer = { version = "0.1", path = "./libs/walproposer/" } wal_decoder = { version = "0.1", path = "./libs/wal_decoder" } +walproposer = { version = "0.1", path = "./libs/walproposer/" } ## Common library dependency workspace_hack = { version = "0.1", path = "./workspace_hack/" } @@ -280,7 +282,7 @@ criterion = "0.5.1" rcgen = "0.13" rstest = "0.18" camino-tempfile = "1.0.2" -tonic-build = "0.12" +tonic-build = "0.13.1" [patch.crates-io] diff --git a/build-tools.Dockerfile b/build-tools.Dockerfile index 1933fd19d8..9d4c93e1cd 100644 --- a/build-tools.Dockerfile +++ b/build-tools.Dockerfile @@ -155,7 +155,7 @@ RUN set -e \ # Keep the version the same as in compute/compute-node.Dockerfile and # test_runner/regress/test_compute_metrics.py. -ENV SQL_EXPORTER_VERSION=0.17.0 +ENV SQL_EXPORTER_VERSION=0.17.3 RUN curl -fsSL \ "https://github.com/burningalchemist/sql_exporter/releases/download/${SQL_EXPORTER_VERSION}/sql_exporter-${SQL_EXPORTER_VERSION}.linux-$(case "$(uname -m)" in x86_64) echo amd64;; aarch64) echo arm64;; esac).tar.gz" \ --output sql_exporter.tar.gz \ diff --git a/compute/compute-node.Dockerfile b/compute/compute-node.Dockerfile index 3e2c09493f..55d008acea 100644 --- a/compute/compute-node.Dockerfile +++ b/compute/compute-node.Dockerfile @@ -582,38 +582,6 @@ RUN make -j $(getconf _NPROCESSORS_ONLN) && \ make -j $(getconf _NPROCESSORS_ONLN) install && \ echo 'trusted = true' >> /usr/local/pgsql/share/extension/hypopg.control -######################################################################################### -# -# Layer "online_advisor-build" -# compile online_advisor extension -# -######################################################################################### -FROM build-deps AS online_advisor-src -ARG PG_VERSION - -# online_advisor supports all Postgres version starting from PG14, but prior to PG17 has to be included in preload_shared_libraries -# last release 1.0 - May 15, 2025 -WORKDIR /ext-src -RUN case "${PG_VERSION:?}" in \ - "v17") \ - ;; \ - *) \ - echo "skipping the version of online_advistor for $PG_VERSION" && exit 0 \ - ;; \ - esac && \ - wget https://github.com/knizhnik/online_advisor/archive/refs/tags/1.0.tar.gz -O online_advisor.tar.gz && \ - echo "059b7d9e5a90013a58bdd22e9505b88406ce05790675eb2d8434e5b215652d54 online_advisor.tar.gz" | sha256sum --check && \ - mkdir online_advisor-src && cd online_advisor-src && tar xzf ../online_advisor.tar.gz --strip-components=1 -C . - -FROM pg-build AS online_advisor-build -COPY --from=online_advisor-src /ext-src/ /ext-src/ -WORKDIR /ext-src/ -RUN if [ -d online_advisor-src ]; then \ - cd online_advisor-src && \ - make -j install && \ - echo 'trusted = true' >> /usr/local/pgsql/share/extension/online_advisor.control; \ - fi - ######################################################################################### # # Layer "pg_hashids-build" @@ -1680,7 +1648,6 @@ COPY --from=pg_jsonschema-build /usr/local/pgsql/ /usr/local/pgsql/ COPY --from=pg_graphql-build /usr/local/pgsql/ /usr/local/pgsql/ COPY --from=pg_tiktoken-build /usr/local/pgsql/ /usr/local/pgsql/ COPY --from=hypopg-build /usr/local/pgsql/ /usr/local/pgsql/ -COPY --from=online_advisor-build /usr/local/pgsql/ /usr/local/pgsql/ COPY --from=pg_hashids-build /usr/local/pgsql/ /usr/local/pgsql/ COPY --from=rum-build /usr/local/pgsql/ /usr/local/pgsql/ COPY --from=pgtap-build /usr/local/pgsql/ /usr/local/pgsql/ @@ -1784,17 +1751,17 @@ ARG TARGETARCH RUN if [ "$TARGETARCH" = "amd64" ]; then\ postgres_exporter_sha256='59aa4a7bb0f7d361f5e05732f5ed8c03cc08f78449cef5856eadec33a627694b';\ pgbouncer_exporter_sha256='c9f7cf8dcff44f0472057e9bf52613d93f3ffbc381ad7547a959daa63c5e84ac';\ - sql_exporter_sha256='38e439732bbf6e28ca4a94d7bc3686d3fa1abdb0050773d5617a9efdb9e64d08';\ + sql_exporter_sha256='9a41127a493e8bfebfe692bf78c7ed2872a58a3f961ee534d1b0da9ae584aaab';\ else\ postgres_exporter_sha256='d1dedea97f56c6d965837bfd1fbb3e35a3b4a4556f8cccee8bd513d8ee086124';\ pgbouncer_exporter_sha256='217c4afd7e6492ae904055bc14fe603552cf9bac458c063407e991d68c519da3';\ - sql_exporter_sha256='11918b00be6e2c3a67564adfdb2414fdcbb15a5db76ea17d1d1a944237a893c6';\ + sql_exporter_sha256='530e6afc77c043497ed965532c4c9dfa873bc2a4f0b3047fad367715c0081d6a';\ fi\ && curl -sL https://github.com/prometheus-community/postgres_exporter/releases/download/v0.17.1/postgres_exporter-0.17.1.linux-${TARGETARCH}.tar.gz\ | tar xzf - --strip-components=1 -C.\ && curl -sL https://github.com/prometheus-community/pgbouncer_exporter/releases/download/v0.10.2/pgbouncer_exporter-0.10.2.linux-${TARGETARCH}.tar.gz\ | tar xzf - --strip-components=1 -C.\ - && curl -sL https://github.com/burningalchemist/sql_exporter/releases/download/0.17.0/sql_exporter-0.17.0.linux-${TARGETARCH}.tar.gz\ + && curl -sL https://github.com/burningalchemist/sql_exporter/releases/download/0.17.3/sql_exporter-0.17.3.linux-${TARGETARCH}.tar.gz\ | tar xzf - --strip-components=1 -C.\ && echo "${postgres_exporter_sha256} postgres_exporter" | sha256sum -c -\ && echo "${pgbouncer_exporter_sha256} pgbouncer_exporter" | sha256sum -c -\ @@ -1847,7 +1814,7 @@ COPY docker-compose/ext-src/ /ext-src/ COPY --from=pg-build /postgres /postgres #COPY --from=postgis-src /ext-src/ /ext-src/ COPY --from=plv8-src /ext-src/ /ext-src/ -#COPY --from=h3-pg-src /ext-src/ /ext-src/ +COPY --from=h3-pg-src /ext-src/h3-pg-src /ext-src/h3-pg-src COPY --from=postgresql-unit-src /ext-src/ /ext-src/ COPY --from=pgvector-src /ext-src/ /ext-src/ COPY --from=pgjwt-src /ext-src/ /ext-src/ @@ -1856,7 +1823,6 @@ COPY --from=pgjwt-src /ext-src/ /ext-src/ COPY --from=pg_graphql-src /ext-src/ /ext-src/ #COPY --from=pg_tiktoken-src /ext-src/ /ext-src/ COPY --from=hypopg-src /ext-src/ /ext-src/ -COPY --from=online_advisor-src /ext-src/ /ext-src/ COPY --from=pg_hashids-src /ext-src/ /ext-src/ COPY --from=rum-src /ext-src/ /ext-src/ COPY --from=pgtap-src /ext-src/ /ext-src/ diff --git a/compute_tools/src/bin/compute_ctl.rs b/compute_tools/src/bin/compute_ctl.rs index 20b5e567a8..02339f752c 100644 --- a/compute_tools/src/bin/compute_ctl.rs +++ b/compute_tools/src/bin/compute_ctl.rs @@ -136,6 +136,10 @@ struct Cli { requires = "compute-id" )] pub control_plane_uri: Option, + + /// Interval in seconds for collecting installed extensions statistics + #[arg(long, default_value = "3600")] + pub installed_extensions_collection_interval: u64, } fn main() -> Result<()> { @@ -179,6 +183,7 @@ fn main() -> Result<()> { cgroup: cli.cgroup, #[cfg(target_os = "linux")] vm_monitor_addr: cli.vm_monitor_addr, + installed_extensions_collection_interval: cli.installed_extensions_collection_interval, }, config, )?; diff --git a/compute_tools/src/compute.rs b/compute_tools/src/compute.rs index f494e2444a..ff49c737f0 100644 --- a/compute_tools/src/compute.rs +++ b/compute_tools/src/compute.rs @@ -97,6 +97,9 @@ pub struct ComputeNodeParams { /// the address of extension storage proxy gateway pub remote_ext_base_url: Option, + + /// Interval for installed extensions collection + pub installed_extensions_collection_interval: u64, } /// Compute node info shared across several `compute_ctl` threads. @@ -695,25 +698,18 @@ impl ComputeNode { let log_directory_path = Path::new(&self.params.pgdata).join("log"); let log_directory_path = log_directory_path.to_string_lossy().to_string(); - // Add project_id,endpoint_id tag to identify the logs. + // Add project_id,endpoint_id to identify the logs. // // These ids are passed from cplane, - // for backwards compatibility (old computes that don't have them), - // we set them to None. - // TODO: Clean up this code when all computes have them. - let tag: Option = match ( - pspec.spec.project_id.as_deref(), - pspec.spec.endpoint_id.as_deref(), - ) { - (Some(project_id), Some(endpoint_id)) => { - Some(format!("{project_id}/{endpoint_id}")) - } - (Some(project_id), None) => Some(format!("{project_id}/None")), - (None, Some(endpoint_id)) => Some(format!("None,{endpoint_id}")), - (None, None) => None, - }; + let endpoint_id = pspec.spec.endpoint_id.as_deref().unwrap_or(""); + let project_id = pspec.spec.project_id.as_deref().unwrap_or(""); - configure_audit_rsyslog(log_directory_path.clone(), tag, &remote_endpoint)?; + configure_audit_rsyslog( + log_directory_path.clone(), + endpoint_id, + project_id, + &remote_endpoint, + )?; // Launch a background task to clean up the audit logs launch_pgaudit_gc(log_directory_path); @@ -749,17 +745,7 @@ impl ComputeNode { let conf = self.get_tokio_conn_conf(None); tokio::task::spawn(async { - let res = get_installed_extensions(conf).await; - match res { - Ok(extensions) => { - info!( - "[NEON_EXT_STAT] {}", - serde_json::to_string(&extensions) - .expect("failed to serialize extensions list") - ); - } - Err(err) => error!("could not get installed extensions: {err:?}"), - } + let _ = installed_extensions(conf).await; }); } @@ -789,6 +775,9 @@ impl ComputeNode { // Log metrics so that we can search for slow operations in logs info!(?metrics, postmaster_pid = %postmaster_pid, "compute start finished"); + // Spawn the extension stats background task + self.spawn_extension_stats_task(); + if pspec.spec.prewarm_lfc_on_startup { self.prewarm_lfc(); } @@ -2199,6 +2188,41 @@ LIMIT 100", info!("Pageserver config changed"); } } + + pub fn spawn_extension_stats_task(&self) { + let conf = self.tokio_conn_conf.clone(); + let installed_extensions_collection_interval = + self.params.installed_extensions_collection_interval; + tokio::spawn(async move { + // An initial sleep is added to ensure that two collections don't happen at the same time. + // The first collection happens during compute startup. + tokio::time::sleep(tokio::time::Duration::from_secs( + installed_extensions_collection_interval, + )) + .await; + let mut interval = tokio::time::interval(tokio::time::Duration::from_secs( + installed_extensions_collection_interval, + )); + loop { + interval.tick().await; + let _ = installed_extensions(conf.clone()).await; + } + }); + } +} + +pub async fn installed_extensions(conf: tokio_postgres::Config) -> Result<()> { + let res = get_installed_extensions(conf).await; + match res { + Ok(extensions) => { + info!( + "[NEON_EXT_STAT] {}", + serde_json::to_string(&extensions).expect("failed to serialize extensions list") + ); + } + Err(err) => error!("could not get installed extensions: {err:?}"), + } + Ok(()) } pub fn forward_termination_signal() { diff --git a/compute_tools/src/config_template/compute_audit_rsyslog_template.conf b/compute_tools/src/config_template/compute_audit_rsyslog_template.conf index 9ca7e36738..48b1a6f5c3 100644 --- a/compute_tools/src/config_template/compute_audit_rsyslog_template.conf +++ b/compute_tools/src/config_template/compute_audit_rsyslog_template.conf @@ -2,10 +2,24 @@ module(load="imfile") # Input configuration for log files in the specified directory -# Replace {log_directory} with the directory containing the log files -input(type="imfile" File="{log_directory}/*.log" Tag="{tag}" Severity="info" Facility="local0") +# The messages can be multiline. The start of the message is a timestamp +# in "%Y-%m-%d %H:%M:%S.%3N GMT" (so timezone hardcoded). +# Replace log_directory with the directory containing the log files +input(type="imfile" File="{log_directory}/*.log" + Tag="pgaudit_log" Severity="info" Facility="local5" + startmsg.regex="^[[:digit:]]{{4}}-[[:digit:]]{{2}}-[[:digit:]]{{2}} [[:digit:]]{{2}}:[[:digit:]]{{2}}:[[:digit:]]{{2}}.[[:digit:]]{{3}} GMT,") + # the directory to store rsyslog state files global(workDirectory="/var/log/rsyslog") -# Forward logs to remote syslog server -*.* @@{remote_endpoint} +# Construct json, endpoint_id and project_id as additional metadata +set $.json_log!endpoint_id = "{endpoint_id}"; +set $.json_log!project_id = "{project_id}"; +set $.json_log!msg = $msg; + +# Template suitable for rfc5424 syslog format +template(name="PgAuditLog" type="string" + string="<%PRI%>1 %TIMESTAMP:::date-rfc3339% %HOSTNAME% - - - - %$.json_log%") + +# Forward to remote syslog receiver (@@:;format +local5.info @@{remote_endpoint};PgAuditLog diff --git a/compute_tools/src/rsyslog.rs b/compute_tools/src/rsyslog.rs index c873697623..3bc2e72b19 100644 --- a/compute_tools/src/rsyslog.rs +++ b/compute_tools/src/rsyslog.rs @@ -84,13 +84,15 @@ fn restart_rsyslog() -> Result<()> { pub fn configure_audit_rsyslog( log_directory: String, - tag: Option, + endpoint_id: &str, + project_id: &str, remote_endpoint: &str, ) -> Result<()> { let config_content: String = format!( include_str!("config_template/compute_audit_rsyslog_template.conf"), log_directory = log_directory, - tag = tag.unwrap_or("".to_string()), + endpoint_id = endpoint_id, + project_id = project_id, remote_endpoint = remote_endpoint ); diff --git a/control_plane/safekeepers.conf b/control_plane/safekeepers.conf index 576cc4a3a9..a73e274dfa 100644 --- a/control_plane/safekeepers.conf +++ b/control_plane/safekeepers.conf @@ -2,8 +2,10 @@ [pageserver] listen_pg_addr = '127.0.0.1:64000' listen_http_addr = '127.0.0.1:9898' +listen_grpc_addr = '127.0.0.1:51051' pg_auth_type = 'Trust' http_auth_type = 'Trust' +grpc_auth_type = 'Trust' [[safekeepers]] id = 1 diff --git a/control_plane/simple.conf b/control_plane/simple.conf index 0ad90a4618..1eb21f846e 100644 --- a/control_plane/simple.conf +++ b/control_plane/simple.conf @@ -4,8 +4,10 @@ id=1 listen_pg_addr = '127.0.0.1:64000' listen_http_addr = '127.0.0.1:9898' +listen_grpc_addr = '127.0.0.1:51051' pg_auth_type = 'Trust' http_auth_type = 'Trust' +grpc_auth_type = 'Trust' [[safekeepers]] id = 1 diff --git a/control_plane/src/bin/neon_local.rs b/control_plane/src/bin/neon_local.rs index 98ab6e5657..ef6985d697 100644 --- a/control_plane/src/bin/neon_local.rs +++ b/control_plane/src/bin/neon_local.rs @@ -32,6 +32,7 @@ use control_plane::storage_controller::{ }; use nix::fcntl::{Flock, FlockArg}; use pageserver_api::config::{ + DEFAULT_GRPC_LISTEN_PORT as DEFAULT_PAGESERVER_GRPC_PORT, DEFAULT_HTTP_LISTEN_PORT as DEFAULT_PAGESERVER_HTTP_PORT, DEFAULT_PG_LISTEN_PORT as DEFAULT_PAGESERVER_PG_PORT, }; @@ -1007,13 +1008,16 @@ fn handle_init(args: &InitCmdArgs) -> anyhow::Result { let pageserver_id = NodeId(DEFAULT_PAGESERVER_ID.0 + i as u64); let pg_port = DEFAULT_PAGESERVER_PG_PORT + i; let http_port = DEFAULT_PAGESERVER_HTTP_PORT + i; + let grpc_port = DEFAULT_PAGESERVER_GRPC_PORT + i; NeonLocalInitPageserverConf { id: pageserver_id, listen_pg_addr: format!("127.0.0.1:{pg_port}"), listen_http_addr: format!("127.0.0.1:{http_port}"), listen_https_addr: None, + listen_grpc_addr: Some(format!("127.0.0.1:{grpc_port}")), pg_auth_type: AuthType::Trust, http_auth_type: AuthType::Trust, + grpc_auth_type: AuthType::Trust, other: Default::default(), // Typical developer machines use disks with slow fsync, and we don't care // about data integrity: disable disk syncs. @@ -1275,6 +1279,7 @@ async fn handle_timeline(cmd: &TimelineCmd, env: &mut local_env::LocalEnv) -> Re mode: pageserver_api::models::TimelineCreateRequestMode::Branch { ancestor_timeline_id, ancestor_start_lsn: start_lsn, + read_only: false, pg_version: None, }, }; diff --git a/control_plane/src/local_env.rs b/control_plane/src/local_env.rs index 4a8892c6de..47b77f0720 100644 --- a/control_plane/src/local_env.rs +++ b/control_plane/src/local_env.rs @@ -278,8 +278,10 @@ pub struct PageServerConf { pub listen_pg_addr: String, pub listen_http_addr: String, pub listen_https_addr: Option, + pub listen_grpc_addr: Option, pub pg_auth_type: AuthType, pub http_auth_type: AuthType, + pub grpc_auth_type: AuthType, pub no_sync: bool, } @@ -290,8 +292,10 @@ impl Default for PageServerConf { listen_pg_addr: String::new(), listen_http_addr: String::new(), listen_https_addr: None, + listen_grpc_addr: None, pg_auth_type: AuthType::Trust, http_auth_type: AuthType::Trust, + grpc_auth_type: AuthType::Trust, no_sync: false, } } @@ -306,8 +310,10 @@ pub struct NeonLocalInitPageserverConf { pub listen_pg_addr: String, pub listen_http_addr: String, pub listen_https_addr: Option, + pub listen_grpc_addr: Option, pub pg_auth_type: AuthType, pub http_auth_type: AuthType, + pub grpc_auth_type: AuthType, #[serde(default, skip_serializing_if = "std::ops::Not::not")] pub no_sync: bool, #[serde(flatten)] @@ -321,8 +327,10 @@ impl From<&NeonLocalInitPageserverConf> for PageServerConf { listen_pg_addr, listen_http_addr, listen_https_addr, + listen_grpc_addr, pg_auth_type, http_auth_type, + grpc_auth_type, no_sync, other: _, } = conf; @@ -331,7 +339,9 @@ impl From<&NeonLocalInitPageserverConf> for PageServerConf { listen_pg_addr: listen_pg_addr.clone(), listen_http_addr: listen_http_addr.clone(), listen_https_addr: listen_https_addr.clone(), + listen_grpc_addr: listen_grpc_addr.clone(), pg_auth_type: *pg_auth_type, + grpc_auth_type: *grpc_auth_type, http_auth_type: *http_auth_type, no_sync: *no_sync, } @@ -707,8 +717,10 @@ impl LocalEnv { listen_pg_addr: String, listen_http_addr: String, listen_https_addr: Option, + listen_grpc_addr: Option, pg_auth_type: AuthType, http_auth_type: AuthType, + grpc_auth_type: AuthType, #[serde(default)] no_sync: bool, } @@ -732,8 +744,10 @@ impl LocalEnv { listen_pg_addr, listen_http_addr, listen_https_addr, + listen_grpc_addr, pg_auth_type, http_auth_type, + grpc_auth_type, no_sync, } = config_toml; let IdentityTomlSubset { @@ -750,8 +764,10 @@ impl LocalEnv { listen_pg_addr, listen_http_addr, listen_https_addr, + listen_grpc_addr, pg_auth_type, http_auth_type, + grpc_auth_type, no_sync, }; pageservers.push(conf); diff --git a/control_plane/src/pageserver.rs b/control_plane/src/pageserver.rs index 756f2b02db..29314dab9e 100644 --- a/control_plane/src/pageserver.rs +++ b/control_plane/src/pageserver.rs @@ -129,7 +129,9 @@ impl PageServerNode { )); } - if conf.http_auth_type != AuthType::Trust || conf.pg_auth_type != AuthType::Trust { + if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type] + .contains(&AuthType::NeonJWT) + { // Keys are generated in the toplevel repo dir, pageservers' workdirs // are one level below that, so refer to keys with ../ overrides.push("auth_validation_public_key_path='../auth_public_key.pem'".to_owned()); diff --git a/docker-compose/compute_wrapper/shell/compute.sh b/docker-compose/compute_wrapper/shell/compute.sh index 20a1ffb7a0..ab8d74d355 100755 --- a/docker-compose/compute_wrapper/shell/compute.sh +++ b/docker-compose/compute_wrapper/shell/compute.sh @@ -20,7 +20,7 @@ first_path="$(ldconfig --verbose 2>/dev/null \ | grep --invert-match ^$'\t' \ | cut --delimiter=: --fields=1 \ | head --lines=1)" -test "$first_path" == '/usr/local/lib' || true # Remove the || true in a follow-up PR. Needed for backwards compat. +test "$first_path" == '/usr/local/lib' echo "Waiting pageserver become ready." while ! nc -z pageserver 6400; do diff --git a/docker-compose/ext-src/h3-pg-src/neon-test.sh b/docker-compose/ext-src/h3-pg-src/neon-test.sh new file mode 100755 index 0000000000..e2ab22f03e --- /dev/null +++ b/docker-compose/ext-src/h3-pg-src/neon-test.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env bash +set -ex +cd "$(dirname "${0}")" +PG_REGRESS=$(dirname "$(pg_config --pgxs)")/../test/regress/pg_regress +dropdb --if-exists contrib_regression +createdb contrib_regression +cd h3_postgis/test +psql -d contrib_regression -c "CREATE EXTENSION postgis" -c "CREATE EXTENSION postgis_raster" -c "CREATE EXTENSION h3" -c "CREATE EXTENSION h3_postgis" +TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g') +${PG_REGRESS} --use-existing --dbname contrib_regression ${TESTS} +cd ../../h3/test +TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g') +dropdb --if-exists contrib_regression +createdb contrib_regression +psql -d contrib_regression -c "CREATE EXTENSION h3" +${PG_REGRESS} --use-existing --dbname contrib_regression ${TESTS} diff --git a/docker-compose/ext-src/h3-pg-src/test-upgrade.sh b/docker-compose/ext-src/h3-pg-src/test-upgrade.sh new file mode 100755 index 0000000000..72d7040966 --- /dev/null +++ b/docker-compose/ext-src/h3-pg-src/test-upgrade.sh @@ -0,0 +1,7 @@ +#!/bin/sh +set -ex +cd "$(dirname ${0})" +PG_REGRESS=$(dirname "$(pg_config --pgxs)")/../test/regress/pg_regress +cd h3/test +TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g') +${PG_REGRESS} --use-existing --inputdir=./ --bindir='/usr/local/pgsql/bin' --dbname=contrib_regression ${TESTS} \ No newline at end of file diff --git a/docker-compose/ext-src/online_advisor-src/neon-test.sh b/docker-compose/ext-src/online_advisor-src/neon-test.sh new file mode 100755 index 0000000000..db5c2821fa --- /dev/null +++ b/docker-compose/ext-src/online_advisor-src/neon-test.sh @@ -0,0 +1,6 @@ +#!/bin/sh +set -ex +cd "$(dirname "${0}")" +if [ -f Makefile ]; then + make installcheck +fi diff --git a/docker-compose/ext-src/online_advisor-src/regular-test.sh b/docker-compose/ext-src/online_advisor-src/regular-test.sh new file mode 100755 index 0000000000..e94f03aa70 --- /dev/null +++ b/docker-compose/ext-src/online_advisor-src/regular-test.sh @@ -0,0 +1,9 @@ +#!/bin/sh +set -ex +cd "$(dirname ${0})" +[ -f Makefile ] || exit 0 +dropdb --if-exist contrib_regression +createdb contrib_regression +PG_REGRESS=$(dirname "$(pg_config --pgxs)")/../test/regress/pg_regress +TESTS=$(echo sql/* | sed 's|sql/||g; s|\.sql||g') +${PG_REGRESS} --use-existing --inputdir=./ --bindir='/usr/local/pgsql/bin' --dbname=contrib_regression ${TESTS} diff --git a/docker-compose/test_extensions_upgrade.sh b/docker-compose/test_extensions_upgrade.sh index 51d1e40802..f1cf17f531 100755 --- a/docker-compose/test_extensions_upgrade.sh +++ b/docker-compose/test_extensions_upgrade.sh @@ -82,7 +82,8 @@ EXTENSIONS='[ {"extname": "pg_ivm", "extdir": "pg_ivm-src"}, {"extname": "pgjwt", "extdir": "pgjwt-src"}, {"extname": "pgtap", "extdir": "pgtap-src"}, -{"extname": "pg_repack", "extdir": "pg_repack-src"} +{"extname": "pg_repack", "extdir": "pg_repack-src"}, +{"extname": "h3", "extdir": "h3-pg-src"} ]' EXTNAMES=$(echo ${EXTENSIONS} | jq -r '.[].extname' | paste -sd ' ' -) COMPUTE_TAG=${NEW_COMPUTE_TAG} docker compose --profile test-extensions up --quiet-pull --build -d diff --git a/libs/pageserver_api/src/config.rs b/libs/pageserver_api/src/config.rs index 0fb2ff38ff..012c020fb1 100644 --- a/libs/pageserver_api/src/config.rs +++ b/libs/pageserver_api/src/config.rs @@ -8,6 +8,8 @@ pub const DEFAULT_PG_LISTEN_PORT: u16 = 64000; pub const DEFAULT_PG_LISTEN_ADDR: &str = formatcp!("127.0.0.1:{DEFAULT_PG_LISTEN_PORT}"); pub const DEFAULT_HTTP_LISTEN_PORT: u16 = 9898; pub const DEFAULT_HTTP_LISTEN_ADDR: &str = formatcp!("127.0.0.1:{DEFAULT_HTTP_LISTEN_PORT}"); +// TODO: gRPC is disabled by default for now, but the port is used in neon_local. +pub const DEFAULT_GRPC_LISTEN_PORT: u16 = 51051; // storage-broker already uses 50051 use std::collections::HashMap; use std::num::{NonZeroU64, NonZeroUsize}; @@ -43,6 +45,21 @@ pub struct NodeMetadata { pub other: HashMap, } +/// PostHog integration config. +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct PostHogConfig { + /// PostHog project ID + pub project_id: String, + /// Server-side (private) API key + pub server_api_key: String, + /// Client-side (public) API key + pub client_api_key: String, + /// Private API URL + pub private_api_url: String, + /// Public API URL + pub public_api_url: String, +} + /// `pageserver.toml` /// /// We use serde derive with `#[serde(default)]` to generate a deserializer @@ -104,6 +121,7 @@ pub struct ConfigToml { pub listen_pg_addr: String, pub listen_http_addr: String, pub listen_https_addr: Option, + pub listen_grpc_addr: Option, pub ssl_key_file: Utf8PathBuf, pub ssl_cert_file: Utf8PathBuf, #[serde(with = "humantime_serde")] @@ -123,6 +141,7 @@ pub struct ConfigToml { pub http_auth_type: AuthType, #[serde_as(as = "serde_with::DisplayFromStr")] pub pg_auth_type: AuthType, + pub grpc_auth_type: AuthType, pub auth_validation_public_key_path: Option, pub remote_storage: Option, pub tenant_config: TenantConfigToml, @@ -182,6 +201,8 @@ pub struct ConfigToml { pub tracing: Option, pub enable_tls_page_service_api: bool, pub dev_mode: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub posthog_config: Option, pub timeline_import_config: TimelineImportConfig, #[serde(skip_serializing_if = "Option::is_none")] pub basebackup_cache_config: Option, @@ -588,6 +609,7 @@ impl Default for ConfigToml { listen_pg_addr: (DEFAULT_PG_LISTEN_ADDR.to_string()), listen_http_addr: (DEFAULT_HTTP_LISTEN_ADDR.to_string()), listen_https_addr: (None), + listen_grpc_addr: None, // TODO: default to 127.0.0.1:51051 ssl_key_file: Utf8PathBuf::from(DEFAULT_SSL_KEY_FILE), ssl_cert_file: Utf8PathBuf::from(DEFAULT_SSL_CERT_FILE), ssl_cert_reload_period: Duration::from_secs(60), @@ -604,6 +626,7 @@ impl Default for ConfigToml { pg_distrib_dir: None, // Utf8PathBuf::from("./pg_install"), // TODO: formely, this was std::env::current_dir() http_auth_type: (AuthType::Trust), pg_auth_type: (AuthType::Trust), + grpc_auth_type: (AuthType::Trust), auth_validation_public_key_path: (None), remote_storage: None, broker_endpoint: (storage_broker::DEFAULT_ENDPOINT @@ -695,6 +718,7 @@ impl Default for ConfigToml { import_job_checkpoint_threshold: NonZeroUsize::new(128).unwrap(), }, basebackup_cache_config: None, + posthog_config: None, } } } diff --git a/libs/pageserver_api/src/models.rs b/libs/pageserver_api/src/models.rs index 383939a13f..9f3736d57a 100644 --- a/libs/pageserver_api/src/models.rs +++ b/libs/pageserver_api/src/models.rs @@ -402,6 +402,8 @@ pub enum TimelineCreateRequestMode { // using a flattened enum, so, it was an accepted field, and // we continue to accept it by having it here. pg_version: Option, + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + read_only: bool, }, ImportPgdata { import_pgdata: TimelineCreateRequestModeImportPgdata, diff --git a/libs/posthog_client_lite/Cargo.toml b/libs/posthog_client_lite/Cargo.toml index 7c19bf2ccb..05a3a9774e 100644 --- a/libs/posthog_client_lite/Cargo.toml +++ b/libs/posthog_client_lite/Cargo.toml @@ -6,9 +6,14 @@ license.workspace = true [dependencies] anyhow.workspace = true +arc-swap.workspace = true reqwest.workspace = true -serde.workspace = true serde_json.workspace = true +serde.workspace = true sha2.workspace = true -workspace_hack.workspace = true thiserror.workspace = true +tokio = { workspace = true, features = ["process", "sync", "fs", "rt", "io-util", "time"] } +tokio-util.workspace = true +tracing-utils.workspace = true +tracing.workspace = true +workspace_hack.workspace = true diff --git a/libs/posthog_client_lite/src/background_loop.rs b/libs/posthog_client_lite/src/background_loop.rs new file mode 100644 index 0000000000..9ffcda3728 --- /dev/null +++ b/libs/posthog_client_lite/src/background_loop.rs @@ -0,0 +1,59 @@ +//! A background loop that fetches feature flags from PostHog and updates the feature store. + +use std::{sync::Arc, time::Duration}; + +use arc_swap::ArcSwap; +use tokio_util::sync::CancellationToken; + +use crate::{FeatureStore, PostHogClient, PostHogClientConfig}; + +/// A background loop that fetches feature flags from PostHog and updates the feature store. +pub struct FeatureResolverBackgroundLoop { + posthog_client: PostHogClient, + feature_store: ArcSwap, + cancel: CancellationToken, +} + +impl FeatureResolverBackgroundLoop { + pub fn new(config: PostHogClientConfig, shutdown_pageserver: CancellationToken) -> Self { + Self { + posthog_client: PostHogClient::new(config), + feature_store: ArcSwap::new(Arc::new(FeatureStore::new())), + cancel: shutdown_pageserver, + } + } + + pub fn spawn(self: Arc, handle: &tokio::runtime::Handle, refresh_period: Duration) { + let this = self.clone(); + let cancel = self.cancel.clone(); + handle.spawn(async move { + tracing::info!("Starting PostHog feature resolver"); + let mut ticker = tokio::time::interval(refresh_period); + ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + loop { + tokio::select! { + _ = ticker.tick() => {} + _ = cancel.cancelled() => break + } + let resp = match this + .posthog_client + .get_feature_flags_local_evaluation() + .await + { + Ok(resp) => resp, + Err(e) => { + tracing::warn!("Cannot get feature flags: {}", e); + continue; + } + }; + let feature_store = FeatureStore::new_with_flags(resp.flags); + this.feature_store.store(Arc::new(feature_store)); + } + tracing::info!("PostHog feature resolver stopped"); + }); + } + + pub fn feature_store(&self) -> Arc { + self.feature_store.load_full() + } +} diff --git a/libs/posthog_client_lite/src/lib.rs b/libs/posthog_client_lite/src/lib.rs index 53deb26ab7..8aa8da2898 100644 --- a/libs/posthog_client_lite/src/lib.rs +++ b/libs/posthog_client_lite/src/lib.rs @@ -1,5 +1,9 @@ //! A lite version of the PostHog client that only supports local evaluation of feature flags. +mod background_loop; + +pub use background_loop::FeatureResolverBackgroundLoop; + use std::collections::HashMap; use serde::{Deserialize, Serialize}; @@ -20,8 +24,7 @@ pub enum PostHogEvaluationError { #[derive(Deserialize)] pub struct LocalEvaluationResponse { - #[allow(dead_code)] - flags: Vec, + pub flags: Vec, } #[derive(Deserialize)] @@ -34,7 +37,7 @@ pub struct LocalEvaluationFlag { #[derive(Deserialize)] pub struct LocalEvaluationFlagFilters { groups: Vec, - multivariate: LocalEvaluationFlagMultivariate, + multivariate: Option, } #[derive(Deserialize)] @@ -94,6 +97,12 @@ impl FeatureStore { } } + pub fn new_with_flags(flags: Vec) -> Self { + let mut store = Self::new(); + store.set_flags(flags); + store + } + pub fn set_flags(&mut self, flags: Vec) { self.flags.clear(); for flag in flags { @@ -245,7 +254,7 @@ impl FeatureStore { } } - /// Evaluate a multivariate feature flag. Returns `None` if the flag is not available or if there are errors + /// Evaluate a multivariate feature flag. Returns an error if the flag is not available or if there are errors /// during the evaluation. /// /// The parsing logic is as follows: @@ -263,10 +272,15 @@ impl FeatureStore { /// Example: we have a multivariate flag with 3 groups of the configured global rollout percentage: A (10%), B (20%), C (70%). /// There is a single group with a condition that has a rollout percentage of 10% and it does not have a variant override. /// Then, we will have 1% of the users evaluated to A, 2% to B, and 7% to C. + /// + /// Error handling: the caller should inspect the error and decide the behavior when a feature flag + /// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be + /// propagated beyond where the feature flag gets resolved. pub fn evaluate_multivariate( &self, flag_key: &str, user_id: &str, + properties: &HashMap, ) -> Result { let hash_on_global_rollout_percentage = Self::consistent_hash(user_id, flag_key, "multivariate"); @@ -276,10 +290,39 @@ impl FeatureStore { flag_key, hash_on_global_rollout_percentage, hash_on_group_rollout_percentage, - &HashMap::new(), + properties, ) } + /// Evaluate a boolean feature flag. Returns an error if the flag is not available or if there are errors + /// during the evaluation. + /// + /// The parsing logic is as follows: + /// + /// * Generate a consistent hash for the tenant-feature. + /// * Match each filter group. + /// - If a group is matched, it will first determine whether the user is in the range of the rollout + /// percentage. + /// - If the hash falls within the group's rollout percentage, return true. + /// * Otherwise, continue with the next group until all groups are evaluated and no group is within the + /// rollout percentage. + /// * If there are no matching groups, return an error. + /// + /// Returns `Ok(())` if the feature flag evaluates to true. In the future, it will return a payload. + /// + /// Error handling: the caller should inspect the error and decide the behavior when a feature flag + /// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be + /// propagated beyond where the feature flag gets resolved. + pub fn evaluate_boolean( + &self, + flag_key: &str, + user_id: &str, + properties: &HashMap, + ) -> Result<(), PostHogEvaluationError> { + let hash_on_global_rollout_percentage = Self::consistent_hash(user_id, flag_key, "boolean"); + self.evaluate_boolean_inner(flag_key, hash_on_global_rollout_percentage, properties) + } + /// Evaluate a multivariate feature flag. Note that we directly take the mapped user ID /// (a consistent hash ranging from 0 to 1) so that it is easier to use it in the tests /// and avoid duplicate computations. @@ -306,6 +349,11 @@ impl FeatureStore { flag_key ))); } + let Some(ref multivariate) = flag_config.filters.multivariate else { + return Err(PostHogEvaluationError::Internal(format!( + "No multivariate available, should use evaluate_boolean?: {flag_key}" + ))); + }; // TODO: sort the groups so that variant overrides always get evaluated first and it follows the PostHog // Python SDK behavior; for now we do not configure conditions without variant overrides in Neon so it // does not matter. @@ -314,7 +362,7 @@ impl FeatureStore { GroupEvaluationResult::MatchedAndOverride(variant) => return Ok(variant), GroupEvaluationResult::MatchedAndEvaluate => { let mut percentage = 0; - for variant in &flag_config.filters.multivariate.variants { + for variant in &multivariate.variants { percentage += variant.rollout_percentage; if self .evaluate_percentage(hash_on_global_rollout_percentage, percentage) @@ -342,6 +390,77 @@ impl FeatureStore { ))) } } + + /// Evaluate a multivariate feature flag. Note that we directly take the mapped user ID + /// (a consistent hash ranging from 0 to 1) so that it is easier to use it in the tests + /// and avoid duplicate computations. + /// + /// Use a different consistent hash for evaluating the group rollout percentage. + /// The behavior: if the condition is set to rolling out to 10% of the users, and + /// we set the variant A to 20% in the global config, then 2% of the total users will + /// be evaluated to variant A. + /// + /// Note that the hash to determine group rollout percentage is shared across all groups. So if we have two + /// exactly-the-same conditions with 10% and 20% rollout percentage respectively, a total of 20% of the users + /// will be evaluated (versus 30% if group evaluation is done independently). + pub(crate) fn evaluate_boolean_inner( + &self, + flag_key: &str, + hash_on_global_rollout_percentage: f64, + properties: &HashMap, + ) -> Result<(), PostHogEvaluationError> { + if let Some(flag_config) = self.flags.get(flag_key) { + if !flag_config.active { + return Err(PostHogEvaluationError::NotAvailable(format!( + "The feature flag is not active: {}", + flag_key + ))); + } + if flag_config.filters.multivariate.is_some() { + return Err(PostHogEvaluationError::Internal(format!( + "This looks like a multivariate flag, should use evaluate_multivariate?: {flag_key}" + ))); + }; + // TODO: sort the groups so that variant overrides always get evaluated first and it follows the PostHog + // Python SDK behavior; for now we do not configure conditions without variant overrides in Neon so it + // does not matter. + for group in &flag_config.filters.groups { + match self.evaluate_group(group, hash_on_global_rollout_percentage, properties)? { + GroupEvaluationResult::MatchedAndOverride(_) => { + return Err(PostHogEvaluationError::Internal(format!( + "Boolean flag cannot have overrides: {}", + flag_key + ))); + } + GroupEvaluationResult::MatchedAndEvaluate => { + return Ok(()); + } + GroupEvaluationResult::Unmatched => continue, + } + } + // If no group is matched, the feature is not available, and up to the caller to decide what to do. + Err(PostHogEvaluationError::NoConditionGroupMatched) + } else { + // The feature flag is not available yet + Err(PostHogEvaluationError::NotAvailable(format!( + "Not found in the local evaluation spec: {}", + flag_key + ))) + } + } +} + +pub struct PostHogClientConfig { + /// The server API key. + pub server_api_key: String, + /// The client API key. + pub client_api_key: String, + /// The project ID. + pub project_id: String, + /// The private API URL. + pub private_api_url: String, + /// The public API URL. + pub public_api_url: String, } /// A lite PostHog client. @@ -360,37 +479,16 @@ impl FeatureStore { /// want to report the feature flag usage back to PostHog. The current plan is to use PostHog only as an UI to /// configure feature flags so it is very likely that the client API will not be used. pub struct PostHogClient { - /// The server API key. - server_api_key: String, - /// The client API key. - client_api_key: String, - /// The project ID. - project_id: String, - /// The private API URL. - private_api_url: String, - /// The public API URL. - public_api_url: String, + /// The config. + config: PostHogClientConfig, /// The HTTP client. client: reqwest::Client, } impl PostHogClient { - pub fn new( - server_api_key: String, - client_api_key: String, - project_id: String, - private_api_url: String, - public_api_url: String, - ) -> Self { + pub fn new(config: PostHogClientConfig) -> Self { let client = reqwest::Client::new(); - Self { - server_api_key, - client_api_key, - project_id, - private_api_url, - public_api_url, - client, - } + Self { config, client } } pub fn new_with_us_region( @@ -398,13 +496,13 @@ impl PostHogClient { client_api_key: String, project_id: String, ) -> Self { - Self::new( + Self::new(PostHogClientConfig { server_api_key, client_api_key, project_id, - "https://us.posthog.com".to_string(), - "https://us.i.posthog.com".to_string(), - ) + private_api_url: "https://us.posthog.com".to_string(), + public_api_url: "https://us.i.posthog.com".to_string(), + }) } /// Fetch the feature flag specs from the server. @@ -422,12 +520,12 @@ impl PostHogClient { // with bearer token of self.server_api_key let url = format!( "{}/api/projects/{}/feature_flags/local_evaluation", - self.private_api_url, self.project_id + self.config.private_api_url, self.config.project_id ); let response = self .client .get(url) - .bearer_auth(&self.server_api_key) + .bearer_auth(&self.config.server_api_key) .send() .await?; let body = response.text().await?; @@ -446,11 +544,11 @@ impl PostHogClient { ) -> anyhow::Result<()> { // PUBLIC_URL/capture/ // with bearer token of self.client_api_key - let url = format!("{}/capture/", self.public_api_url); + let url = format!("{}/capture/", self.config.public_api_url); self.client .post(url) .body(serde_json::to_string(&json!({ - "api_key": self.client_api_key, + "api_key": self.config.client_api_key, "distinct_id": distinct_id, "event": event, "properties": properties, @@ -467,95 +565,162 @@ mod tests { fn data() -> &'static str { r#"{ - "flags": [ - { - "id": 132794, - "team_id": 152860, - "name": "", - "key": "gc-compaction", - "filters": { - "groups": [ - { - "variant": "enabled-stage-2", - "properties": [ - { - "key": "plan_type", - "type": "person", - "value": [ - "free" - ], - "operator": "exact" - }, - { - "key": "pageserver_remote_size", - "type": "person", - "value": "10000000", - "operator": "lt" - } - ], - "rollout_percentage": 50 - }, - { - "properties": [ - { - "key": "plan_type", - "type": "person", - "value": [ - "free" - ], - "operator": "exact" - }, - { - "key": "pageserver_remote_size", - "type": "person", - "value": "10000000", - "operator": "lt" - } - ], - "rollout_percentage": 80 - } - ], - "payloads": {}, - "multivariate": { - "variants": [ - { - "key": "disabled", - "name": "", - "rollout_percentage": 90 - }, - { - "key": "enabled-stage-1", - "name": "", - "rollout_percentage": 10 - }, - { - "key": "enabled-stage-2", - "name": "", - "rollout_percentage": 0 - }, - { - "key": "enabled-stage-3", - "name": "", - "rollout_percentage": 0 - }, - { - "key": "enabled", - "name": "", - "rollout_percentage": 0 - } - ] - } - }, - "deleted": false, - "active": true, - "ensure_experience_continuity": false, - "has_encrypted_payloads": false, - "version": 6 - } + "flags": [ + { + "id": 141807, + "team_id": 152860, + "name": "", + "key": "image-compaction-boundary", + "filters": { + "groups": [ + { + "variant": null, + "properties": [ + { + "key": "plan_type", + "type": "person", + "value": [ + "free" + ], + "operator": "exact" + } ], - "group_type_mapping": {}, - "cohorts": {} - }"# + "rollout_percentage": 40 + }, + { + "variant": null, + "properties": [], + "rollout_percentage": 10 + } + ], + "payloads": {}, + "multivariate": null + }, + "deleted": false, + "active": true, + "ensure_experience_continuity": false, + "has_encrypted_payloads": false, + "version": 1 + }, + { + "id": 135586, + "team_id": 152860, + "name": "", + "key": "boolean-flag", + "filters": { + "groups": [ + { + "variant": null, + "properties": [ + { + "key": "plan_type", + "type": "person", + "value": [ + "free" + ], + "operator": "exact" + } + ], + "rollout_percentage": 47 + } + ], + "payloads": {}, + "multivariate": null + }, + "deleted": false, + "active": true, + "ensure_experience_continuity": false, + "has_encrypted_payloads": false, + "version": 1 + }, + { + "id": 132794, + "team_id": 152860, + "name": "", + "key": "gc-compaction", + "filters": { + "groups": [ + { + "variant": "enabled-stage-2", + "properties": [ + { + "key": "plan_type", + "type": "person", + "value": [ + "free" + ], + "operator": "exact" + }, + { + "key": "pageserver_remote_size", + "type": "person", + "value": "10000000", + "operator": "lt" + } + ], + "rollout_percentage": 50 + }, + { + "properties": [ + { + "key": "plan_type", + "type": "person", + "value": [ + "free" + ], + "operator": "exact" + }, + { + "key": "pageserver_remote_size", + "type": "person", + "value": "10000000", + "operator": "lt" + } + ], + "rollout_percentage": 80 + } + ], + "payloads": {}, + "multivariate": { + "variants": [ + { + "key": "disabled", + "name": "", + "rollout_percentage": 90 + }, + { + "key": "enabled-stage-1", + "name": "", + "rollout_percentage": 10 + }, + { + "key": "enabled-stage-2", + "name": "", + "rollout_percentage": 0 + }, + { + "key": "enabled-stage-3", + "name": "", + "rollout_percentage": 0 + }, + { + "key": "enabled", + "name": "", + "rollout_percentage": 0 + } + ] + } + }, + "deleted": false, + "active": true, + "ensure_experience_continuity": false, + "has_encrypted_payloads": false, + "version": 7 + } + ], + "group_type_mapping": {}, + "cohorts": {} +}"# } #[test] @@ -631,4 +796,125 @@ mod tests { Err(PostHogEvaluationError::NoConditionGroupMatched) ),); } + + #[test] + fn evaluate_boolean_1() { + // The `boolean-flag` feature flag only has one group that matches on the free user. + + let mut store = FeatureStore::new(); + let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap(); + store.set_flags(response.flags); + + // This lacks the required properties and cannot be evaluated. + let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &HashMap::new()); + assert!(matches!( + variant, + Err(PostHogEvaluationError::NotAvailable(_)) + ),); + + let properties_unmatched = HashMap::from([ + ( + "plan_type".to_string(), + PostHogFlagFilterPropertyValue::String("paid".to_string()), + ), + ( + "pageserver_remote_size".to_string(), + PostHogFlagFilterPropertyValue::Number(1000.0), + ), + ]); + + // This does not match any group so there will be an error. + let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &properties_unmatched); + assert!(matches!( + variant, + Err(PostHogEvaluationError::NoConditionGroupMatched) + ),); + + let properties = HashMap::from([ + ( + "plan_type".to_string(), + PostHogFlagFilterPropertyValue::String("free".to_string()), + ), + ( + "pageserver_remote_size".to_string(), + PostHogFlagFilterPropertyValue::Number(1000.0), + ), + ]); + + // It matches the first group as 0.10 <= 0.50 and the properties are matched. Then it gets evaluated to the variant override. + let variant = store.evaluate_boolean_inner("boolean-flag", 0.10, &properties); + assert!(variant.is_ok()); + + // It matches the group conditions but not the group rollout percentage. + let variant = store.evaluate_boolean_inner("boolean-flag", 1.00, &properties); + assert!(matches!( + variant, + Err(PostHogEvaluationError::NoConditionGroupMatched) + ),); + } + + #[test] + fn evaluate_boolean_2() { + // The `image-compaction-boundary` feature flag has one group that matches on the free user and a group that matches on all users. + + let mut store = FeatureStore::new(); + let response: LocalEvaluationResponse = serde_json::from_str(data()).unwrap(); + store.set_flags(response.flags); + + // This lacks the required properties and cannot be evaluated. + let variant = + store.evaluate_boolean_inner("image-compaction-boundary", 1.00, &HashMap::new()); + assert!(matches!( + variant, + Err(PostHogEvaluationError::NotAvailable(_)) + ),); + + let properties_unmatched = HashMap::from([ + ( + "plan_type".to_string(), + PostHogFlagFilterPropertyValue::String("paid".to_string()), + ), + ( + "pageserver_remote_size".to_string(), + PostHogFlagFilterPropertyValue::Number(1000.0), + ), + ]); + + // This does not match the filtered group but the all user group. + let variant = + store.evaluate_boolean_inner("image-compaction-boundary", 1.00, &properties_unmatched); + assert!(matches!( + variant, + Err(PostHogEvaluationError::NoConditionGroupMatched) + ),); + let variant = + store.evaluate_boolean_inner("image-compaction-boundary", 0.05, &properties_unmatched); + assert!(variant.is_ok()); + + let properties = HashMap::from([ + ( + "plan_type".to_string(), + PostHogFlagFilterPropertyValue::String("free".to_string()), + ), + ( + "pageserver_remote_size".to_string(), + PostHogFlagFilterPropertyValue::Number(1000.0), + ), + ]); + + // It matches the first group as 0.30 <= 0.40 and the properties are matched. Then it gets evaluated to the variant override. + let variant = store.evaluate_boolean_inner("image-compaction-boundary", 0.30, &properties); + assert!(variant.is_ok()); + + // It matches the group conditions but not the group rollout percentage. + let variant = store.evaluate_boolean_inner("image-compaction-boundary", 1.00, &properties); + assert!(matches!( + variant, + Err(PostHogEvaluationError::NoConditionGroupMatched) + ),); + + // It matches the second "all" group conditions. + let variant = store.evaluate_boolean_inner("image-compaction-boundary", 0.09, &properties); + assert!(variant.is_ok()); + } } diff --git a/libs/proxy/postgres-protocol2/src/message/frontend.rs b/libs/proxy/postgres-protocol2/src/message/frontend.rs index b447290ea8..9faed2c065 100644 --- a/libs/proxy/postgres-protocol2/src/message/frontend.rs +++ b/libs/proxy/postgres-protocol2/src/message/frontend.rs @@ -25,6 +25,7 @@ where Ok(()) } +#[derive(Debug)] pub enum BindError { Conversion(Box), Serialization(io::Error), @@ -288,6 +289,12 @@ pub fn sync(buf: &mut BytesMut) { write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); } +#[inline] +pub fn flush(buf: &mut BytesMut) { + buf.put_u8(b'H'); + write_body(buf, |_| Ok::<(), io::Error>(())).unwrap(); +} + #[inline] pub fn terminate(buf: &mut BytesMut) { buf.put_u8(b'X'); diff --git a/libs/proxy/postgres-types2/src/lib.rs b/libs/proxy/postgres-types2/src/lib.rs index b6bcabc922..7c9874bda3 100644 --- a/libs/proxy/postgres-types2/src/lib.rs +++ b/libs/proxy/postgres-types2/src/lib.rs @@ -9,7 +9,6 @@ use std::error::Error; use std::fmt; use std::sync::Arc; -use bytes::BytesMut; use fallible_iterator::FallibleIterator; #[doc(inline)] pub use postgres_protocol2::Oid; @@ -27,41 +26,6 @@ macro_rules! accepts { ) } -/// Generates an implementation of `ToSql::to_sql_checked`. -/// -/// All `ToSql` implementations should use this macro. -macro_rules! to_sql_checked { - () => { - fn to_sql_checked( - &self, - ty: &$crate::Type, - out: &mut $crate::private::BytesMut, - ) -> ::std::result::Result< - $crate::IsNull, - Box, - > { - $crate::__to_sql_checked(self, ty, out) - } - }; -} - -// WARNING: this function is not considered part of this crate's public API. -// It is subject to change at any time. -#[doc(hidden)] -pub fn __to_sql_checked( - v: &T, - ty: &Type, - out: &mut BytesMut, -) -> Result> -where - T: ToSql, -{ - if !T::accepts(ty) { - return Err(Box::new(WrongType::new::(ty.clone()))); - } - v.to_sql(ty, out) -} - // mod pg_lsn; #[doc(hidden)] pub mod private; @@ -142,7 +106,7 @@ pub enum Kind { /// An array type along with the type of its elements. Array(Type), /// A range type along with the type of its elements. - Range(Type), + Range(Oid), /// A multirange type along with the type of its elements. Multirange(Type), /// A domain type along with its underlying type. @@ -377,43 +341,6 @@ pub enum IsNull { No, } -/// A trait for types that can be converted into Postgres values. -pub trait ToSql: fmt::Debug { - /// Converts the value of `self` into the binary format of the specified - /// Postgres `Type`, appending it to `out`. - /// - /// The caller of this method is responsible for ensuring that this type - /// is compatible with the Postgres `Type`. - /// - /// The return value indicates if this value should be represented as - /// `NULL`. If this is the case, implementations **must not** write - /// anything to `out`. - fn to_sql(&self, ty: &Type, out: &mut BytesMut) -> Result> - where - Self: Sized; - - /// Determines if a value of this type can be converted to the specified - /// Postgres `Type`. - fn accepts(ty: &Type) -> bool - where - Self: Sized; - - /// An adaptor method used internally by Rust-Postgres. - /// - /// *All* implementations of this method should be generated by the - /// `to_sql_checked!()` macro. - fn to_sql_checked( - &self, - ty: &Type, - out: &mut BytesMut, - ) -> Result>; - - /// Specify the encode format - fn encode_format(&self, _ty: &Type) -> Format { - Format::Binary - } -} - /// Supported Postgres message format types /// /// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8` @@ -424,52 +351,3 @@ pub enum Format { /// Compact, typed binary format Binary, } - -impl ToSql for &str { - fn to_sql(&self, ty: &Type, w: &mut BytesMut) -> Result> { - match *ty { - ref ty if ty.name() == "ltree" => types::ltree_to_sql(self, w), - ref ty if ty.name() == "lquery" => types::lquery_to_sql(self, w), - ref ty if ty.name() == "ltxtquery" => types::ltxtquery_to_sql(self, w), - _ => types::text_to_sql(self, w), - } - Ok(IsNull::No) - } - - fn accepts(ty: &Type) -> bool { - match *ty { - Type::VARCHAR | Type::TEXT | Type::BPCHAR | Type::NAME | Type::UNKNOWN => true, - ref ty - if (ty.name() == "citext" - || ty.name() == "ltree" - || ty.name() == "lquery" - || ty.name() == "ltxtquery") => - { - true - } - _ => false, - } - } - - to_sql_checked!(); -} - -macro_rules! simple_to { - ($t:ty, $f:ident, $($expected:ident),+) => { - impl ToSql for $t { - fn to_sql(&self, - _: &Type, - w: &mut BytesMut) - -> Result> { - types::$f(*self, w); - Ok(IsNull::No) - } - - accepts!($($expected),+); - - to_sql_checked!(); - } - } -} - -simple_to!(u32, oid_to_sql, OID); diff --git a/libs/proxy/postgres-types2/src/type_gen.rs b/libs/proxy/postgres-types2/src/type_gen.rs index a1bc3f85c0..6e6163e343 100644 --- a/libs/proxy/postgres-types2/src/type_gen.rs +++ b/libs/proxy/postgres-types2/src/type_gen.rs @@ -393,7 +393,7 @@ impl Inner { } } - pub fn oid(&self) -> Oid { + pub const fn const_oid(&self) -> Oid { match *self { Inner::Bool => 16, Inner::Bytea => 17, @@ -580,7 +580,14 @@ impl Inner { Inner::TstzmultiRangeArray => 6153, Inner::DatemultiRangeArray => 6155, Inner::Int8multiRangeArray => 6157, + Inner::Other(_) => u32::MAX, + } + } + + pub fn oid(&self) -> Oid { + match *self { Inner::Other(ref u) => u.oid, + _ => self.const_oid(), } } @@ -727,17 +734,17 @@ impl Inner { Inner::JsonbArray => &Kind::Array(Type(Inner::Jsonb)), Inner::AnyRange => &Kind::Pseudo, Inner::EventTrigger => &Kind::Pseudo, - Inner::Int4Range => &Kind::Range(Type(Inner::Int4)), + Inner::Int4Range => &const { Kind::Range(Inner::Int4.const_oid()) }, Inner::Int4RangeArray => &Kind::Array(Type(Inner::Int4Range)), - Inner::NumRange => &Kind::Range(Type(Inner::Numeric)), + Inner::NumRange => &const { Kind::Range(Inner::Numeric.const_oid()) }, Inner::NumRangeArray => &Kind::Array(Type(Inner::NumRange)), - Inner::TsRange => &Kind::Range(Type(Inner::Timestamp)), + Inner::TsRange => &const { Kind::Range(Inner::Timestamp.const_oid()) }, Inner::TsRangeArray => &Kind::Array(Type(Inner::TsRange)), - Inner::TstzRange => &Kind::Range(Type(Inner::Timestamptz)), + Inner::TstzRange => &const { Kind::Range(Inner::Timestamptz.const_oid()) }, Inner::TstzRangeArray => &Kind::Array(Type(Inner::TstzRange)), - Inner::DateRange => &Kind::Range(Type(Inner::Date)), + Inner::DateRange => &const { Kind::Range(Inner::Date.const_oid()) }, Inner::DateRangeArray => &Kind::Array(Type(Inner::DateRange)), - Inner::Int8Range => &Kind::Range(Type(Inner::Int8)), + Inner::Int8Range => &const { Kind::Range(Inner::Int8.const_oid()) }, Inner::Int8RangeArray => &Kind::Array(Type(Inner::Int8Range)), Inner::Jsonpath => &Kind::Simple, Inner::JsonpathArray => &Kind::Array(Type(Inner::Jsonpath)), diff --git a/libs/proxy/tokio-postgres2/src/client.rs b/libs/proxy/tokio-postgres2/src/client.rs index 186eb07000..a7edfc076a 100644 --- a/libs/proxy/tokio-postgres2/src/client.rs +++ b/libs/proxy/tokio-postgres2/src/client.rs @@ -1,14 +1,12 @@ use std::collections::HashMap; use std::fmt; use std::net::IpAddr; -use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; use bytes::BytesMut; use fallible_iterator::FallibleIterator; use futures_util::{TryStreamExt, future, ready}; -use parking_lot::Mutex; use postgres_protocol2::message::backend::Message; use postgres_protocol2::message::frontend; use serde::{Deserialize, Serialize}; @@ -16,29 +14,52 @@ use tokio::sync::mpsc; use crate::codec::{BackendMessages, FrontendMessage}; use crate::config::{Host, SslMode}; -use crate::connection::{Request, RequestMessages}; use crate::query::RowStream; use crate::simple_query::SimpleQueryStream; use crate::types::{Oid, Type}; use crate::{ - CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, Statement, Transaction, - TransactionBuilder, query, simple_query, + CancelToken, Error, ReadyForQueryStatus, SimpleQueryMessage, Transaction, TransactionBuilder, + query, simple_query, }; pub struct Responses { + /// new messages from conn receiver: mpsc::Receiver, + /// current batch of messages cur: BackendMessages, + /// number of total queries sent. + waiting: usize, + /// number of ReadyForQuery messages received. + received: usize, } impl Responses { pub fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll> { loop { - match self.cur.next().map_err(Error::parse)? { - Some(Message::ErrorResponse(body)) => return Poll::Ready(Err(Error::db(body))), - Some(message) => return Poll::Ready(Ok(message)), - None => {} + // get the next saved message + if let Some(message) = self.cur.next().map_err(Error::parse)? { + let received = self.received; + + // increase the query head if this is the last message. + if let Message::ReadyForQuery(_) = message { + self.received += 1; + } + + // check if the client has skipped this query. + if received + 1 < self.waiting { + // grab the next message. + continue; + } + + // convenience: turn the error messaage into a proper error. + let res = match message { + Message::ErrorResponse(body) => Err(Error::db(body)), + message => Ok(message), + }; + return Poll::Ready(res); } + // get the next batch of messages. match ready!(self.receiver.poll_recv(cx)) { Some(messages) => self.cur = messages, None => return Poll::Ready(Err(Error::closed())), @@ -55,44 +76,87 @@ impl Responses { /// (corresponding to the queries in the [crate::prepare] module). #[derive(Default)] pub(crate) struct CachedTypeInfo { - /// A statement for basic information for a type from its - /// OID. Corresponds to [TYPEINFO_QUERY](crate::prepare::TYPEINFO_QUERY) (or its - /// fallback). - pub(crate) typeinfo: Option, - /// Cache of types already looked up. pub(crate) types: HashMap, } pub struct InnerClient { - sender: mpsc::UnboundedSender, + sender: mpsc::UnboundedSender, + responses: Responses, /// A buffer to use when writing out postgres commands. - buffer: Mutex, + buffer: BytesMut, } impl InnerClient { - pub fn send(&self, messages: RequestMessages) -> Result { - let (sender, receiver) = mpsc::channel(1); - let request = Request { messages, sender }; - self.sender.send(request).map_err(|_| Error::closed())?; - - Ok(Responses { - receiver, - cur: BackendMessages::empty(), - }) + pub fn start(&mut self) -> Result { + self.responses.waiting += 1; + Ok(PartialQuery(Some(self))) } - /// Call the given function with a buffer to be used when writing out - /// postgres commands. - pub fn with_buf(&self, f: F) -> R + // pub fn send_with_sync(&mut self, f: F) -> Result<&mut Responses, Error> + // where + // F: FnOnce(&mut BytesMut) -> Result<(), Error>, + // { + // self.start()?.send_with_sync(f) + // } + + pub fn send_simple_query(&mut self, query: &str) -> Result<&mut Responses, Error> { + self.responses.waiting += 1; + + self.buffer.clear(); + // simple queries do not need sync. + frontend::query(query, &mut self.buffer).map_err(Error::encode)?; + let buf = self.buffer.split().freeze(); + self.send_message(FrontendMessage::Raw(buf)) + } + + fn send_message(&mut self, messages: FrontendMessage) -> Result<&mut Responses, Error> { + self.sender.send(messages).map_err(|_| Error::closed())?; + Ok(&mut self.responses) + } +} + +pub struct PartialQuery<'a>(Option<&'a mut InnerClient>); + +impl Drop for PartialQuery<'_> { + fn drop(&mut self) { + if let Some(client) = self.0.take() { + client.buffer.clear(); + frontend::sync(&mut client.buffer); + let buf = client.buffer.split().freeze(); + let _ = client.send_message(FrontendMessage::Raw(buf)); + } + } +} + +impl<'a> PartialQuery<'a> { + pub fn send_with_flush(&mut self, f: F) -> Result<&mut Responses, Error> where - F: FnOnce(&mut BytesMut) -> R, + F: FnOnce(&mut BytesMut) -> Result<(), Error>, { - let mut buffer = self.buffer.lock(); - let r = f(&mut buffer); - buffer.clear(); - r + let client = self.0.as_deref_mut().unwrap(); + + client.buffer.clear(); + f(&mut client.buffer)?; + frontend::flush(&mut client.buffer); + let buf = client.buffer.split().freeze(); + client.send_message(FrontendMessage::Raw(buf)) + } + + pub fn send_with_sync(mut self, f: F) -> Result<&'a mut Responses, Error> + where + F: FnOnce(&mut BytesMut) -> Result<(), Error>, + { + let client = self.0.as_deref_mut().unwrap(); + + client.buffer.clear(); + f(&mut client.buffer)?; + frontend::sync(&mut client.buffer); + let buf = client.buffer.split().freeze(); + let _ = client.send_message(FrontendMessage::Raw(buf)); + + Ok(&mut self.0.take().unwrap().responses) } } @@ -109,7 +173,7 @@ pub struct SocketConfig { /// The client is one half of what is returned when a connection is established. Users interact with the database /// through this client object. pub struct Client { - inner: Arc, + inner: InnerClient, cached_typeinfo: CachedTypeInfo, socket_config: SocketConfig, @@ -120,17 +184,24 @@ pub struct Client { impl Client { pub(crate) fn new( - sender: mpsc::UnboundedSender, + sender: mpsc::UnboundedSender, + receiver: mpsc::Receiver, socket_config: SocketConfig, ssl_mode: SslMode, process_id: i32, secret_key: i32, ) -> Client { Client { - inner: Arc::new(InnerClient { + inner: InnerClient { sender, + responses: Responses { + receiver, + cur: BackendMessages::empty(), + waiting: 0, + received: 0, + }, buffer: Default::default(), - }), + }, cached_typeinfo: Default::default(), socket_config, @@ -145,19 +216,29 @@ impl Client { self.process_id } - pub(crate) fn inner(&self) -> &Arc { - &self.inner + pub(crate) fn inner_mut(&mut self) -> &mut InnerClient { + &mut self.inner } /// Pass text directly to the Postgres backend to allow it to sort out typing itself and /// to save a roundtrip - pub async fn query_raw_txt(&self, statement: &str, params: I) -> Result + pub async fn query_raw_txt( + &mut self, + statement: &str, + params: I, + ) -> Result where S: AsRef, I: IntoIterator>, I::IntoIter: ExactSizeIterator, { - query::query_txt(&self.inner, statement, params).await + query::query_txt( + &mut self.inner, + &mut self.cached_typeinfo, + statement, + params, + ) + .await } /// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows. @@ -173,12 +254,15 @@ impl Client { /// Prepared statements should be use for any query which contains user-specified data, as they provided the /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass /// them to this method! - pub async fn simple_query(&self, query: &str) -> Result, Error> { + pub async fn simple_query(&mut self, query: &str) -> Result, Error> { self.simple_query_raw(query).await?.try_collect().await } - pub(crate) async fn simple_query_raw(&self, query: &str) -> Result { - simple_query::simple_query(self.inner(), query).await + pub(crate) async fn simple_query_raw( + &mut self, + query: &str, + ) -> Result { + simple_query::simple_query(self.inner_mut(), query).await } /// Executes a sequence of SQL statements using the simple query protocol. @@ -191,15 +275,11 @@ impl Client { /// Prepared statements should be use for any query which contains user-specified data, as they provided the /// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass /// them to this method! - pub async fn batch_execute(&self, query: &str) -> Result { - simple_query::batch_execute(self.inner(), query).await + pub async fn batch_execute(&mut self, query: &str) -> Result { + simple_query::batch_execute(self.inner_mut(), query).await } pub async fn discard_all(&mut self) -> Result { - // clear the prepared statements that are about to be nuked from the postgres session - - self.cached_typeinfo.typeinfo = None; - self.batch_execute("discard all").await } @@ -208,7 +288,7 @@ impl Client { /// The transaction will roll back by default - use the `commit` method to commit it. pub async fn transaction(&mut self) -> Result, Error> { struct RollbackIfNotDone<'me> { - client: &'me Client, + client: &'me mut Client, done: bool, } @@ -218,14 +298,7 @@ impl Client { return; } - let buf = self.client.inner().with_buf(|buf| { - frontend::query("ROLLBACK", buf).unwrap(); - buf.split().freeze() - }); - let _ = self - .client - .inner() - .send(RequestMessages::Single(FrontendMessage::Raw(buf))); + let _ = self.client.inner.send_simple_query("ROLLBACK"); } } @@ -239,7 +312,7 @@ impl Client { client: self, done: false, }; - self.batch_execute("BEGIN").await?; + cleaner.client.batch_execute("BEGIN").await?; cleaner.done = true; } @@ -265,11 +338,6 @@ impl Client { } } - /// Query for type information - pub(crate) async fn get_type_inner(&mut self, oid: Oid) -> Result { - crate::prepare::get_type(&self.inner, &mut self.cached_typeinfo, oid).await - } - /// Determines if the connection to the server has already closed. /// /// In that case, all future queries will fail. diff --git a/libs/proxy/tokio-postgres2/src/codec.rs b/libs/proxy/tokio-postgres2/src/codec.rs index f1fd9b47b3..daa5371426 100644 --- a/libs/proxy/tokio-postgres2/src/codec.rs +++ b/libs/proxy/tokio-postgres2/src/codec.rs @@ -1,21 +1,16 @@ use std::io; -use bytes::{Buf, Bytes, BytesMut}; +use bytes::{Bytes, BytesMut}; use fallible_iterator::FallibleIterator; use postgres_protocol2::message::backend; -use postgres_protocol2::message::frontend::CopyData; use tokio_util::codec::{Decoder, Encoder}; pub enum FrontendMessage { Raw(Bytes), - CopyData(CopyData>), } pub enum BackendMessage { - Normal { - messages: BackendMessages, - request_complete: bool, - }, + Normal { messages: BackendMessages }, Async(backend::Message), } @@ -44,7 +39,6 @@ impl Encoder for PostgresCodec { fn encode(&mut self, item: FrontendMessage, dst: &mut BytesMut) -> io::Result<()> { match item { FrontendMessage::Raw(buf) => dst.extend_from_slice(&buf), - FrontendMessage::CopyData(data) => data.write(dst), } Ok(()) @@ -57,7 +51,6 @@ impl Decoder for PostgresCodec { fn decode(&mut self, src: &mut BytesMut) -> Result, io::Error> { let mut idx = 0; - let mut request_complete = false; while let Some(header) = backend::Header::parse(&src[idx..])? { let len = header.len() as usize + 1; @@ -82,7 +75,6 @@ impl Decoder for PostgresCodec { idx += len; if header.tag() == backend::READY_FOR_QUERY_TAG { - request_complete = true; break; } } @@ -92,7 +84,6 @@ impl Decoder for PostgresCodec { } else { Ok(Some(BackendMessage::Normal { messages: BackendMessages(src.split_to(idx)), - request_complete, })) } } diff --git a/libs/proxy/tokio-postgres2/src/connect.rs b/libs/proxy/tokio-postgres2/src/connect.rs index 7c3a358bba..39a0a87c74 100644 --- a/libs/proxy/tokio-postgres2/src/connect.rs +++ b/libs/proxy/tokio-postgres2/src/connect.rs @@ -59,9 +59,11 @@ where connect_timeout: config.connect_timeout, }; - let (sender, receiver) = mpsc::unbounded_channel(); + let (client_tx, conn_rx) = mpsc::unbounded_channel(); + let (conn_tx, client_rx) = mpsc::channel(4); let client = Client::new( - sender, + client_tx, + client_rx, socket_config, config.ssl_mode, process_id, @@ -74,7 +76,7 @@ where .map(|m| BackendMessage::Async(Message::NoticeResponse(m))) .collect(); - let connection = Connection::new(stream, delayed, parameters, receiver); + let connection = Connection::new(stream, delayed, parameters, conn_tx, conn_rx); Ok((client, connection)) } diff --git a/libs/proxy/tokio-postgres2/src/connection.rs b/libs/proxy/tokio-postgres2/src/connection.rs index 99d6f3f8e2..fe0372b266 100644 --- a/libs/proxy/tokio-postgres2/src/connection.rs +++ b/libs/proxy/tokio-postgres2/src/connection.rs @@ -4,7 +4,6 @@ use std::pin::Pin; use std::task::{Context, Poll}; use bytes::BytesMut; -use fallible_iterator::FallibleIterator; use futures_util::{Sink, Stream, ready}; use postgres_protocol2::message::backend::Message; use postgres_protocol2::message::frontend; @@ -19,30 +18,12 @@ use crate::error::DbError; use crate::maybe_tls_stream::MaybeTlsStream; use crate::{AsyncMessage, Error, Notification}; -pub enum RequestMessages { - Single(FrontendMessage), -} - -pub struct Request { - pub messages: RequestMessages, - pub sender: mpsc::Sender, -} - -pub struct Response { - sender: PollSender, -} - #[derive(PartialEq, Debug)] enum State { Active, Closing, } -enum WriteReady { - Terminating, - WaitingOnRead, -} - /// A connection to a PostgreSQL database. /// /// This is one half of what is returned when a new connection is established. It performs the actual IO with the @@ -56,9 +37,11 @@ pub struct Connection { pub stream: Framed, PostgresCodec>, /// HACK: we need this in the Neon Proxy to forward params. pub parameters: HashMap, - receiver: mpsc::UnboundedReceiver, + + sender: PollSender, + receiver: mpsc::UnboundedReceiver, + pending_responses: VecDeque, - responses: VecDeque, state: State, } @@ -71,14 +54,15 @@ where stream: Framed, PostgresCodec>, pending_responses: VecDeque, parameters: HashMap, - receiver: mpsc::UnboundedReceiver, + sender: mpsc::Sender, + receiver: mpsc::UnboundedReceiver, ) -> Connection { Connection { stream, parameters, + sender: PollSender::new(sender), receiver, pending_responses, - responses: VecDeque::new(), state: State::Active, } } @@ -110,7 +94,7 @@ where } }; - let (mut messages, request_complete) = match message { + let messages = match message { BackendMessage::Async(Message::NoticeResponse(body)) => { let error = DbError::parse(&mut body.fields()).map_err(Error::parse)?; return Poll::Ready(Ok(AsyncMessage::Notice(error))); @@ -131,41 +115,19 @@ where continue; } BackendMessage::Async(_) => unreachable!(), - BackendMessage::Normal { - messages, - request_complete, - } => (messages, request_complete), + BackendMessage::Normal { messages } => messages, }; - let mut response = match self.responses.pop_front() { - Some(response) => response, - None => match messages.next().map_err(Error::parse)? { - Some(Message::ErrorResponse(error)) => { - return Poll::Ready(Err(Error::db(error))); - } - _ => return Poll::Ready(Err(Error::unexpected_message())), - }, - }; - - match response.sender.poll_reserve(cx) { + match self.sender.poll_reserve(cx) { Poll::Ready(Ok(())) => { - let _ = response.sender.send_item(messages); - if !request_complete { - self.responses.push_front(response); - } + let _ = self.sender.send_item(messages); } Poll::Ready(Err(_)) => { - // we need to keep paging through the rest of the messages even if the receiver's hung up - if !request_complete { - self.responses.push_front(response); - } + return Poll::Ready(Err(Error::closed())); } Poll::Pending => { - self.responses.push_front(response); - self.pending_responses.push_back(BackendMessage::Normal { - messages, - request_complete, - }); + self.pending_responses + .push_back(BackendMessage::Normal { messages }); trace!("poll_read: waiting on sender"); return Poll::Pending; } @@ -174,7 +136,7 @@ where } /// Fetch the next client request and enqueue the response sender. - fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_request(&mut self, cx: &mut Context<'_>) -> Poll> { if self.receiver.is_closed() { return Poll::Ready(None); } @@ -182,10 +144,7 @@ where match self.receiver.poll_recv(cx) { Poll::Ready(Some(request)) => { trace!("polled new request"); - self.responses.push_back(Response { - sender: PollSender::new(request.sender), - }); - Poll::Ready(Some(request.messages)) + Poll::Ready(Some(request)) } Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, @@ -194,7 +153,7 @@ where /// Process client requests and write them to the postgres connection, flushing if necessary. /// client -> postgres - fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll> { + fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll> { loop { if Pin::new(&mut self.stream) .poll_ready(cx) @@ -209,14 +168,14 @@ where match self.poll_request(cx) { // send the message to postgres - Poll::Ready(Some(RequestMessages::Single(request))) => { + Poll::Ready(Some(request)) => { Pin::new(&mut self.stream) .start_send(request) .map_err(Error::io)?; } // No more messages from the client, and no more responses to wait for. // Send a terminate message to postgres - Poll::Ready(None) if self.responses.is_empty() => { + Poll::Ready(None) => { trace!("poll_write: at eof, terminating"); let mut request = BytesMut::new(); frontend::terminate(&mut request); @@ -228,16 +187,7 @@ where trace!("poll_write: sent eof, closing"); trace!("poll_write: done"); - return Poll::Ready(Ok(WriteReady::Terminating)); - } - // No more messages from the client, but there are still some responses to wait for. - Poll::Ready(None) => { - trace!( - "poll_write: at eof, pending responses {}", - self.responses.len() - ); - ready!(self.poll_flush(cx))?; - return Poll::Ready(Ok(WriteReady::WaitingOnRead)); + return Poll::Ready(Ok(())); } // Still waiting for a message from the client. Poll::Pending => { @@ -298,7 +248,7 @@ where // if the state is still active, try read from and write to postgres. let message = self.poll_read(cx)?; let closing = self.poll_write(cx)?; - if let Poll::Ready(WriteReady::Terminating) = closing { + if let Poll::Ready(()) = closing { self.state = State::Closing; } diff --git a/libs/proxy/tokio-postgres2/src/generic_client.rs b/libs/proxy/tokio-postgres2/src/generic_client.rs index 8e28843347..eeefb45d26 100644 --- a/libs/proxy/tokio-postgres2/src/generic_client.rs +++ b/libs/proxy/tokio-postgres2/src/generic_client.rs @@ -1,9 +1,6 @@ #![allow(async_fn_in_trait)] -use postgres_protocol2::Oid; - use crate::query::RowStream; -use crate::types::Type; use crate::{Client, Error, Transaction}; mod private { @@ -15,20 +12,17 @@ mod private { /// This trait is "sealed", and cannot be implemented outside of this crate. pub trait GenericClient: private::Sealed { /// Like `Client::query_raw_txt`. - async fn query_raw_txt(&self, statement: &str, params: I) -> Result + async fn query_raw_txt(&mut self, statement: &str, params: I) -> Result where S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, I::IntoIter: ExactSizeIterator + Sync + Send; - - /// Query for type information - async fn get_type(&mut self, oid: Oid) -> Result; } impl private::Sealed for Client {} impl GenericClient for Client { - async fn query_raw_txt(&self, statement: &str, params: I) -> Result + async fn query_raw_txt(&mut self, statement: &str, params: I) -> Result where S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, @@ -36,17 +30,12 @@ impl GenericClient for Client { { self.query_raw_txt(statement, params).await } - - /// Query for type information - async fn get_type(&mut self, oid: Oid) -> Result { - self.get_type_inner(oid).await - } } impl private::Sealed for Transaction<'_> {} impl GenericClient for Transaction<'_> { - async fn query_raw_txt(&self, statement: &str, params: I) -> Result + async fn query_raw_txt(&mut self, statement: &str, params: I) -> Result where S: AsRef + Sync + Send, I: IntoIterator> + Sync + Send, @@ -54,9 +43,4 @@ impl GenericClient for Transaction<'_> { { self.query_raw_txt(statement, params).await } - - /// Query for type information - async fn get_type(&mut self, oid: Oid) -> Result { - self.client_mut().get_type(oid).await - } } diff --git a/libs/proxy/tokio-postgres2/src/lib.rs b/libs/proxy/tokio-postgres2/src/lib.rs index c8ebba5487..9556070ed5 100644 --- a/libs/proxy/tokio-postgres2/src/lib.rs +++ b/libs/proxy/tokio-postgres2/src/lib.rs @@ -18,7 +18,6 @@ pub use crate::statement::{Column, Statement}; pub use crate::tls::NoTls; pub use crate::transaction::Transaction; pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder}; -use crate::types::ToSql; /// After executing a query, the connection will be in one of these states #[derive(Clone, Copy, Debug, PartialEq)] @@ -120,9 +119,3 @@ pub enum SimpleQueryMessage { /// The number of rows modified or selected is returned. CommandComplete(u64), } - -fn slice_iter<'a>( - s: &'a [&'a (dyn ToSql + Sync)], -) -> impl ExactSizeIterator + 'a { - s.iter().map(|s| *s as _) -} diff --git a/libs/proxy/tokio-postgres2/src/prepare.rs b/libs/proxy/tokio-postgres2/src/prepare.rs index b27eabcb0e..16b9cf66f4 100644 --- a/libs/proxy/tokio-postgres2/src/prepare.rs +++ b/libs/proxy/tokio-postgres2/src/prepare.rs @@ -1,19 +1,14 @@ -use std::future::Future; -use std::pin::Pin; -use std::sync::Arc; - -use bytes::Bytes; +use bytes::BytesMut; use fallible_iterator::FallibleIterator; -use futures_util::{TryStreamExt, pin_mut}; -use postgres_protocol2::message::backend::Message; +use postgres_protocol2::IsNull; +use postgres_protocol2::message::backend::{Message, RowDescriptionBody}; use postgres_protocol2::message::frontend; -use tracing::debug; +use postgres_protocol2::types::oid_to_sql; +use postgres_types2::Format; -use crate::client::{CachedTypeInfo, InnerClient}; -use crate::codec::FrontendMessage; -use crate::connection::RequestMessages; +use crate::client::{CachedTypeInfo, PartialQuery, Responses}; use crate::types::{Kind, Oid, Type}; -use crate::{Column, Error, Statement, query, slice_iter}; +use crate::{Column, Error, Row, Statement}; pub(crate) const TYPEINFO_QUERY: &str = "\ SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid @@ -23,22 +18,51 @@ INNER JOIN pg_catalog.pg_namespace n ON t.typnamespace = n.oid WHERE t.oid = $1 "; +/// we need to make sure we close this prepared statement. +struct CloseStmt<'a, 'b> { + client: Option<&'a mut PartialQuery<'b>>, + name: &'static str, +} + +impl<'a> CloseStmt<'a, '_> { + fn close(mut self) -> Result<&'a mut Responses, Error> { + let client = self.client.take().unwrap(); + client.send_with_flush(|buf| { + frontend::close(b'S', self.name, buf).map_err(Error::encode)?; + Ok(()) + }) + } +} + +impl Drop for CloseStmt<'_, '_> { + fn drop(&mut self) { + if let Some(client) = self.client.take() { + let _ = client.send_with_flush(|buf| { + frontend::close(b'S', self.name, buf).map_err(Error::encode)?; + Ok(()) + }); + } + } +} + async fn prepare_typecheck( - client: &Arc, + client: &mut PartialQuery<'_>, name: &'static str, query: &str, - types: &[Type], ) -> Result { - let buf = encode(client, name, query, types)?; - let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + let responses = client.send_with_flush(|buf| { + frontend::parse(name, query, [], buf).map_err(Error::encode)?; + frontend::describe(b'S', name, buf).map_err(Error::encode)?; + Ok(()) + })?; match responses.next().await? { Message::ParseComplete => {} _ => return Err(Error::unexpected_message()), } - let parameter_description = match responses.next().await? { - Message::ParameterDescription(body) => body, + match responses.next().await? { + Message::ParameterDescription(_) => {} _ => return Err(Error::unexpected_message()), }; @@ -48,13 +72,6 @@ async fn prepare_typecheck( _ => return Err(Error::unexpected_message()), }; - let mut parameters = vec![]; - let mut it = parameter_description.parameters(); - while let Some(oid) = it.next().map_err(Error::parse)? { - let type_ = Type::from_oid(oid).ok_or_else(Error::unexpected_message)?; - parameters.push(type_); - } - let mut columns = vec![]; if let Some(row_description) = row_description { let mut it = row_description.fields(); @@ -65,98 +82,168 @@ async fn prepare_typecheck( } } - Ok(Statement::new(client, name, parameters, columns)) + Ok(Statement::new(name, columns)) } -fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result { - if types.is_empty() { - debug!("preparing query {}: {}", name, query); - } else { - debug!("preparing query {} with types {:?}: {}", name, types, query); - } - - client.with_buf(|buf| { - frontend::parse(name, query, types.iter().map(Type::oid), buf).map_err(Error::encode)?; - frontend::describe(b'S', name, buf).map_err(Error::encode)?; - frontend::sync(buf); - Ok(buf.split().freeze()) - }) -} - -pub async fn get_type( - client: &Arc, - typecache: &mut CachedTypeInfo, - oid: Oid, -) -> Result { +fn try_from_cache(typecache: &CachedTypeInfo, oid: Oid) -> Option { if let Some(type_) = Type::from_oid(oid) { - return Ok(type_); + return Some(type_); } if let Some(type_) = typecache.types.get(&oid) { - return Ok(type_.clone()); + return Some(type_.clone()); }; - let stmt = typeinfo_statement(client, typecache).await?; + None +} - let rows = query::query(client, stmt, slice_iter(&[&oid])).await?; - pin_mut!(rows); +pub async fn parse_row_description( + client: &mut PartialQuery<'_>, + typecache: &mut CachedTypeInfo, + row_description: Option, +) -> Result, Error> { + let mut columns = vec![]; - let row = match rows.try_next().await? { - Some(row) => row, - None => return Err(Error::unexpected_message()), + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = try_from_cache(typecache, field.type_oid()).unwrap_or(Type::UNKNOWN); + let column = Column::new(field.name().to_string(), type_, field); + columns.push(column); + } + } + + let all_known = columns.iter().all(|c| c.type_ != Type::UNKNOWN); + if all_known { + // all known, return early. + return Ok(columns); + } + + let typeinfo = "neon_proxy_typeinfo"; + + // make sure to close the typeinfo statement before exiting. + let mut guard = CloseStmt { + name: typeinfo, + client: None, + }; + let client = guard.client.insert(client); + + // get the typeinfo statement. + let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY).await?; + + for column in &mut columns { + column.type_ = get_type(client, typecache, &stmt, column.type_oid()).await?; + } + + // cancel the close guard. + let responses = guard.close()?; + + match responses.next().await? { + Message::CloseComplete => {} + _ => return Err(Error::unexpected_message()), + } + + Ok(columns) +} + +async fn get_type( + client: &mut PartialQuery<'_>, + typecache: &mut CachedTypeInfo, + stmt: &Statement, + mut oid: Oid, +) -> Result { + let mut stack = vec![]; + let mut type_ = loop { + if let Some(type_) = try_from_cache(typecache, oid) { + break type_; + } + + let row = exec(client, stmt, oid).await?; + if stack.len() > 8 { + return Err(Error::unexpected_message()); + } + + let name: String = row.try_get(0)?; + let type_: i8 = row.try_get(1)?; + let elem_oid: Oid = row.try_get(2)?; + let rngsubtype: Option = row.try_get(3)?; + let basetype: Oid = row.try_get(4)?; + let schema: String = row.try_get(5)?; + let relid: Oid = row.try_get(6)?; + + let kind = if type_ == b'e' as i8 { + Kind::Enum + } else if type_ == b'p' as i8 { + Kind::Pseudo + } else if basetype != 0 { + Kind::Domain(basetype) + } else if elem_oid != 0 { + stack.push((name, oid, schema)); + oid = elem_oid; + continue; + } else if relid != 0 { + Kind::Composite(relid) + } else if let Some(rngsubtype) = rngsubtype { + Kind::Range(rngsubtype) + } else { + Kind::Simple + }; + + let type_ = Type::new(name, oid, kind, schema); + typecache.types.insert(oid, type_.clone()); + break type_; }; - let name: String = row.try_get(0)?; - let type_: i8 = row.try_get(1)?; - let elem_oid: Oid = row.try_get(2)?; - let rngsubtype: Option = row.try_get(3)?; - let basetype: Oid = row.try_get(4)?; - let schema: String = row.try_get(5)?; - let relid: Oid = row.try_get(6)?; - - let kind = if type_ == b'e' as i8 { - Kind::Enum - } else if type_ == b'p' as i8 { - Kind::Pseudo - } else if basetype != 0 { - Kind::Domain(basetype) - } else if elem_oid != 0 { - let type_ = get_type_rec(client, typecache, elem_oid).await?; - Kind::Array(type_) - } else if relid != 0 { - Kind::Composite(relid) - } else if let Some(rngsubtype) = rngsubtype { - let type_ = get_type_rec(client, typecache, rngsubtype).await?; - Kind::Range(type_) - } else { - Kind::Simple - }; - - let type_ = Type::new(name, oid, kind, schema); - typecache.types.insert(oid, type_.clone()); + while let Some((name, oid, schema)) = stack.pop() { + type_ = Type::new(name, oid, Kind::Array(type_), schema); + typecache.types.insert(oid, type_.clone()); + } Ok(type_) } -fn get_type_rec<'a>( - client: &'a Arc, - typecache: &'a mut CachedTypeInfo, - oid: Oid, -) -> Pin> + Send + 'a>> { - Box::pin(get_type(client, typecache, oid)) -} +/// exec the typeinfo statement returning one row. +async fn exec( + client: &mut PartialQuery<'_>, + statement: &Statement, + param: Oid, +) -> Result { + let responses = client.send_with_flush(|buf| { + encode_bind(statement, param, "", buf); + frontend::execute("", 0, buf).map_err(Error::encode)?; + Ok(()) + })?; -async fn typeinfo_statement( - client: &Arc, - typecache: &mut CachedTypeInfo, -) -> Result { - if let Some(stmt) = &typecache.typeinfo { - return Ok(stmt.clone()); + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), } - let typeinfo = "neon_proxy_typeinfo"; - let stmt = prepare_typecheck(client, typeinfo, TYPEINFO_QUERY, &[]).await?; + let row = match responses.next().await? { + Message::DataRow(body) => Row::new(statement.clone(), body, Format::Binary)?, + _ => return Err(Error::unexpected_message()), + }; - typecache.typeinfo = Some(stmt.clone()); - Ok(stmt) + match responses.next().await? { + Message::CommandComplete(_) => {} + _ => return Err(Error::unexpected_message()), + }; + + Ok(row) +} + +fn encode_bind(statement: &Statement, param: Oid, portal: &str, buf: &mut BytesMut) { + frontend::bind( + portal, + statement.name(), + [Format::Binary as i16], + [param], + |param, buf| { + oid_to_sql(param, buf); + Ok(IsNull::No) + }, + [Format::Binary as i16], + buf, + ) + .unwrap(); } diff --git a/libs/proxy/tokio-postgres2/src/query.rs b/libs/proxy/tokio-postgres2/src/query.rs index 106bc69d49..5f3ed8ef5a 100644 --- a/libs/proxy/tokio-postgres2/src/query.rs +++ b/libs/proxy/tokio-postgres2/src/query.rs @@ -1,76 +1,43 @@ -use std::fmt; -use std::marker::PhantomPinned; use std::pin::Pin; -use std::sync::Arc; use std::task::{Context, Poll}; -use bytes::{BufMut, Bytes, BytesMut}; -use fallible_iterator::FallibleIterator; +use bytes::BufMut; use futures_util::{Stream, ready}; -use pin_project_lite::pin_project; use postgres_protocol2::message::backend::Message; use postgres_protocol2::message::frontend; -use postgres_types2::{Format, ToSql, Type}; -use tracing::debug; +use postgres_types2::Format; -use crate::client::{InnerClient, Responses}; -use crate::codec::FrontendMessage; -use crate::connection::RequestMessages; -use crate::types::IsNull; -use crate::{Column, Error, ReadyForQueryStatus, Row, Statement}; +use crate::client::{CachedTypeInfo, InnerClient, Responses}; +use crate::{Error, ReadyForQueryStatus, Row, Statement}; -struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]); - -impl fmt::Debug for BorrowToSqlParamsDebug<'_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_list().entries(self.0.iter()).finish() - } -} - -pub async fn query<'a, I>( - client: &InnerClient, - statement: Statement, - params: I, -) -> Result -where - I: IntoIterator, - I::IntoIter: ExactSizeIterator, -{ - let buf = if tracing::enabled!(tracing::Level::DEBUG) { - let params = params.into_iter().collect::>(); - debug!( - "executing statement {} with parameters: {:?}", - statement.name(), - BorrowToSqlParamsDebug(params.as_slice()), - ); - encode(client, &statement, params)? - } else { - encode(client, &statement, params)? - }; - let responses = start(client, buf).await?; - Ok(RowStream { - statement, - responses, - command_tag: None, - status: ReadyForQueryStatus::Unknown, - output_format: Format::Binary, - _p: PhantomPinned, - }) -} - -pub async fn query_txt( - client: &Arc, +pub async fn query_txt<'a, S, I>( + client: &'a mut InnerClient, + typecache: &mut CachedTypeInfo, query: &str, params: I, -) -> Result +) -> Result, Error> where S: AsRef, I: IntoIterator>, I::IntoIter: ExactSizeIterator, { let params = params.into_iter(); + let mut client = client.start()?; - let buf = client.with_buf(|buf| { + // Flow: + // 1. Parse the query + // 2. Inspect the row description for OIDs + // 3. If there's any OIDs we don't already know about, perform the typeinfo routine + // 4. Execute the query + // 5. Sync. + // + // The typeinfo routine: + // 1. Parse the typeinfo query + // 2. Execute the query on each OID + // 3. If the result does not match an OID we know, repeat 2. + + // parse the query and get type info + let responses = client.send_with_flush(|buf| { frontend::parse( "", // unnamed prepared statement query, // query to parse @@ -79,7 +46,30 @@ where ) .map_err(Error::encode)?; frontend::describe(b'S', "", buf).map_err(Error::encode)?; - // Bind, pass params as text, retrieve as binary + Ok(()) + })?; + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + + match responses.next().await? { + Message::ParameterDescription(_) => {} + _ => return Err(Error::unexpected_message()), + }; + + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + let columns = + crate::prepare::parse_row_description(&mut client, typecache, row_description).await?; + + let responses = client.send_with_sync(|buf| { + // Bind, pass params as text, retrieve as text match frontend::bind( "", // empty string selects the unnamed portal "", // unnamed prepared statement @@ -102,173 +92,55 @@ where // Execute frontend::execute("", 0, buf).map_err(Error::encode)?; - // Sync - frontend::sync(buf); - Ok(buf.split().freeze()) + Ok(()) })?; - // now read the responses - let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - - match responses.next().await? { - Message::ParseComplete => {} - _ => return Err(Error::unexpected_message()), - } - - let parameter_description = match responses.next().await? { - Message::ParameterDescription(body) => body, - _ => return Err(Error::unexpected_message()), - }; - - let row_description = match responses.next().await? { - Message::RowDescription(body) => Some(body), - Message::NoData => None, - _ => return Err(Error::unexpected_message()), - }; - match responses.next().await? { Message::BindComplete => {} _ => return Err(Error::unexpected_message()), } - let mut parameters = vec![]; - let mut it = parameter_description.parameters(); - while let Some(oid) = it.next().map_err(Error::parse)? { - let type_ = Type::from_oid(oid).unwrap_or(Type::UNKNOWN); - parameters.push(type_); - } - - let mut columns = vec![]; - if let Some(row_description) = row_description { - let mut it = row_description.fields(); - while let Some(field) = it.next().map_err(Error::parse)? { - let type_ = Type::from_oid(field.type_oid()).unwrap_or(Type::UNKNOWN); - let column = Column::new(field.name().to_string(), type_, field); - columns.push(column); - } - } - Ok(RowStream { - statement: Statement::new_anonymous(parameters, columns), responses, + statement: Statement::new("", columns), command_tag: None, status: ReadyForQueryStatus::Unknown, output_format: Format::Text, - _p: PhantomPinned, }) } -async fn start(client: &InnerClient, buf: Bytes) -> Result { - let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; - - match responses.next().await? { - Message::BindComplete => {} - _ => return Err(Error::unexpected_message()), - } - - Ok(responses) +/// A stream of table rows. +pub struct RowStream<'a> { + responses: &'a mut Responses, + output_format: Format, + pub statement: Statement, + pub command_tag: Option, + pub status: ReadyForQueryStatus, } -pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result -where - I: IntoIterator, - I::IntoIter: ExactSizeIterator, -{ - client.with_buf(|buf| { - encode_bind(statement, params, "", buf)?; - frontend::execute("", 0, buf).map_err(Error::encode)?; - frontend::sync(buf); - Ok(buf.split().freeze()) - }) -} - -pub fn encode_bind<'a, I>( - statement: &Statement, - params: I, - portal: &str, - buf: &mut BytesMut, -) -> Result<(), Error> -where - I: IntoIterator, - I::IntoIter: ExactSizeIterator, -{ - let param_types = statement.params(); - let params = params.into_iter(); - - assert!( - param_types.len() == params.len(), - "expected {} parameters but got {}", - param_types.len(), - params.len() - ); - - let (param_formats, params): (Vec<_>, Vec<_>) = params - .zip(param_types.iter()) - .map(|(p, ty)| (p.encode_format(ty) as i16, p)) - .unzip(); - - let params = params.into_iter(); - - let mut error_idx = 0; - let r = frontend::bind( - portal, - statement.name(), - param_formats, - params.zip(param_types).enumerate(), - |(idx, (param, ty)), buf| match param.to_sql_checked(ty, buf) { - Ok(IsNull::No) => Ok(postgres_protocol2::IsNull::No), - Ok(IsNull::Yes) => Ok(postgres_protocol2::IsNull::Yes), - Err(e) => { - error_idx = idx; - Err(e) - } - }, - Some(1), - buf, - ); - match r { - Ok(()) => Ok(()), - Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)), - Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), - } -} - -pin_project! { - /// A stream of table rows. - pub struct RowStream { - statement: Statement, - responses: Responses, - command_tag: Option, - output_format: Format, - status: ReadyForQueryStatus, - #[pin] - _p: PhantomPinned, - } -} - -impl Stream for RowStream { +impl Stream for RowStream<'_> { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = self.project(); + let this = self.get_mut(); loop { match ready!(this.responses.poll_next(cx)?) { Message::DataRow(body) => { return Poll::Ready(Some(Ok(Row::new( this.statement.clone(), body, - *this.output_format, + this.output_format, )?))); } Message::EmptyQueryResponse | Message::PortalSuspended => {} Message::CommandComplete(body) => { if let Ok(tag) = body.tag() { - *this.command_tag = Some(tag.to_string()); + this.command_tag = Some(tag.to_string()); } } Message::ReadyForQuery(status) => { - *this.status = status.into(); + this.status = status.into(); return Poll::Ready(None); } _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), @@ -276,24 +148,3 @@ impl Stream for RowStream { } } } - -impl RowStream { - /// Returns information about the columns of data in the row. - pub fn columns(&self) -> &[Column] { - self.statement.columns() - } - - /// Returns the command tag of this query. - /// - /// This is only available after the stream has been exhausted. - pub fn command_tag(&self) -> Option { - self.command_tag.clone() - } - - /// Returns if the connection is ready for querying, with the status of the connection. - /// - /// This might be available only after the stream has been exhausted. - pub fn ready_status(&self) -> ReadyForQueryStatus { - self.status - } -} diff --git a/libs/proxy/tokio-postgres2/src/simple_query.rs b/libs/proxy/tokio-postgres2/src/simple_query.rs index 2cf17188cf..e1ed48cdaf 100644 --- a/libs/proxy/tokio-postgres2/src/simple_query.rs +++ b/libs/proxy/tokio-postgres2/src/simple_query.rs @@ -1,19 +1,14 @@ -use std::marker::PhantomPinned; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use bytes::Bytes; use fallible_iterator::FallibleIterator; use futures_util::{Stream, ready}; use pin_project_lite::pin_project; use postgres_protocol2::message::backend::Message; -use postgres_protocol2::message::frontend; use tracing::debug; use crate::client::{InnerClient, Responses}; -use crate::codec::FrontendMessage; -use crate::connection::RequestMessages; use crate::{Error, ReadyForQueryStatus, SimpleQueryMessage, SimpleQueryRow}; /// Information about a column of a single query row. @@ -33,28 +28,28 @@ impl SimpleColumn { } } -pub async fn simple_query(client: &InnerClient, query: &str) -> Result { +pub async fn simple_query<'a>( + client: &'a mut InnerClient, + query: &str, +) -> Result, Error> { debug!("executing simple query: {}", query); - let buf = encode(client, query)?; - let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + let responses = client.send_simple_query(query)?; Ok(SimpleQueryStream { responses, columns: None, status: ReadyForQueryStatus::Unknown, - _p: PhantomPinned, }) } pub async fn batch_execute( - client: &InnerClient, + client: &mut InnerClient, query: &str, ) -> Result { debug!("executing statement batch: {}", query); - let buf = encode(client, query)?; - let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + let responses = client.send_simple_query(query)?; loop { match responses.next().await? { @@ -68,25 +63,16 @@ pub async fn batch_execute( } } -pub(crate) fn encode(client: &InnerClient, query: &str) -> Result { - client.with_buf(|buf| { - frontend::query(query, buf).map_err(Error::encode)?; - Ok(buf.split().freeze()) - }) -} - pin_project! { /// A stream of simple query results. - pub struct SimpleQueryStream { - responses: Responses, + pub struct SimpleQueryStream<'a> { + responses: &'a mut Responses, columns: Option>, status: ReadyForQueryStatus, - #[pin] - _p: PhantomPinned, } } -impl SimpleQueryStream { +impl SimpleQueryStream<'_> { /// Returns if the connection is ready for querying, with the status of the connection. /// /// This might be available only after the stream has been exhausted. @@ -95,7 +81,7 @@ impl SimpleQueryStream { } } -impl Stream for SimpleQueryStream { +impl Stream for SimpleQueryStream<'_> { type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { diff --git a/libs/proxy/tokio-postgres2/src/statement.rs b/libs/proxy/tokio-postgres2/src/statement.rs index e4828db712..1f22d87fd7 100644 --- a/libs/proxy/tokio-postgres2/src/statement.rs +++ b/libs/proxy/tokio-postgres2/src/statement.rs @@ -1,35 +1,15 @@ use std::fmt; -use std::sync::{Arc, Weak}; +use std::sync::Arc; +use crate::types::Type; use postgres_protocol2::Oid; use postgres_protocol2::message::backend::Field; -use postgres_protocol2::message::frontend; - -use crate::client::InnerClient; -use crate::codec::FrontendMessage; -use crate::connection::RequestMessages; -use crate::types::Type; struct StatementInner { - client: Weak, name: &'static str, - params: Vec, columns: Vec, } -impl Drop for StatementInner { - fn drop(&mut self) { - if let Some(client) = self.client.upgrade() { - let buf = client.with_buf(|buf| { - frontend::close(b'S', self.name, buf).unwrap(); - frontend::sync(buf); - buf.split().freeze() - }); - let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf))); - } - } -} - /// A prepared statement. /// /// Prepared statements can only be used with the connection that created them. @@ -37,38 +17,14 @@ impl Drop for StatementInner { pub struct Statement(Arc); impl Statement { - pub(crate) fn new( - inner: &Arc, - name: &'static str, - params: Vec, - columns: Vec, - ) -> Statement { - Statement(Arc::new(StatementInner { - client: Arc::downgrade(inner), - name, - params, - columns, - })) - } - - pub(crate) fn new_anonymous(params: Vec, columns: Vec) -> Statement { - Statement(Arc::new(StatementInner { - client: Weak::new(), - name: "", - params, - columns, - })) + pub(crate) fn new(name: &'static str, columns: Vec) -> Statement { + Statement(Arc::new(StatementInner { name, columns })) } pub(crate) fn name(&self) -> &str { self.0.name } - /// Returns the expected types of the statement's parameters. - pub fn params(&self) -> &[Type] { - &self.0.params - } - /// Returns information about the columns returned when the statement is queried. pub fn columns(&self) -> &[Column] { &self.0.columns @@ -78,7 +34,7 @@ impl Statement { /// Information about a column of a query. pub struct Column { name: String, - type_: Type, + pub(crate) type_: Type, // raw fields from RowDescription table_oid: Oid, diff --git a/libs/proxy/tokio-postgres2/src/transaction.rs b/libs/proxy/tokio-postgres2/src/transaction.rs index f32603470f..12fe0737d4 100644 --- a/libs/proxy/tokio-postgres2/src/transaction.rs +++ b/libs/proxy/tokio-postgres2/src/transaction.rs @@ -1,7 +1,3 @@ -use postgres_protocol2::message::frontend; - -use crate::codec::FrontendMessage; -use crate::connection::RequestMessages; use crate::query::RowStream; use crate::{CancelToken, Client, Error, ReadyForQueryStatus}; @@ -20,14 +16,7 @@ impl Drop for Transaction<'_> { return; } - let buf = self.client.inner().with_buf(|buf| { - frontend::query("ROLLBACK", buf).unwrap(); - buf.split().freeze() - }); - let _ = self - .client - .inner() - .send(RequestMessages::Single(FrontendMessage::Raw(buf))); + let _ = self.client.inner_mut().send_simple_query("ROLLBACK"); } } @@ -54,7 +43,11 @@ impl<'a> Transaction<'a> { } /// Like `Client::query_raw_txt`. - pub async fn query_raw_txt(&self, statement: &str, params: I) -> Result + pub async fn query_raw_txt( + &mut self, + statement: &str, + params: I, + ) -> Result where S: AsRef, I: IntoIterator>, diff --git a/libs/walproposer/src/walproposer.rs b/libs/walproposer/src/walproposer.rs index 4e50c21fca..e95494297c 100644 --- a/libs/walproposer/src/walproposer.rs +++ b/libs/walproposer/src/walproposer.rs @@ -1,6 +1,7 @@ #![allow(clippy::todo)] use std::ffi::CString; +use std::str::FromStr; use postgres_ffi::WAL_SEGMENT_SIZE; use utils::id::TenantTimelineId; @@ -173,6 +174,8 @@ pub struct Config { pub ttid: TenantTimelineId, /// List of safekeepers in format `host:port` pub safekeepers_list: Vec, + /// libpq connection info options + pub safekeeper_conninfo_options: String, /// Safekeeper reconnect timeout in milliseconds pub safekeeper_reconnect_timeout: i32, /// Safekeeper connection timeout in milliseconds @@ -202,6 +205,9 @@ impl Wrapper { .into_bytes_with_nul(); assert!(safekeepers_list_vec.len() == safekeepers_list_vec.capacity()); let safekeepers_list = safekeepers_list_vec.as_mut_ptr() as *mut std::ffi::c_char; + let safekeeper_conninfo_options = CString::from_str(&config.safekeeper_conninfo_options) + .unwrap() + .into_raw(); let callback_data = Box::into_raw(Box::new(api)) as *mut ::std::os::raw::c_void; @@ -209,6 +215,7 @@ impl Wrapper { neon_tenant, neon_timeline, safekeepers_list, + safekeeper_conninfo_options, safekeeper_reconnect_timeout: config.safekeeper_reconnect_timeout, safekeeper_connection_timeout: config.safekeeper_connection_timeout, wal_segment_size: WAL_SEGMENT_SIZE as i32, // default 16MB @@ -576,6 +583,7 @@ mod tests { let config = crate::walproposer::Config { ttid, safekeepers_list: vec!["localhost:5000".to_string()], + safekeeper_conninfo_options: String::new(), safekeeper_reconnect_timeout: 1000, safekeeper_connection_timeout: 10000, sync_safekeepers: true, diff --git a/pageserver/Cargo.toml b/pageserver/Cargo.toml index 6a9a5a292a..c4d6d58945 100644 --- a/pageserver/Cargo.toml +++ b/pageserver/Cargo.toml @@ -17,50 +17,69 @@ anyhow.workspace = true arc-swap.workspace = true async-compression.workspace = true async-stream.workspace = true -bit_field.workspace = true bincode.workspace = true +bit_field.workspace = true byteorder.workspace = true bytes.workspace = true -camino.workspace = true camino-tempfile.workspace = true +camino.workspace = true chrono = { workspace = true, features = ["serde"] } clap = { workspace = true, features = ["string"] } consumption_metrics.workspace = true crc32c.workspace = true either.workspace = true +enum-map.workspace = true +enumset = { workspace = true, features = ["serde"]} fail.workspace = true futures.workspace = true hashlink.workspace = true hex.workspace = true -humantime.workspace = true +http-utils.workspace = true humantime-serde.workspace = true +humantime.workspace = true hyper0.workspace = true itertools.workspace = true jsonwebtoken.workspace = true md5.workspace = true +metrics.workspace = true nix.workspace = true -# hack to get the number of worker threads tokio uses -num_cpus.workspace = true +num_cpus.workspace = true # hack to get the number of worker threads tokio uses num-traits.workspace = true once_cell.workspace = true +pageserver_api.workspace = true +pageserver_client.workspace = true # for ResponseErrorMessageExt TOOD refactor that +pageserver_compaction.workspace = true +pageserver_page_api.workspace = true +pem.workspace = true pin-project-lite.workspace = true postgres_backend.workspace = true +postgres_connection.workspace = true +postgres_ffi.workspace = true +postgres_initdb.workspace = true postgres-protocol.workspace = true postgres-types.workspace = true -postgres_initdb.workspace = true +posthog_client_lite.workspace = true pprof.workspace = true +pq_proto.workspace = true rand.workspace = true range-set-blaze = { version = "0.1.16", features = ["alloc"] } regex.workspace = true +remote_storage.workspace = true +reqwest.workspace = true +rpds.workspace = true rustls.workspace = true scopeguard.workspace = true send-future.workspace = true -serde.workspace = true serde_json = { workspace = true, features = ["raw_value"] } serde_path_to_error.workspace = true serde_with.workspace = true +serde.workspace = true +smallvec.workspace = true +storage_broker.workspace = true +strum_macros.workspace = true +strum.workspace = true sysinfo.workspace = true -tokio-tar.workspace = true +tenant_size_model.workspace = true thiserror.workspace = true tikv-jemallocator.workspace = true tokio = { workspace = true, features = ["process", "sync", "fs", "rt", "io-util", "time"] } @@ -69,34 +88,18 @@ tokio-io-timeout.workspace = true tokio-postgres.workspace = true tokio-rustls.workspace = true tokio-stream.workspace = true +tokio-tar.workspace = true tokio-util.workspace = true toml_edit = { workspace = true, features = [ "serde" ] } +tonic.workspace = true +tonic-reflection.workspace = true tracing.workspace = true tracing-utils.workspace = true url.workspace = true -walkdir.workspace = true -metrics.workspace = true -pageserver_api.workspace = true -pageserver_client.workspace = true # for ResponseErrorMessageExt TOOD refactor that -pageserver_compaction.workspace = true -pem.workspace = true -postgres_connection.workspace = true -postgres_ffi.workspace = true -pq_proto.workspace = true -remote_storage.workspace = true -storage_broker.workspace = true -tenant_size_model.workspace = true -http-utils.workspace = true utils.workspace = true -workspace_hack.workspace = true -reqwest.workspace = true -rpds.workspace = true -enum-map.workspace = true -enumset = { workspace = true, features = ["serde"]} -strum.workspace = true -strum_macros.workspace = true wal_decoder.workspace = true -smallvec.workspace = true +walkdir.workspace = true +workspace_hack.workspace = true twox-hash.workspace = true [target.'cfg(target_os = "linux")'.dependencies] diff --git a/pageserver/page_api/Cargo.toml b/pageserver/page_api/Cargo.toml index c237949226..4f62c77eb2 100644 --- a/pageserver/page_api/Cargo.toml +++ b/pageserver/page_api/Cargo.toml @@ -5,8 +5,14 @@ edition.workspace = true license.workspace = true [dependencies] +bytes.workspace = true +pageserver_api.workspace = true +postgres_ffi.workspace = true prost.workspace = true +smallvec.workspace = true +thiserror.workspace = true tonic.workspace = true +utils.workspace = true workspace_hack.workspace = true [build-dependencies] diff --git a/pageserver/page_api/proto/page_service.proto b/pageserver/page_api/proto/page_service.proto index f6acb3eeeb..44976084bf 100644 --- a/pageserver/page_api/proto/page_service.proto +++ b/pageserver/page_api/proto/page_service.proto @@ -54,9 +54,9 @@ service PageService { // RPCs use regular unary requests, since they are not as frequent and // performance-critical, and this simplifies implementation. // - // NB: a status response (e.g. errors) will terminate the stream. The stream - // may be shared by e.g. multiple Postgres backends, so we should avoid this. - // Most errors are therefore sent as GetPageResponse.status instead. + // NB: a gRPC status response (e.g. errors) will terminate the stream. The + // stream may be shared by multiple Postgres backends, so we avoid this by + // sending them as GetPageResponse.status_code instead. rpc GetPages (stream GetPageRequest) returns (stream GetPageResponse); // Returns the size of a relation, as # of blocks. @@ -159,8 +159,8 @@ message GetPageRequest { // A GetPageRequest class. Primarily intended for observability, but may also be // used for prioritization in the future. enum GetPageClass { - // Unknown class. For forwards compatibility: used when the client sends a - // class that the server doesn't know about. + // Unknown class. For backwards compatibility: used when an older client version sends a class + // that a newer server version has removed. GET_PAGE_CLASS_UNKNOWN = 0; // A normal request. This is the default. GET_PAGE_CLASS_NORMAL = 1; @@ -180,31 +180,37 @@ message GetPageResponse { // The original request's ID. uint64 request_id = 1; // The response status code. - GetPageStatus status = 2; + GetPageStatusCode status_code = 2; // A string describing the status, if any. string reason = 3; - // The 8KB page images, in the same order as the request. Empty if status != OK. + // The 8KB page images, in the same order as the request. Empty if status_code != OK. repeated bytes page_image = 4; } -// A GetPageResponse status code. Since we use a bidirectional stream, we don't -// want to send errors as gRPC statuses, since this would terminate the stream. -enum GetPageStatus { - // Unknown status. For forwards compatibility: used when the server sends a - // status code that the client doesn't know about. - GET_PAGE_STATUS_UNKNOWN = 0; +// A GetPageResponse status code. +// +// These are effectively equivalent to gRPC statuses. However, we use a bidirectional stream +// (potentially shared by many backends), and a gRPC status response would terminate the stream so +// we send GetPageResponse messages with these codes instead. +enum GetPageStatusCode { + // Unknown status. For forwards compatibility: used when an older client version receives a new + // status code from a newer server version. + GET_PAGE_STATUS_CODE_UNKNOWN = 0; // The request was successful. - GET_PAGE_STATUS_OK = 1; + GET_PAGE_STATUS_CODE_OK = 1; // The page did not exist. The tenant/timeline/shard has already been // validated during stream setup. - GET_PAGE_STATUS_NOT_FOUND = 2; + GET_PAGE_STATUS_CODE_NOT_FOUND = 2; // The request was invalid. - GET_PAGE_STATUS_INVALID = 3; + GET_PAGE_STATUS_CODE_INVALID_REQUEST = 3; + // The request failed due to an internal server error. + GET_PAGE_STATUS_CODE_INTERNAL_ERROR = 4; // The tenant is rate limited. Slow down and retry later. - GET_PAGE_STATUS_SLOW_DOWN = 4; - // TODO: consider adding a GET_PAGE_STATUS_LAYER_DOWNLOAD in the case of a - // layer download. This could free up the server task to process other - // requests while the layer download is in progress. + GET_PAGE_STATUS_CODE_SLOW_DOWN = 5; + // NB: shutdown errors are emitted as a gRPC Unavailable status. + // + // TODO: consider adding a GET_PAGE_STATUS_CODE_LAYER_DOWNLOAD in the case of a layer download. + // This could free up the server task to process other requests while the download is in progress. } // Fetches the size of a relation at a given LSN, as # of blocks. Only valid on diff --git a/pageserver/page_api/src/lib.rs b/pageserver/page_api/src/lib.rs index 0b68d03aaa..f515f27f3e 100644 --- a/pageserver/page_api/src/lib.rs +++ b/pageserver/page_api/src/lib.rs @@ -17,3 +17,7 @@ pub mod proto { pub use page_service_client::PageServiceClient; pub use page_service_server::{PageService, PageServiceServer}; } + +mod model; + +pub use model::*; diff --git a/pageserver/page_api/src/model.rs b/pageserver/page_api/src/model.rs new file mode 100644 index 0000000000..7ab97a994e --- /dev/null +++ b/pageserver/page_api/src/model.rs @@ -0,0 +1,595 @@ +//! Structs representing the canonical page service API. +//! +//! These mirror the autogenerated Protobuf types. The differences are: +//! +//! - Types that are in fact required by the API are not Options. The protobuf "required" +//! attribute is deprecated and 'prost' marks a lot of members as optional because of that. +//! (See for a gripe on this) +//! +//! - Use more precise datatypes, e.g. Lsn and uints shorter than 32 bits. +//! +//! - Validate protocol invariants, via try_from() and try_into(). + +use bytes::Bytes; +use postgres_ffi::Oid; +use smallvec::SmallVec; +// TODO: split out Lsn, RelTag, SlruKind, Oid and other basic types to a separate crate, to avoid +// pulling in all of their other crate dependencies when building the client. +use utils::lsn::Lsn; + +use crate::proto; + +/// A protocol error. Typically returned via try_from() or try_into(). +#[derive(thiserror::Error, Debug)] +pub enum ProtocolError { + #[error("field '{0}' has invalid value '{1}'")] + Invalid(&'static str, String), + #[error("required field '{0}' is missing")] + Missing(&'static str), +} + +impl ProtocolError { + /// Helper to generate a new ProtocolError::Invalid for the given field and value. + pub fn invalid(field: &'static str, value: impl std::fmt::Debug) -> Self { + Self::Invalid(field, format!("{value:?}")) + } +} + +impl From for tonic::Status { + fn from(err: ProtocolError) -> Self { + tonic::Status::invalid_argument(format!("{err}")) + } +} + +/// The LSN a request should read at. +#[derive(Clone, Copy, Debug)] +pub struct ReadLsn { + /// The request's read LSN. + pub request_lsn: Lsn, + /// If given, the caller guarantees that the page has not been modified since this LSN. Must be + /// smaller than or equal to request_lsn. This allows the Pageserver to serve an old page + /// without waiting for the request LSN to arrive. Valid for all request types. + /// + /// It is undefined behaviour to make a request such that the page was, in fact, modified + /// between request_lsn and not_modified_since_lsn. The Pageserver might detect it and return an + /// error, or it might return the old page version or the new page version. Setting + /// not_modified_since_lsn equal to request_lsn is always safe, but can lead to unnecessary + /// waiting. + pub not_modified_since_lsn: Option, +} + +impl ReadLsn { + /// Validates the ReadLsn. + pub fn validate(&self) -> Result<(), ProtocolError> { + if self.request_lsn == Lsn::INVALID { + return Err(ProtocolError::invalid("request_lsn", self.request_lsn)); + } + if self.not_modified_since_lsn > Some(self.request_lsn) { + return Err(ProtocolError::invalid( + "not_modified_since_lsn", + self.not_modified_since_lsn, + )); + } + Ok(()) + } +} + +impl TryFrom for ReadLsn { + type Error = ProtocolError; + + fn try_from(pb: proto::ReadLsn) -> Result { + let read_lsn = Self { + request_lsn: Lsn(pb.request_lsn), + not_modified_since_lsn: match pb.not_modified_since_lsn { + 0 => None, + lsn => Some(Lsn(lsn)), + }, + }; + read_lsn.validate()?; + Ok(read_lsn) + } +} + +impl TryFrom for proto::ReadLsn { + type Error = ProtocolError; + + fn try_from(read_lsn: ReadLsn) -> Result { + read_lsn.validate()?; + Ok(Self { + request_lsn: read_lsn.request_lsn.0, + not_modified_since_lsn: read_lsn.not_modified_since_lsn.unwrap_or_default().0, + }) + } +} + +// RelTag is defined in pageserver_api::reltag. +pub type RelTag = pageserver_api::reltag::RelTag; + +impl TryFrom for RelTag { + type Error = ProtocolError; + + fn try_from(pb: proto::RelTag) -> Result { + Ok(Self { + spcnode: pb.spc_oid, + dbnode: pb.db_oid, + relnode: pb.rel_number, + forknum: pb + .fork_number + .try_into() + .map_err(|_| ProtocolError::invalid("fork_number", pb.fork_number))?, + }) + } +} + +impl From for proto::RelTag { + fn from(rel_tag: RelTag) -> Self { + Self { + spc_oid: rel_tag.spcnode, + db_oid: rel_tag.dbnode, + rel_number: rel_tag.relnode, + fork_number: rel_tag.forknum as u32, + } + } +} + +/// Checks whether a relation exists, at the given LSN. Only valid on shard 0, other shards error. +#[derive(Clone, Copy, Debug)] +pub struct CheckRelExistsRequest { + pub read_lsn: ReadLsn, + pub rel: RelTag, +} + +impl TryFrom for CheckRelExistsRequest { + type Error = ProtocolError; + + fn try_from(pb: proto::CheckRelExistsRequest) -> Result { + Ok(Self { + read_lsn: pb + .read_lsn + .ok_or(ProtocolError::Missing("read_lsn"))? + .try_into()?, + rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?, + }) + } +} + +pub type CheckRelExistsResponse = bool; + +impl From for CheckRelExistsResponse { + fn from(pb: proto::CheckRelExistsResponse) -> Self { + pb.exists + } +} + +impl From for proto::CheckRelExistsResponse { + fn from(exists: CheckRelExistsResponse) -> Self { + Self { exists } + } +} + +/// Requests a base backup at a given LSN. +#[derive(Clone, Copy, Debug)] +pub struct GetBaseBackupRequest { + /// The LSN to fetch a base backup at. + pub read_lsn: ReadLsn, + /// If true, logical replication slots will not be created. + pub replica: bool, +} + +impl TryFrom for GetBaseBackupRequest { + type Error = ProtocolError; + + fn try_from(pb: proto::GetBaseBackupRequest) -> Result { + Ok(Self { + read_lsn: pb + .read_lsn + .ok_or(ProtocolError::Missing("read_lsn"))? + .try_into()?, + replica: pb.replica, + }) + } +} + +impl TryFrom for proto::GetBaseBackupRequest { + type Error = ProtocolError; + + fn try_from(request: GetBaseBackupRequest) -> Result { + Ok(Self { + read_lsn: Some(request.read_lsn.try_into()?), + replica: request.replica, + }) + } +} + +pub type GetBaseBackupResponseChunk = Bytes; + +impl TryFrom for GetBaseBackupResponseChunk { + type Error = ProtocolError; + + fn try_from(pb: proto::GetBaseBackupResponseChunk) -> Result { + if pb.chunk.is_empty() { + return Err(ProtocolError::Missing("chunk")); + } + Ok(pb.chunk) + } +} + +impl TryFrom for proto::GetBaseBackupResponseChunk { + type Error = ProtocolError; + + fn try_from(chunk: GetBaseBackupResponseChunk) -> Result { + if chunk.is_empty() { + return Err(ProtocolError::Missing("chunk")); + } + Ok(Self { chunk }) + } +} + +/// Requests the size of a database, as # of bytes. Only valid on shard 0, other shards will error. +#[derive(Clone, Copy, Debug)] +pub struct GetDbSizeRequest { + pub read_lsn: ReadLsn, + pub db_oid: Oid, +} + +impl TryFrom for GetDbSizeRequest { + type Error = ProtocolError; + + fn try_from(pb: proto::GetDbSizeRequest) -> Result { + Ok(Self { + read_lsn: pb + .read_lsn + .ok_or(ProtocolError::Missing("read_lsn"))? + .try_into()?, + db_oid: pb.db_oid, + }) + } +} + +impl TryFrom for proto::GetDbSizeRequest { + type Error = ProtocolError; + + fn try_from(request: GetDbSizeRequest) -> Result { + Ok(Self { + read_lsn: Some(request.read_lsn.try_into()?), + db_oid: request.db_oid, + }) + } +} + +pub type GetDbSizeResponse = u64; + +impl From for GetDbSizeResponse { + fn from(pb: proto::GetDbSizeResponse) -> Self { + pb.num_bytes + } +} + +impl From for proto::GetDbSizeResponse { + fn from(num_bytes: GetDbSizeResponse) -> Self { + Self { num_bytes } + } +} + +/// Requests one or more pages. +#[derive(Clone, Debug)] +pub struct GetPageRequest { + /// A request ID. Will be included in the response. Should be unique for in-flight requests on + /// the stream. + pub request_id: RequestID, + /// The request class. + pub request_class: GetPageClass, + /// The LSN to read at. + pub read_lsn: ReadLsn, + /// The relation to read from. + pub rel: RelTag, + /// Page numbers to read. Must belong to the remote shard. + /// + /// Multiple pages will be executed as a single batch by the Pageserver, amortizing layer access + /// costs and parallelizing them. This may increase the latency of any individual request, but + /// improves the overall latency and throughput of the batch as a whole. + pub block_numbers: SmallVec<[u32; 1]>, +} + +impl TryFrom for GetPageRequest { + type Error = ProtocolError; + + fn try_from(pb: proto::GetPageRequest) -> Result { + if pb.block_number.is_empty() { + return Err(ProtocolError::Missing("block_number")); + } + Ok(Self { + request_id: pb.request_id, + request_class: pb.request_class.into(), + read_lsn: pb + .read_lsn + .ok_or(ProtocolError::Missing("read_lsn"))? + .try_into()?, + rel: pb.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?, + block_numbers: pb.block_number.into(), + }) + } +} + +impl TryFrom for proto::GetPageRequest { + type Error = ProtocolError; + + fn try_from(request: GetPageRequest) -> Result { + if request.block_numbers.is_empty() { + return Err(ProtocolError::Missing("block_number")); + } + Ok(Self { + request_id: request.request_id, + request_class: request.request_class.into(), + read_lsn: Some(request.read_lsn.try_into()?), + rel: Some(request.rel.into()), + block_number: request.block_numbers.into_vec(), + }) + } +} + +/// A GetPage request ID. +pub type RequestID = u64; + +/// A GetPage request class. +#[derive(Clone, Copy, Debug)] +pub enum GetPageClass { + /// Unknown class. For backwards compatibility: used when an older client version sends a class + /// that a newer server version has removed. + Unknown, + /// A normal request. This is the default. + Normal, + /// A prefetch request. NB: can only be classified on pg < 18. + Prefetch, + /// A background request (e.g. vacuum). + Background, +} + +impl From for GetPageClass { + fn from(pb: proto::GetPageClass) -> Self { + match pb { + proto::GetPageClass::Unknown => Self::Unknown, + proto::GetPageClass::Normal => Self::Normal, + proto::GetPageClass::Prefetch => Self::Prefetch, + proto::GetPageClass::Background => Self::Background, + } + } +} + +impl From for GetPageClass { + fn from(class: i32) -> Self { + proto::GetPageClass::try_from(class) + .unwrap_or(proto::GetPageClass::Unknown) + .into() + } +} + +impl From for proto::GetPageClass { + fn from(class: GetPageClass) -> Self { + match class { + GetPageClass::Unknown => Self::Unknown, + GetPageClass::Normal => Self::Normal, + GetPageClass::Prefetch => Self::Prefetch, + GetPageClass::Background => Self::Background, + } + } +} + +impl From for i32 { + fn from(class: GetPageClass) -> Self { + proto::GetPageClass::from(class).into() + } +} + +/// A GetPage response. +/// +/// A batch response will contain all of the requested pages. We could eagerly emit individual pages +/// as soon as they are ready, but on a readv() Postgres holds buffer pool locks on all pages in the +/// batch and we'll only return once the entire batch is ready, so no one can make use of the +/// individual pages. +#[derive(Clone, Debug)] +pub struct GetPageResponse { + /// The original request's ID. + pub request_id: RequestID, + /// The response status code. + pub status_code: GetPageStatusCode, + /// A string describing the status, if any. + pub reason: Option, + /// The 8KB page images, in the same order as the request. Empty if status != OK. + pub page_images: SmallVec<[Bytes; 1]>, +} + +impl From for GetPageResponse { + fn from(pb: proto::GetPageResponse) -> Self { + Self { + request_id: pb.request_id, + status_code: pb.status_code.into(), + reason: Some(pb.reason).filter(|r| !r.is_empty()), + page_images: pb.page_image.into(), + } + } +} + +impl From for proto::GetPageResponse { + fn from(response: GetPageResponse) -> Self { + Self { + request_id: response.request_id, + status_code: response.status_code.into(), + reason: response.reason.unwrap_or_default(), + page_image: response.page_images.into_vec(), + } + } +} + +/// A GetPage response status code. +/// +/// These are effectively equivalent to gRPC statuses. However, we use a bidirectional stream +/// (potentially shared by many backends), and a gRPC status response would terminate the stream so +/// we send GetPageResponse messages with these codes instead. +#[derive(Clone, Copy, Debug)] +pub enum GetPageStatusCode { + /// Unknown status. For forwards compatibility: used when an older client version receives a new + /// status code from a newer server version. + Unknown, + /// The request was successful. + Ok, + /// The page did not exist. The tenant/timeline/shard has already been validated during stream + /// setup. + NotFound, + /// The request was invalid. + InvalidRequest, + /// The request failed due to an internal server error. + InternalError, + /// The tenant is rate limited. Slow down and retry later. + SlowDown, +} + +impl From for GetPageStatusCode { + fn from(pb: proto::GetPageStatusCode) -> Self { + match pb { + proto::GetPageStatusCode::Unknown => Self::Unknown, + proto::GetPageStatusCode::Ok => Self::Ok, + proto::GetPageStatusCode::NotFound => Self::NotFound, + proto::GetPageStatusCode::InvalidRequest => Self::InvalidRequest, + proto::GetPageStatusCode::InternalError => Self::InternalError, + proto::GetPageStatusCode::SlowDown => Self::SlowDown, + } + } +} + +impl From for GetPageStatusCode { + fn from(status_code: i32) -> Self { + proto::GetPageStatusCode::try_from(status_code) + .unwrap_or(proto::GetPageStatusCode::Unknown) + .into() + } +} + +impl From for proto::GetPageStatusCode { + fn from(status_code: GetPageStatusCode) -> Self { + match status_code { + GetPageStatusCode::Unknown => Self::Unknown, + GetPageStatusCode::Ok => Self::Ok, + GetPageStatusCode::NotFound => Self::NotFound, + GetPageStatusCode::InvalidRequest => Self::InvalidRequest, + GetPageStatusCode::InternalError => Self::InternalError, + GetPageStatusCode::SlowDown => Self::SlowDown, + } + } +} + +impl From for i32 { + fn from(status_code: GetPageStatusCode) -> Self { + proto::GetPageStatusCode::from(status_code).into() + } +} + +// Fetches the size of a relation at a given LSN, as # of blocks. Only valid on shard 0, other +// shards will error. +pub struct GetRelSizeRequest { + pub read_lsn: ReadLsn, + pub rel: RelTag, +} + +impl TryFrom for GetRelSizeRequest { + type Error = ProtocolError; + + fn try_from(proto: proto::GetRelSizeRequest) -> Result { + Ok(Self { + read_lsn: proto + .read_lsn + .ok_or(ProtocolError::Missing("read_lsn"))? + .try_into()?, + rel: proto.rel.ok_or(ProtocolError::Missing("rel"))?.try_into()?, + }) + } +} + +impl TryFrom for proto::GetRelSizeRequest { + type Error = ProtocolError; + + fn try_from(request: GetRelSizeRequest) -> Result { + Ok(Self { + read_lsn: Some(request.read_lsn.try_into()?), + rel: Some(request.rel.into()), + }) + } +} + +pub type GetRelSizeResponse = u32; + +impl From for GetRelSizeResponse { + fn from(proto: proto::GetRelSizeResponse) -> Self { + proto.num_blocks + } +} + +impl From for proto::GetRelSizeResponse { + fn from(num_blocks: GetRelSizeResponse) -> Self { + Self { num_blocks } + } +} + +/// Requests an SLRU segment. Only valid on shard 0, other shards will error. +pub struct GetSlruSegmentRequest { + pub read_lsn: ReadLsn, + pub kind: SlruKind, + pub segno: u32, +} + +impl TryFrom for GetSlruSegmentRequest { + type Error = ProtocolError; + + fn try_from(pb: proto::GetSlruSegmentRequest) -> Result { + Ok(Self { + read_lsn: pb + .read_lsn + .ok_or(ProtocolError::Missing("read_lsn"))? + .try_into()?, + kind: u8::try_from(pb.kind) + .ok() + .and_then(SlruKind::from_repr) + .ok_or_else(|| ProtocolError::invalid("slru_kind", pb.kind))?, + segno: pb.segno, + }) + } +} + +impl TryFrom for proto::GetSlruSegmentRequest { + type Error = ProtocolError; + + fn try_from(request: GetSlruSegmentRequest) -> Result { + Ok(Self { + read_lsn: Some(request.read_lsn.try_into()?), + kind: request.kind as u32, + segno: request.segno, + }) + } +} + +pub type GetSlruSegmentResponse = Bytes; + +impl TryFrom for GetSlruSegmentResponse { + type Error = ProtocolError; + + fn try_from(pb: proto::GetSlruSegmentResponse) -> Result { + if pb.segment.is_empty() { + return Err(ProtocolError::Missing("segment")); + } + Ok(pb.segment) + } +} + +impl TryFrom for proto::GetSlruSegmentResponse { + type Error = ProtocolError; + + fn try_from(segment: GetSlruSegmentResponse) -> Result { + if segment.is_empty() { + return Err(ProtocolError::Missing("segment")); + } + Ok(Self { segment }) + } +} + +// SlruKind is defined in pageserver_api::reltag. +pub type SlruKind = pageserver_api::reltag::SlruKind; diff --git a/pageserver/src/bin/pageserver.rs b/pageserver/src/bin/pageserver.rs index 6001ea0345..df3c045145 100644 --- a/pageserver/src/bin/pageserver.rs +++ b/pageserver/src/bin/pageserver.rs @@ -21,6 +21,7 @@ use pageserver::config::{PageServerConf, PageserverIdentity, ignored_fields}; use pageserver::controller_upcall_client::StorageControllerUpcallClient; use pageserver::deletion_queue::DeletionQueue; use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task}; +use pageserver::feature_resolver::FeatureResolver; use pageserver::metrics::{STARTUP_DURATION, STARTUP_IS_LOADING}; use pageserver::task_mgr::{ BACKGROUND_RUNTIME, COMPUTE_REQUEST_RUNTIME, MGMT_REQUEST_RUNTIME, WALRECEIVER_RUNTIME, @@ -388,23 +389,30 @@ fn start_pageserver( // We need to release the lock file only when the process exits. std::mem::forget(lock_file); - // Bind the HTTP and libpq ports early, so that if they are in use by some other - // process, we error out early. - let http_addr = &conf.listen_http_addr; - info!("Starting pageserver http handler on {http_addr}"); - let http_listener = tcp_listener::bind(http_addr)?; + // Bind the HTTP, libpq, and gRPC ports early, to error out if they are + // already in use. + info!( + "Starting pageserver http handler on {} with auth {:#?}", + conf.listen_http_addr, conf.http_auth_type + ); + let http_listener = tcp_listener::bind(&conf.listen_http_addr)?; let https_listener = match conf.listen_https_addr.as_ref() { Some(https_addr) => { - info!("Starting pageserver https handler on {https_addr}"); + info!( + "Starting pageserver https handler on {https_addr} with auth {:#?}", + conf.http_auth_type + ); Some(tcp_listener::bind(https_addr)?) } None => None, }; - let pg_addr = &conf.listen_pg_addr; - info!("Starting pageserver pg protocol handler on {pg_addr}"); - let pageserver_listener = tcp_listener::bind(pg_addr)?; + info!( + "Starting pageserver pg protocol handler on {} with auth {:#?}", + conf.listen_pg_addr, conf.pg_auth_type, + ); + let pageserver_listener = tcp_listener::bind(&conf.listen_pg_addr)?; // Enable SO_KEEPALIVE on the socket, to detect dead connections faster. // These are configured via net.ipv4.tcp_keepalive_* sysctls. @@ -413,6 +421,15 @@ fn start_pageserver( // support enabling keepalives while using the default OS sysctls. setsockopt(&pageserver_listener, sockopt::KeepAlive, &true)?; + let mut grpc_listener = None; + if let Some(grpc_addr) = &conf.listen_grpc_addr { + info!( + "Starting pageserver gRPC handler on {grpc_addr} with auth {:#?}", + conf.grpc_auth_type + ); + grpc_listener = Some(tcp_listener::bind(grpc_addr).map_err(|e| anyhow!("{e}"))?); + } + // Launch broker client // The storage_broker::connect call needs to happen inside a tokio runtime thread. let broker_client = WALRECEIVER_RUNTIME @@ -440,7 +457,8 @@ fn start_pageserver( // Initialize authentication for incoming connections let http_auth; let pg_auth; - if conf.http_auth_type == AuthType::NeonJWT || conf.pg_auth_type == AuthType::NeonJWT { + let grpc_auth; + if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type].contains(&AuthType::NeonJWT) { // unwrap is ok because check is performed when creating config, so path is set and exists let key_path = conf.auth_validation_public_key_path.as_ref().unwrap(); info!("Loading public key(s) for verifying JWT tokens from {key_path:?}"); @@ -448,20 +466,23 @@ fn start_pageserver( let jwt_auth = JwtAuth::from_key_path(key_path)?; let auth: Arc = Arc::new(SwappableJwtAuth::new(jwt_auth)); - http_auth = match &conf.http_auth_type { + http_auth = match conf.http_auth_type { AuthType::Trust => None, AuthType::NeonJWT => Some(auth.clone()), }; - pg_auth = match &conf.pg_auth_type { + pg_auth = match conf.pg_auth_type { + AuthType::Trust => None, + AuthType::NeonJWT => Some(auth.clone()), + }; + grpc_auth = match conf.grpc_auth_type { AuthType::Trust => None, AuthType::NeonJWT => Some(auth), }; } else { http_auth = None; pg_auth = None; + grpc_auth = None; } - info!("Using auth for http API: {:#?}", conf.http_auth_type); - info!("Using auth for pg connections: {:#?}", conf.pg_auth_type); let tls_server_config = if conf.listen_https_addr.is_some() || conf.enable_tls_page_service_api { @@ -502,6 +523,12 @@ fn start_pageserver( // Set up remote storage client let remote_storage = BACKGROUND_RUNTIME.block_on(create_remote_storage_client(conf))?; + let feature_resolver = create_feature_resolver( + conf, + shutdown_pageserver.clone(), + BACKGROUND_RUNTIME.handle(), + )?; + // Set up deletion queue let (deletion_queue, deletion_workers) = DeletionQueue::new( remote_storage.clone(), @@ -555,6 +582,7 @@ fn start_pageserver( deletion_queue_client, l0_flush_global_state, basebackup_prepare_sender, + feature_resolver, }, order, shutdown_pageserver.clone(), @@ -776,9 +804,27 @@ fn start_pageserver( } else { None }, - basebackup_cache, + basebackup_cache.clone(), ); + // Spawn a Pageserver gRPC server task. It will spawn separate tasks for + // each stream/request. + // + // TODO: this uses a separate Tokio runtime for the page service. If we want + // other gRPC services, they will need their own port and runtime. Is this + // necessary? + let mut page_service_grpc = None; + if let Some(grpc_listener) = grpc_listener { + page_service_grpc = Some(page_service::spawn_grpc( + conf, + tenant_manager.clone(), + grpc_auth, + otel_guard.as_ref().map(|g| g.dispatch.clone()), + grpc_listener, + basebackup_cache, + )?); + } + // All started up! Now just sit and wait for shutdown signal. BACKGROUND_RUNTIME.block_on(async move { let signal_token = CancellationToken::new(); @@ -797,6 +843,7 @@ fn start_pageserver( http_endpoint_listener, https_endpoint_listener, page_service, + page_service_grpc, consumption_metrics_tasks, disk_usage_eviction_task, &tenant_manager, @@ -810,6 +857,14 @@ fn start_pageserver( }) } +fn create_feature_resolver( + conf: &'static PageServerConf, + shutdown_pageserver: CancellationToken, + handle: &tokio::runtime::Handle, +) -> anyhow::Result { + FeatureResolver::spawn(conf, shutdown_pageserver, handle) +} + async fn create_remote_storage_client( conf: &'static PageServerConf, ) -> anyhow::Result { diff --git a/pageserver/src/config.rs b/pageserver/src/config.rs index e8b3b7b3ab..89f7539722 100644 --- a/pageserver/src/config.rs +++ b/pageserver/src/config.rs @@ -14,7 +14,7 @@ use std::time::Duration; use anyhow::{Context, bail, ensure}; use camino::{Utf8Path, Utf8PathBuf}; use once_cell::sync::OnceCell; -use pageserver_api::config::{DiskUsageEvictionTaskConfig, MaxVectoredReadBytes}; +use pageserver_api::config::{DiskUsageEvictionTaskConfig, MaxVectoredReadBytes, PostHogConfig}; use pageserver_api::models::ImageCompressionAlgorithm; use pageserver_api::shard::TenantShardId; use pem::Pem; @@ -58,11 +58,16 @@ pub struct PageServerConf { pub listen_http_addr: String, /// Example: 127.0.0.1:9899 pub listen_https_addr: Option, + /// If set, expose a gRPC API on this address. + /// Example: 127.0.0.1:51051 + /// + /// EXPERIMENTAL: this protocol is unstable and under active development. + pub listen_grpc_addr: Option, - /// Path to a file with certificate's private key for https API. + /// Path to a file with certificate's private key for https and gRPC API. /// Default: server.key pub ssl_key_file: Utf8PathBuf, - /// Path to a file with a X509 certificate for https API. + /// Path to a file with a X509 certificate for https and gRPC API. /// Default: server.crt pub ssl_cert_file: Utf8PathBuf, /// Period to reload certificate and private key from files. @@ -100,6 +105,8 @@ pub struct PageServerConf { pub http_auth_type: AuthType, /// authentication method for libpq connections from compute pub pg_auth_type: AuthType, + /// authentication method for gRPC connections from compute + pub grpc_auth_type: AuthType, /// Path to a file or directory containing public key(s) for verifying JWT tokens. /// Used for both mgmt and compute auth, if enabled. pub auth_validation_public_key_path: Option, @@ -231,6 +238,9 @@ pub struct PageServerConf { /// This is insecure and should only be used in development environments. pub dev_mode: bool, + /// PostHog integration config. + pub posthog_config: Option, + pub timeline_import_config: pageserver_api::config::TimelineImportConfig, pub basebackup_cache_config: Option, @@ -355,6 +365,7 @@ impl PageServerConf { listen_pg_addr, listen_http_addr, listen_https_addr, + listen_grpc_addr, ssl_key_file, ssl_cert_file, ssl_cert_reload_period, @@ -369,6 +380,7 @@ impl PageServerConf { pg_distrib_dir, http_auth_type, pg_auth_type, + grpc_auth_type, auth_validation_public_key_path, remote_storage, broker_endpoint, @@ -412,6 +424,7 @@ impl PageServerConf { tracing, enable_tls_page_service_api, dev_mode, + posthog_config, timeline_import_config, basebackup_cache_config, } = config_toml; @@ -423,6 +436,7 @@ impl PageServerConf { listen_pg_addr, listen_http_addr, listen_https_addr, + listen_grpc_addr, ssl_key_file, ssl_cert_file, ssl_cert_reload_period, @@ -435,6 +449,7 @@ impl PageServerConf { max_file_descriptors, http_auth_type, pg_auth_type, + grpc_auth_type, auth_validation_public_key_path, remote_storage_config: remote_storage, broker_endpoint, @@ -525,13 +540,16 @@ impl PageServerConf { } None => Vec::new(), }, + posthog_config, }; // ------------------------------------------------------------ // custom validation code that covers more than one field in isolation // ------------------------------------------------------------ - if conf.http_auth_type == AuthType::NeonJWT || conf.pg_auth_type == AuthType::NeonJWT { + if [conf.http_auth_type, conf.pg_auth_type, conf.grpc_auth_type] + .contains(&AuthType::NeonJWT) + { let auth_validation_public_key_path = conf .auth_validation_public_key_path .get_or_insert_with(|| workdir.join("auth_public_key.pem")); diff --git a/pageserver/src/feature_resolver.rs b/pageserver/src/feature_resolver.rs new file mode 100644 index 0000000000..2b0f368079 --- /dev/null +++ b/pageserver/src/feature_resolver.rs @@ -0,0 +1,94 @@ +use std::{collections::HashMap, sync::Arc, time::Duration}; + +use posthog_client_lite::{ + FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError, +}; +use tokio_util::sync::CancellationToken; +use utils::id::TenantId; + +use crate::config::PageServerConf; + +#[derive(Clone)] +pub struct FeatureResolver { + inner: Option>, +} + +impl FeatureResolver { + pub fn new_disabled() -> Self { + Self { inner: None } + } + + pub fn spawn( + conf: &PageServerConf, + shutdown_pageserver: CancellationToken, + handle: &tokio::runtime::Handle, + ) -> anyhow::Result { + // DO NOT block in this function: make it return as fast as possible to avoid startup delays. + if let Some(posthog_config) = &conf.posthog_config { + let inner = FeatureResolverBackgroundLoop::new( + PostHogClientConfig { + server_api_key: posthog_config.server_api_key.clone(), + client_api_key: posthog_config.client_api_key.clone(), + project_id: posthog_config.project_id.clone(), + private_api_url: posthog_config.private_api_url.clone(), + public_api_url: posthog_config.public_api_url.clone(), + }, + shutdown_pageserver, + ); + let inner = Arc::new(inner); + // TODO: make this configurable + inner.clone().spawn(handle, Duration::from_secs(60)); + Ok(FeatureResolver { inner: Some(inner) }) + } else { + Ok(FeatureResolver { inner: None }) + } + } + + /// Evaluate a multivariate feature flag. Currently, we do not support any properties. + /// + /// Error handling: the caller should inspect the error and decide the behavior when a feature flag + /// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be + /// propagated beyond where the feature flag gets resolved. + pub fn evaluate_multivariate( + &self, + flag_key: &str, + tenant_id: TenantId, + ) -> Result { + if let Some(inner) = &self.inner { + inner.feature_store().evaluate_multivariate( + flag_key, + &tenant_id.to_string(), + &HashMap::new(), + ) + } else { + Err(PostHogEvaluationError::NotAvailable( + "PostHog integration is not enabled".to_string(), + )) + } + } + + /// Evaluate a boolean feature flag. Currently, we do not support any properties. + /// + /// Returns `Ok(())` if the flag is evaluated to true, otherwise returns an error. + /// + /// Error handling: the caller should inspect the error and decide the behavior when a feature flag + /// cannot be evaluated (i.e., default to false if it cannot be resolved). The error should *not* be + /// propagated beyond where the feature flag gets resolved. + pub fn evaluate_boolean( + &self, + flag_key: &str, + tenant_id: TenantId, + ) -> Result<(), PostHogEvaluationError> { + if let Some(inner) = &self.inner { + inner.feature_store().evaluate_boolean( + flag_key, + &tenant_id.to_string(), + &HashMap::new(), + ) + } else { + Err(PostHogEvaluationError::NotAvailable( + "PostHog integration is not enabled".to_string(), + )) + } + } +} diff --git a/pageserver/src/http/openapi_spec.yml b/pageserver/src/http/openapi_spec.yml index 7ea148971f..e8d1367d6c 100644 --- a/pageserver/src/http/openapi_spec.yml +++ b/pageserver/src/http/openapi_spec.yml @@ -353,6 +353,33 @@ paths: "200": description: OK + /v1/tenant/{tenant_shard_id}/timeline/{timeline_id}/mark_invisible: + parameters: + - name: tenant_shard_id + in: path + required: true + schema: + type: string + - name: timeline_id + in: path + required: true + schema: + type: string + format: hex + put: + requestBody: + content: + application/json: + schema: + type: object + properties: + is_visible: + type: boolean + default: false + responses: + "200": + description: OK + /v1/tenant/{tenant_shard_id}/location_config: parameters: - name: tenant_shard_id @@ -626,6 +653,8 @@ paths: format: hex pg_version: type: integer + read_only: + type: boolean existing_initdb_timeline_id: type: string format: hex diff --git a/pageserver/src/http/routes.rs b/pageserver/src/http/routes.rs index 0d6791cddd..c449e3373f 100644 --- a/pageserver/src/http/routes.rs +++ b/pageserver/src/http/routes.rs @@ -370,6 +370,18 @@ impl From for ApiError { } } +impl From for ApiError { + fn from(err: crate::tenant::FinalizeTimelineImportError) -> ApiError { + use crate::tenant::FinalizeTimelineImportError::*; + match err { + ImportTaskStillRunning => { + ApiError::ResourceUnavailable("Import task still running".into()) + } + ShuttingDown => ApiError::ShuttingDown, + } + } +} + // Helper function to construct a TimelineInfo struct for a timeline async fn build_timeline_info( timeline: &Arc, @@ -572,6 +584,7 @@ async fn timeline_create_handler( TimelineCreateRequestMode::Branch { ancestor_timeline_id, ancestor_start_lsn, + read_only: _, pg_version: _, } => tenant::CreateTimelineParams::Branch(tenant::CreateTimelineParamsBranch { new_timeline_id, @@ -3532,10 +3545,7 @@ async fn activate_post_import_handler( tenant.wait_to_become_active(ACTIVE_TENANT_TIMEOUT).await?; - tenant - .finalize_importing_timeline(timeline_id) - .await - .map_err(ApiError::InternalServerError)?; + tenant.finalize_importing_timeline(timeline_id).await?; match tenant.get_timeline(timeline_id, false) { Ok(_timeline) => { diff --git a/pageserver/src/lib.rs b/pageserver/src/lib.rs index 71d9c6603f..ae7cbf1d6b 100644 --- a/pageserver/src/lib.rs +++ b/pageserver/src/lib.rs @@ -10,6 +10,7 @@ pub mod context; pub mod controller_upcall_client; pub mod deletion_queue; pub mod disk_usage_eviction_task; +pub mod feature_resolver; pub mod http; pub mod import_datadir; pub mod l0_flush; @@ -84,6 +85,7 @@ pub async fn shutdown_pageserver( http_listener: HttpEndpointListener, https_listener: Option, page_service: page_service::Listener, + grpc_task: Option, consumption_metrics_worker: ConsumptionMetricsTasks, disk_usage_eviction_task: Option, tenant_manager: &TenantManager, @@ -177,6 +179,16 @@ pub async fn shutdown_pageserver( ) .await; + // Shut down the gRPC server task, including request handlers. + if let Some(grpc_task) = grpc_task { + timed( + grpc_task.shutdown(), + "shutdown gRPC PageRequestHandler", + Duration::from_secs(3), + ) + .await; + } + // Shut down all the tenants. This flushes everything to disk and kills // the checkpoint and GC tasks. timed( diff --git a/pageserver/src/metrics.rs b/pageserver/src/metrics.rs index 3076c7f1d6..0ff31dcb8a 100644 --- a/pageserver/src/metrics.rs +++ b/pageserver/src/metrics.rs @@ -2234,8 +2234,10 @@ impl BasebackupQueryTimeOngoingRecording<'_> { // If you want to change categorize of a specific error, also change it in `log_query_error`. let metric = match res { Ok(_) => &self.parent.ok, - Err(QueryError::Shutdown) => { - // Do not observe ok/err for shutdown + Err(QueryError::Shutdown) | Err(QueryError::Reconnect) => { + // Do not observe ok/err for shutdown/reconnect. + // Reconnect error might be raised when the operation is waiting for LSN and the tenant shutdown interrupts + // the operation. A reconnect error will be issued and the client will retry. return; } Err(QueryError::Disconnected(ConnectionError::Io(io_error))) diff --git a/pageserver/src/page_service.rs b/pageserver/src/page_service.rs index 69519dfa87..e96787e027 100644 --- a/pageserver/src/page_service.rs +++ b/pageserver/src/page_service.rs @@ -4,6 +4,7 @@ use std::borrow::Cow; use std::num::NonZeroUsize; use std::os::fd::AsRawFd; +use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; use std::time::{Duration, Instant, SystemTime}; @@ -12,7 +13,7 @@ use std::{io, str}; use anyhow::{Context, bail}; use async_compression::tokio::write::GzipEncoder; use bytes::Buf; -use futures::FutureExt; +use futures::{FutureExt, Stream}; use itertools::Itertools; use jsonwebtoken::TokenData; use once_cell::sync::OnceCell; @@ -30,6 +31,7 @@ use pageserver_api::models::{ }; use pageserver_api::reltag::SlruKind; use pageserver_api::shard::TenantShardId; +use pageserver_page_api::proto; use postgres_backend::{ AuthType, PostgresBackend, PostgresBackendReader, QueryError, is_expected_io_error, }; @@ -41,19 +43,20 @@ use strum_macros::IntoStaticStr; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufWriter}; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; +use tonic::service::Interceptor as _; use tracing::*; use utils::auth::{Claims, Scope, SwappableJwtAuth}; use utils::failpoint_support; -use utils::id::{TenantId, TimelineId}; +use utils::id::{TenantId, TenantTimelineId, TimelineId}; use utils::logging::log_slow; use utils::lsn::Lsn; +use utils::shard::ShardIndex; use utils::simple_rcu::RcuReadGuard; use utils::sync::gate::{Gate, GateGuard}; use utils::sync::spsc_fold; -use crate::PERF_TRACE_TARGET; use crate::auth::check_permission; -use crate::basebackup::BasebackupError; +use crate::basebackup::{self, BasebackupError}; use crate::basebackup_cache::BasebackupCache; use crate::config::PageServerConf; use crate::context::{ @@ -75,7 +78,7 @@ use crate::tenant::mgr::{ use crate::tenant::storage_layer::IoConcurrency; use crate::tenant::timeline::{self, WaitLsnError}; use crate::tenant::{GetTimelineError, PageReconstructError, Timeline}; -use crate::{basebackup, timed_after_cancellation}; +use crate::{CancellableTask, PERF_TRACE_TARGET, timed_after_cancellation}; /// How long we may wait for a [`crate::tenant::mgr::TenantSlot::InProgress`]` and/or a [`crate::tenant::TenantShard`] which /// is not yet in state [`TenantState::Active`]. @@ -86,6 +89,26 @@ const ACTIVE_TENANT_TIMEOUT: Duration = Duration::from_millis(30000); /// Threshold at which to log slow GetPage requests. const LOG_SLOW_GETPAGE_THRESHOLD: Duration = Duration::from_secs(30); +/// The idle time before sending TCP keepalive probes for gRPC connections. The +/// interval and timeout between each probe is configured via sysctl. This +/// allows detecting dead connections sooner. +const GRPC_TCP_KEEPALIVE_TIME: Duration = Duration::from_secs(60); + +/// Whether to enable TCP nodelay for gRPC connections. This disables Nagle's +/// algorithm, which can cause latency spikes for small messages. +const GRPC_TCP_NODELAY: bool = true; + +/// The interval between HTTP2 keepalive pings. This allows shutting down server +/// tasks when clients are unresponsive. +const GRPC_HTTP2_KEEPALIVE_INTERVAL: Duration = Duration::from_secs(30); + +/// The timeout for HTTP2 keepalive pings. Should be <= GRPC_KEEPALIVE_INTERVAL. +const GRPC_HTTP2_KEEPALIVE_TIMEOUT: Duration = Duration::from_secs(20); + +/// Number of concurrent gRPC streams per TCP connection. We expect something +/// like 8 GetPage streams per connections, plus any unary requests. +const GRPC_MAX_CONCURRENT_STREAMS: u32 = 256; + /////////////////////////////////////////////////////////////////////////////// pub struct Listener { @@ -140,6 +163,94 @@ pub fn spawn( Listener { cancel, task } } +/// Spawns a gRPC server for the page service. +/// +/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we +/// need to reimplement the TCP+TLS accept loop ourselves. +pub fn spawn_grpc( + conf: &'static PageServerConf, + tenant_manager: Arc, + auth: Option>, + perf_trace_dispatch: Option, + listener: std::net::TcpListener, + basebackup_cache: Arc, +) -> anyhow::Result { + let cancel = CancellationToken::new(); + let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler) + .download_behavior(DownloadBehavior::Download) + .perf_span_dispatch(perf_trace_dispatch) + .detached_child(); + let gate = Gate::default(); + + // Set up the TCP socket. We take a preconfigured TcpListener to bind the + // port early during startup. + let incoming = { + let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std + listener.set_nonblocking(true)?; + tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std(listener)?) + .with_nodelay(Some(GRPC_TCP_NODELAY)) + .with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME)) + }; + + // Set up the gRPC server. + // + // TODO: consider tuning window sizes. + // TODO: wire up tracing. + let mut server = tonic::transport::Server::builder() + .http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL)) + .http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT)) + .max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS)); + + // Main page service. + let page_service_handler = PageServerHandler::new( + tenant_manager, + auth.clone(), + PageServicePipeliningConfig::Serial, // TODO: unused with gRPC + conf.get_vectored_concurrent_io, + ConnectionPerfSpanFields::default(), + basebackup_cache, + ctx, + cancel.clone(), + gate.enter().expect("just created"), + ); + + let mut tenant_interceptor = TenantMetadataInterceptor; + let mut auth_interceptor = TenantAuthInterceptor::new(auth); + let interceptors = move |mut req: tonic::Request<()>| { + req = tenant_interceptor.call(req)?; + req = auth_interceptor.call(req)?; + Ok(req) + }; + + let page_service = + proto::PageServiceServer::with_interceptor(page_service_handler, interceptors); + let server = server.add_service(page_service); + + // Reflection service for use with e.g. grpcurl. + let reflection_service = tonic_reflection::server::Builder::configure() + .register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET) + .build_v1()?; + let server = server.add_service(reflection_service); + + // Spawn server task. + let task_cancel = cancel.clone(); + let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error( + "grpc listener", + async move { + let result = server + .serve_with_incoming_shutdown(incoming, task_cancel.cancelled()) + .await; + if result.is_ok() { + // TODO: revisit shutdown logic once page service is implemented. + gate.close().await; + } + result + }, + )); + + Ok(CancellableTask { task, cancel }) +} + impl Listener { pub async fn stop_accepting(self) -> Connections { self.cancel.cancel(); @@ -259,7 +370,7 @@ type ConnectionHandlerResult = anyhow::Result<()>; /// Perf root spans start at the per-request level, after shard routing. /// This struct carries connection-level information to the root perf span definition. -#[derive(Clone)] +#[derive(Clone, Default)] struct ConnectionPerfSpanFields { peer_addr: String, application_name: Option, @@ -377,6 +488,11 @@ async fn page_service_conn_main( } } +/// Page service connection handler. +/// +/// TODO: for gRPC, this will be shared by all requests from all connections. +/// Decompose it into global state and per-connection/request state, and make +/// libpq-specific options (e.g. pipelining) separate. struct PageServerHandler { auth: Option>, claims: Option, @@ -653,6 +769,9 @@ struct BatchedGetPageRequest { timer: SmgrOpTimer, lsn_range: LsnRange, ctx: RequestContext, + // If the request is perf enabled, this contains a context + // with a perf span tracking the time spent waiting for the executor. + batch_wait_ctx: Option, } #[cfg(feature = "testing")] @@ -665,6 +784,7 @@ struct BatchedTestRequest { /// so that we don't keep the [`Timeline::gate`] open while the batch /// is being built up inside the [`spsc_fold`] (pagestream pipelining). #[derive(IntoStaticStr)] +#[allow(clippy::large_enum_variant)] enum BatchedFeMessage { Exists { span: Span, @@ -1182,6 +1302,22 @@ impl PageServerHandler { } }; + let batch_wait_ctx = if ctx.has_perf_span() { + Some( + RequestContextBuilder::from(&ctx) + .perf_span(|crnt_perf_span| { + info_span!( + target: PERF_TRACE_TARGET, + parent: crnt_perf_span, + "WAIT_EXECUTOR", + ) + }) + .attached_child(), + ) + } else { + None + }; + BatchedFeMessage::GetPage { span, shard: shard.downgrade(), @@ -1193,6 +1329,7 @@ impl PageServerHandler { request_lsn: req.hdr.request_lsn }, ctx, + batch_wait_ctx, }], // The executor grabs the batch when it becomes idle. // Hence, [`GetPageBatchBreakReason::ExecutorSteal`] is the @@ -1348,7 +1485,7 @@ impl PageServerHandler { let mut flush_timers = Vec::with_capacity(handler_results.len()); for handler_result in &mut handler_results { let flush_timer = match handler_result { - Ok((_, timer)) => Some( + Ok((_response, timer, _ctx)) => Some( timer .observe_execution_end(flushing_start_time) .expect("we are the first caller"), @@ -1368,7 +1505,7 @@ impl PageServerHandler { // Some handler errors cause exit from pagestream protocol. // Other handler errors are sent back as an error message and we stay in pagestream protocol. for (handler_result, flushing_timer) in handler_results.into_iter().zip(flush_timers) { - let response_msg = match handler_result { + let (response_msg, ctx) = match handler_result { Err(e) => match &e.err { PageStreamError::Shutdown => { // If we fail to fulfil a request during shutdown, which may be _because_ of @@ -1393,15 +1530,30 @@ impl PageServerHandler { error!("error reading relation or page version: {full:#}") }); - PagestreamBeMessage::Error(PagestreamErrorResponse { - req: e.req, - message: e.err.to_string(), - }) + ( + PagestreamBeMessage::Error(PagestreamErrorResponse { + req: e.req, + message: e.err.to_string(), + }), + None, + ) } }, - Ok((response_msg, _op_timer_already_observed)) => response_msg, + Ok((response_msg, _op_timer_already_observed, ctx)) => (response_msg, Some(ctx)), }; + let ctx = ctx.map(|req_ctx| { + RequestContextBuilder::from(&req_ctx) + .perf_span(|crnt_perf_span| { + info_span!( + target: PERF_TRACE_TARGET, + parent: crnt_perf_span, + "FLUSH_RESPONSE", + ) + }) + .attached_child() + }); + // // marshal & transmit response message // @@ -1424,6 +1576,17 @@ impl PageServerHandler { )), None => futures::future::Either::Right(flush_fut), }; + + let flush_fut = if let Some(req_ctx) = ctx.as_ref() { + futures::future::Either::Left( + flush_fut.maybe_perf_instrument(req_ctx, |current_perf_span| { + current_perf_span.clone() + }), + ) + } else { + futures::future::Either::Right(flush_fut) + }; + // do it while respecting cancellation let _: () = async move { tokio::select! { @@ -1453,7 +1616,7 @@ impl PageServerHandler { ctx: &RequestContext, ) -> Result< ( - Vec>, + Vec>, Span, ), QueryError, @@ -1480,7 +1643,7 @@ impl PageServerHandler { self.handle_get_rel_exists_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer)) + .map(|msg| (msg, timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1499,7 +1662,7 @@ impl PageServerHandler { self.handle_get_nblocks_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer)) + .map(|msg| (msg, timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1546,7 +1709,7 @@ impl PageServerHandler { self.handle_db_size_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer)) + .map(|msg| (msg, timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1565,7 +1728,7 @@ impl PageServerHandler { self.handle_get_slru_segment_request(&shard, &req, &ctx) .instrument(span.clone()) .await - .map(|msg| (msg, timer)) + .map(|msg| (msg, timer, ctx)) .map_err(|err| BatchedPageStreamError { err, req: req.hdr }), ], span, @@ -1917,12 +2080,25 @@ impl PageServerHandler { return Ok(()); } }; - let batch = match batch { + let mut batch = match batch { Ok(batch) => batch, Err(e) => { return Err(e); } }; + + if let BatchedFeMessage::GetPage { + pages, + span: _, + shard: _, + batch_break_reason: _, + } = &mut batch + { + for req in pages { + req.batch_wait_ctx.take(); + } + } + self.pagestream_handle_batched_message( pgb_writer, batch, @@ -2235,7 +2411,8 @@ impl PageServerHandler { io_concurrency: IoConcurrency, batch_break_reason: GetPageBatchBreakReason, ctx: &RequestContext, - ) -> Vec> { + ) -> Vec> + { debug_assert_current_span_has_tenant_and_timeline_id(); timeline @@ -2342,6 +2519,7 @@ impl PageServerHandler { page, }), req.timer, + req.ctx, ) }) .map_err(|e| BatchedPageStreamError { @@ -2386,7 +2564,8 @@ impl PageServerHandler { timeline: &Timeline, requests: Vec, _ctx: &RequestContext, - ) -> Vec> { + ) -> Vec> + { // real requests would do something with the timeline let mut results = Vec::with_capacity(requests.len()); for _req in requests.iter() { @@ -2413,6 +2592,10 @@ impl PageServerHandler { req: req.req.clone(), }), req.timer, + RequestContext::new( + TaskKind::PageRequestHandler, + DownloadBehavior::Warn, + ), ) }) .map_err(|e| BatchedPageStreamError { @@ -3117,6 +3300,60 @@ where } } +/// Implements the page service over gRPC. +/// +/// TODO: not yet implemented, all methods return unimplemented. +#[tonic::async_trait] +impl proto::PageService for PageServerHandler { + type GetBaseBackupStream = Pin< + Box> + Send>, + >; + type GetPagesStream = + Pin> + Send>>; + + async fn check_rel_exists( + &self, + _: tonic::Request, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented")) + } + + async fn get_base_backup( + &self, + _: tonic::Request, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented")) + } + + async fn get_db_size( + &self, + _: tonic::Request, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented")) + } + + async fn get_pages( + &self, + _: tonic::Request>, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented")) + } + + async fn get_rel_size( + &self, + _: tonic::Request, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented")) + } + + async fn get_slru_segment( + &self, + _: tonic::Request, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented")) + } +} + impl From for QueryError { fn from(e: GetActiveTenantError) -> Self { match e { @@ -3133,6 +3370,104 @@ impl From for QueryError { } } +/// gRPC interceptor that decodes tenant metadata and stores it as request extensions of type +/// TenantTimelineId and ShardIndex. +/// +/// TODO: consider looking up the timeline handle here and storing it. +#[derive(Clone)] +struct TenantMetadataInterceptor; + +impl tonic::service::Interceptor for TenantMetadataInterceptor { + fn call(&mut self, mut req: tonic::Request<()>) -> Result, tonic::Status> { + // Decode the tenant ID. + let tenant_id = req + .metadata() + .get("neon-tenant-id") + .ok_or_else(|| tonic::Status::invalid_argument("missing neon-tenant-id"))? + .to_str() + .map_err(|_| tonic::Status::invalid_argument("invalid neon-tenant-id"))?; + let tenant_id = TenantId::from_str(tenant_id) + .map_err(|_| tonic::Status::invalid_argument("invalid neon-tenant-id"))?; + + // Decode the timeline ID. + let timeline_id = req + .metadata() + .get("neon-timeline-id") + .ok_or_else(|| tonic::Status::invalid_argument("missing neon-timeline-id"))? + .to_str() + .map_err(|_| tonic::Status::invalid_argument("invalid neon-timeline-id"))?; + let timeline_id = TimelineId::from_str(timeline_id) + .map_err(|_| tonic::Status::invalid_argument("invalid neon-timeline-id"))?; + + // Decode the shard ID. + let shard_index = req + .metadata() + .get("neon-shard-id") + .ok_or_else(|| tonic::Status::invalid_argument("missing neon-shard-id"))? + .to_str() + .map_err(|_| tonic::Status::invalid_argument("invalid neon-shard-id"))?; + let shard_index = ShardIndex::from_str(shard_index) + .map_err(|_| tonic::Status::invalid_argument("invalid neon-shard-id"))?; + + // Stash them in the request. + let extensions = req.extensions_mut(); + extensions.insert(TenantTimelineId::new(tenant_id, timeline_id)); + extensions.insert(shard_index); + + Ok(req) + } +} + +/// Authenticates gRPC page service requests. Must run after TenantMetadataInterceptor. +#[derive(Clone)] +struct TenantAuthInterceptor { + auth: Option>, +} + +impl TenantAuthInterceptor { + fn new(auth: Option>) -> Self { + Self { auth } + } +} + +impl tonic::service::Interceptor for TenantAuthInterceptor { + fn call(&mut self, req: tonic::Request<()>) -> Result, tonic::Status> { + // Do nothing if auth is disabled. + let Some(auth) = self.auth.as_ref() else { + return Ok(req); + }; + + // Fetch the tenant ID that's been set by TenantMetadataInterceptor. + let ttid = req + .extensions() + .get::() + .expect("TenantMetadataInterceptor must run before TenantAuthInterceptor"); + + // Fetch and decode the JWT token. + let jwt = req + .metadata() + .get("authorization") + .ok_or_else(|| tonic::Status::unauthenticated("no authorization header"))? + .to_str() + .map_err(|_| tonic::Status::invalid_argument("invalid authorization header"))? + .strip_prefix("Bearer ") + .ok_or_else(|| tonic::Status::invalid_argument("invalid authorization header"))? + .trim(); + let jwtdata: TokenData = auth + .decode(jwt) + .map_err(|err| tonic::Status::invalid_argument(format!("invalid JWT token: {err}")))?; + let claims = jwtdata.claims; + + // Check if the token is valid for this tenant. + check_permission(&claims, Some(ttid.tenant_id)) + .map_err(|err| tonic::Status::permission_denied(err.to_string()))?; + + // TODO: consider stashing the claims in the request extensions, if needed. + + Ok(req) + } +} + #[derive(Debug, thiserror::Error)] pub(crate) enum GetActiveTimelineError { #[error(transparent)] diff --git a/pageserver/src/task_mgr.rs b/pageserver/src/task_mgr.rs index 55272b2125..29897af642 100644 --- a/pageserver/src/task_mgr.rs +++ b/pageserver/src/task_mgr.rs @@ -276,9 +276,10 @@ pub enum TaskKind { // HTTP endpoint listener. HttpEndpointListener, - // Task that handles a single connection. A PageRequestHandler task - // starts detached from any particular tenant or timeline, but it can be - // associated with one later, after receiving a command from the client. + /// Task that handles a single page service connection. A PageRequestHandler + /// task starts detached from any particular tenant or timeline, but it can + /// be associated with one later, after receiving a command from the client. + /// Also used for the gRPC page service API, including the main server task. PageRequestHandler, /// Manages the WAL receiver connection for one timeline. diff --git a/pageserver/src/tenant.rs b/pageserver/src/tenant.rs index bf3f71e35a..58b766933d 100644 --- a/pageserver/src/tenant.rs +++ b/pageserver/src/tenant.rs @@ -84,6 +84,7 @@ use crate::context; use crate::context::RequestContextBuilder; use crate::context::{DownloadBehavior, RequestContext}; use crate::deletion_queue::{DeletionQueueClient, DeletionQueueError}; +use crate::feature_resolver::FeatureResolver; use crate::l0_flush::L0FlushGlobalState; use crate::metrics::{ BROKEN_TENANTS_SET, CIRCUIT_BREAKERS_BROKEN, CIRCUIT_BREAKERS_UNBROKEN, CONCURRENT_INITDBS, @@ -159,6 +160,7 @@ pub struct TenantSharedResources { pub deletion_queue_client: DeletionQueueClient, pub l0_flush_global_state: L0FlushGlobalState, pub basebackup_prepare_sender: BasebackupPrepareSender, + pub feature_resolver: FeatureResolver, } /// A [`TenantShard`] is really an _attached_ tenant. The configuration @@ -380,6 +382,8 @@ pub struct TenantShard { pub(crate) gc_block: gc_block::GcBlock, l0_flush_global_state: L0FlushGlobalState, + + feature_resolver: FeatureResolver, } impl std::fmt::Debug for TenantShard { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -860,6 +864,14 @@ impl Debug for SetStoppingError { } } +#[derive(thiserror::Error, Debug)] +pub(crate) enum FinalizeTimelineImportError { + #[error("Import task not done yet")] + ImportTaskStillRunning, + #[error("Shutting down")] + ShuttingDown, +} + /// Arguments to [`TenantShard::create_timeline`]. /// /// Not usable as an idempotency key for timeline creation because if [`CreateTimelineParamsBranch::ancestor_start_lsn`] @@ -1146,10 +1158,20 @@ impl TenantShard { ctx, )?; let disk_consistent_lsn = timeline.get_disk_consistent_lsn(); - anyhow::ensure!( - disk_consistent_lsn.is_valid(), - "Timeline {tenant_id}/{timeline_id} has invalid disk_consistent_lsn" - ); + + if !disk_consistent_lsn.is_valid() { + // As opposed to normal timelines which get initialised with a disk consitent LSN + // via initdb, imported timelines start from 0. If the import task stops before + // it advances disk consitent LSN, allow it to resume. + let in_progress_import = import_pgdata + .as_ref() + .map(|import| !import.is_done()) + .unwrap_or(false); + if !in_progress_import { + anyhow::bail!("Timeline {tenant_id}/{timeline_id} has invalid disk_consistent_lsn"); + } + } + assert_eq!( disk_consistent_lsn, metadata.disk_consistent_lsn(), @@ -1243,20 +1265,25 @@ impl TenantShard { } } - // Sanity check: a timeline should have some content. - anyhow::ensure!( - ancestor.is_some() - || timeline - .layers - .read() - .await - .layer_map() - .expect("currently loading, layer manager cannot be shutdown already") - .iter_historic_layers() - .next() - .is_some(), - "Timeline has no ancestor and no layer files" - ); + if disk_consistent_lsn.is_valid() { + // Sanity check: a timeline should have some content. + // Exception: importing timelines might not yet have any + anyhow::ensure!( + ancestor.is_some() + || timeline + .layers + .read() + .await + .layer_map() + .expect( + "currently loading, layer manager cannot be shutdown already" + ) + .iter_historic_layers() + .next() + .is_some(), + "Timeline has no ancestor and no layer files" + ); + } Ok(TimelineInitAndSyncResult::ReadyToActivate) } @@ -1292,6 +1319,7 @@ impl TenantShard { deletion_queue_client, l0_flush_global_state, basebackup_prepare_sender, + feature_resolver, } = resources; let attach_mode = attached_conf.location.attach_mode; @@ -1308,6 +1336,7 @@ impl TenantShard { deletion_queue_client, l0_flush_global_state, basebackup_prepare_sender, + feature_resolver, )); // The attach task will carry a GateGuard, so that shutdown() reliably waits for it to drop out if @@ -2854,13 +2883,13 @@ impl TenantShard { pub(crate) async fn finalize_importing_timeline( &self, timeline_id: TimelineId, - ) -> anyhow::Result<()> { + ) -> Result<(), FinalizeTimelineImportError> { let timeline = { let locked = self.timelines_importing.lock().unwrap(); match locked.get(&timeline_id) { Some(importing_timeline) => { if !importing_timeline.import_task_handle.is_finished() { - return Err(anyhow::anyhow!("Import task not done yet")); + return Err(FinalizeTimelineImportError::ImportTaskStillRunning); } importing_timeline.timeline.clone() @@ -2873,8 +2902,13 @@ impl TenantShard { timeline .remote_client - .schedule_index_upload_for_import_pgdata_finalize()?; - timeline.remote_client.wait_completion().await?; + .schedule_index_upload_for_import_pgdata_finalize() + .map_err(|_err| FinalizeTimelineImportError::ShuttingDown)?; + timeline + .remote_client + .wait_completion() + .await + .map_err(|_err| FinalizeTimelineImportError::ShuttingDown)?; self.timelines_importing .lock() @@ -3135,11 +3169,18 @@ impl TenantShard { .or_insert_with(|| Arc::new(GcCompactionQueue::new())) .clone() }; + let gc_compaction_strategy = self + .feature_resolver + .evaluate_multivariate("gc-comapction-strategy", self.tenant_shard_id.tenant_id) + .ok(); + let span = if let Some(gc_compaction_strategy) = gc_compaction_strategy { + info_span!("gc_compact_timeline", timeline_id = %timeline.timeline_id, strategy = %gc_compaction_strategy) + } else { + info_span!("gc_compact_timeline", timeline_id = %timeline.timeline_id) + }; outcome = queue .iteration(cancel, ctx, &self.gc_block, &timeline) - .instrument( - info_span!("gc_compact_timeline", timeline_id = %timeline.timeline_id), - ) + .instrument(span) .await?; } @@ -3471,8 +3512,9 @@ impl TenantShard { let mut timelines_importing = self.timelines_importing.lock().unwrap(); timelines_importing .drain() - .for_each(|(_timeline_id, importing_timeline)| { - importing_timeline.shutdown(); + .for_each(|(timeline_id, importing_timeline)| { + let span = tracing::info_span!("importing_timeline_shutdown", %timeline_id); + js.spawn(async move { importing_timeline.shutdown().instrument(span).await }); }); } // test_long_timeline_create_then_tenant_delete is leaning on this message @@ -4247,6 +4289,7 @@ impl TenantShard { deletion_queue_client: DeletionQueueClient, l0_flush_global_state: L0FlushGlobalState, basebackup_prepare_sender: BasebackupPrepareSender, + feature_resolver: FeatureResolver, ) -> TenantShard { assert!(!attached_conf.location.generation.is_none()); @@ -4351,6 +4394,7 @@ impl TenantShard { gc_block: Default::default(), l0_flush_global_state, basebackup_prepare_sender, + feature_resolver, } } @@ -5271,6 +5315,7 @@ impl TenantShard { l0_compaction_trigger: self.l0_compaction_trigger.clone(), l0_flush_global_state: self.l0_flush_global_state.clone(), basebackup_prepare_sender: self.basebackup_prepare_sender.clone(), + feature_resolver: self.feature_resolver.clone(), } } @@ -5873,6 +5918,7 @@ pub(crate) mod harness { // TODO: ideally we should run all unit tests with both configs L0FlushGlobalState::new(L0FlushConfig::default()), basebackup_requst_sender, + FeatureResolver::new_disabled(), )); let preload = tenant @@ -8314,10 +8360,24 @@ mod tests { } tline.freeze_and_flush().await?; + // Force layers to L1 + tline + .compact( + &cancel, + { + let mut flags = EnumSet::new(); + flags.insert(CompactFlags::ForceL0Compaction); + flags + }, + &ctx, + ) + .await?; if iter % 5 == 0 { + let scan_lsn = Lsn(lsn.0 + 1); + info!("scanning at {}", scan_lsn); let (_, before_delta_file_accessed) = - scan_with_statistics(&tline, &keyspace, lsn, &ctx, io_concurrency.clone()) + scan_with_statistics(&tline, &keyspace, scan_lsn, &ctx, io_concurrency.clone()) .await?; tline .compact( @@ -8326,13 +8386,14 @@ mod tests { let mut flags = EnumSet::new(); flags.insert(CompactFlags::ForceImageLayerCreation); flags.insert(CompactFlags::ForceRepartition); + flags.insert(CompactFlags::ForceL0Compaction); flags }, &ctx, ) .await?; let (_, after_delta_file_accessed) = - scan_with_statistics(&tline, &keyspace, lsn, &ctx, io_concurrency.clone()) + scan_with_statistics(&tline, &keyspace, scan_lsn, &ctx, io_concurrency.clone()) .await?; assert!( after_delta_file_accessed < before_delta_file_accessed, @@ -8773,6 +8834,8 @@ mod tests { let cancel = CancellationToken::new(); + // Image layer creation happens on the disk_consistent_lsn so we need to force set it now. + tline.force_set_disk_consistent_lsn(Lsn(0x40)); tline .compact( &cancel, @@ -8786,8 +8849,7 @@ mod tests { ) .await .unwrap(); - - // Image layers are created at last_record_lsn + // Image layers are created at repartition LSN let images = tline .inspect_image_layers(Lsn(0x40), &ctx, io_concurrency.clone()) .await diff --git a/pageserver/src/tenant/timeline.rs b/pageserver/src/tenant/timeline.rs index 54dc3b2d0b..71765b9197 100644 --- a/pageserver/src/tenant/timeline.rs +++ b/pageserver/src/tenant/timeline.rs @@ -103,6 +103,7 @@ use crate::context::{ DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder, }; use crate::disk_usage_eviction_task::{DiskUsageEvictionInfo, EvictionCandidate, finite_f32}; +use crate::feature_resolver::FeatureResolver; use crate::keyspace::{KeyPartitioning, KeySpace}; use crate::l0_flush::{self, L0FlushGlobalState}; use crate::metrics::{ @@ -198,6 +199,7 @@ pub struct TimelineResources { pub l0_compaction_trigger: Arc, pub l0_flush_global_state: l0_flush::L0FlushGlobalState, pub basebackup_prepare_sender: BasebackupPrepareSender, + pub feature_resolver: FeatureResolver, } pub struct Timeline { @@ -444,6 +446,8 @@ pub struct Timeline { /// A channel to send async requests to prepare a basebackup for the basebackup cache. basebackup_prepare_sender: BasebackupPrepareSender, + + feature_resolver: FeatureResolver, } pub(crate) enum PreviousHeatmap { @@ -3072,6 +3076,8 @@ impl Timeline { wait_lsn_log_slow: tokio::sync::Semaphore::new(1), basebackup_prepare_sender: resources.basebackup_prepare_sender, + + feature_resolver: resources.feature_resolver, }; result.repartition_threshold = @@ -4906,6 +4912,7 @@ impl Timeline { LastImageLayerCreationStatus::Initial, false, // don't yield for L0, we're flushing L0 ) + .instrument(info_span!("create_image_layers", mode = %ImageLayerCreationMode::Initial, partition_mode = "initial", lsn = %self.initdb_lsn)) .await?; debug_assert!( matches!(is_complete, LastImageLayerCreationStatus::Complete), @@ -5462,7 +5469,8 @@ impl Timeline { /// Returns the image layers generated and an enum indicating whether the process is fully completed. /// true = we have generate all image layers, false = we preempt the process for L0 compaction. - #[tracing::instrument(skip_all, fields(%lsn, %mode))] + /// + /// `partition_mode` is only for logging purpose and is not used anywhere in this function. async fn create_image_layers( self: &Arc, partitioning: &KeyPartitioning, diff --git a/pageserver/src/tenant/timeline/compaction.rs b/pageserver/src/tenant/timeline/compaction.rs index 0e4b14c3e4..143c2e0865 100644 --- a/pageserver/src/tenant/timeline/compaction.rs +++ b/pageserver/src/tenant/timeline/compaction.rs @@ -1278,11 +1278,55 @@ impl Timeline { } let gc_cutoff = *self.applied_gc_cutoff_lsn.read(); + let l0_l1_boundary_lsn = { + // We do the repartition on the L0-L1 boundary. All data below the boundary + // are compacted by L0 with low read amplification, thus making the `repartition` + // function run fast. + let guard = self.layers.read().await; + guard + .all_persistent_layers() + .iter() + .map(|x| { + // Use the end LSN of delta layers OR the start LSN of image layers. + if x.is_delta { + x.lsn_range.end + } else { + x.lsn_range.start + } + }) + .max() + }; + + let (partition_mode, partition_lsn) = if cfg!(test) + || cfg!(feature = "testing") + || self + .feature_resolver + .evaluate_boolean("image-compaction-boundary", self.tenant_shard_id.tenant_id) + .is_ok() + { + let last_repartition_lsn = self.partitioning.read().1; + let lsn = match l0_l1_boundary_lsn { + Some(boundary) => gc_cutoff + .max(boundary) + .max(last_repartition_lsn) + .max(self.initdb_lsn) + .max(self.ancestor_lsn), + None => self.get_last_record_lsn(), + }; + if lsn <= self.initdb_lsn || lsn <= self.ancestor_lsn { + // Do not attempt to create image layers below the initdb or ancestor LSN -- no data below it + ("l0_l1_boundary", self.get_last_record_lsn()) + } else { + ("l0_l1_boundary", lsn) + } + } else { + ("latest_record", self.get_last_record_lsn()) + }; // 2. Repartition and create image layers if necessary match self .repartition( - self.get_last_record_lsn(), + partition_lsn, self.get_compaction_target_size(), options.flags, ctx, @@ -1301,18 +1345,19 @@ impl Timeline { .extend(sparse_partitioning.into_dense().parts); // 3. Create new image layers for partitions that have been modified "enough". + let mode = if options + .flags + .contains(CompactFlags::ForceImageLayerCreation) + { + ImageLayerCreationMode::Force + } else { + ImageLayerCreationMode::Try + }; let (image_layers, outcome) = self .create_image_layers( &partitioning, lsn, - if options - .flags - .contains(CompactFlags::ForceImageLayerCreation) - { - ImageLayerCreationMode::Force - } else { - ImageLayerCreationMode::Try - }, + mode, &image_ctx, self.last_image_layer_creation_status .load() @@ -1320,6 +1365,7 @@ impl Timeline { .clone(), options.flags.contains(CompactFlags::YieldForL0), ) + .instrument(info_span!("create_image_layers", mode = %mode, partition_mode = %partition_mode, lsn = %lsn)) .await .inspect_err(|err| { if let CreateImageLayersError::GetVectoredError( @@ -1344,7 +1390,8 @@ impl Timeline { } Ok(_) => { - info!("skipping repartitioning due to image compaction LSN being below GC cutoff"); + // This happens very frequently so we don't want to log it. + debug!("skipping repartitioning due to image compaction LSN being below GC cutoff"); } // Suppress errors when cancelled. diff --git a/pageserver/src/tenant/timeline/import_pgdata.rs b/pageserver/src/tenant/timeline/import_pgdata.rs index 658d867c18..db62e9000c 100644 --- a/pageserver/src/tenant/timeline/import_pgdata.rs +++ b/pageserver/src/tenant/timeline/import_pgdata.rs @@ -25,8 +25,11 @@ pub(crate) struct ImportingTimeline { } impl ImportingTimeline { - pub(crate) fn shutdown(self) { + pub(crate) async fn shutdown(self) { self.import_task_handle.abort(); + let _ = self.import_task_handle.await; + + self.timeline.remote_client.shutdown().await; } } @@ -93,6 +96,11 @@ pub async fn doit( ); } + timeline + .remote_client + .schedule_index_upload_for_file_changes()?; + timeline.remote_client.wait_completion().await?; + // Communicate that shard is done. // Ensure at-least-once delivery of the upcall to storage controller // before we mark the task as done and never come here again. diff --git a/pageserver/src/tenant/timeline/import_pgdata/flow.rs b/pageserver/src/tenant/timeline/import_pgdata/flow.rs index 3e10a4e6d6..2ba4ca69ac 100644 --- a/pageserver/src/tenant/timeline/import_pgdata/flow.rs +++ b/pageserver/src/tenant/timeline/import_pgdata/flow.rs @@ -113,14 +113,14 @@ async fn run_v1( let plan_hash = hasher.finish(); if let Some(progress) = &import_progress { - if plan_hash != progress.import_plan_hash { - anyhow::bail!("Import plan does not match storcon metadata"); - } - // Handle collisions on jobs of unequal length if progress.jobs != plan.jobs.len() { anyhow::bail!("Import plan job length does not match storcon metadata") } + + if plan_hash != progress.import_plan_hash { + anyhow::bail!("Import plan does not match storcon metadata"); + } } pausable_failpoint!("import-timeline-pre-execute-pausable"); @@ -218,6 +218,19 @@ impl Planner { checkpoint_buf, ))); + // Sort the tasks by the key ranges they handle. + // The plan being generated here needs to be stable across invocations + // of this method. + self.tasks.sort_by_key(|task| match task { + AnyImportTask::SingleKey(key) => (key.key, key.key.next()), + AnyImportTask::RelBlocks(rel_blocks) => { + (rel_blocks.key_range.start, rel_blocks.key_range.end) + } + AnyImportTask::SlruBlocks(slru_blocks) => { + (slru_blocks.key_range.start, slru_blocks.key_range.end) + } + }); + // Assigns parts of key space to later parallel jobs let mut last_end_key = Key::MIN; let mut current_chunk = Vec::new(); @@ -426,6 +439,8 @@ impl Plan { })); }, maybe_complete_job_idx = work.next() => { + pausable_failpoint!("import-task-complete-pausable"); + match maybe_complete_job_idx { Some(Ok((job_idx, res))) => { assert!(last_completed_job_idx.checked_add(1).unwrap() == job_idx); @@ -440,6 +455,9 @@ impl Plan { import_plan_hash, }; + timeline.remote_client.schedule_index_upload_for_file_changes()?; + timeline.remote_client.wait_completion().await?; + storcon_client.put_timeline_import_status( timeline.tenant_shard_id, timeline.timeline_id, @@ -640,7 +658,11 @@ impl Hash for ImportSingleKeyTask { let ImportSingleKeyTask { key, buf } = self; key.hash(state); - buf.hash(state); + // The key value might not have a stable binary representation. + // For instance, the db directory uses an unstable hash-map. + // To work around this we are a bit lax here and only hash the + // size of the buffer which must be consistent. + buf.len().hash(state); } } @@ -915,7 +937,7 @@ impl ChunkProcessingJob { let guard = timeline.layers.read().await; let existing_layer = guard.try_get_from_key(&desc.key()); if let Some(layer) = existing_layer { - if layer.metadata().generation != timeline.generation { + if layer.metadata().generation == timeline.generation { return Err(anyhow::anyhow!( "Import attempted to rewrite layer file in the same generation: {}", layer.local_path() diff --git a/pgxn/neon/communicator.c b/pgxn/neon/communicator.c index 9609f186b9..2655a45bcc 100644 --- a/pgxn/neon/communicator.c +++ b/pgxn/neon/communicator.c @@ -717,7 +717,7 @@ prefetch_read(PrefetchRequest *slot) Assert(slot->status == PRFS_REQUESTED); Assert(slot->response == NULL); Assert(slot->my_ring_index == MyPState->ring_receive); - Assert(readpage_reentrant_guard); + Assert(readpage_reentrant_guard || AmPrewarmWorker); if (slot->status != PRFS_REQUESTED || slot->response != NULL || @@ -800,7 +800,7 @@ communicator_prefetch_receive(BufferTag tag) PrfHashEntry *entry; PrefetchRequest hashkey; - Assert(readpage_reentrant_guard); + Assert(readpage_reentrant_guard || AmPrewarmWorker); /* do not pump prefetch state in prewarm worker */ hashkey.buftag = tag; entry = prfh_lookup(MyPState->prf_hash, &hashkey); if (entry != NULL && prefetch_wait_for(entry->slot->my_ring_index)) @@ -2450,6 +2450,7 @@ void communicator_reconfigure_timeout_if_needed(void) { bool needs_set = MyPState->ring_receive != MyPState->ring_unused && + !AmPrewarmWorker && /* do not pump prefetch state in prewarm worker */ readahead_getpage_pull_timeout_ms > 0; if (needs_set != timeout_set) diff --git a/pgxn/neon/file_cache.c b/pgxn/neon/file_cache.c index 176fd9643f..45a4695495 100644 --- a/pgxn/neon/file_cache.c +++ b/pgxn/neon/file_cache.c @@ -201,6 +201,8 @@ static shmem_request_hook_type prev_shmem_request_hook; bool lfc_store_prefetch_result; bool lfc_prewarm_update_ws_estimation; +bool AmPrewarmWorker; + #define LFC_ENABLED() (lfc_ctl->limit != 0) /* @@ -845,6 +847,8 @@ lfc_prewarm_main(Datum main_arg) PrewarmWorkerState* ws; uint32 worker_id = DatumGetInt32(main_arg); + AmPrewarmWorker = true; + pqsignal(SIGTERM, die); BackgroundWorkerUnblockSignals(); diff --git a/pgxn/neon/neon.h b/pgxn/neon/neon.h index a2e81feb5f..431dacb708 100644 --- a/pgxn/neon/neon.h +++ b/pgxn/neon/neon.h @@ -23,6 +23,8 @@ extern int wal_acceptor_connection_timeout; extern int readahead_getpage_pull_timeout_ms; extern bool disable_wal_prev_lsn_checks; +extern bool AmPrewarmWorker; + #if PG_MAJORVERSION_NUM >= 17 extern uint32 WAIT_EVENT_NEON_LFC_MAINTENANCE; extern uint32 WAIT_EVENT_NEON_LFC_READ; diff --git a/pgxn/neon/walproposer.c b/pgxn/neon/walproposer.c index 3befb42030..f42103c7cd 100644 --- a/pgxn/neon/walproposer.c +++ b/pgxn/neon/walproposer.c @@ -155,8 +155,9 @@ WalProposerCreate(WalProposerConfig *config, walproposer_api api) int written = 0; written = snprintf((char *) &sk->conninfo, MAXCONNINFO, - "host=%s port=%s dbname=replication options='-c timeline_id=%s tenant_id=%s'", - sk->host, sk->port, wp->config->neon_timeline, wp->config->neon_tenant); + "%s host=%s port=%s dbname=replication options='-c timeline_id=%s tenant_id=%s'", + wp->config->safekeeper_conninfo_options, sk->host, sk->port, + wp->config->neon_timeline, wp->config->neon_tenant); if (written > MAXCONNINFO || written < 0) wp_log(FATAL, "could not create connection string for safekeeper %s:%s", sk->host, sk->port); } diff --git a/pgxn/neon/walproposer.h b/pgxn/neon/walproposer.h index 83ef72d3d7..cca20e746b 100644 --- a/pgxn/neon/walproposer.h +++ b/pgxn/neon/walproposer.h @@ -714,6 +714,9 @@ typedef struct WalProposerConfig */ char *safekeepers_list; + /* libpq connection info options. */ + char *safekeeper_conninfo_options; + /* * WalProposer reconnects to offline safekeepers once in this interval. * Time is in milliseconds. diff --git a/pgxn/neon/walproposer_pg.c b/pgxn/neon/walproposer_pg.c index 17582405db..d15bf91d24 100644 --- a/pgxn/neon/walproposer_pg.c +++ b/pgxn/neon/walproposer_pg.c @@ -64,6 +64,7 @@ char *wal_acceptors_list = ""; int wal_acceptor_reconnect_timeout = 1000; int wal_acceptor_connection_timeout = 10000; int safekeeper_proto_version = 3; +char *safekeeper_conninfo_options = ""; /* Set to true in the walproposer bgw. */ static bool am_walproposer; @@ -119,6 +120,7 @@ init_walprop_config(bool syncSafekeepers) walprop_config.neon_timeline = neon_timeline; /* WalProposerCreate scribbles directly on it, so pstrdup */ walprop_config.safekeepers_list = pstrdup(wal_acceptors_list); + walprop_config.safekeeper_conninfo_options = pstrdup(safekeeper_conninfo_options); walprop_config.safekeeper_reconnect_timeout = wal_acceptor_reconnect_timeout; walprop_config.safekeeper_connection_timeout = wal_acceptor_connection_timeout; walprop_config.wal_segment_size = wal_segment_size; @@ -203,6 +205,16 @@ nwp_register_gucs(void) * GUC_LIST_QUOTE */ NULL, assign_neon_safekeepers, NULL); + DefineCustomStringVariable( + "neon.safekeeper_conninfo_options", + "libpq keyword parameters and values to apply to safekeeper connections", + NULL, + &safekeeper_conninfo_options, + "", + PGC_POSTMASTER, + 0, + NULL, NULL, NULL); + DefineCustomIntVariable( "neon.safekeeper_reconnect_timeout", "Walproposer reconnects to offline safekeepers once in this interval.", diff --git a/proxy/src/serverless/sql_over_http.rs b/proxy/src/serverless/sql_over_http.rs index dfaeedaeae..1c5bb64480 100644 --- a/proxy/src/serverless/sql_over_http.rs +++ b/proxy/src/serverless/sql_over_http.rs @@ -14,7 +14,9 @@ use hyper::http::{HeaderName, HeaderValue}; use hyper::{HeaderMap, Request, Response, StatusCode, header}; use indexmap::IndexMap; use postgres_client::error::{DbError, ErrorPosition, SqlState}; -use postgres_client::{GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, Transaction}; +use postgres_client::{ + GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction, +}; use pq_proto::StartupMessageParamsBuilder; use serde::Serialize; use serde_json::Value; @@ -1092,22 +1094,41 @@ async fn query_to_json( let query_start = Instant::now(); let query_params = data.params; - let mut row_stream = std::pin::pin!( - client - .query_raw_txt(&data.query, query_params) - .await - .map_err(SqlOverHttpError::Postgres)? - ); + let mut row_stream = client + .query_raw_txt(&data.query, query_params) + .await + .map_err(SqlOverHttpError::Postgres)?; let query_acknowledged = Instant::now(); + let columns_len = row_stream.statement.columns().len(); + let mut fields = Vec::with_capacity(columns_len); + let mut types = Vec::with_capacity(columns_len); + + for c in row_stream.statement.columns() { + fields.push(json!({ + "name": c.name().to_owned(), + "dataTypeID": c.type_().oid(), + "tableID": c.table_oid(), + "columnID": c.column_id(), + "dataTypeSize": c.type_size(), + "dataTypeModifier": c.type_modifier(), + "format": "text", + })); + + types.push(c.type_().clone()); + } + + let raw_output = parsed_headers.raw_output; + let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode); + // Manually drain the stream into a vector to leave row_stream hanging // around to get a command tag. Also check that the response is not too // big. - let mut rows: Vec = Vec::new(); + let mut rows = Vec::new(); while let Some(row) = row_stream.next().await { let row = row.map_err(SqlOverHttpError::Postgres)?; *current_size += row.body_len(); - rows.push(row); + // we don't have a streaming response support yet so this is to prevent OOM // from a malicious query (eg a cross join) if *current_size > config.max_response_size_bytes { @@ -1115,13 +1136,26 @@ async fn query_to_json( config.max_response_size_bytes, )); } + + let row = pg_text_row_to_json(&row, &types, raw_output, array_mode)?; + rows.push(row); + + // assumption: parsing pg text and converting to json takes CPU time. + // let's assume it is slightly expensive, so we should consume some cooperative budget. + // Especially considering that `RowStream::next` might be pulling from a batch + // of rows and never hit the tokio mpsc for a long time (although unlikely). + tokio::task::consume_budget().await; } let query_resp_end = Instant::now(); - let ready = row_stream.ready_status(); + let RowStream { + command_tag, + status: ready, + .. + } = row_stream; // grab the command tag and number of rows affected - let command_tag = row_stream.command_tag().unwrap_or_default(); + let command_tag = command_tag.unwrap_or_default(); let mut command_tag_split = command_tag.split(' '); let command_tag_name = command_tag_split.next().unwrap_or_default(); let command_tag_count = if command_tag_name == "INSERT" { @@ -1142,38 +1176,6 @@ async fn query_to_json( "finished executing query" ); - let columns_len = row_stream.columns().len(); - let mut fields = Vec::with_capacity(columns_len); - let mut columns = Vec::with_capacity(columns_len); - - for c in row_stream.columns() { - fields.push(json!({ - "name": c.name().to_owned(), - "dataTypeID": c.type_().oid(), - "tableID": c.table_oid(), - "columnID": c.column_id(), - "dataTypeSize": c.type_size(), - "dataTypeModifier": c.type_modifier(), - "format": "text", - })); - - match client.get_type(c.type_oid()).await { - Ok(t) => columns.push(t), - Err(err) => { - tracing::warn!(?err, "unable to query type information"); - return Err(SqlOverHttpError::InternalPostgres(err)); - } - } - } - - let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode); - - // convert rows to JSON - let rows = rows - .iter() - .map(|row| pg_text_row_to_json(row, &columns, parsed_headers.raw_output, array_mode)) - .collect::, _>>()?; - // Resulting JSON format is based on the format of node-postgres result. let results = json!({ "command": command_tag_name.to_string(), diff --git a/safekeeper/tests/walproposer_sim/simulation.rs b/safekeeper/tests/walproposer_sim/simulation.rs index f314143952..70fecfbe22 100644 --- a/safekeeper/tests/walproposer_sim/simulation.rs +++ b/safekeeper/tests/walproposer_sim/simulation.rs @@ -87,6 +87,7 @@ impl WalProposer { let config = Config { ttid, safekeepers_list: addrs, + safekeeper_conninfo_options: String::new(), safekeeper_reconnect_timeout: 1000, safekeeper_connection_timeout: 5000, sync_safekeepers, diff --git a/scripts/benchmark_durations.py b/scripts/benchmark_durations.py index a9a90c7370..c74ef9d899 100755 --- a/scripts/benchmark_durations.py +++ b/scripts/benchmark_durations.py @@ -32,12 +32,6 @@ BENCHMARKS_DURATION_QUERY = """ # the total duration varies from 8 to 40 minutes. # We use some pre-collected durations as a fallback to have a better distribution. FALLBACK_DURATION = { - "test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py::test_pageserver_max_throughput_getpage_at_latest_lsn[1-13-30]": 400.15, - "test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py::test_pageserver_max_throughput_getpage_at_latest_lsn[1-6-30]": 372.521, - "test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py::test_pageserver_max_throughput_getpage_at_latest_lsn[10-13-30]": 420.017, - "test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py::test_pageserver_max_throughput_getpage_at_latest_lsn[10-6-30]": 373.769, - "test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py::test_pageserver_max_throughput_getpage_at_latest_lsn[100-13-30]": 678.742, - "test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py::test_pageserver_max_throughput_getpage_at_latest_lsn[100-6-30]": 512.135, "test_runner/performance/test_branch_creation.py::test_branch_creation_heavy_write[20]": 58.036, "test_runner/performance/test_branch_creation.py::test_branch_creation_many_relations": 22.104, "test_runner/performance/test_branch_creation.py::test_branch_creation_many[1024]": 126.073, diff --git a/storage_broker/src/bin/storage_broker.rs b/storage_broker/src/bin/storage_broker.rs index 476d5f03ea..bae5ccb36c 100644 --- a/storage_broker/src/bin/storage_broker.rs +++ b/storage_broker/src/bin/storage_broker.rs @@ -17,12 +17,14 @@ use std::pin::Pin; use std::sync::Arc; use std::time::Duration; +use bytes::Bytes; use camino::Utf8PathBuf; use clap::{Parser, command}; use futures::future::OptionFuture; use futures_core::Stream; use futures_util::StreamExt; -use http_body_util::Full; +use http_body_util::combinators::BoxBody; +use http_body_util::{Empty, Full}; use http_utils::tls_certs::ReloadingCertificateResolver; use hyper::body::Incoming; use hyper::header::CONTENT_TYPE; @@ -46,7 +48,6 @@ use tokio::net::TcpListener; use tokio::sync::broadcast; use tokio::sync::broadcast::error::RecvError; use tokio::time; -use tonic::body::{self, BoxBody, empty_body}; use tonic::codegen::Service; use tonic::{Code, Request, Response, Status}; use tracing::*; @@ -634,7 +635,7 @@ impl BrokerService for Broker { // We serve only metrics and healthcheck through http1. async fn http1_handler( req: hyper::Request, -) -> Result, Infallible> { +) -> Result>, Infallible> { let resp = match (req.method(), req.uri().path()) { (&Method::GET, "/metrics") => { let mut buffer = vec![]; @@ -645,16 +646,16 @@ async fn http1_handler( hyper::Response::builder() .status(StatusCode::OK) .header(CONTENT_TYPE, encoder.format_type()) - .body(body::boxed(Full::new(bytes::Bytes::from(buffer)))) + .body(BoxBody::new(Full::new(Bytes::from(buffer)))) .unwrap() } (&Method::GET, "/status") => hyper::Response::builder() .status(StatusCode::OK) - .body(empty_body()) + .body(BoxBody::new(Empty::new())) .unwrap(), _ => hyper::Response::builder() .status(StatusCode::NOT_FOUND) - .body(empty_body()) + .body(BoxBody::new(Empty::new())) .unwrap(), }; Ok(resp) diff --git a/storage_controller/src/service.rs b/storage_controller/src/service.rs index 7e4bb627af..d284747f73 100644 --- a/storage_controller/src/service.rs +++ b/storage_controller/src/service.rs @@ -3823,6 +3823,13 @@ impl Service { .await; failpoint_support::sleep_millis_async!("tenant-create-timeline-shared-lock"); let is_import = create_req.is_import(); + let read_only = matches!( + create_req.mode, + models::TimelineCreateRequestMode::Branch { + read_only: true, + .. + } + ); if is_import { // Ensure that there is no split on-going. @@ -3895,13 +3902,13 @@ impl Service { } None - } else if safekeepers { + } else if safekeepers || read_only { // Note that for imported timelines, we do not create the timeline on the safekeepers // straight away. Instead, we do it once the import finalized such that we know what // start LSN to provide for the safekeepers. This is done in // [`Self::finalize_timeline_import`]. let res = self - .tenant_timeline_create_safekeepers(tenant_id, &timeline_info) + .tenant_timeline_create_safekeepers(tenant_id, &timeline_info, read_only) .instrument(tracing::info_span!("timeline_create_safekeepers", %tenant_id, timeline_id=%timeline_info.timeline_id)) .await?; Some(res) @@ -3915,6 +3922,11 @@ impl Service { }) } + #[instrument(skip_all, fields( + tenant_id=%req.tenant_shard_id.tenant_id, + shard_id=%req.tenant_shard_id.shard_slug(), + timeline_id=%req.timeline_id, + ))] pub(crate) async fn handle_timeline_shard_import_progress( self: &Arc, req: TimelineImportStatusRequest, @@ -3964,6 +3976,11 @@ impl Service { }) } + #[instrument(skip_all, fields( + tenant_id=%req.tenant_shard_id.tenant_id, + shard_id=%req.tenant_shard_id.shard_slug(), + timeline_id=%req.timeline_id, + ))] pub(crate) async fn handle_timeline_shard_import_progress_upcall( self: &Arc, req: PutTimelineImportStatusRequest, diff --git a/storage_controller/src/service/safekeeper_service.rs b/storage_controller/src/service/safekeeper_service.rs index cd5ace449d..1f673fe445 100644 --- a/storage_controller/src/service/safekeeper_service.rs +++ b/storage_controller/src/service/safekeeper_service.rs @@ -208,6 +208,7 @@ impl Service { self: &Arc, tenant_id: TenantId, timeline_info: &TimelineInfo, + read_only: bool, ) -> Result { let timeline_id = timeline_info.timeline_id; let pg_version = timeline_info.pg_version * 10000; @@ -220,7 +221,11 @@ impl Service { let start_lsn = timeline_info.last_record_lsn; // Choose initial set of safekeepers respecting affinity - let sks = self.safekeepers_for_new_timeline().await?; + let sks = if !read_only { + self.safekeepers_for_new_timeline().await? + } else { + Vec::new() + }; let sks_persistence = sks.iter().map(|sk| sk.id.0 as i64).collect::>(); // Add timeline to db let mut timeline_persist = TimelinePersistence { @@ -253,6 +258,16 @@ impl Service { ))); } } + let ret = SafekeepersInfo { + generation: timeline_persist.generation as u32, + safekeepers: sks.clone(), + tenant_id, + timeline_id, + }; + if read_only { + return Ok(ret); + } + // Create the timeline on a quorum of safekeepers let remaining = self .tenant_timeline_create_safekeepers_quorum( @@ -316,12 +331,7 @@ impl Service { } } - Ok(SafekeepersInfo { - generation: timeline_persist.generation as u32, - safekeepers: sks, - tenant_id, - timeline_id, - }) + Ok(ret) } pub(crate) async fn tenant_timeline_create_safekeepers_until_success( @@ -336,8 +346,10 @@ impl Service { return Err(TimelineImportFinalizeError::ShuttingDown); } + // This function is only used in non-read-only scenarios + let read_only = false; let res = self - .tenant_timeline_create_safekeepers(tenant_id, &timeline_info) + .tenant_timeline_create_safekeepers(tenant_id, &timeline_info, read_only) .await; match res { @@ -410,6 +422,18 @@ impl Service { .chain(tl.sk_set.iter()) .collect::>(); + // The timeline has no safekeepers: we need to delete it from the db manually, + // as no safekeeper reconciler will get to it + if all_sks.is_empty() { + if let Err(err) = self + .persistence + .delete_timeline(tenant_id, timeline_id) + .await + { + tracing::warn!(%tenant_id, %timeline_id, "couldn't delete timeline from db: {err}"); + } + } + // Schedule reconciliations for &sk_id in all_sks.iter() { let pending_op = TimelinePendingOpPersistence { diff --git a/test_runner/fixtures/neon_fixtures.py b/test_runner/fixtures/neon_fixtures.py index e413b3c6d2..7f4150b580 100644 --- a/test_runner/fixtures/neon_fixtures.py +++ b/test_runner/fixtures/neon_fixtures.py @@ -404,6 +404,29 @@ class PageserverTracingConfig: return ("tracing", value) +@dataclass +class PageserverImportConfig: + import_job_concurrency: int + import_job_soft_size_limit: int + import_job_checkpoint_threshold: int + + @staticmethod + def default() -> PageserverImportConfig: + return PageserverImportConfig( + import_job_concurrency=4, + import_job_soft_size_limit=512 * 1024, + import_job_checkpoint_threshold=4, + ) + + def to_config_key_value(self) -> tuple[str, dict[str, Any]]: + value = { + "import_job_concurrency": self.import_job_concurrency, + "import_job_soft_size_limit": self.import_job_soft_size_limit, + "import_job_checkpoint_threshold": self.import_job_checkpoint_threshold, + } + return ("timeline_import_config", value) + + class NeonEnvBuilder: """ Builder object to create a Neon runtime environment @@ -454,6 +477,7 @@ class NeonEnvBuilder: pageserver_wal_receiver_protocol: PageserverWalReceiverProtocol | None = None, pageserver_get_vectored_concurrent_io: str | None = None, pageserver_tracing_config: PageserverTracingConfig | None = None, + pageserver_import_config: PageserverImportConfig | None = None, ): self.repo_dir = repo_dir self.rust_log_override = rust_log_override @@ -511,6 +535,7 @@ class NeonEnvBuilder: ) self.pageserver_tracing_config = pageserver_tracing_config + self.pageserver_import_config = pageserver_import_config self.pageserver_default_tenant_config_compaction_algorithm: dict[str, Any] | None = ( pageserver_default_tenant_config_compaction_algorithm @@ -682,7 +707,7 @@ class NeonEnvBuilder: log.info( f"Copying pageserver tenants directory {tenants_from_dir} to {tenants_to_dir}" ) - shutil.copytree(tenants_from_dir, tenants_to_dir) + subprocess.run(["cp", "-a", tenants_from_dir, tenants_to_dir], check=True) else: log.info( f"Creating overlayfs mount of pageserver tenants directory {tenants_from_dir} to {tenants_to_dir}" @@ -698,8 +723,9 @@ class NeonEnvBuilder: shutil.rmtree(self.repo_dir / "local_fs_remote_storage", ignore_errors=True) if self.test_overlay_dir is None: log.info("Copying local_fs_remote_storage directory from snapshot") - shutil.copytree( - repo_dir / "local_fs_remote_storage", self.repo_dir / "local_fs_remote_storage" + subprocess.run( + ["cp", "-a", f"{repo_dir / 'local_fs_remote_storage'}", f"{self.repo_dir}"], + check=True, ) else: log.info("Creating overlayfs mount of local_fs_remote_storage directory from snapshot") @@ -1178,6 +1204,10 @@ class NeonEnv: self.pageserver_wal_receiver_protocol = config.pageserver_wal_receiver_protocol self.pageserver_get_vectored_concurrent_io = config.pageserver_get_vectored_concurrent_io self.pageserver_tracing_config = config.pageserver_tracing_config + if config.pageserver_import_config is None: + self.pageserver_import_config = PageserverImportConfig.default() + else: + self.pageserver_import_config = config.pageserver_import_config # Create the neon_local's `NeonLocalInitConf` cfg: dict[str, Any] = { @@ -1223,6 +1253,7 @@ class NeonEnv: # Create config for pageserver http_auth_type = "NeonJWT" if config.auth_enabled else "Trust" pg_auth_type = "NeonJWT" if config.auth_enabled else "Trust" + grpc_auth_type = "NeonJWT" if config.auth_enabled else "Trust" for ps_id in range( self.BASE_PAGESERVER_ID, self.BASE_PAGESERVER_ID + config.num_pageservers ): @@ -1249,18 +1280,13 @@ class NeonEnv: else None, "pg_auth_type": pg_auth_type, "http_auth_type": http_auth_type, + "grpc_auth_type": grpc_auth_type, "availability_zone": availability_zone, # Disable pageserver disk syncs in tests: when running tests concurrently, this avoids # the pageserver taking a long time to start up due to syncfs flushing other tests' data "no_sync": True, # Look for gaps in WAL received from safekeepeers "validate_wal_contiguity": True, - # TODO(vlad): make these configurable through the builder - "timeline_import_config": { - "import_job_concurrency": 4, - "import_job_soft_size_limit": 512 * 1024, - "import_job_checkpoint_threshold": 4, - }, } # Batching (https://github.com/neondatabase/neon/issues/9377): @@ -1322,6 +1348,12 @@ class NeonEnv: ps_cfg[key] = value + if self.pageserver_import_config is not None: + key, value = self.pageserver_import_config.to_config_key_value() + + if key not in ps_cfg: + ps_cfg[key] = value + # Create a corresponding NeonPageserver object ps = NeonPageserver( self, ps_id, port=pageserver_port, az_id=ps_cfg["availability_zone"] diff --git a/test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py b/test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py index 8874fe663b..41696bf887 100644 --- a/test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py +++ b/test_runner/performance/pageserver/pagebench/test_pageserver_max_throughput_getpage_at_latest_lsn.py @@ -14,7 +14,7 @@ from fixtures.neon_fixtures import ( PgBin, wait_for_last_flush_lsn, ) -from fixtures.utils import get_scale_for_db, humantime_to_ms, skip_on_ci +from fixtures.utils import get_scale_for_db, humantime_to_ms from performance.pageserver.util import setup_pageserver_with_tenants @@ -36,9 +36,6 @@ if TYPE_CHECKING: @pytest.mark.parametrize("pgbench_scale", [get_scale_for_db(200)]) @pytest.mark.parametrize("n_tenants", [500]) @pytest.mark.timeout(10000) -@skip_on_ci( - "This test needs lot of resources and should run on dedicated HW, not in github action runners as part of CI" -) def test_pageserver_characterize_throughput_with_n_tenants( neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchmarker, @@ -63,9 +60,6 @@ def test_pageserver_characterize_throughput_with_n_tenants( @pytest.mark.parametrize("n_clients", [1, 64]) @pytest.mark.parametrize("n_tenants", [1]) @pytest.mark.timeout(2400) -@skip_on_ci( - "This test needs lot of resources and should run on dedicated HW, not in github action runners as part of CI" -) def test_pageserver_characterize_latencies_with_1_client_and_throughput_with_many_clients_one_tenant( neon_env_builder: NeonEnvBuilder, zenbenchmark: NeonBenchmarker, diff --git a/test_runner/regress/test_compute_metrics.py b/test_runner/regress/test_compute_metrics.py index 5e3f8671a2..2cb2ee7b58 100644 --- a/test_runner/regress/test_compute_metrics.py +++ b/test_runner/regress/test_compute_metrics.py @@ -217,11 +217,11 @@ if SQL_EXPORTER is None: self, logs_dir: Path, config_file: Path, collector_file: Path, port: int ) -> None: # NOTE: Keep the version the same as in - # compute/Dockerfile.compute-node and Dockerfile.build-tools. + # compute/compute-node.Dockerfile and build-tools.Dockerfile. # # The "host" network mode allows sql_exporter to talk to the # endpoint which is running on the host. - super().__init__("docker.io/burningalchemist/sql_exporter:0.17.0", network_mode="host") + super().__init__("docker.io/burningalchemist/sql_exporter:0.17.3", network_mode="host") self.__logs_dir = logs_dir self.__port = port diff --git a/test_runner/regress/test_import_pgdata.py b/test_runner/regress/test_import_pgdata.py index 0472b92145..69cbdec5b0 100644 --- a/test_runner/regress/test_import_pgdata.py +++ b/test_runner/regress/test_import_pgdata.py @@ -1,7 +1,10 @@ import base64 +import concurrent.futures import json +import random +import threading import time -from enum import Enum +from enum import Enum, StrEnum from pathlib import Path from threading import Event @@ -11,7 +14,14 @@ import pytest from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId from fixtures.fast_import import FastImport from fixtures.log_helper import log -from fixtures.neon_fixtures import NeonEnvBuilder, PgBin, PgProtocol, VanillaPostgres +from fixtures.neon_fixtures import ( + NeonEnvBuilder, + PageserverImportConfig, + PgBin, + PgProtocol, + StorageControllerMigrationConfig, + VanillaPostgres, +) from fixtures.pageserver.http import ( ImportPgdataIdemptencyKey, ) @@ -494,6 +504,259 @@ def test_import_respects_tenant_shutdown( wait_until(cplane_notified) +@skip_in_debug_build("Validation query takes too long in debug builds") +def test_import_chaos( + neon_env_builder: NeonEnvBuilder, vanilla_pg: VanillaPostgres, make_httpserver: HTTPServer +): + """ + Perform a timeline import while injecting chaos in the environment. + We expect that the import completes eventually, that it passes validation and + the resulting timeline can be written to. + """ + TARGET_RELBOCK_SIZE = 512 * 1024 * 1024 # 512 MiB + ALLOWED_IMPORT_RUNTIME = 90 # seconds + SHARD_COUNT = 4 + + neon_env_builder.num_pageservers = SHARD_COUNT + neon_env_builder.pageserver_import_config = PageserverImportConfig( + import_job_concurrency=1, + import_job_soft_size_limit=64 * 1024, + import_job_checkpoint_threshold=4, + ) + + # Set up mock control plane HTTP server to listen for import completions + import_completion_signaled = Event() + # There's some Python magic at play here. A list can be updated from the + # handler thread, but an optional cannot. Hence, use a list with one element. + import_error = [] + + def handler(request: Request) -> Response: + assert request.json is not None + + body = request.json + if "error" in body: + if body["error"]: + import_error.append(body["error"]) + + log.info(f"control plane /import_complete request: {request.json}") + import_completion_signaled.set() + return Response(json.dumps({}), status=200) + + cplane_mgmt_api_server = make_httpserver + cplane_mgmt_api_server.expect_request( + "/storage/api/v1/import_complete", method="PUT" + ).respond_with_handler(handler) + + # Plug the cplane mock in + neon_env_builder.control_plane_hooks_api = ( + f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/storage/api/v1/" + ) + + # The import will specifiy a local filesystem path mocking remote storage + neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS) + + vanilla_pg.start() + vanilla_pg.safe_psql("create user cloud_admin with password 'postgres' superuser") + vanilla_pg.safe_psql("""create table t (data char(900)) with (fillfactor = 10)""") + + nrows = 0 + while True: + relblock_size = vanilla_pg.safe_psql_scalar("select pg_relation_size('t')") + log.info( + f"relblock size: {relblock_size / 8192} pages (target: {TARGET_RELBOCK_SIZE // 8192}) pages" + ) + if relblock_size >= TARGET_RELBOCK_SIZE: + break + addrows = int((TARGET_RELBOCK_SIZE - relblock_size) // 8192) + assert addrows >= 1, "forward progress" + vanilla_pg.safe_psql( + f"insert into t select generate_series({nrows + 1}, {nrows + addrows})" + ) + nrows += addrows + + vanilla_pg.stop() + + env = neon_env_builder.init_configs() + env.start() + + # Pause after every import task to extend the test runtime and allow + # for more chaos injection. + for ps in env.pageservers: + ps.add_persistent_failpoint("import-task-complete-pausable", "sleep(5)") + + env.storage_controller.allowed_errors.extend( + [ + # The shard might have moved or the pageserver hosting the shard restarted + ".*Call to node.*management API.*failed.*", + # Migrations have their time outs set to 0 + ".*Timed out after.*downloading layers.*", + ".*Failed to prepare by downloading layers.*", + # The test may kill the storage controller or pageservers + ".*request was dropped before completing.*", + ] + ) + for ps in env.pageservers: + ps.allowed_errors.extend( + [ + # We might re-write a layer in a different generation if the import + # needs to redo some of the progress since not each job is checkpointed. + ".*was unlinked but was not dangling.*", + # The test may kill the storage controller or pageservers + ".*request was dropped before completing.*", + # Test can SIGTERM pageserver while it is downloading + ".*removing local file.*temp_download.*", + ".*Failed to flush heatmap.*", + # Test can SIGTERM the storage controller while pageserver + # is attempting to upcall. + ".*storage controller upcall failed.*timeline_import_status.*", + # TODO(vlad): TenantManager::reset_tenant returns a blanked anyhow error. + # It should return ResourceUnavailable or something that doesn't error log. + ".*activate_post_import.*InternalServerError.*tenant map is shutting down.*", + # TODO(vlad): How can this happen? + ".*Failed to download a remote file: deserialize index part file.*", + ".*Cancelled request finished with an error.*", + ] + ) + + importbucket_path = neon_env_builder.repo_dir / "test_import_chaos_bucket" + mock_import_bucket(vanilla_pg, importbucket_path) + + tenant_id = TenantId.generate() + timeline_id = TimelineId.generate() + idempotency = ImportPgdataIdemptencyKey.random() + + env.storage_controller.tenant_create( + tenant_id, shard_count=SHARD_COUNT, placement_policy={"Attached": 1} + ) + env.storage_controller.reconcile_until_idle() + + env.storage_controller.timeline_create( + tenant_id, + { + "new_timeline_id": str(timeline_id), + "import_pgdata": { + "idempotency_key": str(idempotency), + "location": {"LocalFs": {"path": str(importbucket_path.absolute())}}, + }, + }, + ) + + def chaos(stop_chaos: threading.Event): + class ChaosType(StrEnum): + MIGRATE_SHARD = "migrate_shard" + RESTART_IMMEDIATE = "restart_immediate" + RESTART = "restart" + STORCON_RESTART_IMMEDIATE = "storcon_restart_immediate" + + while not stop_chaos.is_set(): + chaos_type = random.choices( + population=[ + ChaosType.MIGRATE_SHARD, + ChaosType.RESTART, + ChaosType.RESTART_IMMEDIATE, + ChaosType.STORCON_RESTART_IMMEDIATE, + ], + weights=[0.25, 0.25, 0.25, 0.25], + k=1, + )[0] + + try: + if chaos_type == ChaosType.MIGRATE_SHARD: + target_shard_number = random.randint(0, SHARD_COUNT - 1) + target_shard = TenantShardId(tenant_id, target_shard_number, SHARD_COUNT) + + placements = env.storage_controller.get_tenants_placement() + log.info(f"{placements=}") + target_ps = placements[str(target_shard)]["intent"]["attached"] + if len(placements[str(target_shard)]["intent"]["secondary"]) == 0: + dest_ps = None + else: + dest_ps = placements[str(target_shard)]["intent"]["secondary"][0] + + if target_ps is None or dest_ps is None: + continue + + config = StorageControllerMigrationConfig( + secondary_warmup_timeout="0s", + secondary_download_request_timeout="0s", + prewarm=False, + ) + env.storage_controller.tenant_shard_migrate(target_shard, dest_ps, config) + + log.info( + f"CHAOS: Migrating shard {target_shard} from pageserver {target_ps} to {dest_ps}" + ) + elif chaos_type == ChaosType.RESTART_IMMEDIATE: + target_ps = random.choice(env.pageservers) + log.info(f"CHAOS: Immediate restart of pageserver {target_ps.id}") + target_ps.stop(immediate=True) + target_ps.start() + elif chaos_type == ChaosType.RESTART: + target_ps = random.choice(env.pageservers) + log.info(f"CHAOS: Normal restart of pageserver {target_ps.id}") + target_ps.stop(immediate=False) + target_ps.start() + elif chaos_type == ChaosType.STORCON_RESTART_IMMEDIATE: + log.info("CHAOS: Immediate restart of storage controller") + env.storage_controller.stop(immediate=True) + env.storage_controller.start() + except Exception as e: + log.warning(f"CHAOS: Error during chaos operation {chaos_type}: {e}") + + # Sleep before next chaos event + time.sleep(1) + + log.info("Chaos injector stopped") + + def wait_for_import_completion(): + start = time.time() + done = import_completion_signaled.wait(ALLOWED_IMPORT_RUNTIME) + if not done: + raise TimeoutError(f"Import did not signal completion within {ALLOWED_IMPORT_RUNTIME}") + + end = time.time() + + log.info(f"Import completion signalled after {end - start}s {import_error=}") + + if import_error: + raise RuntimeError(f"Import error: {import_error}") + + with concurrent.futures.ThreadPoolExecutor() as executor: + stop_chaos = threading.Event() + + wait_for_import_completion_fut = executor.submit(wait_for_import_completion) + chaos_fut = executor.submit(chaos, stop_chaos) + + try: + wait_for_import_completion_fut.result() + except Exception as e: + raise e + finally: + stop_chaos.set() + chaos_fut.result() + + import_branch_name = "imported" + env.neon_cli.mappings_map_branch(import_branch_name, tenant_id, timeline_id) + endpoint = env.endpoints.create_start(branch_name=import_branch_name, tenant_id=tenant_id) + + # Validate the imported data is legit + assert endpoint.safe_psql_many( + [ + "set effective_io_concurrency=32;", + "SET statement_timeout='300s';", + "select count(*), sum(data::bigint)::bigint from t", + ] + ) == [[], [], [(nrows, nrows * (nrows + 1) // 2)]] + + endpoint.stop() + + # Validate writes + workload = Workload(env, tenant_id, timeline_id, branch_name=import_branch_name) + workload.init() + workload.write_rows(64) + workload.validate() + + def test_fast_import_with_pageserver_ingest( test_output_dir, vanilla_pg: VanillaPostgres, diff --git a/test_runner/regress/test_layers_from_future.py b/test_runner/regress/test_layers_from_future.py index b4eba2779d..f3fcdb0d14 100644 --- a/test_runner/regress/test_layers_from_future.py +++ b/test_runner/regress/test_layers_from_future.py @@ -20,6 +20,9 @@ from fixtures.remote_storage import LocalFsStorage, RemoteStorageKind from fixtures.utils import query_scalar, wait_until +@pytest.mark.skip( + reason="We won't create future layers any more after https://github.com/neondatabase/neon/pull/10548" +) @pytest.mark.parametrize( "attach_mode", ["default_generation", "same_generation"], diff --git a/test_runner/regress/test_storage_controller.py b/test_runner/regress/test_storage_controller.py index af018f7b5d..d07fb38c5a 100644 --- a/test_runner/regress/test_storage_controller.py +++ b/test_runner/regress/test_storage_controller.py @@ -4158,17 +4158,12 @@ def test_storcon_create_delete_sk_down( env.storage_controller.stop() env.storage_controller.start() - config_lines = [ - "neon.safekeeper_proto_version = 3", - ] - with env.endpoints.create("main", tenant_id=tenant_id, config_lines=config_lines) as ep: + with env.endpoints.create("main", tenant_id=tenant_id) as ep: # endpoint should start. ep.start(safekeeper_generation=1, safekeepers=[1, 2, 3]) ep.safe_psql("CREATE TABLE IF NOT EXISTS t(key int, value text)") - with env.endpoints.create( - "child_of_main", tenant_id=tenant_id, config_lines=config_lines - ) as ep: + with env.endpoints.create("child_of_main", tenant_id=tenant_id) as ep: # endpoint should start. ep.start(safekeeper_generation=1, safekeepers=[1, 2, 3]) ep.safe_psql("CREATE TABLE IF NOT EXISTS t(key int, value text)") @@ -4249,17 +4244,12 @@ def test_storcon_few_sk( env.safekeepers[0].assert_log_contains(f"creating new timeline {tenant_id}/{timeline_id}") - config_lines = [ - "neon.safekeeper_proto_version = 3", - ] - with env.endpoints.create("main", tenant_id=tenant_id, config_lines=config_lines) as ep: + with env.endpoints.create("main", tenant_id=tenant_id) as ep: # endpoint should start. ep.start(safekeeper_generation=1, safekeepers=safekeeper_list) ep.safe_psql("CREATE TABLE IF NOT EXISTS t(key int, value text)") - with env.endpoints.create( - "child_of_main", tenant_id=tenant_id, config_lines=config_lines - ) as ep: + with env.endpoints.create("child_of_main", tenant_id=tenant_id) as ep: # endpoint should start. ep.start(safekeeper_generation=1, safekeepers=safekeeper_list) ep.safe_psql("CREATE TABLE IF NOT EXISTS t(key int, value text)") diff --git a/test_runner/regress/test_timeline_detach_ancestor.py b/test_runner/regress/test_timeline_detach_ancestor.py index d42c5d403e..f0810270b1 100644 --- a/test_runner/regress/test_timeline_detach_ancestor.py +++ b/test_runner/regress/test_timeline_detach_ancestor.py @@ -10,6 +10,7 @@ from queue import Empty, Queue from threading import Barrier import pytest +import requests from fixtures.common_types import Lsn, TimelineArchivalState, TimelineId from fixtures.log_helper import log from fixtures.neon_fixtures import ( @@ -401,8 +402,25 @@ def test_ancestor_detach_behavior_v2(neon_env_builder: NeonEnvBuilder, snapshots "earlier", ancestor_branch_name="main", ancestor_start_lsn=branchpoint_pipe ) - snapshot_branchpoint_old = env.create_branch( - "snapshot_branchpoint_old", ancestor_branch_name="main", ancestor_start_lsn=branchpoint_y + snapshot_branchpoint_old = TimelineId.generate() + + env.storage_controller.timeline_create( + env.initial_tenant, + { + "new_timeline_id": str(snapshot_branchpoint_old), + "ancestor_start_lsn": str(branchpoint_y), + "ancestor_timeline_id": str(env.initial_timeline), + "read_only": True, + }, + ) + sk = env.safekeepers[0] + assert sk + with pytest.raises(requests.exceptions.HTTPError, match="Not Found"): + sk.http_client().timeline_status( + tenant_id=env.initial_tenant, timeline_id=snapshot_branchpoint_old + ) + env.neon_cli.mappings_map_branch( + "snapshot_branchpoint_old", env.initial_tenant, snapshot_branchpoint_old ) snapshot_branchpoint = env.create_branch( diff --git a/test_runner/regress/test_wal_acceptor.py b/test_runner/regress/test_wal_acceptor.py index a9a6699e5c..6a7c7a8bef 100644 --- a/test_runner/regress/test_wal_acceptor.py +++ b/test_runner/regress/test_wal_acceptor.py @@ -2012,10 +2012,7 @@ def test_explicit_timeline_creation(neon_env_builder: NeonEnvBuilder): tenant_id = env.initial_tenant timeline_id = env.initial_timeline - config_lines = [ - "neon.safekeeper_proto_version = 3", - ] - ep = env.endpoints.create("main", config_lines=config_lines) + ep = env.endpoints.create("main") # expected to fail because timeline is not created on safekeepers with pytest.raises(Exception, match=r".*timed out.*"): @@ -2043,10 +2040,7 @@ def test_explicit_timeline_creation_storcon(neon_env_builder: NeonEnvBuilder): } env = neon_env_builder.init_start() - config_lines = [ - "neon.safekeeper_proto_version = 3", - ] - ep = env.endpoints.create("main", config_lines=config_lines) + ep = env.endpoints.create("main") # endpoint should start. ep.start(safekeeper_generation=1, safekeepers=[1, 2, 3]) diff --git a/test_runner/regress/test_wal_acceptor_async.py b/test_runner/regress/test_wal_acceptor_async.py index c5dd34f64f..4070f99568 100644 --- a/test_runner/regress/test_wal_acceptor_async.py +++ b/test_runner/regress/test_wal_acceptor_async.py @@ -637,10 +637,7 @@ async def quorum_sanity_single( # create timeline on `members_sks` Safekeeper.create_timeline(tenant_id, timeline_id, env.pageservers[0], mconf, members_sks) - config_lines = [ - "neon.safekeeper_proto_version = 3", - ] - ep = env.endpoints.create(branch_name, config_lines=config_lines) + ep = env.endpoints.create(branch_name) ep.start(safekeeper_generation=1, safekeepers=compute_sks_ids) ep.safe_psql("create table t(key int, value text)") diff --git a/workspace_hack/Cargo.toml b/workspace_hack/Cargo.toml index 87d0092fb2..2b07889871 100644 --- a/workspace_hack/Cargo.toml +++ b/workspace_hack/Cargo.toml @@ -18,6 +18,8 @@ license.workspace = true ahash = { version = "0.8" } anstream = { version = "0.6" } anyhow = { version = "1", features = ["backtrace"] } +axum = { version = "0.8", features = ["ws"] } +axum-core = { version = "0.5", default-features = false, features = ["tracing"] } base64-594e8ee84c453af0 = { package = "base64", version = "0.13", features = ["alloc"] } base64-647d43efb71741da = { package = "base64", version = "0.21" } base64ct = { version = "1", default-features = false, features = ["std"] } @@ -39,10 +41,8 @@ env_logger = { version = "0.11" } fail = { version = "0.5", default-features = false, features = ["failpoints"] } form_urlencoded = { version = "1" } futures-channel = { version = "0.3", features = ["sink"] } -futures-core = { version = "0.3" } futures-executor = { version = "0.3" } futures-io = { version = "0.3" } -futures-task = { version = "0.3", default-features = false, features = ["std"] } futures-util = { version = "0.3", features = ["channel", "io", "sink"] } generic-array = { version = "0.14", default-features = false, features = ["more_lengths", "zeroize"] } getrandom = { version = "0.2", default-features = false, features = ["std"] } @@ -52,9 +52,8 @@ hex = { version = "0.4", features = ["serde"] } hmac = { version = "0.12", default-features = false, features = ["reset"] } hyper-582f2526e08bb6a0 = { package = "hyper", version = "0.14", features = ["client", "http1", "http2", "runtime", "server", "stream"] } hyper-dff4ba8e3ae991db = { package = "hyper", version = "1", features = ["full"] } -hyper-util = { version = "0.1", features = ["client-legacy", "http1", "http2", "server", "service"] } -indexmap-dff4ba8e3ae991db = { package = "indexmap", version = "1", default-features = false, features = ["std"] } -indexmap-f595c2ba2a3f28df = { package = "indexmap", version = "2", features = ["serde"] } +hyper-util = { version = "0.1", features = ["client-legacy", "server-auto", "service"] } +indexmap = { version = "2", features = ["serde"] } itertools = { version = "0.12" } lazy_static = { version = "1", default-features = false, features = ["spin_no_std"] } libc = { version = "0.2", features = ["extra_traits", "use_std"] } @@ -73,7 +72,6 @@ num-traits = { version = "0.2", features = ["i128", "libm"] } once_cell = { version = "1" } p256 = { version = "0.13", features = ["jwk"] } parquet = { version = "53", default-features = false, features = ["zstd"] } -percent-encoding = { version = "2" } prost = { version = "0.13", features = ["no-recursion-limit", "prost-derive"] } rand = { version = "0.8", features = ["small_rng"] } regex = { version = "1" } @@ -82,7 +80,7 @@ regex-syntax = { version = "0.8" } reqwest = { version = "0.12", default-features = false, features = ["blocking", "json", "rustls-tls", "rustls-tls-native-roots", "stream"] } rustls = { version = "0.23", default-features = false, features = ["logging", "ring", "std", "tls12"] } rustls-pki-types = { version = "1", features = ["std"] } -rustls-webpki = { version = "0.102", default-features = false, features = ["ring", "std"] } +rustls-webpki = { version = "0.103", default-features = false, features = ["ring", "std"] } scopeguard = { version = "1" } sec1 = { version = "0.7", features = ["pem", "serde", "std", "subtle"] } serde = { version = "1", features = ["alloc", "derive"] } @@ -99,11 +97,10 @@ tikv-jemalloc-sys = { version = "0.6", features = ["profiling", "stats", "unpref time = { version = "0.3", features = ["macros", "serde-well-known"] } tokio = { version = "1", features = ["full", "test-util"] } tokio-rustls = { version = "0.26", default-features = false, features = ["logging", "ring", "tls12"] } -tokio-stream = { version = "0.1" } +tokio-stream = { version = "0.1", features = ["net"] } tokio-util = { version = "0.7", features = ["codec", "compat", "io", "rt"] } toml_edit = { version = "0.22", features = ["serde"] } -tonic = { version = "0.12", default-features = false, features = ["codegen", "prost", "tls-roots"] } -tower = { version = "0.4", default-features = false, features = ["balance", "buffer", "limit", "util"] } +tower = { version = "0.5", default-features = false, features = ["balance", "buffer", "limit", "log"] } tracing = { version = "0.1", features = ["log"] } tracing-core = { version = "0.1" } tracing-log = { version = "0.2" } @@ -125,8 +122,7 @@ either = { version = "1" } getrandom = { version = "0.2", default-features = false, features = ["std"] } half = { version = "2", default-features = false, features = ["num-traits"] } hashbrown = { version = "0.14", features = ["raw"] } -indexmap-dff4ba8e3ae991db = { package = "indexmap", version = "1", default-features = false, features = ["std"] } -indexmap-f595c2ba2a3f28df = { package = "indexmap", version = "2", features = ["serde"] } +indexmap = { version = "2", features = ["serde"] } itertools = { version = "0.12" } libc = { version = "0.2", features = ["extra_traits", "use_std"] } log = { version = "0.4", default-features = false, features = ["std"] }