Compare commits

..

4 Commits

Author SHA1 Message Date
Conrad Ludgate
01475c9e75 fix accidental recursion 2024-12-06 12:19:40 +00:00
Conrad Ludgate
c835bbba1f refactor statements and the type cache to avoid arcs 2024-12-06 12:01:19 +00:00
Conrad Ludgate
f94dde4432 delete some more 2024-12-06 11:33:34 +00:00
Conrad Ludgate
4991a85704 delete some client methods and make client take &mut 2024-12-06 11:22:03 +00:00
150 changed files with 2067 additions and 9599 deletions

View File

@@ -21,7 +21,3 @@ config-variables:
- SLACK_UPCOMING_RELEASE_CHANNEL_ID
- DEV_AWS_OIDC_ROLE_ARN
- BENCHMARK_INGEST_TARGET_PROJECTID
- PGREGRESS_PG16_PROJECT_ID
- PGREGRESS_PG17_PROJECT_ID
- SLACK_ON_CALL_QA_STAGING_STREAM
- DEV_AWS_OIDC_ROLE_MANAGE_BENCHMARK_EC2_VMS_ARN

View File

@@ -15,21 +15,10 @@ inputs:
prefix:
description: "S3 prefix. Default is '${GITHUB_RUN_ID}/${GITHUB_RUN_ATTEMPT}'"
required: false
aws_oicd_role_arn:
description: "the OIDC role arn for aws auth"
required: false
default: ""
runs:
using: "composite"
steps:
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ inputs.aws_oicd_role_arn }}
role-duration-seconds: 3600
- name: Download artifact
id: download-artifact
shell: bash -euxo pipefail {0}

View File

@@ -62,7 +62,6 @@ runs:
with:
name: neon-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build_type }}-artifact
path: /tmp/neon
aws_oicd_role_arn: ${{ inputs.aws_oicd_role_arn }}
- name: Download Neon binaries for the previous release
if: inputs.build_type != 'remote'
@@ -71,7 +70,6 @@ runs:
name: neon-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build_type }}-artifact
path: /tmp/neon-previous
prefix: latest
aws_oicd_role_arn: ${{ inputs.aws_oicd_role_arn }}
- name: Download compatibility snapshot
if: inputs.build_type != 'remote'
@@ -83,7 +81,6 @@ runs:
# The lack of compatibility snapshot (for example, for the new Postgres version)
# shouldn't fail the whole job. Only relevant test should fail.
skip-if-does-not-exist: true
aws_oicd_role_arn: ${{ inputs.aws_oicd_role_arn }}
- name: Checkout
if: inputs.needs_postgres_source == 'true'
@@ -221,7 +218,6 @@ runs:
# The lack of compatibility snapshot shouldn't fail the job
# (for example if we didn't run the test for non build-and-test workflow)
skip-if-does-not-exist: true
aws_oicd_role_arn: ${{ inputs.aws_oicd_role_arn }}
- name: (Re-)configure AWS credentials # necessary to upload reports to S3 after a long-running test
if: ${{ !cancelled() && (inputs.aws_oicd_role_arn != '') }}
@@ -236,4 +232,3 @@ runs:
with:
report-dir: /tmp/test_output/allure/results
unique-key: ${{ inputs.build_type }}-${{ inputs.pg_version }}
aws_oicd_role_arn: ${{ inputs.aws_oicd_role_arn }}

View File

@@ -14,11 +14,9 @@ runs:
name: coverage-data-artifact
path: /tmp/coverage
skip-if-does-not-exist: true # skip if there's no previous coverage to download
aws_oicd_role_arn: ${{ inputs.aws_oicd_role_arn }}
- name: Upload coverage data
uses: ./.github/actions/upload
with:
name: coverage-data-artifact
path: /tmp/coverage
aws_oicd_role_arn: ${{ inputs.aws_oicd_role_arn }}

View File

@@ -14,10 +14,6 @@ inputs:
prefix:
description: "S3 prefix. Default is '${GITHUB_SHA}/${GITHUB_RUN_ID}/${GITHUB_RUN_ATTEMPT}'"
required: false
aws_oicd_role_arn:
description: "the OIDC role arn for aws auth"
required: false
default: ""
runs:
using: "composite"
@@ -57,13 +53,6 @@ runs:
echo 'SKIPPED=false' >> $GITHUB_OUTPUT
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ inputs.aws_oicd_role_arn }}
role-duration-seconds: 3600
- name: Upload artifact
if: ${{ steps.prepare-artifact.outputs.SKIPPED == 'false' }}
shell: bash -euxo pipefail {0}

View File

@@ -70,7 +70,6 @@ jobs:
name: neon-${{ runner.os }}-${{ runner.arch }}-release-artifact
path: /tmp/neon/
prefix: latest
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
# we create a table that has one row for each database that we want to restore with the status whether the restore is done
- name: Create benchmark_restore_status table if it does not exist

View File

@@ -31,13 +31,12 @@ defaults:
env:
RUST_BACKTRACE: 1
COPT: '-Werror'
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_DEV }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_KEY_DEV }}
jobs:
build-neon:
runs-on: ${{ fromJson(format('["self-hosted", "{0}"]', inputs.arch == 'arm64' && 'large-arm64' || 'large')) }}
permissions:
id-token: write # aws-actions/configure-aws-credentials
contents: read
container:
image: ${{ inputs.build-tools-image }}
credentials:
@@ -206,13 +205,6 @@ jobs:
done
fi
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # 5 hours
- name: Run rust tests
env:
NEXTEST_RETRIES: 3
@@ -264,7 +256,6 @@ jobs:
with:
name: neon-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build-type }}-artifact
path: /tmp/neon
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
# XXX: keep this after the binaries.list is formed, so the coverage can properly work later
- name: Merge and upload coverage data
@@ -274,10 +265,6 @@ jobs:
regress-tests:
# Don't run regression tests on debug arm64 builds
if: inputs.build-type != 'debug' || inputs.arch != 'arm64'
permissions:
id-token: write # aws-actions/configure-aws-credentials
contents: read
statuses: write
needs: [ build-neon ]
runs-on: ${{ fromJson(format('["self-hosted", "{0}"]', inputs.arch == 'arm64' && 'large-arm64' || 'large')) }}
container:
@@ -296,7 +283,7 @@ jobs:
submodules: true
- name: Pytest regression tests
continue-on-error: ${{ matrix.lfc_state == 'with-lfc' && inputs.build-type == 'debug' }}
continue-on-error: ${{ matrix.lfc_state == 'with-lfc' }}
uses: ./.github/actions/run-python-test-set
timeout-minutes: 60
with:
@@ -308,7 +295,6 @@ jobs:
real_s3_region: eu-central-1
rerun_failed: true
pg_version: ${{ matrix.pg_version }}
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
env:
TEST_RESULT_CONNSTR: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}
CHECK_ONDISK_DATA_COMPATIBILITY: nonempty

View File

@@ -105,7 +105,6 @@ jobs:
name: neon-${{ runner.os }}-${{ runner.arch }}-release-artifact
path: /tmp/neon/
prefix: latest
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
- name: Create Neon Project
id: create-neon-project
@@ -205,7 +204,6 @@ jobs:
name: neon-${{ runner.os }}-${{ runner.arch }}-release-artifact
path: /tmp/neon/
prefix: latest
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
- name: Run Logical Replication benchmarks
uses: ./.github/actions/run-python-test-set
@@ -407,7 +405,6 @@ jobs:
name: neon-${{ runner.os }}-${{ runner.arch }}-release-artifact
path: /tmp/neon/
prefix: latest
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
- name: Create Neon Project
if: contains(fromJson('["neonvm-captest-new", "neonvm-captest-freetier", "neonvm-azure-captest-freetier", "neonvm-azure-captest-new"]'), matrix.platform)
@@ -711,7 +708,6 @@ jobs:
name: neon-${{ runner.os }}-${{ runner.arch }}-release-artifact
path: /tmp/neon/
prefix: latest
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
- name: Set up Connection String
id: set-up-connstr
@@ -822,7 +818,6 @@ jobs:
name: neon-${{ runner.os }}-${{ runner.arch }}-release-artifact
path: /tmp/neon/
prefix: latest
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
- name: Get Connstring Secret Name
run: |
@@ -931,7 +926,6 @@ jobs:
name: neon-${{ runner.os }}-${{ runner.arch }}-release-artifact
path: /tmp/neon/
prefix: latest
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
- name: Set up Connection String
id: set-up-connstr

View File

@@ -21,6 +21,8 @@ concurrency:
env:
RUST_BACKTRACE: 1
COPT: '-Werror'
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_DEV }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_KEY_DEV }}
# A concurrency group that we use for e2e-tests runs, matches `concurrency.group` above with `github.repository` as a prefix
E2E_CONCURRENCY_GROUP: ${{ github.repository }}-e2e-tests-${{ github.ref_name }}-${{ github.ref_name == 'main' && github.sha || 'anysha' }}
@@ -253,15 +255,15 @@ jobs:
build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm
build-tag: ${{ needs.tag.outputs.build-tag }}
build-type: ${{ matrix.build-type }}
# Run tests on all Postgres versions in release builds and only on the latest version in debug builds.
# Run without LFC on v17 release and debug builds only. For all the other cases LFC is enabled.
# Run tests on all Postgres versions in release builds and only on the latest version in debug builds
# run without LFC on v17 release only
test-cfg: |
${{ matrix.build-type == 'release' && '[{"pg_version":"v14", "lfc_state": "with-lfc"},
{"pg_version":"v15", "lfc_state": "with-lfc"},
{"pg_version":"v16", "lfc_state": "with-lfc"},
{"pg_version":"v17", "lfc_state": "with-lfc"},
{"pg_version":"v17", "lfc_state": "without-lfc"}]'
|| '[{"pg_version":"v17", "lfc_state": "without-lfc" }]' }}
${{ matrix.build-type == 'release' && '[{"pg_version":"v14", "lfc_state": "without-lfc"},
{"pg_version":"v15", "lfc_state": "without-lfc"},
{"pg_version":"v16", "lfc_state": "without-lfc"},
{"pg_version":"v17", "lfc_state": "without-lfc"},
{"pg_version":"v17", "lfc_state": "with-lfc"}]'
|| '[{"pg_version":"v17", "lfc_state": "without-lfc"}]' }}
secrets: inherit
# Keep `benchmarks` job outside of `build-and-test-locally` workflow to make job failures non-blocking
@@ -358,11 +360,6 @@ jobs:
create-test-report:
needs: [ check-permissions, build-and-test-locally, coverage-report, build-build-tools-image, benchmarks ]
if: ${{ !cancelled() && contains(fromJSON('["skipped", "success"]'), needs.check-permissions.result) }}
permissions:
id-token: write # aws-actions/configure-aws-credentials
statuses: write
contents: write
pull-requests: write
outputs:
report-url: ${{ steps.create-allure-report.outputs.report-url }}
@@ -383,7 +380,6 @@ jobs:
uses: ./.github/actions/allure-report-generate
with:
store-test-results-into-db: true
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
env:
REGRESS_TEST_RESULT_CONNSTR_NEW: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}
@@ -415,10 +411,6 @@ jobs:
coverage-report:
if: ${{ !startsWith(github.ref_name, 'release') }}
needs: [ check-permissions, build-build-tools-image, build-and-test-locally ]
permissions:
id-token: write # aws-actions/configure-aws-credentials
statuses: write
contents: write
runs-on: [ self-hosted, small ]
container:
image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm
@@ -445,14 +437,12 @@ jobs:
with:
name: neon-${{ runner.os }}-${{ runner.arch }}-${{ matrix.build_type }}-artifact
path: /tmp/neon
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
- name: Get coverage artifact
uses: ./.github/actions/download
with:
name: coverage-data-artifact
path: /tmp/coverage
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
- name: Merge coverage data
run: scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/coverage merge
@@ -583,10 +573,6 @@ jobs:
neon-image:
needs: [ neon-image-arch, tag ]
runs-on: ubuntu-22.04
permissions:
id-token: write # aws-actions/configure-aws-credentials
statuses: write
contents: read
steps:
- uses: docker/login-action@v3
@@ -601,15 +587,11 @@ jobs:
neondatabase/neon:${{ needs.tag.outputs.build-tag }}-bookworm-x64 \
neondatabase/neon:${{ needs.tag.outputs.build-tag }}-bookworm-arm64
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
- uses: docker/login-action@v3
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 3600
- name: Login to Amazon Dev ECR
uses: aws-actions/amazon-ecr-login@v2
registry: 369495373322.dkr.ecr.eu-central-1.amazonaws.com
username: ${{ secrets.AWS_ACCESS_KEY_DEV }}
password: ${{ secrets.AWS_SECRET_KEY_DEV }}
- name: Push multi-arch image to ECR
run: |
@@ -618,10 +600,6 @@ jobs:
compute-node-image-arch:
needs: [ check-permissions, build-build-tools-image, tag ]
permissions:
id-token: write # aws-actions/configure-aws-credentials
statuses: write
contents: read
strategy:
fail-fast: false
matrix:
@@ -662,15 +640,11 @@ jobs:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
- uses: docker/login-action@v3
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 3600
- name: Login to Amazon Dev ECR
uses: aws-actions/amazon-ecr-login@v2
registry: 369495373322.dkr.ecr.eu-central-1.amazonaws.com
username: ${{ secrets.AWS_ACCESS_KEY_DEV }}
password: ${{ secrets.AWS_SECRET_KEY_DEV }}
- uses: docker/login-action@v3
with:
@@ -743,10 +717,6 @@ jobs:
compute-node-image:
needs: [ compute-node-image-arch, tag ]
permissions:
id-token: write # aws-actions/configure-aws-credentials
statuses: write
contents: read
runs-on: ubuntu-22.04
strategy:
@@ -791,15 +761,11 @@ jobs:
neondatabase/compute-tools:${{ needs.tag.outputs.build-tag }}-${{ matrix.version.debian }}-x64 \
neondatabase/compute-tools:${{ needs.tag.outputs.build-tag }}-${{ matrix.version.debian }}-arm64
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
- uses: docker/login-action@v3
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 3600
- name: Login to Amazon Dev ECR
uses: aws-actions/amazon-ecr-login@v2
registry: 369495373322.dkr.ecr.eu-central-1.amazonaws.com
username: ${{ secrets.AWS_ACCESS_KEY_DEV }}
password: ${{ secrets.AWS_SECRET_KEY_DEV }}
- name: Push multi-arch compute-node-${{ matrix.version.pg }} image to ECR
run: |
@@ -829,7 +795,7 @@ jobs:
- pg: v17
debian: bookworm
env:
VM_BUILDER_VERSION: v0.37.1
VM_BUILDER_VERSION: v0.35.0
steps:
- uses: actions/checkout@v4
@@ -924,9 +890,7 @@ jobs:
runs-on: ubuntu-22.04
permissions:
id-token: write # aws-actions/configure-aws-credentials
statuses: write
contents: read
id-token: write # for `aws-actions/configure-aws-credentials`
env:
VERSIONS: v14 v15 v16 v17
@@ -937,15 +901,12 @@ jobs:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
- name: Login to dev ECR
uses: docker/login-action@v3
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 3600
- name: Login to Amazon Dev ECR
uses: aws-actions/amazon-ecr-login@v2
registry: 369495373322.dkr.ecr.eu-central-1.amazonaws.com
username: ${{ secrets.AWS_ACCESS_KEY_DEV }}
password: ${{ secrets.AWS_SECRET_KEY_DEV }}
- name: Copy vm-compute-node images to ECR
run: |
@@ -1099,79 +1060,12 @@ jobs:
needs: [ check-permissions, promote-images, tag, build-and-test-locally, trigger-custom-extensions-build-and-wait, push-to-acr-dev, push-to-acr-prod ]
# `!failure() && !cancelled()` is required because the workflow depends on the job that can be skipped: `push-to-acr-dev` and `push-to-acr-prod`
if: (github.ref_name == 'main' || github.ref_name == 'release' || github.ref_name == 'release-proxy' || github.ref_name == 'release-compute') && !failure() && !cancelled()
permissions:
id-token: write # aws-actions/configure-aws-credentials
statuses: write
contents: write
runs-on: [ self-hosted, small ]
container: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/ansible:latest
steps:
- uses: actions/checkout@v4
- name: Create git tag and GitHub release
if: github.ref_name == 'release' || github.ref_name == 'release-proxy' || github.ref_name == 'release-compute'
uses: actions/github-script@v7
with:
retries: 5
script: |
const tag = "${{ needs.tag.outputs.build-tag }}";
try {
const existingRef = await github.rest.git.getRef({
owner: context.repo.owner,
repo: context.repo.repo,
ref: `tags/${tag}`,
});
if (existingRef.data.object.sha !== context.sha) {
throw new Error(`Tag ${tag} already exists but points to a different commit (expected: ${context.sha}, actual: ${existingRef.data.object.sha}).`);
}
console.log(`Tag ${tag} already exists and points to ${context.sha} as expected.`);
} catch (error) {
if (error.status !== 404) {
throw error;
}
console.log(`Tag ${tag} does not exist. Creating it...`);
await github.rest.git.createRef({
owner: context.repo.owner,
repo: context.repo.repo,
ref: `refs/tags/${tag}`,
sha: context.sha,
});
console.log(`Tag ${tag} created successfully.`);
}
// TODO: check how GitHub releases looks for proxy/compute releases and enable them if they're ok
if (context.ref !== 'refs/heads/release') {
console.log(`GitHub release skipped for ${context.ref}.`);
return;
}
try {
const existingRelease = await github.rest.repos.getReleaseByTag({
owner: context.repo.owner,
repo: context.repo.repo,
tag: tag,
});
console.log(`Release for tag ${tag} already exists (ID: ${existingRelease.data.id}).`);
} catch (error) {
if (error.status !== 404) {
throw error;
}
console.log(`Release for tag ${tag} does not exist. Creating it...`);
await github.rest.repos.createRelease({
owner: context.repo.owner,
repo: context.repo.repo,
tag_name: tag,
generate_release_notes: true,
});
console.log(`Release for tag ${tag} created successfully.`);
}
- name: Trigger deploy workflow
env:
GH_TOKEN: ${{ secrets.CI_ACCESS_TOKEN }}
@@ -1221,13 +1115,38 @@ jobs:
exit 1
fi
- name: Create git tag
if: github.ref_name == 'release' || github.ref_name == 'release-proxy' || github.ref_name == 'release-compute'
uses: actions/github-script@v7
with:
# Retry script for 5XX server errors: https://github.com/actions/github-script#retries
retries: 5
script: |
await github.rest.git.createRef({
owner: context.repo.owner,
repo: context.repo.repo,
ref: "refs/tags/${{ needs.tag.outputs.build-tag }}",
sha: context.sha,
})
# TODO: check how GitHub releases looks for proxy releases and enable it if it's ok
- name: Create GitHub release
if: github.ref_name == 'release'
uses: actions/github-script@v7
with:
# Retry script for 5XX server errors: https://github.com/actions/github-script#retries
retries: 5
script: |
await github.rest.repos.createRelease({
owner: context.repo.owner,
repo: context.repo.repo,
tag_name: "${{ needs.tag.outputs.build-tag }}",
generate_release_notes: true,
})
# The job runs on `release` branch and copies compatibility data and Neon artifact from the last *release PR* to the latest directory
promote-compatibility-data:
needs: [ deploy ]
permissions:
id-token: write # aws-actions/configure-aws-credentials
statuses: write
contents: read
# `!failure() && !cancelled()` is required because the workflow transitively depends on the job that can be skipped: `push-to-acr-dev` and `push-to-acr-prod`
if: github.ref_name == 'release' && !failure() && !cancelled()

View File

@@ -19,19 +19,15 @@ concurrency:
group: ${{ github.workflow }}
cancel-in-progress: true
permissions:
id-token: write # aws-actions/configure-aws-credentials
jobs:
regress:
env:
POSTGRES_DISTRIB_DIR: /tmp/neon/pg_install
DEFAULT_PG_VERSION: 16
TEST_OUTPUT: /tmp/test_output
BUILD_TYPE: remote
strategy:
fail-fast: false
matrix:
pg-version: [16, 17]
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_DEV }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_KEY_DEV }}
runs-on: us-east-2
container:
@@ -44,11 +40,9 @@ jobs:
submodules: true
- name: Patch the test
env:
PG_VERSION: ${{matrix.pg-version}}
run: |
cd "vendor/postgres-v${PG_VERSION}"
patch -p1 < "../../compute/patches/cloud_regress_pg${PG_VERSION}.patch"
cd "vendor/postgres-v${DEFAULT_PG_VERSION}"
patch -p1 < "../../compute/patches/cloud_regress_pg${DEFAULT_PG_VERSION}.patch"
- name: Generate a random password
id: pwgen
@@ -61,9 +55,8 @@ jobs:
- name: Change tests according to the generated password
env:
DBPASS: ${{ steps.pwgen.outputs.DBPASS }}
PG_VERSION: ${{matrix.pg-version}}
run: |
cd vendor/postgres-v"${PG_VERSION}"/src/test/regress
cd vendor/postgres-v"${DEFAULT_PG_VERSION}"/src/test/regress
for fname in sql/*.sql expected/*.out; do
sed -i.bak s/NEON_PASSWORD_PLACEHOLDER/"'${DBPASS}'"/ "${fname}"
done
@@ -79,44 +72,27 @@ jobs:
name: neon-${{ runner.os }}-${{ runner.arch }}-release-artifact
path: /tmp/neon/
prefix: latest
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
- name: Create a new branch
id: create-branch
uses: ./.github/actions/neon-branch-create
with:
api_key: ${{ secrets.NEON_STAGING_API_KEY }}
project_id: ${{ vars[format('PGREGRESS_PG{0}_PROJECT_ID', matrix.pg-version)] }}
- name: Run the regression tests
uses: ./.github/actions/run-python-test-set
with:
build_type: ${{ env.BUILD_TYPE }}
test_selection: cloud_regress
pg_version: ${{matrix.pg-version}}
pg_version: ${{ env.DEFAULT_PG_VERSION }}
extra_params: -m remote_cluster
env:
BENCHMARK_CONNSTR: ${{steps.create-branch.outputs.dsn}}
- name: Delete branch
uses: ./.github/actions/neon-branch-delete
with:
api_key: ${{ secrets.NEON_STAGING_API_KEY }}
project_id: ${{ vars[format('PGREGRESS_PG{0}_PROJECT_ID', matrix.pg-version)] }}
branch_id: ${{steps.create-branch.outputs.branch_id}}
BENCHMARK_CONNSTR: ${{ secrets.PG_REGRESS_CONNSTR }}
- name: Create Allure report
id: create-allure-report
if: ${{ !cancelled() }}
uses: ./.github/actions/allure-report-generate
with:
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
- name: Post to a Slack channel
if: ${{ github.event.schedule && failure() }}
uses: slackapi/slack-github-action@v1
with:
channel-id: ${{ vars.SLACK_ON_CALL_QA_STAGING_STREAM }}
channel-id: "C033QLM5P7D" # on-call-staging-stream
slack-message: |
Periodic pg_regress on staging: ${{ job.status }}
<${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}|GitHub Run>

View File

@@ -64,7 +64,6 @@ jobs:
name: neon-${{ runner.os }}-${{ runner.arch }}-release-artifact
path: /tmp/neon/
prefix: latest
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
- name: Create Neon Project
if: ${{ matrix.target_project == 'new_empty_project' }}

View File

@@ -143,10 +143,6 @@ jobs:
gather-rust-build-stats:
needs: [ check-permissions, build-build-tools-image ]
permissions:
id-token: write # aws-actions/configure-aws-credentials
statuses: write
contents: write
if: |
contains(github.event.pull_request.labels.*.name, 'run-extra-build-stats') ||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-*') ||
@@ -181,18 +177,13 @@ jobs:
- name: Produce the build stats
run: PQ_LIB_DIR=$(pwd)/pg_install/v17/lib cargo build --all --release --timings -j$(nproc)
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 3600
- name: Upload the build stats
id: upload-stats
env:
BUCKET: neon-github-public-dev
SHA: ${{ github.event.pull_request.head.sha || github.sha }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_DEV }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_KEY_DEV }}
run: |
REPORT_URL=https://${BUCKET}.s3.amazonaws.com/build-stats/${SHA}/${GITHUB_RUN_ID}/cargo-timing.html
aws s3 cp --only-show-errors ./target/cargo-timings/cargo-timing.html "s3://${BUCKET}/build-stats/${SHA}/${GITHUB_RUN_ID}/"

View File

@@ -21,9 +21,6 @@ defaults:
run:
shell: bash -euo pipefail {0}
permissions:
id-token: write # aws-actions/configure-aws-credentials
concurrency:
group: ${{ github.workflow }}
cancel-in-progress: false
@@ -41,6 +38,8 @@ jobs:
env:
API_KEY: ${{ secrets.PERIODIC_PAGEBENCH_EC2_RUNNER_API_KEY }}
RUN_ID: ${{ github.run_id }}
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_EC2_US_TEST_RUNNER_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY : ${{ secrets.AWS_EC2_US_TEST_RUNNER_ACCESS_KEY_SECRET }}
AWS_DEFAULT_REGION : "eu-central-1"
AWS_INSTANCE_ID : "i-02a59a3bf86bc7e74"
steps:
@@ -51,13 +50,6 @@ jobs:
- name: Show my own (github runner) external IP address - usefull for IP allowlisting
run: curl https://ifconfig.me
- name: Assume AWS OIDC role that allows to manage (start/stop/describe... EC machine)
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_MANAGE_BENCHMARK_EC2_VMS_ARN }}
role-duration-seconds: 3600
- name: Start EC2 instance and wait for the instance to boot up
run: |
aws ec2 start-instances --instance-ids $AWS_INSTANCE_ID
@@ -132,10 +124,11 @@ jobs:
cat "test_log_${GITHUB_RUN_ID}"
- name: Create Allure report
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_DEV }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_KEY_DEV }}
if: ${{ !cancelled() }}
uses: ./.github/actions/allure-report-generate
with:
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
- name: Post to a Slack channel
if: ${{ github.event.schedule && failure() }}
@@ -155,14 +148,6 @@ jobs:
-H "Authorization: Bearer $API_KEY" \
-d ''
- name: Assume AWS OIDC role that allows to manage (start/stop/describe... EC machine)
if: always() && steps.poll_step.outputs.too_many_runs != 'true'
uses: aws-actions/configure-aws-credentials@v4
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_MANAGE_BENCHMARK_EC2_VMS_ARN }}
role-duration-seconds: 3600
- name: Stop EC2 instance and wait for the instance to be stopped
if: always() && steps.poll_step.outputs.too_many_runs != 'true'
run: |

View File

@@ -25,13 +25,11 @@ defaults:
run:
shell: bash -euxo pipefail {0}
permissions:
id-token: write # aws-actions/configure-aws-credentials
statuses: write # require for posting a status update
env:
DEFAULT_PG_VERSION: 16
PLATFORM: neon-captest-new
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_DEV }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_KEY_DEV }}
AWS_DEFAULT_REGION: eu-central-1
jobs:
@@ -96,7 +94,6 @@ jobs:
name: neon-${{ runner.os }}-${{ runner.arch }}-release-artifact
path: /tmp/neon/
prefix: latest
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
- name: Create Neon Project
id: create-neon-project
@@ -129,7 +126,6 @@ jobs:
uses: ./.github/actions/allure-report-generate
with:
store-test-results-into-db: true
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
env:
REGRESS_TEST_RESULT_CONNSTR_NEW: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}
@@ -163,7 +159,6 @@ jobs:
name: neon-${{ runner.os }}-${{ runner.arch }}-release-artifact
path: /tmp/neon/
prefix: latest
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
- name: Create Neon Project
id: create-neon-project
@@ -196,7 +191,6 @@ jobs:
uses: ./.github/actions/allure-report-generate
with:
store-test-results-into-db: true
aws_oicd_role_arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
env:
REGRESS_TEST_RESULT_CONNSTR_NEW: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}

View File

@@ -67,7 +67,7 @@ jobs:
runs-on: ubuntu-22.04
permissions:
id-token: write # for `azure/login` and aws auth
id-token: write # for `azure/login`
steps:
- uses: docker/login-action@v3
@@ -75,15 +75,11 @@ jobs:
username: ${{ secrets.NEON_DOCKERHUB_USERNAME }}
password: ${{ secrets.NEON_DOCKERHUB_PASSWORD }}
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@v4
- uses: docker/login-action@v3
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 3600
- name: Login to Amazon Dev ECR
uses: aws-actions/amazon-ecr-login@v2
registry: 369495373322.dkr.ecr.eu-central-1.amazonaws.com
username: ${{ secrets.AWS_ACCESS_KEY_DEV }}
password: ${{ secrets.AWS_SECRET_KEY_DEV }}
- name: Azure login
uses: azure/login@6c251865b4e6290e7b78be643ea2d005bc51f69a # @v2.1.1

View File

@@ -63,7 +63,6 @@ jobs:
if: always()
permissions:
statuses: write # for `github.repos.createCommitStatus(...)`
contents: write
needs:
- get-changed-files
- check-codestyle-python

View File

@@ -3,7 +3,7 @@ name: Create Release Branch
on:
schedule:
# It should be kept in sync with if-condition in jobs
- cron: '0 6 * * FRI' # Storage release
- cron: '0 6 * * MON' # Storage release
- cron: '0 6 * * THU' # Proxy release
workflow_dispatch:
inputs:
@@ -29,7 +29,7 @@ defaults:
jobs:
create-storage-release-branch:
if: ${{ github.event.schedule == '0 6 * * FRI' || inputs.create-storage-release-branch }}
if: ${{ github.event.schedule == '0 6 * * MON' || inputs.create-storage-release-branch }}
permissions:
contents: write

View File

@@ -1,29 +1,16 @@
# Autoscaling
/libs/vm_monitor/ @neondatabase/autoscaling
# DevProd
/.github/ @neondatabase/developer-productivity
# Compute
/pgxn/ @neondatabase/compute
/vendor/ @neondatabase/compute
/compute/ @neondatabase/compute
/compute_tools/ @neondatabase/compute
# Proxy
/compute_tools/ @neondatabase/control-plane @neondatabase/compute
/libs/pageserver_api/ @neondatabase/storage
/libs/postgres_ffi/ @neondatabase/compute @neondatabase/storage
/libs/proxy/ @neondatabase/proxy
/proxy/ @neondatabase/proxy
# Storage
/libs/remote_storage/ @neondatabase/storage
/libs/safekeeper_api/ @neondatabase/storage
/libs/vm_monitor/ @neondatabase/autoscaling
/pageserver/ @neondatabase/storage
/pgxn/ @neondatabase/compute
/pgxn/neon/ @neondatabase/compute @neondatabase/storage
/proxy/ @neondatabase/proxy
/safekeeper/ @neondatabase/storage
/storage_controller @neondatabase/storage
/storage_scrubber @neondatabase/storage
/libs/pageserver_api/ @neondatabase/storage
/libs/remote_storage/ @neondatabase/storage
/libs/safekeeper_api/ @neondatabase/storage
# Shared
/pgxn/neon/ @neondatabase/compute @neondatabase/storage
/libs/compute_api/ @neondatabase/compute @neondatabase/control-plane
/libs/postgres_ffi/ @neondatabase/compute @neondatabase/storage
/vendor/ @neondatabase/compute

515
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -51,6 +51,10 @@ anyhow = { version = "1.0", features = ["backtrace"] }
arc-swap = "1.6"
async-compression = { version = "0.4.0", features = ["tokio", "gzip", "zstd"] }
atomic-take = "1.1.0"
azure_core = { version = "0.19", default-features = false, features = ["enable_reqwest_rustls", "hmac_rust"] }
azure_identity = { version = "0.19", default-features = false, features = ["enable_reqwest_rustls"] }
azure_storage = { version = "0.19", default-features = false, features = ["enable_reqwest_rustls"] }
azure_storage_blobs = { version = "0.19", default-features = false, features = ["enable_reqwest_rustls"] }
flate2 = "1.0.26"
async-stream = "0.3"
async-trait = "0.1"
@@ -212,12 +216,6 @@ postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git",
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon" }
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon" }
## Azure SDK crates
azure_core = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git", branch = "neon", default-features = false, features = ["enable_reqwest_rustls", "hmac_rust"] }
azure_identity = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git", branch = "neon", default-features = false, features = ["enable_reqwest_rustls"] }
azure_storage = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git", branch = "neon", default-features = false, features = ["enable_reqwest_rustls"] }
azure_storage_blobs = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git", branch = "neon", default-features = false, features = ["enable_reqwest_rustls"] }
## Local libraries
compute_api = { version = "0.1", path = "./libs/compute_api/" }
consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" }

View File

@@ -115,7 +115,7 @@ RUN set -e \
# Keep the version the same as in compute/compute-node.Dockerfile and
# test_runner/regress/test_compute_metrics.py.
ENV SQL_EXPORTER_VERSION=0.16.0
ENV SQL_EXPORTER_VERSION=0.13.1
RUN curl -fsSL \
"https://github.com/burningalchemist/sql_exporter/releases/download/${SQL_EXPORTER_VERSION}/sql_exporter-${SQL_EXPORTER_VERSION}.linux-$(case "$(uname -m)" in x86_64) echo amd64;; aarch64) echo arm64;; esac).tar.gz" \
--output sql_exporter.tar.gz \

View File

@@ -1324,7 +1324,7 @@ FROM quay.io/prometheuscommunity/postgres-exporter:v0.12.1 AS postgres-exporter
# Keep the version the same as in build-tools.Dockerfile and
# test_runner/regress/test_compute_metrics.py.
FROM burningalchemist/sql_exporter:0.16.0 AS sql-exporter
FROM burningalchemist/sql_exporter:0.13.1 AS sql-exporter
#########################################################################################
#

View File

@@ -19,10 +19,3 @@ max_prepared_statements=0
admin_users=postgres
unix_socket_dir=/tmp/
unix_socket_mode=0777
;; Disable connection logging. It produces a lot of logs that no one looks at,
;; and we can get similar log entries from the proxy too. We had incidents in
;; the past where the logging significantly stressed the log device or pgbouncer
;; itself.
log_connections=0
log_disconnections=0

File diff suppressed because it is too large Load Diff

View File

@@ -246,48 +246,47 @@ fn try_spec_from_cli(
let compute_id = matches.get_one::<String>("compute-id");
let control_plane_uri = matches.get_one::<String>("control-plane-uri");
// First, try to get cluster spec from the cli argument
if let Some(spec_json) = spec_json {
info!("got spec from cli argument {}", spec_json);
return Ok(CliSpecParams {
spec: Some(serde_json::from_str(spec_json)?),
live_config_allowed: false,
});
}
// Second, try to read it from the file if path is provided
if let Some(spec_path) = spec_path {
let file = File::open(Path::new(spec_path))?;
return Ok(CliSpecParams {
spec: Some(serde_json::from_reader(file)?),
live_config_allowed: true,
});
}
let Some(compute_id) = compute_id else {
panic!(
"compute spec should be provided by one of the following ways: \
--spec OR --spec-path OR --control-plane-uri and --compute-id"
);
};
let Some(control_plane_uri) = control_plane_uri else {
panic!("must specify both --control-plane-uri and --compute-id or none");
};
match get_spec_from_control_plane(control_plane_uri, compute_id) {
Ok(spec) => Ok(CliSpecParams {
spec,
live_config_allowed: true,
}),
Err(e) => {
error!(
"cannot get response from control plane: {}\n\
neither spec nor confirmation that compute is in the Empty state was received",
e
);
Err(e)
let spec;
let mut live_config_allowed = false;
match spec_json {
// First, try to get cluster spec from the cli argument
Some(json) => {
info!("got spec from cli argument {}", json);
spec = Some(serde_json::from_str(json)?);
}
}
None => {
// Second, try to read it from the file if path is provided
if let Some(sp) = spec_path {
let path = Path::new(sp);
let file = File::open(path)?;
spec = Some(serde_json::from_reader(file)?);
live_config_allowed = true;
} else if let Some(id) = compute_id {
if let Some(cp_base) = control_plane_uri {
live_config_allowed = true;
spec = match get_spec_from_control_plane(cp_base, id) {
Ok(s) => s,
Err(e) => {
error!("cannot get response from control plane: {}", e);
panic!("neither spec nor confirmation that compute is in the Empty state was received");
}
};
} else {
panic!("must specify both --control-plane-uri and --compute-id or none");
}
} else {
panic!(
"compute spec should be provided by one of the following ways: \
--spec OR --spec-path OR --control-plane-uri and --compute-id"
);
}
}
};
Ok(CliSpecParams {
spec,
live_config_allowed,
})
}
struct CliSpecParams {

View File

@@ -537,14 +537,12 @@ components:
properties:
extname:
type: string
version:
type: string
versions:
type: array
items:
type: string
n_databases:
type: integer
owned_by_superuser:
type: integer
SetRoleGrantsRequest:
type: object

View File

@@ -1,6 +1,7 @@
use compute_api::responses::{InstalledExtension, InstalledExtensions};
use metrics::proto::MetricFamily;
use std::collections::HashMap;
use std::collections::HashSet;
use anyhow::Result;
use postgres::{Client, NoTls};
@@ -37,77 +38,61 @@ fn list_dbs(client: &mut Client) -> Result<Vec<String>> {
/// Connect to every database (see list_dbs above) and get the list of installed extensions.
///
/// Same extension can be installed in multiple databases with different versions,
/// so we report a separate metric (number of databases where it is installed)
/// for each extension version.
/// we only keep the highest and lowest version across all databases.
pub fn get_installed_extensions(mut conf: postgres::config::Config) -> Result<InstalledExtensions> {
conf.application_name("compute_ctl:get_installed_extensions");
let mut client = conf.connect(NoTls)?;
let databases: Vec<String> = list_dbs(&mut client)?;
let mut extensions_map: HashMap<(String, String, String), InstalledExtension> = HashMap::new();
let mut extensions_map: HashMap<String, InstalledExtension> = HashMap::new();
for db in databases.iter() {
conf.dbname(db);
let mut db_client = conf.connect(NoTls)?;
let extensions: Vec<(String, String, i32)> = db_client
let extensions: Vec<(String, String)> = db_client
.query(
"SELECT extname, extversion, extowner::integer FROM pg_catalog.pg_extension",
"SELECT extname, extversion FROM pg_catalog.pg_extension;",
&[],
)?
.iter()
.map(|row| {
(
row.get("extname"),
row.get("extversion"),
row.get("extowner"),
)
})
.map(|row| (row.get("extname"), row.get("extversion")))
.collect();
for (extname, v, extowner) in extensions.iter() {
for (extname, v) in extensions.iter() {
let version = v.to_string();
// check if the extension is owned by superuser
// 10 is the oid of superuser
let owned_by_superuser = if *extowner == 10 { "1" } else { "0" };
// increment the number of databases where the version of extension is installed
INSTALLED_EXTENSIONS
.with_label_values(&[extname, &version])
.inc();
extensions_map
.entry((
extname.to_string(),
version.clone(),
owned_by_superuser.to_string(),
))
.entry(extname.to_string())
.and_modify(|e| {
e.versions.insert(version.clone());
// count the number of databases where the extension is installed
e.n_databases += 1;
})
.or_insert(InstalledExtension {
extname: extname.to_string(),
version: version.clone(),
versions: HashSet::from([version.clone()]),
n_databases: 1,
owned_by_superuser: owned_by_superuser.to_string(),
});
}
}
for (key, ext) in extensions_map.iter() {
let (extname, version, owned_by_superuser) = key;
let n_databases = ext.n_databases as u64;
INSTALLED_EXTENSIONS
.with_label_values(&[extname, version, owned_by_superuser])
.set(n_databases);
}
Ok(InstalledExtensions {
let res = InstalledExtensions {
extensions: extensions_map.into_values().collect(),
})
};
Ok(res)
}
static INSTALLED_EXTENSIONS: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"compute_installed_extensions",
"Number of databases where the version of extension is installed",
&["extension_name", "version", "owned_by_superuser"]
&["extension_name", "version"]
)
.expect("failed to define a metric")
});

View File

@@ -274,7 +274,6 @@ fn fill_remote_storage_secrets_vars(mut cmd: &mut Command) -> &mut Command {
for env_key in [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_SESSION_TOKEN",
"AWS_PROFILE",
// HOME is needed in combination with `AWS_PROFILE` to pick up the SSO sessions.
"HOME",

View File

@@ -810,7 +810,7 @@ impl Endpoint {
}
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(120))
.timeout(Duration::from_secs(30))
.build()
.unwrap();
let response = client

View File

@@ -435,7 +435,7 @@ impl PageServerNode {
) -> anyhow::Result<()> {
let config = Self::parse_config(settings)?;
self.http_client
.set_tenant_config(&models::TenantConfigRequest { tenant_id, config })
.tenant_config(&models::TenantConfigRequest { tenant_id, config })
.await?;
Ok(())

View File

@@ -9,8 +9,8 @@ use pageserver_api::{
},
models::{
EvictionPolicy, EvictionPolicyLayerAccessThreshold, LocationConfigSecondary,
ShardParameters, TenantConfig, TenantConfigPatchRequest, TenantConfigRequest,
TenantShardSplitRequest, TenantShardSplitResponse,
ShardParameters, TenantConfig, TenantConfigRequest, TenantShardSplitRequest,
TenantShardSplitResponse,
},
shard::{ShardStripeSize, TenantShardId},
};
@@ -116,19 +116,9 @@ enum Command {
#[arg(long)]
tenant_shard_id: TenantShardId,
},
/// Set the pageserver tenant configuration of a tenant: this is the configuration structure
/// Modify the pageserver tenant configuration of a tenant: this is the configuration structure
/// that is passed through to pageservers, and does not affect storage controller behavior.
/// Any previous tenant configs are overwritten.
SetTenantConfig {
#[arg(long)]
tenant_id: TenantId,
#[arg(long)]
config: String,
},
/// Patch the pageserver tenant configuration of a tenant. Any fields with null values in the
/// provided JSON are unset from the tenant config and all fields with non-null values are set.
/// Unspecified fields are not changed.
PatchTenantConfig {
TenantConfig {
#[arg(long)]
tenant_id: TenantId,
#[arg(long)]
@@ -559,21 +549,11 @@ async fn main() -> anyhow::Result<()> {
)
.await?;
}
Command::SetTenantConfig { tenant_id, config } => {
Command::TenantConfig { tenant_id, config } => {
let tenant_conf = serde_json::from_str(&config)?;
vps_client
.set_tenant_config(&TenantConfigRequest {
tenant_id,
config: tenant_conf,
})
.await?;
}
Command::PatchTenantConfig { tenant_id, config } => {
let tenant_conf = serde_json::from_str(&config)?;
vps_client
.patch_tenant_config(&TenantConfigPatchRequest {
.tenant_config(&TenantConfigRequest {
tenant_id,
config: tenant_conf,
})
@@ -756,7 +736,7 @@ async fn main() -> anyhow::Result<()> {
threshold,
} => {
vps_client
.set_tenant_config(&TenantConfigRequest {
.tenant_config(&TenantConfigRequest {
tenant_id,
config: TenantConfig {
eviction_policy: Some(EvictionPolicy::LayerAccessThreshold(

View File

@@ -42,7 +42,6 @@ allow = [
"MPL-2.0",
"OpenSSL",
"Unicode-DFS-2016",
"Unicode-3.0",
]
confidence-threshold = 0.8
exceptions = [

View File

@@ -1,5 +1,6 @@
//! Structs representing the JSON formats used in the compute_ctl's HTTP API.
use std::collections::HashSet;
use std::fmt::Display;
use chrono::{DateTime, Utc};
@@ -162,9 +163,8 @@ pub enum ControlPlaneComputeStatus {
#[derive(Clone, Debug, Default, Serialize)]
pub struct InstalledExtension {
pub extname: String,
pub version: String,
pub versions: HashSet<String>,
pub n_databases: u32, // Number of databases using this extension
pub owned_by_superuser: String,
}
#[derive(Clone, Debug, Default, Serialize)]

View File

@@ -75,7 +75,7 @@ pub struct TenantPolicyRequest {
pub scheduling: Option<ShardSchedulingPolicy>,
}
#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
#[derive(Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Debug)]
pub struct AvailabilityZone(pub String);
impl Display for AvailabilityZone {
@@ -245,17 +245,6 @@ impl From<NodeAvailability> for NodeAvailabilityWrapper {
}
}
/// Scheduling policy enables us to selectively disable some automatic actions that the
/// controller performs on a tenant shard. This is only set to a non-default value by
/// human intervention, and it is reset to the default value (Active) when the tenant's
/// placement policy is modified away from Attached.
///
/// The typical use of a non-Active scheduling policy is one of:
/// - Pinnning a shard to a node (i.e. migrating it there & setting a non-Active scheduling policy)
/// - Working around a bug (e.g. if something is flapping and we need to stop it until the bug is fixed)
///
/// If you're not sure which policy to use to pin a shard to its current location, you probably
/// want Pause.
#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)]
pub enum ShardSchedulingPolicy {
// Normal mode: the tenant's scheduled locations may be updated at will, including

View File

@@ -24,7 +24,7 @@ pub struct Key {
/// When working with large numbers of Keys in-memory, it is more efficient to handle them as i128 than as
/// a struct of fields.
#[derive(Clone, Copy, Hash, PartialEq, Eq, Ord, PartialOrd, Serialize, Deserialize, Debug)]
#[derive(Clone, Copy, Hash, PartialEq, Eq, Ord, PartialOrd, Serialize, Deserialize)]
pub struct CompactKey(i128);
/// The storage key size.

View File

@@ -17,7 +17,7 @@ use std::{
use byteorder::{BigEndian, ReadBytesExt};
use postgres_ffi::BLCKSZ;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde::{Deserialize, Serialize};
use serde_with::serde_as;
use utils::{
completion,
@@ -325,115 +325,6 @@ impl Default for ShardParameters {
}
}
#[derive(Debug, Default, Clone, Eq, PartialEq)]
pub enum FieldPatch<T> {
Upsert(T),
Remove,
#[default]
Noop,
}
impl<T> FieldPatch<T> {
fn is_noop(&self) -> bool {
matches!(self, FieldPatch::Noop)
}
pub fn apply(self, target: &mut Option<T>) {
match self {
Self::Upsert(v) => *target = Some(v),
Self::Remove => *target = None,
Self::Noop => {}
}
}
pub fn map<U, E, F: FnOnce(T) -> Result<U, E>>(self, map: F) -> Result<FieldPatch<U>, E> {
match self {
Self::Upsert(v) => Ok(FieldPatch::<U>::Upsert(map(v)?)),
Self::Remove => Ok(FieldPatch::<U>::Remove),
Self::Noop => Ok(FieldPatch::<U>::Noop),
}
}
}
impl<'de, T: Deserialize<'de>> Deserialize<'de> for FieldPatch<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
Option::deserialize(deserializer).map(|opt| match opt {
None => FieldPatch::Remove,
Some(val) => FieldPatch::Upsert(val),
})
}
}
impl<T: Serialize> Serialize for FieldPatch<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
FieldPatch::Upsert(val) => serializer.serialize_some(val),
FieldPatch::Remove => serializer.serialize_none(),
FieldPatch::Noop => unreachable!(),
}
}
}
#[derive(Serialize, Deserialize, Debug, Default, Clone, Eq, PartialEq)]
#[serde(default)]
pub struct TenantConfigPatch {
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub checkpoint_distance: FieldPatch<u64>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub checkpoint_timeout: FieldPatch<String>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub compaction_target_size: FieldPatch<u64>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub compaction_period: FieldPatch<String>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub compaction_threshold: FieldPatch<usize>,
// defer parsing compaction_algorithm, like eviction_policy
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub compaction_algorithm: FieldPatch<CompactionAlgorithmSettings>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub gc_horizon: FieldPatch<u64>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub gc_period: FieldPatch<String>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub image_creation_threshold: FieldPatch<usize>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub pitr_interval: FieldPatch<String>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub walreceiver_connect_timeout: FieldPatch<String>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub lagging_wal_timeout: FieldPatch<String>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub max_lsn_wal_lag: FieldPatch<NonZeroU64>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub eviction_policy: FieldPatch<EvictionPolicy>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub min_resident_size_override: FieldPatch<u64>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub evictions_low_residence_duration_metric_threshold: FieldPatch<String>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub heatmap_period: FieldPatch<String>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub lazy_slru_download: FieldPatch<bool>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub timeline_get_throttle: FieldPatch<ThrottleConfig>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub image_layer_creation_check_threshold: FieldPatch<u8>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub lsn_lease_length: FieldPatch<String>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub lsn_lease_length_for_ts: FieldPatch<String>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub timeline_offloading: FieldPatch<bool>,
#[serde(skip_serializing_if = "FieldPatch::is_noop")]
pub wal_receiver_protocol_override: FieldPatch<PostgresClientProtocol>,
}
/// An alternative representation of `pageserver::tenant::TenantConf` with
/// simpler types.
#[derive(Serialize, Deserialize, Debug, Default, Clone, Eq, PartialEq)]
@@ -465,107 +356,6 @@ pub struct TenantConfig {
pub wal_receiver_protocol_override: Option<PostgresClientProtocol>,
}
impl TenantConfig {
pub fn apply_patch(self, patch: TenantConfigPatch) -> TenantConfig {
let Self {
mut checkpoint_distance,
mut checkpoint_timeout,
mut compaction_target_size,
mut compaction_period,
mut compaction_threshold,
mut compaction_algorithm,
mut gc_horizon,
mut gc_period,
mut image_creation_threshold,
mut pitr_interval,
mut walreceiver_connect_timeout,
mut lagging_wal_timeout,
mut max_lsn_wal_lag,
mut eviction_policy,
mut min_resident_size_override,
mut evictions_low_residence_duration_metric_threshold,
mut heatmap_period,
mut lazy_slru_download,
mut timeline_get_throttle,
mut image_layer_creation_check_threshold,
mut lsn_lease_length,
mut lsn_lease_length_for_ts,
mut timeline_offloading,
mut wal_receiver_protocol_override,
} = self;
patch.checkpoint_distance.apply(&mut checkpoint_distance);
patch.checkpoint_timeout.apply(&mut checkpoint_timeout);
patch
.compaction_target_size
.apply(&mut compaction_target_size);
patch.compaction_period.apply(&mut compaction_period);
patch.compaction_threshold.apply(&mut compaction_threshold);
patch.compaction_algorithm.apply(&mut compaction_algorithm);
patch.gc_horizon.apply(&mut gc_horizon);
patch.gc_period.apply(&mut gc_period);
patch
.image_creation_threshold
.apply(&mut image_creation_threshold);
patch.pitr_interval.apply(&mut pitr_interval);
patch
.walreceiver_connect_timeout
.apply(&mut walreceiver_connect_timeout);
patch.lagging_wal_timeout.apply(&mut lagging_wal_timeout);
patch.max_lsn_wal_lag.apply(&mut max_lsn_wal_lag);
patch.eviction_policy.apply(&mut eviction_policy);
patch
.min_resident_size_override
.apply(&mut min_resident_size_override);
patch
.evictions_low_residence_duration_metric_threshold
.apply(&mut evictions_low_residence_duration_metric_threshold);
patch.heatmap_period.apply(&mut heatmap_period);
patch.lazy_slru_download.apply(&mut lazy_slru_download);
patch
.timeline_get_throttle
.apply(&mut timeline_get_throttle);
patch
.image_layer_creation_check_threshold
.apply(&mut image_layer_creation_check_threshold);
patch.lsn_lease_length.apply(&mut lsn_lease_length);
patch
.lsn_lease_length_for_ts
.apply(&mut lsn_lease_length_for_ts);
patch.timeline_offloading.apply(&mut timeline_offloading);
patch
.wal_receiver_protocol_override
.apply(&mut wal_receiver_protocol_override);
Self {
checkpoint_distance,
checkpoint_timeout,
compaction_target_size,
compaction_period,
compaction_threshold,
compaction_algorithm,
gc_horizon,
gc_period,
image_creation_threshold,
pitr_interval,
walreceiver_connect_timeout,
lagging_wal_timeout,
max_lsn_wal_lag,
eviction_policy,
min_resident_size_override,
evictions_low_residence_duration_metric_threshold,
heatmap_period,
lazy_slru_download,
timeline_get_throttle,
image_layer_creation_check_threshold,
lsn_lease_length,
lsn_lease_length_for_ts,
timeline_offloading,
wal_receiver_protocol_override,
}
}
}
/// The policy for the aux file storage.
///
/// It can be switched through `switch_aux_file_policy` tenant config.
@@ -896,14 +686,6 @@ impl TenantConfigRequest {
}
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct TenantConfigPatchRequest {
pub tenant_id: TenantId,
#[serde(flatten)]
pub config: TenantConfigPatch, // as we have a flattened field, we should reject all unknown fields in it
}
/// See [`TenantState::attachment_status`] and the OpenAPI docs for context.
#[derive(Serialize, Deserialize, Clone)]
#[serde(tag = "slug", content = "data", rename_all = "snake_case")]
@@ -1917,45 +1699,4 @@ mod tests {
);
}
}
#[test]
fn test_tenant_config_patch_request_serde() {
let patch_request = TenantConfigPatchRequest {
tenant_id: TenantId::from_str("17c6d121946a61e5ab0fe5a2fd4d8215").unwrap(),
config: TenantConfigPatch {
checkpoint_distance: FieldPatch::Upsert(42),
gc_horizon: FieldPatch::Remove,
compaction_threshold: FieldPatch::Noop,
..TenantConfigPatch::default()
},
};
let json = serde_json::to_string(&patch_request).unwrap();
let expected = r#"{"tenant_id":"17c6d121946a61e5ab0fe5a2fd4d8215","checkpoint_distance":42,"gc_horizon":null}"#;
assert_eq!(json, expected);
let decoded: TenantConfigPatchRequest = serde_json::from_str(&json).unwrap();
assert_eq!(decoded.tenant_id, patch_request.tenant_id);
assert_eq!(decoded.config, patch_request.config);
// Now apply the patch to a config to demonstrate semantics
let base = TenantConfig {
checkpoint_distance: Some(28),
gc_horizon: Some(100),
compaction_target_size: Some(1024),
..Default::default()
};
let expected = TenantConfig {
checkpoint_distance: Some(42),
gc_horizon: None,
..base.clone()
};
let patched = base.apply_patch(decoded.config);
assert_eq!(patched, expected);
}
}

View File

@@ -4,23 +4,18 @@ use crate::config::Host;
use crate::config::SslMode;
use crate::connection::{Request, RequestMessages};
use crate::query::RowStream;
use crate::simple_query::SimpleQueryStream;
use crate::types::{Oid, ToSql, Type};
use crate::types::{Oid, Type};
use crate::{
prepare, query, simple_query, slice_iter, CancelToken, Error, ReadyForQueryStatus, Row,
SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder,
simple_query, CancelToken, Error, ReadyForQueryStatus, Statement, Transaction,
TransactionBuilder,
};
use bytes::BytesMut;
use fallible_iterator::FallibleIterator;
use futures_util::{future, ready, TryStreamExt};
use parking_lot::Mutex;
use futures_util::{future, ready};
use postgres_protocol2::message::{backend::Message, frontend};
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::sync::mpsc;
@@ -55,7 +50,7 @@ impl Responses {
/// A cache of type info and prepared statements for fetching type info
/// (corresponding to the queries in the [prepare] module).
#[derive(Default)]
struct CachedTypeInfo {
pub(crate) struct CachedTypeInfo {
/// A statement for basic information for a type from its
/// OID. Corresponds to [TYPEINFO_QUERY](prepare::TYPEINFO_QUERY) (or its
/// fallback).
@@ -71,13 +66,45 @@ struct CachedTypeInfo {
/// Cache of types already looked up.
types: HashMap<Oid, Type>,
}
impl CachedTypeInfo {
pub(crate) fn typeinfo(&mut self) -> Option<&Statement> {
self.typeinfo.as_ref()
}
pub(crate) fn set_typeinfo(&mut self, statement: Statement) -> &Statement {
self.typeinfo.insert(statement)
}
pub(crate) fn typeinfo_composite(&mut self) -> Option<&Statement> {
self.typeinfo_composite.as_ref()
}
pub(crate) fn set_typeinfo_composite(&mut self, statement: Statement) -> &Statement {
self.typeinfo_composite.insert(statement)
}
pub(crate) fn typeinfo_enum(&mut self) -> Option<&Statement> {
self.typeinfo_enum.as_ref()
}
pub(crate) fn set_typeinfo_enum(&mut self, statement: Statement) -> &Statement {
self.typeinfo_enum.insert(statement)
}
pub(crate) fn type_(&mut self, oid: Oid) -> Option<Type> {
self.types.get(&oid).cloned()
}
pub(crate) fn set_type(&mut self, oid: Oid, type_: &Type) {
self.types.insert(oid, type_.clone());
}
}
pub struct InnerClient {
sender: mpsc::UnboundedSender<Request>,
cached_typeinfo: Mutex<CachedTypeInfo>,
/// A buffer to use when writing out postgres commands.
buffer: Mutex<BytesMut>,
buffer: BytesMut,
}
impl InnerClient {
@@ -92,47 +119,14 @@ impl InnerClient {
})
}
pub fn typeinfo(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo.clone()
}
pub fn set_typeinfo(&self, statement: &Statement) {
self.cached_typeinfo.lock().typeinfo = Some(statement.clone());
}
pub fn typeinfo_composite(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo_composite.clone()
}
pub fn set_typeinfo_composite(&self, statement: &Statement) {
self.cached_typeinfo.lock().typeinfo_composite = Some(statement.clone());
}
pub fn typeinfo_enum(&self) -> Option<Statement> {
self.cached_typeinfo.lock().typeinfo_enum.clone()
}
pub fn set_typeinfo_enum(&self, statement: &Statement) {
self.cached_typeinfo.lock().typeinfo_enum = Some(statement.clone());
}
pub fn type_(&self, oid: Oid) -> Option<Type> {
self.cached_typeinfo.lock().types.get(&oid).cloned()
}
pub fn set_type(&self, oid: Oid, type_: &Type) {
self.cached_typeinfo.lock().types.insert(oid, type_.clone());
}
/// Call the given function with a buffer to be used when writing out
/// postgres commands.
pub fn with_buf<F, R>(&self, f: F) -> R
pub fn with_buf<F, R>(&mut self, f: F) -> R
where
F: FnOnce(&mut BytesMut) -> R,
{
let mut buffer = self.buffer.lock();
let r = f(&mut buffer);
buffer.clear();
let r = f(&mut self.buffer);
self.buffer.clear();
r
}
}
@@ -150,7 +144,8 @@ pub struct SocketConfig {
/// The client is one half of what is returned when a connection is established. Users interact with the database
/// through this client object.
pub struct Client {
inner: Arc<InnerClient>,
pub(crate) inner: InnerClient,
pub(crate) cached_typeinfo: CachedTypeInfo,
socket_config: SocketConfig,
ssl_mode: SslMode,
@@ -167,11 +162,11 @@ impl Client {
secret_key: i32,
) -> Client {
Client {
inner: Arc::new(InnerClient {
inner: InnerClient {
sender,
cached_typeinfo: Default::default(),
buffer: Default::default(),
}),
},
cached_typeinfo: Default::default(),
socket_config,
ssl_mode,
@@ -185,161 +180,6 @@ impl Client {
self.process_id
}
pub(crate) fn inner(&self) -> &Arc<InnerClient> {
&self.inner
}
/// Creates a new prepared statement.
///
/// Prepared statements can be executed repeatedly, and may contain query parameters (indicated by `$1`, `$2`, etc),
/// which are set when executed. Prepared statements can only be used with the connection that created them.
pub async fn prepare(&self, query: &str) -> Result<Statement, Error> {
self.prepare_typed(query, &[]).await
}
/// Like `prepare`, but allows the types of query parameters to be explicitly specified.
///
/// The list of types may be smaller than the number of parameters - the types of the remaining parameters will be
/// inferred. For example, `client.prepare_typed(query, &[])` is equivalent to `client.prepare(query)`.
pub async fn prepare_typed(
&self,
query: &str,
parameter_types: &[Type],
) -> Result<Statement, Error> {
prepare::prepare(&self.inner, query, parameter_types).await
}
/// Executes a statement, returning a vector of the resulting rows.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
pub async fn query<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Vec<Row>, Error>
where
T: ?Sized + ToStatement,
{
self.query_raw(statement, slice_iter(params))
.await?
.try_collect()
.await
}
/// The maximally flexible version of [`query`].
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
///
/// [`query`]: #method.query
pub async fn query_raw<'a, T, I>(&self, statement: &T, params: I) -> Result<RowStream, Error>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
let statement = statement.__convert().into_statement(self).await?;
query::query(&self.inner, statement, params).await
}
/// Pass text directly to the Postgres backend to allow it to sort out typing itself and
/// to save a roundtrip
pub async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
query::query_txt(&self.inner, statement, params).await
}
/// Executes a statement, returning the number of rows modified.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
pub async fn execute<T>(
&self,
statement: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
{
self.execute_raw(statement, slice_iter(params)).await
}
/// The maximally flexible version of [`execute`].
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `statement` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Panics
///
/// Panics if the number of parameters provided does not match the number expected.
///
/// [`execute`]: #method.execute
pub async fn execute_raw<'a, T, I>(&self, statement: &T, params: I) -> Result<u64, Error>
where
T: ?Sized + ToStatement,
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
let statement = statement.__convert().into_statement(self).await?;
query::execute(self.inner(), statement, params).await
}
/// Executes a sequence of SQL statements using the simple query protocol, returning the resulting rows.
///
/// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
/// point. The simple query protocol returns the values in rows as strings rather than in their binary encodings,
/// so the associated row type doesn't work with the `FromSql` trait. Rather than simply returning a list of the
/// rows, this method returns a list of an enum which indicates either the completion of one of the commands,
/// or a row of data. This preserves the framing between the separate statements in the request.
///
/// # Warning
///
/// Prepared statements should be use for any query which contains user-specified data, as they provided the
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
/// them to this method!
pub async fn simple_query(&self, query: &str) -> Result<Vec<SimpleQueryMessage>, Error> {
self.simple_query_raw(query).await?.try_collect().await
}
pub(crate) async fn simple_query_raw(&self, query: &str) -> Result<SimpleQueryStream, Error> {
simple_query::simple_query(self.inner(), query).await
}
/// Executes a sequence of SQL statements using the simple query protocol.
///
/// Statements should be separated by semicolons. If an error occurs, execution of the sequence will stop at that
@@ -350,8 +190,8 @@ impl Client {
/// Prepared statements should be use for any query which contains user-specified data, as they provided the
/// functionality to safely embed that data in the request. Do not form statements via string concatenation and pass
/// them to this method!
pub async fn batch_execute(&self, query: &str) -> Result<ReadyForQueryStatus, Error> {
simple_query::batch_execute(self.inner(), query).await
pub async fn batch_execute(&mut self, query: &str) -> Result<ReadyForQueryStatus, Error> {
simple_query::batch_execute(&mut self.inner, query).await
}
/// Begins a new database transaction.
@@ -359,7 +199,7 @@ impl Client {
/// The transaction will roll back by default - use the `commit` method to commit it.
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
struct RollbackIfNotDone<'me> {
client: &'me Client,
client: &'me mut Client,
done: bool,
}
@@ -369,13 +209,13 @@ impl Client {
return;
}
let buf = self.client.inner().with_buf(|buf| {
let buf = self.client.inner.with_buf(|buf| {
frontend::query("ROLLBACK", buf).unwrap();
buf.split().freeze()
});
let _ = self
.client
.inner()
.inner
.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
}
}
@@ -390,7 +230,7 @@ impl Client {
client: self,
done: false,
};
self.batch_execute("BEGIN").await?;
cleaner.client.batch_execute("BEGIN").await?;
cleaner.done = true;
}
@@ -416,11 +256,6 @@ impl Client {
}
}
/// Query for type information
pub async fn get_type(&self, oid: Oid) -> Result<Type, Error> {
crate::prepare::get_type(&self.inner, oid).await
}
/// Determines if the connection to the server has already closed.
///
/// In that case, all future queries will fail.

View File

@@ -1,4 +1,4 @@
use crate::query::RowStream;
use crate::query::{self, RowStream};
use crate::types::Type;
use crate::{Client, Error, Transaction};
use async_trait::async_trait;
@@ -13,33 +13,32 @@ mod private {
/// This trait is "sealed", and cannot be implemented outside of this crate.
#[async_trait]
pub trait GenericClient: private::Sealed {
/// Like `Client::query_raw_txt`.
async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send;
/// Query for type information
async fn get_type(&self, oid: Oid) -> Result<Type, Error>;
async fn get_type(&mut self, oid: Oid) -> Result<Type, Error>;
}
impl private::Sealed for Client {}
#[async_trait]
impl GenericClient for Client {
async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send,
{
self.query_raw_txt(statement, params).await
query::query_txt(&mut self.inner, statement, params).await
}
/// Query for type information
async fn get_type(&self, oid: Oid) -> Result<Type, Error> {
self.get_type(oid).await
async fn get_type(&mut self, oid: Oid) -> Result<Type, Error> {
crate::prepare::get_type(&mut self.inner, &mut self.cached_typeinfo, oid).await
}
}
@@ -48,17 +47,18 @@ impl private::Sealed for Transaction<'_> {}
#[async_trait]
#[allow(clippy::needless_lifetimes)]
impl GenericClient for Transaction<'_> {
async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
I::IntoIter: ExactSizeIterator + Sync + Send,
{
self.query_raw_txt(statement, params).await
query::query_txt(&mut self.client().inner, statement, params).await
}
/// Query for type information
async fn get_type(&self, oid: Oid) -> Result<Type, Error> {
self.client().get_type(oid).await
async fn get_type(&mut self, oid: Oid) -> Result<Type, Error> {
let client = self.client();
crate::prepare::get_type(&mut client.inner, &mut client.cached_typeinfo, oid).await
}
}

View File

@@ -10,11 +10,10 @@ use crate::error::DbError;
pub use crate::error::Error;
pub use crate::generic_client::GenericClient;
pub use crate::query::RowStream;
pub use crate::row::{Row, SimpleQueryRow};
pub use crate::simple_query::SimpleQueryStream;
pub use crate::row::Row;
pub use crate::statement::{Column, Statement};
pub use crate::tls::NoTls;
pub use crate::to_statement::ToStatement;
// pub use crate::to_statement::ToStatement;
pub use crate::transaction::Transaction;
pub use crate::transaction_builder::{IsolationLevel, TransactionBuilder};
use crate::types::ToSql;
@@ -65,7 +64,7 @@ pub mod row;
mod simple_query;
mod statement;
pub mod tls;
mod to_statement;
// mod to_statement;
mod transaction;
mod transaction_builder;
pub mod types;
@@ -98,7 +97,6 @@ impl Notification {
/// An asynchronous message from the server.
#[allow(clippy::large_enum_variant)]
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum AsyncMessage {
/// A notice.
///
@@ -110,18 +108,6 @@ pub enum AsyncMessage {
Notification(Notification),
}
/// Message returned by the `SimpleQuery` stream.
#[derive(Debug)]
#[non_exhaustive]
pub enum SimpleQueryMessage {
/// A row of data.
Row(SimpleQueryRow),
/// A statement in the query has completed.
///
/// The number of rows modified or selected is returned.
CommandComplete(u64),
}
fn slice_iter<'a>(
s: &'a [&'a (dyn ToSql + Sync)],
) -> impl ExactSizeIterator<Item = &'a (dyn ToSql + Sync)> + 'a {

View File

@@ -1,4 +1,4 @@
use crate::client::InnerClient;
use crate::client::{CachedTypeInfo, InnerClient};
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::error::SqlState;
@@ -7,14 +7,13 @@ use crate::{query, slice_iter};
use crate::{Column, Error, Statement};
use bytes::Bytes;
use fallible_iterator::FallibleIterator;
use futures_util::{pin_mut, TryStreamExt};
use futures_util::{pin_mut, StreamExt, TryStreamExt};
use log::debug;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use std::future::Future;
use std::pin::Pin;
use std::pin::{pin, Pin};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
pub(crate) const TYPEINFO_QUERY: &str = "\
SELECT t.typname, t.typtype, t.typelem, r.rngsubtype, t.typbasetype, n.nspname, t.typrelid
@@ -59,7 +58,8 @@ ORDER BY attnum
static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
pub async fn prepare(
client: &Arc<InnerClient>,
client: &mut InnerClient,
cache: &mut CachedTypeInfo,
query: &str,
types: &[Type],
) -> Result<Statement, Error> {
@@ -86,7 +86,7 @@ pub async fn prepare(
let mut parameters = vec![];
let mut it = parameter_description.parameters();
while let Some(oid) = it.next().map_err(Error::parse)? {
let type_ = get_type(client, oid).await?;
let type_ = get_type(client, cache, oid).await?;
parameters.push(type_);
}
@@ -94,24 +94,30 @@ pub async fn prepare(
if let Some(row_description) = row_description {
let mut it = row_description.fields();
while let Some(field) = it.next().map_err(Error::parse)? {
let type_ = get_type(client, field.type_oid()).await?;
let type_ = get_type(client, cache, field.type_oid()).await?;
let column = Column::new(field.name().to_string(), type_, field);
columns.push(column);
}
}
Ok(Statement::new(client, name, parameters, columns))
Ok(Statement::new(name, parameters, columns))
}
fn prepare_rec<'a>(
client: &'a Arc<InnerClient>,
client: &'a mut InnerClient,
cache: &'a mut CachedTypeInfo,
query: &'a str,
types: &'a [Type],
) -> Pin<Box<dyn Future<Output = Result<Statement, Error>> + 'a + Send>> {
Box::pin(prepare(client, query, types))
Box::pin(prepare(client, cache, query, types))
}
fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Result<Bytes, Error> {
fn encode(
client: &mut InnerClient,
name: &str,
query: &str,
types: &[Type],
) -> Result<Bytes, Error> {
if types.is_empty() {
debug!("preparing query {}: {}", name, query);
} else {
@@ -126,16 +132,20 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu
})
}
pub async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error> {
pub async fn get_type(
client: &mut InnerClient,
cache: &mut CachedTypeInfo,
oid: Oid,
) -> Result<Type, Error> {
if let Some(type_) = Type::from_oid(oid) {
return Ok(type_);
}
if let Some(type_) = client.type_(oid) {
if let Some(type_) = cache.type_(oid) {
return Ok(type_);
}
let stmt = typeinfo_statement(client).await?;
let stmt = typeinfo_statement(client, cache).await?;
let rows = query::query(client, stmt, slice_iter(&[&oid])).await?;
pin_mut!(rows);
@@ -145,118 +155,141 @@ pub async fn get_type(client: &Arc<InnerClient>, oid: Oid) -> Result<Type, Error
None => return Err(Error::unexpected_message()),
};
let name: String = row.try_get(0)?;
let type_: i8 = row.try_get(1)?;
let elem_oid: Oid = row.try_get(2)?;
let rngsubtype: Option<Oid> = row.try_get(3)?;
let basetype: Oid = row.try_get(4)?;
let schema: String = row.try_get(5)?;
let relid: Oid = row.try_get(6)?;
let name: String = row.try_get(stmt.columns(), 0)?;
let type_: i8 = row.try_get(stmt.columns(), 1)?;
let elem_oid: Oid = row.try_get(stmt.columns(), 2)?;
let rngsubtype: Option<Oid> = row.try_get(stmt.columns(), 3)?;
let basetype: Oid = row.try_get(stmt.columns(), 4)?;
let schema: String = row.try_get(stmt.columns(), 5)?;
let relid: Oid = row.try_get(stmt.columns(), 6)?;
let kind = if type_ == b'e' as i8 {
let variants = get_enum_variants(client, oid).await?;
let variants = get_enum_variants(client, cache, oid).await?;
Kind::Enum(variants)
} else if type_ == b'p' as i8 {
Kind::Pseudo
} else if basetype != 0 {
let type_ = get_type_rec(client, basetype).await?;
let type_ = get_type_rec(client, cache, basetype).await?;
Kind::Domain(type_)
} else if elem_oid != 0 {
let type_ = get_type_rec(client, elem_oid).await?;
let type_ = get_type_rec(client, cache, elem_oid).await?;
Kind::Array(type_)
} else if relid != 0 {
let fields = get_composite_fields(client, relid).await?;
let fields = get_composite_fields(client, cache, relid).await?;
Kind::Composite(fields)
} else if let Some(rngsubtype) = rngsubtype {
let type_ = get_type_rec(client, rngsubtype).await?;
let type_ = get_type_rec(client, cache, rngsubtype).await?;
Kind::Range(type_)
} else {
Kind::Simple
};
let type_ = Type::new(name, oid, kind, schema);
client.set_type(oid, &type_);
cache.set_type(oid, &type_);
Ok(type_)
}
fn get_type_rec<'a>(
client: &'a Arc<InnerClient>,
client: &'a mut InnerClient,
cache: &'a mut CachedTypeInfo,
oid: Oid,
) -> Pin<Box<dyn Future<Output = Result<Type, Error>> + Send + 'a>> {
Box::pin(get_type(client, oid))
Box::pin(get_type(client, cache, oid))
}
async fn typeinfo_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
if let Some(stmt) = client.typeinfo() {
return Ok(stmt);
async fn typeinfo_statement<'c>(
client: &mut InnerClient,
cache: &'c mut CachedTypeInfo,
) -> Result<&'c Statement, Error> {
if cache.typeinfo().is_some() {
// needed to get around a borrow checker limitation
return Ok(cache.typeinfo().unwrap());
}
let stmt = match prepare_rec(client, TYPEINFO_QUERY, &[]).await {
let stmt = match prepare_rec(client, cache, TYPEINFO_QUERY, &[]).await {
Ok(stmt) => stmt,
Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_TABLE) => {
prepare_rec(client, TYPEINFO_FALLBACK_QUERY, &[]).await?
prepare_rec(client, cache, TYPEINFO_FALLBACK_QUERY, &[]).await?
}
Err(e) => return Err(e),
};
client.set_typeinfo(&stmt);
Ok(stmt)
Ok(cache.set_typeinfo(stmt))
}
async fn get_enum_variants(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<String>, Error> {
let stmt = typeinfo_enum_statement(client).await?;
async fn get_enum_variants(
client: &mut InnerClient,
cache: &mut CachedTypeInfo,
oid: Oid,
) -> Result<Vec<String>, Error> {
let stmt = typeinfo_enum_statement(client, cache).await?;
query::query(client, stmt, slice_iter(&[&oid]))
.await?
.and_then(|row| async move { row.try_get(0) })
.try_collect()
.await
let mut out = vec![];
let mut rows = pin!(query::query(client, stmt, slice_iter(&[&oid])).await?);
while let Some(row) = rows.next().await {
out.push(row?.try_get(stmt.columns(), 0)?)
}
Ok(out)
}
async fn typeinfo_enum_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
if let Some(stmt) = client.typeinfo_enum() {
return Ok(stmt);
async fn typeinfo_enum_statement<'c>(
client: &mut InnerClient,
cache: &'c mut CachedTypeInfo,
) -> Result<&'c Statement, Error> {
if cache.typeinfo_enum().is_some() {
// needed to get around a borrow checker limitation
return Ok(cache.typeinfo_enum().unwrap());
}
let stmt = match prepare_rec(client, TYPEINFO_ENUM_QUERY, &[]).await {
let stmt = match prepare_rec(client, cache, TYPEINFO_ENUM_QUERY, &[]).await {
Ok(stmt) => stmt,
Err(ref e) if e.code() == Some(&SqlState::UNDEFINED_COLUMN) => {
prepare_rec(client, TYPEINFO_ENUM_FALLBACK_QUERY, &[]).await?
prepare_rec(client, cache, TYPEINFO_ENUM_FALLBACK_QUERY, &[]).await?
}
Err(e) => return Err(e),
};
client.set_typeinfo_enum(&stmt);
Ok(stmt)
Ok(cache.set_typeinfo_enum(stmt))
}
async fn get_composite_fields(client: &Arc<InnerClient>, oid: Oid) -> Result<Vec<Field>, Error> {
let stmt = typeinfo_composite_statement(client).await?;
async fn get_composite_fields(
client: &mut InnerClient,
cache: &mut CachedTypeInfo,
oid: Oid,
) -> Result<Vec<Field>, Error> {
let stmt = typeinfo_composite_statement(client, cache).await?;
let rows = query::query(client, stmt, slice_iter(&[&oid]))
.await?
.try_collect::<Vec<_>>()
.await?;
let mut rows = pin!(query::query(client, stmt, slice_iter(&[&oid])).await?);
let mut oids = vec![];
while let Some(row) = rows.next().await {
let row = row?;
let name = row.try_get(stmt.columns(), 0)?;
let oid = row.try_get(stmt.columns(), 1)?;
oids.push((name, oid));
}
let mut fields = vec![];
for row in rows {
let name = row.try_get(0)?;
let oid = row.try_get(1)?;
let type_ = get_type_rec(client, oid).await?;
for (name, oid) in oids {
let type_ = get_type_rec(client, cache, oid).await?;
fields.push(Field::new(name, type_));
}
Ok(fields)
}
async fn typeinfo_composite_statement(client: &Arc<InnerClient>) -> Result<Statement, Error> {
if let Some(stmt) = client.typeinfo_composite() {
return Ok(stmt);
async fn typeinfo_composite_statement<'c>(
client: &mut InnerClient,
cache: &'c mut CachedTypeInfo,
) -> Result<&'c Statement, Error> {
if cache.typeinfo_composite().is_some() {
// needed to get around a borrow checker limitation
return Ok(cache.typeinfo_composite().unwrap());
}
let stmt = prepare_rec(client, TYPEINFO_COMPOSITE_QUERY, &[]).await?;
let stmt = prepare_rec(client, cache, TYPEINFO_COMPOSITE_QUERY, &[]).await?;
client.set_typeinfo_composite(&stmt);
Ok(stmt)
Ok(cache.set_typeinfo_composite(stmt))
}

View File

@@ -14,7 +14,6 @@ use postgres_types2::{Format, ToSql, Type};
use std::fmt;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
struct BorrowToSqlParamsDebug<'a>(&'a [&'a (dyn ToSql + Sync)]);
@@ -26,10 +25,10 @@ impl fmt::Debug for BorrowToSqlParamsDebug<'_> {
}
pub async fn query<'a, I>(
client: &InnerClient,
statement: Statement,
client: &mut InnerClient,
statement: &Statement,
params: I,
) -> Result<RowStream, Error>
) -> Result<RawRowStream, Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
@@ -41,13 +40,12 @@ where
statement.name(),
BorrowToSqlParamsDebug(params.as_slice()),
);
encode(client, &statement, params)?
encode(client, statement, params)?
} else {
encode(client, &statement, params)?
encode(client, statement, params)?
};
let responses = start(client, buf).await?;
Ok(RowStream {
statement,
Ok(RawRowStream {
responses,
command_tag: None,
status: ReadyForQueryStatus::Unknown,
@@ -57,7 +55,7 @@ where
}
pub async fn query_txt<S, I>(
client: &Arc<InnerClient>,
client: &mut InnerClient,
query: &str,
params: I,
) -> Result<RowStream, Error>
@@ -157,49 +155,6 @@ where
})
}
pub async fn execute<'a, I>(
client: &InnerClient,
statement: Statement,
params: I,
) -> Result<u64, Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
{
let buf = if log_enabled!(Level::Debug) {
let params = params.into_iter().collect::<Vec<_>>();
debug!(
"executing statement {} with parameters: {:?}",
statement.name(),
BorrowToSqlParamsDebug(params.as_slice()),
);
encode(client, &statement, params)?
} else {
encode(client, &statement, params)?
};
let mut responses = start(client, buf).await?;
let mut rows = 0;
loop {
match responses.next().await? {
Message::DataRow(_) => {}
Message::CommandComplete(body) => {
rows = body
.tag()
.map_err(Error::parse)?
.rsplit(' ')
.next()
.unwrap()
.parse()
.unwrap_or(0);
}
Message::EmptyQueryResponse => rows = 0,
Message::ReadyForQuery(_) => return Ok(rows),
_ => return Err(Error::unexpected_message()),
}
}
}
async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
@@ -211,7 +166,11 @@ async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
Ok(responses)
}
pub fn encode<'a, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
pub fn encode<'a, I>(
client: &mut InnerClient,
statement: &Statement,
params: I,
) -> Result<Bytes, Error>
where
I: IntoIterator<Item = &'a (dyn ToSql + Sync)>,
I::IntoIter: ExactSizeIterator,
@@ -296,11 +255,7 @@ impl Stream for RowStream {
loop {
match ready!(this.responses.poll_next(cx)?) {
Message::DataRow(body) => {
return Poll::Ready(Some(Ok(Row::new(
this.statement.clone(),
body,
*this.output_format,
)?)))
return Poll::Ready(Some(Ok(Row::new(body, *this.output_format)?)))
}
Message::EmptyQueryResponse | Message::PortalSuspended => {}
Message::CommandComplete(body) => {
@@ -338,3 +293,41 @@ impl RowStream {
self.status
}
}
pin_project! {
/// A stream of table rows.
pub struct RawRowStream {
responses: Responses,
command_tag: Option<String>,
output_format: Format,
status: ReadyForQueryStatus,
#[pin]
_p: PhantomPinned,
}
}
impl Stream for RawRowStream {
type Item = Result<Row, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
loop {
match ready!(this.responses.poll_next(cx)?) {
Message::DataRow(body) => {
return Poll::Ready(Some(Ok(Row::new(body, *this.output_format)?)))
}
Message::EmptyQueryResponse | Message::PortalSuspended => {}
Message::CommandComplete(body) => {
if let Ok(tag) = body.tag() {
*this.command_tag = Some(tag.to_string());
}
}
Message::ReadyForQuery(status) => {
*this.status = status.into();
return Poll::Ready(None);
}
_ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
}
}
}
}

View File

@@ -1,103 +1,16 @@
//! Rows.
use crate::row::sealed::{AsName, Sealed};
use crate::simple_query::SimpleColumn;
use crate::statement::Column;
use crate::types::{FromSql, Type, WrongType};
use crate::{Error, Statement};
use crate::Error;
use fallible_iterator::FallibleIterator;
use postgres_protocol2::message::backend::DataRowBody;
use postgres_types2::{Format, WrongFormat};
use std::fmt;
use std::ops::Range;
use std::str;
use std::sync::Arc;
mod sealed {
pub trait Sealed {}
pub trait AsName {
fn as_name(&self) -> &str;
}
}
impl AsName for Column {
fn as_name(&self) -> &str {
self.name()
}
}
impl AsName for String {
fn as_name(&self) -> &str {
self
}
}
/// A trait implemented by types that can index into columns of a row.
///
/// This cannot be implemented outside of this crate.
pub trait RowIndex: Sealed {
#[doc(hidden)]
fn __idx<T>(&self, columns: &[T]) -> Option<usize>
where
T: AsName;
}
impl Sealed for usize {}
impl RowIndex for usize {
#[inline]
fn __idx<T>(&self, columns: &[T]) -> Option<usize>
where
T: AsName,
{
if *self >= columns.len() {
None
} else {
Some(*self)
}
}
}
impl Sealed for str {}
impl RowIndex for str {
#[inline]
fn __idx<T>(&self, columns: &[T]) -> Option<usize>
where
T: AsName,
{
if let Some(idx) = columns.iter().position(|d| d.as_name() == self) {
return Some(idx);
};
// FIXME ASCII-only case insensitivity isn't really the right thing to
// do. Postgres itself uses a dubious wrapper around tolower and JDBC
// uses the US locale.
columns
.iter()
.position(|d| d.as_name().eq_ignore_ascii_case(self))
}
}
impl<T> Sealed for &T where T: ?Sized + Sealed {}
impl<T> RowIndex for &T
where
T: ?Sized + RowIndex,
{
#[inline]
fn __idx<U>(&self, columns: &[U]) -> Option<usize>
where
U: AsName,
{
T::__idx(*self, columns)
}
}
/// A row of data returned from the database by a query.
pub struct Row {
statement: Statement,
output_format: Format,
body: DataRowBody,
ranges: Vec<Option<Range<usize>>>,
@@ -105,80 +18,33 @@ pub struct Row {
impl fmt::Debug for Row {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Row")
.field("columns", &self.columns())
.finish()
f.debug_struct("Row").finish()
}
}
impl Row {
pub(crate) fn new(
statement: Statement,
// statement: Statement,
body: DataRowBody,
output_format: Format,
) -> Result<Row, Error> {
let ranges = body.ranges().collect().map_err(Error::parse)?;
Ok(Row {
statement,
body,
ranges,
output_format,
})
}
/// Returns information about the columns of data in the row.
pub fn columns(&self) -> &[Column] {
self.statement.columns()
}
/// Determines if the row contains no values.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Returns the number of values in the row.
pub fn len(&self) -> usize {
self.columns().len()
}
/// Deserializes a value from the row.
///
/// The value can be specified either by its numeric index in the row, or by its column name.
///
/// # Panics
///
/// Panics if the index is out of bounds or if the value cannot be converted to the specified type.
pub fn get<'a, I, T>(&'a self, idx: I) -> T
pub(crate) fn try_get<'a, T>(&'a self, columns: &[Column], idx: usize) -> Result<T, Error>
where
I: RowIndex + fmt::Display,
T: FromSql<'a>,
{
match self.get_inner(&idx) {
Ok(ok) => ok,
Err(err) => panic!("error retrieving column {}: {}", idx, err),
}
}
/// Like `Row::get`, but returns a `Result` rather than panicking.
pub fn try_get<'a, I, T>(&'a self, idx: I) -> Result<T, Error>
where
I: RowIndex + fmt::Display,
T: FromSql<'a>,
{
self.get_inner(&idx)
}
fn get_inner<'a, I, T>(&'a self, idx: &I) -> Result<T, Error>
where
I: RowIndex + fmt::Display,
T: FromSql<'a>,
{
let idx = match idx.__idx(self.columns()) {
Some(idx) => idx,
None => return Err(Error::column(idx.to_string())),
let Some(column) = columns.get(idx) else {
return Err(Error::column(idx.to_string()));
};
let ty = self.columns()[idx].type_();
let ty = column.type_();
if !T::accepts(ty) {
return Err(Error::from_sql(
Box::new(WrongType::new::<T>(ty.clone())),
@@ -216,85 +82,3 @@ impl Row {
self.body.buffer().len()
}
}
impl AsName for SimpleColumn {
fn as_name(&self) -> &str {
self.name()
}
}
/// A row of data returned from the database by a simple query.
#[derive(Debug)]
pub struct SimpleQueryRow {
columns: Arc<[SimpleColumn]>,
body: DataRowBody,
ranges: Vec<Option<Range<usize>>>,
}
impl SimpleQueryRow {
#[allow(clippy::new_ret_no_self)]
pub(crate) fn new(
columns: Arc<[SimpleColumn]>,
body: DataRowBody,
) -> Result<SimpleQueryRow, Error> {
let ranges = body.ranges().collect().map_err(Error::parse)?;
Ok(SimpleQueryRow {
columns,
body,
ranges,
})
}
/// Returns information about the columns of data in the row.
pub fn columns(&self) -> &[SimpleColumn] {
&self.columns
}
/// Determines if the row contains no values.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
/// Returns the number of values in the row.
pub fn len(&self) -> usize {
self.columns.len()
}
/// Returns a value from the row.
///
/// The value can be specified either by its numeric index in the row, or by its column name.
///
/// # Panics
///
/// Panics if the index is out of bounds or if the value cannot be converted to the specified type.
pub fn get<I>(&self, idx: I) -> Option<&str>
where
I: RowIndex + fmt::Display,
{
match self.get_inner(&idx) {
Ok(ok) => ok,
Err(err) => panic!("error retrieving column {}: {}", idx, err),
}
}
/// Like `SimpleQueryRow::get`, but returns a `Result` rather than panicking.
pub fn try_get<I>(&self, idx: I) -> Result<Option<&str>, Error>
where
I: RowIndex + fmt::Display,
{
self.get_inner(&idx)
}
fn get_inner<I>(&self, idx: &I) -> Result<Option<&str>, Error>
where
I: RowIndex + fmt::Display,
{
let idx = match idx.__idx(&self.columns) {
Some(idx) => idx,
None => return Err(Error::column(idx.to_string())),
};
let buf = self.ranges[idx].clone().map(|r| &self.body.buffer()[r]);
FromSql::from_sql_nullable(&Type::TEXT, buf).map_err(|e| Error::from_sql(e, idx))
}
}

View File

@@ -1,52 +1,14 @@
use crate::client::{InnerClient, Responses};
use crate::client::InnerClient;
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::{Error, ReadyForQueryStatus, SimpleQueryMessage, SimpleQueryRow};
use crate::{Error, ReadyForQueryStatus};
use bytes::Bytes;
use fallible_iterator::FallibleIterator;
use futures_util::{ready, Stream};
use log::debug;
use pin_project_lite::pin_project;
use postgres_protocol2::message::backend::Message;
use postgres_protocol2::message::frontend;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
/// Information about a column of a single query row.
#[derive(Debug)]
pub struct SimpleColumn {
name: String,
}
impl SimpleColumn {
pub(crate) fn new(name: String) -> SimpleColumn {
SimpleColumn { name }
}
/// Returns the name of the column.
pub fn name(&self) -> &str {
&self.name
}
}
pub async fn simple_query(client: &InnerClient, query: &str) -> Result<SimpleQueryStream, Error> {
debug!("executing simple query: {}", query);
let buf = encode(client, query)?;
let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
Ok(SimpleQueryStream {
responses,
columns: None,
status: ReadyForQueryStatus::Unknown,
_p: PhantomPinned,
})
}
pub async fn batch_execute(
client: &InnerClient,
client: &mut InnerClient,
query: &str,
) -> Result<ReadyForQueryStatus, Error> {
debug!("executing statement batch: {}", query);
@@ -66,77 +28,9 @@ pub async fn batch_execute(
}
}
pub(crate) fn encode(client: &InnerClient, query: &str) -> Result<Bytes, Error> {
pub(crate) fn encode(client: &mut InnerClient, query: &str) -> Result<Bytes, Error> {
client.with_buf(|buf| {
frontend::query(query, buf).map_err(Error::encode)?;
Ok(buf.split().freeze())
})
}
pin_project! {
/// A stream of simple query results.
pub struct SimpleQueryStream {
responses: Responses,
columns: Option<Arc<[SimpleColumn]>>,
status: ReadyForQueryStatus,
#[pin]
_p: PhantomPinned,
}
}
impl SimpleQueryStream {
/// Returns if the connection is ready for querying, with the status of the connection.
///
/// This might be available only after the stream has been exhausted.
pub fn ready_status(&self) -> ReadyForQueryStatus {
self.status
}
}
impl Stream for SimpleQueryStream {
type Item = Result<SimpleQueryMessage, Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
loop {
match ready!(this.responses.poll_next(cx)?) {
Message::CommandComplete(body) => {
let rows = body
.tag()
.map_err(Error::parse)?
.rsplit(' ')
.next()
.unwrap()
.parse()
.unwrap_or(0);
return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(rows))));
}
Message::EmptyQueryResponse => {
return Poll::Ready(Some(Ok(SimpleQueryMessage::CommandComplete(0))));
}
Message::RowDescription(body) => {
let columns = body
.fields()
.map(|f| Ok(SimpleColumn::new(f.name().to_string())))
.collect::<Vec<_>>()
.map_err(Error::parse)?
.into();
*this.columns = Some(columns);
}
Message::DataRow(body) => {
let row = match &this.columns {
Some(columns) => SimpleQueryRow::new(columns.clone(), body)?,
None => return Poll::Ready(Some(Err(Error::unexpected_message()))),
};
return Poll::Ready(Some(Ok(SimpleQueryMessage::Row(row))));
}
Message::ReadyForQuery(s) => {
*this.status = s.into();
return Poll::Ready(None);
}
_ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
}
}
}
}

View File

@@ -1,64 +1,33 @@
use crate::client::InnerClient;
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::types::Type;
use postgres_protocol2::{
message::{backend::Field, frontend},
Oid,
};
use std::{
fmt,
sync::{Arc, Weak},
};
use postgres_protocol2::{message::backend::Field, Oid};
use std::fmt;
struct StatementInner {
client: Weak<InnerClient>,
name: String,
params: Vec<Type>,
columns: Vec<Column>,
}
impl Drop for StatementInner {
fn drop(&mut self) {
if let Some(client) = self.client.upgrade() {
let buf = client.with_buf(|buf| {
frontend::close(b'S', &self.name, buf).unwrap();
frontend::sync(buf);
buf.split().freeze()
});
let _ = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
}
}
}
/// A prepared statement.
///
/// Prepared statements can only be used with the connection that created them.
#[derive(Clone)]
pub struct Statement(Arc<StatementInner>);
pub struct Statement(StatementInner);
impl Statement {
pub(crate) fn new(
inner: &Arc<InnerClient>,
name: String,
params: Vec<Type>,
columns: Vec<Column>,
) -> Statement {
Statement(Arc::new(StatementInner {
client: Arc::downgrade(inner),
pub(crate) fn new(name: String, params: Vec<Type>, columns: Vec<Column>) -> Statement {
Statement(StatementInner {
name,
params,
columns,
}))
})
}
pub(crate) fn new_anonymous(params: Vec<Type>, columns: Vec<Column>) -> Statement {
Statement(Arc::new(StatementInner {
client: Weak::new(),
Statement(StatementInner {
name: String::new(),
params,
columns,
}))
})
}
pub(crate) fn name(&self) -> &str {

View File

@@ -1,57 +0,0 @@
use crate::to_statement::private::{Sealed, ToStatementType};
use crate::Statement;
mod private {
use crate::{Client, Error, Statement};
pub trait Sealed {}
pub enum ToStatementType<'a> {
Statement(&'a Statement),
Query(&'a str),
}
impl<'a> ToStatementType<'a> {
pub async fn into_statement(self, client: &Client) -> Result<Statement, Error> {
match self {
ToStatementType::Statement(s) => Ok(s.clone()),
ToStatementType::Query(s) => client.prepare(s).await,
}
}
}
}
/// A trait abstracting over prepared and unprepared statements.
///
/// Many methods are generic over this bound, so that they support both a raw query string as well as a statement which
/// was prepared previously.
///
/// This trait is "sealed" and cannot be implemented by anything outside this crate.
pub trait ToStatement: Sealed {
#[doc(hidden)]
fn __convert(&self) -> ToStatementType<'_>;
}
impl ToStatement for Statement {
fn __convert(&self) -> ToStatementType<'_> {
ToStatementType::Statement(self)
}
}
impl Sealed for Statement {}
impl ToStatement for str {
fn __convert(&self) -> ToStatementType<'_> {
ToStatementType::Query(self)
}
}
impl Sealed for str {}
impl ToStatement for String {
fn __convert(&self) -> ToStatementType<'_> {
ToStatementType::Query(self)
}
}
impl Sealed for String {}

View File

@@ -1,6 +1,5 @@
use crate::codec::FrontendMessage;
use crate::connection::RequestMessages;
use crate::query::RowStream;
use crate::{CancelToken, Client, Error, ReadyForQueryStatus};
use postgres_protocol2::message::frontend;
@@ -19,13 +18,13 @@ impl Drop for Transaction<'_> {
return;
}
let buf = self.client.inner().with_buf(|buf| {
let buf = self.client.inner.with_buf(|buf| {
frontend::query("ROLLBACK", buf).unwrap();
buf.split().freeze()
});
let _ = self
.client
.inner()
.inner
.send(RequestMessages::Single(FrontendMessage::Raw(buf)));
}
}
@@ -52,23 +51,13 @@ impl<'a> Transaction<'a> {
self.client.batch_execute("ROLLBACK").await
}
/// Like `Client::query_raw_txt`.
pub async fn query_raw_txt<S, I>(&self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
I::IntoIter: ExactSizeIterator,
{
self.client.query_raw_txt(statement, params).await
}
/// Like `Client::cancel_token`.
pub fn cancel_token(&self) -> CancelToken {
self.client.cancel_token()
}
/// Returns a reference to the underlying `Client`.
pub fn client(&self) -> &Client {
pub fn client(&mut self) -> &mut Client {
self.client
}
}

View File

@@ -8,14 +8,15 @@ use std::io;
use std::num::NonZeroU32;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use std::time::SystemTime;
use super::REMOTE_STORAGE_PREFIX_SEPARATOR;
use anyhow::Context;
use anyhow::Result;
use azure_core::request_options::{IfMatchCondition, MaxResults, Metadata, Range};
use azure_core::{Continuable, RetryOptions};
use azure_identity::DefaultAzureCredential;
use azure_storage::StorageCredentials;
use azure_storage_blobs::blob::CopyStatus;
use azure_storage_blobs::prelude::ClientBuilder;
@@ -75,9 +76,8 @@ impl AzureBlobStorage {
let credentials = if let Ok(access_key) = env::var("AZURE_STORAGE_ACCESS_KEY") {
StorageCredentials::access_key(account.clone(), access_key)
} else {
let token_credential = azure_identity::create_default_credential()
.context("trying to obtain Azure default credentials")?;
StorageCredentials::token_credential(token_credential)
let token_credential = DefaultAzureCredential::default();
StorageCredentials::token_credential(Arc::new(token_credential))
};
// we have an outer retry
@@ -624,10 +624,6 @@ impl RemoteStorage for AzureBlobStorage {
res
}
fn max_keys_per_delete(&self) -> usize {
super::MAX_KEYS_PER_DELETE_AZURE
}
async fn copy(
&self,
from: &RemotePath,

View File

@@ -70,14 +70,7 @@ pub const DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT: usize = 100;
pub const DEFAULT_MAX_KEYS_PER_LIST_RESPONSE: Option<i32> = None;
/// As defined in S3 docs
///
/// <https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObjects.html>
pub const MAX_KEYS_PER_DELETE_S3: usize = 1000;
/// As defined in Azure docs
///
/// <https://learn.microsoft.com/en-us/rest/api/storageservices/blob-batch>
pub const MAX_KEYS_PER_DELETE_AZURE: usize = 256;
pub const MAX_KEYS_PER_DELETE: usize = 1000;
const REMOTE_STORAGE_PREFIX_SEPARATOR: char = '/';
@@ -347,14 +340,6 @@ pub trait RemoteStorage: Send + Sync + 'static {
cancel: &CancellationToken,
) -> anyhow::Result<()>;
/// Returns the maximum number of keys that a call to [`Self::delete_objects`] can delete without chunking
///
/// The value returned is only an optimization hint, One can pass larger number of objects to
/// `delete_objects` as well.
///
/// The value is guaranteed to be >= 1.
fn max_keys_per_delete(&self) -> usize;
/// Deletes all objects matching the given prefix.
///
/// NB: this uses NoDelimiter and will match partial prefixes. For example, the prefix /a/b will
@@ -548,16 +533,6 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
}
}
/// [`RemoteStorage::max_keys_per_delete`]
pub fn max_keys_per_delete(&self) -> usize {
match self {
Self::LocalFs(s) => s.max_keys_per_delete(),
Self::AwsS3(s) => s.max_keys_per_delete(),
Self::AzureBlob(s) => s.max_keys_per_delete(),
Self::Unreliable(s) => s.max_keys_per_delete(),
}
}
/// See [`RemoteStorage::delete_prefix`]
pub async fn delete_prefix(
&self,

View File

@@ -573,10 +573,6 @@ impl RemoteStorage for LocalFs {
Ok(())
}
fn max_keys_per_delete(&self) -> usize {
super::MAX_KEYS_PER_DELETE_S3
}
async fn copy(
&self,
from: &RemotePath,

View File

@@ -48,7 +48,7 @@ use crate::{
metrics::{start_counting_cancelled_wait, start_measuring_requests},
support::PermitCarrying,
ConcurrencyLimiter, Download, DownloadError, DownloadOpts, Listing, ListingMode, ListingObject,
RemotePath, RemoteStorage, TimeTravelError, TimeoutOrCancel, MAX_KEYS_PER_DELETE_S3,
RemotePath, RemoteStorage, TimeTravelError, TimeoutOrCancel, MAX_KEYS_PER_DELETE,
REMOTE_STORAGE_PREFIX_SEPARATOR,
};
@@ -355,7 +355,7 @@ impl S3Bucket {
let kind = RequestKind::Delete;
let mut cancel = std::pin::pin!(cancel.cancelled());
for chunk in delete_objects.chunks(MAX_KEYS_PER_DELETE_S3) {
for chunk in delete_objects.chunks(MAX_KEYS_PER_DELETE) {
let started_at = start_measuring_requests(kind);
let req = self
@@ -832,10 +832,6 @@ impl RemoteStorage for S3Bucket {
self.delete_oids(&permit, &delete_objects, cancel).await
}
fn max_keys_per_delete(&self) -> usize {
MAX_KEYS_PER_DELETE_S3
}
async fn delete(&self, path: &RemotePath, cancel: &CancellationToken) -> anyhow::Result<()> {
let paths = std::array::from_ref(path);
self.delete_objects(paths, cancel).await

View File

@@ -203,10 +203,6 @@ impl RemoteStorage for UnreliableWrapper {
Ok(())
}
fn max_keys_per_delete(&self) -> usize {
self.inner.max_keys_per_delete()
}
async fn copy(
&self,
from: &RemotePath,

View File

@@ -94,8 +94,6 @@ pub mod toml_edit_ext;
pub mod circuit_breaker;
pub mod try_rcu;
// Re-export used in macro. Avoids adding git-version as dep in target crates.
#[doc(hidden)]
pub use git_version;

View File

@@ -1,77 +0,0 @@
//! Try RCU extension lifted from <https://github.com/vorner/arc-swap/issues/94#issuecomment-1987154023>
pub trait ArcSwapExt<T> {
/// [`ArcSwap::rcu`](arc_swap::ArcSwap::rcu), but with Result that short-circuits on error.
fn try_rcu<R, F, E>(&self, f: F) -> Result<T, E>
where
F: FnMut(&T) -> Result<R, E>,
R: Into<T>;
}
impl<T, S> ArcSwapExt<T> for arc_swap::ArcSwapAny<T, S>
where
T: arc_swap::RefCnt,
S: arc_swap::strategy::CaS<T>,
{
fn try_rcu<R, F, E>(&self, mut f: F) -> Result<T, E>
where
F: FnMut(&T) -> Result<R, E>,
R: Into<T>,
{
fn ptr_eq<Base, A, B>(a: A, b: B) -> bool
where
A: arc_swap::AsRaw<Base>,
B: arc_swap::AsRaw<Base>,
{
let a = a.as_raw();
let b = b.as_raw();
std::ptr::eq(a, b)
}
let mut cur = self.load();
loop {
let new = f(&cur)?.into();
let prev = self.compare_and_swap(&*cur, new);
let swapped = ptr_eq(&*cur, &*prev);
if swapped {
return Ok(arc_swap::Guard::into_inner(prev));
} else {
cur = prev;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use arc_swap::ArcSwap;
use std::sync::Arc;
#[test]
fn test_try_rcu_success() {
let swap = ArcSwap::from(Arc::new(42));
let result = swap.try_rcu(|value| -> Result<_, String> { Ok(**value + 1) });
assert!(result.is_ok());
assert_eq!(**swap.load(), 43);
}
#[test]
fn test_try_rcu_error() {
let swap = ArcSwap::from(Arc::new(42));
let result = swap.try_rcu(|value| -> Result<i32, _> {
if **value == 42 {
Err("err")
} else {
Ok(**value + 1)
}
});
assert!(result.is_err());
assert_eq!(result.unwrap_err(), "err");
assert_eq!(**swap.load(), 42);
}
}

View File

@@ -37,7 +37,7 @@ message ValueMeta {
}
message CompactKey {
uint64 high = 1;
uint64 low = 2;
int64 high = 1;
int64 low = 2;
}

View File

@@ -236,8 +236,8 @@ impl From<ValueMeta> for proto::ValueMeta {
impl From<CompactKey> for proto::CompactKey {
fn from(value: CompactKey) -> Self {
proto::CompactKey {
high: (value.raw() >> 64) as u64,
low: value.raw() as u64,
high: (value.raw() >> 64) as i64,
low: value.raw() as i64,
}
}
}
@@ -354,64 +354,3 @@ impl From<proto::CompactKey> for CompactKey {
(((value.high as i128) << 64) | (value.low as i128)).into()
}
}
#[test]
fn test_compact_key_with_large_relnode() {
use pageserver_api::key::Key;
let inputs = vec![
Key {
field1: 0,
field2: 0x100,
field3: 0x200,
field4: 0,
field5: 0x10,
field6: 0x5,
},
Key {
field1: 0,
field2: 0x100,
field3: 0x200,
field4: 0x007FFFFF,
field5: 0x10,
field6: 0x5,
},
Key {
field1: 0,
field2: 0x100,
field3: 0x200,
field4: 0x00800000,
field5: 0x10,
field6: 0x5,
},
Key {
field1: 0,
field2: 0x100,
field3: 0x200,
field4: 0x00800001,
field5: 0x10,
field6: 0x5,
},
Key {
field1: 0,
field2: 0xFFFFFFFF,
field3: 0xFFFFFFFF,
field4: 0xFFFFFFFF,
field5: 0x0,
field6: 0x0,
},
];
for input in inputs {
assert!(input.is_valid_key_on_write_path());
let compact = input.to_compact();
let proto: proto::CompactKey = compact.into();
let from_proto: CompactKey = proto.into();
assert_eq!(
compact, from_proto,
"Round trip failed for key with relnode={:#x}",
input.field4
);
}
}

View File

@@ -30,9 +30,9 @@ fn main() -> anyhow::Result<()> {
let pgxn_neon = std::fs::canonicalize(pgxn_neon)?;
let pgxn_neon = pgxn_neon.to_str().ok_or(anyhow!("Bad non-UTF path"))?;
println!("cargo:rustc-link-lib=static=walproposer");
println!("cargo:rustc-link-lib=static=pgport");
println!("cargo:rustc-link-lib=static=pgcommon");
println!("cargo:rustc-link-lib=static=walproposer");
println!("cargo:rustc-link-search={walproposer_lib_search_str}");
// Rebuild crate when libwalproposer.a changes

View File

@@ -270,18 +270,12 @@ impl Client {
Ok(body)
}
pub async fn set_tenant_config(&self, req: &TenantConfigRequest) -> Result<()> {
pub async fn tenant_config(&self, req: &TenantConfigRequest) -> Result<()> {
let uri = format!("{}/v1/tenant/config", self.mgmt_api_endpoint);
self.request(Method::PUT, &uri, req).await?;
Ok(())
}
pub async fn patch_tenant_config(&self, req: &TenantConfigPatchRequest) -> Result<()> {
let uri = format!("{}/v1/tenant/config", self.mgmt_api_endpoint);
self.request(Method::PATCH, &uri, req).await?;
Ok(())
}
pub async fn tenant_secondary_download(
&self,
tenant_id: TenantShardId,

View File

@@ -64,7 +64,7 @@ async fn main_impl(args: Args) -> anyhow::Result<()> {
println!("operating on timeline {}", timeline);
mgmt_api_client
.set_tenant_config(&TenantConfigRequest {
.tenant_config(&TenantConfigRequest {
tenant_id: timeline.tenant_id,
config: TenantConfig::default(),
})

View File

@@ -9,6 +9,7 @@
use remote_storage::GenericRemoteStorage;
use remote_storage::RemotePath;
use remote_storage::TimeoutOrCancel;
use remote_storage::MAX_KEYS_PER_DELETE;
use std::time::Duration;
use tokio_util::sync::CancellationToken;
use tracing::info;
@@ -130,8 +131,7 @@ impl Deleter {
}
pub(super) async fn background(&mut self) -> Result<(), DeletionQueueError> {
let max_keys_per_delete = self.remote_storage.max_keys_per_delete();
self.accumulator.reserve(max_keys_per_delete);
self.accumulator.reserve(MAX_KEYS_PER_DELETE);
loop {
if self.cancel.is_cancelled() {
@@ -156,14 +156,14 @@ impl Deleter {
match msg {
DeleterMessage::Delete(mut list) => {
while !list.is_empty() || self.accumulator.len() == max_keys_per_delete {
if self.accumulator.len() == max_keys_per_delete {
while !list.is_empty() || self.accumulator.len() == MAX_KEYS_PER_DELETE {
if self.accumulator.len() == MAX_KEYS_PER_DELETE {
self.flush().await?;
// If we have received this number of keys, proceed with attempting to execute
assert_eq!(self.accumulator.len(), 0);
}
let available_slots = max_keys_per_delete - self.accumulator.len();
let available_slots = MAX_KEYS_PER_DELETE - self.accumulator.len();
let take_count = std::cmp::min(available_slots, list.len());
for path in list.drain(list.len() - take_count..) {
self.accumulator.push(path);

View File

@@ -767,27 +767,7 @@ paths:
/v1/tenant/config:
put:
description: |
Update tenant's config by setting it to the provided value
Invalid fields in the tenant config will cause the request to be rejected with status 400.
requestBody:
content:
application/json:
schema:
$ref: "#/components/schemas/TenantConfigRequest"
responses:
"200":
description: OK
content:
application/json:
schema:
type: array
items:
$ref: "#/components/schemas/TenantInfo"
patch:
description: |
Update tenant's config additively by patching the updated fields provided.
Null values unset the field and non-null values upsert it.
Update tenant's config.
Invalid fields in the tenant config will cause the request to be rejected with status 400.
requestBody:

View File

@@ -28,7 +28,6 @@ use pageserver_api::models::LsnLease;
use pageserver_api::models::LsnLeaseRequest;
use pageserver_api::models::OffloadedTimelineInfo;
use pageserver_api::models::ShardParameters;
use pageserver_api::models::TenantConfigPatchRequest;
use pageserver_api::models::TenantDetails;
use pageserver_api::models::TenantLocationConfigRequest;
use pageserver_api::models::TenantLocationConfigResponse;
@@ -1696,47 +1695,7 @@ async fn update_tenant_config_handler(
crate::tenant::Tenant::persist_tenant_config(state.conf, &tenant_shard_id, &location_conf)
.await
.map_err(|e| ApiError::InternalServerError(anyhow::anyhow!(e)))?;
let _ = tenant
.update_tenant_config(|_crnt| Ok(new_tenant_conf.clone()))
.expect("Closure returns Ok()");
json_response(StatusCode::OK, ())
}
async fn patch_tenant_config_handler(
mut request: Request<Body>,
_cancel: CancellationToken,
) -> Result<Response<Body>, ApiError> {
let request_data: TenantConfigPatchRequest = json_request(&mut request).await?;
let tenant_id = request_data.tenant_id;
check_permission(&request, Some(tenant_id))?;
let state = get_state(&request);
let tenant_shard_id = TenantShardId::unsharded(tenant_id);
let tenant = state
.tenant_manager
.get_attached_tenant_shard(tenant_shard_id)?;
tenant.wait_to_become_active(ACTIVE_TENANT_TIMEOUT).await?;
let updated = tenant
.update_tenant_config(|crnt| crnt.apply_patch(request_data.config.clone()))
.map_err(ApiError::BadRequest)?;
// This is a legacy API that only operates on attached tenants: the preferred
// API to use is the location_config/ endpoint, which lets the caller provide
// the full LocationConf.
let location_conf = LocationConf::attached_single(
updated,
tenant.get_generation(),
&ShardParameters::default(),
);
crate::tenant::Tenant::persist_tenant_config(state.conf, &tenant_shard_id, &location_conf)
.await
.map_err(|e| ApiError::InternalServerError(anyhow::anyhow!(e)))?;
tenant.set_new_tenant_config(new_tenant_conf);
json_response(StatusCode::OK, ())
}
@@ -2077,30 +2036,15 @@ async fn timeline_compact_handler(
parse_query_param::<_, bool>(&request, "wait_until_scheduled_compaction_done")?
.unwrap_or(false);
let sub_compaction = compact_request
.as_ref()
.map(|r| r.sub_compaction)
.unwrap_or(false);
let sub_compaction_max_job_size_mb = compact_request
.as_ref()
.and_then(|r| r.sub_compaction_max_job_size_mb);
let options = CompactOptions {
compact_key_range: compact_request
compact_range: compact_request
.as_ref()
.and_then(|r| r.compact_key_range.clone()),
compact_lsn_range: compact_request
.as_ref()
.and_then(|r| r.compact_lsn_range.clone()),
.and_then(|r| r.compact_range.clone()),
compact_below_lsn: compact_request.as_ref().and_then(|r| r.compact_below_lsn),
flags,
sub_compaction,
sub_compaction_max_job_size_mb,
};
let scheduled = compact_request
.as_ref()
.map(|r| r.scheduled)
.unwrap_or(false);
let scheduled = compact_request.map(|r| r.scheduled).unwrap_or(false);
async {
let ctx = RequestContext::new(TaskKind::MgmtRequest, DownloadBehavior::Download);
@@ -2109,7 +2053,7 @@ async fn timeline_compact_handler(
let tenant = state
.tenant_manager
.get_attached_tenant_shard(tenant_shard_id)?;
let rx = tenant.schedule_compaction(timeline_id, options).await.map_err(ApiError::InternalServerError)?;
let rx = tenant.schedule_compaction(timeline_id, options).await;
if wait_until_scheduled_compaction_done {
// It is possible that this will take a long time, dropping the HTTP request will not cancel the compaction.
rx.await.ok();
@@ -3336,9 +3280,6 @@ pub fn make_router(
.get("/v1/tenant/:tenant_shard_id/synthetic_size", |r| {
api_handler(r, tenant_size_handler)
})
.patch("/v1/tenant/config", |r| {
api_handler(r, patch_tenant_config_handler)
})
.put("/v1/tenant/config", |r| {
api_handler(r, update_tenant_config_handler)
})

View File

@@ -16,6 +16,7 @@ use postgres_backend::{is_expected_io_error, QueryError};
use pq_proto::framed::ConnectionError;
use strum::{EnumCount, VariantNames};
use strum_macros::{IntoStaticStr, VariantNames};
use tracing::warn;
use utils::id::TimelineId;
/// Prometheus histogram buckets (in seconds) for operations in the critical
@@ -1222,163 +1223,54 @@ pub(crate) mod virtual_file_io_engine {
});
}
pub(crate) struct SmgrOpTimer(Option<SmgrOpTimerInner>);
pub(crate) struct SmgrOpTimerInner {
global_execution_latency_histo: Histogram,
per_timeline_execution_latency_histo: Option<Histogram>,
pub(crate) struct SmgrOpTimer {
global_latency_histo: Histogram,
global_batch_wait_time: Histogram,
per_timeline_batch_wait_time: Histogram,
// Optional because not all op types are tracked per-timeline
per_timeline_latency_histo: Option<Histogram>,
global_flush_in_progress_micros: IntCounter,
per_timeline_flush_in_progress_micros: IntCounter,
timings: SmgrOpTimerState,
}
#[derive(Debug)]
enum SmgrOpTimerState {
Received {
received_at: Instant,
},
ThrottleDoneExecutionStarting {
received_at: Instant,
throttle_started_at: Instant,
started_execution_at: Instant,
},
}
pub(crate) struct SmgrOpFlushInProgress {
flush_started_at: Instant,
global_micros: IntCounter,
per_timeline_micros: IntCounter,
start: Instant,
throttled: Duration,
op: SmgrQueryType,
}
impl SmgrOpTimer {
pub(crate) fn observe_throttle_done_execution_starting(&mut self, throttle: &ThrottleResult) {
let inner = self.0.as_mut().expect("other public methods consume self");
match (&mut inner.timings, throttle) {
(SmgrOpTimerState::Received { received_at }, throttle) => match throttle {
ThrottleResult::NotThrottled { start } => {
inner.timings = SmgrOpTimerState::ThrottleDoneExecutionStarting {
received_at: *received_at,
throttle_started_at: *start,
started_execution_at: *start,
};
}
ThrottleResult::Throttled { start, end } => {
inner.timings = SmgrOpTimerState::ThrottleDoneExecutionStarting {
received_at: *start,
throttle_started_at: *start,
started_execution_at: *end,
};
}
},
(x, _) => panic!("called in unexpected state: {x:?}"),
}
}
pub(crate) fn observe_smgr_op_completion_and_start_flushing(mut self) -> SmgrOpFlushInProgress {
let (flush_start, inner) = self
.smgr_op_end()
.expect("this method consume self, and the only other caller is drop handler");
let SmgrOpTimerInner {
global_flush_in_progress_micros,
per_timeline_flush_in_progress_micros,
..
} = inner;
SmgrOpFlushInProgress {
flush_started_at: flush_start,
global_micros: global_flush_in_progress_micros,
per_timeline_micros: per_timeline_flush_in_progress_micros,
}
}
/// Returns `None`` if this method has already been called, `Some` otherwise.
fn smgr_op_end(&mut self) -> Option<(Instant, SmgrOpTimerInner)> {
let inner = self.0.take()?;
let now = Instant::now();
let batch;
let execution;
let throttle;
match inner.timings {
SmgrOpTimerState::Received { received_at } => {
batch = (now - received_at).as_secs_f64();
// TODO: use label for dropped requests.
// This is quite rare in practice, only during tenant/pageservers shutdown.
throttle = Duration::ZERO;
execution = Duration::ZERO.as_secs_f64();
}
SmgrOpTimerState::ThrottleDoneExecutionStarting {
received_at,
throttle_started_at,
started_execution_at,
} => {
batch = (throttle_started_at - received_at).as_secs_f64();
throttle = started_execution_at - throttle_started_at;
execution = (now - started_execution_at).as_secs_f64();
}
}
// update time spent in batching
inner.global_batch_wait_time.observe(batch);
inner.per_timeline_batch_wait_time.observe(batch);
// time spent in throttle metric is updated by throttle impl
let _ = throttle;
// update metrics for execution latency
inner.global_execution_latency_histo.observe(execution);
if let Some(per_timeline_execution_latency_histo) =
&inner.per_timeline_execution_latency_histo
{
per_timeline_execution_latency_histo.observe(execution);
}
Some((now, inner))
pub(crate) fn deduct_throttle(&mut self, throttle: &Option<Duration>) {
let Some(throttle) = throttle else {
return;
};
self.throttled += *throttle;
}
}
impl Drop for SmgrOpTimer {
fn drop(&mut self) {
self.smgr_op_end();
}
}
let elapsed = self.start.elapsed();
impl SmgrOpFlushInProgress {
pub(crate) async fn measure<Fut, O>(mut self, mut fut: Fut) -> O
where
Fut: std::future::Future<Output = O>,
{
let mut fut = std::pin::pin!(fut);
let now = Instant::now();
// Whenever observe_guard gets called, or dropped,
// it adds the time elapsed since its last call to metrics.
// Last call is tracked in `now`.
let mut observe_guard = scopeguard::guard(
|| {
let elapsed = now - self.flush_started_at;
self.global_micros
.inc_by(u64::try_from(elapsed.as_micros()).unwrap());
self.per_timeline_micros
.inc_by(u64::try_from(elapsed.as_micros()).unwrap());
self.flush_started_at = now;
},
|mut observe| {
observe();
},
);
loop {
match tokio::time::timeout(Duration::from_secs(10), &mut fut).await {
Ok(v) => return v,
Err(_timeout) => {
(*observe_guard)();
}
let elapsed = match elapsed.checked_sub(self.throttled) {
Some(elapsed) => elapsed,
None => {
use utils::rate_limit::RateLimit;
static LOGGED: Lazy<Mutex<enum_map::EnumMap<SmgrQueryType, RateLimit>>> =
Lazy::new(|| {
Mutex::new(enum_map::EnumMap::from_array(std::array::from_fn(|_| {
RateLimit::new(Duration::from_secs(10))
})))
});
let mut guard = LOGGED.lock().unwrap();
let rate_limit = &mut guard[self.op];
rate_limit.call(|| {
warn!(op=?self.op, ?elapsed, ?self.throttled, "implementation error: time spent throttled exceeds total request wall clock time");
});
elapsed // un-throttled time, more info than just saturating to 0
}
};
let elapsed = elapsed.as_secs_f64();
self.global_latency_histo.observe(elapsed);
if let Some(per_timeline_getpage_histo) = &self.per_timeline_latency_histo {
per_timeline_getpage_histo.observe(elapsed);
}
}
}
@@ -1410,10 +1302,6 @@ pub(crate) struct SmgrQueryTimePerTimeline {
per_timeline_getpage_latency: Histogram,
global_batch_size: Histogram,
per_timeline_batch_size: Histogram,
global_flush_in_progress_micros: IntCounter,
per_timeline_flush_in_progress_micros: IntCounter,
global_batch_wait_time: Histogram,
per_timeline_batch_wait_time: Histogram,
}
static SMGR_QUERY_STARTED_GLOBAL: Lazy<IntCounterVec> = Lazy::new(|| {
@@ -1436,15 +1324,12 @@ static SMGR_QUERY_STARTED_PER_TENANT_TIMELINE: Lazy<IntCounterVec> = Lazy::new(|
.expect("failed to define a metric")
});
// Alias so all histograms recording per-timeline smgr timings use the same buckets.
static SMGR_QUERY_TIME_PER_TENANT_TIMELINE_BUCKETS: &[f64] = CRITICAL_OP_BUCKETS;
static SMGR_QUERY_TIME_PER_TENANT_TIMELINE: Lazy<HistogramVec> = Lazy::new(|| {
register_histogram_vec!(
"pageserver_smgr_query_seconds",
"Time spent _executing_ smgr query handling, excluding batch and throttle delays.",
"Time spent on smgr query handling, aggegated by query type and tenant/timeline.",
&["smgr_query_type", "tenant_id", "shard_id", "timeline_id"],
SMGR_QUERY_TIME_PER_TENANT_TIMELINE_BUCKETS.into(),
CRITICAL_OP_BUCKETS.into(),
)
.expect("failed to define a metric")
});
@@ -1502,7 +1387,7 @@ static SMGR_QUERY_TIME_GLOBAL_BUCKETS: Lazy<Vec<f64>> = Lazy::new(|| {
static SMGR_QUERY_TIME_GLOBAL: Lazy<HistogramVec> = Lazy::new(|| {
register_histogram_vec!(
"pageserver_smgr_query_seconds_global",
"Like pageserver_smgr_query_seconds, but aggregated to instance level.",
"Time spent on smgr query handling, aggregated by query type.",
&["smgr_query_type"],
SMGR_QUERY_TIME_GLOBAL_BUCKETS.clone(),
)
@@ -1579,45 +1464,6 @@ fn set_page_service_config_max_batch_size(conf: &PageServicePipeliningConfig) {
.set(value.try_into().unwrap());
}
static PAGE_SERVICE_SMGR_FLUSH_INPROGRESS_MICROS: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"pageserver_page_service_pagestream_flush_in_progress_micros",
"Counter that sums up the microseconds that a pagestream response was being flushed into the TCP connection. \
If the flush is particularly slow, this counter will be updated periodically to make slow flushes \
easily discoverable in monitoring. \
Hence, this is NOT a completion latency historgram.",
&["tenant_id", "shard_id", "timeline_id"],
)
.expect("failed to define a metric")
});
static PAGE_SERVICE_SMGR_FLUSH_INPROGRESS_MICROS_GLOBAL: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"pageserver_page_service_pagestream_flush_in_progress_micros_global",
"Like pageserver_page_service_pagestream_flush_in_progress_seconds, but instance-wide.",
)
.expect("failed to define a metric")
});
static PAGE_SERVICE_SMGR_BATCH_WAIT_TIME: Lazy<HistogramVec> = Lazy::new(|| {
register_histogram_vec!(
"pageserver_page_service_pagestream_batch_wait_time_seconds",
"Time a request spent waiting in its batch until the batch moved to throttle&execution.",
&["tenant_id", "shard_id", "timeline_id"],
SMGR_QUERY_TIME_PER_TENANT_TIMELINE_BUCKETS.into(),
)
.expect("failed to define a metric")
});
static PAGE_SERVICE_SMGR_BATCH_WAIT_TIME_GLOBAL: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"pageserver_page_service_pagestream_batch_wait_time_seconds_global",
"Like pageserver_page_service_pagestream_batch_wait_time_seconds, but aggregated to instance level.",
SMGR_QUERY_TIME_GLOBAL_BUCKETS.to_vec(),
)
.expect("failed to define a metric")
});
impl SmgrQueryTimePerTimeline {
pub(crate) fn new(tenant_shard_id: &TenantShardId, timeline_id: &TimelineId) -> Self {
let tenant_id = tenant_shard_id.tenant_id.to_string();
@@ -1658,17 +1504,6 @@ impl SmgrQueryTimePerTimeline {
.get_metric_with_label_values(&[&tenant_id, &shard_slug, &timeline_id])
.unwrap();
let global_batch_wait_time = PAGE_SERVICE_SMGR_BATCH_WAIT_TIME_GLOBAL.clone();
let per_timeline_batch_wait_time = PAGE_SERVICE_SMGR_BATCH_WAIT_TIME
.get_metric_with_label_values(&[&tenant_id, &shard_slug, &timeline_id])
.unwrap();
let global_flush_in_progress_micros =
PAGE_SERVICE_SMGR_FLUSH_INPROGRESS_MICROS_GLOBAL.clone();
let per_timeline_flush_in_progress_micros = PAGE_SERVICE_SMGR_FLUSH_INPROGRESS_MICROS
.get_metric_with_label_values(&[&tenant_id, &shard_slug, &timeline_id])
.unwrap();
Self {
global_started,
global_latency,
@@ -1676,13 +1511,9 @@ impl SmgrQueryTimePerTimeline {
per_timeline_getpage_started,
global_batch_size,
per_timeline_batch_size,
global_flush_in_progress_micros,
per_timeline_flush_in_progress_micros,
global_batch_wait_time,
per_timeline_batch_wait_time,
}
}
pub(crate) fn start_smgr_op(&self, op: SmgrQueryType, received_at: Instant) -> SmgrOpTimer {
pub(crate) fn start_smgr_op(&self, op: SmgrQueryType, started_at: Instant) -> SmgrOpTimer {
self.global_started[op as usize].inc();
let per_timeline_latency_histo = if matches!(op, SmgrQueryType::GetPageAtLsn) {
@@ -1692,17 +1523,13 @@ impl SmgrQueryTimePerTimeline {
None
};
SmgrOpTimer(Some(SmgrOpTimerInner {
global_execution_latency_histo: self.global_latency[op as usize].clone(),
per_timeline_execution_latency_histo: per_timeline_latency_histo,
timings: SmgrOpTimerState::Received { received_at },
global_flush_in_progress_micros: self.global_flush_in_progress_micros.clone(),
per_timeline_flush_in_progress_micros: self
.per_timeline_flush_in_progress_micros
.clone(),
global_batch_wait_time: self.global_batch_wait_time.clone(),
per_timeline_batch_wait_time: self.per_timeline_batch_wait_time.clone(),
}))
SmgrOpTimer {
global_latency_histo: self.global_latency[op as usize].clone(),
per_timeline_latency_histo,
start: started_at,
op,
throttled: Duration::ZERO,
}
}
pub(crate) fn observe_getpage_batch_start(&self, batch_size: usize) {
@@ -2377,15 +2204,6 @@ pub(crate) static WAL_INGEST: Lazy<WalIngestMetrics> = Lazy::new(|| WalIngestMet
.expect("failed to define a metric"),
});
pub(crate) static PAGESERVER_TIMELINE_WAL_RECORDS_RECEIVED: Lazy<IntCounterVec> = Lazy::new(|| {
register_int_counter_vec!(
"pageserver_timeline_wal_records_received",
"Number of WAL records received per shard",
&["tenant_id", "shard_id", "timeline_id"]
)
.expect("failed to define a metric")
});
pub(crate) static WAL_REDO_TIME: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"pageserver_wal_redo_seconds",
@@ -2613,7 +2431,6 @@ pub(crate) struct TimelineMetrics {
pub evictions_with_low_residence_duration: std::sync::RwLock<EvictionsWithLowResidenceDuration>,
/// Number of valid LSN leases.
pub valid_lsn_lease_count_gauge: UIntGauge,
pub wal_records_received: IntCounter,
shutdown: std::sync::atomic::AtomicBool,
}
@@ -2771,10 +2588,6 @@ impl TimelineMetrics {
.get_metric_with_label_values(&[&tenant_id, &shard_id, &timeline_id])
.unwrap();
let wal_records_received = PAGESERVER_TIMELINE_WAL_RECORDS_RECEIVED
.get_metric_with_label_values(&[&tenant_id, &shard_id, &timeline_id])
.unwrap();
TimelineMetrics {
tenant_id,
shard_id,
@@ -2807,7 +2620,6 @@ impl TimelineMetrics {
evictions_with_low_residence_duration,
),
valid_lsn_lease_count_gauge,
wal_records_received,
shutdown: std::sync::atomic::AtomicBool::default(),
}
}
@@ -2945,21 +2757,6 @@ impl TimelineMetrics {
shard_id,
timeline_id,
]);
let _ = PAGESERVER_TIMELINE_WAL_RECORDS_RECEIVED.remove_label_values(&[
tenant_id,
shard_id,
timeline_id,
]);
let _ = PAGE_SERVICE_SMGR_FLUSH_INPROGRESS_MICROS.remove_label_values(&[
tenant_id,
shard_id,
timeline_id,
]);
let _ = PAGE_SERVICE_SMGR_BATCH_WAIT_TIME.remove_label_values(&[
tenant_id,
shard_id,
timeline_id,
]);
}
}
@@ -2990,7 +2787,6 @@ use crate::context::{PageContentKind, RequestContext};
use crate::task_mgr::TaskKind;
use crate::tenant::mgr::TenantSlot;
use crate::tenant::tasks::BackgroundLoopKind;
use crate::tenant::throttle::ThrottleResult;
use crate::tenant::Timeline;
/// Maintain a per timeline gauge in addition to the global gauge.
@@ -3845,7 +3641,6 @@ pub fn preinitialize_metrics(conf: &'static PageServerConf) {
&REMOTE_ONDEMAND_DOWNLOADED_BYTES,
&CIRCUIT_BREAKERS_BROKEN,
&CIRCUIT_BREAKERS_UNBROKEN,
&PAGE_SERVICE_SMGR_FLUSH_INPROGRESS_MICROS_GLOBAL,
]
.into_iter()
.for_each(|c| {
@@ -3893,7 +3688,6 @@ pub fn preinitialize_metrics(conf: &'static PageServerConf) {
&WAL_REDO_BYTES_HISTOGRAM,
&WAL_REDO_PROCESS_LAUNCH_DURATION_HISTOGRAM,
&PAGE_SERVICE_BATCH_SIZE_GLOBAL,
&PAGE_SERVICE_SMGR_BATCH_WAIT_TIME_GLOBAL,
]
.into_iter()
.for_each(|h| {

View File

@@ -575,10 +575,7 @@ enum BatchedFeMessage {
}
impl BatchedFeMessage {
async fn throttle_and_record_start_processing(
&mut self,
cancel: &CancellationToken,
) -> Result<(), QueryError> {
async fn throttle(&mut self, cancel: &CancellationToken) -> Result<(), QueryError> {
let (shard, tokens, timers) = match self {
BatchedFeMessage::Exists { shard, timer, .. }
| BatchedFeMessage::Nblocks { shard, timer, .. }
@@ -606,7 +603,7 @@ impl BatchedFeMessage {
}
};
for timer in timers {
timer.observe_throttle_done_execution_starting(&throttled);
timer.deduct_throttle(&throttled);
}
Ok(())
}
@@ -1020,8 +1017,10 @@ impl PageServerHandler {
// Map handler result to protocol behavior.
// Some handler errors cause exit from pagestream protocol.
// Other handler errors are sent back as an error message and we stay in pagestream protocol.
let mut timers: smallvec::SmallVec<[_; 1]> =
smallvec::SmallVec::with_capacity(handler_results.len());
for handler_result in handler_results {
let (response_msg, timer) = match handler_result {
let response_msg = match handler_result {
Err(e) => match &e {
PageStreamError::Shutdown => {
// If we fail to fulfil a request during shutdown, which may be _because_ of
@@ -1045,66 +1044,34 @@ impl PageServerHandler {
span.in_scope(|| {
error!("error reading relation or page version: {full:#}")
});
(
PagestreamBeMessage::Error(PagestreamErrorResponse {
message: e.to_string(),
}),
None, // TODO: measure errors
)
PagestreamBeMessage::Error(PagestreamErrorResponse {
message: e.to_string(),
})
}
},
Ok((response_msg, timer)) => (response_msg, Some(timer)),
Ok((response_msg, timer)) => {
// Extending the lifetime of the timers so observations on drop
// include the flush time.
timers.push(timer);
response_msg
}
};
//
// marshal & transmit response message
//
pgb_writer.write_message_noflush(&BeMessage::CopyData(&response_msg.serialize()))?;
// We purposefully don't count flush time into the timer.
//
// The reason is that current compute client will not perform protocol processing
// if the postgres backend process is doing things other than `->smgr_read()`.
// This is especially the case for prefetch.
//
// If the compute doesn't read from the connection, eventually TCP will backpressure
// all the way into our flush call below.
//
// The timer's underlying metric is used for a storage-internal latency SLO and
// we don't want to include latency in it that we can't control.
// And as pointed out above, in this case, we don't control the time that flush will take.
let flushing_timer =
timer.map(|timer| timer.observe_smgr_op_completion_and_start_flushing());
// what we want to do
let flush_fut = pgb_writer.flush();
// metric for how long flushing takes
let flush_fut = match flushing_timer {
Some(flushing_timer) => {
futures::future::Either::Left(flushing_timer.measure(flush_fut))
}
None => futures::future::Either::Right(flush_fut),
};
// do it while respecting cancellation
let _: () = async move {
tokio::select! {
biased;
_ = cancel.cancelled() => {
// We were requested to shut down.
info!("shutdown request received in page handler");
return Err(QueryError::Shutdown)
}
res = flush_fut => {
res?;
}
}
Ok(())
}
// and log the info! line inside the request span
.instrument(span.clone())
.await?;
}
tokio::select! {
biased;
_ = cancel.cancelled() => {
// We were requested to shut down.
info!("shutdown request received in page handler");
return Err(QueryError::Shutdown)
}
res = pgb_writer.flush() => {
res?;
}
}
drop(timers);
Ok(())
}
@@ -1233,7 +1200,7 @@ impl PageServerHandler {
}
};
if let Err(cancelled) = msg.throttle_and_record_start_processing(&self.cancel).await {
if let Err(cancelled) = msg.throttle(&self.cancel).await {
break cancelled;
}
@@ -1400,9 +1367,7 @@ impl PageServerHandler {
return Err(e);
}
};
batch
.throttle_and_record_start_processing(&self.cancel)
.await?;
batch.throttle(&self.cancel).await?;
self.pagesteam_handle_batched_message(pgb_writer, batch, &cancel, &ctx)
.await?;
}

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,7 @@
pub(crate) use pageserver_api::config::TenantConfigToml as TenantConf;
use pageserver_api::models::CompactionAlgorithmSettings;
use pageserver_api::models::EvictionPolicy;
use pageserver_api::models::{self, TenantConfigPatch, ThrottleConfig};
use pageserver_api::models::{self, ThrottleConfig};
use pageserver_api::shard::{ShardCount, ShardIdentity, ShardNumber, ShardStripeSize};
use serde::de::IntoDeserializer;
use serde::{Deserialize, Serialize};
@@ -427,129 +427,6 @@ impl TenantConfOpt {
.or(global_conf.wal_receiver_protocol_override),
}
}
pub fn apply_patch(self, patch: TenantConfigPatch) -> anyhow::Result<TenantConfOpt> {
let Self {
mut checkpoint_distance,
mut checkpoint_timeout,
mut compaction_target_size,
mut compaction_period,
mut compaction_threshold,
mut compaction_algorithm,
mut gc_horizon,
mut gc_period,
mut image_creation_threshold,
mut pitr_interval,
mut walreceiver_connect_timeout,
mut lagging_wal_timeout,
mut max_lsn_wal_lag,
mut eviction_policy,
mut min_resident_size_override,
mut evictions_low_residence_duration_metric_threshold,
mut heatmap_period,
mut lazy_slru_download,
mut timeline_get_throttle,
mut image_layer_creation_check_threshold,
mut lsn_lease_length,
mut lsn_lease_length_for_ts,
mut timeline_offloading,
mut wal_receiver_protocol_override,
} = self;
patch.checkpoint_distance.apply(&mut checkpoint_distance);
patch
.checkpoint_timeout
.map(|v| humantime::parse_duration(&v))?
.apply(&mut checkpoint_timeout);
patch
.compaction_target_size
.apply(&mut compaction_target_size);
patch
.compaction_period
.map(|v| humantime::parse_duration(&v))?
.apply(&mut compaction_period);
patch.compaction_threshold.apply(&mut compaction_threshold);
patch.compaction_algorithm.apply(&mut compaction_algorithm);
patch.gc_horizon.apply(&mut gc_horizon);
patch
.gc_period
.map(|v| humantime::parse_duration(&v))?
.apply(&mut gc_period);
patch
.image_creation_threshold
.apply(&mut image_creation_threshold);
patch
.pitr_interval
.map(|v| humantime::parse_duration(&v))?
.apply(&mut pitr_interval);
patch
.walreceiver_connect_timeout
.map(|v| humantime::parse_duration(&v))?
.apply(&mut walreceiver_connect_timeout);
patch
.lagging_wal_timeout
.map(|v| humantime::parse_duration(&v))?
.apply(&mut lagging_wal_timeout);
patch.max_lsn_wal_lag.apply(&mut max_lsn_wal_lag);
patch.eviction_policy.apply(&mut eviction_policy);
patch
.min_resident_size_override
.apply(&mut min_resident_size_override);
patch
.evictions_low_residence_duration_metric_threshold
.map(|v| humantime::parse_duration(&v))?
.apply(&mut evictions_low_residence_duration_metric_threshold);
patch
.heatmap_period
.map(|v| humantime::parse_duration(&v))?
.apply(&mut heatmap_period);
patch.lazy_slru_download.apply(&mut lazy_slru_download);
patch
.timeline_get_throttle
.apply(&mut timeline_get_throttle);
patch
.image_layer_creation_check_threshold
.apply(&mut image_layer_creation_check_threshold);
patch
.lsn_lease_length
.map(|v| humantime::parse_duration(&v))?
.apply(&mut lsn_lease_length);
patch
.lsn_lease_length_for_ts
.map(|v| humantime::parse_duration(&v))?
.apply(&mut lsn_lease_length_for_ts);
patch.timeline_offloading.apply(&mut timeline_offloading);
patch
.wal_receiver_protocol_override
.apply(&mut wal_receiver_protocol_override);
Ok(Self {
checkpoint_distance,
checkpoint_timeout,
compaction_target_size,
compaction_period,
compaction_threshold,
compaction_algorithm,
gc_horizon,
gc_period,
image_creation_threshold,
pitr_interval,
walreceiver_connect_timeout,
lagging_wal_timeout,
max_lsn_wal_lag,
eviction_policy,
min_resident_size_override,
evictions_low_residence_duration_metric_threshold,
heatmap_period,
lazy_slru_download,
timeline_get_throttle,
image_layer_creation_check_threshold,
lsn_lease_length,
lsn_lease_length_for_ts,
timeline_offloading,
wal_receiver_protocol_override,
})
}
}
impl TryFrom<&'_ models::TenantConfig> for TenantConfOpt {

View File

@@ -1,4 +1,4 @@
use std::{collections::HashMap, sync::Arc};
use std::collections::HashMap;
use utils::id::TimelineId;
@@ -20,7 +20,7 @@ pub(crate) struct GcBlock {
/// Do not add any more features taking and forbidding taking this lock. It should be
/// `tokio::sync::Notify`, but that is rarely used. On the other side, [`GcBlock::insert`]
/// synchronizes with gc attempts by locking and unlocking this mutex.
blocking: Arc<tokio::sync::Mutex<()>>,
blocking: tokio::sync::Mutex<()>,
}
impl GcBlock {
@@ -30,7 +30,7 @@ impl GcBlock {
/// it's ending, or if not currently possible, a value describing the reasons why not.
///
/// Cancellation safe.
pub(super) async fn start(&self) -> Result<Guard, BlockingReasons> {
pub(super) async fn start(&self) -> Result<Guard<'_>, BlockingReasons> {
let reasons = {
let g = self.reasons.lock().unwrap();
@@ -44,7 +44,7 @@ impl GcBlock {
Err(reasons)
} else {
Ok(Guard {
_inner: self.blocking.clone().lock_owned().await,
_inner: self.blocking.lock().await,
})
}
}
@@ -170,8 +170,8 @@ impl GcBlock {
}
}
pub(crate) struct Guard {
_inner: tokio::sync::OwnedMutexGuard<()>,
pub(super) struct Guard<'a> {
_inner: tokio::sync::MutexGuard<'a, ()>,
}
#[derive(Debug)]

View File

@@ -58,11 +58,6 @@ pub struct Stats {
pub sum_throttled_usecs: u64,
}
pub enum ThrottleResult {
NotThrottled { start: Instant },
Throttled { start: Instant, end: Instant },
}
impl<M> Throttle<M>
where
M: Metric,
@@ -127,15 +122,15 @@ where
self.inner.load().rate_limiter.steady_rps()
}
pub async fn throttle(&self, key_count: usize) -> ThrottleResult {
pub async fn throttle(&self, key_count: usize) -> Option<Duration> {
let inner = self.inner.load_full(); // clones the `Inner` Arc
let start = std::time::Instant::now();
if !inner.enabled {
return ThrottleResult::NotThrottled { start };
return None;
}
let start = std::time::Instant::now();
self.metric.accounting_start();
self.count_accounted_start.fetch_add(1, Ordering::Relaxed);
let did_throttle = inner.rate_limiter.acquire(key_count).await;
@@ -150,9 +145,9 @@ where
.fetch_add(wait_time.as_micros() as u64, Ordering::Relaxed);
let observation = Observation { wait_time };
self.metric.observe_throttling(&observation);
ThrottleResult::Throttled { start, end: now }
Some(wait_time)
} else {
ThrottleResult::NotThrottled { start }
None
}
}
}

View File

@@ -780,90 +780,40 @@ pub(crate) enum CompactFlags {
#[serde_with::serde_as]
#[derive(Debug, Clone, serde::Deserialize)]
pub(crate) struct CompactRequest {
pub compact_key_range: Option<CompactKeyRange>,
pub compact_lsn_range: Option<CompactLsnRange>,
pub compact_range: Option<CompactRange>,
pub compact_below_lsn: Option<Lsn>,
/// Whether the compaction job should be scheduled.
#[serde(default)]
pub scheduled: bool,
/// Whether the compaction job should be split across key ranges.
#[serde(default)]
pub sub_compaction: bool,
/// Max job size for each subcompaction job.
pub sub_compaction_max_job_size_mb: Option<u64>,
}
#[serde_with::serde_as]
#[derive(Debug, Clone, serde::Deserialize)]
pub(crate) struct CompactLsnRange {
pub start: Lsn,
pub end: Lsn,
}
#[serde_with::serde_as]
#[derive(Debug, Clone, serde::Deserialize)]
pub(crate) struct CompactKeyRange {
pub(crate) struct CompactRange {
#[serde_as(as = "serde_with::DisplayFromStr")]
pub start: Key,
#[serde_as(as = "serde_with::DisplayFromStr")]
pub end: Key,
}
impl From<Range<Lsn>> for CompactLsnRange {
fn from(range: Range<Lsn>) -> Self {
Self {
start: range.start,
end: range.end,
}
}
}
impl From<Range<Key>> for CompactKeyRange {
impl From<Range<Key>> for CompactRange {
fn from(range: Range<Key>) -> Self {
Self {
CompactRange {
start: range.start,
end: range.end,
}
}
}
impl From<CompactLsnRange> for Range<Lsn> {
fn from(range: CompactLsnRange) -> Self {
range.start..range.end
}
}
impl From<CompactKeyRange> for Range<Key> {
fn from(range: CompactKeyRange) -> Self {
range.start..range.end
}
}
impl CompactLsnRange {
#[cfg(test)]
#[cfg(feature = "testing")]
pub fn above(lsn: Lsn) -> Self {
Self {
start: lsn,
end: Lsn::MAX,
}
}
}
#[derive(Debug, Clone, Default)]
pub(crate) struct CompactOptions {
pub flags: EnumSet<CompactFlags>,
/// If set, the compaction will only compact the key range specified by this option.
/// This option is only used by GC compaction. For the full explanation, see [`compaction::GcCompactJob`].
pub compact_key_range: Option<CompactKeyRange>,
/// If set, the compaction will only compact the LSN within this value.
/// This option is only used by GC compaction. For the full explanation, see [`compaction::GcCompactJob`].
pub compact_lsn_range: Option<CompactLsnRange>,
/// Enable sub-compaction (split compaction job across key ranges).
/// This option is only used by GC compaction.
pub sub_compaction: bool,
/// Set job size for the GC compaction.
pub compact_range: Option<CompactRange>,
/// If set, the compaction will only compact the LSN below this value.
/// This option is only used by GC compaction.
pub sub_compaction_max_job_size_mb: Option<u64>,
pub compact_below_lsn: Option<Lsn>,
}
impl std::fmt::Debug for Timeline {
@@ -1685,10 +1635,8 @@ impl Timeline {
cancel,
CompactOptions {
flags,
compact_key_range: None,
compact_lsn_range: None,
sub_compaction: false,
sub_compaction_max_job_size_mb: None,
compact_range: None,
compact_below_lsn: None,
},
ctx,
)

View File

@@ -29,6 +29,7 @@ use utils::id::TimelineId;
use crate::context::{AccessStatsBehavior, RequestContext, RequestContextBuilder};
use crate::page_cache;
use crate::statvfs::Statvfs;
use crate::tenant::checks::check_valid_layermap;
use crate::tenant::remote_timeline_client::WaitCompletionError;
use crate::tenant::storage_layer::batch_split_writer::{
BatchWriterResult, SplitDeltaLayerWriter, SplitImageLayerWriter,
@@ -41,7 +42,7 @@ use crate::tenant::storage_layer::{
use crate::tenant::timeline::ImageLayerCreationOutcome;
use crate::tenant::timeline::{drop_rlock, DeltaLayerWriter, ImageLayerWriter};
use crate::tenant::timeline::{Layer, ResidentLayer};
use crate::tenant::{gc_block, DeltaLayer, MaybeOffloaded};
use crate::tenant::{DeltaLayer, MaybeOffloaded};
use crate::virtual_file::{MaybeFatalIo, VirtualFile};
use pageserver_api::config::tenant_conf_defaults::{
DEFAULT_CHECKPOINT_DISTANCE, DEFAULT_COMPACTION_THRESHOLD,
@@ -63,68 +64,21 @@ use super::CompactionError;
const COMPACTION_DELTA_THRESHOLD: usize = 5;
/// A scheduled compaction task.
pub(crate) struct ScheduledCompactionTask {
/// It's unfortunate that we need to store a compact options struct here because the only outer
/// API we can call here is `compact_with_options` which does a few setup calls before starting the
/// actual compaction job... We should refactor this to store `GcCompactionJob` in the future.
pub struct ScheduledCompactionTask {
pub options: CompactOptions,
/// The channel to send the compaction result. If this is a subcompaction, the last compaction job holds the sender.
pub result_tx: Option<tokio::sync::oneshot::Sender<()>>,
/// Hold the GC block. If this is a subcompaction, the last compaction job holds the gc block guard.
pub gc_block: Option<gc_block::Guard>,
}
/// A job description for the gc-compaction job. This structure describes the rectangle range that the job will
/// process. The exact layers that need to be compacted/rewritten will be generated when `compact_with_gc` gets
/// called.
#[derive(Debug, Clone)]
pub(crate) struct GcCompactJob {
pub dry_run: bool,
/// The key range to be compacted. The compaction algorithm will only regenerate key-value pairs within this range
/// [left inclusive, right exclusive), and other pairs will be rewritten into new files if necessary.
pub compact_key_range: Range<Key>,
/// The LSN range to be compacted. The compaction algorithm will use this range to determine the layers to be
/// selected for the compaction, and it does not guarantee the generated layers will have exactly the same LSN range
/// as specified here. The true range being compacted is `min_lsn/max_lsn` in [`GcCompactionJobDescription`].
/// min_lsn will always <= the lower bound specified here, and max_lsn will always >= the upper bound specified here.
pub compact_lsn_range: Range<Lsn>,
}
impl GcCompactJob {
pub fn from_compact_options(options: CompactOptions) -> Self {
GcCompactJob {
dry_run: options.flags.contains(CompactFlags::DryRun),
compact_key_range: options
.compact_key_range
.map(|x| x.into())
.unwrap_or(Key::MIN..Key::MAX),
compact_lsn_range: options
.compact_lsn_range
.map(|x| x.into())
.unwrap_or(Lsn::INVALID..Lsn::MAX),
}
}
}
/// A job description for the gc-compaction job. This structure is generated when `compact_with_gc` is called
/// and contains the exact layers we want to compact.
pub struct GcCompactionJobDescription {
/// All layers to read in the compaction job
selected_layers: Vec<Layer>,
/// GC cutoff of the job. This is the lowest LSN that will be accessed by the read/GC path and we need to
/// keep all deltas <= this LSN or generate an image == this LSN.
/// GC cutoff of the job
gc_cutoff: Lsn,
/// LSNs to retain for the job. Read path will use this LSN so we need to keep deltas <= this LSN or
/// generate an image == this LSN.
/// LSNs to retain for the job
retain_lsns_below_horizon: Vec<Lsn>,
/// Maximum layer LSN processed in this compaction, that is max(end_lsn of layers). Exclusive. All data
/// \>= this LSN will be kept and will not be rewritten.
/// Maximum layer LSN processed in this compaction
max_layer_lsn: Lsn,
/// Minimum layer LSN processed in this compaction, that is min(start_lsn of layers). Inclusive.
/// All access below (strict lower than `<`) this LSN will be routed through the normal read path instead of
/// k-merge within gc-compaction.
min_layer_lsn: Lsn,
/// Only compact layers overlapping with this range.
/// Only compact layers overlapping with this range
compaction_key_range: Range<Key>,
/// When partial compaction is enabled, these layers need to be rewritten to ensure no overlap.
/// This field is here solely for debugging. The field will not be read once the compaction
@@ -343,7 +297,7 @@ impl Timeline {
)));
}
if options.compact_key_range.is_some() || options.compact_lsn_range.is_some() {
if options.compact_range.is_some() {
// maybe useful in the future? could implement this at some point
return Err(CompactionError::Other(anyhow!(
"compaction range is not supported for legacy compaction for now"
@@ -1798,114 +1752,6 @@ impl Timeline {
Ok(())
}
/// Split a gc-compaction job into multiple compaction jobs. The split is based on the key range and the estimated size of the compaction job.
/// The function returns a list of compaction jobs that can be executed separately. If the upper bound of the compact LSN
/// range is not specified, we will use the latest gc_cutoff as the upper bound, so that all jobs in the jobset acts
/// like a full compaction of the specified keyspace.
pub(crate) async fn gc_compaction_split_jobs(
self: &Arc<Self>,
job: GcCompactJob,
sub_compaction_max_job_size_mb: Option<u64>,
) -> anyhow::Result<Vec<GcCompactJob>> {
let compact_below_lsn = if job.compact_lsn_range.end != Lsn::MAX {
job.compact_lsn_range.end
} else {
*self.get_latest_gc_cutoff_lsn() // use the real gc cutoff
};
// Split compaction job to about 4GB each
const GC_COMPACT_MAX_SIZE_MB: u64 = 4 * 1024;
let sub_compaction_max_job_size_mb =
sub_compaction_max_job_size_mb.unwrap_or(GC_COMPACT_MAX_SIZE_MB);
let mut compact_jobs = Vec::new();
// For now, we simply use the key partitioning information; we should do a more fine-grained partitioning
// by estimating the amount of files read for a compaction job. We should also partition on LSN.
let Ok(partition) = self.partitioning.try_lock() else {
bail!("failed to acquire partition lock");
};
let ((dense_ks, sparse_ks), _) = &*partition;
// Truncate the key range to be within user specified compaction range.
fn truncate_to(
source_start: &Key,
source_end: &Key,
target_start: &Key,
target_end: &Key,
) -> Option<(Key, Key)> {
let start = source_start.max(target_start);
let end = source_end.min(target_end);
if start < end {
Some((*start, *end))
} else {
None
}
}
let mut split_key_ranges = Vec::new();
let ranges = dense_ks
.parts
.iter()
.map(|partition| partition.ranges.iter())
.chain(sparse_ks.parts.iter().map(|x| x.0.ranges.iter()))
.flatten()
.cloned()
.collect_vec();
for range in ranges.iter() {
let Some((start, end)) = truncate_to(
&range.start,
&range.end,
&job.compact_key_range.start,
&job.compact_key_range.end,
) else {
continue;
};
split_key_ranges.push((start, end));
}
split_key_ranges.sort();
let guard = self.layers.read().await;
let layer_map = guard.layer_map()?;
let mut current_start = None;
let ranges_num = split_key_ranges.len();
for (idx, (start, end)) in split_key_ranges.into_iter().enumerate() {
if current_start.is_none() {
current_start = Some(start);
}
let start = current_start.unwrap();
if start >= end {
// We have already processed this partition.
continue;
}
let res = layer_map.range_search(start..end, compact_below_lsn);
let total_size = res.found.keys().map(|x| x.layer.file_size()).sum::<u64>();
if total_size > sub_compaction_max_job_size_mb * 1024 * 1024 || ranges_num == idx + 1 {
// Try to extend the compaction range so that we include at least one full layer file.
let extended_end = res
.found
.keys()
.map(|layer| layer.layer.key_range.end)
.min();
// It is possible that the search range does not contain any layer files when we reach the end of the loop.
// In this case, we simply use the specified key range end.
let end = if let Some(extended_end) = extended_end {
extended_end.max(end)
} else {
end
};
info!(
"splitting compaction job: {}..{}, estimated_size={}",
start, end, total_size
);
compact_jobs.push(GcCompactJob {
dry_run: job.dry_run,
compact_key_range: start..end,
compact_lsn_range: job.compact_lsn_range.start..compact_below_lsn,
});
current_start = Some(end);
}
}
drop(guard);
Ok(compact_jobs)
}
/// An experimental compaction building block that combines compaction with garbage collection.
///
/// The current implementation picks all delta + image layers that are below or intersecting with
@@ -1920,43 +1766,13 @@ impl Timeline {
/// Key::MIN..Key..MAX to the function indicates a full compaction, though technically, `Key::MAX` is not
/// part of the range.
///
/// If `options.compact_lsn_range.end` is provided, the compaction will only compact layers below or intersect with
/// If `options.compact_below_lsn` is provided, the compaction will only compact layers below or intersect with
/// the LSN. Otherwise, it will use the gc cutoff by default.
pub(crate) async fn compact_with_gc(
self: &Arc<Self>,
cancel: &CancellationToken,
options: CompactOptions,
ctx: &RequestContext,
) -> anyhow::Result<()> {
let sub_compaction = options.sub_compaction;
let job = GcCompactJob::from_compact_options(options.clone());
if sub_compaction {
info!("running enhanced gc bottom-most compaction with sub-compaction, splitting compaction jobs");
let jobs = self
.gc_compaction_split_jobs(job, options.sub_compaction_max_job_size_mb)
.await?;
let jobs_len = jobs.len();
for (idx, job) in jobs.into_iter().enumerate() {
info!(
"running enhanced gc bottom-most compaction, sub-compaction {}/{}",
idx + 1,
jobs_len
);
self.compact_with_gc_inner(cancel, job, ctx).await?;
}
if jobs_len == 0 {
info!("no jobs to run, skipping gc bottom-most compaction");
}
return Ok(());
}
self.compact_with_gc_inner(cancel, job, ctx).await
}
async fn compact_with_gc_inner(
self: &Arc<Self>,
cancel: &CancellationToken,
job: GcCompactJob,
ctx: &RequestContext,
) -> anyhow::Result<()> {
// Block other compaction/GC tasks from running for now. GC-compaction could run along
// with legacy compaction tasks in the future. Always ensure the lock order is compaction -> gc.
@@ -1977,11 +1793,19 @@ impl Timeline {
)
.await?;
let dry_run = job.dry_run;
let compact_key_range = job.compact_key_range;
let compact_lsn_range = job.compact_lsn_range;
let flags = options.flags;
let compaction_key_range = options
.compact_range
.map(|range| range.start..range.end)
.unwrap_or_else(|| Key::MIN..Key::MAX);
info!("running enhanced gc bottom-most compaction, dry_run={dry_run}, compact_key_range={}..{}, compact_lsn_range={}..{}", compact_key_range.start, compact_key_range.end, compact_lsn_range.start, compact_lsn_range.end);
let dry_run = flags.contains(CompactFlags::DryRun);
if compaction_key_range == (Key::MIN..Key::MAX) {
info!("running enhanced gc bottom-most compaction, dry_run={dry_run}, compaction_key_range={}..{}", compaction_key_range.start, compaction_key_range.end);
} else {
info!("running enhanced gc bottom-most compaction, dry_run={dry_run}");
}
scopeguard::defer! {
info!("done enhanced gc bottom-most compaction");
@@ -1999,21 +1823,13 @@ impl Timeline {
let gc_info = self.gc_info.read().unwrap();
let mut retain_lsns_below_horizon = Vec::new();
let gc_cutoff = {
// Currently, gc-compaction only kicks in after the legacy gc has updated the gc_cutoff.
// Therefore, it can only clean up data that cannot be cleaned up with legacy gc, instead of
// cleaning everything that theoritically it could. In the future, it should use `self.gc_info`
// to get the truth data.
let real_gc_cutoff = *self.get_latest_gc_cutoff_lsn();
let real_gc_cutoff = gc_info.cutoffs.select_min();
// The compaction algorithm will keep all keys above the gc_cutoff while keeping only necessary keys below the gc_cutoff for
// each of the retain_lsn. Therefore, if the user-provided `compact_lsn_range.end` is larger than the real gc cutoff, we will use
// each of the retain_lsn. Therefore, if the user-provided `compact_below_lsn` is larger than the real gc cutoff, we will use
// the real cutoff.
let mut gc_cutoff = if compact_lsn_range.end == Lsn::MAX {
real_gc_cutoff
} else {
compact_lsn_range.end
};
let mut gc_cutoff = options.compact_below_lsn.unwrap_or(real_gc_cutoff);
if gc_cutoff > real_gc_cutoff {
warn!("provided compact_lsn_range.end={} is larger than the real_gc_cutoff={}, using the real gc cutoff", gc_cutoff, real_gc_cutoff);
warn!("provided compact_below_lsn={} is larger than the real_gc_cutoff={}, using the real gc cutoff", gc_cutoff, real_gc_cutoff);
gc_cutoff = real_gc_cutoff;
}
gc_cutoff
@@ -2030,7 +1846,7 @@ impl Timeline {
}
let mut selected_layers: Vec<Layer> = Vec::new();
drop(gc_info);
// Firstly, pick all the layers intersect or below the gc_cutoff, get the largest LSN in the selected layers.
// Pick all the layers intersect or below the gc_cutoff, get the largest LSN in the selected layers.
let Some(max_layer_lsn) = layers
.iter_historic_layers()
.filter(|desc| desc.get_lsn_range().start <= gc_cutoff)
@@ -2040,45 +1856,27 @@ impl Timeline {
info!("no layers to compact with gc: no historic layers below gc_cutoff, gc_cutoff={}", gc_cutoff);
return Ok(());
};
// Next, if the user specifies compact_lsn_range.start, we need to filter some layers out. All the layers (strictly) below
// the min_layer_lsn computed as below will be filtered out and the data will be accessed using the normal read path, as if
// it is a branch.
let Some(min_layer_lsn) = layers
.iter_historic_layers()
.filter(|desc| {
if compact_lsn_range.start == Lsn::INVALID {
true // select all layers below if start == Lsn(0)
} else {
desc.get_lsn_range().end > compact_lsn_range.start // strictly larger than compact_above_lsn
}
})
.map(|desc| desc.get_lsn_range().start)
.min()
else {
info!("no layers to compact with gc: no historic layers above compact_above_lsn, compact_above_lsn={}", compact_lsn_range.end);
return Ok(());
};
// Then, pick all the layers that are below the max_layer_lsn. This is to ensure we can pick all single-key
// layers to compact.
let mut rewrite_layers = Vec::new();
for desc in layers.iter_historic_layers() {
if desc.get_lsn_range().end <= max_layer_lsn
&& desc.get_lsn_range().start >= min_layer_lsn
&& overlaps_with(&desc.get_key_range(), &compact_key_range)
&& overlaps_with(&desc.get_key_range(), &compaction_key_range)
{
// If the layer overlaps with the compaction key range, we need to read it to obtain all keys within the range,
// even if it might contain extra keys
selected_layers.push(guard.get_from_desc(&desc));
// If the layer is not fully contained within the key range, we need to rewrite it if it's a delta layer (it's fine
// to overlap image layers)
if desc.is_delta() && !fully_contains(&compact_key_range, &desc.get_key_range())
if desc.is_delta()
&& !fully_contains(&compaction_key_range, &desc.get_key_range())
{
rewrite_layers.push(desc);
}
}
}
if selected_layers.is_empty() {
info!("no layers to compact with gc: no layers within the key range, gc_cutoff={}, key_range={}..{}", gc_cutoff, compact_key_range.start, compact_key_range.end);
info!("no layers to compact with gc: no layers within the key range, gc_cutoff={}, key_range={}..{}", gc_cutoff, compaction_key_range.start, compaction_key_range.end);
return Ok(());
}
retain_lsns_below_horizon.sort();
@@ -2086,20 +1884,13 @@ impl Timeline {
selected_layers,
gc_cutoff,
retain_lsns_below_horizon,
min_layer_lsn,
max_layer_lsn,
compaction_key_range: compact_key_range,
compaction_key_range,
rewrite_layers,
}
};
let (has_data_below, lowest_retain_lsn) = if compact_lsn_range.start != Lsn::INVALID {
// If we only compact above some LSN, we should get the history from the current branch below the specified LSN.
// We use job_desc.min_layer_lsn as if it's the lowest branch point.
(true, job_desc.min_layer_lsn)
} else if self.ancestor_timeline.is_some() {
// In theory, we can also use min_layer_lsn here, but using ancestor LSN makes sure the delta layers cover the
// LSN ranges all the way to the ancestor timeline.
(true, self.ancestor_lsn)
let lowest_retain_lsn = if self.ancestor_timeline.is_some() {
Lsn(self.ancestor_lsn.0 + 1)
} else {
let res = job_desc
.retain_lsns_below_horizon
@@ -2117,19 +1908,17 @@ impl Timeline {
.unwrap_or(job_desc.gc_cutoff)
);
}
(false, res)
res
};
info!(
"picked {} layers for compaction ({} layers need rewriting) with max_layer_lsn={} min_layer_lsn={} gc_cutoff={} lowest_retain_lsn={}, key_range={}..{}, has_data_below={}",
"picked {} layers for compaction ({} layers need rewriting) with max_layer_lsn={} gc_cutoff={} lowest_retain_lsn={}, key_range={}..{}",
job_desc.selected_layers.len(),
job_desc.rewrite_layers.len(),
job_desc.max_layer_lsn,
job_desc.min_layer_lsn,
job_desc.gc_cutoff,
lowest_retain_lsn,
job_desc.compaction_key_range.start,
job_desc.compaction_key_range.end,
has_data_below,
job_desc.compaction_key_range.end
);
for layer in &job_desc.selected_layers {
@@ -2154,15 +1943,14 @@ impl Timeline {
// Step 1: construct a k-merge iterator over all layers.
// Also, verify if the layer map can be split by drawing a horizontal line at every LSN start/end split point.
// disable the check for now because we need to adjust the check for partial compactions, will enable later.
// let layer_names = job_desc
// .selected_layers
// .iter()
// .map(|layer| layer.layer_desc().layer_name())
// .collect_vec();
// if let Some(err) = check_valid_layermap(&layer_names) {
// warn!("gc-compaction layer map check failed because {}, this is normal if partial compaction is not finished yet", err);
// }
let layer_names = job_desc
.selected_layers
.iter()
.map(|layer| layer.layer_desc().layer_name())
.collect_vec();
if let Some(err) = check_valid_layermap(&layer_names) {
warn!("gc-compaction layer map check failed because {}, this is normal if partial compaction is not finished yet", err);
}
// The maximum LSN we are processing in this compaction loop
let end_lsn = job_desc
.selected_layers
@@ -2173,22 +1961,10 @@ impl Timeline {
let mut delta_layers = Vec::new();
let mut image_layers = Vec::new();
let mut downloaded_layers = Vec::new();
let mut total_downloaded_size = 0;
let mut total_layer_size = 0;
for layer in &job_desc.selected_layers {
if layer.needs_download().await?.is_some() {
total_downloaded_size += layer.layer_desc().file_size;
}
total_layer_size += layer.layer_desc().file_size;
let resident_layer = layer.download_and_keep_resident().await?;
downloaded_layers.push(resident_layer);
}
info!(
"finish downloading layers, downloaded={}, total={}, ratio={:.2}",
total_downloaded_size,
total_layer_size,
total_downloaded_size as f64 / total_layer_size as f64
);
for resident_layer in &downloaded_layers {
if resident_layer.layer_desc().is_delta() {
let layer = resident_layer.get_as_delta(ctx).await?;
@@ -2211,7 +1987,7 @@ impl Timeline {
// Only create image layers when there is no ancestor branches. TODO: create covering image layer
// when some condition meet.
let mut image_layer_writer = if !has_data_below {
let mut image_layer_writer = if self.ancestor_timeline.is_none() {
Some(
SplitImageLayerWriter::new(
self.conf,
@@ -2244,11 +2020,7 @@ impl Timeline {
}
let mut delta_layer_rewriters = HashMap::<Arc<PersistentLayerKey>, RewritingLayers>::new();
/// When compacting not at a bottom range (=`[0,X)`) of the root branch, we "have data below" (`has_data_below=true`).
/// The two cases are compaction in ancestor branches and when `compact_lsn_range.start` is set.
/// In those cases, we need to pull up data from below the LSN range we're compaction.
///
/// This function unifies the cases so that later code doesn't have to think about it.
/// Returns None if there is no ancestor branch. Throw an error when the key is not found.
///
/// Currently, we always get the ancestor image for each key in the child branch no matter whether the image
/// is needed for reconstruction. This should be fixed in the future.
@@ -2256,19 +2028,17 @@ impl Timeline {
/// Furthermore, we should do vectored get instead of a single get, or better, use k-merge for ancestor
/// images.
async fn get_ancestor_image(
this_tline: &Arc<Timeline>,
tline: &Arc<Timeline>,
key: Key,
ctx: &RequestContext,
has_data_below: bool,
history_lsn_point: Lsn,
) -> anyhow::Result<Option<(Key, Lsn, Bytes)>> {
if !has_data_below {
if tline.ancestor_timeline.is_none() {
return Ok(None);
};
// This function is implemented as a get of the current timeline at ancestor LSN, therefore reusing
// as much existing code as possible.
let img = this_tline.get(key, history_lsn_point, ctx).await?;
Ok(Some((key, history_lsn_point, img)))
let img = tline.get(key, tline.ancestor_lsn, ctx).await?;
Ok(Some((key, tline.ancestor_lsn, img)))
}
// Actually, we can decide not to write to the image layer at all at this point because
@@ -2352,8 +2122,7 @@ impl Timeline {
job_desc.gc_cutoff,
&job_desc.retain_lsns_below_horizon,
COMPACTION_DELTA_THRESHOLD,
get_ancestor_image(self, *last_key, ctx, has_data_below, lowest_retain_lsn)
.await?,
get_ancestor_image(self, *last_key, ctx).await?,
)
.await?;
retention
@@ -2382,7 +2151,7 @@ impl Timeline {
job_desc.gc_cutoff,
&job_desc.retain_lsns_below_horizon,
COMPACTION_DELTA_THRESHOLD,
get_ancestor_image(self, last_key, ctx, has_data_below, lowest_retain_lsn).await?,
get_ancestor_image(self, last_key, ctx).await?,
)
.await?;
retention

View File

@@ -369,13 +369,6 @@ pub(super) async fn handle_walreceiver_connection(
// advances it to its end LSN. 0 is just an initialization placeholder.
let mut modification = timeline.begin_modification(Lsn(0));
if !records.is_empty() {
timeline
.metrics
.wal_records_received
.inc_by(records.len() as u64);
}
for interpreted in records {
if matches!(interpreted.flush_uncommitted, FlushUncommittedRecords::Yes)
&& uncommitted_records > 0
@@ -517,7 +510,6 @@ pub(super) async fn handle_walreceiver_connection(
}
// Ingest the records without immediately committing them.
timeline.metrics.wal_records_received.inc();
let ingested = walingest
.ingest_record(interpreted, &mut modification, &ctx)
.await

View File

@@ -877,24 +877,22 @@ impl WalIngest {
// will block waiting for the last valid LSN to advance up to
// it. So we use the previous record's LSN in the get calls
// instead.
if modification.tline.get_shard_identity().is_shard_zero() {
for segno in modification
.tline
.list_slru_segments(SlruKind::Clog, Version::Modified(modification), ctx)
.await?
{
let segpage = segno * pg_constants::SLRU_PAGES_PER_SEGMENT;
for segno in modification
.tline
.list_slru_segments(SlruKind::Clog, Version::Modified(modification), ctx)
.await?
{
let segpage = segno * pg_constants::SLRU_PAGES_PER_SEGMENT;
let may_delete = dispatch_pgversion!(modification.tline.pg_version, {
pgv::nonrelfile_utils::slru_may_delete_clogsegment(segpage, pageno)
});
let may_delete = dispatch_pgversion!(modification.tline.pg_version, {
pgv::nonrelfile_utils::slru_may_delete_clogsegment(segpage, pageno)
});
if may_delete {
modification
.drop_slru_segment(SlruKind::Clog, segno, ctx)
.await?;
trace!("Drop CLOG segment {:>04X}", segno);
}
if may_delete {
modification
.drop_slru_segment(SlruKind::Clog, segno, ctx)
.await?;
trace!("Drop CLOG segment {:>04X}", segno);
}
}
@@ -1049,18 +1047,16 @@ impl WalIngest {
// Delete all the segments except the last one. The last segment can still
// contain, possibly partially, valid data.
if modification.tline.get_shard_identity().is_shard_zero() {
while segment != endsegment {
modification
.drop_slru_segment(SlruKind::MultiXactMembers, segment as u32, ctx)
.await?;
while segment != endsegment {
modification
.drop_slru_segment(SlruKind::MultiXactMembers, segment as u32, ctx)
.await?;
/* move to next segment, handling wraparound correctly */
if segment == maxsegment {
segment = 0;
} else {
segment += 1;
}
/* move to next segment, handling wraparound correctly */
if segment == maxsegment {
segment = 0;
} else {
segment += 1;
}
}

View File

@@ -22,7 +22,6 @@
#include "libpq/pqformat.h"
#include "miscadmin.h"
#include "pgstat.h"
#include "portability/instr_time.h"
#include "postmaster/interrupt.h"
#include "storage/buf_internals.h"
#include "storage/ipc.h"
@@ -119,11 +118,6 @@ typedef struct
*/
PSConnectionState state;
PGconn *conn;
/* request / response counters for debugging */
uint64 nrequests_sent;
uint64 nresponses_received;
/*---
* WaitEventSet containing:
* - WL_SOCKET_READABLE on 'conn'
@@ -634,8 +628,6 @@ pageserver_connect(shardno_t shard_no, int elevel)
}
shard->state = PS_Connected;
shard->nrequests_sent = 0;
shard->nresponses_received = 0;
}
/* FALLTHROUGH */
case PS_Connected:
@@ -664,27 +656,6 @@ call_PQgetCopyData(shardno_t shard_no, char **buffer)
int ret;
PageServer *shard = &page_servers[shard_no];
PGconn *pageserver_conn = shard->conn;
instr_time now,
start_ts,
since_start,
last_log_ts,
since_last_log;
bool logged = false;
/*
* As a debugging aid, if we don't get a response for a long time, print a
* log message.
*
* 10 s is a very generous threshold, normally we expect a response in a
* few milliseconds. We have metrics to track latencies in normal ranges,
* but in the cases that take exceptionally long, it's useful to log the
* exact timestamps.
*/
#define LOG_INTERVAL_US UINT64CONST(10 * 1000000)
INSTR_TIME_SET_CURRENT(now);
start_ts = last_log_ts = now;
INSTR_TIME_SET_ZERO(since_last_log);
retry:
ret = PQgetCopyData(pageserver_conn, buffer, 1 /* async */ );
@@ -692,12 +663,9 @@ retry:
if (ret == 0)
{
WaitEvent event;
long timeout;
timeout = Min(0, LOG_INTERVAL_US - INSTR_TIME_GET_MICROSEC(since_last_log));
/* Sleep until there's something to do */
(void) WaitEventSetWait(shard->wes_read, timeout, &event, 1,
(void) WaitEventSetWait(shard->wes_read, -1L, &event, 1,
WAIT_EVENT_NEON_PS_READ);
ResetLatch(MyLatch);
@@ -716,40 +684,9 @@ retry:
}
}
/*
* Print a message to the log if a long time has passed with no
* response.
*/
INSTR_TIME_SET_CURRENT(now);
since_last_log = now;
INSTR_TIME_SUBTRACT(since_last_log, last_log_ts);
if (INSTR_TIME_GET_MICROSEC(since_last_log) >= LOG_INTERVAL_US)
{
since_start = now;
INSTR_TIME_SUBTRACT(since_start, start_ts);
neon_shard_log(shard_no, LOG, "no response received from pageserver for %0.3f s, still waiting (sent " UINT64_FORMAT " requests, received " UINT64_FORMAT " responses)",
INSTR_TIME_GET_DOUBLE(since_start),
shard->nrequests_sent, shard->nresponses_received);
last_log_ts = now;
logged = true;
}
goto retry;
}
/*
* If we logged earlier that the response is taking a long time, log
* another message when the response is finally received.
*/
if (logged)
{
INSTR_TIME_SET_CURRENT(now);
since_start = now;
INSTR_TIME_SUBTRACT(since_start, start_ts);
neon_shard_log(shard_no, LOG, "received response from pageserver after %0.3f s",
INSTR_TIME_GET_DOUBLE(since_start));
}
return ret;
}
@@ -849,7 +786,6 @@ pageserver_send(shardno_t shard_no, NeonRequest *request)
* PGRES_POLLING_WRITING state. It's kinda dirty to disconnect at this
* point, but on the grand scheme of things it's only a small issue.
*/
shard->nrequests_sent++;
if (PQputCopyData(pageserver_conn, req_buff.data, req_buff.len) <= 0)
{
char *msg = pchomp(PQerrorMessage(pageserver_conn));
@@ -942,7 +878,6 @@ pageserver_receive(shardno_t shard_no)
neon_shard_log(shard_no, ERROR, "pageserver_receive disconnect: unexpected PQgetCopyData return value: %d", rc);
}
shard->nresponses_received++;
return (NeonResponse *) resp;
}

View File

@@ -423,11 +423,7 @@ readahead_buffer_resize(int newsize, void *extra)
* ensuring we have received all but the last n requests (n = newsize).
*/
if (MyPState->n_requests_inflight > newsize)
{
Assert(MyPState->ring_unused >= MyPState->n_requests_inflight - newsize);
prefetch_wait_for(MyPState->ring_unused - (MyPState->n_requests_inflight - newsize));
Assert(MyPState->n_requests_inflight <= newsize);
}
prefetch_wait_for(MyPState->ring_unused - newsize);
/* construct the new PrefetchState, and copy over the memory contexts */
newPState = MemoryContextAllocZero(TopMemoryContext, newprfs_size);
@@ -442,6 +438,7 @@ readahead_buffer_resize(int newsize, void *extra)
newPState->ring_last = newsize;
newPState->ring_unused = newsize;
newPState->ring_receive = newsize;
newPState->ring_flush = newsize;
newPState->max_shard_no = MyPState->max_shard_no;
memcpy(newPState->shard_bitmap, MyPState->shard_bitmap, sizeof(MyPState->shard_bitmap));
@@ -492,7 +489,6 @@ readahead_buffer_resize(int newsize, void *extra)
}
newPState->n_unused -= 1;
}
newPState->ring_flush = newPState->ring_receive;
MyNeonCounters->getpage_prefetches_buffered =
MyPState->n_responses_buffered;
@@ -502,7 +498,6 @@ readahead_buffer_resize(int newsize, void *extra)
for (; end >= MyPState->ring_last && end != UINT64_MAX; end -= 1)
{
PrefetchRequest *slot = GetPrfSlot(end);
Assert(slot->status != PRFS_REQUESTED);
if (slot->status == PRFS_RECEIVED)
{
pfree(slot->response);
@@ -615,9 +610,6 @@ prefetch_read(PrefetchRequest *slot)
{
NeonResponse *response;
MemoryContext old;
BufferTag buftag;
shardno_t shard_no;
uint64 my_ring_index;
Assert(slot->status == PRFS_REQUESTED);
Assert(slot->response == NULL);
@@ -631,29 +623,11 @@ prefetch_read(PrefetchRequest *slot)
slot->status, slot->response,
(long)slot->my_ring_index, (long)MyPState->ring_receive);
/*
* Copy the request info so that if an error happens and the prefetch
* queue is flushed during the receive call, we can print the original
* values in the error message
*/
buftag = slot->buftag;
shard_no = slot->shard_no;
my_ring_index = slot->my_ring_index;
old = MemoryContextSwitchTo(MyPState->errctx);
response = (NeonResponse *) page_server->receive(shard_no);
response = (NeonResponse *) page_server->receive(slot->shard_no);
MemoryContextSwitchTo(old);
if (response)
{
/* The slot should still be valid */
if (slot->status != PRFS_REQUESTED ||
slot->response != NULL ||
slot->my_ring_index != MyPState->ring_receive)
neon_shard_log(shard_no, ERROR,
"Incorrect prefetch slot state after receive: status=%d response=%p my=%lu receive=%lu",
slot->status, slot->response,
(long) slot->my_ring_index, (long) MyPState->ring_receive);
/* update prefetch state */
MyPState->n_responses_buffered += 1;
MyPState->n_requests_inflight -= 1;
@@ -668,15 +642,11 @@ prefetch_read(PrefetchRequest *slot)
}
else
{
/*
* Note: The slot might no longer be valid, if the connection was lost
* and the prefetch queue was flushed during the receive call
*/
neon_shard_log(shard_no, LOG,
neon_shard_log(slot->shard_no, LOG,
"No response from reading prefetch entry %lu: %u/%u/%u.%u block %u. This can be caused by a concurrent disconnect",
(long) my_ring_index,
RelFileInfoFmt(BufTagGetNRelFileInfo(buftag)),
buftag.forkNum, buftag.blockNum);
(long)slot->my_ring_index,
RelFileInfoFmt(BufTagGetNRelFileInfo(slot->buftag)),
slot->buftag.forkNum, slot->buftag.blockNum);
return false;
}
}

View File

@@ -74,6 +74,10 @@ impl std::fmt::Display for Backend<'_, ()> {
.debug_tuple("ControlPlane::ProxyV1")
.field(&endpoint.url())
.finish(),
ControlPlaneClient::Neon(endpoint) => fmt
.debug_tuple("ControlPlane::Neon")
.field(&endpoint.url())
.finish(),
#[cfg(any(test, feature = "testing"))]
ControlPlaneClient::PostgresMock(endpoint) => fmt
.debug_tuple("ControlPlane::PostgresMock")

View File

@@ -43,6 +43,9 @@ static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
#[derive(Clone, Debug, ValueEnum)]
enum AuthBackendType {
#[value(name("console"), alias("cplane"))]
ControlPlane,
#[value(name("cplane-v1"), alias("control-plane"))]
ControlPlaneV1,
@@ -485,7 +488,40 @@ async fn main() -> anyhow::Result<()> {
}
if let Either::Left(auth::Backend::ControlPlane(api, _)) = &auth_backend {
if let proxy::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api {
if let proxy::control_plane::client::ControlPlaneClient::Neon(api) = &**api {
match (redis_notifications_client, regional_redis_client.clone()) {
(None, None) => {}
(client1, client2) => {
let cache = api.caches.project_info.clone();
if let Some(client) = client1 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
}
if let Some(client) = client2 {
maintenance_tasks.spawn(notifications::task_main(
client,
cache.clone(),
cancel_map.clone(),
args.region.clone(),
));
}
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
}
}
if let Some(regional_redis_client) = regional_redis_client {
let cache = api.caches.endpoints_cache.clone();
let con = regional_redis_client;
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(con, cancellation_token.clone()).await }
.instrument(span),
);
}
} else if let proxy::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api {
match (redis_notifications_client, regional_redis_client.clone()) {
(None, None) => {}
(client1, client2) => {
@@ -721,6 +757,65 @@ fn build_auth_backend(
Ok(Either::Left(config))
}
AuthBackendType::ControlPlane => {
let wake_compute_cache_config: CacheOptions = args.wake_compute_cache.parse()?;
let project_info_cache_config: ProjectInfoCacheOptions =
args.project_info_cache.parse()?;
let endpoint_cache_config: config::EndpointCacheConfig =
args.endpoint_cache_config.parse()?;
info!("Using NodeInfoCache (wake_compute) with options={wake_compute_cache_config:?}");
info!(
"Using AllowedIpsCache (wake_compute) with options={project_info_cache_config:?}"
);
info!("Using EndpointCacheConfig with options={endpoint_cache_config:?}");
let caches = Box::leak(Box::new(control_plane::caches::ApiCaches::new(
wake_compute_cache_config,
project_info_cache_config,
endpoint_cache_config,
)));
let config::ConcurrencyLockOptions {
shards,
limiter,
epoch,
timeout,
} = args.wake_compute_lock.parse()?;
info!(?limiter, shards, ?epoch, "Using NodeLocks (wake_compute)");
let locks = Box::leak(Box::new(control_plane::locks::ApiLocks::new(
"wake_compute_lock",
limiter,
shards,
timeout,
epoch,
&Metrics::get().wake_compute_lock,
)?));
tokio::spawn(locks.garbage_collect_worker());
let url: proxy::url::ApiUrl = args.auth_endpoint.parse()?;
let endpoint = http::Endpoint::new(url, http::new_client());
let mut wake_compute_rps_limit = args.wake_compute_limit.clone();
RateBucketInfo::validate(&mut wake_compute_rps_limit)?;
let wake_compute_endpoint_rate_limiter =
Arc::new(WakeComputeRateLimiter::new(wake_compute_rps_limit));
let api = control_plane::client::neon::NeonControlPlaneClient::new(
endpoint,
args.control_plane_token.clone(),
caches,
locks,
wake_compute_endpoint_rate_limiter,
);
let api = control_plane::client::ControlPlaneClient::Neon(api);
let auth_backend = auth::Backend::ControlPlane(MaybeOwned::Owned(api), ());
let config = Box::leak(Box::new(auth_backend));
Ok(Either::Left(config))
}
#[cfg(feature = "testing")]
AuthBackendType::Postgres => {
let url = args.auth_endpoint.parse()?;

View File

@@ -115,8 +115,7 @@ impl<P: CancellationPublisher> CancellationHandler<P> {
IpAddr::V6(ip) => IpNet::V6(Ipv6Net::new_assert(ip, 64).trunc()),
};
if !self.limiter.lock().unwrap().check(subnet_key, 1) {
// log only the subnet part of the IP address to know which subnet is rate limited
tracing::warn!("Rate limit exceeded. Skipping cancellation message, {subnet_key}");
tracing::debug!("Rate limit exceeded. Skipping cancellation message");
Metrics::get()
.proxy
.cancellation_requests_total

View File

@@ -163,36 +163,32 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, tls, record_handshake_error);
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session_id = ctx.session_id();
let peer_ip = ctx.peer_addr();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?session_id);
cancel_span.follows_from(tracing::Span::current());
async move {
drop(
cancellation_handler_clone
.cancel_session(
cancel_key_data,
session_id,
peer_ip,
config.authentication_config.ip_allowlist_check_enabled,
)
.instrument(cancel_span)
.await,
);
}
});
let (mut stream, params) =
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session_id = ctx.session_id();
let peer_ip = ctx.peer_addr();
async move {
drop(
cancellation_handler_clone
.cancel_session(
cancel_key_data,
session_id,
peer_ip,
config.authentication_config.ip_allowlist_check_enabled,
)
.await,
);
}
});
return Ok(None);
}
};
return Ok(None);
}
};
drop(pause);
ctx.set_db_options(params.clone());

View File

@@ -1,6 +1,7 @@
pub mod cplane_proxy_v1;
#[cfg(any(test, feature = "testing"))]
pub mod mock;
pub mod neon;
use std::hash::Hash;
use std::sync::Arc;
@@ -27,8 +28,10 @@ use crate::types::EndpointId;
#[non_exhaustive]
#[derive(Clone)]
pub enum ControlPlaneClient {
/// Proxy V1 control plane API
/// New Proxy V1 control plane API
ProxyV1(cplane_proxy_v1::NeonControlPlaneClient),
/// Current Management API (V2).
Neon(neon::NeonControlPlaneClient),
/// Local mock control plane.
#[cfg(any(test, feature = "testing"))]
PostgresMock(mock::MockControlPlane),
@@ -46,6 +49,7 @@ impl ControlPlaneApi for ControlPlaneClient {
) -> Result<CachedRoleSecret, errors::GetAuthInfoError> {
match self {
Self::ProxyV1(api) => api.get_role_secret(ctx, user_info).await,
Self::Neon(api) => api.get_role_secret(ctx, user_info).await,
#[cfg(any(test, feature = "testing"))]
Self::PostgresMock(api) => api.get_role_secret(ctx, user_info).await,
#[cfg(test)]
@@ -62,6 +66,7 @@ impl ControlPlaneApi for ControlPlaneClient {
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), errors::GetAuthInfoError> {
match self {
Self::ProxyV1(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
Self::Neon(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
#[cfg(any(test, feature = "testing"))]
Self::PostgresMock(api) => api.get_allowed_ips_and_secret(ctx, user_info).await,
#[cfg(test)]
@@ -76,6 +81,7 @@ impl ControlPlaneApi for ControlPlaneClient {
) -> Result<Vec<AuthRule>, errors::GetEndpointJwksError> {
match self {
Self::ProxyV1(api) => api.get_endpoint_jwks(ctx, endpoint).await,
Self::Neon(api) => api.get_endpoint_jwks(ctx, endpoint).await,
#[cfg(any(test, feature = "testing"))]
Self::PostgresMock(api) => api.get_endpoint_jwks(ctx, endpoint).await,
#[cfg(test)]
@@ -90,6 +96,7 @@ impl ControlPlaneApi for ControlPlaneClient {
) -> Result<CachedNodeInfo, errors::WakeComputeError> {
match self {
Self::ProxyV1(api) => api.wake_compute(ctx, user_info).await,
Self::Neon(api) => api.wake_compute(ctx, user_info).await,
#[cfg(any(test, feature = "testing"))]
Self::PostgresMock(api) => api.wake_compute(ctx, user_info).await,
#[cfg(test)]

View File

@@ -0,0 +1,511 @@
//! Stale console backend, remove after migrating to Proxy V1 API (#15245).
use std::sync::Arc;
use std::time::Duration;
use ::http::header::AUTHORIZATION;
use ::http::HeaderName;
use futures::TryFutureExt;
use postgres_client::config::SslMode;
use tokio::time::Instant;
use tracing::{debug, info, info_span, warn, Instrument};
use super::super::messages::{ControlPlaneErrorMessage, GetRoleSecret, WakeCompute};
use crate::auth::backend::jwt::AuthRule;
use crate::auth::backend::ComputeUserInfo;
use crate::cache::Cached;
use crate::context::RequestContext;
use crate::control_plane::caches::ApiCaches;
use crate::control_plane::errors::{
ControlPlaneError, GetAuthInfoError, GetEndpointJwksError, WakeComputeError,
};
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::messages::{ColdStartInfo, EndpointJwksResponse, Reason};
use crate::control_plane::{
AuthInfo, AuthSecret, CachedAllowedIps, CachedNodeInfo, CachedRoleSecret, NodeInfo,
};
use crate::metrics::{CacheOutcome, Metrics};
use crate::rate_limiter::WakeComputeRateLimiter;
use crate::types::{EndpointCacheKey, EndpointId};
use crate::{compute, http, scram};
const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id");
#[derive(Clone)]
pub struct NeonControlPlaneClient {
endpoint: http::Endpoint,
pub caches: &'static ApiCaches,
pub(crate) locks: &'static ApiLocks<EndpointCacheKey>,
pub(crate) wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
// put in a shared ref so we don't copy secrets all over in memory
jwt: Arc<str>,
}
impl NeonControlPlaneClient {
/// Construct an API object containing the auth parameters.
pub fn new(
endpoint: http::Endpoint,
jwt: Arc<str>,
caches: &'static ApiCaches,
locks: &'static ApiLocks<EndpointCacheKey>,
wake_compute_endpoint_rate_limiter: Arc<WakeComputeRateLimiter>,
) -> Self {
Self {
endpoint,
caches,
locks,
wake_compute_endpoint_rate_limiter,
jwt,
}
}
pub(crate) fn url(&self) -> &str {
self.endpoint.url().as_str()
}
async fn do_get_auth_info(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<AuthInfo, GetAuthInfoError> {
if !self
.caches
.endpoints_cache
.is_valid(ctx, &user_info.endpoint.normalize())
{
// TODO: refactor this because it's weird
// this is a failure to authenticate but we return Ok.
info!("endpoint is not valid, skipping the request");
return Ok(AuthInfo::default());
}
let request_id = ctx.session_id().to_string();
let application_name = ctx.console_application_name();
async {
let request = self
.endpoint
.get_path("proxy_get_role_secret")
.header(X_REQUEST_ID, &request_id)
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.query(&[
("application_name", application_name.as_str()),
("project", user_info.endpoint.as_str()),
("role", user_info.user.as_str()),
])
.build()?;
debug!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
let response = self.endpoint.execute(request).await?;
drop(pause);
info!(duration = ?start.elapsed(), "received http response");
let body = match parse_body::<GetRoleSecret>(response).await {
Ok(body) => body,
// Error 404 is special: it's ok not to have a secret.
// TODO(anna): retry
Err(e) => {
return if e.get_reason().is_not_found() {
// TODO: refactor this because it's weird
// this is a failure to authenticate but we return Ok.
Ok(AuthInfo::default())
} else {
Err(e.into())
};
}
};
let secret = if body.role_secret.is_empty() {
None
} else {
let secret = scram::ServerSecret::parse(&body.role_secret)
.map(AuthSecret::Scram)
.ok_or(GetAuthInfoError::BadSecret)?;
Some(secret)
};
let allowed_ips = body.allowed_ips.unwrap_or_default();
Metrics::get()
.proxy
.allowed_ips_number
.observe(allowed_ips.len() as f64);
Ok(AuthInfo {
secret,
allowed_ips,
project_id: body.project_id,
})
}
.inspect_err(|e| tracing::debug!(error = ?e))
.instrument(info_span!("do_get_auth_info"))
.await
}
async fn do_get_endpoint_jwks(
&self,
ctx: &RequestContext,
endpoint: EndpointId,
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
if !self
.caches
.endpoints_cache
.is_valid(ctx, &endpoint.normalize())
{
return Err(GetEndpointJwksError::EndpointNotFound);
}
let request_id = ctx.session_id().to_string();
async {
let request = self
.endpoint
.get_with_url(|url| {
url.path_segments_mut()
.push("endpoints")
.push(endpoint.as_str())
.push("jwks");
})
.header(X_REQUEST_ID, &request_id)
.header(AUTHORIZATION, format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.build()
.map_err(GetEndpointJwksError::RequestBuild)?;
debug!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
let response = self
.endpoint
.execute(request)
.await
.map_err(GetEndpointJwksError::RequestExecute)?;
drop(pause);
info!(duration = ?start.elapsed(), "received http response");
let body = parse_body::<EndpointJwksResponse>(response).await?;
let rules = body
.jwks
.into_iter()
.map(|jwks| AuthRule {
id: jwks.id,
jwks_url: jwks.jwks_url,
audience: jwks.jwt_audience,
role_names: jwks.role_names,
})
.collect();
Ok(rules)
}
.inspect_err(|e| tracing::debug!(error = ?e))
.instrument(info_span!("do_get_endpoint_jwks"))
.await
}
async fn do_wake_compute(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<NodeInfo, WakeComputeError> {
let request_id = ctx.session_id().to_string();
let application_name = ctx.console_application_name();
async {
let mut request_builder = self
.endpoint
.get_path("proxy_wake_compute")
.header("X-Request-ID", &request_id)
.header("Authorization", format!("Bearer {}", &self.jwt))
.query(&[("session_id", ctx.session_id())])
.query(&[
("application_name", application_name.as_str()),
("project", user_info.endpoint.as_str()),
]);
let options = user_info.options.to_deep_object();
if !options.is_empty() {
request_builder = request_builder.query(&options);
}
let request = request_builder.build()?;
debug!(url = request.url().as_str(), "sending http request");
let start = Instant::now();
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Cplane);
let response = self.endpoint.execute(request).await?;
drop(pause);
info!(duration = ?start.elapsed(), "received http response");
let body = parse_body::<WakeCompute>(response).await?;
// Unfortunately, ownership won't let us use `Option::ok_or` here.
let (host, port) = match parse_host_port(&body.address) {
None => return Err(WakeComputeError::BadComputeAddress(body.address)),
Some(x) => x,
};
// Don't set anything but host and port! This config will be cached.
// We'll set username and such later using the startup message.
// TODO: add more type safety (in progress).
let mut config = compute::ConnCfg::new(host.to_owned(), port);
config.ssl_mode(SslMode::Disable); // TLS is not configured on compute nodes.
let node = NodeInfo {
config,
aux: body.aux,
allow_self_signed_compute: false,
};
Ok(node)
}
.inspect_err(|e| tracing::debug!(error = ?e))
.instrument(info_span!("do_wake_compute"))
.await
}
}
impl super::ControlPlaneApi for NeonControlPlaneClient {
#[tracing::instrument(skip_all)]
async fn get_role_secret(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedRoleSecret, GetAuthInfoError> {
let normalized_ep = &user_info.endpoint.normalize();
let user = &user_info.user;
if let Some(role_secret) = self
.caches
.project_info
.get_role_secret(normalized_ep, user)
{
return Ok(role_secret);
}
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();
self.caches.project_info.insert_role_secret(
project_id,
normalized_ep_int,
user.into(),
auth_info.secret.clone(),
);
self.caches.project_info.insert_allowed_ips(
project_id,
normalized_ep_int,
Arc::new(auth_info.allowed_ips),
);
ctx.set_project_id(project_id);
}
// When we just got a secret, we don't need to invalidate it.
Ok(Cached::new_uncached(auth_info.secret))
}
async fn get_allowed_ips_and_secret(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<(CachedAllowedIps, Option<CachedRoleSecret>), GetAuthInfoError> {
let normalized_ep = &user_info.endpoint.normalize();
if let Some(allowed_ips) = self.caches.project_info.get_allowed_ips(normalized_ep) {
Metrics::get()
.proxy
.allowed_ips_cache_misses
.inc(CacheOutcome::Hit);
return Ok((allowed_ips, None));
}
Metrics::get()
.proxy
.allowed_ips_cache_misses
.inc(CacheOutcome::Miss);
let auth_info = self.do_get_auth_info(ctx, user_info).await?;
let allowed_ips = Arc::new(auth_info.allowed_ips);
let user = &user_info.user;
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();
self.caches.project_info.insert_role_secret(
project_id,
normalized_ep_int,
user.into(),
auth_info.secret.clone(),
);
self.caches.project_info.insert_allowed_ips(
project_id,
normalized_ep_int,
allowed_ips.clone(),
);
ctx.set_project_id(project_id);
}
Ok((
Cached::new_uncached(allowed_ips),
Some(Cached::new_uncached(auth_info.secret)),
))
}
#[tracing::instrument(skip_all)]
async fn get_endpoint_jwks(
&self,
ctx: &RequestContext,
endpoint: EndpointId,
) -> Result<Vec<AuthRule>, GetEndpointJwksError> {
self.do_get_endpoint_jwks(ctx, endpoint).await
}
#[tracing::instrument(skip_all)]
async fn wake_compute(
&self,
ctx: &RequestContext,
user_info: &ComputeUserInfo,
) -> Result<CachedNodeInfo, WakeComputeError> {
let key = user_info.endpoint_cache_key();
macro_rules! check_cache {
() => {
if let Some(cached) = self.caches.node_info.get(&key) {
let (cached, info) = cached.take_value();
let info = info.map_err(|c| {
info!(key = &*key, "found cached wake_compute error");
WakeComputeError::ControlPlane(ControlPlaneError::Message(Box::new(*c)))
})?;
debug!(key = &*key, "found cached compute node info");
ctx.set_project(info.aux.clone());
return Ok(cached.map(|()| info));
}
};
}
// Every time we do a wakeup http request, the compute node will stay up
// for some time (highly depends on the console's scale-to-zero policy);
// The connection info remains the same during that period of time,
// which means that we might cache it to reduce the load and latency.
check_cache!();
let permit = self.locks.get_permit(&key).await?;
// after getting back a permit - it's possible the cache was filled
// double check
if permit.should_check_cache() {
// TODO: if there is something in the cache, mark the permit as success.
check_cache!();
}
// check rate limit
if !self
.wake_compute_endpoint_rate_limiter
.check(user_info.endpoint.normalize_intern(), 1)
{
return Err(WakeComputeError::TooManyConnections);
}
let node = permit.release_result(self.do_wake_compute(ctx, user_info).await);
match node {
Ok(node) => {
ctx.set_project(node.aux.clone());
debug!(key = &*key, "created a cache entry for woken compute node");
let mut stored_node = node.clone();
// store the cached node as 'warm_cached'
stored_node.aux.cold_start_info = ColdStartInfo::WarmCached;
let (_, cached) = self.caches.node_info.insert_unit(key, Ok(stored_node));
Ok(cached.map(|()| node))
}
Err(err) => match err {
WakeComputeError::ControlPlane(ControlPlaneError::Message(err)) => {
let Some(status) = &err.status else {
return Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
err,
)));
};
let reason = status
.details
.error_info
.map_or(Reason::Unknown, |x| x.reason);
// if we can retry this error, do not cache it.
if reason.can_retry() {
return Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
err,
)));
}
// at this point, we should only have quota errors.
debug!(
key = &*key,
"created a cache entry for the wake compute error"
);
self.caches.node_info.insert_ttl(
key,
Err(err.clone()),
Duration::from_secs(30),
);
Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
err,
)))
}
err => return Err(err),
},
}
}
}
/// Parse http response body, taking status code into account.
async fn parse_body<T: for<'a> serde::Deserialize<'a>>(
response: http::Response,
) -> Result<T, ControlPlaneError> {
let status = response.status();
if status.is_success() {
// We shouldn't log raw body because it may contain secrets.
info!("request succeeded, processing the body");
return Ok(response.json().await?);
}
let s = response.bytes().await?;
// Log plaintext to be able to detect, whether there are some cases not covered by the error struct.
info!("response_error plaintext: {:?}", s);
// Don't throw an error here because it's not as important
// as the fact that the request itself has failed.
let mut body = serde_json::from_slice(&s).unwrap_or_else(|e| {
warn!("failed to parse error body: {e}");
ControlPlaneErrorMessage {
error: "reason unclear (malformed error message)".into(),
http_status_code: status,
status: None,
}
});
body.http_status_code = status;
warn!("console responded with an error ({status}): {body:?}");
Err(ControlPlaneError::Message(Box::new(body)))
}
fn parse_host_port(input: &str) -> Option<(&str, u16)> {
let (host, port) = input.rsplit_once(':')?;
let ipv6_brackets: &[_] = &['[', ']'];
Some((host.trim_matches(ipv6_brackets), port.parse().ok()?))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_host_port_v4() {
let (host, port) = parse_host_port("127.0.0.1:5432").expect("failed to parse");
assert_eq!(host, "127.0.0.1");
assert_eq!(port, 5432);
}
#[test]
fn test_parse_host_port_v6() {
let (host, port) = parse_host_port("[2001:db8::1]:5432").expect("failed to parse");
assert_eq!(host, "2001:db8::1");
assert_eq!(port, 5432);
}
#[test]
fn test_parse_host_port_url() {
let (host, port) = parse_host_port("compute-foo-bar-1234.default.svc.cluster.local:5432")
.expect("failed to parse");
assert_eq!(host, "compute-foo-bar-1234.default.svc.cluster.local");
assert_eq!(port, 5432);
}
}

View File

@@ -221,6 +221,15 @@ pub(crate) struct UserFacingMessage {
pub(crate) message: Box<str>,
}
/// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`].
/// Returned by the `/proxy_get_role_secret` API method.
#[derive(Deserialize)]
pub(crate) struct GetRoleSecret {
pub(crate) role_secret: Box<str>,
pub(crate) allowed_ips: Option<Vec<IpPattern>>,
pub(crate) project_id: Option<ProjectIdInt>,
}
/// Response which holds client's auth secret, e.g. [`crate::scram::ServerSecret`].
/// Returned by the `/get_endpoint_access_control` API method.
#[derive(Deserialize)]
@@ -231,6 +240,13 @@ pub(crate) struct GetEndpointAccessControl {
pub(crate) allowed_vpc_endpoint_ids: Option<Vec<EndpointIdInt>>,
}
// Manually implement debug to omit sensitive info.
impl fmt::Debug for GetRoleSecret {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GetRoleSecret").finish_non_exhaustive()
}
}
/// Response which holds compute node's `host:port` pair.
/// Returned by the `/proxy_wake_compute` API method.
#[derive(Debug, Deserialize)]
@@ -461,18 +477,18 @@ mod tests {
let json = json!({
"role_secret": "secret",
});
serde_json::from_str::<GetEndpointAccessControl>(&json.to_string())?;
serde_json::from_str::<GetRoleSecret>(&json.to_string())?;
let json = json!({
"role_secret": "secret",
"allowed_ips": ["8.8.8.8"],
});
serde_json::from_str::<GetEndpointAccessControl>(&json.to_string())?;
serde_json::from_str::<GetRoleSecret>(&json.to_string())?;
let json = json!({
"role_secret": "secret",
"allowed_ips": ["8.8.8.8"],
"project_id": "project",
});
serde_json::from_str::<GetEndpointAccessControl>(&json.to_string())?;
serde_json::from_str::<GetRoleSecret>(&json.to_string())?;
Ok(())
}

View File

@@ -272,36 +272,32 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin>(
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Client);
let do_handshake = handshake(ctx, stream, mode.handshake_tls(tls), record_handshake_error);
let (mut stream, params) = match tokio::time::timeout(config.handshake_timeout, do_handshake)
.await??
{
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session_id = ctx.session_id();
let peer_ip = ctx.peer_addr();
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?session_id);
cancel_span.follows_from(tracing::Span::current());
async move {
drop(
cancellation_handler_clone
.cancel_session(
cancel_key_data,
session_id,
peer_ip,
config.authentication_config.ip_allowlist_check_enabled,
)
.instrument(cancel_span)
.await,
);
}
});
let (mut stream, params) =
match tokio::time::timeout(config.handshake_timeout, do_handshake).await?? {
HandshakeData::Startup(stream, params) => (stream, params),
HandshakeData::Cancel(cancel_key_data) => {
// spawn a task to cancel the session, but don't wait for it
cancellations.spawn({
let cancellation_handler_clone = Arc::clone(&cancellation_handler);
let session_id = ctx.session_id();
let peer_ip = ctx.peer_addr();
async move {
drop(
cancellation_handler_clone
.cancel_session(
cancel_key_data,
session_id,
peer_ip,
config.authentication_config.ip_allowlist_check_enabled,
)
.await,
);
}
});
return Ok(None);
}
};
return Ok(None);
}
};
drop(pause);
ctx.set_db_options(params.clone());

View File

@@ -13,7 +13,6 @@ use crate::cache::project_info::ProjectInfoCache;
use crate::cancellation::{CancelMap, CancellationHandler};
use crate::intern::{ProjectIdInt, RoleNameInt};
use crate::metrics::{Metrics, RedisErrors, RedisEventsCount};
use tracing::Instrument;
const CPLANE_CHANNEL_NAME: &str = "neondb-proxy-ws-updates";
pub(crate) const PROXY_CHANNEL_NAME: &str = "neondb-proxy-to-proxy-updates";
@@ -144,8 +143,6 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
let peer_addr = cancel_session
.peer_addr
.unwrap_or(std::net::IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED));
let cancel_span = tracing::span!(parent: None, tracing::Level::INFO, "cancel_session", session_id = ?cancel_session.session_id);
cancel_span.follows_from(tracing::Span::current());
// This instance of cancellation_handler doesn't have a RedisPublisherClient so it can't publish the message.
match self
.cancellation_handler
@@ -155,7 +152,6 @@ impl<C: ProjectInfoCache + Send + Sync + 'static> MessageHandler<C> {
peer_addr,
cancel_session.peer_addr.is_some(),
)
.instrument(cancel_span)
.await
{
Ok(()) => {}

View File

@@ -340,7 +340,7 @@ impl PoolingBackend {
debug!("setting up backend session state");
// initiates the auth session
if let Err(e) = client.execute("select auth.init()", &[]).await {
if let Err(e) = client.batch_execute("select auth.init();").await {
discard.discard();
return Err(e.into());
}

View File

@@ -11,7 +11,7 @@ use smallvec::SmallVec;
use tokio::net::TcpStream;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use tracing::{error, info, info_span, warn, Instrument};
use tracing::{debug, error, info, info_span, Instrument};
#[cfg(test)]
use {
super::conn_pool_lib::GlobalConnPoolOptions,
@@ -125,13 +125,10 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
match message {
Some(Ok(AsyncMessage::Notice(notice))) => {
info!(%session_id, "notice: {}", notice);
debug!(%session_id, "notice: {}", notice);
}
Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
}
Some(Ok(_)) => {
warn!(%session_id, "unknown message");
debug!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
}
Some(Err(e)) => {
error!(%session_id, "connection error: {}", e);

View File

@@ -1,5 +1,5 @@
use postgres_client::types::{Kind, Type};
use postgres_client::Row;
use postgres_client::{Column, Row};
use serde_json::{Map, Value};
//
@@ -77,14 +77,14 @@ pub(crate) enum JsonConversionError {
//
pub(crate) fn pg_text_row_to_json(
row: &Row,
columns: &[Type],
columns: &[Column],
c_types: &[Type],
raw_output: bool,
array_mode: bool,
) -> Result<Value, JsonConversionError> {
let iter = row
.columns()
let iter = columns
.iter()
.zip(columns)
.zip(c_types)
.enumerate()
.map(|(i, (column, typ))| {
let name = column.name();

View File

@@ -23,14 +23,13 @@ use jose_jwk::jose_b64::base64ct::{Base64UrlUnpadded, Encoding};
use p256::ecdsa::{Signature, SigningKey};
use parking_lot::RwLock;
use postgres_client::tls::NoTlsStream;
use postgres_client::types::ToSql;
use postgres_client::AsyncMessage;
use serde_json::value::RawValue;
use signature::Signer;
use tokio::net::TcpStream;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, info_span, warn, Instrument};
use tracing::{debug, error, info, info_span, Instrument};
use super::backend::HttpConnError;
use super::conn_pool_lib::{
@@ -229,13 +228,10 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
match message {
Some(Ok(AsyncMessage::Notice(notice))) => {
info!(%session_id, "notice: {}", notice);
debug!(%session_id, "notice: {}", notice);
}
Some(Ok(AsyncMessage::Notification(notif))) => {
warn!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
}
Some(Ok(_)) => {
warn!(%session_id, "unknown message");
debug!(%session_id, pid = notif.process_id(), channel = notif.channel(), "notification received");
}
Some(Err(e)) => {
error!(%session_id, "connection error: {}", e);
@@ -287,12 +283,11 @@ impl ClientInnerCommon<postgres_client::Client> {
let token = resign_jwt(&local_data.key, payload, local_data.jti)?;
// initiates the auth session
self.inner.batch_execute("discard all").await?;
// the token contains only `[a-zA-Z1-9_-\.]+` so it cannot escape the string literal formatting.
self.inner
.execute(
"select auth.jwt_session_init($1)",
&[&&*token as &(dyn ToSql + Sync)],
)
.batch_execute(&format!(
"discard all; select auth.jwt_session_init('{token}');"
))
.await?;
let pid = self.inner.get_process_id();

View File

@@ -797,7 +797,13 @@ impl QueryData {
let cancel_token = inner.cancel_token();
let res = match select(
pin!(query_to_json(config, &*inner, self, &mut 0, parsed_headers)),
pin!(query_to_json(
config,
&mut *inner,
self,
&mut 0,
parsed_headers
)),
pin!(cancel.cancelled()),
)
.await
@@ -881,7 +887,7 @@ impl BatchQueryData {
builder = builder.deferrable(true);
}
let transaction = builder.start().await.inspect_err(|_| {
let mut transaction = builder.start().await.inspect_err(|_| {
// if we cannot start a transaction, we should return immediately
// and not return to the pool. connection is clearly broken
discard.discard();
@@ -890,7 +896,7 @@ impl BatchQueryData {
let json_output = match query_batch(
config,
cancel.child_token(),
&transaction,
&mut transaction,
self,
parsed_headers,
)
@@ -934,7 +940,7 @@ impl BatchQueryData {
async fn query_batch(
config: &'static HttpConfig,
cancel: CancellationToken,
transaction: &Transaction<'_>,
transaction: &mut Transaction<'_>,
queries: BatchQueryData,
parsed_headers: HttpHeaders,
) -> Result<String, SqlOverHttpError> {
@@ -972,7 +978,7 @@ async fn query_batch(
async fn query_to_json<T: GenericClient>(
config: &'static HttpConfig,
client: &T,
client: &mut T,
data: QueryData,
current_size: &mut usize,
parsed_headers: HttpHeaders,
@@ -1027,7 +1033,7 @@ async fn query_to_json<T: GenericClient>(
let columns_len = row_stream.columns().len();
let mut fields = Vec::with_capacity(columns_len);
let mut columns = Vec::with_capacity(columns_len);
let mut c_types = Vec::with_capacity(columns_len);
for c in row_stream.columns() {
fields.push(json!({
@@ -1039,7 +1045,7 @@ async fn query_to_json<T: GenericClient>(
"dataTypeModifier": c.type_modifier(),
"format": "text",
}));
columns.push(client.get_type(c.type_oid()).await?);
c_types.push(client.get_type(c.type_oid()).await?);
}
let array_mode = data.array_mode.unwrap_or(parsed_headers.default_array_mode);
@@ -1047,7 +1053,15 @@ async fn query_to_json<T: GenericClient>(
// convert rows to JSON
let rows = rows
.iter()
.map(|row| pg_text_row_to_json(row, &columns, parsed_headers.raw_output, array_mode))
.map(|row| {
pg_text_row_to_json(
row,
row_stream.columns(),
&c_types,
parsed_headers.raw_output,
array_mode,
)
})
.collect::<Result<Vec<_>, _>>()?;
// Resulting JSON format is based on the format of node-postgres result.

View File

@@ -83,20 +83,14 @@ impl Env {
node_id: NodeId,
ttid: TenantTimelineId,
) -> anyhow::Result<Arc<Timeline>> {
let conf = Arc::new(self.make_conf(node_id));
let conf = self.make_conf(node_id);
let timeline_dir = get_timeline_dir(&conf, &ttid);
let remote_path = remote_timeline_path(&ttid)?;
let safekeeper = self.make_safekeeper(node_id, ttid).await?;
let shared_state = SharedState::new(StateSK::Loaded(safekeeper));
let timeline = Timeline::new(
ttid,
&timeline_dir,
&remote_path,
shared_state,
conf.clone(),
);
let timeline = Timeline::new(ttid, &timeline_dir, &remote_path, shared_state);
timeline.bootstrap(
&mut timeline.write_shared_state().await,
&conf,

View File

@@ -338,7 +338,7 @@ async fn main() -> anyhow::Result<()> {
}
};
let conf = Arc::new(SafeKeeperConf {
let conf = SafeKeeperConf {
workdir,
my_id: id,
listen_pg_addr: args.listen_pg,
@@ -368,7 +368,7 @@ async fn main() -> anyhow::Result<()> {
control_file_save_interval: args.control_file_save_interval,
partial_backup_concurrency: args.partial_backup_concurrency,
eviction_min_resident: args.eviction_min_resident,
});
};
// initialize sentry if SENTRY_DSN is provided
let _sentry_guard = init_sentry(
@@ -382,7 +382,7 @@ async fn main() -> anyhow::Result<()> {
/// complete, e.g. panicked, inner is error produced by task itself.
type JoinTaskRes = Result<anyhow::Result<()>, JoinError>;
async fn start_safekeeper(conf: Arc<SafeKeeperConf>) -> Result<()> {
async fn start_safekeeper(conf: SafeKeeperConf) -> Result<()> {
// fsync the datadir to make sure we have a consistent state on disk.
if !conf.no_sync {
let dfd = File::open(&conf.workdir).context("open datadir for syncfs")?;
@@ -428,11 +428,9 @@ async fn start_safekeeper(conf: Arc<SafeKeeperConf>) -> Result<()> {
e
})?;
let global_timelines = Arc::new(GlobalTimelines::new(conf.clone()));
// Register metrics collector for active timelines. It's important to do this
// after daemonizing, otherwise process collector will be upset.
let timeline_collector = safekeeper::metrics::TimelineCollector::new(global_timelines.clone());
let timeline_collector = safekeeper::metrics::TimelineCollector::new();
metrics::register_internal(Box::new(timeline_collector))?;
wal_backup::init_remote_storage(&conf).await;
@@ -449,8 +447,9 @@ async fn start_safekeeper(conf: Arc<SafeKeeperConf>) -> Result<()> {
.then(|| Handle::try_current().expect("no runtime in main"));
// Load all timelines from disk to memory.
global_timelines.init().await?;
GlobalTimelines::init(conf.clone()).await?;
let conf_ = conf.clone();
// Run everything in current thread rt, if asked.
if conf.current_thread_runtime {
info!("running in current thread runtime");
@@ -460,16 +459,14 @@ async fn start_safekeeper(conf: Arc<SafeKeeperConf>) -> Result<()> {
.as_ref()
.unwrap_or_else(|| WAL_SERVICE_RUNTIME.handle())
.spawn(wal_service::task_main(
conf.clone(),
conf_,
pg_listener,
Scope::SafekeeperData,
global_timelines.clone(),
))
// wrap with task name for error reporting
.map(|res| ("WAL service main".to_owned(), res));
tasks_handles.push(Box::pin(wal_service_handle));
let global_timelines_ = global_timelines.clone();
let timeline_housekeeping_handle = current_thread_rt
.as_ref()
.unwrap_or_else(|| WAL_SERVICE_RUNTIME.handle())
@@ -477,45 +474,40 @@ async fn start_safekeeper(conf: Arc<SafeKeeperConf>) -> Result<()> {
const TOMBSTONE_TTL: Duration = Duration::from_secs(3600 * 24);
loop {
tokio::time::sleep(TOMBSTONE_TTL).await;
global_timelines_.housekeeping(&TOMBSTONE_TTL);
GlobalTimelines::housekeeping(&TOMBSTONE_TTL);
}
})
.map(|res| ("Timeline map housekeeping".to_owned(), res));
tasks_handles.push(Box::pin(timeline_housekeeping_handle));
if let Some(pg_listener_tenant_only) = pg_listener_tenant_only {
let conf_ = conf.clone();
let wal_service_handle = current_thread_rt
.as_ref()
.unwrap_or_else(|| WAL_SERVICE_RUNTIME.handle())
.spawn(wal_service::task_main(
conf.clone(),
conf_,
pg_listener_tenant_only,
Scope::Tenant,
global_timelines.clone(),
))
// wrap with task name for error reporting
.map(|res| ("WAL service tenant only main".to_owned(), res));
tasks_handles.push(Box::pin(wal_service_handle));
}
let conf_ = conf.clone();
let http_handle = current_thread_rt
.as_ref()
.unwrap_or_else(|| HTTP_RUNTIME.handle())
.spawn(http::task_main(
conf.clone(),
http_listener,
global_timelines.clone(),
))
.spawn(http::task_main(conf_, http_listener))
.map(|res| ("HTTP service main".to_owned(), res));
tasks_handles.push(Box::pin(http_handle));
let conf_ = conf.clone();
let broker_task_handle = current_thread_rt
.as_ref()
.unwrap_or_else(|| BROKER_RUNTIME.handle())
.spawn(
broker::task_main(conf.clone(), global_timelines.clone())
.instrument(info_span!("broker")),
)
.spawn(broker::task_main(conf_).instrument(info_span!("broker")))
.map(|res| ("broker main".to_owned(), res));
tasks_handles.push(Box::pin(broker_task_handle));

View File

@@ -39,17 +39,14 @@ const RETRY_INTERVAL_MSEC: u64 = 1000;
const PUSH_INTERVAL_MSEC: u64 = 1000;
/// Push once in a while data about all active timelines to the broker.
async fn push_loop(
conf: Arc<SafeKeeperConf>,
global_timelines: Arc<GlobalTimelines>,
) -> anyhow::Result<()> {
async fn push_loop(conf: SafeKeeperConf) -> anyhow::Result<()> {
if conf.disable_periodic_broker_push {
info!("broker push_loop is disabled, doing nothing...");
futures::future::pending::<()>().await; // sleep forever
return Ok(());
}
let active_timelines_set = global_timelines.get_global_broker_active_set();
let active_timelines_set = GlobalTimelines::get_global_broker_active_set();
let mut client =
storage_broker::connect(conf.broker_endpoint.clone(), conf.broker_keepalive_interval)?;
@@ -90,13 +87,8 @@ async fn push_loop(
/// Subscribe and fetch all the interesting data from the broker.
#[instrument(name = "broker_pull", skip_all)]
async fn pull_loop(
conf: Arc<SafeKeeperConf>,
global_timelines: Arc<GlobalTimelines>,
stats: Arc<BrokerStats>,
) -> Result<()> {
let mut client =
storage_broker::connect(conf.broker_endpoint.clone(), conf.broker_keepalive_interval)?;
async fn pull_loop(conf: SafeKeeperConf, stats: Arc<BrokerStats>) -> Result<()> {
let mut client = storage_broker::connect(conf.broker_endpoint, conf.broker_keepalive_interval)?;
// TODO: subscribe only to local timelines instead of all
let request = SubscribeSafekeeperInfoRequest {
@@ -121,7 +113,7 @@ async fn pull_loop(
.as_ref()
.ok_or_else(|| anyhow!("missing tenant_timeline_id"))?;
let ttid = parse_proto_ttid(proto_ttid)?;
if let Ok(tli) = global_timelines.get(ttid) {
if let Ok(tli) = GlobalTimelines::get(ttid) {
// Note that we also receive *our own* info. That's
// important, as it is used as an indication of live
// connection to the broker.
@@ -143,11 +135,7 @@ async fn pull_loop(
/// Process incoming discover requests. This is done in a separate task to avoid
/// interfering with the normal pull/push loops.
async fn discover_loop(
conf: Arc<SafeKeeperConf>,
global_timelines: Arc<GlobalTimelines>,
stats: Arc<BrokerStats>,
) -> Result<()> {
async fn discover_loop(conf: SafeKeeperConf, stats: Arc<BrokerStats>) -> Result<()> {
let mut client =
storage_broker::connect(conf.broker_endpoint.clone(), conf.broker_keepalive_interval)?;
@@ -183,7 +171,7 @@ async fn discover_loop(
.as_ref()
.ok_or_else(|| anyhow!("missing tenant_timeline_id"))?;
let ttid = parse_proto_ttid(proto_ttid)?;
if let Ok(tli) = global_timelines.get(ttid) {
if let Ok(tli) = GlobalTimelines::get(ttid) {
// we received a discovery request for a timeline we know about
discover_counter.inc();
@@ -222,10 +210,7 @@ async fn discover_loop(
bail!("end of stream");
}
pub async fn task_main(
conf: Arc<SafeKeeperConf>,
global_timelines: Arc<GlobalTimelines>,
) -> anyhow::Result<()> {
pub async fn task_main(conf: SafeKeeperConf) -> anyhow::Result<()> {
info!("started, broker endpoint {:?}", conf.broker_endpoint);
let mut ticker = tokio::time::interval(Duration::from_millis(RETRY_INTERVAL_MSEC));
@@ -276,13 +261,13 @@ pub async fn task_main(
},
_ = ticker.tick() => {
if push_handle.is_none() {
push_handle = Some(tokio::spawn(push_loop(conf.clone(), global_timelines.clone())));
push_handle = Some(tokio::spawn(push_loop(conf.clone())));
}
if pull_handle.is_none() {
pull_handle = Some(tokio::spawn(pull_loop(conf.clone(), global_timelines.clone(), stats.clone())));
pull_handle = Some(tokio::spawn(pull_loop(conf.clone(), stats.clone())));
}
if discover_handle.is_none() {
discover_handle = Some(tokio::spawn(discover_loop(conf.clone(), global_timelines.clone(), stats.clone())));
discover_handle = Some(tokio::spawn(discover_loop(conf.clone(), stats.clone())));
}
},
_ = &mut stats_task => {}

View File

@@ -1,7 +1,9 @@
use std::sync::Arc;
use anyhow::{bail, Result};
use camino::Utf8PathBuf;
use postgres_ffi::{MAX_SEND_SIZE, WAL_SEGMENT_SIZE};
use std::sync::Arc;
use tokio::{
fs::OpenOptions,
io::{AsyncSeekExt, AsyncWriteExt},
@@ -12,7 +14,7 @@ use utils::{id::TenantTimelineId, lsn::Lsn};
use crate::{
control_file::FileStorage,
state::TimelinePersistentState,
timeline::{TimelineError, WalResidentTimeline},
timeline::{Timeline, TimelineError, WalResidentTimeline},
timelines_global_map::{create_temp_timeline_dir, validate_temp_timeline},
wal_backup::copy_s3_segments,
wal_storage::{wal_file_paths, WalReader},
@@ -23,19 +25,16 @@ use crate::{
const MAX_BACKUP_LAG: u64 = 10 * WAL_SEGMENT_SIZE as u64;
pub struct Request {
pub source_ttid: TenantTimelineId,
pub source: Arc<Timeline>,
pub until_lsn: Lsn,
pub destination_ttid: TenantTimelineId,
}
pub async fn handle_request(
request: Request,
global_timelines: Arc<GlobalTimelines>,
) -> Result<()> {
pub async fn handle_request(request: Request) -> Result<()> {
// TODO: request.until_lsn MUST be a valid LSN, and we cannot check it :(
// if LSN will point to the middle of a WAL record, timeline will be in "broken" state
match global_timelines.get(request.destination_ttid) {
match GlobalTimelines::get(request.destination_ttid) {
// timeline already exists. would be good to check that this timeline is the copy
// of the source timeline, but it isn't obvious how to do that
Ok(_) => return Ok(()),
@@ -47,10 +46,9 @@ pub async fn handle_request(
}
}
let source = global_timelines.get(request.source_ttid)?;
let source_tli = source.wal_residence_guard().await?;
let source_tli = request.source.wal_residence_guard().await?;
let conf = &global_timelines.get_global_config();
let conf = &GlobalTimelines::get_global_config();
let ttid = request.destination_ttid;
let (_tmp_dir, tli_dir_path) = create_temp_timeline_dir(conf, ttid).await?;
@@ -129,7 +127,7 @@ pub async fn handle_request(
copy_s3_segments(
wal_seg_size,
&request.source_ttid,
&request.source.ttid,
&request.destination_ttid,
first_segment,
first_ondisk_segment,
@@ -160,9 +158,7 @@ pub async fn handle_request(
// now we have a ready timeline in a temp directory
validate_temp_timeline(conf, request.destination_ttid, &tli_dir_path).await?;
global_timelines
.load_temp_timeline(request.destination_ttid, &tli_dir_path, true)
.await?;
GlobalTimelines::load_temp_timeline(request.destination_ttid, &tli_dir_path, true).await?;
Ok(())
}

View File

@@ -207,23 +207,23 @@ pub struct FileInfo {
}
/// Build debug dump response, using the provided [`Args`] filters.
pub async fn build(args: Args, global_timelines: Arc<GlobalTimelines>) -> Result<Response> {
pub async fn build(args: Args) -> Result<Response> {
let start_time = Utc::now();
let timelines_count = global_timelines.timelines_count();
let config = global_timelines.get_global_config();
let timelines_count = GlobalTimelines::timelines_count();
let config = GlobalTimelines::get_global_config();
let ptrs_snapshot = if args.tenant_id.is_some() && args.timeline_id.is_some() {
// If both tenant_id and timeline_id are specified, we can just get the
// timeline directly, without taking a snapshot of the whole list.
let ttid = TenantTimelineId::new(args.tenant_id.unwrap(), args.timeline_id.unwrap());
if let Ok(tli) = global_timelines.get(ttid) {
if let Ok(tli) = GlobalTimelines::get(ttid) {
vec![tli]
} else {
vec![]
}
} else {
// Otherwise, take a snapshot of the whole list.
global_timelines.get_all()
GlobalTimelines::get_all()
};
let mut timelines = Vec::new();
@@ -344,12 +344,12 @@ fn get_wal_last_modified(path: &Utf8Path) -> Result<Option<DateTime<Utc>>> {
/// Converts SafeKeeperConf to Config, filtering out the fields that are not
/// supposed to be exposed.
fn build_config(config: Arc<SafeKeeperConf>) -> Config {
fn build_config(config: SafeKeeperConf) -> Config {
Config {
id: config.my_id,
workdir: config.workdir.clone().into(),
listen_pg_addr: config.listen_pg_addr.clone(),
listen_http_addr: config.listen_http_addr.clone(),
workdir: config.workdir.into(),
listen_pg_addr: config.listen_pg_addr,
listen_http_addr: config.listen_http_addr,
no_sync: config.no_sync,
max_offloader_lag_bytes: config.max_offloader_lag_bytes,
wal_backup_enabled: config.wal_backup_enabled,

View File

@@ -33,7 +33,7 @@ use utils::{
/// Safekeeper handler of postgres commands
pub struct SafekeeperPostgresHandler {
pub conf: Arc<SafeKeeperConf>,
pub conf: SafeKeeperConf,
/// assigned application name
pub appname: Option<String>,
pub tenant_id: Option<TenantId>,
@@ -43,7 +43,6 @@ pub struct SafekeeperPostgresHandler {
pub protocol: Option<PostgresClientProtocol>,
/// Unique connection id is logged in spans for observability.
pub conn_id: ConnectionId,
pub global_timelines: Arc<GlobalTimelines>,
/// Auth scope allowed on the connections and public key used to check auth tokens. None if auth is not configured.
auth: Option<(Scope, Arc<JwtAuth>)>,
claims: Option<Claims>,
@@ -315,11 +314,10 @@ impl<IO: AsyncRead + AsyncWrite + Unpin + Send> postgres_backend::Handler<IO>
impl SafekeeperPostgresHandler {
pub fn new(
conf: Arc<SafeKeeperConf>,
conf: SafeKeeperConf,
conn_id: u32,
io_metrics: Option<TrafficMetrics>,
auth: Option<(Scope, Arc<JwtAuth>)>,
global_timelines: Arc<GlobalTimelines>,
) -> Self {
SafekeeperPostgresHandler {
conf,
@@ -333,7 +331,6 @@ impl SafekeeperPostgresHandler {
claims: None,
auth,
io_metrics,
global_timelines,
}
}
@@ -363,7 +360,7 @@ impl SafekeeperPostgresHandler {
pgb: &mut PostgresBackend<IO>,
) -> Result<(), QueryError> {
// Get timeline, handling "not found" error
let tli = match self.global_timelines.get(self.ttid) {
let tli = match GlobalTimelines::get(self.ttid) {
Ok(tli) => Ok(Some(tli)),
Err(TimelineError::NotFound(_)) => Ok(None),
Err(e) => Err(QueryError::Other(e.into())),
@@ -397,10 +394,7 @@ impl SafekeeperPostgresHandler {
&mut self,
pgb: &mut PostgresBackend<IO>,
) -> Result<(), QueryError> {
let tli = self
.global_timelines
.get(self.ttid)
.map_err(|e| QueryError::Other(e.into()))?;
let tli = GlobalTimelines::get(self.ttid).map_err(|e| QueryError::Other(e.into()))?;
let lsn = if self.is_walproposer_recovery() {
// walproposer should get all local WAL until flush_lsn

View File

@@ -3,16 +3,14 @@ pub mod routes;
pub use routes::make_router;
pub use safekeeper_api::models;
use std::sync::Arc;
use crate::{GlobalTimelines, SafeKeeperConf};
use crate::SafeKeeperConf;
pub async fn task_main(
conf: Arc<SafeKeeperConf>,
conf: SafeKeeperConf,
http_listener: std::net::TcpListener,
global_timelines: Arc<GlobalTimelines>,
) -> anyhow::Result<()> {
let router = make_router(conf, global_timelines)
let router = make_router(conf)
.build()
.map_err(|err| anyhow::anyhow!(err))?;
let service = utils::http::RouterService::new(router).unwrap();

View File

@@ -66,13 +66,6 @@ fn get_conf(request: &Request<Body>) -> &SafeKeeperConf {
.as_ref()
}
fn get_global_timelines(request: &Request<Body>) -> Arc<GlobalTimelines> {
request
.data::<Arc<GlobalTimelines>>()
.expect("unknown state type")
.clone()
}
/// Same as TermLsn, but serializes LSN using display serializer
/// in Postgres format, i.e. 0/FFFFFFFF. Used only for the API response.
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
@@ -130,11 +123,9 @@ async fn tenant_delete_handler(mut request: Request<Body>) -> Result<Response<Bo
let only_local = parse_query_param(&request, "only_local")?.unwrap_or(false);
check_permission(&request, Some(tenant_id))?;
ensure_no_body(&mut request).await?;
let global_timelines = get_global_timelines(&request);
// FIXME: `delete_force_all_for_tenant` can return an error for multiple different reasons;
// Using an `InternalServerError` should be fixed when the types support it
let delete_info = global_timelines
.delete_force_all_for_tenant(&tenant_id, only_local)
let delete_info = GlobalTimelines::delete_force_all_for_tenant(&tenant_id, only_local)
.await
.map_err(ApiError::InternalServerError)?;
json_response(
@@ -165,9 +156,7 @@ async fn timeline_create_handler(mut request: Request<Body>) -> Result<Response<
.commit_lsn
.segment_lsn(server_info.wal_seg_size as usize)
});
let global_timelines = get_global_timelines(&request);
global_timelines
.create(ttid, server_info, request_data.commit_lsn, local_start_lsn)
GlobalTimelines::create(ttid, server_info, request_data.commit_lsn, local_start_lsn)
.await
.map_err(ApiError::InternalServerError)?;
@@ -178,9 +167,7 @@ async fn timeline_create_handler(mut request: Request<Body>) -> Result<Response<
/// Note: it is possible to do the same with debug_dump.
async fn timeline_list_handler(request: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permission(&request, None)?;
let global_timelines = get_global_timelines(&request);
let res: Vec<TenantTimelineId> = global_timelines
.get_all()
let res: Vec<TenantTimelineId> = GlobalTimelines::get_all()
.iter()
.map(|tli| tli.ttid)
.collect();
@@ -195,8 +182,7 @@ async fn timeline_status_handler(request: Request<Body>) -> Result<Response<Body
);
check_permission(&request, Some(ttid.tenant_id))?;
let global_timelines = get_global_timelines(&request);
let tli = global_timelines.get(ttid).map_err(ApiError::from)?;
let tli = GlobalTimelines::get(ttid).map_err(ApiError::from)?;
let (inmem, state) = tli.get_state().await;
let flush_lsn = tli.get_flush_lsn().await;
@@ -247,11 +233,9 @@ async fn timeline_delete_handler(mut request: Request<Body>) -> Result<Response<
let only_local = parse_query_param(&request, "only_local")?.unwrap_or(false);
check_permission(&request, Some(ttid.tenant_id))?;
ensure_no_body(&mut request).await?;
let global_timelines = get_global_timelines(&request);
// FIXME: `delete_force` can fail from both internal errors and bad requests. Add better
// error handling here when we're able to.
let resp = global_timelines
.delete(&ttid, only_local)
let resp = GlobalTimelines::delete(&ttid, only_local)
.await
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, resp)
@@ -263,9 +247,8 @@ async fn timeline_pull_handler(mut request: Request<Body>) -> Result<Response<Bo
let data: pull_timeline::Request = json_request(&mut request).await?;
let conf = get_conf(&request);
let global_timelines = get_global_timelines(&request);
let resp = pull_timeline::handle_request(data, conf.sk_auth_token.clone(), global_timelines)
let resp = pull_timeline::handle_request(data, conf.sk_auth_token.clone())
.await
.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, resp)
@@ -280,8 +263,7 @@ async fn timeline_snapshot_handler(request: Request<Body>) -> Result<Response<Bo
);
check_permission(&request, Some(ttid.tenant_id))?;
let global_timelines = get_global_timelines(&request);
let tli = global_timelines.get(ttid).map_err(ApiError::from)?;
let tli = GlobalTimelines::get(ttid).map_err(ApiError::from)?;
// To stream the body use wrap_stream which wants Stream of Result<Bytes>,
// so create the chan and write to it in another task.
@@ -311,19 +293,19 @@ async fn timeline_copy_handler(mut request: Request<Body>) -> Result<Response<Bo
check_permission(&request, None)?;
let request_data: TimelineCopyRequest = json_request(&mut request).await?;
let source_ttid = TenantTimelineId::new(
let ttid = TenantTimelineId::new(
parse_request_param(&request, "tenant_id")?,
parse_request_param(&request, "source_timeline_id")?,
);
let global_timelines = get_global_timelines(&request);
let source = GlobalTimelines::get(ttid)?;
copy_timeline::handle_request(copy_timeline::Request{
source_ttid,
source,
until_lsn: request_data.until_lsn,
destination_ttid: TenantTimelineId::new(source_ttid.tenant_id, request_data.target_timeline_id),
}, global_timelines)
.instrument(info_span!("copy_timeline", from=%source_ttid, to=%request_data.target_timeline_id, until_lsn=%request_data.until_lsn))
destination_ttid: TenantTimelineId::new(ttid.tenant_id, request_data.target_timeline_id),
})
.instrument(info_span!("copy_timeline", from=%ttid, to=%request_data.target_timeline_id, until_lsn=%request_data.until_lsn))
.await
.map_err(ApiError::InternalServerError)?;
@@ -340,8 +322,7 @@ async fn patch_control_file_handler(
parse_request_param(&request, "timeline_id")?,
);
let global_timelines = get_global_timelines(&request);
let tli = global_timelines.get(ttid).map_err(ApiError::from)?;
let tli = GlobalTimelines::get(ttid).map_err(ApiError::from)?;
let patch_request: patch_control_file::Request = json_request(&mut request).await?;
let response = patch_control_file::handle_request(tli, patch_request)
@@ -360,8 +341,7 @@ async fn timeline_checkpoint_handler(request: Request<Body>) -> Result<Response<
parse_request_param(&request, "timeline_id")?,
);
let global_timelines = get_global_timelines(&request);
let tli = global_timelines.get(ttid)?;
let tli = GlobalTimelines::get(ttid)?;
tli.write_shared_state()
.await
.sk
@@ -379,7 +359,6 @@ async fn timeline_digest_handler(request: Request<Body>) -> Result<Response<Body
);
check_permission(&request, Some(ttid.tenant_id))?;
let global_timelines = get_global_timelines(&request);
let from_lsn: Option<Lsn> = parse_query_param(&request, "from_lsn")?;
let until_lsn: Option<Lsn> = parse_query_param(&request, "until_lsn")?;
@@ -392,7 +371,7 @@ async fn timeline_digest_handler(request: Request<Body>) -> Result<Response<Body
)))?,
};
let tli = global_timelines.get(ttid).map_err(ApiError::from)?;
let tli = GlobalTimelines::get(ttid).map_err(ApiError::from)?;
let tli = tli
.wal_residence_guard()
.await
@@ -414,8 +393,7 @@ async fn timeline_backup_partial_reset(request: Request<Body>) -> Result<Respons
);
check_permission(&request, Some(ttid.tenant_id))?;
let global_timelines = get_global_timelines(&request);
let tli = global_timelines.get(ttid).map_err(ApiError::from)?;
let tli = GlobalTimelines::get(ttid).map_err(ApiError::from)?;
let response = tli
.backup_partial_reset()
@@ -437,8 +415,7 @@ async fn timeline_term_bump_handler(
let request_data: TimelineTermBumpRequest = json_request(&mut request).await?;
let global_timelines = get_global_timelines(&request);
let tli = global_timelines.get(ttid).map_err(ApiError::from)?;
let tli = GlobalTimelines::get(ttid).map_err(ApiError::from)?;
let response = tli
.term_bump(request_data.term)
.await
@@ -475,8 +452,7 @@ async fn record_safekeeper_info(mut request: Request<Body>) -> Result<Response<B
standby_horizon: sk_info.standby_horizon.0,
};
let global_timelines = get_global_timelines(&request);
let tli = global_timelines.get(ttid).map_err(ApiError::from)?;
let tli = GlobalTimelines::get(ttid).map_err(ApiError::from)?;
tli.record_safekeeper_info(proto_sk_info)
.await
.map_err(ApiError::InternalServerError)?;
@@ -530,8 +506,6 @@ async fn dump_debug_handler(mut request: Request<Body>) -> Result<Response<Body>
let dump_term_history = dump_term_history.unwrap_or(true);
let dump_wal_last_modified = dump_wal_last_modified.unwrap_or(dump_all);
let global_timelines = get_global_timelines(&request);
let args = debug_dump::Args {
dump_all,
dump_control_file,
@@ -543,7 +517,7 @@ async fn dump_debug_handler(mut request: Request<Body>) -> Result<Response<Body>
timeline_id,
};
let resp = debug_dump::build(args, global_timelines)
let resp = debug_dump::build(args)
.await
.map_err(ApiError::InternalServerError)?;
@@ -596,10 +570,7 @@ async fn dump_debug_handler(mut request: Request<Body>) -> Result<Response<Body>
}
/// Safekeeper http router.
pub fn make_router(
conf: Arc<SafeKeeperConf>,
global_timelines: Arc<GlobalTimelines>,
) -> RouterBuilder<hyper::Body, ApiError> {
pub fn make_router(conf: SafeKeeperConf) -> RouterBuilder<hyper::Body, ApiError> {
let mut router = endpoint::make_router();
if conf.http_auth.is_some() {
router = router.middleware(auth_middleware(|request| {
@@ -621,8 +592,7 @@ pub fn make_router(
// located nearby (/safekeeper/src/http/openapi_spec.yaml).
let auth = conf.http_auth.clone();
router
.data(conf)
.data(global_timelines)
.data(Arc::new(conf))
.data(auth)
.get("/metrics", |r| request_span(r, prometheus_metrics_handler))
.get("/profile/cpu", |r| request_span(r, profile_cpu_handler))

View File

@@ -11,6 +11,7 @@ use postgres_backend::QueryError;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncRead, AsyncWrite};
use tracing::*;
use utils::id::TenantTimelineId;
use crate::handler::SafekeeperPostgresHandler;
use crate::safekeeper::{AcceptorProposerMessage, AppendResponse, ServerInfo};
@@ -20,6 +21,7 @@ use crate::safekeeper::{
use crate::safekeeper::{Term, TermHistory, TermLsn};
use crate::state::TimelinePersistentState;
use crate::timeline::WalResidentTimeline;
use crate::GlobalTimelines;
use postgres_backend::PostgresBackend;
use postgres_ffi::encode_logical_message;
use postgres_ffi::WAL_SEGMENT_SIZE;
@@ -68,7 +70,7 @@ pub async fn handle_json_ctrl<IO: AsyncRead + AsyncWrite + Unpin>(
info!("JSON_CTRL request: {append_request:?}");
// need to init safekeeper state before AppendRequest
let tli = prepare_safekeeper(spg, append_request.pg_version).await?;
let tli = prepare_safekeeper(spg.ttid, append_request.pg_version).await?;
// if send_proposer_elected is true, we need to update local history
if append_request.send_proposer_elected {
@@ -97,22 +99,20 @@ pub async fn handle_json_ctrl<IO: AsyncRead + AsyncWrite + Unpin>(
/// Prepare safekeeper to process append requests without crashes,
/// by sending ProposerGreeting with default server.wal_seg_size.
async fn prepare_safekeeper(
spg: &SafekeeperPostgresHandler,
ttid: TenantTimelineId,
pg_version: u32,
) -> anyhow::Result<WalResidentTimeline> {
let tli = spg
.global_timelines
.create(
spg.ttid,
ServerInfo {
pg_version,
wal_seg_size: WAL_SEGMENT_SIZE as u32,
system_id: 0,
},
Lsn::INVALID,
Lsn::INVALID,
)
.await?;
let tli = GlobalTimelines::create(
ttid,
ServerInfo {
pg_version,
wal_seg_size: WAL_SEGMENT_SIZE as u32,
system_id: 0,
},
Lsn::INVALID,
Lsn::INVALID,
)
.await?;
tli.wal_residence_guard().await
}

View File

@@ -455,7 +455,6 @@ pub struct FullTimelineInfo {
/// Collects metrics for all active timelines.
pub struct TimelineCollector {
global_timelines: Arc<GlobalTimelines>,
descs: Vec<Desc>,
commit_lsn: GenericGaugeVec<AtomicU64>,
backup_lsn: GenericGaugeVec<AtomicU64>,
@@ -479,8 +478,14 @@ pub struct TimelineCollector {
active_timelines_count: IntGauge,
}
impl Default for TimelineCollector {
fn default() -> Self {
Self::new()
}
}
impl TimelineCollector {
pub fn new(global_timelines: Arc<GlobalTimelines>) -> TimelineCollector {
pub fn new() -> TimelineCollector {
let mut descs = Vec::new();
let commit_lsn = GenericGaugeVec::new(
@@ -671,7 +676,6 @@ impl TimelineCollector {
descs.extend(active_timelines_count.desc().into_iter().cloned());
TimelineCollector {
global_timelines,
descs,
commit_lsn,
backup_lsn,
@@ -724,18 +728,17 @@ impl Collector for TimelineCollector {
self.written_wal_seconds.reset();
self.flushed_wal_seconds.reset();
let timelines_count = self.global_timelines.get_all().len();
let timelines_count = GlobalTimelines::get_all().len();
let mut active_timelines_count = 0;
// Prometheus Collector is sync, and data is stored under async lock. To
// bridge the gap with a crutch, collect data in spawned thread with
// local tokio runtime.
let global_timelines = self.global_timelines.clone();
let infos = std::thread::spawn(|| {
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.expect("failed to create rt");
rt.block_on(collect_timeline_metrics(global_timelines))
rt.block_on(collect_timeline_metrics())
})
.join()
.expect("collect_timeline_metrics thread panicked");
@@ -854,9 +857,9 @@ impl Collector for TimelineCollector {
}
}
async fn collect_timeline_metrics(global_timelines: Arc<GlobalTimelines>) -> Vec<FullTimelineInfo> {
async fn collect_timeline_metrics() -> Vec<FullTimelineInfo> {
let mut res = vec![];
let active_timelines = global_timelines.get_global_broker_active_set().get_all();
let active_timelines = GlobalTimelines::get_global_broker_active_set().get_all();
for tli in active_timelines {
if let Some(info) = tli.info_for_metrics().await {

View File

@@ -409,9 +409,8 @@ pub struct DebugDumpResponse {
pub async fn handle_request(
request: Request,
sk_auth_token: Option<SecretString>,
global_timelines: Arc<GlobalTimelines>,
) -> Result<Response> {
let existing_tli = global_timelines.get(TenantTimelineId::new(
let existing_tli = GlobalTimelines::get(TenantTimelineId::new(
request.tenant_id,
request.timeline_id,
));
@@ -454,14 +453,13 @@ pub async fn handle_request(
assert!(status.tenant_id == request.tenant_id);
assert!(status.timeline_id == request.timeline_id);
pull_timeline(status, safekeeper_host, sk_auth_token, global_timelines).await
pull_timeline(status, safekeeper_host, sk_auth_token).await
}
async fn pull_timeline(
status: TimelineStatus,
host: String,
sk_auth_token: Option<SecretString>,
global_timelines: Arc<GlobalTimelines>,
) -> Result<Response> {
let ttid = TenantTimelineId::new(status.tenant_id, status.timeline_id);
info!(
@@ -474,7 +472,7 @@ async fn pull_timeline(
status.acceptor_state.epoch
);
let conf = &global_timelines.get_global_config();
let conf = &GlobalTimelines::get_global_config();
let (_tmp_dir, tli_dir_path) = create_temp_timeline_dir(conf, ttid).await?;
@@ -533,9 +531,7 @@ async fn pull_timeline(
assert!(status.commit_lsn <= status.flush_lsn);
// Finally, load the timeline.
let _tli = global_timelines
.load_temp_timeline(ttid, &tli_dir_path, false)
.await?;
let _tli = GlobalTimelines::load_temp_timeline(ttid, &tli_dir_path, false).await?;
Ok(Response {
safekeeper_host: host,

View File

@@ -267,7 +267,6 @@ impl SafekeeperPostgresHandler {
pgb_reader: &mut pgb_reader,
peer_addr,
acceptor_handle: &mut acceptor_handle,
global_timelines: self.global_timelines.clone(),
};
// Read first message and create timeline if needed.
@@ -332,7 +331,6 @@ struct NetworkReader<'a, IO> {
// WalAcceptor is spawned when we learn server info from walproposer and
// create timeline; handle is put here.
acceptor_handle: &'a mut Option<JoinHandle<anyhow::Result<()>>>,
global_timelines: Arc<GlobalTimelines>,
}
impl<'a, IO: AsyncRead + AsyncWrite + Unpin> NetworkReader<'a, IO> {
@@ -352,11 +350,10 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin> NetworkReader<'a, IO> {
system_id: greeting.system_id,
wal_seg_size: greeting.wal_seg_size,
};
let tli = self
.global_timelines
.create(self.ttid, server_info, Lsn::INVALID, Lsn::INVALID)
.await
.context("create timeline")?;
let tli =
GlobalTimelines::create(self.ttid, server_info, Lsn::INVALID, Lsn::INVALID)
.await
.context("create timeline")?;
tli.wal_residence_guard().await?
}
_ => {

View File

@@ -10,6 +10,7 @@ use crate::timeline::WalResidentTimeline;
use crate::wal_reader_stream::WalReaderStreamBuilder;
use crate::wal_service::ConnectionId;
use crate::wal_storage::WalReader;
use crate::GlobalTimelines;
use anyhow::{bail, Context as AnyhowContext};
use bytes::Bytes;
use futures::future::Either;
@@ -399,10 +400,7 @@ impl SafekeeperPostgresHandler {
start_pos: Lsn,
term: Option<Term>,
) -> Result<(), QueryError> {
let tli = self
.global_timelines
.get(self.ttid)
.map_err(|e| QueryError::Other(e.into()))?;
let tli = GlobalTimelines::get(self.ttid).map_err(|e| QueryError::Other(e.into()))?;
let residence_guard = tli.wal_residence_guard().await?;
if let Err(end) = self

Some files were not shown because too many files have changed in this diff Show More