Compare commits

..

6 Commits

Author SHA1 Message Date
Konstantin Knizhnik
4fed569517 Add description of test_logical_replication_ondemand_download test 2025-07-07 18:26:04 +03:00
Konstantin Knizhnik
aa0986de85 Reduce table size in on-demand WAL download test 2025-07-04 22:45:55 +03:00
Konstantin Knizhnik
a2b4ca23a0 Make ruff happy 2025-07-04 22:42:44 +03:00
Konstantin Knizhnik
b40d367f4b Reduce table size in on-demand WAL download test 2025-07-04 22:18:53 +03:00
Konstantin Knizhnik
1e34b8e16a Add test for pg_wal size 2025-07-04 14:15:57 +03:00
Konstantin Knizhnik
c8cab8803b Do not invalidate obsolete slot if on demand wal download is supported 2025-07-03 15:25:31 +03:00
135 changed files with 1867 additions and 4386 deletions

View File

@@ -33,7 +33,6 @@ workspace-members = [
"compute_api",
"consumption_metrics",
"desim",
"json",
"metrics",
"pageserver_api",
"postgres_backend",

View File

@@ -7,7 +7,6 @@ self-hosted-runner:
- small-metal
- small-arm64
- unit-perf
- unit-perf-aws-arm
- us-east-2
config-variables:
- AWS_ECR_REGION

View File

@@ -32,14 +32,162 @@ permissions:
contents: read
jobs:
make-all:
build-pgxn:
if: |
inputs.pg_versions != '[]' || inputs.rebuild_everything ||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-macos') ||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-*') ||
github.ref_name == 'main'
timeout-minutes: 30
runs-on: macos-15
strategy:
matrix:
postgres-version: ${{ inputs.rebuild_everything && fromJSON('["v14", "v15", "v16", "v17"]') || fromJSON(inputs.pg_versions) }}
env:
# Use release build only, to have less debug info around
# Hence keeping target/ (and general cache size) smaller
BUILD_TYPE: release
steps:
- name: Harden the runner (Audit all outbound calls)
uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0
with:
egress-policy: audit
- name: Checkout main repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set pg ${{ matrix.postgres-version }} for caching
id: pg_rev
run: echo pg_rev=$(git rev-parse HEAD:vendor/postgres-${{ matrix.postgres-version }}) | tee -a "${GITHUB_OUTPUT}"
- name: Cache postgres ${{ matrix.postgres-version }} build
id: cache_pg
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
with:
path: pg_install/${{ matrix.postgres-version }}
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-pg-${{ matrix.postgres-version }}-${{ steps.pg_rev.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
- name: Checkout submodule vendor/postgres-${{ matrix.postgres-version }}
if: steps.cache_pg.outputs.cache-hit != 'true'
run: |
git submodule init vendor/postgres-${{ matrix.postgres-version }}
git submodule update --depth 1 --recursive
- name: Install build dependencies
if: steps.cache_pg.outputs.cache-hit != 'true'
run: |
brew install flex bison openssl protobuf icu4c
- name: Set extra env for macOS
if: steps.cache_pg.outputs.cache-hit != 'true'
run: |
echo 'LDFLAGS=-L/usr/local/opt/openssl@3/lib' >> $GITHUB_ENV
echo 'CPPFLAGS=-I/usr/local/opt/openssl@3/include' >> $GITHUB_ENV
- name: Build Postgres ${{ matrix.postgres-version }}
if: steps.cache_pg.outputs.cache-hit != 'true'
run: |
make postgres-${{ matrix.postgres-version }} -j$(sysctl -n hw.ncpu)
- name: Build Neon Pg Ext ${{ matrix.postgres-version }}
if: steps.cache_pg.outputs.cache-hit != 'true'
run: |
make "neon-pg-ext-${{ matrix.postgres-version }}" -j$(sysctl -n hw.ncpu)
- name: Upload "pg_install/${{ matrix.postgres-version }}" artifact
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: pg_install--${{ matrix.postgres-version }}
path: pg_install/${{ matrix.postgres-version }}
# The artifact is supposed to be used by the next job in the same workflow,
# so theres no need to store it for too long.
retention-days: 1
build-walproposer-lib:
if: |
contains(inputs.pg_versions, 'v17') || inputs.rebuild_everything ||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-macos') ||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-*') ||
github.ref_name == 'main'
timeout-minutes: 30
runs-on: macos-15
needs: [build-pgxn]
env:
# Use release build only, to have less debug info around
# Hence keeping target/ (and general cache size) smaller
BUILD_TYPE: release
steps:
- name: Harden the runner (Audit all outbound calls)
uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0
with:
egress-policy: audit
- name: Checkout main repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set pg v17 for caching
id: pg_rev
run: echo pg_rev=$(git rev-parse HEAD:vendor/postgres-v17) | tee -a "${GITHUB_OUTPUT}"
- name: Download "pg_install/v17" artifact
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
name: pg_install--v17
path: pg_install/v17
# `actions/download-artifact` doesn't preserve permissions:
# https://github.com/actions/download-artifact?tab=readme-ov-file#permission-loss
- name: Make pg_install/v*/bin/* executable
run: |
chmod +x pg_install/v*/bin/*
- name: Cache walproposer-lib
id: cache_walproposer_lib
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
with:
path: build/walproposer-lib
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-walproposer_lib-v17-${{ steps.pg_rev.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
- name: Checkout submodule vendor/postgres-v17
if: steps.cache_walproposer_lib.outputs.cache-hit != 'true'
run: |
git submodule init vendor/postgres-v17
git submodule update --depth 1 --recursive
- name: Install build dependencies
if: steps.cache_walproposer_lib.outputs.cache-hit != 'true'
run: |
brew install flex bison openssl protobuf icu4c
- name: Set extra env for macOS
if: steps.cache_walproposer_lib.outputs.cache-hit != 'true'
run: |
echo 'LDFLAGS=-L/usr/local/opt/openssl@3/lib' >> $GITHUB_ENV
echo 'CPPFLAGS=-I/usr/local/opt/openssl@3/include' >> $GITHUB_ENV
- name: Build walproposer-lib (only for v17)
if: steps.cache_walproposer_lib.outputs.cache-hit != 'true'
run:
make walproposer-lib -j$(sysctl -n hw.ncpu) PG_INSTALL_CACHED=1
- name: Upload "build/walproposer-lib" artifact
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: build--walproposer-lib
path: build/walproposer-lib
# The artifact is supposed to be used by the next job in the same workflow,
# so theres no need to store it for too long.
retention-days: 1
cargo-build:
if: |
inputs.pg_versions != '[]' || inputs.rebuild_rust_code || inputs.rebuild_everything ||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-macos') ||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-*') ||
github.ref_name == 'main'
timeout-minutes: 60
timeout-minutes: 30
runs-on: macos-15
needs: [build-pgxn, build-walproposer-lib]
env:
# Use release build only, to have less debug info around
# Hence keeping target/ (and general cache size) smaller
@@ -55,53 +203,41 @@ jobs:
with:
submodules: true
- name: Install build dependencies
run: |
brew install flex bison openssl protobuf icu4c
- name: Set extra env for macOS
run: |
echo 'LDFLAGS=-L/usr/local/opt/openssl@3/lib' >> $GITHUB_ENV
echo 'CPPFLAGS=-I/usr/local/opt/openssl@3/include' >> $GITHUB_ENV
- name: Restore "pg_install/" cache
id: cache_pg
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
- name: Download "pg_install/v14" artifact
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
path: pg_install
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-pg-install-v14-${{ hashFiles('Makefile', 'postgres.mk', 'vendor/revisions.json') }}
name: pg_install--v14
path: pg_install/v14
- name: Checkout vendor/postgres submodules
if: steps.cache_pg.outputs.cache-hit != 'true'
run: |
git submodule init
git submodule update --depth 1 --recursive
- name: Download "pg_install/v15" artifact
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
name: pg_install--v15
path: pg_install/v15
- name: Build Postgres
if: steps.cache_pg.outputs.cache-hit != 'true'
run: |
make postgres -j$(sysctl -n hw.ncpu)
- name: Download "pg_install/v16" artifact
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
name: pg_install--v16
path: pg_install/v16
# This isn't strictly necessary, but it makes the cached and non-cached builds more similar,
# When pg_install is restored from cache, there is no 'build/' directory. By removing it
# in a non-cached build too, we enforce that the rest of the steps don't depend on it,
# so that we notice any build caching bugs earlier.
- name: Remove build artifacts
if: steps.cache_pg.outputs.cache-hit != 'true'
run: |
rm -rf build
- name: Download "pg_install/v17" artifact
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
name: pg_install--v17
path: pg_install/v17
# Explicitly update the rust toolchain before running 'make'. The parallel make build can
# invoke 'cargo build' more than once in parallel, for different crates. That's OK, 'cargo'
# does its own locking to prevent concurrent builds from stepping on each other's
# toes. However, it will first try to update the toolchain, and that step is not locked the
# same way. To avoid two toolchain updates running in parallel and stepping on each other's
# toes, ensure that the toolchain is up-to-date beforehand.
- name: Update rust toolchain
- name: Download "build/walproposer-lib" artifact
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
name: build--walproposer-lib
path: build/walproposer-lib
# `actions/download-artifact` doesn't preserve permissions:
# https://github.com/actions/download-artifact?tab=readme-ov-file#permission-loss
- name: Make pg_install/v*/bin/* executable
run: |
rustup --version &&
rustup update &&
rustup show
chmod +x pg_install/v*/bin/*
- name: Cache cargo deps
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
@@ -113,12 +249,17 @@ jobs:
target
key: v1-${{ runner.os }}-${{ runner.arch }}-cargo-${{ hashFiles('./Cargo.lock') }}-${{ hashFiles('./rust-toolchain.toml') }}-rust
# Build the neon-specific postgres extensions, and all the Rust bits.
#
# Pass PG_INSTALL_CACHED=1 because PostgreSQL was already built and cached
# separately.
- name: Build all
run: PG_INSTALL_CACHED=1 BUILD_TYPE=release make -j$(sysctl -n hw.ncpu) all
- name: Install build dependencies
run: |
brew install flex bison openssl protobuf icu4c
- name: Set extra env for macOS
run: |
echo 'LDFLAGS=-L/usr/local/opt/openssl@3/lib' >> $GITHUB_ENV
echo 'CPPFLAGS=-I/usr/local/opt/openssl@3/include' >> $GITHUB_ENV
- name: Run cargo build
run: cargo build --all --release -j$(sysctl -n hw.ncpu)
- name: Check that no warnings are produced
run: ./run_clippy.sh

View File

@@ -306,14 +306,14 @@ jobs:
statuses: write
contents: write
pull-requests: write
runs-on: [ self-hosted, unit-perf-aws-arm ]
runs-on: [ self-hosted, unit-perf ]
container:
image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
# for changed limits, see comments on `options:` earlier in this file
options: --init --shm-size=512mb --ulimit memlock=67108864:67108864 --ulimit nofile=65536:65536 --security-opt seccomp=unconfined
options: --init --shm-size=512mb --ulimit memlock=67108864:67108864
strategy:
fail-fast: false
matrix:

View File

@@ -1,4 +1,4 @@
name: Periodic pagebench performance test on unit-perf-aws-arm runners
name: Periodic pagebench performance test on unit-perf hetzner runner
on:
schedule:
@@ -40,7 +40,7 @@ jobs:
statuses: write
contents: write
pull-requests: write
runs-on: [ self-hosted, unit-perf-aws-arm ]
runs-on: [ self-hosted, unit-perf ]
container:
image: ghcr.io/neondatabase/build-tools:pinned-bookworm
credentials:

View File

@@ -1,4 +1,4 @@
name: Periodic proxy performance test on unit-perf-aws-arm runners
name: Periodic proxy performance test on unit-perf hetzner runner
on:
push: # TODO: remove after testing
@@ -32,7 +32,7 @@ jobs:
statuses: write
contents: write
pull-requests: write
runs-on: [self-hosted, unit-perf-aws-arm]
runs-on: [self-hosted, unit-perf]
timeout-minutes: 60 # 1h timeout
container:
image: ghcr.io/neondatabase/build-tools:pinned-bookworm

40
Cargo.lock generated
View File

@@ -1083,25 +1083,6 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cbindgen"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "975982cdb7ad6a142be15bdf84aea7ec6a9e5d4d797c004d43185b24cfe4e684"
dependencies = [
"clap",
"heck",
"indexmap 2.9.0",
"log",
"proc-macro2",
"quote",
"serde",
"serde_json",
"syn 2.0.100",
"tempfile",
"toml",
]
[[package]]
name = "cc"
version = "1.2.16"
@@ -1286,15 +1267,6 @@ dependencies = [
"unicode-width",
]
[[package]]
name = "communicator"
version = "0.1.0"
dependencies = [
"cbindgen",
"neon-shmem",
"workspace_hack",
]
[[package]]
name = "compute_api"
version = "0.1.0"
@@ -3489,15 +3461,6 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "json"
version = "0.1.0"
dependencies = [
"futures",
"itoa",
"ryu",
]
[[package]]
name = "json-structural-diff"
version = "0.2.0"
@@ -8702,10 +8665,8 @@ dependencies = [
"fail",
"form_urlencoded",
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-util",
"generic-array",
"getrandom 0.2.11",
@@ -8732,7 +8693,6 @@ dependencies = [
"num-iter",
"num-rational",
"num-traits",
"once_cell",
"p256 0.13.2",
"parquet",
"prettyplease",

View File

@@ -42,12 +42,10 @@ members = [
"libs/walproposer",
"libs/wal_decoder",
"libs/postgres_initdb",
"libs/proxy/json",
"libs/proxy/postgres-protocol2",
"libs/proxy/postgres-types2",
"libs/proxy/tokio-postgres2",
"endpoint_storage",
"pgxn/neon/communicator",
]
[workspace.package]
@@ -257,7 +255,6 @@ desim = { version = "0.1", path = "./libs/desim" }
endpoint_storage = { version = "0.0.1", path = "./endpoint_storage/" }
http-utils = { version = "0.1", path = "./libs/http-utils/" }
metrics = { version = "0.1", path = "./libs/metrics/" }
neon-shmem = { version = "0.1", path = "./libs/neon-shmem/" }
pageserver = { path = "./pageserver" }
pageserver_api = { version = "0.1", path = "./libs/pageserver_api/" }
pageserver_client = { path = "./pageserver/client" }
@@ -287,7 +284,6 @@ walproposer = { version = "0.1", path = "./libs/walproposer/" }
workspace_hack = { version = "0.1", path = "./workspace_hack/" }
## Build dependencies
cbindgen = "0.29.0"
criterion = "0.5.1"
rcgen = "0.13"
rstest = "0.18"

View File

@@ -30,18 +30,7 @@ ARG BASE_IMAGE_SHA=debian:${DEBIAN_FLAVOR}
ARG BASE_IMAGE_SHA=${BASE_IMAGE_SHA/debian:bookworm-slim/debian@$BOOKWORM_SLIM_SHA}
ARG BASE_IMAGE_SHA=${BASE_IMAGE_SHA/debian:bullseye-slim/debian@$BULLSEYE_SLIM_SHA}
# Naive way:
#
# 1. COPY . .
# 1. make neon-pg-ext
# 2. cargo build <storage binaries>
#
# But to enable docker to cache intermediate layers, we perform a few preparatory steps:
#
# - Build all postgres versions, depending on just the contents of vendor/
# - Use cargo chef to build all rust dependencies
# 1. Build all postgres versions
# Build Postgres
FROM $REPOSITORY/$IMAGE:$TAG AS pg-build
WORKDIR /home/nonroot
@@ -49,15 +38,17 @@ COPY --chown=nonroot vendor/postgres-v14 vendor/postgres-v14
COPY --chown=nonroot vendor/postgres-v15 vendor/postgres-v15
COPY --chown=nonroot vendor/postgres-v16 vendor/postgres-v16
COPY --chown=nonroot vendor/postgres-v17 vendor/postgres-v17
COPY --chown=nonroot pgxn pgxn
COPY --chown=nonroot Makefile Makefile
COPY --chown=nonroot postgres.mk postgres.mk
COPY --chown=nonroot scripts/ninstall.sh scripts/ninstall.sh
ENV BUILD_TYPE=release
RUN set -e \
&& mold -run make -j $(nproc) -s postgres
&& mold -run make -j $(nproc) -s neon-pg-ext \
&& tar -C pg_install -czf /home/nonroot/postgres_install.tar.gz .
# 2. Prepare cargo-chef recipe
# Prepare cargo-chef recipe
FROM $REPOSITORY/$IMAGE:$TAG AS plan
WORKDIR /home/nonroot
@@ -65,22 +56,23 @@ COPY --chown=nonroot . .
RUN cargo chef prepare --recipe-path recipe.json
# Main build image
# Build neon binaries
FROM $REPOSITORY/$IMAGE:$TAG AS build
WORKDIR /home/nonroot
ARG GIT_VERSION=local
ARG BUILD_TAG
COPY --from=pg-build /home/nonroot/pg_install/v14/include/postgresql/server pg_install/v14/include/postgresql/server
COPY --from=pg-build /home/nonroot/pg_install/v15/include/postgresql/server pg_install/v15/include/postgresql/server
COPY --from=pg-build /home/nonroot/pg_install/v16/include/postgresql/server pg_install/v16/include/postgresql/server
COPY --from=pg-build /home/nonroot/pg_install/v17/include/postgresql/server pg_install/v17/include/postgresql/server
COPY --from=plan /home/nonroot/recipe.json recipe.json
ARG ADDITIONAL_RUSTFLAGS=""
# 3. Build cargo dependencies. Note that this step doesn't depend on anything else than
# `recipe.json`, so the layer can be reused as long as none of the dependencies change.
COPY --from=plan /home/nonroot/recipe.json recipe.json
RUN set -e \
&& RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo chef cook --locked --release --recipe-path recipe.json
# Perform the main build. We reuse the Postgres build artifacts from the intermediate 'pg-build'
# layer, and the cargo dependencies built in the previous step.
COPY --chown=nonroot --from=pg-build /home/nonroot/pg_install/ pg_install
COPY --chown=nonroot . .
RUN set -e \
@@ -95,10 +87,10 @@ RUN set -e \
--bin endpoint_storage \
--bin neon_local \
--bin storage_scrubber \
--locked --release \
&& mold -run make -j $(nproc) -s neon-pg-ext
--locked --release
# Assemble the final image
# Build final image
#
FROM $BASE_IMAGE_SHA
WORKDIR /data
@@ -138,15 +130,12 @@ COPY --from=build --chown=neon:neon /home/nonroot/target/release/proxy
COPY --from=build --chown=neon:neon /home/nonroot/target/release/endpoint_storage /usr/local/bin
COPY --from=build --chown=neon:neon /home/nonroot/target/release/neon_local /usr/local/bin
COPY --from=build --chown=neon:neon /home/nonroot/target/release/storage_scrubber /usr/local/bin
COPY --from=build /home/nonroot/pg_install/v14 /usr/local/v14/
COPY --from=build /home/nonroot/pg_install/v15 /usr/local/v15/
COPY --from=build /home/nonroot/pg_install/v16 /usr/local/v16/
COPY --from=build /home/nonroot/pg_install/v17 /usr/local/v17/
# Deprecated: Old deployment scripts use this tarball which contains all the Postgres binaries.
# That's obsolete, since all the same files are also present under /usr/local/v*. But to keep the
# old scripts working for now, create the tarball.
RUN tar -C /usr/local -cvzf /data/postgres_install.tar.gz v14 v15 v16 v17
COPY --from=pg-build /home/nonroot/pg_install/v14 /usr/local/v14/
COPY --from=pg-build /home/nonroot/pg_install/v15 /usr/local/v15/
COPY --from=pg-build /home/nonroot/pg_install/v16 /usr/local/v16/
COPY --from=pg-build /home/nonroot/pg_install/v17 /usr/local/v17/
COPY --from=pg-build /home/nonroot/postgres_install.tar.gz /data/
# By default, pageserver uses `.neon/` working directory in WORKDIR, so create one and fill it with the dummy config.
# Now, when `docker run ... pageserver` is run, it can start without errors, yet will have some default dummy values.

View File

@@ -30,18 +30,11 @@ ifeq ($(BUILD_TYPE),release)
PG_CFLAGS += -O2 -g3 $(CFLAGS)
PG_LDFLAGS = $(LDFLAGS)
CARGO_PROFILE ?= --profile=release
# NEON_CARGO_ARTIFACT_TARGET_DIR is the directory where `cargo build` places
# the final build artifacts. There is unfortunately no easy way of changing
# it to a fully predictable path, nor to extract the path with a simple
# command. See https://github.com/rust-lang/cargo/issues/9661 and
# https://github.com/rust-lang/cargo/issues/6790.
NEON_CARGO_ARTIFACT_TARGET_DIR = $(ROOT_PROJECT_DIR)/target/release
else ifeq ($(BUILD_TYPE),debug)
PG_CONFIGURE_OPTS = --enable-debug --with-openssl --enable-cassert --enable-depend
PG_CFLAGS += -O0 -g3 $(CFLAGS)
PG_LDFLAGS = $(LDFLAGS)
CARGO_PROFILE ?= --profile=dev
NEON_CARGO_ARTIFACT_TARGET_DIR = $(ROOT_PROJECT_DIR)/target/debug
else
$(error Bad build type '$(BUILD_TYPE)', see Makefile for options)
endif
@@ -109,7 +102,7 @@ all: neon postgres-install neon-pg-ext
### Neon Rust bits
#
# The 'postgres_ffi' crate depends on the Postgres headers.
# The 'postgres_ffi' depends on the Postgres headers.
.PHONY: neon
neon: postgres-headers-install walproposer-lib cargo-target-dir
+@echo "Compiling Neon"
@@ -122,13 +115,10 @@ cargo-target-dir:
test -e target/CACHEDIR.TAG || echo "$(CACHEDIR_TAG_CONTENTS)" > target/CACHEDIR.TAG
.PHONY: neon-pg-ext-%
neon-pg-ext-%: postgres-install-% cargo-target-dir
neon-pg-ext-%: postgres-install-%
+@echo "Compiling neon-specific Postgres extensions for $*"
mkdir -p $(BUILD_DIR)/pgxn-$*
$(MAKE) PG_CONFIG="$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config" COPT='$(COPT)' \
NEON_CARGO_ARTIFACT_TARGET_DIR="$(NEON_CARGO_ARTIFACT_TARGET_DIR)" \
CARGO_BUILD_FLAGS="$(CARGO_BUILD_FLAGS)" \
CARGO_PROFILE="$(CARGO_PROFILE)" \
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \
-C $(BUILD_DIR)/pgxn-$*\
-f $(ROOT_PROJECT_DIR)/pgxn/Makefile install

View File

@@ -1636,14 +1636,11 @@ RUN make install USE_PGXS=1 -j $(getconf _NPROCESSORS_ONLN)
# compile neon extensions
#
#########################################################################################
FROM pg-build-with-cargo AS neon-ext-build
FROM pg-build AS neon-ext-build
ARG PG_VERSION
USER root
COPY . .
RUN make -j $(getconf _NPROCESSORS_ONLN) -C pgxn -s install-compute \
BUILD_TYPE=release CARGO_BUILD_FLAGS="--locked --release" NEON_CARGO_ARTIFACT_TARGET_DIR="$(pwd)/target/release"
COPY pgxn/ pgxn/
RUN make -j $(getconf _NPROCESSORS_ONLN) -C pgxn -s install-compute
#########################################################################################
#

View File

@@ -29,8 +29,7 @@ use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, Condvar, Mutex, RwLock};
use std::time::{Duration, Instant};
use std::{env, fs};
use tokio::task::JoinHandle;
use tokio::{spawn, time};
use tokio::spawn;
use tracing::{Instrument, debug, error, info, instrument, warn};
use url::Url;
use utils::id::{TenantId, TimelineId};
@@ -108,8 +107,6 @@ pub struct ComputeNodeParams {
pub installed_extensions_collection_interval: Arc<AtomicU64>,
}
type TaskHandle = Mutex<Option<JoinHandle<()>>>;
/// Compute node info shared across several `compute_ctl` threads.
pub struct ComputeNode {
pub params: ComputeNodeParams,
@@ -132,8 +129,7 @@ pub struct ComputeNode {
pub compute_ctl_config: ComputeCtlConfig,
/// Handle to the extension stats collection task
extension_stats_task: TaskHandle,
lfc_offload_task: TaskHandle,
extension_stats_task: Mutex<Option<tokio::task::JoinHandle<()>>>,
}
// store some metrics about download size that might impact startup time
@@ -372,7 +368,7 @@ fn maybe_cgexec(cmd: &str) -> Command {
struct PostgresHandle {
postgres: std::process::Child,
log_collector: JoinHandle<Result<()>>,
log_collector: tokio::task::JoinHandle<Result<()>>,
}
impl PostgresHandle {
@@ -386,7 +382,7 @@ struct StartVmMonitorResult {
#[cfg(target_os = "linux")]
token: tokio_util::sync::CancellationToken,
#[cfg(target_os = "linux")]
vm_monitor: Option<JoinHandle<Result<()>>>,
vm_monitor: Option<tokio::task::JoinHandle<Result<()>>>,
}
impl ComputeNode {
@@ -437,7 +433,6 @@ impl ComputeNode {
ext_download_progress: RwLock::new(HashMap::new()),
compute_ctl_config: config.compute_ctl_config,
extension_stats_task: Mutex::new(None),
lfc_offload_task: Mutex::new(None),
})
}
@@ -525,8 +520,8 @@ impl ComputeNode {
None
};
// Terminate the extension stats collection task
this.terminate_extension_stats_task();
this.terminate_lfc_offload_task();
// Terminate the vm_monitor so it releases the file watcher on
// /sys/fs/cgroup/neon-postgres.
@@ -856,15 +851,12 @@ impl ComputeNode {
// Log metrics so that we can search for slow operations in logs
info!(?metrics, postmaster_pid = %postmaster_pid, "compute start finished");
// Spawn the extension stats background task
self.spawn_extension_stats_task();
if pspec.spec.autoprewarm {
info!("autoprewarming on startup as requested");
self.prewarm_lfc(None);
}
if let Some(seconds) = pspec.spec.offload_lfc_interval_seconds {
self.spawn_lfc_offload_task(Duration::from_secs(seconds.into()));
};
Ok(())
}
@@ -2365,7 +2357,10 @@ LIMIT 100",
}
pub fn spawn_extension_stats_task(&self) {
self.terminate_extension_stats_task();
// Cancel any existing task
if let Some(handle) = self.extension_stats_task.lock().unwrap().take() {
handle.abort();
}
let conf = self.tokio_conn_conf.clone();
let atomic_interval = self.params.installed_extensions_collection_interval.clone();
@@ -2376,23 +2371,24 @@ LIMIT 100",
installed_extensions_collection_interval
);
let handle = tokio::spawn(async move {
// An initial sleep is added to ensure that two collections don't happen at the same time.
// The first collection happens during compute startup.
tokio::time::sleep(tokio::time::Duration::from_secs(
installed_extensions_collection_interval,
))
.await;
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(
installed_extensions_collection_interval,
));
loop {
info!(
"[NEON_EXT_INT_SLEEP]: Interval: {}",
installed_extensions_collection_interval
);
// Sleep at the start of the loop to ensure that two collections don't happen at the same time.
// The first collection happens during compute startup.
tokio::time::sleep(tokio::time::Duration::from_secs(
installed_extensions_collection_interval,
))
.await;
interval.tick().await;
let _ = installed_extensions(conf.clone()).await;
// Acquire a read lock on the compute spec and then update the interval if necessary
installed_extensions_collection_interval = std::cmp::max(
interval = tokio::time::interval(tokio::time::Duration::from_secs(std::cmp::max(
installed_extensions_collection_interval,
2 * atomic_interval.load(std::sync::atomic::Ordering::SeqCst),
);
)));
installed_extensions_collection_interval = interval.period().as_secs();
}
});
@@ -2401,30 +2397,8 @@ LIMIT 100",
}
fn terminate_extension_stats_task(&self) {
if let Some(h) = self.extension_stats_task.lock().unwrap().take() {
h.abort()
}
}
pub fn spawn_lfc_offload_task(self: &Arc<Self>, interval: Duration) {
self.terminate_lfc_offload_task();
let secs = interval.as_secs();
info!("spawning lfc offload worker with {secs}s interval");
let this = self.clone();
let handle = spawn(async move {
let mut interval = time::interval(interval);
interval.tick().await; // returns immediately
loop {
interval.tick().await;
this.offload_lfc_async().await;
}
});
*self.lfc_offload_task.lock().unwrap() = Some(handle);
}
fn terminate_lfc_offload_task(&self) {
if let Some(h) = self.lfc_offload_task.lock().unwrap().take() {
h.abort()
if let Some(handle) = self.extension_stats_task.lock().unwrap().take() {
handle.abort();
}
}

View File

@@ -5,7 +5,6 @@ use compute_api::responses::LfcOffloadState;
use compute_api::responses::LfcPrewarmState;
use http::StatusCode;
use reqwest::Client;
use std::mem::replace;
use std::sync::Arc;
use tokio::{io::AsyncReadExt, spawn};
use tracing::{error, info};
@@ -89,15 +88,17 @@ impl ComputeNode {
self.state.lock().unwrap().lfc_offload_state.clone()
}
/// If there is a prewarm request ongoing, return false, true otherwise
/// Returns false if there is a prewarm request ongoing, true otherwise
pub fn prewarm_lfc(self: &Arc<Self>, from_endpoint: Option<String>) -> bool {
crate::metrics::LFC_PREWARM_REQUESTS.inc();
{
let state = &mut self.state.lock().unwrap().lfc_prewarm_state;
if let LfcPrewarmState::Prewarming = replace(state, LfcPrewarmState::Prewarming) {
if let LfcPrewarmState::Prewarming =
std::mem::replace(state, LfcPrewarmState::Prewarming)
{
return false;
}
}
crate::metrics::LFC_PREWARMS.inc();
let cloned = self.clone();
spawn(async move {
@@ -151,39 +152,30 @@ impl ComputeNode {
.map(|_| ())
}
/// If offload request is ongoing, return false, true otherwise
/// Returns false if there is an offload request ongoing, true otherwise
pub fn offload_lfc(self: &Arc<Self>) -> bool {
crate::metrics::LFC_OFFLOAD_REQUESTS.inc();
{
let state = &mut self.state.lock().unwrap().lfc_offload_state;
if replace(state, LfcOffloadState::Offloading) == LfcOffloadState::Offloading {
if let LfcOffloadState::Offloading =
std::mem::replace(state, LfcOffloadState::Offloading)
{
return false;
}
}
let cloned = self.clone();
spawn(async move { cloned.offload_lfc_with_state_update().await });
true
}
pub async fn offload_lfc_async(self: &Arc<Self>) {
{
let state = &mut self.state.lock().unwrap().lfc_offload_state;
if replace(state, LfcOffloadState::Offloading) == LfcOffloadState::Offloading {
spawn(async move {
let Err(err) = cloned.offload_lfc_impl().await else {
cloned.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Completed;
return;
}
}
self.offload_lfc_with_state_update().await
}
async fn offload_lfc_with_state_update(&self) {
crate::metrics::LFC_OFFLOADS.inc();
let Err(err) = self.offload_lfc_impl().await else {
self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Completed;
return;
};
error!(%err);
self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Failed {
error: err.to_string(),
};
};
error!(%err);
cloned.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Failed {
error: err.to_string(),
};
});
true
}
async fn offload_lfc_impl(&self) -> Result<()> {

View File

@@ -97,18 +97,20 @@ pub(crate) static PG_TOTAL_DOWNTIME_MS: Lazy<GenericCounter<AtomicU64>> = Lazy::
.expect("failed to define a metric")
});
pub(crate) static LFC_PREWARMS: Lazy<IntCounter> = Lazy::new(|| {
/// Needed as neon.file_cache_prewarm_batch == 0 doesn't mean we never tried to prewarm.
/// On the other hand, LFC_PREWARMED_PAGES is excessive as we can GET /lfc/prewarm
pub(crate) static LFC_PREWARM_REQUESTS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"compute_ctl_lfc_prewarms_total",
"Total number of LFC prewarms requested by compute_ctl or autoprewarm option",
"compute_ctl_lfc_prewarm_requests_total",
"Total number of LFC prewarm requests made by compute_ctl",
)
.expect("failed to define a metric")
});
pub(crate) static LFC_OFFLOADS: Lazy<IntCounter> = Lazy::new(|| {
pub(crate) static LFC_OFFLOAD_REQUESTS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"compute_ctl_lfc_offloads_total",
"Total number of LFC offloads requested by compute_ctl or lfc_offload_period_seconds option",
"compute_ctl_lfc_offload_requests_total",
"Total number of LFC offload requests made by compute_ctl",
)
.expect("failed to define a metric")
});
@@ -122,7 +124,7 @@ pub fn collect() -> Vec<MetricFamily> {
metrics.extend(AUDIT_LOG_DIR_SIZE.collect());
metrics.extend(PG_CURR_DOWNTIME_MS.collect());
metrics.extend(PG_TOTAL_DOWNTIME_MS.collect());
metrics.extend(LFC_PREWARMS.collect());
metrics.extend(LFC_OFFLOADS.collect());
metrics.extend(LFC_PREWARM_REQUESTS.collect());
metrics.extend(LFC_OFFLOAD_REQUESTS.collect());
metrics
}

View File

@@ -31,7 +31,6 @@ mod pg_helpers_tests {
wal_level = logical
hot_standby = on
autoprewarm = off
offload_lfc_interval_seconds = 20
neon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'
wal_log_hints = on
log_connections = on

View File

@@ -64,9 +64,7 @@ const DEFAULT_PAGESERVER_ID: NodeId = NodeId(1);
const DEFAULT_BRANCH_NAME: &str = "main";
project_git_version!(GIT_VERSION);
#[allow(dead_code)]
const DEFAULT_PG_VERSION: PgMajorVersion = PgMajorVersion::PG17;
const DEFAULT_PG_VERSION_NUM: &str = "17";
const DEFAULT_PAGESERVER_CONTROL_PLANE_API: &str = "http://127.0.0.1:1234/upcall/v1/";
@@ -169,7 +167,7 @@ struct TenantCreateCmdArgs {
#[clap(short = 'c')]
config: Vec<String>,
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[arg(default_value_t = DEFAULT_PG_VERSION)]
#[clap(long, help = "Postgres version to use for the initial timeline")]
pg_version: PgMajorVersion,
@@ -292,7 +290,7 @@ struct TimelineCreateCmdArgs {
#[clap(long, help = "Human-readable alias for the new timeline")]
branch_name: String,
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[arg(default_value_t = DEFAULT_PG_VERSION)]
#[clap(long, help = "Postgres version")]
pg_version: PgMajorVersion,
}
@@ -324,7 +322,7 @@ struct TimelineImportCmdArgs {
#[clap(long, help = "Lsn the basebackup ends at")]
end_lsn: Option<Lsn>,
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[arg(default_value_t = DEFAULT_PG_VERSION)]
#[clap(long, help = "Postgres version of the backup being imported")]
pg_version: PgMajorVersion,
}
@@ -603,7 +601,7 @@ struct EndpointCreateCmdArgs {
)]
config_only: bool,
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[arg(default_value_t = DEFAULT_PG_VERSION)]
#[clap(long, help = "Postgres version")]
pg_version: PgMajorVersion,
@@ -675,16 +673,6 @@ struct EndpointStartCmdArgs {
#[arg(default_value = "90s")]
start_timeout: Duration,
#[clap(
long,
help = "Download LFC cache from endpoint storage on endpoint startup",
default_value = "false"
)]
autoprewarm: bool,
#[clap(long, help = "Upload LFC cache to endpoint storage periodically")]
offload_lfc_interval_seconds: Option<std::num::NonZeroU64>,
#[clap(
long,
help = "Run in development mode, skipping VM-specific operations like process termination",
@@ -1595,24 +1583,22 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
let endpoint_storage_token = env.generate_auth_token(&claims)?;
let endpoint_storage_addr = env.endpoint_storage.listen_addr.to_string();
let args = control_plane::endpoint::EndpointStartArgs {
auth_token,
endpoint_storage_token,
endpoint_storage_addr,
safekeepers_generation,
safekeepers,
pageservers,
remote_ext_base_url: remote_ext_base_url.clone(),
shard_stripe_size: stripe_size.0 as usize,
create_test_user: args.create_test_user,
start_timeout: args.start_timeout,
autoprewarm: args.autoprewarm,
offload_lfc_interval_seconds: args.offload_lfc_interval_seconds,
dev: args.dev,
};
println!("Starting existing endpoint {endpoint_id}...");
endpoint.start(args).await?;
endpoint
.start(
&auth_token,
endpoint_storage_token,
endpoint_storage_addr,
safekeepers_generation,
safekeepers,
pageservers,
remote_ext_base_url.as_ref(),
stripe_size.0 as usize,
args.create_test_user,
args.start_timeout,
args.dev,
)
.await?;
}
EndpointCmd::Reconfigure(args) => {
let endpoint_id = &args.endpoint_id;

View File

@@ -373,22 +373,6 @@ impl std::fmt::Display for EndpointTerminateMode {
}
}
pub struct EndpointStartArgs {
pub auth_token: Option<String>,
pub endpoint_storage_token: String,
pub endpoint_storage_addr: String,
pub safekeepers_generation: Option<SafekeeperGeneration>,
pub safekeepers: Vec<NodeId>,
pub pageservers: Vec<(PageserverProtocol, Host, u16)>,
pub remote_ext_base_url: Option<String>,
pub shard_stripe_size: usize,
pub create_test_user: bool,
pub start_timeout: Duration,
pub autoprewarm: bool,
pub offload_lfc_interval_seconds: Option<std::num::NonZeroU64>,
pub dev: bool,
}
impl Endpoint {
fn from_dir_entry(entry: std::fs::DirEntry, env: &LocalEnv) -> Result<Endpoint> {
if !entry.file_type()?.is_dir() {
@@ -693,7 +677,21 @@ impl Endpoint {
})
}
pub async fn start(&self, args: EndpointStartArgs) -> Result<()> {
#[allow(clippy::too_many_arguments)]
pub async fn start(
&self,
auth_token: &Option<String>,
endpoint_storage_token: String,
endpoint_storage_addr: String,
safekeepers_generation: Option<SafekeeperGeneration>,
safekeepers: Vec<NodeId>,
pageservers: Vec<(PageserverProtocol, Host, u16)>,
remote_ext_base_url: Option<&String>,
shard_stripe_size: usize,
create_test_user: bool,
start_timeout: Duration,
dev: bool,
) -> Result<()> {
if self.status() == EndpointStatus::Running {
anyhow::bail!("The endpoint is already running");
}
@@ -706,10 +704,10 @@ impl Endpoint {
std::fs::remove_dir_all(self.pgdata())?;
}
let pageserver_connstring = Self::build_pageserver_connstr(&args.pageservers);
let pageserver_connstring = Self::build_pageserver_connstr(&pageservers);
assert!(!pageserver_connstring.is_empty());
let safekeeper_connstrings = self.build_safekeepers_connstrs(args.safekeepers)?;
let safekeeper_connstrings = self.build_safekeepers_connstrs(safekeepers)?;
// check for file remote_extensions_spec.json
// if it is present, read it and pass to compute_ctl
@@ -737,7 +735,7 @@ impl Endpoint {
cluster_id: None, // project ID: not used
name: None, // project name: not used
state: None,
roles: if args.create_test_user {
roles: if create_test_user {
vec![Role {
name: PgIdent::from_str("test").unwrap(),
encrypted_password: None,
@@ -746,7 +744,7 @@ impl Endpoint {
} else {
Vec::new()
},
databases: if args.create_test_user {
databases: if create_test_user {
vec![Database {
name: PgIdent::from_str("neondb").unwrap(),
owner: PgIdent::from_str("test").unwrap(),
@@ -768,21 +766,20 @@ impl Endpoint {
endpoint_id: Some(self.endpoint_id.clone()),
mode: self.mode,
pageserver_connstring: Some(pageserver_connstring),
safekeepers_generation: args.safekeepers_generation.map(|g| g.into_inner()),
safekeepers_generation: safekeepers_generation.map(|g| g.into_inner()),
safekeeper_connstrings,
storage_auth_token: args.auth_token.clone(),
storage_auth_token: auth_token.clone(),
remote_extensions,
pgbouncer_settings: None,
shard_stripe_size: Some(args.shard_stripe_size),
shard_stripe_size: Some(shard_stripe_size),
local_proxy_config: None,
reconfigure_concurrency: self.reconfigure_concurrency,
drop_subscriptions_before_start: self.drop_subscriptions_before_start,
audit_log_level: ComputeAudit::Disabled,
logs_export_host: None::<String>,
endpoint_storage_addr: Some(args.endpoint_storage_addr),
endpoint_storage_token: Some(args.endpoint_storage_token),
autoprewarm: args.autoprewarm,
offload_lfc_interval_seconds: args.offload_lfc_interval_seconds,
endpoint_storage_addr: Some(endpoint_storage_addr),
endpoint_storage_token: Some(endpoint_storage_token),
autoprewarm: false,
suspend_timeout_seconds: -1, // Only used in neon_local.
};
@@ -794,7 +791,7 @@ impl Endpoint {
debug!("spec.cluster {:?}", spec.cluster);
// fill missing fields again
if args.create_test_user {
if create_test_user {
spec.cluster.roles.push(Role {
name: PgIdent::from_str("test").unwrap(),
encrypted_password: None,
@@ -829,7 +826,7 @@ impl Endpoint {
// Launch compute_ctl
let conn_str = self.connstr("cloud_admin", "postgres");
println!("Starting postgres node at '{conn_str}'");
if args.create_test_user {
if create_test_user {
let conn_str = self.connstr("test", "neondb");
println!("Also at '{conn_str}'");
}
@@ -861,11 +858,11 @@ impl Endpoint {
.stderr(logfile.try_clone()?)
.stdout(logfile);
if let Some(remote_ext_base_url) = args.remote_ext_base_url {
cmd.args(["--remote-ext-base-url", &remote_ext_base_url]);
if let Some(remote_ext_base_url) = remote_ext_base_url {
cmd.args(["--remote-ext-base-url", remote_ext_base_url]);
}
if args.dev {
if dev {
cmd.arg("--dev");
}
@@ -897,11 +894,10 @@ impl Endpoint {
Ok(state) => {
match state.status {
ComputeStatus::Init => {
let timeout = args.start_timeout;
if Instant::now().duration_since(start_at) > timeout {
if Instant::now().duration_since(start_at) > start_timeout {
bail!(
"compute startup timed out {:?}; still in Init state",
timeout
start_timeout
);
}
// keep retrying
@@ -929,10 +925,9 @@ impl Endpoint {
}
}
Err(e) => {
if Instant::now().duration_since(start_at) > args.start_timeout {
if Instant::now().duration_since(start_at) > start_timeout {
return Err(e).context(format!(
"timed out {:?} waiting to connect to compute_ctl HTTP",
args.start_timeout
"timed out {start_timeout:?} waiting to connect to compute_ctl HTTP",
));
}
}

View File

@@ -65,27 +65,12 @@ enum Command {
#[arg(long)]
scheduling: Option<NodeSchedulingPolicy>,
},
/// Exists for backup usage and will be removed in future.
/// Use [`Command::NodeStartDelete`] instead, if possible.
// Set a node status as deleted.
NodeDelete {
#[arg(long)]
node_id: NodeId,
},
/// Start deletion of the specified pageserver.
NodeStartDelete {
#[arg(long)]
node_id: NodeId,
},
/// Cancel deletion of the specified pageserver and wait for `timeout`
/// for the operation to be canceled. May be retried.
NodeCancelDelete {
#[arg(long)]
node_id: NodeId,
#[arg(long)]
timeout: humantime::Duration,
},
/// Delete a tombstone of node from the storage controller.
/// This is used when we want to allow the node to be re-registered.
NodeDeleteTombstone {
#[arg(long)]
node_id: NodeId,
@@ -927,43 +912,10 @@ async fn main() -> anyhow::Result<()> {
.await?;
}
Command::NodeDelete { node_id } => {
eprintln!("Warning: This command is obsolete and will be removed in a future version");
eprintln!("Use `NodeStartDelete` instead, if possible");
storcon_client
.dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None)
.await?;
}
Command::NodeStartDelete { node_id } => {
storcon_client
.dispatch::<(), ()>(
Method::PUT,
format!("control/v1/node/{node_id}/delete"),
None,
)
.await?;
println!("Delete started for {node_id}");
}
Command::NodeCancelDelete { node_id, timeout } => {
storcon_client
.dispatch::<(), ()>(
Method::DELETE,
format!("control/v1/node/{node_id}/delete"),
None,
)
.await?;
println!("Waiting for node {node_id} to quiesce on scheduling policy ...");
let final_policy =
wait_for_scheduling_policy(storcon_client, node_id, *timeout, |sched| {
!matches!(sched, NodeSchedulingPolicy::Deleting)
})
.await?;
println!(
"Delete was cancelled for node {node_id}. Schedulling policy is now {final_policy:?}"
);
}
Command::NodeDeleteTombstone { node_id } => {
storcon_client
.dispatch::<(), ()>(

View File

@@ -20,7 +20,7 @@ In our case consensus leader is compute (walproposer), and we don't want to wake
up all computes for the change. Neither we want to fully reimplement the leader
logic second time outside compute. Because of that the proposed algorithm relies
for issuing configurations on the external fault tolerant (distributed) strongly
consistent storage with simple API: CAS (compare-and-swap) on the single key.
consisent storage with simple API: CAS (compare-and-swap) on the single key.
Properly configured postgres suits this.
In the system consensus is implemented at the timeline level, so algorithm below
@@ -34,7 +34,7 @@ A configuration is
```
struct Configuration {
generation: SafekeeperGeneration, // a number uniquely identifying configuration
generation: Generation, // a number uniquely identifying configuration
sk_set: Vec<NodeId>, // current safekeeper set
new_sk_set: Optional<Vec<NodeId>>,
}
@@ -81,11 +81,11 @@ configuration generation in them is less than its current one. Namely, it
refuses to vote, to truncate WAL in `handle_elected` and to accept WAL. In
response it sends its current configuration generation to let walproposer know.
Safekeeper gets `PUT /v1/tenants/{tenant_id}/timelines/{timeline_id}/membership`
accepting `Configuration`. Safekeeper switches to the given conf if it is higher than its
Safekeeper gets `PUT /v1/tenants/{tenant_id}/timelines/{timeline_id}/configuration`
accepting `Configuration`. Safekeeper switches to the given conf it is higher than its
current one and ignores it otherwise. In any case it replies with
```
struct TimelineMembershipSwitchResponse {
struct ConfigurationSwitchResponse {
conf: Configuration,
term: Term,
last_log_term: Term,
@@ -108,7 +108,7 @@ establishes this configuration as its own and moves to voting.
It should stop talking to safekeepers not listed in the configuration at this
point, though it is not unsafe to continue doing so.
To be elected it must receive votes from both majorities if `new_sk_set` is present.
To be elected it must receive votes from both majorites if `new_sk_set` is present.
Similarly, to commit WAL it must receive flush acknowledge from both majorities.
If walproposer hears from safekeeper configuration higher than his own (i.e.
@@ -130,7 +130,7 @@ storage are reachable.
1) Fetch current timeline configuration from the configuration storage.
2) If it is already joint one and `new_set` is different from `desired_set`
refuse to change. However, assign join conf to (in memory) var
`joint_conf` and proceed to step 4 to finish the ongoing change.
`join_conf` and proceed to step 4 to finish the ongoing change.
3) Else, create joint `joint_conf: Configuration`: increment current conf number
`n` and put `desired_set` to `new_sk_set`. Persist it in the configuration
storage by doing CAS on the current generation: change happens only if
@@ -161,11 +161,11 @@ storage are reachable.
because `pull_timeline` already includes it and plus additionally would be
broadcast by compute. More importantly, we may proceed to the next step
only when `<last_log_term, flush_lsn>` on the majority of the new set reached
`sync_position`. Similarly, on the happy path no waiting is needed because
`sync_position`. Similarly, on the happy path no waiting is not needed because
`pull_timeline` already includes it. However, we should double
check to be safe. For example, timeline could have been created earlier e.g.
manually or after try-to-migrate, abort, try-to-migrate-again sequence.
7) Create `new_conf: Configuration` incrementing `joint_conf` generation and having new
7) Create `new_conf: Configuration` incrementing `join_conf` generation and having new
safekeeper set as `sk_set` and None `new_sk_set`. Write it to configuration
storage under one more CAS.
8) Call `PUT` `configuration` on safekeepers from the new set,
@@ -178,12 +178,12 @@ spec of it.
Description above focuses on safety. To make the flow practical and live, here a few more
considerations.
1) It makes sense to ping new set to ensure we are migrating to live node(s) before
1) It makes sense to ping new set to ensure it we are migrating to live node(s) before
step 3.
2) If e.g. accidentally wrong new sk set has been specified, before CAS in step `6` is completed
it is safe to rollback to the old conf with one more CAS.
3) On step 4 timeline might be already created on members of the new set for various reasons;
the simplest is the procedure restart. There are more complicated scenarios like mentioned
the simplest is the procedure restart. There are more complicated scenarious like mentioned
in step 5. Deleting and re-doing `pull_timeline` is generally unsafe without involving
generations, so seems simpler to treat existing timeline as success. However, this also
has a disadvantage: you might imagine an surpassingly unlikely schedule where condition in
@@ -192,7 +192,7 @@ considerations.
4) In the end timeline should be locally deleted on the safekeeper(s) which are
in the old set but not in the new one, unless they are unreachable. To be
safe this also should be done under generation number (deletion proceeds only if
current configuration is <= than one in request and safekeeper is not member of it).
current configuration is <= than one in request and safekeeper is not memeber of it).
5) If current conf fetched on step 1 is already not joint and members equal to `desired_set`,
jump to step 7, using it as `new_conf`.
@@ -261,14 +261,14 @@ Timeline (branch) creation in cplane should call storage_controller POST
Response should be augmented with `safekeepers_generation` and `safekeepers`
fields like described in `/notify-safekeepers` above. Initially (currently)
these fields may be absent; in this case cplane chooses safekeepers on its own
like it currently does. The call should be retried until it succeeds.
like it currently does. The call should be retried until succeeds.
Timeline deletion and tenant deletion in cplane should call appropriate
storage_controller endpoints like it currently does for sharded tenants. The
calls should be retried until they succeed.
When compute receives safekeeper list from control plane it needs to know the
generation to check whether it should be updated (note that compute may get
When compute receives safekeepers list from control plane it needs to know the
generation to checked whether it should be updated (note that compute may get
safekeeper list from either cplane or safekeepers). Currently `neon.safekeepers`
GUC is just a comma separates list of `host:port`. Let's prefix it with
`g#<generation>:` to this end, so it will look like
@@ -305,8 +305,8 @@ enum MigrationRequest {
```
`FinishPending` requests to run the procedure to ensure state is clean: current
configuration is not joint and the majority of safekeepers are aware of it, but do
not attempt to migrate anywhere. If the current configuration fetched on step 1 is
configuration is not joint and majority of safekeepers are aware of it, but do
not attempt to migrate anywhere. If current configuration fetched on step 1 is
not joint it jumps to step 7. It should be run at startup for all timelines (but
similarly, in the first version it is ok to trigger it manually).
@@ -315,7 +315,7 @@ similarly, in the first version it is ok to trigger it manually).
`safekeepers` table mirroring current `nodes` should be added, except that for
`scheduling_policy`: it is enough to have at least in the beginning only 3
fields: 1) `active` 2) `paused` (initially means only not assign new tlis there
3) `decommissioned` (node is removed).
3) `decomissioned` (node is removed).
`timelines` table:
```
@@ -326,10 +326,9 @@ table! {
tenant_id -> Varchar,
start_lsn -> pg_lsn,
generation -> Int4,
sk_set -> Array<Int8>, // list of safekeeper ids
sk_set -> Array<Int4>, // list of safekeeper ids
new_sk_set -> Nullable<Array<Int8>>, // list of safekeeper ids, null if not joint conf
cplane_notified_generation -> Int4,
sk_set_notified_generation -> Int4, // the generation a quorum of sk_set knows about
deleted_at -> Nullable<Timestamptz>,
}
}
@@ -339,23 +338,13 @@ table! {
might also want to add ancestor_timeline_id to preserve the hierarchy, but for
this RFC it is not needed.
`cplane_notified_generation` and `sk_set_notified_generation` fields are used to
track the last stage of the algorithm, when we need to notify safekeeper set and cplane
with the final configuration after it's already committed to DB.
The timeline is up-to-date (no migration in progress) if `new_sk_set` is null and
`*_notified_generation` fields are up to date with `generation`.
It's possible to replace `*_notified_generation` with one boolean field `migration_completed`,
but for better observability it's nice to have them separately.
#### API
Node management is similar to pageserver:
1) POST `/control/v1/safekeeper` inserts safekeeper.
2) GET `/control/v1/safekeeper` lists safekeepers.
3) GET `/control/v1/safekeeper/:node_id` gets safekeeper.
4) PUT `/control/v1/safekeper/:node_id/scheduling_policy` changes status to e.g.
1) POST `/control/v1/safekeepers` inserts safekeeper.
2) GET `/control/v1/safekeepers` lists safekeepers.
3) GET `/control/v1/safekeepers/:node_id` gets safekeeper.
4) PUT `/control/v1/safekepers/:node_id/status` changes status to e.g.
`offline` or `decomissioned`. Initially it is simpler not to schedule any
migrations here.
@@ -379,8 +368,8 @@ Migration API: the first version is the simplest and the most imperative:
all timelines from one safekeeper to another. It accepts json
```
{
"src_sk": NodeId,
"dst_sk": NodeId,
"src_sk": u32,
"dst_sk": u32,
"limit": Optional<u32>,
}
```
@@ -390,15 +379,12 @@ Returns list of scheduled requests.
2) PUT `/control/v1/tenant/:tenant_id/timeline/:timeline_id/safekeeper_migrate` schedules `MigrationRequest`
to move single timeline to given set of safekeepers:
```
struct TimelineSafekeeperMigrateRequest {
"new_sk_set": Vec<NodeId>,
{
"desired_set": Vec<u32>,
}
```
In the first version the handler migrates the timeline to `new_sk_set` synchronously.
Should be retried until success.
In the future we might change it to asynchronous API and return scheduled request.
Returns scheduled request.
Similar call should be added for the tenant.
@@ -448,9 +434,6 @@ table! {
}
```
We load all pending ops from the table on startup into the memory.
The table is needed only to preserve the state between restarts.
`op_type` can be `include` (seed from peers and ensure generation is up to
date), `exclude` (remove locally) and `delete`. Field is actually not strictly
needed as it can be computed from current configuration, but gives more explicit
@@ -491,7 +474,7 @@ actions must be idempotent. Now, a tricky point here is timeline start LSN. For
the initial (tenant creation) call cplane doesn't know it. However, setting
start_lsn on safekeepers during creation is a good thing -- it provides a
guarantee that walproposer can always find a common point in WAL histories of
safekeeper and its own, and so absence of it would be a clear sign of
safekeeper and its own, and so absense of it would be a clear sign of
corruption. The following sequence works:
1) Create timeline (or observe that it exists) on pageserver,
figuring out last_record_lsn in response.
@@ -514,9 +497,11 @@ corruption. The following sequence works:
retries the call until 200 response.
There is a small question how request handler (timeline creation in this
case) would interact with per sk reconciler. In the current implementation
we first persist the request in the DB, and then send an in-memory request
to each safekeeper reconciler to process it.
case) would interact with per sk reconciler. As always I prefer to do the
simplest possible thing and here it seems to be just waking it up so it
re-reads the db for work to do. Passing work in memory is faster, but
that shouldn't matter, and path to scan db for work will exist anyway,
simpler to reuse it.
For pg version / wal segment size: while we may persist them in `timelines`
table, it is not necessary as initial creation at step 3 can take them from
@@ -524,40 +509,30 @@ pageserver or cplane creation call and later pull_timeline will carry them
around.
Timeline migration.
1) CAS to the db to create joint conf. Since this moment the migration is considered to be
"in progress". We can detect all "in-progress" migrations looking into the database.
2) Do steps 4-6 from the algorithm, including `pull_timeline` onto `new_sk_set`, update membership
configuration on all safekeepers, notify cplane, etc. All operations are idempotent,
so we don't need to persist anything in the database at this stage. If any errors occur,
it's safe to retry or abort the migration.
3) Once it becomes possible per alg description above, get out of joint conf
with another CAS. Also should insert `exclude` entries into `safekeeper_timeline_pending_ops`
in the same DB transaction. Adding `exclude` entries atomically is nesessary because after
CAS we don't have the list of excluded safekeepers in the `timelines` table anymore, but we
need to have them persisted somewhere in case the migration is interrupted right after the CAS.
4) Finish the migration. The final membership configuration is committed to the DB at this stage.
So, the migration can not be aborted anymore. But it can still be retried if the migration fails
past stage 3. To finish the migration we need to send the new membership configuration to
a new quorum of safekeepers, notify cplane with the new safekeeper list and schedule the `exclude`
requests to in-memory queue for safekeeper reconciler. If the algrorithm is retried, it's
possible that we have already committed `exclude` requests to DB, but didn't send them to
the in-memory queue. In this case we need to read them from `safekeeper_timeline_pending_ops`
because it's the only place where they are persistent. The fields `sk_set_notified_generation`
and `cplane_notified_generation` are updated after each step. The migration is considered
fully completed when they match the `generation` field.
In practice, we can report "success" after stage 3 and do the "finish" step in per-timeline
reconciler (if we implement it). But it's wise to at least try to finish them synchronously,
so the timeline is always in a "good state" and doesn't require an old quorum to commit
WAL after the migration reported "success".
1) CAS to the db to create joint conf, and in the same transaction create
`safekeeper_timeline_pending_ops` `include` entries to initialize new members
as well as deliver this conf to current ones; poke per sk reconcilers to work
on it. Also any conf change should also poke cplane notifier task(s).
2) Once it becomes possible per alg description above, get out of joint conf
with another CAS. Task should get wakeups from per sk reconcilers because
conf switch is required for advancement; however retries should be sleep
based as well as LSN advancement might be needed, though in happy path
it isn't. To see whether further transition is possible on wakup migration
executor polls safekeepers per the algorithm. CAS creating new conf with only
new members should again insert entries to `safekeeper_timeline_pending_ops`
to switch them there, as well as `exclude` rows to remove timeline from
old members.
Timeline deletion: just set `deleted_at` on the timeline row and insert
`safekeeper_timeline_pending_ops` entries in the same xact, the rest is done by
per sk reconcilers.
When node is removed (set to `decommissioned`), `safekeeper_timeline_pending_ops`
When node is removed (set to `decomissioned`), `safekeeper_timeline_pending_ops`
for it must be cleared in the same transaction.
One more task pool should infinitely retry notifying control plane about changed
safekeeper sets (trying making `cplane_notified_generation` equal `generation`).
#### Dealing with multiple instances of storage_controller
Operations described above executed concurrently might create some errors but do
@@ -566,7 +541,7 @@ of storage_controller it is fine to have it temporarily, e.g. during redeploy.
To harden against some controller instance creating some work in
`safekeeper_timeline_pending_ops` and then disappearing without anyone pickup up
the job per sk reconcilers apart from explicit wakeups should scan for work
the job per sk reconcilers apart from explicit wakups should scan for work
periodically. It is possible to remove that though if all db updates are
protected with leadership token/term -- then such scans are needed only after
leadership is acquired.
@@ -588,7 +563,7 @@ There should be following layers of tests:
safekeeper communication and pull_timeline need to be mocked and main switch
procedure wrapped to as a node (thread) in simulation tests, using these
mocks. Test would inject migrations like it currently injects
safekeeper/walproposer restarts. Main assert is the same -- committed WAL must
safekeeper/walproposer restars. Main assert is the same -- committed WAL must
not be lost.
3) Since simulation testing injects at relatively high level points (not
@@ -638,7 +613,7 @@ Let's have the following implementation bits for gradual rollout:
`notify-safekeepers`.
Then the rollout for a region would be:
- Current situation: safekeepers are chosen by control_plane.
- Current situation: safekeepers are choosen by control_plane.
- We manually migrate some timelines, test moving them around.
- Then we enable `--set-safekeepers` so that all new timelines
are on storage controller.

View File

@@ -58,7 +58,7 @@ pub enum LfcPrewarmState {
},
}
#[derive(Serialize, Default, Debug, Clone, PartialEq)]
#[derive(Serialize, Default, Debug, Clone)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum LfcOffloadState {
#[default]

View File

@@ -181,14 +181,10 @@ pub struct ComputeSpec {
/// JWT for authorizing requests to endpoint storage service
pub endpoint_storage_token: Option<String>,
/// Download LFC state from endpoint_storage and pass it to Postgres on startup
#[serde(default)]
/// Download LFC state from endpoint storage and pass it to Postgres on compute startup
pub autoprewarm: bool,
#[serde(default)]
/// Upload LFC state to endpoint storage periodically. Default value (None) means "don't upload"
pub offload_lfc_interval_seconds: Option<std::num::NonZeroU64>,
/// Suspend timeout in seconds.
///
/// We use this value to derive other values, such as the installed extensions metric.

View File

@@ -90,11 +90,6 @@
"value": "off",
"vartype": "bool"
},
{
"name": "offload_lfc_interval_seconds",
"value": "20",
"vartype": "integer"
},
{
"name": "neon.safekeepers",
"value": "127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501",

View File

@@ -386,7 +386,6 @@ pub enum NodeSchedulingPolicy {
Pause,
PauseForRestart,
Draining,
Deleting,
}
impl FromStr for NodeSchedulingPolicy {
@@ -399,7 +398,6 @@ impl FromStr for NodeSchedulingPolicy {
"pause" => Ok(Self::Pause),
"pause_for_restart" => Ok(Self::PauseForRestart),
"draining" => Ok(Self::Draining),
"deleting" => Ok(Self::Deleting),
_ => Err(anyhow::anyhow!("Unknown scheduling state '{s}'")),
}
}
@@ -414,7 +412,6 @@ impl From<NodeSchedulingPolicy> for String {
Pause => "pause",
PauseForRestart => "pause_for_restart",
Draining => "draining",
Deleting => "deleting",
}
.to_string()
}
@@ -423,7 +420,6 @@ impl From<NodeSchedulingPolicy> for String {
#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)]
pub enum SkSchedulingPolicy {
Active,
Activating,
Pause,
Decomissioned,
}
@@ -434,7 +430,6 @@ impl FromStr for SkSchedulingPolicy {
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s {
"active" => Self::Active,
"activating" => Self::Activating,
"pause" => Self::Pause,
"decomissioned" => Self::Decomissioned,
_ => {
@@ -451,7 +446,6 @@ impl From<SkSchedulingPolicy> for String {
use SkSchedulingPolicy::*;
match value {
Active => "active",
Activating => "activating",
Pause => "pause",
Decomissioned => "decomissioned",
}

View File

@@ -78,13 +78,7 @@ pub fn is_expected_io_error(e: &io::Error) -> bool {
use io::ErrorKind::*;
matches!(
e.kind(),
HostUnreachable
| NetworkUnreachable
| BrokenPipe
| ConnectionRefused
| ConnectionAborted
| ConnectionReset
| TimedOut,
BrokenPipe | ConnectionRefused | ConnectionAborted | ConnectionReset | TimedOut
)
}

View File

@@ -1,12 +0,0 @@
[package]
name = "json"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
ryu = "1"
itoa = "1"
[dev-dependencies]
futures = "0.3"

View File

@@ -1,412 +0,0 @@
//! A JSON serialization lib, designed for more flexibility than `serde_json` offers.
//!
//! Features:
//!
//! ## Dynamic construction
//!
//! Sometimes you have dynamic values you want to serialize, that are not already in a serde-aware model like a struct or a Vec etc.
//! To achieve this with serde, you need to implement a lot of different traits on a lot of different new-types.
//! Because of this, it's often easier to give-in and pull all the data into a serde-aware model (`serde_json::Value` or some intermediate struct),
//! but that is often not very efficient.
//!
//! This crate allows full control over the JSON encoding without needing to implement any extra traits. Just call the
//! relevant functions, and it will guarantee a correctly encoded JSON value.
//!
//! ## Async construction
//!
//! Similar to the above, sometimes the values arrive asynchronously. Often collecting those values in memory
//! is more expensive than writing them as JSON, since the overheads of `Vec` and `String` is much higher, however
//! there are exceptions.
//!
//! Serializing to JSON all in one go is also more CPU intensive and can cause lag spikes,
//! whereas serializing values incrementally spreads out the CPU load and reduces lag.
//!
//! ## Examples
//!
//! To represent the following JSON as a compact string
//!
//! ```json
//! {
//! "results": {
//! "rows": [
//! {
//! "id": 1,
//! "value": null
//! },
//! {
//! "id": 2,
//! "value": "hello"
//! }
//! ]
//! }
//! }
//! ```
//!
//! We can use the following code:
//!
//! ```
//! // create the outer object
//! let s = json::value_to_string!(|v| json::value_as_object!(|v| {
//! // create an entry with key "results" and start an object value associated with it.
//! let results = v.key("results");
//! json::value_as_object!(|results| {
//! // create an entry with key "rows" and start an list value associated with it.
//! let rows = results.key("rows");
//! json::value_as_list!(|rows| {
//! // create a list entry and start an object value associated with it.
//! let row = rows.entry();
//! json::value_as_object!(|row| {
//! // add entry "id": 1
//! row.entry("id", 1);
//! // add entry "value": null
//! row.entry("value", json::Null);
//! });
//!
//! // create a list entry and start an object value associated with it.
//! let row = rows.entry();
//! json::value_as_object!(|row| {
//! // add entry "id": 2
//! row.entry("id", 2);
//! // add entry "value": "hello"
//! row.entry("value", "hello");
//! });
//! });
//! });
//! }));
//!
//! assert_eq!(s, r#"{"results":{"rows":[{"id":1,"value":null},{"id":2,"value":"hello"}]}}"#);
//! ```
mod macros;
mod str;
mod value;
pub use value::{Null, ValueEncoder};
#[must_use]
/// Serialize a single json value.
pub struct ValueSer<'buf> {
buf: &'buf mut Vec<u8>,
start: usize,
}
impl<'buf> ValueSer<'buf> {
/// Create a new json value serializer.
pub fn new(buf: &'buf mut Vec<u8>) -> Self {
Self { buf, start: 0 }
}
/// Borrow the underlying buffer
pub fn as_buffer(&self) -> &[u8] {
self.buf
}
#[inline]
pub fn value(self, e: impl ValueEncoder) {
e.encode(self);
}
/// Write raw bytes to the buf. This must be already JSON encoded.
#[inline]
pub fn write_raw_json(self, data: &[u8]) {
self.buf.extend_from_slice(data);
self.finish();
}
/// Start a new object serializer.
#[inline]
pub fn object(self) -> ObjectSer<'buf> {
ObjectSer::new(self)
}
/// Start a new list serializer.
#[inline]
pub fn list(self) -> ListSer<'buf> {
ListSer::new(self)
}
/// Finish the value ser.
#[inline]
fn finish(self) {
// don't trigger the drop handler which triggers a rollback.
// this won't cause memory leaks because `ValueSet` owns no allocations.
std::mem::forget(self);
}
}
impl Drop for ValueSer<'_> {
fn drop(&mut self) {
self.buf.truncate(self.start);
}
}
#[must_use]
/// Serialize a json object.
pub struct ObjectSer<'buf> {
value: ValueSer<'buf>,
start: usize,
}
impl<'buf> ObjectSer<'buf> {
/// Start a new object serializer.
#[inline]
pub fn new(value: ValueSer<'buf>) -> Self {
value.buf.push(b'{');
let start = value.buf.len();
Self { value, start }
}
/// Borrow the underlying buffer
pub fn as_buffer(&self) -> &[u8] {
self.value.as_buffer()
}
/// Start a new object entry with the given string key, returning a [`ValueSer`] for the associated value.
#[inline]
pub fn key(&mut self, key: impl KeyEncoder) -> ValueSer<'_> {
key.write_key(self)
}
/// Write an entry (key-value pair) to the object.
#[inline]
pub fn entry(&mut self, key: impl KeyEncoder, val: impl ValueEncoder) {
self.key(key).value(val);
}
#[inline]
fn entry_inner(&mut self, f: impl FnOnce(&mut Vec<u8>)) -> ValueSer<'_> {
// track before the separator so we the value is rolled back it also removes the separator.
let start = self.value.buf.len();
// push separator if necessary
if self.value.buf.len() > self.start {
self.value.buf.push(b',');
}
// push key
f(self.value.buf);
// push value separator
self.value.buf.push(b':');
// return value writer.
ValueSer {
buf: self.value.buf,
start,
}
}
/// Reset the buffer back to before this object was started.
#[inline]
pub fn rollback(self) -> ValueSer<'buf> {
// Do not fully reset the value, only reset it to before the `{`.
// This ensures any `,` before this value are not clobbered.
self.value.buf.truncate(self.start - 1);
self.value
}
/// Finish the object ser.
#[inline]
pub fn finish(self) {
self.value.buf.push(b'}');
self.value.finish();
}
}
pub trait KeyEncoder {
fn write_key<'a>(self, obj: &'a mut ObjectSer) -> ValueSer<'a>;
}
#[must_use]
/// Serialize a json object.
pub struct ListSer<'buf> {
value: ValueSer<'buf>,
start: usize,
}
impl<'buf> ListSer<'buf> {
/// Start a new list serializer.
#[inline]
pub fn new(value: ValueSer<'buf>) -> Self {
value.buf.push(b'[');
let start = value.buf.len();
Self { value, start }
}
/// Borrow the underlying buffer
pub fn as_buffer(&self) -> &[u8] {
self.value.as_buffer()
}
/// Write an value to the list.
#[inline]
pub fn push(&mut self, val: impl ValueEncoder) {
self.entry().value(val);
}
/// Start a new value entry in this list.
#[inline]
pub fn entry(&mut self) -> ValueSer<'_> {
// track before the separator so we the value is rolled back it also removes the separator.
let start = self.value.buf.len();
// push separator if necessary
if self.value.buf.len() > self.start {
self.value.buf.push(b',');
}
// return value writer.
ValueSer {
buf: self.value.buf,
start,
}
}
/// Reset the buffer back to before this object was started.
#[inline]
pub fn rollback(self) -> ValueSer<'buf> {
// Do not fully reset the value, only reset it to before the `[`.
// This ensures any `,` before this value are not clobbered.
self.value.buf.truncate(self.start - 1);
self.value
}
/// Finish the object ser.
#[inline]
pub fn finish(self) {
self.value.buf.push(b']');
self.value.finish();
}
}
#[cfg(test)]
mod tests {
use crate::{Null, ValueSer};
#[test]
fn object() {
let mut buf = vec![];
let mut object = ValueSer::new(&mut buf).object();
object.entry("foo", "bar");
object.entry("baz", Null);
object.finish();
assert_eq!(buf, br#"{"foo":"bar","baz":null}"#);
}
#[test]
fn list() {
let mut buf = vec![];
let mut list = ValueSer::new(&mut buf).list();
list.entry().value("bar");
list.entry().value(Null);
list.finish();
assert_eq!(buf, br#"["bar",null]"#);
}
#[test]
fn object_macro() {
let res = crate::value_to_string!(|obj| {
crate::value_as_object!(|obj| {
obj.entry("foo", "bar");
obj.entry("baz", Null);
})
});
assert_eq!(res, r#"{"foo":"bar","baz":null}"#);
}
#[test]
fn list_macro() {
let res = crate::value_to_string!(|list| {
crate::value_as_list!(|list| {
list.entry().value("bar");
list.entry().value(Null);
})
});
assert_eq!(res, r#"["bar",null]"#);
}
#[test]
fn rollback_on_drop() {
let res = crate::value_to_string!(|list| {
crate::value_as_list!(|list| {
list.entry().value("bar");
'cancel: {
let nested_list = list.entry();
crate::value_as_list!(|nested_list| {
nested_list.entry().value(1);
assert_eq!(nested_list.as_buffer(), br#"["bar",[1"#);
if true {
break 'cancel;
}
})
}
assert_eq!(list.as_buffer(), br#"["bar""#);
list.entry().value(Null);
})
});
assert_eq!(res, r#"["bar",null]"#);
}
#[test]
fn rollback_object() {
let res = crate::value_to_string!(|obj| {
crate::value_as_object!(|obj| {
let entry = obj.key("1");
entry.value(1_i32);
let entry = obj.key("2");
let entry = {
let mut nested_obj = entry.object();
nested_obj.entry("foo", "bar");
nested_obj.rollback()
};
entry.value(2_i32);
})
});
assert_eq!(res, r#"{"1":1,"2":2}"#);
}
#[test]
fn rollback_list() {
let res = crate::value_to_string!(|list| {
crate::value_as_list!(|list| {
let entry = list.entry();
entry.value(1_i32);
let entry = list.entry();
let entry = {
let mut nested_list = entry.list();
nested_list.push("foo");
nested_list.rollback()
};
entry.value(2_i32);
})
});
assert_eq!(res, r#"[1,2]"#);
}
#[test]
fn string_escaping() {
let mut buf = vec![];
let mut object = ValueSer::new(&mut buf).object();
let key = "hello";
let value = "\n world";
object.entry(format_args!("{key:?}"), value);
object.finish();
assert_eq!(buf, br#"{"\"hello\"":"\n world"}"#);
}
}

View File

@@ -1,86 +0,0 @@
//! # Examples
//!
//! ```
//! use futures::{StreamExt, TryStream, TryStreamExt};
//!
//! async fn stream_to_json_list<S, T, E>(mut s: S) -> Result<String, E>
//! where
//! S: TryStream<Ok = T, Error = E> + Unpin,
//! T: json::ValueEncoder
//! {
//! Ok(json::value_to_string!(|val| json::value_as_list!(|val| {
//! // note how we can use `.await` and `?` in here.
//! while let Some(value) = s.try_next().await? {
//! val.push(value);
//! }
//! })))
//! }
//!
//! let stream = futures::stream::iter([1, 2, 3]).map(Ok::<i32, ()>);
//! let json_string = futures::executor::block_on(stream_to_json_list(stream)).unwrap();
//! assert_eq!(json_string, "[1,2,3]");
//! ```
/// A helper to create a new JSON vec.
///
/// Implemented as a macro to preserve all control flow.
#[macro_export]
macro_rules! value_to_vec {
(|$val:ident| $body:expr) => {{
let mut buf = vec![];
let $val = $crate::ValueSer::new(&mut buf);
let _: () = $body;
buf
}};
}
/// A helper to create a new JSON string.
///
/// Implemented as a macro to preserve all control flow.
#[macro_export]
macro_rules! value_to_string {
(|$val:ident| $body:expr) => {{
::std::string::String::from_utf8($crate::value_to_vec!(|$val| $body))
.expect("json should be valid utf8")
}};
}
/// A helper that ensures the [`ObjectSer::finish`](crate::ObjectSer::finish) method is called on completion.
///
/// Consumes `$val` and assigns it as an [`ObjectSer`](crate::ObjectSer) serializer.
/// The serializer is only 'finished' if the body completes.
/// The serializer is rolled back if `break`/`return` escapes the body.
///
/// Implemented as a macro to preserve all control flow.
#[macro_export]
macro_rules! value_as_object {
(|$val:ident| $body:expr) => {{
let mut obj = $crate::ObjectSer::new($val);
let $val = &mut obj;
let res = $body;
obj.finish();
res
}};
}
/// A helper that ensures the [`ListSer::finish`](crate::ListSer::finish) method is called on completion.
///
/// Consumes `$val` and assigns it as an [`ListSer`](crate::ListSer) serializer.
/// The serializer is only 'finished' if the body completes.
/// The serializer is rolled back if `break`/`return` escapes the body.
///
/// Implemented as a macro to preserve all control flow.
#[macro_export]
macro_rules! value_as_list {
(|$val:ident| $body:expr) => {{
let mut list = $crate::ListSer::new($val);
let $val = &mut list;
let res = $body;
list.finish();
res
}};
}

View File

@@ -1,166 +0,0 @@
//! Helpers for serializing escaped strings.
//!
//! ## License
//!
//! <https://github.com/serde-rs/json/blob/c1826ebcccb1a520389c6b78ad3da15db279220d/src/ser.rs#L1514-L1552>
//! <https://github.com/serde-rs/json/blob/c1826ebcccb1a520389c6b78ad3da15db279220d/src/ser.rs#L2081-L2157>
//! Licensed by David Tolnay under MIT or Apache-2.0.
//!
//! With modifications by Conrad Ludgate on behalf of Databricks.
use std::fmt::{self, Write};
/// Represents a character escape code in a type-safe manner.
pub enum CharEscape {
/// An escaped quote `"`
Quote,
/// An escaped reverse solidus `\`
ReverseSolidus,
// /// An escaped solidus `/`
// Solidus,
/// An escaped backspace character (usually escaped as `\b`)
Backspace,
/// An escaped form feed character (usually escaped as `\f`)
FormFeed,
/// An escaped line feed character (usually escaped as `\n`)
LineFeed,
/// An escaped carriage return character (usually escaped as `\r`)
CarriageReturn,
/// An escaped tab character (usually escaped as `\t`)
Tab,
/// An escaped ASCII plane control character (usually escaped as
/// `\u00XX` where `XX` are two hex characters)
AsciiControl(u8),
}
impl CharEscape {
#[inline]
fn from_escape_table(escape: u8, byte: u8) -> CharEscape {
match escape {
self::BB => CharEscape::Backspace,
self::TT => CharEscape::Tab,
self::NN => CharEscape::LineFeed,
self::FF => CharEscape::FormFeed,
self::RR => CharEscape::CarriageReturn,
self::QU => CharEscape::Quote,
self::BS => CharEscape::ReverseSolidus,
self::UU => CharEscape::AsciiControl(byte),
_ => unreachable!(),
}
}
}
pub(crate) fn format_escaped_str(writer: &mut Vec<u8>, value: &str) {
writer.reserve(2 + value.len());
writer.push(b'"');
let rest = format_escaped_str_contents(writer, value);
writer.extend_from_slice(rest);
writer.push(b'"');
}
pub(crate) fn format_escaped_fmt(writer: &mut Vec<u8>, args: fmt::Arguments) {
writer.push(b'"');
Collect { buf: writer }
.write_fmt(args)
.expect("formatting should not error");
writer.push(b'"');
}
struct Collect<'buf> {
buf: &'buf mut Vec<u8>,
}
impl fmt::Write for Collect<'_> {
fn write_str(&mut self, s: &str) -> fmt::Result {
let last = format_escaped_str_contents(self.buf, s);
self.buf.extend(last);
Ok(())
}
}
// writes any escape sequences, and returns the suffix still needed to be written.
fn format_escaped_str_contents<'a>(writer: &mut Vec<u8>, value: &'a str) -> &'a [u8] {
let bytes = value.as_bytes();
let mut start = 0;
for (i, &byte) in bytes.iter().enumerate() {
let escape = ESCAPE[byte as usize];
if escape == 0 {
continue;
}
writer.extend_from_slice(&bytes[start..i]);
let char_escape = CharEscape::from_escape_table(escape, byte);
write_char_escape(writer, char_escape);
start = i + 1;
}
&bytes[start..]
}
const BB: u8 = b'b'; // \x08
const TT: u8 = b't'; // \x09
const NN: u8 = b'n'; // \x0A
const FF: u8 = b'f'; // \x0C
const RR: u8 = b'r'; // \x0D
const QU: u8 = b'"'; // \x22
const BS: u8 = b'\\'; // \x5C
const UU: u8 = b'u'; // \x00...\x1F except the ones above
const __: u8 = 0;
// Lookup table of escape sequences. A value of b'x' at index i means that byte
// i is escaped as "\x" in JSON. A value of 0 means that byte i is not escaped.
static ESCAPE: [u8; 256] = [
// 1 2 3 4 5 6 7 8 9 A B C D E F
UU, UU, UU, UU, UU, UU, UU, UU, BB, TT, NN, UU, FF, RR, UU, UU, // 0
UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, UU, // 1
__, __, QU, __, __, __, __, __, __, __, __, __, __, __, __, __, // 2
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 3
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 4
__, __, __, __, __, __, __, __, __, __, __, __, BS, __, __, __, // 5
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 6
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 7
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 8
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // 9
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // A
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // B
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // C
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // D
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // E
__, __, __, __, __, __, __, __, __, __, __, __, __, __, __, __, // F
];
fn write_char_escape(writer: &mut Vec<u8>, char_escape: CharEscape) {
let s = match char_escape {
CharEscape::Quote => b"\\\"",
CharEscape::ReverseSolidus => b"\\\\",
// CharEscape::Solidus => b"\\/",
CharEscape::Backspace => b"\\b",
CharEscape::FormFeed => b"\\f",
CharEscape::LineFeed => b"\\n",
CharEscape::CarriageReturn => b"\\r",
CharEscape::Tab => b"\\t",
CharEscape::AsciiControl(byte) => {
static HEX_DIGITS: [u8; 16] = *b"0123456789abcdef";
let bytes = &[
b'\\',
b'u',
b'0',
b'0',
HEX_DIGITS[(byte >> 4) as usize],
HEX_DIGITS[(byte & 0xF) as usize],
];
return writer.extend_from_slice(bytes);
}
};
writer.extend_from_slice(s);
}

View File

@@ -1,168 +0,0 @@
use core::fmt;
use std::collections::{BTreeMap, HashMap};
use crate::str::{format_escaped_fmt, format_escaped_str};
use crate::{KeyEncoder, ObjectSer, ValueSer, value_as_list, value_as_object};
/// Write a value to the underlying json representation.
pub trait ValueEncoder {
fn encode(self, v: ValueSer<'_>);
}
pub(crate) fn write_int(x: impl itoa::Integer, b: &mut Vec<u8>) {
b.extend_from_slice(itoa::Buffer::new().format(x).as_bytes());
}
pub(crate) fn write_float(x: impl ryu::Float, b: &mut Vec<u8>) {
b.extend_from_slice(ryu::Buffer::new().format(x).as_bytes());
}
impl<T: Copy + ValueEncoder> ValueEncoder for &T {
#[inline]
fn encode(self, v: ValueSer<'_>) {
T::encode(*self, v);
}
}
impl ValueEncoder for &str {
#[inline]
fn encode(self, v: ValueSer<'_>) {
format_escaped_str(v.buf, self);
v.finish();
}
}
impl ValueEncoder for fmt::Arguments<'_> {
#[inline]
fn encode(self, v: ValueSer<'_>) {
if let Some(s) = self.as_str() {
format_escaped_str(v.buf, s);
} else {
format_escaped_fmt(v.buf, self);
}
v.finish();
}
}
macro_rules! int {
[$($t:ty),*] => {
$(
impl ValueEncoder for $t {
#[inline]
fn encode(self, v: ValueSer<'_>) {
write_int(self, v.buf);
v.finish();
}
}
)*
};
}
int![u8, u16, u32, u64, usize, u128];
int![i8, i16, i32, i64, isize, i128];
macro_rules! float {
[$($t:ty),*] => {
$(
impl ValueEncoder for $t {
#[inline]
fn encode(self, v: ValueSer<'_>) {
write_float(self, v.buf);
v.finish();
}
}
)*
};
}
float![f32, f64];
impl ValueEncoder for bool {
#[inline]
fn encode(self, v: ValueSer<'_>) {
v.write_raw_json(if self { b"true" } else { b"false" });
}
}
impl<T: ValueEncoder> ValueEncoder for Option<T> {
#[inline]
fn encode(self, v: ValueSer<'_>) {
match self {
Some(value) => value.encode(v),
None => Null.encode(v),
}
}
}
impl KeyEncoder for &str {
#[inline]
fn write_key<'a>(self, obj: &'a mut ObjectSer) -> ValueSer<'a> {
let obj = &mut *obj;
obj.entry_inner(|b| format_escaped_str(b, self))
}
}
impl KeyEncoder for fmt::Arguments<'_> {
#[inline]
fn write_key<'a>(self, obj: &'a mut ObjectSer) -> ValueSer<'a> {
if let Some(key) = self.as_str() {
obj.entry_inner(|b| format_escaped_str(b, key))
} else {
obj.entry_inner(|b| format_escaped_fmt(b, self))
}
}
}
/// Represents the JSON null value.
pub struct Null;
impl ValueEncoder for Null {
#[inline]
fn encode(self, v: ValueSer<'_>) {
v.write_raw_json(b"null");
}
}
impl<T: ValueEncoder> ValueEncoder for Vec<T> {
#[inline]
fn encode(self, v: ValueSer<'_>) {
value_as_list!(|v| {
for t in self {
v.entry().value(t);
}
});
}
}
impl<T: Copy + ValueEncoder> ValueEncoder for &[T] {
#[inline]
fn encode(self, v: ValueSer<'_>) {
value_as_list!(|v| {
for t in self {
v.entry().value(t);
}
});
}
}
impl<K: KeyEncoder, V: ValueEncoder, S> ValueEncoder for HashMap<K, V, S> {
#[inline]
fn encode(self, o: ValueSer<'_>) {
value_as_object!(|o| {
for (k, v) in self {
o.entry(k, v);
}
});
}
}
impl<K: KeyEncoder, V: ValueEncoder> ValueEncoder for BTreeMap<K, V> {
#[inline]
fn encode(self, o: ValueSer<'_>) {
value_as_object!(|o| {
for (k, v) in self {
o.entry(k, v);
}
});
}
}

View File

@@ -52,7 +52,7 @@ pub(crate) async fn hi(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
}
// yield every ~250us
// hopefully reduces tail latencies
if i.is_multiple_of(1024) {
if i % 1024 == 0 {
yield_now().await
}
}

View File

@@ -90,7 +90,7 @@ pub struct InnerClient {
}
impl InnerClient {
pub fn start(&mut self) -> Result<PartialQuery<'_>, Error> {
pub fn start(&mut self) -> Result<PartialQuery, Error> {
self.responses.waiting += 1;
Ok(PartialQuery(Some(self)))
}
@@ -227,7 +227,7 @@ impl Client {
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
) -> Result<RowStream, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
@@ -262,7 +262,7 @@ impl Client {
pub(crate) async fn simple_query_raw(
&mut self,
query: &str,
) -> Result<SimpleQueryStream<'_>, Error> {
) -> Result<SimpleQueryStream, Error> {
simple_query::simple_query(self.inner_mut(), query).await
}

View File

@@ -12,11 +12,7 @@ mod private {
/// This trait is "sealed", and cannot be implemented outside of this crate.
pub trait GenericClient: private::Sealed {
/// Like `Client::query_raw_txt`.
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
@@ -26,11 +22,7 @@ pub trait GenericClient: private::Sealed {
impl private::Sealed for Client {}
impl GenericClient for Client {
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
@@ -43,11 +35,7 @@ impl GenericClient for Client {
impl private::Sealed for Transaction<'_> {}
impl GenericClient for Transaction<'_> {
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,

View File

@@ -47,7 +47,7 @@ impl<'a> Transaction<'a> {
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
) -> Result<RowStream, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,

View File

@@ -221,7 +221,7 @@ pub struct TimelineMembershipSwitchRequest {
pub struct TimelineMembershipSwitchResponse {
pub previous_conf: Configuration,
pub current_conf: Configuration,
pub last_log_term: Term,
pub term: Term,
pub flush_lsn: Lsn,
}

View File

@@ -24,28 +24,12 @@ macro_rules! critical {
if cfg!(debug_assertions) {
panic!($($arg)*);
}
// Increment both metrics
$crate::logging::TRACING_EVENT_COUNT_METRIC.inc_critical();
let backtrace = std::backtrace::Backtrace::capture();
tracing::error!("CRITICAL: {}\n{backtrace}", format!($($arg)*));
}};
}
#[macro_export]
macro_rules! critical_timeline {
($tenant_shard_id:expr, $timeline_id:expr, $($arg:tt)*) => {{
if cfg!(debug_assertions) {
panic!($($arg)*);
}
// Increment both metrics
$crate::logging::TRACING_EVENT_COUNT_METRIC.inc_critical();
$crate::logging::HADRON_CRITICAL_STORAGE_EVENT_COUNT_METRIC.inc(&$tenant_shard_id.to_string(), &$timeline_id.to_string());
let backtrace = std::backtrace::Backtrace::capture();
tracing::error!("CRITICAL: [tenant_shard_id: {}, timeline_id: {}] {}\n{backtrace}",
$tenant_shard_id, $timeline_id, format!($($arg)*));
}};
}
#[derive(EnumString, strum_macros::Display, VariantNames, Eq, PartialEq, Debug, Clone, Copy)]
#[strum(serialize_all = "snake_case")]
pub enum LogFormat {
@@ -77,36 +61,6 @@ pub struct TracingEventCountMetric {
trace: IntCounter,
}
// Begin Hadron: Add a HadronCriticalStorageEventCountMetric metric that is sliced by tenant_id and timeline_id
pub struct HadronCriticalStorageEventCountMetric {
critical: IntCounterVec,
}
pub static HADRON_CRITICAL_STORAGE_EVENT_COUNT_METRIC: Lazy<HadronCriticalStorageEventCountMetric> =
Lazy::new(|| {
let vec = metrics::register_int_counter_vec!(
"hadron_critical_storage_event_count",
"Number of critical storage events, by tenant_id and timeline_id",
&["tenant_shard_id", "timeline_id"]
)
.expect("failed to define metric");
HadronCriticalStorageEventCountMetric::new(vec)
});
impl HadronCriticalStorageEventCountMetric {
fn new(vec: IntCounterVec) -> Self {
Self { critical: vec }
}
// Allow public access from `critical!` macro.
pub fn inc(&self, tenant_shard_id: &str, timeline_id: &str) {
self.critical
.with_label_values(&[tenant_shard_id, timeline_id])
.inc();
}
}
// End Hadron
pub static TRACING_EVENT_COUNT_METRIC: Lazy<TracingEventCountMetric> = Lazy::new(|| {
let vec = metrics::register_int_counter_vec!(
"libmetrics_tracing_event_count",

View File

@@ -28,7 +28,6 @@ use reqwest::Url;
use storage_broker::Uri;
use utils::id::{NodeId, TimelineId};
use utils::logging::{LogFormat, SecretString};
use utils::serde_percent::Percent;
use crate::tenant::storage_layer::inmemory_layer::IndexEntry;
use crate::tenant::{TENANTS_SEGMENT_NAME, TIMELINES_SEGMENT_NAME};
@@ -460,16 +459,7 @@ impl PageServerConf {
metric_collection_endpoint,
metric_collection_bucket,
synthetic_size_calculation_interval,
disk_usage_based_eviction: Some(disk_usage_based_eviction.unwrap_or(
DiskUsageEvictionTaskConfig {
max_usage_pct: Percent::new(80).unwrap(),
min_avail_bytes: 2_000_000_000,
period: Duration::from_secs(60),
#[cfg(feature = "testing")]
mock_statvfs: None,
eviction_order: Default::default(),
},
)),
disk_usage_based_eviction,
test_remote_failures,
ondemand_download_behavior_treat_error_as_warn,
background_task_maximum_delay,
@@ -707,8 +697,6 @@ impl ConfigurableSemaphore {
#[cfg(test)]
mod tests {
use std::time::Duration;
use camino::Utf8PathBuf;
use rstest::rstest;
use utils::id::NodeId;
@@ -810,20 +798,4 @@ mod tests {
PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir)
.expect("parse_and_validate");
}
#[test]
fn test_config_disk_usage_based_eviction_is_valid() {
let input = r#"
control_plane_api = "http://localhost:6666"
"#;
let config_toml = toml_edit::de::from_str::<pageserver_api::config::ConfigToml>(input)
.expect("disk_usage_based_eviction is valid");
let workdir = Utf8PathBuf::from("/nonexistent");
let config = PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir).unwrap();
let disk_usage_based_eviction = config.disk_usage_based_eviction.unwrap();
assert_eq!(disk_usage_based_eviction.max_usage_pct.get(), 80);
assert_eq!(disk_usage_based_eviction.min_avail_bytes, 2_000_000_000);
assert_eq!(disk_usage_based_eviction.period, Duration::from_secs(60));
assert_eq!(disk_usage_based_eviction.eviction_order, Default::default());
}
}

View File

@@ -50,7 +50,6 @@ use tokio::io::{AsyncRead, AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _, Bu
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use tonic::service::Interceptor as _;
use tonic::transport::server::TcpConnectInfo;
use tracing::*;
use utils::auth::{Claims, Scope, SwappableJwtAuth};
use utils::id::{TenantId, TenantTimelineId, TimelineId};
@@ -3686,15 +3685,8 @@ impl proto::PageService for GrpcPageServiceHandler {
yield match result {
Ok(resp) => resp,
// Convert per-request errors to GetPageResponses as appropriate, or terminate
// the stream with a tonic::Status. Log the error regardless, since
// ObservabilityLayer can't automatically log stream errors.
Err(status) => {
// TODO: it would be nice if we could propagate the get_page() fields here.
span.in_scope(|| {
warn!("request failed with {:?}: {}", status.code(), status.message());
});
page_api::GetPageResponse::try_from_status(status, req_id)?.into()
}
// the stream with a tonic::Status.
Err(err) => page_api::GetPageResponse::try_from_status(err, req_id)?.into(),
}
}
};
@@ -3832,85 +3824,40 @@ impl<S: tonic::server::NamedService> tonic::server::NamedService for Observabili
const NAME: &'static str = S::NAME; // propagate inner service name
}
impl<S, Req, Resp> tower::Service<http::Request<Req>> for ObservabilityLayerService<S>
impl<S, B> tower::Service<http::Request<B>> for ObservabilityLayerService<S>
where
S: tower::Service<http::Request<Req>, Response = http::Response<Resp>> + Send,
S: tower::Service<http::Request<B>>,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn call(&mut self, mut req: http::Request<Req>) -> Self::Future {
fn call(&mut self, mut req: http::Request<B>) -> Self::Future {
// Record the request start time as a request extension.
//
// TODO: we should start a timer here instead, but it currently requires a timeline handle
// and SmgrQueryType, which we don't have yet. Refactor it to provide it later.
req.extensions_mut().insert(ReceivedAt(Instant::now()));
// Extract the peer address and gRPC method.
let peer = req
.extensions()
.get::<TcpConnectInfo>()
.and_then(|info| info.remote_addr())
.map(|addr| addr.to_string())
.unwrap_or_default();
let method = req
.uri()
.path()
.split('/')
.nth(2)
.unwrap_or(req.uri().path())
.to_string();
// Create a basic tracing span.
// Create a basic tracing span. Enter the span for the current thread (to use it for inner
// sync code like interceptors), and instrument the future (to use it for inner async code
// like the page service itself).
//
// Enter the span for the current thread and instrument the future. It is not sufficient to
// only instrument the future, since it only takes effect after the future is returned and
// polled, not when the inner service is called below (e.g. during interceptor execution).
// The instrument() call below is not sufficient. It only affects the returned future, and
// only takes effect when the caller polls it. Any sync code executed when we call
// self.inner.call() below (such as interceptors) runs outside of the returned future, and
// is not affected by it. We therefore have to enter the span on the current thread too.
let span = info_span!(
"grpc:pageservice",
// These will be populated by TenantMetadataInterceptor.
// Set by TenantMetadataInterceptor.
tenant_id = field::Empty,
timeline_id = field::Empty,
shard_id = field::Empty,
// NB: empty fields must be listed first above. Otherwise, the field names will be
// clobbered when the empty fields are populated. They will be output last regardless.
%peer,
%method,
);
let _guard = span.enter();
// Construct a future for calling the inner service, but don't await it. This avoids having
// to clone the inner service into the future below.
let call = self.inner.call(req);
async move {
// Await the inner service call.
let result = call.await;
// Log gRPC error statuses. This won't include request info from handler spans, but it
// will catch all errors (even those emitted before handler spans are constructed). Only
// unary request errors are logged here, not streaming response errors.
if let Ok(ref resp) = result
&& let Some(status) = tonic::Status::from_header_map(resp.headers())
&& status.code() != tonic::Code::Ok
{
// TODO: it would be nice if we could propagate the handler span's request fields
// here. This could e.g. be done by attaching the request fields to
// tonic::Status::metadata via a proc macro.
warn!(
"request failed with {:?}: {}",
status.code(),
status.message()
);
}
result
}
.instrument(span.clone())
.boxed()
Box::pin(self.inner.call(req).instrument(span.clone()))
}
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {

View File

@@ -78,7 +78,7 @@ use utils::rate_limit::RateLimit;
use utils::seqwait::SeqWait;
use utils::simple_rcu::{Rcu, RcuReadGuard};
use utils::sync::gate::{Gate, GateGuard};
use utils::{completion, critical_timeline, fs_ext, pausable_failpoint};
use utils::{completion, critical, fs_ext, pausable_failpoint};
#[cfg(test)]
use wal_decoder::models::value::Value;
use wal_decoder::serialized_batch::{SerializedValueBatch, ValueMeta};
@@ -2144,31 +2144,14 @@ impl Timeline {
debug_assert_current_span_has_tenant_and_timeline_id();
// Regardless of whether we're going to try_freeze_and_flush
// cancel walreceiver to stop ingesting more data asap.
//
// Note that we're accepting a race condition here where we may
// do the final flush below, before walreceiver observes the
// cancellation and exits.
// This means we may open a new InMemoryLayer after the final flush below.
// Flush loop is also still running for a short while, so, in theory, it
// could also make its way into the upload queue.
//
// If we wait for the shutdown of the walreceiver before moving on to the
// flush, then that would be avoided. But we don't do it because the
// walreceiver entertains reads internally, which means that it possibly
// depends on the download of layers. Layer download is only sensitive to
// the cancellation of the entire timeline, so cancelling the walreceiver
// will have no effect on the individual get requests.
// This would cause problems when there is a lot of ongoing downloads or
// there is S3 unavailabilities, i.e. detach, deletion, etc would hang,
// and we can't deallocate resources of the timeline, etc.
// or not, stop ingesting any more data.
let walreceiver = self.walreceiver.lock().unwrap().take();
tracing::debug!(
is_some = walreceiver.is_some(),
"Waiting for WalReceiverManager..."
);
if let Some(walreceiver) = walreceiver {
walreceiver.cancel().await;
walreceiver.shutdown().await;
}
// ... and inform any waiters for newer LSNs that there won't be any.
self.last_record_lsn.shutdown();
@@ -4746,7 +4729,7 @@ impl Timeline {
}
// Fetch the next layer to flush, if any.
let (layer, l0_count, frozen_count, frozen_size, open_layer_size) = {
let (layer, l0_count, frozen_count, frozen_size) = {
let layers = self.layers.read(LayerManagerLockHolder::FlushLoop).await;
let Ok(lm) = layers.layer_map() else {
info!("dropping out of flush loop for timeline shutdown");
@@ -4759,13 +4742,8 @@ impl Timeline {
.iter()
.map(|l| l.estimated_in_mem_size())
.sum();
let open_layer_size: u64 = lm
.open_layer
.as_ref()
.map(|l| l.estimated_in_mem_size())
.unwrap_or(0);
let layer = lm.frozen_layers.front().cloned();
(layer, l0_count, frozen_count, frozen_size, open_layer_size)
(layer, l0_count, frozen_count, frozen_size)
// drop 'layers' lock
};
let Some(layer) = layer else {
@@ -4778,7 +4756,7 @@ impl Timeline {
if l0_count >= stall_threshold {
warn!(
"stalling layer flushes for compaction backpressure at {l0_count} \
L0 layers ({frozen_count} frozen layers with {frozen_size} bytes, {open_layer_size} bytes in open layer)"
L0 layers ({frozen_count} frozen layers with {frozen_size} bytes)"
);
let stall_timer = self
.metrics
@@ -4831,7 +4809,7 @@ impl Timeline {
let delay = flush_duration.as_secs_f64();
info!(
"delaying layer flush by {delay:.3}s for compaction backpressure at \
{l0_count} L0 layers ({frozen_count} frozen layers with {frozen_size} bytes, {open_layer_size} bytes in open layer)"
{l0_count} L0 layers ({frozen_count} frozen layers with {frozen_size} bytes)"
);
let _delay_timer = self
.metrics
@@ -6841,11 +6819,7 @@ impl Timeline {
Err(walredo::Error::Cancelled) => return Err(PageReconstructError::Cancelled),
Err(walredo::Error::Other(err)) => {
if fire_critical_error {
critical_timeline!(
self.tenant_shard_id,
self.timeline_id,
"walredo failure during page reconstruction: {err:?}"
);
critical!("walredo failure during page reconstruction: {err:?}");
}
return Err(PageReconstructError::WalRedo(
err.context("reconstruct a page image"),

View File

@@ -36,7 +36,7 @@ use serde::Serialize;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, info_span, trace, warn};
use utils::critical_timeline;
use utils::critical;
use utils::id::TimelineId;
use utils::lsn::Lsn;
use wal_decoder::models::record::NeonWalRecord;
@@ -1390,11 +1390,7 @@ impl Timeline {
GetVectoredError::MissingKey(_),
) = err
{
critical_timeline!(
self.tenant_shard_id,
self.timeline_id,
"missing key during compaction: {err:?}"
);
critical!("missing key during compaction: {err:?}");
}
})?;
@@ -1422,11 +1418,7 @@ impl Timeline {
// Alert on critical errors that indicate data corruption.
Err(err) if err.is_critical() => {
critical_timeline!(
self.tenant_shard_id,
self.timeline_id,
"could not compact, repartitioning keyspace failed: {err:?}"
);
critical!("could not compact, repartitioning keyspace failed: {err:?}");
}
// Log other errors. No partitioning? This is normal, if the timeline was just created

View File

@@ -182,7 +182,6 @@ pub(crate) async fn generate_tombstone_image_layer(
detached: &Arc<Timeline>,
ancestor: &Arc<Timeline>,
ancestor_lsn: Lsn,
historic_layers_to_copy: &Vec<Layer>,
ctx: &RequestContext,
) -> Result<Option<ResidentLayer>, Error> {
tracing::info!(
@@ -200,20 +199,6 @@ pub(crate) async fn generate_tombstone_image_layer(
let image_lsn = ancestor_lsn;
{
for layer in historic_layers_to_copy {
let desc = layer.layer_desc();
if !desc.is_delta
&& desc.lsn_range.start == image_lsn
&& overlaps_with(&key_range, &desc.key_range)
{
tracing::info!(
layer=%layer, "will copy tombstone from ancestor instead of creating a new one"
);
return Ok(None);
}
}
let layers = detached
.layers
.read(LayerManagerLockHolder::DetachAncestor)
@@ -465,8 +450,7 @@ pub(super) async fn prepare(
Vec::with_capacity(straddling_branchpoint.len() + rest_of_historic.len() + 1);
if let Some(tombstone_layer) =
generate_tombstone_image_layer(detached, &ancestor, ancestor_lsn, &rest_of_historic, ctx)
.await?
generate_tombstone_image_layer(detached, &ancestor, ancestor_lsn, ctx).await?
{
new_layers.push(tombstone_layer.into());
}

View File

@@ -63,6 +63,7 @@ pub struct WalReceiver {
/// All task spawned by [`WalReceiver::start`] and its children are sensitive to this token.
/// It's a child token of [`Timeline`] so that timeline shutdown can cancel WalReceiver tasks early for `freeze_and_flush=true`.
cancel: CancellationToken,
task: tokio::task::JoinHandle<()>,
}
impl WalReceiver {
@@ -79,7 +80,7 @@ impl WalReceiver {
let loop_status = Arc::new(std::sync::RwLock::new(None));
let manager_status = Arc::clone(&loop_status);
let cancel = timeline.cancel.child_token();
let _task = WALRECEIVER_RUNTIME.spawn({
let task = WALRECEIVER_RUNTIME.spawn({
let cancel = cancel.clone();
async move {
debug_assert_current_span_has_tenant_and_timeline_id();
@@ -120,14 +121,25 @@ impl WalReceiver {
Self {
manager_status,
cancel,
task,
}
}
#[instrument(skip_all, level = tracing::Level::DEBUG)]
pub async fn cancel(self) {
pub async fn shutdown(self) {
debug_assert_current_span_has_tenant_and_timeline_id();
debug!("cancelling walreceiver tasks");
self.cancel.cancel();
match self.task.await {
Ok(()) => debug!("Shutdown success"),
Err(je) if je.is_cancelled() => unreachable!("not used"),
Err(je) if je.is_panic() => {
// already logged by panic hook
}
Err(je) => {
error!("shutdown walreceiver task join error: {je}")
}
}
}
pub(crate) fn status(&self) -> Option<ConnectionManagerStatus> {

View File

@@ -100,7 +100,6 @@ pub(super) async fn connection_manager_loop_step(
// with other streams on this client (other connection managers). When
// object goes out of scope, stream finishes in drop() automatically.
let mut broker_subscription = subscribe_for_timeline_updates(broker_client, id, cancel).await?;
let mut broker_reset_interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
debug!("Subscribed for broker timeline updates");
loop {
@@ -157,10 +156,7 @@ pub(super) async fn connection_manager_loop_step(
// Got a new update from the broker
broker_update = broker_subscription.message() /* TODO: review cancellation-safety */ => {
match broker_update {
Ok(Some(broker_update)) => {
broker_reset_interval.reset();
connection_manager_state.register_timeline_update(broker_update);
},
Ok(Some(broker_update)) => connection_manager_state.register_timeline_update(broker_update),
Err(status) => {
match status.code() {
Code::Unknown if status.message().contains("stream closed because of a broken pipe") || status.message().contains("connection reset") || status.message().contains("error reading a body from connection") => {
@@ -182,14 +178,6 @@ pub(super) async fn connection_manager_loop_step(
}
},
_ = broker_reset_interval.tick() => {
if wait_lsn_status.borrow().is_some() {
tracing::warn!("No broker updates received for a while, but waiting for WAL. Re-setting stream ...")
}
broker_subscription = subscribe_for_timeline_updates(broker_client, id, cancel).await?;
},
new_event = async {
// Reminder: this match arm needs to be cancellation-safe.
loop {

View File

@@ -25,7 +25,7 @@ use tokio_postgres::replication::ReplicationStream;
use tokio_postgres::{Client, SimpleQueryMessage, SimpleQueryRow};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, trace, warn};
use utils::critical_timeline;
use utils::critical;
use utils::id::NodeId;
use utils::lsn::Lsn;
use utils::pageserver_feedback::PageserverFeedback;
@@ -275,12 +275,20 @@ pub(super) async fn handle_walreceiver_connection(
let copy_stream = replication_client.copy_both_simple(&query).await?;
let mut physical_stream = pin!(ReplicationStream::new(copy_stream));
let mut walingest = WalIngest::new(timeline.as_ref(), startpoint, &ctx)
.await
.map_err(|e| match e.kind {
crate::walingest::WalIngestErrorKind::Cancelled => WalReceiverError::Cancelled,
_ => WalReceiverError::Other(e.into()),
})?;
let walingest_future = WalIngest::new(timeline.as_ref(), startpoint, &ctx);
let walingest_res = select! {
walingest_res = walingest_future => walingest_res,
_ = cancellation.cancelled() => {
// We are doing reads in WalIngest::new, and those can hang as they come from the network.
// Timeline cancellation hits the walreceiver cancellation token before it hits the timeline global one.
debug!("Connection cancelled");
return Err(WalReceiverError::Cancelled);
},
};
let mut walingest = walingest_res.map_err(|e| match e.kind {
crate::walingest::WalIngestErrorKind::Cancelled => WalReceiverError::Cancelled,
_ => WalReceiverError::Other(e.into()),
})?;
let (format, compression) = match protocol {
PostgresClientProtocol::Interpreted {
@@ -360,13 +368,9 @@ pub(super) async fn handle_walreceiver_connection(
match raw_wal_start_lsn.cmp(&expected_wal_start) {
std::cmp::Ordering::Greater => {
let msg = format!(
"Gap in streamed WAL: [{expected_wal_start}, {raw_wal_start_lsn}"
);
critical_timeline!(
timeline.tenant_shard_id,
timeline.timeline_id,
"{msg}"
"Gap in streamed WAL: [{expected_wal_start}, {raw_wal_start_lsn})"
);
critical!("{msg}");
return Err(WalReceiverError::Other(anyhow!(msg)));
}
std::cmp::Ordering::Less => {
@@ -379,11 +383,7 @@ pub(super) async fn handle_walreceiver_connection(
"Received record with next_record_lsn multiple times ({} < {})",
first_rec.next_record_lsn, expected_wal_start
);
critical_timeline!(
timeline.tenant_shard_id,
timeline.timeline_id,
"{msg}"
);
critical!("{msg}");
return Err(WalReceiverError::Other(anyhow!(msg)));
}
}
@@ -452,11 +452,7 @@ pub(super) async fn handle_walreceiver_connection(
// TODO: we can't differentiate cancellation errors with
// anyhow::Error, so just ignore it if we're cancelled.
if !cancellation.is_cancelled() && !timeline.is_stopping() {
critical_timeline!(
timeline.tenant_shard_id,
timeline.timeline_id,
"{err:?}"
);
critical!("{err:?}")
}
})?;

View File

@@ -40,7 +40,7 @@ use tracing::*;
use utils::bin_ser::{DeserializeError, SerializeError};
use utils::lsn::Lsn;
use utils::rate_limit::RateLimit;
use utils::{critical_timeline, failpoint_support};
use utils::{critical, failpoint_support};
use wal_decoder::models::record::NeonWalRecord;
use wal_decoder::models::*;
@@ -418,30 +418,18 @@ impl WalIngest {
// as there has historically been cases where PostgreSQL has cleared spurious VM pages. See:
// https://github.com/neondatabase/neon/pull/10634.
let Some(vm_size) = get_relsize(modification, vm_rel, ctx).await? else {
critical_timeline!(
modification.tline.tenant_shard_id,
modification.tline.timeline_id,
"clear_vm_bits for unknown VM relation {vm_rel}"
);
critical!("clear_vm_bits for unknown VM relation {vm_rel}");
return Ok(());
};
if let Some(blknum) = new_vm_blk {
if blknum >= vm_size {
critical_timeline!(
modification.tline.tenant_shard_id,
modification.tline.timeline_id,
"new_vm_blk {blknum} not in {vm_rel} of size {vm_size}"
);
critical!("new_vm_blk {blknum} not in {vm_rel} of size {vm_size}");
new_vm_blk = None;
}
}
if let Some(blknum) = old_vm_blk {
if blknum >= vm_size {
critical_timeline!(
modification.tline.tenant_shard_id,
modification.tline.timeline_id,
"old_vm_blk {blknum} not in {vm_rel} of size {vm_size}"
);
critical!("old_vm_blk {blknum} not in {vm_rel} of size {vm_size}");
old_vm_blk = None;
}
}

View File

@@ -22,8 +22,7 @@ OBJS = \
walproposer.o \
walproposer_pg.o \
neon_ddl_handler.o \
walsender_hooks.o \
$(NEON_CARGO_ARTIFACT_TARGET_DIR)/libcommunicator.a
walsender_hooks.o
PG_CPPFLAGS = -I$(libpq_srcdir)
SHLIB_LINK_INTERNAL = $(libpq)
@@ -55,17 +54,6 @@ WALPROP_OBJS = \
neon_utils.o \
walproposer_compat.o
# libcommunicator.a is built by cargo from the Rust sources under communicator/
# subdirectory. `cargo build` also generates communicator_bindings.h.
neon.o: communicator/communicator_bindings.h
$(NEON_CARGO_ARTIFACT_TARGET_DIR)/libcommunicator.a communicator/communicator_bindings.h &:
(cd $(srcdir)/communicator && cargo build $(CARGO_BUILD_FLAGS) $(CARGO_PROFILE))
# Force `cargo build` every time. Some of the Rust sources might have
# changed.
.PHONY: $(NEON_CARGO_ARTIFACT_TARGET_DIR)/libcommunicator.a communicator/communicator_bindings.h
.PHONY: walproposer-lib
walproposer-lib: CPPFLAGS += -DWALPROPOSER_LIB
walproposer-lib: libwalproposer.a;

View File

@@ -1,2 +0,0 @@
# generated file (with cbindgen, see build.rs)
communicator_bindings.h

View File

@@ -1,20 +0,0 @@
[package]
name = "communicator"
version = "0.1.0"
license.workspace = true
edition.workspace = true
[lib]
crate-type = ["staticlib"]
[features]
# 'testing' feature is currently unused in the communicator, but we accept it for convenience of
# calling build scripts, so that you can pass the same feature to all packages.
testing = []
[dependencies]
neon-shmem.workspace = true
workspace_hack = { version = "0.1", path = "../../../workspace_hack" }
[build-dependencies]
cbindgen.workspace = true

View File

@@ -1,8 +0,0 @@
This package will evolve into a "compute-pageserver communicator"
process and machinery. For now, it's just a dummy that doesn't do
anything interesting, but it allows us to test the compilation and
linking of Rust code into the Postgres extensions.
At compilation time, pgxn/neon/communicator/ produces a static
library, libcommunicator.a. It is linked to the neon.so extension
library.

View File

@@ -1,20 +0,0 @@
use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
match cbindgen::generate(crate_dir) {
Ok(bindings) => {
bindings.write_to_file("communicator_bindings.h");
}
Err(cbindgen::Error::ParseSyntaxError { .. }) => {
// This means there was a syntax error in the Rust sources. Don't panic, because
// we want the build to continue and the Rust compiler to hit the error. The
// Rust compiler produces a better error message than cbindgen.
eprintln!("Generating C bindings failed because of a Rust syntax error");
}
Err(err) => panic!("Unable to generate C bindings: {err:?}"),
};
Ok(())
}

View File

@@ -1,4 +0,0 @@
language = "C"
[enum]
prefix_with_name = true

View File

@@ -1,6 +0,0 @@
/// dummy function, just to test linking Rust functions into the C
/// extension
#[unsafe(no_mangle)]
pub extern "C" fn communicator_dummy(arg: u32) -> u32 {
arg + 1
}

View File

@@ -43,9 +43,6 @@
#include "storage/ipc.h"
#endif
/* the rust bindings, generated by cbindgen */
#include "communicator/communicator_bindings.h"
PG_MODULE_MAGIC;
void _PG_init(void);
@@ -90,14 +87,6 @@ static const struct config_enum_entry running_xacts_overflow_policies[] = {
{NULL, 0, false}
};
static const struct config_enum_entry debug_compare_local_modes[] = {
{"none", DEBUG_COMPARE_LOCAL_NONE, false},
{"prefetch", DEBUG_COMPARE_LOCAL_PREFETCH, false},
{"lfc", DEBUG_COMPARE_LOCAL_LFC, false},
{"all", DEBUG_COMPARE_LOCAL_ALL, false},
{NULL, 0, false}
};
/*
* XXX: These private to procarray.c, but we need them here.
*/
@@ -455,9 +444,6 @@ _PG_init(void)
shmem_startup_hook = neon_shmem_startup_hook;
#endif
/* dummy call to a Rust function in the communicator library, to check that it works */
(void) communicator_dummy(123);
pg_init_libpagestore();
lfc_init();
pg_init_walproposer();
@@ -533,16 +519,6 @@ _PG_init(void)
GUC_UNIT_KB,
NULL, NULL, NULL);
DefineCustomEnumVariable(
"neon.debug_compare_local",
"Debug mode for compaing content of pages in prefetch ring/LFC/PS and local disk",
NULL,
&debug_compare_local,
DEBUG_COMPARE_LOCAL_NONE,
debug_compare_local_modes,
PGC_POSTMASTER,
0,
NULL, NULL, NULL);
/*
* Important: This must happen after other parts of the extension are
* loaded, otherwise any settings to GUCs that were set before the

View File

@@ -98,14 +98,12 @@ typedef struct
typedef struct DdlHashTable
{
struct DdlHashTable *prev_table;
size_t subtrans_level;
HTAB *db_table;
HTAB *role_table;
} DdlHashTable;
static DdlHashTable RootTable;
static DdlHashTable *CurrentDdlTable = &RootTable;
static int SubtransLevel; /* current nesting level of subtransactions */
static void
PushKeyValue(JsonbParseState **state, char *key, char *value)
@@ -334,25 +332,9 @@ SendDeltasToControlPlane()
}
}
static void
InitCurrentDdlTableIfNeeded()
{
/* Lazy construction of DllHashTable chain */
if (SubtransLevel > CurrentDdlTable->subtrans_level)
{
DdlHashTable *new_table = MemoryContextAlloc(CurTransactionContext, sizeof(DdlHashTable));
new_table->prev_table = CurrentDdlTable;
new_table->subtrans_level = SubtransLevel;
new_table->role_table = NULL;
new_table->db_table = NULL;
CurrentDdlTable = new_table;
}
}
static void
InitDbTableIfNeeded()
{
InitCurrentDdlTableIfNeeded();
if (!CurrentDdlTable->db_table)
{
HASHCTL db_ctl = {};
@@ -371,7 +353,6 @@ InitDbTableIfNeeded()
static void
InitRoleTableIfNeeded()
{
InitCurrentDdlTableIfNeeded();
if (!CurrentDdlTable->role_table)
{
HASHCTL role_ctl = {};
@@ -390,21 +371,19 @@ InitRoleTableIfNeeded()
static void
PushTable()
{
SubtransLevel += 1;
DdlHashTable *new_table = MemoryContextAlloc(CurTransactionContext, sizeof(DdlHashTable));
new_table->prev_table = CurrentDdlTable;
new_table->role_table = NULL;
new_table->db_table = NULL;
CurrentDdlTable = new_table;
}
static void
MergeTable()
{
DdlHashTable *old_table;
DdlHashTable *old_table = CurrentDdlTable;
Assert(SubtransLevel >= CurrentDdlTable->subtrans_level);
if (--SubtransLevel >= CurrentDdlTable->subtrans_level)
{
return;
}
old_table = CurrentDdlTable;
CurrentDdlTable = old_table->prev_table;
if (old_table->db_table)
@@ -497,15 +476,11 @@ MergeTable()
static void
PopTable()
{
Assert(SubtransLevel >= CurrentDdlTable->subtrans_level);
if (--SubtransLevel < CurrentDdlTable->subtrans_level)
{
/*
* Current table gets freed because it is allocated in aborted
* subtransaction's memory context.
*/
CurrentDdlTable = CurrentDdlTable->prev_table;
}
/*
* Current table gets freed because it is allocated in aborted
* subtransaction's memory context.
*/
CurrentDdlTable = CurrentDdlTable->prev_table;
}
static void

View File

@@ -177,22 +177,6 @@ extern StringInfoData nm_pack_request(NeonRequest *msg);
extern NeonResponse *nm_unpack_response(StringInfo s);
extern char *nm_to_string(NeonMessage *msg);
/*
* If debug_compare_local>DEBUG_COMPARE_LOCAL_NONE, we pass through all the SMGR API
* calls to md.c, and *also* do the calls to the Page Server. On every
* read, compare the versions we read from local disk and Page Server,
* and Assert that they are identical.
*/
typedef enum
{
DEBUG_COMPARE_LOCAL_NONE, /* normal mode - pages are storted locally only for unlogged relations */
DEBUG_COMPARE_LOCAL_PREFETCH, /* if page is found in prefetch ring, then compare it with local and return */
DEBUG_COMPARE_LOCAL_LFC, /* if page is found in LFC or prefetch ring, then compare it with local and return */
DEBUG_COMPARE_LOCAL_ALL /* always fetch page from PS and compare it with local */
} DebugCompareLocalMode;
extern int debug_compare_local;
/*
* API
*/

View File

@@ -76,11 +76,21 @@
typedef PGAlignedBlock PGIOAlignedBlock;
#endif
/*
* If DEBUG_COMPARE_LOCAL is defined, we pass through all the SMGR API
* calls to md.c, and *also* do the calls to the Page Server. On every
* read, compare the versions we read from local disk and Page Server,
* and Assert that they are identical.
*/
/* #define DEBUG_COMPARE_LOCAL */
#ifdef DEBUG_COMPARE_LOCAL
#include "access/nbtree.h"
#include "storage/bufpage.h"
#include "access/xlog_internal.h"
static char *hexdump_page(char *page);
#endif
#define IS_LOCAL_REL(reln) (\
NInfoGetDbOid(InfoFromSMgrRel(reln)) != 0 && \
@@ -98,8 +108,6 @@ typedef enum
UNLOGGED_BUILD_NOT_PERMANENT
} UnloggedBuildPhase;
int debug_compare_local;
static NRelFileInfo unlogged_build_rel_info;
static UnloggedBuildPhase unlogged_build_phase = UNLOGGED_BUILD_NOT_IN_PROGRESS;
@@ -470,10 +478,9 @@ neon_init(void)
old_redo_read_buffer_filter = redo_read_buffer_filter;
redo_read_buffer_filter = neon_redo_read_buffer_filter;
if (debug_compare_local)
{
mdinit();
}
#ifdef DEBUG_COMPARE_LOCAL
mdinit();
#endif
}
/*
@@ -796,16 +803,13 @@ neon_create(SMgrRelation reln, ForkNumber forkNum, bool isRedo)
case RELPERSISTENCE_TEMP:
case RELPERSISTENCE_UNLOGGED:
if (debug_compare_local)
{
mdcreate(reln, forkNum, forkNum == INIT_FORKNUM || isRedo);
if (forkNum == MAIN_FORKNUM)
mdcreate(reln, INIT_FORKNUM, true);
}
else
{
mdcreate(reln, forkNum, isRedo);
}
#ifdef DEBUG_COMPARE_LOCAL
mdcreate(reln, forkNum, forkNum == INIT_FORKNUM || isRedo);
if (forkNum == MAIN_FORKNUM)
mdcreate(reln, INIT_FORKNUM, true);
#else
mdcreate(reln, forkNum, isRedo);
#endif
return;
default:
@@ -844,11 +848,10 @@ neon_create(SMgrRelation reln, ForkNumber forkNum, bool isRedo)
else
set_cached_relsize(InfoFromSMgrRel(reln), forkNum, 0);
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdcreate(reln, forkNum, isRedo);
}
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdcreate(reln, forkNum, isRedo);
#endif
}
/*
@@ -874,7 +877,7 @@ neon_unlink(NRelFileInfoBackend rinfo, ForkNumber forkNum, bool isRedo)
{
/*
* Might or might not exist locally, depending on whether it's an unlogged
* or permanent relation (or if debug_compare_local is set). Try to
* or permanent relation (or if DEBUG_COMPARE_LOCAL is set). Try to
* unlink, it won't do any harm if the file doesn't exist.
*/
mdunlink(rinfo, forkNum, isRedo);
@@ -970,11 +973,10 @@ neon_extend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno,
lfc_write(InfoFromSMgrRel(reln), forkNum, blkno, buffer);
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdextend(reln, forkNum, blkno, buffer, skipFsync);
}
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdextend(reln, forkNum, blkno, buffer, skipFsync);
#endif
/*
* smgr_extend is often called with an all-zeroes page, so
@@ -1049,11 +1051,10 @@ neon_zeroextend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blocknum,
relpath(reln->smgr_rlocator, forkNum),
InvalidBlockNumber)));
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdzeroextend(reln, forkNum, blocknum, nblocks, skipFsync);
}
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdzeroextend(reln, forkNum, blocknum, nblocks, skipFsync);
#endif
/* Don't log any pages if we're not allowed to do so. */
if (!XLogInsertAllowed())
@@ -1264,11 +1265,10 @@ neon_writeback(SMgrRelation reln, ForkNumber forknum,
communicator_prefetch_pump_state();
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdwriteback(reln, forknum, blocknum, nblocks);
}
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdwriteback(reln, forknum, blocknum, nblocks);
#endif
}
/*
@@ -1282,6 +1282,7 @@ neon_read_at_lsn(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
communicator_read_at_lsnv(rinfo, forkNum, blkno, &request_lsns, &buffer, 1, NULL);
}
#ifdef DEBUG_COMPARE_LOCAL
static void
compare_with_local(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void* buffer, XLogRecPtr request_lsn)
{
@@ -1363,6 +1364,7 @@ compare_with_local(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, voi
}
}
}
#endif
#if PG_MAJORVERSION_NUM < 17
@@ -1415,28 +1417,22 @@ neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void *buffer
if (communicator_prefetch_lookupv(InfoFromSMgrRel(reln), forkNum, blkno, &request_lsns, 1, &bufferp, &present))
{
/* Prefetch hit */
if (debug_compare_local >= DEBUG_COMPARE_LOCAL_PREFETCH)
{
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
}
if (debug_compare_local <= DEBUG_COMPARE_LOCAL_PREFETCH)
{
return;
}
#ifdef DEBUG_COMPARE_LOCAL
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
#else
return;
#endif
}
/* Try to read from local file cache */
if (lfc_read(InfoFromSMgrRel(reln), forkNum, blkno, buffer))
{
MyNeonCounters->file_cache_hits_total++;
if (debug_compare_local >= DEBUG_COMPARE_LOCAL_LFC)
{
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
}
if (debug_compare_local <= DEBUG_COMPARE_LOCAL_LFC)
{
return;
}
#ifdef DEBUG_COMPARE_LOCAL
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
#else
return;
#endif
}
neon_read_at_lsn(InfoFromSMgrRel(reln), forkNum, blkno, request_lsns, buffer);
@@ -1446,15 +1442,15 @@ neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void *buffer
*/
communicator_prefetch_pump_state();
if (debug_compare_local)
{
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
}
#ifdef DEBUG_COMPARE_LOCAL
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
#endif
}
#endif /* PG_MAJORVERSION_NUM <= 16 */
#if PG_MAJORVERSION_NUM >= 17
#ifdef DEBUG_COMPARE_LOCAL
static void
compare_with_localv(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void** buffers, BlockNumber nblocks, neon_request_lsns* request_lsns, bits8* read_pages)
{
@@ -1469,6 +1465,7 @@ compare_with_localv(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, vo
}
}
}
#endif
static void
@@ -1519,19 +1516,13 @@ neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
blocknum, request_lsns, nblocks,
buffers, read_pages);
if (debug_compare_local >= DEBUG_COMPARE_LOCAL_PREFETCH)
{
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
}
if (debug_compare_local <= DEBUG_COMPARE_LOCAL_PREFETCH && prefetch_result == nblocks)
{
#ifdef DEBUG_COMPARE_LOCAL
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
memset(read_pages, 0, sizeof(read_pages));
#else
if (prefetch_result == nblocks)
return;
}
if (debug_compare_local > DEBUG_COMPARE_LOCAL_PREFETCH)
{
memset(read_pages, 0, sizeof(read_pages));
}
#endif
/* Try to read from local file cache */
lfc_result = lfc_readv_select(InfoFromSMgrRel(reln), forknum, blocknum, buffers,
@@ -1540,19 +1531,14 @@ neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
if (lfc_result > 0)
MyNeonCounters->file_cache_hits_total += lfc_result;
if (debug_compare_local >= DEBUG_COMPARE_LOCAL_LFC)
{
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
}
if (debug_compare_local <= DEBUG_COMPARE_LOCAL_LFC && prefetch_result + lfc_result == nblocks)
{
/* Read all blocks from LFC, so we're done */
#ifdef DEBUG_COMPARE_LOCAL
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
memset(read_pages, 0, sizeof(read_pages));
#else
/* Read all blocks from LFC, so we're done */
if (prefetch_result + lfc_result == nblocks)
return;
}
if (debug_compare_local > DEBUG_COMPARE_LOCAL_LFC)
{
memset(read_pages, 0, sizeof(read_pages));
}
#endif
communicator_read_at_lsnv(InfoFromSMgrRel(reln), forknum, blocknum, request_lsns,
buffers, nblocks, read_pages);
@@ -1562,14 +1548,14 @@ neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
*/
communicator_prefetch_pump_state();
if (debug_compare_local)
{
memset(read_pages, 0xFF, sizeof(read_pages));
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
}
#ifdef DEBUG_COMPARE_LOCAL
memset(read_pages, 0xFF, sizeof(read_pages));
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
#endif
}
#endif
#ifdef DEBUG_COMPARE_LOCAL
static char *
hexdump_page(char *page)
{
@@ -1588,6 +1574,7 @@ hexdump_page(char *page)
return result.data;
}
#endif
#if PG_MAJORVERSION_NUM < 17
/*
@@ -1609,8 +1596,12 @@ neon_write(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, const vo
switch (reln->smgr_relpersistence)
{
case 0:
#ifndef DEBUG_COMPARE_LOCAL
/* This is a bit tricky. Check if the relation exists locally */
if (mdexists(reln, debug_compare_local ? INIT_FORKNUM : forknum))
if (mdexists(reln, forknum))
#else
if (mdexists(reln, INIT_FORKNUM))
#endif
{
/* It exists locally. Guess it's unlogged then. */
#if PG_MAJORVERSION_NUM >= 17
@@ -1665,17 +1656,14 @@ neon_write(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, const vo
communicator_prefetch_pump_state();
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
{
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
#if PG_MAJORVERSION_NUM >= 17
mdwritev(reln, forknum, blocknum, &buffer, 1, skipFsync);
mdwritev(reln, forknum, blocknum, &buffer, 1, skipFsync);
#else
mdwrite(reln, forknum, blocknum, buffer, skipFsync);
mdwrite(reln, forknum, blocknum, buffer, skipFsync);
#endif
}
}
#endif
}
#endif
@@ -1689,8 +1677,12 @@ neon_writev(SMgrRelation reln, ForkNumber forknum, BlockNumber blkno,
switch (reln->smgr_relpersistence)
{
case 0:
#ifndef DEBUG_COMPARE_LOCAL
/* This is a bit tricky. Check if the relation exists locally */
if (mdexists(reln, debug_compare_local ? INIT_FORKNUM : forknum))
if (mdexists(reln, forknum))
#else
if (mdexists(reln, INIT_FORKNUM))
#endif
{
/* It exists locally. Guess it's unlogged then. */
mdwritev(reln, forknum, blkno, buffers, nblocks, skipFsync);
@@ -1728,11 +1720,10 @@ neon_writev(SMgrRelation reln, ForkNumber forknum, BlockNumber blkno,
communicator_prefetch_pump_state();
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdwritev(reln, forknum, blkno, buffers, nblocks, skipFsync);
}
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdwritev(reln, forknum, blkno, buffers, nblocks, skipFsync);
#endif
}
#endif
@@ -1871,11 +1862,10 @@ neon_truncate(SMgrRelation reln, ForkNumber forknum, BlockNumber old_blocks, Blo
*/
neon_set_lwlsn_relation(lsn, InfoFromSMgrRel(reln), forknum);
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdtruncate(reln, forknum, old_blocks, nblocks);
}
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdtruncate(reln, forknum, old_blocks, nblocks);
#endif
}
/*
@@ -1914,11 +1904,10 @@ neon_immedsync(SMgrRelation reln, ForkNumber forknum)
communicator_prefetch_pump_state();
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdimmedsync(reln, forknum);
}
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdimmedsync(reln, forknum);
#endif
}
#if PG_MAJORVERSION_NUM >= 17
@@ -1945,11 +1934,10 @@ neon_registersync(SMgrRelation reln, ForkNumber forknum)
neon_log(SmgrTrace, "[NEON_SMGR] registersync noop");
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdimmedsync(reln, forknum);
}
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdimmedsync(reln, forknum);
#endif
}
#endif
@@ -1990,11 +1978,10 @@ neon_start_unlogged_build(SMgrRelation reln)
case RELPERSISTENCE_UNLOGGED:
unlogged_build_rel_info = InfoFromSMgrRel(reln);
unlogged_build_phase = UNLOGGED_BUILD_NOT_PERMANENT;
if (debug_compare_local)
{
if (!IsParallelWorker())
mdcreate(reln, INIT_FORKNUM, true);
}
#ifdef DEBUG_COMPARE_LOCAL
if (!IsParallelWorker())
mdcreate(reln, INIT_FORKNUM, true);
#endif
return;
default:
@@ -2022,7 +2009,11 @@ neon_start_unlogged_build(SMgrRelation reln)
*/
if (!IsParallelWorker())
{
mdcreate(reln, debug_compare_local ? INIT_FORKNUM : MAIN_FORKNUM, false);
#ifndef DEBUG_COMPARE_LOCAL
mdcreate(reln, MAIN_FORKNUM, false);
#else
mdcreate(reln, INIT_FORKNUM, true);
#endif
}
}
@@ -2116,14 +2107,14 @@ neon_end_unlogged_build(SMgrRelation reln)
lfc_invalidate(InfoFromNInfoB(rinfob), forknum, nblocks);
mdclose(reln, forknum);
if (!debug_compare_local)
{
/* use isRedo == true, so that we drop it immediately */
mdunlink(rinfob, forknum, true);
}
#ifndef DEBUG_COMPARE_LOCAL
/* use isRedo == true, so that we drop it immediately */
mdunlink(rinfob, forknum, true);
#endif
}
if (debug_compare_local)
mdunlink(rinfob, INIT_FORKNUM, true);
#ifdef DEBUG_COMPARE_LOCAL
mdunlink(rinfob, INIT_FORKNUM, true);
#endif
}
NRelFileInfoInvalidate(unlogged_build_rel_info);
unlogged_build_phase = UNLOGGED_BUILD_NOT_IN_PROGRESS;

View File

@@ -138,62 +138,3 @@ Now from client you can start a new session:
```sh
PGSSLROOTCERT=./server.crt psql "postgresql://proxy:password@endpoint.local.neon.build:4432/postgres?sslmode=verify-full"
```
## auth broker setup:
Create a postgres instance:
```sh
docker run \
--detach \
--name proxy-postgres \
--env POSTGRES_HOST_AUTH_METHOD=trust \
--env POSTGRES_USER=authenticated \
--env POSTGRES_DB=database \
--publish 5432:5432 \
postgres:17-bookworm
```
Create a configuration file called `local_proxy.json` in the root of the repo (used also by the auth broker to validate JWTs)
```sh
{
"jwks": [
{
"id": "1",
"role_names": ["authenticator", "authenticated", "anon"],
"jwks_url": "https://climbing-minnow-11.clerk.accounts.dev/.well-known/jwks.json",
"provider_name": "foo",
"jwt_audience": null
}
]
}
```
Start the local proxy:
```sh
cargo run --bin local_proxy -- \
--disable_pg_session_jwt true \
--http 0.0.0.0:7432
```
Start the auth broker:
```sh
LOGFMT=text OTEL_SDK_DISABLED=true cargo run --bin proxy --features testing -- \
-c server.crt -k server.key \
--is-auth-broker true \
--wss 0.0.0.0:8080 \
--http 0.0.0.0:7002 \
--auth-backend local
```
Create a JWT in your auth provider (e.g. Clerk) and set it in the `NEON_JWT` environment variable.
```sh
export NEON_JWT="..."
```
Run a query against the auth broker:
```sh
curl -k "https://foo.local.neon.build:8080/sql" \
-H "Authorization: Bearer $NEON_JWT" \
-H "neon-connection-string: postgresql://authenticator@foo.local.neon.build/database" \
-d '{"query":"select 1","params":[]}'
```

View File

@@ -164,20 +164,21 @@ async fn authenticate(
})?
.map_err(ConsoleRedirectError::from)?;
if auth_config.ip_allowlist_check_enabled
&& let Some(allowed_ips) = &db_info.allowed_ips
&& !auth::check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips)
{
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
if auth_config.ip_allowlist_check_enabled {
if let Some(allowed_ips) = &db_info.allowed_ips {
if !auth::check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
}
}
// Check if the access over the public internet is allowed, otherwise block. Note that
// the console redirect is not behind the VPC service endpoint, so we don't need to check
// the VPC endpoint ID.
if let Some(public_access_allowed) = db_info.public_access_allowed
&& !public_access_allowed
{
return Err(auth::AuthError::NetworkNotAllowed);
if let Some(public_access_allowed) = db_info.public_access_allowed {
if !public_access_allowed {
return Err(auth::AuthError::NetworkNotAllowed);
}
}
client.write_message(BeMessage::NoticeResponse("Connecting to database."));

View File

@@ -399,36 +399,36 @@ impl JwkCacheEntryLock {
tracing::debug!(?payload, "JWT signature valid with claims");
if let Some(aud) = expected_audience
&& payload.audience.0.iter().all(|s| s != aud)
{
return Err(JwtError::InvalidClaims(
JwtClaimsError::InvalidJwtTokenAudience,
));
if let Some(aud) = expected_audience {
if payload.audience.0.iter().all(|s| s != aud) {
return Err(JwtError::InvalidClaims(
JwtClaimsError::InvalidJwtTokenAudience,
));
}
}
let now = SystemTime::now();
if let Some(exp) = payload.expiration
&& now >= exp + CLOCK_SKEW_LEEWAY
{
return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired(
exp.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
)));
}
if let Some(nbf) = payload.not_before
&& nbf >= now + CLOCK_SKEW_LEEWAY
{
return Err(JwtError::InvalidClaims(
JwtClaimsError::JwtTokenNotYetReadyToUse(
nbf.duration_since(SystemTime::UNIX_EPOCH)
if let Some(exp) = payload.expiration {
if now >= exp + CLOCK_SKEW_LEEWAY {
return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired(
exp.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
),
));
)));
}
}
if let Some(nbf) = payload.not_before {
if nbf >= now + CLOCK_SKEW_LEEWAY {
return Err(JwtError::InvalidClaims(
JwtClaimsError::JwtTokenNotYetReadyToUse(
nbf.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
),
));
}
}
Ok(ComputeCredentialKeys::JwtPayload(payloadb))

View File

@@ -171,6 +171,7 @@ impl ComputeUserInfo {
pub(crate) enum ComputeCredentialKeys {
AuthKeys(AuthKeys),
JwtPayload(Vec<u8>),
None,
}
impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
@@ -345,13 +346,15 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
Err(e) => {
// The password could have been changed, so we invalidate the cache.
// We should only invalidate the cache if the TTL might have expired.
if e.is_password_failed()
&& let ControlPlaneClient::ProxyV1(api) = &*api
&& let Some(ep) = &user_info.endpoint_id
{
api.caches
.project_info
.maybe_invalidate_role_secret(ep, &user_info.user);
if e.is_password_failed() {
#[allow(irrefutable_let_patterns)]
if let ControlPlaneClient::ProxyV1(api) = &*api {
if let Some(ep) = &user_info.endpoint_id {
api.caches
.project_info
.maybe_invalidate_role_secret(ep, &user_info.user);
}
}
}
Err(e)

View File

@@ -1,37 +1,43 @@
use std::net::SocketAddr;
use std::pin::pin;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::bail;
use anyhow::{Context, bail, ensure};
use arc_swap::ArcSwapOption;
use camino::Utf8PathBuf;
use camino::{Utf8Path, Utf8PathBuf};
use clap::Parser;
use compute_api::spec::LocalProxySpec;
use futures::future::Either;
use thiserror::Error;
use tokio::net::TcpListener;
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info};
use tracing::{debug, error, info, warn};
use utils::sentry_init::init_sentry;
use utils::{pid_file, project_build_tag, project_git_version};
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::local::LocalBackend;
use crate::auth::backend::local::{JWKS_ROLE_MAP, LocalBackend};
use crate::auth::{self};
use crate::cancellation::CancellationHandler;
use crate::config::{
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
refresh_config_loop,
};
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::ext::TaskExt;
use crate::http::health_server::AppMetrics;
use crate::intern::RoleNameInt;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::cancel_set::CancelSet;
use crate::serverless::{self, GlobalConnPoolOptions};
use crate::tls::client_config::compute_client_config_with_root_certs;
use crate::types::RoleName;
use crate::url::ApiUrl;
project_git_version!(GIT_VERSION);
@@ -76,11 +82,6 @@ struct LocalProxyCliArgs {
/// Path of the local proxy PID file
#[clap(long, default_value = "./local_proxy.pid")]
pid_path: Utf8PathBuf,
/// Disable pg_session_jwt extension installation
/// This is useful for testing the local proxy with vanilla postgres.
#[clap(long, default_value = "false")]
#[cfg(feature = "testing")]
disable_pg_session_jwt: bool,
}
#[derive(clap::Args, Clone, Copy, Debug)]
@@ -281,8 +282,6 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
connect_compute_locks,
connect_to_compute: compute_config,
#[cfg(feature = "testing")]
disable_pg_session_jwt: args.disable_pg_session_jwt,
})))
}
@@ -294,3 +293,132 @@ fn build_auth_backend(args: &LocalProxyCliArgs) -> &'static auth::Backend<'stati
Box::leak(Box::new(auth_backend))
}
#[derive(Error, Debug)]
enum RefreshConfigError {
#[error(transparent)]
Read(#[from] std::io::Error),
#[error(transparent)]
Parse(#[from] serde_json::Error),
#[error(transparent)]
Validate(anyhow::Error),
#[error(transparent)]
Tls(anyhow::Error),
}
async fn refresh_config_loop(config: &ProxyConfig, path: Utf8PathBuf, rx: Arc<Notify>) {
let mut init = true;
loop {
rx.notified().await;
match refresh_config_inner(config, &path).await {
Ok(()) => {}
// don't log for file not found errors if this is the first time we are checking
// for computes that don't use local_proxy, this is not an error.
Err(RefreshConfigError::Read(e))
if init && e.kind() == std::io::ErrorKind::NotFound =>
{
debug!(error=?e, ?path, "could not read config file");
}
Err(RefreshConfigError::Tls(e)) => {
error!(error=?e, ?path, "could not read TLS certificates");
}
Err(e) => {
error!(error=?e, ?path, "could not read config file");
}
}
init = false;
}
}
async fn refresh_config_inner(
config: &ProxyConfig,
path: &Utf8Path,
) -> Result<(), RefreshConfigError> {
let bytes = tokio::fs::read(&path).await?;
let data: LocalProxySpec = serde_json::from_slice(&bytes)?;
let mut jwks_set = vec![];
fn parse_jwks_settings(jwks: compute_api::spec::JwksSettings) -> anyhow::Result<JwksSettings> {
let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?;
ensure!(
jwks_url.has_authority()
&& (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"),
"Invalid JWKS url. Must be HTTP",
);
ensure!(
jwks_url.host().is_some_and(|h| h != url::Host::Domain("")),
"Invalid JWKS url. No domain listed",
);
// clear username, password and ports
jwks_url
.set_username("")
.expect("url can be a base and has a valid host and is not a file. should not error");
jwks_url
.set_password(None)
.expect("url can be a base and has a valid host and is not a file. should not error");
// local testing is hard if we need to have a specific restricted port
if cfg!(not(feature = "testing")) {
jwks_url.set_port(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
}
// clear query params
jwks_url.set_fragment(None);
jwks_url.query_pairs_mut().clear().finish();
if jwks_url.scheme() != "https" {
// local testing is hard if we need to set up https support.
if cfg!(not(feature = "testing")) {
jwks_url
.set_scheme("https")
.expect("should not error to set the scheme to https if it was http");
} else {
warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS");
}
}
Ok(JwksSettings {
id: jwks.id,
jwks_url,
_provider_name: jwks.provider_name,
jwt_audience: jwks.jwt_audience,
role_names: jwks
.role_names
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
})
}
for jwks in data.jwks.into_iter().flatten() {
jwks_set.push(parse_jwks_settings(jwks).map_err(RefreshConfigError::Validate)?);
}
info!("successfully loaded new config");
JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set })));
if let Some(tls_config) = data.tls {
let tls_config = tokio::task::spawn_blocking(move || {
crate::tls::server_config::configure_tls(
tls_config.key_path.as_ref(),
tls_config.cert_path.as_ref(),
None,
false,
)
})
.await
.propagate_task_panic()
.map_err(RefreshConfigError::Tls)?;
config.tls_config.store(Some(Arc::new(tls_config)));
}
Ok(())
}

View File

@@ -4,7 +4,6 @@
//! This allows connecting to pods/services running in the same Kubernetes cluster from
//! the outside. Similar to an ingress controller for HTTPS.
use std::io;
use std::net::SocketAddr;
use std::path::Path;
use std::sync::Arc;
@@ -230,6 +229,7 @@ pub(super) async fn task_main(
.set_nodelay(true)
.context("failed to set socket option")?;
info!(%peer_addr, "serving");
let ctx = RequestContext::new(
session_id,
ConnectionInfo {
@@ -241,14 +241,6 @@ pub(super) async fn task_main(
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
}
.unwrap_or_else(|e| {
if let Some(FirstMessage(io_error)) = e.downcast_ref() {
// this is noisy. if we get EOF on the very first message that's likely
// just NLB doing a healthcheck.
if io_error.kind() == io::ErrorKind::UnexpectedEof {
return;
}
}
// Acknowledge that the task has finished with an error.
error!("per-client task finished with an error: {e:#}");
})
@@ -265,19 +257,12 @@ pub(super) async fn task_main(
Ok(())
}
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
struct FirstMessage(io::Error);
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
ctx: &RequestContext,
raw_stream: S,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<TlsStream<S>> {
let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream))
.await
.map_err(FirstMessage)?;
let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream)).await?;
match msg {
FeStartupPacket::SslRequest { direct: None } => {
let raw = stream.accept_tls().await?;

View File

@@ -10,15 +10,11 @@ use std::time::Duration;
use anyhow::Context;
use anyhow::{bail, ensure};
use arc_swap::ArcSwapOption;
#[cfg(any(test, feature = "testing"))]
use camino::Utf8PathBuf;
use futures::future::Either;
use itertools::{Itertools, Position};
use rand::{Rng, thread_rng};
use remote_storage::RemoteStorageConfig;
use tokio::net::TcpListener;
#[cfg(any(test, feature = "testing"))]
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, error, info, warn};
@@ -26,13 +22,9 @@ use utils::sentry_init::init_sentry;
use utils::{project_build_tag, project_git_version};
use crate::auth::backend::jwt::JwkCache;
#[cfg(any(test, feature = "testing"))]
use crate::auth::backend::local::LocalBackend;
use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned};
use crate::batch::BatchQueue;
use crate::cancellation::{CancellationHandler, CancellationProcessor};
#[cfg(any(test, feature = "testing"))]
use crate::config::refresh_config_loop;
use crate::config::{
self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions,
ProxyConfig, ProxyProtocolV2, remote_storage_from_toml,
@@ -68,9 +60,6 @@ enum AuthBackendType {
#[cfg(any(test, feature = "testing"))]
Postgres,
#[cfg(any(test, feature = "testing"))]
Local,
}
/// Neon proxy/router
@@ -85,10 +74,6 @@ struct ProxyCliArgs {
proxy: SocketAddr,
#[clap(value_enum, long, default_value_t = AuthBackendType::ConsoleRedirect)]
auth_backend: AuthBackendType,
/// Path of the local proxy config file (used for local-file auth backend)
#[clap(long, default_value = "./local_proxy.json")]
#[cfg(any(test, feature = "testing"))]
config_path: Utf8PathBuf,
/// listen for management callback connection on ip:port
#[clap(short, long, default_value = "127.0.0.1:7000")]
mgmt: SocketAddr,
@@ -241,14 +226,6 @@ struct ProxyCliArgs {
#[clap(flatten)]
pg_sni_router: PgSniRouterArgs,
/// if this is not local proxy, this toggles whether we accept Postgres REST requests
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
is_rest_broker: bool,
/// cache for `db_schema_cache` introspection (use `size=0` to disable)
#[clap(long, default_value = "size=1000,ttl=1h")]
db_schema_cache: String,
}
#[derive(clap::Args, Clone, Copy, Debug)]
@@ -409,8 +386,6 @@ pub async fn run() -> anyhow::Result<()> {
64,
));
#[cfg(any(test, feature = "testing"))]
let refresh_config_notify = Arc::new(Notify::new());
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
@@ -437,17 +412,6 @@ pub async fn run() -> anyhow::Result<()> {
endpoint_rate_limiter.clone(),
));
}
// if auth backend is local, we need to load the config file
#[cfg(any(test, feature = "testing"))]
if let auth::Backend::Local(_) = &auth_backend {
refresh_config_notify.notify_one();
tokio::spawn(refresh_config_loop(
config,
args.config_path,
refresh_config_notify.clone(),
));
}
}
Either::Right(auth_backend) => {
if let Some(proxy_listener) = proxy_listener {
@@ -498,13 +462,7 @@ pub async fn run() -> anyhow::Result<()> {
// maintenance tasks. these never return unless there's an error
let mut maintenance_tasks = JoinSet::new();
maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), {
move || {
#[cfg(any(test, feature = "testing"))]
refresh_config_notify.notify_one();
}
}));
maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), || {}));
maintenance_tasks.spawn(http::health_server::task_main(
http_listener,
AppMetrics {
@@ -520,51 +478,54 @@ pub async fn run() -> anyhow::Result<()> {
maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
}
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend
&& let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api
&& let Some(client) = redis_client
{
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
#[cfg_attr(not(any(test, feature = "testing")), expect(irrefutable_let_patterns))]
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend {
if let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api {
if let Some(client) = redis_client {
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues.
// cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
for attempt in (0..3).with_position() {
match redis_kv_client.try_connect().await {
Ok(()) => {
info!("Connected to Redis KV client");
cancellation_handler.init_tx(BatchQueue::new(CancellationProcessor {
client: redis_kv_client,
batch_size: args.cancellation_batch_size,
}));
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues.
// cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
for attempt in (0..3).with_position() {
match redis_kv_client.try_connect().await {
Ok(()) => {
info!("Connected to Redis KV client");
cancellation_handler.init_tx(BatchQueue::new(CancellationProcessor {
client: redis_kv_client,
batch_size: args.cancellation_batch_size,
}));
break;
}
Err(e) => {
error!("Failed to connect to Redis KV client: {e}");
if matches!(attempt, Position::Last(_)) {
bail!(
"Failed to connect to Redis KV client after {} attempts",
attempt.into_inner()
);
break;
}
Err(e) => {
error!("Failed to connect to Redis KV client: {e}");
if matches!(attempt, Position::Last(_)) {
bail!(
"Failed to connect to Redis KV client after {} attempts",
attempt.into_inner()
);
}
let jitter = thread_rng().gen_range(0..100);
tokio::time::sleep(Duration::from_millis(1000 + jitter)).await;
}
}
let jitter = thread_rng().gen_range(0..100);
tokio::time::sleep(Duration::from_millis(1000 + jitter)).await;
}
// listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone();
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(client, cancellation_token.clone()).await }
.instrument(span),
);
}
}
// listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone();
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(client, cancellation_token.clone()).await }.instrument(span),
);
}
let maintenance = loop {
@@ -692,8 +653,6 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute: compute_config,
#[cfg(feature = "testing")]
disable_pg_session_jwt: false,
};
let config = Box::leak(Box::new(config));
@@ -847,19 +806,6 @@ fn build_auth_backend(
Ok(Either::Right(config))
}
#[cfg(any(test, feature = "testing"))]
AuthBackendType::Local => {
let postgres: SocketAddr = "127.0.0.1:7432".parse()?;
let compute_ctl: ApiUrl = "http://127.0.0.1:3081/".parse()?;
let auth_backend = crate::auth::Backend::Local(
crate::auth::backend::MaybeOwned::Owned(LocalBackend::new(postgres, compute_ctl)),
);
let config = Box::leak(Box::new(auth_backend));
Ok(Either::Left(config))
}
}
}

View File

@@ -165,7 +165,7 @@ impl AuthInfo {
ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
Some(Auth::Scram(Box::new(auth_keys)))
}
ComputeCredentialKeys::JwtPayload(_) => None,
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None,
},
server_params: StartupMessageParams::default(),
skip_db_user: false,

View File

@@ -4,26 +4,17 @@ use std::time::Duration;
use anyhow::{Context, Ok, bail, ensure};
use arc_swap::ArcSwapOption;
use camino::{Utf8Path, Utf8PathBuf};
use clap::ValueEnum;
use compute_api::spec::LocalProxySpec;
use remote_storage::RemoteStorageConfig;
use thiserror::Error;
use tokio::sync::Notify;
use tracing::{debug, error, info, warn};
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::local::JWKS_ROLE_MAP;
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::ext::TaskExt;
use crate::intern::RoleNameInt;
use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::GlobalConnPoolOptions;
use crate::serverless::cancel_set::CancelSet;
pub use crate::tls::server_config::{TlsConfig, configure_tls};
use crate::types::{Host, RoleName};
use crate::types::Host;
pub struct ProxyConfig {
pub tls_config: ArcSwapOption<TlsConfig>,
@@ -35,8 +26,6 @@ pub struct ProxyConfig {
pub wake_compute_retry_config: RetryConfig,
pub connect_compute_locks: ApiLocks<Host>,
pub connect_to_compute: ComputeConfig,
#[cfg(feature = "testing")]
pub disable_pg_session_jwt: bool,
}
pub struct ComputeConfig {
@@ -420,135 +409,6 @@ impl FromStr for ConcurrencyLockOptions {
}
}
#[derive(Error, Debug)]
pub(crate) enum RefreshConfigError {
#[error(transparent)]
Read(#[from] std::io::Error),
#[error(transparent)]
Parse(#[from] serde_json::Error),
#[error(transparent)]
Validate(anyhow::Error),
#[error(transparent)]
Tls(anyhow::Error),
}
pub(crate) async fn refresh_config_loop(config: &ProxyConfig, path: Utf8PathBuf, rx: Arc<Notify>) {
let mut init = true;
loop {
rx.notified().await;
match refresh_config_inner(config, &path).await {
std::result::Result::Ok(()) => {}
// don't log for file not found errors if this is the first time we are checking
// for computes that don't use local_proxy, this is not an error.
Err(RefreshConfigError::Read(e))
if init && e.kind() == std::io::ErrorKind::NotFound =>
{
debug!(error=?e, ?path, "could not read config file");
}
Err(RefreshConfigError::Tls(e)) => {
error!(error=?e, ?path, "could not read TLS certificates");
}
Err(e) => {
error!(error=?e, ?path, "could not read config file");
}
}
init = false;
}
}
pub(crate) async fn refresh_config_inner(
config: &ProxyConfig,
path: &Utf8Path,
) -> Result<(), RefreshConfigError> {
let bytes = tokio::fs::read(&path).await?;
let data: LocalProxySpec = serde_json::from_slice(&bytes)?;
let mut jwks_set = vec![];
fn parse_jwks_settings(jwks: compute_api::spec::JwksSettings) -> anyhow::Result<JwksSettings> {
let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?;
ensure!(
jwks_url.has_authority()
&& (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"),
"Invalid JWKS url. Must be HTTP",
);
ensure!(
jwks_url.host().is_some_and(|h| h != url::Host::Domain("")),
"Invalid JWKS url. No domain listed",
);
// clear username, password and ports
jwks_url
.set_username("")
.expect("url can be a base and has a valid host and is not a file. should not error");
jwks_url
.set_password(None)
.expect("url can be a base and has a valid host and is not a file. should not error");
// local testing is hard if we need to have a specific restricted port
if cfg!(not(feature = "testing")) {
jwks_url.set_port(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
}
// clear query params
jwks_url.set_fragment(None);
jwks_url.query_pairs_mut().clear().finish();
if jwks_url.scheme() != "https" {
// local testing is hard if we need to set up https support.
if cfg!(not(feature = "testing")) {
jwks_url
.set_scheme("https")
.expect("should not error to set the scheme to https if it was http");
} else {
warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS");
}
}
Ok(JwksSettings {
id: jwks.id,
jwks_url,
_provider_name: jwks.provider_name,
jwt_audience: jwks.jwt_audience,
role_names: jwks
.role_names
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
})
}
for jwks in data.jwks.into_iter().flatten() {
jwks_set.push(parse_jwks_settings(jwks).map_err(RefreshConfigError::Validate)?);
}
info!("successfully loaded new config");
JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set })));
if let Some(tls_config) = data.tls {
let tls_config = tokio::task::spawn_blocking(move || {
crate::tls::server_config::configure_tls(
tls_config.key_path.as_ref(),
tls_config.cert_path.as_ref(),
None,
false,
)
})
.await
.propagate_task_panic()
.map_err(RefreshConfigError::Tls)?;
config.tls_config.store(Some(Arc::new(tls_config)));
}
std::result::Result::Ok(())
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -209,9 +209,11 @@ impl RequestContext {
if let Some(options_str) = options.get("options") {
// If not found directly, try to extract it from the options string
for option in options_str.split_whitespace() {
if let Some(value) = option.strip_prefix("neon_query_id:") {
this.set_testodrome_id(value.into());
break;
if option.starts_with("neon_query_id:") {
if let Some(value) = option.strip_prefix("neon_query_id:") {
this.set_testodrome_id(value.into());
break;
}
}
}
}

View File

@@ -250,8 +250,10 @@ impl NeonControlPlaneClient {
info!(duration = ?start.elapsed(), "received http response");
let body = parse_body::<WakeCompute>(response.status(), response.bytes().await?)?;
let Some((host, port)) = parse_host_port(&body.address) else {
return Err(WakeComputeError::BadComputeAddress(body.address));
// Unfortunately, ownership won't let us use `Option::ok_or` here.
let (host, port) = match parse_host_port(&body.address) {
None => return Err(WakeComputeError::BadComputeAddress(body.address)),
Some(x) => x,
};
let host_addr = IpAddr::from_str(host).ok();

View File

@@ -213,12 +213,7 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
self.metrics
.semaphore_acquire_seconds
.observe(now.elapsed().as_secs_f64());
if permit.is_ok() {
debug!(elapsed = ?now.elapsed(), "acquired permit");
} else {
debug!(elapsed = ?now.elapsed(), "timed out acquiring permit");
}
debug!("acquired permit {:?}", now.elapsed().as_secs_f64());
Ok(WakeComputePermit { permit: permit? })
}

View File

@@ -52,7 +52,7 @@ pub async fn init() -> anyhow::Result<LoggingGuard> {
StderrWriter {
stderr: std::io::stderr(),
},
&["conn_id", "ep", "query_id", "request_id", "session_id"],
&["request_id", "session_id", "conn_id"],
))
} else {
None
@@ -271,18 +271,18 @@ where
});
// In case logging fails we generate a simpler JSON object.
if let Err(err) = res
&& let Ok(mut line) = serde_json::to_vec(&serde_json::json!( {
if let Err(err) = res {
if let Ok(mut line) = serde_json::to_vec(&serde_json::json!( {
"timestamp": now.to_rfc3339_opts(chrono::SecondsFormat::Micros, true),
"level": "ERROR",
"message": format_args!("cannot log event: {err:?}"),
"fields": {
"event": format_args!("{event:?}"),
},
}))
{
line.push(b'\n');
self.writer.make_writer().write_all(&line).ok();
})) {
line.push(b'\n');
self.writer.make_writer().write_all(&line).ok();
}
}
}
@@ -583,11 +583,10 @@ impl EventFormatter {
THREAD_ID.with(|tid| serializer.serialize_entry("thread_id", tid))?;
// TODO: tls cache? name could change
if let Some(thread_name) = std::thread::current().name()
&& !thread_name.is_empty()
&& thread_name != "tokio-runtime-worker"
{
serializer.serialize_entry("thread_name", thread_name)?;
if let Some(thread_name) = std::thread::current().name() {
if !thread_name.is_empty() && thread_name != "tokio-runtime-worker" {
serializer.serialize_entry("thread_name", thread_name)?;
}
}
if let Some(task_id) = tokio::task::try_id() {
@@ -597,10 +596,10 @@ impl EventFormatter {
serializer.serialize_entry("target", meta.target())?;
// Skip adding module if it's the same as target.
if let Some(module) = meta.module_path()
&& module != meta.target()
{
serializer.serialize_entry("module", module)?;
if let Some(module) = meta.module_path() {
if module != meta.target() {
serializer.serialize_entry("module", module)?;
}
}
if let Some(file) = meta.file() {

View File

@@ -236,6 +236,13 @@ pub enum Bool {
False,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "outcome")]
pub enum Outcome {
Success,
Failed,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "outcome")]
pub enum CacheOutcome {

View File

@@ -90,27 +90,27 @@ where
// TODO: 1 info log, with a enum label for close direction.
// Early termination checks from compute to client.
if let TransferState::Done(_) = compute_to_client
&& let TransferState::Running(buf) = &client_to_compute
{
info!("Compute is done, terminate client");
// Initiate shutdown
client_to_compute = TransferState::ShuttingDown(buf.amt);
client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, client, compute)
.map_err(ErrorSource::from_client)?;
if let TransferState::Done(_) = compute_to_client {
if let TransferState::Running(buf) = &client_to_compute {
info!("Compute is done, terminate client");
// Initiate shutdown
client_to_compute = TransferState::ShuttingDown(buf.amt);
client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, client, compute)
.map_err(ErrorSource::from_client)?;
}
}
// Early termination checks from client to compute.
if let TransferState::Done(_) = client_to_compute
&& let TransferState::Running(buf) = &compute_to_client
{
info!("Client is done, terminate compute");
// Initiate shutdown
compute_to_client = TransferState::ShuttingDown(buf.amt);
compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, compute, client)
.map_err(ErrorSource::from_compute)?;
if let TransferState::Done(_) = client_to_compute {
if let TransferState::Running(buf) = &compute_to_client {
info!("Client is done, terminate compute");
// Initiate shutdown
compute_to_client = TransferState::ShuttingDown(buf.amt);
compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, compute, client)
.map_err(ErrorSource::from_compute)?;
}
}
// It is not a problem if ready! returns early ... (comment remains the same)

View File

@@ -39,11 +39,7 @@ impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
let config = config.map_or(self.default_config, Into::into);
if self
.access_count
.fetch_add(1, Ordering::AcqRel)
.is_multiple_of(2048)
{
if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
self.do_gc(now);
}

View File

@@ -211,11 +211,7 @@ impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
// worst case memory usage is about:
// = 2 * 2048 * 64 * (48B + 72B)
// = 30MB
if self
.access_count
.fetch_add(1, Ordering::AcqRel)
.is_multiple_of(2048)
{
if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
self.do_gc();
}

View File

@@ -0,0 +1,79 @@
use core::net::IpAddr;
use std::sync::Arc;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::pqproto::CancelKeyData;
pub trait CancellationPublisherMut: Send + Sync + 'static {
#[allow(async_fn_in_trait)]
async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()>;
}
pub trait CancellationPublisher: Send + Sync + 'static {
#[allow(async_fn_in_trait)]
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()>;
}
impl CancellationPublisher for () {
async fn try_publish(
&self,
_cancel_key_data: CancelKeyData,
_session_id: Uuid,
_peer_addr: IpAddr,
) -> anyhow::Result<()> {
Ok(())
}
}
impl<P: CancellationPublisher> CancellationPublisherMut for P {
async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()> {
<P as CancellationPublisher>::try_publish(self, cancel_key_data, session_id, peer_addr)
.await
}
}
impl<P: CancellationPublisher> CancellationPublisher for Option<P> {
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()> {
if let Some(p) = self {
p.try_publish(cancel_key_data, session_id, peer_addr).await
} else {
Ok(())
}
}
}
impl<P: CancellationPublisherMut> CancellationPublisher for Arc<Mutex<P>> {
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()> {
self.lock()
.await
.try_publish(cancel_key_data, session_id, peer_addr)
.await
}
}

View File

@@ -1,12 +1,11 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, atomic::AtomicBool, atomic::Ordering};
use std::time::Duration;
use futures::FutureExt;
use redis::aio::{ConnectionLike, MultiplexedConnection};
use redis::{ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult};
use tokio::task::AbortHandle;
use tracing::{error, info, warn};
use tokio::task::JoinHandle;
use tracing::{debug, error, info, warn};
use super::elasticache::CredentialsProvider;
@@ -32,7 +31,7 @@ pub struct ConnectionWithCredentialsProvider {
credentials: Credentials,
// TODO: with more load on the connection, we should consider using a connection pool
con: Option<MultiplexedConnection>,
refresh_token_task: Option<AbortHandle>,
refresh_token_task: Option<JoinHandle<()>>,
mutex: tokio::sync::Mutex<()>,
credentials_refreshed: Arc<AtomicBool>,
}
@@ -122,12 +121,16 @@ impl ConnectionWithCredentialsProvider {
let credentials_provider = credentials_provider.clone();
let con2 = con.clone();
let credentials_refreshed = self.credentials_refreshed.clone();
let f = tokio::spawn(Self::keep_connection(
con2,
credentials_provider,
credentials_refreshed,
));
self.refresh_token_task = Some(f.abort_handle());
let f = tokio::spawn(async move {
let result = Self::keep_connection(con2, credentials_provider).await;
if let Err(e) = result {
credentials_refreshed.store(false, Ordering::Release);
debug!("keep_connection failed: {e}");
} else {
credentials_refreshed.store(true, Ordering::Release);
}
});
self.refresh_token_task = Some(f);
}
match Self::ping(&mut con).await {
Ok(()) => {
@@ -162,7 +165,6 @@ impl ConnectionWithCredentialsProvider {
async fn get_client(&self) -> anyhow::Result<redis::Client> {
let client = redis::Client::open(self.get_connection_info().await?)?;
self.credentials_refreshed.store(true, Ordering::Relaxed);
Ok(client)
}
@@ -178,19 +180,16 @@ impl ConnectionWithCredentialsProvider {
async fn keep_connection(
mut con: MultiplexedConnection,
credentials_provider: Arc<CredentialsProvider>,
credentials_refreshed: Arc<AtomicBool>,
) -> ! {
) -> anyhow::Result<()> {
loop {
// The connection lives for 12h, for the sanity check we refresh it every hour.
tokio::time::sleep(Duration::from_secs(60 * 60)).await;
match Self::refresh_token(&mut con, credentials_provider.clone()).await {
Ok(()) => {
info!("Token refreshed");
credentials_refreshed.store(true, Ordering::Relaxed);
}
Err(e) => {
error!("Error during token refresh: {e:?}");
credentials_refreshed.store(false, Ordering::Relaxed);
}
}
}
@@ -244,7 +243,7 @@ impl ConnectionLike for ConnectionWithCredentialsProvider {
&'a mut self,
cmd: &'a redis::Cmd,
) -> redis::RedisFuture<'a, redis::Value> {
self.send_packed_command(cmd).boxed()
(async move { self.send_packed_command(cmd).await }).boxed()
}
fn req_packed_commands<'a>(
@@ -253,10 +252,10 @@ impl ConnectionLike for ConnectionWithCredentialsProvider {
offset: usize,
count: usize,
) -> redis::RedisFuture<'a, Vec<redis::Value>> {
self.send_packed_commands(cmd, offset, count).boxed()
(async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
}
fn get_db(&self) -> i64 {
self.con.as_ref().map_or(0, |c| c.get_db())
0
}
}

View File

@@ -1,3 +1,4 @@
pub mod cancellation_publisher;
pub mod connection_with_credentials_provider;
pub mod elasticache;
pub mod keys;

View File

@@ -54,7 +54,9 @@ impl<T: std::fmt::Display> ChannelBinding<T> {
"eSws".into()
}
Self::Required(mode) => {
let mut cbind_input = format!("p={mode},,",).into_bytes();
use std::io::Write;
let mut cbind_input = vec![];
write!(&mut cbind_input, "p={mode},,",).unwrap();
cbind_input.extend_from_slice(get_cbind_data(mode)?);
BASE64_STANDARD.encode(&cbind_input).into()
}

View File

@@ -107,7 +107,7 @@ pub(crate) async fn exchange(
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
let salt = BASE64_STANDARD.decode(&secret.salt_base64)?;
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
if secret.is_password_invalid(&client_key).into() {

View File

@@ -87,20 +87,13 @@ impl<'a> ClientFirstMessage<'a> {
salt_base64: &str,
iterations: u32,
) -> OwnedServerFirstMessage {
let mut message = String::with_capacity(128);
message.push_str("r=");
use std::fmt::Write;
// write combined nonce
let combined_nonce_start = message.len();
message.push_str(self.nonce);
let mut message = String::new();
write!(&mut message, "r={}", self.nonce).unwrap();
BASE64_STANDARD.encode_string(nonce, &mut message);
let combined_nonce = combined_nonce_start..message.len();
// write salt and iterations
message.push_str(",s=");
message.push_str(salt_base64);
message.push_str(",i=");
message.push_str(itoa::Buffer::new().format(iterations));
let combined_nonce = 2..message.len();
write!(&mut message, ",s={salt_base64},i={iterations}").unwrap();
// This design guarantees that it's impossible to create a
// server-first-message without receiving a client-first-message

View File

@@ -14,7 +14,7 @@ pub(crate) struct ServerSecret {
/// Number of iterations for `PBKDF2` function.
pub(crate) iterations: u32,
/// Salt used to hash user's password.
pub(crate) salt_base64: Box<str>,
pub(crate) salt_base64: String,
/// Hashed `ClientKey`.
pub(crate) stored_key: ScramKey,
/// Used by client to verify server's signature.
@@ -35,7 +35,7 @@ impl ServerSecret {
let secret = ServerSecret {
iterations: iterations.parse().ok()?,
salt_base64: salt.into(),
salt_base64: salt.to_owned(),
stored_key: base64_decode_array(stored_key)?.into(),
server_key: base64_decode_array(server_key)?.into(),
doomed: false,
@@ -58,7 +58,7 @@ impl ServerSecret {
// iteration count 1 for our generated passwords going forward.
// PG16 users can set iteration count=1 already today.
iterations: 1,
salt_base64: BASE64_STANDARD.encode(nonce).into_boxed_str(),
salt_base64: BASE64_STANDARD.encode(nonce),
stored_key: ScramKey::default(),
server_key: ScramKey::default(),
doomed: true,
@@ -88,7 +88,7 @@ mod tests {
let parsed = ServerSecret::parse(&secret).unwrap();
assert_eq!(parsed.iterations, iterations);
assert_eq!(&*parsed.salt_base64, salt);
assert_eq!(parsed.salt_base64, salt);
assert_eq!(BASE64_STANDARD.encode(parsed.stored_key), stored_key);
assert_eq!(BASE64_STANDARD.encode(parsed.server_key), server_key);

View File

@@ -137,7 +137,7 @@ impl Future for JobSpec {
let state = state.as_mut().expect("should be set on thread startup");
state.tick = state.tick.wrapping_add(1);
if state.tick.is_multiple_of(SKETCH_RESET_INTERVAL) {
if state.tick % SKETCH_RESET_INTERVAL == 0 {
state.countmin.reset();
}

View File

@@ -115,8 +115,7 @@ impl PoolingBackend {
match &self.auth_backend {
crate::auth::Backend::ControlPlane(console, ()) => {
let keys = self
.config
self.config
.authentication_config
.jwks_cache
.check_jwt(
@@ -130,7 +129,7 @@ impl PoolingBackend {
Ok(ComputeCredentials {
info: user_info.clone(),
keys,
keys: crate::auth::backend::ComputeCredentialKeys::None,
})
}
crate::auth::Backend::Local(_) => {
@@ -257,7 +256,6 @@ impl PoolingBackend {
&self,
ctx: &RequestContext,
conn_info: ConnInfo,
disable_pg_session_jwt: bool,
) -> Result<Client<postgres_client::Client>, HttpConnError> {
if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
return Ok(client);
@@ -279,7 +277,7 @@ impl PoolingBackend {
.expect("semaphore should never be closed");
// check again for race
if !self.local_pool.initialized(&conn_info) && !disable_pg_session_jwt {
if !self.local_pool.initialized(&conn_info) {
local_backend
.compute_ctl
.install_extension(&ExtensionInstallRequest {
@@ -315,16 +313,14 @@ impl PoolingBackend {
.to_postgres_client_config();
config
.user(&conn_info.user_info.user)
.dbname(&conn_info.dbname);
if !disable_pg_session_jwt {
config.set_param(
.dbname(&conn_info.dbname)
.set_param(
"options",
&format!(
"-c pg_session_jwt.jwk={}",
serde_json::to_string(&jwk).expect("serializing jwk to json should not fail")
),
);
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = config.connect(&postgres_client::NoTls).await?;
@@ -349,9 +345,7 @@ impl PoolingBackend {
debug!("setting up backend session state");
// initiates the auth session
if !disable_pg_session_jwt
&& let Err(e) = client.batch_execute("select auth.init();").await
{
if let Err(e) = client.batch_execute("select auth.init();").await {
discard.discard();
return Err(e.into());
}

View File

@@ -148,10 +148,11 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade()
&& pool.write().remove_client(db_user.clone(), conn_id) {
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_client(db_user.clone(), conn_id) {
info!("closed connection removed");
}
}
Poll::Ready(())
}).await;

View File

@@ -1,93 +1,5 @@
use http::StatusCode;
use http::header::HeaderName;
use crate::auth::ComputeUserInfoParseError;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::http::ReadBodyError;
pub trait HttpCodeError {
fn get_http_status_code(&self) -> StatusCode;
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ConnInfoError {
#[error("invalid header: {0}")]
InvalidHeader(&'static HeaderName),
#[error("invalid connection string: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("incorrect scheme")]
IncorrectScheme,
#[error("missing database name")]
MissingDbName,
#[error("invalid database name")]
InvalidDbName,
#[error("missing username")]
MissingUsername,
#[error("invalid username: {0}")]
InvalidUsername(#[from] std::string::FromUtf8Error),
#[error("missing authentication credentials: {0}")]
MissingCredentials(Credentials),
#[error("missing hostname")]
MissingHostname,
#[error("invalid hostname: {0}")]
InvalidEndpoint(#[from] ComputeUserInfoParseError),
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum Credentials {
#[error("required password")]
Password,
#[error("required authorization bearer token in JWT format")]
BearerJwt,
}
impl ReportableError for ConnInfoError {
fn get_error_kind(&self) -> ErrorKind {
ErrorKind::User
}
}
impl UserFacingError for ConnInfoError {
fn to_string_client(&self) -> String {
self.to_string()
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ReadPayloadError {
#[error("could not read the HTTP request body: {0}")]
Read(#[from] hyper::Error),
#[error("request is too large (max is {limit} bytes)")]
BodyTooLarge { limit: usize },
#[error("could not parse the HTTP request body: {0}")]
Parse(#[from] serde_json::Error),
}
impl From<ReadBodyError<hyper::Error>> for ReadPayloadError {
fn from(value: ReadBodyError<hyper::Error>) -> Self {
match value {
ReadBodyError::BodyTooLarge { limit } => Self::BodyTooLarge { limit },
ReadBodyError::Read(e) => Self::Read(e),
}
}
}
impl ReportableError for ReadPayloadError {
fn get_error_kind(&self) -> ErrorKind {
match self {
ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
ReadPayloadError::BodyTooLarge { .. } => ErrorKind::User,
ReadPayloadError::Parse(_) => ErrorKind::User,
}
}
}
impl HttpCodeError for ReadPayloadError {
fn get_http_status_code(&self) -> StatusCode {
match self {
ReadPayloadError::Read(_) => StatusCode::BAD_REQUEST,
ReadPayloadError::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
ReadPayloadError::Parse(_) => StatusCode::BAD_REQUEST,
}
}
}

View File

@@ -2,8 +2,6 @@ use std::collections::VecDeque;
use std::sync::atomic::{self, AtomicUsize};
use std::sync::{Arc, Weak};
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use hyper::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
@@ -23,9 +21,8 @@ use crate::protocol2::ConnectionInfoExtra;
use crate::types::EndpointCacheKey;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
pub(crate) type Send = http2::SendRequest<BoxBody<Bytes, hyper::Error>>;
pub(crate) type Connect =
http2::Connection<TokioIo<AsyncRW>, BoxBody<Bytes, hyper::Error>, TokioExecutor>;
pub(crate) type Send = http2::SendRequest<hyper::body::Incoming>;
pub(crate) type Connect = http2::Connection<TokioIo<AsyncRW>, hyper::body::Incoming, TokioExecutor>;
#[derive(Clone)]
pub(crate) struct ClientDataHttp();
@@ -240,10 +237,10 @@ pub(crate) fn poll_http2_client(
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade()
&& pool.write().remove_conn(conn_id)
{
info!("closed connection removed");
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_conn(conn_id) {
info!("closed connection removed");
}
}
}
.instrument(span),

View File

@@ -3,42 +3,11 @@
use anyhow::Context;
use bytes::Bytes;
use http::header::AUTHORIZATION;
use http::{HeaderMap, HeaderName, HeaderValue, Response, StatusCode};
use http::{Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use http_utils::error::ApiError;
use serde::Serialize;
use url::Url;
use uuid::Uuid;
use super::conn_pool::{AuthData, ConnInfoWithAuth};
use super::conn_pool_lib::ConnInfo;
use super::error::{ConnInfoError, Credentials};
use crate::auth::backend::ComputeUserInfo;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::metrics::{Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::NeonOptions;
use crate::types::{DbName, EndpointId, RoleName};
// Common header names used across serverless modules
pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
pub(super) static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
pub(super) static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
pub(super) static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
pub(super) static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
pub(super) static TXN_ISOLATION_LEVEL: HeaderName =
HeaderName::from_static("neon-batch-isolation-level");
pub(super) static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
pub(super) static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
.expect("uuid hyphenated format should be all valid header characters")
}
/// Like [`ApiError::into_response`]
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper::Error>> {
@@ -138,136 +107,3 @@ pub(crate) fn json_response<T: Serialize>(
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
}
pub(crate) fn get_conn_info(
config: &'static AuthenticationConfig,
ctx: &RequestContext,
connection_string: Option<&str>,
headers: &HeaderMap,
) -> Result<ConnInfoWithAuth, ConnInfoError> {
let connection_url = match connection_string {
Some(connection_string) => Url::parse(connection_string)?,
None => {
let connection_string = headers
.get(&CONN_STRING)
.ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
Url::parse(connection_string)?
}
};
let protocol = connection_url.scheme();
if protocol != "postgres" && protocol != "postgresql" {
return Err(ConnInfoError::IncorrectScheme);
}
let mut url_path = connection_url
.path_segments()
.ok_or(ConnInfoError::MissingDbName)?;
let dbname: DbName =
urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
ctx.set_dbname(dbname.clone());
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
if username.is_empty() {
return Err(ConnInfoError::MissingUsername);
}
ctx.set_user(username.clone());
// TODO: make sure this is right in the context of rest broker
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
if !config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
}
let auth = auth
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
AuthData::Jwt(
auth.strip_prefix("Bearer ")
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
.into(),
)
} else if let Some(pass) = connection_url.password() {
// wrong credentials provided
if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
}
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
})
} else if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
} else {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
};
let endpoint: EndpointId = match connection_url.host() {
Some(url::Host::Domain(hostname)) => hostname
.split_once('.')
.map_or(hostname, |(prefix, _)| prefix)
.into(),
Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
return Err(ConnInfoError::MissingHostname);
}
};
ctx.set_endpoint_id(endpoint.clone());
let pairs = connection_url.query_pairs();
let mut options = Option::None;
let mut params = StartupMessageParams::default();
params.insert("user", &username);
params.insert("database", &dbname);
for (key, value) in pairs {
params.insert(&key, &value);
if key == "options" {
options = Some(NeonOptions::parse_options_raw(&value));
}
}
// check the URL that was used, for metrics
{
let host_endpoint = headers
// get the host header
.get("host")
// extract the domain
.and_then(|h| {
let (host, _port) = h.to_str().ok()?.split_once(':')?;
Some(host)
})
// get the endpoint prefix
.map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix));
let kind = if host_endpoint == Some(&*endpoint) {
SniKind::Sni
} else {
SniKind::NoSni
};
let protocol = ctx.protocol();
Metrics::get()
.proxy
.accepted_connections_by_sni
.inc(SniGroup { protocol, kind });
}
ctx.set_user_agent(
headers
.get(hyper::header::USER_AGENT)
.and_then(|h| h.to_str().ok())
.map(Into::into),
);
let user_info = ComputeUserInfo {
endpoint,
user: username,
options: options.unwrap_or_default(),
};
let conn_info = ConnInfo { user_info, dbname };
Ok(ConnInfoWithAuth { conn_info, auth })
}

View File

@@ -1,112 +1,60 @@
use postgres_client::Row;
use postgres_client::types::{Kind, Type};
use serde::Deserialize;
use serde::de::{Deserializer, IgnoredAny, Visitor};
use serde_json::value::RawValue;
use serde_json::{Map, Value};
//
// Convert json non-string types to strings, so that they can be passed to Postgres
// as parameters.
//
pub(crate) fn json_to_pg_text(json: Vec<Box<RawValue>>) -> Vec<Option<String>> {
json.into_iter()
.map(|raw| {
match raw.get().as_bytes() {
// special handling for null.
b"null" => None,
// remove the escape characters from the string.
[b'"', ..] => {
Some(String::deserialize(&*raw).expect("json should be a valid string"))
}
[b'[', ..] => {
let mut output = String::with_capacity(raw.get().len());
raw.deserialize_seq(PgArrayVisitor(&raw, &mut output))
.expect("json should be a valid");
Some(output)
}
// write all other values out directly
_ => Some(<Box<str>>::from(raw).into()),
}
})
.collect()
pub(crate) fn json_to_pg_text(json: Vec<Value>) -> Vec<Option<String>> {
json.iter().map(json_value_to_pg_text).collect()
}
struct PgArrayVisitor<'de, 'a>(&'de RawValue, &'a mut String);
fn json_value_to_pg_text(value: &Value) -> Option<String> {
match value {
// special care for nulls
Value::Null => None,
impl PgArrayVisitor<'_, '_> {
#[inline]
#[allow(clippy::unnecessary_wraps)]
fn raw<E>(self) -> Result<(), E> {
self.1.push_str(self.0.get());
Ok(())
// convert to text with escaping
v @ (Value::Bool(_) | Value::Number(_) | Value::Object(_)) => Some(v.to_string()),
// avoid escaping here, as we pass this as a parameter
Value::String(s) => Some(s.clone()),
// special care for arrays
Value::Array(_) => json_array_to_pg_array(value),
}
}
impl<'de> Visitor<'de> for PgArrayVisitor<'de, '_> {
type Value = ();
//
// Serialize a JSON array to a Postgres array. Contrary to the strings in the params
// in the array we need to escape the strings. Postgres is okay with arrays of form
// '{1,"2",3}'::int[], so we don't check that array holds values of the same type, leaving
// it for Postgres to check.
//
// Example of the same escaping in node-postgres: packages/pg/lib/utils.js
//
fn json_array_to_pg_array(value: &Value) -> Option<String> {
match value {
// special care for nulls
Value::Null => None,
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("any valid JSON value")
}
// convert to text with escaping
// here string needs to be escaped, as it is part of the array
v @ (Value::Bool(_) | Value::Number(_) | Value::String(_)) => Some(v.to_string()),
v @ Value::Object(_) => json_array_to_pg_array(&Value::String(v.to_string())),
// special care for nulls
fn visit_none<E>(self) -> Result<Self::Value, E> {
self.1.push_str("NULL");
Ok(())
}
fn visit_unit<E>(self) -> Result<Self::Value, E> {
self.1.push_str("NULL");
Ok(())
}
// recurse into array
Value::Array(arr) => {
let vals = arr
.iter()
.map(json_array_to_pg_array)
.map(|v| v.unwrap_or_else(|| "NULL".to_string()))
.collect::<Vec<_>>()
.join(",");
// convert to text with escaping
fn visit_bool<E>(self, _: bool) -> Result<Self::Value, E> {
self.raw()
}
fn visit_i64<E>(self, _: i64) -> Result<Self::Value, E> {
self.raw()
}
fn visit_u64<E>(self, _: u64) -> Result<Self::Value, E> {
self.raw()
}
fn visit_i128<E>(self, _: i128) -> Result<Self::Value, E> {
self.raw()
}
fn visit_u128<E>(self, _: u128) -> Result<Self::Value, E> {
self.raw()
}
fn visit_f64<E>(self, _: f64) -> Result<Self::Value, E> {
self.raw()
}
fn visit_str<E>(self, _: &str) -> Result<Self::Value, E> {
self.raw()
}
// an object needs re-escaping
fn visit_map<A: serde::de::MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
while map.next_entry::<IgnoredAny, IgnoredAny>()?.is_some() {}
let s = serde_json::to_string(self.0.get()).expect("a string should be valid json");
self.1.push_str(&s);
Ok(())
}
// write an array
fn visit_seq<A: serde::de::SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
self.1.push('{');
let mut comma = false;
while let Some(val) = seq.next_element::<&'de RawValue>()? {
if comma {
self.1.push(',');
}
comma = true;
val.deserialize_any(PgArrayVisitor(val, self.1))
.expect("all json values are valid");
Some(format!("{{{vals}}}"))
}
self.1.push('}');
Ok(())
}
}
@@ -436,14 +384,6 @@ mod tests {
use super::*;
fn json_to_pg_text(json: Vec<serde_json::Value>) -> Vec<Option<String>> {
let json = json
.into_iter()
.map(|value| serde_json::from_str(&value.to_string()).unwrap())
.collect();
super::json_to_pg_text(json)
}
#[test]
fn test_atomic_types_to_pg_params() {
let json = vec![Value::Bool(true), Value::Bool(false)];

View File

@@ -249,10 +249,11 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade()
&& pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
if let Some(pool) = pool.clone().upgrade() {
if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
info!("closed connection removed");
}
}
Poll::Ready(())
}).await;

View File

@@ -29,13 +29,13 @@ use futures::future::{Either, select};
use http::{Method, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty};
use http_util::{NEON_REQUEST_ID, uuid_to_header_value};
use http_utils::error::ApiError;
use hyper::body::Incoming;
use hyper_util::rt::TokioExecutor;
use hyper_util::server::conn::auto::Builder;
use rand::SeedableRng;
use rand::rngs::StdRng;
use sql_over_http::{NEON_REQUEST_ID, uuid_to_header_value};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio::time::timeout;

View File

@@ -11,35 +11,39 @@ use http_body_util::{BodyExt, Full};
use http_utils::error::ApiError;
use hyper::body::Incoming;
use hyper::http::{HeaderName, HeaderValue};
use hyper::{Request, Response, StatusCode, header};
use hyper::{HeaderMap, Request, Response, StatusCode, header};
use indexmap::IndexMap;
use postgres_client::error::{DbError, ErrorPosition, SqlState};
use postgres_client::{
GenericClient, IsolationLevel, NoTls, ReadyForQueryStatus, RowStream, Transaction,
};
use serde::Serialize;
use serde_json::Value;
use serde_json::value::RawValue;
use tokio::time::{self, Instant};
use tokio_util::sync::CancellationToken;
use tracing::{Level, debug, error, info};
use typed_json::json;
use url::Url;
use uuid::Uuid;
use super::backend::{LocalProxyConnError, PoolingBackend};
use super::conn_pool::AuthData;
use super::conn_pool::{AuthData, ConnInfoWithAuth};
use super::conn_pool_lib::{self, ConnInfo};
use super::error::{ConnInfoError, HttpCodeError, ReadPayloadError};
use super::http_util::{
ALLOW_POOL, ARRAY_MODE, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_DEFERRABLE,
TXN_ISOLATION_LEVEL, TXN_READ_ONLY, get_conn_info, json_response, uuid_to_header_value,
};
use super::error::HttpCodeError;
use super::http_util::json_response;
use super::json::{JsonConversionError, json_to_pg_text, pg_text_row_to_json};
use crate::auth::backend::ComputeCredentialKeys;
use crate::config::{HttpConfig, ProxyConfig};
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::{ComputeUserInfoParseError, endpoint_sni};
use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig};
use crate::context::RequestContext;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::http::read_body_with_limit;
use crate::metrics::{HttpDirection, Metrics};
use crate::http::{ReadBodyError, read_body_with_limit};
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::NeonOptions;
use crate::serverless::backend::HttpConnError;
use crate::types::{DbName, RoleName};
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
use crate::util::run_until_cancelled;
@@ -47,8 +51,9 @@ use crate::util::run_until_cancelled;
#[serde(rename_all = "camelCase")]
struct QueryData {
query: String,
#[serde(deserialize_with = "bytes_to_pg_text")]
#[serde(default)]
params: Vec<Box<RawValue>>,
params: Vec<Option<String>>,
#[serde(default)]
array_mode: Option<bool>,
}
@@ -58,13 +63,216 @@ struct BatchQueryData {
queries: Vec<QueryData>,
}
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum Payload {
Single(QueryData),
Batch(BatchQueryData),
}
pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
where
D: serde::de::Deserializer<'de>,
{
// TODO: consider avoiding the allocation here.
let json: Vec<Value> = serde::de::Deserialize::deserialize(deserializer)?;
Ok(json_to_pg_text(json))
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ConnInfoError {
#[error("invalid header: {0}")]
InvalidHeader(&'static HeaderName),
#[error("invalid connection string: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("incorrect scheme")]
IncorrectScheme,
#[error("missing database name")]
MissingDbName,
#[error("invalid database name")]
InvalidDbName,
#[error("missing username")]
MissingUsername,
#[error("invalid username: {0}")]
InvalidUsername(#[from] std::string::FromUtf8Error),
#[error("missing authentication credentials: {0}")]
MissingCredentials(Credentials),
#[error("missing hostname")]
MissingHostname,
#[error("invalid hostname: {0}")]
InvalidEndpoint(#[from] ComputeUserInfoParseError),
#[error("malformed endpoint")]
MalformedEndpoint,
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum Credentials {
#[error("required password")]
Password,
#[error("required authorization bearer token in JWT format")]
BearerJwt,
}
impl ReportableError for ConnInfoError {
fn get_error_kind(&self) -> ErrorKind {
ErrorKind::User
}
}
impl UserFacingError for ConnInfoError {
fn to_string_client(&self) -> String {
self.to_string()
}
}
fn get_conn_info(
config: &'static AuthenticationConfig,
ctx: &RequestContext,
headers: &HeaderMap,
tls: Option<&TlsConfig>,
) -> Result<ConnInfoWithAuth, ConnInfoError> {
let connection_string = headers
.get(&CONN_STRING)
.ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
let connection_url = Url::parse(connection_string)?;
let protocol = connection_url.scheme();
if protocol != "postgres" && protocol != "postgresql" {
return Err(ConnInfoError::IncorrectScheme);
}
let mut url_path = connection_url
.path_segments()
.ok_or(ConnInfoError::MissingDbName)?;
let dbname: DbName =
urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
ctx.set_dbname(dbname.clone());
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
if username.is_empty() {
return Err(ConnInfoError::MissingUsername);
}
ctx.set_user(username.clone());
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
if !config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
}
let auth = auth
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
AuthData::Jwt(
auth.strip_prefix("Bearer ")
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
.into(),
)
} else if let Some(pass) = connection_url.password() {
// wrong credentials provided
if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
}
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
})
} else if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
} else {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
};
let endpoint = match connection_url.host() {
Some(url::Host::Domain(hostname)) => {
if let Some(tls) = tls {
endpoint_sni(hostname, &tls.common_names).ok_or(ConnInfoError::MalformedEndpoint)?
} else {
hostname
.split_once('.')
.map_or(hostname, |(prefix, _)| prefix)
.into()
}
}
Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
return Err(ConnInfoError::MissingHostname);
}
};
ctx.set_endpoint_id(endpoint.clone());
let pairs = connection_url.query_pairs();
let mut options = Option::None;
let mut params = StartupMessageParams::default();
params.insert("user", &username);
params.insert("database", &dbname);
for (key, value) in pairs {
params.insert(&key, &value);
if key == "options" {
options = Some(NeonOptions::parse_options_raw(&value));
}
}
// check the URL that was used, for metrics
{
let host_endpoint = headers
// get the host header
.get("host")
// extract the domain
.and_then(|h| {
let (host, _port) = h.to_str().ok()?.split_once(':')?;
Some(host)
})
// get the endpoint prefix
.map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix));
let kind = if host_endpoint == Some(&*endpoint) {
SniKind::Sni
} else {
SniKind::NoSni
};
let protocol = ctx.protocol();
Metrics::get()
.proxy
.accepted_connections_by_sni
.inc(SniGroup { protocol, kind });
}
ctx.set_user_agent(
headers
.get(hyper::header::USER_AGENT)
.and_then(|h| h.to_str().ok())
.map(Into::into),
);
let user_info = ComputeUserInfo {
endpoint,
user: username,
options: options.unwrap_or_default(),
};
let conn_info = ConnInfo { user_info, dbname };
Ok(ConnInfoWithAuth { conn_info, auth })
}
pub(crate) async fn handle(
config: &'static ProxyConfig,
ctx: RequestContext,
@@ -333,6 +541,45 @@ impl HttpCodeError for SqlOverHttpError {
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ReadPayloadError {
#[error("could not read the HTTP request body: {0}")]
Read(#[from] hyper::Error),
#[error("request is too large (max is {limit} bytes)")]
BodyTooLarge { limit: usize },
#[error("could not parse the HTTP request body: {0}")]
Parse(#[from] serde_json::Error),
}
impl From<ReadBodyError<hyper::Error>> for ReadPayloadError {
fn from(value: ReadBodyError<hyper::Error>) -> Self {
match value {
ReadBodyError::BodyTooLarge { limit } => Self::BodyTooLarge { limit },
ReadBodyError::Read(e) => Self::Read(e),
}
}
}
impl ReportableError for ReadPayloadError {
fn get_error_kind(&self) -> ErrorKind {
match self {
ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
ReadPayloadError::BodyTooLarge { .. } => ErrorKind::User,
ReadPayloadError::Parse(_) => ErrorKind::User,
}
}
}
impl HttpCodeError for ReadPayloadError {
fn get_http_status_code(&self) -> StatusCode {
match self {
ReadPayloadError::Read(_) => StatusCode::BAD_REQUEST,
ReadPayloadError::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
ReadPayloadError::Parse(_) => StatusCode::BAD_REQUEST,
}
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum SqlOverHttpCancel {
#[error("query was cancelled")]
@@ -423,7 +670,14 @@ async fn handle_inner(
"handling interactive connection from client"
);
let conn_info = get_conn_info(&config.authentication_config, ctx, None, request.headers())?;
let conn_info = get_conn_info(
&config.authentication_config,
ctx,
request.headers(),
// todo: race condition?
// we're unlikely to change the common names.
config.tls_config.load().as_deref(),
)?;
info!(
user = conn_info.conn_info.user_info.user.as_str(),
"credentials"
@@ -486,14 +740,7 @@ async fn handle_db_inner(
.observe(HttpDirection::Request, body.len() as f64);
debug!(length = body.len(), "request payload read");
// try unbatched, then try batched.
let payload = if let Ok(batch) = serde_json::from_slice(&body) {
Payload::Batch(batch)
} else {
Payload::Single(serde_json::from_slice(&body)?)
};
let payload: Payload = serde_json::from_slice(&body)?;
Ok::<Payload, ReadPayloadError>(payload) // Adjust error type accordingly
}
.map_err(SqlOverHttpError::from),
@@ -516,17 +763,9 @@ async fn handle_db_inner(
ComputeCredentialKeys::JwtPayload(payload)
if backend.auth_backend.is_local_proxy() =>
{
#[cfg(feature = "testing")]
let disable_pg_session_jwt = config.disable_pg_session_jwt;
#[cfg(not(feature = "testing"))]
let disable_pg_session_jwt = false;
let mut client = backend
.connect_to_local_postgres(ctx, conn_info, disable_pg_session_jwt)
.await?;
if !disable_pg_session_jwt {
let (cli_inner, _dsc) = client.client_inner();
cli_inner.set_jwt_session(&payload).await?;
}
let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?;
let (cli_inner, _dsc) = client.client_inner();
cli_inner.set_jwt_session(&payload).await?;
Client::Local(client)
}
_ => {
@@ -625,6 +864,12 @@ static HEADERS_TO_FORWARD: &[&HeaderName] = &[
&TXN_DEFERRABLE,
];
pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
.expect("uuid hyphenated format should be all valid header characters")
}
async fn handle_auth_broker_inner(
ctx: &RequestContext,
request: Request<Incoming>,
@@ -654,7 +899,7 @@ async fn handle_auth_broker_inner(
req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id()));
let req = req
.body(body.map_err(|e| e).boxed()) //TODO: is there a potential for a regression here?
.body(body)
.expect("all headers and params received via hyper should be valid for request");
// todo: map body to count egress
@@ -881,8 +1126,7 @@ async fn query_to_json<T: GenericClient>(
) -> Result<(ReadyForQueryStatus, impl Serialize + use<T>), SqlOverHttpError> {
let query_start = Instant::now();
let query_params = json_to_pg_text(data.params);
let query_params = data.params;
let mut row_stream = client
.query_raw_txt(&data.query, query_params)
.await
@@ -1028,38 +1272,55 @@ mod tests {
#[test]
fn test_payload() {
let payload = "{\"query\":\"SELECT * FROM users WHERE name = ?\",\"params\":[\"test\"],\"arrayMode\":true}";
let QueryData {
query,
params,
array_mode,
} = serde_json::from_str(payload).unwrap();
let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
assert_eq!(query, "SELECT * FROM users WHERE name = ?");
assert_eq!(params[0].get(), "\"test\"");
assert!(array_mode.unwrap());
match deserialized_payload {
Payload::Single(QueryData {
query,
params,
array_mode,
}) => {
assert_eq!(query, "SELECT * FROM users WHERE name = ?");
assert_eq!(params, vec![Some(String::from("test"))]);
assert!(array_mode.unwrap());
}
Payload::Batch(_) => {
panic!("deserialization failed: case with single query, one param, and array mode")
}
}
let payload = "{\"queries\":[{\"query\":\"SELECT * FROM users0 WHERE name = ?\",\"params\":[\"test0\"], \"arrayMode\":false},{\"query\":\"SELECT * FROM users1 WHERE name = ?\",\"params\":[\"test1\"],\"arrayMode\":true}]}";
let BatchQueryData { queries } = serde_json::from_str(payload).unwrap();
let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
assert_eq!(queries.len(), 2);
for (i, query) in queries.into_iter().enumerate() {
assert_eq!(
query.query,
format!("SELECT * FROM users{i} WHERE name = ?")
);
assert_eq!(query.params[0].get(), &format!("\"test{i}\""));
assert_eq!(query.array_mode.unwrap(), i > 0);
match deserialized_payload {
Payload::Batch(BatchQueryData { queries }) => {
assert_eq!(queries.len(), 2);
for (i, query) in queries.into_iter().enumerate() {
assert_eq!(
query.query,
format!("SELECT * FROM users{i} WHERE name = ?")
);
assert_eq!(query.params, vec![Some(format!("test{i}"))]);
assert_eq!(query.array_mode.unwrap(), i > 0);
}
}
Payload::Single(_) => panic!("deserialization failed: case with multiple queries"),
}
let payload = "{\"query\":\"SELECT 1\"}";
let QueryData {
query,
params,
array_mode,
} = serde_json::from_str(payload).unwrap();
let deserialized_payload: Payload = serde_json::from_str(payload).unwrap();
assert_eq!(query, "SELECT 1");
assert!(params.is_empty());
assert!(array_mode.is_none());
match deserialized_payload {
Payload::Single(QueryData {
query,
params,
array_mode,
}) => {
assert_eq!(query, "SELECT 1");
assert_eq!(params, vec![]);
assert!(array_mode.is_none());
}
Payload::Batch(_) => panic!("deserialization failed: case with only one query"),
}
}
}

View File

@@ -199,27 +199,27 @@ impl<S: AsyncWrite + Unpin> PqStream<S> {
let probe_msg;
let mut msg = &*msg;
if let Some(ctx) = ctx
&& ctx.get_testodrome_id().is_some()
{
let tag = match error_kind {
ErrorKind::User => "client",
ErrorKind::ClientDisconnect => "client",
ErrorKind::RateLimit => "proxy",
ErrorKind::ServiceRateLimit => "proxy",
ErrorKind::Quota => "proxy",
ErrorKind::Service => "proxy",
ErrorKind::ControlPlane => "controlplane",
ErrorKind::Postgres => "other",
ErrorKind::Compute => "compute",
};
probe_msg = typed_json::json!({
"tag": tag,
"msg": msg,
"cold_start_info": ctx.cold_start_info(),
})
.to_string();
msg = &probe_msg;
if let Some(ctx) = ctx {
if ctx.get_testodrome_id().is_some() {
let tag = match error_kind {
ErrorKind::User => "client",
ErrorKind::ClientDisconnect => "client",
ErrorKind::RateLimit => "proxy",
ErrorKind::ServiceRateLimit => "proxy",
ErrorKind::Quota => "proxy",
ErrorKind::Service => "proxy",
ErrorKind::ControlPlane => "controlplane",
ErrorKind::Postgres => "other",
ErrorKind::Compute => "compute",
};
probe_msg = typed_json::json!({
"tag": tag,
"msg": msg,
"cold_start_info": ctx.cold_start_info(),
})
.to_string();
msg = &probe_msg;
}
}
// TODO: either preserve the error code from postgres, or assign error codes to proxy errors.

View File

@@ -18,10 +18,9 @@ use metrics::set_build_info_metric;
use remote_storage::RemoteStorageConfig;
use safekeeper::defaults::{
DEFAULT_CONTROL_FILE_SAVE_INTERVAL, DEFAULT_EVICTION_MIN_RESIDENT, DEFAULT_HEARTBEAT_TIMEOUT,
DEFAULT_HTTP_LISTEN_ADDR, DEFAULT_MAX_OFFLOADER_LAG_BYTES,
DEFAULT_MAX_REELECT_OFFLOADER_LAG_BYTES, DEFAULT_MAX_TIMELINE_DISK_USAGE_BYTES,
DEFAULT_PARTIAL_BACKUP_CONCURRENCY, DEFAULT_PARTIAL_BACKUP_TIMEOUT, DEFAULT_PG_LISTEN_ADDR,
DEFAULT_SSL_CERT_FILE, DEFAULT_SSL_CERT_RELOAD_PERIOD, DEFAULT_SSL_KEY_FILE,
DEFAULT_HTTP_LISTEN_ADDR, DEFAULT_MAX_OFFLOADER_LAG_BYTES, DEFAULT_PARTIAL_BACKUP_CONCURRENCY,
DEFAULT_PARTIAL_BACKUP_TIMEOUT, DEFAULT_PG_LISTEN_ADDR, DEFAULT_SSL_CERT_FILE,
DEFAULT_SSL_CERT_RELOAD_PERIOD, DEFAULT_SSL_KEY_FILE,
};
use safekeeper::wal_backup::WalBackup;
use safekeeper::{
@@ -139,15 +138,6 @@ struct Args {
/// Safekeeper won't be elected for WAL offloading if it is lagging for more than this value in bytes
#[arg(long, default_value_t = DEFAULT_MAX_OFFLOADER_LAG_BYTES)]
max_offloader_lag: u64,
/* BEGIN_HADRON */
/// Safekeeper will re-elect a new offloader if the current backup lagging for more than this value in bytes
#[arg(long, default_value_t = DEFAULT_MAX_REELECT_OFFLOADER_LAG_BYTES)]
max_reelect_offloader_lag_bytes: u64,
/// Safekeeper will stop accepting new WALs if the timeline disk usage exceeds this value in bytes.
/// Setting this value to 0 disables the limit.
#[arg(long, default_value_t = DEFAULT_MAX_TIMELINE_DISK_USAGE_BYTES)]
max_timeline_disk_usage_bytes: u64,
/* END_HADRON */
/// Number of max parallel WAL segments to be offloaded to remote storage.
#[arg(long, default_value = "5")]
wal_backup_parallel_jobs: usize,
@@ -401,10 +391,6 @@ async fn main() -> anyhow::Result<()> {
peer_recovery_enabled: args.peer_recovery,
remote_storage: args.remote_storage,
max_offloader_lag_bytes: args.max_offloader_lag,
/* BEGIN_HADRON */
max_reelect_offloader_lag_bytes: args.max_reelect_offloader_lag_bytes,
max_timeline_disk_usage_bytes: args.max_timeline_disk_usage_bytes,
/* END_HADRON */
wal_backup_enabled: !args.disable_wal_backup,
backup_parallel_jobs: args.wal_backup_parallel_jobs,
pg_auth,

View File

@@ -17,7 +17,6 @@ use utils::crashsafe::durable_rename;
use crate::control_file_upgrade::{downgrade_v10_to_v9, upgrade_control_file};
use crate::metrics::PERSIST_CONTROL_FILE_SECONDS;
use crate::metrics::WAL_DISK_IO_ERRORS;
use crate::state::{EvictionState, TimelinePersistentState};
pub const SK_MAGIC: u32 = 0xcafeceefu32;
@@ -193,14 +192,11 @@ impl TimelinePersistentState {
impl Storage for FileStorage {
/// Persists state durably to the underlying storage.
async fn persist(&mut self, s: &TimelinePersistentState) -> Result<()> {
// start timer for metrics
let _timer = PERSIST_CONTROL_FILE_SECONDS.start_timer();
// write data to safekeeper.control.partial
let control_partial_path = self.timeline_dir.join(CONTROL_FILE_NAME_PARTIAL);
let mut control_partial = File::create(&control_partial_path).await.with_context(|| {
/* BEGIN_HADRON */
WAL_DISK_IO_ERRORS.inc();
/*END_HADRON */
format!(
"failed to create partial control file at: {}",
&control_partial_path
@@ -210,24 +206,14 @@ impl Storage for FileStorage {
let buf: Vec<u8> = s.write_to_buf()?;
control_partial.write_all(&buf).await.with_context(|| {
/* BEGIN_HADRON */
WAL_DISK_IO_ERRORS.inc();
/*END_HADRON */
format!("failed to write safekeeper state into control file at: {control_partial_path}")
})?;
control_partial.flush().await.with_context(|| {
/* BEGIN_HADRON */
WAL_DISK_IO_ERRORS.inc();
/*END_HADRON */
format!("failed to flush safekeeper state into control file at: {control_partial_path}")
})?;
let control_path = self.timeline_dir.join(CONTROL_FILE_NAME);
durable_rename(&control_partial_path, &control_path, !self.no_sync)
.await
/* BEGIN_HADRON */
.inspect_err(|_| WAL_DISK_IO_ERRORS.inc())?;
/* END_HADRON */
durable_rename(&control_partial_path, &control_path, !self.no_sync).await?;
// update internal state
self.state = s.clone();

View File

@@ -61,13 +61,6 @@ pub mod defaults {
pub const DEFAULT_HEARTBEAT_TIMEOUT: &str = "5000ms";
pub const DEFAULT_MAX_OFFLOADER_LAG_BYTES: u64 = 128 * (1 << 20);
/* BEGIN_HADRON */
// Default leader re-elect is 0(disabled). SK will re-elect leader if the current leader is lagging this many bytes.
pub const DEFAULT_MAX_REELECT_OFFLOADER_LAG_BYTES: u64 = 0;
// Default disk usage limit is 0 (disabled). It means each timeline by default can use up to this many WAL
// disk space on this SK until SK begins to reject WALs.
pub const DEFAULT_MAX_TIMELINE_DISK_USAGE_BYTES: u64 = 0;
/* END_HADRON */
pub const DEFAULT_PARTIAL_BACKUP_TIMEOUT: &str = "15m";
pub const DEFAULT_CONTROL_FILE_SAVE_INTERVAL: &str = "300s";
pub const DEFAULT_PARTIAL_BACKUP_CONCURRENCY: &str = "5";
@@ -106,10 +99,6 @@ pub struct SafeKeeperConf {
pub peer_recovery_enabled: bool,
pub remote_storage: Option<RemoteStorageConfig>,
pub max_offloader_lag_bytes: u64,
/* BEGIN_HADRON */
pub max_reelect_offloader_lag_bytes: u64,
pub max_timeline_disk_usage_bytes: u64,
/* END_HADRON */
pub backup_parallel_jobs: usize,
pub wal_backup_enabled: bool,
pub pg_auth: Option<Arc<JwtAuth>>,
@@ -162,10 +151,6 @@ impl SafeKeeperConf {
sk_auth_token: None,
heartbeat_timeout: Duration::new(5, 0),
max_offloader_lag_bytes: defaults::DEFAULT_MAX_OFFLOADER_LAG_BYTES,
/* BEGIN_HADRON */
max_reelect_offloader_lag_bytes: defaults::DEFAULT_MAX_REELECT_OFFLOADER_LAG_BYTES,
max_timeline_disk_usage_bytes: defaults::DEFAULT_MAX_TIMELINE_DISK_USAGE_BYTES,
/* END_HADRON */
current_thread_runtime: false,
walsenders_keep_horizon: false,
partial_backup_timeout: Duration::from_secs(0),

View File

@@ -58,25 +58,6 @@ pub static FLUSH_WAL_SECONDS: Lazy<Histogram> = Lazy::new(|| {
)
.expect("Failed to register safekeeper_flush_wal_seconds histogram")
});
/* BEGIN_HADRON */
pub static WAL_DISK_IO_ERRORS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"safekeeper_wal_disk_io_errors",
"Number of disk I/O errors when creating and flushing WALs and control files"
)
.expect("Failed to register safekeeper_wal_disk_io_errors counter")
});
pub static WAL_STORAGE_LIMIT_ERRORS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"safekeeper_wal_storage_limit_errors",
concat!(
"Number of errors due to timeline WAL storage utilization exceeding configured limit. ",
"An increase in this metric indicates issues backing up or removing WALs."
)
)
.expect("Failed to register safekeeper_wal_storage_limit_errors counter")
});
/* END_HADRON */
pub static PERSIST_CONTROL_FILE_SECONDS: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"safekeeper_persist_control_file_seconds",
@@ -157,15 +138,6 @@ pub static BACKUP_ERRORS: Lazy<IntCounter> = Lazy::new(|| {
)
.expect("Failed to register safekeeper_backup_errors_total counter")
});
/* BEGIN_HADRON */
pub static BACKUP_REELECT_LEADER_COUNT: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"safekeeper_backup_reelect_leader_total",
"Number of times the backup leader was reelected"
)
.expect("Failed to register safekeeper_backup_reelect_leader_total counter")
});
/* END_HADRON */
pub static BROKER_PUSH_ALL_UPDATES_SECONDS: Lazy<Histogram> = Lazy::new(|| {
register_histogram!(
"safekeeper_broker_push_update_seconds",

View File

@@ -16,7 +16,7 @@ use tokio::sync::mpsc::error::SendError;
use tokio::task::JoinHandle;
use tokio::time::MissedTickBehavior;
use tracing::{Instrument, error, info, info_span};
use utils::critical_timeline;
use utils::critical;
use utils::lsn::Lsn;
use utils::postgres_client::{Compression, InterpretedFormat};
use wal_decoder::models::{InterpretedWalRecord, InterpretedWalRecords};
@@ -268,8 +268,6 @@ impl InterpretedWalReader {
let (shard_notification_tx, shard_notification_rx) = tokio::sync::mpsc::unbounded_channel();
let ttid = wal_stream.ttid;
let reader = InterpretedWalReader {
wal_stream,
shard_senders: HashMap::from([(
@@ -302,11 +300,7 @@ impl InterpretedWalReader {
.inspect_err(|err| match err {
// TODO: we may want to differentiate these errors further.
InterpretedWalReaderError::Decode(_) => {
critical_timeline!(
ttid.tenant_id,
ttid.timeline_id,
"failed to read WAL record: {err:?}"
);
critical!("failed to decode WAL record: {err:?}");
}
err => error!("failed to read WAL record: {err}"),
})
@@ -369,14 +363,9 @@ impl InterpretedWalReader {
metric.dec();
}
let ttid = self.wal_stream.ttid;
match self.run_impl(start_pos).await {
Err(err @ InterpretedWalReaderError::Decode(_)) => {
critical_timeline!(
ttid.tenant_id,
ttid.timeline_id,
"failed to decode WAL record: {err:?}"
);
critical!("failed to decode WAL record: {err:?}");
}
Err(err) => error!("failed to read WAL record: {err}"),
Ok(()) => info!("interpreted wal reader exiting"),

View File

@@ -26,9 +26,7 @@ use utils::id::{NodeId, TenantId, TenantTimelineId};
use utils::lsn::Lsn;
use utils::sync::gate::Gate;
use crate::metrics::{
FullTimelineInfo, MISC_OPERATION_SECONDS, WAL_STORAGE_LIMIT_ERRORS, WalStorageMetrics,
};
use crate::metrics::{FullTimelineInfo, MISC_OPERATION_SECONDS, WalStorageMetrics};
use crate::rate_limit::RateLimiter;
use crate::receive_wal::WalReceivers;
use crate::safekeeper::{AcceptorProposerMessage, ProposerAcceptorMessage, SafeKeeper, TermLsn};
@@ -197,7 +195,7 @@ impl StateSK {
Ok(TimelineMembershipSwitchResponse {
previous_conf: result.previous_conf,
current_conf: result.current_conf,
last_log_term: self.state().acceptor_state.term,
term: self.state().acceptor_state.term,
flush_lsn: self.flush_lsn(),
})
}
@@ -1052,39 +1050,6 @@ impl WalResidentTimeline {
Ok(ss)
}
// BEGIN HADRON
// Check if disk usage by WAL segment files for this timeline exceeds the configured limit.
fn hadron_check_disk_usage(
&self,
shared_state_locked: &mut WriteGuardSharedState<'_>,
) -> Result<()> {
// The disk usage is calculated based on the number of segments between `last_removed_segno`
// and the current flush LSN segment number. `last_removed_segno` is advanced after
// unneeded WAL files are physically removed from disk (see `update_wal_removal_end()`
// in `timeline_manager.rs`).
let max_timeline_disk_usage_bytes = self.conf.max_timeline_disk_usage_bytes;
if max_timeline_disk_usage_bytes > 0 {
let last_removed_segno = self.last_removed_segno.load(Ordering::Relaxed);
let flush_lsn = shared_state_locked.sk.flush_lsn();
let wal_seg_size = shared_state_locked.sk.state().server.wal_seg_size as u64;
let current_segno = flush_lsn.segment_number(wal_seg_size as usize);
let segno_count = current_segno - last_removed_segno;
let disk_usage_bytes = segno_count * wal_seg_size;
if disk_usage_bytes > max_timeline_disk_usage_bytes {
WAL_STORAGE_LIMIT_ERRORS.inc();
bail!(
"WAL storage utilization exceeds configured limit of {} bytes: current disk usage: {} bytes",
max_timeline_disk_usage_bytes,
disk_usage_bytes
);
}
}
Ok(())
}
// END HADRON
/// Pass arrived message to the safekeeper.
pub async fn process_msg(
&self,
@@ -1097,13 +1062,6 @@ impl WalResidentTimeline {
let mut rmsg: Option<AcceptorProposerMessage>;
{
let mut shared_state = self.write_shared_state().await;
// BEGIN HADRON
// Errors from the `hadron_check_disk_usage()` function fail the process_msg() function, which
// gets propagated upward and terminates the entire WalAcceptor. This will cause postgres to
// disconnect from the safekeeper and reestablish another connection. Postgres will keep retrying
// safekeeper connections every second until it can successfully propose WAL to the SK again.
self.hadron_check_disk_usage(&mut shared_state)?;
// END HADRON
rmsg = shared_state.sk.safekeeper().process_msg(msg).await?;
// if this is AppendResponse, fill in proper hot standby feedback.

View File

@@ -26,9 +26,7 @@ use utils::id::{NodeId, TenantTimelineId};
use utils::lsn::Lsn;
use utils::{backoff, pausable_failpoint};
use crate::metrics::{
BACKED_UP_SEGMENTS, BACKUP_ERRORS, BACKUP_REELECT_LEADER_COUNT, WAL_BACKUP_TASKS,
};
use crate::metrics::{BACKED_UP_SEGMENTS, BACKUP_ERRORS, WAL_BACKUP_TASKS};
use crate::timeline::WalResidentTimeline;
use crate::timeline_manager::{Manager, StateSnapshot};
use crate::{SafeKeeperConf, WAL_BACKUP_RUNTIME};
@@ -72,9 +70,8 @@ pub(crate) async fn update_task(
need_backup: bool,
state: &StateSnapshot,
) {
/* BEGIN_HADRON */
let (offloader, election_dbg_str) = hadron_determine_offloader(mgr, state);
/* END_HADRON */
let (offloader, election_dbg_str) =
determine_offloader(&state.peers, state.backup_lsn, mgr.tli.ttid, &mgr.conf);
let elected_me = Some(mgr.conf.my_id) == offloader;
let should_task_run = need_backup && elected_me;
@@ -130,70 +127,6 @@ async fn shut_down_task(entry: &mut Option<WalBackupTaskHandle>) {
}
}
/* BEGIN_HADRON */
// On top of the neon determine_offloader, we also check if the current offloader is lagging behind too much.
// If it is, we re-elect a new offloader. This mitigates the below issue. It also helps distribute the load across SKs.
//
// We observe that the offloader fails to upload a segment due to race conditions on XLOG SWITCH and PG start streaming WALs.
// wal_backup task continously failing to upload a full segment while the segment remains partial on the disk.
// The consequence is that commit_lsn for all SKs move forward but backup_lsn stays the same. Then, all SKs run out of disk space.
// See go/sk-ood-xlog-switch for more details.
//
// To mitigate this issue, we will re-elect a new offloader if the current offloader is lagging behind too much.
// Each SK makes the decision locally but they are aware of each other's commit and backup lsns.
//
// determine_offloader will pick a SK. say SK-1.
// Each SK checks
// -- if commit_lsn - back_lsn > threshold,
// -- -- remove SK-1 from the candidate and call determine_offloader again.
// SK-1 will step down and all SKs will elect the same leader again.
// After the backup is caught up, the leader will become SK-1 again.
fn hadron_determine_offloader(mgr: &Manager, state: &StateSnapshot) -> (Option<NodeId>, String) {
let mut offloader: Option<NodeId>;
let mut election_dbg_str: String;
let caughtup_peers_count: usize;
(offloader, election_dbg_str, caughtup_peers_count) =
determine_offloader(&state.peers, state.backup_lsn, mgr.tli.ttid, &mgr.conf);
if offloader.is_none()
|| caughtup_peers_count <= 1
|| mgr.conf.max_reelect_offloader_lag_bytes == 0
{
return (offloader, election_dbg_str);
}
let offloader_sk_id = offloader.unwrap();
let backup_lag = state.commit_lsn.checked_sub(state.backup_lsn);
if backup_lag.is_none() {
info!("Backup lag is None. Skipping re-election.");
return (offloader, election_dbg_str);
}
let backup_lag = backup_lag.unwrap().0;
if backup_lag < mgr.conf.max_reelect_offloader_lag_bytes {
return (offloader, election_dbg_str);
}
info!(
"Electing a new leader: Backup lag is too high backup lsn lag {} threshold {}: {}",
backup_lag, mgr.conf.max_reelect_offloader_lag_bytes, election_dbg_str
);
BACKUP_REELECT_LEADER_COUNT.inc();
// Remove the current offloader if lag is too high.
let new_peers: Vec<_> = state
.peers
.iter()
.filter(|p| p.sk_id != offloader_sk_id)
.cloned()
.collect();
(offloader, election_dbg_str, _) =
determine_offloader(&new_peers, state.backup_lsn, mgr.tli.ttid, &mgr.conf);
(offloader, election_dbg_str)
}
/* END_HADRON */
/// The goal is to ensure that normally only one safekeepers offloads. However,
/// it is fine (and inevitable, as s3 doesn't provide CAS) that for some short
/// time we have several ones as they PUT the same files. Also,
@@ -208,13 +141,13 @@ fn determine_offloader(
wal_backup_lsn: Lsn,
ttid: TenantTimelineId,
conf: &SafeKeeperConf,
) -> (Option<NodeId>, String, usize) {
) -> (Option<NodeId>, String) {
// TODO: remove this once we fill newly joined safekeepers since backup_lsn.
let capable_peers = alive_peers
.iter()
.filter(|p| p.local_start_lsn <= wal_backup_lsn);
match capable_peers.clone().map(|p| p.commit_lsn).max() {
None => (None, "no connected peers to elect from".to_string(), 0),
None => (None, "no connected peers to elect from".to_string()),
Some(max_commit_lsn) => {
let threshold = max_commit_lsn
.checked_sub(conf.max_offloader_lag_bytes)
@@ -242,7 +175,6 @@ fn determine_offloader(
capable_peers_dbg,
caughtup_peers.len()
),
caughtup_peers.len(),
)
}
}
@@ -414,8 +346,6 @@ async fn backup_lsn_range(
anyhow::bail!("parallel_jobs must be >= 1");
}
pausable_failpoint!("backup-lsn-range-pausable");
let remote_timeline_path = &timeline.remote_path;
let start_lsn = *backup_lsn;
let segments = get_segments(start_lsn, end_lsn, wal_seg_size);

View File

@@ -1,16 +1,16 @@
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::send_wal::EndWatch;
use crate::timeline::WalResidentTimeline;
use crate::wal_storage::WalReader;
use bytes::Bytes;
use futures::stream::BoxStream;
use futures::{Stream, StreamExt};
use safekeeper_api::Term;
use utils::id::TenantTimelineId;
use utils::lsn::Lsn;
use crate::send_wal::EndWatch;
use crate::timeline::WalResidentTimeline;
use crate::wal_storage::WalReader;
#[derive(PartialEq, Eq, Debug)]
pub(crate) struct WalBytes {
/// Raw PG WAL
@@ -37,8 +37,6 @@ struct PositionedWalReader {
pub(crate) struct StreamingWalReader {
stream: BoxStream<'static, WalOrReset>,
start_changed_tx: tokio::sync::watch::Sender<Lsn>,
// HADRON: Added TenantTimelineId for instrumentation purposes.
pub(crate) ttid: TenantTimelineId,
}
pub(crate) enum WalOrReset {
@@ -65,7 +63,6 @@ impl StreamingWalReader {
buffer_size: usize,
) -> Self {
let (start_changed_tx, start_changed_rx) = tokio::sync::watch::channel(start);
let ttid = tli.ttid;
let state = WalReaderStreamState {
tli,
@@ -110,7 +107,6 @@ impl StreamingWalReader {
Self {
stream,
start_changed_tx,
ttid,
}
}

View File

@@ -31,8 +31,7 @@ use utils::id::TenantTimelineId;
use utils::lsn::Lsn;
use crate::metrics::{
REMOVED_WAL_SEGMENTS, WAL_DISK_IO_ERRORS, WAL_STORAGE_OPERATION_SECONDS, WalStorageMetrics,
time_io_closure,
REMOVED_WAL_SEGMENTS, WAL_STORAGE_OPERATION_SECONDS, WalStorageMetrics, time_io_closure,
};
use crate::state::TimelinePersistentState;
use crate::wal_backup::{WalBackup, read_object, remote_timeline_path};
@@ -294,12 +293,9 @@ impl PhysicalStorage {
// half initialized segment, first bake it under tmp filename and
// then rename.
let tmp_path = self.timeline_dir.join("waltmp");
let file: File = File::create(&tmp_path).await.with_context(|| {
/* BEGIN_HADRON */
WAL_DISK_IO_ERRORS.inc();
/* END_HADRON */
format!("Failed to open tmp wal file {:?}", &tmp_path)
})?;
let file = File::create(&tmp_path)
.await
.with_context(|| format!("Failed to open tmp wal file {:?}", &tmp_path))?;
fail::fail_point!("sk-zero-segment", |_| {
info!("sk-zero-segment failpoint hit");
@@ -386,11 +382,7 @@ impl PhysicalStorage {
let flushed = self
.write_in_segment(segno, xlogoff, &buf[..bytes_write])
.await
/* BEGIN_HADRON */
.inspect_err(|_| WAL_DISK_IO_ERRORS.inc())?;
/* END_HADRON */
.await?;
self.write_lsn += bytes_write as u64;
if flushed {
self.flush_lsn = self.write_lsn;
@@ -499,11 +491,7 @@ impl Storage for PhysicalStorage {
}
if let Some(unflushed_file) = self.file.take() {
self.fdatasync_file(&unflushed_file)
.await
/* BEGIN_HADRON */
.inspect_err(|_| WAL_DISK_IO_ERRORS.inc())?;
/* END_HADRON */
self.fdatasync_file(&unflushed_file).await?;
self.file = Some(unflushed_file);
} else {
// We have unflushed data (write_lsn != flush_lsn), but no file. This

View File

@@ -159,10 +159,6 @@ pub fn run_server(os: NodeOs, disk: Arc<SafekeeperDisk>) -> Result<()> {
heartbeat_timeout: Duration::from_secs(0),
remote_storage: None,
max_offloader_lag_bytes: 0,
/* BEGIN_HADRON */
max_reelect_offloader_lag_bytes: 0,
max_timeline_disk_usage_bytes: 0,
/* END_HADRON */
wal_backup_enabled: false,
listen_pg_addr_tenant_only: None,
advertise_pg_addr: None,

Some files were not shown because too many files have changed in this diff Show More