mirror of
https://github.com/neondatabase/neon.git
synced 2026-01-17 18:32:56 +00:00
Compare commits
186 Commits
lfc_fixes2
...
remove_ini
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bd235a5fe3 | ||
|
|
f95f001b8b | ||
|
|
e0821e1eab | ||
|
|
4469b1a62c | ||
|
|
842223b47f | ||
|
|
893616051d | ||
|
|
7cdde285a5 | ||
|
|
9c30883c4b | ||
|
|
0495798591 | ||
|
|
87389bc933 | ||
|
|
ea118a238a | ||
|
|
e9b227a11e | ||
|
|
40441f8ada | ||
|
|
a8a39cd464 | ||
|
|
b989ad1922 | ||
|
|
acef742a6e | ||
|
|
11d9d801b5 | ||
|
|
fc47af156f | ||
|
|
e310533ed3 | ||
|
|
1d68f52b57 | ||
|
|
4cd47b7d4b | ||
|
|
0141c95788 | ||
|
|
0ac4cf67a6 | ||
|
|
4be6bc7251 | ||
|
|
a394f49e0d | ||
|
|
c00651ff9b | ||
|
|
bea8efac24 | ||
|
|
ad5b02e175 | ||
|
|
b09a851705 | ||
|
|
85cd97af61 | ||
|
|
e6470ee92e | ||
|
|
dc72567288 | ||
|
|
6defa2b5d5 | ||
|
|
b3d3a2587d | ||
|
|
b85fc39bdb | ||
|
|
09b5954526 | ||
|
|
306c4f9967 | ||
|
|
5ceccdc7de | ||
|
|
cdcaa329bf | ||
|
|
27bdbf5e36 | ||
|
|
4c7fa12a2a | ||
|
|
367971a0e9 | ||
|
|
51570114ea | ||
|
|
098d3111a5 | ||
|
|
3737fe3a4b | ||
|
|
5650138532 | ||
|
|
2dca4c03fc | ||
|
|
0b790b6d00 | ||
|
|
e82d1ad6b8 | ||
|
|
4f0a8e92ad | ||
|
|
5952f350cb | ||
|
|
726c8e6730 | ||
|
|
f7067a38b7 | ||
|
|
896347f307 | ||
|
|
e5c81fef86 | ||
|
|
7ebe9ca1ac | ||
|
|
1588601503 | ||
|
|
9c35e1e6e5 | ||
|
|
d8c21ec70d | ||
|
|
ad99fa5f03 | ||
|
|
e675f4cec8 | ||
|
|
4db8efb2cf | ||
|
|
07c2b29895 | ||
|
|
9cdffd164a | ||
|
|
87db4b441c | ||
|
|
964c5c56b7 | ||
|
|
bd59349af3 | ||
|
|
2bd79906d9 | ||
|
|
493b47e1da | ||
|
|
c13e932c3b | ||
|
|
a5292f7e67 | ||
|
|
262348e41b | ||
|
|
68f15cf967 | ||
|
|
39f8fd6945 | ||
|
|
83567f9e4e | ||
|
|
71611f4ab3 | ||
|
|
7c16b5215e | ||
|
|
39b148b74e | ||
|
|
116c342cad | ||
|
|
ba4fe9e10f | ||
|
|
de90bf4663 | ||
|
|
8360307ea0 | ||
|
|
6129077d31 | ||
|
|
e0ebdfc7ce | ||
|
|
c508d3b5fa | ||
|
|
acda65d7d4 | ||
|
|
378daa358b | ||
|
|
85f4514e7d | ||
|
|
f70019797c | ||
|
|
325258413a | ||
|
|
4ddbc0e46d | ||
|
|
a673e4e7a9 | ||
|
|
c155cc0c3f | ||
|
|
32126d705b | ||
|
|
5683ae9eab | ||
|
|
4778b6a12e | ||
|
|
8b8be7bed4 | ||
|
|
a461c459d8 | ||
|
|
4ae2d1390d | ||
|
|
c5949e1fd6 | ||
|
|
127837abb0 | ||
|
|
b2c96047d0 | ||
|
|
44202eeb3b | ||
|
|
4bef977c56 | ||
|
|
a0b862a8bd | ||
|
|
767ef29390 | ||
|
|
a8a800af51 | ||
|
|
1e250cd90a | ||
|
|
eaaa18f6ed | ||
|
|
188f67e1df | ||
|
|
7e805200bb | ||
|
|
c6ca1d76d2 | ||
|
|
94b4e76e13 | ||
|
|
b514da90cb | ||
|
|
7d17f1719f | ||
|
|
41ee75bc71 | ||
|
|
11e523f503 | ||
|
|
b1a1126152 | ||
|
|
a8899e1e0f | ||
|
|
2fbd5ab075 | ||
|
|
702382e99a | ||
|
|
1b53b3e200 | ||
|
|
b332268cec | ||
|
|
76c702219c | ||
|
|
ba856140e7 | ||
|
|
2cf6a47cca | ||
|
|
5a8bcdccb0 | ||
|
|
2c8741a5ed | ||
|
|
893b7bac9a | ||
|
|
66f8f5f1c8 | ||
|
|
3a19da1066 | ||
|
|
572eda44ee | ||
|
|
b1d6af5ebe | ||
|
|
f842b22b90 | ||
|
|
d444d4dcea | ||
|
|
c8637f3736 | ||
|
|
ecf759be6d | ||
|
|
9a9d9eba42 | ||
|
|
1f4805baf8 | ||
|
|
5c88213eaf | ||
|
|
607d19f0e0 | ||
|
|
1fa0478980 | ||
|
|
9da67c4f19 | ||
|
|
16c87b5bda | ||
|
|
9fe5cc6a82 | ||
|
|
543b8153c6 | ||
|
|
3a8959a4c4 | ||
|
|
4a50483861 | ||
|
|
f775928dfc | ||
|
|
ea648cfbc6 | ||
|
|
093f8c5f45 | ||
|
|
00c71bb93a | ||
|
|
9256788273 | ||
|
|
9e1449353d | ||
|
|
b06dffe3dc | ||
|
|
b08a0ee186 | ||
|
|
3666df6342 | ||
|
|
0ca342260c | ||
|
|
ded7f48565 | ||
|
|
e09d5ada6a | ||
|
|
8c522ea034 | ||
|
|
44b1c4c456 | ||
|
|
99c15907c1 | ||
|
|
c3626e3432 | ||
|
|
dd6990567f | ||
|
|
21deb81acb | ||
|
|
dbb21d6592 | ||
|
|
ddceb9e6cd | ||
|
|
0fc3708de2 | ||
|
|
e0c8ad48d4 | ||
|
|
39e144696f | ||
|
|
653044f754 | ||
|
|
80dcdfa8bf | ||
|
|
685add2009 | ||
|
|
d4dc86f8e3 | ||
|
|
5158de70f3 | ||
|
|
aec9188d36 | ||
|
|
acefee9a32 | ||
|
|
bf065aabdf | ||
|
|
fe74fac276 | ||
|
|
b91ac670e1 | ||
|
|
b3195afd20 | ||
|
|
7eaa7a496b | ||
|
|
4772cd6c93 | ||
|
|
010b4d0d5c | ||
|
|
477cb3717b |
@@ -22,5 +22,11 @@ platforms = [
|
||||
# "x86_64-pc-windows-msvc",
|
||||
]
|
||||
|
||||
[final-excludes]
|
||||
# vm_monitor benefits from the same Cargo.lock as the rest of our artifacts, but
|
||||
# it is built primarly in separate repo neondatabase/autoscaling and thus is excluded
|
||||
# from depending on workspace-hack because most of the dependencies are not used.
|
||||
workspace-members = ["vm_monitor"]
|
||||
|
||||
# Write out exact versions rather than a semver range. (Defaults to false.)
|
||||
# exact-versions = true
|
||||
|
||||
5
.github/ISSUE_TEMPLATE/epic-template.md
vendored
5
.github/ISSUE_TEMPLATE/epic-template.md
vendored
@@ -17,8 +17,9 @@ assignees: ''
|
||||
## Implementation ideas
|
||||
|
||||
|
||||
## Tasks
|
||||
- [ ]
|
||||
```[tasklist]
|
||||
### Tasks
|
||||
```
|
||||
|
||||
|
||||
## Other related tasks and Epics
|
||||
|
||||
2
.github/actionlint.yml
vendored
2
.github/actionlint.yml
vendored
@@ -5,4 +5,6 @@ self-hosted-runner:
|
||||
- small
|
||||
- us-east-2
|
||||
config-variables:
|
||||
- REMOTE_STORAGE_AZURE_CONTAINER
|
||||
- REMOTE_STORAGE_AZURE_REGION
|
||||
- SLACK_UPCOMING_RELEASE_CHANNEL_ID
|
||||
|
||||
@@ -203,6 +203,10 @@ runs:
|
||||
COMMIT_SHA: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
BASE_S3_URL: ${{ steps.generate-report.outputs.base-s3-url }}
|
||||
run: |
|
||||
if [ ! -d "${WORKDIR}/report/data/test-cases" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
export DATABASE_URL=${REGRESS_TEST_RESULT_CONNSTR_NEW}
|
||||
|
||||
./scripts/pysync
|
||||
|
||||
20
.github/workflows/build_and_test.yml
vendored
20
.github/workflows/build_and_test.yml
vendored
@@ -320,6 +320,9 @@ jobs:
|
||||
- name: Build neon extensions
|
||||
run: mold -run make neon-pg-ext -j$(nproc)
|
||||
|
||||
- name: Build walproposer-lib
|
||||
run: mold -run make walproposer-lib -j$(nproc)
|
||||
|
||||
- name: Run cargo build
|
||||
run: |
|
||||
${cov_prefix} mold -run cargo build $CARGO_FLAGS $CARGO_FEATURES --bins --tests
|
||||
@@ -335,6 +338,16 @@ jobs:
|
||||
# Avoid `$CARGO_FEATURES` since there's no `testing` feature in the e2e tests now
|
||||
${cov_prefix} cargo test $CARGO_FLAGS --package remote_storage --test test_real_s3
|
||||
|
||||
# Run separate tests for real Azure Blob Storage
|
||||
# XXX: replace region with `eu-central-1`-like region
|
||||
export ENABLE_REAL_AZURE_REMOTE_STORAGE=y
|
||||
export AZURE_STORAGE_ACCOUNT="${{ secrets.AZURE_STORAGE_ACCOUNT_DEV }}"
|
||||
export AZURE_STORAGE_ACCESS_KEY="${{ secrets.AZURE_STORAGE_ACCESS_KEY_DEV }}"
|
||||
export REMOTE_STORAGE_AZURE_CONTAINER="${{ vars.REMOTE_STORAGE_AZURE_CONTAINER }}"
|
||||
export REMOTE_STORAGE_AZURE_REGION="${{ vars.REMOTE_STORAGE_AZURE_REGION }}"
|
||||
# Avoid `$CARGO_FEATURES` since there's no `testing` feature in the e2e tests now
|
||||
${cov_prefix} cargo test $CARGO_FLAGS --package remote_storage --test test_real_azure
|
||||
|
||||
- name: Install rust binaries
|
||||
run: |
|
||||
# Install target binaries
|
||||
@@ -420,7 +433,7 @@ jobs:
|
||||
rerun_flaky: true
|
||||
pg_version: ${{ matrix.pg_version }}
|
||||
env:
|
||||
TEST_RESULT_CONNSTR: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR }}
|
||||
TEST_RESULT_CONNSTR: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}
|
||||
CHECK_ONDISK_DATA_COMPATIBILITY: nonempty
|
||||
|
||||
- name: Merge and upload coverage data
|
||||
@@ -455,7 +468,7 @@ jobs:
|
||||
env:
|
||||
VIP_VAP_ACCESS_TOKEN: "${{ secrets.VIP_VAP_ACCESS_TOKEN }}"
|
||||
PERF_TEST_RESULT_CONNSTR: "${{ secrets.PERF_TEST_RESULT_CONNSTR }}"
|
||||
TEST_RESULT_CONNSTR: "${{ secrets.REGRESS_TEST_RESULT_CONNSTR }}"
|
||||
TEST_RESULT_CONNSTR: "${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}"
|
||||
# XXX: no coverage data handling here, since benchmarks are run on release builds,
|
||||
# while coverage is currently collected for the debug ones
|
||||
|
||||
@@ -710,6 +723,7 @@ jobs:
|
||||
--cache-repo 369495373322.dkr.ecr.eu-central-1.amazonaws.com/cache
|
||||
--context .
|
||||
--build-arg GIT_VERSION=${{ github.event.pull_request.head.sha || github.sha }}
|
||||
--build-arg BUILD_TAG=${{ needs.tag.outputs.build-tag }}
|
||||
--build-arg REPOSITORY=369495373322.dkr.ecr.eu-central-1.amazonaws.com
|
||||
--destination 369495373322.dkr.ecr.eu-central-1.amazonaws.com/neon:${{needs.tag.outputs.build-tag}}
|
||||
--destination neondatabase/neon:${{needs.tag.outputs.build-tag}}
|
||||
@@ -834,7 +848,7 @@ jobs:
|
||||
run:
|
||||
shell: sh -eu {0}
|
||||
env:
|
||||
VM_BUILDER_VERSION: v0.17.12
|
||||
VM_BUILDER_VERSION: v0.18.5
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
18
.github/workflows/neon_extra_builds.yml
vendored
18
.github/workflows/neon_extra_builds.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: true
|
||||
fetch-depth: 1
|
||||
@@ -90,18 +90,21 @@ jobs:
|
||||
|
||||
- name: Build postgres v14
|
||||
if: steps.cache_pg_14.outputs.cache-hit != 'true'
|
||||
run: make postgres-v14 -j$(nproc)
|
||||
run: make postgres-v14 -j$(sysctl -n hw.ncpu)
|
||||
|
||||
- name: Build postgres v15
|
||||
if: steps.cache_pg_15.outputs.cache-hit != 'true'
|
||||
run: make postgres-v15 -j$(nproc)
|
||||
run: make postgres-v15 -j$(sysctl -n hw.ncpu)
|
||||
|
||||
- name: Build postgres v16
|
||||
if: steps.cache_pg_16.outputs.cache-hit != 'true'
|
||||
run: make postgres-v16 -j$(nproc)
|
||||
run: make postgres-v16 -j$(sysctl -n hw.ncpu)
|
||||
|
||||
- name: Build neon extensions
|
||||
run: make neon-pg-ext -j$(nproc)
|
||||
run: make neon-pg-ext -j$(sysctl -n hw.ncpu)
|
||||
|
||||
- name: Build walproposer-lib
|
||||
run: make walproposer-lib -j$(sysctl -n hw.ncpu)
|
||||
|
||||
- name: Run cargo build
|
||||
run: cargo build --all --release
|
||||
@@ -126,7 +129,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: true
|
||||
fetch-depth: 1
|
||||
@@ -135,6 +138,9 @@ jobs:
|
||||
- name: Get postgres headers
|
||||
run: make postgres-headers -j$(nproc)
|
||||
|
||||
- name: Build walproposer-lib
|
||||
run: make walproposer-lib -j$(nproc)
|
||||
|
||||
- name: Produce the build stats
|
||||
run: cargo build --all --release --timings
|
||||
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -2,7 +2,7 @@ name: Create Release Branch
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 7 * * 2'
|
||||
- cron: '0 7 * * 5'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
|
||||
1061
Cargo.lock
generated
1061
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
24
Cargo.toml
24
Cargo.toml
@@ -26,6 +26,7 @@ members = [
|
||||
"libs/tracing-utils",
|
||||
"libs/postgres_ffi/wal_craft",
|
||||
"libs/vm_monitor",
|
||||
"libs/walproposer",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -35,13 +36,19 @@ license = "Apache-2.0"
|
||||
## All dependency versions, used in the project
|
||||
[workspace.dependencies]
|
||||
anyhow = { version = "1.0", features = ["backtrace"] }
|
||||
arc-swap = "1.6"
|
||||
async-compression = { version = "0.4.0", features = ["tokio", "gzip"] }
|
||||
azure_core = "0.16"
|
||||
azure_identity = "0.16"
|
||||
azure_storage = "0.16"
|
||||
azure_storage_blobs = "0.16"
|
||||
flate2 = "1.0.26"
|
||||
async-stream = "0.3"
|
||||
async-trait = "0.1"
|
||||
aws-config = { version = "0.56", default-features = false, features=["rustls"] }
|
||||
aws-sdk-s3 = "0.29"
|
||||
aws-smithy-http = "0.56"
|
||||
aws-smithy-async = { version = "0.56", default-features = false, features=["rt-tokio"] }
|
||||
aws-credential-types = "0.56"
|
||||
aws-types = "0.56"
|
||||
axum = { version = "0.6.20", features = ["ws"] }
|
||||
@@ -60,7 +67,7 @@ comfy-table = "6.1"
|
||||
const_format = "0.2"
|
||||
crc32c = "0.6"
|
||||
crossbeam-utils = "0.8.5"
|
||||
dashmap = "5.5.0"
|
||||
dashmap = { version = "5.5.0", features = ["raw-api"] }
|
||||
either = "1.8"
|
||||
enum-map = "2.4.2"
|
||||
enumset = "1.0.12"
|
||||
@@ -76,6 +83,7 @@ hex = "0.4"
|
||||
hex-literal = "0.4"
|
||||
hmac = "0.12.1"
|
||||
hostname = "0.3.1"
|
||||
http-types = "2"
|
||||
humantime = "2.1"
|
||||
humantime-serde = "1.1.1"
|
||||
hyper = "0.14"
|
||||
@@ -118,6 +126,7 @@ sentry = { version = "0.31", default-features = false, features = ["backtrace",
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
serde_with = "2.0"
|
||||
serde_assert = "0.5.0"
|
||||
sha2 = "0.10.2"
|
||||
signal-hook = "0.3"
|
||||
smallvec = "1.11"
|
||||
@@ -155,11 +164,11 @@ env_logger = "0.10"
|
||||
log = "0.4"
|
||||
|
||||
## Libraries from neondatabase/ git forks, ideally with changes to be upstreamed
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" }
|
||||
postgres-native-tls = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" }
|
||||
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" }
|
||||
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" }
|
||||
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="ce7260db5998fe27167da42503905a12e7ad9048" }
|
||||
postgres-native-tls = { git = "https://github.com/neondatabase/rust-postgres.git", rev="ce7260db5998fe27167da42503905a12e7ad9048" }
|
||||
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", rev="ce7260db5998fe27167da42503905a12e7ad9048" }
|
||||
postgres-types = { git = "https://github.com/neondatabase/rust-postgres.git", rev="ce7260db5998fe27167da42503905a12e7ad9048" }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="ce7260db5998fe27167da42503905a12e7ad9048" }
|
||||
|
||||
## Other git libraries
|
||||
heapless = { default-features=false, features=[], git = "https://github.com/japaric/heapless.git", rev = "644653bf3b831c6bb4963be2de24804acf5e5001" } # upstream release pending
|
||||
@@ -180,6 +189,7 @@ tenant_size_model = { version = "0.1", path = "./libs/tenant_size_model/" }
|
||||
tracing-utils = { version = "0.1", path = "./libs/tracing-utils/" }
|
||||
utils = { version = "0.1", path = "./libs/utils/" }
|
||||
vm_monitor = { version = "0.1", path = "./libs/vm_monitor/" }
|
||||
walproposer = { version = "0.1", path = "./libs/walproposer/" }
|
||||
|
||||
## Common library dependency
|
||||
workspace_hack = { version = "0.1", path = "./workspace_hack/" }
|
||||
@@ -195,7 +205,7 @@ tonic-build = "0.9"
|
||||
|
||||
# This is only needed for proxy's tests.
|
||||
# TODO: we should probably fork `tokio-postgres-rustls` instead.
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="9011f7110db12b5e15afaf98f8ac834501d50ddc" }
|
||||
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", rev="ce7260db5998fe27167da42503905a12e7ad9048" }
|
||||
|
||||
################# Binary contents sections
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ RUN set -e \
|
||||
FROM $REPOSITORY/$IMAGE:$TAG AS build
|
||||
WORKDIR /home/nonroot
|
||||
ARG GIT_VERSION=local
|
||||
ARG BUILD_TAG
|
||||
|
||||
# Enable https://github.com/paritytech/cachepot to cache Rust crates' compilation results in Docker builds.
|
||||
# Set up cachepot to use an AWS S3 bucket for cache results, to reuse it between `docker build` invocations.
|
||||
@@ -78,9 +79,9 @@ COPY --from=build --chown=neon:neon /home/nonroot/target/release/pg_sni_router
|
||||
COPY --from=build --chown=neon:neon /home/nonroot/target/release/pageserver /usr/local/bin
|
||||
COPY --from=build --chown=neon:neon /home/nonroot/target/release/pagectl /usr/local/bin
|
||||
COPY --from=build --chown=neon:neon /home/nonroot/target/release/safekeeper /usr/local/bin
|
||||
COPY --from=build --chown=neon:neon /home/nonroot/target/release/storage_broker /usr/local/bin
|
||||
COPY --from=build --chown=neon:neon /home/nonroot/target/release/storage_broker /usr/local/bin
|
||||
COPY --from=build --chown=neon:neon /home/nonroot/target/release/proxy /usr/local/bin
|
||||
COPY --from=build --chown=neon:neon /home/nonroot/target/release/neon_local /usr/local/bin
|
||||
COPY --from=build --chown=neon:neon /home/nonroot/target/release/neon_local /usr/local/bin
|
||||
|
||||
COPY --from=pg-build /home/nonroot/pg_install/v14 /usr/local/v14/
|
||||
COPY --from=pg-build /home/nonroot/pg_install/v15 /usr/local/v15/
|
||||
|
||||
@@ -224,8 +224,8 @@ RUN wget https://github.com/df7cb/postgresql-unit/archive/refs/tags/7.7.tar.gz -
|
||||
FROM build-deps AS vector-pg-build
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
RUN wget https://github.com/pgvector/pgvector/archive/refs/tags/v0.5.0.tar.gz -O pgvector.tar.gz && \
|
||||
echo "d8aa3504b215467ca528525a6de12c3f85f9891b091ce0e5864dd8a9b757f77b pgvector.tar.gz" | sha256sum --check && \
|
||||
RUN wget https://github.com/pgvector/pgvector/archive/refs/tags/v0.5.1.tar.gz -O pgvector.tar.gz && \
|
||||
echo "cc7a8e034a96e30a819911ac79d32f6bc47bdd1aa2de4d7d4904e26b83209dc8 pgvector.tar.gz" | sha256sum --check && \
|
||||
mkdir pgvector-src && cd pgvector-src && tar xvzf ../pgvector.tar.gz --strip-components=1 -C . && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config && \
|
||||
@@ -368,8 +368,8 @@ RUN wget https://github.com/citusdata/postgresql-hll/archive/refs/tags/v2.18.tar
|
||||
FROM build-deps AS plpgsql-check-pg-build
|
||||
COPY --from=pg-build /usr/local/pgsql/ /usr/local/pgsql/
|
||||
|
||||
RUN wget https://github.com/okbob/plpgsql_check/archive/refs/tags/v2.4.0.tar.gz -O plpgsql_check.tar.gz && \
|
||||
echo "9ba58387a279b35a3bfa39ee611e5684e6cddb2ba046ddb2c5190b3bd2ca254a plpgsql_check.tar.gz" | sha256sum --check && \
|
||||
RUN wget https://github.com/okbob/plpgsql_check/archive/refs/tags/v2.5.3.tar.gz -O plpgsql_check.tar.gz && \
|
||||
echo "6631ec3e7fb3769eaaf56e3dfedb829aa761abf163d13dba354b4c218508e1c0 plpgsql_check.tar.gz" | sha256sum --check && \
|
||||
mkdir plpgsql_check-src && cd plpgsql_check-src && tar xvzf ../plpgsql_check.tar.gz --strip-components=1 -C . && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
|
||||
make -j $(getconf _NPROCESSORS_ONLN) install PG_CONFIG=/usr/local/pgsql/bin/pg_config USE_PGXS=1 && \
|
||||
|
||||
42
Makefile
42
Makefile
@@ -62,7 +62,7 @@ all: neon postgres neon-pg-ext
|
||||
#
|
||||
# The 'postgres_ffi' depends on the Postgres headers.
|
||||
.PHONY: neon
|
||||
neon: postgres-headers
|
||||
neon: postgres-headers walproposer-lib
|
||||
+@echo "Compiling Neon"
|
||||
$(CARGO_CMD_PREFIX) cargo build $(CARGO_BUILD_FLAGS)
|
||||
|
||||
@@ -72,6 +72,10 @@ neon: postgres-headers
|
||||
#
|
||||
$(POSTGRES_INSTALL_DIR)/build/%/config.status:
|
||||
+@echo "Configuring Postgres $* build"
|
||||
@test -s $(ROOT_PROJECT_DIR)/vendor/postgres-$*/configure || { \
|
||||
echo "\nPostgres submodule not found in $(ROOT_PROJECT_DIR)/vendor/postgres-$*/, execute "; \
|
||||
echo "'git submodule update --init --recursive --depth 2 --progress .' in project root.\n"; \
|
||||
exit 1; }
|
||||
mkdir -p $(POSTGRES_INSTALL_DIR)/build/$*
|
||||
(cd $(POSTGRES_INSTALL_DIR)/build/$* && \
|
||||
env PATH="$(EXTRA_PATH_OVERRIDES):$$PATH" $(ROOT_PROJECT_DIR)/vendor/postgres-$*/configure \
|
||||
@@ -168,6 +172,42 @@ neon-pg-ext-clean-%:
|
||||
-C $(POSTGRES_INSTALL_DIR)/build/neon-utils-$* \
|
||||
-f $(ROOT_PROJECT_DIR)/pgxn/neon_utils/Makefile clean
|
||||
|
||||
# Build walproposer as a static library. walproposer source code is located
|
||||
# in the pgxn/neon directory.
|
||||
#
|
||||
# We also need to include libpgport.a and libpgcommon.a, because walproposer
|
||||
# uses some functions from those libraries.
|
||||
#
|
||||
# Some object files are removed from libpgport.a and libpgcommon.a because
|
||||
# they depend on openssl and other libraries that are not included in our
|
||||
# Rust build.
|
||||
.PHONY: walproposer-lib
|
||||
walproposer-lib: neon-pg-ext-v16
|
||||
+@echo "Compiling walproposer-lib"
|
||||
mkdir -p $(POSTGRES_INSTALL_DIR)/build/walproposer-lib
|
||||
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v16/bin/pg_config CFLAGS='$(PG_CFLAGS) $(COPT)' \
|
||||
-C $(POSTGRES_INSTALL_DIR)/build/walproposer-lib \
|
||||
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile walproposer-lib
|
||||
cp $(POSTGRES_INSTALL_DIR)/v16/lib/libpgport.a $(POSTGRES_INSTALL_DIR)/build/walproposer-lib
|
||||
cp $(POSTGRES_INSTALL_DIR)/v16/lib/libpgcommon.a $(POSTGRES_INSTALL_DIR)/build/walproposer-lib
|
||||
ifeq ($(UNAME_S),Linux)
|
||||
$(AR) d $(POSTGRES_INSTALL_DIR)/build/walproposer-lib/libpgport.a \
|
||||
pg_strong_random.o
|
||||
$(AR) d $(POSTGRES_INSTALL_DIR)/build/walproposer-lib/libpgcommon.a \
|
||||
pg_crc32c.o \
|
||||
hmac_openssl.o \
|
||||
cryptohash_openssl.o \
|
||||
scram-common.o \
|
||||
md5_common.o \
|
||||
checksum_helper.o
|
||||
endif
|
||||
|
||||
.PHONY: walproposer-lib-clean
|
||||
walproposer-lib-clean:
|
||||
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/v16/bin/pg_config \
|
||||
-C $(POSTGRES_INSTALL_DIR)/build/walproposer-lib \
|
||||
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile clean
|
||||
|
||||
.PHONY: neon-pg-ext
|
||||
neon-pg-ext: \
|
||||
neon-pg-ext-v14 \
|
||||
|
||||
4
NOTICE
4
NOTICE
@@ -1,5 +1,5 @@
|
||||
Neon
|
||||
Copyright 2022 Neon Inc.
|
||||
|
||||
The PostgreSQL submodules in vendor/postgres-v14 and vendor/postgres-v15 are licensed under the
|
||||
PostgreSQL license. See vendor/postgres-v14/COPYRIGHT and vendor/postgres-v15/COPYRIGHT.
|
||||
The PostgreSQL submodules in vendor/ are licensed under the PostgreSQL license.
|
||||
See vendor/postgres-vX/COPYRIGHT for details.
|
||||
|
||||
@@ -156,6 +156,7 @@ fn main() -> Result<()> {
|
||||
let path = Path::new(sp);
|
||||
let file = File::open(path)?;
|
||||
spec = Some(serde_json::from_reader(file)?);
|
||||
live_config_allowed = true;
|
||||
} else if let Some(id) = compute_id {
|
||||
if let Some(cp_base) = control_plane_uri {
|
||||
live_config_allowed = true;
|
||||
@@ -277,32 +278,26 @@ fn main() -> Result<()> {
|
||||
if #[cfg(target_os = "linux")] {
|
||||
use std::env;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::warn;
|
||||
let vm_monitor_addr = matches.get_one::<String>("vm-monitor-addr");
|
||||
let vm_monitor_addr = matches
|
||||
.get_one::<String>("vm-monitor-addr")
|
||||
.expect("--vm-monitor-addr should always be set because it has a default arg");
|
||||
let file_cache_connstr = matches.get_one::<String>("filecache-connstr");
|
||||
let cgroup = matches.get_one::<String>("cgroup");
|
||||
let file_cache_on_disk = matches.get_flag("file-cache-on-disk");
|
||||
|
||||
// Only make a runtime if we need to.
|
||||
// Note: it seems like you can make a runtime in an inner scope and
|
||||
// if you start a task in it it won't be dropped. However, make it
|
||||
// in the outermost scope just to be safe.
|
||||
let rt = match (env::var_os("AUTOSCALING"), vm_monitor_addr) {
|
||||
(None, None) => None,
|
||||
(None, Some(_)) => {
|
||||
warn!("--vm-monitor-addr option set but AUTOSCALING env var not present");
|
||||
None
|
||||
}
|
||||
(Some(_), None) => {
|
||||
panic!("AUTOSCALING env var present but --vm-monitor-addr option not set")
|
||||
}
|
||||
(Some(_), Some(_)) => Some(
|
||||
let rt = if env::var_os("AUTOSCALING").is_some() {
|
||||
Some(
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(4)
|
||||
.enable_all()
|
||||
.build()
|
||||
.expect("failed to create tokio runtime for monitor"),
|
||||
),
|
||||
.expect("failed to create tokio runtime for monitor")
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// This token is used internally by the monitor to clean up all threads
|
||||
@@ -313,8 +308,7 @@ fn main() -> Result<()> {
|
||||
Box::leak(Box::new(vm_monitor::Args {
|
||||
cgroup: cgroup.cloned(),
|
||||
pgconnstr: file_cache_connstr.cloned(),
|
||||
addr: vm_monitor_addr.cloned().unwrap(),
|
||||
file_cache_on_disk,
|
||||
addr: vm_monitor_addr.clone(),
|
||||
})),
|
||||
token.clone(),
|
||||
))
|
||||
@@ -486,6 +480,8 @@ fn cli() -> clap::Command {
|
||||
.value_name("FILECACHE_CONNSTR"),
|
||||
)
|
||||
.arg(
|
||||
// DEPRECATED, NO LONGER DOES ANYTHING.
|
||||
// See https://github.com/neondatabase/cloud/issues/7516
|
||||
Arg::new("file-cache-on-disk")
|
||||
.long("file-cache-on-disk")
|
||||
.action(clap::ArgAction::SetTrue),
|
||||
|
||||
@@ -252,7 +252,7 @@ fn create_neon_superuser(spec: &ComputeSpec, client: &mut Client) -> Result<()>
|
||||
IF NOT EXISTS (
|
||||
SELECT FROM pg_catalog.pg_roles WHERE rolname = 'neon_superuser')
|
||||
THEN
|
||||
CREATE ROLE neon_superuser CREATEDB CREATEROLE NOLOGIN IN ROLE pg_read_all_data, pg_write_all_data;
|
||||
CREATE ROLE neon_superuser CREATEDB CREATEROLE NOLOGIN REPLICATION IN ROLE pg_read_all_data, pg_write_all_data;
|
||||
IF array_length(roles, 1) IS NOT NULL THEN
|
||||
EXECUTE format('GRANT neon_superuser TO %s',
|
||||
array_to_string(ARRAY(SELECT quote_ident(x) FROM unnest(roles) as x), ', '));
|
||||
@@ -692,10 +692,11 @@ impl ComputeNode {
|
||||
// Proceed with post-startup configuration. Note, that order of operations is important.
|
||||
let spec = &compute_state.pspec.as_ref().expect("spec must be set").spec;
|
||||
create_neon_superuser(spec, &mut client)?;
|
||||
cleanup_instance(&mut client)?;
|
||||
handle_roles(spec, &mut client)?;
|
||||
handle_databases(spec, &mut client)?;
|
||||
handle_role_deletions(spec, self.connstr.as_str(), &mut client)?;
|
||||
handle_grants(spec, self.connstr.as_str())?;
|
||||
handle_grants(spec, &mut client, self.connstr.as_str())?;
|
||||
handle_extensions(spec, &mut client)?;
|
||||
create_availability_check_data(&mut client)?;
|
||||
|
||||
@@ -709,8 +710,12 @@ impl ComputeNode {
|
||||
// `pg_ctl` for start / stop, so this just seems much easier to do as we already
|
||||
// have opened connection to Postgres and superuser access.
|
||||
#[instrument(skip_all)]
|
||||
fn pg_reload_conf(&self, client: &mut Client) -> Result<()> {
|
||||
client.simple_query("SELECT pg_reload_conf()")?;
|
||||
fn pg_reload_conf(&self) -> Result<()> {
|
||||
let pgctl_bin = Path::new(&self.pgbin).parent().unwrap().join("pg_ctl");
|
||||
Command::new(pgctl_bin)
|
||||
.args(["reload", "-D", &self.pgdata])
|
||||
.output()
|
||||
.expect("cannot run pg_ctl process");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -723,18 +728,19 @@ impl ComputeNode {
|
||||
// Write new config
|
||||
let pgdata_path = Path::new(&self.pgdata);
|
||||
config::write_postgres_conf(&pgdata_path.join("postgresql.conf"), &spec, None)?;
|
||||
self.pg_reload_conf()?;
|
||||
|
||||
let mut client = Client::connect(self.connstr.as_str(), NoTls)?;
|
||||
self.pg_reload_conf(&mut client)?;
|
||||
|
||||
// Proceed with post-startup configuration. Note, that order of operations is important.
|
||||
// Disable DDL forwarding because control plane already knows about these roles/databases.
|
||||
if spec.mode == ComputeMode::Primary {
|
||||
client.simple_query("SET neon.forward_ddl = false")?;
|
||||
cleanup_instance(&mut client)?;
|
||||
handle_roles(&spec, &mut client)?;
|
||||
handle_databases(&spec, &mut client)?;
|
||||
handle_role_deletions(&spec, self.connstr.as_str(), &mut client)?;
|
||||
handle_grants(&spec, self.connstr.as_str())?;
|
||||
handle_grants(&spec, &mut client, self.connstr.as_str())?;
|
||||
handle_extensions(&spec, &mut client)?;
|
||||
}
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ use regex::Regex;
|
||||
use remote_storage::*;
|
||||
use serde_json;
|
||||
use std::io::Read;
|
||||
use std::num::{NonZeroU32, NonZeroUsize};
|
||||
use std::num::NonZeroUsize;
|
||||
use std::path::Path;
|
||||
use std::str;
|
||||
use tar::Archive;
|
||||
@@ -281,8 +281,6 @@ pub fn init_remote_storage(remote_ext_config: &str) -> anyhow::Result<GenericRem
|
||||
max_keys_per_list_response: None,
|
||||
};
|
||||
let config = RemoteStorageConfig {
|
||||
max_concurrent_syncs: NonZeroUsize::new(100).expect("100 != 0"),
|
||||
max_sync_errors: NonZeroU32::new(100).expect("100 != 0"),
|
||||
storage: RemoteStorageKind::AwsS3(config),
|
||||
};
|
||||
GenericRemoteStorage::from_config(&config)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//!
|
||||
//! Various tools and helpers to handle cluster / compute node (Postgres)
|
||||
//! configuration.
|
||||
//!
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
pub mod checker;
|
||||
pub mod config;
|
||||
pub mod configurator;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Write;
|
||||
use std::fs;
|
||||
use std::fs::File;
|
||||
@@ -192,11 +193,16 @@ impl Escaping for PgIdent {
|
||||
/// Build a list of existing Postgres roles
|
||||
pub fn get_existing_roles(xact: &mut Transaction<'_>) -> Result<Vec<Role>> {
|
||||
let postgres_roles = xact
|
||||
.query("SELECT rolname, rolpassword FROM pg_catalog.pg_authid", &[])?
|
||||
.query(
|
||||
"SELECT rolname, rolpassword, rolreplication, rolbypassrls FROM pg_catalog.pg_authid",
|
||||
&[],
|
||||
)?
|
||||
.iter()
|
||||
.map(|row| Role {
|
||||
name: row.get("rolname"),
|
||||
encrypted_password: row.get("rolpassword"),
|
||||
replication: Some(row.get("rolreplication")),
|
||||
bypassrls: Some(row.get("rolbypassrls")),
|
||||
options: None,
|
||||
})
|
||||
.collect();
|
||||
@@ -205,22 +211,37 @@ pub fn get_existing_roles(xact: &mut Transaction<'_>) -> Result<Vec<Role>> {
|
||||
}
|
||||
|
||||
/// Build a list of existing Postgres databases
|
||||
pub fn get_existing_dbs(client: &mut Client) -> Result<Vec<Database>> {
|
||||
let postgres_dbs = client
|
||||
pub fn get_existing_dbs(client: &mut Client) -> Result<HashMap<String, Database>> {
|
||||
// `pg_database.datconnlimit = -2` means that the database is in the
|
||||
// invalid state. See:
|
||||
// https://github.com/postgres/postgres/commit/a4b4cc1d60f7e8ccfcc8ff8cb80c28ee411ad9a9
|
||||
let postgres_dbs: Vec<Database> = client
|
||||
.query(
|
||||
"SELECT datname, datdba::regrole::text as owner
|
||||
FROM pg_catalog.pg_database;",
|
||||
"SELECT
|
||||
datname AS name,
|
||||
datdba::regrole::text AS owner,
|
||||
NOT datallowconn AS restrict_conn,
|
||||
datconnlimit = - 2 AS invalid
|
||||
FROM
|
||||
pg_catalog.pg_database;",
|
||||
&[],
|
||||
)?
|
||||
.iter()
|
||||
.map(|row| Database {
|
||||
name: row.get("datname"),
|
||||
name: row.get("name"),
|
||||
owner: row.get("owner"),
|
||||
restrict_conn: row.get("restrict_conn"),
|
||||
invalid: row.get("invalid"),
|
||||
options: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(postgres_dbs)
|
||||
let dbs_map = postgres_dbs
|
||||
.iter()
|
||||
.map(|db| (db.name.clone(), db.clone()))
|
||||
.collect::<HashMap<_, _>>();
|
||||
|
||||
Ok(dbs_map)
|
||||
}
|
||||
|
||||
/// Wait for Postgres to become ready to accept connections. It's ready to
|
||||
|
||||
@@ -13,7 +13,7 @@ use crate::params::PG_HBA_ALL_MD5;
|
||||
use crate::pg_helpers::*;
|
||||
|
||||
use compute_api::responses::{ControlPlaneComputeStatus, ControlPlaneSpecResponse};
|
||||
use compute_api::spec::{ComputeSpec, Database, PgIdent, Role};
|
||||
use compute_api::spec::{ComputeSpec, PgIdent, Role};
|
||||
|
||||
// Do control plane request and return response if any. In case of error it
|
||||
// returns a bool flag indicating whether it makes sense to retry the request
|
||||
@@ -24,7 +24,7 @@ fn do_control_plane_request(
|
||||
) -> Result<ControlPlaneSpecResponse, (bool, String)> {
|
||||
let resp = reqwest::blocking::Client::new()
|
||||
.get(uri)
|
||||
.header("Authorization", jwt)
|
||||
.header("Authorization", format!("Bearer {}", jwt))
|
||||
.send()
|
||||
.map_err(|e| {
|
||||
(
|
||||
@@ -68,7 +68,7 @@ pub fn get_spec_from_control_plane(
|
||||
base_uri: &str,
|
||||
compute_id: &str,
|
||||
) -> Result<Option<ComputeSpec>> {
|
||||
let cp_uri = format!("{base_uri}/management/api/v2/computes/{compute_id}/spec");
|
||||
let cp_uri = format!("{base_uri}/compute/api/v2/computes/{compute_id}/spec");
|
||||
let jwt: String = match std::env::var("NEON_CONTROL_PLANE_TOKEN") {
|
||||
Ok(v) => v,
|
||||
Err(_) => "".to_string(),
|
||||
@@ -161,6 +161,38 @@ pub fn add_standby_signal(pgdata_path: &Path) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Compute could be unexpectedly shut down, for example, during the
|
||||
/// database dropping. This leaves the database in the invalid state,
|
||||
/// which prevents new db creation with the same name. This function
|
||||
/// will clean it up before proceeding with catalog updates. All
|
||||
/// possible future cleanup operations may go here too.
|
||||
#[instrument(skip_all)]
|
||||
pub fn cleanup_instance(client: &mut Client) -> Result<()> {
|
||||
let existing_dbs = get_existing_dbs(client)?;
|
||||
|
||||
for (_, db) in existing_dbs {
|
||||
if db.invalid {
|
||||
// After recent commit in Postgres, interrupted DROP DATABASE
|
||||
// leaves the database in the invalid state. According to the
|
||||
// commit message, the only option for user is to drop it again.
|
||||
// See:
|
||||
// https://github.com/postgres/postgres/commit/a4b4cc1d60f7e8ccfcc8ff8cb80c28ee411ad9a9
|
||||
//
|
||||
// Postgres Neon extension is done the way, that db is de-registered
|
||||
// in the control plane metadata only after it is dropped. So there is
|
||||
// a chance that it still thinks that db should exist. This means
|
||||
// that it will be re-created by `handle_databases()`. Yet, it's fine
|
||||
// as user can just repeat drop (in vanilla Postgres they would need
|
||||
// to do the same, btw).
|
||||
let query = format!("DROP DATABASE IF EXISTS {}", db.name.pg_quote());
|
||||
info!("dropping invalid database {}", db.name);
|
||||
client.execute(query.as_str(), &[])?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Given a cluster spec json and open transaction it handles roles creation,
|
||||
/// deletion and update.
|
||||
#[instrument(skip_all)]
|
||||
@@ -233,6 +265,8 @@ pub fn handle_roles(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
|
||||
let action = if let Some(r) = pg_role {
|
||||
if (r.encrypted_password.is_none() && role.encrypted_password.is_some())
|
||||
|| (r.encrypted_password.is_some() && role.encrypted_password.is_none())
|
||||
|| !r.bypassrls.unwrap_or(false)
|
||||
|| !r.replication.unwrap_or(false)
|
||||
{
|
||||
RoleAction::Update
|
||||
} else if let Some(pg_pwd) = &r.encrypted_password {
|
||||
@@ -264,13 +298,14 @@ pub fn handle_roles(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
|
||||
match action {
|
||||
RoleAction::None => {}
|
||||
RoleAction::Update => {
|
||||
let mut query: String = format!("ALTER ROLE {} ", name.pg_quote());
|
||||
let mut query: String =
|
||||
format!("ALTER ROLE {} BYPASSRLS REPLICATION", name.pg_quote());
|
||||
query.push_str(&role.to_pg_options());
|
||||
xact.execute(query.as_str(), &[])?;
|
||||
}
|
||||
RoleAction::Create => {
|
||||
let mut query: String = format!(
|
||||
"CREATE ROLE {} CREATEROLE CREATEDB BYPASSRLS IN ROLE neon_superuser",
|
||||
"CREATE ROLE {} CREATEROLE CREATEDB BYPASSRLS REPLICATION IN ROLE neon_superuser",
|
||||
name.pg_quote()
|
||||
);
|
||||
info!("role create query: '{}'", &query);
|
||||
@@ -379,13 +414,13 @@ fn reassign_owned_objects(spec: &ComputeSpec, connstr: &str, role_name: &PgIdent
|
||||
/// which together provide us idempotency.
|
||||
#[instrument(skip_all)]
|
||||
pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
|
||||
let existing_dbs: Vec<Database> = get_existing_dbs(client)?;
|
||||
let existing_dbs = get_existing_dbs(client)?;
|
||||
|
||||
// Print a list of existing Postgres databases (only in debug mode)
|
||||
if span_enabled!(Level::INFO) {
|
||||
info!("postgres databases:");
|
||||
for r in &existing_dbs {
|
||||
info!(" {}:{}", r.name, r.owner);
|
||||
for (dbname, db) in &existing_dbs {
|
||||
info!(" {}:{}", dbname, db.owner);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -439,8 +474,7 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
|
||||
"rename_db" => {
|
||||
let new_name = op.new_name.as_ref().unwrap();
|
||||
|
||||
// XXX: with a limited number of roles it is fine, but consider making it a HashMap
|
||||
if existing_dbs.iter().any(|r| r.name == op.name) {
|
||||
if existing_dbs.get(&op.name).is_some() {
|
||||
let query: String = format!(
|
||||
"ALTER DATABASE {} RENAME TO {}",
|
||||
op.name.pg_quote(),
|
||||
@@ -457,14 +491,12 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
|
||||
}
|
||||
|
||||
// Refresh Postgres databases info to handle possible renames
|
||||
let existing_dbs: Vec<Database> = get_existing_dbs(client)?;
|
||||
let existing_dbs = get_existing_dbs(client)?;
|
||||
|
||||
info!("cluster spec databases:");
|
||||
for db in &spec.cluster.databases {
|
||||
let name = &db.name;
|
||||
|
||||
// XXX: with a limited number of databases it is fine, but consider making it a HashMap
|
||||
let pg_db = existing_dbs.iter().find(|r| r.name == *name);
|
||||
let pg_db = existing_dbs.get(name);
|
||||
|
||||
enum DatabaseAction {
|
||||
None,
|
||||
@@ -530,13 +562,32 @@ pub fn handle_databases(spec: &ComputeSpec, client: &mut Client) -> Result<()> {
|
||||
/// Grant CREATE ON DATABASE to the database owner and do some other alters and grants
|
||||
/// to allow users creating trusted extensions and re-creating `public` schema, for example.
|
||||
#[instrument(skip_all)]
|
||||
pub fn handle_grants(spec: &ComputeSpec, connstr: &str) -> Result<()> {
|
||||
info!("cluster spec grants:");
|
||||
pub fn handle_grants(spec: &ComputeSpec, client: &mut Client, connstr: &str) -> Result<()> {
|
||||
info!("modifying database permissions");
|
||||
let existing_dbs = get_existing_dbs(client)?;
|
||||
|
||||
// Do some per-database access adjustments. We'd better do this at db creation time,
|
||||
// but CREATE DATABASE isn't transactional. So we cannot create db + do some grants
|
||||
// atomically.
|
||||
for db in &spec.cluster.databases {
|
||||
match existing_dbs.get(&db.name) {
|
||||
Some(pg_db) => {
|
||||
if pg_db.restrict_conn || pg_db.invalid {
|
||||
info!(
|
||||
"skipping grants for db {} (invalid: {}, connections not allowed: {})",
|
||||
db.name, pg_db.invalid, pg_db.restrict_conn
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
bail!(
|
||||
"database {} doesn't exist in Postgres after handle_databases()",
|
||||
db.name
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let mut conf = Config::from_str(connstr)?;
|
||||
conf.dbname(&db.name);
|
||||
|
||||
@@ -575,6 +626,11 @@ pub fn handle_grants(spec: &ComputeSpec, connstr: &str) -> Result<()> {
|
||||
|
||||
// Explicitly grant CREATE ON SCHEMA PUBLIC to the web_access user.
|
||||
// This is needed because since postgres 15 this privilege is removed by default.
|
||||
// TODO: web_access isn't created for almost 1 year. It could be that we have
|
||||
// active users of 1 year old projects, but hopefully not, so check it and
|
||||
// remove this code if possible. The worst thing that could happen is that
|
||||
// user won't be able to use public schema in NEW databases created in the
|
||||
// very OLD project.
|
||||
let grant_query = "DO $$\n\
|
||||
BEGIN\n\
|
||||
IF EXISTS(\n\
|
||||
|
||||
@@ -28,7 +28,7 @@ mod pg_helpers_tests {
|
||||
assert_eq!(
|
||||
spec.cluster.settings.as_pg_settings(),
|
||||
r#"fsync = off
|
||||
wal_level = replica
|
||||
wal_level = logical
|
||||
hot_standby = on
|
||||
neon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'
|
||||
wal_log_hints = on
|
||||
|
||||
@@ -2,7 +2,6 @@ use crate::{background_process, local_env::LocalEnv};
|
||||
use anyhow::anyhow;
|
||||
use camino::Utf8PathBuf;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
use std::{path::PathBuf, process::Child};
|
||||
use utils::id::{NodeId, TenantId};
|
||||
|
||||
@@ -14,12 +13,10 @@ pub struct AttachmentService {
|
||||
|
||||
const COMMAND: &str = "attachment_service";
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct AttachHookRequest {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub tenant_id: TenantId,
|
||||
pub pageserver_id: Option<NodeId>,
|
||||
pub node_id: Option<NodeId>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
@@ -85,7 +82,7 @@ impl AttachmentService {
|
||||
.control_plane_api
|
||||
.clone()
|
||||
.unwrap()
|
||||
.join("attach_hook")
|
||||
.join("attach-hook")
|
||||
.unwrap();
|
||||
let client = reqwest::blocking::ClientBuilder::new()
|
||||
.build()
|
||||
@@ -93,7 +90,7 @@ impl AttachmentService {
|
||||
|
||||
let request = AttachHookRequest {
|
||||
tenant_id,
|
||||
pageserver_id: Some(pageserver_id),
|
||||
node_id: Some(pageserver_id),
|
||||
};
|
||||
|
||||
let response = client.post(url).json(&request).send()?;
|
||||
|
||||
@@ -86,7 +86,7 @@ where
|
||||
.stdout(process_log_file)
|
||||
.stderr(same_file_for_stderr)
|
||||
.args(args);
|
||||
let filled_cmd = fill_aws_secrets_vars(fill_rust_env_vars(background_command));
|
||||
let filled_cmd = fill_remote_storage_secrets_vars(fill_rust_env_vars(background_command));
|
||||
filled_cmd.envs(envs);
|
||||
|
||||
let pid_file_to_check = match initial_pid_file {
|
||||
@@ -238,11 +238,13 @@ fn fill_rust_env_vars(cmd: &mut Command) -> &mut Command {
|
||||
filled_cmd
|
||||
}
|
||||
|
||||
fn fill_aws_secrets_vars(mut cmd: &mut Command) -> &mut Command {
|
||||
fn fill_remote_storage_secrets_vars(mut cmd: &mut Command) -> &mut Command {
|
||||
for env_key in [
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
"AWS_SESSION_TOKEN",
|
||||
"AZURE_STORAGE_ACCOUNT",
|
||||
"AZURE_STORAGE_ACCESS_KEY",
|
||||
] {
|
||||
if let Ok(value) = std::env::var(env_key) {
|
||||
cmd = cmd.env(env_key, value);
|
||||
@@ -260,7 +262,7 @@ where
|
||||
P: Into<Utf8PathBuf>,
|
||||
{
|
||||
let path: Utf8PathBuf = path.into();
|
||||
// SAFETY
|
||||
// SAFETY:
|
||||
// pre_exec is marked unsafe because it runs between fork and exec.
|
||||
// Why is that dangerous in various ways?
|
||||
// Long answer: https://github.com/rust-lang/rust/issues/39575
|
||||
|
||||
@@ -12,7 +12,9 @@ use hyper::{Body, Request, Response};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use utils::http::endpoint::request_span;
|
||||
use utils::logging::{self, LogFormat};
|
||||
use utils::signals::{ShutdownSignals, Signal};
|
||||
|
||||
use utils::{
|
||||
http::{
|
||||
@@ -170,7 +172,7 @@ async fn handle_re_attach(mut req: Request<Body>) -> Result<Response<Body>, ApiE
|
||||
state.generation += 1;
|
||||
response.tenants.push(ReAttachResponseTenant {
|
||||
id: *t,
|
||||
generation: state.generation,
|
||||
gen: state.generation,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -216,14 +218,31 @@ async fn handle_attach_hook(mut req: Request<Body>) -> Result<Response<Body>, Ap
|
||||
.tenants
|
||||
.entry(attach_req.tenant_id)
|
||||
.or_insert_with(|| TenantState {
|
||||
pageserver: attach_req.pageserver_id,
|
||||
pageserver: attach_req.node_id,
|
||||
generation: 0,
|
||||
});
|
||||
|
||||
if attach_req.pageserver_id.is_some() {
|
||||
if let Some(attaching_pageserver) = attach_req.node_id.as_ref() {
|
||||
tenant_state.generation += 1;
|
||||
tracing::info!(
|
||||
tenant_id = %attach_req.tenant_id,
|
||||
ps_id = %attaching_pageserver,
|
||||
generation = %tenant_state.generation,
|
||||
"issuing",
|
||||
);
|
||||
} else if let Some(ps_id) = tenant_state.pageserver {
|
||||
tracing::info!(
|
||||
tenant_id = %attach_req.tenant_id,
|
||||
%ps_id,
|
||||
generation = %tenant_state.generation,
|
||||
"dropping",
|
||||
);
|
||||
} else {
|
||||
tracing::info!(
|
||||
tenant_id = %attach_req.tenant_id,
|
||||
"no-op: tenant already has no pageserver");
|
||||
}
|
||||
tenant_state.pageserver = attach_req.pageserver_id;
|
||||
tenant_state.pageserver = attach_req.node_id;
|
||||
let generation = tenant_state.generation;
|
||||
|
||||
locked.save().await.map_err(ApiError::InternalServerError)?;
|
||||
@@ -231,7 +250,7 @@ async fn handle_attach_hook(mut req: Request<Body>) -> Result<Response<Body>, Ap
|
||||
json_response(
|
||||
StatusCode::OK,
|
||||
AttachHookResponse {
|
||||
gen: attach_req.pageserver_id.map(|_| generation),
|
||||
gen: attach_req.node_id.map(|_| generation),
|
||||
},
|
||||
)
|
||||
}
|
||||
@@ -239,9 +258,9 @@ async fn handle_attach_hook(mut req: Request<Body>) -> Result<Response<Body>, Ap
|
||||
fn make_router(persistent_state: PersistentState) -> RouterBuilder<hyper::Body, ApiError> {
|
||||
endpoint::make_router()
|
||||
.data(Arc::new(State::new(persistent_state)))
|
||||
.post("/re-attach", handle_re_attach)
|
||||
.post("/validate", handle_validate)
|
||||
.post("/attach_hook", handle_attach_hook)
|
||||
.post("/re-attach", |r| request_span(r, handle_re_attach))
|
||||
.post("/validate", |r| request_span(r, handle_validate))
|
||||
.post("/attach-hook", |r| request_span(r, handle_attach_hook))
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@@ -268,7 +287,16 @@ async fn main() -> anyhow::Result<()> {
|
||||
let server = hyper::Server::from_tcp(http_listener)?.serve(service);
|
||||
|
||||
tracing::info!("Serving on {0}", args.listen);
|
||||
server.await?;
|
||||
|
||||
tokio::task::spawn(server);
|
||||
|
||||
ShutdownSignals::handle(|signal| match signal {
|
||||
Signal::Interrupt | Signal::Terminate | Signal::Quit => {
|
||||
tracing::info!("Got {}. Terminating", signal.name());
|
||||
// We're just a test helper: no graceful shutdown.
|
||||
std::process::exit(0);
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -116,6 +116,7 @@ fn main() -> Result<()> {
|
||||
"attachment_service" => handle_attachment_service(sub_args, &env),
|
||||
"safekeeper" => handle_safekeeper(sub_args, &env),
|
||||
"endpoint" => handle_endpoint(sub_args, &env),
|
||||
"mappings" => handle_mappings(sub_args, &mut env),
|
||||
"pg" => bail!("'pg' subcommand has been renamed to 'endpoint'"),
|
||||
_ => bail!("unexpected subcommand {sub_name}"),
|
||||
};
|
||||
@@ -797,6 +798,24 @@ fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<(
|
||||
ep.start(&auth_token, safekeepers, remote_ext_config)?;
|
||||
}
|
||||
}
|
||||
"reconfigure" => {
|
||||
let endpoint_id = sub_args
|
||||
.get_one::<String>("endpoint_id")
|
||||
.ok_or_else(|| anyhow!("No endpoint ID provided to reconfigure"))?;
|
||||
let endpoint = cplane
|
||||
.endpoints
|
||||
.get(endpoint_id.as_str())
|
||||
.with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?;
|
||||
let pageserver_id =
|
||||
if let Some(id_str) = sub_args.get_one::<String>("endpoint-pageserver-id") {
|
||||
Some(NodeId(
|
||||
id_str.parse().context("while parsing pageserver id")?,
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
endpoint.reconfigure(pageserver_id)?;
|
||||
}
|
||||
"stop" => {
|
||||
let endpoint_id = sub_args
|
||||
.get_one::<String>("endpoint_id")
|
||||
@@ -816,6 +835,38 @@ fn handle_endpoint(ep_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_mappings(sub_match: &ArgMatches, env: &mut local_env::LocalEnv) -> Result<()> {
|
||||
let (sub_name, sub_args) = match sub_match.subcommand() {
|
||||
Some(ep_subcommand_data) => ep_subcommand_data,
|
||||
None => bail!("no mappings subcommand provided"),
|
||||
};
|
||||
|
||||
match sub_name {
|
||||
"map" => {
|
||||
let branch_name = sub_args
|
||||
.get_one::<String>("branch-name")
|
||||
.expect("branch-name argument missing");
|
||||
|
||||
let tenant_id = sub_args
|
||||
.get_one::<String>("tenant-id")
|
||||
.map(|x| TenantId::from_str(x))
|
||||
.expect("tenant-id argument missing")
|
||||
.expect("malformed tenant-id arg");
|
||||
|
||||
let timeline_id = sub_args
|
||||
.get_one::<String>("timeline-id")
|
||||
.map(|x| TimelineId::from_str(x))
|
||||
.expect("timeline-id argument missing")
|
||||
.expect("malformed timeline-id arg");
|
||||
|
||||
env.register_branch_mapping(branch_name.to_owned(), tenant_id, timeline_id)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
other => unimplemented!("mappings subcommand {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_pageserver(sub_match: &ArgMatches, env: &local_env::LocalEnv) -> Result<()> {
|
||||
fn get_pageserver(env: &local_env::LocalEnv, args: &ArgMatches) -> Result<PageServerNode> {
|
||||
let node_id = if let Some(id_str) = args.get_one::<String>("pageserver-id") {
|
||||
@@ -1084,6 +1135,7 @@ fn cli() -> Command {
|
||||
// --id, when using a pageserver command
|
||||
let pageserver_id_arg = Arg::new("pageserver-id")
|
||||
.long("id")
|
||||
.global(true)
|
||||
.help("pageserver id")
|
||||
.required(false);
|
||||
// --pageserver-id when using a non-pageserver command
|
||||
@@ -1254,17 +1306,20 @@ fn cli() -> Command {
|
||||
Command::new("pageserver")
|
||||
.arg_required_else_help(true)
|
||||
.about("Manage pageserver")
|
||||
.arg(pageserver_id_arg)
|
||||
.subcommand(Command::new("status"))
|
||||
.arg(pageserver_id_arg.clone())
|
||||
.subcommand(Command::new("start").about("Start local pageserver")
|
||||
.arg(pageserver_id_arg.clone())
|
||||
.arg(pageserver_config_args.clone()))
|
||||
.subcommand(Command::new("stop").about("Stop local pageserver")
|
||||
.arg(pageserver_id_arg.clone())
|
||||
.arg(stop_mode_arg.clone()))
|
||||
.subcommand(Command::new("restart").about("Restart local pageserver")
|
||||
.arg(pageserver_id_arg.clone())
|
||||
.arg(pageserver_config_args.clone()))
|
||||
.subcommand(Command::new("start")
|
||||
.about("Start local pageserver")
|
||||
.arg(pageserver_config_args.clone())
|
||||
)
|
||||
.subcommand(Command::new("stop")
|
||||
.about("Stop local pageserver")
|
||||
.arg(stop_mode_arg.clone())
|
||||
)
|
||||
.subcommand(Command::new("restart")
|
||||
.about("Restart local pageserver")
|
||||
.arg(pageserver_config_args.clone())
|
||||
)
|
||||
)
|
||||
.subcommand(
|
||||
Command::new("attachment_service")
|
||||
@@ -1321,8 +1376,8 @@ fn cli() -> Command {
|
||||
.about("Start postgres.\n If the endpoint doesn't exist yet, it is created.")
|
||||
.arg(endpoint_id_arg.clone())
|
||||
.arg(tenant_id_arg.clone())
|
||||
.arg(branch_name_arg)
|
||||
.arg(timeline_id_arg)
|
||||
.arg(branch_name_arg.clone())
|
||||
.arg(timeline_id_arg.clone())
|
||||
.arg(lsn_arg)
|
||||
.arg(pg_port_arg)
|
||||
.arg(http_port_arg)
|
||||
@@ -1332,10 +1387,16 @@ fn cli() -> Command {
|
||||
.arg(safekeepers_arg)
|
||||
.arg(remote_ext_config_args)
|
||||
)
|
||||
.subcommand(Command::new("reconfigure")
|
||||
.about("Reconfigure the endpoint")
|
||||
.arg(endpoint_pageserver_id_arg)
|
||||
.arg(endpoint_id_arg.clone())
|
||||
.arg(tenant_id_arg.clone())
|
||||
)
|
||||
.subcommand(
|
||||
Command::new("stop")
|
||||
.arg(endpoint_id_arg)
|
||||
.arg(tenant_id_arg)
|
||||
.arg(tenant_id_arg.clone())
|
||||
.arg(
|
||||
Arg::new("destroy")
|
||||
.help("Also delete data directory (now optional, should be default in future)")
|
||||
@@ -1346,6 +1407,18 @@ fn cli() -> Command {
|
||||
)
|
||||
|
||||
)
|
||||
.subcommand(
|
||||
Command::new("mappings")
|
||||
.arg_required_else_help(true)
|
||||
.about("Manage neon_local branch name mappings")
|
||||
.subcommand(
|
||||
Command::new("map")
|
||||
.about("Create new mapping which cannot exist already")
|
||||
.arg(branch_name_arg.clone())
|
||||
.arg(tenant_id_arg.clone())
|
||||
.arg(timeline_id_arg.clone())
|
||||
)
|
||||
)
|
||||
// Obsolete old name for 'endpoint'. We now just print an error if it's used.
|
||||
.subcommand(
|
||||
Command::new("pg")
|
||||
|
||||
@@ -46,7 +46,6 @@ use std::time::Duration;
|
||||
|
||||
use anyhow::{anyhow, bail, Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
use utils::id::{NodeId, TenantId, TimelineId};
|
||||
|
||||
use crate::local_env::LocalEnv;
|
||||
@@ -57,13 +56,10 @@ use compute_api::responses::{ComputeState, ComputeStatus};
|
||||
use compute_api::spec::{Cluster, ComputeMode, ComputeSpec};
|
||||
|
||||
// contents of a endpoint.json file
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
|
||||
pub struct EndpointConf {
|
||||
endpoint_id: String,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
tenant_id: TenantId,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
timeline_id: TimelineId,
|
||||
mode: ComputeMode,
|
||||
pg_port: u16,
|
||||
@@ -253,7 +249,7 @@ impl Endpoint {
|
||||
conf.append("shared_buffers", "1MB");
|
||||
conf.append("fsync", "off");
|
||||
conf.append("max_connections", "100");
|
||||
conf.append("wal_level", "replica");
|
||||
conf.append("wal_level", "logical");
|
||||
// wal_sender_timeout is the maximum time to wait for WAL replication.
|
||||
// It also defines how often the walreciever will send a feedback message to the wal sender.
|
||||
conf.append("wal_sender_timeout", "5s");
|
||||
@@ -414,18 +410,34 @@ impl Endpoint {
|
||||
);
|
||||
}
|
||||
|
||||
// Also wait for the compute_ctl process to die. It might have some cleanup
|
||||
// work to do after postgres stops, like syncing safekeepers, etc.
|
||||
//
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn wait_for_compute_ctl_to_exit(&self) -> Result<()> {
|
||||
// TODO use background_process::stop_process instead
|
||||
let pidfile_path = self.endpoint_path().join("compute_ctl.pid");
|
||||
let pid: u32 = std::fs::read_to_string(pidfile_path)?.parse()?;
|
||||
let pid = nix::unistd::Pid::from_raw(pid as i32);
|
||||
crate::background_process::wait_until_stopped("compute_ctl", pid)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn read_postgresql_conf(&self) -> Result<String> {
|
||||
// Slurp the endpoints/<endpoint id>/postgresql.conf file into
|
||||
// memory. We will include it in the spec file that we pass to
|
||||
// `compute_ctl`, and `compute_ctl` will write it to the postgresql.conf
|
||||
// in the data directory.
|
||||
let postgresql_conf_path = self.endpoint_path().join("postgresql.conf");
|
||||
match std::fs::read(&postgresql_conf_path) {
|
||||
Ok(content) => Ok(String::from_utf8(content)?),
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok("".to_string()),
|
||||
Err(e) => Err(anyhow::Error::new(e).context(format!(
|
||||
"failed to read config file in {}",
|
||||
postgresql_conf_path.to_str().unwrap()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start(
|
||||
&self,
|
||||
auth_token: &Option<String>,
|
||||
@@ -436,21 +448,7 @@ impl Endpoint {
|
||||
anyhow::bail!("The endpoint is already running");
|
||||
}
|
||||
|
||||
// Slurp the endpoints/<endpoint id>/postgresql.conf file into
|
||||
// memory. We will include it in the spec file that we pass to
|
||||
// `compute_ctl`, and `compute_ctl` will write it to the postgresql.conf
|
||||
// in the data directory.
|
||||
let postgresql_conf_path = self.endpoint_path().join("postgresql.conf");
|
||||
let postgresql_conf = match std::fs::read(&postgresql_conf_path) {
|
||||
Ok(content) => String::from_utf8(content)?,
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => "".to_string(),
|
||||
Err(e) => {
|
||||
return Err(anyhow::Error::new(e).context(format!(
|
||||
"failed to read config file in {}",
|
||||
postgresql_conf_path.to_str().unwrap()
|
||||
)))
|
||||
}
|
||||
};
|
||||
let postgresql_conf = self.read_postgresql_conf()?;
|
||||
|
||||
// We always start the compute node from scratch, so if the Postgres
|
||||
// data dir exists from a previous launch, remove it first.
|
||||
@@ -621,6 +619,61 @@ impl Endpoint {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reconfigure(&self, pageserver_id: Option<NodeId>) -> Result<()> {
|
||||
let mut spec: ComputeSpec = {
|
||||
let spec_path = self.endpoint_path().join("spec.json");
|
||||
let file = std::fs::File::open(spec_path)?;
|
||||
serde_json::from_reader(file)?
|
||||
};
|
||||
|
||||
let postgresql_conf = self.read_postgresql_conf()?;
|
||||
spec.cluster.postgresql_conf = Some(postgresql_conf);
|
||||
|
||||
if let Some(pageserver_id) = pageserver_id {
|
||||
let endpoint_config_path = self.endpoint_path().join("endpoint.json");
|
||||
let mut endpoint_conf: EndpointConf = {
|
||||
let file = std::fs::File::open(&endpoint_config_path)?;
|
||||
serde_json::from_reader(file)?
|
||||
};
|
||||
endpoint_conf.pageserver_id = pageserver_id;
|
||||
std::fs::write(
|
||||
endpoint_config_path,
|
||||
serde_json::to_string_pretty(&endpoint_conf)?,
|
||||
)?;
|
||||
|
||||
let pageserver =
|
||||
PageServerNode::from_env(&self.env, self.env.get_pageserver_conf(pageserver_id)?);
|
||||
let ps_http_conf = &pageserver.pg_connection_config;
|
||||
let (host, port) = (ps_http_conf.host(), ps_http_conf.port());
|
||||
spec.pageserver_connstring = Some(format!("postgresql://no_user@{host}:{port}"));
|
||||
}
|
||||
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let response = client
|
||||
.post(format!(
|
||||
"http://{}:{}/configure",
|
||||
self.http_address.ip(),
|
||||
self.http_address.port()
|
||||
))
|
||||
.body(format!(
|
||||
"{{\"spec\":{}}}",
|
||||
serde_json::to_string_pretty(&spec)?
|
||||
))
|
||||
.send()?;
|
||||
|
||||
let status = response.status();
|
||||
if !(status.is_client_error() || status.is_server_error()) {
|
||||
Ok(())
|
||||
} else {
|
||||
let url = response.url().to_owned();
|
||||
let msg = match response.text() {
|
||||
Ok(err_body) => format!("Error: {}", err_body),
|
||||
Err(_) => format!("Http error ({}) at {}.", status.as_u16(), url),
|
||||
};
|
||||
Err(anyhow::anyhow!(msg))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stop(&self, destroy: bool) -> Result<()> {
|
||||
// If we are going to destroy data directory,
|
||||
// use immediate shutdown mode, otherwise,
|
||||
@@ -629,15 +682,25 @@ impl Endpoint {
|
||||
// Postgres is always started from scratch, so stop
|
||||
// without destroy only used for testing and debugging.
|
||||
//
|
||||
self.pg_ctl(
|
||||
if destroy {
|
||||
&["-m", "immediate", "stop"]
|
||||
} else {
|
||||
&["stop"]
|
||||
},
|
||||
&None,
|
||||
)?;
|
||||
|
||||
// Also wait for the compute_ctl process to die. It might have some cleanup
|
||||
// work to do after postgres stops, like syncing safekeepers, etc.
|
||||
//
|
||||
self.wait_for_compute_ctl_to_exit()?;
|
||||
if destroy {
|
||||
self.pg_ctl(&["-m", "immediate", "stop"], &None)?;
|
||||
println!(
|
||||
"Destroying postgres data directory '{}'",
|
||||
self.pgdata().to_str().unwrap()
|
||||
);
|
||||
std::fs::remove_dir_all(self.endpoint_path())?;
|
||||
} else {
|
||||
self.pg_ctl(&["stop"], &None)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
//
|
||||
// Local control plane.
|
||||
//
|
||||
// Can start, configure and stop postgres instances running as a local processes.
|
||||
//
|
||||
// Intended to be used in integration tests and in CLI tools for
|
||||
// local installations.
|
||||
//
|
||||
//! Local control plane.
|
||||
//!
|
||||
//! Can start, configure and stop postgres instances running as a local processes.
|
||||
//!
|
||||
//! Intended to be used in integration tests and in CLI tools for
|
||||
//! local installations.
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
pub mod attachment_service;
|
||||
mod background_process;
|
||||
|
||||
@@ -8,7 +8,6 @@ use anyhow::{bail, ensure, Context};
|
||||
use postgres_backend::AuthType;
|
||||
use reqwest::Url;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::fs;
|
||||
@@ -33,7 +32,6 @@ pub const DEFAULT_PG_VERSION: u32 = 15;
|
||||
// to 'neon_local init --config=<path>' option. See control_plane/simple.conf for
|
||||
// an example.
|
||||
//
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
|
||||
pub struct LocalEnv {
|
||||
// Base directory for all the nodes (the pageserver, safekeepers and
|
||||
@@ -59,7 +57,6 @@ pub struct LocalEnv {
|
||||
// Default tenant ID to use with the 'neon_local' command line utility, when
|
||||
// --tenant_id is not explicitly specified.
|
||||
#[serde(default)]
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub default_tenant_id: Option<TenantId>,
|
||||
|
||||
// used to issue tokens during e.g pg start
|
||||
@@ -84,7 +81,6 @@ pub struct LocalEnv {
|
||||
// A `HashMap<String, HashMap<TenantId, TimelineId>>` would be more appropriate here,
|
||||
// but deserialization into a generic toml object as `toml::Value::try_from` fails with an error.
|
||||
// https://toml.io/en/v1.0.0 does not contain a concept of "a table inside another table".
|
||||
#[serde_as(as = "HashMap<_, Vec<(DisplayFromStr, DisplayFromStr)>>")]
|
||||
branch_name_mappings: HashMap<String, Vec<(TenantId, TimelineId)>>,
|
||||
}
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
},
|
||||
{
|
||||
"name": "wal_level",
|
||||
"value": "replica",
|
||||
"value": "logical",
|
||||
"vartype": "enum"
|
||||
},
|
||||
{
|
||||
|
||||
@@ -188,11 +188,60 @@ that.
|
||||
|
||||
## Error message style
|
||||
|
||||
### PostgreSQL extensions
|
||||
|
||||
PostgreSQL has a style guide for writing error messages:
|
||||
|
||||
https://www.postgresql.org/docs/current/error-style-guide.html
|
||||
|
||||
Follow that guide when writing error messages in the PostgreSQL
|
||||
extension. We don't follow it strictly in the pageserver and
|
||||
safekeeper, but the advice in the PostgreSQL style guide is generally
|
||||
good, and you can't go wrong by following it.
|
||||
extensions.
|
||||
|
||||
### Neon Rust code
|
||||
|
||||
#### Anyhow Context
|
||||
|
||||
When adding anyhow `context()`, use form `present-tense-verb+action`.
|
||||
|
||||
Example:
|
||||
- Bad: `file.metadata().context("could not get file metadata")?;`
|
||||
- Good: `file.metadata().context("get file metadata")?;`
|
||||
|
||||
#### Logging Errors
|
||||
|
||||
When logging any error `e`, use `could not {e:#}` or `failed to {e:#}`.
|
||||
|
||||
If `e` is an `anyhow` error and you want to log the backtrace that it contains,
|
||||
use `{e:?}` instead of `{e:#}`.
|
||||
|
||||
#### Rationale
|
||||
|
||||
The `{:#}` ("alternate Display") of an `anyhow` error chain is concatenation fo the contexts, using `: `.
|
||||
|
||||
For example, the following Rust code will result in output
|
||||
```
|
||||
ERROR failed to list users: load users from server: parse response: invalid json
|
||||
```
|
||||
|
||||
This is more concise / less noisy than what happens if you do `.context("could not ...")?` at each level, i.e.:
|
||||
|
||||
```
|
||||
ERROR could not list users: could not load users from server: could not parse response: invalid json
|
||||
```
|
||||
|
||||
|
||||
```rust
|
||||
fn main() {
|
||||
match list_users().context("list users") else {
|
||||
Ok(_) => ...,
|
||||
Err(e) => tracing::error!("failed to {e:#}"),
|
||||
}
|
||||
}
|
||||
fn list_users() {
|
||||
http_get_users().context("load users from server")?;
|
||||
}
|
||||
fn http_get_users() {
|
||||
let response = client....?;
|
||||
response.parse().context("parse response")?; // fails with serde error "invalid json"
|
||||
}
|
||||
```
|
||||
|
||||
@@ -96,6 +96,16 @@ prefix_in_bucket = '/test_prefix/'
|
||||
|
||||
`AWS_SECRET_ACCESS_KEY` and `AWS_ACCESS_KEY_ID` env variables can be used to specify the S3 credentials if needed.
|
||||
|
||||
or
|
||||
|
||||
```toml
|
||||
[remote_storage]
|
||||
container_name = 'some-container-name'
|
||||
container_region = 'us-east'
|
||||
prefix_in_container = '/test-prefix/'
|
||||
```
|
||||
|
||||
`AZURE_STORAGE_ACCOUNT` and `AZURE_STORAGE_ACCESS_KEY` env variables can be used to specify the azure credentials if needed.
|
||||
|
||||
## Repository background tasks
|
||||
|
||||
|
||||
108
docs/updating-postgres.md
Normal file
108
docs/updating-postgres.md
Normal file
@@ -0,0 +1,108 @@
|
||||
# Updating Postgres
|
||||
|
||||
## Minor Versions
|
||||
|
||||
When upgrading to a new minor version of Postgres, please follow these steps:
|
||||
|
||||
_Example: 15.4 is the new minor version to upgrade to from 15.3._
|
||||
|
||||
1. Clone the Neon Postgres repository if you have not done so already.
|
||||
|
||||
```shell
|
||||
git clone git@github.com:neondatabase/postgres.git
|
||||
```
|
||||
|
||||
1. Add the Postgres upstream remote.
|
||||
|
||||
```shell
|
||||
git remote add upstream https://git.postgresql.org/git/postgresql.git
|
||||
```
|
||||
|
||||
1. Create a new branch based on the stable branch you are updating.
|
||||
|
||||
```shell
|
||||
git checkout -b my-branch REL_15_STABLE_neon
|
||||
```
|
||||
|
||||
1. Tag the last commit on the stable branch you are updating.
|
||||
|
||||
```shell
|
||||
git tag REL_15_3_neon
|
||||
```
|
||||
|
||||
1. Push the new tag to the Neon Postgres repository.
|
||||
|
||||
```shell
|
||||
git push origin REL_15_3_neon
|
||||
```
|
||||
|
||||
1. Find the release tags you're looking for. They are of the form `REL_X_Y`.
|
||||
|
||||
1. Rebase the branch you created on the tag and resolve any conflicts.
|
||||
|
||||
```shell
|
||||
git fetch upstream REL_15_4
|
||||
git rebase REL_15_4
|
||||
```
|
||||
|
||||
1. Run the Postgres test suite to make sure our commits have not affected
|
||||
Postgres in a negative way.
|
||||
|
||||
```shell
|
||||
make check
|
||||
# OR
|
||||
meson test -C builddir
|
||||
```
|
||||
|
||||
1. Push your branch to the Neon Postgres repository.
|
||||
|
||||
```shell
|
||||
git push origin my-branch
|
||||
```
|
||||
|
||||
1. Clone the Neon repository if you have not done so already.
|
||||
|
||||
```shell
|
||||
git clone git@github.com:neondatabase/neon.git
|
||||
```
|
||||
|
||||
1. Create a new branch.
|
||||
|
||||
1. Change the `revisions.json` file to point at the HEAD of your Postgres
|
||||
branch.
|
||||
|
||||
1. Update the Git submodule.
|
||||
|
||||
```shell
|
||||
git submodule set-branch --branch my-branch vendor/postgres-v15
|
||||
git submodule update --remote vendor/postgres-v15
|
||||
```
|
||||
|
||||
1. Run the Neon test suite to make sure that Neon is still good to go on this
|
||||
minor Postgres release.
|
||||
|
||||
```shell
|
||||
./scripts/poetry -k pg15
|
||||
```
|
||||
|
||||
1. Commit your changes.
|
||||
|
||||
1. Create a pull request, and wait for CI to go green.
|
||||
|
||||
1. Force push the rebased Postgres branches into the Neon Postgres repository.
|
||||
|
||||
```shell
|
||||
git push --force origin my-branch:REL_15_STABLE_neon
|
||||
```
|
||||
|
||||
It may require disabling various branch protections.
|
||||
|
||||
1. Update your Neon PR to point at the branches.
|
||||
|
||||
```shell
|
||||
git submodule set-branch --branch REL_15_STABLE_neon vendor/postgres-v15
|
||||
git commit --amend --no-edit
|
||||
git push --force origin
|
||||
```
|
||||
|
||||
1. Merge the pull request after getting approval(s) and CI completion.
|
||||
@@ -1,3 +1,5 @@
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
pub mod requests;
|
||||
pub mod responses;
|
||||
pub mod spec;
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
use utils::id::{TenantId, TimelineId};
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
@@ -19,7 +18,6 @@ pub type PgIdent = String;
|
||||
|
||||
/// Cluster spec or configuration represented as an optional number of
|
||||
/// delta operations + final cluster state description.
|
||||
#[serde_as]
|
||||
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
|
||||
pub struct ComputeSpec {
|
||||
pub format_version: f32,
|
||||
@@ -50,12 +48,12 @@ pub struct ComputeSpec {
|
||||
// these, and instead set the "neon.tenant_id", "neon.timeline_id",
|
||||
// etc. GUCs in cluster.settings. TODO: Once the control plane has been
|
||||
// updated to fill these fields, we can make these non optional.
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub tenant_id: Option<TenantId>,
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
|
||||
pub timeline_id: Option<TimelineId>,
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
|
||||
pub pageserver_connstring: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub safekeeper_connstrings: Vec<String>,
|
||||
|
||||
@@ -140,14 +138,13 @@ impl RemoteExtSpec {
|
||||
}
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ComputeMode {
|
||||
/// A read-write node
|
||||
#[default]
|
||||
Primary,
|
||||
/// A read-only node, pinned at a particular LSN
|
||||
Static(#[serde_as(as = "DisplayFromStr")] Lsn),
|
||||
Static(Lsn),
|
||||
/// A read-only node that follows the tip of the branch in hot standby mode
|
||||
///
|
||||
/// Future versions may want to distinguish between replicas with hot standby
|
||||
@@ -190,6 +187,8 @@ pub struct DeltaOp {
|
||||
pub struct Role {
|
||||
pub name: PgIdent,
|
||||
pub encrypted_password: Option<String>,
|
||||
pub replication: Option<bool>,
|
||||
pub bypassrls: Option<bool>,
|
||||
pub options: GenericOptions,
|
||||
}
|
||||
|
||||
@@ -200,6 +199,12 @@ pub struct Database {
|
||||
pub name: PgIdent,
|
||||
pub owner: PgIdent,
|
||||
pub options: GenericOptions,
|
||||
// These are derived flags, not present in the spec file.
|
||||
// They are never set by the control plane.
|
||||
#[serde(skip_deserializing, default)]
|
||||
pub restrict_conn: bool,
|
||||
#[serde(skip_deserializing, default)]
|
||||
pub invalid: bool,
|
||||
}
|
||||
|
||||
/// Common type representing both SQL statement params with or without value,
|
||||
|
||||
@@ -76,7 +76,7 @@
|
||||
},
|
||||
{
|
||||
"name": "wal_level",
|
||||
"value": "replica",
|
||||
"value": "logical",
|
||||
"vartype": "enum"
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//!
|
||||
//! Shared code for consumption metics collection
|
||||
//!
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
use chrono::{DateTime, Utc};
|
||||
use rand::Rng;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
//! make sure that we use the same dep version everywhere.
|
||||
//! Otherwise, we might not see all metrics registered via
|
||||
//! a default registry.
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
use once_cell::sync::Lazy;
|
||||
use prometheus::core::{AtomicU64, Collector, GenericGauge, GenericGaugeVec};
|
||||
pub use prometheus::opts;
|
||||
@@ -89,14 +90,14 @@ pub const DISK_WRITE_SECONDS_BUCKETS: &[f64] = &[
|
||||
0.000_050, 0.000_100, 0.000_500, 0.001, 0.003, 0.005, 0.01, 0.05, 0.1, 0.3, 0.5,
|
||||
];
|
||||
|
||||
pub fn set_build_info_metric(revision: &str) {
|
||||
pub fn set_build_info_metric(revision: &str, build_tag: &str) {
|
||||
let metric = register_int_gauge_vec!(
|
||||
"libmetrics_build_info",
|
||||
"Build/version information",
|
||||
&["revision"]
|
||||
&["revision", "build_tag"]
|
||||
)
|
||||
.expect("Failed to register build info metric");
|
||||
metric.with_label_values(&[revision]).set(1);
|
||||
metric.with_label_values(&[revision, build_tag]).set(1);
|
||||
}
|
||||
|
||||
// Records I/O stats in a "cross-platform" way.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::io::{Read, Result, Write};
|
||||
|
||||
/// A wrapper for an object implementing [Read](std::io::Read)
|
||||
/// A wrapper for an object implementing [Read]
|
||||
/// which allows a closure to observe the amount of bytes read.
|
||||
/// This is useful in conjunction with metrics (e.g. [IntCounter](crate::IntCounter)).
|
||||
///
|
||||
@@ -51,17 +51,17 @@ impl<'a, T> CountedReader<'a, T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get an immutable reference to the underlying [Read](std::io::Read) implementor
|
||||
/// Get an immutable reference to the underlying [Read] implementor
|
||||
pub fn inner(&self) -> &T {
|
||||
&self.reader
|
||||
}
|
||||
|
||||
/// Get a mutable reference to the underlying [Read](std::io::Read) implementor
|
||||
/// Get a mutable reference to the underlying [Read] implementor
|
||||
pub fn inner_mut(&mut self) -> &mut T {
|
||||
&mut self.reader
|
||||
}
|
||||
|
||||
/// Consume the wrapper and return the underlying [Read](std::io::Read) implementor
|
||||
/// Consume the wrapper and return the underlying [Read] implementor
|
||||
pub fn into_inner(self) -> T {
|
||||
self.reader
|
||||
}
|
||||
@@ -75,7 +75,7 @@ impl<T: Read> Read for CountedReader<'_, T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper for an object implementing [Write](std::io::Write)
|
||||
/// A wrapper for an object implementing [Write]
|
||||
/// which allows a closure to observe the amount of bytes written.
|
||||
/// This is useful in conjunction with metrics (e.g. [IntCounter](crate::IntCounter)).
|
||||
///
|
||||
@@ -122,17 +122,17 @@ impl<'a, T> CountedWriter<'a, T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get an immutable reference to the underlying [Write](std::io::Write) implementor
|
||||
/// Get an immutable reference to the underlying [Write] implementor
|
||||
pub fn inner(&self) -> &T {
|
||||
&self.writer
|
||||
}
|
||||
|
||||
/// Get a mutable reference to the underlying [Write](std::io::Write) implementor
|
||||
/// Get a mutable reference to the underlying [Write] implementor
|
||||
pub fn inner_mut(&mut self) -> &mut T {
|
||||
&mut self.writer
|
||||
}
|
||||
|
||||
/// Consume the wrapper and return the underlying [Write](std::io::Write) implementor
|
||||
/// Consume the wrapper and return the underlying [Write] implementor
|
||||
pub fn into_inner(self) -> T {
|
||||
self.writer
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
//! See docs/rfcs/025-generation-numbers.md
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
use utils::id::{NodeId, TenantId};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
@@ -12,12 +11,10 @@ pub struct ReAttachRequest {
|
||||
pub node_id: NodeId,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ReAttachResponseTenant {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub id: TenantId,
|
||||
pub generation: u32,
|
||||
pub gen: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
@@ -25,10 +22,8 @@ pub struct ReAttachResponse {
|
||||
pub tenants: Vec<ReAttachResponseTenant>,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ValidateRequestTenant {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub id: TenantId,
|
||||
pub gen: u32,
|
||||
}
|
||||
@@ -43,10 +38,8 @@ pub struct ValidateResponse {
|
||||
pub tenants: Vec<ValidateResponseTenant>,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ValidateResponseTenant {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub id: TenantId,
|
||||
pub valid: bool,
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
use const_format::formatcp;
|
||||
|
||||
/// Public API types
|
||||
|
||||
@@ -6,7 +6,7 @@ use std::{
|
||||
|
||||
use byteorder::{BigEndian, ReadBytesExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
use serde_with::serde_as;
|
||||
use strum_macros;
|
||||
use utils::{
|
||||
completion,
|
||||
@@ -110,7 +110,6 @@ impl TenantState {
|
||||
// So, return `Maybe` while Attaching, making Console wait for the attach task to finish.
|
||||
Self::Attaching | Self::Activating(ActivatingFrom::Attaching) => Maybe,
|
||||
// tenant mgr startup distinguishes attaching from loading via marker file.
|
||||
// If it's loading, there is no attach marker file, i.e., attach had finished in the past.
|
||||
Self::Loading | Self::Activating(ActivatingFrom::Loading) => Attached,
|
||||
// We only reach Active after successful load / attach.
|
||||
// So, call atttachment status Attached.
|
||||
@@ -175,25 +174,19 @@ pub enum TimelineState {
|
||||
Broken { reason: String, backtrace: String },
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct TimelineCreateRequest {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub new_timeline_id: TimelineId,
|
||||
#[serde(default)]
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub ancestor_timeline_id: Option<TimelineId>,
|
||||
#[serde(default)]
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub ancestor_start_lsn: Option<Lsn>,
|
||||
pub pg_version: Option<u32>,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct TenantCreateRequest {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub new_tenant_id: TenantId,
|
||||
#[serde(default)]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -202,7 +195,6 @@ pub struct TenantCreateRequest {
|
||||
pub config: TenantConfig, // as we have a flattened field, we should reject all unknown fields in it
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct TenantLoadRequest {
|
||||
@@ -279,31 +271,26 @@ pub struct LocationConfig {
|
||||
pub tenant_conf: TenantConfig,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct TenantCreateResponse(#[serde_as(as = "DisplayFromStr")] pub TenantId);
|
||||
pub struct TenantCreateResponse(pub TenantId);
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct StatusResponse {
|
||||
pub id: NodeId,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct TenantLocationConfigRequest {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub tenant_id: TenantId,
|
||||
#[serde(flatten)]
|
||||
pub config: LocationConfig, // as we have a flattened field, we should reject all unknown fields in it
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct TenantConfigRequest {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub tenant_id: TenantId,
|
||||
#[serde(flatten)]
|
||||
pub config: TenantConfig, // as we have a flattened field, we should reject all unknown fields in it
|
||||
@@ -375,10 +362,8 @@ pub enum TenantAttachmentStatus {
|
||||
Failed { reason: String },
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct TenantInfo {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub id: TenantId,
|
||||
// NB: intentionally not part of OpenAPI, we don't want to commit to a specific set of TenantState's
|
||||
pub state: TenantState,
|
||||
@@ -389,33 +374,22 @@ pub struct TenantInfo {
|
||||
}
|
||||
|
||||
/// This represents the output of the "timeline_detail" and "timeline_list" API calls.
|
||||
#[serde_as]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct TimelineInfo {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub tenant_id: TenantId,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub timeline_id: TimelineId,
|
||||
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub ancestor_timeline_id: Option<TimelineId>,
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub ancestor_lsn: Option<Lsn>,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub last_record_lsn: Lsn,
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub prev_record_lsn: Option<Lsn>,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub latest_gc_cutoff_lsn: Lsn,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub disk_consistent_lsn: Lsn,
|
||||
|
||||
/// The LSN that we have succesfully uploaded to remote storage
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub remote_consistent_lsn: Lsn,
|
||||
|
||||
/// The LSN that we are advertizing to safekeepers
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub remote_consistent_lsn_visible: Lsn,
|
||||
|
||||
pub current_logical_size: Option<u64>, // is None when timeline is Unloaded
|
||||
@@ -427,7 +401,6 @@ pub struct TimelineInfo {
|
||||
pub timeline_dir_layer_file_size_sum: Option<u64>,
|
||||
|
||||
pub wal_source_connstr: Option<String>,
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub last_received_msg_lsn: Option<Lsn>,
|
||||
/// the timestamp (in microseconds) of the last received message
|
||||
pub last_received_msg_ts: Option<u128>,
|
||||
@@ -524,23 +497,13 @@ pub struct LayerAccessStats {
|
||||
pub residence_events_history: HistoryBufferWithDropCounter<LayerResidenceEvent, 16>,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "kind")]
|
||||
pub enum InMemoryLayerInfo {
|
||||
Open {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
lsn_start: Lsn,
|
||||
},
|
||||
Frozen {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
lsn_start: Lsn,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
lsn_end: Lsn,
|
||||
},
|
||||
Open { lsn_start: Lsn },
|
||||
Frozen { lsn_start: Lsn, lsn_end: Lsn },
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
#[serde(tag = "kind")]
|
||||
pub enum HistoricLayerInfo {
|
||||
@@ -548,9 +511,7 @@ pub enum HistoricLayerInfo {
|
||||
layer_file_name: String,
|
||||
layer_file_size: u64,
|
||||
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
lsn_start: Lsn,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
lsn_end: Lsn,
|
||||
remote: bool,
|
||||
access_stats: LayerAccessStats,
|
||||
@@ -559,7 +520,6 @@ pub enum HistoricLayerInfo {
|
||||
layer_file_name: String,
|
||||
layer_file_size: u64,
|
||||
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
lsn_start: Lsn,
|
||||
remote: bool,
|
||||
access_stats: LayerAccessStats,
|
||||
|
||||
@@ -22,9 +22,9 @@ use postgres_ffi::Oid;
|
||||
/// [See more related comments here](https:///github.com/postgres/postgres/blob/99c5852e20a0987eca1c38ba0c09329d4076b6a0/src/include/storage/relfilenode.h#L57).
|
||||
///
|
||||
// FIXME: should move 'forknum' as last field to keep this consistent with Postgres.
|
||||
// Then we could replace the custo Ord and PartialOrd implementations below with
|
||||
// deriving them.
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize, Deserialize)]
|
||||
// Then we could replace the custom Ord and PartialOrd implementations below with
|
||||
// deriving them. This will require changes in walredoproc.c.
|
||||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Serialize)]
|
||||
pub struct RelTag {
|
||||
pub forknum: u8,
|
||||
pub spcnode: Oid,
|
||||
@@ -40,21 +40,9 @@ impl PartialOrd for RelTag {
|
||||
|
||||
impl Ord for RelTag {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
let mut cmp = self.spcnode.cmp(&other.spcnode);
|
||||
if cmp != Ordering::Equal {
|
||||
return cmp;
|
||||
}
|
||||
cmp = self.dbnode.cmp(&other.dbnode);
|
||||
if cmp != Ordering::Equal {
|
||||
return cmp;
|
||||
}
|
||||
cmp = self.relnode.cmp(&other.relnode);
|
||||
if cmp != Ordering::Equal {
|
||||
return cmp;
|
||||
}
|
||||
cmp = self.forknum.cmp(&other.forknum);
|
||||
|
||||
cmp
|
||||
// Custom ordering where we put forknum to the end of the list
|
||||
let other_tup = (other.spcnode, other.dbnode, other.relnode, other.forknum);
|
||||
(self.spcnode, self.dbnode, self.relnode, self.forknum).cmp(&other_tup)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
//! To use, create PostgresBackend and run() it, passing the Handler
|
||||
//! implementation determining how to process the queries. Currently its API
|
||||
//! is rather narrow, but we can extend it once required.
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
use anyhow::Context;
|
||||
use bytes::Bytes;
|
||||
use futures::pin_mut;
|
||||
@@ -15,12 +17,12 @@ use std::{fmt, io};
|
||||
use std::{future::Future, str::FromStr};
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tracing::{debug, error, info, trace};
|
||||
use tracing::{debug, error, info, trace, warn};
|
||||
|
||||
use pq_proto::framed::{ConnectionError, Framed, FramedReader, FramedWriter};
|
||||
use pq_proto::{
|
||||
BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_INTERNAL_ERROR,
|
||||
SQLSTATE_SUCCESSFUL_COMPLETION,
|
||||
BeMessage, FeMessage, FeStartupPacket, ProtocolError, SQLSTATE_ADMIN_SHUTDOWN,
|
||||
SQLSTATE_INTERNAL_ERROR, SQLSTATE_SUCCESSFUL_COMPLETION,
|
||||
};
|
||||
|
||||
/// An error, occurred during query processing:
|
||||
@@ -30,6 +32,14 @@ pub enum QueryError {
|
||||
/// The connection was lost while processing the query.
|
||||
#[error(transparent)]
|
||||
Disconnected(#[from] ConnectionError),
|
||||
/// We were instructed to shutdown while processing the query
|
||||
#[error("Shutting down")]
|
||||
Shutdown,
|
||||
/// Authentication failure
|
||||
#[error("Unauthorized: {0}")]
|
||||
Unauthorized(std::borrow::Cow<'static, str>),
|
||||
#[error("Simulated Connection Error")]
|
||||
SimulatedConnectionError,
|
||||
/// Some other error
|
||||
#[error(transparent)]
|
||||
Other(#[from] anyhow::Error),
|
||||
@@ -44,7 +54,9 @@ impl From<io::Error> for QueryError {
|
||||
impl QueryError {
|
||||
pub fn pg_error_code(&self) -> &'static [u8; 5] {
|
||||
match self {
|
||||
Self::Disconnected(_) => b"08006", // connection failure
|
||||
Self::Disconnected(_) | Self::SimulatedConnectionError => b"08006", // connection failure
|
||||
Self::Shutdown => SQLSTATE_ADMIN_SHUTDOWN,
|
||||
Self::Unauthorized(_) => SQLSTATE_INTERNAL_ERROR,
|
||||
Self::Other(_) => SQLSTATE_INTERNAL_ERROR, // internal error
|
||||
}
|
||||
}
|
||||
@@ -238,6 +250,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> MaybeWriteOnly<IO> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancellation safe as long as the underlying IO is cancellation safe.
|
||||
async fn shutdown(&mut self) -> io::Result<()> {
|
||||
match self {
|
||||
MaybeWriteOnly::Full(framed) => framed.shutdown().await,
|
||||
@@ -389,14 +402,37 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
|
||||
shutdown_watcher: F,
|
||||
) -> Result<(), QueryError>
|
||||
where
|
||||
F: Fn() -> S,
|
||||
F: Fn() -> S + Clone,
|
||||
S: Future,
|
||||
{
|
||||
let ret = self.run_message_loop(handler, shutdown_watcher).await;
|
||||
// socket might be already closed, e.g. if previously received error,
|
||||
// so ignore result.
|
||||
self.framed.shutdown().await.ok();
|
||||
ret
|
||||
let ret = self
|
||||
.run_message_loop(handler, shutdown_watcher.clone())
|
||||
.await;
|
||||
|
||||
tokio::select! {
|
||||
_ = shutdown_watcher() => {
|
||||
// do nothing; we most likely got already stopped by shutdown and will log it next.
|
||||
}
|
||||
_ = self.framed.shutdown() => {
|
||||
// socket might be already closed, e.g. if previously received error,
|
||||
// so ignore result.
|
||||
},
|
||||
}
|
||||
|
||||
match ret {
|
||||
Ok(()) => Ok(()),
|
||||
Err(QueryError::Shutdown) => {
|
||||
info!("Stopped due to shutdown");
|
||||
Ok(())
|
||||
}
|
||||
Err(QueryError::Disconnected(e)) => {
|
||||
info!("Disconnected ({e:#})");
|
||||
// Disconnection is not an error: we just use it that way internally to drop
|
||||
// out of loops.
|
||||
Ok(())
|
||||
}
|
||||
e => e,
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_message_loop<F, S>(
|
||||
@@ -416,15 +452,11 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
|
||||
_ = shutdown_watcher() => {
|
||||
// We were requested to shut down.
|
||||
tracing::info!("shutdown request received during handshake");
|
||||
return Ok(())
|
||||
return Err(QueryError::Shutdown)
|
||||
},
|
||||
|
||||
result = self.handshake(handler) => {
|
||||
// Handshake complete.
|
||||
result?;
|
||||
if self.state == ProtoState::Closed {
|
||||
return Ok(()); // EOF during handshake
|
||||
}
|
||||
handshake_r = self.handshake(handler) => {
|
||||
handshake_r?;
|
||||
}
|
||||
);
|
||||
|
||||
@@ -435,17 +467,34 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
|
||||
_ = shutdown_watcher() => {
|
||||
// We were requested to shut down.
|
||||
tracing::info!("shutdown request received in run_message_loop");
|
||||
Ok(None)
|
||||
return Err(QueryError::Shutdown)
|
||||
},
|
||||
msg = self.read_message() => { msg },
|
||||
)? {
|
||||
trace!("got message {:?}", msg);
|
||||
|
||||
let result = self.process_message(handler, msg, &mut query_string).await;
|
||||
self.flush().await?;
|
||||
tokio::select!(
|
||||
biased;
|
||||
_ = shutdown_watcher() => {
|
||||
// We were requested to shut down.
|
||||
tracing::info!("shutdown request received during response flush");
|
||||
|
||||
// If we exited process_message with a shutdown error, there may be
|
||||
// some valid response content on in our transmit buffer: permit sending
|
||||
// this within a short timeout. This is a best effort thing so we don't
|
||||
// care about the result.
|
||||
tokio::time::timeout(std::time::Duration::from_millis(500), self.flush()).await.ok();
|
||||
|
||||
return Err(QueryError::Shutdown)
|
||||
},
|
||||
flush_r = self.flush() => {
|
||||
flush_r?;
|
||||
}
|
||||
);
|
||||
|
||||
match result? {
|
||||
ProcessMsgResult::Continue => {
|
||||
self.flush().await?;
|
||||
continue;
|
||||
}
|
||||
ProcessMsgResult::Break => break,
|
||||
@@ -550,7 +599,9 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
|
||||
self.peer_addr
|
||||
);
|
||||
self.state = ProtoState::Closed;
|
||||
return Ok(());
|
||||
return Err(QueryError::Disconnected(ConnectionError::Protocol(
|
||||
ProtocolError::Protocol("EOF during handshake".to_string()),
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -565,7 +616,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
|
||||
|
||||
if let Err(e) = handler.check_auth_jwt(self, jwt_response) {
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
&e.to_string(),
|
||||
&short_error(&e),
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
return Err(e);
|
||||
@@ -589,7 +640,9 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
|
||||
self.peer_addr
|
||||
);
|
||||
self.state = ProtoState::Closed;
|
||||
return Ok(());
|
||||
return Err(QueryError::Disconnected(ConnectionError::Protocol(
|
||||
ProtocolError::Protocol("EOF during auth".to_string()),
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -683,12 +736,20 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
|
||||
|
||||
trace!("got query {query_string:?}");
|
||||
if let Err(e) = handler.process_query(self, query_string).await {
|
||||
log_query_error(query_string, &e);
|
||||
let short_error = short_error(&e);
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
&short_error,
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
match e {
|
||||
QueryError::Shutdown => return Ok(ProcessMsgResult::Break),
|
||||
QueryError::SimulatedConnectionError => {
|
||||
return Err(QueryError::SimulatedConnectionError)
|
||||
}
|
||||
e => {
|
||||
log_query_error(query_string, &e);
|
||||
let short_error = short_error(&e);
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
&short_error,
|
||||
Some(e.pg_error_code()),
|
||||
))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
self.write_message_noflush(&BeMessage::ReadyForQuery)?;
|
||||
}
|
||||
@@ -913,6 +974,9 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin> AsyncWrite for CopyDataWriter<'a, I
|
||||
pub fn short_error(e: &QueryError) -> String {
|
||||
match e {
|
||||
QueryError::Disconnected(connection_error) => connection_error.to_string(),
|
||||
QueryError::Shutdown => "shutdown".to_string(),
|
||||
QueryError::Unauthorized(_e) => "JWT authentication error".to_string(),
|
||||
QueryError::SimulatedConnectionError => "simulated connection error".to_string(),
|
||||
QueryError::Other(e) => format!("{e:#}"),
|
||||
}
|
||||
}
|
||||
@@ -929,6 +993,15 @@ fn log_query_error(query: &str, e: &QueryError) {
|
||||
QueryError::Disconnected(other_connection_error) => {
|
||||
error!("query handler for '{query}' failed with connection error: {other_connection_error:?}")
|
||||
}
|
||||
QueryError::SimulatedConnectionError => {
|
||||
error!("query handler for query '{query}' failed due to a simulated connection error")
|
||||
}
|
||||
QueryError::Shutdown => {
|
||||
info!("query handler for '{query}' cancelled during tenant shutdown")
|
||||
}
|
||||
QueryError::Unauthorized(e) => {
|
||||
warn!("query handler for '{query}' failed with authentication error: {e}");
|
||||
}
|
||||
QueryError::Other(e) => {
|
||||
error!("query handler for '{query}' failed: {e:?}");
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
use anyhow::{bail, Context};
|
||||
use itertools::Itertools;
|
||||
use std::borrow::Cow;
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
// modules included with the postgres_ffi macro depend on the types of the specific version's
|
||||
// types, and trigger a too eager lint.
|
||||
#![allow(clippy::duplicate_mod)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
use bytes::Bytes;
|
||||
use utils::bin_ser::SerializeError;
|
||||
@@ -20,6 +21,7 @@ macro_rules! postgres_ffi {
|
||||
pub mod bindings {
|
||||
// bindgen generates bindings for a lot of stuff we don't need
|
||||
#![allow(dead_code)]
|
||||
#![allow(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
include!(concat!(
|
||||
@@ -131,6 +133,7 @@ pub const MAX_SEND_SIZE: usize = XLOG_BLCKSZ * 16;
|
||||
|
||||
// Export some version independent functions that are used outside of this mod
|
||||
pub use v14::xlog_utils::encode_logical_message;
|
||||
pub use v14::xlog_utils::from_pg_timestamp;
|
||||
pub use v14::xlog_utils::get_current_timestamp;
|
||||
pub use v14::xlog_utils::to_pg_timestamp;
|
||||
pub use v14::xlog_utils::XLogFileName;
|
||||
|
||||
@@ -220,6 +220,10 @@ pub const XLOG_CHECKPOINT_ONLINE: u8 = 0x10;
|
||||
pub const XLP_FIRST_IS_CONTRECORD: u16 = 0x0001;
|
||||
pub const XLP_LONG_HEADER: u16 = 0x0002;
|
||||
|
||||
/* From replication/slot.h */
|
||||
pub const REPL_SLOT_ON_DISK_OFFSETOF_RESTART_LSN: usize = 4*4 /* offset of `slotdata` in ReplicationSlotOnDisk */
|
||||
+ 64 /* NameData */ + 4*4;
|
||||
|
||||
/* From fsm_internals.h */
|
||||
const FSM_NODES_PER_PAGE: usize = BLCKSZ as usize - SIZEOF_PAGE_HEADER_DATA - 4;
|
||||
const FSM_NON_LEAF_NODES_PER_PAGE: usize = BLCKSZ as usize / 2 - 1;
|
||||
|
||||
@@ -136,21 +136,42 @@ pub fn get_current_timestamp() -> TimestampTz {
|
||||
to_pg_timestamp(SystemTime::now())
|
||||
}
|
||||
|
||||
pub fn to_pg_timestamp(time: SystemTime) -> TimestampTz {
|
||||
const UNIX_EPOCH_JDATE: u64 = 2440588; /* == date2j(1970, 1, 1) */
|
||||
const POSTGRES_EPOCH_JDATE: u64 = 2451545; /* == date2j(2000, 1, 1) */
|
||||
// Module to reduce the scope of the constants
|
||||
mod timestamp_conversions {
|
||||
use std::time::Duration;
|
||||
|
||||
use super::*;
|
||||
|
||||
const UNIX_EPOCH_JDATE: u64 = 2440588; // == date2j(1970, 1, 1)
|
||||
const POSTGRES_EPOCH_JDATE: u64 = 2451545; // == date2j(2000, 1, 1)
|
||||
const SECS_PER_DAY: u64 = 86400;
|
||||
const USECS_PER_SEC: u64 = 1000000;
|
||||
match time.duration_since(SystemTime::UNIX_EPOCH) {
|
||||
Ok(n) => {
|
||||
((n.as_secs() - ((POSTGRES_EPOCH_JDATE - UNIX_EPOCH_JDATE) * SECS_PER_DAY))
|
||||
* USECS_PER_SEC
|
||||
+ n.subsec_micros() as u64) as i64
|
||||
const SECS_DIFF_UNIX_TO_POSTGRES_EPOCH: u64 =
|
||||
(POSTGRES_EPOCH_JDATE - UNIX_EPOCH_JDATE) * SECS_PER_DAY;
|
||||
|
||||
pub fn to_pg_timestamp(time: SystemTime) -> TimestampTz {
|
||||
match time.duration_since(SystemTime::UNIX_EPOCH) {
|
||||
Ok(n) => {
|
||||
((n.as_secs() - SECS_DIFF_UNIX_TO_POSTGRES_EPOCH) * USECS_PER_SEC
|
||||
+ n.subsec_micros() as u64) as i64
|
||||
}
|
||||
Err(_) => panic!("SystemTime before UNIX EPOCH!"),
|
||||
}
|
||||
Err(_) => panic!("SystemTime before UNIX EPOCH!"),
|
||||
}
|
||||
|
||||
pub fn from_pg_timestamp(time: TimestampTz) -> SystemTime {
|
||||
let time: u64 = time
|
||||
.try_into()
|
||||
.expect("timestamp before millenium (postgres epoch)");
|
||||
let since_unix_epoch = time + SECS_DIFF_UNIX_TO_POSTGRES_EPOCH * USECS_PER_SEC;
|
||||
SystemTime::UNIX_EPOCH
|
||||
.checked_add(Duration::from_micros(since_unix_epoch))
|
||||
.expect("SystemTime overflow")
|
||||
}
|
||||
}
|
||||
|
||||
pub use timestamp_conversions::{from_pg_timestamp, to_pg_timestamp};
|
||||
|
||||
// Returns (aligned) end_lsn of the last record in data_dir with WAL segments.
|
||||
// start_lsn must point to some previously known record boundary (beginning of
|
||||
// the next record). If no valid record after is found, start_lsn is returned
|
||||
@@ -481,4 +502,24 @@ pub fn encode_logical_message(prefix: &str, message: &str) -> Vec<u8> {
|
||||
wal
|
||||
}
|
||||
|
||||
// If you need to craft WAL and write tests for this module, put it at wal_craft crate.
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ts_conversion() {
|
||||
let now = SystemTime::now();
|
||||
let round_trip = from_pg_timestamp(to_pg_timestamp(now));
|
||||
|
||||
let now_since = now.duration_since(SystemTime::UNIX_EPOCH).unwrap();
|
||||
let round_trip_since = round_trip.duration_since(SystemTime::UNIX_EPOCH).unwrap();
|
||||
assert_eq!(now_since.as_micros(), round_trip_since.as_micros());
|
||||
|
||||
let now_pg = get_current_timestamp();
|
||||
let round_trip_pg = to_pg_timestamp(from_pg_timestamp(now_pg));
|
||||
|
||||
assert_eq!(now_pg, round_trip_pg);
|
||||
}
|
||||
|
||||
// If you need to craft WAL and write tests for this module, put it at wal_craft crate.
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ macro_rules! xlog_utils_test {
|
||||
($version:ident) => {
|
||||
#[path = "."]
|
||||
mod $version {
|
||||
#[allow(unused_imports)]
|
||||
pub use postgres_ffi::$version::wal_craft_test_export::*;
|
||||
#[allow(clippy::duplicate_mod)]
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -214,27 +214,24 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancellation safe as long as the AsyncWrite is cancellation safe.
|
||||
async fn flush<S: AsyncWrite + Unpin>(
|
||||
stream: &mut S,
|
||||
write_buf: &mut BytesMut,
|
||||
) -> Result<(), io::Error> {
|
||||
while write_buf.has_remaining() {
|
||||
let bytes_written = stream.write(write_buf.chunk()).await?;
|
||||
let bytes_written = stream.write_buf(write_buf).await?;
|
||||
if bytes_written == 0 {
|
||||
return Err(io::Error::new(
|
||||
ErrorKind::WriteZero,
|
||||
"failed to write message",
|
||||
));
|
||||
}
|
||||
// The advanced part will be garbage collected, likely during shifting
|
||||
// data left on next attempt to write to buffer when free space is not
|
||||
// enough.
|
||||
write_buf.advance(bytes_written);
|
||||
}
|
||||
write_buf.clear();
|
||||
stream.flush().await
|
||||
}
|
||||
|
||||
/// Cancellation safe as long as the AsyncWrite is cancellation safe.
|
||||
async fn shutdown<S: AsyncWrite + Unpin>(
|
||||
stream: &mut S,
|
||||
write_buf: &mut BytesMut,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
//! Postgres protocol messages serialization-deserialization. See
|
||||
//! <https://www.postgresql.org/docs/devel/protocol-message-formats.html>
|
||||
//! on message formats.
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
pub mod framed;
|
||||
|
||||
@@ -670,6 +671,7 @@ pub fn read_cstr(buf: &mut Bytes) -> Result<Bytes, ProtocolError> {
|
||||
}
|
||||
|
||||
pub const SQLSTATE_INTERNAL_ERROR: &[u8; 5] = b"XX000";
|
||||
pub const SQLSTATE_ADMIN_SHUTDOWN: &[u8; 5] = b"57P01";
|
||||
pub const SQLSTATE_SUCCESSFUL_COMPLETION: &[u8; 5] = b"00000";
|
||||
|
||||
impl<'a> BeMessage<'a> {
|
||||
|
||||
@@ -8,11 +8,13 @@ license.workspace = true
|
||||
anyhow.workspace = true
|
||||
async-trait.workspace = true
|
||||
once_cell.workspace = true
|
||||
aws-smithy-async.workspace = true
|
||||
aws-smithy-http.workspace = true
|
||||
aws-types.workspace = true
|
||||
aws-config.workspace = true
|
||||
aws-sdk-s3.workspace = true
|
||||
aws-credential-types.workspace = true
|
||||
bytes.workspace = true
|
||||
camino.workspace = true
|
||||
hyper = { workspace = true, features = ["stream"] }
|
||||
serde.workspace = true
|
||||
@@ -26,6 +28,13 @@ metrics.workspace = true
|
||||
utils.workspace = true
|
||||
pin-project-lite.workspace = true
|
||||
workspace_hack.workspace = true
|
||||
azure_core.workspace = true
|
||||
azure_identity.workspace = true
|
||||
azure_storage.workspace = true
|
||||
azure_storage_blobs.workspace = true
|
||||
futures-util.workspace = true
|
||||
http-types.workspace = true
|
||||
itertools.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
camino-tempfile.workspace = true
|
||||
|
||||
314
libs/remote_storage/src/azure_blob.rs
Normal file
314
libs/remote_storage/src/azure_blob.rs
Normal file
@@ -0,0 +1,314 @@
|
||||
//! Azure Blob Storage wrapper
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::num::NonZeroU32;
|
||||
use std::sync::Arc;
|
||||
use std::{borrow::Cow, io::Cursor};
|
||||
|
||||
use super::REMOTE_STORAGE_PREFIX_SEPARATOR;
|
||||
use anyhow::Result;
|
||||
use azure_core::request_options::{MaxResults, Metadata, Range};
|
||||
use azure_identity::DefaultAzureCredential;
|
||||
use azure_storage::StorageCredentials;
|
||||
use azure_storage_blobs::prelude::ClientBuilder;
|
||||
use azure_storage_blobs::{blob::operations::GetBlobBuilder, prelude::ContainerClient};
|
||||
use futures_util::StreamExt;
|
||||
use http_types::StatusCode;
|
||||
use tokio::io::AsyncRead;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::s3_bucket::RequestKind;
|
||||
use crate::{
|
||||
AzureConfig, ConcurrencyLimiter, Download, DownloadError, Listing, ListingMode, RemotePath,
|
||||
RemoteStorage, StorageMetadata,
|
||||
};
|
||||
|
||||
pub struct AzureBlobStorage {
|
||||
client: ContainerClient,
|
||||
prefix_in_container: Option<String>,
|
||||
max_keys_per_list_response: Option<NonZeroU32>,
|
||||
concurrency_limiter: ConcurrencyLimiter,
|
||||
}
|
||||
|
||||
impl AzureBlobStorage {
|
||||
pub fn new(azure_config: &AzureConfig) -> Result<Self> {
|
||||
debug!(
|
||||
"Creating azure remote storage for azure container {}",
|
||||
azure_config.container_name
|
||||
);
|
||||
|
||||
let account = env::var("AZURE_STORAGE_ACCOUNT").expect("missing AZURE_STORAGE_ACCOUNT");
|
||||
|
||||
// If the `AZURE_STORAGE_ACCESS_KEY` env var has an access key, use that,
|
||||
// otherwise try the token based credentials.
|
||||
let credentials = if let Ok(access_key) = env::var("AZURE_STORAGE_ACCESS_KEY") {
|
||||
StorageCredentials::access_key(account.clone(), access_key)
|
||||
} else {
|
||||
let token_credential = DefaultAzureCredential::default();
|
||||
StorageCredentials::token_credential(Arc::new(token_credential))
|
||||
};
|
||||
|
||||
let builder = ClientBuilder::new(account, credentials);
|
||||
|
||||
let client = builder.container_client(azure_config.container_name.to_owned());
|
||||
|
||||
let max_keys_per_list_response =
|
||||
if let Some(limit) = azure_config.max_keys_per_list_response {
|
||||
Some(
|
||||
NonZeroU32::new(limit as u32)
|
||||
.ok_or_else(|| anyhow::anyhow!("max_keys_per_list_response can't be 0"))?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(AzureBlobStorage {
|
||||
client,
|
||||
prefix_in_container: azure_config.prefix_in_container.to_owned(),
|
||||
max_keys_per_list_response,
|
||||
concurrency_limiter: ConcurrencyLimiter::new(azure_config.concurrency_limit.get()),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn relative_path_to_name(&self, path: &RemotePath) -> String {
|
||||
assert_eq!(std::path::MAIN_SEPARATOR, REMOTE_STORAGE_PREFIX_SEPARATOR);
|
||||
let path_string = path
|
||||
.get_path()
|
||||
.as_str()
|
||||
.trim_end_matches(REMOTE_STORAGE_PREFIX_SEPARATOR);
|
||||
match &self.prefix_in_container {
|
||||
Some(prefix) => {
|
||||
if prefix.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
|
||||
prefix.clone() + path_string
|
||||
} else {
|
||||
format!("{prefix}{REMOTE_STORAGE_PREFIX_SEPARATOR}{path_string}")
|
||||
}
|
||||
}
|
||||
None => path_string.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn name_to_relative_path(&self, key: &str) -> RemotePath {
|
||||
let relative_path =
|
||||
match key.strip_prefix(self.prefix_in_container.as_deref().unwrap_or_default()) {
|
||||
Some(stripped) => stripped,
|
||||
// we rely on Azure to return properly prefixed paths
|
||||
// for requests with a certain prefix
|
||||
None => panic!(
|
||||
"Key {key} does not start with container prefix {:?}",
|
||||
self.prefix_in_container
|
||||
),
|
||||
};
|
||||
RemotePath(
|
||||
relative_path
|
||||
.split(REMOTE_STORAGE_PREFIX_SEPARATOR)
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
async fn download_for_builder(
|
||||
&self,
|
||||
builder: GetBlobBuilder,
|
||||
) -> Result<Download, DownloadError> {
|
||||
let mut response = builder.into_stream();
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
// TODO give proper streaming response instead of buffering into RAM
|
||||
// https://github.com/neondatabase/neon/issues/5563
|
||||
let mut buf = Vec::new();
|
||||
while let Some(part) = response.next().await {
|
||||
let part = part.map_err(to_download_error)?;
|
||||
if let Some(blob_meta) = part.blob.metadata {
|
||||
metadata.extend(blob_meta.iter().map(|(k, v)| (k.to_owned(), v.to_owned())));
|
||||
}
|
||||
let data = part
|
||||
.data
|
||||
.collect()
|
||||
.await
|
||||
.map_err(|e| DownloadError::Other(e.into()))?;
|
||||
buf.extend_from_slice(&data.slice(..));
|
||||
}
|
||||
Ok(Download {
|
||||
download_stream: Box::pin(Cursor::new(buf)),
|
||||
metadata: Some(StorageMetadata(metadata)),
|
||||
})
|
||||
}
|
||||
|
||||
async fn permit(&self, kind: RequestKind) -> tokio::sync::SemaphorePermit<'_> {
|
||||
self.concurrency_limiter
|
||||
.acquire(kind)
|
||||
.await
|
||||
.expect("semaphore is never closed")
|
||||
}
|
||||
}
|
||||
|
||||
fn to_azure_metadata(metadata: StorageMetadata) -> Metadata {
|
||||
let mut res = Metadata::new();
|
||||
for (k, v) in metadata.0.into_iter() {
|
||||
res.insert(k, v);
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn to_download_error(error: azure_core::Error) -> DownloadError {
|
||||
if let Some(http_err) = error.as_http_error() {
|
||||
match http_err.status() {
|
||||
StatusCode::NotFound => DownloadError::NotFound,
|
||||
StatusCode::BadRequest => DownloadError::BadInput(anyhow::Error::new(error)),
|
||||
_ => DownloadError::Other(anyhow::Error::new(error)),
|
||||
}
|
||||
} else {
|
||||
DownloadError::Other(error.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl RemoteStorage for AzureBlobStorage {
|
||||
async fn list(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
mode: ListingMode,
|
||||
) -> anyhow::Result<Listing, DownloadError> {
|
||||
// get the passed prefix or if it is not set use prefix_in_bucket value
|
||||
let list_prefix = prefix
|
||||
.map(|p| self.relative_path_to_name(p))
|
||||
.or_else(|| self.prefix_in_container.clone())
|
||||
.map(|mut p| {
|
||||
// required to end with a separator
|
||||
// otherwise request will return only the entry of a prefix
|
||||
if matches!(mode, ListingMode::WithDelimiter)
|
||||
&& !p.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR)
|
||||
{
|
||||
p.push(REMOTE_STORAGE_PREFIX_SEPARATOR);
|
||||
}
|
||||
p
|
||||
});
|
||||
|
||||
let mut builder = self.client.list_blobs();
|
||||
|
||||
if let ListingMode::WithDelimiter = mode {
|
||||
builder = builder.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
|
||||
}
|
||||
|
||||
if let Some(prefix) = list_prefix {
|
||||
builder = builder.prefix(Cow::from(prefix.to_owned()));
|
||||
}
|
||||
|
||||
if let Some(limit) = self.max_keys_per_list_response {
|
||||
builder = builder.max_results(MaxResults::new(limit));
|
||||
}
|
||||
|
||||
let mut response = builder.into_stream();
|
||||
let mut res = Listing::default();
|
||||
while let Some(l) = response.next().await {
|
||||
let entry = l.map_err(to_download_error)?;
|
||||
let prefix_iter = entry
|
||||
.blobs
|
||||
.prefixes()
|
||||
.map(|prefix| self.name_to_relative_path(&prefix.name));
|
||||
res.prefixes.extend(prefix_iter);
|
||||
|
||||
let blob_iter = entry
|
||||
.blobs
|
||||
.blobs()
|
||||
.map(|k| self.name_to_relative_path(&k.name));
|
||||
res.keys.extend(blob_iter);
|
||||
}
|
||||
Ok(res)
|
||||
}
|
||||
async fn upload(
|
||||
&self,
|
||||
mut from: impl AsyncRead + Unpin + Send + Sync + 'static,
|
||||
data_size_bytes: usize,
|
||||
to: &RemotePath,
|
||||
metadata: Option<StorageMetadata>,
|
||||
) -> anyhow::Result<()> {
|
||||
let _permit = self.permit(RequestKind::Put).await;
|
||||
let blob_client = self.client.blob_client(self.relative_path_to_name(to));
|
||||
|
||||
// TODO FIX THIS UGLY HACK and don't buffer the entire object
|
||||
// into RAM here, but use the streaming interface. For that,
|
||||
// we'd have to change the interface though...
|
||||
// https://github.com/neondatabase/neon/issues/5563
|
||||
let mut buf = Vec::with_capacity(data_size_bytes);
|
||||
tokio::io::copy(&mut from, &mut buf).await?;
|
||||
let body = azure_core::Body::Bytes(buf.into());
|
||||
|
||||
let mut builder = blob_client.put_block_blob(body);
|
||||
|
||||
if let Some(metadata) = metadata {
|
||||
builder = builder.metadata(to_azure_metadata(metadata));
|
||||
}
|
||||
|
||||
let _response = builder.into_future().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn download(&self, from: &RemotePath) -> Result<Download, DownloadError> {
|
||||
let _permit = self.permit(RequestKind::Get).await;
|
||||
let blob_client = self.client.blob_client(self.relative_path_to_name(from));
|
||||
|
||||
let builder = blob_client.get();
|
||||
|
||||
self.download_for_builder(builder).await
|
||||
}
|
||||
|
||||
async fn download_byte_range(
|
||||
&self,
|
||||
from: &RemotePath,
|
||||
start_inclusive: u64,
|
||||
end_exclusive: Option<u64>,
|
||||
) -> Result<Download, DownloadError> {
|
||||
let _permit = self.permit(RequestKind::Get).await;
|
||||
let blob_client = self.client.blob_client(self.relative_path_to_name(from));
|
||||
|
||||
let mut builder = blob_client.get();
|
||||
|
||||
if let Some(end_exclusive) = end_exclusive {
|
||||
builder = builder.range(Range::new(start_inclusive, end_exclusive));
|
||||
} else {
|
||||
// Open ranges are not supported by the SDK so we work around
|
||||
// by setting the upper limit extremely high (but high enough
|
||||
// to still be representable by signed 64 bit integers).
|
||||
// TODO remove workaround once the SDK adds open range support
|
||||
// https://github.com/Azure/azure-sdk-for-rust/issues/1438
|
||||
let end_exclusive = u64::MAX / 4;
|
||||
builder = builder.range(Range::new(start_inclusive, end_exclusive));
|
||||
}
|
||||
|
||||
self.download_for_builder(builder).await
|
||||
}
|
||||
|
||||
async fn delete(&self, path: &RemotePath) -> anyhow::Result<()> {
|
||||
let _permit = self.permit(RequestKind::Delete).await;
|
||||
let blob_client = self.client.blob_client(self.relative_path_to_name(path));
|
||||
|
||||
let builder = blob_client.delete();
|
||||
|
||||
match builder.into_future().await {
|
||||
Ok(_response) => Ok(()),
|
||||
Err(e) => {
|
||||
if let Some(http_err) = e.as_http_error() {
|
||||
if http_err.status() == StatusCode::NotFound {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Err(anyhow::Error::new(e))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn delete_objects<'a>(&self, paths: &'a [RemotePath]) -> anyhow::Result<()> {
|
||||
// Permit is already obtained by inner delete function
|
||||
|
||||
// TODO batch requests are also not supported by the SDK
|
||||
// https://github.com/Azure/azure-sdk-for-rust/issues/1068
|
||||
// https://github.com/Azure/azure-sdk-for-rust/issues/1249
|
||||
for path in paths {
|
||||
self.delete(path).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -4,41 +4,43 @@
|
||||
//! [`RemoteStorage`] trait a CRUD-like generic abstraction to use for adapting external storages with a few implementations:
|
||||
//! * [`local_fs`] allows to use local file system as an external storage
|
||||
//! * [`s3_bucket`] uses AWS S3 bucket as an external storage
|
||||
//! * [`azure_blob`] allows to use Azure Blob storage as an external storage
|
||||
//!
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
mod azure_blob;
|
||||
mod local_fs;
|
||||
mod s3_bucket;
|
||||
mod simulate_failures;
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fmt::Debug,
|
||||
num::{NonZeroU32, NonZeroUsize},
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
};
|
||||
use std::{collections::HashMap, fmt::Debug, num::NonZeroUsize, pin::Pin, sync::Arc};
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use camino::{Utf8Path, Utf8PathBuf};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::io;
|
||||
use tokio::{io, sync::Semaphore};
|
||||
use toml_edit::Item;
|
||||
use tracing::info;
|
||||
|
||||
pub use self::{local_fs::LocalFs, s3_bucket::S3Bucket, simulate_failures::UnreliableWrapper};
|
||||
pub use self::{
|
||||
azure_blob::AzureBlobStorage, local_fs::LocalFs, s3_bucket::S3Bucket,
|
||||
simulate_failures::UnreliableWrapper,
|
||||
};
|
||||
use s3_bucket::RequestKind;
|
||||
|
||||
/// How many different timelines can be processed simultaneously when synchronizing layers with the remote storage.
|
||||
/// During regular work, pageserver produces one layer file per timeline checkpoint, with bursts of concurrency
|
||||
/// during start (where local and remote timelines are compared and initial sync tasks are scheduled) and timeline attach.
|
||||
/// Both cases may trigger timeline download, that might download a lot of layers. This concurrency is limited by the clients internally, if needed.
|
||||
pub const DEFAULT_REMOTE_STORAGE_MAX_CONCURRENT_SYNCS: usize = 50;
|
||||
pub const DEFAULT_REMOTE_STORAGE_MAX_SYNC_ERRORS: u32 = 10;
|
||||
/// Currently, sync happens with AWS S3, that has two limits on requests per second:
|
||||
/// ~200 RPS for IAM services
|
||||
/// <https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/UsingWithRDS.IAMDBAuth.html>
|
||||
/// ~3500 PUT/COPY/POST/DELETE or 5500 GET/HEAD S3 requests
|
||||
/// <https://aws.amazon.com/premiumsupport/knowledge-center/s3-request-limit-avoid-throttling/>
|
||||
pub const DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT: usize = 100;
|
||||
/// We set this a little bit low as we currently buffer the entire file into RAM
|
||||
///
|
||||
/// Here, a limit of max 20k concurrent connections was noted.
|
||||
/// <https://learn.microsoft.com/en-us/answers/questions/1301863/is-there-any-limitation-to-concurrent-connections>
|
||||
pub const DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT: usize = 30;
|
||||
/// No limits on the client side, which currenltly means 1000 for AWS S3.
|
||||
/// <https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html#API_ListObjectsV2_RequestSyntax>
|
||||
pub const DEFAULT_MAX_KEYS_PER_LIST_RESPONSE: Option<i32> = None;
|
||||
@@ -117,6 +119,22 @@ impl RemotePath {
|
||||
}
|
||||
}
|
||||
|
||||
/// We don't need callers to be able to pass arbitrary delimiters: just control
|
||||
/// whether listings will use a '/' separator or not.
|
||||
///
|
||||
/// The WithDelimiter mode will populate `prefixes` and `keys` in the result. The
|
||||
/// NoDelimiter mode will only populate `keys`.
|
||||
pub enum ListingMode {
|
||||
WithDelimiter,
|
||||
NoDelimiter,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct Listing {
|
||||
pub prefixes: Vec<RemotePath>,
|
||||
pub keys: Vec<RemotePath>,
|
||||
}
|
||||
|
||||
/// Storage (potentially remote) API to manage its state.
|
||||
/// This storage tries to be unaware of any layered repository context,
|
||||
/// providing basic CRUD operations for storage files.
|
||||
@@ -129,8 +147,13 @@ pub trait RemoteStorage: Send + Sync + 'static {
|
||||
async fn list_prefixes(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
) -> Result<Vec<RemotePath>, DownloadError>;
|
||||
|
||||
) -> Result<Vec<RemotePath>, DownloadError> {
|
||||
let result = self
|
||||
.list(prefix, ListingMode::WithDelimiter)
|
||||
.await?
|
||||
.prefixes;
|
||||
Ok(result)
|
||||
}
|
||||
/// Lists all files in directory "recursively"
|
||||
/// (not really recursively, because AWS has a flat namespace)
|
||||
/// Note: This is subtely different than list_prefixes,
|
||||
@@ -142,7 +165,16 @@ pub trait RemoteStorage: Send + Sync + 'static {
|
||||
/// whereas,
|
||||
/// list_prefixes("foo/bar/") = ["cat", "dog"]
|
||||
/// See `test_real_s3.rs` for more details.
|
||||
async fn list_files(&self, folder: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>>;
|
||||
async fn list_files(&self, prefix: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>> {
|
||||
let result = self.list(prefix, ListingMode::NoDelimiter).await?.keys;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
_mode: ListingMode,
|
||||
) -> anyhow::Result<Listing, DownloadError>;
|
||||
|
||||
/// Streams the local file contents into remote into the remote storage entry.
|
||||
async fn upload(
|
||||
@@ -193,6 +225,9 @@ pub enum DownloadError {
|
||||
BadInput(anyhow::Error),
|
||||
/// The file was not found in the remote storage.
|
||||
NotFound,
|
||||
/// A cancellation token aborted the download, typically during
|
||||
/// tenant detach or process shutdown.
|
||||
Cancelled,
|
||||
/// The file was found in the remote storage, but the download failed.
|
||||
Other(anyhow::Error),
|
||||
}
|
||||
@@ -203,6 +238,7 @@ impl std::fmt::Display for DownloadError {
|
||||
DownloadError::BadInput(e) => {
|
||||
write!(f, "Failed to download a remote file due to user input: {e}")
|
||||
}
|
||||
DownloadError::Cancelled => write!(f, "Cancelled, shutting down"),
|
||||
DownloadError::NotFound => write!(f, "No file found for the remote object id given"),
|
||||
DownloadError::Other(e) => write!(f, "Failed to download a remote file: {e:?}"),
|
||||
}
|
||||
@@ -217,10 +253,24 @@ impl std::error::Error for DownloadError {}
|
||||
pub enum GenericRemoteStorage {
|
||||
LocalFs(LocalFs),
|
||||
AwsS3(Arc<S3Bucket>),
|
||||
AzureBlob(Arc<AzureBlobStorage>),
|
||||
Unreliable(Arc<UnreliableWrapper>),
|
||||
}
|
||||
|
||||
impl GenericRemoteStorage {
|
||||
pub async fn list(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
mode: ListingMode,
|
||||
) -> anyhow::Result<Listing, DownloadError> {
|
||||
match self {
|
||||
Self::LocalFs(s) => s.list(prefix, mode).await,
|
||||
Self::AwsS3(s) => s.list(prefix, mode).await,
|
||||
Self::AzureBlob(s) => s.list(prefix, mode).await,
|
||||
Self::Unreliable(s) => s.list(prefix, mode).await,
|
||||
}
|
||||
}
|
||||
|
||||
// A function for listing all the files in a "directory"
|
||||
// Example:
|
||||
// list_files("foo/bar") = ["foo/bar/a.txt", "foo/bar/b.txt"]
|
||||
@@ -228,6 +278,7 @@ impl GenericRemoteStorage {
|
||||
match self {
|
||||
Self::LocalFs(s) => s.list_files(folder).await,
|
||||
Self::AwsS3(s) => s.list_files(folder).await,
|
||||
Self::AzureBlob(s) => s.list_files(folder).await,
|
||||
Self::Unreliable(s) => s.list_files(folder).await,
|
||||
}
|
||||
}
|
||||
@@ -242,6 +293,7 @@ impl GenericRemoteStorage {
|
||||
match self {
|
||||
Self::LocalFs(s) => s.list_prefixes(prefix).await,
|
||||
Self::AwsS3(s) => s.list_prefixes(prefix).await,
|
||||
Self::AzureBlob(s) => s.list_prefixes(prefix).await,
|
||||
Self::Unreliable(s) => s.list_prefixes(prefix).await,
|
||||
}
|
||||
}
|
||||
@@ -256,6 +308,7 @@ impl GenericRemoteStorage {
|
||||
match self {
|
||||
Self::LocalFs(s) => s.upload(from, data_size_bytes, to, metadata).await,
|
||||
Self::AwsS3(s) => s.upload(from, data_size_bytes, to, metadata).await,
|
||||
Self::AzureBlob(s) => s.upload(from, data_size_bytes, to, metadata).await,
|
||||
Self::Unreliable(s) => s.upload(from, data_size_bytes, to, metadata).await,
|
||||
}
|
||||
}
|
||||
@@ -264,6 +317,7 @@ impl GenericRemoteStorage {
|
||||
match self {
|
||||
Self::LocalFs(s) => s.download(from).await,
|
||||
Self::AwsS3(s) => s.download(from).await,
|
||||
Self::AzureBlob(s) => s.download(from).await,
|
||||
Self::Unreliable(s) => s.download(from).await,
|
||||
}
|
||||
}
|
||||
@@ -283,6 +337,10 @@ impl GenericRemoteStorage {
|
||||
s.download_byte_range(from, start_inclusive, end_exclusive)
|
||||
.await
|
||||
}
|
||||
Self::AzureBlob(s) => {
|
||||
s.download_byte_range(from, start_inclusive, end_exclusive)
|
||||
.await
|
||||
}
|
||||
Self::Unreliable(s) => {
|
||||
s.download_byte_range(from, start_inclusive, end_exclusive)
|
||||
.await
|
||||
@@ -294,6 +352,7 @@ impl GenericRemoteStorage {
|
||||
match self {
|
||||
Self::LocalFs(s) => s.delete(path).await,
|
||||
Self::AwsS3(s) => s.delete(path).await,
|
||||
Self::AzureBlob(s) => s.delete(path).await,
|
||||
Self::Unreliable(s) => s.delete(path).await,
|
||||
}
|
||||
}
|
||||
@@ -302,6 +361,7 @@ impl GenericRemoteStorage {
|
||||
match self {
|
||||
Self::LocalFs(s) => s.delete_objects(paths).await,
|
||||
Self::AwsS3(s) => s.delete_objects(paths).await,
|
||||
Self::AzureBlob(s) => s.delete_objects(paths).await,
|
||||
Self::Unreliable(s) => s.delete_objects(paths).await,
|
||||
}
|
||||
}
|
||||
@@ -319,6 +379,11 @@ impl GenericRemoteStorage {
|
||||
s3_config.bucket_name, s3_config.bucket_region, s3_config.prefix_in_bucket, s3_config.endpoint);
|
||||
Self::AwsS3(Arc::new(S3Bucket::new(s3_config)?))
|
||||
}
|
||||
RemoteStorageKind::AzureContainer(azure_config) => {
|
||||
info!("Using azure container '{}' in region '{}' as a remote storage, prefix in container: '{:?}'",
|
||||
azure_config.container_name, azure_config.container_region, azure_config.prefix_in_container);
|
||||
Self::AzureBlob(Arc::new(AzureBlobStorage::new(azure_config)?))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -366,10 +431,6 @@ pub struct StorageMetadata(HashMap<String, String>);
|
||||
/// External backup storage configuration, enough for creating a client for that storage.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RemoteStorageConfig {
|
||||
/// Max allowed number of concurrent sync operations between the API user and the remote storage.
|
||||
pub max_concurrent_syncs: NonZeroUsize,
|
||||
/// Max allowed errors before the sync task is considered failed and evicted.
|
||||
pub max_sync_errors: NonZeroU32,
|
||||
/// The storage connection configuration.
|
||||
pub storage: RemoteStorageKind,
|
||||
}
|
||||
@@ -383,6 +444,9 @@ pub enum RemoteStorageKind {
|
||||
/// AWS S3 based storage, storing all files in the S3 bucket
|
||||
/// specified by the config
|
||||
AwsS3(S3Config),
|
||||
/// Azure Blob based storage, storing all files in the container
|
||||
/// specified by the config
|
||||
AzureContainer(AzureConfig),
|
||||
}
|
||||
|
||||
/// AWS S3 bucket coordinates and access credentials to manage the bucket contents (read and write).
|
||||
@@ -422,27 +486,53 @@ impl Debug for S3Config {
|
||||
}
|
||||
}
|
||||
|
||||
/// Azure bucket coordinates and access credentials to manage the bucket contents (read and write).
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct AzureConfig {
|
||||
/// Name of the container to connect to.
|
||||
pub container_name: String,
|
||||
/// The region where the bucket is located at.
|
||||
pub container_region: String,
|
||||
/// A "subfolder" in the container, to use the same container separately by multiple remote storage users at once.
|
||||
pub prefix_in_container: Option<String>,
|
||||
/// Azure has various limits on its API calls, we need not to exceed those.
|
||||
/// See [`DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT`] for more details.
|
||||
pub concurrency_limit: NonZeroUsize,
|
||||
pub max_keys_per_list_response: Option<i32>,
|
||||
}
|
||||
|
||||
impl Debug for AzureConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AzureConfig")
|
||||
.field("bucket_name", &self.container_name)
|
||||
.field("bucket_region", &self.container_region)
|
||||
.field("prefix_in_bucket", &self.prefix_in_container)
|
||||
.field("concurrency_limit", &self.concurrency_limit)
|
||||
.field(
|
||||
"max_keys_per_list_response",
|
||||
&self.max_keys_per_list_response,
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl RemoteStorageConfig {
|
||||
pub fn from_toml(toml: &toml_edit::Item) -> anyhow::Result<Option<RemoteStorageConfig>> {
|
||||
let local_path = toml.get("local_path");
|
||||
let bucket_name = toml.get("bucket_name");
|
||||
let bucket_region = toml.get("bucket_region");
|
||||
let container_name = toml.get("container_name");
|
||||
let container_region = toml.get("container_region");
|
||||
|
||||
let max_concurrent_syncs = NonZeroUsize::new(
|
||||
parse_optional_integer("max_concurrent_syncs", toml)?
|
||||
.unwrap_or(DEFAULT_REMOTE_STORAGE_MAX_CONCURRENT_SYNCS),
|
||||
)
|
||||
.context("Failed to parse 'max_concurrent_syncs' as a positive integer")?;
|
||||
|
||||
let max_sync_errors = NonZeroU32::new(
|
||||
parse_optional_integer("max_sync_errors", toml)?
|
||||
.unwrap_or(DEFAULT_REMOTE_STORAGE_MAX_SYNC_ERRORS),
|
||||
)
|
||||
.context("Failed to parse 'max_sync_errors' as a positive integer")?;
|
||||
let use_azure = container_name.is_some() && container_region.is_some();
|
||||
|
||||
let default_concurrency_limit = if use_azure {
|
||||
DEFAULT_REMOTE_STORAGE_AZURE_CONCURRENCY_LIMIT
|
||||
} else {
|
||||
DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT
|
||||
};
|
||||
let concurrency_limit = NonZeroUsize::new(
|
||||
parse_optional_integer("concurrency_limit", toml)?
|
||||
.unwrap_or(DEFAULT_REMOTE_STORAGE_S3_CONCURRENCY_LIMIT),
|
||||
parse_optional_integer("concurrency_limit", toml)?.unwrap_or(default_concurrency_limit),
|
||||
)
|
||||
.context("Failed to parse 'concurrency_limit' as a positive integer")?;
|
||||
|
||||
@@ -451,40 +541,73 @@ impl RemoteStorageConfig {
|
||||
.context("Failed to parse 'max_keys_per_list_response' as a positive integer")?
|
||||
.or(DEFAULT_MAX_KEYS_PER_LIST_RESPONSE);
|
||||
|
||||
let storage = match (local_path, bucket_name, bucket_region) {
|
||||
let endpoint = toml
|
||||
.get("endpoint")
|
||||
.map(|endpoint| parse_toml_string("endpoint", endpoint))
|
||||
.transpose()?;
|
||||
|
||||
let storage = match (
|
||||
local_path,
|
||||
bucket_name,
|
||||
bucket_region,
|
||||
container_name,
|
||||
container_region,
|
||||
) {
|
||||
// no 'local_path' nor 'bucket_name' options are provided, consider this remote storage disabled
|
||||
(None, None, None) => return Ok(None),
|
||||
(_, Some(_), None) => {
|
||||
(None, None, None, None, None) => return Ok(None),
|
||||
(_, Some(_), None, ..) => {
|
||||
bail!("'bucket_region' option is mandatory if 'bucket_name' is given ")
|
||||
}
|
||||
(_, None, Some(_)) => {
|
||||
(_, None, Some(_), ..) => {
|
||||
bail!("'bucket_name' option is mandatory if 'bucket_region' is given ")
|
||||
}
|
||||
(None, Some(bucket_name), Some(bucket_region)) => RemoteStorageKind::AwsS3(S3Config {
|
||||
bucket_name: parse_toml_string("bucket_name", bucket_name)?,
|
||||
bucket_region: parse_toml_string("bucket_region", bucket_region)?,
|
||||
prefix_in_bucket: toml
|
||||
.get("prefix_in_bucket")
|
||||
.map(|prefix_in_bucket| parse_toml_string("prefix_in_bucket", prefix_in_bucket))
|
||||
.transpose()?,
|
||||
endpoint: toml
|
||||
.get("endpoint")
|
||||
.map(|endpoint| parse_toml_string("endpoint", endpoint))
|
||||
.transpose()?,
|
||||
concurrency_limit,
|
||||
max_keys_per_list_response,
|
||||
}),
|
||||
(Some(local_path), None, None) => RemoteStorageKind::LocalFs(Utf8PathBuf::from(
|
||||
parse_toml_string("local_path", local_path)?,
|
||||
)),
|
||||
(Some(_), Some(_), _) => bail!("local_path and bucket_name are mutually exclusive"),
|
||||
(None, Some(bucket_name), Some(bucket_region), ..) => {
|
||||
RemoteStorageKind::AwsS3(S3Config {
|
||||
bucket_name: parse_toml_string("bucket_name", bucket_name)?,
|
||||
bucket_region: parse_toml_string("bucket_region", bucket_region)?,
|
||||
prefix_in_bucket: toml
|
||||
.get("prefix_in_bucket")
|
||||
.map(|prefix_in_bucket| {
|
||||
parse_toml_string("prefix_in_bucket", prefix_in_bucket)
|
||||
})
|
||||
.transpose()?,
|
||||
endpoint,
|
||||
concurrency_limit,
|
||||
max_keys_per_list_response,
|
||||
})
|
||||
}
|
||||
(_, _, _, Some(_), None) => {
|
||||
bail!("'container_name' option is mandatory if 'container_region' is given ")
|
||||
}
|
||||
(_, _, _, None, Some(_)) => {
|
||||
bail!("'container_name' option is mandatory if 'container_region' is given ")
|
||||
}
|
||||
(None, None, None, Some(container_name), Some(container_region)) => {
|
||||
RemoteStorageKind::AzureContainer(AzureConfig {
|
||||
container_name: parse_toml_string("container_name", container_name)?,
|
||||
container_region: parse_toml_string("container_region", container_region)?,
|
||||
prefix_in_container: toml
|
||||
.get("prefix_in_container")
|
||||
.map(|prefix_in_container| {
|
||||
parse_toml_string("prefix_in_container", prefix_in_container)
|
||||
})
|
||||
.transpose()?,
|
||||
concurrency_limit,
|
||||
max_keys_per_list_response,
|
||||
})
|
||||
}
|
||||
(Some(local_path), None, None, None, None) => RemoteStorageKind::LocalFs(
|
||||
Utf8PathBuf::from(parse_toml_string("local_path", local_path)?),
|
||||
),
|
||||
(Some(_), Some(_), ..) => {
|
||||
bail!("'local_path' and 'bucket_name' are mutually exclusive")
|
||||
}
|
||||
(Some(_), _, _, Some(_), Some(_)) => {
|
||||
bail!("local_path and 'container_name' are mutually exclusive")
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Some(RemoteStorageConfig {
|
||||
max_concurrent_syncs,
|
||||
max_sync_errors,
|
||||
storage,
|
||||
}))
|
||||
Ok(Some(RemoteStorageConfig { storage }))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -513,6 +636,46 @@ fn parse_toml_string(name: &str, item: &Item) -> anyhow::Result<String> {
|
||||
Ok(s.to_string())
|
||||
}
|
||||
|
||||
struct ConcurrencyLimiter {
|
||||
// Every request to S3 can be throttled or cancelled, if a certain number of requests per second is exceeded.
|
||||
// Same goes to IAM, which is queried before every S3 request, if enabled. IAM has even lower RPS threshold.
|
||||
// The helps to ensure we don't exceed the thresholds.
|
||||
write: Arc<Semaphore>,
|
||||
read: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl ConcurrencyLimiter {
|
||||
fn for_kind(&self, kind: RequestKind) -> &Arc<Semaphore> {
|
||||
match kind {
|
||||
RequestKind::Get => &self.read,
|
||||
RequestKind::Put => &self.write,
|
||||
RequestKind::List => &self.read,
|
||||
RequestKind::Delete => &self.write,
|
||||
}
|
||||
}
|
||||
|
||||
async fn acquire(
|
||||
&self,
|
||||
kind: RequestKind,
|
||||
) -> Result<tokio::sync::SemaphorePermit<'_>, tokio::sync::AcquireError> {
|
||||
self.for_kind(kind).acquire().await
|
||||
}
|
||||
|
||||
async fn acquire_owned(
|
||||
&self,
|
||||
kind: RequestKind,
|
||||
) -> Result<tokio::sync::OwnedSemaphorePermit, tokio::sync::AcquireError> {
|
||||
Arc::clone(self.for_kind(kind)).acquire_owned().await
|
||||
}
|
||||
|
||||
fn new(limit: usize) -> ConcurrencyLimiter {
|
||||
Self {
|
||||
read: Arc::new(Semaphore::new(limit)),
|
||||
write: Arc::new(Semaphore::new(limit)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -15,7 +15,7 @@ use tokio::{
|
||||
use tracing::*;
|
||||
use utils::{crashsafe::path_with_suffix_extension, fs_ext::is_directory_empty};
|
||||
|
||||
use crate::{Download, DownloadError, RemotePath};
|
||||
use crate::{Download, DownloadError, Listing, ListingMode, RemotePath};
|
||||
|
||||
use super::{RemoteStorage, StorageMetadata};
|
||||
|
||||
@@ -75,7 +75,7 @@ impl LocalFs {
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
async fn list(&self) -> anyhow::Result<Vec<RemotePath>> {
|
||||
async fn list_all(&self) -> anyhow::Result<Vec<RemotePath>> {
|
||||
Ok(get_all_files(&self.storage_root, true)
|
||||
.await?
|
||||
.into_iter()
|
||||
@@ -89,52 +89,10 @@ impl LocalFs {
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl RemoteStorage for LocalFs {
|
||||
async fn list_prefixes(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
) -> Result<Vec<RemotePath>, DownloadError> {
|
||||
let path = match prefix {
|
||||
Some(prefix) => Cow::Owned(prefix.with_base(&self.storage_root)),
|
||||
None => Cow::Borrowed(&self.storage_root),
|
||||
};
|
||||
|
||||
let prefixes_to_filter = get_all_files(path.as_ref(), false)
|
||||
.await
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
let mut prefixes = Vec::with_capacity(prefixes_to_filter.len());
|
||||
|
||||
// filter out empty directories to mirror s3 behavior.
|
||||
for prefix in prefixes_to_filter {
|
||||
if prefix.is_dir()
|
||||
&& is_directory_empty(&prefix)
|
||||
.await
|
||||
.map_err(DownloadError::Other)?
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
prefixes.push(
|
||||
prefix
|
||||
.strip_prefix(&self.storage_root)
|
||||
.context("Failed to strip prefix")
|
||||
.and_then(RemotePath::new)
|
||||
.expect(
|
||||
"We list files for storage root, hence should be able to remote the prefix",
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
Ok(prefixes)
|
||||
}
|
||||
|
||||
// recursively lists all files in a directory,
|
||||
// mirroring the `list_files` for `s3_bucket`
|
||||
async fn list_files(&self, folder: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>> {
|
||||
async fn list_recursive(&self, folder: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>> {
|
||||
let full_path = match folder {
|
||||
Some(folder) => folder.with_base(&self.storage_root),
|
||||
None => self.storage_root.clone(),
|
||||
@@ -186,6 +144,70 @@ impl RemoteStorage for LocalFs {
|
||||
|
||||
Ok(files)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl RemoteStorage for LocalFs {
|
||||
async fn list(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
mode: ListingMode,
|
||||
) -> Result<Listing, DownloadError> {
|
||||
let mut result = Listing::default();
|
||||
|
||||
if let ListingMode::NoDelimiter = mode {
|
||||
let keys = self
|
||||
.list_recursive(prefix)
|
||||
.await
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
result.keys = keys
|
||||
.into_iter()
|
||||
.filter(|k| {
|
||||
let path = k.with_base(&self.storage_root);
|
||||
!path.is_dir()
|
||||
})
|
||||
.collect();
|
||||
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
let path = match prefix {
|
||||
Some(prefix) => Cow::Owned(prefix.with_base(&self.storage_root)),
|
||||
None => Cow::Borrowed(&self.storage_root),
|
||||
};
|
||||
|
||||
let prefixes_to_filter = get_all_files(path.as_ref(), false)
|
||||
.await
|
||||
.map_err(DownloadError::Other)?;
|
||||
|
||||
// filter out empty directories to mirror s3 behavior.
|
||||
for prefix in prefixes_to_filter {
|
||||
if prefix.is_dir()
|
||||
&& is_directory_empty(&prefix)
|
||||
.await
|
||||
.map_err(DownloadError::Other)?
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let stripped = prefix
|
||||
.strip_prefix(&self.storage_root)
|
||||
.context("Failed to strip prefix")
|
||||
.and_then(RemotePath::new)
|
||||
.expect(
|
||||
"We list files for storage root, hence should be able to remote the prefix",
|
||||
);
|
||||
|
||||
if prefix.is_dir() {
|
||||
result.prefixes.push(stripped);
|
||||
} else {
|
||||
result.keys.push(stripped);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn upload(
|
||||
&self,
|
||||
@@ -479,7 +501,7 @@ mod fs_tests {
|
||||
|
||||
let target_path_1 = upload_dummy_file(&storage, "upload_1", None).await?;
|
||||
assert_eq!(
|
||||
storage.list().await?,
|
||||
storage.list_all().await?,
|
||||
vec![target_path_1.clone()],
|
||||
"Should list a single file after first upload"
|
||||
);
|
||||
@@ -667,7 +689,7 @@ mod fs_tests {
|
||||
let upload_target = upload_dummy_file(&storage, upload_name, None).await?;
|
||||
|
||||
storage.delete(&upload_target).await?;
|
||||
assert!(storage.list().await?.is_empty());
|
||||
assert!(storage.list_all().await?.is_empty());
|
||||
|
||||
storage
|
||||
.delete(&upload_target)
|
||||
@@ -725,6 +747,43 @@ mod fs_tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list() -> anyhow::Result<()> {
|
||||
// No delimiter: should recursively list everything
|
||||
let storage = create_storage()?;
|
||||
let child = upload_dummy_file(&storage, "grandparent/parent/child", None).await?;
|
||||
let uncle = upload_dummy_file(&storage, "grandparent/uncle", None).await?;
|
||||
|
||||
let listing = storage.list(None, ListingMode::NoDelimiter).await?;
|
||||
assert!(listing.prefixes.is_empty());
|
||||
assert_eq!(listing.keys, [uncle.clone(), child.clone()].to_vec());
|
||||
|
||||
// Delimiter: should only go one deep
|
||||
let listing = storage.list(None, ListingMode::WithDelimiter).await?;
|
||||
|
||||
assert_eq!(
|
||||
listing.prefixes,
|
||||
[RemotePath::from_string("timelines").unwrap()].to_vec()
|
||||
);
|
||||
assert!(listing.keys.is_empty());
|
||||
|
||||
// Delimiter & prefix
|
||||
let listing = storage
|
||||
.list(
|
||||
Some(&RemotePath::from_string("timelines/some_timeline/grandparent").unwrap()),
|
||||
ListingMode::WithDelimiter,
|
||||
)
|
||||
.await?;
|
||||
assert_eq!(
|
||||
listing.prefixes,
|
||||
[RemotePath::from_string("timelines/some_timeline/grandparent/parent").unwrap()]
|
||||
.to_vec()
|
||||
);
|
||||
assert_eq!(listing.keys, [uncle.clone()].to_vec());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn upload_dummy_file(
|
||||
storage: &LocalFs,
|
||||
name: &str,
|
||||
@@ -777,7 +836,7 @@ mod fs_tests {
|
||||
}
|
||||
|
||||
async fn list_files_sorted(storage: &LocalFs) -> anyhow::Result<Vec<RemotePath>> {
|
||||
let mut files = storage.list().await?;
|
||||
let mut files = storage.list_all().await?;
|
||||
files.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
@@ -4,42 +4,44 @@
|
||||
//! allowing multiple api users to independently work with the same S3 bucket, if
|
||||
//! their bucket prefixes are both specified and different.
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::{borrow::Cow, sync::Arc};
|
||||
|
||||
use anyhow::Context;
|
||||
use aws_config::{
|
||||
environment::credentials::EnvironmentVariableCredentialsProvider,
|
||||
imds::credentials::ImdsCredentialsProvider, meta::credentials::CredentialsProviderChain,
|
||||
provider_config::ProviderConfig, web_identity_token::WebIdentityTokenCredentialsProvider,
|
||||
imds::credentials::ImdsCredentialsProvider,
|
||||
meta::credentials::CredentialsProviderChain,
|
||||
provider_config::ProviderConfig,
|
||||
retry::{RetryConfigBuilder, RetryMode},
|
||||
web_identity_token::WebIdentityTokenCredentialsProvider,
|
||||
};
|
||||
use aws_credential_types::cache::CredentialsCache;
|
||||
use aws_sdk_s3::{
|
||||
config::{Config, Region},
|
||||
config::{AsyncSleep, Config, Region, SharedAsyncSleep},
|
||||
error::SdkError,
|
||||
operation::get_object::GetObjectError,
|
||||
primitives::ByteStream,
|
||||
types::{Delete, ObjectIdentifier},
|
||||
Client,
|
||||
};
|
||||
use aws_smithy_async::rt::sleep::TokioSleep;
|
||||
use aws_smithy_http::body::SdkBody;
|
||||
use hyper::Body;
|
||||
use scopeguard::ScopeGuard;
|
||||
use tokio::{
|
||||
io::{self, AsyncRead},
|
||||
sync::Semaphore,
|
||||
};
|
||||
use tokio::io::{self, AsyncRead};
|
||||
use tokio_util::io::ReaderStream;
|
||||
use tracing::debug;
|
||||
|
||||
use super::StorageMetadata;
|
||||
use crate::{
|
||||
Download, DownloadError, RemotePath, RemoteStorage, S3Config, MAX_KEYS_PER_DELETE,
|
||||
REMOTE_STORAGE_PREFIX_SEPARATOR,
|
||||
ConcurrencyLimiter, Download, DownloadError, Listing, ListingMode, RemotePath, RemoteStorage,
|
||||
S3Config, MAX_KEYS_PER_DELETE, REMOTE_STORAGE_PREFIX_SEPARATOR,
|
||||
};
|
||||
|
||||
pub(super) mod metrics;
|
||||
|
||||
use self::metrics::{AttemptOutcome, RequestKind};
|
||||
use self::metrics::AttemptOutcome;
|
||||
pub(super) use self::metrics::RequestKind;
|
||||
|
||||
/// AWS S3 storage.
|
||||
pub struct S3Bucket {
|
||||
@@ -50,46 +52,6 @@ pub struct S3Bucket {
|
||||
concurrency_limiter: ConcurrencyLimiter,
|
||||
}
|
||||
|
||||
struct ConcurrencyLimiter {
|
||||
// Every request to S3 can be throttled or cancelled, if a certain number of requests per second is exceeded.
|
||||
// Same goes to IAM, which is queried before every S3 request, if enabled. IAM has even lower RPS threshold.
|
||||
// The helps to ensure we don't exceed the thresholds.
|
||||
write: Arc<Semaphore>,
|
||||
read: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl ConcurrencyLimiter {
|
||||
fn for_kind(&self, kind: RequestKind) -> &Arc<Semaphore> {
|
||||
match kind {
|
||||
RequestKind::Get => &self.read,
|
||||
RequestKind::Put => &self.write,
|
||||
RequestKind::List => &self.read,
|
||||
RequestKind::Delete => &self.write,
|
||||
}
|
||||
}
|
||||
|
||||
async fn acquire(
|
||||
&self,
|
||||
kind: RequestKind,
|
||||
) -> Result<tokio::sync::SemaphorePermit<'_>, tokio::sync::AcquireError> {
|
||||
self.for_kind(kind).acquire().await
|
||||
}
|
||||
|
||||
async fn acquire_owned(
|
||||
&self,
|
||||
kind: RequestKind,
|
||||
) -> Result<tokio::sync::OwnedSemaphorePermit, tokio::sync::AcquireError> {
|
||||
Arc::clone(self.for_kind(kind)).acquire_owned().await
|
||||
}
|
||||
|
||||
fn new(limit: usize) -> ConcurrencyLimiter {
|
||||
Self {
|
||||
read: Arc::new(Semaphore::new(limit)),
|
||||
write: Arc::new(Semaphore::new(limit)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct GetObjectRequest {
|
||||
bucket: String,
|
||||
@@ -125,10 +87,23 @@ impl S3Bucket {
|
||||
.or_else("imds", ImdsCredentialsProvider::builder().build())
|
||||
};
|
||||
|
||||
// AWS SDK requires us to specify how the RetryConfig should sleep when it wants to back off
|
||||
let sleep_impl: Arc<dyn AsyncSleep> = Arc::new(TokioSleep::new());
|
||||
|
||||
// We do our own retries (see [`backoff::retry`]). However, for the AWS SDK to enable rate limiting in response to throttling
|
||||
// responses (e.g. 429 on too many ListObjectsv2 requests), we must provide a retry config. We set it to use at most one
|
||||
// attempt, and enable 'Adaptive' mode, which causes rate limiting to be enabled.
|
||||
let mut retry_config = RetryConfigBuilder::new();
|
||||
retry_config
|
||||
.set_max_attempts(Some(1))
|
||||
.set_mode(Some(RetryMode::Adaptive));
|
||||
|
||||
let mut config_builder = Config::builder()
|
||||
.region(region)
|
||||
.credentials_cache(CredentialsCache::lazy())
|
||||
.credentials_provider(credentials_provider);
|
||||
.credentials_provider(credentials_provider)
|
||||
.sleep_impl(SharedAsyncSleep::from(sleep_impl))
|
||||
.retry_config(retry_config.build());
|
||||
|
||||
if let Some(custom_endpoint) = aws_config.endpoint.clone() {
|
||||
config_builder = config_builder
|
||||
@@ -341,13 +316,13 @@ impl<S: AsyncRead> AsyncRead for TimedDownload<S> {
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl RemoteStorage for S3Bucket {
|
||||
/// See the doc for `RemoteStorage::list_prefixes`
|
||||
/// Note: it wont include empty "directories"
|
||||
async fn list_prefixes(
|
||||
async fn list(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
) -> Result<Vec<RemotePath>, DownloadError> {
|
||||
mode: ListingMode,
|
||||
) -> Result<Listing, DownloadError> {
|
||||
let kind = RequestKind::List;
|
||||
let mut result = Listing::default();
|
||||
|
||||
// get the passed prefix or if it is not set use prefix_in_bucket value
|
||||
let list_prefix = prefix
|
||||
@@ -356,28 +331,33 @@ impl RemoteStorage for S3Bucket {
|
||||
.map(|mut p| {
|
||||
// required to end with a separator
|
||||
// otherwise request will return only the entry of a prefix
|
||||
if !p.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR) {
|
||||
if matches!(mode, ListingMode::WithDelimiter)
|
||||
&& !p.ends_with(REMOTE_STORAGE_PREFIX_SEPARATOR)
|
||||
{
|
||||
p.push(REMOTE_STORAGE_PREFIX_SEPARATOR);
|
||||
}
|
||||
p
|
||||
});
|
||||
|
||||
let mut document_keys = Vec::new();
|
||||
|
||||
let mut continuation_token = None;
|
||||
|
||||
loop {
|
||||
let _guard = self.permit(kind).await;
|
||||
let started_at = start_measuring_requests(kind);
|
||||
|
||||
let fetch_response = self
|
||||
let mut request = self
|
||||
.client
|
||||
.list_objects_v2()
|
||||
.bucket(self.bucket_name.clone())
|
||||
.set_prefix(list_prefix.clone())
|
||||
.set_continuation_token(continuation_token)
|
||||
.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string())
|
||||
.set_max_keys(self.max_keys_per_list_response)
|
||||
.set_max_keys(self.max_keys_per_list_response);
|
||||
|
||||
if let ListingMode::WithDelimiter = mode {
|
||||
request = request.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
|
||||
}
|
||||
|
||||
let response = request
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to list S3 prefixes")
|
||||
@@ -387,71 +367,35 @@ impl RemoteStorage for S3Bucket {
|
||||
|
||||
metrics::BUCKET_METRICS
|
||||
.req_seconds
|
||||
.observe_elapsed(kind, &fetch_response, started_at);
|
||||
.observe_elapsed(kind, &response, started_at);
|
||||
|
||||
let fetch_response = fetch_response?;
|
||||
let response = response?;
|
||||
|
||||
document_keys.extend(
|
||||
fetch_response
|
||||
.common_prefixes
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
let keys = response.contents().unwrap_or_default();
|
||||
let empty = Vec::new();
|
||||
let prefixes = response.common_prefixes.as_ref().unwrap_or(&empty);
|
||||
|
||||
tracing::info!("list: {} prefixes, {} keys", prefixes.len(), keys.len());
|
||||
|
||||
for object in keys {
|
||||
let object_path = object.key().expect("response does not contain a key");
|
||||
let remote_path = self.s3_object_to_relative_path(object_path);
|
||||
result.keys.push(remote_path);
|
||||
}
|
||||
|
||||
result.prefixes.extend(
|
||||
prefixes
|
||||
.iter()
|
||||
.filter_map(|o| Some(self.s3_object_to_relative_path(o.prefix()?))),
|
||||
);
|
||||
|
||||
continuation_token = match fetch_response.next_continuation_token {
|
||||
continuation_token = match response.next_continuation_token {
|
||||
Some(new_token) => Some(new_token),
|
||||
None => break,
|
||||
};
|
||||
}
|
||||
|
||||
Ok(document_keys)
|
||||
}
|
||||
|
||||
/// See the doc for `RemoteStorage::list_files`
|
||||
async fn list_files(&self, folder: Option<&RemotePath>) -> anyhow::Result<Vec<RemotePath>> {
|
||||
let kind = RequestKind::List;
|
||||
|
||||
let folder_name = folder
|
||||
.map(|p| self.relative_path_to_s3_object(p))
|
||||
.or_else(|| self.prefix_in_bucket.clone());
|
||||
|
||||
// AWS may need to break the response into several parts
|
||||
let mut continuation_token = None;
|
||||
let mut all_files = vec![];
|
||||
loop {
|
||||
let _guard = self.permit(kind).await;
|
||||
let started_at = start_measuring_requests(kind);
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.list_objects_v2()
|
||||
.bucket(self.bucket_name.clone())
|
||||
.set_prefix(folder_name.clone())
|
||||
.set_continuation_token(continuation_token)
|
||||
.set_max_keys(self.max_keys_per_list_response)
|
||||
.send()
|
||||
.await
|
||||
.context("Failed to list files in S3 bucket");
|
||||
|
||||
let started_at = ScopeGuard::into_inner(started_at);
|
||||
metrics::BUCKET_METRICS
|
||||
.req_seconds
|
||||
.observe_elapsed(kind, &response, started_at);
|
||||
|
||||
let response = response?;
|
||||
|
||||
for object in response.contents().unwrap_or_default() {
|
||||
let object_path = object.key().expect("response does not contain a key");
|
||||
let remote_path = self.s3_object_to_relative_path(object_path);
|
||||
all_files.push(remote_path);
|
||||
}
|
||||
match response.next_continuation_token {
|
||||
Some(new_token) => continuation_token = Some(new_token),
|
||||
None => break,
|
||||
}
|
||||
}
|
||||
Ok(all_files)
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
async fn upload(
|
||||
@@ -556,6 +500,20 @@ impl RemoteStorage for S3Bucket {
|
||||
.deleted_objects_total
|
||||
.inc_by(chunk.len() as u64);
|
||||
if let Some(errors) = resp.errors {
|
||||
// Log a bounded number of the errors within the response:
|
||||
// these requests can carry 1000 keys so logging each one
|
||||
// would be too verbose, especially as errors may lead us
|
||||
// to retry repeatedly.
|
||||
const LOG_UP_TO_N_ERRORS: usize = 10;
|
||||
for e in errors.iter().take(LOG_UP_TO_N_ERRORS) {
|
||||
tracing::warn!(
|
||||
"DeleteObjects key {} failed: {}: {}",
|
||||
e.key.as_ref().map(Cow::from).unwrap_or("".into()),
|
||||
e.code.as_ref().map(Cow::from).unwrap_or("".into()),
|
||||
e.message.as_ref().map(Cow::from).unwrap_or("".into())
|
||||
);
|
||||
}
|
||||
|
||||
return Err(anyhow::format_err!(
|
||||
"Failed to delete {} objects",
|
||||
errors.len()
|
||||
|
||||
@@ -6,7 +6,7 @@ use once_cell::sync::Lazy;
|
||||
pub(super) static BUCKET_METRICS: Lazy<BucketMetrics> = Lazy::new(Default::default);
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub(super) enum RequestKind {
|
||||
pub(crate) enum RequestKind {
|
||||
Get = 0,
|
||||
Put = 1,
|
||||
Delete = 2,
|
||||
|
||||
@@ -5,7 +5,9 @@ use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use crate::{Download, DownloadError, RemotePath, RemoteStorage, StorageMetadata};
|
||||
use crate::{
|
||||
Download, DownloadError, Listing, ListingMode, RemotePath, RemoteStorage, StorageMetadata,
|
||||
};
|
||||
|
||||
pub struct UnreliableWrapper {
|
||||
inner: crate::GenericRemoteStorage,
|
||||
@@ -95,6 +97,15 @@ impl RemoteStorage for UnreliableWrapper {
|
||||
self.inner.list_files(folder).await
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
prefix: Option<&RemotePath>,
|
||||
mode: ListingMode,
|
||||
) -> Result<Listing, DownloadError> {
|
||||
self.attempt(RemoteOp::ListPrefixes(prefix.cloned()))?;
|
||||
self.inner.list(prefix, mode).await
|
||||
}
|
||||
|
||||
async fn upload(
|
||||
&self,
|
||||
data: impl tokio::io::AsyncRead + Unpin + Send + Sync + 'static,
|
||||
|
||||
623
libs/remote_storage/tests/test_real_azure.rs
Normal file
623
libs/remote_storage/tests/test_real_azure.rs
Normal file
@@ -0,0 +1,623 @@
|
||||
use std::collections::HashSet;
|
||||
use std::env;
|
||||
use std::num::NonZeroUsize;
|
||||
use std::ops::ControlFlow;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::UNIX_EPOCH;
|
||||
|
||||
use anyhow::Context;
|
||||
use camino::Utf8Path;
|
||||
use once_cell::sync::OnceCell;
|
||||
use remote_storage::{
|
||||
AzureConfig, Download, GenericRemoteStorage, RemotePath, RemoteStorageConfig, RemoteStorageKind,
|
||||
};
|
||||
use test_context::{test_context, AsyncTestContext};
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
static LOGGING_DONE: OnceCell<()> = OnceCell::new();
|
||||
|
||||
const ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME: &str = "ENABLE_REAL_AZURE_REMOTE_STORAGE";
|
||||
|
||||
const BASE_PREFIX: &str = "test";
|
||||
|
||||
/// Tests that the Azure client can list all prefixes, even if the response comes paginated and requires multiple HTTP queries.
|
||||
/// Uses real Azure and requires [`ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME`] and related Azure cred env vars specified.
|
||||
/// See the client creation in [`create_azure_client`] for details on the required env vars.
|
||||
/// If real Azure tests are disabled, the test passes, skipping any real test run: currently, there's no way to mark the test ignored in runtime with the
|
||||
/// deafult test framework, see https://github.com/rust-lang/rust/issues/68007 for details.
|
||||
///
|
||||
/// First, the test creates a set of Azure blobs with keys `/${random_prefix_part}/${base_prefix_str}/sub_prefix_${i}/blob_${i}` in [`upload_azure_data`]
|
||||
/// where
|
||||
/// * `random_prefix_part` is set for the entire Azure client during the Azure client creation in [`create_azure_client`], to avoid multiple test runs interference
|
||||
/// * `base_prefix_str` is a common prefix to use in the client requests: we would want to ensure that the client is able to list nested prefixes inside the bucket
|
||||
///
|
||||
/// Then, verifies that the client does return correct prefixes when queried:
|
||||
/// * with no prefix, it lists everything after its `${random_prefix_part}/` — that should be `${base_prefix_str}` value only
|
||||
/// * with `${base_prefix_str}/` prefix, it lists every `sub_prefix_${i}`
|
||||
///
|
||||
/// With the real Azure enabled and `#[cfg(test)]` Rust configuration used, the Azure client test adds a `max-keys` param to limit the response keys.
|
||||
/// This way, we are able to test the pagination implicitly, by ensuring all results are returned from the remote storage and avoid uploading too many blobs to Azure.
|
||||
///
|
||||
/// Lastly, the test attempts to clean up and remove all uploaded Azure files.
|
||||
/// If any errors appear during the clean up, they get logged, but the test is not failed or stopped until clean up is finished.
|
||||
#[test_context(MaybeEnabledAzureWithTestBlobs)]
|
||||
#[tokio::test]
|
||||
async fn azure_pagination_should_work(
|
||||
ctx: &mut MaybeEnabledAzureWithTestBlobs,
|
||||
) -> anyhow::Result<()> {
|
||||
let ctx = match ctx {
|
||||
MaybeEnabledAzureWithTestBlobs::Enabled(ctx) => ctx,
|
||||
MaybeEnabledAzureWithTestBlobs::Disabled => return Ok(()),
|
||||
MaybeEnabledAzureWithTestBlobs::UploadsFailed(e, _) => {
|
||||
anyhow::bail!("Azure init failed: {e:?}")
|
||||
}
|
||||
};
|
||||
|
||||
let test_client = Arc::clone(&ctx.enabled.client);
|
||||
let expected_remote_prefixes = ctx.remote_prefixes.clone();
|
||||
|
||||
let base_prefix = RemotePath::new(Utf8Path::new(ctx.enabled.base_prefix))
|
||||
.context("common_prefix construction")?;
|
||||
let root_remote_prefixes = test_client
|
||||
.list_prefixes(None)
|
||||
.await
|
||||
.context("client list root prefixes failure")?
|
||||
.into_iter()
|
||||
.collect::<HashSet<_>>();
|
||||
assert_eq!(
|
||||
root_remote_prefixes, HashSet::from([base_prefix.clone()]),
|
||||
"remote storage root prefixes list mismatches with the uploads. Returned prefixes: {root_remote_prefixes:?}"
|
||||
);
|
||||
|
||||
let nested_remote_prefixes = test_client
|
||||
.list_prefixes(Some(&base_prefix))
|
||||
.await
|
||||
.context("client list nested prefixes failure")?
|
||||
.into_iter()
|
||||
.collect::<HashSet<_>>();
|
||||
let remote_only_prefixes = nested_remote_prefixes
|
||||
.difference(&expected_remote_prefixes)
|
||||
.collect::<HashSet<_>>();
|
||||
let missing_uploaded_prefixes = expected_remote_prefixes
|
||||
.difference(&nested_remote_prefixes)
|
||||
.collect::<HashSet<_>>();
|
||||
assert_eq!(
|
||||
remote_only_prefixes.len() + missing_uploaded_prefixes.len(), 0,
|
||||
"remote storage nested prefixes list mismatches with the uploads. Remote only prefixes: {remote_only_prefixes:?}, missing uploaded prefixes: {missing_uploaded_prefixes:?}",
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Tests that Azure client can list all files in a folder, even if the response comes paginated and requirees multiple Azure queries.
|
||||
/// Uses real Azure and requires [`ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME`] and related Azure cred env vars specified. Test will skip real code and pass if env vars not set.
|
||||
/// See `Azure_pagination_should_work` for more information.
|
||||
///
|
||||
/// First, create a set of Azure objects with keys `random_prefix/folder{j}/blob_{i}.txt` in [`upload_azure_data`]
|
||||
/// Then performs the following queries:
|
||||
/// 1. `list_files(None)`. This should return all files `random_prefix/folder{j}/blob_{i}.txt`
|
||||
/// 2. `list_files("folder1")`. This should return all files `random_prefix/folder1/blob_{i}.txt`
|
||||
#[test_context(MaybeEnabledAzureWithSimpleTestBlobs)]
|
||||
#[tokio::test]
|
||||
async fn azure_list_files_works(
|
||||
ctx: &mut MaybeEnabledAzureWithSimpleTestBlobs,
|
||||
) -> anyhow::Result<()> {
|
||||
let ctx = match ctx {
|
||||
MaybeEnabledAzureWithSimpleTestBlobs::Enabled(ctx) => ctx,
|
||||
MaybeEnabledAzureWithSimpleTestBlobs::Disabled => return Ok(()),
|
||||
MaybeEnabledAzureWithSimpleTestBlobs::UploadsFailed(e, _) => {
|
||||
anyhow::bail!("Azure init failed: {e:?}")
|
||||
}
|
||||
};
|
||||
let test_client = Arc::clone(&ctx.enabled.client);
|
||||
let base_prefix =
|
||||
RemotePath::new(Utf8Path::new("folder1")).context("common_prefix construction")?;
|
||||
let root_files = test_client
|
||||
.list_files(None)
|
||||
.await
|
||||
.context("client list root files failure")?
|
||||
.into_iter()
|
||||
.collect::<HashSet<_>>();
|
||||
assert_eq!(
|
||||
root_files,
|
||||
ctx.remote_blobs.clone(),
|
||||
"remote storage list_files on root mismatches with the uploads."
|
||||
);
|
||||
let nested_remote_files = test_client
|
||||
.list_files(Some(&base_prefix))
|
||||
.await
|
||||
.context("client list nested files failure")?
|
||||
.into_iter()
|
||||
.collect::<HashSet<_>>();
|
||||
let trim_remote_blobs: HashSet<_> = ctx
|
||||
.remote_blobs
|
||||
.iter()
|
||||
.map(|x| x.get_path())
|
||||
.filter(|x| x.starts_with("folder1"))
|
||||
.map(|x| RemotePath::new(x).expect("must be valid path"))
|
||||
.collect();
|
||||
assert_eq!(
|
||||
nested_remote_files, trim_remote_blobs,
|
||||
"remote storage list_files on subdirrectory mismatches with the uploads."
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test_context(MaybeEnabledAzure)]
|
||||
#[tokio::test]
|
||||
async fn azure_delete_non_exising_works(ctx: &mut MaybeEnabledAzure) -> anyhow::Result<()> {
|
||||
let ctx = match ctx {
|
||||
MaybeEnabledAzure::Enabled(ctx) => ctx,
|
||||
MaybeEnabledAzure::Disabled => return Ok(()),
|
||||
};
|
||||
|
||||
let path = RemotePath::new(Utf8Path::new(
|
||||
format!("{}/for_sure_there_is_nothing_there_really", ctx.base_prefix).as_str(),
|
||||
))
|
||||
.with_context(|| "RemotePath conversion")?;
|
||||
|
||||
ctx.client.delete(&path).await.expect("should succeed");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test_context(MaybeEnabledAzure)]
|
||||
#[tokio::test]
|
||||
async fn azure_delete_objects_works(ctx: &mut MaybeEnabledAzure) -> anyhow::Result<()> {
|
||||
let ctx = match ctx {
|
||||
MaybeEnabledAzure::Enabled(ctx) => ctx,
|
||||
MaybeEnabledAzure::Disabled => return Ok(()),
|
||||
};
|
||||
|
||||
let path1 = RemotePath::new(Utf8Path::new(format!("{}/path1", ctx.base_prefix).as_str()))
|
||||
.with_context(|| "RemotePath conversion")?;
|
||||
|
||||
let path2 = RemotePath::new(Utf8Path::new(format!("{}/path2", ctx.base_prefix).as_str()))
|
||||
.with_context(|| "RemotePath conversion")?;
|
||||
|
||||
let path3 = RemotePath::new(Utf8Path::new(format!("{}/path3", ctx.base_prefix).as_str()))
|
||||
.with_context(|| "RemotePath conversion")?;
|
||||
|
||||
let data1 = "remote blob data1".as_bytes();
|
||||
let data1_len = data1.len();
|
||||
let data2 = "remote blob data2".as_bytes();
|
||||
let data2_len = data2.len();
|
||||
let data3 = "remote blob data3".as_bytes();
|
||||
let data3_len = data3.len();
|
||||
ctx.client
|
||||
.upload(std::io::Cursor::new(data1), data1_len, &path1, None)
|
||||
.await?;
|
||||
|
||||
ctx.client
|
||||
.upload(std::io::Cursor::new(data2), data2_len, &path2, None)
|
||||
.await?;
|
||||
|
||||
ctx.client
|
||||
.upload(std::io::Cursor::new(data3), data3_len, &path3, None)
|
||||
.await?;
|
||||
|
||||
ctx.client.delete_objects(&[path1, path2]).await?;
|
||||
|
||||
let prefixes = ctx.client.list_prefixes(None).await?;
|
||||
|
||||
assert_eq!(prefixes.len(), 1);
|
||||
|
||||
ctx.client.delete_objects(&[path3]).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test_context(MaybeEnabledAzure)]
|
||||
#[tokio::test]
|
||||
async fn azure_upload_download_works(ctx: &mut MaybeEnabledAzure) -> anyhow::Result<()> {
|
||||
let MaybeEnabledAzure::Enabled(ctx) = ctx else {
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let path = RemotePath::new(Utf8Path::new(format!("{}/file", ctx.base_prefix).as_str()))
|
||||
.with_context(|| "RemotePath conversion")?;
|
||||
|
||||
let data = "remote blob data here".as_bytes();
|
||||
let data_len = data.len() as u64;
|
||||
|
||||
ctx.client
|
||||
.upload(std::io::Cursor::new(data), data.len(), &path, None)
|
||||
.await?;
|
||||
|
||||
async fn download_and_compare(mut dl: Download) -> anyhow::Result<Vec<u8>> {
|
||||
let mut buf = Vec::new();
|
||||
tokio::io::copy(&mut dl.download_stream, &mut buf).await?;
|
||||
Ok(buf)
|
||||
}
|
||||
// Normal download request
|
||||
let dl = ctx.client.download(&path).await?;
|
||||
let buf = download_and_compare(dl).await?;
|
||||
assert_eq!(buf, data);
|
||||
|
||||
// Full range (end specified)
|
||||
let dl = ctx
|
||||
.client
|
||||
.download_byte_range(&path, 0, Some(data_len))
|
||||
.await?;
|
||||
let buf = download_and_compare(dl).await?;
|
||||
assert_eq!(buf, data);
|
||||
|
||||
// partial range (end specified)
|
||||
let dl = ctx.client.download_byte_range(&path, 4, Some(10)).await?;
|
||||
let buf = download_and_compare(dl).await?;
|
||||
assert_eq!(buf, data[4..10]);
|
||||
|
||||
// partial range (end beyond real end)
|
||||
let dl = ctx
|
||||
.client
|
||||
.download_byte_range(&path, 8, Some(data_len * 100))
|
||||
.await?;
|
||||
let buf = download_and_compare(dl).await?;
|
||||
assert_eq!(buf, data[8..]);
|
||||
|
||||
// Partial range (end unspecified)
|
||||
let dl = ctx.client.download_byte_range(&path, 4, None).await?;
|
||||
let buf = download_and_compare(dl).await?;
|
||||
assert_eq!(buf, data[4..]);
|
||||
|
||||
// Full range (end unspecified)
|
||||
let dl = ctx.client.download_byte_range(&path, 0, None).await?;
|
||||
let buf = download_and_compare(dl).await?;
|
||||
assert_eq!(buf, data);
|
||||
|
||||
debug!("Cleanup: deleting file at path {path:?}");
|
||||
ctx.client
|
||||
.delete(&path)
|
||||
.await
|
||||
.with_context(|| format!("{path:?} removal"))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_logging_ready() {
|
||||
LOGGING_DONE.get_or_init(|| {
|
||||
utils::logging::init(
|
||||
utils::logging::LogFormat::Test,
|
||||
utils::logging::TracingErrorLayerEnablement::Disabled,
|
||||
)
|
||||
.expect("logging init failed");
|
||||
});
|
||||
}
|
||||
|
||||
struct EnabledAzure {
|
||||
client: Arc<GenericRemoteStorage>,
|
||||
base_prefix: &'static str,
|
||||
}
|
||||
|
||||
impl EnabledAzure {
|
||||
async fn setup(max_keys_in_list_response: Option<i32>) -> Self {
|
||||
let client = create_azure_client(max_keys_in_list_response)
|
||||
.context("Azure client creation")
|
||||
.expect("Azure client creation failed");
|
||||
|
||||
EnabledAzure {
|
||||
client,
|
||||
base_prefix: BASE_PREFIX,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum MaybeEnabledAzure {
|
||||
Enabled(EnabledAzure),
|
||||
Disabled,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl AsyncTestContext for MaybeEnabledAzure {
|
||||
async fn setup() -> Self {
|
||||
ensure_logging_ready();
|
||||
|
||||
if env::var(ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME).is_err() {
|
||||
info!(
|
||||
"`{}` env variable is not set, skipping the test",
|
||||
ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME
|
||||
);
|
||||
return Self::Disabled;
|
||||
}
|
||||
|
||||
Self::Enabled(EnabledAzure::setup(None).await)
|
||||
}
|
||||
}
|
||||
|
||||
enum MaybeEnabledAzureWithTestBlobs {
|
||||
Enabled(AzureWithTestBlobs),
|
||||
Disabled,
|
||||
UploadsFailed(anyhow::Error, AzureWithTestBlobs),
|
||||
}
|
||||
|
||||
struct AzureWithTestBlobs {
|
||||
enabled: EnabledAzure,
|
||||
remote_prefixes: HashSet<RemotePath>,
|
||||
remote_blobs: HashSet<RemotePath>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl AsyncTestContext for MaybeEnabledAzureWithTestBlobs {
|
||||
async fn setup() -> Self {
|
||||
ensure_logging_ready();
|
||||
if env::var(ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME).is_err() {
|
||||
info!(
|
||||
"`{}` env variable is not set, skipping the test",
|
||||
ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME
|
||||
);
|
||||
return Self::Disabled;
|
||||
}
|
||||
|
||||
let max_keys_in_list_response = 10;
|
||||
let upload_tasks_count = 1 + (2 * usize::try_from(max_keys_in_list_response).unwrap());
|
||||
|
||||
let enabled = EnabledAzure::setup(Some(max_keys_in_list_response)).await;
|
||||
|
||||
match upload_azure_data(&enabled.client, enabled.base_prefix, upload_tasks_count).await {
|
||||
ControlFlow::Continue(uploads) => {
|
||||
info!("Remote objects created successfully");
|
||||
|
||||
Self::Enabled(AzureWithTestBlobs {
|
||||
enabled,
|
||||
remote_prefixes: uploads.prefixes,
|
||||
remote_blobs: uploads.blobs,
|
||||
})
|
||||
}
|
||||
ControlFlow::Break(uploads) => Self::UploadsFailed(
|
||||
anyhow::anyhow!("One or multiple blobs failed to upload to Azure"),
|
||||
AzureWithTestBlobs {
|
||||
enabled,
|
||||
remote_prefixes: uploads.prefixes,
|
||||
remote_blobs: uploads.blobs,
|
||||
},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
async fn teardown(self) {
|
||||
match self {
|
||||
Self::Disabled => {}
|
||||
Self::Enabled(ctx) | Self::UploadsFailed(_, ctx) => {
|
||||
cleanup(&ctx.enabled.client, ctx.remote_blobs).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: the setups for the list_prefixes test and the list_files test are very similar
|
||||
// However, they are not idential. The list_prefixes function is concerned with listing prefixes,
|
||||
// whereas the list_files function is concerned with listing files.
|
||||
// See `RemoteStorage::list_files` documentation for more details
|
||||
enum MaybeEnabledAzureWithSimpleTestBlobs {
|
||||
Enabled(AzureWithSimpleTestBlobs),
|
||||
Disabled,
|
||||
UploadsFailed(anyhow::Error, AzureWithSimpleTestBlobs),
|
||||
}
|
||||
struct AzureWithSimpleTestBlobs {
|
||||
enabled: EnabledAzure,
|
||||
remote_blobs: HashSet<RemotePath>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl AsyncTestContext for MaybeEnabledAzureWithSimpleTestBlobs {
|
||||
async fn setup() -> Self {
|
||||
ensure_logging_ready();
|
||||
if env::var(ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME).is_err() {
|
||||
info!(
|
||||
"`{}` env variable is not set, skipping the test",
|
||||
ENABLE_REAL_AZURE_REMOTE_STORAGE_ENV_VAR_NAME
|
||||
);
|
||||
return Self::Disabled;
|
||||
}
|
||||
|
||||
let max_keys_in_list_response = 10;
|
||||
let upload_tasks_count = 1 + (2 * usize::try_from(max_keys_in_list_response).unwrap());
|
||||
|
||||
let enabled = EnabledAzure::setup(Some(max_keys_in_list_response)).await;
|
||||
|
||||
match upload_simple_azure_data(&enabled.client, upload_tasks_count).await {
|
||||
ControlFlow::Continue(uploads) => {
|
||||
info!("Remote objects created successfully");
|
||||
|
||||
Self::Enabled(AzureWithSimpleTestBlobs {
|
||||
enabled,
|
||||
remote_blobs: uploads,
|
||||
})
|
||||
}
|
||||
ControlFlow::Break(uploads) => Self::UploadsFailed(
|
||||
anyhow::anyhow!("One or multiple blobs failed to upload to Azure"),
|
||||
AzureWithSimpleTestBlobs {
|
||||
enabled,
|
||||
remote_blobs: uploads,
|
||||
},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
async fn teardown(self) {
|
||||
match self {
|
||||
Self::Disabled => {}
|
||||
Self::Enabled(ctx) | Self::UploadsFailed(_, ctx) => {
|
||||
cleanup(&ctx.enabled.client, ctx.remote_blobs).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_azure_client(
|
||||
max_keys_per_list_response: Option<i32>,
|
||||
) -> anyhow::Result<Arc<GenericRemoteStorage>> {
|
||||
use rand::Rng;
|
||||
|
||||
let remote_storage_azure_container = env::var("REMOTE_STORAGE_AZURE_CONTAINER").context(
|
||||
"`REMOTE_STORAGE_AZURE_CONTAINER` env var is not set, but real Azure tests are enabled",
|
||||
)?;
|
||||
let remote_storage_azure_region = env::var("REMOTE_STORAGE_AZURE_REGION").context(
|
||||
"`REMOTE_STORAGE_AZURE_REGION` env var is not set, but real Azure tests are enabled",
|
||||
)?;
|
||||
|
||||
// due to how time works, we've had test runners use the same nanos as bucket prefixes.
|
||||
// millis is just a debugging aid for easier finding the prefix later.
|
||||
let millis = std::time::SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.context("random Azure test prefix part calculation")?
|
||||
.as_millis();
|
||||
|
||||
// because nanos can be the same for two threads so can millis, add randomness
|
||||
let random = rand::thread_rng().gen::<u32>();
|
||||
|
||||
let remote_storage_config = RemoteStorageConfig {
|
||||
storage: RemoteStorageKind::AzureContainer(AzureConfig {
|
||||
container_name: remote_storage_azure_container,
|
||||
container_region: remote_storage_azure_region,
|
||||
prefix_in_container: Some(format!("test_{millis}_{random:08x}/")),
|
||||
concurrency_limit: NonZeroUsize::new(100).unwrap(),
|
||||
max_keys_per_list_response,
|
||||
}),
|
||||
};
|
||||
Ok(Arc::new(
|
||||
GenericRemoteStorage::from_config(&remote_storage_config).context("remote storage init")?,
|
||||
))
|
||||
}
|
||||
|
||||
struct Uploads {
|
||||
prefixes: HashSet<RemotePath>,
|
||||
blobs: HashSet<RemotePath>,
|
||||
}
|
||||
|
||||
async fn upload_azure_data(
|
||||
client: &Arc<GenericRemoteStorage>,
|
||||
base_prefix_str: &'static str,
|
||||
upload_tasks_count: usize,
|
||||
) -> ControlFlow<Uploads, Uploads> {
|
||||
info!("Creating {upload_tasks_count} Azure files");
|
||||
let mut upload_tasks = JoinSet::new();
|
||||
for i in 1..upload_tasks_count + 1 {
|
||||
let task_client = Arc::clone(client);
|
||||
upload_tasks.spawn(async move {
|
||||
let prefix = format!("{base_prefix_str}/sub_prefix_{i}/");
|
||||
let blob_prefix = RemotePath::new(Utf8Path::new(&prefix))
|
||||
.with_context(|| format!("{prefix:?} to RemotePath conversion"))?;
|
||||
let blob_path = blob_prefix.join(Utf8Path::new(&format!("blob_{i}")));
|
||||
debug!("Creating remote item {i} at path {blob_path:?}");
|
||||
|
||||
let data = format!("remote blob data {i}").into_bytes();
|
||||
let data_len = data.len();
|
||||
task_client
|
||||
.upload(std::io::Cursor::new(data), data_len, &blob_path, None)
|
||||
.await?;
|
||||
|
||||
Ok::<_, anyhow::Error>((blob_prefix, blob_path))
|
||||
});
|
||||
}
|
||||
|
||||
let mut upload_tasks_failed = false;
|
||||
let mut uploaded_prefixes = HashSet::with_capacity(upload_tasks_count);
|
||||
let mut uploaded_blobs = HashSet::with_capacity(upload_tasks_count);
|
||||
while let Some(task_run_result) = upload_tasks.join_next().await {
|
||||
match task_run_result
|
||||
.context("task join failed")
|
||||
.and_then(|task_result| task_result.context("upload task failed"))
|
||||
{
|
||||
Ok((upload_prefix, upload_path)) => {
|
||||
uploaded_prefixes.insert(upload_prefix);
|
||||
uploaded_blobs.insert(upload_path);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Upload task failed: {e:?}");
|
||||
upload_tasks_failed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let uploads = Uploads {
|
||||
prefixes: uploaded_prefixes,
|
||||
blobs: uploaded_blobs,
|
||||
};
|
||||
if upload_tasks_failed {
|
||||
ControlFlow::Break(uploads)
|
||||
} else {
|
||||
ControlFlow::Continue(uploads)
|
||||
}
|
||||
}
|
||||
|
||||
async fn cleanup(client: &Arc<GenericRemoteStorage>, objects_to_delete: HashSet<RemotePath>) {
|
||||
info!(
|
||||
"Removing {} objects from the remote storage during cleanup",
|
||||
objects_to_delete.len()
|
||||
);
|
||||
let mut delete_tasks = JoinSet::new();
|
||||
for object_to_delete in objects_to_delete {
|
||||
let task_client = Arc::clone(client);
|
||||
delete_tasks.spawn(async move {
|
||||
debug!("Deleting remote item at path {object_to_delete:?}");
|
||||
task_client
|
||||
.delete(&object_to_delete)
|
||||
.await
|
||||
.with_context(|| format!("{object_to_delete:?} removal"))
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(task_run_result) = delete_tasks.join_next().await {
|
||||
match task_run_result {
|
||||
Ok(task_result) => match task_result {
|
||||
Ok(()) => {}
|
||||
Err(e) => error!("Delete task failed: {e:?}"),
|
||||
},
|
||||
Err(join_err) => error!("Delete task did not finish correctly: {join_err}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Uploads files `folder{j}/blob{i}.txt`. See test description for more details.
|
||||
async fn upload_simple_azure_data(
|
||||
client: &Arc<GenericRemoteStorage>,
|
||||
upload_tasks_count: usize,
|
||||
) -> ControlFlow<HashSet<RemotePath>, HashSet<RemotePath>> {
|
||||
info!("Creating {upload_tasks_count} Azure files");
|
||||
let mut upload_tasks = JoinSet::new();
|
||||
for i in 1..upload_tasks_count + 1 {
|
||||
let task_client = Arc::clone(client);
|
||||
upload_tasks.spawn(async move {
|
||||
let blob_path = PathBuf::from(format!("folder{}/blob_{}.txt", i / 7, i));
|
||||
let blob_path = RemotePath::new(
|
||||
Utf8Path::from_path(blob_path.as_path()).expect("must be valid blob path"),
|
||||
)
|
||||
.with_context(|| format!("{blob_path:?} to RemotePath conversion"))?;
|
||||
debug!("Creating remote item {i} at path {blob_path:?}");
|
||||
|
||||
let data = format!("remote blob data {i}").into_bytes();
|
||||
let data_len = data.len();
|
||||
task_client
|
||||
.upload(std::io::Cursor::new(data), data_len, &blob_path, None)
|
||||
.await?;
|
||||
|
||||
Ok::<_, anyhow::Error>(blob_path)
|
||||
});
|
||||
}
|
||||
|
||||
let mut upload_tasks_failed = false;
|
||||
let mut uploaded_blobs = HashSet::with_capacity(upload_tasks_count);
|
||||
while let Some(task_run_result) = upload_tasks.join_next().await {
|
||||
match task_run_result
|
||||
.context("task join failed")
|
||||
.and_then(|task_result| task_result.context("upload task failed"))
|
||||
{
|
||||
Ok(upload_path) => {
|
||||
uploaded_blobs.insert(upload_path);
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Upload task failed: {e:?}");
|
||||
upload_tasks_failed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if upload_tasks_failed {
|
||||
ControlFlow::Break(uploaded_blobs)
|
||||
} else {
|
||||
ControlFlow::Continue(uploaded_blobs)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::collections::HashSet;
|
||||
use std::env;
|
||||
use std::num::{NonZeroU32, NonZeroUsize};
|
||||
use std::num::NonZeroUsize;
|
||||
use std::ops::ControlFlow;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
@@ -396,8 +396,6 @@ fn create_s3_client(
|
||||
let random = rand::thread_rng().gen::<u32>();
|
||||
|
||||
let remote_storage_config = RemoteStorageConfig {
|
||||
max_concurrent_syncs: NonZeroUsize::new(100).unwrap(),
|
||||
max_sync_errors: NonZeroU32::new(5).unwrap(),
|
||||
storage: RemoteStorageKind::AwsS3(S3Config {
|
||||
bucket_name: remote_storage_s3_bucket,
|
||||
bucket_region: remote_storage_s3_region,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
use const_format::formatcp;
|
||||
|
||||
/// Public API types
|
||||
|
||||
@@ -1,23 +1,18 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
|
||||
use utils::{
|
||||
id::{NodeId, TenantId, TimelineId},
|
||||
lsn::Lsn,
|
||||
};
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct TimelineCreateRequest {
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub tenant_id: TenantId,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub timeline_id: TimelineId,
|
||||
pub peer_ids: Option<Vec<NodeId>>,
|
||||
pub pg_version: u32,
|
||||
pub system_id: Option<u64>,
|
||||
pub wal_seg_size: Option<u32>,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub commit_lsn: Lsn,
|
||||
// If not passed, it is assigned to the beginning of commit_lsn segment.
|
||||
pub local_start_lsn: Option<Lsn>,
|
||||
@@ -28,7 +23,6 @@ fn lsn_invalid() -> Lsn {
|
||||
}
|
||||
|
||||
/// Data about safekeeper's timeline, mirrors broker.proto.
|
||||
#[serde_as]
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct SkTimelineInfo {
|
||||
/// Term.
|
||||
@@ -36,25 +30,19 @@ pub struct SkTimelineInfo {
|
||||
/// Term of the last entry.
|
||||
pub last_log_term: Option<u64>,
|
||||
/// LSN of the last record.
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
#[serde(default = "lsn_invalid")]
|
||||
pub flush_lsn: Lsn,
|
||||
/// Up to which LSN safekeeper regards its WAL as committed.
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
#[serde(default = "lsn_invalid")]
|
||||
pub commit_lsn: Lsn,
|
||||
/// LSN up to which safekeeper has backed WAL.
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
#[serde(default = "lsn_invalid")]
|
||||
pub backup_lsn: Lsn,
|
||||
/// LSN of last checkpoint uploaded by pageserver.
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
#[serde(default = "lsn_invalid")]
|
||||
pub remote_consistent_lsn: Lsn,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
#[serde(default = "lsn_invalid")]
|
||||
pub peer_horizon_lsn: Lsn,
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
#[serde(default = "lsn_invalid")]
|
||||
pub local_start_lsn: Lsn,
|
||||
/// A connection string to use for WAL receiving.
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
//! Synthetic size calculation
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
mod calculation;
|
||||
pub mod svg;
|
||||
|
||||
@@ -32,6 +32,8 @@
|
||||
//! .init();
|
||||
//! }
|
||||
//! ```
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
use opentelemetry::sdk::Resource;
|
||||
use opentelemetry::KeyValue;
|
||||
|
||||
@@ -5,6 +5,7 @@ edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
arc-swap.workspace = true
|
||||
sentry.workspace = true
|
||||
async-trait.workspace = true
|
||||
anyhow.workspace = true
|
||||
@@ -55,6 +56,7 @@ bytes.workspace = true
|
||||
criterion.workspace = true
|
||||
hex-literal.workspace = true
|
||||
camino-tempfile.workspace = true
|
||||
serde_assert.workspace = true
|
||||
|
||||
[[bench]]
|
||||
name = "benchmarks"
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
// For details about authentication see docs/authentication.md
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use serde;
|
||||
use std::fs;
|
||||
use std::{borrow::Cow, fmt::Display, fs, sync::Arc};
|
||||
|
||||
use anyhow::Result;
|
||||
use camino::Utf8Path;
|
||||
@@ -9,9 +10,8 @@ use jsonwebtoken::{
|
||||
decode, encode, Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
|
||||
use crate::id::TenantId;
|
||||
use crate::{http::error::ApiError, id::TenantId};
|
||||
|
||||
/// Algorithm to use. We require EdDSA.
|
||||
const STORAGE_TOKEN_ALGORITHM: Algorithm = Algorithm::EdDSA;
|
||||
@@ -32,11 +32,9 @@ pub enum Scope {
|
||||
}
|
||||
|
||||
/// JWT payload. See docs/authentication.md for the format
|
||||
#[serde_as]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
pub struct Claims {
|
||||
#[serde(default)]
|
||||
#[serde_as(as = "Option<DisplayFromStr>")]
|
||||
pub tenant_id: Option<TenantId>,
|
||||
pub scope: Scope,
|
||||
}
|
||||
@@ -47,31 +45,106 @@ impl Claims {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SwappableJwtAuth(ArcSwap<JwtAuth>);
|
||||
|
||||
impl SwappableJwtAuth {
|
||||
pub fn new(jwt_auth: JwtAuth) -> Self {
|
||||
SwappableJwtAuth(ArcSwap::new(Arc::new(jwt_auth)))
|
||||
}
|
||||
pub fn swap(&self, jwt_auth: JwtAuth) {
|
||||
self.0.swap(Arc::new(jwt_auth));
|
||||
}
|
||||
pub fn decode(&self, token: &str) -> std::result::Result<TokenData<Claims>, AuthError> {
|
||||
self.0.load().decode(token)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SwappableJwtAuth {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Swappable({:?})", self.0.load())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
|
||||
pub struct AuthError(pub Cow<'static, str>);
|
||||
|
||||
impl Display for AuthError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<AuthError> for ApiError {
|
||||
fn from(_value: AuthError) -> Self {
|
||||
// Don't pass on the value of the AuthError as a precautionary measure.
|
||||
// Being intentionally vague in public error communication hurts debugability
|
||||
// but it is more secure.
|
||||
ApiError::Forbidden("JWT authentication error".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct JwtAuth {
|
||||
decoding_key: DecodingKey,
|
||||
decoding_keys: Vec<DecodingKey>,
|
||||
validation: Validation,
|
||||
}
|
||||
|
||||
impl JwtAuth {
|
||||
pub fn new(decoding_key: DecodingKey) -> Self {
|
||||
pub fn new(decoding_keys: Vec<DecodingKey>) -> Self {
|
||||
let mut validation = Validation::default();
|
||||
validation.algorithms = vec![STORAGE_TOKEN_ALGORITHM];
|
||||
// The default 'required_spec_claims' is 'exp'. But we don't want to require
|
||||
// expiration.
|
||||
validation.required_spec_claims = [].into();
|
||||
Self {
|
||||
decoding_key,
|
||||
decoding_keys,
|
||||
validation,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_key_path(key_path: &Utf8Path) -> Result<Self> {
|
||||
let public_key = fs::read(key_path)?;
|
||||
Ok(Self::new(DecodingKey::from_ed_pem(&public_key)?))
|
||||
let metadata = key_path.metadata()?;
|
||||
let decoding_keys = if metadata.is_dir() {
|
||||
let mut keys = Vec::new();
|
||||
for entry in fs::read_dir(key_path)? {
|
||||
let path = entry?.path();
|
||||
if !path.is_file() {
|
||||
// Ignore directories (don't recurse)
|
||||
continue;
|
||||
}
|
||||
let public_key = fs::read(path)?;
|
||||
keys.push(DecodingKey::from_ed_pem(&public_key)?);
|
||||
}
|
||||
keys
|
||||
} else if metadata.is_file() {
|
||||
let public_key = fs::read(key_path)?;
|
||||
vec![DecodingKey::from_ed_pem(&public_key)?]
|
||||
} else {
|
||||
anyhow::bail!("path is neither a directory or a file")
|
||||
};
|
||||
if decoding_keys.is_empty() {
|
||||
anyhow::bail!("Configured for JWT auth with zero decoding keys. All JWT gated requests would be rejected.");
|
||||
}
|
||||
Ok(Self::new(decoding_keys))
|
||||
}
|
||||
|
||||
pub fn decode(&self, token: &str) -> Result<TokenData<Claims>> {
|
||||
Ok(decode(token, &self.decoding_key, &self.validation)?)
|
||||
/// Attempt to decode the token with the internal decoding keys.
|
||||
///
|
||||
/// The function tries the stored decoding keys in succession,
|
||||
/// and returns the first yielding a successful result.
|
||||
/// If there is no working decoding key, it returns the last error.
|
||||
pub fn decode(&self, token: &str) -> std::result::Result<TokenData<Claims>, AuthError> {
|
||||
let mut res = None;
|
||||
for decoding_key in &self.decoding_keys {
|
||||
res = Some(decode(token, decoding_key, &self.validation));
|
||||
if let Some(Ok(res)) = res {
|
||||
return Ok(res);
|
||||
}
|
||||
}
|
||||
if let Some(res) = res {
|
||||
res.map_err(|e| AuthError(Cow::Owned(e.to_string())))
|
||||
} else {
|
||||
Err(AuthError(Cow::Borrowed("no JWT decoding keys configured")))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,9 +184,9 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
|
||||
"#;
|
||||
|
||||
#[test]
|
||||
fn test_decode() -> Result<(), anyhow::Error> {
|
||||
fn test_decode() {
|
||||
let expected_claims = Claims {
|
||||
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081")?),
|
||||
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
|
||||
scope: Scope::Tenant,
|
||||
};
|
||||
|
||||
@@ -132,28 +205,24 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
|
||||
let encoded_eddsa = "eyJhbGciOiJFZERTQSIsInR5cCI6IkpXVCJ9.eyJzY29wZSI6InRlbmFudCIsInRlbmFudF9pZCI6IjNkMWY3NTk1YjQ2ODIzMDMwNGUwYjczY2VjYmNiMDgxIiwiaXNzIjoibmVvbi5jb250cm9scGxhbmUiLCJleHAiOjE3MDkyMDA4NzksImlhdCI6MTY3ODQ0MjQ3OX0.U3eA8j-uU-JnhzeO3EDHRuXLwkAUFCPxtGHEgw6p7Ccc3YRbFs2tmCdbD9PZEXP-XsxSeBQi1FY0YPcT3NXADw";
|
||||
|
||||
// Check it can be validated with the public key
|
||||
let auth = JwtAuth::new(DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519)?);
|
||||
let claims_from_token = auth.decode(encoded_eddsa)?.claims;
|
||||
let auth = JwtAuth::new(vec![DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519).unwrap()]);
|
||||
let claims_from_token = auth.decode(encoded_eddsa).unwrap().claims;
|
||||
assert_eq!(claims_from_token, expected_claims);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode() -> Result<(), anyhow::Error> {
|
||||
fn test_encode() {
|
||||
let claims = Claims {
|
||||
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081")?),
|
||||
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
|
||||
scope: Scope::Tenant,
|
||||
};
|
||||
|
||||
let encoded = encode_from_key_file(&claims, TEST_PRIV_KEY_ED25519)?;
|
||||
let encoded = encode_from_key_file(&claims, TEST_PRIV_KEY_ED25519).unwrap();
|
||||
|
||||
// decode it back
|
||||
let auth = JwtAuth::new(DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519)?);
|
||||
let decoded = auth.decode(&encoded)?;
|
||||
let auth = JwtAuth::new(vec![DecodingKey::from_ed_pem(TEST_PUB_KEY_ED25519).unwrap()]);
|
||||
let decoded = auth.decode(&encoded).unwrap();
|
||||
|
||||
assert_eq!(decoded.claims, claims);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ use serde::{Deserialize, Serialize};
|
||||
///
|
||||
/// See docs/rfcs/025-generation-numbers.md for detail on how generation
|
||||
/// numbers are used.
|
||||
#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord)]
|
||||
#[derive(Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
|
||||
pub enum Generation {
|
||||
// Generations with this magic value will not add a suffix to S3 keys, and will not
|
||||
// be included in persisted index_part.json. This value is only to be used
|
||||
|
||||
41
libs/utils/src/hex.rs
Normal file
41
libs/utils/src/hex.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
/// Useful type for asserting that expected bytes match reporting the bytes more readable
|
||||
/// array-syntax compatible hex bytes.
|
||||
///
|
||||
/// # Usage
|
||||
///
|
||||
/// ```
|
||||
/// use utils::Hex;
|
||||
///
|
||||
/// let actual = serialize_something();
|
||||
/// let expected = [0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64];
|
||||
///
|
||||
/// // the type implements PartialEq and on mismatch, both sides are printed in 16 wide multiline
|
||||
/// // output suffixed with an array style length for easier comparisons.
|
||||
/// assert_eq!(Hex(&actual), Hex(&expected));
|
||||
///
|
||||
/// // with `let expected = [0x68];` the error would had been:
|
||||
/// // assertion `left == right` failed
|
||||
/// // left: [0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x20, 0x77, 0x6f, 0x72, 0x6c, 0x64; 11]
|
||||
/// // right: [0x68; 1]
|
||||
/// # fn serialize_something() -> Vec<u8> { "hello world".as_bytes().to_vec() }
|
||||
/// ```
|
||||
#[derive(PartialEq)]
|
||||
pub struct Hex<'a>(pub &'a [u8]);
|
||||
|
||||
impl std::fmt::Debug for Hex<'_> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "[")?;
|
||||
for (i, c) in self.0.chunks(16).enumerate() {
|
||||
if i > 0 && !c.is_empty() {
|
||||
writeln!(f, ", ")?;
|
||||
}
|
||||
for (j, b) in c.iter().enumerate() {
|
||||
if j > 0 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
write!(f, "0x{b:02x}")?;
|
||||
}
|
||||
}
|
||||
write!(f, "; {}]", self.0.len())
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::auth::{Claims, JwtAuth};
|
||||
use crate::auth::{AuthError, Claims, SwappableJwtAuth};
|
||||
use crate::http::error::{api_error_handler, route_error_handler, ApiError};
|
||||
use anyhow::Context;
|
||||
use hyper::header::{HeaderName, AUTHORIZATION};
|
||||
@@ -14,6 +14,11 @@ use tracing::{self, debug, info, info_span, warn, Instrument};
|
||||
use std::future::Future;
|
||||
use std::str::FromStr;
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use std::io::Write as _;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
static SERVE_METRICS_COUNT: Lazy<IntCounter> = Lazy::new(|| {
|
||||
register_int_counter!(
|
||||
"libmetrics_metric_handler_requests_total",
|
||||
@@ -146,94 +151,89 @@ impl Drop for RequestCancelled {
|
||||
}
|
||||
}
|
||||
|
||||
/// An [`std::io::Write`] implementation on top of a channel sending [`bytes::Bytes`] chunks.
|
||||
pub struct ChannelWriter {
|
||||
buffer: BytesMut,
|
||||
pub tx: mpsc::Sender<std::io::Result<Bytes>>,
|
||||
written: usize,
|
||||
}
|
||||
|
||||
impl ChannelWriter {
|
||||
pub fn new(buf_len: usize, tx: mpsc::Sender<std::io::Result<Bytes>>) -> Self {
|
||||
assert_ne!(buf_len, 0);
|
||||
ChannelWriter {
|
||||
// split about half off the buffer from the start, because we flush depending on
|
||||
// capacity. first flush will come sooner than without this, but now resizes will
|
||||
// have better chance of picking up the "other" half. not guaranteed of course.
|
||||
buffer: BytesMut::with_capacity(buf_len).split_off(buf_len / 2),
|
||||
tx,
|
||||
written: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn flush0(&mut self) -> std::io::Result<usize> {
|
||||
let n = self.buffer.len();
|
||||
if n == 0 {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
tracing::trace!(n, "flushing");
|
||||
let ready = self.buffer.split().freeze();
|
||||
|
||||
// not ideal to call from blocking code to block_on, but we are sure that this
|
||||
// operation does not spawn_blocking other tasks
|
||||
let res: Result<(), ()> = tokio::runtime::Handle::current().block_on(async {
|
||||
self.tx.send(Ok(ready)).await.map_err(|_| ())?;
|
||||
|
||||
// throttle sending to allow reuse of our buffer in `write`.
|
||||
self.tx.reserve().await.map_err(|_| ())?;
|
||||
|
||||
// now the response task has picked up the buffer and hopefully started
|
||||
// sending it to the client.
|
||||
Ok(())
|
||||
});
|
||||
if res.is_err() {
|
||||
return Err(std::io::ErrorKind::BrokenPipe.into());
|
||||
}
|
||||
self.written += n;
|
||||
Ok(n)
|
||||
}
|
||||
|
||||
pub fn flushed_bytes(&self) -> usize {
|
||||
self.written
|
||||
}
|
||||
}
|
||||
|
||||
impl std::io::Write for ChannelWriter {
|
||||
fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
|
||||
let remaining = self.buffer.capacity() - self.buffer.len();
|
||||
|
||||
let out_of_space = remaining < buf.len();
|
||||
|
||||
let original_len = buf.len();
|
||||
|
||||
if out_of_space {
|
||||
let can_still_fit = buf.len() - remaining;
|
||||
self.buffer.extend_from_slice(&buf[..can_still_fit]);
|
||||
buf = &buf[can_still_fit..];
|
||||
self.flush0()?;
|
||||
}
|
||||
|
||||
// assume that this will often under normal operation just move the pointer back to the
|
||||
// beginning of allocation, because previous split off parts are already sent and
|
||||
// dropped.
|
||||
self.buffer.extend_from_slice(buf);
|
||||
Ok(original_len)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
self.flush0().map(|_| ())
|
||||
}
|
||||
}
|
||||
|
||||
async fn prometheus_metrics_handler(_req: Request<Body>) -> Result<Response<Body>, ApiError> {
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use std::io::Write as _;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
SERVE_METRICS_COUNT.inc();
|
||||
|
||||
/// An [`std::io::Write`] implementation on top of a channel sending [`bytes::Bytes`] chunks.
|
||||
struct ChannelWriter {
|
||||
buffer: BytesMut,
|
||||
tx: mpsc::Sender<std::io::Result<Bytes>>,
|
||||
written: usize,
|
||||
}
|
||||
|
||||
impl ChannelWriter {
|
||||
fn new(buf_len: usize, tx: mpsc::Sender<std::io::Result<Bytes>>) -> Self {
|
||||
assert_ne!(buf_len, 0);
|
||||
ChannelWriter {
|
||||
// split about half off the buffer from the start, because we flush depending on
|
||||
// capacity. first flush will come sooner than without this, but now resizes will
|
||||
// have better chance of picking up the "other" half. not guaranteed of course.
|
||||
buffer: BytesMut::with_capacity(buf_len).split_off(buf_len / 2),
|
||||
tx,
|
||||
written: 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn flush0(&mut self) -> std::io::Result<usize> {
|
||||
let n = self.buffer.len();
|
||||
if n == 0 {
|
||||
return Ok(0);
|
||||
}
|
||||
|
||||
tracing::trace!(n, "flushing");
|
||||
let ready = self.buffer.split().freeze();
|
||||
|
||||
// not ideal to call from blocking code to block_on, but we are sure that this
|
||||
// operation does not spawn_blocking other tasks
|
||||
let res: Result<(), ()> = tokio::runtime::Handle::current().block_on(async {
|
||||
self.tx.send(Ok(ready)).await.map_err(|_| ())?;
|
||||
|
||||
// throttle sending to allow reuse of our buffer in `write`.
|
||||
self.tx.reserve().await.map_err(|_| ())?;
|
||||
|
||||
// now the response task has picked up the buffer and hopefully started
|
||||
// sending it to the client.
|
||||
Ok(())
|
||||
});
|
||||
if res.is_err() {
|
||||
return Err(std::io::ErrorKind::BrokenPipe.into());
|
||||
}
|
||||
self.written += n;
|
||||
Ok(n)
|
||||
}
|
||||
|
||||
fn flushed_bytes(&self) -> usize {
|
||||
self.written
|
||||
}
|
||||
}
|
||||
|
||||
impl std::io::Write for ChannelWriter {
|
||||
fn write(&mut self, mut buf: &[u8]) -> std::io::Result<usize> {
|
||||
let remaining = self.buffer.capacity() - self.buffer.len();
|
||||
|
||||
let out_of_space = remaining < buf.len();
|
||||
|
||||
let original_len = buf.len();
|
||||
|
||||
if out_of_space {
|
||||
let can_still_fit = buf.len() - remaining;
|
||||
self.buffer.extend_from_slice(&buf[..can_still_fit]);
|
||||
buf = &buf[can_still_fit..];
|
||||
self.flush0()?;
|
||||
}
|
||||
|
||||
// assume that this will often under normal operation just move the pointer back to the
|
||||
// beginning of allocation, because previous split off parts are already sent and
|
||||
// dropped.
|
||||
self.buffer.extend_from_slice(buf);
|
||||
Ok(original_len)
|
||||
}
|
||||
|
||||
fn flush(&mut self) -> std::io::Result<()> {
|
||||
self.flush0().map(|_| ())
|
||||
}
|
||||
}
|
||||
|
||||
let started_at = std::time::Instant::now();
|
||||
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
@@ -389,7 +389,7 @@ fn parse_token(header_value: &str) -> Result<&str, ApiError> {
|
||||
}
|
||||
|
||||
pub fn auth_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
|
||||
provide_auth: fn(&Request<Body>) -> Option<&JwtAuth>,
|
||||
provide_auth: fn(&Request<Body>) -> Option<&SwappableJwtAuth>,
|
||||
) -> Middleware<B, ApiError> {
|
||||
Middleware::pre(move |req| async move {
|
||||
if let Some(auth) = provide_auth(&req) {
|
||||
@@ -400,9 +400,11 @@ pub fn auth_middleware<B: hyper::body::HttpBody + Send + Sync + 'static>(
|
||||
})?;
|
||||
let token = parse_token(header_value)?;
|
||||
|
||||
let data = auth
|
||||
.decode(token)
|
||||
.map_err(|_| ApiError::Unauthorized("malformed jwt token".to_string()))?;
|
||||
let data = auth.decode(token).map_err(|err| {
|
||||
warn!("Authentication error: {err}");
|
||||
// Rely on From<AuthError> for ApiError impl
|
||||
err
|
||||
})?;
|
||||
req.set_context(data.claims);
|
||||
}
|
||||
None => {
|
||||
@@ -450,12 +452,11 @@ where
|
||||
|
||||
pub fn check_permission_with(
|
||||
req: &Request<Body>,
|
||||
check_permission: impl Fn(&Claims) -> Result<(), anyhow::Error>,
|
||||
check_permission: impl Fn(&Claims) -> Result<(), AuthError>,
|
||||
) -> Result<(), ApiError> {
|
||||
match req.context::<Claims>() {
|
||||
Some(claims) => {
|
||||
Ok(check_permission(&claims).map_err(|err| ApiError::Forbidden(err.to_string()))?)
|
||||
}
|
||||
Some(claims) => Ok(check_permission(&claims)
|
||||
.map_err(|_err| ApiError::Forbidden("JWT authentication error".to_string()))?),
|
||||
None => Ok(()), // claims is None because auth is disabled
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};
|
||||
use std::borrow::Cow;
|
||||
use std::error::Error as StdError;
|
||||
use thiserror::Error;
|
||||
use tracing::{error, info};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ApiError {
|
||||
@@ -118,7 +118,11 @@ pub fn api_error_handler(api_error: ApiError) -> Response<Body> {
|
||||
// Print a stack trace for Internal Server errors
|
||||
|
||||
match api_error {
|
||||
ApiError::Forbidden(_) | ApiError::Unauthorized(_) => {
|
||||
warn!("Error processing HTTP request: {api_error:#}")
|
||||
}
|
||||
ApiError::ResourceUnavailable(_) => info!("Error processing HTTP request: {api_error:#}"),
|
||||
ApiError::NotFound(_) => info!("Error processing HTTP request: {api_error:#}"),
|
||||
ApiError::InternalServerError(_) => error!("Error processing HTTP request: {api_error:?}"),
|
||||
_ => error!("Error processing HTTP request: {api_error:#}"),
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::{fmt, str::FromStr};
|
||||
use anyhow::Context;
|
||||
use hex::FromHex;
|
||||
use rand::Rng;
|
||||
use serde::de::Visitor;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
@@ -17,12 +18,74 @@ pub enum IdError {
|
||||
///
|
||||
/// NOTE: It (de)serializes as an array of hex bytes, so the string representation would look
|
||||
/// like `[173,80,132,115,129,226,72,254,170,201,135,108,199,26,228,24]`.
|
||||
///
|
||||
/// Use `#[serde_as(as = "DisplayFromStr")]` to (de)serialize it as hex string instead: `ad50847381e248feaac9876cc71ae418`.
|
||||
/// Check the `serde_with::serde_as` documentation for options for more complex types.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, PartialOrd, Ord)]
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
|
||||
struct Id([u8; 16]);
|
||||
|
||||
impl Serialize for Id {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
if serializer.is_human_readable() {
|
||||
serializer.collect_str(self)
|
||||
} else {
|
||||
self.0.serialize(serializer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Id {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct IdVisitor {
|
||||
is_human_readable_deserializer: bool,
|
||||
}
|
||||
|
||||
impl<'de> Visitor<'de> for IdVisitor {
|
||||
type Value = Id;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
if self.is_human_readable_deserializer {
|
||||
formatter.write_str("value in form of hex string")
|
||||
} else {
|
||||
formatter.write_str("value in form of integer array([u8; 16])")
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: serde::de::SeqAccess<'de>,
|
||||
{
|
||||
let s = serde::de::value::SeqAccessDeserializer::new(seq);
|
||||
let id: [u8; 16] = Deserialize::deserialize(s)?;
|
||||
Ok(Id::from(id))
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Id::from_str(v).map_err(E::custom)
|
||||
}
|
||||
}
|
||||
|
||||
if deserializer.is_human_readable() {
|
||||
deserializer.deserialize_str(IdVisitor {
|
||||
is_human_readable_deserializer: true,
|
||||
})
|
||||
} else {
|
||||
deserializer.deserialize_tuple(
|
||||
16,
|
||||
IdVisitor {
|
||||
is_human_readable_deserializer: false,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Id {
|
||||
pub fn get_from_buf(buf: &mut impl bytes::Buf) -> Id {
|
||||
let mut arr = [0u8; 16];
|
||||
@@ -57,6 +120,8 @@ impl Id {
|
||||
chunk[0] = HEX[((b >> 4) & 0xf) as usize];
|
||||
chunk[1] = HEX[(b & 0xf) as usize];
|
||||
}
|
||||
|
||||
// SAFETY: vec constructed out of `HEX`, it can only be ascii
|
||||
unsafe { String::from_utf8_unchecked(buf) }
|
||||
}
|
||||
}
|
||||
@@ -308,3 +373,112 @@ impl fmt::Display for NodeId {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use serde_assert::{Deserializer, Serializer, Token, Tokens};
|
||||
|
||||
use crate::bin_ser::BeSer;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_id_serde_non_human_readable() {
|
||||
let original_id = Id([
|
||||
173, 80, 132, 115, 129, 226, 72, 254, 170, 201, 135, 108, 199, 26, 228, 24,
|
||||
]);
|
||||
let expected_tokens = Tokens(vec![
|
||||
Token::Tuple { len: 16 },
|
||||
Token::U8(173),
|
||||
Token::U8(80),
|
||||
Token::U8(132),
|
||||
Token::U8(115),
|
||||
Token::U8(129),
|
||||
Token::U8(226),
|
||||
Token::U8(72),
|
||||
Token::U8(254),
|
||||
Token::U8(170),
|
||||
Token::U8(201),
|
||||
Token::U8(135),
|
||||
Token::U8(108),
|
||||
Token::U8(199),
|
||||
Token::U8(26),
|
||||
Token::U8(228),
|
||||
Token::U8(24),
|
||||
Token::TupleEnd,
|
||||
]);
|
||||
|
||||
let serializer = Serializer::builder().is_human_readable(false).build();
|
||||
let serialized_tokens = original_id.serialize(&serializer).unwrap();
|
||||
assert_eq!(serialized_tokens, expected_tokens);
|
||||
|
||||
let mut deserializer = Deserializer::builder()
|
||||
.is_human_readable(false)
|
||||
.tokens(serialized_tokens)
|
||||
.build();
|
||||
let deserialized_id = Id::deserialize(&mut deserializer).unwrap();
|
||||
assert_eq!(deserialized_id, original_id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_id_serde_human_readable() {
|
||||
let original_id = Id([
|
||||
173, 80, 132, 115, 129, 226, 72, 254, 170, 201, 135, 108, 199, 26, 228, 24,
|
||||
]);
|
||||
let expected_tokens = Tokens(vec![Token::Str(String::from(
|
||||
"ad50847381e248feaac9876cc71ae418",
|
||||
))]);
|
||||
|
||||
let serializer = Serializer::builder().is_human_readable(true).build();
|
||||
let serialized_tokens = original_id.serialize(&serializer).unwrap();
|
||||
assert_eq!(serialized_tokens, expected_tokens);
|
||||
|
||||
let mut deserializer = Deserializer::builder()
|
||||
.is_human_readable(true)
|
||||
.tokens(Tokens(vec![Token::Str(String::from(
|
||||
"ad50847381e248feaac9876cc71ae418",
|
||||
))]))
|
||||
.build();
|
||||
assert_eq!(Id::deserialize(&mut deserializer).unwrap(), original_id);
|
||||
}
|
||||
|
||||
macro_rules! roundtrip_type {
|
||||
($type:ty, $expected_bytes:expr) => {{
|
||||
let expected_bytes: [u8; 16] = $expected_bytes;
|
||||
let original_id = <$type>::from(expected_bytes);
|
||||
|
||||
let ser_bytes = original_id.ser().unwrap();
|
||||
assert_eq!(ser_bytes, expected_bytes);
|
||||
|
||||
let des_id = <$type>::des(&ser_bytes).unwrap();
|
||||
assert_eq!(des_id, original_id);
|
||||
}};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_id_bincode_serde() {
|
||||
let expected_bytes = [
|
||||
173, 80, 132, 115, 129, 226, 72, 254, 170, 201, 135, 108, 199, 26, 228, 24,
|
||||
];
|
||||
|
||||
roundtrip_type!(Id, expected_bytes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tenant_id_bincode_serde() {
|
||||
let expected_bytes = [
|
||||
173, 80, 132, 115, 129, 226, 72, 254, 170, 201, 135, 108, 199, 26, 228, 24,
|
||||
];
|
||||
|
||||
roundtrip_type!(TenantId, expected_bytes);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timeline_id_bincode_serde() {
|
||||
let expected_bytes = [
|
||||
173, 80, 132, 115, 129, 226, 72, 254, 170, 201, 135, 108, 199, 26, 228, 24,
|
||||
];
|
||||
|
||||
roundtrip_type!(TimelineId, expected_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
//! `utils` is intended to be a place to put code that is shared
|
||||
//! between other crates in this repository.
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
pub mod backoff;
|
||||
|
||||
@@ -24,6 +25,10 @@ pub mod auth;
|
||||
|
||||
// utility functions and helper traits for unified unique id generation/serialization etc.
|
||||
pub mod id;
|
||||
|
||||
mod hex;
|
||||
pub use hex::Hex;
|
||||
|
||||
// http endpoint utils
|
||||
pub mod http;
|
||||
|
||||
@@ -73,6 +78,11 @@ pub mod completion;
|
||||
/// Reporting utilities
|
||||
pub mod error;
|
||||
|
||||
/// async timeout helper
|
||||
pub mod timeout;
|
||||
|
||||
pub mod sync;
|
||||
|
||||
/// This is a shortcut to embed git sha into binaries and avoid copying the same build script to all packages
|
||||
///
|
||||
/// we have several cases:
|
||||
@@ -128,6 +138,21 @@ macro_rules! project_git_version {
|
||||
};
|
||||
}
|
||||
|
||||
/// This is a shortcut to embed build tag into binaries and avoid copying the same build script to all packages
|
||||
#[macro_export]
|
||||
macro_rules! project_build_tag {
|
||||
($const_identifier:ident) => {
|
||||
const $const_identifier: &::core::primitive::str = {
|
||||
const __ARG: &[&::core::primitive::str; 2] = &match ::core::option_env!("BUILD_TAG") {
|
||||
::core::option::Option::Some(x) => ["build_tag-env:", x],
|
||||
::core::option::Option::None => ["build_tag:", ""],
|
||||
};
|
||||
|
||||
$crate::__const_format::concatcp!(__ARG[0], __ARG[1])
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
/// Re-export for `project_git_version` macro
|
||||
#[doc(hidden)]
|
||||
pub use const_format as __const_format;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#![warn(missing_docs)]
|
||||
|
||||
use camino::Utf8Path;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{de::Visitor, Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use std::ops::{Add, AddAssign};
|
||||
use std::str::FromStr;
|
||||
@@ -13,10 +13,114 @@ use crate::seqwait::MonotonicCounter;
|
||||
pub const XLOG_BLCKSZ: u32 = 8192;
|
||||
|
||||
/// A Postgres LSN (Log Sequence Number), also known as an XLogRecPtr
|
||||
#[derive(Clone, Copy, Eq, Ord, PartialEq, PartialOrd, Hash, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
#[derive(Clone, Copy, Eq, Ord, PartialEq, PartialOrd, Hash)]
|
||||
pub struct Lsn(pub u64);
|
||||
|
||||
impl Serialize for Lsn {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
if serializer.is_human_readable() {
|
||||
serializer.collect_str(self)
|
||||
} else {
|
||||
self.0.serialize(serializer)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Lsn {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct LsnVisitor {
|
||||
is_human_readable_deserializer: bool,
|
||||
}
|
||||
|
||||
impl<'de> Visitor<'de> for LsnVisitor {
|
||||
type Value = Lsn;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
if self.is_human_readable_deserializer {
|
||||
formatter.write_str(
|
||||
"value in form of hex string({upper_u32_hex}/{lower_u32_hex}) representing u64 integer",
|
||||
)
|
||||
} else {
|
||||
formatter.write_str("value in form of integer(u64)")
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(Lsn(v))
|
||||
}
|
||||
|
||||
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Lsn::from_str(v).map_err(|e| E::custom(e))
|
||||
}
|
||||
}
|
||||
|
||||
if deserializer.is_human_readable() {
|
||||
deserializer.deserialize_str(LsnVisitor {
|
||||
is_human_readable_deserializer: true,
|
||||
})
|
||||
} else {
|
||||
deserializer.deserialize_u64(LsnVisitor {
|
||||
is_human_readable_deserializer: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Allows (de)serialization of an `Lsn` always as `u64`.
|
||||
///
|
||||
/// ### Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # use serde::{Serialize, Deserialize};
|
||||
/// use utils::lsn::Lsn;
|
||||
///
|
||||
/// #[derive(PartialEq, Serialize, Deserialize, Debug)]
|
||||
/// struct Foo {
|
||||
/// #[serde(with = "utils::lsn::serde_as_u64")]
|
||||
/// always_u64: Lsn,
|
||||
/// }
|
||||
///
|
||||
/// let orig = Foo { always_u64: Lsn(1234) };
|
||||
///
|
||||
/// let res = serde_json::to_string(&orig).unwrap();
|
||||
/// assert_eq!(res, r#"{"always_u64":1234}"#);
|
||||
///
|
||||
/// let foo = serde_json::from_str::<Foo>(&res).unwrap();
|
||||
/// assert_eq!(foo, orig);
|
||||
/// ```
|
||||
///
|
||||
pub mod serde_as_u64 {
|
||||
use super::Lsn;
|
||||
|
||||
/// Serializes the Lsn as u64 disregarding the human readability of the format.
|
||||
///
|
||||
/// Meant to be used via `#[serde(with = "...")]` or `#[serde(serialize_with = "...")]`.
|
||||
pub fn serialize<S: serde::Serializer>(lsn: &Lsn, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
use serde::Serialize;
|
||||
lsn.0.serialize(serializer)
|
||||
}
|
||||
|
||||
/// Deserializes the Lsn as u64 disregarding the human readability of the format.
|
||||
///
|
||||
/// Meant to be used via `#[serde(with = "...")]` or `#[serde(deserialize_with = "...")]`.
|
||||
pub fn deserialize<'de, D: serde::Deserializer<'de>>(deserializer: D) -> Result<Lsn, D::Error> {
|
||||
use serde::Deserialize;
|
||||
u64::deserialize(deserializer).map(Lsn)
|
||||
}
|
||||
}
|
||||
|
||||
/// We tried to parse an LSN from a string, but failed
|
||||
#[derive(Debug, PartialEq, Eq, thiserror::Error)]
|
||||
#[error("LsnParseError")]
|
||||
@@ -264,8 +368,13 @@ impl MonotonicCounter<Lsn> for RecordLsn {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::bin_ser::BeSer;
|
||||
|
||||
use super::*;
|
||||
|
||||
use serde::ser::Serialize;
|
||||
use serde_assert::{Deserializer, Serializer, Token, Tokens};
|
||||
|
||||
#[test]
|
||||
fn test_lsn_strings() {
|
||||
assert_eq!("12345678/AAAA5555".parse(), Ok(Lsn(0x12345678AAAA5555)));
|
||||
@@ -341,4 +450,95 @@ mod tests {
|
||||
assert_eq!(lsn.fetch_max(Lsn(6000)), Lsn(5678));
|
||||
assert_eq!(lsn.fetch_max(Lsn(5000)), Lsn(6000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lsn_serde() {
|
||||
let original_lsn = Lsn(0x0123456789abcdef);
|
||||
let expected_readable_tokens = Tokens(vec![Token::U64(0x0123456789abcdef)]);
|
||||
let expected_non_readable_tokens =
|
||||
Tokens(vec![Token::Str(String::from("1234567/89ABCDEF"))]);
|
||||
|
||||
// Testing human_readable ser/de
|
||||
let serializer = Serializer::builder().is_human_readable(false).build();
|
||||
let readable_ser_tokens = original_lsn.serialize(&serializer).unwrap();
|
||||
assert_eq!(readable_ser_tokens, expected_readable_tokens);
|
||||
|
||||
let mut deserializer = Deserializer::builder()
|
||||
.is_human_readable(false)
|
||||
.tokens(readable_ser_tokens)
|
||||
.build();
|
||||
let des_lsn = Lsn::deserialize(&mut deserializer).unwrap();
|
||||
assert_eq!(des_lsn, original_lsn);
|
||||
|
||||
// Testing NON human_readable ser/de
|
||||
let serializer = Serializer::builder().is_human_readable(true).build();
|
||||
let non_readable_ser_tokens = original_lsn.serialize(&serializer).unwrap();
|
||||
assert_eq!(non_readable_ser_tokens, expected_non_readable_tokens);
|
||||
|
||||
let mut deserializer = Deserializer::builder()
|
||||
.is_human_readable(true)
|
||||
.tokens(non_readable_ser_tokens)
|
||||
.build();
|
||||
let des_lsn = Lsn::deserialize(&mut deserializer).unwrap();
|
||||
assert_eq!(des_lsn, original_lsn);
|
||||
|
||||
// Testing mismatching ser/de
|
||||
let serializer = Serializer::builder().is_human_readable(false).build();
|
||||
let non_readable_ser_tokens = original_lsn.serialize(&serializer).unwrap();
|
||||
|
||||
let mut deserializer = Deserializer::builder()
|
||||
.is_human_readable(true)
|
||||
.tokens(non_readable_ser_tokens)
|
||||
.build();
|
||||
Lsn::deserialize(&mut deserializer).unwrap_err();
|
||||
|
||||
let serializer = Serializer::builder().is_human_readable(true).build();
|
||||
let readable_ser_tokens = original_lsn.serialize(&serializer).unwrap();
|
||||
|
||||
let mut deserializer = Deserializer::builder()
|
||||
.is_human_readable(false)
|
||||
.tokens(readable_ser_tokens)
|
||||
.build();
|
||||
Lsn::deserialize(&mut deserializer).unwrap_err();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lsn_ensure_roundtrip() {
|
||||
let original_lsn = Lsn(0xaaaabbbb);
|
||||
|
||||
let serializer = Serializer::builder().is_human_readable(false).build();
|
||||
let ser_tokens = original_lsn.serialize(&serializer).unwrap();
|
||||
|
||||
let mut deserializer = Deserializer::builder()
|
||||
.is_human_readable(false)
|
||||
.tokens(ser_tokens)
|
||||
.build();
|
||||
|
||||
let des_lsn = Lsn::deserialize(&mut deserializer).unwrap();
|
||||
assert_eq!(des_lsn, original_lsn);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lsn_bincode_serde() {
|
||||
let lsn = Lsn(0x0123456789abcdef);
|
||||
let expected_bytes = [0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef];
|
||||
|
||||
let ser_bytes = lsn.ser().unwrap();
|
||||
assert_eq!(ser_bytes, expected_bytes);
|
||||
|
||||
let des_lsn = Lsn::des(&ser_bytes).unwrap();
|
||||
assert_eq!(des_lsn, lsn);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lsn_bincode_ensure_roundtrip() {
|
||||
let original_lsn = Lsn(0x01_02_03_04_05_06_07_08);
|
||||
let expected_bytes = vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08];
|
||||
|
||||
let ser_bytes = original_lsn.ser().unwrap();
|
||||
assert_eq!(ser_bytes, expected_bytes);
|
||||
|
||||
let des_lsn = Lsn::des(&ser_bytes).unwrap();
|
||||
assert_eq!(des_lsn, original_lsn);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ use std::time::{Duration, SystemTime};
|
||||
use bytes::{Buf, BufMut, Bytes, BytesMut};
|
||||
use pq_proto::{read_cstr, PG_EPOCH};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{serde_as, DisplayFromStr};
|
||||
use tracing::{trace, warn};
|
||||
|
||||
use crate::lsn::Lsn;
|
||||
@@ -15,21 +14,17 @@ use crate::lsn::Lsn;
|
||||
///
|
||||
/// serde Serialize is used only for human readable dump to json (e.g. in
|
||||
/// safekeepers debug_dump).
|
||||
#[serde_as]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct PageserverFeedback {
|
||||
/// Last known size of the timeline. Used to enforce timeline size limit.
|
||||
pub current_timeline_size: u64,
|
||||
/// LSN last received and ingested by the pageserver. Controls backpressure.
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub last_received_lsn: Lsn,
|
||||
/// LSN up to which data is persisted by the pageserver to its local disc.
|
||||
/// Controls backpressure.
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub disk_consistent_lsn: Lsn,
|
||||
/// LSN up to which data is persisted by the pageserver on s3; safekeepers
|
||||
/// consider WAL before it can be removed.
|
||||
#[serde_as(as = "DisplayFromStr")]
|
||||
pub remote_consistent_lsn: Lsn,
|
||||
// Serialize with RFC3339 format.
|
||||
#[serde(with = "serde_systemtime")]
|
||||
|
||||
@@ -125,6 +125,9 @@ where
|
||||
// Wake everyone with an error.
|
||||
let mut internal = self.internal.lock().unwrap();
|
||||
|
||||
// Block any future waiters from starting
|
||||
internal.shutdown = true;
|
||||
|
||||
// This will steal the entire waiters map.
|
||||
// When we drop it all waiters will be woken.
|
||||
mem::take(&mut internal.waiters)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
/// Immediately terminate the calling process without calling
|
||||
/// atexit callbacks, C runtime destructors etc. We mainly use
|
||||
/// this to protect coverage data from concurrent writes.
|
||||
pub fn exit_now(code: u8) {
|
||||
pub fn exit_now(code: u8) -> ! {
|
||||
// SAFETY: exiting is safe, the ffi is not safe
|
||||
unsafe { nix::libc::_exit(code as _) };
|
||||
}
|
||||
|
||||
3
libs/utils/src/sync.rs
Normal file
3
libs/utils/src/sync.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod heavier_once_cell;
|
||||
|
||||
pub mod gate;
|
||||
158
libs/utils/src/sync/gate.rs
Normal file
158
libs/utils/src/sync/gate.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
/// Gates are a concurrency helper, primarily used for implementing safe shutdown.
|
||||
///
|
||||
/// Users of a resource call `enter()` to acquire a GateGuard, and the owner of
|
||||
/// the resource calls `close()` when they want to ensure that all holders of guards
|
||||
/// have released them, and that no future guards will be issued.
|
||||
pub struct Gate {
|
||||
/// Each caller of enter() takes one unit from the semaphore. In close(), we
|
||||
/// take all the units to ensure all GateGuards are destroyed.
|
||||
sem: Arc<tokio::sync::Semaphore>,
|
||||
|
||||
/// For observability only: a name that will be used to log warnings if a particular
|
||||
/// gate is holding up shutdown
|
||||
name: String,
|
||||
}
|
||||
|
||||
/// RAII guard for a [`Gate`]: as long as this exists, calls to [`Gate::close`] will
|
||||
/// not complete.
|
||||
#[derive(Debug)]
|
||||
pub struct GateGuard(tokio::sync::OwnedSemaphorePermit);
|
||||
|
||||
/// Observability helper: every `warn_period`, emit a log warning that we're still waiting on this gate
|
||||
async fn warn_if_stuck<Fut: std::future::Future>(
|
||||
fut: Fut,
|
||||
name: &str,
|
||||
warn_period: std::time::Duration,
|
||||
) -> <Fut as std::future::Future>::Output {
|
||||
let started = std::time::Instant::now();
|
||||
|
||||
let mut fut = std::pin::pin!(fut);
|
||||
|
||||
loop {
|
||||
match tokio::time::timeout(warn_period, &mut fut).await {
|
||||
Ok(ret) => return ret,
|
||||
Err(_) => {
|
||||
tracing::warn!(
|
||||
gate = name,
|
||||
elapsed_ms = started.elapsed().as_millis(),
|
||||
"still waiting, taking longer than expected..."
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum GateError {
|
||||
GateClosed,
|
||||
}
|
||||
|
||||
impl Gate {
|
||||
const MAX_UNITS: u32 = u32::MAX;
|
||||
|
||||
pub fn new(name: String) -> Self {
|
||||
Self {
|
||||
sem: Arc::new(tokio::sync::Semaphore::new(Self::MAX_UNITS as usize)),
|
||||
name,
|
||||
}
|
||||
}
|
||||
|
||||
/// Acquire a guard that will prevent close() calls from completing. If close()
|
||||
/// was already called, this will return an error which should be interpreted
|
||||
/// as "shutting down".
|
||||
///
|
||||
/// This function would typically be used from e.g. request handlers. While holding
|
||||
/// the guard returned from this function, it is important to respect a CancellationToken
|
||||
/// to avoid blocking close() indefinitely: typically types that contain a Gate will
|
||||
/// also contain a CancellationToken.
|
||||
pub fn enter(&self) -> Result<GateGuard, GateError> {
|
||||
self.sem
|
||||
.clone()
|
||||
.try_acquire_owned()
|
||||
.map(GateGuard)
|
||||
.map_err(|_| GateError::GateClosed)
|
||||
}
|
||||
|
||||
/// Types with a shutdown() method and a gate should call this method at the
|
||||
/// end of shutdown, to ensure that all GateGuard holders are done.
|
||||
///
|
||||
/// This will wait for all guards to be destroyed. For this to complete promptly, it is
|
||||
/// important that the holders of such guards are respecting a CancellationToken which has
|
||||
/// been cancelled before entering this function.
|
||||
pub async fn close(&self) {
|
||||
warn_if_stuck(self.do_close(), &self.name, Duration::from_millis(1000)).await
|
||||
}
|
||||
|
||||
/// Check if [`Self::close()`] has finished waiting for all [`Self::enter()`] users to finish. This
|
||||
/// is usually analoguous for "Did shutdown finish?" for types that include a Gate, whereas checking
|
||||
/// the CancellationToken on such types is analogous to "Did shutdown start?"
|
||||
pub fn close_complete(&self) -> bool {
|
||||
self.sem.is_closed()
|
||||
}
|
||||
|
||||
async fn do_close(&self) {
|
||||
tracing::debug!(gate = self.name, "Closing Gate...");
|
||||
match self.sem.acquire_many(Self::MAX_UNITS).await {
|
||||
Ok(_units) => {
|
||||
// While holding all units, close the semaphore. All subsequent calls to enter() will fail.
|
||||
self.sem.close();
|
||||
}
|
||||
Err(_) => {
|
||||
// Semaphore closed: we are the only function that can do this, so it indicates a double-call.
|
||||
// This is legal. Timeline::shutdown for example is not protected from being called more than
|
||||
// once.
|
||||
tracing::debug!(gate = self.name, "Double close")
|
||||
}
|
||||
}
|
||||
tracing::debug!(gate = self.name, "Closed Gate.")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::FutureExt;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_idle_gate() {
|
||||
// Having taken no gates, we should not be blocked in close
|
||||
let gate = Gate::new("test".to_string());
|
||||
gate.close().await;
|
||||
|
||||
// If a guard is dropped before entering, close should not be blocked
|
||||
let gate = Gate::new("test".to_string());
|
||||
let guard = gate.enter().unwrap();
|
||||
drop(guard);
|
||||
gate.close().await;
|
||||
|
||||
// Entering a closed guard fails
|
||||
gate.enter().expect_err("enter should fail after close");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_busy_gate() {
|
||||
let gate = Gate::new("test".to_string());
|
||||
|
||||
let guard = gate.enter().unwrap();
|
||||
|
||||
let mut close_fut = std::pin::pin!(gate.close());
|
||||
|
||||
// Close should be blocked
|
||||
assert!(close_fut.as_mut().now_or_never().is_none());
|
||||
|
||||
// Attempting to enter() should fail, even though close isn't done yet.
|
||||
gate.enter()
|
||||
.expect_err("enter should fail after entering close");
|
||||
|
||||
drop(guard);
|
||||
|
||||
// Guard is gone, close should finish
|
||||
assert!(close_fut.as_mut().now_or_never().is_some());
|
||||
|
||||
// Attempting to enter() is still forbidden
|
||||
gate.enter().expect_err("enter should fail finishing close");
|
||||
}
|
||||
}
|
||||
383
libs/utils/src/sync/heavier_once_cell.rs
Normal file
383
libs/utils/src/sync/heavier_once_cell.rs
Normal file
@@ -0,0 +1,383 @@
|
||||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc, Mutex, MutexGuard,
|
||||
};
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
/// Custom design like [`tokio::sync::OnceCell`] but using [`OwnedSemaphorePermit`] instead of
|
||||
/// `SemaphorePermit`, allowing use of `take` which does not require holding an outer mutex guard
|
||||
/// for the duration of initialization.
|
||||
///
|
||||
/// Has no unsafe, builds upon [`tokio::sync::Semaphore`] and [`std::sync::Mutex`].
|
||||
///
|
||||
/// [`OwnedSemaphorePermit`]: tokio::sync::OwnedSemaphorePermit
|
||||
pub struct OnceCell<T> {
|
||||
inner: Mutex<Inner<T>>,
|
||||
initializers: AtomicUsize,
|
||||
}
|
||||
|
||||
impl<T> Default for OnceCell<T> {
|
||||
/// Create new uninitialized [`OnceCell`].
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
inner: Default::default(),
|
||||
initializers: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Semaphore is the current state:
|
||||
/// - open semaphore means the value is `None`, not yet initialized
|
||||
/// - closed semaphore means the value has been initialized
|
||||
#[derive(Debug)]
|
||||
struct Inner<T> {
|
||||
init_semaphore: Arc<Semaphore>,
|
||||
value: Option<T>,
|
||||
}
|
||||
|
||||
impl<T> Default for Inner<T> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
init_semaphore: Arc::new(Semaphore::new(1)),
|
||||
value: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> OnceCell<T> {
|
||||
/// Creates an already initialized `OnceCell` with the given value.
|
||||
pub fn new(value: T) -> Self {
|
||||
let sem = Semaphore::new(1);
|
||||
sem.close();
|
||||
Self {
|
||||
inner: Mutex::new(Inner {
|
||||
init_semaphore: Arc::new(sem),
|
||||
value: Some(value),
|
||||
}),
|
||||
initializers: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a guard to an existing initialized value, or uniquely initializes the value before
|
||||
/// returning the guard.
|
||||
///
|
||||
/// Initializing might wait on any existing [`Guard::take_and_deinit`] deinitialization.
|
||||
///
|
||||
/// Initialization is panic-safe and cancellation-safe.
|
||||
pub async fn get_or_init<F, Fut, E>(&self, factory: F) -> Result<Guard<'_, T>, E>
|
||||
where
|
||||
F: FnOnce(InitPermit) -> Fut,
|
||||
Fut: std::future::Future<Output = Result<(T, InitPermit), E>>,
|
||||
{
|
||||
let sem = {
|
||||
let guard = self.inner.lock().unwrap();
|
||||
if guard.value.is_some() {
|
||||
return Ok(Guard(guard));
|
||||
}
|
||||
guard.init_semaphore.clone()
|
||||
};
|
||||
|
||||
let permit = {
|
||||
// increment the count for the duration of queued
|
||||
let _guard = CountWaitingInitializers::start(self);
|
||||
sem.acquire_owned().await
|
||||
};
|
||||
|
||||
match permit {
|
||||
Ok(permit) => {
|
||||
let permit = InitPermit(permit);
|
||||
let (value, _permit) = factory(permit).await?;
|
||||
|
||||
let guard = self.inner.lock().unwrap();
|
||||
|
||||
Ok(Self::set0(value, guard))
|
||||
}
|
||||
Err(_closed) => {
|
||||
let guard = self.inner.lock().unwrap();
|
||||
assert!(
|
||||
guard.value.is_some(),
|
||||
"semaphore got closed, must be initialized"
|
||||
);
|
||||
return Ok(Guard(guard));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Assuming a permit is held after previous call to [`Guard::take_and_deinit`], it can be used
|
||||
/// to complete initializing the inner value.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the inner has already been initialized.
|
||||
pub fn set(&self, value: T, _permit: InitPermit) -> Guard<'_, T> {
|
||||
let guard = self.inner.lock().unwrap();
|
||||
|
||||
// cannot assert that this permit is for self.inner.semaphore, but we can assert it cannot
|
||||
// give more permits right now.
|
||||
if guard.init_semaphore.try_acquire().is_ok() {
|
||||
drop(guard);
|
||||
panic!("permit is of wrong origin");
|
||||
}
|
||||
|
||||
Self::set0(value, guard)
|
||||
}
|
||||
|
||||
fn set0(value: T, mut guard: std::sync::MutexGuard<'_, Inner<T>>) -> Guard<'_, T> {
|
||||
if guard.value.is_some() {
|
||||
drop(guard);
|
||||
unreachable!("we won permit, must not be initialized");
|
||||
}
|
||||
guard.value = Some(value);
|
||||
guard.init_semaphore.close();
|
||||
Guard(guard)
|
||||
}
|
||||
|
||||
/// Returns a guard to an existing initialized value, if any.
|
||||
pub fn get(&self) -> Option<Guard<'_, T>> {
|
||||
let guard = self.inner.lock().unwrap();
|
||||
if guard.value.is_some() {
|
||||
Some(Guard(guard))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the number of [`Self::get_or_init`] calls waiting for initialization to complete.
|
||||
pub fn initializer_count(&self) -> usize {
|
||||
self.initializers.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
/// DropGuard counter for queued tasks waiting to initialize, mainly accessible for the
|
||||
/// initializing task for example at the end of initialization.
|
||||
struct CountWaitingInitializers<'a, T>(&'a OnceCell<T>);
|
||||
|
||||
impl<'a, T> CountWaitingInitializers<'a, T> {
|
||||
fn start(target: &'a OnceCell<T>) -> Self {
|
||||
target.initializers.fetch_add(1, Ordering::Relaxed);
|
||||
CountWaitingInitializers(target)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Drop for CountWaitingInitializers<'a, T> {
|
||||
fn drop(&mut self) {
|
||||
self.0.initializers.fetch_sub(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
/// Uninteresting guard object to allow short-lived access to inspect or clone the held,
|
||||
/// initialized value.
|
||||
#[derive(Debug)]
|
||||
pub struct Guard<'a, T>(MutexGuard<'a, Inner<T>>);
|
||||
|
||||
impl<T> std::ops::Deref for Guard<'_, T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.0
|
||||
.value
|
||||
.as_ref()
|
||||
.expect("guard is not created unless value has been initialized")
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::DerefMut for Guard<'_, T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.0
|
||||
.value
|
||||
.as_mut()
|
||||
.expect("guard is not created unless value has been initialized")
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T> Guard<'a, T> {
|
||||
/// Take the current value, and a new permit for it's deinitialization.
|
||||
///
|
||||
/// The permit will be on a semaphore part of the new internal value, and any following
|
||||
/// [`OnceCell::get_or_init`] will wait on it to complete.
|
||||
pub fn take_and_deinit(&mut self) -> (T, InitPermit) {
|
||||
let mut swapped = Inner::default();
|
||||
let permit = swapped
|
||||
.init_semaphore
|
||||
.clone()
|
||||
.try_acquire_owned()
|
||||
.expect("we just created this");
|
||||
std::mem::swap(&mut *self.0, &mut swapped);
|
||||
swapped
|
||||
.value
|
||||
.map(|v| (v, InitPermit(permit)))
|
||||
.expect("guard is not created unless value has been initialized")
|
||||
}
|
||||
}
|
||||
|
||||
/// Type held by OnceCell (de)initializing task.
|
||||
pub struct InitPermit(tokio::sync::OwnedSemaphorePermit);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::{
|
||||
convert::Infallible,
|
||||
sync::atomic::{AtomicUsize, Ordering},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn many_initializers() {
|
||||
#[derive(Default, Debug)]
|
||||
struct Counters {
|
||||
factory_got_to_run: AtomicUsize,
|
||||
future_polled: AtomicUsize,
|
||||
winners: AtomicUsize,
|
||||
}
|
||||
|
||||
let initializers = 100;
|
||||
|
||||
let cell = Arc::new(OnceCell::default());
|
||||
let counters = Arc::new(Counters::default());
|
||||
let barrier = Arc::new(tokio::sync::Barrier::new(initializers + 1));
|
||||
|
||||
let mut js = tokio::task::JoinSet::new();
|
||||
for i in 0..initializers {
|
||||
js.spawn({
|
||||
let cell = cell.clone();
|
||||
let counters = counters.clone();
|
||||
let barrier = barrier.clone();
|
||||
|
||||
async move {
|
||||
barrier.wait().await;
|
||||
let won = {
|
||||
let g = cell
|
||||
.get_or_init(|permit| {
|
||||
counters.factory_got_to_run.fetch_add(1, Ordering::Relaxed);
|
||||
async {
|
||||
counters.future_polled.fetch_add(1, Ordering::Relaxed);
|
||||
Ok::<_, Infallible>((i, permit))
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
*g == i
|
||||
};
|
||||
|
||||
if won {
|
||||
counters.winners.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
barrier.wait().await;
|
||||
|
||||
while let Some(next) = js.join_next().await {
|
||||
next.expect("no panics expected");
|
||||
}
|
||||
|
||||
let mut counters = Arc::try_unwrap(counters).unwrap();
|
||||
|
||||
assert_eq!(*counters.factory_got_to_run.get_mut(), 1);
|
||||
assert_eq!(*counters.future_polled.get_mut(), 1);
|
||||
assert_eq!(*counters.winners.get_mut(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn reinit_waits_for_deinit() {
|
||||
// with the tokio::time paused, we will "sleep" for 1s while holding the reinitialization
|
||||
let sleep_for = Duration::from_secs(1);
|
||||
let initial = 42;
|
||||
let reinit = 1;
|
||||
let cell = Arc::new(OnceCell::new(initial));
|
||||
|
||||
let deinitialization_started = Arc::new(tokio::sync::Barrier::new(2));
|
||||
|
||||
let jh = tokio::spawn({
|
||||
let cell = cell.clone();
|
||||
let deinitialization_started = deinitialization_started.clone();
|
||||
async move {
|
||||
let (answer, _permit) = cell.get().expect("initialized to value").take_and_deinit();
|
||||
assert_eq!(answer, initial);
|
||||
|
||||
deinitialization_started.wait().await;
|
||||
tokio::time::sleep(sleep_for).await;
|
||||
}
|
||||
});
|
||||
|
||||
deinitialization_started.wait().await;
|
||||
|
||||
let started_at = tokio::time::Instant::now();
|
||||
cell.get_or_init(|permit| async { Ok::<_, Infallible>((reinit, permit)) })
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let elapsed = started_at.elapsed();
|
||||
assert!(
|
||||
elapsed >= sleep_for,
|
||||
"initialization should had taken at least the time time slept with permit"
|
||||
);
|
||||
|
||||
jh.await.unwrap();
|
||||
|
||||
assert_eq!(*cell.get().unwrap(), reinit);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reinit_with_deinit_permit() {
|
||||
let cell = Arc::new(OnceCell::new(42));
|
||||
|
||||
let (mol, permit) = cell.get().unwrap().take_and_deinit();
|
||||
cell.set(5, permit);
|
||||
assert_eq!(*cell.get().unwrap(), 5);
|
||||
|
||||
let (five, permit) = cell.get().unwrap().take_and_deinit();
|
||||
assert_eq!(5, five);
|
||||
cell.set(mol, permit);
|
||||
assert_eq!(*cell.get().unwrap(), 42);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn initialization_attemptable_until_ok() {
|
||||
let cell = OnceCell::default();
|
||||
|
||||
for _ in 0..10 {
|
||||
cell.get_or_init(|_permit| async { Err("whatever error") })
|
||||
.await
|
||||
.unwrap_err();
|
||||
}
|
||||
|
||||
let g = cell
|
||||
.get_or_init(|permit| async { Ok::<_, Infallible>(("finally success", permit)) })
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(*g, "finally success");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn initialization_is_cancellation_safe() {
|
||||
let cell = OnceCell::default();
|
||||
|
||||
let barrier = tokio::sync::Barrier::new(2);
|
||||
|
||||
let initializer = cell.get_or_init(|permit| async {
|
||||
barrier.wait().await;
|
||||
futures::future::pending::<()>().await;
|
||||
|
||||
Ok::<_, Infallible>(("never reached", permit))
|
||||
});
|
||||
|
||||
tokio::select! {
|
||||
_ = initializer => { unreachable!("cannot complete; stuck in pending().await") },
|
||||
_ = barrier.wait() => {}
|
||||
};
|
||||
|
||||
// now initializer is dropped
|
||||
|
||||
assert!(cell.get().is_none());
|
||||
|
||||
let g = cell
|
||||
.get_or_init(|permit| async { Ok::<_, Infallible>(("now initialized", permit)) })
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(*g, "now initialized");
|
||||
}
|
||||
}
|
||||
37
libs/utils/src/timeout.rs
Normal file
37
libs/utils/src/timeout.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
pub enum TimeoutCancellableError {
|
||||
Timeout,
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
/// Wrap [`tokio::time::timeout`] with a CancellationToken.
|
||||
///
|
||||
/// This wrapper is appropriate for any long running operation in a task
|
||||
/// that ought to respect a CancellationToken (which means most tasks).
|
||||
///
|
||||
/// The only time you should use a bare tokio::timeout is when the future `F`
|
||||
/// itself respects a CancellationToken: otherwise, always use this wrapper
|
||||
/// with your CancellationToken to ensure that your task does not hold up
|
||||
/// graceful shutdown.
|
||||
pub async fn timeout_cancellable<F>(
|
||||
duration: Duration,
|
||||
cancel: &CancellationToken,
|
||||
future: F,
|
||||
) -> Result<F::Output, TimeoutCancellableError>
|
||||
where
|
||||
F: std::future::Future,
|
||||
{
|
||||
tokio::select!(
|
||||
r = tokio::time::timeout(duration, future) => {
|
||||
r.map_err(|_| TimeoutCancellableError::Timeout)
|
||||
|
||||
},
|
||||
_ = cancel.cancelled() => {
|
||||
Err(TimeoutCancellableError::Cancelled)
|
||||
|
||||
}
|
||||
)
|
||||
}
|
||||
@@ -19,13 +19,12 @@ inotify.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
sysinfo.workspace = true
|
||||
tokio.workspace = true
|
||||
tokio = { workspace = true, features = ["rt-multi-thread"] }
|
||||
tokio-postgres.workspace = true
|
||||
tokio-stream.workspace = true
|
||||
tokio-util.workspace = true
|
||||
tracing.workspace = true
|
||||
tracing-subscriber.workspace = true
|
||||
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
|
||||
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
cgroups-rs = "0.3.3"
|
||||
|
||||
@@ -27,8 +27,8 @@ and old one if it exists.
|
||||
* the filecache: a struct that allows communication with the Postgres file cache.
|
||||
On startup, we connect to the filecache and hold on to the connection for the
|
||||
entire monitor lifetime.
|
||||
* the cgroup watcher: the `CgroupWatcher` manages the `neon-postgres` cgroup by
|
||||
listening for `memory.high` events and setting its `memory.{high,max}` values.
|
||||
* the cgroup watcher: the `CgroupWatcher` polls the `neon-postgres` cgroup's memory
|
||||
usage and sends rolling aggregates to the runner.
|
||||
* the runner: the runner marries the filecache and cgroup watcher together,
|
||||
communicating with the agent throught the `Dispatcher`, and then calling filecache
|
||||
and cgroup watcher functions as needed to upscale and downscale
|
||||
|
||||
@@ -1,161 +1,38 @@
|
||||
use std::{
|
||||
fmt::{Debug, Display},
|
||||
fs,
|
||||
pin::pin,
|
||||
sync::atomic::{AtomicU64, Ordering},
|
||||
};
|
||||
use std::fmt::{self, Debug, Formatter};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use anyhow::{anyhow, bail, Context};
|
||||
use anyhow::{anyhow, Context};
|
||||
use cgroups_rs::{
|
||||
freezer::FreezerController,
|
||||
hierarchies::{self, is_cgroup2_unified_mode, UNIFIED_MOUNTPOINT},
|
||||
hierarchies::{self, is_cgroup2_unified_mode},
|
||||
memory::MemController,
|
||||
MaxValue,
|
||||
Subsystem::{Freezer, Mem},
|
||||
Subsystem,
|
||||
};
|
||||
use inotify::{EventStream, Inotify, WatchMask};
|
||||
use tokio::sync::mpsc::{self, error::TryRecvError};
|
||||
use tokio::time::{Duration, Instant};
|
||||
use tokio_stream::{Stream, StreamExt};
|
||||
use tokio::sync::watch;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::protocol::Resources;
|
||||
use crate::MiB;
|
||||
|
||||
/// Monotonically increasing counter of the number of memory.high events
|
||||
/// the cgroup has experienced.
|
||||
///
|
||||
/// We use this to determine if a modification to the `memory.events` file actually
|
||||
/// changed the `high` field. If not, we don't care about the change. When we
|
||||
/// read the file, we check the `high` field in the file against `MEMORY_EVENT_COUNT`
|
||||
/// to see if it changed since last time.
|
||||
pub static MEMORY_EVENT_COUNT: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
/// Monotonically increasing counter that gives each cgroup event a unique id.
|
||||
///
|
||||
/// This allows us to answer questions like "did this upscale arrive before this
|
||||
/// memory.high?". This static is also used by the `Sequenced` type to "tag" values
|
||||
/// with a sequence number. As such, prefer to used the `Sequenced` type rather
|
||||
/// than this static directly.
|
||||
static EVENT_SEQUENCE_NUMBER: AtomicU64 = AtomicU64::new(0);
|
||||
|
||||
/// A memory event type reported in memory.events.
|
||||
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
|
||||
pub enum MemoryEvent {
|
||||
Low,
|
||||
High,
|
||||
Max,
|
||||
Oom,
|
||||
OomKill,
|
||||
OomGroupKill,
|
||||
}
|
||||
|
||||
impl MemoryEvent {
|
||||
fn as_str(&self) -> &str {
|
||||
match self {
|
||||
MemoryEvent::Low => "low",
|
||||
MemoryEvent::High => "high",
|
||||
MemoryEvent::Max => "max",
|
||||
MemoryEvent::Oom => "oom",
|
||||
MemoryEvent::OomKill => "oom_kill",
|
||||
MemoryEvent::OomGroupKill => "oom_group_kill",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for MemoryEvent {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for a `CgroupWatcher`
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
// The target difference between the total memory reserved for the cgroup
|
||||
// and the value of the cgroup's memory.high.
|
||||
//
|
||||
// In other words, memory.high + oom_buffer_bytes will equal the total memory that the cgroup may
|
||||
// use (equal to system memory, minus whatever's taken out for the file cache).
|
||||
oom_buffer_bytes: u64,
|
||||
/// Interval at which we should be fetching memory statistics
|
||||
memory_poll_interval: Duration,
|
||||
|
||||
// The amount of memory, in bytes, below a proposed new value for
|
||||
// memory.high that the cgroup's memory usage must be for us to downscale
|
||||
//
|
||||
// In other words, we can downscale only when:
|
||||
//
|
||||
// memory.current + memory_high_buffer_bytes < (proposed) memory.high
|
||||
//
|
||||
// TODO: there's some minor issues with this approach -- in particular, that we might have
|
||||
// memory in use by the kernel's page cache that we're actually ok with getting rid of.
|
||||
pub(crate) memory_high_buffer_bytes: u64,
|
||||
|
||||
// The maximum duration, in milliseconds, that we're allowed to pause
|
||||
// the cgroup for while waiting for the autoscaler-agent to upscale us
|
||||
max_upscale_wait: Duration,
|
||||
|
||||
// The required minimum time, in milliseconds, that we must wait before re-freezing
|
||||
// the cgroup while waiting for the autoscaler-agent to upscale us.
|
||||
do_not_freeze_more_often_than: Duration,
|
||||
|
||||
// The amount of memory, in bytes, that we should periodically increase memory.high
|
||||
// by while waiting for the autoscaler-agent to upscale us.
|
||||
//
|
||||
// This exists to avoid the excessive throttling that happens when a cgroup is above its
|
||||
// memory.high for too long. See more here:
|
||||
// https://github.com/neondatabase/autoscaling/issues/44#issuecomment-1522487217
|
||||
memory_high_increase_by_bytes: u64,
|
||||
|
||||
// The period, in milliseconds, at which we should repeatedly increase the value
|
||||
// of the cgroup's memory.high while we're waiting on upscaling and memory.high
|
||||
// is still being hit.
|
||||
//
|
||||
// Technically speaking, this actually serves as a rate limit to moderate responding to
|
||||
// memory.high events, but these are roughly equivalent if the process is still allocating
|
||||
// memory.
|
||||
memory_high_increase_every: Duration,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Calculate the new value for the cgroups memory.high based on system memory
|
||||
pub fn calculate_memory_high_value(&self, total_system_mem: u64) -> u64 {
|
||||
total_system_mem.saturating_sub(self.oom_buffer_bytes)
|
||||
}
|
||||
/// The number of samples used in constructing aggregated memory statistics
|
||||
memory_history_len: usize,
|
||||
/// The number of most recent samples that will be periodically logged.
|
||||
///
|
||||
/// Each sample is logged exactly once. Increasing this value means that recent samples will be
|
||||
/// logged less frequently, and vice versa.
|
||||
///
|
||||
/// For simplicity, this value must be greater than or equal to `memory_history_len`.
|
||||
memory_history_log_interval: usize,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
oom_buffer_bytes: 100 * MiB,
|
||||
memory_high_buffer_bytes: 100 * MiB,
|
||||
// while waiting for upscale, don't freeze for more than 20ms every 1s
|
||||
max_upscale_wait: Duration::from_millis(20),
|
||||
do_not_freeze_more_often_than: Duration::from_millis(1000),
|
||||
// while waiting for upscale, increase memory.high by 10MiB every 25ms
|
||||
memory_high_increase_by_bytes: 10 * MiB,
|
||||
memory_high_increase_every: Duration::from_millis(25),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Used to represent data that is associated with a certain point in time, such
|
||||
/// as an upscale request or memory.high event.
|
||||
///
|
||||
/// Internally, creating a `Sequenced` uses a static atomic counter to obtain
|
||||
/// a unique sequence number. Sequence numbers are monotonically increasing,
|
||||
/// allowing us to answer questions like "did this upscale happen after this
|
||||
/// memory.high event?" by comparing the sequence numbers of the two events.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Sequenced<T> {
|
||||
seqnum: u64,
|
||||
data: T,
|
||||
}
|
||||
|
||||
impl<T> Sequenced<T> {
|
||||
pub fn new(data: T) -> Self {
|
||||
Self {
|
||||
seqnum: EVENT_SEQUENCE_NUMBER.fetch_add(1, Ordering::AcqRel),
|
||||
data,
|
||||
memory_poll_interval: Duration::from_millis(100),
|
||||
memory_history_len: 5, // use 500ms of history for decision-making
|
||||
memory_history_log_interval: 20, // but only log every ~2s (otherwise it's spammy)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -170,74 +47,14 @@ impl<T> Sequenced<T> {
|
||||
pub struct CgroupWatcher {
|
||||
pub config: Config,
|
||||
|
||||
/// The sequence number of the last upscale.
|
||||
///
|
||||
/// If we receive a memory.high event that has a _lower_ sequence number than
|
||||
/// `last_upscale_seqnum`, then we know it occured before the upscale, and we
|
||||
/// can safely ignore it.
|
||||
///
|
||||
/// Note: Like the `events` field, this doesn't _need_ interior mutability but we
|
||||
/// use it anyways so that methods take `&self`, not `&mut self`.
|
||||
last_upscale_seqnum: AtomicU64,
|
||||
|
||||
/// A channel on which we send messages to request upscale from the dispatcher.
|
||||
upscale_requester: mpsc::Sender<()>,
|
||||
|
||||
/// The actual cgroup we are watching and managing.
|
||||
cgroup: cgroups_rs::Cgroup,
|
||||
}
|
||||
|
||||
/// Read memory.events for the desired event type.
|
||||
///
|
||||
/// `path` specifies the path to the desired `memory.events` file.
|
||||
/// For more info, see the `memory.events` section of the [kernel docs]
|
||||
/// <https://docs.kernel.org/admin-guide/cgroup-v2.html#memory-interface-files>
|
||||
fn get_event_count(path: &str, event: MemoryEvent) -> anyhow::Result<u64> {
|
||||
let contents = fs::read_to_string(path)
|
||||
.with_context(|| format!("failed to read memory.events from {path}"))?;
|
||||
|
||||
// Then contents of the file look like:
|
||||
// low 42
|
||||
// high 101
|
||||
// ...
|
||||
contents
|
||||
.lines()
|
||||
.filter_map(|s| s.split_once(' '))
|
||||
.find(|(e, _)| *e == event.as_str())
|
||||
.ok_or_else(|| anyhow!("failed to find entry for memory.{event} events in {path}"))
|
||||
.and_then(|(_, count)| {
|
||||
count
|
||||
.parse::<u64>()
|
||||
.with_context(|| format!("failed to parse memory.{event} as u64"))
|
||||
})
|
||||
}
|
||||
|
||||
/// Create an event stream that produces events whenever the file at the provided
|
||||
/// path is modified.
|
||||
fn create_file_watcher(path: &str) -> anyhow::Result<EventStream<[u8; 1024]>> {
|
||||
info!("creating file watcher for {path}");
|
||||
let inotify = Inotify::init().context("failed to initialize file watcher")?;
|
||||
inotify
|
||||
.watches()
|
||||
.add(path, WatchMask::MODIFY)
|
||||
.with_context(|| format!("failed to start watching {path}"))?;
|
||||
inotify
|
||||
// The inotify docs use [0u8; 1024] so we'll just copy them. We only need
|
||||
// to store one event at a time - if the event gets written over, that's
|
||||
// ok. We still see that there is an event. For more information, see:
|
||||
// https://man7.org/linux/man-pages/man7/inotify.7.html
|
||||
.into_event_stream([0u8; 1024])
|
||||
.context("failed to start inotify event stream")
|
||||
}
|
||||
|
||||
impl CgroupWatcher {
|
||||
/// Create a new `CgroupWatcher`.
|
||||
#[tracing::instrument(skip_all, fields(%name))]
|
||||
pub fn new(
|
||||
name: String,
|
||||
// A channel on which to send upscale requests
|
||||
upscale_requester: mpsc::Sender<()>,
|
||||
) -> anyhow::Result<(Self, impl Stream<Item = Sequenced<u64>>)> {
|
||||
pub fn new(name: String) -> anyhow::Result<Self> {
|
||||
// TODO: clarify exactly why we need v2
|
||||
// Make sure cgroups v2 (aka unified) are supported
|
||||
if !is_cgroup2_unified_mode() {
|
||||
@@ -245,410 +62,203 @@ impl CgroupWatcher {
|
||||
}
|
||||
let cgroup = cgroups_rs::Cgroup::load(hierarchies::auto(), &name);
|
||||
|
||||
// Start monitoring the cgroup for memory events. In general, for
|
||||
// cgroups v2 (aka unified), metrics are reported in files like
|
||||
// > `/sys/fs/cgroup/{name}/{metric}`
|
||||
// We are looking for `memory.high` events, which are stored in the
|
||||
// file `memory.events`. For more info, see the `memory.events` section
|
||||
// of https://docs.kernel.org/admin-guide/cgroup-v2.html#memory-interface-files
|
||||
let path = format!("{}/{}/memory.events", UNIFIED_MOUNTPOINT, &name);
|
||||
let memory_events = create_file_watcher(&path)
|
||||
.with_context(|| format!("failed to create event watcher for {path}"))?
|
||||
// This would be nice with with .inspect_err followed by .ok
|
||||
.filter_map(move |_| match get_event_count(&path, MemoryEvent::High) {
|
||||
Ok(high) => Some(high),
|
||||
Err(error) => {
|
||||
// TODO: Might want to just panic here
|
||||
warn!(?error, "failed to read high events count from {}", &path);
|
||||
None
|
||||
}
|
||||
})
|
||||
// Only report the event if the memory.high count increased
|
||||
.filter_map(|high| {
|
||||
if MEMORY_EVENT_COUNT.fetch_max(high, Ordering::AcqRel) < high {
|
||||
Some(high)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.map(Sequenced::new);
|
||||
|
||||
let initial_count = get_event_count(
|
||||
&format!("{}/{}/memory.events", UNIFIED_MOUNTPOINT, &name),
|
||||
MemoryEvent::High,
|
||||
)?;
|
||||
|
||||
info!(initial_count, "initial memory.high event count");
|
||||
|
||||
// Hard update `MEMORY_EVENT_COUNT` since there could have been processes
|
||||
// running in the cgroup before that caused it to be non-zero.
|
||||
MEMORY_EVENT_COUNT.fetch_max(initial_count, Ordering::AcqRel);
|
||||
|
||||
Ok((
|
||||
Self {
|
||||
cgroup,
|
||||
upscale_requester,
|
||||
last_upscale_seqnum: AtomicU64::new(0),
|
||||
config: Default::default(),
|
||||
},
|
||||
memory_events,
|
||||
))
|
||||
Ok(Self {
|
||||
cgroup,
|
||||
config: Default::default(),
|
||||
})
|
||||
}
|
||||
|
||||
/// The entrypoint for the `CgroupWatcher`.
|
||||
#[tracing::instrument(skip_all)]
|
||||
pub async fn watch<E>(
|
||||
pub async fn watch(
|
||||
&self,
|
||||
// These are ~dependency injected~ (fancy, I know) because this function
|
||||
// should never return.
|
||||
// -> therefore: when we tokio::spawn it, we don't await the JoinHandle.
|
||||
// -> therefore: if we want to stick it in an Arc so many threads can access
|
||||
// it, methods can never take mutable access.
|
||||
// - note: we use the Arc strategy so that a) we can call this function
|
||||
// right here and b) the runner can call the set/get_memory methods
|
||||
// -> since calling recv() on a tokio::sync::mpsc::Receiver takes &mut self,
|
||||
// we just pass them in here instead of holding them in fields, as that
|
||||
// would require this method to take &mut self.
|
||||
mut upscales: mpsc::Receiver<Sequenced<Resources>>,
|
||||
events: E,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
E: Stream<Item = Sequenced<u64>>,
|
||||
{
|
||||
let mut wait_to_freeze = pin!(tokio::time::sleep(Duration::ZERO));
|
||||
let mut last_memory_high_increase_at: Option<Instant> = None;
|
||||
let mut events = pin!(events);
|
||||
|
||||
// Are we waiting to be upscaled? Could be true if we request upscale due
|
||||
// to a memory.high event and it does not arrive in time.
|
||||
let mut waiting_on_upscale = false;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
upscale = upscales.recv() => {
|
||||
let Sequenced { seqnum, data } = upscale
|
||||
.context("failed to listen on upscale notification channel")?;
|
||||
waiting_on_upscale = false;
|
||||
last_memory_high_increase_at = None;
|
||||
self.last_upscale_seqnum.store(seqnum, Ordering::Release);
|
||||
info!(cpu = data.cpu, mem_bytes = data.mem, "received upscale");
|
||||
}
|
||||
event = events.next() => {
|
||||
let Some(Sequenced { seqnum, .. }) = event else {
|
||||
bail!("failed to listen for memory.high events")
|
||||
};
|
||||
// The memory.high came before our last upscale, so we consider
|
||||
// it resolved
|
||||
if self.last_upscale_seqnum.fetch_max(seqnum, Ordering::AcqRel) > seqnum {
|
||||
info!(
|
||||
"received memory.high event, but it came before our last upscale -> ignoring it"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// The memory.high came after our latest upscale. We don't
|
||||
// want to do anything yet, so peek the next event in hopes
|
||||
// that it's an upscale.
|
||||
if let Some(upscale_num) = self
|
||||
.upscaled(&mut upscales)
|
||||
.context("failed to check if we were upscaled")?
|
||||
{
|
||||
if upscale_num > seqnum {
|
||||
info!(
|
||||
"received memory.high event, but it came before our last upscale -> ignoring it"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// If it's been long enough since we last froze, freeze the
|
||||
// cgroup and request upscale
|
||||
if wait_to_freeze.is_elapsed() {
|
||||
info!("received memory.high event -> requesting upscale");
|
||||
waiting_on_upscale = self
|
||||
.handle_memory_high_event(&mut upscales)
|
||||
.await
|
||||
.context("failed to handle upscale")?;
|
||||
wait_to_freeze
|
||||
.as_mut()
|
||||
.reset(Instant::now() + self.config.do_not_freeze_more_often_than);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Ok, we can't freeze, just request upscale
|
||||
if !waiting_on_upscale {
|
||||
info!("received memory.high event, but too soon to refreeze -> requesting upscale");
|
||||
|
||||
// Make check to make sure we haven't been upscaled in the
|
||||
// meantine (can happen if the agent independently decides
|
||||
// to upscale us again)
|
||||
if self
|
||||
.upscaled(&mut upscales)
|
||||
.context("failed to check if we were upscaled")?
|
||||
.is_some()
|
||||
{
|
||||
info!("no need to request upscaling because we got upscaled");
|
||||
continue;
|
||||
}
|
||||
self.upscale_requester
|
||||
.send(())
|
||||
.await
|
||||
.context("failed to request upscale")?;
|
||||
waiting_on_upscale = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Shoot, we can't freeze or and we're still waiting on upscale,
|
||||
// increase memory.high to reduce throttling
|
||||
let can_increase_memory_high = match last_memory_high_increase_at {
|
||||
None => true,
|
||||
Some(t) => t.elapsed() > self.config.memory_high_increase_every,
|
||||
};
|
||||
if can_increase_memory_high {
|
||||
info!(
|
||||
"received memory.high event, \
|
||||
but too soon to refreeze and already requested upscale \
|
||||
-> increasing memory.high"
|
||||
);
|
||||
|
||||
// Make check to make sure we haven't been upscaled in the
|
||||
// meantine (can happen if the agent independently decides
|
||||
// to upscale us again)
|
||||
if self
|
||||
.upscaled(&mut upscales)
|
||||
.context("failed to check if we were upscaled")?
|
||||
.is_some()
|
||||
{
|
||||
info!("no need to increase memory.high because got upscaled");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Request upscale anyways (the agent will handle deduplicating
|
||||
// requests)
|
||||
self.upscale_requester
|
||||
.send(())
|
||||
.await
|
||||
.context("failed to request upscale")?;
|
||||
|
||||
let memory_high =
|
||||
self.get_memory_high_bytes().context("failed to get memory.high")?;
|
||||
let new_high = memory_high + self.config.memory_high_increase_by_bytes;
|
||||
info!(
|
||||
current_high_bytes = memory_high,
|
||||
new_high_bytes = new_high,
|
||||
"updating memory.high"
|
||||
);
|
||||
self.set_memory_high_bytes(new_high)
|
||||
.context("failed to set memory.high")?;
|
||||
last_memory_high_increase_at = Some(Instant::now());
|
||||
continue;
|
||||
}
|
||||
|
||||
info!("received memory.high event, but can't do anything");
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a `memory.high`, returning whether we are still waiting on upscale
|
||||
/// by the time the function returns.
|
||||
///
|
||||
/// The general plan for handling a `memory.high` event is as follows:
|
||||
/// 1. Freeze the cgroup
|
||||
/// 2. Start a timer for `self.config.max_upscale_wait`
|
||||
/// 3. Request upscale
|
||||
/// 4. After the timer elapses or we receive upscale, thaw the cgroup.
|
||||
/// 5. Return whether or not we are still waiting for upscale. If we are,
|
||||
/// we'll increase the cgroups memory.high to avoid getting oom killed
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn handle_memory_high_event(
|
||||
&self,
|
||||
upscales: &mut mpsc::Receiver<Sequenced<Resources>>,
|
||||
) -> anyhow::Result<bool> {
|
||||
// Immediately freeze the cgroup before doing anything else.
|
||||
info!("received memory.high event -> freezing cgroup");
|
||||
self.freeze().context("failed to freeze cgroup")?;
|
||||
|
||||
// We'll use this for logging durations
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Await the upscale until we have to unfreeze
|
||||
let timed =
|
||||
tokio::time::timeout(self.config.max_upscale_wait, self.await_upscale(upscales));
|
||||
|
||||
// Request the upscale
|
||||
info!(
|
||||
wait = ?self.config.max_upscale_wait,
|
||||
"sending request for immediate upscaling",
|
||||
);
|
||||
self.upscale_requester
|
||||
.send(())
|
||||
.await
|
||||
.context("failed to request upscale")?;
|
||||
|
||||
let waiting_on_upscale = match timed.await {
|
||||
Ok(Ok(())) => {
|
||||
info!(elapsed = ?start_time.elapsed(), "received upscale in time");
|
||||
false
|
||||
}
|
||||
// **important**: unfreeze the cgroup before ?-reporting the error
|
||||
Ok(Err(e)) => {
|
||||
info!("error waiting for upscale -> thawing cgroup");
|
||||
self.thaw()
|
||||
.context("failed to thaw cgroup after errored waiting for upscale")?;
|
||||
Err(e.context("failed to await upscale"))?
|
||||
}
|
||||
Err(_) => {
|
||||
info!(elapsed = ?self.config.max_upscale_wait, "timed out waiting for upscale");
|
||||
true
|
||||
}
|
||||
};
|
||||
|
||||
info!("thawing cgroup");
|
||||
self.thaw().context("failed to thaw cgroup")?;
|
||||
|
||||
Ok(waiting_on_upscale)
|
||||
}
|
||||
|
||||
/// Checks whether we were just upscaled, returning the upscale's sequence
|
||||
/// number if so.
|
||||
#[tracing::instrument(skip_all)]
|
||||
fn upscaled(
|
||||
&self,
|
||||
upscales: &mut mpsc::Receiver<Sequenced<Resources>>,
|
||||
) -> anyhow::Result<Option<u64>> {
|
||||
let Sequenced { seqnum, data } = match upscales.try_recv() {
|
||||
Ok(upscale) => upscale,
|
||||
Err(TryRecvError::Empty) => return Ok(None),
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
bail!("upscale notification channel was disconnected")
|
||||
}
|
||||
};
|
||||
|
||||
// Make sure to update the last upscale sequence number
|
||||
self.last_upscale_seqnum.store(seqnum, Ordering::Release);
|
||||
info!(cpu = data.cpu, mem_bytes = data.mem, "received upscale");
|
||||
Ok(Some(seqnum))
|
||||
}
|
||||
|
||||
/// Await an upscale event, discarding any `memory.high` events received in
|
||||
/// the process.
|
||||
///
|
||||
/// This is used in `handle_memory_high_event`, where we need to listen
|
||||
/// for upscales in particular so we know if we can thaw the cgroup early.
|
||||
#[tracing::instrument(skip_all)]
|
||||
async fn await_upscale(
|
||||
&self,
|
||||
upscales: &mut mpsc::Receiver<Sequenced<Resources>>,
|
||||
updates: watch::Sender<(Instant, MemoryHistory)>,
|
||||
) -> anyhow::Result<()> {
|
||||
let Sequenced { seqnum, .. } = upscales
|
||||
.recv()
|
||||
.await
|
||||
.context("error listening for upscales")?;
|
||||
// this requirement makes the code a bit easier to work with; see the config for more.
|
||||
assert!(self.config.memory_history_len <= self.config.memory_history_log_interval);
|
||||
|
||||
self.last_upscale_seqnum.store(seqnum, Ordering::Release);
|
||||
Ok(())
|
||||
}
|
||||
let mut ticker = tokio::time::interval(self.config.memory_poll_interval);
|
||||
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
// ticker.reset_immediately(); // FIXME: enable this once updating to tokio >= 1.30.0
|
||||
|
||||
/// Get the cgroup's name.
|
||||
pub fn path(&self) -> &str {
|
||||
self.cgroup.path()
|
||||
}
|
||||
}
|
||||
let mem_controller = self.memory()?;
|
||||
|
||||
// Methods for manipulating the actual cgroup
|
||||
impl CgroupWatcher {
|
||||
/// Get a handle on the freezer subsystem.
|
||||
fn freezer(&self) -> anyhow::Result<&FreezerController> {
|
||||
if let Some(Freezer(freezer)) = self
|
||||
.cgroup
|
||||
.subsystems()
|
||||
.iter()
|
||||
.find(|sub| matches!(sub, Freezer(_)))
|
||||
{
|
||||
Ok(freezer)
|
||||
} else {
|
||||
anyhow::bail!("could not find freezer subsystem")
|
||||
// buffer for samples that will be logged. once full, it remains so.
|
||||
let history_log_len = self.config.memory_history_log_interval;
|
||||
let mut history_log_buf = vec![MemoryStatus::zeroed(); history_log_len];
|
||||
|
||||
for t in 0_u64.. {
|
||||
ticker.tick().await;
|
||||
|
||||
let now = Instant::now();
|
||||
let mem = Self::memory_usage(mem_controller);
|
||||
|
||||
let i = t as usize % history_log_len;
|
||||
history_log_buf[i] = mem;
|
||||
|
||||
// We're taking *at most* memory_history_len values; we may be bounded by the total
|
||||
// number of samples that have come in so far.
|
||||
let samples_count = (t + 1).min(self.config.memory_history_len as u64) as usize;
|
||||
// NB: in `ring_buf_recent_values_iter`, `i` is *inclusive*, which matches the fact
|
||||
// that we just inserted a value there, so the end of the iterator will *include* the
|
||||
// value at i, rather than stopping just short of it.
|
||||
let samples = ring_buf_recent_values_iter(&history_log_buf, i, samples_count);
|
||||
|
||||
let summary = MemoryHistory {
|
||||
avg_non_reclaimable: samples.map(|h| h.non_reclaimable).sum::<u64>()
|
||||
/ samples_count as u64,
|
||||
samples_count,
|
||||
samples_span: self.config.memory_poll_interval * (samples_count - 1) as u32,
|
||||
};
|
||||
|
||||
// Log the current history if it's time to do so. Because `history_log_buf` has length
|
||||
// equal to the logging interval, we can just log the entire buffer every time we set
|
||||
// the last entry, which also means that for this log line, we can ignore that it's a
|
||||
// ring buffer (because all the entries are in order of increasing time).
|
||||
if i == history_log_len - 1 {
|
||||
info!(
|
||||
history = ?MemoryStatus::debug_slice(&history_log_buf),
|
||||
summary = ?summary,
|
||||
"Recent cgroup memory statistics history"
|
||||
);
|
||||
}
|
||||
|
||||
updates
|
||||
.send((now, summary))
|
||||
.context("failed to send MemoryHistory")?;
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to freeze the cgroup.
|
||||
pub fn freeze(&self) -> anyhow::Result<()> {
|
||||
self.freezer()
|
||||
.context("failed to get freezer subsystem")?
|
||||
.freeze()
|
||||
.context("failed to freeze")
|
||||
}
|
||||
|
||||
/// Attempt to thaw the cgroup.
|
||||
pub fn thaw(&self) -> anyhow::Result<()> {
|
||||
self.freezer()
|
||||
.context("failed to get freezer subsystem")?
|
||||
.thaw()
|
||||
.context("failed to thaw")
|
||||
unreachable!()
|
||||
}
|
||||
|
||||
/// Get a handle on the memory subsystem.
|
||||
///
|
||||
/// Note: this method does not require `self.memory_update_lock` because
|
||||
/// getting a handle to the subsystem does not access any of the files we
|
||||
/// care about, such as memory.high and memory.events
|
||||
fn memory(&self) -> anyhow::Result<&MemController> {
|
||||
if let Some(Mem(memory)) = self
|
||||
.cgroup
|
||||
self.cgroup
|
||||
.subsystems()
|
||||
.iter()
|
||||
.find(|sub| matches!(sub, Mem(_)))
|
||||
{
|
||||
Ok(memory)
|
||||
} else {
|
||||
anyhow::bail!("could not find memory subsystem")
|
||||
}
|
||||
}
|
||||
|
||||
/// Get cgroup current memory usage.
|
||||
pub fn current_memory_usage(&self) -> anyhow::Result<u64> {
|
||||
Ok(self
|
||||
.memory()
|
||||
.context("failed to get memory subsystem")?
|
||||
.memory_stat()
|
||||
.usage_in_bytes)
|
||||
}
|
||||
|
||||
/// Set cgroup memory.high threshold.
|
||||
pub fn set_memory_high_bytes(&self, bytes: u64) -> anyhow::Result<()> {
|
||||
self.set_memory_high_internal(MaxValue::Value(u64::min(bytes, i64::MAX as u64) as i64))
|
||||
}
|
||||
|
||||
/// Set the cgroup's memory.high to 'max', disabling it.
|
||||
pub fn unset_memory_high(&self) -> anyhow::Result<()> {
|
||||
self.set_memory_high_internal(MaxValue::Max)
|
||||
}
|
||||
|
||||
fn set_memory_high_internal(&self, value: MaxValue) -> anyhow::Result<()> {
|
||||
self.memory()
|
||||
.context("failed to get memory subsystem")?
|
||||
.set_mem(cgroups_rs::memory::SetMemory {
|
||||
low: None,
|
||||
high: Some(value),
|
||||
min: None,
|
||||
max: None,
|
||||
.find_map(|sub| match sub {
|
||||
Subsystem::Mem(c) => Some(c),
|
||||
_ => None,
|
||||
})
|
||||
.map_err(anyhow::Error::from)
|
||||
.ok_or_else(|| anyhow!("could not find memory subsystem"))
|
||||
}
|
||||
|
||||
/// Get memory.high threshold.
|
||||
pub fn get_memory_high_bytes(&self) -> anyhow::Result<u64> {
|
||||
let high = self
|
||||
.memory()
|
||||
.context("failed to get memory subsystem while getting memory statistics")?
|
||||
.get_mem()
|
||||
.map(|mem| mem.high)
|
||||
.context("failed to get memory statistics from subsystem")?;
|
||||
match high {
|
||||
Some(MaxValue::Max) => Ok(i64::MAX as u64),
|
||||
Some(MaxValue::Value(high)) => Ok(high as u64),
|
||||
None => anyhow::bail!("failed to read memory.high from memory subsystem"),
|
||||
/// Given a handle on the memory subsystem, returns the current memory information
|
||||
fn memory_usage(mem_controller: &MemController) -> MemoryStatus {
|
||||
let stat = mem_controller.memory_stat().stat;
|
||||
MemoryStatus {
|
||||
non_reclaimable: stat.active_anon + stat.inactive_anon,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function for `CgroupWatcher::watch`
|
||||
fn ring_buf_recent_values_iter<T>(
|
||||
buf: &[T],
|
||||
last_value_idx: usize,
|
||||
count: usize,
|
||||
) -> impl '_ + Iterator<Item = &T> {
|
||||
// Assertion carried over from `CgroupWatcher::watch`, to make the logic in this function
|
||||
// easier (we only have to add `buf.len()` once, rather than a dynamic number of times).
|
||||
assert!(count <= buf.len());
|
||||
|
||||
buf.iter()
|
||||
// 'cycle' because the values could wrap around
|
||||
.cycle()
|
||||
// with 'cycle', this skip is more like 'offset', and functionally this is
|
||||
// offsettting by 'last_value_idx - count (mod buf.len())', but we have to be
|
||||
// careful to avoid underflow, so we pre-add buf.len().
|
||||
// The '+ 1' is because `last_value_idx` is inclusive, rather than exclusive.
|
||||
.skip((buf.len() + last_value_idx + 1 - count) % buf.len())
|
||||
.take(count)
|
||||
}
|
||||
|
||||
/// Summary of recent memory usage
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct MemoryHistory {
|
||||
/// Rolling average of non-reclaimable memory usage samples over the last `history_period`
|
||||
pub avg_non_reclaimable: u64,
|
||||
|
||||
/// The number of samples used to construct this summary
|
||||
pub samples_count: usize,
|
||||
/// Total timespan between the first and last sample used for this summary
|
||||
pub samples_span: Duration,
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
pub struct MemoryStatus {
|
||||
non_reclaimable: u64,
|
||||
}
|
||||
|
||||
impl MemoryStatus {
|
||||
fn zeroed() -> Self {
|
||||
MemoryStatus { non_reclaimable: 0 }
|
||||
}
|
||||
|
||||
fn debug_slice(slice: &[Self]) -> impl '_ + Debug {
|
||||
struct DS<'a>(&'a [MemoryStatus]);
|
||||
|
||||
impl<'a> Debug for DS<'a> {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
f.debug_struct("[MemoryStatus]")
|
||||
.field(
|
||||
"non_reclaimable[..]",
|
||||
&Fields(self.0, |stat: &MemoryStatus| {
|
||||
BytesToGB(stat.non_reclaimable)
|
||||
}),
|
||||
)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
struct Fields<'a, F>(&'a [MemoryStatus], F);
|
||||
|
||||
impl<'a, F: Fn(&MemoryStatus) -> T, T: Debug> Debug for Fields<'a, F> {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
f.debug_list().entries(self.0.iter().map(&self.1)).finish()
|
||||
}
|
||||
}
|
||||
|
||||
struct BytesToGB(u64);
|
||||
|
||||
impl Debug for BytesToGB {
|
||||
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
|
||||
f.write_fmt(format_args!(
|
||||
"{:.3}Gi",
|
||||
self.0 as f64 / (1_u64 << 30) as f64
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
DS(slice)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
fn ring_buf_iter() {
|
||||
let buf = vec![0_i32, 1, 2, 3, 4, 5, 6, 7, 8, 9];
|
||||
|
||||
let values = |offset, count| {
|
||||
super::ring_buf_recent_values_iter(&buf, offset, count)
|
||||
.copied()
|
||||
.collect::<Vec<i32>>()
|
||||
};
|
||||
|
||||
// Boundary conditions: start, end, and entire thing:
|
||||
assert_eq!(values(0, 1), [0]);
|
||||
assert_eq!(values(3, 4), [0, 1, 2, 3]);
|
||||
assert_eq!(values(9, 4), [6, 7, 8, 9]);
|
||||
assert_eq!(values(9, 10), [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
|
||||
|
||||
// "normal" operation: no wraparound
|
||||
assert_eq!(values(7, 4), [4, 5, 6, 7]);
|
||||
|
||||
// wraparound:
|
||||
assert_eq!(values(0, 4), [7, 8, 9, 0]);
|
||||
assert_eq!(values(1, 4), [8, 9, 0, 1]);
|
||||
assert_eq!(values(2, 4), [9, 0, 1, 2]);
|
||||
assert_eq!(values(2, 10), [3, 4, 5, 6, 7, 8, 9, 0, 1, 2]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,12 +12,10 @@ use futures::{
|
||||
stream::{SplitSink, SplitStream},
|
||||
SinkExt, StreamExt,
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::info;
|
||||
|
||||
use crate::cgroup::Sequenced;
|
||||
use crate::protocol::{
|
||||
OutboundMsg, ProtocolRange, ProtocolResponse, ProtocolVersion, Resources, PROTOCOL_MAX_VERSION,
|
||||
OutboundMsg, ProtocolRange, ProtocolResponse, ProtocolVersion, PROTOCOL_MAX_VERSION,
|
||||
PROTOCOL_MIN_VERSION,
|
||||
};
|
||||
|
||||
@@ -36,13 +34,6 @@ pub struct Dispatcher {
|
||||
/// We send messages to the agent through `sink`
|
||||
sink: SplitSink<WebSocket, Message>,
|
||||
|
||||
/// Used to notify the cgroup when we are upscaled.
|
||||
pub(crate) notify_upscale_events: mpsc::Sender<Sequenced<Resources>>,
|
||||
|
||||
/// When the cgroup requests upscale it will send on this channel. In response
|
||||
/// we send an `UpscaleRequst` to the agent.
|
||||
pub(crate) request_upscale_events: mpsc::Receiver<()>,
|
||||
|
||||
/// The protocol version we have agreed to use with the agent. This is negotiated
|
||||
/// during the creation of the dispatcher, and should be the highest shared protocol
|
||||
/// version.
|
||||
@@ -61,11 +52,7 @@ impl Dispatcher {
|
||||
/// 1. Wait for the agent to sent the range of protocols it supports.
|
||||
/// 2. Send a protocol version that works for us as well, or an error if there
|
||||
/// is no compatible version.
|
||||
pub async fn new(
|
||||
stream: WebSocket,
|
||||
notify_upscale_events: mpsc::Sender<Sequenced<Resources>>,
|
||||
request_upscale_events: mpsc::Receiver<()>,
|
||||
) -> anyhow::Result<Self> {
|
||||
pub async fn new(stream: WebSocket) -> anyhow::Result<Self> {
|
||||
let (mut sink, mut source) = stream.split();
|
||||
|
||||
// Figure out the highest protocol version we both support
|
||||
@@ -119,22 +106,10 @@ impl Dispatcher {
|
||||
Ok(Self {
|
||||
sink,
|
||||
source,
|
||||
notify_upscale_events,
|
||||
request_upscale_events,
|
||||
proto_version: highest_shared_version,
|
||||
})
|
||||
}
|
||||
|
||||
/// Notify the cgroup manager that we have received upscale and wait for
|
||||
/// the acknowledgement.
|
||||
#[tracing::instrument(skip_all, fields(?resources))]
|
||||
pub async fn notify_upscale(&self, resources: Sequenced<Resources>) -> anyhow::Result<()> {
|
||||
self.notify_upscale_events
|
||||
.send(resources)
|
||||
.await
|
||||
.context("failed to send resources and oneshot sender across channel")
|
||||
}
|
||||
|
||||
/// Send a message to the agent.
|
||||
///
|
||||
/// Although this function is small, it has one major benefit: it is the only
|
||||
|
||||
@@ -21,11 +21,6 @@ pub struct FileCacheState {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FileCacheConfig {
|
||||
/// Whether the file cache is *actually* stored in memory (e.g. by writing to
|
||||
/// a tmpfs or shmem file). If true, the size of the file cache will be counted against the
|
||||
/// memory available for the cgroup.
|
||||
pub(crate) in_memory: bool,
|
||||
|
||||
/// The size of the file cache, in terms of the size of the resource it consumes
|
||||
/// (currently: only memory)
|
||||
///
|
||||
@@ -59,22 +54,9 @@ pub struct FileCacheConfig {
|
||||
spread_factor: f64,
|
||||
}
|
||||
|
||||
impl FileCacheConfig {
|
||||
pub fn default_in_memory() -> Self {
|
||||
impl Default for FileCacheConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
in_memory: true,
|
||||
// 75 %
|
||||
resource_multiplier: 0.75,
|
||||
// 640 MiB; (512 + 128)
|
||||
min_remaining_after_cache: NonZeroU64::new(640 * MiB).unwrap(),
|
||||
// ensure any increase in file cache size is split 90-10 with 10% to other memory
|
||||
spread_factor: 0.1,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_on_disk() -> Self {
|
||||
Self {
|
||||
in_memory: false,
|
||||
resource_multiplier: 0.75,
|
||||
// 256 MiB - lower than when in memory because overcommitting is safe; if we don't have
|
||||
// memory, the kernel will just evict from its page cache, rather than e.g. killing
|
||||
@@ -83,7 +65,9 @@ impl FileCacheConfig {
|
||||
spread_factor: 0.1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FileCacheConfig {
|
||||
/// Make sure fields of the config are consistent.
|
||||
pub fn validate(&self) -> anyhow::Result<()> {
|
||||
// Single field validity
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#![deny(unsafe_code)]
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
#![cfg(target_os = "linux")]
|
||||
|
||||
use anyhow::Context;
|
||||
@@ -39,16 +41,6 @@ pub struct Args {
|
||||
#[arg(short, long)]
|
||||
pub pgconnstr: Option<String>,
|
||||
|
||||
/// Flag to signal that the Postgres file cache is on disk (i.e. not in memory aside from the
|
||||
/// kernel's page cache), and therefore should not count against available memory.
|
||||
//
|
||||
// NB: Ideally this flag would directly refer to whether the file cache is in memory (rather
|
||||
// than a roundabout way, via whether it's on disk), but in order to be backwards compatible
|
||||
// during the switch away from an in-memory file cache, we had to default to the previous
|
||||
// behavior.
|
||||
#[arg(long)]
|
||||
pub file_cache_on_disk: bool,
|
||||
|
||||
/// The address we should listen on for connection requests. For the
|
||||
/// agent, this is 0.0.0.0:10301. For the informant, this is 127.0.0.1:10369.
|
||||
#[arg(short, long)]
|
||||
|
||||
@@ -5,18 +5,16 @@
|
||||
//! all functionality.
|
||||
|
||||
use std::fmt::Debug;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use anyhow::{bail, Context};
|
||||
use axum::extract::ws::{Message, WebSocket};
|
||||
use futures::StreamExt;
|
||||
use tokio::sync::broadcast;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::{broadcast, watch};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use crate::cgroup::{CgroupWatcher, Sequenced};
|
||||
use crate::cgroup::{self, CgroupWatcher};
|
||||
use crate::dispatcher::Dispatcher;
|
||||
use crate::filecache::{FileCacheConfig, FileCacheState};
|
||||
use crate::protocol::{InboundMsg, InboundMsgKind, OutboundMsg, OutboundMsgKind, Resources};
|
||||
@@ -28,7 +26,7 @@ use crate::{bytes_to_mebibytes, get_total_system_memory, spawn_with_cancel, Args
|
||||
pub struct Runner {
|
||||
config: Config,
|
||||
filecache: Option<FileCacheState>,
|
||||
cgroup: Option<Arc<CgroupWatcher>>,
|
||||
cgroup: Option<CgroupState>,
|
||||
dispatcher: Dispatcher,
|
||||
|
||||
/// We "mint" new message ids by incrementing this counter and taking the value.
|
||||
@@ -45,6 +43,14 @@ pub struct Runner {
|
||||
kill: broadcast::Receiver<()>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CgroupState {
|
||||
watcher: watch::Receiver<(Instant, cgroup::MemoryHistory)>,
|
||||
/// If [`cgroup::MemoryHistory::avg_non_reclaimable`] exceeds `threshold`, we send upscale
|
||||
/// requests.
|
||||
threshold: u64,
|
||||
}
|
||||
|
||||
/// Configuration for a `Runner`
|
||||
#[derive(Debug)]
|
||||
pub struct Config {
|
||||
@@ -62,16 +68,56 @@ pub struct Config {
|
||||
/// upscale resource amounts (because we might not *actually* have been upscaled yet). This field
|
||||
/// should be removed once we have a better solution there.
|
||||
sys_buffer_bytes: u64,
|
||||
|
||||
/// Minimum fraction of total system memory reserved *before* the the cgroup threshold; in
|
||||
/// other words, providing a ceiling for the highest value of the threshold by enforcing that
|
||||
/// there's at least `cgroup_min_overhead_fraction` of the total memory remaining beyond the
|
||||
/// threshold.
|
||||
///
|
||||
/// For example, a value of `0.1` means that 10% of total memory must remain after exceeding
|
||||
/// the threshold, so the value of the cgroup threshold would always be capped at 90% of total
|
||||
/// memory.
|
||||
///
|
||||
/// The default value of `0.15` means that we *guarantee* sending upscale requests if the
|
||||
/// cgroup is using more than 85% of total memory (even if we're *not* separately reserving
|
||||
/// memory for the file cache).
|
||||
cgroup_min_overhead_fraction: f64,
|
||||
|
||||
cgroup_downscale_threshold_buffer_bytes: u64,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
sys_buffer_bytes: 100 * MiB,
|
||||
cgroup_min_overhead_fraction: 0.15,
|
||||
cgroup_downscale_threshold_buffer_bytes: 100 * MiB,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn cgroup_threshold(&self, total_mem: u64, file_cache_disk_size: u64) -> u64 {
|
||||
// If the file cache is in tmpfs, then it will count towards shmem usage of the cgroup,
|
||||
// and thus be non-reclaimable, so we should allow for additional memory usage.
|
||||
//
|
||||
// If the file cache sits on disk, our desired stable system state is for it to be fully
|
||||
// page cached (its contents should only be paged to/from disk in situations where we can't
|
||||
// upscale fast enough). Page-cached memory is reclaimable, so we need to lower the
|
||||
// threshold for non-reclaimable memory so we scale up *before* the kernel starts paging
|
||||
// out the file cache.
|
||||
let memory_remaining_for_cgroup = total_mem.saturating_sub(file_cache_disk_size);
|
||||
|
||||
// Even if we're not separately making room for the file cache (if it's in tmpfs), we still
|
||||
// want our threshold to be met gracefully instead of letting postgres get OOM-killed.
|
||||
// So we guarantee that there's at least `cgroup_min_overhead_fraction` of total memory
|
||||
// remaining above the threshold.
|
||||
let max_threshold = (total_mem as f64 * (1.0 - self.cgroup_min_overhead_fraction)) as u64;
|
||||
|
||||
memory_remaining_for_cgroup.min(max_threshold)
|
||||
}
|
||||
}
|
||||
|
||||
impl Runner {
|
||||
/// Create a new monitor.
|
||||
#[tracing::instrument(skip_all, fields(?config, ?args))]
|
||||
@@ -87,12 +133,7 @@ impl Runner {
|
||||
"invalid monitor Config: sys_buffer_bytes cannot be 0"
|
||||
);
|
||||
|
||||
// *NOTE*: the dispatcher and cgroup manager talk through these channels
|
||||
// so make sure they each get the correct half, nothing is droppped, etc.
|
||||
let (notified_send, notified_recv) = mpsc::channel(1);
|
||||
let (requesting_send, requesting_recv) = mpsc::channel(1);
|
||||
|
||||
let dispatcher = Dispatcher::new(ws, notified_send, requesting_recv)
|
||||
let dispatcher = Dispatcher::new(ws)
|
||||
.await
|
||||
.context("error creating new dispatcher")?;
|
||||
|
||||
@@ -106,57 +147,18 @@ impl Runner {
|
||||
kill,
|
||||
};
|
||||
|
||||
// If we have both the cgroup and file cache integrations enabled, it's possible for
|
||||
// temporary failures to result in cgroup throttling (from memory.high), that in turn makes
|
||||
// it near-impossible to connect to the file cache (because it times out). Unfortunately,
|
||||
// we *do* still want to determine the file cache size before setting the cgroup's
|
||||
// memory.high, so it's not as simple as just swapping the order.
|
||||
//
|
||||
// Instead, the resolution here is that on vm-monitor startup (note: happens on each
|
||||
// connection from autoscaler-agent, possibly multiple times per compute_ctl lifecycle), we
|
||||
// temporarily unset memory.high, to allow any existing throttling to dissipate. It's a bit
|
||||
// of a hacky solution, but helps with reliability.
|
||||
if let Some(name) = &args.cgroup {
|
||||
// Best not to set up cgroup stuff more than once, so we'll initialize cgroup state
|
||||
// now, and then set limits later.
|
||||
info!("initializing cgroup");
|
||||
|
||||
let (cgroup, cgroup_event_stream) = CgroupWatcher::new(name.clone(), requesting_send)
|
||||
.context("failed to create cgroup manager")?;
|
||||
|
||||
info!("temporarily unsetting memory.high");
|
||||
|
||||
// Temporarily un-set cgroup memory.high; see above.
|
||||
cgroup
|
||||
.unset_memory_high()
|
||||
.context("failed to unset memory.high")?;
|
||||
|
||||
let cgroup = Arc::new(cgroup);
|
||||
|
||||
let cgroup_clone = Arc::clone(&cgroup);
|
||||
spawn_with_cancel(
|
||||
token.clone(),
|
||||
|_| error!("cgroup watcher terminated"),
|
||||
async move { cgroup_clone.watch(notified_recv, cgroup_event_stream).await },
|
||||
);
|
||||
|
||||
state.cgroup = Some(cgroup);
|
||||
}
|
||||
|
||||
let mut file_cache_reserved_bytes = 0;
|
||||
let mem = get_total_system_memory();
|
||||
|
||||
let mut file_cache_disk_size = 0;
|
||||
|
||||
// We need to process file cache initialization before cgroup initialization, so that the memory
|
||||
// allocated to the file cache is appropriately taken into account when we decide the cgroup's
|
||||
// memory limits.
|
||||
if let Some(connstr) = &args.pgconnstr {
|
||||
info!("initializing file cache");
|
||||
let config = match args.file_cache_on_disk {
|
||||
true => FileCacheConfig::default_on_disk(),
|
||||
false => FileCacheConfig::default_in_memory(),
|
||||
};
|
||||
let config = FileCacheConfig::default();
|
||||
|
||||
let mut file_cache = FileCacheState::new(connstr, config, token)
|
||||
let mut file_cache = FileCacheState::new(connstr, config, token.clone())
|
||||
.await
|
||||
.context("failed to create file cache")?;
|
||||
|
||||
@@ -181,23 +183,37 @@ impl Runner {
|
||||
if actual_size != new_size {
|
||||
info!("file cache size actually got set to {actual_size}")
|
||||
}
|
||||
// Mark the resources given to the file cache as reserved, but only if it's in memory.
|
||||
if !args.file_cache_on_disk {
|
||||
file_cache_reserved_bytes = actual_size;
|
||||
}
|
||||
|
||||
file_cache_disk_size = actual_size;
|
||||
state.filecache = Some(file_cache);
|
||||
}
|
||||
|
||||
if let Some(cgroup) = &state.cgroup {
|
||||
let available = mem - file_cache_reserved_bytes;
|
||||
let value = cgroup.config.calculate_memory_high_value(available);
|
||||
if let Some(name) = &args.cgroup {
|
||||
// Best not to set up cgroup stuff more than once, so we'll initialize cgroup state
|
||||
// now, and then set limits later.
|
||||
info!("initializing cgroup");
|
||||
|
||||
info!(value, "setting memory.high");
|
||||
let cgroup =
|
||||
CgroupWatcher::new(name.clone()).context("failed to create cgroup manager")?;
|
||||
|
||||
cgroup
|
||||
.set_memory_high_bytes(value)
|
||||
.context("failed to set cgroup memory.high")?;
|
||||
let init_value = cgroup::MemoryHistory {
|
||||
avg_non_reclaimable: 0,
|
||||
samples_count: 0,
|
||||
samples_span: Duration::ZERO,
|
||||
};
|
||||
let (hist_tx, hist_rx) = watch::channel((Instant::now(), init_value));
|
||||
|
||||
spawn_with_cancel(token, |_| error!("cgroup watcher terminated"), async move {
|
||||
cgroup.watch(hist_tx).await
|
||||
});
|
||||
|
||||
let threshold = state.config.cgroup_threshold(mem, file_cache_disk_size);
|
||||
info!(threshold, "set initial cgroup threshold",);
|
||||
|
||||
state.cgroup = Some(CgroupState {
|
||||
watcher: hist_rx,
|
||||
threshold,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(state)
|
||||
@@ -217,28 +233,45 @@ impl Runner {
|
||||
|
||||
let requested_mem = target.mem;
|
||||
let usable_system_memory = requested_mem.saturating_sub(self.config.sys_buffer_bytes);
|
||||
let expected_file_cache_mem_usage = self
|
||||
let expected_file_cache_size = self
|
||||
.filecache
|
||||
.as_ref()
|
||||
.map(|file_cache| file_cache.config.calculate_cache_size(usable_system_memory))
|
||||
.unwrap_or(0);
|
||||
let mut new_cgroup_mem_high = 0;
|
||||
if let Some(cgroup) = &self.cgroup {
|
||||
new_cgroup_mem_high = cgroup
|
||||
let (last_time, last_history) = *cgroup.watcher.borrow();
|
||||
|
||||
// NB: The ordering of these conditions is intentional. During startup, we should deny
|
||||
// downscaling until we have enough information to determine that it's safe to do so
|
||||
// (i.e. enough samples have come in). But if it's been a while and we *still* haven't
|
||||
// received any information, we should *fail* instead of just denying downscaling.
|
||||
//
|
||||
// `last_time` is set to `Instant::now()` on startup, so checking `last_time.elapsed()`
|
||||
// serves double-duty: it trips if we haven't received *any* metrics for long enough,
|
||||
// OR if we haven't received metrics *recently enough*.
|
||||
//
|
||||
// TODO: make the duration here configurable.
|
||||
if last_time.elapsed() > Duration::from_secs(5) {
|
||||
bail!("haven't gotten cgroup memory stats recently enough to determine downscaling information");
|
||||
} else if last_history.samples_count <= 1 {
|
||||
let status = "haven't received enough cgroup memory stats yet";
|
||||
info!(status, "discontinuing downscale");
|
||||
return Ok((false, status.to_owned()));
|
||||
}
|
||||
|
||||
let new_threshold = self
|
||||
.config
|
||||
.calculate_memory_high_value(usable_system_memory - expected_file_cache_mem_usage);
|
||||
.cgroup_threshold(usable_system_memory, expected_file_cache_size);
|
||||
|
||||
let current = cgroup
|
||||
.current_memory_usage()
|
||||
.context("failed to fetch cgroup memory")?;
|
||||
let current = last_history.avg_non_reclaimable;
|
||||
|
||||
if new_cgroup_mem_high < current + cgroup.config.memory_high_buffer_bytes {
|
||||
if new_threshold < current + self.config.cgroup_downscale_threshold_buffer_bytes {
|
||||
let status = format!(
|
||||
"{}: {} MiB (new high) < {} (current usage) + {} (buffer)",
|
||||
"calculated memory.high too low",
|
||||
bytes_to_mebibytes(new_cgroup_mem_high),
|
||||
"{}: {} MiB (new threshold) < {} (current usage) + {} (downscale buffer)",
|
||||
"calculated memory threshold too low",
|
||||
bytes_to_mebibytes(new_threshold),
|
||||
bytes_to_mebibytes(current),
|
||||
bytes_to_mebibytes(cgroup.config.memory_high_buffer_bytes)
|
||||
bytes_to_mebibytes(self.config.cgroup_downscale_threshold_buffer_bytes)
|
||||
);
|
||||
|
||||
info!(status, "discontinuing downscale");
|
||||
@@ -249,42 +282,33 @@ impl Runner {
|
||||
|
||||
// The downscaling has been approved. Downscale the file cache, then the cgroup.
|
||||
let mut status = vec![];
|
||||
let mut file_cache_mem_usage = 0;
|
||||
let mut file_cache_disk_size = 0;
|
||||
if let Some(file_cache) = &mut self.filecache {
|
||||
let actual_usage = file_cache
|
||||
.set_file_cache_size(expected_file_cache_mem_usage)
|
||||
.set_file_cache_size(expected_file_cache_size)
|
||||
.await
|
||||
.context("failed to set file cache size")?;
|
||||
if file_cache.config.in_memory {
|
||||
file_cache_mem_usage = actual_usage;
|
||||
}
|
||||
file_cache_disk_size = actual_usage;
|
||||
let message = format!(
|
||||
"set file cache size to {} MiB (in memory = {})",
|
||||
"set file cache size to {} MiB",
|
||||
bytes_to_mebibytes(actual_usage),
|
||||
file_cache.config.in_memory,
|
||||
);
|
||||
info!("downscale: {message}");
|
||||
status.push(message);
|
||||
}
|
||||
|
||||
if let Some(cgroup) = &self.cgroup {
|
||||
let available_memory = usable_system_memory - file_cache_mem_usage;
|
||||
|
||||
if file_cache_mem_usage != expected_file_cache_mem_usage {
|
||||
new_cgroup_mem_high = cgroup.config.calculate_memory_high_value(available_memory);
|
||||
}
|
||||
|
||||
// new_cgroup_mem_high is initialized to 0 but it is guaranteed to not be here
|
||||
// since it is properly initialized in the previous cgroup if let block
|
||||
cgroup
|
||||
.set_memory_high_bytes(new_cgroup_mem_high)
|
||||
.context("failed to set cgroup memory.high")?;
|
||||
if let Some(cgroup) = &mut self.cgroup {
|
||||
let new_threshold = self
|
||||
.config
|
||||
.cgroup_threshold(usable_system_memory, file_cache_disk_size);
|
||||
|
||||
let message = format!(
|
||||
"set cgroup memory.high to {} MiB, of new max {} MiB",
|
||||
bytes_to_mebibytes(new_cgroup_mem_high),
|
||||
bytes_to_mebibytes(available_memory)
|
||||
"set cgroup memory threshold from {} MiB to {} MiB, of new total {} MiB",
|
||||
bytes_to_mebibytes(cgroup.threshold),
|
||||
bytes_to_mebibytes(new_threshold),
|
||||
bytes_to_mebibytes(usable_system_memory)
|
||||
);
|
||||
cgroup.threshold = new_threshold;
|
||||
info!("downscale: {message}");
|
||||
status.push(message);
|
||||
}
|
||||
@@ -305,8 +329,7 @@ impl Runner {
|
||||
let new_mem = resources.mem;
|
||||
let usable_system_memory = new_mem.saturating_sub(self.config.sys_buffer_bytes);
|
||||
|
||||
// Get the file cache's expected contribution to the memory usage
|
||||
let mut file_cache_mem_usage = 0;
|
||||
let mut file_cache_disk_size = 0;
|
||||
if let Some(file_cache) = &mut self.filecache {
|
||||
let expected_usage = file_cache.config.calculate_cache_size(usable_system_memory);
|
||||
info!(
|
||||
@@ -319,9 +342,7 @@ impl Runner {
|
||||
.set_file_cache_size(expected_usage)
|
||||
.await
|
||||
.context("failed to set file cache size")?;
|
||||
if file_cache.config.in_memory {
|
||||
file_cache_mem_usage = actual_usage;
|
||||
}
|
||||
file_cache_disk_size = actual_usage;
|
||||
|
||||
if actual_usage != expected_usage {
|
||||
warn!(
|
||||
@@ -332,18 +353,18 @@ impl Runner {
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(cgroup) = &self.cgroup {
|
||||
let available_memory = usable_system_memory - file_cache_mem_usage;
|
||||
let new_cgroup_mem_high = cgroup.config.calculate_memory_high_value(available_memory);
|
||||
if let Some(cgroup) = &mut self.cgroup {
|
||||
let new_threshold = self
|
||||
.config
|
||||
.cgroup_threshold(usable_system_memory, file_cache_disk_size);
|
||||
|
||||
info!(
|
||||
target = bytes_to_mebibytes(new_cgroup_mem_high),
|
||||
total = bytes_to_mebibytes(new_mem),
|
||||
name = cgroup.path(),
|
||||
"updating cgroup memory.high",
|
||||
"set cgroup memory threshold from {} MiB to {} MiB of new total {} MiB",
|
||||
bytes_to_mebibytes(cgroup.threshold),
|
||||
bytes_to_mebibytes(new_threshold),
|
||||
bytes_to_mebibytes(usable_system_memory)
|
||||
);
|
||||
cgroup
|
||||
.set_memory_high_bytes(new_cgroup_mem_high)
|
||||
.context("failed to set cgroup memory.high")?;
|
||||
cgroup.threshold = new_threshold;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -361,10 +382,6 @@ impl Runner {
|
||||
self.handle_upscale(granted)
|
||||
.await
|
||||
.context("failed to handle upscale")?;
|
||||
self.dispatcher
|
||||
.notify_upscale(Sequenced::new(granted))
|
||||
.await
|
||||
.context("failed to notify notify cgroup of upscale")?;
|
||||
Ok(Some(OutboundMsg::new(
|
||||
OutboundMsgKind::UpscaleConfirmation {},
|
||||
id,
|
||||
@@ -408,33 +425,53 @@ impl Runner {
|
||||
Err(e) => bail!("failed to receive kill signal: {e}")
|
||||
}
|
||||
}
|
||||
// we need to propagate an upscale request
|
||||
request = self.dispatcher.request_upscale_events.recv(), if self.cgroup.is_some() => {
|
||||
if request.is_none() {
|
||||
bail!("failed to listen for upscale event from cgroup")
|
||||
|
||||
// New memory stats from the cgroup, *may* need to request upscaling, if we've
|
||||
// exceeded the threshold
|
||||
result = self.cgroup.as_mut().unwrap().watcher.changed(), if self.cgroup.is_some() => {
|
||||
result.context("failed to receive from cgroup memory stats watcher")?;
|
||||
|
||||
let cgroup = self.cgroup.as_ref().unwrap();
|
||||
|
||||
let (_time, cgroup_mem_stat) = *cgroup.watcher.borrow();
|
||||
|
||||
// If we haven't exceeded the threshold, then we're all ok
|
||||
if cgroup_mem_stat.avg_non_reclaimable < cgroup.threshold {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If it's been less than 1 second since the last time we requested upscaling,
|
||||
// ignore the event, to avoid spamming the agent (otherwise, this can happen
|
||||
// ~1k times per second).
|
||||
// Otherwise, we generally want upscaling. But, if it's been less than 1 second
|
||||
// since the last time we requested upscaling, ignore the event, to avoid
|
||||
// spamming the agent.
|
||||
if let Some(t) = self.last_upscale_request_at {
|
||||
let elapsed = t.elapsed();
|
||||
if elapsed < Duration::from_secs(1) {
|
||||
info!(elapsed_millis = elapsed.as_millis(), "cgroup asked for upscale but too soon to forward the request, ignoring");
|
||||
info!(
|
||||
elapsed_millis = elapsed.as_millis(),
|
||||
avg_non_reclaimable = bytes_to_mebibytes(cgroup_mem_stat.avg_non_reclaimable),
|
||||
threshold = bytes_to_mebibytes(cgroup.threshold),
|
||||
"cgroup memory stats are high enough to upscale but too soon to forward the request, ignoring",
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
self.last_upscale_request_at = Some(Instant::now());
|
||||
|
||||
info!("cgroup asking for upscale; forwarding request");
|
||||
info!(
|
||||
avg_non_reclaimable = bytes_to_mebibytes(cgroup_mem_stat.avg_non_reclaimable),
|
||||
threshold = bytes_to_mebibytes(cgroup.threshold),
|
||||
"cgroup memory stats are high enough to upscale, requesting upscale",
|
||||
);
|
||||
|
||||
self.counter += 2; // Increment, preserving parity (i.e. keep the
|
||||
// counter odd). See the field comment for more.
|
||||
self.dispatcher
|
||||
.send(OutboundMsg::new(OutboundMsgKind::UpscaleRequest {}, self.counter))
|
||||
.await
|
||||
.context("failed to send message")?;
|
||||
}
|
||||
},
|
||||
|
||||
// there is a message from the agent
|
||||
msg = self.dispatcher.source.next() => {
|
||||
if let Some(msg) = msg {
|
||||
@@ -462,11 +499,14 @@ impl Runner {
|
||||
Ok(Some(out)) => out,
|
||||
Ok(None) => continue,
|
||||
Err(e) => {
|
||||
let error = e.to_string();
|
||||
warn!(?error, "error handling message");
|
||||
// use {:#} for our logging because the display impl only
|
||||
// gives the outermost cause, and the debug impl
|
||||
// pretty-prints the error, whereas {:#} contains all the
|
||||
// causes, but is compact (no newlines).
|
||||
warn!(error = format!("{e:#}"), "error handling message");
|
||||
OutboundMsg::new(
|
||||
OutboundMsgKind::InternalError {
|
||||
error
|
||||
error: e.to_string(),
|
||||
},
|
||||
message.id
|
||||
)
|
||||
|
||||
16
libs/walproposer/Cargo.toml
Normal file
16
libs/walproposer/Cargo.toml
Normal file
@@ -0,0 +1,16 @@
|
||||
[package]
|
||||
name = "walproposer"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
utils.workspace = true
|
||||
postgres_ffi.workspace = true
|
||||
|
||||
workspace_hack.workspace = true
|
||||
|
||||
[build-dependencies]
|
||||
anyhow.workspace = true
|
||||
bindgen.workspace = true
|
||||
1
libs/walproposer/bindgen_deps.h
Normal file
1
libs/walproposer/bindgen_deps.h
Normal file
@@ -0,0 +1 @@
|
||||
#include "walproposer.h"
|
||||
113
libs/walproposer/build.rs
Normal file
113
libs/walproposer/build.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
use std::{env, path::PathBuf, process::Command};
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use bindgen::CargoCallbacks;
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
// Tell cargo to invalidate the built crate whenever the wrapper changes
|
||||
println!("cargo:rerun-if-changed=bindgen_deps.h");
|
||||
|
||||
// Finding the location of built libraries and Postgres C headers:
|
||||
// - if POSTGRES_INSTALL_DIR is set look into it, otherwise look into `<project_root>/pg_install`
|
||||
// - if there's a `bin/pg_config` file use it for getting include server, otherwise use `<project_root>/pg_install/{PG_MAJORVERSION}/include/postgresql/server`
|
||||
let pg_install_dir = if let Some(postgres_install_dir) = env::var_os("POSTGRES_INSTALL_DIR") {
|
||||
postgres_install_dir.into()
|
||||
} else {
|
||||
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../pg_install")
|
||||
};
|
||||
|
||||
let pg_install_abs = std::fs::canonicalize(pg_install_dir)?;
|
||||
let walproposer_lib_dir = pg_install_abs.join("build/walproposer-lib");
|
||||
let walproposer_lib_search_str = walproposer_lib_dir
|
||||
.to_str()
|
||||
.ok_or(anyhow!("Bad non-UTF path"))?;
|
||||
|
||||
let pgxn_neon = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../pgxn/neon");
|
||||
let pgxn_neon = std::fs::canonicalize(pgxn_neon)?;
|
||||
let pgxn_neon = pgxn_neon.to_str().ok_or(anyhow!("Bad non-UTF path"))?;
|
||||
|
||||
println!("cargo:rustc-link-lib=static=pgport");
|
||||
println!("cargo:rustc-link-lib=static=pgcommon");
|
||||
println!("cargo:rustc-link-lib=static=walproposer");
|
||||
println!("cargo:rustc-link-search={walproposer_lib_search_str}");
|
||||
|
||||
let pg_config_bin = pg_install_abs.join("v16").join("bin").join("pg_config");
|
||||
let inc_server_path: String = if pg_config_bin.exists() {
|
||||
let output = Command::new(pg_config_bin)
|
||||
.arg("--includedir-server")
|
||||
.output()
|
||||
.context("failed to execute `pg_config --includedir-server`")?;
|
||||
|
||||
if !output.status.success() {
|
||||
panic!("`pg_config --includedir-server` failed")
|
||||
}
|
||||
|
||||
String::from_utf8(output.stdout)
|
||||
.context("pg_config output is not UTF-8")?
|
||||
.trim_end()
|
||||
.into()
|
||||
} else {
|
||||
let server_path = pg_install_abs
|
||||
.join("v16")
|
||||
.join("include")
|
||||
.join("postgresql")
|
||||
.join("server")
|
||||
.into_os_string();
|
||||
server_path
|
||||
.into_string()
|
||||
.map_err(|s| anyhow!("Bad postgres server path {s:?}"))?
|
||||
};
|
||||
|
||||
// The bindgen::Builder is the main entry point
|
||||
// to bindgen, and lets you build up options for
|
||||
// the resulting bindings.
|
||||
let bindings = bindgen::Builder::default()
|
||||
// The input header we would like to generate
|
||||
// bindings for.
|
||||
.header("bindgen_deps.h")
|
||||
// Tell cargo to invalidate the built crate whenever any of the
|
||||
// included header files changed.
|
||||
.parse_callbacks(Box::new(CargoCallbacks))
|
||||
.allowlist_type("WalProposer")
|
||||
.allowlist_type("WalProposerConfig")
|
||||
.allowlist_type("walproposer_api")
|
||||
.allowlist_function("WalProposerCreate")
|
||||
.allowlist_function("WalProposerStart")
|
||||
.allowlist_function("WalProposerBroadcast")
|
||||
.allowlist_function("WalProposerPoll")
|
||||
.allowlist_function("WalProposerFree")
|
||||
.allowlist_var("DEBUG5")
|
||||
.allowlist_var("DEBUG4")
|
||||
.allowlist_var("DEBUG3")
|
||||
.allowlist_var("DEBUG2")
|
||||
.allowlist_var("DEBUG1")
|
||||
.allowlist_var("LOG")
|
||||
.allowlist_var("INFO")
|
||||
.allowlist_var("NOTICE")
|
||||
.allowlist_var("WARNING")
|
||||
.allowlist_var("ERROR")
|
||||
.allowlist_var("FATAL")
|
||||
.allowlist_var("PANIC")
|
||||
.allowlist_var("WPEVENT")
|
||||
.allowlist_var("WL_LATCH_SET")
|
||||
.allowlist_var("WL_SOCKET_READABLE")
|
||||
.allowlist_var("WL_SOCKET_WRITEABLE")
|
||||
.allowlist_var("WL_TIMEOUT")
|
||||
.allowlist_var("WL_SOCKET_CLOSED")
|
||||
.allowlist_var("WL_SOCKET_MASK")
|
||||
.clang_arg("-DWALPROPOSER_LIB")
|
||||
.clang_arg(format!("-I{pgxn_neon}"))
|
||||
.clang_arg(format!("-I{inc_server_path}"))
|
||||
// Finish the builder and generate the bindings.
|
||||
.generate()
|
||||
// Unwrap the Result and panic on failure.
|
||||
.expect("Unable to generate bindings");
|
||||
|
||||
// Write the bindings to the $OUT_DIR/bindings.rs file.
|
||||
let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()).join("bindings.rs");
|
||||
bindings
|
||||
.write_to_file(out_path)
|
||||
.expect("Couldn't write bindings!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
457
libs/walproposer/src/api_bindings.rs
Normal file
457
libs/walproposer/src/api_bindings.rs
Normal file
@@ -0,0 +1,457 @@
|
||||
#![allow(dead_code)]
|
||||
|
||||
use std::ffi::CStr;
|
||||
use std::ffi::CString;
|
||||
|
||||
use crate::bindings::uint32;
|
||||
use crate::bindings::walproposer_api;
|
||||
use crate::bindings::PGAsyncReadResult;
|
||||
use crate::bindings::PGAsyncWriteResult;
|
||||
use crate::bindings::Safekeeper;
|
||||
use crate::bindings::Size;
|
||||
use crate::bindings::StringInfoData;
|
||||
use crate::bindings::TimeLineID;
|
||||
use crate::bindings::TimestampTz;
|
||||
use crate::bindings::WalProposer;
|
||||
use crate::bindings::WalProposerConnStatusType;
|
||||
use crate::bindings::WalProposerConnectPollStatusType;
|
||||
use crate::bindings::WalProposerExecStatusType;
|
||||
use crate::bindings::WalproposerShmemState;
|
||||
use crate::bindings::XLogRecPtr;
|
||||
use crate::walproposer::ApiImpl;
|
||||
use crate::walproposer::WaitResult;
|
||||
|
||||
extern "C" fn get_shmem_state(wp: *mut WalProposer) -> *mut WalproposerShmemState {
|
||||
unsafe {
|
||||
let callback_data = (*(*wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).get_shmem_state()
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn start_streaming(wp: *mut WalProposer, startpos: XLogRecPtr) {
|
||||
unsafe {
|
||||
let callback_data = (*(*wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).start_streaming(startpos)
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn get_flush_rec_ptr(wp: *mut WalProposer) -> XLogRecPtr {
|
||||
unsafe {
|
||||
let callback_data = (*(*wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).get_flush_rec_ptr()
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn get_current_timestamp(wp: *mut WalProposer) -> TimestampTz {
|
||||
unsafe {
|
||||
let callback_data = (*(*wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).get_current_timestamp()
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn conn_error_message(sk: *mut Safekeeper) -> *mut ::std::os::raw::c_char {
|
||||
unsafe {
|
||||
let callback_data = (*(*(*sk).wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
let msg = (*api).conn_error_message(&mut (*sk));
|
||||
let msg = CString::new(msg).unwrap();
|
||||
// TODO: fix leaking error message
|
||||
msg.into_raw()
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn conn_status(sk: *mut Safekeeper) -> WalProposerConnStatusType {
|
||||
unsafe {
|
||||
let callback_data = (*(*(*sk).wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).conn_status(&mut (*sk))
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn conn_connect_start(sk: *mut Safekeeper) {
|
||||
unsafe {
|
||||
let callback_data = (*(*(*sk).wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).conn_connect_start(&mut (*sk))
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn conn_connect_poll(sk: *mut Safekeeper) -> WalProposerConnectPollStatusType {
|
||||
unsafe {
|
||||
let callback_data = (*(*(*sk).wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).conn_connect_poll(&mut (*sk))
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn conn_send_query(sk: *mut Safekeeper, query: *mut ::std::os::raw::c_char) -> bool {
|
||||
let query = unsafe { CStr::from_ptr(query) };
|
||||
let query = query.to_str().unwrap();
|
||||
|
||||
unsafe {
|
||||
let callback_data = (*(*(*sk).wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).conn_send_query(&mut (*sk), query)
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn conn_get_query_result(sk: *mut Safekeeper) -> WalProposerExecStatusType {
|
||||
unsafe {
|
||||
let callback_data = (*(*(*sk).wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).conn_get_query_result(&mut (*sk))
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn conn_flush(sk: *mut Safekeeper) -> ::std::os::raw::c_int {
|
||||
unsafe {
|
||||
let callback_data = (*(*(*sk).wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).conn_flush(&mut (*sk))
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn conn_finish(sk: *mut Safekeeper) {
|
||||
unsafe {
|
||||
let callback_data = (*(*(*sk).wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).conn_finish(&mut (*sk))
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn conn_async_read(
|
||||
sk: *mut Safekeeper,
|
||||
buf: *mut *mut ::std::os::raw::c_char,
|
||||
amount: *mut ::std::os::raw::c_int,
|
||||
) -> PGAsyncReadResult {
|
||||
unsafe {
|
||||
let callback_data = (*(*(*sk).wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
let (res, result) = (*api).conn_async_read(&mut (*sk));
|
||||
|
||||
// This function has guarantee that returned buf will be valid until
|
||||
// the next call. So we can store a Vec in each Safekeeper and reuse
|
||||
// it on the next call.
|
||||
let mut inbuf = take_vec_u8(&mut (*sk).inbuf).unwrap_or_default();
|
||||
|
||||
inbuf.clear();
|
||||
inbuf.extend_from_slice(res);
|
||||
|
||||
// Put a Vec back to sk->inbuf and return data ptr.
|
||||
*buf = store_vec_u8(&mut (*sk).inbuf, inbuf);
|
||||
*amount = res.len() as i32;
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn conn_async_write(
|
||||
sk: *mut Safekeeper,
|
||||
buf: *const ::std::os::raw::c_void,
|
||||
size: usize,
|
||||
) -> PGAsyncWriteResult {
|
||||
unsafe {
|
||||
let buf = std::slice::from_raw_parts(buf as *const u8, size);
|
||||
let callback_data = (*(*(*sk).wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).conn_async_write(&mut (*sk), buf)
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn conn_blocking_write(
|
||||
sk: *mut Safekeeper,
|
||||
buf: *const ::std::os::raw::c_void,
|
||||
size: usize,
|
||||
) -> bool {
|
||||
unsafe {
|
||||
let buf = std::slice::from_raw_parts(buf as *const u8, size);
|
||||
let callback_data = (*(*(*sk).wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).conn_blocking_write(&mut (*sk), buf)
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn recovery_download(
|
||||
sk: *mut Safekeeper,
|
||||
_timeline: TimeLineID,
|
||||
startpos: XLogRecPtr,
|
||||
endpos: XLogRecPtr,
|
||||
) -> bool {
|
||||
unsafe {
|
||||
let callback_data = (*(*(*sk).wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).recovery_download(&mut (*sk), startpos, endpos)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::unnecessary_cast)]
|
||||
extern "C" fn wal_read(
|
||||
sk: *mut Safekeeper,
|
||||
buf: *mut ::std::os::raw::c_char,
|
||||
startptr: XLogRecPtr,
|
||||
count: Size,
|
||||
) {
|
||||
unsafe {
|
||||
let buf = std::slice::from_raw_parts_mut(buf as *mut u8, count);
|
||||
let callback_data = (*(*(*sk).wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).wal_read(&mut (*sk), buf, startptr)
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn wal_reader_allocate(sk: *mut Safekeeper) {
|
||||
unsafe {
|
||||
let callback_data = (*(*(*sk).wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).wal_reader_allocate(&mut (*sk));
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn free_event_set(wp: *mut WalProposer) {
|
||||
unsafe {
|
||||
let callback_data = (*(*wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).free_event_set(&mut (*wp));
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn init_event_set(wp: *mut WalProposer) {
|
||||
unsafe {
|
||||
let callback_data = (*(*wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).init_event_set(&mut (*wp));
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn update_event_set(sk: *mut Safekeeper, events: uint32) {
|
||||
unsafe {
|
||||
let callback_data = (*(*(*sk).wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).update_event_set(&mut (*sk), events);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn add_safekeeper_event_set(sk: *mut Safekeeper, events: uint32) {
|
||||
unsafe {
|
||||
let callback_data = (*(*(*sk).wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).add_safekeeper_event_set(&mut (*sk), events);
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn wait_event_set(
|
||||
wp: *mut WalProposer,
|
||||
timeout: ::std::os::raw::c_long,
|
||||
event_sk: *mut *mut Safekeeper,
|
||||
events: *mut uint32,
|
||||
) -> ::std::os::raw::c_int {
|
||||
unsafe {
|
||||
let callback_data = (*(*wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
let result = (*api).wait_event_set(&mut (*wp), timeout);
|
||||
match result {
|
||||
WaitResult::Latch => {
|
||||
*event_sk = std::ptr::null_mut();
|
||||
*events = crate::bindings::WL_LATCH_SET;
|
||||
1
|
||||
}
|
||||
WaitResult::Timeout => {
|
||||
*event_sk = std::ptr::null_mut();
|
||||
*events = crate::bindings::WL_TIMEOUT;
|
||||
0
|
||||
}
|
||||
WaitResult::Network(sk, event_mask) => {
|
||||
*event_sk = sk;
|
||||
*events = event_mask;
|
||||
1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn strong_random(
|
||||
wp: *mut WalProposer,
|
||||
buf: *mut ::std::os::raw::c_void,
|
||||
len: usize,
|
||||
) -> bool {
|
||||
unsafe {
|
||||
let buf = std::slice::from_raw_parts_mut(buf as *mut u8, len);
|
||||
let callback_data = (*(*wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).strong_random(buf)
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn get_redo_start_lsn(wp: *mut WalProposer) -> XLogRecPtr {
|
||||
unsafe {
|
||||
let callback_data = (*(*wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).get_redo_start_lsn()
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn finish_sync_safekeepers(wp: *mut WalProposer, lsn: XLogRecPtr) {
|
||||
unsafe {
|
||||
let callback_data = (*(*wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).finish_sync_safekeepers(lsn)
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn process_safekeeper_feedback(wp: *mut WalProposer, commit_lsn: XLogRecPtr) {
|
||||
unsafe {
|
||||
let callback_data = (*(*wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).process_safekeeper_feedback(&mut (*wp), commit_lsn)
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn confirm_wal_streamed(wp: *mut WalProposer, lsn: XLogRecPtr) {
|
||||
unsafe {
|
||||
let callback_data = (*(*wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).confirm_wal_streamed(&mut (*wp), lsn)
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn log_internal(
|
||||
wp: *mut WalProposer,
|
||||
level: ::std::os::raw::c_int,
|
||||
line: *const ::std::os::raw::c_char,
|
||||
) {
|
||||
unsafe {
|
||||
let callback_data = (*(*wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
let line = CStr::from_ptr(line);
|
||||
let line = line.to_str().unwrap();
|
||||
(*api).log_internal(&mut (*wp), Level::from(level as u32), line)
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" fn after_election(wp: *mut WalProposer) {
|
||||
unsafe {
|
||||
let callback_data = (*(*wp).config).callback_data;
|
||||
let api = callback_data as *mut Box<dyn ApiImpl>;
|
||||
(*api).after_election(&mut (*wp))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Level {
|
||||
Debug5,
|
||||
Debug4,
|
||||
Debug3,
|
||||
Debug2,
|
||||
Debug1,
|
||||
Log,
|
||||
Info,
|
||||
Notice,
|
||||
Warning,
|
||||
Error,
|
||||
Fatal,
|
||||
Panic,
|
||||
WPEvent,
|
||||
}
|
||||
|
||||
impl Level {
|
||||
pub fn from(elevel: u32) -> Level {
|
||||
use crate::bindings::*;
|
||||
|
||||
match elevel {
|
||||
DEBUG5 => Level::Debug5,
|
||||
DEBUG4 => Level::Debug4,
|
||||
DEBUG3 => Level::Debug3,
|
||||
DEBUG2 => Level::Debug2,
|
||||
DEBUG1 => Level::Debug1,
|
||||
LOG => Level::Log,
|
||||
INFO => Level::Info,
|
||||
NOTICE => Level::Notice,
|
||||
WARNING => Level::Warning,
|
||||
ERROR => Level::Error,
|
||||
FATAL => Level::Fatal,
|
||||
PANIC => Level::Panic,
|
||||
WPEVENT => Level::WPEvent,
|
||||
_ => panic!("unknown log level {}", elevel),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn create_api() -> walproposer_api {
|
||||
walproposer_api {
|
||||
get_shmem_state: Some(get_shmem_state),
|
||||
start_streaming: Some(start_streaming),
|
||||
get_flush_rec_ptr: Some(get_flush_rec_ptr),
|
||||
get_current_timestamp: Some(get_current_timestamp),
|
||||
conn_error_message: Some(conn_error_message),
|
||||
conn_status: Some(conn_status),
|
||||
conn_connect_start: Some(conn_connect_start),
|
||||
conn_connect_poll: Some(conn_connect_poll),
|
||||
conn_send_query: Some(conn_send_query),
|
||||
conn_get_query_result: Some(conn_get_query_result),
|
||||
conn_flush: Some(conn_flush),
|
||||
conn_finish: Some(conn_finish),
|
||||
conn_async_read: Some(conn_async_read),
|
||||
conn_async_write: Some(conn_async_write),
|
||||
conn_blocking_write: Some(conn_blocking_write),
|
||||
recovery_download: Some(recovery_download),
|
||||
wal_read: Some(wal_read),
|
||||
wal_reader_allocate: Some(wal_reader_allocate),
|
||||
free_event_set: Some(free_event_set),
|
||||
init_event_set: Some(init_event_set),
|
||||
update_event_set: Some(update_event_set),
|
||||
add_safekeeper_event_set: Some(add_safekeeper_event_set),
|
||||
wait_event_set: Some(wait_event_set),
|
||||
strong_random: Some(strong_random),
|
||||
get_redo_start_lsn: Some(get_redo_start_lsn),
|
||||
finish_sync_safekeepers: Some(finish_sync_safekeepers),
|
||||
process_safekeeper_feedback: Some(process_safekeeper_feedback),
|
||||
confirm_wal_streamed: Some(confirm_wal_streamed),
|
||||
log_internal: Some(log_internal),
|
||||
after_election: Some(after_election),
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Level {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Take ownership of `Vec<u8>` from StringInfoData.
|
||||
#[allow(clippy::unnecessary_cast)]
|
||||
pub(crate) fn take_vec_u8(pg: &mut StringInfoData) -> Option<Vec<u8>> {
|
||||
if pg.data.is_null() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let ptr = pg.data as *mut u8;
|
||||
let length = pg.len as usize;
|
||||
let capacity = pg.maxlen as usize;
|
||||
|
||||
pg.data = std::ptr::null_mut();
|
||||
pg.len = 0;
|
||||
pg.maxlen = 0;
|
||||
|
||||
unsafe { Some(Vec::from_raw_parts(ptr, length, capacity)) }
|
||||
}
|
||||
|
||||
/// Store `Vec<u8>` in StringInfoData.
|
||||
fn store_vec_u8(pg: &mut StringInfoData, vec: Vec<u8>) -> *mut ::std::os::raw::c_char {
|
||||
let ptr = vec.as_ptr() as *mut ::std::os::raw::c_char;
|
||||
let length = vec.len();
|
||||
let capacity = vec.capacity();
|
||||
|
||||
assert!(pg.data.is_null());
|
||||
|
||||
pg.data = ptr;
|
||||
pg.len = length as i32;
|
||||
pg.maxlen = capacity as i32;
|
||||
|
||||
std::mem::forget(vec);
|
||||
|
||||
ptr
|
||||
}
|
||||
14
libs/walproposer/src/lib.rs
Normal file
14
libs/walproposer/src/lib.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
pub mod bindings {
|
||||
#![allow(non_upper_case_globals)]
|
||||
#![allow(non_camel_case_types)]
|
||||
#![allow(non_snake_case)]
|
||||
// bindgen creates some unsafe code with no doc comments.
|
||||
#![allow(clippy::missing_safety_doc)]
|
||||
// noted at 1.63 that in many cases there's a u32 -> u32 transmutes in bindgen code.
|
||||
#![allow(clippy::useless_transmute)]
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
|
||||
}
|
||||
|
||||
pub mod api_bindings;
|
||||
pub mod walproposer;
|
||||
485
libs/walproposer/src/walproposer.rs
Normal file
485
libs/walproposer/src/walproposer.rs
Normal file
@@ -0,0 +1,485 @@
|
||||
use std::ffi::CString;
|
||||
|
||||
use postgres_ffi::WAL_SEGMENT_SIZE;
|
||||
use utils::id::TenantTimelineId;
|
||||
|
||||
use crate::{
|
||||
api_bindings::{create_api, take_vec_u8, Level},
|
||||
bindings::{
|
||||
Safekeeper, WalProposer, WalProposerConfig, WalProposerCreate, WalProposerFree,
|
||||
WalProposerStart,
|
||||
},
|
||||
};
|
||||
|
||||
/// Rust high-level wrapper for C walproposer API. Many methods are not required
|
||||
/// for simple cases, hence todo!() in default implementations.
|
||||
///
|
||||
/// Refer to `pgxn/neon/walproposer.h` for documentation.
|
||||
pub trait ApiImpl {
|
||||
fn get_shmem_state(&self) -> &mut crate::bindings::WalproposerShmemState {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn start_streaming(&self, _startpos: u64) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn get_flush_rec_ptr(&self) -> u64 {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn get_current_timestamp(&self) -> i64 {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conn_error_message(&self, _sk: &mut Safekeeper) -> String {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conn_status(&self, _sk: &mut Safekeeper) -> crate::bindings::WalProposerConnStatusType {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conn_connect_start(&self, _sk: &mut Safekeeper) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conn_connect_poll(
|
||||
&self,
|
||||
_sk: &mut Safekeeper,
|
||||
) -> crate::bindings::WalProposerConnectPollStatusType {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conn_send_query(&self, _sk: &mut Safekeeper, _query: &str) -> bool {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conn_get_query_result(
|
||||
&self,
|
||||
_sk: &mut Safekeeper,
|
||||
) -> crate::bindings::WalProposerExecStatusType {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conn_flush(&self, _sk: &mut Safekeeper) -> i32 {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conn_finish(&self, _sk: &mut Safekeeper) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conn_async_read(&self, _sk: &mut Safekeeper) -> (&[u8], crate::bindings::PGAsyncReadResult) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conn_async_write(
|
||||
&self,
|
||||
_sk: &mut Safekeeper,
|
||||
_buf: &[u8],
|
||||
) -> crate::bindings::PGAsyncWriteResult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn conn_blocking_write(&self, _sk: &mut Safekeeper, _buf: &[u8]) -> bool {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn recovery_download(&self, _sk: &mut Safekeeper, _startpos: u64, _endpos: u64) -> bool {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn wal_read(&self, _sk: &mut Safekeeper, _buf: &mut [u8], _startpos: u64) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn wal_reader_allocate(&self, _sk: &mut Safekeeper) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn free_event_set(&self, _wp: &mut WalProposer) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn init_event_set(&self, _wp: &mut WalProposer) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn update_event_set(&self, _sk: &mut Safekeeper, _events_mask: u32) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn add_safekeeper_event_set(&self, _sk: &mut Safekeeper, _events_mask: u32) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn wait_event_set(&self, _wp: &mut WalProposer, _timeout_millis: i64) -> WaitResult {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn strong_random(&self, _buf: &mut [u8]) -> bool {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn get_redo_start_lsn(&self) -> u64 {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn finish_sync_safekeepers(&self, _lsn: u64) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn process_safekeeper_feedback(&self, _wp: &mut WalProposer, _commit_lsn: u64) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn confirm_wal_streamed(&self, _wp: &mut WalProposer, _lsn: u64) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn log_internal(&self, _wp: &mut WalProposer, _level: Level, _msg: &str) {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn after_election(&self, _wp: &mut WalProposer) {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
pub enum WaitResult {
|
||||
Latch,
|
||||
Timeout,
|
||||
Network(*mut Safekeeper, u32),
|
||||
}
|
||||
|
||||
pub struct Config {
|
||||
/// Tenant and timeline id
|
||||
pub ttid: TenantTimelineId,
|
||||
/// List of safekeepers in format `host:port`
|
||||
pub safekeepers_list: Vec<String>,
|
||||
/// Safekeeper reconnect timeout in milliseconds
|
||||
pub safekeeper_reconnect_timeout: i32,
|
||||
/// Safekeeper connection timeout in milliseconds
|
||||
pub safekeeper_connection_timeout: i32,
|
||||
/// walproposer mode, finish when all safekeepers are synced or subscribe
|
||||
/// to WAL streaming
|
||||
pub sync_safekeepers: bool,
|
||||
}
|
||||
|
||||
/// WalProposer main struct. C methods are reexported as Rust functions.
|
||||
pub struct Wrapper {
|
||||
wp: *mut WalProposer,
|
||||
_safekeepers_list_vec: Vec<u8>,
|
||||
}
|
||||
|
||||
impl Wrapper {
|
||||
pub fn new(api: Box<dyn ApiImpl>, config: Config) -> Wrapper {
|
||||
let neon_tenant = CString::new(config.ttid.tenant_id.to_string())
|
||||
.unwrap()
|
||||
.into_raw();
|
||||
let neon_timeline = CString::new(config.ttid.timeline_id.to_string())
|
||||
.unwrap()
|
||||
.into_raw();
|
||||
|
||||
let mut safekeepers_list_vec = CString::new(config.safekeepers_list.join(","))
|
||||
.unwrap()
|
||||
.into_bytes_with_nul();
|
||||
assert!(safekeepers_list_vec.len() == safekeepers_list_vec.capacity());
|
||||
let safekeepers_list = safekeepers_list_vec.as_mut_ptr() as *mut std::ffi::c_char;
|
||||
|
||||
let callback_data = Box::into_raw(Box::new(api)) as *mut ::std::os::raw::c_void;
|
||||
|
||||
let c_config = WalProposerConfig {
|
||||
neon_tenant,
|
||||
neon_timeline,
|
||||
safekeepers_list,
|
||||
safekeeper_reconnect_timeout: config.safekeeper_reconnect_timeout,
|
||||
safekeeper_connection_timeout: config.safekeeper_connection_timeout,
|
||||
wal_segment_size: WAL_SEGMENT_SIZE as i32, // default 16MB
|
||||
syncSafekeepers: config.sync_safekeepers,
|
||||
systemId: 0,
|
||||
pgTimeline: 1,
|
||||
callback_data,
|
||||
};
|
||||
let c_config = Box::into_raw(Box::new(c_config));
|
||||
|
||||
let api = create_api();
|
||||
let wp = unsafe { WalProposerCreate(c_config, api) };
|
||||
Wrapper {
|
||||
wp,
|
||||
_safekeepers_list_vec: safekeepers_list_vec,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start(&self) {
|
||||
unsafe { WalProposerStart(self.wp) }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Wrapper {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
let config = (*self.wp).config;
|
||||
drop(Box::from_raw(
|
||||
(*config).callback_data as *mut Box<dyn ApiImpl>,
|
||||
));
|
||||
drop(CString::from_raw((*config).neon_tenant));
|
||||
drop(CString::from_raw((*config).neon_timeline));
|
||||
drop(Box::from_raw(config));
|
||||
|
||||
for i in 0..(*self.wp).n_safekeepers {
|
||||
let sk = &mut (*self.wp).safekeeper[i as usize];
|
||||
take_vec_u8(&mut sk.inbuf);
|
||||
}
|
||||
|
||||
WalProposerFree(self.wp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{
|
||||
cell::Cell,
|
||||
sync::{atomic::AtomicUsize, mpsc::sync_channel},
|
||||
};
|
||||
|
||||
use utils::id::TenantTimelineId;
|
||||
|
||||
use crate::{api_bindings::Level, walproposer::Wrapper};
|
||||
|
||||
use super::ApiImpl;
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct WaitEventsData {
|
||||
sk: *mut crate::bindings::Safekeeper,
|
||||
event_mask: u32,
|
||||
}
|
||||
|
||||
struct MockImpl {
|
||||
// data to return from wait_event_set
|
||||
wait_events: Cell<WaitEventsData>,
|
||||
// walproposer->safekeeper messages
|
||||
expected_messages: Vec<Vec<u8>>,
|
||||
expected_ptr: AtomicUsize,
|
||||
// safekeeper->walproposer messages
|
||||
safekeeper_replies: Vec<Vec<u8>>,
|
||||
replies_ptr: AtomicUsize,
|
||||
// channel to send LSN to the main thread
|
||||
sync_channel: std::sync::mpsc::SyncSender<u64>,
|
||||
}
|
||||
|
||||
impl MockImpl {
|
||||
fn check_walproposer_msg(&self, msg: &[u8]) {
|
||||
let ptr = self
|
||||
.expected_ptr
|
||||
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||
|
||||
if ptr >= self.expected_messages.len() {
|
||||
panic!("unexpected message from walproposer");
|
||||
}
|
||||
|
||||
let expected_msg = &self.expected_messages[ptr];
|
||||
assert_eq!(msg, expected_msg.as_slice());
|
||||
}
|
||||
|
||||
fn next_safekeeper_reply(&self) -> &[u8] {
|
||||
let ptr = self
|
||||
.replies_ptr
|
||||
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
|
||||
|
||||
if ptr >= self.safekeeper_replies.len() {
|
||||
panic!("no more safekeeper replies");
|
||||
}
|
||||
|
||||
&self.safekeeper_replies[ptr]
|
||||
}
|
||||
}
|
||||
|
||||
impl ApiImpl for MockImpl {
|
||||
fn get_current_timestamp(&self) -> i64 {
|
||||
println!("get_current_timestamp");
|
||||
0
|
||||
}
|
||||
|
||||
fn conn_status(
|
||||
&self,
|
||||
_: &mut crate::bindings::Safekeeper,
|
||||
) -> crate::bindings::WalProposerConnStatusType {
|
||||
println!("conn_status");
|
||||
crate::bindings::WalProposerConnStatusType_WP_CONNECTION_OK
|
||||
}
|
||||
|
||||
fn conn_connect_start(&self, _: &mut crate::bindings::Safekeeper) {
|
||||
println!("conn_connect_start");
|
||||
}
|
||||
|
||||
fn conn_connect_poll(
|
||||
&self,
|
||||
_: &mut crate::bindings::Safekeeper,
|
||||
) -> crate::bindings::WalProposerConnectPollStatusType {
|
||||
println!("conn_connect_poll");
|
||||
crate::bindings::WalProposerConnectPollStatusType_WP_CONN_POLLING_OK
|
||||
}
|
||||
|
||||
fn conn_send_query(&self, _: &mut crate::bindings::Safekeeper, query: &str) -> bool {
|
||||
println!("conn_send_query: {}", query);
|
||||
true
|
||||
}
|
||||
|
||||
fn conn_get_query_result(
|
||||
&self,
|
||||
_: &mut crate::bindings::Safekeeper,
|
||||
) -> crate::bindings::WalProposerExecStatusType {
|
||||
println!("conn_get_query_result");
|
||||
crate::bindings::WalProposerExecStatusType_WP_EXEC_SUCCESS_COPYBOTH
|
||||
}
|
||||
|
||||
fn conn_async_read(
|
||||
&self,
|
||||
_: &mut crate::bindings::Safekeeper,
|
||||
) -> (&[u8], crate::bindings::PGAsyncReadResult) {
|
||||
println!("conn_async_read");
|
||||
let reply = self.next_safekeeper_reply();
|
||||
println!("conn_async_read result: {:?}", reply);
|
||||
(
|
||||
reply,
|
||||
crate::bindings::PGAsyncReadResult_PG_ASYNC_READ_SUCCESS,
|
||||
)
|
||||
}
|
||||
|
||||
fn conn_blocking_write(&self, _: &mut crate::bindings::Safekeeper, buf: &[u8]) -> bool {
|
||||
println!("conn_blocking_write: {:?}", buf);
|
||||
self.check_walproposer_msg(buf);
|
||||
true
|
||||
}
|
||||
|
||||
fn wal_reader_allocate(&self, _: &mut crate::bindings::Safekeeper) {
|
||||
println!("wal_reader_allocate")
|
||||
}
|
||||
|
||||
fn free_event_set(&self, _: &mut crate::bindings::WalProposer) {
|
||||
println!("free_event_set")
|
||||
}
|
||||
|
||||
fn init_event_set(&self, _: &mut crate::bindings::WalProposer) {
|
||||
println!("init_event_set")
|
||||
}
|
||||
|
||||
fn update_event_set(&self, sk: &mut crate::bindings::Safekeeper, event_mask: u32) {
|
||||
println!(
|
||||
"update_event_set, sk={:?}, events_mask={:#b}",
|
||||
sk as *mut crate::bindings::Safekeeper, event_mask
|
||||
);
|
||||
self.wait_events.set(WaitEventsData { sk, event_mask });
|
||||
}
|
||||
|
||||
fn add_safekeeper_event_set(&self, sk: &mut crate::bindings::Safekeeper, event_mask: u32) {
|
||||
println!(
|
||||
"add_safekeeper_event_set, sk={:?}, events_mask={:#b}",
|
||||
sk as *mut crate::bindings::Safekeeper, event_mask
|
||||
);
|
||||
self.wait_events.set(WaitEventsData { sk, event_mask });
|
||||
}
|
||||
|
||||
fn wait_event_set(
|
||||
&self,
|
||||
_: &mut crate::bindings::WalProposer,
|
||||
timeout_millis: i64,
|
||||
) -> super::WaitResult {
|
||||
let data = self.wait_events.get();
|
||||
println!(
|
||||
"wait_event_set, timeout_millis={}, res={:?}",
|
||||
timeout_millis, data
|
||||
);
|
||||
super::WaitResult::Network(data.sk, data.event_mask)
|
||||
}
|
||||
|
||||
fn strong_random(&self, buf: &mut [u8]) -> bool {
|
||||
println!("strong_random");
|
||||
buf.fill(0);
|
||||
true
|
||||
}
|
||||
|
||||
fn finish_sync_safekeepers(&self, lsn: u64) {
|
||||
self.sync_channel.send(lsn).unwrap();
|
||||
panic!("sync safekeepers finished at lsn={}", lsn);
|
||||
}
|
||||
|
||||
fn log_internal(&self, _wp: &mut crate::bindings::WalProposer, level: Level, msg: &str) {
|
||||
println!("walprop_log[{}] {}", level, msg);
|
||||
}
|
||||
|
||||
fn after_election(&self, _wp: &mut crate::bindings::WalProposer) {
|
||||
println!("after_election");
|
||||
}
|
||||
}
|
||||
|
||||
/// Test that walproposer can successfully connect to safekeeper and finish
|
||||
/// sync_safekeepers. API is mocked in MockImpl.
|
||||
///
|
||||
/// Run this test with valgrind to detect leaks:
|
||||
/// `valgrind --leak-check=full target/debug/deps/walproposer-<build>`
|
||||
#[test]
|
||||
fn test_simple_sync_safekeepers() -> anyhow::Result<()> {
|
||||
let ttid = TenantTimelineId::new(
|
||||
"9e4c8f36063c6c6e93bc20d65a820f3d".parse()?,
|
||||
"9e4c8f36063c6c6e93bc20d65a820f3d".parse()?,
|
||||
);
|
||||
|
||||
let (sender, receiver) = sync_channel(1);
|
||||
|
||||
let my_impl: Box<dyn ApiImpl> = Box::new(MockImpl {
|
||||
wait_events: Cell::new(WaitEventsData {
|
||||
sk: std::ptr::null_mut(),
|
||||
event_mask: 0,
|
||||
}),
|
||||
expected_messages: vec![
|
||||
// Greeting(ProposerGreeting { protocol_version: 2, pg_version: 160000, proposer_id: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], system_id: 0, timeline_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tenant_id: 9e4c8f36063c6c6e93bc20d65a820f3d, tli: 1, wal_seg_size: 16777216 })
|
||||
vec![
|
||||
103, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 113, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 158, 76, 143, 54, 6, 60, 108, 110,
|
||||
147, 188, 32, 214, 90, 130, 15, 61, 158, 76, 143, 54, 6, 60, 108, 110, 147,
|
||||
188, 32, 214, 90, 130, 15, 61, 1, 0, 0, 0, 0, 0, 0, 1,
|
||||
],
|
||||
// VoteRequest(VoteRequest { term: 3 })
|
||||
vec![
|
||||
118, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||
0, 0, 0, 0, 0, 0,
|
||||
],
|
||||
],
|
||||
expected_ptr: AtomicUsize::new(0),
|
||||
safekeeper_replies: vec![
|
||||
// Greeting(AcceptorGreeting { term: 2, node_id: NodeId(1) })
|
||||
vec![
|
||||
103, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
|
||||
],
|
||||
// VoteResponse(VoteResponse { term: 3, vote_given: 1, flush_lsn: 0/539, truncate_lsn: 0/539, term_history: [(2, 0/539)], timeline_start_lsn: 0/539 })
|
||||
vec![
|
||||
118, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 57,
|
||||
5, 0, 0, 0, 0, 0, 0, 57, 5, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0,
|
||||
0, 57, 5, 0, 0, 0, 0, 0, 0, 57, 5, 0, 0, 0, 0, 0, 0,
|
||||
],
|
||||
],
|
||||
replies_ptr: AtomicUsize::new(0),
|
||||
sync_channel: sender,
|
||||
});
|
||||
let config = crate::walproposer::Config {
|
||||
ttid,
|
||||
safekeepers_list: vec!["localhost:5000".to_string()],
|
||||
safekeeper_reconnect_timeout: 1000,
|
||||
safekeeper_connection_timeout: 10000,
|
||||
sync_safekeepers: true,
|
||||
};
|
||||
|
||||
let wp = Wrapper::new(my_impl, config);
|
||||
|
||||
// walproposer will panic when it finishes sync_safekeepers
|
||||
std::panic::catch_unwind(|| wp.start()).unwrap_err();
|
||||
// validate the resulting LSN
|
||||
assert_eq!(receiver.recv()?, 1337);
|
||||
Ok(())
|
||||
// drop() will free up resources here
|
||||
}
|
||||
}
|
||||
@@ -11,10 +11,7 @@ use std::sync::{Arc, Barrier};
|
||||
|
||||
use bytes::{Buf, Bytes};
|
||||
use pageserver::{
|
||||
config::PageServerConf,
|
||||
repository::Key,
|
||||
walrecord::NeonWalRecord,
|
||||
walredo::{PostgresRedoManager, WalRedoError},
|
||||
config::PageServerConf, repository::Key, walrecord::NeonWalRecord, walredo::PostgresRedoManager,
|
||||
};
|
||||
use utils::{id::TenantId, lsn::Lsn};
|
||||
|
||||
@@ -35,9 +32,15 @@ fn redo_scenarios(c: &mut Criterion) {
|
||||
|
||||
let manager = Arc::new(manager);
|
||||
|
||||
tracing::info!("executing first");
|
||||
short().execute(&manager).unwrap();
|
||||
tracing::info!("first executed");
|
||||
{
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap();
|
||||
tracing::info!("executing first");
|
||||
short().execute(rt.handle(), &manager).unwrap();
|
||||
tracing::info!("first executed");
|
||||
}
|
||||
|
||||
let thread_counts = [1, 2, 4, 8, 16];
|
||||
|
||||
@@ -80,9 +83,14 @@ fn add_multithreaded_walredo_requesters(
|
||||
assert_ne!(threads, 0);
|
||||
|
||||
if threads == 1 {
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap();
|
||||
let handle = rt.handle();
|
||||
b.iter_batched_ref(
|
||||
|| Some(input_factory()),
|
||||
|input| execute_all(input.take(), manager),
|
||||
|input| execute_all(input.take(), handle, manager),
|
||||
criterion::BatchSize::PerIteration,
|
||||
);
|
||||
} else {
|
||||
@@ -98,19 +106,26 @@ fn add_multithreaded_walredo_requesters(
|
||||
let manager = manager.clone();
|
||||
let barrier = barrier.clone();
|
||||
let work_rx = work_rx.clone();
|
||||
move || loop {
|
||||
// queue up and wait if we want to go another round
|
||||
if work_rx.lock().unwrap().recv().is_err() {
|
||||
break;
|
||||
move || {
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap();
|
||||
let handle = rt.handle();
|
||||
loop {
|
||||
// queue up and wait if we want to go another round
|
||||
if work_rx.lock().unwrap().recv().is_err() {
|
||||
break;
|
||||
}
|
||||
|
||||
let input = Some(input_factory());
|
||||
|
||||
barrier.wait();
|
||||
|
||||
execute_all(input, handle, &manager).unwrap();
|
||||
|
||||
barrier.wait();
|
||||
}
|
||||
|
||||
let input = Some(input_factory());
|
||||
|
||||
barrier.wait();
|
||||
|
||||
execute_all(input, &manager).unwrap();
|
||||
|
||||
barrier.wait();
|
||||
}
|
||||
})
|
||||
})
|
||||
@@ -152,15 +167,19 @@ impl Drop for JoinOnDrop {
|
||||
}
|
||||
}
|
||||
|
||||
fn execute_all<I>(input: I, manager: &PostgresRedoManager) -> Result<(), WalRedoError>
|
||||
fn execute_all<I>(
|
||||
input: I,
|
||||
handle: &tokio::runtime::Handle,
|
||||
manager: &PostgresRedoManager,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
I: IntoIterator<Item = Request>,
|
||||
{
|
||||
// just fire all requests as fast as possible
|
||||
input.into_iter().try_for_each(|req| {
|
||||
let page = req.execute(manager)?;
|
||||
let page = req.execute(handle, manager)?;
|
||||
assert_eq!(page.remaining(), 8192);
|
||||
Ok::<_, WalRedoError>(())
|
||||
anyhow::Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -473,9 +492,11 @@ struct Request {
|
||||
}
|
||||
|
||||
impl Request {
|
||||
fn execute(self, manager: &PostgresRedoManager) -> Result<Bytes, WalRedoError> {
|
||||
use pageserver::walredo::WalRedoManager;
|
||||
|
||||
fn execute(
|
||||
self,
|
||||
rt: &tokio::runtime::Handle,
|
||||
manager: &PostgresRedoManager,
|
||||
) -> anyhow::Result<Bytes> {
|
||||
let Request {
|
||||
key,
|
||||
lsn,
|
||||
@@ -484,6 +505,6 @@ impl Request {
|
||||
pg_version,
|
||||
} = self;
|
||||
|
||||
manager.request_redo(key, lsn, base_img, records, pg_version)
|
||||
rt.block_on(manager.request_redo(key, lsn, base_img, records, pg_version))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,22 +1,21 @@
|
||||
use anyhow::{bail, Result};
|
||||
use utils::auth::{Claims, Scope};
|
||||
use utils::auth::{AuthError, Claims, Scope};
|
||||
use utils::id::TenantId;
|
||||
|
||||
pub fn check_permission(claims: &Claims, tenant_id: Option<TenantId>) -> Result<()> {
|
||||
pub fn check_permission(claims: &Claims, tenant_id: Option<TenantId>) -> Result<(), AuthError> {
|
||||
match (&claims.scope, tenant_id) {
|
||||
(Scope::Tenant, None) => {
|
||||
bail!("Attempt to access management api with tenant scope. Permission denied")
|
||||
}
|
||||
(Scope::Tenant, None) => Err(AuthError(
|
||||
"Attempt to access management api with tenant scope. Permission denied".into(),
|
||||
)),
|
||||
(Scope::Tenant, Some(tenant_id)) => {
|
||||
if claims.tenant_id.unwrap() != tenant_id {
|
||||
bail!("Tenant id mismatch. Permission denied")
|
||||
return Err(AuthError("Tenant id mismatch. Permission denied".into()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
(Scope::PageServerApi, None) => Ok(()), // access to management api for PageServerApi scope
|
||||
(Scope::PageServerApi, Some(_)) => Ok(()), // access to tenant api using PageServerApi scope
|
||||
(Scope::SafekeeperData, _) => {
|
||||
bail!("SafekeeperData scope makes no sense for Pageserver")
|
||||
}
|
||||
(Scope::SafekeeperData, _) => Err(AuthError(
|
||||
"SafekeeperData scope makes no sense for Pageserver".into(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
use anyhow::{anyhow, bail, ensure, Context};
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use fail::fail_point;
|
||||
use postgres_ffi::pg_constants;
|
||||
use std::fmt::Write as FmtWrite;
|
||||
use std::time::SystemTime;
|
||||
use tokio::io;
|
||||
@@ -180,6 +181,7 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
let mut min_restart_lsn: Lsn = Lsn::MAX;
|
||||
// Create tablespace directories
|
||||
for ((spcnode, dbnode), has_relmap_file) in
|
||||
self.timeline.list_dbdirs(self.lsn, self.ctx).await?
|
||||
@@ -213,6 +215,34 @@ where
|
||||
self.add_rel(rel, rel).await?;
|
||||
}
|
||||
}
|
||||
|
||||
for (path, content) in self.timeline.list_aux_files(self.lsn, self.ctx).await? {
|
||||
if path.starts_with("pg_replslot") {
|
||||
let offs = pg_constants::REPL_SLOT_ON_DISK_OFFSETOF_RESTART_LSN;
|
||||
let restart_lsn = Lsn(u64::from_le_bytes(
|
||||
content[offs..offs + 8].try_into().unwrap(),
|
||||
));
|
||||
info!("Replication slot {} restart LSN={}", path, restart_lsn);
|
||||
min_restart_lsn = Lsn::min(min_restart_lsn, restart_lsn);
|
||||
}
|
||||
let header = new_tar_header(&path, content.len() as u64)?;
|
||||
self.ar
|
||||
.append(&header, &*content)
|
||||
.await
|
||||
.context("could not add aux file to basebackup tarball")?;
|
||||
}
|
||||
}
|
||||
if min_restart_lsn != Lsn::MAX {
|
||||
info!(
|
||||
"Min restart LSN for logical replication is {}",
|
||||
min_restart_lsn
|
||||
);
|
||||
let data = min_restart_lsn.0.to_le_bytes();
|
||||
let header = new_tar_header("restart.lsn", data.len() as u64)?;
|
||||
self.ar
|
||||
.append(&header, &data[..])
|
||||
.await
|
||||
.context("could not add restart.lsn file to basebackup tarball")?;
|
||||
}
|
||||
for xid in self
|
||||
.timeline
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
use std::env::{var, VarError};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{env, ops::ControlFlow, str::FromStr};
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
@@ -33,11 +34,15 @@ use postgres_backend::AuthType;
|
||||
use utils::logging::TracingErrorLayerEnablement;
|
||||
use utils::signals::ShutdownSignals;
|
||||
use utils::{
|
||||
auth::JwtAuth, logging, project_git_version, sentry_init::init_sentry, signals::Signal,
|
||||
auth::{JwtAuth, SwappableJwtAuth},
|
||||
logging, project_build_tag, project_git_version,
|
||||
sentry_init::init_sentry,
|
||||
signals::Signal,
|
||||
tcp_listener,
|
||||
};
|
||||
|
||||
project_git_version!(GIT_VERSION);
|
||||
project_build_tag!(BUILD_TAG);
|
||||
|
||||
const PID_FILE_NAME: &str = "pageserver.pid";
|
||||
|
||||
@@ -200,6 +205,51 @@ fn initialize_config(
|
||||
})
|
||||
}
|
||||
|
||||
struct WaitForPhaseResult<F: std::future::Future + Unpin> {
|
||||
timeout_remaining: Duration,
|
||||
skipped: Option<F>,
|
||||
}
|
||||
|
||||
/// During startup, we apply a timeout to our waits for readiness, to avoid
|
||||
/// stalling the whole service if one Tenant experiences some problem. Each
|
||||
/// phase may consume some of the timeout: this function returns the updated
|
||||
/// timeout for use in the next call.
|
||||
async fn wait_for_phase<F>(phase: &str, mut fut: F, timeout: Duration) -> WaitForPhaseResult<F>
|
||||
where
|
||||
F: std::future::Future + Unpin,
|
||||
{
|
||||
let initial_t = Instant::now();
|
||||
let skipped = match tokio::time::timeout(timeout, &mut fut).await {
|
||||
Ok(_) => None,
|
||||
Err(_) => {
|
||||
tracing::info!(
|
||||
timeout_millis = timeout.as_millis(),
|
||||
%phase,
|
||||
"Startup phase timed out, proceeding anyway"
|
||||
);
|
||||
Some(fut)
|
||||
}
|
||||
};
|
||||
|
||||
WaitForPhaseResult {
|
||||
timeout_remaining: timeout
|
||||
.checked_sub(Instant::now().duration_since(initial_t))
|
||||
.unwrap_or(Duration::ZERO),
|
||||
skipped,
|
||||
}
|
||||
}
|
||||
|
||||
fn startup_checkpoint(started_at: Instant, phase: &str, human_phase: &str) {
|
||||
let elapsed = started_at.elapsed();
|
||||
let secs = elapsed.as_secs_f64();
|
||||
STARTUP_DURATION.with_label_values(&[phase]).set(secs);
|
||||
|
||||
info!(
|
||||
elapsed_ms = elapsed.as_millis(),
|
||||
"{human_phase} ({secs:.3}s since start)"
|
||||
)
|
||||
}
|
||||
|
||||
fn start_pageserver(
|
||||
launch_ts: &'static LaunchTimestamp,
|
||||
conf: &'static PageServerConf,
|
||||
@@ -207,26 +257,17 @@ fn start_pageserver(
|
||||
// Monotonic time for later calculating startup duration
|
||||
let started_startup_at = Instant::now();
|
||||
|
||||
let startup_checkpoint = move |phase: &str, human_phase: &str| {
|
||||
let elapsed = started_startup_at.elapsed();
|
||||
let secs = elapsed.as_secs_f64();
|
||||
STARTUP_DURATION.with_label_values(&[phase]).set(secs);
|
||||
info!(
|
||||
elapsed_ms = elapsed.as_millis(),
|
||||
"{human_phase} ({secs:.3}s since start)"
|
||||
)
|
||||
};
|
||||
|
||||
// Print version and launch timestamp to the log,
|
||||
// and expose them as prometheus metrics.
|
||||
// A changed version string indicates changed software.
|
||||
// A changed launch timestamp indicates a pageserver restart.
|
||||
info!(
|
||||
"version: {} launch_timestamp: {}",
|
||||
"version: {} launch_timestamp: {} build_tag: {}",
|
||||
version(),
|
||||
launch_ts.to_string()
|
||||
launch_ts.to_string(),
|
||||
BUILD_TAG,
|
||||
);
|
||||
set_build_info_metric(GIT_VERSION);
|
||||
set_build_info_metric(GIT_VERSION, BUILD_TAG);
|
||||
set_launch_timestamp_metric(launch_ts);
|
||||
pageserver::preinitialize_metrics();
|
||||
|
||||
@@ -283,13 +324,12 @@ fn start_pageserver(
|
||||
let http_auth;
|
||||
let pg_auth;
|
||||
if conf.http_auth_type == AuthType::NeonJWT || conf.pg_auth_type == AuthType::NeonJWT {
|
||||
// unwrap is ok because check is performed when creating config, so path is set and file exists
|
||||
// unwrap is ok because check is performed when creating config, so path is set and exists
|
||||
let key_path = conf.auth_validation_public_key_path.as_ref().unwrap();
|
||||
info!(
|
||||
"Loading public key for verifying JWT tokens from {:#?}",
|
||||
key_path
|
||||
);
|
||||
let auth: Arc<JwtAuth> = Arc::new(JwtAuth::from_key_path(key_path)?);
|
||||
info!("Loading public key(s) for verifying JWT tokens from {key_path:?}");
|
||||
|
||||
let jwt_auth = JwtAuth::from_key_path(key_path)?;
|
||||
let auth: Arc<SwappableJwtAuth> = Arc::new(SwappableJwtAuth::new(jwt_auth));
|
||||
|
||||
http_auth = match &conf.http_auth_type {
|
||||
AuthType::Trust => None,
|
||||
@@ -341,7 +381,7 @@ fn start_pageserver(
|
||||
|
||||
// Up to this point no significant I/O has been done: this should have been fast. Record
|
||||
// duration prior to starting I/O intensive phase of startup.
|
||||
startup_checkpoint("initial", "Starting loading tenants");
|
||||
startup_checkpoint(started_startup_at, "initial", "Starting loading tenants");
|
||||
STARTUP_IS_LOADING.set(1);
|
||||
|
||||
// Startup staging or optimizing:
|
||||
@@ -355,6 +395,7 @@ fn start_pageserver(
|
||||
// consumer side) will be dropped once we can start the background jobs. Currently it is behind
|
||||
// completing all initial logical size calculations (init_logical_size_done_rx) and a timeout
|
||||
// (background_task_maximum_delay).
|
||||
let (init_remote_done_tx, init_remote_done_rx) = utils::completion::channel();
|
||||
let (init_done_tx, init_done_rx) = utils::completion::channel();
|
||||
|
||||
let (init_logical_size_done_tx, init_logical_size_done_rx) = utils::completion::channel();
|
||||
@@ -362,7 +403,8 @@ fn start_pageserver(
|
||||
let (background_jobs_can_start, background_jobs_barrier) = utils::completion::channel();
|
||||
|
||||
let order = pageserver::InitializationOrder {
|
||||
initial_tenant_load: Some(init_done_tx),
|
||||
initial_tenant_load_remote: Some(init_done_tx),
|
||||
initial_tenant_load: Some(init_remote_done_tx),
|
||||
initial_logical_size_can_start: init_done_rx.clone(),
|
||||
initial_logical_size_attempt: Some(init_logical_size_done_tx),
|
||||
background_jobs_can_start: background_jobs_barrier.clone(),
|
||||
@@ -370,7 +412,7 @@ fn start_pageserver(
|
||||
|
||||
// Scan the local 'tenants/' directory and start loading the tenants
|
||||
let deletion_queue_client = deletion_queue.new_client();
|
||||
BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(
|
||||
let tenant_manager = BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(
|
||||
conf,
|
||||
TenantSharedResources {
|
||||
broker_client: broker_client.clone(),
|
||||
@@ -380,61 +422,100 @@ fn start_pageserver(
|
||||
order,
|
||||
shutdown_pageserver.clone(),
|
||||
))?;
|
||||
let tenant_manager = Arc::new(tenant_manager);
|
||||
|
||||
BACKGROUND_RUNTIME.spawn({
|
||||
let init_done_rx = init_done_rx;
|
||||
let shutdown_pageserver = shutdown_pageserver.clone();
|
||||
let drive_init = async move {
|
||||
// NOTE: unlike many futures in pageserver, this one is cancellation-safe
|
||||
let guard = scopeguard::guard_on_success((), |_| tracing::info!("Cancelled before initial load completed"));
|
||||
let guard = scopeguard::guard_on_success((), |_| {
|
||||
tracing::info!("Cancelled before initial load completed")
|
||||
});
|
||||
|
||||
init_done_rx.wait().await;
|
||||
startup_checkpoint("initial_tenant_load", "Initial load completed");
|
||||
STARTUP_IS_LOADING.set(0);
|
||||
let timeout = conf.background_task_maximum_delay;
|
||||
|
||||
let init_remote_done = std::pin::pin!(async {
|
||||
init_remote_done_rx.wait().await;
|
||||
startup_checkpoint(
|
||||
started_startup_at,
|
||||
"initial_tenant_load_remote",
|
||||
"Remote part of initial load completed",
|
||||
);
|
||||
});
|
||||
|
||||
let WaitForPhaseResult {
|
||||
timeout_remaining: timeout,
|
||||
skipped: init_remote_skipped,
|
||||
} = wait_for_phase("initial_tenant_load_remote", init_remote_done, timeout).await;
|
||||
|
||||
let init_load_done = std::pin::pin!(async {
|
||||
init_done_rx.wait().await;
|
||||
startup_checkpoint(
|
||||
started_startup_at,
|
||||
"initial_tenant_load",
|
||||
"Initial load completed",
|
||||
);
|
||||
STARTUP_IS_LOADING.set(0);
|
||||
});
|
||||
|
||||
let WaitForPhaseResult {
|
||||
timeout_remaining: timeout,
|
||||
skipped: init_load_skipped,
|
||||
} = wait_for_phase("initial_tenant_load", init_load_done, timeout).await;
|
||||
|
||||
// initial logical sizes can now start, as they were waiting on init_done_rx.
|
||||
|
||||
scopeguard::ScopeGuard::into_inner(guard);
|
||||
|
||||
let mut init_sizes_done = std::pin::pin!(init_logical_size_done_rx.wait());
|
||||
let guard = scopeguard::guard_on_success((), |_| {
|
||||
tracing::info!("Cancelled before initial logical sizes completed")
|
||||
});
|
||||
|
||||
let timeout = conf.background_task_maximum_delay;
|
||||
let logical_sizes_done = std::pin::pin!(async {
|
||||
init_logical_size_done_rx.wait().await;
|
||||
startup_checkpoint(
|
||||
started_startup_at,
|
||||
"initial_logical_sizes",
|
||||
"Initial logical sizes completed",
|
||||
);
|
||||
});
|
||||
|
||||
let guard = scopeguard::guard_on_success((), |_| tracing::info!("Cancelled before initial logical sizes completed"));
|
||||
|
||||
let init_sizes_done = match tokio::time::timeout(timeout, &mut init_sizes_done).await {
|
||||
Ok(_) => {
|
||||
startup_checkpoint("initial_logical_sizes", "Initial logical sizes completed");
|
||||
None
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::info!(
|
||||
timeout_millis = timeout.as_millis(),
|
||||
"Initial logical size timeout elapsed; starting background jobs"
|
||||
);
|
||||
Some(init_sizes_done)
|
||||
}
|
||||
};
|
||||
let WaitForPhaseResult {
|
||||
timeout_remaining: _,
|
||||
skipped: logical_sizes_skipped,
|
||||
} = wait_for_phase("initial_logical_sizes", logical_sizes_done, timeout).await;
|
||||
|
||||
scopeguard::ScopeGuard::into_inner(guard);
|
||||
|
||||
// allow background jobs to start
|
||||
// allow background jobs to start: we either completed prior stages, or they reached timeout
|
||||
// and were skipped. It is important that we do not let them block background jobs indefinitely,
|
||||
// because things like consumption metrics for billing are blocked by this barrier.
|
||||
drop(background_jobs_can_start);
|
||||
startup_checkpoint("background_jobs_can_start", "Starting background jobs");
|
||||
|
||||
if let Some(init_sizes_done) = init_sizes_done {
|
||||
// ending up here is not a bug; at the latest logical sizes will be queried by
|
||||
// consumption metrics.
|
||||
let guard = scopeguard::guard_on_success((), |_| tracing::info!("Cancelled before initial logical sizes completed"));
|
||||
init_sizes_done.await;
|
||||
|
||||
scopeguard::ScopeGuard::into_inner(guard);
|
||||
|
||||
startup_checkpoint("initial_logical_sizes", "Initial logical sizes completed after timeout (background jobs already started)");
|
||||
startup_checkpoint(
|
||||
started_startup_at,
|
||||
"background_jobs_can_start",
|
||||
"Starting background jobs",
|
||||
);
|
||||
|
||||
// We are done. If we skipped any phases due to timeout, run them to completion here so that
|
||||
// they will eventually update their startup_checkpoint, and so that we do not declare the
|
||||
// 'complete' stage until all the other stages are really done.
|
||||
let guard = scopeguard::guard_on_success((), |_| {
|
||||
tracing::info!("Cancelled before waiting for skipped phases done")
|
||||
});
|
||||
if let Some(f) = init_remote_skipped {
|
||||
f.await;
|
||||
}
|
||||
if let Some(f) = init_load_skipped {
|
||||
f.await;
|
||||
}
|
||||
if let Some(f) = logical_sizes_skipped {
|
||||
f.await;
|
||||
}
|
||||
scopeguard::ScopeGuard::into_inner(guard);
|
||||
|
||||
startup_checkpoint("complete", "Startup complete");
|
||||
startup_checkpoint(started_startup_at, "complete", "Startup complete");
|
||||
};
|
||||
|
||||
async move {
|
||||
@@ -470,6 +551,7 @@ fn start_pageserver(
|
||||
let router_state = Arc::new(
|
||||
http::routes::State::new(
|
||||
conf,
|
||||
tenant_manager,
|
||||
http_auth.clone(),
|
||||
remote_storage.clone(),
|
||||
broker_client.clone(),
|
||||
@@ -574,6 +656,7 @@ fn start_pageserver(
|
||||
pageserver_listener,
|
||||
conf.pg_auth_type,
|
||||
libpq_ctx,
|
||||
task_mgr::shutdown_token(),
|
||||
)
|
||||
.await
|
||||
},
|
||||
|
||||
@@ -33,8 +33,7 @@ use crate::disk_usage_eviction_task::DiskUsageEvictionTaskConfig;
|
||||
use crate::tenant::config::TenantConf;
|
||||
use crate::tenant::config::TenantConfOpt;
|
||||
use crate::tenant::{
|
||||
TENANTS_SEGMENT_NAME, TENANT_ATTACHING_MARKER_FILENAME, TENANT_DELETED_MARKER_FILE_NAME,
|
||||
TIMELINES_SEGMENT_NAME,
|
||||
TENANTS_SEGMENT_NAME, TENANT_DELETED_MARKER_FILE_NAME, TIMELINES_SEGMENT_NAME,
|
||||
};
|
||||
use crate::{
|
||||
IGNORED_TENANT_FILE_NAME, METADATA_FILE_NAME, TENANT_CONFIG_NAME, TENANT_LOCATION_CONFIG_NAME,
|
||||
@@ -162,7 +161,7 @@ pub struct PageServerConf {
|
||||
pub http_auth_type: AuthType,
|
||||
/// authentication method for libpq connections from compute
|
||||
pub pg_auth_type: AuthType,
|
||||
/// Path to a file containing public key for verifying JWT tokens.
|
||||
/// Path to a file or directory containing public key(s) for verifying JWT tokens.
|
||||
/// Used for both mgmt and compute auth, if enabled.
|
||||
pub auth_validation_public_key_path: Option<Utf8PathBuf>,
|
||||
|
||||
@@ -633,11 +632,6 @@ impl PageServerConf {
|
||||
self.tenants_path().join(tenant_id.to_string())
|
||||
}
|
||||
|
||||
pub fn tenant_attaching_mark_file_path(&self, tenant_id: &TenantId) -> Utf8PathBuf {
|
||||
self.tenant_path(tenant_id)
|
||||
.join(TENANT_ATTACHING_MARKER_FILENAME)
|
||||
}
|
||||
|
||||
pub fn tenant_ignore_mark_file_path(&self, tenant_id: &TenantId) -> Utf8PathBuf {
|
||||
self.tenant_path(tenant_id).join(IGNORED_TENANT_FILE_NAME)
|
||||
}
|
||||
@@ -1320,12 +1314,6 @@ broker_endpoint = '{broker_endpoint}'
|
||||
assert_eq!(
|
||||
parsed_remote_storage_config,
|
||||
RemoteStorageConfig {
|
||||
max_concurrent_syncs: NonZeroUsize::new(
|
||||
remote_storage::DEFAULT_REMOTE_STORAGE_MAX_CONCURRENT_SYNCS
|
||||
)
|
||||
.unwrap(),
|
||||
max_sync_errors: NonZeroU32::new(remote_storage::DEFAULT_REMOTE_STORAGE_MAX_SYNC_ERRORS)
|
||||
.unwrap(),
|
||||
storage: RemoteStorageKind::LocalFs(local_storage_path.clone()),
|
||||
},
|
||||
"Remote storage config should correctly parse the local FS config and fill other storage defaults"
|
||||
@@ -1386,8 +1374,6 @@ broker_endpoint = '{broker_endpoint}'
|
||||
assert_eq!(
|
||||
parsed_remote_storage_config,
|
||||
RemoteStorageConfig {
|
||||
max_concurrent_syncs,
|
||||
max_sync_errors,
|
||||
storage: RemoteStorageKind::AwsS3(S3Config {
|
||||
bucket_name: bucket_name.clone(),
|
||||
bucket_region: bucket_region.clone(),
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
//! and push them to a HTTP endpoint.
|
||||
use crate::context::{DownloadBehavior, RequestContext};
|
||||
use crate::task_mgr::{self, TaskKind, BACKGROUND_RUNTIME};
|
||||
use crate::tenant::tasks::BackgroundLoopKind;
|
||||
use crate::tenant::{mgr, LogicalSizeCalculationCause};
|
||||
use camino::Utf8PathBuf;
|
||||
use consumption_metrics::EventType;
|
||||
@@ -10,6 +11,7 @@ use reqwest::Url;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, SystemTime};
|
||||
use tokio::time::Instant;
|
||||
use tracing::*;
|
||||
use utils::id::NodeId;
|
||||
|
||||
@@ -87,22 +89,12 @@ pub async fn collect_metrics(
|
||||
|
||||
let node_id = node_id.to_string();
|
||||
|
||||
// reminder: ticker is ready immediatedly
|
||||
let mut ticker = tokio::time::interval(metric_collection_interval);
|
||||
|
||||
loop {
|
||||
let tick_at = tokio::select! {
|
||||
_ = cancel.cancelled() => return Ok(()),
|
||||
tick_at = ticker.tick() => tick_at,
|
||||
};
|
||||
let started_at = Instant::now();
|
||||
|
||||
// these are point in time, with variable "now"
|
||||
let metrics = metrics::collect_all_metrics(&cached_metrics, &ctx).await;
|
||||
|
||||
if metrics.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let metrics = Arc::new(metrics);
|
||||
|
||||
// why not race cancellation here? because we are one of the last tasks, and if we are
|
||||
@@ -141,10 +133,19 @@ pub async fn collect_metrics(
|
||||
let (_, _) = tokio::join!(flush, upload);
|
||||
|
||||
crate::tenant::tasks::warn_when_period_overrun(
|
||||
tick_at.elapsed(),
|
||||
started_at.elapsed(),
|
||||
metric_collection_interval,
|
||||
"consumption_metrics_collect_metrics",
|
||||
BackgroundLoopKind::ConsumptionMetricsCollectMetrics,
|
||||
);
|
||||
|
||||
let res = tokio::time::timeout_at(
|
||||
started_at + metric_collection_interval,
|
||||
task_mgr::shutdown_token().cancelled(),
|
||||
)
|
||||
.await;
|
||||
if res.is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,16 +244,14 @@ async fn calculate_synthetic_size_worker(
|
||||
ctx: &RequestContext,
|
||||
) -> anyhow::Result<()> {
|
||||
info!("starting calculate_synthetic_size_worker");
|
||||
scopeguard::defer! {
|
||||
info!("calculate_synthetic_size_worker stopped");
|
||||
};
|
||||
|
||||
// reminder: ticker is ready immediatedly
|
||||
let mut ticker = tokio::time::interval(synthetic_size_calculation_interval);
|
||||
let cause = LogicalSizeCalculationCause::ConsumptionMetricsSyntheticSize;
|
||||
|
||||
loop {
|
||||
let tick_at = tokio::select! {
|
||||
_ = task_mgr::shutdown_watcher() => return Ok(()),
|
||||
tick_at = ticker.tick() => tick_at,
|
||||
};
|
||||
let started_at = Instant::now();
|
||||
|
||||
let tenants = match mgr::list_tenants().await {
|
||||
Ok(tenants) => tenants,
|
||||
@@ -267,7 +266,12 @@ async fn calculate_synthetic_size_worker(
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Ok(tenant) = mgr::get_tenant(tenant_id, true).await {
|
||||
if let Ok(tenant) = mgr::get_tenant(tenant_id, true) {
|
||||
// TODO should we use concurrent_background_tasks_rate_limit() here, like the other background tasks?
|
||||
// We can put in some prioritization for consumption metrics.
|
||||
// Same for the loop that fetches computed metrics.
|
||||
// By using the same limiter, we centralize metrics collection for "start" and "finished" counters,
|
||||
// which turns out is really handy to understand the system.
|
||||
if let Err(e) = tenant.calculate_synthetic_size(cause, ctx).await {
|
||||
error!("failed to calculate synthetic size for tenant {tenant_id}: {e:#}");
|
||||
}
|
||||
@@ -275,9 +279,18 @@ async fn calculate_synthetic_size_worker(
|
||||
}
|
||||
|
||||
crate::tenant::tasks::warn_when_period_overrun(
|
||||
tick_at.elapsed(),
|
||||
started_at.elapsed(),
|
||||
synthetic_size_calculation_interval,
|
||||
"consumption_metrics_synthetic_size_worker",
|
||||
BackgroundLoopKind::ConsumptionMetricsSyntheticSizeWorker,
|
||||
);
|
||||
|
||||
let res = tokio::time::timeout_at(
|
||||
started_at + synthetic_size_calculation_interval,
|
||||
task_mgr::shutdown_token().cancelled(),
|
||||
)
|
||||
.await;
|
||||
if res.is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ use anyhow::Context;
|
||||
use chrono::{DateTime, Utc};
|
||||
use consumption_metrics::EventType;
|
||||
use futures::stream::StreamExt;
|
||||
use serde_with::serde_as;
|
||||
use std::{sync::Arc, time::SystemTime};
|
||||
use utils::{
|
||||
id::{TenantId, TimelineId},
|
||||
@@ -42,13 +41,10 @@ pub(super) enum Name {
|
||||
///
|
||||
/// This is a denormalization done at the MetricsKey const methods; these should not be constructed
|
||||
/// elsewhere.
|
||||
#[serde_with::serde_as]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
|
||||
pub(crate) struct MetricsKey {
|
||||
#[serde_as(as = "serde_with::DisplayFromStr")]
|
||||
pub(super) tenant_id: TenantId,
|
||||
|
||||
#[serde_as(as = "Option<serde_with::DisplayFromStr>")]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(super) timeline_id: Option<TimelineId>,
|
||||
|
||||
@@ -206,7 +202,6 @@ pub(super) async fn collect_all_metrics(
|
||||
None
|
||||
} else {
|
||||
crate::tenant::mgr::get_tenant(id, true)
|
||||
.await
|
||||
.ok()
|
||||
.map(|tenant| (id, tenant))
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use consumption_metrics::{Event, EventChunk, IdempotencyKey, CHUNK_SIZE};
|
||||
use serde_with::serde_as;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::Instrument;
|
||||
|
||||
@@ -7,12 +6,9 @@ use super::{metrics::Name, Cache, MetricsKey, RawMetric};
|
||||
use utils::id::{TenantId, TimelineId};
|
||||
|
||||
/// How the metrics from pageserver are identified.
|
||||
#[serde_with::serde_as]
|
||||
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, Copy, PartialEq)]
|
||||
struct Ids {
|
||||
#[serde_as(as = "serde_with::DisplayFromStr")]
|
||||
pub(super) tenant_id: TenantId,
|
||||
#[serde_as(as = "Option<serde_with::DisplayFromStr>")]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub(super) timeline_id: Option<TimelineId>,
|
||||
}
|
||||
|
||||
@@ -57,7 +57,10 @@ impl ControlPlaneClient {
|
||||
|
||||
if let Some(jwt) = &conf.control_plane_api_token {
|
||||
let mut headers = hyper::HeaderMap::new();
|
||||
headers.insert("Authorization", jwt.get_contents().parse().unwrap());
|
||||
headers.insert(
|
||||
"Authorization",
|
||||
format!("Bearer {}", jwt.get_contents()).parse().unwrap(),
|
||||
);
|
||||
client = client.default_headers(headers);
|
||||
}
|
||||
|
||||
@@ -144,7 +147,7 @@ impl ControlPlaneGenerationsApi for ControlPlaneClient {
|
||||
Ok(response
|
||||
.tenants
|
||||
.into_iter()
|
||||
.map(|t| (t.id, Generation::new(t.generation)))
|
||||
.map(|t| (t.id, Generation::new(t.gen)))
|
||||
.collect::<HashMap<_, _>>())
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user