mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-18 10:52:55 +00:00
Compare commits
1 Commits
hack/compu
...
proxy-prot
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fdf4a3c922 |
@@ -22,11 +22,5 @@ platforms = [
|
||||
# "x86_64-pc-windows-msvc",
|
||||
]
|
||||
|
||||
[final-excludes]
|
||||
# vm_monitor benefits from the same Cargo.lock as the rest of our artifacts, but
|
||||
# it is built primarly in separate repo neondatabase/autoscaling and thus is excluded
|
||||
# from depending on workspace-hack because most of the dependencies are not used.
|
||||
workspace-members = ["vm_monitor"]
|
||||
|
||||
# Write out exact versions rather than a semver range. (Defaults to false.)
|
||||
# exact-versions = true
|
||||
|
||||
5
.github/ISSUE_TEMPLATE/epic-template.md
vendored
5
.github/ISSUE_TEMPLATE/epic-template.md
vendored
@@ -17,9 +17,8 @@ assignees: ''
|
||||
## Implementation ideas
|
||||
|
||||
|
||||
```[tasklist]
|
||||
### Tasks
|
||||
```
|
||||
## Tasks
|
||||
- [ ]
|
||||
|
||||
|
||||
## Other related tasks and Epics
|
||||
|
||||
2
.github/PULL_REQUEST_TEMPLATE/release-pr.md
vendored
2
.github/PULL_REQUEST_TEMPLATE/release-pr.md
vendored
@@ -3,7 +3,7 @@
|
||||
**NB: this PR must be merged only by 'Create a merge commit'!**
|
||||
|
||||
### Checklist when preparing for release
|
||||
- [ ] Read or refresh [the release flow guide](https://www.notion.so/neondatabase/Release-general-flow-61f2e39fd45d4d14a70c7749604bd70b)
|
||||
- [ ] Read or refresh [the release flow guide](https://github.com/neondatabase/cloud/wiki/Release:-general-flow)
|
||||
- [ ] Ask in the [cloud Slack channel](https://neondb.slack.com/archives/C033A2WE6BZ) that you are going to rollout the release. Any blockers?
|
||||
- [ ] Does this release contain any db migrations? Destructive ones? What is the rollback plan?
|
||||
|
||||
|
||||
4
.github/actionlint.yml
vendored
4
.github/actionlint.yml
vendored
@@ -1,12 +1,8 @@
|
||||
self-hosted-runner:
|
||||
labels:
|
||||
- arm64
|
||||
- dev
|
||||
- gen3
|
||||
- large
|
||||
- small
|
||||
- us-east-2
|
||||
config-variables:
|
||||
- REMOTE_STORAGE_AZURE_CONTAINER
|
||||
- REMOTE_STORAGE_AZURE_REGION
|
||||
- SLACK_UPCOMING_RELEASE_CHANNEL_ID
|
||||
|
||||
@@ -76,8 +76,8 @@ runs:
|
||||
rm -f ${ALLURE_ZIP}
|
||||
fi
|
||||
env:
|
||||
ALLURE_VERSION: 2.24.0
|
||||
ALLURE_ZIP_SHA256: 60b1d6ce65d9ef24b23cf9c2c19fd736a123487c38e54759f1ed1a7a77353c90
|
||||
ALLURE_VERSION: 2.23.1
|
||||
ALLURE_ZIP_SHA256: 11141bfe727504b3fd80c0f9801eb317407fd0ac983ebb57e671f14bac4bcd86
|
||||
|
||||
# Potentially we could have several running build for the same key (for example, for the main branch), so we use improvised lock for this
|
||||
- name: Acquire lock
|
||||
@@ -203,10 +203,6 @@ runs:
|
||||
COMMIT_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
BASE_S3_URL: ${{ steps.generate-report.outputs.base-s3-url }}
|
||||
run: |
|
||||
if [ ! -d "${WORKDIR}/report/data/test-cases" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
export DATABASE_URL=${REGRESS_TEST_RESULT_CONNSTR_NEW}
|
||||
|
||||
./scripts/pysync
|
||||
|
||||
37
.github/workflows/build_and_test.yml
vendored
37
.github/workflows/build_and_test.yml
vendored
@@ -172,10 +172,10 @@ jobs:
|
||||
# https://github.com/EmbarkStudios/cargo-deny
|
||||
- name: Check rust licenses/bans/advisories/sources
|
||||
if: ${{ !cancelled() }}
|
||||
run: cargo deny check --hide-inclusion-graph
|
||||
run: cargo deny check
|
||||
|
||||
build-neon:
|
||||
needs: [ check-permissions, tag ]
|
||||
needs: [ check-permissions ]
|
||||
runs-on: [ self-hosted, gen3, large ]
|
||||
container:
|
||||
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/rust:pinned
|
||||
@@ -187,7 +187,6 @@ jobs:
|
||||
env:
|
||||
BUILD_TYPE: ${{ matrix.build_type }}
|
||||
GIT_VERSION: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
BUILD_TAG: ${{ needs.tag.outputs.build-tag }}
|
||||
|
||||
steps:
|
||||
- name: Fix git ownership
|
||||
@@ -321,9 +320,6 @@ jobs:
|
||||
- name: Build neon extensions
|
||||
run: mold -run make neon-pg-ext -j$(nproc)
|
||||
|
||||
- name: Build walproposer-lib
|
||||
run: mold -run make walproposer-lib -j$(nproc)
|
||||
|
||||
- name: Run cargo build
|
||||
run: |
|
||||
${cov_prefix} mold -run cargo build $CARGO_FLAGS $CARGO_FEATURES --bins --tests
|
||||
@@ -339,16 +335,6 @@ jobs:
|
||||
# Avoid `$CARGO_FEATURES` since there's no `testing` feature in the e2e tests now
|
||||
${cov_prefix} cargo test $CARGO_FLAGS --package remote_storage --test test_real_s3
|
||||
|
||||
# Run separate tests for real Azure Blob Storage
|
||||
# XXX: replace region with `eu-central-1`-like region
|
||||
export ENABLE_REAL_AZURE_REMOTE_STORAGE=y
|
||||
export AZURE_STORAGE_ACCOUNT="${{ secrets.AZURE_STORAGE_ACCOUNT_DEV }}"
|
||||
export AZURE_STORAGE_ACCESS_KEY="${{ secrets.AZURE_STORAGE_ACCESS_KEY_DEV }}"
|
||||
export REMOTE_STORAGE_AZURE_CONTAINER="${{ vars.REMOTE_STORAGE_AZURE_CONTAINER }}"
|
||||
export REMOTE_STORAGE_AZURE_REGION="${{ vars.REMOTE_STORAGE_AZURE_REGION }}"
|
||||
# Avoid `$CARGO_FEATURES` since there's no `testing` feature in the e2e tests now
|
||||
${cov_prefix} cargo test $CARGO_FLAGS --package remote_storage --test test_real_azure
|
||||
|
||||
- name: Install rust binaries
|
||||
run: |
|
||||
# Install target binaries
|
||||
@@ -434,7 +420,7 @@ jobs:
|
||||
rerun_flaky: true
|
||||
pg_version: ${{ matrix.pg_version }}
|
||||
env:
|
||||
TEST_RESULT_CONNSTR: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}
|
||||
TEST_RESULT_CONNSTR: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR }}
|
||||
CHECK_ONDISK_DATA_COMPATIBILITY: nonempty
|
||||
|
||||
- name: Merge and upload coverage data
|
||||
@@ -469,7 +455,7 @@ jobs:
|
||||
env:
|
||||
VIP_VAP_ACCESS_TOKEN: "${{ secrets.VIP_VAP_ACCESS_TOKEN }}"
|
||||
PERF_TEST_RESULT_CONNSTR: "${{ secrets.PERF_TEST_RESULT_CONNSTR }}"
|
||||
TEST_RESULT_CONNSTR: "${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}"
|
||||
TEST_RESULT_CONNSTR: "${{ secrets.REGRESS_TEST_RESULT_CONNSTR }}"
|
||||
# XXX: no coverage data handling here, since benchmarks are run on release builds,
|
||||
# while coverage is currently collected for the debug ones
|
||||
|
||||
@@ -586,13 +572,10 @@ jobs:
|
||||
id: upload-coverage-report-new
|
||||
env:
|
||||
BUCKET: neon-github-public-dev
|
||||
# A differential coverage report is available only for PRs.
|
||||
# (i.e. for pushes into main/release branches we have a regular coverage report)
|
||||
COMMIT_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
BASE_SHA: ${{ github.event.pull_request.base.sha || github.sha }}
|
||||
run: |
|
||||
BASELINE="$(git merge-base HEAD origin/main)"
|
||||
CURRENT="${COMMIT_SHA}"
|
||||
BASELINE="$(git merge-base $BASE_SHA $CURRENT)"
|
||||
|
||||
cp /tmp/coverage/report/lcov.info ./${CURRENT}.info
|
||||
|
||||
@@ -727,7 +710,6 @@ jobs:
|
||||
--cache-repo 369495373322.dkr.ecr.eu-central-1.amazonaws.com/cache
|
||||
--context .
|
||||
--build-arg GIT_VERSION=${{ github.event.pull_request.head.sha || github.sha }}
|
||||
--build-arg BUILD_TAG=${{ needs.tag.outputs.build-tag }}
|
||||
--build-arg REPOSITORY=369495373322.dkr.ecr.eu-central-1.amazonaws.com
|
||||
--destination 369495373322.dkr.ecr.eu-central-1.amazonaws.com/neon:${{needs.tag.outputs.build-tag}}
|
||||
--destination neondatabase/neon:${{needs.tag.outputs.build-tag}}
|
||||
@@ -852,7 +834,7 @@ jobs:
|
||||
run:
|
||||
shell: sh -eu {0}
|
||||
env:
|
||||
VM_BUILDER_VERSION: v0.19.0
|
||||
VM_BUILDER_VERSION: v0.17.12
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -874,7 +856,8 @@ jobs:
|
||||
- name: Build vm image
|
||||
run: |
|
||||
./vm-builder \
|
||||
-spec=vm-image-spec.yaml \
|
||||
-enable-file-cache \
|
||||
-cgroup-uid=postgres \
|
||||
-src=369495373322.dkr.ecr.eu-central-1.amazonaws.com/compute-node-${{ matrix.version }}:${{needs.tag.outputs.build-tag}} \
|
||||
-dst=369495373322.dkr.ecr.eu-central-1.amazonaws.com/vm-compute-node-${{ matrix.version }}:${{needs.tag.outputs.build-tag}}
|
||||
|
||||
@@ -1109,10 +1092,8 @@ jobs:
|
||||
run: |
|
||||
if [[ "$GITHUB_REF_NAME" == "main" ]]; then
|
||||
gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main -f branch=main -f dockerTag=${{needs.tag.outputs.build-tag}} -f deployPreprodRegion=false
|
||||
|
||||
# TODO: move deployPreprodRegion to release (`"$GITHUB_REF_NAME" == "release"` block), once Staging support different compute tag prefixes for different regions
|
||||
gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main -f branch=main -f dockerTag=${{needs.tag.outputs.build-tag}} -f deployPreprodRegion=true
|
||||
elif [[ "$GITHUB_REF_NAME" == "release" ]]; then
|
||||
gh workflow --repo neondatabase/aws run deploy-dev.yml --ref main -f branch=main -f dockerTag=${{needs.tag.outputs.build-tag}} -f deployPreprodRegion=true
|
||||
gh workflow --repo neondatabase/aws run deploy-prod.yml --ref main -f branch=main -f dockerTag=${{needs.tag.outputs.build-tag}} -f disclamerAcknowledged=true
|
||||
else
|
||||
echo "GITHUB_REF_NAME (value '$GITHUB_REF_NAME') is not set to either 'main' or 'release'"
|
||||
|
||||
199
.github/workflows/neon_extra_builds.yml
vendored
199
.github/workflows/neon_extra_builds.yml
vendored
@@ -21,10 +21,7 @@ env:
|
||||
|
||||
jobs:
|
||||
check-macos-build:
|
||||
if: |
|
||||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-macos') ||
|
||||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-*') ||
|
||||
github.ref_name == 'main'
|
||||
if: github.ref_name == 'main' || contains(github.event.pull_request.labels.*.name, 'run-extra-build-macos')
|
||||
timeout-minutes: 90
|
||||
runs-on: macos-latest
|
||||
|
||||
@@ -35,7 +32,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
submodules: true
|
||||
fetch-depth: 1
|
||||
@@ -93,21 +90,18 @@ jobs:
|
||||
|
||||
- name: Build postgres v14
|
||||
if: steps.cache_pg_14.outputs.cache-hit != 'true'
|
||||
run: make postgres-v14 -j$(sysctl -n hw.ncpu)
|
||||
run: make postgres-v14 -j$(nproc)
|
||||
|
||||
- name: Build postgres v15
|
||||
if: steps.cache_pg_15.outputs.cache-hit != 'true'
|
||||
run: make postgres-v15 -j$(sysctl -n hw.ncpu)
|
||||
run: make postgres-v15 -j$(nproc)
|
||||
|
||||
- name: Build postgres v16
|
||||
if: steps.cache_pg_16.outputs.cache-hit != 'true'
|
||||
run: make postgres-v16 -j$(sysctl -n hw.ncpu)
|
||||
run: make postgres-v16 -j$(nproc)
|
||||
|
||||
- name: Build neon extensions
|
||||
run: make neon-pg-ext -j$(sysctl -n hw.ncpu)
|
||||
|
||||
- name: Build walproposer-lib
|
||||
run: make walproposer-lib -j$(sysctl -n hw.ncpu)
|
||||
run: make neon-pg-ext -j$(nproc)
|
||||
|
||||
- name: Run cargo build
|
||||
run: cargo build --all --release
|
||||
@@ -115,182 +109,8 @@ jobs:
|
||||
- name: Check that no warnings are produced
|
||||
run: ./run_clippy.sh
|
||||
|
||||
check-linux-arm-build:
|
||||
timeout-minutes: 90
|
||||
runs-on: [ self-hosted, dev, arm64 ]
|
||||
|
||||
env:
|
||||
# Use release build only, to have less debug info around
|
||||
# Hence keeping target/ (and general cache size) smaller
|
||||
BUILD_TYPE: release
|
||||
CARGO_FEATURES: --features testing
|
||||
CARGO_FLAGS: --locked --release
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_DEV }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_KEY_DEV }}
|
||||
|
||||
container:
|
||||
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/rust:pinned
|
||||
options: --init
|
||||
|
||||
steps:
|
||||
- name: Fix git ownership
|
||||
run: |
|
||||
# Workaround for `fatal: detected dubious ownership in repository at ...`
|
||||
#
|
||||
# Use both ${{ github.workspace }} and ${GITHUB_WORKSPACE} because they're different on host and in containers
|
||||
# Ref https://github.com/actions/checkout/issues/785
|
||||
#
|
||||
git config --global --add safe.directory ${{ github.workspace }}
|
||||
git config --global --add safe.directory ${GITHUB_WORKSPACE}
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: true
|
||||
fetch-depth: 1
|
||||
|
||||
- name: Set pg 14 revision for caching
|
||||
id: pg_v14_rev
|
||||
run: echo pg_rev=$(git rev-parse HEAD:vendor/postgres-v14) >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set pg 15 revision for caching
|
||||
id: pg_v15_rev
|
||||
run: echo pg_rev=$(git rev-parse HEAD:vendor/postgres-v15) >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set pg 16 revision for caching
|
||||
id: pg_v16_rev
|
||||
run: echo pg_rev=$(git rev-parse HEAD:vendor/postgres-v16) >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Set env variables
|
||||
run: |
|
||||
echo "CARGO_HOME=${GITHUB_WORKSPACE}/.cargo" >> $GITHUB_ENV
|
||||
|
||||
- name: Cache postgres v14 build
|
||||
id: cache_pg_14
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: pg_install/v14
|
||||
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-pg-${{ steps.pg_v14_rev.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
|
||||
|
||||
- name: Cache postgres v15 build
|
||||
id: cache_pg_15
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: pg_install/v15
|
||||
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-pg-${{ steps.pg_v15_rev.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
|
||||
|
||||
- name: Cache postgres v16 build
|
||||
id: cache_pg_16
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: pg_install/v16
|
||||
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-pg-${{ steps.pg_v16_rev.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
|
||||
|
||||
- name: Build postgres v14
|
||||
if: steps.cache_pg_14.outputs.cache-hit != 'true'
|
||||
run: mold -run make postgres-v14 -j$(nproc)
|
||||
|
||||
- name: Build postgres v15
|
||||
if: steps.cache_pg_15.outputs.cache-hit != 'true'
|
||||
run: mold -run make postgres-v15 -j$(nproc)
|
||||
|
||||
- name: Build postgres v16
|
||||
if: steps.cache_pg_16.outputs.cache-hit != 'true'
|
||||
run: mold -run make postgres-v16 -j$(nproc)
|
||||
|
||||
- name: Build neon extensions
|
||||
run: mold -run make neon-pg-ext -j$(nproc)
|
||||
|
||||
- name: Build walproposer-lib
|
||||
run: mold -run make walproposer-lib -j$(nproc)
|
||||
|
||||
- name: Run cargo build
|
||||
run: |
|
||||
mold -run cargo build $CARGO_FLAGS $CARGO_FEATURES --bins --tests
|
||||
|
||||
- name: Run cargo test
|
||||
run: |
|
||||
cargo test $CARGO_FLAGS $CARGO_FEATURES
|
||||
|
||||
# Run separate tests for real S3
|
||||
export ENABLE_REAL_S3_REMOTE_STORAGE=nonempty
|
||||
export REMOTE_STORAGE_S3_BUCKET=neon-github-public-dev
|
||||
export REMOTE_STORAGE_S3_REGION=eu-central-1
|
||||
# Avoid `$CARGO_FEATURES` since there's no `testing` feature in the e2e tests now
|
||||
cargo test $CARGO_FLAGS --package remote_storage --test test_real_s3
|
||||
|
||||
# Run separate tests for real Azure Blob Storage
|
||||
# XXX: replace region with `eu-central-1`-like region
|
||||
export ENABLE_REAL_AZURE_REMOTE_STORAGE=y
|
||||
export AZURE_STORAGE_ACCOUNT="${{ secrets.AZURE_STORAGE_ACCOUNT_DEV }}"
|
||||
export AZURE_STORAGE_ACCESS_KEY="${{ secrets.AZURE_STORAGE_ACCESS_KEY_DEV }}"
|
||||
export REMOTE_STORAGE_AZURE_CONTAINER="${{ vars.REMOTE_STORAGE_AZURE_CONTAINER }}"
|
||||
export REMOTE_STORAGE_AZURE_REGION="${{ vars.REMOTE_STORAGE_AZURE_REGION }}"
|
||||
# Avoid `$CARGO_FEATURES` since there's no `testing` feature in the e2e tests now
|
||||
cargo test $CARGO_FLAGS --package remote_storage --test test_real_azure
|
||||
|
||||
check-codestyle-rust-arm:
|
||||
timeout-minutes: 90
|
||||
runs-on: [ self-hosted, dev, arm64 ]
|
||||
|
||||
container:
|
||||
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/rust:pinned
|
||||
options: --init
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: true
|
||||
fetch-depth: 1
|
||||
|
||||
# Some of our rust modules use FFI and need those to be checked
|
||||
- name: Get postgres headers
|
||||
run: make postgres-headers -j$(nproc)
|
||||
|
||||
# cargo hack runs the given cargo subcommand (clippy in this case) for all feature combinations.
|
||||
# This will catch compiler & clippy warnings in all feature combinations.
|
||||
# TODO: use cargo hack for build and test as well, but, that's quite expensive.
|
||||
# NB: keep clippy args in sync with ./run_clippy.sh
|
||||
- run: |
|
||||
CLIPPY_COMMON_ARGS="$( source .neon_clippy_args; echo "$CLIPPY_COMMON_ARGS")"
|
||||
if [ "$CLIPPY_COMMON_ARGS" = "" ]; then
|
||||
echo "No clippy args found in .neon_clippy_args"
|
||||
exit 1
|
||||
fi
|
||||
echo "CLIPPY_COMMON_ARGS=${CLIPPY_COMMON_ARGS}" >> $GITHUB_ENV
|
||||
- name: Run cargo clippy (debug)
|
||||
run: cargo hack --feature-powerset clippy $CLIPPY_COMMON_ARGS
|
||||
- name: Run cargo clippy (release)
|
||||
run: cargo hack --feature-powerset clippy --release $CLIPPY_COMMON_ARGS
|
||||
|
||||
- name: Check documentation generation
|
||||
run: cargo doc --workspace --no-deps --document-private-items
|
||||
env:
|
||||
RUSTDOCFLAGS: "-Dwarnings -Arustdoc::private_intra_doc_links"
|
||||
|
||||
# Use `${{ !cancelled() }}` to run quck tests after the longer clippy run
|
||||
- name: Check formatting
|
||||
if: ${{ !cancelled() }}
|
||||
run: cargo fmt --all -- --check
|
||||
|
||||
# https://github.com/facebookincubator/cargo-guppy/tree/bec4e0eb29dcd1faac70b1b5360267fc02bf830e/tools/cargo-hakari#2-keep-the-workspace-hack-up-to-date-in-ci
|
||||
- name: Check rust dependencies
|
||||
if: ${{ !cancelled() }}
|
||||
run: |
|
||||
cargo hakari generate --diff # workspace-hack Cargo.toml is up-to-date
|
||||
cargo hakari manage-deps --dry-run # all workspace crates depend on workspace-hack
|
||||
|
||||
# https://github.com/EmbarkStudios/cargo-deny
|
||||
- name: Check rust licenses/bans/advisories/sources
|
||||
if: ${{ !cancelled() }}
|
||||
run: cargo deny check
|
||||
|
||||
gather-rust-build-stats:
|
||||
if: |
|
||||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-stats') ||
|
||||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-*') ||
|
||||
github.ref_name == 'main'
|
||||
if: github.ref_name == 'main' || contains(github.event.pull_request.labels.*.name, 'run-extra-build-stats')
|
||||
runs-on: [ self-hosted, gen3, large ]
|
||||
container:
|
||||
image: 369495373322.dkr.ecr.eu-central-1.amazonaws.com/rust:pinned
|
||||
@@ -306,7 +126,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
submodules: true
|
||||
fetch-depth: 1
|
||||
@@ -315,9 +135,6 @@ jobs:
|
||||
- name: Get postgres headers
|
||||
run: make postgres-headers -j$(nproc)
|
||||
|
||||
- name: Build walproposer-lib
|
||||
run: make walproposer-lib -j$(nproc)
|
||||
|
||||
- name: Produce the build stats
|
||||
run: cargo build --all --release --timings
|
||||
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -2,7 +2,7 @@ name: Create Release Branch
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 7 * * 5'
|
||||
- cron: '0 7 * * 2'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
/libs/remote_storage/ @neondatabase/storage
|
||||
/libs/safekeeper_api/ @neondatabase/safekeepers
|
||||
/libs/vm_monitor/ @neondatabase/autoscaling @neondatabase/compute
|
||||
/pageserver/ @neondatabase/storage
|
||||
/pageserver/ @neondatabase/compute @neondatabase/storage
|
||||
/pgxn/ @neondatabase/compute
|
||||
/proxy/ @neondatabase/proxy
|
||||
/safekeeper/ @neondatabase/safekeepers
|
||||
|
||||
569
Cargo.lock
generated
569
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
28
Cargo.toml
28
Cargo.toml
@@ -26,7 +26,6 @@ members = [
|
||||
"libs/tracing-utils",
|
||||
"libs/postgres_ffi/wal_craft",
|
||||
"libs/vm_monitor",
|
||||
"libs/walproposer",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -36,19 +35,13 @@ license = "Apache-2.0"
|
||||
## All dependency versions, used in the project
|
||||
[workspace.dependencies]
|
||||
anyhow = { version = "1.0", features = ["backtrace"] }
|
||||
arc-swap = "1.6"
|
||||
async-compression = { version = "0.4.0", features = ["tokio", "gzip"] }
|
||||
azure_core = "0.16"
|
||||
azure_identity = "0.16"
|
||||
azure_storage = "0.16"
|
||||
azure_storage_blobs = "0.16"
|
||||
flate2 = "1.0.26"
|
||||
async-stream = "0.3"
|
||||
async-trait = "0.1"
|
||||
aws-config = { version = "0.56", default-features = false, features=["rustls"] }
|
||||
aws-sdk-s3 = "0.29"
|
||||
aws-smithy-http = "0.56"
|
||||
aws-smithy-async = { version = "0.56", default-features = false, features=["rt-tokio"] }
|
||||
aws-credential-types = "0.56"
|
||||
aws-types = "0.56"
|
||||
axum = { version = "0.6.20", features = ["ws"] }
|
||||
@@ -58,7 +51,6 @@ bindgen = "0.65"
|
||||
bstr = "1.0"
|
||||
byteorder = "1.4"
|
||||
bytes = "1.0"
|
||||
camino = "1.1.6"
|
||||
cfg-if = "1.0.0"
|
||||
chrono = { version = "0.4", default-features = false, features = ["clock"] }
|
||||
clap = { version = "4.0", features = ["derive"] }
|
||||
@@ -67,7 +59,7 @@ comfy-table = "6.1"
|
||||
const_format = "0.2"
|
||||
crc32c = "0.6"
|
||||
crossbeam-utils = "0.8.5"
|
||||
dashmap = { version = "5.5.0", features = ["raw-api"] }
|
||||
dashmap = "5.5.0"
|
||||
either = "1.8"
|
||||
enum-map = "2.4.2"
|
||||
enumset = "1.0.12"
|
||||
@@ -83,7 +75,6 @@ hex = "0.4"
|
||||
hex-literal = "0.4"
|
||||
hmac = "0.12.1"
|
||||
hostname = "0.3.1"
|
||||
http-types = { version = "2", default-features = false }
|
||||
humantime = "2.1"
|
||||
humantime-serde = "1.1.1"
|
||||
hyper = "0.14"
|
||||
@@ -126,7 +117,6 @@ sentry = { version = "0.31", default-features = false, features = ["backtrace",
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
serde_with = "2.0"
|
||||
serde_assert = "0.5.0"
|
||||
sha2 = "0.10.2"
|
||||
signal-hook = "0.3"
|
||||
smallvec = "1.11"
|
||||
@@ -136,7 +126,6 @@ strum_macros = "0.24"
|
||||
svg_fmt = "0.4.1"
|
||||
sync_wrapper = "0.1.2"
|
||||
tar = "0.4"
|
||||
task-local-extensions = "0.1.4"
|
||||
test-context = "0.1"
|
||||
thiserror = "1.0"
|
||||
tls-listener = { version = "0.7", features = ["rustls", "hyper-h1"] }
|
||||
@@ -165,11 +154,11 @@ env_logger = "0.10"
|
||||
log = "0.4"
|
||||
|
||||
## Libraries from neondatabase/ git forks, ideally with changes to be upstreamed
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="ce7260db5998fe27167da42503905a12e7ad9048" }
|
||||
postgres-native-tls = { git = "https://github.com/neondatabase/rust-postgres.git", rev="ce7260db5998fe27167da42503905a12e7ad9048" }
|
||||
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="ce7260db5998fe27167da42503905a12e7ad9048" }
|
||||
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev="ce7260db5998fe27167da42503905a12e7ad9048" }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="ce7260db5998fe27167da42503905a12e7ad9048" }
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" }
|
||||
postgres-native-tls = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" }
|
||||
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" }
|
||||
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" }
|
||||
|
||||
## Other git libraries
|
||||
heapless = { default-features=false, features=[], git = "https://github.com/japaric/heapless.git", rev = "644653bf3b831c6bb4963be2de24804acf5e5001" } # upstream release pending
|
||||
@@ -190,7 +179,6 @@ tenant_size_model = { version = "0.1", path = "./libs/tenant_size_model/" }
|
||||
tracing-utils = { version = "0.1", path = "./libs/tracing-utils/" }
|
||||
utils = { version = "0.1", path = "./libs/utils/" }
|
||||
vm_monitor = { version = "0.1", path = "./libs/vm_monitor/" }
|
||||
walproposer = { version = "0.1", path = "./libs/walproposer/" }
|
||||
|
||||
## Common library dependency
|
||||
workspace_hack = { version = "0.1", path = "./workspace_hack/" }
|
||||
@@ -199,14 +187,14 @@ workspace_hack = { version = "0.1", path = "./workspace_hack/" }
|
||||
criterion = "0.5.1"
|
||||
rcgen = "0.11"
|
||||
rstest = "0.18"
|
||||
camino-tempfile = "1.0.2"
|
||||
tempfile = "3.4"
|
||||
tonic-build = "0.9"
|
||||
|
||||
[patch.crates-io]
|
||||
|
||||
# This is only needed for proxy's tests.
|
||||
# TODO: we should probably fork `tokio-postgres-rustls` instead.
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="ce7260db5998fe27167da42503905a12e7ad9048" }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" }
|
||||
|
||||
################# Binary contents sections
|
||||
|
||||
|
||||
@@ -27,7 +27,6 @@ RUN set -e \
|
||||
FROM $REPOSITORY/$IMAGE:$TAG AS build
|
||||
WORKDIR /home/nonroot
|
||||
ARG GIT_VERSION=local
|
||||
ARG BUILD_TAG
|
||||
|
||||
# Enable https://github.com/paritytech/cachepot to cache Rust crates' compilation results in Docker builds.
|
||||
# Set up cachepot to use an AWS S3 bucket for cache results, to reuse it between `docker build` invocations.
|
||||
@@ -79,9 +78,9 @@ COPY --from=build --chown=neon:neon /home/nonroot/target/release/pg_sni_router
|
||||
COPY --from=build --chown=neon:neon /home/nonroot/target/release/pageserver /usr/local/bin
|
||||
COPY --from=build --chown=neon:neon /home/nonroot/target/release/pagectl /usr/local/bin
|
||||
COPY --from=build --chown=neon:neon /home/nonroot/target/release/safekeeper /usr/local/bin
|
||||
COPY --from=build --chown=neon:neon /home/nonroot/target/release/storage_broker /usr/local/bin
|
||||
COPY --from=build --chown=neon:neon /home/nonroot/target/release/storage_broker /usr/local/bin
|
||||
COPY --from=build --chown=neon:neon /home/nonroot/target/release/proxy /usr/local/bin
|
||||
COPY --from=build --chown=neon:neon /home/nonroot/target/release/neon_local /usr/local/bin
|
||||
COPY --from=build --chown=neon:neon /home/nonroot/target/release/neon_local /usr/local/bin
|
||||
|
||||
COPY --from=pg-build /home/nonroot/pg_install/v14 /usr/local/v14/
|
||||
COPY --from=pg-build /home/nonroot/pg_install/v15 /usr/local/v15/
|
||||
|
||||
@@ -224,8 +224,8 @@ RUN wget https://github.com/df7cb/postgresql-unit/archive/refs/tags/7.7.tar.gz -
|
||||
FROM build-deps AS vector-pg-build
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
RUN wget https://github.com/pgvector/pgvector/archive/refs/tags/v0.5.1.tar.gz -O pgvector.tar.gz && \
|
||||
echo "cc7a8e034a96e30a819911ac79d32f6bc47bdd1aa2de4d7d4904e26b83209dc8 pgvector.tar.gz" | sha256sum --check && \
|
||||
RUN wget https://github.com/pgvector/pgvector/archive/refs/tags/v0.5.0.tar.gz -O pgvector.tar.gz && \
|
||||
echo "d8aa3504b215467ca528525a6de12c3f85f9891b091ce0e5864dd8a9b757f77b pgvector.tar.gz" | sha256sum --check && \
|
||||
mkdir pgvector-src && cd pgvector-src && tar xvzf ../pgvector.tar.gz --strip-components=1 -C . && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
|
||||
@@ -368,8 +368,8 @@ RUN wget https://github.com/citusdata/postgresql-hll/archive/refs/tags/v2.18.tar
|
||||
FROM build-deps AS plpgsql-check-pg-build
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
RUN wget https://github.com/okbob/plpgsql_check/archive/refs/tags/v2.5.3.tar.gz -O plpgsql_check.tar.gz && \
|
||||
echo "6631ec3e7fb3769eaaf56e3dfedb829aa761abf163d13dba354b4c218508e1c0 plpgsql_check.tar.gz" | sha256sum --check && \
|
||||
RUN wget https://github.com/okbob/plpgsql_check/archive/refs/tags/v2.4.0.tar.gz -O plpgsql_check.tar.gz && \
|
||||
echo "9ba58387a279b35a3bfa39ee611e5684e6cddb2ba046ddb2c5190b3bd2ca254a plpgsql_check.tar.gz" | sha256sum --check && \
|
||||
mkdir plpgsql_check-src && cd plpgsql_check-src && tar xvzf ../plpgsql_check.tar.gz --strip-components=1 -C . && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
|
||||
@@ -615,7 +615,11 @@ RUN wget https://gitlab.com/dalibo/postgresql_anonymizer/-/archive/1.1.0/postgre
|
||||
#########################################################################################
|
||||
#
|
||||
# Layer "rust extensions"
|
||||
# This layer is used to build `pgrx` deps
|
||||
# This layer is used to build `pgx` deps
|
||||
#
|
||||
# FIXME: This needs to be updated to latest version of 'pgrx' (it was renamed from
|
||||
# 'pgx' to 'pgrx') for PostgreSQL 16. And that in turn requires bumping the pgx
|
||||
# dependency on all the rust extension that depend on it, too.
|
||||
#
|
||||
#########################################################################################
|
||||
FROM build-deps AS rust-extensions-build
|
||||
@@ -631,12 +635,22 @@ USER nonroot
|
||||
WORKDIR /home/nonroot
|
||||
ARG PG_VERSION
|
||||
|
||||
RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux-gnu/rustup-init && \
|
||||
RUN case "${PG_VERSION}" in \
|
||||
"v14" | "v15") \
|
||||
;; \
|
||||
"v16") \
|
||||
echo "TODO: Not yet supported for PostgreSQL 16. Need to update pgrx dependencies" && exit 0 \
|
||||
;; \
|
||||
*) \
|
||||
echo "unexpected PostgreSQL version ${PG_VERSION}" && exit 1 \
|
||||
;; \
|
||||
esac && \
|
||||
curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux-gnu/rustup-init && \
|
||||
chmod +x rustup-init && \
|
||||
./rustup-init -y --no-modify-path --profile minimal --default-toolchain stable && \
|
||||
rm rustup-init && \
|
||||
cargo install --locked --version 0.10.2 cargo-pgrx && \
|
||||
/bin/bash -c 'cargo pgrx init --pg${PG_VERSION:1}=/usr/local/pgsql/bin/pg_config'
|
||||
cargo install --locked --version 0.7.3 cargo-pgx && \
|
||||
/bin/bash -c 'cargo pgx init --pg${PG_VERSION:1}=/usr/local/pgsql/bin/pg_config'
|
||||
|
||||
USER root
|
||||
|
||||
@@ -650,11 +664,23 @@ USER root
|
||||
FROM rust-extensions-build AS pg-jsonschema-pg-build
|
||||
ARG PG_VERSION
|
||||
|
||||
RUN wget https://github.com/supabase/pg_jsonschema/archive/refs/tags/v0.2.0.tar.gz -O pg_jsonschema.tar.gz && \
|
||||
echo "9118fc508a6e231e7a39acaa6f066fcd79af17a5db757b47d2eefbe14f7794f0 pg_jsonschema.tar.gz" | sha256sum --check && \
|
||||
# caeab60d70b2fd3ae421ec66466a3abbb37b7ee6 made on 06/03/2023
|
||||
# there is no release tag yet, but we need it due to the superuser fix in the control file, switch to git tag after release >= 0.1.5
|
||||
RUN case "${PG_VERSION}" in \
|
||||
"v14" | "v15") \
|
||||
;; \
|
||||
"v16") \
|
||||
echo "TODO: Not yet supported for PostgreSQL 16. Need to update pgrx dependencies" && exit 0 \
|
||||
;; \
|
||||
*) \
|
||||
echo "unexpected PostgreSQL version \"${PG_VERSION}\"" && exit 1 \
|
||||
;; \
|
||||
esac && \
|
||||
wget https://github.com/supabase/pg_jsonschema/archive/caeab60d70b2fd3ae421ec66466a3abbb37b7ee6.tar.gz -O pg_jsonschema.tar.gz && \
|
||||
echo "54129ce2e7ee7a585648dbb4cef6d73f795d94fe72f248ac01119992518469a4 pg_jsonschema.tar.gz" | sha256sum --check && \
|
||||
mkdir pg_jsonschema-src && cd pg_jsonschema-src && tar xvzf ../pg_jsonschema.tar.gz --strip-components=1 -C . && \
|
||||
sed -i 's/pgrx = "0.10.2"/pgrx = { version = "0.10.2", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
|
||||
cargo pgrx install --release && \
|
||||
sed -i 's/pgx = "0.7.1"/pgx = { version = "0.7.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
|
||||
cargo pgx install --release && \
|
||||
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_jsonschema.control
|
||||
|
||||
#########################################################################################
|
||||
@@ -667,11 +693,26 @@ RUN wget https://github.com/supabase/pg_jsonschema/archive/refs/tags/v0.2.0.tar.
|
||||
FROM rust-extensions-build AS pg-graphql-pg-build
|
||||
ARG PG_VERSION
|
||||
|
||||
RUN wget https://github.com/supabase/pg_graphql/archive/refs/tags/v1.4.0.tar.gz -O pg_graphql.tar.gz && \
|
||||
echo "bd8dc7230282b3efa9ae5baf053a54151ed0e66881c7c53750e2d0c765776edc pg_graphql.tar.gz" | sha256sum --check && \
|
||||
# b4988843647450a153439be367168ed09971af85 made on 22/02/2023 (from remove-pgx-contrib-spiext branch)
|
||||
# Currently pgx version bump to >= 0.7.2 causes "call to unsafe function" compliation errors in
|
||||
# pgx-contrib-spiext. There is a branch that removes that dependency, so use it. It is on the
|
||||
# same 1.1 version we've used before.
|
||||
RUN case "${PG_VERSION}" in \
|
||||
"v14" | "v15") \
|
||||
;; \
|
||||
"v16") \
|
||||
echo "TODO: Not yet supported for PostgreSQL 16. Need to update pgrx dependencies" && exit 0 \
|
||||
;; \
|
||||
*) \
|
||||
echo "unexpected PostgreSQL version" && exit 1 \
|
||||
;; \
|
||||
esac && \
|
||||
wget https://github.com/yrashk/pg_graphql/archive/b4988843647450a153439be367168ed09971af85.tar.gz -O pg_graphql.tar.gz && \
|
||||
echo "0c7b0e746441b2ec24187d0e03555faf935c2159e2839bddd14df6dafbc8c9bd pg_graphql.tar.gz" | sha256sum --check && \
|
||||
mkdir pg_graphql-src && cd pg_graphql-src && tar xvzf ../pg_graphql.tar.gz --strip-components=1 -C . && \
|
||||
sed -i 's/pgrx = "=0.10.2"/pgrx = { version = "0.10.2", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
|
||||
cargo pgrx install --release && \
|
||||
sed -i 's/pgx = "~0.7.1"/pgx = { version = "0.7.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
|
||||
sed -i 's/pgx-tests = "~0.7.1"/pgx-tests = "0.7.3"/g' Cargo.toml && \
|
||||
cargo pgx install --release && \
|
||||
# it's needed to enable extension because it uses untrusted C language
|
||||
sed -i 's/superuser = false/superuser = true/g' /usr/local/pgsql/share/extension/pg_graphql.control && \
|
||||
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_graphql.control
|
||||
@@ -686,11 +727,21 @@ RUN wget https://github.com/supabase/pg_graphql/archive/refs/tags/v1.4.0.tar.gz
|
||||
FROM rust-extensions-build AS pg-tiktoken-pg-build
|
||||
ARG PG_VERSION
|
||||
|
||||
# 26806147b17b60763039c6a6878884c41a262318 made on 26/09/2023
|
||||
RUN wget https://github.com/kelvich/pg_tiktoken/archive/26806147b17b60763039c6a6878884c41a262318.tar.gz -O pg_tiktoken.tar.gz && \
|
||||
echo "e64e55aaa38c259512d3e27c572da22c4637418cf124caba904cd50944e5004e pg_tiktoken.tar.gz" | sha256sum --check && \
|
||||
# 801f84f08c6881c8aa30f405fafbf00eec386a72 made on 10/03/2023
|
||||
RUN case "${PG_VERSION}" in \
|
||||
"v14" | "v15") \
|
||||
;; \
|
||||
"v16") \
|
||||
echo "TODO: Not yet supported for PostgreSQL 16. Need to update pgrx dependencies" && exit 0 \
|
||||
;; \
|
||||
*) \
|
||||
echo "unexpected PostgreSQL version" && exit 1 \
|
||||
;; \
|
||||
esac && \
|
||||
wget https://github.com/kelvich/pg_tiktoken/archive/801f84f08c6881c8aa30f405fafbf00eec386a72.tar.gz -O pg_tiktoken.tar.gz && \
|
||||
echo "52f60ac800993a49aa8c609961842b611b6b1949717b69ce2ec9117117e16e4a pg_tiktoken.tar.gz" | sha256sum --check && \
|
||||
mkdir pg_tiktoken-src && cd pg_tiktoken-src && tar xvzf ../pg_tiktoken.tar.gz --strip-components=1 -C . && \
|
||||
cargo pgrx install --release && \
|
||||
cargo pgx install --release && \
|
||||
echo "trusted = true" >> /usr/local/pgsql/share/extension/pg_tiktoken.control
|
||||
|
||||
#########################################################################################
|
||||
@@ -703,34 +754,23 @@ RUN wget https://github.com/kelvich/pg_tiktoken/archive/26806147b17b60763039c6a6
|
||||
FROM rust-extensions-build AS pg-pgx-ulid-build
|
||||
ARG PG_VERSION
|
||||
|
||||
RUN wget https://github.com/pksunkara/pgx_ulid/archive/refs/tags/v0.1.3.tar.gz -O pgx_ulid.tar.gz && \
|
||||
echo "ee5db82945d2d9f2d15597a80cf32de9dca67b897f605beb830561705f12683c pgx_ulid.tar.gz" | sha256sum --check && \
|
||||
RUN case "${PG_VERSION}" in \
|
||||
"v14" | "v15") \
|
||||
;; \
|
||||
"v16") \
|
||||
echo "TODO: Not yet supported for PostgreSQL 16. Need to update pgrx dependencies" && exit 0 \
|
||||
;; \
|
||||
*) \
|
||||
echo "unexpected PostgreSQL version" && exit 1 \
|
||||
;; \
|
||||
esac && \
|
||||
wget https://github.com/pksunkara/pgx_ulid/archive/refs/tags/v0.1.0.tar.gz -O pgx_ulid.tar.gz && \
|
||||
echo "908b7358e6f846e87db508ae5349fb56a88ee6305519074b12f3d5b0ff09f791 pgx_ulid.tar.gz" | sha256sum --check && \
|
||||
mkdir pgx_ulid-src && cd pgx_ulid-src && tar xvzf ../pgx_ulid.tar.gz --strip-components=1 -C . && \
|
||||
echo "******************* Apply a patch for Postgres 16 support; delete in the next release ******************" && \
|
||||
wget https://github.com/pksunkara/pgx_ulid/commit/f84954cf63fc8c80d964ac970d9eceed3c791196.patch && \
|
||||
patch -p1 < f84954cf63fc8c80d964ac970d9eceed3c791196.patch && \
|
||||
echo "********************************************************************************************************" && \
|
||||
sed -i 's/pgrx = "=0.10.2"/pgrx = { version = "=0.10.2", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
|
||||
cargo pgrx install --release && \
|
||||
sed -i 's/pgx = "=0.7.3"/pgx = { version = "0.7.3", features = [ "unsafe-postgres" ] }/g' Cargo.toml && \
|
||||
cargo pgx install --release && \
|
||||
echo "trusted = true" >> /usr/local/pgsql/share/extension/ulid.control
|
||||
|
||||
#########################################################################################
|
||||
#
|
||||
# Layer "pg-wait-sampling-pg-build"
|
||||
# compile pg_wait_sampling extension
|
||||
#
|
||||
#########################################################################################
|
||||
FROM build-deps AS pg-wait-sampling-pg-build
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
ENV PATH "/usr/local/pgsql/bin/:$PATH"
|
||||
RUN wget https://github.com/postgrespro/pg_wait_sampling/archive/refs/tags/v1.1.5.tar.gz -O pg_wait_sampling.tar.gz && \
|
||||
echo 'a03da6a413f5652ce470a3635ed6ebba528c74cb26aa4cfced8aff8a8441f81ec6dd657ff62cd6ce96a4e6ce02cad9f2519ae9525367ece60497aa20faafde5c pg_wait_sampling.tar.gz' | sha512sum -c && \
|
||||
mkdir pg_wait_sampling-src && cd pg_wait_sampling-src && tar xvzf ../pg_wait_sampling.tar.gz --strip-components=1 -C . && \
|
||||
make USE_PGXS=1 -j $(getconf _NPROCESSORS_ONLN) && \
|
||||
make USE_PGXS=1 -j $(getconf _NPROCESSORS_ONLN) install && \
|
||||
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pg_wait_sampling.control
|
||||
|
||||
#########################################################################################
|
||||
#
|
||||
# Layer "neon-pg-ext-build"
|
||||
@@ -767,7 +807,6 @@ COPY --from=rdkit-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=pg-uuidv7-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=pg-roaringbitmap-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=pg-embedding-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY --from=pg-wait-sampling-pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
COPY pgxn/ pgxn/
|
||||
|
||||
RUN make -j $(getconf _NPROCESSORS_ONLN) \
|
||||
|
||||
42
Makefile
42
Makefile
@@ -62,7 +62,7 @@ all: neon postgres neon-pg-ext
|
||||
#
|
||||
# The 'postgres_ffi' depends on the Postgres headers.
|
||||
.PHONY: neon
|
||||
neon: postgres-headers walproposer-lib
|
||||
neon: postgres-headers
|
||||
+@echo "Compiling Neon"
|
||||
$(CARGO_CMD_PREFIX) cargo build $(CARGO_BUILD_FLAGS)
|
||||
|
||||
@@ -72,10 +72,6 @@ neon: postgres-headers walproposer-lib
|
||||
#
|
||||
$(POSTGRES_INSTALL_DIR)/build/%/config.status:
|
||||
+@echo "Configuring Postgres $* build"
|
||||
@test -s $(ROOT_PROJECT_DIR)/vendor/postgres-$*/configure || { \
|
||||
echo "\nPostgres submodule not found in $(ROOT_PROJECT_DIR)/vendor/postgres-$*/, execute "; \
|
||||
echo "'git submodule update --init --recursive --depth 2 --progress .' in project root.\n"; \
|
||||
exit 1; }
|
||||
mkdir -p $(POSTGRES_INSTALL_DIR)/build/$*
|
||||
(cd $(POSTGRES_INSTALL_DIR)/build/$* && \
|
||||
env PATH="$(EXTRA_PATH_OVERRIDES):$$PATH" $(ROOT_PROJECT_DIR)/vendor/postgres-$*/configure \
|
||||
@@ -172,42 +168,6 @@ neon-pg-ext-clean-%:
|
||||
-C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \
|
||||
-f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile clean
|
||||
|
||||
# Build walproposer as a static library. walproposer source code is located
|
||||
# in the pgxn/neon directory.
|
||||
#
|
||||
# We also need to include libpgport.a and libpgcommon.a, because walproposer
|
||||
# uses some functions from those libraries.
|
||||
#
|
||||
# Some object files are removed from libpgport.a and libpgcommon.a because
|
||||
# they depend on openssl and other libraries that are not included in our
|
||||
# Rust build.
|
||||
.PHONY: walproposer-lib
|
||||
walproposer-lib: neon-pg-ext-v16
|
||||
+@echo "Compiling walproposer-lib"
|
||||
mkdir -p $(POSTGRES_INSTALL_DIR)/build/walproposer-lib
|
||||
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v16/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
|
||||
-C $(POSTGRES_INSTALL_DIR)/build/walproposer-lib \
|
||||
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile walproposer-lib
|
||||
cp $(POSTGRES_INSTALL_DIR)/v16/lib/libpgport.a $(POSTGRES_INSTALL_DIR)/build/walproposer-lib
|
||||
cp $(POSTGRES_INSTALL_DIR)/v16/lib/libpgcommon.a $(POSTGRES_INSTALL_DIR)/build/walproposer-lib
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
$(AR) d $(POSTGRES_INSTALL_DIR)/build/walproposer-lib/libpgport.a \
|
||||
pg_strong_random.o
|
||||
$(AR) d $(POSTGRES_INSTALL_DIR)/build/walproposer-lib/libpgcommon.a \
|
||||
pg_crc32c.o \
|
||||
hmac_openssl.o \
|
||||
cryptohash_openssl.o \
|
||||
scram-common.o \
|
||||
md5_common.o \
|
||||
checksum_helper.o
|
||||
endif
|
||||
|
||||
.PHONY: walproposer-lib-clean
|
||||
walproposer-lib-clean:
|
||||
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v16/bin/pg_config \
|
||||
-C $(POSTGRES_INSTALL_DIR)/build/walproposer-lib \
|
||||
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile clean
|
||||
|
||||
.PHONY: neon-pg-ext
|
||||
neon-pg-ext: \
|
||||
neon-pg-ext-v14 \
|
||||
|
||||
4
NOTICE
4
NOTICE
@@ -1,5 +1,5 @@
|
||||
Neon
|
||||
Copyright 2022 Neon Inc.
|
||||
|
||||
The PostgreSQL submodules in vendor/ are licensed under the PostgreSQL license.
|
||||
See vendor/postgres-vX/COPYRIGHT for details.
|
||||
The PostgreSQL submodules in vendor/postgres-v14 and vendor/postgres-v15 are licensed under the
|
||||
PostgreSQL license. See vendor/postgres-v14/COPYRIGHT and vendor/postgres-v15/COPYRIGHT.
|
||||
|
||||
@@ -156,7 +156,6 @@ fn main() -> Result<()> {
|
||||
let path = Path::new(sp);
|
||||
let file = File::open(path)?;
|
||||
spec = Some(serde_json::from_reader(file)?);
|
||||
live_config_allowed = true;
|
||||
} else if let Some(id) = compute_id {
|
||||
if let Some(cp_base) = control_plane_uri {
|
||||
live_config_allowed = true;
|
||||
@@ -278,26 +277,32 @@ fn main() -> Result<()> {
|
||||
if #[cfg(target_os = "linux")] {
|
||||
use std::env;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
let vm_monitor_addr = matches
|
||||
.get_one::<String>("vm-monitor-addr")
|
||||
.expect("--vm-monitor-addr should always be set because it has a default arg");
|
||||
use tracing::warn;
|
||||
let vm_monitor_addr = matches.get_one::<String>("vm-monitor-addr");
|
||||
let file_cache_connstr = matches.get_one::<String>("filecache-connstr");
|
||||
let cgroup = matches.get_one::<String>("cgroup");
|
||||
let file_cache_on_disk = matches.get_flag("file-cache-on-disk");
|
||||
|
||||
// Only make a runtime if we need to.
|
||||
// Note: it seems like you can make a runtime in an inner scope and
|
||||
// if you start a task in it it won't be dropped. However, make it
|
||||
// in the outermost scope just to be safe.
|
||||
let rt = if env::var_os("AUTOSCALING").is_some() {
|
||||
Some(
|
||||
let rt = match (env::var_os("AUTOSCALING"), vm_monitor_addr) {
|
||||
(None, None) => None,
|
||||
(None, Some(_)) => {
|
||||
warn!("--vm-monitor-addr option set but AUTOSCALING env var not present");
|
||||
None
|
||||
}
|
||||
(Some(_), None) => {
|
||||
panic!("AUTOSCALING env var present but --vm-monitor-addr option not set")
|
||||
}
|
||||
(Some(_), Some(_)) => Some(
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(4)
|
||||
.enable_all()
|
||||
.build()
|
||||
.expect("failed to create tokio runtime for monitor")
|
||||
)
|
||||
} else {
|
||||
None
|
||||
.expect("failed to create tokio runtime for monitor"),
|
||||
),
|
||||
};
|
||||
|
||||
// This token is used internally by the monitor to clean up all threads
|
||||
@@ -308,7 +313,8 @@ fn main() -> Result<()> {
|
||||
Box::leak(Box::new(vm_monitor::Args {
|
||||
cgroup: cgroup.cloned(),
|
||||
pgconnstr: file_cache_connstr.cloned(),
|
||||
addr: vm_monitor_addr.clone(),
|
||||
addr: vm_monitor_addr.cloned().unwrap(),
|
||||
file_cache_on_disk,
|
||||
})),
|
||||
token.clone(),
|
||||
))
|
||||
@@ -480,8 +486,6 @@ fn cli() -> clap::Command {
|
||||
.value_name("FILECACHE_CONNSTR"),
|
||||
)
|
||||
.arg(
|
||||
// DEPRECATED, NO LONGER DOES ANYTHING.
|
||||
// See https://github.com/neondatabase/cloud/issues/7516
|
||||
Arg::new("file-cache-on-disk")
|
||||
.long("file-cache-on-disk")
|
||||
.action(clap::ArgAction::SetTrue),
|
||||
|
||||
@@ -252,7 +252,7 @@ fn create_neon_superuser(spec: &ComputeSpec, client: &mut Client) -> Result<()>
|
||||
IF NOT EXISTS (
|
||||
SELECT FROM pg_catalog.pg_roles WHERE rolname = 'neon_superuser')
|
||||
THEN
|
||||
CREATE ROLE neon_superuser CREATEDB CREATEROLE NOLOGIN REPLICATION IN ROLE pg_read_all_data, pg_write_all_data;
|
||||
CREATE ROLE neon_superuser CREATEDB CREATEROLE NOLOGIN IN ROLE pg_read_all_data, pg_write_all_data;
|
||||
IF array_length(roles, 1) IS NOT NULL THEN
|
||||
EXECUTE format('GRANT neon_superuser TO %s',
|
||||
array_to_string(ARRAY(SELECT quote_ident(x) FROM unnest(roles) as x), ', '));
|
||||
@@ -692,11 +692,10 @@ impl ComputeNode {
|
||||
// Proceed with post-startup configuration. Note, that order of operations is important.
|
||||
let spec = &compute_state.pspec.as_ref().expect("spec must be set").spec;
|
||||
create_neon_superuser(spec, &mut client)?;
|
||||
cleanup_instance(&mut client)?;
|
||||
handle_roles(spec, &mut client)?;
|
||||
handle_databases(spec, &mut client)?;
|
||||
handle_role_deletions(spec, self.connstr.as_str(), &mut client)?;
|
||||
handle_grants(spec, &mut client, self.connstr.as_str())?;
|
||||
handle_grants(spec, self.connstr.as_str())?;
|
||||
handle_extensions(spec, &mut client)?;
|
||||
create_availability_check_data(&mut client)?;
|
||||
|
||||
@@ -710,12 +709,8 @@ impl ComputeNode {
|
||||
// `pg_ctl` for start / stop, so this just seems much easier to do as we already
|
||||
// have opened connection to Postgres and superuser access.
|
||||
#[instrument(skip_all)]
|
||||
fn pg_reload_conf(&self) -> Result<()> {
|
||||
let pgctl_bin = Path::new(&self.pgbin).parent().unwrap().join("pg_ctl");
|
||||
Command::new(pgctl_bin)
|
||||
.args(["reload", "-D", &self.pgdata])
|
||||
.output()
|
||||
.expect("cannot run pg_ctl process");
|
||||
fn pg_reload_conf(&self, client: &mut Client) -> Result<()> {
|
||||
client.simple_query("SELECT pg_reload_conf()")?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -728,19 +723,18 @@ impl ComputeNode {
|
||||
// Write new config
|
||||
let pgdata_path = Path::new(&self.pgdata);
|
||||
config::write_postgres_conf(&pgdata_path.join("postgresql.conf"), &spec, None)?;
|
||||
self.pg_reload_conf()?;
|
||||
|
||||
let mut client = Client::connect(self.connstr.as_str(), NoTls)?;
|
||||
self.pg_reload_conf(&mut client)?;
|
||||
|
||||
// Proceed with post-startup configuration. Note, that order of operations is important.
|
||||
// Disable DDL forwarding because control plane already knows about these roles/databases.
|
||||
if spec.mode == ComputeMode::Primary {
|
||||
client.simple_query("SET neon.forward_ddl = false")?;
|
||||
cleanup_instance(&mut client)?;
|
||||
handle_roles(&spec, &mut client)?;
|
||||
handle_databases(&spec, &mut client)?;
|
||||
handle_role_deletions(&spec, self.connstr.as_str(), &mut client)?;
|
||||
handle_grants(&spec, &mut client, self.connstr.as_str())?;
|
||||
handle_grants(&spec, self.connstr.as_str())?;
|
||||
handle_extensions(&spec, &mut client)?;
|
||||
}
|
||||
|
||||
@@ -1045,7 +1039,7 @@ LIMIT 100",
|
||||
let remote_extensions = spec
|
||||
.remote_extensions
|
||||
.as_ref()
|
||||
.ok_or(anyhow::anyhow!("Remote extensions are not configured"))?;
|
||||
.ok_or(anyhow::anyhow!("Remote extensions are not configured",))?;
|
||||
|
||||
info!("parse shared_preload_libraries from spec.cluster.settings");
|
||||
let mut libs_vec = Vec::new();
|
||||
|
||||
@@ -78,7 +78,7 @@ use regex::Regex;
|
||||
use remote_storage::*;
|
||||
use serde_json;
|
||||
use std::io::Read;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::num::{NonZeroU32, NonZeroUsize};
|
||||
use std::path::Path;
|
||||
use std::str;
|
||||
use tar::Archive;
|
||||
@@ -133,6 +133,45 @@ fn parse_pg_version(human_version: &str) -> &str {
|
||||
panic!("Unsuported postgres version {human_version}");
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::parse_pg_version;
|
||||
|
||||
#[test]
|
||||
fn test_parse_pg_version() {
|
||||
assert_eq!(parse_pg_version("PostgreSQL 15.4"), "v15");
|
||||
assert_eq!(parse_pg_version("PostgreSQL 15.14"), "v15");
|
||||
assert_eq!(
|
||||
parse_pg_version("PostgreSQL 15.4 (Ubuntu 15.4-0ubuntu0.23.04.1)"),
|
||||
"v15"
|
||||
);
|
||||
|
||||
assert_eq!(parse_pg_version("PostgreSQL 14.15"), "v14");
|
||||
assert_eq!(parse_pg_version("PostgreSQL 14.0"), "v14");
|
||||
assert_eq!(
|
||||
parse_pg_version("PostgreSQL 14.9 (Debian 14.9-1.pgdg120+1"),
|
||||
"v14"
|
||||
);
|
||||
|
||||
assert_eq!(parse_pg_version("PostgreSQL 16devel"), "v16");
|
||||
assert_eq!(parse_pg_version("PostgreSQL 16beta1"), "v16");
|
||||
assert_eq!(parse_pg_version("PostgreSQL 16rc2"), "v16");
|
||||
assert_eq!(parse_pg_version("PostgreSQL 16extra"), "v16");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_parse_pg_unsupported_version() {
|
||||
parse_pg_version("PostgreSQL 13.14");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_parse_pg_incorrect_version_format() {
|
||||
parse_pg_version("PostgreSQL 14");
|
||||
}
|
||||
}
|
||||
|
||||
// download the archive for a given extension,
|
||||
// unzip it, and place files in the appropriate locations (share/lib)
|
||||
pub async fn download_extension(
|
||||
@@ -242,46 +281,9 @@ pub fn init_remote_storage(remote_ext_config: &str) -> anyhow::Result<GenericRem
|
||||
max_keys_per_list_response: None,
|
||||
};
|
||||
let config = RemoteStorageConfig {
|
||||
max_concurrent_syncs: NonZeroUsize::new(100).expect("100 != 0"),
|
||||
max_sync_errors: NonZeroU32::new(100).expect("100 != 0"),
|
||||
storage: RemoteStorageKind::AwsS3(config),
|
||||
};
|
||||
GenericRemoteStorage::from_config(&config)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::parse_pg_version;
|
||||
|
||||
#[test]
|
||||
fn test_parse_pg_version() {
|
||||
assert_eq!(parse_pg_version("PostgreSQL 15.4"), "v15");
|
||||
assert_eq!(parse_pg_version("PostgreSQL 15.14"), "v15");
|
||||
assert_eq!(
|
||||
parse_pg_version("PostgreSQL 15.4 (Ubuntu 15.4-0ubuntu0.23.04.1)"),
|
||||
"v15"
|
||||
);
|
||||
|
||||
assert_eq!(parse_pg_version("PostgreSQL 14.15"), "v14");
|
||||
assert_eq!(parse_pg_version("PostgreSQL 14.0"), "v14");
|
||||
assert_eq!(
|
||||
parse_pg_version("PostgreSQL 14.9 (Debian 14.9-1.pgdg120+1"),
|
||||
"v14"
|
||||
);
|
||||
|
||||
assert_eq!(parse_pg_version("PostgreSQL 16devel"), "v16");
|
||||
assert_eq!(parse_pg_version("PostgreSQL 16beta1"), "v16");
|
||||
assert_eq!(parse_pg_version("PostgreSQL 16rc2"), "v16");
|
||||
assert_eq!(parse_pg_version("PostgreSQL 16extra"), "v16");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_parse_pg_unsupported_version() {
|
||||
parse_pg_version("PostgreSQL 13.14");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_parse_pg_incorrect_version_format() {
|
||||
parse_pg_version("PostgreSQL 14");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//!
|
||||
//! Various tools and helpers to handle cluster / compute node (Postgres)
|
||||
//! configuration.
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
//!
|
||||
pub mod checker;
|
||||
pub mod config;
|
||||
pub mod configurator;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
use std::{thread, time::Duration};
|
||||
use std::{thread, time};
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use postgres::{Client, NoTls};
|
||||
@@ -7,7 +7,7 @@ use tracing::{debug, info};
|
||||
|
||||
use crate::compute::ComputeNode;
|
||||
|
||||
const MONITOR_CHECK_INTERVAL: Duration = Duration::from_millis(500);
|
||||
const MONITOR_CHECK_INTERVAL: u64 = 500; // milliseconds
|
||||
|
||||
// Spin in a loop and figure out the last activity time in the Postgres.
|
||||
// Then update it in the shared state. This function never errors out.
|
||||
@@ -17,12 +17,13 @@ fn watch_compute_activity(compute: &ComputeNode) {
|
||||
let connstr = compute.connstr.as_str();
|
||||
// Define `client` outside of the loop to reuse existing connection if it's active.
|
||||
let mut client = Client::connect(connstr, NoTls);
|
||||
let timeout = time::Duration::from_millis(MONITOR_CHECK_INTERVAL);
|
||||
|
||||
info!("watching Postgres activity at {}", connstr);
|
||||
|
||||
loop {
|
||||
// Should be outside of the write lock to allow others to read while we sleep.
|
||||
thread::sleep(MONITOR_CHECK_INTERVAL);
|
||||
thread::sleep(timeout);
|
||||
|
||||
match &mut client {
|
||||
Ok(cli) => {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Write;
|
||||
use std::fs;
|
||||
use std::fs::File;
|
||||
@@ -193,16 +192,11 @@ impl Escaping for PgIdent {
|
||||
/// Build a list of existing Postgres roles
|
||||
pub fn get_existing_roles(xact: &mut Transaction<'_>) -> Result<Vec<Role>> {
|
||||
let postgres_roles = xact
|
||||
.query(
|
||||
"SELECT rolname, rolpassword, rolreplication, rolbypassrls FROM pg_catalog.pg_authid",
|
||||
&[],
|
||||
)?
|
||||
.query("SELECT rolname, rolpassword FROM pg_catalog.pg_authid", &[])?
|
||||
.iter()
|
||||
.map(|row| Role {
|
||||
name: row.get("rolname"),
|
||||
encrypted_password: row.get("rolpassword"),
|
||||
replication: Some(row.get("rolreplication")),
|
||||
bypassrls: Some(row.get("rolbypassrls")),
|
||||
options: None,
|
||||
})
|
||||
.collect();
|
||||
@@ -211,37 +205,22 @@ pub fn get_existing_roles(xact: &mut Transaction<'_>) -> Result<Vec<Role>> {
|
||||
}
|
||||
|
||||
/// Build a list of existing Postgres databases
|
||||
pub fn get_existing_dbs(client: &mut Client) -> Result<HashMap<String, Database>> {
|
||||
// `pg_database.datconnlimit = -2` means that the database is in the
|
||||
// invalid state. See:
|
||||
// https://github.com/postgres/postgres/commit/a4b4cc1d60f7e8ccfcc8ff8cb80c28ee411ad9a9
|
||||
let postgres_dbs: Vec<Database> = client
|
||||
pub fn get_existing_dbs(client: &mut Client) -> Result<Vec<Database>> {
|
||||
let postgres_dbs = client
|
||||
.query(
|
||||
"SELECT
|
||||
datname AS name,
|
||||
datdba::regrole::text AS owner,
|
||||
NOT datallowconn AS restrict_conn,
|
||||
datconnlimit = - 2 AS invalid
|
||||
FROM
|
||||
pg_catalog.pg_database;",
|
||||
"SELECT datname, datdba::regrole::text as owner
|
||||
FROM pg_catalog.pg_database;",
|
||||
&[],
|
||||
)?
|
||||
.iter()
|
||||
.map(|row| Database {
|
||||
name: row.get("name"),
|
||||
name: row.get("datname"),
|
||||
owner: row.get("owner"),
|
||||
restrict_conn: row.get("restrict_conn"),
|
||||
invalid: row.get("invalid"),
|
||||
options: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
let dbs_map = postgres_dbs
|
||||
.iter()
|
||||
.map(|db| (db.name.clone(), db.clone()))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
Ok(dbs_map)
|
||||
Ok(postgres_dbs)
|
||||
}
|
||||
|
||||
/// Wait for Postgres to become ready to accept connections. It's ready to
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::params::PG_HBA_ALL_MD5;
|
||||
use crate::pg_helpers::*;
|
||||
|
||||
use compute_api::responses::{ControlPlaneComputeStatus, ControlPlaneSpecResponse};
|
||||
use compute_api::spec::{ComputeSpec, PgIdent, Role};
|
||||
use compute_api::spec::{ComputeSpec, Database, PgIdent, Role};
|
||||
|
||||
// Do control plane request and return response if any. In case of error it
|
||||
// returns a bool flag indicating whether it makes sense to retry the request
|
||||
@@ -24,7 +24,7 @@ fn do_control_plane_request(
|
||||
) -> Result<ControlPlaneSpecResponse, (bool, String)> {
|
||||
let resp = reqwest::blocking::Client::new()
|
||||
.get(uri)
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.header("Authorization", jwt)
|
||||
.send()
|
||||
.map_err(|e| {
|
||||
(
|
||||
@@ -68,7 +68,7 @@ pub fn get_spec_from_control_plane(
|
||||
base_uri: &str,
|
||||
compute_id: &str,
|
||||
) -> Result<Option<ComputeSpec>> {
|
||||
let cp_uri = format!("{base_uri}/compute/api/v2/computes/{compute_id}/spec");
|
||||
let cp_uri = format!("{base_uri}/management/api/v2/computes/{compute_id}/spec");
|
||||
let jwt: String = match std::env::var("NEON_CONTROL_PLANE_TOKEN") {
|
||||
Ok(v) => v,
|
||||
Err(_) => "".to_string(),
|
||||
@@ -161,38 +161,6 @@ pub fn add_standby_signal(pgdata_path: &Path) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute could be unexpectedly shut down, for example, during the
|
||||
/// database dropping. This leaves the database in the invalid state,
|
||||
/// which prevents new db creation with the same name. This function
|
||||
/// will clean it up before proceeding with catalog updates. All
|
||||
/// possible future cleanup operations may go here too.
|
||||
#[instrument(skip_all)]
|
||||
pub fn cleanup_instance(client: &mut Client) -> Result<()> {
|
||||
let existing_dbs = get_existing_dbs(client)?;
|
||||
|
||||
for (_, db) in existing_dbs {
|
||||
if db.invalid {
|
||||
// After recent commit in Postgres, interrupted DROP DATABASE
|
||||
// leaves the database in the invalid state. According to the
|
||||
// commit message, the only option for user is to drop it again.
|
||||
// See:
|
||||
// https://github.com/postgres/postgres/commit/a4b4cc1d60f7e8ccfcc8ff8cb80c28ee411ad9a9
|
||||
//
|
||||
// Postgres Neon extension is done the way, that db is de-registered
|
||||
// in the control plane metadata only after it is dropped. So there is
|
||||
// a chance that it still thinks that db should exist. This means
|
||||
// that it will be re-created by `handle_databases()`. Yet, it's fine
|
||||
// as user can just repeat drop (in vanilla Postgres they would need
|
||||
// to do the same, btw).
|
||||
let query = format!("DROP DATABASE IF EXISTS {}", db.name.pg_quote());
|
||||
info!("dropping invalid database {}", db.name);
|
||||
client.execute(query.as_str(), &[])?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Given a cluster spec json and open transaction it handles roles creation,
|
||||
/// deletion and update.
|
||||
#[instrument(skip_all)]
|
||||
@@ -265,8 +233,6 @@ pub fn handle_roles(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
|
||||
let action = if let Some(r) = pg_role {
|
||||
if (r.encrypted_password.is_none() && role.encrypted_password.is_some())
|
||||
|| (r.encrypted_password.is_some() && role.encrypted_password.is_none())
|
||||
|| !r.bypassrls.unwrap_or(false)
|
||||
|| !r.replication.unwrap_or(false)
|
||||
{
|
||||
RoleAction::Update
|
||||
} else if let Some(pg_pwd) = &r.encrypted_password {
|
||||
@@ -298,14 +264,13 @@ pub fn handle_roles(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
|
||||
match action {
|
||||
RoleAction::None => {}
|
||||
RoleAction::Update => {
|
||||
let mut query: String =
|
||||
format!("ALTER ROLE {} BYPASSRLS REPLICATION", name.pg_quote());
|
||||
let mut query: String = format!("ALTER ROLE {} ", name.pg_quote());
|
||||
query.push_str(&role.to_pg_options());
|
||||
xact.execute(query.as_str(), &[])?;
|
||||
}
|
||||
RoleAction::Create => {
|
||||
let mut query: String = format!(
|
||||
"CREATE ROLE {} CREATEROLE CREATEDB BYPASSRLS REPLICATION IN ROLE neon_superuser",
|
||||
"CREATE ROLE {} CREATEROLE CREATEDB BYPASSRLS IN ROLE neon_superuser",
|
||||
name.pg_quote()
|
||||
);
|
||||
info!("role create query: '{}'", &query);
|
||||
@@ -414,13 +379,13 @@ fn reassign_owned_objects(spec: &ComputeSpec, connstr: &str, role_name: &PgIdent
|
||||
/// which together provide us idempotency.
|
||||
#[instrument(skip_all)]
|
||||
pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
|
||||
let existing_dbs = get_existing_dbs(client)?;
|
||||
let existing_dbs: Vec<Database> = get_existing_dbs(client)?;
|
||||
|
||||
// Print a list of existing Postgres databases (only in debug mode)
|
||||
if span_enabled!(Level::INFO) {
|
||||
info!("postgres databases:");
|
||||
for (dbname, db) in &existing_dbs {
|
||||
info!(" {}:{}", dbname, db.owner);
|
||||
for r in &existing_dbs {
|
||||
info!(" {}:{}", r.name, r.owner);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -474,7 +439,8 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
|
||||
"rename_db" => {
|
||||
let new_name = op.new_name.as_ref().unwrap();
|
||||
|
||||
if existing_dbs.get(&op.name).is_some() {
|
||||
// XXX: with a limited number of roles it is fine, but consider making it a HashMap
|
||||
if existing_dbs.iter().any(|r| r.name == op.name) {
|
||||
let query: String = format!(
|
||||
"ALTER DATABASE {} RENAME TO {}",
|
||||
op.name.pg_quote(),
|
||||
@@ -491,12 +457,14 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
|
||||
}
|
||||
|
||||
// Refresh Postgres databases info to handle possible renames
|
||||
let existing_dbs = get_existing_dbs(client)?;
|
||||
let existing_dbs: Vec<Database> = get_existing_dbs(client)?;
|
||||
|
||||
info!("cluster spec databases:");
|
||||
for db in &spec.cluster.databases {
|
||||
let name = &db.name;
|
||||
let pg_db = existing_dbs.get(name);
|
||||
|
||||
// XXX: with a limited number of databases it is fine, but consider making it a HashMap
|
||||
let pg_db = existing_dbs.iter().find(|r| r.name == *name);
|
||||
|
||||
enum DatabaseAction {
|
||||
None,
|
||||
@@ -562,32 +530,13 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
|
||||
/// Grant CREATE ON DATABASE to the database owner and do some other alters and grants
|
||||
/// to allow users creating trusted extensions and re-creating `public` schema, for example.
|
||||
#[instrument(skip_all)]
|
||||
pub fn handle_grants(spec: &ComputeSpec, client: &mut Client, connstr: &str) -> Result<()> {
|
||||
info!("modifying database permissions");
|
||||
let existing_dbs = get_existing_dbs(client)?;
|
||||
pub fn handle_grants(spec: &ComputeSpec, connstr: &str) -> Result<()> {
|
||||
info!("cluster spec grants:");
|
||||
|
||||
// Do some per-database access adjustments. We'd better do this at db creation time,
|
||||
// but CREATE DATABASE isn't transactional. So we cannot create db + do some grants
|
||||
// atomically.
|
||||
for db in &spec.cluster.databases {
|
||||
match existing_dbs.get(&db.name) {
|
||||
Some(pg_db) => {
|
||||
if pg_db.restrict_conn || pg_db.invalid {
|
||||
info!(
|
||||
"skipping grants for db {} (invalid: {}, connections not allowed: {})",
|
||||
db.name, pg_db.invalid, pg_db.restrict_conn
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
bail!(
|
||||
"database {} doesn't exist in Postgres after handle_databases()",
|
||||
db.name
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let mut conf = Config::from_str(connstr)?;
|
||||
conf.dbname(&db.name);
|
||||
|
||||
@@ -626,11 +575,6 @@ pub fn handle_grants(spec: &ComputeSpec, client: &mut Client, connstr: &str) ->
|
||||
|
||||
// Explicitly grant CREATE ON SCHEMA PUBLIC to the web_access user.
|
||||
// This is needed because since postgres 15 this privilege is removed by default.
|
||||
// TODO: web_access isn't created for almost 1 year. It could be that we have
|
||||
// active users of 1 year old projects, but hopefully not, so check it and
|
||||
// remove this code if possible. The worst thing that could happen is that
|
||||
// user won't be able to use public schema in NEW databases created in the
|
||||
// very OLD project.
|
||||
let grant_query = "DO $$\n\
|
||||
BEGIN\n\
|
||||
IF EXISTS(\n\
|
||||
|
||||
@@ -28,7 +28,7 @@ mod pg_helpers_tests {
|
||||
assert_eq!(
|
||||
spec.cluster.settings.as_pg_settings(),
|
||||
r#"fsync = off
|
||||
wal_level = logical
|
||||
wal_level = replica
|
||||
hot_standby = on
|
||||
neon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'
|
||||
wal_log_hints = on
|
||||
|
||||
@@ -6,7 +6,6 @@ license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
camino.workspace = true
|
||||
clap.workspace = true
|
||||
comfy-table.workspace = true
|
||||
git-version.workspace = true
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::{background_process, local_env::LocalEnv};
|
||||
use anyhow::anyhow;
|
||||
use camino::Utf8PathBuf;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
use std::{path::PathBuf, process::Child};
|
||||
use utils::id::{NodeId, TenantId};
|
||||
|
||||
@@ -9,15 +9,16 @@ pub struct AttachmentService {
|
||||
env: LocalEnv,
|
||||
listen: String,
|
||||
path: PathBuf,
|
||||
client: reqwest::blocking::Client,
|
||||
}
|
||||
|
||||
const COMMAND: &str = "attachment_service";
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct AttachHookRequest {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub tenant_id: TenantId,
|
||||
pub node_id: Option<NodeId>,
|
||||
pub pageserver_id: Option<NodeId>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
@@ -25,16 +26,6 @@ pub struct AttachHookResponse {
|
||||
pub gen: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct InspectRequest {
|
||||
pub tenant_id: TenantId,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct InspectResponse {
|
||||
pub attachment: Option<(u32, NodeId)>,
|
||||
}
|
||||
|
||||
impl AttachmentService {
|
||||
pub fn from_env(env: &LocalEnv) -> Self {
|
||||
let path = env.base_data_dir.join("attachments.json");
|
||||
@@ -53,15 +44,11 @@ impl AttachmentService {
|
||||
env: env.clone(),
|
||||
path,
|
||||
listen,
|
||||
client: reqwest::blocking::ClientBuilder::new()
|
||||
.build()
|
||||
.expect("Failed to construct http client"),
|
||||
}
|
||||
}
|
||||
|
||||
fn pid_file(&self) -> Utf8PathBuf {
|
||||
Utf8PathBuf::from_path_buf(self.env.base_data_dir.join("attachment_service.pid"))
|
||||
.expect("non-Unicode path")
|
||||
fn pid_file(&self) -> PathBuf {
|
||||
self.env.base_data_dir.join("attachment_service.pid")
|
||||
}
|
||||
|
||||
pub fn start(&self) -> anyhow::Result<Child> {
|
||||
@@ -96,15 +83,18 @@ impl AttachmentService {
|
||||
.control_plane_api
|
||||
.clone()
|
||||
.unwrap()
|
||||
.join("attach-hook")
|
||||
.join("attach_hook")
|
||||
.unwrap();
|
||||
let client = reqwest::blocking::ClientBuilder::new()
|
||||
.build()
|
||||
.expect("Failed to construct http client");
|
||||
|
||||
let request = AttachHookRequest {
|
||||
tenant_id,
|
||||
node_id: Some(pageserver_id),
|
||||
pageserver_id: Some(pageserver_id),
|
||||
};
|
||||
|
||||
let response = self.client.post(url).json(&request).send()?;
|
||||
let response = client.post(url).json(&request).send()?;
|
||||
if response.status() != StatusCode::OK {
|
||||
return Err(anyhow!("Unexpected status {}", response.status()));
|
||||
}
|
||||
@@ -112,26 +102,4 @@ impl AttachmentService {
|
||||
let response = response.json::<AttachHookResponse>()?;
|
||||
Ok(response.gen)
|
||||
}
|
||||
|
||||
pub fn inspect(&self, tenant_id: TenantId) -> anyhow::Result<Option<(u32, NodeId)>> {
|
||||
use hyper::StatusCode;
|
||||
|
||||
let url = self
|
||||
.env
|
||||
.control_plane_api
|
||||
.clone()
|
||||
.unwrap()
|
||||
.join("inspect")
|
||||
.unwrap();
|
||||
|
||||
let request = InspectRequest { tenant_id };
|
||||
|
||||
let response = self.client.post(url).json(&request).send()?;
|
||||
if response.status() != StatusCode::OK {
|
||||
return Err(anyhow!("Unexpected status {}", response.status()));
|
||||
}
|
||||
|
||||
let response = response.json::<InspectResponse>()?;
|
||||
Ok(response.attachment)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,13 +16,12 @@ use std::ffi::OsStr;
|
||||
use std::io::Write;
|
||||
use std::os::unix::prelude::AsRawFd;
|
||||
use std::os::unix::process::CommandExt;
|
||||
use std::path::Path;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::{Child, Command};
|
||||
use std::time::Duration;
|
||||
use std::{fs, io, thread};
|
||||
|
||||
use anyhow::Context;
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
use nix::errno::Errno;
|
||||
use nix::fcntl::{FcntlArg, FdFlag};
|
||||
use nix::sys::signal::{kill, Signal};
|
||||
@@ -46,9 +45,9 @@ const NOTICE_AFTER_RETRIES: u64 = 50;
|
||||
/// it itself.
|
||||
pub enum InitialPidFile<'t> {
|
||||
/// Create a pidfile, to allow future CLI invocations to manipulate the process.
|
||||
Create(&'t Utf8Path),
|
||||
Create(&'t Path),
|
||||
/// The process will create the pidfile itself, need to wait for that event.
|
||||
Expect(&'t Utf8Path),
|
||||
Expect(&'t Path),
|
||||
}
|
||||
|
||||
/// Start a background child process using the parameters given.
|
||||
@@ -86,7 +85,7 @@ where
|
||||
.stdout(process_log_file)
|
||||
.stderr(same_file_for_stderr)
|
||||
.args(args);
|
||||
let filled_cmd = fill_remote_storage_secrets_vars(fill_rust_env_vars(background_command));
|
||||
let filled_cmd = fill_aws_secrets_vars(fill_rust_env_vars(background_command));
|
||||
filled_cmd.envs(envs);
|
||||
|
||||
let pid_file_to_check = match initial_pid_file {
|
||||
@@ -138,11 +137,7 @@ where
|
||||
}
|
||||
|
||||
/// Stops the process, using the pid file given. Returns Ok also if the process is already not running.
|
||||
pub fn stop_process(
|
||||
immediate: bool,
|
||||
process_name: &str,
|
||||
pid_file: &Utf8Path,
|
||||
) -> anyhow::Result<()> {
|
||||
pub fn stop_process(immediate: bool, process_name: &str, pid_file: &Path) -> anyhow::Result<()> {
|
||||
let pid = match pid_file::read(pid_file)
|
||||
.with_context(|| format!("read pid_file {pid_file:?}"))?
|
||||
{
|
||||
@@ -238,13 +233,11 @@ fn fill_rust_env_vars(cmd: &mut Command) -> &mut Command {
|
||||
filled_cmd
|
||||
}
|
||||
|
||||
fn fill_remote_storage_secrets_vars(mut cmd: &mut Command) -> &mut Command {
|
||||
fn fill_aws_secrets_vars(mut cmd: &mut Command) -> &mut Command {
|
||||
for env_key in [
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
"AWS_SESSION_TOKEN",
|
||||
"AZURE_STORAGE_ACCOUNT",
|
||||
"AZURE_STORAGE_ACCESS_KEY",
|
||||
] {
|
||||
if let Ok(value) = std::env::var(env_key) {
|
||||
cmd = cmd.env(env_key, value);
|
||||
@@ -259,10 +252,10 @@ fn fill_remote_storage_secrets_vars(mut cmd: &mut Command) -> &mut Command {
|
||||
/// will remain held until the cmd exits.
|
||||
fn pre_exec_create_pidfile<P>(cmd: &mut Command, path: P) -> &mut Command
|
||||
where
|
||||
P: Into<Utf8PathBuf>,
|
||||
P: Into<PathBuf>,
|
||||
{
|
||||
let path: Utf8PathBuf = path.into();
|
||||
// SAFETY:
|
||||
let path: PathBuf = path.into();
|
||||
// SAFETY
|
||||
// pre_exec is marked unsafe because it runs between fork and exec.
|
||||
// Why is that dangerous in various ways?
|
||||
// Long answer: https://github.com/rust-lang/rust/issues/39575
|
||||
@@ -318,7 +311,7 @@ where
|
||||
|
||||
fn process_started<F>(
|
||||
pid: Pid,
|
||||
pid_file_to_check: Option<&Utf8Path>,
|
||||
pid_file_to_check: Option<&Path>,
|
||||
status_check: &F,
|
||||
) -> anyhow::Result<bool>
|
||||
where
|
||||
|
||||
@@ -12,9 +12,7 @@ use hyper::{Body, Request, Response};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use utils::http::endpoint::request_span;
|
||||
use utils::logging::{self, LogFormat};
|
||||
use utils::signals::{ShutdownSignals, Signal};
|
||||
|
||||
use utils::{
|
||||
http::{
|
||||
@@ -32,9 +30,7 @@ use pageserver_api::control_api::{
|
||||
ValidateResponseTenant,
|
||||
};
|
||||
|
||||
use control_plane::attachment_service::{
|
||||
AttachHookRequest, AttachHookResponse, InspectRequest, InspectResponse,
|
||||
};
|
||||
use control_plane::attachment_service::{AttachHookRequest, AttachHookResponse};
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
@@ -174,7 +170,7 @@ async fn handle_re_attach(mut req: Request<Body>) -> Result<Response<Body>, ApiE
|
||||
state.generation += 1;
|
||||
response.tenants.push(ReAttachResponseTenant {
|
||||
id: *t,
|
||||
gen: state.generation,
|
||||
generation: state.generation,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -220,31 +216,14 @@ async fn handle_attach_hook(mut req: Request<Body>) -> Result<Response<Body>, Ap
|
||||
.tenants
|
||||
.entry(attach_req.tenant_id)
|
||||
.or_insert_with(|| TenantState {
|
||||
pageserver: attach_req.node_id,
|
||||
pageserver: attach_req.pageserver_id,
|
||||
generation: 0,
|
||||
});
|
||||
|
||||
if let Some(attaching_pageserver) = attach_req.node_id.as_ref() {
|
||||
if attach_req.pageserver_id.is_some() {
|
||||
tenant_state.generation += 1;
|
||||
tracing::info!(
|
||||
tenant_id = %attach_req.tenant_id,
|
||||
ps_id = %attaching_pageserver,
|
||||
generation = %tenant_state.generation,
|
||||
"issuing",
|
||||
);
|
||||
} else if let Some(ps_id) = tenant_state.pageserver {
|
||||
tracing::info!(
|
||||
tenant_id = %attach_req.tenant_id,
|
||||
%ps_id,
|
||||
generation = %tenant_state.generation,
|
||||
"dropping",
|
||||
);
|
||||
} else {
|
||||
tracing::info!(
|
||||
tenant_id = %attach_req.tenant_id,
|
||||
"no-op: tenant already has no pageserver");
|
||||
}
|
||||
tenant_state.pageserver = attach_req.node_id;
|
||||
tenant_state.pageserver = attach_req.pageserver_id;
|
||||
let generation = tenant_state.generation;
|
||||
|
||||
locked.save().await.map_err(ApiError::InternalServerError)?;
|
||||
@@ -252,22 +231,7 @@ async fn handle_attach_hook(mut req: Request<Body>) -> Result<Response<Body>, Ap
|
||||
json_response(
|
||||
StatusCode::OK,
|
||||
AttachHookResponse {
|
||||
gen: attach_req.node_id.map(|_| generation),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle_inspect(mut req: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
let inspect_req = json_request::<InspectRequest>(&mut req).await?;
|
||||
|
||||
let state = get_state(&req).inner.clone();
|
||||
let locked = state.write().await;
|
||||
let tenant_state = locked.tenants.get(&inspect_req.tenant_id);
|
||||
|
||||
json_response(
|
||||
StatusCode::OK,
|
||||
InspectResponse {
|
||||
attachment: tenant_state.and_then(|s| s.pageserver.map(|ps| (s.generation, ps))),
|
||||
gen: attach_req.pageserver_id.map(|_| generation),
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -275,10 +239,9 @@ async fn handle_inspect(mut req: Request<Body>) -> Result<Response<Body>, ApiErr
|
||||
fn make_router(persistent_state: PersistentState) -> RouterBuilder<hyper::Body, ApiError> {
|
||||
endpoint::make_router()
|
||||
.data(Arc::new(State::new(persistent_state)))
|
||||
.post("/re-attach", |r| request_span(r, handle_re_attach))
|
||||
.post("/validate", |r| request_span(r, handle_validate))
|
||||
.post("/attach-hook", |r| request_span(r, handle_attach_hook))
|
||||
.post("/inspect", |r| request_span(r, handle_inspect))
|
||||
.post("/re-attach", handle_re_attach)
|
||||
.post("/validate", handle_validate)
|
||||
.post("/attach_hook", handle_attach_hook)
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@@ -305,16 +268,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
let server = hyper::Server::from_tcp(http_listener)?.serve(service);
|
||||
|
||||
tracing::info!("Serving on {0}", args.listen);
|
||||
|
||||
tokio::task::spawn(server);
|
||||
|
||||
ShutdownSignals::handle(|signal| match signal {
|
||||
Signal::Interrupt | Signal::Terminate | Signal::Quit => {
|
||||
tracing::info!("Got {}. Terminating", signal.name());
|
||||
// We're just a test helper: no graceful shutdown.
|
||||
std::process::exit(0);
|
||||
}
|
||||
})?;
|
||||
server.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -11,14 +11,13 @@ use compute_api::spec::ComputeMode;
|
||||
use control_plane::attachment_service::AttachmentService;
|
||||
use control_plane::endpoint::ComputeControlPlane;
|
||||
use control_plane::local_env::LocalEnv;
|
||||
use control_plane::pageserver::{PageServerNode, PAGESERVER_REMOTE_STORAGE_DIR};
|
||||
use control_plane::pageserver::PageServerNode;
|
||||
use control_plane::safekeeper::SafekeeperNode;
|
||||
use control_plane::tenant_migration::migrate_tenant;
|
||||
use control_plane::{broker, local_env};
|
||||
use pageserver_api::models::TimelineInfo;
|
||||
use pageserver_api::{
|
||||
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_PAGESERVER_HTTP_PORT,
|
||||
DEFAULT_PG_LISTEN_PORT as DEFAULT_PAGESERVER_PG_PORT,
|
||||
DEFAULT_HTTP_LISTEN_ADDR as DEFAULT_PAGESERVER_HTTP_ADDR,
|
||||
DEFAULT_PG_LISTEN_ADDR as DEFAULT_PAGESERVER_PG_ADDR,
|
||||
};
|
||||
use postgres_backend::AuthType;
|
||||
use safekeeper_api::{
|
||||
@@ -47,8 +46,8 @@ const DEFAULT_PG_VERSION: &str = "15";
|
||||
|
||||
const DEFAULT_PAGESERVER_CONTROL_PLANE_API: &str = "http://127.0.0.1:1234/";
|
||||
|
||||
fn default_conf(num_pageservers: u16) -> String {
|
||||
let mut template = format!(
|
||||
fn default_conf() -> String {
|
||||
format!(
|
||||
r#"
|
||||
# Default built-in configuration, defined in main.rs
|
||||
control_plane_api = '{DEFAULT_PAGESERVER_CONTROL_PLANE_API}'
|
||||
@@ -56,33 +55,21 @@ control_plane_api = '{DEFAULT_PAGESERVER_CONTROL_PLANE_API}'
|
||||
[broker]
|
||||
listen_addr = '{DEFAULT_BROKER_ADDR}'
|
||||
|
||||
[[pageservers]]
|
||||
id = {DEFAULT_PAGESERVER_ID}
|
||||
listen_pg_addr = '{DEFAULT_PAGESERVER_PG_ADDR}'
|
||||
listen_http_addr = '{DEFAULT_PAGESERVER_HTTP_ADDR}'
|
||||
pg_auth_type = '{trust_auth}'
|
||||
http_auth_type = '{trust_auth}'
|
||||
|
||||
[[safekeepers]]
|
||||
id = {DEFAULT_SAFEKEEPER_ID}
|
||||
pg_port = {DEFAULT_SAFEKEEPER_PG_PORT}
|
||||
http_port = {DEFAULT_SAFEKEEPER_HTTP_PORT}
|
||||
|
||||
"#,
|
||||
);
|
||||
|
||||
for i in 0..num_pageservers {
|
||||
let pageserver_id = NodeId(DEFAULT_PAGESERVER_ID.0 + i as u64);
|
||||
let pg_port = DEFAULT_PAGESERVER_PG_PORT + i;
|
||||
let http_port = DEFAULT_PAGESERVER_HTTP_PORT + i;
|
||||
|
||||
template += &format!(
|
||||
r#"
|
||||
[[pageservers]]
|
||||
id = {pageserver_id}
|
||||
listen_pg_addr = '127.0.0.1:{pg_port}'
|
||||
listen_http_addr = '127.0.0.1:{http_port}'
|
||||
pg_auth_type = '{trust_auth}'
|
||||
http_auth_type = '{trust_auth}'
|
||||
"#,
|
||||
trust_auth = AuthType::Trust,
|
||||
)
|
||||
}
|
||||
|
||||
template
|
||||
trust_auth = AuthType::Trust,
|
||||
)
|
||||
}
|
||||
|
||||
///
|
||||
@@ -129,7 +116,6 @@ fn main() -> Result<()> {
|
||||
"attachment_service" => handle_attachment_service(sub_args, &env),
|
||||
"safekeeper" => handle_safekeeper(sub_args, &env),
|
||||
"endpoint" => handle_endpoint(sub_args, &env),
|
||||
"mappings" => handle_mappings(sub_args, &mut env),
|
||||
"pg" => bail!("'pg' subcommand has been renamed to 'endpoint'"),
|
||||
_ => bail!("unexpected subcommand {sub_name}"),
|
||||
};
|
||||
@@ -308,9 +294,6 @@ fn parse_timeline_id(sub_match: &ArgMatches) -> anyhow::Result<Option<TimelineId
|
||||
}
|
||||
|
||||
fn handle_init(init_match: &ArgMatches) -> anyhow::Result<LocalEnv> {
|
||||
let num_pageservers = init_match
|
||||
.get_one::<u16>("num-pageservers")
|
||||
.expect("num-pageservers arg has a default");
|
||||
// Create config file
|
||||
let toml_file: String = if let Some(config_path) = init_match.get_one::<PathBuf>("config") {
|
||||
// load and parse the file
|
||||
@@ -322,7 +305,7 @@ fn handle_init(init_match: &ArgMatches) -> anyhow::Result<LocalEnv> {
|
||||
})?
|
||||
} else {
|
||||
// Built-in default config
|
||||
default_conf(*num_pageservers)
|
||||
default_conf()
|
||||
};
|
||||
|
||||
let pg_version = init_match
|
||||
@@ -336,9 +319,6 @@ fn handle_init(init_match: &ArgMatches) -> anyhow::Result<LocalEnv> {
|
||||
env.init(pg_version, force)
|
||||
.context("Failed to initialize neon repository")?;
|
||||
|
||||
// Create remote storage location for default LocalFs remote storage
|
||||
std::fs::create_dir_all(env.base_data_dir.join(PAGESERVER_REMOTE_STORAGE_DIR))?;
|
||||
|
||||
// Initialize pageserver, create initial tenant and timeline.
|
||||
for ps_conf in &env.pageservers {
|
||||
PageServerNode::from_env(&env, ps_conf)
|
||||
@@ -452,15 +432,6 @@ fn handle_tenant(tenant_match: &ArgMatches, env: &mut local_env::LocalEnv) -> an
|
||||
.with_context(|| format!("Tenant config failed for tenant with id {tenant_id}"))?;
|
||||
println!("tenant {tenant_id} successfully configured on the pageserver");
|
||||
}
|
||||
Some(("migrate", matches)) => {
|
||||
let tenant_id = get_tenant_id(matches, env)?;
|
||||
let new_pageserver = get_pageserver(env, matches)?;
|
||||
let new_pageserver_id = new_pageserver.conf.id;
|
||||
|
||||
migrate_tenant(env, tenant_id, new_pageserver)?;
|
||||
println!("tenant {tenant_id} migrated to {}", new_pageserver_id);
|
||||
}
|
||||
|
||||
Some((sub_name, _)) => bail!("Unexpected tenant subcommand '{}'", sub_name),
|
||||
None => bail!("no tenant subcommand provided"),
|
||||
}
|
||||
@@ -826,24 +797,6 @@ fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<(
|
||||
ep.start(&auth_token, safekeepers, remote_ext_config)?;
|
||||
}
|
||||
}
|
||||
"reconfigure" => {
|
||||
let endpoint_id = sub_args
|
||||
.get_one::<String>("endpoint_id")
|
||||
.ok_or_else(|| anyhow!("No endpoint ID provided to reconfigure"))?;
|
||||
let endpoint = cplane
|
||||
.endpoints
|
||||
.get(endpoint_id.as_str())
|
||||
.with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?;
|
||||
let pageserver_id =
|
||||
if let Some(id_str) = sub_args.get_one::<String>("endpoint-pageserver-id") {
|
||||
Some(NodeId(
|
||||
id_str.parse().context("while parsing pageserver id")?,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
endpoint.reconfigure(pageserver_id)?;
|
||||
}
|
||||
"stop" => {
|
||||
let endpoint_id = sub_args
|
||||
.get_one::<String>("endpoint_id")
|
||||
@@ -863,52 +816,20 @@ fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_mappings(sub_match: &ArgMatches, env: &mut local_env::LocalEnv) -> Result<()> {
|
||||
let (sub_name, sub_args) = match sub_match.subcommand() {
|
||||
Some(ep_subcommand_data) => ep_subcommand_data,
|
||||
None => bail!("no mappings subcommand provided"),
|
||||
};
|
||||
|
||||
match sub_name {
|
||||
"map" => {
|
||||
let branch_name = sub_args
|
||||
.get_one::<String>("branch-name")
|
||||
.expect("branch-name argument missing");
|
||||
|
||||
let tenant_id = sub_args
|
||||
.get_one::<String>("tenant-id")
|
||||
.map(|x| TenantId::from_str(x))
|
||||
.expect("tenant-id argument missing")
|
||||
.expect("malformed tenant-id arg");
|
||||
|
||||
let timeline_id = sub_args
|
||||
.get_one::<String>("timeline-id")
|
||||
.map(|x| TimelineId::from_str(x))
|
||||
.expect("timeline-id argument missing")
|
||||
.expect("malformed timeline-id arg");
|
||||
|
||||
env.register_branch_mapping(branch_name.to_owned(), tenant_id, timeline_id)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
other => unimplemented!("mappings subcommand {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_pageserver(env: &local_env::LocalEnv, args: &ArgMatches) -> Result<PageServerNode> {
|
||||
let node_id = if let Some(id_str) = args.get_one::<String>("pageserver-id") {
|
||||
NodeId(id_str.parse().context("while parsing pageserver id")?)
|
||||
} else {
|
||||
DEFAULT_PAGESERVER_ID
|
||||
};
|
||||
|
||||
Ok(PageServerNode::from_env(
|
||||
env,
|
||||
env.get_pageserver_conf(node_id)?,
|
||||
))
|
||||
}
|
||||
|
||||
fn handle_pageserver(sub_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<()> {
|
||||
fn get_pageserver(env: &local_env::LocalEnv, args: &ArgMatches) -> Result<PageServerNode> {
|
||||
let node_id = if let Some(id_str) = args.get_one::<String>("pageserver-id") {
|
||||
NodeId(id_str.parse().context("while parsing pageserver id")?)
|
||||
} else {
|
||||
DEFAULT_PAGESERVER_ID
|
||||
};
|
||||
|
||||
Ok(PageServerNode::from_env(
|
||||
env,
|
||||
env.get_pageserver_conf(node_id)?,
|
||||
))
|
||||
}
|
||||
|
||||
match sub_match.subcommand() {
|
||||
Some(("start", subcommand_args)) => {
|
||||
if let Err(e) = get_pageserver(env, subcommand_args)?
|
||||
@@ -945,20 +866,6 @@ fn handle_pageserver(sub_match: &ArgMatches, env: &local_env::LocalEnv) -> Resul
|
||||
}
|
||||
}
|
||||
|
||||
Some(("migrate", subcommand_args)) => {
|
||||
let pageserver = get_pageserver(env, subcommand_args)?;
|
||||
//TODO what shutdown strategy should we use here?
|
||||
if let Err(e) = pageserver.stop(false) {
|
||||
eprintln!("pageserver stop failed: {}", e);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if let Err(e) = pageserver.start(&pageserver_config_overrides(subcommand_args)) {
|
||||
eprintln!("pageserver start failed: {e}");
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
Some(("status", subcommand_args)) => {
|
||||
match get_pageserver(env, subcommand_args)?.check_status() {
|
||||
Ok(_) => println!("Page server is up and running"),
|
||||
@@ -1177,7 +1084,6 @@ fn cli() -> Command {
|
||||
// --id, when using a pageserver command
|
||||
let pageserver_id_arg = Arg::new("pageserver-id")
|
||||
.long("id")
|
||||
.global(true)
|
||||
.help("pageserver id")
|
||||
.required(false);
|
||||
// --pageserver-id when using a non-pageserver command
|
||||
@@ -1266,13 +1172,6 @@ fn cli() -> Command {
|
||||
.help("Force initialization even if the repository is not empty")
|
||||
.required(false);
|
||||
|
||||
let num_pageservers_arg = Arg::new("num-pageservers")
|
||||
.value_parser(value_parser!(u16))
|
||||
.long("num-pageservers")
|
||||
.help("How many pageservers to create (default 1)")
|
||||
.required(false)
|
||||
.default_value("1");
|
||||
|
||||
Command::new("Neon CLI")
|
||||
.arg_required_else_help(true)
|
||||
.version(GIT_VERSION)
|
||||
@@ -1280,7 +1179,6 @@ fn cli() -> Command {
|
||||
Command::new("init")
|
||||
.about("Initialize a new Neon repository, preparing configs for services to start with")
|
||||
.arg(pageserver_config_args.clone())
|
||||
.arg(num_pageservers_arg.clone())
|
||||
.arg(
|
||||
Arg::new("config")
|
||||
.long("config")
|
||||
@@ -1351,29 +1249,22 @@ fn cli() -> Command {
|
||||
.subcommand(Command::new("config")
|
||||
.arg(tenant_id_arg.clone())
|
||||
.arg(Arg::new("config").short('c').num_args(1).action(ArgAction::Append).required(false)))
|
||||
.subcommand(Command::new("migrate")
|
||||
.about("Migrate a tenant from one pageserver to another")
|
||||
.arg(tenant_id_arg.clone())
|
||||
.arg(pageserver_id_arg.clone()))
|
||||
)
|
||||
.subcommand(
|
||||
Command::new("pageserver")
|
||||
.arg_required_else_help(true)
|
||||
.about("Manage pageserver")
|
||||
.arg(pageserver_id_arg)
|
||||
.subcommand(Command::new("status"))
|
||||
.subcommand(Command::new("start")
|
||||
.about("Start local pageserver")
|
||||
.arg(pageserver_config_args.clone())
|
||||
)
|
||||
.subcommand(Command::new("stop")
|
||||
.about("Stop local pageserver")
|
||||
.arg(stop_mode_arg.clone())
|
||||
)
|
||||
.subcommand(Command::new("restart")
|
||||
.about("Restart local pageserver")
|
||||
.arg(pageserver_config_args.clone())
|
||||
)
|
||||
.arg(pageserver_id_arg.clone())
|
||||
.subcommand(Command::new("start").about("Start local pageserver")
|
||||
.arg(pageserver_id_arg.clone())
|
||||
.arg(pageserver_config_args.clone()))
|
||||
.subcommand(Command::new("stop").about("Stop local pageserver")
|
||||
.arg(pageserver_id_arg.clone())
|
||||
.arg(stop_mode_arg.clone()))
|
||||
.subcommand(Command::new("restart").about("Restart local pageserver")
|
||||
.arg(pageserver_id_arg.clone())
|
||||
.arg(pageserver_config_args.clone()))
|
||||
)
|
||||
.subcommand(
|
||||
Command::new("attachment_service")
|
||||
@@ -1430,8 +1321,8 @@ fn cli() -> Command {
|
||||
.about("Start postgres.\n If the endpoint doesn't exist yet, it is created.")
|
||||
.arg(endpoint_id_arg.clone())
|
||||
.arg(tenant_id_arg.clone())
|
||||
.arg(branch_name_arg.clone())
|
||||
.arg(timeline_id_arg.clone())
|
||||
.arg(branch_name_arg)
|
||||
.arg(timeline_id_arg)
|
||||
.arg(lsn_arg)
|
||||
.arg(pg_port_arg)
|
||||
.arg(http_port_arg)
|
||||
@@ -1441,16 +1332,10 @@ fn cli() -> Command {
|
||||
.arg(safekeepers_arg)
|
||||
.arg(remote_ext_config_args)
|
||||
)
|
||||
.subcommand(Command::new("reconfigure")
|
||||
.about("Reconfigure the endpoint")
|
||||
.arg(endpoint_pageserver_id_arg)
|
||||
.arg(endpoint_id_arg.clone())
|
||||
.arg(tenant_id_arg.clone())
|
||||
)
|
||||
.subcommand(
|
||||
Command::new("stop")
|
||||
.arg(endpoint_id_arg)
|
||||
.arg(tenant_id_arg.clone())
|
||||
.arg(tenant_id_arg)
|
||||
.arg(
|
||||
Arg::new("destroy")
|
||||
.help("Also delete data directory (now optional, should be default in future)")
|
||||
@@ -1461,18 +1346,6 @@ fn cli() -> Command {
|
||||
)
|
||||
|
||||
)
|
||||
.subcommand(
|
||||
Command::new("mappings")
|
||||
.arg_required_else_help(true)
|
||||
.about("Manage neon_local branch name mappings")
|
||||
.subcommand(
|
||||
Command::new("map")
|
||||
.about("Create new mapping which cannot exist already")
|
||||
.arg(branch_name_arg.clone())
|
||||
.arg(tenant_id_arg.clone())
|
||||
.arg(timeline_id_arg.clone())
|
||||
)
|
||||
)
|
||||
// Obsolete old name for 'endpoint'. We now just print an error if it's used.
|
||||
.subcommand(
|
||||
Command::new("pg")
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
//! ```
|
||||
use anyhow::Context;
|
||||
|
||||
use camino::Utf8PathBuf;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::{background_process, local_env};
|
||||
|
||||
@@ -30,7 +30,7 @@ pub fn start_broker_process(env: &local_env::LocalEnv) -> anyhow::Result<()> {
|
||||
|| {
|
||||
let url = broker.client_url();
|
||||
let status_url = url.join("status").with_context(|| {
|
||||
format!("Failed to append /status path to broker endpoint {url}")
|
||||
format!("Failed to append /status path to broker endpoint {url}",)
|
||||
})?;
|
||||
let request = client
|
||||
.get(status_url)
|
||||
@@ -50,7 +50,6 @@ pub fn stop_broker_process(env: &local_env::LocalEnv) -> anyhow::Result<()> {
|
||||
background_process::stop_process(true, "storage_broker", &storage_broker_pid_file_path(env))
|
||||
}
|
||||
|
||||
fn storage_broker_pid_file_path(env: &local_env::LocalEnv) -> Utf8PathBuf {
|
||||
Utf8PathBuf::from_path_buf(env.base_data_dir.join("storage_broker.pid"))
|
||||
.expect("non-Unicode path")
|
||||
fn storage_broker_pid_file_path(env: &local_env::LocalEnv) -> PathBuf {
|
||||
env.base_data_dir.join("storage_broker.pid")
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ use std::time::Duration;
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
use utils::id::{NodeId, TenantId, TimelineId};
|
||||
|
||||
use crate::local_env::LocalEnv;
|
||||
@@ -56,10 +57,13 @@ use compute_api::responses::{ComputeState, ComputeStatus};
|
||||
use compute_api::spec::{Cluster, ComputeMode, ComputeSpec};
|
||||
|
||||
// contents of a endpoint.json file
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
|
||||
pub struct EndpointConf {
|
||||
endpoint_id: String,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
tenant_id: TenantId,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
timeline_id: TimelineId,
|
||||
mode: ComputeMode,
|
||||
pg_port: u16,
|
||||
@@ -249,7 +253,7 @@ impl Endpoint {
|
||||
conf.append("shared_buffers", "1MB");
|
||||
conf.append("fsync", "off");
|
||||
conf.append("max_connections", "100");
|
||||
conf.append("wal_level", "logical");
|
||||
conf.append("wal_level", "replica");
|
||||
// wal_sender_timeout is the maximum time to wait for WAL replication.
|
||||
// It also defines how often the walreciever will send a feedback message to the wal sender.
|
||||
conf.append("wal_sender_timeout", "5s");
|
||||
@@ -410,32 +414,16 @@ impl Endpoint {
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn wait_for_compute_ctl_to_exit(&self) -> Result<()> {
|
||||
// Also wait for the compute_ctl process to die. It might have some cleanup
|
||||
// work to do after postgres stops, like syncing safekeepers, etc.
|
||||
//
|
||||
// TODO use background_process::stop_process instead
|
||||
let pidfile_path = self.endpoint_path().join("compute_ctl.pid");
|
||||
let pid: u32 = std::fs::read_to_string(pidfile_path)?.parse()?;
|
||||
let pid = nix::unistd::Pid::from_raw(pid as i32);
|
||||
crate::background_process::wait_until_stopped("compute_ctl", pid)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn read_postgresql_conf(&self) -> Result<String> {
|
||||
// Slurp the endpoints/<endpoint id>/postgresql.conf file into
|
||||
// memory. We will include it in the spec file that we pass to
|
||||
// `compute_ctl`, and `compute_ctl` will write it to the postgresql.conf
|
||||
// in the data directory.
|
||||
let postgresql_conf_path = self.endpoint_path().join("postgresql.conf");
|
||||
match std::fs::read(&postgresql_conf_path) {
|
||||
Ok(content) => Ok(String::from_utf8(content)?),
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok("".to_string()),
|
||||
Err(e) => Err(anyhow::Error::new(e).context(format!(
|
||||
"failed to read config file in {}",
|
||||
postgresql_conf_path.to_str().unwrap()
|
||||
))),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn start(
|
||||
@@ -448,7 +436,21 @@ impl Endpoint {
|
||||
anyhow::bail!("The endpoint is already running");
|
||||
}
|
||||
|
||||
let postgresql_conf = self.read_postgresql_conf()?;
|
||||
// Slurp the endpoints/<endpoint id>/postgresql.conf file into
|
||||
// memory. We will include it in the spec file that we pass to
|
||||
// `compute_ctl`, and `compute_ctl` will write it to the postgresql.conf
|
||||
// in the data directory.
|
||||
let postgresql_conf_path = self.endpoint_path().join("postgresql.conf");
|
||||
let postgresql_conf = match std::fs::read(&postgresql_conf_path) {
|
||||
Ok(content) => String::from_utf8(content)?,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => "".to_string(),
|
||||
Err(e) => {
|
||||
return Err(anyhow::Error::new(e).context(format!(
|
||||
"failed to read config file in {}",
|
||||
postgresql_conf_path.to_str().unwrap()
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
// We always start the compute node from scratch, so if the Postgres
|
||||
// data dir exists from a previous launch, remove it first.
|
||||
@@ -619,61 +621,6 @@ impl Endpoint {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reconfigure(&self, pageserver_id: Option<NodeId>) -> Result<()> {
|
||||
let mut spec: ComputeSpec = {
|
||||
let spec_path = self.endpoint_path().join("spec.json");
|
||||
let file = std::fs::File::open(spec_path)?;
|
||||
serde_json::from_reader(file)?
|
||||
};
|
||||
|
||||
let postgresql_conf = self.read_postgresql_conf()?;
|
||||
spec.cluster.postgresql_conf = Some(postgresql_conf);
|
||||
|
||||
if let Some(pageserver_id) = pageserver_id {
|
||||
let endpoint_config_path = self.endpoint_path().join("endpoint.json");
|
||||
let mut endpoint_conf: EndpointConf = {
|
||||
let file = std::fs::File::open(&endpoint_config_path)?;
|
||||
serde_json::from_reader(file)?
|
||||
};
|
||||
endpoint_conf.pageserver_id = pageserver_id;
|
||||
std::fs::write(
|
||||
endpoint_config_path,
|
||||
serde_json::to_string_pretty(&endpoint_conf)?,
|
||||
)?;
|
||||
|
||||
let pageserver =
|
||||
PageServerNode::from_env(&self.env, self.env.get_pageserver_conf(pageserver_id)?);
|
||||
let ps_http_conf = &pageserver.pg_connection_config;
|
||||
let (host, port) = (ps_http_conf.host(), ps_http_conf.port());
|
||||
spec.pageserver_connstring = Some(format!("postgresql://no_user@{host}:{port}"));
|
||||
}
|
||||
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let response = client
|
||||
.post(format!(
|
||||
"http://{}:{}/configure",
|
||||
self.http_address.ip(),
|
||||
self.http_address.port()
|
||||
))
|
||||
.body(format!(
|
||||
"{{\"spec\":{}}}",
|
||||
serde_json::to_string_pretty(&spec)?
|
||||
))
|
||||
.send()?;
|
||||
|
||||
let status = response.status();
|
||||
if !(status.is_client_error() || status.is_server_error()) {
|
||||
Ok(())
|
||||
} else {
|
||||
let url = response.url().to_owned();
|
||||
let msg = match response.text() {
|
||||
Ok(err_body) => format!("Error: {}", err_body),
|
||||
Err(_) => format!("Http error ({}) at {}.", status.as_u16(), url),
|
||||
};
|
||||
Err(anyhow::anyhow!(msg))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stop(&self, destroy: bool) -> Result<()> {
|
||||
// If we are going to destroy data directory,
|
||||
// use immediate shutdown mode, otherwise,
|
||||
@@ -682,25 +629,15 @@ impl Endpoint {
|
||||
// Postgres is always started from scratch, so stop
|
||||
// without destroy only used for testing and debugging.
|
||||
//
|
||||
self.pg_ctl(
|
||||
if destroy {
|
||||
&["-m", "immediate", "stop"]
|
||||
} else {
|
||||
&["stop"]
|
||||
},
|
||||
&None,
|
||||
)?;
|
||||
|
||||
// Also wait for the compute_ctl process to die. It might have some cleanup
|
||||
// work to do after postgres stops, like syncing safekeepers, etc.
|
||||
//
|
||||
self.wait_for_compute_ctl_to_exit()?;
|
||||
if destroy {
|
||||
self.pg_ctl(&["-m", "immediate", "stop"], &None)?;
|
||||
println!(
|
||||
"Destroying postgres data directory '{}'",
|
||||
self.pgdata().to_str().unwrap()
|
||||
);
|
||||
std::fs::remove_dir_all(self.endpoint_path())?;
|
||||
} else {
|
||||
self.pg_ctl(&["stop"], &None)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
//! Local control plane.
|
||||
//!
|
||||
//! Can start, configure and stop postgres instances running as a local processes.
|
||||
//!
|
||||
//! Intended to be used in integration tests and in CLI tools for
|
||||
//! local installations.
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
//
|
||||
// Local control plane.
|
||||
//
|
||||
// Can start, configure and stop postgres instances running as a local processes.
|
||||
//
|
||||
// Intended to be used in integration tests and in CLI tools for
|
||||
// local installations.
|
||||
//
|
||||
|
||||
pub mod attachment_service;
|
||||
mod background_process;
|
||||
@@ -14,4 +15,3 @@ pub mod local_env;
|
||||
pub mod pageserver;
|
||||
pub mod postgresql_conf;
|
||||
pub mod safekeeper;
|
||||
pub mod tenant_migration;
|
||||
|
||||
@@ -8,6 +8,7 @@ use anyhow::{bail, ensure, Context};
|
||||
use postgres_backend::AuthType;
|
||||
use reqwest::Url;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::fs;
|
||||
@@ -32,6 +33,7 @@ pub const DEFAULT_PG_VERSION: u32 = 15;
|
||||
// to 'neon_local init --config=<path>' option. See control_plane/simple.conf for
|
||||
// an example.
|
||||
//
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
|
||||
pub struct LocalEnv {
|
||||
// Base directory for all the nodes (the pageserver, safekeepers and
|
||||
@@ -57,6 +59,7 @@ pub struct LocalEnv {
|
||||
// Default tenant ID to use with the 'neon_local' command line utility, when
|
||||
// --tenant_id is not explicitly specified.
|
||||
#[serde(default)]
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub default_tenant_id: Option<TenantId>,
|
||||
|
||||
// used to issue tokens during e.g pg start
|
||||
@@ -81,6 +84,7 @@ pub struct LocalEnv {
|
||||
// A `HashMap<String, HashMap<TenantId, TimelineId>>` would be more appropriate here,
|
||||
// but deserialization into a generic toml object as `toml::Value::try_from` fails with an error.
|
||||
// https://toml.io/en/v1.0.0 does not contain a concept of "a table inside another table".
|
||||
#[serde_as(as = "HashMap<_, Vec<(DisplayFromStr, DisplayFromStr)>>")]
|
||||
branch_name_mappings: HashMap<String, Vec<(TenantId, TimelineId)>>,
|
||||
}
|
||||
|
||||
|
||||
@@ -14,11 +14,7 @@ use std::process::{Child, Command};
|
||||
use std::{io, result};
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use camino::Utf8PathBuf;
|
||||
use pageserver_api::models::{
|
||||
self, LocationConfig, TenantInfo, TenantLocationConfigRequest, TimelineInfo,
|
||||
};
|
||||
use pageserver_api::shard::TenantShardId;
|
||||
use pageserver_api::models::{self, TenantInfo, TimelineInfo};
|
||||
use postgres_backend::AuthType;
|
||||
use postgres_connection::{parse_host_port, PgConnectionConfig};
|
||||
use reqwest::blocking::{Client, RequestBuilder, Response};
|
||||
@@ -34,9 +30,6 @@ use utils::{
|
||||
use crate::local_env::PageServerConf;
|
||||
use crate::{background_process, local_env::LocalEnv};
|
||||
|
||||
/// Directory within .neon which will be used by default for LocalFs remote storage.
|
||||
pub const PAGESERVER_REMOTE_STORAGE_DIR: &str = "local_fs_remote_storage/pageserver";
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum PageserverHttpError {
|
||||
#[error("Reqwest error: {0}")]
|
||||
@@ -104,10 +97,8 @@ impl PageServerNode {
|
||||
}
|
||||
}
|
||||
|
||||
/// Merge overrides provided by the user on the command line with our default overides derived from neon_local configuration.
|
||||
///
|
||||
/// These all end up on the command line of the `pageserver` binary.
|
||||
fn neon_local_overrides(&self, cli_overrides: &[&str]) -> Vec<String> {
|
||||
// pageserver conf overrides defined by neon_local configuration.
|
||||
fn neon_local_overrides(&self) -> Vec<String> {
|
||||
let id = format!("id={}", self.conf.id);
|
||||
// FIXME: the paths should be shell-escaped to handle paths with spaces, quotas etc.
|
||||
let pg_distrib_dir_param = format!(
|
||||
@@ -140,25 +131,12 @@ impl PageServerNode {
|
||||
));
|
||||
}
|
||||
|
||||
if !cli_overrides
|
||||
.iter()
|
||||
.any(|c| c.starts_with("remote_storage"))
|
||||
{
|
||||
overrides.push(format!(
|
||||
"remote_storage={{local_path='../{PAGESERVER_REMOTE_STORAGE_DIR}'}}"
|
||||
));
|
||||
}
|
||||
|
||||
if self.conf.http_auth_type != AuthType::Trust || self.conf.pg_auth_type != AuthType::Trust
|
||||
{
|
||||
// Keys are generated in the toplevel repo dir, pageservers' workdirs
|
||||
// are one level below that, so refer to keys with ../
|
||||
overrides.push("auth_validation_public_key_path='../auth_public_key.pem'".to_owned());
|
||||
}
|
||||
|
||||
// Apply the user-provided overrides
|
||||
overrides.extend(cli_overrides.iter().map(|&c| c.to_owned()));
|
||||
|
||||
overrides
|
||||
}
|
||||
|
||||
@@ -166,7 +144,7 @@ impl PageServerNode {
|
||||
pub fn initialize(&self, config_overrides: &[&str]) -> anyhow::Result<()> {
|
||||
// First, run `pageserver --init` and wait for it to write a config into FS and exit.
|
||||
self.pageserver_init(config_overrides)
|
||||
.with_context(|| format!("Failed to run init for pageserver node {}", self.conf.id))
|
||||
.with_context(|| format!("Failed to run init for pageserver node {}", self.conf.id,))
|
||||
}
|
||||
|
||||
pub fn repo_path(&self) -> PathBuf {
|
||||
@@ -176,9 +154,8 @@ impl PageServerNode {
|
||||
/// The pid file is created by the pageserver process, with its pid stored inside.
|
||||
/// Other pageservers cannot lock the same file and overwrite it for as long as the current
|
||||
/// pageserver runs. (Unless someone removes the file manually; never do that!)
|
||||
fn pid_file(&self) -> Utf8PathBuf {
|
||||
Utf8PathBuf::from_path_buf(self.repo_path().join("pageserver.pid"))
|
||||
.expect("non-Unicode path")
|
||||
fn pid_file(&self) -> PathBuf {
|
||||
self.repo_path().join("pageserver.pid")
|
||||
}
|
||||
|
||||
pub fn start(&self, config_overrides: &[&str]) -> anyhow::Result<Child> {
|
||||
@@ -224,6 +201,9 @@ impl PageServerNode {
|
||||
}
|
||||
|
||||
fn start_node(&self, config_overrides: &[&str], update_config: bool) -> anyhow::Result<Child> {
|
||||
let mut overrides = self.neon_local_overrides();
|
||||
overrides.extend(config_overrides.iter().map(|&c| c.to_owned()));
|
||||
|
||||
let datadir = self.repo_path();
|
||||
print!(
|
||||
"Starting pageserver node {} at '{}' in {:?}",
|
||||
@@ -266,7 +246,8 @@ impl PageServerNode {
|
||||
) -> Vec<Cow<'a, str>> {
|
||||
let mut args = vec![Cow::Borrowed("-D"), Cow::Borrowed(datadir_path_str)];
|
||||
|
||||
let overrides = self.neon_local_overrides(config_overrides);
|
||||
let mut overrides = self.neon_local_overrides();
|
||||
overrides.extend(config_overrides.iter().map(|&c| c.to_owned()));
|
||||
for config_override in overrides {
|
||||
args.push(Cow::Borrowed("-c"));
|
||||
args.push(Cow::Owned(config_override));
|
||||
@@ -409,7 +390,7 @@ impl PageServerNode {
|
||||
};
|
||||
|
||||
let request = models::TenantCreateRequest {
|
||||
new_tenant_id: TenantShardId::unsharded(new_tenant_id),
|
||||
new_tenant_id,
|
||||
generation,
|
||||
config,
|
||||
};
|
||||
@@ -518,27 +499,6 @@ impl PageServerNode {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn location_config(
|
||||
&self,
|
||||
tenant_id: TenantId,
|
||||
config: LocationConfig,
|
||||
) -> anyhow::Result<()> {
|
||||
let req_body = TenantLocationConfigRequest { tenant_id, config };
|
||||
|
||||
self.http_request(
|
||||
Method::PUT,
|
||||
format!(
|
||||
"{}/tenant/{}/location_config",
|
||||
self.http_base_url, tenant_id
|
||||
),
|
||||
)?
|
||||
.json(&req_body)
|
||||
.send()?
|
||||
.error_from_body()?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn timeline_list(&self, tenant_id: &TenantId) -> anyhow::Result<Vec<TimelineInfo>> {
|
||||
let timeline_infos: Vec<TimelineInfo> = self
|
||||
.http_request(
|
||||
|
||||
@@ -11,7 +11,6 @@ use std::process::Child;
|
||||
use std::{io, result};
|
||||
|
||||
use anyhow::Context;
|
||||
use camino::Utf8PathBuf;
|
||||
use postgres_connection::PgConnectionConfig;
|
||||
use reqwest::blocking::{Client, RequestBuilder, Response};
|
||||
use reqwest::{IntoUrl, Method};
|
||||
@@ -98,9 +97,8 @@ impl SafekeeperNode {
|
||||
SafekeeperNode::datadir_path_by_id(&self.env, self.id)
|
||||
}
|
||||
|
||||
pub fn pid_file(&self) -> Utf8PathBuf {
|
||||
Utf8PathBuf::from_path_buf(self.datadir_path().join("safekeeper.pid"))
|
||||
.expect("non-Unicode path")
|
||||
pub fn pid_file(&self) -> PathBuf {
|
||||
self.datadir_path().join("safekeeper.pid")
|
||||
}
|
||||
|
||||
pub fn start(&self, extra_opts: Vec<String>) -> anyhow::Result<Child> {
|
||||
|
||||
@@ -1,202 +0,0 @@
|
||||
//!
|
||||
//! Functionality for migrating tenants across pageservers: unlike most of neon_local, this code
|
||||
//! isn't scoped to a particular physical service, as it needs to update compute endpoints to
|
||||
//! point to the new pageserver.
|
||||
//!
|
||||
use crate::local_env::LocalEnv;
|
||||
use crate::{
|
||||
attachment_service::AttachmentService, endpoint::ComputeControlPlane,
|
||||
pageserver::PageServerNode,
|
||||
};
|
||||
use pageserver_api::models::{
|
||||
LocationConfig, LocationConfigMode, LocationConfigSecondary, TenantConfig,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
use utils::{
|
||||
generation::Generation,
|
||||
id::{TenantId, TimelineId},
|
||||
lsn::Lsn,
|
||||
};
|
||||
|
||||
/// Given an attached pageserver, retrieve the LSN for all timelines
|
||||
fn get_lsns(
|
||||
tenant_id: TenantId,
|
||||
pageserver: &PageServerNode,
|
||||
) -> anyhow::Result<HashMap<TimelineId, Lsn>> {
|
||||
let timelines = pageserver.timeline_list(&tenant_id)?;
|
||||
Ok(timelines
|
||||
.into_iter()
|
||||
.map(|t| (t.timeline_id, t.last_record_lsn))
|
||||
.collect())
|
||||
}
|
||||
|
||||
/// Wait for the timeline LSNs on `pageserver` to catch up with or overtake
|
||||
/// `baseline`.
|
||||
fn await_lsn(
|
||||
tenant_id: TenantId,
|
||||
pageserver: &PageServerNode,
|
||||
baseline: HashMap<TimelineId, Lsn>,
|
||||
) -> anyhow::Result<()> {
|
||||
loop {
|
||||
let latest = match get_lsns(tenant_id, pageserver) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
println!(
|
||||
"🕑 Can't get LSNs on pageserver {} yet, waiting ({e})",
|
||||
pageserver.conf.id
|
||||
);
|
||||
std::thread::sleep(Duration::from_millis(500));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let mut any_behind: bool = false;
|
||||
for (timeline_id, baseline_lsn) in &baseline {
|
||||
match latest.get(timeline_id) {
|
||||
Some(latest_lsn) => {
|
||||
println!("🕑 LSN origin {baseline_lsn} vs destination {latest_lsn}");
|
||||
if latest_lsn < baseline_lsn {
|
||||
any_behind = true;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
// Expected timeline isn't yet visible on migration destination.
|
||||
// (IRL we would have to account for timeline deletion, but this
|
||||
// is just test helper)
|
||||
any_behind = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !any_behind {
|
||||
println!("✅ LSN caught up. Proceeding...");
|
||||
break;
|
||||
} else {
|
||||
std::thread::sleep(Duration::from_millis(500));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// This function spans multiple services, to demonstrate live migration of a tenant
|
||||
/// between pageservers:
|
||||
/// - Coordinate attach/secondary/detach on pageservers
|
||||
/// - call into attachment_service for generations
|
||||
/// - reconfigure compute endpoints to point to new attached pageserver
|
||||
pub fn migrate_tenant(
|
||||
env: &LocalEnv,
|
||||
tenant_id: TenantId,
|
||||
dest_ps: PageServerNode,
|
||||
) -> anyhow::Result<()> {
|
||||
// Get a new generation
|
||||
let attachment_service = AttachmentService::from_env(env);
|
||||
|
||||
let previous = attachment_service.inspect(tenant_id)?;
|
||||
let mut baseline_lsns = None;
|
||||
if let Some((generation, origin_ps_id)) = &previous {
|
||||
let origin_ps = PageServerNode::from_env(env, env.get_pageserver_conf(*origin_ps_id)?);
|
||||
|
||||
if origin_ps_id == &dest_ps.conf.id {
|
||||
println!("🔁 Already attached to {origin_ps_id}, freshening...");
|
||||
let gen = attachment_service.attach_hook(tenant_id, dest_ps.conf.id)?;
|
||||
let dest_conf = LocationConfig {
|
||||
mode: LocationConfigMode::AttachedSingle,
|
||||
generation: gen.map(Generation::new),
|
||||
secondary_conf: None,
|
||||
tenant_conf: TenantConfig::default(),
|
||||
};
|
||||
dest_ps.location_config(tenant_id, dest_conf)?;
|
||||
println!("✅ Migration complete");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
println!("🔁 Switching origin pageserver {origin_ps_id} to stale mode");
|
||||
|
||||
let stale_conf = LocationConfig {
|
||||
mode: LocationConfigMode::AttachedStale,
|
||||
generation: Some(Generation::new(*generation)),
|
||||
secondary_conf: None,
|
||||
tenant_conf: TenantConfig::default(),
|
||||
};
|
||||
origin_ps.location_config(tenant_id, stale_conf)?;
|
||||
|
||||
baseline_lsns = Some(get_lsns(tenant_id, &origin_ps)?);
|
||||
}
|
||||
|
||||
let gen = attachment_service.attach_hook(tenant_id, dest_ps.conf.id)?;
|
||||
let dest_conf = LocationConfig {
|
||||
mode: LocationConfigMode::AttachedMulti,
|
||||
generation: gen.map(Generation::new),
|
||||
secondary_conf: None,
|
||||
tenant_conf: TenantConfig::default(),
|
||||
};
|
||||
|
||||
println!("🔁 Attaching to pageserver {}", dest_ps.conf.id);
|
||||
dest_ps.location_config(tenant_id, dest_conf)?;
|
||||
|
||||
if let Some(baseline) = baseline_lsns {
|
||||
println!("🕑 Waiting for LSN to catch up...");
|
||||
await_lsn(tenant_id, &dest_ps, baseline)?;
|
||||
}
|
||||
|
||||
let cplane = ComputeControlPlane::load(env.clone())?;
|
||||
for (endpoint_name, endpoint) in &cplane.endpoints {
|
||||
if endpoint.tenant_id == tenant_id {
|
||||
println!(
|
||||
"🔁 Reconfiguring endpoint {} to use pageserver {}",
|
||||
endpoint_name, dest_ps.conf.id
|
||||
);
|
||||
endpoint.reconfigure(Some(dest_ps.conf.id))?;
|
||||
}
|
||||
}
|
||||
|
||||
for other_ps_conf in &env.pageservers {
|
||||
if other_ps_conf.id == dest_ps.conf.id {
|
||||
continue;
|
||||
}
|
||||
|
||||
let other_ps = PageServerNode::from_env(env, other_ps_conf);
|
||||
let other_ps_tenants = other_ps.tenant_list()?;
|
||||
|
||||
// Check if this tenant is attached
|
||||
let found = other_ps_tenants
|
||||
.into_iter()
|
||||
.map(|t| t.id)
|
||||
.any(|i| i == tenant_id);
|
||||
if !found {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Downgrade to a secondary location
|
||||
let secondary_conf = LocationConfig {
|
||||
mode: LocationConfigMode::Secondary,
|
||||
generation: None,
|
||||
secondary_conf: Some(LocationConfigSecondary { warm: true }),
|
||||
tenant_conf: TenantConfig::default(),
|
||||
};
|
||||
|
||||
println!(
|
||||
"💤 Switching to secondary mode on pageserver {}",
|
||||
other_ps.conf.id
|
||||
);
|
||||
other_ps.location_config(tenant_id, secondary_conf)?;
|
||||
}
|
||||
|
||||
println!(
|
||||
"🔁 Switching to AttachedSingle mode on pageserver {}",
|
||||
dest_ps.conf.id
|
||||
);
|
||||
let dest_conf = LocationConfig {
|
||||
mode: LocationConfigMode::AttachedSingle,
|
||||
generation: gen.map(Generation::new),
|
||||
secondary_conf: None,
|
||||
tenant_conf: TenantConfig::default(),
|
||||
};
|
||||
dest_ps.location_config(tenant_id, dest_conf)?;
|
||||
|
||||
println!("✅ Migration complete");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
24
deny.toml
24
deny.toml
@@ -23,7 +23,7 @@ vulnerability = "deny"
|
||||
unmaintained = "warn"
|
||||
yanked = "warn"
|
||||
notice = "warn"
|
||||
ignore = []
|
||||
ignore = ["RUSTSEC-2023-0052"]
|
||||
|
||||
# This section is considered when running `cargo deny check licenses`
|
||||
# More documentation for the licenses section can be found here:
|
||||
@@ -74,30 +74,10 @@ highlight = "all"
|
||||
workspace-default-features = "allow"
|
||||
external-default-features = "allow"
|
||||
allow = []
|
||||
|
||||
deny = []
|
||||
skip = []
|
||||
skip-tree = []
|
||||
|
||||
[[bans.deny]]
|
||||
# we use tokio, the same rationale applies for async-{io,waker,global-executor,executor,channel,lock}, smol
|
||||
# if you find yourself here while adding a dependency, try "default-features = false", ask around on #rust
|
||||
name = "async-std"
|
||||
|
||||
[[bans.deny]]
|
||||
name = "async-io"
|
||||
|
||||
[[bans.deny]]
|
||||
name = "async-waker"
|
||||
|
||||
[[bans.deny]]
|
||||
name = "async-global-executor"
|
||||
|
||||
[[bans.deny]]
|
||||
name = "async-executor"
|
||||
|
||||
[[bans.deny]]
|
||||
name = "smol"
|
||||
|
||||
# This section is considered when running `cargo deny check sources`.
|
||||
# More documentation about the 'sources' section can be found here:
|
||||
# https://embarkstudios.github.io/cargo-deny/checks/sources/cfg.html
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
},
|
||||
{
|
||||
"name": "wal_level",
|
||||
"value": "logical",
|
||||
"value": "replica",
|
||||
"vartype": "enum"
|
||||
},
|
||||
{
|
||||
|
||||
@@ -188,60 +188,11 @@ that.
|
||||
|
||||
## Error message style
|
||||
|
||||
### PostgreSQL extensions
|
||||
|
||||
PostgreSQL has a style guide for writing error messages:
|
||||
|
||||
https://www.postgresql.org/docs/current/error-style-guide.html
|
||||
|
||||
Follow that guide when writing error messages in the PostgreSQL
|
||||
extensions.
|
||||
|
||||
### Neon Rust code
|
||||
|
||||
#### Anyhow Context
|
||||
|
||||
When adding anyhow `context()`, use form `present-tense-verb+action`.
|
||||
|
||||
Example:
|
||||
- Bad: `file.metadata().context("could not get file metadata")?;`
|
||||
- Good: `file.metadata().context("get file metadata")?;`
|
||||
|
||||
#### Logging Errors
|
||||
|
||||
When logging any error `e`, use `could not {e:#}` or `failed to {e:#}`.
|
||||
|
||||
If `e` is an `anyhow` error and you want to log the backtrace that it contains,
|
||||
use `{e:?}` instead of `{e:#}`.
|
||||
|
||||
#### Rationale
|
||||
|
||||
The `{:#}` ("alternate Display") of an `anyhow` error chain is concatenation fo the contexts, using `: `.
|
||||
|
||||
For example, the following Rust code will result in output
|
||||
```
|
||||
ERROR failed to list users: load users from server: parse response: invalid json
|
||||
```
|
||||
|
||||
This is more concise / less noisy than what happens if you do `.context("could not ...")?` at each level, i.e.:
|
||||
|
||||
```
|
||||
ERROR could not list users: could not load users from server: could not parse response: invalid json
|
||||
```
|
||||
|
||||
|
||||
```rust
|
||||
fn main() {
|
||||
match list_users().context("list users") else {
|
||||
Ok(_) => ...,
|
||||
Err(e) => tracing::error!("failed to {e:#}"),
|
||||
}
|
||||
}
|
||||
fn list_users() {
|
||||
http_get_users().context("load users from server")?;
|
||||
}
|
||||
fn http_get_users() {
|
||||
let response = client....?;
|
||||
response.parse().context("parse response")?; // fails with serde error "invalid json"
|
||||
}
|
||||
```
|
||||
extension. We don't follow it strictly in the pageserver and
|
||||
safekeeper, but the advice in the PostgreSQL style guide is generally
|
||||
good, and you can't go wrong by following it.
|
||||
|
||||
@@ -96,16 +96,6 @@ prefix_in_bucket = '/test_prefix/'
|
||||
|
||||
`AWS_SECRET_ACCESS_KEY` and `AWS_ACCESS_KEY_ID` env variables can be used to specify the S3 credentials if needed.
|
||||
|
||||
or
|
||||
|
||||
```toml
|
||||
[remote_storage]
|
||||
container_name = 'some-container-name'
|
||||
container_region = 'us-east'
|
||||
prefix_in_container = '/test-prefix/'
|
||||
```
|
||||
|
||||
`AZURE_STORAGE_ACCOUNT` and `AZURE_STORAGE_ACCESS_KEY` env variables can be used to specify the azure credentials if needed.
|
||||
|
||||
## Repository background tasks
|
||||
|
||||
|
||||
@@ -177,7 +177,7 @@ I e during migration create_branch can be called on old pageserver and newly cre
|
||||
|
||||
The difference of simplistic approach from one described above is that it calls ignore on source tenant first and then calls attach on target pageserver. Approach above does it in opposite order thus opening a possibility for race conditions we strive to avoid.
|
||||
|
||||
The approach largely follows this guide: <https://www.notion.so/neondatabase/Cloud-Ad-hoc-tenant-relocation-f687474f7bfc42269e6214e3acba25c7>
|
||||
The approach largely follows this guide: <https://github.com/neondatabase/cloud/wiki/Cloud:-Ad-hoc-tenant-relocation>
|
||||
|
||||
The happy path sequence:
|
||||
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
# Updating Postgres
|
||||
|
||||
## Minor Versions
|
||||
|
||||
When upgrading to a new minor version of Postgres, please follow these steps:
|
||||
|
||||
_Example: 15.4 is the new minor version to upgrade to from 15.3._
|
||||
|
||||
1. Clone the Neon Postgres repository if you have not done so already.
|
||||
|
||||
```shell
|
||||
git clone git@github.com:neondatabase/postgres.git
|
||||
```
|
||||
|
||||
1. Add the Postgres upstream remote.
|
||||
|
||||
```shell
|
||||
git remote add upstream https://git.postgresql.org/git/postgresql.git
|
||||
```
|
||||
|
||||
1. Create a new branch based on the stable branch you are updating.
|
||||
|
||||
```shell
|
||||
git checkout -b my-branch REL_15_STABLE_neon
|
||||
```
|
||||
|
||||
1. Tag the last commit on the stable branch you are updating.
|
||||
|
||||
```shell
|
||||
git tag REL_15_3_neon
|
||||
```
|
||||
|
||||
1. Push the new tag to the Neon Postgres repository.
|
||||
|
||||
```shell
|
||||
git push origin REL_15_3_neon
|
||||
```
|
||||
|
||||
1. Find the release tags you're looking for. They are of the form `REL_X_Y`.
|
||||
|
||||
1. Rebase the branch you created on the tag and resolve any conflicts.
|
||||
|
||||
```shell
|
||||
git fetch upstream REL_15_4
|
||||
git rebase REL_15_4
|
||||
```
|
||||
|
||||
1. Run the Postgres test suite to make sure our commits have not affected
|
||||
Postgres in a negative way.
|
||||
|
||||
```shell
|
||||
make check
|
||||
# OR
|
||||
meson test -C builddir
|
||||
```
|
||||
|
||||
1. Push your branch to the Neon Postgres repository.
|
||||
|
||||
```shell
|
||||
git push origin my-branch
|
||||
```
|
||||
|
||||
1. Clone the Neon repository if you have not done so already.
|
||||
|
||||
```shell
|
||||
git clone git@github.com:neondatabase/neon.git
|
||||
```
|
||||
|
||||
1. Create a new branch.
|
||||
|
||||
1. Change the `revisions.json` file to point at the HEAD of your Postgres
|
||||
branch.
|
||||
|
||||
1. Update the Git submodule.
|
||||
|
||||
```shell
|
||||
git submodule set-branch --branch my-branch vendor/postgres-v15
|
||||
git submodule update --remote vendor/postgres-v15
|
||||
```
|
||||
|
||||
1. Run the Neon test suite to make sure that Neon is still good to go on this
|
||||
minor Postgres release.
|
||||
|
||||
```shell
|
||||
./scripts/poetry -k pg15
|
||||
```
|
||||
|
||||
1. Commit your changes.
|
||||
|
||||
1. Create a pull request, and wait for CI to go green.
|
||||
|
||||
1. Force push the rebased Postgres branches into the Neon Postgres repository.
|
||||
|
||||
```shell
|
||||
git push --force origin my-branch:REL_15_STABLE_neon
|
||||
```
|
||||
|
||||
It may require disabling various branch protections.
|
||||
|
||||
1. Update your Neon PR to point at the branches.
|
||||
|
||||
```shell
|
||||
git submodule set-branch --branch REL_15_STABLE_neon vendor/postgres-v15
|
||||
git commit --amend --no-edit
|
||||
git push --force origin
|
||||
```
|
||||
|
||||
1. Merge the pull request after getting approval(s) and CI completion.
|
||||
@@ -1,5 +1,3 @@
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
pub mod requests;
|
||||
pub mod responses;
|
||||
pub mod spec;
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
use utils::id::{TenantId, TimelineId};
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
@@ -18,6 +19,7 @@ pub type PgIdent = String;
|
||||
|
||||
/// Cluster spec or configuration represented as an optional number of
|
||||
/// delta operations + final cluster state description.
|
||||
#[serde_as]
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
|
||||
pub struct ComputeSpec {
|
||||
pub format_version: f32,
|
||||
@@ -48,12 +50,12 @@ pub struct ComputeSpec {
|
||||
// these, and instead set the "neon.tenant_id", "neon.timeline_id",
|
||||
// etc. GUCs in cluster.settings. TODO: Once the control plane has been
|
||||
// updated to fill these fields, we can make these non optional.
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub tenant_id: Option<TenantId>,
|
||||
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub timeline_id: Option<TimelineId>,
|
||||
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub pageserver_connstring: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub safekeeper_connstrings: Vec<String>,
|
||||
|
||||
@@ -138,13 +140,14 @@ impl RemoteExtSpec {
|
||||
}
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ComputeMode {
|
||||
/// A read-write node
|
||||
#[default]
|
||||
Primary,
|
||||
/// A read-only node, pinned at a particular LSN
|
||||
Static(Lsn),
|
||||
Static(#[serde_as(as = "DisplayFromStr")] Lsn),
|
||||
/// A read-only node that follows the tip of the branch in hot standby mode
|
||||
///
|
||||
/// Future versions may want to distinguish between replicas with hot standby
|
||||
@@ -187,8 +190,6 @@ pub struct DeltaOp {
|
||||
pub struct Role {
|
||||
pub name: PgIdent,
|
||||
pub encrypted_password: Option<String>,
|
||||
pub replication: Option<bool>,
|
||||
pub bypassrls: Option<bool>,
|
||||
pub options: GenericOptions,
|
||||
}
|
||||
|
||||
@@ -199,12 +200,6 @@ pub struct Database {
|
||||
pub name: PgIdent,
|
||||
pub owner: PgIdent,
|
||||
pub options: GenericOptions,
|
||||
// These are derived flags, not present in the spec file.
|
||||
// They are never set by the control plane.
|
||||
#[serde(skip_deserializing, default)]
|
||||
pub restrict_conn: bool,
|
||||
#[serde(skip_deserializing, default)]
|
||||
pub invalid: bool,
|
||||
}
|
||||
|
||||
/// Common type representing both SQL statement params with or without value,
|
||||
|
||||
@@ -76,7 +76,7 @@
|
||||
},
|
||||
{
|
||||
"name": "wal_level",
|
||||
"value": "logical",
|
||||
"value": "replica",
|
||||
"vartype": "enum"
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//!
|
||||
//! Shared code for consumption metics collection
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
//!
|
||||
use chrono::{DateTime, Utc};
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
//! make sure that we use the same dep version everywhere.
|
||||
//! Otherwise, we might not see all metrics registered via
|
||||
//! a default registry.
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
use once_cell::sync::Lazy;
|
||||
use prometheus::core::{AtomicU64, Collector, GenericGauge, GenericGaugeVec};
|
||||
pub use prometheus::opts;
|
||||
@@ -90,14 +89,14 @@ pub const DISK_WRITE_SECONDS_BUCKETS: &[f64] = &[
|
||||
0.000_050, 0.000_100, 0.000_500, 0.001, 0.003, 0.005, 0.01, 0.05, 0.1, 0.3, 0.5,
|
||||
];
|
||||
|
||||
pub fn set_build_info_metric(revision: &str, build_tag: &str) {
|
||||
pub fn set_build_info_metric(revision: &str) {
|
||||
let metric = register_int_gauge_vec!(
|
||||
"libmetrics_build_info",
|
||||
"Build/version information",
|
||||
&["revision", "build_tag"]
|
||||
&["revision"]
|
||||
)
|
||||
.expect("Failed to register build info metric");
|
||||
metric.with_label_values(&[revision, build_tag]).set(1);
|
||||
metric.with_label_values(&[revision]).set(1);
|
||||
}
|
||||
|
||||
// Records I/O stats in a "cross-platform" way.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::io::{Read, Result, Write};
|
||||
|
||||
/// A wrapper for an object implementing [Read]
|
||||
/// A wrapper for an object implementing [Read](std::io::Read)
|
||||
/// which allows a closure to observe the amount of bytes read.
|
||||
/// This is useful in conjunction with metrics (e.g. [IntCounter](crate::IntCounter)).
|
||||
///
|
||||
@@ -51,17 +51,17 @@ impl<'a, T> CountedReader<'a, T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get an immutable reference to the underlying [Read] implementor
|
||||
/// Get an immutable reference to the underlying [Read](std::io::Read) implementor
|
||||
pub fn inner(&self) -> &T {
|
||||
&self.reader
|
||||
}
|
||||
|
||||
/// Get a mutable reference to the underlying [Read] implementor
|
||||
/// Get a mutable reference to the underlying [Read](std::io::Read) implementor
|
||||
pub fn inner_mut(&mut self) -> &mut T {
|
||||
&mut self.reader
|
||||
}
|
||||
|
||||
/// Consume the wrapper and return the underlying [Read] implementor
|
||||
/// Consume the wrapper and return the underlying [Read](std::io::Read) implementor
|
||||
pub fn into_inner(self) -> T {
|
||||
self.reader
|
||||
}
|
||||
@@ -75,7 +75,7 @@ impl<T: Read> Read for CountedReader<'_, T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper for an object implementing [Write]
|
||||
/// A wrapper for an object implementing [Write](std::io::Write)
|
||||
/// which allows a closure to observe the amount of bytes written.
|
||||
/// This is useful in conjunction with metrics (e.g. [IntCounter](crate::IntCounter)).
|
||||
///
|
||||
@@ -122,17 +122,17 @@ impl<'a, T> CountedWriter<'a, T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get an immutable reference to the underlying [Write] implementor
|
||||
/// Get an immutable reference to the underlying [Write](std::io::Write) implementor
|
||||
pub fn inner(&self) -> &T {
|
||||
&self.writer
|
||||
}
|
||||
|
||||
/// Get a mutable reference to the underlying [Write] implementor
|
||||
/// Get a mutable reference to the underlying [Write](std::io::Write) implementor
|
||||
pub fn inner_mut(&mut self) -> &mut T {
|
||||
&mut self.writer
|
||||
}
|
||||
|
||||
/// Consume the wrapper and return the underlying [Write] implementor
|
||||
/// Consume the wrapper and return the underlying [Write](std::io::Write) implementor
|
||||
pub fn into_inner(self) -> T {
|
||||
self.writer
|
||||
}
|
||||
|
||||
@@ -17,9 +17,5 @@ postgres_ffi.workspace = true
|
||||
enum-map.workspace = true
|
||||
strum.workspace = true
|
||||
strum_macros.workspace = true
|
||||
hex.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
bincode.workspace = true
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
//! See docs/rfcs/025-generation-numbers.md
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
use utils::id::{NodeId, TenantId};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
@@ -11,10 +12,12 @@ pub struct ReAttachRequest {
|
||||
pub node_id: NodeId,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ReAttachResponseTenant {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub id: TenantId,
|
||||
pub gen: u32,
|
||||
pub generation: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
@@ -22,8 +25,10 @@ pub struct ReAttachResponse {
|
||||
pub tenants: Vec<ReAttachResponseTenant>,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ValidateRequestTenant {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub id: TenantId,
|
||||
pub gen: u32,
|
||||
}
|
||||
@@ -38,8 +43,10 @@ pub struct ValidateResponse {
|
||||
pub tenants: Vec<ValidateResponseTenant>,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ValidateResponseTenant {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub id: TenantId,
|
||||
pub valid: bool,
|
||||
}
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
use anyhow::{bail, Result};
|
||||
use byteorder::{ByteOrder, BE};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
/// Key used in the Repository kv-store.
|
||||
///
|
||||
/// The Repository treats this as an opaque struct, but see the code in pgdatadir_mapping.rs
|
||||
/// for what we actually store in these fields.
|
||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Ord, PartialOrd, Serialize, Deserialize)]
|
||||
pub struct Key {
|
||||
pub field1: u8,
|
||||
pub field2: u32,
|
||||
pub field3: u32,
|
||||
pub field4: u32,
|
||||
pub field5: u8,
|
||||
pub field6: u32,
|
||||
}
|
||||
|
||||
pub const KEY_SIZE: usize = 18;
|
||||
|
||||
impl Key {
|
||||
/// 'field2' is used to store tablespaceid for relations and small enum numbers for other relish.
|
||||
/// As long as Neon does not support tablespace (because of lack of access to local file system),
|
||||
/// we can assume that only some predefined namespace OIDs are used which can fit in u16
|
||||
pub fn to_i128(&self) -> i128 {
|
||||
assert!(self.field2 < 0xFFFF || self.field2 == 0xFFFFFFFF || self.field2 == 0x22222222);
|
||||
(((self.field1 & 0xf) as i128) << 120)
|
||||
| (((self.field2 & 0xFFFF) as i128) << 104)
|
||||
| ((self.field3 as i128) << 72)
|
||||
| ((self.field4 as i128) << 40)
|
||||
| ((self.field5 as i128) << 32)
|
||||
| self.field6 as i128
|
||||
}
|
||||
|
||||
pub const fn from_i128(x: i128) -> Self {
|
||||
Key {
|
||||
field1: ((x >> 120) & 0xf) as u8,
|
||||
field2: ((x >> 104) & 0xFFFF) as u32,
|
||||
field3: (x >> 72) as u32,
|
||||
field4: (x >> 40) as u32,
|
||||
field5: (x >> 32) as u8,
|
||||
field6: x as u32,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next(&self) -> Key {
|
||||
self.add(1)
|
||||
}
|
||||
|
||||
pub fn add(&self, x: u32) -> Key {
|
||||
let mut key = *self;
|
||||
|
||||
let r = key.field6.overflowing_add(x);
|
||||
key.field6 = r.0;
|
||||
if r.1 {
|
||||
let r = key.field5.overflowing_add(1);
|
||||
key.field5 = r.0;
|
||||
if r.1 {
|
||||
let r = key.field4.overflowing_add(1);
|
||||
key.field4 = r.0;
|
||||
if r.1 {
|
||||
let r = key.field3.overflowing_add(1);
|
||||
key.field3 = r.0;
|
||||
if r.1 {
|
||||
let r = key.field2.overflowing_add(1);
|
||||
key.field2 = r.0;
|
||||
if r.1 {
|
||||
let r = key.field1.overflowing_add(1);
|
||||
key.field1 = r.0;
|
||||
assert!(!r.1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
key
|
||||
}
|
||||
|
||||
pub fn from_slice(b: &[u8]) -> Self {
|
||||
Key {
|
||||
field1: b[0],
|
||||
field2: u32::from_be_bytes(b[1..5].try_into().unwrap()),
|
||||
field3: u32::from_be_bytes(b[5..9].try_into().unwrap()),
|
||||
field4: u32::from_be_bytes(b[9..13].try_into().unwrap()),
|
||||
field5: b[13],
|
||||
field6: u32::from_be_bytes(b[14..18].try_into().unwrap()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_to_byte_slice(&self, buf: &mut [u8]) {
|
||||
buf[0] = self.field1;
|
||||
BE::write_u32(&mut buf[1..5], self.field2);
|
||||
BE::write_u32(&mut buf[5..9], self.field3);
|
||||
BE::write_u32(&mut buf[9..13], self.field4);
|
||||
buf[13] = self.field5;
|
||||
BE::write_u32(&mut buf[14..18], self.field6);
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Key {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{:02X}{:08X}{:08X}{:08X}{:02X}{:08X}",
|
||||
self.field1, self.field2, self.field3, self.field4, self.field5, self.field6
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl Key {
|
||||
pub const MIN: Key = Key {
|
||||
field1: u8::MIN,
|
||||
field2: u32::MIN,
|
||||
field3: u32::MIN,
|
||||
field4: u32::MIN,
|
||||
field5: u8::MIN,
|
||||
field6: u32::MIN,
|
||||
};
|
||||
pub const MAX: Key = Key {
|
||||
field1: u8::MAX,
|
||||
field2: u32::MAX,
|
||||
field3: u32::MAX,
|
||||
field4: u32::MAX,
|
||||
field5: u8::MAX,
|
||||
field6: u32::MAX,
|
||||
};
|
||||
|
||||
pub fn from_hex(s: &str) -> Result<Self> {
|
||||
if s.len() != 36 {
|
||||
bail!("parse error");
|
||||
}
|
||||
Ok(Key {
|
||||
field1: u8::from_str_radix(&s[0..2], 16)?,
|
||||
field2: u32::from_str_radix(&s[2..10], 16)?,
|
||||
field3: u32::from_str_radix(&s[10..18], 16)?,
|
||||
field4: u32::from_str_radix(&s[18..26], 16)?,
|
||||
field5: u8::from_str_radix(&s[26..28], 16)?,
|
||||
field6: u32::from_str_radix(&s[28..36], 16)?,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,9 @@
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
use const_format::formatcp;
|
||||
|
||||
/// Public API types
|
||||
pub mod control_api;
|
||||
pub mod key;
|
||||
pub mod models;
|
||||
pub mod reltag;
|
||||
pub mod shard;
|
||||
|
||||
pub const DEFAULT_PG_LISTEN_PORT: u16 = 64000;
|
||||
pub const DEFAULT_PG_LISTEN_ADDR: &str = formatcp!("127.0.0.1:{DEFAULT_PG_LISTEN_PORT}");
|
||||
|
||||
@@ -6,17 +6,16 @@ use std::{
|
||||
|
||||
use byteorder::{BigEndian, ReadBytesExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::serde_as;
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
use strum_macros;
|
||||
use utils::{
|
||||
completion,
|
||||
generation::Generation,
|
||||
history_buffer::HistoryBufferWithDropCounter,
|
||||
id::{NodeId, TenantId, TimelineId},
|
||||
lsn::Lsn,
|
||||
};
|
||||
|
||||
use crate::{reltag::RelTag, shard::TenantShardId};
|
||||
use crate::reltag::RelTag;
|
||||
use anyhow::bail;
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
|
||||
@@ -110,6 +109,7 @@ impl TenantState {
|
||||
// So, return `Maybe` while Attaching, making Console wait for the attach task to finish.
|
||||
Self::Attaching | Self::Activating(ActivatingFrom::Attaching) => Maybe,
|
||||
// tenant mgr startup distinguishes attaching from loading via marker file.
|
||||
// If it's loading, there is no attach marker file, i.e., attach had finished in the past.
|
||||
Self::Loading | Self::Activating(ActivatingFrom::Loading) => Attached,
|
||||
// We only reach Active after successful load / attach.
|
||||
// So, call atttachment status Attached.
|
||||
@@ -174,20 +174,26 @@ pub enum TimelineState {
|
||||
Broken { reason: String, backtrace: String },
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct TimelineCreateRequest {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub new_timeline_id: TimelineId,
|
||||
#[serde(default)]
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub ancestor_timeline_id: Option<TimelineId>,
|
||||
#[serde(default)]
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub ancestor_start_lsn: Option<Lsn>,
|
||||
pub pg_version: Option<u32>,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct TenantCreateRequest {
|
||||
pub new_tenant_id: TenantShardId,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub new_tenant_id: TenantId,
|
||||
#[serde(default)]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub generation: Option<u32>,
|
||||
@@ -195,6 +201,7 @@ pub struct TenantCreateRequest {
|
||||
pub config: TenantConfig, // as we have a flattened field, we should reject all unknown fields in it
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct TenantLoadRequest {
|
||||
@@ -211,8 +218,6 @@ impl std::ops::Deref for TenantCreateRequest {
|
||||
}
|
||||
}
|
||||
|
||||
/// An alternative representation of `pageserver::tenant::TenantConf` with
|
||||
/// simpler types.
|
||||
#[derive(Serialize, Deserialize, Debug, Default)]
|
||||
pub struct TenantConfig {
|
||||
pub checkpoint_distance: Option<u64>,
|
||||
@@ -238,59 +243,21 @@ pub struct TenantConfig {
|
||||
pub gc_feedback: Option<bool>,
|
||||
}
|
||||
|
||||
/// A flattened analog of a `pagesever::tenant::LocationMode`, which
|
||||
/// lists out all possible states (and the virtual "Detached" state)
|
||||
/// in a flat form rather than using rust-style enums.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub enum LocationConfigMode {
|
||||
AttachedSingle,
|
||||
AttachedMulti,
|
||||
AttachedStale,
|
||||
Secondary,
|
||||
Detached,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct LocationConfigSecondary {
|
||||
pub warm: bool,
|
||||
}
|
||||
|
||||
/// An alternative representation of `pageserver::tenant::LocationConf`,
|
||||
/// for use in external-facing APIs.
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct LocationConfig {
|
||||
pub mode: LocationConfigMode,
|
||||
/// If attaching, in what generation?
|
||||
#[serde(default)]
|
||||
pub generation: Option<Generation>,
|
||||
#[serde(default)]
|
||||
pub secondary_conf: Option<LocationConfigSecondary>,
|
||||
|
||||
// If requesting mode `Secondary`, configuration for that.
|
||||
// Custom storage configuration for the tenant, if any
|
||||
pub tenant_conf: TenantConfig,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct TenantCreateResponse(pub TenantId);
|
||||
pub struct TenantCreateResponse(#[serde_as(as = "DisplayFromStr")] pub TenantId);
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct StatusResponse {
|
||||
pub id: NodeId,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct TenantLocationConfigRequest {
|
||||
pub tenant_id: TenantId,
|
||||
#[serde(flatten)]
|
||||
pub config: LocationConfig, // as we have a flattened field, we should reject all unknown fields in it
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct TenantConfigRequest {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub tenant_id: TenantId,
|
||||
#[serde(flatten)]
|
||||
pub config: TenantConfig, // as we have a flattened field, we should reject all unknown fields in it
|
||||
@@ -362,8 +329,10 @@ pub enum TenantAttachmentStatus {
|
||||
Failed { reason: String },
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct TenantInfo {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub id: TenantId,
|
||||
// NB: intentionally not part of OpenAPI, we don't want to commit to a specific set of TenantState's
|
||||
pub state: TenantState,
|
||||
@@ -374,22 +343,33 @@ pub struct TenantInfo {
|
||||
}
|
||||
|
||||
/// This represents the output of the "timeline_detail" and "timeline_list" API calls.
|
||||
#[serde_as]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct TimelineInfo {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub tenant_id: TenantId,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub timeline_id: TimelineId,
|
||||
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub ancestor_timeline_id: Option<TimelineId>,
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub ancestor_lsn: Option<Lsn>,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub last_record_lsn: Lsn,
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub prev_record_lsn: Option<Lsn>,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub latest_gc_cutoff_lsn: Lsn,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub disk_consistent_lsn: Lsn,
|
||||
|
||||
/// The LSN that we have succesfully uploaded to remote storage
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub remote_consistent_lsn: Lsn,
|
||||
|
||||
/// The LSN that we are advertizing to safekeepers
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub remote_consistent_lsn_visible: Lsn,
|
||||
|
||||
pub current_logical_size: Option<u64>, // is None when timeline is Unloaded
|
||||
@@ -401,6 +381,7 @@ pub struct TimelineInfo {
|
||||
pub timeline_dir_layer_file_size_sum: Option<u64>,
|
||||
|
||||
pub wal_source_connstr: Option<String>,
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub last_received_msg_lsn: Option<Lsn>,
|
||||
/// the timestamp (in microseconds) of the last received message
|
||||
pub last_received_msg_ts: Option<u128>,
|
||||
@@ -497,13 +478,23 @@ pub struct LayerAccessStats {
|
||||
pub residence_events_history: HistoryBufferWithDropCounter<LayerResidenceEvent, 16>,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "kind")]
|
||||
pub enum InMemoryLayerInfo {
|
||||
Open { lsn_start: Lsn },
|
||||
Frozen { lsn_start: Lsn, lsn_end: Lsn },
|
||||
Open {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
lsn_start: Lsn,
|
||||
},
|
||||
Frozen {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
lsn_start: Lsn,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
lsn_end: Lsn,
|
||||
},
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "kind")]
|
||||
pub enum HistoricLayerInfo {
|
||||
@@ -511,7 +502,9 @@ pub enum HistoricLayerInfo {
|
||||
layer_file_name: String,
|
||||
layer_file_size: u64,
|
||||
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
lsn_start: Lsn,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
lsn_end: Lsn,
|
||||
remote: bool,
|
||||
access_stats: LayerAccessStats,
|
||||
@@ -520,6 +513,7 @@ pub enum HistoricLayerInfo {
|
||||
layer_file_name: String,
|
||||
layer_file_size: u64,
|
||||
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
lsn_start: Lsn,
|
||||
remote: bool,
|
||||
access_stats: LayerAccessStats,
|
||||
|
||||
@@ -22,9 +22,9 @@ use postgres_ffi::Oid;
|
||||
/// [See more related comments here](https:///github.com/postgres/postgres/blob/99c5852e20a0987eca1c38ba0c09329d4076b6a0/src/include/storage/relfilenode.h#L57).
|
||||
///
|
||||
// FIXME: should move 'forknum' as last field to keep this consistent with Postgres.
|
||||
// Then we could replace the custom Ord and PartialOrd implementations below with
|
||||
// deriving them. This will require changes in walredoproc.c.
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize)]
|
||||
// Then we could replace the custo Ord and PartialOrd implementations below with
|
||||
// deriving them.
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct RelTag {
|
||||
pub forknum: u8,
|
||||
pub spcnode: Oid,
|
||||
@@ -40,9 +40,21 @@ impl PartialOrd for RelTag {
|
||||
|
||||
impl Ord for RelTag {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
// Custom ordering where we put forknum to the end of the list
|
||||
let other_tup = (other.spcnode, other.dbnode, other.relnode, other.forknum);
|
||||
(self.spcnode, self.dbnode, self.relnode, self.forknum).cmp(&other_tup)
|
||||
let mut cmp = self.spcnode.cmp(&other.spcnode);
|
||||
if cmp != Ordering::Equal {
|
||||
return cmp;
|
||||
}
|
||||
cmp = self.dbnode.cmp(&other.dbnode);
|
||||
if cmp != Ordering::Equal {
|
||||
return cmp;
|
||||
}
|
||||
cmp = self.relnode.cmp(&other.relnode);
|
||||
if cmp != Ordering::Equal {
|
||||
return cmp;
|
||||
}
|
||||
cmp = self.forknum.cmp(&other.forknum);
|
||||
|
||||
cmp
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,321 +0,0 @@
|
||||
use std::{ops::RangeInclusive, str::FromStr};
|
||||
|
||||
use hex::FromHex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use utils::id::TenantId;
|
||||
|
||||
#[derive(Ord, PartialOrd, Eq, PartialEq, Clone, Copy, Serialize, Deserialize, Debug)]
|
||||
pub struct ShardNumber(pub u8);
|
||||
|
||||
#[derive(Ord, PartialOrd, Eq, PartialEq, Clone, Copy, Serialize, Deserialize, Debug)]
|
||||
pub struct ShardCount(pub u8);
|
||||
|
||||
impl ShardCount {
|
||||
pub const MAX: Self = Self(u8::MAX);
|
||||
}
|
||||
|
||||
impl ShardNumber {
|
||||
pub const MAX: Self = Self(u8::MAX);
|
||||
}
|
||||
|
||||
/// TenantShardId identify the units of work for the Pageserver.
|
||||
///
|
||||
/// These are written as `<tenant_id>-<shard number><shard-count>`, for example:
|
||||
///
|
||||
/// # The second shard in a two-shard tenant
|
||||
/// 072f1291a5310026820b2fe4b2968934-0102
|
||||
///
|
||||
/// Historically, tenants could not have multiple shards, and were identified
|
||||
/// by TenantId. To support this, TenantShardId has a special legacy
|
||||
/// mode where `shard_count` is equal to zero: this represents a single-sharded
|
||||
/// tenant which should be written as a TenantId with no suffix.
|
||||
///
|
||||
/// The human-readable encoding of TenantShardId, such as used in API URLs,
|
||||
/// is both forward and backward compatible: a legacy TenantId can be
|
||||
/// decoded as a TenantShardId, and when re-encoded it will be parseable
|
||||
/// as a TenantId.
|
||||
///
|
||||
/// Note that the binary encoding is _not_ backward compatible, because
|
||||
/// at the time sharding is introduced, there are no existing binary structures
|
||||
/// containing TenantId that we need to handle.
|
||||
#[derive(Eq, PartialEq, PartialOrd, Ord, Clone, Copy)]
|
||||
pub struct TenantShardId {
|
||||
pub tenant_id: TenantId,
|
||||
pub shard_number: ShardNumber,
|
||||
pub shard_count: ShardCount,
|
||||
}
|
||||
|
||||
impl TenantShardId {
|
||||
pub fn unsharded(tenant_id: TenantId) -> Self {
|
||||
Self {
|
||||
tenant_id,
|
||||
shard_number: ShardNumber(0),
|
||||
shard_count: ShardCount(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// The range of all TenantShardId that belong to a particular TenantId. This is useful when
|
||||
/// you have a BTreeMap of TenantShardId, and are querying by TenantId.
|
||||
pub fn tenant_range(tenant_id: TenantId) -> RangeInclusive<Self> {
|
||||
RangeInclusive::new(
|
||||
Self {
|
||||
tenant_id,
|
||||
shard_number: ShardNumber(0),
|
||||
shard_count: ShardCount(0),
|
||||
},
|
||||
Self {
|
||||
tenant_id,
|
||||
shard_number: ShardNumber::MAX,
|
||||
shard_count: ShardCount::MAX,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
pub fn shard_slug(&self) -> String {
|
||||
format!("{:02x}{:02x}", self.shard_number.0, self.shard_count.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TenantShardId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
if self.shard_count != ShardCount(0) {
|
||||
write!(
|
||||
f,
|
||||
"{}-{:02x}{:02x}",
|
||||
self.tenant_id, self.shard_number.0, self.shard_count.0
|
||||
)
|
||||
} else {
|
||||
// Legacy case (shard_count == 0) -- format as just the tenant id. Note that this
|
||||
// is distinct from the normal single shard case (shard count == 1).
|
||||
self.tenant_id.fmt(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for TenantShardId {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Debug is the same as Display: the compact hex representation
|
||||
write!(f, "{}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::str::FromStr for TenantShardId {
|
||||
type Err = hex::FromHexError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
// Expect format: 16 byte TenantId, '-', 1 byte shard number, 1 byte shard count
|
||||
if s.len() == 32 {
|
||||
// Legacy case: no shard specified
|
||||
Ok(Self {
|
||||
tenant_id: TenantId::from_str(s)?,
|
||||
shard_number: ShardNumber(0),
|
||||
shard_count: ShardCount(0),
|
||||
})
|
||||
} else if s.len() == 37 {
|
||||
let bytes = s.as_bytes();
|
||||
let tenant_id = TenantId::from_hex(&bytes[0..32])?;
|
||||
let mut shard_parts: [u8; 2] = [0u8; 2];
|
||||
hex::decode_to_slice(&bytes[33..37], &mut shard_parts)?;
|
||||
Ok(Self {
|
||||
tenant_id,
|
||||
shard_number: ShardNumber(shard_parts[0]),
|
||||
shard_count: ShardCount(shard_parts[1]),
|
||||
})
|
||||
} else {
|
||||
Err(hex::FromHexError::InvalidStringLength)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<[u8; 18]> for TenantShardId {
|
||||
fn from(b: [u8; 18]) -> Self {
|
||||
let tenant_id_bytes: [u8; 16] = b[0..16].try_into().unwrap();
|
||||
|
||||
Self {
|
||||
tenant_id: TenantId::from(tenant_id_bytes),
|
||||
shard_number: ShardNumber(b[16]),
|
||||
shard_count: ShardCount(b[17]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for TenantShardId {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
if serializer.is_human_readable() {
|
||||
serializer.collect_str(self)
|
||||
} else {
|
||||
let mut packed: [u8; 18] = [0; 18];
|
||||
packed[0..16].clone_from_slice(&self.tenant_id.as_arr());
|
||||
packed[16] = self.shard_number.0;
|
||||
packed[17] = self.shard_count.0;
|
||||
|
||||
packed.serialize(serializer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for TenantShardId {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct IdVisitor {
|
||||
is_human_readable_deserializer: bool,
|
||||
}
|
||||
|
||||
impl<'de> serde::de::Visitor<'de> for IdVisitor {
|
||||
type Value = TenantShardId;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
if self.is_human_readable_deserializer {
|
||||
formatter.write_str("value in form of hex string")
|
||||
} else {
|
||||
formatter.write_str("value in form of integer array([u8; 18])")
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: serde::de::SeqAccess<'de>,
|
||||
{
|
||||
let s = serde::de::value::SeqAccessDeserializer::new(seq);
|
||||
let id: [u8; 18] = Deserialize::deserialize(s)?;
|
||||
Ok(TenantShardId::from(id))
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
TenantShardId::from_str(v).map_err(E::custom)
|
||||
}
|
||||
}
|
||||
|
||||
if deserializer.is_human_readable() {
|
||||
deserializer.deserialize_str(IdVisitor {
|
||||
is_human_readable_deserializer: true,
|
||||
})
|
||||
} else {
|
||||
deserializer.deserialize_tuple(
|
||||
18,
|
||||
IdVisitor {
|
||||
is_human_readable_deserializer: false,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::str::FromStr;
|
||||
|
||||
use bincode;
|
||||
use utils::{id::TenantId, Hex};
|
||||
|
||||
use super::*;
|
||||
|
||||
const EXAMPLE_TENANT_ID: &str = "1f359dd625e519a1a4e8d7509690f6fc";
|
||||
|
||||
#[test]
|
||||
fn tenant_shard_id_string() -> Result<(), hex::FromHexError> {
|
||||
let example = TenantShardId {
|
||||
tenant_id: TenantId::from_str(EXAMPLE_TENANT_ID).unwrap(),
|
||||
shard_count: ShardCount(10),
|
||||
shard_number: ShardNumber(7),
|
||||
};
|
||||
|
||||
let encoded = format!("{example}");
|
||||
|
||||
let expected = format!("{EXAMPLE_TENANT_ID}-070a");
|
||||
assert_eq!(&encoded, &expected);
|
||||
|
||||
let decoded = TenantShardId::from_str(&encoded)?;
|
||||
|
||||
assert_eq!(example, decoded);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tenant_shard_id_binary() -> Result<(), hex::FromHexError> {
|
||||
let example = TenantShardId {
|
||||
tenant_id: TenantId::from_str(EXAMPLE_TENANT_ID).unwrap(),
|
||||
shard_count: ShardCount(10),
|
||||
shard_number: ShardNumber(7),
|
||||
};
|
||||
|
||||
let encoded = bincode::serialize(&example).unwrap();
|
||||
let expected: [u8; 18] = [
|
||||
0x1f, 0x35, 0x9d, 0xd6, 0x25, 0xe5, 0x19, 0xa1, 0xa4, 0xe8, 0xd7, 0x50, 0x96, 0x90,
|
||||
0xf6, 0xfc, 0x07, 0x0a,
|
||||
];
|
||||
assert_eq!(Hex(&encoded), Hex(&expected));
|
||||
|
||||
let decoded = bincode::deserialize(&encoded).unwrap();
|
||||
|
||||
assert_eq!(example, decoded);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tenant_shard_id_backward_compat() -> Result<(), hex::FromHexError> {
|
||||
// Test that TenantShardId can decode a TenantId in human
|
||||
// readable form
|
||||
let example = TenantId::from_str(EXAMPLE_TENANT_ID).unwrap();
|
||||
let encoded = format!("{example}");
|
||||
|
||||
assert_eq!(&encoded, EXAMPLE_TENANT_ID);
|
||||
|
||||
let decoded = TenantShardId::from_str(&encoded)?;
|
||||
|
||||
assert_eq!(example, decoded.tenant_id);
|
||||
assert_eq!(decoded.shard_count, ShardCount(0));
|
||||
assert_eq!(decoded.shard_number, ShardNumber(0));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tenant_shard_id_forward_compat() -> Result<(), hex::FromHexError> {
|
||||
// Test that a legacy TenantShardId encodes into a form that
|
||||
// can be decoded as TenantId
|
||||
let example_tenant_id = TenantId::from_str(EXAMPLE_TENANT_ID).unwrap();
|
||||
let example = TenantShardId::unsharded(example_tenant_id);
|
||||
let encoded = format!("{example}");
|
||||
|
||||
assert_eq!(&encoded, EXAMPLE_TENANT_ID);
|
||||
|
||||
let decoded = TenantId::from_str(&encoded)?;
|
||||
|
||||
assert_eq!(example_tenant_id, decoded);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tenant_shard_id_legacy_binary() -> Result<(), hex::FromHexError> {
|
||||
// Unlike in human readable encoding, binary encoding does not
|
||||
// do any special handling of legacy unsharded TenantIds: this test
|
||||
// is equivalent to the main test for binary encoding, just verifying
|
||||
// that the same behavior applies when we have used `unsharded()` to
|
||||
// construct a TenantShardId.
|
||||
let example = TenantShardId::unsharded(TenantId::from_str(EXAMPLE_TENANT_ID).unwrap());
|
||||
let encoded = bincode::serialize(&example).unwrap();
|
||||
|
||||
let expected: [u8; 18] = [
|
||||
0x1f, 0x35, 0x9d, 0xd6, 0x25, 0xe5, 0x19, 0xa1, 0xa4, 0xe8, 0xd7, 0x50, 0x96, 0x90,
|
||||
0xf6, 0xfc, 0x00, 0x00,
|
||||
];
|
||||
assert_eq!(Hex(&encoded), Hex(&expected));
|
||||
|
||||
let decoded = bincode::deserialize::<TenantShardId>(&encoded).unwrap();
|
||||
assert_eq!(example, decoded);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -2,8 +2,6 @@
|
||||
//! To use, create PostgresBackend and run() it, passing the Handler
|
||||
//! implementation determining how to process the queries. Currently its API
|
||||
//! is rather narrow, but we can extend it once required.
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
use anyhow::Context;
|
||||
use bytes::Bytes;
|
||||
use futures::pin_mut;
|
||||
@@ -17,12 +15,12 @@ use std::{fmt, io};
|
||||
use std::{future::Future, str::FromStr};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tracing::{debug, error, info, trace, warn};
|
||||
use tracing::{debug, error, info, trace};
|
||||
|
||||
use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
|
||||
use pq_proto::{
|
||||
BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_ADMIN_SHUTDOWN,
|
||||
SQLSTATE_INTERNAL_ERROR, SQLSTATE_SUCCESSFUL_COMPLETION,
|
||||
BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_INTERNAL_ERROR,
|
||||
SQLSTATE_SUCCESSFUL_COMPLETION,
|
||||
};
|
||||
|
||||
/// An error, occurred during query processing:
|
||||
@@ -32,14 +30,6 @@ pub enum QueryError {
|
||||
/// The connection was lost while processing the query.
|
||||
#[error(transparent)]
|
||||
Disconnected(#[from] ConnectionError),
|
||||
/// We were instructed to shutdown while processing the query
|
||||
#[error("Shutting down")]
|
||||
Shutdown,
|
||||
/// Authentication failure
|
||||
#[error("Unauthorized: {0}")]
|
||||
Unauthorized(std::borrow::Cow<'static, str>),
|
||||
#[error("Simulated Connection Error")]
|
||||
SimulatedConnectionError,
|
||||
/// Some other error
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
@@ -54,9 +44,7 @@ impl From<io::Error> for QueryError {
|
||||
impl QueryError {
|
||||
pub fn pg_error_code(&self) -> &'static [u8; 5] {
|
||||
match self {
|
||||
Self::Disconnected(_) | Self::SimulatedConnectionError => b"08006", // connection failure
|
||||
Self::Shutdown => SQLSTATE_ADMIN_SHUTDOWN,
|
||||
Self::Unauthorized(_) => SQLSTATE_INTERNAL_ERROR,
|
||||
Self::Disconnected(_) => b"08006", // connection failure
|
||||
Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
|
||||
}
|
||||
}
|
||||
@@ -250,7 +238,6 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancellation safe as long as the underlying IO is cancellation safe.
|
||||
async fn shutdown(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
MaybeWriteOnly::Full(framed) => framed.shutdown().await,
|
||||
@@ -402,37 +389,14 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
|
||||
shutdown_watcher: F,
|
||||
) -> Result<(), QueryError>
|
||||
where
|
||||
F: Fn() -> S + Clone,
|
||||
F: Fn() -> S,
|
||||
S: Future,
|
||||
{
|
||||
let ret = self
|
||||
.run_message_loop(handler, shutdown_watcher.clone())
|
||||
.await;
|
||||
|
||||
tokio::select! {
|
||||
_ = shutdown_watcher() => {
|
||||
// do nothing; we most likely got already stopped by shutdown and will log it next.
|
||||
}
|
||||
_ = self.framed.shutdown() => {
|
||||
// socket might be already closed, e.g. if previously received error,
|
||||
// so ignore result.
|
||||
},
|
||||
}
|
||||
|
||||
match ret {
|
||||
Ok(()) => Ok(()),
|
||||
Err(QueryError::Shutdown) => {
|
||||
info!("Stopped due to shutdown");
|
||||
Ok(())
|
||||
}
|
||||
Err(QueryError::Disconnected(e)) => {
|
||||
info!("Disconnected ({e:#})");
|
||||
// Disconnection is not an error: we just use it that way internally to drop
|
||||
// out of loops.
|
||||
Ok(())
|
||||
}
|
||||
e => e,
|
||||
}
|
||||
let ret = self.run_message_loop(handler, shutdown_watcher).await;
|
||||
// socket might be already closed, e.g. if previously received error,
|
||||
// so ignore result.
|
||||
self.framed.shutdown().await.ok();
|
||||
ret
|
||||
}
|
||||
|
||||
async fn run_message_loop<F, S>(
|
||||
@@ -452,11 +416,15 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
|
||||
_ = shutdown_watcher() => {
|
||||
// We were requested to shut down.
|
||||
tracing::info!("shutdown request received during handshake");
|
||||
return Err(QueryError::Shutdown)
|
||||
return Ok(())
|
||||
},
|
||||
|
||||
handshake_r = self.handshake(handler) => {
|
||||
handshake_r?;
|
||||
result = self.handshake(handler) => {
|
||||
// Handshake complete.
|
||||
result?;
|
||||
if self.state == ProtoState::Closed {
|
||||
return Ok(()); // EOF during handshake
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
@@ -467,34 +435,17 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
|
||||
_ = shutdown_watcher() => {
|
||||
// We were requested to shut down.
|
||||
tracing::info!("shutdown request received in run_message_loop");
|
||||
return Err(QueryError::Shutdown)
|
||||
Ok(None)
|
||||
},
|
||||
msg = self.read_message() => { msg },
|
||||
)? {
|
||||
trace!("got message {:?}", msg);
|
||||
|
||||
let result = self.process_message(handler, msg, &mut query_string).await;
|
||||
tokio::select!(
|
||||
biased;
|
||||
_ = shutdown_watcher() => {
|
||||
// We were requested to shut down.
|
||||
tracing::info!("shutdown request received during response flush");
|
||||
|
||||
// If we exited process_message with a shutdown error, there may be
|
||||
// some valid response content on in our transmit buffer: permit sending
|
||||
// this within a short timeout. This is a best effort thing so we don't
|
||||
// care about the result.
|
||||
tokio::time::timeout(std::time::Duration::from_millis(500), self.flush()).await.ok();
|
||||
|
||||
return Err(QueryError::Shutdown)
|
||||
},
|
||||
flush_r = self.flush() => {
|
||||
flush_r?;
|
||||
}
|
||||
);
|
||||
|
||||
self.flush().await?;
|
||||
match result? {
|
||||
ProcessMsgResult::Continue => {
|
||||
self.flush().await?;
|
||||
continue;
|
||||
}
|
||||
ProcessMsgResult::Break => break,
|
||||
@@ -599,9 +550,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
|
||||
self.peer_addr
|
||||
);
|
||||
self.state = ProtoState::Closed;
|
||||
return Err(QueryError::Disconnected(ConnectionError::Protocol(
|
||||
ProtocolError::Protocol("EOF during handshake".to_string()),
|
||||
)));
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -616,7 +565,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
|
||||
|
||||
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
&short_error(&e),
|
||||
&e.to_string(),
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
return Err(e);
|
||||
@@ -640,9 +589,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
|
||||
self.peer_addr
|
||||
);
|
||||
self.state = ProtoState::Closed;
|
||||
return Err(QueryError::Disconnected(ConnectionError::Protocol(
|
||||
ProtocolError::Protocol("EOF during auth".to_string()),
|
||||
)));
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -736,20 +683,12 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
|
||||
|
||||
trace!("got query {query_string:?}");
|
||||
if let Err(e) = handler.process_query(self, query_string).await {
|
||||
match e {
|
||||
QueryError::Shutdown => return Ok(ProcessMsgResult::Break),
|
||||
QueryError::SimulatedConnectionError => {
|
||||
return Err(QueryError::SimulatedConnectionError)
|
||||
}
|
||||
e => {
|
||||
log_query_error(query_string, &e);
|
||||
let short_error = short_error(&e);
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
&short_error,
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
}
|
||||
}
|
||||
log_query_error(query_string, &e);
|
||||
let short_error = short_error(&e);
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
&short_error,
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
}
|
||||
self.write_message_noflush(&BeMessage::ReadyForQuery)?;
|
||||
}
|
||||
@@ -974,9 +913,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'a, I
|
||||
pub fn short_error(e: &QueryError) -> String {
|
||||
match e {
|
||||
QueryError::Disconnected(connection_error) => connection_error.to_string(),
|
||||
QueryError::Shutdown => "shutdown".to_string(),
|
||||
QueryError::Unauthorized(_e) => "JWT authentication error".to_string(),
|
||||
QueryError::SimulatedConnectionError => "simulated connection error".to_string(),
|
||||
QueryError::Other(e) => format!("{e:#}"),
|
||||
}
|
||||
}
|
||||
@@ -993,15 +929,6 @@ fn log_query_error(query: &str, e: &QueryError) {
|
||||
QueryError::Disconnected(other_connection_error) => {
|
||||
error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
|
||||
}
|
||||
QueryError::SimulatedConnectionError => {
|
||||
error!("query handler for query '{query}' failed due to a simulated connection error")
|
||||
}
|
||||
QueryError::Shutdown => {
|
||||
info!("query handler for '{query}' cancelled during tenant shutdown")
|
||||
}
|
||||
QueryError::Unauthorized(e) => {
|
||||
warn!("query handler for '{query}' failed with authentication error: {e}");
|
||||
}
|
||||
QueryError::Other(e) => {
|
||||
error!("query handler for '{query}' failed: {e:?}");
|
||||
}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
use anyhow::{bail, Context};
|
||||
use itertools::Itertools;
|
||||
use std::borrow::Cow;
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
// modules included with the postgres_ffi macro depend on the types of the specific version's
|
||||
// types, and trigger a too eager lint.
|
||||
#![allow(clippy::duplicate_mod)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
use bytes::Bytes;
|
||||
use utils::bin_ser::SerializeError;
|
||||
@@ -21,7 +20,6 @@ macro_rules! postgres_ffi {
|
||||
pub mod bindings {
|
||||
// bindgen generates bindings for a lot of stuff we don't need
|
||||
#![allow(dead_code)]
|
||||
#![allow(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
include!(concat!(
|
||||
@@ -133,7 +131,6 @@ pub const MAX_SEND_SIZE: usize = XLOG_BLCKSZ * 16;
|
||||
|
||||
// Export some version independent functions that are used outside of this mod
|
||||
pub use v14::xlog_utils::encode_logical_message;
|
||||
pub use v14::xlog_utils::from_pg_timestamp;
|
||||
pub use v14::xlog_utils::get_current_timestamp;
|
||||
pub use v14::xlog_utils::to_pg_timestamp;
|
||||
pub use v14::xlog_utils::XLogFileName;
|
||||
|
||||
@@ -220,10 +220,6 @@ pub const XLOG_CHECKPOINT_ONLINE: u8 = 0x10;
|
||||
pub const XLP_FIRST_IS_CONTRECORD: u16 = 0x0001;
|
||||
pub const XLP_LONG_HEADER: u16 = 0x0002;
|
||||
|
||||
/* From replication/slot.h */
|
||||
pub const REPL_SLOT_ON_DISK_OFFSETOF_RESTART_LSN: usize = 4*4 /* offset of `slotdata` in ReplicationSlotOnDisk */
|
||||
+ 64 /* NameData */ + 4*4;
|
||||
|
||||
/* From fsm_internals.h */
|
||||
const FSM_NODES_PER_PAGE: usize = BLCKSZ as usize - SIZEOF_PAGE_HEADER_DATA - 4;
|
||||
const FSM_NON_LEAF_NODES_PER_PAGE: usize = BLCKSZ as usize / 2 - 1;
|
||||
|
||||
@@ -136,42 +136,21 @@ pub fn get_current_timestamp() -> TimestampTz {
|
||||
to_pg_timestamp(SystemTime::now())
|
||||
}
|
||||
|
||||
// Module to reduce the scope of the constants
|
||||
mod timestamp_conversions {
|
||||
use std::time::Duration;
|
||||
|
||||
use super::*;
|
||||
|
||||
const UNIX_EPOCH_JDATE: u64 = 2440588; // == date2j(1970, 1, 1)
|
||||
const POSTGRES_EPOCH_JDATE: u64 = 2451545; // == date2j(2000, 1, 1)
|
||||
pub fn to_pg_timestamp(time: SystemTime) -> TimestampTz {
|
||||
const UNIX_EPOCH_JDATE: u64 = 2440588; /* == date2j(1970, 1, 1) */
|
||||
const POSTGRES_EPOCH_JDATE: u64 = 2451545; /* == date2j(2000, 1, 1) */
|
||||
const SECS_PER_DAY: u64 = 86400;
|
||||
const USECS_PER_SEC: u64 = 1000000;
|
||||
const SECS_DIFF_UNIX_TO_POSTGRES_EPOCH: u64 =
|
||||
(POSTGRES_EPOCH_JDATE - UNIX_EPOCH_JDATE) * SECS_PER_DAY;
|
||||
|
||||
pub fn to_pg_timestamp(time: SystemTime) -> TimestampTz {
|
||||
match time.duration_since(SystemTime::UNIX_EPOCH) {
|
||||
Ok(n) => {
|
||||
((n.as_secs() - SECS_DIFF_UNIX_TO_POSTGRES_EPOCH) * USECS_PER_SEC
|
||||
+ n.subsec_micros() as u64) as i64
|
||||
}
|
||||
Err(_) => panic!("SystemTime before UNIX EPOCH!"),
|
||||
match time.duration_since(SystemTime::UNIX_EPOCH) {
|
||||
Ok(n) => {
|
||||
((n.as_secs() - ((POSTGRES_EPOCH_JDATE - UNIX_EPOCH_JDATE) * SECS_PER_DAY))
|
||||
* USECS_PER_SEC
|
||||
+ n.subsec_micros() as u64) as i64
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_pg_timestamp(time: TimestampTz) -> SystemTime {
|
||||
let time: u64 = time
|
||||
.try_into()
|
||||
.expect("timestamp before millenium (postgres epoch)");
|
||||
let since_unix_epoch = time + SECS_DIFF_UNIX_TO_POSTGRES_EPOCH * USECS_PER_SEC;
|
||||
SystemTime::UNIX_EPOCH
|
||||
.checked_add(Duration::from_micros(since_unix_epoch))
|
||||
.expect("SystemTime overflow")
|
||||
Err(_) => panic!("SystemTime before UNIX EPOCH!"),
|
||||
}
|
||||
}
|
||||
|
||||
pub use timestamp_conversions::{from_pg_timestamp, to_pg_timestamp};
|
||||
|
||||
// Returns (aligned) end_lsn of the last record in data_dir with WAL segments.
|
||||
// start_lsn must point to some previously known record boundary (beginning of
|
||||
// the next record). If no valid record after is found, start_lsn is returned
|
||||
@@ -502,24 +481,4 @@ pub fn encode_logical_message(prefix: &str, message: &str) -> Vec<u8> {
|
||||
wal
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ts_conversion() {
|
||||
let now = SystemTime::now();
|
||||
let round_trip = from_pg_timestamp(to_pg_timestamp(now));
|
||||
|
||||
let now_since = now.duration_since(SystemTime::UNIX_EPOCH).unwrap();
|
||||
let round_trip_since = round_trip.duration_since(SystemTime::UNIX_EPOCH).unwrap();
|
||||
assert_eq!(now_since.as_micros(), round_trip_since.as_micros());
|
||||
|
||||
let now_pg = get_current_timestamp();
|
||||
let round_trip_pg = to_pg_timestamp(from_pg_timestamp(now_pg));
|
||||
|
||||
assert_eq!(now_pg, round_trip_pg);
|
||||
}
|
||||
|
||||
// If you need to craft WAL and write tests for this module, put it at wal_craft crate.
|
||||
}
|
||||
// If you need to craft WAL and write tests for this module, put it at wal_craft crate.
|
||||
|
||||
@@ -12,7 +12,7 @@ log.workspace = true
|
||||
once_cell.workspace = true
|
||||
postgres.workspace = true
|
||||
postgres_ffi.workspace = true
|
||||
camino-tempfile.workspace = true
|
||||
tempfile.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use anyhow::{bail, ensure};
|
||||
use camino_tempfile::{tempdir, Utf8TempDir};
|
||||
use log::*;
|
||||
use postgres::types::PgLsn;
|
||||
use postgres::Client;
|
||||
@@ -9,12 +8,12 @@ use std::cmp::Ordering;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
use std::time::{Duration, Instant};
|
||||
use tempfile::{tempdir, TempDir};
|
||||
|
||||
macro_rules! xlog_utils_test {
|
||||
($version:ident) => {
|
||||
#[path = "."]
|
||||
mod $version {
|
||||
#[allow(unused_imports)]
|
||||
pub use postgres_ffi::$version::wal_craft_test_export::*;
|
||||
#[allow(clippy::duplicate_mod)]
|
||||
#[cfg(test)]
|
||||
@@ -34,7 +33,7 @@ pub struct Conf {
|
||||
|
||||
pub struct PostgresServer {
|
||||
process: std::process::Child,
|
||||
_unix_socket_dir: Utf8TempDir,
|
||||
_unix_socket_dir: TempDir,
|
||||
client_config: postgres::Config,
|
||||
}
|
||||
|
||||
|
||||
@@ -214,24 +214,27 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancellation safe as long as the AsyncWrite is cancellation safe.
|
||||
async fn flush<S: AsyncWrite + Unpin>(
|
||||
stream: &mut S,
|
||||
write_buf: &mut BytesMut,
|
||||
) -> Result<(), io::Error> {
|
||||
while write_buf.has_remaining() {
|
||||
let bytes_written = stream.write_buf(write_buf).await?;
|
||||
let bytes_written = stream.write(write_buf.chunk()).await?;
|
||||
if bytes_written == 0 {
|
||||
return Err(io::Error::new(
|
||||
ErrorKind::WriteZero,
|
||||
"failed to write message",
|
||||
));
|
||||
}
|
||||
// The advanced part will be garbage collected, likely during shifting
|
||||
// data left on next attempt to write to buffer when free space is not
|
||||
// enough.
|
||||
write_buf.advance(bytes_written);
|
||||
}
|
||||
write_buf.clear();
|
||||
stream.flush().await
|
||||
}
|
||||
|
||||
/// Cancellation safe as long as the AsyncWrite is cancellation safe.
|
||||
async fn shutdown<S: AsyncWrite + Unpin>(
|
||||
stream: &mut S,
|
||||
write_buf: &mut BytesMut,
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
//! Postgres protocol messages serialization-deserialization. See
|
||||
//! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
|
||||
//! on message formats.
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
pub mod framed;
|
||||
|
||||
@@ -671,7 +670,6 @@ pub fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
|
||||
}
|
||||
|
||||
pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
|
||||
pub const SQLSTATE_ADMIN_SHUTDOWN: &[u8; 5] = b"57P01";
|
||||
pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
|
||||
|
||||
impl<'a> BeMessage<'a> {
|
||||
|
||||
@@ -8,14 +8,11 @@ license.workspace = true
|
||||
anyhow.workspace = true
|
||||
async-trait.workspace = true
|
||||
once_cell.workspace = true
|
||||
aws-smithy-async.workspace = true
|
||||
aws-smithy-http.workspace = true
|
||||
aws-types.workspace = true
|
||||
aws-config.workspace = true
|
||||
aws-sdk-s3.workspace = true
|
||||
aws-credential-types.workspace = true
|
||||
bytes.workspace = true
|
||||
camino.workspace = true
|
||||
hyper = { workspace = true, features = ["stream"] }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
@@ -28,15 +25,8 @@ metrics.workspace = true
|
||||
utils.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
workspace_hack.workspace = true
|
||||
azure_core.workspace = true
|
||||
azure_identity.workspace = true
|
||||
azure_storage.workspace = true
|
||||
azure_storage_blobs.workspace = true
|
||||
futures-util.workspace = true
|
||||
http-types.workspace = true
|
||||
itertools.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
camino-tempfile.workspace = true
|
||||
tempfile.workspace = true
|
||||
test-context.workspace = true
|
||||
rand.workspace = true
|
||||
|
||||
@@ -1,314 +0,0 @@
|
||||
//! Azure Blob Storage wrapper
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::num::NonZeroU32;
|
||||
use std::sync::Arc;
|
||||
use std::{borrow::Cow, io::Cursor};
|
||||
|
||||
use super::REMOTE_STORAGE_PREFIX_SEPARATOR;
|
||||
use anyhow::Result;
|
||||
use azure_core::request_options::{MaxResults, Metadata, Range};
|
||||
use azure_identity::DefaultAzureCredential;
|
||||
use azure_storage::StorageCredentials;
|
||||
use azure_storage_blobs::prelude::ClientBuilder;
|
||||
use azure_storage_blobs::{blob::operations::GetBlobBuilder, prelude::ContainerClient};
|
||||
use futures_util::StreamExt;
|
||||
use http_types::StatusCode;
|
||||
use tokio::io::AsyncRead;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::s3_bucket::RequestKind;
|
||||
use crate::{
|
||||
AzureConfig, ConcurrencyLimiter, Download, DownloadError, Listing, ListingMode, RemotePath,
|
||||
RemoteStorage, StorageMetadata,
|
||||
};
|
||||
|
||||
pub struct AzureBlobStorage {
|
||||
client: ContainerClient,
|
||||
prefix_in_container: Option<String>,
|
||||
max_keys_per_list_response: Option<NonZeroU32>,
|
||||
concurrency_limiter: ConcurrencyLimiter,
|
||||
}
|
||||
|
||||
impl AzureBlobStorage {
|
||||
pub fn new(azure_config: &AzureConfig) -> Result<Self> {
|
||||
debug!(
|
||||
"Creating azure remote storage for azure container {}",
|
||||
azure_config.container_name
|
||||
);
|
||||
|
||||
let account = env::var("AZURE_STORAGE_ACCOUNT").expect("missing AZURE_STORAGE_ACCOUNT");
|
||||
|
||||
// If the `AZURE_STORAGE_ACCESS_KEY` env var has an access key, use that,
|
||||
// otherwise try the token based credentials.
|
||||
let credentials = if let Ok(access_key) = env::var("AZURE_STORAGE_ACCESS_KEY") {
|
||||
StorageCredentials::access_key(account.clone(), access_key)
|
||||
} else {
|
||||
let token_credential = DefaultAzureCredential::default();
|
||||
StorageCredentials::token_credential(Arc::new(token_credential))
|
||||
};
|
||||
|
||||
let builder = ClientBuilder::new(account, credentials);
|
||||
|
||||
let client = builder.container_client(azure_config.container_name.to_owned());
|
||||
|
||||
let max_keys_per_list_response =
|
||||
if let Some(limit) = azure_config.max_keys_per_list_response {
|
||||
Some(
|
||||
NonZeroU32::new(limit as u32)
|
||||
.ok_or_else(|| anyhow::anyhow!("max_keys_per_list_response can't be 0"))?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(AzureBlobStorage {
|
||||
client,
|
||||
prefix_in_container: azure_config.prefix_in_container.to_owned(),
|
||||
max_keys_per_list_response,
|
||||
concurrency_limiter: ConcurrencyLimiter::new(azure_config.concurrency_limit.get()),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn relative_path_to_name(&self, path: &RemotePath) -> String {
|
||||
assert_eq!(std::path::MAIN_SEPARATOR, REMOTE_STORAGE_PREFIX_SEPARATOR);
|
||||
let path_string = path
|
||||
.get_path()
|
||||
.as_str()
|
||||
.trim_end_matches(REMOTE_STORAGE_PREFIX_SEPARATOR);
|
||||
match &self.prefix_in_container {
|
||||
Some(prefix) => {
|
||||
if prefix.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
|
||||
prefix.clone() + path_string
|
||||
} else {
|
||||
format!("{prefix}{REMOTE_STORAGE_PREFIX_SEPARATOR}{path_string}")
|
||||
}
|
||||
}
|
||||
None => path_string.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn name_to_relative_path(&self, key: &str) -> RemotePath {
|
||||
let relative_path =
|
||||
match key.strip_prefix(self.prefix_in_container.as_deref().unwrap_or_default()) {
|
||||
Some(stripped) => stripped,
|
||||
// we rely on Azure to return properly prefixed paths
|
||||
// for requests with a certain prefix
|
||||
None => panic!(
|
||||
"Key {key} does not start with container prefix {:?}",
|
||||
self.prefix_in_container
|
||||
),
|
||||
};
|
||||
RemotePath(
|
||||
relative_path
|
||||
.split(REMOTE_STORAGE_PREFIX_SEPARATOR)
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
async fn download_for_builder(
|
||||
&self,
|
||||
builder: GetBlobBuilder,
|
||||
) -> Result<Download, DownloadError> {
|
||||
let mut response = builder.into_stream();
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
// TODO give proper streaming response instead of buffering into RAM
|
||||
// https://github.com/neondatabase/neon/issues/5563
|
||||
let mut buf = Vec::new();
|
||||
while let Some(part) = response.next().await {
|
||||
let part = part.map_err(to_download_error)?;
|
||||
if let Some(blob_meta) = part.blob.metadata {
|
||||
metadata.extend(blob_meta.iter().map(|(k, v)| (k.to_owned(), v.to_owned())));
|
||||
}
|
||||
let data = part
|
||||
.data
|
||||
.collect()
|
||||
.await
|
||||
.map_err(|e| DownloadError::Other(e.into()))?;
|
||||
buf.extend_from_slice(&data.slice(..));
|
||||
}
|
||||
Ok(Download {
|
||||
download_stream: Box::pin(Cursor::new(buf)),
|
||||
metadata: Some(StorageMetadata(metadata)),
|
||||
})
|
||||
}
|
||||
|
||||
async fn permit(&self, kind: RequestKind) -> tokio::sync::SemaphorePermit<'_> {
|
||||
self.concurrency_limiter
|
||||
.acquire(kind)
|
||||
.await
|
||||
.expect("semaphore is never closed")
|
||||
}
|
||||
}
|
||||
|
||||
fn to_azure_metadata(metadata: StorageMetadata) -> Metadata {
|
||||
let mut res = Metadata::new();
|
||||
for (k, v) in metadata.0.into_iter() {
|
||||
res.insert(k, v);
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn to_download_error(error: azure_core::Error) -> DownloadError {
|
||||
if let Some(http_err) = error.as_http_error() {
|
||||
match http_err.status() {
|
||||
StatusCode::NotFound => DownloadError::NotFound,
|
||||
StatusCode::BadRequest => DownloadError::BadInput(anyhow::Error::new(error)),
|
||||
_ => DownloadError::Other(anyhow::Error::new(error)),
|
||||
}
|
||||
} else {
|
||||
DownloadError::Other(error.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl RemoteStorage for AzureBlobStorage {
|
||||
async fn list(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
mode: ListingMode,
|
||||
) -> anyhow::Result<Listing, DownloadError> {
|
||||
// get the passed prefix or if it is not set use prefix_in_bucket value
|
||||
let list_prefix = prefix
|
||||
.map(|p| self.relative_path_to_name(p))
|
||||
.or_else(|| self.prefix_in_container.clone())
|
||||
.map(|mut p| {
|
||||
// required to end with a separator
|
||||
// otherwise request will return only the entry of a prefix
|
||||
if matches!(mode, ListingMode::WithDelimiter)
|
||||
&& !p.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR)
|
||||
{
|
||||
p.push(REMOTE_STORAGE_PREFIX_SEPARATOR);
|
||||
}
|
||||
p
|
||||
});
|
||||
|
||||
let mut builder = self.client.list_blobs();
|
||||
|
||||
if let ListingMode::WithDelimiter = mode {
|
||||
builder = builder.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
|
||||
}
|
||||
|
||||
if let Some(prefix) = list_prefix {
|
||||
builder = builder.prefix(Cow::from(prefix.to_owned()));
|
||||
}
|
||||
|
||||
if let Some(limit) = self.max_keys_per_list_response {
|
||||
builder = builder.max_results(MaxResults::new(limit));
|
||||
}
|
||||
|
||||
let mut response = builder.into_stream();
|
||||
let mut res = Listing::default();
|
||||
while let Some(l) = response.next().await {
|
||||
let entry = l.map_err(to_download_error)?;
|
||||
let prefix_iter = entry
|
||||
.blobs
|
||||
.prefixes()
|
||||
.map(|prefix| self.name_to_relative_path(&prefix.name));
|
||||
res.prefixes.extend(prefix_iter);
|
||||
|
||||
let blob_iter = entry
|
||||
.blobs
|
||||
.blobs()
|
||||
.map(|k| self.name_to_relative_path(&k.name));
|
||||
res.keys.extend(blob_iter);
|
||||
}
|
||||
Ok(res)
|
||||
}
|
||||
async fn upload(
|
||||
&self,
|
||||
mut from: impl AsyncRead + Unpin + Send + Sync + 'static,
|
||||
data_size_bytes: usize,
|
||||
to: &RemotePath,
|
||||
metadata: Option<StorageMetadata>,
|
||||
) -> anyhow::Result<()> {
|
||||
let _permit = self.permit(RequestKind::Put).await;
|
||||
let blob_client = self.client.blob_client(self.relative_path_to_name(to));
|
||||
|
||||
// TODO FIX THIS UGLY HACK and don't buffer the entire object
|
||||
// into RAM here, but use the streaming interface. For that,
|
||||
// we'd have to change the interface though...
|
||||
// https://github.com/neondatabase/neon/issues/5563
|
||||
let mut buf = Vec::with_capacity(data_size_bytes);
|
||||
tokio::io::copy(&mut from, &mut buf).await?;
|
||||
let body = azure_core::Body::Bytes(buf.into());
|
||||
|
||||
let mut builder = blob_client.put_block_blob(body);
|
||||
|
||||
if let Some(metadata) = metadata {
|
||||
builder = builder.metadata(to_azure_metadata(metadata));
|
||||
}
|
||||
|
||||
let _response = builder.into_future().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn download(&self, from: &RemotePath) -> Result<Download, DownloadError> {
|
||||
let _permit = self.permit(RequestKind::Get).await;
|
||||
let blob_client = self.client.blob_client(self.relative_path_to_name(from));
|
||||
|
||||
let builder = blob_client.get();
|
||||
|
||||
self.download_for_builder(builder).await
|
||||
}
|
||||
|
||||
async fn download_byte_range(
|
||||
&self,
|
||||
from: &RemotePath,
|
||||
start_inclusive: u64,
|
||||
end_exclusive: Option<u64>,
|
||||
) -> Result<Download, DownloadError> {
|
||||
let _permit = self.permit(RequestKind::Get).await;
|
||||
let blob_client = self.client.blob_client(self.relative_path_to_name(from));
|
||||
|
||||
let mut builder = blob_client.get();
|
||||
|
||||
if let Some(end_exclusive) = end_exclusive {
|
||||
builder = builder.range(Range::new(start_inclusive, end_exclusive));
|
||||
} else {
|
||||
// Open ranges are not supported by the SDK so we work around
|
||||
// by setting the upper limit extremely high (but high enough
|
||||
// to still be representable by signed 64 bit integers).
|
||||
// TODO remove workaround once the SDK adds open range support
|
||||
// https://github.com/Azure/azure-sdk-for-rust/issues/1438
|
||||
let end_exclusive = u64::MAX / 4;
|
||||
builder = builder.range(Range::new(start_inclusive, end_exclusive));
|
||||
}
|
||||
|
||||
self.download_for_builder(builder).await
|
||||
}
|
||||
|
||||
async fn delete(&self, path: &RemotePath) -> anyhow::Result<()> {
|
||||
let _permit = self.permit(RequestKind::Delete).await;
|
||||
let blob_client = self.client.blob_client(self.relative_path_to_name(path));
|
||||
|
||||
let builder = blob_client.delete();
|
||||
|
||||
match builder.into_future().await {
|
||||
Ok(_response) => Ok(()),
|
||||
Err(e) => {
|
||||
if let Some(http_err) = e.as_http_error() {
|
||||
if http_err.status() == StatusCode::NotFound {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Err(anyhow::Error::new(e))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn delete_objects<'a>(&self, paths: &'a [RemotePath]) -> anyhow::Result<()> {
|
||||
// Permit is already obtained by inner delete function
|
||||
|
||||
// TODO batch requests are also not supported by the SDK
|
||||
// https://github.com/Azure/azure-sdk-for-rust/issues/1068
|
||||
// https://github.com/Azure/azure-sdk-for-rust/issues/1249
|
||||
for path in paths {
|
||||
self.delete(path).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -4,43 +4,41 @@
|
||||
//! [`RemoteStorage`] trait a CRUD-like generic abstraction to use for adapting external storages with a few implementations:
|
||||
//! * [`local_fs`] allows to use local file system as an external storage
|
||||
//! * [`s3_bucket`] uses AWS S3 bucket as an external storage
|
||||
//! * [`azure_blob`] allows to use Azure Blob storage as an external storage
|
||||
//!
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
mod azure_blob;
|
||||
mod local_fs;
|
||||
mod s3_bucket;
|
||||
mod simulate_failures;
|
||||
|
||||
use std::{collections::HashMap, fmt::Debug, num::NonZeroUsize, pin::Pin, sync::Arc};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fmt::Debug,
|
||||
num::{NonZeroU32, NonZeroUsize},
|
||||
path::{Path, PathBuf},
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::{io, sync::Semaphore};
|
||||
use tokio::io;
|
||||
use toml_edit::Item;
|
||||
use tracing::info;
|
||||
|
||||
pub use self::{
|
||||
azure_blob::AzureBlobStorage, local_fs::LocalFs, s3_bucket::S3Bucket,
|
||||
simulate_failures::UnreliableWrapper,
|
||||
};
|
||||
use s3_bucket::RequestKind;
|
||||
pub use self::{local_fs::LocalFs, s3_bucket::S3Bucket, simulate_failures::UnreliableWrapper};
|
||||
|
||||
/// How many different timelines can be processed simultaneously when synchronizing layers with the remote storage.
|
||||
/// During regular work, pageserver produces one layer file per timeline checkpoint, with bursts of concurrency
|
||||
/// during start (where local and remote timelines are compared and initial sync tasks are scheduled) and timeline attach.
|
||||
/// Both cases may trigger timeline download, that might download a lot of layers. This concurrency is limited by the clients internally, if needed.
|
||||
pub const DEFAULT_REMOTE_STORAGE_MAX_CONCURRENT_SYNCS: usize = 50;
|
||||
pub const DEFAULT_REMOTE_STORAGE_MAX_SYNC_ERRORS: u32 = 10;
|
||||
/// Currently, sync happens with AWS S3, that has two limits on requests per second:
|
||||
/// ~200 RPS for IAM services
|
||||
/// <https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/UsingWithRDS.IAMDBAuth.html>
|
||||
/// ~3500 PUT/COPY/POST/DELETE or 5500 GET/HEAD S3 requests
|
||||
/// <https://aws.amazon.com/premiumsupport/knowledge-center/s3-request-limit-avoid-throttling/>
|
||||
pub const DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT: usize = 100;
|
||||
/// We set this a little bit low as we currently buffer the entire file into RAM
|
||||
///
|
||||
/// Here, a limit of max 20k concurrent connections was noted.
|
||||
/// <https://learn.microsoft.com/en-us/answers/questions/1301863/is-there-any-limitation-to-concurrent-connections>
|
||||
pub const DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT: usize = 30;
|
||||
/// No limits on the client side, which currenltly means 1000 for AWS S3.
|
||||
/// <https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html#API_ListObjectsV2_RequestSyntax>
|
||||
pub const DEFAULT_MAX_KEYS_PER_LIST_RESPONSE: Option<i32> = None;
|
||||
@@ -54,7 +52,7 @@ const REMOTE_STORAGE_PREFIX_SEPARATOR: char = '/';
|
||||
/// The prefix is an implementation detail, that allows representing local paths
|
||||
/// as the remote ones, stripping the local storage prefix away.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub struct RemotePath(Utf8PathBuf);
|
||||
pub struct RemotePath(PathBuf);
|
||||
|
||||
impl Serialize for RemotePath {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
@@ -71,18 +69,18 @@ impl<'de> Deserialize<'de> for RemotePath {
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let str = String::deserialize(deserializer)?;
|
||||
Ok(Self(Utf8PathBuf::from(&str)))
|
||||
Ok(Self(PathBuf::from(&str)))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for RemotePath {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
std::fmt::Display::fmt(&self.0, f)
|
||||
write!(f, "{}", self.0.display())
|
||||
}
|
||||
}
|
||||
|
||||
impl RemotePath {
|
||||
pub fn new(relative_path: &Utf8Path) -> anyhow::Result<Self> {
|
||||
pub fn new(relative_path: &Path) -> anyhow::Result<Self> {
|
||||
anyhow::ensure!(
|
||||
relative_path.is_relative(),
|
||||
"Path {relative_path:?} is not relative"
|
||||
@@ -91,50 +89,34 @@ impl RemotePath {
|
||||
}
|
||||
|
||||
pub fn from_string(relative_path: &str) -> anyhow::Result<Self> {
|
||||
Self::new(Utf8Path::new(relative_path))
|
||||
Self::new(Path::new(relative_path))
|
||||
}
|
||||
|
||||
pub fn with_base(&self, base_path: &Utf8Path) -> Utf8PathBuf {
|
||||
pub fn with_base(&self, base_path: &Path) -> PathBuf {
|
||||
base_path.join(&self.0)
|
||||
}
|
||||
|
||||
pub fn object_name(&self) -> Option<&str> {
|
||||
self.0.file_name()
|
||||
self.0.file_name().and_then(|os_str| os_str.to_str())
|
||||
}
|
||||
|
||||
pub fn join(&self, segment: &Utf8Path) -> Self {
|
||||
pub fn join(&self, segment: &Path) -> Self {
|
||||
Self(self.0.join(segment))
|
||||
}
|
||||
|
||||
pub fn get_path(&self) -> &Utf8PathBuf {
|
||||
pub fn get_path(&self) -> &PathBuf {
|
||||
&self.0
|
||||
}
|
||||
|
||||
pub fn extension(&self) -> Option<&str> {
|
||||
self.0.extension()
|
||||
self.0.extension()?.to_str()
|
||||
}
|
||||
|
||||
pub fn strip_prefix(&self, p: &RemotePath) -> Result<&Utf8Path, std::path::StripPrefixError> {
|
||||
pub fn strip_prefix(&self, p: &RemotePath) -> Result<&Path, std::path::StripPrefixError> {
|
||||
self.0.strip_prefix(&p.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// We don't need callers to be able to pass arbitrary delimiters: just control
|
||||
/// whether listings will use a '/' separator or not.
|
||||
///
|
||||
/// The WithDelimiter mode will populate `prefixes` and `keys` in the result. The
|
||||
/// NoDelimiter mode will only populate `keys`.
|
||||
pub enum ListingMode {
|
||||
WithDelimiter,
|
||||
NoDelimiter,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Listing {
|
||||
pub prefixes: Vec<RemotePath>,
|
||||
pub keys: Vec<RemotePath>,
|
||||
}
|
||||
|
||||
/// Storage (potentially remote) API to manage its state.
|
||||
/// This storage tries to be unaware of any layered repository context,
|
||||
/// providing basic CRUD operations for storage files.
|
||||
@@ -147,13 +129,8 @@ pub trait RemoteStorage: Send + Sync + 'static {
|
||||
async fn list_prefixes(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
) -> Result<Vec<RemotePath>, DownloadError> {
|
||||
let result = self
|
||||
.list(prefix, ListingMode::WithDelimiter)
|
||||
.await?
|
||||
.prefixes;
|
||||
Ok(result)
|
||||
}
|
||||
) -> Result<Vec<RemotePath>, DownloadError>;
|
||||
|
||||
/// Lists all files in directory "recursively"
|
||||
/// (not really recursively, because AWS has a flat namespace)
|
||||
/// Note: This is subtely different than list_prefixes,
|
||||
@@ -165,16 +142,7 @@ pub trait RemoteStorage: Send + Sync + 'static {
|
||||
/// whereas,
|
||||
/// list_prefixes("foo/bar/") = ["cat", "dog"]
|
||||
/// See `test_real_s3.rs` for more details.
|
||||
async fn list_files(&self, prefix: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>> {
|
||||
let result = self.list(prefix, ListingMode::NoDelimiter).await?.keys;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
_mode: ListingMode,
|
||||
) -> anyhow::Result<Listing, DownloadError>;
|
||||
async fn list_files(&self, folder: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>>;
|
||||
|
||||
/// Streams the local file contents into remote into the remote storage entry.
|
||||
async fn upload(
|
||||
@@ -225,9 +193,6 @@ pub enum DownloadError {
|
||||
BadInput(anyhow::Error),
|
||||
/// The file was not found in the remote storage.
|
||||
NotFound,
|
||||
/// A cancellation token aborted the download, typically during
|
||||
/// tenant detach or process shutdown.
|
||||
Cancelled,
|
||||
/// The file was found in the remote storage, but the download failed.
|
||||
Other(anyhow::Error),
|
||||
}
|
||||
@@ -238,7 +203,6 @@ impl std::fmt::Display for DownloadError {
|
||||
DownloadError::BadInput(e) => {
|
||||
write!(f, "Failed to download a remote file due to user input: {e}")
|
||||
}
|
||||
DownloadError::Cancelled => write!(f, "Cancelled, shutting down"),
|
||||
DownloadError::NotFound => write!(f, "No file found for the remote object id given"),
|
||||
DownloadError::Other(e) => write!(f, "Failed to download a remote file: {e:?}"),
|
||||
}
|
||||
@@ -253,24 +217,10 @@ impl std::error::Error for DownloadError {}
|
||||
pub enum GenericRemoteStorage {
|
||||
LocalFs(LocalFs),
|
||||
AwsS3(Arc<S3Bucket>),
|
||||
AzureBlob(Arc<AzureBlobStorage>),
|
||||
Unreliable(Arc<UnreliableWrapper>),
|
||||
}
|
||||
|
||||
impl GenericRemoteStorage {
|
||||
pub async fn list(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
mode: ListingMode,
|
||||
) -> anyhow::Result<Listing, DownloadError> {
|
||||
match self {
|
||||
Self::LocalFs(s) => s.list(prefix, mode).await,
|
||||
Self::AwsS3(s) => s.list(prefix, mode).await,
|
||||
Self::AzureBlob(s) => s.list(prefix, mode).await,
|
||||
Self::Unreliable(s) => s.list(prefix, mode).await,
|
||||
}
|
||||
}
|
||||
|
||||
// A function for listing all the files in a "directory"
|
||||
// Example:
|
||||
// list_files("foo/bar") = ["foo/bar/a.txt", "foo/bar/b.txt"]
|
||||
@@ -278,7 +228,6 @@ impl GenericRemoteStorage {
|
||||
match self {
|
||||
Self::LocalFs(s) => s.list_files(folder).await,
|
||||
Self::AwsS3(s) => s.list_files(folder).await,
|
||||
Self::AzureBlob(s) => s.list_files(folder).await,
|
||||
Self::Unreliable(s) => s.list_files(folder).await,
|
||||
}
|
||||
}
|
||||
@@ -293,7 +242,6 @@ impl GenericRemoteStorage {
|
||||
match self {
|
||||
Self::LocalFs(s) => s.list_prefixes(prefix).await,
|
||||
Self::AwsS3(s) => s.list_prefixes(prefix).await,
|
||||
Self::AzureBlob(s) => s.list_prefixes(prefix).await,
|
||||
Self::Unreliable(s) => s.list_prefixes(prefix).await,
|
||||
}
|
||||
}
|
||||
@@ -308,7 +256,6 @@ impl GenericRemoteStorage {
|
||||
match self {
|
||||
Self::LocalFs(s) => s.upload(from, data_size_bytes, to, metadata).await,
|
||||
Self::AwsS3(s) => s.upload(from, data_size_bytes, to, metadata).await,
|
||||
Self::AzureBlob(s) => s.upload(from, data_size_bytes, to, metadata).await,
|
||||
Self::Unreliable(s) => s.upload(from, data_size_bytes, to, metadata).await,
|
||||
}
|
||||
}
|
||||
@@ -317,7 +264,6 @@ impl GenericRemoteStorage {
|
||||
match self {
|
||||
Self::LocalFs(s) => s.download(from).await,
|
||||
Self::AwsS3(s) => s.download(from).await,
|
||||
Self::AzureBlob(s) => s.download(from).await,
|
||||
Self::Unreliable(s) => s.download(from).await,
|
||||
}
|
||||
}
|
||||
@@ -337,10 +283,6 @@ impl GenericRemoteStorage {
|
||||
s.download_byte_range(from, start_inclusive, end_exclusive)
|
||||
.await
|
||||
}
|
||||
Self::AzureBlob(s) => {
|
||||
s.download_byte_range(from, start_inclusive, end_exclusive)
|
||||
.await
|
||||
}
|
||||
Self::Unreliable(s) => {
|
||||
s.download_byte_range(from, start_inclusive, end_exclusive)
|
||||
.await
|
||||
@@ -352,7 +294,6 @@ impl GenericRemoteStorage {
|
||||
match self {
|
||||
Self::LocalFs(s) => s.delete(path).await,
|
||||
Self::AwsS3(s) => s.delete(path).await,
|
||||
Self::AzureBlob(s) => s.delete(path).await,
|
||||
Self::Unreliable(s) => s.delete(path).await,
|
||||
}
|
||||
}
|
||||
@@ -361,7 +302,6 @@ impl GenericRemoteStorage {
|
||||
match self {
|
||||
Self::LocalFs(s) => s.delete_objects(paths).await,
|
||||
Self::AwsS3(s) => s.delete_objects(paths).await,
|
||||
Self::AzureBlob(s) => s.delete_objects(paths).await,
|
||||
Self::Unreliable(s) => s.delete_objects(paths).await,
|
||||
}
|
||||
}
|
||||
@@ -371,7 +311,7 @@ impl GenericRemoteStorage {
|
||||
pub fn from_config(storage_config: &RemoteStorageConfig) -> anyhow::Result<Self> {
|
||||
Ok(match &storage_config.storage {
|
||||
RemoteStorageKind::LocalFs(root) => {
|
||||
info!("Using fs root '{root}' as a remote storage");
|
||||
info!("Using fs root '{}' as a remote storage", root.display());
|
||||
Self::LocalFs(LocalFs::new(root.clone())?)
|
||||
}
|
||||
RemoteStorageKind::AwsS3(s3_config) => {
|
||||
@@ -379,11 +319,6 @@ impl GenericRemoteStorage {
|
||||
s3_config.bucket_name, s3_config.bucket_region, s3_config.prefix_in_bucket, s3_config.endpoint);
|
||||
Self::AwsS3(Arc::new(S3Bucket::new(s3_config)?))
|
||||
}
|
||||
RemoteStorageKind::AzureContainer(azure_config) => {
|
||||
info!("Using azure container '{}' in region '{}' as a remote storage, prefix in container: '{:?}'",
|
||||
azure_config.container_name, azure_config.container_region, azure_config.prefix_in_container);
|
||||
Self::AzureBlob(Arc::new(AzureBlobStorage::new(azure_config)?))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -431,6 +366,10 @@ pub struct StorageMetadata(HashMap<String, String>);
|
||||
/// External backup storage configuration, enough for creating a client for that storage.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RemoteStorageConfig {
|
||||
/// Max allowed number of concurrent sync operations between the API user and the remote storage.
|
||||
pub max_concurrent_syncs: NonZeroUsize,
|
||||
/// Max allowed errors before the sync task is considered failed and evicted.
|
||||
pub max_sync_errors: NonZeroU32,
|
||||
/// The storage connection configuration.
|
||||
pub storage: RemoteStorageKind,
|
||||
}
|
||||
@@ -440,13 +379,10 @@ pub struct RemoteStorageConfig {
|
||||
pub enum RemoteStorageKind {
|
||||
/// Storage based on local file system.
|
||||
/// Specify a root folder to place all stored files into.
|
||||
LocalFs(Utf8PathBuf),
|
||||
LocalFs(PathBuf),
|
||||
/// AWS S3 based storage, storing all files in the S3 bucket
|
||||
/// specified by the config
|
||||
AwsS3(S3Config),
|
||||
/// Azure Blob based storage, storing all files in the container
|
||||
/// specified by the config
|
||||
AzureContainer(AzureConfig),
|
||||
}
|
||||
|
||||
/// AWS S3 bucket coordinates and access credentials to manage the bucket contents (read and write).
|
||||
@@ -486,53 +422,27 @@ impl Debug for S3Config {
|
||||
}
|
||||
}
|
||||
|
||||
/// Azure bucket coordinates and access credentials to manage the bucket contents (read and write).
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct AzureConfig {
|
||||
/// Name of the container to connect to.
|
||||
pub container_name: String,
|
||||
/// The region where the bucket is located at.
|
||||
pub container_region: String,
|
||||
/// A "subfolder" in the container, to use the same container separately by multiple remote storage users at once.
|
||||
pub prefix_in_container: Option<String>,
|
||||
/// Azure has various limits on its API calls, we need not to exceed those.
|
||||
/// See [`DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT`] for more details.
|
||||
pub concurrency_limit: NonZeroUsize,
|
||||
pub max_keys_per_list_response: Option<i32>,
|
||||
}
|
||||
|
||||
impl Debug for AzureConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AzureConfig")
|
||||
.field("bucket_name", &self.container_name)
|
||||
.field("bucket_region", &self.container_region)
|
||||
.field("prefix_in_bucket", &self.prefix_in_container)
|
||||
.field("concurrency_limit", &self.concurrency_limit)
|
||||
.field(
|
||||
"max_keys_per_list_response",
|
||||
&self.max_keys_per_list_response,
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl RemoteStorageConfig {
|
||||
pub fn from_toml(toml: &toml_edit::Item) -> anyhow::Result<Option<RemoteStorageConfig>> {
|
||||
let local_path = toml.get("local_path");
|
||||
let bucket_name = toml.get("bucket_name");
|
||||
let bucket_region = toml.get("bucket_region");
|
||||
let container_name = toml.get("container_name");
|
||||
let container_region = toml.get("container_region");
|
||||
|
||||
let use_azure = container_name.is_some() && container_region.is_some();
|
||||
let max_concurrent_syncs = NonZeroUsize::new(
|
||||
parse_optional_integer("max_concurrent_syncs", toml)?
|
||||
.unwrap_or(DEFAULT_REMOTE_STORAGE_MAX_CONCURRENT_SYNCS),
|
||||
)
|
||||
.context("Failed to parse 'max_concurrent_syncs' as a positive integer")?;
|
||||
|
||||
let max_sync_errors = NonZeroU32::new(
|
||||
parse_optional_integer("max_sync_errors", toml)?
|
||||
.unwrap_or(DEFAULT_REMOTE_STORAGE_MAX_SYNC_ERRORS),
|
||||
)
|
||||
.context("Failed to parse 'max_sync_errors' as a positive integer")?;
|
||||
|
||||
let default_concurrency_limit = if use_azure {
|
||||
DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT
|
||||
} else {
|
||||
DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT
|
||||
};
|
||||
let concurrency_limit = NonZeroUsize::new(
|
||||
parse_optional_integer("concurrency_limit", toml)?.unwrap_or(default_concurrency_limit),
|
||||
parse_optional_integer("concurrency_limit", toml)?
|
||||
.unwrap_or(DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT),
|
||||
)
|
||||
.context("Failed to parse 'concurrency_limit' as a positive integer")?;
|
||||
|
||||
@@ -541,73 +451,40 @@ impl RemoteStorageConfig {
|
||||
.context("Failed to parse 'max_keys_per_list_response' as a positive integer")?
|
||||
.or(DEFAULT_MAX_KEYS_PER_LIST_RESPONSE);
|
||||
|
||||
let endpoint = toml
|
||||
.get("endpoint")
|
||||
.map(|endpoint| parse_toml_string("endpoint", endpoint))
|
||||
.transpose()?;
|
||||
|
||||
let storage = match (
|
||||
local_path,
|
||||
bucket_name,
|
||||
bucket_region,
|
||||
container_name,
|
||||
container_region,
|
||||
) {
|
||||
let storage = match (local_path, bucket_name, bucket_region) {
|
||||
// no 'local_path' nor 'bucket_name' options are provided, consider this remote storage disabled
|
||||
(None, None, None, None, None) => return Ok(None),
|
||||
(_, Some(_), None, ..) => {
|
||||
(None, None, None) => return Ok(None),
|
||||
(_, Some(_), None) => {
|
||||
bail!("'bucket_region' option is mandatory if 'bucket_name' is given ")
|
||||
}
|
||||
(_, None, Some(_), ..) => {
|
||||
(_, None, Some(_)) => {
|
||||
bail!("'bucket_name' option is mandatory if 'bucket_region' is given ")
|
||||
}
|
||||
(None, Some(bucket_name), Some(bucket_region), ..) => {
|
||||
RemoteStorageKind::AwsS3(S3Config {
|
||||
bucket_name: parse_toml_string("bucket_name", bucket_name)?,
|
||||
bucket_region: parse_toml_string("bucket_region", bucket_region)?,
|
||||
prefix_in_bucket: toml
|
||||
.get("prefix_in_bucket")
|
||||
.map(|prefix_in_bucket| {
|
||||
parse_toml_string("prefix_in_bucket", prefix_in_bucket)
|
||||
})
|
||||
.transpose()?,
|
||||
endpoint,
|
||||
concurrency_limit,
|
||||
max_keys_per_list_response,
|
||||
})
|
||||
}
|
||||
(_, _, _, Some(_), None) => {
|
||||
bail!("'container_name' option is mandatory if 'container_region' is given ")
|
||||
}
|
||||
(_, _, _, None, Some(_)) => {
|
||||
bail!("'container_name' option is mandatory if 'container_region' is given ")
|
||||
}
|
||||
(None, None, None, Some(container_name), Some(container_region)) => {
|
||||
RemoteStorageKind::AzureContainer(AzureConfig {
|
||||
container_name: parse_toml_string("container_name", container_name)?,
|
||||
container_region: parse_toml_string("container_region", container_region)?,
|
||||
prefix_in_container: toml
|
||||
.get("prefix_in_container")
|
||||
.map(|prefix_in_container| {
|
||||
parse_toml_string("prefix_in_container", prefix_in_container)
|
||||
})
|
||||
.transpose()?,
|
||||
concurrency_limit,
|
||||
max_keys_per_list_response,
|
||||
})
|
||||
}
|
||||
(Some(local_path), None, None, None, None) => RemoteStorageKind::LocalFs(
|
||||
Utf8PathBuf::from(parse_toml_string("local_path", local_path)?),
|
||||
),
|
||||
(Some(_), Some(_), ..) => {
|
||||
bail!("'local_path' and 'bucket_name' are mutually exclusive")
|
||||
}
|
||||
(Some(_), _, _, Some(_), Some(_)) => {
|
||||
bail!("local_path and 'container_name' are mutually exclusive")
|
||||
}
|
||||
(None, Some(bucket_name), Some(bucket_region)) => RemoteStorageKind::AwsS3(S3Config {
|
||||
bucket_name: parse_toml_string("bucket_name", bucket_name)?,
|
||||
bucket_region: parse_toml_string("bucket_region", bucket_region)?,
|
||||
prefix_in_bucket: toml
|
||||
.get("prefix_in_bucket")
|
||||
.map(|prefix_in_bucket| parse_toml_string("prefix_in_bucket", prefix_in_bucket))
|
||||
.transpose()?,
|
||||
endpoint: toml
|
||||
.get("endpoint")
|
||||
.map(|endpoint| parse_toml_string("endpoint", endpoint))
|
||||
.transpose()?,
|
||||
concurrency_limit,
|
||||
max_keys_per_list_response,
|
||||
}),
|
||||
(Some(local_path), None, None) => RemoteStorageKind::LocalFs(PathBuf::from(
|
||||
parse_toml_string("local_path", local_path)?,
|
||||
)),
|
||||
(Some(_), Some(_), _) => bail!("local_path and bucket_name are mutually exclusive"),
|
||||
};
|
||||
|
||||
Ok(Some(RemoteStorageConfig { storage }))
|
||||
Ok(Some(RemoteStorageConfig {
|
||||
max_concurrent_syncs,
|
||||
max_sync_errors,
|
||||
storage,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -636,69 +513,29 @@ fn parse_toml_string(name: &str, item: &Item) -> anyhow::Result<String> {
|
||||
Ok(s.to_string())
|
||||
}
|
||||
|
||||
struct ConcurrencyLimiter {
|
||||
// Every request to S3 can be throttled or cancelled, if a certain number of requests per second is exceeded.
|
||||
// Same goes to IAM, which is queried before every S3 request, if enabled. IAM has even lower RPS threshold.
|
||||
// The helps to ensure we don't exceed the thresholds.
|
||||
write: Arc<Semaphore>,
|
||||
read: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl ConcurrencyLimiter {
|
||||
fn for_kind(&self, kind: RequestKind) -> &Arc<Semaphore> {
|
||||
match kind {
|
||||
RequestKind::Get => &self.read,
|
||||
RequestKind::Put => &self.write,
|
||||
RequestKind::List => &self.read,
|
||||
RequestKind::Delete => &self.write,
|
||||
}
|
||||
}
|
||||
|
||||
async fn acquire(
|
||||
&self,
|
||||
kind: RequestKind,
|
||||
) -> Result<tokio::sync::SemaphorePermit<'_>, tokio::sync::AcquireError> {
|
||||
self.for_kind(kind).acquire().await
|
||||
}
|
||||
|
||||
async fn acquire_owned(
|
||||
&self,
|
||||
kind: RequestKind,
|
||||
) -> Result<tokio::sync::OwnedSemaphorePermit, tokio::sync::AcquireError> {
|
||||
Arc::clone(self.for_kind(kind)).acquire_owned().await
|
||||
}
|
||||
|
||||
fn new(limit: usize) -> ConcurrencyLimiter {
|
||||
Self {
|
||||
read: Arc::new(Semaphore::new(limit)),
|
||||
write: Arc::new(Semaphore::new(limit)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_object_name() {
|
||||
let k = RemotePath::new(Utf8Path::new("a/b/c")).unwrap();
|
||||
let k = RemotePath::new(Path::new("a/b/c")).unwrap();
|
||||
assert_eq!(k.object_name(), Some("c"));
|
||||
|
||||
let k = RemotePath::new(Utf8Path::new("a/b/c/")).unwrap();
|
||||
let k = RemotePath::new(Path::new("a/b/c/")).unwrap();
|
||||
assert_eq!(k.object_name(), Some("c"));
|
||||
|
||||
let k = RemotePath::new(Utf8Path::new("a/")).unwrap();
|
||||
let k = RemotePath::new(Path::new("a/")).unwrap();
|
||||
assert_eq!(k.object_name(), Some("a"));
|
||||
|
||||
// XXX is it impossible to have an empty key?
|
||||
let k = RemotePath::new(Utf8Path::new("")).unwrap();
|
||||
let k = RemotePath::new(Path::new("")).unwrap();
|
||||
assert_eq!(k.object_name(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rempte_path_cannot_be_created_from_absolute_ones() {
|
||||
let err = RemotePath::new(Utf8Path::new("/")).expect_err("Should fail on absolute paths");
|
||||
let err = RemotePath::new(Path::new("/")).expect_err("Should fail on absolute paths");
|
||||
assert_eq!(err.to_string(), "Path \"/\" is not relative");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,10 +4,15 @@
|
||||
//! This storage used in tests, but can also be used in cases when a certain persistent
|
||||
//! volume is mounted to the local FS.
|
||||
|
||||
use std::{borrow::Cow, future::Future, io::ErrorKind, pin::Pin};
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
future::Future,
|
||||
io::ErrorKind,
|
||||
path::{Path, PathBuf},
|
||||
pin::Pin,
|
||||
};
|
||||
|
||||
use anyhow::{bail, ensure, Context};
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
use tokio::{
|
||||
fs,
|
||||
io::{self, AsyncReadExt, AsyncSeekExt, AsyncWriteExt},
|
||||
@@ -15,7 +20,7 @@ use tokio::{
|
||||
use tracing::*;
|
||||
use utils::{crashsafe::path_with_suffix_extension, fs_ext::is_directory_empty};
|
||||
|
||||
use crate::{Download, DownloadError, Listing, ListingMode, RemotePath};
|
||||
use crate::{Download, DownloadError, RemotePath};
|
||||
|
||||
use super::{RemoteStorage, StorageMetadata};
|
||||
|
||||
@@ -23,20 +28,20 @@ const LOCAL_FS_TEMP_FILE_SUFFIX: &str = "___temp";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LocalFs {
|
||||
storage_root: Utf8PathBuf,
|
||||
storage_root: PathBuf,
|
||||
}
|
||||
|
||||
impl LocalFs {
|
||||
/// Attempts to create local FS storage, along with its root directory.
|
||||
/// Storage root will be created (if does not exist) and transformed into an absolute path (if passed as relative).
|
||||
pub fn new(mut storage_root: Utf8PathBuf) -> anyhow::Result<Self> {
|
||||
pub fn new(mut storage_root: PathBuf) -> anyhow::Result<Self> {
|
||||
if !storage_root.exists() {
|
||||
std::fs::create_dir_all(&storage_root).with_context(|| {
|
||||
format!("Failed to create all directories in the given root path {storage_root:?}")
|
||||
})?;
|
||||
}
|
||||
if !storage_root.is_absolute() {
|
||||
storage_root = storage_root.canonicalize_utf8().with_context(|| {
|
||||
storage_root = storage_root.canonicalize().with_context(|| {
|
||||
format!("Failed to represent path {storage_root:?} as an absolute path")
|
||||
})?;
|
||||
}
|
||||
@@ -45,7 +50,7 @@ impl LocalFs {
|
||||
}
|
||||
|
||||
// mirrors S3Bucket::s3_object_to_relative_path
|
||||
fn local_file_to_relative_path(&self, key: Utf8PathBuf) -> RemotePath {
|
||||
fn local_file_to_relative_path(&self, key: PathBuf) -> RemotePath {
|
||||
let relative_path = key
|
||||
.strip_prefix(&self.storage_root)
|
||||
.expect("relative path must contain storage_root as prefix");
|
||||
@@ -54,18 +59,22 @@ impl LocalFs {
|
||||
|
||||
async fn read_storage_metadata(
|
||||
&self,
|
||||
file_path: &Utf8Path,
|
||||
file_path: &Path,
|
||||
) -> anyhow::Result<Option<StorageMetadata>> {
|
||||
let metadata_path = storage_metadata_path(file_path);
|
||||
if metadata_path.exists() && metadata_path.is_file() {
|
||||
let metadata_string = fs::read_to_string(&metadata_path).await.with_context(|| {
|
||||
format!("Failed to read metadata from the local storage at '{metadata_path}'")
|
||||
format!(
|
||||
"Failed to read metadata from the local storage at '{}'",
|
||||
metadata_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
serde_json::from_str(&metadata_string)
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to deserialize metadata from the local storage at '{metadata_path}'",
|
||||
"Failed to deserialize metadata from the local storage at '{}'",
|
||||
metadata_path.display()
|
||||
)
|
||||
})
|
||||
.map(|metadata| Some(StorageMetadata(metadata)))
|
||||
@@ -75,7 +84,7 @@ impl LocalFs {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn list_all(&self) -> anyhow::Result<Vec<RemotePath>> {
|
||||
async fn list(&self) -> anyhow::Result<Vec<RemotePath>> {
|
||||
Ok(get_all_files(&self.storage_root, true)
|
||||
.await?
|
||||
.into_iter()
|
||||
@@ -89,10 +98,52 @@ impl LocalFs {
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl RemoteStorage for LocalFs {
|
||||
async fn list_prefixes(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
) -> Result<Vec<RemotePath>, DownloadError> {
|
||||
let path = match prefix {
|
||||
Some(prefix) => Cow::Owned(prefix.with_base(&self.storage_root)),
|
||||
None => Cow::Borrowed(&self.storage_root),
|
||||
};
|
||||
|
||||
let prefixes_to_filter = get_all_files(path.as_ref(), false)
|
||||
.await
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
let mut prefixes = Vec::with_capacity(prefixes_to_filter.len());
|
||||
|
||||
// filter out empty directories to mirror s3 behavior.
|
||||
for prefix in prefixes_to_filter {
|
||||
if prefix.is_dir()
|
||||
&& is_directory_empty(&prefix)
|
||||
.await
|
||||
.map_err(DownloadError::Other)?
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
prefixes.push(
|
||||
prefix
|
||||
.strip_prefix(&self.storage_root)
|
||||
.context("Failed to strip prefix")
|
||||
.and_then(RemotePath::new)
|
||||
.expect(
|
||||
"We list files for storage root, hence should be able to remote the prefix",
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
Ok(prefixes)
|
||||
}
|
||||
|
||||
// recursively lists all files in a directory,
|
||||
// mirroring the `list_files` for `s3_bucket`
|
||||
async fn list_recursive(&self, folder: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>> {
|
||||
async fn list_files(&self, folder: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>> {
|
||||
let full_path = match folder {
|
||||
Some(folder) => folder.with_base(&self.storage_root),
|
||||
None => self.storage_root.clone(),
|
||||
@@ -120,21 +171,25 @@ impl LocalFs {
|
||||
}
|
||||
}
|
||||
|
||||
// Note that Utf8PathBuf starts_with only considers full path segments, but
|
||||
// Note that PathBuf starts_with only considers full path segments, but
|
||||
// object prefixes are arbitrary strings, so we need the strings for doing
|
||||
// starts_with later.
|
||||
let prefix = full_path.as_str();
|
||||
let prefix = full_path.to_string_lossy();
|
||||
|
||||
let mut files = vec![];
|
||||
let mut directory_queue = vec![initial_dir];
|
||||
let mut directory_queue = vec![initial_dir.clone()];
|
||||
while let Some(cur_folder) = directory_queue.pop() {
|
||||
let mut entries = cur_folder.read_dir_utf8()?;
|
||||
while let Some(Ok(entry)) = entries.next() {
|
||||
let file_name = entry.file_name();
|
||||
let full_file_name = cur_folder.join(file_name);
|
||||
if full_file_name.as_str().starts_with(prefix) {
|
||||
let mut entries = fs::read_dir(cur_folder.clone()).await?;
|
||||
while let Some(entry) = entries.next_entry().await? {
|
||||
let file_name: PathBuf = entry.file_name().into();
|
||||
let full_file_name = cur_folder.clone().join(&file_name);
|
||||
if full_file_name
|
||||
.to_str()
|
||||
.map(|s| s.starts_with(prefix.as_ref()))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let file_remote_path = self.local_file_to_relative_path(full_file_name.clone());
|
||||
files.push(file_remote_path);
|
||||
files.push(file_remote_path.clone());
|
||||
if full_file_name.is_dir() {
|
||||
directory_queue.push(full_file_name);
|
||||
}
|
||||
@@ -144,70 +199,6 @@ impl LocalFs {
|
||||
|
||||
Ok(files)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl RemoteStorage for LocalFs {
|
||||
async fn list(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
mode: ListingMode,
|
||||
) -> Result<Listing, DownloadError> {
|
||||
let mut result = Listing::default();
|
||||
|
||||
if let ListingMode::NoDelimiter = mode {
|
||||
let keys = self
|
||||
.list_recursive(prefix)
|
||||
.await
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
result.keys = keys
|
||||
.into_iter()
|
||||
.filter(|k| {
|
||||
let path = k.with_base(&self.storage_root);
|
||||
!path.is_dir()
|
||||
})
|
||||
.collect();
|
||||
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
let path = match prefix {
|
||||
Some(prefix) => Cow::Owned(prefix.with_base(&self.storage_root)),
|
||||
None => Cow::Borrowed(&self.storage_root),
|
||||
};
|
||||
|
||||
let prefixes_to_filter = get_all_files(path.as_ref(), false)
|
||||
.await
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
// filter out empty directories to mirror s3 behavior.
|
||||
for prefix in prefixes_to_filter {
|
||||
if prefix.is_dir()
|
||||
&& is_directory_empty(&prefix)
|
||||
.await
|
||||
.map_err(DownloadError::Other)?
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let stripped = prefix
|
||||
.strip_prefix(&self.storage_root)
|
||||
.context("Failed to strip prefix")
|
||||
.and_then(RemotePath::new)
|
||||
.expect(
|
||||
"We list files for storage root, hence should be able to remote the prefix",
|
||||
);
|
||||
|
||||
if prefix.is_dir() {
|
||||
result.prefixes.push(stripped);
|
||||
} else {
|
||||
result.keys.push(stripped);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn upload(
|
||||
&self,
|
||||
@@ -239,7 +230,10 @@ impl RemoteStorage for LocalFs {
|
||||
.open(&temp_file_path)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!("Failed to open target fs destination at '{target_file_path}'")
|
||||
format!(
|
||||
"Failed to open target fs destination at '{}'",
|
||||
target_file_path.display()
|
||||
)
|
||||
})?,
|
||||
);
|
||||
|
||||
@@ -250,7 +244,8 @@ impl RemoteStorage for LocalFs {
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to upload file (write temp) to the local storage at '{temp_file_path}'",
|
||||
"Failed to upload file (write temp) to the local storage at '{}'",
|
||||
temp_file_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
@@ -267,7 +262,8 @@ impl RemoteStorage for LocalFs {
|
||||
|
||||
destination.flush().await.with_context(|| {
|
||||
format!(
|
||||
"Failed to upload (flush temp) file to the local storage at '{temp_file_path}'",
|
||||
"Failed to upload (flush temp) file to the local storage at '{}'",
|
||||
temp_file_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
@@ -275,7 +271,8 @@ impl RemoteStorage for LocalFs {
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to upload (rename) file to the local storage at '{target_file_path}'",
|
||||
"Failed to upload (rename) file to the local storage at '{}'",
|
||||
target_file_path.display()
|
||||
)
|
||||
})?;
|
||||
|
||||
@@ -289,7 +286,8 @@ impl RemoteStorage for LocalFs {
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to write metadata to the local storage at '{storage_metadata_path}'",
|
||||
"Failed to write metadata to the local storage at '{}'",
|
||||
storage_metadata_path.display()
|
||||
)
|
||||
})?;
|
||||
}
|
||||
@@ -395,16 +393,16 @@ impl RemoteStorage for LocalFs {
|
||||
}
|
||||
}
|
||||
|
||||
fn storage_metadata_path(original_path: &Utf8Path) -> Utf8PathBuf {
|
||||
fn storage_metadata_path(original_path: &Path) -> PathBuf {
|
||||
path_with_suffix_extension(original_path, "metadata")
|
||||
}
|
||||
|
||||
fn get_all_files<'a, P>(
|
||||
directory_path: P,
|
||||
recursive: bool,
|
||||
) -> Pin<Box<dyn Future<Output = anyhow::Result<Vec<Utf8PathBuf>>> + Send + Sync + 'a>>
|
||||
) -> Pin<Box<dyn Future<Output = anyhow::Result<Vec<PathBuf>>> + Send + Sync + 'a>>
|
||||
where
|
||||
P: AsRef<Utf8Path> + Send + Sync + 'a,
|
||||
P: AsRef<Path> + Send + Sync + 'a,
|
||||
{
|
||||
Box::pin(async move {
|
||||
let directory_path = directory_path.as_ref();
|
||||
@@ -414,13 +412,7 @@ where
|
||||
let mut dir_contents = fs::read_dir(directory_path).await?;
|
||||
while let Some(dir_entry) = dir_contents.next_entry().await? {
|
||||
let file_type = dir_entry.file_type().await?;
|
||||
let entry_path =
|
||||
Utf8PathBuf::from_path_buf(dir_entry.path()).map_err(|pb| {
|
||||
anyhow::Error::msg(format!(
|
||||
"non-Unicode path: {}",
|
||||
pb.to_string_lossy()
|
||||
))
|
||||
})?;
|
||||
let entry_path = dir_entry.path();
|
||||
if file_type.is_symlink() {
|
||||
debug!("{entry_path:?} is a symlink, skipping")
|
||||
} else if file_type.is_dir() {
|
||||
@@ -443,10 +435,13 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
async fn create_target_directory(target_file_path: &Utf8Path) -> anyhow::Result<()> {
|
||||
async fn create_target_directory(target_file_path: &Path) -> anyhow::Result<()> {
|
||||
let target_dir = match target_file_path.parent() {
|
||||
Some(parent_dir) => parent_dir,
|
||||
None => bail!("File path '{target_file_path}' has no parent directory"),
|
||||
None => bail!(
|
||||
"File path '{}' has no parent directory",
|
||||
target_file_path.display()
|
||||
),
|
||||
};
|
||||
if !target_dir.exists() {
|
||||
fs::create_dir_all(target_dir).await?;
|
||||
@@ -454,9 +449,13 @@ async fn create_target_directory(target_file_path: &Utf8Path) -> anyhow::Result<
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn file_exists(file_path: &Utf8Path) -> anyhow::Result<bool> {
|
||||
fn file_exists(file_path: &Path) -> anyhow::Result<bool> {
|
||||
if file_path.exists() {
|
||||
ensure!(file_path.is_file(), "file path '{file_path}' is not a file");
|
||||
ensure!(
|
||||
file_path.is_file(),
|
||||
"file path '{}' is not a file",
|
||||
file_path.display()
|
||||
);
|
||||
Ok(true)
|
||||
} else {
|
||||
Ok(false)
|
||||
@@ -467,13 +466,13 @@ fn file_exists(file_path: &Utf8Path) -> anyhow::Result<bool> {
|
||||
mod fs_tests {
|
||||
use super::*;
|
||||
|
||||
use camino_tempfile::tempdir;
|
||||
use std::{collections::HashMap, io::Write};
|
||||
use tempfile::tempdir;
|
||||
|
||||
async fn read_and_assert_remote_file_contents(
|
||||
storage: &LocalFs,
|
||||
#[allow(clippy::ptr_arg)]
|
||||
// have to use &Utf8PathBuf due to `storage.local_path` parameter requirements
|
||||
// have to use &PathBuf due to `storage.local_path` parameter requirements
|
||||
remote_storage_path: &RemotePath,
|
||||
expected_metadata: Option<&StorageMetadata>,
|
||||
) -> anyhow::Result<String> {
|
||||
@@ -501,7 +500,7 @@ mod fs_tests {
|
||||
|
||||
let target_path_1 = upload_dummy_file(&storage, "upload_1", None).await?;
|
||||
assert_eq!(
|
||||
storage.list_all().await?,
|
||||
storage.list().await?,
|
||||
vec![target_path_1.clone()],
|
||||
"Should list a single file after first upload"
|
||||
);
|
||||
@@ -520,7 +519,7 @@ mod fs_tests {
|
||||
async fn upload_file_negatives() -> anyhow::Result<()> {
|
||||
let storage = create_storage()?;
|
||||
|
||||
let id = RemotePath::new(Utf8Path::new("dummy"))?;
|
||||
let id = RemotePath::new(Path::new("dummy"))?;
|
||||
let content = std::io::Cursor::new(b"12345");
|
||||
|
||||
// Check that you get an error if the size parameter doesn't match the actual
|
||||
@@ -545,8 +544,7 @@ mod fs_tests {
|
||||
}
|
||||
|
||||
fn create_storage() -> anyhow::Result<LocalFs> {
|
||||
let storage_root = tempdir()?.path().to_path_buf();
|
||||
LocalFs::new(storage_root)
|
||||
LocalFs::new(tempdir()?.path().to_owned())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -563,7 +561,7 @@ mod fs_tests {
|
||||
);
|
||||
|
||||
let non_existing_path = "somewhere/else";
|
||||
match storage.download(&RemotePath::new(Utf8Path::new(non_existing_path))?).await {
|
||||
match storage.download(&RemotePath::new(Path::new(non_existing_path))?).await {
|
||||
Err(DownloadError::NotFound) => {} // Should get NotFound for non existing keys
|
||||
other => panic!("Should get a NotFound error when downloading non-existing storage files, but got: {other:?}"),
|
||||
}
|
||||
@@ -689,7 +687,7 @@ mod fs_tests {
|
||||
let upload_target = upload_dummy_file(&storage, upload_name, None).await?;
|
||||
|
||||
storage.delete(&upload_target).await?;
|
||||
assert!(storage.list_all().await?.is_empty());
|
||||
assert!(storage.list().await?.is_empty());
|
||||
|
||||
storage
|
||||
.delete(&upload_target)
|
||||
@@ -747,43 +745,6 @@ mod fs_tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list() -> anyhow::Result<()> {
|
||||
// No delimiter: should recursively list everything
|
||||
let storage = create_storage()?;
|
||||
let child = upload_dummy_file(&storage, "grandparent/parent/child", None).await?;
|
||||
let uncle = upload_dummy_file(&storage, "grandparent/uncle", None).await?;
|
||||
|
||||
let listing = storage.list(None, ListingMode::NoDelimiter).await?;
|
||||
assert!(listing.prefixes.is_empty());
|
||||
assert_eq!(listing.keys, [uncle.clone(), child.clone()].to_vec());
|
||||
|
||||
// Delimiter: should only go one deep
|
||||
let listing = storage.list(None, ListingMode::WithDelimiter).await?;
|
||||
|
||||
assert_eq!(
|
||||
listing.prefixes,
|
||||
[RemotePath::from_string("timelines").unwrap()].to_vec()
|
||||
);
|
||||
assert!(listing.keys.is_empty());
|
||||
|
||||
// Delimiter & prefix
|
||||
let listing = storage
|
||||
.list(
|
||||
Some(&RemotePath::from_string("timelines/some_timeline/grandparent").unwrap()),
|
||||
ListingMode::WithDelimiter,
|
||||
)
|
||||
.await?;
|
||||
assert_eq!(
|
||||
listing.prefixes,
|
||||
[RemotePath::from_string("timelines/some_timeline/grandparent/parent").unwrap()]
|
||||
.to_vec()
|
||||
);
|
||||
assert_eq!(listing.keys, [uncle.clone()].to_vec());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn upload_dummy_file(
|
||||
storage: &LocalFs,
|
||||
name: &str,
|
||||
@@ -814,7 +775,7 @@ mod fs_tests {
|
||||
}
|
||||
|
||||
async fn create_file_for_upload(
|
||||
path: &Utf8Path,
|
||||
path: &Path,
|
||||
contents: &str,
|
||||
) -> anyhow::Result<(io::BufReader<fs::File>, usize)> {
|
||||
std::fs::create_dir_all(path.parent().unwrap())?;
|
||||
@@ -836,7 +797,7 @@ mod fs_tests {
|
||||
}
|
||||
|
||||
async fn list_files_sorted(storage: &LocalFs) -> anyhow::Result<Vec<RemotePath>> {
|
||||
let mut files = storage.list_all().await?;
|
||||
let mut files = storage.list().await?;
|
||||
files.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
@@ -4,44 +4,42 @@
|
||||
//! allowing multiple api users to independently work with the same S3 bucket, if
|
||||
//! their bucket prefixes are both specified and different.
|
||||
|
||||
use std::{borrow::Cow, sync::Arc};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context;
|
||||
use aws_config::{
|
||||
environment::credentials::EnvironmentVariableCredentialsProvider,
|
||||
imds::credentials::ImdsCredentialsProvider,
|
||||
meta::credentials::CredentialsProviderChain,
|
||||
provider_config::ProviderConfig,
|
||||
retry::{RetryConfigBuilder, RetryMode},
|
||||
web_identity_token::WebIdentityTokenCredentialsProvider,
|
||||
imds::credentials::ImdsCredentialsProvider, meta::credentials::CredentialsProviderChain,
|
||||
provider_config::ProviderConfig, web_identity_token::WebIdentityTokenCredentialsProvider,
|
||||
};
|
||||
use aws_credential_types::cache::CredentialsCache;
|
||||
use aws_sdk_s3::{
|
||||
config::{AsyncSleep, Config, Region, SharedAsyncSleep},
|
||||
config::{Config, Region},
|
||||
error::SdkError,
|
||||
operation::get_object::GetObjectError,
|
||||
primitives::ByteStream,
|
||||
types::{Delete, ObjectIdentifier},
|
||||
Client,
|
||||
};
|
||||
use aws_smithy_async::rt::sleep::TokioSleep;
|
||||
use aws_smithy_http::body::SdkBody;
|
||||
use hyper::Body;
|
||||
use scopeguard::ScopeGuard;
|
||||
use tokio::io::{self, AsyncRead};
|
||||
use tokio::{
|
||||
io::{self, AsyncRead},
|
||||
sync::Semaphore,
|
||||
};
|
||||
use tokio_util::io::ReaderStream;
|
||||
use tracing::debug;
|
||||
|
||||
use super::StorageMetadata;
|
||||
use crate::{
|
||||
ConcurrencyLimiter, Download, DownloadError, Listing, ListingMode, RemotePath, RemoteStorage,
|
||||
S3Config, MAX_KEYS_PER_DELETE, REMOTE_STORAGE_PREFIX_SEPARATOR,
|
||||
Download, DownloadError, RemotePath, RemoteStorage, S3Config, MAX_KEYS_PER_DELETE,
|
||||
REMOTE_STORAGE_PREFIX_SEPARATOR,
|
||||
};
|
||||
|
||||
pub(super) mod metrics;
|
||||
|
||||
use self::metrics::AttemptOutcome;
|
||||
pub(super) use self::metrics::RequestKind;
|
||||
use self::metrics::{AttemptOutcome, RequestKind};
|
||||
|
||||
/// AWS S3 storage.
|
||||
pub struct S3Bucket {
|
||||
@@ -49,7 +47,10 @@ pub struct S3Bucket {
|
||||
bucket_name: String,
|
||||
prefix_in_bucket: Option<String>,
|
||||
max_keys_per_list_response: Option<i32>,
|
||||
concurrency_limiter: ConcurrencyLimiter,
|
||||
// Every request to S3 can be throttled or cancelled, if a certain number of requests per second is exceeded.
|
||||
// Same goes to IAM, which is queried before every S3 request, if enabled. IAM has even lower RPS threshold.
|
||||
// The helps to ensure we don't exceed the thresholds.
|
||||
concurrency_limiter: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -87,23 +88,10 @@ impl S3Bucket {
|
||||
.or_else("imds", ImdsCredentialsProvider::builder().build())
|
||||
};
|
||||
|
||||
// AWS SDK requires us to specify how the RetryConfig should sleep when it wants to back off
|
||||
let sleep_impl: Arc<dyn AsyncSleep> = Arc::new(TokioSleep::new());
|
||||
|
||||
// We do our own retries (see [`backoff::retry`]). However, for the AWS SDK to enable rate limiting in response to throttling
|
||||
// responses (e.g. 429 on too many ListObjectsv2 requests), we must provide a retry config. We set it to use at most one
|
||||
// attempt, and enable 'Adaptive' mode, which causes rate limiting to be enabled.
|
||||
let mut retry_config = RetryConfigBuilder::new();
|
||||
retry_config
|
||||
.set_max_attempts(Some(1))
|
||||
.set_mode(Some(RetryMode::Adaptive));
|
||||
|
||||
let mut config_builder = Config::builder()
|
||||
.region(region)
|
||||
.credentials_cache(CredentialsCache::lazy())
|
||||
.credentials_provider(credentials_provider)
|
||||
.sleep_impl(SharedAsyncSleep::from(sleep_impl))
|
||||
.retry_config(retry_config.build());
|
||||
.credentials_provider(credentials_provider);
|
||||
|
||||
if let Some(custom_endpoint) = aws_config.endpoint.clone() {
|
||||
config_builder = config_builder
|
||||
@@ -129,7 +117,7 @@ impl S3Bucket {
|
||||
bucket_name: aws_config.bucket_name.clone(),
|
||||
max_keys_per_list_response: aws_config.max_keys_per_list_response,
|
||||
prefix_in_bucket,
|
||||
concurrency_limiter: ConcurrencyLimiter::new(aws_config.concurrency_limit.get()),
|
||||
concurrency_limiter: Arc::new(Semaphore::new(aws_config.concurrency_limit.get())),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -155,11 +143,12 @@ impl S3Bucket {
|
||||
assert_eq!(std::path::MAIN_SEPARATOR, REMOTE_STORAGE_PREFIX_SEPARATOR);
|
||||
let path_string = path
|
||||
.get_path()
|
||||
.as_str()
|
||||
.trim_end_matches(REMOTE_STORAGE_PREFIX_SEPARATOR);
|
||||
.to_string_lossy()
|
||||
.trim_end_matches(REMOTE_STORAGE_PREFIX_SEPARATOR)
|
||||
.to_string();
|
||||
match &self.prefix_in_bucket {
|
||||
Some(prefix) => prefix.clone() + "/" + path_string,
|
||||
None => path_string.to_string(),
|
||||
Some(prefix) => prefix.clone() + "/" + &path_string,
|
||||
None => path_string,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,7 +156,7 @@ impl S3Bucket {
|
||||
let started_at = start_counting_cancelled_wait(kind);
|
||||
let permit = self
|
||||
.concurrency_limiter
|
||||
.acquire(kind)
|
||||
.acquire()
|
||||
.await
|
||||
.expect("semaphore is never closed");
|
||||
|
||||
@@ -183,7 +172,8 @@ impl S3Bucket {
|
||||
let started_at = start_counting_cancelled_wait(kind);
|
||||
let permit = self
|
||||
.concurrency_limiter
|
||||
.acquire_owned(kind)
|
||||
.clone()
|
||||
.acquire_owned()
|
||||
.await
|
||||
.expect("semaphore is never closed");
|
||||
|
||||
@@ -316,13 +306,13 @@ impl<S: AsyncRead> AsyncRead for TimedDownload<S> {
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl RemoteStorage for S3Bucket {
|
||||
async fn list(
|
||||
/// See the doc for `RemoteStorage::list_prefixes`
|
||||
/// Note: it wont include empty "directories"
|
||||
async fn list_prefixes(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
mode: ListingMode,
|
||||
) -> Result<Listing, DownloadError> {
|
||||
) -> Result<Vec<RemotePath>, DownloadError> {
|
||||
let kind = RequestKind::List;
|
||||
let mut result = Listing::default();
|
||||
|
||||
// get the passed prefix or if it is not set use prefix_in_bucket value
|
||||
let list_prefix = prefix
|
||||
@@ -331,33 +321,28 @@ impl RemoteStorage for S3Bucket {
|
||||
.map(|mut p| {
|
||||
// required to end with a separator
|
||||
// otherwise request will return only the entry of a prefix
|
||||
if matches!(mode, ListingMode::WithDelimiter)
|
||||
&& !p.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR)
|
||||
{
|
||||
if !p.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
|
||||
p.push(REMOTE_STORAGE_PREFIX_SEPARATOR);
|
||||
}
|
||||
p
|
||||
});
|
||||
|
||||
let mut document_keys = Vec::new();
|
||||
|
||||
let mut continuation_token = None;
|
||||
|
||||
loop {
|
||||
let _guard = self.permit(kind).await;
|
||||
let started_at = start_measuring_requests(kind);
|
||||
|
||||
let mut request = self
|
||||
let fetch_response = self
|
||||
.client
|
||||
.list_objects_v2()
|
||||
.bucket(self.bucket_name.clone())
|
||||
.set_prefix(list_prefix.clone())
|
||||
.set_continuation_token(continuation_token)
|
||||
.set_max_keys(self.max_keys_per_list_response);
|
||||
|
||||
if let ListingMode::WithDelimiter = mode {
|
||||
request = request.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
|
||||
}
|
||||
|
||||
let response = request
|
||||
.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string())
|
||||
.set_max_keys(self.max_keys_per_list_response)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to list S3 prefixes")
|
||||
@@ -367,35 +352,71 @@ impl RemoteStorage for S3Bucket {
|
||||
|
||||
metrics::BUCKET_METRICS
|
||||
.req_seconds
|
||||
.observe_elapsed(kind, &response, started_at);
|
||||
.observe_elapsed(kind, &fetch_response, started_at);
|
||||
|
||||
let response = response?;
|
||||
let fetch_response = fetch_response?;
|
||||
|
||||
let keys = response.contents().unwrap_or_default();
|
||||
let empty = Vec::new();
|
||||
let prefixes = response.common_prefixes.as_ref().unwrap_or(&empty);
|
||||
|
||||
tracing::info!("list: {} prefixes, {} keys", prefixes.len(), keys.len());
|
||||
|
||||
for object in keys {
|
||||
let object_path = object.key().expect("response does not contain a key");
|
||||
let remote_path = self.s3_object_to_relative_path(object_path);
|
||||
result.keys.push(remote_path);
|
||||
}
|
||||
|
||||
result.prefixes.extend(
|
||||
prefixes
|
||||
.iter()
|
||||
document_keys.extend(
|
||||
fetch_response
|
||||
.common_prefixes
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.filter_map(|o| Some(self.s3_object_to_relative_path(o.prefix()?))),
|
||||
);
|
||||
|
||||
continuation_token = match response.next_continuation_token {
|
||||
continuation_token = match fetch_response.next_continuation_token {
|
||||
Some(new_token) => Some(new_token),
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
Ok(document_keys)
|
||||
}
|
||||
|
||||
/// See the doc for `RemoteStorage::list_files`
|
||||
async fn list_files(&self, folder: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>> {
|
||||
let kind = RequestKind::List;
|
||||
|
||||
let folder_name = folder
|
||||
.map(|p| self.relative_path_to_s3_object(p))
|
||||
.or_else(|| self.prefix_in_bucket.clone());
|
||||
|
||||
// AWS may need to break the response into several parts
|
||||
let mut continuation_token = None;
|
||||
let mut all_files = vec![];
|
||||
loop {
|
||||
let _guard = self.permit(kind).await;
|
||||
let started_at = start_measuring_requests(kind);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.list_objects_v2()
|
||||
.bucket(self.bucket_name.clone())
|
||||
.set_prefix(folder_name.clone())
|
||||
.set_continuation_token(continuation_token)
|
||||
.set_max_keys(self.max_keys_per_list_response)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to list files in S3 bucket");
|
||||
|
||||
let started_at = ScopeGuard::into_inner(started_at);
|
||||
metrics::BUCKET_METRICS
|
||||
.req_seconds
|
||||
.observe_elapsed(kind, &response, started_at);
|
||||
|
||||
let response = response?;
|
||||
|
||||
for object in response.contents().unwrap_or_default() {
|
||||
let object_path = object.key().expect("response does not contain a key");
|
||||
let remote_path = self.s3_object_to_relative_path(object_path);
|
||||
all_files.push(remote_path);
|
||||
}
|
||||
match response.next_continuation_token {
|
||||
Some(new_token) => continuation_token = Some(new_token),
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
Ok(all_files)
|
||||
}
|
||||
|
||||
async fn upload(
|
||||
@@ -500,20 +521,6 @@ impl RemoteStorage for S3Bucket {
|
||||
.deleted_objects_total
|
||||
.inc_by(chunk.len() as u64);
|
||||
if let Some(errors) = resp.errors {
|
||||
// Log a bounded number of the errors within the response:
|
||||
// these requests can carry 1000 keys so logging each one
|
||||
// would be too verbose, especially as errors may lead us
|
||||
// to retry repeatedly.
|
||||
const LOG_UP_TO_N_ERRORS: usize = 10;
|
||||
for e in errors.iter().take(LOG_UP_TO_N_ERRORS) {
|
||||
tracing::warn!(
|
||||
"DeleteObjects key {} failed: {}: {}",
|
||||
e.key.as_ref().map(Cow::from).unwrap_or("".into()),
|
||||
e.code.as_ref().map(Cow::from).unwrap_or("".into()),
|
||||
e.message.as_ref().map(Cow::from).unwrap_or("".into())
|
||||
);
|
||||
}
|
||||
|
||||
return Err(anyhow::format_err!(
|
||||
"Failed to delete {} objects",
|
||||
errors.len()
|
||||
@@ -558,8 +565,8 @@ fn start_measuring_requests(
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use camino::Utf8Path;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::path::Path;
|
||||
|
||||
use crate::{RemotePath, S3Bucket, S3Config};
|
||||
|
||||
@@ -568,7 +575,7 @@ mod tests {
|
||||
let all_paths = ["", "some/path", "some/path/"];
|
||||
let all_paths: Vec<RemotePath> = all_paths
|
||||
.iter()
|
||||
.map(|x| RemotePath::new(Utf8Path::new(x)).expect("bad path"))
|
||||
.map(|x| RemotePath::new(Path::new(x)).expect("bad path"))
|
||||
.collect();
|
||||
let prefixes = [
|
||||
None,
|
||||
|
||||
@@ -6,7 +6,7 @@ use once_cell::sync::Lazy;
|
||||
pub(super) static BUCKET_METRICS: Lazy<BucketMetrics> = Lazy::new(Default::default);
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub(crate) enum RequestKind {
|
||||
pub(super) enum RequestKind {
|
||||
Get = 0,
|
||||
Put = 1,
|
||||
Delete = 2,
|
||||
|
||||
@@ -5,9 +5,7 @@ use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use crate::{
|
||||
Download, DownloadError, Listing, ListingMode, RemotePath, RemoteStorage, StorageMetadata,
|
||||
};
|
||||
use crate::{Download, DownloadError, RemotePath, RemoteStorage, StorageMetadata};
|
||||
|
||||
pub struct UnreliableWrapper {
|
||||
inner: crate::GenericRemoteStorage,
|
||||
@@ -97,15 +95,6 @@ impl RemoteStorage for UnreliableWrapper {
|
||||
self.inner.list_files(folder).await
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
mode: ListingMode,
|
||||
) -> Result<Listing, DownloadError> {
|
||||
self.attempt(RemoteOp::ListPrefixes(prefix.cloned()))?;
|
||||
self.inner.list(prefix, mode).await
|
||||
}
|
||||
|
||||
async fn upload(
|
||||
&self,
|
||||
data: impl tokio::io::AsyncRead + Unpin + Send + Sync + 'static,
|
||||
|
||||
@@ -1,623 +0,0 @@
|
||||
use std::collections::HashSet;
|
||||
use std::env;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::ops::ControlFlow;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::UNIX_EPOCH;
|
||||
|
||||
use anyhow::Context;
|
||||
use camino::Utf8Path;
|
||||
use once_cell::sync::OnceCell;
|
||||
use remote_storage::{
|
||||
AzureConfig, Download, GenericRemoteStorage, RemotePath, RemoteStorageConfig, RemoteStorageKind,
|
||||
};
|
||||
use test_context::{test_context, AsyncTestContext};
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
static LOGGING_DONE: OnceCell<()> = OnceCell::new();
|
||||
|
||||
const ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME: &str = "ENABLE_REAL_AZURE_REMOTE_STORAGE";
|
||||
|
||||
const BASE_PREFIX: &str = "test";
|
||||
|
||||
/// Tests that the Azure client can list all prefixes, even if the response comes paginated and requires multiple HTTP queries.
|
||||
/// Uses real Azure and requires [`ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME`] and related Azure cred env vars specified.
|
||||
/// See the client creation in [`create_azure_client`] for details on the required env vars.
|
||||
/// If real Azure tests are disabled, the test passes, skipping any real test run: currently, there's no way to mark the test ignored in runtime with the
|
||||
/// deafult test framework, see https://github.com/rust-lang/rust/issues/68007 for details.
|
||||
///
|
||||
/// First, the test creates a set of Azure blobs with keys `/${random_prefix_part}/${base_prefix_str}/sub_prefix_${i}/blob_${i}` in [`upload_azure_data`]
|
||||
/// where
|
||||
/// * `random_prefix_part` is set for the entire Azure client during the Azure client creation in [`create_azure_client`], to avoid multiple test runs interference
|
||||
/// * `base_prefix_str` is a common prefix to use in the client requests: we would want to ensure that the client is able to list nested prefixes inside the bucket
|
||||
///
|
||||
/// Then, verifies that the client does return correct prefixes when queried:
|
||||
/// * with no prefix, it lists everything after its `${random_prefix_part}/` — that should be `${base_prefix_str}` value only
|
||||
/// * with `${base_prefix_str}/` prefix, it lists every `sub_prefix_${i}`
|
||||
///
|
||||
/// With the real Azure enabled and `#[cfg(test)]` Rust configuration used, the Azure client test adds a `max-keys` param to limit the response keys.
|
||||
/// This way, we are able to test the pagination implicitly, by ensuring all results are returned from the remote storage and avoid uploading too many blobs to Azure.
|
||||
///
|
||||
/// Lastly, the test attempts to clean up and remove all uploaded Azure files.
|
||||
/// If any errors appear during the clean up, they get logged, but the test is not failed or stopped until clean up is finished.
|
||||
#[test_context(MaybeEnabledAzureWithTestBlobs)]
|
||||
#[tokio::test]
|
||||
async fn azure_pagination_should_work(
|
||||
ctx: &mut MaybeEnabledAzureWithTestBlobs,
|
||||
) -> anyhow::Result<()> {
|
||||
let ctx = match ctx {
|
||||
MaybeEnabledAzureWithTestBlobs::Enabled(ctx) => ctx,
|
||||
MaybeEnabledAzureWithTestBlobs::Disabled => return Ok(()),
|
||||
MaybeEnabledAzureWithTestBlobs::UploadsFailed(e, _) => {
|
||||
anyhow::bail!("Azure init failed: {e:?}")
|
||||
}
|
||||
};
|
||||
|
||||
let test_client = Arc::clone(&ctx.enabled.client);
|
||||
let expected_remote_prefixes = ctx.remote_prefixes.clone();
|
||||
|
||||
let base_prefix = RemotePath::new(Utf8Path::new(ctx.enabled.base_prefix))
|
||||
.context("common_prefix construction")?;
|
||||
let root_remote_prefixes = test_client
|
||||
.list_prefixes(None)
|
||||
.await
|
||||
.context("client list root prefixes failure")?
|
||||
.into_iter()
|
||||
.collect::<HashSet<_>>();
|
||||
assert_eq!(
|
||||
root_remote_prefixes, HashSet::from([base_prefix.clone()]),
|
||||
"remote storage root prefixes list mismatches with the uploads. Returned prefixes: {root_remote_prefixes:?}"
|
||||
);
|
||||
|
||||
let nested_remote_prefixes = test_client
|
||||
.list_prefixes(Some(&base_prefix))
|
||||
.await
|
||||
.context("client list nested prefixes failure")?
|
||||
.into_iter()
|
||||
.collect::<HashSet<_>>();
|
||||
let remote_only_prefixes = nested_remote_prefixes
|
||||
.difference(&expected_remote_prefixes)
|
||||
.collect::<HashSet<_>>();
|
||||
let missing_uploaded_prefixes = expected_remote_prefixes
|
||||
.difference(&nested_remote_prefixes)
|
||||
.collect::<HashSet<_>>();
|
||||
assert_eq!(
|
||||
remote_only_prefixes.len() + missing_uploaded_prefixes.len(), 0,
|
||||
"remote storage nested prefixes list mismatches with the uploads. Remote only prefixes: {remote_only_prefixes:?}, missing uploaded prefixes: {missing_uploaded_prefixes:?}",
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Tests that Azure client can list all files in a folder, even if the response comes paginated and requirees multiple Azure queries.
|
||||
/// Uses real Azure and requires [`ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME`] and related Azure cred env vars specified. Test will skip real code and pass if env vars not set.
|
||||
/// See `Azure_pagination_should_work` for more information.
|
||||
///
|
||||
/// First, create a set of Azure objects with keys `random_prefix/folder{j}/blob_{i}.txt` in [`upload_azure_data`]
|
||||
/// Then performs the following queries:
|
||||
/// 1. `list_files(None)`. This should return all files `random_prefix/folder{j}/blob_{i}.txt`
|
||||
/// 2. `list_files("folder1")`. This should return all files `random_prefix/folder1/blob_{i}.txt`
|
||||
#[test_context(MaybeEnabledAzureWithSimpleTestBlobs)]
|
||||
#[tokio::test]
|
||||
async fn azure_list_files_works(
|
||||
ctx: &mut MaybeEnabledAzureWithSimpleTestBlobs,
|
||||
) -> anyhow::Result<()> {
|
||||
let ctx = match ctx {
|
||||
MaybeEnabledAzureWithSimpleTestBlobs::Enabled(ctx) => ctx,
|
||||
MaybeEnabledAzureWithSimpleTestBlobs::Disabled => return Ok(()),
|
||||
MaybeEnabledAzureWithSimpleTestBlobs::UploadsFailed(e, _) => {
|
||||
anyhow::bail!("Azure init failed: {e:?}")
|
||||
}
|
||||
};
|
||||
let test_client = Arc::clone(&ctx.enabled.client);
|
||||
let base_prefix =
|
||||
RemotePath::new(Utf8Path::new("folder1")).context("common_prefix construction")?;
|
||||
let root_files = test_client
|
||||
.list_files(None)
|
||||
.await
|
||||
.context("client list root files failure")?
|
||||
.into_iter()
|
||||
.collect::<HashSet<_>>();
|
||||
assert_eq!(
|
||||
root_files,
|
||||
ctx.remote_blobs.clone(),
|
||||
"remote storage list_files on root mismatches with the uploads."
|
||||
);
|
||||
let nested_remote_files = test_client
|
||||
.list_files(Some(&base_prefix))
|
||||
.await
|
||||
.context("client list nested files failure")?
|
||||
.into_iter()
|
||||
.collect::<HashSet<_>>();
|
||||
let trim_remote_blobs: HashSet<_> = ctx
|
||||
.remote_blobs
|
||||
.iter()
|
||||
.map(|x| x.get_path())
|
||||
.filter(|x| x.starts_with("folder1"))
|
||||
.map(|x| RemotePath::new(x).expect("must be valid path"))
|
||||
.collect();
|
||||
assert_eq!(
|
||||
nested_remote_files, trim_remote_blobs,
|
||||
"remote storage list_files on subdirrectory mismatches with the uploads."
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test_context(MaybeEnabledAzure)]
|
||||
#[tokio::test]
|
||||
async fn azure_delete_non_exising_works(ctx: &mut MaybeEnabledAzure) -> anyhow::Result<()> {
|
||||
let ctx = match ctx {
|
||||
MaybeEnabledAzure::Enabled(ctx) => ctx,
|
||||
MaybeEnabledAzure::Disabled => return Ok(()),
|
||||
};
|
||||
|
||||
let path = RemotePath::new(Utf8Path::new(
|
||||
format!("{}/for_sure_there_is_nothing_there_really", ctx.base_prefix).as_str(),
|
||||
))
|
||||
.with_context(|| "RemotePath conversion")?;
|
||||
|
||||
ctx.client.delete(&path).await.expect("should succeed");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test_context(MaybeEnabledAzure)]
|
||||
#[tokio::test]
|
||||
async fn azure_delete_objects_works(ctx: &mut MaybeEnabledAzure) -> anyhow::Result<()> {
|
||||
let ctx = match ctx {
|
||||
MaybeEnabledAzure::Enabled(ctx) => ctx,
|
||||
MaybeEnabledAzure::Disabled => return Ok(()),
|
||||
};
|
||||
|
||||
let path1 = RemotePath::new(Utf8Path::new(format!("{}/path1", ctx.base_prefix).as_str()))
|
||||
.with_context(|| "RemotePath conversion")?;
|
||||
|
||||
let path2 = RemotePath::new(Utf8Path::new(format!("{}/path2", ctx.base_prefix).as_str()))
|
||||
.with_context(|| "RemotePath conversion")?;
|
||||
|
||||
let path3 = RemotePath::new(Utf8Path::new(format!("{}/path3", ctx.base_prefix).as_str()))
|
||||
.with_context(|| "RemotePath conversion")?;
|
||||
|
||||
let data1 = "remote blob data1".as_bytes();
|
||||
let data1_len = data1.len();
|
||||
let data2 = "remote blob data2".as_bytes();
|
||||
let data2_len = data2.len();
|
||||
let data3 = "remote blob data3".as_bytes();
|
||||
let data3_len = data3.len();
|
||||
ctx.client
|
||||
.upload(std::io::Cursor::new(data1), data1_len, &path1, None)
|
||||
.await?;
|
||||
|
||||
ctx.client
|
||||
.upload(std::io::Cursor::new(data2), data2_len, &path2, None)
|
||||
.await?;
|
||||
|
||||
ctx.client
|
||||
.upload(std::io::Cursor::new(data3), data3_len, &path3, None)
|
||||
.await?;
|
||||
|
||||
ctx.client.delete_objects(&[path1, path2]).await?;
|
||||
|
||||
let prefixes = ctx.client.list_prefixes(None).await?;
|
||||
|
||||
assert_eq!(prefixes.len(), 1);
|
||||
|
||||
ctx.client.delete_objects(&[path3]).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test_context(MaybeEnabledAzure)]
|
||||
#[tokio::test]
|
||||
async fn azure_upload_download_works(ctx: &mut MaybeEnabledAzure) -> anyhow::Result<()> {
|
||||
let MaybeEnabledAzure::Enabled(ctx) = ctx else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let path = RemotePath::new(Utf8Path::new(format!("{}/file", ctx.base_prefix).as_str()))
|
||||
.with_context(|| "RemotePath conversion")?;
|
||||
|
||||
let data = "remote blob data here".as_bytes();
|
||||
let data_len = data.len() as u64;
|
||||
|
||||
ctx.client
|
||||
.upload(std::io::Cursor::new(data), data.len(), &path, None)
|
||||
.await?;
|
||||
|
||||
async fn download_and_compare(mut dl: Download) -> anyhow::Result<Vec<u8>> {
|
||||
let mut buf = Vec::new();
|
||||
tokio::io::copy(&mut dl.download_stream, &mut buf).await?;
|
||||
Ok(buf)
|
||||
}
|
||||
// Normal download request
|
||||
let dl = ctx.client.download(&path).await?;
|
||||
let buf = download_and_compare(dl).await?;
|
||||
assert_eq!(buf, data);
|
||||
|
||||
// Full range (end specified)
|
||||
let dl = ctx
|
||||
.client
|
||||
.download_byte_range(&path, 0, Some(data_len))
|
||||
.await?;
|
||||
let buf = download_and_compare(dl).await?;
|
||||
assert_eq!(buf, data);
|
||||
|
||||
// partial range (end specified)
|
||||
let dl = ctx.client.download_byte_range(&path, 4, Some(10)).await?;
|
||||
let buf = download_and_compare(dl).await?;
|
||||
assert_eq!(buf, data[4..10]);
|
||||
|
||||
// partial range (end beyond real end)
|
||||
let dl = ctx
|
||||
.client
|
||||
.download_byte_range(&path, 8, Some(data_len * 100))
|
||||
.await?;
|
||||
let buf = download_and_compare(dl).await?;
|
||||
assert_eq!(buf, data[8..]);
|
||||
|
||||
// Partial range (end unspecified)
|
||||
let dl = ctx.client.download_byte_range(&path, 4, None).await?;
|
||||
let buf = download_and_compare(dl).await?;
|
||||
assert_eq!(buf, data[4..]);
|
||||
|
||||
// Full range (end unspecified)
|
||||
let dl = ctx.client.download_byte_range(&path, 0, None).await?;
|
||||
let buf = download_and_compare(dl).await?;
|
||||
assert_eq!(buf, data);
|
||||
|
||||
debug!("Cleanup: deleting file at path {path:?}");
|
||||
ctx.client
|
||||
.delete(&path)
|
||||
.await
|
||||
.with_context(|| format!("{path:?} removal"))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_logging_ready() {
|
||||
LOGGING_DONE.get_or_init(|| {
|
||||
utils::logging::init(
|
||||
utils::logging::LogFormat::Test,
|
||||
utils::logging::TracingErrorLayerEnablement::Disabled,
|
||||
)
|
||||
.expect("logging init failed");
|
||||
});
|
||||
}
|
||||
|
||||
struct EnabledAzure {
|
||||
client: Arc<GenericRemoteStorage>,
|
||||
base_prefix: &'static str,
|
||||
}
|
||||
|
||||
impl EnabledAzure {
|
||||
async fn setup(max_keys_in_list_response: Option<i32>) -> Self {
|
||||
let client = create_azure_client(max_keys_in_list_response)
|
||||
.context("Azure client creation")
|
||||
.expect("Azure client creation failed");
|
||||
|
||||
EnabledAzure {
|
||||
client,
|
||||
base_prefix: BASE_PREFIX,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum MaybeEnabledAzure {
|
||||
Enabled(EnabledAzure),
|
||||
Disabled,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl AsyncTestContext for MaybeEnabledAzure {
|
||||
async fn setup() -> Self {
|
||||
ensure_logging_ready();
|
||||
|
||||
if env::var(ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME).is_err() {
|
||||
info!(
|
||||
"`{}` env variable is not set, skipping the test",
|
||||
ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME
|
||||
);
|
||||
return Self::Disabled;
|
||||
}
|
||||
|
||||
Self::Enabled(EnabledAzure::setup(None).await)
|
||||
}
|
||||
}
|
||||
|
||||
enum MaybeEnabledAzureWithTestBlobs {
|
||||
Enabled(AzureWithTestBlobs),
|
||||
Disabled,
|
||||
UploadsFailed(anyhow::Error, AzureWithTestBlobs),
|
||||
}
|
||||
|
||||
struct AzureWithTestBlobs {
|
||||
enabled: EnabledAzure,
|
||||
remote_prefixes: HashSet<RemotePath>,
|
||||
remote_blobs: HashSet<RemotePath>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl AsyncTestContext for MaybeEnabledAzureWithTestBlobs {
|
||||
async fn setup() -> Self {
|
||||
ensure_logging_ready();
|
||||
if env::var(ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME).is_err() {
|
||||
info!(
|
||||
"`{}` env variable is not set, skipping the test",
|
||||
ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME
|
||||
);
|
||||
return Self::Disabled;
|
||||
}
|
||||
|
||||
let max_keys_in_list_response = 10;
|
||||
let upload_tasks_count = 1 + (2 * usize::try_from(max_keys_in_list_response).unwrap());
|
||||
|
||||
let enabled = EnabledAzure::setup(Some(max_keys_in_list_response)).await;
|
||||
|
||||
match upload_azure_data(&enabled.client, enabled.base_prefix, upload_tasks_count).await {
|
||||
ControlFlow::Continue(uploads) => {
|
||||
info!("Remote objects created successfully");
|
||||
|
||||
Self::Enabled(AzureWithTestBlobs {
|
||||
enabled,
|
||||
remote_prefixes: uploads.prefixes,
|
||||
remote_blobs: uploads.blobs,
|
||||
})
|
||||
}
|
||||
ControlFlow::Break(uploads) => Self::UploadsFailed(
|
||||
anyhow::anyhow!("One or multiple blobs failed to upload to Azure"),
|
||||
AzureWithTestBlobs {
|
||||
enabled,
|
||||
remote_prefixes: uploads.prefixes,
|
||||
remote_blobs: uploads.blobs,
|
||||
},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
async fn teardown(self) {
|
||||
match self {
|
||||
Self::Disabled => {}
|
||||
Self::Enabled(ctx) | Self::UploadsFailed(_, ctx) => {
|
||||
cleanup(&ctx.enabled.client, ctx.remote_blobs).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: the setups for the list_prefixes test and the list_files test are very similar
|
||||
// However, they are not idential. The list_prefixes function is concerned with listing prefixes,
|
||||
// whereas the list_files function is concerned with listing files.
|
||||
// See `RemoteStorage::list_files` documentation for more details
|
||||
enum MaybeEnabledAzureWithSimpleTestBlobs {
|
||||
Enabled(AzureWithSimpleTestBlobs),
|
||||
Disabled,
|
||||
UploadsFailed(anyhow::Error, AzureWithSimpleTestBlobs),
|
||||
}
|
||||
struct AzureWithSimpleTestBlobs {
|
||||
enabled: EnabledAzure,
|
||||
remote_blobs: HashSet<RemotePath>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl AsyncTestContext for MaybeEnabledAzureWithSimpleTestBlobs {
|
||||
async fn setup() -> Self {
|
||||
ensure_logging_ready();
|
||||
if env::var(ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME).is_err() {
|
||||
info!(
|
||||
"`{}` env variable is not set, skipping the test",
|
||||
ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME
|
||||
);
|
||||
return Self::Disabled;
|
||||
}
|
||||
|
||||
let max_keys_in_list_response = 10;
|
||||
let upload_tasks_count = 1 + (2 * usize::try_from(max_keys_in_list_response).unwrap());
|
||||
|
||||
let enabled = EnabledAzure::setup(Some(max_keys_in_list_response)).await;
|
||||
|
||||
match upload_simple_azure_data(&enabled.client, upload_tasks_count).await {
|
||||
ControlFlow::Continue(uploads) => {
|
||||
info!("Remote objects created successfully");
|
||||
|
||||
Self::Enabled(AzureWithSimpleTestBlobs {
|
||||
enabled,
|
||||
remote_blobs: uploads,
|
||||
})
|
||||
}
|
||||
ControlFlow::Break(uploads) => Self::UploadsFailed(
|
||||
anyhow::anyhow!("One or multiple blobs failed to upload to Azure"),
|
||||
AzureWithSimpleTestBlobs {
|
||||
enabled,
|
||||
remote_blobs: uploads,
|
||||
},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
async fn teardown(self) {
|
||||
match self {
|
||||
Self::Disabled => {}
|
||||
Self::Enabled(ctx) | Self::UploadsFailed(_, ctx) => {
|
||||
cleanup(&ctx.enabled.client, ctx.remote_blobs).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_azure_client(
|
||||
max_keys_per_list_response: Option<i32>,
|
||||
) -> anyhow::Result<Arc<GenericRemoteStorage>> {
|
||||
use rand::Rng;
|
||||
|
||||
let remote_storage_azure_container = env::var("REMOTE_STORAGE_AZURE_CONTAINER").context(
|
||||
"`REMOTE_STORAGE_AZURE_CONTAINER` env var is not set, but real Azure tests are enabled",
|
||||
)?;
|
||||
let remote_storage_azure_region = env::var("REMOTE_STORAGE_AZURE_REGION").context(
|
||||
"`REMOTE_STORAGE_AZURE_REGION` env var is not set, but real Azure tests are enabled",
|
||||
)?;
|
||||
|
||||
// due to how time works, we've had test runners use the same nanos as bucket prefixes.
|
||||
// millis is just a debugging aid for easier finding the prefix later.
|
||||
let millis = std::time::SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.context("random Azure test prefix part calculation")?
|
||||
.as_millis();
|
||||
|
||||
// because nanos can be the same for two threads so can millis, add randomness
|
||||
let random = rand::thread_rng().gen::<u32>();
|
||||
|
||||
let remote_storage_config = RemoteStorageConfig {
|
||||
storage: RemoteStorageKind::AzureContainer(AzureConfig {
|
||||
container_name: remote_storage_azure_container,
|
||||
container_region: remote_storage_azure_region,
|
||||
prefix_in_container: Some(format!("test_{millis}_{random:08x}/")),
|
||||
concurrency_limit: NonZeroUsize::new(100).unwrap(),
|
||||
max_keys_per_list_response,
|
||||
}),
|
||||
};
|
||||
Ok(Arc::new(
|
||||
GenericRemoteStorage::from_config(&remote_storage_config).context("remote storage init")?,
|
||||
))
|
||||
}
|
||||
|
||||
struct Uploads {
|
||||
prefixes: HashSet<RemotePath>,
|
||||
blobs: HashSet<RemotePath>,
|
||||
}
|
||||
|
||||
async fn upload_azure_data(
|
||||
client: &Arc<GenericRemoteStorage>,
|
||||
base_prefix_str: &'static str,
|
||||
upload_tasks_count: usize,
|
||||
) -> ControlFlow<Uploads, Uploads> {
|
||||
info!("Creating {upload_tasks_count} Azure files");
|
||||
let mut upload_tasks = JoinSet::new();
|
||||
for i in 1..upload_tasks_count + 1 {
|
||||
let task_client = Arc::clone(client);
|
||||
upload_tasks.spawn(async move {
|
||||
let prefix = format!("{base_prefix_str}/sub_prefix_{i}/");
|
||||
let blob_prefix = RemotePath::new(Utf8Path::new(&prefix))
|
||||
.with_context(|| format!("{prefix:?} to RemotePath conversion"))?;
|
||||
let blob_path = blob_prefix.join(Utf8Path::new(&format!("blob_{i}")));
|
||||
debug!("Creating remote item {i} at path {blob_path:?}");
|
||||
|
||||
let data = format!("remote blob data {i}").into_bytes();
|
||||
let data_len = data.len();
|
||||
task_client
|
||||
.upload(std::io::Cursor::new(data), data_len, &blob_path, None)
|
||||
.await?;
|
||||
|
||||
Ok::<_, anyhow::Error>((blob_prefix, blob_path))
|
||||
});
|
||||
}
|
||||
|
||||
let mut upload_tasks_failed = false;
|
||||
let mut uploaded_prefixes = HashSet::with_capacity(upload_tasks_count);
|
||||
let mut uploaded_blobs = HashSet::with_capacity(upload_tasks_count);
|
||||
while let Some(task_run_result) = upload_tasks.join_next().await {
|
||||
match task_run_result
|
||||
.context("task join failed")
|
||||
.and_then(|task_result| task_result.context("upload task failed"))
|
||||
{
|
||||
Ok((upload_prefix, upload_path)) => {
|
||||
uploaded_prefixes.insert(upload_prefix);
|
||||
uploaded_blobs.insert(upload_path);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Upload task failed: {e:?}");
|
||||
upload_tasks_failed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let uploads = Uploads {
|
||||
prefixes: uploaded_prefixes,
|
||||
blobs: uploaded_blobs,
|
||||
};
|
||||
if upload_tasks_failed {
|
||||
ControlFlow::Break(uploads)
|
||||
} else {
|
||||
ControlFlow::Continue(uploads)
|
||||
}
|
||||
}
|
||||
|
||||
async fn cleanup(client: &Arc<GenericRemoteStorage>, objects_to_delete: HashSet<RemotePath>) {
|
||||
info!(
|
||||
"Removing {} objects from the remote storage during cleanup",
|
||||
objects_to_delete.len()
|
||||
);
|
||||
let mut delete_tasks = JoinSet::new();
|
||||
for object_to_delete in objects_to_delete {
|
||||
let task_client = Arc::clone(client);
|
||||
delete_tasks.spawn(async move {
|
||||
debug!("Deleting remote item at path {object_to_delete:?}");
|
||||
task_client
|
||||
.delete(&object_to_delete)
|
||||
.await
|
||||
.with_context(|| format!("{object_to_delete:?} removal"))
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(task_run_result) = delete_tasks.join_next().await {
|
||||
match task_run_result {
|
||||
Ok(task_result) => match task_result {
|
||||
Ok(()) => {}
|
||||
Err(e) => error!("Delete task failed: {e:?}"),
|
||||
},
|
||||
Err(join_err) => error!("Delete task did not finish correctly: {join_err}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Uploads files `folder{j}/blob{i}.txt`. See test description for more details.
|
||||
async fn upload_simple_azure_data(
|
||||
client: &Arc<GenericRemoteStorage>,
|
||||
upload_tasks_count: usize,
|
||||
) -> ControlFlow<HashSet<RemotePath>, HashSet<RemotePath>> {
|
||||
info!("Creating {upload_tasks_count} Azure files");
|
||||
let mut upload_tasks = JoinSet::new();
|
||||
for i in 1..upload_tasks_count + 1 {
|
||||
let task_client = Arc::clone(client);
|
||||
upload_tasks.spawn(async move {
|
||||
let blob_path = PathBuf::from(format!("folder{}/blob_{}.txt", i / 7, i));
|
||||
let blob_path = RemotePath::new(
|
||||
Utf8Path::from_path(blob_path.as_path()).expect("must be valid blob path"),
|
||||
)
|
||||
.with_context(|| format!("{blob_path:?} to RemotePath conversion"))?;
|
||||
debug!("Creating remote item {i} at path {blob_path:?}");
|
||||
|
||||
let data = format!("remote blob data {i}").into_bytes();
|
||||
let data_len = data.len();
|
||||
task_client
|
||||
.upload(std::io::Cursor::new(data), data_len, &blob_path, None)
|
||||
.await?;
|
||||
|
||||
Ok::<_, anyhow::Error>(blob_path)
|
||||
});
|
||||
}
|
||||
|
||||
let mut upload_tasks_failed = false;
|
||||
let mut uploaded_blobs = HashSet::with_capacity(upload_tasks_count);
|
||||
while let Some(task_run_result) = upload_tasks.join_next().await {
|
||||
match task_run_result
|
||||
.context("task join failed")
|
||||
.and_then(|task_result| task_result.context("upload task failed"))
|
||||
{
|
||||
Ok(upload_path) => {
|
||||
uploaded_blobs.insert(upload_path);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Upload task failed: {e:?}");
|
||||
upload_tasks_failed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if upload_tasks_failed {
|
||||
ControlFlow::Break(uploaded_blobs)
|
||||
} else {
|
||||
ControlFlow::Continue(uploaded_blobs)
|
||||
}
|
||||
}
|
||||
@@ -1,13 +1,12 @@
|
||||
use std::collections::HashSet;
|
||||
use std::env;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::num::{NonZeroU32, NonZeroUsize};
|
||||
use std::ops::ControlFlow;
|
||||
use std::path::PathBuf;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::time::UNIX_EPOCH;
|
||||
|
||||
use anyhow::Context;
|
||||
use camino::Utf8Path;
|
||||
use once_cell::sync::OnceCell;
|
||||
use remote_storage::{
|
||||
GenericRemoteStorage, RemotePath, RemoteStorageConfig, RemoteStorageKind, S3Config,
|
||||
@@ -56,7 +55,7 @@ async fn s3_pagination_should_work(ctx: &mut MaybeEnabledS3WithTestBlobs) -> any
|
||||
let test_client = Arc::clone(&ctx.enabled.client);
|
||||
let expected_remote_prefixes = ctx.remote_prefixes.clone();
|
||||
|
||||
let base_prefix = RemotePath::new(Utf8Path::new(ctx.enabled.base_prefix))
|
||||
let base_prefix = RemotePath::new(Path::new(ctx.enabled.base_prefix))
|
||||
.context("common_prefix construction")?;
|
||||
let root_remote_prefixes = test_client
|
||||
.list_prefixes(None)
|
||||
@@ -109,7 +108,7 @@ async fn s3_list_files_works(ctx: &mut MaybeEnabledS3WithSimpleTestBlobs) -> any
|
||||
};
|
||||
let test_client = Arc::clone(&ctx.enabled.client);
|
||||
let base_prefix =
|
||||
RemotePath::new(Utf8Path::new("folder1")).context("common_prefix construction")?;
|
||||
RemotePath::new(Path::new("folder1")).context("common_prefix construction")?;
|
||||
let root_files = test_client
|
||||
.list_files(None)
|
||||
.await
|
||||
@@ -130,9 +129,9 @@ async fn s3_list_files_works(ctx: &mut MaybeEnabledS3WithSimpleTestBlobs) -> any
|
||||
let trim_remote_blobs: HashSet<_> = ctx
|
||||
.remote_blobs
|
||||
.iter()
|
||||
.map(|x| x.get_path())
|
||||
.map(|x| x.get_path().to_str().expect("must be valid name"))
|
||||
.filter(|x| x.starts_with("folder1"))
|
||||
.map(|x| RemotePath::new(x).expect("must be valid path"))
|
||||
.map(|x| RemotePath::new(Path::new(x)).expect("must be valid name"))
|
||||
.collect();
|
||||
assert_eq!(
|
||||
nested_remote_files, trim_remote_blobs,
|
||||
@@ -149,9 +148,10 @@ async fn s3_delete_non_exising_works(ctx: &mut MaybeEnabledS3) -> anyhow::Result
|
||||
MaybeEnabledS3::Disabled => return Ok(()),
|
||||
};
|
||||
|
||||
let path = RemotePath::new(Utf8Path::new(
|
||||
format!("{}/for_sure_there_is_nothing_there_really", ctx.base_prefix).as_str(),
|
||||
))
|
||||
let path = RemotePath::new(&PathBuf::from(format!(
|
||||
"{}/for_sure_there_is_nothing_there_really",
|
||||
ctx.base_prefix,
|
||||
)))
|
||||
.with_context(|| "RemotePath conversion")?;
|
||||
|
||||
ctx.client.delete(&path).await.expect("should succeed");
|
||||
@@ -167,13 +167,13 @@ async fn s3_delete_objects_works(ctx: &mut MaybeEnabledS3) -> anyhow::Result<()>
|
||||
MaybeEnabledS3::Disabled => return Ok(()),
|
||||
};
|
||||
|
||||
let path1 = RemotePath::new(Utf8Path::new(format!("{}/path1", ctx.base_prefix).as_str()))
|
||||
let path1 = RemotePath::new(&PathBuf::from(format!("{}/path1", ctx.base_prefix,)))
|
||||
.with_context(|| "RemotePath conversion")?;
|
||||
|
||||
let path2 = RemotePath::new(Utf8Path::new(format!("{}/path2", ctx.base_prefix).as_str()))
|
||||
let path2 = RemotePath::new(&PathBuf::from(format!("{}/path2", ctx.base_prefix,)))
|
||||
.with_context(|| "RemotePath conversion")?;
|
||||
|
||||
let path3 = RemotePath::new(Utf8Path::new(format!("{}/path3", ctx.base_prefix).as_str()))
|
||||
let path3 = RemotePath::new(&PathBuf::from(format!("{}/path3", ctx.base_prefix,)))
|
||||
.with_context(|| "RemotePath conversion")?;
|
||||
|
||||
let data1 = "remote blob data1".as_bytes();
|
||||
@@ -396,6 +396,8 @@ fn create_s3_client(
|
||||
let random = rand::thread_rng().gen::<u32>();
|
||||
|
||||
let remote_storage_config = RemoteStorageConfig {
|
||||
max_concurrent_syncs: NonZeroUsize::new(100).unwrap(),
|
||||
max_sync_errors: NonZeroU32::new(5).unwrap(),
|
||||
storage: RemoteStorageKind::AwsS3(S3Config {
|
||||
bucket_name: remote_storage_s3_bucket,
|
||||
bucket_region: remote_storage_s3_region,
|
||||
@@ -425,10 +427,10 @@ async fn upload_s3_data(
|
||||
for i in 1..upload_tasks_count + 1 {
|
||||
let task_client = Arc::clone(client);
|
||||
upload_tasks.spawn(async move {
|
||||
let prefix = format!("{base_prefix_str}/sub_prefix_{i}/");
|
||||
let blob_prefix = RemotePath::new(Utf8Path::new(&prefix))
|
||||
let prefix = PathBuf::from(format!("{base_prefix_str}/sub_prefix_{i}/"));
|
||||
let blob_prefix = RemotePath::new(&prefix)
|
||||
.with_context(|| format!("{prefix:?} to RemotePath conversion"))?;
|
||||
let blob_path = blob_prefix.join(Utf8Path::new(&format!("blob_{i}")));
|
||||
let blob_path = blob_prefix.join(Path::new(&format!("blob_{i}")));
|
||||
debug!("Creating remote item {i} at path {blob_path:?}");
|
||||
|
||||
let data = format!("remote blob data {i}").into_bytes();
|
||||
@@ -510,10 +512,8 @@ async fn upload_simple_s3_data(
|
||||
let task_client = Arc::clone(client);
|
||||
upload_tasks.spawn(async move {
|
||||
let blob_path = PathBuf::from(format!("folder{}/blob_{}.txt", i / 7, i));
|
||||
let blob_path = RemotePath::new(
|
||||
Utf8Path::from_path(blob_path.as_path()).expect("must be valid blob path"),
|
||||
)
|
||||
.with_context(|| format!("{blob_path:?} to RemotePath conversion"))?;
|
||||
let blob_path = RemotePath::new(&blob_path)
|
||||
.with_context(|| format!("{blob_path:?} to RemotePath conversion"))?;
|
||||
debug!("Creating remote item {i} at path {blob_path:?}");
|
||||
|
||||
let data = format!("remote blob data {i}").into_bytes();
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
use const_format::formatcp;
|
||||
|
||||
/// Public API types
|
||||
|
||||
@@ -1,18 +1,23 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
|
||||
use utils::{
|
||||
id::{NodeId, TenantId, TimelineId},
|
||||
lsn::Lsn,
|
||||
};
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct TimelineCreateRequest {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub tenant_id: TenantId,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub timeline_id: TimelineId,
|
||||
pub peer_ids: Option<Vec<NodeId>>,
|
||||
pub pg_version: u32,
|
||||
pub system_id: Option<u64>,
|
||||
pub wal_seg_size: Option<u32>,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub commit_lsn: Lsn,
|
||||
// If not passed, it is assigned to the beginning of commit_lsn segment.
|
||||
pub local_start_lsn: Option<Lsn>,
|
||||
@@ -23,6 +28,7 @@ fn lsn_invalid() -> Lsn {
|
||||
}
|
||||
|
||||
/// Data about safekeeper's timeline, mirrors broker.proto.
|
||||
#[serde_as]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct SkTimelineInfo {
|
||||
/// Term.
|
||||
@@ -30,19 +36,25 @@ pub struct SkTimelineInfo {
|
||||
/// Term of the last entry.
|
||||
pub last_log_term: Option<u64>,
|
||||
/// LSN of the last record.
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
#[serde(default = "lsn_invalid")]
|
||||
pub flush_lsn: Lsn,
|
||||
/// Up to which LSN safekeeper regards its WAL as committed.
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
#[serde(default = "lsn_invalid")]
|
||||
pub commit_lsn: Lsn,
|
||||
/// LSN up to which safekeeper has backed WAL.
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
#[serde(default = "lsn_invalid")]
|
||||
pub backup_lsn: Lsn,
|
||||
/// LSN of last checkpoint uploaded by pageserver.
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
#[serde(default = "lsn_invalid")]
|
||||
pub remote_consistent_lsn: Lsn,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
#[serde(default = "lsn_invalid")]
|
||||
pub peer_horizon_lsn: Lsn,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
#[serde(default = "lsn_invalid")]
|
||||
pub local_start_lsn: Lsn,
|
||||
/// A connection string to use for WAL receiving.
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
//! Synthetic size calculation
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
mod calculation;
|
||||
pub mod svg;
|
||||
|
||||
@@ -32,8 +32,6 @@
|
||||
//! .init();
|
||||
//! }
|
||||
//! ```
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
use opentelemetry::sdk::Resource;
|
||||
use opentelemetry::KeyValue;
|
||||
|
||||
@@ -5,13 +5,11 @@ edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
arc-swap.workspace = true
|
||||
sentry.workspace = true
|
||||
async-trait.workspace = true
|
||||
anyhow.workspace = true
|
||||
bincode.workspace = true
|
||||
bytes.workspace = true
|
||||
camino.workspace = true
|
||||
chrono.workspace = true
|
||||
heapless.workspace = true
|
||||
hex = { workspace = true, features = ["serde"] }
|
||||
@@ -55,8 +53,7 @@ byteorder.workspace = true
|
||||
bytes.workspace = true
|
||||
criterion.workspace = true
|
||||
hex-literal.workspace = true
|
||||
camino-tempfile.workspace = true
|
||||
serde_assert.workspace = true
|
||||
tempfile.workspace = true
|
||||
|
||||
[[bench]]
|
||||
name = "benchmarks"
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
// For details about authentication see docs/authentication.md
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use serde;
|
||||
use std::{borrow::Cow, fmt::Display, fs, sync::Arc};
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::Result;
|
||||
use camino::Utf8Path;
|
||||
use jsonwebtoken::{
|
||||
decode, encode, Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
|
||||
use crate::{http::error::ApiError, id::TenantId};
|
||||
use crate::id::TenantId;
|
||||
|
||||
/// Algorithm to use. We require EdDSA.
|
||||
const STORAGE_TOKEN_ALGORITHM: Algorithm = Algorithm::EdDSA;
|
||||
@@ -32,9 +32,11 @@ pub enum Scope {
|
||||
}
|
||||
|
||||
/// JWT payload. See docs/authentication.md for the format
|
||||
#[serde_as]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
pub struct Claims {
|
||||
#[serde(default)]
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub tenant_id: Option<TenantId>,
|
||||
pub scope: Scope,
|
||||
}
|
||||
@@ -45,106 +47,31 @@ impl Claims {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SwappableJwtAuth(ArcSwap<JwtAuth>);
|
||||
|
||||
impl SwappableJwtAuth {
|
||||
pub fn new(jwt_auth: JwtAuth) -> Self {
|
||||
SwappableJwtAuth(ArcSwap::new(Arc::new(jwt_auth)))
|
||||
}
|
||||
pub fn swap(&self, jwt_auth: JwtAuth) {
|
||||
self.0.swap(Arc::new(jwt_auth));
|
||||
}
|
||||
pub fn decode(&self, token: &str) -> std::result::Result<TokenData<Claims>, AuthError> {
|
||||
self.0.load().decode(token)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SwappableJwtAuth {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Swappable({:?})", self.0.load())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
|
||||
pub struct AuthError(pub Cow<'static, str>);
|
||||
|
||||
impl Display for AuthError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<AuthError> for ApiError {
|
||||
fn from(_value: AuthError) -> Self {
|
||||
// Don't pass on the value of the AuthError as a precautionary measure.
|
||||
// Being intentionally vague in public error communication hurts debugability
|
||||
// but it is more secure.
|
||||
ApiError::Forbidden("JWT authentication error".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct JwtAuth {
|
||||
decoding_keys: Vec<DecodingKey>,
|
||||
decoding_key: DecodingKey,
|
||||
validation: Validation,
|
||||
}
|
||||
|
||||
impl JwtAuth {
|
||||
pub fn new(decoding_keys: Vec<DecodingKey>) -> Self {
|
||||
pub fn new(decoding_key: DecodingKey) -> Self {
|
||||
let mut validation = Validation::default();
|
||||
validation.algorithms = vec![STORAGE_TOKEN_ALGORITHM];
|
||||
// The default 'required_spec_claims' is 'exp'. But we don't want to require
|
||||
// expiration.
|
||||
validation.required_spec_claims = [].into();
|
||||
Self {
|
||||
decoding_keys,
|
||||
decoding_key,
|
||||
validation,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_key_path(key_path: &Utf8Path) -> Result<Self> {
|
||||
let metadata = key_path.metadata()?;
|
||||
let decoding_keys = if metadata.is_dir() {
|
||||
let mut keys = Vec::new();
|
||||
for entry in fs::read_dir(key_path)? {
|
||||
let path = entry?.path();
|
||||
if !path.is_file() {
|
||||
// Ignore directories (don't recurse)
|
||||
continue;
|
||||
}
|
||||
let public_key = fs::read(path)?;
|
||||
keys.push(DecodingKey::from_ed_pem(&public_key)?);
|
||||
}
|
||||
keys
|
||||
} else if metadata.is_file() {
|
||||
let public_key = fs::read(key_path)?;
|
||||
vec![DecodingKey::from_ed_pem(&public_key)?]
|
||||
} else {
|
||||
anyhow::bail!("path is neither a directory or a file")
|
||||
};
|
||||
if decoding_keys.is_empty() {
|
||||
anyhow::bail!("Configured for JWT auth with zero decoding keys. All JWT gated requests would be rejected.");
|
||||
}
|
||||
Ok(Self::new(decoding_keys))
|
||||
pub fn from_key_path(key_path: &Path) -> Result<Self> {
|
||||
let public_key = fs::read(key_path)?;
|
||||
Ok(Self::new(DecodingKey::from_ed_pem(&public_key)?))
|
||||
}
|
||||
|
||||
/// Attempt to decode the token with the internal decoding keys.
|
||||
///
|
||||
/// The function tries the stored decoding keys in succession,
|
||||
/// and returns the first yielding a successful result.
|
||||
/// If there is no working decoding key, it returns the last error.
|
||||
pub fn decode(&self, token: &str) -> std::result::Result<TokenData<Claims>, AuthError> {
|
||||
let mut res = None;
|
||||
for decoding_key in &self.decoding_keys {
|
||||
res = Some(decode(token, decoding_key, &self.validation));
|
||||
if let Some(Ok(res)) = res {
|
||||
return Ok(res);
|
||||
}
|
||||
}
|
||||
if let Some(res) = res {
|
||||
res.map_err(|e| AuthError(Cow::Owned(e.to_string())))
|
||||
} else {
|
||||
Err(AuthError(Cow::Borrowed("no JWT decoding keys configured")))
|
||||
}
|
||||
pub fn decode(&self, token: &str) -> Result<TokenData<Claims>> {
|
||||
Ok(decode(token, &self.decoding_key, &self.validation)?)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -184,9 +111,9 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
|
||||
"#;
|
||||
|
||||
#[test]
|
||||
fn test_decode() {
|
||||
fn test_decode() -> Result<(), anyhow::Error> {
|
||||
let expected_claims = Claims {
|
||||
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
|
||||
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081")?),
|
||||
scope: Scope::Tenant,
|
||||
};
|
||||
|
||||
@@ -205,24 +132,28 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
|
||||
let encoded_eddsa = "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJzY29wZSI6InRlbmFudCIsInRlbmFudF9pZCI6IjNkMWY3NTk1YjQ2ODIzMDMwNGUwYjczY2VjYmNiMDgxIiwiaXNzIjoibmVvbi5jb250cm9scGxhbmUiLCJleHAiOjE3MDkyMDA4NzksImlhdCI6MTY3ODQ0MjQ3OX0.U3eA8j-uU-JnhzeO3EDHRuXLwkAUFCPxtGHEgw6p7Ccc3YRbFs2tmCdbD9PZEXP-XsxSeBQi1FY0YPcT3NXADw";
|
||||
|
||||
// Check it can be validated with the public key
|
||||
let auth = JwtAuth::new(vec![DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519).unwrap()]);
|
||||
let claims_from_token = auth.decode(encoded_eddsa).unwrap().claims;
|
||||
let auth = JwtAuth::new(DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519)?);
|
||||
let claims_from_token = auth.decode(encoded_eddsa)?.claims;
|
||||
assert_eq!(claims_from_token, expected_claims);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode() {
|
||||
fn test_encode() -> Result<(), anyhow::Error> {
|
||||
let claims = Claims {
|
||||
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
|
||||
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081")?),
|
||||
scope: Scope::Tenant,
|
||||
};
|
||||
|
||||
let encoded = encode_from_key_file(&claims, TEST_PRIV_KEY_ED25519).unwrap();
|
||||
let encoded = encode_from_key_file(&claims, TEST_PRIV_KEY_ED25519)?;
|
||||
|
||||
// decode it back
|
||||
let auth = JwtAuth::new(vec![DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519).unwrap()]);
|
||||
let decoded = auth.decode(&encoded).unwrap();
|
||||
let auth = JwtAuth::new(DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519)?);
|
||||
let decoded = auth.decode(&encoded)?;
|
||||
|
||||
assert_eq!(decoded.claims, claims);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
ffi::OsStr,
|
||||
fs::{self, File},
|
||||
io,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
|
||||
/// Similar to [`std::fs::create_dir`], except we fsync the
|
||||
/// created directory and its parent.
|
||||
pub fn create_dir(path: impl AsRef<Utf8Path>) -> io::Result<()> {
|
||||
pub fn create_dir(path: impl AsRef<Path>) -> io::Result<()> {
|
||||
let path = path.as_ref();
|
||||
|
||||
fs::create_dir(path)?;
|
||||
@@ -18,7 +18,7 @@ pub fn create_dir(path: impl AsRef<Utf8Path>) -> io::Result<()> {
|
||||
|
||||
/// Similar to [`std::fs::create_dir_all`], except we fsync all
|
||||
/// newly created directories and the pre-existing parent.
|
||||
pub fn create_dir_all(path: impl AsRef<Utf8Path>) -> io::Result<()> {
|
||||
pub fn create_dir_all(path: impl AsRef<Path>) -> io::Result<()> {
|
||||
let mut path = path.as_ref();
|
||||
|
||||
let mut dirs_to_create = Vec::new();
|
||||
@@ -30,7 +30,7 @@ pub fn create_dir_all(path: impl AsRef<Utf8Path>) -> io::Result<()> {
|
||||
Ok(_) => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::AlreadyExists,
|
||||
format!("non-directory found in path: {path}"),
|
||||
format!("non-directory found in path: {}", path.display()),
|
||||
));
|
||||
}
|
||||
Err(ref e) if e.kind() == io::ErrorKind::NotFound => {}
|
||||
@@ -44,7 +44,7 @@ pub fn create_dir_all(path: impl AsRef<Utf8Path>) -> io::Result<()> {
|
||||
None => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
format!("can't find parent of path '{path}'"),
|
||||
format!("can't find parent of path '{}'", path.display()).as_str(),
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -70,18 +70,21 @@ pub fn create_dir_all(path: impl AsRef<Utf8Path>) -> io::Result<()> {
|
||||
|
||||
/// Adds a suffix to the file(directory) name, either appending the suffix to the end of its extension,
|
||||
/// or if there's no extension, creates one and puts a suffix there.
|
||||
pub fn path_with_suffix_extension(
|
||||
original_path: impl AsRef<Utf8Path>,
|
||||
suffix: &str,
|
||||
) -> Utf8PathBuf {
|
||||
let new_extension = match original_path.as_ref().extension() {
|
||||
pub fn path_with_suffix_extension(original_path: impl AsRef<Path>, suffix: &str) -> PathBuf {
|
||||
let new_extension = match original_path
|
||||
.as_ref()
|
||||
.extension()
|
||||
.map(OsStr::to_string_lossy)
|
||||
{
|
||||
Some(extension) => Cow::Owned(format!("{extension}.{suffix}")),
|
||||
None => Cow::Borrowed(suffix),
|
||||
};
|
||||
original_path.as_ref().with_extension(new_extension)
|
||||
original_path
|
||||
.as_ref()
|
||||
.with_extension(new_extension.as_ref())
|
||||
}
|
||||
|
||||
pub fn fsync_file_and_parent(file_path: &Utf8Path) -> io::Result<()> {
|
||||
pub fn fsync_file_and_parent(file_path: &Path) -> io::Result<()> {
|
||||
let parent = file_path.parent().ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::Other,
|
||||
@@ -94,7 +97,7 @@ pub fn fsync_file_and_parent(file_path: &Utf8Path) -> io::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn fsync(path: &Utf8Path) -> io::Result<()> {
|
||||
pub fn fsync(path: &Path) -> io::Result<()> {
|
||||
File::open(path)
|
||||
.map_err(|e| io::Error::new(e.kind(), format!("Failed to open the file {path:?}: {e}")))
|
||||
.and_then(|file| {
|
||||
@@ -108,18 +111,19 @@ pub fn fsync(path: &Utf8Path) -> io::Result<()> {
|
||||
.map_err(|e| io::Error::new(e.kind(), format!("Failed to fsync file {path:?}: {e}")))
|
||||
}
|
||||
|
||||
pub async fn fsync_async(path: impl AsRef<Utf8Path>) -> Result<(), std::io::Error> {
|
||||
tokio::fs::File::open(path.as_ref()).await?.sync_all().await
|
||||
pub async fn fsync_async(path: impl AsRef<std::path::Path>) -> Result<(), std::io::Error> {
|
||||
tokio::fs::File::open(path).await?.sync_all().await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tempfile::tempdir;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_dir_fsyncd() {
|
||||
let dir = camino_tempfile::tempdir().unwrap();
|
||||
let dir = tempdir().unwrap();
|
||||
|
||||
let existing_dir_path = dir.path();
|
||||
let err = create_dir(existing_dir_path).unwrap_err();
|
||||
@@ -135,7 +139,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_create_dir_all_fsyncd() {
|
||||
let dir = camino_tempfile::tempdir().unwrap();
|
||||
let dir = tempdir().unwrap();
|
||||
|
||||
let existing_dir_path = dir.path();
|
||||
create_dir_all(existing_dir_path).unwrap();
|
||||
@@ -162,29 +166,29 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_path_with_suffix_extension() {
|
||||
let p = Utf8PathBuf::from("/foo/bar");
|
||||
let p = PathBuf::from("/foo/bar");
|
||||
assert_eq!(
|
||||
&path_with_suffix_extension(p, "temp").to_string(),
|
||||
&path_with_suffix_extension(p, "temp").to_string_lossy(),
|
||||
"/foo/bar.temp"
|
||||
);
|
||||
let p = Utf8PathBuf::from("/foo/bar");
|
||||
let p = PathBuf::from("/foo/bar");
|
||||
assert_eq!(
|
||||
&path_with_suffix_extension(p, "temp.temp").to_string(),
|
||||
&path_with_suffix_extension(p, "temp.temp").to_string_lossy(),
|
||||
"/foo/bar.temp.temp"
|
||||
);
|
||||
let p = Utf8PathBuf::from("/foo/bar.baz");
|
||||
let p = PathBuf::from("/foo/bar.baz");
|
||||
assert_eq!(
|
||||
&path_with_suffix_extension(p, "temp.temp").to_string(),
|
||||
&path_with_suffix_extension(p, "temp.temp").to_string_lossy(),
|
||||
"/foo/bar.baz.temp.temp"
|
||||
);
|
||||
let p = Utf8PathBuf::from("/foo/bar.baz");
|
||||
let p = PathBuf::from("/foo/bar.baz");
|
||||
assert_eq!(
|
||||
&path_with_suffix_extension(p, ".temp").to_string(),
|
||||
&path_with_suffix_extension(p, ".temp").to_string_lossy(),
|
||||
"/foo/bar.baz..temp"
|
||||
);
|
||||
let p = Utf8PathBuf::from("/foo/bar/dir/");
|
||||
let p = PathBuf::from("/foo/bar/dir/");
|
||||
assert_eq!(
|
||||
&path_with_suffix_extension(p, ".temp").to_string(),
|
||||
&path_with_suffix_extension(p, ".temp").to_string_lossy(),
|
||||
"/foo/bar/dir..temp"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -55,6 +55,8 @@ where
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use crate::fs_ext::{is_directory_empty, list_dir};
|
||||
|
||||
use super::ignore_absent_files;
|
||||
@@ -63,7 +65,7 @@ mod test {
|
||||
fn is_empty_dir() {
|
||||
use super::PathExt;
|
||||
|
||||
let dir = camino_tempfile::tempdir().unwrap();
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let dir_path = dir.path();
|
||||
|
||||
// test positive case
|
||||
@@ -73,7 +75,7 @@ mod test {
|
||||
);
|
||||
|
||||
// invoke on a file to ensure it returns an error
|
||||
let file_path = dir_path.join("testfile");
|
||||
let file_path: PathBuf = dir_path.join("testfile");
|
||||
let f = std::fs::File::create(&file_path).unwrap();
|
||||
drop(f);
|
||||
assert!(file_path.is_empty_dir().is_err());
|
||||
@@ -85,7 +87,7 @@ mod test {
|
||||
|
||||
#[tokio::test]
|
||||
async fn is_empty_dir_async() {
|
||||
let dir = camino_tempfile::tempdir().unwrap();
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let dir_path = dir.path();
|
||||
|
||||
// test positive case
|
||||
@@ -95,7 +97,7 @@ mod test {
|
||||
);
|
||||
|
||||
// invoke on a file to ensure it returns an error
|
||||
let file_path = dir_path.join("testfile");
|
||||
let file_path: PathBuf = dir_path.join("testfile");
|
||||
let f = std::fs::File::create(&file_path).unwrap();
|
||||
drop(f);
|
||||
assert!(is_directory_empty(&file_path).await.is_err());
|
||||
@@ -107,9 +109,10 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn ignore_absent_files_works() {
|
||||
let dir = camino_tempfile::tempdir().unwrap();
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let dir_path = dir.path();
|
||||
|
||||
let file_path = dir.path().join("testfile");
|
||||
let file_path: PathBuf = dir_path.join("testfile");
|
||||
|
||||
ignore_absent_files(|| std::fs::remove_file(&file_path)).expect("should execute normally");
|
||||
|
||||
@@ -123,17 +126,17 @@ mod test {
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_dir_works() {
|
||||
let dir = camino_tempfile::tempdir().unwrap();
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let dir_path = dir.path();
|
||||
|
||||
assert!(list_dir(dir_path).await.unwrap().is_empty());
|
||||
|
||||
let file_path = dir_path.join("testfile");
|
||||
let file_path: PathBuf = dir_path.join("testfile");
|
||||
let _ = std::fs::File::create(&file_path).unwrap();
|
||||
|
||||
assert_eq!(&list_dir(dir_path).await.unwrap(), &["testfile"]);
|
||||
|
||||
let another_dir_path = dir_path.join("testdir");
|
||||
let another_dir_path: PathBuf = dir_path.join("testdir");
|
||||
std::fs::create_dir(another_dir_path).unwrap();
|
||||
|
||||
let expected = &["testdir", "testfile"];
|
||||
|
||||
@@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
|
||||
///
|
||||
/// See docs/rfcs/025-generation-numbers.md for detail on how generation
|
||||
/// numbers are used.
|
||||
#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
|
||||
#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
|
||||
pub enum Generation {
|
||||
// Generations with this magic value will not add a suffix to S3 keys, and will not
|
||||
// be included in persisted index_part.json. This value is only to be used
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
/// Useful type for asserting that expected bytes match reporting the bytes more readable
|
||||
/// array-syntax compatible hex bytes.
|
||||
///
|
||||
/// # Usage
|
||||
///
|
||||
/// ```
|
||||
/// use utils::Hex;
|
||||
///
|
||||
/// let actual = serialize_something();
|
||||
/// let expected = [0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64];
|
||||
///
|
||||
/// // the type implements PartialEq and on mismatch, both sides are printed in 16 wide multiline
|
||||
/// // output suffixed with an array style length for easier comparisons.
|
||||
/// assert_eq!(Hex(&actual), Hex(&expected));
|
||||
///
|
||||
/// // with `let expected = [0x68];` the error would had been:
|
||||
/// // assertion `left == right` failed
|
||||
/// // left: [0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64; 11]
|
||||
/// // right: [0x68; 1]
|
||||
/// # fn serialize_something() -> Vec<u8> { "hello world".as_bytes().to_vec() }
|
||||
/// ```
|
||||
#[derive(PartialEq)]
|
||||
pub struct Hex<'a>(pub &'a [u8]);
|
||||
|
||||
impl std::fmt::Debug for Hex<'_> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "[")?;
|
||||
for (i, c) in self.0.chunks(16).enumerate() {
|
||||
if i > 0 && !c.is_empty() {
|
||||
writeln!(f, ", ")?;
|
||||
}
|
||||
for (j, b) in c.iter().enumerate() {
|
||||
if j > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "0x{b:02x}")?;
|
||||
}
|
||||
}
|
||||
write!(f, "; {}]", self.0.len())
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::auth::{AuthError, Claims, SwappableJwtAuth};
|
||||
use crate::auth::{Claims, JwtAuth};
|
||||
use crate::http::error::{api_error_handler, route_error_handler, ApiError};
|
||||
use anyhow::Context;
|
||||
use hyper::header::{HeaderName, AUTHORIZATION};
|
||||
@@ -14,11 +14,6 @@ use tracing::{self, debug, info, info_span, warn, Instrument};
|
||||
use std::future::Future;
|
||||
use std::str::FromStr;
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use std::io::Write as _;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
static SERVE_METRICS_COUNT: Lazy<IntCounter> = Lazy::new(|| {
|
||||
register_int_counter!(
|
||||
"libmetrics_metric_handler_requests_total",
|
||||
@@ -151,89 +146,94 @@ impl Drop for RequestCancelled {
|
||||
}
|
||||
}
|
||||
|
||||
/// An [`std::io::Write`] implementation on top of a channel sending [`bytes::Bytes`] chunks.
|
||||
pub struct ChannelWriter {
|
||||
buffer: BytesMut,
|
||||
pub tx: mpsc::Sender<std::io::Result<Bytes>>,
|
||||
written: usize,
|
||||
}
|
||||
|
||||
impl ChannelWriter {
|
||||
pub fn new(buf_len: usize, tx: mpsc::Sender<std::io::Result<Bytes>>) -> Self {
|
||||
assert_ne!(buf_len, 0);
|
||||
ChannelWriter {
|
||||
// split about half off the buffer from the start, because we flush depending on
|
||||
// capacity. first flush will come sooner than without this, but now resizes will
|
||||
// have better chance of picking up the "other" half. not guaranteed of course.
|
||||
buffer: BytesMut::with_capacity(buf_len).split_off(buf_len / 2),
|
||||
tx,
|
||||
written: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn flush0(&mut self) -> std::io::Result<usize> {
|
||||
let n = self.buffer.len();
|
||||
if n == 0 {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
tracing::trace!(n, "flushing");
|
||||
let ready = self.buffer.split().freeze();
|
||||
|
||||
// not ideal to call from blocking code to block_on, but we are sure that this
|
||||
// operation does not spawn_blocking other tasks
|
||||
let res: Result<(), ()> = tokio::runtime::Handle::current().block_on(async {
|
||||
self.tx.send(Ok(ready)).await.map_err(|_| ())?;
|
||||
|
||||
// throttle sending to allow reuse of our buffer in `write`.
|
||||
self.tx.reserve().await.map_err(|_| ())?;
|
||||
|
||||
// now the response task has picked up the buffer and hopefully started
|
||||
// sending it to the client.
|
||||
Ok(())
|
||||
});
|
||||
if res.is_err() {
|
||||
return Err(std::io::ErrorKind::BrokenPipe.into());
|
||||
}
|
||||
self.written += n;
|
||||
Ok(n)
|
||||
}
|
||||
|
||||
pub fn flushed_bytes(&self) -> usize {
|
||||
self.written
|
||||
}
|
||||
}
|
||||
|
||||
impl std::io::Write for ChannelWriter {
|
||||
fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
|
||||
let remaining = self.buffer.capacity() - self.buffer.len();
|
||||
|
||||
let out_of_space = remaining < buf.len();
|
||||
|
||||
let original_len = buf.len();
|
||||
|
||||
if out_of_space {
|
||||
let can_still_fit = buf.len() - remaining;
|
||||
self.buffer.extend_from_slice(&buf[..can_still_fit]);
|
||||
buf = &buf[can_still_fit..];
|
||||
self.flush0()?;
|
||||
}
|
||||
|
||||
// assume that this will often under normal operation just move the pointer back to the
|
||||
// beginning of allocation, because previous split off parts are already sent and
|
||||
// dropped.
|
||||
self.buffer.extend_from_slice(buf);
|
||||
Ok(original_len)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
self.flush0().map(|_| ())
|
||||
}
|
||||
}
|
||||
|
||||
async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use std::io::Write as _;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
SERVE_METRICS_COUNT.inc();
|
||||
|
||||
/// An [`std::io::Write`] implementation on top of a channel sending [`bytes::Bytes`] chunks.
|
||||
struct ChannelWriter {
|
||||
buffer: BytesMut,
|
||||
tx: mpsc::Sender<std::io::Result<Bytes>>,
|
||||
written: usize,
|
||||
}
|
||||
|
||||
impl ChannelWriter {
|
||||
fn new(buf_len: usize, tx: mpsc::Sender<std::io::Result<Bytes>>) -> Self {
|
||||
assert_ne!(buf_len, 0);
|
||||
ChannelWriter {
|
||||
// split about half off the buffer from the start, because we flush depending on
|
||||
// capacity. first flush will come sooner than without this, but now resizes will
|
||||
// have better chance of picking up the "other" half. not guaranteed of course.
|
||||
buffer: BytesMut::with_capacity(buf_len).split_off(buf_len / 2),
|
||||
tx,
|
||||
written: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn flush0(&mut self) -> std::io::Result<usize> {
|
||||
let n = self.buffer.len();
|
||||
if n == 0 {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
tracing::trace!(n, "flushing");
|
||||
let ready = self.buffer.split().freeze();
|
||||
|
||||
// not ideal to call from blocking code to block_on, but we are sure that this
|
||||
// operation does not spawn_blocking other tasks
|
||||
let res: Result<(), ()> = tokio::runtime::Handle::current().block_on(async {
|
||||
self.tx.send(Ok(ready)).await.map_err(|_| ())?;
|
||||
|
||||
// throttle sending to allow reuse of our buffer in `write`.
|
||||
self.tx.reserve().await.map_err(|_| ())?;
|
||||
|
||||
// now the response task has picked up the buffer and hopefully started
|
||||
// sending it to the client.
|
||||
Ok(())
|
||||
});
|
||||
if res.is_err() {
|
||||
return Err(std::io::ErrorKind::BrokenPipe.into());
|
||||
}
|
||||
self.written += n;
|
||||
Ok(n)
|
||||
}
|
||||
|
||||
fn flushed_bytes(&self) -> usize {
|
||||
self.written
|
||||
}
|
||||
}
|
||||
|
||||
impl std::io::Write for ChannelWriter {
|
||||
fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
|
||||
let remaining = self.buffer.capacity() - self.buffer.len();
|
||||
|
||||
let out_of_space = remaining < buf.len();
|
||||
|
||||
let original_len = buf.len();
|
||||
|
||||
if out_of_space {
|
||||
let can_still_fit = buf.len() - remaining;
|
||||
self.buffer.extend_from_slice(&buf[..can_still_fit]);
|
||||
buf = &buf[can_still_fit..];
|
||||
self.flush0()?;
|
||||
}
|
||||
|
||||
// assume that this will often under normal operation just move the pointer back to the
|
||||
// beginning of allocation, because previous split off parts are already sent and
|
||||
// dropped.
|
||||
self.buffer.extend_from_slice(buf);
|
||||
Ok(original_len)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
self.flush0().map(|_| ())
|
||||
}
|
||||
}
|
||||
|
||||
let started_at = std::time::Instant::now();
|
||||
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
@@ -389,7 +389,7 @@ fn parse_token(header_value: &str) -> Result<&str, ApiError> {
|
||||
}
|
||||
|
||||
pub fn auth_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
|
||||
provide_auth: fn(&Request<Body>) -> Option<&SwappableJwtAuth>,
|
||||
provide_auth: fn(&Request<Body>) -> Option<&JwtAuth>,
|
||||
) -> Middleware<B, ApiError> {
|
||||
Middleware::pre(move |req| async move {
|
||||
if let Some(auth) = provide_auth(&req) {
|
||||
@@ -400,11 +400,9 @@ pub fn auth_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
|
||||
})?;
|
||||
let token = parse_token(header_value)?;
|
||||
|
||||
let data = auth.decode(token).map_err(|err| {
|
||||
warn!("Authentication error: {err}");
|
||||
// Rely on From<AuthError> for ApiError impl
|
||||
err
|
||||
})?;
|
||||
let data = auth
|
||||
.decode(token)
|
||||
.map_err(|_| ApiError::Unauthorized("malformed jwt token".to_string()))?;
|
||||
req.set_context(data.claims);
|
||||
}
|
||||
None => {
|
||||
@@ -452,11 +450,12 @@ where
|
||||
|
||||
pub fn check_permission_with(
|
||||
req: &Request<Body>,
|
||||
check_permission: impl Fn(&Claims) -> Result<(), AuthError>,
|
||||
check_permission: impl Fn(&Claims) -> Result<(), anyhow::Error>,
|
||||
) -> Result<(), ApiError> {
|
||||
match req.context::<Claims>() {
|
||||
Some(claims) => Ok(check_permission(&claims)
|
||||
.map_err(|_err| ApiError::Forbidden("JWT authentication error".to_string()))?),
|
||||
Some(claims) => {
|
||||
Ok(check_permission(&claims).map_err(|err| ApiError::Forbidden(err.to_string()))?)
|
||||
}
|
||||
None => Ok(()), // claims is None because auth is disabled
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
use hyper::{header, Body, Response, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::borrow::Cow;
|
||||
use std::error::Error as StdError;
|
||||
use thiserror::Error;
|
||||
use tracing::{error, info, warn};
|
||||
use tracing::error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ApiError {
|
||||
@@ -25,9 +24,6 @@ pub enum ApiError {
|
||||
#[error("Precondition failed: {0}")]
|
||||
PreconditionFailed(Box<str>),
|
||||
|
||||
#[error("Resource temporarily unavailable: {0}")]
|
||||
ResourceUnavailable(Cow<'static, str>),
|
||||
|
||||
#[error("Shutting down")]
|
||||
ShuttingDown,
|
||||
|
||||
@@ -63,10 +59,6 @@ impl ApiError {
|
||||
"Shutting down".to_string(),
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
),
|
||||
ApiError::ResourceUnavailable(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
err.to_string(),
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
),
|
||||
ApiError::InternalServerError(err) => HttpErrorBody::response_from_msg_and_status(
|
||||
err.to_string(),
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
@@ -116,15 +108,10 @@ pub async fn route_error_handler(err: routerify::RouteError) -> Response<Body> {
|
||||
|
||||
pub fn api_error_handler(api_error: ApiError) -> Response<Body> {
|
||||
// Print a stack trace for Internal Server errors
|
||||
|
||||
match api_error {
|
||||
ApiError::Forbidden(_) | ApiError::Unauthorized(_) => {
|
||||
warn!("Error processing HTTP request: {api_error:#}")
|
||||
}
|
||||
ApiError::ResourceUnavailable(_) => info!("Error processing HTTP request: {api_error:#}"),
|
||||
ApiError::NotFound(_) => info!("Error processing HTTP request: {api_error:#}"),
|
||||
ApiError::InternalServerError(_) => error!("Error processing HTTP request: {api_error:?}"),
|
||||
_ => error!("Error processing HTTP request: {api_error:#}"),
|
||||
if let ApiError::InternalServerError(_) = api_error {
|
||||
error!("Error processing HTTP request: {api_error:?}");
|
||||
} else {
|
||||
error!("Error processing HTTP request: {api_error:#}");
|
||||
}
|
||||
|
||||
api_error.into_response()
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use std::ffi::OsStr;
|
||||
use std::{fmt, str::FromStr};
|
||||
|
||||
use anyhow::Context;
|
||||
use hex::FromHex;
|
||||
use rand::Rng;
|
||||
use serde::de::Visitor;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
@@ -18,74 +18,12 @@ pub enum IdError {
|
||||
///
|
||||
/// NOTE: It (de)serializes as an array of hex bytes, so the string representation would look
|
||||
/// like `[173,80,132,115,129,226,72,254,170,201,135,108,199,26,228,24]`.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
///
|
||||
/// Use `#[serde_as(as = "DisplayFromStr")]` to (de)serialize it as hex string instead: `ad50847381e248feaac9876cc71ae418`.
|
||||
/// Check the `serde_with::serde_as` documentation for options for more complex types.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
struct Id([u8; 16]);
|
||||
|
||||
impl Serialize for Id {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
if serializer.is_human_readable() {
|
||||
serializer.collect_str(self)
|
||||
} else {
|
||||
self.0.serialize(serializer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Id {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct IdVisitor {
|
||||
is_human_readable_deserializer: bool,
|
||||
}
|
||||
|
||||
impl<'de> Visitor<'de> for IdVisitor {
|
||||
type Value = Id;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
if self.is_human_readable_deserializer {
|
||||
formatter.write_str("value in form of hex string")
|
||||
} else {
|
||||
formatter.write_str("value in form of integer array([u8; 16])")
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: serde::de::SeqAccess<'de>,
|
||||
{
|
||||
let s = serde::de::value::SeqAccessDeserializer::new(seq);
|
||||
let id: [u8; 16] = Deserialize::deserialize(s)?;
|
||||
Ok(Id::from(id))
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Id::from_str(v).map_err(E::custom)
|
||||
}
|
||||
}
|
||||
|
||||
if deserializer.is_human_readable() {
|
||||
deserializer.deserialize_str(IdVisitor {
|
||||
is_human_readable_deserializer: true,
|
||||
})
|
||||
} else {
|
||||
deserializer.deserialize_tuple(
|
||||
16,
|
||||
IdVisitor {
|
||||
is_human_readable_deserializer: false,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Id {
|
||||
pub fn get_from_buf(buf: &mut impl bytes::Buf) -> Id {
|
||||
let mut arr = [0u8; 16];
|
||||
@@ -120,8 +58,6 @@ impl Id {
|
||||
chunk[0] = HEX[((b >> 4) & 0xf) as usize];
|
||||
chunk[1] = HEX[(b & 0xf) as usize];
|
||||
}
|
||||
|
||||
// SAFETY: vec constructed out of `HEX`, it can only be ascii
|
||||
unsafe { String::from_utf8_unchecked(buf) }
|
||||
}
|
||||
}
|
||||
@@ -279,11 +215,12 @@ pub struct TimelineId(Id);
|
||||
|
||||
id_newtype!(TimelineId);
|
||||
|
||||
impl TryFrom<Option<&str>> for TimelineId {
|
||||
impl TryFrom<Option<&OsStr>> for TimelineId {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: Option<&str>) -> Result<Self, Self::Error> {
|
||||
fn try_from(value: Option<&OsStr>) -> Result<Self, Self::Error> {
|
||||
value
|
||||
.and_then(OsStr::to_str)
|
||||
.unwrap_or_default()
|
||||
.parse::<TimelineId>()
|
||||
.with_context(|| format!("Could not parse timeline id from {:?}", value))
|
||||
@@ -373,112 +310,3 @@ impl fmt::Display for NodeId {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use serde_assert::{Deserializer, Serializer, Token, Tokens};
|
||||
|
||||
use crate::bin_ser::BeSer;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_id_serde_non_human_readable() {
|
||||
let original_id = Id([
|
||||
173, 80, 132, 115, 129, 226, 72, 254, 170, 201, 135, 108, 199, 26, 228, 24,
|
||||
]);
|
||||
let expected_tokens = Tokens(vec![
|
||||
Token::Tuple { len: 16 },
|
||||
Token::U8(173),
|
||||
Token::U8(80),
|
||||
Token::U8(132),
|
||||
Token::U8(115),
|
||||
Token::U8(129),
|
||||
Token::U8(226),
|
||||
Token::U8(72),
|
||||
Token::U8(254),
|
||||
Token::U8(170),
|
||||
Token::U8(201),
|
||||
Token::U8(135),
|
||||
Token::U8(108),
|
||||
Token::U8(199),
|
||||
Token::U8(26),
|
||||
Token::U8(228),
|
||||
Token::U8(24),
|
||||
Token::TupleEnd,
|
||||
]);
|
||||
|
||||
let serializer = Serializer::builder().is_human_readable(false).build();
|
||||
let serialized_tokens = original_id.serialize(&serializer).unwrap();
|
||||
assert_eq!(serialized_tokens, expected_tokens);
|
||||
|
||||
let mut deserializer = Deserializer::builder()
|
||||
.is_human_readable(false)
|
||||
.tokens(serialized_tokens)
|
||||
.build();
|
||||
let deserialized_id = Id::deserialize(&mut deserializer).unwrap();
|
||||
assert_eq!(deserialized_id, original_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_id_serde_human_readable() {
|
||||
let original_id = Id([
|
||||
173, 80, 132, 115, 129, 226, 72, 254, 170, 201, 135, 108, 199, 26, 228, 24,
|
||||
]);
|
||||
let expected_tokens = Tokens(vec![Token::Str(String::from(
|
||||
"ad50847381e248feaac9876cc71ae418",
|
||||
))]);
|
||||
|
||||
let serializer = Serializer::builder().is_human_readable(true).build();
|
||||
let serialized_tokens = original_id.serialize(&serializer).unwrap();
|
||||
assert_eq!(serialized_tokens, expected_tokens);
|
||||
|
||||
let mut deserializer = Deserializer::builder()
|
||||
.is_human_readable(true)
|
||||
.tokens(Tokens(vec![Token::Str(String::from(
|
||||
"ad50847381e248feaac9876cc71ae418",
|
||||
))]))
|
||||
.build();
|
||||
assert_eq!(Id::deserialize(&mut deserializer).unwrap(), original_id);
|
||||
}
|
||||
|
||||
macro_rules! roundtrip_type {
|
||||
($type:ty, $expected_bytes:expr) => {{
|
||||
let expected_bytes: [u8; 16] = $expected_bytes;
|
||||
let original_id = <$type>::from(expected_bytes);
|
||||
|
||||
let ser_bytes = original_id.ser().unwrap();
|
||||
assert_eq!(ser_bytes, expected_bytes);
|
||||
|
||||
let des_id = <$type>::des(&ser_bytes).unwrap();
|
||||
assert_eq!(des_id, original_id);
|
||||
}};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_id_bincode_serde() {
|
||||
let expected_bytes = [
|
||||
173, 80, 132, 115, 129, 226, 72, 254, 170, 201, 135, 108, 199, 26, 228, 24,
|
||||
];
|
||||
|
||||
roundtrip_type!(Id, expected_bytes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tenant_id_bincode_serde() {
|
||||
let expected_bytes = [
|
||||
173, 80, 132, 115, 129, 226, 72, 254, 170, 201, 135, 108, 199, 26, 228, 24,
|
||||
];
|
||||
|
||||
roundtrip_type!(TenantId, expected_bytes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timeline_id_bincode_serde() {
|
||||
let expected_bytes = [
|
||||
173, 80, 132, 115, 129, 226, 72, 254, 170, 201, 135, 108, 199, 26, 228, 24,
|
||||
];
|
||||
|
||||
roundtrip_type!(TimelineId, expected_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
//! `utils` is intended to be a place to put code that is shared
|
||||
//! between other crates in this repository.
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
pub mod backoff;
|
||||
|
||||
@@ -25,10 +24,6 @@ pub mod auth;
|
||||
|
||||
// utility functions and helper traits for unified unique id generation/serialization etc.
|
||||
pub mod id;
|
||||
|
||||
mod hex;
|
||||
pub use hex::Hex;
|
||||
|
||||
// http endpoint utils
|
||||
pub mod http;
|
||||
|
||||
@@ -78,11 +73,6 @@ pub mod completion;
|
||||
/// Reporting utilities
|
||||
pub mod error;
|
||||
|
||||
/// async timeout helper
|
||||
pub mod timeout;
|
||||
|
||||
pub mod sync;
|
||||
|
||||
/// This is a shortcut to embed git sha into binaries and avoid copying the same build script to all packages
|
||||
///
|
||||
/// we have several cases:
|
||||
@@ -138,21 +128,6 @@ macro_rules! project_git_version {
|
||||
};
|
||||
}
|
||||
|
||||
/// This is a shortcut to embed build tag into binaries and avoid copying the same build script to all packages
|
||||
#[macro_export]
|
||||
macro_rules! project_build_tag {
|
||||
($const_identifier:ident) => {
|
||||
const $const_identifier: &::core::primitive::str = {
|
||||
const __ARG: &[&::core::primitive::str; 2] = &match ::core::option_env!("BUILD_TAG") {
|
||||
::core::option::Option::Some(x) => ["build_tag-env:", x],
|
||||
::core::option::Option::None => ["build_tag:", ""],
|
||||
};
|
||||
|
||||
$crate::__const_format::concatcp!(__ARG[0], __ARG[1])
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
/// Re-export for `project_git_version` macro
|
||||
#[doc(hidden)]
|
||||
pub use const_format as __const_format;
|
||||
|
||||
@@ -11,10 +11,10 @@ use std::{
|
||||
io::{Read, Write},
|
||||
ops::Deref,
|
||||
os::unix::prelude::AsRawFd,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
use nix::{errno::Errno::EAGAIN, fcntl};
|
||||
|
||||
use crate::crashsafe;
|
||||
@@ -23,7 +23,7 @@ use crate::crashsafe;
|
||||
/// Returned by [`create_exclusive`].
|
||||
#[must_use]
|
||||
pub struct UnwrittenLockFile {
|
||||
path: Utf8PathBuf,
|
||||
path: PathBuf,
|
||||
file: fs::File,
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ impl UnwrittenLockFile {
|
||||
///
|
||||
/// It is not an error if the file already exists.
|
||||
/// It is an error if the file is already locked.
|
||||
pub fn create_exclusive(lock_file_path: &Utf8Path) -> anyhow::Result<UnwrittenLockFile> {
|
||||
pub fn create_exclusive(lock_file_path: &Path) -> anyhow::Result<UnwrittenLockFile> {
|
||||
let lock_file = fs::OpenOptions::new()
|
||||
.create(true) // O_CREAT
|
||||
.write(true)
|
||||
@@ -101,7 +101,7 @@ pub enum LockFileRead {
|
||||
/// Open & try to lock the lock file at the given `path`, returning a [handle][`LockFileRead`] to
|
||||
/// inspect its content. It is not an `Err(...)` if the file does not exist or is already locked.
|
||||
/// Check the [`LockFileRead`] variants for details.
|
||||
pub fn read_and_hold_lock_file(path: &Utf8Path) -> anyhow::Result<LockFileRead> {
|
||||
pub fn read_and_hold_lock_file(path: &Path) -> anyhow::Result<LockFileRead> {
|
||||
let res = fs::OpenOptions::new().read(true).open(path);
|
||||
let mut lock_file = match res {
|
||||
Ok(f) => f,
|
||||
|
||||
@@ -228,12 +228,6 @@ impl SecretString {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for SecretString {
|
||||
fn from(s: String) -> Self {
|
||||
Self(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SecretString {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "[SECRET]")
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
#![warn(missing_docs)]
|
||||
|
||||
use camino::Utf8Path;
|
||||
use serde::{de::Visitor, Deserialize, Serialize};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use std::ops::{Add, AddAssign};
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
@@ -13,114 +13,10 @@ use crate::seqwait::MonotonicCounter;
|
||||
pub const XLOG_BLCKSZ: u32 = 8192;
|
||||
|
||||
/// A Postgres LSN (Log Sequence Number), also known as an XLogRecPtr
|
||||
#[derive(Clone, Copy, Eq, Ord, PartialEq, PartialOrd, Hash)]
|
||||
#[derive(Clone, Copy, Eq, Ord, PartialEq, PartialOrd, Hash, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct Lsn(pub u64);
|
||||
|
||||
impl Serialize for Lsn {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
if serializer.is_human_readable() {
|
||||
serializer.collect_str(self)
|
||||
} else {
|
||||
self.0.serialize(serializer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Lsn {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct LsnVisitor {
|
||||
is_human_readable_deserializer: bool,
|
||||
}
|
||||
|
||||
impl<'de> Visitor<'de> for LsnVisitor {
|
||||
type Value = Lsn;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
if self.is_human_readable_deserializer {
|
||||
formatter.write_str(
|
||||
"value in form of hex string({upper_u32_hex}/{lower_u32_hex}) representing u64 integer",
|
||||
)
|
||||
} else {
|
||||
formatter.write_str("value in form of integer(u64)")
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(Lsn(v))
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Lsn::from_str(v).map_err(|e| E::custom(e))
|
||||
}
|
||||
}
|
||||
|
||||
if deserializer.is_human_readable() {
|
||||
deserializer.deserialize_str(LsnVisitor {
|
||||
is_human_readable_deserializer: true,
|
||||
})
|
||||
} else {
|
||||
deserializer.deserialize_u64(LsnVisitor {
|
||||
is_human_readable_deserializer: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Allows (de)serialization of an `Lsn` always as `u64`.
|
||||
///
|
||||
/// ### Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # use serde::{Serialize, Deserialize};
|
||||
/// use utils::lsn::Lsn;
|
||||
///
|
||||
/// #[derive(PartialEq, Serialize, Deserialize, Debug)]
|
||||
/// struct Foo {
|
||||
/// #[serde(with = "utils::lsn::serde_as_u64")]
|
||||
/// always_u64: Lsn,
|
||||
/// }
|
||||
///
|
||||
/// let orig = Foo { always_u64: Lsn(1234) };
|
||||
///
|
||||
/// let res = serde_json::to_string(&orig).unwrap();
|
||||
/// assert_eq!(res, r#"{"always_u64":1234}"#);
|
||||
///
|
||||
/// let foo = serde_json::from_str::<Foo>(&res).unwrap();
|
||||
/// assert_eq!(foo, orig);
|
||||
/// ```
|
||||
///
|
||||
pub mod serde_as_u64 {
|
||||
use super::Lsn;
|
||||
|
||||
/// Serializes the Lsn as u64 disregarding the human readability of the format.
|
||||
///
|
||||
/// Meant to be used via `#[serde(with = "...")]` or `#[serde(serialize_with = "...")]`.
|
||||
pub fn serialize<S: serde::Serializer>(lsn: &Lsn, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
use serde::Serialize;
|
||||
lsn.0.serialize(serializer)
|
||||
}
|
||||
|
||||
/// Deserializes the Lsn as u64 disregarding the human readability of the format.
|
||||
///
|
||||
/// Meant to be used via `#[serde(with = "...")]` or `#[serde(deserialize_with = "...")]`.
|
||||
pub fn deserialize<'de, D: serde::Deserializer<'de>>(deserializer: D) -> Result<Lsn, D::Error> {
|
||||
use serde::Deserialize;
|
||||
u64::deserialize(deserializer).map(Lsn)
|
||||
}
|
||||
}
|
||||
|
||||
/// We tried to parse an LSN from a string, but failed
|
||||
#[derive(Debug, PartialEq, Eq, thiserror::Error)]
|
||||
#[error("LsnParseError")]
|
||||
@@ -148,9 +44,11 @@ impl Lsn {
|
||||
/// Parse an LSN from a filename in the form `0000000000000000`
|
||||
pub fn from_filename<F>(filename: F) -> Result<Self, LsnParseError>
|
||||
where
|
||||
F: AsRef<Utf8Path>,
|
||||
F: AsRef<Path>,
|
||||
{
|
||||
Lsn::from_hex(filename.as_ref().as_str())
|
||||
let filename: &Path = filename.as_ref();
|
||||
let filename = filename.to_str().ok_or(LsnParseError)?;
|
||||
Lsn::from_hex(filename)
|
||||
}
|
||||
|
||||
/// Parse an LSN from a string in the form `0000000000000000`
|
||||
@@ -368,13 +266,8 @@ impl MonotonicCounter<Lsn> for RecordLsn {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::bin_ser::BeSer;
|
||||
|
||||
use super::*;
|
||||
|
||||
use serde::ser::Serialize;
|
||||
use serde_assert::{Deserializer, Serializer, Token, Tokens};
|
||||
|
||||
#[test]
|
||||
fn test_lsn_strings() {
|
||||
assert_eq!("12345678/AAAA5555".parse(), Ok(Lsn(0x12345678AAAA5555)));
|
||||
@@ -450,95 +343,4 @@ mod tests {
|
||||
assert_eq!(lsn.fetch_max(Lsn(6000)), Lsn(5678));
|
||||
assert_eq!(lsn.fetch_max(Lsn(5000)), Lsn(6000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lsn_serde() {
|
||||
let original_lsn = Lsn(0x0123456789abcdef);
|
||||
let expected_readable_tokens = Tokens(vec![Token::U64(0x0123456789abcdef)]);
|
||||
let expected_non_readable_tokens =
|
||||
Tokens(vec![Token::Str(String::from("1234567/89ABCDEF"))]);
|
||||
|
||||
// Testing human_readable ser/de
|
||||
let serializer = Serializer::builder().is_human_readable(false).build();
|
||||
let readable_ser_tokens = original_lsn.serialize(&serializer).unwrap();
|
||||
assert_eq!(readable_ser_tokens, expected_readable_tokens);
|
||||
|
||||
let mut deserializer = Deserializer::builder()
|
||||
.is_human_readable(false)
|
||||
.tokens(readable_ser_tokens)
|
||||
.build();
|
||||
let des_lsn = Lsn::deserialize(&mut deserializer).unwrap();
|
||||
assert_eq!(des_lsn, original_lsn);
|
||||
|
||||
// Testing NON human_readable ser/de
|
||||
let serializer = Serializer::builder().is_human_readable(true).build();
|
||||
let non_readable_ser_tokens = original_lsn.serialize(&serializer).unwrap();
|
||||
assert_eq!(non_readable_ser_tokens, expected_non_readable_tokens);
|
||||
|
||||
let mut deserializer = Deserializer::builder()
|
||||
.is_human_readable(true)
|
||||
.tokens(non_readable_ser_tokens)
|
||||
.build();
|
||||
let des_lsn = Lsn::deserialize(&mut deserializer).unwrap();
|
||||
assert_eq!(des_lsn, original_lsn);
|
||||
|
||||
// Testing mismatching ser/de
|
||||
let serializer = Serializer::builder().is_human_readable(false).build();
|
||||
let non_readable_ser_tokens = original_lsn.serialize(&serializer).unwrap();
|
||||
|
||||
let mut deserializer = Deserializer::builder()
|
||||
.is_human_readable(true)
|
||||
.tokens(non_readable_ser_tokens)
|
||||
.build();
|
||||
Lsn::deserialize(&mut deserializer).unwrap_err();
|
||||
|
||||
let serializer = Serializer::builder().is_human_readable(true).build();
|
||||
let readable_ser_tokens = original_lsn.serialize(&serializer).unwrap();
|
||||
|
||||
let mut deserializer = Deserializer::builder()
|
||||
.is_human_readable(false)
|
||||
.tokens(readable_ser_tokens)
|
||||
.build();
|
||||
Lsn::deserialize(&mut deserializer).unwrap_err();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lsn_ensure_roundtrip() {
|
||||
let original_lsn = Lsn(0xaaaabbbb);
|
||||
|
||||
let serializer = Serializer::builder().is_human_readable(false).build();
|
||||
let ser_tokens = original_lsn.serialize(&serializer).unwrap();
|
||||
|
||||
let mut deserializer = Deserializer::builder()
|
||||
.is_human_readable(false)
|
||||
.tokens(ser_tokens)
|
||||
.build();
|
||||
|
||||
let des_lsn = Lsn::deserialize(&mut deserializer).unwrap();
|
||||
assert_eq!(des_lsn, original_lsn);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lsn_bincode_serde() {
|
||||
let lsn = Lsn(0x0123456789abcdef);
|
||||
let expected_bytes = [0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef];
|
||||
|
||||
let ser_bytes = lsn.ser().unwrap();
|
||||
assert_eq!(ser_bytes, expected_bytes);
|
||||
|
||||
let des_lsn = Lsn::des(&ser_bytes).unwrap();
|
||||
assert_eq!(des_lsn, lsn);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lsn_bincode_ensure_roundtrip() {
|
||||
let original_lsn = Lsn(0x01_02_03_04_05_06_07_08);
|
||||
let expected_bytes = vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08];
|
||||
|
||||
let ser_bytes = original_lsn.ser().unwrap();
|
||||
assert_eq!(ser_bytes, expected_bytes);
|
||||
|
||||
let des_lsn = Lsn::des(&ser_bytes).unwrap();
|
||||
assert_eq!(des_lsn, original_lsn);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::time::{Duration, SystemTime};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use pq_proto::{read_cstr, PG_EPOCH};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
use tracing::{trace, warn};
|
||||
|
||||
use crate::lsn::Lsn;
|
||||
@@ -14,17 +15,21 @@ use crate::lsn::Lsn;
|
||||
///
|
||||
/// serde Serialize is used only for human readable dump to json (e.g. in
|
||||
/// safekeepers debug_dump).
|
||||
#[serde_as]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct PageserverFeedback {
|
||||
/// Last known size of the timeline. Used to enforce timeline size limit.
|
||||
pub current_timeline_size: u64,
|
||||
/// LSN last received and ingested by the pageserver. Controls backpressure.
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub last_received_lsn: Lsn,
|
||||
/// LSN up to which data is persisted by the pageserver to its local disc.
|
||||
/// Controls backpressure.
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub disk_consistent_lsn: Lsn,
|
||||
/// LSN up to which data is persisted by the pageserver on s3; safekeepers
|
||||
/// consider WAL before it can be removed.
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub remote_consistent_lsn: Lsn,
|
||||
// Serialize with RFC3339 format.
|
||||
#[serde(with = "serde_systemtime")]
|
||||
|
||||
@@ -49,10 +49,9 @@
|
||||
//! At this point, `B` and `C` are running, which is hazardous.
|
||||
//! Morale of the story: don't unlink pidfiles, ever.
|
||||
|
||||
use std::ops::Deref;
|
||||
use std::{ops::Deref, path::Path};
|
||||
|
||||
use anyhow::Context;
|
||||
use camino::Utf8Path;
|
||||
use nix::unistd::Pid;
|
||||
|
||||
use crate::lock_file::{self, LockFileRead};
|
||||
@@ -85,7 +84,7 @@ impl Deref for PidFileGuard {
|
||||
/// The claim ends as soon as the returned guard object is dropped.
|
||||
/// To maintain the claim for the remaining lifetime of the current process,
|
||||
/// use [`std::mem::forget`] or similar.
|
||||
pub fn claim_for_current_process(path: &Utf8Path) -> anyhow::Result<PidFileGuard> {
|
||||
pub fn claim_for_current_process(path: &Path) -> anyhow::Result<PidFileGuard> {
|
||||
let unwritten_lock_file = lock_file::create_exclusive(path).context("lock file")?;
|
||||
// if any of the next steps fail, we drop the file descriptor and thereby release the lock
|
||||
let guard = unwritten_lock_file
|
||||
@@ -133,7 +132,7 @@ pub enum PidFileRead {
|
||||
///
|
||||
/// On success, this function returns a [`PidFileRead`].
|
||||
/// Check its docs for a description of the meaning of its different variants.
|
||||
pub fn read(pidfile: &Utf8Path) -> anyhow::Result<PidFileRead> {
|
||||
pub fn read(pidfile: &Path) -> anyhow::Result<PidFileRead> {
|
||||
let res = lock_file::read_and_hold_lock_file(pidfile).context("read and hold pid file")?;
|
||||
let ret = match res {
|
||||
LockFileRead::NotExist => PidFileRead::NotExist,
|
||||
|
||||
@@ -58,7 +58,7 @@ where
|
||||
// to get that.
|
||||
impl<T: Ord> PartialOrd for Waiter<T> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
other.wake_num.partial_cmp(&self.wake_num)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,9 +125,6 @@ where
|
||||
// Wake everyone with an error.
|
||||
let mut internal = self.internal.lock().unwrap();
|
||||
|
||||
// Block any future waiters from starting
|
||||
internal.shutdown = true;
|
||||
|
||||
// This will steal the entire waiters map.
|
||||
// When we drop it all waiters will be woken.
|
||||
mem::take(&mut internal.waiters)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
/// Immediately terminate the calling process without calling
|
||||
/// atexit callbacks, C runtime destructors etc. We mainly use
|
||||
/// this to protect coverage data from concurrent writes.
|
||||
pub fn exit_now(code: u8) -> ! {
|
||||
// SAFETY: exiting is safe, the ffi is not safe
|
||||
pub fn exit_now(code: u8) {
|
||||
unsafe { nix::libc::_exit(code as _) };
|
||||
}
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
pub mod heavier_once_cell;
|
||||
|
||||
pub mod gate;
|
||||
@@ -1,158 +0,0 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
/// Gates are a concurrency helper, primarily used for implementing safe shutdown.
|
||||
///
|
||||
/// Users of a resource call `enter()` to acquire a GateGuard, and the owner of
|
||||
/// the resource calls `close()` when they want to ensure that all holders of guards
|
||||
/// have released them, and that no future guards will be issued.
|
||||
pub struct Gate {
|
||||
/// Each caller of enter() takes one unit from the semaphore. In close(), we
|
||||
/// take all the units to ensure all GateGuards are destroyed.
|
||||
sem: Arc<tokio::sync::Semaphore>,
|
||||
|
||||
/// For observability only: a name that will be used to log warnings if a particular
|
||||
/// gate is holding up shutdown
|
||||
name: String,
|
||||
}
|
||||
|
||||
/// RAII guard for a [`Gate`]: as long as this exists, calls to [`Gate::close`] will
|
||||
/// not complete.
|
||||
#[derive(Debug)]
|
||||
pub struct GateGuard(tokio::sync::OwnedSemaphorePermit);
|
||||
|
||||
/// Observability helper: every `warn_period`, emit a log warning that we're still waiting on this gate
|
||||
async fn warn_if_stuck<Fut: std::future::Future>(
|
||||
fut: Fut,
|
||||
name: &str,
|
||||
warn_period: std::time::Duration,
|
||||
) -> <Fut as std::future::Future>::Output {
|
||||
let started = std::time::Instant::now();
|
||||
|
||||
let mut fut = std::pin::pin!(fut);
|
||||
|
||||
loop {
|
||||
match tokio::time::timeout(warn_period, &mut fut).await {
|
||||
Ok(ret) => return ret,
|
||||
Err(_) => {
|
||||
tracing::warn!(
|
||||
gate = name,
|
||||
elapsed_ms = started.elapsed().as_millis(),
|
||||
"still waiting, taking longer than expected..."
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum GateError {
|
||||
GateClosed,
|
||||
}
|
||||
|
||||
impl Gate {
|
||||
const MAX_UNITS: u32 = u32::MAX;
|
||||
|
||||
pub fn new(name: String) -> Self {
|
||||
Self {
|
||||
sem: Arc::new(tokio::sync::Semaphore::new(Self::MAX_UNITS as usize)),
|
||||
name,
|
||||
}
|
||||
}
|
||||
|
||||
/// Acquire a guard that will prevent close() calls from completing. If close()
|
||||
/// was already called, this will return an error which should be interpreted
|
||||
/// as "shutting down".
|
||||
///
|
||||
/// This function would typically be used from e.g. request handlers. While holding
|
||||
/// the guard returned from this function, it is important to respect a CancellationToken
|
||||
/// to avoid blocking close() indefinitely: typically types that contain a Gate will
|
||||
/// also contain a CancellationToken.
|
||||
pub fn enter(&self) -> Result<GateGuard, GateError> {
|
||||
self.sem
|
||||
.clone()
|
||||
.try_acquire_owned()
|
||||
.map(GateGuard)
|
||||
.map_err(|_| GateError::GateClosed)
|
||||
}
|
||||
|
||||
/// Types with a shutdown() method and a gate should call this method at the
|
||||
/// end of shutdown, to ensure that all GateGuard holders are done.
|
||||
///
|
||||
/// This will wait for all guards to be destroyed. For this to complete promptly, it is
|
||||
/// important that the holders of such guards are respecting a CancellationToken which has
|
||||
/// been cancelled before entering this function.
|
||||
pub async fn close(&self) {
|
||||
warn_if_stuck(self.do_close(), &self.name, Duration::from_millis(1000)).await
|
||||
}
|
||||
|
||||
/// Check if [`Self::close()`] has finished waiting for all [`Self::enter()`] users to finish. This
|
||||
/// is usually analoguous for "Did shutdown finish?" for types that include a Gate, whereas checking
|
||||
/// the CancellationToken on such types is analogous to "Did shutdown start?"
|
||||
pub fn close_complete(&self) -> bool {
|
||||
self.sem.is_closed()
|
||||
}
|
||||
|
||||
async fn do_close(&self) {
|
||||
tracing::debug!(gate = self.name, "Closing Gate...");
|
||||
match self.sem.acquire_many(Self::MAX_UNITS).await {
|
||||
Ok(_units) => {
|
||||
// While holding all units, close the semaphore. All subsequent calls to enter() will fail.
|
||||
self.sem.close();
|
||||
}
|
||||
Err(_) => {
|
||||
// Semaphore closed: we are the only function that can do this, so it indicates a double-call.
|
||||
// This is legal. Timeline::shutdown for example is not protected from being called more than
|
||||
// once.
|
||||
tracing::debug!(gate = self.name, "Double close")
|
||||
}
|
||||
}
|
||||
tracing::debug!(gate = self.name, "Closed Gate.")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::FutureExt;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_idle_gate() {
|
||||
// Having taken no gates, we should not be blocked in close
|
||||
let gate = Gate::new("test".to_string());
|
||||
gate.close().await;
|
||||
|
||||
// If a guard is dropped before entering, close should not be blocked
|
||||
let gate = Gate::new("test".to_string());
|
||||
let guard = gate.enter().unwrap();
|
||||
drop(guard);
|
||||
gate.close().await;
|
||||
|
||||
// Entering a closed guard fails
|
||||
gate.enter().expect_err("enter should fail after close");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_busy_gate() {
|
||||
let gate = Gate::new("test".to_string());
|
||||
|
||||
let guard = gate.enter().unwrap();
|
||||
|
||||
let mut close_fut = std::pin::pin!(gate.close());
|
||||
|
||||
// Close should be blocked
|
||||
assert!(close_fut.as_mut().now_or_never().is_none());
|
||||
|
||||
// Attempting to enter() should fail, even though close isn't done yet.
|
||||
gate.enter()
|
||||
.expect_err("enter should fail after entering close");
|
||||
|
||||
drop(guard);
|
||||
|
||||
// Guard is gone, close should finish
|
||||
assert!(close_fut.as_mut().now_or_never().is_some());
|
||||
|
||||
// Attempting to enter() is still forbidden
|
||||
gate.enter().expect_err("enter should fail finishing close");
|
||||
}
|
||||
}
|
||||
@@ -1,383 +0,0 @@
|
||||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc, Mutex, MutexGuard,
|
||||
};
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
/// Custom design like [`tokio::sync::OnceCell`] but using [`OwnedSemaphorePermit`] instead of
|
||||
/// `SemaphorePermit`, allowing use of `take` which does not require holding an outer mutex guard
|
||||
/// for the duration of initialization.
|
||||
///
|
||||
/// Has no unsafe, builds upon [`tokio::sync::Semaphore`] and [`std::sync::Mutex`].
|
||||
///
|
||||
/// [`OwnedSemaphorePermit`]: tokio::sync::OwnedSemaphorePermit
|
||||
pub struct OnceCell<T> {
|
||||
inner: Mutex<Inner<T>>,
|
||||
initializers: AtomicUsize,
|
||||
}
|
||||
|
||||
impl<T> Default for OnceCell<T> {
|
||||
/// Create new uninitialized [`OnceCell`].
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
inner: Default::default(),
|
||||
initializers: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Semaphore is the current state:
|
||||
/// - open semaphore means the value is `None`, not yet initialized
|
||||
/// - closed semaphore means the value has been initialized
|
||||
#[derive(Debug)]
|
||||
struct Inner<T> {
|
||||
init_semaphore: Arc<Semaphore>,
|
||||
value: Option<T>,
|
||||
}
|
||||
|
||||
impl<T> Default for Inner<T> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
init_semaphore: Arc::new(Semaphore::new(1)),
|
||||
value: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> OnceCell<T> {
|
||||
/// Creates an already initialized `OnceCell` with the given value.
|
||||
pub fn new(value: T) -> Self {
|
||||
let sem = Semaphore::new(1);
|
||||
sem.close();
|
||||
Self {
|
||||
inner: Mutex::new(Inner {
|
||||
init_semaphore: Arc::new(sem),
|
||||
value: Some(value),
|
||||
}),
|
||||
initializers: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a guard to an existing initialized value, or uniquely initializes the value before
|
||||
/// returning the guard.
|
||||
///
|
||||
/// Initializing might wait on any existing [`Guard::take_and_deinit`] deinitialization.
|
||||
///
|
||||
/// Initialization is panic-safe and cancellation-safe.
|
||||
pub async fn get_or_init<F, Fut, E>(&self, factory: F) -> Result<Guard<'_, T>, E>
|
||||
where
|
||||
F: FnOnce(InitPermit) -> Fut,
|
||||
Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
|
||||
{
|
||||
let sem = {
|
||||
let guard = self.inner.lock().unwrap();
|
||||
if guard.value.is_some() {
|
||||
return Ok(Guard(guard));
|
||||
}
|
||||
guard.init_semaphore.clone()
|
||||
};
|
||||
|
||||
let permit = {
|
||||
// increment the count for the duration of queued
|
||||
let _guard = CountWaitingInitializers::start(self);
|
||||
sem.acquire_owned().await
|
||||
};
|
||||
|
||||
match permit {
|
||||
Ok(permit) => {
|
||||
let permit = InitPermit(permit);
|
||||
let (value, _permit) = factory(permit).await?;
|
||||
|
||||
let guard = self.inner.lock().unwrap();
|
||||
|
||||
Ok(Self::set0(value, guard))
|
||||
}
|
||||
Err(_closed) => {
|
||||
let guard = self.inner.lock().unwrap();
|
||||
assert!(
|
||||
guard.value.is_some(),
|
||||
"semaphore got closed, must be initialized"
|
||||
);
|
||||
return Ok(Guard(guard));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Assuming a permit is held after previous call to [`Guard::take_and_deinit`], it can be used
|
||||
/// to complete initializing the inner value.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the inner has already been initialized.
|
||||
pub fn set(&self, value: T, _permit: InitPermit) -> Guard<'_, T> {
|
||||
let guard = self.inner.lock().unwrap();
|
||||
|
||||
// cannot assert that this permit is for self.inner.semaphore, but we can assert it cannot
|
||||
// give more permits right now.
|
||||
if guard.init_semaphore.try_acquire().is_ok() {
|
||||
drop(guard);
|
||||
panic!("permit is of wrong origin");
|
||||
}
|
||||
|
||||
Self::set0(value, guard)
|
||||
}
|
||||
|
||||
fn set0(value: T, mut guard: std::sync::MutexGuard<'_, Inner<T>>) -> Guard<'_, T> {
|
||||
if guard.value.is_some() {
|
||||
drop(guard);
|
||||
unreachable!("we won permit, must not be initialized");
|
||||
}
|
||||
guard.value = Some(value);
|
||||
guard.init_semaphore.close();
|
||||
Guard(guard)
|
||||
}
|
||||
|
||||
/// Returns a guard to an existing initialized value, if any.
|
||||
pub fn get(&self) -> Option<Guard<'_, T>> {
|
||||
let guard = self.inner.lock().unwrap();
|
||||
if guard.value.is_some() {
|
||||
Some(Guard(guard))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the number of [`Self::get_or_init`] calls waiting for initialization to complete.
|
||||
pub fn initializer_count(&self) -> usize {
|
||||
self.initializers.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
/// DropGuard counter for queued tasks waiting to initialize, mainly accessible for the
|
||||
/// initializing task for example at the end of initialization.
|
||||
struct CountWaitingInitializers<'a, T>(&'a OnceCell<T>);
|
||||
|
||||
impl<'a, T> CountWaitingInitializers<'a, T> {
|
||||
fn start(target: &'a OnceCell<T>) -> Self {
|
||||
target.initializers.fetch_add(1, Ordering::Relaxed);
|
||||
CountWaitingInitializers(target)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Drop for CountWaitingInitializers<'a, T> {
|
||||
fn drop(&mut self) {
|
||||
self.0.initializers.fetch_sub(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
/// Uninteresting guard object to allow short-lived access to inspect or clone the held,
|
||||
/// initialized value.
|
||||
#[derive(Debug)]
|
||||
pub struct Guard<'a, T>(MutexGuard<'a, Inner<T>>);
|
||||
|
||||
impl<T> std::ops::Deref for Guard<'_, T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.0
|
||||
.value
|
||||
.as_ref()
|
||||
.expect("guard is not created unless value has been initialized")
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::DerefMut for Guard<'_, T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.0
|
||||
.value
|
||||
.as_mut()
|
||||
.expect("guard is not created unless value has been initialized")
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Guard<'a, T> {
|
||||
/// Take the current value, and a new permit for it's deinitialization.
|
||||
///
|
||||
/// The permit will be on a semaphore part of the new internal value, and any following
|
||||
/// [`OnceCell::get_or_init`] will wait on it to complete.
|
||||
pub fn take_and_deinit(&mut self) -> (T, InitPermit) {
|
||||
let mut swapped = Inner::default();
|
||||
let permit = swapped
|
||||
.init_semaphore
|
||||
.clone()
|
||||
.try_acquire_owned()
|
||||
.expect("we just created this");
|
||||
std::mem::swap(&mut *self.0, &mut swapped);
|
||||
swapped
|
||||
.value
|
||||
.map(|v| (v, InitPermit(permit)))
|
||||
.expect("guard is not created unless value has been initialized")
|
||||
}
|
||||
}
|
||||
|
||||
/// Type held by OnceCell (de)initializing task.
|
||||
pub struct InitPermit(tokio::sync::OwnedSemaphorePermit);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn many_initializers() {
|
||||
#[derive(Default, Debug)]
|
||||
struct Counters {
|
||||
factory_got_to_run: AtomicUsize,
|
||||
future_polled: AtomicUsize,
|
||||
winners: AtomicUsize,
|
||||
}
|
||||
|
||||
let initializers = 100;
|
||||
|
||||
let cell = Arc::new(OnceCell::default());
|
||||
let counters = Arc::new(Counters::default());
|
||||
let barrier = Arc::new(tokio::sync::Barrier::new(initializers + 1));
|
||||
|
||||
let mut js = tokio::task::JoinSet::new();
|
||||
for i in 0..initializers {
|
||||
js.spawn({
|
||||
let cell = cell.clone();
|
||||
let counters = counters.clone();
|
||||
let barrier = barrier.clone();
|
||||
|
||||
async move {
|
||||
barrier.wait().await;
|
||||
let won = {
|
||||
let g = cell
|
||||
.get_or_init(|permit| {
|
||||
counters.factory_got_to_run.fetch_add(1, Ordering::Relaxed);
|
||||
async {
|
||||
counters.future_polled.fetch_add(1, Ordering::Relaxed);
|
||||
Ok::<_, Infallible>((i, permit))
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
*g == i
|
||||
};
|
||||
|
||||
if won {
|
||||
counters.winners.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
barrier.wait().await;
|
||||
|
||||
while let Some(next) = js.join_next().await {
|
||||
next.expect("no panics expected");
|
||||
}
|
||||
|
||||
let mut counters = Arc::try_unwrap(counters).unwrap();
|
||||
|
||||
assert_eq!(*counters.factory_got_to_run.get_mut(), 1);
|
||||
assert_eq!(*counters.future_polled.get_mut(), 1);
|
||||
assert_eq!(*counters.winners.get_mut(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn reinit_waits_for_deinit() {
|
||||
// with the tokio::time paused, we will "sleep" for 1s while holding the reinitialization
|
||||
let sleep_for = Duration::from_secs(1);
|
||||
let initial = 42;
|
||||
let reinit = 1;
|
||||
let cell = Arc::new(OnceCell::new(initial));
|
||||
|
||||
let deinitialization_started = Arc::new(tokio::sync::Barrier::new(2));
|
||||
|
||||
let jh = tokio::spawn({
|
||||
let cell = cell.clone();
|
||||
let deinitialization_started = deinitialization_started.clone();
|
||||
async move {
|
||||
let (answer, _permit) = cell.get().expect("initialized to value").take_and_deinit();
|
||||
assert_eq!(answer, initial);
|
||||
|
||||
deinitialization_started.wait().await;
|
||||
tokio::time::sleep(sleep_for).await;
|
||||
}
|
||||
});
|
||||
|
||||
deinitialization_started.wait().await;
|
||||
|
||||
let started_at = tokio::time::Instant::now();
|
||||
cell.get_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) })
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let elapsed = started_at.elapsed();
|
||||
assert!(
|
||||
elapsed >= sleep_for,
|
||||
"initialization should had taken at least the time time slept with permit"
|
||||
);
|
||||
|
||||
jh.await.unwrap();
|
||||
|
||||
assert_eq!(*cell.get().unwrap(), reinit);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reinit_with_deinit_permit() {
|
||||
let cell = Arc::new(OnceCell::new(42));
|
||||
|
||||
let (mol, permit) = cell.get().unwrap().take_and_deinit();
|
||||
cell.set(5, permit);
|
||||
assert_eq!(*cell.get().unwrap(), 5);
|
||||
|
||||
let (five, permit) = cell.get().unwrap().take_and_deinit();
|
||||
assert_eq!(5, five);
|
||||
cell.set(mol, permit);
|
||||
assert_eq!(*cell.get().unwrap(), 42);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn initialization_attemptable_until_ok() {
|
||||
let cell = OnceCell::default();
|
||||
|
||||
for _ in 0..10 {
|
||||
cell.get_or_init(|_permit| async { Err("whatever error") })
|
||||
.await
|
||||
.unwrap_err();
|
||||
}
|
||||
|
||||
let g = cell
|
||||
.get_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) })
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(*g, "finally success");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn initialization_is_cancellation_safe() {
|
||||
let cell = OnceCell::default();
|
||||
|
||||
let barrier = tokio::sync::Barrier::new(2);
|
||||
|
||||
let initializer = cell.get_or_init(|permit| async {
|
||||
barrier.wait().await;
|
||||
futures::future::pending::<()>().await;
|
||||
|
||||
Ok::<_, Infallible>(("never reached", permit))
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = initializer => { unreachable!("cannot complete; stuck in pending().await") },
|
||||
_ = barrier.wait() => {}
|
||||
};
|
||||
|
||||
// now initializer is dropped
|
||||
|
||||
assert!(cell.get().is_none());
|
||||
|
||||
let g = cell
|
||||
.get_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) })
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(*g, "now initialized");
|
||||
}
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
pub enum TimeoutCancellableError {
|
||||
Timeout,
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
/// Wrap [`tokio::time::timeout`] with a CancellationToken.
|
||||
///
|
||||
/// This wrapper is appropriate for any long running operation in a task
|
||||
/// that ought to respect a CancellationToken (which means most tasks).
|
||||
///
|
||||
/// The only time you should use a bare tokio::timeout is when the future `F`
|
||||
/// itself respects a CancellationToken: otherwise, always use this wrapper
|
||||
/// with your CancellationToken to ensure that your task does not hold up
|
||||
/// graceful shutdown.
|
||||
pub async fn timeout_cancellable<F>(
|
||||
duration: Duration,
|
||||
cancel: &CancellationToken,
|
||||
future: F,
|
||||
) -> Result<F::Output, TimeoutCancellableError>
|
||||
where
|
||||
F: std::future::Future,
|
||||
{
|
||||
tokio::select!(
|
||||
r = tokio::time::timeout(duration, future) => {
|
||||
r.map_err(|_| TimeoutCancellableError::Timeout)
|
||||
|
||||
},
|
||||
_ = cancel.cancelled() => {
|
||||
Err(TimeoutCancellableError::Cancelled)
|
||||
|
||||
}
|
||||
)
|
||||
}
|
||||
@@ -19,12 +19,13 @@ inotify.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
sysinfo.workspace = true
|
||||
tokio = { workspace = true, features = ["rt-multi-thread"] }
|
||||
tokio.workspace = true
|
||||
tokio-postgres.workspace = true
|
||||
tokio-stream.workspace = true
|
||||
tokio-util.workspace = true
|
||||
tracing.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
cgroups-rs = "0.3.3"
|
||||
|
||||
@@ -27,8 +27,8 @@ and old one if it exists.
|
||||
* the filecache: a struct that allows communication with the Postgres file cache.
|
||||
On startup, we connect to the filecache and hold on to the connection for the
|
||||
entire monitor lifetime.
|
||||
* the cgroup watcher: the `CgroupWatcher` polls the `neon-postgres` cgroup's memory
|
||||
usage and sends rolling aggregates to the runner.
|
||||
* the cgroup watcher: the `CgroupWatcher` manages the `neon-postgres` cgroup by
|
||||
listening for `memory.high` events and setting its `memory.{high,max}` values.
|
||||
* the runner: the runner marries the filecache and cgroup watcher together,
|
||||
communicating with the agent throught the `Dispatcher`, and then calling filecache
|
||||
and cgroup watcher functions as needed to upscale and downscale
|
||||
|
||||
@@ -1,38 +1,161 @@
|
||||
use std::fmt::{self, Debug, Formatter};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use cgroups_rs::{
|
||||
hierarchies::{self, is_cgroup2_unified_mode},
|
||||
memory::MemController,
|
||||
Subsystem,
|
||||
use std::{
|
||||
fmt::{Debug, Display},
|
||||
fs,
|
||||
pin::pin,
|
||||
sync::atomic::{AtomicU64, Ordering},
|
||||
};
|
||||
use tokio::sync::watch;
|
||||
|
||||
use anyhow::{anyhow, bail, Context};
|
||||
use cgroups_rs::{
|
||||
freezer::FreezerController,
|
||||
hierarchies::{self, is_cgroup2_unified_mode, UNIFIED_MOUNTPOINT},
|
||||
memory::MemController,
|
||||
MaxValue,
|
||||
Subsystem::{Freezer, Mem},
|
||||
};
|
||||
use inotify::{EventStream, Inotify, WatchMask};
|
||||
use tokio::sync::mpsc::{self, error::TryRecvError};
|
||||
use tokio::time::{Duration, Instant};
|
||||
use tokio_stream::{Stream, StreamExt};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::protocol::Resources;
|
||||
use crate::MiB;
|
||||
|
||||
/// Monotonically increasing counter of the number of memory.high events
|
||||
/// the cgroup has experienced.
|
||||
///
|
||||
/// We use this to determine if a modification to the `memory.events` file actually
|
||||
/// changed the `high` field. If not, we don't care about the change. When we
|
||||
/// read the file, we check the `high` field in the file against `MEMORY_EVENT_COUNT`
|
||||
/// to see if it changed since last time.
|
||||
pub static MEMORY_EVENT_COUNT: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
/// Monotonically increasing counter that gives each cgroup event a unique id.
|
||||
///
|
||||
/// This allows us to answer questions like "did this upscale arrive before this
|
||||
/// memory.high?". This static is also used by the `Sequenced` type to "tag" values
|
||||
/// with a sequence number. As such, prefer to used the `Sequenced` type rather
|
||||
/// than this static directly.
|
||||
static EVENT_SEQUENCE_NUMBER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
/// A memory event type reported in memory.events.
|
||||
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
|
||||
pub enum MemoryEvent {
|
||||
Low,
|
||||
High,
|
||||
Max,
|
||||
Oom,
|
||||
OomKill,
|
||||
OomGroupKill,
|
||||
}
|
||||
|
||||
impl MemoryEvent {
|
||||
fn as_str(&self) -> &str {
|
||||
match self {
|
||||
MemoryEvent::Low => "low",
|
||||
MemoryEvent::High => "high",
|
||||
MemoryEvent::Max => "max",
|
||||
MemoryEvent::Oom => "oom",
|
||||
MemoryEvent::OomKill => "oom_kill",
|
||||
MemoryEvent::OomGroupKill => "oom_group_kill",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for MemoryEvent {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for a `CgroupWatcher`
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
/// Interval at which we should be fetching memory statistics
|
||||
memory_poll_interval: Duration,
|
||||
// The target difference between the total memory reserved for the cgroup
|
||||
// and the value of the cgroup's memory.high.
|
||||
//
|
||||
// In other words, memory.high + oom_buffer_bytes will equal the total memory that the cgroup may
|
||||
// use (equal to system memory, minus whatever's taken out for the file cache).
|
||||
oom_buffer_bytes: u64,
|
||||
|
||||
/// The number of samples used in constructing aggregated memory statistics
|
||||
memory_history_len: usize,
|
||||
/// The number of most recent samples that will be periodically logged.
|
||||
///
|
||||
/// Each sample is logged exactly once. Increasing this value means that recent samples will be
|
||||
/// logged less frequently, and vice versa.
|
||||
///
|
||||
/// For simplicity, this value must be greater than or equal to `memory_history_len`.
|
||||
memory_history_log_interval: usize,
|
||||
// The amount of memory, in bytes, below a proposed new value for
|
||||
// memory.high that the cgroup's memory usage must be for us to downscale
|
||||
//
|
||||
// In other words, we can downscale only when:
|
||||
//
|
||||
// memory.current + memory_high_buffer_bytes < (proposed) memory.high
|
||||
//
|
||||
// TODO: there's some minor issues with this approach -- in particular, that we might have
|
||||
// memory in use by the kernel's page cache that we're actually ok with getting rid of.
|
||||
pub(crate) memory_high_buffer_bytes: u64,
|
||||
|
||||
// The maximum duration, in milliseconds, that we're allowed to pause
|
||||
// the cgroup for while waiting for the autoscaler-agent to upscale us
|
||||
max_upscale_wait: Duration,
|
||||
|
||||
// The required minimum time, in milliseconds, that we must wait before re-freezing
|
||||
// the cgroup while waiting for the autoscaler-agent to upscale us.
|
||||
do_not_freeze_more_often_than: Duration,
|
||||
|
||||
// The amount of memory, in bytes, that we should periodically increase memory.high
|
||||
// by while waiting for the autoscaler-agent to upscale us.
|
||||
//
|
||||
// This exists to avoid the excessive throttling that happens when a cgroup is above its
|
||||
// memory.high for too long. See more here:
|
||||
// https://github.com/neondatabase/autoscaling/issues/44#issuecomment-1522487217
|
||||
memory_high_increase_by_bytes: u64,
|
||||
|
||||
// The period, in milliseconds, at which we should repeatedly increase the value
|
||||
// of the cgroup's memory.high while we're waiting on upscaling and memory.high
|
||||
// is still being hit.
|
||||
//
|
||||
// Technically speaking, this actually serves as a rate limit to moderate responding to
|
||||
// memory.high events, but these are roughly equivalent if the process is still allocating
|
||||
// memory.
|
||||
memory_high_increase_every: Duration,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Calculate the new value for the cgroups memory.high based on system memory
|
||||
pub fn calculate_memory_high_value(&self, total_system_mem: u64) -> u64 {
|
||||
total_system_mem.saturating_sub(self.oom_buffer_bytes)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
memory_poll_interval: Duration::from_millis(100),
|
||||
memory_history_len: 5, // use 500ms of history for decision-making
|
||||
memory_history_log_interval: 20, // but only log every ~2s (otherwise it's spammy)
|
||||
oom_buffer_bytes: 100 * MiB,
|
||||
memory_high_buffer_bytes: 100 * MiB,
|
||||
// while waiting for upscale, don't freeze for more than 20ms every 1s
|
||||
max_upscale_wait: Duration::from_millis(20),
|
||||
do_not_freeze_more_often_than: Duration::from_millis(1000),
|
||||
// while waiting for upscale, increase memory.high by 10MiB every 25ms
|
||||
memory_high_increase_by_bytes: 10 * MiB,
|
||||
memory_high_increase_every: Duration::from_millis(25),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Used to represent data that is associated with a certain point in time, such
|
||||
/// as an upscale request or memory.high event.
|
||||
///
|
||||
/// Internally, creating a `Sequenced` uses a static atomic counter to obtain
|
||||
/// a unique sequence number. Sequence numbers are monotonically increasing,
|
||||
/// allowing us to answer questions like "did this upscale happen after this
|
||||
/// memory.high event?" by comparing the sequence numbers of the two events.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Sequenced<T> {
|
||||
seqnum: u64,
|
||||
data: T,
|
||||
}
|
||||
|
||||
impl<T> Sequenced<T> {
|
||||
pub fn new(data: T) -> Self {
|
||||
Self {
|
||||
seqnum: EVENT_SEQUENCE_NUMBER.fetch_add(1, Ordering::AcqRel),
|
||||
data,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -47,14 +170,74 @@ impl Default for Config {
|
||||
pub struct CgroupWatcher {
|
||||
pub config: Config,
|
||||
|
||||
/// The sequence number of the last upscale.
|
||||
///
|
||||
/// If we receive a memory.high event that has a _lower_ sequence number than
|
||||
/// `last_upscale_seqnum`, then we know it occured before the upscale, and we
|
||||
/// can safely ignore it.
|
||||
///
|
||||
/// Note: Like the `events` field, this doesn't _need_ interior mutability but we
|
||||
/// use it anyways so that methods take `&self`, not `&mut self`.
|
||||
last_upscale_seqnum: AtomicU64,
|
||||
|
||||
/// A channel on which we send messages to request upscale from the dispatcher.
|
||||
upscale_requester: mpsc::Sender<()>,
|
||||
|
||||
/// The actual cgroup we are watching and managing.
|
||||
cgroup: cgroups_rs::Cgroup,
|
||||
}
|
||||
|
||||
/// Read memory.events for the desired event type.
|
||||
///
|
||||
/// `path` specifies the path to the desired `memory.events` file.
|
||||
/// For more info, see the `memory.events` section of the [kernel docs]
|
||||
/// <https://docs.kernel.org/admin-guide/cgroup-v2.html#memory-interface-files>
|
||||
fn get_event_count(path: &str, event: MemoryEvent) -> anyhow::Result<u64> {
|
||||
let contents = fs::read_to_string(path)
|
||||
.with_context(|| format!("failed to read memory.events from {path}"))?;
|
||||
|
||||
// Then contents of the file look like:
|
||||
// low 42
|
||||
// high 101
|
||||
// ...
|
||||
contents
|
||||
.lines()
|
||||
.filter_map(|s| s.split_once(' '))
|
||||
.find(|(e, _)| *e == event.as_str())
|
||||
.ok_or_else(|| anyhow!("failed to find entry for memory.{event} events in {path}"))
|
||||
.and_then(|(_, count)| {
|
||||
count
|
||||
.parse::<u64>()
|
||||
.with_context(|| format!("failed to parse memory.{event} as u64"))
|
||||
})
|
||||
}
|
||||
|
||||
/// Create an event stream that produces events whenever the file at the provided
|
||||
/// path is modified.
|
||||
fn create_file_watcher(path: &str) -> anyhow::Result<EventStream<[u8; 1024]>> {
|
||||
info!("creating file watcher for {path}");
|
||||
let inotify = Inotify::init().context("failed to initialize file watcher")?;
|
||||
inotify
|
||||
.watches()
|
||||
.add(path, WatchMask::MODIFY)
|
||||
.with_context(|| format!("failed to start watching {path}"))?;
|
||||
inotify
|
||||
// The inotify docs use [0u8; 1024] so we'll just copy them. We only need
|
||||
// to store one event at a time - if the event gets written over, that's
|
||||
// ok. We still see that there is an event. For more information, see:
|
||||
// https://man7.org/linux/man-pages/man7/inotify.7.html
|
||||
.into_event_stream([0u8; 1024])
|
||||
.context("failed to start inotify event stream")
|
||||
}
|
||||
|
||||
impl CgroupWatcher {
|
||||
/// Create a new `CgroupWatcher`.
|
||||
#[tracing::instrument(skip_all, fields(%name))]
|
||||
pub fn new(name: String) -> anyhow::Result<Self> {
|
||||
pub fn new(
|
||||
name: String,
|
||||
// A channel on which to send upscale requests
|
||||
upscale_requester: mpsc::Sender<()>,
|
||||
) -> anyhow::Result<(Self, impl Stream<Item = Sequenced<u64>>)> {
|
||||
// TODO: clarify exactly why we need v2
|
||||
// Make sure cgroups v2 (aka unified) are supported
|
||||
if !is_cgroup2_unified_mode() {
|
||||
@@ -62,203 +245,410 @@ impl CgroupWatcher {
|
||||
}
|
||||
let cgroup = cgroups_rs::Cgroup::load(hierarchies::auto(), &name);
|
||||
|
||||
Ok(Self {
|
||||
cgroup,
|
||||
config: Default::default(),
|
||||
})
|
||||
// Start monitoring the cgroup for memory events. In general, for
|
||||
// cgroups v2 (aka unified), metrics are reported in files like
|
||||
// > `/sys/fs/cgroup/{name}/{metric}`
|
||||
// We are looking for `memory.high` events, which are stored in the
|
||||
// file `memory.events`. For more info, see the `memory.events` section
|
||||
// of https://docs.kernel.org/admin-guide/cgroup-v2.html#memory-interface-files
|
||||
let path = format!("{}/{}/memory.events", UNIFIED_MOUNTPOINT, &name);
|
||||
let memory_events = create_file_watcher(&path)
|
||||
.with_context(|| format!("failed to create event watcher for {path}"))?
|
||||
// This would be nice with with .inspect_err followed by .ok
|
||||
.filter_map(move |_| match get_event_count(&path, MemoryEvent::High) {
|
||||
Ok(high) => Some(high),
|
||||
Err(error) => {
|
||||
// TODO: Might want to just panic here
|
||||
warn!(?error, "failed to read high events count from {}", &path);
|
||||
None
|
||||
}
|
||||
})
|
||||
// Only report the event if the memory.high count increased
|
||||
.filter_map(|high| {
|
||||
if MEMORY_EVENT_COUNT.fetch_max(high, Ordering::AcqRel) < high {
|
||||
Some(high)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.map(Sequenced::new);
|
||||
|
||||
let initial_count = get_event_count(
|
||||
&format!("{}/{}/memory.events", UNIFIED_MOUNTPOINT, &name),
|
||||
MemoryEvent::High,
|
||||
)?;
|
||||
|
||||
info!(initial_count, "initial memory.high event count");
|
||||
|
||||
// Hard update `MEMORY_EVENT_COUNT` since there could have been processes
|
||||
// running in the cgroup before that caused it to be non-zero.
|
||||
MEMORY_EVENT_COUNT.fetch_max(initial_count, Ordering::AcqRel);
|
||||
|
||||
Ok((
|
||||
Self {
|
||||
cgroup,
|
||||
upscale_requester,
|
||||
last_upscale_seqnum: AtomicU64::new(0),
|
||||
config: Default::default(),
|
||||
},
|
||||
memory_events,
|
||||
))
|
||||
}
|
||||
|
||||
/// The entrypoint for the `CgroupWatcher`.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn watch(
|
||||
pub async fn watch<E>(
|
||||
&self,
|
||||
updates: watch::Sender<(Instant, MemoryHistory)>,
|
||||
) -> anyhow::Result<()> {
|
||||
// this requirement makes the code a bit easier to work with; see the config for more.
|
||||
assert!(self.config.memory_history_len <= self.config.memory_history_log_interval);
|
||||
// These are ~dependency injected~ (fancy, I know) because this function
|
||||
// should never return.
|
||||
// -> therefore: when we tokio::spawn it, we don't await the JoinHandle.
|
||||
// -> therefore: if we want to stick it in an Arc so many threads can access
|
||||
// it, methods can never take mutable access.
|
||||
// - note: we use the Arc strategy so that a) we can call this function
|
||||
// right here and b) the runner can call the set/get_memory methods
|
||||
// -> since calling recv() on a tokio::sync::mpsc::Receiver takes &mut self,
|
||||
// we just pass them in here instead of holding them in fields, as that
|
||||
// would require this method to take &mut self.
|
||||
mut upscales: mpsc::Receiver<Sequenced<Resources>>,
|
||||
events: E,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
E: Stream<Item = Sequenced<u64>>,
|
||||
{
|
||||
let mut wait_to_freeze = pin!(tokio::time::sleep(Duration::ZERO));
|
||||
let mut last_memory_high_increase_at: Option<Instant> = None;
|
||||
let mut events = pin!(events);
|
||||
|
||||
let mut ticker = tokio::time::interval(self.config.memory_poll_interval);
|
||||
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
// ticker.reset_immediately(); // FIXME: enable this once updating to tokio >= 1.30.0
|
||||
// Are we waiting to be upscaled? Could be true if we request upscale due
|
||||
// to a memory.high event and it does not arrive in time.
|
||||
let mut waiting_on_upscale = false;
|
||||
|
||||
let mem_controller = self.memory()?;
|
||||
loop {
|
||||
tokio::select! {
|
||||
upscale = upscales.recv() => {
|
||||
let Sequenced { seqnum, data } = upscale
|
||||
.context("failed to listen on upscale notification channel")?;
|
||||
waiting_on_upscale = false;
|
||||
last_memory_high_increase_at = None;
|
||||
self.last_upscale_seqnum.store(seqnum, Ordering::Release);
|
||||
info!(cpu = data.cpu, mem_bytes = data.mem, "received upscale");
|
||||
}
|
||||
event = events.next() => {
|
||||
let Some(Sequenced { seqnum, .. }) = event else {
|
||||
bail!("failed to listen for memory.high events")
|
||||
};
|
||||
// The memory.high came before our last upscale, so we consider
|
||||
// it resolved
|
||||
if self.last_upscale_seqnum.fetch_max(seqnum, Ordering::AcqRel) > seqnum {
|
||||
info!(
|
||||
"received memory.high event, but it came before our last upscale -> ignoring it"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// buffer for samples that will be logged. once full, it remains so.
|
||||
let history_log_len = self.config.memory_history_log_interval;
|
||||
let mut history_log_buf = vec![MemoryStatus::zeroed(); history_log_len];
|
||||
// The memory.high came after our latest upscale. We don't
|
||||
// want to do anything yet, so peek the next event in hopes
|
||||
// that it's an upscale.
|
||||
if let Some(upscale_num) = self
|
||||
.upscaled(&mut upscales)
|
||||
.context("failed to check if we were upscaled")?
|
||||
{
|
||||
if upscale_num > seqnum {
|
||||
info!(
|
||||
"received memory.high event, but it came before our last upscale -> ignoring it"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
for t in 0_u64.. {
|
||||
ticker.tick().await;
|
||||
// If it's been long enough since we last froze, freeze the
|
||||
// cgroup and request upscale
|
||||
if wait_to_freeze.is_elapsed() {
|
||||
info!("received memory.high event -> requesting upscale");
|
||||
waiting_on_upscale = self
|
||||
.handle_memory_high_event(&mut upscales)
|
||||
.await
|
||||
.context("failed to handle upscale")?;
|
||||
wait_to_freeze
|
||||
.as_mut()
|
||||
.reset(Instant::now() + self.config.do_not_freeze_more_often_than);
|
||||
continue;
|
||||
}
|
||||
|
||||
let now = Instant::now();
|
||||
let mem = Self::memory_usage(mem_controller);
|
||||
// Ok, we can't freeze, just request upscale
|
||||
if !waiting_on_upscale {
|
||||
info!("received memory.high event, but too soon to refreeze -> requesting upscale");
|
||||
|
||||
let i = t as usize % history_log_len;
|
||||
history_log_buf[i] = mem;
|
||||
// Make check to make sure we haven't been upscaled in the
|
||||
// meantine (can happen if the agent independently decides
|
||||
// to upscale us again)
|
||||
if self
|
||||
.upscaled(&mut upscales)
|
||||
.context("failed to check if we were upscaled")?
|
||||
.is_some()
|
||||
{
|
||||
info!("no need to request upscaling because we got upscaled");
|
||||
continue;
|
||||
}
|
||||
self.upscale_requester
|
||||
.send(())
|
||||
.await
|
||||
.context("failed to request upscale")?;
|
||||
waiting_on_upscale = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// We're taking *at most* memory_history_len values; we may be bounded by the total
|
||||
// number of samples that have come in so far.
|
||||
let samples_count = (t + 1).min(self.config.memory_history_len as u64) as usize;
|
||||
// NB: in `ring_buf_recent_values_iter`, `i` is *inclusive*, which matches the fact
|
||||
// that we just inserted a value there, so the end of the iterator will *include* the
|
||||
// value at i, rather than stopping just short of it.
|
||||
let samples = ring_buf_recent_values_iter(&history_log_buf, i, samples_count);
|
||||
// Shoot, we can't freeze or and we're still waiting on upscale,
|
||||
// increase memory.high to reduce throttling
|
||||
let can_increase_memory_high = match last_memory_high_increase_at {
|
||||
None => true,
|
||||
Some(t) => t.elapsed() > self.config.memory_high_increase_every,
|
||||
};
|
||||
if can_increase_memory_high {
|
||||
info!(
|
||||
"received memory.high event, \
|
||||
but too soon to refreeze and already requested upscale \
|
||||
-> increasing memory.high"
|
||||
);
|
||||
|
||||
let summary = MemoryHistory {
|
||||
avg_non_reclaimable: samples.map(|h| h.non_reclaimable).sum::<u64>()
|
||||
/ samples_count as u64,
|
||||
samples_count,
|
||||
samples_span: self.config.memory_poll_interval * (samples_count - 1) as u32,
|
||||
// Make check to make sure we haven't been upscaled in the
|
||||
// meantine (can happen if the agent independently decides
|
||||
// to upscale us again)
|
||||
if self
|
||||
.upscaled(&mut upscales)
|
||||
.context("failed to check if we were upscaled")?
|
||||
.is_some()
|
||||
{
|
||||
info!("no need to increase memory.high because got upscaled");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Request upscale anyways (the agent will handle deduplicating
|
||||
// requests)
|
||||
self.upscale_requester
|
||||
.send(())
|
||||
.await
|
||||
.context("failed to request upscale")?;
|
||||
|
||||
let memory_high =
|
||||
self.get_memory_high_bytes().context("failed to get memory.high")?;
|
||||
let new_high = memory_high + self.config.memory_high_increase_by_bytes;
|
||||
info!(
|
||||
current_high_bytes = memory_high,
|
||||
new_high_bytes = new_high,
|
||||
"updating memory.high"
|
||||
);
|
||||
self.set_memory_high_bytes(new_high)
|
||||
.context("failed to set memory.high")?;
|
||||
last_memory_high_increase_at = Some(Instant::now());
|
||||
continue;
|
||||
}
|
||||
|
||||
info!("received memory.high event, but can't do anything");
|
||||
}
|
||||
};
|
||||
|
||||
// Log the current history if it's time to do so. Because `history_log_buf` has length
|
||||
// equal to the logging interval, we can just log the entire buffer every time we set
|
||||
// the last entry, which also means that for this log line, we can ignore that it's a
|
||||
// ring buffer (because all the entries are in order of increasing time).
|
||||
if i == history_log_len - 1 {
|
||||
info!(
|
||||
history = ?MemoryStatus::debug_slice(&history_log_buf),
|
||||
summary = ?summary,
|
||||
"Recent cgroup memory statistics history"
|
||||
);
|
||||
}
|
||||
|
||||
updates
|
||||
.send((now, summary))
|
||||
.context("failed to send MemoryHistory")?;
|
||||
}
|
||||
}
|
||||
|
||||
unreachable!()
|
||||
/// Handle a `memory.high`, returning whether we are still waiting on upscale
|
||||
/// by the time the function returns.
|
||||
///
|
||||
/// The general plan for handling a `memory.high` event is as follows:
|
||||
/// 1. Freeze the cgroup
|
||||
/// 2. Start a timer for `self.config.max_upscale_wait`
|
||||
/// 3. Request upscale
|
||||
/// 4. After the timer elapses or we receive upscale, thaw the cgroup.
|
||||
/// 5. Return whether or not we are still waiting for upscale. If we are,
|
||||
/// we'll increase the cgroups memory.high to avoid getting oom killed
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn handle_memory_high_event(
|
||||
&self,
|
||||
upscales: &mut mpsc::Receiver<Sequenced<Resources>>,
|
||||
) -> anyhow::Result<bool> {
|
||||
// Immediately freeze the cgroup before doing anything else.
|
||||
info!("received memory.high event -> freezing cgroup");
|
||||
self.freeze().context("failed to freeze cgroup")?;
|
||||
|
||||
// We'll use this for logging durations
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Await the upscale until we have to unfreeze
|
||||
let timed =
|
||||
tokio::time::timeout(self.config.max_upscale_wait, self.await_upscale(upscales));
|
||||
|
||||
// Request the upscale
|
||||
info!(
|
||||
wait = ?self.config.max_upscale_wait,
|
||||
"sending request for immediate upscaling",
|
||||
);
|
||||
self.upscale_requester
|
||||
.send(())
|
||||
.await
|
||||
.context("failed to request upscale")?;
|
||||
|
||||
let waiting_on_upscale = match timed.await {
|
||||
Ok(Ok(())) => {
|
||||
info!(elapsed = ?start_time.elapsed(), "received upscale in time");
|
||||
false
|
||||
}
|
||||
// **important**: unfreeze the cgroup before ?-reporting the error
|
||||
Ok(Err(e)) => {
|
||||
info!("error waiting for upscale -> thawing cgroup");
|
||||
self.thaw()
|
||||
.context("failed to thaw cgroup after errored waiting for upscale")?;
|
||||
Err(e.context("failed to await upscale"))?
|
||||
}
|
||||
Err(_) => {
|
||||
info!(elapsed = ?self.config.max_upscale_wait, "timed out waiting for upscale");
|
||||
true
|
||||
}
|
||||
};
|
||||
|
||||
info!("thawing cgroup");
|
||||
self.thaw().context("failed to thaw cgroup")?;
|
||||
|
||||
Ok(waiting_on_upscale)
|
||||
}
|
||||
|
||||
/// Checks whether we were just upscaled, returning the upscale's sequence
|
||||
/// number if so.
|
||||
#[tracing::instrument(skip_all)]
|
||||
fn upscaled(
|
||||
&self,
|
||||
upscales: &mut mpsc::Receiver<Sequenced<Resources>>,
|
||||
) -> anyhow::Result<Option<u64>> {
|
||||
let Sequenced { seqnum, data } = match upscales.try_recv() {
|
||||
Ok(upscale) => upscale,
|
||||
Err(TryRecvError::Empty) => return Ok(None),
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
bail!("upscale notification channel was disconnected")
|
||||
}
|
||||
};
|
||||
|
||||
// Make sure to update the last upscale sequence number
|
||||
self.last_upscale_seqnum.store(seqnum, Ordering::Release);
|
||||
info!(cpu = data.cpu, mem_bytes = data.mem, "received upscale");
|
||||
Ok(Some(seqnum))
|
||||
}
|
||||
|
||||
/// Await an upscale event, discarding any `memory.high` events received in
|
||||
/// the process.
|
||||
///
|
||||
/// This is used in `handle_memory_high_event`, where we need to listen
|
||||
/// for upscales in particular so we know if we can thaw the cgroup early.
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn await_upscale(
|
||||
&self,
|
||||
upscales: &mut mpsc::Receiver<Sequenced<Resources>>,
|
||||
) -> anyhow::Result<()> {
|
||||
let Sequenced { seqnum, .. } = upscales
|
||||
.recv()
|
||||
.await
|
||||
.context("error listening for upscales")?;
|
||||
|
||||
self.last_upscale_seqnum.store(seqnum, Ordering::Release);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the cgroup's name.
|
||||
pub fn path(&self) -> &str {
|
||||
self.cgroup.path()
|
||||
}
|
||||
}
|
||||
|
||||
// Methods for manipulating the actual cgroup
|
||||
impl CgroupWatcher {
|
||||
/// Get a handle on the freezer subsystem.
|
||||
fn freezer(&self) -> anyhow::Result<&FreezerController> {
|
||||
if let Some(Freezer(freezer)) = self
|
||||
.cgroup
|
||||
.subsystems()
|
||||
.iter()
|
||||
.find(|sub| matches!(sub, Freezer(_)))
|
||||
{
|
||||
Ok(freezer)
|
||||
} else {
|
||||
anyhow::bail!("could not find freezer subsystem")
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to freeze the cgroup.
|
||||
pub fn freeze(&self) -> anyhow::Result<()> {
|
||||
self.freezer()
|
||||
.context("failed to get freezer subsystem")?
|
||||
.freeze()
|
||||
.context("failed to freeze")
|
||||
}
|
||||
|
||||
/// Attempt to thaw the cgroup.
|
||||
pub fn thaw(&self) -> anyhow::Result<()> {
|
||||
self.freezer()
|
||||
.context("failed to get freezer subsystem")?
|
||||
.thaw()
|
||||
.context("failed to thaw")
|
||||
}
|
||||
|
||||
/// Get a handle on the memory subsystem.
|
||||
///
|
||||
/// Note: this method does not require `self.memory_update_lock` because
|
||||
/// getting a handle to the subsystem does not access any of the files we
|
||||
/// care about, such as memory.high and memory.events
|
||||
fn memory(&self) -> anyhow::Result<&MemController> {
|
||||
self.cgroup
|
||||
if let Some(Mem(memory)) = self
|
||||
.cgroup
|
||||
.subsystems()
|
||||
.iter()
|
||||
.find_map(|sub| match sub {
|
||||
Subsystem::Mem(c) => Some(c),
|
||||
_ => None,
|
||||
.find(|sub| matches!(sub, Mem(_)))
|
||||
{
|
||||
Ok(memory)
|
||||
} else {
|
||||
anyhow::bail!("could not find memory subsystem")
|
||||
}
|
||||
}
|
||||
|
||||
/// Get cgroup current memory usage.
|
||||
pub fn current_memory_usage(&self) -> anyhow::Result<u64> {
|
||||
Ok(self
|
||||
.memory()
|
||||
.context("failed to get memory subsystem")?
|
||||
.memory_stat()
|
||||
.usage_in_bytes)
|
||||
}
|
||||
|
||||
/// Set cgroup memory.high threshold.
|
||||
pub fn set_memory_high_bytes(&self, bytes: u64) -> anyhow::Result<()> {
|
||||
self.set_memory_high_internal(MaxValue::Value(u64::min(bytes, i64::MAX as u64) as i64))
|
||||
}
|
||||
|
||||
/// Set the cgroup's memory.high to 'max', disabling it.
|
||||
pub fn unset_memory_high(&self) -> anyhow::Result<()> {
|
||||
self.set_memory_high_internal(MaxValue::Max)
|
||||
}
|
||||
|
||||
fn set_memory_high_internal(&self, value: MaxValue) -> anyhow::Result<()> {
|
||||
self.memory()
|
||||
.context("failed to get memory subsystem")?
|
||||
.set_mem(cgroups_rs::memory::SetMemory {
|
||||
low: None,
|
||||
high: Some(value),
|
||||
min: None,
|
||||
max: None,
|
||||
})
|
||||
.ok_or_else(|| anyhow!("could not find memory subsystem"))
|
||||
.map_err(anyhow::Error::from)
|
||||
}
|
||||
|
||||
/// Given a handle on the memory subsystem, returns the current memory information
|
||||
fn memory_usage(mem_controller: &MemController) -> MemoryStatus {
|
||||
let stat = mem_controller.memory_stat().stat;
|
||||
MemoryStatus {
|
||||
non_reclaimable: stat.active_anon + stat.inactive_anon,
|
||||
/// Get memory.high threshold.
|
||||
pub fn get_memory_high_bytes(&self) -> anyhow::Result<u64> {
|
||||
let high = self
|
||||
.memory()
|
||||
.context("failed to get memory subsystem while getting memory statistics")?
|
||||
.get_mem()
|
||||
.map(|mem| mem.high)
|
||||
.context("failed to get memory statistics from subsystem")?;
|
||||
match high {
|
||||
Some(MaxValue::Max) => Ok(i64::MAX as u64),
|
||||
Some(MaxValue::Value(high)) => Ok(high as u64),
|
||||
None => anyhow::bail!("failed to read memory.high from memory subsystem"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function for `CgroupWatcher::watch`
|
||||
fn ring_buf_recent_values_iter<T>(
|
||||
buf: &[T],
|
||||
last_value_idx: usize,
|
||||
count: usize,
|
||||
) -> impl '_ + Iterator<Item = &T> {
|
||||
// Assertion carried over from `CgroupWatcher::watch`, to make the logic in this function
|
||||
// easier (we only have to add `buf.len()` once, rather than a dynamic number of times).
|
||||
assert!(count <= buf.len());
|
||||
|
||||
buf.iter()
|
||||
// 'cycle' because the values could wrap around
|
||||
.cycle()
|
||||
// with 'cycle', this skip is more like 'offset', and functionally this is
|
||||
// offsettting by 'last_value_idx - count (mod buf.len())', but we have to be
|
||||
// careful to avoid underflow, so we pre-add buf.len().
|
||||
// The '+ 1' is because `last_value_idx` is inclusive, rather than exclusive.
|
||||
.skip((buf.len() + last_value_idx + 1 - count) % buf.len())
|
||||
.take(count)
|
||||
}
|
||||
|
||||
/// Summary of recent memory usage
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct MemoryHistory {
|
||||
/// Rolling average of non-reclaimable memory usage samples over the last `history_period`
|
||||
pub avg_non_reclaimable: u64,
|
||||
|
||||
/// The number of samples used to construct this summary
|
||||
pub samples_count: usize,
|
||||
/// Total timespan between the first and last sample used for this summary
|
||||
pub samples_span: Duration,
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct MemoryStatus {
|
||||
non_reclaimable: u64,
|
||||
}
|
||||
|
||||
impl MemoryStatus {
|
||||
fn zeroed() -> Self {
|
||||
MemoryStatus { non_reclaimable: 0 }
|
||||
}
|
||||
|
||||
fn debug_slice(slice: &[Self]) -> impl '_ + Debug {
|
||||
struct DS<'a>(&'a [MemoryStatus]);
|
||||
|
||||
impl<'a> Debug for DS<'a> {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
f.debug_struct("[MemoryStatus]")
|
||||
.field(
|
||||
"non_reclaimable[..]",
|
||||
&Fields(self.0, |stat: &MemoryStatus| {
|
||||
BytesToGB(stat.non_reclaimable)
|
||||
}),
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
struct Fields<'a, F>(&'a [MemoryStatus], F);
|
||||
|
||||
impl<'a, F: Fn(&MemoryStatus) -> T, T: Debug> Debug for Fields<'a, F> {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
f.debug_list().entries(self.0.iter().map(&self.1)).finish()
|
||||
}
|
||||
}
|
||||
|
||||
struct BytesToGB(u64);
|
||||
|
||||
impl Debug for BytesToGB {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
f.write_fmt(format_args!(
|
||||
"{:.3}Gi",
|
||||
self.0 as f64 / (1_u64 << 30) as f64
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
DS(slice)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
fn ring_buf_iter() {
|
||||
let buf = vec![0_i32, 1, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||
|
||||
let values = |offset, count| {
|
||||
super::ring_buf_recent_values_iter(&buf, offset, count)
|
||||
.copied()
|
||||
.collect::<Vec<i32>>()
|
||||
};
|
||||
|
||||
// Boundary conditions: start, end, and entire thing:
|
||||
assert_eq!(values(0, 1), [0]);
|
||||
assert_eq!(values(3, 4), [0, 1, 2, 3]);
|
||||
assert_eq!(values(9, 4), [6, 7, 8, 9]);
|
||||
assert_eq!(values(9, 10), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
|
||||
|
||||
// "normal" operation: no wraparound
|
||||
assert_eq!(values(7, 4), [4, 5, 6, 7]);
|
||||
|
||||
// wraparound:
|
||||
assert_eq!(values(0, 4), [7, 8, 9, 0]);
|
||||
assert_eq!(values(1, 4), [8, 9, 0, 1]);
|
||||
assert_eq!(values(2, 4), [9, 0, 1, 2]);
|
||||
assert_eq!(values(2, 10), [3, 4, 5, 6, 7, 8, 9, 0, 1, 2]);
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user