Compare commits

..

2 Commits

Author SHA1 Message Date
Conrad Ludgate
6d2bbffdab only for console 2023-12-15 12:28:50 +00:00
Conrad Ludgate
7151bcc175 proxy console force http2 2023-12-15 12:26:51 +00:00
294 changed files with 8520 additions and 25211 deletions

View File

@@ -1,2 +0,0 @@
[profile.default]
slow-timeout = { period = "20s", terminate-after = 3 }

View File

@@ -11,7 +11,7 @@ on:
# │ │ ┌───────────── day of the month (1 - 31)
# │ │ │ ┌───────────── month (1 - 12 or JAN-DEC)
# │ │ │ │ ┌───────────── day of the week (0 - 6 or SUN-SAT)
- cron: '0 3 * * *' # run once a day, timezone is utc
- cron: '0 3 * * *' # run once a day, timezone is utc
workflow_dispatch: # adds ability to run this manually
inputs:
@@ -23,21 +23,6 @@ on:
type: boolean
description: 'Publish perf report. If not set, the report will be published only for the main branch'
required: false
collect_olap_explain:
type: boolean
description: 'Collect EXPLAIN ANALYZE for OLAP queries. If not set, EXPLAIN ANALYZE will not be collected'
required: false
default: false
collect_pg_stat_statements:
type: boolean
description: 'Collect pg_stat_statements for OLAP queries. If not set, pg_stat_statements will not be collected'
required: false
default: false
run_AWS_RDS_AND_AURORA:
type: boolean
description: 'AWS-RDS and AWS-AURORA normally only run on Saturday. Set this to true to run them on every workflow_dispatch'
required: false
default: false
defaults:
run:
@@ -128,8 +113,6 @@ jobs:
# - neon-captest-reuse: Reusing existing project
# - rds-aurora: Aurora Postgres Serverless v2 with autoscaling from 0.5 to 2 ACUs
# - rds-postgres: RDS Postgres db.m5.large instance (2 vCPU, 8 GiB) with gp3 EBS storage
env:
RUN_AWS_RDS_AND_AURORA: ${{ github.event.inputs.run_AWS_RDS_AND_AURORA || 'false' }}
runs-on: ubuntu-latest
outputs:
pgbench-compare-matrix: ${{ steps.pgbench-compare-matrix.outputs.matrix }}
@@ -169,7 +152,7 @@ jobs:
]
}'
if [ "$(date +%A)" = "Saturday" ] || [ ${RUN_AWS_RDS_AND_AURORA} = "true" ]; then
if [ "$(date +%A)" = "Saturday" ]; then
matrix=$(echo "$matrix" | jq '.include += [{ "platform": "rds-postgres" },
{ "platform": "rds-aurora" }]')
fi
@@ -188,9 +171,9 @@ jobs:
]
}'
if [ "$(date +%A)" = "Saturday" ] || [ ${RUN_AWS_RDS_AND_AURORA} = "true" ]; then
if [ "$(date +%A)" = "Saturday" ]; then
matrix=$(echo "$matrix" | jq '.include += [{ "platform": "rds-postgres", "scale": "10" },
{ "platform": "rds-aurora", "scale": "10" }]')
{ "platform": "rds-aurora", "scale": "10" }]')
fi
echo "matrix=$(echo "$matrix" | jq --compact-output '.')" >> $GITHUB_OUTPUT
@@ -354,8 +337,6 @@ jobs:
POSTGRES_DISTRIB_DIR: /tmp/neon/pg_install
DEFAULT_PG_VERSION: 14
TEST_OUTPUT: /tmp/test_output
TEST_OLAP_COLLECT_EXPLAIN: ${{ github.event.inputs.collect_olap_explain }}
TEST_OLAP_COLLECT_PG_STAT_STATEMENTS: ${{ github.event.inputs.collect_pg_stat_statements }}
BUILD_TYPE: remote
SAVE_PERF_REPORT: ${{ github.event.inputs.save_perf_report || ( github.ref_name == 'main' ) }}
PLATFORM: ${{ matrix.platform }}
@@ -418,8 +399,6 @@ jobs:
env:
VIP_VAP_ACCESS_TOKEN: "${{ secrets.VIP_VAP_ACCESS_TOKEN }}"
PERF_TEST_RESULT_CONNSTR: "${{ secrets.PERF_TEST_RESULT_CONNSTR }}"
TEST_OLAP_COLLECT_EXPLAIN: ${{ github.event.inputs.collect_olap_explain || 'false' }}
TEST_OLAP_COLLECT_PG_STAT_STATEMENTS: ${{ github.event.inputs.collect_pg_stat_statements || 'false' }}
BENCHMARK_CONNSTR: ${{ steps.set-up-connstr.outputs.connstr }}
TEST_OLAP_SCALE: 10

View File

@@ -1,105 +0,0 @@
name: Build and Push Docker Image
on:
workflow_call:
inputs:
dockerfile-path:
required: true
type: string
image-name:
required: true
type: string
outputs:
build-tools-tag:
description: "tag generated for build tools"
value: ${{ jobs.tag.outputs.build-tools-tag }}
jobs:
check-if-build-tools-dockerfile-changed:
runs-on: ubuntu-latest
outputs:
docker_file_changed: ${{ steps.dockerfile.outputs.docker_file_changed }}
steps:
- name: Check if Dockerfile.buildtools has changed
id: dockerfile
run: |
if [[ "$GITHUB_EVENT_NAME" != "pull_request" ]]; then
echo "docker_file_changed=false" >> $GITHUB_OUTPUT
exit
fi
updated_files=$(gh pr --repo neondatabase/neon diff ${{ github.event.pull_request.number }} --name-only)
if [[ $updated_files == *"Dockerfile.buildtools"* ]]; then
echo "docker_file_changed=true" >> $GITHUB_OUTPUT
fi
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
tag:
runs-on: ubuntu-latest
needs: [ check-if-build-tools-dockerfile-changed ]
outputs:
build-tools-tag: ${{steps.buildtools-tag.outputs.image_tag}}
steps:
- name: Get buildtools tag
env:
DOCKERFILE_CHANGED: ${{ needs.check-if-build-tools-dockerfile-changed.outputs.docker_file_changed }}
run: |
if [[ "$GITHUB_EVENT_NAME" == "pull_request" ]] && [[ "${DOCKERFILE_CHANGED}" == "true" ]]; then
IMAGE_TAG=$GITHUB_RUN_ID
else
IMAGE_TAG=pinned
fi
echo "image_tag=${IMAGE_TAG}" >> $GITHUB_OUTPUT
shell: bash
id: buildtools-tag
kaniko:
if: needs.check-if-build-tools-dockerfile-changed.outputs.docker_file_changed == 'true'
needs: [ tag, check-if-build-tools-dockerfile-changed ]
runs-on: [ self-hosted, dev, x64 ]
container: gcr.io/kaniko-project/executor:v1.7.0-debug
steps:
- name: Checkout
uses: actions/checkout@v1
- name: Configure ECR login
run: echo "{\"credsStore\":\"ecr-login\"}" > /kaniko/.docker/config.json
- name: Kaniko build
run: /kaniko/executor --reproducible --snapshotMode=redo --skip-unused-stages --dockerfile ${{ inputs.dockerfile-path }} --cache=true --cache-repo 369495373322.dkr.ecr.eu-central-1.amazonaws.com/cache --destination 369495373322.dkr.ecr.eu-central-1.amazonaws.com/${{ inputs.image-name }}:${{ needs.tag.outputs.build-tools-tag }}-amd64
kaniko-arm:
if: needs.check-if-build-tools-dockerfile-changed.outputs.docker_file_changed == 'true'
needs: [ tag, check-if-build-tools-dockerfile-changed ]
runs-on: [ self-hosted, dev, arm64 ]
container: gcr.io/kaniko-project/executor:v1.7.0-debug
steps:
- name: Checkout
uses: actions/checkout@v1
- name: Configure ECR login
run: echo "{\"credsStore\":\"ecr-login\"}" > /kaniko/.docker/config.json
- name: Kaniko build
run: /kaniko/executor --reproducible --snapshotMode=redo --skip-unused-stages --dockerfile ${{ inputs.dockerfile-path }} --cache=true --cache-repo 369495373322.dkr.ecr.eu-central-1.amazonaws.com/cache --destination 369495373322.dkr.ecr.eu-central-1.amazonaws.com/${{ inputs.image-name }}:${{ needs.tag.outputs.build-tools-tag }}-arm64
manifest:
if: needs.check-if-build-tools-dockerfile-changed.outputs.docker_file_changed == 'true'
name: 'manifest'
runs-on: [ self-hosted, dev, x64 ]
needs:
- tag
- kaniko
- kaniko-arm
- check-if-build-tools-dockerfile-changed
steps:
- name: Create manifest
run: docker manifest create 369495373322.dkr.ecr.eu-central-1.amazonaws.com/${{ inputs.image-name }}:${{ needs.tag.outputs.build-tools-tag }} --amend 369495373322.dkr.ecr.eu-central-1.amazonaws.com/${{ inputs.image-name }}:${{ needs.tag.outputs.build-tools-tag }}-amd64 --amend 369495373322.dkr.ecr.eu-central-1.amazonaws.com/${{ inputs.image-name }}:${{ needs.tag.outputs.build-tools-tag }}-arm64
- name: Push manifest
run: docker manifest push 369495373322.dkr.ecr.eu-central-1.amazonaws.com/${{ inputs.image-name }}:${{ needs.tag.outputs.build-tools-tag }}

View File

@@ -44,6 +44,7 @@ jobs:
exit 1
tag:
needs: [ check-permissions ]
runs-on: [ self-hosted, gen3, small ]
@@ -73,19 +74,11 @@ jobs:
shell: bash
id: build-tag
build-buildtools-image:
needs: [ check-permissions ]
uses: ./.github/workflows/build_and_push_docker_image.yml
with:
dockerfile-path: Dockerfile.buildtools
image-name: build-tools
secrets: inherit
check-codestyle-python:
needs: [ check-permissions, build-buildtools-image ]
needs: [ check-permissions ]
runs-on: [ self-hosted, gen3, small ]
container:
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:${{ needs.build-buildtools-image.outputs.build-tools-tag }}
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/rust:pinned
options: --init
steps:
@@ -105,20 +98,20 @@ jobs:
- name: Install Python deps
run: ./scripts/pysync
- name: Run `ruff check` to ensure code format
run: poetry run ruff check .
- name: Run ruff to ensure code format
run: poetry run ruff .
- name: Run `ruff format` to ensure code format
run: poetry run ruff format --check .
- name: Run black to ensure code format
run: poetry run black --diff --check .
- name: Run mypy to check types
run: poetry run mypy .
check-codestyle-rust:
needs: [ check-permissions, build-buildtools-image ]
needs: [ check-permissions ]
runs-on: [ self-hosted, gen3, large ]
container:
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:${{ needs.build-buildtools-image.outputs.build-tools-tag }}
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/rust:pinned
options: --init
steps:
@@ -182,10 +175,10 @@ jobs:
run: cargo deny check --hide-inclusion-graph
build-neon:
needs: [ check-permissions, tag, build-buildtools-image ]
needs: [ check-permissions, tag ]
runs-on: [ self-hosted, gen3, large ]
container:
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:${{ needs.build-buildtools-image.outputs.build-tools-tag }}
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/rust:pinned
options: --init
strategy:
fail-fast: false
@@ -339,16 +332,16 @@ jobs:
run: |
${cov_prefix} mold -run cargo build $CARGO_FLAGS $CARGO_FEATURES --bins --tests
- name: Run rust tests
- name: Run cargo test
run: |
${cov_prefix} cargo nextest run $CARGO_FLAGS $CARGO_FEATURES
${cov_prefix} cargo test $CARGO_FLAGS $CARGO_FEATURES
# Run separate tests for real S3
export ENABLE_REAL_S3_REMOTE_STORAGE=nonempty
export REMOTE_STORAGE_S3_BUCKET=neon-github-ci-tests
export REMOTE_STORAGE_S3_BUCKET=neon-github-public-dev
export REMOTE_STORAGE_S3_REGION=eu-central-1
# Avoid `$CARGO_FEATURES` since there's no `testing` feature in the e2e tests now
${cov_prefix} cargo nextest run $CARGO_FLAGS -E 'package(remote_storage)' -E 'test(test_real_s3)'
${cov_prefix} cargo test $CARGO_FLAGS --package remote_storage --test test_real_s3
# Run separate tests for real Azure Blob Storage
# XXX: replace region with `eu-central-1`-like region
@@ -358,7 +351,7 @@ jobs:
export REMOTE_STORAGE_AZURE_CONTAINER="${{ vars.REMOTE_STORAGE_AZURE_CONTAINER }}"
export REMOTE_STORAGE_AZURE_REGION="${{ vars.REMOTE_STORAGE_AZURE_REGION }}"
# Avoid `$CARGO_FEATURES` since there's no `testing` feature in the e2e tests now
${cov_prefix} cargo nextest run $CARGO_FLAGS -E 'package(remote_storage)' -E 'test(test_real_azure)'
${cov_prefix} cargo test $CARGO_FLAGS --package remote_storage --test test_real_azure
- name: Install rust binaries
run: |
@@ -415,10 +408,10 @@ jobs:
uses: ./.github/actions/save-coverage-data
regress-tests:
needs: [ check-permissions, build-neon, build-buildtools-image, tag ]
needs: [ check-permissions, build-neon, tag ]
runs-on: [ self-hosted, gen3, large ]
container:
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:${{ needs.build-buildtools-image.outputs.build-tools-tag }}
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/rust:pinned
# Default shared memory is 64mb
options: --init --shm-size=512mb
strategy:
@@ -454,10 +447,10 @@ jobs:
uses: ./.github/actions/save-coverage-data
benchmarks:
needs: [ check-permissions, build-neon, build-buildtools-image ]
needs: [ check-permissions, build-neon ]
runs-on: [ self-hosted, gen3, small ]
container:
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:${{ needs.build-buildtools-image.outputs.build-tools-tag }}
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/rust:pinned
# Default shared memory is 64mb
options: --init --shm-size=512mb
if: github.ref_name == 'main' || contains(github.event.pull_request.labels.*.name, 'run-benchmarks')
@@ -486,12 +479,12 @@ jobs:
# while coverage is currently collected for the debug ones
create-test-report:
needs: [ check-permissions, regress-tests, coverage-report, benchmarks, build-buildtools-image ]
needs: [ check-permissions, regress-tests, coverage-report, benchmarks ]
if: ${{ !cancelled() && contains(fromJSON('["skipped", "success"]'), needs.check-permissions.result) }}
runs-on: [ self-hosted, gen3, small ]
container:
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:${{ needs.build-buildtools-image.outputs.build-tools-tag }}
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/rust:pinned
options: --init
steps:
@@ -533,10 +526,11 @@ jobs:
})
coverage-report:
needs: [ check-permissions, regress-tests, build-buildtools-image ]
needs: [ check-permissions, regress-tests ]
runs-on: [ self-hosted, gen3, small ]
container:
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools:${{ needs.build-buildtools-image.outputs.build-tools-tag }}
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/rust:pinned
options: --init
strategy:
fail-fast: false
@@ -700,7 +694,7 @@ jobs:
}"
neon-image:
needs: [ check-permissions, build-buildtools-image, tag ]
needs: [ check-permissions, tag ]
runs-on: [ self-hosted, gen3, large ]
container: gcr.io/kaniko-project/executor:v1.9.2-debug
defaults:
@@ -739,7 +733,6 @@ jobs:
--context .
--build-arg GIT_VERSION=${{ github.event.pull_request.head.sha || github.sha }}
--build-arg BUILD_TAG=${{ needs.tag.outputs.build-tag }}
--build-arg TAG=${{ needs.build-buildtools-image.outputs.build-tools-tag }}
--build-arg REPOSITORY=369495373322.dkr.ecr.eu-central-1.amazonaws.com
--destination 369495373322.dkr.ecr.eu-central-1.amazonaws.com/neon:${{needs.tag.outputs.build-tag}}
--destination neondatabase/neon:${{needs.tag.outputs.build-tag}}
@@ -750,7 +743,7 @@ jobs:
compute-tools-image:
runs-on: [ self-hosted, gen3, large ]
needs: [ check-permissions, build-buildtools-image, tag ]
needs: [ check-permissions, tag ]
container: gcr.io/kaniko-project/executor:v1.9.2-debug
defaults:
run:
@@ -785,7 +778,6 @@ jobs:
--context .
--build-arg GIT_VERSION=${{ github.event.pull_request.head.sha || github.sha }}
--build-arg BUILD_TAG=${{needs.tag.outputs.build-tag}}
--build-arg TAG=${{needs.build-buildtools-image.outputs.build-tools-tag}}
--build-arg REPOSITORY=369495373322.dkr.ecr.eu-central-1.amazonaws.com
--dockerfile Dockerfile.compute-tools
--destination 369495373322.dkr.ecr.eu-central-1.amazonaws.com/compute-tools:${{needs.tag.outputs.build-tag}}
@@ -796,7 +788,7 @@ jobs:
run: rm -rf ~/.ecr
compute-node-image:
needs: [ check-permissions, build-buildtools-image, tag ]
needs: [ check-permissions, tag ]
runs-on: [ self-hosted, gen3, large ]
container:
image: gcr.io/kaniko-project/executor:v1.9.2-debug
@@ -844,7 +836,6 @@ jobs:
--build-arg GIT_VERSION=${{ github.event.pull_request.head.sha || github.sha }}
--build-arg PG_VERSION=${{ matrix.version }}
--build-arg BUILD_TAG=${{needs.tag.outputs.build-tag}}
--build-arg TAG=${{needs.build-buildtools-image.outputs.build-tools-tag}}
--build-arg REPOSITORY=369495373322.dkr.ecr.eu-central-1.amazonaws.com
--dockerfile Dockerfile.compute-node
--destination 369495373322.dkr.ecr.eu-central-1.amazonaws.com/compute-node-${{ matrix.version }}:${{needs.tag.outputs.build-tag}}
@@ -866,7 +857,7 @@ jobs:
run:
shell: sh -eu {0}
env:
VM_BUILDER_VERSION: v0.21.0
VM_BUILDER_VERSION: v0.19.0
steps:
- name: Checkout
@@ -1131,7 +1122,7 @@ jobs:
# TODO: move deployPreprodRegion to release (`"$GITHUB_REF_NAME" == "release"` block), once Staging support different compute tag prefixes for different regions
gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main -f branch=main -f dockerTag=${{needs.tag.outputs.build-tag}} -f deployPreprodRegion=true
elif [[ "$GITHUB_REF_NAME" == "release" ]]; then
gh workflow --repo neondatabase/aws run deploy-prod.yml --ref main -f branch=main -f dockerTag=${{needs.tag.outputs.build-tag}}
gh workflow --repo neondatabase/aws run deploy-prod.yml --ref main -f branch=main -f dockerTag=${{needs.tag.outputs.build-tag}} -f disclamerAcknowledged=true
else
echo "GITHUB_REF_NAME (value '$GITHUB_REF_NAME') is not set to either 'main' or 'release'"
exit 1

View File

@@ -218,7 +218,7 @@ jobs:
# Run separate tests for real S3
export ENABLE_REAL_S3_REMOTE_STORAGE=nonempty
export REMOTE_STORAGE_S3_BUCKET=neon-github-ci-tests
export REMOTE_STORAGE_S3_BUCKET=neon-github-public-dev
export REMOTE_STORAGE_S3_REGION=eu-central-1
# Avoid `$CARGO_FEATURES` since there's no `testing` feature in the e2e tests now
cargo test $CARGO_FLAGS --package remote_storage --test test_real_s3

View File

@@ -1,130 +0,0 @@
name: 'Update build tools image tag'
# This workflow it used to update tag of build tools in ECR.
# The most common use case is adding/moving `pinned` tag to `${GITHUB_RUN_IT}` image.
on:
workflow_dispatch:
inputs:
from-tag:
description: 'Source tag'
required: true
type: string
to-tag:
description: 'Destination tag'
required: true
type: string
default: 'pinned'
defaults:
run:
shell: bash -euo pipefail {0}
env:
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_DEV }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_KEY_DEV }}
permissions: {}
jobs:
tag-image:
runs-on: [ self-hosted, gen3, small ]
container: golang:1.19-bullseye
env:
IMAGE: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools
FROM_TAG: ${{ inputs.from-tag }}
TO_TAG: ${{ inputs.to-tag }}
outputs:
next-digest-buildtools: ${{ steps.next-digest.outputs.next-digest-buildtools }}
prev-digest-buildtools: ${{ steps.prev-digest.outputs.prev-digest-buildtools }}
steps:
- name: Install Crane & ECR helper
run: |
go install github.com/google/go-containerregistry/cmd/crane@a54d64203cffcbf94146e04069aae4a97f228ee2 # v0.16.1
go install github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cli/docker-credential-ecr-login@adf1bafd791ae7d4ff098108b1e91f36a4da5404 # v0.7.1
- name: Configure ECR login
run: |
mkdir /github/home/.docker/
echo "{\"credsStore\":\"ecr-login\"}" > /github/home/.docker/config.json
- name: Get source image digest
id: next-digest
run: |
NEXT_DIGEST=$(crane digest ${IMAGE}:${FROM_TAG} || true)
if [ -z "${NEXT_DIGEST}" ]; then
echo >&2 "Image ${IMAGE}:${FROM_TAG} does not exist"
exit 1
fi
echo "Current ${IMAGE}@${FROM_TAG} image is ${IMAGE}@${NEXT_DIGEST}"
echo "next-digest-buildtools=$NEXT_DIGEST" >> $GITHUB_OUTPUT
- name: Get destination image digest (if already exists)
id: prev-digest
run: |
PREV_DIGEST=$(crane digest ${IMAGE}:${TO_TAG} || true)
if [ -z "${PREV_DIGEST}" ]; then
echo >&2 "Image ${IMAGE}:${TO_TAG} does not exist (it's ok)"
else
echo >&2 "Current ${IMAGE}@${TO_TAG} image is ${IMAGE}@${PREV_DIGEST}"
echo "prev-digest-buildtools=$PREV_DIGEST" >> $GITHUB_OUTPUT
fi
- name: Tag image
run: |
crane tag "${IMAGE}:${FROM_TAG}" "${TO_TAG}"
rollback-tag-image:
needs: tag-image
if: ${{ !success() }}
runs-on: [ self-hosted, gen3, small ]
container: golang:1.19-bullseye
env:
IMAGE: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/build-tools
FROM_TAG: ${{ inputs.from-tag }}
TO_TAG: ${{ inputs.to-tag }}
steps:
- name: Install Crane & ECR helper
run: |
go install github.com/google/go-containerregistry/cmd/crane@a54d64203cffcbf94146e04069aae4a97f228ee2 # v0.16.1
go install github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cli/docker-credential-ecr-login@adf1bafd791ae7d4ff098108b1e91f36a4da5404 # v0.7.1
- name: Configure ECR login
run: |
mkdir /github/home/.docker/
echo "{\"credsStore\":\"ecr-login\"}" > /github/home/.docker/config.json
- name: Restore previous tag if needed
run: |
NEXT_DIGEST="${{ needs.tag-image.outputs.next-digest-buildtools }}"
PREV_DIGEST="${{ needs.tag-image.outputs.prev-digest-buildtools }}"
if [ -z "${NEXT_DIGEST}" ]; then
echo >&2 "Image ${IMAGE}:${FROM_TAG} does not exist, nothing to rollback"
exit 0
fi
if [ -z "${PREV_DIGEST}" ]; then
# I guess we should delete the tag here/untag the image, but crane does not support it
# - https://github.com/google/go-containerregistry/issues/999
echo >&2 "Image ${IMAGE}:${TO_TAG} did not exist, but it was created by the job, no need to rollback"
exit 0
fi
CURRENT_DIGEST=$(crane digest "${IMAGE}:${TO_TAG}")
if [ "${CURRENT_DIGEST}" == "${NEXT_DIGEST}" ]; then
crane tag "${IMAGE}@${PREV_DIGEST}" "${TO_TAG}"
echo >&2 "Successfully restored ${TO_TAG} tag from ${IMAGE}@${CURRENT_DIGEST} to ${IMAGE}@${PREV_DIGEST}"
else
echo >&2 "Image ${IMAGE}:${TO_TAG}@${CURRENT_DIGEST} is not required to be restored"
fi

1
.gitignore vendored
View File

@@ -6,7 +6,6 @@ __pycache__/
test_output/
.vscode
.idea
neon.iml
/.neon
/integration_tests/.neon

View File

@@ -70,17 +70,3 @@ We're using the following approach to make it work:
- The label gets removed automatically, so to run CI again with new changes, the label should be added again (after the review)
For details see [`approved-for-ci-run.yml`](.github/workflows/approved-for-ci-run.yml)
## How do I add the "pinned" tag to an buildtools image?
We use the `pinned` tag for `Dockerfile.buildtools` build images in our CI/CD setup, currently adding the `pinned` tag is a manual operation.
You can call it from GitHub UI: https://github.com/neondatabase/neon/actions/workflows/update_build_tools_image.yml,
or using GitHub CLI:
```bash
gh workflow -R neondatabase/neon run update_build_tools_image.yml \
-f from-tag=6254913013 \
-f to-tag=pinned \
# Default `-f to-tag` is `pinned`, so the parameter can be omitted.
```

499
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -3,11 +3,8 @@ resolver = "2"
members = [
"compute_tools",
"control_plane",
"control_plane/attachment_service",
"pageserver",
"pageserver/ctl",
"pageserver/client",
"pageserver/pagebench",
"proxy",
"safekeeper",
"storage_broker",
@@ -81,7 +78,6 @@ futures-util = "0.3"
git-version = "0.3"
hashbrown = "0.13"
hashlink = "0.8.1"
hdrhistogram = "7.5.2"
hex = "0.4"
hex-literal = "0.4"
hmac = "0.12.1"
@@ -94,7 +90,7 @@ hyper-tungstenite = "0.11"
inotify = "0.10.2"
ipnet = "2.9.0"
itertools = "0.10"
jsonwebtoken = "9"
jsonwebtoken = "8"
libc = "0.2"
md5 = "0.7.0"
memoffset = "0.8"
@@ -108,14 +104,11 @@ opentelemetry = "0.19.0"
opentelemetry-otlp = { version = "0.12.0", default_features=false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry-semantic-conventions = "0.11.0"
parking_lot = "0.12"
parquet = { version = "49.0.0", default-features = false, features = ["zstd"] }
parquet_derive = "49.0.0"
pbkdf2 = { version = "0.12.1", features = ["simple", "std"] }
pin-project-lite = "0.2"
prometheus = {version = "0.13", default_features=false, features = ["process"]} # removes protobuf dependency
prost = "0.11"
rand = "0.8"
redis = { version = "0.24.0", features = ["tokio-rustls-comp", "keep-alive"] }
regex = "1.10.2"
reqwest = { version = "0.11", default-features = false, features = ["rustls-tls"] }
reqwest-tracing = { version = "0.4.0", features = ["opentelemetry_0_19"] }
@@ -165,7 +158,7 @@ tracing-error = "0.2.0"
tracing-opentelemetry = "0.19.0"
tracing-subscriber = { version = "0.3", default_features = false, features = ["smallvec", "fmt", "tracing-log", "std", "env-filter", "json"] }
url = "2.2"
uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }
uuid = { version = "1.2", features = ["v4", "serde"] }
walkdir = "2.3.2"
webpki-roots = "0.25"
x509-parser = "0.15"
@@ -189,7 +182,6 @@ compute_api = { version = "0.1", path = "./libs/compute_api/" }
consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" }
metrics = { version = "0.1", path = "./libs/metrics/" }
pageserver_api = { version = "0.1", path = "./libs/pageserver_api/" }
pageserver_client = { path = "./pageserver/client" }
postgres_backend = { version = "0.1", path = "./libs/postgres_backend/" }
postgres_connection = { version = "0.1", path = "./libs/postgres_connection/" }
postgres_ffi = { version = "0.1", path = "./libs/postgres_ffi/" }
@@ -219,10 +211,6 @@ tonic-build = "0.9"
# TODO: we should probably fork `tokio-postgres-rustls` instead.
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch="neon" }
# bug fixes for UUID
parquet = { git = "https://github.com/neondatabase/arrow-rs", branch = "neon-fix-bugs" }
parquet_derive = { git = "https://github.com/neondatabase/arrow-rs", branch = "neon-fix-bugs" }
################# Binary contents sections
[profile.release]

View File

@@ -3,7 +3,7 @@
### By default, the binaries inside the image have some mock parameters and can start, but are not intended to be used
### inside this image in the real deployments.
ARG REPOSITORY=neondatabase
ARG IMAGE=build-tools
ARG IMAGE=rust
ARG TAG=pinned
# Build Postgres

View File

@@ -1,166 +0,0 @@
FROM debian:bullseye-slim
# Add nonroot user
RUN useradd -ms /bin/bash nonroot -b /home
SHELL ["/bin/bash", "-c"]
# System deps
RUN set -e \
&& apt update \
&& apt install -y \
autoconf \
automake \
bison \
build-essential \
ca-certificates \
cmake \
curl \
flex \
git \
gnupg \
gzip \
jq \
libcurl4-openssl-dev \
libbz2-dev \
libffi-dev \
liblzma-dev \
libncurses5-dev \
libncursesw5-dev \
libpq-dev \
libreadline-dev \
libseccomp-dev \
libsqlite3-dev \
libssl-dev \
libstdc++-10-dev \
libtool \
libxml2-dev \
libxmlsec1-dev \
libxxhash-dev \
lsof \
make \
netcat \
net-tools \
openssh-client \
parallel \
pkg-config \
unzip \
wget \
xz-utils \
zlib1g-dev \
zstd \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
# protobuf-compiler (protoc)
ENV PROTOC_VERSION 25.1
RUN curl -fsSL "https://github.com/protocolbuffers/protobuf/releases/download/v${PROTOC_VERSION}/protoc-${PROTOC_VERSION}-linux-$(uname -m | sed 's/aarch64/aarch_64/g').zip" -o "protoc.zip" \
&& unzip -q protoc.zip -d protoc \
&& mv protoc/bin/protoc /usr/local/bin/protoc \
&& mv protoc/include/google /usr/local/include/google \
&& rm -rf protoc.zip protoc
# LLVM
ENV LLVM_VERSION=17
RUN curl -fsSL 'https://apt.llvm.org/llvm-snapshot.gpg.key' | apt-key add - \
&& echo "deb http://apt.llvm.org/bullseye/ llvm-toolchain-bullseye-${LLVM_VERSION} main" > /etc/apt/sources.list.d/llvm.stable.list \
&& apt update \
&& apt install -y clang-${LLVM_VERSION} llvm-${LLVM_VERSION} \
&& bash -c 'for f in /usr/bin/clang*-${LLVM_VERSION} /usr/bin/llvm*-${LLVM_VERSION}; do ln -s "${f}" "${f%-${LLVM_VERSION}}"; done' \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
# PostgreSQL 14
RUN curl -fsSL 'https://www.postgresql.org/media/keys/ACCC4CF8.asc' | apt-key add - \
&& echo 'deb http://apt.postgresql.org/pub/repos/apt bullseye-pgdg main' > /etc/apt/sources.list.d/pgdg.list \
&& apt update \
&& apt install -y postgresql-client-14 \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
# AWS CLI
RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-$(uname -m).zip" -o "awscliv2.zip" \
&& unzip -q awscliv2.zip \
&& ./aws/install \
&& rm awscliv2.zip
# Mold: A Modern Linker
ENV MOLD_VERSION v2.4.0
RUN set -e \
&& git clone https://github.com/rui314/mold.git \
&& mkdir mold/build \
&& cd mold/build \
&& git checkout ${MOLD_VERSION} \
&& cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_COMPILER=clang++ .. \
&& cmake --build . -j $(nproc) \
&& cmake --install . \
&& cd .. \
&& rm -rf mold
# LCOV
# Build lcov from a fork:
# It includes several bug fixes on top on v2.0 release (https://github.com/linux-test-project/lcov/compare/v2.0...master)
# And patches from us:
# - Generates json file with code coverage summary (https://github.com/neondatabase/lcov/commit/426e7e7a22f669da54278e9b55e6d8caabd00af0.tar.gz)
RUN for package in Capture::Tiny DateTime Devel::Cover Digest::MD5 File::Spec JSON::XS Memory::Process Time::HiRes JSON; do yes | perl -MCPAN -e "CPAN::Shell->notest('install', '$package')"; done \
&& wget https://github.com/neondatabase/lcov/archive/426e7e7a22f669da54278e9b55e6d8caabd00af0.tar.gz -O lcov.tar.gz \
&& echo "61a22a62e20908b8b9e27d890bd0ea31f567a7b9668065589266371dcbca0992 lcov.tar.gz" | sha256sum --check \
&& mkdir -p lcov && tar -xzf lcov.tar.gz -C lcov --strip-components=1 \
&& cd lcov \
&& make install \
&& rm -rf ../lcov.tar.gz
# Switch to nonroot user
USER nonroot:nonroot
WORKDIR /home/nonroot
# Python
ENV PYTHON_VERSION=3.9.2 \
PYENV_ROOT=/home/nonroot/.pyenv \
PATH=/home/nonroot/.pyenv/shims:/home/nonroot/.pyenv/bin:/home/nonroot/.poetry/bin:$PATH
RUN set -e \
&& cd $HOME \
&& curl -sSO https://raw.githubusercontent.com/pyenv/pyenv-installer/master/bin/pyenv-installer \
&& chmod +x pyenv-installer \
&& ./pyenv-installer \
&& export PYENV_ROOT=/home/nonroot/.pyenv \
&& export PATH="$PYENV_ROOT/bin:$PATH" \
&& export PATH="$PYENV_ROOT/shims:$PATH" \
&& pyenv install ${PYTHON_VERSION} \
&& pyenv global ${PYTHON_VERSION} \
&& python --version \
&& pip install --upgrade pip \
&& pip --version \
&& pip install pipenv wheel poetry
# Switch to nonroot user (again)
USER nonroot:nonroot
WORKDIR /home/nonroot
# Rust
# Please keep the version of llvm (installed above) in sync with rust llvm (`rustc --version --verbose | grep LLVM`)
ENV RUSTC_VERSION=1.75.0
ENV RUSTUP_HOME="/home/nonroot/.rustup"
ENV PATH="/home/nonroot/.cargo/bin:${PATH}"
RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux-gnu/rustup-init && whoami && \
chmod +x rustup-init && \
./rustup-init -y --default-toolchain ${RUSTC_VERSION} && \
rm rustup-init && \
export PATH="$HOME/.cargo/bin:$PATH" && \
. "$HOME/.cargo/env" && \
cargo --version && rustup --version && \
rustup component add llvm-tools-preview rustfmt clippy && \
cargo install --git https://github.com/paritytech/cachepot && \
cargo install rustfilt && \
cargo install cargo-hakari && \
cargo install cargo-deny && \
cargo install cargo-hack && \
cargo install cargo-nextest && \
rm -rf /home/nonroot/.cargo/registry && \
rm -rf /home/nonroot/.cargo/git
ENV RUSTC_WRAPPER=cachepot
# Show versions
RUN whoami \
&& python --version \
&& pip --version \
&& cargo --version --verbose \
&& rustup --version --verbose \
&& rustc --version --verbose \
&& clang --version

View File

@@ -1,6 +1,6 @@
ARG PG_VERSION
ARG REPOSITORY=neondatabase
ARG IMAGE=build-tools
ARG IMAGE=rust
ARG TAG=pinned
ARG BUILD_TAG
@@ -48,29 +48,7 @@ RUN cd postgres && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgrowlocks.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgstattuple.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/refint.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/xml2.control && \
# We need to grant EXECUTE on pg_stat_statements_reset() to neon_superuser.
# In vanilla postgres this function is limited to Postgres role superuser.
# In neon we have neon_superuser role that is not a superuser but replaces superuser in some cases.
# We could add the additional grant statements to the postgres repository but it would be hard to maintain,
# whenever we need to pick up a new postgres version and we want to limit the changes in our postgres fork,
# so we do it here.
old_list="pg_stat_statements--1.0--1.1.sql pg_stat_statements--1.1--1.2.sql pg_stat_statements--1.2--1.3.sql pg_stat_statements--1.3--1.4.sql pg_stat_statements--1.4--1.5.sql pg_stat_statements--1.4.sql pg_stat_statements--1.5--1.6.sql"; \
# the first loop is for pg_stat_statement extension version <= 1.6
for file in /usr/local/pgsql/share/extension/pg_stat_statements--*.sql; do \
filename=$(basename "$file"); \
if echo "$old_list" | grep -q -F "$filename"; then \
echo 'GRANT EXECUTE ON FUNCTION pg_stat_statements_reset() TO neon_superuser;' >> $file; \
fi; \
done; \
# the second loop is for pg_stat_statement extension versions >= 1.7,
# where pg_stat_statement_reset() got 3 additional arguments
for file in /usr/local/pgsql/share/extension/pg_stat_statements--*.sql; do \
filename=$(basename "$file"); \
if ! echo "$old_list" | grep -q -F "$filename"; then \
echo 'GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) TO neon_superuser;' >> $file; \
fi; \
done
echo 'trusted = true' >> /usr/local/pgsql/share/extension/xml2.control
#########################################################################################
#
@@ -591,23 +569,6 @@ RUN wget https://github.com/ChenHuajun/pg_roaringbitmap/archive/refs/tags/v0.5.4
make -j $(getconf _NPROCESSORS_ONLN) install && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/roaringbitmap.control
#########################################################################################
#
# Layer "pg-semver-pg-build"
# compile pg_semver extension
#
#########################################################################################
FROM build-deps AS pg-semver-pg-build
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
ENV PATH "/usr/local/pgsql/bin/:$PATH"
RUN wget https://github.com/theory/pg-semver/archive/refs/tags/v0.32.1.tar.gz -O pg_semver.tar.gz && \
echo "fbdaf7512026d62eec03fad8687c15ed509b6ba395bff140acd63d2e4fbe25d7 pg_semver.tar.gz" | sha256sum --check && \
mkdir pg_semver-src && cd pg_semver-src && tar xvzf ../pg_semver.tar.gz --strip-components=1 -C . && \
make -j $(getconf _NPROCESSORS_ONLN) && \
make -j $(getconf _NPROCESSORS_ONLN) install && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/semver.control
#########################################################################################
#
# Layer "pg-embedding-pg-build"
@@ -807,7 +768,6 @@ COPY --from=pg-pgx-ulid-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=rdkit-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-uuidv7-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-roaringbitmap-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-semver-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=pg-embedding-pg-build /usr/local/pgsql/ /usr/local/pgsql/
COPY --from=wal2json-pg-build /usr/local/pgsql /usr/local/pgsql
COPY pgxn/ pgxn/
@@ -883,10 +843,8 @@ FROM debian:bullseye-slim
RUN mkdir /var/db && useradd -m -d /var/db/postgres postgres && \
echo "postgres:test_console_pass" | chpasswd && \
mkdir /var/db/postgres/compute && mkdir /var/db/postgres/specs && \
mkdir /var/db/postgres/pgbouncer && \
chown -R postgres:postgres /var/db/postgres && \
chmod 0750 /var/db/postgres/compute && \
chmod 0750 /var/db/postgres/pgbouncer && \
echo '/usr/local/lib' >> /etc/ld.so.conf && /sbin/ldconfig && \
# create folder for file cache
mkdir -p -m 777 /neon/cache

View File

@@ -1,7 +1,7 @@
# First transient image to build compute_tools binaries
# NB: keep in sync with rust image version in .github/workflows/build_and_test.yml
ARG REPOSITORY=neondatabase
ARG IMAGE=build-tools
ARG IMAGE=rust
ARG TAG=pinned
ARG BUILD_TAG

View File

@@ -29,14 +29,13 @@ See developer documentation in [SUMMARY.md](/docs/SUMMARY.md) for more informati
```bash
apt install build-essential libtool libreadline-dev zlib1g-dev flex bison libseccomp-dev \
libssl-dev clang pkg-config libpq-dev cmake postgresql-client protobuf-compiler \
libcurl4-openssl-dev openssl python3-poetry lsof libicu-dev
libcurl4-openssl-dev openssl python-poetry lsof libicu-dev
```
* On Fedora, these packages are needed:
```bash
dnf install flex bison readline-devel zlib-devel openssl-devel \
libseccomp-devel perl clang cmake postgresql postgresql-contrib protobuf-compiler \
protobuf-devel libcurl-devel openssl poetry lsof libicu-devel libpq-devel python3-devel \
libffi-devel
protobuf-devel libcurl-devel openssl poetry lsof libicu-devel
```
* On Arch based systems, these packages are needed:
```bash

View File

@@ -13,7 +13,6 @@ clap.workspace = true
flate2.workspace = true
futures.workspace = true
hyper = { workspace = true, features = ["full"] }
nix.workspace = true
notify.workspace = true
num_cpus.workspace = true
opentelemetry.workspace = true
@@ -21,7 +20,6 @@ postgres.workspace = true
regex.workspace = true
serde.workspace = true
serde_json.workspace = true
signal-hook.workspace = true
tar.workspace = true
reqwest = { workspace = true, features = ["json"] }
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
@@ -39,6 +37,5 @@ workspace_hack.workspace = true
toml_edit.workspace = true
remote_storage = { version = "0.1", path = "../libs/remote_storage/" }
vm_monitor = { version = "0.1", path = "../libs/vm_monitor/" }
zstd = "0.13"
zstd = "0.12.4"
bytes = "1.0"
rust-ini = "0.20.0"

View File

@@ -31,29 +31,25 @@
//! -C 'postgresql://cloud_admin@localhost/postgres' \
//! -S /var/db/postgres/specs/current.json \
//! -b /usr/local/bin/postgres \
//! -r http://pg-ext-s3-gateway \
//! -r http://pg-ext-s3-gateway
//! ```
//!
use std::collections::HashMap;
use std::fs::File;
use std::path::Path;
use std::process::exit;
use std::sync::atomic::Ordering;
use std::sync::{mpsc, Arc, Condvar, Mutex, RwLock};
use std::{thread, time::Duration};
use anyhow::{Context, Result};
use chrono::Utc;
use clap::Arg;
use nix::sys::signal::{kill, Signal};
use signal_hook::consts::{SIGQUIT, SIGTERM};
use signal_hook::{consts::SIGINT, iterator::Signals};
use tracing::{error, info};
use url::Url;
use compute_api::responses::ComputeStatus;
use compute_tools::compute::{ComputeNode, ComputeState, ParsedSpec, PG_PID, SYNC_SAFEKEEPERS_PID};
use compute_tools::compute::{ComputeNode, ComputeState, ParsedSpec};
use compute_tools::configurator::launch_configurator;
use compute_tools::extension_server::get_pg_version;
use compute_tools::http::api::launch_http_server;
@@ -69,13 +65,6 @@ const BUILD_TAG_DEFAULT: &str = "latest";
fn main() -> Result<()> {
init_tracing_and_logging(DEFAULT_LOG_LEVEL)?;
let mut signals = Signals::new([SIGINT, SIGTERM, SIGQUIT])?;
thread::spawn(move || {
for sig in signals.forever() {
handle_exit_signal(sig);
}
});
let build_tag = option_env!("BUILD_TAG")
.unwrap_or(BUILD_TAG_DEFAULT)
.to_string();
@@ -224,9 +213,9 @@ fn main() -> Result<()> {
let compute = Arc::new(compute_node);
// If this is a pooled VM, prewarm before starting HTTP server and becoming
// available for binding. Prewarming helps Postgres start quicker later,
// available for binding. Prewarming helps postgres start quicker later,
// because QEMU will already have it's memory allocated from the host, and
// the necessary binaries will already be cached.
// the necessary binaries will alreaady be cached.
if !spec_set {
compute.prewarm_postgres()?;
}
@@ -269,11 +258,6 @@ fn main() -> Result<()> {
state.status = ComputeStatus::Init;
compute.state_changed.notify_all();
info!(
"running compute with features: {:?}",
state.pspec.as_ref().unwrap().spec.features
);
drop(state);
// Launch remaining service threads
@@ -286,7 +270,7 @@ fn main() -> Result<()> {
let pg = match compute.start_compute(extension_server_port) {
Ok(pg) => Some(pg),
Err(err) => {
error!("could not start the compute node: {:#}", err);
error!("could not start the compute node: {:?}", err);
let mut state = compute.state.lock().unwrap();
state.error = Some(format!("{:?}", err));
state.status = ComputeStatus::Failed;
@@ -348,20 +332,13 @@ fn main() -> Result<()> {
// Wait for the child Postgres process forever. In this state Ctrl+C will
// propagate to Postgres and it will be shut down as well.
if let Some((mut pg, logs_handle)) = pg {
if let Some(mut pg) = pg {
// Startup is finished, exit the startup tracing span
drop(startup_context_guard);
let ecode = pg
.wait()
.expect("failed to start waiting on Postgres process");
PG_PID.store(0, Ordering::SeqCst);
// Process has exited, so we can join the logs thread.
let _ = logs_handle
.join()
.map_err(|e| tracing::error!("log thread panicked: {:?}", e));
info!("Postgres exited with code {}, shutting down", ecode);
exit_code = ecode.code()
}
@@ -518,24 +495,6 @@ fn cli() -> clap::Command {
)
}
/// When compute_ctl is killed, send also termination signal to sync-safekeepers
/// to prevent leakage. TODO: it is better to convert compute_ctl to async and
/// wait for termination which would be easy then.
fn handle_exit_signal(sig: i32) {
info!("received {sig} termination signal");
let ss_pid = SYNC_SAFEKEEPERS_PID.load(Ordering::SeqCst);
if ss_pid != 0 {
let ss_pid = nix::unistd::Pid::from_raw(ss_pid as i32);
kill(ss_pid, Signal::SIGTERM).ok();
}
let pg_pid = PG_PID.load(Ordering::SeqCst);
if pg_pid != 0 {
let pg_pid = nix::unistd::Pid::from_raw(pg_pid as i32);
kill(pg_pid, Signal::SIGTERM).ok();
}
exit(1);
}
#[test]
fn verify_cli() {
cli().debug_assert()

View File

@@ -6,10 +6,7 @@ use std::os::unix::fs::PermissionsExt;
use std::path::Path;
use std::process::{Command, Stdio};
use std::str::FromStr;
use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering;
use std::sync::{Condvar, Mutex, RwLock};
use std::thread;
use std::time::Instant;
use anyhow::{Context, Result};
@@ -20,7 +17,7 @@ use futures::StreamExt;
use postgres::{Client, NoTls};
use tokio;
use tokio_postgres;
use tracing::{debug, error, info, instrument, warn};
use tracing::{error, info, instrument, warn};
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
@@ -31,15 +28,11 @@ use utils::measured_stream::MeasuredReader;
use remote_storage::{DownloadError, RemotePath};
use crate::checker::create_availability_check_data;
use crate::logger::inlinify;
use crate::pg_helpers::*;
use crate::spec::*;
use crate::sync_sk::{check_if_synced, ping_safekeeper};
use crate::{config, extension_server};
pub static SYNC_SAFEKEEPERS_PID: AtomicU32 = AtomicU32::new(0);
pub static PG_PID: AtomicU32 = AtomicU32::new(0);
/// Compute node info shared across several `compute_ctl` threads.
pub struct ComputeNode {
// Url type maintains proper escaping
@@ -276,7 +269,7 @@ fn create_neon_superuser(spec: &ComputeSpec, client: &mut Client) -> Result<()>
$$;"#,
roles_decl, database_decl,
);
info!("Neon superuser created: {}", inlinify(&query));
info!("Neon superuser created:\n{}", &query);
client
.simple_query(&query)
.map_err(|e| anyhow::anyhow!(e).context(query))?;
@@ -492,7 +485,7 @@ impl ComputeNode {
pub fn sync_safekeepers(&self, storage_auth_token: Option<String>) -> Result<Lsn> {
let start_time = Utc::now();
let mut sync_handle = maybe_cgexec(&self.pgbin)
let sync_handle = maybe_cgexec(&self.pgbin)
.args(["--sync-safekeepers"])
.env("PGDATA", &self.pgdata) // we cannot use -D in this mode
.envs(if let Some(storage_auth_token) = &storage_auth_token {
@@ -501,29 +494,15 @@ impl ComputeNode {
vec![]
})
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.expect("postgres --sync-safekeepers failed to start");
SYNC_SAFEKEEPERS_PID.store(sync_handle.id(), Ordering::SeqCst);
// `postgres --sync-safekeepers` will print all log output to stderr and
// final LSN to stdout. So we leave stdout to collect LSN, while stderr logs
// will be collected in a child thread.
let stderr = sync_handle
.stderr
.take()
.expect("stderr should be captured");
let logs_handle = handle_postgres_logs(stderr);
// final LSN to stdout. So we pipe only stdout, while stderr will be automatically
// redirected to the caller output.
let sync_output = sync_handle
.wait_with_output()
.expect("postgres --sync-safekeepers failed");
SYNC_SAFEKEEPERS_PID.store(0, Ordering::SeqCst);
// Process has exited, so we can join the logs thread.
let _ = logs_handle
.join()
.map_err(|e| tracing::error!("log thread panicked: {:?}", e));
if !sync_output.status.success() {
anyhow::bail!(
@@ -661,12 +640,11 @@ impl ComputeNode {
/// Start Postgres as a child process and manage DBs/roles.
/// After that this will hang waiting on the postmaster process to exit.
/// Returns a handle to the child process and a handle to the logs thread.
#[instrument(skip_all)]
pub fn start_postgres(
&self,
storage_auth_token: Option<String>,
) -> Result<(std::process::Child, std::thread::JoinHandle<()>)> {
) -> Result<std::process::Child> {
let pgdata_path = Path::new(&self.pgdata);
// Run postgres as a child process.
@@ -677,18 +655,12 @@ impl ComputeNode {
} else {
vec![]
})
.stderr(Stdio::piped())
.spawn()
.expect("cannot start postgres process");
PG_PID.store(pg.id(), Ordering::SeqCst);
// Start a thread to collect logs from stderr.
let stderr = pg.stderr.take().expect("stderr should be captured");
let logs_handle = handle_postgres_logs(stderr);
wait_for_postgres(&mut pg, pgdata_path)?;
Ok((pg, logs_handle))
Ok(pg)
}
/// Do initial configuration of the already started Postgres.
@@ -765,25 +737,6 @@ impl ComputeNode {
pub fn reconfigure(&self) -> Result<()> {
let spec = self.state.lock().unwrap().pspec.clone().unwrap().spec;
if let Some(ref pgbouncer_settings) = spec.pgbouncer_settings {
info!("tuning pgbouncer");
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to create rt");
// Spawn a thread to do the tuning,
// so that we don't block the main thread that starts Postgres.
let pgbouncer_settings = pgbouncer_settings.clone();
let _handle = thread::spawn(move || {
let res = rt.block_on(tune_pgbouncer(pgbouncer_settings));
if let Err(err) = res {
error!("error while tuning pgbouncer: {err:?}");
}
});
}
// Write new config
let pgdata_path = Path::new(&self.pgdata);
let postgresql_conf_path = pgdata_path.join("postgresql.conf");
@@ -827,10 +780,7 @@ impl ComputeNode {
}
#[instrument(skip_all)]
pub fn start_compute(
&self,
extension_server_port: u16,
) -> Result<(std::process::Child, std::thread::JoinHandle<()>)> {
pub fn start_compute(&self, extension_server_port: u16) -> Result<std::process::Child> {
let compute_state = self.state.lock().unwrap().clone();
let pspec = compute_state.pspec.as_ref().expect("spec must be set");
info!(
@@ -841,26 +791,6 @@ impl ComputeNode {
pspec.timeline_id,
);
// tune pgbouncer
if let Some(pgbouncer_settings) = &pspec.spec.pgbouncer_settings {
info!("tuning pgbouncer");
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to create rt");
// Spawn a thread to do the tuning,
// so that we don't block the main thread that starts Postgres.
let pgbouncer_settings = pgbouncer_settings.clone();
let _handle = thread::spawn(move || {
let res = rt.block_on(tune_pgbouncer(pgbouncer_settings));
if let Err(err) = res {
error!("error while tuning pgbouncer: {err:?}");
}
});
}
info!(
"start_compute spec.remote_extensions {:?}",
pspec.spec.remote_extensions
@@ -895,7 +825,7 @@ impl ComputeNode {
self.prepare_pgdata(&compute_state, extension_server_port)?;
let start_time = Utc::now();
let pg_process = self.start_postgres(pspec.storage_auth_token.clone())?;
let pg = self.start_postgres(pspec.storage_auth_token.clone())?;
let config_time = Utc::now();
if pspec.spec.mode == ComputeMode::Primary && !pspec.spec.skip_pg_catalog_updates {
@@ -945,17 +875,7 @@ impl ComputeNode {
};
info!(?metrics, "compute start finished");
Ok(pg_process)
}
/// Update the `last_active` in the shared state, but ensure that it's a more recent one.
pub fn update_last_active(&self, last_active: Option<DateTime<Utc>>) {
let mut state = self.state.lock().unwrap();
// NB: `Some(<DateTime>)` is always greater than `None`.
if last_active > state.last_active {
state.last_active = last_active;
debug!("set the last compute activity time to: {:?}", last_active);
}
Ok(pg)
}
// Look for core dumps and collect backtraces.

View File

@@ -38,9 +38,3 @@ pub fn init_tracing_and_logging(default_log_level: &str) -> anyhow::Result<()> {
Ok(())
}
/// Replace all newline characters with a special character to make it
/// easier to grep for log messages.
pub fn inlinify(s: &str) -> String {
s.replace('\n', "\u{200B}")
}

View File

@@ -3,165 +3,97 @@ use std::{thread, time::Duration};
use chrono::{DateTime, Utc};
use postgres::{Client, NoTls};
use tracing::{debug, error, info, warn};
use tracing::{debug, info};
use crate::compute::ComputeNode;
use compute_api::responses::ComputeStatus;
use compute_api::spec::ComputeFeature;
const MONITOR_CHECK_INTERVAL: Duration = Duration::from_millis(500);
// Spin in a loop and figure out the last activity time in the Postgres.
// Then update it in the shared state. This function never errors out.
// NB: the only expected panic is at `Mutex` unwrap(), all other errors
// should be handled gracefully.
// XXX: the only expected panic is at `RwLock` unwrap().
fn watch_compute_activity(compute: &ComputeNode) {
// Suppose that `connstr` doesn't change
let connstr = compute.connstr.as_str();
// During startup and configuration we connect to every Postgres database,
// but we don't want to count this as some user activity. So wait until
// the compute fully started before monitoring activity.
wait_for_postgres_start(compute);
// Define `client` outside of the loop to reuse existing connection if it's active.
let mut client = Client::connect(connstr, NoTls);
let mut sleep = false;
let mut prev_active_time: Option<f64> = None;
let mut prev_sessions: Option<i64> = None;
if compute.has_feature(ComputeFeature::ActivityMonitorExperimental) {
info!("starting experimental activity monitor for {}", connstr);
} else {
info!("starting activity monitor for {}", connstr);
}
info!("watching Postgres activity at {}", connstr);
loop {
// We use `continue` a lot, so it's more convenient to sleep at the top of the loop.
// But skip the first sleep, so we can connect to Postgres immediately.
if sleep {
// Should be outside of the mutex lock to allow others to read while we sleep.
thread::sleep(MONITOR_CHECK_INTERVAL);
} else {
sleep = true;
}
// Should be outside of the write lock to allow others to read while we sleep.
thread::sleep(MONITOR_CHECK_INTERVAL);
match &mut client {
Ok(cli) => {
if cli.is_closed() {
info!("connection to Postgres is closed, trying to reconnect");
info!("connection to postgres closed, trying to reconnect");
// Connection is closed, reconnect and try again.
client = Client::connect(connstr, NoTls);
continue;
}
// This is a new logic, only enable if the feature flag is set.
// TODO: remove this once we are sure that it works OR drop it altogether.
if compute.has_feature(ComputeFeature::ActivityMonitorExperimental) {
// First, check if the total active time or sessions across all databases has changed.
// If it did, it means that user executed some queries. In theory, it can even go down if
// some databases were dropped, but it's still a user activity.
match get_database_stats(cli) {
Ok((active_time, sessions)) => {
let mut detected_activity = false;
// Get all running client backends except ourself, use RFC3339 DateTime format.
let backends = cli
.query(
"SELECT state, to_char(state_change, 'YYYY-MM-DD\"T\"HH24:MI:SS.US\"Z\"') AS state_change
FROM pg_stat_activity
WHERE backend_type = 'client backend'
AND pid != pg_backend_pid()
AND usename != 'cloud_admin';", // XXX: find a better way to filter other monitors?
&[],
);
let mut last_active = compute.state.lock().unwrap().last_active;
prev_active_time = match prev_active_time {
Some(prev_active_time) => {
if active_time != prev_active_time {
detected_activity = true;
}
Some(active_time)
}
None => Some(active_time),
if let Ok(backs) = backends {
let mut idle_backs: Vec<DateTime<Utc>> = vec![];
for b in backs.into_iter() {
let state: String = match b.try_get("state") {
Ok(state) => state,
Err(_) => continue,
};
if state == "idle" {
let change: String = match b.try_get("state_change") {
Ok(state_change) => state_change,
Err(_) => continue,
};
prev_sessions = match prev_sessions {
Some(prev_sessions) => {
if sessions != prev_sessions {
detected_activity = true;
}
Some(sessions)
let change = DateTime::parse_from_rfc3339(&change);
match change {
Ok(t) => idle_backs.push(t.with_timezone(&Utc)),
Err(e) => {
info!("cannot parse backend state_change DateTime: {}", e);
continue;
}
None => Some(sessions),
};
if detected_activity {
// Update the last active time and continue, we don't need to
// check backends state change.
compute.update_last_active(Some(Utc::now()));
continue;
}
} else {
// Found non-idle backend, so the last activity is NOW.
// Save it and exit the for loop. Also clear the idle backend
// `state_change` timestamps array as it doesn't matter now.
last_active = Some(Utc::now());
idle_backs.clear();
break;
}
Err(e) => {
error!("could not get database statistics: {}", e);
continue;
}
}
// Get idle backend `state_change` with the max timestamp.
if let Some(last) = idle_backs.iter().max() {
last_active = Some(*last);
}
}
// Second, if database statistics is the same, check all backends state change,
// maybe there is some with more recent activity. `get_backends_state_change()`
// can return None or stale timestamp, so it's `compute.update_last_active()`
// responsibility to check if the new timestamp is more recent than the current one.
// This helps us to discover new sessions, that did nothing yet.
match get_backends_state_change(cli) {
Ok(last_active) => {
compute.update_last_active(last_active);
}
Err(e) => {
error!("could not get backends state change: {}", e);
}
}
// Finally, if there are existing (logical) walsenders, do not suspend.
//
// walproposer doesn't currently show up in pg_stat_replication,
// but protect if it will be
let ws_count_query = "select count(*) from pg_stat_replication where application_name != 'walproposer';";
match cli.query_one(ws_count_query, &[]) {
Ok(r) => match r.try_get::<&str, i64>("count") {
Ok(num_ws) => {
if num_ws > 0 {
compute.update_last_active(Some(Utc::now()));
continue;
}
}
Err(e) => {
warn!("failed to parse walsenders count: {:?}", e);
continue;
}
},
Err(e) => {
warn!("failed to get list of walsenders: {:?}", e);
continue;
}
}
//
// Do not suspend compute if autovacuum is running
//
let autovacuum_count_query = "select count(*) from pg_stat_activity where backend_type = 'autovacuum worker'";
match cli.query_one(autovacuum_count_query, &[]) {
Ok(r) => match r.try_get::<&str, i64>("count") {
Ok(num_workers) => {
if num_workers > 0 {
compute.update_last_active(Some(Utc::now()));
continue;
}
}
Err(e) => {
warn!("failed to parse autovacuum workers count: {:?}", e);
continue;
}
},
Err(e) => {
warn!("failed to get list of autovacuum workers: {:?}", e);
continue;
}
// Update the last activity in the shared state if we got a more recent one.
let mut state = compute.state.lock().unwrap();
// NB: `Some(<DateTime>)` is always greater than `None`.
if last_active > state.last_active {
state.last_active = last_active;
debug!("set the last compute activity time to: {:?}", last_active);
}
}
Err(e) => {
debug!("could not connect to Postgres: {}, retrying", e);
debug!("cannot connect to postgres: {}, retrying", e);
// Establish a new connection and try again.
client = Client::connect(connstr, NoTls);
@@ -170,124 +102,12 @@ fn watch_compute_activity(compute: &ComputeNode) {
}
}
// Hang on condition variable waiting until the compute status is `Running`.
fn wait_for_postgres_start(compute: &ComputeNode) {
let mut state = compute.state.lock().unwrap();
while state.status != ComputeStatus::Running {
info!("compute is not running, waiting before monitoring activity");
state = compute.state_changed.wait(state).unwrap();
if state.status == ComputeStatus::Running {
break;
}
}
}
// Figure out the total active time and sessions across all non-system databases.
// Returned tuple is `(active_time, sessions)`.
// It can return `0.0` active time or `0` sessions, which means no user databases exist OR
// it was a start with skipped `pg_catalog` updates and user didn't do any queries
// (or open any sessions) yet.
fn get_database_stats(cli: &mut Client) -> anyhow::Result<(f64, i64)> {
// Filter out `postgres` database as `compute_ctl` and other monitoring tools
// like `postgres_exporter` use it to query Postgres statistics.
// Use explicit 8 bytes type casts to match Rust types.
let stats = cli.query_one(
"SELECT coalesce(sum(active_time), 0.0)::float8 AS total_active_time,
coalesce(sum(sessions), 0)::bigint AS total_sessions
FROM pg_stat_database
WHERE datname NOT IN (
'postgres',
'template0',
'template1'
);",
&[],
);
let stats = match stats {
Ok(stats) => stats,
Err(e) => {
return Err(anyhow::anyhow!("could not query active_time: {}", e));
}
};
let active_time: f64 = match stats.try_get("total_active_time") {
Ok(active_time) => active_time,
Err(e) => return Err(anyhow::anyhow!("could not get total_active_time: {}", e)),
};
let sessions: i64 = match stats.try_get("total_sessions") {
Ok(sessions) => sessions,
Err(e) => return Err(anyhow::anyhow!("could not get total_sessions: {}", e)),
};
Ok((active_time, sessions))
}
// Figure out the most recent state change time across all client backends.
// If there is currently active backend, timestamp will be `Utc::now()`.
// It can return `None`, which means no client backends exist or we were
// unable to parse the timestamp.
fn get_backends_state_change(cli: &mut Client) -> anyhow::Result<Option<DateTime<Utc>>> {
let mut last_active: Option<DateTime<Utc>> = None;
// Get all running client backends except ourself, use RFC3339 DateTime format.
let backends = cli.query(
"SELECT state, to_char(state_change, 'YYYY-MM-DD\"T\"HH24:MI:SS.US\"Z\"') AS state_change
FROM pg_stat_activity
WHERE backend_type = 'client backend'
AND pid != pg_backend_pid()
AND usename != 'cloud_admin';", // XXX: find a better way to filter other monitors?
&[],
);
match backends {
Ok(backs) => {
let mut idle_backs: Vec<DateTime<Utc>> = vec![];
for b in backs.into_iter() {
let state: String = match b.try_get("state") {
Ok(state) => state,
Err(_) => continue,
};
if state == "idle" {
let change: String = match b.try_get("state_change") {
Ok(state_change) => state_change,
Err(_) => continue,
};
let change = DateTime::parse_from_rfc3339(&change);
match change {
Ok(t) => idle_backs.push(t.with_timezone(&Utc)),
Err(e) => {
info!("cannot parse backend state_change DateTime: {}", e);
continue;
}
}
} else {
// Found non-idle backend, so the last activity is NOW.
// Return immediately, no need to check other backends.
return Ok(Some(Utc::now()));
}
}
// Get idle backend `state_change` with the max timestamp.
if let Some(last) = idle_backs.iter().max() {
last_active = Some(*last);
}
}
Err(e) => {
return Err(anyhow::anyhow!("could not query backends: {}", e));
}
}
Ok(last_active)
}
/// Launch a separate compute monitor thread and return its `JoinHandle`.
pub fn launch_monitor(compute: &Arc<ComputeNode>) -> thread::JoinHandle<()> {
let compute = Arc::clone(compute);
pub fn launch_monitor(state: &Arc<ComputeNode>) -> thread::JoinHandle<()> {
let state = Arc::clone(state);
thread::Builder::new()
.name("compute-monitor".into())
.spawn(move || watch_compute_activity(&compute))
.spawn(move || watch_compute_activity(&state))
.expect("cannot launch compute monitor thread")
}

View File

@@ -6,17 +6,12 @@ use std::io::{BufRead, BufReader};
use std::os::unix::fs::PermissionsExt;
use std::path::Path;
use std::process::Child;
use std::thread::JoinHandle;
use std::time::{Duration, Instant};
use anyhow::{bail, Result};
use ini::Ini;
use notify::{RecursiveMode, Watcher};
use postgres::{Client, Transaction};
use tokio::io::AsyncBufReadExt;
use tokio::time::timeout;
use tokio_postgres::NoTls;
use tracing::{debug, error, info, instrument};
use tracing::{debug, instrument};
use compute_api::spec::{Database, GenericOption, GenericOptions, PgIdent, Role};
@@ -364,172 +359,3 @@ pub fn create_pgdata(pgdata: &str) -> Result<()> {
Ok(())
}
/// Update pgbouncer.ini with provided options
fn update_pgbouncer_ini(
pgbouncer_config: HashMap<String, String>,
pgbouncer_ini_path: &str,
) -> Result<()> {
let mut conf = Ini::load_from_file(pgbouncer_ini_path)?;
let section = conf.section_mut(Some("pgbouncer")).unwrap();
for (option_name, value) in pgbouncer_config.iter() {
section.insert(option_name, value);
debug!(
"Updating pgbouncer.ini with new values {}={}",
option_name, value
);
}
conf.write_to_file(pgbouncer_ini_path)?;
Ok(())
}
/// Tune pgbouncer.
/// 1. Apply new config using pgbouncer admin console
/// 2. Add new values to pgbouncer.ini to preserve them after restart
pub async fn tune_pgbouncer(pgbouncer_config: HashMap<String, String>) -> Result<()> {
let pgbouncer_connstr = if std::env::var_os("AUTOSCALING").is_some() {
// for VMs use pgbouncer specific way to connect to
// pgbouncer admin console without password
// when pgbouncer is running under the same user.
"host=/tmp port=6432 dbname=pgbouncer user=pgbouncer".to_string()
} else {
// for k8s use normal connection string with password
// to connect to pgbouncer admin console
let mut pgbouncer_connstr =
"host=localhost port=6432 dbname=pgbouncer user=postgres sslmode=disable".to_string();
if let Ok(pass) = std::env::var("PGBOUNCER_PASSWORD") {
pgbouncer_connstr.push_str(format!(" password={}", pass).as_str());
}
pgbouncer_connstr
};
info!(
"Connecting to pgbouncer with connection string: {}",
pgbouncer_connstr
);
// connect to pgbouncer, retrying several times
// because pgbouncer may not be ready yet
let mut retries = 3;
let client = loop {
match tokio_postgres::connect(&pgbouncer_connstr, NoTls).await {
Ok((client, connection)) => {
tokio::spawn(async move {
if let Err(e) = connection.await {
eprintln!("connection error: {}", e);
}
});
break client;
}
Err(e) => {
if retries == 0 {
return Err(e.into());
}
error!("Failed to connect to pgbouncer: pgbouncer_connstr {}", e);
retries -= 1;
tokio::time::sleep(Duration::from_secs(1)).await;
}
}
};
// Apply new config
for (option_name, value) in pgbouncer_config.iter() {
let query = format!("SET {}={}", option_name, value);
// keep this log line for debugging purposes
info!("Applying pgbouncer setting change: {}", query);
if let Err(err) = client.simple_query(&query).await {
// Don't fail on error, just print it into log
error!(
"Failed to apply pgbouncer setting change: {}, {}",
query, err
);
};
}
// save values to pgbouncer.ini
// so that they are preserved after pgbouncer restart
let pgbouncer_ini_path = if std::env::var_os("AUTOSCALING").is_some() {
// in VMs we use /etc/pgbouncer.ini
"/etc/pgbouncer.ini".to_string()
} else {
// in pods we use /var/db/postgres/pgbouncer/pgbouncer.ini
// this is a shared volume between pgbouncer and postgres containers
// FIXME: fix permissions for this file
"/var/db/postgres/pgbouncer/pgbouncer.ini".to_string()
};
update_pgbouncer_ini(pgbouncer_config, &pgbouncer_ini_path)?;
Ok(())
}
/// Spawn a thread that will read Postgres logs from `stderr`, join multiline logs
/// and send them to the logger. In the future we may also want to add context to
/// these logs.
pub fn handle_postgres_logs(stderr: std::process::ChildStderr) -> JoinHandle<()> {
std::thread::spawn(move || {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("failed to build tokio runtime");
let res = runtime.block_on(async move {
let stderr = tokio::process::ChildStderr::from_std(stderr)?;
handle_postgres_logs_async(stderr).await
});
if let Err(e) = res {
tracing::error!("error while processing postgres logs: {}", e);
}
})
}
/// Read Postgres logs from `stderr` until EOF. Buffer is flushed on one of the following conditions:
/// - next line starts with timestamp
/// - EOF
/// - no new lines were written for the last second
async fn handle_postgres_logs_async(stderr: tokio::process::ChildStderr) -> Result<()> {
let mut lines = tokio::io::BufReader::new(stderr).lines();
let timeout_duration = Duration::from_millis(100);
let ts_regex =
regex::Regex::new(r"^\d+-\d{2}-\d{2} \d{2}:\d{2}:\d{2}").expect("regex is valid");
let mut buf = vec![];
loop {
let next_line = timeout(timeout_duration, lines.next_line()).await;
// we should flush lines from the buffer if we cannot continue reading multiline message
let should_flush_buf = match next_line {
// Flushing if new line starts with timestamp
Ok(Ok(Some(ref line))) => ts_regex.is_match(line),
// Flushing on EOF, timeout or error
_ => true,
};
if !buf.is_empty() && should_flush_buf {
// join multiline message into a single line, separated by unicode Zero Width Space.
// "PG:" suffix is used to distinguish postgres logs from other logs.
let combined = format!("PG:{}\n", buf.join("\u{200B}"));
buf.clear();
// sync write to stderr to avoid interleaving with other logs
use std::io::Write;
let res = std::io::stderr().lock().write_all(combined.as_bytes());
if let Err(e) = res {
tracing::error!("error while writing to stderr: {}", e);
}
}
// if not timeout, append line to the buffer
if next_line.is_ok() {
match next_line?? {
Some(line) => buf.push(line),
// EOF
None => break,
};
}
}
Ok(())
}

View File

@@ -9,7 +9,6 @@ use reqwest::StatusCode;
use tracing::{error, info, info_span, instrument, span_enabled, warn, Level};
use crate::config;
use crate::logger::inlinify;
use crate::params::PG_HBA_ALL_MD5;
use crate::pg_helpers::*;
@@ -190,20 +189,18 @@ pub fn handle_roles(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
// Print a list of existing Postgres roles (only in debug mode)
if span_enabled!(Level::INFO) {
let mut vec = Vec::new();
info!("postgres roles:");
for r in &existing_roles {
vec.push(format!(
"{}:{}",
info!(
" - {}:{}",
r.name,
if r.encrypted_password.is_some() {
"[FILTERED]"
} else {
"(null)"
}
));
);
}
info!("postgres roles (total {}): {:?}", vec.len(), vec);
}
// Process delta operations first
@@ -241,10 +238,7 @@ pub fn handle_roles(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
// Refresh Postgres roles info to handle possible roles renaming
let existing_roles: Vec<Role> = get_existing_roles(&mut xact)?;
info!(
"handling cluster spec roles (total {})",
spec.cluster.roles.len()
);
info!("cluster spec roles:");
for role in &spec.cluster.roles {
let name = &role.name;
// XXX: with a limited number of roles it is fine, but consider making it a HashMap
@@ -304,10 +298,10 @@ pub fn handle_roles(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
// safe to add more permissions here. BYPASSRLS and REPLICATION are inherited
// from neon_superuser.
let mut query: String = format!(
"CREATE ROLE {} INHERIT CREATEROLE CREATEDB BYPASSRLS REPLICATION IN ROLE neon_superuser",
"CREATE ROLE {} INHERIT CREATEROLE CREATEDB IN ROLE neon_superuser",
name.pg_quote()
);
info!("running role create query: '{}'", &query);
info!("role create query: '{}'", &query);
query.push_str(&role.to_pg_options());
xact.execute(query.as_str(), &[])?;
}
@@ -324,7 +318,7 @@ pub fn handle_roles(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
RoleAction::Create => " -> create",
RoleAction::Update => " -> update",
};
info!(" - {}:{}{}", name, pwd, action_str);
info!(" - {}:{}{}", name, pwd, action_str);
}
}
@@ -376,49 +370,33 @@ pub fn handle_role_deletions(spec: &ComputeSpec, connstr: &str, client: &mut Cli
Ok(())
}
fn reassign_owned_objects_in_one_db(
conf: Config,
role_name: &PgIdent,
db_owner: &PgIdent,
) -> Result<()> {
let mut client = conf.connect(NoTls)?;
// This will reassign all dependent objects to the db owner
let reassign_query = format!(
"REASSIGN OWNED BY {} TO {}",
role_name.pg_quote(),
db_owner.pg_quote()
);
info!(
"reassigning objects owned by '{}' in db '{}' to '{}'",
role_name,
conf.get_dbname().unwrap_or(""),
db_owner
);
client.simple_query(&reassign_query)?;
// This now will only drop privileges of the role
let drop_query = format!("DROP OWNED BY {}", role_name.pg_quote());
client.simple_query(&drop_query)?;
Ok(())
}
// Reassign all owned objects in all databases to the owner of the database.
fn reassign_owned_objects(spec: &ComputeSpec, connstr: &str, role_name: &PgIdent) -> Result<()> {
for db in &spec.cluster.databases {
if db.owner != *role_name {
let mut conf = Config::from_str(connstr)?;
conf.dbname(&db.name);
reassign_owned_objects_in_one_db(conf, role_name, &db.owner)?;
let mut client = conf.connect(NoTls)?;
// This will reassign all dependent objects to the db owner
let reassign_query = format!(
"REASSIGN OWNED BY {} TO {}",
role_name.pg_quote(),
db.owner.pg_quote()
);
info!(
"reassigning objects owned by '{}' in db '{}' to '{}'",
role_name, &db.name, &db.owner
);
client.simple_query(&reassign_query)?;
// This now will only drop privileges of the role
let drop_query = format!("DROP OWNED BY {}", role_name.pg_quote());
client.simple_query(&drop_query)?;
}
}
// Also handle case when there are no databases in the spec.
// In this case we need to reassign objects in the default database.
let conf = Config::from_str(connstr)?;
let db_owner = PgIdent::from_str("cloud_admin")?;
reassign_owned_objects_in_one_db(conf, role_name, &db_owner)?;
Ok(())
}
@@ -433,11 +411,10 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
// Print a list of existing Postgres databases (only in debug mode)
if span_enabled!(Level::INFO) {
let mut vec = Vec::new();
info!("postgres databases:");
for (dbname, db) in &existing_dbs {
vec.push(format!("{}:{}", dbname, db.owner));
info!(" {}:{}", dbname, db.owner);
}
info!("postgres databases (total {}): {:?}", vec.len(), vec);
}
// Process delta operations first
@@ -509,10 +486,7 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
// Refresh Postgres databases info to handle possible renames
let existing_dbs = get_existing_dbs(client)?;
info!(
"handling cluster spec databases (total {})",
spec.cluster.databases.len()
);
info!("cluster spec databases:");
for db in &spec.cluster.databases {
let name = &db.name;
let pg_db = existing_dbs.get(name);
@@ -571,7 +545,7 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
DatabaseAction::Create => " -> create",
DatabaseAction::Update => " -> update",
};
info!(" - {}:{}{}", db.name, db.owner, action_str);
info!(" - {}:{}{}", db.name, db.owner, action_str);
}
}
@@ -672,11 +646,7 @@ pub fn handle_grants(spec: &ComputeSpec, client: &mut Client, connstr: &str) ->
$$;"
.to_string();
info!(
"grant query for db {} : {}",
&db.name,
inlinify(&grant_query)
);
info!("grant query for db {} : {}", &db.name, &grant_query);
db_client.simple_query(&grant_query)?;
}

View File

@@ -6,11 +6,9 @@ license.workspace = true
[dependencies]
anyhow.workspace = true
async-trait.workspace = true
camino.workspace = true
clap.workspace = true
comfy-table.workspace = true
futures.workspace = true
git-version.workspace = true
nix.workspace = true
once_cell.workspace = true
@@ -26,11 +24,10 @@ tar.workspace = true
thiserror.workspace = true
toml.workspace = true
tokio.workspace = true
tokio-postgres.workspace = true
tokio-util.workspace = true
url.workspace = true
# Note: Do not directly depend on pageserver or safekeeper; use pageserver_api or safekeeper_api
# instead, so that recompile times are better.
pageserver_api.workspace = true
pageserver_client.workspace = true
postgres_backend.workspace = true
safekeeper_api.workspace = true
postgres_connection.workspace = true

View File

@@ -1,32 +0,0 @@
[package]
name = "attachment_service"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
anyhow.workspace = true
camino.workspace = true
clap.workspace = true
futures.workspace = true
git-version.workspace = true
hyper.workspace = true
pageserver_api.workspace = true
pageserver_client.workspace = true
postgres_connection.workspace = true
serde.workspace = true
serde_json.workspace = true
thiserror.workspace = true
tokio.workspace = true
tokio-util.workspace = true
tracing.workspace = true
# TODO: remove this after DB persistence is added, it is only used for
# a parsing function when loading pageservers from neon_local LocalEnv
postgres_backend.workspace = true
utils = { path = "../../libs/utils/" }
metrics = { path = "../../libs/metrics/" }
control_plane = { path = ".." }
workspace_hack = { version = "0.1", path = "../../workspace_hack" }

View File

@@ -1,116 +0,0 @@
use std::collections::HashMap;
use control_plane::endpoint::ComputeControlPlane;
use control_plane::local_env::LocalEnv;
use pageserver_api::shard::{ShardCount, ShardIndex, TenantShardId};
use postgres_connection::parse_host_port;
use utils::id::{NodeId, TenantId};
pub(super) struct ComputeHookTenant {
shards: Vec<(ShardIndex, NodeId)>,
}
impl ComputeHookTenant {
pub(super) async fn maybe_reconfigure(&mut self, tenant_id: TenantId) -> anyhow::Result<()> {
// Find the highest shard count and drop any shards that aren't
// for that shard count.
let shard_count = self.shards.iter().map(|(k, _v)| k.shard_count).max();
let Some(shard_count) = shard_count else {
// No shards, nothing to do.
tracing::info!("ComputeHookTenant::maybe_reconfigure: no shards");
return Ok(());
};
self.shards.retain(|(k, _v)| k.shard_count == shard_count);
self.shards
.sort_by_key(|(shard, _node_id)| shard.shard_number);
if self.shards.len() == shard_count.0 as usize || shard_count == ShardCount(0) {
// We have pageservers for all the shards: proceed to reconfigure compute
let env = match LocalEnv::load_config() {
Ok(e) => e,
Err(e) => {
tracing::warn!(
"Couldn't load neon_local config, skipping compute update ({e})"
);
return Ok(());
}
};
let cplane = ComputeControlPlane::load(env.clone())
.expect("Error loading compute control plane");
let compute_pageservers = self
.shards
.iter()
.map(|(_shard, node_id)| {
let ps_conf = env
.get_pageserver_conf(*node_id)
.expect("Unknown pageserver");
let (pg_host, pg_port) = parse_host_port(&ps_conf.listen_pg_addr)
.expect("Unable to parse listen_pg_addr");
(pg_host, pg_port.unwrap_or(5432))
})
.collect::<Vec<_>>();
for (endpoint_name, endpoint) in &cplane.endpoints {
if endpoint.tenant_id == tenant_id && endpoint.status() == "running" {
tracing::info!("🔁 Reconfiguring endpoint {}", endpoint_name,);
endpoint.reconfigure(compute_pageservers.clone()).await?;
}
}
} else {
tracing::info!(
"ComputeHookTenant::maybe_reconfigure: not enough shards ({}/{})",
self.shards.len(),
shard_count.0
);
}
Ok(())
}
}
/// The compute hook is a destination for notifications about changes to tenant:pageserver
/// mapping. It aggregates updates for the shards in a tenant, and when appropriate reconfigures
/// the compute connection string.
pub(super) struct ComputeHook {
state: tokio::sync::Mutex<HashMap<TenantId, ComputeHookTenant>>,
}
impl ComputeHook {
pub(super) fn new() -> Self {
Self {
state: Default::default(),
}
}
pub(super) async fn notify(
&self,
tenant_shard_id: TenantShardId,
node_id: NodeId,
) -> anyhow::Result<()> {
tracing::info!("ComputeHook::notify: {}->{}", tenant_shard_id, node_id);
let mut locked = self.state.lock().await;
let entry = locked
.entry(tenant_shard_id.tenant_id)
.or_insert_with(|| ComputeHookTenant { shards: Vec::new() });
let shard_index = ShardIndex {
shard_count: tenant_shard_id.shard_count,
shard_number: tenant_shard_id.shard_number,
};
let mut set = false;
for (existing_shard, existing_node) in &mut entry.shards {
if *existing_shard == shard_index {
*existing_node = node_id;
set = true;
}
}
if !set {
entry.shards.push((shard_index, node_id));
}
entry.maybe_reconfigure(tenant_shard_id.tenant_id).await
}
}

View File

@@ -1,218 +0,0 @@
use crate::reconciler::ReconcileError;
use crate::service::Service;
use hyper::{Body, Request, Response};
use hyper::{StatusCode, Uri};
use pageserver_api::models::{TenantCreateRequest, TimelineCreateRequest};
use pageserver_api::shard::TenantShardId;
use std::sync::Arc;
use utils::auth::SwappableJwtAuth;
use utils::http::endpoint::{auth_middleware, request_span};
use utils::http::request::parse_request_param;
use utils::id::TenantId;
use utils::{
http::{
endpoint::{self},
error::ApiError,
json::{json_request, json_response},
RequestExt, RouterBuilder,
},
id::NodeId,
};
use pageserver_api::control_api::{ReAttachRequest, ValidateRequest};
use control_plane::attachment_service::{
AttachHookRequest, InspectRequest, NodeConfigureRequest, NodeRegisterRequest,
TenantShardMigrateRequest,
};
/// State available to HTTP request handlers
#[derive(Clone)]
pub struct HttpState {
service: Arc<crate::service::Service>,
auth: Option<Arc<SwappableJwtAuth>>,
allowlist_routes: Vec<Uri>,
}
impl HttpState {
pub fn new(service: Arc<crate::service::Service>, auth: Option<Arc<SwappableJwtAuth>>) -> Self {
let allowlist_routes = ["/status"]
.iter()
.map(|v| v.parse().unwrap())
.collect::<Vec<_>>();
Self {
service,
auth,
allowlist_routes,
}
}
}
#[inline(always)]
fn get_state(request: &Request<Body>) -> &HttpState {
request
.data::<Arc<HttpState>>()
.expect("unknown state type")
.as_ref()
}
/// Pageserver calls into this on startup, to learn which tenants it should attach
async fn handle_re_attach(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
let reattach_req = json_request::<ReAttachRequest>(&mut req).await?;
let state = get_state(&req);
json_response(
StatusCode::OK,
state
.service
.re_attach(reattach_req)
.await
.map_err(ApiError::InternalServerError)?,
)
}
/// Pageserver calls into this before doing deletions, to confirm that it still
/// holds the latest generation for the tenants with deletions enqueued
async fn handle_validate(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
let validate_req = json_request::<ValidateRequest>(&mut req).await?;
let state = get_state(&req);
json_response(StatusCode::OK, state.service.validate(validate_req))
}
/// Call into this before attaching a tenant to a pageserver, to acquire a generation number
/// (in the real control plane this is unnecessary, because the same program is managing
/// generation numbers and doing attachments).
async fn handle_attach_hook(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
let attach_req = json_request::<AttachHookRequest>(&mut req).await?;
let state = get_state(&req);
json_response(
StatusCode::OK,
state
.service
.attach_hook(attach_req)
.await
.map_err(ApiError::InternalServerError)?,
)
}
async fn handle_inspect(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
let inspect_req = json_request::<InspectRequest>(&mut req).await?;
let state = get_state(&req);
json_response(StatusCode::OK, state.service.inspect(inspect_req))
}
async fn handle_tenant_create(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
let create_req = json_request::<TenantCreateRequest>(&mut req).await?;
let state = get_state(&req);
json_response(
StatusCode::OK,
state.service.tenant_create(create_req).await?,
)
}
async fn handle_tenant_timeline_create(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
let create_req = json_request::<TimelineCreateRequest>(&mut req).await?;
let state = get_state(&req);
json_response(
StatusCode::OK,
state
.service
.tenant_timeline_create(tenant_id, create_req)
.await?,
)
}
async fn handle_tenant_locate(req: Request<Body>) -> Result<Response<Body>, ApiError> {
let tenant_id: TenantId = parse_request_param(&req, "tenant_id")?;
let state = get_state(&req);
json_response(StatusCode::OK, state.service.tenant_locate(tenant_id)?)
}
async fn handle_node_register(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
let register_req = json_request::<NodeRegisterRequest>(&mut req).await?;
let state = get_state(&req);
state.service.node_register(register_req).await?;
json_response(StatusCode::OK, ())
}
async fn handle_node_configure(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
let node_id: NodeId = parse_request_param(&req, "node_id")?;
let config_req = json_request::<NodeConfigureRequest>(&mut req).await?;
if node_id != config_req.node_id {
return Err(ApiError::BadRequest(anyhow::anyhow!(
"Path and body node_id differ"
)));
}
let state = get_state(&req);
json_response(StatusCode::OK, state.service.node_configure(config_req)?)
}
async fn handle_tenant_shard_migrate(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
let tenant_shard_id: TenantShardId = parse_request_param(&req, "tenant_shard_id")?;
let migrate_req = json_request::<TenantShardMigrateRequest>(&mut req).await?;
let state = get_state(&req);
json_response(
StatusCode::OK,
state
.service
.tenant_shard_migrate(tenant_shard_id, migrate_req)
.await?,
)
}
/// Status endpoint is just used for checking that our HTTP listener is up
async fn handle_status(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
json_response(StatusCode::OK, ())
}
impl From<ReconcileError> for ApiError {
fn from(value: ReconcileError) -> Self {
ApiError::Conflict(format!("Reconciliation error: {}", value))
}
}
pub fn make_router(
service: Arc<Service>,
auth: Option<Arc<SwappableJwtAuth>>,
) -> RouterBuilder<hyper::Body, ApiError> {
let mut router = endpoint::make_router();
if auth.is_some() {
router = router.middleware(auth_middleware(|request| {
let state = get_state(request);
if state.allowlist_routes.contains(request.uri()) {
None
} else {
state.auth.as_deref()
}
}))
}
router
.data(Arc::new(HttpState::new(service, auth)))
.get("/status", |r| request_span(r, handle_status))
.post("/re-attach", |r| request_span(r, handle_re_attach))
.post("/validate", |r| request_span(r, handle_validate))
.post("/attach-hook", |r| request_span(r, handle_attach_hook))
.post("/inspect", |r| request_span(r, handle_inspect))
.post("/node", |r| request_span(r, handle_node_register))
.put("/node/:node_id/config", |r| {
request_span(r, handle_node_configure)
})
.post("/tenant", |r| request_span(r, handle_tenant_create))
.post("/tenant/:tenant_id/timeline", |r| {
request_span(r, handle_tenant_timeline_create)
})
.get("/tenant/:tenant_id/locate", |r| {
request_span(r, handle_tenant_locate)
})
.put("/tenant/:tenant_shard_id/migrate", |r| {
request_span(r, handle_tenant_shard_migrate)
})
}

View File

@@ -1,57 +0,0 @@
use serde::{Deserialize, Serialize};
use utils::seqwait::MonotonicCounter;
mod compute_hook;
pub mod http;
mod node;
pub mod persistence;
mod reconciler;
mod scheduler;
pub mod service;
mod tenant_state;
#[derive(Clone, Serialize, Deserialize)]
enum PlacementPolicy {
/// Cheapest way to attach a tenant: just one pageserver, no secondary
Single,
/// Production-ready way to attach a tenant: one attached pageserver and
/// some number of secondaries.
Double(usize),
}
#[derive(Ord, PartialOrd, Eq, PartialEq, Copy, Clone)]
struct Sequence(u64);
impl Sequence {
fn initial() -> Self {
Self(0)
}
}
impl std::fmt::Display for Sequence {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl MonotonicCounter<Sequence> for Sequence {
fn cnt_advance(&mut self, v: Sequence) {
assert!(*self <= v);
*self = v;
}
fn cnt_value(&self) -> Sequence {
*self
}
}
impl Sequence {
fn next(&self) -> Sequence {
Sequence(self.0 + 1)
}
}
impl Default for PlacementPolicy {
fn default() -> Self {
PlacementPolicy::Double(1)
}
}

View File

@@ -1,100 +0,0 @@
/// The attachment service mimics the aspects of the control plane API
/// that are required for a pageserver to operate.
///
/// This enables running & testing pageservers without a full-blown
/// deployment of the Neon cloud platform.
///
use anyhow::anyhow;
use attachment_service::http::make_router;
use attachment_service::persistence::Persistence;
use attachment_service::service::{Config, Service};
use camino::Utf8PathBuf;
use clap::Parser;
use metrics::launch_timestamp::LaunchTimestamp;
use std::sync::Arc;
use utils::auth::{JwtAuth, SwappableJwtAuth};
use utils::logging::{self, LogFormat};
use utils::signals::{ShutdownSignals, Signal};
use utils::{project_build_tag, project_git_version, tcp_listener};
project_git_version!(GIT_VERSION);
project_build_tag!(BUILD_TAG);
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
#[command(arg_required_else_help(true))]
struct Cli {
/// Host and port to listen on, like `127.0.0.1:1234`
#[arg(short, long)]
listen: std::net::SocketAddr,
/// Path to public key for JWT authentication of clients
#[arg(long)]
public_key: Option<camino::Utf8PathBuf>,
/// Token for authenticating this service with the pageservers it controls
#[arg(short, long)]
jwt_token: Option<String>,
/// Path to the .json file to store state (will be created if it doesn't exist)
#[arg(short, long)]
path: Utf8PathBuf,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let launch_ts = Box::leak(Box::new(LaunchTimestamp::generate()));
logging::init(
LogFormat::Plain,
logging::TracingErrorLayerEnablement::Disabled,
logging::Output::Stdout,
)?;
let args = Cli::parse();
tracing::info!(
"version: {}, launch_timestamp: {}, build_tag {}, state at {}, listening on {}",
GIT_VERSION,
launch_ts.to_string(),
BUILD_TAG,
args.path,
args.listen
);
let config = Config {
jwt_token: args.jwt_token,
};
let persistence = Arc::new(Persistence::new(&args.path).await);
let service = Service::spawn(config, persistence).await?;
let http_listener = tcp_listener::bind(args.listen)?;
let auth = if let Some(public_key_path) = &args.public_key {
let jwt_auth = JwtAuth::from_key_path(public_key_path)?;
Some(Arc::new(SwappableJwtAuth::new(jwt_auth)))
} else {
None
};
let router = make_router(service, auth)
.build()
.map_err(|err| anyhow!(err))?;
let service = utils::http::RouterService::new(router).unwrap();
let server = hyper::Server::from_tcp(http_listener)?.serve(service);
tracing::info!("Serving on {0}", args.listen);
tokio::task::spawn(server);
ShutdownSignals::handle(|signal| match signal {
Signal::Interrupt | Signal::Terminate | Signal::Quit => {
tracing::info!("Got {}. Terminating", signal.name());
// We're just a test helper: no graceful shutdown.
std::process::exit(0);
}
})?;
Ok(())
}

View File

@@ -1,37 +0,0 @@
use control_plane::attachment_service::{NodeAvailability, NodeSchedulingPolicy};
use utils::id::NodeId;
#[derive(Clone)]
pub(crate) struct Node {
pub(crate) id: NodeId,
pub(crate) availability: NodeAvailability,
pub(crate) scheduling: NodeSchedulingPolicy,
pub(crate) listen_http_addr: String,
pub(crate) listen_http_port: u16,
pub(crate) listen_pg_addr: String,
pub(crate) listen_pg_port: u16,
}
impl Node {
pub(crate) fn base_url(&self) -> String {
format!("http://{}:{}", self.listen_http_addr, self.listen_http_port)
}
/// Is this node elegible to have work scheduled onto it?
pub(crate) fn may_schedule(&self) -> bool {
match self.availability {
NodeAvailability::Active => {}
NodeAvailability::Offline => return false,
}
match self.scheduling {
NodeSchedulingPolicy::Active => true,
NodeSchedulingPolicy::Draining => false,
NodeSchedulingPolicy::Filling => true,
NodeSchedulingPolicy::Pause => false,
}
}
}

View File

@@ -1,272 +0,0 @@
use std::{collections::HashMap, str::FromStr};
use camino::{Utf8Path, Utf8PathBuf};
use control_plane::{
attachment_service::{NodeAvailability, NodeSchedulingPolicy},
local_env::LocalEnv,
};
use pageserver_api::{
models::TenantConfig,
shard::{ShardCount, ShardNumber, TenantShardId},
};
use postgres_connection::parse_host_port;
use serde::{Deserialize, Serialize};
use utils::{
generation::Generation,
id::{NodeId, TenantId},
};
use crate::{node::Node, PlacementPolicy};
/// Placeholder for storage. This will be replaced with a database client.
pub struct Persistence {
state: std::sync::Mutex<PersistentState>,
}
// Top level state available to all HTTP handlers
#[derive(Serialize, Deserialize)]
struct PersistentState {
tenants: HashMap<TenantShardId, TenantShardPersistence>,
#[serde(skip)]
path: Utf8PathBuf,
}
/// A convenience for serializing the state inside a sync lock, and then
/// writing it to disk outside of the lock. This will go away when switching
/// to a database backend.
struct PendingWrite {
bytes: Vec<u8>,
path: Utf8PathBuf,
}
impl PendingWrite {
async fn commit(&self) -> anyhow::Result<()> {
tokio::fs::write(&self.path, &self.bytes).await?;
Ok(())
}
}
impl PersistentState {
fn save(&self) -> PendingWrite {
PendingWrite {
bytes: serde_json::to_vec(self).expect("Serialization error"),
path: self.path.clone(),
}
}
async fn load(path: &Utf8Path) -> anyhow::Result<Self> {
let bytes = tokio::fs::read(path).await?;
let mut decoded = serde_json::from_slice::<Self>(&bytes)?;
decoded.path = path.to_owned();
for (tenant_id, tenant) in &mut decoded.tenants {
// Backward compat: an old attachments.json from before PR #6251, replace
// empty strings with proper defaults.
if tenant.tenant_id.is_empty() {
tenant.tenant_id = format!("{}", tenant_id);
tenant.config = serde_json::to_string(&TenantConfig::default())?;
tenant.placement_policy = serde_json::to_string(&PlacementPolicy::default())?;
}
}
Ok(decoded)
}
async fn load_or_new(path: &Utf8Path) -> Self {
match Self::load(path).await {
Ok(s) => {
tracing::info!("Loaded state file at {}", path);
s
}
Err(e)
if e.downcast_ref::<std::io::Error>()
.map(|e| e.kind() == std::io::ErrorKind::NotFound)
.unwrap_or(false) =>
{
tracing::info!("Will create state file at {}", path);
Self {
tenants: HashMap::new(),
path: path.to_owned(),
}
}
Err(e) => {
panic!("Failed to load state from '{}': {e:#} (maybe your .neon/ dir was written by an older version?)", path)
}
}
}
}
impl Persistence {
pub async fn new(path: &Utf8Path) -> Self {
let state = PersistentState::load_or_new(path).await;
Self {
state: std::sync::Mutex::new(state),
}
}
/// When registering a node, persist it so that on next start we will be able to
/// iterate over known nodes to synchronize their tenant shard states with our observed state.
pub(crate) async fn insert_node(&self, _node: &Node) -> anyhow::Result<()> {
// TODO: node persitence will come with database backend
Ok(())
}
/// At startup, we populate the service's list of nodes, and use this list to call into
/// each node to do an initial reconciliation of the state of the world with our in-memory
/// observed state.
pub(crate) async fn list_nodes(&self) -> anyhow::Result<Vec<Node>> {
let env = LocalEnv::load_config()?;
// TODO: node persitence will come with database backend
// XXX hack: enable test_backward_compatibility to work by populating our list of
// nodes from LocalEnv when it is not present in persistent storage. Otherwise at
// first startup in the compat test, we may have shards but no nodes.
let mut result = Vec::new();
tracing::info!(
"Loaded {} pageserver nodes from LocalEnv",
env.pageservers.len()
);
for ps_conf in env.pageservers {
let (pg_host, pg_port) =
parse_host_port(&ps_conf.listen_pg_addr).expect("Unable to parse listen_pg_addr");
let (http_host, http_port) = parse_host_port(&ps_conf.listen_http_addr)
.expect("Unable to parse listen_http_addr");
result.push(Node {
id: ps_conf.id,
listen_pg_addr: pg_host.to_string(),
listen_pg_port: pg_port.unwrap_or(5432),
listen_http_addr: http_host.to_string(),
listen_http_port: http_port.unwrap_or(80),
availability: NodeAvailability::Active,
scheduling: NodeSchedulingPolicy::Active,
});
}
Ok(result)
}
/// At startup, we populate our map of tenant shards from persistent storage.
pub(crate) async fn list_tenant_shards(&self) -> anyhow::Result<Vec<TenantShardPersistence>> {
let locked = self.state.lock().unwrap();
Ok(locked.tenants.values().cloned().collect())
}
/// Tenants must be persisted before we schedule them for the first time. This enables us
/// to correctly retain generation monotonicity, and the externally provided placement policy & config.
pub(crate) async fn insert_tenant_shards(
&self,
shards: Vec<TenantShardPersistence>,
) -> anyhow::Result<()> {
let write = {
let mut locked = self.state.lock().unwrap();
for shard in shards {
let tenant_shard_id = TenantShardId {
tenant_id: TenantId::from_str(shard.tenant_id.as_str())?,
shard_number: ShardNumber(shard.shard_number as u8),
shard_count: ShardCount(shard.shard_count as u8),
};
locked.tenants.insert(tenant_shard_id, shard);
}
locked.save()
};
write.commit().await?;
Ok(())
}
/// Reconciler calls this immediately before attaching to a new pageserver, to acquire a unique, monotonically
/// advancing generation number. We also store the NodeId for which the generation was issued, so that in
/// [`Self::re_attach`] we can do a bulk UPDATE on the generations for that node.
pub(crate) async fn increment_generation(
&self,
tenant_shard_id: TenantShardId,
node_id: Option<NodeId>,
) -> anyhow::Result<Generation> {
let (write, gen) = {
let mut locked = self.state.lock().unwrap();
let Some(shard) = locked.tenants.get_mut(&tenant_shard_id) else {
anyhow::bail!("Tried to increment generation of unknown shard");
};
// If we're called with a None pageserver, we need only update the generation
// record to disassociate it with this pageserver, not actually increment the number, as
// the increment is guaranteed to happen the next time this tenant is attached.
if node_id.is_some() {
shard.generation += 1;
}
shard.generation_pageserver = node_id;
let gen = Generation::new(shard.generation);
(locked.save(), gen)
};
write.commit().await?;
Ok(gen)
}
pub(crate) async fn re_attach(
&self,
node_id: NodeId,
) -> anyhow::Result<HashMap<TenantShardId, Generation>> {
let (write, result) = {
let mut result = HashMap::new();
let mut locked = self.state.lock().unwrap();
for (tenant_shard_id, shard) in locked.tenants.iter_mut() {
if shard.generation_pageserver == Some(node_id) {
shard.generation += 1;
result.insert(*tenant_shard_id, Generation::new(shard.generation));
}
}
(locked.save(), result)
};
write.commit().await?;
Ok(result)
}
// TODO: when we start shard splitting, we must durably mark the tenant so that
// on restart, we know that we must go through recovery (list shards that exist
// and pick up where we left off and/or revert to parent shards).
#[allow(dead_code)]
pub(crate) async fn begin_shard_split(&self, _tenant_id: TenantId) -> anyhow::Result<()> {
todo!();
}
// TODO: when we finish shard splitting, we must atomically clean up the old shards
// and insert the new shards, and clear the splitting marker.
#[allow(dead_code)]
pub(crate) async fn complete_shard_split(&self, _tenant_id: TenantId) -> anyhow::Result<()> {
todo!();
}
}
/// Parts of [`crate::tenant_state::TenantState`] that are stored durably
#[derive(Serialize, Deserialize, Clone)]
pub(crate) struct TenantShardPersistence {
#[serde(default)]
pub(crate) tenant_id: String,
#[serde(default)]
pub(crate) shard_number: i32,
#[serde(default)]
pub(crate) shard_count: i32,
#[serde(default)]
pub(crate) shard_stripe_size: i32,
// Currently attached pageserver
#[serde(rename = "pageserver")]
pub(crate) generation_pageserver: Option<NodeId>,
// Latest generation number: next time we attach, increment this
// and use the incremented number when attaching
pub(crate) generation: u32,
#[serde(default)]
pub(crate) placement_policy: String,
#[serde(default)]
pub(crate) config: String,
}

View File

@@ -1,495 +0,0 @@
use crate::persistence::Persistence;
use crate::service;
use control_plane::attachment_service::NodeAvailability;
use pageserver_api::models::{
LocationConfig, LocationConfigMode, LocationConfigSecondary, TenantConfig,
};
use pageserver_api::shard::{ShardIdentity, TenantShardId};
use pageserver_client::mgmt_api;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio_util::sync::CancellationToken;
use utils::generation::Generation;
use utils::id::{NodeId, TimelineId};
use utils::lsn::Lsn;
use crate::compute_hook::ComputeHook;
use crate::node::Node;
use crate::tenant_state::{IntentState, ObservedState, ObservedStateLocation};
/// Object with the lifetime of the background reconcile task that is created
/// for tenants which have a difference between their intent and observed states.
pub(super) struct Reconciler {
/// See [`crate::tenant_state::TenantState`] for the meanings of these fields: they are a snapshot
/// of a tenant's state from when we spawned a reconcile task.
pub(super) tenant_shard_id: TenantShardId,
pub(crate) shard: ShardIdentity,
pub(crate) generation: Generation,
pub(crate) intent: IntentState,
pub(crate) config: TenantConfig,
pub(crate) observed: ObservedState,
pub(crate) service_config: service::Config,
/// A snapshot of the pageservers as they were when we were asked
/// to reconcile.
pub(crate) pageservers: Arc<HashMap<NodeId, Node>>,
/// A hook to notify the running postgres instances when we change the location
/// of a tenant
pub(crate) compute_hook: Arc<ComputeHook>,
/// A means to abort background reconciliation: it is essential to
/// call this when something changes in the original TenantState that
/// will make this reconciliation impossible or unnecessary, for
/// example when a pageserver node goes offline, or the PlacementPolicy for
/// the tenant is changed.
pub(crate) cancel: CancellationToken,
/// Access to persistent storage for updating generation numbers
pub(crate) persistence: Arc<Persistence>,
}
#[derive(thiserror::Error, Debug)]
pub enum ReconcileError {
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl Reconciler {
async fn location_config(
&mut self,
node_id: NodeId,
config: LocationConfig,
flush_ms: Option<Duration>,
) -> anyhow::Result<()> {
let node = self
.pageservers
.get(&node_id)
.expect("Pageserver may not be removed while referenced");
self.observed
.locations
.insert(node.id, ObservedStateLocation { conf: None });
tracing::info!("location_config({}) calling: {:?}", node_id, config);
let client =
mgmt_api::Client::new(node.base_url(), self.service_config.jwt_token.as_deref());
client
.location_config(self.tenant_shard_id, config.clone(), flush_ms)
.await?;
tracing::info!("location_config({}) complete: {:?}", node_id, config);
self.observed
.locations
.insert(node.id, ObservedStateLocation { conf: Some(config) });
Ok(())
}
async fn maybe_live_migrate(&mut self) -> Result<(), ReconcileError> {
let destination = if let Some(node_id) = self.intent.attached {
match self.observed.locations.get(&node_id) {
Some(conf) => {
// We will do a live migration only if the intended destination is not
// currently in an attached state.
match &conf.conf {
Some(conf) if conf.mode == LocationConfigMode::Secondary => {
// Fall through to do a live migration
node_id
}
None | Some(_) => {
// Attached or uncertain: don't do a live migration, proceed
// with a general-case reconciliation
tracing::info!("maybe_live_migrate: destination is None or attached");
return Ok(());
}
}
}
None => {
// Our destination is not attached: maybe live migrate if some other
// node is currently attached. Fall through.
node_id
}
}
} else {
// No intent to be attached
tracing::info!("maybe_live_migrate: no attached intent");
return Ok(());
};
let mut origin = None;
for (node_id, state) in &self.observed.locations {
if let Some(observed_conf) = &state.conf {
if observed_conf.mode == LocationConfigMode::AttachedSingle {
let node = self
.pageservers
.get(node_id)
.expect("Nodes may not be removed while referenced");
// We will only attempt live migration if the origin is not offline: this
// avoids trying to do it while reconciling after responding to an HA failover.
if !matches!(node.availability, NodeAvailability::Offline) {
origin = Some(*node_id);
break;
}
}
}
}
let Some(origin) = origin else {
tracing::info!("maybe_live_migrate: no origin found");
return Ok(());
};
// We have an origin and a destination: proceed to do the live migration
tracing::info!("Live migrating {}->{}", origin, destination);
self.live_migrate(origin, destination).await?;
Ok(())
}
async fn get_lsns(
&self,
tenant_shard_id: TenantShardId,
node_id: &NodeId,
) -> anyhow::Result<HashMap<TimelineId, Lsn>> {
let node = self
.pageservers
.get(node_id)
.expect("Pageserver may not be removed while referenced");
let client =
mgmt_api::Client::new(node.base_url(), self.service_config.jwt_token.as_deref());
let timelines = client.timeline_list(&tenant_shard_id).await?;
Ok(timelines
.into_iter()
.map(|t| (t.timeline_id, t.last_record_lsn))
.collect())
}
async fn secondary_download(&self, tenant_shard_id: TenantShardId, node_id: &NodeId) {
let node = self
.pageservers
.get(node_id)
.expect("Pageserver may not be removed while referenced");
let client =
mgmt_api::Client::new(node.base_url(), self.service_config.jwt_token.as_deref());
match client.tenant_secondary_download(tenant_shard_id).await {
Ok(()) => {}
Err(_) => {
tracing::info!(" (skipping, destination wasn't in secondary mode)")
}
}
}
async fn await_lsn(
&self,
tenant_shard_id: TenantShardId,
pageserver_id: &NodeId,
baseline: HashMap<TimelineId, Lsn>,
) -> anyhow::Result<()> {
loop {
let latest = match self.get_lsns(tenant_shard_id, pageserver_id).await {
Ok(l) => l,
Err(e) => {
println!(
"🕑 Can't get LSNs on pageserver {} yet, waiting ({e})",
pageserver_id
);
std::thread::sleep(Duration::from_millis(500));
continue;
}
};
let mut any_behind: bool = false;
for (timeline_id, baseline_lsn) in &baseline {
match latest.get(timeline_id) {
Some(latest_lsn) => {
println!("🕑 LSN origin {baseline_lsn} vs destination {latest_lsn}");
if latest_lsn < baseline_lsn {
any_behind = true;
}
}
None => {
// Expected timeline isn't yet visible on migration destination.
// (IRL we would have to account for timeline deletion, but this
// is just test helper)
any_behind = true;
}
}
}
if !any_behind {
println!("✅ LSN caught up. Proceeding...");
break;
} else {
std::thread::sleep(Duration::from_millis(500));
}
}
Ok(())
}
pub async fn live_migrate(
&mut self,
origin_ps_id: NodeId,
dest_ps_id: NodeId,
) -> anyhow::Result<()> {
// `maybe_live_migrate` is responsibble for sanity of inputs
assert!(origin_ps_id != dest_ps_id);
fn build_location_config(
shard: &ShardIdentity,
config: &TenantConfig,
mode: LocationConfigMode,
generation: Option<Generation>,
secondary_conf: Option<LocationConfigSecondary>,
) -> LocationConfig {
LocationConfig {
mode,
generation: generation.map(|g| g.into().unwrap()),
secondary_conf,
tenant_conf: config.clone(),
shard_number: shard.number.0,
shard_count: shard.count.0,
shard_stripe_size: shard.stripe_size.0,
}
}
tracing::info!(
"🔁 Switching origin pageserver {} to stale mode",
origin_ps_id
);
// FIXME: it is incorrect to use self.generation here, we should use the generation
// from the ObservedState of the origin pageserver (it might be older than self.generation)
let stale_conf = build_location_config(
&self.shard,
&self.config,
LocationConfigMode::AttachedStale,
Some(self.generation),
None,
);
self.location_config(origin_ps_id, stale_conf, Some(Duration::from_secs(10)))
.await?;
let baseline_lsns = Some(self.get_lsns(self.tenant_shard_id, &origin_ps_id).await?);
// If we are migrating to a destination that has a secondary location, warm it up first
if let Some(destination_conf) = self.observed.locations.get(&dest_ps_id) {
if let Some(destination_conf) = &destination_conf.conf {
if destination_conf.mode == LocationConfigMode::Secondary {
tracing::info!(
"🔁 Downloading latest layers to destination pageserver {}",
dest_ps_id,
);
self.secondary_download(self.tenant_shard_id, &dest_ps_id)
.await;
}
}
}
// Increment generation before attaching to new pageserver
self.generation = self
.persistence
.increment_generation(self.tenant_shard_id, Some(dest_ps_id))
.await?;
let dest_conf = build_location_config(
&self.shard,
&self.config,
LocationConfigMode::AttachedMulti,
Some(self.generation),
None,
);
tracing::info!("🔁 Attaching to pageserver {}", dest_ps_id);
self.location_config(dest_ps_id, dest_conf, None).await?;
if let Some(baseline) = baseline_lsns {
tracing::info!("🕑 Waiting for LSN to catch up...");
self.await_lsn(self.tenant_shard_id, &dest_ps_id, baseline)
.await?;
}
tracing::info!("🔁 Notifying compute to use pageserver {}", dest_ps_id);
self.compute_hook
.notify(self.tenant_shard_id, dest_ps_id)
.await?;
// Downgrade the origin to secondary. If the tenant's policy is PlacementPolicy::Single, then
// this location will be deleted in the general case reconciliation that runs after this.
let origin_secondary_conf = build_location_config(
&self.shard,
&self.config,
LocationConfigMode::Secondary,
None,
Some(LocationConfigSecondary { warm: true }),
);
self.location_config(origin_ps_id, origin_secondary_conf.clone(), None)
.await?;
// TODO: we should also be setting the ObservedState on earlier API calls, in case we fail
// partway through. In fact, all location conf API calls should be in a wrapper that sets
// the observed state to None, then runs, then sets it to what we wrote.
self.observed.locations.insert(
origin_ps_id,
ObservedStateLocation {
conf: Some(origin_secondary_conf),
},
);
println!(
"🔁 Switching to AttachedSingle mode on pageserver {}",
dest_ps_id
);
let dest_final_conf = build_location_config(
&self.shard,
&self.config,
LocationConfigMode::AttachedSingle,
Some(self.generation),
None,
);
self.location_config(dest_ps_id, dest_final_conf.clone(), None)
.await?;
self.observed.locations.insert(
dest_ps_id,
ObservedStateLocation {
conf: Some(dest_final_conf),
},
);
println!("✅ Migration complete");
Ok(())
}
/// Reconciling a tenant makes API calls to pageservers until the observed state
/// matches the intended state.
///
/// First we apply special case handling (e.g. for live migrations), and then a
/// general case reconciliation where we walk through the intent by pageserver
/// and call out to the pageserver to apply the desired state.
pub(crate) async fn reconcile(&mut self) -> Result<(), ReconcileError> {
// TODO: if any of self.observed is None, call to remote pageservers
// to learn correct state.
// Special case: live migration
self.maybe_live_migrate().await?;
// If the attached pageserver is not attached, do so now.
if let Some(node_id) = self.intent.attached {
let mut wanted_conf =
attached_location_conf(self.generation, &self.shard, &self.config);
match self.observed.locations.get(&node_id) {
Some(conf) if conf.conf.as_ref() == Some(&wanted_conf) => {
// Nothing to do
tracing::info!("Observed configuration already correct.")
}
_ => {
// In all cases other than a matching observed configuration, we will
// reconcile this location. This includes locations with different configurations, as well
// as locations with unknown (None) observed state.
self.generation = self
.persistence
.increment_generation(self.tenant_shard_id, Some(node_id))
.await?;
wanted_conf.generation = self.generation.into();
tracing::info!("Observed configuration requires update.");
self.location_config(node_id, wanted_conf, None).await?;
if let Err(e) = self
.compute_hook
.notify(self.tenant_shard_id, node_id)
.await
{
tracing::warn!(
"Failed to notify compute of newly attached pageserver {node_id}: {e}"
);
}
}
}
}
// Configure secondary locations: if these were previously attached this
// implicitly downgrades them from attached to secondary.
let mut changes = Vec::new();
for node_id in &self.intent.secondary {
let wanted_conf = secondary_location_conf(&self.shard, &self.config);
match self.observed.locations.get(node_id) {
Some(conf) if conf.conf.as_ref() == Some(&wanted_conf) => {
// Nothing to do
tracing::info!(%node_id, "Observed configuration already correct.")
}
_ => {
// In all cases other than a matching observed configuration, we will
// reconcile this location.
tracing::info!(%node_id, "Observed configuration requires update.");
changes.push((*node_id, wanted_conf))
}
}
}
// Detach any extraneous pageservers that are no longer referenced
// by our intent.
let all_pageservers = self.intent.all_pageservers();
for node_id in self.observed.locations.keys() {
if all_pageservers.contains(node_id) {
// We are only detaching pageservers that aren't used at all.
continue;
}
changes.push((
*node_id,
LocationConfig {
mode: LocationConfigMode::Detached,
generation: None,
secondary_conf: None,
shard_number: self.shard.number.0,
shard_count: self.shard.count.0,
shard_stripe_size: self.shard.stripe_size.0,
tenant_conf: self.config.clone(),
},
));
}
for (node_id, conf) in changes {
self.location_config(node_id, conf, None).await?;
}
Ok(())
}
}
pub(crate) fn attached_location_conf(
generation: Generation,
shard: &ShardIdentity,
config: &TenantConfig,
) -> LocationConfig {
LocationConfig {
mode: LocationConfigMode::AttachedSingle,
generation: generation.into(),
secondary_conf: None,
shard_number: shard.number.0,
shard_count: shard.count.0,
shard_stripe_size: shard.stripe_size.0,
tenant_conf: config.clone(),
}
}
pub(crate) fn secondary_location_conf(
shard: &ShardIdentity,
config: &TenantConfig,
) -> LocationConfig {
LocationConfig {
mode: LocationConfigMode::Secondary,
generation: None,
secondary_conf: Some(LocationConfigSecondary { warm: true }),
shard_number: shard.number.0,
shard_count: shard.count.0,
shard_stripe_size: shard.stripe_size.0,
tenant_conf: config.clone(),
}
}

View File

@@ -1,89 +0,0 @@
use pageserver_api::shard::TenantShardId;
use std::collections::{BTreeMap, HashMap};
use utils::{http::error::ApiError, id::NodeId};
use crate::{node::Node, tenant_state::TenantState};
/// Scenarios in which we cannot find a suitable location for a tenant shard
#[derive(thiserror::Error, Debug)]
pub enum ScheduleError {
#[error("No pageservers found")]
NoPageservers,
#[error("No pageserver found matching constraint")]
ImpossibleConstraint,
}
impl From<ScheduleError> for ApiError {
fn from(value: ScheduleError) -> Self {
ApiError::Conflict(format!("Scheduling error: {}", value))
}
}
pub(crate) struct Scheduler {
tenant_counts: HashMap<NodeId, usize>,
}
impl Scheduler {
pub(crate) fn new(
tenants: &BTreeMap<TenantShardId, TenantState>,
nodes: &HashMap<NodeId, Node>,
) -> Self {
let mut tenant_counts = HashMap::new();
for node_id in nodes.keys() {
tenant_counts.insert(*node_id, 0);
}
for tenant in tenants.values() {
if let Some(ps) = tenant.intent.attached {
let entry = tenant_counts.entry(ps).or_insert(0);
*entry += 1;
}
}
for (node_id, node) in nodes {
if !node.may_schedule() {
tenant_counts.remove(node_id);
}
}
Self { tenant_counts }
}
pub(crate) fn schedule_shard(
&mut self,
hard_exclude: &[NodeId],
) -> Result<NodeId, ScheduleError> {
if self.tenant_counts.is_empty() {
return Err(ScheduleError::NoPageservers);
}
let mut tenant_counts: Vec<(NodeId, usize)> = self
.tenant_counts
.iter()
.filter_map(|(k, v)| {
if hard_exclude.contains(k) {
None
} else {
Some((*k, *v))
}
})
.collect();
// Sort by tenant count. Nodes with the same tenant count are sorted by ID.
tenant_counts.sort_by_key(|i| (i.1, i.0));
if tenant_counts.is_empty() {
// After applying constraints, no pageservers were left
return Err(ScheduleError::ImpossibleConstraint);
}
for (node_id, count) in &tenant_counts {
tracing::info!("tenant_counts[{node_id}]={count}");
}
let node_id = tenant_counts.first().unwrap().0;
tracing::info!("scheduler selected node {node_id}");
*self.tenant_counts.get_mut(&node_id).unwrap() += 1;
Ok(node_id)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,455 +0,0 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use control_plane::attachment_service::NodeAvailability;
use pageserver_api::{
models::{LocationConfig, LocationConfigMode, TenantConfig},
shard::{ShardIdentity, TenantShardId},
};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use utils::{
generation::Generation,
id::NodeId,
seqwait::{SeqWait, SeqWaitError},
};
use crate::{
compute_hook::ComputeHook,
node::Node,
persistence::Persistence,
reconciler::{attached_location_conf, secondary_location_conf, ReconcileError, Reconciler},
scheduler::{ScheduleError, Scheduler},
service, PlacementPolicy, Sequence,
};
pub(crate) struct TenantState {
pub(crate) tenant_shard_id: TenantShardId,
pub(crate) shard: ShardIdentity,
// Runtime only: sequence used to coordinate when updating this object while
// with background reconcilers may be running. A reconciler runs to a particular
// sequence.
pub(crate) sequence: Sequence,
// Latest generation number: next time we attach, increment this
// and use the incremented number when attaching
pub(crate) generation: Generation,
// High level description of how the tenant should be set up. Provided
// externally.
pub(crate) policy: PlacementPolicy,
// Low level description of exactly which pageservers should fulfil
// which role. Generated by `Self::schedule`.
pub(crate) intent: IntentState,
// Low level description of how the tenant is configured on pageservers:
// if this does not match `Self::intent` then the tenant needs reconciliation
// with `Self::reconcile`.
pub(crate) observed: ObservedState,
// Tenant configuration, passed through opaquely to the pageserver. Identical
// for all shards in a tenant.
pub(crate) config: TenantConfig,
/// If a reconcile task is currently in flight, it may be joined here (it is
/// only safe to join if either the result has been received or the reconciler's
/// cancellation token has been fired)
pub(crate) reconciler: Option<ReconcilerHandle>,
/// Optionally wait for reconciliation to complete up to a particular
/// sequence number.
pub(crate) waiter: std::sync::Arc<SeqWait<Sequence, Sequence>>,
/// Indicates sequence number for which we have encountered an error reconciling. If
/// this advances ahead of [`Self::waiter`] then a reconciliation error has occurred,
/// and callers should stop waiting for `waiter` and propagate the error.
pub(crate) error_waiter: std::sync::Arc<SeqWait<Sequence, Sequence>>,
/// The most recent error from a reconcile on this tenant
/// TODO: generalize to an array of recent events
/// TOOD: use a ArcSwap instead of mutex for faster reads?
pub(crate) last_error: std::sync::Arc<std::sync::Mutex<String>>,
}
#[derive(Default, Clone, Debug)]
pub(crate) struct IntentState {
pub(crate) attached: Option<NodeId>,
pub(crate) secondary: Vec<NodeId>,
}
#[derive(Default, Clone)]
pub(crate) struct ObservedState {
pub(crate) locations: HashMap<NodeId, ObservedStateLocation>,
}
/// Our latest knowledge of how this tenant is configured in the outside world.
///
/// Meaning:
/// * No instance of this type exists for a node: we are certain that we have nothing configured on that
/// node for this shard.
/// * Instance exists with conf==None: we *might* have some state on that node, but we don't know
/// what it is (e.g. we failed partway through configuring it)
/// * Instance exists with conf==Some: this tells us what we last successfully configured on this node,
/// and that configuration will still be present unless something external interfered.
#[derive(Clone)]
pub(crate) struct ObservedStateLocation {
/// If None, it means we do not know the status of this shard's location on this node, but
/// we know that we might have some state on this node.
pub(crate) conf: Option<LocationConfig>,
}
pub(crate) struct ReconcilerWaiter {
// For observability purposes, remember the ID of the shard we're
// waiting for.
pub(crate) tenant_shard_id: TenantShardId,
seq_wait: std::sync::Arc<SeqWait<Sequence, Sequence>>,
error_seq_wait: std::sync::Arc<SeqWait<Sequence, Sequence>>,
error: std::sync::Arc<std::sync::Mutex<String>>,
seq: Sequence,
}
#[derive(thiserror::Error, Debug)]
pub enum ReconcileWaitError {
#[error("Timeout waiting for shard {0}")]
Timeout(TenantShardId),
#[error("shutting down")]
Shutdown,
#[error("Reconcile error on shard {0}: {1}")]
Failed(TenantShardId, String),
}
impl ReconcilerWaiter {
pub(crate) async fn wait_timeout(&self, timeout: Duration) -> Result<(), ReconcileWaitError> {
tokio::select! {
result = self.seq_wait.wait_for_timeout(self.seq, timeout)=> {
result.map_err(|e| match e {
SeqWaitError::Timeout => ReconcileWaitError::Timeout(self.tenant_shard_id),
SeqWaitError::Shutdown => ReconcileWaitError::Shutdown
})?;
},
result = self.error_seq_wait.wait_for(self.seq) => {
result.map_err(|e| match e {
SeqWaitError::Shutdown => ReconcileWaitError::Shutdown,
SeqWaitError::Timeout => unreachable!()
})?;
return Err(ReconcileWaitError::Failed(self.tenant_shard_id, self.error.lock().unwrap().clone()))
}
}
Ok(())
}
}
/// Having spawned a reconciler task, the tenant shard's state will carry enough
/// information to optionally cancel & await it later.
pub(crate) struct ReconcilerHandle {
sequence: Sequence,
handle: JoinHandle<()>,
cancel: CancellationToken,
}
/// When a reconcile task completes, it sends this result object
/// to be applied to the primary TenantState.
pub(crate) struct ReconcileResult {
pub(crate) sequence: Sequence,
/// On errors, `observed` should be treated as an incompleted description
/// of state (i.e. any nodes present in the result should override nodes
/// present in the parent tenant state, but any unmentioned nodes should
/// not be removed from parent tenant state)
pub(crate) result: Result<(), ReconcileError>,
pub(crate) tenant_shard_id: TenantShardId,
pub(crate) generation: Generation,
pub(crate) observed: ObservedState,
}
impl IntentState {
pub(crate) fn new() -> Self {
Self {
attached: None,
secondary: vec![],
}
}
pub(crate) fn all_pageservers(&self) -> Vec<NodeId> {
let mut result = Vec::new();
if let Some(p) = self.attached {
result.push(p)
}
result.extend(self.secondary.iter().copied());
result
}
/// When a node goes offline, we update intents to avoid using it
/// as their attached pageserver.
///
/// Returns true if a change was made
pub(crate) fn notify_offline(&mut self, node_id: NodeId) -> bool {
if self.attached == Some(node_id) {
self.attached = None;
self.secondary.push(node_id);
true
} else {
false
}
}
}
impl ObservedState {
pub(crate) fn new() -> Self {
Self {
locations: HashMap::new(),
}
}
}
impl TenantState {
pub(crate) fn new(
tenant_shard_id: TenantShardId,
shard: ShardIdentity,
policy: PlacementPolicy,
) -> Self {
Self {
tenant_shard_id,
policy,
intent: IntentState::default(),
generation: Generation::new(0),
shard,
observed: ObservedState::default(),
config: TenantConfig::default(),
reconciler: None,
sequence: Sequence(1),
waiter: Arc::new(SeqWait::new(Sequence(0))),
error_waiter: Arc::new(SeqWait::new(Sequence(0))),
last_error: Arc::default(),
}
}
/// For use on startup when learning state from pageservers: generate my [`IntentState`] from my
/// [`ObservedState`], even if it violates my [`PlacementPolicy`]. Call [`Self::schedule`] next,
/// to get an intent state that complies with placement policy. The overall goal is to do scheduling
/// in a way that makes use of any configured locations that already exist in the outside world.
pub(crate) fn intent_from_observed(&mut self) {
// Choose an attached location by filtering observed locations, and then sorting to get the highest
// generation
let mut attached_locs = self
.observed
.locations
.iter()
.filter_map(|(node_id, l)| {
if let Some(conf) = &l.conf {
if conf.mode == LocationConfigMode::AttachedMulti
|| conf.mode == LocationConfigMode::AttachedSingle
|| conf.mode == LocationConfigMode::AttachedStale
{
Some((node_id, conf.generation))
} else {
None
}
} else {
None
}
})
.collect::<Vec<_>>();
attached_locs.sort_by_key(|i| i.1);
if let Some((node_id, _gen)) = attached_locs.into_iter().last() {
self.intent.attached = Some(*node_id);
}
// All remaining observed locations generate secondary intents. This includes None
// observations, as these may well have some local content on disk that is usable (this
// is an edge case that might occur if we restarted during a migration or other change)
self.observed.locations.keys().for_each(|node_id| {
if Some(*node_id) != self.intent.attached {
self.intent.secondary.push(*node_id);
}
});
}
pub(crate) fn schedule(&mut self, scheduler: &mut Scheduler) -> Result<(), ScheduleError> {
// TODO: before scheduling new nodes, check if any existing content in
// self.intent refers to pageservers that are offline, and pick other
// pageservers if so.
// Build the set of pageservers already in use by this tenant, to avoid scheduling
// more work on the same pageservers we're already using.
let mut used_pageservers = self.intent.all_pageservers();
let mut modified = false;
use PlacementPolicy::*;
match self.policy {
Single => {
// Should have exactly one attached, and zero secondaries
if self.intent.attached.is_none() {
let node_id = scheduler.schedule_shard(&used_pageservers)?;
self.intent.attached = Some(node_id);
used_pageservers.push(node_id);
modified = true;
}
if !self.intent.secondary.is_empty() {
self.intent.secondary.clear();
modified = true;
}
}
Double(secondary_count) => {
// Should have exactly one attached, and N secondaries
if self.intent.attached.is_none() {
let node_id = scheduler.schedule_shard(&used_pageservers)?;
self.intent.attached = Some(node_id);
used_pageservers.push(node_id);
modified = true;
}
while self.intent.secondary.len() < secondary_count {
let node_id = scheduler.schedule_shard(&used_pageservers)?;
self.intent.secondary.push(node_id);
used_pageservers.push(node_id);
modified = true;
}
}
}
if modified {
self.sequence.0 += 1;
}
Ok(())
}
fn dirty(&self) -> bool {
if let Some(node_id) = self.intent.attached {
let wanted_conf = attached_location_conf(self.generation, &self.shard, &self.config);
match self.observed.locations.get(&node_id) {
Some(conf) if conf.conf.as_ref() == Some(&wanted_conf) => {}
Some(_) | None => {
return true;
}
}
}
for node_id in &self.intent.secondary {
let wanted_conf = secondary_location_conf(&self.shard, &self.config);
match self.observed.locations.get(node_id) {
Some(conf) if conf.conf.as_ref() == Some(&wanted_conf) => {}
Some(_) | None => {
return true;
}
}
}
false
}
pub(crate) fn maybe_reconcile(
&mut self,
result_tx: tokio::sync::mpsc::UnboundedSender<ReconcileResult>,
pageservers: &Arc<HashMap<NodeId, Node>>,
compute_hook: &Arc<ComputeHook>,
service_config: &service::Config,
persistence: &Arc<Persistence>,
) -> Option<ReconcilerWaiter> {
// If there are any ambiguous observed states, and the nodes they refer to are available,
// we should reconcile to clean them up.
let mut dirty_observed = false;
for (node_id, observed_loc) in &self.observed.locations {
let node = pageservers
.get(node_id)
.expect("Nodes may not be removed while referenced");
if observed_loc.conf.is_none()
&& !matches!(node.availability, NodeAvailability::Offline)
{
dirty_observed = true;
break;
}
}
if !self.dirty() && !dirty_observed {
tracing::info!("Not dirty, no reconciliation needed.");
return None;
}
// Reconcile already in flight for the current sequence?
if let Some(handle) = &self.reconciler {
if handle.sequence == self.sequence {
return Some(ReconcilerWaiter {
tenant_shard_id: self.tenant_shard_id,
seq_wait: self.waiter.clone(),
error_seq_wait: self.error_waiter.clone(),
error: self.last_error.clone(),
seq: self.sequence,
});
}
}
// Reconcile in flight for a stale sequence? Our sequence's task will wait for it before
// doing our sequence's work.
let old_handle = self.reconciler.take();
let cancel = CancellationToken::new();
let mut reconciler = Reconciler {
tenant_shard_id: self.tenant_shard_id,
shard: self.shard,
generation: self.generation,
intent: self.intent.clone(),
config: self.config.clone(),
observed: self.observed.clone(),
pageservers: pageservers.clone(),
compute_hook: compute_hook.clone(),
service_config: service_config.clone(),
cancel: cancel.clone(),
persistence: persistence.clone(),
};
let reconcile_seq = self.sequence;
tracing::info!("Spawning Reconciler for sequence {}", self.sequence);
let join_handle = tokio::task::spawn(async move {
// Wait for any previous reconcile task to complete before we start
if let Some(old_handle) = old_handle {
old_handle.cancel.cancel();
if let Err(e) = old_handle.handle.await {
// We can't do much with this other than log it: the task is done, so
// we may proceed with our work.
tracing::error!("Unexpected join error waiting for reconcile task: {e}");
}
}
// Early check for cancellation before doing any work
// TODO: wrap all remote API operations in cancellation check
// as well.
if reconciler.cancel.is_cancelled() {
return;
}
let result = reconciler.reconcile().await;
result_tx
.send(ReconcileResult {
sequence: reconcile_seq,
result,
tenant_shard_id: reconciler.tenant_shard_id,
generation: reconciler.generation,
observed: reconciler.observed,
})
.ok();
});
self.reconciler = Some(ReconcilerHandle {
sequence: self.sequence,
handle: join_handle,
cancel,
});
Some(ReconcilerWaiter {
tenant_shard_id: self.tenant_shard_id,
seq_wait: self.waiter.clone(),
error_seq_wait: self.error_waiter.clone(),
error: self.last_error.clone(),
seq: self.sequence,
})
}
}

View File

@@ -1,35 +1,22 @@
use crate::{background_process, local_env::LocalEnv};
use anyhow::anyhow;
use camino::Utf8PathBuf;
use hyper::Method;
use pageserver_api::{
models::{ShardParameters, TenantCreateRequest, TimelineCreateRequest, TimelineInfo},
shard::TenantShardId,
};
use pageserver_client::mgmt_api::ResponseErrorMessageExt;
use postgres_backend::AuthType;
use postgres_connection::parse_host_port;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{path::PathBuf, process::Child, str::FromStr};
use tracing::instrument;
use utils::{
auth::{Claims, Scope},
id::{NodeId, TenantId},
};
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, process::Child};
use utils::id::{NodeId, TenantId};
pub struct AttachmentService {
env: LocalEnv,
listen: String,
path: PathBuf,
jwt_token: Option<String>,
public_key_path: Option<Utf8PathBuf>,
client: reqwest::Client,
client: reqwest::blocking::Client,
}
const COMMAND: &str = "attachment_service";
#[derive(Serialize, Deserialize)]
pub struct AttachHookRequest {
pub tenant_shard_id: TenantShardId,
pub tenant_id: TenantId,
pub node_id: Option<NodeId>,
}
@@ -40,7 +27,7 @@ pub struct AttachHookResponse {
#[derive(Serialize, Deserialize)]
pub struct InspectRequest {
pub tenant_shard_id: TenantShardId,
pub tenant_id: TenantId,
}
#[derive(Serialize, Deserialize)]
@@ -48,125 +35,6 @@ pub struct InspectResponse {
pub attachment: Option<(u32, NodeId)>,
}
#[derive(Serialize, Deserialize)]
pub struct TenantCreateResponseShard {
pub node_id: NodeId,
pub generation: u32,
}
#[derive(Serialize, Deserialize)]
pub struct TenantCreateResponse {
pub shards: Vec<TenantCreateResponseShard>,
}
#[derive(Serialize, Deserialize)]
pub struct NodeRegisterRequest {
pub node_id: NodeId,
pub listen_pg_addr: String,
pub listen_pg_port: u16,
pub listen_http_addr: String,
pub listen_http_port: u16,
}
#[derive(Serialize, Deserialize)]
pub struct NodeConfigureRequest {
pub node_id: NodeId,
pub availability: Option<NodeAvailability>,
pub scheduling: Option<NodeSchedulingPolicy>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct TenantLocateResponseShard {
pub shard_id: TenantShardId,
pub node_id: NodeId,
pub listen_pg_addr: String,
pub listen_pg_port: u16,
pub listen_http_addr: String,
pub listen_http_port: u16,
}
#[derive(Serialize, Deserialize)]
pub struct TenantLocateResponse {
pub shards: Vec<TenantLocateResponseShard>,
pub shard_params: ShardParameters,
}
/// Explicitly migrating a particular shard is a low level operation
/// TODO: higher level "Reschedule tenant" operation where the request
/// specifies some constraints, e.g. asking it to get off particular node(s)
#[derive(Serialize, Deserialize, Debug)]
pub struct TenantShardMigrateRequest {
pub tenant_shard_id: TenantShardId,
pub node_id: NodeId,
}
#[derive(Serialize, Deserialize, Clone, Copy)]
pub enum NodeAvailability {
// Normal, happy state
Active,
// Offline: Tenants shouldn't try to attach here, but they may assume that their
// secondary locations on this node still exist. Newly added nodes are in this
// state until we successfully contact them.
Offline,
}
impl FromStr for NodeAvailability {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"active" => Ok(Self::Active),
"offline" => Ok(Self::Offline),
_ => Err(anyhow::anyhow!("Unknown availability state '{s}'")),
}
}
}
/// FIXME: this is a duplicate of the type in the attachment_service crate, because the
/// type needs to be defined with diesel traits in there.
#[derive(Serialize, Deserialize, Clone, Copy)]
pub enum NodeSchedulingPolicy {
Active,
Filling,
Pause,
Draining,
}
impl FromStr for NodeSchedulingPolicy {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"active" => Ok(Self::Active),
"filling" => Ok(Self::Filling),
"pause" => Ok(Self::Pause),
"draining" => Ok(Self::Draining),
_ => Err(anyhow::anyhow!("Unknown scheduling state '{s}'")),
}
}
}
impl From<NodeSchedulingPolicy> for String {
fn from(value: NodeSchedulingPolicy) -> String {
use NodeSchedulingPolicy::*;
match value {
Active => "active",
Filling => "filling",
Pause => "pause",
Draining => "draining",
}
.to_string()
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct TenantShardMigrateResponse {}
impl AttachmentService {
pub fn from_env(env: &LocalEnv) -> Self {
let path = env.base_data_dir.join("attachments.json");
@@ -181,35 +49,11 @@ impl AttachmentService {
listen_url.port().unwrap()
);
// Assume all pageservers have symmetric auth configuration: this service
// expects to use one JWT token to talk to all of them.
let ps_conf = env
.pageservers
.first()
.expect("Config is validated to contain at least one pageserver");
let (jwt_token, public_key_path) = match ps_conf.http_auth_type {
AuthType::Trust => (None, None),
AuthType::NeonJWT => {
let jwt_token = env
.generate_auth_token(&Claims::new(None, Scope::PageServerApi))
.unwrap();
// If pageserver auth is enabled, this implicitly enables auth for this service,
// using the same credentials.
let public_key_path =
camino::Utf8PathBuf::try_from(env.base_data_dir.join("auth_public_key.pem"))
.unwrap();
(Some(jwt_token), Some(public_key_path))
}
};
Self {
env: env.clone(),
path,
listen,
jwt_token,
public_key_path,
client: reqwest::ClientBuilder::new()
client: reqwest::blocking::ClientBuilder::new()
.build()
.expect("Failed to construct http client"),
}
@@ -220,202 +64,74 @@ impl AttachmentService {
.expect("non-Unicode path")
}
pub async fn start(&self) -> anyhow::Result<Child> {
pub fn start(&self) -> anyhow::Result<Child> {
let path_str = self.path.to_string_lossy();
let mut args = vec!["-l", &self.listen, "-p", &path_str]
.into_iter()
.map(|s| s.to_string())
.collect::<Vec<_>>();
if let Some(jwt_token) = &self.jwt_token {
args.push(format!("--jwt-token={jwt_token}"));
}
if let Some(public_key_path) = &self.public_key_path {
args.push(format!("--public-key={public_key_path}"));
}
let result = background_process::start_process(
background_process::start_process(
COMMAND,
&self.env.base_data_dir,
&self.env.attachment_service_bin(),
args,
[(
"NEON_REPO_DIR".to_string(),
self.env.base_data_dir.to_string_lossy().to_string(),
)],
background_process::InitialPidFile::Create(self.pid_file()),
|| async {
match self.status().await {
Ok(_) => Ok(true),
Err(_) => Ok(false),
}
},
["-l", &self.listen, "-p", &path_str],
[],
background_process::InitialPidFile::Create(&self.pid_file()),
// TODO: a real status check
|| Ok(true),
)
.await;
for ps_conf in &self.env.pageservers {
let (pg_host, pg_port) =
parse_host_port(&ps_conf.listen_pg_addr).expect("Unable to parse listen_pg_addr");
let (http_host, http_port) = parse_host_port(&ps_conf.listen_http_addr)
.expect("Unable to parse listen_http_addr");
self.node_register(NodeRegisterRequest {
node_id: ps_conf.id,
listen_pg_addr: pg_host.to_string(),
listen_pg_port: pg_port.unwrap_or(5432),
listen_http_addr: http_host.to_string(),
listen_http_port: http_port.unwrap_or(80),
})
.await?;
}
result
}
pub fn stop(&self, immediate: bool) -> anyhow::Result<()> {
background_process::stop_process(immediate, COMMAND, &self.pid_file())
}
/// Simple HTTP request wrapper for calling into attachment service
async fn dispatch<RQ, RS>(
/// Call into the attach_hook API, for use before handing out attachments to pageservers
pub fn attach_hook(
&self,
method: hyper::Method,
path: String,
body: Option<RQ>,
) -> anyhow::Result<RS>
where
RQ: Serialize + Sized,
RS: DeserializeOwned + Sized,
{
tenant_id: TenantId,
pageserver_id: NodeId,
) -> anyhow::Result<Option<u32>> {
use hyper::StatusCode;
let url = self
.env
.control_plane_api
.clone()
.unwrap()
.join(&path)
.join("attach-hook")
.unwrap();
let mut builder = self.client.request(method, url);
if let Some(body) = body {
builder = builder.json(&body)
}
if let Some(jwt_token) = &self.jwt_token {
builder = builder.header(
reqwest::header::AUTHORIZATION,
format!("Bearer {jwt_token}"),
);
}
let response = builder.send().await?;
let response = response.error_from_body().await?;
Ok(response
.json()
.await
.map_err(pageserver_client::mgmt_api::Error::ReceiveBody)?)
}
/// Call into the attach_hook API, for use before handing out attachments to pageservers
#[instrument(skip(self))]
pub async fn attach_hook(
&self,
tenant_shard_id: TenantShardId,
pageserver_id: NodeId,
) -> anyhow::Result<Option<u32>> {
let request = AttachHookRequest {
tenant_shard_id,
tenant_id,
node_id: Some(pageserver_id),
};
let response = self
.dispatch::<_, AttachHookResponse>(
Method::POST,
"attach-hook".to_string(),
Some(request),
)
.await?;
let response = self.client.post(url).json(&request).send()?;
if response.status() != StatusCode::OK {
return Err(anyhow!("Unexpected status {}", response.status()));
}
let response = response.json::<AttachHookResponse>()?;
Ok(response.gen)
}
#[instrument(skip(self))]
pub async fn inspect(
&self,
tenant_shard_id: TenantShardId,
) -> anyhow::Result<Option<(u32, NodeId)>> {
let request = InspectRequest { tenant_shard_id };
pub fn inspect(&self, tenant_id: TenantId) -> anyhow::Result<Option<(u32, NodeId)>> {
use hyper::StatusCode;
let response = self
.dispatch::<_, InspectResponse>(Method::POST, "inspect".to_string(), Some(request))
.await?;
let url = self
.env
.control_plane_api
.clone()
.unwrap()
.join("inspect")
.unwrap();
let request = InspectRequest { tenant_id };
let response = self.client.post(url).json(&request).send()?;
if response.status() != StatusCode::OK {
return Err(anyhow!("Unexpected status {}", response.status()));
}
let response = response.json::<InspectResponse>()?;
Ok(response.attachment)
}
#[instrument(skip(self))]
pub async fn tenant_create(
&self,
req: TenantCreateRequest,
) -> anyhow::Result<TenantCreateResponse> {
self.dispatch(Method::POST, "tenant".to_string(), Some(req))
.await
}
#[instrument(skip(self))]
pub async fn tenant_locate(&self, tenant_id: TenantId) -> anyhow::Result<TenantLocateResponse> {
self.dispatch::<(), _>(Method::GET, format!("tenant/{tenant_id}/locate"), None)
.await
}
#[instrument(skip(self))]
pub async fn tenant_migrate(
&self,
tenant_shard_id: TenantShardId,
node_id: NodeId,
) -> anyhow::Result<TenantShardMigrateResponse> {
self.dispatch(
Method::PUT,
format!("tenant/{tenant_shard_id}/migrate"),
Some(TenantShardMigrateRequest {
tenant_shard_id,
node_id,
}),
)
.await
}
#[instrument(skip_all, fields(node_id=%req.node_id))]
pub async fn node_register(&self, req: NodeRegisterRequest) -> anyhow::Result<()> {
self.dispatch::<_, ()>(Method::POST, "node".to_string(), Some(req))
.await
}
#[instrument(skip_all, fields(node_id=%req.node_id))]
pub async fn node_configure(&self, req: NodeConfigureRequest) -> anyhow::Result<()> {
self.dispatch::<_, ()>(
Method::PUT,
format!("node/{}/config", req.node_id),
Some(req),
)
.await
}
#[instrument(skip(self))]
pub async fn status(&self) -> anyhow::Result<()> {
self.dispatch::<(), ()>(Method::GET, "status".to_string(), None)
.await
}
#[instrument(skip_all, fields(%tenant_id, timeline_id=%req.new_timeline_id))]
pub async fn tenant_timeline_create(
&self,
tenant_id: TenantId,
req: TimelineCreateRequest,
) -> anyhow::Result<TimelineInfo> {
self.dispatch(
Method::POST,
format!("tenant/{tenant_id}/timeline"),
Some(req),
)
.await
}
}

View File

@@ -44,15 +44,15 @@ const NOTICE_AFTER_RETRIES: u64 = 50;
/// Argument to `start_process`, to indicate whether it should create pidfile or if the process creates
/// it itself.
pub enum InitialPidFile {
pub enum InitialPidFile<'t> {
/// Create a pidfile, to allow future CLI invocations to manipulate the process.
Create(Utf8PathBuf),
Create(&'t Utf8Path),
/// The process will create the pidfile itself, need to wait for that event.
Expect(Utf8PathBuf),
Expect(&'t Utf8Path),
}
/// Start a background child process using the parameters given.
pub async fn start_process<F, Fut, AI, A, EI>(
pub fn start_process<F, AI, A, EI>(
process_name: &str,
datadir: &Path,
command: &Path,
@@ -62,8 +62,7 @@ pub async fn start_process<F, Fut, AI, A, EI>(
process_status_check: F,
) -> anyhow::Result<Child>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = anyhow::Result<bool>>,
F: Fn() -> anyhow::Result<bool>,
AI: IntoIterator<Item = A>,
A: AsRef<OsStr>,
// Not generic AsRef<OsStr>, otherwise empty `envs` prevents type inference
@@ -90,7 +89,7 @@ where
let filled_cmd = fill_remote_storage_secrets_vars(fill_rust_env_vars(background_command));
filled_cmd.envs(envs);
let pid_file_to_check = match &initial_pid_file {
let pid_file_to_check = match initial_pid_file {
InitialPidFile::Create(path) => {
pre_exec_create_pidfile(filled_cmd, path);
path
@@ -108,7 +107,7 @@ where
);
for retries in 0..RETRIES {
match process_started(pid, pid_file_to_check, &process_status_check).await {
match process_started(pid, Some(pid_file_to_check), &process_status_check) {
Ok(true) => {
println!("\n{process_name} started, pid: {pid}");
return Ok(spawned_process);
@@ -317,20 +316,22 @@ where
cmd
}
async fn process_started<F, Fut>(
fn process_started<F>(
pid: Pid,
pid_file_to_check: &Utf8Path,
pid_file_to_check: Option<&Utf8Path>,
status_check: &F,
) -> anyhow::Result<bool>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = anyhow::Result<bool>>,
F: Fn() -> anyhow::Result<bool>,
{
match status_check().await {
Ok(true) => match pid_file::read(pid_file_to_check)? {
PidFileRead::NotExist => Ok(false),
PidFileRead::LockedByOtherProcess(pid_in_file) => Ok(pid_in_file == pid),
PidFileRead::NotHeldByAnyProcess(_) => Ok(false),
match status_check() {
Ok(true) => match pid_file_to_check {
Some(pid_file_path) => match pid_file::read(pid_file_path)? {
PidFileRead::NotExist => Ok(false),
PidFileRead::LockedByOtherProcess(pid_in_file) => Ok(pid_in_file == pid),
PidFileRead::NotHeldByAnyProcess(_) => Ok(false),
},
None => Ok(true),
},
Ok(false) => Ok(false),
Err(e) => anyhow::bail!("process failed to start: {e}"),

View File

@@ -0,0 +1,337 @@
/// The attachment service mimics the aspects of the control plane API
/// that are required for a pageserver to operate.
///
/// This enables running & testing pageservers without a full-blown
/// deployment of the Neon cloud platform.
///
use anyhow::anyhow;
use clap::Parser;
use hex::FromHex;
use hyper::StatusCode;
use hyper::{Body, Request, Response};
use pageserver_api::shard::TenantShardId;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::{collections::HashMap, sync::Arc};
use utils::http::endpoint::request_span;
use utils::logging::{self, LogFormat};
use utils::signals::{ShutdownSignals, Signal};
use utils::{
http::{
endpoint::{self},
error::ApiError,
json::{json_request, json_response},
RequestExt, RouterBuilder,
},
id::{NodeId, TenantId},
tcp_listener,
};
use pageserver_api::control_api::{
ReAttachRequest, ReAttachResponse, ReAttachResponseTenant, ValidateRequest, ValidateResponse,
ValidateResponseTenant,
};
use control_plane::attachment_service::{
AttachHookRequest, AttachHookResponse, InspectRequest, InspectResponse,
};
#[derive(Parser)]
#[command(author, version, about, long_about = None)]
#[command(arg_required_else_help(true))]
struct Cli {
/// Host and port to listen on, like `127.0.0.1:1234`
#[arg(short, long)]
listen: std::net::SocketAddr,
/// Path to the .json file to store state (will be created if it doesn't exist)
#[arg(short, long)]
path: PathBuf,
}
// The persistent state of each Tenant
#[derive(Serialize, Deserialize, Clone)]
struct TenantState {
// Currently attached pageserver
pageserver: Option<NodeId>,
// Latest generation number: next time we attach, increment this
// and use the incremented number when attaching
generation: u32,
}
fn to_hex_map<S, V>(input: &HashMap<TenantId, V>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
V: Clone + Serialize,
{
let transformed = input.iter().map(|(k, v)| (hex::encode(k), v.clone()));
transformed
.collect::<HashMap<String, V>>()
.serialize(serializer)
}
fn from_hex_map<'de, D, V>(deserializer: D) -> Result<HashMap<TenantId, V>, D::Error>
where
D: serde::de::Deserializer<'de>,
V: Deserialize<'de>,
{
let hex_map = HashMap::<String, V>::deserialize(deserializer)?;
hex_map
.into_iter()
.map(|(k, v)| {
TenantId::from_hex(k)
.map(|k| (k, v))
.map_err(serde::de::Error::custom)
})
.collect()
}
// Top level state available to all HTTP handlers
#[derive(Serialize, Deserialize)]
struct PersistentState {
#[serde(serialize_with = "to_hex_map", deserialize_with = "from_hex_map")]
tenants: HashMap<TenantId, TenantState>,
#[serde(skip)]
path: PathBuf,
}
impl PersistentState {
async fn save(&self) -> anyhow::Result<()> {
let bytes = serde_json::to_vec(self)?;
tokio::fs::write(&self.path, &bytes).await?;
Ok(())
}
async fn load(path: &Path) -> anyhow::Result<Self> {
let bytes = tokio::fs::read(path).await?;
let mut decoded = serde_json::from_slice::<Self>(&bytes)?;
decoded.path = path.to_owned();
Ok(decoded)
}
async fn load_or_new(path: &Path) -> Self {
match Self::load(path).await {
Ok(s) => {
tracing::info!("Loaded state file at {}", path.display());
s
}
Err(e)
if e.downcast_ref::<std::io::Error>()
.map(|e| e.kind() == std::io::ErrorKind::NotFound)
.unwrap_or(false) =>
{
tracing::info!("Will create state file at {}", path.display());
Self {
tenants: HashMap::new(),
path: path.to_owned(),
}
}
Err(e) => {
panic!("Failed to load state from '{}': {e:#} (maybe your .neon/ dir was written by an older version?)", path.display())
}
}
}
}
/// State available to HTTP request handlers
#[derive(Clone)]
struct State {
inner: Arc<tokio::sync::RwLock<PersistentState>>,
}
impl State {
fn new(persistent_state: PersistentState) -> State {
Self {
inner: Arc::new(tokio::sync::RwLock::new(persistent_state)),
}
}
}
#[inline(always)]
fn get_state(request: &Request<Body>) -> &State {
request
.data::<Arc<State>>()
.expect("unknown state type")
.as_ref()
}
/// Pageserver calls into this on startup, to learn which tenants it should attach
async fn handle_re_attach(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
let reattach_req = json_request::<ReAttachRequest>(&mut req).await?;
let state = get_state(&req).inner.clone();
let mut locked = state.write().await;
let mut response = ReAttachResponse {
tenants: Vec::new(),
};
for (t, state) in &mut locked.tenants {
if state.pageserver == Some(reattach_req.node_id) {
state.generation += 1;
response.tenants.push(ReAttachResponseTenant {
// TODO(sharding): make this shard-aware
id: TenantShardId::unsharded(*t),
gen: state.generation,
});
}
}
locked.save().await.map_err(ApiError::InternalServerError)?;
json_response(StatusCode::OK, response)
}
/// Pageserver calls into this before doing deletions, to confirm that it still
/// holds the latest generation for the tenants with deletions enqueued
async fn handle_validate(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
let validate_req = json_request::<ValidateRequest>(&mut req).await?;
let locked = get_state(&req).inner.read().await;
let mut response = ValidateResponse {
tenants: Vec::new(),
};
for req_tenant in validate_req.tenants {
// TODO(sharding): make this shard-aware
if let Some(tenant_state) = locked.tenants.get(&req_tenant.id.tenant_id) {
let valid = tenant_state.generation == req_tenant.gen;
tracing::info!(
"handle_validate: {}(gen {}): valid={valid} (latest {})",
req_tenant.id,
req_tenant.gen,
tenant_state.generation
);
response.tenants.push(ValidateResponseTenant {
id: req_tenant.id,
valid,
});
}
}
json_response(StatusCode::OK, response)
}
/// Call into this before attaching a tenant to a pageserver, to acquire a generation number
/// (in the real control plane this is unnecessary, because the same program is managing
/// generation numbers and doing attachments).
async fn handle_attach_hook(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
let attach_req = json_request::<AttachHookRequest>(&mut req).await?;
let state = get_state(&req).inner.clone();
let mut locked = state.write().await;
let tenant_state = locked
.tenants
.entry(attach_req.tenant_id)
.or_insert_with(|| TenantState {
pageserver: attach_req.node_id,
generation: 0,
});
if let Some(attaching_pageserver) = attach_req.node_id.as_ref() {
tenant_state.generation += 1;
tracing::info!(
tenant_id = %attach_req.tenant_id,
ps_id = %attaching_pageserver,
generation = %tenant_state.generation,
"issuing",
);
} else if let Some(ps_id) = tenant_state.pageserver {
tracing::info!(
tenant_id = %attach_req.tenant_id,
%ps_id,
generation = %tenant_state.generation,
"dropping",
);
} else {
tracing::info!(
tenant_id = %attach_req.tenant_id,
"no-op: tenant already has no pageserver");
}
tenant_state.pageserver = attach_req.node_id;
let generation = tenant_state.generation;
tracing::info!(
"handle_attach_hook: tenant {} set generation {}, pageserver {}",
attach_req.tenant_id,
tenant_state.generation,
attach_req.node_id.unwrap_or(utils::id::NodeId(0xfffffff))
);
locked.save().await.map_err(ApiError::InternalServerError)?;
json_response(
StatusCode::OK,
AttachHookResponse {
gen: attach_req.node_id.map(|_| generation),
},
)
}
async fn handle_inspect(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
let inspect_req = json_request::<InspectRequest>(&mut req).await?;
let state = get_state(&req).inner.clone();
let locked = state.write().await;
let tenant_state = locked.tenants.get(&inspect_req.tenant_id);
json_response(
StatusCode::OK,
InspectResponse {
attachment: tenant_state.and_then(|s| s.pageserver.map(|ps| (s.generation, ps))),
},
)
}
fn make_router(persistent_state: PersistentState) -> RouterBuilder<hyper::Body, ApiError> {
endpoint::make_router()
.data(Arc::new(State::new(persistent_state)))
.post("/re-attach", |r| request_span(r, handle_re_attach))
.post("/validate", |r| request_span(r, handle_validate))
.post("/attach-hook", |r| request_span(r, handle_attach_hook))
.post("/inspect", |r| request_span(r, handle_inspect))
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
logging::init(
LogFormat::Plain,
logging::TracingErrorLayerEnablement::Disabled,
logging::Output::Stdout,
)?;
let args = Cli::parse();
tracing::info!(
"Starting, state at {}, listening on {}",
args.path.to_string_lossy(),
args.listen
);
let persistent_state = PersistentState::load_or_new(&args.path).await;
let http_listener = tcp_listener::bind(args.listen)?;
let router = make_router(persistent_state)
.build()
.map_err(|err| anyhow!(err))?;
let service = utils::http::RouterService::new(router).unwrap();
let server = hyper::Server::from_tcp(http_listener)?.serve(service);
tracing::info!("Serving on {0}", args.listen);
tokio::task::spawn(server);
ShutdownSignals::handle(|signal| match signal {
Signal::Interrupt | Signal::Terminate | Signal::Quit => {
tracing::info!("Got {}. Terminating", signal.name());
// We're just a test helper: no graceful shutdown.
std::process::exit(0);
}
})?;
Ok(())
}

View File

@@ -6,26 +6,21 @@
//! rely on `neon_local` to set up the environment for each test.
//!
use anyhow::{anyhow, bail, Context, Result};
use clap::{value_parser, Arg, ArgAction, ArgMatches, Command, ValueEnum};
use clap::{value_parser, Arg, ArgAction, ArgMatches, Command};
use compute_api::spec::ComputeMode;
use control_plane::attachment_service::{
AttachmentService, NodeAvailability, NodeConfigureRequest, NodeSchedulingPolicy,
};
use control_plane::attachment_service::AttachmentService;
use control_plane::endpoint::ComputeControlPlane;
use control_plane::local_env::{InitForceMode, LocalEnv};
use control_plane::local_env::LocalEnv;
use control_plane::pageserver::{PageServerNode, PAGESERVER_REMOTE_STORAGE_DIR};
use control_plane::safekeeper::SafekeeperNode;
use control_plane::tenant_migration::migrate_tenant;
use control_plane::{broker, local_env};
use pageserver_api::models::{
ShardParameters, TenantCreateRequest, TimelineCreateRequest, TimelineInfo,
};
use pageserver_api::shard::{ShardCount, ShardStripeSize, TenantShardId};
use pageserver_api::models::TimelineInfo;
use pageserver_api::{
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_PAGESERVER_HTTP_PORT,
DEFAULT_PG_LISTEN_PORT as DEFAULT_PAGESERVER_PG_PORT,
};
use postgres_backend::AuthType;
use postgres_connection::parse_host_port;
use safekeeper_api::{
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT,
DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT,
@@ -35,7 +30,6 @@ use std::path::PathBuf;
use std::process::exit;
use std::str::FromStr;
use storage_broker::DEFAULT_LISTEN_ADDR as DEFAULT_BROKER_ADDR;
use url::Host;
use utils::{
auth::{Claims, Scope},
id::{NodeId, TenantId, TenantTimelineId, TimelineId},
@@ -126,20 +120,15 @@ fn main() -> Result<()> {
let mut env = LocalEnv::load_config().context("Error loading config")?;
let original_env = env.clone();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let subcommand_result = match sub_name {
"tenant" => rt.block_on(handle_tenant(sub_args, &mut env)),
"timeline" => rt.block_on(handle_timeline(sub_args, &mut env)),
"start" => rt.block_on(handle_start_all(sub_args, &env)),
"tenant" => handle_tenant(sub_args, &mut env),
"timeline" => handle_timeline(sub_args, &mut env),
"start" => handle_start_all(sub_args, &env),
"stop" => handle_stop_all(sub_args, &env),
"pageserver" => rt.block_on(handle_pageserver(sub_args, &env)),
"attachment_service" => rt.block_on(handle_attachment_service(sub_args, &env)),
"safekeeper" => rt.block_on(handle_safekeeper(sub_args, &env)),
"endpoint" => rt.block_on(handle_endpoint(sub_args, &env)),
"pageserver" => handle_pageserver(sub_args, &env),
"attachment_service" => handle_attachment_service(sub_args, &env),
"safekeeper" => handle_safekeeper(sub_args, &env),
"endpoint" => handle_endpoint(sub_args, &env),
"mappings" => handle_mappings(sub_args, &mut env),
"pg" => bail!("'pg' subcommand has been renamed to 'endpoint'"),
_ => bail!("unexpected subcommand {sub_name}"),
@@ -280,13 +269,12 @@ fn print_timeline(
/// Returns a map of timeline IDs to timeline_id@lsn strings.
/// Connects to the pageserver to query this information.
async fn get_timeline_infos(
fn get_timeline_infos(
env: &local_env::LocalEnv,
tenant_shard_id: &TenantShardId,
tenant_id: &TenantId,
) -> Result<HashMap<TimelineId, TimelineInfo>> {
Ok(get_default_pageserver(env)
.timeline_list(tenant_shard_id)
.await?
.timeline_list(tenant_id)?
.into_iter()
.map(|timeline_info| (timeline_info.timeline_id, timeline_info))
.collect())
@@ -303,20 +291,6 @@ fn get_tenant_id(sub_match: &ArgMatches, env: &local_env::LocalEnv) -> anyhow::R
}
}
// Helper function to parse --tenant_id option, for commands that accept a shard suffix
fn get_tenant_shard_id(
sub_match: &ArgMatches,
env: &local_env::LocalEnv,
) -> anyhow::Result<TenantShardId> {
if let Some(tenant_id_from_arguments) = parse_tenant_shard_id(sub_match).transpose() {
tenant_id_from_arguments
} else if let Some(default_id) = env.default_tenant_id {
Ok(TenantShardId::unsharded(default_id))
} else {
anyhow::bail!("No tenant shard id. Use --tenant-id, or set a default tenant");
}
}
fn parse_tenant_id(sub_match: &ArgMatches) -> anyhow::Result<Option<TenantId>> {
sub_match
.get_one::<String>("tenant-id")
@@ -325,14 +299,6 @@ fn parse_tenant_id(sub_match: &ArgMatches) -> anyhow::Result<Option<TenantId>> {
.context("Failed to parse tenant id from the argument string")
}
fn parse_tenant_shard_id(sub_match: &ArgMatches) -> anyhow::Result<Option<TenantShardId>> {
sub_match
.get_one::<String>("tenant-id")
.map(|id_str| TenantShardId::from_str(id_str))
.transpose()
.context("Failed to parse tenant shard id from the argument string")
}
fn parse_timeline_id(sub_match: &ArgMatches) -> anyhow::Result<Option<TimelineId>> {
sub_match
.get_one::<String>("timeline-id")
@@ -366,7 +332,7 @@ fn handle_init(init_match: &ArgMatches) -> anyhow::Result<LocalEnv> {
let mut env =
LocalEnv::parse_config(&toml_file).context("Failed to create neon configuration")?;
let force = init_match.get_one("force").expect("we set a default value");
let force = init_match.get_flag("force");
env.init(pg_version, force)
.context("Failed to initialize neon repository")?;
@@ -407,82 +373,52 @@ fn pageserver_config_overrides(init_match: &ArgMatches) -> Vec<&str> {
.collect()
}
async fn handle_tenant(
tenant_match: &ArgMatches,
env: &mut local_env::LocalEnv,
) -> anyhow::Result<()> {
fn handle_tenant(tenant_match: &ArgMatches, env: &mut local_env::LocalEnv) -> anyhow::Result<()> {
let pageserver = get_default_pageserver(env);
match tenant_match.subcommand() {
Some(("list", _)) => {
for t in pageserver.tenant_list().await? {
for t in pageserver.tenant_list()? {
println!("{} {:?}", t.id, t.state);
}
}
Some(("create", create_match)) => {
let tenant_conf: HashMap<_, _> = create_match
.get_many::<String>("config")
.map(|vals: clap::parser::ValuesRef<'_, String>| {
vals.flat_map(|c| c.split_once(':')).collect()
})
.map(|vals| vals.flat_map(|c| c.split_once(':')).collect())
.unwrap_or_default();
let shard_count: u8 = create_match
.get_one::<u8>("shard-count")
.cloned()
.unwrap_or(0);
let shard_stripe_size: Option<u32> =
create_match.get_one::<u32>("shard-stripe-size").cloned();
let tenant_conf = PageServerNode::parse_config(tenant_conf)?;
// If tenant ID was not specified, generate one
let tenant_id = parse_tenant_id(create_match)?.unwrap_or_else(TenantId::generate);
// We must register the tenant with the attachment service, so
// that when the pageserver restarts, it will be re-attached.
let attachment_service = AttachmentService::from_env(env);
attachment_service
.tenant_create(TenantCreateRequest {
// Note that ::unsharded here isn't actually because the tenant is unsharded, its because the
// attachment service expecfs a shard-naive tenant_id in this attribute, and the TenantCreateRequest
// type is used both in attachment service (for creating tenants) and in pageserver (for creating shards)
new_tenant_id: TenantShardId::unsharded(tenant_id),
generation: None,
shard_parameters: ShardParameters {
count: ShardCount(shard_count),
stripe_size: shard_stripe_size
.map(ShardStripeSize)
.unwrap_or(ShardParameters::DEFAULT_STRIPE_SIZE),
},
config: tenant_conf,
})
.await?;
let generation = if env.control_plane_api.is_some() {
// We must register the tenant with the attachment service, so
// that when the pageserver restarts, it will be re-attached.
let attachment_service = AttachmentService::from_env(env);
attachment_service.attach_hook(tenant_id, pageserver.conf.id)?
} else {
None
};
pageserver.tenant_create(tenant_id, generation, tenant_conf)?;
println!("tenant {tenant_id} successfully created on the pageserver");
// Create an initial timeline for the new tenant
let new_timeline_id =
parse_timeline_id(create_match)?.unwrap_or(TimelineId::generate());
let new_timeline_id = parse_timeline_id(create_match)?;
let pg_version = create_match
.get_one::<u32>("pg-version")
.copied()
.context("Failed to parse postgres version from the argument string")?;
// FIXME: passing None for ancestor_start_lsn is not kosher in a sharded world: we can't have
// different shards picking different start lsns. Maybe we have to teach attachment service
// to let shard 0 branch first and then propagate the chosen LSN to other shards.
attachment_service
.tenant_timeline_create(
tenant_id,
TimelineCreateRequest {
new_timeline_id,
ancestor_timeline_id: None,
ancestor_start_lsn: None,
existing_initdb_timeline_id: None,
pg_version: Some(pg_version),
},
)
.await?;
let timeline_info = pageserver.timeline_create(
tenant_id,
new_timeline_id,
None,
None,
Some(pg_version),
None,
)?;
let new_timeline_id = timeline_info.timeline_id;
let last_record_lsn = timeline_info.last_record_lsn;
env.register_branch_mapping(
DEFAULT_BRANCH_NAME.to_string(),
@@ -490,7 +426,9 @@ async fn handle_tenant(
new_timeline_id,
)?;
println!("Created an initial timeline '{new_timeline_id}' for tenant: {tenant_id}",);
println!(
"Created an initial timeline '{new_timeline_id}' at Lsn {last_record_lsn} for tenant: {tenant_id}",
);
if create_match.get_flag("set-default") {
println!("Setting tenant {tenant_id} as a default one");
@@ -512,84 +450,31 @@ async fn handle_tenant(
pageserver
.tenant_config(tenant_id, tenant_conf)
.await
.with_context(|| format!("Tenant config failed for tenant with id {tenant_id}"))?;
println!("tenant {tenant_id} successfully configured on the pageserver");
}
Some(("migrate", matches)) => {
let tenant_shard_id = get_tenant_shard_id(matches, env)?;
let tenant_id = get_tenant_id(matches, env)?;
let new_pageserver = get_pageserver(env, matches)?;
let new_pageserver_id = new_pageserver.conf.id;
let attachment_service = AttachmentService::from_env(env);
attachment_service
.tenant_migrate(tenant_shard_id, new_pageserver_id)
.await?;
println!("tenant {tenant_shard_id} migrated to {}", new_pageserver_id);
migrate_tenant(env, tenant_id, new_pageserver)?;
println!("tenant {tenant_id} migrated to {}", new_pageserver_id);
}
Some(("status", matches)) => {
let tenant_id = get_tenant_id(matches, env)?;
let mut shard_table = comfy_table::Table::new();
shard_table.set_header(["Shard", "Pageserver", "Physical Size"]);
let mut tenant_synthetic_size = None;
let attachment_service = AttachmentService::from_env(env);
for shard in attachment_service.tenant_locate(tenant_id).await?.shards {
let pageserver =
PageServerNode::from_env(env, env.get_pageserver_conf(shard.node_id)?);
let size = pageserver
.http_client
.tenant_details(shard.shard_id)
.await?
.tenant_info
.current_physical_size
.unwrap();
shard_table.add_row([
format!("{}", shard.shard_id.shard_slug()),
format!("{}", shard.node_id.0),
format!("{} MiB", size / (1024 * 1024)),
]);
if shard.shard_id.is_zero() {
tenant_synthetic_size =
Some(pageserver.tenant_synthetic_size(shard.shard_id).await?);
}
}
let Some(synthetic_size) = tenant_synthetic_size else {
bail!("Shard 0 not found")
};
let mut tenant_table = comfy_table::Table::new();
tenant_table.add_row(["Tenant ID".to_string(), tenant_id.to_string()]);
tenant_table.add_row([
"Synthetic size".to_string(),
format!("{} MiB", synthetic_size.size.unwrap_or(0) / (1024 * 1024)),
]);
println!("{tenant_table}");
println!("{shard_table}");
}
Some((sub_name, _)) => bail!("Unexpected tenant subcommand '{}'", sub_name),
None => bail!("no tenant subcommand provided"),
}
Ok(())
}
async fn handle_timeline(timeline_match: &ArgMatches, env: &mut local_env::LocalEnv) -> Result<()> {
fn handle_timeline(timeline_match: &ArgMatches, env: &mut local_env::LocalEnv) -> Result<()> {
let pageserver = get_default_pageserver(env);
match timeline_match.subcommand() {
Some(("list", list_match)) => {
// TODO(sharding): this command shouldn't have to specify a shard ID: we should ask the attachment service
// where shard 0 is attached, and query there.
let tenant_shard_id = get_tenant_shard_id(list_match, env)?;
let timelines = pageserver.timeline_list(&tenant_shard_id).await?;
let tenant_id = get_tenant_id(list_match, env)?;
let timelines = pageserver.timeline_list(&tenant_id)?;
print_timelines_tree(timelines, env.timeline_name_mappings())?;
}
Some(("create", create_match)) => {
@@ -604,19 +489,16 @@ async fn handle_timeline(timeline_match: &ArgMatches, env: &mut local_env::Local
.context("Failed to parse postgres version from the argument string")?;
let new_timeline_id_opt = parse_timeline_id(create_match)?;
let new_timeline_id = new_timeline_id_opt.unwrap_or(TimelineId::generate());
let attachment_service = AttachmentService::from_env(env);
let create_req = TimelineCreateRequest {
new_timeline_id,
ancestor_timeline_id: None,
existing_initdb_timeline_id: None,
ancestor_start_lsn: None,
pg_version: Some(pg_version),
};
let timeline_info = attachment_service
.tenant_timeline_create(tenant_id, create_req)
.await?;
let timeline_info = pageserver.timeline_create(
tenant_id,
new_timeline_id_opt,
None,
None,
Some(pg_version),
None,
)?;
let new_timeline_id = timeline_info.timeline_id;
let last_record_lsn = timeline_info.last_record_lsn;
env.register_branch_mapping(new_branch_name.to_string(), tenant_id, new_timeline_id)?;
@@ -660,9 +542,7 @@ async fn handle_timeline(timeline_match: &ArgMatches, env: &mut local_env::Local
let mut cplane = ComputeControlPlane::load(env.clone())?;
println!("Importing timeline into pageserver ...");
pageserver
.timeline_import(tenant_id, timeline_id, base, pg_wal, pg_version)
.await?;
pageserver.timeline_import(tenant_id, timeline_id, base, pg_wal, pg_version)?;
env.register_branch_mapping(name.to_string(), tenant_id, timeline_id)?;
println!("Creating endpoint for imported timeline ...");
@@ -674,6 +554,7 @@ async fn handle_timeline(timeline_match: &ArgMatches, env: &mut local_env::Local
None,
pg_version,
ComputeMode::Primary,
DEFAULT_PAGESERVER_ID,
)?;
println!("Done");
}
@@ -697,18 +578,15 @@ async fn handle_timeline(timeline_match: &ArgMatches, env: &mut local_env::Local
.map(|lsn_str| Lsn::from_str(lsn_str))
.transpose()
.context("Failed to parse ancestor start Lsn from the request")?;
let new_timeline_id = TimelineId::generate();
let attachment_service = AttachmentService::from_env(env);
let create_req = TimelineCreateRequest {
new_timeline_id,
ancestor_timeline_id: Some(ancestor_timeline_id),
existing_initdb_timeline_id: None,
ancestor_start_lsn: start_lsn,
pg_version: None,
};
let timeline_info = attachment_service
.tenant_timeline_create(tenant_id, create_req)
.await?;
let timeline_info = pageserver.timeline_create(
tenant_id,
None,
start_lsn,
Some(ancestor_timeline_id),
None,
None,
)?;
let new_timeline_id = timeline_info.timeline_id;
let last_record_lsn = timeline_info.last_record_lsn;
@@ -726,7 +604,7 @@ async fn handle_timeline(timeline_match: &ArgMatches, env: &mut local_env::Local
Ok(())
}
async fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<()> {
fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<()> {
let (sub_name, sub_args) = match ep_match.subcommand() {
Some(ep_subcommand_data) => ep_subcommand_data,
None => bail!("no endpoint subcommand provided"),
@@ -735,15 +613,11 @@ async fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Re
match sub_name {
"list" => {
// TODO(sharding): this command shouldn't have to specify a shard ID: we should ask the attachment service
// where shard 0 is attached, and query there.
let tenant_shard_id = get_tenant_shard_id(sub_args, env)?;
let timeline_infos = get_timeline_infos(env, &tenant_shard_id)
.await
.unwrap_or_else(|e| {
eprintln!("Failed to load timeline info: {}", e);
HashMap::new()
});
let tenant_id = get_tenant_id(sub_args, env)?;
let timeline_infos = get_timeline_infos(env, &tenant_id).unwrap_or_else(|e| {
eprintln!("Failed to load timeline info: {}", e);
HashMap::new()
});
let timeline_name_mappings = env.timeline_name_mappings();
@@ -763,7 +637,7 @@ async fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Re
for (endpoint_id, endpoint) in cplane
.endpoints
.iter()
.filter(|(_, endpoint)| endpoint.tenant_id == tenant_shard_id.tenant_id)
.filter(|(_, endpoint)| endpoint.tenant_id == tenant_id)
{
let lsn_str = match endpoint.mode {
ComputeMode::Static(lsn) => {
@@ -782,10 +656,7 @@ async fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Re
};
let branch_name = timeline_name_mappings
.get(&TenantTimelineId::new(
tenant_shard_id.tenant_id,
endpoint.timeline_id,
))
.get(&TenantTimelineId::new(tenant_id, endpoint.timeline_id))
.map(|name| name.as_str())
.unwrap_or("?");
@@ -833,6 +704,13 @@ async fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Re
.copied()
.unwrap_or(false);
let pageserver_id =
if let Some(id_str) = sub_args.get_one::<String>("endpoint-pageserver-id") {
NodeId(id_str.parse().context("while parsing pageserver id")?)
} else {
DEFAULT_PAGESERVER_ID
};
let mode = match (lsn, hot_standby) {
(Some(lsn), false) => ComputeMode::Static(lsn),
(None, true) => ComputeMode::Replica,
@@ -860,6 +738,7 @@ async fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Re
http_port,
pg_version,
mode,
pageserver_id,
)?;
}
"start" => {
@@ -869,11 +748,9 @@ async fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Re
let pageserver_id =
if let Some(id_str) = sub_args.get_one::<String>("endpoint-pageserver-id") {
Some(NodeId(
id_str.parse().context("while parsing pageserver id")?,
))
NodeId(id_str.parse().context("while parsing pageserver id")?)
} else {
None
DEFAULT_PAGESERVER_ID
};
let remote_ext_config = sub_args.get_one::<String>("remote-ext-config");
@@ -904,38 +781,7 @@ async fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Re
endpoint.timeline_id,
)?;
let (pageservers, stripe_size) = if let Some(pageserver_id) = pageserver_id {
let conf = env.get_pageserver_conf(pageserver_id).unwrap();
let parsed = parse_host_port(&conf.listen_pg_addr).expect("Bad config");
(
vec![(parsed.0, parsed.1.unwrap_or(5432))],
// If caller is telling us what pageserver to use, this is not a tenant which is
// full managed by attachment service, therefore not sharded.
ShardParameters::DEFAULT_STRIPE_SIZE,
)
} else {
// Look up the currently attached location of the tenant, and its striping metadata,
// to pass these on to postgres.
let attachment_service = AttachmentService::from_env(env);
let locate_result = attachment_service.tenant_locate(endpoint.tenant_id).await?;
let pageservers = locate_result
.shards
.into_iter()
.map(|shard| {
(
Host::parse(&shard.listen_pg_addr)
.expect("Attachment service reported bad hostname"),
shard.listen_pg_port,
)
})
.collect::<Vec<_>>();
let stripe_size = locate_result.shard_params.stripe_size;
(pageservers, stripe_size)
};
assert!(!pageservers.is_empty());
let ps_conf = env.get_pageserver_conf(DEFAULT_PAGESERVER_ID)?;
let ps_conf = env.get_pageserver_conf(pageserver_id)?;
let auth_token = if matches!(ps_conf.pg_auth_type, AuthType::NeonJWT) {
let claims = Claims::new(Some(endpoint.tenant_id), Scope::Tenant);
@@ -945,15 +791,7 @@ async fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Re
};
println!("Starting existing endpoint {endpoint_id}...");
endpoint
.start(
&auth_token,
safekeepers,
pageservers,
remote_ext_config,
stripe_size.0 as usize,
)
.await?;
endpoint.start(&auth_token, safekeepers, remote_ext_config)?;
}
"reconfigure" => {
let endpoint_id = sub_args
@@ -963,31 +801,15 @@ async fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Re
.endpoints
.get(endpoint_id.as_str())
.with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?;
let pageservers =
let pageserver_id =
if let Some(id_str) = sub_args.get_one::<String>("endpoint-pageserver-id") {
let ps_id = NodeId(id_str.parse().context("while parsing pageserver id")?);
let pageserver = PageServerNode::from_env(env, env.get_pageserver_conf(ps_id)?);
vec![(
pageserver.pg_connection_config.host().clone(),
pageserver.pg_connection_config.port(),
)]
Some(NodeId(
id_str.parse().context("while parsing pageserver id")?,
))
} else {
let attachment_service = AttachmentService::from_env(env);
attachment_service
.tenant_locate(endpoint.tenant_id)
.await?
.shards
.into_iter()
.map(|shard| {
(
Host::parse(&shard.listen_pg_addr)
.expect("Attachment service reported malformed host"),
shard.listen_pg_port,
)
})
.collect::<Vec<_>>()
None
};
endpoint.reconfigure(pageservers).await?;
endpoint.reconfigure(pageserver_id)?;
}
"stop" => {
let endpoint_id = sub_args
@@ -1053,12 +875,11 @@ fn get_pageserver(env: &local_env::LocalEnv, args: &ArgMatches) -> Result<PageSe
))
}
async fn handle_pageserver(sub_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<()> {
fn handle_pageserver(sub_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<()> {
match sub_match.subcommand() {
Some(("start", subcommand_args)) => {
if let Err(e) = get_pageserver(env, subcommand_args)?
.start(&pageserver_config_overrides(subcommand_args))
.await
{
eprintln!("pageserver start failed: {e}");
exit(1);
@@ -1085,10 +906,7 @@ async fn handle_pageserver(sub_match: &ArgMatches, env: &local_env::LocalEnv) ->
exit(1);
}
if let Err(e) = pageserver
.start(&pageserver_config_overrides(subcommand_args))
.await
{
if let Err(e) = pageserver.start(&pageserver_config_overrides(subcommand_args)) {
eprintln!("pageserver start failed: {e}");
exit(1);
}
@@ -1102,32 +920,14 @@ async fn handle_pageserver(sub_match: &ArgMatches, env: &local_env::LocalEnv) ->
exit(1);
}
if let Err(e) = pageserver
.start(&pageserver_config_overrides(subcommand_args))
.await
{
if let Err(e) = pageserver.start(&pageserver_config_overrides(subcommand_args)) {
eprintln!("pageserver start failed: {e}");
exit(1);
}
}
Some(("set-state", subcommand_args)) => {
let pageserver = get_pageserver(env, subcommand_args)?;
let scheduling = subcommand_args.get_one("scheduling");
let availability = subcommand_args.get_one("availability");
let attachment_service = AttachmentService::from_env(env);
attachment_service
.node_configure(NodeConfigureRequest {
node_id: pageserver.conf.id,
scheduling: scheduling.cloned(),
availability: availability.cloned(),
})
.await?;
}
Some(("status", subcommand_args)) => {
match get_pageserver(env, subcommand_args)?.check_status().await {
match get_pageserver(env, subcommand_args)?.check_status() {
Ok(_) => println!("Page server is up and running"),
Err(err) => {
eprintln!("Page server is not available: {}", err);
@@ -1142,14 +942,11 @@ async fn handle_pageserver(sub_match: &ArgMatches, env: &local_env::LocalEnv) ->
Ok(())
}
async fn handle_attachment_service(
sub_match: &ArgMatches,
env: &local_env::LocalEnv,
) -> Result<()> {
fn handle_attachment_service(sub_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<()> {
let svc = AttachmentService::from_env(env);
match sub_match.subcommand() {
Some(("start", _start_match)) => {
if let Err(e) = svc.start().await {
if let Err(e) = svc.start() {
eprintln!("start failed: {e}");
exit(1);
}
@@ -1190,7 +987,7 @@ fn safekeeper_extra_opts(init_match: &ArgMatches) -> Vec<String> {
.collect()
}
async fn handle_safekeeper(sub_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<()> {
fn handle_safekeeper(sub_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<()> {
let (sub_name, sub_args) = match sub_match.subcommand() {
Some(safekeeper_command_data) => safekeeper_command_data,
None => bail!("no safekeeper subcommand provided"),
@@ -1208,7 +1005,7 @@ async fn handle_safekeeper(sub_match: &ArgMatches, env: &local_env::LocalEnv) ->
"start" => {
let extra_opts = safekeeper_extra_opts(sub_args);
if let Err(e) = safekeeper.start(extra_opts).await {
if let Err(e) = safekeeper.start(extra_opts) {
eprintln!("safekeeper start failed: {}", e);
exit(1);
}
@@ -1234,7 +1031,7 @@ async fn handle_safekeeper(sub_match: &ArgMatches, env: &local_env::LocalEnv) ->
}
let extra_opts = safekeeper_extra_opts(sub_args);
if let Err(e) = safekeeper.start(extra_opts).await {
if let Err(e) = safekeeper.start(extra_opts) {
eprintln!("safekeeper start failed: {}", e);
exit(1);
}
@@ -1247,15 +1044,15 @@ async fn handle_safekeeper(sub_match: &ArgMatches, env: &local_env::LocalEnv) ->
Ok(())
}
async fn handle_start_all(sub_match: &ArgMatches, env: &local_env::LocalEnv) -> anyhow::Result<()> {
fn handle_start_all(sub_match: &ArgMatches, env: &local_env::LocalEnv) -> anyhow::Result<()> {
// Endpoints are not started automatically
broker::start_broker_process(env).await?;
broker::start_broker_process(env)?;
// Only start the attachment service if the pageserver is configured to need it
if env.control_plane_api.is_some() {
let attachment_service = AttachmentService::from_env(env);
if let Err(e) = attachment_service.start().await {
if let Err(e) = attachment_service.start() {
eprintln!("attachment_service start failed: {:#}", e);
try_stop_all(env, true);
exit(1);
@@ -1264,10 +1061,7 @@ async fn handle_start_all(sub_match: &ArgMatches, env: &local_env::LocalEnv) ->
for ps_conf in &env.pageservers {
let pageserver = PageServerNode::from_env(env, ps_conf);
if let Err(e) = pageserver
.start(&pageserver_config_overrides(sub_match))
.await
{
if let Err(e) = pageserver.start(&pageserver_config_overrides(sub_match)) {
eprintln!("pageserver {} start failed: {:#}", ps_conf.id, e);
try_stop_all(env, true);
exit(1);
@@ -1276,7 +1070,7 @@ async fn handle_start_all(sub_match: &ArgMatches, env: &local_env::LocalEnv) ->
for node in env.safekeepers.iter() {
let safekeeper = SafekeeperNode::from_env(env, node);
if let Err(e) = safekeeper.start(vec![]).await {
if let Err(e) = safekeeper.start(vec![]) {
eprintln!("safekeeper {} start failed: {:#}", safekeeper.id, e);
try_stop_all(env, false);
exit(1);
@@ -1433,15 +1227,9 @@ fn cli() -> Command {
.required(false);
let force_arg = Arg::new("force")
.value_parser(value_parser!(InitForceMode))
.value_parser(value_parser!(bool))
.long("force")
.default_value(
InitForceMode::MustNotExist
.to_possible_value()
.unwrap()
.get_name()
.to_owned(),
)
.action(ArgAction::SetTrue)
.help("Force initialization even if the repository is not empty")
.required(false);
@@ -1525,8 +1313,6 @@ fn cli() -> Command {
.arg(pg_version_arg.clone())
.arg(Arg::new("set-default").long("set-default").action(ArgAction::SetTrue).required(false)
.help("Use this tenant in future CLI commands where tenant_id is needed, but not specified"))
.arg(Arg::new("shard-count").value_parser(value_parser!(u8)).long("shard-count").action(ArgAction::Set).help("Number of shards in the new tenant (default 1)"))
.arg(Arg::new("shard-stripe-size").value_parser(value_parser!(u32)).long("shard-stripe-size").action(ArgAction::Set).help("Sharding stripe size in pages"))
)
.subcommand(Command::new("set-default").arg(tenant_id_arg.clone().required(true))
.about("Set a particular tenant as default in future CLI commands where tenant_id is needed, but not specified"))
@@ -1537,9 +1323,6 @@ fn cli() -> Command {
.about("Migrate a tenant from one pageserver to another")
.arg(tenant_id_arg.clone())
.arg(pageserver_id_arg.clone()))
.subcommand(Command::new("status")
.about("Human readable summary of the tenant's shards and attachment locations")
.arg(tenant_id_arg.clone()))
)
.subcommand(
Command::new("pageserver")
@@ -1559,12 +1342,6 @@ fn cli() -> Command {
.about("Restart local pageserver")
.arg(pageserver_config_args.clone())
)
.subcommand(Command::new("set-state")
.arg(Arg::new("availability").value_parser(value_parser!(NodeAvailability)).long("availability").action(ArgAction::Set).help("Availability state: offline,active"))
.arg(Arg::new("scheduling").value_parser(value_parser!(NodeSchedulingPolicy)).long("scheduling").action(ArgAction::Set).help("Scheduling state: draining,pause,filling,active"))
.about("Set scheduling or availability state of pageserver node")
.arg(pageserver_config_args.clone())
)
)
.subcommand(
Command::new("attachment_service")

View File

@@ -11,7 +11,7 @@ use camino::Utf8PathBuf;
use crate::{background_process, local_env};
pub async fn start_broker_process(env: &local_env::LocalEnv) -> anyhow::Result<()> {
pub fn start_broker_process(env: &local_env::LocalEnv) -> anyhow::Result<()> {
let broker = &env.broker;
let listen_addr = &broker.listen_addr;
@@ -19,15 +19,15 @@ pub async fn start_broker_process(env: &local_env::LocalEnv) -> anyhow::Result<(
let args = [format!("--listen-addr={listen_addr}")];
let client = reqwest::Client::new();
let client = reqwest::blocking::Client::new();
background_process::start_process(
"storage_broker",
&env.base_data_dir,
&env.storage_broker_bin(),
args,
[],
background_process::InitialPidFile::Create(storage_broker_pid_file_path(env)),
|| async {
background_process::InitialPidFile::Create(&storage_broker_pid_file_path(env)),
|| {
let url = broker.client_url();
let status_url = url.join("status").with_context(|| {
format!("Failed to append /status path to broker endpoint {url}")
@@ -36,13 +36,12 @@ pub async fn start_broker_process(env: &local_env::LocalEnv) -> anyhow::Result<(
.get(status_url)
.build()
.with_context(|| format!("Failed to construct request to broker endpoint {url}"))?;
match client.execute(request).await {
match client.execute(request) {
Ok(resp) => Ok(resp.status().is_success()),
Err(_) => Ok(false),
}
},
)
.await
.context("Failed to spawn storage_broker subprocess")?;
Ok(())
}

View File

@@ -46,14 +46,11 @@ use std::time::Duration;
use anyhow::{anyhow, bail, Context, Result};
use compute_api::spec::RemoteExtSpec;
use nix::sys::signal::kill;
use nix::sys::signal::Signal;
use serde::{Deserialize, Serialize};
use url::Host;
use utils::id::{NodeId, TenantId, TimelineId};
use crate::attachment_service::AttachmentService;
use crate::local_env::LocalEnv;
use crate::pageserver::PageServerNode;
use crate::postgresql_conf::PostgresConf;
use compute_api::responses::{ComputeState, ComputeStatus};
@@ -70,6 +67,7 @@ pub struct EndpointConf {
http_port: u16,
pg_version: u32,
skip_pg_catalog_updates: bool,
pageserver_id: NodeId,
}
//
@@ -121,14 +119,19 @@ impl ComputeControlPlane {
http_port: Option<u16>,
pg_version: u32,
mode: ComputeMode,
pageserver_id: NodeId,
) -> Result<Arc<Endpoint>> {
let pg_port = pg_port.unwrap_or_else(|| self.get_port());
let http_port = http_port.unwrap_or_else(|| self.get_port() + 1);
let pageserver =
PageServerNode::from_env(&self.env, self.env.get_pageserver_conf(pageserver_id)?);
let ep = Arc::new(Endpoint {
endpoint_id: endpoint_id.to_owned(),
pg_address: SocketAddr::new("127.0.0.1".parse().unwrap(), pg_port),
http_address: SocketAddr::new("127.0.0.1".parse().unwrap(), http_port),
env: self.env.clone(),
pageserver,
timeline_id,
mode,
tenant_id,
@@ -154,6 +157,7 @@ impl ComputeControlPlane {
pg_port,
pg_version,
skip_pg_catalog_updates: true,
pageserver_id,
})?,
)?;
std::fs::write(
@@ -212,6 +216,7 @@ pub struct Endpoint {
// These are not part of the endpoint as such, but the environment
// the endpoint runs in.
pub env: LocalEnv,
pageserver: PageServerNode,
// Optimizations
skip_pg_catalog_updates: bool,
@@ -234,11 +239,15 @@ impl Endpoint {
let conf: EndpointConf =
serde_json::from_slice(&std::fs::read(entry.path().join("endpoint.json"))?)?;
let pageserver =
PageServerNode::from_env(env, env.get_pageserver_conf(conf.pageserver_id)?);
Ok(Endpoint {
pg_address: SocketAddr::new("127.0.0.1".parse().unwrap(), conf.pg_port),
http_address: SocketAddr::new("127.0.0.1".parse().unwrap(), conf.http_port),
endpoint_id,
env: env.clone(),
pageserver,
timeline_id: conf.timeline_id,
mode: conf.mode,
tenant_id: conf.tenant_id,
@@ -430,14 +439,11 @@ impl Endpoint {
Ok(())
}
fn wait_for_compute_ctl_to_exit(&self, send_sigterm: bool) -> Result<()> {
fn wait_for_compute_ctl_to_exit(&self) -> Result<()> {
// TODO use background_process::stop_process instead
let pidfile_path = self.endpoint_path().join("compute_ctl.pid");
let pid: u32 = std::fs::read_to_string(pidfile_path)?.parse()?;
let pid = nix::unistd::Pid::from_raw(pid as i32);
if send_sigterm {
kill(pid, Signal::SIGTERM).ok();
}
crate::background_process::wait_until_stopped("compute_ctl", pid)?;
Ok(())
}
@@ -458,21 +464,11 @@ impl Endpoint {
}
}
fn build_pageserver_connstr(pageservers: &[(Host, u16)]) -> String {
pageservers
.iter()
.map(|(host, port)| format!("postgresql://no_user@{host}:{port}"))
.collect::<Vec<_>>()
.join(",")
}
pub async fn start(
pub fn start(
&self,
auth_token: &Option<String>,
safekeepers: Vec<NodeId>,
pageservers: Vec<(Host, u16)>,
remote_ext_config: Option<&String>,
shard_stripe_size: usize,
) -> Result<()> {
if self.status() == "running" {
anyhow::bail!("The endpoint is already running");
@@ -486,9 +482,13 @@ impl Endpoint {
std::fs::remove_dir_all(self.pgdata())?;
}
let pageserver_connstring = Self::build_pageserver_connstr(&pageservers);
assert!(!pageserver_connstring.is_empty());
let pageserver_connstring = {
let config = &self.pageserver.pg_connection_config;
let (host, port) = (config.host(), config.port());
// NOTE: avoid spaces in connection string, because it is less error prone if we forward it somewhere.
format!("postgresql://no_user@{host}:{port}")
};
let mut safekeeper_connstrings = Vec::new();
if self.mode == ComputeMode::Primary {
for sk_id in safekeepers {
@@ -537,8 +537,6 @@ impl Endpoint {
safekeeper_connstrings,
storage_auth_token: auth_token.clone(),
remote_extensions,
pgbouncer_settings: None,
shard_stripe_size: Some(shard_stripe_size),
};
let spec_path = self.endpoint_path().join("spec.json");
std::fs::write(spec_path, serde_json::to_string_pretty(&spec)?)?;
@@ -589,7 +587,7 @@ impl Endpoint {
const MAX_ATTEMPTS: u32 = 10 * 30; // Wait up to 30 s
loop {
attempt += 1;
match self.get_status().await {
match self.get_status() {
Ok(state) => {
match state.status {
ComputeStatus::Init => {
@@ -631,8 +629,8 @@ impl Endpoint {
}
// Call the /status HTTP API
pub async fn get_status(&self) -> Result<ComputeState> {
let client = reqwest::Client::new();
pub fn get_status(&self) -> Result<ComputeState> {
let client = reqwest::blocking::Client::new();
let response = client
.request(
@@ -643,17 +641,16 @@ impl Endpoint {
self.http_address.port()
),
)
.send()
.await?;
.send()?;
// Interpret the response
let status = response.status();
if !(status.is_client_error() || status.is_server_error()) {
Ok(response.json().await?)
Ok(response.json()?)
} else {
// reqwest does not export its error construction utility functions, so let's craft the message ourselves
let url = response.url().to_owned();
let msg = match response.text().await {
let msg = match response.text() {
Ok(err_body) => format!("Error: {}", err_body),
Err(_) => format!("Http error ({}) at {}.", status.as_u16(), url),
};
@@ -661,7 +658,7 @@ impl Endpoint {
}
}
pub async fn reconfigure(&self, mut pageservers: Vec<(Host, u16)>) -> Result<()> {
pub fn reconfigure(&self, pageserver_id: Option<NodeId>) -> Result<()> {
let mut spec: ComputeSpec = {
let spec_path = self.endpoint_path().join("spec.json");
let file = std::fs::File::open(spec_path)?;
@@ -671,28 +668,26 @@ impl Endpoint {
let postgresql_conf = self.read_postgresql_conf()?;
spec.cluster.postgresql_conf = Some(postgresql_conf);
// If we weren't given explicit pageservers, query the attachment service
if pageservers.is_empty() {
let attachment_service = AttachmentService::from_env(&self.env);
let locate_result = attachment_service.tenant_locate(self.tenant_id).await?;
pageservers = locate_result
.shards
.into_iter()
.map(|shard| {
(
Host::parse(&shard.listen_pg_addr)
.expect("Attachment service reported bad hostname"),
shard.listen_pg_port,
)
})
.collect::<Vec<_>>();
if let Some(pageserver_id) = pageserver_id {
let endpoint_config_path = self.endpoint_path().join("endpoint.json");
let mut endpoint_conf: EndpointConf = {
let file = std::fs::File::open(&endpoint_config_path)?;
serde_json::from_reader(file)?
};
endpoint_conf.pageserver_id = pageserver_id;
std::fs::write(
endpoint_config_path,
serde_json::to_string_pretty(&endpoint_conf)?,
)?;
let pageserver =
PageServerNode::from_env(&self.env, self.env.get_pageserver_conf(pageserver_id)?);
let ps_http_conf = &pageserver.pg_connection_config;
let (host, port) = (ps_http_conf.host(), ps_http_conf.port());
spec.pageserver_connstring = Some(format!("postgresql://no_user@{host}:{port}"));
}
let pageserver_connstr = Self::build_pageserver_connstr(&pageservers);
assert!(!pageserver_connstr.is_empty());
spec.pageserver_connstring = Some(pageserver_connstr);
let client = reqwest::Client::new();
let client = reqwest::blocking::Client::new();
let response = client
.post(format!(
"http://{}:{}/configure",
@@ -703,15 +698,14 @@ impl Endpoint {
"{{\"spec\":{}}}",
serde_json::to_string_pretty(&spec)?
))
.send()
.await?;
.send()?;
let status = response.status();
if !(status.is_client_error() || status.is_server_error()) {
Ok(())
} else {
let url = response.url().to_owned();
let msg = match response.text().await {
let msg = match response.text() {
Ok(err_body) => format!("Error: {}", err_body),
Err(_) => format!("Http error ({}) at {}.", status.as_u16(), url),
};
@@ -736,15 +730,10 @@ impl Endpoint {
&None,
)?;
// Also wait for the compute_ctl process to die. It might have some
// cleanup work to do after postgres stops, like syncing safekeepers,
// etc.
// Also wait for the compute_ctl process to die. It might have some cleanup
// work to do after postgres stops, like syncing safekeepers, etc.
//
// If destroying, send it SIGTERM before waiting. Sometimes we do *not*
// want this cleanup: tests intentionally do stop when majority of
// safekeepers is down, so sync-safekeepers would hang otherwise. This
// could be a separate flag though.
self.wait_for_compute_ctl_to_exit(destroy)?;
self.wait_for_compute_ctl_to_exit()?;
if destroy {
println!(
"Destroying postgres data directory '{}'",

View File

@@ -14,3 +14,4 @@ pub mod local_env;
pub mod pageserver;
pub mod postgresql_conf;
pub mod safekeeper;
pub mod tenant_migration;

View File

@@ -5,7 +5,6 @@
use anyhow::{bail, ensure, Context};
use clap::ValueEnum;
use postgres_backend::AuthType;
use reqwest::Url;
use serde::{Deserialize, Serialize};
@@ -163,31 +162,6 @@ impl Default for SafekeeperConf {
}
}
#[derive(Clone, Copy)]
pub enum InitForceMode {
MustNotExist,
EmptyDirOk,
RemoveAllContents,
}
impl ValueEnum for InitForceMode {
fn value_variants<'a>() -> &'a [Self] {
&[
Self::MustNotExist,
Self::EmptyDirOk,
Self::RemoveAllContents,
]
}
fn to_possible_value(&self) -> Option<clap::builder::PossibleValue> {
Some(clap::builder::PossibleValue::new(match self {
InitForceMode::MustNotExist => "must-not-exist",
InitForceMode::EmptyDirOk => "empty-dir-ok",
InitForceMode::RemoveAllContents => "remove-all-contents",
}))
}
}
impl SafekeeperConf {
/// Compute is served by port on which only tenant scoped tokens allowed, if
/// it is configured.
@@ -251,13 +225,7 @@ impl LocalEnv {
if let Some(conf) = self.pageservers.iter().find(|node| node.id == id) {
Ok(conf)
} else {
let have_ids = self
.pageservers
.iter()
.map(|node| format!("{}:{}", node.id, node.listen_http_addr))
.collect::<Vec<_>>();
let joined = have_ids.join(",");
bail!("could not find pageserver {id}, have ids {joined}")
bail!("could not find pageserver {id}")
}
}
@@ -416,7 +384,7 @@ impl LocalEnv {
//
// Initialize a new Neon repository
//
pub fn init(&mut self, pg_version: u32, force: &InitForceMode) -> anyhow::Result<()> {
pub fn init(&mut self, pg_version: u32, force: bool) -> anyhow::Result<()> {
// check if config already exists
let base_path = &self.base_data_dir;
ensure!(
@@ -425,34 +393,25 @@ impl LocalEnv {
);
if base_path.exists() {
match force {
InitForceMode::MustNotExist => {
bail!(
"directory '{}' already exists. Perhaps already initialized?",
base_path.display()
);
}
InitForceMode::EmptyDirOk => {
if let Some(res) = std::fs::read_dir(base_path)?.next() {
res.context("check if directory is empty")?;
anyhow::bail!("directory not empty: {base_path:?}");
}
}
InitForceMode::RemoveAllContents => {
println!("removing all contents of '{}'", base_path.display());
// instead of directly calling `remove_dir_all`, we keep the original dir but removing
// all contents inside. This helps if the developer symbol links another directory (i.e.,
// S3 local SSD) to the `.neon` base directory.
for entry in std::fs::read_dir(base_path)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
fs::remove_dir_all(&path)?;
} else {
fs::remove_file(&path)?;
}
if force {
println!("removing all contents of '{}'", base_path.display());
// instead of directly calling `remove_dir_all`, we keep the original dir but removing
// all contents inside. This helps if the developer symbol links another directory (i.e.,
// S3 local SSD) to the `.neon` base directory.
for entry in std::fs::read_dir(base_path)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
fs::remove_dir_all(&path)?;
} else {
fs::remove_file(&path)?;
}
}
} else {
bail!(
"directory '{}' already exists. Perhaps already initialized? (Hint: use --force to remove all contents)",
base_path.display()
);
}
}

View File

@@ -6,26 +6,28 @@
//!
use std::borrow::Cow;
use std::collections::HashMap;
use std::io;
use std::io::Write;
use std::fs::File;
use std::io::{BufReader, Write};
use std::num::NonZeroU64;
use std::path::PathBuf;
use std::process::{Child, Command};
use std::time::Duration;
use std::{io, result};
use anyhow::{bail, Context};
use camino::Utf8PathBuf;
use futures::SinkExt;
use pageserver_api::models::{
self, LocationConfig, ShardParameters, TenantHistorySize, TenantInfo, TimelineInfo,
self, LocationConfig, TenantInfo, TenantLocationConfigRequest, TimelineInfo,
};
use pageserver_api::shard::TenantShardId;
use pageserver_client::mgmt_api;
use postgres_backend::AuthType;
use postgres_connection::{parse_host_port, PgConnectionConfig};
use reqwest::blocking::{Client, RequestBuilder, Response};
use reqwest::{IntoUrl, Method};
use thiserror::Error;
use utils::auth::{Claims, Scope};
use utils::{
http::error::HttpErrorBody,
id::{TenantId, TimelineId},
lsn::Lsn,
};
@@ -36,6 +38,45 @@ use crate::{background_process, local_env::LocalEnv};
/// Directory within .neon which will be used by default for LocalFs remote storage.
pub const PAGESERVER_REMOTE_STORAGE_DIR: &str = "local_fs_remote_storage/pageserver";
#[derive(Error, Debug)]
pub enum PageserverHttpError {
#[error("Reqwest error: {0}")]
Transport(#[from] reqwest::Error),
#[error("Error: {0}")]
Response(String),
}
impl From<anyhow::Error> for PageserverHttpError {
fn from(e: anyhow::Error) -> Self {
Self::Response(e.to_string())
}
}
type Result<T> = result::Result<T, PageserverHttpError>;
pub trait ResponseErrorMessageExt: Sized {
fn error_from_body(self) -> Result<Self>;
}
impl ResponseErrorMessageExt for Response {
fn error_from_body(self) -> Result<Self> {
let status = self.status();
if !(status.is_client_error() || status.is_server_error()) {
return Ok(self);
}
// reqwest does not export its error construction utility functions, so let's craft the message ourselves
let url = self.url().to_owned();
Err(PageserverHttpError::Response(
match self.json::<HttpErrorBody>() {
Ok(err_body) => format!("Error: {}", err_body.msg),
Err(_) => format!("Http error ({}) at {}.", status.as_u16(), url),
},
))
}
}
//
// Control routines for pageserver.
//
@@ -46,7 +87,8 @@ pub struct PageServerNode {
pub pg_connection_config: PgConnectionConfig,
pub conf: PageServerConf,
pub env: LocalEnv,
pub http_client: mgmt_api::Client,
pub http_client: Client,
pub http_base_url: String,
}
impl PageServerNode {
@@ -58,19 +100,8 @@ impl PageServerNode {
pg_connection_config: PgConnectionConfig::new_host_port(host, port),
conf: conf.clone(),
env: env.clone(),
http_client: mgmt_api::Client::new(
format!("http://{}", conf.listen_http_addr),
{
match conf.http_auth_type {
AuthType::Trust => None,
AuthType::NeonJWT => Some(
env.generate_auth_token(&Claims::new(None, Scope::PageServerApi))
.unwrap(),
),
}
}
.as_deref(),
),
http_client: Client::new(),
http_base_url: format!("http://{}/v1", conf.listen_http_addr),
}
}
@@ -108,16 +139,6 @@ impl PageServerNode {
"control_plane_api='{}'",
control_plane_api.as_str()
));
// Attachment service uses the same auth as pageserver: if JWT is enabled
// for us, we will also need it to talk to them.
if matches!(self.conf.http_auth_type, AuthType::NeonJWT) {
let jwt_token = self
.env
.generate_auth_token(&Claims::new(None, Scope::PageServerApi))
.unwrap();
overrides.push(format!("control_plane_api_token='{}'", jwt_token));
}
}
if !cli_overrides
@@ -161,8 +182,8 @@ impl PageServerNode {
.expect("non-Unicode path")
}
pub async fn start(&self, config_overrides: &[&str]) -> anyhow::Result<Child> {
self.start_node(config_overrides, false).await
pub fn start(&self, config_overrides: &[&str]) -> anyhow::Result<Child> {
self.start_node(config_overrides, false)
}
fn pageserver_init(&self, config_overrides: &[&str]) -> anyhow::Result<()> {
@@ -203,12 +224,7 @@ impl PageServerNode {
Ok(())
}
async fn start_node(
&self,
config_overrides: &[&str],
update_config: bool,
) -> anyhow::Result<Child> {
// TODO: using a thread here because start_process() is not async but we need to call check_status()
fn start_node(&self, config_overrides: &[&str], update_config: bool) -> anyhow::Result<Child> {
let datadir = self.repo_path();
print!(
"Starting pageserver node {} at '{}' in {:?}",
@@ -216,7 +232,7 @@ impl PageServerNode {
self.pg_connection_config.raw_address(),
datadir
);
io::stdout().flush().context("flush stdout")?;
io::stdout().flush()?;
let datadir_path_str = datadir.to_str().with_context(|| {
format!(
@@ -228,23 +244,20 @@ impl PageServerNode {
if update_config {
args.push(Cow::Borrowed("--update-config"));
}
background_process::start_process(
"pageserver",
&datadir,
&self.env.pageserver_bin(),
args.iter().map(Cow::as_ref),
self.pageserver_env_variables()?,
background_process::InitialPidFile::Expect(self.pid_file()),
|| async {
let st = self.check_status().await;
match st {
Ok(()) => Ok(true),
Err(mgmt_api::Error::ReceiveBody(_)) => Ok(false),
Err(e) => Err(anyhow::anyhow!("Failed to check node status: {e}")),
}
background_process::InitialPidFile::Expect(&self.pid_file()),
|| match self.check_status() {
Ok(()) => Ok(true),
Err(PageserverHttpError::Transport(_)) => Ok(false),
Err(e) => Err(anyhow::anyhow!("Failed to check node status: {e}")),
},
)
.await
}
fn pageserver_basic_args<'a>(
@@ -290,12 +303,7 @@ impl PageServerNode {
background_process::stop_process(immediate, "pageserver", &self.pid_file())
}
pub async fn page_server_psql_client(
&self,
) -> anyhow::Result<(
tokio_postgres::Client,
tokio_postgres::Connection<tokio_postgres::Socket, tokio_postgres::tls::NoTlsStream>,
)> {
pub fn page_server_psql_client(&self) -> anyhow::Result<postgres::Client> {
let mut config = self.pg_connection_config.clone();
if self.conf.pg_auth_type == AuthType::NeonJWT {
let token = self
@@ -303,18 +311,44 @@ impl PageServerNode {
.generate_auth_token(&Claims::new(None, Scope::PageServerApi))?;
config = config.set_password(Some(token));
}
Ok(config.connect_no_tls().await?)
Ok(config.connect_no_tls()?)
}
pub async fn check_status(&self) -> mgmt_api::Result<()> {
self.http_client.status().await
fn http_request<U: IntoUrl>(&self, method: Method, url: U) -> anyhow::Result<RequestBuilder> {
let mut builder = self.http_client.request(method, url);
if self.conf.http_auth_type == AuthType::NeonJWT {
let token = self
.env
.generate_auth_token(&Claims::new(None, Scope::PageServerApi))?;
builder = builder.bearer_auth(token)
}
Ok(builder)
}
pub async fn tenant_list(&self) -> mgmt_api::Result<Vec<TenantInfo>> {
self.http_client.list_tenants().await
pub fn check_status(&self) -> Result<()> {
self.http_request(Method::GET, format!("{}/status", self.http_base_url))?
.send()?
.error_from_body()?;
Ok(())
}
pub fn parse_config(mut settings: HashMap<&str, &str>) -> anyhow::Result<models::TenantConfig> {
let result = models::TenantConfig {
pub fn tenant_list(&self) -> Result<Vec<TenantInfo>> {
Ok(self
.http_request(Method::GET, format!("{}/tenant", self.http_base_url))?
.send()?
.error_from_body()?
.json()?)
}
pub fn tenant_create(
&self,
new_tenant_id: TenantId,
generation: Option<u32>,
settings: HashMap<&str, &str>,
) -> anyhow::Result<TenantId> {
let mut settings = settings.clone();
let config = models::TenantConfig {
checkpoint_distance: settings
.remove("checkpoint_distance")
.map(|x| x.parse::<u64>())
@@ -375,34 +409,32 @@ impl PageServerNode {
.context("Failed to parse 'gc_feedback' as bool")?,
heatmap_period: settings.remove("heatmap_period").map(|x| x.to_string()),
};
if !settings.is_empty() {
bail!("Unrecognized tenant settings: {settings:?}")
} else {
Ok(result)
}
}
pub async fn tenant_create(
&self,
new_tenant_id: TenantId,
generation: Option<u32>,
settings: HashMap<&str, &str>,
) -> anyhow::Result<TenantId> {
let config = Self::parse_config(settings.clone())?;
let request = models::TenantCreateRequest {
new_tenant_id: TenantShardId::unsharded(new_tenant_id),
generation,
config,
shard_parameters: ShardParameters::default(),
};
if !settings.is_empty() {
bail!("Unrecognized tenant settings: {settings:?}")
}
Ok(self.http_client.tenant_create(&request).await?)
self.http_request(Method::POST, format!("{}/tenant", self.http_base_url))?
.json(&request)
.send()?
.error_from_body()?
.json::<Option<String>>()
.with_context(|| {
format!("Failed to parse tenant creation response for tenant id: {new_tenant_id:?}")
})?
.context("No tenant id was found in the tenant creation response")
.and_then(|tenant_id_string| {
tenant_id_string.parse().with_context(|| {
format!("Failed to parse response string as tenant id: '{tenant_id_string}'")
})
})
}
pub async fn tenant_config(
pub fn tenant_config(
&self,
tenant_id: TenantId,
mut settings: HashMap<&str, &str>,
@@ -481,59 +513,87 @@ impl PageServerNode {
bail!("Unrecognized tenant settings: {settings:?}")
}
self.http_client
.tenant_config(&models::TenantConfigRequest { tenant_id, config })
.await?;
self.http_request(Method::PUT, format!("{}/tenant/config", self.http_base_url))?
.json(&models::TenantConfigRequest { tenant_id, config })
.send()?
.error_from_body()?;
Ok(())
}
pub async fn location_config(
pub fn location_config(
&self,
tenant_shard_id: TenantShardId,
tenant_id: TenantId,
config: LocationConfig,
flush_ms: Option<Duration>,
) -> anyhow::Result<()> {
Ok(self
.http_client
.location_config(tenant_shard_id, config, flush_ms)
.await?)
let req_body = TenantLocationConfigRequest { tenant_id, config };
let path = format!(
"{}/tenant/{}/location_config",
self.http_base_url, tenant_id
);
let path = if let Some(flush_ms) = flush_ms {
format!("{}?flush_ms={}", path, flush_ms.as_millis())
} else {
path
};
self.http_request(Method::PUT, path)?
.json(&req_body)
.send()?
.error_from_body()?;
Ok(())
}
pub async fn timeline_list(
pub fn timeline_list(&self, tenant_id: &TenantId) -> anyhow::Result<Vec<TimelineInfo>> {
let timeline_infos: Vec<TimelineInfo> = self
.http_request(
Method::GET,
format!("{}/tenant/{}/timeline", self.http_base_url, tenant_id),
)?
.send()?
.error_from_body()?
.json()?;
Ok(timeline_infos)
}
pub fn timeline_create(
&self,
tenant_shard_id: &TenantShardId,
) -> anyhow::Result<Vec<TimelineInfo>> {
Ok(self.http_client.list_timelines(*tenant_shard_id).await?)
}
pub async fn tenant_secondary_download(&self, tenant_id: &TenantShardId) -> anyhow::Result<()> {
Ok(self
.http_client
.tenant_secondary_download(*tenant_id)
.await?)
}
pub async fn timeline_create(
&self,
tenant_shard_id: TenantShardId,
new_timeline_id: TimelineId,
tenant_id: TenantId,
new_timeline_id: Option<TimelineId>,
ancestor_start_lsn: Option<Lsn>,
ancestor_timeline_id: Option<TimelineId>,
pg_version: Option<u32>,
existing_initdb_timeline_id: Option<TimelineId>,
) -> anyhow::Result<TimelineInfo> {
let req = models::TimelineCreateRequest {
// If timeline ID was not specified, generate one
let new_timeline_id = new_timeline_id.unwrap_or(TimelineId::generate());
self.http_request(
Method::POST,
format!("{}/tenant/{}/timeline", self.http_base_url, tenant_id),
)?
.json(&models::TimelineCreateRequest {
new_timeline_id,
ancestor_start_lsn,
ancestor_timeline_id,
pg_version,
existing_initdb_timeline_id,
};
Ok(self
.http_client
.timeline_create(tenant_shard_id, &req)
.await?)
})
.send()?
.error_from_body()?
.json::<Option<TimelineInfo>>()
.with_context(|| {
format!("Failed to parse timeline creation response for tenant id: {tenant_id}")
})?
.with_context(|| {
format!(
"No timeline id was found in the timeline creation response for tenant {tenant_id}"
)
})
}
/// Import a basebackup prepared using either:
@@ -545,7 +605,7 @@ impl PageServerNode {
/// * `timeline_id` - id to assign to imported timeline
/// * `base` - (start lsn of basebackup, path to `base.tar` file)
/// * `pg_wal` - if there's any wal to import: (end lsn, path to `pg_wal.tar`)
pub async fn timeline_import(
pub fn timeline_import(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
@@ -553,72 +613,38 @@ impl PageServerNode {
pg_wal: Option<(Lsn, PathBuf)>,
pg_version: u32,
) -> anyhow::Result<()> {
let (client, conn) = self.page_server_psql_client().await?;
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
tokio::spawn(async move {
if let Err(e) = conn.await {
eprintln!("connection error: {}", e);
}
});
tokio::pin!(client);
let mut client = self.page_server_psql_client()?;
// Init base reader
let (start_lsn, base_tarfile_path) = base;
let base_tarfile = tokio::fs::File::open(base_tarfile_path).await?;
let base_tarfile = tokio_util::io::ReaderStream::new(base_tarfile);
let base_tarfile = File::open(base_tarfile_path)?;
let mut base_reader = BufReader::new(base_tarfile);
// Init wal reader if necessary
let (end_lsn, wal_reader) = if let Some((end_lsn, wal_tarfile_path)) = pg_wal {
let wal_tarfile = tokio::fs::File::open(wal_tarfile_path).await?;
let wal_reader = tokio_util::io::ReaderStream::new(wal_tarfile);
let wal_tarfile = File::open(wal_tarfile_path)?;
let wal_reader = BufReader::new(wal_tarfile);
(end_lsn, Some(wal_reader))
} else {
(start_lsn, None)
};
let copy_in = |reader, cmd| {
let client = &client;
async move {
let writer = client.copy_in(&cmd).await?;
let writer = std::pin::pin!(writer);
let mut writer = writer.sink_map_err(|e| {
std::io::Error::new(std::io::ErrorKind::Other, format!("{e}"))
});
let mut reader = std::pin::pin!(reader);
writer.send_all(&mut reader).await?;
writer.into_inner().finish().await?;
anyhow::Ok(())
}
};
// Import base
copy_in(
base_tarfile,
format!(
"import basebackup {tenant_id} {timeline_id} {start_lsn} {end_lsn} {pg_version}"
),
)
.await?;
let import_cmd = format!(
"import basebackup {tenant_id} {timeline_id} {start_lsn} {end_lsn} {pg_version}"
);
let mut writer = client.copy_in(&import_cmd)?;
io::copy(&mut base_reader, &mut writer)?;
writer.finish()?;
// Import wal if necessary
if let Some(wal_reader) = wal_reader {
copy_in(
wal_reader,
format!("import wal {tenant_id} {timeline_id} {start_lsn} {end_lsn}"),
)
.await?;
if let Some(mut wal_reader) = wal_reader {
let import_cmd = format!("import wal {tenant_id} {timeline_id} {start_lsn} {end_lsn}");
let mut writer = client.copy_in(&import_cmd)?;
io::copy(&mut wal_reader, &mut writer)?;
writer.finish()?;
}
Ok(())
}
pub async fn tenant_synthetic_size(
&self,
tenant_shard_id: TenantShardId,
) -> anyhow::Result<TenantHistorySize> {
Ok(self
.http_client
.tenant_synthetic_size(tenant_shard_id)
.await?)
}
}

View File

@@ -13,6 +13,7 @@ use std::{io, result};
use anyhow::Context;
use camino::Utf8PathBuf;
use postgres_connection::PgConnectionConfig;
use reqwest::blocking::{Client, RequestBuilder, Response};
use reqwest::{IntoUrl, Method};
use thiserror::Error;
use utils::{http::error::HttpErrorBody, id::NodeId};
@@ -33,14 +34,12 @@ pub enum SafekeeperHttpError {
type Result<T> = result::Result<T, SafekeeperHttpError>;
#[async_trait::async_trait]
pub trait ResponseErrorMessageExt: Sized {
async fn error_from_body(self) -> Result<Self>;
fn error_from_body(self) -> Result<Self>;
}
#[async_trait::async_trait]
impl ResponseErrorMessageExt for reqwest::Response {
async fn error_from_body(self) -> Result<Self> {
impl ResponseErrorMessageExt for Response {
fn error_from_body(self) -> Result<Self> {
let status = self.status();
if !(status.is_client_error() || status.is_server_error()) {
return Ok(self);
@@ -49,7 +48,7 @@ impl ResponseErrorMessageExt for reqwest::Response {
// reqwest does not export its error construction utility functions, so let's craft the message ourselves
let url = self.url().to_owned();
Err(SafekeeperHttpError::Response(
match self.json::<HttpErrorBody>().await {
match self.json::<HttpErrorBody>() {
Ok(err_body) => format!("Error: {}", err_body.msg),
Err(_) => format!("Http error ({}) at {}.", status.as_u16(), url),
},
@@ -70,7 +69,7 @@ pub struct SafekeeperNode {
pub pg_connection_config: PgConnectionConfig,
pub env: LocalEnv,
pub http_client: reqwest::Client,
pub http_client: Client,
pub http_base_url: String,
}
@@ -81,7 +80,7 @@ impl SafekeeperNode {
conf: conf.clone(),
pg_connection_config: Self::safekeeper_connection_config(conf.pg_port),
env: env.clone(),
http_client: reqwest::Client::new(),
http_client: Client::new(),
http_base_url: format!("http://127.0.0.1:{}/v1", conf.http_port),
}
}
@@ -104,7 +103,7 @@ impl SafekeeperNode {
.expect("non-Unicode path")
}
pub async fn start(&self, extra_opts: Vec<String>) -> anyhow::Result<Child> {
pub fn start(&self, extra_opts: Vec<String>) -> anyhow::Result<Child> {
print!(
"Starting safekeeper at '{}' in '{}'",
self.pg_connection_config.raw_address(),
@@ -192,16 +191,13 @@ impl SafekeeperNode {
&self.env.safekeeper_bin(),
&args,
[],
background_process::InitialPidFile::Expect(self.pid_file()),
|| async {
match self.check_status().await {
Ok(()) => Ok(true),
Err(SafekeeperHttpError::Transport(_)) => Ok(false),
Err(e) => Err(anyhow::anyhow!("Failed to check node status: {e}")),
}
background_process::InitialPidFile::Expect(&self.pid_file()),
|| match self.check_status() {
Ok(()) => Ok(true),
Err(SafekeeperHttpError::Transport(_)) => Ok(false),
Err(e) => Err(anyhow::anyhow!("Failed to check node status: {e}")),
},
)
.await
}
///
@@ -220,7 +216,7 @@ impl SafekeeperNode {
)
}
fn http_request<U: IntoUrl>(&self, method: Method, url: U) -> reqwest::RequestBuilder {
fn http_request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
// TODO: authentication
//if self.env.auth_type == AuthType::NeonJWT {
// builder = builder.bearer_auth(&self.env.safekeeper_auth_token)
@@ -228,12 +224,10 @@ impl SafekeeperNode {
self.http_client.request(method, url)
}
pub async fn check_status(&self) -> Result<()> {
pub fn check_status(&self) -> Result<()> {
self.http_request(Method::GET, format!("{}/{}", self.http_base_url, "status"))
.send()
.await?
.error_from_body()
.await?;
.send()?
.error_from_body()?;
Ok(())
}
}

View File

@@ -0,0 +1,197 @@
//!
//! Functionality for migrating tenants across pageservers: unlike most of neon_local, this code
//! isn't scoped to a particular physical service, as it needs to update compute endpoints to
//! point to the new pageserver.
//!
use crate::local_env::LocalEnv;
use crate::{
attachment_service::AttachmentService, endpoint::ComputeControlPlane,
pageserver::PageServerNode,
};
use pageserver_api::models::{
LocationConfig, LocationConfigMode, LocationConfigSecondary, TenantConfig,
};
use std::collections::HashMap;
use std::time::Duration;
use utils::{
id::{TenantId, TimelineId},
lsn::Lsn,
};
/// Given an attached pageserver, retrieve the LSN for all timelines
fn get_lsns(
tenant_id: TenantId,
pageserver: &PageServerNode,
) -> anyhow::Result<HashMap<TimelineId, Lsn>> {
let timelines = pageserver.timeline_list(&tenant_id)?;
Ok(timelines
.into_iter()
.map(|t| (t.timeline_id, t.last_record_lsn))
.collect())
}
/// Wait for the timeline LSNs on `pageserver` to catch up with or overtake
/// `baseline`.
fn await_lsn(
tenant_id: TenantId,
pageserver: &PageServerNode,
baseline: HashMap<TimelineId, Lsn>,
) -> anyhow::Result<()> {
loop {
let latest = match get_lsns(tenant_id, pageserver) {
Ok(l) => l,
Err(e) => {
println!(
"🕑 Can't get LSNs on pageserver {} yet, waiting ({e})",
pageserver.conf.id
);
std::thread::sleep(Duration::from_millis(500));
continue;
}
};
let mut any_behind: bool = false;
for (timeline_id, baseline_lsn) in &baseline {
match latest.get(timeline_id) {
Some(latest_lsn) => {
println!("🕑 LSN origin {baseline_lsn} vs destination {latest_lsn}");
if latest_lsn < baseline_lsn {
any_behind = true;
}
}
None => {
// Expected timeline isn't yet visible on migration destination.
// (IRL we would have to account for timeline deletion, but this
// is just test helper)
any_behind = true;
}
}
}
if !any_behind {
println!("✅ LSN caught up. Proceeding...");
break;
} else {
std::thread::sleep(Duration::from_millis(500));
}
}
Ok(())
}
/// This function spans multiple services, to demonstrate live migration of a tenant
/// between pageservers:
/// - Coordinate attach/secondary/detach on pageservers
/// - call into attachment_service for generations
/// - reconfigure compute endpoints to point to new attached pageserver
pub fn migrate_tenant(
env: &LocalEnv,
tenant_id: TenantId,
dest_ps: PageServerNode,
) -> anyhow::Result<()> {
// Get a new generation
let attachment_service = AttachmentService::from_env(env);
fn build_location_config(
mode: LocationConfigMode,
generation: Option<u32>,
secondary_conf: Option<LocationConfigSecondary>,
) -> LocationConfig {
LocationConfig {
mode,
generation,
secondary_conf,
tenant_conf: TenantConfig::default(),
shard_number: 0,
shard_count: 0,
shard_stripe_size: 0,
}
}
let previous = attachment_service.inspect(tenant_id)?;
let mut baseline_lsns = None;
if let Some((generation, origin_ps_id)) = &previous {
let origin_ps = PageServerNode::from_env(env, env.get_pageserver_conf(*origin_ps_id)?);
if origin_ps_id == &dest_ps.conf.id {
println!("🔁 Already attached to {origin_ps_id}, freshening...");
let gen = attachment_service.attach_hook(tenant_id, dest_ps.conf.id)?;
let dest_conf = build_location_config(LocationConfigMode::AttachedSingle, gen, None);
dest_ps.location_config(tenant_id, dest_conf, None)?;
println!("✅ Migration complete");
return Ok(());
}
println!("🔁 Switching origin pageserver {origin_ps_id} to stale mode");
let stale_conf =
build_location_config(LocationConfigMode::AttachedStale, Some(*generation), None);
origin_ps.location_config(tenant_id, stale_conf, Some(Duration::from_secs(10)))?;
baseline_lsns = Some(get_lsns(tenant_id, &origin_ps)?);
}
let gen = attachment_service.attach_hook(tenant_id, dest_ps.conf.id)?;
let dest_conf = build_location_config(LocationConfigMode::AttachedMulti, gen, None);
println!("🔁 Attaching to pageserver {}", dest_ps.conf.id);
dest_ps.location_config(tenant_id, dest_conf, None)?;
if let Some(baseline) = baseline_lsns {
println!("🕑 Waiting for LSN to catch up...");
await_lsn(tenant_id, &dest_ps, baseline)?;
}
let cplane = ComputeControlPlane::load(env.clone())?;
for (endpoint_name, endpoint) in &cplane.endpoints {
if endpoint.tenant_id == tenant_id {
println!(
"🔁 Reconfiguring endpoint {} to use pageserver {}",
endpoint_name, dest_ps.conf.id
);
endpoint.reconfigure(Some(dest_ps.conf.id))?;
}
}
for other_ps_conf in &env.pageservers {
if other_ps_conf.id == dest_ps.conf.id {
continue;
}
let other_ps = PageServerNode::from_env(env, other_ps_conf);
let other_ps_tenants = other_ps.tenant_list()?;
// Check if this tenant is attached
let found = other_ps_tenants
.into_iter()
.map(|t| t.id)
.any(|i| i.tenant_id == tenant_id);
if !found {
continue;
}
// Downgrade to a secondary location
let secondary_conf = build_location_config(
LocationConfigMode::Secondary,
None,
Some(LocationConfigSecondary { warm: true }),
);
println!(
"💤 Switching to secondary mode on pageserver {}",
other_ps.conf.id
);
other_ps.location_config(tenant_id, secondary_conf, None)?;
}
println!(
"🔁 Switching to AttachedSingle mode on pageserver {}",
dest_ps.conf.id
);
let dest_conf = build_location_config(LocationConfigMode::AttachedSingle, gen, None);
dest_ps.location_config(tenant_id, dest_conf, None)?;
println!("✅ Migration complete");
Ok(())
}

View File

@@ -35,7 +35,6 @@ allow = [
"Artistic-2.0",
"BSD-2-Clause",
"BSD-3-Clause",
"CC0-1.0",
"ISC",
"MIT",
"MPL-2.0",

View File

@@ -1,197 +0,0 @@
# Per-Tenant GetPage@LSN Throttling
Author: Christian Schwarz
Date: Oct 24, 2023
## Summary
This RFC proposes per-tenant throttling of GetPage@LSN requests inside Pageserver
and the interactions with its client, i.e., the neon_smgr component in Compute.
The result of implementing & executing this RFC will be a fleet-wide upper limit for
**"the highest GetPage/second that Pageserver can support for a single tenant/shard"**.
## Background
### GetPage@LSN Request Flow
Pageserver exposes its `page_service.rs` as a libpq listener.
The Computes' `neon_smgr` module connects to that libpq listener.
Once a connection is established, the protocol allows Compute to request page images at a given LSN.
We call these requests GetPage@LSN requests, or GetPage requests for short.
Other request types can be sent, but these are low traffic compared to GetPage requests
and are not the concern of this RFC.
Pageserver associates one libpq connection with one tokio task.
Per connection/task, the pq protocol is handled by the common `postgres_backend` crate.
Its `run_message_loop` function invokes the `page_service` specific `impl<IO> postgres_backend::Handler<IO> for PageServerHandler`.
Requests are processed in the order in which they arrive via the TCP-based pq protocol.
So, there is no concurrent request processing within one connection/task.
There is a degree of natural pipelining:
Compute can "fill the pipe" by sending more than one GetPage request into the libpq TCP stream.
And Pageserver can fill the pipe with responses in the other direction.
Both directions are subject to the limit of tx/rx buffers, nodelay, TCP flow control, etc.
### GetPage@LSN Access Pattern
The Compute has its own hierarchy of caches, specifically `shared_buffers` and the `local file cache` (LFC).
Compute only issues GetPage requests to Pageserver if it encounters a miss in these caches.
If the working set stops fitting into Compute's caches, requests to Pageserver increase sharply -- the Compute starts *thrashing*.
## Motivation
In INC-69, a tenant issued 155k GetPage/second for a period of 10 minutes and 60k GetPage/second for a period of 3h,
then dropping to ca 18k GetPage/second for a period of 9h.
We noticed this because of an internal GetPage latency SLO burn rate alert, i.e.,
the request latency profile during this period significantly exceeded what was acceptable according to the internal SLO.
Sadly, we do not have the observability data to determine the impact of this tenant on other tenants on the same tenants.
However, here are some illustrative data points for the 155k period:
The tenant was responsible for >= 99% of the GetPage traffic and, frankly, the overall activity on this Pageserver instance.
We were serving pages at 10 Gb/s (`155k x 8 kbyte (PAGE_SZ) per second is 1.12GiB/s = 9.4Gb/s.`)
The CPU utilization of the instance was 75% user+system.
Pageserver page cache served 1.75M accesses/second at a hit rate of ca 90%.
The hit rate for materialized pages was ca. 40%.
Curiously, IOPS to the Instance Store NVMe were very low, rarely exceeding 100.
The fact that the IOPS were so low / the materialized page cache hit rate was so high suggests that **this tenant's compute's caches were thrashing**.
The compute was of type `k8s-pod`; hence, auto-scaling could/would not have helped remediate the thrashing by provisioning more RAM.
The consequence was that the **thrashing translated into excessive GetPage requests against Pageserver**.
My claim is that it was **unhealthy to serve this workload at the pace we did**:
* it is likely that other tenants were/would have experienced high latencies (again, we sadly don't have per-tenant latency data to confirm this)
* more importantly, it was **unsustainable** to serve traffic at this pace for multiple reasons:
* **predictability of performance**: when the working set grows, the pageserver materialized page cache hit rate drops.
At some point, we're bound by the EC2 Instance Store NVMe drive's IOPS limit.
The result is an **uneven** performance profile from the Compute perspective.
* **economics**: Neon currently does not charge for IOPS, only capacity.
**We cannot afford to undercut the market in IOPS/$ this drastically; it leads to adverse selection and perverse incentives.**
For example, the 155k IOPS, which we served for 10min, would cost ca. 6.5k$/month when provisioned as an io2 EBS volume.
Even the 18k IOPS, which we served for 9h, would cost ca. 1.1k$/month when provisioned as an io2 EBS volume.
We charge 0$.
It could be economically advantageous to keep using a low-DRAM compute because Pageserver IOPS are fast enough and free.
Note: It is helpful to think of Pageserver as a disk, because it's precisely where `neon_smgr` sits:
vanilla Postgres gets its pages from disk, Neon Postgres gets them from Pageserver.
So, regarding the above performance & economic arguments, it is fair to say that we currently provide an "as-fast-as-possible-IOPS" disk that we charge for only by capacity.
## Solution: Throttling GetPage Requests
**The consequence of the above analysis must be that Pageserver throttles GetPage@LSN requests**.
That is, unless we want to start charging for provisioned GetPage@LSN/second.
Throttling sets the correct incentive for a thrashing Compute to scale up its DRAM to the working set size.
Neon Autoscaling will make this easy, [eventually](https://github.com/neondatabase/neon/pull/3913).
## The Design Space
What that remains is the question about *policy* and *mechanism*:
**Policy** concerns itself with the question of what limit applies to a given connection|timeline|tenant.
Candidates are:
* hard limit, same limit value per connection|timeline|tenant
* Per-tenant will provide an upper bound for the impact of a tenant on a given Pageserver instance.
This is a major operational pain point / risk right now.
* hard limit, configurable per connection|timeline|tenant
* This outsources policy to console/control plane, with obvious advantages for flexible structuring of what service we offer to customers.
* Note that this is not a mechanism to guarantee a minium provisioned rate, i.e., this is not a mechanism to guarantee a certain QoS for a tenant.
* fair share among active connections|timelines|tenants per instance
* example: each connection|timeline|tenant gets a fair fraction of the machine's GetPage/second capacity
* NB: needs definition of "active", and knowledge of available GetPage/second capacity in advance
* ...
Regarding **mechanism**, it's clear that **backpressure** is the way to go.
However, we must choose between
* **implicit** backpressure through pq/TCP and
* **explicit** rejection of requests + retries with exponential backoff
Further, there is the question of how throttling GetPage@LSN will affect the **internal GetPage latency SLO**:
where do we measure the SLI for Pageserver's internal getpage latency SLO? Before or after the throttling?
And when we eventually move the measurement point into the Computes (to avoid coordinated omission),
how do we avoid counting throttling-induced latency toward the internal getpage latency SLI/SLO?
## Scope Of This RFC
**This RFC proposes introducing a hard GetPage@LSN/second limit per tenant, with the same value applying to each tenant on a Pageserver**.
This proposal is easy to implement and significantly de-risks operating large Pageservers,
based on the assumption that extremely-high-GetPage-rate-episodes like the one from the "Motivation" section are uncorrelated between tenants.
For example, suppose we pick a limit that allows up to 10 tenants to go at limit rate.
Suppose our Pageserver can serve 100k GetPage/second total at a 100% page cache miss rate.
If each tenant gets a hard limit of 10k GetPage/second, we can serve up to 10 tenants at limit speed without latency degradation.
The mechanism for backpressure will be TCP-based implicit backpressure.
The compute team isn't concerned about prefetch queue depth.
Pageserver will implement it by delaying the reading of requests from the libpq connection(s).
The rate limit will be implemented using a per-tenant token bucket.
The bucket will be be shared among all connections to the tenant.
The bucket implementation supports starvation-preventing `await`ing.
The current candidate for the implementation is [`leaky_bucket`](https://docs.rs/leaky-bucket/).
The getpage@lsn benchmark that's being added in https://github.com/neondatabase/neon/issues/5771
can be used to evaluate the overhead of sharing the bucket among connections of a tenant.
A possible technique to mitigate the impact of sharing the bucket would be to maintain a buffer of a few tokens per connection handler.
Regarding metrics / the internal GetPage latency SLO:
we will measure the GetPage latency SLO _after_ the throttler and introduce a new metric to measure the amount of throttling, quantified by:
- histogram that records the tenants' observations of queue depth before they start waiting (one such histogram per pageserver)
- histogram that records the tenants' observations of time spent waiting (one such histogram per pageserver)
Further observability measures:
- an INFO log message at frequency 1/min if the tenant/timeline/connection was throttled in that last minute.
The message will identify the tenant/timeline/connection to allow correlation with compute logs/stats.
Rollout will happen as follows:
- deploy 1: implementation + config: disabled by default, ability to enable it per tenant through tenant_conf
- experimentation in staging and later production to study impact & interaction with auto-scaling
- determination of a sensible global default value
- the value will be chosen as high as possible ...
- ... but low enough to work towards this RFC's goal that one tenant should not be able to dominate a pageserver instance.
- deploy 2: implementation fixes if any + config: enabled by default with the aforementioned global default
- reset of the experimental per-tenant overrides
- gain experience & lower the limit over time
- we stop lowering the limit as soon as this RFC's goal is achieved, i.e.,
once we decide that in practice the chosen value sufficiently de-risks operating large pageservers
The per-tenant override will remain for emergencies and testing.
But since Console doesn't preserve it during tenant migrations, it isn't durably configurable for the tenant.
Toward the upper layers of the Neon stack, the resulting limit will be
**"the highest GetPage/second that Pageserver can support for a single tenant"**.
### Rationale
We decided against error + retry because of worries about starvation.
## Future Work
Enable per-tenant emergency override of the limit via Console.
Should be part of a more general framework to specify tenant config overrides.
**NB:** this is **not** the right mechanism to _sell_ different max GetPage/second levels to users,
or _auto-scale_ the GetPage/second levels. Such functionality will require a separate RFC that
concerns itself with GetPage/second capacity planning.
Compute-side metrics for GetPage latency.
Back-channel to inform Compute/Autoscaling/ControlPlane that the project is being throttled.
Compute-side neon_smgr improvements to avoid sending the same GetPage request multiple times if multiple backends experience a cache miss.
Dealing with read-only endpoints: users use read-only endpoints to scale reads for a single tenant.
Possibly there are also assumptions around read-only endpoints not affecting the primary read-write endpoint's performance.
With per-tenant rate limiting, we will not meet that expectation.
However, we can currently only scale per tenant.
Soon, we will have sharding (#5505), which will apply the throttling on a per-shard basis.
But, that's orthogonal to scaling reads: if many endpoints hit one shard, they share the same throttling limit.
To solve this properly, I think we'll need replicas for tenants / shard.
To performance-isolate a tenant's endpoints from each other, we'd then route them to different replicas.

View File

@@ -1,142 +0,0 @@
# Vectored Timeline Get
Created on: 2024-01-02
Author: Christian Schwarz
# Summary
A brief RFC / GitHub Epic describing a vectored version of the `Timeline::get` method that is at the heart of Pageserver.
# Motivation
During basebackup, we issue many `Timeline::get` calls for SLRU pages that are *adjacent* in key space.
For an example, see
https://github.com/neondatabase/neon/blob/5c88213eaf1b1e29c610a078d0b380f69ed49a7e/pageserver/src/basebackup.rs#L281-L302.
Each of these `Timeline::get` calls must traverse the layer map to gather reconstruct data (`Timeline::get_reconstruct_data`) for the requested page number (`blknum` in the example).
For each layer visited by layer map traversal, we do a `DiskBtree` point lookup.
If it's negative (no entry), we resume layer map traversal.
If it's positive, we collect the result in our reconstruct data bag.
If the reconstruct data bag contents suffice to reconstruct the page, we're done with `get_reconstruct_data` and move on to walredo.
Otherwise, we resume layer map traversal.
Doing this many `Timeline::get` calls is quite inefficient because:
1. We do the layer map traversal repeatedly, even if, e.g., all the data sits in the same image layer at the bottom of the stack.
2. We may visit many DiskBtree inner pages multiple times for point lookup of different keys.
This is likely particularly bad for L0s which span the whole key space and hence must be visited by layer map traversal, but
may not contain the data we're looking for.
3. Anecdotally, keys adjacent in keyspace and written simultaneously also end up physically adjacent in the layer files [^1].
So, to provide the reconstruct data for N adjacent keys, we would actually only _need_ to issue a single large read to the filesystem, instead of the N reads we currently do.
The filesystem, in turn, ideally stores the layer file physically contiguously, so our large read will turn into one IOP toward the disk.
[^1]: https://www.notion.so/neondatabase/Christian-Investigation-Slow-Basebackups-Early-2023-12-34ea5c7dcdc1485d9ac3731da4d2a6fc?pvs=4#15ee4e143392461fa64590679c8f54c9
# Solution
We should have a vectored aka batched aka scatter-gather style alternative API for `Timeline::get`. Having such an API unlocks:
* more efficient basebackup
* batched IO during compaction (useful for strides of unchanged pages)
* page_service: expose vectored get_page_at_lsn for compute (=> good for seqscan / prefetch)
* if [on-demand SLRU downloads](https://github.com/neondatabase/neon/pull/6151) land before vectored Timeline::get, on-demand SLRU downloads will still benefit from this API
# DoD
There is a new variant of `Timeline::get`, called `Timeline::get_vectored`.
It takes as arguments an `lsn: Lsn` and a `src: &[KeyVec]` where `struct KeyVec { base: Key, count: usize }`.
It is up to the implementor to figure out a suitable and efficient way to return the reconstructed page images.
It is sufficient to simply return a `Vec<Bytes>`, but, likely more efficient solutions can be found after studying all the callers of `Timeline::get`.
Functionally, the behavior of `Timeline::get_vectored` is equivalent to
```rust
let mut keys_iter: impl Iterator<Item=Key>
= src.map(|KeyVec{ base, count }| (base..base+count)).flatten();
let mut out = Vec::new();
for key in keys_iter {
let data = Timeline::get(key, lsn)?;
out.push(data);
}
return out;
```
However, unlike above, an ideal solution will
* Visit each `struct Layer` at most once.
* For each visited layer, call `Layer::get_value_reconstruct_data` at most once.
* This means, read each `DiskBtree` page at most once.
* Facilitate merging of the reads we issue to the OS and eventually NVMe.
Each of these items above represents a signficant amount of work.
## Performance
Ideally, the **base performance** of a vectored get of a single page should be identical to the current `Timeline::get`.
A reasonable constant overhead over current `Timeline::get` is acceptable.
The performance improvement for the vectored use case is demonstrated in some way, e.g., using the `pagebench` basebackup benchmark against a tenant with a lot of SLRU segments.
# Implementation
High-level set of tasks / changes to be made:
- **Get clarity on API**:
- Define naive `Timeline::get_vectored` implementation & adopt it across pageserver.
- The tricky thing here will be the return type (e.g. `Vec<Bytes>` vs `impl Stream`).
- Start with something simple to explore the different usages of the API.
Then iterate with peers until we have something that is good enough.
- **Vectored Layer Map traversal**
- Vectored `LayerMap::search` (take 1 LSN and N `Key`s instead of just 1 LSN and 1 `Key`)
- Refactor `Timeline::get_reconstruct_data` to hold & return state for N `Key`s instead of 1
- The slightly tricky part here is what to do about `cont_lsn` [after we've found some reconstruct data for some keys](https://github.com/neondatabase/neon/blob/d066dad84b076daf3781cdf9a692098889d3974e/pageserver/src/tenant/timeline.rs#L2378-L2385)
but need more.
Likely we'll need to keep track of `cont_lsn` per key and continue next iteration at `max(cont_lsn)` of all keys that still need data.
- **Vectored `Layer::get_value_reconstruct_data` / `DiskBtree`**
- Current code calls it [here](https://github.com/neondatabase/neon/blob/d066dad84b076daf3781cdf9a692098889d3974e/pageserver/src/tenant/timeline.rs#L2378-L2384).
- Delta layers use `DiskBtreeReader::visit()` to collect the `(offset,len)` pairs for delta record blobs to load.
- Image layers use `DiskBtreeReader::get` to get the offset of the image blob to load. Underneath, that's just a `::visit()` call.
- What needs to happen to `DiskBtree::visit()`?
* Minimally
* take a single `KeyVec` instead of a single `Key` as argument, i.e., take a single contiguous key range to visit.
* Change the visit code to to invoke the callback for all values in the `KeyVec`'s key range
* This should be good enough for what we've seen when investigating basebackup slowness, because there, the key ranges are contiguous.
* Ideally:
* Take a `&[KeyVec]`, sort it;
* during Btree traversal, peek at the next `KeyVec` range to determine whether we need to descend or back out.
* NB: this should be a straight-forward extension of the minimal solution above, as we'll already be checking for "is there more key range in the requested `KeyVec`".
- **Facilitate merging of the reads we issue to the OS and eventually NVMe.**
- The `DiskBtree::visit` produces a set of offsets which we then read from a `VirtualFile` [here](https://github.com/neondatabase/neon/blob/292281c9dfb24152b728b1a846cc45105dac7fe0/pageserver/src/tenant/storage_layer/delta_layer.rs#L772-L804)
- [Delta layer reads](https://github.com/neondatabase/neon/blob/292281c9dfb24152b728b1a846cc45105dac7fe0/pageserver/src/tenant/storage_layer/delta_layer.rs#L772-L804)
- We hit (and rely) on `PageCache` and `VirtualFile here (not great under pressure)
- [Image layer reads](https://github.com/neondatabase/neon/blob/292281c9dfb24152b728b1a846cc45105dac7fe0/pageserver/src/tenant/storage_layer/image_layer.rs#L429-L435)
- What needs to happen is the **vectorization of the `blob_io` interface and then the `VirtualFile` API**.
- That is tricky because
- the `VirtualFile` API, which sits underneath `blob_io`, is being touched by ongoing [io_uring work](https://github.com/neondatabase/neon/pull/5824)
- there's the question how IO buffers will be managed; currently this area relies heavily on `PageCache`, but there's controversy around the future of `PageCache`.
- The guiding principle here should be to avoid coupling this work to the `PageCache`.
- I.e., treat `PageCache` as an extra hop in the I/O chain, rather than as an integral part of buffer management.
Let's see how we can improve by doing the first three items in above list first, then revisit.
## Rollout / Feature Flags
No feature flags are required for this epic.
At the end of this epic, `Timeline::get` forwards to `Timeline::get_vectored`, i.e., it's an all-or-nothing type of change.
It is encouraged to deliver this feature incrementally, i.e., do many small PRs over multiple weeks.
That will help isolate performance regressions across weekly releases.
# Interaction With Sharding
[Sharding](https://github.com/neondatabase/neon/pull/5432) splits up the key space, see functions `is_key_local` / `key_to_shard_number`.
Just as with `Timeline::get`, callers of `Timeline::get_vectored` are responsible for ensuring that they only ask for blocks of the given `struct Timeline`'s shard.
Given that this is already the case, there shouldn't be significant interaction/interference with sharding.
However, let's have a safety check for this constraint (error or assertion) because there are currently few affordances at the higher layers of Pageserver for sharding<=>keyspace interaction.
For example, `KeySpace` is not broken up by shard stripe, so if someone naively converted the compaction code to issue a vectored get for a keyspace range it would violate this constraint.

View File

@@ -129,13 +129,13 @@ Run `poetry shell` to activate the virtual environment.
Alternatively, use `poetry run` to run a single command in the venv, e.g. `poetry run pytest`.
### Obligatory checks
We force code formatting via `ruff`, and type hints via `mypy`.
We force code formatting via `black`, `ruff`, and type hints via `mypy`.
Run the following commands in the repository's root (next to `pyproject.toml`):
```bash
poetry run ruff format . # All code is reformatted
poetry run ruff check . # Python linter
poetry run mypy . # Ensure there are no typing errors
poetry run black . # All code is reformatted
poetry run ruff . # Python linter
poetry run mypy . # Ensure there are no typing errors
```
**WARNING**: do not run `mypy` from a directory other than the root of the repository.

View File

@@ -73,12 +73,6 @@ pub struct ComputeSpec {
// information about available remote extensions
pub remote_extensions: Option<RemoteExtSpec>,
pub pgbouncer_settings: Option<HashMap<String, String>>,
// Stripe size for pageserver sharding, in pages
#[serde(default)]
pub shard_stripe_size: Option<usize>,
}
/// Feature flag to signal `compute_ctl` to enable certain experimental functionality.
@@ -86,13 +80,10 @@ pub struct ComputeSpec {
#[serde(rename_all = "snake_case")]
pub enum ComputeFeature {
// XXX: Add more feature flags here.
/// Enable the experimental activity monitor logic, which uses `pg_stat_database` to
/// track short-lived connections as user activity.
ActivityMonitorExperimental,
/// This is a special feature flag that is used to represent unknown feature flags.
/// Basically all unknown to enum flags are represented as this one. See unit test
/// `parse_unknown_features()` for more details.
// This is a special feature flag that is used to represent unknown feature flags.
// Basically all unknown to enum flags are represented as this one. See unit test
// `parse_unknown_features()` for more details.
#[serde(other)]
UnknownFeature,
}
@@ -289,23 +280,4 @@ mod tests {
assert!(spec.features.contains(&ComputeFeature::UnknownFeature));
assert_eq!(spec.features, vec![ComputeFeature::UnknownFeature; 2]);
}
#[test]
fn parse_known_features() {
// Test that we can properly parse known feature flags.
let file = File::open("tests/cluster_spec.json").unwrap();
let mut json: serde_json::Value = serde_json::from_reader(file).unwrap();
let ob = json.as_object_mut().unwrap();
// Add known feature flags.
let features = vec!["activity_monitor_experimental"];
ob.insert("features".into(), features.into());
let spec: ComputeSpec = serde_json::from_value(json).unwrap();
assert_eq!(
spec.features,
vec![ComputeFeature::ActivityMonitorExperimental]
);
}
}

View File

@@ -243,9 +243,5 @@
"public_extensions": [
"postgis"
]
},
"pgbouncer_settings": {
"default_pool_size": "42",
"pool_mode": "session"
}
}

View File

@@ -19,10 +19,8 @@ strum.workspace = true
strum_macros.workspace = true
hex.workspace = true
thiserror.workspace = true
humantime-serde.workspace = true
workspace_hack.workspace = true
[dev-dependencies]
bincode.workspace = true
rand.workspace = true

View File

@@ -3,8 +3,6 @@ use byteorder::{ByteOrder, BE};
use serde::{Deserialize, Serialize};
use std::fmt;
use crate::reltag::{BlockNumber, RelTag};
/// Key used in the Repository kv-store.
///
/// The Repository treats this as an opaque struct, but see the code in pgdatadir_mapping.rs
@@ -143,57 +141,6 @@ impl Key {
}
}
#[inline(always)]
pub fn is_rel_block_key(key: &Key) -> bool {
key.field1 == 0x00 && key.field4 != 0 && key.field6 != 0xffffffff
}
/// Guaranteed to return `Ok()` if [[is_rel_block_key]] returns `true` for `key`.
pub fn key_to_rel_block(key: Key) -> anyhow::Result<(RelTag, BlockNumber)> {
Ok(match key.field1 {
0x00 => (
RelTag {
spcnode: key.field2,
dbnode: key.field3,
relnode: key.field4,
forknum: key.field5,
},
key.field6,
),
_ => anyhow::bail!("unexpected value kind 0x{:02x}", key.field1),
})
}
impl std::str::FromStr for Key {
type Err = anyhow::Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
Self::from_hex(s)
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use crate::key::Key;
use rand::Rng;
use rand::SeedableRng;
#[test]
fn display_fromstr_bijection() {
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let key = Key {
field1: rng.gen(),
field2: rng.gen(),
field3: rng.gen(),
field4: rng.gen(),
field5: rng.gen(),
field6: rng.gen(),
};
assert_eq!(key, Key::from_str(&format!("{key}")).unwrap());
}
key.field1 == 0x00 && key.field4 != 0
}

View File

@@ -5,7 +5,6 @@ use const_format::formatcp;
/// Public API types
pub mod control_api;
pub mod key;
pub mod keyspace;
pub mod models;
pub mod reltag;
pub mod shard;

View File

@@ -1,10 +1,7 @@
pub mod partitioning;
use std::{
collections::HashMap,
io::{BufRead, Read},
num::{NonZeroU64, NonZeroUsize},
time::{Duration, SystemTime},
time::SystemTime,
};
use byteorder::{BigEndian, ReadBytesExt};
@@ -18,12 +15,9 @@ use utils::{
lsn::Lsn,
};
use crate::{
reltag::RelTag,
shard::{ShardCount, ShardStripeSize, TenantShardId},
};
use crate::{reltag::RelTag, shard::TenantShardId};
use anyhow::bail;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use bytes::{BufMut, Bytes, BytesMut};
/// The state of a tenant in this pageserver.
///
@@ -191,31 +185,6 @@ pub struct TimelineCreateRequest {
pub pg_version: Option<u32>,
}
/// Parameters that apply to all shards in a tenant. Used during tenant creation.
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct ShardParameters {
pub count: ShardCount,
pub stripe_size: ShardStripeSize,
}
impl ShardParameters {
pub const DEFAULT_STRIPE_SIZE: ShardStripeSize = ShardStripeSize(256 * 1024 / 8);
pub fn is_unsharded(&self) -> bool {
self.count == ShardCount(0)
}
}
impl Default for ShardParameters {
fn default() -> Self {
Self {
count: ShardCount(0),
stripe_size: Self::DEFAULT_STRIPE_SIZE,
}
}
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct TenantCreateRequest {
@@ -223,12 +192,6 @@ pub struct TenantCreateRequest {
#[serde(default)]
#[serde(skip_serializing_if = "Option::is_none")]
pub generation: Option<u32>,
// If omitted, create a single shard with TenantShardId::unsharded()
#[serde(default)]
#[serde(skip_serializing_if = "ShardParameters::is_unsharded")]
pub shard_parameters: ShardParameters,
#[serde(flatten)]
pub config: TenantConfig, // as we have a flattened field, we should reject all unknown fields in it
}
@@ -251,7 +214,7 @@ impl std::ops::Deref for TenantCreateRequest {
/// An alternative representation of `pageserver::tenant::TenantConf` with
/// simpler types.
#[derive(Serialize, Deserialize, Debug, Default, Clone, Eq, PartialEq)]
#[derive(Serialize, Deserialize, Debug, Default)]
pub struct TenantConfig {
pub checkpoint_distance: Option<u64>,
pub checkpoint_timeout: Option<String>,
@@ -266,41 +229,21 @@ pub struct TenantConfig {
pub lagging_wal_timeout: Option<String>,
pub max_lsn_wal_lag: Option<NonZeroU64>,
pub trace_read_requests: Option<bool>,
pub eviction_policy: Option<EvictionPolicy>,
// We defer the parsing of the eviction_policy field to the request handler.
// Otherwise we'd have to move the types for eviction policy into this package.
// We might do that once the eviction feature has stabilizied.
// For now, this field is not even documented in the openapi_spec.yml.
pub eviction_policy: Option<serde_json::Value>,
pub min_resident_size_override: Option<u64>,
pub evictions_low_residence_duration_metric_threshold: Option<String>,
pub gc_feedback: Option<bool>,
pub heatmap_period: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "kind")]
pub enum EvictionPolicy {
NoEviction,
LayerAccessThreshold(EvictionPolicyLayerAccessThreshold),
}
impl EvictionPolicy {
pub fn discriminant_str(&self) -> &'static str {
match self {
EvictionPolicy::NoEviction => "NoEviction",
EvictionPolicy::LayerAccessThreshold(_) => "LayerAccessThreshold",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct EvictionPolicyLayerAccessThreshold {
#[serde(with = "humantime_serde")]
pub period: Duration,
#[serde(with = "humantime_serde")]
pub threshold: Duration,
}
/// A flattened analog of a `pagesever::tenant::LocationMode`, which
/// lists out all possible states (and the virtual "Detached" state)
/// in a flat form rather than using rust-style enums.
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
#[derive(Serialize, Deserialize, Debug)]
pub enum LocationConfigMode {
AttachedSingle,
AttachedMulti,
@@ -309,21 +252,19 @@ pub enum LocationConfigMode {
Detached,
}
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
#[derive(Serialize, Deserialize, Debug)]
pub struct LocationConfigSecondary {
pub warm: bool,
}
/// An alternative representation of `pageserver::tenant::LocationConf`,
/// for use in external-facing APIs.
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
#[derive(Serialize, Deserialize, Debug)]
pub struct LocationConfig {
pub mode: LocationConfigMode,
/// If attaching, in what generation?
#[serde(default)]
pub generation: Option<u32>,
// If requesting mode `Secondary`, configuration for that.
#[serde(default)]
pub secondary_conf: Option<LocationConfigSecondary>,
@@ -336,17 +277,11 @@ pub struct LocationConfig {
#[serde(default)]
pub shard_stripe_size: u32,
// This configuration only affects attached mode, but should be provided irrespective
// of the mode, as a secondary location might transition on startup if the response
// to the `/re-attach` control plane API requests it.
// If requesting mode `Secondary`, configuration for that.
// Custom storage configuration for the tenant, if any
pub tenant_conf: TenantConfig,
}
#[derive(Serialize, Deserialize)]
pub struct LocationConfigListResponse {
pub tenant_shards: Vec<(TenantShardId, Option<LocationConfig>)>,
}
#[derive(Serialize, Deserialize)]
#[serde(transparent)]
pub struct TenantCreateResponse(pub TenantId);
@@ -359,7 +294,7 @@ pub struct StatusResponse {
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct TenantLocationConfigRequest {
pub tenant_id: TenantShardId,
pub tenant_id: TenantId,
#[serde(flatten)]
pub config: LocationConfig, // as we have a flattened field, we should reject all unknown fields in it
}
@@ -430,16 +365,6 @@ pub struct TenantInfo {
/// If a layer is present in both local FS and S3, it counts only once.
pub current_physical_size: Option<u64>, // physical size is only included in `tenant_status` endpoint
pub attachment_status: TenantAttachmentStatus,
#[serde(skip_serializing_if = "Option::is_none")]
pub generation: Option<u32>,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct TenantDetails {
#[serde(flatten)]
pub tenant_info: TenantInfo,
pub timelines: Vec<TimelineId>,
}
/// This represents the output of the "timeline_detail" and "timeline_list" API calls.
@@ -621,6 +546,19 @@ pub enum DownloadRemoteLayersTaskState {
ShutDown,
}
pub type ConfigureFailpointsRequest = Vec<FailpointConfig>;
/// Information for configuring a single fail point
#[derive(Debug, Serialize, Deserialize)]
pub struct FailpointConfig {
/// Name of the fail point
pub name: String,
/// List of actions to take, using the format described in `fail::cfg`
///
/// We also support `actions = "exit"` to cause the fail point to immediately exit.
pub actions: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TimelineGcRequest {
pub gc_horizon: Option<u64>,
@@ -636,7 +574,6 @@ pub enum PagestreamFeMessage {
}
// Wrapped in libpq CopyData
#[derive(strum_macros::EnumProperty)]
pub enum PagestreamBeMessage {
Exists(PagestreamExistsResponse),
Nblocks(PagestreamNblocksResponse),
@@ -645,29 +582,6 @@ pub enum PagestreamBeMessage {
DbSize(PagestreamDbSizeResponse),
}
// Keep in sync with `pagestore_client.h`
#[repr(u8)]
enum PagestreamBeMessageTag {
Exists = 100,
Nblocks = 101,
GetPage = 102,
Error = 103,
DbSize = 104,
}
impl TryFrom<u8> for PagestreamBeMessageTag {
type Error = u8;
fn try_from(value: u8) -> Result<Self, u8> {
match value {
100 => Ok(PagestreamBeMessageTag::Exists),
101 => Ok(PagestreamBeMessageTag::Nblocks),
102 => Ok(PagestreamBeMessageTag::GetPage),
103 => Ok(PagestreamBeMessageTag::Error),
104 => Ok(PagestreamBeMessageTag::DbSize),
_ => Err(value),
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct PagestreamExistsRequest {
pub latest: bool,
@@ -722,17 +636,6 @@ pub struct PagestreamDbSizeResponse {
pub db_size: i64,
}
// This is a cut-down version of TenantHistorySize from the pageserver crate, omitting fields
// that require pageserver-internal types. It is sufficient to get the total size.
#[derive(Serialize, Deserialize, Debug)]
pub struct TenantHistorySize {
pub id: TenantId,
/// Size is a mixture of WAL and logical size, so the unit is bytes.
///
/// Will be none if `?inputs_only=true` was given.
pub size: Option<u64>,
}
impl PagestreamFeMessage {
pub fn serialize(&self) -> Bytes {
let mut bytes = BytesMut::new();
@@ -834,92 +737,35 @@ impl PagestreamBeMessage {
pub fn serialize(&self) -> Bytes {
let mut bytes = BytesMut::new();
use PagestreamBeMessageTag as Tag;
match self {
Self::Exists(resp) => {
bytes.put_u8(Tag::Exists as u8);
bytes.put_u8(100); /* tag from pagestore_client.h */
bytes.put_u8(resp.exists as u8);
}
Self::Nblocks(resp) => {
bytes.put_u8(Tag::Nblocks as u8);
bytes.put_u8(101); /* tag from pagestore_client.h */
bytes.put_u32(resp.n_blocks);
}
Self::GetPage(resp) => {
bytes.put_u8(Tag::GetPage as u8);
bytes.put_u8(102); /* tag from pagestore_client.h */
bytes.put(&resp.page[..]);
}
Self::Error(resp) => {
bytes.put_u8(Tag::Error as u8);
bytes.put_u8(103); /* tag from pagestore_client.h */
bytes.put(resp.message.as_bytes());
bytes.put_u8(0); // null terminator
}
Self::DbSize(resp) => {
bytes.put_u8(Tag::DbSize as u8);
bytes.put_u8(104); /* tag from pagestore_client.h */
bytes.put_i64(resp.db_size);
}
}
bytes.into()
}
pub fn deserialize(buf: Bytes) -> anyhow::Result<Self> {
let mut buf = buf.reader();
let msg_tag = buf.read_u8()?;
use PagestreamBeMessageTag as Tag;
let ok =
match Tag::try_from(msg_tag).map_err(|tag: u8| anyhow::anyhow!("invalid tag {tag}"))? {
Tag::Exists => {
let exists = buf.read_u8()?;
Self::Exists(PagestreamExistsResponse {
exists: exists != 0,
})
}
Tag::Nblocks => {
let n_blocks = buf.read_u32::<BigEndian>()?;
Self::Nblocks(PagestreamNblocksResponse { n_blocks })
}
Tag::GetPage => {
let mut page = vec![0; 8192]; // TODO: use MaybeUninit
buf.read_exact(&mut page)?;
PagestreamBeMessage::GetPage(PagestreamGetPageResponse { page: page.into() })
}
Tag::Error => {
let mut msg = Vec::new();
buf.read_until(0, &mut msg)?;
let cstring = std::ffi::CString::from_vec_with_nul(msg)?;
let rust_str = cstring.to_str()?;
PagestreamBeMessage::Error(PagestreamErrorResponse {
message: rust_str.to_owned(),
})
}
Tag::DbSize => {
let db_size = buf.read_i64::<BigEndian>()?;
Self::DbSize(PagestreamDbSizeResponse { db_size })
}
};
let remaining = buf.into_inner();
if !remaining.is_empty() {
anyhow::bail!(
"remaining bytes in msg with tag={msg_tag}: {}",
remaining.len()
);
}
Ok(ok)
}
pub fn kind(&self) -> &'static str {
match self {
Self::Exists(_) => "Exists",
Self::Nblocks(_) => "Nblocks",
Self::GetPage(_) => "GetPage",
Self::Error(_) => "Error",
Self::DbSize(_) => "DbSize",
}
}
}
#[cfg(test)]
@@ -985,7 +831,6 @@ mod tests {
state: TenantState::Active,
current_physical_size: Some(42),
attachment_status: TenantAttachmentStatus::Attached,
generation: None,
};
let expected_active = json!({
"id": original_active.id.to_string(),
@@ -1006,7 +851,6 @@ mod tests {
},
current_physical_size: Some(42),
attachment_status: TenantAttachmentStatus::Attached,
generation: None,
};
let expected_broken = json!({
"id": original_broken.id.to_string(),

View File

@@ -1,151 +0,0 @@
use utils::lsn::Lsn;
#[derive(Debug, PartialEq, Eq)]
pub struct Partitioning {
pub keys: crate::keyspace::KeySpace,
pub at_lsn: Lsn,
}
impl serde::Serialize for Partitioning {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
pub struct KeySpace<'a>(&'a crate::keyspace::KeySpace);
impl<'a> serde::Serialize for KeySpace<'a> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeSeq;
let mut seq = serializer.serialize_seq(Some(self.0.ranges.len()))?;
for kr in &self.0.ranges {
seq.serialize_element(&KeyRange(kr))?;
}
seq.end()
}
}
use serde::ser::SerializeMap;
let mut map = serializer.serialize_map(Some(2))?;
map.serialize_key("keys")?;
map.serialize_value(&KeySpace(&self.keys))?;
map.serialize_key("at_lsn")?;
map.serialize_value(&WithDisplay(&self.at_lsn))?;
map.end()
}
}
pub struct WithDisplay<'a, T>(&'a T);
impl<'a, T: std::fmt::Display> serde::Serialize for WithDisplay<'a, T> {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.collect_str(&self.0)
}
}
pub struct KeyRange<'a>(&'a std::ops::Range<crate::key::Key>);
impl<'a> serde::Serialize for KeyRange<'a> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeTuple;
let mut t = serializer.serialize_tuple(2)?;
t.serialize_element(&WithDisplay(&self.0.start))?;
t.serialize_element(&WithDisplay(&self.0.end))?;
t.end()
}
}
impl<'a> serde::Deserialize<'a> for Partitioning {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'a>,
{
pub struct KeySpace(crate::keyspace::KeySpace);
impl<'de> serde::Deserialize<'de> for KeySpace {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[serde_with::serde_as]
#[derive(serde::Deserialize)]
#[serde(transparent)]
struct Key(#[serde_as(as = "serde_with::DisplayFromStr")] crate::key::Key);
#[serde_with::serde_as]
#[derive(serde::Deserialize)]
struct Range(Key, Key);
let ranges: Vec<Range> = serde::Deserialize::deserialize(deserializer)?;
Ok(Self(crate::keyspace::KeySpace {
ranges: ranges
.into_iter()
.map(|Range(start, end)| (start.0..end.0))
.collect(),
}))
}
}
#[serde_with::serde_as]
#[derive(serde::Deserialize)]
struct De {
keys: KeySpace,
#[serde_as(as = "serde_with::DisplayFromStr")]
at_lsn: Lsn,
}
let de: De = serde::Deserialize::deserialize(deserializer)?;
Ok(Self {
at_lsn: de.at_lsn,
keys: de.keys.0,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialization_roundtrip() {
let reference = r#"
{
"keys": [
[
"000000000000000000000000000000000000",
"000000000000000000000000000000000001"
],
[
"000000067F00000001000000000000000000",
"000000067F00000001000000000000000002"
],
[
"030000000000000000000000000000000000",
"030000000000000000000000000000000003"
]
],
"at_lsn": "0/2240160"
}
"#;
let de: Partitioning = serde_json::from_str(reference).unwrap();
let ser = serde_json::to_string(&de).unwrap();
let ser_de: serde_json::Value = serde_json::from_str(&ser).unwrap();
assert_eq!(
ser_de,
serde_json::from_str::<'_, serde_json::Value>(reference).unwrap()
);
}
}

View File

@@ -32,9 +32,6 @@ pub struct RelTag {
pub relnode: Oid,
}
/// Block number within a relation or SLRU. This matches PostgreSQL's BlockNumber type.
pub type BlockNumber = u32;
impl PartialOrd for RelTag {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))

View File

@@ -1,9 +1,6 @@
use std::{ops::RangeInclusive, str::FromStr};
use crate::{
key::{is_rel_block_key, Key},
models::ShardParameters,
};
use crate::key::{is_rel_block_key, Key};
use hex::FromHex;
use serde::{Deserialize, Serialize};
use thiserror;
@@ -84,16 +81,6 @@ impl TenantShardId {
pub fn is_zero(&self) -> bool {
self.shard_number == ShardNumber(0)
}
pub fn is_unsharded(&self) -> bool {
self.shard_number == ShardNumber(0) && self.shard_count == ShardCount(0)
}
pub fn to_index(&self) -> ShardIndex {
ShardIndex {
shard_number: self.shard_number,
shard_count: self.shard_count,
}
}
}
/// Formatting helper
@@ -172,7 +159,7 @@ impl From<[u8; 18]> for TenantShardId {
/// shard we're dealing with, but do not need to know the full ShardIdentity (because
/// we won't be doing any page->shard mapping), and do not need to know the fully qualified
/// TenantShardId.
#[derive(Eq, PartialEq, PartialOrd, Ord, Clone, Copy, Hash)]
#[derive(Eq, PartialEq, PartialOrd, Ord, Clone, Copy)]
pub struct ShardIndex {
pub shard_number: ShardNumber,
pub shard_count: ShardCount,
@@ -342,7 +329,7 @@ const DEFAULT_STRIPE_SIZE: ShardStripeSize = ShardStripeSize(256 * 1024 / 8);
pub struct ShardIdentity {
pub number: ShardNumber,
pub count: ShardCount,
pub stripe_size: ShardStripeSize,
stripe_size: ShardStripeSize,
layout: ShardLayout,
}
@@ -412,17 +399,6 @@ impl ShardIdentity {
}
}
/// For use when creating ShardIdentity instances for new shards, where a creation request
/// specifies the ShardParameters that apply to all shards.
pub fn from_params(number: ShardNumber, params: &ShardParameters) -> Self {
Self {
number,
count: params.count,
layout: LAYOUT_V1,
stripe_size: params.stripe_size,
}
}
fn is_broken(&self) -> bool {
self.layout == LAYOUT_BROKEN
}
@@ -442,21 +418,6 @@ impl ShardIdentity {
}
}
/// Return true if the key should be discarded if found in this shard's
/// data store, e.g. during compaction after a split
pub fn is_key_disposable(&self, key: &Key) -> bool {
if key_is_shard0(key) {
// Q: Why can't we dispose of shard0 content if we're not shard 0?
// A: because the WAL ingestion logic currently ingests some shard 0
// content on all shards, even though it's only read on shard 0. If we
// dropped it, then subsequent WAL ingest to these keys would encounter
// an error.
false
} else {
!self.is_key_local(key)
}
}
pub fn shard_slug(&self) -> String {
if self.count > ShardCount(0) {
format!("-{:02x}{:02x}", self.number.0, self.count.0)
@@ -550,7 +511,12 @@ fn key_is_shard0(key: &Key) -> bool {
// relation pages are distributed to shards other than shard zero. Everything else gets
// stored on shard 0. This guarantees that shard 0 can independently serve basebackup
// requests, and any request other than those for particular blocks in relations.
!is_rel_block_key(key)
//
// In this condition:
// - is_rel_block_key includes only relations, i.e. excludes SLRU data and
// all metadata.
// - field6 is set to -1 for relation size pages.
!(is_rel_block_key(key) && key.field6 != 0xffffffff)
}
/// Provide the same result as the function in postgres `hashfn.h` with the same name

View File

@@ -35,12 +35,6 @@ pub enum QueryError {
/// We were instructed to shutdown while processing the query
#[error("Shutting down")]
Shutdown,
/// Query handler indicated that client should reconnect
#[error("Server requested reconnect")]
Reconnect,
/// Query named an entity that was not found
#[error("Not found: {0}")]
NotFound(std::borrow::Cow<'static, str>),
/// Authentication failure
#[error("Unauthorized: {0}")]
Unauthorized(std::borrow::Cow<'static, str>),
@@ -60,9 +54,9 @@ impl From<io::Error> for QueryError {
impl QueryError {
pub fn pg_error_code(&self) -> &'static [u8; 5] {
match self {
Self::Disconnected(_) | Self::SimulatedConnectionError | Self::Reconnect => b"08006", // connection failure
Self::Disconnected(_) | Self::SimulatedConnectionError => b"08006", // connection failure
Self::Shutdown => SQLSTATE_ADMIN_SHUTDOWN,
Self::Unauthorized(_) | Self::NotFound(_) => SQLSTATE_INTERNAL_ERROR,
Self::Unauthorized(_) => SQLSTATE_INTERNAL_ERROR,
Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
}
}
@@ -431,11 +425,6 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
info!("Stopped due to shutdown");
Ok(())
}
Err(QueryError::Reconnect) => {
// Dropping out of this loop implicitly disconnects
info!("Stopped due to handler reconnect request");
Ok(())
}
Err(QueryError::Disconnected(e)) => {
info!("Disconnected ({e:#})");
// Disconnection is not an error: we just use it that way internally to drop
@@ -985,9 +974,7 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'a, I
pub fn short_error(e: &QueryError) -> String {
match e {
QueryError::Disconnected(connection_error) => connection_error.to_string(),
QueryError::Reconnect => "reconnect".to_string(),
QueryError::Shutdown => "shutdown".to_string(),
QueryError::NotFound(_) => "not found".to_string(),
QueryError::Unauthorized(_e) => "JWT authentication error".to_string(),
QueryError::SimulatedConnectionError => "simulated connection error".to_string(),
QueryError::Other(e) => format!("{e:#}"),
@@ -1009,15 +996,9 @@ fn log_query_error(query: &str, e: &QueryError) {
QueryError::SimulatedConnectionError => {
error!("query handler for query '{query}' failed due to a simulated connection error")
}
QueryError::Reconnect => {
info!("query handler for '{query}' requested client to reconnect")
}
QueryError::Shutdown => {
info!("query handler for '{query}' cancelled during tenant shutdown")
}
QueryError::NotFound(reason) => {
info!("query handler for '{query}' entity not found: {reason}")
}
QueryError::Unauthorized(e) => {
warn!("query handler for '{query}' failed with authentication error: {e}");
}

View File

@@ -163,18 +163,8 @@ impl PgConnectionConfig {
}
/// Connect using postgres protocol with TLS disabled.
pub async fn connect_no_tls(
&self,
) -> Result<
(
tokio_postgres::Client,
tokio_postgres::Connection<tokio_postgres::Socket, tokio_postgres::tls::NoTlsStream>,
),
postgres::Error,
> {
self.to_tokio_postgres_config()
.connect(postgres::NoTls)
.await
pub fn connect_no_tls(&self) -> Result<postgres::Client, postgres::Error> {
postgres::Config::from(self.to_tokio_postgres_config()).connect(postgres::NoTls)
}
}

View File

@@ -5,9 +5,7 @@ use std::collections::HashMap;
use std::env;
use std::num::NonZeroU32;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use super::REMOTE_STORAGE_PREFIX_SEPARATOR;
use anyhow::Result;
@@ -15,14 +13,12 @@ use azure_core::request_options::{MaxResults, Metadata, Range};
use azure_core::RetryOptions;
use azure_identity::DefaultAzureCredential;
use azure_storage::StorageCredentials;
use azure_storage_blobs::blob::CopyStatus;
use azure_storage_blobs::prelude::ClientBuilder;
use azure_storage_blobs::{blob::operations::GetBlobBuilder, prelude::ContainerClient};
use bytes::Bytes;
use futures::stream::Stream;
use futures_util::StreamExt;
use http_types::{StatusCode, Url};
use tokio::time::Instant;
use http_types::StatusCode;
use tracing::debug;
use crate::s3_bucket::RequestKind;
@@ -121,8 +117,6 @@ impl AzureBlobStorage {
) -> Result<Download, DownloadError> {
let mut response = builder.into_stream();
let mut etag = None;
let mut last_modified = None;
let mut metadata = HashMap::new();
// TODO give proper streaming response instead of buffering into RAM
// https://github.com/neondatabase/neon/issues/5563
@@ -130,13 +124,6 @@ impl AzureBlobStorage {
let mut bufs = Vec::new();
while let Some(part) = response.next().await {
let part = part.map_err(to_download_error)?;
let etag_str: &str = part.blob.properties.etag.as_ref();
if etag.is_none() {
etag = Some(etag.unwrap_or_else(|| etag_str.to_owned()));
}
if last_modified.is_none() {
last_modified = Some(part.blob.properties.last_modified.into());
}
if let Some(blob_meta) = part.blob.metadata {
metadata.extend(blob_meta.iter().map(|(k, v)| (k.to_owned(), v.to_owned())));
}
@@ -149,8 +136,6 @@ impl AzureBlobStorage {
}
Ok(Download {
download_stream: Box::pin(futures::stream::iter(bufs.into_iter().map(Ok))),
etag,
last_modified,
metadata: Some(StorageMetadata(metadata)),
})
}
@@ -326,51 +311,6 @@ impl RemoteStorage for AzureBlobStorage {
}
Ok(())
}
async fn copy(&self, from: &RemotePath, to: &RemotePath) -> anyhow::Result<()> {
let _permit = self.permit(RequestKind::Copy).await;
let blob_client = self.client.blob_client(self.relative_path_to_name(to));
let source_url = format!(
"{}/{}",
self.client.url()?,
self.relative_path_to_name(from)
);
let builder = blob_client.copy(Url::from_str(&source_url)?);
let result = builder.into_future().await?;
let mut copy_status = result.copy_status;
let start_time = Instant::now();
const MAX_WAIT_TIME: Duration = Duration::from_secs(60);
loop {
match copy_status {
CopyStatus::Aborted => {
anyhow::bail!("Received abort for copy from {from} to {to}.");
}
CopyStatus::Failed => {
anyhow::bail!("Received failure response for copy from {from} to {to}.");
}
CopyStatus::Success => return Ok(()),
CopyStatus::Pending => (),
}
// The copy is taking longer. Waiting a second and then re-trying.
// TODO estimate time based on copy_progress and adjust time based on that
tokio::time::sleep(Duration::from_millis(1000)).await;
let properties = blob_client.get_properties().into_future().await?;
let Some(status) = properties.blob.properties.copy_status else {
tracing::warn!("copy_status for copy is None!, from={from}, to={to}");
return Ok(());
};
if start_time.elapsed() > MAX_WAIT_TIME {
anyhow::bail!("Copy from from {from} to {to} took longer than limit MAX_WAIT_TIME={}s. copy_pogress={:?}.",
MAX_WAIT_TIME.as_secs_f32(),
properties.blob.properties.copy_progress,
);
}
copy_status = status;
}
}
}
pin_project_lite::pin_project! {

View File

@@ -14,9 +14,7 @@ mod local_fs;
mod s3_bucket;
mod simulate_failures;
use std::{
collections::HashMap, fmt::Debug, num::NonZeroUsize, pin::Pin, sync::Arc, time::SystemTime,
};
use std::{collections::HashMap, fmt::Debug, num::NonZeroUsize, pin::Pin, sync::Arc};
use anyhow::{bail, Context};
use camino::{Utf8Path, Utf8PathBuf};
@@ -207,18 +205,10 @@ pub trait RemoteStorage: Send + Sync + 'static {
async fn delete(&self, path: &RemotePath) -> anyhow::Result<()>;
async fn delete_objects<'a>(&self, paths: &'a [RemotePath]) -> anyhow::Result<()>;
/// Copy a remote object inside a bucket from one path to another.
async fn copy(&self, from: &RemotePath, to: &RemotePath) -> anyhow::Result<()>;
}
pub type DownloadStream = Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Unpin + Send + Sync>>;
pub struct Download {
pub download_stream: DownloadStream,
/// The last time the file was modified (`last-modified` HTTP header)
pub last_modified: Option<SystemTime>,
/// A way to identify this specific version of the resource (`etag` HTTP header)
pub etag: Option<String>,
pub download_stream: Pin<Box<dyn Stream<Item = std::io::Result<Bytes>> + Unpin + Send + Sync>>,
/// Extra key-value data, associated with the current remote file.
pub metadata: Option<StorageMetadata>,
}
@@ -377,15 +367,6 @@ impl GenericRemoteStorage {
Self::Unreliable(s) => s.delete_objects(paths).await,
}
}
pub async fn copy_object(&self, from: &RemotePath, to: &RemotePath) -> anyhow::Result<()> {
match self {
Self::LocalFs(s) => s.copy(from, to).await,
Self::AwsS3(s) => s.copy(from, to).await,
Self::AzureBlob(s) => s.copy(from, to).await,
Self::Unreliable(s) => s.copy(from, to).await,
}
}
}
impl GenericRemoteStorage {
@@ -672,7 +653,6 @@ impl ConcurrencyLimiter {
RequestKind::Put => &self.write,
RequestKind::List => &self.read,
RequestKind::Delete => &self.write,
RequestKind::Copy => &self.write,
}
}

View File

@@ -18,7 +18,7 @@ use tokio_util::io::ReaderStream;
use tracing::*;
use utils::{crashsafe::path_with_suffix_extension, fs_ext::is_directory_empty};
use crate::{Download, DownloadError, DownloadStream, Listing, ListingMode, RemotePath};
use crate::{Download, DownloadError, Listing, ListingMode, RemotePath};
use super::{RemoteStorage, StorageMetadata};
@@ -331,8 +331,6 @@ impl RemoteStorage for LocalFs {
.map_err(DownloadError::Other)?;
Ok(Download {
metadata,
last_modified: None,
etag: None,
download_stream: Box::pin(source),
})
} else {
@@ -374,17 +372,17 @@ impl RemoteStorage for LocalFs {
.await
.map_err(DownloadError::Other)?;
let download_stream: DownloadStream = match end_exclusive {
Some(end_exclusive) => Box::pin(ReaderStream::new(
source.take(end_exclusive - start_inclusive),
)),
None => Box::pin(ReaderStream::new(source)),
};
Ok(Download {
metadata,
last_modified: None,
etag: None,
download_stream,
Ok(match end_exclusive {
Some(end_exclusive) => Download {
metadata,
download_stream: Box::pin(ReaderStream::new(
source.take(end_exclusive - start_inclusive),
)),
},
None => Download {
metadata,
download_stream: Box::pin(ReaderStream::new(source)),
},
})
} else {
Err(DownloadError::NotFound)
@@ -409,20 +407,6 @@ impl RemoteStorage for LocalFs {
}
Ok(())
}
async fn copy(&self, from: &RemotePath, to: &RemotePath) -> anyhow::Result<()> {
let from_path = from.with_base(&self.storage_root);
let to_path = to.with_base(&self.storage_root);
create_target_directory(&to_path).await?;
fs::copy(&from_path, &to_path).await.with_context(|| {
format!(
"Failed to copy file from '{from_path}' to '{to_path}'",
from_path = from_path,
to_path = to_path
)
})?;
Ok(())
}
}
fn storage_metadata_path(original_path: &Utf8Path) -> Utf8PathBuf {

View File

@@ -16,7 +16,6 @@ use aws_config::{
environment::credentials::EnvironmentVariableCredentialsProvider,
imds::credentials::ImdsCredentialsProvider,
meta::credentials::CredentialsProviderChain,
profile::ProfileFileCredentialsProvider,
provider_config::ProviderConfig,
retry::{RetryConfigBuilder, RetryMode},
web_identity_token::WebIdentityTokenCredentialsProvider,
@@ -75,29 +74,20 @@ impl S3Bucket {
let region = Some(Region::new(aws_config.bucket_region.clone()));
let provider_conf = ProviderConfig::without_region().with_region(region.clone());
let credentials_provider = {
// uses "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"
CredentialsProviderChain::first_try(
"env",
EnvironmentVariableCredentialsProvider::new(),
)
// uses "AWS_PROFILE" / `aws sso login --profile <profile>`
.or_else(
"profile-sso",
ProfileFileCredentialsProvider::builder()
.configure(&provider_conf)
.build(),
)
// uses "AWS_WEB_IDENTITY_TOKEN_FILE", "AWS_ROLE_ARN", "AWS_ROLE_SESSION_NAME"
// needed to access remote extensions bucket
.or_else(
"token",
.or_else("token", {
let provider_conf = ProviderConfig::without_region().with_region(region.clone());
WebIdentityTokenCredentialsProvider::builder()
.configure(&provider_conf)
.build(),
)
.build()
})
// uses imds v2
.or_else("imds", ImdsCredentialsProvider::builder().build())
};
@@ -228,11 +218,17 @@ impl S3Bucket {
let started_at = ScopeGuard::into_inner(started_at);
if get_object.is_err() {
metrics::BUCKET_METRICS.req_seconds.observe_elapsed(
kind,
AttemptOutcome::Err,
started_at,
);
}
match get_object {
Ok(object_output) => {
let metadata = object_output.metadata().cloned().map(StorageMetadata);
let etag = object_output.e_tag.clone();
let last_modified = object_output.last_modified.and_then(|t| t.try_into().ok());
let body = object_output.body;
let body = ByteStreamAsStream::from(body);
@@ -241,33 +237,15 @@ impl S3Bucket {
Ok(Download {
metadata,
etag,
last_modified,
download_stream: Box::pin(body),
})
}
Err(SdkError::ServiceError(e)) if matches!(e.err(), GetObjectError::NoSuchKey(_)) => {
// Count this in the AttemptOutcome::Ok bucket, because 404 is not
// an error: we expect to sometimes fetch an object and find it missing,
// e.g. when probing for timeline indices.
metrics::BUCKET_METRICS.req_seconds.observe_elapsed(
kind,
AttemptOutcome::Ok,
started_at,
);
Err(DownloadError::NotFound)
}
Err(e) => {
metrics::BUCKET_METRICS.req_seconds.observe_elapsed(
kind,
AttemptOutcome::Err,
started_at,
);
Err(DownloadError::Other(
anyhow::Error::new(e).context("download s3 object"),
))
}
Err(e) => Err(DownloadError::Other(
anyhow::Error::new(e).context("download s3 object"),
)),
}
}
}
@@ -493,38 +471,6 @@ impl RemoteStorage for S3Bucket {
Ok(())
}
async fn copy(&self, from: &RemotePath, to: &RemotePath) -> anyhow::Result<()> {
let kind = RequestKind::Copy;
let _guard = self.permit(kind).await;
let started_at = start_measuring_requests(kind);
// we need to specify bucket_name as a prefix
let copy_source = format!(
"{}/{}",
self.bucket_name,
self.relative_path_to_s3_object(from)
);
let res = self
.client
.copy_object()
.bucket(self.bucket_name.clone())
.key(self.relative_path_to_s3_object(to))
.copy_source(copy_source)
.send()
.await;
let started_at = ScopeGuard::into_inner(started_at);
metrics::BUCKET_METRICS
.req_seconds
.observe_elapsed(kind, &res, started_at);
res?;
Ok(())
}
async fn download(&self, from: &RemotePath) -> Result<Download, DownloadError> {
// if prefix is not none then download file `prefix/from`
// if prefix is none then download file `from`

View File

@@ -11,7 +11,6 @@ pub(crate) enum RequestKind {
Put = 1,
Delete = 2,
List = 3,
Copy = 4,
}
use RequestKind::*;
@@ -23,7 +22,6 @@ impl RequestKind {
Put => "put_object",
Delete => "delete_object",
List => "list_objects",
Copy => "copy_object",
}
}
const fn as_index(&self) -> usize {
@@ -31,7 +29,7 @@ impl RequestKind {
}
}
pub(super) struct RequestTyped<C>([C; 5]);
pub(super) struct RequestTyped<C>([C; 4]);
impl<C> RequestTyped<C> {
pub(super) fn get(&self, kind: RequestKind) -> &C {
@@ -40,8 +38,8 @@ impl<C> RequestTyped<C> {
fn build_with(mut f: impl FnMut(RequestKind) -> C) -> Self {
use RequestKind::*;
let mut it = [Get, Put, Delete, List, Copy].into_iter();
let arr = std::array::from_fn::<C, 5, _>(|index| {
let mut it = [Get, Put, Delete, List].into_iter();
let arr = std::array::from_fn::<C, 4, _>(|index| {
let next = it.next().unwrap();
assert_eq!(index, next.as_index());
f(next)

View File

@@ -162,11 +162,4 @@ impl RemoteStorage for UnreliableWrapper {
}
Ok(())
}
async fn copy(&self, from: &RemotePath, to: &RemotePath) -> anyhow::Result<()> {
// copy is equivalent to download + upload
self.attempt(RemoteOp::Download(from.clone()))?;
self.attempt(RemoteOp::Upload(to.clone()))?;
self.inner.copy_object(from, to).await
}
}

View File

@@ -1,200 +0,0 @@
use std::collections::HashSet;
use std::ops::ControlFlow;
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::Context;
use bytes::Bytes;
use camino::Utf8Path;
use futures::stream::Stream;
use once_cell::sync::OnceCell;
use remote_storage::{Download, GenericRemoteStorage, RemotePath};
use tokio::task::JoinSet;
use tracing::{debug, error, info};
static LOGGING_DONE: OnceCell<()> = OnceCell::new();
pub(crate) fn upload_stream(
content: std::borrow::Cow<'static, [u8]>,
) -> (
impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
usize,
) {
use std::borrow::Cow;
let content = match content {
Cow::Borrowed(x) => Bytes::from_static(x),
Cow::Owned(vec) => Bytes::from(vec),
};
wrap_stream(content)
}
pub(crate) fn wrap_stream(
content: bytes::Bytes,
) -> (
impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
usize,
) {
let len = content.len();
let content = futures::future::ready(Ok(content));
(futures::stream::once(content), len)
}
pub(crate) async fn download_to_vec(dl: Download) -> anyhow::Result<Vec<u8>> {
let mut buf = Vec::new();
tokio::io::copy_buf(
&mut tokio_util::io::StreamReader::new(dl.download_stream),
&mut buf,
)
.await?;
Ok(buf)
}
// Uploads files `folder{j}/blob{i}.txt`. See test description for more details.
pub(crate) async fn upload_simple_remote_data(
client: &Arc<GenericRemoteStorage>,
upload_tasks_count: usize,
) -> ControlFlow<HashSet<RemotePath>, HashSet<RemotePath>> {
info!("Creating {upload_tasks_count} remote files");
let mut upload_tasks = JoinSet::new();
for i in 1..upload_tasks_count + 1 {
let task_client = Arc::clone(client);
upload_tasks.spawn(async move {
let blob_path = PathBuf::from(format!("folder{}/blob_{}.txt", i / 7, i));
let blob_path = RemotePath::new(
Utf8Path::from_path(blob_path.as_path()).expect("must be valid blob path"),
)
.with_context(|| format!("{blob_path:?} to RemotePath conversion"))?;
debug!("Creating remote item {i} at path {blob_path:?}");
let (data, len) = upload_stream(format!("remote blob data {i}").into_bytes().into());
task_client.upload(data, len, &blob_path, None).await?;
Ok::<_, anyhow::Error>(blob_path)
});
}
let mut upload_tasks_failed = false;
let mut uploaded_blobs = HashSet::with_capacity(upload_tasks_count);
while let Some(task_run_result) = upload_tasks.join_next().await {
match task_run_result
.context("task join failed")
.and_then(|task_result| task_result.context("upload task failed"))
{
Ok(upload_path) => {
uploaded_blobs.insert(upload_path);
}
Err(e) => {
error!("Upload task failed: {e:?}");
upload_tasks_failed = true;
}
}
}
if upload_tasks_failed {
ControlFlow::Break(uploaded_blobs)
} else {
ControlFlow::Continue(uploaded_blobs)
}
}
pub(crate) async fn cleanup(
client: &Arc<GenericRemoteStorage>,
objects_to_delete: HashSet<RemotePath>,
) {
info!(
"Removing {} objects from the remote storage during cleanup",
objects_to_delete.len()
);
let mut delete_tasks = JoinSet::new();
for object_to_delete in objects_to_delete {
let task_client = Arc::clone(client);
delete_tasks.spawn(async move {
debug!("Deleting remote item at path {object_to_delete:?}");
task_client
.delete(&object_to_delete)
.await
.with_context(|| format!("{object_to_delete:?} removal"))
});
}
while let Some(task_run_result) = delete_tasks.join_next().await {
match task_run_result {
Ok(task_result) => match task_result {
Ok(()) => {}
Err(e) => error!("Delete task failed: {e:?}"),
},
Err(join_err) => error!("Delete task did not finish correctly: {join_err}"),
}
}
}
pub(crate) struct Uploads {
pub(crate) prefixes: HashSet<RemotePath>,
pub(crate) blobs: HashSet<RemotePath>,
}
pub(crate) async fn upload_remote_data(
client: &Arc<GenericRemoteStorage>,
base_prefix_str: &'static str,
upload_tasks_count: usize,
) -> ControlFlow<Uploads, Uploads> {
info!("Creating {upload_tasks_count} remote files");
let mut upload_tasks = JoinSet::new();
for i in 1..upload_tasks_count + 1 {
let task_client = Arc::clone(client);
upload_tasks.spawn(async move {
let prefix = format!("{base_prefix_str}/sub_prefix_{i}/");
let blob_prefix = RemotePath::new(Utf8Path::new(&prefix))
.with_context(|| format!("{prefix:?} to RemotePath conversion"))?;
let blob_path = blob_prefix.join(Utf8Path::new(&format!("blob_{i}")));
debug!("Creating remote item {i} at path {blob_path:?}");
let (data, data_len) =
upload_stream(format!("remote blob data {i}").into_bytes().into());
task_client.upload(data, data_len, &blob_path, None).await?;
Ok::<_, anyhow::Error>((blob_prefix, blob_path))
});
}
let mut upload_tasks_failed = false;
let mut uploaded_prefixes = HashSet::with_capacity(upload_tasks_count);
let mut uploaded_blobs = HashSet::with_capacity(upload_tasks_count);
while let Some(task_run_result) = upload_tasks.join_next().await {
match task_run_result
.context("task join failed")
.and_then(|task_result| task_result.context("upload task failed"))
{
Ok((upload_prefix, upload_path)) => {
uploaded_prefixes.insert(upload_prefix);
uploaded_blobs.insert(upload_path);
}
Err(e) => {
error!("Upload task failed: {e:?}");
upload_tasks_failed = true;
}
}
}
let uploads = Uploads {
prefixes: uploaded_prefixes,
blobs: uploaded_blobs,
};
if upload_tasks_failed {
ControlFlow::Break(uploads)
} else {
ControlFlow::Continue(uploads)
}
}
pub(crate) fn ensure_logging_ready() {
LOGGING_DONE.get_or_init(|| {
utils::logging::init(
utils::logging::LogFormat::Test,
utils::logging::TracingErrorLayerEnablement::Disabled,
utils::logging::Output::Stdout,
)
.expect("logging init failed");
});
}

View File

@@ -1,288 +0,0 @@
use anyhow::Context;
use camino::Utf8Path;
use remote_storage::RemotePath;
use std::collections::HashSet;
use std::sync::Arc;
use test_context::test_context;
use tracing::debug;
use crate::common::{download_to_vec, upload_stream, wrap_stream};
use super::{
MaybeEnabledStorage, MaybeEnabledStorageWithSimpleTestBlobs, MaybeEnabledStorageWithTestBlobs,
};
/// Tests that S3 client can list all prefixes, even if the response come paginated and requires multiple S3 queries.
/// Uses real S3 and requires [`ENABLE_REAL_S3_REMOTE_STORAGE_ENV_VAR_NAME`] and related S3 cred env vars specified.
/// See the client creation in [`create_s3_client`] for details on the required env vars.
/// If real S3 tests are disabled, the test passes, skipping any real test run: currently, there's no way to mark the test ignored in runtime with the
/// deafult test framework, see https://github.com/rust-lang/rust/issues/68007 for details.
///
/// First, the test creates a set of S3 objects with keys `/${random_prefix_part}/${base_prefix_str}/sub_prefix_${i}/blob_${i}` in [`upload_remote_data`]
/// where
/// * `random_prefix_part` is set for the entire S3 client during the S3 client creation in [`create_s3_client`], to avoid multiple test runs interference
/// * `base_prefix_str` is a common prefix to use in the client requests: we would want to ensure that the client is able to list nested prefixes inside the bucket
///
/// Then, verifies that the client does return correct prefixes when queried:
/// * with no prefix, it lists everything after its `${random_prefix_part}/` — that should be `${base_prefix_str}` value only
/// * with `${base_prefix_str}/` prefix, it lists every `sub_prefix_${i}`
///
/// With the real S3 enabled and `#[cfg(test)]` Rust configuration used, the S3 client test adds a `max-keys` param to limit the response keys.
/// This way, we are able to test the pagination implicitly, by ensuring all results are returned from the remote storage and avoid uploading too many blobs to S3,
/// since current default AWS S3 pagination limit is 1000.
/// (see https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html#API_ListObjectsV2_RequestSyntax)
///
/// Lastly, the test attempts to clean up and remove all uploaded S3 files.
/// If any errors appear during the clean up, they get logged, but the test is not failed or stopped until clean up is finished.
#[test_context(MaybeEnabledStorageWithTestBlobs)]
#[tokio::test]
async fn pagination_should_work(ctx: &mut MaybeEnabledStorageWithTestBlobs) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledStorageWithTestBlobs::Enabled(ctx) => ctx,
MaybeEnabledStorageWithTestBlobs::Disabled => return Ok(()),
MaybeEnabledStorageWithTestBlobs::UploadsFailed(e, _) => {
anyhow::bail!("S3 init failed: {e:?}")
}
};
let test_client = Arc::clone(&ctx.enabled.client);
let expected_remote_prefixes = ctx.remote_prefixes.clone();
let base_prefix = RemotePath::new(Utf8Path::new(ctx.enabled.base_prefix))
.context("common_prefix construction")?;
let root_remote_prefixes = test_client
.list_prefixes(None)
.await
.context("client list root prefixes failure")?
.into_iter()
.collect::<HashSet<_>>();
assert_eq!(
root_remote_prefixes, HashSet::from([base_prefix.clone()]),
"remote storage root prefixes list mismatches with the uploads. Returned prefixes: {root_remote_prefixes:?}"
);
let nested_remote_prefixes = test_client
.list_prefixes(Some(&base_prefix))
.await
.context("client list nested prefixes failure")?
.into_iter()
.collect::<HashSet<_>>();
let remote_only_prefixes = nested_remote_prefixes
.difference(&expected_remote_prefixes)
.collect::<HashSet<_>>();
let missing_uploaded_prefixes = expected_remote_prefixes
.difference(&nested_remote_prefixes)
.collect::<HashSet<_>>();
assert_eq!(
remote_only_prefixes.len() + missing_uploaded_prefixes.len(), 0,
"remote storage nested prefixes list mismatches with the uploads. Remote only prefixes: {remote_only_prefixes:?}, missing uploaded prefixes: {missing_uploaded_prefixes:?}",
);
Ok(())
}
/// Tests that S3 client can list all files in a folder, even if the response comes paginated and requirees multiple S3 queries.
/// Uses real S3 and requires [`ENABLE_REAL_S3_REMOTE_STORAGE_ENV_VAR_NAME`] and related S3 cred env vars specified. Test will skip real code and pass if env vars not set.
/// See `s3_pagination_should_work` for more information.
///
/// First, create a set of S3 objects with keys `random_prefix/folder{j}/blob_{i}.txt` in [`upload_remote_data`]
/// Then performs the following queries:
/// 1. `list_files(None)`. This should return all files `random_prefix/folder{j}/blob_{i}.txt`
/// 2. `list_files("folder1")`. This should return all files `random_prefix/folder1/blob_{i}.txt`
#[test_context(MaybeEnabledStorageWithSimpleTestBlobs)]
#[tokio::test]
async fn list_files_works(ctx: &mut MaybeEnabledStorageWithSimpleTestBlobs) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledStorageWithSimpleTestBlobs::Enabled(ctx) => ctx,
MaybeEnabledStorageWithSimpleTestBlobs::Disabled => return Ok(()),
MaybeEnabledStorageWithSimpleTestBlobs::UploadsFailed(e, _) => {
anyhow::bail!("S3 init failed: {e:?}")
}
};
let test_client = Arc::clone(&ctx.enabled.client);
let base_prefix =
RemotePath::new(Utf8Path::new("folder1")).context("common_prefix construction")?;
let root_files = test_client
.list_files(None)
.await
.context("client list root files failure")?
.into_iter()
.collect::<HashSet<_>>();
assert_eq!(
root_files,
ctx.remote_blobs.clone(),
"remote storage list_files on root mismatches with the uploads."
);
let nested_remote_files = test_client
.list_files(Some(&base_prefix))
.await
.context("client list nested files failure")?
.into_iter()
.collect::<HashSet<_>>();
let trim_remote_blobs: HashSet<_> = ctx
.remote_blobs
.iter()
.map(|x| x.get_path())
.filter(|x| x.starts_with("folder1"))
.map(|x| RemotePath::new(x).expect("must be valid path"))
.collect();
assert_eq!(
nested_remote_files, trim_remote_blobs,
"remote storage list_files on subdirrectory mismatches with the uploads."
);
Ok(())
}
#[test_context(MaybeEnabledStorage)]
#[tokio::test]
async fn delete_non_exising_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledStorage::Enabled(ctx) => ctx,
MaybeEnabledStorage::Disabled => return Ok(()),
};
let path = RemotePath::new(Utf8Path::new(
format!("{}/for_sure_there_is_nothing_there_really", ctx.base_prefix).as_str(),
))
.with_context(|| "RemotePath conversion")?;
ctx.client.delete(&path).await.expect("should succeed");
Ok(())
}
#[test_context(MaybeEnabledStorage)]
#[tokio::test]
async fn delete_objects_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledStorage::Enabled(ctx) => ctx,
MaybeEnabledStorage::Disabled => return Ok(()),
};
let path1 = RemotePath::new(Utf8Path::new(format!("{}/path1", ctx.base_prefix).as_str()))
.with_context(|| "RemotePath conversion")?;
let path2 = RemotePath::new(Utf8Path::new(format!("{}/path2", ctx.base_prefix).as_str()))
.with_context(|| "RemotePath conversion")?;
let path3 = RemotePath::new(Utf8Path::new(format!("{}/path3", ctx.base_prefix).as_str()))
.with_context(|| "RemotePath conversion")?;
let (data, len) = upload_stream("remote blob data1".as_bytes().into());
ctx.client.upload(data, len, &path1, None).await?;
let (data, len) = upload_stream("remote blob data2".as_bytes().into());
ctx.client.upload(data, len, &path2, None).await?;
let (data, len) = upload_stream("remote blob data3".as_bytes().into());
ctx.client.upload(data, len, &path3, None).await?;
ctx.client.delete_objects(&[path1, path2]).await?;
let prefixes = ctx.client.list_prefixes(None).await?;
assert_eq!(prefixes.len(), 1);
ctx.client.delete_objects(&[path3]).await?;
Ok(())
}
#[test_context(MaybeEnabledStorage)]
#[tokio::test]
async fn upload_download_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<()> {
let MaybeEnabledStorage::Enabled(ctx) = ctx else {
return Ok(());
};
let path = RemotePath::new(Utf8Path::new(format!("{}/file", ctx.base_prefix).as_str()))
.with_context(|| "RemotePath conversion")?;
let orig = bytes::Bytes::from_static("remote blob data here".as_bytes());
let (data, len) = wrap_stream(orig.clone());
ctx.client.upload(data, len, &path, None).await?;
// Normal download request
let dl = ctx.client.download(&path).await?;
let buf = download_to_vec(dl).await?;
assert_eq!(&buf, &orig);
// Full range (end specified)
let dl = ctx
.client
.download_byte_range(&path, 0, Some(len as u64))
.await?;
let buf = download_to_vec(dl).await?;
assert_eq!(&buf, &orig);
// partial range (end specified)
let dl = ctx.client.download_byte_range(&path, 4, Some(10)).await?;
let buf = download_to_vec(dl).await?;
assert_eq!(&buf, &orig[4..10]);
// partial range (end beyond real end)
let dl = ctx
.client
.download_byte_range(&path, 8, Some(len as u64 * 100))
.await?;
let buf = download_to_vec(dl).await?;
assert_eq!(&buf, &orig[8..]);
// Partial range (end unspecified)
let dl = ctx.client.download_byte_range(&path, 4, None).await?;
let buf = download_to_vec(dl).await?;
assert_eq!(&buf, &orig[4..]);
// Full range (end unspecified)
let dl = ctx.client.download_byte_range(&path, 0, None).await?;
let buf = download_to_vec(dl).await?;
assert_eq!(&buf, &orig);
debug!("Cleanup: deleting file at path {path:?}");
ctx.client
.delete(&path)
.await
.with_context(|| format!("{path:?} removal"))?;
Ok(())
}
#[test_context(MaybeEnabledStorage)]
#[tokio::test]
async fn copy_works(ctx: &mut MaybeEnabledStorage) -> anyhow::Result<()> {
let MaybeEnabledStorage::Enabled(ctx) = ctx else {
return Ok(());
};
let path = RemotePath::new(Utf8Path::new(
format!("{}/file_to_copy", ctx.base_prefix).as_str(),
))
.with_context(|| "RemotePath conversion")?;
let path_dest = RemotePath::new(Utf8Path::new(
format!("{}/file_dest", ctx.base_prefix).as_str(),
))
.with_context(|| "RemotePath conversion")?;
let orig = bytes::Bytes::from_static("remote blob data content".as_bytes());
let (data, len) = wrap_stream(orig.clone());
ctx.client.upload(data, len, &path, None).await?;
// Normal download request
ctx.client.copy_object(&path, &path_dest).await?;
let dl = ctx.client.download(&path_dest).await?;
let buf = download_to_vec(dl).await?;
assert_eq!(&buf, &orig);
debug!("Cleanup: deleting file at path {path:?}");
ctx.client
.delete_objects(&[path.clone(), path_dest.clone()])
.await
.with_context(|| format!("{path:?} removal"))?;
Ok(())
}

View File

@@ -2,27 +2,287 @@ use std::collections::HashSet;
use std::env;
use std::num::NonZeroUsize;
use std::ops::ControlFlow;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::UNIX_EPOCH;
use anyhow::Context;
use bytes::Bytes;
use camino::Utf8Path;
use futures::stream::Stream;
use once_cell::sync::OnceCell;
use remote_storage::{
AzureConfig, GenericRemoteStorage, RemotePath, RemoteStorageConfig, RemoteStorageKind,
AzureConfig, Download, GenericRemoteStorage, RemotePath, RemoteStorageConfig, RemoteStorageKind,
};
use test_context::AsyncTestContext;
use tracing::info;
use test_context::{test_context, AsyncTestContext};
use tokio::task::JoinSet;
use tracing::{debug, error, info};
mod common;
#[path = "common/tests.rs"]
mod tests_azure;
use common::{cleanup, ensure_logging_ready, upload_remote_data, upload_simple_remote_data};
static LOGGING_DONE: OnceCell<()> = OnceCell::new();
const ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME: &str = "ENABLE_REAL_AZURE_REMOTE_STORAGE";
const BASE_PREFIX: &str = "test";
/// Tests that the Azure client can list all prefixes, even if the response comes paginated and requires multiple HTTP queries.
/// Uses real Azure and requires [`ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME`] and related Azure cred env vars specified.
/// See the client creation in [`create_azure_client`] for details on the required env vars.
/// If real Azure tests are disabled, the test passes, skipping any real test run: currently, there's no way to mark the test ignored in runtime with the
/// deafult test framework, see https://github.com/rust-lang/rust/issues/68007 for details.
///
/// First, the test creates a set of Azure blobs with keys `/${random_prefix_part}/${base_prefix_str}/sub_prefix_${i}/blob_${i}` in [`upload_azure_data`]
/// where
/// * `random_prefix_part` is set for the entire Azure client during the Azure client creation in [`create_azure_client`], to avoid multiple test runs interference
/// * `base_prefix_str` is a common prefix to use in the client requests: we would want to ensure that the client is able to list nested prefixes inside the bucket
///
/// Then, verifies that the client does return correct prefixes when queried:
/// * with no prefix, it lists everything after its `${random_prefix_part}/` — that should be `${base_prefix_str}` value only
/// * with `${base_prefix_str}/` prefix, it lists every `sub_prefix_${i}`
///
/// With the real Azure enabled and `#[cfg(test)]` Rust configuration used, the Azure client test adds a `max-keys` param to limit the response keys.
/// This way, we are able to test the pagination implicitly, by ensuring all results are returned from the remote storage and avoid uploading too many blobs to Azure.
///
/// Lastly, the test attempts to clean up and remove all uploaded Azure files.
/// If any errors appear during the clean up, they get logged, but the test is not failed or stopped until clean up is finished.
#[test_context(MaybeEnabledAzureWithTestBlobs)]
#[tokio::test]
async fn azure_pagination_should_work(
ctx: &mut MaybeEnabledAzureWithTestBlobs,
) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledAzureWithTestBlobs::Enabled(ctx) => ctx,
MaybeEnabledAzureWithTestBlobs::Disabled => return Ok(()),
MaybeEnabledAzureWithTestBlobs::UploadsFailed(e, _) => {
anyhow::bail!("Azure init failed: {e:?}")
}
};
let test_client = Arc::clone(&ctx.enabled.client);
let expected_remote_prefixes = ctx.remote_prefixes.clone();
let base_prefix = RemotePath::new(Utf8Path::new(ctx.enabled.base_prefix))
.context("common_prefix construction")?;
let root_remote_prefixes = test_client
.list_prefixes(None)
.await
.context("client list root prefixes failure")?
.into_iter()
.collect::<HashSet<_>>();
assert_eq!(
root_remote_prefixes, HashSet::from([base_prefix.clone()]),
"remote storage root prefixes list mismatches with the uploads. Returned prefixes: {root_remote_prefixes:?}"
);
let nested_remote_prefixes = test_client
.list_prefixes(Some(&base_prefix))
.await
.context("client list nested prefixes failure")?
.into_iter()
.collect::<HashSet<_>>();
let remote_only_prefixes = nested_remote_prefixes
.difference(&expected_remote_prefixes)
.collect::<HashSet<_>>();
let missing_uploaded_prefixes = expected_remote_prefixes
.difference(&nested_remote_prefixes)
.collect::<HashSet<_>>();
assert_eq!(
remote_only_prefixes.len() + missing_uploaded_prefixes.len(), 0,
"remote storage nested prefixes list mismatches with the uploads. Remote only prefixes: {remote_only_prefixes:?}, missing uploaded prefixes: {missing_uploaded_prefixes:?}",
);
Ok(())
}
/// Tests that Azure client can list all files in a folder, even if the response comes paginated and requirees multiple Azure queries.
/// Uses real Azure and requires [`ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME`] and related Azure cred env vars specified. Test will skip real code and pass if env vars not set.
/// See `Azure_pagination_should_work` for more information.
///
/// First, create a set of Azure objects with keys `random_prefix/folder{j}/blob_{i}.txt` in [`upload_azure_data`]
/// Then performs the following queries:
/// 1. `list_files(None)`. This should return all files `random_prefix/folder{j}/blob_{i}.txt`
/// 2. `list_files("folder1")`. This should return all files `random_prefix/folder1/blob_{i}.txt`
#[test_context(MaybeEnabledAzureWithSimpleTestBlobs)]
#[tokio::test]
async fn azure_list_files_works(
ctx: &mut MaybeEnabledAzureWithSimpleTestBlobs,
) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledAzureWithSimpleTestBlobs::Enabled(ctx) => ctx,
MaybeEnabledAzureWithSimpleTestBlobs::Disabled => return Ok(()),
MaybeEnabledAzureWithSimpleTestBlobs::UploadsFailed(e, _) => {
anyhow::bail!("Azure init failed: {e:?}")
}
};
let test_client = Arc::clone(&ctx.enabled.client);
let base_prefix =
RemotePath::new(Utf8Path::new("folder1")).context("common_prefix construction")?;
let root_files = test_client
.list_files(None)
.await
.context("client list root files failure")?
.into_iter()
.collect::<HashSet<_>>();
assert_eq!(
root_files,
ctx.remote_blobs.clone(),
"remote storage list_files on root mismatches with the uploads."
);
let nested_remote_files = test_client
.list_files(Some(&base_prefix))
.await
.context("client list nested files failure")?
.into_iter()
.collect::<HashSet<_>>();
let trim_remote_blobs: HashSet<_> = ctx
.remote_blobs
.iter()
.map(|x| x.get_path())
.filter(|x| x.starts_with("folder1"))
.map(|x| RemotePath::new(x).expect("must be valid path"))
.collect();
assert_eq!(
nested_remote_files, trim_remote_blobs,
"remote storage list_files on subdirrectory mismatches with the uploads."
);
Ok(())
}
#[test_context(MaybeEnabledAzure)]
#[tokio::test]
async fn azure_delete_non_exising_works(ctx: &mut MaybeEnabledAzure) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledAzure::Enabled(ctx) => ctx,
MaybeEnabledAzure::Disabled => return Ok(()),
};
let path = RemotePath::new(Utf8Path::new(
format!("{}/for_sure_there_is_nothing_there_really", ctx.base_prefix).as_str(),
))
.with_context(|| "RemotePath conversion")?;
ctx.client.delete(&path).await.expect("should succeed");
Ok(())
}
#[test_context(MaybeEnabledAzure)]
#[tokio::test]
async fn azure_delete_objects_works(ctx: &mut MaybeEnabledAzure) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledAzure::Enabled(ctx) => ctx,
MaybeEnabledAzure::Disabled => return Ok(()),
};
let path1 = RemotePath::new(Utf8Path::new(format!("{}/path1", ctx.base_prefix).as_str()))
.with_context(|| "RemotePath conversion")?;
let path2 = RemotePath::new(Utf8Path::new(format!("{}/path2", ctx.base_prefix).as_str()))
.with_context(|| "RemotePath conversion")?;
let path3 = RemotePath::new(Utf8Path::new(format!("{}/path3", ctx.base_prefix).as_str()))
.with_context(|| "RemotePath conversion")?;
let (data, len) = upload_stream("remote blob data1".as_bytes().into());
ctx.client.upload(data, len, &path1, None).await?;
let (data, len) = upload_stream("remote blob data2".as_bytes().into());
ctx.client.upload(data, len, &path2, None).await?;
let (data, len) = upload_stream("remote blob data3".as_bytes().into());
ctx.client.upload(data, len, &path3, None).await?;
ctx.client.delete_objects(&[path1, path2]).await?;
let prefixes = ctx.client.list_prefixes(None).await?;
assert_eq!(prefixes.len(), 1);
ctx.client.delete_objects(&[path3]).await?;
Ok(())
}
#[test_context(MaybeEnabledAzure)]
#[tokio::test]
async fn azure_upload_download_works(ctx: &mut MaybeEnabledAzure) -> anyhow::Result<()> {
let MaybeEnabledAzure::Enabled(ctx) = ctx else {
return Ok(());
};
let path = RemotePath::new(Utf8Path::new(format!("{}/file", ctx.base_prefix).as_str()))
.with_context(|| "RemotePath conversion")?;
let orig = bytes::Bytes::from_static("remote blob data here".as_bytes());
let (data, len) = wrap_stream(orig.clone());
ctx.client.upload(data, len, &path, None).await?;
async fn download_and_compare(dl: Download) -> anyhow::Result<Vec<u8>> {
let mut buf = Vec::new();
tokio::io::copy_buf(
&mut tokio_util::io::StreamReader::new(dl.download_stream),
&mut buf,
)
.await?;
Ok(buf)
}
// Normal download request
let dl = ctx.client.download(&path).await?;
let buf = download_and_compare(dl).await?;
assert_eq!(&buf, &orig);
// Full range (end specified)
let dl = ctx
.client
.download_byte_range(&path, 0, Some(len as u64))
.await?;
let buf = download_and_compare(dl).await?;
assert_eq!(&buf, &orig);
// partial range (end specified)
let dl = ctx.client.download_byte_range(&path, 4, Some(10)).await?;
let buf = download_and_compare(dl).await?;
assert_eq!(&buf, &orig[4..10]);
// partial range (end beyond real end)
let dl = ctx
.client
.download_byte_range(&path, 8, Some(len as u64 * 100))
.await?;
let buf = download_and_compare(dl).await?;
assert_eq!(&buf, &orig[8..]);
// Partial range (end unspecified)
let dl = ctx.client.download_byte_range(&path, 4, None).await?;
let buf = download_and_compare(dl).await?;
assert_eq!(&buf, &orig[4..]);
// Full range (end unspecified)
let dl = ctx.client.download_byte_range(&path, 0, None).await?;
let buf = download_and_compare(dl).await?;
assert_eq!(&buf, &orig);
debug!("Cleanup: deleting file at path {path:?}");
ctx.client
.delete(&path)
.await
.with_context(|| format!("{path:?} removal"))?;
Ok(())
}
fn ensure_logging_ready() {
LOGGING_DONE.get_or_init(|| {
utils::logging::init(
utils::logging::LogFormat::Test,
utils::logging::TracingErrorLayerEnablement::Disabled,
utils::logging::Output::Stdout,
)
.expect("logging init failed");
});
}
struct EnabledAzure {
client: Arc<GenericRemoteStorage>,
base_prefix: &'static str,
@@ -41,13 +301,13 @@ impl EnabledAzure {
}
}
enum MaybeEnabledStorage {
enum MaybeEnabledAzure {
Enabled(EnabledAzure),
Disabled,
}
#[async_trait::async_trait]
impl AsyncTestContext for MaybeEnabledStorage {
impl AsyncTestContext for MaybeEnabledAzure {
async fn setup() -> Self {
ensure_logging_ready();
@@ -63,7 +323,7 @@ impl AsyncTestContext for MaybeEnabledStorage {
}
}
enum MaybeEnabledStorageWithTestBlobs {
enum MaybeEnabledAzureWithTestBlobs {
Enabled(AzureWithTestBlobs),
Disabled,
UploadsFailed(anyhow::Error, AzureWithTestBlobs),
@@ -76,7 +336,7 @@ struct AzureWithTestBlobs {
}
#[async_trait::async_trait]
impl AsyncTestContext for MaybeEnabledStorageWithTestBlobs {
impl AsyncTestContext for MaybeEnabledAzureWithTestBlobs {
async fn setup() -> Self {
ensure_logging_ready();
if env::var(ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME).is_err() {
@@ -92,7 +352,7 @@ impl AsyncTestContext for MaybeEnabledStorageWithTestBlobs {
let enabled = EnabledAzure::setup(Some(max_keys_in_list_response)).await;
match upload_remote_data(&enabled.client, enabled.base_prefix, upload_tasks_count).await {
match upload_azure_data(&enabled.client, enabled.base_prefix, upload_tasks_count).await {
ControlFlow::Continue(uploads) => {
info!("Remote objects created successfully");
@@ -127,7 +387,7 @@ impl AsyncTestContext for MaybeEnabledStorageWithTestBlobs {
// However, they are not idential. The list_prefixes function is concerned with listing prefixes,
// whereas the list_files function is concerned with listing files.
// See `RemoteStorage::list_files` documentation for more details
enum MaybeEnabledStorageWithSimpleTestBlobs {
enum MaybeEnabledAzureWithSimpleTestBlobs {
Enabled(AzureWithSimpleTestBlobs),
Disabled,
UploadsFailed(anyhow::Error, AzureWithSimpleTestBlobs),
@@ -138,7 +398,7 @@ struct AzureWithSimpleTestBlobs {
}
#[async_trait::async_trait]
impl AsyncTestContext for MaybeEnabledStorageWithSimpleTestBlobs {
impl AsyncTestContext for MaybeEnabledAzureWithSimpleTestBlobs {
async fn setup() -> Self {
ensure_logging_ready();
if env::var(ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME).is_err() {
@@ -154,7 +414,7 @@ impl AsyncTestContext for MaybeEnabledStorageWithSimpleTestBlobs {
let enabled = EnabledAzure::setup(Some(max_keys_in_list_response)).await;
match upload_simple_remote_data(&enabled.client, upload_tasks_count).await {
match upload_simple_azure_data(&enabled.client, upload_tasks_count).await {
ControlFlow::Continue(uploads) => {
info!("Remote objects created successfully");
@@ -218,3 +478,166 @@ fn create_azure_client(
GenericRemoteStorage::from_config(&remote_storage_config).context("remote storage init")?,
))
}
struct Uploads {
prefixes: HashSet<RemotePath>,
blobs: HashSet<RemotePath>,
}
async fn upload_azure_data(
client: &Arc<GenericRemoteStorage>,
base_prefix_str: &'static str,
upload_tasks_count: usize,
) -> ControlFlow<Uploads, Uploads> {
info!("Creating {upload_tasks_count} Azure files");
let mut upload_tasks = JoinSet::new();
for i in 1..upload_tasks_count + 1 {
let task_client = Arc::clone(client);
upload_tasks.spawn(async move {
let prefix = format!("{base_prefix_str}/sub_prefix_{i}/");
let blob_prefix = RemotePath::new(Utf8Path::new(&prefix))
.with_context(|| format!("{prefix:?} to RemotePath conversion"))?;
let blob_path = blob_prefix.join(Utf8Path::new(&format!("blob_{i}")));
debug!("Creating remote item {i} at path {blob_path:?}");
let (data, len) = upload_stream(format!("remote blob data {i}").into_bytes().into());
task_client.upload(data, len, &blob_path, None).await?;
Ok::<_, anyhow::Error>((blob_prefix, blob_path))
});
}
let mut upload_tasks_failed = false;
let mut uploaded_prefixes = HashSet::with_capacity(upload_tasks_count);
let mut uploaded_blobs = HashSet::with_capacity(upload_tasks_count);
while let Some(task_run_result) = upload_tasks.join_next().await {
match task_run_result
.context("task join failed")
.and_then(|task_result| task_result.context("upload task failed"))
{
Ok((upload_prefix, upload_path)) => {
uploaded_prefixes.insert(upload_prefix);
uploaded_blobs.insert(upload_path);
}
Err(e) => {
error!("Upload task failed: {e:?}");
upload_tasks_failed = true;
}
}
}
let uploads = Uploads {
prefixes: uploaded_prefixes,
blobs: uploaded_blobs,
};
if upload_tasks_failed {
ControlFlow::Break(uploads)
} else {
ControlFlow::Continue(uploads)
}
}
async fn cleanup(client: &Arc<GenericRemoteStorage>, objects_to_delete: HashSet<RemotePath>) {
info!(
"Removing {} objects from the remote storage during cleanup",
objects_to_delete.len()
);
let mut delete_tasks = JoinSet::new();
for object_to_delete in objects_to_delete {
let task_client = Arc::clone(client);
delete_tasks.spawn(async move {
debug!("Deleting remote item at path {object_to_delete:?}");
task_client
.delete(&object_to_delete)
.await
.with_context(|| format!("{object_to_delete:?} removal"))
});
}
while let Some(task_run_result) = delete_tasks.join_next().await {
match task_run_result {
Ok(task_result) => match task_result {
Ok(()) => {}
Err(e) => error!("Delete task failed: {e:?}"),
},
Err(join_err) => error!("Delete task did not finish correctly: {join_err}"),
}
}
}
// Uploads files `folder{j}/blob{i}.txt`. See test description for more details.
async fn upload_simple_azure_data(
client: &Arc<GenericRemoteStorage>,
upload_tasks_count: usize,
) -> ControlFlow<HashSet<RemotePath>, HashSet<RemotePath>> {
info!("Creating {upload_tasks_count} Azure files");
let mut upload_tasks = JoinSet::new();
for i in 1..upload_tasks_count + 1 {
let task_client = Arc::clone(client);
upload_tasks.spawn(async move {
let blob_path = PathBuf::from(format!("folder{}/blob_{}.txt", i / 7, i));
let blob_path = RemotePath::new(
Utf8Path::from_path(blob_path.as_path()).expect("must be valid blob path"),
)
.with_context(|| format!("{blob_path:?} to RemotePath conversion"))?;
debug!("Creating remote item {i} at path {blob_path:?}");
let (data, len) = upload_stream(format!("remote blob data {i}").into_bytes().into());
task_client.upload(data, len, &blob_path, None).await?;
Ok::<_, anyhow::Error>(blob_path)
});
}
let mut upload_tasks_failed = false;
let mut uploaded_blobs = HashSet::with_capacity(upload_tasks_count);
while let Some(task_run_result) = upload_tasks.join_next().await {
match task_run_result
.context("task join failed")
.and_then(|task_result| task_result.context("upload task failed"))
{
Ok(upload_path) => {
uploaded_blobs.insert(upload_path);
}
Err(e) => {
error!("Upload task failed: {e:?}");
upload_tasks_failed = true;
}
}
}
if upload_tasks_failed {
ControlFlow::Break(uploaded_blobs)
} else {
ControlFlow::Continue(uploaded_blobs)
}
}
// FIXME: copypasted from test_real_s3, can't remember how to share a module which is not compiled
// to binary
fn upload_stream(
content: std::borrow::Cow<'static, [u8]>,
) -> (
impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
usize,
) {
use std::borrow::Cow;
let content = match content {
Cow::Borrowed(x) => Bytes::from_static(x),
Cow::Owned(vec) => Bytes::from(vec),
};
wrap_stream(content)
}
fn wrap_stream(
content: bytes::Bytes,
) -> (
impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
usize,
) {
let len = content.len();
let content = futures::future::ready(Ok(content));
(futures::stream::once(content), len)
}

View File

@@ -2,27 +2,213 @@ use std::collections::HashSet;
use std::env;
use std::num::NonZeroUsize;
use std::ops::ControlFlow;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::UNIX_EPOCH;
use anyhow::Context;
use bytes::Bytes;
use camino::Utf8Path;
use futures::stream::Stream;
use once_cell::sync::OnceCell;
use remote_storage::{
GenericRemoteStorage, RemotePath, RemoteStorageConfig, RemoteStorageKind, S3Config,
};
use test_context::AsyncTestContext;
use tracing::info;
use test_context::{test_context, AsyncTestContext};
use tokio::task::JoinSet;
use tracing::{debug, error, info};
mod common;
#[path = "common/tests.rs"]
mod tests_s3;
use common::{cleanup, ensure_logging_ready, upload_remote_data, upload_simple_remote_data};
static LOGGING_DONE: OnceCell<()> = OnceCell::new();
const ENABLE_REAL_S3_REMOTE_STORAGE_ENV_VAR_NAME: &str = "ENABLE_REAL_S3_REMOTE_STORAGE";
const BASE_PREFIX: &str = "test";
/// Tests that S3 client can list all prefixes, even if the response come paginated and requires multiple S3 queries.
/// Uses real S3 and requires [`ENABLE_REAL_S3_REMOTE_STORAGE_ENV_VAR_NAME`] and related S3 cred env vars specified.
/// See the client creation in [`create_s3_client`] for details on the required env vars.
/// If real S3 tests are disabled, the test passes, skipping any real test run: currently, there's no way to mark the test ignored in runtime with the
/// deafult test framework, see https://github.com/rust-lang/rust/issues/68007 for details.
///
/// First, the test creates a set of S3 objects with keys `/${random_prefix_part}/${base_prefix_str}/sub_prefix_${i}/blob_${i}` in [`upload_s3_data`]
/// where
/// * `random_prefix_part` is set for the entire S3 client during the S3 client creation in [`create_s3_client`], to avoid multiple test runs interference
/// * `base_prefix_str` is a common prefix to use in the client requests: we would want to ensure that the client is able to list nested prefixes inside the bucket
///
/// Then, verifies that the client does return correct prefixes when queried:
/// * with no prefix, it lists everything after its `${random_prefix_part}/` — that should be `${base_prefix_str}` value only
/// * with `${base_prefix_str}/` prefix, it lists every `sub_prefix_${i}`
///
/// With the real S3 enabled and `#[cfg(test)]` Rust configuration used, the S3 client test adds a `max-keys` param to limit the response keys.
/// This way, we are able to test the pagination implicitly, by ensuring all results are returned from the remote storage and avoid uploading too many blobs to S3,
/// since current default AWS S3 pagination limit is 1000.
/// (see https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html#API_ListObjectsV2_RequestSyntax)
///
/// Lastly, the test attempts to clean up and remove all uploaded S3 files.
/// If any errors appear during the clean up, they get logged, but the test is not failed or stopped until clean up is finished.
#[test_context(MaybeEnabledS3WithTestBlobs)]
#[tokio::test]
async fn s3_pagination_should_work(ctx: &mut MaybeEnabledS3WithTestBlobs) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledS3WithTestBlobs::Enabled(ctx) => ctx,
MaybeEnabledS3WithTestBlobs::Disabled => return Ok(()),
MaybeEnabledS3WithTestBlobs::UploadsFailed(e, _) => anyhow::bail!("S3 init failed: {e:?}"),
};
let test_client = Arc::clone(&ctx.enabled.client);
let expected_remote_prefixes = ctx.remote_prefixes.clone();
let base_prefix = RemotePath::new(Utf8Path::new(ctx.enabled.base_prefix))
.context("common_prefix construction")?;
let root_remote_prefixes = test_client
.list_prefixes(None)
.await
.context("client list root prefixes failure")?
.into_iter()
.collect::<HashSet<_>>();
assert_eq!(
root_remote_prefixes, HashSet::from([base_prefix.clone()]),
"remote storage root prefixes list mismatches with the uploads. Returned prefixes: {root_remote_prefixes:?}"
);
let nested_remote_prefixes = test_client
.list_prefixes(Some(&base_prefix))
.await
.context("client list nested prefixes failure")?
.into_iter()
.collect::<HashSet<_>>();
let remote_only_prefixes = nested_remote_prefixes
.difference(&expected_remote_prefixes)
.collect::<HashSet<_>>();
let missing_uploaded_prefixes = expected_remote_prefixes
.difference(&nested_remote_prefixes)
.collect::<HashSet<_>>();
assert_eq!(
remote_only_prefixes.len() + missing_uploaded_prefixes.len(), 0,
"remote storage nested prefixes list mismatches with the uploads. Remote only prefixes: {remote_only_prefixes:?}, missing uploaded prefixes: {missing_uploaded_prefixes:?}",
);
Ok(())
}
/// Tests that S3 client can list all files in a folder, even if the response comes paginated and requirees multiple S3 queries.
/// Uses real S3 and requires [`ENABLE_REAL_S3_REMOTE_STORAGE_ENV_VAR_NAME`] and related S3 cred env vars specified. Test will skip real code and pass if env vars not set.
/// See `s3_pagination_should_work` for more information.
///
/// First, create a set of S3 objects with keys `random_prefix/folder{j}/blob_{i}.txt` in [`upload_s3_data`]
/// Then performs the following queries:
/// 1. `list_files(None)`. This should return all files `random_prefix/folder{j}/blob_{i}.txt`
/// 2. `list_files("folder1")`. This should return all files `random_prefix/folder1/blob_{i}.txt`
#[test_context(MaybeEnabledS3WithSimpleTestBlobs)]
#[tokio::test]
async fn s3_list_files_works(ctx: &mut MaybeEnabledS3WithSimpleTestBlobs) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledS3WithSimpleTestBlobs::Enabled(ctx) => ctx,
MaybeEnabledS3WithSimpleTestBlobs::Disabled => return Ok(()),
MaybeEnabledS3WithSimpleTestBlobs::UploadsFailed(e, _) => {
anyhow::bail!("S3 init failed: {e:?}")
}
};
let test_client = Arc::clone(&ctx.enabled.client);
let base_prefix =
RemotePath::new(Utf8Path::new("folder1")).context("common_prefix construction")?;
let root_files = test_client
.list_files(None)
.await
.context("client list root files failure")?
.into_iter()
.collect::<HashSet<_>>();
assert_eq!(
root_files,
ctx.remote_blobs.clone(),
"remote storage list_files on root mismatches with the uploads."
);
let nested_remote_files = test_client
.list_files(Some(&base_prefix))
.await
.context("client list nested files failure")?
.into_iter()
.collect::<HashSet<_>>();
let trim_remote_blobs: HashSet<_> = ctx
.remote_blobs
.iter()
.map(|x| x.get_path())
.filter(|x| x.starts_with("folder1"))
.map(|x| RemotePath::new(x).expect("must be valid path"))
.collect();
assert_eq!(
nested_remote_files, trim_remote_blobs,
"remote storage list_files on subdirrectory mismatches with the uploads."
);
Ok(())
}
#[test_context(MaybeEnabledS3)]
#[tokio::test]
async fn s3_delete_non_exising_works(ctx: &mut MaybeEnabledS3) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledS3::Enabled(ctx) => ctx,
MaybeEnabledS3::Disabled => return Ok(()),
};
let path = RemotePath::new(Utf8Path::new(
format!("{}/for_sure_there_is_nothing_there_really", ctx.base_prefix).as_str(),
))
.with_context(|| "RemotePath conversion")?;
ctx.client.delete(&path).await.expect("should succeed");
Ok(())
}
#[test_context(MaybeEnabledS3)]
#[tokio::test]
async fn s3_delete_objects_works(ctx: &mut MaybeEnabledS3) -> anyhow::Result<()> {
let ctx = match ctx {
MaybeEnabledS3::Enabled(ctx) => ctx,
MaybeEnabledS3::Disabled => return Ok(()),
};
let path1 = RemotePath::new(Utf8Path::new(format!("{}/path1", ctx.base_prefix).as_str()))
.with_context(|| "RemotePath conversion")?;
let path2 = RemotePath::new(Utf8Path::new(format!("{}/path2", ctx.base_prefix).as_str()))
.with_context(|| "RemotePath conversion")?;
let path3 = RemotePath::new(Utf8Path::new(format!("{}/path3", ctx.base_prefix).as_str()))
.with_context(|| "RemotePath conversion")?;
let (data, len) = upload_stream("remote blob data1".as_bytes().into());
ctx.client.upload(data, len, &path1, None).await?;
let (data, len) = upload_stream("remote blob data2".as_bytes().into());
ctx.client.upload(data, len, &path2, None).await?;
let (data, len) = upload_stream("remote blob data3".as_bytes().into());
ctx.client.upload(data, len, &path3, None).await?;
ctx.client.delete_objects(&[path1, path2]).await?;
let prefixes = ctx.client.list_prefixes(None).await?;
assert_eq!(prefixes.len(), 1);
ctx.client.delete_objects(&[path3]).await?;
Ok(())
}
fn ensure_logging_ready() {
LOGGING_DONE.get_or_init(|| {
utils::logging::init(
utils::logging::LogFormat::Test,
utils::logging::TracingErrorLayerEnablement::Disabled,
utils::logging::Output::Stdout,
)
.expect("logging init failed");
});
}
struct EnabledS3 {
client: Arc<GenericRemoteStorage>,
base_prefix: &'static str,
@@ -41,13 +227,13 @@ impl EnabledS3 {
}
}
enum MaybeEnabledStorage {
enum MaybeEnabledS3 {
Enabled(EnabledS3),
Disabled,
}
#[async_trait::async_trait]
impl AsyncTestContext for MaybeEnabledStorage {
impl AsyncTestContext for MaybeEnabledS3 {
async fn setup() -> Self {
ensure_logging_ready();
@@ -63,7 +249,7 @@ impl AsyncTestContext for MaybeEnabledStorage {
}
}
enum MaybeEnabledStorageWithTestBlobs {
enum MaybeEnabledS3WithTestBlobs {
Enabled(S3WithTestBlobs),
Disabled,
UploadsFailed(anyhow::Error, S3WithTestBlobs),
@@ -76,7 +262,7 @@ struct S3WithTestBlobs {
}
#[async_trait::async_trait]
impl AsyncTestContext for MaybeEnabledStorageWithTestBlobs {
impl AsyncTestContext for MaybeEnabledS3WithTestBlobs {
async fn setup() -> Self {
ensure_logging_ready();
if env::var(ENABLE_REAL_S3_REMOTE_STORAGE_ENV_VAR_NAME).is_err() {
@@ -92,7 +278,7 @@ impl AsyncTestContext for MaybeEnabledStorageWithTestBlobs {
let enabled = EnabledS3::setup(Some(max_keys_in_list_response)).await;
match upload_remote_data(&enabled.client, enabled.base_prefix, upload_tasks_count).await {
match upload_s3_data(&enabled.client, enabled.base_prefix, upload_tasks_count).await {
ControlFlow::Continue(uploads) => {
info!("Remote objects created successfully");
@@ -127,7 +313,7 @@ impl AsyncTestContext for MaybeEnabledStorageWithTestBlobs {
// However, they are not idential. The list_prefixes function is concerned with listing prefixes,
// whereas the list_files function is concerned with listing files.
// See `RemoteStorage::list_files` documentation for more details
enum MaybeEnabledStorageWithSimpleTestBlobs {
enum MaybeEnabledS3WithSimpleTestBlobs {
Enabled(S3WithSimpleTestBlobs),
Disabled,
UploadsFailed(anyhow::Error, S3WithSimpleTestBlobs),
@@ -138,7 +324,7 @@ struct S3WithSimpleTestBlobs {
}
#[async_trait::async_trait]
impl AsyncTestContext for MaybeEnabledStorageWithSimpleTestBlobs {
impl AsyncTestContext for MaybeEnabledS3WithSimpleTestBlobs {
async fn setup() -> Self {
ensure_logging_ready();
if env::var(ENABLE_REAL_S3_REMOTE_STORAGE_ENV_VAR_NAME).is_err() {
@@ -154,7 +340,7 @@ impl AsyncTestContext for MaybeEnabledStorageWithSimpleTestBlobs {
let enabled = EnabledS3::setup(Some(max_keys_in_list_response)).await;
match upload_simple_remote_data(&enabled.client, upload_tasks_count).await {
match upload_simple_s3_data(&enabled.client, upload_tasks_count).await {
ControlFlow::Continue(uploads) => {
info!("Remote objects created successfully");
@@ -217,3 +403,166 @@ fn create_s3_client(
GenericRemoteStorage::from_config(&remote_storage_config).context("remote storage init")?,
))
}
struct Uploads {
prefixes: HashSet<RemotePath>,
blobs: HashSet<RemotePath>,
}
async fn upload_s3_data(
client: &Arc<GenericRemoteStorage>,
base_prefix_str: &'static str,
upload_tasks_count: usize,
) -> ControlFlow<Uploads, Uploads> {
info!("Creating {upload_tasks_count} S3 files");
let mut upload_tasks = JoinSet::new();
for i in 1..upload_tasks_count + 1 {
let task_client = Arc::clone(client);
upload_tasks.spawn(async move {
let prefix = format!("{base_prefix_str}/sub_prefix_{i}/");
let blob_prefix = RemotePath::new(Utf8Path::new(&prefix))
.with_context(|| format!("{prefix:?} to RemotePath conversion"))?;
let blob_path = blob_prefix.join(Utf8Path::new(&format!("blob_{i}")));
debug!("Creating remote item {i} at path {blob_path:?}");
let (data, data_len) =
upload_stream(format!("remote blob data {i}").into_bytes().into());
task_client.upload(data, data_len, &blob_path, None).await?;
Ok::<_, anyhow::Error>((blob_prefix, blob_path))
});
}
let mut upload_tasks_failed = false;
let mut uploaded_prefixes = HashSet::with_capacity(upload_tasks_count);
let mut uploaded_blobs = HashSet::with_capacity(upload_tasks_count);
while let Some(task_run_result) = upload_tasks.join_next().await {
match task_run_result
.context("task join failed")
.and_then(|task_result| task_result.context("upload task failed"))
{
Ok((upload_prefix, upload_path)) => {
uploaded_prefixes.insert(upload_prefix);
uploaded_blobs.insert(upload_path);
}
Err(e) => {
error!("Upload task failed: {e:?}");
upload_tasks_failed = true;
}
}
}
let uploads = Uploads {
prefixes: uploaded_prefixes,
blobs: uploaded_blobs,
};
if upload_tasks_failed {
ControlFlow::Break(uploads)
} else {
ControlFlow::Continue(uploads)
}
}
async fn cleanup(client: &Arc<GenericRemoteStorage>, objects_to_delete: HashSet<RemotePath>) {
info!(
"Removing {} objects from the remote storage during cleanup",
objects_to_delete.len()
);
let mut delete_tasks = JoinSet::new();
for object_to_delete in objects_to_delete {
let task_client = Arc::clone(client);
delete_tasks.spawn(async move {
debug!("Deleting remote item at path {object_to_delete:?}");
task_client
.delete(&object_to_delete)
.await
.with_context(|| format!("{object_to_delete:?} removal"))
});
}
while let Some(task_run_result) = delete_tasks.join_next().await {
match task_run_result {
Ok(task_result) => match task_result {
Ok(()) => {}
Err(e) => error!("Delete task failed: {e:?}"),
},
Err(join_err) => error!("Delete task did not finish correctly: {join_err}"),
}
}
}
// Uploads files `folder{j}/blob{i}.txt`. See test description for more details.
async fn upload_simple_s3_data(
client: &Arc<GenericRemoteStorage>,
upload_tasks_count: usize,
) -> ControlFlow<HashSet<RemotePath>, HashSet<RemotePath>> {
info!("Creating {upload_tasks_count} S3 files");
let mut upload_tasks = JoinSet::new();
for i in 1..upload_tasks_count + 1 {
let task_client = Arc::clone(client);
upload_tasks.spawn(async move {
let blob_path = PathBuf::from(format!("folder{}/blob_{}.txt", i / 7, i));
let blob_path = RemotePath::new(
Utf8Path::from_path(blob_path.as_path()).expect("must be valid blob path"),
)
.with_context(|| format!("{blob_path:?} to RemotePath conversion"))?;
debug!("Creating remote item {i} at path {blob_path:?}");
let (data, data_len) =
upload_stream(format!("remote blob data {i}").into_bytes().into());
task_client.upload(data, data_len, &blob_path, None).await?;
Ok::<_, anyhow::Error>(blob_path)
});
}
let mut upload_tasks_failed = false;
let mut uploaded_blobs = HashSet::with_capacity(upload_tasks_count);
while let Some(task_run_result) = upload_tasks.join_next().await {
match task_run_result
.context("task join failed")
.and_then(|task_result| task_result.context("upload task failed"))
{
Ok(upload_path) => {
uploaded_blobs.insert(upload_path);
}
Err(e) => {
error!("Upload task failed: {e:?}");
upload_tasks_failed = true;
}
}
}
if upload_tasks_failed {
ControlFlow::Break(uploaded_blobs)
} else {
ControlFlow::Continue(uploaded_blobs)
}
}
fn upload_stream(
content: std::borrow::Cow<'static, [u8]>,
) -> (
impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
usize,
) {
use std::borrow::Cow;
let content = match content {
Cow::Borrowed(x) => Bytes::from_static(x),
Cow::Owned(vec) => Bytes::from(vec),
};
wrap_stream(content)
}
fn wrap_stream(
content: bytes::Bytes,
) -> (
impl Stream<Item = std::io::Result<Bytes>> + Send + Sync + 'static,
usize,
) {
let len = content.len();
let content = futures::future::ready(Ok(content));
(futures::stream::once(content), len)
}

View File

@@ -51,9 +51,3 @@ pub struct SkTimelineInfo {
#[serde(default)]
pub http_connstr: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TimelineCopyRequest {
pub target_timeline_id: TimelineId,
pub until_lsn: Lsn,
}

View File

@@ -4,12 +4,6 @@ version = "0.1.0"
edition.workspace = true
license.workspace = true
[features]
default = []
# Enables test-only APIs, incuding failpoints. In particular, enables the `fail_point!` macro,
# which adds some runtime cost to run tests on outage conditions
testing = ["fail/failpoints"]
[dependencies]
arc-swap.workspace = true
sentry.workspace = true
@@ -22,7 +16,6 @@ chrono.workspace = true
heapless.workspace = true
hex = { workspace = true, features = ["serde"] }
hyper = { workspace = true, features = ["full"] }
fail.workspace = true
futures = { workspace = true}
jsonwebtoken.workspace = true
nix.workspace = true

View File

@@ -1,177 +0,0 @@
//! Failpoint support code shared between pageserver and safekeepers.
use crate::http::{
error::ApiError,
json::{json_request, json_response},
};
use hyper::{Body, Request, Response, StatusCode};
use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use tracing::*;
/// use with fail::cfg("$name", "return(2000)")
///
/// The effect is similar to a "sleep(2000)" action, i.e. we sleep for the
/// specified time (in milliseconds). The main difference is that we use async
/// tokio sleep function. Another difference is that we print lines to the log,
/// which can be useful in tests to check that the failpoint was hit.
///
/// Optionally pass a cancellation token, and this failpoint will drop out of
/// its sleep when the cancellation token fires. This is useful for testing
/// cases where we would like to block something, but test its clean shutdown behavior.
#[macro_export]
macro_rules! __failpoint_sleep_millis_async {
($name:literal) => {{
// If the failpoint is used with a "return" action, set should_sleep to the
// returned value (as string). Otherwise it's set to None.
let should_sleep = (|| {
::fail::fail_point!($name, |x| x);
::std::option::Option::None
})();
// Sleep if the action was a returned value
if let ::std::option::Option::Some(duration_str) = should_sleep {
$crate::failpoint_support::failpoint_sleep_helper($name, duration_str).await
}
}};
($name:literal, $cancel:expr) => {{
// If the failpoint is used with a "return" action, set should_sleep to the
// returned value (as string). Otherwise it's set to None.
let should_sleep = (|| {
::fail::fail_point!($name, |x| x);
::std::option::Option::None
})();
// Sleep if the action was a returned value
if let ::std::option::Option::Some(duration_str) = should_sleep {
$crate::failpoint_support::failpoint_sleep_cancellable_helper(
$name,
duration_str,
$cancel,
)
.await
}
}};
}
pub use __failpoint_sleep_millis_async as sleep_millis_async;
// Helper function used by the macro. (A function has nicer scoping so we
// don't need to decorate everything with "::")
#[doc(hidden)]
pub async fn failpoint_sleep_helper(name: &'static str, duration_str: String) {
let millis = duration_str.parse::<u64>().unwrap();
let d = std::time::Duration::from_millis(millis);
tracing::info!("failpoint {:?}: sleeping for {:?}", name, d);
tokio::time::sleep(d).await;
tracing::info!("failpoint {:?}: sleep done", name);
}
// Helper function used by the macro. (A function has nicer scoping so we
// don't need to decorate everything with "::")
#[doc(hidden)]
pub async fn failpoint_sleep_cancellable_helper(
name: &'static str,
duration_str: String,
cancel: &CancellationToken,
) {
let millis = duration_str.parse::<u64>().unwrap();
let d = std::time::Duration::from_millis(millis);
tracing::info!("failpoint {:?}: sleeping for {:?}", name, d);
tokio::time::timeout(d, cancel.cancelled()).await.ok();
tracing::info!("failpoint {:?}: sleep done", name);
}
pub fn init() -> fail::FailScenario<'static> {
// The failpoints lib provides support for parsing the `FAILPOINTS` env var.
// We want non-default behavior for `exit`, though, so, we handle it separately.
//
// Format for FAILPOINTS is "name=actions" separated by ";".
let actions = std::env::var("FAILPOINTS");
if actions.is_ok() {
std::env::remove_var("FAILPOINTS");
} else {
// let the library handle non-utf8, or nothing for not present
}
let scenario = fail::FailScenario::setup();
if let Ok(val) = actions {
val.split(';')
.enumerate()
.map(|(i, s)| s.split_once('=').ok_or((i, s)))
.for_each(|res| {
let (name, actions) = match res {
Ok(t) => t,
Err((i, s)) => {
panic!(
"startup failpoints: missing action on the {}th failpoint; try `{s}=return`",
i + 1,
);
}
};
if let Err(e) = apply_failpoint(name, actions) {
panic!("startup failpoints: failed to apply failpoint {name}={actions}: {e}");
}
});
}
scenario
}
pub fn apply_failpoint(name: &str, actions: &str) -> Result<(), String> {
if actions == "exit" {
fail::cfg_callback(name, exit_failpoint)
} else {
fail::cfg(name, actions)
}
}
#[inline(never)]
fn exit_failpoint() {
tracing::info!("Exit requested by failpoint");
std::process::exit(1);
}
pub type ConfigureFailpointsRequest = Vec<FailpointConfig>;
/// Information for configuring a single fail point
#[derive(Debug, Serialize, Deserialize)]
pub struct FailpointConfig {
/// Name of the fail point
pub name: String,
/// List of actions to take, using the format described in `fail::cfg`
///
/// We also support `actions = "exit"` to cause the fail point to immediately exit.
pub actions: String,
}
/// Configure failpoints through http.
pub async fn failpoints_handler(
mut request: Request<Body>,
_cancel: CancellationToken,
) -> Result<Response<Body>, ApiError> {
if !fail::has_failpoints() {
return Err(ApiError::BadRequest(anyhow::anyhow!(
"Cannot manage failpoints because storage was compiled without failpoints support"
)));
}
let failpoints: ConfigureFailpointsRequest = json_request(&mut request).await?;
for fp in failpoints {
info!("cfg failpoint: {} {}", fp.name, fp.actions);
// We recognize one extra "action" that's not natively recognized
// by the failpoints crate: exit, to immediately kill the process
let cfg_result = apply_failpoint(&fp.name, &fp.actions);
if let Err(err_msg) = cfg_result {
return Err(ApiError::BadRequest(anyhow::anyhow!(
"Failed to configure failpoints: {err_msg}"
)));
}
}
json_response(StatusCode::OK, ())
}

View File

@@ -31,9 +31,6 @@ pub enum ApiError {
#[error("Shutting down")]
ShuttingDown,
#[error("Timeout")]
Timeout(Cow<'static, str>),
#[error(transparent)]
InternalServerError(anyhow::Error),
}
@@ -70,10 +67,6 @@ impl ApiError {
err.to_string(),
StatusCode::SERVICE_UNAVAILABLE,
),
ApiError::Timeout(err) => HttpErrorBody::response_from_msg_and_status(
err.to_string(),
StatusCode::REQUEST_TIMEOUT,
),
ApiError::InternalServerError(err) => HttpErrorBody::response_from_msg_and_status(
err.to_string(),
StatusCode::INTERNAL_SERVER_ERROR,

View File

@@ -1,4 +1,3 @@
use std::num::ParseIntError;
use std::{fmt, str::FromStr};
use anyhow::Context;
@@ -375,13 +374,6 @@ impl fmt::Display for NodeId {
}
}
impl FromStr for NodeId {
type Err = ParseIntError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(NodeId(u64::from_str(s)?))
}
}
#[cfg(test)]
mod tests {
use serde_assert::{Deserializer, Serializer, Token, Tokens};

View File

@@ -83,10 +83,6 @@ pub mod timeout;
pub mod sync;
pub mod failpoint_support;
pub mod yielding_loop;
/// This is a shortcut to embed git sha into binaries and avoid copying the same build script to all packages
///
/// we have several cases:

View File

@@ -366,49 +366,6 @@ impl MonotonicCounter<Lsn> for RecordLsn {
}
}
/// Implements [`rand::distributions::uniform::UniformSampler`] so we can sample [`Lsn`]s.
///
/// This is used by the `pagebench` pageserver benchmarking tool.
pub struct LsnSampler(<u64 as rand::distributions::uniform::SampleUniform>::Sampler);
impl rand::distributions::uniform::SampleUniform for Lsn {
type Sampler = LsnSampler;
}
impl rand::distributions::uniform::UniformSampler for LsnSampler {
type X = Lsn;
fn new<B1, B2>(low: B1, high: B2) -> Self
where
B1: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
B2: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
{
Self(
<u64 as rand::distributions::uniform::SampleUniform>::Sampler::new(
low.borrow().0,
high.borrow().0,
),
)
}
fn new_inclusive<B1, B2>(low: B1, high: B2) -> Self
where
B1: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
B2: rand::distributions::uniform::SampleBorrow<Self::X> + Sized,
{
Self(
<u64 as rand::distributions::uniform::SampleUniform>::Sampler::new_inclusive(
low.borrow().0,
high.borrow().0,
),
)
}
fn sample<R: rand::prelude::Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
Lsn(self.0.sample(rng))
}
}
#[cfg(test)]
mod tests {
use crate::bin_ser::BeSer;

View File

@@ -15,12 +15,6 @@ pub struct Gate {
name: String,
}
impl std::fmt::Debug for Gate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Gate<{}>", self.name)
}
}
/// RAII guard for a [`Gate`]: as long as this exists, calls to [`Gate::close`] will
/// not complete.
#[derive(Debug)]

View File

@@ -2,11 +2,8 @@ use std::time::Duration;
use tokio_util::sync::CancellationToken;
#[derive(thiserror::Error, Debug)]
pub enum TimeoutCancellableError {
#[error("Timed out")]
Timeout,
#[error("Cancelled")]
Cancelled,
}

View File

@@ -1,35 +0,0 @@
use tokio_util::sync::CancellationToken;
#[derive(thiserror::Error, Debug)]
pub enum YieldingLoopError {
#[error("Cancelled")]
Cancelled,
}
/// Helper for long synchronous loops, e.g. over all tenants in the system. Periodically
/// yields to avoid blocking the executor, and after resuming checks the provided
/// cancellation token to drop out promptly on shutdown.
#[inline(always)]
pub async fn yielding_loop<I, T, F>(
interval: usize,
cancel: &CancellationToken,
iter: I,
mut visitor: F,
) -> Result<(), YieldingLoopError>
where
I: Iterator<Item = T>,
F: FnMut(T),
{
for (i, item) in iter.enumerate() {
visitor(item);
if i + 1 % interval == 0 {
tokio::task::yield_now().await;
if cancel.is_cancelled() {
return Err(YieldingLoopError::Cancelled);
}
}
}
Ok(())
}

View File

@@ -446,11 +446,12 @@ impl Runner {
if let Some(t) = self.last_upscale_request_at {
let elapsed = t.elapsed();
if elapsed < Duration::from_secs(1) {
// *Ideally* we'd like to log here that we're ignoring the fact the
// memory stats are too high, but in practice this can result in
// spamming the logs with repetitive messages about ignoring the signal
//
// See https://github.com/neondatabase/neon/issues/5865 for more.
info!(
elapsed_millis = elapsed.as_millis(),
avg_non_reclaimable = bytes_to_mebibytes(cgroup_mem_stat.avg_non_reclaimable),
threshold = bytes_to_mebibytes(cgroup.threshold),
"cgroup memory stats are high enough to upscale but too soon to forward the request, ignoring",
);
continue;
}
}

View File

@@ -1,2 +1 @@
#include "postgres.h"
#include "walproposer.h"

View File

@@ -1,6 +1,3 @@
//! Links with walproposer, pgcommon, pgport and runs bindgen on walproposer.h
//! to generate Rust bindings for it.
use std::{env, path::PathBuf, process::Command};
use anyhow::{anyhow, Context};

View File

@@ -1,6 +1,3 @@
//! A C-Rust shim: defines implementation of C walproposer API, assuming wp
//! callback_data stores Box to some Rust implementation.
#![allow(dead_code)]
use std::ffi::CStr;
@@ -8,12 +5,12 @@ use std::ffi::CString;
use crate::bindings::uint32;
use crate::bindings::walproposer_api;
use crate::bindings::NeonWALReadResult;
use crate::bindings::PGAsyncReadResult;
use crate::bindings::PGAsyncWriteResult;
use crate::bindings::Safekeeper;
use crate::bindings::Size;
use crate::bindings::StringInfoData;
use crate::bindings::TimeLineID;
use crate::bindings::TimestampTz;
use crate::bindings::WalProposer;
use crate::bindings::WalProposerConnStatusType;
@@ -178,11 +175,31 @@ extern "C" fn conn_blocking_write(
}
}
extern "C" fn recovery_download(wp: *mut WalProposer, sk: *mut Safekeeper) -> bool {
extern "C" fn recovery_download(
sk: *mut Safekeeper,
_timeline: TimeLineID,
startpos: XLogRecPtr,
endpos: XLogRecPtr,
) -> bool {
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).recovery_download(&mut (*wp), &mut (*sk))
(*api).recovery_download(&mut (*sk), startpos, endpos)
}
}
#[allow(clippy::unnecessary_cast)]
extern "C" fn wal_read(
sk: *mut Safekeeper,
buf: *mut ::std::os::raw::c_char,
startptr: XLogRecPtr,
count: Size,
) {
unsafe {
let buf = std::slice::from_raw_parts_mut(buf as *mut u8, count);
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).wal_read(&mut (*sk), buf, startptr)
}
}
@@ -194,28 +211,11 @@ extern "C" fn wal_reader_allocate(sk: *mut Safekeeper) {
}
}
#[allow(clippy::unnecessary_cast)]
extern "C" fn wal_read(
sk: *mut Safekeeper,
buf: *mut ::std::os::raw::c_char,
startptr: XLogRecPtr,
count: Size,
_errmsg: *mut *mut ::std::os::raw::c_char,
) -> NeonWALReadResult {
extern "C" fn free_event_set(wp: *mut WalProposer) {
unsafe {
let buf = std::slice::from_raw_parts_mut(buf as *mut u8, count);
let callback_data = (*(*(*sk).wp).config).callback_data;
let callback_data = (*(*wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
// TODO: errmsg is not forwarded
(*api).wal_read(&mut (*sk), buf, startptr)
}
}
extern "C" fn wal_reader_events(sk: *mut Safekeeper) -> uint32 {
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).wal_reader_events(&mut (*sk))
(*api).free_event_set(&mut (*wp));
}
}
@@ -235,14 +235,6 @@ extern "C" fn update_event_set(sk: *mut Safekeeper, events: uint32) {
}
}
extern "C" fn active_state_update_event_set(sk: *mut Safekeeper) {
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).active_state_update_event_set(&mut (*sk));
}
}
extern "C" fn add_safekeeper_event_set(sk: *mut Safekeeper, events: uint32) {
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
@@ -251,14 +243,6 @@ extern "C" fn add_safekeeper_event_set(sk: *mut Safekeeper, events: uint32) {
}
}
extern "C" fn rm_safekeeper_event_set(sk: *mut Safekeeper) {
unsafe {
let callback_data = (*(*(*sk).wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).rm_safekeeper_event_set(&mut (*sk));
}
}
extern "C" fn wait_event_set(
wp: *mut WalProposer,
timeout: ::std::os::raw::c_long,
@@ -326,6 +310,14 @@ extern "C" fn process_safekeeper_feedback(wp: *mut WalProposer, commit_lsn: XLog
}
}
extern "C" fn confirm_wal_streamed(wp: *mut WalProposer, lsn: XLogRecPtr) {
unsafe {
let callback_data = (*(*wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).confirm_wal_streamed(&mut (*wp), lsn)
}
}
extern "C" fn log_internal(
wp: *mut WalProposer,
level: ::std::os::raw::c_int,
@@ -340,6 +332,14 @@ extern "C" fn log_internal(
}
}
extern "C" fn after_election(wp: *mut WalProposer) {
unsafe {
let callback_data = (*(*wp).config).callback_data;
let api = callback_data as *mut Box<dyn ApiImpl>;
(*api).after_election(&mut (*wp))
}
}
#[derive(Debug)]
pub enum Level {
Debug5,
@@ -398,20 +398,20 @@ pub(crate) fn create_api() -> walproposer_api {
conn_async_write: Some(conn_async_write),
conn_blocking_write: Some(conn_blocking_write),
recovery_download: Some(recovery_download),
wal_reader_allocate: Some(wal_reader_allocate),
wal_read: Some(wal_read),
wal_reader_events: Some(wal_reader_events),
wal_reader_allocate: Some(wal_reader_allocate),
free_event_set: Some(free_event_set),
init_event_set: Some(init_event_set),
update_event_set: Some(update_event_set),
active_state_update_event_set: Some(active_state_update_event_set),
add_safekeeper_event_set: Some(add_safekeeper_event_set),
rm_safekeeper_event_set: Some(rm_safekeeper_event_set),
wait_event_set: Some(wait_event_set),
strong_random: Some(strong_random),
get_redo_start_lsn: Some(get_redo_start_lsn),
finish_sync_safekeepers: Some(finish_sync_safekeepers),
process_safekeeper_feedback: Some(process_safekeeper_feedback),
confirm_wal_streamed: Some(confirm_wal_streamed),
log_internal: Some(log_internal),
after_election: Some(after_election),
}
}

View File

@@ -6,8 +6,8 @@ use utils::id::TenantTimelineId;
use crate::{
api_bindings::{create_api, take_vec_u8, Level},
bindings::{
NeonWALReadResult, Safekeeper, WalProposer, WalProposerConfig, WalProposerCreate,
WalProposerFree, WalProposerStart,
Safekeeper, WalProposer, WalProposerConfig, WalProposerCreate, WalProposerFree,
WalProposerStart,
},
};
@@ -86,19 +86,19 @@ pub trait ApiImpl {
todo!()
}
fn recovery_download(&self, _wp: &mut WalProposer, _sk: &mut Safekeeper) -> bool {
fn recovery_download(&self, _sk: &mut Safekeeper, _startpos: u64, _endpos: u64) -> bool {
todo!()
}
fn wal_reader_allocate(&self, _sk: &mut Safekeeper) -> NeonWALReadResult {
fn wal_read(&self, _sk: &mut Safekeeper, _buf: &mut [u8], _startpos: u64) {
todo!()
}
fn wal_read(&self, _sk: &mut Safekeeper, _buf: &mut [u8], _startpos: u64) -> NeonWALReadResult {
fn wal_reader_allocate(&self, _sk: &mut Safekeeper) {
todo!()
}
fn wal_reader_events(&self, _sk: &mut Safekeeper) -> u32 {
fn free_event_set(&self, _wp: &mut WalProposer) {
todo!()
}
@@ -110,18 +110,10 @@ pub trait ApiImpl {
todo!()
}
fn active_state_update_event_set(&self, _sk: &mut Safekeeper) {
todo!()
}
fn add_safekeeper_event_set(&self, _sk: &mut Safekeeper, _events_mask: u32) {
todo!()
}
fn rm_safekeeper_event_set(&self, _sk: &mut Safekeeper) {
todo!()
}
fn wait_event_set(&self, _wp: &mut WalProposer, _timeout_millis: i64) -> WaitResult {
todo!()
}
@@ -142,6 +134,10 @@ pub trait ApiImpl {
todo!()
}
fn confirm_wal_streamed(&self, _wp: &mut WalProposer, _lsn: u64) {
todo!()
}
fn log_internal(&self, _wp: &mut WalProposer, _level: Level, _msg: &str) {
todo!()
}
@@ -244,7 +240,6 @@ impl Drop for Wrapper {
#[cfg(test)]
mod tests {
use core::panic;
use std::{
cell::Cell,
sync::{atomic::AtomicUsize, mpsc::sync_channel},
@@ -252,7 +247,7 @@ mod tests {
use utils::id::TenantTimelineId;
use crate::{api_bindings::Level, bindings::NeonWALReadResult, walproposer::Wrapper};
use crate::{api_bindings::Level, walproposer::Wrapper};
use super::ApiImpl;
@@ -360,17 +355,12 @@ mod tests {
true
}
fn recovery_download(
&self,
_wp: &mut crate::bindings::WalProposer,
_sk: &mut crate::bindings::Safekeeper,
) -> bool {
true
fn wal_reader_allocate(&self, _: &mut crate::bindings::Safekeeper) {
println!("wal_reader_allocate")
}
fn wal_reader_allocate(&self, _: &mut crate::bindings::Safekeeper) -> NeonWALReadResult {
println!("wal_reader_allocate");
crate::bindings::NeonWALReadResult_NEON_WALREAD_SUCCESS
fn free_event_set(&self, _: &mut crate::bindings::WalProposer) {
println!("free_event_set")
}
fn init_event_set(&self, _: &mut crate::bindings::WalProposer) {
@@ -393,13 +383,6 @@ mod tests {
self.wait_events.set(WaitEventsData { sk, event_mask });
}
fn rm_safekeeper_event_set(&self, sk: &mut crate::bindings::Safekeeper) {
println!(
"rm_safekeeper_event_set, sk={:?}",
sk as *mut crate::bindings::Safekeeper
);
}
fn wait_event_set(
&self,
_: &mut crate::bindings::WalProposer,
@@ -425,7 +408,7 @@ mod tests {
}
fn log_internal(&self, _wp: &mut crate::bindings::WalProposer, level: Level, msg: &str) {
println!("wp_log[{}] {}", level, msg);
println!("walprop_log[{}] {}", level, msg);
}
fn after_election(&self, _wp: &mut crate::bindings::WalProposer) {

View File

@@ -63,7 +63,6 @@ thiserror.workspace = true
tokio = { workspace = true, features = ["process", "sync", "fs", "rt", "io-util", "time"] }
tokio-io-timeout.workspace = true
tokio-postgres.workspace = true
tokio-stream.workspace = true
tokio-util.workspace = true
toml_edit = { workspace = true, features = [ "serde" ] }
tracing.workspace = true

View File

@@ -13,7 +13,6 @@ use bytes::{Buf, Bytes};
use pageserver::{
config::PageServerConf, repository::Key, walrecord::NeonWalRecord, walredo::PostgresRedoManager,
};
use pageserver_api::shard::TenantShardId;
use utils::{id::TenantId, lsn::Lsn};
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
@@ -27,9 +26,9 @@ fn redo_scenarios(c: &mut Criterion) {
let conf = PageServerConf::dummy_conf(repo_dir.path().to_path_buf());
let conf = Box::leak(Box::new(conf));
let tenant_shard_id = TenantShardId::unsharded(TenantId::generate());
let tenant_id = TenantId::generate();
let manager = PostgresRedoManager::new(conf, tenant_shard_id);
let manager = PostgresRedoManager::new(conf, tenant_id);
let manager = Arc::new(manager);

View File

@@ -1,22 +0,0 @@
[package]
name = "pageserver_client"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
pageserver_api.workspace = true
thiserror.workspace = true
async-trait.workspace = true
reqwest.workspace = true
utils.workspace = true
serde.workspace = true
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
tokio-postgres.workspace = true
tokio-stream.workspace = true
tokio.workspace = true
futures.workspace = true
tokio-util.workspace = true
anyhow.workspace = true
postgres.workspace = true
bytes.workspace = true

View File

@@ -1,2 +0,0 @@
pub mod mgmt_api;
pub mod page_service;

View File

@@ -1,278 +0,0 @@
use pageserver_api::{models::*, shard::TenantShardId};
use reqwest::{IntoUrl, Method, StatusCode};
use utils::{
http::error::HttpErrorBody,
id::{TenantId, TimelineId},
};
pub mod util;
#[derive(Debug)]
pub struct Client {
mgmt_api_endpoint: String,
authorization_header: Option<String>,
client: reqwest::Client,
}
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("receive body: {0}")]
ReceiveBody(reqwest::Error),
#[error("receive error body: {0}")]
ReceiveErrorBody(String),
#[error("pageserver API: {1}")]
ApiError(StatusCode, String),
}
pub type Result<T> = std::result::Result<T, Error>;
pub trait ResponseErrorMessageExt: Sized {
fn error_from_body(self) -> impl std::future::Future<Output = Result<Self>> + Send;
}
impl ResponseErrorMessageExt for reqwest::Response {
async fn error_from_body(self) -> Result<Self> {
let status = self.status();
if !(status.is_client_error() || status.is_server_error()) {
return Ok(self);
}
let url = self.url().to_owned();
Err(match self.json::<HttpErrorBody>().await {
Ok(HttpErrorBody { msg }) => Error::ApiError(status, msg),
Err(_) => {
Error::ReceiveErrorBody(format!("Http error ({}) at {}.", status.as_u16(), url))
}
})
}
}
pub enum ForceAwaitLogicalSize {
Yes,
No,
}
impl Client {
pub fn new(mgmt_api_endpoint: String, jwt: Option<&str>) -> Self {
Self {
mgmt_api_endpoint,
authorization_header: jwt.map(|jwt| format!("Bearer {jwt}")),
client: reqwest::Client::new(),
}
}
pub async fn list_tenants(&self) -> Result<Vec<pageserver_api::models::TenantInfo>> {
let uri = format!("{}/v1/tenant", self.mgmt_api_endpoint);
let resp = self.get(&uri).await?;
resp.json().await.map_err(Error::ReceiveBody)
}
pub async fn tenant_details(
&self,
tenant_shard_id: TenantShardId,
) -> Result<pageserver_api::models::TenantDetails> {
let uri = format!("{}/v1/tenant/{tenant_shard_id}", self.mgmt_api_endpoint);
self.get(uri)
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
pub async fn list_timelines(
&self,
tenant_shard_id: TenantShardId,
) -> Result<Vec<pageserver_api::models::TimelineInfo>> {
let uri = format!(
"{}/v1/tenant/{tenant_shard_id}/timeline",
self.mgmt_api_endpoint
);
self.get(&uri)
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
pub async fn timeline_info(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
force_await_logical_size: ForceAwaitLogicalSize,
) -> Result<pageserver_api::models::TimelineInfo> {
let uri = format!(
"{}/v1/tenant/{tenant_id}/timeline/{timeline_id}",
self.mgmt_api_endpoint
);
let uri = match force_await_logical_size {
ForceAwaitLogicalSize::Yes => format!("{}?force-await-logical-size={}", uri, true),
ForceAwaitLogicalSize::No => uri,
};
self.get(&uri)
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
pub async fn keyspace(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
) -> Result<pageserver_api::models::partitioning::Partitioning> {
let uri = format!(
"{}/v1/tenant/{tenant_id}/timeline/{timeline_id}/keyspace",
self.mgmt_api_endpoint
);
self.get(&uri)
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
async fn get<U: IntoUrl>(&self, uri: U) -> Result<reqwest::Response> {
self.request(Method::GET, uri, ()).await
}
async fn request<B: serde::Serialize, U: reqwest::IntoUrl>(
&self,
method: Method,
uri: U,
body: B,
) -> Result<reqwest::Response> {
let req = self.client.request(method, uri);
let req = if let Some(value) = &self.authorization_header {
req.header(reqwest::header::AUTHORIZATION, value)
} else {
req
};
let res = req.json(&body).send().await.map_err(Error::ReceiveBody)?;
let response = res.error_from_body().await?;
Ok(response)
}
pub async fn status(&self) -> Result<()> {
let uri = format!("{}/v1/status", self.mgmt_api_endpoint);
self.get(&uri).await?;
Ok(())
}
pub async fn tenant_create(&self, req: &TenantCreateRequest) -> Result<TenantId> {
let uri = format!("{}/v1/tenant", self.mgmt_api_endpoint);
self.request(Method::POST, &uri, req)
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
pub async fn tenant_config(&self, req: &TenantConfigRequest) -> Result<()> {
let uri = format!("{}/v1/tenant/config", self.mgmt_api_endpoint);
self.request(Method::PUT, &uri, req).await?;
Ok(())
}
pub async fn tenant_secondary_download(&self, tenant_id: TenantShardId) -> Result<()> {
let uri = format!(
"{}/v1/tenant/{}/secondary/download",
self.mgmt_api_endpoint, tenant_id
);
self.request(Method::POST, &uri, ()).await?;
Ok(())
}
pub async fn location_config(
&self,
tenant_shard_id: TenantShardId,
config: LocationConfig,
flush_ms: Option<std::time::Duration>,
) -> Result<()> {
let req_body = TenantLocationConfigRequest {
tenant_id: tenant_shard_id,
config,
};
let path = format!(
"{}/v1/tenant/{}/location_config",
self.mgmt_api_endpoint, tenant_shard_id
);
let path = if let Some(flush_ms) = flush_ms {
format!("{}?flush_ms={}", path, flush_ms.as_millis())
} else {
path
};
self.request(Method::PUT, &path, &req_body).await?;
Ok(())
}
pub async fn list_location_config(&self) -> Result<LocationConfigListResponse> {
let path = format!("{}/v1/location_config", self.mgmt_api_endpoint);
self.request(Method::GET, &path, ())
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
pub async fn timeline_create(
&self,
tenant_shard_id: TenantShardId,
req: &TimelineCreateRequest,
) -> Result<TimelineInfo> {
let uri = format!(
"{}/v1/tenant/{}/timeline",
self.mgmt_api_endpoint, tenant_shard_id
);
self.request(Method::POST, &uri, req)
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
pub async fn tenant_reset(&self, tenant_shard_id: TenantShardId) -> Result<()> {
let uri = format!(
"{}/v1/tenant/{}/reset",
self.mgmt_api_endpoint, tenant_shard_id
);
self.request(Method::POST, &uri, ())
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
pub async fn timeline_list(
&self,
tenant_shard_id: &TenantShardId,
) -> Result<Vec<TimelineInfo>> {
let uri = format!(
"{}/v1/tenant/{}/timeline",
self.mgmt_api_endpoint, tenant_shard_id
);
self.get(&uri)
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
pub async fn tenant_synthetic_size(
&self,
tenant_shard_id: TenantShardId,
) -> Result<TenantHistorySize> {
let uri = format!(
"{}/v1/tenant/{}/synthetic_size",
self.mgmt_api_endpoint, tenant_shard_id
);
self.get(&uri)
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
}

View File

@@ -1,53 +0,0 @@
//! Helpers to do common higher-level tasks with the [`Client`].
use std::sync::Arc;
use pageserver_api::shard::TenantShardId;
use tokio::task::JoinSet;
use utils::id::{TenantId, TenantTimelineId};
use super::Client;
/// Retrieve a list of all of the pageserver's timelines.
///
/// Fails if there are sharded tenants present on the pageserver.
pub async fn get_pageserver_tenant_timelines_unsharded(
api_client: &Arc<Client>,
) -> anyhow::Result<Vec<TenantTimelineId>> {
let mut timelines: Vec<TenantTimelineId> = Vec::new();
let mut tenants: Vec<TenantId> = Vec::new();
for ti in api_client.list_tenants().await? {
if !ti.id.is_unsharded() {
anyhow::bail!(
"only unsharded tenants are supported at this time: {}",
ti.id
);
}
tenants.push(ti.id.tenant_id)
}
let mut js = JoinSet::new();
for tenant_id in tenants {
js.spawn({
let mgmt_api_client = Arc::clone(api_client);
async move {
(
tenant_id,
mgmt_api_client
.tenant_details(TenantShardId::unsharded(tenant_id))
.await
.unwrap(),
)
}
});
}
while let Some(res) = js.join_next().await {
let (tenant_id, details) = res.unwrap();
for timeline_id in details.timelines {
timelines.push(TenantTimelineId {
tenant_id,
timeline_id,
});
}
}
Ok(timelines)
}

View File

@@ -1,167 +0,0 @@
use std::pin::Pin;
use futures::SinkExt;
use pageserver_api::{
models::{
PagestreamBeMessage, PagestreamFeMessage, PagestreamGetPageRequest,
PagestreamGetPageResponse,
},
reltag::RelTag,
};
use tokio::task::JoinHandle;
use tokio_postgres::CopyOutStream;
use tokio_stream::StreamExt;
use tokio_util::sync::CancellationToken;
use utils::{
id::{TenantId, TimelineId},
lsn::Lsn,
};
pub struct Client {
client: tokio_postgres::Client,
cancel_on_client_drop: Option<tokio_util::sync::DropGuard>,
conn_task: JoinHandle<()>,
}
pub struct BasebackupRequest {
pub tenant_id: TenantId,
pub timeline_id: TimelineId,
pub lsn: Option<Lsn>,
pub gzip: bool,
}
impl Client {
pub async fn new(connstring: String) -> anyhow::Result<Self> {
let (client, connection) = tokio_postgres::connect(&connstring, postgres::NoTls).await?;
let conn_task_cancel = CancellationToken::new();
let conn_task = tokio::spawn({
let conn_task_cancel = conn_task_cancel.clone();
async move {
tokio::select! {
_ = conn_task_cancel.cancelled() => { }
res = connection => {
res.unwrap();
}
}
}
});
Ok(Self {
cancel_on_client_drop: Some(conn_task_cancel.drop_guard()),
conn_task,
client,
})
}
pub async fn pagestream(
self,
tenant_id: TenantId,
timeline_id: TimelineId,
) -> anyhow::Result<PagestreamClient> {
let copy_both: tokio_postgres::CopyBothDuplex<bytes::Bytes> = self
.client
.copy_both_simple(&format!("pagestream {tenant_id} {timeline_id}"))
.await?;
let Client {
cancel_on_client_drop,
conn_task,
client: _,
} = self;
Ok(PagestreamClient {
copy_both: Box::pin(copy_both),
conn_task,
cancel_on_client_drop,
})
}
pub async fn basebackup(&self, req: &BasebackupRequest) -> anyhow::Result<CopyOutStream> {
let BasebackupRequest {
tenant_id,
timeline_id,
lsn,
gzip,
} = req;
let mut args = Vec::with_capacity(5);
args.push("basebackup".to_string());
args.push(format!("{tenant_id}"));
args.push(format!("{timeline_id}"));
if let Some(lsn) = lsn {
args.push(format!("{lsn}"));
}
if *gzip {
args.push("--gzip".to_string())
}
Ok(self.client.copy_out(&args.join(" ")).await?)
}
}
/// Create using [`Client::pagestream`].
pub struct PagestreamClient {
copy_both: Pin<Box<tokio_postgres::CopyBothDuplex<bytes::Bytes>>>,
cancel_on_client_drop: Option<tokio_util::sync::DropGuard>,
conn_task: JoinHandle<()>,
}
pub struct RelTagBlockNo {
pub rel_tag: RelTag,
pub block_no: u32,
}
impl PagestreamClient {
pub async fn shutdown(self) {
let Self {
copy_both,
cancel_on_client_drop: cancel_conn_task,
conn_task,
} = self;
// The `copy_both` contains internal channel sender, the receiver of which is polled by `conn_task`.
// When `conn_task` observes the sender has been dropped, it sends a `FeMessage::CopyFail` into the connection.
// (see https://github.com/neondatabase/rust-postgres/blob/2005bf79573b8add5cf205b52a2b208e356cc8b0/tokio-postgres/src/copy_both.rs#L56).
//
// If we drop(copy_both) first, but then immediately drop the `cancel_on_client_drop`,
// the CopyFail mesage only makes it to the socket sometimes (i.e., it's a race).
//
// Further, the pageserver makes a lot of noise when it receives CopyFail.
// Computes don't send it in practice, they just hard-close the connection.
//
// So, let's behave like the computes and suppress the CopyFail as follows:
// kill the socket first, then drop copy_both.
//
// See also: https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-COPY
//
// NB: page_service doesn't have a use case to exit the `pagestream` mode currently.
// => https://github.com/neondatabase/neon/issues/6390
let _ = cancel_conn_task.unwrap();
conn_task.await.unwrap();
drop(copy_both);
}
pub async fn getpage(
&mut self,
req: PagestreamGetPageRequest,
) -> anyhow::Result<PagestreamGetPageResponse> {
let req = PagestreamFeMessage::GetPage(req);
let req: bytes::Bytes = req.serialize();
// let mut req = tokio_util::io::ReaderStream::new(&req);
let mut req = tokio_stream::once(Ok(req));
self.copy_both.send_all(&mut req).await?;
let next: Option<Result<bytes::Bytes, _>> = self.copy_both.next().await;
let next: bytes::Bytes = next.unwrap()?;
let msg = PagestreamBeMessage::deserialize(next)?;
match msg {
PagestreamBeMessage::GetPage(p) => Ok(p),
PagestreamBeMessage::Error(e) => anyhow::bail!("Error: {:?}", e),
PagestreamBeMessage::Exists(_)
| PagestreamBeMessage::Nblocks(_)
| PagestreamBeMessage::DbSize(_) => {
anyhow::bail!(
"unexpected be message kind in response to getpage request: {}",
msg.kind()
)
}
}
}
}

View File

@@ -1,27 +0,0 @@
[package]
name = "pagebench"
version = "0.1.0"
edition.workspace = true
license.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow.workspace = true
camino.workspace = true
clap.workspace = true
futures.workspace = true
hdrhistogram.workspace = true
humantime.workspace = true
humantime-serde.workspace = true
rand.workspace = true
serde.workspace = true
serde_json.workspace = true
tracing.workspace = true
tokio.workspace = true
tokio-util.workspace = true
pageserver_client.workspace = true
pageserver_api.workspace = true
utils = { path = "../../libs/utils/" }
workspace_hack = { version = "0.1", path = "../../workspace_hack" }

View File

@@ -1,275 +0,0 @@
use anyhow::Context;
use pageserver_client::mgmt_api::ForceAwaitLogicalSize;
use pageserver_client::page_service::BasebackupRequest;
use utils::id::TenantTimelineId;
use utils::lsn::Lsn;
use rand::prelude::*;
use tokio::sync::Barrier;
use tokio::task::JoinSet;
use tracing::{debug, info, instrument};
use std::collections::HashMap;
use std::num::NonZeroUsize;
use std::ops::Range;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::Instant;
use crate::util::tokio_thread_local_stats::AllThreadLocalStats;
use crate::util::{request_stats, tokio_thread_local_stats};
/// basebackup@LatestLSN
#[derive(clap::Parser)]
pub(crate) struct Args {
#[clap(long, default_value = "http://localhost:9898")]
mgmt_api_endpoint: String,
#[clap(long, default_value = "localhost:64000")]
page_service_host_port: String,
#[clap(long)]
pageserver_jwt: Option<String>,
#[clap(long, default_value = "1")]
num_clients: NonZeroUsize,
#[clap(long, default_value = "1.0")]
gzip_probability: f64,
#[clap(long)]
runtime: Option<humantime::Duration>,
#[clap(long)]
limit_to_first_n_targets: Option<usize>,
targets: Option<Vec<TenantTimelineId>>,
}
#[derive(Debug, Default)]
struct LiveStats {
completed_requests: AtomicU64,
}
impl LiveStats {
fn inc(&self) {
self.completed_requests.fetch_add(1, Ordering::Relaxed);
}
}
struct Target {
timeline: TenantTimelineId,
lsn_range: Option<Range<Lsn>>,
}
#[derive(serde::Serialize)]
struct Output {
total: request_stats::Output,
}
tokio_thread_local_stats::declare!(STATS: request_stats::Stats);
pub(crate) fn main(args: Args) -> anyhow::Result<()> {
tokio_thread_local_stats::main!(STATS, move |thread_local_stats| {
main_impl(args, thread_local_stats)
})
}
async fn main_impl(
args: Args,
all_thread_local_stats: AllThreadLocalStats<request_stats::Stats>,
) -> anyhow::Result<()> {
let args: &'static Args = Box::leak(Box::new(args));
let mgmt_api_client = Arc::new(pageserver_client::mgmt_api::Client::new(
args.mgmt_api_endpoint.clone(),
args.pageserver_jwt.as_deref(),
));
// discover targets
let timelines: Vec<TenantTimelineId> = crate::util::cli::targets::discover(
&mgmt_api_client,
crate::util::cli::targets::Spec {
limit_to_first_n_targets: args.limit_to_first_n_targets,
targets: args.targets.clone(),
},
)
.await?;
let mut js = JoinSet::new();
for timeline in &timelines {
js.spawn({
let timeline = *timeline;
let info = mgmt_api_client
.timeline_info(
timeline.tenant_id,
timeline.timeline_id,
ForceAwaitLogicalSize::No,
)
.await
.unwrap();
async move {
anyhow::Ok(Target {
timeline,
// TODO: support lsn_range != latest LSN
lsn_range: Some(info.last_record_lsn..(info.last_record_lsn + 1)),
})
}
});
}
let mut all_targets: Vec<Target> = Vec::new();
while let Some(res) = js.join_next().await {
all_targets.push(res.unwrap().unwrap());
}
let live_stats = Arc::new(LiveStats::default());
let num_client_tasks = timelines.len();
let num_live_stats_dump = 1;
let num_work_sender_tasks = 1;
let start_work_barrier = Arc::new(tokio::sync::Barrier::new(
num_client_tasks + num_live_stats_dump + num_work_sender_tasks,
));
let all_work_done_barrier = Arc::new(tokio::sync::Barrier::new(num_client_tasks));
tokio::spawn({
let stats = Arc::clone(&live_stats);
let start_work_barrier = Arc::clone(&start_work_barrier);
async move {
start_work_barrier.wait().await;
loop {
let start = std::time::Instant::now();
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
let completed_requests = stats.completed_requests.swap(0, Ordering::Relaxed);
let elapsed = start.elapsed();
info!(
"RPS: {:.0}",
completed_requests as f64 / elapsed.as_secs_f64()
);
}
}
});
let mut work_senders = HashMap::new();
let mut tasks = Vec::new();
for tl in &timelines {
let (sender, receiver) = tokio::sync::mpsc::channel(1); // TODO: not sure what the implications of this are
work_senders.insert(tl, sender);
tasks.push(tokio::spawn(client(
args,
*tl,
Arc::clone(&start_work_barrier),
receiver,
Arc::clone(&all_work_done_barrier),
Arc::clone(&live_stats),
)));
}
let work_sender = async move {
start_work_barrier.wait().await;
loop {
let (timeline, work) = {
let mut rng = rand::thread_rng();
let target = all_targets.choose(&mut rng).unwrap();
let lsn = target.lsn_range.clone().map(|r| rng.gen_range(r));
(
target.timeline,
Work {
lsn,
gzip: rng.gen_bool(args.gzip_probability),
},
)
};
let sender = work_senders.get(&timeline).unwrap();
// TODO: what if this blocks?
sender.send(work).await.ok().unwrap();
}
};
if let Some(runtime) = args.runtime {
match tokio::time::timeout(runtime.into(), work_sender).await {
Ok(()) => unreachable!("work sender never terminates"),
Err(_timeout) => {
// this implicitly drops the work_senders, making all the clients exit
}
}
} else {
work_sender.await;
unreachable!("work sender never terminates");
}
for t in tasks {
t.await.unwrap();
}
let output = Output {
total: {
let mut agg_stats = request_stats::Stats::new();
for stats in all_thread_local_stats.lock().unwrap().iter() {
let stats = stats.lock().unwrap();
agg_stats.add(&stats);
}
agg_stats.output()
},
};
let output = serde_json::to_string_pretty(&output).unwrap();
println!("{output}");
anyhow::Ok(())
}
#[derive(Copy, Clone)]
struct Work {
lsn: Option<Lsn>,
gzip: bool,
}
#[instrument(skip_all)]
async fn client(
args: &'static Args,
timeline: TenantTimelineId,
start_work_barrier: Arc<Barrier>,
mut work: tokio::sync::mpsc::Receiver<Work>,
all_work_done_barrier: Arc<Barrier>,
live_stats: Arc<LiveStats>,
) {
start_work_barrier.wait().await;
let client = pageserver_client::page_service::Client::new(crate::util::connstring::connstring(
&args.page_service_host_port,
args.pageserver_jwt.as_deref(),
))
.await
.unwrap();
while let Some(Work { lsn, gzip }) = work.recv().await {
let start = Instant::now();
let copy_out_stream = client
.basebackup(&BasebackupRequest {
tenant_id: timeline.tenant_id,
timeline_id: timeline.timeline_id,
lsn,
gzip,
})
.await
.with_context(|| format!("start basebackup for {timeline}"))
.unwrap();
use futures::StreamExt;
let size = Arc::new(AtomicUsize::new(0));
copy_out_stream
.for_each({
|r| {
let size = Arc::clone(&size);
async move {
let size = Arc::clone(&size);
size.fetch_add(r.unwrap().len(), Ordering::Relaxed);
}
}
})
.await;
debug!("basebackup size is {} bytes", size.load(Ordering::Relaxed));
let elapsed = start.elapsed();
live_stats.inc();
STATS.with(|stats| {
stats.borrow().lock().unwrap().observe(elapsed).unwrap();
});
}
all_work_done_barrier.wait().await;
}

View File

@@ -1,430 +0,0 @@
use anyhow::Context;
use camino::Utf8PathBuf;
use futures::future::join_all;
use pageserver_api::key::{is_rel_block_key, key_to_rel_block, Key};
use pageserver_api::keyspace::KeySpaceAccum;
use pageserver_api::models::PagestreamGetPageRequest;
use tokio_util::sync::CancellationToken;
use utils::id::TenantTimelineId;
use utils::lsn::Lsn;
use rand::prelude::*;
use tokio::sync::Barrier;
use tokio::task::JoinSet;
use tracing::{info, instrument};
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::num::NonZeroUsize;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use crate::util::tokio_thread_local_stats::AllThreadLocalStats;
use crate::util::{request_stats, tokio_thread_local_stats};
/// GetPage@LatestLSN, uniformly distributed across the compute-accessible keyspace.
#[derive(clap::Parser)]
pub(crate) struct Args {
#[clap(long, default_value = "http://localhost:9898")]
mgmt_api_endpoint: String,
#[clap(long, default_value = "postgres://postgres@localhost:64000")]
page_service_connstring: String,
#[clap(long)]
pageserver_jwt: Option<String>,
#[clap(long, default_value = "1")]
num_clients: NonZeroUsize,
#[clap(long)]
runtime: Option<humantime::Duration>,
#[clap(long)]
per_target_rate_limit: Option<usize>,
/// Probability for sending `latest=true` in the request (uniform distribution).
#[clap(long, default_value = "1")]
req_latest_probability: f64,
#[clap(long)]
limit_to_first_n_targets: Option<usize>,
/// For large pageserver installations, enumerating the keyspace takes a lot of time.
/// If specified, the specified path is used to maintain a cache of the keyspace enumeration result.
/// The cache is tagged and auto-invalided by the tenant/timeline ids only.
/// It doesn't get invalidated if the keyspace changes under the hood, e.g., due to new ingested data or compaction.
#[clap(long)]
keyspace_cache: Option<Utf8PathBuf>,
targets: Option<Vec<TenantTimelineId>>,
}
#[derive(Debug, Default)]
struct LiveStats {
completed_requests: AtomicU64,
}
impl LiveStats {
fn inc(&self) {
self.completed_requests.fetch_add(1, Ordering::Relaxed);
}
}
#[derive(Clone, serde::Serialize, serde::Deserialize)]
struct KeyRange {
timeline: TenantTimelineId,
timeline_lsn: Lsn,
start: i128,
end: i128,
}
impl KeyRange {
fn len(&self) -> i128 {
self.end - self.start
}
}
#[derive(serde::Serialize)]
struct Output {
total: request_stats::Output,
}
tokio_thread_local_stats::declare!(STATS: request_stats::Stats);
pub(crate) fn main(args: Args) -> anyhow::Result<()> {
tokio_thread_local_stats::main!(STATS, move |thread_local_stats| {
main_impl(args, thread_local_stats)
})
}
async fn main_impl(
args: Args,
all_thread_local_stats: AllThreadLocalStats<request_stats::Stats>,
) -> anyhow::Result<()> {
let args: &'static Args = Box::leak(Box::new(args));
let mgmt_api_client = Arc::new(pageserver_client::mgmt_api::Client::new(
args.mgmt_api_endpoint.clone(),
args.pageserver_jwt.as_deref(),
));
// discover targets
let timelines: Vec<TenantTimelineId> = crate::util::cli::targets::discover(
&mgmt_api_client,
crate::util::cli::targets::Spec {
limit_to_first_n_targets: args.limit_to_first_n_targets,
targets: args.targets.clone(),
},
)
.await?;
#[derive(serde::Deserialize)]
struct KeyspaceCacheDe {
tag: Vec<TenantTimelineId>,
data: Vec<KeyRange>,
}
#[derive(serde::Serialize)]
struct KeyspaceCacheSer<'a> {
tag: &'a [TenantTimelineId],
data: &'a [KeyRange],
}
let cache = args
.keyspace_cache
.as_ref()
.map(|keyspace_cache_file| {
let contents = match std::fs::read(keyspace_cache_file) {
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
return anyhow::Ok(None);
}
x => x.context("read keyspace cache file")?,
};
let cache: KeyspaceCacheDe =
serde_json::from_slice(&contents).context("deserialize cache file")?;
let tag_ok = HashSet::<TenantTimelineId>::from_iter(cache.tag.into_iter())
== HashSet::from_iter(timelines.iter().cloned());
info!("keyspace cache file matches tag: {tag_ok}");
anyhow::Ok(if tag_ok { Some(cache.data) } else { None })
})
.transpose()?
.flatten();
let all_ranges: Vec<KeyRange> = if let Some(cached) = cache {
info!("using keyspace cache file");
cached
} else {
let mut js = JoinSet::new();
for timeline in &timelines {
js.spawn({
let mgmt_api_client = Arc::clone(&mgmt_api_client);
let timeline = *timeline;
async move {
let partitioning = mgmt_api_client
.keyspace(timeline.tenant_id, timeline.timeline_id)
.await?;
let lsn = partitioning.at_lsn;
let start = Instant::now();
let mut filtered = KeySpaceAccum::new();
// let's hope this is inlined and vectorized...
// TODO: turn this loop into a is_rel_block_range() function.
for r in partitioning.keys.ranges.iter() {
let mut i = r.start;
while i != r.end {
if is_rel_block_key(&i) {
filtered.add_key(i);
}
i = i.next();
}
}
let filtered = filtered.to_keyspace();
let filter_duration = start.elapsed();
anyhow::Ok((
filter_duration,
filtered.ranges.into_iter().map(move |r| KeyRange {
timeline,
timeline_lsn: lsn,
start: r.start.to_i128(),
end: r.end.to_i128(),
}),
))
}
});
}
let mut total_filter_duration = Duration::from_secs(0);
let mut all_ranges: Vec<KeyRange> = Vec::new();
while let Some(res) = js.join_next().await {
let (filter_duration, range) = res.unwrap().unwrap();
all_ranges.extend(range);
total_filter_duration += filter_duration;
}
info!("filter duration: {}", total_filter_duration.as_secs_f64());
if let Some(cachefile) = args.keyspace_cache.as_ref() {
let cache = KeyspaceCacheSer {
tag: &timelines,
data: &all_ranges,
};
let bytes = serde_json::to_vec(&cache).context("serialize keyspace for cache file")?;
std::fs::write(cachefile, bytes).context("write keyspace cache file to disk")?;
info!("successfully wrote keyspace cache file");
}
all_ranges
};
let live_stats = Arc::new(LiveStats::default());
let num_client_tasks = timelines.len();
let num_live_stats_dump = 1;
let num_work_sender_tasks = 1;
let num_main_impl = 1;
let start_work_barrier = Arc::new(tokio::sync::Barrier::new(
num_client_tasks + num_live_stats_dump + num_work_sender_tasks + num_main_impl,
));
tokio::spawn({
let stats = Arc::clone(&live_stats);
let start_work_barrier = Arc::clone(&start_work_barrier);
async move {
start_work_barrier.wait().await;
loop {
let start = std::time::Instant::now();
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
let completed_requests = stats.completed_requests.swap(0, Ordering::Relaxed);
let elapsed = start.elapsed();
info!(
"RPS: {:.0}",
completed_requests as f64 / elapsed.as_secs_f64()
);
}
}
});
let cancel = CancellationToken::new();
let mut work_senders: HashMap<TenantTimelineId, _> = HashMap::new();
let mut tasks = Vec::new();
for tl in &timelines {
let (sender, receiver) = tokio::sync::mpsc::channel(10); // TODO: not sure what the implications of this are
work_senders.insert(*tl, sender);
tasks.push(tokio::spawn(client(
args,
*tl,
Arc::clone(&start_work_barrier),
receiver,
Arc::clone(&live_stats),
cancel.clone(),
)));
}
let work_sender: Pin<Box<dyn Send + Future<Output = ()>>> = {
let start_work_barrier = start_work_barrier.clone();
let cancel = cancel.clone();
match args.per_target_rate_limit {
None => Box::pin(async move {
let weights = rand::distributions::weighted::WeightedIndex::new(
all_ranges.iter().map(|v| v.len()),
)
.unwrap();
start_work_barrier.wait().await;
while !cancel.is_cancelled() {
let (timeline, req) = {
let mut rng = rand::thread_rng();
let r = &all_ranges[weights.sample(&mut rng)];
let key: i128 = rng.gen_range(r.start..r.end);
let key = Key::from_i128(key);
let (rel_tag, block_no) =
key_to_rel_block(key).expect("we filter non-rel-block keys out above");
(
r.timeline,
PagestreamGetPageRequest {
latest: rng.gen_bool(args.req_latest_probability),
lsn: r.timeline_lsn,
rel: rel_tag,
blkno: block_no,
},
)
};
let sender = work_senders.get(&timeline).unwrap();
// TODO: what if this blocks?
if sender.send(req).await.is_err() {
assert!(cancel.is_cancelled(), "client has gone away unexpectedly");
}
}
}),
Some(rps_limit) => Box::pin(async move {
let period = Duration::from_secs_f64(1.0 / (rps_limit as f64));
let make_timeline_task: &dyn Fn(
TenantTimelineId,
)
-> Pin<Box<dyn Send + Future<Output = ()>>> = &|timeline| {
let sender = work_senders.get(&timeline).unwrap();
let ranges: Vec<KeyRange> = all_ranges
.iter()
.filter(|r| r.timeline == timeline)
.cloned()
.collect();
let weights = rand::distributions::weighted::WeightedIndex::new(
ranges.iter().map(|v| v.len()),
)
.unwrap();
let cancel = cancel.clone();
Box::pin(async move {
let mut ticker = tokio::time::interval(period);
ticker.set_missed_tick_behavior(
/* TODO review this choice */
tokio::time::MissedTickBehavior::Burst,
);
while !cancel.is_cancelled() {
ticker.tick().await;
let req = {
let mut rng = rand::thread_rng();
let r = &ranges[weights.sample(&mut rng)];
let key: i128 = rng.gen_range(r.start..r.end);
let key = Key::from_i128(key);
assert!(is_rel_block_key(&key));
let (rel_tag, block_no) = key_to_rel_block(key)
.expect("we filter non-rel-block keys out above");
PagestreamGetPageRequest {
latest: rng.gen_bool(args.req_latest_probability),
lsn: r.timeline_lsn,
rel: rel_tag,
blkno: block_no,
}
};
if sender.send(req).await.is_err() {
assert!(cancel.is_cancelled(), "client has gone away unexpectedly");
}
}
})
};
let tasks: Vec<_> = work_senders
.keys()
.map(|tl| make_timeline_task(*tl))
.collect();
start_work_barrier.wait().await;
join_all(tasks).await;
}),
}
};
let work_sender_task = tokio::spawn(work_sender);
info!("waiting for everything to become ready");
start_work_barrier.wait().await;
info!("work started");
if let Some(runtime) = args.runtime {
tokio::time::sleep(runtime.into()).await;
info!("runtime over, signalling cancellation");
cancel.cancel();
work_sender_task.await.unwrap();
info!("work sender exited");
} else {
work_sender_task.await.unwrap();
unreachable!("work sender never terminates");
}
info!("joining clients");
for t in tasks {
t.await.unwrap();
}
info!("all clients stopped");
let output = Output {
total: {
let mut agg_stats = request_stats::Stats::new();
for stats in all_thread_local_stats.lock().unwrap().iter() {
let stats = stats.lock().unwrap();
agg_stats.add(&stats);
}
agg_stats.output()
},
};
let output = serde_json::to_string_pretty(&output).unwrap();
println!("{output}");
anyhow::Ok(())
}
#[instrument(skip_all)]
async fn client(
args: &'static Args,
timeline: TenantTimelineId,
start_work_barrier: Arc<Barrier>,
mut work: tokio::sync::mpsc::Receiver<PagestreamGetPageRequest>,
live_stats: Arc<LiveStats>,
cancel: CancellationToken,
) {
let client = pageserver_client::page_service::Client::new(args.page_service_connstring.clone())
.await
.unwrap();
let mut client = client
.pagestream(timeline.tenant_id, timeline.timeline_id)
.await
.unwrap();
let do_requests = async {
start_work_barrier.wait().await;
while let Some(req) = work.recv().await {
let start = Instant::now();
client
.getpage(req)
.await
.with_context(|| format!("getpage for {timeline}"))
.unwrap();
let elapsed = start.elapsed();
live_stats.inc();
STATS.with(|stats| {
stats.borrow().lock().unwrap().observe(elapsed).unwrap();
});
}
};
tokio::select! {
res = do_requests => { res },
_ = cancel.cancelled() => {
client.shutdown().await;
return;
}
}
}

View File

@@ -1,88 +0,0 @@
use std::sync::Arc;
use humantime::Duration;
use tokio::task::JoinSet;
use utils::id::TenantTimelineId;
use pageserver_client::mgmt_api::ForceAwaitLogicalSize;
#[derive(clap::Parser)]
pub(crate) struct Args {
#[clap(long, default_value = "http://localhost:9898")]
mgmt_api_endpoint: String,
#[clap(long, default_value = "localhost:64000")]
page_service_host_port: String,
#[clap(long)]
pageserver_jwt: Option<String>,
#[clap(
long,
help = "if specified, poll mgmt api to check whether init logical size calculation has completed"
)]
poll_for_completion: Option<Duration>,
#[clap(long)]
limit_to_first_n_targets: Option<usize>,
targets: Option<Vec<TenantTimelineId>>,
}
pub(crate) fn main(args: Args) -> anyhow::Result<()> {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
let main_task = rt.spawn(main_impl(args));
rt.block_on(main_task).unwrap()
}
async fn main_impl(args: Args) -> anyhow::Result<()> {
let args: &'static Args = Box::leak(Box::new(args));
let mgmt_api_client = Arc::new(pageserver_client::mgmt_api::Client::new(
args.mgmt_api_endpoint.clone(),
args.pageserver_jwt.as_deref(),
));
// discover targets
let timelines: Vec<TenantTimelineId> = crate::util::cli::targets::discover(
&mgmt_api_client,
crate::util::cli::targets::Spec {
limit_to_first_n_targets: args.limit_to_first_n_targets,
targets: args.targets.clone(),
},
)
.await?;
// kick it off
let mut js = JoinSet::new();
for tl in timelines {
let mgmt_api_client = Arc::clone(&mgmt_api_client);
js.spawn(async move {
let info = mgmt_api_client
.timeline_info(tl.tenant_id, tl.timeline_id, ForceAwaitLogicalSize::Yes)
.await
.unwrap();
// Polling should not be strictly required here since we await
// for the initial logical size, however it's possible for the request
// to land before the timeline is initialised. This results in an approximate
// logical size.
if let Some(period) = args.poll_for_completion {
let mut ticker = tokio::time::interval(period.into());
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
let mut info = info;
while !info.current_logical_size_is_accurate {
ticker.tick().await;
info = mgmt_api_client
.timeline_info(tl.tenant_id, tl.timeline_id, ForceAwaitLogicalSize::Yes)
.await
.unwrap();
}
}
});
}
while let Some(res) = js.join_next().await {
let _: () = res.unwrap();
}
Ok(())
}

View File

@@ -1,49 +0,0 @@
use clap::Parser;
use utils::logging;
/// Re-usable pieces of code that aren't CLI-specific.
mod util {
pub(crate) mod connstring;
pub(crate) mod request_stats;
#[macro_use]
pub(crate) mod tokio_thread_local_stats;
/// Re-usable pieces of CLI-specific code.
pub(crate) mod cli {
pub(crate) mod targets;
}
}
/// The pagebench CLI sub-commands, dispatched in [`main`] below.
mod cmd {
pub(super) mod basebackup;
pub(super) mod getpage_latest_lsn;
pub(super) mod trigger_initial_size_calculation;
}
/// Component-level performance test for pageserver.
#[derive(clap::Parser)]
enum Args {
Basebackup(cmd::basebackup::Args),
GetPageLatestLsn(cmd::getpage_latest_lsn::Args),
TriggerInitialSizeCalculation(cmd::trigger_initial_size_calculation::Args),
}
fn main() {
logging::init(
logging::LogFormat::Plain,
logging::TracingErrorLayerEnablement::Disabled,
logging::Output::Stderr,
)
.unwrap();
logging::replace_panic_hook_with_tracing_panic_hook().forget();
let args = Args::parse();
match args {
Args::Basebackup(args) => cmd::basebackup::main(args),
Args::GetPageLatestLsn(args) => cmd::getpage_latest_lsn::main(args),
Args::TriggerInitialSizeCalculation(args) => {
cmd::trigger_initial_size_calculation::main(args)
}
}
.unwrap()
}

View File

@@ -1,34 +0,0 @@
use std::sync::Arc;
use pageserver_client::mgmt_api;
use tracing::info;
use utils::id::TenantTimelineId;
pub(crate) struct Spec {
pub(crate) limit_to_first_n_targets: Option<usize>,
pub(crate) targets: Option<Vec<TenantTimelineId>>,
}
pub(crate) async fn discover(
api_client: &Arc<mgmt_api::Client>,
spec: Spec,
) -> anyhow::Result<Vec<TenantTimelineId>> {
let mut timelines = if let Some(targets) = spec.targets {
targets
} else {
mgmt_api::util::get_pageserver_tenant_timelines_unsharded(api_client).await?
};
if let Some(limit) = spec.limit_to_first_n_targets {
timelines.sort(); // for determinism
timelines.truncate(limit);
if timelines.len() < limit {
anyhow::bail!("pageserver has less than limit_to_first_n_targets={limit} tenants");
}
}
info!("timelines:\n{:?}", timelines);
info!("number of timelines:\n{:?}", timelines.len());
Ok(timelines)
}

View File

@@ -1,8 +0,0 @@
pub(crate) fn connstring(host_port: &str, jwt: Option<&str>) -> String {
let colon_and_jwt = if let Some(jwt) = jwt {
format!(":{jwt}") // TODO: urlescape
} else {
String::new()
};
format!("postgres://postgres{colon_and_jwt}@{host_port}")
}

View File

@@ -1,88 +0,0 @@
use std::time::Duration;
use anyhow::Context;
pub(crate) struct Stats {
latency_histo: hdrhistogram::Histogram<u64>,
}
impl Stats {
pub(crate) fn new() -> Self {
Self {
// Initialize with fixed bounds so that we panic at runtime instead of resizing the histogram,
// which would skew the benchmark results.
latency_histo: hdrhistogram::Histogram::new_with_bounds(1, 1_000_000_000, 3).unwrap(),
}
}
pub(crate) fn observe(&mut self, latency: Duration) -> anyhow::Result<()> {
let micros: u64 = latency
.as_micros()
.try_into()
.context("latency greater than u64")?;
self.latency_histo
.record(micros)
.context("add to histogram")?;
Ok(())
}
pub(crate) fn output(&self) -> Output {
let latency_percentiles = std::array::from_fn(|idx| {
let micros = self
.latency_histo
.value_at_percentile(LATENCY_PERCENTILES[idx]);
Duration::from_micros(micros)
});
Output {
request_count: self.latency_histo.len(),
latency_mean: Duration::from_micros(self.latency_histo.mean() as u64),
latency_percentiles: LatencyPercentiles {
latency_percentiles,
},
}
}
pub(crate) fn add(&mut self, other: &Self) {
let Self {
ref mut latency_histo,
} = self;
latency_histo.add(&other.latency_histo).unwrap();
}
}
impl Default for Stats {
fn default() -> Self {
Self::new()
}
}
const LATENCY_PERCENTILES: [f64; 4] = [95.0, 99.00, 99.90, 99.99];
struct LatencyPercentiles {
latency_percentiles: [Duration; 4],
}
impl serde::Serialize for LatencyPercentiles {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeMap;
let mut ser = serializer.serialize_map(Some(LATENCY_PERCENTILES.len()))?;
for p in LATENCY_PERCENTILES {
ser.serialize_entry(
&format!("p{p}"),
&format!(
"{}",
&humantime::format_duration(self.latency_percentiles[0])
),
)?;
}
ser.end()
}
}
#[derive(serde::Serialize)]
pub(crate) struct Output {
request_count: u64,
#[serde(with = "humantime_serde")]
latency_mean: Duration,
latency_percentiles: LatencyPercentiles,
}

Some files were not shown because too many files have changed in this diff Show More