Compare commits

..

1 Commits

Author SHA1 Message Date
Folke Behrens
ef737e7d7c proxy: add benchmark for custom json logging vs official fmt logger 2025-07-15 19:44:41 +02:00
148 changed files with 1957 additions and 8847 deletions

View File

@@ -30,7 +30,6 @@ workspace-members = [
"vm_monitor",
# All of these exist in libs and are not usually built independently.
# Putting workspace hack there adds a bottleneck for cargo builds.
"alloc-metrics",
"compute_api",
"consumption_metrics",
"desim",

View File

@@ -181,8 +181,6 @@ runs:
# Ref https://github.com/neondatabase/neon/issues/4540
# cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/coverage run)
cov_prefix=()
# Explicitly set LLVM_PROFILE_FILE to /dev/null to avoid writing *.profraw files
export LLVM_PROFILE_FILE=/dev/null
else
cov_prefix=()
fi

View File

@@ -87,27 +87,22 @@ jobs:
uses: ./.github/workflows/build-build-tools-image.yml
secrets: inherit
lint-yamls:
needs: [ meta, check-permissions, build-build-tools-image ]
lint-openapi-spec:
runs-on: ubuntu-22.04
needs: [ meta, check-permissions ]
# We do need to run this in `.*-rc-pr` because of hotfixes.
if: ${{ contains(fromJSON('["pr", "push-main", "storage-rc-pr", "proxy-rc-pr", "compute-rc-pr"]'), needs.meta.outputs.run-kind) }}
runs-on: [ self-hosted, small ]
container:
image: ${{ needs.build-build-tools-image.outputs.image }}
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
options: --init
steps:
- name: Harden the runner (Audit all outbound calls)
uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0
with:
egress-policy: audit
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- run: make -C compute manifest-schema-validation
- uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- run: make lint-openapi-spec
check-codestyle-python:
@@ -222,6 +217,28 @@ jobs:
build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm
secrets: inherit
validate-compute-manifest:
runs-on: ubuntu-22.04
needs: [ meta, check-permissions ]
# We do need to run this in `.*-rc-pr` because of hotfixes.
if: ${{ contains(fromJSON('["pr", "push-main", "storage-rc-pr", "proxy-rc-pr", "compute-rc-pr"]'), needs.meta.outputs.run-kind) }}
steps:
- name: Harden the runner (Audit all outbound calls)
uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0
with:
egress-policy: audit
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Node.js
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
with:
node-version: '24'
- name: Validate manifest against schema
run: |
make -C compute manifest-schema-validation
build-and-test-locally:
needs: [ meta, build-build-tools-image ]
# We do need to run this in `.*-rc-pr` because of hotfixes.

3
.gitignore vendored
View File

@@ -29,6 +29,3 @@ docker-compose/docker-compose-parallel.yml
# pgindent typedef lists
*.list
# Node
**/node_modules/

114
Cargo.lock generated
View File

@@ -61,17 +61,6 @@ dependencies = [
"equator",
]
[[package]]
name = "alloc-metrics"
version = "0.1.0"
dependencies = [
"criterion",
"measured",
"metrics",
"thread_local",
"tikv-jemallocator",
]
[[package]]
name = "allocator-api2"
version = "0.2.16"
@@ -1883,7 +1872,6 @@ dependencies = [
"diesel_derives",
"itoa",
"serde_json",
"uuid",
]
[[package]]
@@ -2545,18 +2533,6 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "getrandom"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
dependencies = [
"cfg-if",
"libc",
"r-efi",
"wasi 0.14.2+wasi-0.2.4",
]
[[package]]
name = "gettid"
version = "0.1.3"
@@ -3630,9 +3606,9 @@ checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104"
[[package]]
name = "lock_api"
version = "0.4.13"
version = "0.4.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765"
checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16"
dependencies = [
"autocfg",
"scopeguard",
@@ -3782,7 +3758,7 @@ dependencies = [
"procfs",
"prometheus",
"rand 0.8.5",
"rand_distr 0.4.3",
"rand_distr",
"twox-hash",
]
@@ -3870,12 +3846,7 @@ checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
name = "neon-shmem"
version = "0.1.0"
dependencies = [
"libc",
"lock_api",
"nix 0.30.1",
"rand 0.9.1",
"rand_distr 0.5.1",
"rustc-hash 2.1.1",
"tempfile",
"thiserror 1.0.69",
"workspace_hack",
@@ -5312,7 +5283,6 @@ name = "proxy"
version = "0.1.0"
dependencies = [
"ahash",
"alloc-metrics",
"anyhow",
"arc-swap",
"assert-json-diff",
@@ -5333,6 +5303,7 @@ dependencies = [
"clashmap",
"compute_api",
"consumption_metrics",
"criterion",
"ecdsa 0.16.9",
"ed25519-dalek",
"env_logger",
@@ -5377,7 +5348,7 @@ dependencies = [
"postgres_backend",
"pq_proto",
"rand 0.8.5",
"rand_distr 0.4.3",
"rand_distr",
"rcgen",
"redis",
"regex",
@@ -5388,7 +5359,7 @@ dependencies = [
"reqwest-tracing",
"rsa",
"rstest",
"rustc-hash 2.1.1",
"rustc-hash 1.1.0",
"rustls 0.23.27",
"rustls-native-certs 0.8.0",
"rustls-pemfile 2.1.1",
@@ -5481,12 +5452,6 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "r-efi"
version = "5.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
[[package]]
name = "rand"
version = "0.7.3"
@@ -5511,16 +5476,6 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rand"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
]
[[package]]
name = "rand_chacha"
version = "0.2.2"
@@ -5541,16 +5496,6 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rand_chacha"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core 0.9.3",
]
[[package]]
name = "rand_core"
version = "0.5.1"
@@ -5569,15 +5514,6 @@ dependencies = [
"getrandom 0.2.11",
]
[[package]]
name = "rand_core"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
dependencies = [
"getrandom 0.3.3",
]
[[package]]
name = "rand_distr"
version = "0.4.3"
@@ -5588,16 +5524,6 @@ dependencies = [
"rand 0.8.5",
]
[[package]]
name = "rand_distr"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463"
dependencies = [
"num-traits",
"rand 0.9.1",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
@@ -6279,7 +6205,6 @@ dependencies = [
"itertools 0.10.5",
"jsonwebtoken",
"metrics",
"nix 0.30.1",
"once_cell",
"pageserver_api",
"parking_lot 0.12.1",
@@ -7008,7 +6933,6 @@ dependencies = [
"tokio-util",
"tracing",
"utils",
"uuid",
"workspace_hack",
]
@@ -7344,10 +7268,12 @@ dependencies = [
[[package]]
name = "thread_local"
version = "1.1.9"
source = "git+https://github.com/conradludgate/thread_local-rs?branch=no-tls-destructor-get#f9ca3d375745c14a632ae3ffe6a7a646dc8421a0"
version = "1.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152"
dependencies = [
"cfg-if",
"once_cell",
]
[[package]]
@@ -8280,7 +8206,6 @@ dependencies = [
"tracing-error",
"tracing-subscriber",
"tracing-utils",
"uuid",
"walkdir",
]
@@ -8423,15 +8348,6 @@ version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wasi"
version = "0.14.2+wasi-0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3"
dependencies = [
"wit-bindgen-rt",
]
[[package]]
name = "wasite"
version = "0.1.0"
@@ -8789,15 +8705,6 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "wit-bindgen-rt"
version = "0.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
dependencies = [
"bitflags 2.8.0",
]
[[package]]
name = "workspace_hack"
version = "0.1.0"
@@ -8900,6 +8807,7 @@ dependencies = [
"tracing-log",
"tracing-subscriber",
"url",
"uuid",
"zeroize",
"zstd",
"zstd-safe",

View File

@@ -130,7 +130,6 @@ jemalloc_pprof = { version = "0.7", features = ["symbolize", "flamegraph"] }
jsonwebtoken = "9"
lasso = "0.7"
libc = "0.2"
lock_api = "0.4.13"
md5 = "0.7.0"
measured = { version = "0.0.22", features=["lasso"] }
measured-process = { version = "0.0.22" }
@@ -166,7 +165,7 @@ reqwest-middleware = "0.4"
reqwest-retry = "0.7"
routerify = "3"
rpds = "0.13"
rustc-hash = "2.1.1"
rustc-hash = "1.1.0"
rustls = { version = "0.23.16", default-features = false }
rustls-pemfile = "2"
rustls-pki-types = "1.11"
@@ -195,7 +194,6 @@ sync_wrapper = "0.1.2"
tar = "0.4"
test-context = "0.3"
thiserror = "1.0"
thread_local = "1.1.9"
tikv-jemallocator = { version = "0.6", features = ["profiling", "stats", "unprefixed_malloc_on_supported_platforms"] }
tikv-jemalloc-ctl = { version = "0.6", features = ["stats"] }
tokio = { version = "1.43.1", features = ["macros"] }
@@ -254,7 +252,6 @@ azure_storage = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git"
azure_storage_blobs = { git = "https://github.com/neondatabase/azure-sdk-for-rust.git", branch = "neon", default-features = false, features = ["enable_reqwest_rustls"] }
## Local libraries
alloc-metrics = { version = "0.1", path = "./libs/alloc-metrics/" }
compute_api = { version = "0.1", path = "./libs/compute_api/" }
consumption_metrics = { version = "0.1", path = "./libs/consumption_metrics/" }
desim = { version = "0.1", path = "./libs/desim" }
@@ -304,9 +301,6 @@ tonic-build = "0.13.1"
# Needed to get `tokio-postgres-rustls` to depend on our fork.
tokio-postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon" }
# Needed to fix a bug in alloc-metrics
thread_local = { git = "https://github.com/conradludgate/thread_local-rs", branch = "no-tls-destructor-get" }
################# Binary contents sections
[profile.release]

View File

@@ -2,7 +2,7 @@ ROOT_PROJECT_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
# Where to install Postgres, default is ./pg_install, maybe useful for package
# managers.
POSTGRES_INSTALL_DIR ?= $(ROOT_PROJECT_DIR)/pg_install
POSTGRES_INSTALL_DIR ?= $(ROOT_PROJECT_DIR)/pg_install/
# Supported PostgreSQL versions
POSTGRES_VERSIONS = v17 v16 v15 v14
@@ -14,7 +14,7 @@ POSTGRES_VERSIONS = v17 v16 v15 v14
# it is derived from BUILD_TYPE.
# All intermediate build artifacts are stored here.
BUILD_DIR := $(ROOT_PROJECT_DIR)/build
BUILD_DIR := build
ICU_PREFIX_DIR := /usr/local/icu
@@ -212,7 +212,7 @@ neon-pgindent: postgres-v17-pg-bsd-indent neon-pg-ext-v17
FIND_TYPEDEF=$(ROOT_PROJECT_DIR)/vendor/postgres-v17/src/tools/find_typedef \
INDENT=$(BUILD_DIR)/v17/src/tools/pg_bsd_indent/pg_bsd_indent \
PGINDENT_SCRIPT=$(ROOT_PROJECT_DIR)/vendor/postgres-v17/src/tools/pgindent/pgindent \
-C $(BUILD_DIR)/pgxn-v17/neon \
-C $(BUILD_DIR)/neon-v17 \
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile pgindent
@@ -220,15 +220,11 @@ neon-pgindent: postgres-v17-pg-bsd-indent neon-pg-ext-v17
setup-pre-commit-hook:
ln -s -f $(ROOT_PROJECT_DIR)/pre-commit.py .git/hooks/pre-commit
build-tools/node_modules: build-tools/package.json
cd build-tools && $(if $(CI),npm ci,npm install)
touch build-tools/node_modules
.PHONY: lint-openapi-spec
lint-openapi-spec: build-tools/node_modules
lint-openapi-spec:
# operation-2xx-response: pageserver timeline delete returns 404 on success
find . -iname "openapi_spec.y*ml" -exec\
npx --prefix=build-tools/ redocly\
docker run --rm -v ${PWD}:/spec ghcr.io/redocly/cli:1.34.4\
--skip-rule=operation-operationId --skip-rule=operation-summary --extends=minimal\
--skip-rule=no-server-example.com --skip-rule=operation-2xx-response\
lint {} \+

View File

@@ -188,12 +188,6 @@ RUN curl -fsSL 'https://apt.llvm.org/llvm-snapshot.gpg.key' | apt-key add - \
&& bash -c 'for f in /usr/bin/clang*-${LLVM_VERSION} /usr/bin/llvm*-${LLVM_VERSION}; do ln -s "${f}" "${f%-${LLVM_VERSION}}"; done' \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
# Install node
ENV NODE_VERSION=24
RUN curl -fsSL https://deb.nodesource.com/setup_${NODE_VERSION}.x | bash - \
&& apt install -y nodejs \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
# Install docker
RUN curl -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg \
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/debian ${DEBIAN_VERSION} stable" > /etc/apt/sources.list.d/docker.list \
@@ -317,14 +311,14 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux
. "$HOME/.cargo/env" && \
cargo --version && rustup --version && \
rustup component add llvm-tools rustfmt clippy && \
cargo install rustfilt --locked --version ${RUSTFILT_VERSION} && \
cargo install cargo-hakari --locked --version ${CARGO_HAKARI_VERSION} && \
cargo install cargo-deny --locked --version ${CARGO_DENY_VERSION} && \
cargo install cargo-hack --locked --version ${CARGO_HACK_VERSION} && \
cargo install cargo-nextest --locked --version ${CARGO_NEXTEST_VERSION} && \
cargo install cargo-chef --locked --version ${CARGO_CHEF_VERSION} && \
cargo install diesel_cli --locked --version ${CARGO_DIESEL_CLI_VERSION} \
--features postgres-bundled --no-default-features && \
cargo install rustfilt --version ${RUSTFILT_VERSION} --locked && \
cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} --locked && \
cargo install cargo-deny --version ${CARGO_DENY_VERSION} --locked && \
cargo install cargo-hack --version ${CARGO_HACK_VERSION} --locked && \
cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} --locked && \
cargo install cargo-chef --version ${CARGO_CHEF_VERSION} --locked && \
cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} --locked \
--features postgres-bundled --no-default-features && \
rm -rf /home/nonroot/.cargo/registry && \
rm -rf /home/nonroot/.cargo/git

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +0,0 @@
{
"name": "build-tools",
"private": true,
"devDependencies": {
"@redocly/cli": "1.34.4",
"@sourcemeta/jsonschema": "10.0.0"
}
}

View File

@@ -50,9 +50,9 @@ jsonnetfmt-format:
jsonnetfmt --in-place $(jsonnet_files)
.PHONY: manifest-schema-validation
manifest-schema-validation: ../build-tools/node_modules
npx --prefix=../build-tools/ jsonschema validate -d https://json-schema.org/draft/2020-12/schema manifest.schema.json manifest.yaml
manifest-schema-validation: node_modules
node_modules/.bin/jsonschema validate -d https://json-schema.org/draft/2020-12/schema manifest.schema.json manifest.yaml
../build-tools/node_modules: ../build-tools/package.json
cd ../build-tools && $(if $(CI),npm ci,npm install)
touch ../build-tools/node_modules
node_modules: package.json
npm install
touch node_modules

View File

@@ -170,29 +170,7 @@ RUN case $DEBIAN_VERSION in \
FROM build-deps AS pg-build
ARG PG_VERSION
COPY vendor/postgres-${PG_VERSION:?} postgres
COPY compute/patches/postgres_fdw.patch .
COPY compute/patches/pg_stat_statements_pg14-16.patch .
COPY compute/patches/pg_stat_statements_pg17.patch .
RUN cd postgres && \
# Apply patches to some contrib extensions
# For example, we need to grant EXECUTE on pg_stat_statements_reset() to {privileged_role_name}.
# In vanilla Postgres this function is limited to Postgres role superuser.
# In Neon we have {privileged_role_name} role that is not a superuser but replaces superuser in some cases.
# We could add the additional grant statements to the Postgres repository but it would be hard to maintain,
# whenever we need to pick up a new Postgres version and we want to limit the changes in our Postgres fork,
# so we do it here.
case "${PG_VERSION}" in \
"v14" | "v15" | "v16") \
patch -p1 < /pg_stat_statements_pg14-16.patch; \
;; \
"v17") \
patch -p1 < /pg_stat_statements_pg17.patch; \
;; \
*) \
# To do not forget to migrate patches to the next major version
echo "No contrib patches for this PostgreSQL version" && exit 1;; \
esac && \
patch -p1 < /postgres_fdw.patch && \
export CONFIGURE_CMD="./configure CFLAGS='-O2 -g3 -fsigned-char' --enable-debug --with-openssl --with-uuid=ossp \
--with-icu --with-libxml --with-libxslt --with-lz4" && \
if [ "${PG_VERSION:?}" != "v14" ]; then \
@@ -206,6 +184,8 @@ RUN cd postgres && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/autoinc.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/dblink.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/postgres_fdw.control && \
file=/usr/local/pgsql/share/extension/postgres_fdw--1.0.sql && [ -e $file ] && \
echo 'GRANT USAGE ON FOREIGN DATA WRAPPER postgres_fdw TO neon_superuser;' >> $file && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/bloom.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/earthdistance.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/insert_username.control && \
@@ -215,7 +195,34 @@ RUN cd postgres && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgrowlocks.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgstattuple.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/refint.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/xml2.control
echo 'trusted = true' >> /usr/local/pgsql/share/extension/xml2.control && \
# We need to grant EXECUTE on pg_stat_statements_reset() to neon_superuser.
# In vanilla postgres this function is limited to Postgres role superuser.
# In neon we have neon_superuser role that is not a superuser but replaces superuser in some cases.
# We could add the additional grant statements to the postgres repository but it would be hard to maintain,
# whenever we need to pick up a new postgres version and we want to limit the changes in our postgres fork,
# so we do it here.
for file in /usr/local/pgsql/share/extension/pg_stat_statements--*.sql; do \
filename=$(basename "$file"); \
# Note that there are no downgrade scripts for pg_stat_statements, so we \
# don't have to modify any downgrade paths or (much) older versions: we only \
# have to make sure every creation of the pg_stat_statements_reset function \
# also adds execute permissions to the neon_superuser.
case $filename in \
pg_stat_statements--1.4.sql) \
# pg_stat_statements_reset is first created with 1.4
echo 'GRANT EXECUTE ON FUNCTION pg_stat_statements_reset() TO neon_superuser;' >> $file; \
;; \
pg_stat_statements--1.6--1.7.sql) \
# Then with the 1.6-1.7 migration it is re-created with a new signature, thus add the permissions back
echo 'GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) TO neon_superuser;' >> $file; \
;; \
pg_stat_statements--1.10--1.11.sql) \
# Then with the 1.10-1.11 migration it is re-created with a new signature again, thus add the permissions back
echo 'GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint, boolean) TO neon_superuser;' >> $file; \
;; \
esac; \
done;
# Set PATH for all the subsequent build steps
ENV PATH="/usr/local/pgsql/bin:$PATH"
@@ -1517,7 +1524,7 @@ WORKDIR /ext-src
COPY compute/patches/pg_duckdb_v031.patch .
COPY compute/patches/duckdb_v120.patch .
# pg_duckdb build requires source dir to be a git repo to get submodules
# allow {privileged_role_name} to execute some functions that in pg_duckdb are available to superuser only:
# allow neon_superuser to execute some functions that in pg_duckdb are available to superuser only:
# - extension management function duckdb.install_extension()
# - access to duckdb.extensions table and its sequence
RUN git clone --depth 1 --branch v0.3.1 https://github.com/duckdb/pg_duckdb.git pg_duckdb-src && \

7
compute/package.json Normal file
View File

@@ -0,0 +1,7 @@
{
"name": "neon-compute",
"private": true,
"dependencies": {
"@sourcemeta/jsonschema": "9.3.4"
}
}

View File

@@ -1,26 +1,22 @@
diff --git a/sql/anon.sql b/sql/anon.sql
index 0cdc769..5eab1d6 100644
index 0cdc769..b450327 100644
--- a/sql/anon.sql
+++ b/sql/anon.sql
@@ -1141,3 +1141,19 @@ $$
@@ -1141,3 +1141,15 @@ $$
-- TODO : https://en.wikipedia.org/wiki/L-diversity
-- TODO : https://en.wikipedia.org/wiki/T-closeness
+
+-- NEON Patches
+
+GRANT ALL ON SCHEMA anon to neon_superuser;
+GRANT ALL ON ALL TABLES IN SCHEMA anon TO neon_superuser;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT ALL ON SCHEMA anon to %I', privileged_role_name);
+ EXECUTE format('GRANT ALL ON ALL TABLES IN SCHEMA anon TO %I', privileged_role_name);
+
+ IF current_setting('server_version_num')::int >= 150000 THEN
+ EXECUTE format('GRANT SET ON PARAMETER anon.transparent_dynamic_masking TO %I', privileged_role_name);
+ END IF;
+ IF current_setting('server_version_num')::int >= 150000 THEN
+ GRANT SET ON PARAMETER anon.transparent_dynamic_masking TO neon_superuser;
+ END IF;
+END $$;
diff --git a/sql/init.sql b/sql/init.sql
index 7da6553..9b6164b 100644

View File

@@ -21,21 +21,13 @@ index 3235cc8..6b892bc 100644
include Makefile.global
diff --git a/sql/pg_duckdb--0.2.0--0.3.0.sql b/sql/pg_duckdb--0.2.0--0.3.0.sql
index d777d76..3b54396 100644
index d777d76..af60106 100644
--- a/sql/pg_duckdb--0.2.0--0.3.0.sql
+++ b/sql/pg_duckdb--0.2.0--0.3.0.sql
@@ -1056,3 +1056,14 @@ GRANT ALL ON FUNCTION duckdb.cache(TEXT, TEXT) TO PUBLIC;
@@ -1056,3 +1056,6 @@ GRANT ALL ON FUNCTION duckdb.cache(TEXT, TEXT) TO PUBLIC;
GRANT ALL ON FUNCTION duckdb.cache_info() TO PUBLIC;
GRANT ALL ON FUNCTION duckdb.cache_delete(TEXT) TO PUBLIC;
GRANT ALL ON PROCEDURE duckdb.recycle_ddb() TO PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT ALL ON FUNCTION duckdb.install_extension(TEXT) TO %I', privileged_role_name);
+ EXECUTE format('GRANT ALL ON TABLE duckdb.extensions TO %I', privileged_role_name);
+ EXECUTE format('GRANT ALL ON SEQUENCE duckdb.extensions_table_seq TO %I', privileged_role_name);
+END $$;
+GRANT ALL ON FUNCTION duckdb.install_extension(TEXT) TO neon_superuser;
+GRANT ALL ON TABLE duckdb.extensions TO neon_superuser;
+GRANT ALL ON SEQUENCE duckdb.extensions_table_seq TO neon_superuser;

View File

@@ -1,34 +0,0 @@
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.4.sql b/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
index 58cdf600fce..8be57a996f6 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
@@ -46,3 +46,12 @@ GRANT SELECT ON pg_stat_statements TO PUBLIC;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset() FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset() TO %I', privileged_role_name);
+END $$;
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql b/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
index 6fc3fed4c93..256345a8f79 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
@@ -20,3 +20,12 @@ LANGUAGE C STRICT PARALLEL SAFE;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) TO %I', privileged_role_name);
+END $$;

View File

@@ -1,52 +0,0 @@
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.10--1.11.sql b/contrib/pg_stat_statements/pg_stat_statements--1.10--1.11.sql
index 0bb2c397711..32764db1d8b 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.10--1.11.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.10--1.11.sql
@@ -80,3 +80,12 @@ LANGUAGE C STRICT PARALLEL SAFE;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint, boolean) FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint, boolean) TO %I', privileged_role_name);
+END $$;
\ No newline at end of file
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.4.sql b/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
index 58cdf600fce..8be57a996f6 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
@@ -46,3 +46,12 @@ GRANT SELECT ON pg_stat_statements TO PUBLIC;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset() FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset() TO %I', privileged_role_name);
+END $$;
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql b/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
index 6fc3fed4c93..256345a8f79 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
@@ -20,3 +20,12 @@ LANGUAGE C STRICT PARALLEL SAFE;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) TO %I', privileged_role_name);
+END $$;

View File

@@ -1,17 +0,0 @@
diff --git a/contrib/postgres_fdw/postgres_fdw--1.0.sql b/contrib/postgres_fdw/postgres_fdw--1.0.sql
index a0f0fc1bf45..ee077f2eea6 100644
--- a/contrib/postgres_fdw/postgres_fdw--1.0.sql
+++ b/contrib/postgres_fdw/postgres_fdw--1.0.sql
@@ -16,3 +16,12 @@ LANGUAGE C STRICT;
CREATE FOREIGN DATA WRAPPER postgres_fdw
HANDLER postgres_fdw_handler
VALIDATOR postgres_fdw_validator;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT USAGE ON FOREIGN DATA WRAPPER postgres_fdw TO %I', privileged_role_name);
+END $$;

View File

@@ -87,14 +87,6 @@ struct Cli {
#[arg(short = 'C', long, value_name = "DATABASE_URL")]
pub connstr: String,
#[arg(
long,
default_value = "neon_superuser",
value_name = "PRIVILEGED_ROLE_NAME",
value_parser = Self::parse_privileged_role_name
)]
pub privileged_role_name: String,
#[cfg(target_os = "linux")]
#[arg(long, default_value = "neon-postgres")]
pub cgroup: String,
@@ -157,21 +149,6 @@ impl Cli {
Ok(url)
}
/// For simplicity, we do not escape `privileged_role_name` anywhere in the code.
/// Since it's a system role, which we fully control, that's fine. Still, let's
/// validate it to avoid any surprises.
fn parse_privileged_role_name(value: &str) -> Result<String> {
use regex::Regex;
let pattern = Regex::new(r"^[a-z_]+$").unwrap();
if !pattern.is_match(value) {
bail!("--privileged-role-name can only contain lowercase letters and underscores")
}
Ok(value.to_string())
}
}
fn main() -> Result<()> {
@@ -201,7 +178,6 @@ fn main() -> Result<()> {
ComputeNodeParams {
compute_id: cli.compute_id,
connstr,
privileged_role_name: cli.privileged_role_name.clone(),
pgdata: cli.pgdata.clone(),
pgbin: cli.pgbin.clone(),
pgversion: get_pg_version_string(&cli.pgbin),
@@ -351,49 +327,4 @@ mod test {
])
.expect_err("URL parameters are not allowed");
}
#[test]
fn verify_privileged_role_name() {
// Valid name
let cli = Cli::parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--privileged-role-name",
"my_superuser",
]);
assert_eq!(cli.privileged_role_name, "my_superuser");
// Invalid names
Cli::try_parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--privileged-role-name",
"NeonSuperuser",
])
.expect_err("uppercase letters are not allowed");
Cli::try_parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--privileged-role-name",
"$'neon_superuser",
])
.expect_err("special characters are not allowed");
Cli::try_parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--privileged-role-name",
"",
])
.expect_err("empty name is not allowed");
}
}

View File

@@ -74,20 +74,12 @@ const DEFAULT_INSTALLED_EXTENSIONS_COLLECTION_INTERVAL: u64 = 3600;
/// Static configuration params that don't change after startup. These mostly
/// come from the CLI args, or are derived from them.
#[derive(Clone, Debug)]
pub struct ComputeNodeParams {
/// The ID of the compute
pub compute_id: String,
/// Url type maintains proper escaping
// Url type maintains proper escaping
pub connstr: url::Url,
/// The name of the 'weak' superuser role, which we give to the users.
/// It follows the allow list approach, i.e., we take a standard role
/// and grant it extra permissions with explicit GRANTs here and there,
/// and core patches.
pub privileged_role_name: String,
pub resize_swap_on_bind: bool,
pub set_disk_quota_for_fs: Option<String>,
@@ -1397,7 +1389,6 @@ impl ComputeNode {
self.create_pgdata()?;
config::write_postgres_conf(
pgdata_path,
&self.params,
&pspec.spec,
self.params.internal_http_port,
tls_config,
@@ -1746,7 +1737,6 @@ impl ComputeNode {
}
// Run migrations separately to not hold up cold starts
let params = self.params.clone();
tokio::spawn(async move {
let mut conf = conf.as_ref().clone();
conf.application_name("compute_ctl:migrations");
@@ -1758,7 +1748,7 @@ impl ComputeNode {
eprintln!("connection error: {e}");
}
});
if let Err(e) = handle_migrations(params, &mut client).await {
if let Err(e) = handle_migrations(&mut client).await {
error!("Failed to run migrations: {}", e);
}
}
@@ -1837,7 +1827,6 @@ impl ComputeNode {
let pgdata_path = Path::new(&self.params.pgdata);
config::write_postgres_conf(
pgdata_path,
&self.params,
&spec,
self.params.internal_http_port,
tls_config,
@@ -2450,31 +2439,14 @@ LIMIT 100",
pub fn spawn_lfc_offload_task(self: &Arc<Self>, interval: Duration) {
self.terminate_lfc_offload_task();
let secs = interval.as_secs();
info!("spawning lfc offload worker with {secs}s interval");
let this = self.clone();
info!("spawning LFC offload worker with {secs}s interval");
let handle = spawn(async move {
let mut interval = time::interval(interval);
interval.tick().await; // returns immediately
loop {
interval.tick().await;
let prewarm_state = this.state.lock().unwrap().lfc_prewarm_state.clone();
// Do not offload LFC state if we are currently prewarming or any issue occurred.
// If we'd do that, we might override the LFC state in endpoint storage with some
// incomplete state. Imagine a situation:
// 1. Endpoint started with `autoprewarm: true`
// 2. While prewarming is not completed, we upload the new incomplete state
// 3. Compute gets interrupted and restarts
// 4. We start again and try to prewarm with the state from 2. instead of the previous complete state
if matches!(
prewarm_state,
LfcPrewarmState::Completed
| LfcPrewarmState::NotPrewarmed
| LfcPrewarmState::Skipped
) {
this.offload_lfc_async().await;
}
this.offload_lfc_async().await;
}
});
*self.lfc_offload_task.lock().unwrap() = Some(handle);

View File

@@ -89,7 +89,7 @@ impl ComputeNode {
self.state.lock().unwrap().lfc_offload_state.clone()
}
/// If there is a prewarm request ongoing, return `false`, `true` otherwise.
/// If there is a prewarm request ongoing, return false, true otherwise
pub fn prewarm_lfc(self: &Arc<Self>, from_endpoint: Option<String>) -> bool {
{
let state = &mut self.state.lock().unwrap().lfc_prewarm_state;
@@ -101,25 +101,15 @@ impl ComputeNode {
let cloned = self.clone();
spawn(async move {
let state = match cloned.prewarm_impl(from_endpoint).await {
Ok(true) => LfcPrewarmState::Completed,
Ok(false) => {
info!(
"skipping LFC prewarm because LFC state is not found in endpoint storage"
);
LfcPrewarmState::Skipped
}
Err(err) => {
crate::metrics::LFC_PREWARM_ERRORS.inc();
error!(%err, "could not prewarm LFC");
LfcPrewarmState::Failed {
error: err.to_string(),
}
}
let Err(err) = cloned.prewarm_impl(from_endpoint).await else {
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Completed;
return;
};
crate::metrics::LFC_PREWARM_ERRORS.inc();
error!(%err, "prewarming lfc");
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Failed {
error: err.to_string(),
};
cloned.state.lock().unwrap().lfc_prewarm_state = state;
});
true
}
@@ -130,21 +120,15 @@ impl ComputeNode {
EndpointStoragePair::from_spec_and_endpoint(state.pspec.as_ref().unwrap(), from_endpoint)
}
/// Request LFC state from endpoint storage and load corresponding pages into Postgres.
/// Returns a result with `false` if the LFC state is not found in endpoint storage.
async fn prewarm_impl(&self, from_endpoint: Option<String>) -> Result<bool> {
async fn prewarm_impl(&self, from_endpoint: Option<String>) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?;
info!(%url, "requesting LFC state from endpoint storage");
let request = Client::new().get(&url).bearer_auth(token);
let res = request.send().await.context("querying endpoint storage")?;
let status = res.status();
match status {
StatusCode::OK => (),
StatusCode::NOT_FOUND => {
return Ok(false);
}
_ => bail!("{status} querying endpoint storage"),
if status != StatusCode::OK {
bail!("{status} querying endpoint storage")
}
let mut uncompressed = Vec::new();
@@ -157,8 +141,7 @@ impl ComputeNode {
.await
.context("decoding LFC state")?;
let uncompressed_len = uncompressed.len();
info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}, loading into Postgres");
info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}, loading into postgres");
ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
.await
@@ -166,9 +149,7 @@ impl ComputeNode {
.query_one("select neon.prewarm_local_cache($1)", &[&uncompressed])
.await
.context("loading LFC state into postgres")
.map(|_| ())?;
Ok(true)
.map(|_| ())
}
/// If offload request is ongoing, return false, true otherwise
@@ -196,14 +177,12 @@ impl ComputeNode {
async fn offload_lfc_with_state_update(&self) {
crate::metrics::LFC_OFFLOADS.inc();
let Err(err) = self.offload_lfc_impl().await else {
self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Completed;
return;
};
crate::metrics::LFC_OFFLOAD_ERRORS.inc();
error!(%err, "could not offload LFC state to endpoint storage");
error!(%err, "offloading lfc");
self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Failed {
error: err.to_string(),
};
@@ -211,7 +190,7 @@ impl ComputeNode {
async fn offload_lfc_impl(&self) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?;
info!(%url, "requesting LFC state from Postgres");
info!(%url, "requesting LFC state from postgres");
let mut compressed = Vec::new();
ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
@@ -226,17 +205,13 @@ impl ComputeNode {
.read_to_end(&mut compressed)
.await
.context("compressing LFC state")?;
let compressed_len = compressed.len();
info!(%url, "downloaded LFC state, compressed size {compressed_len}, writing to endpoint storage");
let request = Client::new().put(url).bearer_auth(token).body(compressed);
match request.send().await {
Ok(res) if res.status() == StatusCode::OK => Ok(()),
Ok(res) => bail!(
"Request to endpoint storage failed with status: {}",
res.status()
),
Ok(res) => bail!("Error writing to endpoint storage: {}", res.status()),
Err(err) => Err(err).context("writing to endpoint storage"),
}
}

View File

@@ -9,7 +9,6 @@ use std::path::Path;
use compute_api::responses::TlsConfig;
use compute_api::spec::{ComputeAudit, ComputeMode, ComputeSpec, GenericOption};
use crate::compute::ComputeNodeParams;
use crate::pg_helpers::{
GenericOptionExt, GenericOptionsSearch, PgOptionsSerialize, escape_conf_value,
};
@@ -42,7 +41,6 @@ pub fn line_in_file(path: &Path, line: &str) -> Result<bool> {
/// Create or completely rewrite configuration file specified by `path`
pub fn write_postgres_conf(
pgdata_path: &Path,
params: &ComputeNodeParams,
spec: &ComputeSpec,
extension_server_port: u16,
tls_config: &Option<TlsConfig>,
@@ -56,15 +54,14 @@ pub fn write_postgres_conf(
writeln!(file, "{conf}")?;
}
// Stripe size GUC should be defined prior to connection string
if let Some(stripe_size) = spec.shard_stripe_size {
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
// Add options for connecting to storage
writeln!(file, "# Neon storage settings")?;
if let Some(s) = &spec.pageserver_connstring {
writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?;
}
if let Some(stripe_size) = spec.shard_stripe_size {
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
if !spec.safekeeper_connstrings.is_empty() {
let mut neon_safekeepers_value = String::new();
tracing::info!(
@@ -164,12 +161,6 @@ pub fn write_postgres_conf(
}
}
writeln!(
file,
"neon.privileged_role_name={}",
escape_conf_value(params.privileged_role_name.as_str())
)?;
// If there are any extra options in the 'settings' field, append those
if spec.cluster.settings.is_some() {
writeln!(file, "# Managed by compute_ctl: begin")?;

View File

@@ -613,11 +613,11 @@ components:
- skipped
properties:
status:
description: LFC prewarm status
enum: [not_prewarmed, prewarming, completed, failed, skipped]
description: Lfc prewarm status
enum: [not_prewarmed, prewarming, completed, failed]
type: string
error:
description: LFC prewarm error, if any
description: Lfc prewarm error, if any
type: string
total:
description: Total pages processed
@@ -635,11 +635,11 @@ components:
- status
properties:
status:
description: LFC offload status
description: Lfc offload status
enum: [not_offloaded, offloading, completed, failed]
type: string
error:
description: LFC offload error, if any
description: Lfc offload error, if any
type: string
PromoteState:

View File

@@ -1 +0,0 @@
ALTER ROLE {privileged_role_name} BYPASSRLS;

View File

@@ -0,0 +1 @@
ALTER ROLE neon_superuser BYPASSRLS;

View File

@@ -15,7 +15,7 @@ DO $$
DECLARE
role_name text;
BEGIN
FOR role_name IN SELECT rolname FROM pg_roles WHERE pg_has_role(rolname, '{privileged_role_name}', 'member')
FOR role_name IN SELECT rolname FROM pg_roles WHERE pg_has_role(rolname, 'neon_superuser', 'member')
LOOP
RAISE NOTICE 'EXECUTING ALTER ROLE % INHERIT', quote_ident(role_name);
EXECUTE 'ALTER ROLE ' || quote_ident(role_name) || ' INHERIT';
@@ -23,7 +23,7 @@ BEGIN
FOR role_name IN SELECT rolname FROM pg_roles
WHERE
NOT pg_has_role(rolname, '{privileged_role_name}', 'member') AND NOT starts_with(rolname, 'pg_')
NOT pg_has_role(rolname, 'neon_superuser', 'member') AND NOT starts_with(rolname, 'pg_')
LOOP
RAISE NOTICE 'EXECUTING ALTER ROLE % NOBYPASSRLS', quote_ident(role_name);
EXECUTE 'ALTER ROLE ' || quote_ident(role_name) || ' NOBYPASSRLS';

View File

@@ -1,6 +1,6 @@
DO $$
BEGIN
IF (SELECT setting::numeric >= 160000 FROM pg_settings WHERE name = 'server_version_num') THEN
EXECUTE 'GRANT pg_create_subscription TO {privileged_role_name}';
EXECUTE 'GRANT pg_create_subscription TO neon_superuser';
END IF;
END $$;

View File

@@ -0,0 +1 @@
GRANT pg_monitor TO neon_superuser WITH ADMIN OPTION;

View File

@@ -1 +0,0 @@
GRANT pg_monitor TO {privileged_role_name} WITH ADMIN OPTION;

View File

@@ -1,4 +1,4 @@
-- SKIP: Deemed insufficient for allowing relations created by extensions to be
-- interacted with by {privileged_role_name} without permission issues.
-- interacted with by neon_superuser without permission issues.
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO {privileged_role_name};
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO neon_superuser;

View File

@@ -1,4 +1,4 @@
-- SKIP: Deemed insufficient for allowing relations created by extensions to be
-- interacted with by {privileged_role_name} without permission issues.
-- interacted with by neon_superuser without permission issues.
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO {privileged_role_name};
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO neon_superuser;

View File

@@ -1,3 +1,3 @@
-- SKIP: Moved inline to the handle_grants() functions.
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO {privileged_role_name} WITH GRANT OPTION;
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO neon_superuser WITH GRANT OPTION;

View File

@@ -1,3 +1,3 @@
-- SKIP: Moved inline to the handle_grants() functions.
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO {privileged_role_name} WITH GRANT OPTION;
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO neon_superuser WITH GRANT OPTION;

View File

@@ -1,7 +1,7 @@
DO $$
BEGIN
IF (SELECT setting::numeric >= 160000 FROM pg_settings WHERE name = 'server_version_num') THEN
EXECUTE 'GRANT EXECUTE ON FUNCTION pg_export_snapshot TO {privileged_role_name}';
EXECUTE 'GRANT EXECUTE ON FUNCTION pg_log_standby_snapshot TO {privileged_role_name}';
EXECUTE 'GRANT EXECUTE ON FUNCTION pg_export_snapshot TO neon_superuser';
EXECUTE 'GRANT EXECUTE ON FUNCTION pg_log_standby_snapshot TO neon_superuser';
END IF;
END $$;

View File

@@ -0,0 +1 @@
GRANT EXECUTE ON FUNCTION pg_show_replication_origin_status TO neon_superuser;

View File

@@ -1 +0,0 @@
GRANT EXECUTE ON FUNCTION pg_show_replication_origin_status TO {privileged_role_name};

View File

@@ -0,0 +1 @@
GRANT pg_signal_backend TO neon_superuser WITH ADMIN OPTION;

View File

@@ -1 +0,0 @@
GRANT pg_signal_backend TO {privileged_role_name} WITH ADMIN OPTION;

View File

@@ -9,7 +9,6 @@ use reqwest::StatusCode;
use tokio_postgres::Client;
use tracing::{error, info, instrument};
use crate::compute::ComputeNodeParams;
use crate::config;
use crate::metrics::{CPLANE_REQUESTS_TOTAL, CPlaneRequestRPC, UNKNOWN_HTTP_STATUS};
use crate::migration::MigrationRunner;
@@ -170,7 +169,7 @@ pub async fn handle_neon_extension_upgrade(client: &mut Client) -> Result<()> {
}
#[instrument(skip_all)]
pub async fn handle_migrations(params: ComputeNodeParams, client: &mut Client) -> Result<()> {
pub async fn handle_migrations(client: &mut Client) -> Result<()> {
info!("handle migrations");
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -179,59 +178,26 @@ pub async fn handle_migrations(params: ComputeNodeParams, client: &mut Client) -
// Add new migrations in numerical order.
let migrations = [
&format!(
include_str!("./migrations/0001-add_bypass_rls_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
include_str!("./migrations/0001-neon_superuser_bypass_rls.sql"),
include_str!("./migrations/0002-alter_roles.sql"),
include_str!("./migrations/0003-grant_pg_create_subscription_to_neon_superuser.sql"),
include_str!("./migrations/0004-grant_pg_monitor_to_neon_superuser.sql"),
include_str!("./migrations/0005-grant_all_on_tables_to_neon_superuser.sql"),
include_str!("./migrations/0006-grant_all_on_sequences_to_neon_superuser.sql"),
include_str!(
"./migrations/0007-grant_all_on_tables_to_neon_superuser_with_grant_option.sql"
),
&format!(
include_str!("./migrations/0002-alter_roles.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0003-grant_pg_create_subscription_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0004-grant_pg_monitor_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0005-grant_all_on_tables_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0006-grant_all_on_sequences_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!(
"./migrations/0007-grant_all_on_tables_with_grant_option_to_privileged_role.sql"
),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!(
"./migrations/0008-grant_all_on_sequences_with_grant_option_to_privileged_role.sql"
),
privileged_role_name = params.privileged_role_name
include_str!(
"./migrations/0008-grant_all_on_sequences_to_neon_superuser_with_grant_option.sql"
),
include_str!("./migrations/0009-revoke_replication_for_previously_allowed_roles.sql"),
&format!(
include_str!(
"./migrations/0010-grant_snapshot_synchronization_funcs_to_privileged_role.sql"
),
privileged_role_name = params.privileged_role_name
include_str!(
"./migrations/0010-grant_snapshot_synchronization_funcs_to_neon_superuser.sql"
),
&format!(
include_str!(
"./migrations/0011-grant_pg_show_replication_origin_status_to_privileged_role.sql"
),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0012-grant_pg_signal_backend_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
include_str!(
"./migrations/0011-grant_pg_show_replication_origin_status_to_neon_superuser.sql"
),
include_str!("./migrations/0012-grant_pg_signal_backend_to_neon_superuser.sql"),
];
MigrationRunner::new(client, &migrations)

View File

@@ -13,14 +13,14 @@ use tokio_postgres::Client;
use tokio_postgres::error::SqlState;
use tracing::{Instrument, debug, error, info, info_span, instrument, warn};
use crate::compute::{ComputeNode, ComputeNodeParams, ComputeState};
use crate::compute::{ComputeNode, ComputeState};
use crate::pg_helpers::{
DatabaseExt, Escaping, GenericOptionsSearch, RoleExt, get_existing_dbs_async,
get_existing_roles_async,
};
use crate::spec_apply::ApplySpecPhase::{
CreateAndAlterDatabases, CreateAndAlterRoles, CreateAvailabilityCheck, CreatePgauditExtension,
CreatePgauditlogtofileExtension, CreatePrivilegedRole, CreateSchemaNeon,
CreateAndAlterDatabases, CreateAndAlterRoles, CreateAvailabilityCheck, CreateNeonSuperuser,
CreatePgauditExtension, CreatePgauditlogtofileExtension, CreateSchemaNeon,
DisablePostgresDBPgAudit, DropInvalidDatabases, DropRoles, FinalizeDropLogicalSubscriptions,
HandleNeonExtension, HandleOtherExtensions, RenameAndDeleteDatabases, RenameRoles,
RunInEachDatabase,
@@ -49,7 +49,6 @@ impl ComputeNode {
// Proceed with post-startup configuration. Note, that order of operations is important.
let client = Self::get_maintenance_client(&conf).await?;
let spec = spec.clone();
let params = Arc::new(self.params.clone());
let databases = get_existing_dbs_async(&client).await?;
let roles = get_existing_roles_async(&client)
@@ -158,7 +157,6 @@ impl ComputeNode {
let conf = Arc::new(conf);
let fut = Self::apply_spec_sql_db(
params.clone(),
spec.clone(),
conf,
ctx.clone(),
@@ -187,7 +185,7 @@ impl ComputeNode {
}
for phase in [
CreatePrivilegedRole,
CreateNeonSuperuser,
DropInvalidDatabases,
RenameRoles,
CreateAndAlterRoles,
@@ -197,7 +195,6 @@ impl ComputeNode {
] {
info!("Applying phase {:?}", &phase);
apply_operations(
params.clone(),
spec.clone(),
ctx.clone(),
jwks_roles.clone(),
@@ -246,7 +243,6 @@ impl ComputeNode {
}
let fut = Self::apply_spec_sql_db(
params.clone(),
spec.clone(),
conf,
ctx.clone(),
@@ -297,7 +293,6 @@ impl ComputeNode {
for phase in phases {
debug!("Applying phase {:?}", &phase);
apply_operations(
params.clone(),
spec.clone(),
ctx.clone(),
jwks_roles.clone(),
@@ -318,9 +313,7 @@ impl ComputeNode {
/// May opt to not connect to databases that don't have any scheduled
/// operations. The function is concurrency-controlled with the provided
/// semaphore. The caller has to make sure the semaphore isn't exhausted.
#[allow(clippy::too_many_arguments)] // TODO: needs bigger refactoring
async fn apply_spec_sql_db(
params: Arc<ComputeNodeParams>,
spec: Arc<ComputeSpec>,
conf: Arc<tokio_postgres::Config>,
ctx: Arc<tokio::sync::RwLock<MutableApplyContext>>,
@@ -335,7 +328,6 @@ impl ComputeNode {
for subphase in subphases {
apply_operations(
params.clone(),
spec.clone(),
ctx.clone(),
jwks_roles.clone(),
@@ -475,7 +467,7 @@ pub enum PerDatabasePhase {
#[derive(Clone, Debug)]
pub enum ApplySpecPhase {
CreatePrivilegedRole,
CreateNeonSuperuser,
DropInvalidDatabases,
RenameRoles,
CreateAndAlterRoles,
@@ -518,7 +510,6 @@ pub struct MutableApplyContext {
/// - No timeouts have (yet) been implemented.
/// - The caller is responsible for limiting and/or applying concurrency.
pub async fn apply_operations<'a, Fut, F>(
params: Arc<ComputeNodeParams>,
spec: Arc<ComputeSpec>,
ctx: Arc<RwLock<MutableApplyContext>>,
jwks_roles: Arc<HashSet<String>>,
@@ -536,7 +527,7 @@ where
debug!("Processing phase {:?}", &apply_spec_phase);
let ctx = ctx;
let mut ops = get_operations(&params, &spec, &ctx, &jwks_roles, &apply_spec_phase)
let mut ops = get_operations(&spec, &ctx, &jwks_roles, &apply_spec_phase)
.await?
.peekable();
@@ -597,18 +588,14 @@ where
/// sort/merge/batch execution, but for now this is a nice way to improve
/// batching behavior of the commands.
async fn get_operations<'a>(
params: &'a ComputeNodeParams,
spec: &'a ComputeSpec,
ctx: &'a RwLock<MutableApplyContext>,
jwks_roles: &'a HashSet<String>,
apply_spec_phase: &'a ApplySpecPhase,
) -> Result<Box<dyn Iterator<Item = Operation> + 'a + Send>> {
match apply_spec_phase {
ApplySpecPhase::CreatePrivilegedRole => Ok(Box::new(once(Operation {
query: format!(
include_str!("sql/create_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
ApplySpecPhase::CreateNeonSuperuser => Ok(Box::new(once(Operation {
query: include_str!("sql/create_neon_superuser.sql").to_string(),
comment: None,
}))),
ApplySpecPhase::DropInvalidDatabases => {
@@ -710,9 +697,8 @@ async fn get_operations<'a>(
None => {
let query = if !jwks_roles.contains(role.name.as_str()) {
format!(
"CREATE ROLE {} INHERIT CREATEROLE CREATEDB BYPASSRLS REPLICATION IN ROLE {} {}",
"CREATE ROLE {} INHERIT CREATEROLE CREATEDB BYPASSRLS REPLICATION IN ROLE neon_superuser {}",
role.name.pg_quote(),
params.privileged_role_name,
role.to_pg_options(),
)
} else {
@@ -863,9 +849,8 @@ async fn get_operations<'a>(
// ALL PRIVILEGES grants CREATE, CONNECT, and TEMPORARY on the database
// (see https://www.postgresql.org/docs/current/ddl-priv.html)
query: format!(
"GRANT ALL PRIVILEGES ON DATABASE {} TO {}",
db.name.pg_quote(),
params.privileged_role_name
"GRANT ALL PRIVILEGES ON DATABASE {} TO neon_superuser",
db.name.pg_quote()
),
comment: None,
},

View File

@@ -0,0 +1,8 @@
DO $$
BEGIN
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = 'neon_superuser')
THEN
CREATE ROLE neon_superuser CREATEDB CREATEROLE NOLOGIN REPLICATION BYPASSRLS IN ROLE pg_read_all_data, pg_write_all_data;
END IF;
END
$$;

View File

@@ -1,8 +0,0 @@
DO $$
BEGIN
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{privileged_role_name}')
THEN
CREATE ROLE {privileged_role_name} CREATEDB CREATEROLE NOLOGIN REPLICATION BYPASSRLS IN ROLE pg_read_all_data, pg_write_all_data;
END IF;
END
$$;

View File

@@ -8,10 +8,10 @@ code changes locally, but not suitable for running production systems.
## Example: Start with Postgres 16
To create and start a local development environment with Postgres 16, you will need to provide `--pg-version` flag to 2 of the start-up commands.
To create and start a local development environment with Postgres 16, you will need to provide `--pg-version` flag to 3 of the start-up commands.
```shell
cargo neon init
cargo neon init --pg-version 16
cargo neon start
cargo neon tenant create --set-default --pg-version 16
cargo neon endpoint create main --pg-version 16

View File

@@ -631,10 +631,6 @@ struct EndpointCreateCmdArgs {
help = "Allow multiple primary endpoints running on the same branch. Shouldn't be used normally, but useful for tests."
)]
allow_multiple: bool,
/// Only allow changing it on creation
#[clap(long, help = "Name of the privileged role for the endpoint")]
privileged_role_name: Option<String>,
}
#[derive(clap::Args)]
@@ -1484,7 +1480,6 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
args.grpc,
!args.update_catalog,
false,
args.privileged_role_name.clone(),
)?;
}
EndpointCmd::Start(args) => {

View File

@@ -99,7 +99,6 @@ pub struct EndpointConf {
features: Vec<ComputeFeature>,
cluster: Option<Cluster>,
compute_ctl_config: ComputeCtlConfig,
privileged_role_name: Option<String>,
}
//
@@ -200,7 +199,6 @@ impl ComputeControlPlane {
grpc: bool,
skip_pg_catalog_updates: bool,
drop_subscriptions_before_start: bool,
privileged_role_name: Option<String>,
) -> Result<Arc<Endpoint>> {
let pg_port = pg_port.unwrap_or_else(|| self.get_port());
let external_http_port = external_http_port.unwrap_or_else(|| self.get_port() + 1);
@@ -238,7 +236,6 @@ impl ComputeControlPlane {
features: vec![],
cluster: None,
compute_ctl_config: compute_ctl_config.clone(),
privileged_role_name: privileged_role_name.clone(),
});
ep.create_endpoint_dir()?;
@@ -260,7 +257,6 @@ impl ComputeControlPlane {
features: vec![],
cluster: None,
compute_ctl_config,
privileged_role_name,
})?,
)?;
std::fs::write(
@@ -336,9 +332,6 @@ pub struct Endpoint {
/// The compute_ctl config for the endpoint's compute.
compute_ctl_config: ComputeCtlConfig,
/// The name of the privileged role for the endpoint.
privileged_role_name: Option<String>,
}
#[derive(PartialEq, Eq)]
@@ -439,7 +432,6 @@ impl Endpoint {
features: conf.features,
cluster: conf.cluster,
compute_ctl_config: conf.compute_ctl_config,
privileged_role_name: conf.privileged_role_name,
})
}
@@ -878,10 +870,6 @@ impl Endpoint {
cmd.arg("--dev");
}
if let Some(privileged_role_name) = self.privileged_role_name.clone() {
cmd.args(["--privileged-role-name", &privileged_role_name]);
}
let child = cmd.spawn()?;
// set up a scopeguard to kill & wait for the child in case we panic or bail below
let child = scopeguard::guard(child, |mut child| {

View File

@@ -76,12 +76,6 @@ enum Command {
NodeStartDelete {
#[arg(long)]
node_id: NodeId,
/// When `force` is true, skip waiting for shards to prewarm during migration.
/// This can significantly speed up node deletion since prewarming all shards
/// can take considerable time, but may result in slower initial access to
/// migrated shards until they warm up naturally.
#[arg(long)]
force: bool,
},
/// Cancel deletion of the specified pageserver and wait for `timeout`
/// for the operation to be canceled. May be retried.
@@ -958,14 +952,13 @@ async fn main() -> anyhow::Result<()> {
.dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None)
.await?;
}
Command::NodeStartDelete { node_id, force } => {
let query = if force {
format!("control/v1/node/{node_id}/delete?force=true")
} else {
format!("control/v1/node/{node_id}/delete")
};
Command::NodeStartDelete { node_id } => {
storcon_client
.dispatch::<(), ()>(Method::PUT, query, None)
.dispatch::<(), ()>(
Method::PUT,
format!("control/v1/node/{node_id}/delete"),
None,
)
.await?;
println!("Delete started for {node_id}");
}

View File

@@ -1,18 +0,0 @@
[package]
name = "alloc-metrics"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
metrics.workspace = true
measured.workspace = true
thread_local.workspace = true
[dev-dependencies]
criterion.workspace = true
tikv-jemallocator.workspace = true
[[bench]]
harness = false
name = "alloc"

View File

@@ -1,110 +0,0 @@
use std::alloc::{GlobalAlloc, Layout, System, handle_alloc_error};
use alloc_metrics::TrackedAllocator;
use criterion::{
AxisScale, BenchmarkGroup, BenchmarkId, Criterion, PlotConfiguration, measurement::Measurement,
};
use measured::FixedCardinalityLabel;
use tikv_jemallocator::Jemalloc;
fn main() {
let mut c = Criterion::default().configure_from_args();
bench(&mut c);
c.final_summary();
}
#[rustfmt::skip]
fn bench(c: &mut Criterion) {
bench_alloc(c.benchmark_group("alloc/system"), &System, &ALLOC_SYSTEM);
bench_alloc(c.benchmark_group("alloc/jemalloc"), &Jemalloc, &ALLOC_JEMALLOC);
bench_dealloc(c.benchmark_group("dealloc/system"), &System, &ALLOC_SYSTEM);
bench_dealloc(c.benchmark_group("dealloc/jemalloc"), &Jemalloc, &ALLOC_JEMALLOC);
}
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
#[label(singleton = "memory_context")]
pub enum MemoryContext {
Root,
Test,
}
static ALLOC_SYSTEM: TrackedAllocator<System, MemoryContext> =
unsafe { TrackedAllocator::new(System, MemoryContext::Root) };
static ALLOC_JEMALLOC: TrackedAllocator<Jemalloc, MemoryContext> =
unsafe { TrackedAllocator::new(Jemalloc, MemoryContext::Root) };
const KB: u64 = 1024;
const SIZES: [u64; 6] = [64, 256, KB, 4 * KB, 16 * KB, KB * KB];
fn bench_alloc<A: GlobalAlloc>(
mut g: BenchmarkGroup<'_, impl Measurement>,
alloc1: &'static A,
alloc2: &'static TrackedAllocator<A, MemoryContext>,
) {
g.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
for size in SIZES {
let layout = Layout::from_size_align(size as usize, 8).unwrap();
g.throughput(criterion::Throughput::Bytes(size));
g.bench_with_input(BenchmarkId::new("default", size), &layout, |b, &layout| {
let bs = criterion::BatchSize::NumBatches(10 + size.ilog2() as u64);
b.iter_batched(|| {}, |()| Alloc::new(alloc1, layout), bs);
});
g.bench_with_input(BenchmarkId::new("tracked", size), &layout, |b, &layout| {
let _scope = alloc2.scope(MemoryContext::Test);
let bs = criterion::BatchSize::NumBatches(10 + size.ilog2() as u64);
b.iter_batched(|| {}, |()| Alloc::new(alloc2, layout), bs);
});
}
}
fn bench_dealloc<A: GlobalAlloc>(
mut g: BenchmarkGroup<'_, impl Measurement>,
alloc1: &'static A,
alloc2: &'static TrackedAllocator<A, MemoryContext>,
) {
g.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
for size in SIZES {
let layout = Layout::from_size_align(size as usize, 8).unwrap();
g.throughput(criterion::Throughput::Bytes(size));
g.bench_with_input(BenchmarkId::new("default", size), &layout, |b, &layout| {
let bs = criterion::BatchSize::NumBatches(10 + size.ilog2() as u64);
b.iter_batched(|| Alloc::new(alloc1, layout), drop, bs);
});
g.bench_with_input(BenchmarkId::new("tracked", size), &layout, |b, &layout| {
let _scope = alloc2.scope(MemoryContext::Test);
let bs = criterion::BatchSize::NumBatches(10 + size.ilog2() as u64);
b.iter_batched(|| Alloc::new(alloc2, layout), drop, bs);
});
}
}
struct Alloc<'a, A: GlobalAlloc> {
alloc: &'a A,
ptr: *mut u8,
layout: Layout,
}
impl<'a, A: GlobalAlloc> Alloc<'a, A> {
fn new(alloc: &'a A, layout: Layout) -> Self {
let ptr = unsafe { alloc.alloc(layout) };
if ptr.is_null() {
handle_alloc_error(layout);
}
// actually make the page resident.
unsafe { ptr.cast::<u8>().write(1) };
Self { alloc, ptr, layout }
}
}
impl<'a, A: GlobalAlloc> Drop for Alloc<'a, A> {
fn drop(&mut self) {
unsafe { self.alloc.dealloc(self.ptr, self.layout) };
}
}

View File

@@ -1,48 +0,0 @@
use std::marker::PhantomData;
use measured::{
FixedCardinalityLabel, LabelGroup, label::StaticLabelSet, metric::MetricFamilyEncoding,
};
use metrics::{CounterPairAssoc, Dec, Inc, MeasuredCounterPairState};
use crate::metric_vec::DenseMetricVec;
pub struct DenseCounterPairVec<
A: CounterPairAssoc<LabelGroupSet = StaticLabelSet<L>>,
L: FixedCardinalityLabel + LabelGroup,
> {
pub vec: DenseMetricVec<MeasuredCounterPairState, L>,
pub _marker: PhantomData<A>,
}
impl<A: CounterPairAssoc<LabelGroupSet = StaticLabelSet<L>>, L: FixedCardinalityLabel + LabelGroup>
DenseCounterPairVec<A, L>
{
pub fn new() -> Self {
Self {
vec: DenseMetricVec::new(),
_marker: PhantomData,
}
}
}
impl<T, A, L> ::measured::metric::group::MetricGroup<T> for DenseCounterPairVec<A, L>
where
T: ::measured::metric::group::Encoding,
::measured::metric::counter::CounterState: ::measured::metric::MetricEncoding<T>,
A: CounterPairAssoc<LabelGroupSet = StaticLabelSet<L>>,
L: FixedCardinalityLabel + LabelGroup,
{
fn collect_group_into(&self, enc: &mut T) -> Result<(), T::Err> {
// write decrement first to avoid a race condition where inc - dec < 0
T::write_help(enc, A::DEC_NAME, A::DEC_HELP)?;
self.vec
.collect_family_into(A::DEC_NAME, &mut Dec(&mut *enc))?;
T::write_help(enc, A::INC_NAME, A::INC_HELP)?;
self.vec
.collect_family_into(A::INC_NAME, &mut Inc(&mut *enc))?;
Ok(())
}
}

View File

@@ -1,441 +0,0 @@
//! Tagged allocator measurements.
mod counters;
mod metric_vec;
use std::{
alloc::{GlobalAlloc, Layout},
cell::Cell,
marker::PhantomData,
sync::{
OnceLock,
atomic::{AtomicU64, Ordering::Relaxed},
},
};
use measured::{
FixedCardinalityLabel, LabelGroup, MetricGroup,
label::StaticLabelSet,
metric::{MetricEncoding, counter::CounterState, group::Encoding, name::MetricName},
};
use metrics::{CounterPairAssoc, MeasuredCounterPairState};
use thread_local::ThreadLocal;
type AllocCounter<T> = counters::DenseCounterPairVec<AllocPair<T>, T>;
pub struct TrackedAllocator<A, T: 'static + Send + Sync + FixedCardinalityLabel + LabelGroup> {
inner: A,
/// potentially high-content fallback if the thread was not registered.
default_counters: MeasuredCounterPairState,
/// Default tag to use if this thread is not registered.
default_tag: T,
thread: OnceLock<RegisteredThread<T>>,
/// where thread alloc data is eventually saved to, even if threads are shutdown.
global: OnceLock<AllocCounter<T>>,
}
impl<A, T> TrackedAllocator<A, T>
where
T: 'static + Send + Sync + FixedCardinalityLabel + LabelGroup,
{
/// # Safety
///
/// [`FixedCardinalityLabel`] must be implemented correctly, fully dense, and must not panic.
pub const unsafe fn new(alloc: A, default: T) -> Self {
TrackedAllocator {
inner: alloc,
default_tag: default,
default_counters: MeasuredCounterPairState {
inc: CounterState {
count: AtomicU64::new(0),
},
dec: CounterState {
count: AtomicU64::new(0),
},
},
thread: OnceLock::new(),
global: OnceLock::new(),
}
}
/// Allocations
pub fn register_thread(&'static self) {
self.register_thread_inner();
}
pub fn scope(&'static self, tag: T) -> AllocScope<'static, T> {
let cell = self.register_thread_inner();
let last = cell.replace(tag);
AllocScope { cell, last }
}
fn register_thread_inner(&'static self) -> &'static Cell<T> {
let thread = self.thread.get_or_init(|| RegisteredThread {
scope: ThreadLocal::new(),
state: ThreadLocal::new(),
});
thread.state.get_or(|| ThreadState {
counters: AllocCounter::new(),
global: self.global.get_or_init(AllocCounter::new),
});
thread.scope.get_or(|| Cell::new(self.default_tag))
}
}
macro_rules! alloc {
($alloc_fn:ident) => {
unsafe fn $alloc_fn(&self, layout: Layout) -> *mut u8 {
let Ok((tagged_layout, tag_offset)) = layout.extend(Layout::new::<T>()) else {
return std::ptr::null_mut();
};
let tagged_layout = tagged_layout.pad_to_align();
// Safety: The layout is not zero-sized.
let ptr = unsafe { self.inner.$alloc_fn(tagged_layout) };
// allocation failed.
if ptr.is_null() {
return ptr;
}
// We are being very careful here to not allocate or panic.
let thread = self.thread.get().map(|s| (s.scope.get(), s.state.get()));
let tag = thread.and_then(|t| t.0).map_or(self.default_tag, Cell::get);
// Allocation successful. Write our tag
// Safety: tag_offset is inbounds of the ptr
unsafe { ptr.add(tag_offset).cast::<T>().write(tag) }
let counters = thread.and_then(|t| t.1).map(|s| &s.counters);
let metric = if let Some(counters) = counters {
counters.vec.get_metric(tag)
} else {
// if tag is not default, then the thread state would have been registered, therefore tag must be default.
&self.default_counters
};
metric.inc.count.fetch_add(layout.size() as u64, Relaxed);
ptr
}
};
}
// We will tag our allocation by adding `T` to the end of the layout.
// This is ok only as long as it does not overflow. If it does, we will
// just fail the allocation by returning null.
//
// Safety: we will not unwind during alloc, and we will ensure layouts are handled correctly.
unsafe impl<A, T> GlobalAlloc for TrackedAllocator<A, T>
where
A: GlobalAlloc,
T: 'static + Send + Sync + FixedCardinalityLabel + LabelGroup,
{
alloc!(alloc);
alloc!(alloc_zeroed);
unsafe fn realloc(&self, ptr: *mut u8, layout: Layout, new_size: usize) -> *mut u8 {
// SAFETY: the caller must ensure that the `new_size` does not overflow.
// `layout.align()` comes from a `Layout` and is thus guaranteed to be valid.
let new_layout = unsafe { Layout::from_size_align_unchecked(new_size, layout.align()) };
let Ok((new_tagged_layout, new_tag_offset)) = new_layout.extend(Layout::new::<T>()) else {
return std::ptr::null_mut();
};
let new_tagged_layout = new_tagged_layout.pad_to_align();
let Ok((tagged_layout, tag_offset)) = layout.extend(Layout::new::<T>()) else {
// Safety: This layout clearly did not match what was originally allocated,
// otherwise alloc() would have caught this error and returned null.
unsafe { std::hint::unreachable_unchecked() }
};
let tagged_layout = tagged_layout.pad_to_align();
// get the tag set during alloc
// Safety: tag_offset is inbounds of the ptr
let tag = unsafe { ptr.add(tag_offset).cast::<T>().read() };
// Safety: layout sizes are correct
let new_ptr = unsafe {
self.inner
.realloc(ptr, tagged_layout, new_tagged_layout.size())
};
// allocation failed.
if new_ptr.is_null() {
return new_ptr;
}
// We are being very careful here to not allocate or panic.
let thread = self.thread.get().map(|s| (s.scope.get(), s.state.get()));
let new_tag = thread.and_then(|t| t.0).map_or(self.default_tag, Cell::get);
// Allocation successful. Write our tag
// Safety: new_tag_offset is inbounds of the ptr
unsafe { new_ptr.add(new_tag_offset).cast::<T>().write(new_tag) }
let counters = thread.and_then(|t| t.1).map(|s| &s.counters);
let counters = counters.or_else(|| self.global.get());
let (new_metric, old_metric) = if let Some(counters) = counters {
let new_metric = counters.vec.get_metric(new_tag);
let old_metric = counters.vec.get_metric(tag);
(new_metric, old_metric)
} else {
// no tag was registered at all, therefore both tags must be default.
(&self.default_counters, &self.default_counters)
};
let (inc, dec) = if tag.encode() != new_tag.encode() {
(new_layout.size() as u64, layout.size() as u64)
} else if new_layout.size() > layout.size() {
((new_layout.size() - layout.size()) as u64, 0)
} else {
(0, (layout.size() - new_layout.size()) as u64)
};
new_metric.inc.count.fetch_add(inc, Relaxed);
old_metric.dec.count.fetch_add(dec, Relaxed);
new_ptr
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
let Ok((tagged_layout, tag_offset)) = layout.extend(Layout::new::<T>()) else {
// Safety: This layout clearly did not match what was originally allocated,
// otherwise alloc() would have caught this error and returned null.
unsafe { std::hint::unreachable_unchecked() }
};
let tagged_layout = tagged_layout.pad_to_align();
// get the tag set during alloc
// Safety: tag_offset is inbounds of the ptr
let tag = unsafe { ptr.add(tag_offset).cast::<T>().read() };
// Safety: caller upholds contract for us
unsafe { self.inner.dealloc(ptr, tagged_layout) }
// We are being very careful here to not allocate or panic.
let thread = self.thread.get().map(|s| (s.scope.get(), s.state.get()));
let counters = thread.and_then(|t| t.1).map(|s| &s.counters);
let counters = counters.or_else(|| self.global.get());
let metric = if let Some(counters) = counters {
counters.vec.get_metric(tag)
} else {
// if tag is not default, then global would have been registered, therefore tag must be default.
&self.default_counters
};
metric.dec.count.fetch_add(layout.size() as u64, Relaxed);
}
}
pub struct AllocScope<'a, T: FixedCardinalityLabel> {
cell: &'a Cell<T>,
last: T,
}
impl<'a, T: FixedCardinalityLabel> Drop for AllocScope<'a, T> {
fn drop(&mut self) {
self.cell.set(self.last);
}
}
struct AllocPair<T>(PhantomData<T>);
impl<T: FixedCardinalityLabel + LabelGroup> CounterPairAssoc for AllocPair<T> {
const INC_NAME: &'static MetricName = MetricName::from_str("allocated_bytes");
const DEC_NAME: &'static MetricName = MetricName::from_str("deallocated_bytes");
const INC_HELP: &'static str = "total number of bytes allocated";
const DEC_HELP: &'static str = "total number of bytes deallocated";
type LabelGroupSet = StaticLabelSet<T>;
}
struct RegisteredThread<T: 'static + Send + Sync + FixedCardinalityLabel + LabelGroup> {
/// Current memory context for this thread.
scope: ThreadLocal<Cell<T>>,
/// per thread state containing low contention counters for faster allocations.
state: ThreadLocal<ThreadState<T>>,
}
struct ThreadState<T: 'static + FixedCardinalityLabel + LabelGroup> {
counters: AllocCounter<T>,
global: &'static AllocCounter<T>,
}
// Ensure the counters are measured on thread destruction.
impl<T: 'static + FixedCardinalityLabel + LabelGroup> Drop for ThreadState<T> {
fn drop(&mut self) {
// iterate over all labels
for tag in (0..T::cardinality()).map(T::decode) {
// load and reset the counts in the thread-local counters.
let m = self.counters.vec.get_metric_mut(tag);
let inc = *m.inc.count.get_mut();
let dec = *m.dec.count.get_mut();
// add the counts into the global counters.
let m = self.global.vec.get_metric(tag);
m.inc.count.fetch_add(inc, Relaxed);
m.dec.count.fetch_add(dec, Relaxed);
}
}
}
impl<A, T, Enc> MetricGroup<Enc> for TrackedAllocator<A, T>
where
T: 'static + Send + Sync + FixedCardinalityLabel + LabelGroup,
Enc: Encoding,
CounterState: MetricEncoding<Enc>,
{
fn collect_group_into(&self, enc: &mut Enc) -> Result<(), Enc::Err> {
let global = self.global.get_or_init(AllocCounter::new);
// iterate over all counter threads
for s in self.thread.get().into_iter().flat_map(|s| s.state.iter()) {
// iterate over all labels
for tag in (0..T::cardinality()).map(T::decode) {
sample(global, s.counters.vec.get_metric(tag), tag);
}
}
sample(global, &self.default_counters, self.default_tag);
global.collect_group_into(enc)
}
}
fn sample<T: FixedCardinalityLabel + LabelGroup>(
global: &AllocCounter<T>,
local: &MeasuredCounterPairState,
tag: T,
) {
// load and reset the counts in the thread-local counters.
let inc = local.inc.count.swap(0, Relaxed);
let dec = local.dec.count.swap(0, Relaxed);
// add the counts into the global counters.
let m = global.vec.get_metric(tag);
m.inc.count.fetch_add(inc, Relaxed);
m.dec.count.fetch_add(dec, Relaxed);
}
#[cfg(test)]
mod tests {
use std::alloc::{GlobalAlloc, Layout, System};
use measured::{FixedCardinalityLabel, MetricGroup, text::BufferedTextEncoder};
use crate::TrackedAllocator;
#[derive(FixedCardinalityLabel, Clone, Copy, Debug)]
#[label(singleton = "memory_context")]
pub enum MemoryContext {
Root,
Test,
}
#[test]
fn alloc() {
// Safety: `MemoryContext` upholds the safety requirements.
static GLOBAL: TrackedAllocator<System, MemoryContext> =
unsafe { TrackedAllocator::new(System, MemoryContext::Root) };
GLOBAL.register_thread();
let _test = GLOBAL.scope(MemoryContext::Test);
let ptr = unsafe { GLOBAL.alloc(Layout::for_value(&[0_i32])) };
let ptr = unsafe { GLOBAL.realloc(ptr, Layout::for_value(&[0_i32]), 8) };
drop(_test);
let ptr = unsafe { GLOBAL.realloc(ptr, Layout::for_value(&[0_i32, 1_i32]), 4) };
unsafe { GLOBAL.dealloc(ptr, Layout::for_value(&[0_i32])) };
let mut text = BufferedTextEncoder::new();
GLOBAL.collect_group_into(&mut text).unwrap();
let text = String::from_utf8(text.finish().into()).unwrap();
assert_eq!(
text,
r#"# HELP deallocated_bytes total number of bytes deallocated
# TYPE deallocated_bytes counter
deallocated_bytes{memory_context="root"} 4
deallocated_bytes{memory_context="test"} 8
# HELP allocated_bytes total number of bytes allocated
# TYPE allocated_bytes counter
allocated_bytes{memory_context="root"} 4
allocated_bytes{memory_context="test"} 8
"#
);
}
#[test]
fn unregistered_thread() {
// Safety: `MemoryContext` upholds the safety requirements.
static GLOBAL: TrackedAllocator<System, MemoryContext> =
unsafe { TrackedAllocator::new(System, MemoryContext::Root) };
GLOBAL.register_thread();
// unregistered thread
std::thread::spawn(|| {
let ptr = unsafe { GLOBAL.alloc(Layout::for_value(&[0_i32])) };
unsafe { GLOBAL.dealloc(ptr, Layout::for_value(&[0_i32])) };
})
.join()
.unwrap();
let mut text = BufferedTextEncoder::new();
GLOBAL.collect_group_into(&mut text).unwrap();
let text = String::from_utf8(text.finish().into()).unwrap();
assert_eq!(
text,
r#"# HELP deallocated_bytes total number of bytes deallocated
# TYPE deallocated_bytes counter
deallocated_bytes{memory_context="root"} 4
deallocated_bytes{memory_context="test"} 0
# HELP allocated_bytes total number of bytes allocated
# TYPE allocated_bytes counter
allocated_bytes{memory_context="root"} 4
allocated_bytes{memory_context="test"} 0
"#
);
}
#[test]
fn fully_unregistered() {
// Safety: `MemoryContext` upholds the safety requirements.
static GLOBAL: TrackedAllocator<System, MemoryContext> =
unsafe { TrackedAllocator::new(System, MemoryContext::Root) };
let ptr = unsafe { GLOBAL.alloc(Layout::for_value(&[0_i32])) };
unsafe { GLOBAL.dealloc(ptr, Layout::for_value(&[0_i32])) };
let mut text = BufferedTextEncoder::new();
GLOBAL.collect_group_into(&mut text).unwrap();
let text = String::from_utf8(text.finish().into()).unwrap();
assert_eq!(
text,
r#"# HELP deallocated_bytes total number of bytes deallocated
# TYPE deallocated_bytes counter
deallocated_bytes{memory_context="root"} 4
deallocated_bytes{memory_context="test"} 0
# HELP allocated_bytes total number of bytes allocated
# TYPE allocated_bytes counter
allocated_bytes{memory_context="root"} 4
allocated_bytes{memory_context="test"} 0
"#
);
}
}

View File

@@ -1,72 +0,0 @@
//! Dense metric vec
use measured::{
FixedCardinalityLabel, LabelGroup,
label::StaticLabelSet,
metric::{
MetricEncoding, MetricFamilyEncoding, MetricType, group::Encoding, name::MetricNameEncoder,
},
};
pub struct DenseMetricVec<M: MetricType, L: FixedCardinalityLabel + LabelGroup> {
metrics: Box<[M]>,
metadata: M::Metadata,
_label_set: StaticLabelSet<L>,
}
fn new_dense<M: MetricType>(c: usize) -> Box<[M]> {
let mut vec = Vec::with_capacity(c);
vec.resize_with(c, M::default);
vec.into_boxed_slice()
}
impl<M: MetricType, L: FixedCardinalityLabel + LabelGroup> DenseMetricVec<M, L>
where
M::Metadata: Default,
{
/// Create a new metric vec with the given label set and metric metadata
pub fn new() -> Self {
Self::with_metadata(<M::Metadata>::default())
}
}
impl<M: MetricType, L: FixedCardinalityLabel + LabelGroup> DenseMetricVec<M, L> {
/// Create a new metric vec with the given label set and metric metadata
pub fn with_metadata(metadata: M::Metadata) -> Self {
Self {
metrics: new_dense(L::cardinality()),
metadata,
_label_set: StaticLabelSet::new(),
}
}
/// Get the individual metric at the given identifier.
///
/// # Panics
/// Can panic or cause strange behaviour if the label ID comes from a different metric family.
pub fn get_metric(&self, label: L) -> &M {
// safety: The caller has guarantees that the label encoding is valid.
unsafe { self.metrics.get_unchecked(label.encode()) }
}
/// Get the individual metric at the given identifier.
///
/// # Panics
/// Can panic or cause strange behaviour if the label ID comes from a different metric family.
pub fn get_metric_mut(&mut self, label: L) -> &mut M {
// safety: The caller has guarantees that the label encoding is valid.
unsafe { self.metrics.get_unchecked_mut(label.encode()) }
}
}
impl<M: MetricEncoding<T>, L: FixedCardinalityLabel + LabelGroup, T: Encoding>
MetricFamilyEncoding<T> for DenseMetricVec<M, L>
{
fn collect_family_into(&self, name: impl MetricNameEncoder, enc: &mut T) -> Result<(), T::Err> {
M::write_type(&name, enc)?;
for (index, value) in self.metrics.iter().enumerate() {
value.collect_into(&self.metadata, L::decode(index), &name, enc)?;
}
Ok(())
}
}

View File

@@ -46,33 +46,16 @@ pub struct ExtensionInstallResponse {
pub version: ExtVersion,
}
/// Status of the LFC prewarm process. The same state machine is reused for
/// both autoprewarm (prewarm after compute/Postgres start using the previously
/// stored LFC state) and explicit prewarming via API.
#[derive(Serialize, Default, Debug, Clone, PartialEq)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum LfcPrewarmState {
/// Default value when compute boots up.
#[default]
NotPrewarmed,
/// Prewarming thread is active and loading pages into LFC.
Prewarming,
/// We found requested LFC state in the endpoint storage and
/// completed prewarming successfully.
Completed,
/// Unexpected error happened during prewarming. Note, `Not Found 404`
/// response from the endpoint storage is explicitly excluded here
/// because it can normally happen on the first compute start,
/// since LFC state is not available yet.
Failed { error: String },
/// We tried to fetch the corresponding LFC state from the endpoint storage,
/// but received `Not Found 404`. This should normally happen only during the
/// first endpoint start after creation with `autoprewarm: true`.
///
/// During the orchestrated prewarm via API, when a caller explicitly
/// provides the LFC state key to prewarm from, it's the caller responsibility
/// to handle this status as an error state in this case.
Skipped,
Failed {
error: String,
},
}
impl Display for LfcPrewarmState {
@@ -81,7 +64,6 @@ impl Display for LfcPrewarmState {
LfcPrewarmState::NotPrewarmed => f.write_str("NotPrewarmed"),
LfcPrewarmState::Prewarming => f.write_str("Prewarming"),
LfcPrewarmState::Completed => f.write_str("Completed"),
LfcPrewarmState::Skipped => f.write_str("Skipped"),
LfcPrewarmState::Failed { error } => write!(f, "Error({error})"),
}
}

View File

@@ -4,14 +4,12 @@
//! a default registry.
#![deny(clippy::undocumented_unsafe_blocks)]
use std::sync::RwLock;
use measured::label::{LabelGroupSet, LabelGroupVisitor, LabelName, NoLabels};
use measured::metric::counter::CounterState;
use measured::metric::gauge::GaugeState;
use measured::metric::group::Encoding;
use measured::metric::name::{MetricName, MetricNameEncoder};
use measured::metric::{MetricEncoding, MetricFamilyEncoding, MetricType};
use measured::metric::{MetricEncoding, MetricFamilyEncoding};
use measured::{FixedCardinalityLabel, LabelGroup, MetricGroup};
use once_cell::sync::Lazy;
use prometheus::Registry;
@@ -118,52 +116,12 @@ pub fn pow2_buckets(start: usize, end: usize) -> Vec<f64> {
.collect()
}
pub struct InfoMetric<L: LabelGroup, M: MetricType = GaugeState> {
label: RwLock<L>,
metric: M,
}
impl<L: LabelGroup> InfoMetric<L> {
pub fn new(label: L) -> Self {
Self::with_metric(label, GaugeState::new(1))
}
}
impl<L: LabelGroup, M: MetricType<Metadata = ()>> InfoMetric<L, M> {
pub fn with_metric(label: L, metric: M) -> Self {
Self {
label: RwLock::new(label),
metric,
}
}
pub fn set_label(&self, label: L) {
*self.label.write().unwrap() = label;
}
}
impl<L, M, E> MetricFamilyEncoding<E> for InfoMetric<L, M>
where
L: LabelGroup,
M: MetricEncoding<E, Metadata = ()>,
E: Encoding,
{
fn collect_family_into(
&self,
name: impl measured::metric::name::MetricNameEncoder,
enc: &mut E,
) -> Result<(), E::Err> {
M::write_type(&name, enc)?;
self.metric
.collect_into(&(), &*self.label.read().unwrap(), name, enc)
}
}
pub struct BuildInfo {
pub revision: &'static str,
pub build_tag: &'static str,
}
// todo: allow label group without the set
impl LabelGroup for BuildInfo {
fn visit_values(&self, v: &mut impl LabelGroupVisitor) {
const REVISION: &LabelName = LabelName::from_str("revision");
@@ -173,6 +131,24 @@ impl LabelGroup for BuildInfo {
}
}
impl<T: Encoding> MetricFamilyEncoding<T> for BuildInfo
where
GaugeState: MetricEncoding<T>,
{
fn collect_family_into(
&self,
name: impl measured::metric::name::MetricNameEncoder,
enc: &mut T,
) -> Result<(), T::Err> {
enc.write_help(&name, "Build/version information")?;
GaugeState::write_type(&name, enc)?;
GaugeState {
count: std::sync::atomic::AtomicI64::new(1),
}
.collect_into(&(), self, name, enc)
}
}
#[derive(MetricGroup)]
#[metric(new(build_info: BuildInfo))]
pub struct NeonMetrics {
@@ -189,8 +165,8 @@ pub struct NeonMetrics {
#[derive(MetricGroup)]
#[metric(new(build_info: BuildInfo))]
pub struct LibMetrics {
#[metric(init = InfoMetric::new(build_info))]
build_info: InfoMetric<BuildInfo>,
#[metric(init = build_info)]
build_info: BuildInfo,
#[metric(flatten)]
rusage: Rusage,
@@ -478,7 +454,7 @@ pub trait CounterPairAssoc {
}
pub struct CounterPairVec<A: CounterPairAssoc> {
pub vec: measured::metric::MetricVec<MeasuredCounterPairState, A::LabelGroupSet>,
vec: measured::metric::MetricVec<MeasuredCounterPairState, A::LabelGroupSet>,
}
impl<A: CounterPairAssoc> Default for CounterPairVec<A>
@@ -492,17 +468,6 @@ where
}
}
impl<A: CounterPairAssoc> CounterPairVec<A>
where
A::LabelGroupSet: Default,
{
pub fn dense() -> Self {
Self {
vec: measured::metric::MetricVec::dense(),
}
}
}
impl<A: CounterPairAssoc> CounterPairVec<A> {
pub fn guard(
&self,
@@ -512,31 +477,14 @@ impl<A: CounterPairAssoc> CounterPairVec<A> {
self.vec.get_metric(id).inc.inc();
MeasuredCounterPairGuard { vec: &self.vec, id }
}
#[inline]
pub fn inc(&self, labels: <A::LabelGroupSet as LabelGroupSet>::Group<'_>) {
let id = self.vec.with_labels(labels);
self.vec.get_metric(id).inc.inc();
}
#[inline]
pub fn dec(&self, labels: <A::LabelGroupSet as LabelGroupSet>::Group<'_>) {
let id = self.vec.with_labels(labels);
self.vec.get_metric(id).dec.inc();
}
#[inline]
pub fn inc_by(&self, labels: <A::LabelGroupSet as LabelGroupSet>::Group<'_>, x: u64) {
let id = self.vec.with_labels(labels);
self.vec.get_metric(id).inc.inc_by(x);
}
#[inline]
pub fn dec_by(&self, labels: <A::LabelGroupSet as LabelGroupSet>::Group<'_>, x: u64) {
let id = self.vec.with_labels(labels);
self.vec.get_metric(id).dec.inc_by(x);
}
pub fn remove_metric(
&self,
labels: <A::LabelGroupSet as LabelGroupSet>::Group<'_>,
@@ -581,28 +529,6 @@ pub struct MeasuredCounterPairState {
pub dec: CounterState,
}
impl MeasuredCounterPairState {
#[inline]
pub fn inc(&self) {
self.inc.inc();
}
#[inline]
pub fn dec(&self) {
self.dec.inc();
}
#[inline]
pub fn inc_by(&self, x: u64) {
self.inc.inc_by(x);
}
#[inline]
pub fn dec_by(&self, x: u64) {
self.dec.inc_by(x);
}
}
impl measured::metric::MetricType for MeasuredCounterPairState {
type Metadata = ();
}
@@ -619,9 +545,9 @@ impl<A: CounterPairAssoc> Drop for MeasuredCounterPairGuard<'_, A> {
}
/// [`MetricEncoding`] for [`MeasuredCounterPairState`] that only writes the inc counter to the inner encoder.
pub struct Inc<T>(pub T);
struct Inc<T>(T);
/// [`MetricEncoding`] for [`MeasuredCounterPairState`] that only writes the dec counter to the inner encoder.
pub struct Dec<T>(pub T);
struct Dec<T>(T);
impl<T: Encoding> Encoding for Inc<T> {
type Err = T::Err;

View File

@@ -8,13 +8,6 @@ license.workspace = true
thiserror.workspace = true
nix.workspace=true
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
libc.workspace = true
lock_api.workspace = true
rustc-hash.workspace = true
[target.'cfg(target_os = "macos")'.dependencies]
tempfile = "3.14.0"
[dev-dependencies]
rand = "0.9"
rand_distr = "0.5.1"

View File

@@ -1,583 +0,0 @@
//! Resizable hash table implementation on top of byte-level storage (either a [`ShmemHandle`] or a fixed byte array).
//!
//! This hash table has two major components: the bucket array and the dictionary. Each bucket within the
//! bucket array contains a `Option<(K, V)>` and an index of another bucket. In this way there is both an
//! implicit freelist within the bucket array (`None` buckets point to other `None` entries) and various hash
//! chains within the bucket array (a Some bucket will point to other Some buckets that had the same hash).
//!
//! Buckets are never moved unless they are within a region that is being shrunk, and so the actual hash-
//! dependent component is done with the dictionary. When a new key is inserted into the map, a position
//! within the dictionary is decided based on its hash, the data is inserted into an empty bucket based
//! off of the freelist, and then the index of said bucket is placed in the dictionary.
//!
//! This map is resizable (if initialized on top of a [`ShmemHandle`]). Both growing and shrinking happen
//! in-place and are at a high level achieved by expanding/reducing the bucket array and rebuilding the
//! dictionary by rehashing all keys.
//!
//! Concurrency is managed very simply: the entire map is guarded by one shared-memory RwLock.
use std::hash::{BuildHasher, Hash};
use std::mem::MaybeUninit;
use crate::shmem::ShmemHandle;
use crate::{shmem, sync::*};
mod core;
pub mod entry;
#[cfg(test)]
mod tests;
use core::{Bucket, CoreHashMap, INVALID_POS};
use entry::{Entry, OccupiedEntry, PrevPos, VacantEntry};
use thiserror::Error;
/// Error type for a hashmap shrink operation.
#[derive(Error, Debug)]
pub enum HashMapShrinkError {
/// There was an error encountered while resizing the memory area.
#[error("shmem resize failed: {0}")]
ResizeError(shmem::Error),
/// Occupied entries in to-be-shrunk space were encountered beginning at the given index.
#[error("occupied entry in deallocated space found at {0}")]
RemainingEntries(usize),
}
/// This represents a hash table that (possibly) lives in shared memory.
/// If a new process is launched with fork(), the child process inherits
/// this struct.
#[must_use]
pub struct HashMapInit<'a, K, V, S = rustc_hash::FxBuildHasher> {
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
shared_size: usize,
hasher: S,
num_buckets: u32,
}
/// This is a per-process handle to a hash table that (possibly) lives in shared memory.
/// If a child process is launched with fork(), the child process should
/// get its own HashMapAccess by calling HashMapInit::attach_writer/reader().
///
/// XXX: We're not making use of it at the moment, but this struct could
/// hold process-local information in the future.
pub struct HashMapAccess<'a, K, V, S = rustc_hash::FxBuildHasher> {
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
hasher: S,
}
unsafe impl<K: Sync, V: Sync, S> Sync for HashMapAccess<'_, K, V, S> {}
unsafe impl<K: Send, V: Send, S> Send for HashMapAccess<'_, K, V, S> {}
impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
/// Change the 'hasher' used by the hash table.
///
/// NOTE: This must be called right after creating the hash table,
/// before inserting any entries and before calling attach_writer/reader.
/// Otherwise different accessors could be using different hash function,
/// with confusing results.
pub fn with_hasher<T: BuildHasher>(self, hasher: T) -> HashMapInit<'a, K, V, T> {
HashMapInit {
hasher,
shmem_handle: self.shmem_handle,
shared_ptr: self.shared_ptr,
shared_size: self.shared_size,
num_buckets: self.num_buckets,
}
}
/// Loosely (over)estimate the size needed to store a hash table with `num_buckets` buckets.
pub fn estimate_size(num_buckets: u32) -> usize {
// add some margin to cover alignment etc.
CoreHashMap::<K, V>::estimate_size(num_buckets) + size_of::<HashMapShared<K, V>>() + 1000
}
fn new(
num_buckets: u32,
shmem_handle: Option<ShmemHandle>,
area_ptr: *mut u8,
area_size: usize,
hasher: S,
) -> Self {
let mut ptr: *mut u8 = area_ptr;
let end_ptr: *mut u8 = unsafe { ptr.add(area_size) };
// carve out area for the One Big Lock (TM) and the HashMapShared.
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<libc::pthread_rwlock_t>())) };
let raw_lock_ptr = ptr;
ptr = unsafe { ptr.add(size_of::<libc::pthread_rwlock_t>()) };
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<HashMapShared<K, V>>())) };
let shared_ptr: *mut HashMapShared<K, V> = ptr.cast();
ptr = unsafe { ptr.add(size_of::<HashMapShared<K, V>>()) };
// carve out the buckets
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<core::Bucket<K, V>>())) };
let buckets_ptr = ptr;
ptr = unsafe { ptr.add(size_of::<core::Bucket<K, V>>() * num_buckets as usize) };
// use remaining space for the dictionary
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<u32>())) };
assert!(ptr.addr() < end_ptr.addr());
let dictionary_ptr = ptr;
let dictionary_size = unsafe { end_ptr.byte_offset_from(ptr) / size_of::<u32>() as isize };
assert!(dictionary_size > 0);
let buckets =
unsafe { std::slice::from_raw_parts_mut(buckets_ptr.cast(), num_buckets as usize) };
let dictionary = unsafe {
std::slice::from_raw_parts_mut(dictionary_ptr.cast(), dictionary_size as usize)
};
let hashmap = CoreHashMap::new(buckets, dictionary);
unsafe {
let lock = RwLock::from_raw(PthreadRwLock::new(raw_lock_ptr.cast()), hashmap);
std::ptr::write(shared_ptr, lock);
}
Self {
num_buckets,
shmem_handle,
shared_ptr,
shared_size: area_size,
hasher,
}
}
/// Attach to a hash table for writing.
pub fn attach_writer(self) -> HashMapAccess<'a, K, V, S> {
HashMapAccess {
shmem_handle: self.shmem_handle,
shared_ptr: self.shared_ptr,
hasher: self.hasher,
}
}
/// Initialize a table for reading. Currently identical to [`HashMapInit::attach_writer`].
///
/// This is a holdover from a previous implementation and is being kept around for
/// backwards compatibility reasons.
pub fn attach_reader(self) -> HashMapAccess<'a, K, V, S> {
self.attach_writer()
}
}
/// Hash table data that is actually stored in the shared memory area.
///
/// NOTE: We carve out the parts from a contiguous chunk. Growing and shrinking the hash table
/// relies on the memory layout! The data structures are laid out in the contiguous shared memory
/// area as follows:
///
/// [`libc::pthread_rwlock_t`]
/// [`HashMapShared`]
/// buckets
/// dictionary
///
/// In between the above parts, there can be padding bytes to align the parts correctly.
type HashMapShared<'a, K, V> = RwLock<CoreHashMap<'a, K, V>>;
impl<'a, K, V> HashMapInit<'a, K, V, rustc_hash::FxBuildHasher>
where
K: Clone + Hash + Eq,
{
/// Place the hash table within a user-supplied fixed memory area.
pub fn with_fixed(num_buckets: u32, area: &'a mut [MaybeUninit<u8>]) -> Self {
Self::new(
num_buckets,
None,
area.as_mut_ptr().cast(),
area.len(),
rustc_hash::FxBuildHasher,
)
}
/// Place a new hash map in the given shared memory area
///
/// # Panics
/// Will panic on failure to resize area to expected map size.
pub fn with_shmem(num_buckets: u32, shmem: ShmemHandle) -> Self {
let size = Self::estimate_size(num_buckets);
shmem
.set_size(size)
.expect("could not resize shared memory area");
let ptr = shmem.data_ptr.as_ptr().cast();
Self::new(
num_buckets,
Some(shmem),
ptr,
size,
rustc_hash::FxBuildHasher,
)
}
/// Make a resizable hash map within a new shared memory area with the given name.
pub fn new_resizeable_named(num_buckets: u32, max_buckets: u32, name: &str) -> Self {
let size = Self::estimate_size(num_buckets);
let max_size = Self::estimate_size(max_buckets);
let shmem =
ShmemHandle::new(name, size, max_size).expect("failed to make shared memory area");
let ptr = shmem.data_ptr.as_ptr().cast();
Self::new(
num_buckets,
Some(shmem),
ptr,
size,
rustc_hash::FxBuildHasher,
)
}
/// Make a resizable hash map within a new anonymous shared memory area.
pub fn new_resizeable(num_buckets: u32, max_buckets: u32) -> Self {
use std::sync::atomic::{AtomicUsize, Ordering};
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let val = COUNTER.fetch_add(1, Ordering::Relaxed);
let name = format!("neon_shmem_hmap{val}");
Self::new_resizeable_named(num_buckets, max_buckets, &name)
}
}
impl<'a, K, V, S: BuildHasher> HashMapAccess<'a, K, V, S>
where
K: Clone + Hash + Eq,
{
/// Hash a key using the map's hasher.
#[inline]
fn get_hash_value(&self, key: &K) -> u64 {
self.hasher.hash_one(key)
}
fn entry_with_hash(&self, key: K, hash: u64) -> Entry<'a, '_, K, V> {
let mut map = unsafe { self.shared_ptr.as_ref() }.unwrap().write();
let dict_pos = hash as usize % map.dictionary.len();
let first = map.dictionary[dict_pos];
if first == INVALID_POS {
// no existing entry
return Entry::Vacant(VacantEntry {
map,
key,
dict_pos: dict_pos as u32,
});
}
let mut prev_pos = PrevPos::First(dict_pos as u32);
let mut next = first;
loop {
let bucket = &mut map.buckets[next as usize];
let (bucket_key, _bucket_value) = bucket.inner.as_mut().expect("entry is in use");
if *bucket_key == key {
// found existing entry
return Entry::Occupied(OccupiedEntry {
map,
_key: key,
prev_pos,
bucket_pos: next,
});
}
if bucket.next == INVALID_POS {
// No existing entry
return Entry::Vacant(VacantEntry {
map,
key,
dict_pos: dict_pos as u32,
});
}
prev_pos = PrevPos::Chained(next);
next = bucket.next;
}
}
/// Get a reference to the corresponding value for a key.
pub fn get<'e>(&'e self, key: &K) -> Option<ValueReadGuard<'e, V>> {
let hash = self.get_hash_value(key);
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
RwLockReadGuard::try_map(map, |m| m.get_with_hash(key, hash)).ok()
}
/// Get a reference to the entry containing a key.
///
/// NB: THis takes a write lock as there's no way to distinguish whether the intention
/// is to use the entry for reading or for writing in advance.
pub fn entry(&self, key: K) -> Entry<'a, '_, K, V> {
let hash = self.get_hash_value(&key);
self.entry_with_hash(key, hash)
}
/// Remove a key given its hash. Returns the associated value if it existed.
pub fn remove(&self, key: &K) -> Option<V> {
let hash = self.get_hash_value(key);
match self.entry_with_hash(key.clone(), hash) {
Entry::Occupied(e) => Some(e.remove()),
Entry::Vacant(_) => None,
}
}
/// Insert/update a key. Returns the previous associated value if it existed.
///
/// # Errors
/// Will return [`core::FullError`] if there is no more space left in the map.
pub fn insert(&self, key: K, value: V) -> Result<Option<V>, core::FullError> {
let hash = self.get_hash_value(&key);
match self.entry_with_hash(key.clone(), hash) {
Entry::Occupied(mut e) => Ok(Some(e.insert(value))),
Entry::Vacant(e) => {
_ = e.insert(value)?;
Ok(None)
}
}
}
/// Optionally return the entry for a bucket at a given index if it exists.
///
/// Has more overhead than one would intuitively expect: performs both a clone of the key
/// due to the [`OccupiedEntry`] type owning the key and also a hash of the key in order
/// to enable repairing the hash chain if the entry is removed.
pub fn entry_at_bucket(&self, pos: usize) -> Option<OccupiedEntry<'a, '_, K, V>> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
if pos >= map.buckets.len() {
return None;
}
let entry = map.buckets[pos].inner.as_ref();
match entry {
Some((key, _)) => Some(OccupiedEntry {
_key: key.clone(),
bucket_pos: pos as u32,
prev_pos: entry::PrevPos::Unknown(self.get_hash_value(key)),
map,
}),
_ => None,
}
}
/// Returns the number of buckets in the table.
pub fn get_num_buckets(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
map.get_num_buckets()
}
/// Return the key and value stored in bucket with given index. This can be used to
/// iterate through the hash map.
// TODO: An Iterator might be nicer. The communicator's clock algorithm needs to
// _slowly_ iterate through all buckets with its clock hand, without holding a lock.
// If we switch to an Iterator, it must not hold the lock.
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<(K, V)>> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
if pos >= map.buckets.len() {
return None;
}
RwLockReadGuard::try_map(map, |m| m.buckets[pos].inner.as_ref()).ok()
}
/// Returns the index of the bucket a given value corresponds to.
pub fn get_bucket_for_value(&self, val_ptr: *const V) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
let origin = map.buckets.as_ptr();
let idx = (val_ptr as usize - origin as usize) / size_of::<Bucket<K, V>>();
assert!(idx < map.buckets.len());
idx
}
/// Returns the number of occupied buckets in the table.
pub fn get_num_buckets_in_use(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
map.buckets_in_use as usize
}
/// Clears all entries in a table. Does not reset any shrinking operations.
pub fn clear(&self) {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
map.clear();
}
/// Perform an in-place rehash of some region (0..`rehash_buckets`) of the table and reset
/// the `buckets` and `dictionary` slices to be as long as `num_buckets`. Resets the freelist
/// in the process.
fn rehash_dict(
&self,
inner: &mut CoreHashMap<'a, K, V>,
buckets_ptr: *mut core::Bucket<K, V>,
end_ptr: *mut u8,
num_buckets: u32,
rehash_buckets: u32,
) {
inner.free_head = INVALID_POS;
let buckets;
let dictionary;
unsafe {
let buckets_end_ptr = buckets_ptr.add(num_buckets as usize);
let dictionary_ptr: *mut u32 = buckets_end_ptr
.byte_add(buckets_end_ptr.align_offset(align_of::<u32>()))
.cast();
let dictionary_size: usize =
end_ptr.byte_offset_from(buckets_end_ptr) as usize / size_of::<u32>();
buckets = std::slice::from_raw_parts_mut(buckets_ptr, num_buckets as usize);
dictionary = std::slice::from_raw_parts_mut(dictionary_ptr, dictionary_size);
}
for e in dictionary.iter_mut() {
*e = INVALID_POS;
}
for (i, bucket) in buckets.iter_mut().enumerate().take(rehash_buckets as usize) {
if bucket.inner.is_none() {
bucket.next = inner.free_head;
inner.free_head = i as u32;
continue;
}
let hash = self.hasher.hash_one(&bucket.inner.as_ref().unwrap().0);
let pos: usize = (hash % dictionary.len() as u64) as usize;
bucket.next = dictionary[pos];
dictionary[pos] = i as u32;
}
inner.dictionary = dictionary;
inner.buckets = buckets;
}
/// Rehash the map without growing or shrinking.
pub fn shuffle(&self) {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
let num_buckets = map.get_num_buckets() as u32;
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
let end_ptr: *mut u8 = unsafe { self.shared_ptr.byte_add(size_bytes).cast() };
let buckets_ptr = map.buckets.as_mut_ptr();
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets);
}
/// Grow the number of buckets within the table.
///
/// 1. Grows the underlying shared memory area
/// 2. Initializes new buckets and overwrites the current dictionary
/// 3. Rehashes the dictionary
///
/// # Panics
/// Panics if called on a map initialized with [`HashMapInit::with_fixed`].
///
/// # Errors
/// Returns an [`shmem::Error`] if any errors occur resizing the memory region.
pub fn grow(&self, num_buckets: u32) -> Result<(), shmem::Error> {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
let old_num_buckets = map.buckets.len() as u32;
assert!(
num_buckets >= old_num_buckets,
"grow called with a smaller number of buckets"
);
if num_buckets == old_num_buckets {
return Ok(());
}
let shmem_handle = self
.shmem_handle
.as_ref()
.expect("grow called on a fixed-size hash table");
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
shmem_handle.set_size(size_bytes)?;
let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
// Initialize new buckets. The new buckets are linked to the free list.
// NB: This overwrites the dictionary!
let buckets_ptr = map.buckets.as_mut_ptr();
unsafe {
for i in old_num_buckets..num_buckets {
let bucket = buckets_ptr.add(i as usize);
bucket.write(core::Bucket {
next: if i < num_buckets - 1 {
i + 1
} else {
map.free_head
},
inner: None,
});
}
}
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, old_num_buckets);
map.free_head = old_num_buckets;
Ok(())
}
/// Begin a shrink, limiting all new allocations to be in buckets with index below `num_buckets`.
///
/// # Panics
/// Panics if called on a map initialized with [`HashMapInit::with_fixed`] or if `num_buckets` is
/// greater than the number of buckets in the map.
pub fn begin_shrink(&mut self, num_buckets: u32) {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
assert!(
num_buckets <= map.get_num_buckets() as u32,
"shrink called with a larger number of buckets"
);
_ = self
.shmem_handle
.as_ref()
.expect("shrink called on a fixed-size hash table");
map.alloc_limit = num_buckets;
}
/// If a shrink operation is underway, returns the target size of the map. Otherwise, returns None.
pub fn shrink_goal(&self) -> Option<usize> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap().read();
let goal = map.alloc_limit;
if goal == INVALID_POS {
None
} else {
Some(goal as usize)
}
}
/// Complete a shrink after caller has evicted entries, removing the unused buckets and rehashing.
///
/// # Panics
/// The following cases result in a panic:
/// - Calling this function on a map initialized with [`HashMapInit::with_fixed`].
/// - Calling this function on a map when no shrink operation is in progress.
pub fn finish_shrink(&self) -> Result<(), HashMapShrinkError> {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
assert!(
map.alloc_limit != INVALID_POS,
"called finish_shrink when no shrink is in progress"
);
let num_buckets = map.alloc_limit;
if map.get_num_buckets() == num_buckets as usize {
return Ok(());
}
assert!(
map.buckets_in_use <= num_buckets,
"called finish_shrink before enough entries were removed"
);
for i in (num_buckets as usize)..map.buckets.len() {
if map.buckets[i].inner.is_some() {
return Err(HashMapShrinkError::RemainingEntries(i));
}
}
let shmem_handle = self
.shmem_handle
.as_ref()
.expect("shrink called on a fixed-size hash table");
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
if let Err(e) = shmem_handle.set_size(size_bytes) {
return Err(HashMapShrinkError::ResizeError(e));
}
let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
let buckets_ptr = map.buckets.as_mut_ptr();
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets);
map.alloc_limit = INVALID_POS;
Ok(())
}
}

View File

@@ -1,174 +0,0 @@
//! Simple hash table with chaining.
use std::hash::Hash;
use std::mem::MaybeUninit;
use crate::hash::entry::*;
/// Invalid position within the map (either within the dictionary or bucket array).
pub(crate) const INVALID_POS: u32 = u32::MAX;
/// Fundamental storage unit within the hash table. Either empty or contains a key-value pair.
/// Always part of a chain of some kind (either a freelist if empty or a hash chain if full).
pub(crate) struct Bucket<K, V> {
/// Index of next bucket in the chain.
pub(crate) next: u32,
/// Key-value pair contained within bucket.
pub(crate) inner: Option<(K, V)>,
}
/// Core hash table implementation.
pub(crate) struct CoreHashMap<'a, K, V> {
/// Dictionary used to map hashes to bucket indices.
pub(crate) dictionary: &'a mut [u32],
/// Buckets containing key-value pairs.
pub(crate) buckets: &'a mut [Bucket<K, V>],
/// Head of the freelist.
pub(crate) free_head: u32,
/// Maximum index of a bucket allowed to be allocated. [`INVALID_POS`] if no limit.
pub(crate) alloc_limit: u32,
/// The number of currently occupied buckets.
pub(crate) buckets_in_use: u32,
}
/// Error for when there are no empty buckets left but one is needed.
#[derive(Debug, PartialEq)]
pub struct FullError;
impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
const FILL_FACTOR: f32 = 0.60;
/// Estimate the size of data contained within the the hash map.
pub fn estimate_size(num_buckets: u32) -> usize {
let mut size = 0;
// buckets
size += size_of::<Bucket<K, V>>() * num_buckets as usize;
// dictionary
size += (f32::ceil((size_of::<u32>() * num_buckets as usize) as f32 / Self::FILL_FACTOR))
as usize;
size
}
pub fn new(
buckets: &'a mut [MaybeUninit<Bucket<K, V>>],
dictionary: &'a mut [MaybeUninit<u32>],
) -> Self {
// Initialize the buckets
for i in 0..buckets.len() {
buckets[i].write(Bucket {
next: if i < buckets.len() - 1 {
i as u32 + 1
} else {
INVALID_POS
},
inner: None,
});
}
// Initialize the dictionary
for e in dictionary.iter_mut() {
e.write(INVALID_POS);
}
// TODO: use std::slice::assume_init_mut() once it stabilizes
let buckets =
unsafe { std::slice::from_raw_parts_mut(buckets.as_mut_ptr().cast(), buckets.len()) };
let dictionary = unsafe {
std::slice::from_raw_parts_mut(dictionary.as_mut_ptr().cast(), dictionary.len())
};
Self {
dictionary,
buckets,
free_head: 0,
buckets_in_use: 0,
alloc_limit: INVALID_POS,
}
}
/// Get the value associated with a key (if it exists) given its hash.
pub fn get_with_hash(&self, key: &K, hash: u64) -> Option<&V> {
let mut next = self.dictionary[hash as usize % self.dictionary.len()];
loop {
if next == INVALID_POS {
return None;
}
let bucket = &self.buckets[next as usize];
let (bucket_key, bucket_value) = bucket.inner.as_ref().expect("entry is in use");
if bucket_key == key {
return Some(bucket_value);
}
next = bucket.next;
}
}
/// Get number of buckets in map.
pub fn get_num_buckets(&self) -> usize {
self.buckets.len()
}
/// Clears all entries from the hashmap.
///
/// Does not reset any allocation limits, but does clear any entries beyond them.
pub fn clear(&mut self) {
for i in 0..self.buckets.len() {
self.buckets[i] = Bucket {
next: if i < self.buckets.len() - 1 {
i as u32 + 1
} else {
INVALID_POS
},
inner: None,
}
}
for i in 0..self.dictionary.len() {
self.dictionary[i] = INVALID_POS;
}
self.free_head = 0;
self.buckets_in_use = 0;
}
/// Find the position of an unused bucket via the freelist and initialize it.
pub(crate) fn alloc_bucket(&mut self, key: K, value: V) -> Result<u32, FullError> {
let mut pos = self.free_head;
// Find the first bucket we're *allowed* to use.
let mut prev = PrevPos::First(self.free_head);
while pos != INVALID_POS && pos >= self.alloc_limit {
let bucket = &mut self.buckets[pos as usize];
prev = PrevPos::Chained(pos);
pos = bucket.next;
}
if pos == INVALID_POS {
return Err(FullError);
}
// Repair the freelist.
match prev {
PrevPos::First(_) => {
let next_pos = self.buckets[pos as usize].next;
self.free_head = next_pos;
}
PrevPos::Chained(p) => {
if p != INVALID_POS {
let next_pos = self.buckets[pos as usize].next;
self.buckets[p as usize].next = next_pos;
}
}
_ => unreachable!(),
}
// Initialize the bucket.
let bucket = &mut self.buckets[pos as usize];
self.buckets_in_use += 1;
bucket.next = INVALID_POS;
bucket.inner = Some((key, value));
Ok(pos)
}
}

View File

@@ -1,130 +0,0 @@
//! Equivalent of [`std::collections::hash_map::Entry`] for this hashmap.
use crate::hash::core::{CoreHashMap, FullError, INVALID_POS};
use crate::sync::{RwLockWriteGuard, ValueWriteGuard};
use std::hash::Hash;
use std::mem;
pub enum Entry<'a, 'b, K, V> {
Occupied(OccupiedEntry<'a, 'b, K, V>),
Vacant(VacantEntry<'a, 'b, K, V>),
}
/// Enum representing the previous position within a chain.
#[derive(Clone, Copy)]
pub(crate) enum PrevPos {
/// Starting index within the dictionary.
First(u32),
/// Regular index within the buckets.
Chained(u32),
/// Unknown - e.g. the associated entry was retrieved by index instead of chain.
Unknown(u64),
}
pub struct OccupiedEntry<'a, 'b, K, V> {
/// Mutable reference to the map containing this entry.
pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>,
/// The key of the occupied entry
pub(crate) _key: K,
/// The index of the previous entry in the chain.
pub(crate) prev_pos: PrevPos,
/// The position of the bucket in the [`CoreHashMap`] bucket array.
pub(crate) bucket_pos: u32,
}
impl<K, V> OccupiedEntry<'_, '_, K, V> {
pub fn get(&self) -> &V {
&self.map.buckets[self.bucket_pos as usize]
.inner
.as_ref()
.unwrap()
.1
}
pub fn get_mut(&mut self) -> &mut V {
&mut self.map.buckets[self.bucket_pos as usize]
.inner
.as_mut()
.unwrap()
.1
}
/// Inserts a value into the entry, replacing (and returning) the existing value.
pub fn insert(&mut self, value: V) -> V {
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
// This assumes inner is Some, which it must be for an OccupiedEntry
mem::replace(&mut bucket.inner.as_mut().unwrap().1, value)
}
/// Removes the entry from the hash map, returning the value originally stored within it.
///
/// This may result in multiple bucket accesses if the entry was obtained by index as the
/// previous chain entry needs to be discovered in this case.
pub fn remove(mut self) -> V {
// If this bucket was queried by index, go ahead and follow its chain from the start.
let prev = if let PrevPos::Unknown(hash) = self.prev_pos {
let dict_idx = hash as usize % self.map.dictionary.len();
let mut prev = PrevPos::First(dict_idx as u32);
let mut curr = self.map.dictionary[dict_idx];
while curr != self.bucket_pos {
assert!(curr != INVALID_POS);
prev = PrevPos::Chained(curr);
curr = self.map.buckets[curr as usize].next;
}
prev
} else {
self.prev_pos
};
// CoreHashMap::remove returns Option<(K, V)>. We know it's Some for an OccupiedEntry.
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
// unlink it from the chain
match prev {
PrevPos::First(dict_pos) => {
self.map.dictionary[dict_pos as usize] = bucket.next;
}
PrevPos::Chained(bucket_pos) => {
self.map.buckets[bucket_pos as usize].next = bucket.next;
}
_ => unreachable!(),
}
// and add it to the freelist
let free = self.map.free_head;
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
let old_value = bucket.inner.take();
bucket.next = free;
self.map.free_head = self.bucket_pos;
self.map.buckets_in_use -= 1;
old_value.unwrap().1
}
}
/// An abstract view into a vacant entry within the map.
pub struct VacantEntry<'a, 'b, K, V> {
/// Mutable reference to the map containing this entry.
pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>,
/// The key to be inserted into this entry.
pub(crate) key: K,
/// The position within the dictionary corresponding to the key's hash.
pub(crate) dict_pos: u32,
}
impl<'b, K: Clone + Hash + Eq, V> VacantEntry<'_, 'b, K, V> {
/// Insert a value into the vacant entry, finding and populating an empty bucket in the process.
///
/// # Errors
/// Will return [`FullError`] if there are no unoccupied buckets in the map.
pub fn insert(mut self, value: V) -> Result<ValueWriteGuard<'b, V>, FullError> {
let pos = self.map.alloc_bucket(self.key, value)?;
self.map.buckets[pos as usize].next = self.map.dictionary[self.dict_pos as usize];
self.map.dictionary[self.dict_pos as usize] = pos;
Ok(RwLockWriteGuard::map(self.map, |m| {
&mut m.buckets[pos as usize].inner.as_mut().unwrap().1
}))
}
}

View File

@@ -1,428 +0,0 @@
use std::collections::BTreeMap;
use std::collections::HashSet;
use std::fmt::Debug;
use std::mem::MaybeUninit;
use crate::hash::Entry;
use crate::hash::HashMapAccess;
use crate::hash::HashMapInit;
use crate::hash::core::FullError;
use rand::seq::SliceRandom;
use rand::{Rng, RngCore};
use rand_distr::Zipf;
const TEST_KEY_LEN: usize = 16;
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
struct TestKey([u8; TEST_KEY_LEN]);
impl From<&TestKey> for u128 {
fn from(val: &TestKey) -> u128 {
u128::from_be_bytes(val.0)
}
}
impl From<u128> for TestKey {
fn from(val: u128) -> TestKey {
TestKey(val.to_be_bytes())
}
}
impl<'a> From<&'a [u8]> for TestKey {
fn from(bytes: &'a [u8]) -> TestKey {
TestKey(bytes.try_into().unwrap())
}
}
fn test_inserts<K: Into<TestKey> + Copy>(keys: &[K]) {
let w = HashMapInit::<TestKey, usize>::new_resizeable_named(100000, 120000, "test_inserts")
.attach_writer();
for (idx, k) in keys.iter().enumerate() {
let res = w.entry((*k).into());
match res {
Entry::Occupied(mut e) => {
e.insert(idx);
}
Entry::Vacant(e) => {
let res = e.insert(idx);
assert!(res.is_ok());
}
};
}
for (idx, k) in keys.iter().enumerate() {
let x = w.get(&(*k).into());
let value = x.as_deref().copied();
assert_eq!(value, Some(idx));
}
}
#[test]
fn dense() {
// This exercises splitting a node with prefix
let keys: &[u128] = &[0, 1, 2, 3, 256];
test_inserts(keys);
// Dense keys
let mut keys: Vec<u128> = (0..10000).collect();
test_inserts(&keys);
// Do the same in random orders
for _ in 1..10 {
keys.shuffle(&mut rand::rng());
test_inserts(&keys);
}
}
#[test]
fn sparse() {
// sparse keys
let mut keys: Vec<TestKey> = Vec::new();
let mut used_keys = HashSet::new();
for _ in 0..10000 {
loop {
let key = rand::random::<u128>();
if used_keys.contains(&key) {
continue;
}
used_keys.insert(key);
keys.push(key.into());
break;
}
}
test_inserts(&keys);
}
#[derive(Clone, Debug)]
struct TestOp(TestKey, Option<usize>);
fn apply_op(
op: &TestOp,
map: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
) {
// apply the change to the shadow tree first
let shadow_existing = if let Some(v) = op.1 {
shadow.insert(op.0, v)
} else {
shadow.remove(&op.0)
};
let entry = map.entry(op.0);
let hash_existing = match op.1 {
Some(new) => match entry {
Entry::Occupied(mut e) => Some(e.insert(new)),
Entry::Vacant(e) => {
_ = e.insert(new).unwrap();
None
}
},
None => match entry {
Entry::Occupied(e) => Some(e.remove()),
Entry::Vacant(_) => None,
},
};
assert_eq!(shadow_existing, hash_existing);
}
fn do_random_ops(
num_ops: usize,
size: u32,
del_prob: f64,
writer: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
rng: &mut rand::rngs::ThreadRng,
) {
for i in 0..num_ops {
let key: TestKey = ((rng.next_u32() % size) as u128).into();
let op = TestOp(
key,
if rng.random_bool(del_prob) {
Some(i)
} else {
None
},
);
apply_op(&op, writer, shadow);
}
}
fn do_deletes(
num_ops: usize,
writer: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
) {
for _ in 0..num_ops {
let (k, _) = shadow.pop_first().unwrap();
writer.remove(&k);
}
}
fn do_shrink(
writer: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
from: u32,
to: u32,
) {
assert!(writer.shrink_goal().is_none());
writer.begin_shrink(to);
assert_eq!(writer.shrink_goal(), Some(to as usize));
for i in to..from {
if let Some(entry) = writer.entry_at_bucket(i as usize) {
shadow.remove(&entry._key);
entry.remove();
}
}
let old_usage = writer.get_num_buckets_in_use();
writer.finish_shrink().unwrap();
assert!(writer.shrink_goal().is_none());
assert_eq!(writer.get_num_buckets_in_use(), old_usage);
}
#[test]
fn random_ops() {
let mut writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(100000, 120000, "test_random")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let distribution = Zipf::new(u128::MAX as f64, 1.1).unwrap();
let mut rng = rand::rng();
for i in 0..100000 {
let key: TestKey = (rng.sample(distribution) as u128).into();
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &mut writer, &mut shadow);
}
}
#[test]
fn test_shuffle() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 1200, "test_shuf")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng);
writer.shuffle();
do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng);
}
#[test]
fn test_grow() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 2000, "test_grow")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng);
let old_usage = writer.get_num_buckets_in_use();
writer.grow(1500).unwrap();
assert_eq!(writer.get_num_buckets_in_use(), old_usage);
assert_eq!(writer.get_num_buckets(), 1500);
do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
}
#[test]
fn test_clear() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_clear")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(2000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
writer.clear();
assert_eq!(writer.get_num_buckets_in_use(), 0);
assert_eq!(writer.get_num_buckets(), 1500);
while let Some((key, _)) = shadow.pop_first() {
assert!(writer.get(&key).is_none());
}
do_random_ops(2000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
for i in 0..(1500 - writer.get_num_buckets_in_use()) {
writer.insert((1500 + i as u128).into(), 0).unwrap();
}
assert_eq!(writer.insert(5000.into(), 0), Err(FullError {}));
writer.clear();
assert!(writer.insert(5000.into(), 0).is_ok());
}
#[test]
fn test_idx_remove() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_clear")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(2000, 1500, 0.25, &mut writer, &mut shadow, &mut rng);
for _ in 0..100 {
let idx = (rng.next_u32() % 1500) as usize;
if let Some(e) = writer.entry_at_bucket(idx) {
shadow.remove(&e._key);
e.remove();
}
}
while let Some((key, val)) = shadow.pop_first() {
assert_eq!(*writer.get(&key).unwrap(), val);
}
}
#[test]
fn test_idx_get() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_clear")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(2000, 1500, 0.25, &mut writer, &mut shadow, &mut rng);
for _ in 0..100 {
let idx = (rng.next_u32() % 1500) as usize;
if let Some(pair) = writer.get_at_bucket(idx) {
{
let v: *const usize = &pair.1;
assert_eq!(writer.get_bucket_for_value(v), idx);
}
{
let v: *const usize = &pair.1;
assert_eq!(writer.get_bucket_for_value(v), idx);
}
}
}
}
#[test]
fn test_shrink() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_shrink")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
do_shrink(&mut writer, &mut shadow, 1500, 1000);
assert_eq!(writer.get_num_buckets(), 1000);
do_deletes(500, &mut writer, &mut shadow);
do_random_ops(10000, 500, 0.75, &mut writer, &mut shadow, &mut rng);
assert!(writer.get_num_buckets_in_use() <= 1000);
}
#[test]
fn test_shrink_grow_seq() {
let mut writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 20000, "test_grow_seq")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(500, 1000, 0.1, &mut writer, &mut shadow, &mut rng);
eprintln!("Shrinking to 750");
do_shrink(&mut writer, &mut shadow, 1000, 750);
do_random_ops(200, 1000, 0.5, &mut writer, &mut shadow, &mut rng);
eprintln!("Growing to 1500");
writer.grow(1500).unwrap();
do_random_ops(600, 1500, 0.1, &mut writer, &mut shadow, &mut rng);
eprintln!("Shrinking to 200");
while shadow.len() > 100 {
do_deletes(1, &mut writer, &mut shadow);
}
do_shrink(&mut writer, &mut shadow, 1500, 200);
do_random_ops(50, 1500, 0.25, &mut writer, &mut shadow, &mut rng);
eprintln!("Growing to 10k");
writer.grow(10000).unwrap();
do_random_ops(10000, 5000, 0.25, &mut writer, &mut shadow, &mut rng);
}
#[test]
fn test_bucket_ops() {
let writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 1200, "test_bucket_ops")
.attach_writer();
match writer.entry(1.into()) {
Entry::Occupied(mut e) => {
e.insert(2);
}
Entry::Vacant(e) => {
_ = e.insert(2).unwrap();
}
}
assert_eq!(writer.get_num_buckets_in_use(), 1);
assert_eq!(writer.get_num_buckets(), 1000);
assert_eq!(*writer.get(&1.into()).unwrap(), 2);
let pos = match writer.entry(1.into()) {
Entry::Occupied(e) => {
assert_eq!(e._key, 1.into());
e.bucket_pos as usize
}
Entry::Vacant(_) => {
panic!("Insert didn't affect entry");
}
};
assert_eq!(writer.entry_at_bucket(pos).unwrap()._key, 1.into());
assert_eq!(*writer.get_at_bucket(pos).unwrap(), (1.into(), 2));
{
let ptr: *const usize = &*writer.get(&1.into()).unwrap();
assert_eq!(writer.get_bucket_for_value(ptr), pos);
}
writer.remove(&1.into());
assert!(writer.get(&1.into()).is_none());
}
#[test]
fn test_shrink_zero() {
let mut writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_shrink_zero")
.attach_writer();
writer.begin_shrink(0);
for i in 0..1500 {
writer.entry_at_bucket(i).map(|x| x.remove());
}
writer.finish_shrink().unwrap();
assert_eq!(writer.get_num_buckets_in_use(), 0);
let entry = writer.entry(1.into());
if let Entry::Vacant(v) = entry {
assert!(v.insert(2).is_err());
} else {
panic!("Somehow got non-vacant entry in empty map.")
}
writer.grow(50).unwrap();
let entry = writer.entry(1.into());
if let Entry::Vacant(v) = entry {
assert!(v.insert(2).is_ok());
} else {
panic!("Somehow got non-vacant entry in empty map.")
}
assert_eq!(writer.get_num_buckets_in_use(), 1);
}
#[test]
#[should_panic]
fn test_grow_oom() {
let writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_grow_oom")
.attach_writer();
writer.grow(20000).unwrap();
}
#[test]
#[should_panic]
fn test_shrink_bigger() {
let mut writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2500, "test_shrink_bigger")
.attach_writer();
writer.begin_shrink(2000);
}
#[test]
#[should_panic]
fn test_shrink_early_finish() {
let writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2500, "test_shrink_early_finish")
.attach_writer();
writer.finish_shrink().unwrap();
}
#[test]
#[should_panic]
fn test_shrink_fixed_size() {
let mut area = [MaybeUninit::uninit(); 10000];
let init_struct = HashMapInit::<TestKey, usize>::with_fixed(3, &mut area);
let mut writer = init_struct.attach_writer();
writer.begin_shrink(1);
}

View File

@@ -1,3 +1,418 @@
pub mod hash;
pub mod shmem;
pub mod sync;
//! Shared memory utilities for neon communicator
use std::num::NonZeroUsize;
use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use nix::errno::Errno;
use nix::sys::mman::MapFlags;
use nix::sys::mman::ProtFlags;
use nix::sys::mman::mmap as nix_mmap;
use nix::sys::mman::munmap as nix_munmap;
use nix::unistd::ftruncate as nix_ftruncate;
/// ShmemHandle represents a shared memory area that can be shared by processes over fork().
/// Unlike shared memory allocated by Postgres, this area is resizable, up to 'max_size' that's
/// specified at creation.
///
/// The area is backed by an anonymous file created with memfd_create(). The full address space for
/// 'max_size' is reserved up-front with mmap(), but whenever you call [`ShmemHandle::set_size`],
/// the underlying file is resized. Do not access the area beyond the current size. Currently, that
/// will cause the file to be expanded, but we might use mprotect() etc. to enforce that in the
/// future.
pub struct ShmemHandle {
/// memfd file descriptor
fd: OwnedFd,
max_size: usize,
// Pointer to the beginning of the shared memory area. The header is stored there.
shared_ptr: NonNull<SharedStruct>,
// Pointer to the beginning of the user data
pub data_ptr: NonNull<u8>,
}
/// This is stored at the beginning in the shared memory area.
struct SharedStruct {
max_size: usize,
/// Current size of the backing file. The high-order bit is used for the RESIZE_IN_PROGRESS flag
current_size: AtomicUsize,
}
const RESIZE_IN_PROGRESS: usize = 1 << 63;
const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
/// Error type returned by the ShmemHandle functions.
#[derive(thiserror::Error, Debug)]
#[error("{msg}: {errno}")]
pub struct Error {
pub msg: String,
pub errno: Errno,
}
impl Error {
fn new(msg: &str, errno: Errno) -> Error {
Error {
msg: msg.to_string(),
errno,
}
}
}
impl ShmemHandle {
/// Create a new shared memory area. To communicate between processes, the processes need to be
/// fork()'d after calling this, so that the ShmemHandle is inherited by all processes.
///
/// If the ShmemHandle is dropped, the memory is unmapped from the current process. Other
/// processes can continue using it, however.
pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<ShmemHandle, Error> {
// create the backing anonymous file.
let fd = create_backing_file(name)?;
Self::new_with_fd(fd, initial_size, max_size)
}
fn new_with_fd(
fd: OwnedFd,
initial_size: usize,
max_size: usize,
) -> Result<ShmemHandle, Error> {
// We reserve the high-order bit for the RESIZE_IN_PROGRESS flag, and the actual size
// is a little larger than this because of the SharedStruct header. Make the upper limit
// somewhat smaller than that, because with anything close to that, you'll run out of
// memory anyway.
if max_size >= 1 << 48 {
panic!("max size {max_size} too large");
}
if initial_size > max_size {
panic!("initial size {initial_size} larger than max size {max_size}");
}
// The actual initial / max size is the one given by the caller, plus the size of
// 'SharedStruct'.
let initial_size = HEADER_SIZE + initial_size;
let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
// Reserve address space for it with mmap
//
// TODO: Use MAP_HUGETLB if possible
let start_ptr = unsafe {
nix_mmap(
None,
max_size,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_SHARED,
&fd,
0,
)
}
.map_err(|e| Error::new("mmap failed: {e}", e))?;
// Reserve space for the initial size
enlarge_file(fd.as_fd(), initial_size as u64)?;
// Initialize the header
let shared: NonNull<SharedStruct> = start_ptr.cast();
unsafe {
shared.write(SharedStruct {
max_size: max_size.into(),
current_size: AtomicUsize::new(initial_size),
})
};
// The user data begins after the header
let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
Ok(ShmemHandle {
fd,
max_size: max_size.into(),
shared_ptr: shared,
data_ptr,
})
}
// return reference to the header
fn shared(&self) -> &SharedStruct {
unsafe { self.shared_ptr.as_ref() }
}
/// Resize the shared memory area. 'new_size' must not be larger than the 'max_size' specified
/// when creating the area.
///
/// This may only be called from one process/thread concurrently. We detect that case
/// and return an Error.
pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
let new_size = new_size + HEADER_SIZE;
let shared = self.shared();
if new_size > self.max_size {
panic!(
"new size ({} is greater than max size ({})",
new_size, self.max_size
);
}
assert_eq!(self.max_size, shared.max_size);
// Lock the area by setting the bit in 'current_size'
//
// Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
// and the posix_fallocate/ftruncate call is surely a synchronization point anyway. But
// since this is not performance-critical, better safe than sorry .
let mut old_size = shared.current_size.load(Ordering::Acquire);
loop {
if (old_size & RESIZE_IN_PROGRESS) != 0 {
return Err(Error::new(
"concurrent resize detected",
Errno::UnknownErrno,
));
}
match shared.current_size.compare_exchange(
old_size,
new_size,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => old_size = x,
}
}
// Ok, we got the lock.
//
// NB: If anything goes wrong, we *must* clear the bit!
let result = {
use std::cmp::Ordering::{Equal, Greater, Less};
match new_size.cmp(&old_size) {
Less => nix_ftruncate(&self.fd, new_size as i64).map_err(|e| {
Error::new("could not shrink shmem segment, ftruncate failed: {e}", e)
}),
Equal => Ok(()),
Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
}
};
// Unlock
shared.current_size.store(
if result.is_ok() { new_size } else { old_size },
Ordering::Release,
);
result
}
/// Returns the current user-visible size of the shared memory segment.
///
/// NOTE: a concurrent set_size() call can change the size at any time. It is the caller's
/// responsibility not to access the area beyond the current size.
pub fn current_size(&self) -> usize {
let total_current_size =
self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
total_current_size - HEADER_SIZE
}
}
impl Drop for ShmemHandle {
fn drop(&mut self) {
// SAFETY: The pointer was obtained from mmap() with the given size.
// We unmap the entire region.
let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
// The fd is dropped automatically by OwnedFd.
}
}
/// Create a "backing file" for the shared memory area. On Linux, use memfd_create(), to create an
/// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
/// development and testing, but in production we want the file to stay in memory.
///
/// disable 'unused_variables' warnings, because in the macos path, 'name' is unused.
#[allow(unused_variables)]
fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
#[cfg(not(target_os = "macos"))]
{
nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
.map_err(|e| Error::new("memfd_create failed: {e}", e))
}
#[cfg(target_os = "macos")]
{
let file = tempfile::tempfile().map_err(|e| {
Error::new(
"could not create temporary file to back shmem area: {e}",
nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
)
})?;
Ok(OwnedFd::from(file))
}
}
fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
// Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
// we don't get a segfault later when trying to actually use it.
#[cfg(not(target_os = "macos"))]
{
nix::fcntl::posix_fallocate(fd, 0, size as i64).map_err(|e| {
Error::new(
"could not grow shmem segment, posix_fallocate failed: {e}",
e,
)
})
}
// As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
#[cfg(target_os = "macos")]
{
nix::unistd::ftruncate(fd, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, ftruncate failed: {e}", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nix::unistd::ForkResult;
use std::ops::Range;
/// check that all bytes in given range have the expected value.
fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
for i in range {
let b = unsafe { *(ptr.add(i)) };
assert_eq!(expected, b, "unexpected byte at offset {i}");
}
}
/// Write 'b' to all bytes in the given range
fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
}
// simple single-process test of growing and shrinking
#[test]
fn test_shmem_resize() -> Result<(), Error> {
let max_size = 1024 * 1024;
let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
assert_eq!(init_struct.current_size(), 0);
// Initial grow
let size1 = 10000;
init_struct.set_size(size1).unwrap();
assert_eq!(init_struct.current_size(), size1);
// Write some data
let data_ptr = init_struct.data_ptr.as_ptr();
write_range(data_ptr, 0xAA, 0..size1);
assert_range(data_ptr, 0xAA, 0..size1);
// Shrink
let size2 = 5000;
init_struct.set_size(size2).unwrap();
assert_eq!(init_struct.current_size(), size2);
// Grow again
let size3 = 20000;
init_struct.set_size(size3).unwrap();
assert_eq!(init_struct.current_size(), size3);
// Try to read it. The area that was shrunk and grown again should read as all zeros now
assert_range(data_ptr, 0xAA, 0..5000);
assert_range(data_ptr, 0, 5000..size1);
// Try to grow beyond max_size
//let size4 = max_size + 1;
//assert!(init_struct.set_size(size4).is_err());
// Dropping init_struct should unmap the memory
drop(init_struct);
Ok(())
}
/// This is used in tests to coordinate between test processes. It's like std::sync::Barrier,
/// but is stored in the shared memory area and works across processes. It's implemented by
/// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
struct SimpleBarrier {
num_procs: usize,
count: AtomicUsize,
}
impl SimpleBarrier {
unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
unsafe {
*ptr = SimpleBarrier {
num_procs,
count: AtomicUsize::new(0),
}
}
}
pub fn wait(&self) {
let old = self.count.fetch_add(1, Ordering::Relaxed);
let generation = old / self.num_procs;
let mut current = old + 1;
while current < (generation + 1) * self.num_procs {
std::thread::sleep(std::time::Duration::from_millis(10));
current = self.count.load(Ordering::Relaxed);
}
}
}
#[test]
fn test_multi_process() {
// Initialize
let max_size = 1_000_000_000_000;
let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
let ptr = init_struct.data_ptr.as_ptr();
// Store the SimpleBarrier in the first 1k of the area.
init_struct.set_size(10000).unwrap();
let barrier_ptr: *mut SimpleBarrier = unsafe {
ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
.cast()
};
unsafe { SimpleBarrier::init(barrier_ptr, 2) };
let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
// Fork another test process. The code after this runs in both processes concurrently.
let fork_result = unsafe { nix::unistd::fork().unwrap() };
// In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
if fork_result.is_parent() {
write_range(ptr, 0xAA, 1000..2000);
} else {
write_range(ptr, 0xBB, 2000..3000);
}
barrier.wait();
// Verify the contents. (in both processes)
assert_range(ptr, 0xAA, 1000..2000);
assert_range(ptr, 0xBB, 2000..3000);
// Grow, from the child this time
let size = 10_000_000;
if !fork_result.is_parent() {
init_struct.set_size(size).unwrap();
}
barrier.wait();
// make some writes at the end
if fork_result.is_parent() {
write_range(ptr, 0xAA, (size - 10)..size);
} else {
write_range(ptr, 0xBB, (size - 20)..(size - 10));
}
barrier.wait();
// Verify the contents. (This runs in both processes)
assert_range(ptr, 0, (size - 1000)..(size - 20));
assert_range(ptr, 0xBB, (size - 20)..(size - 10));
assert_range(ptr, 0xAA, (size - 10)..size);
if let ForkResult::Parent { child } = fork_result {
nix::sys::wait::waitpid(child, None).unwrap();
}
}
}

View File

@@ -1,409 +0,0 @@
//! Dynamically resizable contiguous chunk of shared memory
use std::num::NonZeroUsize;
use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use nix::errno::Errno;
use nix::sys::mman::MapFlags;
use nix::sys::mman::ProtFlags;
use nix::sys::mman::mmap as nix_mmap;
use nix::sys::mman::munmap as nix_munmap;
use nix::unistd::ftruncate as nix_ftruncate;
/// `ShmemHandle` represents a shared memory area that can be shared by processes over `fork()`.
/// Unlike shared memory allocated by Postgres, this area is resizable, up to `max_size` that's
/// specified at creation.
///
/// The area is backed by an anonymous file created with `memfd_create()`. The full address space for
/// `max_size` is reserved up-front with `mmap()`, but whenever you call [`ShmemHandle::set_size`],
/// the underlying file is resized. Do not access the area beyond the current size. Currently, that
/// will cause the file to be expanded, but we might use `mprotect()` etc. to enforce that in the
/// future.
pub struct ShmemHandle {
/// memfd file descriptor
fd: OwnedFd,
max_size: usize,
// Pointer to the beginning of the shared memory area. The header is stored there.
shared_ptr: NonNull<SharedStruct>,
// Pointer to the beginning of the user data
pub data_ptr: NonNull<u8>,
}
/// This is stored at the beginning in the shared memory area.
struct SharedStruct {
max_size: usize,
/// Current size of the backing file. The high-order bit is used for the [`RESIZE_IN_PROGRESS`] flag.
current_size: AtomicUsize,
}
const RESIZE_IN_PROGRESS: usize = 1 << 63;
const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
/// Error type returned by the [`ShmemHandle`] functions.
#[derive(thiserror::Error, Debug)]
#[error("{msg}: {errno}")]
pub struct Error {
pub msg: String,
pub errno: Errno,
}
impl Error {
fn new(msg: &str, errno: Errno) -> Self {
Self {
msg: msg.to_string(),
errno,
}
}
}
impl ShmemHandle {
/// Create a new shared memory area. To communicate between processes, the processes need to be
/// `fork()`'d after calling this, so that the `ShmemHandle` is inherited by all processes.
///
/// If the `ShmemHandle` is dropped, the memory is unmapped from the current process. Other
/// processes can continue using it, however.
pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<Self, Error> {
// create the backing anonymous file.
let fd = create_backing_file(name)?;
Self::new_with_fd(fd, initial_size, max_size)
}
fn new_with_fd(fd: OwnedFd, initial_size: usize, max_size: usize) -> Result<Self, Error> {
// We reserve the high-order bit for the `RESIZE_IN_PROGRESS` flag, and the actual size
// is a little larger than this because of the SharedStruct header. Make the upper limit
// somewhat smaller than that, because with anything close to that, you'll run out of
// memory anyway.
assert!(max_size < 1 << 48, "max size {max_size} too large");
assert!(
initial_size <= max_size,
"initial size {initial_size} larger than max size {max_size}"
);
// The actual initial / max size is the one given by the caller, plus the size of
// 'SharedStruct'.
let initial_size = HEADER_SIZE + initial_size;
let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
// Reserve address space for it with mmap
//
// TODO: Use MAP_HUGETLB if possible
let start_ptr = unsafe {
nix_mmap(
None,
max_size,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_SHARED,
&fd,
0,
)
}
.map_err(|e| Error::new("mmap failed", e))?;
// Reserve space for the initial size
enlarge_file(fd.as_fd(), initial_size as u64)?;
// Initialize the header
let shared: NonNull<SharedStruct> = start_ptr.cast();
unsafe {
shared.write(SharedStruct {
max_size: max_size.into(),
current_size: AtomicUsize::new(initial_size),
});
}
// The user data begins after the header
let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
Ok(Self {
fd,
max_size: max_size.into(),
shared_ptr: shared,
data_ptr,
})
}
// return reference to the header
fn shared(&self) -> &SharedStruct {
unsafe { self.shared_ptr.as_ref() }
}
/// Resize the shared memory area. `new_size` must not be larger than the `max_size` specified
/// when creating the area.
///
/// This may only be called from one process/thread concurrently. We detect that case
/// and return an [`shmem::Error`](Error).
pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
let new_size = new_size + HEADER_SIZE;
let shared = self.shared();
assert!(
new_size <= self.max_size,
"new size ({new_size}) is greater than max size ({})",
self.max_size
);
assert_eq!(self.max_size, shared.max_size);
// Lock the area by setting the bit in `current_size`
//
// Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
// and the `posix_fallocate`/`ftruncate` call is surely a synchronization point anyway. But
// since this is not performance-critical, better safe than sorry.
let mut old_size = shared.current_size.load(Ordering::Acquire);
loop {
if (old_size & RESIZE_IN_PROGRESS) != 0 {
return Err(Error::new(
"concurrent resize detected",
Errno::UnknownErrno,
));
}
match shared.current_size.compare_exchange(
old_size,
new_size,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => old_size = x,
}
}
// Ok, we got the lock.
//
// NB: If anything goes wrong, we *must* clear the bit!
let result = {
use std::cmp::Ordering::{Equal, Greater, Less};
match new_size.cmp(&old_size) {
Less => nix_ftruncate(&self.fd, new_size as i64)
.map_err(|e| Error::new("could not shrink shmem segment, ftruncate failed", e)),
Equal => Ok(()),
Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
}
};
// Unlock
shared.current_size.store(
if result.is_ok() { new_size } else { old_size },
Ordering::Release,
);
result
}
/// Returns the current user-visible size of the shared memory segment.
///
/// NOTE: a concurrent [`ShmemHandle::set_size()`] call can change the size at any time.
/// It is the caller's responsibility not to access the area beyond the current size.
pub fn current_size(&self) -> usize {
let total_current_size =
self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
total_current_size - HEADER_SIZE
}
}
impl Drop for ShmemHandle {
fn drop(&mut self) {
// SAFETY: The pointer was obtained from mmap() with the given size.
// We unmap the entire region.
let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
// The fd is dropped automatically by OwnedFd.
}
}
/// Create a "backing file" for the shared memory area. On Linux, use `memfd_create()`, to create an
/// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
/// development and testing, but in production we want the file to stay in memory.
///
/// Disable unused variables warnings because `name` is unused in the macos path.
#[allow(unused_variables)]
fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
#[cfg(not(target_os = "macos"))]
{
nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
.map_err(|e| Error::new("memfd_create failed", e))
}
#[cfg(target_os = "macos")]
{
let file = tempfile::tempfile().map_err(|e| {
Error::new(
"could not create temporary file to back shmem area",
nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
)
})?;
Ok(OwnedFd::from(file))
}
}
fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
// Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
// we don't get a segfault later when trying to actually use it.
#[cfg(not(target_os = "macos"))]
{
nix::fcntl::posix_fallocate(fd, 0, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, posix_fallocate failed", e))
}
// As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
#[cfg(target_os = "macos")]
{
nix::unistd::ftruncate(fd, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, ftruncate failed", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nix::unistd::ForkResult;
use std::ops::Range;
/// check that all bytes in given range have the expected value.
fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
for i in range {
let b = unsafe { *(ptr.add(i)) };
assert_eq!(expected, b, "unexpected byte at offset {i}");
}
}
/// Write 'b' to all bytes in the given range
fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
}
// simple single-process test of growing and shrinking
#[test]
fn test_shmem_resize() -> Result<(), Error> {
let max_size = 1024 * 1024;
let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
assert_eq!(init_struct.current_size(), 0);
// Initial grow
let size1 = 10000;
init_struct.set_size(size1).unwrap();
assert_eq!(init_struct.current_size(), size1);
// Write some data
let data_ptr = init_struct.data_ptr.as_ptr();
write_range(data_ptr, 0xAA, 0..size1);
assert_range(data_ptr, 0xAA, 0..size1);
// Shrink
let size2 = 5000;
init_struct.set_size(size2).unwrap();
assert_eq!(init_struct.current_size(), size2);
// Grow again
let size3 = 20000;
init_struct.set_size(size3).unwrap();
assert_eq!(init_struct.current_size(), size3);
// Try to read it. The area that was shrunk and grown again should read as all zeros now
assert_range(data_ptr, 0xAA, 0..5000);
assert_range(data_ptr, 0, 5000..size1);
// Try to grow beyond max_size
//let size4 = max_size + 1;
//assert!(init_struct.set_size(size4).is_err());
// Dropping init_struct should unmap the memory
drop(init_struct);
Ok(())
}
/// This is used in tests to coordinate between test processes. It's like `std::sync::Barrier`,
/// but is stored in the shared memory area and works across processes. It's implemented by
/// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
struct SimpleBarrier {
num_procs: usize,
count: AtomicUsize,
}
impl SimpleBarrier {
unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
unsafe {
*ptr = SimpleBarrier {
num_procs,
count: AtomicUsize::new(0),
}
}
}
pub fn wait(&self) {
let old = self.count.fetch_add(1, Ordering::Relaxed);
let generation = old / self.num_procs;
let mut current = old + 1;
while current < (generation + 1) * self.num_procs {
std::thread::sleep(std::time::Duration::from_millis(10));
current = self.count.load(Ordering::Relaxed);
}
}
}
#[test]
fn test_multi_process() {
// Initialize
let max_size = 1_000_000_000_000;
let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
let ptr = init_struct.data_ptr.as_ptr();
// Store the SimpleBarrier in the first 1k of the area.
init_struct.set_size(10000).unwrap();
let barrier_ptr: *mut SimpleBarrier = unsafe {
ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
.cast()
};
unsafe { SimpleBarrier::init(barrier_ptr, 2) };
let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
// Fork another test process. The code after this runs in both processes concurrently.
let fork_result = unsafe { nix::unistd::fork().unwrap() };
// In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
if fork_result.is_parent() {
write_range(ptr, 0xAA, 1000..2000);
} else {
write_range(ptr, 0xBB, 2000..3000);
}
barrier.wait();
// Verify the contents. (in both processes)
assert_range(ptr, 0xAA, 1000..2000);
assert_range(ptr, 0xBB, 2000..3000);
// Grow, from the child this time
let size = 10_000_000;
if !fork_result.is_parent() {
init_struct.set_size(size).unwrap();
}
barrier.wait();
// make some writes at the end
if fork_result.is_parent() {
write_range(ptr, 0xAA, (size - 10)..size);
} else {
write_range(ptr, 0xBB, (size - 20)..(size - 10));
}
barrier.wait();
// Verify the contents. (This runs in both processes)
assert_range(ptr, 0, (size - 1000)..(size - 20));
assert_range(ptr, 0xBB, (size - 20)..(size - 10));
assert_range(ptr, 0xAA, (size - 10)..size);
if let ForkResult::Parent { child } = fork_result {
nix::sys::wait::waitpid(child, None).unwrap();
}
}
}

View File

@@ -1,111 +0,0 @@
//! Simple utilities akin to what's in [`std::sync`] but designed to work with shared memory.
use std::mem::MaybeUninit;
use std::ptr::NonNull;
use nix::errno::Errno;
pub type RwLock<T> = lock_api::RwLock<PthreadRwLock, T>;
pub type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, PthreadRwLock, T>;
pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, PthreadRwLock, T>;
pub type ValueReadGuard<'a, T> = lock_api::MappedRwLockReadGuard<'a, PthreadRwLock, T>;
pub type ValueWriteGuard<'a, T> = lock_api::MappedRwLockWriteGuard<'a, PthreadRwLock, T>;
/// Shared memory read-write lock.
pub struct PthreadRwLock(Option<NonNull<libc::pthread_rwlock_t>>);
/// Simple macro that calls a function in the libc namespace and panics if return value is nonzero.
macro_rules! libc_checked {
($fn_name:ident ( $($arg:expr),* )) => {{
let res = libc::$fn_name($($arg),*);
if res != 0 {
panic!("{} failed with {}", stringify!($fn_name), Errno::from_raw(res));
}
}};
}
impl PthreadRwLock {
/// Creates a new `PthreadRwLock` on top of a pointer to a pthread rwlock.
///
/// # Safety
/// `lock` must be non-null. Every unsafe operation will panic in the event of an error.
pub unsafe fn new(lock: *mut libc::pthread_rwlock_t) -> Self {
unsafe {
let mut attrs = MaybeUninit::uninit();
libc_checked!(pthread_rwlockattr_init(attrs.as_mut_ptr()));
libc_checked!(pthread_rwlockattr_setpshared(
attrs.as_mut_ptr(),
libc::PTHREAD_PROCESS_SHARED
));
libc_checked!(pthread_rwlock_init(lock, attrs.as_mut_ptr()));
// Safety: POSIX specifies that "any function affecting the attributes
// object (including destruction) shall not affect any previously
// initialized read-write locks".
libc_checked!(pthread_rwlockattr_destroy(attrs.as_mut_ptr()));
Self(Some(NonNull::new_unchecked(lock)))
}
}
fn inner(&self) -> NonNull<libc::pthread_rwlock_t> {
match self.0 {
None => {
panic!("PthreadRwLock constructed badly - something likely used RawRwLock::INIT")
}
Some(x) => x,
}
}
}
unsafe impl lock_api::RawRwLock for PthreadRwLock {
type GuardMarker = lock_api::GuardSend;
const INIT: Self = Self(None);
fn try_lock_shared(&self) -> bool {
unsafe {
let res = libc::pthread_rwlock_tryrdlock(self.inner().as_ptr());
match res {
0 => true,
libc::EAGAIN => false,
_ => panic!(
"pthread_rwlock_tryrdlock failed with {}",
Errno::from_raw(res)
),
}
}
}
fn try_lock_exclusive(&self) -> bool {
unsafe {
let res = libc::pthread_rwlock_trywrlock(self.inner().as_ptr());
match res {
0 => true,
libc::EAGAIN => false,
_ => panic!("try_wrlock failed with {}", Errno::from_raw(res)),
}
}
}
fn lock_shared(&self) {
unsafe {
libc_checked!(pthread_rwlock_rdlock(self.inner().as_ptr()));
}
}
fn lock_exclusive(&self) {
unsafe {
libc_checked!(pthread_rwlock_wrlock(self.inner().as_ptr()));
}
}
unsafe fn unlock_exclusive(&self) {
unsafe {
libc_checked!(pthread_rwlock_unlock(self.inner().as_ptr()));
}
}
unsafe fn unlock_shared(&self) {
unsafe {
libc_checked!(pthread_rwlock_unlock(self.inner().as_ptr()));
}
}
}

View File

@@ -749,18 +749,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
trace!("got query {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
match e {
err @ QueryError::Shutdown => {
// Notify postgres of the connection shutdown at the libpq
// protocol level. This avoids postgres having to tell apart
// from an idle connection and a stale one, which is bug prone.
let shutdown_error = short_error(&err);
self.write_message_noflush(&BeMessage::ErrorResponse(
&shutdown_error,
Some(err.pg_error_code()),
))?;
return Ok(ProcessMsgResult::Break);
}
QueryError::Shutdown => return Ok(ProcessMsgResult::Break),
QueryError::SimulatedConnectionError => {
return Err(QueryError::SimulatedConnectionError);
}

View File

@@ -47,7 +47,6 @@ tracing-subscriber = { workspace = true, features = ["json", "registry"] }
tracing-utils.workspace = true
rand.workspace = true
scopeguard.workspace = true
uuid.workspace = true
strum.workspace = true
strum_macros.workspace = true
walkdir.workspace = true

View File

@@ -12,8 +12,7 @@ use jsonwebtoken::{
Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode,
};
use pem::Pem;
use serde::{Deserialize, Deserializer, Serialize, de::DeserializeOwned};
use uuid::Uuid;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use crate::id::TenantId;
@@ -26,11 +25,6 @@ pub enum Scope {
/// Provides access to all data for a specific tenant (specified in `struct Claims` below)
// TODO: join these two?
Tenant,
/// Provides access to all data for a specific tenant, but based on endpoint ID. This token scope
/// is only used by compute to fetch the spec for a specific endpoint. The spec contains a Tenant-scoped
/// token authorizing access to all data of a tenant, so the spec-fetch API requires a TenantEndpoint
/// scope token to ensure that untrusted compute nodes can't fetch spec for arbitrary endpoints.
TenantEndpoint,
/// Provides blanket access to all tenants on the pageserver plus pageserver-wide APIs.
/// Should only be used e.g. for status check/tenant creation/list.
PageServerApi,
@@ -57,43 +51,17 @@ pub enum Scope {
ControllerPeer,
}
fn deserialize_empty_string_as_none_uuid<'de, D>(deserializer: D) -> Result<Option<Uuid>, D::Error>
where
D: Deserializer<'de>,
{
let opt = Option::<String>::deserialize(deserializer)?;
match opt.as_deref() {
Some("") => Ok(None),
Some(s) => Uuid::parse_str(s)
.map(Some)
.map_err(serde::de::Error::custom),
None => Ok(None),
}
}
/// JWT payload. See docs/authentication.md for the format
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct Claims {
#[serde(default)]
pub tenant_id: Option<TenantId>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
// Neon control plane includes this field as empty in the claims.
// Consider it None in those cases.
deserialize_with = "deserialize_empty_string_as_none_uuid"
)]
pub endpoint_id: Option<Uuid>,
pub scope: Scope,
}
impl Claims {
pub fn new(tenant_id: Option<TenantId>, scope: Scope) -> Self {
Self {
tenant_id,
scope,
endpoint_id: None,
}
Self { tenant_id, scope }
}
}
@@ -244,7 +212,6 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
let expected_claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
scope: Scope::Tenant,
endpoint_id: None,
};
// A test token containing the following payload, signed using TEST_PRIV_KEY_ED25519:
@@ -273,7 +240,6 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
let claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
scope: Scope::Tenant,
endpoint_id: None,
};
let pem = pem::parse(TEST_PRIV_KEY_ED25519).unwrap();

View File

@@ -431,7 +431,7 @@ pub fn empty_shmem() -> crate::bindings::WalproposerShmemState {
let empty_wal_rate_limiter = crate::bindings::WalRateLimiter {
should_limit: crate::bindings::pg_atomic_uint32 { value: 0 },
sent_bytes: 0,
last_recorded_time_us: crate::bindings::pg_atomic_uint64 { value: 0 },
last_recorded_time_us: 0,
};
crate::bindings::WalproposerShmemState {

View File

@@ -873,22 +873,6 @@ impl Client {
.map_err(Error::ReceiveBody)
}
pub async fn reset_alert_gauges(&self) -> Result<()> {
let uri = format!(
"{}/hadron-internal/reset_alert_gauges",
self.mgmt_api_endpoint
);
self.start_request(Method::POST, uri)
.send()
.await
.map_err(Error::SendRequest)?
.error_from_body()
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
pub async fn wait_lsn(
&self,
tenant_shard_id: TenantShardId,

View File

@@ -20,8 +20,7 @@ pub fn check_permission(claims: &Claims, tenant_id: Option<TenantId>) -> Result<
| Scope::GenerationsApi
| Scope::Infra
| Scope::Scrubber
| Scope::ControllerPeer
| Scope::TenantEndpoint,
| Scope::ControllerPeer,
_,
) => Err(AuthError(
format!(

View File

@@ -2357,7 +2357,6 @@ async fn timeline_compact_handler(
flags,
sub_compaction,
sub_compaction_max_job_size_mb,
gc_compaction_do_metadata_compaction: false,
};
let scheduled = compact_request

View File

@@ -813,7 +813,6 @@ impl Timeline {
let gc_cutoff_lsn_guard = self.get_applied_gc_cutoff_lsn();
let gc_cutoff_planned = {
let gc_info = self.gc_info.read().unwrap();
info!(cutoffs=?gc_info.cutoffs, applied_cutoff=%*gc_cutoff_lsn_guard, "starting find_lsn_for_timestamp");
gc_info.min_cutoff()
};
// Usually the planned cutoff is newer than the cutoff of the last gc run,

View File

@@ -9216,11 +9216,7 @@ mod tests {
let cancel = CancellationToken::new();
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
@@ -9303,11 +9299,7 @@ mod tests {
guard.cutoffs.space = Lsn(0x40);
}
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
@@ -9844,11 +9836,7 @@ mod tests {
let cancel = CancellationToken::new();
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
@@ -9883,11 +9871,7 @@ mod tests {
guard.cutoffs.space = Lsn(0x40);
}
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
@@ -10462,7 +10446,7 @@ mod tests {
&cancel,
CompactOptions {
flags: dryrun_flags,
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -10473,22 +10457,14 @@ mod tests {
verify_result().await;
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
verify_result().await;
// compact again
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
verify_result().await;
@@ -10507,22 +10483,14 @@ mod tests {
guard.cutoffs.space = Lsn(0x38);
}
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
verify_result().await; // no wals between 0x30 and 0x38, so we should obtain the same result
// not increasing the GC horizon and compact again
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
verify_result().await;
@@ -10727,7 +10695,7 @@ mod tests {
&cancel,
CompactOptions {
flags: dryrun_flags,
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -10738,22 +10706,14 @@ mod tests {
verify_result().await;
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
verify_result().await;
// compact again
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
verify_result().await;
@@ -10953,11 +10913,7 @@ mod tests {
let cancel = CancellationToken::new();
branch_tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
@@ -10970,7 +10926,7 @@ mod tests {
&cancel,
CompactOptions {
compact_lsn_range: Some(CompactLsnRange::above(Lsn(0x40))),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -11638,7 +11594,7 @@ mod tests {
CompactOptions {
flags: EnumSet::new(),
compact_key_range: Some((get_key(0)..get_key(2)).into()),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -11685,7 +11641,7 @@ mod tests {
CompactOptions {
flags: EnumSet::new(),
compact_key_range: Some((get_key(2)..get_key(4)).into()),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -11737,7 +11693,7 @@ mod tests {
CompactOptions {
flags: EnumSet::new(),
compact_key_range: Some((get_key(4)..get_key(9)).into()),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -11788,7 +11744,7 @@ mod tests {
CompactOptions {
flags: EnumSet::new(),
compact_key_range: Some((get_key(9)..get_key(10)).into()),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -11844,7 +11800,7 @@ mod tests {
CompactOptions {
flags: EnumSet::new(),
compact_key_range: Some((get_key(0)..get_key(10)).into()),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -12115,7 +12071,7 @@ mod tests {
&cancel,
CompactOptions {
compact_lsn_range: Some(CompactLsnRange::above(Lsn(0x28))),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -12150,11 +12106,7 @@ mod tests {
// compact again
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
verify_result().await;
@@ -12373,7 +12325,7 @@ mod tests {
CompactOptions {
compact_key_range: Some((get_key(0)..get_key(2)).into()),
compact_lsn_range: Some((Lsn(0x20)..Lsn(0x28)).into()),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -12419,7 +12371,7 @@ mod tests {
CompactOptions {
compact_key_range: Some((get_key(3)..get_key(8)).into()),
compact_lsn_range: Some((Lsn(0x28)..Lsn(0x40)).into()),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -12467,7 +12419,7 @@ mod tests {
CompactOptions {
compact_key_range: Some((get_key(0)..get_key(5)).into()),
compact_lsn_range: Some((Lsn(0x20)..Lsn(0x50)).into()),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -12502,11 +12454,7 @@ mod tests {
// final full compaction
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
verify_result().await;
@@ -12616,7 +12564,7 @@ mod tests {
CompactOptions {
compact_key_range: None,
compact_lsn_range: None,
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)

View File

@@ -939,20 +939,6 @@ pub(crate) struct CompactOptions {
/// Set job size for the GC compaction.
/// This option is only used by GC compaction.
pub sub_compaction_max_job_size_mb: Option<u64>,
/// Only for GC compaction.
/// If set, the compaction will compact the metadata layers. Should be only set to true in unit tests
/// because metadata compaction is not fully supported yet.
pub gc_compaction_do_metadata_compaction: bool,
}
impl CompactOptions {
#[cfg(test)]
pub fn default_for_gc_compaction_unit_tests() -> Self {
Self {
gc_compaction_do_metadata_compaction: true,
..Default::default()
}
}
}
impl std::fmt::Debug for Timeline {
@@ -2199,7 +2185,6 @@ impl Timeline {
compact_lsn_range: None,
sub_compaction: false,
sub_compaction_max_job_size_mb: None,
gc_compaction_do_metadata_compaction: false,
},
ctx,
)

View File

@@ -396,7 +396,6 @@ impl GcCompactionQueue {
}),
compact_lsn_range: None,
sub_compaction_max_job_size_mb: None,
gc_compaction_do_metadata_compaction: false,
},
permit,
);
@@ -513,7 +512,6 @@ impl GcCompactionQueue {
compact_key_range: Some(job.compact_key_range.into()),
compact_lsn_range: Some(job.compact_lsn_range.into()),
sub_compaction_max_job_size_mb: None,
gc_compaction_do_metadata_compaction: false,
};
pending_tasks.push(GcCompactionQueueItem::SubCompactionJob {
options,
@@ -787,8 +785,6 @@ pub(crate) struct GcCompactJob {
/// as specified here. The true range being compacted is `min_lsn/max_lsn` in [`GcCompactionJobDescription`].
/// min_lsn will always <= the lower bound specified here, and max_lsn will always >= the upper bound specified here.
pub compact_lsn_range: Range<Lsn>,
/// See [`CompactOptions::gc_compaction_do_metadata_compaction`].
pub do_metadata_compaction: bool,
}
impl GcCompactJob {
@@ -803,7 +799,6 @@ impl GcCompactJob {
.compact_lsn_range
.map(|x| x.into())
.unwrap_or(Lsn::INVALID..Lsn::MAX),
do_metadata_compaction: options.gc_compaction_do_metadata_compaction,
}
}
}
@@ -3179,7 +3174,6 @@ impl Timeline {
dry_run: job.dry_run,
compact_key_range: start..end,
compact_lsn_range: job.compact_lsn_range.start..compact_below_lsn,
do_metadata_compaction: false,
});
current_start = Some(end);
}
@@ -3242,7 +3236,7 @@ impl Timeline {
async fn compact_with_gc_inner(
self: &Arc<Self>,
cancel: &CancellationToken,
mut job: GcCompactJob,
job: GcCompactJob,
ctx: &RequestContext,
yield_for_l0: bool,
) -> Result<CompactionOutcome, CompactionError> {
@@ -3250,28 +3244,6 @@ impl Timeline {
// with legacy compaction tasks in the future. Always ensure the lock order is compaction -> gc.
// Note that we already acquired the compaction lock when the outer `compact` function gets called.
// If the job is not configured to compact the metadata key range, shrink the key range
// to exclude the metadata key range. The check is done by checking if the end of the key range
// is larger than the start of the metadata key range. Note that metadata keys cover the entire
// second half of the keyspace, so it's enough to only check the end of the key range.
if !job.do_metadata_compaction
&& job.compact_key_range.end > Key::metadata_key_range().start
{
tracing::info!(
"compaction for metadata key range is not supported yet, overriding compact_key_range from {} to {}",
job.compact_key_range.end,
Key::metadata_key_range().start
);
// Shrink the key range to exclude the metadata key range.
job.compact_key_range.end = Key::metadata_key_range().start;
// Skip the job if the key range completely lies within the metadata key range.
if job.compact_key_range.start >= job.compact_key_range.end {
tracing::info!("compact_key_range is empty, skipping compaction");
return Ok(CompactionOutcome::Done);
}
}
let timer = Instant::now();
let begin_timer = timer;

View File

@@ -184,7 +184,7 @@ pub(super) async fn connection_manager_loop_step(
// If we've not received any updates from the broker from a while, are waiting for WAL
// and have no safekeeper connection or connection candidates, then it might be that
// the broker subscription is wedged. Drop the current subscription and re-subscribe
// the broker subscription is wedged. Drop the currrent subscription and re-subscribe
// with the goal of unblocking it.
_ = broker_reset_interval.tick() => {
let awaiting_lsn = wait_lsn_status.borrow().is_some();
@@ -192,7 +192,7 @@ pub(super) async fn connection_manager_loop_step(
let no_connection = connection_manager_state.wal_connection.is_none();
if awaiting_lsn && no_candidates && no_connection {
tracing::info!("No broker updates received for a while, but waiting for WAL. Re-setting stream ...");
tracing::warn!("No broker updates received for a while, but waiting for WAL. Re-setting stream ...");
broker_subscription = subscribe_for_timeline_updates(broker_client, id, cancel).await?;
}
},

View File

@@ -219,6 +219,10 @@ static char *lfc_path;
static uint64 lfc_generation;
static FileCacheControl *lfc_ctl;
static bool lfc_do_prewarm;
static shmem_startup_hook_type prev_shmem_startup_hook;
#if PG_VERSION_NUM>=150000
static shmem_request_hook_type prev_shmem_request_hook;
#endif
bool lfc_store_prefetch_result;
bool lfc_prewarm_update_ws_estimation;
@@ -338,14 +342,18 @@ lfc_ensure_opened(void)
return true;
}
void
LfcShmemInit(void)
static void
lfc_shmem_startup(void)
{
bool found;
static HASHCTL info;
if (lfc_max_size <= 0)
return;
if (prev_shmem_startup_hook)
{
prev_shmem_startup_hook();
}
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
lfc_ctl = (FileCacheControl *) ShmemInitStruct("lfc", sizeof(FileCacheControl), &found);
if (!found)
@@ -390,16 +398,19 @@ LfcShmemInit(void)
ConditionVariableInit(&lfc_ctl->cv[i]);
}
LWLockRelease(AddinShmemInitLock);
}
void
LfcShmemRequest(void)
static void
lfc_shmem_request(void)
{
if (lfc_max_size > 0)
{
RequestAddinShmemSpace(sizeof(FileCacheControl) + hash_estimate_size(SIZE_MB_TO_CHUNKS(lfc_max_size) + 1, FILE_CACHE_ENRTY_SIZE));
RequestNamedLWLockTranche("lfc_lock", 1);
}
#if PG_VERSION_NUM>=150000
if (prev_shmem_request_hook)
prev_shmem_request_hook();
#endif
RequestAddinShmemSpace(sizeof(FileCacheControl) + hash_estimate_size(SIZE_MB_TO_CHUNKS(lfc_max_size) + 1, FILE_CACHE_ENRTY_SIZE));
RequestNamedLWLockTranche("lfc_lock", 1);
}
static bool
@@ -631,6 +642,18 @@ lfc_init(void)
NULL,
NULL,
NULL);
if (lfc_max_size == 0)
return;
prev_shmem_startup_hook = shmem_startup_hook;
shmem_startup_hook = lfc_shmem_startup;
#if PG_VERSION_NUM>=150000
prev_shmem_request_hook = shmem_request_hook;
shmem_request_hook = lfc_shmem_request;
#else
lfc_shmem_request();
#endif
}
FileCacheState*

View File

@@ -90,7 +90,6 @@ typedef struct
{
char connstring[MAX_SHARDS][MAX_PAGESERVER_CONNSTRING_SIZE];
size_t num_shards;
size_t stripe_size;
} ShardMap;
/*
@@ -111,11 +110,6 @@ typedef struct
* has changed since last access, and to detect and retry copying the value if
* the postmaster changes the value concurrently. (Postmaster doesn't have a
* PGPROC entry and therefore cannot use LWLocks.)
*
* stripe_size is now also part of ShardMap, although it is defined by separate GUC.
* Postgres doesn't provide any mechanism to enforce dependencies between GUCs,
* that it we we have to rely on order of GUC definition in config file.
* "neon.stripe_size" should be defined prior to "neon.pageserver_connstring"
*/
typedef struct
{
@@ -124,6 +118,10 @@ typedef struct
ShardMap shard_map;
} PagestoreShmemState;
#if PG_VERSION_NUM >= 150000
static shmem_request_hook_type prev_shmem_request_hook = NULL;
#endif
static shmem_startup_hook_type prev_shmem_startup_hook;
static PagestoreShmemState *pagestore_shared;
static uint64 pagestore_local_counter = 0;
@@ -236,10 +234,7 @@ ParseShardMap(const char *connstr, ShardMap *result)
p = sep + 1;
}
if (result)
{
result->num_shards = nshards;
result->stripe_size = stripe_size;
}
return true;
}
@@ -300,13 +295,12 @@ AssignPageserverConnstring(const char *newval, void *extra)
* last call, terminates all existing connections to all pageservers.
*/
static void
load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p, size_t* stripe_size_p)
load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p)
{
uint64 begin_update_counter;
uint64 end_update_counter;
ShardMap *shard_map = &pagestore_shared->shard_map;
shardno_t num_shards;
size_t stripe_size;
/*
* Postmaster can update the shared memory values concurrently, in which
@@ -321,7 +315,6 @@ load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p, siz
end_update_counter = pg_atomic_read_u64(&pagestore_shared->end_update_counter);
num_shards = shard_map->num_shards;
stripe_size = shard_map->stripe_size;
if (connstr_p && shard_no < MAX_SHARDS)
strlcpy(connstr_p, shard_map->connstring[shard_no], MAX_PAGESERVER_CONNSTRING_SIZE);
pg_memory_barrier();
@@ -356,8 +349,6 @@ load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p, siz
if (num_shards_p)
*num_shards_p = num_shards;
if (stripe_size_p)
*stripe_size_p = stripe_size;
}
#define MB (1024*1024)
@@ -366,10 +357,9 @@ shardno_t
get_shard_number(BufferTag *tag)
{
shardno_t n_shards;
size_t stripe_size;
uint32 hash;
load_shard_map(0, NULL, &n_shards, &stripe_size);
load_shard_map(0, NULL, &n_shards);
#if PG_MAJORVERSION_NUM < 16
hash = murmurhash32(tag->rnode.relNode);
@@ -422,7 +412,7 @@ pageserver_connect(shardno_t shard_no, int elevel)
* Note that connstr is used both during connection start, and when we
* log the successful connection.
*/
load_shard_map(shard_no, connstr, NULL, NULL);
load_shard_map(shard_no, connstr, NULL);
switch (shard->state)
{
@@ -1294,12 +1284,18 @@ check_neon_id(char **newval, void **extra, GucSource source)
return **newval == '\0' || HexDecodeString(id, *newval, 16);
}
static Size
PagestoreShmemSize(void)
{
return add_size(sizeof(PagestoreShmemState), NeonPerfCountersShmemSize());
}
void
static bool
PagestoreShmemInit(void)
{
bool found;
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
pagestore_shared = ShmemInitStruct("libpagestore shared state",
sizeof(PagestoreShmemState),
&found);
@@ -1310,12 +1306,44 @@ PagestoreShmemInit(void)
memset(&pagestore_shared->shard_map, 0, sizeof(ShardMap));
AssignPageserverConnstring(page_server_connstring, NULL);
}
NeonPerfCountersShmemInit();
LWLockRelease(AddinShmemInitLock);
return found;
}
void
PagestoreShmemRequest(void)
static void
pagestore_shmem_startup_hook(void)
{
RequestAddinShmemSpace(sizeof(PagestoreShmemState));
if (prev_shmem_startup_hook)
prev_shmem_startup_hook();
PagestoreShmemInit();
}
static void
pagestore_shmem_request(void)
{
#if PG_VERSION_NUM >= 150000
if (prev_shmem_request_hook)
prev_shmem_request_hook();
#endif
RequestAddinShmemSpace(PagestoreShmemSize());
}
static void
pagestore_prepare_shmem(void)
{
#if PG_VERSION_NUM >= 150000
prev_shmem_request_hook = shmem_request_hook;
shmem_request_hook = pagestore_shmem_request;
#else
pagestore_shmem_request();
#endif
prev_shmem_startup_hook = shmem_startup_hook;
shmem_startup_hook = pagestore_shmem_startup_hook;
}
/*
@@ -1324,6 +1352,8 @@ PagestoreShmemRequest(void)
void
pg_init_libpagestore(void)
{
pagestore_prepare_shmem();
DefineCustomStringVariable("neon.pageserver_connstring",
"connection string to the page server",
NULL,
@@ -1474,6 +1504,8 @@ pg_init_libpagestore(void)
0,
NULL, NULL, NULL);
relsize_hash_init();
if (page_server != NULL)
neon_log(ERROR, "libpagestore already loaded");

View File

@@ -22,7 +22,6 @@
#include "replication/slot.h"
#include "replication/walsender.h"
#include "storage/proc.h"
#include "storage/ipc.h"
#include "funcapi.h"
#include "access/htup_details.h"
#include "utils/builtins.h"
@@ -60,15 +59,11 @@ static ExecutorEnd_hook_type prev_ExecutorEnd = NULL;
static void neon_ExecutorStart(QueryDesc *queryDesc, int eflags);
static void neon_ExecutorEnd(QueryDesc *queryDesc);
#if PG_MAJORVERSION_NUM >= 16
static shmem_startup_hook_type prev_shmem_startup_hook;
static void neon_shmem_startup_hook(void);
static void neon_shmem_request_hook(void);
#if PG_MAJORVERSION_NUM >= 15
static shmem_request_hook_type prev_shmem_request_hook = NULL;
#endif
#if PG_MAJORVERSION_NUM >= 17
uint32 WAIT_EVENT_NEON_LFC_MAINTENANCE;
uint32 WAIT_EVENT_NEON_LFC_READ;
@@ -455,44 +450,15 @@ _PG_init(void)
*/
#if PG_VERSION_NUM >= 160000
load_file("$libdir/neon_rmgr", false);
prev_shmem_startup_hook = shmem_startup_hook;
shmem_startup_hook = neon_shmem_startup_hook;
#endif
/* dummy call to a Rust function in the communicator library, to check that it works */
(void) communicator_dummy(123);
/*
* Initializing a pre-loaded Postgres extension happens in three stages:
*
* 1. _PG_init() is called early at postmaster startup. In this stage, no
* shared memory has been allocated yet. Core Postgres GUCs have been
* initialized from the config files, but notably, MaxBackends has not
* calculated yet. In this stage, we must register any extension GUCs
* and can do other early initialization that doesn't depend on shared
* memory. In this stage we must also register "shmem request" and
* "shmem starutup" hooks, to be called in stages 2 and 3.
*
* 2. After MaxBackends have been calculated, the "shmem request" hooks
* are called. The hooks can reserve shared memory by calling
* RequestAddinShmemSpace and RequestNamedLWLockTranche(). The "shmem
* request hooks" are a new mechanism in Postgres v15. In v14 and
* below, you had to make those Requests in stage 1 already, which
* means they could not depend on MaxBackends. (See hack in
* NeonPerfCountersShmemRequest())
*
* 3. After some more runtime-computed GUCs that affect the amount of
* shared memory needed have been calculated, the "shmem startup" hooks
* are called. In this stage, we allocate any shared memory, LWLocks
* and other shared resources.
*
* Here, in the 'neon' extension, we register just one shmem request hook
* and one startup hook, which call into functions in all the subsystems
* that are part of the extension. On v14, the ShmemRequest functions are
* called in stage 1, and on v15 onwards they are called in stage 2.
*/
/* Stage 1: Define GUCs, and other early intialization */
pg_init_libpagestore();
relsize_hash_init();
lfc_init();
pg_init_walproposer();
init_lwlsncache();
@@ -577,15 +543,6 @@ _PG_init(void)
PGC_POSTMASTER,
0,
NULL, NULL, NULL);
DefineCustomStringVariable(
"neon.privileged_role_name",
"Name of the 'weak' superuser role, which we give to the users",
NULL,
&privileged_role_name,
"neon_superuser",
PGC_POSTMASTER, 0, NULL, NULL, NULL);
/*
* Important: This must happen after other parts of the extension are
* loaded, otherwise any settings to GUCs that were set before the
@@ -595,22 +552,6 @@ _PG_init(void)
ReportSearchPath();
/*
* Register initialization hooks for stage 2. (On v14, there's no "shmem
* request" hooks, so call the ShmemRequest functions immediately.)
*/
#if PG_VERSION_NUM >= 150000
prev_shmem_request_hook = shmem_request_hook;
shmem_request_hook = neon_shmem_request_hook;
#else
neon_shmem_request_hook();
#endif
/* Register hooks for stage 3 */
prev_shmem_startup_hook = shmem_startup_hook;
shmem_startup_hook = neon_shmem_startup_hook;
/* Other misc initialization */
prev_ExecutorStart = ExecutorStart_hook;
ExecutorStart_hook = neon_ExecutorStart;
prev_ExecutorEnd = ExecutorEnd_hook;
@@ -696,34 +637,7 @@ approximate_working_set_size(PG_FUNCTION_ARGS)
PG_RETURN_INT32(dc);
}
/*
* Initialization stage 2: make requests for the amount of shared memory we
* will need.
*
* For a high-level explanation of the initialization process, see _PG_init().
*/
static void
neon_shmem_request_hook(void)
{
#if PG_VERSION_NUM >= 150000
if (prev_shmem_request_hook)
prev_shmem_request_hook();
#endif
LfcShmemRequest();
NeonPerfCountersShmemRequest();
PagestoreShmemRequest();
RelsizeCacheShmemRequest();
WalproposerShmemRequest();
LwLsnCacheShmemRequest();
}
/*
* Initialization stage 3: Initialize shared memory.
*
* For a high-level explanation of the initialization process, see _PG_init().
*/
#if PG_MAJORVERSION_NUM >= 16
static void
neon_shmem_startup_hook(void)
{
@@ -731,15 +645,6 @@ neon_shmem_startup_hook(void)
if (prev_shmem_startup_hook)
prev_shmem_startup_hook();
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
LfcShmemInit();
NeonPerfCountersShmemInit();
PagestoreShmemInit();
RelsizeCacheShmemInit();
WalproposerShmemInit();
LwLsnCacheShmemInit();
#if PG_MAJORVERSION_NUM >= 17
WAIT_EVENT_NEON_LFC_MAINTENANCE = WaitEventExtensionNew("Neon/FileCache_Maintenance");
WAIT_EVENT_NEON_LFC_READ = WaitEventExtensionNew("Neon/FileCache_Read");
@@ -752,9 +657,8 @@ neon_shmem_startup_hook(void)
WAIT_EVENT_NEON_PS_READ = WaitEventExtensionNew("Neon/PS_ReadIO");
WAIT_EVENT_NEON_WAL_DL = WaitEventExtensionNew("Neon/WAL_Download");
#endif
LWLockRelease(AddinShmemInitLock);
}
#endif
/*
* ExecutorStart hook: start up tracking if needed

View File

@@ -16,6 +16,7 @@
extern char *neon_auth_token;
extern char *neon_timeline;
extern char *neon_tenant;
extern char *wal_acceptors_list;
extern int wal_acceptor_reconnect_timeout;
extern int wal_acceptor_connection_timeout;
@@ -70,19 +71,4 @@ extern PGDLLEXPORT void WalProposerSync(int argc, char *argv[]);
extern PGDLLEXPORT void WalProposerMain(Datum main_arg);
extern PGDLLEXPORT void LogicalSlotsMonitorMain(Datum main_arg);
extern void LfcShmemRequest(void);
extern void PagestoreShmemRequest(void);
extern void RelsizeCacheShmemRequest(void);
extern void WalproposerShmemRequest(void);
extern void LwLsnCacheShmemRequest(void);
extern void NeonPerfCountersShmemRequest(void);
extern void LfcShmemInit(void);
extern void PagestoreShmemInit(void);
extern void RelsizeCacheShmemInit(void);
extern void WalproposerShmemInit(void);
extern void LwLsnCacheShmemInit(void);
extern void NeonPerfCountersShmemInit(void);
#endif /* NEON_H */

View File

@@ -13,7 +13,7 @@
* accumulate changes. On subtransaction commit, the top of the stack
* is merged with the table below it.
*
* Support event triggers for {privileged_role_name}
* Support event triggers for neon_superuser
*
* IDENTIFICATION
* contrib/neon/neon_dll_handler.c
@@ -49,7 +49,6 @@
#include "neon_ddl_handler.h"
#include "neon_utils.h"
#include "neon.h"
static ProcessUtility_hook_type PreviousProcessUtilityHook = NULL;
static fmgr_hook_type next_fmgr_hook = NULL;
@@ -542,11 +541,11 @@ NeonXactCallback(XactEvent event, void *arg)
}
static bool
IsPrivilegedRole(const char *role_name)
RoleIsNeonSuperuser(const char *role_name)
{
Assert(role_name);
return strcmp(role_name, privileged_role_name) == 0;
return strcmp(role_name, "neon_superuser") == 0;
}
static void
@@ -579,9 +578,8 @@ HandleCreateDb(CreatedbStmt *stmt)
{
const char *owner_name = defGetString(downer);
if (IsPrivilegedRole(owner_name))
elog(ERROR, "could not create a database with owner %s", privileged_role_name);
if (RoleIsNeonSuperuser(owner_name))
elog(ERROR, "can't create a database with owner neon_superuser");
entry->owner = get_role_oid(owner_name, false);
}
else
@@ -611,9 +609,8 @@ HandleAlterOwner(AlterOwnerStmt *stmt)
memset(entry->old_name, 0, sizeof(entry->old_name));
new_owner = get_rolespec_name(stmt->newowner);
if (IsPrivilegedRole(new_owner))
elog(ERROR, "could not alter owner to %s", privileged_role_name);
if (RoleIsNeonSuperuser(new_owner))
elog(ERROR, "can't alter owner to neon_superuser");
entry->owner = get_role_oid(new_owner, false);
entry->type = Op_Set;
}
@@ -719,8 +716,8 @@ HandleAlterRole(AlterRoleStmt *stmt)
InitRoleTableIfNeeded();
role_name = get_rolespec_name(stmt->role);
if (IsPrivilegedRole(role_name) && !superuser())
elog(ERROR, "could not ALTER %s", privileged_role_name);
if (RoleIsNeonSuperuser(role_name) && !superuser())
elog(ERROR, "can't ALTER neon_superuser");
dpass = NULL;
foreach(option, stmt->options)
@@ -834,7 +831,7 @@ HandleRename(RenameStmt *stmt)
*
* In vanilla only superuser can create Event Triggers.
*
* We allow it for {privileged_role_name} by temporary switching to superuser. But as
* We allow it for neon_superuser by temporary switching to superuser. But as
* far as event trigger can fire in superuser context we should protect
* superuser from execution of arbitrary user's code.
*
@@ -894,7 +891,7 @@ force_noop(FmgrInfo *finfo)
* Also skip executing Event Triggers when GUC neon.event_triggers has been
* set to false. This might be necessary to be able to connect again after a
* LOGIN Event Trigger has been installed that would prevent connections as
* {privileged_role_name}.
* neon_superuser.
*/
static void
neon_fmgr_hook(FmgrHookEventType event, FmgrInfo *flinfo, Datum *private)
@@ -913,24 +910,24 @@ neon_fmgr_hook(FmgrHookEventType event, FmgrInfo *flinfo, Datum *private)
}
/*
* The {privileged_role_name} role can use the GUC neon.event_triggers to disable
* The neon_superuser role can use the GUC neon.event_triggers to disable
* firing Event Trigger.
*
* SET neon.event_triggers TO false;
*
* This only applies to the {privileged_role_name} role though, and only allows
* skipping Event Triggers owned by {privileged_role_name}, which we check by
* proxy of the Event Trigger function being owned by {privileged_role_name}.
* This only applies to the neon_superuser role though, and only allows
* skipping Event Triggers owned by neon_superuser, which we check by
* proxy of the Event Trigger function being owned by neon_superuser.
*
* A role that is created in role {privileged_role_name} should be allowed to also
* A role that is created in role neon_superuser should be allowed to also
* benefit from the neon_event_triggers GUC, and will be considered the
* same as the {privileged_role_name} role.
* same as the neon_superuser role.
*/
if (event == FHET_START
&& !neon_event_triggers
&& is_privileged_role())
&& is_neon_superuser())
{
Oid weak_superuser_oid = get_role_oid(privileged_role_name, false);
Oid neon_superuser_oid = get_role_oid("neon_superuser", false);
/* Find the Function Attributes (owner Oid, security definer) */
const char *fun_owner_name = NULL;
@@ -940,8 +937,8 @@ neon_fmgr_hook(FmgrHookEventType event, FmgrInfo *flinfo, Datum *private)
LookupFuncOwnerSecDef(flinfo->fn_oid, &fun_owner, &fun_is_secdef);
fun_owner_name = GetUserNameFromId(fun_owner, false);
if (IsPrivilegedRole(fun_owner_name)
|| has_privs_of_role(fun_owner, weak_superuser_oid))
if (RoleIsNeonSuperuser(fun_owner_name)
|| has_privs_of_role(fun_owner, neon_superuser_oid))
{
elog(WARNING,
"Skipping Event Trigger: neon.event_triggers is false");
@@ -1152,13 +1149,13 @@ ProcessCreateEventTrigger(
}
/*
* Allow {privileged_role_name} to create Event Trigger, while keeping the
* Allow neon_superuser to create Event Trigger, while keeping the
* ownership of the object.
*
* For that we give superuser membership to the role for the execution of
* the command.
*/
if (IsTransactionState() && is_privileged_role())
if (IsTransactionState() && is_neon_superuser())
{
/* Find the Event Trigger function Oid */
Oid func_oid = LookupFuncName(stmt->funcname, 0, NULL, false);
@@ -1235,7 +1232,7 @@ ProcessCreateEventTrigger(
*
* That way [ ALTER | DROP ] EVENT TRIGGER commands just work.
*/
if (IsTransactionState() && is_privileged_role())
if (IsTransactionState() && is_neon_superuser())
{
if (!current_user_is_super)
{
@@ -1355,17 +1352,19 @@ NeonProcessUtility(
}
/*
* Only {privileged_role_name} is granted privilege to edit neon.event_triggers GUC.
* Only neon_superuser is granted privilege to edit neon.event_triggers GUC.
*/
static void
neon_event_triggers_assign_hook(bool newval, void *extra)
{
if (IsTransactionState() && !is_privileged_role())
/* MyDatabaseId == InvalidOid || !OidIsValid(GetUserId()) */
if (IsTransactionState() && !is_neon_superuser())
{
ereport(ERROR,
(errcode(ERRCODE_INSUFFICIENT_PRIVILEGE),
errmsg("permission denied to set neon.event_triggers"),
errdetail("Only \"%s\" is allowed to set the GUC", privileged_role_name)));
errdetail("Only \"neon_superuser\" is allowed to set the GUC")));
}
}

View File

@@ -1,6 +1,5 @@
#include "postgres.h"
#include "neon.h"
#include "neon_lwlsncache.h"
#include "miscadmin.h"
@@ -82,6 +81,14 @@ static set_max_lwlsn_hook_type prev_set_max_lwlsn_hook = NULL;
static set_lwlsn_relation_hook_type prev_set_lwlsn_relation_hook = NULL;
static set_lwlsn_db_hook_type prev_set_lwlsn_db_hook = NULL;
static shmem_startup_hook_type prev_shmem_startup_hook;
#if PG_VERSION_NUM >= 150000
static shmem_request_hook_type prev_shmem_request_hook;
#endif
static void shmemrequest(void);
static void shmeminit(void);
static void neon_set_max_lwlsn(XLogRecPtr lsn);
void
@@ -92,6 +99,16 @@ init_lwlsncache(void)
lwlc_register_gucs();
prev_shmem_startup_hook = shmem_startup_hook;
shmem_startup_hook = shmeminit;
#if PG_VERSION_NUM >= 150000
prev_shmem_request_hook = shmem_request_hook;
shmem_request_hook = shmemrequest;
#else
shmemrequest();
#endif
prev_set_lwlsn_block_range_hook = set_lwlsn_block_range_hook;
set_lwlsn_block_range_hook = neon_set_lwlsn_block_range;
prev_set_lwlsn_block_v_hook = set_lwlsn_block_v_hook;
@@ -107,19 +124,20 @@ init_lwlsncache(void)
}
void
LwLsnCacheShmemRequest(void)
{
static void shmemrequest(void) {
Size requested_size = sizeof(LwLsnCacheCtl);
requested_size += hash_estimate_size(lwlsn_cache_size, sizeof(LastWrittenLsnCacheEntry));
RequestAddinShmemSpace(requested_size);
#if PG_VERSION_NUM >= 150000
if (prev_shmem_request_hook)
prev_shmem_request_hook();
#endif
}
void
LwLsnCacheShmemInit(void)
{
static void shmeminit(void) {
static HASHCTL info;
bool found;
if (lwlsn_cache_size > 0)
@@ -139,6 +157,9 @@ LwLsnCacheShmemInit(void)
}
dlist_init(&LwLsnCache->lastWrittenLsnLRU);
LwLsnCache->maxLastWrittenLsn = GetRedoRecPtr();
if (prev_shmem_startup_hook) {
prev_shmem_startup_hook();
}
}
/*

View File

@@ -17,32 +17,22 @@
#include "storage/shmem.h"
#include "utils/builtins.h"
#include "neon.h"
#include "neon_perf_counters.h"
#include "neon_pgversioncompat.h"
neon_per_backend_counters *neon_per_backend_counters_shared;
void
NeonPerfCountersShmemRequest(void)
Size
NeonPerfCountersShmemSize(void)
{
Size size;
#if PG_MAJORVERSION_NUM < 15
/* Hack: in PG14 MaxBackends is not initialized at the time of calling NeonPerfCountersShmemRequest function.
* Do it ourselves and then undo to prevent assertion failure
*/
Assert(MaxBackends == 0); /* not initialized yet */
InitializeMaxBackends();
size = mul_size(NUM_NEON_PERF_COUNTER_SLOTS, sizeof(neon_per_backend_counters));
MaxBackends = 0;
#else
size = mul_size(NUM_NEON_PERF_COUNTER_SLOTS, sizeof(neon_per_backend_counters));
#endif
RequestAddinShmemSpace(size);
Size size = 0;
size = add_size(size, mul_size(NUM_NEON_PERF_COUNTER_SLOTS,
sizeof(neon_per_backend_counters)));
return size;
}
void
NeonPerfCountersShmemInit(void)
{

View File

@@ -10,7 +10,6 @@
*/
#include "postgres.h"
#include "neon.h"
#include "neon_pgversioncompat.h"
#include "pagestore_client.h"
@@ -50,23 +49,32 @@ typedef struct
* algorithm */
} RelSizeHashControl;
static HTAB *relsize_hash;
static LWLockId relsize_lock;
static int relsize_hash_size;
static RelSizeHashControl* relsize_ctl;
static shmem_startup_hook_type prev_shmem_startup_hook = NULL;
#if PG_VERSION_NUM >= 150000
static shmem_request_hook_type prev_shmem_request_hook = NULL;
static void relsize_shmem_request(void);
#endif
/*
* Size of a cache entry is 36 bytes. So this default will take about 2.3 MB,
* which seems reasonable.
*/
#define DEFAULT_RELSIZE_HASH_SIZE (64 * 1024)
static HTAB *relsize_hash;
static LWLockId relsize_lock;
static int relsize_hash_size = DEFAULT_RELSIZE_HASH_SIZE;
static RelSizeHashControl* relsize_ctl;
void
RelsizeCacheShmemInit(void)
static void
neon_smgr_shmem_startup(void)
{
static HASHCTL info;
bool found;
if (prev_shmem_startup_hook)
prev_shmem_startup_hook();
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
relsize_ctl = (RelSizeHashControl *) ShmemInitStruct("relsize_hash", sizeof(RelSizeHashControl), &found);
if (!found)
{
@@ -77,6 +85,7 @@ RelsizeCacheShmemInit(void)
relsize_hash_size, relsize_hash_size,
&info,
HASH_ELEM | HASH_BLOBS);
LWLockRelease(AddinShmemInitLock);
relsize_ctl->size = 0;
relsize_ctl->hits = 0;
relsize_ctl->misses = 0;
@@ -233,15 +242,34 @@ relsize_hash_init(void)
PGC_POSTMASTER,
0,
NULL, NULL, NULL);
if (relsize_hash_size > 0)
{
#if PG_VERSION_NUM >= 150000
prev_shmem_request_hook = shmem_request_hook;
shmem_request_hook = relsize_shmem_request;
#else
RequestAddinShmemSpace(hash_estimate_size(relsize_hash_size, sizeof(RelSizeEntry)));
RequestNamedLWLockTranche("neon_relsize", 1);
#endif
prev_shmem_startup_hook = shmem_startup_hook;
shmem_startup_hook = neon_smgr_shmem_startup;
}
}
#if PG_VERSION_NUM >= 150000
/*
* shmem_request hook: request additional shared resources. We'll allocate or
* attach to the shared resources in neon_smgr_shmem_startup().
*/
void
RelsizeCacheShmemRequest(void)
static void
relsize_shmem_request(void)
{
if (prev_shmem_request_hook)
prev_shmem_request_hook();
RequestAddinShmemSpace(sizeof(RelSizeHashControl) + hash_estimate_size(relsize_hash_size, sizeof(RelSizeEntry)));
RequestNamedLWLockTranche("neon_relsize", 1);
}
#endif

View File

@@ -377,16 +377,6 @@ typedef struct PageserverFeedback
} PageserverFeedback;
/* BEGIN_HADRON */
/**
* WAL proposer is the only backend that will update `sent_bytes` and `last_recorded_time_us`.
* Once the `sent_bytes` reaches the limit, it puts backpressure on PG backends.
*
* A PG backend checks `should_limit` to see if it should hit backpressure.
* - If yes, it also checks the `last_recorded_time_us` to see
* if it's time to push more WALs. This is because the WAL proposer
* only resets `should_limit` to 0 after it is notified about new WALs
* which might take a while.
*/
typedef struct WalRateLimiter
{
/* If the value is 1, PG backends will hit backpressure. */
@@ -394,7 +384,7 @@ typedef struct WalRateLimiter
/* The number of bytes sent in the current second. */
uint64 sent_bytes;
/* The last recorded time in microsecond. */
pg_atomic_uint64 last_recorded_time_us;
TimestampTz last_recorded_time_us;
} WalRateLimiter;
/* END_HADRON */

View File

@@ -83,8 +83,10 @@ static XLogRecPtr standby_flush_lsn = InvalidXLogRecPtr;
static XLogRecPtr standby_apply_lsn = InvalidXLogRecPtr;
static HotStandbyFeedback agg_hs_feedback;
static void nwp_shmem_startup_hook(void);
static void nwp_register_gucs(void);
static void assign_neon_safekeepers(const char *newval, void *extra);
static void nwp_prepare_shmem(void);
static uint64 backpressure_lag_impl(void);
static uint64 startup_backpressure_wrap(void);
static bool backpressure_throttling_impl(void);
@@ -97,6 +99,11 @@ static TimestampTz walprop_pg_get_current_timestamp(WalProposer *wp);
static void walprop_pg_load_libpqwalreceiver(void);
static process_interrupts_callback_t PrevProcessInterruptsCallback = NULL;
static shmem_startup_hook_type prev_shmem_startup_hook_type;
#if PG_VERSION_NUM >= 150000
static shmem_request_hook_type prev_shmem_request_hook = NULL;
static void walproposer_shmem_request(void);
#endif
static void WalproposerShmemInit_SyncSafekeeper(void);
@@ -186,6 +193,8 @@ pg_init_walproposer(void)
nwp_register_gucs();
nwp_prepare_shmem();
delay_backend_us = &startup_backpressure_wrap;
PrevProcessInterruptsCallback = ProcessInterruptsCallback;
ProcessInterruptsCallback = backpressure_throttling_impl;
@@ -440,20 +449,8 @@ backpressure_lag_impl(void)
}
state = GetWalpropShmemState();
if (state != NULL && !!pg_atomic_read_u32(&state->wal_rate_limiter.should_limit))
if (state != NULL && pg_atomic_read_u32(&state->wal_rate_limiter.should_limit) == 1)
{
TimestampTz now = GetCurrentTimestamp();
struct WalRateLimiter *limiter = &state->wal_rate_limiter;
uint64 last_recorded_time = pg_atomic_read_u64(&limiter->last_recorded_time_us);
if (now - last_recorded_time > USECS_PER_SEC)
{
/*
* The backend has past 1 second since the last recorded time and it's time to push more WALs.
* If the backends are pushing WALs too fast, the wal proposer will rate limit them again.
*/
uint32 expected = true;
pg_atomic_compare_exchange_u32(&state->wal_rate_limiter.should_limit, &expected, false);
}
return 1;
}
/* END_HADRON */
@@ -485,11 +482,12 @@ WalproposerShmemSize(void)
return sizeof(WalproposerShmemState);
}
void
static bool
WalproposerShmemInit(void)
{
bool found;
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
walprop_shared = ShmemInitStruct("Walproposer shared state",
sizeof(WalproposerShmemState),
&found);
@@ -504,9 +502,11 @@ WalproposerShmemInit(void)
pg_atomic_init_u64(&walprop_shared->currentClusterSize, 0);
/* BEGIN_HADRON */
pg_atomic_init_u32(&walprop_shared->wal_rate_limiter.should_limit, 0);
pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.last_recorded_time_us, 0);
/* END_HADRON */
}
LWLockRelease(AddinShmemInitLock);
return found;
}
static void
@@ -520,7 +520,6 @@ WalproposerShmemInit_SyncSafekeeper(void)
pg_atomic_init_u64(&walprop_shared->backpressureThrottlingTime, 0);
/* BEGIN_HADRON */
pg_atomic_init_u32(&walprop_shared->wal_rate_limiter.should_limit, 0);
pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.last_recorded_time_us, 0);
/* END_HADRON */
}
@@ -610,15 +609,42 @@ walprop_register_bgworker(void)
/* shmem handling */
static void
nwp_prepare_shmem(void)
{
#if PG_VERSION_NUM >= 150000
prev_shmem_request_hook = shmem_request_hook;
shmem_request_hook = walproposer_shmem_request;
#else
RequestAddinShmemSpace(WalproposerShmemSize());
#endif
prev_shmem_startup_hook_type = shmem_startup_hook;
shmem_startup_hook = nwp_shmem_startup_hook;
}
#if PG_VERSION_NUM >= 150000
/*
* shmem_request hook: request additional shared resources. We'll allocate or
* attach to the shared resources in WalproposerShmemInit().
* attach to the shared resources in nwp_shmem_startup_hook().
*/
void
WalproposerShmemRequest(void)
static void
walproposer_shmem_request(void)
{
if (prev_shmem_request_hook)
prev_shmem_request_hook();
RequestAddinShmemSpace(WalproposerShmemSize());
}
#endif
static void
nwp_shmem_startup_hook(void)
{
if (prev_shmem_startup_hook_type)
prev_shmem_startup_hook_type();
WalproposerShmemInit();
}
WalproposerShmemState *
GetWalpropShmemState(void)
@@ -1525,18 +1551,18 @@ XLogBroadcastWalProposer(WalProposer *wp)
{
uint64 max_wal_bytes = (uint64) databricks_max_wal_mb_per_second * 1024 * 1024;
struct WalRateLimiter *limiter = &state->wal_rate_limiter;
uint64 last_recorded_time = pg_atomic_read_u64(&limiter->last_recorded_time_us);
if (now - last_recorded_time > USECS_PER_SEC)
if (now - limiter->last_recorded_time_us > USECS_PER_SEC)
{
/* Reset the rate limiter */
limiter->last_recorded_time_us = now;
limiter->sent_bytes = 0;
pg_atomic_write_u64(&limiter->last_recorded_time_us, now);
pg_atomic_write_u32(&limiter->should_limit, false);
pg_atomic_exchange_u32(&limiter->should_limit, 0);
}
limiter->sent_bytes += (endptr - startptr);
if (limiter->sent_bytes > max_wal_bytes)
{
pg_atomic_write_u32(&limiter->should_limit, true);
pg_atomic_exchange_u32(&limiter->should_limit, 1);
}
}
/* END_HADRON */

View File

@@ -10,7 +10,6 @@ testing = ["dep:tokio-postgres"]
[dependencies]
ahash.workspace = true
alloc-metrics.workspace = true
anyhow.workspace = true
arc-swap.workspace = true
async-compression.workspace = true
@@ -121,6 +120,7 @@ workspace_hack.workspace = true
[dev-dependencies]
assert-json-diff.workspace = true
camino-tempfile.workspace = true
criterion.workspace = true
fallible-iterator.workspace = true
flate2.workspace = true
tokio-tungstenite.workspace = true
@@ -131,3 +131,8 @@ walkdir.workspace = true
rand_distr = "0.4"
tokio-postgres.workspace = true
tracing-test = "0.2"
[[bench]]
name = "logging"
harness = false

127
proxy/benches/logging.rs Normal file
View File

@@ -0,0 +1,127 @@
use std::io;
use criterion::{Criterion, criterion_group, criterion_main};
use proxy::logging::{Clock, JsonLoggingLayer};
use tracing_subscriber::prelude::*;
struct DevNullWriter;
impl proxy::logging::MakeWriter for DevNullWriter {
fn make_writer(&self) -> impl io::Write {
DevNullWriter
}
}
impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for DevNullWriter {
type Writer = DevNullWriter;
fn make_writer(&'a self) -> Self::Writer {
DevNullWriter
}
}
impl io::Write for DevNullWriter {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
Ok(criterion::black_box(buf).len())
}
#[inline(always)]
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
struct FixedClock;
impl Clock for FixedClock {
fn now(&self) -> chrono::DateTime<chrono::Utc> {
const { chrono::DateTime::from_timestamp_nanos(1747859990_000_000_000).to_utc() }
}
}
pub fn bench_logging(c: &mut Criterion) {
c.bench_function("text fmt current", |b| {
let registry = tracing_subscriber::Registry::default().with(
tracing_subscriber::fmt::layer()
.with_ansi(false)
.with_target(false)
.with_writer(DevNullWriter),
);
tracing::subscriber::with_default(registry, || {
tracing::info_span!("span1", a = 42, b = true, c = "string").in_scope(|| {
tracing::info_span!("span2", a = 42, b = true, c = "string").in_scope(|| {
b.iter(|| {
tracing::error!(a = 42, b = true, c = "string", "message field");
})
});
});
});
});
c.bench_function("text fmt full", |b| {
let registry = tracing_subscriber::Registry::default().with(
tracing_subscriber::fmt::layer()
.with_level(true)
.with_file(true)
.with_line_number(true)
.with_target(true)
.with_thread_ids(true)
.with_writer(DevNullWriter),
);
tracing::subscriber::with_default(registry, || {
tracing::info_span!("span1", a = 42, b = true, c = "string").in_scope(|| {
tracing::info_span!("span2", a = 42, b = true, c = "string").in_scope(|| {
b.iter(|| {
tracing::error!(a = 42, b = true, c = "string", "message field");
})
});
});
});
});
c.bench_function("json fmt", |b| {
let registry = tracing_subscriber::Registry::default().with(
tracing_subscriber::fmt::layer()
.with_level(true)
.with_file(true)
.with_line_number(true)
.with_target(true)
.with_thread_ids(true)
.with_writer(DevNullWriter)
.json(),
);
tracing::subscriber::with_default(registry, || {
tracing::info_span!("span1", a = 42, b = true, c = "string").in_scope(|| {
tracing::info_span!("span2", a = 42, b = true, c = "string").in_scope(|| {
b.iter(|| {
tracing::error!(a = 42, b = true, c = "string", "message field");
})
});
});
});
});
c.bench_function("json custom", |b| {
let registry = tracing_subscriber::Registry::default().with(JsonLoggingLayer::new(
FixedClock,
DevNullWriter,
&["a"],
));
tracing::subscriber::with_default(registry, || {
tracing::info_span!("span1", a = 42, b = true, c = "string").in_scope(|| {
tracing::info_span!("span2", a = 42, b = true, c = "string").in_scope(|| {
b.iter(|| {
tracing::error!(a = 42, b = true, c = "string", "message field");
})
});
});
});
});
}
criterion_group!(benches, bench_logging);
criterion_main!(benches);

View File

@@ -1,22 +1,11 @@
use alloc_metrics::TrackedAllocator;
use proxy::binary::proxy::MemoryContext;
use tikv_jemallocator::Jemalloc;
#[global_allocator]
// Safety: `MemoryContext` upholds the safety requirements.
static GLOBAL: TrackedAllocator<Jemalloc, MemoryContext> =
unsafe { TrackedAllocator::new(Jemalloc, MemoryContext::Root) };
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
#[allow(non_upper_case_globals)]
#[unsafe(export_name = "malloc_conf")]
pub static malloc_conf: &[u8] = b"prof:true,prof_active:true,lg_prof_sample:21\0";
fn main() -> anyhow::Result<()> {
GLOBAL.register_thread();
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.on_thread_start(|| GLOBAL.register_thread())
.build()
.expect("Failed building the Runtime")
.block_on(proxy::binary::proxy::run(&GLOBAL))
#[tokio::main]
async fn main() -> anyhow::Result<()> {
proxy::binary::proxy::run().await
}

View File

@@ -111,7 +111,7 @@ pub async fn run() -> anyhow::Result<()> {
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)), None);
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
// TODO: refactor these to use labels
debug!("Version: {GIT_VERSION}");

View File

@@ -80,7 +80,7 @@ pub async fn run() -> anyhow::Result<()> {
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)), None);
Metrics::install(Arc::new(ThreadPoolMetrics::new(0)));
let args = cli().get_matches();
let destination: String = args

View File

@@ -39,8 +39,7 @@ use crate::config::{
};
use crate::context::parquet::ParquetUploadArgs;
use crate::http::health_server::AppMetrics;
pub use crate::metrics::MemoryContext;
use crate::metrics::{Alloc, Metrics};
use crate::metrics::Metrics;
use crate::rate_limiter::{EndpointRateLimiter, RateBucketInfo, WakeComputeRateLimiter};
use crate::redis::connection_with_credentials_provider::ConnectionWithCredentialsProvider;
use crate::redis::kv_ops::RedisKVClient;
@@ -319,7 +318,7 @@ struct PgSniRouterArgs {
dest: Option<String>,
}
pub async fn run(alloc: &'static Alloc) -> anyhow::Result<()> {
pub async fn run() -> anyhow::Result<()> {
let _logging_guard = crate::logging::init().await?;
let _panic_hook_guard = utils::logging::replace_panic_hook_with_tracing_panic_hook();
let _sentry_guard = init_sentry(Some(GIT_VERSION.into()), &[]);
@@ -341,7 +340,7 @@ pub async fn run(alloc: &'static Alloc) -> anyhow::Result<()> {
};
let args = ProxyCliArgs::parse();
let config = build_config(&args, alloc)?;
let config = build_config(&args)?;
let auth_backend = build_auth_backend(&args)?;
match auth_backend {
@@ -590,12 +589,9 @@ pub async fn run(alloc: &'static Alloc) -> anyhow::Result<()> {
}
/// ProxyConfig is created at proxy startup, and lives forever.
fn build_config(
args: &ProxyCliArgs,
alloc: &'static Alloc,
) -> anyhow::Result<&'static ProxyConfig> {
fn build_config(args: &ProxyCliArgs) -> anyhow::Result<&'static ProxyConfig> {
let thread_pool = ThreadPool::new(args.scram_thread_pool_size);
Metrics::install(thread_pool.metrics.clone(), Some(alloc));
Metrics::install(thread_pool.metrics.clone());
let tls_config = match (&args.tls_key, &args.tls_cert) {
(Some(key_path), Some(cert_path)) => Some(config::configure_tls(

View File

@@ -1,16 +1,17 @@
use std::collections::{HashMap, HashSet, hash_map};
use std::convert::Infallible;
use std::sync::atomic::AtomicU64;
use std::time::Duration;
use async_trait::async_trait;
use clashmap::ClashMap;
use clashmap::mapref::one::Ref;
use rand::{Rng, thread_rng};
use tokio::sync::Mutex;
use tokio::time::Instant;
use tracing::{debug, info};
use crate::config::ProjectInfoCacheOptions;
use crate::control_plane::messages::{ControlPlaneErrorMessage, Reason};
use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::types::{EndpointId, RoleName};
@@ -21,53 +22,52 @@ pub(crate) trait ProjectInfoCache {
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt);
fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt);
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
async fn decrement_active_listeners(&self);
async fn increment_active_listeners(&self);
}
struct Entry<T> {
expires_at: Instant,
created_at: Instant,
value: T,
}
impl<T> Entry<T> {
pub(crate) fn new(value: T, ttl: Duration) -> Self {
pub(crate) fn new(value: T) -> Self {
Self {
expires_at: Instant::now() + ttl,
created_at: Instant::now(),
value,
}
}
pub(crate) fn get(&self) -> Option<&T> {
(!self.is_expired()).then_some(&self.value)
pub(crate) fn get(&self, valid_since: Instant) -> Option<&T> {
(valid_since < self.created_at).then_some(&self.value)
}
}
fn is_expired(&self) -> bool {
self.expires_at <= Instant::now()
impl<T> From<T> for Entry<T> {
fn from(value: T) -> Self {
Self::new(value)
}
}
struct EndpointInfo {
role_controls: HashMap<RoleNameInt, Entry<ControlPlaneResult<RoleAccessControl>>>,
controls: Option<Entry<ControlPlaneResult<EndpointAccessControl>>>,
role_controls: HashMap<RoleNameInt, Entry<RoleAccessControl>>,
controls: Option<Entry<EndpointAccessControl>>,
}
type ControlPlaneResult<T> = Result<T, Box<ControlPlaneErrorMessage>>;
impl EndpointInfo {
pub(crate) fn get_role_secret_with_ttl(
pub(crate) fn get_role_secret(
&self,
role_name: RoleNameInt,
) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
let entry = self.role_controls.get(&role_name)?;
let ttl = entry.expires_at - Instant::now();
Some((entry.get()?.clone(), ttl))
valid_since: Instant,
) -> Option<RoleAccessControl> {
let controls = self.role_controls.get(&role_name)?;
controls.get(valid_since).cloned()
}
pub(crate) fn get_controls_with_ttl(
&self,
) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
let entry = self.controls.as_ref()?;
let ttl = entry.expires_at - Instant::now();
Some((entry.get()?.clone(), ttl))
pub(crate) fn get_controls(&self, valid_since: Instant) -> Option<EndpointAccessControl> {
let controls = self.controls.as_ref()?;
controls.get(valid_since).cloned()
}
pub(crate) fn invalidate_endpoint(&mut self) {
@@ -92,8 +92,11 @@ pub struct ProjectInfoCacheImpl {
project2ep: ClashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
// FIXME(stefan): we need a way to GC the account2ep map.
account2ep: ClashMap<AccountIdInt, HashSet<EndpointIdInt>>,
config: ProjectInfoCacheOptions,
start_time: Instant,
ttl_disabled_since_us: AtomicU64,
active_listeners_lock: Mutex<usize>,
}
#[async_trait]
@@ -149,6 +152,29 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
}
}
}
async fn decrement_active_listeners(&self) {
let mut listeners_guard = self.active_listeners_lock.lock().await;
if *listeners_guard == 0 {
tracing::error!("active_listeners count is already 0, something is broken");
return;
}
*listeners_guard -= 1;
if *listeners_guard == 0 {
self.ttl_disabled_since_us
.store(u64::MAX, std::sync::atomic::Ordering::SeqCst);
}
}
async fn increment_active_listeners(&self) {
let mut listeners_guard = self.active_listeners_lock.lock().await;
*listeners_guard += 1;
if *listeners_guard == 1 {
let new_ttl = (self.start_time.elapsed() + self.config.ttl).as_micros() as u64;
self.ttl_disabled_since_us
.store(new_ttl, std::sync::atomic::Ordering::SeqCst);
}
}
}
impl ProjectInfoCacheImpl {
@@ -158,6 +184,9 @@ impl ProjectInfoCacheImpl {
project2ep: ClashMap::new(),
account2ep: ClashMap::new(),
config,
ttl_disabled_since_us: AtomicU64::new(u64::MAX),
start_time: Instant::now(),
active_listeners_lock: Mutex::new(0),
}
}
@@ -169,28 +198,30 @@ impl ProjectInfoCacheImpl {
self.cache.get(&endpoint_id)
}
pub(crate) fn get_role_secret_with_ttl(
pub(crate) fn get_role_secret(
&self,
endpoint_id: &EndpointId,
role_name: &RoleName,
) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
) -> Option<RoleAccessControl> {
let valid_since = self.get_cache_times();
let role_name = RoleNameInt::get(role_name)?;
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_role_secret_with_ttl(role_name)
endpoint_info.get_role_secret(role_name, valid_since)
}
pub(crate) fn get_endpoint_access_with_ttl(
pub(crate) fn get_endpoint_access(
&self,
endpoint_id: &EndpointId,
) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
) -> Option<EndpointAccessControl> {
let valid_since = self.get_cache_times();
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_controls_with_ttl()
endpoint_info.get_controls(valid_since)
}
pub(crate) fn insert_endpoint_access(
&self,
account_id: Option<AccountIdInt>,
project_id: Option<ProjectIdInt>,
project_id: ProjectIdInt,
endpoint_id: EndpointIdInt,
role_name: RoleNameInt,
controls: EndpointAccessControl,
@@ -199,89 +230,26 @@ impl ProjectInfoCacheImpl {
if let Some(account_id) = account_id {
self.insert_account2endpoint(account_id, endpoint_id);
}
if let Some(project_id) = project_id {
self.insert_project2endpoint(project_id, endpoint_id);
}
self.insert_project2endpoint(project_id, endpoint_id);
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
debug!(
key = &*endpoint_id,
"created a cache entry for endpoint access"
);
let controls = Some(Entry::new(Ok(controls), self.config.ttl));
let role_controls = Entry::new(Ok(role_controls), self.config.ttl);
let controls = Entry::from(controls);
let role_controls = Entry::from(role_controls);
match self.cache.entry(endpoint_id) {
clashmap::Entry::Vacant(e) => {
e.insert(EndpointInfo {
role_controls: HashMap::from_iter([(role_name, role_controls)]),
controls,
controls: Some(controls),
});
}
clashmap::Entry::Occupied(mut e) => {
let ep = e.get_mut();
ep.controls = controls;
if ep.role_controls.len() < self.config.max_roles {
ep.role_controls.insert(role_name, role_controls);
}
}
}
}
pub(crate) fn insert_endpoint_access_err(
&self,
endpoint_id: EndpointIdInt,
role_name: RoleNameInt,
msg: Box<ControlPlaneErrorMessage>,
ttl: Option<Duration>,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
debug!(
key = &*endpoint_id,
"created a cache entry for an endpoint access error"
);
let ttl = ttl.unwrap_or(self.config.ttl);
let controls = if msg.get_reason() == Reason::RoleProtected {
// RoleProtected is the only role-specific error that control plane can give us.
// If a given role name does not exist, it still returns a successful response,
// just with an empty secret.
None
} else {
// We can cache all the other errors in EndpointInfo.controls,
// because they don't depend on what role name we pass to control plane.
Some(Entry::new(Err(msg.clone()), ttl))
};
let role_controls = Entry::new(Err(msg), ttl);
match self.cache.entry(endpoint_id) {
clashmap::Entry::Vacant(e) => {
e.insert(EndpointInfo {
role_controls: HashMap::from_iter([(role_name, role_controls)]),
controls,
});
}
clashmap::Entry::Occupied(mut e) => {
let ep = e.get_mut();
if let Some(entry) = &ep.controls
&& !entry.is_expired()
&& entry.value.is_ok()
{
// If we have cached non-expired, non-error controls, keep them.
} else {
ep.controls = controls;
}
ep.controls = Some(controls);
if ep.role_controls.len() < self.config.max_roles {
ep.role_controls.insert(role_name, role_controls);
}
@@ -307,6 +275,27 @@ impl ProjectInfoCacheImpl {
}
}
fn ignore_ttl_since(&self) -> Option<Instant> {
let ttl_disabled_since_us = self
.ttl_disabled_since_us
.load(std::sync::atomic::Ordering::Relaxed);
if ttl_disabled_since_us == u64::MAX {
return None;
}
Some(self.start_time + Duration::from_micros(ttl_disabled_since_us))
}
fn get_cache_times(&self) -> Instant {
let mut valid_since = Instant::now() - self.config.ttl;
if let Some(ignore_ttl_since) = self.ignore_ttl_since() {
// We are fine if entry is not older than ttl or was added before we are getting notifications.
valid_since = valid_since.min(ignore_ttl_since);
}
valid_since
}
pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) {
let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else {
return;
@@ -324,7 +313,16 @@ impl ProjectInfoCacheImpl {
return;
};
if role_controls.get().is_expired() {
let created_at = role_controls.get().created_at;
let expire = match self.ignore_ttl_since() {
// if ignoring TTL, we should still try and roll the password if it's old
// and we the client gave an incorrect password. There could be some lag on the redis channel.
Some(_) => created_at + self.config.ttl < Instant::now(),
// edge case: redis is down, let's be generous and invalidate the cache immediately.
None => true,
};
if expire {
role_controls.remove();
}
}
@@ -363,11 +361,13 @@ impl ProjectInfoCacheImpl {
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use crate::control_plane::messages::{Details, EndpointRateLimitConfig, ErrorInfo, Status};
use crate::control_plane::messages::EndpointRateLimitConfig;
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
use crate::scram::ServerSecret;
use std::sync::Arc;
use crate::types::ProjectId;
#[tokio::test]
async fn test_project_info_cache_settings() {
@@ -378,9 +378,9 @@ mod tests {
ttl: Duration::from_secs(1),
gc_interval: Duration::from_secs(600),
});
let project_id: Option<ProjectIdInt> = Some(ProjectIdInt::from(&"project".into()));
let project_id: ProjectId = "project".into();
let endpoint_id: EndpointId = "endpoint".into();
let account_id = None;
let account_id: Option<AccountIdInt> = None;
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
@@ -393,7 +393,7 @@ mod tests {
cache.insert_endpoint_access(
account_id,
project_id,
(&project_id).into(),
(&endpoint_id).into(),
(&user1).into(),
EndpointAccessControl {
@@ -409,7 +409,7 @@ mod tests {
cache.insert_endpoint_access(
account_id,
project_id,
(&project_id).into(),
(&endpoint_id).into(),
(&user2).into(),
EndpointAccessControl {
@@ -423,17 +423,11 @@ mod tests {
},
);
let (cached, ttl) = cache
.get_role_secret_with_ttl(&endpoint_id, &user1)
.unwrap();
assert_eq!(cached.unwrap().secret, secret1);
assert_eq!(ttl, cache.config.ttl);
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
assert_eq!(cached.secret, secret1);
let (cached, ttl) = cache
.get_role_secret_with_ttl(&endpoint_id, &user2)
.unwrap();
assert_eq!(cached.unwrap().secret, secret2);
assert_eq!(ttl, cache.config.ttl);
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
assert_eq!(cached.secret, secret2);
// Shouldn't add more than 2 roles.
let user3: RoleName = "user3".into();
@@ -441,7 +435,7 @@ mod tests {
cache.insert_endpoint_access(
account_id,
project_id,
(&project_id).into(),
(&endpoint_id).into(),
(&user3).into(),
EndpointAccessControl {
@@ -455,144 +449,17 @@ mod tests {
},
);
assert!(
cache
.get_role_secret_with_ttl(&endpoint_id, &user3)
.is_none()
);
assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
let cached = cache
.get_endpoint_access_with_ttl(&endpoint_id)
.unwrap()
.0
.unwrap();
let cached = cache.get_endpoint_access(&endpoint_id).unwrap();
assert_eq!(cached.allowed_ips, allowed_ips);
tokio::time::advance(Duration::from_secs(2)).await;
let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user1);
let cached = cache.get_role_secret(&endpoint_id, &user1);
assert!(cached.is_none());
let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user2);
let cached = cache.get_role_secret(&endpoint_id, &user2);
assert!(cached.is_none());
let cached = cache.get_endpoint_access_with_ttl(&endpoint_id);
let cached = cache.get_endpoint_access(&endpoint_id);
assert!(cached.is_none());
}
#[tokio::test]
async fn test_caching_project_info_errors() {
let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
size: 10,
max_roles: 10,
ttl: Duration::from_secs(1),
gc_interval: Duration::from_secs(600),
});
let project_id = Some(ProjectIdInt::from(&"project".into()));
let endpoint_id: EndpointId = "endpoint".into();
let account_id = None;
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
let secret = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
let role_msg = Box::new(ControlPlaneErrorMessage {
error: "role is protected and cannot be used for password-based authentication"
.to_owned()
.into_boxed_str(),
http_status_code: http::StatusCode::NOT_FOUND,
status: Some(Status {
code: "PERMISSION_DENIED".to_owned().into_boxed_str(),
message: "role is protected and cannot be used for password-based authentication"
.to_owned()
.into_boxed_str(),
details: Details {
error_info: Some(ErrorInfo {
reason: Reason::RoleProtected,
}),
retry_info: None,
user_facing_message: None,
},
}),
});
let generic_msg = Box::new(ControlPlaneErrorMessage {
error: "oh noes".to_owned().into_boxed_str(),
http_status_code: http::StatusCode::NOT_FOUND,
status: None,
});
let get_role_secret = |endpoint_id, role_name| {
cache
.get_role_secret_with_ttl(endpoint_id, role_name)
.unwrap()
.0
};
let get_endpoint_access =
|endpoint_id| cache.get_endpoint_access_with_ttl(endpoint_id).unwrap().0;
// stores role-specific errors only for get_role_secret
cache.insert_endpoint_access_err(
(&endpoint_id).into(),
(&user1).into(),
role_msg.clone(),
None,
);
assert_eq!(
get_role_secret(&endpoint_id, &user1).unwrap_err().error,
role_msg.error
);
assert!(cache.get_endpoint_access_with_ttl(&endpoint_id).is_none());
// stores non-role specific errors for both get_role_secret and get_endpoint_access
cache.insert_endpoint_access_err(
(&endpoint_id).into(),
(&user1).into(),
generic_msg.clone(),
None,
);
assert_eq!(
get_role_secret(&endpoint_id, &user1).unwrap_err().error,
generic_msg.error
);
assert_eq!(
get_endpoint_access(&endpoint_id).unwrap_err().error,
generic_msg.error
);
// error isn't returned for other roles in the same endpoint
assert!(
cache
.get_role_secret_with_ttl(&endpoint_id, &user2)
.is_none()
);
// success for a role does not overwrite errors for other roles
cache.insert_endpoint_access(
account_id,
project_id,
(&endpoint_id).into(),
(&user2).into(),
EndpointAccessControl {
allowed_ips: Arc::new(vec![]),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
},
RoleAccessControl {
secret: secret.clone(),
},
);
assert!(get_role_secret(&endpoint_id, &user1).is_err());
assert!(get_role_secret(&endpoint_id, &user2).is_ok());
// ...but does clear the access control error
assert!(get_endpoint_access(&endpoint_id).is_ok());
// storing an error does not overwrite successful access control response
cache.insert_endpoint_access_err(
(&endpoint_id).into(),
(&user2).into(),
generic_msg.clone(),
None,
);
assert!(get_role_secret(&endpoint_id, &user2).is_err());
assert!(get_endpoint_access(&endpoint_id).is_ok());
}
}

Some files were not shown because too many files have changed in this diff Show More