Merge remote-tracking branch 'origin/main' into communicator-rewrite

This commit is contained in:
Heikki Linnakangas
2025-07-05 16:59:51 +03:00
141 changed files with 5475 additions and 2033 deletions

View File

@@ -32,162 +32,14 @@ permissions:
contents: read
jobs:
build-pgxn:
if: |
inputs.pg_versions != '[]' || inputs.rebuild_everything ||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-macos') ||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-*') ||
github.ref_name == 'main'
timeout-minutes: 30
runs-on: macos-15
strategy:
matrix:
postgres-version: ${{ inputs.rebuild_everything && fromJSON('["v14", "v15", "v16", "v17"]') || fromJSON(inputs.pg_versions) }}
env:
# Use release build only, to have less debug info around
# Hence keeping target/ (and general cache size) smaller
BUILD_TYPE: release
steps:
- name: Harden the runner (Audit all outbound calls)
uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0
with:
egress-policy: audit
- name: Checkout main repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set pg ${{ matrix.postgres-version }} for caching
id: pg_rev
run: echo pg_rev=$(git rev-parse HEAD:vendor/postgres-${{ matrix.postgres-version }}) | tee -a "${GITHUB_OUTPUT}"
- name: Cache postgres ${{ matrix.postgres-version }} build
id: cache_pg
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
with:
path: pg_install/${{ matrix.postgres-version }}
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-pg-${{ matrix.postgres-version }}-${{ steps.pg_rev.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
- name: Checkout submodule vendor/postgres-${{ matrix.postgres-version }}
if: steps.cache_pg.outputs.cache-hit != 'true'
run: |
git submodule init vendor/postgres-${{ matrix.postgres-version }}
git submodule update --depth 1 --recursive
- name: Install build dependencies
if: steps.cache_pg.outputs.cache-hit != 'true'
run: |
brew install flex bison openssl protobuf icu4c
- name: Set extra env for macOS
if: steps.cache_pg.outputs.cache-hit != 'true'
run: |
echo 'LDFLAGS=-L/usr/local/opt/openssl@3/lib' >> $GITHUB_ENV
echo 'CPPFLAGS=-I/usr/local/opt/openssl@3/include' >> $GITHUB_ENV
- name: Build Postgres ${{ matrix.postgres-version }}
if: steps.cache_pg.outputs.cache-hit != 'true'
run: |
make postgres-${{ matrix.postgres-version }} -j$(sysctl -n hw.ncpu)
- name: Build Neon Pg Ext ${{ matrix.postgres-version }}
if: steps.cache_pg.outputs.cache-hit != 'true'
run: |
make "neon-pg-ext-${{ matrix.postgres-version }}" -j$(sysctl -n hw.ncpu)
- name: Upload "pg_install/${{ matrix.postgres-version }}" artifact
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: pg_install--${{ matrix.postgres-version }}
path: pg_install/${{ matrix.postgres-version }}
# The artifact is supposed to be used by the next job in the same workflow,
# so theres no need to store it for too long.
retention-days: 1
build-walproposer-lib:
if: |
contains(inputs.pg_versions, 'v17') || inputs.rebuild_everything ||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-macos') ||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-*') ||
github.ref_name == 'main'
timeout-minutes: 30
runs-on: macos-15
needs: [build-pgxn]
env:
# Use release build only, to have less debug info around
# Hence keeping target/ (and general cache size) smaller
BUILD_TYPE: release
steps:
- name: Harden the runner (Audit all outbound calls)
uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0
with:
egress-policy: audit
- name: Checkout main repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set pg v17 for caching
id: pg_rev
run: echo pg_rev=$(git rev-parse HEAD:vendor/postgres-v17) | tee -a "${GITHUB_OUTPUT}"
- name: Download "pg_install/v17" artifact
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
name: pg_install--v17
path: pg_install/v17
# `actions/download-artifact` doesn't preserve permissions:
# https://github.com/actions/download-artifact?tab=readme-ov-file#permission-loss
- name: Make pg_install/v*/bin/* executable
run: |
chmod +x pg_install/v*/bin/*
- name: Cache walproposer-lib
id: cache_walproposer_lib
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
with:
path: build/walproposer-lib
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-walproposer_lib-v17-${{ steps.pg_rev.outputs.pg_rev }}-${{ hashFiles('Makefile') }}
- name: Checkout submodule vendor/postgres-v17
if: steps.cache_walproposer_lib.outputs.cache-hit != 'true'
run: |
git submodule init vendor/postgres-v17
git submodule update --depth 1 --recursive
- name: Install build dependencies
if: steps.cache_walproposer_lib.outputs.cache-hit != 'true'
run: |
brew install flex bison openssl protobuf icu4c
- name: Set extra env for macOS
if: steps.cache_walproposer_lib.outputs.cache-hit != 'true'
run: |
echo 'LDFLAGS=-L/usr/local/opt/openssl@3/lib' >> $GITHUB_ENV
echo 'CPPFLAGS=-I/usr/local/opt/openssl@3/include' >> $GITHUB_ENV
- name: Build walproposer-lib (only for v17)
if: steps.cache_walproposer_lib.outputs.cache-hit != 'true'
run:
make walproposer-lib -j$(sysctl -n hw.ncpu) PG_INSTALL_CACHED=1
- name: Upload "build/walproposer-lib" artifact
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
with:
name: build--walproposer-lib
path: build/walproposer-lib
# The artifact is supposed to be used by the next job in the same workflow,
# so theres no need to store it for too long.
retention-days: 1
cargo-build:
make-all:
if: |
inputs.pg_versions != '[]' || inputs.rebuild_rust_code || inputs.rebuild_everything ||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-macos') ||
contains(github.event.pull_request.labels.*.name, 'run-extra-build-*') ||
github.ref_name == 'main'
timeout-minutes: 30
timeout-minutes: 60
runs-on: macos-15
needs: [build-pgxn, build-walproposer-lib]
env:
# Use release build only, to have less debug info around
# Hence keeping target/ (and general cache size) smaller
@@ -203,41 +55,53 @@ jobs:
with:
submodules: true
- name: Download "pg_install/v14" artifact
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
name: pg_install--v14
path: pg_install/v14
- name: Download "pg_install/v15" artifact
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
name: pg_install--v15
path: pg_install/v15
- name: Download "pg_install/v16" artifact
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
name: pg_install--v16
path: pg_install/v16
- name: Download "pg_install/v17" artifact
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
name: pg_install--v17
path: pg_install/v17
- name: Download "build/walproposer-lib" artifact
uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4.3.0
with:
name: build--walproposer-lib
path: build/walproposer-lib
# `actions/download-artifact` doesn't preserve permissions:
# https://github.com/actions/download-artifact?tab=readme-ov-file#permission-loss
- name: Make pg_install/v*/bin/* executable
- name: Install build dependencies
run: |
chmod +x pg_install/v*/bin/*
brew install flex bison openssl protobuf icu4c
- name: Set extra env for macOS
run: |
echo 'LDFLAGS=-L/usr/local/opt/openssl@3/lib' >> $GITHUB_ENV
echo 'CPPFLAGS=-I/usr/local/opt/openssl@3/include' >> $GITHUB_ENV
- name: Restore "pg_install/" cache
id: cache_pg
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
with:
path: pg_install
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ env.BUILD_TYPE }}-pg-install-v14-${{ hashFiles('Makefile', 'postgres.mk', 'vendor/revisions.json') }}
- name: Checkout vendor/postgres submodules
if: steps.cache_pg.outputs.cache-hit != 'true'
run: |
git submodule init
git submodule update --depth 1 --recursive
- name: Build Postgres
if: steps.cache_pg.outputs.cache-hit != 'true'
run: |
make postgres -j$(sysctl -n hw.ncpu)
# This isn't strictly necessary, but it makes the cached and non-cached builds more similar,
# When pg_install is restored from cache, there is no 'build/' directory. By removing it
# in a non-cached build too, we enforce that the rest of the steps don't depend on it,
# so that we notice any build caching bugs earlier.
- name: Remove build artifacts
if: steps.cache_pg.outputs.cache-hit != 'true'
run: |
rm -rf build
# Explicitly update the rust toolchain before running 'make'. The parallel make build can
# invoke 'cargo build' more than once in parallel, for different crates. That's OK, 'cargo'
# does its own locking to prevent concurrent builds from stepping on each other's
# toes. However, it will first try to update the toolchain, and that step is not locked the
# same way. To avoid two toolchain updates running in parallel and stepping on each other's
# toes, ensure that the toolchain is up-to-date beforehand.
- name: Update rust toolchain
run: |
rustup --version &&
rustup update &&
rustup show
- name: Cache cargo deps
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
@@ -249,17 +113,12 @@ jobs:
target
key: v1-${{ runner.os }}-${{ runner.arch }}-cargo-${{ hashFiles('./Cargo.lock') }}-${{ hashFiles('./rust-toolchain.toml') }}-rust
- name: Install build dependencies
run: |
brew install flex bison openssl protobuf icu4c
- name: Set extra env for macOS
run: |
echo 'LDFLAGS=-L/usr/local/opt/openssl@3/lib' >> $GITHUB_ENV
echo 'CPPFLAGS=-I/usr/local/opt/openssl@3/include' >> $GITHUB_ENV
- name: Run cargo build
run: cargo build --all --release -j$(sysctl -n hw.ncpu)
# Build the neon-specific postgres extensions, and all the Rust bits.
#
# Pass PG_INSTALL_CACHED=1 because PostgreSQL was already built and cached
# separately.
- name: Build all
run: PG_INSTALL_CACHED=1 BUILD_TYPE=release make -j$(sysctl -n hw.ncpu) all
- name: Check that no warnings are produced
run: ./run_clippy.sh

1
.gitignore vendored
View File

@@ -6,6 +6,7 @@
/tmp_check_cli
__pycache__/
test_output/
neon_previous/
.vscode
.idea
*.swp

35
Cargo.lock generated
View File

@@ -1213,12 +1213,12 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cbindgen"
version = "0.28.0"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eadd868a2ce9ca38de7eeafdcec9c7065ef89b42b32f0839278d55f35c54d1ff"
checksum = "975982cdb7ad6a142be15bdf84aea7ec6a9e5d4d797c004d43185b24cfe4e684"
dependencies = [
"clap",
"heck 0.4.1",
"heck",
"indexmap 2.9.0",
"log",
"proc-macro2",
@@ -1356,7 +1356,7 @@ version = "4.5.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2c7947ae4cc3d851207c1adb5b5e260ff0cca11446b1d6d1423788e442257ce"
dependencies = [
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"syn",
@@ -1424,7 +1424,7 @@ dependencies = [
[[package]]
name = "communicator"
version = "0.1.0"
version = "0.0.0"
dependencies = [
"atomic_enum",
"axum 0.8.4",
@@ -1449,6 +1449,7 @@ dependencies = [
"tracing-subscriber",
"uring-common",
"utils",
"workspace_hack",
]
[[package]]
@@ -1489,6 +1490,7 @@ dependencies = [
"fail",
"flate2",
"futures",
"hostname-validator",
"http 1.3.1",
"indexmap 2.9.0",
"itertools 0.10.5",
@@ -2169,7 +2171,7 @@ checksum = "139ae9aca7527f85f26dd76483eb38533fd84bd571065da1739656ef71c5ff5b"
dependencies = [
"darling",
"either",
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"syn",
@@ -3001,12 +3003,6 @@ dependencies = [
"http 1.3.1",
]
[[package]]
name = "heck"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "heck"
version = "0.5.0"
@@ -3072,6 +3068,12 @@ dependencies = [
"windows-link",
]
[[package]]
name = "hostname-validator"
version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f558a64ac9af88b5ba400d99b579451af0d39c6d360980045b91aac966d705e2"
[[package]]
name = "http"
version = "0.2.12"
@@ -3993,7 +3995,7 @@ version = "0.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e6777fc80a575f9503d908c8b498782a6c3ee88a06cb416dc3941401e43b94"
dependencies = [
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"syn",
@@ -5584,7 +5586,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4"
dependencies = [
"bytes",
"heck 0.5.0",
"heck",
"itertools 0.12.1",
"log",
"multimap",
@@ -5604,7 +5606,7 @@ version = "0.13.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf"
dependencies = [
"heck 0.5.0",
"heck",
"itertools 0.14.0",
"log",
"multimap",
@@ -7547,7 +7549,7 @@ version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be"
dependencies = [
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"rustversion",
@@ -9449,6 +9451,7 @@ dependencies = [
"num-iter",
"num-rational",
"num-traits",
"once_cell",
"p256 0.13.2",
"parquet",
"prettyplease",

View File

@@ -263,7 +263,6 @@ desim = { version = "0.1", path = "./libs/desim" }
endpoint_storage = { version = "0.0.1", path = "./endpoint_storage/" }
http-utils = { version = "0.1", path = "./libs/http-utils/" }
metrics = { version = "0.1", path = "./libs/metrics/" }
neonart = { version = "0.1", path = "./libs/neonart/" }
neon-shmem = { version = "0.1", path = "./libs/neon-shmem/" }
pageserver = { path = "./pageserver" }
pageserver_api = { version = "0.1", path = "./libs/pageserver_api/" }
@@ -295,7 +294,7 @@ walproposer = { version = "0.1", path = "./libs/walproposer/" }
workspace_hack = { version = "0.1", path = "./workspace_hack/" }
## Build dependencies
cbindgen = "0.28.0"
cbindgen = "0.29.0"
criterion = "0.5.1"
rcgen = "0.13"
rstest = "0.18"

View File

@@ -30,7 +30,18 @@ ARG BASE_IMAGE_SHA=debian:${DEBIAN_FLAVOR}
ARG BASE_IMAGE_SHA=${BASE_IMAGE_SHA/debian:bookworm-slim/debian@$BOOKWORM_SLIM_SHA}
ARG BASE_IMAGE_SHA=${BASE_IMAGE_SHA/debian:bullseye-slim/debian@$BULLSEYE_SLIM_SHA}
# Build Postgres
# Naive way:
#
# 1. COPY . .
# 1. make neon-pg-ext
# 2. cargo build <storage binaries>
#
# But to enable docker to cache intermediate layers, we perform a few preparatory steps:
#
# - Build all postgres versions, depending on just the contents of vendor/
# - Use cargo chef to build all rust dependencies
# 1. Build all postgres versions
FROM $REPOSITORY/$IMAGE:$TAG AS pg-build
WORKDIR /home/nonroot
@@ -38,17 +49,15 @@ COPY --chown=nonroot vendor/postgres-v14 vendor/postgres-v14
COPY --chown=nonroot vendor/postgres-v15 vendor/postgres-v15
COPY --chown=nonroot vendor/postgres-v16 vendor/postgres-v16
COPY --chown=nonroot vendor/postgres-v17 vendor/postgres-v17
COPY --chown=nonroot pgxn pgxn
COPY --chown=nonroot Makefile Makefile
COPY --chown=nonroot postgres.mk postgres.mk
COPY --chown=nonroot scripts/ninstall.sh scripts/ninstall.sh
ENV BUILD_TYPE=release
RUN set -e \
&& mold -run make -j $(nproc) -s neon-pg-ext \
&& tar -C pg_install -czf /home/nonroot/postgres_install.tar.gz .
&& mold -run make -j $(nproc) -s postgres
# Prepare cargo-chef recipe
# 2. Prepare cargo-chef recipe
FROM $REPOSITORY/$IMAGE:$TAG AS plan
WORKDIR /home/nonroot
@@ -56,23 +65,22 @@ COPY --chown=nonroot . .
RUN cargo chef prepare --recipe-path recipe.json
# Build neon binaries
# Main build image
FROM $REPOSITORY/$IMAGE:$TAG AS build
WORKDIR /home/nonroot
ARG GIT_VERSION=local
ARG BUILD_TAG
COPY --from=pg-build /home/nonroot/pg_install/v14/include/postgresql/server pg_install/v14/include/postgresql/server
COPY --from=pg-build /home/nonroot/pg_install/v15/include/postgresql/server pg_install/v15/include/postgresql/server
COPY --from=pg-build /home/nonroot/pg_install/v16/include/postgresql/server pg_install/v16/include/postgresql/server
COPY --from=pg-build /home/nonroot/pg_install/v17/include/postgresql/server pg_install/v17/include/postgresql/server
COPY --from=plan /home/nonroot/recipe.json recipe.json
ARG ADDITIONAL_RUSTFLAGS=""
# 3. Build cargo dependencies. Note that this step doesn't depend on anything else than
# `recipe.json`, so the layer can be reused as long as none of the dependencies change.
COPY --from=plan /home/nonroot/recipe.json recipe.json
RUN set -e \
&& RUSTFLAGS="-Clinker=clang -Clink-arg=-fuse-ld=mold -Clink-arg=-Wl,--no-rosegment -Cforce-frame-pointers=yes ${ADDITIONAL_RUSTFLAGS}" cargo chef cook --locked --release --recipe-path recipe.json
# Perform the main build. We reuse the Postgres build artifacts from the intermediate 'pg-build'
# layer, and the cargo dependencies built in the previous step.
COPY --chown=nonroot --from=pg-build /home/nonroot/pg_install/ pg_install
COPY --chown=nonroot . .
RUN set -e \
@@ -87,10 +95,10 @@ RUN set -e \
--bin endpoint_storage \
--bin neon_local \
--bin storage_scrubber \
--locked --release
--locked --release \
&& mold -run make -j $(nproc) -s neon-pg-ext
# Build final image
#
# Assemble the final image
FROM $BASE_IMAGE_SHA
WORKDIR /data
@@ -130,12 +138,15 @@ COPY --from=build --chown=neon:neon /home/nonroot/target/release/proxy
COPY --from=build --chown=neon:neon /home/nonroot/target/release/endpoint_storage /usr/local/bin
COPY --from=build --chown=neon:neon /home/nonroot/target/release/neon_local /usr/local/bin
COPY --from=build --chown=neon:neon /home/nonroot/target/release/storage_scrubber /usr/local/bin
COPY --from=build /home/nonroot/pg_install/v14 /usr/local/v14/
COPY --from=build /home/nonroot/pg_install/v15 /usr/local/v15/
COPY --from=build /home/nonroot/pg_install/v16 /usr/local/v16/
COPY --from=build /home/nonroot/pg_install/v17 /usr/local/v17/
COPY --from=pg-build /home/nonroot/pg_install/v14 /usr/local/v14/
COPY --from=pg-build /home/nonroot/pg_install/v15 /usr/local/v15/
COPY --from=pg-build /home/nonroot/pg_install/v16 /usr/local/v16/
COPY --from=pg-build /home/nonroot/pg_install/v17 /usr/local/v17/
COPY --from=pg-build /home/nonroot/postgres_install.tar.gz /data/
# Deprecated: Old deployment scripts use this tarball which contains all the Postgres binaries.
# That's obsolete, since all the same files are also present under /usr/local/v*. But to keep the
# old scripts working for now, create the tarball.
RUN tar -C /usr/local -cvzf /data/postgres_install.tar.gz v14 v15 v16 v17
# By default, pageserver uses `.neon/` working directory in WORKDIR, so create one and fill it with the dummy config.
# Now, when `docker run ... pageserver` is run, it can start without errors, yet will have some default dummy values.

View File

@@ -109,7 +109,7 @@ all: neon postgres-install neon-pg-ext
### Neon Rust bits
#
# The 'postgres_ffi' depends on the Postgres headers.
# The 'postgres_ffi' crate depends on the Postgres headers.
.PHONY: neon
neon: postgres-headers-install walproposer-lib cargo-target-dir
+@echo "Compiling Neon"
@@ -122,7 +122,7 @@ cargo-target-dir:
test -e target/CACHEDIR.TAG || echo "$(CACHEDIR_TAG_CONTENTS)" > target/CACHEDIR.TAG
.PHONY: neon-pg-ext-%
neon-pg-ext-%: postgres-install-%
neon-pg-ext-%: postgres-install-% cargo-target-dir
+@echo "Compiling neon-specific Postgres extensions for $*"
mkdir -p $(BUILD_DIR)/pgxn-$*
$(MAKE) PG_CONFIG="$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config" COPT='$(COPT)' \

View File

@@ -1572,6 +1572,7 @@ RUN make -j $(getconf _NPROCESSORS_ONLN) && \
FROM build-deps AS pgaudit-src
ARG PG_VERSION
WORKDIR /ext-src
COPY "compute/patches/pgaudit-parallel_workers-${PG_VERSION}.patch" .
RUN case "${PG_VERSION}" in \
"v14") \
export PGAUDIT_VERSION=1.6.3 \
@@ -1594,7 +1595,8 @@ RUN case "${PG_VERSION}" in \
esac && \
wget https://github.com/pgaudit/pgaudit/archive/refs/tags/${PGAUDIT_VERSION}.tar.gz -O pgaudit.tar.gz && \
echo "${PGAUDIT_CHECKSUM} pgaudit.tar.gz" | sha256sum --check && \
mkdir pgaudit-src && cd pgaudit-src && tar xzf ../pgaudit.tar.gz --strip-components=1 -C .
mkdir pgaudit-src && cd pgaudit-src && tar xzf ../pgaudit.tar.gz --strip-components=1 -C . && \
patch -p1 < "/ext-src/pgaudit-parallel_workers-${PG_VERSION}.patch"
FROM pg-build AS pgaudit-build
COPY --from=pgaudit-src /ext-src/ /ext-src/
@@ -1634,11 +1636,14 @@ RUN make install USE_PGXS=1 -j $(getconf _NPROCESSORS_ONLN)
# compile neon extensions
#
#########################################################################################
FROM pg-build AS neon-ext-build
FROM pg-build-with-cargo AS neon-ext-build
ARG PG_VERSION
COPY pgxn/ pgxn/
RUN make -j $(getconf _NPROCESSORS_ONLN) -C pgxn -s install-compute
USER root
COPY . .
RUN make -j $(getconf _NPROCESSORS_ONLN) -C pgxn -s install-compute \
BUILD_TYPE=release CARGO_BUILD_FLAGS="--locked --release" NEON_CARGO_ARTIFACT_TARGET_DIR="$(pwd)/target/release"
#########################################################################################
#
@@ -1983,7 +1988,7 @@ RUN apt update && \
locales \
lsof \
procps \
rsyslog \
rsyslog-gnutls \
screen \
tcpdump \
$VERSION_INSTALLS && \

View File

@@ -8,6 +8,8 @@
import 'sql_exporter/compute_logical_snapshot_files.libsonnet',
import 'sql_exporter/compute_logical_snapshots_bytes.libsonnet',
import 'sql_exporter/compute_max_connections.libsonnet',
import 'sql_exporter/compute_pg_oldest_frozen_xid_age.libsonnet',
import 'sql_exporter/compute_pg_oldest_mxid_age.libsonnet',
import 'sql_exporter/compute_receive_lsn.libsonnet',
import 'sql_exporter/compute_subscriptions_count.libsonnet',
import 'sql_exporter/connection_counts.libsonnet',

View File

@@ -0,0 +1,13 @@
{
metric_name: 'compute_pg_oldest_frozen_xid_age',
type: 'gauge',
help: 'Age of oldest XIDs that have not been frozen by VACUUM. An indicator of how long it has been since VACUUM last ran.',
key_labels: [
'database_name',
],
value_label: 'metric',
values: [
'frozen_xid_age',
],
query: importstr 'sql_exporter/compute_pg_oldest_frozen_xid_age.sql',
}

View File

@@ -0,0 +1,4 @@
SELECT datname database_name,
age(datfrozenxid) frozen_xid_age
FROM pg_database
ORDER BY frozen_xid_age DESC LIMIT 10;

View File

@@ -0,0 +1,13 @@
{
metric_name: 'compute_pg_oldest_mxid_age',
type: 'gauge',
help: 'Age of oldest MXIDs that have not been replaced by VACUUM. An indicator of how long it has been since VACUUM last ran.',
key_labels: [
'database_name',
],
value_label: 'metric',
values: [
'min_mxid_age',
],
query: importstr 'sql_exporter/compute_pg_oldest_mxid_age.sql',
}

View File

@@ -0,0 +1,4 @@
SELECT datname database_name,
mxid_age(datminmxid) min_mxid_age
FROM pg_database
ORDER BY min_mxid_age DESC LIMIT 10;

View File

@@ -1,8 +1,8 @@
diff --git a/sql/anon.sql b/sql/anon.sql
index 0cdc769..f6cc950 100644
index 0cdc769..b450327 100644
--- a/sql/anon.sql
+++ b/sql/anon.sql
@@ -1141,3 +1141,8 @@ $$
@@ -1141,3 +1141,15 @@ $$
-- TODO : https://en.wikipedia.org/wiki/L-diversity
-- TODO : https://en.wikipedia.org/wiki/T-closeness
@@ -11,6 +11,13 @@ index 0cdc769..f6cc950 100644
+
+GRANT ALL ON SCHEMA anon to neon_superuser;
+GRANT ALL ON ALL TABLES IN SCHEMA anon TO neon_superuser;
+
+DO $$
+BEGIN
+ IF current_setting('server_version_num')::int >= 150000 THEN
+ GRANT SET ON PARAMETER anon.transparent_dynamic_masking TO neon_superuser;
+ END IF;
+END $$;
diff --git a/sql/init.sql b/sql/init.sql
index 7da6553..9b6164b 100644
--- a/sql/init.sql

View File

@@ -0,0 +1,143 @@
commit 7220bb3a3f23fa27207d77562dcc286f9a123313
Author: Tristan Partin <tristan.partin@databricks.com>
Date: 2025-06-23 02:09:31 +0000
Disable logging in parallel workers
When a query uses parallel workers, pgaudit will log the same query for
every parallel worker. This is undesireable since it can result in log
amplification for queries that use parallel workers.
Signed-off-by: Tristan Partin <tristan.partin@databricks.com>
diff --git a/expected/pgaudit.out b/expected/pgaudit.out
index baa8011..a601375 100644
--- a/expected/pgaudit.out
+++ b/expected/pgaudit.out
@@ -2563,6 +2563,37 @@ COMMIT;
NOTICE: AUDIT: SESSION,12,4,MISC,COMMIT,,,COMMIT;,<not logged>
DROP TABLE part_test;
NOTICE: AUDIT: SESSION,13,1,DDL,DROP TABLE,,,DROP TABLE part_test;,<not logged>
+--
+-- Test logging in parallel workers
+SET pgaudit.log = 'read';
+SET pgaudit.log_client = on;
+SET pgaudit.log_level = 'notice';
+-- Force parallel execution for testing
+SET max_parallel_workers_per_gather = 2;
+SET parallel_tuple_cost = 0;
+SET parallel_setup_cost = 0;
+SET min_parallel_table_scan_size = 0;
+SET min_parallel_index_scan_size = 0;
+-- Create table with enough data to trigger parallel execution
+CREATE TABLE parallel_test (id int, data text);
+INSERT INTO parallel_test SELECT generate_series(1, 1000), 'test data';
+SELECT count(*) FROM parallel_test;
+NOTICE: AUDIT: SESSION,14,1,READ,SELECT,,,SELECT count(*) FROM parallel_test;,<not logged>
+ count
+-------
+ 1000
+(1 row)
+
+-- Cleanup parallel test
+DROP TABLE parallel_test;
+RESET max_parallel_workers_per_gather;
+RESET parallel_tuple_cost;
+RESET parallel_setup_cost;
+RESET min_parallel_table_scan_size;
+RESET min_parallel_index_scan_size;
+RESET pgaudit.log;
+RESET pgaudit.log_client;
+RESET pgaudit.log_level;
-- Cleanup
-- Set client_min_messages up to warning to avoid noise
SET client_min_messages = 'warning';
diff --git a/pgaudit.c b/pgaudit.c
index 5e6fd38..ac9ded2 100644
--- a/pgaudit.c
+++ b/pgaudit.c
@@ -11,6 +11,7 @@
#include "postgres.h"
#include "access/htup_details.h"
+#include "access/parallel.h"
#include "access/sysattr.h"
#include "access/xact.h"
#include "access/relation.h"
@@ -1303,7 +1304,7 @@ pgaudit_ExecutorStart_hook(QueryDesc *queryDesc, int eflags)
{
AuditEventStackItem *stackItem = NULL;
- if (!internalStatement)
+ if (!internalStatement && !IsParallelWorker())
{
/* Push the audit even onto the stack */
stackItem = stack_push();
@@ -1384,7 +1385,7 @@ pgaudit_ExecutorCheckPerms_hook(List *rangeTabls, bool abort)
/* Log DML if the audit role is valid or session logging is enabled */
if ((auditOid != InvalidOid || auditLogBitmap != 0) &&
- !IsAbortedTransactionBlockState())
+ !IsAbortedTransactionBlockState() && !IsParallelWorker())
{
/* If auditLogRows is on, wait for rows processed to be set */
if (auditLogRows && auditEventStack != NULL)
@@ -1438,7 +1439,7 @@ pgaudit_ExecutorRun_hook(QueryDesc *queryDesc, ScanDirection direction, uint64 c
else
standard_ExecutorRun(queryDesc, direction, count, execute_once);
- if (auditLogRows && !internalStatement)
+ if (auditLogRows && !internalStatement && !IsParallelWorker())
{
/* Find an item from the stack by the query memory context */
stackItem = stack_find_context(queryDesc->estate->es_query_cxt);
@@ -1458,7 +1459,7 @@ pgaudit_ExecutorEnd_hook(QueryDesc *queryDesc)
AuditEventStackItem *stackItem = NULL;
AuditEventStackItem *auditEventStackFull = NULL;
- if (auditLogRows && !internalStatement)
+ if (auditLogRows && !internalStatement && !IsParallelWorker())
{
/* Find an item from the stack by the query memory context */
stackItem = stack_find_context(queryDesc->estate->es_query_cxt);
diff --git a/sql/pgaudit.sql b/sql/pgaudit.sql
index cc1374a..1870a60 100644
--- a/sql/pgaudit.sql
+++ b/sql/pgaudit.sql
@@ -1612,6 +1612,36 @@ COMMIT;
DROP TABLE part_test;
+--
+-- Test logging in parallel workers
+SET pgaudit.log = 'read';
+SET pgaudit.log_client = on;
+SET pgaudit.log_level = 'notice';
+
+-- Force parallel execution for testing
+SET max_parallel_workers_per_gather = 2;
+SET parallel_tuple_cost = 0;
+SET parallel_setup_cost = 0;
+SET min_parallel_table_scan_size = 0;
+SET min_parallel_index_scan_size = 0;
+
+-- Create table with enough data to trigger parallel execution
+CREATE TABLE parallel_test (id int, data text);
+INSERT INTO parallel_test SELECT generate_series(1, 1000), 'test data';
+
+SELECT count(*) FROM parallel_test;
+
+-- Cleanup parallel test
+DROP TABLE parallel_test;
+RESET max_parallel_workers_per_gather;
+RESET parallel_tuple_cost;
+RESET parallel_setup_cost;
+RESET min_parallel_table_scan_size;
+RESET min_parallel_index_scan_size;
+RESET pgaudit.log;
+RESET pgaudit.log_client;
+RESET pgaudit.log_level;
+
-- Cleanup
-- Set client_min_messages up to warning to avoid noise
SET client_min_messages = 'warning';

View File

@@ -0,0 +1,143 @@
commit 29dc2847f6255541992f18faf8a815dfab79631a
Author: Tristan Partin <tristan.partin@databricks.com>
Date: 2025-06-23 02:09:31 +0000
Disable logging in parallel workers
When a query uses parallel workers, pgaudit will log the same query for
every parallel worker. This is undesireable since it can result in log
amplification for queries that use parallel workers.
Signed-off-by: Tristan Partin <tristan.partin@databricks.com>
diff --git a/expected/pgaudit.out b/expected/pgaudit.out
index b22560b..73f0327 100644
--- a/expected/pgaudit.out
+++ b/expected/pgaudit.out
@@ -2563,6 +2563,37 @@ COMMIT;
NOTICE: AUDIT: SESSION,12,4,MISC,COMMIT,,,COMMIT;,<not logged>
DROP TABLE part_test;
NOTICE: AUDIT: SESSION,13,1,DDL,DROP TABLE,,,DROP TABLE part_test;,<not logged>
+--
+-- Test logging in parallel workers
+SET pgaudit.log = 'read';
+SET pgaudit.log_client = on;
+SET pgaudit.log_level = 'notice';
+-- Force parallel execution for testing
+SET max_parallel_workers_per_gather = 2;
+SET parallel_tuple_cost = 0;
+SET parallel_setup_cost = 0;
+SET min_parallel_table_scan_size = 0;
+SET min_parallel_index_scan_size = 0;
+-- Create table with enough data to trigger parallel execution
+CREATE TABLE parallel_test (id int, data text);
+INSERT INTO parallel_test SELECT generate_series(1, 1000), 'test data';
+SELECT count(*) FROM parallel_test;
+NOTICE: AUDIT: SESSION,14,1,READ,SELECT,,,SELECT count(*) FROM parallel_test;,<not logged>
+ count
+-------
+ 1000
+(1 row)
+
+-- Cleanup parallel test
+DROP TABLE parallel_test;
+RESET max_parallel_workers_per_gather;
+RESET parallel_tuple_cost;
+RESET parallel_setup_cost;
+RESET min_parallel_table_scan_size;
+RESET min_parallel_index_scan_size;
+RESET pgaudit.log;
+RESET pgaudit.log_client;
+RESET pgaudit.log_level;
-- Cleanup
-- Set client_min_messages up to warning to avoid noise
SET client_min_messages = 'warning';
diff --git a/pgaudit.c b/pgaudit.c
index 5e6fd38..ac9ded2 100644
--- a/pgaudit.c
+++ b/pgaudit.c
@@ -11,6 +11,7 @@
#include "postgres.h"
#include "access/htup_details.h"
+#include "access/parallel.h"
#include "access/sysattr.h"
#include "access/xact.h"
#include "access/relation.h"
@@ -1303,7 +1304,7 @@ pgaudit_ExecutorStart_hook(QueryDesc *queryDesc, int eflags)
{
AuditEventStackItem *stackItem = NULL;
- if (!internalStatement)
+ if (!internalStatement && !IsParallelWorker())
{
/* Push the audit even onto the stack */
stackItem = stack_push();
@@ -1384,7 +1385,7 @@ pgaudit_ExecutorCheckPerms_hook(List *rangeTabls, bool abort)
/* Log DML if the audit role is valid or session logging is enabled */
if ((auditOid != InvalidOid || auditLogBitmap != 0) &&
- !IsAbortedTransactionBlockState())
+ !IsAbortedTransactionBlockState() && !IsParallelWorker())
{
/* If auditLogRows is on, wait for rows processed to be set */
if (auditLogRows && auditEventStack != NULL)
@@ -1438,7 +1439,7 @@ pgaudit_ExecutorRun_hook(QueryDesc *queryDesc, ScanDirection direction, uint64 c
else
standard_ExecutorRun(queryDesc, direction, count, execute_once);
- if (auditLogRows && !internalStatement)
+ if (auditLogRows && !internalStatement && !IsParallelWorker())
{
/* Find an item from the stack by the query memory context */
stackItem = stack_find_context(queryDesc->estate->es_query_cxt);
@@ -1458,7 +1459,7 @@ pgaudit_ExecutorEnd_hook(QueryDesc *queryDesc)
AuditEventStackItem *stackItem = NULL;
AuditEventStackItem *auditEventStackFull = NULL;
- if (auditLogRows && !internalStatement)
+ if (auditLogRows && !internalStatement && !IsParallelWorker())
{
/* Find an item from the stack by the query memory context */
stackItem = stack_find_context(queryDesc->estate->es_query_cxt);
diff --git a/sql/pgaudit.sql b/sql/pgaudit.sql
index 8052426..7f0667b 100644
--- a/sql/pgaudit.sql
+++ b/sql/pgaudit.sql
@@ -1612,6 +1612,36 @@ COMMIT;
DROP TABLE part_test;
+--
+-- Test logging in parallel workers
+SET pgaudit.log = 'read';
+SET pgaudit.log_client = on;
+SET pgaudit.log_level = 'notice';
+
+-- Force parallel execution for testing
+SET max_parallel_workers_per_gather = 2;
+SET parallel_tuple_cost = 0;
+SET parallel_setup_cost = 0;
+SET min_parallel_table_scan_size = 0;
+SET min_parallel_index_scan_size = 0;
+
+-- Create table with enough data to trigger parallel execution
+CREATE TABLE parallel_test (id int, data text);
+INSERT INTO parallel_test SELECT generate_series(1, 1000), 'test data';
+
+SELECT count(*) FROM parallel_test;
+
+-- Cleanup parallel test
+DROP TABLE parallel_test;
+RESET max_parallel_workers_per_gather;
+RESET parallel_tuple_cost;
+RESET parallel_setup_cost;
+RESET min_parallel_table_scan_size;
+RESET min_parallel_index_scan_size;
+RESET pgaudit.log;
+RESET pgaudit.log_client;
+RESET pgaudit.log_level;
+
-- Cleanup
-- Set client_min_messages up to warning to avoid noise
SET client_min_messages = 'warning';

View File

@@ -0,0 +1,143 @@
commit cc708dde7ef2af2a8120d757102d2e34c0463a0f
Author: Tristan Partin <tristan.partin@databricks.com>
Date: 2025-06-23 02:09:31 +0000
Disable logging in parallel workers
When a query uses parallel workers, pgaudit will log the same query for
every parallel worker. This is undesireable since it can result in log
amplification for queries that use parallel workers.
Signed-off-by: Tristan Partin <tristan.partin@databricks.com>
diff --git a/expected/pgaudit.out b/expected/pgaudit.out
index 8772054..9b66ac6 100644
--- a/expected/pgaudit.out
+++ b/expected/pgaudit.out
@@ -2556,6 +2556,37 @@ DROP SERVER fdw_server;
NOTICE: AUDIT: SESSION,11,1,DDL,DROP SERVER,,,DROP SERVER fdw_server;,<not logged>
DROP EXTENSION postgres_fdw;
NOTICE: AUDIT: SESSION,12,1,DDL,DROP EXTENSION,,,DROP EXTENSION postgres_fdw;,<not logged>
+--
+-- Test logging in parallel workers
+SET pgaudit.log = 'read';
+SET pgaudit.log_client = on;
+SET pgaudit.log_level = 'notice';
+-- Force parallel execution for testing
+SET max_parallel_workers_per_gather = 2;
+SET parallel_tuple_cost = 0;
+SET parallel_setup_cost = 0;
+SET min_parallel_table_scan_size = 0;
+SET min_parallel_index_scan_size = 0;
+-- Create table with enough data to trigger parallel execution
+CREATE TABLE parallel_test (id int, data text);
+INSERT INTO parallel_test SELECT generate_series(1, 1000), 'test data';
+SELECT count(*) FROM parallel_test;
+NOTICE: AUDIT: SESSION,13,1,READ,SELECT,,,SELECT count(*) FROM parallel_test;,<not logged>
+ count
+-------
+ 1000
+(1 row)
+
+-- Cleanup parallel test
+DROP TABLE parallel_test;
+RESET max_parallel_workers_per_gather;
+RESET parallel_tuple_cost;
+RESET parallel_setup_cost;
+RESET min_parallel_table_scan_size;
+RESET min_parallel_index_scan_size;
+RESET pgaudit.log;
+RESET pgaudit.log_client;
+RESET pgaudit.log_level;
-- Cleanup
-- Set client_min_messages up to warning to avoid noise
SET client_min_messages = 'warning';
diff --git a/pgaudit.c b/pgaudit.c
index 004d1f9..f061164 100644
--- a/pgaudit.c
+++ b/pgaudit.c
@@ -11,6 +11,7 @@
#include "postgres.h"
#include "access/htup_details.h"
+#include "access/parallel.h"
#include "access/sysattr.h"
#include "access/xact.h"
#include "access/relation.h"
@@ -1339,7 +1340,7 @@ pgaudit_ExecutorStart_hook(QueryDesc *queryDesc, int eflags)
{
AuditEventStackItem *stackItem = NULL;
- if (!internalStatement)
+ if (!internalStatement && !IsParallelWorker())
{
/* Push the audit even onto the stack */
stackItem = stack_push();
@@ -1420,7 +1421,7 @@ pgaudit_ExecutorCheckPerms_hook(List *rangeTabls, List *permInfos, bool abort)
/* Log DML if the audit role is valid or session logging is enabled */
if ((auditOid != InvalidOid || auditLogBitmap != 0) &&
- !IsAbortedTransactionBlockState())
+ !IsAbortedTransactionBlockState() && !IsParallelWorker())
{
/* If auditLogRows is on, wait for rows processed to be set */
if (auditLogRows && auditEventStack != NULL)
@@ -1475,7 +1476,7 @@ pgaudit_ExecutorRun_hook(QueryDesc *queryDesc, ScanDirection direction, uint64 c
else
standard_ExecutorRun(queryDesc, direction, count, execute_once);
- if (auditLogRows && !internalStatement)
+ if (auditLogRows && !internalStatement && !IsParallelWorker())
{
/* Find an item from the stack by the query memory context */
stackItem = stack_find_context(queryDesc->estate->es_query_cxt);
@@ -1495,7 +1496,7 @@ pgaudit_ExecutorEnd_hook(QueryDesc *queryDesc)
AuditEventStackItem *stackItem = NULL;
AuditEventStackItem *auditEventStackFull = NULL;
- if (auditLogRows && !internalStatement)
+ if (auditLogRows && !internalStatement && !IsParallelWorker())
{
/* Find an item from the stack by the query memory context */
stackItem = stack_find_context(queryDesc->estate->es_query_cxt);
diff --git a/sql/pgaudit.sql b/sql/pgaudit.sql
index 6aae88b..de6d7fd 100644
--- a/sql/pgaudit.sql
+++ b/sql/pgaudit.sql
@@ -1631,6 +1631,36 @@ DROP USER MAPPING FOR regress_user1 SERVER fdw_server;
DROP SERVER fdw_server;
DROP EXTENSION postgres_fdw;
+--
+-- Test logging in parallel workers
+SET pgaudit.log = 'read';
+SET pgaudit.log_client = on;
+SET pgaudit.log_level = 'notice';
+
+-- Force parallel execution for testing
+SET max_parallel_workers_per_gather = 2;
+SET parallel_tuple_cost = 0;
+SET parallel_setup_cost = 0;
+SET min_parallel_table_scan_size = 0;
+SET min_parallel_index_scan_size = 0;
+
+-- Create table with enough data to trigger parallel execution
+CREATE TABLE parallel_test (id int, data text);
+INSERT INTO parallel_test SELECT generate_series(1, 1000), 'test data';
+
+SELECT count(*) FROM parallel_test;
+
+-- Cleanup parallel test
+DROP TABLE parallel_test;
+RESET max_parallel_workers_per_gather;
+RESET parallel_tuple_cost;
+RESET parallel_setup_cost;
+RESET min_parallel_table_scan_size;
+RESET min_parallel_index_scan_size;
+RESET pgaudit.log;
+RESET pgaudit.log_client;
+RESET pgaudit.log_level;
+
-- Cleanup
-- Set client_min_messages up to warning to avoid noise
SET client_min_messages = 'warning';

View File

@@ -0,0 +1,143 @@
commit 8d02e4c6c5e1e8676251b0717a46054267091cb4
Author: Tristan Partin <tristan.partin@databricks.com>
Date: 2025-06-23 02:09:31 +0000
Disable logging in parallel workers
When a query uses parallel workers, pgaudit will log the same query for
every parallel worker. This is undesireable since it can result in log
amplification for queries that use parallel workers.
Signed-off-by: Tristan Partin <tristan.partin@databricks.com>
diff --git a/expected/pgaudit.out b/expected/pgaudit.out
index d696287..4b1059a 100644
--- a/expected/pgaudit.out
+++ b/expected/pgaudit.out
@@ -2568,6 +2568,37 @@ DROP SERVER fdw_server;
NOTICE: AUDIT: SESSION,11,1,DDL,DROP SERVER,,,DROP SERVER fdw_server,<not logged>
DROP EXTENSION postgres_fdw;
NOTICE: AUDIT: SESSION,12,1,DDL,DROP EXTENSION,,,DROP EXTENSION postgres_fdw,<not logged>
+--
+-- Test logging in parallel workers
+SET pgaudit.log = 'read';
+SET pgaudit.log_client = on;
+SET pgaudit.log_level = 'notice';
+-- Force parallel execution for testing
+SET max_parallel_workers_per_gather = 2;
+SET parallel_tuple_cost = 0;
+SET parallel_setup_cost = 0;
+SET min_parallel_table_scan_size = 0;
+SET min_parallel_index_scan_size = 0;
+-- Create table with enough data to trigger parallel execution
+CREATE TABLE parallel_test (id int, data text);
+INSERT INTO parallel_test SELECT generate_series(1, 1000), 'test data';
+SELECT count(*) FROM parallel_test;
+NOTICE: AUDIT: SESSION,13,1,READ,SELECT,,,SELECT count(*) FROM parallel_test,<not logged>
+ count
+-------
+ 1000
+(1 row)
+
+-- Cleanup parallel test
+DROP TABLE parallel_test;
+RESET max_parallel_workers_per_gather;
+RESET parallel_tuple_cost;
+RESET parallel_setup_cost;
+RESET min_parallel_table_scan_size;
+RESET min_parallel_index_scan_size;
+RESET pgaudit.log;
+RESET pgaudit.log_client;
+RESET pgaudit.log_level;
-- Cleanup
-- Set client_min_messages up to warning to avoid noise
SET client_min_messages = 'warning';
diff --git a/pgaudit.c b/pgaudit.c
index 1764af1..0e48875 100644
--- a/pgaudit.c
+++ b/pgaudit.c
@@ -11,6 +11,7 @@
#include "postgres.h"
#include "access/htup_details.h"
+#include "access/parallel.h"
#include "access/sysattr.h"
#include "access/xact.h"
#include "access/relation.h"
@@ -1406,7 +1407,7 @@ pgaudit_ExecutorStart_hook(QueryDesc *queryDesc, int eflags)
{
AuditEventStackItem *stackItem = NULL;
- if (!internalStatement)
+ if (!internalStatement && !IsParallelWorker())
{
/* Push the audit event onto the stack */
stackItem = stack_push();
@@ -1489,7 +1490,7 @@ pgaudit_ExecutorCheckPerms_hook(List *rangeTabls, List *permInfos, bool abort)
/* Log DML if the audit role is valid or session logging is enabled */
if ((auditOid != InvalidOid || auditLogBitmap != 0) &&
- !IsAbortedTransactionBlockState())
+ !IsAbortedTransactionBlockState() && !IsParallelWorker())
{
/* If auditLogRows is on, wait for rows processed to be set */
if (auditLogRows && auditEventStack != NULL)
@@ -1544,7 +1545,7 @@ pgaudit_ExecutorRun_hook(QueryDesc *queryDesc, ScanDirection direction, uint64 c
else
standard_ExecutorRun(queryDesc, direction, count, execute_once);
- if (auditLogRows && !internalStatement)
+ if (auditLogRows && !internalStatement && !IsParallelWorker())
{
/* Find an item from the stack by the query memory context */
stackItem = stack_find_context(queryDesc->estate->es_query_cxt);
@@ -1564,7 +1565,7 @@ pgaudit_ExecutorEnd_hook(QueryDesc *queryDesc)
AuditEventStackItem *stackItem = NULL;
AuditEventStackItem *auditEventStackFull = NULL;
- if (auditLogRows && !internalStatement)
+ if (auditLogRows && !internalStatement && !IsParallelWorker())
{
/* Find an item from the stack by the query memory context */
stackItem = stack_find_context(queryDesc->estate->es_query_cxt);
diff --git a/sql/pgaudit.sql b/sql/pgaudit.sql
index e161f01..c873098 100644
--- a/sql/pgaudit.sql
+++ b/sql/pgaudit.sql
@@ -1637,6 +1637,36 @@ DROP USER MAPPING FOR regress_user1 SERVER fdw_server;
DROP SERVER fdw_server;
DROP EXTENSION postgres_fdw;
+--
+-- Test logging in parallel workers
+SET pgaudit.log = 'read';
+SET pgaudit.log_client = on;
+SET pgaudit.log_level = 'notice';
+
+-- Force parallel execution for testing
+SET max_parallel_workers_per_gather = 2;
+SET parallel_tuple_cost = 0;
+SET parallel_setup_cost = 0;
+SET min_parallel_table_scan_size = 0;
+SET min_parallel_index_scan_size = 0;
+
+-- Create table with enough data to trigger parallel execution
+CREATE TABLE parallel_test (id int, data text);
+INSERT INTO parallel_test SELECT generate_series(1, 1000), 'test data';
+
+SELECT count(*) FROM parallel_test;
+
+-- Cleanup parallel test
+DROP TABLE parallel_test;
+RESET max_parallel_workers_per_gather;
+RESET parallel_tuple_cost;
+RESET parallel_setup_cost;
+RESET min_parallel_table_scan_size;
+RESET min_parallel_index_scan_size;
+RESET pgaudit.log;
+RESET pgaudit.log_client;
+RESET pgaudit.log_level;
+
-- Cleanup
-- Set client_min_messages up to warning to avoid noise
SET client_min_messages = 'warning';

View File

@@ -27,6 +27,7 @@ fail.workspace = true
flate2.workspace = true
futures.workspace = true
http.workspace = true
hostname-validator = "1.1"
indexmap.workspace = true
itertools.workspace = true
jsonwebtoken.workspace = true

View File

@@ -36,6 +36,8 @@
use std::ffi::OsString;
use std::fs::File;
use std::process::exit;
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
@@ -190,7 +192,9 @@ fn main() -> Result<()> {
cgroup: cli.cgroup,
#[cfg(target_os = "linux")]
vm_monitor_addr: cli.vm_monitor_addr,
installed_extensions_collection_interval: cli.installed_extensions_collection_interval,
installed_extensions_collection_interval: Arc::new(AtomicU64::new(
cli.installed_extensions_collection_interval,
)),
},
config,
)?;

View File

@@ -26,11 +26,12 @@ use std::os::unix::fs::{PermissionsExt, symlink};
use std::path::Path;
use std::process::{Command, Stdio};
use std::str::FromStr;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::sync::{Arc, Condvar, Mutex, RwLock};
use std::time::{Duration, Instant};
use std::{env, fs};
use tokio::spawn;
use tokio::task::JoinHandle;
use tokio::{spawn, time};
use tracing::{Instrument, debug, error, info, instrument, warn};
use url::Url;
use utils::id::{TenantId, TimelineId};
@@ -71,6 +72,7 @@ pub static BUILD_TAG: Lazy<String> = Lazy::new(|| {
.unwrap_or(BUILD_TAG_DEFAULT)
.to_string()
});
const DEFAULT_INSTALLED_EXTENSIONS_COLLECTION_INTERVAL: u64 = 3600;
/// Static configuration params that don't change after startup. These mostly
/// come from the CLI args, or are derived from them.
@@ -104,9 +106,11 @@ pub struct ComputeNodeParams {
pub remote_ext_base_url: Option<Url>,
/// Interval for installed extensions collection
pub installed_extensions_collection_interval: u64,
pub installed_extensions_collection_interval: Arc<AtomicU64>,
}
type TaskHandle = Mutex<Option<JoinHandle<()>>>;
/// Compute node info shared across several `compute_ctl` threads.
pub struct ComputeNode {
pub params: ComputeNodeParams,
@@ -127,6 +131,10 @@ pub struct ComputeNode {
// key: ext_archive_name, value: started download time, download_completed?
pub ext_download_progress: RwLock<HashMap<String, (DateTime<Utc>, bool)>>,
pub compute_ctl_config: ComputeCtlConfig,
/// Handle to the extension stats collection task
extension_stats_task: TaskHandle,
lfc_offload_task: TaskHandle,
}
// store some metrics about download size that might impact startup time
@@ -392,7 +400,7 @@ fn maybe_cgexec(cmd: &str) -> Command {
struct PostgresHandle {
postgres: std::process::Child,
log_collector: tokio::task::JoinHandle<Result<()>>,
log_collector: JoinHandle<Result<()>>,
}
impl PostgresHandle {
@@ -406,7 +414,7 @@ struct StartVmMonitorResult {
#[cfg(target_os = "linux")]
token: tokio_util::sync::CancellationToken,
#[cfg(target_os = "linux")]
vm_monitor: Option<tokio::task::JoinHandle<Result<()>>>,
vm_monitor: Option<JoinHandle<Result<()>>>,
}
impl ComputeNode {
@@ -456,6 +464,8 @@ impl ComputeNode {
state_changed: Condvar::new(),
ext_download_progress: RwLock::new(HashMap::new()),
compute_ctl_config: config.compute_ctl_config,
extension_stats_task: Mutex::new(None),
lfc_offload_task: Mutex::new(None),
})
}
@@ -543,6 +553,9 @@ impl ComputeNode {
None
};
this.terminate_extension_stats_task();
this.terminate_lfc_offload_task();
// Terminate the vm_monitor so it releases the file watcher on
// /sys/fs/cgroup/neon-postgres.
// Note: the vm-monitor only runs on linux because it requires cgroups.
@@ -779,10 +792,15 @@ impl ComputeNode {
// Configure and start rsyslog for compliance audit logging
match pspec.spec.audit_log_level {
ComputeAudit::Hipaa | ComputeAudit::Extended | ComputeAudit::Full => {
let remote_endpoint =
let remote_tls_endpoint =
std::env::var("AUDIT_LOGGING_TLS_ENDPOINT").unwrap_or("".to_string());
let remote_plain_endpoint =
std::env::var("AUDIT_LOGGING_ENDPOINT").unwrap_or("".to_string());
if remote_endpoint.is_empty() {
anyhow::bail!("AUDIT_LOGGING_ENDPOINT is empty");
if remote_plain_endpoint.is_empty() && remote_tls_endpoint.is_empty() {
anyhow::bail!(
"AUDIT_LOGGING_ENDPOINT and AUDIT_LOGGING_TLS_ENDPOINT are both empty"
);
}
let log_directory_path = Path::new(&self.params.pgdata).join("log");
@@ -798,7 +816,8 @@ impl ComputeNode {
log_directory_path.clone(),
endpoint_id,
project_id,
&remote_endpoint,
&remote_plain_endpoint,
&remote_tls_endpoint,
)?;
// Launch a background task to clean up the audit logs
@@ -865,12 +884,15 @@ impl ComputeNode {
// Log metrics so that we can search for slow operations in logs
info!(?metrics, postmaster_pid = %postmaster_pid, "compute start finished");
// Spawn the extension stats background task
self.spawn_extension_stats_task();
if pspec.spec.autoprewarm {
info!("autoprewarming on startup as requested");
self.prewarm_lfc(None);
}
if let Some(seconds) = pspec.spec.offload_lfc_interval_seconds {
self.spawn_lfc_offload_task(Duration::from_secs(seconds.into()));
};
Ok(())
}
@@ -1693,6 +1715,8 @@ impl ComputeNode {
tls_config = self.compute_ctl_config.tls.clone();
}
self.update_installed_extensions_collection_interval(&spec);
let max_concurrent_connections = self.max_service_connections(compute_state, &spec);
// Merge-apply spec & changes to PostgreSQL state.
@@ -1757,6 +1781,8 @@ impl ComputeNode {
let tls_config = self.tls_config(&spec);
self.update_installed_extensions_collection_interval(&spec);
if let Some(ref pgbouncer_settings) = spec.pgbouncer_settings {
info!("tuning pgbouncer");
@@ -2361,24 +2387,92 @@ LIMIT 100",
}
pub fn spawn_extension_stats_task(&self) {
self.terminate_extension_stats_task();
let conf = self.tokio_conn_conf.clone();
let installed_extensions_collection_interval =
self.params.installed_extensions_collection_interval;
tokio::spawn(async move {
// An initial sleep is added to ensure that two collections don't happen at the same time.
// The first collection happens during compute startup.
tokio::time::sleep(tokio::time::Duration::from_secs(
installed_extensions_collection_interval,
))
.await;
let mut interval = tokio::time::interval(tokio::time::Duration::from_secs(
installed_extensions_collection_interval,
));
let atomic_interval = self.params.installed_extensions_collection_interval.clone();
let mut installed_extensions_collection_interval =
2 * atomic_interval.load(std::sync::atomic::Ordering::SeqCst);
info!(
"[NEON_EXT_SPAWN] Spawning background installed extensions worker with Timeout: {}",
installed_extensions_collection_interval
);
let handle = tokio::spawn(async move {
loop {
interval.tick().await;
info!(
"[NEON_EXT_INT_SLEEP]: Interval: {}",
installed_extensions_collection_interval
);
// Sleep at the start of the loop to ensure that two collections don't happen at the same time.
// The first collection happens during compute startup.
tokio::time::sleep(tokio::time::Duration::from_secs(
installed_extensions_collection_interval,
))
.await;
let _ = installed_extensions(conf.clone()).await;
// Acquire a read lock on the compute spec and then update the interval if necessary
installed_extensions_collection_interval = std::cmp::max(
installed_extensions_collection_interval,
2 * atomic_interval.load(std::sync::atomic::Ordering::SeqCst),
);
}
});
// Store the new task handle
*self.extension_stats_task.lock().unwrap() = Some(handle);
}
fn terminate_extension_stats_task(&self) {
if let Some(h) = self.extension_stats_task.lock().unwrap().take() {
h.abort()
}
}
pub fn spawn_lfc_offload_task(self: &Arc<Self>, interval: Duration) {
self.terminate_lfc_offload_task();
let secs = interval.as_secs();
info!("spawning lfc offload worker with {secs}s interval");
let this = self.clone();
let handle = spawn(async move {
let mut interval = time::interval(interval);
interval.tick().await; // returns immediately
loop {
interval.tick().await;
this.offload_lfc_async().await;
}
});
*self.lfc_offload_task.lock().unwrap() = Some(handle);
}
fn terminate_lfc_offload_task(&self) {
if let Some(h) = self.lfc_offload_task.lock().unwrap().take() {
h.abort()
}
}
fn update_installed_extensions_collection_interval(&self, spec: &ComputeSpec) {
// Update the interval for collecting installed extensions statistics
// If the value is -1, we never suspend so set the value to default collection.
// If the value is 0, it means default, we will just continue to use the default.
if spec.suspend_timeout_seconds == -1 || spec.suspend_timeout_seconds == 0 {
info!(
"[NEON_EXT_INT_UPD] Spec Timeout: {}, New Timeout: {}",
spec.suspend_timeout_seconds, DEFAULT_INSTALLED_EXTENSIONS_COLLECTION_INTERVAL
);
self.params.installed_extensions_collection_interval.store(
DEFAULT_INSTALLED_EXTENSIONS_COLLECTION_INTERVAL,
std::sync::atomic::Ordering::SeqCst,
);
} else {
info!(
"[NEON_EXT_INT_UPD] Spec Timeout: {}",
spec.suspend_timeout_seconds
);
self.params.installed_extensions_collection_interval.store(
spec.suspend_timeout_seconds as u64,
std::sync::atomic::Ordering::SeqCst,
);
}
}
}

View File

@@ -5,6 +5,7 @@ use compute_api::responses::LfcOffloadState;
use compute_api::responses::LfcPrewarmState;
use http::StatusCode;
use reqwest::Client;
use std::mem::replace;
use std::sync::Arc;
use tokio::{io::AsyncReadExt, spawn};
use tracing::{error, info};
@@ -88,17 +89,15 @@ impl ComputeNode {
self.state.lock().unwrap().lfc_offload_state.clone()
}
/// Returns false if there is a prewarm request ongoing, true otherwise
/// If there is a prewarm request ongoing, return false, true otherwise
pub fn prewarm_lfc(self: &Arc<Self>, from_endpoint: Option<String>) -> bool {
crate::metrics::LFC_PREWARM_REQUESTS.inc();
{
let state = &mut self.state.lock().unwrap().lfc_prewarm_state;
if let LfcPrewarmState::Prewarming =
std::mem::replace(state, LfcPrewarmState::Prewarming)
{
if let LfcPrewarmState::Prewarming = replace(state, LfcPrewarmState::Prewarming) {
return false;
}
}
crate::metrics::LFC_PREWARMS.inc();
let cloned = self.clone();
spawn(async move {
@@ -152,32 +151,41 @@ impl ComputeNode {
.map(|_| ())
}
/// Returns false if there is an offload request ongoing, true otherwise
/// If offload request is ongoing, return false, true otherwise
pub fn offload_lfc(self: &Arc<Self>) -> bool {
crate::metrics::LFC_OFFLOAD_REQUESTS.inc();
{
let state = &mut self.state.lock().unwrap().lfc_offload_state;
if let LfcOffloadState::Offloading =
std::mem::replace(state, LfcOffloadState::Offloading)
{
if replace(state, LfcOffloadState::Offloading) == LfcOffloadState::Offloading {
return false;
}
}
let cloned = self.clone();
spawn(async move {
let Err(err) = cloned.offload_lfc_impl().await else {
cloned.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Completed;
return;
};
error!(%err);
cloned.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Failed {
error: err.to_string(),
};
});
spawn(async move { cloned.offload_lfc_with_state_update().await });
true
}
pub async fn offload_lfc_async(self: &Arc<Self>) {
{
let state = &mut self.state.lock().unwrap().lfc_offload_state;
if replace(state, LfcOffloadState::Offloading) == LfcOffloadState::Offloading {
return;
}
}
self.offload_lfc_with_state_update().await
}
async fn offload_lfc_with_state_update(&self) {
crate::metrics::LFC_OFFLOADS.inc();
let Err(err) = self.offload_lfc_impl().await else {
self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Completed;
return;
};
error!(%err);
self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Failed {
error: err.to_string(),
};
}
async fn offload_lfc_impl(&self) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?;
info!(%url, "requesting LFC state from postgres");

View File

@@ -10,7 +10,13 @@ input(type="imfile" File="{log_directory}/*.log"
startmsg.regex="^[[:digit:]]{{4}}-[[:digit:]]{{2}}-[[:digit:]]{{2}} [[:digit:]]{{2}}:[[:digit:]]{{2}}:[[:digit:]]{{2}}.[[:digit:]]{{3}} GMT,")
# the directory to store rsyslog state files
global(workDirectory="/var/log/rsyslog")
global(
workDirectory="/var/log/rsyslog"
DefaultNetstreamDriverCAFile="/etc/ssl/certs/ca-certificates.crt"
)
# Whether the remote syslog receiver uses tls
set $.remote_syslog_tls = "{remote_syslog_tls}";
# Construct json, endpoint_id and project_id as additional metadata
set $.json_log!endpoint_id = "{endpoint_id}";
@@ -21,5 +27,29 @@ set $.json_log!msg = $msg;
template(name="PgAuditLog" type="string"
string="<%PRI%>1 %TIMESTAMP:::date-rfc3339% %HOSTNAME% - - - - %$.json_log%")
# Forward to remote syslog receiver (@@<hostname>:<port>;format
local5.info @@{remote_endpoint};PgAuditLog
# Forward to remote syslog receiver (over TLS)
if ( $syslogtag == 'pgaudit_log' ) then {{
if ( $.remote_syslog_tls == 'true' ) then {{
action(type="omfwd" target="{remote_syslog_host}" port="{remote_syslog_port}" protocol="tcp"
template="PgAuditLog"
queue.type="linkedList"
queue.size="1000"
action.ResumeRetryCount="10"
StreamDriver="gtls"
StreamDriverMode="1"
StreamDriverAuthMode="x509/name"
StreamDriverPermittedPeers="{remote_syslog_host}"
StreamDriver.CheckExtendedKeyPurpose="on"
StreamDriver.PermitExpiredCerts="off"
)
stop
}} else {{
action(type="omfwd" target="{remote_syslog_host}" port="{remote_syslog_port}" protocol="tcp"
template="PgAuditLog"
queue.type="linkedList"
queue.size="1000"
action.ResumeRetryCount="10"
)
stop
}}
}}

View File

@@ -97,20 +97,18 @@ pub(crate) static PG_TOTAL_DOWNTIME_MS: Lazy<GenericCounter<AtomicU64>> = Lazy::
.expect("failed to define a metric")
});
/// Needed as neon.file_cache_prewarm_batch == 0 doesn't mean we never tried to prewarm.
/// On the other hand, LFC_PREWARMED_PAGES is excessive as we can GET /lfc/prewarm
pub(crate) static LFC_PREWARM_REQUESTS: Lazy<IntCounter> = Lazy::new(|| {
pub(crate) static LFC_PREWARMS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"compute_ctl_lfc_prewarm_requests_total",
"Total number of LFC prewarm requests made by compute_ctl",
"compute_ctl_lfc_prewarms_total",
"Total number of LFC prewarms requested by compute_ctl or autoprewarm option",
)
.expect("failed to define a metric")
});
pub(crate) static LFC_OFFLOAD_REQUESTS: Lazy<IntCounter> = Lazy::new(|| {
pub(crate) static LFC_OFFLOADS: Lazy<IntCounter> = Lazy::new(|| {
register_int_counter!(
"compute_ctl_lfc_offload_requests_total",
"Total number of LFC offload requests made by compute_ctl",
"compute_ctl_lfc_offloads_total",
"Total number of LFC offloads requested by compute_ctl or lfc_offload_period_seconds option",
)
.expect("failed to define a metric")
});
@@ -124,7 +122,7 @@ pub fn collect() -> Vec<MetricFamily> {
metrics.extend(AUDIT_LOG_DIR_SIZE.collect());
metrics.extend(PG_CURR_DOWNTIME_MS.collect());
metrics.extend(PG_TOTAL_DOWNTIME_MS.collect());
metrics.extend(LFC_PREWARM_REQUESTS.collect());
metrics.extend(LFC_OFFLOAD_REQUESTS.collect());
metrics.extend(LFC_PREWARMS.collect());
metrics.extend(LFC_OFFLOADS.collect());
metrics
}

View File

@@ -4,8 +4,10 @@ use std::path::Path;
use std::process::Command;
use std::time::Duration;
use std::{fs::OpenOptions, io::Write};
use url::{Host, Url};
use anyhow::{Context, Result, anyhow};
use hostname_validator;
use tracing::{error, info, instrument, warn};
const POSTGRES_LOGS_CONF_PATH: &str = "/etc/rsyslog.d/postgres_logs.conf";
@@ -82,18 +84,84 @@ fn restart_rsyslog() -> Result<()> {
Ok(())
}
fn parse_audit_syslog_address(
remote_plain_endpoint: &str,
remote_tls_endpoint: &str,
) -> Result<(String, u16, String)> {
let tls;
let remote_endpoint = if !remote_tls_endpoint.is_empty() {
tls = "true".to_string();
remote_tls_endpoint
} else {
tls = "false".to_string();
remote_plain_endpoint
};
// Urlify the remote_endpoint, so parsing can be done with url::Url.
let url_str = format!("http://{remote_endpoint}");
let url = Url::parse(&url_str).map_err(|err| {
anyhow!("Error parsing {remote_endpoint}, expected host:port, got {err:?}")
})?;
let is_valid = url.scheme() == "http"
&& url.path() == "/"
&& url.query().is_none()
&& url.fragment().is_none()
&& url.username() == ""
&& url.password().is_none();
if !is_valid {
return Err(anyhow!(
"Invalid address format {remote_endpoint}, expected host:port"
));
}
let host = match url.host() {
Some(Host::Domain(h)) if hostname_validator::is_valid(h) => h.to_string(),
Some(Host::Ipv4(ip4)) => ip4.to_string(),
Some(Host::Ipv6(ip6)) => ip6.to_string(),
_ => return Err(anyhow!("Invalid host")),
};
let port = url
.port()
.ok_or_else(|| anyhow!("Invalid port in {remote_endpoint}"))?;
Ok((host, port, tls))
}
fn generate_audit_rsyslog_config(
log_directory: String,
endpoint_id: &str,
project_id: &str,
remote_syslog_host: &str,
remote_syslog_port: u16,
remote_syslog_tls: &str,
) -> String {
format!(
include_str!("config_template/compute_audit_rsyslog_template.conf"),
log_directory = log_directory,
endpoint_id = endpoint_id,
project_id = project_id,
remote_syslog_host = remote_syslog_host,
remote_syslog_port = remote_syslog_port,
remote_syslog_tls = remote_syslog_tls
)
}
pub fn configure_audit_rsyslog(
log_directory: String,
endpoint_id: &str,
project_id: &str,
remote_endpoint: &str,
remote_tls_endpoint: &str,
) -> Result<()> {
let config_content: String = format!(
include_str!("config_template/compute_audit_rsyslog_template.conf"),
log_directory = log_directory,
endpoint_id = endpoint_id,
project_id = project_id,
remote_endpoint = remote_endpoint
let (remote_syslog_host, remote_syslog_port, remote_syslog_tls) =
parse_audit_syslog_address(remote_endpoint, remote_tls_endpoint).unwrap();
let config_content = generate_audit_rsyslog_config(
log_directory,
endpoint_id,
project_id,
&remote_syslog_host,
remote_syslog_port,
&remote_syslog_tls,
);
info!("rsyslog config_content: {}", config_content);
@@ -258,6 +326,8 @@ pub fn launch_pgaudit_gc(log_directory: String) {
mod tests {
use crate::rsyslog::PostgresLogsRsyslogConfig;
use super::{generate_audit_rsyslog_config, parse_audit_syslog_address};
#[test]
fn test_postgres_logs_config() {
{
@@ -287,4 +357,146 @@ mod tests {
assert!(res.is_err());
}
}
#[test]
fn test_parse_audit_syslog_address() {
{
// host:port format (plaintext)
let parsed = parse_audit_syslog_address("collector.host.tld:5555", "");
assert!(parsed.is_ok());
assert_eq!(
parsed.unwrap(),
(
String::from("collector.host.tld"),
5555,
String::from("false")
)
);
}
{
// host:port format with ipv4 ip address (plaintext)
let parsed = parse_audit_syslog_address("10.0.0.1:5555", "");
assert!(parsed.is_ok());
assert_eq!(
parsed.unwrap(),
(String::from("10.0.0.1"), 5555, String::from("false"))
);
}
{
// host:port format with ipv6 ip address (plaintext)
let parsed =
parse_audit_syslog_address("[7e60:82ed:cb2e:d617:f904:f395:aaca:e252]:5555", "");
assert_eq!(
parsed.unwrap(),
(
String::from("7e60:82ed:cb2e:d617:f904:f395:aaca:e252"),
5555,
String::from("false")
)
);
}
{
// Only TLS host:port defined
let parsed = parse_audit_syslog_address("", "tls.host.tld:5556");
assert_eq!(
parsed.unwrap(),
(String::from("tls.host.tld"), 5556, String::from("true"))
);
}
{
// tls host should take precedence, when both defined
let parsed = parse_audit_syslog_address("plaintext.host.tld:5555", "tls.host.tld:5556");
assert_eq!(
parsed.unwrap(),
(String::from("tls.host.tld"), 5556, String::from("true"))
);
}
{
// host without port (plaintext)
let parsed = parse_audit_syslog_address("collector.host.tld", "");
assert!(parsed.is_err());
}
{
// port without host
let parsed = parse_audit_syslog_address(":5555", "");
assert!(parsed.is_err());
}
{
// valid host with invalid port
let parsed = parse_audit_syslog_address("collector.host.tld:90001", "");
assert!(parsed.is_err());
}
{
// invalid hostname with valid port
let parsed = parse_audit_syslog_address("-collector.host.tld:5555", "");
assert!(parsed.is_err());
}
{
// parse error
let parsed = parse_audit_syslog_address("collector.host.tld:::5555", "");
assert!(parsed.is_err());
}
}
#[test]
fn test_generate_audit_rsyslog_config() {
{
// plaintext version
let log_directory = "/tmp/log".to_string();
let endpoint_id = "ep-test-endpoint-id";
let project_id = "test-project-id";
let remote_syslog_host = "collector.host.tld";
let remote_syslog_port = 5555;
let remote_syslog_tls = "false";
let conf_str = generate_audit_rsyslog_config(
log_directory,
endpoint_id,
project_id,
remote_syslog_host,
remote_syslog_port,
remote_syslog_tls,
);
assert!(conf_str.contains(r#"set $.remote_syslog_tls = "false";"#));
assert!(conf_str.contains(r#"type="omfwd""#));
assert!(conf_str.contains(r#"target="collector.host.tld""#));
assert!(conf_str.contains(r#"port="5555""#));
assert!(conf_str.contains(r#"StreamDriverPermittedPeers="collector.host.tld""#));
}
{
// TLS version
let log_directory = "/tmp/log".to_string();
let endpoint_id = "ep-test-endpoint-id";
let project_id = "test-project-id";
let remote_syslog_host = "collector.host.tld";
let remote_syslog_port = 5556;
let remote_syslog_tls = "true";
let conf_str = generate_audit_rsyslog_config(
log_directory,
endpoint_id,
project_id,
remote_syslog_host,
remote_syslog_port,
remote_syslog_tls,
);
assert!(conf_str.contains(r#"set $.remote_syslog_tls = "true";"#));
assert!(conf_str.contains(r#"type="omfwd""#));
assert!(conf_str.contains(r#"target="collector.host.tld""#));
assert!(conf_str.contains(r#"port="5556""#));
assert!(conf_str.contains(r#"StreamDriverPermittedPeers="collector.host.tld""#));
}
}
}

View File

@@ -3,7 +3,8 @@
"timestamp": "2021-05-23T18:25:43.511Z",
"operation_uuid": "0f657b36-4b0f-4a2d-9c2e-1dcd615e7d8b",
"suspend_timeout_seconds": 3600,
"cluster": {
"cluster_id": "test-cluster-42",
"name": "Zenith Test",

View File

@@ -31,6 +31,7 @@ mod pg_helpers_tests {
wal_level = logical
hot_standby = on
autoprewarm = off
offload_lfc_interval_seconds = 20
neon.safekeepers = '127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501'
wal_log_hints = on
log_connections = on

View File

@@ -64,7 +64,9 @@ const DEFAULT_PAGESERVER_ID: NodeId = NodeId(1);
const DEFAULT_BRANCH_NAME: &str = "main";
project_git_version!(GIT_VERSION);
#[allow(dead_code)]
const DEFAULT_PG_VERSION: PgMajorVersion = PgMajorVersion::PG17;
const DEFAULT_PG_VERSION_NUM: &str = "17";
const DEFAULT_PAGESERVER_CONTROL_PLANE_API: &str = "http://127.0.0.1:1234/upcall/v1/";
@@ -167,7 +169,7 @@ struct TenantCreateCmdArgs {
#[clap(short = 'c')]
config: Vec<String>,
#[arg(default_value_t = DEFAULT_PG_VERSION)]
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[clap(long, help = "Postgres version to use for the initial timeline")]
pg_version: PgMajorVersion,
@@ -290,7 +292,7 @@ struct TimelineCreateCmdArgs {
#[clap(long, help = "Human-readable alias for the new timeline")]
branch_name: String,
#[arg(default_value_t = DEFAULT_PG_VERSION)]
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[clap(long, help = "Postgres version")]
pg_version: PgMajorVersion,
}
@@ -322,7 +324,7 @@ struct TimelineImportCmdArgs {
#[clap(long, help = "Lsn the basebackup ends at")]
end_lsn: Option<Lsn>,
#[arg(default_value_t = DEFAULT_PG_VERSION)]
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[clap(long, help = "Postgres version of the backup being imported")]
pg_version: PgMajorVersion,
}
@@ -601,7 +603,7 @@ struct EndpointCreateCmdArgs {
)]
config_only: bool,
#[arg(default_value_t = DEFAULT_PG_VERSION)]
#[arg(default_value = DEFAULT_PG_VERSION_NUM)]
#[clap(long, help = "Postgres version")]
pg_version: PgMajorVersion,
@@ -673,6 +675,16 @@ struct EndpointStartCmdArgs {
#[arg(default_value = "90s")]
start_timeout: Duration,
#[clap(
long,
help = "Download LFC cache from endpoint storage on endpoint startup",
default_value = "false"
)]
autoprewarm: bool,
#[clap(long, help = "Upload LFC cache to endpoint storage periodically")]
offload_lfc_interval_seconds: Option<std::num::NonZeroU64>,
#[clap(
long,
help = "Run in development mode, skipping VM-specific operations like process termination",
@@ -1595,22 +1607,24 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
let endpoint_storage_token = env.generate_auth_token(&claims)?;
let endpoint_storage_addr = env.endpoint_storage.listen_addr.to_string();
let args = control_plane::endpoint::EndpointStartArgs {
auth_token,
endpoint_storage_token,
endpoint_storage_addr,
safekeepers_generation,
safekeepers,
pageserver_conninfo,
remote_ext_base_url: remote_ext_base_url.clone(),
shard_stripe_size: stripe_size.0 as usize,
create_test_user: args.create_test_user,
start_timeout: args.start_timeout,
autoprewarm: args.autoprewarm,
offload_lfc_interval_seconds: args.offload_lfc_interval_seconds,
dev: args.dev,
};
println!("Starting existing endpoint {endpoint_id}...");
endpoint
.start(
&auth_token,
endpoint_storage_token,
endpoint_storage_addr,
safekeepers_generation,
safekeepers,
pageserver_conninfo,
remote_ext_base_url.as_ref(),
stripe_size.0 as usize,
args.create_test_user,
args.start_timeout,
args.dev,
)
.await?;
endpoint.start(args).await?;
}
EndpointCmd::Reconfigure(args) => {
let endpoint_id = &args.endpoint_id;

View File

@@ -376,6 +376,22 @@ impl std::fmt::Display for EndpointTerminateMode {
}
}
pub struct EndpointStartArgs {
pub auth_token: Option<String>,
pub endpoint_storage_token: String,
pub endpoint_storage_addr: String,
pub safekeepers_generation: Option<SafekeeperGeneration>,
pub safekeepers: Vec<NodeId>,
pub pageserver_conninfo: PageserverConnectionInfo,
pub remote_ext_base_url: Option<String>,
pub shard_stripe_size: usize,
pub create_test_user: bool,
pub start_timeout: Duration,
pub autoprewarm: bool,
pub offload_lfc_interval_seconds: Option<std::num::NonZeroU64>,
pub dev: bool,
}
impl Endpoint {
fn from_dir_entry(entry: std::fs::DirEntry, env: &LocalEnv) -> Result<Endpoint> {
if !entry.file_type()?.is_dir() {
@@ -672,21 +688,7 @@ impl Endpoint {
})
}
#[allow(clippy::too_many_arguments)]
pub async fn start(
&self,
auth_token: &Option<String>,
endpoint_storage_token: String,
endpoint_storage_addr: String,
safekeepers_generation: Option<SafekeeperGeneration>,
safekeepers: Vec<NodeId>,
pageserver_conninfo: PageserverConnectionInfo,
remote_ext_base_url: Option<&String>,
shard_stripe_size: usize,
create_test_user: bool,
start_timeout: Duration,
dev: bool,
) -> Result<()> {
pub async fn start(&self, args: EndpointStartArgs) -> Result<()> {
if self.status() == EndpointStatus::Running {
anyhow::bail!("The endpoint is already running");
}
@@ -699,7 +701,7 @@ impl Endpoint {
std::fs::remove_dir_all(self.pgdata())?;
}
let safekeeper_connstrings = self.build_safekeepers_connstrs(safekeepers)?;
let safekeeper_connstrings = self.build_safekeepers_connstrs(args.safekeepers)?;
// check for file remote_extensions_spec.json
// if it is present, read it and pass to compute_ctl
@@ -727,7 +729,7 @@ impl Endpoint {
cluster_id: None, // project ID: not used
name: None, // project name: not used
state: None,
roles: if create_test_user {
roles: if args.create_test_user {
vec![Role {
name: PgIdent::from_str("test").unwrap(),
encrypted_password: None,
@@ -736,7 +738,7 @@ impl Endpoint {
} else {
Vec::new()
},
databases: if create_test_user {
databases: if args.create_test_user {
vec![Database {
name: PgIdent::from_str("neondb").unwrap(),
owner: PgIdent::from_str("test").unwrap(),
@@ -757,21 +759,23 @@ impl Endpoint {
branch_id: None,
endpoint_id: Some(self.endpoint_id.clone()),
mode: self.mode,
pageserver_connection_info: Some(pageserver_conninfo),
safekeepers_generation: safekeepers_generation.map(|g| g.into_inner()),
pageserver_connection_info: Some(args.pageserver_conninfo),
safekeepers_generation: args.safekeepers_generation.map(|g| g.into_inner()),
safekeeper_connstrings,
storage_auth_token: auth_token.clone(),
storage_auth_token: args.auth_token.clone(),
remote_extensions,
pgbouncer_settings: None,
shard_stripe_size: Some(shard_stripe_size),
shard_stripe_size: Some(args.shard_stripe_size),
local_proxy_config: None,
reconfigure_concurrency: self.reconfigure_concurrency,
drop_subscriptions_before_start: self.drop_subscriptions_before_start,
audit_log_level: ComputeAudit::Disabled,
logs_export_host: None::<String>,
endpoint_storage_addr: Some(endpoint_storage_addr),
endpoint_storage_token: Some(endpoint_storage_token),
autoprewarm: false,
endpoint_storage_addr: Some(args.endpoint_storage_addr),
endpoint_storage_token: Some(args.endpoint_storage_token),
autoprewarm: args.autoprewarm,
offload_lfc_interval_seconds: args.offload_lfc_interval_seconds,
suspend_timeout_seconds: -1, // Only used in neon_local.
};
// this strange code is needed to support respec() in tests
@@ -782,7 +786,7 @@ impl Endpoint {
debug!("spec.cluster {:?}", spec.cluster);
// fill missing fields again
if create_test_user {
if args.create_test_user {
spec.cluster.roles.push(Role {
name: PgIdent::from_str("test").unwrap(),
encrypted_password: None,
@@ -817,7 +821,7 @@ impl Endpoint {
// Launch compute_ctl
let conn_str = self.connstr("cloud_admin", "postgres");
println!("Starting postgres node at '{conn_str}'");
if create_test_user {
if args.create_test_user {
let conn_str = self.connstr("test", "neondb");
println!("Also at '{conn_str}'");
}
@@ -849,11 +853,11 @@ impl Endpoint {
.stderr(logfile.try_clone()?)
.stdout(logfile);
if let Some(remote_ext_base_url) = remote_ext_base_url {
cmd.args(["--remote-ext-base-url", remote_ext_base_url]);
if let Some(remote_ext_base_url) = args.remote_ext_base_url {
cmd.args(["--remote-ext-base-url", &remote_ext_base_url]);
}
if dev {
if args.dev {
cmd.arg("--dev");
}
@@ -885,10 +889,11 @@ impl Endpoint {
Ok(state) => {
match state.status {
ComputeStatus::Init => {
if Instant::now().duration_since(start_at) > start_timeout {
let timeout = args.start_timeout;
if Instant::now().duration_since(start_at) > timeout {
bail!(
"compute startup timed out {:?}; still in Init state",
start_timeout
timeout
);
}
// keep retrying
@@ -916,9 +921,10 @@ impl Endpoint {
}
}
Err(e) => {
if Instant::now().duration_since(start_at) > start_timeout {
if Instant::now().duration_since(start_at) > args.start_timeout {
return Err(e).context(format!(
"timed out {start_timeout:?} waiting to connect to compute_ctl HTTP",
"timed out {:?} waiting to connect to compute_ctl HTTP",
args.start_timeout
));
}
}

View File

@@ -65,12 +65,27 @@ enum Command {
#[arg(long)]
scheduling: Option<NodeSchedulingPolicy>,
},
// Set a node status as deleted.
/// Exists for backup usage and will be removed in future.
/// Use [`Command::NodeStartDelete`] instead, if possible.
NodeDelete {
#[arg(long)]
node_id: NodeId,
},
/// Start deletion of the specified pageserver.
NodeStartDelete {
#[arg(long)]
node_id: NodeId,
},
/// Cancel deletion of the specified pageserver and wait for `timeout`
/// for the operation to be canceled. May be retried.
NodeCancelDelete {
#[arg(long)]
node_id: NodeId,
#[arg(long)]
timeout: humantime::Duration,
},
/// Delete a tombstone of node from the storage controller.
/// This is used when we want to allow the node to be re-registered.
NodeDeleteTombstone {
#[arg(long)]
node_id: NodeId,
@@ -912,10 +927,43 @@ async fn main() -> anyhow::Result<()> {
.await?;
}
Command::NodeDelete { node_id } => {
eprintln!("Warning: This command is obsolete and will be removed in a future version");
eprintln!("Use `NodeStartDelete` instead, if possible");
storcon_client
.dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None)
.await?;
}
Command::NodeStartDelete { node_id } => {
storcon_client
.dispatch::<(), ()>(
Method::PUT,
format!("control/v1/node/{node_id}/delete"),
None,
)
.await?;
println!("Delete started for {node_id}");
}
Command::NodeCancelDelete { node_id, timeout } => {
storcon_client
.dispatch::<(), ()>(
Method::DELETE,
format!("control/v1/node/{node_id}/delete"),
None,
)
.await?;
println!("Waiting for node {node_id} to quiesce on scheduling policy ...");
let final_policy =
wait_for_scheduling_policy(storcon_client, node_id, *timeout, |sched| {
!matches!(sched, NodeSchedulingPolicy::Deleting)
})
.await?;
println!(
"Delete was cancelled for node {node_id}. Schedulling policy is now {final_policy:?}"
);
}
Command::NodeDeleteTombstone { node_id } => {
storcon_client
.dispatch::<(), ()>(

View File

@@ -4,6 +4,7 @@
"timestamp": "2022-10-12T18:00:00.000Z",
"operation_uuid": "0f657b36-4b0f-4a2d-9c2e-1dcd615e7d8c",
"suspend_timeout_seconds": -1,
"cluster": {
"cluster_id": "docker_compose",

View File

@@ -0,0 +1,179 @@
# Storage Feature Flags
In this RFC, we will describe how we will implement per-tenant feature flags.
## PostHog as Feature Flag Service
Before we start, let's talk about how current feature flag services work. PostHog is the feature flag service we are currently using across multiple user-facing components in the company. PostHog has two modes of operation: HTTP evaluation and server-side local evaluation.
Let's assume we have a storage feature flag called gc-compaction and we want to roll it out to scale-tier users with resident size >= 10GB and <= 100GB.
### Define User Profiles
The first step is to synchronize our user profiles to the PostHog service. We can simply assume that each tenant is a user in PostHog. Each user profile has some properties associated with it. In our case, it will be: plan type (free, scale, enterprise, etc); resident size (in bytes); primary pageserver (string); region (string).
### Define Feature Flags
We would create a feature flag called gc-compaction in PostHog with 4 variants: disabled, stage-1, stage-2, fully-enabled. We will flip the feature flags from disabled to fully-enabled stage by stage for some percentage of our users.
### Option 1: HTTP Evaluation Mode
When using PostHog's HTTP evaluation mode, the client will make request to the PostHog service, asking for the value of a feature flag for a specific user.
* Control plane will report the plan type to PostHog each time it attaches a tenant to the storcon or when the user upgrades/downgrades. It calls the PostHog profile API to associate tenant ID with the plan type. Assume we have X active tenants and such attach or plan change event happens each week, that would be 4X profile update requests per month.
* Pageservers will report the resident size and the primary pageserver to the PostHog service. Assume we report resident size every 24 hours, that would be 30X requests per month.
* Each tenant will request the state of the feature flag every 1 hour, that's 720X requests per month.
* The Rust client would be easy to implement as we only need to call the `/decide` API on PostHog.
Using the HTTP evaluation mode we will issue 754X requests a month.
### Option 2: Local Evaluation Mode
When using PostHog's HTTP evaluation mode, the client (usually the server in a browser/server architecture) will poll the feature flag configuration every 30s (default in the Python client) from PostHog. Such configuration contains data like:
<details>
<summary>Example JSON response from the PostHog local evaluation API</summary>
```
[
{
"id": 1,
"name": "Beta Feature",
"key": "person-flag",
"is_simple_flag": True,
"active": True,
"filters": {
"groups": [
{
"properties": [
{
"key": "location",
"operator": "exact",
"value": ["Straße"],
"type": "person",
}
],
"rollout_percentage": 100,
},
{
"properties": [
{
"key": "star",
"operator": "exact",
"value": ["ſun"],
"type": "person",
}
],
"rollout_percentage": 100,
},
],
},
}
]
```
</details>
Note that the API only contains information like "under what condition => rollout percentage". The user is responsible to provide the properties required to the client for local evaluation, and the PostHog service (web UI) cannot know if a feature is enabled for the tenant or not until the client uses the `capture` API to report the result back. To control the rollout percentage, the user ID gets mapped to a float number in `[0, 1)` on a consistent hash ring. All values <= the percentage will get the feature enabled or set to the desired value.
To use the local evaluation mode, the system needs:
* Assume each pageserver will poll PostHog for the local evaluation JSON every 5 minutes (instead of the 30s default as it's too frequent). That's 8640Y per month, Y is the number of pageservers. Local evaluation requests cost 10x more than the normal decide request, so that's 86400Y request units to bill.
* Storcon needs to store the plan type in the database and pass that information to the pageserver when attaching the tenant.
* Storcon also needs to update PostHog with the active tenants, for example, when the tenant gets detached/attached. Assume each active tenant gets detached/attached every week, that would be 4X requests per month.
* We do not need to update bill type or resident size to PostHog as all these are evaluated locally.
* After each local evaluation of the feature flag, we need to call PostHog's capture event API to update the result of the evaluation that the feature is enabled. We can do this when the flag gets changed compared with the last cached state in memory. That would be at least 4X (assume we do deployment every week so the cache gets cleared) and maybe an additional multiplifier of 10 assume we have 10 active features.
In this case, we will issue 86400Y + 40X requests per month.
Assume X = 1,000,000 and Y = 100,
| | HTTP Evaluation | Local Evaluation |
|---|---|---|
| Latency of propagating the conditions/properties for feature flag | 24 hours | available locally |
| Latency of applying the feature flag | 1 hour | 5 minutes |
| Can properties be reported from different services | Yes | No |
| Do we need to sync billing info etc to pageserver | No | Yes |
| Cost | 75400$ / month | 4864$ / month |
# Our Solution
We will use PostHog _only_ as an UI to configure the feature flags. Whether a feature is enabled or not can only be queried through storcon/pageserver instead of using the PostHog UI. (We could report it back to PostHog via `capture_event` but it costs $$$.) This allows us to ramp up the feature flag functionality fast at first. At the same time, it would also give us the option to migrate to our own solution once we want to have more properties and more complex evaluation rules in our system.
* We will create several fake users (tenants) in PostHog that contains all the properties we will use for evaluating a feature flag (i.e., resident size, billing type, pageserver id, etc.)
* We will use PostHog's local evaluation API to poll the configuration of the feature flags and evaluate them locally on each of the pageserver.
* The evaluation result will not be reported back to PostHog.
* Storcon needs to pull some information from cplane database.
* To know if a feature is currently enabled or not, we need to call the storcon/pageserver API; and we won't be able to know if a feature has been enabled on a tenant before easily: we need to look at the Grafana logs.
We only need to pay for the 86400Y local evaluation requests (that would be setting Y=0 in solution 2 => $864/month, and even less if we proxy it through storcon).
## Implementation
* Pageserver: implement a PostHog local evaluation client. The client will be shared across all tenants on the pageserver with a single API: `evaluate(tenant_id, feature_flag, properties) -> json`.
* Storcon: if we need plan type as the evaluation condition, pull it from cplane database.
* Storcon/Pageserver: implement an HTTP API `:tenant_id/feature/:feature` to retrieve the current feature flag status.
* Storcon/Pageserver: a loop to update the feature flag spec on both storcon and pageserver. Pageserver loop will only be activated if storcon does not push the specs to the pageserver.
## Difference from Tenant Config
* Feature flags can be modified by percentage, and the default config for each feature flag can be modified in UI without going through the release process.
* Feature flags are more flexible and won't be persisted anywhere and will be passed as plain JSON over the wire so that do not need to handle backward/forward compatibility as in tenant config.
* The expectation of tenant config is that once we add a flag we cannot remove it (or it will be hard to remove), but feature flags are more flexible.
# Final Implementation
* We added a new crate `posthog_lite_client` that supports local feature evaluations.
* We set up two projects "Storage (staging)" and "Storage (production)" in the PostHog console.
* Each pageserver reports 10 fake tenants to PostHog so that we can get all combinations of regions (and other properties) in the PostHog UI.
* Supported properties: AZ, neon_region, pageserver, tenant_id.
* You may use "Pageserver Feature Flags" dashboard to see the evaluation status.
* The feature flag spec is polled on storcon every 30s (in each of the region) and storcon will propagate the spec to the pageservers.
* The pageserver housekeeping loop updates the tenant-specific properties (e.g., remote size) for evaluation.
Each tenant has a `feature_resolver` object. After you add a feature flag in the PostHog console, you can retrieve it with:
```rust
// Boolean flag
self
.feature_resolver
.evaluate_boolean("flag")
.is_ok()
// Multivariate flag
self
.feature_resolver
.evaluate_multivariate("gc-comapction-strategy")
.ok();
```
The user needs to handle the case where the evaluation result is an error. This can occur in a variety of cases:
* During the pageserver start, the feature flag spec has not been retrieved.
* No condition group is matched.
* The feature flag spec contains an operand/operation not supported by the lite PostHog library.
For boolean flags, the return value is `Result<(), Error>`. `Ok(())` means the flag is evaluated to true. Otherwise,
there is either an error in evaluation or it does not match any groups.
For multivariate flags, the return value is `Result<String, Error>`. `Ok(variant)` indicates the flag is evaluated
to a variant. Otherwise, there is either an error in evaluation or it does not match any groups.
The evaluation logic is documented in the PostHog lite library. It compares the consistent hash of a flag key + tenant_id
with the rollout percentage and determines which tenant to roll out a specific feature.
Users can use the feature flag evaluation API to get the flag evaluation result of a specific tenant for debugging purposes.
```
curl http://localhost:9898/v1/tenant/:tenant_id/feature_flag?flag=:key&as=multivariate/boolean"
```
By default, the storcon pushes the feature flag specs to the pageservers every 30 seconds, which means that a change in feature flag in the
PostHog UI will propagate to the pageservers within 30 seconds.
# Future Works
* Support dynamic tenant properties like logical size as the evaluation condition.
* Support properties like `plan_type` (needs cplane to pass it down).
* Report feature flag evaluation result back to PostHog (if the cost is okay).
* Fast feature flag evaluation cache on critical paths (e.g., cache a feature flag result in `AtomicBool` and use it on the read path).

View File

@@ -0,0 +1,399 @@
# Compute rolling restart with prewarm
Created on 2025-03-17
Implemented on _TBD_
Author: Alexey Kondratov (@ololobus)
## Summary
This RFC describes an approach to reduce performance degradation due to missing caches after compute node restart, i.e.:
1. Rolling restart of the running instance via 'warm' replica.
2. Auto-prewarm compute caches after unplanned restart or scale-to-zero.
## Motivation
Neon currently implements several features that guarantee high uptime of compute nodes:
1. Storage high-availability (HA), i.e. each tenant shard has a secondary pageserver location, so we can quickly switch over compute to it in case of primary pageserver failure.
2. Fast compute provisioning, i.e. we have a fleet of pre-created empty computes, that are ready to serve workload, so restarting unresponsive compute is very fast.
3. Preemptive NeonVM compute provisioning in case of k8s node unavailability.
This helps us to be well-within the uptime SLO of 99.95% most of the time. Problems begin when we go up to multi-TB workloads and 32-64 CU computes.
During restart, compute loses all caches: LFC, shared buffers, file system cache. Depending on the workload, it can take a lot of time to warm up the caches,
so that performance could be degraded and might be even unacceptable for certain workloads. The latter means that although current approach works well for small to
medium workloads, we still have to do some additional work to avoid performance degradation after restart of large instances.
## Non Goals
- Details of the persistence storage for prewarm data are out of scope, there is a separate RFC for that: <https://github.com/neondatabase/neon/pull/9661>.
- Complete compute/Postgres HA setup and flow. Although it was originally in scope of this RFC, during preliminary research it appeared to be a rabbit hole, so it's worth of a separate RFC.
- Low-level implementation details for Postgres replica-to-primary promotion. There are a lot of things to think and care about: how to start walproposer, [logical replication failover](https://www.postgresql.org/docs/current/logical-replication-failover.html), and so on, but it's worth of at least a separate one-pager design document if not RFC.
## Impacted components
Postgres, compute_ctl, Control plane, Endpoint storage for unlogged storage of compute files.
For the latter, we will need to implement a uniform abstraction layer on top of S3, ABS, etc., but
S3 is used in text interchangeably with 'endpoint storage' for simplicity.
## Proposed implementation
### compute_ctl spec changes and auto-prewarm
We are going to extend the current compute spec with the following attributes
```rust
struct ComputeSpec {
/// [All existing attributes]
...
/// Whether to do auto-prewarm at start or not.
/// Default to `false`.
pub lfc_auto_prewarm: bool
/// Interval in seconds between automatic dumps of
/// LFC state into S3. Default `None`, which means 'off'.
pub lfc_dump_interval_sec: Option<i32>
}
```
When `lfc_dump_interval_sec` is set to `N`, `compute_ctl` will periodically dump the LFC state
and store it in S3, so that it could be used either for auto-prewarm after restart or by replica
during the rolling restart. For enabling periodic dumping, we should consider the following value
`lfc_dump_interval_sec=300` (5 minutes), same as in the upstream's `pg_prewarm.autoprewarm_interval`.
When `lfc_auto_prewarm` is set to `true`, `compute_ctl` will start prewarming the LFC upon restart
iif some of the previous states is present in S3.
### compute_ctl API
1. `POST /store_lfc_state` -- dump LFC state using Postgres SQL interface and store result in S3.
This has to be a blocking call, i.e. it will return only after the state is stored in S3.
If there is any concurrent request in progress, we should return `429 Too Many Requests`,
and let the caller to retry.
2. `GET /dump_lfc_state` -- dump LFC state using Postgres SQL interface and return it as is
in text format suitable for the future restore/prewarm. This API is not strictly needed at
the end state, but could be useful for a faster prototyping of a complete rolling restart flow
with prewarm, as it doesn't require persistent for LFC state storage.
3. `POST /restore_lfc_state` -- restore/prewarm LFC state with request
```yaml
RestoreLFCStateRequest:
oneOf:
- type: object
required:
- lfc_state
properties:
lfc_state:
type: string
description: Raw LFC content dumped with GET `/dump_lfc_state`
- type: object
required:
- lfc_cache_key
properties:
lfc_cache_key:
type: string
description: |
endpoint_id of the source endpoint on the same branch
to use as a 'donor' for LFC content. Compute will look up
LFC content dump in S3 using this key and do prewarm.
```
where `lfc_state` and `lfc_cache_key` are mutually exclusive.
The actual prewarming will happen asynchronously, so the caller need to check the
prewarm status using the compute's standard `GET /status` API.
4. `GET /status` -- extend existing API with following attributes
```rust
struct ComputeStatusResponse {
// [All existing attributes]
...
pub prewarm_state: PrewarmState
}
/// Compute prewarm state. Will be stored in the shared Compute state
/// in compute_ctl
struct PrewarmState {
pub status: PrewarmStatus
/// Total number of pages to prewarm
pub pages_total: i64
/// Number of pages prewarmed so far
pub pages_processed: i64
/// Optional prewarm error
pub error: Option<String>
}
pub enum PrewarmStatus {
/// Prewarming was never requested on this compute
Off,
/// Prewarming was requested, but not started yet
Pending,
/// Prewarming is in progress. The caller should follow
/// `PrewarmState::progress`.
InProgress,
/// Prewarming has been successfully completed
Completed,
/// Prewarming failed. The caller should look at
/// `PrewarmState::error` for the reason.
Failed,
/// It is intended to be used by auto-prewarm if none of
/// the previous LFC states is available in S3.
/// This is a distinct state from the `Failed` because
/// technically it's not a failure and could happen if
/// compute was restart before it dumped anything into S3,
/// or just after the initial rollout of the feature.
Skipped,
}
```
5. `POST /promote` -- this is a **blocking** API call to promote compute replica into primary.
This API should be very similar to the existing `POST /configure` API, i.e. accept the
spec (primary spec, because originally compute was started as replica). It's a distinct
API method because semantics and response codes are different:
- If promotion is done successfully, it will return `200 OK`.
- If compute is already primary, the call will be no-op and `compute_ctl`
will return `412 Precondition Failed`.
- If, for some reason, second request reaches compute that is in progress of promotion,
it will respond with `429 Too Many Requests`.
- If compute hit any permanent failure during promotion `500 Internal Server Error`
will be returned.
### Control plane operations
The complete flow will be present as a sequence diagram in the next section, but here
we just want to list some important steps that have to be done by control plane during
the rolling restart via warm replica, but without much of low-level implementation details.
1. Register the 'intent' of the instance restart, but not yet interrupt any workload at
primary and also accept new connections. This may require some endpoint state machine
changes, e.g. introduction of the `pending_restart` state. Being in this state also
**mustn't prevent any other operations except restart**: suspend, live-reconfiguration
(e.g. due to notify-attach call from the storage controller), deletion.
2. Start new replica compute on the same timeline and start prewarming it. This process
may take quite a while, so the same concurrency considerations as in 1. should be applied
here as well.
3. When warm replica is ready, control plane should:
3.1. Terminate the primary compute. Starting from here, **this is a critical section**,
if anything goes off, the only option is to start the primary normally and proceed
with auto-prewarm.
3.2. Send cache invalidation message to all proxies, notifying them that all new connections
should request and wait for the new connection details. At this stage, proxy has to also
drop any existing connections to the old primary, so they didn't do stale reads.
3.3. Attach warm replica compute to the primary endpoint inside control plane metadata
database.
3.4. Promote replica to primary.
3.5. When everything is done, finalize the endpoint state to be just `active`.
### Complete rolling restart flow
```mermaid
sequenceDiagram
autonumber
participant proxy as Neon proxy
participant cplane as Control plane
participant primary as Compute (primary)
box Compute (replica)
participant ctl as compute_ctl
participant pg as Postgres
end
box Endpoint unlogged storage
participant s3proxy as Endpoint storage service
participant s3 as S3/ABS/etc.
end
cplane ->> primary: POST /store_lfc_state
primary -->> cplane: 200 OK
cplane ->> ctl: POST /restore_lfc_state
activate ctl
ctl -->> cplane: 202 Accepted
activate cplane
cplane ->> ctl: GET /status: poll prewarm status
ctl ->> s3proxy: GET /read_file
s3proxy ->> s3: read file
s3 -->> s3proxy: file content
s3proxy -->> ctl: 200 OK: file content
proxy ->> cplane: GET /proxy_wake_compute
cplane -->> proxy: 200 OK: old primary conninfo
ctl ->> pg: prewarm LFC
activate pg
pg -->> ctl: prewarm is completed
deactivate pg
ctl -->> cplane: 200 OK: prewarm is completed
deactivate ctl
deactivate cplane
cplane -->> cplane: reassign replica compute to endpoint,<br>start terminating the old primary compute
activate cplane
cplane ->> proxy: invalidate caches
proxy ->> cplane: GET /proxy_wake_compute
cplane -x primary: POST /terminate
primary -->> cplane: 200 OK
note over primary: old primary<br>compute terminated
cplane ->> ctl: POST /promote
activate ctl
ctl ->> pg: pg_ctl promote
activate pg
pg -->> ctl: done
deactivate pg
ctl -->> cplane: 200 OK
deactivate ctl
cplane -->> cplane: finalize operation
cplane -->> proxy: 200 OK: new primary conninfo
deactivate cplane
```
### Network bandwidth and prewarm speed
It's currently known that pageserver can sustain about 3000 RPS per shard for a few running computes.
Large tenants are usually split into 8 shards, so the final formula may look like this:
```text
8 shards * 3000 RPS * 8 KB =~ 190 MB/s
```
so depending on the LFC size, prewarming will take at least:
- ~5s for 1 GB
- ~50s for 10 GB
- ~5m for 100 GB
- \>1h for 1 TB
In total, one pageserver is normally capped by 30k RPS, so it obviously can't sustain many computes
doing prewarm at the same time. Later, we may need an additional mechanism for computes to throttle
the prewarming requests gracefully.
### Reliability, failure modes and corner cases
We consider following failures while implementing this RFC:
1. Compute got interrupted/crashed/restarted during prewarm. The caller -- control plane -- should
detect that and start prewarm from the beginning.
2. Control plane promotion request timed out or hit network issues. If it never reached the
compute, control plane should just repeat it. If it did reach the compute, then during
retry control plane can hit `409` as previous request triggered the promotion already.
In this case, control plane need to retry until either `200` or
permanent error `500` is returned.
3. Compute got interrupted/crashed/restarted during promotion. At restart it will ask for
a spec from control plane, and its content should signal compute to start as **primary**,
so it's expected that control plane will continue polling for certain period of time and
will discover that compute is ready to accept connections if restart is fast enough.
4. Any other unexpected failure or timeout during prewarming. This **failure mustn't be fatal**,
control plane has to report failure, terminate replica and keep primary running.
5. Any other unexpected failure or timeout during promotion. Unfortunately, at this moment
we already have the primary node stopped, so the only option is to start primary again
and proceed with auto-prewarm.
6. Any unexpected failure during auto-prewarm. This **failure mustn't be fatal**,
`compute_ctl` has to report the failure, but do not crash the compute.
7. Control plane failed to confirm that old primary has terminated. This can happen, especially
in the future HA setup. In this case, control plane has to ensure that it sent VM deletion
and pod termination requests to k8s, so long-term we do not have two running primaries
on the same timeline.
### Security implications
There are two security implications to consider:
1. Access to `compute_ctl` API. It has to be accessible from the outside of compute, so all
new API methods have to be exposed on the **external** HTTP port and **must** be authenticated
with JWT.
2. Read/write only your own LFC state data in S3. Although it's not really a security concern,
since LFC state is just a mapping of blocks present in LFC at certain moment in time;
it still has to be highly restricted, so that i) only computes on the same timeline can
read S3 state; ii) each compute can only write to the path that contains it's `endpoint_id`.
Both of this must be validated by Endpoint storage service using the JWT token provided by `compute_ctl`.
### Unresolved questions
#### Billing, metrics and monitoring
Currently, we only label computes with `endpoint_id` after attaching them to the endpoint.
In this proposal, this means that temporary replica will remain unlabelled until it's promoted
to primary. We can also hide it from users in the control plane API, but what to do with
billing and monitoring is still unclear.
We can probably mark it as 'billable' and tag with `project_id`, so it will be billed, but
not interfere in any way with the current primary monitoring.
Another thing to consider is how logs and metrics export will switch to the new compute.
It's expected that OpenTelemetry collector will auto-discover the new compute and start
scraping metrics from it.
#### Auto-prewarm
It's still an open question whether we need auto-prewarm at all. The author's gut-feeling is
that yes, we need it, but might be not for all workloads, so it could end up exposed as a
user-controllable knob on the endpoint. There are two arguments for that:
1. Auto-prewarm existing in upstream's `pg_prewarm`, _probably for a reason_.
2. There are still could be 2 flows when we cannot perform the rolling restart via the warm
replica: i) any failure or interruption during promotion; ii) wake up after scale-to-zero.
The latter might be challenged as well, i.e. one can argue that auto-prewarm may and will
compete with user-workload for storage resources. This is correct, but it might as well
reduce the time to get warm LFC and good performance.
#### Low-level details of the replica promotion
There are many things to consider here, but three items just off the top of my head:
1. How to properly start the `walproposer` inside Postgres.
2. What to do with logical replication. Currently, we do not include logical replication slots
inside basebackup, because nobody advances them at replica, so they just prevent the WAL
deletion. Yet, we do need to have them at primary after promotion. Starting with Postgres 17,
there is a new feature called
[logical replication failover](https://www.postgresql.org/docs/current/logical-replication-failover.html)
and `synchronized_standby_slots` setting, but we need a plan for the older versions. Should we
request a new basebackup during promotion?
3. How do we guarantee that replica will receive all the latest WAL from safekeepers? Do some
'shallow' version of sync safekeepers without data copying? Or just a standard version of
sync safekeepers?
## Alternative implementation
The proposal already assumes one of the alternatives -- do not have any persistent storage for
LFC state. This is possible to implement faster with the proposed API, but it means that
we do not implement auto-prewarm yet.
## Definition of Done
At the end of implementing this RFC we should have two high-level settings that enable:
1. Auto-prewarm of user computes upon restart.
2. Perform primary compute restart via the warm replica promotion.
It also has to be decided what's the criteria for enabling one or both of these flows for
certain clients.

View File

@@ -58,7 +58,7 @@ pub enum LfcPrewarmState {
},
}
#[derive(Serialize, Default, Debug, Clone)]
#[derive(Serialize, Default, Debug, Clone, PartialEq)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum LfcOffloadState {
#[default]

View File

@@ -185,9 +185,18 @@ pub struct ComputeSpec {
/// JWT for authorizing requests to endpoint storage service
pub endpoint_storage_token: Option<String>,
/// Download LFC state from endpoint_storage and pass it to Postgres on startup
#[serde(default)]
/// Download LFC state from endpoint storage and pass it to Postgres on compute startup
pub autoprewarm: bool,
#[serde(default)]
/// Upload LFC state to endpoint storage periodically. Default value (None) means "don't upload"
pub offload_lfc_interval_seconds: Option<std::num::NonZeroU64>,
/// Suspend timeout in seconds.
///
/// We use this value to derive other values, such as the installed extensions metric.
pub suspend_timeout_seconds: i64,
}
/// Feature flag to signal `compute_ctl` to enable certain experimental functionality.

View File

@@ -3,6 +3,7 @@
"timestamp": "2021-05-23T18:25:43.511Z",
"operation_uuid": "0f657b36-4b0f-4a2d-9c2e-1dcd615e7d8b",
"suspend_timeout_seconds": 3600,
"cluster": {
"cluster_id": "test-cluster-42",
@@ -89,6 +90,11 @@
"value": "off",
"vartype": "bool"
},
{
"name": "offload_lfc_interval_seconds",
"value": "20",
"vartype": "integer"
},
{
"name": "neon.safekeepers",
"value": "127.0.0.1:6502,127.0.0.1:6503,127.0.0.1:6501",

View File

@@ -386,6 +386,7 @@ pub enum NodeSchedulingPolicy {
Pause,
PauseForRestart,
Draining,
Deleting,
}
impl FromStr for NodeSchedulingPolicy {
@@ -398,6 +399,7 @@ impl FromStr for NodeSchedulingPolicy {
"pause" => Ok(Self::Pause),
"pause_for_restart" => Ok(Self::PauseForRestart),
"draining" => Ok(Self::Draining),
"deleting" => Ok(Self::Deleting),
_ => Err(anyhow::anyhow!("Unknown scheduling state '{s}'")),
}
}
@@ -412,6 +414,7 @@ impl From<NodeSchedulingPolicy> for String {
Pause => "pause",
PauseForRestart => "pause_for_restart",
Draining => "draining",
Deleting => "deleting",
}
.to_string()
}
@@ -420,6 +423,7 @@ impl From<NodeSchedulingPolicy> for String {
#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)]
pub enum SkSchedulingPolicy {
Active,
Activating,
Pause,
Decomissioned,
}
@@ -430,6 +434,7 @@ impl FromStr for SkSchedulingPolicy {
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s {
"active" => Self::Active,
"activating" => Self::Activating,
"pause" => Self::Pause,
"decomissioned" => Self::Decomissioned,
_ => {
@@ -446,6 +451,7 @@ impl From<SkSchedulingPolicy> for String {
use SkSchedulingPolicy::*;
match value {
Active => "active",
Activating => "activating",
Pause => "pause",
Decomissioned => "decomissioned",
}

View File

@@ -78,7 +78,13 @@ pub fn is_expected_io_error(e: &io::Error) -> bool {
use io::ErrorKind::*;
matches!(
e.kind(),
BrokenPipe | ConnectionRefused | ConnectionAborted | ConnectionReset | TimedOut
HostUnreachable
| NetworkUnreachable
| BrokenPipe
| ConnectionRefused
| ConnectionAborted
| ConnectionReset
| TimedOut,
)
}

View File

@@ -52,7 +52,7 @@ pub(crate) async fn hi(str: &[u8], salt: &[u8], iterations: u32) -> [u8; 32] {
}
// yield every ~250us
// hopefully reduces tail latencies
if i % 1024 == 0 {
if i.is_multiple_of(1024) {
yield_now().await
}
}

View File

@@ -90,7 +90,7 @@ pub struct InnerClient {
}
impl InnerClient {
pub fn start(&mut self) -> Result<PartialQuery, Error> {
pub fn start(&mut self) -> Result<PartialQuery<'_>, Error> {
self.responses.waiting += 1;
Ok(PartialQuery(Some(self)))
}
@@ -227,7 +227,7 @@ impl Client {
&mut self,
statement: &str,
params: I,
) -> Result<RowStream, Error>
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,
@@ -262,7 +262,7 @@ impl Client {
pub(crate) async fn simple_query_raw(
&mut self,
query: &str,
) -> Result<SimpleQueryStream, Error> {
) -> Result<SimpleQueryStream<'_>, Error> {
simple_query::simple_query(self.inner_mut(), query).await
}

View File

@@ -12,7 +12,11 @@ mod private {
/// This trait is "sealed", and cannot be implemented outside of this crate.
pub trait GenericClient: private::Sealed {
/// Like `Client::query_raw_txt`.
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
@@ -22,7 +26,11 @@ pub trait GenericClient: private::Sealed {
impl private::Sealed for Client {}
impl GenericClient for Client {
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,
@@ -35,7 +43,11 @@ impl GenericClient for Client {
impl private::Sealed for Transaction<'_> {}
impl GenericClient for Transaction<'_> {
async fn query_raw_txt<S, I>(&mut self, statement: &str, params: I) -> Result<RowStream, Error>
async fn query_raw_txt<S, I>(
&mut self,
statement: &str,
params: I,
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str> + Sync + Send,
I: IntoIterator<Item = Option<S>> + Sync + Send,

View File

@@ -47,7 +47,7 @@ impl<'a> Transaction<'a> {
&mut self,
statement: &str,
params: I,
) -> Result<RowStream, Error>
) -> Result<RowStream<'_>, Error>
where
S: AsRef<str>,
I: IntoIterator<Item = Option<S>>,

View File

@@ -24,12 +24,28 @@ macro_rules! critical {
if cfg!(debug_assertions) {
panic!($($arg)*);
}
// Increment both metrics
$crate::logging::TRACING_EVENT_COUNT_METRIC.inc_critical();
let backtrace = std::backtrace::Backtrace::capture();
tracing::error!("CRITICAL: {}\n{backtrace}", format!($($arg)*));
}};
}
#[macro_export]
macro_rules! critical_timeline {
($tenant_shard_id:expr, $timeline_id:expr, $($arg:tt)*) => {{
if cfg!(debug_assertions) {
panic!($($arg)*);
}
// Increment both metrics
$crate::logging::TRACING_EVENT_COUNT_METRIC.inc_critical();
$crate::logging::HADRON_CRITICAL_STORAGE_EVENT_COUNT_METRIC.inc(&$tenant_shard_id.to_string(), &$timeline_id.to_string());
let backtrace = std::backtrace::Backtrace::capture();
tracing::error!("CRITICAL: [tenant_shard_id: {}, timeline_id: {}] {}\n{backtrace}",
$tenant_shard_id, $timeline_id, format!($($arg)*));
}};
}
#[derive(EnumString, strum_macros::Display, VariantNames, Eq, PartialEq, Debug, Clone, Copy)]
#[strum(serialize_all = "snake_case")]
pub enum LogFormat {
@@ -61,6 +77,36 @@ pub struct TracingEventCountMetric {
trace: IntCounter,
}
// Begin Hadron: Add a HadronCriticalStorageEventCountMetric metric that is sliced by tenant_id and timeline_id
pub struct HadronCriticalStorageEventCountMetric {
critical: IntCounterVec,
}
pub static HADRON_CRITICAL_STORAGE_EVENT_COUNT_METRIC: Lazy<HadronCriticalStorageEventCountMetric> =
Lazy::new(|| {
let vec = metrics::register_int_counter_vec!(
"hadron_critical_storage_event_count",
"Number of critical storage events, by tenant_id and timeline_id",
&["tenant_shard_id", "timeline_id"]
)
.expect("failed to define metric");
HadronCriticalStorageEventCountMetric::new(vec)
});
impl HadronCriticalStorageEventCountMetric {
fn new(vec: IntCounterVec) -> Self {
Self { critical: vec }
}
// Allow public access from `critical!` macro.
pub fn inc(&self, tenant_shard_id: &str, timeline_id: &str) {
self.critical
.with_label_values(&[tenant_shard_id, timeline_id])
.inc();
}
}
// End Hadron
pub static TRACING_EVENT_COUNT_METRIC: Lazy<TracingEventCountMetric> = Lazy::new(|| {
let vec = metrics::register_int_counter_vec!(
"libmetrics_tracing_event_count",

View File

@@ -28,6 +28,7 @@ use reqwest::Url;
use storage_broker::Uri;
use utils::id::{NodeId, TimelineId};
use utils::logging::{LogFormat, SecretString};
use utils::serde_percent::Percent;
use crate::tenant::storage_layer::inmemory_layer::IndexEntry;
use crate::tenant::{TENANTS_SEGMENT_NAME, TIMELINES_SEGMENT_NAME};
@@ -459,7 +460,16 @@ impl PageServerConf {
metric_collection_endpoint,
metric_collection_bucket,
synthetic_size_calculation_interval,
disk_usage_based_eviction,
disk_usage_based_eviction: Some(disk_usage_based_eviction.unwrap_or(
DiskUsageEvictionTaskConfig {
max_usage_pct: Percent::new(80).unwrap(),
min_avail_bytes: 2_000_000_000,
period: Duration::from_secs(60),
#[cfg(feature = "testing")]
mock_statvfs: None,
eviction_order: Default::default(),
},
)),
test_remote_failures,
ondemand_download_behavior_treat_error_as_warn,
background_task_maximum_delay,
@@ -697,6 +707,8 @@ impl ConfigurableSemaphore {
#[cfg(test)]
mod tests {
use std::time::Duration;
use camino::Utf8PathBuf;
use rstest::rstest;
use utils::id::NodeId;
@@ -798,4 +810,20 @@ mod tests {
PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir)
.expect("parse_and_validate");
}
#[test]
fn test_config_disk_usage_based_eviction_is_valid() {
let input = r#"
control_plane_api = "http://localhost:6666"
"#;
let config_toml = toml_edit::de::from_str::<pageserver_api::config::ConfigToml>(input)
.expect("disk_usage_based_eviction is valid");
let workdir = Utf8PathBuf::from("/nonexistent");
let config = PageServerConf::parse_and_validate(NodeId(0), config_toml, &workdir).unwrap();
let disk_usage_based_eviction = config.disk_usage_based_eviction.unwrap();
assert_eq!(disk_usage_based_eviction.max_usage_pct.get(), 80);
assert_eq!(disk_usage_based_eviction.min_avail_bytes, 2_000_000_000);
assert_eq!(disk_usage_based_eviction.period, Duration::from_secs(60));
assert_eq!(disk_usage_based_eviction.eviction_order, Default::default());
}
}

View File

@@ -99,7 +99,7 @@ pub(super) async fn upload_metrics_bucket(
// Compose object path
let datetime: DateTime<Utc> = SystemTime::now().into();
let ts_prefix = datetime.format("year=%Y/month=%m/day=%d/%H:%M:%SZ");
let ts_prefix = datetime.format("year=%Y/month=%m/day=%d/hour=%H/%H:%M:%SZ");
let path = RemotePath::from_string(&format!("{ts_prefix}_{node_id}.ndjson.gz"))?;
// Set up a gzip writer into a buffer
@@ -109,7 +109,7 @@ pub(super) async fn upload_metrics_bucket(
// Serialize and write into compressed buffer
let started_at = std::time::Instant::now();
for res in serialize_in_chunks(CHUNK_SIZE, metrics, idempotency_keys) {
for res in serialize_in_chunks_ndjson(CHUNK_SIZE, metrics, idempotency_keys) {
let (_chunk, body) = res?;
gzip_writer.write_all(&body).await?;
}
@@ -216,6 +216,86 @@ fn serialize_in_chunks<'a>(
}
}
/// Serializes the input metrics as NDJSON in chunks of chunk_size. Each event
/// is serialized as a separate JSON object on its own line. The provided
/// idempotency keys are injected into the corresponding metric events (reused
/// across different metrics sinks), and must have the same length as input.
fn serialize_in_chunks_ndjson<'a>(
chunk_size: usize,
input: &'a [NewRawMetric],
idempotency_keys: &'a [IdempotencyKey<'a>],
) -> impl ExactSizeIterator<Item = Result<(&'a [NewRawMetric], bytes::Bytes), serde_json::Error>> + 'a
{
use bytes::BufMut;
assert_eq!(input.len(), idempotency_keys.len());
struct Iter<'a> {
inner: std::slice::Chunks<'a, NewRawMetric>,
idempotency_keys: std::slice::Iter<'a, IdempotencyKey<'a>>,
chunk_size: usize,
// write to a BytesMut so that we can cheaply clone the frozen Bytes for retries
buffer: bytes::BytesMut,
// chunk amount of events are reused to produce the serialized document
scratch: Vec<Event<Ids, Name>>,
}
impl<'a> Iterator for Iter<'a> {
type Item = Result<(&'a [NewRawMetric], bytes::Bytes), serde_json::Error>;
fn next(&mut self) -> Option<Self::Item> {
let chunk = self.inner.next()?;
if self.scratch.is_empty() {
// first round: create events with N strings
self.scratch.extend(
chunk
.iter()
.zip(&mut self.idempotency_keys)
.map(|(raw_metric, key)| raw_metric.as_event(key)),
);
} else {
// next rounds: update_in_place to reuse allocations
assert_eq!(self.scratch.len(), self.chunk_size);
itertools::izip!(self.scratch.iter_mut(), chunk, &mut self.idempotency_keys)
.for_each(|(slot, raw_metric, key)| raw_metric.update_in_place(slot, key));
}
// Serialize each event as NDJSON (one JSON object per line)
for event in self.scratch[..chunk.len()].iter() {
let res = serde_json::to_writer((&mut self.buffer).writer(), event);
if let Err(e) = res {
return Some(Err(e));
}
// Add newline after each event to follow NDJSON format
self.buffer.put_u8(b'\n');
}
Some(Ok((chunk, self.buffer.split().freeze())))
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
impl ExactSizeIterator for Iter<'_> {}
let buffer = bytes::BytesMut::new();
let inner = input.chunks(chunk_size);
let idempotency_keys = idempotency_keys.iter();
let scratch = Vec::new();
Iter {
inner,
idempotency_keys,
chunk_size,
buffer,
scratch,
}
}
trait RawMetricExt {
fn as_event(&self, key: &IdempotencyKey<'_>) -> Event<Ids, Name>;
fn update_in_place(&self, event: &mut Event<Ids, Name>, key: &IdempotencyKey<'_>);
@@ -479,6 +559,43 @@ mod tests {
}
}
#[test]
fn chunked_serialization_ndjson() {
let examples = metric_samples();
assert!(examples.len() > 1);
let now = Utc::now();
let idempotency_keys = (0..examples.len())
.map(|i| FixedGen::new(now, "1", i as u16).generate())
.collect::<Vec<_>>();
// Parse NDJSON format - each line is a separate JSON object
let parse_ndjson = |body: &[u8]| -> Vec<Event<Ids, Name>> {
let body_str = std::str::from_utf8(body).unwrap();
body_str
.trim_end_matches('\n')
.lines()
.filter(|line| !line.is_empty())
.map(|line| serde_json::from_str::<Event<Ids, Name>>(line).unwrap())
.collect()
};
let correct = serialize_in_chunks_ndjson(examples.len(), &examples, &idempotency_keys)
.map(|res| res.unwrap().1)
.flat_map(|body| parse_ndjson(&body))
.collect::<Vec<_>>();
for chunk_size in 1..examples.len() {
let actual = serialize_in_chunks_ndjson(chunk_size, &examples, &idempotency_keys)
.map(|res| res.unwrap().1)
.flat_map(|body| parse_ndjson(&body))
.collect::<Vec<_>>();
// if these are equal, it means that multi-chunking version works as well
assert_eq!(correct, actual);
}
}
#[derive(Clone, Copy)]
struct FixedGen<'a>(chrono::DateTime<chrono::Utc>, &'a str, u16);

View File

@@ -6,12 +6,13 @@ use posthog_client_lite::{
CaptureEvent, FeatureResolverBackgroundLoop, PostHogEvaluationError,
PostHogFlagFilterPropertyValue,
};
use rand::Rng;
use remote_storage::RemoteStorageKind;
use serde_json::json;
use tokio_util::sync::CancellationToken;
use utils::id::TenantId;
use crate::{config::PageServerConf, metrics::FEATURE_FLAG_EVALUATION};
use crate::{config::PageServerConf, metrics::FEATURE_FLAG_EVALUATION, tenant::TenantShard};
const DEFAULT_POSTHOG_REFRESH_INTERVAL: Duration = Duration::from_secs(600);
@@ -138,6 +139,7 @@ impl FeatureResolver {
}
Arc::new(properties)
};
let fake_tenants = {
let mut tenants = Vec::new();
for i in 0..10 {
@@ -147,9 +149,16 @@ impl FeatureResolver {
conf.id,
i
);
let tenant_properties = PerTenantProperties {
remote_size_mb: Some(rand::thread_rng().gen_range(100.0..1000000.00)),
}
.into_posthog_properties();
let properties = Self::collect_properties_inner(
distinct_id.clone(),
Some(&internal_properties),
&tenant_properties,
);
tenants.push(CaptureEvent {
event: "initial_tenant_report".to_string(),
@@ -183,6 +192,7 @@ impl FeatureResolver {
fn collect_properties_inner(
tenant_id: String,
internal_properties: Option<&HashMap<String, PostHogFlagFilterPropertyValue>>,
tenant_properties: &HashMap<String, PostHogFlagFilterPropertyValue>,
) -> HashMap<String, PostHogFlagFilterPropertyValue> {
let mut properties = HashMap::new();
if let Some(internal_properties) = internal_properties {
@@ -194,6 +204,9 @@ impl FeatureResolver {
"tenant_id".to_string(),
PostHogFlagFilterPropertyValue::String(tenant_id),
);
for (key, value) in tenant_properties.iter() {
properties.insert(key.clone(), value.clone());
}
properties
}
@@ -201,8 +214,13 @@ impl FeatureResolver {
pub(crate) fn collect_properties(
&self,
tenant_id: TenantId,
tenant_properties: &HashMap<String, PostHogFlagFilterPropertyValue>,
) -> HashMap<String, PostHogFlagFilterPropertyValue> {
Self::collect_properties_inner(tenant_id.to_string(), self.internal_properties.as_deref())
Self::collect_properties_inner(
tenant_id.to_string(),
self.internal_properties.as_deref(),
tenant_properties,
)
}
/// Evaluate a multivariate feature flag. Currently, we do not support any properties.
@@ -214,6 +232,7 @@ impl FeatureResolver {
&self,
flag_key: &str,
tenant_id: TenantId,
tenant_properties: &HashMap<String, PostHogFlagFilterPropertyValue>,
) -> Result<String, PostHogEvaluationError> {
let force_overrides = self.force_overrides_for_testing.load();
if let Some(value) = force_overrides.get(flag_key) {
@@ -224,7 +243,7 @@ impl FeatureResolver {
let res = inner.feature_store().evaluate_multivariate(
flag_key,
&tenant_id.to_string(),
&self.collect_properties(tenant_id),
&self.collect_properties(tenant_id, tenant_properties),
);
match &res {
Ok(value) => {
@@ -257,6 +276,7 @@ impl FeatureResolver {
&self,
flag_key: &str,
tenant_id: TenantId,
tenant_properties: &HashMap<String, PostHogFlagFilterPropertyValue>,
) -> Result<(), PostHogEvaluationError> {
let force_overrides = self.force_overrides_for_testing.load();
if let Some(value) = force_overrides.get(flag_key) {
@@ -271,7 +291,7 @@ impl FeatureResolver {
let res = inner.feature_store().evaluate_boolean(
flag_key,
&tenant_id.to_string(),
&self.collect_properties(tenant_id),
&self.collect_properties(tenant_id, tenant_properties),
);
match &res {
Ok(()) => {
@@ -317,3 +337,78 @@ impl FeatureResolver {
.store(Arc::new(force_overrides));
}
}
struct PerTenantProperties {
pub remote_size_mb: Option<f64>,
}
impl PerTenantProperties {
pub fn into_posthog_properties(self) -> HashMap<String, PostHogFlagFilterPropertyValue> {
let mut properties = HashMap::new();
if let Some(remote_size_mb) = self.remote_size_mb {
properties.insert(
"tenant_remote_size_mb".to_string(),
PostHogFlagFilterPropertyValue::Number(remote_size_mb),
);
}
properties
}
}
#[derive(Clone)]
pub struct TenantFeatureResolver {
inner: FeatureResolver,
tenant_id: TenantId,
cached_tenant_properties: Arc<ArcSwap<HashMap<String, PostHogFlagFilterPropertyValue>>>,
}
impl TenantFeatureResolver {
pub fn new(inner: FeatureResolver, tenant_id: TenantId) -> Self {
Self {
inner,
tenant_id,
cached_tenant_properties: Arc::new(ArcSwap::new(Arc::new(HashMap::new()))),
}
}
pub fn evaluate_multivariate(&self, flag_key: &str) -> Result<String, PostHogEvaluationError> {
self.inner.evaluate_multivariate(
flag_key,
self.tenant_id,
&self.cached_tenant_properties.load(),
)
}
pub fn evaluate_boolean(&self, flag_key: &str) -> Result<(), PostHogEvaluationError> {
self.inner.evaluate_boolean(
flag_key,
self.tenant_id,
&self.cached_tenant_properties.load(),
)
}
pub fn collect_properties(&self) -> HashMap<String, PostHogFlagFilterPropertyValue> {
self.inner
.collect_properties(self.tenant_id, &self.cached_tenant_properties.load())
}
pub fn is_feature_flag_boolean(&self, flag_key: &str) -> Result<bool, PostHogEvaluationError> {
self.inner.is_feature_flag_boolean(flag_key)
}
pub fn update_cached_tenant_properties(&self, tenant_shard: &TenantShard) {
let mut remote_size_mb = None;
for timeline in tenant_shard.list_timelines() {
let size = timeline.metrics.resident_physical_size_get();
if size == 0 {
remote_size_mb = None;
}
if let Some(ref mut remote_size_mb) = remote_size_mb {
*remote_size_mb += size as f64 / 1024.0 / 1024.0;
}
}
self.cached_tenant_properties.store(Arc::new(
PerTenantProperties { remote_size_mb }.into_posthog_properties(),
));
}
}

View File

@@ -2438,6 +2438,7 @@ async fn timeline_offload_handler(
.map_err(|e| {
match e {
OffloadError::Cancelled => ApiError::ResourceUnavailable("Timeline shutting down".into()),
OffloadError::AlreadyInProgress => ApiError::Conflict("Timeline already being offloaded or deleted".into()),
_ => ApiError::InternalServerError(anyhow!(e))
}
})?;
@@ -3697,23 +3698,25 @@ async fn tenant_evaluate_feature_flag(
let tenant = state
.tenant_manager
.get_attached_tenant_shard(tenant_shard_id)?;
let properties = tenant.feature_resolver.collect_properties(tenant_shard_id.tenant_id);
// TODO: the properties we get here might be stale right after it is collected. But such races are rare (updated every 10s)
// and we don't need to worry about it for now.
let properties = tenant.feature_resolver.collect_properties();
if as_type.as_deref() == Some("boolean") {
let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id);
let result = tenant.feature_resolver.evaluate_boolean(&flag);
let result = result.map(|_| true).map_err(|e| e.to_string());
json_response(StatusCode::OK, json!({ "result": result, "properties": properties }))
} else if as_type.as_deref() == Some("multivariate") {
let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string());
let result = tenant.feature_resolver.evaluate_multivariate(&flag).map_err(|e| e.to_string());
json_response(StatusCode::OK, json!({ "result": result, "properties": properties }))
} else {
// Auto infer the type of the feature flag.
let is_boolean = tenant.feature_resolver.is_feature_flag_boolean(&flag).map_err(|e| ApiError::InternalServerError(anyhow::anyhow!("{e}")))?;
if is_boolean {
let result = tenant.feature_resolver.evaluate_boolean(&flag, tenant_shard_id.tenant_id);
let result = tenant.feature_resolver.evaluate_boolean(&flag);
let result = result.map(|_| true).map_err(|e| e.to_string());
json_response(StatusCode::OK, json!({ "result": result, "properties": properties }))
} else {
let result = tenant.feature_resolver.evaluate_multivariate(&flag, tenant_shard_id.tenant_id).map_err(|e| e.to_string());
let result = tenant.feature_resolver.evaluate_multivariate(&flag).map_err(|e| e.to_string());
json_response(StatusCode::OK, json!({ "result": result, "properties": properties }))
}
}

View File

@@ -86,7 +86,7 @@ use crate::context;
use crate::context::RequestContextBuilder;
use crate::context::{DownloadBehavior, RequestContext};
use crate::deletion_queue::{DeletionQueueClient, DeletionQueueError};
use crate::feature_resolver::FeatureResolver;
use crate::feature_resolver::{FeatureResolver, TenantFeatureResolver};
use crate::l0_flush::L0FlushGlobalState;
use crate::metrics::{
BROKEN_TENANTS_SET, CIRCUIT_BREAKERS_BROKEN, CIRCUIT_BREAKERS_UNBROKEN, CONCURRENT_INITDBS,
@@ -386,7 +386,7 @@ pub struct TenantShard {
l0_flush_global_state: L0FlushGlobalState,
pub(crate) feature_resolver: FeatureResolver,
pub(crate) feature_resolver: TenantFeatureResolver,
}
impl std::fmt::Debug for TenantShard {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -3263,7 +3263,7 @@ impl TenantShard {
};
let gc_compaction_strategy = self
.feature_resolver
.evaluate_multivariate("gc-comapction-strategy", self.tenant_shard_id.tenant_id)
.evaluate_multivariate("gc-comapction-strategy")
.ok();
let span = if let Some(gc_compaction_strategy) = gc_compaction_strategy {
info_span!("gc_compact_timeline", timeline_id = %timeline.timeline_id, strategy = %gc_compaction_strategy)
@@ -3285,6 +3285,7 @@ impl TenantShard {
.or_else(|err| match err {
// Ignore this, we likely raced with unarchival.
OffloadError::NotArchived => Ok(()),
OffloadError::AlreadyInProgress => Ok(()),
err => Err(err),
})?;
}
@@ -3408,6 +3409,9 @@ impl TenantShard {
if let Some(ref walredo_mgr) = self.walredo_mgr {
walredo_mgr.maybe_quiesce(WALREDO_IDLE_TIMEOUT);
}
// Update the feature resolver with the latest tenant-spcific data.
self.feature_resolver.update_cached_tenant_properties(self);
}
pub fn timeline_has_no_attached_children(&self, timeline_id: TimelineId) -> bool {
@@ -4490,7 +4494,10 @@ impl TenantShard {
gc_block: Default::default(),
l0_flush_global_state,
basebackup_cache,
feature_resolver,
feature_resolver: TenantFeatureResolver::new(
feature_resolver,
tenant_shard_id.tenant_id,
),
}
}

View File

@@ -182,7 +182,7 @@ impl BatchLayerWriter {
/// An image writer that takes images and produces multiple image layers.
#[must_use]
pub struct SplitImageLayerWriter<'a> {
inner: ImageLayerWriter,
inner: Option<ImageLayerWriter>,
target_layer_size: u64,
lsn: Lsn,
conf: &'static PageServerConf,
@@ -196,7 +196,7 @@ pub struct SplitImageLayerWriter<'a> {
impl<'a> SplitImageLayerWriter<'a> {
#[allow(clippy::too_many_arguments)]
pub async fn new(
pub fn new(
conf: &'static PageServerConf,
timeline_id: TimelineId,
tenant_shard_id: TenantShardId,
@@ -205,22 +205,10 @@ impl<'a> SplitImageLayerWriter<'a> {
target_layer_size: u64,
gate: &'a utils::sync::gate::Gate,
cancel: CancellationToken,
ctx: &RequestContext,
) -> anyhow::Result<Self> {
Ok(Self {
) -> Self {
Self {
target_layer_size,
// XXX make this lazy like in SplitDeltaLayerWriter?
inner: ImageLayerWriter::new(
conf,
timeline_id,
tenant_shard_id,
&(start_key..Key::MAX),
lsn,
gate,
cancel.clone(),
ctx,
)
.await?,
inner: None,
conf,
timeline_id,
tenant_shard_id,
@@ -229,7 +217,7 @@ impl<'a> SplitImageLayerWriter<'a> {
start_key,
gate,
cancel,
})
}
}
pub async fn put_image(
@@ -238,12 +226,31 @@ impl<'a> SplitImageLayerWriter<'a> {
img: Bytes,
ctx: &RequestContext,
) -> Result<(), PutError> {
if self.inner.is_none() {
self.inner = Some(
ImageLayerWriter::new(
self.conf,
self.timeline_id,
self.tenant_shard_id,
&(self.start_key..Key::MAX),
self.lsn,
self.gate,
self.cancel.clone(),
ctx,
)
.await
.map_err(PutError::Other)?,
);
}
let inner = self.inner.as_mut().unwrap();
// The current estimation is an upper bound of the space that the key/image could take
// because we did not consider compression in this estimation. The resulting image layer
// could be smaller than the target size.
let addition_size_estimation = KEY_SIZE as u64 + img.len() as u64;
if self.inner.num_keys() >= 1
&& self.inner.estimated_size() + addition_size_estimation >= self.target_layer_size
if inner.num_keys() >= 1
&& inner.estimated_size() + addition_size_estimation >= self.target_layer_size
{
let next_image_writer = ImageLayerWriter::new(
self.conf,
@@ -257,7 +264,7 @@ impl<'a> SplitImageLayerWriter<'a> {
)
.await
.map_err(PutError::Other)?;
let prev_image_writer = std::mem::replace(&mut self.inner, next_image_writer);
let prev_image_writer = std::mem::replace(inner, next_image_writer);
self.batches.add_unfinished_image_writer(
prev_image_writer,
self.start_key..key,
@@ -265,7 +272,7 @@ impl<'a> SplitImageLayerWriter<'a> {
);
self.start_key = key;
}
self.inner.put_image(key, img, ctx).await
inner.put_image(key, img, ctx).await
}
pub(crate) async fn finish_with_discard_fn<D, F>(
@@ -282,8 +289,10 @@ impl<'a> SplitImageLayerWriter<'a> {
let Self {
mut batches, inner, ..
} = self;
if inner.num_keys() != 0 {
batches.add_unfinished_image_writer(inner, self.start_key..end_key, self.lsn);
if let Some(inner) = inner {
if inner.num_keys() != 0 {
batches.add_unfinished_image_writer(inner, self.start_key..end_key, self.lsn);
}
}
batches.finish_with_discard_fn(tline, ctx, discard_fn).await
}
@@ -498,10 +507,7 @@ mod tests {
4 * 1024 * 1024,
&tline.gate,
tline.cancel.clone(),
&ctx,
)
.await
.unwrap();
);
let mut delta_writer = SplitDeltaLayerWriter::new(
tenant.conf,
@@ -577,10 +583,7 @@ mod tests {
4 * 1024 * 1024,
&tline.gate,
tline.cancel.clone(),
&ctx,
)
.await
.unwrap();
);
let mut delta_writer = SplitDeltaLayerWriter::new(
tenant.conf,
tline.timeline_id,
@@ -676,10 +679,7 @@ mod tests {
4 * 1024,
&tline.gate,
tline.cancel.clone(),
&ctx,
)
.await
.unwrap();
);
let mut delta_writer = SplitDeltaLayerWriter::new(
tenant.conf,

View File

@@ -78,7 +78,7 @@ use utils::rate_limit::RateLimit;
use utils::seqwait::SeqWait;
use utils::simple_rcu::{Rcu, RcuReadGuard};
use utils::sync::gate::{Gate, GateGuard};
use utils::{completion, critical, fs_ext, pausable_failpoint};
use utils::{completion, critical_timeline, fs_ext, pausable_failpoint};
#[cfg(test)]
use wal_decoder::models::value::Value;
use wal_decoder::serialized_batch::{SerializedValueBatch, ValueMeta};
@@ -106,7 +106,7 @@ use crate::context::{
DownloadBehavior, PerfInstrumentFutureExt, RequestContext, RequestContextBuilder,
};
use crate::disk_usage_eviction_task::{DiskUsageEvictionInfo, EvictionCandidate, finite_f32};
use crate::feature_resolver::FeatureResolver;
use crate::feature_resolver::TenantFeatureResolver;
use crate::keyspace::{KeyPartitioning, KeySpace};
use crate::l0_flush::{self, L0FlushGlobalState};
use crate::metrics::{
@@ -202,7 +202,7 @@ pub struct TimelineResources {
pub l0_compaction_trigger: Arc<Notify>,
pub l0_flush_global_state: l0_flush::L0FlushGlobalState,
pub basebackup_cache: Arc<BasebackupCache>,
pub feature_resolver: FeatureResolver,
pub feature_resolver: TenantFeatureResolver,
}
pub struct Timeline {
@@ -450,7 +450,7 @@ pub struct Timeline {
/// A channel to send async requests to prepare a basebackup for the basebackup cache.
basebackup_cache: Arc<BasebackupCache>,
feature_resolver: FeatureResolver,
feature_resolver: TenantFeatureResolver,
}
pub(crate) enum PreviousHeatmap {
@@ -2144,14 +2144,31 @@ impl Timeline {
debug_assert_current_span_has_tenant_and_timeline_id();
// Regardless of whether we're going to try_freeze_and_flush
// or not, stop ingesting any more data.
// cancel walreceiver to stop ingesting more data asap.
//
// Note that we're accepting a race condition here where we may
// do the final flush below, before walreceiver observes the
// cancellation and exits.
// This means we may open a new InMemoryLayer after the final flush below.
// Flush loop is also still running for a short while, so, in theory, it
// could also make its way into the upload queue.
//
// If we wait for the shutdown of the walreceiver before moving on to the
// flush, then that would be avoided. But we don't do it because the
// walreceiver entertains reads internally, which means that it possibly
// depends on the download of layers. Layer download is only sensitive to
// the cancellation of the entire timeline, so cancelling the walreceiver
// will have no effect on the individual get requests.
// This would cause problems when there is a lot of ongoing downloads or
// there is S3 unavailabilities, i.e. detach, deletion, etc would hang,
// and we can't deallocate resources of the timeline, etc.
let walreceiver = self.walreceiver.lock().unwrap().take();
tracing::debug!(
is_some = walreceiver.is_some(),
"Waiting for WalReceiverManager..."
);
if let Some(walreceiver) = walreceiver {
walreceiver.shutdown().await;
walreceiver.cancel().await;
}
// ... and inform any waiters for newer LSNs that there won't be any.
self.last_record_lsn.shutdown();
@@ -4729,7 +4746,7 @@ impl Timeline {
}
// Fetch the next layer to flush, if any.
let (layer, l0_count, frozen_count, frozen_size) = {
let (layer, l0_count, frozen_count, frozen_size, open_layer_size) = {
let layers = self.layers.read(LayerManagerLockHolder::FlushLoop).await;
let Ok(lm) = layers.layer_map() else {
info!("dropping out of flush loop for timeline shutdown");
@@ -4742,8 +4759,13 @@ impl Timeline {
.iter()
.map(|l| l.estimated_in_mem_size())
.sum();
let open_layer_size: u64 = lm
.open_layer
.as_ref()
.map(|l| l.estimated_in_mem_size())
.unwrap_or(0);
let layer = lm.frozen_layers.front().cloned();
(layer, l0_count, frozen_count, frozen_size)
(layer, l0_count, frozen_count, frozen_size, open_layer_size)
// drop 'layers' lock
};
let Some(layer) = layer else {
@@ -4756,7 +4778,7 @@ impl Timeline {
if l0_count >= stall_threshold {
warn!(
"stalling layer flushes for compaction backpressure at {l0_count} \
L0 layers ({frozen_count} frozen layers with {frozen_size} bytes)"
L0 layers ({frozen_count} frozen layers with {frozen_size} bytes, {open_layer_size} bytes in open layer)"
);
let stall_timer = self
.metrics
@@ -4809,7 +4831,7 @@ impl Timeline {
let delay = flush_duration.as_secs_f64();
info!(
"delaying layer flush by {delay:.3}s for compaction backpressure at \
{l0_count} L0 layers ({frozen_count} frozen layers with {frozen_size} bytes)"
{l0_count} L0 layers ({frozen_count} frozen layers with {frozen_size} bytes, {open_layer_size} bytes in open layer)"
);
let _delay_timer = self
.metrics
@@ -5308,6 +5330,7 @@ impl Timeline {
ctx: &RequestContext,
img_range: Range<Key>,
io_concurrency: IoConcurrency,
progress: Option<(usize, usize)>,
) -> Result<ImageLayerCreationOutcome, CreateImageLayersError> {
let mut wrote_keys = false;
@@ -5384,11 +5407,15 @@ impl Timeline {
}
}
let progress_report = progress
.map(|(idx, total)| format!("({idx}/{total}) "))
.unwrap_or_default();
if wrote_keys {
// Normal path: we have written some data into the new image layer for this
// partition, so flush it to disk.
info!(
"produced image layer for rel {}",
"{} produced image layer for rel {}",
progress_report,
ImageLayerName {
key_range: img_range.clone(),
lsn
@@ -5398,7 +5425,12 @@ impl Timeline {
unfinished_image_layer: image_layer_writer,
})
} else {
tracing::debug!("no data in range {}-{}", img_range.start, img_range.end);
tracing::debug!(
"{} no data in range {}-{}",
progress_report,
img_range.start,
img_range.end
);
Ok(ImageLayerCreationOutcome::Empty)
}
}
@@ -5633,7 +5665,8 @@ impl Timeline {
}
}
for partition in partition_parts.iter() {
let total = partition_parts.len();
for (idx, partition) in partition_parts.iter().enumerate() {
if self.cancel.is_cancelled() {
return Err(CreateImageLayersError::Cancelled);
}
@@ -5718,6 +5751,7 @@ impl Timeline {
ctx,
img_range.clone(),
io_concurrency,
Some((idx, total)),
)
.await?
} else {
@@ -6807,7 +6841,11 @@ impl Timeline {
Err(walredo::Error::Cancelled) => return Err(PageReconstructError::Cancelled),
Err(walredo::Error::Other(err)) => {
if fire_critical_error {
critical!("walredo failure during page reconstruction: {err:?}");
critical_timeline!(
self.tenant_shard_id,
self.timeline_id,
"walredo failure during page reconstruction: {err:?}"
);
}
return Err(PageReconstructError::WalRedo(
err.context("reconstruct a page image"),

View File

@@ -9,7 +9,7 @@ use std::ops::{Deref, Range};
use std::sync::Arc;
use std::time::{Duration, Instant};
use super::layer_manager::{LayerManagerLockHolder, LayerManagerReadGuard};
use super::layer_manager::LayerManagerLockHolder;
use super::{
CompactFlags, CompactOptions, CompactionError, CreateImageLayersError, DurationRecorder,
GetVectoredError, ImageLayerCreationMode, LastImageLayerCreationStatus, RecordedDuration,
@@ -36,7 +36,7 @@ use serde::Serialize;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, info_span, trace, warn};
use utils::critical;
use utils::critical_timeline;
use utils::id::TimelineId;
use utils::lsn::Lsn;
use wal_decoder::models::record::NeonWalRecord;
@@ -101,7 +101,11 @@ pub enum GcCompactionQueueItem {
/// Whether the compaction is triggered automatically (determines whether we need to update L2 LSN)
auto: bool,
},
SubCompactionJob(CompactOptions),
SubCompactionJob {
i: usize,
total: usize,
options: CompactOptions,
},
Notify(GcCompactionJobId, Option<Lsn>),
}
@@ -163,7 +167,7 @@ impl GcCompactionQueueItem {
running,
job_id: id.0,
}),
GcCompactionQueueItem::SubCompactionJob(options) => Some(CompactInfoResponse {
GcCompactionQueueItem::SubCompactionJob { options, .. } => Some(CompactInfoResponse {
compact_key_range: options.compact_key_range,
compact_lsn_range: options.compact_lsn_range,
sub_compaction: options.sub_compaction,
@@ -489,7 +493,7 @@ impl GcCompactionQueue {
.map(|job| job.compact_lsn_range.end)
.max()
.unwrap();
for job in jobs {
for (i, job) in jobs.into_iter().enumerate() {
// Unfortunately we need to convert the `GcCompactJob` back to `CompactionOptions`
// until we do further refactors to allow directly call `compact_with_gc`.
let mut flags: EnumSet<CompactFlags> = EnumSet::default();
@@ -507,7 +511,11 @@ impl GcCompactionQueue {
compact_lsn_range: Some(job.compact_lsn_range.into()),
sub_compaction_max_job_size_mb: None,
};
pending_tasks.push(GcCompactionQueueItem::SubCompactionJob(options));
pending_tasks.push(GcCompactionQueueItem::SubCompactionJob {
options,
i,
total: jobs_len,
});
}
if !auto {
@@ -651,7 +659,7 @@ impl GcCompactionQueue {
}
}
}
GcCompactionQueueItem::SubCompactionJob(options) => {
GcCompactionQueueItem::SubCompactionJob { options, i, total } => {
// TODO: error handling, clear the queue if any task fails?
let _gc_guard = match gc_block.start().await {
Ok(guard) => guard,
@@ -663,6 +671,7 @@ impl GcCompactionQueue {
)));
}
};
info!("running gc-compaction subcompaction job {}/{}", i, total);
let res = timeline.compact_with_options(cancel, options, ctx).await;
let compaction_result = match res {
Ok(res) => res,
@@ -1310,7 +1319,7 @@ impl Timeline {
|| cfg!(feature = "testing")
|| self
.feature_resolver
.evaluate_boolean("image-compaction-boundary", self.tenant_shard_id.tenant_id)
.evaluate_boolean("image-compaction-boundary")
.is_ok()
{
let last_repartition_lsn = self.partitioning.read().1;
@@ -1381,7 +1390,11 @@ impl Timeline {
GetVectoredError::MissingKey(_),
) = err
{
critical!("missing key during compaction: {err:?}");
critical_timeline!(
self.tenant_shard_id,
self.timeline_id,
"missing key during compaction: {err:?}"
);
}
})?;
@@ -1409,7 +1422,11 @@ impl Timeline {
// Alert on critical errors that indicate data corruption.
Err(err) if err.is_critical() => {
critical!("could not compact, repartitioning keyspace failed: {err:?}");
critical_timeline!(
self.tenant_shard_id,
self.timeline_id,
"could not compact, repartitioning keyspace failed: {err:?}"
);
}
// Log other errors. No partitioning? This is normal, if the timeline was just created
@@ -1591,13 +1608,15 @@ impl Timeline {
let started = Instant::now();
let mut replace_image_layers = Vec::new();
let total = layers_to_rewrite.len();
for layer in layers_to_rewrite {
for (i, layer) in layers_to_rewrite.into_iter().enumerate() {
if self.cancel.is_cancelled() {
return Err(CompactionError::ShuttingDown);
}
info!(layer=%layer, "rewriting layer after shard split");
info!(layer=%layer, "rewriting layer after shard split: {}/{}", i, total);
let mut image_layer_writer = ImageLayerWriter::new(
self.conf,
self.timeline_id,
@@ -1779,20 +1798,14 @@ impl Timeline {
} = {
let phase1_span = info_span!("compact_level0_phase1");
let ctx = ctx.attached_child();
let mut stats = CompactLevel0Phase1StatsBuilder {
let stats = CompactLevel0Phase1StatsBuilder {
version: Some(2),
tenant_id: Some(self.tenant_shard_id),
timeline_id: Some(self.timeline_id),
..Default::default()
};
let begin = tokio::time::Instant::now();
let phase1_layers_locked = self.layers.read(LayerManagerLockHolder::Compaction).await;
let now = tokio::time::Instant::now();
stats.read_lock_acquisition_micros =
DurationRecorder::Recorded(RecordedDuration(now - begin), now);
self.compact_level0_phase1(
phase1_layers_locked,
stats,
target_file_size,
force_compaction_ignore_threshold,
@@ -1813,16 +1826,19 @@ impl Timeline {
}
/// Level0 files first phase of compaction, explained in the [`Self::compact_legacy`] comment.
async fn compact_level0_phase1<'a>(
self: &'a Arc<Self>,
guard: LayerManagerReadGuard<'a>,
async fn compact_level0_phase1(
self: &Arc<Self>,
mut stats: CompactLevel0Phase1StatsBuilder,
target_file_size: u64,
force_compaction_ignore_threshold: bool,
ctx: &RequestContext,
) -> Result<CompactLevel0Phase1Result, CompactionError> {
stats.read_lock_held_spawn_blocking_startup_micros =
stats.read_lock_acquisition_micros.till_now(); // set by caller
let begin = tokio::time::Instant::now();
let guard = self.layers.read(LayerManagerLockHolder::Compaction).await;
let now = tokio::time::Instant::now();
stats.read_lock_acquisition_micros =
DurationRecorder::Recorded(RecordedDuration(now - begin), now);
let layers = guard.layer_map()?;
let level0_deltas = layers.level0_deltas();
stats.level0_deltas_count = Some(level0_deltas.len());
@@ -1857,6 +1873,12 @@ impl Timeline {
.map(|x| guard.get_from_desc(x))
.collect::<Vec<_>>();
drop_layer_manager_rlock(guard);
// The is the last LSN that we have seen for L0 compaction in the timeline. This LSN might be updated
// by the time we finish the compaction. So we need to get it here.
let l0_last_record_lsn = self.get_last_record_lsn();
// Gather the files to compact in this iteration.
//
// Start with the oldest Level 0 delta file, and collect any other
@@ -1944,9 +1966,7 @@ impl Timeline {
// we don't accidentally use it later in the function.
drop(level0_deltas);
stats.read_lock_held_prerequisites_micros = stats
.read_lock_held_spawn_blocking_startup_micros
.till_now();
stats.compaction_prerequisites_micros = stats.read_lock_acquisition_micros.till_now();
// TODO: replace with streaming k-merge
let all_keys = {
@@ -1968,7 +1988,7 @@ impl Timeline {
all_keys
};
stats.read_lock_held_key_sort_micros = stats.read_lock_held_prerequisites_micros.till_now();
stats.read_lock_held_key_sort_micros = stats.compaction_prerequisites_micros.till_now();
// Determine N largest holes where N is number of compacted layers. The vec is sorted by key range start.
//
@@ -2002,7 +2022,6 @@ impl Timeline {
}
}
let max_holes = deltas_to_compact.len();
let last_record_lsn = self.get_last_record_lsn();
let min_hole_range = (target_file_size / page_cache::PAGE_SZ as u64) as i128;
let min_hole_coverage_size = 3; // TODO: something more flexible?
// min-heap (reserve space for one more element added before eviction)
@@ -2021,8 +2040,12 @@ impl Timeline {
// has not so much sense, because largest holes will corresponds field1/field2 changes.
// But we are mostly interested to eliminate holes which cause generation of excessive image layers.
// That is why it is better to measure size of hole as number of covering image layers.
let coverage_size =
layers.image_coverage(&key_range, last_record_lsn).len();
let coverage_size = {
// TODO: optimize this with copy-on-write layer map.
let guard = self.layers.read(LayerManagerLockHolder::Compaction).await;
let layers = guard.layer_map()?;
layers.image_coverage(&key_range, l0_last_record_lsn).len()
};
if coverage_size >= min_hole_coverage_size {
heap.push(Hole {
key_range,
@@ -2041,7 +2064,6 @@ impl Timeline {
holes
};
stats.read_lock_held_compute_holes_micros = stats.read_lock_held_key_sort_micros.till_now();
drop_layer_manager_rlock(guard);
if self.cancel.is_cancelled() {
return Err(CompactionError::ShuttingDown);
@@ -2382,9 +2404,8 @@ struct CompactLevel0Phase1StatsBuilder {
tenant_id: Option<TenantShardId>,
timeline_id: Option<TimelineId>,
read_lock_acquisition_micros: DurationRecorder,
read_lock_held_spawn_blocking_startup_micros: DurationRecorder,
read_lock_held_key_sort_micros: DurationRecorder,
read_lock_held_prerequisites_micros: DurationRecorder,
compaction_prerequisites_micros: DurationRecorder,
read_lock_held_compute_holes_micros: DurationRecorder,
read_lock_drop_micros: DurationRecorder,
write_layer_files_micros: DurationRecorder,
@@ -2399,9 +2420,8 @@ struct CompactLevel0Phase1Stats {
tenant_id: TenantShardId,
timeline_id: TimelineId,
read_lock_acquisition_micros: RecordedDuration,
read_lock_held_spawn_blocking_startup_micros: RecordedDuration,
read_lock_held_key_sort_micros: RecordedDuration,
read_lock_held_prerequisites_micros: RecordedDuration,
compaction_prerequisites_micros: RecordedDuration,
read_lock_held_compute_holes_micros: RecordedDuration,
read_lock_drop_micros: RecordedDuration,
write_layer_files_micros: RecordedDuration,
@@ -2426,16 +2446,12 @@ impl TryFrom<CompactLevel0Phase1StatsBuilder> for CompactLevel0Phase1Stats {
.read_lock_acquisition_micros
.into_recorded()
.ok_or_else(|| anyhow!("read_lock_acquisition_micros not set"))?,
read_lock_held_spawn_blocking_startup_micros: value
.read_lock_held_spawn_blocking_startup_micros
.into_recorded()
.ok_or_else(|| anyhow!("read_lock_held_spawn_blocking_startup_micros not set"))?,
read_lock_held_key_sort_micros: value
.read_lock_held_key_sort_micros
.into_recorded()
.ok_or_else(|| anyhow!("read_lock_held_key_sort_micros not set"))?,
read_lock_held_prerequisites_micros: value
.read_lock_held_prerequisites_micros
compaction_prerequisites_micros: value
.compaction_prerequisites_micros
.into_recorded()
.ok_or_else(|| anyhow!("read_lock_held_prerequisites_micros not set"))?,
read_lock_held_compute_holes_micros: value
@@ -3503,22 +3519,16 @@ impl Timeline {
// Only create image layers when there is no ancestor branches. TODO: create covering image layer
// when some condition meet.
let mut image_layer_writer = if !has_data_below {
Some(
SplitImageLayerWriter::new(
self.conf,
self.timeline_id,
self.tenant_shard_id,
job_desc.compaction_key_range.start,
lowest_retain_lsn,
self.get_compaction_target_size(),
&self.gate,
self.cancel.clone(),
ctx,
)
.await
.context("failed to create image layer writer")
.map_err(CompactionError::Other)?,
)
Some(SplitImageLayerWriter::new(
self.conf,
self.timeline_id,
self.tenant_shard_id,
job_desc.compaction_key_range.start,
lowest_retain_lsn,
self.get_compaction_target_size(),
&self.gate,
self.cancel.clone(),
))
} else {
None
};
@@ -4352,6 +4362,7 @@ impl TimelineAdaptor {
ctx,
key_range.clone(),
IoConcurrency::sequential(),
None,
)
.await?;

View File

@@ -182,6 +182,7 @@ pub(crate) async fn generate_tombstone_image_layer(
detached: &Arc<Timeline>,
ancestor: &Arc<Timeline>,
ancestor_lsn: Lsn,
historic_layers_to_copy: &Vec<Layer>,
ctx: &RequestContext,
) -> Result<Option<ResidentLayer>, Error> {
tracing::info!(
@@ -199,6 +200,20 @@ pub(crate) async fn generate_tombstone_image_layer(
let image_lsn = ancestor_lsn;
{
for layer in historic_layers_to_copy {
let desc = layer.layer_desc();
if !desc.is_delta
&& desc.lsn_range.start == image_lsn
&& overlaps_with(&key_range, &desc.key_range)
{
tracing::info!(
layer=%layer, "will copy tombstone from ancestor instead of creating a new one"
);
return Ok(None);
}
}
let layers = detached
.layers
.read(LayerManagerLockHolder::DetachAncestor)
@@ -450,7 +465,8 @@ pub(super) async fn prepare(
Vec::with_capacity(straddling_branchpoint.len() + rest_of_historic.len() + 1);
if let Some(tombstone_layer) =
generate_tombstone_image_layer(detached, &ancestor, ancestor_lsn, ctx).await?
generate_tombstone_image_layer(detached, &ancestor, ancestor_lsn, &rest_of_historic, ctx)
.await?
{
new_layers.push(tombstone_layer.into());
}
@@ -885,7 +901,7 @@ async fn remote_copy(
}
tracing::info!("Deleting orphan layer file to make way for hard linking");
// Delete orphan layer file and try again, to ensure this layer has a well understood source
std::fs::remove_file(adopted_path)
std::fs::remove_file(&adoptee_path)
.map_err(|e| Error::launder(e.into(), Error::Prepare))?;
std::fs::hard_link(adopted_path, &adoptee_path)
.map_err(|e| Error::launder(e.into(), Error::Prepare))?;

View File

@@ -19,6 +19,8 @@ pub(crate) enum OffloadError {
NotArchived,
#[error(transparent)]
RemoteStorage(anyhow::Error),
#[error("Offload or deletion already in progress")]
AlreadyInProgress,
#[error("Unexpected offload error: {0}")]
Other(anyhow::Error),
}
@@ -44,20 +46,26 @@ pub(crate) async fn offload_timeline(
timeline.timeline_id,
TimelineDeleteGuardKind::Offload,
);
if let Err(DeleteTimelineError::HasChildren(children)) = delete_guard_res {
let is_archived = timeline.is_archived();
if is_archived == Some(true) {
tracing::error!("timeline is archived but has non-archived children: {children:?}");
let (timeline, guard) = match delete_guard_res {
Ok(timeline_and_guard) => timeline_and_guard,
Err(DeleteTimelineError::HasChildren(children)) => {
let is_archived = timeline.is_archived();
if is_archived == Some(true) {
tracing::error!("timeline is archived but has non-archived children: {children:?}");
return Err(OffloadError::NotArchived);
}
tracing::info!(
?is_archived,
"timeline is not archived and has unarchived children"
);
return Err(OffloadError::NotArchived);
}
tracing::info!(
?is_archived,
"timeline is not archived and has unarchived children"
);
return Err(OffloadError::NotArchived);
Err(DeleteTimelineError::AlreadyInProgress(_)) => {
tracing::info!("timeline offload or deletion already in progress");
return Err(OffloadError::AlreadyInProgress);
}
Err(e) => return Err(OffloadError::Other(anyhow::anyhow!(e))),
};
let (timeline, guard) =
delete_guard_res.map_err(|e| OffloadError::Other(anyhow::anyhow!(e)))?;
let TimelineOrOffloaded::Timeline(timeline) = timeline else {
tracing::error!("timeline already offloaded, but given timeline object");

View File

@@ -63,7 +63,6 @@ pub struct WalReceiver {
/// All task spawned by [`WalReceiver::start`] and its children are sensitive to this token.
/// It's a child token of [`Timeline`] so that timeline shutdown can cancel WalReceiver tasks early for `freeze_and_flush=true`.
cancel: CancellationToken,
task: tokio::task::JoinHandle<()>,
}
impl WalReceiver {
@@ -80,7 +79,7 @@ impl WalReceiver {
let loop_status = Arc::new(std::sync::RwLock::new(None));
let manager_status = Arc::clone(&loop_status);
let cancel = timeline.cancel.child_token();
let task = WALRECEIVER_RUNTIME.spawn({
let _task = WALRECEIVER_RUNTIME.spawn({
let cancel = cancel.clone();
async move {
debug_assert_current_span_has_tenant_and_timeline_id();
@@ -121,25 +120,14 @@ impl WalReceiver {
Self {
manager_status,
cancel,
task,
}
}
#[instrument(skip_all, level = tracing::Level::DEBUG)]
pub async fn shutdown(self) {
pub async fn cancel(self) {
debug_assert_current_span_has_tenant_and_timeline_id();
debug!("cancelling walreceiver tasks");
self.cancel.cancel();
match self.task.await {
Ok(()) => debug!("Shutdown success"),
Err(je) if je.is_cancelled() => unreachable!("not used"),
Err(je) if je.is_panic() => {
// already logged by panic hook
}
Err(je) => {
error!("shutdown walreceiver task join error: {je}")
}
}
}
pub(crate) fn status(&self) -> Option<ConnectionManagerStatus> {

View File

@@ -100,6 +100,7 @@ pub(super) async fn connection_manager_loop_step(
// with other streams on this client (other connection managers). When
// object goes out of scope, stream finishes in drop() automatically.
let mut broker_subscription = subscribe_for_timeline_updates(broker_client, id, cancel).await?;
let mut broker_reset_interval = tokio::time::interval(tokio::time::Duration::from_secs(30));
debug!("Subscribed for broker timeline updates");
loop {
@@ -156,7 +157,10 @@ pub(super) async fn connection_manager_loop_step(
// Got a new update from the broker
broker_update = broker_subscription.message() /* TODO: review cancellation-safety */ => {
match broker_update {
Ok(Some(broker_update)) => connection_manager_state.register_timeline_update(broker_update),
Ok(Some(broker_update)) => {
broker_reset_interval.reset();
connection_manager_state.register_timeline_update(broker_update);
},
Err(status) => {
match status.code() {
Code::Unknown if status.message().contains("stream closed because of a broken pipe") || status.message().contains("connection reset") || status.message().contains("error reading a body from connection") => {
@@ -178,6 +182,14 @@ pub(super) async fn connection_manager_loop_step(
}
},
_ = broker_reset_interval.tick() => {
if wait_lsn_status.borrow().is_some() {
tracing::warn!("No broker updates received for a while, but waiting for WAL. Re-setting stream ...")
}
broker_subscription = subscribe_for_timeline_updates(broker_client, id, cancel).await?;
},
new_event = async {
// Reminder: this match arm needs to be cancellation-safe.
loop {

View File

@@ -25,7 +25,7 @@ use tokio_postgres::replication::ReplicationStream;
use tokio_postgres::{Client, SimpleQueryMessage, SimpleQueryRow};
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, debug, error, info, trace, warn};
use utils::critical;
use utils::critical_timeline;
use utils::id::NodeId;
use utils::lsn::Lsn;
use utils::pageserver_feedback::PageserverFeedback;
@@ -275,20 +275,12 @@ pub(super) async fn handle_walreceiver_connection(
let copy_stream = replication_client.copy_both_simple(&query).await?;
let mut physical_stream = pin!(ReplicationStream::new(copy_stream));
let walingest_future = WalIngest::new(timeline.as_ref(), startpoint, &ctx);
let walingest_res = select! {
walingest_res = walingest_future => walingest_res,
_ = cancellation.cancelled() => {
// We are doing reads in WalIngest::new, and those can hang as they come from the network.
// Timeline cancellation hits the walreceiver cancellation token before it hits the timeline global one.
debug!("Connection cancelled");
return Err(WalReceiverError::Cancelled);
},
};
let mut walingest = walingest_res.map_err(|e| match e.kind {
crate::walingest::WalIngestErrorKind::Cancelled => WalReceiverError::Cancelled,
_ => WalReceiverError::Other(e.into()),
})?;
let mut walingest = WalIngest::new(timeline.as_ref(), startpoint, &ctx)
.await
.map_err(|e| match e.kind {
crate::walingest::WalIngestErrorKind::Cancelled => WalReceiverError::Cancelled,
_ => WalReceiverError::Other(e.into()),
})?;
let (format, compression) = match protocol {
PostgresClientProtocol::Interpreted {
@@ -368,9 +360,13 @@ pub(super) async fn handle_walreceiver_connection(
match raw_wal_start_lsn.cmp(&expected_wal_start) {
std::cmp::Ordering::Greater => {
let msg = format!(
"Gap in streamed WAL: [{expected_wal_start}, {raw_wal_start_lsn})"
"Gap in streamed WAL: [{expected_wal_start}, {raw_wal_start_lsn}"
);
critical_timeline!(
timeline.tenant_shard_id,
timeline.timeline_id,
"{msg}"
);
critical!("{msg}");
return Err(WalReceiverError::Other(anyhow!(msg)));
}
std::cmp::Ordering::Less => {
@@ -383,7 +379,11 @@ pub(super) async fn handle_walreceiver_connection(
"Received record with next_record_lsn multiple times ({} < {})",
first_rec.next_record_lsn, expected_wal_start
);
critical!("{msg}");
critical_timeline!(
timeline.tenant_shard_id,
timeline.timeline_id,
"{msg}"
);
return Err(WalReceiverError::Other(anyhow!(msg)));
}
}
@@ -452,7 +452,11 @@ pub(super) async fn handle_walreceiver_connection(
// TODO: we can't differentiate cancellation errors with
// anyhow::Error, so just ignore it if we're cancelled.
if !cancellation.is_cancelled() && !timeline.is_stopping() {
critical!("{err:?}")
critical_timeline!(
timeline.tenant_shard_id,
timeline.timeline_id,
"{err:?}"
);
}
})?;

View File

@@ -40,7 +40,7 @@ use tracing::*;
use utils::bin_ser::{DeserializeError, SerializeError};
use utils::lsn::Lsn;
use utils::rate_limit::RateLimit;
use utils::{critical, failpoint_support};
use utils::{critical_timeline, failpoint_support};
use wal_decoder::models::record::NeonWalRecord;
use wal_decoder::models::*;
@@ -418,18 +418,30 @@ impl WalIngest {
// as there has historically been cases where PostgreSQL has cleared spurious VM pages. See:
// https://github.com/neondatabase/neon/pull/10634.
let Some(vm_size) = get_relsize(modification, vm_rel, ctx).await? else {
critical!("clear_vm_bits for unknown VM relation {vm_rel}");
critical_timeline!(
modification.tline.tenant_shard_id,
modification.tline.timeline_id,
"clear_vm_bits for unknown VM relation {vm_rel}"
);
return Ok(());
};
if let Some(blknum) = new_vm_blk {
if blknum >= vm_size {
critical!("new_vm_blk {blknum} not in {vm_rel} of size {vm_size}");
critical_timeline!(
modification.tline.tenant_shard_id,
modification.tline.timeline_id,
"new_vm_blk {blknum} not in {vm_rel} of size {vm_size}"
);
new_vm_blk = None;
}
}
if let Some(blknum) = old_vm_blk {
if blknum >= vm_size {
critical!("old_vm_blk {blknum} not in {vm_rel} of size {vm_size}");
critical_timeline!(
modification.tline.tenant_shard_id,
modification.tline.timeline_id,
"old_vm_blk {blknum} not in {vm_rel} of size {vm_size}"
);
old_vm_blk = None;
}
}

2
pgxn/neon/communicator/.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
# generated file (with cbindgen, see build.rs)
communicator_bindings.h

View File

@@ -1,9 +1,11 @@
[package]
name = "communicator"
version = "0.1.0"
edition = "2024"
license.workspace = true
edition.workspace = true
[features]
# 'testing' feature is currently unused in the communicator, but we accept it for convenience of
# calling build scripts, so that you can pass the same feature to all packages.
testing = []
[lib]
@@ -35,6 +37,7 @@ pageserver_page_api.workspace = true
neon-shmem.workspace = true
utils.workspace = true
workspace_hack = { version = "0.1", path = "../../../workspace_hack" }
[build-dependencies]
cbindgen.workspace = true

View File

@@ -3,20 +3,18 @@ use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
cbindgen::generate(crate_dir).map_or_else(
|error| match error {
cbindgen::Error::ParseSyntaxError { .. } => {
// This means there was a syntax error in the Rust sources. Don't panic, because
// we want the build to continue and the Rust compiler to hit the error. The
// Rust compiler produces a better error message than cbindgen.
eprintln!("Generating C bindings failed because of a Rust syntax error");
}
e => panic!("Unable to generate C bindings: {e:?}"),
},
|bindings| {
match cbindgen::generate(crate_dir) {
Ok(bindings) => {
bindings.write_to_file("communicator_bindings.h");
},
);
}
Err(cbindgen::Error::ParseSyntaxError { .. }) => {
// This means there was a syntax error in the Rust sources. Don't panic, because
// we want the build to continue and the Rust compiler to hit the error. The
// Rust compiler produces a better error message than cbindgen.
eprintln!("Generating C bindings failed because of a Rust syntax error");
}
Err(err) => panic!("Unable to generate C bindings: {err:?}"),
};
Ok(())
}

View File

@@ -45,6 +45,9 @@
#include "storage/ipc.h"
#endif
/* the rust bindings, generated by cbindgen */
#include "communicator/communicator_bindings.h"
PG_MODULE_MAGIC;
void _PG_init(void);
@@ -93,6 +96,14 @@ static const struct config_enum_entry running_xacts_overflow_policies[] = {
{NULL, 0, false}
};
static const struct config_enum_entry debug_compare_local_modes[] = {
{"none", DEBUG_COMPARE_LOCAL_NONE, false},
{"prefetch", DEBUG_COMPARE_LOCAL_PREFETCH, false},
{"lfc", DEBUG_COMPARE_LOCAL_LFC, false},
{"all", DEBUG_COMPARE_LOCAL_ALL, false},
{NULL, 0, false}
};
/*
* XXX: These private to procarray.c, but we need them here.
*/
@@ -544,6 +555,16 @@ _PG_init(void)
GUC_UNIT_KB,
NULL, NULL, NULL);
DefineCustomEnumVariable(
"neon.debug_compare_local",
"Debug mode for compaing content of pages in prefetch ring/LFC/PS and local disk",
NULL,
&debug_compare_local,
DEBUG_COMPARE_LOCAL_NONE,
debug_compare_local_modes,
PGC_POSTMASTER,
0,
NULL, NULL, NULL);
/*
* Important: This must happen after other parts of the extension are
* loaded, otherwise any settings to GUCs that were set before the

View File

@@ -98,12 +98,14 @@ typedef struct
typedef struct DdlHashTable
{
struct DdlHashTable *prev_table;
size_t subtrans_level;
HTAB *db_table;
HTAB *role_table;
} DdlHashTable;
static DdlHashTable RootTable;
static DdlHashTable *CurrentDdlTable = &RootTable;
static int SubtransLevel; /* current nesting level of subtransactions */
static void
PushKeyValue(JsonbParseState **state, char *key, char *value)
@@ -332,9 +334,25 @@ SendDeltasToControlPlane()
}
}
static void
InitCurrentDdlTableIfNeeded()
{
/* Lazy construction of DllHashTable chain */
if (SubtransLevel > CurrentDdlTable->subtrans_level)
{
DdlHashTable *new_table = MemoryContextAlloc(CurTransactionContext, sizeof(DdlHashTable));
new_table->prev_table = CurrentDdlTable;
new_table->subtrans_level = SubtransLevel;
new_table->role_table = NULL;
new_table->db_table = NULL;
CurrentDdlTable = new_table;
}
}
static void
InitDbTableIfNeeded()
{
InitCurrentDdlTableIfNeeded();
if (!CurrentDdlTable->db_table)
{
HASHCTL db_ctl = {};
@@ -353,6 +371,7 @@ InitDbTableIfNeeded()
static void
InitRoleTableIfNeeded()
{
InitCurrentDdlTableIfNeeded();
if (!CurrentDdlTable->role_table)
{
HASHCTL role_ctl = {};
@@ -371,19 +390,21 @@ InitRoleTableIfNeeded()
static void
PushTable()
{
DdlHashTable *new_table = MemoryContextAlloc(CurTransactionContext, sizeof(DdlHashTable));
new_table->prev_table = CurrentDdlTable;
new_table->role_table = NULL;
new_table->db_table = NULL;
CurrentDdlTable = new_table;
SubtransLevel += 1;
}
static void
MergeTable()
{
DdlHashTable *old_table = CurrentDdlTable;
DdlHashTable *old_table;
Assert(SubtransLevel >= CurrentDdlTable->subtrans_level);
if (--SubtransLevel >= CurrentDdlTable->subtrans_level)
{
return;
}
old_table = CurrentDdlTable;
CurrentDdlTable = old_table->prev_table;
if (old_table->db_table)
@@ -476,11 +497,15 @@ MergeTable()
static void
PopTable()
{
/*
* Current table gets freed because it is allocated in aborted
* subtransaction's memory context.
*/
CurrentDdlTable = CurrentDdlTable->prev_table;
Assert(SubtransLevel >= CurrentDdlTable->subtrans_level);
if (--SubtransLevel < CurrentDdlTable->subtrans_level)
{
/*
* Current table gets freed because it is allocated in aborted
* subtransaction's memory context.
*/
CurrentDdlTable = CurrentDdlTable->prev_table;
}
}
static void

View File

@@ -177,6 +177,22 @@ extern StringInfoData nm_pack_request(NeonRequest *msg);
extern NeonResponse *nm_unpack_response(StringInfo s);
extern char *nm_to_string(NeonMessage *msg);
/*
* If debug_compare_local>DEBUG_COMPARE_LOCAL_NONE, we pass through all the SMGR API
* calls to md.c, and *also* do the calls to the Page Server. On every
* read, compare the versions we read from local disk and Page Server,
* and Assert that they are identical.
*/
typedef enum
{
DEBUG_COMPARE_LOCAL_NONE, /* normal mode - pages are storted locally only for unlogged relations */
DEBUG_COMPARE_LOCAL_PREFETCH, /* if page is found in prefetch ring, then compare it with local and return */
DEBUG_COMPARE_LOCAL_LFC, /* if page is found in LFC or prefetch ring, then compare it with local and return */
DEBUG_COMPARE_LOCAL_ALL /* always fetch page from PS and compare it with local */
} DebugCompareLocalMode;
extern int debug_compare_local;
/*
* API
*/

View File

@@ -73,21 +73,15 @@
#include "access/xlogrecovery.h"
#endif
/*
* If DEBUG_COMPARE_LOCAL is defined, we pass through all the SMGR API
* calls to md.c, and *also* do the calls to the Page Server. On every
* read, compare the versions we read from local disk and Page Server,
* and Assert that they are identical.
*/
/* #define DEBUG_COMPARE_LOCAL */
#if PG_VERSION_NUM < 160000
typedef PGAlignedBlock PGIOAlignedBlock;
#endif
#ifdef DEBUG_COMPARE_LOCAL
#include "access/nbtree.h"
#include "storage/bufpage.h"
#include "access/xlog_internal.h"
static char *hexdump_page(char *page);
#endif
#define IS_LOCAL_REL(reln) (\
NInfoGetDbOid(InfoFromSMgrRel(reln)) != 0 && \
@@ -105,6 +99,8 @@ typedef enum
UNLOGGED_BUILD_NOT_PERMANENT
} UnloggedBuildPhase;
int debug_compare_local;
static NRelFileInfo unlogged_build_rel_info;
static UnloggedBuildPhase unlogged_build_phase = UNLOGGED_BUILD_NOT_IN_PROGRESS;
@@ -475,9 +471,10 @@ neon_init(void)
old_redo_read_buffer_filter = redo_read_buffer_filter;
redo_read_buffer_filter = neon_redo_read_buffer_filter;
#ifdef DEBUG_COMPARE_LOCAL
mdinit();
#endif
if (debug_compare_local)
{
mdinit();
}
}
/*
@@ -805,13 +802,16 @@ neon_create(SMgrRelation reln, ForkNumber forkNum, bool isRedo)
case RELPERSISTENCE_TEMP:
case RELPERSISTENCE_UNLOGGED:
#ifdef DEBUG_COMPARE_LOCAL
mdcreate(reln, forkNum, forkNum == INIT_FORKNUM || isRedo);
if (forkNum == MAIN_FORKNUM)
mdcreate(reln, INIT_FORKNUM, true);
#else
mdcreate(reln, forkNum, isRedo);
#endif
if (debug_compare_local)
{
mdcreate(reln, forkNum, forkNum == INIT_FORKNUM || isRedo);
if (forkNum == MAIN_FORKNUM)
mdcreate(reln, INIT_FORKNUM, true);
}
else
{
mdcreate(reln, forkNum, isRedo);
}
return;
default:
@@ -863,10 +863,11 @@ neon_create(SMgrRelation reln, ForkNumber forkNum, bool isRedo)
set_cached_relsize(InfoFromSMgrRel(reln), forkNum, 0);
}
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdcreate(reln, forkNum, isRedo);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdcreate(reln, forkNum, isRedo);
}
}
/*
@@ -892,7 +893,7 @@ neon_unlink(NRelFileInfoBackend rinfo, ForkNumber forkNum, bool isRedo)
{
/*
* Might or might not exist locally, depending on whether it's an unlogged
* or permanent relation (or if DEBUG_COMPARE_LOCAL is set). Try to
* or permanent relation (or if debug_compare_local is set). Try to
* unlink, it won't do any harm if the file doesn't exist.
*/
mdunlink(rinfo, forkNum, isRedo);
@@ -995,16 +996,23 @@ neon_extend(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno,
{
// FIXME: this can pass lsn == invalid. Is that ok?
communicator_new_rel_extend(InfoFromSMgrRel(reln), forkNum, blkno, (const void *) buffer, lsn);
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdextend(reln, forkNum, blkno, buffer, skipFsync);
}
}
else
{
set_cached_relsize(InfoFromSMgrRel(reln), forkNum, blkno + 1);
lfc_write(InfoFromSMgrRel(reln), forkNum, blkno, buffer);
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdextend(reln, forkNum, blkno, buffer, skipFsync);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdextend(reln, forkNum, blkno, buffer, skipFsync);
}
/*
* smgr_extend is often called with an all-zeroes page, so
@@ -1081,10 +1089,11 @@ neon_zeroextend(SMgrRelation reln, ForkNumber forkNum, BlockNumber start_block,
relpath(reln->smgr_rlocator, forkNum),
InvalidBlockNumber)));
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdzeroextend(reln, forkNum, blocknum, nblocks, skipFsync);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdzeroextend(reln, forkNum, blocknum, nblocks, skipFsync);
}
/* Don't log any pages if we're not allowed to do so. */
if (!XLogInsertAllowed())
@@ -1319,10 +1328,11 @@ neon_writeback(SMgrRelation reln, ForkNumber forknum,
if (!neon_enable_new_communicator)
communicator_prefetch_pump_state();
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdwriteback(reln, forknum, blocknum, nblocks);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdwriteback(reln, forknum, blocknum, nblocks);
}
}
/*
@@ -1343,7 +1353,6 @@ neon_read_at_lsn(NRelFileInfo rinfo, ForkNumber forkNum, BlockNumber blkno,
communicator_read_at_lsnv(rinfo, forkNum, blkno, &request_lsns, &buffer, 1, NULL);
}
#ifdef DEBUG_COMPARE_LOCAL
static void
compare_with_local(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void* buffer, XLogRecPtr request_lsn)
{
@@ -1425,7 +1434,6 @@ compare_with_local(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, voi
}
}
}
#endif
#if PG_MAJORVERSION_NUM < 17
@@ -1485,22 +1493,28 @@ neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void *buffer
if (communicator_prefetch_lookupv(InfoFromSMgrRel(reln), forkNum, blkno, &request_lsns, 1, &bufferp, &present))
{
/* Prefetch hit */
#ifdef DEBUG_COMPARE_LOCAL
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
#else
return;
#endif
if (debug_compare_local >= DEBUG_COMPARE_LOCAL_PREFETCH)
{
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
}
if (debug_compare_local <= DEBUG_COMPARE_LOCAL_PREFETCH)
{
return;
}
}
/* Try to read from local file cache */
if (lfc_read(InfoFromSMgrRel(reln), forkNum, blkno, buffer))
{
MyNeonCounters->file_cache_hits_total++;
#ifdef DEBUG_COMPARE_LOCAL
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
#else
return;
#endif
if (debug_compare_local >= DEBUG_COMPARE_LOCAL_LFC)
{
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
}
if (debug_compare_local <= DEBUG_COMPARE_LOCAL_LFC)
{
return;
}
}
neon_read_at_lsn(InfoFromSMgrRel(reln), forkNum, blkno, request_lsns, buffer);
@@ -1511,15 +1525,15 @@ neon_read(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void *buffer
communicator_prefetch_pump_state();
}
#ifdef DEBUG_COMPARE_LOCAL
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
#endif
if (debug_compare_local)
{
compare_with_local(reln, forkNum, blkno, buffer, request_lsns.request_lsn);
}
}
#endif /* PG_MAJORVERSION_NUM <= 16 */
#if PG_MAJORVERSION_NUM >= 17
#ifdef DEBUG_COMPARE_LOCAL
static void
compare_with_localv(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, void** buffers, BlockNumber nblocks, neon_request_lsns* request_lsns, bits8* read_pages)
{
@@ -1534,7 +1548,6 @@ compare_with_localv(SMgrRelation reln, ForkNumber forkNum, BlockNumber blkno, vo
}
}
}
#endif
static void
@@ -1593,13 +1606,18 @@ neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
blocknum, request_lsns, nblocks,
buffers, read_pages);
#ifdef DEBUG_COMPARE_LOCAL
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
memset(read_pages, 0, sizeof(read_pages));
#else
if (prefetch_result == nblocks)
if (debug_compare_local >= DEBUG_COMPARE_LOCAL_PREFETCH)
{
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
}
if (debug_compare_local <= DEBUG_COMPARE_LOCAL_PREFETCH && prefetch_result == nblocks)
{
return;
#endif
}
if (debug_compare_local > DEBUG_COMPARE_LOCAL_PREFETCH)
{
memset(read_pages, 0, sizeof(read_pages));
}
/* Try to read from local file cache */
lfc_result = lfc_readv_select(InfoFromSMgrRel(reln), forknum, blocknum, buffers,
@@ -1608,14 +1626,19 @@ neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
if (lfc_result > 0)
MyNeonCounters->file_cache_hits_total += lfc_result;
#ifdef DEBUG_COMPARE_LOCAL
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
memset(read_pages, 0, sizeof(read_pages));
#else
/* Read all blocks from LFC, so we're done */
if (prefetch_result + lfc_result == nblocks)
if (debug_compare_local >= DEBUG_COMPARE_LOCAL_LFC)
{
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
}
if (debug_compare_local <= DEBUG_COMPARE_LOCAL_LFC && prefetch_result + lfc_result == nblocks)
{
/* Read all blocks from LFC, so we're done */
return;
#endif
}
if (debug_compare_local > DEBUG_COMPARE_LOCAL_LFC)
{
memset(read_pages, 0, sizeof(read_pages));
}
communicator_read_at_lsnv(InfoFromSMgrRel(reln), forknum, blocknum, request_lsns,
buffers, nblocks, read_pages);
@@ -1626,14 +1649,14 @@ neon_readv(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum,
communicator_prefetch_pump_state();
}
#ifdef DEBUG_COMPARE_LOCAL
memset(read_pages, 0xFF, sizeof(read_pages));
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
#endif
if (debug_compare_local)
{
memset(read_pages, 0xFF, sizeof(read_pages));
compare_with_localv(reln, forknum, blocknum, buffers, nblocks, request_lsns, read_pages);
}
}
#endif
#ifdef DEBUG_COMPARE_LOCAL
static char *
hexdump_page(char *page)
{
@@ -1652,7 +1675,6 @@ hexdump_page(char *page)
return result.data;
}
#endif
#if PG_MAJORVERSION_NUM < 17
/*
@@ -1674,12 +1696,8 @@ neon_write(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, const vo
switch (reln->smgr_relpersistence)
{
case 0:
#ifndef DEBUG_COMPARE_LOCAL
/* This is a bit tricky. Check if the relation exists locally */
if (mdexists(reln, forknum))
#else
if (mdexists(reln, INIT_FORKNUM))
#endif
if (mdexists(reln, debug_compare_local ? INIT_FORKNUM : forknum))
{
/* It exists locally. Guess it's unlogged then. */
#if PG_MAJORVERSION_NUM >= 17
@@ -1741,14 +1759,17 @@ neon_write(SMgrRelation reln, ForkNumber forknum, BlockNumber blocknum, const vo
communicator_prefetch_pump_state();
}
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
{
#if PG_MAJORVERSION_NUM >= 17
mdwritev(reln, forknum, blocknum, &buffer, 1, skipFsync);
mdwritev(reln, forknum, blocknum, &buffer, 1, skipFsync);
#else
mdwrite(reln, forknum, blocknum, buffer, skipFsync);
mdwrite(reln, forknum, blocknum, buffer, skipFsync);
#endif
#endif
}
}
}
#endif
@@ -1762,12 +1783,8 @@ neon_writev(SMgrRelation reln, ForkNumber forknum, BlockNumber blkno,
switch (reln->smgr_relpersistence)
{
case 0:
#ifndef DEBUG_COMPARE_LOCAL
/* This is a bit tricky. Check if the relation exists locally */
if (mdexists(reln, forknum))
#else
if (mdexists(reln, INIT_FORKNUM))
#endif
if (mdexists(reln, debug_compare_local ? INIT_FORKNUM : forknum))
{
/* It exists locally. Guess it's unlogged then. */
mdwritev(reln, forknum, blkno, buffers, nblocks, skipFsync);
@@ -1817,10 +1834,11 @@ neon_writev(SMgrRelation reln, ForkNumber forknum, BlockNumber blkno,
communicator_prefetch_pump_state();
}
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdwritev(reln, forknum, blkno, buffers, nblocks, skipFsync);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdwritev(reln, forknum, blkno, buffers, nblocks, skipFsync);
}
}
#endif
@@ -1980,10 +1998,11 @@ neon_truncate(SMgrRelation reln, ForkNumber forknum, BlockNumber old_blocks, Blo
neon_set_lwlsn_relation(lsn, InfoFromSMgrRel(reln), forknum);
}
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdtruncate(reln, forknum, old_blocks, nblocks);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdtruncate(reln, forknum, old_blocks, nblocks);
}
}
/*
@@ -2023,10 +2042,11 @@ neon_immedsync(SMgrRelation reln, ForkNumber forknum)
if (!neon_enable_new_communicator)
communicator_prefetch_pump_state();
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdimmedsync(reln, forknum);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdimmedsync(reln, forknum);
}
}
#if PG_MAJORVERSION_NUM >= 17
@@ -2053,10 +2073,11 @@ neon_registersync(SMgrRelation reln, ForkNumber forknum)
neon_log(SmgrTrace, "[NEON_SMGR] registersync noop");
#ifdef DEBUG_COMPARE_LOCAL
if (IS_LOCAL_REL(reln))
mdimmedsync(reln, forknum);
#endif
if (debug_compare_local)
{
if (IS_LOCAL_REL(reln))
mdimmedsync(reln, forknum);
}
}
#endif
@@ -2097,10 +2118,11 @@ neon_start_unlogged_build(SMgrRelation reln)
case RELPERSISTENCE_UNLOGGED:
unlogged_build_rel_info = InfoFromSMgrRel(reln);
unlogged_build_phase = UNLOGGED_BUILD_NOT_PERMANENT;
#ifdef DEBUG_COMPARE_LOCAL
if (!IsParallelWorker())
mdcreate(reln, INIT_FORKNUM, true);
#endif
if (debug_compare_local)
{
if (!IsParallelWorker())
mdcreate(reln, INIT_FORKNUM, true);
}
return;
default:
@@ -2128,11 +2150,7 @@ neon_start_unlogged_build(SMgrRelation reln)
*/
if (!IsParallelWorker())
{
#ifndef DEBUG_COMPARE_LOCAL
mdcreate(reln, MAIN_FORKNUM, false);
#else
mdcreate(reln, INIT_FORKNUM, true);
#endif
mdcreate(reln, debug_compare_local ? INIT_FORKNUM : MAIN_FORKNUM, false);
}
}
@@ -2230,14 +2248,14 @@ neon_end_unlogged_build(SMgrRelation reln)
}
mdclose(reln, forknum);
#ifndef DEBUG_COMPARE_LOCAL
/* use isRedo == true, so that we drop it immediately */
mdunlink(rinfob, forknum, true);
#endif
if (!debug_compare_local)
{
/* use isRedo == true, so that we drop it immediately */
mdunlink(rinfob, forknum, true);
}
}
#ifdef DEBUG_COMPARE_LOCAL
mdunlink(rinfob, INIT_FORKNUM, true);
#endif
if (debug_compare_local)
mdunlink(rinfob, INIT_FORKNUM, true);
}
NRelFileInfoInvalidate(unlogged_build_rel_info);
unlogged_build_phase = UNLOGGED_BUILD_NOT_IN_PROGRESS;

View File

@@ -138,3 +138,62 @@ Now from client you can start a new session:
```sh
PGSSLROOTCERT=./server.crt psql "postgresql://proxy:password@endpoint.local.neon.build:4432/postgres?sslmode=verify-full"
```
## auth broker setup:
Create a postgres instance:
```sh
docker run \
--detach \
--name proxy-postgres \
--env POSTGRES_HOST_AUTH_METHOD=trust \
--env POSTGRES_USER=authenticated \
--env POSTGRES_DB=database \
--publish 5432:5432 \
postgres:17-bookworm
```
Create a configuration file called `local_proxy.json` in the root of the repo (used also by the auth broker to validate JWTs)
```sh
{
"jwks": [
{
"id": "1",
"role_names": ["authenticator", "authenticated", "anon"],
"jwks_url": "https://climbing-minnow-11.clerk.accounts.dev/.well-known/jwks.json",
"provider_name": "foo",
"jwt_audience": null
}
]
}
```
Start the local proxy:
```sh
cargo run --bin local_proxy -- \
--disable_pg_session_jwt true \
--http 0.0.0.0:7432
```
Start the auth broker:
```sh
LOGFMT=text OTEL_SDK_DISABLED=true cargo run --bin proxy --features testing -- \
-c server.crt -k server.key \
--is-auth-broker true \
--wss 0.0.0.0:8080 \
--http 0.0.0.0:7002 \
--auth-backend local
```
Create a JWT in your auth provider (e.g. Clerk) and set it in the `NEON_JWT` environment variable.
```sh
export NEON_JWT="..."
```
Run a query against the auth broker:
```sh
curl -k "https://foo.local.neon.build:8080/sql" \
-H "Authorization: Bearer $NEON_JWT" \
-H "neon-connection-string: postgresql://authenticator@foo.local.neon.build/database" \
-d '{"query":"select 1","params":[]}'
```

View File

@@ -164,21 +164,20 @@ async fn authenticate(
})?
.map_err(ConsoleRedirectError::from)?;
if auth_config.ip_allowlist_check_enabled {
if let Some(allowed_ips) = &db_info.allowed_ips {
if !auth::check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips) {
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
}
if auth_config.ip_allowlist_check_enabled
&& let Some(allowed_ips) = &db_info.allowed_ips
&& !auth::check_peer_addr_is_in_list(&ctx.peer_addr(), allowed_ips)
{
return Err(auth::AuthError::ip_address_not_allowed(ctx.peer_addr()));
}
// Check if the access over the public internet is allowed, otherwise block. Note that
// the console redirect is not behind the VPC service endpoint, so we don't need to check
// the VPC endpoint ID.
if let Some(public_access_allowed) = db_info.public_access_allowed {
if !public_access_allowed {
return Err(auth::AuthError::NetworkNotAllowed);
}
if let Some(public_access_allowed) = db_info.public_access_allowed
&& !public_access_allowed
{
return Err(auth::AuthError::NetworkNotAllowed);
}
client.write_message(BeMessage::NoticeResponse("Connecting to database."));

View File

@@ -399,36 +399,36 @@ impl JwkCacheEntryLock {
tracing::debug!(?payload, "JWT signature valid with claims");
if let Some(aud) = expected_audience {
if payload.audience.0.iter().all(|s| s != aud) {
return Err(JwtError::InvalidClaims(
JwtClaimsError::InvalidJwtTokenAudience,
));
}
if let Some(aud) = expected_audience
&& payload.audience.0.iter().all(|s| s != aud)
{
return Err(JwtError::InvalidClaims(
JwtClaimsError::InvalidJwtTokenAudience,
));
}
let now = SystemTime::now();
if let Some(exp) = payload.expiration {
if now >= exp + CLOCK_SKEW_LEEWAY {
return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired(
exp.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
)));
}
if let Some(exp) = payload.expiration
&& now >= exp + CLOCK_SKEW_LEEWAY
{
return Err(JwtError::InvalidClaims(JwtClaimsError::JwtTokenHasExpired(
exp.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
)));
}
if let Some(nbf) = payload.not_before {
if nbf >= now + CLOCK_SKEW_LEEWAY {
return Err(JwtError::InvalidClaims(
JwtClaimsError::JwtTokenNotYetReadyToUse(
nbf.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
),
));
}
if let Some(nbf) = payload.not_before
&& nbf >= now + CLOCK_SKEW_LEEWAY
{
return Err(JwtError::InvalidClaims(
JwtClaimsError::JwtTokenNotYetReadyToUse(
nbf.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
),
));
}
Ok(ComputeCredentialKeys::JwtPayload(payloadb))

View File

@@ -171,7 +171,6 @@ impl ComputeUserInfo {
pub(crate) enum ComputeCredentialKeys {
AuthKeys(AuthKeys),
JwtPayload(Vec<u8>),
None,
}
impl TryFrom<ComputeUserInfoMaybeEndpoint> for ComputeUserInfo {
@@ -346,15 +345,13 @@ impl<'a> Backend<'a, ComputeUserInfoMaybeEndpoint> {
Err(e) => {
// The password could have been changed, so we invalidate the cache.
// We should only invalidate the cache if the TTL might have expired.
if e.is_password_failed() {
#[allow(irrefutable_let_patterns)]
if let ControlPlaneClient::ProxyV1(api) = &*api {
if let Some(ep) = &user_info.endpoint_id {
api.caches
.project_info
.maybe_invalidate_role_secret(ep, &user_info.user);
}
}
if e.is_password_failed()
&& let ControlPlaneClient::ProxyV1(api) = &*api
&& let Some(ep) = &user_info.endpoint_id
{
api.caches
.project_info
.maybe_invalidate_role_secret(ep, &user_info.user);
}
Err(e)

View File

@@ -1,43 +1,37 @@
use std::net::SocketAddr;
use std::pin::pin;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::{Context, bail, ensure};
use anyhow::bail;
use arc_swap::ArcSwapOption;
use camino::{Utf8Path, Utf8PathBuf};
use camino::Utf8PathBuf;
use clap::Parser;
use compute_api::spec::LocalProxySpec;
use futures::future::Either;
use thiserror::Error;
use tokio::net::TcpListener;
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use tracing::{debug, error, info};
use utils::sentry_init::init_sentry;
use utils::{pid_file, project_build_tag, project_git_version};
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::local::{JWKS_ROLE_MAP, LocalBackend};
use crate::auth::backend::local::LocalBackend;
use crate::auth::{self};
use crate::cancellation::CancellationHandler;
use crate::config::{
self, AuthenticationConfig, ComputeConfig, HttpConfig, ProxyConfig, RetryConfig,
refresh_config_loop,
};
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::ext::TaskExt;
use crate::http::health_server::AppMetrics;
use crate::intern::RoleNameInt;
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig, RateBucketInfo};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::cancel_set::CancelSet;
use crate::serverless::{self, GlobalConnPoolOptions};
use crate::tls::client_config::compute_client_config_with_root_certs;
use crate::types::RoleName;
use crate::url::ApiUrl;
project_git_version!(GIT_VERSION);
@@ -82,6 +76,11 @@ struct LocalProxyCliArgs {
/// Path of the local proxy PID file
#[clap(long, default_value = "./local_proxy.pid")]
pid_path: Utf8PathBuf,
/// Disable pg_session_jwt extension installation
/// This is useful for testing the local proxy with vanilla postgres.
#[clap(long, default_value = "false")]
#[cfg(feature = "testing")]
disable_pg_session_jwt: bool,
}
#[derive(clap::Args, Clone, Copy, Debug)]
@@ -282,6 +281,8 @@ fn build_config(args: &LocalProxyCliArgs) -> anyhow::Result<&'static ProxyConfig
wake_compute_retry_config: RetryConfig::parse(RetryConfig::WAKE_COMPUTE_DEFAULT_VALUES)?,
connect_compute_locks,
connect_to_compute: compute_config,
#[cfg(feature = "testing")]
disable_pg_session_jwt: args.disable_pg_session_jwt,
})))
}
@@ -293,132 +294,3 @@ fn build_auth_backend(args: &LocalProxyCliArgs) -> &'static auth::Backend<'stati
Box::leak(Box::new(auth_backend))
}
#[derive(Error, Debug)]
enum RefreshConfigError {
#[error(transparent)]
Read(#[from] std::io::Error),
#[error(transparent)]
Parse(#[from] serde_json::Error),
#[error(transparent)]
Validate(anyhow::Error),
#[error(transparent)]
Tls(anyhow::Error),
}
async fn refresh_config_loop(config: &ProxyConfig, path: Utf8PathBuf, rx: Arc<Notify>) {
let mut init = true;
loop {
rx.notified().await;
match refresh_config_inner(config, &path).await {
Ok(()) => {}
// don't log for file not found errors if this is the first time we are checking
// for computes that don't use local_proxy, this is not an error.
Err(RefreshConfigError::Read(e))
if init && e.kind() == std::io::ErrorKind::NotFound =>
{
debug!(error=?e, ?path, "could not read config file");
}
Err(RefreshConfigError::Tls(e)) => {
error!(error=?e, ?path, "could not read TLS certificates");
}
Err(e) => {
error!(error=?e, ?path, "could not read config file");
}
}
init = false;
}
}
async fn refresh_config_inner(
config: &ProxyConfig,
path: &Utf8Path,
) -> Result<(), RefreshConfigError> {
let bytes = tokio::fs::read(&path).await?;
let data: LocalProxySpec = serde_json::from_slice(&bytes)?;
let mut jwks_set = vec![];
fn parse_jwks_settings(jwks: compute_api::spec::JwksSettings) -> anyhow::Result<JwksSettings> {
let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?;
ensure!(
jwks_url.has_authority()
&& (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"),
"Invalid JWKS url. Must be HTTP",
);
ensure!(
jwks_url.host().is_some_and(|h| h != url::Host::Domain("")),
"Invalid JWKS url. No domain listed",
);
// clear username, password and ports
jwks_url
.set_username("")
.expect("url can be a base and has a valid host and is not a file. should not error");
jwks_url
.set_password(None)
.expect("url can be a base and has a valid host and is not a file. should not error");
// local testing is hard if we need to have a specific restricted port
if cfg!(not(feature = "testing")) {
jwks_url.set_port(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
}
// clear query params
jwks_url.set_fragment(None);
jwks_url.query_pairs_mut().clear().finish();
if jwks_url.scheme() != "https" {
// local testing is hard if we need to set up https support.
if cfg!(not(feature = "testing")) {
jwks_url
.set_scheme("https")
.expect("should not error to set the scheme to https if it was http");
} else {
warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS");
}
}
Ok(JwksSettings {
id: jwks.id,
jwks_url,
_provider_name: jwks.provider_name,
jwt_audience: jwks.jwt_audience,
role_names: jwks
.role_names
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
})
}
for jwks in data.jwks.into_iter().flatten() {
jwks_set.push(parse_jwks_settings(jwks).map_err(RefreshConfigError::Validate)?);
}
info!("successfully loaded new config");
JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set })));
if let Some(tls_config) = data.tls {
let tls_config = tokio::task::spawn_blocking(move || {
crate::tls::server_config::configure_tls(
tls_config.key_path.as_ref(),
tls_config.cert_path.as_ref(),
None,
false,
)
})
.await
.propagate_task_panic()
.map_err(RefreshConfigError::Tls)?;
config.tls_config.store(Some(Arc::new(tls_config)));
}
Ok(())
}

View File

@@ -4,6 +4,7 @@
//! This allows connecting to pods/services running in the same Kubernetes cluster from
//! the outside. Similar to an ingress controller for HTTPS.
use std::io;
use std::net::SocketAddr;
use std::path::Path;
use std::sync::Arc;
@@ -229,7 +230,6 @@ pub(super) async fn task_main(
.set_nodelay(true)
.context("failed to set socket option")?;
info!(%peer_addr, "serving");
let ctx = RequestContext::new(
session_id,
ConnectionInfo {
@@ -241,6 +241,14 @@ pub(super) async fn task_main(
handle_client(ctx, dest_suffix, tls_config, compute_tls_config, socket).await
}
.unwrap_or_else(|e| {
if let Some(FirstMessage(io_error)) = e.downcast_ref() {
// this is noisy. if we get EOF on the very first message that's likely
// just NLB doing a healthcheck.
if io_error.kind() == io::ErrorKind::UnexpectedEof {
return;
}
}
// Acknowledge that the task has finished with an error.
error!("per-client task finished with an error: {e:#}");
})
@@ -257,12 +265,19 @@ pub(super) async fn task_main(
Ok(())
}
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
struct FirstMessage(io::Error);
async fn ssl_handshake<S: AsyncRead + AsyncWrite + Unpin>(
ctx: &RequestContext,
raw_stream: S,
tls_config: Arc<rustls::ServerConfig>,
) -> anyhow::Result<TlsStream<S>> {
let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream)).await?;
let (mut stream, msg) = PqStream::parse_startup(Stream::from_raw(raw_stream))
.await
.map_err(FirstMessage)?;
match msg {
FeStartupPacket::SslRequest { direct: None } => {
let raw = stream.accept_tls().await?;

View File

@@ -10,11 +10,15 @@ use std::time::Duration;
use anyhow::Context;
use anyhow::{bail, ensure};
use arc_swap::ArcSwapOption;
#[cfg(any(test, feature = "testing"))]
use camino::Utf8PathBuf;
use futures::future::Either;
use itertools::{Itertools, Position};
use rand::{Rng, thread_rng};
use remote_storage::RemoteStorageConfig;
use tokio::net::TcpListener;
#[cfg(any(test, feature = "testing"))]
use tokio::sync::Notify;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, error, info, warn};
@@ -22,9 +26,13 @@ use utils::sentry_init::init_sentry;
use utils::{project_build_tag, project_git_version};
use crate::auth::backend::jwt::JwkCache;
#[cfg(any(test, feature = "testing"))]
use crate::auth::backend::local::LocalBackend;
use crate::auth::backend::{ConsoleRedirectBackend, MaybeOwned};
use crate::batch::BatchQueue;
use crate::cancellation::{CancellationHandler, CancellationProcessor};
#[cfg(any(test, feature = "testing"))]
use crate::config::refresh_config_loop;
use crate::config::{
self, AuthenticationConfig, CacheOptions, ComputeConfig, HttpConfig, ProjectInfoCacheOptions,
ProxyConfig, ProxyProtocolV2, remote_storage_from_toml,
@@ -60,6 +68,9 @@ enum AuthBackendType {
#[cfg(any(test, feature = "testing"))]
Postgres,
#[cfg(any(test, feature = "testing"))]
Local,
}
/// Neon proxy/router
@@ -74,6 +85,10 @@ struct ProxyCliArgs {
proxy: SocketAddr,
#[clap(value_enum, long, default_value_t = AuthBackendType::ConsoleRedirect)]
auth_backend: AuthBackendType,
/// Path of the local proxy config file (used for local-file auth backend)
#[clap(long, default_value = "./local_proxy.json")]
#[cfg(any(test, feature = "testing"))]
config_path: Utf8PathBuf,
/// listen for management callback connection on ip:port
#[clap(short, long, default_value = "127.0.0.1:7000")]
mgmt: SocketAddr,
@@ -226,6 +241,14 @@ struct ProxyCliArgs {
#[clap(flatten)]
pg_sni_router: PgSniRouterArgs,
/// if this is not local proxy, this toggles whether we accept Postgres REST requests
#[clap(long, default_value_t = false, value_parser = clap::builder::BoolishValueParser::new(), action = clap::ArgAction::Set)]
is_rest_broker: bool,
/// cache for `db_schema_cache` introspection (use `size=0` to disable)
#[clap(long, default_value = "size=1000,ttl=1h")]
db_schema_cache: String,
}
#[derive(clap::Args, Clone, Copy, Debug)]
@@ -386,6 +409,8 @@ pub async fn run() -> anyhow::Result<()> {
64,
));
#[cfg(any(test, feature = "testing"))]
let refresh_config_notify = Arc::new(Notify::new());
// client facing tasks. these will exit on error or on cancellation
// cancellation returns Ok(())
let mut client_tasks = JoinSet::new();
@@ -412,6 +437,17 @@ pub async fn run() -> anyhow::Result<()> {
endpoint_rate_limiter.clone(),
));
}
// if auth backend is local, we need to load the config file
#[cfg(any(test, feature = "testing"))]
if let auth::Backend::Local(_) = &auth_backend {
refresh_config_notify.notify_one();
tokio::spawn(refresh_config_loop(
config,
args.config_path,
refresh_config_notify.clone(),
));
}
}
Either::Right(auth_backend) => {
if let Some(proxy_listener) = proxy_listener {
@@ -462,7 +498,13 @@ pub async fn run() -> anyhow::Result<()> {
// maintenance tasks. these never return unless there's an error
let mut maintenance_tasks = JoinSet::new();
maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), || {}));
maintenance_tasks.spawn(crate::signals::handle(cancellation_token.clone(), {
move || {
#[cfg(any(test, feature = "testing"))]
refresh_config_notify.notify_one();
}
}));
maintenance_tasks.spawn(http::health_server::task_main(
http_listener,
AppMetrics {
@@ -478,54 +520,51 @@ pub async fn run() -> anyhow::Result<()> {
maintenance_tasks.spawn(usage_metrics::task_main(metrics_config));
}
#[cfg_attr(not(any(test, feature = "testing")), expect(irrefutable_let_patterns))]
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend {
if let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api {
if let Some(client) = redis_client {
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
if let Either::Left(auth::Backend::ControlPlane(api, ())) = &auth_backend
&& let crate::control_plane::client::ControlPlaneClient::ProxyV1(api) = &**api
&& let Some(client) = redis_client
{
// project info cache and invalidation of that cache.
let cache = api.caches.project_info.clone();
maintenance_tasks.spawn(notifications::task_main(client.clone(), cache.clone()));
maintenance_tasks.spawn(async move { cache.clone().gc_worker().await });
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues.
// cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
for attempt in (0..3).with_position() {
match redis_kv_client.try_connect().await {
Ok(()) => {
info!("Connected to Redis KV client");
cancellation_handler.init_tx(BatchQueue::new(CancellationProcessor {
client: redis_kv_client,
batch_size: args.cancellation_batch_size,
}));
// Try to connect to Redis 3 times with 1 + (0..0.1) second interval.
// This prevents immediate exit and pod restart,
// which can cause hammering of the redis in case of connection issues.
// cancellation key management
let mut redis_kv_client = RedisKVClient::new(client.clone());
for attempt in (0..3).with_position() {
match redis_kv_client.try_connect().await {
Ok(()) => {
info!("Connected to Redis KV client");
cancellation_handler.init_tx(BatchQueue::new(CancellationProcessor {
client: redis_kv_client,
batch_size: args.cancellation_batch_size,
}));
break;
}
Err(e) => {
error!("Failed to connect to Redis KV client: {e}");
if matches!(attempt, Position::Last(_)) {
bail!(
"Failed to connect to Redis KV client after {} attempts",
attempt.into_inner()
);
}
let jitter = thread_rng().gen_range(0..100);
tokio::time::sleep(Duration::from_millis(1000 + jitter)).await;
}
}
break;
}
Err(e) => {
error!("Failed to connect to Redis KV client: {e}");
if matches!(attempt, Position::Last(_)) {
bail!(
"Failed to connect to Redis KV client after {} attempts",
attempt.into_inner()
);
}
let jitter = thread_rng().gen_range(0..100);
tokio::time::sleep(Duration::from_millis(1000 + jitter)).await;
}
// listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone();
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(client, cancellation_token.clone()).await }
.instrument(span),
);
}
}
// listen for notifications of new projects/endpoints/branches
let cache = api.caches.endpoints_cache.clone();
let span = tracing::info_span!("endpoints_cache");
maintenance_tasks.spawn(
async move { cache.do_read(client, cancellation_token.clone()).await }.instrument(span),
);
}
let maintenance = loop {
@@ -653,6 +692,8 @@ fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
wake_compute_retry_config: config::RetryConfig::parse(&args.wake_compute_retry)?,
connect_compute_locks,
connect_to_compute: compute_config,
#[cfg(feature = "testing")]
disable_pg_session_jwt: false,
};
let config = Box::leak(Box::new(config));
@@ -806,6 +847,19 @@ fn build_auth_backend(
Ok(Either::Right(config))
}
#[cfg(any(test, feature = "testing"))]
AuthBackendType::Local => {
let postgres: SocketAddr = "127.0.0.1:7432".parse()?;
let compute_ctl: ApiUrl = "http://127.0.0.1:3081/".parse()?;
let auth_backend = crate::auth::Backend::Local(
crate::auth::backend::MaybeOwned::Owned(LocalBackend::new(postgres, compute_ctl)),
);
let config = Box::leak(Box::new(auth_backend));
Ok(Either::Left(config))
}
}
}

View File

@@ -64,6 +64,13 @@ impl Pipeline {
let responses = self.replies;
let batch_size = self.inner.len();
if !client.credentials_refreshed() {
tracing::debug!(
"Redis credentials are not refreshed. Sleeping for 5 seconds before retrying..."
);
tokio::time::sleep(Duration::from_secs(5)).await;
}
match client.query(&self.inner).await {
// for each reply, we expect that many values.
Ok(Value::Array(values)) if values.len() == responses => {
@@ -127,6 +134,14 @@ impl QueueProcessing for CancellationProcessor {
}
async fn apply(&mut self, batch: Vec<Self::Req>) -> Vec<Self::Res> {
if !self.client.credentials_refreshed() {
// this will cause a timeout for cancellation operations
tracing::debug!(
"Redis credentials are not refreshed. Sleeping for 5 seconds before retrying..."
);
tokio::time::sleep(Duration::from_secs(5)).await;
}
let mut pipeline = Pipeline::with_capacity(batch.len());
let batch_size = batch.len();

View File

@@ -165,7 +165,7 @@ impl AuthInfo {
ComputeCredentialKeys::AuthKeys(AuthKeys::ScramSha256(auth_keys)) => {
Some(Auth::Scram(Box::new(auth_keys)))
}
ComputeCredentialKeys::JwtPayload(_) | ComputeCredentialKeys::None => None,
ComputeCredentialKeys::JwtPayload(_) => None,
},
server_params: StartupMessageParams::default(),
skip_db_user: false,

View File

@@ -4,17 +4,26 @@ use std::time::Duration;
use anyhow::{Context, Ok, bail, ensure};
use arc_swap::ArcSwapOption;
use camino::{Utf8Path, Utf8PathBuf};
use clap::ValueEnum;
use compute_api::spec::LocalProxySpec;
use remote_storage::RemoteStorageConfig;
use thiserror::Error;
use tokio::sync::Notify;
use tracing::{debug, error, info, warn};
use crate::auth::backend::jwt::JwkCache;
use crate::auth::backend::local::JWKS_ROLE_MAP;
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::messages::{EndpointJwksResponse, JwksSettings};
use crate::ext::TaskExt;
use crate::intern::RoleNameInt;
use crate::rate_limiter::{RateBucketInfo, RateLimitAlgorithm, RateLimiterConfig};
use crate::scram::threadpool::ThreadPool;
use crate::serverless::GlobalConnPoolOptions;
use crate::serverless::cancel_set::CancelSet;
pub use crate::tls::server_config::{TlsConfig, configure_tls};
use crate::types::Host;
use crate::types::{Host, RoleName};
pub struct ProxyConfig {
pub tls_config: ArcSwapOption<TlsConfig>,
@@ -26,6 +35,8 @@ pub struct ProxyConfig {
pub wake_compute_retry_config: RetryConfig,
pub connect_compute_locks: ApiLocks<Host>,
pub connect_to_compute: ComputeConfig,
#[cfg(feature = "testing")]
pub disable_pg_session_jwt: bool,
}
pub struct ComputeConfig {
@@ -409,6 +420,135 @@ impl FromStr for ConcurrencyLockOptions {
}
}
#[derive(Error, Debug)]
pub(crate) enum RefreshConfigError {
#[error(transparent)]
Read(#[from] std::io::Error),
#[error(transparent)]
Parse(#[from] serde_json::Error),
#[error(transparent)]
Validate(anyhow::Error),
#[error(transparent)]
Tls(anyhow::Error),
}
pub(crate) async fn refresh_config_loop(config: &ProxyConfig, path: Utf8PathBuf, rx: Arc<Notify>) {
let mut init = true;
loop {
rx.notified().await;
match refresh_config_inner(config, &path).await {
std::result::Result::Ok(()) => {}
// don't log for file not found errors if this is the first time we are checking
// for computes that don't use local_proxy, this is not an error.
Err(RefreshConfigError::Read(e))
if init && e.kind() == std::io::ErrorKind::NotFound =>
{
debug!(error=?e, ?path, "could not read config file");
}
Err(RefreshConfigError::Tls(e)) => {
error!(error=?e, ?path, "could not read TLS certificates");
}
Err(e) => {
error!(error=?e, ?path, "could not read config file");
}
}
init = false;
}
}
pub(crate) async fn refresh_config_inner(
config: &ProxyConfig,
path: &Utf8Path,
) -> Result<(), RefreshConfigError> {
let bytes = tokio::fs::read(&path).await?;
let data: LocalProxySpec = serde_json::from_slice(&bytes)?;
let mut jwks_set = vec![];
fn parse_jwks_settings(jwks: compute_api::spec::JwksSettings) -> anyhow::Result<JwksSettings> {
let mut jwks_url = url::Url::from_str(&jwks.jwks_url).context("parsing JWKS url")?;
ensure!(
jwks_url.has_authority()
&& (jwks_url.scheme() == "http" || jwks_url.scheme() == "https"),
"Invalid JWKS url. Must be HTTP",
);
ensure!(
jwks_url.host().is_some_and(|h| h != url::Host::Domain("")),
"Invalid JWKS url. No domain listed",
);
// clear username, password and ports
jwks_url
.set_username("")
.expect("url can be a base and has a valid host and is not a file. should not error");
jwks_url
.set_password(None)
.expect("url can be a base and has a valid host and is not a file. should not error");
// local testing is hard if we need to have a specific restricted port
if cfg!(not(feature = "testing")) {
jwks_url.set_port(None).expect(
"url can be a base and has a valid host and is not a file. should not error",
);
}
// clear query params
jwks_url.set_fragment(None);
jwks_url.query_pairs_mut().clear().finish();
if jwks_url.scheme() != "https" {
// local testing is hard if we need to set up https support.
if cfg!(not(feature = "testing")) {
jwks_url
.set_scheme("https")
.expect("should not error to set the scheme to https if it was http");
} else {
warn!(scheme = jwks_url.scheme(), "JWKS url is not HTTPS");
}
}
Ok(JwksSettings {
id: jwks.id,
jwks_url,
_provider_name: jwks.provider_name,
jwt_audience: jwks.jwt_audience,
role_names: jwks
.role_names
.into_iter()
.map(RoleName::from)
.map(|s| RoleNameInt::from(&s))
.collect(),
})
}
for jwks in data.jwks.into_iter().flatten() {
jwks_set.push(parse_jwks_settings(jwks).map_err(RefreshConfigError::Validate)?);
}
info!("successfully loaded new config");
JWKS_ROLE_MAP.store(Some(Arc::new(EndpointJwksResponse { jwks: jwks_set })));
if let Some(tls_config) = data.tls {
let tls_config = tokio::task::spawn_blocking(move || {
crate::tls::server_config::configure_tls(
tls_config.key_path.as_ref(),
tls_config.cert_path.as_ref(),
None,
false,
)
})
.await
.propagate_task_panic()
.map_err(RefreshConfigError::Tls)?;
config.tls_config.store(Some(Arc::new(tls_config)));
}
std::result::Result::Ok(())
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -209,11 +209,9 @@ impl RequestContext {
if let Some(options_str) = options.get("options") {
// If not found directly, try to extract it from the options string
for option in options_str.split_whitespace() {
if option.starts_with("neon_query_id:") {
if let Some(value) = option.strip_prefix("neon_query_id:") {
this.set_testodrome_id(value.into());
break;
}
if let Some(value) = option.strip_prefix("neon_query_id:") {
this.set_testodrome_id(value.into());
break;
}
}
}

View File

@@ -250,10 +250,8 @@ impl NeonControlPlaneClient {
info!(duration = ?start.elapsed(), "received http response");
let body = parse_body::<WakeCompute>(response.status(), response.bytes().await?)?;
// Unfortunately, ownership won't let us use `Option::ok_or` here.
let (host, port) = match parse_host_port(&body.address) {
None => return Err(WakeComputeError::BadComputeAddress(body.address)),
Some(x) => x,
let Some((host, port)) = parse_host_port(&body.address) else {
return Err(WakeComputeError::BadComputeAddress(body.address));
};
let host_addr = IpAddr::from_str(host).ok();

View File

@@ -213,7 +213,12 @@ impl<K: Hash + Eq + Clone> ApiLocks<K> {
self.metrics
.semaphore_acquire_seconds
.observe(now.elapsed().as_secs_f64());
debug!("acquired permit {:?}", now.elapsed().as_secs_f64());
if permit.is_ok() {
debug!(elapsed = ?now.elapsed(), "acquired permit");
} else {
debug!(elapsed = ?now.elapsed(), "timed out acquiring permit");
}
Ok(WakeComputePermit { permit: permit? })
}

View File

@@ -78,16 +78,6 @@ pub(crate) trait ReportableError: fmt::Display + Send + 'static {
fn get_error_kind(&self) -> ErrorKind;
}
impl ReportableError for postgres_client::error::Error {
fn get_error_kind(&self) -> ErrorKind {
if self.as_db_error().is_some() {
ErrorKind::Postgres
} else {
ErrorKind::Compute
}
}
}
/// Flattens `Result<Result<T>>` into `Result<T>`.
pub fn flatten_err<T>(r: Result<anyhow::Result<T>, JoinError>) -> anyhow::Result<T> {
r.context("join error").and_then(|x| x)

View File

@@ -52,7 +52,7 @@ pub async fn init() -> anyhow::Result<LoggingGuard> {
StderrWriter {
stderr: std::io::stderr(),
},
&["request_id", "session_id", "conn_id"],
&["conn_id", "ep", "query_id", "request_id", "session_id"],
))
} else {
None
@@ -271,18 +271,18 @@ where
});
// In case logging fails we generate a simpler JSON object.
if let Err(err) = res {
if let Ok(mut line) = serde_json::to_vec(&serde_json::json!( {
if let Err(err) = res
&& let Ok(mut line) = serde_json::to_vec(&serde_json::json!( {
"timestamp": now.to_rfc3339_opts(chrono::SecondsFormat::Micros, true),
"level": "ERROR",
"message": format_args!("cannot log event: {err:?}"),
"fields": {
"event": format_args!("{event:?}"),
},
})) {
line.push(b'\n');
self.writer.make_writer().write_all(&line).ok();
}
}))
{
line.push(b'\n');
self.writer.make_writer().write_all(&line).ok();
}
}
@@ -583,10 +583,11 @@ impl EventFormatter {
THREAD_ID.with(|tid| serializer.serialize_entry("thread_id", tid))?;
// TODO: tls cache? name could change
if let Some(thread_name) = std::thread::current().name() {
if !thread_name.is_empty() && thread_name != "tokio-runtime-worker" {
serializer.serialize_entry("thread_name", thread_name)?;
}
if let Some(thread_name) = std::thread::current().name()
&& !thread_name.is_empty()
&& thread_name != "tokio-runtime-worker"
{
serializer.serialize_entry("thread_name", thread_name)?;
}
if let Some(task_id) = tokio::task::try_id() {
@@ -596,10 +597,10 @@ impl EventFormatter {
serializer.serialize_entry("target", meta.target())?;
// Skip adding module if it's the same as target.
if let Some(module) = meta.module_path() {
if module != meta.target() {
serializer.serialize_entry("module", module)?;
}
if let Some(module) = meta.module_path()
&& module != meta.target()
{
serializer.serialize_entry("module", module)?;
}
if let Some(file) = meta.file() {

View File

@@ -236,13 +236,6 @@ pub enum Bool {
False,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "outcome")]
pub enum Outcome {
Success,
Failed,
}
#[derive(FixedCardinalityLabel, Copy, Clone)]
#[label(singleton = "outcome")]
pub enum CacheOutcome {

View File

@@ -90,27 +90,27 @@ where
// TODO: 1 info log, with a enum label for close direction.
// Early termination checks from compute to client.
if let TransferState::Done(_) = compute_to_client {
if let TransferState::Running(buf) = &client_to_compute {
info!("Compute is done, terminate client");
// Initiate shutdown
client_to_compute = TransferState::ShuttingDown(buf.amt);
client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, client, compute)
.map_err(ErrorSource::from_client)?;
}
if let TransferState::Done(_) = compute_to_client
&& let TransferState::Running(buf) = &client_to_compute
{
info!("Compute is done, terminate client");
// Initiate shutdown
client_to_compute = TransferState::ShuttingDown(buf.amt);
client_to_compute_result =
transfer_one_direction(cx, &mut client_to_compute, client, compute)
.map_err(ErrorSource::from_client)?;
}
// Early termination checks from client to compute.
if let TransferState::Done(_) = client_to_compute {
if let TransferState::Running(buf) = &compute_to_client {
info!("Client is done, terminate compute");
// Initiate shutdown
compute_to_client = TransferState::ShuttingDown(buf.amt);
compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, compute, client)
.map_err(ErrorSource::from_compute)?;
}
if let TransferState::Done(_) = client_to_compute
&& let TransferState::Running(buf) = &compute_to_client
{
info!("Client is done, terminate compute");
// Initiate shutdown
compute_to_client = TransferState::ShuttingDown(buf.amt);
compute_to_client_result =
transfer_one_direction(cx, &mut compute_to_client, compute, client)
.map_err(ErrorSource::from_compute)?;
}
// It is not a problem if ready! returns early ... (comment remains the same)

View File

@@ -39,7 +39,11 @@ impl<K: Hash + Eq> LeakyBucketRateLimiter<K> {
let config = config.map_or(self.default_config, Into::into);
if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
if self
.access_count
.fetch_add(1, Ordering::AcqRel)
.is_multiple_of(2048)
{
self.do_gc(now);
}

View File

@@ -211,7 +211,11 @@ impl<K: Hash + Eq, R: Rng, S: BuildHasher + Clone> BucketRateLimiter<K, R, S> {
// worst case memory usage is about:
// = 2 * 2048 * 64 * (48B + 72B)
// = 30MB
if self.access_count.fetch_add(1, Ordering::AcqRel) % 2048 == 0 {
if self
.access_count
.fetch_add(1, Ordering::AcqRel)
.is_multiple_of(2048)
{
self.do_gc();
}

View File

@@ -1,79 +0,0 @@
use core::net::IpAddr;
use std::sync::Arc;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::pqproto::CancelKeyData;
pub trait CancellationPublisherMut: Send + Sync + 'static {
#[allow(async_fn_in_trait)]
async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()>;
}
pub trait CancellationPublisher: Send + Sync + 'static {
#[allow(async_fn_in_trait)]
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()>;
}
impl CancellationPublisher for () {
async fn try_publish(
&self,
_cancel_key_data: CancelKeyData,
_session_id: Uuid,
_peer_addr: IpAddr,
) -> anyhow::Result<()> {
Ok(())
}
}
impl<P: CancellationPublisher> CancellationPublisherMut for P {
async fn try_publish(
&mut self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()> {
<P as CancellationPublisher>::try_publish(self, cancel_key_data, session_id, peer_addr)
.await
}
}
impl<P: CancellationPublisher> CancellationPublisher for Option<P> {
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()> {
if let Some(p) = self {
p.try_publish(cancel_key_data, session_id, peer_addr).await
} else {
Ok(())
}
}
}
impl<P: CancellationPublisherMut> CancellationPublisher for Arc<Mutex<P>> {
async fn try_publish(
&self,
cancel_key_data: CancelKeyData,
session_id: Uuid,
peer_addr: IpAddr,
) -> anyhow::Result<()> {
self.lock()
.await
.try_publish(cancel_key_data, session_id, peer_addr)
.await
}
}

View File

@@ -1,11 +1,12 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use futures::FutureExt;
use redis::aio::{ConnectionLike, MultiplexedConnection};
use redis::{ConnectionInfo, IntoConnectionInfo, RedisConnectionInfo, RedisResult};
use tokio::task::JoinHandle;
use tracing::{debug, error, info, warn};
use tokio::task::AbortHandle;
use tracing::{error, info, warn};
use super::elasticache::CredentialsProvider;
@@ -31,8 +32,9 @@ pub struct ConnectionWithCredentialsProvider {
credentials: Credentials,
// TODO: with more load on the connection, we should consider using a connection pool
con: Option<MultiplexedConnection>,
refresh_token_task: Option<JoinHandle<()>>,
refresh_token_task: Option<AbortHandle>,
mutex: tokio::sync::Mutex<()>,
credentials_refreshed: Arc<AtomicBool>,
}
impl Clone for ConnectionWithCredentialsProvider {
@@ -42,6 +44,7 @@ impl Clone for ConnectionWithCredentialsProvider {
con: None,
refresh_token_task: None,
mutex: tokio::sync::Mutex::new(()),
credentials_refreshed: Arc::new(AtomicBool::new(false)),
}
}
}
@@ -65,6 +68,7 @@ impl ConnectionWithCredentialsProvider {
con: None,
refresh_token_task: None,
mutex: tokio::sync::Mutex::new(()),
credentials_refreshed: Arc::new(AtomicBool::new(false)),
}
}
@@ -78,6 +82,7 @@ impl ConnectionWithCredentialsProvider {
con: None,
refresh_token_task: None,
mutex: tokio::sync::Mutex::new(()),
credentials_refreshed: Arc::new(AtomicBool::new(true)),
}
}
@@ -85,6 +90,10 @@ impl ConnectionWithCredentialsProvider {
redis::cmd("PING").query_async(con).await
}
pub(crate) fn credentials_refreshed(&self) -> bool {
self.credentials_refreshed.load(Ordering::Relaxed)
}
pub(crate) async fn connect(&mut self) -> anyhow::Result<()> {
let _guard = self.mutex.lock().await;
if let Some(con) = self.con.as_mut() {
@@ -112,13 +121,13 @@ impl ConnectionWithCredentialsProvider {
if let Credentials::Dynamic(credentials_provider, _) = &self.credentials {
let credentials_provider = credentials_provider.clone();
let con2 = con.clone();
let f = tokio::spawn(async move {
Self::keep_connection(con2, credentials_provider)
.await
.inspect_err(|e| debug!("keep_connection failed: {e}"))
.ok();
});
self.refresh_token_task = Some(f);
let credentials_refreshed = self.credentials_refreshed.clone();
let f = tokio::spawn(Self::keep_connection(
con2,
credentials_provider,
credentials_refreshed,
));
self.refresh_token_task = Some(f.abort_handle());
}
match Self::ping(&mut con).await {
Ok(()) => {
@@ -153,6 +162,7 @@ impl ConnectionWithCredentialsProvider {
async fn get_client(&self) -> anyhow::Result<redis::Client> {
let client = redis::Client::open(self.get_connection_info().await?)?;
self.credentials_refreshed.store(true, Ordering::Relaxed);
Ok(client)
}
@@ -168,16 +178,19 @@ impl ConnectionWithCredentialsProvider {
async fn keep_connection(
mut con: MultiplexedConnection,
credentials_provider: Arc<CredentialsProvider>,
) -> anyhow::Result<()> {
credentials_refreshed: Arc<AtomicBool>,
) -> ! {
loop {
// The connection lives for 12h, for the sanity check we refresh it every hour.
tokio::time::sleep(Duration::from_secs(60 * 60)).await;
match Self::refresh_token(&mut con, credentials_provider.clone()).await {
Ok(()) => {
info!("Token refreshed");
credentials_refreshed.store(true, Ordering::Relaxed);
}
Err(e) => {
error!("Error during token refresh: {e:?}");
credentials_refreshed.store(false, Ordering::Relaxed);
}
}
}
@@ -231,7 +244,7 @@ impl ConnectionLike for ConnectionWithCredentialsProvider {
&'a mut self,
cmd: &'a redis::Cmd,
) -> redis::RedisFuture<'a, redis::Value> {
(async move { self.send_packed_command(cmd).await }).boxed()
self.send_packed_command(cmd).boxed()
}
fn req_packed_commands<'a>(
@@ -240,10 +253,10 @@ impl ConnectionLike for ConnectionWithCredentialsProvider {
offset: usize,
count: usize,
) -> redis::RedisFuture<'a, Vec<redis::Value>> {
(async move { self.send_packed_commands(cmd, offset, count).await }).boxed()
self.send_packed_commands(cmd, offset, count).boxed()
}
fn get_db(&self) -> i64 {
0
self.con.as_ref().map_or(0, |c| c.get_db())
}
}

View File

@@ -40,6 +40,10 @@ impl RedisKVClient {
.inspect_err(|e| tracing::error!("failed to connect to redis: {e}"))
}
pub(crate) fn credentials_refreshed(&self) -> bool {
self.client.credentials_refreshed()
}
pub(crate) async fn query<T: FromRedisValue>(
&mut self,
q: &impl Queryable,
@@ -49,7 +53,7 @@ impl RedisKVClient {
Err(e) => e,
};
tracing::error!("failed to run query: {e}");
tracing::debug!("failed to run query: {e}");
match e.retry_method() {
redis::RetryMethod::Reconnect => {
tracing::info!("Redis client is disconnected. Reconnecting...");

View File

@@ -1,4 +1,3 @@
pub mod cancellation_publisher;
pub mod connection_with_credentials_provider;
pub mod elasticache;
pub mod keys;

View File

@@ -54,9 +54,7 @@ impl<T: std::fmt::Display> ChannelBinding<T> {
"eSws".into()
}
Self::Required(mode) => {
use std::io::Write;
let mut cbind_input = vec![];
write!(&mut cbind_input, "p={mode},,",).unwrap();
let mut cbind_input = format!("p={mode},,",).into_bytes();
cbind_input.extend_from_slice(get_cbind_data(mode)?);
BASE64_STANDARD.encode(&cbind_input).into()
}

View File

@@ -107,7 +107,7 @@ pub(crate) async fn exchange(
secret: &ServerSecret,
password: &[u8],
) -> sasl::Result<sasl::Outcome<super::ScramKey>> {
let salt = BASE64_STANDARD.decode(&secret.salt_base64)?;
let salt = BASE64_STANDARD.decode(&*secret.salt_base64)?;
let client_key = derive_client_key(pool, endpoint, password, &salt, secret.iterations).await;
if secret.is_password_invalid(&client_key).into() {

View File

@@ -87,13 +87,20 @@ impl<'a> ClientFirstMessage<'a> {
salt_base64: &str,
iterations: u32,
) -> OwnedServerFirstMessage {
use std::fmt::Write;
let mut message = String::with_capacity(128);
message.push_str("r=");
let mut message = String::new();
write!(&mut message, "r={}", self.nonce).unwrap();
// write combined nonce
let combined_nonce_start = message.len();
message.push_str(self.nonce);
BASE64_STANDARD.encode_string(nonce, &mut message);
let combined_nonce = 2..message.len();
write!(&mut message, ",s={salt_base64},i={iterations}").unwrap();
let combined_nonce = combined_nonce_start..message.len();
// write salt and iterations
message.push_str(",s=");
message.push_str(salt_base64);
message.push_str(",i=");
message.push_str(itoa::Buffer::new().format(iterations));
// This design guarantees that it's impossible to create a
// server-first-message without receiving a client-first-message

View File

@@ -14,7 +14,7 @@ pub(crate) struct ServerSecret {
/// Number of iterations for `PBKDF2` function.
pub(crate) iterations: u32,
/// Salt used to hash user's password.
pub(crate) salt_base64: String,
pub(crate) salt_base64: Box<str>,
/// Hashed `ClientKey`.
pub(crate) stored_key: ScramKey,
/// Used by client to verify server's signature.
@@ -35,7 +35,7 @@ impl ServerSecret {
let secret = ServerSecret {
iterations: iterations.parse().ok()?,
salt_base64: salt.to_owned(),
salt_base64: salt.into(),
stored_key: base64_decode_array(stored_key)?.into(),
server_key: base64_decode_array(server_key)?.into(),
doomed: false,
@@ -58,7 +58,7 @@ impl ServerSecret {
// iteration count 1 for our generated passwords going forward.
// PG16 users can set iteration count=1 already today.
iterations: 1,
salt_base64: BASE64_STANDARD.encode(nonce),
salt_base64: BASE64_STANDARD.encode(nonce).into_boxed_str(),
stored_key: ScramKey::default(),
server_key: ScramKey::default(),
doomed: true,
@@ -88,7 +88,7 @@ mod tests {
let parsed = ServerSecret::parse(&secret).unwrap();
assert_eq!(parsed.iterations, iterations);
assert_eq!(parsed.salt_base64, salt);
assert_eq!(&*parsed.salt_base64, salt);
assert_eq!(BASE64_STANDARD.encode(parsed.stored_key), stored_key);
assert_eq!(BASE64_STANDARD.encode(parsed.server_key), server_key);

View File

@@ -137,7 +137,7 @@ impl Future for JobSpec {
let state = state.as_mut().expect("should be set on thread startup");
state.tick = state.tick.wrapping_add(1);
if state.tick % SKETCH_RESET_INTERVAL == 0 {
if state.tick.is_multiple_of(SKETCH_RESET_INTERVAL) {
state.countmin.reset();
}

View File

@@ -115,7 +115,8 @@ impl PoolingBackend {
match &self.auth_backend {
crate::auth::Backend::ControlPlane(console, ()) => {
self.config
let keys = self
.config
.authentication_config
.jwks_cache
.check_jwt(
@@ -129,7 +130,7 @@ impl PoolingBackend {
Ok(ComputeCredentials {
info: user_info.clone(),
keys: crate::auth::backend::ComputeCredentialKeys::None,
keys,
})
}
crate::auth::Backend::Local(_) => {
@@ -256,6 +257,7 @@ impl PoolingBackend {
&self,
ctx: &RequestContext,
conn_info: ConnInfo,
disable_pg_session_jwt: bool,
) -> Result<Client<postgres_client::Client>, HttpConnError> {
if let Some(client) = self.local_pool.get(ctx, &conn_info)? {
return Ok(client);
@@ -277,7 +279,7 @@ impl PoolingBackend {
.expect("semaphore should never be closed");
// check again for race
if !self.local_pool.initialized(&conn_info) {
if !self.local_pool.initialized(&conn_info) && !disable_pg_session_jwt {
local_backend
.compute_ctl
.install_extension(&ExtensionInstallRequest {
@@ -313,14 +315,16 @@ impl PoolingBackend {
.to_postgres_client_config();
config
.user(&conn_info.user_info.user)
.dbname(&conn_info.dbname)
.set_param(
.dbname(&conn_info.dbname);
if !disable_pg_session_jwt {
config.set_param(
"options",
&format!(
"-c pg_session_jwt.jwk={}",
serde_json::to_string(&jwk).expect("serializing jwk to json should not fail")
),
);
}
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (client, connection) = config.connect(&postgres_client::NoTls).await?;
@@ -345,7 +349,9 @@ impl PoolingBackend {
debug!("setting up backend session state");
// initiates the auth session
if let Err(e) = client.batch_execute("select auth.init();").await {
if !disable_pg_session_jwt
&& let Err(e) = client.batch_execute("select auth.init();").await
{
discard.discard();
return Err(e.into());
}
@@ -404,7 +410,15 @@ impl ReportableError for HttpConnError {
fn get_error_kind(&self) -> ErrorKind {
match self {
HttpConnError::ConnectionClosedAbruptly(_) => ErrorKind::Compute,
HttpConnError::PostgresConnectionError(p) => p.get_error_kind(),
HttpConnError::PostgresConnectionError(p) => {
if p.as_db_error().is_some() {
// postgres rejected the connection
ErrorKind::Postgres
} else {
// couldn't even reach postgres
ErrorKind::Compute
}
}
HttpConnError::LocalProxyConnectionError(_) => ErrorKind::Compute,
HttpConnError::ComputeCtl(_) => ErrorKind::Service,
HttpConnError::JwtPayloadError(_) => ErrorKind::User,

View File

@@ -148,11 +148,10 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_client(db_user.clone(), conn_id) {
if let Some(pool) = pool.clone().upgrade()
&& pool.write().remove_client(db_user.clone(), conn_id) {
info!("closed connection removed");
}
}
Poll::Ready(())
}).await;

View File

@@ -1,5 +1,93 @@
use http::StatusCode;
use http::header::HeaderName;
use crate::auth::ComputeUserInfoParseError;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::http::ReadBodyError;
pub trait HttpCodeError {
fn get_http_status_code(&self) -> StatusCode;
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ConnInfoError {
#[error("invalid header: {0}")]
InvalidHeader(&'static HeaderName),
#[error("invalid connection string: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("incorrect scheme")]
IncorrectScheme,
#[error("missing database name")]
MissingDbName,
#[error("invalid database name")]
InvalidDbName,
#[error("missing username")]
MissingUsername,
#[error("invalid username: {0}")]
InvalidUsername(#[from] std::string::FromUtf8Error),
#[error("missing authentication credentials: {0}")]
MissingCredentials(Credentials),
#[error("missing hostname")]
MissingHostname,
#[error("invalid hostname: {0}")]
InvalidEndpoint(#[from] ComputeUserInfoParseError),
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum Credentials {
#[error("required password")]
Password,
#[error("required authorization bearer token in JWT format")]
BearerJwt,
}
impl ReportableError for ConnInfoError {
fn get_error_kind(&self) -> ErrorKind {
ErrorKind::User
}
}
impl UserFacingError for ConnInfoError {
fn to_string_client(&self) -> String {
self.to_string()
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ReadPayloadError {
#[error("could not read the HTTP request body: {0}")]
Read(#[from] hyper::Error),
#[error("request is too large (max is {limit} bytes)")]
BodyTooLarge { limit: usize },
#[error("could not parse the HTTP request body: {0}")]
Parse(#[from] serde_json::Error),
}
impl From<ReadBodyError<hyper::Error>> for ReadPayloadError {
fn from(value: ReadBodyError<hyper::Error>) -> Self {
match value {
ReadBodyError::BodyTooLarge { limit } => Self::BodyTooLarge { limit },
ReadBodyError::Read(e) => Self::Read(e),
}
}
}
impl ReportableError for ReadPayloadError {
fn get_error_kind(&self) -> ErrorKind {
match self {
ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
ReadPayloadError::BodyTooLarge { .. } => ErrorKind::User,
ReadPayloadError::Parse(_) => ErrorKind::User,
}
}
}
impl HttpCodeError for ReadPayloadError {
fn get_http_status_code(&self) -> StatusCode {
match self {
ReadPayloadError::Read(_) => StatusCode::BAD_REQUEST,
ReadPayloadError::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
ReadPayloadError::Parse(_) => StatusCode::BAD_REQUEST,
}
}
}

View File

@@ -2,6 +2,8 @@ use std::collections::VecDeque;
use std::sync::atomic::{self, AtomicUsize};
use std::sync::{Arc, Weak};
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use hyper::client::conn::http2;
use hyper_util::rt::{TokioExecutor, TokioIo};
use parking_lot::RwLock;
@@ -21,8 +23,9 @@ use crate::protocol2::ConnectionInfoExtra;
use crate::types::EndpointCacheKey;
use crate::usage_metrics::{Ids, MetricCounter, USAGE_METRICS};
pub(crate) type Send = http2::SendRequest<hyper::body::Incoming>;
pub(crate) type Connect = http2::Connection<TokioIo<AsyncRW>, hyper::body::Incoming, TokioExecutor>;
pub(crate) type Send = http2::SendRequest<BoxBody<Bytes, hyper::Error>>;
pub(crate) type Connect =
http2::Connection<TokioIo<AsyncRW>, BoxBody<Bytes, hyper::Error>, TokioExecutor>;
#[derive(Clone)]
pub(crate) struct ClientDataHttp();
@@ -237,10 +240,10 @@ pub(crate) fn poll_http2_client(
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.write().remove_conn(conn_id) {
info!("closed connection removed");
}
if let Some(pool) = pool.clone().upgrade()
&& pool.write().remove_conn(conn_id)
{
info!("closed connection removed");
}
}
.instrument(span),

View File

@@ -3,11 +3,42 @@
use anyhow::Context;
use bytes::Bytes;
use http::{Response, StatusCode};
use http::header::AUTHORIZATION;
use http::{HeaderMap, HeaderName, HeaderValue, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use http_utils::error::ApiError;
use serde::Serialize;
use url::Url;
use uuid::Uuid;
use super::conn_pool::{AuthData, ConnInfoWithAuth};
use super::conn_pool_lib::ConnInfo;
use super::error::{ConnInfoError, Credentials};
use crate::auth::backend::ComputeUserInfo;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::metrics::{Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::NeonOptions;
use crate::types::{DbName, EndpointId, RoleName};
// Common header names used across serverless modules
pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
pub(super) static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
pub(super) static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
pub(super) static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
pub(super) static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
pub(super) static TXN_ISOLATION_LEVEL: HeaderName =
HeaderName::from_static("neon-batch-isolation-level");
pub(super) static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
pub(super) static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
.expect("uuid hyphenated format should be all valid header characters")
}
/// Like [`ApiError::into_response`]
pub(crate) fn api_error_into_response(this: ApiError) -> Response<BoxBody<Bytes, hyper::Error>> {
@@ -107,3 +138,136 @@ pub(crate) fn json_response<T: Serialize>(
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
}
pub(crate) fn get_conn_info(
config: &'static AuthenticationConfig,
ctx: &RequestContext,
connection_string: Option<&str>,
headers: &HeaderMap,
) -> Result<ConnInfoWithAuth, ConnInfoError> {
let connection_url = match connection_string {
Some(connection_string) => Url::parse(connection_string)?,
None => {
let connection_string = headers
.get(&CONN_STRING)
.ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
Url::parse(connection_string)?
}
};
let protocol = connection_url.scheme();
if protocol != "postgres" && protocol != "postgresql" {
return Err(ConnInfoError::IncorrectScheme);
}
let mut url_path = connection_url
.path_segments()
.ok_or(ConnInfoError::MissingDbName)?;
let dbname: DbName =
urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
ctx.set_dbname(dbname.clone());
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
if username.is_empty() {
return Err(ConnInfoError::MissingUsername);
}
ctx.set_user(username.clone());
// TODO: make sure this is right in the context of rest broker
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
if !config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
}
let auth = auth
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
AuthData::Jwt(
auth.strip_prefix("Bearer ")
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
.into(),
)
} else if let Some(pass) = connection_url.password() {
// wrong credentials provided
if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
}
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
})
} else if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
} else {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
};
let endpoint: EndpointId = match connection_url.host() {
Some(url::Host::Domain(hostname)) => hostname
.split_once('.')
.map_or(hostname, |(prefix, _)| prefix)
.into(),
Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
return Err(ConnInfoError::MissingHostname);
}
};
ctx.set_endpoint_id(endpoint.clone());
let pairs = connection_url.query_pairs();
let mut options = Option::None;
let mut params = StartupMessageParams::default();
params.insert("user", &username);
params.insert("database", &dbname);
for (key, value) in pairs {
params.insert(&key, &value);
if key == "options" {
options = Some(NeonOptions::parse_options_raw(&value));
}
}
// check the URL that was used, for metrics
{
let host_endpoint = headers
// get the host header
.get("host")
// extract the domain
.and_then(|h| {
let (host, _port) = h.to_str().ok()?.split_once(':')?;
Some(host)
})
// get the endpoint prefix
.map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix));
let kind = if host_endpoint == Some(&*endpoint) {
SniKind::Sni
} else {
SniKind::NoSni
};
let protocol = ctx.protocol();
Metrics::get()
.proxy
.accepted_connections_by_sni
.inc(SniGroup { protocol, kind });
}
ctx.set_user_agent(
headers
.get(hyper::header::USER_AGENT)
.and_then(|h| h.to_str().ok())
.map(Into::into),
);
let user_info = ComputeUserInfo {
endpoint,
user: username,
options: options.unwrap_or_default(),
};
let conn_info = ConnInfo { user_info, dbname };
Ok(ConnInfoWithAuth { conn_info, auth })
}

View File

@@ -70,6 +70,34 @@ pub(crate) enum JsonConversionError {
ParseJsonError(#[from] serde_json::Error),
#[error("unbalanced array")]
UnbalancedArray,
#[error("unbalanced quoted string")]
UnbalancedString,
}
enum OutputMode {
Array(Vec<Value>),
Object(Map<String, Value>),
}
impl OutputMode {
fn key(&mut self, key: &str) -> &mut Value {
match self {
OutputMode::Array(values) => push_entry(values, Value::Null),
OutputMode::Object(map) => map.entry(key.to_string()).or_insert(Value::Null),
}
}
fn finish(self) -> Value {
match self {
OutputMode::Array(values) => Value::Array(values),
OutputMode::Object(map) => Value::Object(map),
}
}
}
fn push_entry<T>(arr: &mut Vec<T>, t: T) -> &mut T {
arr.push(t);
arr.last_mut().expect("a value was just inserted")
}
//
@@ -77,182 +105,277 @@ pub(crate) enum JsonConversionError {
//
pub(crate) fn pg_text_row_to_json(
row: &Row,
columns: &[Type],
raw_output: bool,
array_mode: bool,
) -> Result<Value, JsonConversionError> {
let iter = row
.columns()
.iter()
.zip(columns)
.enumerate()
.map(|(i, (column, typ))| {
let name = column.name();
let pg_value = row.as_text(i).map_err(JsonConversionError::AsTextError)?;
let json_value = if raw_output {
match pg_value {
Some(v) => Value::String(v.to_string()),
None => Value::Null,
}
} else {
pg_text_to_json(pg_value, typ)?
};
Ok((name.to_string(), json_value))
});
if array_mode {
// drop keys and aggregate into array
let arr = iter
.map(|r| r.map(|(_key, val)| val))
.collect::<Result<Vec<Value>, JsonConversionError>>()?;
Ok(Value::Array(arr))
let mut entries = if array_mode {
OutputMode::Array(Vec::with_capacity(row.columns().len()))
} else {
let obj = iter.collect::<Result<Map<String, Value>, JsonConversionError>>()?;
Ok(Value::Object(obj))
OutputMode::Object(Map::with_capacity(row.columns().len()))
};
for (i, column) in row.columns().iter().enumerate() {
let pg_value = row.as_text(i).map_err(JsonConversionError::AsTextError)?;
let value = entries.key(column.name());
match pg_value {
Some(v) if raw_output => *value = Value::String(v.to_string()),
Some(v) => pg_text_to_json(value, v, column.type_())?,
None => *value = Value::Null,
}
}
Ok(entries.finish())
}
//
// Convert postgres text-encoded value to JSON value
//
fn pg_text_to_json(pg_value: Option<&str>, pg_type: &Type) -> Result<Value, JsonConversionError> {
if let Some(val) = pg_value {
if let Kind::Array(elem_type) = pg_type.kind() {
return pg_array_parse(val, elem_type);
}
fn pg_text_to_json(
output: &mut Value,
val: &str,
pg_type: &Type,
) -> Result<(), JsonConversionError> {
if let Kind::Array(elem_type) = pg_type.kind() {
// todo: we should fetch this from postgres.
let delimiter = ',';
match *pg_type {
Type::BOOL => Ok(Value::Bool(val == "t")),
Type::INT2 | Type::INT4 => {
let val = val.parse::<i32>()?;
Ok(Value::Number(serde_json::Number::from(val)))
}
Type::FLOAT4 | Type::FLOAT8 => {
let fval = val.parse::<f64>()?;
let num = serde_json::Number::from_f64(fval);
if let Some(num) = num {
Ok(Value::Number(num))
} else {
// Pass Nan, Inf, -Inf as strings
// JS JSON.stringify() does converts them to null, but we
// want to preserve them, so we pass them as strings
Ok(Value::String(val.to_string()))
}
}
Type::JSON | Type::JSONB => Ok(serde_json::from_str(val)?),
_ => Ok(Value::String(val.to_string())),
}
} else {
Ok(Value::Null)
}
}
//
// Parse postgres array into JSON array.
//
// This is a bit involved because we need to handle nested arrays and quoted
// values. Unlike postgres we don't check that all nested arrays have the same
// dimensions, we just return them as is.
//
fn pg_array_parse(pg_array: &str, elem_type: &Type) -> Result<Value, JsonConversionError> {
pg_array_parse_inner(pg_array, elem_type, false).map(|(v, _)| v)
}
fn pg_array_parse_inner(
pg_array: &str,
elem_type: &Type,
nested: bool,
) -> Result<(Value, usize), JsonConversionError> {
let mut pg_array_chr = pg_array.char_indices();
let mut level = 0;
let mut quote = false;
let mut entries: Vec<Value> = Vec::new();
let mut entry = String::new();
// skip bounds decoration
if let Some('[') = pg_array.chars().next() {
for (_, c) in pg_array_chr.by_ref() {
if c == '=' {
break;
}
}
let mut array = vec![];
pg_array_parse(&mut array, val, elem_type, delimiter)?;
*output = Value::Array(array);
return Ok(());
}
fn push_checked(
entry: &mut String,
entries: &mut Vec<Value>,
elem_type: &Type,
) -> Result<(), JsonConversionError> {
if !entry.is_empty() {
// While in usual postgres response we get nulls as None and everything else
// as Some(&str), in arrays we get NULL as unquoted 'NULL' string (while
// string with value 'NULL' will be represented by '"NULL"'). So catch NULLs
// here while we have quotation info and convert them to None.
if entry == "NULL" {
entries.push(pg_text_to_json(None, elem_type)?);
match *pg_type {
Type::BOOL => *output = Value::Bool(val == "t"),
Type::INT2 | Type::INT4 => {
let val = val.parse::<i32>()?;
*output = Value::Number(serde_json::Number::from(val));
}
Type::FLOAT4 | Type::FLOAT8 => {
let fval = val.parse::<f64>()?;
let num = serde_json::Number::from_f64(fval);
if let Some(num) = num {
*output = Value::Number(num);
} else {
entries.push(pg_text_to_json(Some(entry), elem_type)?);
// Pass Nan, Inf, -Inf as strings
// JS JSON.stringify() does converts them to null, but we
// want to preserve them, so we pass them as strings
*output = Value::String(val.to_string());
}
entry.clear();
}
Ok(())
Type::JSON | Type::JSONB => *output = serde_json::from_str(val)?,
_ => *output = Value::String(val.to_string()),
}
while let Some((mut i, mut c)) = pg_array_chr.next() {
let mut escaped = false;
Ok(())
}
if c == '\\' {
escaped = true;
let Some(x) = pg_array_chr.next() else {
return Err(JsonConversionError::UnbalancedArray);
};
(i, c) = x;
}
match c {
'{' if !quote => {
level += 1;
if level > 1 {
let (res, off) = pg_array_parse_inner(&pg_array[i..], elem_type, true)?;
entries.push(res);
for _ in 0..off - 1 {
pg_array_chr.next();
}
}
}
'}' if !quote => {
level -= 1;
if level == 0 {
push_checked(&mut entry, &mut entries, elem_type)?;
if nested {
return Ok((Value::Array(entries), i));
}
}
}
'"' if !escaped => {
if quote {
// end of quoted string, so push it manually without any checks
// for emptiness or nulls
entries.push(pg_text_to_json(Some(&entry), elem_type)?);
entry.clear();
}
quote = !quote;
}
',' if !quote => {
push_checked(&mut entry, &mut entries, elem_type)?;
}
_ => {
entry.push(c);
}
}
/// Parse postgres array into JSON array.
///
/// This is a bit involved because we need to handle nested arrays and quoted
/// values. Unlike postgres we don't check that all nested arrays have the same
/// dimensions, we just return them as is.
///
/// <https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO>
///
/// The external text representation of an array value consists of items that are interpreted
/// according to the I/O conversion rules for the array's element type, plus decoration that
/// indicates the array structure. The decoration consists of curly braces (`{` and `}`) around
/// the array value plus delimiter characters between adjacent items. The delimiter character
/// is usually a comma (,) but can be something else: it is determined by the typdelim setting
/// for the array's element type. Among the standard data types provided in the PostgreSQL
/// distribution, all use a comma, except for type box, which uses a semicolon (;).
///
/// In a multidimensional array, each dimension (row, plane, cube, etc.)
/// gets its own level of curly braces, and delimiters must be written between adjacent
/// curly-braced entities of the same level.
fn pg_array_parse(
elements: &mut Vec<Value>,
mut pg_array: &str,
elem: &Type,
delim: char,
) -> Result<(), JsonConversionError> {
// skip bounds decoration, eg:
// `[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}`
// technically these are significant, but we have no way to represent them in json.
if let Some('[') = pg_array.chars().next() {
let Some((_bounds, array)) = pg_array.split_once('=') else {
return Err(JsonConversionError::UnbalancedArray);
};
pg_array = array;
}
if level != 0 {
// whitespace might preceed a `{`.
let pg_array = pg_array.trim_start();
let rest = pg_array_parse_inner(elements, pg_array, elem, delim)?;
if !rest.is_empty() {
return Err(JsonConversionError::UnbalancedArray);
}
Ok((Value::Array(entries), 0))
Ok(())
}
/// reads a single array from the `pg_array` string and pushes each values to `elements`.
/// returns the rest of the `pg_array` string that was not read.
fn pg_array_parse_inner<'a>(
elements: &mut Vec<Value>,
mut pg_array: &'a str,
elem: &Type,
delim: char,
) -> Result<&'a str, JsonConversionError> {
// array should have a `{` prefix.
pg_array = pg_array
.strip_prefix('{')
.ok_or(JsonConversionError::UnbalancedArray)?;
let mut q = String::new();
loop {
let value = push_entry(elements, Value::Null);
pg_array = pg_array_parse_item(value, &mut q, pg_array, elem, delim)?;
// check for separator.
if let Some(next) = pg_array.strip_prefix(delim) {
// next item.
pg_array = next;
} else {
break;
}
}
let Some(next) = pg_array.strip_prefix('}') else {
// missing `}` terminator.
return Err(JsonConversionError::UnbalancedArray);
};
// whitespace might follow a `}`.
Ok(next.trim_start())
}
/// reads a single item from the `pg_array` string.
/// returns the rest of the `pg_array` string that was not read.
///
/// `quoted` is a scratch allocation that has no defined output.
fn pg_array_parse_item<'a>(
output: &mut Value,
quoted: &mut String,
mut pg_array: &'a str,
elem: &Type,
delim: char,
) -> Result<&'a str, JsonConversionError> {
// We are trying to parse an array item.
// This could be a new array, if this is a multi-dimentional array.
// This could be a quoted string representing `elem`.
// This could be an unquoted string representing `elem`.
// whitespace might preceed an item.
pg_array = pg_array.trim_start();
if pg_array.starts_with('{') {
// nested array.
let mut nested = vec![];
pg_array = pg_array_parse_inner(&mut nested, pg_array, elem, delim)?;
*output = Value::Array(nested);
return Ok(pg_array);
}
if let Some(mut pg_array) = pg_array.strip_prefix('"') {
// the parsed string is un-escaped and written into quoted.
pg_array = pg_array_parse_quoted(quoted, pg_array)?;
// we have un-escaped the string, parse it as pgtext.
pg_text_to_json(output, quoted, elem)?;
return Ok(pg_array);
}
// we need to parse an item. read until we find a delimiter or `}`.
let index = pg_array
.find([delim, '}'])
.ok_or(JsonConversionError::UnbalancedArray)?;
let item;
(item, pg_array) = pg_array.split_at(index);
// item might have trailing whitespace that we need to ignore.
let item = item.trim_end();
// we might have an item string:
// check for null
if item == "NULL" {
*output = Value::Null;
} else {
pg_text_to_json(output, item, elem)?;
}
Ok(pg_array)
}
/// reads a single quoted item from the `pg_array` string.
///
/// Returns the rest of the `pg_array` string that was not read.
/// The output is written into `quoted`.
///
/// The pg_array string must have a `"` terminator, but the `"` initial value
/// must have already been removed from the input. The terminator is removed.
fn pg_array_parse_quoted<'a>(
quoted: &mut String,
mut pg_array: &'a str,
) -> Result<&'a str, JsonConversionError> {
// The array output routine will put double quotes around element values if they are empty strings,
// contain curly braces, delimiter characters, double quotes, backslashes, or white space,
// or match the word `NULL`. Double quotes and backslashes embedded in element values will be backslash-escaped.
// For numeric data types it is safe to assume that double quotes will never appear,
// but for textual data types one should be prepared to cope with either the presence or absence of quotes.
quoted.clear();
// We write to quoted in chunks terminated by an escape character.
// Eg if we have the input `foo\"bar"`, then we write `foo`, then `"`, then finally `bar`.
loop {
// we need to parse an chunk. read until we find a '\\' or `"`.
let i = pg_array
.find(['\\', '"'])
.ok_or(JsonConversionError::UnbalancedString)?;
let chunk: &str;
(chunk, pg_array) = pg_array
.split_at_checked(i)
.expect("i is guaranteed to be in-bounds of pg_array");
// push the chunk.
quoted.push_str(chunk);
// consume the chunk_end character.
let chunk_end: char;
(chunk_end, pg_array) =
split_first_char(pg_array).expect("pg_array should start with either '\\\\' or '\"'");
// finished.
if chunk_end == '"' {
// whitespace might follow the '"'.
pg_array = pg_array.trim_start();
break Ok(pg_array);
}
// consume the escaped character.
let escaped: char;
(escaped, pg_array) =
split_first_char(pg_array).ok_or(JsonConversionError::UnbalancedString)?;
quoted.push(escaped);
}
}
fn split_first_char(s: &str) -> Option<(char, &str)> {
let mut chars = s.chars();
let c = chars.next()?;
Some((c, chars.as_str()))
}
#[cfg(test)]
@@ -316,37 +439,33 @@ mod tests {
);
}
fn pg_text_to_json(val: &str, pg_type: &Type) -> Value {
let mut v = Value::Null;
super::pg_text_to_json(&mut v, val, pg_type).unwrap();
v
}
fn pg_array_parse(pg_array: &str, pg_type: &Type) -> Value {
let mut array = vec![];
super::pg_array_parse(&mut array, pg_array, pg_type, ',').unwrap();
Value::Array(array)
}
#[test]
fn test_atomic_types_parse() {
assert_eq!(pg_text_to_json("foo", &Type::TEXT), json!("foo"));
assert_eq!(pg_text_to_json("42", &Type::INT4), json!(42));
assert_eq!(pg_text_to_json("42", &Type::INT2), json!(42));
assert_eq!(pg_text_to_json("42", &Type::INT8), json!("42"));
assert_eq!(pg_text_to_json("42.42", &Type::FLOAT8), json!(42.42));
assert_eq!(pg_text_to_json("42.42", &Type::FLOAT4), json!(42.42));
assert_eq!(pg_text_to_json("NaN", &Type::FLOAT4), json!("NaN"));
assert_eq!(
pg_text_to_json(Some("foo"), &Type::TEXT).unwrap(),
json!("foo")
);
assert_eq!(pg_text_to_json(None, &Type::TEXT).unwrap(), json!(null));
assert_eq!(pg_text_to_json(Some("42"), &Type::INT4).unwrap(), json!(42));
assert_eq!(pg_text_to_json(Some("42"), &Type::INT2).unwrap(), json!(42));
assert_eq!(
pg_text_to_json(Some("42"), &Type::INT8).unwrap(),
json!("42")
);
assert_eq!(
pg_text_to_json(Some("42.42"), &Type::FLOAT8).unwrap(),
json!(42.42)
);
assert_eq!(
pg_text_to_json(Some("42.42"), &Type::FLOAT4).unwrap(),
json!(42.42)
);
assert_eq!(
pg_text_to_json(Some("NaN"), &Type::FLOAT4).unwrap(),
json!("NaN")
);
assert_eq!(
pg_text_to_json(Some("Infinity"), &Type::FLOAT4).unwrap(),
pg_text_to_json("Infinity", &Type::FLOAT4),
json!("Infinity")
);
assert_eq!(
pg_text_to_json(Some("-Infinity"), &Type::FLOAT4).unwrap(),
pg_text_to_json("-Infinity", &Type::FLOAT4),
json!("-Infinity")
);
@@ -355,10 +474,9 @@ mod tests {
.unwrap();
assert_eq!(
pg_text_to_json(
Some(r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#),
r#"{"s":"str","n":42,"f":4.2,"a":[null,3,"a"]}"#,
&Type::JSONB
)
.unwrap(),
),
json
);
}
@@ -366,7 +484,7 @@ mod tests {
#[test]
fn test_pg_array_parse_text() {
fn pt(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::TEXT).unwrap()
pg_array_parse(pg_arr, &Type::TEXT)
}
assert_eq!(
pt(r#"{"aa\"\\\,a",cha,"bbbb"}"#),
@@ -389,7 +507,7 @@ mod tests {
#[test]
fn test_pg_array_parse_bool() {
fn pb(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::BOOL).unwrap()
pg_array_parse(pg_arr, &Type::BOOL)
}
assert_eq!(pb(r#"{t,f,t}"#), json!([true, false, true]));
assert_eq!(pb(r#"{{t,f,t}}"#), json!([[true, false, true]]));
@@ -406,7 +524,7 @@ mod tests {
#[test]
fn test_pg_array_parse_numbers() {
fn pn(pg_arr: &str, ty: &Type) -> Value {
pg_array_parse(pg_arr, ty).unwrap()
pg_array_parse(pg_arr, ty)
}
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT4), json!([1, 2, 3]));
assert_eq!(pn(r#"{1,2,3}"#, &Type::INT2), json!([1, 2, 3]));
@@ -434,7 +552,7 @@ mod tests {
#[test]
fn test_pg_array_with_decoration() {
fn p(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::INT2).unwrap()
pg_array_parse(pg_arr, &Type::INT2)
}
assert_eq!(
p(r#"[1:1][-2:-1][3:5]={{{1,2,3},{4,5,6}}}"#),
@@ -445,7 +563,7 @@ mod tests {
#[test]
fn test_pg_array_parse_json() {
fn pt(pg_arr: &str) -> Value {
pg_array_parse(pg_arr, &Type::JSONB).unwrap()
pg_array_parse(pg_arr, &Type::JSONB)
}
assert_eq!(pt(r#"{"{}"}"#), json!([{}]));
assert_eq!(

View File

@@ -249,11 +249,10 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
}
// remove from connection pool
if let Some(pool) = pool.clone().upgrade() {
if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
if let Some(pool) = pool.clone().upgrade()
&& pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
info!("closed connection removed");
}
}
Poll::Ready(())
}).await;

View File

@@ -29,13 +29,13 @@ use futures::future::{Either, select};
use http::{Method, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty};
use http_util::{NEON_REQUEST_ID, uuid_to_header_value};
use http_utils::error::ApiError;
use hyper::body::Incoming;
use hyper_util::rt::TokioExecutor;
use hyper_util::server::conn::auto::Builder;
use rand::SeedableRng;
use rand::rngs::StdRng;
use sql_over_http::{NEON_REQUEST_ID, uuid_to_header_value};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{TcpListener, TcpStream};
use tokio::time::timeout;

View File

@@ -11,7 +11,7 @@ use http_body_util::{BodyExt, Full};
use http_utils::error::ApiError;
use hyper::body::Incoming;
use hyper::http::{HeaderName, HeaderValue};
use hyper::{HeaderMap, Request, Response, StatusCode, header};
use hyper::{Request, Response, StatusCode, header};
use indexmap::IndexMap;
use postgres_client::error::{DbError, ErrorPosition, SqlState};
use postgres_client::{
@@ -22,28 +22,25 @@ use serde_json::Value;
use serde_json::value::RawValue;
use tokio::time::{self, Instant};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info};
use tracing::{Level, debug, error, info};
use typed_json::json;
use url::Url;
use uuid::Uuid;
use super::backend::{LocalProxyConnError, PoolingBackend};
use super::conn_pool::{AuthData, ConnInfoWithAuth};
use super::conn_pool::AuthData;
use super::conn_pool_lib::{self, ConnInfo};
use super::error::HttpCodeError;
use super::http_util::json_response;
use super::error::{ConnInfoError, HttpCodeError, ReadPayloadError};
use super::http_util::{
ALLOW_POOL, ARRAY_MODE, CONN_STRING, NEON_REQUEST_ID, RAW_TEXT_OUTPUT, TXN_DEFERRABLE,
TXN_ISOLATION_LEVEL, TXN_READ_ONLY, get_conn_info, json_response, uuid_to_header_value,
};
use super::json::{JsonConversionError, json_to_pg_text, pg_text_row_to_json};
use crate::auth::backend::{ComputeCredentialKeys, ComputeUserInfo};
use crate::auth::{ComputeUserInfoParseError, endpoint_sni};
use crate::config::{AuthenticationConfig, HttpConfig, ProxyConfig, TlsConfig};
use crate::auth::backend::ComputeCredentialKeys;
use crate::config::{HttpConfig, ProxyConfig};
use crate::context::RequestContext;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::http::{ReadBodyError, read_body_with_limit};
use crate::metrics::{HttpDirection, Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::NeonOptions;
use crate::http::read_body_with_limit;
use crate::metrics::{HttpDirection, Metrics};
use crate::serverless::backend::HttpConnError;
use crate::types::{DbName, RoleName};
use crate::usage_metrics::{MetricCounter, MetricCounterRecorder};
use crate::util::run_until_cancelled;
@@ -70,16 +67,6 @@ enum Payload {
Batch(BatchQueryData),
}
pub(super) static NEON_REQUEST_ID: HeaderName = HeaderName::from_static("neon-request-id");
static CONN_STRING: HeaderName = HeaderName::from_static("neon-connection-string");
static RAW_TEXT_OUTPUT: HeaderName = HeaderName::from_static("neon-raw-text-output");
static ARRAY_MODE: HeaderName = HeaderName::from_static("neon-array-mode");
static ALLOW_POOL: HeaderName = HeaderName::from_static("neon-pool-opt-in");
static TXN_ISOLATION_LEVEL: HeaderName = HeaderName::from_static("neon-batch-isolation-level");
static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
static HEADER_VALUE_TRUE: HeaderValue = HeaderValue::from_static("true");
fn bytes_to_pg_text<'de, D>(deserializer: D) -> Result<Vec<Option<String>>, D::Error>
@@ -91,188 +78,6 @@ where
Ok(json_to_pg_text(json))
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ConnInfoError {
#[error("invalid header: {0}")]
InvalidHeader(&'static HeaderName),
#[error("invalid connection string: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("incorrect scheme")]
IncorrectScheme,
#[error("missing database name")]
MissingDbName,
#[error("invalid database name")]
InvalidDbName,
#[error("missing username")]
MissingUsername,
#[error("invalid username: {0}")]
InvalidUsername(#[from] std::string::FromUtf8Error),
#[error("missing authentication credentials: {0}")]
MissingCredentials(Credentials),
#[error("missing hostname")]
MissingHostname,
#[error("invalid hostname: {0}")]
InvalidEndpoint(#[from] ComputeUserInfoParseError),
#[error("malformed endpoint")]
MalformedEndpoint,
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum Credentials {
#[error("required password")]
Password,
#[error("required authorization bearer token in JWT format")]
BearerJwt,
}
impl ReportableError for ConnInfoError {
fn get_error_kind(&self) -> ErrorKind {
ErrorKind::User
}
}
impl UserFacingError for ConnInfoError {
fn to_string_client(&self) -> String {
self.to_string()
}
}
fn get_conn_info(
config: &'static AuthenticationConfig,
ctx: &RequestContext,
headers: &HeaderMap,
tls: Option<&TlsConfig>,
) -> Result<ConnInfoWithAuth, ConnInfoError> {
let connection_string = headers
.get(&CONN_STRING)
.ok_or(ConnInfoError::InvalidHeader(&CONN_STRING))?
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&CONN_STRING))?;
let connection_url = Url::parse(connection_string)?;
let protocol = connection_url.scheme();
if protocol != "postgres" && protocol != "postgresql" {
return Err(ConnInfoError::IncorrectScheme);
}
let mut url_path = connection_url
.path_segments()
.ok_or(ConnInfoError::MissingDbName)?;
let dbname: DbName =
urlencoding::decode(url_path.next().ok_or(ConnInfoError::InvalidDbName)?)?.into();
ctx.set_dbname(dbname.clone());
let username = RoleName::from(urlencoding::decode(connection_url.username())?);
if username.is_empty() {
return Err(ConnInfoError::MissingUsername);
}
ctx.set_user(username.clone());
let auth = if let Some(auth) = headers.get(&AUTHORIZATION) {
if !config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
}
let auth = auth
.to_str()
.map_err(|_| ConnInfoError::InvalidHeader(&AUTHORIZATION))?;
AuthData::Jwt(
auth.strip_prefix("Bearer ")
.ok_or(ConnInfoError::MissingCredentials(Credentials::BearerJwt))?
.into(),
)
} else if let Some(pass) = connection_url.password() {
// wrong credentials provided
if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
}
AuthData::Password(match urlencoding::decode_binary(pass.as_bytes()) {
std::borrow::Cow::Borrowed(b) => b.into(),
std::borrow::Cow::Owned(b) => b.into(),
})
} else if config.accept_jwts {
return Err(ConnInfoError::MissingCredentials(Credentials::BearerJwt));
} else {
return Err(ConnInfoError::MissingCredentials(Credentials::Password));
};
let endpoint = match connection_url.host() {
Some(url::Host::Domain(hostname)) => {
if let Some(tls) = tls {
endpoint_sni(hostname, &tls.common_names).ok_or(ConnInfoError::MalformedEndpoint)?
} else {
hostname
.split_once('.')
.map_or(hostname, |(prefix, _)| prefix)
.into()
}
}
Some(url::Host::Ipv4(_) | url::Host::Ipv6(_)) | None => {
return Err(ConnInfoError::MissingHostname);
}
};
ctx.set_endpoint_id(endpoint.clone());
let pairs = connection_url.query_pairs();
let mut options = Option::None;
let mut params = StartupMessageParams::default();
params.insert("user", &username);
params.insert("database", &dbname);
for (key, value) in pairs {
params.insert(&key, &value);
if key == "options" {
options = Some(NeonOptions::parse_options_raw(&value));
}
}
// check the URL that was used, for metrics
{
let host_endpoint = headers
// get the host header
.get("host")
// extract the domain
.and_then(|h| {
let (host, _port) = h.to_str().ok()?.split_once(':')?;
Some(host)
})
// get the endpoint prefix
.map(|h| h.split_once('.').map_or(h, |(prefix, _)| prefix));
let kind = if host_endpoint == Some(&*endpoint) {
SniKind::Sni
} else {
SniKind::NoSni
};
let protocol = ctx.protocol();
Metrics::get()
.proxy
.accepted_connections_by_sni
.inc(SniGroup { protocol, kind });
}
ctx.set_user_agent(
headers
.get(hyper::header::USER_AGENT)
.and_then(|h| h.to_str().ok())
.map(Into::into),
);
let user_info = ComputeUserInfo {
endpoint,
user: username,
options: options.unwrap_or_default(),
};
let conn_info = ConnInfo { user_info, dbname };
Ok(ConnInfoWithAuth { conn_info, auth })
}
pub(crate) async fn handle(
config: &'static ProxyConfig,
ctx: RequestContext,
@@ -390,12 +195,35 @@ pub(crate) async fn handle(
let line = get(db_error, |db| db.line().map(|l| l.to_string()));
let routine = get(db_error, |db| db.routine());
tracing::info!(
kind=error_kind.to_metric_label(),
error=%e,
msg=message,
"forwarding error to user"
);
match &e {
SqlOverHttpError::Postgres(e)
if e.as_db_error().is_some() && error_kind == ErrorKind::User =>
{
// this error contains too much info, and it's not an error we care about.
if tracing::enabled!(Level::DEBUG) {
tracing::debug!(
kind=error_kind.to_metric_label(),
error=%e,
msg=message,
"forwarding error to user"
);
} else {
tracing::info!(
kind = error_kind.to_metric_label(),
error = "bad query",
"forwarding error to user"
);
}
}
_ => {
tracing::info!(
kind=error_kind.to_metric_label(),
error=%e,
msg=message,
"forwarding error to user"
);
}
}
json_response(
e.get_http_status_code(),
@@ -460,7 +288,15 @@ impl ReportableError for SqlOverHttpError {
SqlOverHttpError::ConnInfo(e) => e.get_error_kind(),
SqlOverHttpError::ResponseTooLarge(_) => ErrorKind::User,
SqlOverHttpError::InvalidIsolationLevel => ErrorKind::User,
SqlOverHttpError::Postgres(p) => p.get_error_kind(),
// customer initiated SQL errors.
SqlOverHttpError::Postgres(p) => {
if p.as_db_error().is_some() {
ErrorKind::User
} else {
ErrorKind::Compute
}
}
// proxy initiated SQL errors.
SqlOverHttpError::InternalPostgres(p) => {
if p.as_db_error().is_some() {
ErrorKind::Service
@@ -468,6 +304,7 @@ impl ReportableError for SqlOverHttpError {
ErrorKind::Compute
}
}
// postgres returned a bad row format that we couldn't parse.
SqlOverHttpError::JsonConversion(_) => ErrorKind::Postgres,
SqlOverHttpError::Cancelled(c) => c.get_error_kind(),
}
@@ -509,45 +346,6 @@ impl HttpCodeError for SqlOverHttpError {
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum ReadPayloadError {
#[error("could not read the HTTP request body: {0}")]
Read(#[from] hyper::Error),
#[error("request is too large (max is {limit} bytes)")]
BodyTooLarge { limit: usize },
#[error("could not parse the HTTP request body: {0}")]
Parse(#[from] serde_json::Error),
}
impl From<ReadBodyError<hyper::Error>> for ReadPayloadError {
fn from(value: ReadBodyError<hyper::Error>) -> Self {
match value {
ReadBodyError::BodyTooLarge { limit } => Self::BodyTooLarge { limit },
ReadBodyError::Read(e) => Self::Read(e),
}
}
}
impl ReportableError for ReadPayloadError {
fn get_error_kind(&self) -> ErrorKind {
match self {
ReadPayloadError::Read(_) => ErrorKind::ClientDisconnect,
ReadPayloadError::BodyTooLarge { .. } => ErrorKind::User,
ReadPayloadError::Parse(_) => ErrorKind::User,
}
}
}
impl HttpCodeError for ReadPayloadError {
fn get_http_status_code(&self) -> StatusCode {
match self {
ReadPayloadError::Read(_) => StatusCode::BAD_REQUEST,
ReadPayloadError::BodyTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE,
ReadPayloadError::Parse(_) => StatusCode::BAD_REQUEST,
}
}
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum SqlOverHttpCancel {
#[error("query was cancelled")]
@@ -638,14 +436,7 @@ async fn handle_inner(
"handling interactive connection from client"
);
let conn_info = get_conn_info(
&config.authentication_config,
ctx,
request.headers(),
// todo: race condition?
// we're unlikely to change the common names.
config.tls_config.load().as_deref(),
)?;
let conn_info = get_conn_info(&config.authentication_config, ctx, None, request.headers())?;
info!(
user = conn_info.conn_info.user_info.user.as_str(),
"credentials"
@@ -731,9 +522,17 @@ async fn handle_db_inner(
ComputeCredentialKeys::JwtPayload(payload)
if backend.auth_backend.is_local_proxy() =>
{
let mut client = backend.connect_to_local_postgres(ctx, conn_info).await?;
let (cli_inner, _dsc) = client.client_inner();
cli_inner.set_jwt_session(&payload).await?;
#[cfg(feature = "testing")]
let disable_pg_session_jwt = config.disable_pg_session_jwt;
#[cfg(not(feature = "testing"))]
let disable_pg_session_jwt = false;
let mut client = backend
.connect_to_local_postgres(ctx, conn_info, disable_pg_session_jwt)
.await?;
if !disable_pg_session_jwt {
let (cli_inner, _dsc) = client.client_inner();
cli_inner.set_jwt_session(&payload).await?;
}
Client::Local(client)
}
_ => {
@@ -832,12 +631,6 @@ static HEADERS_TO_FORWARD: &[&HeaderName] = &[
&TXN_DEFERRABLE,
];
pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
.expect("uuid hyphenated format should be all valid header characters")
}
async fn handle_auth_broker_inner(
ctx: &RequestContext,
request: Request<Incoming>,
@@ -867,7 +660,7 @@ async fn handle_auth_broker_inner(
req = req.header(&NEON_REQUEST_ID, uuid_to_header_value(ctx.session_id()));
let req = req
.body(body)
.body(body.map_err(|e| e).boxed()) //TODO: is there a potential for a regression here?
.expect("all headers and params received via hyper should be valid for request");
// todo: map body to count egress
@@ -1103,7 +896,6 @@ async fn query_to_json<T: GenericClient>(
let columns_len = row_stream.statement.columns().len();
let mut fields = Vec::with_capacity(columns_len);
let mut types = Vec::with_capacity(columns_len);
for c in row_stream.statement.columns() {
fields.push(json!({
@@ -1115,8 +907,6 @@ async fn query_to_json<T: GenericClient>(
"dataTypeModifier": c.type_modifier(),
"format": "text",
}));
types.push(c.type_().clone());
}
let raw_output = parsed_headers.raw_output;
@@ -1138,7 +928,7 @@ async fn query_to_json<T: GenericClient>(
));
}
let row = pg_text_row_to_json(&row, &types, raw_output, array_mode)?;
let row = pg_text_row_to_json(&row, raw_output, array_mode)?;
rows.push(row);
// assumption: parsing pg text and converting to json takes CPU time.

Some files were not shown because too many files have changed in this diff Show More