Merge branch 'refs/heads/main' into amasterov/cloud-extensions-test

# Conflicts:
#	docker-compose/docker_compose_test.sh
This commit is contained in:
Alexey Masterov
2025-04-25 12:09:21 +02:00
103 changed files with 4520 additions and 2336 deletions

View File

@@ -19,7 +19,7 @@
!pageserver/
!pgxn/
!proxy/
!object_storage/
!endpoint_storage/
!storage_scrubber/
!safekeeper/
!storage_broker/

View File

@@ -275,7 +275,7 @@ jobs:
for io_mode in buffered direct direct-rw ; do
NEON_PAGESERVER_UNIT_TEST_GET_VECTORED_CONCURRENT_IO=$get_vectored_concurrent_io \
NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IOENGINE=$io_engine \
NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IOMODE=$io_mode \
NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IO_MODE=$io_mode \
${cov_prefix} \
cargo nextest run $CARGO_FLAGS $CARGO_FEATURES -E 'package(pageserver)'
done
@@ -395,7 +395,7 @@ jobs:
BUILD_TAG: ${{ inputs.build-tag }}
PAGESERVER_VIRTUAL_FILE_IO_ENGINE: tokio-epoll-uring
PAGESERVER_GET_VECTORED_CONCURRENT_IO: sidecar-task
PAGESERVER_VIRTUAL_FILE_IO_MODE: direct
PAGESERVER_VIRTUAL_FILE_IO_MODE: direct-rw
USE_LFC: ${{ matrix.lfc_state == 'with-lfc' && 'true' || 'false' }}
# Temporary disable this step until we figure out why it's so flaky

View File

@@ -63,13 +63,8 @@ jobs:
- name: Cache postgres ${{ matrix.postgres-version }} build
id: cache_pg
uses: tespkg/actions-cache@b7bf5fcc2f98a52ac6080eb0fd282c2f752074b1 # v1.8.0
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
with:
endpoint: ${{ vars.HETZNER_CACHE_REGION }}.${{ vars.HETZNER_CACHE_ENDPOINT }}
bucket: ${{ vars.HETZNER_CACHE_BUCKET }}
accessKey: ${{ secrets.HETZNER_CACHE_ACCESS_KEY }}
secretKey: ${{ secrets.HETZNER_CACHE_SECRET_KEY }}
use-fallback: false
path: pg_install/${{ matrix.postgres-version }}
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-pg-${{ matrix.postgres-version }}-${{ steps.pg_rev.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
@@ -134,25 +129,15 @@ jobs:
- name: Cache postgres v17 build
id: cache_pg
uses: tespkg/actions-cache@b7bf5fcc2f98a52ac6080eb0fd282c2f752074b1 # v1.8.0
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
with:
endpoint: ${{ vars.HETZNER_CACHE_REGION }}.${{ vars.HETZNER_CACHE_ENDPOINT }}
bucket: ${{ vars.HETZNER_CACHE_BUCKET }}
accessKey: ${{ secrets.HETZNER_CACHE_ACCESS_KEY }}
secretKey: ${{ secrets.HETZNER_CACHE_SECRET_KEY }}
use-fallback: false
path: pg_install/v17
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-pg-v17-${{ steps.pg_rev.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
- name: Cache walproposer-lib
id: cache_walproposer_lib
uses: tespkg/actions-cache@b7bf5fcc2f98a52ac6080eb0fd282c2f752074b1 # v1.8.0
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
with:
endpoint: ${{ vars.HETZNER_CACHE_REGION }}.${{ vars.HETZNER_CACHE_ENDPOINT }}
bucket: ${{ vars.HETZNER_CACHE_BUCKET }}
accessKey: ${{ secrets.HETZNER_CACHE_ACCESS_KEY }}
secretKey: ${{ secrets.HETZNER_CACHE_SECRET_KEY }}
use-fallback: false
path: pg_install/build/walproposer-lib
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-walproposer_lib-v17-${{ steps.pg_rev.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
@@ -218,57 +203,32 @@ jobs:
- name: Cache postgres v14 build
id: cache_pg
uses: tespkg/actions-cache@b7bf5fcc2f98a52ac6080eb0fd282c2f752074b1 # v1.8.0
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
with:
endpoint: ${{ vars.HETZNER_CACHE_REGION }}.${{ vars.HETZNER_CACHE_ENDPOINT }}
bucket: ${{ vars.HETZNER_CACHE_BUCKET }}
accessKey: ${{ secrets.HETZNER_CACHE_ACCESS_KEY }}
secretKey: ${{ secrets.HETZNER_CACHE_SECRET_KEY }}
use-fallback: false
path: pg_install/v14
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-pg-v14-${{ steps.pg_rev_v14.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
- name: Cache postgres v15 build
id: cache_pg_v15
uses: tespkg/actions-cache@b7bf5fcc2f98a52ac6080eb0fd282c2f752074b1 # v1.8.0
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
with:
endpoint: ${{ vars.HETZNER_CACHE_REGION }}.${{ vars.HETZNER_CACHE_ENDPOINT }}
bucket: ${{ vars.HETZNER_CACHE_BUCKET }}
accessKey: ${{ secrets.HETZNER_CACHE_ACCESS_KEY }}
secretKey: ${{ secrets.HETZNER_CACHE_SECRET_KEY }}
use-fallback: false
path: pg_install/v15
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-pg-v15-${{ steps.pg_rev_v15.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
- name: Cache postgres v16 build
id: cache_pg_v16
uses: tespkg/actions-cache@b7bf5fcc2f98a52ac6080eb0fd282c2f752074b1 # v1.8.0
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
with:
endpoint: ${{ vars.HETZNER_CACHE_REGION }}.${{ vars.HETZNER_CACHE_ENDPOINT }}
bucket: ${{ vars.HETZNER_CACHE_BUCKET }}
accessKey: ${{ secrets.HETZNER_CACHE_ACCESS_KEY }}
secretKey: ${{ secrets.HETZNER_CACHE_SECRET_KEY }}
use-fallback: false
path: pg_install/v16
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-pg-v16-${{ steps.pg_rev_v16.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
- name: Cache postgres v17 build
id: cache_pg_v17
uses: tespkg/actions-cache@b7bf5fcc2f98a52ac6080eb0fd282c2f752074b1 # v1.8.0
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
with:
endpoint: ${{ vars.HETZNER_CACHE_REGION }}.${{ vars.HETZNER_CACHE_ENDPOINT }}
bucket: ${{ vars.HETZNER_CACHE_BUCKET }}
accessKey: ${{ secrets.HETZNER_CACHE_ACCESS_KEY }}
secretKey: ${{ secrets.HETZNER_CACHE_SECRET_KEY }}
use-fallback: false
path: pg_install/v17
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-pg-v17-${{ steps.pg_rev_v17.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
- name: Cache cargo deps (only for v17)
uses: tespkg/actions-cache@b7bf5fcc2f98a52ac6080eb0fd282c2f752074b1 # v1.8.0
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
with:
endpoint: ${{ vars.HETZNER_CACHE_REGION }}.${{ vars.HETZNER_CACHE_ENDPOINT }}
bucket: ${{ vars.HETZNER_CACHE_BUCKET }}
accessKey: ${{ secrets.HETZNER_CACHE_ACCESS_KEY }}
secretKey: ${{ secrets.HETZNER_CACHE_SECRET_KEY }}
use-fallback: false
path: |
~/.cargo/registry
!~/.cargo/registry/src
@@ -278,13 +238,8 @@ jobs:
- name: Cache walproposer-lib
id: cache_walproposer_lib
uses: tespkg/actions-cache@b7bf5fcc2f98a52ac6080eb0fd282c2f752074b1 # v1.8.0
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
with:
endpoint: ${{ vars.HETZNER_CACHE_REGION }}.${{ vars.HETZNER_CACHE_ENDPOINT }}
bucket: ${{ vars.HETZNER_CACHE_BUCKET }}
accessKey: ${{ secrets.HETZNER_CACHE_ACCESS_KEY }}
secretKey: ${{ secrets.HETZNER_CACHE_SECRET_KEY }}
use-fallback: false
path: pg_install/build/walproposer-lib
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-walproposer_lib-v17-${{ steps.pg_rev_v17.outputs.pg_rev }}-${{ hashFiles('Makefile') }}

View File

@@ -324,7 +324,7 @@ jobs:
TEST_RESULT_CONNSTR: "${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}"
PAGESERVER_VIRTUAL_FILE_IO_ENGINE: tokio-epoll-uring
PAGESERVER_GET_VECTORED_CONCURRENT_IO: sidecar-task
PAGESERVER_VIRTUAL_FILE_IO_MODE: direct
PAGESERVER_VIRTUAL_FILE_IO_MODE: direct-rw
SYNC_BETWEEN_TESTS: true
# XXX: no coverage data handling here, since benchmarks are run on release builds,
# while coverage is currently collected for the debug ones
@@ -1238,7 +1238,7 @@ jobs:
env:
GH_TOKEN: ${{ secrets.CI_ACCESS_TOKEN }}
run: |
TIMEOUT=1800 # 30 minutes, usually it takes ~2-3 minutes, but if runners are busy, it might take longer
TIMEOUT=5400 # 90 minutes, usually it takes ~2-3 minutes, but if runners are busy, it might take longer
INTERVAL=15 # try each N seconds
last_status="" # a variable to carry the last status of the "build-and-upload-extensions" context

134
Cargo.lock generated
View File

@@ -40,7 +40,7 @@ dependencies = [
"getrandom 0.2.11",
"once_cell",
"version_check",
"zerocopy",
"zerocopy 0.7.31",
]
[[package]]
@@ -1323,7 +1323,6 @@ dependencies = [
"serde_json",
"serde_with",
"signal-hook",
"spki 0.7.3",
"tar",
"thiserror 1.0.69",
"tokio",
@@ -2037,6 +2036,33 @@ dependencies = [
"zeroize",
]
[[package]]
name = "endpoint_storage"
version = "0.0.1"
dependencies = [
"anyhow",
"axum",
"axum-extra",
"camino",
"camino-tempfile",
"futures",
"http-body-util",
"itertools 0.10.5",
"jsonwebtoken",
"prometheus",
"rand 0.8.5",
"remote_storage",
"serde",
"serde_json",
"test-log",
"tokio",
"tokio-util",
"tower 0.5.2",
"tracing",
"utils",
"workspace_hack",
]
[[package]]
name = "enum-map"
version = "2.5.0"
@@ -3998,33 +4024,6 @@ dependencies = [
"memchr",
]
[[package]]
name = "object_storage"
version = "0.0.1"
dependencies = [
"anyhow",
"axum",
"axum-extra",
"camino",
"camino-tempfile",
"futures",
"http-body-util",
"itertools 0.10.5",
"jsonwebtoken",
"prometheus",
"rand 0.8.5",
"remote_storage",
"serde",
"serde_json",
"test-log",
"tokio",
"tokio-util",
"tower 0.5.2",
"tracing",
"utils",
"workspace_hack",
]
[[package]]
name = "once_cell"
version = "1.20.2"
@@ -4302,6 +4301,7 @@ dependencies = [
"remote_storage",
"reqwest",
"rpds",
"rstest",
"rustls 0.23.18",
"scopeguard",
"send-future",
@@ -4415,9 +4415,9 @@ dependencies = [
[[package]]
name = "papaya"
version = "0.2.0"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aab21828b6b5952fdadd6c377728ffae53ec3a21b2febc47319ab65741f7e2fd"
checksum = "6827e3fc394523c21d4464d02c0bb1c19966ea4a58a9844ad6d746214179d2bc"
dependencies = [
"equivalent",
"seize",
@@ -5204,7 +5204,7 @@ dependencies = [
"walkdir",
"workspace_hack",
"x509-cert",
"zerocopy",
"zerocopy 0.8.24",
]
[[package]]
@@ -5594,7 +5594,7 @@ dependencies = [
"wasm-bindgen-futures",
"wasm-streams",
"web-sys",
"webpki-roots 0.26.1",
"webpki-roots",
"winreg",
]
@@ -6195,13 +6195,13 @@ checksum = "224e328af6e080cddbab3c770b1cf50f0351ba0577091ef2410c3951d835ff87"
[[package]]
name = "sentry"
version = "0.32.3"
version = "0.37.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00421ed8fa0c995f07cde48ba6c89e80f2b312f74ff637326f392fbfd23abe02"
checksum = "255914a8e53822abd946e2ce8baa41d4cded6b8e938913b7f7b9da5b7ab44335"
dependencies = [
"httpdate",
"reqwest",
"rustls 0.21.12",
"rustls 0.23.18",
"sentry-backtrace",
"sentry-contexts",
"sentry-core",
@@ -6209,14 +6209,14 @@ dependencies = [
"sentry-tracing",
"tokio",
"ureq",
"webpki-roots 0.25.2",
"webpki-roots",
]
[[package]]
name = "sentry-backtrace"
version = "0.32.3"
version = "0.37.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a79194074f34b0cbe5dd33896e5928bbc6ab63a889bd9df2264af5acb186921e"
checksum = "00293cd332a859961f24fd69258f7e92af736feaeb91020cff84dac4188a4302"
dependencies = [
"backtrace",
"once_cell",
@@ -6226,9 +6226,9 @@ dependencies = [
[[package]]
name = "sentry-contexts"
version = "0.32.3"
version = "0.37.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eba8870c5dba2bfd9db25c75574a11429f6b95957b0a78ac02e2970dd7a5249a"
checksum = "961990f9caa76476c481de130ada05614cd7f5aa70fb57c2142f0e09ad3fb2aa"
dependencies = [
"hostname",
"libc",
@@ -6240,9 +6240,9 @@ dependencies = [
[[package]]
name = "sentry-core"
version = "0.32.3"
version = "0.37.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46a75011ea1c0d5c46e9e57df03ce81f5c7f0a9e199086334a1f9c0a541e0826"
checksum = "1a6409d845707d82415c800290a5d63be5e3df3c2e417b0997c60531dfbd35ef"
dependencies = [
"once_cell",
"rand 0.8.5",
@@ -6253,9 +6253,9 @@ dependencies = [
[[package]]
name = "sentry-panic"
version = "0.32.3"
version = "0.37.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2eaa3ecfa3c8750c78dcfd4637cfa2598b95b52897ed184b4dc77fcf7d95060d"
checksum = "609b1a12340495ce17baeec9e08ff8ed423c337c1a84dffae36a178c783623f3"
dependencies = [
"sentry-backtrace",
"sentry-core",
@@ -6263,9 +6263,9 @@ dependencies = [
[[package]]
name = "sentry-tracing"
version = "0.32.3"
version = "0.37.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f715932bf369a61b7256687c6f0554141b7ce097287e30e3f7ed6e9de82498fe"
checksum = "49f4e86402d5c50239dc7d8fd3f6d5e048221d5fcb4e026d8d50ab57fe4644cb"
dependencies = [
"sentry-backtrace",
"sentry-core",
@@ -6275,9 +6275,9 @@ dependencies = [
[[package]]
name = "sentry-types"
version = "0.32.3"
version = "0.37.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4519c900ce734f7a0eb7aba0869dfb225a7af8820634a7dd51449e3b093cfb7c"
checksum = "3d3f117b8755dbede8260952de2aeb029e20f432e72634e8969af34324591631"
dependencies = [
"debugid",
"hex",
@@ -6711,8 +6711,6 @@ version = "0.1.0"
dependencies = [
"anyhow",
"async-stream",
"aws-config",
"aws-sdk-s3",
"camino",
"chrono",
"clap",
@@ -7801,7 +7799,7 @@ dependencies = [
"rustls 0.23.18",
"rustls-pki-types",
"url",
"webpki-roots 0.26.1",
"webpki-roots",
]
[[package]]
@@ -8169,12 +8167,6 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "webpki-roots"
version = "0.25.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "14247bb57be4f377dfb94c72830b8ce8fc6beac03cf4bf7b9732eadd414123fc"
[[package]]
name = "webpki-roots"
version = "0.26.1"
@@ -8482,6 +8474,8 @@ dependencies = [
"regex-syntax 0.8.2",
"reqwest",
"rustls 0.23.18",
"rustls-pki-types",
"rustls-webpki 0.102.8",
"scopeguard",
"sec1 0.7.3",
"serde",
@@ -8510,7 +8504,6 @@ dependencies = [
"tracing-log",
"url",
"uuid",
"zerocopy",
"zeroize",
"zstd",
"zstd-safe",
@@ -8614,8 +8607,16 @@ version = "0.7.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1c4061bedbb353041c12f413700357bec76df2c7e2ca8e4df8bac24c6bf68e3d"
dependencies = [
"byteorder",
"zerocopy-derive",
"zerocopy-derive 0.7.31",
]
[[package]]
name = "zerocopy"
version = "0.8.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2586fea28e186957ef732a5f8b3be2da217d65c5969d4b1e17f973ebbe876879"
dependencies = [
"zerocopy-derive 0.8.24",
]
[[package]]
@@ -8629,6 +8630,17 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "zerocopy-derive"
version = "0.8.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a996a8f63c5c4448cd959ac1bab0aaa3306ccfd060472f85943ee0750f0169be"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "zerofrom"
version = "0.1.5"

View File

@@ -40,7 +40,7 @@ members = [
"libs/proxy/postgres-protocol2",
"libs/proxy/postgres-types2",
"libs/proxy/tokio-postgres2",
"object_storage",
"endpoint_storage",
]
[workspace.package]
@@ -164,7 +164,7 @@ scopeguard = "1.1"
sysinfo = "0.29.2"
sd-notify = "0.4.1"
send-future = "0.1.0"
sentry = { version = "0.32", default-features = false, features = ["backtrace", "contexts", "panic", "rustls", "reqwest" ] }
sentry = { version = "0.37", default-features = false, features = ["backtrace", "contexts", "panic", "rustls", "reqwest" ] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1"
serde_path_to_error = "0.1"
@@ -220,7 +220,7 @@ uuid = { version = "1.6.1", features = ["v4", "v7", "serde"] }
walkdir = "2.3.2"
rustls-native-certs = "0.8"
whoami = "1.5.1"
zerocopy = { version = "0.7", features = ["derive"] }
zerocopy = { version = "0.8", features = ["derive", "simd"] }
json-structural-diff = { version = "0.2.0" }
x509-cert = { version = "0.2.5" }

View File

@@ -89,7 +89,7 @@ RUN set -e \
--bin storage_broker \
--bin storage_controller \
--bin proxy \
--bin object_storage \
--bin endpoint_storage \
--bin neon_local \
--bin storage_scrubber \
--locked --release
@@ -122,7 +122,7 @@ COPY --from=build --chown=neon:neon /home/nonroot/target/release/safekeeper
COPY --from=build --chown=neon:neon /home/nonroot/target/release/storage_broker /usr/local/bin
COPY --from=build --chown=neon:neon /home/nonroot/target/release/storage_controller /usr/local/bin
COPY --from=build --chown=neon:neon /home/nonroot/target/release/proxy /usr/local/bin
COPY --from=build --chown=neon:neon /home/nonroot/target/release/object_storage /usr/local/bin
COPY --from=build --chown=neon:neon /home/nonroot/target/release/endpoint_storage /usr/local/bin
COPY --from=build --chown=neon:neon /home/nonroot/target/release/neon_local /usr/local/bin
COPY --from=build --chown=neon:neon /home/nonroot/target/release/storage_scrubber /usr/local/bin

View File

@@ -173,7 +173,7 @@ RUN curl -fsSL "https://github.com/protocolbuffers/protobuf/releases/download/v$
&& rm -rf protoc.zip protoc
# s5cmd
ENV S5CMD_VERSION=2.2.2
ENV S5CMD_VERSION=2.3.0
RUN curl -sL "https://github.com/peak/s5cmd/releases/download/v${S5CMD_VERSION}/s5cmd_${S5CMD_VERSION}_Linux-$(uname -m | sed 's/x86_64/64bit/g' | sed 's/aarch64/arm64/g').tar.gz" | tar zxvf - s5cmd \
&& chmod +x s5cmd \
&& mv s5cmd /usr/local/bin/s5cmd
@@ -206,7 +206,7 @@ RUN curl "https://awscli.amazonaws.com/awscli-exe-linux-$(uname -m).zip" -o "aws
&& rm awscliv2.zip
# Mold: A Modern Linker
ENV MOLD_VERSION=v2.34.1
ENV MOLD_VERSION=v2.37.1
RUN set -e \
&& git clone https://github.com/rui314/mold.git \
&& mkdir mold/build \
@@ -268,7 +268,7 @@ WORKDIR /home/nonroot
RUN echo -e "--retry-connrefused\n--connect-timeout 15\n--retry 5\n--max-time 300\n" > /home/nonroot/.curlrc
# Python
ENV PYTHON_VERSION=3.11.10 \
ENV PYTHON_VERSION=3.11.12 \
PYENV_ROOT=/home/nonroot/.pyenv \
PATH=/home/nonroot/.pyenv/shims:/home/nonroot/.pyenv/bin:/home/nonroot/.poetry/bin:$PATH
RUN set -e \
@@ -296,12 +296,12 @@ ENV RUSTC_VERSION=1.86.0
ENV RUSTUP_HOME="/home/nonroot/.rustup"
ENV PATH="/home/nonroot/.cargo/bin:${PATH}"
ARG RUSTFILT_VERSION=0.2.1
ARG CARGO_HAKARI_VERSION=0.9.33
ARG CARGO_DENY_VERSION=0.16.2
ARG CARGO_HACK_VERSION=0.6.33
ARG CARGO_NEXTEST_VERSION=0.9.85
ARG CARGO_HAKARI_VERSION=0.9.36
ARG CARGO_DENY_VERSION=0.18.2
ARG CARGO_HACK_VERSION=0.6.36
ARG CARGO_NEXTEST_VERSION=0.9.94
ARG CARGO_CHEF_VERSION=0.1.71
ARG CARGO_DIESEL_CLI_VERSION=2.2.6
ARG CARGO_DIESEL_CLI_VERSION=2.2.9
RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux-gnu/rustup-init && whoami && \
chmod +x rustup-init && \
./rustup-init -y --default-toolchain ${RUSTC_VERSION} && \

View File

@@ -44,7 +44,6 @@ serde.workspace = true
serde_with.workspace = true
serde_json.workspace = true
signal-hook.workspace = true
spki = { version = "0.7.3", features = ["std"] }
tar.workspace = true
tower.workspace = true
tower-http.workspace = true

View File

@@ -57,13 +57,24 @@ use tracing::{error, info};
use url::Url;
use utils::failpoint_support;
// Compatibility hack: if the control plane specified any remote-ext-config
// use the default value for extension storage proxy gateway.
// Remove this once the control plane is updated to pass the gateway URL
fn parse_remote_ext_config(arg: &str) -> Result<String> {
if arg.starts_with("http") {
Ok(arg.trim_end_matches('/').to_string())
} else {
Ok("http://pg-ext-s3-gateway".to_string())
}
}
#[derive(Parser)]
#[command(rename_all = "kebab-case")]
struct Cli {
#[arg(short = 'b', long, default_value = "postgres", env = "POSTGRES_PATH")]
pub pgbin: String,
#[arg(short = 'r', long)]
#[arg(short = 'r', long, value_parser = parse_remote_ext_config)]
pub remote_ext_config: Option<String>,
/// The port to bind the external listening HTTP server to. Clients running

View File

@@ -1,8 +1,8 @@
use metrics::core::{AtomicF64, Collector, GenericGauge};
use metrics::core::{AtomicF64, AtomicU64, Collector, GenericCounter, GenericGauge};
use metrics::proto::MetricFamily;
use metrics::{
IntCounterVec, IntGaugeVec, UIntGaugeVec, register_gauge, register_int_counter_vec,
register_int_gauge_vec, register_uint_gauge_vec,
IntCounterVec, IntGaugeVec, UIntGaugeVec, register_gauge, register_int_counter,
register_int_counter_vec, register_int_gauge_vec, register_uint_gauge_vec,
};
use once_cell::sync::Lazy;
@@ -81,6 +81,22 @@ pub(crate) static COMPUTE_CTL_UP: Lazy<IntGaugeVec> = Lazy::new(|| {
.expect("failed to define a metric")
});
pub(crate) static PG_CURR_DOWNTIME_MS: Lazy<GenericGauge<AtomicF64>> = Lazy::new(|| {
register_gauge!(
"compute_pg_current_downtime_ms",
"Non-cumulative duration of Postgres downtime in ms; resets after successful check",
)
.expect("failed to define a metric")
});
pub(crate) static PG_TOTAL_DOWNTIME_MS: Lazy<GenericCounter<AtomicU64>> = Lazy::new(|| {
register_int_counter!(
"compute_pg_downtime_ms_total",
"Cumulative duration of Postgres downtime in ms",
)
.expect("failed to define a metric")
});
pub fn collect() -> Vec<MetricFamily> {
let mut metrics = COMPUTE_CTL_UP.collect();
metrics.extend(INSTALLED_EXTENSIONS.collect());
@@ -88,5 +104,7 @@ pub fn collect() -> Vec<MetricFamily> {
metrics.extend(REMOTE_EXT_REQUESTS_TOTAL.collect());
metrics.extend(DB_MIGRATION_FAILED.collect());
metrics.extend(AUDIT_LOG_DIR_SIZE.collect());
metrics.extend(PG_CURR_DOWNTIME_MS.collect());
metrics.extend(PG_TOTAL_DOWNTIME_MS.collect());
metrics
}

View File

@@ -6,197 +6,294 @@ use chrono::{DateTime, Utc};
use compute_api::responses::ComputeStatus;
use compute_api::spec::ComputeFeature;
use postgres::{Client, NoTls};
use tracing::{debug, error, info, warn};
use tracing::{Level, error, info, instrument, span};
use crate::compute::ComputeNode;
use crate::metrics::{PG_CURR_DOWNTIME_MS, PG_TOTAL_DOWNTIME_MS};
const MONITOR_CHECK_INTERVAL: Duration = Duration::from_millis(500);
// Spin in a loop and figure out the last activity time in the Postgres.
// Then update it in the shared state. This function never errors out.
// NB: the only expected panic is at `Mutex` unwrap(), all other errors
// should be handled gracefully.
fn watch_compute_activity(compute: &ComputeNode) {
// Suppose that `connstr` doesn't change
let connstr = compute.params.connstr.clone();
let conf = compute.get_conn_conf(Some("compute_ctl:activity_monitor"));
struct ComputeMonitor {
compute: Arc<ComputeNode>,
// During startup and configuration we connect to every Postgres database,
// but we don't want to count this as some user activity. So wait until
// the compute fully started before monitoring activity.
wait_for_postgres_start(compute);
/// The moment when Postgres had some activity,
/// that should prevent compute from being suspended.
last_active: Option<DateTime<Utc>>,
// Define `client` outside of the loop to reuse existing connection if it's active.
let mut client = conf.connect(NoTls);
/// The moment when we last tried to check Postgres.
last_checked: DateTime<Utc>,
/// The last moment we did a successful Postgres check.
last_up: DateTime<Utc>,
let mut sleep = false;
let mut prev_active_time: Option<f64> = None;
let mut prev_sessions: Option<i64> = None;
/// Only used for internal statistics change tracking
/// between monitor runs and can be outdated.
active_time: Option<f64>,
/// Only used for internal statistics change tracking
/// between monitor runs and can be outdated.
sessions: Option<i64>,
if compute.has_feature(ComputeFeature::ActivityMonitorExperimental) {
info!("starting experimental activity monitor for {}", connstr);
} else {
info!("starting activity monitor for {}", connstr);
/// Use experimental statistics-based activity monitor. It's no longer
/// 'experimental' per se, as it's enabled for everyone, but we still
/// keep the flag as an option to turn it off in some cases if it will
/// misbehave.
experimental: bool,
}
impl ComputeMonitor {
fn report_down(&self) {
let now = Utc::now();
// Calculate and report current downtime
// (since the last time Postgres was up)
let downtime = now.signed_duration_since(self.last_up);
PG_CURR_DOWNTIME_MS.set(downtime.num_milliseconds() as f64);
// Calculate and update total downtime
// (cumulative duration of Postgres downtime in ms)
let inc = now
.signed_duration_since(self.last_checked)
.num_milliseconds();
PG_TOTAL_DOWNTIME_MS.inc_by(inc as u64);
}
loop {
// We use `continue` a lot, so it's more convenient to sleep at the top of the loop.
// But skip the first sleep, so we can connect to Postgres immediately.
if sleep {
// Should be outside of the mutex lock to allow others to read while we sleep.
thread::sleep(MONITOR_CHECK_INTERVAL);
} else {
sleep = true;
}
fn report_up(&mut self) {
self.last_up = Utc::now();
PG_CURR_DOWNTIME_MS.set(0.0);
}
match &mut client {
Ok(cli) => {
if cli.is_closed() {
info!("connection to Postgres is closed, trying to reconnect");
fn downtime_info(&self) -> String {
format!(
"total_ms: {}, current_ms: {}, last_up: {}",
PG_TOTAL_DOWNTIME_MS.get(),
PG_CURR_DOWNTIME_MS.get(),
self.last_up
)
}
// Connection is closed, reconnect and try again.
client = conf.connect(NoTls);
continue;
}
/// Spin in a loop and figure out the last activity time in the Postgres.
/// Then update it in the shared state. This function never errors out.
/// NB: the only expected panic is at `Mutex` unwrap(), all other errors
/// should be handled gracefully.
#[instrument(skip_all)]
pub fn run(&mut self) {
// Suppose that `connstr` doesn't change
let connstr = self.compute.params.connstr.clone();
let conf = self
.compute
.get_conn_conf(Some("compute_ctl:compute_monitor"));
// This is a new logic, only enable if the feature flag is set.
// TODO: remove this once we are sure that it works OR drop it altogether.
if compute.has_feature(ComputeFeature::ActivityMonitorExperimental) {
// First, check if the total active time or sessions across all databases has changed.
// If it did, it means that user executed some queries. In theory, it can even go down if
// some databases were dropped, but it's still a user activity.
match get_database_stats(cli) {
Ok((active_time, sessions)) => {
let mut detected_activity = false;
// During startup and configuration we connect to every Postgres database,
// but we don't want to count this as some user activity. So wait until
// the compute fully started before monitoring activity.
wait_for_postgres_start(&self.compute);
prev_active_time = match prev_active_time {
Some(prev_active_time) => {
if active_time != prev_active_time {
detected_activity = true;
}
Some(active_time)
}
None => Some(active_time),
};
prev_sessions = match prev_sessions {
Some(prev_sessions) => {
if sessions != prev_sessions {
detected_activity = true;
}
Some(sessions)
}
None => Some(sessions),
};
// Define `client` outside of the loop to reuse existing connection if it's active.
let mut client = conf.connect(NoTls);
if detected_activity {
// Update the last active time and continue, we don't need to
// check backends state change.
compute.update_last_active(Some(Utc::now()));
continue;
}
}
Err(e) => {
error!("could not get database statistics: {}", e);
continue;
}
}
}
info!("starting compute monitor for {}", connstr);
// Second, if database statistics is the same, check all backends state change,
// maybe there is some with more recent activity. `get_backends_state_change()`
// can return None or stale timestamp, so it's `compute.update_last_active()`
// responsibility to check if the new timestamp is more recent than the current one.
// This helps us to discover new sessions, that did nothing yet.
match get_backends_state_change(cli) {
Ok(last_active) => {
compute.update_last_active(last_active);
}
Err(e) => {
error!("could not get backends state change: {}", e);
}
}
// Finally, if there are existing (logical) walsenders, do not suspend.
//
// walproposer doesn't currently show up in pg_stat_replication,
// but protect if it will be
let ws_count_query = "select count(*) from pg_stat_replication where application_name != 'walproposer';";
match cli.query_one(ws_count_query, &[]) {
Ok(r) => match r.try_get::<&str, i64>("count") {
Ok(num_ws) => {
if num_ws > 0 {
compute.update_last_active(Some(Utc::now()));
continue;
}
}
Err(e) => {
warn!("failed to parse walsenders count: {:?}", e);
continue;
}
},
Err(e) => {
warn!("failed to get list of walsenders: {:?}", e);
continue;
}
}
//
// Don't suspend compute if there is an active logical replication subscription
//
// `where pid is not null` to filter out read only computes and subscription on branches
//
let logical_subscriptions_query =
"select count(*) from pg_stat_subscription where pid is not null;";
match cli.query_one(logical_subscriptions_query, &[]) {
Ok(row) => match row.try_get::<&str, i64>("count") {
Ok(num_subscribers) => {
if num_subscribers > 0 {
compute.update_last_active(Some(Utc::now()));
continue;
}
}
Err(e) => {
warn!("failed to parse `pg_stat_subscription` count: {:?}", e);
continue;
}
},
Err(e) => {
warn!(
"failed to get list of active logical replication subscriptions: {:?}",
e
loop {
match &mut client {
Ok(cli) => {
if cli.is_closed() {
info!(
downtime_info = self.downtime_info(),
"connection to Postgres is closed, trying to reconnect"
);
continue;
}
}
//
// Do not suspend compute if autovacuum is running
//
let autovacuum_count_query = "select count(*) from pg_stat_activity where backend_type = 'autovacuum worker'";
match cli.query_one(autovacuum_count_query, &[]) {
Ok(r) => match r.try_get::<&str, i64>("count") {
Ok(num_workers) => {
if num_workers > 0 {
compute.update_last_active(Some(Utc::now()));
continue;
self.report_down();
// Connection is closed, reconnect and try again.
client = conf.connect(NoTls);
} else {
match self.check(cli) {
Ok(_) => {
self.report_up();
self.compute.update_last_active(self.last_active);
}
Err(e) => {
// Although we have many places where we can return errors in `check()`,
// normally it shouldn't happen. I.e., we will likely return error if
// connection got broken, query timed out, Postgres returned invalid data, etc.
// In all such cases it's suspicious, so let's report this as downtime.
self.report_down();
error!(
downtime_info = self.downtime_info(),
"could not check Postgres: {}", e
);
// Reconnect to Postgres just in case. During tests, I noticed
// that queries in `check()` can fail with `connection closed`,
// but `cli.is_closed()` above doesn't detect it. Even if old
// connection is still alive, it will be dropped when we reassign
// `client` to a new connection.
client = conf.connect(NoTls);
}
}
Err(e) => {
warn!("failed to parse autovacuum workers count: {:?}", e);
continue;
}
},
Err(e) => {
warn!("failed to get list of autovacuum workers: {:?}", e);
continue;
}
}
}
Err(e) => {
debug!("could not connect to Postgres: {}, retrying", e);
Err(e) => {
info!(
downtime_info = self.downtime_info(),
"could not connect to Postgres: {}, retrying", e
);
self.report_down();
// Establish a new connection and try again.
client = conf.connect(NoTls);
// Establish a new connection and try again.
client = conf.connect(NoTls);
}
}
// Reset the `last_checked` timestamp and sleep before the next iteration.
self.last_checked = Utc::now();
thread::sleep(MONITOR_CHECK_INTERVAL);
}
}
#[instrument(skip_all)]
fn check(&mut self, cli: &mut Client) -> anyhow::Result<()> {
// This is new logic, only enable if the feature flag is set.
// TODO: remove this once we are sure that it works OR drop it altogether.
if self.experimental {
// Check if the total active time or sessions across all databases has changed.
// If it did, it means that user executed some queries. In theory, it can even go down if
// some databases were dropped, but it's still user activity.
match get_database_stats(cli) {
Ok((active_time, sessions)) => {
let mut detected_activity = false;
if let Some(prev_active_time) = self.active_time {
if active_time != prev_active_time {
detected_activity = true;
}
}
self.active_time = Some(active_time);
if let Some(prev_sessions) = self.sessions {
if sessions != prev_sessions {
detected_activity = true;
}
}
self.sessions = Some(sessions);
if detected_activity {
// Update the last active time and continue, we don't need to
// check backends state change.
self.last_active = Some(Utc::now());
return Ok(());
}
}
Err(e) => {
return Err(anyhow::anyhow!("could not get database statistics: {}", e));
}
}
}
// If database statistics are the same, check all backends for state changes.
// Maybe there are some with more recent activity. `get_backends_state_change()`
// can return None or stale timestamp, so it's `compute.update_last_active()`
// responsibility to check if the new timestamp is more recent than the current one.
// This helps us to discover new sessions that have not done anything yet.
match get_backends_state_change(cli) {
Ok(last_active) => match (last_active, self.last_active) {
(Some(last_active), Some(prev_last_active)) => {
if last_active > prev_last_active {
self.last_active = Some(last_active);
return Ok(());
}
}
(Some(last_active), None) => {
self.last_active = Some(last_active);
return Ok(());
}
_ => {}
},
Err(e) => {
return Err(anyhow::anyhow!(
"could not get backends state change: {}",
e
));
}
}
// If there are existing (logical) walsenders, do not suspend.
//
// N.B. walproposer doesn't currently show up in pg_stat_replication,
// but protect if it will.
const WS_COUNT_QUERY: &str =
"select count(*) from pg_stat_replication where application_name != 'walproposer';";
match cli.query_one(WS_COUNT_QUERY, &[]) {
Ok(r) => match r.try_get::<&str, i64>("count") {
Ok(num_ws) => {
if num_ws > 0 {
self.last_active = Some(Utc::now());
return Ok(());
}
}
Err(e) => {
let err: anyhow::Error = e.into();
return Err(err.context("failed to parse walsenders count"));
}
},
Err(e) => {
return Err(anyhow::anyhow!("failed to get list of walsenders: {}", e));
}
}
// Don't suspend compute if there is an active logical replication subscription
//
// `where pid is not null` to filter out read only computes and subscription on branches
const LOGICAL_SUBSCRIPTIONS_QUERY: &str =
"select count(*) from pg_stat_subscription where pid is not null;";
match cli.query_one(LOGICAL_SUBSCRIPTIONS_QUERY, &[]) {
Ok(row) => match row.try_get::<&str, i64>("count") {
Ok(num_subscribers) => {
if num_subscribers > 0 {
self.last_active = Some(Utc::now());
return Ok(());
}
}
Err(e) => {
return Err(anyhow::anyhow!(
"failed to parse 'pg_stat_subscription' count: {}",
e
));
}
},
Err(e) => {
return Err(anyhow::anyhow!(
"failed to get list of active logical replication subscriptions: {}",
e
));
}
}
// Do not suspend compute if autovacuum is running
const AUTOVACUUM_COUNT_QUERY: &str =
"select count(*) from pg_stat_activity where backend_type = 'autovacuum worker'";
match cli.query_one(AUTOVACUUM_COUNT_QUERY, &[]) {
Ok(r) => match r.try_get::<&str, i64>("count") {
Ok(num_workers) => {
if num_workers > 0 {
self.last_active = Some(Utc::now());
return Ok(());
};
}
Err(e) => {
return Err(anyhow::anyhow!(
"failed to parse autovacuum workers count: {}",
e
));
}
},
Err(e) => {
return Err(anyhow::anyhow!(
"failed to get list of autovacuum workers: {}",
e
));
}
}
Ok(())
}
}
@@ -315,9 +412,24 @@ fn get_backends_state_change(cli: &mut Client) -> anyhow::Result<Option<DateTime
/// Launch a separate compute monitor thread and return its `JoinHandle`.
pub fn launch_monitor(compute: &Arc<ComputeNode>) -> thread::JoinHandle<()> {
let compute = Arc::clone(compute);
let experimental = compute.has_feature(ComputeFeature::ActivityMonitorExperimental);
let now = Utc::now();
let mut monitor = ComputeMonitor {
compute,
last_active: None,
last_checked: now,
last_up: now,
active_time: None,
sessions: None,
experimental,
};
let span = span!(Level::INFO, "compute_monitor");
thread::Builder::new()
.name("compute-monitor".into())
.spawn(move || watch_compute_activity(&compute))
.spawn(move || {
let _enter = span.enter();
monitor.run();
})
.expect("cannot launch compute monitor thread")
}

View File

@@ -3,7 +3,6 @@ use std::{io::Write, os::unix::fs::OpenOptionsExt, path::Path, time::Duration};
use anyhow::{Context, Result, bail};
use compute_api::responses::TlsConfig;
use ring::digest;
use spki::der::{Decode, PemReader};
use x509_cert::Certificate;
#[derive(Clone, Copy)]
@@ -52,7 +51,7 @@ pub fn update_key_path_blocking(pg_data: &Path, tls_config: &TlsConfig) {
match try_update_key_path_blocking(pg_data, tls_config) {
Ok(()) => break,
Err(e) => {
tracing::error!("could not create key file {e:?}");
tracing::error!(error = ?e, "could not create key file");
std::thread::sleep(Duration::from_secs(1))
}
}
@@ -92,8 +91,14 @@ fn try_update_key_path_blocking(pg_data: &Path, tls_config: &TlsConfig) -> Resul
fn verify_key_cert(key: &str, cert: &str) -> Result<()> {
use x509_cert::der::oid::db::rfc5912::ECDSA_WITH_SHA_256;
let cert = Certificate::decode(&mut PemReader::new(cert.as_bytes()).context("pem reader")?)
.context("decode cert")?;
let certs = Certificate::load_pem_chain(cert.as_bytes())
.context("decoding PEM encoded certificates")?;
// First certificate is our server-cert,
// all the rest of the certs are the CA cert chain.
let Some(cert) = certs.first() else {
bail!("no certificates found");
};
match cert.signature_algorithm.oid {
ECDSA_WITH_SHA_256 => {
@@ -115,3 +120,82 @@ fn verify_key_cert(key: &str, cert: &str) -> Result<()> {
Ok(())
}
#[cfg(test)]
mod tests {
use super::verify_key_cert;
/// Real certificate chain file, generated by cert-manager in dev.
/// The server auth certificate has expired since 2025-04-24T15:41:35Z.
const CERT: &str = "
-----BEGIN CERTIFICATE-----
MIICCDCCAa+gAwIBAgIQKhLomFcNULbZA/bPdGzaSzAKBggqhkjOPQQDAjBEMQsw
CQYDVQQGEwJVUzESMBAGA1UEChMJTmVvbiBJbmMuMSEwHwYDVQQDExhOZW9uIEs4
cyBJbnRlcm1lZGlhdGUgQ0EwHhcNMjUwNDIzMTU0MTM1WhcNMjUwNDI0MTU0MTM1
WjBBMT8wPQYDVQQDEzZjb21wdXRlLXdpc3B5LWdyYXNzLXcwY21laWp3LmRlZmF1
bHQuc3ZjLmNsdXN0ZXIubG9jYWwwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAATF
QCcG2m/EVHAiZtSsYgVnHgoTjUL/Jtwfdrpvz2t0bVRZmBmSKhlo53uPV9Y5eKFG
AmR54p9/gT2eO3xU7vAgo4GFMIGCMA4GA1UdDwEB/wQEAwIFoDAMBgNVHRMBAf8E
AjAAMB8GA1UdIwQYMBaAFFR2JAhXkeiNQNEixTvAYIwxUu3QMEEGA1UdEQQ6MDiC
NmNvbXB1dGUtd2lzcHktZ3Jhc3MtdzBjbWVpancuZGVmYXVsdC5zdmMuY2x1c3Rl
ci5sb2NhbDAKBggqhkjOPQQDAgNHADBEAiBLG22wKG8XS9e9RxBT+kmUx/kIThcP
DIpp7jx0PrFcdQIgEMTdnXpx5Cv/Z0NIEDxtMHUD7G0vuRPfztki36JuakM=
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIICFzCCAb6gAwIBAgIUbbX98N2Ip6lWAONRk8dU9hSz+YIwCgYIKoZIzj0EAwIw
RDELMAkGA1UEBhMCVVMxEjAQBgNVBAoTCU5lb24gSW5jLjEhMB8GA1UEAxMYTmVv
biBBV1MgSW50ZXJtZWRpYXRlIENBMB4XDTI1MDQyMjE1MTAxMFoXDTI1MDcyMTE1
MTAxMFowRDELMAkGA1UEBhMCVVMxEjAQBgNVBAoTCU5lb24gSW5jLjEhMB8GA1UE
AxMYTmVvbiBLOHMgSW50ZXJtZWRpYXRlIENBMFkwEwYHKoZIzj0CAQYIKoZIzj0D
AQcDQgAE5++m5owqNI4BPMTVNIUQH0qvU7pYhdpHGVGhdj/Lgars6ROvE6uSNQV4
SAmJN5HBzj5/6kLQaTPWpXW7EHXjK6OBjTCBijAOBgNVHQ8BAf8EBAMCAQYwEgYD
VR0TAQH/BAgwBgEB/wIBADAdBgNVHQ4EFgQUVHYkCFeR6I1A0SLFO8BgjDFS7dAw
HwYDVR0jBBgwFoAUgHfNXfyKtHO0V9qoLOWCjkNiaI8wJAYDVR0eAQH/BBowGKAW
MBSCEi5zdmMuY2x1c3Rlci5sb2NhbDAKBggqhkjOPQQDAgNHADBEAiBObVFFdXaL
QpOXmN60dYUNnQRwjKreFduEkQgOdOlssgIgVAdJJQFgvlrvEOBhY8j5WyeKRwUN
k/ALs6KpgaFBCGY=
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIB4jCCAYegAwIBAgIUFlxWFn/11yoGdmD+6gf+yQMToS0wCgYIKoZIzj0EAwIw
ODELMAkGA1UEBhMCVVMxEjAQBgNVBAoTCU5lb24gSW5jLjEVMBMGA1UEAxMMTmVv
biBSb290IENBMB4XDTI1MDQwMzA3MTUyMloXDTI2MDQwMzA3MTUyMlowRDELMAkG
A1UEBhMCVVMxEjAQBgNVBAoTCU5lb24gSW5jLjEhMB8GA1UEAxMYTmVvbiBBV1Mg
SW50ZXJtZWRpYXRlIENBMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEqonG/IQ6
ZxtEtOUTkkoNopPieXDO5CBKUkNFTGeJEB7OxRlSpYJgsBpaYIaD6Vc4sVk3thIF
p+pLw52idQOIN6NjMGEwDgYDVR0PAQH/BAQDAgEGMA8GA1UdEwEB/wQFMAMBAf8w
HQYDVR0OBBYEFIB3zV38irRztFfaqCzlgo5DYmiPMB8GA1UdIwQYMBaAFKh7M4/G
FHvr/ORDQZt4bMLlJvHCMAoGCCqGSM49BAMCA0kAMEYCIQCbS4x7QPslONzBYbjC
UQaQ0QLDW4CJHvQ4u4gbWFG87wIhAJMsHQHjP9qTT27Q65zQCR7O8QeLAfha1jrH
Ag/LsxSr
-----END CERTIFICATE-----
";
/// The key corresponding to [`CERT`]
const KEY: &str = "
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIDnAnrqmIJjndCLWP1iIO5X3X63Aia48TGpGuMXwvm6IoAoGCCqGSM49
AwEHoUQDQgAExUAnBtpvxFRwImbUrGIFZx4KE41C/ybcH3a6b89rdG1UWZgZkioZ
aOd7j1fWOXihRgJkeeKff4E9njt8VO7wIA==
-----END EC PRIVATE KEY-----
";
/// An incorrect key.
const INCORRECT_KEY: &str = "
-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIL6WqqBDyvM0HWz7Ir5M5+jhFWB7IzOClGn26OPrzHCXoAoGCCqGSM49
AwEHoUQDQgAE7XVvdOy5lfwtNKb+gJEUtnG+DrnnXLY5LsHDeGQKV9PTRcEMeCrG
YZzHyML4P6Sr4yi2ts+4B9i47uvAG8+XwQ==
-----END EC PRIVATE KEY-----
";
#[test]
fn certificate_verification() {
verify_key_cert(KEY, CERT).unwrap();
}
#[test]
#[should_panic(expected = "private key file does not match certificate")]
fn certificate_verification_fail() {
verify_key_cert(INCORRECT_KEY, CERT).unwrap();
}
}

View File

@@ -18,12 +18,11 @@ use anyhow::{Context, Result, anyhow, bail};
use clap::Parser;
use compute_api::spec::ComputeMode;
use control_plane::endpoint::ComputeControlPlane;
use control_plane::endpoint_storage::{ENDPOINT_STORAGE_DEFAULT_PORT, EndpointStorage};
use control_plane::local_env::{
InitForceMode, LocalEnv, NeonBroker, NeonLocalInitConf, NeonLocalInitPageserverConf,
ObjectStorageConf, SafekeeperConf,
EndpointStorageConf, InitForceMode, LocalEnv, NeonBroker, NeonLocalInitConf,
NeonLocalInitPageserverConf, SafekeeperConf,
};
use control_plane::object_storage::OBJECT_STORAGE_DEFAULT_PORT;
use control_plane::object_storage::ObjectStorage;
use control_plane::pageserver::PageServerNode;
use control_plane::safekeeper::SafekeeperNode;
use control_plane::storage_controller::{
@@ -93,7 +92,7 @@ enum NeonLocalCmd {
#[command(subcommand)]
Safekeeper(SafekeeperCmd),
#[command(subcommand)]
ObjectStorage(ObjectStorageCmd),
EndpointStorage(EndpointStorageCmd),
#[command(subcommand)]
Endpoint(EndpointCmd),
#[command(subcommand)]
@@ -460,14 +459,14 @@ enum SafekeeperCmd {
#[derive(clap::Subcommand)]
#[clap(about = "Manage object storage")]
enum ObjectStorageCmd {
Start(ObjectStorageStartCmd),
Stop(ObjectStorageStopCmd),
enum EndpointStorageCmd {
Start(EndpointStorageStartCmd),
Stop(EndpointStorageStopCmd),
}
#[derive(clap::Args)]
#[clap(about = "Start object storage")]
struct ObjectStorageStartCmd {
struct EndpointStorageStartCmd {
#[clap(short = 't', long, help = "timeout until we fail the command")]
#[arg(default_value = "10s")]
start_timeout: humantime::Duration,
@@ -475,7 +474,7 @@ struct ObjectStorageStartCmd {
#[derive(clap::Args)]
#[clap(about = "Stop object storage")]
struct ObjectStorageStopCmd {
struct EndpointStorageStopCmd {
#[arg(value_enum, default_value = "fast")]
#[clap(
short = 'm',
@@ -797,7 +796,9 @@ fn main() -> Result<()> {
}
NeonLocalCmd::StorageBroker(subcmd) => rt.block_on(handle_storage_broker(&subcmd, env)),
NeonLocalCmd::Safekeeper(subcmd) => rt.block_on(handle_safekeeper(&subcmd, env)),
NeonLocalCmd::ObjectStorage(subcmd) => rt.block_on(handle_object_storage(&subcmd, env)),
NeonLocalCmd::EndpointStorage(subcmd) => {
rt.block_on(handle_endpoint_storage(&subcmd, env))
}
NeonLocalCmd::Endpoint(subcmd) => rt.block_on(handle_endpoint(&subcmd, env)),
NeonLocalCmd::Mappings(subcmd) => handle_mappings(&subcmd, env),
};
@@ -1014,8 +1015,8 @@ fn handle_init(args: &InitCmdArgs) -> anyhow::Result<LocalEnv> {
}
})
.collect(),
object_storage: ObjectStorageConf {
port: OBJECT_STORAGE_DEFAULT_PORT,
endpoint_storage: EndpointStorageConf {
port: ENDPOINT_STORAGE_DEFAULT_PORT,
},
pg_distrib_dir: None,
neon_distrib_dir: None,
@@ -1735,12 +1736,15 @@ async fn handle_safekeeper(subcmd: &SafekeeperCmd, env: &local_env::LocalEnv) ->
Ok(())
}
async fn handle_object_storage(subcmd: &ObjectStorageCmd, env: &local_env::LocalEnv) -> Result<()> {
use ObjectStorageCmd::*;
let storage = ObjectStorage::from_env(env);
async fn handle_endpoint_storage(
subcmd: &EndpointStorageCmd,
env: &local_env::LocalEnv,
) -> Result<()> {
use EndpointStorageCmd::*;
let storage = EndpointStorage::from_env(env);
// In tests like test_forward_compatibility or test_graceful_cluster_restart
// old neon binaries (without object_storage) are present
// old neon binaries (without endpoint_storage) are present
if !storage.bin.exists() {
eprintln!(
"{} binary not found. Ignore if this is a compatibility test",
@@ -1750,13 +1754,13 @@ async fn handle_object_storage(subcmd: &ObjectStorageCmd, env: &local_env::Local
}
match subcmd {
Start(ObjectStorageStartCmd { start_timeout }) => {
Start(EndpointStorageStartCmd { start_timeout }) => {
if let Err(e) = storage.start(start_timeout).await {
eprintln!("object_storage start failed: {e}");
eprintln!("endpoint_storage start failed: {e}");
exit(1);
}
}
Stop(ObjectStorageStopCmd { stop_mode }) => {
Stop(EndpointStorageStopCmd { stop_mode }) => {
let immediate = match stop_mode {
StopMode::Fast => false,
StopMode::Immediate => true,
@@ -1866,10 +1870,10 @@ async fn handle_start_all_impl(
}
js.spawn(async move {
ObjectStorage::from_env(env)
EndpointStorage::from_env(env)
.start(&retry_timeout)
.await
.map_err(|e| e.context("start object_storage"))
.map_err(|e| e.context("start endpoint_storage"))
});
})();
@@ -1968,9 +1972,9 @@ async fn try_stop_all(env: &local_env::LocalEnv, immediate: bool) {
}
}
let storage = ObjectStorage::from_env(env);
let storage = EndpointStorage::from_env(env);
if let Err(e) = storage.stop(immediate) {
eprintln!("object_storage stop failed: {:#}", e);
eprintln!("endpoint_storage stop failed: {:#}", e);
}
for ps_conf in &env.pageservers {

View File

@@ -1,34 +1,33 @@
use crate::background_process::{self, start_process, stop_process};
use crate::local_env::LocalEnv;
use anyhow::anyhow;
use anyhow::{Context, Result};
use camino::Utf8PathBuf;
use std::io::Write;
use std::time::Duration;
/// Directory within .neon which will be used by default for LocalFs remote storage.
pub const OBJECT_STORAGE_REMOTE_STORAGE_DIR: &str = "local_fs_remote_storage/object_storage";
pub const OBJECT_STORAGE_DEFAULT_PORT: u16 = 9993;
pub const ENDPOINT_STORAGE_REMOTE_STORAGE_DIR: &str = "local_fs_remote_storage/endpoint_storage";
pub const ENDPOINT_STORAGE_DEFAULT_PORT: u16 = 9993;
pub struct ObjectStorage {
pub struct EndpointStorage {
pub bin: Utf8PathBuf,
pub data_dir: Utf8PathBuf,
pub pemfile: Utf8PathBuf,
pub port: u16,
}
impl ObjectStorage {
pub fn from_env(env: &LocalEnv) -> ObjectStorage {
ObjectStorage {
bin: Utf8PathBuf::from_path_buf(env.object_storage_bin()).unwrap(),
data_dir: Utf8PathBuf::from_path_buf(env.object_storage_data_dir()).unwrap(),
impl EndpointStorage {
pub fn from_env(env: &LocalEnv) -> EndpointStorage {
EndpointStorage {
bin: Utf8PathBuf::from_path_buf(env.endpoint_storage_bin()).unwrap(),
data_dir: Utf8PathBuf::from_path_buf(env.endpoint_storage_data_dir()).unwrap(),
pemfile: Utf8PathBuf::from_path_buf(env.public_key_path.clone()).unwrap(),
port: env.object_storage.port,
port: env.endpoint_storage.port,
}
}
fn config_path(&self) -> Utf8PathBuf {
self.data_dir.join("object_storage.json")
self.data_dir.join("endpoint_storage.json")
}
fn listen_addr(&self) -> Utf8PathBuf {
@@ -49,7 +48,7 @@ impl ObjectStorage {
let cfg = Cfg {
listen: self.listen_addr(),
pemfile: parent.join(self.pemfile.clone()),
local_path: parent.join(OBJECT_STORAGE_REMOTE_STORAGE_DIR),
local_path: parent.join(ENDPOINT_STORAGE_REMOTE_STORAGE_DIR),
r#type: "LocalFs".to_string(),
};
std::fs::create_dir_all(self.config_path().parent().unwrap())?;
@@ -59,24 +58,19 @@ impl ObjectStorage {
}
pub async fn start(&self, retry_timeout: &Duration) -> Result<()> {
println!("Starting s3 proxy at {}", self.listen_addr());
println!("Starting endpoint_storage at {}", self.listen_addr());
std::io::stdout().flush().context("flush stdout")?;
let process_status_check = || async {
tokio::time::sleep(Duration::from_millis(500)).await;
let res = reqwest::Client::new()
.get(format!("http://{}/metrics", self.listen_addr()))
.send()
.await;
match res {
Ok(response) if response.status().is_success() => Ok(true),
Ok(_) => Err(anyhow!("Failed to query /metrics")),
Err(e) => Err(anyhow!("Failed to check node status: {e}")),
let res = reqwest::Client::new().get(format!("http://{}/metrics", self.listen_addr()));
match res.send().await {
Ok(res) => Ok(res.status().is_success()),
Err(_) => Ok(false),
}
};
let res = start_process(
"object_storage",
"endpoint_storage",
&self.data_dir.clone().into_std_path_buf(),
&self.bin.clone().into_std_path_buf(),
vec![self.config_path().to_string()],
@@ -94,14 +88,14 @@ impl ObjectStorage {
}
pub fn stop(&self, immediate: bool) -> anyhow::Result<()> {
stop_process(immediate, "object_storage", &self.pid_file())
stop_process(immediate, "endpoint_storage", &self.pid_file())
}
fn log_file(&self) -> Utf8PathBuf {
self.data_dir.join("object_storage.log")
self.data_dir.join("endpoint_storage.log")
}
fn pid_file(&self) -> Utf8PathBuf {
self.data_dir.join("object_storage.pid")
self.data_dir.join("endpoint_storage.pid")
}
}

View File

@@ -9,8 +9,8 @@
mod background_process;
pub mod broker;
pub mod endpoint;
pub mod endpoint_storage;
pub mod local_env;
pub mod object_storage;
pub mod pageserver;
pub mod postgresql_conf;
pub mod safekeeper;

View File

@@ -19,7 +19,7 @@ use serde::{Deserialize, Serialize};
use utils::auth::encode_from_key_file;
use utils::id::{NodeId, TenantId, TenantTimelineId, TimelineId};
use crate::object_storage::{OBJECT_STORAGE_REMOTE_STORAGE_DIR, ObjectStorage};
use crate::endpoint_storage::{ENDPOINT_STORAGE_REMOTE_STORAGE_DIR, EndpointStorage};
use crate::pageserver::{PAGESERVER_REMOTE_STORAGE_DIR, PageServerNode};
use crate::safekeeper::SafekeeperNode;
@@ -72,7 +72,7 @@ pub struct LocalEnv {
pub safekeepers: Vec<SafekeeperConf>,
pub object_storage: ObjectStorageConf,
pub endpoint_storage: EndpointStorageConf,
// Control plane upcall API for pageserver: if None, we will not run storage_controller If set, this will
// be propagated into each pageserver's configuration.
@@ -110,7 +110,7 @@ pub struct OnDiskConfig {
)]
pub pageservers: Vec<PageServerConf>,
pub safekeepers: Vec<SafekeeperConf>,
pub object_storage: ObjectStorageConf,
pub endpoint_storage: EndpointStorageConf,
pub control_plane_api: Option<Url>,
pub control_plane_hooks_api: Option<Url>,
pub control_plane_compute_hook_api: Option<Url>,
@@ -144,7 +144,7 @@ pub struct NeonLocalInitConf {
pub storage_controller: Option<NeonStorageControllerConf>,
pub pageservers: Vec<NeonLocalInitPageserverConf>,
pub safekeepers: Vec<SafekeeperConf>,
pub object_storage: ObjectStorageConf,
pub endpoint_storage: EndpointStorageConf,
pub control_plane_api: Option<Url>,
pub control_plane_hooks_api: Option<Url>,
pub generate_local_ssl_certs: bool,
@@ -152,7 +152,7 @@ pub struct NeonLocalInitConf {
#[derive(Serialize, Default, Deserialize, PartialEq, Eq, Clone, Debug)]
#[serde(default)]
pub struct ObjectStorageConf {
pub struct EndpointStorageConf {
pub port: u16,
}
@@ -413,8 +413,8 @@ impl LocalEnv {
self.pg_dir(pg_version, "lib")
}
pub fn object_storage_bin(&self) -> PathBuf {
self.neon_distrib_dir.join("object_storage")
pub fn endpoint_storage_bin(&self) -> PathBuf {
self.neon_distrib_dir.join("endpoint_storage")
}
pub fn pageserver_bin(&self) -> PathBuf {
@@ -450,8 +450,8 @@ impl LocalEnv {
self.base_data_dir.join("safekeepers").join(data_dir_name)
}
pub fn object_storage_data_dir(&self) -> PathBuf {
self.base_data_dir.join("object_storage")
pub fn endpoint_storage_data_dir(&self) -> PathBuf {
self.base_data_dir.join("endpoint_storage")
}
pub fn get_pageserver_conf(&self, id: NodeId) -> anyhow::Result<&PageServerConf> {
@@ -615,7 +615,7 @@ impl LocalEnv {
control_plane_compute_hook_api: _,
branch_name_mappings,
generate_local_ssl_certs,
object_storage,
endpoint_storage,
} = on_disk_config;
LocalEnv {
base_data_dir: repopath.to_owned(),
@@ -632,7 +632,7 @@ impl LocalEnv {
control_plane_hooks_api,
branch_name_mappings,
generate_local_ssl_certs,
object_storage,
endpoint_storage,
}
};
@@ -742,7 +742,7 @@ impl LocalEnv {
control_plane_compute_hook_api: None,
branch_name_mappings: self.branch_name_mappings.clone(),
generate_local_ssl_certs: self.generate_local_ssl_certs,
object_storage: self.object_storage.clone(),
endpoint_storage: self.endpoint_storage.clone(),
},
)
}
@@ -849,7 +849,7 @@ impl LocalEnv {
control_plane_api,
generate_local_ssl_certs,
control_plane_hooks_api,
object_storage,
endpoint_storage,
} = conf;
// Find postgres binaries.
@@ -901,7 +901,7 @@ impl LocalEnv {
control_plane_hooks_api,
branch_name_mappings: Default::default(),
generate_local_ssl_certs,
object_storage,
endpoint_storage,
};
if generate_local_ssl_certs {
@@ -929,13 +929,13 @@ impl LocalEnv {
.context("pageserver init failed")?;
}
ObjectStorage::from_env(&env)
EndpointStorage::from_env(&env)
.init()
.context("object storage init failed")?;
// setup remote remote location for default LocalFs remote storage
std::fs::create_dir_all(env.base_data_dir.join(PAGESERVER_REMOTE_STORAGE_DIR))?;
std::fs::create_dir_all(env.base_data_dir.join(OBJECT_STORAGE_REMOTE_STORAGE_DIR))?;
std::fs::create_dir_all(env.base_data_dir.join(ENDPOINT_STORAGE_REMOTE_STORAGE_DIR))?;
env.persist_config()
}

View File

@@ -45,9 +45,7 @@ allow = [
"ISC",
"MIT",
"MPL-2.0",
"OpenSSL",
"Unicode-3.0",
"Zlib",
]
confidence-threshold = 0.8
exceptions = [
@@ -56,14 +54,6 @@ exceptions = [
{ allow = ["Zlib"], name = "const_format", version = "*" },
]
[[licenses.clarify]]
name = "ring"
version = "*"
expression = "MIT AND ISC AND OpenSSL"
license-files = [
{ path = "LICENSE", hash = 0xbd0eed23 }
]
[licenses.private]
ignore = true
registries = []
@@ -116,7 +106,11 @@ name = "openssl"
unknown-registry = "warn"
unknown-git = "warn"
allow-registry = ["https://github.com/rust-lang/crates.io-index"]
allow-git = []
allow-git = [
# Crate pinned to commit in origin repo due to opentelemetry version.
# TODO: Remove this once crate is fetched from crates.io again.
"https://github.com/mattiapenati/tower-otel",
]
[sources.allow-org]
github = [

View File

@@ -9,21 +9,20 @@
# to verify custom image builds (e.g pre-published ones).
#
# A test script for postgres extensions
# Currently supports only v16
# Currently supports only v16+
#
set -eux -o pipefail
COMPOSE_FILE='docker-compose.yml'
cd $(dirname $0)
COMPUTE_CONTAINER_NAME=docker-compose-compute-1
TEST_CONTAINER_NAME=docker-compose-neon-test-extensions-1
export COMPOSE_FILE='docker-compose.yml'
export COMPOSE_PROFILES=test-extensions
cd "$(dirname "${0}")"
PSQL_OPTION="-h localhost -U cloud_admin -p 55433 -d postgres"
cleanup() {
function cleanup() {
echo "show container information"
docker ps
echo "stop containers..."
docker compose --profile test-extensions -f $COMPOSE_FILE down
docker compose down
}
for pg_version in ${TEST_VERSION_ONLY-14 15 16 17}; do
@@ -31,50 +30,57 @@ for pg_version in ${TEST_VERSION_ONLY-14 15 16 17}; do
echo "clean up containers if exists"
cleanup
PG_TEST_VERSION=$((pg_version < 16 ? 16 : pg_version))
PG_VERSION=$pg_version PG_TEST_VERSION=$PG_TEST_VERSION docker compose --profile test-extensions -f $COMPOSE_FILE up --quiet-pull --build -d
PG_VERSION=${pg_version} PG_TEST_VERSION=${PG_TEST_VERSION} docker compose up --quiet-pull --build -d
echo "wait until the compute is ready. timeout after 60s. "
cnt=0
while sleep 3; do
# check timeout
cnt=`expr $cnt + 3`
if [ $cnt -gt 60 ]; then
(( cnt += 3 ))
if [[ ${cnt} -gt 60 ]]; then
echo "timeout before the compute is ready."
exit 1
fi
if docker compose --profile test-extensions -f $COMPOSE_FILE logs "compute_is_ready" | grep -q "accepting connections"; then
if docker compose logs "compute_is_ready" | grep -q "accepting connections"; then
echo "OK. The compute is ready to connect."
echo "execute simple queries."
docker exec $COMPUTE_CONTAINER_NAME /bin/bash -c "psql $PSQL_OPTION"
docker compose exec compute /bin/bash -c "psql ${PSQL_OPTION} -c 'SELECT 1'"
break
fi
done
if [ $pg_version -ge 16 ]; then
if [[ ${pg_version} -ge 16 ]]; then
# This is required for the pg_hint_plan test, to prevent flaky log message causing the test to fail
# It cannot be moved to Dockerfile now because the database directory is created after the start of the container
echo Adding dummy config
docker exec $COMPUTE_CONTAINER_NAME touch /var/db/postgres/compute/compute_ctl_temp_override.conf
docker compose exec compute touch /var/db/postgres/compute/compute_ctl_temp_override.conf
# The following block copies the files for the pg_hintplan test to the compute node for the extension test in an isolated docker-compose environment
TMPDIR=$(mktemp -d)
docker cp $TEST_CONTAINER_NAME:/ext-src/pg_hint_plan-src/data $TMPDIR/data
docker cp $TMPDIR/data $COMPUTE_CONTAINER_NAME:/ext-src/pg_hint_plan-src/
rm -rf $TMPDIR
docker compose cp neon-test-extensions:/ext-src/pg_hint_plan-src/data "${TMPDIR}/data"
docker compose cp "${TMPDIR}/data" compute:/ext-src/pg_hint_plan-src/
rm -rf "${TMPDIR}"
# The following block does the same for the contrib/file_fdw test
TMPDIR=$(mktemp -d)
docker cp $TEST_CONTAINER_NAME:/postgres/contrib/file_fdw/data $TMPDIR/data
docker cp $TMPDIR/data $COMPUTE_CONTAINER_NAME:/postgres/contrib/file_fdw/data
rm -rf $TMPDIR
docker compose cp neon-test-extensions:/postgres/contrib/file_fdw/data "${TMPDIR}/data"
docker compose cp "${TMPDIR}/data" compute:/postgres/contrib/file_fdw/data
rm -rf "${TMPDIR}"
# Apply patches
cat ../compute/patches/contrib_pg${pg_version}.patch | docker exec -i $TEST_CONTAINER_NAME bash -c "(cd /postgres && patch -p1)"
docker compose exec -i neon-test-extensions bash -c "(cd /postgres && patch -p1)" <"../compute/patches/contrib_pg${pg_version}.patch"
# We are running tests now
rm -f testout.txt testout_contrib.txt
docker exec -e USE_PGXS=1 -e SKIP=timescaledb-src,rdkit-src,postgis-src,pg_jsonschema-src,kq_imcx-src,wal2json_2_5-src,rag_jina_reranker_v1_tiny_en-src,rag_bge_small_en_v15-src \
$TEST_CONTAINER_NAME /run-tests.sh /ext-src | tee testout.txt && EXT_SUCCESS=1 || EXT_SUCCESS=0
docker exec -e SKIP=start-scripts,postgres_fdw,ltree_plpython,jsonb_plpython,jsonb_plperl,hstore_plpython,hstore_plperl,dblink,bool_plperl \
$TEST_CONTAINER_NAME /run-tests.sh /postgres/contrib | tee testout_contrib.txt && CONTRIB_SUCCESS=1 || CONTRIB_SUCCESS=0
if [ $EXT_SUCCESS -eq 0 ] || [ $CONTRIB_SUCCESS -eq 0 ]; then
exit 1
docker compose exec -e USE_PGXS=1 -e SKIP=timescaledb-src,rdkit-src,postgis-src,pg_jsonschema-src,kq_imcx-src,wal2json_2_5-src,rag_jina_reranker_v1_tiny_en-src,rag_bge_small_en_v15-src \
neon-test-extensions /run-tests.sh /ext-src | tee testout.txt && EXT_SUCCESS=1 || EXT_SUCCESS=0
docker compose exec -e SKIP=start-scripts,postgres_fdw,ltree_plpython,jsonb_plpython,jsonb_plperl,hstore_plpython,hstore_plperl,dblink,bool_plperl \
neon-test-extensions /run-tests.sh /postgres/contrib | tee testout_contrib.txt && CONTRIB_SUCCESS=1 || CONTRIB_SUCCESS=0
if [[ ${EXT_SUCCESS} -eq 0 || ${CONTRIB_SUCCESS} -eq 0 ]]; then
CONTRIB_FAILED=
FAILED=
[[ ${EXT_SUCCESS} -eq 0 ]] && FAILED=$(tail -1 testout.txt | awk '{for(i=1;i<=NF;i++){print "/ext-src/"$i;}}')
[[ ${CONTRIB_SUCCESS} -eq 0 ]] && CONTRIB_FAILED=$(tail -1 testout_contrib.txt | awk '{for(i=0;i<=NF;i++){print "/postgres/contrib/"$i;}}')
for d in ${FAILED} ${CONTRIB_FAILED}; do
docker compose exec neon-test-extensions bash -c 'for file in $(find '"${d}"' -name regression.diffs -o -name regression.out); do cat ${file}; done' || [[ ${?} -eq 1 ]]
done
exit 1
fi
fi
done

View File

@@ -1,5 +1,5 @@
[package]
name = "object_storage"
name = "endpoint_storage"
version = "0.0.1"
edition.workspace = true
license.workspace = true

View File

@@ -2,7 +2,7 @@ use anyhow::anyhow;
use axum::body::{Body, Bytes};
use axum::response::{IntoResponse, Response};
use axum::{Router, http::StatusCode};
use object_storage::{PrefixS3Path, S3Path, Storage, bad_request, internal_error, not_found, ok};
use endpoint_storage::{PrefixS3Path, S3Path, Storage, bad_request, internal_error, not_found, ok};
use remote_storage::TimeoutOrCancel;
use remote_storage::{DownloadError, DownloadOpts, GenericRemoteStorage, RemotePath};
use std::{sync::Arc, time::SystemTime, time::UNIX_EPOCH};
@@ -46,12 +46,12 @@ async fn metrics() -> Result {
async fn get(S3Path { path }: S3Path, state: State) -> Result {
info!(%path, "downloading");
let download_err = |e| {
if let DownloadError::NotFound = e {
info!(%path, %e, "downloading"); // 404 is not an issue of _this_ service
let download_err = |err| {
if let DownloadError::NotFound = err {
info!(%path, %err, "downloading"); // 404 is not an issue of _this_ service
return not_found(&path);
}
internal_error(e, &path, "downloading")
internal_error(err, &path, "downloading")
};
let cancel = state.cancel.clone();
let opts = &DownloadOpts::default();
@@ -249,7 +249,7 @@ mod tests {
};
let proxy = Storage {
auth: object_storage::JwtAuth::new(TEST_PUB_KEY_ED25519).unwrap(),
auth: endpoint_storage::JwtAuth::new(TEST_PUB_KEY_ED25519).unwrap(),
storage,
cancel: cancel.clone(),
max_upload_file_limit: usize::MAX,
@@ -343,14 +343,14 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
TimelineId::from_array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 2, 3, 4, 5, 7]);
const ENDPOINT_ID: &str = "ep-winter-frost-a662z3vg";
fn token() -> String {
let claims = object_storage::Claims {
let claims = endpoint_storage::Claims {
tenant_id: TENANT_ID,
timeline_id: TIMELINE_ID,
endpoint_id: ENDPOINT_ID.into(),
exp: u64::MAX,
};
let key = jsonwebtoken::EncodingKey::from_ed_pem(TEST_PRIV_KEY_ED25519).unwrap();
let header = jsonwebtoken::Header::new(object_storage::VALIDATION_ALGO);
let header = jsonwebtoken::Header::new(endpoint_storage::VALIDATION_ALGO);
jsonwebtoken::encode(&header, &claims, &key).unwrap()
}
@@ -364,7 +364,10 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
vec![TIMELINE_ID.to_string(), TimelineId::generate().to_string()],
vec![ENDPOINT_ID, "ep-ololo"]
)
.skip(1);
// first one is fully valid path, second path is valid for GET as
// read paths may have different endpoint if tenant and timeline matches
// (needed for prewarming RO->RW replica)
.skip(2);
for ((uri, method), (tenant, timeline, endpoint)) in iproduct!(routes(), args) {
info!(%uri, %method, %tenant, %timeline, %endpoint);
@@ -475,6 +478,16 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
requests_chain(chain.into_iter(), |_| token()).await;
}
#[testlog(tokio::test)]
async fn read_other_endpoint_data() {
let uri = format!("/{TENANT_ID}/{TIMELINE_ID}/other_endpoint/key");
let chain = vec![
(uri.clone(), "GET", "", StatusCode::NOT_FOUND, false),
(uri.clone(), "PUT", "", StatusCode::UNAUTHORIZED, false),
];
requests_chain(chain.into_iter(), |_| token()).await;
}
fn delete_prefix_token(uri: &str) -> String {
use serde::Serialize;
let parts = uri.split("/").collect::<Vec<&str>>();
@@ -482,7 +495,7 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
struct PrefixClaims {
tenant_id: TenantId,
timeline_id: Option<TimelineId>,
endpoint_id: Option<object_storage::EndpointId>,
endpoint_id: Option<endpoint_storage::EndpointId>,
exp: u64,
}
let claims = PrefixClaims {
@@ -492,7 +505,7 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
exp: u64::MAX,
};
let key = jsonwebtoken::EncodingKey::from_ed_pem(TEST_PRIV_KEY_ED25519).unwrap();
let header = jsonwebtoken::Header::new(object_storage::VALIDATION_ALGO);
let header = jsonwebtoken::Header::new(endpoint_storage::VALIDATION_ALGO);
jsonwebtoken::encode(&header, &claims, &key).unwrap()
}

View File

@@ -169,10 +169,19 @@ impl FromRequestParts<Arc<Storage>> for S3Path {
.auth
.decode(bearer.token())
.map_err(|e| bad_request(e, "decoding token"))?;
// Read paths may have different endpoint ids. For readonly -> readwrite replica
// prewarming, endpoint must read other endpoint's data.
let endpoint_id = if parts.method == axum::http::Method::GET {
claims.endpoint_id.clone()
} else {
path.endpoint_id.clone()
};
let route = Claims {
tenant_id: path.tenant_id,
timeline_id: path.timeline_id,
endpoint_id: path.endpoint_id.clone(),
endpoint_id,
exp: claims.exp,
};
if route != claims {

View File

@@ -1,4 +1,4 @@
//! `object_storage` is a service which provides API for uploading and downloading
//! `endpoint_storage` is a service which provides API for uploading and downloading
//! files. It is used by compute and control plane for accessing LFC prewarm data.
//! This service is deployed either as a separate component or as part of compute image
//! for large computes.
@@ -33,7 +33,7 @@ async fn main() -> anyhow::Result<()> {
let config: String = std::env::args().skip(1).take(1).collect();
if config.is_empty() {
anyhow::bail!("Usage: object_storage config.json")
anyhow::bail!("Usage: endpoint_storage config.json")
}
info!("Reading config from {config}");
let config = std::fs::read_to_string(config.clone())?;
@@ -41,7 +41,7 @@ async fn main() -> anyhow::Result<()> {
info!("Reading pemfile from {}", config.pemfile.clone());
let pemfile = std::fs::read(config.pemfile.clone())?;
info!("Loading public key from {}", config.pemfile.clone());
let auth = object_storage::JwtAuth::new(&pemfile)?;
let auth = endpoint_storage::JwtAuth::new(&pemfile)?;
let listener = tokio::net::TcpListener::bind(config.listen).await.unwrap();
info!("listening on {}", listener.local_addr().unwrap());
@@ -50,7 +50,7 @@ async fn main() -> anyhow::Result<()> {
let cancel = tokio_util::sync::CancellationToken::new();
app::check_storage_permissions(&storage, cancel.clone()).await?;
let proxy = std::sync::Arc::new(object_storage::Storage {
let proxy = std::sync::Arc::new(endpoint_storage::Storage {
auth,
storage,
cancel: cancel.clone(),

View File

@@ -181,6 +181,7 @@ pub struct ConfigToml {
pub generate_unarchival_heatmap: Option<bool>,
pub tracing: Option<Tracing>,
pub enable_tls_page_service_api: bool,
pub dev_mode: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
@@ -657,6 +658,7 @@ impl Default for ConfigToml {
generate_unarchival_heatmap: None,
tracing: None,
enable_tls_page_service_api: false,
dev_mode: false,
}
}
}

View File

@@ -320,6 +320,35 @@ pub struct TimelineCreateRequest {
pub mode: TimelineCreateRequestMode,
}
impl TimelineCreateRequest {
pub fn mode_tag(&self) -> &'static str {
match &self.mode {
TimelineCreateRequestMode::Branch { .. } => "branch",
TimelineCreateRequestMode::ImportPgdata { .. } => "import",
TimelineCreateRequestMode::Bootstrap { .. } => "bootstrap",
}
}
pub fn is_import(&self) -> bool {
matches!(self.mode, TimelineCreateRequestMode::ImportPgdata { .. })
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub enum ShardImportStatus {
InProgress,
Done,
Error(String),
}
impl ShardImportStatus {
pub fn is_terminal(&self) -> bool {
match self {
ShardImportStatus::InProgress => false,
ShardImportStatus::Done | ShardImportStatus::Error(_) => true,
}
}
}
/// Storage controller specific extensions to [`TimelineInfo`].
#[derive(Serialize, Deserialize, Clone)]
pub struct TimelineCreateResponseStorcon {
@@ -1774,6 +1803,8 @@ pub struct TopTenantShardsResponse {
}
pub mod virtual_file {
use std::sync::LazyLock;
#[derive(
Copy,
Clone,
@@ -1811,35 +1842,33 @@ pub mod virtual_file {
pub enum IoMode {
/// Uses buffered IO.
Buffered,
/// Uses direct IO, error out if the operation fails.
/// Uses direct IO for reads only.
#[cfg(target_os = "linux")]
Direct,
/// Use direct IO for reads and writes.
#[cfg(target_os = "linux")]
DirectRw,
}
impl IoMode {
pub fn preferred() -> Self {
// The default behavior when running Rust unit tests without any further
// flags is to use the newest behavior if available on the platform (Direct).
// flags is to use the newest behavior (DirectRw).
// The CI uses the following environment variable to unit tests for all
// different modes.
// NB: the Python regression & perf tests have their own defaults management
// that writes pageserver.toml; they do not use this variable.
if cfg!(test) {
use once_cell::sync::Lazy;
static CACHED: Lazy<IoMode> = Lazy::new(|| {
static CACHED: LazyLock<IoMode> = LazyLock::new(|| {
utils::env::var_serde_json_string(
"NEON_PAGESERVER_UNIT_TEST_VIRTUAL_FILE_IO_MODE",
)
.unwrap_or({
.unwrap_or(
#[cfg(target_os = "linux")]
{
IoMode::Direct
}
IoMode::DirectRw,
#[cfg(not(target_os = "linux"))]
{
IoMode::Buffered
}
})
IoMode::Buffered,
)
});
*CACHED
} else {
@@ -1856,6 +1885,8 @@ pub mod virtual_file {
v if v == (IoMode::Buffered as u8) => IoMode::Buffered,
#[cfg(target_os = "linux")]
v if v == (IoMode::Direct as u8) => IoMode::Direct,
#[cfg(target_os = "linux")]
v if v == (IoMode::DirectRw as u8) => IoMode::DirectRw,
x => return Err(x),
})
}

View File

@@ -4,10 +4,10 @@
//! See docs/rfcs/025-generation-numbers.md
use serde::{Deserialize, Serialize};
use utils::id::NodeId;
use utils::id::{NodeId, TimelineId};
use crate::controller_api::NodeRegisterRequest;
use crate::models::LocationConfigMode;
use crate::models::{LocationConfigMode, ShardImportStatus};
use crate::shard::TenantShardId;
/// Upcall message sent by the pageserver to the configured `control_plane_api` on
@@ -62,3 +62,10 @@ pub struct ValidateResponseTenant {
pub id: TenantShardId,
pub valid: bool,
}
#[derive(Serialize, Deserialize)]
pub struct PutTimelineImportStatusRequest {
pub tenant_shard_id: TenantShardId,
pub timeline_id: TimelineId,
pub status: ShardImportStatus,
}

View File

@@ -14,8 +14,9 @@ use anyhow::{Context, Result};
use azure_core::request_options::{IfMatchCondition, MaxResults, Metadata, Range};
use azure_core::{Continuable, HttpClient, RetryOptions, TransportOptions};
use azure_storage::StorageCredentials;
use azure_storage_blobs::blob::CopyStatus;
use azure_storage_blobs::blob::operations::GetBlobBuilder;
use azure_storage_blobs::blob::{Blob, CopyStatus};
use azure_storage_blobs::container::operations::ListBlobsBuilder;
use azure_storage_blobs::prelude::{ClientBuilder, ContainerClient};
use bytes::Bytes;
use futures::FutureExt;
@@ -253,53 +254,15 @@ impl AzureBlobStorage {
download
}
async fn permit(
&self,
kind: RequestKind,
cancel: &CancellationToken,
) -> Result<tokio::sync::SemaphorePermit<'_>, Cancelled> {
let acquire = self.concurrency_limiter.acquire(kind);
tokio::select! {
permit = acquire => Ok(permit.expect("never closed")),
_ = cancel.cancelled() => Err(Cancelled),
}
}
pub fn container_name(&self) -> &str {
&self.container_name
}
}
fn to_azure_metadata(metadata: StorageMetadata) -> Metadata {
let mut res = Metadata::new();
for (k, v) in metadata.0.into_iter() {
res.insert(k, v);
}
res
}
fn to_download_error(error: azure_core::Error) -> DownloadError {
if let Some(http_err) = error.as_http_error() {
match http_err.status() {
StatusCode::NotFound => DownloadError::NotFound,
StatusCode::NotModified => DownloadError::Unmodified,
StatusCode::BadRequest => DownloadError::BadInput(anyhow::Error::new(error)),
_ => DownloadError::Other(anyhow::Error::new(error)),
}
} else {
DownloadError::Other(error.into())
}
}
impl RemoteStorage for AzureBlobStorage {
fn list_streaming(
fn list_streaming_for_fn<T: Default + ListingCollector>(
&self,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
cancel: &CancellationToken,
) -> impl Stream<Item = Result<Listing, DownloadError>> {
request_kind: RequestKind,
customize_builder: impl Fn(ListBlobsBuilder) -> ListBlobsBuilder,
) -> impl Stream<Item = Result<T, DownloadError>> {
// get the passed prefix or if it is not set use prefix_in_bucket value
let list_prefix = prefix.map(|p| self.relative_path_to_name(p)).or_else(|| {
self.prefix_in_container.clone().map(|mut s| {
@@ -311,7 +274,7 @@ impl RemoteStorage for AzureBlobStorage {
});
async_stream::stream! {
let _permit = self.permit(RequestKind::List, cancel).await?;
let _permit = self.permit(request_kind, cancel).await?;
let mut builder = self.client.list_blobs();
@@ -327,6 +290,8 @@ impl RemoteStorage for AzureBlobStorage {
builder = builder.max_results(MaxResults::new(limit));
}
builder = customize_builder(builder);
let mut next_marker = None;
let mut timeout_try_cnt = 1;
@@ -382,26 +347,20 @@ impl RemoteStorage for AzureBlobStorage {
break;
};
let mut res = Listing::default();
let mut res = T::default();
next_marker = entry.continuation();
let prefix_iter = entry
.blobs
.prefixes()
.map(|prefix| self.name_to_relative_path(&prefix.name));
res.prefixes.extend(prefix_iter);
res.add_prefixes(self, prefix_iter);
let blob_iter = entry
.blobs
.blobs()
.map(|k| ListingObject{
key: self.name_to_relative_path(&k.name),
last_modified: k.properties.last_modified.into(),
size: k.properties.content_length,
}
);
.blobs();
for key in blob_iter {
res.keys.push(key);
res.add_blob(self, key);
if let Some(mut mk) = max_keys {
assert!(mk > 0);
@@ -423,6 +382,128 @@ impl RemoteStorage for AzureBlobStorage {
}
}
async fn permit(
&self,
kind: RequestKind,
cancel: &CancellationToken,
) -> Result<tokio::sync::SemaphorePermit<'_>, Cancelled> {
let acquire = self.concurrency_limiter.acquire(kind);
tokio::select! {
permit = acquire => Ok(permit.expect("never closed")),
_ = cancel.cancelled() => Err(Cancelled),
}
}
pub fn container_name(&self) -> &str {
&self.container_name
}
}
trait ListingCollector {
fn add_prefixes(&mut self, abs: &AzureBlobStorage, prefix_it: impl Iterator<Item = RemotePath>);
fn add_blob(&mut self, abs: &AzureBlobStorage, blob: &Blob);
}
impl ListingCollector for Listing {
fn add_prefixes(
&mut self,
_abs: &AzureBlobStorage,
prefix_it: impl Iterator<Item = RemotePath>,
) {
self.prefixes.extend(prefix_it);
}
fn add_blob(&mut self, abs: &AzureBlobStorage, blob: &Blob) {
self.keys.push(ListingObject {
key: abs.name_to_relative_path(&blob.name),
last_modified: blob.properties.last_modified.into(),
size: blob.properties.content_length,
});
}
}
impl ListingCollector for crate::VersionListing {
fn add_prefixes(
&mut self,
_abs: &AzureBlobStorage,
_prefix_it: impl Iterator<Item = RemotePath>,
) {
// nothing
}
fn add_blob(&mut self, abs: &AzureBlobStorage, blob: &Blob) {
let id = crate::VersionId(blob.version_id.clone().expect("didn't find version ID"));
self.versions.push(crate::Version {
key: abs.name_to_relative_path(&blob.name),
last_modified: blob.properties.last_modified.into(),
kind: crate::VersionKind::Version(id),
});
}
}
fn to_azure_metadata(metadata: StorageMetadata) -> Metadata {
let mut res = Metadata::new();
for (k, v) in metadata.0.into_iter() {
res.insert(k, v);
}
res
}
fn to_download_error(error: azure_core::Error) -> DownloadError {
if let Some(http_err) = error.as_http_error() {
match http_err.status() {
StatusCode::NotFound => DownloadError::NotFound,
StatusCode::NotModified => DownloadError::Unmodified,
StatusCode::BadRequest => DownloadError::BadInput(anyhow::Error::new(error)),
_ => DownloadError::Other(anyhow::Error::new(error)),
}
} else {
DownloadError::Other(error.into())
}
}
impl RemoteStorage for AzureBlobStorage {
fn list_streaming(
&self,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
cancel: &CancellationToken,
) -> impl Stream<Item = Result<Listing, DownloadError>> {
let customize_builder = |builder| builder;
let kind = RequestKind::ListVersions;
self.list_streaming_for_fn(prefix, mode, max_keys, cancel, kind, customize_builder)
}
async fn list_versions(
&self,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
cancel: &CancellationToken,
) -> std::result::Result<crate::VersionListing, DownloadError> {
let customize_builder = |mut builder: ListBlobsBuilder| {
builder = builder.include_versions(true);
builder
};
let kind = RequestKind::ListVersions;
let mut stream = std::pin::pin!(self.list_streaming_for_fn(
prefix,
mode,
max_keys,
cancel,
kind,
customize_builder
));
let mut combined: crate::VersionListing =
stream.next().await.expect("At least one item required")?;
while let Some(list) = stream.next().await {
let list = list?;
combined.versions.extend(list.versions.into_iter());
}
Ok(combined)
}
async fn head_object(
&self,
key: &RemotePath,
@@ -532,7 +613,12 @@ impl RemoteStorage for AzureBlobStorage {
let mut builder = blob_client.get();
if let Some(ref etag) = opts.etag {
builder = builder.if_match(IfMatchCondition::NotMatch(etag.to_string()))
builder = builder.if_match(IfMatchCondition::NotMatch(etag.to_string()));
}
if let Some(ref version_id) = opts.version_id {
let version_id = azure_storage_blobs::prelude::VersionId::new(version_id.0.clone());
builder = builder.blob_versioning(version_id);
}
if let Some((start, end)) = opts.byte_range() {

View File

@@ -176,6 +176,32 @@ pub struct Listing {
pub keys: Vec<ListingObject>,
}
#[derive(Default)]
pub struct VersionListing {
pub versions: Vec<Version>,
}
pub struct Version {
pub key: RemotePath,
pub last_modified: SystemTime,
pub kind: VersionKind,
}
impl Version {
pub fn version_id(&self) -> Option<&VersionId> {
match &self.kind {
VersionKind::Version(id) => Some(id),
VersionKind::DeletionMarker => None,
}
}
}
#[derive(Debug)]
pub enum VersionKind {
DeletionMarker,
Version(VersionId),
}
/// Options for downloads. The default value is a plain GET.
pub struct DownloadOpts {
/// If given, returns [`DownloadError::Unmodified`] if the object still has
@@ -186,6 +212,8 @@ pub struct DownloadOpts {
/// The end of the byte range to download, or unbounded. Must be after the
/// start bound.
pub byte_end: Bound<u64>,
/// Optionally request a specific version of a key
pub version_id: Option<VersionId>,
/// Indicate whether we're downloading something small or large: this indirectly controls
/// timeouts: for something like an index/manifest/heatmap, we should time out faster than
/// for layer files
@@ -197,12 +225,16 @@ pub enum DownloadKind {
Small,
}
#[derive(Debug, Clone)]
pub struct VersionId(pub String);
impl Default for DownloadOpts {
fn default() -> Self {
Self {
etag: Default::default(),
byte_start: Bound::Unbounded,
byte_end: Bound::Unbounded,
version_id: None,
kind: DownloadKind::Large,
}
}
@@ -295,6 +327,14 @@ pub trait RemoteStorage: Send + Sync + 'static {
Ok(combined)
}
async fn list_versions(
&self,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
cancel: &CancellationToken,
) -> Result<VersionListing, DownloadError>;
/// Obtain metadata information about an object.
async fn head_object(
&self,
@@ -475,6 +515,22 @@ impl<Other: RemoteStorage> GenericRemoteStorage<Arc<Other>> {
}
}
// See [`RemoteStorage::list_versions`].
pub async fn list_versions<'a>(
&'a self,
prefix: Option<&'a RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
cancel: &'a CancellationToken,
) -> Result<VersionListing, DownloadError> {
match self {
Self::LocalFs(s) => s.list_versions(prefix, mode, max_keys, cancel).await,
Self::AwsS3(s) => s.list_versions(prefix, mode, max_keys, cancel).await,
Self::AzureBlob(s) => s.list_versions(prefix, mode, max_keys, cancel).await,
Self::Unreliable(s) => s.list_versions(prefix, mode, max_keys, cancel).await,
}
}
// See [`RemoteStorage::head_object`].
pub async fn head_object(
&self,
@@ -727,6 +783,7 @@ impl ConcurrencyLimiter {
RequestKind::Copy => &self.write,
RequestKind::TimeTravel => &self.write,
RequestKind::Head => &self.read,
RequestKind::ListVersions => &self.read,
}
}

View File

@@ -445,6 +445,16 @@ impl RemoteStorage for LocalFs {
}
}
async fn list_versions(
&self,
_prefix: Option<&RemotePath>,
_mode: ListingMode,
_max_keys: Option<NonZeroU32>,
_cancel: &CancellationToken,
) -> Result<crate::VersionListing, DownloadError> {
unimplemented!()
}
async fn head_object(
&self,
key: &RemotePath,

View File

@@ -14,6 +14,7 @@ pub(crate) enum RequestKind {
Copy = 4,
TimeTravel = 5,
Head = 6,
ListVersions = 7,
}
use RequestKind::*;
@@ -29,6 +30,7 @@ impl RequestKind {
Copy => "copy_object",
TimeTravel => "time_travel_recover",
Head => "head_object",
ListVersions => "list_versions",
}
}
const fn as_index(&self) -> usize {
@@ -36,7 +38,10 @@ impl RequestKind {
}
}
const REQUEST_KIND_COUNT: usize = 7;
const REQUEST_KIND_LIST: &[RequestKind] =
&[Get, Put, Delete, List, Copy, TimeTravel, Head, ListVersions];
const REQUEST_KIND_COUNT: usize = REQUEST_KIND_LIST.len();
pub(crate) struct RequestTyped<C>([C; REQUEST_KIND_COUNT]);
impl<C> RequestTyped<C> {
@@ -45,12 +50,11 @@ impl<C> RequestTyped<C> {
}
fn build_with(mut f: impl FnMut(RequestKind) -> C) -> Self {
use RequestKind::*;
let mut it = [Get, Put, Delete, List, Copy, TimeTravel, Head].into_iter();
let mut it = REQUEST_KIND_LIST.iter();
let arr = std::array::from_fn::<C, REQUEST_KIND_COUNT, _>(|index| {
let next = it.next().unwrap();
assert_eq!(index, next.as_index());
f(next)
f(*next)
});
if let Some(next) = it.next() {

View File

@@ -21,9 +21,8 @@ use aws_sdk_s3::config::{AsyncSleep, IdentityCache, Region, SharedAsyncSleep};
use aws_sdk_s3::error::SdkError;
use aws_sdk_s3::operation::get_object::GetObjectError;
use aws_sdk_s3::operation::head_object::HeadObjectError;
use aws_sdk_s3::types::{Delete, DeleteMarkerEntry, ObjectIdentifier, ObjectVersion, StorageClass};
use aws_sdk_s3::types::{Delete, ObjectIdentifier, StorageClass};
use aws_smithy_async::rt::sleep::TokioSleep;
use aws_smithy_types::DateTime;
use aws_smithy_types::body::SdkBody;
use aws_smithy_types::byte_stream::ByteStream;
use aws_smithy_types::date_time::ConversionError;
@@ -46,7 +45,7 @@ use crate::support::PermitCarrying;
use crate::{
ConcurrencyLimiter, Download, DownloadError, DownloadOpts, Listing, ListingMode, ListingObject,
MAX_KEYS_PER_DELETE_S3, REMOTE_STORAGE_PREFIX_SEPARATOR, RemotePath, RemoteStorage,
TimeTravelError, TimeoutOrCancel,
TimeTravelError, TimeoutOrCancel, Version, VersionId, VersionKind, VersionListing,
};
/// AWS S3 storage.
@@ -66,6 +65,7 @@ struct GetObjectRequest {
key: String,
etag: Option<String>,
range: Option<String>,
version_id: Option<String>,
}
impl S3Bucket {
/// Creates the S3 storage, errors if incorrect AWS S3 configuration provided.
@@ -251,6 +251,7 @@ impl S3Bucket {
.get_object()
.bucket(request.bucket)
.key(request.key)
.set_version_id(request.version_id)
.set_range(request.range);
if let Some(etag) = request.etag {
@@ -405,6 +406,124 @@ impl S3Bucket {
Ok(())
}
async fn list_versions_with_permit(
&self,
_permit: &tokio::sync::SemaphorePermit<'_>,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
cancel: &CancellationToken,
) -> Result<crate::VersionListing, DownloadError> {
// get the passed prefix or if it is not set use prefix_in_bucket value
let prefix = prefix
.map(|p| self.relative_path_to_s3_object(p))
.or_else(|| self.prefix_in_bucket.clone());
let warn_threshold = 3;
let max_retries = 10;
let is_permanent = |e: &_| matches!(e, DownloadError::Cancelled);
let mut key_marker = None;
let mut version_id_marker = None;
let mut versions_and_deletes = Vec::new();
loop {
let response = backoff::retry(
|| async {
let mut request = self
.client
.list_object_versions()
.bucket(self.bucket_name.clone())
.set_prefix(prefix.clone())
.set_key_marker(key_marker.clone())
.set_version_id_marker(version_id_marker.clone());
if let ListingMode::WithDelimiter = mode {
request = request.delimiter(REMOTE_STORAGE_PREFIX_SEPARATOR.to_string());
}
let op = request.send();
tokio::select! {
res = op => res.map_err(|e| DownloadError::Other(e.into())),
_ = cancel.cancelled() => Err(DownloadError::Cancelled),
}
},
is_permanent,
warn_threshold,
max_retries,
"listing object versions",
cancel,
)
.await
.ok_or_else(|| DownloadError::Cancelled)
.and_then(|x| x)?;
tracing::trace!(
" Got List response version_id_marker={:?}, key_marker={:?}",
response.version_id_marker,
response.key_marker
);
let versions = response
.versions
.unwrap_or_default()
.into_iter()
.map(|version| {
let key = version.key.expect("response does not contain a key");
let key = self.s3_object_to_relative_path(&key);
let version_id = VersionId(version.version_id.expect("needing version id"));
let last_modified =
SystemTime::try_from(version.last_modified.expect("no last_modified"))?;
Ok(Version {
key,
last_modified,
kind: crate::VersionKind::Version(version_id),
})
});
let deletes = response
.delete_markers
.unwrap_or_default()
.into_iter()
.map(|version| {
let key = version.key.expect("response does not contain a key");
let key = self.s3_object_to_relative_path(&key);
let last_modified =
SystemTime::try_from(version.last_modified.expect("no last_modified"))?;
Ok(Version {
key,
last_modified,
kind: crate::VersionKind::DeletionMarker,
})
});
itertools::process_results(versions.chain(deletes), |n_vds| {
versions_and_deletes.extend(n_vds)
})
.map_err(DownloadError::Other)?;
fn none_if_empty(v: Option<String>) -> Option<String> {
v.filter(|v| !v.is_empty())
}
version_id_marker = none_if_empty(response.next_version_id_marker);
key_marker = none_if_empty(response.next_key_marker);
if version_id_marker.is_none() {
// The final response is not supposed to be truncated
if response.is_truncated.unwrap_or_default() {
return Err(DownloadError::Other(anyhow::anyhow!(
"Received truncated ListObjectVersions response for prefix={prefix:?}"
)));
}
break;
}
if let Some(max_keys) = max_keys {
if versions_and_deletes.len() >= max_keys.get().try_into().unwrap() {
return Err(DownloadError::Other(anyhow::anyhow!("too many versions")));
}
}
}
Ok(VersionListing {
versions: versions_and_deletes,
})
}
pub fn bucket_name(&self) -> &str {
&self.bucket_name
}
@@ -621,6 +740,19 @@ impl RemoteStorage for S3Bucket {
}
}
async fn list_versions(
&self,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
cancel: &CancellationToken,
) -> Result<crate::VersionListing, DownloadError> {
let kind = RequestKind::ListVersions;
let permit = self.permit(kind, cancel).await?;
self.list_versions_with_permit(&permit, prefix, mode, max_keys, cancel)
.await
}
async fn head_object(
&self,
key: &RemotePath,
@@ -801,6 +933,7 @@ impl RemoteStorage for S3Bucket {
key: self.relative_path_to_s3_object(from),
etag: opts.etag.as_ref().map(|e| e.to_string()),
range: opts.byte_range_header(),
version_id: opts.version_id.as_ref().map(|v| v.0.to_owned()),
},
cancel,
)
@@ -845,94 +978,25 @@ impl RemoteStorage for S3Bucket {
let kind = RequestKind::TimeTravel;
let permit = self.permit(kind, cancel).await?;
let timestamp = DateTime::from(timestamp);
let done_if_after = DateTime::from(done_if_after);
tracing::trace!("Target time: {timestamp:?}, done_if_after {done_if_after:?}");
// get the passed prefix or if it is not set use prefix_in_bucket value
let prefix = prefix
.map(|p| self.relative_path_to_s3_object(p))
.or_else(|| self.prefix_in_bucket.clone());
// Limit the number of versions deletions, mostly so that we don't
// keep requesting forever if the list is too long, as we'd put the
// list in RAM.
// Building a list of 100k entries that reaches the limit roughly takes
// 40 seconds, and roughly corresponds to tenants of 2 TiB physical size.
const COMPLEXITY_LIMIT: Option<NonZeroU32> = NonZeroU32::new(100_000);
let warn_threshold = 3;
let max_retries = 10;
let is_permanent = |e: &_| matches!(e, TimeTravelError::Cancelled);
let mut key_marker = None;
let mut version_id_marker = None;
let mut versions_and_deletes = Vec::new();
loop {
let response = backoff::retry(
|| async {
let op = self
.client
.list_object_versions()
.bucket(self.bucket_name.clone())
.set_prefix(prefix.clone())
.set_key_marker(key_marker.clone())
.set_version_id_marker(version_id_marker.clone())
.send();
tokio::select! {
res = op => res.map_err(|e| TimeTravelError::Other(e.into())),
_ = cancel.cancelled() => Err(TimeTravelError::Cancelled),
}
},
is_permanent,
warn_threshold,
max_retries,
"listing object versions for time_travel_recover",
cancel,
)
let mode = ListingMode::NoDelimiter;
let version_listing = self
.list_versions_with_permit(&permit, prefix, mode, COMPLEXITY_LIMIT, cancel)
.await
.ok_or_else(|| TimeTravelError::Cancelled)
.and_then(|x| x)?;
tracing::trace!(
" Got List response version_id_marker={:?}, key_marker={:?}",
response.version_id_marker,
response.key_marker
);
let versions = response
.versions
.unwrap_or_default()
.into_iter()
.map(VerOrDelete::from_version);
let deletes = response
.delete_markers
.unwrap_or_default()
.into_iter()
.map(VerOrDelete::from_delete_marker);
itertools::process_results(versions.chain(deletes), |n_vds| {
versions_and_deletes.extend(n_vds)
})
.map_err(TimeTravelError::Other)?;
fn none_if_empty(v: Option<String>) -> Option<String> {
v.filter(|v| !v.is_empty())
}
version_id_marker = none_if_empty(response.next_version_id_marker);
key_marker = none_if_empty(response.next_key_marker);
if version_id_marker.is_none() {
// The final response is not supposed to be truncated
if response.is_truncated.unwrap_or_default() {
return Err(TimeTravelError::Other(anyhow::anyhow!(
"Received truncated ListObjectVersions response for prefix={prefix:?}"
)));
}
break;
}
// Limit the number of versions deletions, mostly so that we don't
// keep requesting forever if the list is too long, as we'd put the
// list in RAM.
// Building a list of 100k entries that reaches the limit roughly takes
// 40 seconds, and roughly corresponds to tenants of 2 TiB physical size.
const COMPLEXITY_LIMIT: usize = 100_000;
if versions_and_deletes.len() >= COMPLEXITY_LIMIT {
return Err(TimeTravelError::TooManyVersions);
}
}
.map_err(|err| match err {
DownloadError::Other(e) => TimeTravelError::Other(e),
DownloadError::Cancelled => TimeTravelError::Cancelled,
other => TimeTravelError::Other(other.into()),
})?;
let versions_and_deletes = version_listing.versions;
tracing::info!(
"Built list for time travel with {} versions and deletions",
@@ -948,24 +1012,26 @@ impl RemoteStorage for S3Bucket {
let mut vds_for_key = HashMap::<_, Vec<_>>::new();
for vd in &versions_and_deletes {
let VerOrDelete {
version_id, key, ..
} = &vd;
if version_id == "null" {
let Version { key, .. } = &vd;
let version_id = vd.version_id().map(|v| v.0.as_str());
if version_id == Some("null") {
return Err(TimeTravelError::Other(anyhow!(
"Received ListVersions response for key={key} with version_id='null', \
indicating either disabled versioning, or legacy objects with null version id values"
)));
}
tracing::trace!(
"Parsing version key={key} version_id={version_id} kind={:?}",
vd.kind
);
tracing::trace!("Parsing version key={key} kind={:?}", vd.kind);
vds_for_key.entry(key).or_default().push(vd);
}
let warn_threshold = 3;
let max_retries = 10;
let is_permanent = |e: &_| matches!(e, TimeTravelError::Cancelled);
for (key, versions) in vds_for_key {
let last_vd = versions.last().unwrap();
let key = self.relative_path_to_s3_object(key);
if last_vd.last_modified > done_if_after {
tracing::trace!("Key {key} has version later than done_if_after, skipping");
continue;
@@ -990,11 +1056,11 @@ impl RemoteStorage for S3Bucket {
do_delete = true;
} else {
match &versions[version_to_restore_to - 1] {
VerOrDelete {
kind: VerOrDeleteKind::Version,
version_id,
Version {
kind: VersionKind::Version(version_id),
..
} => {
let version_id = &version_id.0;
tracing::trace!("Copying old version {version_id} for {key}...");
// Restore the state to the last version by copying
let source_id =
@@ -1006,7 +1072,7 @@ impl RemoteStorage for S3Bucket {
.client
.copy_object()
.bucket(self.bucket_name.clone())
.key(key)
.key(&key)
.set_storage_class(self.upload_storage_class.clone())
.copy_source(&source_id)
.send();
@@ -1027,8 +1093,8 @@ impl RemoteStorage for S3Bucket {
.and_then(|x| x)?;
tracing::info!(%version_id, %key, "Copied old version in S3");
}
VerOrDelete {
kind: VerOrDeleteKind::DeleteMarker,
Version {
kind: VersionKind::DeletionMarker,
..
} => {
do_delete = true;
@@ -1036,7 +1102,7 @@ impl RemoteStorage for S3Bucket {
}
};
if do_delete {
if matches!(last_vd.kind, VerOrDeleteKind::DeleteMarker) {
if matches!(last_vd.kind, VersionKind::DeletionMarker) {
// Key has since been deleted (but there was some history), no need to do anything
tracing::trace!("Key {key} already deleted, skipping.");
} else {
@@ -1064,62 +1130,6 @@ impl RemoteStorage for S3Bucket {
}
}
// Save RAM and only store the needed data instead of the entire ObjectVersion/DeleteMarkerEntry
struct VerOrDelete {
kind: VerOrDeleteKind,
last_modified: DateTime,
version_id: String,
key: String,
}
#[derive(Debug)]
enum VerOrDeleteKind {
Version,
DeleteMarker,
}
impl VerOrDelete {
fn with_kind(
kind: VerOrDeleteKind,
last_modified: Option<DateTime>,
version_id: Option<String>,
key: Option<String>,
) -> anyhow::Result<Self> {
let lvk = (last_modified, version_id, key);
let (Some(last_modified), Some(version_id), Some(key)) = lvk else {
anyhow::bail!(
"One (or more) of last_modified, key, and id is None. \
Is versioning enabled in the bucket? last_modified={:?}, version_id={:?}, key={:?}",
lvk.0,
lvk.1,
lvk.2,
);
};
Ok(Self {
kind,
last_modified,
version_id,
key,
})
}
fn from_version(v: ObjectVersion) -> anyhow::Result<Self> {
Self::with_kind(
VerOrDeleteKind::Version,
v.last_modified,
v.version_id,
v.key,
)
}
fn from_delete_marker(v: DeleteMarkerEntry) -> anyhow::Result<Self> {
Self::with_kind(
VerOrDeleteKind::DeleteMarker,
v.last_modified,
v.version_id,
v.key,
)
}
}
#[cfg(test)]
mod tests {
use std::num::NonZeroUsize;

View File

@@ -139,6 +139,20 @@ impl RemoteStorage for UnreliableWrapper {
self.inner.list(prefix, mode, max_keys, cancel).await
}
async fn list_versions(
&self,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
cancel: &CancellationToken,
) -> Result<crate::VersionListing, DownloadError> {
self.attempt(RemoteOp::ListPrefixes(prefix.cloned()))
.map_err(DownloadError::Other)?;
self.inner
.list_versions(prefix, mode, max_keys, cancel)
.await
}
async fn head_object(
&self,
key: &RemotePath,

View File

@@ -106,6 +106,7 @@ hex-literal.workspace = true
tokio = { workspace = true, features = ["process", "sync", "fs", "rt", "io-util", "time", "test-util"] }
indoc.workspace = true
uuid.workspace = true
rstest.workspace = true
[[bench]]
name = "bench_layer_map"

View File

@@ -11,6 +11,7 @@ use pageserver::task_mgr::TaskKind;
use pageserver::tenant::storage_layer::InMemoryLayer;
use pageserver::{page_cache, virtual_file};
use pageserver_api::key::Key;
use pageserver_api::models::virtual_file::IoMode;
use pageserver_api::shard::TenantShardId;
use pageserver_api::value::Value;
use tokio_util::sync::CancellationToken;
@@ -28,6 +29,7 @@ fn murmurhash32(mut h: u32) -> u32 {
h
}
#[derive(serde::Serialize, Clone, Copy, Debug)]
enum KeyLayout {
/// Sequential unique keys
Sequential,
@@ -37,6 +39,7 @@ enum KeyLayout {
RandomReuse(u32),
}
#[derive(serde::Serialize, Clone, Copy, Debug)]
enum WriteDelta {
Yes,
No,
@@ -138,12 +141,15 @@ async fn ingest(
/// Wrapper to instantiate a tokio runtime
fn ingest_main(
conf: &'static PageServerConf,
io_mode: IoMode,
put_size: usize,
put_count: usize,
key_layout: KeyLayout,
write_delta: WriteDelta,
) {
let runtime = tokio::runtime::Builder::new_current_thread()
pageserver::virtual_file::set_io_mode(io_mode);
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
@@ -174,93 +180,245 @@ fn criterion_benchmark(c: &mut Criterion) {
virtual_file::init(
16384,
virtual_file::io_engine_for_bench(),
// immaterial, each `ingest_main` invocation below overrides this
conf.virtual_file_io_mode,
// without actually doing syncs, buffered writes have an unfair advantage over direct IO writes
virtual_file::SyncMode::Sync,
);
page_cache::init(conf.page_cache_size);
{
let mut group = c.benchmark_group("ingest-small-values");
let put_size = 100usize;
let put_count = 128 * 1024 * 1024 / put_size;
group.throughput(criterion::Throughput::Bytes((put_size * put_count) as u64));
group.sample_size(10);
group.bench_function("ingest 128MB/100b seq", |b| {
b.iter(|| {
ingest_main(
conf,
put_size,
put_count,
KeyLayout::Sequential,
WriteDelta::Yes,
)
})
});
group.bench_function("ingest 128MB/100b rand", |b| {
b.iter(|| {
ingest_main(
conf,
put_size,
put_count,
KeyLayout::Random,
WriteDelta::Yes,
)
})
});
group.bench_function("ingest 128MB/100b rand-1024keys", |b| {
b.iter(|| {
ingest_main(
conf,
put_size,
put_count,
KeyLayout::RandomReuse(0x3ff),
WriteDelta::Yes,
)
})
});
group.bench_function("ingest 128MB/100b seq, no delta", |b| {
b.iter(|| {
ingest_main(
conf,
put_size,
put_count,
KeyLayout::Sequential,
WriteDelta::No,
)
})
});
#[derive(serde::Serialize)]
struct ExplodedParameters {
io_mode: IoMode,
volume_mib: usize,
key_size: usize,
key_layout: KeyLayout,
write_delta: WriteDelta,
}
{
let mut group = c.benchmark_group("ingest-big-values");
let put_size = 8192usize;
let put_count = 128 * 1024 * 1024 / put_size;
group.throughput(criterion::Throughput::Bytes((put_size * put_count) as u64));
#[derive(Clone)]
struct HandPickedParameters {
volume_mib: usize,
key_size: usize,
key_layout: KeyLayout,
write_delta: WriteDelta,
}
let expect = vec![
// Small values (100b) tests
HandPickedParameters {
volume_mib: 128,
key_size: 100,
key_layout: KeyLayout::Sequential,
write_delta: WriteDelta::Yes,
},
HandPickedParameters {
volume_mib: 128,
key_size: 100,
key_layout: KeyLayout::Random,
write_delta: WriteDelta::Yes,
},
HandPickedParameters {
volume_mib: 128,
key_size: 100,
key_layout: KeyLayout::RandomReuse(0x3ff),
write_delta: WriteDelta::Yes,
},
HandPickedParameters {
volume_mib: 128,
key_size: 100,
key_layout: KeyLayout::Sequential,
write_delta: WriteDelta::No,
},
// Large values (8k) tests
HandPickedParameters {
volume_mib: 128,
key_size: 8192,
key_layout: KeyLayout::Sequential,
write_delta: WriteDelta::Yes,
},
HandPickedParameters {
volume_mib: 128,
key_size: 8192,
key_layout: KeyLayout::Sequential,
write_delta: WriteDelta::No,
},
];
let exploded_parameters = {
let mut out = Vec::new();
for io_mode in [
IoMode::Buffered,
#[cfg(target_os = "linux")]
IoMode::Direct,
#[cfg(target_os = "linux")]
IoMode::DirectRw,
] {
for param in expect.clone() {
let HandPickedParameters {
volume_mib,
key_size,
key_layout,
write_delta,
} = param;
out.push(ExplodedParameters {
io_mode,
volume_mib,
key_size,
key_layout,
write_delta,
});
}
}
out
};
impl ExplodedParameters {
fn benchmark_id(&self) -> String {
let ExplodedParameters {
io_mode,
volume_mib,
key_size,
key_layout,
write_delta,
} = self;
format!(
"io_mode={io_mode:?} volume_mib={volume_mib:?} key_size_bytes={key_size:?} key_layout={key_layout:?} write_delta={write_delta:?}"
)
}
}
let mut group = c.benchmark_group("ingest");
for params in exploded_parameters {
let id = params.benchmark_id();
let ExplodedParameters {
io_mode,
volume_mib,
key_size,
key_layout,
write_delta,
} = params;
let put_count = volume_mib * 1024 * 1024 / key_size;
group.throughput(criterion::Throughput::Bytes((key_size * put_count) as u64));
group.sample_size(10);
group.bench_function("ingest 128MB/8k seq", |b| {
b.iter(|| {
ingest_main(
conf,
put_size,
put_count,
KeyLayout::Sequential,
WriteDelta::Yes,
)
})
});
group.bench_function("ingest 128MB/8k seq, no delta", |b| {
b.iter(|| {
ingest_main(
conf,
put_size,
put_count,
KeyLayout::Sequential,
WriteDelta::No,
)
})
group.bench_function(id, |b| {
b.iter(|| ingest_main(conf, io_mode, key_size, put_count, key_layout, write_delta))
});
}
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
/*
cargo bench --bench bench_ingest
im4gn.2xlarge:
ingest/io_mode=Buffered volume_mib=128 key_size_bytes=100 key_layout=Sequential write_delta=Yes
time: [1.2901 s 1.2943 s 1.2991 s]
thrpt: [98.533 MiB/s 98.892 MiB/s 99.220 MiB/s]
ingest/io_mode=Buffered volume_mib=128 key_size_bytes=100 key_layout=Random write_delta=Yes
time: [2.1387 s 2.1623 s 2.1845 s]
thrpt: [58.595 MiB/s 59.197 MiB/s 59.851 MiB/s]
ingest/io_mode=Buffered volume_mib=128 key_size_bytes=100 key_layout=RandomReuse(1023) write_delta=Y...
time: [1.2036 s 1.2074 s 1.2122 s]
thrpt: [105.60 MiB/s 106.01 MiB/s 106.35 MiB/s]
ingest/io_mode=Buffered volume_mib=128 key_size_bytes=100 key_layout=Sequential write_delta=No
time: [520.55 ms 521.46 ms 522.57 ms]
thrpt: [244.94 MiB/s 245.47 MiB/s 245.89 MiB/s]
ingest/io_mode=Buffered volume_mib=128 key_size_bytes=8192 key_layout=Sequential write_delta=Yes
time: [440.33 ms 442.24 ms 444.10 ms]
thrpt: [288.22 MiB/s 289.43 MiB/s 290.69 MiB/s]
ingest/io_mode=Buffered volume_mib=128 key_size_bytes=8192 key_layout=Sequential write_delta=No
time: [168.78 ms 169.42 ms 170.18 ms]
thrpt: [752.16 MiB/s 755.52 MiB/s 758.40 MiB/s]
ingest/io_mode=Direct volume_mib=128 key_size_bytes=100 key_layout=Sequential write_delta=Yes
time: [1.2978 s 1.3094 s 1.3227 s]
thrpt: [96.775 MiB/s 97.758 MiB/s 98.632 MiB/s]
ingest/io_mode=Direct volume_mib=128 key_size_bytes=100 key_layout=Random write_delta=Yes
time: [2.1976 s 2.2067 s 2.2154 s]
thrpt: [57.777 MiB/s 58.006 MiB/s 58.245 MiB/s]
ingest/io_mode=Direct volume_mib=128 key_size_bytes=100 key_layout=RandomReuse(1023) write_delta=Yes
time: [1.2103 s 1.2160 s 1.2233 s]
thrpt: [104.64 MiB/s 105.26 MiB/s 105.76 MiB/s]
ingest/io_mode=Direct volume_mib=128 key_size_bytes=100 key_layout=Sequential write_delta=No
time: [525.05 ms 526.37 ms 527.79 ms]
thrpt: [242.52 MiB/s 243.17 MiB/s 243.79 MiB/s]
ingest/io_mode=Direct volume_mib=128 key_size_bytes=8192 key_layout=Sequential write_delta=Yes
time: [443.06 ms 444.88 ms 447.15 ms]
thrpt: [286.26 MiB/s 287.72 MiB/s 288.90 MiB/s]
ingest/io_mode=Direct volume_mib=128 key_size_bytes=8192 key_layout=Sequential write_delta=No
time: [169.40 ms 169.80 ms 170.17 ms]
thrpt: [752.21 MiB/s 753.81 MiB/s 755.60 MiB/s]
ingest/io_mode=DirectRw volume_mib=128 key_size_bytes=100 key_layout=Sequential write_delta=Yes
time: [1.2844 s 1.2915 s 1.2990 s]
thrpt: [98.536 MiB/s 99.112 MiB/s 99.657 MiB/s]
ingest/io_mode=DirectRw volume_mib=128 key_size_bytes=100 key_layout=Random write_delta=Yes
time: [2.1431 s 2.1663 s 2.1900 s]
thrpt: [58.446 MiB/s 59.087 MiB/s 59.726 MiB/s]
ingest/io_mode=DirectRw volume_mib=128 key_size_bytes=100 key_layout=RandomReuse(1023) write_delta=Y...
time: [1.1906 s 1.1926 s 1.1947 s]
thrpt: [107.14 MiB/s 107.33 MiB/s 107.51 MiB/s]
ingest/io_mode=DirectRw volume_mib=128 key_size_bytes=100 key_layout=Sequential write_delta=No
time: [516.86 ms 518.25 ms 519.47 ms]
thrpt: [246.40 MiB/s 246.98 MiB/s 247.65 MiB/s]
ingest/io_mode=DirectRw volume_mib=128 key_size_bytes=8192 key_layout=Sequential write_delta=Yes
time: [536.50 ms 536.53 ms 536.60 ms]
thrpt: [238.54 MiB/s 238.57 MiB/s 238.59 MiB/s]
ingest/io_mode=DirectRw volume_mib=128 key_size_bytes=8192 key_layout=Sequential write_delta=No
time: [267.77 ms 267.90 ms 268.04 ms]
thrpt: [477.53 MiB/s 477.79 MiB/s 478.02 MiB/s]
Hetzner AX102:
ingest/io_mode=Buffered volume_mib=128 key_size_bytes=100 key_layout=Sequential write_delta=Yes
time: [836.58 ms 861.93 ms 886.57 ms]
thrpt: [144.38 MiB/s 148.50 MiB/s 153.00 MiB/s]
ingest/io_mode=Buffered volume_mib=128 key_size_bytes=100 key_layout=Random write_delta=Yes
time: [1.2782 s 1.3191 s 1.3665 s]
thrpt: [93.668 MiB/s 97.037 MiB/s 100.14 MiB/s]
ingest/io_mode=Buffered volume_mib=128 key_size_bytes=100 key_layout=RandomReuse(1023) write_delta=Y...
time: [791.27 ms 807.08 ms 822.95 ms]
thrpt: [155.54 MiB/s 158.60 MiB/s 161.77 MiB/s]
ingest/io_mode=Buffered volume_mib=128 key_size_bytes=100 key_layout=Sequential write_delta=No
time: [310.78 ms 314.66 ms 318.47 ms]
thrpt: [401.92 MiB/s 406.79 MiB/s 411.87 MiB/s]
ingest/io_mode=Buffered volume_mib=128 key_size_bytes=8192 key_layout=Sequential write_delta=Yes
time: [377.11 ms 387.77 ms 399.21 ms]
thrpt: [320.63 MiB/s 330.10 MiB/s 339.42 MiB/s]
ingest/io_mode=Buffered volume_mib=128 key_size_bytes=8192 key_layout=Sequential write_delta=No
time: [128.37 ms 132.96 ms 138.55 ms]
thrpt: [923.83 MiB/s 962.69 MiB/s 997.11 MiB/s]
ingest/io_mode=Direct volume_mib=128 key_size_bytes=100 key_layout=Sequential write_delta=Yes
time: [900.38 ms 914.88 ms 928.86 ms]
thrpt: [137.80 MiB/s 139.91 MiB/s 142.16 MiB/s]
ingest/io_mode=Direct volume_mib=128 key_size_bytes=100 key_layout=Random write_delta=Yes
time: [1.2538 s 1.2936 s 1.3313 s]
thrpt: [96.149 MiB/s 98.946 MiB/s 102.09 MiB/s]
ingest/io_mode=Direct volume_mib=128 key_size_bytes=100 key_layout=RandomReuse(1023) write_delta=Yes
time: [787.17 ms 803.89 ms 820.63 ms]
thrpt: [155.98 MiB/s 159.23 MiB/s 162.61 MiB/s]
ingest/io_mode=Direct volume_mib=128 key_size_bytes=100 key_layout=Sequential write_delta=No
time: [318.78 ms 321.89 ms 324.74 ms]
thrpt: [394.16 MiB/s 397.65 MiB/s 401.53 MiB/s]
ingest/io_mode=Direct volume_mib=128 key_size_bytes=8192 key_layout=Sequential write_delta=Yes
time: [374.01 ms 383.45 ms 393.20 ms]
thrpt: [325.53 MiB/s 333.81 MiB/s 342.24 MiB/s]
ingest/io_mode=Direct volume_mib=128 key_size_bytes=8192 key_layout=Sequential write_delta=No
time: [137.98 ms 141.31 ms 143.57 ms]
thrpt: [891.58 MiB/s 905.79 MiB/s 927.66 MiB/s]
ingest/io_mode=DirectRw volume_mib=128 key_size_bytes=100 key_layout=Sequential write_delta=Yes
time: [613.69 ms 622.48 ms 630.97 ms]
thrpt: [202.86 MiB/s 205.63 MiB/s 208.57 MiB/s]
ingest/io_mode=DirectRw volume_mib=128 key_size_bytes=100 key_layout=Random write_delta=Yes
time: [1.0299 s 1.0766 s 1.1273 s]
thrpt: [113.55 MiB/s 118.90 MiB/s 124.29 MiB/s]
ingest/io_mode=DirectRw volume_mib=128 key_size_bytes=100 key_layout=RandomReuse(1023) write_delta=Y...
time: [637.80 ms 647.78 ms 658.01 ms]
thrpt: [194.53 MiB/s 197.60 MiB/s 200.69 MiB/s]
ingest/io_mode=DirectRw volume_mib=128 key_size_bytes=100 key_layout=Sequential write_delta=No
time: [266.09 ms 267.20 ms 268.31 ms]
thrpt: [477.06 MiB/s 479.04 MiB/s 481.04 MiB/s]
ingest/io_mode=DirectRw volume_mib=128 key_size_bytes=8192 key_layout=Sequential write_delta=Yes
time: [269.34 ms 273.27 ms 277.69 ms]
thrpt: [460.95 MiB/s 468.40 MiB/s 475.24 MiB/s]
ingest/io_mode=DirectRw volume_mib=128 key_size_bytes=8192 key_layout=Sequential write_delta=No
time: [123.18 ms 124.24 ms 125.15 ms]
thrpt: [1022.8 MiB/s 1.0061 GiB/s 1.0148 GiB/s]
*/

View File

@@ -419,6 +419,23 @@ impl Client {
}
}
pub async fn timeline_detail(
&self,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
) -> Result<TimelineInfo> {
let uri = format!(
"{}/v1/tenant/{tenant_shard_id}/timeline/{timeline_id}",
self.mgmt_api_endpoint
);
self.request(Method::GET, &uri, ())
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
pub async fn timeline_archival_config(
&self,
tenant_shard_id: TenantShardId,

View File

@@ -225,6 +225,11 @@ pub struct PageServerConf {
/// Does not force TLS: the client negotiates TLS usage during the handshake.
/// Uses key and certificate from ssl_key_file/ssl_cert_file.
pub enable_tls_page_service_api: bool,
/// Run in development mode, which disables certain safety checks
/// such as authentication requirements for HTTP and PostgreSQL APIs.
/// This is insecure and should only be used in development environments.
pub dev_mode: bool,
}
/// Token for authentication to safekeepers
@@ -398,6 +403,7 @@ impl PageServerConf {
generate_unarchival_heatmap,
tracing,
enable_tls_page_service_api,
dev_mode,
} = config_toml;
let mut conf = PageServerConf {
@@ -449,6 +455,7 @@ impl PageServerConf {
get_vectored_concurrent_io,
tracing,
enable_tls_page_service_api,
dev_mode,
// ------------------------------------------------------------
// fields that require additional validation or custom handling

View File

@@ -3,10 +3,11 @@ use std::collections::HashMap;
use futures::Future;
use pageserver_api::config::NodeMetadata;
use pageserver_api::controller_api::{AvailabilityZone, NodeRegisterRequest};
use pageserver_api::models::ShardImportStatus;
use pageserver_api::shard::TenantShardId;
use pageserver_api::upcall_api::{
ReAttachRequest, ReAttachResponse, ReAttachResponseTenant, ValidateRequest,
ValidateRequestTenant, ValidateResponse,
PutTimelineImportStatusRequest, ReAttachRequest, ReAttachResponse, ReAttachResponseTenant,
ValidateRequest, ValidateRequestTenant, ValidateResponse,
};
use reqwest::Certificate;
use serde::Serialize;
@@ -14,7 +15,7 @@ use serde::de::DeserializeOwned;
use tokio_util::sync::CancellationToken;
use url::Url;
use utils::generation::Generation;
use utils::id::NodeId;
use utils::id::{NodeId, TimelineId};
use utils::{backoff, failpoint_support};
use crate::config::PageServerConf;
@@ -46,6 +47,12 @@ pub trait StorageControllerUpcallApi {
&self,
tenants: Vec<(TenantShardId, Generation)>,
) -> impl Future<Output = Result<HashMap<TenantShardId, bool>, RetryForeverError>> + Send;
fn put_timeline_import_status(
&self,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
status: ShardImportStatus,
) -> impl Future<Output = Result<(), RetryForeverError>> + Send;
}
impl StorageControllerUpcallClient {
@@ -273,4 +280,30 @@ impl StorageControllerUpcallApi for StorageControllerUpcallClient {
Ok(result.into_iter().collect())
}
/// Send a shard import status to the storage controller
///
/// The implementation must have at-least-once delivery semantics.
/// To this end, we retry the request until it succeeds. If the pageserver
/// restarts or crashes, the shard import will start again from the beggining.
#[tracing::instrument(skip_all)] // so that warning logs from retry_http_forever have context
async fn put_timeline_import_status(
&self,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
status: ShardImportStatus,
) -> Result<(), RetryForeverError> {
let url = self
.base_url
.join("timeline_import_status")
.expect("Failed to build path");
let request = PutTimelineImportStatusRequest {
tenant_shard_id,
timeline_id,
status,
};
self.retry_http_forever(&url, request).await
}
}

View File

@@ -787,6 +787,15 @@ mod test {
Ok(result)
}
async fn put_timeline_import_status(
&self,
_tenant_shard_id: TenantShardId,
_timeline_id: TimelineId,
_status: pageserver_api::models::ShardImportStatus,
) -> Result<(), RetryForeverError> {
unimplemented!()
}
}
async fn setup(test_name: &str) -> anyhow::Result<TestSetup> {

View File

@@ -1289,6 +1289,7 @@ pub(crate) enum StorageIoOperation {
Seek,
Fsync,
Metadata,
SetLen,
}
impl StorageIoOperation {
@@ -1303,6 +1304,7 @@ impl StorageIoOperation {
StorageIoOperation::Seek => "seek",
StorageIoOperation::Fsync => "fsync",
StorageIoOperation::Metadata => "metadata",
StorageIoOperation::SetLen => "set_len",
}
}
}

View File

@@ -15,21 +15,23 @@
//! len >= 128: 1CCCXXXX XXXXXXXX XXXXXXXX XXXXXXXX
//!
use std::cmp::min;
use std::io::Error;
use anyhow::Context;
use async_compression::Level;
use bytes::{BufMut, BytesMut};
use pageserver_api::models::ImageCompressionAlgorithm;
use tokio::io::AsyncWriteExt;
use tokio_epoll_uring::{BoundedBuf, IoBuf, Slice};
use tokio_epoll_uring::IoBuf;
use tokio_util::sync::CancellationToken;
use tracing::warn;
use crate::context::RequestContext;
use crate::page_cache::PAGE_SZ;
use crate::tenant::block_io::BlockCursor;
use crate::virtual_file::VirtualFile;
use crate::virtual_file::IoBufferMut;
use crate::virtual_file::owned_buffers_io::io_buf_ext::{FullSlice, IoBufExt};
use crate::virtual_file::owned_buffers_io::write::{BufferedWriter, FlushTaskError};
use crate::virtual_file::owned_buffers_io::write::{BufferedWriterShutdownMode, OwnedAsyncWriter};
#[derive(Copy, Clone, Debug)]
pub struct CompressionInfo {
@@ -50,12 +52,9 @@ pub struct Header {
impl Header {
/// Decodes a header from a byte slice.
pub fn decode(bytes: &[u8]) -> Result<Self, std::io::Error> {
pub fn decode(bytes: &[u8]) -> anyhow::Result<Self> {
let Some(&first_header_byte) = bytes.first() else {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"zero-length blob header",
));
anyhow::bail!("zero-length blob header");
};
// If the first bit is 0, this is just a 1-byte length prefix up to 128 bytes.
@@ -69,12 +68,9 @@ impl Header {
// Otherwise, this is a 4-byte header containing compression information and length.
const HEADER_LEN: usize = 4;
let mut header_buf: [u8; HEADER_LEN] = bytes[0..HEADER_LEN].try_into().map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("blob header too short: {bytes:?}"),
)
})?;
let mut header_buf: [u8; HEADER_LEN] = bytes[0..HEADER_LEN]
.try_into()
.map_err(|_| anyhow::anyhow!("blob header too short: {bytes:?}"))?;
// TODO: verify the compression bits and convert to an enum.
let compression_bits = header_buf[0] & LEN_COMPRESSION_BIT_MASK;
@@ -94,6 +90,16 @@ impl Header {
}
}
#[derive(Debug, thiserror::Error)]
pub enum WriteBlobError {
#[error(transparent)]
Flush(FlushTaskError),
#[error("blob too large ({len} bytes)")]
BlobTooLarge { len: usize },
#[error(transparent)]
WriteBlobRaw(anyhow::Error),
}
impl BlockCursor<'_> {
/// Read a blob into a new buffer.
pub async fn read_blob(
@@ -213,143 +219,64 @@ pub(super) const BYTE_UNCOMPRESSED: u8 = 0x80;
pub(super) const BYTE_ZSTD: u8 = BYTE_UNCOMPRESSED | 0x10;
/// A wrapper of `VirtualFile` that allows users to write blobs.
///
/// If a `BlobWriter` is dropped, the internal buffer will be
/// discarded. You need to call [`flush_buffer`](Self::flush_buffer)
/// manually before dropping.
pub struct BlobWriter<const BUFFERED: bool> {
inner: VirtualFile,
offset: u64,
/// A buffer to save on write calls, only used if BUFFERED=true
buf: Vec<u8>,
pub struct BlobWriter<W> {
/// We do tiny writes for the length headers; they need to be in an owned buffer;
io_buf: Option<BytesMut>,
writer: BufferedWriter<IoBufferMut, W>,
offset: u64,
}
impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
impl<W> BlobWriter<W>
where
W: OwnedAsyncWriter + std::fmt::Debug + Send + Sync + 'static,
{
/// See [`BufferedWriter`] struct-level doc comment for semantics of `start_offset`.
pub fn new(
inner: VirtualFile,
file: W,
start_offset: u64,
_gate: &utils::sync::gate::Gate,
_cancel: CancellationToken,
_ctx: &RequestContext,
) -> Self {
Self {
inner,
offset: start_offset,
buf: Vec::with_capacity(Self::CAPACITY),
gate: &utils::sync::gate::Gate,
cancel: CancellationToken,
ctx: &RequestContext,
flush_task_span: tracing::Span,
) -> anyhow::Result<Self> {
Ok(Self {
io_buf: Some(BytesMut::new()),
}
writer: BufferedWriter::new(
file,
start_offset,
|| IoBufferMut::with_capacity(Self::CAPACITY),
gate.enter()?,
cancel,
ctx,
flush_task_span,
),
offset: start_offset,
})
}
pub fn size(&self) -> u64 {
self.offset
}
const CAPACITY: usize = if BUFFERED { 64 * 1024 } else { 0 };
const CAPACITY: usize = 64 * 1024;
/// Writes the given buffer directly to the underlying `VirtualFile`.
/// You need to make sure that the internal buffer is empty, otherwise
/// data will be written in wrong order.
#[inline(always)]
async fn write_all_unbuffered<Buf: IoBuf + Send>(
&mut self,
src_buf: FullSlice<Buf>,
ctx: &RequestContext,
) -> (FullSlice<Buf>, Result<(), Error>) {
let (src_buf, res) = self.inner.write_all(src_buf, ctx).await;
let nbytes = match res {
Ok(nbytes) => nbytes,
Err(e) => return (src_buf, Err(e)),
};
self.offset += nbytes as u64;
(src_buf, Ok(()))
}
#[inline(always)]
/// Flushes the internal buffer to the underlying `VirtualFile`.
pub async fn flush_buffer(&mut self, ctx: &RequestContext) -> Result<(), Error> {
let buf = std::mem::take(&mut self.buf);
let (slice, res) = self.inner.write_all(buf.slice_len(), ctx).await;
res?;
let mut buf = slice.into_raw_slice().into_inner();
buf.clear();
self.buf = buf;
Ok(())
}
#[inline(always)]
/// Writes as much of `src_buf` into the internal buffer as it fits
fn write_into_buffer(&mut self, src_buf: &[u8]) -> usize {
let remaining = Self::CAPACITY - self.buf.len();
let to_copy = src_buf.len().min(remaining);
self.buf.extend_from_slice(&src_buf[..to_copy]);
self.offset += to_copy as u64;
to_copy
}
/// Internal, possibly buffered, write function
/// Writes `src_buf` to the file at the current offset.
async fn write_all<Buf: IoBuf + Send>(
&mut self,
src_buf: FullSlice<Buf>,
ctx: &RequestContext,
) -> (FullSlice<Buf>, Result<(), Error>) {
let src_buf = src_buf.into_raw_slice();
let src_buf_bounds = src_buf.bounds();
let restore = move |src_buf_slice: Slice<_>| {
FullSlice::must_new(Slice::from_buf_bounds(
src_buf_slice.into_inner(),
src_buf_bounds,
))
};
) -> (FullSlice<Buf>, Result<(), FlushTaskError>) {
let res = self
.writer
// TODO: why are we taking a FullSlice if we're going to pass a borrow downstack?
// Can remove all the complexity around owned buffers upstack
.write_buffered_borrowed(&src_buf, ctx)
.await
.map(|len| {
self.offset += len as u64;
});
if !BUFFERED {
assert!(self.buf.is_empty());
return self
.write_all_unbuffered(FullSlice::must_new(src_buf), ctx)
.await;
}
let remaining = Self::CAPACITY - self.buf.len();
let src_buf_len = src_buf.bytes_init();
if src_buf_len == 0 {
return (restore(src_buf), Ok(()));
}
let mut src_buf = src_buf.slice(0..src_buf_len);
// First try to copy as much as we can into the buffer
if remaining > 0 {
let copied = self.write_into_buffer(&src_buf);
src_buf = src_buf.slice(copied..);
}
// Then, if the buffer is full, flush it out
if self.buf.len() == Self::CAPACITY {
if let Err(e) = self.flush_buffer(ctx).await {
return (restore(src_buf), Err(e));
}
}
// Finally, write the tail of src_buf:
// If it wholly fits into the buffer without
// completely filling it, then put it there.
// If not, write it out directly.
let src_buf = if !src_buf.is_empty() {
assert_eq!(self.buf.len(), 0);
if src_buf.len() < Self::CAPACITY {
let copied = self.write_into_buffer(&src_buf);
// We just verified above that src_buf fits into our internal buffer.
assert_eq!(copied, src_buf.len());
restore(src_buf)
} else {
let (src_buf, res) = self
.write_all_unbuffered(FullSlice::must_new(src_buf), ctx)
.await;
if let Err(e) = res {
return (src_buf, Err(e));
}
src_buf
}
} else {
restore(src_buf)
};
(src_buf, Ok(()))
(src_buf, res)
}
/// Write a blob of data. Returns the offset that it was written to,
@@ -358,7 +285,7 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
&mut self,
srcbuf: FullSlice<Buf>,
ctx: &RequestContext,
) -> (FullSlice<Buf>, Result<u64, Error>) {
) -> (FullSlice<Buf>, Result<u64, WriteBlobError>) {
let (buf, res) = self
.write_blob_maybe_compressed(srcbuf, ctx, ImageCompressionAlgorithm::Disabled)
.await;
@@ -372,7 +299,10 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
srcbuf: FullSlice<Buf>,
ctx: &RequestContext,
algorithm: ImageCompressionAlgorithm,
) -> (FullSlice<Buf>, Result<(u64, CompressionInfo), Error>) {
) -> (
FullSlice<Buf>,
Result<(u64, CompressionInfo), WriteBlobError>,
) {
let offset = self.offset;
let mut compression_info = CompressionInfo {
written_compressed: false,
@@ -388,14 +318,16 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
if len < 128 {
// Short blob. Write a 1-byte length header
io_buf.put_u8(len as u8);
(self.write_all(io_buf.slice_len(), ctx).await, srcbuf)
let (slice, res) = self.write_all(io_buf.slice_len(), ctx).await;
let res = res.map_err(WriteBlobError::Flush);
((slice, res), srcbuf)
} else {
// Write a 4-byte length header
if len > MAX_SUPPORTED_BLOB_LEN {
return (
(
io_buf.slice_len(),
Err(Error::other(format!("blob too large ({len} bytes)"))),
Err(WriteBlobError::BlobTooLarge { len }),
),
srcbuf,
);
@@ -429,7 +361,9 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
assert_eq!(len_buf[0] & 0xf0, 0);
len_buf[0] |= high_bit_mask;
io_buf.extend_from_slice(&len_buf[..]);
(self.write_all(io_buf.slice_len(), ctx).await, srcbuf)
let (slice, res) = self.write_all(io_buf.slice_len(), ctx).await;
let res = res.map_err(WriteBlobError::Flush);
((slice, res), srcbuf)
}
}
.await;
@@ -444,6 +378,7 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
} else {
self.write_all(srcbuf, ctx).await
};
let res = res.map_err(WriteBlobError::Flush);
(srcbuf, res.map(|_| (offset, compression_info)))
}
@@ -452,9 +387,12 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
&mut self,
raw_with_header: FullSlice<Buf>,
ctx: &RequestContext,
) -> (FullSlice<Buf>, Result<u64, Error>) {
) -> (FullSlice<Buf>, Result<u64, WriteBlobError>) {
// Verify the header, to ensure we don't write invalid/corrupt data.
let header = match Header::decode(&raw_with_header) {
let header = match Header::decode(&raw_with_header)
.context("decoding blob header")
.map_err(WriteBlobError::WriteBlobRaw)
{
Ok(header) => header,
Err(err) => return (raw_with_header, Err(err)),
};
@@ -463,42 +401,26 @@ impl<const BUFFERED: bool> BlobWriter<BUFFERED> {
let raw_len = raw_with_header.len();
return (
raw_with_header,
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("header length mismatch: {header_total_len} != {raw_len}"),
)),
Err(WriteBlobError::WriteBlobRaw(anyhow::anyhow!(
"header length mismatch: {header_total_len} != {raw_len}"
))),
);
}
let offset = self.offset;
let (raw_with_header, result) = self.write_all(raw_with_header, ctx).await;
let result = result.map_err(WriteBlobError::Flush);
(raw_with_header, result.map(|_| offset))
}
}
impl BlobWriter<true> {
/// Access the underlying `VirtualFile`.
///
/// This function flushes the internal buffer before giving access
/// to the underlying `VirtualFile`.
pub async fn into_inner(mut self, ctx: &RequestContext) -> Result<VirtualFile, Error> {
self.flush_buffer(ctx).await?;
Ok(self.inner)
}
/// Access the underlying `VirtualFile`.
///
/// Unlike [`into_inner`](Self::into_inner), this doesn't flush
/// the internal buffer before giving access.
pub fn into_inner_no_flush(self) -> VirtualFile {
self.inner
}
}
impl BlobWriter<false> {
/// Access the underlying `VirtualFile`.
pub fn into_inner(self) -> VirtualFile {
self.inner
/// Finish this blob writer and return the underlying `W`.
pub async fn shutdown(
self,
mode: BufferedWriterShutdownMode,
ctx: &RequestContext,
) -> Result<W, FlushTaskError> {
let (_, file) = self.writer.shutdown(mode, ctx).await?;
Ok(file)
}
}
@@ -507,21 +429,25 @@ pub(crate) mod tests {
use camino::Utf8PathBuf;
use camino_tempfile::Utf8TempDir;
use rand::{Rng, SeedableRng};
use tracing::info_span;
use super::*;
use crate::context::DownloadBehavior;
use crate::task_mgr::TaskKind;
use crate::tenant::block_io::BlockReaderRef;
use crate::virtual_file;
use crate::virtual_file::TempVirtualFile;
use crate::virtual_file::VirtualFile;
async fn round_trip_test<const BUFFERED: bool>(blobs: &[Vec<u8>]) -> Result<(), Error> {
round_trip_test_compressed::<BUFFERED>(blobs, false).await
async fn round_trip_test(blobs: &[Vec<u8>]) -> anyhow::Result<()> {
round_trip_test_compressed(blobs, false).await
}
pub(crate) async fn write_maybe_compressed<const BUFFERED: bool>(
pub(crate) async fn write_maybe_compressed(
blobs: &[Vec<u8>],
compression: bool,
ctx: &RequestContext,
) -> Result<(Utf8TempDir, Utf8PathBuf, Vec<u64>), Error> {
) -> anyhow::Result<(Utf8TempDir, Utf8PathBuf, Vec<u64>)> {
let temp_dir = camino_tempfile::tempdir()?;
let pathbuf = temp_dir.path().join("file");
let gate = utils::sync::gate::Gate::default();
@@ -530,8 +456,19 @@ pub(crate) mod tests {
// Write part (in block to drop the file)
let mut offsets = Vec::new();
{
let file = VirtualFile::create(pathbuf.as_path(), ctx).await?;
let mut wtr = BlobWriter::<BUFFERED>::new(file, 0, &gate, cancel.clone(), ctx);
let file = TempVirtualFile::new(
VirtualFile::open_with_options_v2(
pathbuf.as_path(),
virtual_file::OpenOptions::new()
.create_new(true)
.write(true),
ctx,
)
.await?,
gate.enter()?,
);
let mut wtr =
BlobWriter::new(file, 0, &gate, cancel.clone(), ctx, info_span!("test")).unwrap();
for blob in blobs.iter() {
let (_, res) = if compression {
let res = wtr
@@ -548,26 +485,28 @@ pub(crate) mod tests {
let offs = res?;
offsets.push(offs);
}
// Write out one page worth of zeros so that we can
// read again with read_blk
let (_, res) = wtr.write_blob(vec![0; PAGE_SZ].slice_len(), ctx).await;
let offs = res?;
println!("Writing final blob at offs={offs}");
wtr.flush_buffer(ctx).await?;
}
let file = wtr
.shutdown(
BufferedWriterShutdownMode::ZeroPadToNextMultiple(PAGE_SZ),
ctx,
)
.await?;
file.disarm_into_inner()
};
Ok((temp_dir, pathbuf, offsets))
}
async fn round_trip_test_compressed<const BUFFERED: bool>(
async fn round_trip_test_compressed(
blobs: &[Vec<u8>],
compression: bool,
) -> Result<(), Error> {
) -> anyhow::Result<()> {
let ctx =
RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error).with_scope_unit_test();
let (_temp_dir, pathbuf, offsets) =
write_maybe_compressed::<BUFFERED>(blobs, compression, &ctx).await?;
write_maybe_compressed(blobs, compression, &ctx).await?;
let file = VirtualFile::open(pathbuf, &ctx).await?;
println!("Done writing!");
let file = VirtualFile::open_v2(pathbuf, &ctx).await?;
let rdr = BlockReaderRef::VirtualFile(&file);
let rdr = BlockCursor::new_with_compression(rdr, compression);
for (idx, (blob, offset)) in blobs.iter().zip(offsets.iter()).enumerate() {
@@ -586,30 +525,27 @@ pub(crate) mod tests {
}
#[tokio::test]
async fn test_one() -> Result<(), Error> {
async fn test_one() -> anyhow::Result<()> {
let blobs = &[vec![12, 21, 22]];
round_trip_test::<false>(blobs).await?;
round_trip_test::<true>(blobs).await?;
round_trip_test(blobs).await?;
Ok(())
}
#[tokio::test]
async fn test_hello_simple() -> Result<(), Error> {
async fn test_hello_simple() -> anyhow::Result<()> {
let blobs = &[
vec![0, 1, 2, 3],
b"Hello, World!".to_vec(),
Vec::new(),
b"foobar".to_vec(),
];
round_trip_test::<false>(blobs).await?;
round_trip_test::<true>(blobs).await?;
round_trip_test_compressed::<false>(blobs, true).await?;
round_trip_test_compressed::<true>(blobs, true).await?;
round_trip_test(blobs).await?;
round_trip_test_compressed(blobs, true).await?;
Ok(())
}
#[tokio::test]
async fn test_really_big_array() -> Result<(), Error> {
async fn test_really_big_array() -> anyhow::Result<()> {
let blobs = &[
b"test".to_vec(),
random_array(10 * PAGE_SZ),
@@ -618,25 +554,22 @@ pub(crate) mod tests {
vec![0xf3; 24 * PAGE_SZ],
b"foobar".to_vec(),
];
round_trip_test::<false>(blobs).await?;
round_trip_test::<true>(blobs).await?;
round_trip_test_compressed::<false>(blobs, true).await?;
round_trip_test_compressed::<true>(blobs, true).await?;
round_trip_test(blobs).await?;
round_trip_test_compressed(blobs, true).await?;
Ok(())
}
#[tokio::test]
async fn test_arrays_inc() -> Result<(), Error> {
async fn test_arrays_inc() -> anyhow::Result<()> {
let blobs = (0..PAGE_SZ / 8)
.map(|v| random_array(v * 16))
.collect::<Vec<_>>();
round_trip_test::<false>(&blobs).await?;
round_trip_test::<true>(&blobs).await?;
round_trip_test(&blobs).await?;
Ok(())
}
#[tokio::test]
async fn test_arrays_random_size() -> Result<(), Error> {
async fn test_arrays_random_size() -> anyhow::Result<()> {
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let blobs = (0..1024)
.map(|_| {
@@ -648,20 +581,18 @@ pub(crate) mod tests {
random_array(sz.into())
})
.collect::<Vec<_>>();
round_trip_test::<false>(&blobs).await?;
round_trip_test::<true>(&blobs).await?;
round_trip_test(&blobs).await?;
Ok(())
}
#[tokio::test]
async fn test_arrays_page_boundary() -> Result<(), Error> {
async fn test_arrays_page_boundary() -> anyhow::Result<()> {
let blobs = &[
random_array(PAGE_SZ - 4),
random_array(PAGE_SZ - 4),
random_array(PAGE_SZ - 4),
];
round_trip_test::<false>(blobs).await?;
round_trip_test::<true>(blobs).await?;
round_trip_test(blobs).await?;
Ok(())
}
}

View File

@@ -4,14 +4,12 @@
use std::ops::Deref;
use bytes::Bytes;
use super::storage_layer::delta_layer::{Adapter, DeltaLayerInner};
use crate::context::RequestContext;
use crate::page_cache::{self, FileId, PAGE_SZ, PageReadGuard, PageWriteGuard, ReadBufResult};
#[cfg(test)]
use crate::virtual_file::IoBufferMut;
use crate::virtual_file::VirtualFile;
use crate::virtual_file::{IoBuffer, VirtualFile};
/// This is implemented by anything that can read 8 kB (PAGE_SZ)
/// blocks, using the page cache
@@ -247,17 +245,17 @@ pub trait BlockWriter {
/// 'buf' must be of size PAGE_SZ. Returns the block number the page was
/// written to.
///
fn write_blk(&mut self, buf: Bytes) -> Result<u32, std::io::Error>;
fn write_blk(&mut self, buf: IoBuffer) -> Result<u32, std::io::Error>;
}
///
/// A simple in-memory buffer of blocks.
///
pub struct BlockBuf {
pub blocks: Vec<Bytes>,
pub blocks: Vec<IoBuffer>,
}
impl BlockWriter for BlockBuf {
fn write_blk(&mut self, buf: Bytes) -> Result<u32, std::io::Error> {
fn write_blk(&mut self, buf: IoBuffer) -> Result<u32, std::io::Error> {
assert!(buf.len() == PAGE_SZ);
let blknum = self.blocks.len();
self.blocks.push(buf);

View File

@@ -25,7 +25,7 @@ use std::{io, result};
use async_stream::try_stream;
use byteorder::{BE, ReadBytesExt};
use bytes::{BufMut, Bytes, BytesMut};
use bytes::BufMut;
use either::Either;
use futures::{Stream, StreamExt};
use hex;
@@ -34,6 +34,7 @@ use tracing::error;
use crate::context::RequestContext;
use crate::tenant::block_io::{BlockReader, BlockWriter};
use crate::virtual_file::{IoBuffer, IoBufferMut, owned_buffers_io::write::Buffer};
// The maximum size of a value stored in the B-tree. 5 bytes is enough currently.
pub const VALUE_SZ: usize = 5;
@@ -787,12 +788,12 @@ impl<const L: usize> BuildNode<L> {
///
/// Serialize the node to on-disk format.
///
fn pack(&self) -> Bytes {
fn pack(&self) -> IoBuffer {
assert!(self.keys.len() == self.num_children as usize * self.suffix_len);
assert!(self.values.len() == self.num_children as usize * VALUE_SZ);
assert!(self.num_children > 0);
let mut buf = BytesMut::new();
let mut buf = IoBufferMut::with_capacity(PAGE_SZ);
buf.put_u16(self.num_children);
buf.put_u8(self.level);
@@ -805,7 +806,7 @@ impl<const L: usize> BuildNode<L> {
assert!(buf.len() == self.size);
assert!(buf.len() <= PAGE_SZ);
buf.resize(PAGE_SZ, 0);
buf.extend_with(0, PAGE_SZ - buf.len());
buf.freeze()
}
@@ -839,7 +840,7 @@ pub(crate) mod tests {
#[derive(Clone, Default)]
pub(crate) struct TestDisk {
blocks: Vec<Bytes>,
blocks: Vec<IoBuffer>,
}
impl TestDisk {
fn new() -> Self {
@@ -857,7 +858,7 @@ pub(crate) mod tests {
}
}
impl BlockWriter for &mut TestDisk {
fn write_blk(&mut self, buf: Bytes) -> io::Result<u32> {
fn write_blk(&mut self, buf: IoBuffer) -> io::Result<u32> {
let blknum = self.blocks.len();
self.blocks.push(buf);
Ok(blknum as u32)

View File

@@ -12,6 +12,7 @@ use tokio_epoll_uring::{BoundedBuf, Slice};
use tokio_util::sync::CancellationToken;
use tracing::{error, info_span};
use utils::id::TimelineId;
use utils::sync::gate::GateGuard;
use crate::assert_u64_eq_usize::{U64IsUsize, UsizeIsU64};
use crate::config::PageServerConf;
@@ -21,16 +22,33 @@ use crate::tenant::storage_layer::inmemory_layer::vectored_dio_read::File;
use crate::virtual_file::owned_buffers_io::io_buf_aligned::IoBufAlignedMut;
use crate::virtual_file::owned_buffers_io::slice::SliceMutExt;
use crate::virtual_file::owned_buffers_io::write::{Buffer, FlushTaskError};
use crate::virtual_file::{self, IoBufferMut, VirtualFile, owned_buffers_io};
use crate::virtual_file::{self, IoBufferMut, TempVirtualFile, VirtualFile, owned_buffers_io};
use self::owned_buffers_io::write::OwnedAsyncWriter;
pub struct EphemeralFile {
_tenant_shard_id: TenantShardId,
_timeline_id: TimelineId,
page_cache_file_id: page_cache::FileId,
bytes_written: u64,
buffered_writer: owned_buffers_io::write::BufferedWriter<IoBufferMut, VirtualFile>,
/// Gate guard is held on as long as we need to do operations in the path (delete on drop)
_gate_guard: utils::sync::gate::GateGuard,
file: TempVirtualFileCoOwnedByEphemeralFileAndBufferedWriter,
buffered_writer: BufferedWriter,
}
type BufferedWriter = owned_buffers_io::write::BufferedWriter<
IoBufferMut,
TempVirtualFileCoOwnedByEphemeralFileAndBufferedWriter,
>;
/// A TempVirtualFile that is co-owned by the [`EphemeralFile`]` and [`BufferedWriter`].
///
/// (Actually [`BufferedWriter`] internally is just a client to a background flush task.
/// The co-ownership is between [`EphemeralFile`] and that flush task.)
///
/// Co-ownership allows us to serve reads for data that has already been flushed by the [`BufferedWriter`].
#[derive(Debug, Clone)]
struct TempVirtualFileCoOwnedByEphemeralFileAndBufferedWriter {
inner: Arc<TempVirtualFile>,
}
const TAIL_SZ: usize = 64 * 1024;
@@ -44,9 +62,12 @@ impl EphemeralFile {
cancel: &CancellationToken,
ctx: &RequestContext,
) -> anyhow::Result<EphemeralFile> {
static NEXT_FILENAME: AtomicU64 = AtomicU64::new(1);
// TempVirtualFile requires us to never reuse a filename while an old
// instance of TempVirtualFile created with that filename is not done dropping yet.
// So, we use a monotonic counter to disambiguate the filenames.
static NEXT_TEMP_DISAMBIGUATOR: AtomicU64 = AtomicU64::new(1);
let filename_disambiguator =
NEXT_FILENAME.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
NEXT_TEMP_DISAMBIGUATOR.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let filename = conf
.timeline_path(&tenant_shard_id, &timeline_id)
@@ -54,16 +75,17 @@ impl EphemeralFile {
"ephemeral-{filename_disambiguator}"
)));
let file = Arc::new(
let file = TempVirtualFileCoOwnedByEphemeralFileAndBufferedWriter::new(
VirtualFile::open_with_options_v2(
&filename,
virtual_file::OpenOptions::new()
.create_new(true)
.read(true)
.write(true)
.create(true),
.write(true),
ctx,
)
.await?,
gate.enter()?,
);
let page_cache_file_id = page_cache::next_file_id(); // XXX get rid, we're not page-caching anymore
@@ -73,37 +95,60 @@ impl EphemeralFile {
_timeline_id: timeline_id,
page_cache_file_id,
bytes_written: 0,
buffered_writer: owned_buffers_io::write::BufferedWriter::new(
file: file.clone(),
buffered_writer: BufferedWriter::new(
file,
0,
|| IoBufferMut::with_capacity(TAIL_SZ),
gate.enter()?,
cancel.child_token(),
ctx,
info_span!(parent: None, "ephemeral_file_buffered_writer", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), timeline_id=%timeline_id, path = %filename),
),
_gate_guard: gate.enter()?,
})
}
}
impl Drop for EphemeralFile {
fn drop(&mut self) {
// unlink the file
// we are clear to do this, because we have entered a gate
let path = self.buffered_writer.as_inner().path();
let res = std::fs::remove_file(path);
if let Err(e) = res {
if e.kind() != std::io::ErrorKind::NotFound {
// just never log the not found errors, we cannot do anything for them; on detach
// the tenant directory is already gone.
//
// not found files might also be related to https://github.com/neondatabase/neon/issues/2442
error!("could not remove ephemeral file '{path}': {e}");
}
impl TempVirtualFileCoOwnedByEphemeralFileAndBufferedWriter {
fn new(file: VirtualFile, gate_guard: GateGuard) -> Self {
Self {
inner: Arc::new(TempVirtualFile::new(file, gate_guard)),
}
}
}
impl OwnedAsyncWriter for TempVirtualFileCoOwnedByEphemeralFileAndBufferedWriter {
fn write_all_at<Buf: owned_buffers_io::io_buf_aligned::IoBufAligned + Send>(
&self,
buf: owned_buffers_io::io_buf_ext::FullSlice<Buf>,
offset: u64,
ctx: &RequestContext,
) -> impl std::future::Future<
Output = (
owned_buffers_io::io_buf_ext::FullSlice<Buf>,
std::io::Result<()>,
),
> + Send {
self.inner.write_all_at(buf, offset, ctx)
}
fn set_len(
&self,
len: u64,
ctx: &RequestContext,
) -> impl Future<Output = std::io::Result<()>> + Send {
self.inner.set_len(len, ctx)
}
}
impl std::ops::Deref for TempVirtualFileCoOwnedByEphemeralFileAndBufferedWriter {
type Target = VirtualFile;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum EphemeralFileWriteError {
#[error("{0}")]
@@ -262,9 +307,9 @@ impl super::storage_layer::inmemory_layer::vectored_dio_read::File for Ephemeral
let mutable_range = Range(std::cmp::max(start, submitted_offset), end);
let dst = if written_range.len() > 0 {
let file: &VirtualFile = self.buffered_writer.as_inner();
let bounds = dst.bounds();
let slice = file
let slice = self
.file
.read_exact_at(dst.slice(0..written_range.len().into_usize()), start, ctx)
.await?;
Slice::from_buf_bounds(Slice::into_inner(slice), bounds)
@@ -456,7 +501,7 @@ mod tests {
assert_eq!(&buf, &content[range]);
}
let file_contents = std::fs::read(file.buffered_writer.as_inner().path()).unwrap();
let file_contents = std::fs::read(file.file.path()).unwrap();
assert!(file_contents == content[0..cap * 2]);
let maybe_flushed_buffer_contents = file.buffered_writer.inspect_maybe_flushed().unwrap();
@@ -489,7 +534,7 @@ mod tests {
// assert the state is as this test expects it to be
let load_io_buf_res = file.load_to_io_buf(&ctx).await.unwrap();
assert_eq!(&load_io_buf_res[..], &content[0..cap * 2 + cap / 2]);
let md = file.buffered_writer.as_inner().path().metadata().unwrap();
let md = file.file.path().metadata().unwrap();
assert_eq!(
md.len(),
2 * cap.into_u64(),

View File

@@ -6,6 +6,7 @@
use std::collections::HashSet;
use std::future::Future;
use std::str::FromStr;
use std::sync::atomic::AtomicU64;
use std::time::SystemTime;
use anyhow::{Context, anyhow};
@@ -15,7 +16,7 @@ use remote_storage::{
DownloadError, DownloadKind, DownloadOpts, GenericRemoteStorage, ListingMode, RemotePath,
};
use tokio::fs::{self, File, OpenOptions};
use tokio::io::{AsyncSeekExt, AsyncWriteExt};
use tokio::io::AsyncSeekExt;
use tokio_util::io::StreamReader;
use tokio_util::sync::CancellationToken;
use tracing::warn;
@@ -40,7 +41,10 @@ use crate::span::{
use crate::tenant::Generation;
use crate::tenant::remote_timeline_client::{remote_layer_path, remote_timelines_path};
use crate::tenant::storage_layer::LayerName;
use crate::virtual_file::{MaybeFatalIo, VirtualFile, on_fatal_io_error};
use crate::virtual_file;
use crate::virtual_file::owned_buffers_io::write::FlushTaskError;
use crate::virtual_file::{IoBufferMut, MaybeFatalIo, VirtualFile};
use crate::virtual_file::{TempVirtualFile, owned_buffers_io};
///
/// If 'metadata' is given, we will validate that the downloaded file's size matches that
@@ -72,21 +76,34 @@ pub async fn download_layer_file<'a>(
layer_metadata.generation,
);
// Perform a rename inspired by durable_rename from file_utils.c.
// The sequence:
// write(tmp)
// fsync(tmp)
// rename(tmp, new)
// fsync(new)
// fsync(parent)
// For more context about durable_rename check this email from postgres mailing list:
// https://www.postgresql.org/message-id/56583BDD.9060302@2ndquadrant.com
// If pageserver crashes the temp file will be deleted on startup and re-downloaded.
let temp_file_path = path_with_suffix_extension(local_path, TEMP_DOWNLOAD_EXTENSION);
let bytes_amount = download_retry(
let (bytes_amount, temp_file) = download_retry(
|| async {
download_object(storage, &remote_path, &temp_file_path, gate, cancel, ctx).await
// TempVirtualFile requires us to never reuse a filename while an old
// instance of TempVirtualFile created with that filename is not done dropping yet.
// So, we use a monotonic counter to disambiguate the filenames.
static NEXT_TEMP_DISAMBIGUATOR: AtomicU64 = AtomicU64::new(1);
let filename_disambiguator =
NEXT_TEMP_DISAMBIGUATOR.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
let temp_file_path = path_with_suffix_extension(
local_path,
&format!("{filename_disambiguator:x}.{TEMP_DOWNLOAD_EXTENSION}"),
);
let temp_file = TempVirtualFile::new(
VirtualFile::open_with_options_v2(
&temp_file_path,
virtual_file::OpenOptions::new()
.create_new(true)
.write(true),
ctx,
)
.await
.with_context(|| format!("create a temp file for layer download: {temp_file_path}"))
.map_err(DownloadError::Other)?,
gate.enter().map_err(|_| DownloadError::Cancelled)?,
);
download_object(storage, &remote_path, temp_file, gate, cancel, ctx).await
},
&format!("download {remote_path:?}"),
cancel,
@@ -96,7 +113,8 @@ pub async fn download_layer_file<'a>(
let expected = layer_metadata.file_size;
if expected != bytes_amount {
return Err(DownloadError::Other(anyhow!(
"According to layer file metadata should have downloaded {expected} bytes but downloaded {bytes_amount} bytes into file {temp_file_path:?}",
"According to layer file metadata should have downloaded {expected} bytes but downloaded {bytes_amount} bytes into file {:?}",
temp_file.path()
)));
}
@@ -106,11 +124,28 @@ pub async fn download_layer_file<'a>(
)))
});
fs::rename(&temp_file_path, &local_path)
// Try rename before disarming the temp file.
// That way, if rename fails for whatever reason, we clean up the temp file on the return path.
fs::rename(temp_file.path(), &local_path)
.await
.with_context(|| format!("rename download layer file to {local_path}"))
.map_err(DownloadError::Other)?;
// The temp file's VirtualFile points to the temp_file_path which we moved above.
// Drop it immediately, it's invalid.
// This will get better in https://github.com/neondatabase/neon/issues/11692
let _: VirtualFile = temp_file.disarm_into_inner();
// NB: The gate guard that was stored in `temp_file` is dropped but we continue
// to operate on it and on the parent timeline directory.
// Those operations are safe to do because higher-level code is holding another gate guard:
// - attached mode: the download task spawned by struct Layer is holding the gate guard
// - secondary mode: The TenantDownloader::download holds the gate open
// The rename above is not durable yet.
// It doesn't matter for crash consistency because pageserver startup deletes temp
// files and we'll re-download on demand if necessary.
// We use fatal_err() below because the after the rename above,
// the in-memory state of the filesystem already has the layer file in its final place,
// and subsequent pageserver code could think it's durable while it really isn't.
@@ -146,147 +181,64 @@ pub async fn download_layer_file<'a>(
async fn download_object(
storage: &GenericRemoteStorage,
src_path: &RemotePath,
dst_path: &Utf8PathBuf,
#[cfg_attr(target_os = "macos", allow(unused_variables))] gate: &utils::sync::gate::Gate,
destination_file: TempVirtualFile,
gate: &utils::sync::gate::Gate,
cancel: &CancellationToken,
#[cfg_attr(target_os = "macos", allow(unused_variables))] ctx: &RequestContext,
) -> Result<u64, DownloadError> {
let res = match crate::virtual_file::io_engine::get() {
crate::virtual_file::io_engine::IoEngine::NotSet => panic!("unset"),
crate::virtual_file::io_engine::IoEngine::StdFs => {
async {
let destination_file = tokio::fs::File::create(dst_path)
.await
.with_context(|| format!("create a destination file for layer '{dst_path}'"))
.map_err(DownloadError::Other)?;
ctx: &RequestContext,
) -> Result<(u64, TempVirtualFile), DownloadError> {
let mut download = storage
.download(src_path, &DownloadOpts::default(), cancel)
.await?;
let download = storage
.download(src_path, &DownloadOpts::default(), cancel)
.await?;
pausable_failpoint!("before-downloading-layer-stream-pausable");
pausable_failpoint!("before-downloading-layer-stream-pausable");
let dst_path = destination_file.path().to_owned();
let mut buffered = owned_buffers_io::write::BufferedWriter::<IoBufferMut, _>::new(
destination_file,
0,
|| IoBufferMut::with_capacity(super::BUFFER_SIZE),
gate.enter().map_err(|_| DownloadError::Cancelled)?,
cancel.child_token(),
ctx,
tracing::info_span!(parent: None, "download_object_buffered_writer", %dst_path),
);
let mut buf_writer =
tokio::io::BufWriter::with_capacity(super::BUFFER_SIZE, destination_file);
let mut reader = tokio_util::io::StreamReader::new(download.download_stream);
let bytes_amount = tokio::io::copy_buf(&mut reader, &mut buf_writer).await?;
buf_writer.flush().await?;
let mut destination_file = buf_writer.into_inner();
// Tokio doc here: https://docs.rs/tokio/1.17.0/tokio/fs/struct.File.html states that:
// A file will not be closed immediately when it goes out of scope if there are any IO operations
// that have not yet completed. To ensure that a file is closed immediately when it is dropped,
// you should call flush before dropping it.
//
// From the tokio code I see that it waits for pending operations to complete. There shouldt be any because
// we assume that `destination_file` file is fully written. I e there is no pending .write(...).await operations.
// But for additional safety lets check/wait for any pending operations.
destination_file
.flush()
.await
.maybe_fatal_err("download_object sync_all")
.with_context(|| format!("flush source file at {dst_path}"))
.map_err(DownloadError::Other)?;
// not using sync_data because it can lose file size update
destination_file
.sync_all()
.await
.maybe_fatal_err("download_object sync_all")
.with_context(|| format!("failed to fsync source file at {dst_path}"))
.map_err(DownloadError::Other)?;
Ok(bytes_amount)
}
// TODO: use vectored write (writev) once supported by tokio-epoll-uring.
// There's chunks_vectored() on the stream.
let (bytes_amount, destination_file) = async {
while let Some(res) = futures::StreamExt::next(&mut download.download_stream).await {
let chunk = match res {
Ok(chunk) => chunk,
Err(e) => return Err(DownloadError::from(e)),
};
buffered
.write_buffered_borrowed(&chunk, ctx)
.await
.map_err(|e| match e {
FlushTaskError::Cancelled => DownloadError::Cancelled,
})?;
}
buffered
.shutdown(
owned_buffers_io::write::BufferedWriterShutdownMode::PadThenTruncate,
ctx,
)
.await
}
#[cfg(target_os = "linux")]
crate::virtual_file::io_engine::IoEngine::TokioEpollUring => {
use crate::virtual_file::owned_buffers_io::write::FlushTaskError;
use std::sync::Arc;
use crate::virtual_file::{IoBufferMut, owned_buffers_io};
async {
let destination_file = Arc::new(
VirtualFile::create(dst_path, ctx)
.await
.with_context(|| {
format!("create a destination file for layer '{dst_path}'")
})
.map_err(DownloadError::Other)?,
);
let mut download = storage
.download(src_path, &DownloadOpts::default(), cancel)
.await?;
pausable_failpoint!("before-downloading-layer-stream-pausable");
let mut buffered = owned_buffers_io::write::BufferedWriter::<IoBufferMut, _>::new(
destination_file,
|| IoBufferMut::with_capacity(super::BUFFER_SIZE),
gate.enter().map_err(|_| DownloadError::Cancelled)?,
cancel.child_token(),
ctx,
tracing::info_span!(parent: None, "download_object_buffered_writer", %dst_path),
);
// TODO: use vectored write (writev) once supported by tokio-epoll-uring.
// There's chunks_vectored() on the stream.
let (bytes_amount, destination_file) = async {
while let Some(res) =
futures::StreamExt::next(&mut download.download_stream).await
{
let chunk = match res {
Ok(chunk) => chunk,
Err(e) => return Err(DownloadError::from(e)),
};
buffered
.write_buffered_borrowed(&chunk, ctx)
.await
.map_err(|e| match e {
FlushTaskError::Cancelled => DownloadError::Cancelled,
})?;
}
let inner = buffered
.flush_and_into_inner(ctx)
.await
.map_err(|e| match e {
FlushTaskError::Cancelled => DownloadError::Cancelled,
})?;
Ok(inner)
}
.await?;
// not using sync_data because it can lose file size update
destination_file
.sync_all()
.await
.maybe_fatal_err("download_object sync_all")
.with_context(|| format!("failed to fsync source file at {dst_path}"))
.map_err(DownloadError::Other)?;
Ok(bytes_amount)
}
.await
}
};
// in case the download failed, clean up
match res {
Ok(bytes_amount) => Ok(bytes_amount),
Err(e) => {
if let Err(e) = tokio::fs::remove_file(dst_path).await {
if e.kind() != std::io::ErrorKind::NotFound {
on_fatal_io_error(&e, &format!("Removing temporary file {dst_path}"));
}
}
Err(e)
}
.map_err(|e| match e {
FlushTaskError::Cancelled => DownloadError::Cancelled,
})
}
.await?;
// not using sync_data because it can lose file size update
destination_file
.sync_all()
.await
.maybe_fatal_err("download_object sync_all")
.with_context(|| format!("failed to fsync source file at {dst_path}"))
.map_err(DownloadError::Other)?;
Ok((bytes_amount, destination_file))
}
const TEMP_DOWNLOAD_EXTENSION: &str = "temp_download";

View File

@@ -646,7 +646,7 @@ enum UpdateError {
NoData,
#[error("Insufficient local storage space")]
NoSpace,
#[error("Failed to download")]
#[error("Failed to download: {0}")]
DownloadError(DownloadError),
#[error(transparent)]
Deserialize(#[from] serde_json::Error),
@@ -1521,12 +1521,11 @@ async fn load_heatmap(
path: &Utf8PathBuf,
ctx: &RequestContext,
) -> Result<Option<HeatMapTenant>, anyhow::Error> {
let mut file = match VirtualFile::open(path, ctx).await {
Ok(file) => file,
let st = match VirtualFile::read_to_string(path, ctx).await {
Ok(st) => st,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(None),
Err(e) => Err(e)?,
};
let st = file.read_to_string(ctx).await?;
let htm = serde_json::from_str(&st)?;
Ok(Some(htm))
}

View File

@@ -29,11 +29,11 @@
//!
use std::collections::{HashMap, VecDeque};
use std::fs::File;
use std::io::SeekFrom;
use std::ops::Range;
use std::os::unix::fs::FileExt;
use std::str::FromStr;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use anyhow::{Context, Result, bail, ensure};
use camino::{Utf8Path, Utf8PathBuf};
@@ -45,14 +45,13 @@ use pageserver_api::keyspace::KeySpace;
use pageserver_api::models::ImageCompressionAlgorithm;
use pageserver_api::shard::TenantShardId;
use pageserver_api::value::Value;
use rand::Rng;
use rand::distributions::Alphanumeric;
use serde::{Deserialize, Serialize};
use tokio::sync::OnceCell;
use tokio_epoll_uring::IoBuf;
use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::bin_ser::BeSer;
use utils::bin_ser::SerializeError;
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
@@ -74,8 +73,10 @@ use crate::tenant::vectored_blob_io::{
BlobFlag, BufView, StreamingVectoredReadPlanner, VectoredBlobReader, VectoredRead,
VectoredReadPlanner,
};
use crate::virtual_file::TempVirtualFile;
use crate::virtual_file::owned_buffers_io::io_buf_ext::{FullSlice, IoBufExt};
use crate::virtual_file::{self, IoBufferMut, MaybeFatalIo, VirtualFile};
use crate::virtual_file::owned_buffers_io::write::{Buffer, BufferedWriterShutdownMode};
use crate::virtual_file::{self, IoBuffer, IoBufferMut, MaybeFatalIo, VirtualFile};
use crate::{DELTA_FILE_MAGIC, STORAGE_FORMAT_VERSION, TEMP_FILE_SUFFIX};
///
@@ -113,6 +114,15 @@ impl From<&DeltaLayer> for Summary {
}
impl Summary {
/// Serializes the summary header into an aligned buffer of lenth `PAGE_SZ`.
pub fn ser_into_page(&self) -> Result<IoBuffer, SerializeError> {
let mut buf = IoBufferMut::with_capacity(PAGE_SZ);
Self::ser_into(self, &mut buf)?;
// Pad zeroes to the buffer so the length is a multiple of the alignment.
buf.extend_with(0, buf.capacity() - buf.len());
Ok(buf.freeze())
}
pub(super) fn expected(
tenant_id: TenantId,
timeline_id: TimelineId,
@@ -288,19 +298,20 @@ impl DeltaLayer {
key_start: Key,
lsn_range: &Range<Lsn>,
) -> Utf8PathBuf {
let rand_string: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(8)
.map(char::from)
.collect();
// TempVirtualFile requires us to never reuse a filename while an old
// instance of TempVirtualFile created with that filename is not done dropping yet.
// So, we use a monotonic counter to disambiguate the filenames.
static NEXT_TEMP_DISAMBIGUATOR: AtomicU64 = AtomicU64::new(1);
let filename_disambiguator =
NEXT_TEMP_DISAMBIGUATOR.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
conf.timeline_path(tenant_shard_id, timeline_id)
.join(format!(
"{}-XXX__{:016X}-{:016X}.{}.{}",
"{}-XXX__{:016X}-{:016X}.{:x}.{}",
key_start,
u64::from(lsn_range.start),
u64::from(lsn_range.end),
rand_string,
filename_disambiguator,
TEMP_FILE_SUFFIX,
))
}
@@ -391,7 +402,7 @@ struct DeltaLayerWriterInner {
tree: DiskBtreeBuilder<BlockBuf, DELTA_KEY_SIZE>,
blob_writer: BlobWriter<true>,
blob_writer: BlobWriter<TempVirtualFile>,
// Number of key-lsns in the layer.
num_keys: usize,
@@ -415,16 +426,29 @@ impl DeltaLayerWriterInner {
// Create the file initially with a temporary filename. We don't know
// the end key yet, so we cannot form the final filename yet. We will
// rename it when we're done.
//
// Note: This overwrites any existing file. There shouldn't be any.
// FIXME: throw an error instead?
let path =
DeltaLayer::temp_path_for(conf, &tenant_shard_id, &timeline_id, key_start, &lsn_range);
let file = TempVirtualFile::new(
VirtualFile::open_with_options_v2(
&path,
virtual_file::OpenOptions::new()
.create_new(true)
.write(true),
ctx,
)
.await?,
gate.enter()?,
);
let mut file = VirtualFile::create(&path, ctx).await?;
// make room for the header block
file.seek(SeekFrom::Start(PAGE_SZ as u64)).await?;
let blob_writer = BlobWriter::new(file, PAGE_SZ as u64, gate, cancel, ctx);
// Start at PAGE_SZ, make room for the header block
let blob_writer = BlobWriter::new(
file,
PAGE_SZ as u64,
gate,
cancel,
ctx,
info_span!(parent: None, "delta_layer_writer_flush_task", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), timeline_id=%timeline_id, path = %path),
)?;
// Initialize the b-tree index builder
let block_buf = BlockBuf::new();
@@ -515,34 +539,27 @@ impl DeltaLayerWriterInner {
self,
key_end: Key,
ctx: &RequestContext,
) -> anyhow::Result<(PersistentLayerDesc, Utf8PathBuf)> {
let temp_path = self.path.clone();
let result = self.finish0(key_end, ctx).await;
if let Err(ref e) = result {
tracing::info!(%temp_path, "cleaning up temporary file after error during writing: {e}");
if let Err(e) = std::fs::remove_file(&temp_path) {
tracing::warn!(error=%e, %temp_path, "error cleaning up temporary layer file after error during writing");
}
}
result
}
async fn finish0(
self,
key_end: Key,
ctx: &RequestContext,
) -> anyhow::Result<(PersistentLayerDesc, Utf8PathBuf)> {
let index_start_blk = self.blob_writer.size().div_ceil(PAGE_SZ as u64) as u32;
let mut file = self.blob_writer.into_inner(ctx).await?;
let file = self
.blob_writer
.shutdown(
BufferedWriterShutdownMode::ZeroPadToNextMultiple(PAGE_SZ),
ctx,
)
.await?;
// Write out the index
let (index_root_blk, block_buf) = self.tree.finish()?;
file.seek(SeekFrom::Start(index_start_blk as u64 * PAGE_SZ as u64))
.await?;
let mut offset = index_start_blk as u64 * PAGE_SZ as u64;
// TODO(yuchen): https://github.com/neondatabase/neon/issues/10092
// Should we just replace BlockBuf::blocks with one big buffer
for buf in block_buf.blocks {
let (_buf, res) = file.write_all(buf.slice_len(), ctx).await;
let (_buf, res) = file.write_all_at(buf.slice_len(), offset, ctx).await;
res?;
offset += PAGE_SZ as u64;
}
assert!(self.lsn_range.start < self.lsn_range.end);
// Fill in the summary on blk 0
@@ -557,11 +574,9 @@ impl DeltaLayerWriterInner {
index_root_blk,
};
let mut buf = Vec::with_capacity(PAGE_SZ);
// TODO: could use smallvec here but it's a pain with Slice<T>
Summary::ser_into(&summary, &mut buf)?;
file.seek(SeekFrom::Start(0)).await?;
let (_buf, res) = file.write_all(buf.slice_len(), ctx).await;
// Writes summary at the first block (offset 0).
let buf = summary.ser_into_page()?;
let (_buf, res) = file.write_all_at(buf.slice_len(), 0, ctx).await;
res?;
let metadata = file
@@ -598,6 +613,10 @@ impl DeltaLayerWriterInner {
trace!("created delta layer {}", self.path);
// The gate guard stored in `destination_file` is dropped. Callers (e.g.. flush loop or compaction)
// keep the gate open also, so that it's safe for them to rename the file to its final destination.
file.disarm_into_inner();
Ok((desc, self.path))
}
}
@@ -726,17 +745,6 @@ impl DeltaLayerWriter {
}
}
impl Drop for DeltaLayerWriter {
fn drop(&mut self) {
if let Some(inner) = self.inner.take() {
// We want to remove the virtual file here, so it's fine to not
// having completely flushed unwritten data.
let vfile = inner.blob_writer.into_inner_no_flush();
vfile.remove();
}
}
}
#[derive(thiserror::Error, Debug)]
pub enum RewriteSummaryError {
#[error("magic mismatch")]
@@ -760,7 +768,7 @@ impl DeltaLayer {
where
F: Fn(Summary) -> Summary,
{
let mut file = VirtualFile::open_with_options(
let file = VirtualFile::open_with_options_v2(
path,
virtual_file::OpenOptions::new().read(true).write(true),
ctx,
@@ -777,11 +785,8 @@ impl DeltaLayer {
let new_summary = rewrite(actual_summary);
let mut buf = Vec::with_capacity(PAGE_SZ);
// TODO: could use smallvec here, but it's a pain with Slice<T>
Summary::ser_into(&new_summary, &mut buf).context("serialize")?;
file.seek(SeekFrom::Start(0)).await?;
let (_buf, res) = file.write_all(buf.slice_len(), ctx).await;
let buf = new_summary.ser_into_page().context("serialize")?;
let (_buf, res) = file.write_all_at(buf.slice_len(), 0, ctx).await;
res?;
Ok(())
}
@@ -1437,6 +1442,19 @@ impl DeltaLayerInner {
}
pub fn iter<'a>(&'a self, ctx: &'a RequestContext) -> DeltaLayerIterator<'a> {
self.iter_with_options(
ctx,
1024 * 8192, // The default value. Unit tests might use a different value. 1024 * 8K = 8MB buffer.
1024, // The default value. Unit tests might use a different value
)
}
pub fn iter_with_options<'a>(
&'a self,
ctx: &'a RequestContext,
max_read_size: u64,
max_batch_size: usize,
) -> DeltaLayerIterator<'a> {
let block_reader = FileBlockReader::new(&self.file, self.file_id);
let tree_reader =
DiskBtreeReader::new(self.index_start_blk, self.index_root_blk, block_reader);
@@ -1446,10 +1464,7 @@ impl DeltaLayerInner {
index_iter: tree_reader.iter(&[0; DELTA_KEY_SIZE], ctx),
key_values_batch: std::collections::VecDeque::new(),
is_end: false,
planner: StreamingVectoredReadPlanner::new(
1024 * 8192, // The default value. Unit tests might use a different value. 1024 * 8K = 8MB buffer.
1024, // The default value. Unit tests might use a different value
),
planner: StreamingVectoredReadPlanner::new(max_read_size, max_batch_size),
}
}
@@ -1609,8 +1624,8 @@ pub(crate) mod test {
use bytes::Bytes;
use itertools::MinMaxResult;
use pageserver_api::value::Value;
use rand::RngCore;
use rand::prelude::{SeedableRng, SliceRandom, StdRng};
use rand::{Rng, RngCore};
use super::*;
use crate::DEFAULT_PG_VERSION;

View File

@@ -27,11 +27,11 @@
//! actual page images are stored in the "values" part.
use std::collections::{HashMap, VecDeque};
use std::fs::File;
use std::io::SeekFrom;
use std::ops::Range;
use std::os::unix::prelude::FileExt;
use std::str::FromStr;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use anyhow::{Context, Result, bail, ensure};
use bytes::Bytes;
@@ -43,14 +43,13 @@ use pageserver_api::key::{DBDIR_KEY, KEY_SIZE, Key};
use pageserver_api::keyspace::KeySpace;
use pageserver_api::shard::{ShardIdentity, TenantShardId};
use pageserver_api::value::Value;
use rand::Rng;
use rand::distributions::Alphanumeric;
use serde::{Deserialize, Serialize};
use tokio::sync::OnceCell;
use tokio_stream::StreamExt;
use tokio_util::sync::CancellationToken;
use tracing::*;
use utils::bin_ser::BeSer;
use utils::bin_ser::SerializeError;
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
@@ -72,8 +71,10 @@ use crate::tenant::vectored_blob_io::{
BlobFlag, BufView, StreamingVectoredReadPlanner, VectoredBlobReader, VectoredRead,
VectoredReadPlanner,
};
use crate::virtual_file::TempVirtualFile;
use crate::virtual_file::owned_buffers_io::io_buf_ext::IoBufExt;
use crate::virtual_file::{self, IoBufferMut, MaybeFatalIo, VirtualFile};
use crate::virtual_file::owned_buffers_io::write::{Buffer, BufferedWriterShutdownMode};
use crate::virtual_file::{self, IoBuffer, IoBufferMut, MaybeFatalIo, VirtualFile};
use crate::{IMAGE_FILE_MAGIC, STORAGE_FORMAT_VERSION, TEMP_FILE_SUFFIX};
///
@@ -112,6 +113,15 @@ impl From<&ImageLayer> for Summary {
}
impl Summary {
/// Serializes the summary header into an aligned buffer of lenth `PAGE_SZ`.
pub fn ser_into_page(&self) -> Result<IoBuffer, SerializeError> {
let mut buf = IoBufferMut::with_capacity(PAGE_SZ);
Self::ser_into(self, &mut buf)?;
// Pad zeroes to the buffer so the length is a multiple of the alignment.
buf.extend_with(0, buf.capacity() - buf.len());
Ok(buf.freeze())
}
pub(super) fn expected(
tenant_id: TenantId,
timeline_id: TimelineId,
@@ -252,14 +262,18 @@ impl ImageLayer {
tenant_shard_id: TenantShardId,
fname: &ImageLayerName,
) -> Utf8PathBuf {
let rand_string: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(8)
.map(char::from)
.collect();
// TempVirtualFile requires us to never reuse a filename while an old
// instance of TempVirtualFile created with that filename is not done dropping yet.
// So, we use a monotonic counter to disambiguate the filenames.
static NEXT_TEMP_DISAMBIGUATOR: AtomicU64 = AtomicU64::new(1);
let filename_disambiguator =
NEXT_TEMP_DISAMBIGUATOR.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
conf.timeline_path(&tenant_shard_id, &timeline_id)
.join(format!("{fname}.{rand_string}.{TEMP_FILE_SUFFIX}"))
.join(format!(
"{fname}.{:x}.{TEMP_FILE_SUFFIX}",
filename_disambiguator
))
}
///
@@ -349,7 +363,7 @@ impl ImageLayer {
where
F: Fn(Summary) -> Summary,
{
let mut file = VirtualFile::open_with_options(
let file = VirtualFile::open_with_options_v2(
path,
virtual_file::OpenOptions::new().read(true).write(true),
ctx,
@@ -366,11 +380,8 @@ impl ImageLayer {
let new_summary = rewrite(actual_summary);
let mut buf = Vec::with_capacity(PAGE_SZ);
// TODO: could use smallvec here but it's a pain with Slice<T>
Summary::ser_into(&new_summary, &mut buf).context("serialize")?;
file.seek(SeekFrom::Start(0)).await?;
let (_buf, res) = file.write_all(buf.slice_len(), ctx).await;
let buf = new_summary.ser_into_page().context("serialize")?;
let (_buf, res) = file.write_all_at(buf.slice_len(), 0, ctx).await;
res?;
Ok(())
}
@@ -674,6 +685,19 @@ impl ImageLayerInner {
}
pub(crate) fn iter<'a>(&'a self, ctx: &'a RequestContext) -> ImageLayerIterator<'a> {
self.iter_with_options(
ctx,
1024 * 8192, // The default value. Unit tests might use a different value. 1024 * 8K = 8MB buffer.
1024, // The default value. Unit tests might use a different value
)
}
pub(crate) fn iter_with_options<'a>(
&'a self,
ctx: &'a RequestContext,
max_read_size: u64,
max_batch_size: usize,
) -> ImageLayerIterator<'a> {
let block_reader = FileBlockReader::new(&self.file, self.file_id);
let tree_reader =
DiskBtreeReader::new(self.index_start_blk, self.index_root_blk, block_reader);
@@ -683,10 +707,7 @@ impl ImageLayerInner {
index_iter: tree_reader.iter(&[0; KEY_SIZE], ctx),
key_values_batch: VecDeque::new(),
is_end: false,
planner: StreamingVectoredReadPlanner::new(
1024 * 8192, // The default value. Unit tests might use a different value. 1024 * 8K = 8MB buffer.
1024, // The default value. Unit tests might use a different value
),
planner: StreamingVectoredReadPlanner::new(max_read_size, max_batch_size),
}
}
@@ -739,7 +760,7 @@ struct ImageLayerWriterInner {
// Number of keys in the layer.
num_keys: usize,
blob_writer: BlobWriter<false>,
blob_writer: BlobWriter<TempVirtualFile>,
tree: DiskBtreeBuilder<BlockBuf, KEY_SIZE>,
#[cfg(feature = "testing")]
@@ -773,19 +794,27 @@ impl ImageLayerWriterInner {
},
);
trace!("creating image layer {}", path);
let mut file = {
VirtualFile::open_with_options(
let file = TempVirtualFile::new(
VirtualFile::open_with_options_v2(
&path,
virtual_file::OpenOptions::new()
.write(true)
.create_new(true),
.create_new(true)
.write(true),
ctx,
)
.await?
};
// make room for the header block
file.seek(SeekFrom::Start(PAGE_SZ as u64)).await?;
let blob_writer = BlobWriter::new(file, PAGE_SZ as u64, gate, cancel, ctx);
.await?,
gate.enter()?,
);
// Start at `PAGE_SZ` to make room for the header block.
let blob_writer = BlobWriter::new(
file,
PAGE_SZ as u64,
gate,
cancel,
ctx,
info_span!(parent: None, "image_layer_writer_flush_task", tenant_id=%tenant_shard_id.tenant_id, shard_id=%tenant_shard_id.shard_slug(), timeline_id=%timeline_id, path = %path),
)?;
// Initialize the b-tree index builder
let block_buf = BlockBuf::new();
@@ -896,25 +925,6 @@ impl ImageLayerWriterInner {
self,
ctx: &RequestContext,
end_key: Option<Key>,
) -> anyhow::Result<(PersistentLayerDesc, Utf8PathBuf)> {
let temp_path = self.path.clone();
let result = self.finish0(ctx, end_key).await;
if let Err(ref e) = result {
tracing::info!(%temp_path, "cleaning up temporary file after error during writing: {e}");
if let Err(e) = std::fs::remove_file(&temp_path) {
tracing::warn!(error=%e, %temp_path, "error cleaning up temporary layer file after error during writing");
}
}
result
}
///
/// Finish writing the image layer.
///
async fn finish0(
self,
ctx: &RequestContext,
end_key: Option<Key>,
) -> anyhow::Result<(PersistentLayerDesc, Utf8PathBuf)> {
let index_start_blk = self.blob_writer.size().div_ceil(PAGE_SZ as u64) as u32;
@@ -932,15 +942,24 @@ impl ImageLayerWriterInner {
crate::metrics::COMPRESSION_IMAGE_OUTPUT_BYTES.inc_by(compressed_size);
};
let mut file = self.blob_writer.into_inner();
let file = self
.blob_writer
.shutdown(
BufferedWriterShutdownMode::ZeroPadToNextMultiple(PAGE_SZ),
ctx,
)
.await?;
// Write out the index
file.seek(SeekFrom::Start(index_start_blk as u64 * PAGE_SZ as u64))
.await?;
let mut offset = index_start_blk as u64 * PAGE_SZ as u64;
let (index_root_blk, block_buf) = self.tree.finish()?;
// TODO(yuchen): https://github.com/neondatabase/neon/issues/10092
// Should we just replace BlockBuf::blocks with one big buffer?
for buf in block_buf.blocks {
let (_buf, res) = file.write_all(buf.slice_len(), ctx).await;
let (_buf, res) = file.write_all_at(buf.slice_len(), offset, ctx).await;
res?;
offset += PAGE_SZ as u64;
}
let final_key_range = if let Some(end_key) = end_key {
@@ -961,11 +980,9 @@ impl ImageLayerWriterInner {
index_root_blk,
};
let mut buf = Vec::with_capacity(PAGE_SZ);
// TODO: could use smallvec here but it's a pain with Slice<T>
Summary::ser_into(&summary, &mut buf)?;
file.seek(SeekFrom::Start(0)).await?;
let (_buf, res) = file.write_all(buf.slice_len(), ctx).await;
// Writes summary at the first block (offset 0).
let buf = summary.ser_into_page()?;
let (_buf, res) = file.write_all_at(buf.slice_len(), 0, ctx).await;
res?;
let metadata = file
@@ -1000,6 +1017,10 @@ impl ImageLayerWriterInner {
trace!("created image layer {}", self.path);
// The gate guard stored in `destination_file` is dropped. Callers (e.g.. flush loop or compaction)
// keep the gate open also, so that it's safe for them to rename the file to its final destination.
file.disarm_into_inner();
Ok((desc, self.path))
}
}
@@ -1125,14 +1146,6 @@ impl ImageLayerWriter {
}
}
impl Drop for ImageLayerWriter {
fn drop(&mut self) {
if let Some(inner) = self.inner.take() {
inner.blob_writer.into_inner().remove();
}
}
}
pub struct ImageLayerIterator<'a> {
image_layer: &'a ImageLayerInner,
ctx: &'a RequestContext,

View File

@@ -19,6 +19,7 @@ pub(crate) enum LayerRef<'a> {
}
impl<'a> LayerRef<'a> {
#[allow(dead_code)]
fn iter(self, ctx: &'a RequestContext) -> LayerIterRef<'a> {
match self {
Self::Image(x) => LayerIterRef::Image(x.iter(ctx)),
@@ -26,6 +27,22 @@ impl<'a> LayerRef<'a> {
}
}
fn iter_with_options(
self,
ctx: &'a RequestContext,
max_read_size: u64,
max_batch_size: usize,
) -> LayerIterRef<'a> {
match self {
Self::Image(x) => {
LayerIterRef::Image(x.iter_with_options(ctx, max_read_size, max_batch_size))
}
Self::Delta(x) => {
LayerIterRef::Delta(x.iter_with_options(ctx, max_read_size, max_batch_size))
}
}
}
fn layer_dbg_info(&self) -> String {
match self {
Self::Image(x) => x.layer_dbg_info(),
@@ -66,6 +83,8 @@ pub(crate) enum IteratorWrapper<'a> {
first_key_lower_bound: (Key, Lsn),
layer: LayerRef<'a>,
source_desc: Arc<PersistentLayerKey>,
max_read_size: u64,
max_batch_size: usize,
},
Loaded {
iter: PeekableLayerIterRef<'a>,
@@ -146,6 +165,8 @@ impl<'a> IteratorWrapper<'a> {
pub fn create_from_image_layer(
image_layer: &'a ImageLayerInner,
ctx: &'a RequestContext,
max_read_size: u64,
max_batch_size: usize,
) -> Self {
Self::NotLoaded {
layer: LayerRef::Image(image_layer),
@@ -157,12 +178,16 @@ impl<'a> IteratorWrapper<'a> {
is_delta: false,
}
.into(),
max_read_size,
max_batch_size,
}
}
pub fn create_from_delta_layer(
delta_layer: &'a DeltaLayerInner,
ctx: &'a RequestContext,
max_read_size: u64,
max_batch_size: usize,
) -> Self {
Self::NotLoaded {
layer: LayerRef::Delta(delta_layer),
@@ -174,6 +199,8 @@ impl<'a> IteratorWrapper<'a> {
is_delta: true,
}
.into(),
max_read_size,
max_batch_size,
}
}
@@ -204,11 +231,13 @@ impl<'a> IteratorWrapper<'a> {
first_key_lower_bound,
layer,
source_desc,
max_read_size,
max_batch_size,
} = self
else {
unreachable!()
};
let iter = layer.iter(ctx);
let iter = layer.iter_with_options(ctx, *max_read_size, *max_batch_size);
let iter = PeekableLayerIterRef::create(iter).await?;
if let Some((k1, l1, _)) = iter.peek() {
let (k2, l2) = first_key_lower_bound;
@@ -293,21 +322,41 @@ impl MergeIteratorItem for ((Key, Lsn, Value), Arc<PersistentLayerKey>) {
}
impl<'a> MergeIterator<'a> {
pub fn create_with_options(
deltas: &[&'a DeltaLayerInner],
images: &[&'a ImageLayerInner],
ctx: &'a RequestContext,
max_read_size: u64,
max_batch_size: usize,
) -> Self {
let mut heap = Vec::with_capacity(images.len() + deltas.len());
for image in images {
heap.push(IteratorWrapper::create_from_image_layer(
image,
ctx,
max_read_size,
max_batch_size,
));
}
for delta in deltas {
heap.push(IteratorWrapper::create_from_delta_layer(
delta,
ctx,
max_read_size,
max_batch_size,
));
}
Self {
heap: BinaryHeap::from(heap),
}
}
pub fn create(
deltas: &[&'a DeltaLayerInner],
images: &[&'a ImageLayerInner],
ctx: &'a RequestContext,
) -> Self {
let mut heap = Vec::with_capacity(images.len() + deltas.len());
for image in images {
heap.push(IteratorWrapper::create_from_image_layer(image, ctx));
}
for delta in deltas {
heap.push(IteratorWrapper::create_from_delta_layer(delta, ctx));
}
Self {
heap: BinaryHeap::from(heap),
}
Self::create_with_options(deltas, images, ctx, 1024 * 8192, 1024)
}
pub(crate) async fn next_inner<R: MergeIteratorItem>(&mut self) -> anyhow::Result<Option<R>> {

View File

@@ -2828,6 +2828,41 @@ impl Timeline {
Ok(())
}
/// Check if the memory usage is within the limit.
async fn check_memory_usage(
self: &Arc<Self>,
layer_selection: &[Layer],
) -> Result<(), CompactionError> {
let mut estimated_memory_usage_mb = 0.0;
let mut num_image_layers = 0;
let mut num_delta_layers = 0;
let target_layer_size_bytes = 256 * 1024 * 1024;
for layer in layer_selection {
let layer_desc = layer.layer_desc();
if layer_desc.is_delta() {
// Delta layers at most have 1MB buffer; 3x to make it safe (there're deltas as large as 16KB).
// Multiply the layer size so that tests can pass.
estimated_memory_usage_mb +=
3.0 * (layer_desc.file_size / target_layer_size_bytes) as f64;
num_delta_layers += 1;
} else {
// Image layers at most have 1MB buffer but it might be compressed; assume 5x compression ratio.
estimated_memory_usage_mb +=
5.0 * (layer_desc.file_size / target_layer_size_bytes) as f64;
num_image_layers += 1;
}
}
if estimated_memory_usage_mb > 1024.0 {
return Err(CompactionError::Other(anyhow!(
"estimated memory usage is too high: {}MB, giving up compaction; num_image_layers={}, num_delta_layers={}",
estimated_memory_usage_mb,
num_image_layers,
num_delta_layers
)));
}
Ok(())
}
/// Get a watermark for gc-compaction, that is the lowest LSN that we can use as the `gc_horizon` for
/// the compaction algorithm. It is min(space_cutoff, time_cutoff, latest_gc_cutoff, standby_horizon).
/// Leases and retain_lsns are considered in the gc-compaction job itself so we don't need to account for them
@@ -3264,6 +3299,17 @@ impl Timeline {
self.check_compaction_space(&job_desc.selected_layers)
.await?;
self.check_memory_usage(&job_desc.selected_layers).await?;
if job_desc.selected_layers.len() > 100
&& job_desc.rewrite_layers.len() as f64 >= job_desc.selected_layers.len() as f64 * 0.7
{
return Err(CompactionError::Other(anyhow!(
"too many layers to rewrite: {} / {}, giving up compaction",
job_desc.rewrite_layers.len(),
job_desc.selected_layers.len()
)));
}
// Generate statistics for the compaction
for layer in &job_desc.selected_layers {
let desc = layer.layer_desc();
@@ -3359,7 +3405,13 @@ impl Timeline {
.context("failed to collect gc compaction keyspace")
.map_err(CompactionError::Other)?;
let mut merge_iter = FilterIterator::create(
MergeIterator::create(&delta_layers, &image_layers, ctx),
MergeIterator::create_with_options(
&delta_layers,
&image_layers,
ctx,
128 * 8192, /* 1MB buffer for each of the inner iterators */
128,
),
dense_ks,
sparse_ks,
)

View File

@@ -1,20 +1,21 @@
use std::sync::Arc;
use anyhow::{Context, bail};
use pageserver_api::models::ShardImportStatus;
use remote_storage::RemotePath;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, info, info_span};
use tracing::info;
use utils::lsn::Lsn;
use super::Timeline;
use crate::context::RequestContext;
use crate::controller_upcall_client::{StorageControllerUpcallApi, StorageControllerUpcallClient};
use crate::tenant::metadata::TimelineMetadata;
mod flow;
mod importbucket_client;
mod importbucket_format;
pub(crate) mod index_part_format;
pub(crate) mod upcall_api;
pub async fn doit(
timeline: &Arc<Timeline>,
@@ -34,23 +35,6 @@ pub async fn doit(
let storage = importbucket_client::new(timeline.conf, &location, cancel.clone()).await?;
info!("get spec early so we know we'll be able to upcall when done");
let Some(spec) = storage.get_spec().await? else {
bail!("spec not found")
};
let upcall_client =
upcall_api::Client::new(timeline.conf, cancel.clone()).context("create upcall client")?;
//
// send an early progress update to clean up k8s job early and generate potentially useful logs
//
info!("send early progress update");
upcall_client
.send_progress_until_success(&spec)
.instrument(info_span!("early_progress_update"))
.await?;
let status_prefix = RemotePath::from_string("status").unwrap();
//
@@ -176,7 +160,21 @@ pub async fn doit(
//
// Communicate that shard is done.
// Ensure at-least-once delivery of the upcall to storage controller
// before we mark the task as done and never come here again.
//
let storcon_client = StorageControllerUpcallClient::new(timeline.conf, &cancel)?
.expect("storcon configured");
storcon_client
.put_timeline_import_status(
timeline.tenant_shard_id,
timeline.timeline_id,
// TODO(vlad): What about import errors?
ShardImportStatus::Done,
)
.await
.map_err(|_err| anyhow::anyhow!("Shut down while putting timeline import status"))?;
storage
.put_json(
&shard_status_key,
@@ -186,16 +184,6 @@ pub async fn doit(
.context("put shard status")?;
}
//
// Ensure at-least-once deliver of the upcall to cplane
// before we mark the task as done and never come here again.
//
info!("send final progress update");
upcall_client
.send_progress_until_success(&spec)
.instrument(info_span!("final_progress_update"))
.await?;
//
// Mark as done in index_part.
// This makes subsequent timeline loads enter the normal load code path

View File

@@ -13,7 +13,7 @@ use tokio_util::sync::CancellationToken;
use tracing::{debug, info, instrument};
use utils::lsn::Lsn;
use super::{importbucket_format, index_part_format};
use super::index_part_format;
use crate::assert_u64_eq_usize::U64IsUsize;
use crate::config::PageServerConf;
@@ -173,12 +173,6 @@ impl RemoteStorageWrapper {
res
}
pub async fn get_spec(&self) -> Result<Option<importbucket_format::Spec>, anyhow::Error> {
self.get_json(&RemotePath::from_string("spec.json").unwrap())
.await
.context("get spec")
}
#[instrument(level = tracing::Level::DEBUG, skip_all, fields(%path))]
pub async fn get_json<T: DeserializeOwned>(
&self,
@@ -244,7 +238,8 @@ impl RemoteStorageWrapper {
kind: DownloadKind::Large,
etag: None,
byte_start: Bound::Included(start_inclusive),
byte_end: Bound::Excluded(end_exclusive)
byte_end: Bound::Excluded(end_exclusive),
version_id: None,
},
&self.cancel)
.await?;

View File

@@ -11,10 +11,3 @@ pub struct ShardStatus {
pub done: bool,
// TODO: remaining fields
}
// TODO: dedupe with fast_import code
#[derive(Deserialize, Serialize, Debug, Clone, PartialEq, Eq)]
pub struct Spec {
pub project_id: String,
pub branch_id: String,
}

View File

@@ -1,124 +0,0 @@
//! FIXME: most of this is copy-paste from mgmt_api.rs ; dedupe into a `reqwest_utils::Client` crate.
use pageserver_client::mgmt_api::{Error, ResponseErrorMessageExt};
use reqwest::{Certificate, Method};
use serde::{Deserialize, Serialize};
use tokio_util::sync::CancellationToken;
use tracing::error;
use super::importbucket_format::Spec;
use crate::config::PageServerConf;
pub struct Client {
base_url: String,
authorization_header: Option<String>,
client: reqwest::Client,
cancel: CancellationToken,
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Serialize, Deserialize, Debug)]
struct ImportProgressRequest {
// no fields yet, not sure if there every will be any
}
#[derive(Serialize, Deserialize, Debug)]
struct ImportProgressResponse {
// we don't care
}
impl Client {
pub fn new(conf: &PageServerConf, cancel: CancellationToken) -> anyhow::Result<Self> {
let Some(ref base_url) = conf.import_pgdata_upcall_api else {
anyhow::bail!("import_pgdata_upcall_api is not configured")
};
let mut http_client = reqwest::Client::builder();
for cert in &conf.ssl_ca_certs {
http_client = http_client.add_root_certificate(Certificate::from_der(cert.contents())?);
}
let http_client = http_client.build()?;
Ok(Self {
base_url: base_url.to_string(),
client: http_client,
cancel,
authorization_header: conf
.import_pgdata_upcall_api_token
.as_ref()
.map(|secret_string| secret_string.get_contents())
.map(|jwt| format!("Bearer {jwt}")),
})
}
fn start_request<U: reqwest::IntoUrl>(
&self,
method: Method,
uri: U,
) -> reqwest::RequestBuilder {
let req = self.client.request(method, uri);
if let Some(value) = &self.authorization_header {
req.header(reqwest::header::AUTHORIZATION, value)
} else {
req
}
}
async fn request_noerror<B: serde::Serialize, U: reqwest::IntoUrl>(
&self,
method: Method,
uri: U,
body: B,
) -> Result<reqwest::Response> {
self.start_request(method, uri)
.json(&body)
.send()
.await
.map_err(Error::ReceiveBody)
}
async fn request<B: serde::Serialize, U: reqwest::IntoUrl>(
&self,
method: Method,
uri: U,
body: B,
) -> Result<reqwest::Response> {
let res = self.request_noerror(method, uri, body).await?;
let response = res.error_from_body().await?;
Ok(response)
}
pub async fn send_progress_once(&self, spec: &Spec) -> Result<()> {
let url = format!(
"{}/projects/{}/branches/{}/import_progress",
self.base_url, spec.project_id, spec.branch_id
);
let ImportProgressResponse {} = self
.request(Method::POST, url, &ImportProgressRequest {})
.await?
.json()
.await
.map_err(Error::ReceiveBody)?;
Ok(())
}
pub async fn send_progress_until_success(&self, spec: &Spec) -> anyhow::Result<()> {
loop {
match self.send_progress_once(spec).await {
Ok(()) => return Ok(()),
Err(Error::Cancelled) => return Err(anyhow::anyhow!("cancelled")),
Err(err) => {
error!(?err, "error sending progress, retrying");
if tokio::time::timeout(
std::time::Duration::from_secs(10),
self.cancel.cancelled(),
)
.await
.is_ok()
{
anyhow::bail!("cancelled while sending early progress update");
}
}
}
}
}
}

View File

@@ -507,7 +507,9 @@ impl<'a> VectoredBlobReader<'a> {
for (blob_start, meta) in blobs_at.iter().copied() {
let header_start = (blob_start - read.start) as usize;
let header = Header::decode(&buf[header_start..])?;
let header = Header::decode(&buf[header_start..]).map_err(|anyhow_err| {
std::io::Error::new(std::io::ErrorKind::InvalidData, anyhow_err)
})?;
let data_start = header_start + header.header_len;
let end = data_start + header.data_len;
let compression_bits = header.compression_bits;
@@ -662,7 +664,6 @@ impl StreamingVectoredReadPlanner {
#[cfg(test)]
mod tests {
use anyhow::Error;
use super::super::blob_io::tests::{random_array, write_maybe_compressed};
use super::*;
@@ -945,13 +946,16 @@ mod tests {
}
}
async fn round_trip_test_compressed(blobs: &[Vec<u8>], compression: bool) -> Result<(), Error> {
async fn round_trip_test_compressed(
blobs: &[Vec<u8>],
compression: bool,
) -> anyhow::Result<()> {
let ctx =
RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error).with_scope_unit_test();
let (_temp_dir, pathbuf, offsets) =
write_maybe_compressed::<true>(blobs, compression, &ctx).await?;
write_maybe_compressed(blobs, compression, &ctx).await?;
let file = VirtualFile::open(&pathbuf, &ctx).await?;
let file = VirtualFile::open_v2(&pathbuf, &ctx).await?;
let file_len = std::fs::metadata(&pathbuf)?.len();
// Multiply by two (compressed data might need more space), and add a few bytes for the header
@@ -997,7 +1001,7 @@ mod tests {
}
#[tokio::test]
async fn test_really_big_array() -> Result<(), Error> {
async fn test_really_big_array() -> anyhow::Result<()> {
let blobs = &[
b"test".to_vec(),
random_array(10 * PAGE_SZ),
@@ -1012,7 +1016,7 @@ mod tests {
}
#[tokio::test]
async fn test_arrays_inc() -> Result<(), Error> {
async fn test_arrays_inc() -> anyhow::Result<()> {
let blobs = (0..PAGE_SZ / 8)
.map(|v| random_array(v * 16))
.collect::<Vec<_>>();

View File

@@ -12,10 +12,11 @@
//! src/backend/storage/file/fd.c
//!
use std::fs::File;
use std::io::{Error, ErrorKind, Seek, SeekFrom};
use std::io::{Error, ErrorKind};
use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
#[cfg(target_os = "linux")]
use std::os::unix::fs::OpenOptionsExt;
use std::sync::LazyLock;
use std::sync::atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering};
use camino::{Utf8Path, Utf8PathBuf};
@@ -25,29 +26,31 @@ use owned_buffers_io::aligned_buffer::{AlignedBufferMut, AlignedSlice, ConstAlig
use owned_buffers_io::io_buf_aligned::{IoBufAligned, IoBufAlignedMut};
use owned_buffers_io::io_buf_ext::FullSlice;
use pageserver_api::config::defaults::DEFAULT_IO_BUFFER_ALIGNMENT;
pub use pageserver_api::models::virtual_file as api;
use tokio::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use tokio::time::Instant;
use tokio_epoll_uring::{BoundedBuf, IoBuf, IoBufMut, Slice};
use self::owned_buffers_io::write::OwnedAsyncWriter;
use crate::assert_u64_eq_usize::UsizeIsU64;
use crate::context::RequestContext;
use crate::metrics::{STORAGE_IO_TIME_METRIC, StorageIoOperation};
use crate::page_cache::{PAGE_SZ, PageWriteGuard};
pub(crate) mod io_engine;
pub(crate) use api::IoMode;
pub(crate) use io_engine::IoEngineKind;
pub use io_engine::{
FeatureTestResult as IoEngineFeatureTestResult, feature_test as io_engine_feature_test,
io_engine_for_bench,
};
mod metadata;
mod open_options;
pub(crate) use api::IoMode;
pub(crate) use io_engine::IoEngineKind;
pub(crate) use metadata::Metadata;
pub(crate) use open_options::*;
pub use pageserver_api::models::virtual_file as api;
pub use temporary::TempVirtualFile;
use self::owned_buffers_io::write::OwnedAsyncWriter;
pub(crate) mod io_engine;
mod metadata;
mod open_options;
mod temporary;
pub(crate) mod owned_buffers_io {
//! Abstractions for IO with owned buffers.
//!
@@ -94,69 +97,38 @@ impl VirtualFile {
Self::open_with_options_v2(path.as_ref(), OpenOptions::new().read(true), ctx).await
}
pub async fn create<P: AsRef<Utf8Path>>(
path: P,
ctx: &RequestContext,
) -> Result<Self, std::io::Error> {
let inner = VirtualFileInner::create(path, ctx).await?;
Ok(VirtualFile {
inner,
_mode: IoMode::Buffered,
})
}
pub async fn create_v2<P: AsRef<Utf8Path>>(
path: P,
ctx: &RequestContext,
) -> Result<Self, std::io::Error> {
VirtualFile::open_with_options_v2(
path.as_ref(),
OpenOptions::new().write(true).create(true).truncate(true),
ctx,
)
.await
}
pub async fn open_with_options<P: AsRef<Utf8Path>>(
path: P,
open_options: &OpenOptions,
ctx: &RequestContext,
) -> Result<Self, std::io::Error> {
let inner = VirtualFileInner::open_with_options(path, open_options, ctx).await?;
Ok(VirtualFile {
inner,
_mode: IoMode::Buffered,
})
}
pub async fn open_with_options_v2<P: AsRef<Utf8Path>>(
path: P,
open_options: &OpenOptions,
ctx: &RequestContext,
) -> Result<Self, std::io::Error> {
let file = match get_io_mode() {
IoMode::Buffered => {
let inner = VirtualFileInner::open_with_options(path, open_options, ctx).await?;
VirtualFile {
inner,
_mode: IoMode::Buffered,
}
}
let mode = get_io_mode();
let set_o_direct = match (mode, open_options.is_write()) {
(IoMode::Buffered, _) => false,
#[cfg(target_os = "linux")]
IoMode::Direct => {
let inner = VirtualFileInner::open_with_options(
path,
open_options.clone().custom_flags(nix::libc::O_DIRECT),
ctx,
)
.await?;
VirtualFile {
inner,
_mode: IoMode::Direct,
}
}
(IoMode::Direct, false) => true,
#[cfg(target_os = "linux")]
(IoMode::Direct, true) => false,
#[cfg(target_os = "linux")]
(IoMode::DirectRw, _) => true,
};
Ok(file)
let open_options = open_options.clone();
let open_options = if set_o_direct {
#[cfg(target_os = "linux")]
{
let mut open_options = open_options;
open_options.custom_flags(nix::libc::O_DIRECT);
open_options
}
#[cfg(not(target_os = "linux"))]
unreachable!(
"O_DIRECT is not supported on this platform, IoMode's that result in set_o_direct=true shouldn't even be defined"
);
} else {
open_options
};
let inner = VirtualFileInner::open_with_options(path, open_options, ctx).await?;
Ok(VirtualFile { inner, _mode: mode })
}
pub fn path(&self) -> &Utf8Path {
@@ -185,18 +157,14 @@ impl VirtualFile {
self.inner.sync_data().await
}
pub async fn set_len(&self, len: u64, ctx: &RequestContext) -> Result<(), Error> {
self.inner.set_len(len, ctx).await
}
pub async fn metadata(&self) -> Result<Metadata, Error> {
self.inner.metadata().await
}
pub fn remove(self) {
self.inner.remove();
}
pub async fn seek(&mut self, pos: SeekFrom) -> Result<u64, Error> {
self.inner.seek(pos).await
}
pub async fn read_exact_at<Buf>(
&self,
slice: Slice<Buf>,
@@ -227,25 +195,31 @@ impl VirtualFile {
self.inner.write_all_at(buf, offset, ctx).await
}
pub async fn write_all<Buf: IoBuf + Send>(
&mut self,
buf: FullSlice<Buf>,
pub(crate) async fn read_to_string<P: AsRef<Utf8Path>>(
path: P,
ctx: &RequestContext,
) -> (FullSlice<Buf>, Result<usize, Error>) {
self.inner.write_all(buf, ctx).await
}
async fn read_to_end(&mut self, buf: &mut Vec<u8>, ctx: &RequestContext) -> Result<(), Error> {
self.inner.read_to_end(buf, ctx).await
}
pub(crate) async fn read_to_string(
&mut self,
ctx: &RequestContext,
) -> Result<String, anyhow::Error> {
) -> std::io::Result<String> {
let file = VirtualFile::open(path, ctx).await?; // TODO: open_v2
let mut buf = Vec::new();
self.read_to_end(&mut buf, ctx).await?;
Ok(String::from_utf8(buf)?)
let mut tmp = vec![0; 128];
let mut pos: u64 = 0;
loop {
let slice = tmp.slice(..128);
let (slice, res) = file.inner.read_at(slice, pos, ctx).await;
match res {
Ok(0) => break,
Ok(n) => {
pos += n as u64;
buf.extend_from_slice(&slice[..n]);
}
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
tmp = slice.into_inner();
}
String::from_utf8(buf).map_err(|_| {
std::io::Error::new(ErrorKind::InvalidData, "file contents are not valid UTF-8")
})
}
}
@@ -292,9 +266,6 @@ pub struct VirtualFileInner {
/// belongs to a different VirtualFile.
handle: RwLock<SlotHandle>,
/// Current file position
pos: u64,
/// File path and options to use to open it.
///
/// Note: this only contains the options needed to re-open it. For example,
@@ -559,21 +530,7 @@ impl VirtualFileInner {
path: P,
ctx: &RequestContext,
) -> Result<VirtualFileInner, std::io::Error> {
Self::open_with_options(path.as_ref(), OpenOptions::new().read(true), ctx).await
}
/// Create a new file for writing. If the file exists, it will be truncated.
/// Like File::create.
pub async fn create<P: AsRef<Utf8Path>>(
path: P,
ctx: &RequestContext,
) -> Result<VirtualFileInner, std::io::Error> {
Self::open_with_options(
path.as_ref(),
OpenOptions::new().write(true).create(true).truncate(true),
ctx,
)
.await
Self::open_with_options(path.as_ref(), OpenOptions::new().read(true).clone(), ctx).await
}
/// Open a file with given options.
@@ -583,7 +540,7 @@ impl VirtualFileInner {
/// on the first time. Make sure that's sane!
pub async fn open_with_options<P: AsRef<Utf8Path>>(
path: P,
open_options: &OpenOptions,
open_options: OpenOptions,
_ctx: &RequestContext,
) -> Result<VirtualFileInner, std::io::Error> {
let path = path.as_ref();
@@ -608,7 +565,6 @@ impl VirtualFileInner {
let vfile = VirtualFileInner {
handle: RwLock::new(handle),
pos: 0,
path: path.to_owned(),
open_options: reopen_options,
};
@@ -675,6 +631,13 @@ impl VirtualFileInner {
})
}
pub async fn set_len(&self, len: u64, _ctx: &RequestContext) -> Result<(), Error> {
with_file!(self, StorageIoOperation::SetLen, |file_guard| {
let (_file_guard, res) = io_engine::get().set_len(file_guard, len).await;
res.maybe_fatal_err("set_len")
})
}
/// Helper function internal to `VirtualFile` that looks up the underlying File,
/// opens it and evicts some other File if necessary. The passed parameter is
/// assumed to be a function available for the physical `File`.
@@ -742,38 +705,6 @@ impl VirtualFileInner {
})
}
pub fn remove(self) {
let path = self.path.clone();
drop(self);
std::fs::remove_file(path).expect("failed to remove the virtual file");
}
pub async fn seek(&mut self, pos: SeekFrom) -> Result<u64, Error> {
match pos {
SeekFrom::Start(offset) => {
self.pos = offset;
}
SeekFrom::End(offset) => {
self.pos = with_file!(self, StorageIoOperation::Seek, |mut file_guard| file_guard
.with_std_file_mut(|std_file| std_file.seek(SeekFrom::End(offset))))?
}
SeekFrom::Current(offset) => {
let pos = self.pos as i128 + offset as i128;
if pos < 0 {
return Err(Error::new(
ErrorKind::InvalidInput,
"offset would be negative",
));
}
if pos > u64::MAX as i128 {
return Err(Error::new(ErrorKind::InvalidInput, "offset overflow"));
}
self.pos = pos as u64;
}
}
Ok(self.pos)
}
/// Read the file contents in range `offset..(offset + slice.bytes_total())` into `slice[0..slice.bytes_total()]`.
///
/// The returned `Slice<Buf>` is equivalent to the input `slice`, i.e., it's the same view into the same buffer.
@@ -857,59 +788,7 @@ impl VirtualFileInner {
(restore(buf), Ok(()))
}
/// Writes `buf` to the file at the current offset.
///
/// Panics if there is an uninitialized range in `buf`, as that is most likely a bug in the caller.
pub async fn write_all<Buf: IoBuf + Send>(
&mut self,
buf: FullSlice<Buf>,
ctx: &RequestContext,
) -> (FullSlice<Buf>, Result<usize, Error>) {
let buf = buf.into_raw_slice();
let bounds = buf.bounds();
let restore =
|buf: Slice<_>| FullSlice::must_new(Slice::from_buf_bounds(buf.into_inner(), bounds));
let nbytes = buf.len();
let mut buf = buf;
while !buf.is_empty() {
let (tmp, res) = self.write(FullSlice::must_new(buf), ctx).await;
buf = tmp.into_raw_slice();
match res {
Ok(0) => {
return (
restore(buf),
Err(Error::new(
std::io::ErrorKind::WriteZero,
"failed to write whole buffer",
)),
);
}
Ok(n) => {
buf = buf.slice(n..);
}
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
Err(e) => return (restore(buf), Err(e)),
}
}
(restore(buf), Ok(nbytes))
}
async fn write<B: IoBuf + Send>(
&mut self,
buf: FullSlice<B>,
ctx: &RequestContext,
) -> (FullSlice<B>, Result<usize, std::io::Error>) {
let pos = self.pos;
let (buf, res) = self.write_at(buf, pos, ctx).await;
let n = match res {
Ok(n) => n,
Err(e) => return (buf, Err(e)),
};
self.pos += n as u64;
(buf, Ok(n))
}
pub(crate) async fn read_at<Buf>(
pub(super) async fn read_at<Buf>(
&self,
buf: tokio_epoll_uring::Slice<Buf>,
offset: u64,
@@ -937,23 +816,11 @@ impl VirtualFileInner {
})
}
/// The function aborts the process if the error is fatal.
async fn write_at<B: IoBuf + Send>(
&self,
buf: FullSlice<B>,
offset: u64,
ctx: &RequestContext,
) -> (FullSlice<B>, Result<usize, Error>) {
let (slice, result) = self.write_at_inner(buf, offset, ctx).await;
let result = result.maybe_fatal_err("write_at");
(slice, result)
}
async fn write_at_inner<B: IoBuf + Send>(
&self,
buf: FullSlice<B>,
offset: u64,
ctx: &RequestContext,
) -> (FullSlice<B>, Result<usize, Error>) {
let file_guard = match self.lock_file().await {
Ok(file_guard) => file_guard,
@@ -962,30 +829,13 @@ impl VirtualFileInner {
observe_duration!(StorageIoOperation::Write, {
let ((_file_guard, buf), result) =
io_engine::get().write_at(file_guard, offset, buf).await;
let result = result.maybe_fatal_err("write_at");
if let Ok(size) = result {
ctx.io_size_metrics().write.add(size.into_u64());
}
(buf, result)
})
}
async fn read_to_end(&mut self, buf: &mut Vec<u8>, ctx: &RequestContext) -> Result<(), Error> {
let mut tmp = vec![0; 128];
loop {
let slice = tmp.slice(..128);
let (slice, res) = self.read_at(slice, self.pos, ctx).await;
match res {
Ok(0) => return Ok(()),
Ok(n) => {
self.pos += n as u64;
buf.extend_from_slice(&slice[..n]);
}
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
tmp = slice.into_inner();
}
}
}
// Adapted from https://doc.rust-lang.org/1.72.0/src/std/os/unix/fs.rs.html#117-135
@@ -1200,19 +1050,6 @@ impl FileGuard {
let _ = file.into_raw_fd();
res
}
/// Soft deprecation: we'll move VirtualFile to async APIs and remove this function eventually.
fn with_std_file_mut<F, R>(&mut self, with: F) -> R
where
F: FnOnce(&mut File) -> R,
{
// SAFETY:
// - lifetime of the fd: `file` doesn't outlive the OwnedFd stored in `self`.
// - &mut usage below: `self` is `&mut`, hence this call is the only task/thread that has control over the underlying fd
let mut file = unsafe { File::from_raw_fd(self.as_ref().as_raw_fd()) };
let res = with(&mut file);
let _ = file.into_raw_fd();
res
}
}
impl tokio_epoll_uring::IoFd for FileGuard {
@@ -1302,6 +1139,9 @@ impl OwnedAsyncWriter for VirtualFile {
) -> (FullSlice<Buf>, std::io::Result<()>) {
VirtualFile::write_all_at(self, buf, offset, ctx).await
}
async fn set_len(&self, len: u64, ctx: &RequestContext) -> std::io::Result<()> {
VirtualFile::set_len(self, len, ctx).await
}
}
impl OpenFiles {
@@ -1366,10 +1206,9 @@ pub(crate) type IoBuffer = AlignedBuffer<ConstAlign<{ get_io_buffer_alignment()
pub(crate) type IoPageSlice<'a> =
AlignedSlice<'a, PAGE_SZ, ConstAlign<{ get_io_buffer_alignment() }>>;
static IO_MODE: once_cell::sync::Lazy<AtomicU8> =
once_cell::sync::Lazy::new(|| AtomicU8::new(IoMode::preferred() as u8));
static IO_MODE: LazyLock<AtomicU8> = LazyLock::new(|| AtomicU8::new(IoMode::preferred() as u8));
pub(crate) fn set_io_mode(mode: IoMode) {
pub fn set_io_mode(mode: IoMode) {
IO_MODE.store(mode as u8, std::sync::atomic::Ordering::Relaxed);
}
@@ -1381,7 +1220,6 @@ static SYNC_MODE: AtomicU8 = AtomicU8::new(SyncMode::Sync as u8);
#[cfg(test)]
mod tests {
use std::io::Write;
use std::os::unix::fs::FileExt;
use std::sync::Arc;
@@ -1434,43 +1272,6 @@ mod tests {
MaybeVirtualFile::File(file) => file.write_all_at(&buf[..], offset),
}
}
async fn seek(&mut self, pos: SeekFrom) -> Result<u64, Error> {
match self {
MaybeVirtualFile::VirtualFile(file) => file.seek(pos).await,
MaybeVirtualFile::File(file) => file.seek(pos),
}
}
async fn write_all<Buf: IoBuf + Send>(
&mut self,
buf: FullSlice<Buf>,
ctx: &RequestContext,
) -> Result<(), Error> {
match self {
MaybeVirtualFile::VirtualFile(file) => {
let (_buf, res) = file.write_all(buf, ctx).await;
res.map(|_| ())
}
MaybeVirtualFile::File(file) => file.write_all(&buf[..]),
}
}
// Helper function to slurp contents of a file, starting at the current position,
// into a string
async fn read_string(&mut self, ctx: &RequestContext) -> Result<String, Error> {
use std::io::Read;
let mut buf = String::new();
match self {
MaybeVirtualFile::VirtualFile(file) => {
let mut buf = Vec::new();
file.read_to_end(&mut buf, ctx).await?;
return Ok(String::from_utf8(buf).unwrap());
}
MaybeVirtualFile::File(file) => {
file.read_to_string(&mut buf)?;
}
}
Ok(buf)
}
// Helper function to slurp a portion of a file into a string
async fn read_string_at(
@@ -1506,7 +1307,7 @@ mod tests {
opts: OpenOptions,
ctx: &RequestContext,
) -> Result<MaybeVirtualFile, anyhow::Error> {
let vf = VirtualFile::open_with_options(&path, &opts, ctx).await?;
let vf = VirtualFile::open_with_options_v2(&path, &opts, ctx).await?;
Ok(MaybeVirtualFile::VirtualFile(vf))
}
}
@@ -1566,48 +1367,23 @@ mod tests {
.await?;
file_a
.write_all(b"foobar".to_vec().slice_len(), &ctx)
.write_all_at(IoBuffer::from(b"foobar").slice_len(), 0, &ctx)
.await?;
// cannot read from a file opened in write-only mode
let _ = file_a.read_string(&ctx).await.unwrap_err();
let _ = file_a.read_string_at(0, 1, &ctx).await.unwrap_err();
// Close the file and re-open for reading
let mut file_a = A::open(path_a, OpenOptions::new().read(true).to_owned(), &ctx).await?;
// cannot write to a file opened in read-only mode
let _ = file_a
.write_all(b"bar".to_vec().slice_len(), &ctx)
.write_all_at(IoBuffer::from(b"bar").slice_len(), 0, &ctx)
.await
.unwrap_err();
// Try simple read
assert_eq!("foobar", file_a.read_string(&ctx).await?);
// It's positioned at the EOF now.
assert_eq!("", file_a.read_string(&ctx).await?);
// Test seeks.
assert_eq!(file_a.seek(SeekFrom::Start(1)).await?, 1);
assert_eq!("oobar", file_a.read_string(&ctx).await?);
assert_eq!(file_a.seek(SeekFrom::End(-2)).await?, 4);
assert_eq!("ar", file_a.read_string(&ctx).await?);
assert_eq!(file_a.seek(SeekFrom::Start(1)).await?, 1);
assert_eq!(file_a.seek(SeekFrom::Current(2)).await?, 3);
assert_eq!("bar", file_a.read_string(&ctx).await?);
assert_eq!(file_a.seek(SeekFrom::Current(-5)).await?, 1);
assert_eq!("oobar", file_a.read_string(&ctx).await?);
// Test erroneous seeks to before byte 0
file_a.seek(SeekFrom::End(-7)).await.unwrap_err();
assert_eq!(file_a.seek(SeekFrom::Start(1)).await?, 1);
file_a.seek(SeekFrom::Current(-2)).await.unwrap_err();
// the erroneous seek should have left the position unchanged
assert_eq!("oobar", file_a.read_string(&ctx).await?);
assert_eq!("foobar", file_a.read_string_at(0, 6, &ctx).await?);
// Create another test file, and try FileExt functions on it.
let path_b = testdir.join("file_b");
@@ -1633,9 +1409,6 @@ mod tests {
// Open a lot of files, enough to cause some evictions. (Or to be precise,
// open the same file many times. The effect is the same.)
//
// leave file_a positioned at offset 1 before we start
assert_eq!(file_a.seek(SeekFrom::Start(1)).await?, 1);
let mut vfiles = Vec::new();
for _ in 0..100 {
@@ -1645,7 +1418,7 @@ mod tests {
&ctx,
)
.await?;
assert_eq!("FOOBAR", vfile.read_string(&ctx).await?);
assert_eq!("FOOBAR", vfile.read_string_at(0, 6, &ctx).await?);
vfiles.push(vfile);
}
@@ -1653,8 +1426,8 @@ mod tests {
assert!(vfiles.len() > TEST_MAX_FILE_DESCRIPTORS * 2);
// The underlying file descriptor for 'file_a' should be closed now. Try to read
// from it again. We left the file positioned at offset 1 above.
assert_eq!("oobar", file_a.read_string(&ctx).await?);
// from it again.
assert_eq!("foobar", file_a.read_string_at(0, 6, &ctx).await?);
// Check that all the other FDs still work too. Use them in random order for
// good measure.
@@ -1693,7 +1466,7 @@ mod tests {
for _ in 0..VIRTUAL_FILES {
let f = VirtualFileInner::open_with_options(
&test_file_path,
OpenOptions::new().read(true),
OpenOptions::new().read(true).clone(),
&ctx,
)
.await?;
@@ -1748,7 +1521,7 @@ mod tests {
.await
.unwrap();
let mut file = MaybeVirtualFile::from(VirtualFile::open(&path, &ctx).await.unwrap());
let post = file.read_string(&ctx).await.unwrap();
let post = file.read_string_at(0, 3, &ctx).await.unwrap();
assert_eq!(post, "foo");
assert!(!tmp_path.exists());
drop(file);
@@ -1757,7 +1530,7 @@ mod tests {
.await
.unwrap();
let mut file = MaybeVirtualFile::from(VirtualFile::open(&path, &ctx).await.unwrap());
let post = file.read_string(&ctx).await.unwrap();
let post = file.read_string_at(0, 3, &ctx).await.unwrap();
assert_eq!(post, "bar");
assert!(!tmp_path.exists());
drop(file);
@@ -1782,7 +1555,7 @@ mod tests {
.unwrap();
let mut file = MaybeVirtualFile::from(VirtualFile::open(&path, &ctx).await.unwrap());
let post = file.read_string(&ctx).await.unwrap();
let post = file.read_string_at(0, 3, &ctx).await.unwrap();
assert_eq!(post, "foo");
assert!(!tmp_path.exists());
drop(file);

View File

@@ -209,6 +209,27 @@ impl IoEngine {
}
}
}
pub(super) async fn set_len(
&self,
file_guard: FileGuard,
len: u64,
) -> (FileGuard, std::io::Result<()>) {
match self {
IoEngine::NotSet => panic!("not initialized"),
IoEngine::StdFs => {
let res = file_guard.with_std_file(|std_file| std_file.set_len(len));
(file_guard, res)
}
#[cfg(target_os = "linux")]
IoEngine::TokioEpollUring => {
// TODO: ftruncate op for tokio-epoll-uring
let res = file_guard.with_std_file(|std_file| std_file.set_len(len));
(file_guard, res)
}
}
}
pub(super) async fn write_at<B: IoBuf + Send>(
&self,
file_guard: FileGuard,

View File

@@ -6,7 +6,12 @@ use std::path::Path;
use super::io_engine::IoEngine;
#[derive(Debug, Clone)]
pub enum OpenOptions {
pub struct OpenOptions {
write: bool,
inner: Inner,
}
#[derive(Debug, Clone)]
enum Inner {
StdFs(std::fs::OpenOptions),
#[cfg(target_os = "linux")]
TokioEpollUring(tokio_epoll_uring::ops::open_at::OpenOptions),
@@ -14,13 +19,17 @@ pub enum OpenOptions {
impl Default for OpenOptions {
fn default() -> Self {
match super::io_engine::get() {
let inner = match super::io_engine::get() {
IoEngine::NotSet => panic!("io engine not set"),
IoEngine::StdFs => Self::StdFs(std::fs::OpenOptions::new()),
IoEngine::StdFs => Inner::StdFs(std::fs::OpenOptions::new()),
#[cfg(target_os = "linux")]
IoEngine::TokioEpollUring => {
Self::TokioEpollUring(tokio_epoll_uring::ops::open_at::OpenOptions::new())
Inner::TokioEpollUring(tokio_epoll_uring::ops::open_at::OpenOptions::new())
}
};
Self {
write: false,
inner,
}
}
}
@@ -30,13 +39,17 @@ impl OpenOptions {
Self::default()
}
pub(super) fn is_write(&self) -> bool {
self.write
}
pub fn read(&mut self, read: bool) -> &mut OpenOptions {
match self {
OpenOptions::StdFs(x) => {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.read(read);
}
#[cfg(target_os = "linux")]
OpenOptions::TokioEpollUring(x) => {
Inner::TokioEpollUring(x) => {
let _ = x.read(read);
}
}
@@ -44,12 +57,13 @@ impl OpenOptions {
}
pub fn write(&mut self, write: bool) -> &mut OpenOptions {
match self {
OpenOptions::StdFs(x) => {
self.write = write;
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.write(write);
}
#[cfg(target_os = "linux")]
OpenOptions::TokioEpollUring(x) => {
Inner::TokioEpollUring(x) => {
let _ = x.write(write);
}
}
@@ -57,12 +71,12 @@ impl OpenOptions {
}
pub fn create(&mut self, create: bool) -> &mut OpenOptions {
match self {
OpenOptions::StdFs(x) => {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.create(create);
}
#[cfg(target_os = "linux")]
OpenOptions::TokioEpollUring(x) => {
Inner::TokioEpollUring(x) => {
let _ = x.create(create);
}
}
@@ -70,12 +84,12 @@ impl OpenOptions {
}
pub fn create_new(&mut self, create_new: bool) -> &mut OpenOptions {
match self {
OpenOptions::StdFs(x) => {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.create_new(create_new);
}
#[cfg(target_os = "linux")]
OpenOptions::TokioEpollUring(x) => {
Inner::TokioEpollUring(x) => {
let _ = x.create_new(create_new);
}
}
@@ -83,12 +97,12 @@ impl OpenOptions {
}
pub fn truncate(&mut self, truncate: bool) -> &mut OpenOptions {
match self {
OpenOptions::StdFs(x) => {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.truncate(truncate);
}
#[cfg(target_os = "linux")]
OpenOptions::TokioEpollUring(x) => {
Inner::TokioEpollUring(x) => {
let _ = x.truncate(truncate);
}
}
@@ -96,10 +110,10 @@ impl OpenOptions {
}
pub(in crate::virtual_file) async fn open(&self, path: &Path) -> std::io::Result<OwnedFd> {
match self {
OpenOptions::StdFs(x) => x.open(path).map(|file| file.into()),
match &self.inner {
Inner::StdFs(x) => x.open(path).map(|file| file.into()),
#[cfg(target_os = "linux")]
OpenOptions::TokioEpollUring(x) => {
Inner::TokioEpollUring(x) => {
let system = super::io_engine::tokio_epoll_uring_ext::thread_local_system().await;
system.open(path, x).await.map_err(|e| match e {
tokio_epoll_uring::Error::Op(e) => e,
@@ -114,12 +128,12 @@ impl OpenOptions {
impl std::os::unix::prelude::OpenOptionsExt for OpenOptions {
fn mode(&mut self, mode: u32) -> &mut OpenOptions {
match self {
OpenOptions::StdFs(x) => {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.mode(mode);
}
#[cfg(target_os = "linux")]
OpenOptions::TokioEpollUring(x) => {
Inner::TokioEpollUring(x) => {
let _ = x.mode(mode);
}
}
@@ -127,12 +141,12 @@ impl std::os::unix::prelude::OpenOptionsExt for OpenOptions {
}
fn custom_flags(&mut self, flags: i32) -> &mut OpenOptions {
match self {
OpenOptions::StdFs(x) => {
match &mut self.inner {
Inner::StdFs(x) => {
let _ = x.custom_flags(flags);
}
#[cfg(target_os = "linux")]
OpenOptions::TokioEpollUring(x) => {
Inner::TokioEpollUring(x) => {
let _ = x.custom_flags(flags);
}
}

View File

@@ -282,6 +282,17 @@ unsafe impl<A: Alignment> tokio_epoll_uring::IoBufMut for AlignedBufferMut<A> {
}
}
impl<A: Alignment> std::io::Write for AlignedBufferMut<A> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {

View File

@@ -1,15 +1,19 @@
mod flush;
use std::sync::Arc;
use bytes::BufMut;
pub(crate) use flush::FlushControl;
use flush::FlushHandle;
pub(crate) use flush::FlushTaskError;
use flush::ShutdownRequest;
use tokio_epoll_uring::IoBuf;
use tokio_util::sync::CancellationToken;
use tracing::trace;
use super::io_buf_aligned::IoBufAligned;
use super::io_buf_aligned::IoBufAlignedMut;
use super::io_buf_ext::{FullSlice, IoBufExt};
use crate::context::RequestContext;
use crate::virtual_file::UsizeIsU64;
use crate::virtual_file::{IoBuffer, IoBufferMut};
pub(crate) trait CheapCloneForRead {
@@ -34,14 +38,50 @@ pub trait OwnedAsyncWriter {
offset: u64,
ctx: &RequestContext,
) -> impl std::future::Future<Output = (FullSlice<Buf>, std::io::Result<()>)> + Send;
fn set_len(
&self,
len: u64,
ctx: &RequestContext,
) -> impl Future<Output = std::io::Result<()>> + Send;
}
/// A wrapper aorund an [`OwnedAsyncWriter`] that uses a [`Buffer`] to batch
/// small writes into larger writes of size [`Buffer::cap`].
///
/// The buffer is flushed if and only if it is full ([`Buffer::pending`] == [`Buffer::cap`]).
/// This guarantees that writes to the filesystem happen
/// - at offsets that are multiples of [`Buffer::cap`]
/// - in lengths that are multiples of [`Buffer::cap`]
///
/// Above property is useful for Direct IO, where whatever the
/// effectively dominating disk-sector/filesystem-block/memory-page size
/// determines the requirements on
/// - the alignment of the pointer passed to the read/write operation
/// - the value of `count` (i.e., the length of the read/write operation)
/// which must be a multiple of the dominating sector/block/page size.
///
/// See [`BufferedWriter::shutdown`] / [`BufferedWriterShutdownMode`] for different
/// ways of dealing with the special case that the buffer is not full by the time
/// we are done writing.
///
/// The first flush to the underlying `W` happens at offset `start_offset` (arg of [`BufferedWriter::new`]).
/// The next flush is to offset `start_offset + Buffer::cap`. The one after at `start_offset + 2 * Buffer::cap` and so on.
///
/// TODO: decouple buffer capacity from alignment requirement.
/// Right now we assume [`Buffer::cap`] is the alignment requirement,
/// but actually [`Buffer::cap`] should only determine how often we flush
/// while writing, while a separate alignment requirement argument should
/// be passed to determine alignment requirement. This could be used by
/// [`BufferedWriterShutdownMode::PadThenTruncate`] to avoid excessive
/// padding of zeroes. For example, today, with a capacity of 64KiB, we
/// would pad up to 64KiB-1 bytes of zeroes, then truncate off 64KiB-1.
/// This is wasteful, e.g., if the alignment requirement is 4KiB, we only
/// need to pad & truncate up to 4KiB-1 bytes of zeroes
///
// TODO(yuchen): For large write, implementing buffer bypass for aligned parts of the write could be beneficial to throughput,
// since we would avoid copying majority of the data into the internal buffer.
// https://github.com/neondatabase/neon/issues/10101
pub struct BufferedWriter<B: Buffer, W> {
writer: Arc<W>,
/// Clone of the buffer that was last submitted to the flush loop.
/// `None` if no flush request has been submitted, Some forever after.
pub(super) maybe_flushed: Option<FullSlice<B::IoBuf>>,
@@ -62,9 +102,24 @@ pub struct BufferedWriter<B: Buffer, W> {
bytes_submitted: u64,
}
/// How [`BufferedWriter::shutdown`] should deal with pending (=not-yet-flushed) data.
///
/// Cf the [`BufferedWriter`] comment's paragraph for context on why we need to think about this.
pub enum BufferedWriterShutdownMode {
/// Drop pending data, don't write back to file.
DropTail,
/// Pad the pending data with zeroes (cf [`usize::next_multiple_of`]).
ZeroPadToNextMultiple(usize),
/// Fill the IO buffer with zeroes, flush to disk, the `ftruncate` the
/// file to the exact number of bytes written to [`Self`].
///
/// TODO: see in [`BufferedWriter`] comment about decoupling buffer capacity from alignment requirement.
PadThenTruncate,
}
impl<B, Buf, W> BufferedWriter<B, W>
where
B: Buffer<IoBuf = Buf> + Send + 'static,
B: IoBufAlignedMut + Buffer<IoBuf = Buf> + Send + 'static,
Buf: IoBufAligned + Send + Sync + CheapCloneForRead,
W: OwnedAsyncWriter + Send + Sync + 'static + std::fmt::Debug,
{
@@ -72,7 +127,8 @@ where
///
/// The `buf_new` function provides a way to initialize the owned buffers used by this writer.
pub fn new(
writer: Arc<W>,
writer: W,
start_offset: u64,
buf_new: impl Fn() -> B,
gate_guard: utils::sync::gate::GateGuard,
cancel: CancellationToken,
@@ -80,7 +136,6 @@ where
flush_task_span: tracing::Span,
) -> Self {
Self {
writer: writer.clone(),
mutable: Some(buf_new()),
maybe_flushed: None,
flush_handle: FlushHandle::spawn_new(
@@ -91,14 +146,10 @@ where
ctx.attached_child(),
flush_task_span,
),
bytes_submitted: 0,
bytes_submitted: start_offset,
}
}
pub fn as_inner(&self) -> &W {
&self.writer
}
/// Returns the number of bytes submitted to the background flush task.
pub fn bytes_submitted(&self) -> u64 {
self.bytes_submitted
@@ -116,22 +167,80 @@ where
}
#[cfg_attr(target_os = "macos", allow(dead_code))]
pub async fn flush_and_into_inner(
pub async fn shutdown(
mut self,
mode: BufferedWriterShutdownMode,
ctx: &RequestContext,
) -> Result<(u64, Arc<W>), FlushTaskError> {
self.flush(ctx).await?;
) -> Result<(u64, W), FlushTaskError> {
let mut mutable = self.mutable.take().expect("must not use after an error");
let unpadded_pending = mutable.pending();
let final_len: u64;
let shutdown_req;
match mode {
BufferedWriterShutdownMode::DropTail => {
trace!(pending=%mutable.pending(), "dropping pending data");
drop(mutable);
final_len = self.bytes_submitted;
shutdown_req = ShutdownRequest { set_len: None };
}
BufferedWriterShutdownMode::ZeroPadToNextMultiple(next_multiple) => {
let len = mutable.pending();
let cap = mutable.cap();
assert!(
len <= cap,
"buffer impl ensures this, but let's check because the extend_with below would panic if we go beyond"
);
let padded_len = len.next_multiple_of(next_multiple);
assert!(
padded_len <= cap,
"caller specified a multiple that is larger than the buffer capacity"
);
let count = padded_len - len;
mutable.extend_with(0, count);
trace!(count, "padding with zeros");
self.mutable = Some(mutable);
final_len = self.bytes_submitted + padded_len.into_u64();
shutdown_req = ShutdownRequest { set_len: None };
}
BufferedWriterShutdownMode::PadThenTruncate => {
let len = mutable.pending();
let cap = mutable.cap();
// TODO: see struct comment TODO on decoupling buffer capacity from alignment requirement.
let alignment_requirement = cap;
assert!(len <= cap, "buffer impl should ensure this");
let padding_end_offset = len.next_multiple_of(alignment_requirement);
assert!(
padding_end_offset <= cap,
"{padding_end_offset} <= {cap} ({alignment_requirement})"
);
let count = padding_end_offset - len;
mutable.extend_with(0, count);
trace!(count, "padding with zeros");
self.mutable = Some(mutable);
final_len = self.bytes_submitted + len.into_u64();
shutdown_req = ShutdownRequest {
// Avoid set_len call if we didn't need to pad anything.
set_len: if count > 0 { Some(final_len) } else { None },
};
}
};
let padded_pending = self.mutable.as_ref().map(|b| b.pending());
trace!(unpadded_pending, padded_pending, "padding done");
if self.mutable.is_some() {
self.flush(ctx).await?;
}
let Self {
mutable: buf,
mutable: _,
maybe_flushed: _,
writer,
mut flush_handle,
bytes_submitted: bytes_amount,
bytes_submitted: _,
} = self;
flush_handle.shutdown().await?;
assert!(buf.is_some());
Ok((bytes_amount, writer))
let writer = flush_handle.shutdown(shutdown_req).await?;
Ok((final_len, writer))
}
#[cfg(test)]
@@ -235,6 +344,10 @@ pub trait Buffer {
/// panics if `other.len() > self.cap() - self.pending()`.
fn extend_from_slice(&mut self, other: &[u8]);
/// Add `count` bytes `val` into `self`.
/// Panics if `count > self.cap() - self.pending()`.
fn extend_with(&mut self, val: u8, count: usize);
/// Number of bytes in the buffer.
fn pending(&self) -> usize;
@@ -262,6 +375,14 @@ impl Buffer for IoBufferMut {
IoBufferMut::extend_from_slice(self, other);
}
fn extend_with(&mut self, val: u8, count: usize) {
if self.len() + count > self.cap() {
panic!("Buffer capacity exceeded");
}
IoBufferMut::put_bytes(self, val, count);
}
fn pending(&self) -> usize {
self.len()
}
@@ -284,26 +405,22 @@ impl Buffer for IoBufferMut {
mod tests {
use std::sync::Mutex;
use rstest::rstest;
use super::*;
use crate::context::{DownloadBehavior, RequestContext};
use crate::task_mgr::TaskKind;
#[derive(Debug, PartialEq, Eq)]
enum Op {
Write { buf: Vec<u8>, offset: u64 },
SetLen { len: u64 },
}
#[derive(Default, Debug)]
struct RecorderWriter {
/// record bytes and write offsets.
writes: Mutex<Vec<(Vec<u8>, u64)>>,
}
impl RecorderWriter {
/// Gets recorded bytes and write offsets.
fn get_writes(&self) -> Vec<Vec<u8>> {
self.writes
.lock()
.unwrap()
.iter()
.map(|(buf, _)| buf.clone())
.collect()
}
recording: Mutex<Vec<Op>>,
}
impl OwnedAsyncWriter for RecorderWriter {
@@ -313,28 +430,42 @@ mod tests {
offset: u64,
_: &RequestContext,
) -> (FullSlice<Buf>, std::io::Result<()>) {
self.writes
.lock()
.unwrap()
.push((Vec::from(&buf[..]), offset));
self.recording.lock().unwrap().push(Op::Write {
buf: Vec::from(&buf[..]),
offset,
});
(buf, Ok(()))
}
async fn set_len(&self, len: u64, _ctx: &RequestContext) -> std::io::Result<()> {
self.recording.lock().unwrap().push(Op::SetLen { len });
Ok(())
}
}
fn test_ctx() -> RequestContext {
RequestContext::new(TaskKind::UnitTest, DownloadBehavior::Error)
}
#[rstest]
#[tokio::test]
async fn test_write_all_borrowed_always_goes_through_buffer() -> anyhow::Result<()> {
async fn test_write_all_borrowed_always_goes_through_buffer(
#[values(
BufferedWriterShutdownMode::DropTail,
BufferedWriterShutdownMode::ZeroPadToNextMultiple(2),
BufferedWriterShutdownMode::PadThenTruncate
)]
mode: BufferedWriterShutdownMode,
) -> anyhow::Result<()> {
let ctx = test_ctx();
let ctx = &ctx;
let recorder = Arc::new(RecorderWriter::default());
let recorder = RecorderWriter::default();
let gate = utils::sync::gate::Gate::default();
let cancel = CancellationToken::new();
let cap = 4;
let mut writer = BufferedWriter::<_, RecorderWriter>::new(
recorder,
|| IoBufferMut::with_capacity(2),
0,
|| IoBufferMut::with_capacity(cap),
gate.enter()?,
cancel,
ctx,
@@ -344,23 +475,89 @@ mod tests {
writer.write_buffered_borrowed(b"abc", ctx).await?;
writer.write_buffered_borrowed(b"", ctx).await?;
writer.write_buffered_borrowed(b"d", ctx).await?;
writer.write_buffered_borrowed(b"e", ctx).await?;
writer.write_buffered_borrowed(b"fg", ctx).await?;
writer.write_buffered_borrowed(b"hi", ctx).await?;
writer.write_buffered_borrowed(b"j", ctx).await?;
writer.write_buffered_borrowed(b"klmno", ctx).await?;
writer.write_buffered_borrowed(b"efg", ctx).await?;
writer.write_buffered_borrowed(b"hijklm", ctx).await?;
let (_, recorder) = writer.flush_and_into_inner(ctx).await?;
assert_eq!(
recorder.get_writes(),
{
let expect: &[&[u8]] = &[b"ab", b"cd", b"ef", b"gh", b"ij", b"kl", b"mn", b"o"];
expect
let mut expect = {
[(0, b"abcd"), (4, b"efgh"), (8, b"ijkl")]
.into_iter()
.map(|(offset, v)| Op::Write {
offset,
buf: v[..].to_vec(),
})
.collect::<Vec<_>>()
};
let expect_next_offset = 12;
match &mode {
BufferedWriterShutdownMode::DropTail => (),
// We test the case with padding to next multiple of 2 so that it's different
// from the alignment requirement of 4 inferred from buffer capacity.
// See TODOs in the `BufferedWriter` struct comment on decoupling buffer capacity from alignment requirement.
BufferedWriterShutdownMode::ZeroPadToNextMultiple(2) => {
expect.push(Op::Write {
offset: expect_next_offset,
// it's legitimate for pad-to-next multiple 2 to be < alignment requirement 4 inferred from buffer capacity
buf: b"m\0".to_vec(),
});
}
.iter()
.map(|v| v[..].to_vec())
.collect::<Vec<_>>()
BufferedWriterShutdownMode::ZeroPadToNextMultiple(_) => unimplemented!(),
BufferedWriterShutdownMode::PadThenTruncate => {
expect.push(Op::Write {
offset: expect_next_offset,
buf: b"m\0\0\0".to_vec(),
});
expect.push(Op::SetLen { len: 13 });
}
}
let (_, recorder) = writer.shutdown(mode, ctx).await?;
assert_eq!(&*recorder.recording.lock().unwrap(), &expect);
Ok(())
}
#[tokio::test]
async fn test_set_len_is_skipped_if_not_needed() -> anyhow::Result<()> {
let ctx = test_ctx();
let ctx = &ctx;
let recorder = RecorderWriter::default();
let gate = utils::sync::gate::Gate::default();
let cancel = CancellationToken::new();
let cap = 4;
let mut writer = BufferedWriter::<_, RecorderWriter>::new(
recorder,
0,
|| IoBufferMut::with_capacity(cap),
gate.enter()?,
cancel,
ctx,
tracing::Span::none(),
);
// write a multiple of `cap`
writer.write_buffered_borrowed(b"abc", ctx).await?;
writer.write_buffered_borrowed(b"defgh", ctx).await?;
let (_, recorder) = writer
.shutdown(BufferedWriterShutdownMode::PadThenTruncate, ctx)
.await?;
let expect = {
[(0, b"abcd"), (4, b"efgh")]
.into_iter()
.map(|(offset, v)| Op::Write {
offset,
buf: v[..].to_vec(),
})
.collect::<Vec<_>>()
};
assert_eq!(
&*recorder.recording.lock().unwrap(),
&expect,
"set_len should not be called if the buffer is already aligned"
);
Ok(())
}
}

View File

@@ -1,8 +1,7 @@
use std::ops::ControlFlow;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, info, info_span, warn};
use tracing::{Instrument, info_span, warn};
use utils::sync::duplex;
use super::{Buffer, CheapCloneForRead, OwnedAsyncWriter};
@@ -19,18 +18,36 @@ pub struct FlushHandle<Buf, W> {
pub struct FlushHandleInner<Buf, W> {
/// A bi-directional channel that sends (buffer, offset) for writes,
/// and receives recyled buffer.
channel: duplex::mpsc::Duplex<FlushRequest<Buf>, FullSlice<Buf>>,
channel: duplex::mpsc::Duplex<Request<Buf>, FullSlice<Buf>>,
/// Join handle for the background flush task.
join_handle: tokio::task::JoinHandle<Result<Arc<W>, FlushTaskError>>,
join_handle: tokio::task::JoinHandle<Result<W, FlushTaskError>>,
}
struct FlushRequest<Buf> {
slice: FullSlice<Buf>,
offset: u64,
#[cfg(test)]
ready_to_flush_rx: tokio::sync::oneshot::Receiver<()>,
ready_to_flush_rx: Option<tokio::sync::oneshot::Receiver<()>>,
#[cfg(test)]
done_flush_tx: tokio::sync::oneshot::Sender<()>,
done_flush_tx: Option<tokio::sync::oneshot::Sender<()>>,
}
pub struct ShutdownRequest {
pub set_len: Option<u64>,
}
enum Request<Buf> {
Flush(FlushRequest<Buf>),
Shutdown(ShutdownRequest),
}
impl<Buf> Request<Buf> {
fn op_str(&self) -> &'static str {
match self {
Request::Flush(_) => "flush",
Request::Shutdown(_) => "shutdown",
}
}
}
/// Constructs a request and a control object for a new flush operation.
@@ -52,8 +69,8 @@ fn new_flush_op<Buf>(slice: FullSlice<Buf>, offset: u64) -> (FlushRequest<Buf>,
let request = FlushRequest {
slice,
offset,
ready_to_flush_rx,
done_flush_tx,
ready_to_flush_rx: Some(ready_to_flush_rx),
done_flush_tx: Some(done_flush_tx),
};
(request, control)
}
@@ -120,7 +137,7 @@ where
/// The queue depth is 1, and the passed-in `buf` seeds the queue depth.
/// I.e., the passed-in buf is immediately available to the handle as a recycled buffer.
pub fn spawn_new<B>(
file: Arc<W>,
file: W,
buf: B,
gate_guard: utils::sync::gate::GateGuard,
cancel: CancellationToken,
@@ -160,10 +177,7 @@ where
let (request, flush_control) = new_flush_op(slice, offset);
// Submits the buffer to the background task.
let submit = self.inner_mut().channel.send(request).await;
if submit.is_err() {
return self.handle_error().await;
}
self.send(Request::Flush(request)).await?;
// Wait for an available buffer from the background flush task.
// This is the BACKPRESSURE mechanism: if the flush task can't keep up,
@@ -175,15 +189,28 @@ where
Ok((recycled, flush_control))
}
/// Sends poison pill to flush task and waits for it to exit.
pub async fn shutdown(&mut self, req: ShutdownRequest) -> Result<W, FlushTaskError> {
self.send(Request::Shutdown(req)).await?;
self.wait().await
}
async fn send(&mut self, request: Request<Buf>) -> Result<(), FlushTaskError> {
let submit = self.inner_mut().channel.send(request).await;
if submit.is_err() {
return self.handle_error().await;
}
Ok(())
}
async fn handle_error<T>(&mut self) -> Result<T, FlushTaskError> {
Err(self
.shutdown()
.wait()
.await
.expect_err("flush task only disconnects duplex if it exits with an error"))
}
/// Cleans up the channel, join the flush task.
pub async fn shutdown(&mut self) -> Result<Arc<W>, FlushTaskError> {
async fn wait(&mut self) -> Result<W, FlushTaskError> {
let handle = self
.inner
.take()
@@ -205,9 +232,9 @@ where
pub struct FlushBackgroundTask<Buf, W> {
/// A bi-directional channel that receives (buffer, offset) for writes,
/// and send back recycled buffer.
channel: duplex::mpsc::Duplex<FullSlice<Buf>, FlushRequest<Buf>>,
channel: duplex::mpsc::Duplex<FullSlice<Buf>, Request<Buf>>,
/// A writter for persisting data to disk.
writer: Arc<W>,
writer: W,
ctx: RequestContext,
cancel: CancellationToken,
/// Prevent timeline from shuting down until the flush background task finishes flushing all remaining buffers to disk.
@@ -227,8 +254,8 @@ where
{
/// Creates a new background flush task.
fn new(
channel: duplex::mpsc::Duplex<FullSlice<Buf>, FlushRequest<Buf>>,
file: Arc<W>,
channel: duplex::mpsc::Duplex<FullSlice<Buf>, Request<Buf>>,
file: W,
gate_guard: utils::sync::gate::GateGuard,
cancel: CancellationToken,
ctx: RequestContext,
@@ -243,18 +270,12 @@ where
}
/// Runs the background flush task.
async fn run(mut self) -> Result<Arc<W>, FlushTaskError> {
async fn run(mut self) -> Result<W, FlushTaskError> {
// Exit condition: channel is closed and there is no remaining buffer to be flushed
while let Some(request) = self.channel.recv().await {
#[cfg(test)]
{
// In test, wait for control to signal that we are ready to flush.
if request.ready_to_flush_rx.await.is_err() {
tracing::debug!("control dropped");
}
}
let op_kind = request.op_str();
// Write slice to disk at `offset`.
// Perform the requested operation.
//
// Error handling happens according to the current policy of crashing
// on fatal IO errors and retrying in place otherwise (deeming all other errors retryable).
@@ -263,52 +284,112 @@ where
//
// TODO: use utils::backoff::retry once async closures are actually usable
//
let mut slice_storage = Some(request.slice);
let mut request_storage = Some(request);
for attempt in 1.. {
if self.cancel.is_cancelled() {
return Err(FlushTaskError::Cancelled);
}
let result = async {
if attempt > 1 {
info!("retrying flush");
}
let slice = slice_storage.take().expect(
let request: Request<Buf> = request_storage .take().expect(
"likely previous invocation of this future didn't get polled to completion",
);
// Don't cancel this write by doing tokio::select with self.cancel.cancelled().
match &request {
Request::Shutdown(ShutdownRequest { set_len: None }) => {
request_storage = Some(request);
return ControlFlow::Break(());
},
Request::Flush(_) | Request::Shutdown(ShutdownRequest { set_len: Some(_) }) => {
},
}
if attempt > 1 {
warn!(op=%request.op_str(), "retrying");
}
// borrows so we can async move the requests into async block while not moving these borrows here
let writer = &self.writer;
let request_storage = &mut request_storage;
let ctx = &self.ctx;
let io_fut = match request {
Request::Flush(FlushRequest { slice, offset, #[cfg(test)] ready_to_flush_rx, #[cfg(test)] done_flush_tx }) => futures::future::Either::Left(async move {
#[cfg(test)]
if let Some(ready_to_flush_rx) = ready_to_flush_rx {
{
// In test, wait for control to signal that we are ready to flush.
if ready_to_flush_rx.await.is_err() {
tracing::debug!("control dropped");
}
}
}
let (slice, res) = writer.write_all_at(slice, offset, ctx).await;
*request_storage = Some(Request::Flush(FlushRequest {
slice,
offset,
#[cfg(test)]
ready_to_flush_rx: None, // the contract is that we notify before first attempt
#[cfg(test)]
done_flush_tx
}));
res
}),
Request::Shutdown(ShutdownRequest { set_len }) => futures::future::Either::Right(async move {
let set_len = set_len.expect("we filter out the None case above");
let res = writer.set_len(set_len, ctx).await;
*request_storage = Some(Request::Shutdown(ShutdownRequest {
set_len: Some(set_len),
}));
res
}),
};
// Don't cancel the io_fut by doing tokio::select with self.cancel.cancelled().
// The underlying tokio-epoll-uring slot / kernel operation is still ongoing and occupies resources.
// If we retry indefinitely, we'll deplete those resources.
// Future: teach tokio-epoll-uring io_uring operation cancellation, but still,
// wait for cancelled ops to complete and discard their error.
let (slice, res) = self.writer.write_all_at(slice, request.offset, &self.ctx).await;
slice_storage = Some(slice);
let res = io_fut.await;
let res = res.maybe_fatal_err("owned_buffers_io flush");
let Err(err) = res else {
if attempt > 1 {
warn!(op=%op_kind, "retry succeeded");
}
return ControlFlow::Break(());
};
warn!(%err, "error flushing buffered writer buffer to disk, retrying after backoff");
utils::backoff::exponential_backoff(attempt, 1.0, 10.0, &self.cancel).await;
ControlFlow::Continue(())
}
.instrument(info_span!("flush_attempt", %attempt))
.instrument(info_span!("attempt", %attempt, %op_kind))
.await;
match result {
ControlFlow::Break(()) => break,
ControlFlow::Continue(()) => continue,
}
}
let slice = slice_storage.expect("loop must have run at least once");
let request = request_storage.expect("loop must have run at least once");
#[cfg(test)]
{
// In test, tell control we are done flushing buffer.
if request.done_flush_tx.send(()).is_err() {
tracing::debug!("control dropped");
let slice = match request {
Request::Flush(FlushRequest {
slice,
#[cfg(test)]
mut done_flush_tx,
..
}) => {
#[cfg(test)]
{
// In test, tell control we are done flushing buffer.
if done_flush_tx.take().expect("always Some").send(()).is_err() {
tracing::debug!("control dropped");
}
}
slice
}
}
Request::Shutdown(_) => {
// next iteration will observe recv() returning None
continue;
}
};
// Sends the buffer back to the handle for reuse. The handle is in charged of cleaning the buffer.
if self.channel.send(slice).await.is_err() {
let send_res = self.channel.send(slice).await;
if send_res.is_err() {
// Although channel is closed. Still need to finish flushing the remaining buffers.
continue;
}

View File

@@ -0,0 +1,110 @@
use tracing::error;
use utils::sync::gate::GateGuard;
use crate::context::RequestContext;
use super::{
MaybeFatalIo, VirtualFile,
owned_buffers_io::{
io_buf_aligned::IoBufAligned, io_buf_ext::FullSlice, write::OwnedAsyncWriter,
},
};
/// A wrapper around [`super::VirtualFile`] that deletes the file on drop.
/// For use as a [`OwnedAsyncWriter`] in [`super::owned_buffers_io::write::BufferedWriter`].
#[derive(Debug)]
pub struct TempVirtualFile {
inner: Option<Inner>,
}
#[derive(Debug)]
struct Inner {
file: VirtualFile,
/// Gate guard is held on as long as we need to do operations in the path (delete on drop)
_gate_guard: GateGuard,
}
impl OwnedAsyncWriter for TempVirtualFile {
fn write_all_at<Buf: IoBufAligned + Send>(
&self,
buf: FullSlice<Buf>,
offset: u64,
ctx: &RequestContext,
) -> impl std::future::Future<Output = (FullSlice<Buf>, std::io::Result<()>)> + Send {
VirtualFile::write_all_at(self, buf, offset, ctx)
}
async fn set_len(&self, len: u64, ctx: &RequestContext) -> std::io::Result<()> {
VirtualFile::set_len(self, len, ctx).await
}
}
impl Drop for TempVirtualFile {
fn drop(&mut self) {
let Some(Inner { file, _gate_guard }) = self.inner.take() else {
return;
};
let path = file.path();
if let Err(e) =
std::fs::remove_file(path).maybe_fatal_err("failed to remove the virtual file")
{
error!(err=%e, path=%path, "failed to remove");
}
drop(_gate_guard);
}
}
impl std::ops::Deref for TempVirtualFile {
type Target = VirtualFile;
fn deref(&self) -> &Self::Target {
&self
.inner
.as_ref()
.expect("only None after into_inner or drop")
.file
}
}
impl std::ops::DerefMut for TempVirtualFile {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self
.inner
.as_mut()
.expect("only None after into_inner or drop")
.file
}
}
impl TempVirtualFile {
/// The caller is responsible for ensuring that the path of `virtual_file` is not reused
/// until after this TempVirtualFile's `Drop` impl has completed.
/// Failure to do so will result in unlinking of the reused path by the original instance's Drop impl.
/// The best way to do so is by using a monotonic counter as a disambiguator.
/// TODO: centralize this disambiguator pattern inside this struct.
/// => <https://github.com/neondatabase/neon/pull/11549#issuecomment-2824592831>
pub fn new(virtual_file: VirtualFile, gate_guard: GateGuard) -> Self {
Self {
inner: Some(Inner {
file: virtual_file,
_gate_guard: gate_guard,
}),
}
}
/// Dismantle this wrapper and return the underlying [`VirtualFile`].
/// This disables auto-unlinking functionality that is the essence of this wrapper.
///
/// The gate guard is dropped as well; it is the callers responsibility to ensure filesystem
/// operations after calls to this functions are still gated by some other gate guard.
///
/// TODO:
/// - centralize the common usage pattern of callers (sync_all(self), rename(self, dst), sync_all(dst.parent))
/// => <https://github.com/neondatabase/neon/pull/11549#issuecomment-2824592831>
pub fn disarm_into_inner(mut self) -> VirtualFile {
self.inner
.take()
.expect("only None after into_inner or drop, and we are into_inner, and we consume")
.file
}
}

18
poetry.lock generated
View File

@@ -1274,14 +1274,14 @@ files = [
[[package]]
name = "h11"
version = "0.14.0"
version = "0.16.0"
description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1"
optional = false
python-versions = ">=3.7"
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"},
{file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"},
{file = "h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86"},
{file = "h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1"},
]
[[package]]
@@ -1314,25 +1314,25 @@ files = [
[[package]]
name = "httpcore"
version = "1.0.3"
version = "1.0.9"
description = "A minimal low-level HTTP client."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "httpcore-1.0.3-py3-none-any.whl", hash = "sha256:9a6a501c3099307d9fd76ac244e08503427679b1e81ceb1d922485e2f2462ad2"},
{file = "httpcore-1.0.3.tar.gz", hash = "sha256:5c0f9546ad17dac4d0772b0808856eb616eb8b48ce94f49ed819fd6982a8a544"},
{file = "httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55"},
{file = "httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8"},
]
[package.dependencies]
certifi = "*"
h11 = ">=0.13,<0.15"
h11 = ">=0.16"
[package.extras]
asyncio = ["anyio (>=4.0,<5.0)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
trio = ["trio (>=0.22.0,<0.24.0)"]
trio = ["trio (>=0.22.0,<1.0)"]
[[package]]
name = "httpx"

View File

@@ -91,6 +91,7 @@ mod jemalloc;
mod logging;
mod metrics;
mod parse;
mod pglb;
mod protocol2;
mod proxy;
mod rate_limiter;

193
proxy/src/pglb/inprocess.rs Normal file
View File

@@ -0,0 +1,193 @@
#![allow(dead_code, reason = "TODO: work in progress")]
use std::pin::{Pin, pin};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::{Context, Poll};
use std::{fmt, io};
use tokio::io::{AsyncRead, AsyncWrite, DuplexStream, ReadBuf};
use tokio::sync::mpsc;
const STREAM_CHANNEL_SIZE: usize = 16;
const MAX_STREAM_BUFFER_SIZE: usize = 4096;
#[derive(Debug)]
pub struct Connection {
stream_sender: mpsc::Sender<Stream>,
stream_receiver: mpsc::Receiver<Stream>,
stream_id_counter: Arc<AtomicUsize>,
}
impl Connection {
pub fn new() -> (Connection, Connection) {
let (sender_a, receiver_a) = mpsc::channel(STREAM_CHANNEL_SIZE);
let (sender_b, receiver_b) = mpsc::channel(STREAM_CHANNEL_SIZE);
let stream_id_counter = Arc::new(AtomicUsize::new(1));
let conn_a = Connection {
stream_sender: sender_a,
stream_receiver: receiver_b,
stream_id_counter: Arc::clone(&stream_id_counter),
};
let conn_b = Connection {
stream_sender: sender_b,
stream_receiver: receiver_a,
stream_id_counter,
};
(conn_a, conn_b)
}
#[inline]
fn next_stream_id(&self) -> StreamId {
StreamId(self.stream_id_counter.fetch_add(1, Ordering::Relaxed))
}
#[tracing::instrument(skip_all, fields(stream_id = tracing::field::Empty, err))]
pub async fn open_stream(&self) -> io::Result<Stream> {
let (local, remote) = tokio::io::duplex(MAX_STREAM_BUFFER_SIZE);
let stream_id = self.next_stream_id();
tracing::Span::current().record("stream_id", stream_id.0);
let local = Stream {
inner: local,
id: stream_id,
};
let remote = Stream {
inner: remote,
id: stream_id,
};
self.stream_sender
.send(remote)
.await
.map_err(io::Error::other)?;
Ok(local)
}
#[tracing::instrument(skip_all, fields(stream_id = tracing::field::Empty, err))]
pub async fn accept_stream(&mut self) -> io::Result<Option<Stream>> {
Ok(self.stream_receiver.recv().await.inspect(|stream| {
tracing::Span::current().record("stream_id", stream.id.0);
}))
}
}
#[derive(Copy, Clone, Debug)]
pub struct StreamId(usize);
impl fmt::Display for StreamId {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
// TODO: Proper closing. Currently Streams can outlive their Connections.
// Carry WeakSender and check strong_count?
#[derive(Debug)]
pub struct Stream {
inner: DuplexStream,
id: StreamId,
}
impl Stream {
#[inline]
pub fn id(&self) -> StreamId {
self.id
}
}
impl AsyncRead for Stream {
#[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))]
#[inline]
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
pin!(&mut self.inner).poll_read(cx, buf)
}
}
impl AsyncWrite for Stream {
#[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))]
#[inline]
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
pin!(&mut self.inner).poll_write(cx, buf)
}
#[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))]
#[inline]
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
pin!(&mut self.inner).poll_flush(cx)
}
#[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))]
#[inline]
fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>> {
pin!(&mut self.inner).poll_shutdown(cx)
}
#[tracing::instrument(level = "debug", skip_all, fields(stream_id = %self.id))]
#[inline]
fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
pin!(&mut self.inner).poll_write_vectored(cx, bufs)
}
#[inline]
fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}
#[cfg(test)]
mod tests {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::*;
#[tokio::test]
async fn test_simple_roundtrip() {
let (client, mut server) = Connection::new();
let server_task = tokio::spawn(async move {
while let Some(mut stream) = server.accept_stream().await.unwrap() {
tokio::spawn(async move {
let mut buf = [0; 64];
loop {
match stream.read(&mut buf).await.unwrap() {
0 => break,
n => stream.write(&buf[..n]).await.unwrap(),
};
}
});
}
});
let mut stream = client.open_stream().await.unwrap();
stream.write_all(b"hello!").await.unwrap();
let mut buf = [0; 64];
let n = stream.read(&mut buf).await.unwrap();
assert_eq!(n, 6);
assert_eq!(&buf[..n], b"hello!");
drop(stream);
drop(client);
server_task.await.unwrap();
}
}

1
proxy/src/pglb/mod.rs Normal file
View File

@@ -0,0 +1 @@
pub mod inprocess;

View File

@@ -12,7 +12,7 @@ use pin_project_lite::pin_project;
use smol_str::SmolStr;
use strum_macros::FromRepr;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
use zerocopy::{FromBytes, FromZeroes};
use zerocopy::{FromBytes, Immutable, KnownLayout, Unaligned, network_endian};
pin_project! {
/// A chained [`AsyncRead`] with [`AsyncWrite`] passthrough
@@ -339,49 +339,49 @@ trait BufExt: Sized {
}
impl BufExt for BytesMut {
fn try_get<T: FromBytes>(&mut self) -> Option<T> {
let res = T::read_from_prefix(self)?;
let (res, _) = T::read_from_prefix(self).ok()?;
self.advance(size_of::<T>());
Some(res)
}
}
#[derive(FromBytes, FromZeroes, Copy, Clone)]
#[repr(C)]
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
#[repr(C, packed)]
struct ProxyProtocolV2Header {
signature: [u8; 12],
version_and_command: u8,
protocol_and_family: u8,
len: zerocopy::byteorder::network_endian::U16,
len: network_endian::U16,
}
#[derive(FromBytes, FromZeroes, Copy, Clone)]
#[repr(C)]
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
#[repr(C, packed)]
struct ProxyProtocolV2HeaderV4 {
src_addr: NetworkEndianIpv4,
dst_addr: NetworkEndianIpv4,
src_port: zerocopy::byteorder::network_endian::U16,
dst_port: zerocopy::byteorder::network_endian::U16,
src_port: network_endian::U16,
dst_port: network_endian::U16,
}
#[derive(FromBytes, FromZeroes, Copy, Clone)]
#[repr(C)]
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
#[repr(C, packed)]
struct ProxyProtocolV2HeaderV6 {
src_addr: NetworkEndianIpv6,
dst_addr: NetworkEndianIpv6,
src_port: zerocopy::byteorder::network_endian::U16,
dst_port: zerocopy::byteorder::network_endian::U16,
src_port: network_endian::U16,
dst_port: network_endian::U16,
}
#[derive(FromBytes, FromZeroes, Copy, Clone)]
#[repr(C)]
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
#[repr(C, packed)]
struct TlvHeader {
kind: u8,
len: zerocopy::byteorder::network_endian::U16,
len: network_endian::U16,
}
#[derive(FromBytes, FromZeroes, Copy, Clone)]
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
#[repr(transparent)]
struct NetworkEndianIpv4(zerocopy::byteorder::network_endian::U32);
struct NetworkEndianIpv4(network_endian::U32);
impl NetworkEndianIpv4 {
#[inline]
fn get(self) -> Ipv4Addr {
@@ -389,9 +389,9 @@ impl NetworkEndianIpv4 {
}
}
#[derive(FromBytes, FromZeroes, Copy, Clone)]
#[derive(FromBytes, KnownLayout, Immutable, Unaligned, Copy, Clone)]
#[repr(transparent)]
struct NetworkEndianIpv6(zerocopy::byteorder::network_endian::U128);
struct NetworkEndianIpv6(network_endian::U128);
impl NetworkEndianIpv6 {
#[inline]
fn get(self) -> Ipv6Addr {

View File

@@ -226,11 +226,16 @@ struct Args {
/// Path to the JWT auth token used to authenticate with other safekeepers.
#[arg(long)]
auth_token_path: Option<Utf8PathBuf>,
/// Enable TLS in WAL service API.
/// Does not force TLS: the client negotiates TLS usage during the handshake.
/// Uses key and certificate from ssl_key_file/ssl_cert_file.
#[arg(long)]
enable_tls_wal_service_api: bool,
/// Run in development mode (disables security checks)
#[arg(long, help = "Run in development mode (disables security checks)")]
dev: bool,
}
// Like PathBufValueParser, but allows empty string.

View File

@@ -0,0 +1 @@
DROP TABLE timeline_imports;

View File

@@ -0,0 +1,6 @@
CREATE TABLE timeline_imports (
tenant_id VARCHAR NOT NULL,
timeline_id VARCHAR NOT NULL,
shard_statuses JSONB NOT NULL,
PRIMARY KEY(tenant_id, timeline_id)
);

View File

@@ -30,7 +30,9 @@ use pageserver_api::models::{
TimelineArchivalConfigRequest, TimelineCreateRequest,
};
use pageserver_api::shard::TenantShardId;
use pageserver_api::upcall_api::{ReAttachRequest, ValidateRequest};
use pageserver_api::upcall_api::{
PutTimelineImportStatusRequest, ReAttachRequest, ValidateRequest,
};
use pageserver_client::{BlockUnblock, mgmt_api};
use routerify::Middleware;
use tokio_util::sync::CancellationToken;
@@ -70,6 +72,7 @@ impl HttpState {
neon_metrics: NeonMetrics::new(build_info),
allowlist_routes: &[
"/status",
"/live",
"/ready",
"/metrics",
"/profile/cpu",
@@ -154,6 +157,28 @@ async fn handle_validate(req: Request<Body>) -> Result<Response<Body>, ApiError>
json_response(StatusCode::OK, state.service.validate(validate_req).await?)
}
async fn handle_put_timeline_import_status(req: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permissions(&req, Scope::GenerationsApi)?;
let mut req = match maybe_forward(req).await {
ForwardOutcome::Forwarded(res) => {
return res;
}
ForwardOutcome::NotForwarded(req) => req,
};
let put_req = json_request::<PutTimelineImportStatusRequest>(&mut req).await?;
let state = get_state(&req);
json_response(
StatusCode::OK,
state
.service
.handle_timeline_shard_import_progress_upcall(put_req)
.await?,
)
}
/// Call into this before attaching a tenant to a pageserver, to acquire a generation number
/// (in the real control plane this is unnecessary, because the same program is managing
/// generation numbers and doing attachments).
@@ -1236,16 +1261,8 @@ async fn handle_step_down(req: Request<Body>) -> Result<Response<Body>, ApiError
ForwardOutcome::NotForwarded(req) => req,
};
// Spawn a background task: once we start stepping down, we must finish: if the client drops
// their request we should avoid stopping in some part-stepped-down state.
let handle = tokio::spawn(async move {
let state = get_state(&req);
state.service.step_down().await
});
let result = handle
.await
.map_err(|e| ApiError::InternalServerError(e.into()))?;
let state = get_state(&req);
let result = state.service.step_down().await;
json_response(StatusCode::OK, result)
}
@@ -1377,6 +1394,8 @@ async fn handle_reconcile_all(req: Request<Body>) -> Result<Response<Body>, ApiE
}
/// Status endpoint is just used for checking that our HTTP listener is up
///
/// This serves as our k8s startup probe.
async fn handle_status(req: Request<Body>) -> Result<Response<Body>, ApiError> {
match maybe_forward(req).await {
ForwardOutcome::Forwarded(res) => {
@@ -1388,6 +1407,30 @@ async fn handle_status(req: Request<Body>) -> Result<Response<Body>, ApiError> {
json_response(StatusCode::OK, ())
}
/// Liveness endpoint indicates that this storage controller is in a state
/// where it can fulfill it's responsibilties. Namely, startup has finished
/// and it is the current leader.
///
/// This serves as our k8s liveness probe.
async fn handle_live(req: Request<Body>) -> Result<Response<Body>, ApiError> {
let req = match maybe_forward(req).await {
ForwardOutcome::Forwarded(res) => {
return res;
}
ForwardOutcome::NotForwarded(req) => req,
};
let state = get_state(&req);
let live = state.service.startup_complete.is_ready()
&& state.service.get_leadership_status() == LeadershipStatus::Leader;
if live {
json_response(StatusCode::OK, ())
} else {
json_response(StatusCode::SERVICE_UNAVAILABLE, ())
}
}
/// Readiness endpoint indicates when we're done doing startup I/O (e.g. reconciling
/// with remote pageserver nodes). This is intended for use as a kubernetes readiness probe.
async fn handle_ready(req: Request<Body>) -> Result<Response<Body>, ApiError> {
@@ -1721,6 +1764,7 @@ async fn maybe_forward(req: Request<Body>) -> ForwardOutcome {
const NOT_FOR_FORWARD: &[&str] = &[
"/control/v1/step_down",
"/status",
"/live",
"/ready",
"/metrics",
"/profile/cpu",
@@ -1945,6 +1989,9 @@ pub fn make_router(
.get("/status", |r| {
named_request_span(r, handle_status, RequestName("status"))
})
.get("/live", |r| {
named_request_span(r, handle_live, RequestName("live"))
})
.get("/ready", |r| {
named_request_span(r, handle_ready, RequestName("ready"))
})
@@ -1961,6 +2008,13 @@ pub fn make_router(
.post("/upcall/v1/validate", |r| {
named_request_span(r, handle_validate, RequestName("upcall_v1_validate"))
})
.post("/upcall/v1/timeline_import_status", |r| {
named_request_span(
r,
handle_put_timeline_import_status,
RequestName("upcall_v1_timeline_import_status"),
)
})
// Test/dev/debug endpoints
.post("/debug/v1/attach-hook", |r| {
named_request_span(r, handle_attach_hook, RequestName("debug_v1_attach_hook"))

View File

@@ -43,6 +43,19 @@ impl Leadership {
&self,
) -> Result<(Option<ControllerPersistence>, Option<GlobalObservedState>)> {
let leader = self.current_leader().await?;
if leader.as_ref().map(|l| &l.address)
== self
.config
.address_for_peers
.as_ref()
.map(Uri::to_string)
.as_ref()
{
// We already are the current leader. This is a restart.
return Ok((leader, None));
}
let leader_step_down_state = if let Some(ref leader) = leader {
if self.config.start_as_candidate {
self.request_step_down(leader).await

View File

@@ -23,6 +23,7 @@ mod scheduler;
mod schema;
pub mod service;
mod tenant_shard;
mod timeline_import;
#[derive(Ord, PartialOrd, Eq, PartialEq, Copy, Clone, Serialize)]
struct Sequence(u64);

View File

@@ -212,6 +212,21 @@ impl PageserverClient {
)
}
pub(crate) async fn timeline_detail(
&self,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
) -> Result<TimelineInfo> {
measured_request!(
"timeline_detail",
crate::metrics::Method::Get,
&self.node_id_label,
self.inner
.timeline_detail(tenant_shard_id, timeline_id)
.await
)
}
pub(crate) async fn tenant_shard_split(
&self,
tenant_shard_id: TenantShardId,

View File

@@ -55,9 +55,12 @@ impl ResponseErrorMessageExt for reqwest::Response {
}
}
#[derive(Serialize, Deserialize, Debug, Default)]
#[derive(Serialize, Deserialize, Debug, Default, Clone)]
pub(crate) struct GlobalObservedState(pub(crate) HashMap<TenantShardId, ObservedState>);
const STEP_DOWN_RETRIES: u32 = 8;
const STEP_DOWN_TIMEOUT: Duration = Duration::from_secs(1);
impl PeerClient {
pub(crate) fn new(http_client: reqwest::Client, uri: Uri, jwt: Option<String>) -> Self {
Self {
@@ -76,7 +79,7 @@ impl PeerClient {
req
};
let req = req.timeout(Duration::from_secs(2));
let req = req.timeout(STEP_DOWN_TIMEOUT);
let res = req
.send()
@@ -94,8 +97,7 @@ impl PeerClient {
}
/// Request the peer to step down and return its current observed state
/// All errors are retried with exponential backoff for a maximum of 4 attempts.
/// Assuming all retries are performed, the function times out after roughly 4 seconds.
/// All errors are re-tried
pub(crate) async fn step_down(
&self,
cancel: &CancellationToken,
@@ -104,7 +106,7 @@ impl PeerClient {
|| self.request_step_down(),
|_e| false,
2,
4,
STEP_DOWN_RETRIES,
"Send step down request",
cancel,
)

View File

@@ -22,7 +22,7 @@ use pageserver_api::controller_api::{
AvailabilityZone, MetadataHealthRecord, NodeSchedulingPolicy, PlacementPolicy,
SafekeeperDescribeResponse, ShardSchedulingPolicy, SkSchedulingPolicy,
};
use pageserver_api::models::TenantConfig;
use pageserver_api::models::{ShardImportStatus, TenantConfig};
use pageserver_api::shard::{
ShardConfigError, ShardCount, ShardIdentity, ShardNumber, ShardStripeSize, TenantShardId,
};
@@ -40,6 +40,9 @@ use crate::metrics::{
DatabaseQueryErrorLabelGroup, DatabaseQueryLatencyLabelGroup, METRICS_REGISTRY,
};
use crate::node::Node;
use crate::timeline_import::{
TimelineImport, TimelineImportUpdateError, TimelineImportUpdateFollowUp,
};
const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations");
/// ## What do we store?
@@ -127,6 +130,10 @@ pub(crate) enum DatabaseOperation {
RemoveTimelineReconcile,
ListTimelineReconcile,
ListTimelineReconcileStartup,
InsertTimelineImport,
UpdateTimelineImport,
DeleteTimelineImport,
ListTimelineImports,
}
#[must_use]
@@ -1614,6 +1621,158 @@ impl Persistence {
Ok(())
}
pub(crate) async fn insert_timeline_import(
&self,
import: TimelineImportPersistence,
) -> DatabaseResult<bool> {
self.with_measured_conn(DatabaseOperation::InsertTimelineImport, move |conn| {
Box::pin({
let import = import.clone();
async move {
let inserted = diesel::insert_into(crate::schema::timeline_imports::table)
.values(import)
.execute(conn)
.await?;
Ok(inserted == 1)
}
})
})
.await
}
pub(crate) async fn list_complete_timeline_imports(
&self,
) -> DatabaseResult<Vec<TimelineImport>> {
use crate::schema::timeline_imports::dsl;
let persistent = self
.with_measured_conn(DatabaseOperation::ListTimelineImports, move |conn| {
Box::pin(async move {
let from_db: Vec<TimelineImportPersistence> =
dsl::timeline_imports.load(conn).await?;
Ok(from_db)
})
})
.await?;
let imports: Result<Vec<TimelineImport>, _> = persistent
.into_iter()
.map(TimelineImport::from_persistent)
.collect();
match imports {
Ok(ok) => Ok(ok
.into_iter()
.filter(|import| import.is_complete())
.collect()),
Err(err) => Err(DatabaseError::Logical(format!(
"failed to deserialize import: {err}"
))),
}
}
pub(crate) async fn delete_timeline_import(
&self,
tenant_id: TenantId,
timeline_id: TimelineId,
) -> DatabaseResult<()> {
use crate::schema::timeline_imports::dsl;
self.with_measured_conn(DatabaseOperation::DeleteTimelineImport, move |conn| {
Box::pin(async move {
diesel::delete(crate::schema::timeline_imports::table)
.filter(
dsl::tenant_id
.eq(tenant_id.to_string())
.and(dsl::timeline_id.eq(timeline_id.to_string())),
)
.execute(conn)
.await?;
Ok(())
})
})
.await
}
/// Idempotently update the status of one shard for an ongoing timeline import
///
/// If the update was persisted to the database, then the current state of the
/// import is returned to the caller. In case of logical errors a bespoke
/// [`TimelineImportUpdateError`] instance is returned. Other database errors
/// are covered by the outer [`DatabaseError`].
pub(crate) async fn update_timeline_import(
&self,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
shard_status: ShardImportStatus,
) -> DatabaseResult<Result<Option<TimelineImport>, TimelineImportUpdateError>> {
use crate::schema::timeline_imports::dsl;
self.with_measured_conn(DatabaseOperation::UpdateTimelineImport, move |conn| {
Box::pin({
let shard_status = shard_status.clone();
async move {
// Load the current state from the database
let mut from_db: Vec<TimelineImportPersistence> = dsl::timeline_imports
.filter(
dsl::tenant_id
.eq(tenant_shard_id.tenant_id.to_string())
.and(dsl::timeline_id.eq(timeline_id.to_string())),
)
.load(conn)
.await?;
assert!(from_db.len() <= 1);
let mut status = match from_db.pop() {
Some(some) => TimelineImport::from_persistent(some).unwrap(),
None => {
return Ok(Err(TimelineImportUpdateError::ImportNotFound {
tenant_id: tenant_shard_id.tenant_id,
timeline_id,
}));
}
};
// Perform the update in-memory
let follow_up = match status.update(tenant_shard_id.to_index(), shard_status) {
Ok(ok) => ok,
Err(err) => {
return Ok(Err(err));
}
};
let new_persistent = status.to_persistent();
// Write back if required (in the same transaction)
match follow_up {
TimelineImportUpdateFollowUp::Persist => {
let updated = diesel::update(dsl::timeline_imports)
.filter(
dsl::tenant_id
.eq(tenant_shard_id.tenant_id.to_string())
.and(dsl::timeline_id.eq(timeline_id.to_string())),
)
.set(dsl::shard_statuses.eq(new_persistent.shard_statuses))
.execute(conn)
.await?;
if updated != 1 {
return Ok(Err(TimelineImportUpdateError::ImportNotFound {
tenant_id: tenant_shard_id.tenant_id,
timeline_id,
}));
}
Ok(Ok(Some(status)))
}
TimelineImportUpdateFollowUp::None => Ok(Ok(None)),
}
}
})
})
.await
}
}
pub(crate) fn load_certs() -> anyhow::Result<Arc<rustls::RootCertStore>> {
@@ -2171,3 +2330,11 @@ impl ToSql<diesel::sql_types::VarChar, Pg> for SafekeeperTimelineOpKind {
.map_err(Into::into)
}
}
#[derive(Serialize, Deserialize, Queryable, Selectable, Insertable, Eq, PartialEq, Clone)]
#[diesel(table_name = crate::schema::timeline_imports)]
pub(crate) struct TimelineImportPersistence {
pub(crate) tenant_id: String,
pub(crate) timeline_id: String,
pub(crate) shard_statuses: serde_json::Value,
}

View File

@@ -76,6 +76,14 @@ diesel::table! {
}
}
diesel::table! {
timeline_imports (tenant_id, timeline_id) {
tenant_id -> Varchar,
timeline_id -> Varchar,
shard_statuses -> Jsonb,
}
}
diesel::table! {
use diesel::sql_types::*;
use super::sql_types::PgLsn;
@@ -99,5 +107,6 @@ diesel::allow_tables_to_appear_in_same_query!(
safekeeper_timeline_pending_ops,
safekeepers,
tenant_shards,
timeline_imports,
timelines,
);

View File

@@ -11,7 +11,7 @@ use std::num::NonZeroU32;
use std::ops::{Deref, DerefMut};
use std::path::PathBuf;
use std::str::FromStr;
use std::sync::Arc;
use std::sync::{Arc, OnceLock};
use std::time::{Duration, Instant, SystemTime};
use anyhow::Context;
@@ -40,14 +40,14 @@ use pageserver_api::models::{
TenantLocationConfigResponse, TenantShardLocation, TenantShardSplitRequest,
TenantShardSplitResponse, TenantSorting, TenantTimeTravelRequest,
TimelineArchivalConfigRequest, TimelineCreateRequest, TimelineCreateResponseStorcon,
TimelineInfo, TopTenantShardItem, TopTenantShardsRequest,
TimelineInfo, TimelineState, TopTenantShardItem, TopTenantShardsRequest,
};
use pageserver_api::shard::{
DEFAULT_STRIPE_SIZE, ShardCount, ShardIdentity, ShardNumber, ShardStripeSize, TenantShardId,
};
use pageserver_api::upcall_api::{
ReAttachRequest, ReAttachResponse, ReAttachResponseTenant, ValidateRequest, ValidateResponse,
ValidateResponseTenant,
PutTimelineImportStatusRequest, ReAttachRequest, ReAttachResponse, ReAttachResponseTenant,
ValidateRequest, ValidateResponse, ValidateResponseTenant,
};
use pageserver_client::{BlockUnblock, mgmt_api};
use reqwest::{Certificate, StatusCode};
@@ -97,6 +97,7 @@ use crate::tenant_shard::{
ReconcileNeeded, ReconcileResult, ReconcileWaitError, ReconcilerStatus, ReconcilerWaiter,
ScheduleOptimization, ScheduleOptimizationAction, TenantShard,
};
use crate::timeline_import::{ShardImportStatuses, TimelineImport, UpcallClient};
const WAITER_FILL_DRAIN_POLL_TIMEOUT: Duration = Duration::from_millis(500);
@@ -523,6 +524,9 @@ pub struct Service {
/// HTTP client with proper CA certs.
http_client: reqwest::Client,
/// Handle for the step down background task if one was ever requested
step_down_barrier: OnceLock<tokio::sync::watch::Receiver<Option<GlobalObservedState>>>,
}
impl From<ReconcileWaitError> for ApiError {
@@ -874,6 +878,22 @@ impl Service {
});
}
// Fetch the list of completed imports and attempt to finalize them in the background.
// This handles the case where the previous storage controller instance shut down
// whilst finalizing imports.
let complete_imports = self.persistence.list_complete_timeline_imports().await;
match complete_imports {
Ok(ok) => {
tokio::task::spawn({
let finalize_imports_self = self.clone();
async move { finalize_imports_self.finalize_timeline_imports(ok).await }
});
}
Err(err) => {
tracing::error!("Could not retrieve completed imports from database: {err}");
}
}
tracing::info!(
"Startup complete, spawned {reconcile_tasks} reconciliation tasks ({shard_count} shards total)"
);
@@ -1744,6 +1764,7 @@ impl Service {
tenant_op_locks: Default::default(),
node_op_locks: Default::default(),
http_client,
step_down_barrier: Default::default(),
});
let result_task_this = this.clone();
@@ -3732,11 +3753,14 @@ impl Service {
create_req: TimelineCreateRequest,
) -> Result<TimelineCreateResponseStorcon, ApiError> {
let safekeepers = self.config.timelines_onto_safekeepers;
let timeline_id = create_req.new_timeline_id;
tracing::info!(
mode=%create_req.mode_tag(),
%safekeepers,
"Creating timeline {}/{}",
tenant_id,
create_req.new_timeline_id,
timeline_id,
);
let _tenant_lock = trace_shared_lock(
@@ -3746,15 +3770,62 @@ impl Service {
)
.await;
failpoint_support::sleep_millis_async!("tenant-create-timeline-shared-lock");
let create_mode = create_req.mode.clone();
let is_import = create_req.is_import();
let timeline_info = self
.tenant_timeline_create_pageservers(tenant_id, create_req)
.await?;
let safekeepers = if safekeepers {
let selected_safekeepers = if is_import {
let shards = {
let locked = self.inner.read().unwrap();
locked
.tenants
.range(TenantShardId::tenant_range(tenant_id))
.map(|(ts_id, _)| ts_id.to_index())
.collect::<Vec<_>>()
};
if !shards
.iter()
.map(|shard_index| shard_index.shard_count)
.all_equal()
{
return Err(ApiError::InternalServerError(anyhow::anyhow!(
"Inconsistent shard count"
)));
}
let import = TimelineImport {
tenant_id,
timeline_id,
shard_statuses: ShardImportStatuses::new(shards),
};
let inserted = self
.persistence
.insert_timeline_import(import.to_persistent())
.await
.context("timeline import insert")
.map_err(ApiError::InternalServerError)?;
match inserted {
true => {
tracing::info!(%tenant_id, %timeline_id, "Inserted timeline import");
}
false => {
tracing::info!(%tenant_id, %timeline_id, "Timeline import entry already present");
}
}
None
} else if safekeepers {
// Note that we do not support creating the timeline on the safekeepers
// for imported timelines. The `start_lsn` of the timeline is not known
// until the import finshes.
// https://github.com/neondatabase/neon/issues/11569
let res = self
.tenant_timeline_create_safekeepers(tenant_id, &timeline_info, create_mode)
.tenant_timeline_create_safekeepers(tenant_id, &timeline_info)
.instrument(tracing::info_span!("timeline_create_safekeepers", %tenant_id, timeline_id=%timeline_info.timeline_id))
.await?;
Some(res)
@@ -3764,10 +3835,174 @@ impl Service {
Ok(TimelineCreateResponseStorcon {
timeline_info,
safekeepers,
safekeepers: selected_safekeepers,
})
}
pub(crate) async fn handle_timeline_shard_import_progress_upcall(
self: &Arc<Self>,
req: PutTimelineImportStatusRequest,
) -> Result<(), ApiError> {
let res = self
.persistence
.update_timeline_import(req.tenant_shard_id, req.timeline_id, req.status)
.await;
let timeline_import = match res {
Ok(Ok(Some(timeline_import))) => timeline_import,
Ok(Ok(None)) => {
// Idempotency: we've already seen and handled this update.
return Ok(());
}
Ok(Err(logical_err)) => {
return Err(logical_err.into());
}
Err(db_err) => {
return Err(db_err.into());
}
};
tracing::info!(
tenant_id=%req.tenant_shard_id.tenant_id,
timeline_id=%req.timeline_id,
shard_id=%req.tenant_shard_id.shard_slug(),
"Updated timeline import status to: {timeline_import:?}");
if timeline_import.is_complete() {
tokio::task::spawn({
let this = self.clone();
async move { this.finalize_timeline_import(timeline_import).await }
});
}
Ok(())
}
#[instrument(skip_all, fields(
tenant_id=%import.tenant_id,
shard_id=%import.timeline_id,
))]
async fn finalize_timeline_import(
self: &Arc<Self>,
import: TimelineImport,
) -> anyhow::Result<()> {
tracing::info!("Finalizing timeline import");
pausable_failpoint!("timeline-import-pre-cplane-notification");
let import_failed = import.completion_error().is_some();
if !import_failed {
loop {
if self.cancel.is_cancelled() {
anyhow::bail!("Shut down requested while finalizing import");
}
let active = self.timeline_active_on_all_shards(&import).await?;
match active {
true => {
tracing::info!("Timeline became active on all shards");
break;
}
false => {
tracing::info!("Timeline not active on all shards yet");
tokio::select! {
_ = self.cancel.cancelled() => {
anyhow::bail!("Shut down requested while finalizing import");
},
_ = tokio::time::sleep(Duration::from_secs(5)) => {}
};
}
}
}
}
tracing::info!(%import_failed, "Notifying cplane of import completion");
let client = UpcallClient::new(self.get_config(), self.cancel.child_token());
client.notify_import_complete(&import).await?;
if let Err(err) = self
.persistence
.delete_timeline_import(import.tenant_id, import.timeline_id)
.await
{
tracing::warn!("Failed to delete timeline import entry from database: {err}");
}
// TODO(vlad): Timeline creations in import mode do not return a correct initdb lsn,
// so we can't create the timeline on the safekeepers. Fix by moving creation here.
// https://github.com/neondatabase/neon/issues/11569
tracing::info!(%import_failed, "Timeline import complete");
Ok(())
}
async fn finalize_timeline_imports(self: &Arc<Self>, imports: Vec<TimelineImport>) {
futures::future::join_all(
imports
.into_iter()
.map(|import| self.finalize_timeline_import(import)),
)
.await;
}
async fn timeline_active_on_all_shards(
self: &Arc<Self>,
import: &TimelineImport,
) -> anyhow::Result<bool> {
let targets = {
let locked = self.inner.read().unwrap();
let mut targets = Vec::new();
for (tenant_shard_id, shard) in locked
.tenants
.range(TenantShardId::tenant_range(import.tenant_id))
{
if !import
.shard_statuses
.0
.contains_key(&tenant_shard_id.to_index())
{
anyhow::bail!("Shard layout change detected on completion");
}
if let Some(node_id) = shard.intent.get_attached() {
let node = locked
.nodes
.get(node_id)
.expect("Pageservers may not be deleted while referenced");
targets.push((*tenant_shard_id, node.clone()));
} else {
return Ok(false);
}
}
targets
};
let results = self
.tenant_for_shards_api(
targets,
|tenant_shard_id, client| async move {
client
.timeline_detail(tenant_shard_id, import.timeline_id)
.await
},
1,
1,
SHORT_RECONCILE_TIMEOUT,
&self.cancel,
)
.await;
Ok(results.into_iter().all(|res| match res {
Ok(info) => info.state == TimelineState::Active,
Err(_) => false,
}))
}
pub(crate) async fn tenant_timeline_archival_config(
&self,
tenant_id: TenantId,
@@ -8677,27 +8912,59 @@ impl Service {
self.inner.read().unwrap().get_leadership_status()
}
pub(crate) async fn step_down(&self) -> GlobalObservedState {
/// Handler for step down requests
///
/// Step down runs in separate task since once it's called it should
/// be driven to completion. Subsequent requests will wait on the same
/// step down task.
pub(crate) async fn step_down(self: &Arc<Self>) -> GlobalObservedState {
let handle = self.step_down_barrier.get_or_init(|| {
let step_down_self = self.clone();
let (tx, rx) = tokio::sync::watch::channel::<Option<GlobalObservedState>>(None);
tokio::spawn(async move {
let state = step_down_self.step_down_task().await;
tx.send(Some(state))
.expect("Task Arc<Service> keeps receiver alive");
});
rx
});
handle
.clone()
.wait_for(|observed_state| observed_state.is_some())
.await
.expect("Task Arc<Service> keeps sender alive")
.deref()
.clone()
.expect("Checked above")
}
async fn step_down_task(&self) -> GlobalObservedState {
tracing::info!("Received step down request from peer");
failpoint_support::sleep_millis_async!("sleep-on-step-down-handling");
self.inner.write().unwrap().step_down();
// Wait for reconciliations to stop, or terminate this process if they
// fail to stop in time (this indicates a bug in shutdown)
tokio::select! {
_ = self.stop_reconciliations(StopReconciliationsReason::SteppingDown) => {
tracing::info!("Reconciliations stopped, proceeding with step down");
}
_ = async {
failpoint_support::sleep_millis_async!("step-down-delay-timeout");
tokio::time::sleep(Duration::from_secs(10)).await
} => {
tracing::warn!("Step down timed out while waiting for reconciliation gate, terminating process");
let stop_reconciliations =
self.stop_reconciliations(StopReconciliationsReason::SteppingDown);
let mut stop_reconciliations = std::pin::pin!(stop_reconciliations);
// The caller may proceed to act as leader when it sees this request fail: reduce the chance
// of a split-brain situation by terminating this controller instead of leaving it up in a partially-shut-down state.
std::process::exit(1);
let started_at = Instant::now();
// Wait for reconciliations to stop and warn if that's taking a long time
loop {
tokio::select! {
_ = &mut stop_reconciliations => {
tracing::info!("Reconciliations stopped, proceeding with step down");
break;
}
_ = tokio::time::sleep(Duration::from_secs(10)) => {
tracing::warn!(
elapsed_sec=%started_at.elapsed().as_secs(),
"Stopping reconciliations during step down is taking too long"
);
}
}
}

View File

@@ -15,7 +15,7 @@ use http_utils::error::ApiError;
use pageserver_api::controller_api::{
SafekeeperDescribeResponse, SkSchedulingPolicy, TimelineImportRequest,
};
use pageserver_api::models::{self, SafekeeperInfo, SafekeepersInfo, TimelineInfo};
use pageserver_api::models::{SafekeeperInfo, SafekeepersInfo, TimelineInfo};
use safekeeper_api::membership::{MemberSet, SafekeeperId};
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
@@ -207,7 +207,6 @@ impl Service {
self: &Arc<Self>,
tenant_id: TenantId,
timeline_info: &TimelineInfo,
create_mode: models::TimelineCreateRequestMode,
) -> Result<SafekeepersInfo, ApiError> {
let timeline_id = timeline_info.timeline_id;
let pg_version = timeline_info.pg_version * 10000;
@@ -217,15 +216,8 @@ impl Service {
// previously existed as on retries in theory endpoint might have
// already written some data and advanced last_record_lsn, while we want
// safekeepers to have consistent start_lsn.
let start_lsn = match create_mode {
models::TimelineCreateRequestMode::Bootstrap { .. } => timeline_info.last_record_lsn,
models::TimelineCreateRequestMode::Branch { .. } => timeline_info.last_record_lsn,
models::TimelineCreateRequestMode::ImportPgdata { .. } => {
return Err(ApiError::InternalServerError(anyhow::anyhow!(
"import pgdata doesn't specify the start lsn, aborting creation on safekeepers"
)))?;
}
};
let start_lsn = timeline_info.last_record_lsn;
// Choose initial set of safekeepers respecting affinity
let sks = self.safekeepers_for_new_timeline().await?;
let sks_persistence = sks.iter().map(|sk| sk.id.0 as i64).collect::<Vec<_>>();

View File

@@ -0,0 +1,260 @@
use std::time::Duration;
use std::{collections::HashMap, str::FromStr};
use http_utils::error::ApiError;
use reqwest::Method;
use serde::{Deserialize, Serialize};
use pageserver_api::models::ShardImportStatus;
use tokio_util::sync::CancellationToken;
use utils::{
id::{TenantId, TimelineId},
shard::ShardIndex,
};
use crate::{persistence::TimelineImportPersistence, service::Config};
#[derive(Serialize, Deserialize, Clone, Debug)]
pub(crate) struct ShardImportStatuses(pub(crate) HashMap<ShardIndex, ShardImportStatus>);
impl ShardImportStatuses {
pub(crate) fn new(shards: Vec<ShardIndex>) -> Self {
ShardImportStatuses(
shards
.into_iter()
.map(|ts_id| (ts_id, ShardImportStatus::InProgress))
.collect(),
)
}
}
#[derive(Debug)]
pub(crate) struct TimelineImport {
pub(crate) tenant_id: TenantId,
pub(crate) timeline_id: TimelineId,
pub(crate) shard_statuses: ShardImportStatuses,
}
pub(crate) enum TimelineImportUpdateFollowUp {
Persist,
None,
}
pub(crate) enum TimelineImportUpdateError {
ImportNotFound {
tenant_id: TenantId,
timeline_id: TimelineId,
},
MismatchedShards,
UnexpectedUpdate,
}
impl From<TimelineImportUpdateError> for ApiError {
fn from(err: TimelineImportUpdateError) -> ApiError {
match err {
TimelineImportUpdateError::ImportNotFound {
tenant_id,
timeline_id,
} => ApiError::NotFound(
anyhow::anyhow!("Import for {tenant_id}/{timeline_id} not found").into(),
),
TimelineImportUpdateError::MismatchedShards => {
ApiError::InternalServerError(anyhow::anyhow!(
"Import shards do not match update request, likely a shard split happened during import, this is a bug"
))
}
TimelineImportUpdateError::UnexpectedUpdate => {
ApiError::InternalServerError(anyhow::anyhow!("Update request is unexpected"))
}
}
}
}
impl TimelineImport {
pub(crate) fn from_persistent(persistent: TimelineImportPersistence) -> anyhow::Result<Self> {
let tenant_id = TenantId::from_str(persistent.tenant_id.as_str())?;
let timeline_id = TimelineId::from_str(persistent.timeline_id.as_str())?;
let shard_statuses = serde_json::from_value(persistent.shard_statuses)?;
Ok(TimelineImport {
tenant_id,
timeline_id,
shard_statuses,
})
}
pub(crate) fn to_persistent(&self) -> TimelineImportPersistence {
TimelineImportPersistence {
tenant_id: self.tenant_id.to_string(),
timeline_id: self.timeline_id.to_string(),
shard_statuses: serde_json::to_value(self.shard_statuses.clone()).unwrap(),
}
}
pub(crate) fn update(
&mut self,
shard: ShardIndex,
status: ShardImportStatus,
) -> Result<TimelineImportUpdateFollowUp, TimelineImportUpdateError> {
use std::collections::hash_map::Entry::*;
match self.shard_statuses.0.entry(shard) {
Occupied(mut occ) => {
let crnt = occ.get_mut();
if *crnt == status {
Ok(TimelineImportUpdateFollowUp::None)
} else if crnt.is_terminal() && *crnt != status {
Err(TimelineImportUpdateError::UnexpectedUpdate)
} else {
*crnt = status;
Ok(TimelineImportUpdateFollowUp::Persist)
}
}
Vacant(_) => Err(TimelineImportUpdateError::MismatchedShards),
}
}
pub(crate) fn is_complete(&self) -> bool {
self.shard_statuses
.0
.values()
.all(|status| status.is_terminal())
}
pub(crate) fn completion_error(&self) -> Option<String> {
assert!(self.is_complete());
let shard_errors: HashMap<_, _> = self
.shard_statuses
.0
.iter()
.filter_map(|(shard, status)| {
if let ShardImportStatus::Error(err) = status {
Some((*shard, err.clone()))
} else {
None
}
})
.collect();
if shard_errors.is_empty() {
None
} else {
Some(serde_json::to_string(&shard_errors).unwrap())
}
}
}
pub(crate) struct UpcallClient {
authorization_header: Option<String>,
client: reqwest::Client,
cancel: CancellationToken,
base_url: String,
}
const IMPORT_COMPLETE_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Serialize, Deserialize, Debug)]
struct ImportCompleteRequest {
tenant_id: TenantId,
timeline_id: TimelineId,
error: Option<String>,
}
impl UpcallClient {
pub(crate) fn new(config: &Config, cancel: CancellationToken) -> Self {
let authorization_header = config
.control_plane_jwt_token
.clone()
.map(|jwt| format!("Bearer {}", jwt));
let client = reqwest::ClientBuilder::new()
.timeout(IMPORT_COMPLETE_REQUEST_TIMEOUT)
.build()
.expect("Failed to construct HTTP client");
let base_url = config
.control_plane_url
.clone()
.expect("must be configured");
Self {
authorization_header,
client,
cancel,
base_url,
}
}
/// Notify control plane of a completed import
///
/// This method guarantees at least once delivery semantics assuming
/// eventual cplane availability. The cplane API is idempotent.
pub(crate) async fn notify_import_complete(
&self,
import: &TimelineImport,
) -> anyhow::Result<()> {
let endpoint = if self.base_url.ends_with('/') {
format!("{}import_complete", self.base_url)
} else {
format!("{}/import_complete", self.base_url)
};
tracing::info!("Endpoint is {endpoint}");
let request = self
.client
.request(Method::PUT, endpoint)
.json(&ImportCompleteRequest {
tenant_id: import.tenant_id,
timeline_id: import.timeline_id,
error: import.completion_error(),
})
.timeout(IMPORT_COMPLETE_REQUEST_TIMEOUT);
let request = if let Some(auth) = &self.authorization_header {
request.header(reqwest::header::AUTHORIZATION, auth)
} else {
request
};
const RETRY_DELAY: Duration = Duration::from_secs(1);
let mut attempt = 1;
loop {
if self.cancel.is_cancelled() {
return Err(anyhow::anyhow!(
"Shutting down while notifying cplane of import completion"
));
}
match request.try_clone().unwrap().send().await {
Ok(response) if response.status().is_success() => {
return Ok(());
}
Ok(response) => {
tracing::warn!(
"Import complete notification failed with status {}, attempt {}",
response.status(),
attempt
);
}
Err(e) => {
tracing::warn!(
"Import complete notification failed with error: {}, attempt {}",
e,
attempt
);
}
}
tokio::select! {
_ = tokio::time::sleep(RETRY_DELAY) => {}
_ = self.cancel.cancelled() => {
return Err(anyhow::anyhow!("Shutting down while notifying cplane of import completion"));
}
}
attempt += 1;
}
}
}

View File

@@ -5,8 +5,6 @@ edition = "2024"
license.workspace = true
[dependencies]
aws-config.workspace = true
aws-sdk-s3.workspace = true
either.workspace = true
anyhow.workspace = true
hex.workspace = true

View File

@@ -12,14 +12,9 @@ pub mod tenant_snapshot;
use std::env;
use std::fmt::Display;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use anyhow::Context;
use aws_config::retry::{RetryConfigBuilder, RetryMode};
use aws_sdk_s3::Client;
use aws_sdk_s3::config::Region;
use aws_sdk_s3::error::DisplayErrorContext;
use camino::{Utf8Path, Utf8PathBuf};
use clap::ValueEnum;
use futures::{Stream, StreamExt};
@@ -28,7 +23,7 @@ use pageserver::tenant::remote_timeline_client::{remote_tenant_path, remote_time
use pageserver_api::shard::TenantShardId;
use remote_storage::{
DownloadOpts, GenericRemoteStorage, Listing, ListingMode, RemotePath, RemoteStorageConfig,
RemoteStorageKind, S3Config,
RemoteStorageKind, VersionId,
};
use reqwest::Url;
use serde::{Deserialize, Serialize};
@@ -351,21 +346,6 @@ pub fn init_logging(file_name: &str) -> Option<WorkerGuard> {
}
}
async fn init_s3_client(bucket_region: Region) -> Client {
let mut retry_config_builder = RetryConfigBuilder::new();
retry_config_builder
.set_max_attempts(Some(3))
.set_mode(Some(RetryMode::Adaptive));
let config = aws_config::defaults(aws_config::BehaviorVersion::v2024_03_28())
.region(bucket_region)
.retry_config(retry_config_builder.build())
.load()
.await;
Client::new(&config)
}
fn default_prefix_in_bucket(node_kind: NodeKind) -> &'static str {
match node_kind {
NodeKind::Pageserver => "pageserver/v1/",
@@ -385,23 +365,6 @@ fn make_root_target(desc_str: String, prefix_in_bucket: String, node_kind: NodeK
}
}
async fn init_remote_s3(
bucket_config: S3Config,
node_kind: NodeKind,
) -> anyhow::Result<(Arc<Client>, RootTarget)> {
let bucket_region = Region::new(bucket_config.bucket_region);
let s3_client = Arc::new(init_s3_client(bucket_region).await);
let default_prefix = default_prefix_in_bucket(node_kind).to_string();
let s3_root = make_root_target(
bucket_config.bucket_name,
bucket_config.prefix_in_bucket.unwrap_or(default_prefix),
node_kind,
);
Ok((s3_client, s3_root))
}
async fn init_remote(
mut storage_config: BucketConfig,
node_kind: NodeKind,
@@ -499,7 +462,7 @@ async fn list_objects_with_retries(
remote_client.bucket_name().unwrap_or_default(),
s3_target.prefix_in_bucket,
s3_target.delimiter,
DisplayErrorContext(e),
e,
);
let backoff_time = 1 << trial.min(5);
tokio::time::sleep(Duration::from_secs(backoff_time)).await;
@@ -549,14 +512,18 @@ async fn download_object_with_retries(
anyhow::bail!("Failed to download objects with key {key} {MAX_RETRIES} times")
}
async fn download_object_to_file_s3(
s3_client: &Client,
bucket_name: &str,
key: &str,
version_id: Option<&str>,
async fn download_object_to_file(
remote_storage: &GenericRemoteStorage,
key: &RemotePath,
version_id: Option<VersionId>,
local_path: &Utf8Path,
) -> anyhow::Result<()> {
let opts = DownloadOpts {
version_id: version_id.clone(),
..Default::default()
};
let tmp_path = Utf8PathBuf::from(format!("{local_path}.tmp"));
let cancel = CancellationToken::new();
for _ in 0..MAX_RETRIES {
tokio::fs::remove_file(&tmp_path)
.await
@@ -566,28 +533,24 @@ async fn download_object_to_file_s3(
.await
.context("Opening output file")?;
let request = s3_client.get_object().bucket(bucket_name).key(key);
let res = remote_storage.download(key, &opts, &cancel).await;
let request = match version_id {
Some(version_id) => request.version_id(version_id),
None => request,
};
let response_stream = match request.send().await {
let download = match res {
Ok(response) => response,
Err(e) => {
error!(
"Failed to download object for key {key} version {}: {e:#}",
version_id.unwrap_or("")
"Failed to download object for key {key} version {:?}: {e:#}",
&version_id.as_ref().unwrap_or(&VersionId(String::new()))
);
tokio::time::sleep(Duration::from_secs(1)).await;
continue;
}
};
let mut read_stream = response_stream.body.into_async_read();
//response_stream.download_stream
tokio::io::copy(&mut read_stream, &mut file).await?;
let mut body = tokio_util::io::StreamReader::new(download.download_stream);
tokio::io::copy(&mut body, &mut file).await?;
tokio::fs::rename(&tmp_path, local_path).await?;
return Ok(());

View File

@@ -1,31 +1,30 @@
use std::collections::HashMap;
use std::sync::Arc;
use anyhow::Context;
use async_stream::stream;
use aws_sdk_s3::Client;
use camino::Utf8PathBuf;
use futures::{StreamExt, TryStreamExt};
use pageserver::tenant::IndexPart;
use pageserver::tenant::remote_timeline_client::index::LayerFileMetadata;
use pageserver::tenant::remote_timeline_client::remote_layer_path;
use pageserver::tenant::storage_layer::LayerName;
use pageserver_api::shard::TenantShardId;
use remote_storage::{GenericRemoteStorage, S3Config};
use remote_storage::GenericRemoteStorage;
use tokio_util::sync::CancellationToken;
use utils::generation::Generation;
use utils::id::TenantId;
use crate::checks::{BlobDataParseResult, RemoteTimelineBlobData, list_timeline_blobs};
use crate::metadata_stream::{stream_tenant_shards, stream_tenant_timelines};
use crate::{
BucketConfig, NodeKind, RootTarget, TenantShardTimelineId, download_object_to_file_s3,
init_remote, init_remote_s3,
BucketConfig, NodeKind, RootTarget, TenantShardTimelineId, download_object_to_file, init_remote,
};
pub struct SnapshotDownloader {
s3_client: Arc<Client>,
s3_root: RootTarget,
remote_client: GenericRemoteStorage,
#[allow(dead_code)]
target: RootTarget,
bucket_config: BucketConfig,
bucket_config_s3: S3Config,
tenant_id: TenantId,
output_path: Utf8PathBuf,
concurrency: usize,
@@ -38,17 +37,13 @@ impl SnapshotDownloader {
output_path: Utf8PathBuf,
concurrency: usize,
) -> anyhow::Result<Self> {
let bucket_config_s3 = match &bucket_config.0.storage {
remote_storage::RemoteStorageKind::AwsS3(config) => config.clone(),
_ => panic!("only S3 configuration is supported for snapshot downloading"),
};
let (s3_client, s3_root) =
init_remote_s3(bucket_config_s3.clone(), NodeKind::Pageserver).await?;
let (remote_client, target) =
init_remote(bucket_config.clone(), NodeKind::Pageserver).await?;
Ok(Self {
s3_client,
s3_root,
remote_client,
target,
bucket_config,
bucket_config_s3,
tenant_id,
output_path,
concurrency,
@@ -61,6 +56,7 @@ impl SnapshotDownloader {
layer_name: LayerName,
layer_metadata: LayerFileMetadata,
) -> anyhow::Result<(LayerName, LayerFileMetadata)> {
let cancel = CancellationToken::new();
// Note this is local as in a local copy of S3 data, not local as in the pageserver's local format. They use
// different layer names (remote-style has the generation suffix)
let local_path = self.output_path.join(format!(
@@ -82,30 +78,27 @@ impl SnapshotDownloader {
} else {
tracing::debug!("{} requires download...", local_path);
let timeline_root = self.s3_root.timeline_root(&ttid);
let remote_layer_path = format!(
"{}{}{}",
timeline_root.prefix_in_bucket,
layer_name,
layer_metadata.generation.get_suffix()
let remote_path = remote_layer_path(
&ttid.tenant_shard_id.tenant_id,
&ttid.timeline_id,
layer_metadata.shard,
&layer_name,
layer_metadata.generation,
);
let mode = remote_storage::ListingMode::NoDelimiter;
// List versions: the object might be deleted.
let versions = self
.s3_client
.list_object_versions()
.bucket(self.bucket_config_s3.bucket_name.clone())
.prefix(&remote_layer_path)
.send()
.remote_client
.list_versions(Some(&remote_path), mode, None, &cancel)
.await?;
let Some(version) = versions.versions.as_ref().and_then(|v| v.first()) else {
return Err(anyhow::anyhow!("No versions found for {remote_layer_path}"));
let Some(version) = versions.versions.first() else {
return Err(anyhow::anyhow!("No versions found for {remote_path}"));
};
download_object_to_file_s3(
&self.s3_client,
&self.bucket_config_s3.bucket_name,
&remote_layer_path,
version.version_id.as_deref(),
download_object_to_file(
&self.remote_client,
&remote_path,
version.version_id().cloned(),
&local_path,
)
.await?;

View File

@@ -16,4 +16,5 @@ pytest_plugins = (
"fixtures.slow",
"fixtures.reruns",
"fixtures.fast_import",
"fixtures.pg_config",
)

View File

@@ -417,14 +417,14 @@ class NeonLocalCli(AbstractNeonCli):
cmd.append(f"--instance-id={instance_id}")
return self.raw_cli(cmd)
def object_storage_start(self, timeout_in_seconds: int | None = None):
cmd = ["object-storage", "start"]
def endpoint_storage_start(self, timeout_in_seconds: int | None = None):
cmd = ["endpoint-storage", "start"]
if timeout_in_seconds is not None:
cmd.append(f"--start-timeout={timeout_in_seconds}s")
return self.raw_cli(cmd)
def object_storage_stop(self, immediate: bool):
cmd = ["object-storage", "stop"]
def endpoint_storage_stop(self, immediate: bool):
cmd = ["endpoint-storage", "stop"]
if immediate:
cmd.extend(["-m", "immediate"])
return self.raw_cli(cmd)

View File

@@ -1029,7 +1029,7 @@ class NeonEnvBuilder:
self.env.broker.assert_no_errors()
self.env.object_storage.assert_no_errors()
self.env.endpoint_storage.assert_no_errors()
try:
self.overlay_cleanup_teardown()
@@ -1126,7 +1126,7 @@ class NeonEnv:
pagectl_env_vars["RUST_LOG"] = self.rust_log_override
self.pagectl = Pagectl(extra_env=pagectl_env_vars, binpath=self.neon_binpath)
self.object_storage = ObjectStorage(self)
self.endpoint_storage = EndpointStorage(self)
# The URL for the pageserver to use as its control_plane_api config
if config.storage_controller_port_override is not None:
@@ -1183,7 +1183,7 @@ class NeonEnv:
},
"safekeepers": [],
"pageservers": [],
"object_storage": {"port": self.port_distributor.get_port()},
"endpoint_storage": {"port": self.port_distributor.get_port()},
"generate_local_ssl_certs": self.generate_local_ssl_certs,
}
@@ -1291,7 +1291,11 @@ class NeonEnv:
ps_cfg[key] = value
if self.pageserver_virtual_file_io_mode is not None:
ps_cfg["virtual_file_io_mode"] = self.pageserver_virtual_file_io_mode
# TODO(christian): https://github.com/neondatabase/neon/issues/11598
if not config.test_may_use_compatibility_snapshot_binaries:
ps_cfg["virtual_file_io_mode"] = self.pageserver_virtual_file_io_mode
else:
log.info("ignoring virtual_file_io_mode parametrization for compatibility test")
if self.pageserver_wal_receiver_protocol is not None:
key, value = PageserverWalReceiverProtocol.to_config_key_value(
@@ -1420,7 +1424,7 @@ class NeonEnv:
self.storage_controller.on_safekeeper_deploy(sk_id, body)
self.storage_controller.safekeeper_scheduling_policy(sk_id, "Active")
self.object_storage.start(timeout_in_seconds=timeout_in_seconds)
self.endpoint_storage.start(timeout_in_seconds=timeout_in_seconds)
def stop(self, immediate=False, ps_assert_metric_no_errors=False, fail_on_endpoint_errors=True):
"""
@@ -1439,7 +1443,7 @@ class NeonEnv:
except Exception as e:
raise_later = e
self.object_storage.stop(immediate=immediate)
self.endpoint_storage.stop(immediate=immediate)
# Stop storage controller before pageservers: we don't want it to spuriously
# detect a pageserver "failure" during test teardown
@@ -2660,24 +2664,24 @@ class NeonStorageController(MetricsGetter, LogUtils):
self.stop(immediate=True)
class ObjectStorage(LogUtils):
class EndpointStorage(LogUtils):
def __init__(self, env: NeonEnv):
service_dir = env.repo_dir / "object_storage"
super().__init__(logfile=service_dir / "object_storage.log")
self.conf_path = service_dir / "object_storage.json"
service_dir = env.repo_dir / "endpoint_storage"
super().__init__(logfile=service_dir / "endpoint_storage.log")
self.conf_path = service_dir / "endpoint_storage.json"
self.env = env
def base_url(self):
return json.loads(self.conf_path.read_text())["listen"]
def start(self, timeout_in_seconds: int | None = None):
self.env.neon_cli.object_storage_start(timeout_in_seconds)
self.env.neon_cli.endpoint_storage_start(timeout_in_seconds)
def stop(self, immediate: bool = False):
self.env.neon_cli.object_storage_stop(immediate)
self.env.neon_cli.endpoint_storage_stop(immediate)
def assert_no_errors(self):
assert_no_errors(self.logfile, "object_storage", [])
assert_no_errors(self.logfile, "endpoint_storage", [])
class NeonProxiedStorageController(NeonStorageController):
@@ -3380,6 +3384,9 @@ class VanillaPostgres(PgProtocol):
"""Return size of pgdatadir subdirectory in bytes."""
return get_dir_size(self.pgdatadir / subdir)
def is_running(self) -> bool:
return self.running
def __enter__(self) -> Self:
return self

View File

@@ -0,0 +1,249 @@
from __future__ import annotations
import shlex
from enum import StrEnum
from pathlib import Path
from typing import TYPE_CHECKING, cast, final
import pytest
if TYPE_CHECKING:
from collections.abc import Iterator
from typing import IO
from fixtures.neon_fixtures import PgBin
@final
class PgConfigKey(StrEnum):
BINDIR = "BINDIR"
DOCDIR = "DOCDIR"
HTMLDIR = "HTMLDIR"
INCLUDEDIR = "INCLUDEDIR"
PKGINCLUDEDIR = "PKGINCLUDEDIR"
INCLUDEDIR_SERVER = "INCLUDEDIR-SERVER"
LIBDIR = "LIBDIR"
PKGLIBDIR = "PKGLIBDIR"
LOCALEDIR = "LOCALEDIR"
MANDIR = "MANDIR"
SHAREDIR = "SHAREDIR"
SYSCONFDIR = "SYSCONFDIR"
PGXS = "PGXS"
CONFIGURE = "CONFIGURE"
CC = "CC"
CPPFLAGS = "CPPFLAGS"
CFLAGS = "CFLAGS"
CFLAGS_SL = "CFLAGS_SL"
LDFLAGS = "LDFLAGS"
LDFLAGS_EX = "LDFLAGS_EX"
LDFLAGS_SL = "LDFLAGS_SL"
LIBS = "LIBS"
VERSION = "VERSION"
if TYPE_CHECKING:
# TODO: This could become a TypedDict if Python ever allows StrEnums to be
# keys.
PgConfig = dict[PgConfigKey, str | Path | list[str]]
def __get_pg_config(pg_bin: PgBin) -> PgConfig:
"""Get pg_config values by invoking the command"""
cmd = pg_bin.run_nonblocking(["pg_config"])
cmd.wait()
if cmd.returncode != 0:
pytest.exit("")
assert cmd.stdout
stdout = cast("IO[str]", cmd.stdout)
# Parse the output into a dictionary
values: PgConfig = {}
for line in stdout.readlines():
if "=" in line:
key, value = line.split("=", 1)
value = value.strip()
match PgConfigKey(key.strip()):
case (
(
PgConfigKey.CC
| PgConfigKey.CPPFLAGS
| PgConfigKey.CFLAGS
| PgConfigKey.CFLAGS_SL
| PgConfigKey.LDFLAGS
| PgConfigKey.LDFLAGS_EX
| PgConfigKey.LDFLAGS_SL
| PgConfigKey.LIBS
) as k
):
values[k] = shlex.split(value)
case (
(
PgConfigKey.BINDIR
| PgConfigKey.DOCDIR
| PgConfigKey.HTMLDIR
| PgConfigKey.INCLUDEDIR
| PgConfigKey.PKGINCLUDEDIR
| PgConfigKey.INCLUDEDIR_SERVER
| PgConfigKey.LIBDIR
| PgConfigKey.PKGLIBDIR
| PgConfigKey.LOCALEDIR
| PgConfigKey.MANDIR
| PgConfigKey.SHAREDIR
| PgConfigKey.SYSCONFDIR
| PgConfigKey.PGXS
) as k
):
values[k] = Path(value)
case _ as k:
values[k] = value
return values
@pytest.fixture(scope="function")
def pg_config(pg_bin: PgBin) -> Iterator[PgConfig]:
"""Dictionary of all pg_config values from the system"""
yield __get_pg_config(pg_bin)
@pytest.fixture(scope="function")
def pg_config_bindir(pg_config: PgConfig) -> Iterator[Path]:
"""BINDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.BINDIR])
@pytest.fixture(scope="function")
def pg_config_docdir(pg_config: PgConfig) -> Iterator[Path]:
"""DOCDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.DOCDIR])
@pytest.fixture(scope="function")
def pg_config_htmldir(pg_config: PgConfig) -> Iterator[Path]:
"""HTMLDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.HTMLDIR])
@pytest.fixture(scope="function")
def pg_config_includedir(
pg_config: dict[PgConfigKey, str | Path | list[str]],
) -> Iterator[Path]:
"""INCLUDEDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.INCLUDEDIR])
@pytest.fixture(scope="function")
def pg_config_pkgincludedir(pg_config: PgConfig) -> Iterator[Path]:
"""PKGINCLUDEDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.PKGINCLUDEDIR])
@pytest.fixture(scope="function")
def pg_config_includedir_server(pg_config: PgConfig) -> Iterator[Path]:
"""INCLUDEDIR-SERVER value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.INCLUDEDIR_SERVER])
@pytest.fixture(scope="function")
def pg_config_libdir(pg_config: PgConfig) -> Iterator[Path]:
"""LIBDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.LIBDIR])
@pytest.fixture(scope="function")
def pg_config_pkglibdir(pg_config: PgConfig) -> Iterator[Path]:
"""PKGLIBDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.PKGLIBDIR])
@pytest.fixture(scope="function")
def pg_config_localedir(pg_config: PgConfig) -> Iterator[Path]:
"""LOCALEDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.LOCALEDIR])
@pytest.fixture(scope="function")
def pg_config_mandir(pg_config: PgConfig) -> Iterator[Path]:
"""MANDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.MANDIR])
@pytest.fixture(scope="function")
def pg_config_sharedir(pg_config: PgConfig) -> Iterator[Path]:
"""SHAREDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.SHAREDIR])
@pytest.fixture(scope="function")
def pg_config_sysconfdir(pg_config: PgConfig) -> Iterator[Path]:
"""SYSCONFDIR value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.SYSCONFDIR])
@pytest.fixture(scope="function")
def pg_config_pgxs(pg_config: PgConfig) -> Iterator[Path]:
"""PGXS value from pg_config"""
yield cast("Path", pg_config[PgConfigKey.PGXS])
@pytest.fixture(scope="function")
def pg_config_configure(pg_config: PgConfig) -> Iterator[str]:
"""CONFIGURE value from pg_config"""
yield cast("str", pg_config[PgConfigKey.CONFIGURE])
@pytest.fixture(scope="function")
def pg_config_cc(pg_config: PgConfig) -> Iterator[list[str]]:
"""CC value from pg_config"""
yield cast("list[str]", pg_config[PgConfigKey.CC])
@pytest.fixture(scope="function")
def pg_config_cppflags(pg_config: PgConfig) -> Iterator[list[str]]:
"""CPPFLAGS value from pg_config"""
yield cast("list[str]", pg_config[PgConfigKey.CPPFLAGS])
@pytest.fixture(scope="function")
def pg_config_cflags(pg_config: PgConfig) -> Iterator[list[str]]:
"""CFLAGS value from pg_config"""
yield cast("list[str]", pg_config[PgConfigKey.CFLAGS])
@pytest.fixture(scope="function")
def pg_config_cflags_sl(pg_config: PgConfig) -> Iterator[list[str]]:
"""CFLAGS_SL value from pg_config"""
yield cast("list[str]", pg_config[PgConfigKey.CFLAGS_SL])
@pytest.fixture(scope="function")
def pg_config_ldflags(pg_config: PgConfig) -> Iterator[list[str]]:
"""LDFLAGS value from pg_config"""
yield cast("list[str]", pg_config[PgConfigKey.LDFLAGS])
@pytest.fixture(scope="function")
def pg_config_ldflags_ex(pg_config: PgConfig) -> Iterator[list[str]]:
"""LDFLAGS_EX value from pg_config"""
yield cast("list[str]", pg_config[PgConfigKey.LDFLAGS_EX])
@pytest.fixture(scope="function")
def pg_config_ldflags_sl(pg_config: PgConfig) -> Iterator[list[str]]:
"""LDFLAGS_SL value from pg_config"""
yield cast("list[str]", pg_config[PgConfigKey.LDFLAGS_SL])
@pytest.fixture(scope="function")
def pg_config_libs(pg_config: PgConfig) -> Iterator[list[str]]:
"""LIBS value from pg_config"""
yield cast("list[str]", pg_config[PgConfigKey.LIBS])
@pytest.fixture(scope="function")
def pg_config_version(pg_config: PgConfig) -> Iterator[str]:
"""VERSION value from pg_config"""
yield cast("str", pg_config[PgConfigKey.VERSION])

View File

@@ -65,7 +65,7 @@ def test_ro_replica_lag(
project = neon_api.create_project(pg_version)
project_id = project["project"]["id"]
log.info("Project ID: %s", project_id)
log.info("Primary endpoint ID: %s", project["project"]["endpoints"][0]["id"])
log.info("Primary endpoint ID: %s", project["endpoints"][0]["id"])
neon_api.wait_for_operation_to_finish(project_id)
error_occurred = False
try:
@@ -198,7 +198,7 @@ def test_replication_start_stop(
project = neon_api.create_project(pg_version)
project_id = project["project"]["id"]
log.info("Project ID: %s", project_id)
log.info("Primary endpoint ID: %s", project["project"]["endpoints"][0]["id"])
log.info("Primary endpoint ID: %s", project["endpoints"][0]["id"])
neon_api.wait_for_operation_to_finish(project_id)
try:
branch_id = project["branch"]["id"]

View File

@@ -1,12 +0,0 @@
\echo Use "CREATE EXTENSION test_extension" to load this file. \quit
CREATE SCHEMA test_extension;
CREATE FUNCTION test_extension.motd()
RETURNS void
IMMUTABLE LEAKPROOF PARALLEL SAFE
AS $$
BEGIN
RAISE NOTICE 'Have a great day';
END;
$$ LANGUAGE 'plpgsql';

View File

@@ -1,6 +1,6 @@
\echo Use "ALTER EXTENSION test_extension UPDATE TO '1.1'" to load this file. \quit
\echo Use "ALTER EXTENSION test_extension_sql_only UPDATE TO '1.1'" to load this file. \quit
CREATE FUNCTION test_extension.fun_fact()
CREATE FUNCTION test_extension_sql_only.fun_fact()
RETURNS void
IMMUTABLE LEAKPROOF PARALLEL SAFE
AS $$

View File

@@ -0,0 +1,12 @@
\echo Use "CREATE EXTENSION test_extension_sql_only" to load this file. \quit
CREATE SCHEMA test_extension_sql_only;
CREATE FUNCTION test_extension_sql_only.motd()
RETURNS void
IMMUTABLE LEAKPROOF PARALLEL SAFE
AS $$
BEGIN
RAISE NOTICE 'Have a great day';
END;
$$ LANGUAGE 'plpgsql';

View File

@@ -0,0 +1 @@
comment = 'Test extension SQL only'

View File

@@ -0,0 +1,6 @@
\echo Use "ALTER EXTENSION test_extension_with_lib UPDATE TO '1.1'" to load this file. \quit
CREATE FUNCTION test_extension_with_lib.fun_fact()
RETURNS void
IMMUTABLE LEAKPROOF PARALLEL SAFE
AS 'MODULE_PATHNAME', 'fun_fact' LANGUAGE C;

View File

@@ -0,0 +1,8 @@
\echo Use "CREATE EXTENSION test_extension_with_lib" to load this file. \quit
CREATE SCHEMA test_extension_with_lib;
CREATE FUNCTION test_extension_with_lib.motd()
RETURNS void
IMMUTABLE LEAKPROOF PARALLEL SAFE
AS 'MODULE_PATHNAME', 'motd' LANGUAGE C;

View File

@@ -0,0 +1,34 @@
#include <postgres.h>
#include <fmgr.h>
PG_MODULE_MAGIC;
PG_FUNCTION_INFO_V1(motd);
PG_FUNCTION_INFO_V1(fun_fact);
/* Old versions of Postgres didn't pre-declare this in fmgr.h */
#if PG_MAJORVERSION_NUM <= 15
void _PG_init(void);
#endif
void
_PG_init(void)
{
}
Datum
motd(PG_FUNCTION_ARGS)
{
elog(NOTICE, "Have a great day");
PG_RETURN_VOID();
}
Datum
fun_fact(PG_FUNCTION_ARGS)
{
elog(NOTICE, "Neon has a melting point of -246.08 C");
PG_RETURN_VOID();
}

View File

@@ -0,0 +1,2 @@
comment = 'Test extension with lib'
module_pathname = '$libdir/test_extension_with_lib'

View File

@@ -0,0 +1,70 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from fixtures.metrics import parse_metrics
from fixtures.utils import wait_until
if TYPE_CHECKING:
from fixtures.neon_fixtures import NeonEnv
def test_compute_monitor(neon_simple_env: NeonEnv):
"""
Test that compute_ctl can detect Postgres going down (unresponsive) and
reconnect when it comes back online. Also check that the downtime metrics
are properly emitted.
"""
TEST_DB = "test_compute_monitor"
env = neon_simple_env
endpoint = env.endpoints.create_start("main")
# Check that default postgres database is present
with endpoint.cursor() as cursor:
cursor.execute("SELECT datname FROM pg_database WHERE datname = 'postgres'")
catalog_db = cursor.fetchone()
assert catalog_db is not None
assert len(catalog_db) == 1
# Create a new database
cursor.execute(f"CREATE DATABASE {TEST_DB}")
# Drop database 'postgres'
with endpoint.cursor(dbname=TEST_DB) as cursor:
# Use FORCE to terminate all connections to the database
cursor.execute("DROP DATABASE postgres WITH (FORCE)")
client = endpoint.http_client()
def check_metrics_down():
raw_metrics = client.metrics()
metrics = parse_metrics(raw_metrics)
compute_pg_current_downtime_ms = metrics.query_all("compute_pg_current_downtime_ms")
assert len(compute_pg_current_downtime_ms) == 1
assert compute_pg_current_downtime_ms[0].value > 0
compute_pg_downtime_ms_total = metrics.query_all("compute_pg_downtime_ms_total")
assert len(compute_pg_downtime_ms_total) == 1
assert compute_pg_downtime_ms_total[0].value > 0
wait_until(check_metrics_down)
# Recreate postgres database
with endpoint.cursor(dbname=TEST_DB) as cursor:
cursor.execute("CREATE DATABASE postgres")
# Current downtime should reset to 0, but not total downtime
def check_metrics_up():
raw_metrics = client.metrics()
metrics = parse_metrics(raw_metrics)
compute_pg_current_downtime_ms = metrics.query_all("compute_pg_current_downtime_ms")
assert len(compute_pg_current_downtime_ms) == 1
assert compute_pg_current_downtime_ms[0].value == 0
compute_pg_downtime_ms_total = metrics.query_all("compute_pg_downtime_ms_total")
assert len(compute_pg_downtime_ms_total) == 1
assert compute_pg_downtime_ms_total[0].value > 0
wait_until(check_metrics_up)
# Just a sanity check that we log the downtime info
endpoint.log_contains("downtime_info")

View File

@@ -4,12 +4,17 @@ import os
import platform
import shutil
import tarfile
from typing import TYPE_CHECKING
from enum import StrEnum
from pathlib import Path
from typing import TYPE_CHECKING, cast, final
import pytest
import zstandard
from fixtures.log_helper import log
from fixtures.metrics import parse_metrics
from fixtures.paths import BASE_DIR
from fixtures.pg_config import PgConfigKey
from fixtures.utils import subprocess_capture
from werkzeug.wrappers.response import Response
if TYPE_CHECKING:
@@ -20,6 +25,7 @@ if TYPE_CHECKING:
from fixtures.neon_fixtures import (
NeonEnvBuilder,
)
from fixtures.pg_config import PgConfig
from fixtures.pg_version import PgVersion
from pytest_httpserver import HTTPServer
from werkzeug.wrappers.request import Request
@@ -46,46 +52,108 @@ def neon_env_builder_local(
return neon_env_builder
@final
class RemoteExtension(StrEnum):
SQL_ONLY = "test_extension_sql_only"
WITH_LIB = "test_extension_with_lib"
@property
def compressed_tarball_name(self) -> str:
return f"{self.tarball_name}.zst"
@property
def control_file_name(self) -> str:
return f"{self}.control"
@property
def directory(self) -> Path:
return BASE_DIR / "test_runner" / "regress" / "data" / "test_remote_extensions" / self
@property
def shared_library_name(self) -> str:
return f"{self}.so"
@property
def tarball_name(self) -> str:
return f"{self}.tar"
def archive_route(self, build_tag: str, arch: str, pg_version: PgVersion) -> str:
return f"{build_tag}/{arch}/v{pg_version}/extensions/{self.compressed_tarball_name}"
def build(self, pg_config: PgConfig, output_dir: Path) -> None:
if self is not RemoteExtension.WITH_LIB:
return
cmd: list[str] = [
*cast("list[str]", pg_config[PgConfigKey.CC]),
*cast("list[str]", pg_config[PgConfigKey.CPPFLAGS]),
*["-I", str(cast("Path", pg_config[PgConfigKey.INCLUDEDIR_SERVER]))],
*cast("list[str]", pg_config[PgConfigKey.CFLAGS]),
*cast("list[str]", pg_config[PgConfigKey.CFLAGS_SL]),
*cast("list[str]", pg_config[PgConfigKey.LDFLAGS_EX]),
*cast("list[str]", pg_config[PgConfigKey.LDFLAGS_SL]),
"-shared",
*["-o", str(output_dir / self.shared_library_name)],
str(self.directory / "src" / f"{self}.c"),
]
subprocess_capture(output_dir, cmd, check=True)
def control_file_contents(self) -> str:
with open(self.directory / self.control_file_name, encoding="utf-8") as f:
return f.read()
def files(self, output_dir: Path) -> dict[Path, str]:
files = {
# self.directory / self.control_file_name: f"share/extension/{self.control_file_name}",
self.directory / "sql" / f"{self}--1.0.sql": f"share/extension/{self}--1.0.sql",
self.directory
/ "sql"
/ f"{self}--1.0--1.1.sql": f"share/extension/{self}--1.0--1.1.sql",
}
if self is RemoteExtension.WITH_LIB:
files[output_dir / self.shared_library_name] = f"lib/{self.shared_library_name}"
return files
def package(self, output_dir: Path) -> Path:
tarball = output_dir / self.tarball_name
with tarfile.open(tarball, "x") as tarf:
for file, arcname in self.files(output_dir).items():
tarf.add(file, arcname=arcname)
return tarball
def remove(self, output_dir: Path, pg_version: PgVersion) -> None:
for file in self.files(output_dir).values():
if file.startswith("share/extension"):
file = f"share/postgresql/extension/{os.path.basename(file)}"
if file.startswith("lib"):
file = f"lib/postgresql/{os.path.basename(file)}"
(output_dir / "pg_install" / f"v{pg_version}" / file).unlink()
@pytest.mark.parametrize(
"extension",
(RemoteExtension.SQL_ONLY, RemoteExtension.WITH_LIB),
ids=["sql_only", "with_lib"],
)
def test_remote_extensions(
httpserver: HTTPServer,
neon_env_builder_local: NeonEnvBuilder,
httpserver_listen_address: ListenAddress,
test_output_dir: Path,
base_dir: Path,
pg_version: PgVersion,
pg_config: PgConfig,
extension: RemoteExtension,
):
# Setup a mock nginx S3 gateway which will return our test extension.
(host, port) = httpserver_listen_address
extensions_endpoint = f"http://{host}:{port}/pg-ext-s3-gateway"
build_tag = os.environ.get("BUILD_TAG", "latest")
# We have decided to use the Go naming convention due to Kubernetes.
arch = platform.machine()
match arch:
case "aarch64":
arch = "arm64"
case "x86_64":
arch = "amd64"
case _:
pass
archive_route = f"{build_tag}/{arch}/v{pg_version}/extensions/test_extension.tar.zst"
tarball = test_output_dir / "test_extension.tar"
extension_dir = (
base_dir / "test_runner" / "regress" / "data" / "test_remote_extensions" / "test_extension"
)
# Create tarball
with tarfile.open(tarball, "x") as tarf:
tarf.add(
extension_dir / "sql" / "test_extension--1.0.sql",
arcname="share/extension/test_extension--1.0.sql",
)
tarf.add(
extension_dir / "sql" / "test_extension--1.0--1.1.sql",
arcname="share/extension/test_extension--1.0--1.1.sql",
)
extension.build(pg_config, test_output_dir)
tarball = extension.package(test_output_dir)
def handler(request: Request) -> Response:
log.info(f"request: {request}")
@@ -104,8 +172,19 @@ def test_remote_extensions(
direct_passthrough=True,
)
# We have decided to use the Go naming convention due to Kubernetes.
arch = platform.machine()
match arch:
case "aarch64":
arch = "arm64"
case "x86_64":
arch = "amd64"
case _:
pass
httpserver.expect_request(
f"/pg-ext-s3-gateway/{archive_route}", method="GET"
f"/pg-ext-s3-gateway/{extension.archive_route(build_tag=os.environ.get('BUILD_TAG', 'latest'), arch=arch, pg_version=pg_version)}",
method="GET",
).respond_with_handler(handler)
# Start a compute node with remote_extension spec
@@ -114,21 +193,18 @@ def test_remote_extensions(
env.create_branch("test_remote_extensions")
endpoint = env.endpoints.create("test_remote_extensions")
with open(extension_dir / "test_extension.control", encoding="utf-8") as f:
control_data = f.read()
# mock remote_extensions spec
spec: dict[str, Any] = {
"public_extensions": ["test_extension"],
"public_extensions": [extension],
"custom_extensions": None,
"library_index": {
"test_extension": "test_extension",
extension: extension,
},
"extension_data": {
"test_extension": {
extension: {
"archive_path": "",
"control_data": {
"test_extension.control": control_data,
extension.control_file_name: extension.control_file_contents(),
},
},
},
@@ -141,8 +217,8 @@ def test_remote_extensions(
with endpoint.connect() as conn:
with conn.cursor() as cur:
# Check that appropriate files were downloaded
cur.execute("CREATE EXTENSION test_extension VERSION '1.0'")
cur.execute("SELECT test_extension.motd()")
cur.execute(f"CREATE EXTENSION {extension} VERSION '1.0'")
cur.execute(f"SELECT {extension}.motd()")
httpserver.check()
@@ -153,7 +229,7 @@ def test_remote_extensions(
remote_ext_requests = metrics.query_all(
"compute_ctl_remote_ext_requests_total",
# Check that we properly report the filename in the metrics
{"filename": "test_extension.tar.zst"},
{"filename": extension.compressed_tarball_name},
)
assert len(remote_ext_requests) == 1
for sample in remote_ext_requests:
@@ -162,20 +238,7 @@ def test_remote_extensions(
endpoint.stop()
# Remove the extension files to force a redownload of the extension.
for file in (
"test_extension.control",
"test_extension--1.0.sql",
"test_extension--1.0--1.1.sql",
):
(
test_output_dir
/ "pg_install"
/ f"v{pg_version}"
/ "share"
/ "postgresql"
/ "extension"
/ file
).unlink()
extension.remove(test_output_dir, pg_version)
endpoint.start(remote_ext_config=extensions_endpoint)
@@ -183,8 +246,8 @@ def test_remote_extensions(
with endpoint.connect() as conn:
with conn.cursor() as cur:
# Check that appropriate files were downloaded
cur.execute("ALTER EXTENSION test_extension UPDATE TO '1.1'")
cur.execute("SELECT test_extension.fun_fact()")
cur.execute(f"ALTER EXTENSION {extension} UPDATE TO '1.1'")
cur.execute(f"SELECT {extension}.fun_fact()")
# Check that we properly recorded downloads in the metrics
client = endpoint.http_client()
@@ -193,7 +256,7 @@ def test_remote_extensions(
remote_ext_requests = metrics.query_all(
"compute_ctl_remote_ext_requests_total",
# Check that we properly report the filename in the metrics
{"filename": "test_extension.tar.zst"},
{"filename": extension.compressed_tarball_name},
)
assert len(remote_ext_requests) == 1
for sample in remote_ext_requests:

View File

@@ -8,7 +8,7 @@ from jwcrypto import jwk, jwt
@pytest.mark.asyncio
async def test_object_storage_insert_retrieve_delete(neon_simple_env: NeonEnv):
async def test_endpoint_storage_insert_retrieve_delete(neon_simple_env: NeonEnv):
"""
Inserts, retrieves, and deletes test file using a JWT token
"""
@@ -31,7 +31,7 @@ async def test_object_storage_insert_retrieve_delete(neon_simple_env: NeonEnv):
token.make_signed_token(key)
token = token.serialize()
base_url = env.object_storage.base_url()
base_url = env.endpoint_storage.base_url()
key = f"http://{base_url}/{tenant_id}/{timeline_id}/{endpoint_id}/key"
headers = {"Authorization": f"Bearer {token}"}
log.info(f"cache key url {key}")

View File

@@ -1,9 +1,9 @@
import base64
import json
import re
import time
from enum import Enum
from pathlib import Path
from threading import Event
import psycopg2
import psycopg2.errors
@@ -14,12 +14,16 @@ from fixtures.log_helper import log
from fixtures.neon_fixtures import NeonEnvBuilder, PgBin, PgProtocol, VanillaPostgres
from fixtures.pageserver.http import (
ImportPgdataIdemptencyKey,
PageserverApiException,
)
from fixtures.pg_version import PgVersion
from fixtures.port_distributor import PortDistributor
from fixtures.remote_storage import MockS3Server, RemoteStorageKind
from fixtures.utils import shared_buffers_for_max_cu
from fixtures.utils import (
run_only_on_default_postgres,
shared_buffers_for_max_cu,
skip_in_debug_build,
wait_until,
)
from mypy_boto3_kms import KMSClient
from mypy_boto3_kms.type_defs import EncryptResponseTypeDef
from mypy_boto3_s3 import S3Client
@@ -44,6 +48,25 @@ smoke_params = [
]
def mock_import_bucket(vanilla_pg: VanillaPostgres, path: Path):
"""
Mock the import S3 bucket into a local directory for a provided vanilla PG instance.
"""
assert not vanilla_pg.is_running()
path.mkdir()
# what cplane writes before scheduling fast_import
specpath = path / "spec.json"
specpath.write_text(json.dumps({"branch_id": "somebranch", "project_id": "someproject"}))
# what fast_import writes
vanilla_pg.pgdatadir.rename(path / "pgdata")
statusdir = path / "status"
statusdir.mkdir()
(statusdir / "pgdata").write_text(json.dumps({"done": True}))
(statusdir / "fast_import").write_text(json.dumps({"command": "pgdata", "done": True}))
@skip_in_debug_build("MULTIPLE_RELATION_SEGMENTS has non trivial amount of data")
@pytest.mark.parametrize("shard_count,stripe_size,rel_block_size", smoke_params)
def test_pgdata_import_smoke(
vanilla_pg: VanillaPostgres,
@@ -56,24 +79,29 @@ def test_pgdata_import_smoke(
#
# Setup fake control plane for import progress
#
import_completion_signaled = Event()
def handler(request: Request) -> Response:
log.info(f"control plane request: {request.json}")
log.info(f"control plane /import_complete request: {request.json}")
import_completion_signaled.set()
return Response(json.dumps({}), status=200)
cplane_mgmt_api_server = make_httpserver
cplane_mgmt_api_server.expect_request(re.compile(".*")).respond_with_handler(handler)
cplane_mgmt_api_server.expect_request(
"/storage/api/v1/import_complete", method="PUT"
).respond_with_handler(handler)
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
neon_env_builder.control_plane_hooks_api = (
f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/storage/api/v1/"
)
env = neon_env_builder.init_start()
# The test needs LocalFs support, which is only built in testing mode.
env.pageserver.is_testing_enabled_or_skip()
env.pageserver.patch_config_toml_nonrecursive(
{
"import_pgdata_upcall_api": f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/path/to/mgmt/api"
}
)
env.pageserver.stop()
env.pageserver.start()
@@ -150,17 +178,8 @@ def test_pgdata_import_smoke(
# TODO: actually exercise fast_import here
# TODO: test s3 remote storage
#
importbucket = neon_env_builder.repo_dir / "importbucket"
importbucket.mkdir()
# what cplane writes before scheduling fast_import
specpath = importbucket / "spec.json"
specpath.write_text(json.dumps({"branch_id": "somebranch", "project_id": "someproject"}))
# what fast_import writes
vanilla_pg.pgdatadir.rename(importbucket / "pgdata")
statusdir = importbucket / "status"
statusdir.mkdir()
(statusdir / "pgdata").write_text(json.dumps({"done": True}))
(statusdir / "fast_import").write_text(json.dumps({"command": "pgdata", "done": True}))
importbucket_path = neon_env_builder.repo_dir / "importbucket"
mock_import_bucket(vanilla_pg, importbucket_path)
#
# Do the import
@@ -187,46 +206,17 @@ def test_pgdata_import_smoke(
"new_timeline_id": str(timeline_id),
"import_pgdata": {
"idempotency_key": str(idempotency),
"location": {"LocalFs": {"path": str(importbucket.absolute())}},
"location": {"LocalFs": {"path": str(importbucket_path.absolute())}},
},
},
)
env.neon_cli.mappings_map_branch(import_branch_name, tenant_id, timeline_id)
while True:
locations = env.storage_controller.locate(tenant_id)
active_count = 0
for location in locations:
shard_id = TenantShardId.parse(location["shard_id"])
ps = env.get_pageserver(location["node_id"])
try:
detail = ps.http_client().timeline_detail(shard_id, timeline_id)
state = detail["state"]
log.info(f"shard {shard_id} state: {state}")
if state == "Active":
active_count += 1
except PageserverApiException as e:
if e.status_code == 404:
log.info("not found, import is in progress")
continue
elif e.status_code == 429:
log.info("import is in progress")
continue
else:
raise
def cplane_notified():
assert import_completion_signaled.is_set()
shard_status_file = statusdir / f"shard-{shard_id.shard_index}"
if state == "Active":
shard_status_file_contents = (
shard_status_file.read_text()
) # Active state implies import is done
shard_status = json.loads(shard_status_file_contents)
assert shard_status["done"] is True
if active_count == len(locations):
log.info("all shards are active")
break
time.sleep(1)
# Generous timeout for the MULTIPLE_RELATION_SEGMENTS test variants
wait_until(cplane_notified, timeout=90)
import_duration = time.monotonic() - start
log.info(f"import complete; duration={import_duration:.2f}s")
@@ -343,6 +333,87 @@ def test_pgdata_import_smoke(
br_initdb_endpoint.safe_psql("select * from othertable")
@run_only_on_default_postgres(reason="PG version is irrelevant here")
def test_import_completion_on_restart(
neon_env_builder: NeonEnvBuilder, vanilla_pg: VanillaPostgres, make_httpserver: HTTPServer
):
"""
Validate that the storage controller delivers the import completion notification
eventually even if it was restarted when the import initially completed.
"""
# Set up mock control plane HTTP server to listen for import completions
import_completion_signaled = Event()
def handler(request: Request) -> Response:
log.info(f"control plane /import_complete request: {request.json}")
import_completion_signaled.set()
return Response(json.dumps({}), status=200)
cplane_mgmt_api_server = make_httpserver
cplane_mgmt_api_server.expect_request(
"/storage/api/v1/import_complete", method="PUT"
).respond_with_handler(handler)
# Plug the cplane mock in
neon_env_builder.control_plane_hooks_api = (
f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/storage/api/v1/"
)
# The import will specifiy a local filesystem path mocking remote storage
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.LOCAL_FS)
vanilla_pg.start()
vanilla_pg.stop()
env = neon_env_builder.init_configs()
env.start()
importbucket_path = neon_env_builder.repo_dir / "test_import_completion_bucket"
mock_import_bucket(vanilla_pg, importbucket_path)
tenant_id = TenantId.generate()
timeline_id = TimelineId.generate()
idempotency = ImportPgdataIdemptencyKey.random()
# Pause before sending the notification
failpoint_name = "timeline-import-pre-cplane-notification"
env.storage_controller.configure_failpoints((failpoint_name, "pause"))
env.storage_controller.tenant_create(tenant_id)
env.storage_controller.timeline_create(
tenant_id,
{
"new_timeline_id": str(timeline_id),
"import_pgdata": {
"idempotency_key": str(idempotency),
"location": {"LocalFs": {"path": str(importbucket_path.absolute())}},
},
},
)
def hit_failpoint():
log.info("Checking log for pattern...")
try:
assert env.storage_controller.log_contains(f".*at failpoint {failpoint_name}.*")
except Exception:
log.exception("Failed to find pattern in log")
raise
wait_until(hit_failpoint)
assert not import_completion_signaled.is_set()
# Restart the storage controller before signalling control plane.
# This clears the failpoint and we expect that the import start-up reconciliation
# kicks in and notifies cplane.
env.storage_controller.stop()
env.storage_controller.start()
def cplane_notified():
assert import_completion_signaled.is_set()
wait_until(cplane_notified)
def test_fast_import_with_pageserver_ingest(
test_output_dir,
vanilla_pg: VanillaPostgres,
@@ -372,19 +443,27 @@ def test_fast_import_with_pageserver_ingest(
vanilla_pg.safe_psql("CREATE TABLE foo (a int); INSERT INTO foo SELECT generate_series(1, 10);")
# Setup pageserver and fake cplane for import progress
import_completion_signaled = Event()
def handler(request: Request) -> Response:
log.info(f"control plane request: {request.json}")
log.info(f"control plane /import_complete request: {request.json}")
import_completion_signaled.set()
return Response(json.dumps({}), status=200)
cplane_mgmt_api_server = make_httpserver
cplane_mgmt_api_server.expect_request(re.compile(".*")).respond_with_handler(handler)
cplane_mgmt_api_server.expect_request(
"/storage/api/v1/import_complete", method="PUT"
).respond_with_handler(handler)
neon_env_builder.control_plane_hooks_api = (
f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/storage/api/v1/"
)
neon_env_builder.enable_pageserver_remote_storage(RemoteStorageKind.MOCK_S3)
env = neon_env_builder.init_start()
env.pageserver.patch_config_toml_nonrecursive(
{
"import_pgdata_upcall_api": f"http://{cplane_mgmt_api_server.host}:{cplane_mgmt_api_server.port}/path/to/mgmt/api",
# because import_pgdata code uses this endpoint, not the one in common remote storage config
# TODO: maybe use common remote_storage config in pageserver?
"import_pgdata_aws_endpoint_url": env.s3_mock_server.endpoint(),
@@ -476,42 +555,10 @@ def test_fast_import_with_pageserver_ingest(
conn = PgProtocol(dsn=f"postgresql://cloud_admin@localhost:{pg_port}/neondb")
validate_vanilla_equivalence(conn)
# Poll pageserver statuses in s3
while True:
locations = env.storage_controller.locate(tenant_id)
active_count = 0
for location in locations:
shard_id = TenantShardId.parse(location["shard_id"])
ps = env.get_pageserver(location["node_id"])
try:
detail = ps.http_client().timeline_detail(shard_id, timeline_id)
log.info(f"timeline {tenant_id}/{timeline_id} detail: {detail}")
state = detail["state"]
log.info(f"shard {shard_id} state: {state}")
if state == "Active":
active_count += 1
except PageserverApiException as e:
if e.status_code == 404:
log.info("not found, import is in progress")
continue
elif e.status_code == 429:
log.info("import is in progress")
continue
else:
raise
def cplane_notified():
assert import_completion_signaled.is_set()
if state == "Active":
key = f"{key_prefix}/status/shard-{shard_id.shard_index}"
shard_status_file_contents = (
mock_s3_client.get_object(Bucket=bucket, Key=key)["Body"].read().decode("utf-8")
)
shard_status = json.loads(shard_status_file_contents)
assert shard_status["done"] is True
if active_count == len(locations):
log.info("all shards are active")
break
time.sleep(0.5)
wait_until(cplane_notified, timeout=60)
import_duration = time.monotonic() - start
log.info(f"import complete; duration={import_duration:.2f}s")

Some files were not shown because too many files have changed in this diff Show More