Compare commits

..

2 Commits

Author SHA1 Message Date
Folke Behrens
dd99ad6dc7 proxy: queue UNLINK command when conn terminates 2025-07-15 17:43:48 +02:00
Folke Behrens
11a804a3ac Split enqueuing and driving state machine, reduce noise 2025-07-15 17:42:37 +02:00
102 changed files with 1289 additions and 5072 deletions

View File

@@ -181,8 +181,6 @@ runs:
# Ref https://github.com/neondatabase/neon/issues/4540
# cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/coverage run)
cov_prefix=()
# Explicitly set LLVM_PROFILE_FILE to /dev/null to avoid writing *.profraw files
export LLVM_PROFILE_FILE=/dev/null
else
cov_prefix=()
fi

View File

@@ -87,27 +87,22 @@ jobs:
uses: ./.github/workflows/build-build-tools-image.yml
secrets: inherit
lint-yamls:
needs: [ meta, check-permissions, build-build-tools-image ]
lint-openapi-spec:
runs-on: ubuntu-22.04
needs: [ meta, check-permissions ]
# We do need to run this in `.*-rc-pr` because of hotfixes.
if: ${{ contains(fromJSON('["pr", "push-main", "storage-rc-pr", "proxy-rc-pr", "compute-rc-pr"]'), needs.meta.outputs.run-kind) }}
runs-on: [ self-hosted, small ]
container:
image: ${{ needs.build-build-tools-image.outputs.image }}
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
options: --init
steps:
- name: Harden the runner (Audit all outbound calls)
uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0
with:
egress-policy: audit
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- run: make -C compute manifest-schema-validation
- uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- run: make lint-openapi-spec
check-codestyle-python:
@@ -222,6 +217,28 @@ jobs:
build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm
secrets: inherit
validate-compute-manifest:
runs-on: ubuntu-22.04
needs: [ meta, check-permissions ]
# We do need to run this in `.*-rc-pr` because of hotfixes.
if: ${{ contains(fromJSON('["pr", "push-main", "storage-rc-pr", "proxy-rc-pr", "compute-rc-pr"]'), needs.meta.outputs.run-kind) }}
steps:
- name: Harden the runner (Audit all outbound calls)
uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0
with:
egress-policy: audit
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Node.js
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
with:
node-version: '24'
- name: Validate manifest against schema
run: |
make -C compute manifest-schema-validation
build-and-test-locally:
needs: [ meta, build-build-tools-image ]
# We do need to run this in `.*-rc-pr` because of hotfixes.

3
.gitignore vendored
View File

@@ -29,6 +29,3 @@ docker-compose/docker-compose-parallel.yml
# pgindent typedef lists
*.list
# Node
**/node_modules/

4
Cargo.lock generated
View File

@@ -6204,7 +6204,6 @@ dependencies = [
"itertools 0.10.5",
"jsonwebtoken",
"metrics",
"nix 0.30.1",
"once_cell",
"pageserver_api",
"parking_lot 0.12.1",
@@ -6212,7 +6211,6 @@ dependencies = [
"postgres-protocol",
"postgres_backend",
"postgres_ffi",
"postgres_ffi_types",
"postgres_versioninfo",
"pprof",
"pq_proto",
@@ -6257,7 +6255,7 @@ dependencies = [
"anyhow",
"const_format",
"pageserver_api",
"postgres_ffi_types",
"postgres_ffi",
"postgres_versioninfo",
"pq_proto",
"serde",

View File

@@ -2,7 +2,7 @@ ROOT_PROJECT_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
# Where to install Postgres, default is ./pg_install, maybe useful for package
# managers.
POSTGRES_INSTALL_DIR ?= $(ROOT_PROJECT_DIR)/pg_install
POSTGRES_INSTALL_DIR ?= $(ROOT_PROJECT_DIR)/pg_install/
# Supported PostgreSQL versions
POSTGRES_VERSIONS = v17 v16 v15 v14
@@ -14,7 +14,7 @@ POSTGRES_VERSIONS = v17 v16 v15 v14
# it is derived from BUILD_TYPE.
# All intermediate build artifacts are stored here.
BUILD_DIR := $(ROOT_PROJECT_DIR)/build
BUILD_DIR := build
ICU_PREFIX_DIR := /usr/local/icu
@@ -212,7 +212,7 @@ neon-pgindent: postgres-v17-pg-bsd-indent neon-pg-ext-v17
FIND_TYPEDEF=$(ROOT_PROJECT_DIR)/vendor/postgres-v17/src/tools/find_typedef \
INDENT=$(BUILD_DIR)/v17/src/tools/pg_bsd_indent/pg_bsd_indent \
PGINDENT_SCRIPT=$(ROOT_PROJECT_DIR)/vendor/postgres-v17/src/tools/pgindent/pgindent \
-C $(BUILD_DIR)/pgxn-v17/neon \
-C $(BUILD_DIR)/neon-v17 \
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile pgindent
@@ -220,15 +220,11 @@ neon-pgindent: postgres-v17-pg-bsd-indent neon-pg-ext-v17
setup-pre-commit-hook:
ln -s -f $(ROOT_PROJECT_DIR)/pre-commit.py .git/hooks/pre-commit
build-tools/node_modules: build-tools/package.json
cd build-tools && $(if $(CI),npm ci,npm install)
touch build-tools/node_modules
.PHONY: lint-openapi-spec
lint-openapi-spec: build-tools/node_modules
lint-openapi-spec:
# operation-2xx-response: pageserver timeline delete returns 404 on success
find . -iname "openapi_spec.y*ml" -exec\
npx --prefix=build-tools/ redocly\
docker run --rm -v ${PWD}:/spec ghcr.io/redocly/cli:1.34.4\
--skip-rule=operation-operationId --skip-rule=operation-summary --extends=minimal\
--skip-rule=no-server-example.com --skip-rule=operation-2xx-response\
lint {} \+

View File

@@ -188,12 +188,6 @@ RUN curl -fsSL 'https://apt.llvm.org/llvm-snapshot.gpg.key' | apt-key add - \
&& bash -c 'for f in /usr/bin/clang*-${LLVM_VERSION} /usr/bin/llvm*-${LLVM_VERSION}; do ln -s "${f}" "${f%-${LLVM_VERSION}}"; done' \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
# Install node
ENV NODE_VERSION=24
RUN curl -fsSL https://deb.nodesource.com/setup_${NODE_VERSION}.x | bash - \
&& apt install -y nodejs \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
# Install docker
RUN curl -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg \
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/debian ${DEBIAN_VERSION} stable" > /etc/apt/sources.list.d/docker.list \
@@ -317,14 +311,14 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux
. "$HOME/.cargo/env" && \
cargo --version && rustup --version && \
rustup component add llvm-tools rustfmt clippy && \
cargo install rustfilt --locked --version ${RUSTFILT_VERSION} && \
cargo install cargo-hakari --locked --version ${CARGO_HAKARI_VERSION} && \
cargo install cargo-deny --locked --version ${CARGO_DENY_VERSION} && \
cargo install cargo-hack --locked --version ${CARGO_HACK_VERSION} && \
cargo install cargo-nextest --locked --version ${CARGO_NEXTEST_VERSION} && \
cargo install cargo-chef --locked --version ${CARGO_CHEF_VERSION} && \
cargo install diesel_cli --locked --version ${CARGO_DIESEL_CLI_VERSION} \
--features postgres-bundled --no-default-features && \
cargo install rustfilt --version ${RUSTFILT_VERSION} --locked && \
cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} --locked && \
cargo install cargo-deny --version ${CARGO_DENY_VERSION} --locked && \
cargo install cargo-hack --version ${CARGO_HACK_VERSION} --locked && \
cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} --locked && \
cargo install cargo-chef --version ${CARGO_CHEF_VERSION} --locked && \
cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} --locked \
--features postgres-bundled --no-default-features && \
rm -rf /home/nonroot/.cargo/registry && \
rm -rf /home/nonroot/.cargo/git

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +0,0 @@
{
"name": "build-tools",
"private": true,
"devDependencies": {
"@redocly/cli": "1.34.4",
"@sourcemeta/jsonschema": "10.0.0"
}
}

View File

@@ -50,9 +50,9 @@ jsonnetfmt-format:
jsonnetfmt --in-place $(jsonnet_files)
.PHONY: manifest-schema-validation
manifest-schema-validation: ../build-tools/node_modules
npx --prefix=../build-tools/ jsonschema validate -d https://json-schema.org/draft/2020-12/schema manifest.schema.json manifest.yaml
manifest-schema-validation: node_modules
node_modules/.bin/jsonschema validate -d https://json-schema.org/draft/2020-12/schema manifest.schema.json manifest.yaml
../build-tools/node_modules: ../build-tools/package.json
cd ../build-tools && $(if $(CI),npm ci,npm install)
touch ../build-tools/node_modules
node_modules: package.json
npm install
touch node_modules

View File

@@ -170,29 +170,7 @@ RUN case $DEBIAN_VERSION in \
FROM build-deps AS pg-build
ARG PG_VERSION
COPY vendor/postgres-${PG_VERSION:?} postgres
COPY compute/patches/postgres_fdw.patch .
COPY compute/patches/pg_stat_statements_pg14-16.patch .
COPY compute/patches/pg_stat_statements_pg17.patch .
RUN cd postgres && \
# Apply patches to some contrib extensions
# For example, we need to grant EXECUTE on pg_stat_statements_reset() to {privileged_role_name}.
# In vanilla Postgres this function is limited to Postgres role superuser.
# In Neon we have {privileged_role_name} role that is not a superuser but replaces superuser in some cases.
# We could add the additional grant statements to the Postgres repository but it would be hard to maintain,
# whenever we need to pick up a new Postgres version and we want to limit the changes in our Postgres fork,
# so we do it here.
case "${PG_VERSION}" in \
"v14" | "v15" | "v16") \
patch -p1 < /pg_stat_statements_pg14-16.patch; \
;; \
"v17") \
patch -p1 < /pg_stat_statements_pg17.patch; \
;; \
*) \
# To do not forget to migrate patches to the next major version
echo "No contrib patches for this PostgreSQL version" && exit 1;; \
esac && \
patch -p1 < /postgres_fdw.patch && \
export CONFIGURE_CMD="./configure CFLAGS='-O2 -g3 -fsigned-char' --enable-debug --with-openssl --with-uuid=ossp \
--with-icu --with-libxml --with-libxslt --with-lz4" && \
if [ "${PG_VERSION:?}" != "v14" ]; then \
@@ -206,6 +184,8 @@ RUN cd postgres && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/autoinc.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/dblink.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/postgres_fdw.control && \
file=/usr/local/pgsql/share/extension/postgres_fdw--1.0.sql && [ -e $file ] && \
echo 'GRANT USAGE ON FOREIGN DATA WRAPPER postgres_fdw TO neon_superuser;' >> $file && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/bloom.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/earthdistance.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/insert_username.control && \
@@ -215,7 +195,34 @@ RUN cd postgres && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgrowlocks.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgstattuple.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/refint.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/xml2.control
echo 'trusted = true' >> /usr/local/pgsql/share/extension/xml2.control && \
# We need to grant EXECUTE on pg_stat_statements_reset() to neon_superuser.
# In vanilla postgres this function is limited to Postgres role superuser.
# In neon we have neon_superuser role that is not a superuser but replaces superuser in some cases.
# We could add the additional grant statements to the postgres repository but it would be hard to maintain,
# whenever we need to pick up a new postgres version and we want to limit the changes in our postgres fork,
# so we do it here.
for file in /usr/local/pgsql/share/extension/pg_stat_statements--*.sql; do \
filename=$(basename "$file"); \
# Note that there are no downgrade scripts for pg_stat_statements, so we \
# don't have to modify any downgrade paths or (much) older versions: we only \
# have to make sure every creation of the pg_stat_statements_reset function \
# also adds execute permissions to the neon_superuser.
case $filename in \
pg_stat_statements--1.4.sql) \
# pg_stat_statements_reset is first created with 1.4
echo 'GRANT EXECUTE ON FUNCTION pg_stat_statements_reset() TO neon_superuser;' >> $file; \
;; \
pg_stat_statements--1.6--1.7.sql) \
# Then with the 1.6-1.7 migration it is re-created with a new signature, thus add the permissions back
echo 'GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) TO neon_superuser;' >> $file; \
;; \
pg_stat_statements--1.10--1.11.sql) \
# Then with the 1.10-1.11 migration it is re-created with a new signature again, thus add the permissions back
echo 'GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint, boolean) TO neon_superuser;' >> $file; \
;; \
esac; \
done;
# Set PATH for all the subsequent build steps
ENV PATH="/usr/local/pgsql/bin:$PATH"
@@ -1517,7 +1524,7 @@ WORKDIR /ext-src
COPY compute/patches/pg_duckdb_v031.patch .
COPY compute/patches/duckdb_v120.patch .
# pg_duckdb build requires source dir to be a git repo to get submodules
# allow {privileged_role_name} to execute some functions that in pg_duckdb are available to superuser only:
# allow neon_superuser to execute some functions that in pg_duckdb are available to superuser only:
# - extension management function duckdb.install_extension()
# - access to duckdb.extensions table and its sequence
RUN git clone --depth 1 --branch v0.3.1 https://github.com/duckdb/pg_duckdb.git pg_duckdb-src && \

7
compute/package.json Normal file
View File

@@ -0,0 +1,7 @@
{
"name": "neon-compute",
"private": true,
"dependencies": {
"@sourcemeta/jsonschema": "9.3.4"
}
}

View File

@@ -1,26 +1,22 @@
diff --git a/sql/anon.sql b/sql/anon.sql
index 0cdc769..5eab1d6 100644
index 0cdc769..b450327 100644
--- a/sql/anon.sql
+++ b/sql/anon.sql
@@ -1141,3 +1141,19 @@ $$
@@ -1141,3 +1141,15 @@ $$
-- TODO : https://en.wikipedia.org/wiki/L-diversity
-- TODO : https://en.wikipedia.org/wiki/T-closeness
+
+-- NEON Patches
+
+GRANT ALL ON SCHEMA anon to neon_superuser;
+GRANT ALL ON ALL TABLES IN SCHEMA anon TO neon_superuser;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT ALL ON SCHEMA anon to %I', privileged_role_name);
+ EXECUTE format('GRANT ALL ON ALL TABLES IN SCHEMA anon TO %I', privileged_role_name);
+
+ IF current_setting('server_version_num')::int >= 150000 THEN
+ EXECUTE format('GRANT SET ON PARAMETER anon.transparent_dynamic_masking TO %I', privileged_role_name);
+ END IF;
+ IF current_setting('server_version_num')::int >= 150000 THEN
+ GRANT SET ON PARAMETER anon.transparent_dynamic_masking TO neon_superuser;
+ END IF;
+END $$;
diff --git a/sql/init.sql b/sql/init.sql
index 7da6553..9b6164b 100644

View File

@@ -21,21 +21,13 @@ index 3235cc8..6b892bc 100644
include Makefile.global
diff --git a/sql/pg_duckdb--0.2.0--0.3.0.sql b/sql/pg_duckdb--0.2.0--0.3.0.sql
index d777d76..3b54396 100644
index d777d76..af60106 100644
--- a/sql/pg_duckdb--0.2.0--0.3.0.sql
+++ b/sql/pg_duckdb--0.2.0--0.3.0.sql
@@ -1056,3 +1056,14 @@ GRANT ALL ON FUNCTION duckdb.cache(TEXT, TEXT) TO PUBLIC;
@@ -1056,3 +1056,6 @@ GRANT ALL ON FUNCTION duckdb.cache(TEXT, TEXT) TO PUBLIC;
GRANT ALL ON FUNCTION duckdb.cache_info() TO PUBLIC;
GRANT ALL ON FUNCTION duckdb.cache_delete(TEXT) TO PUBLIC;
GRANT ALL ON PROCEDURE duckdb.recycle_ddb() TO PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT ALL ON FUNCTION duckdb.install_extension(TEXT) TO %I', privileged_role_name);
+ EXECUTE format('GRANT ALL ON TABLE duckdb.extensions TO %I', privileged_role_name);
+ EXECUTE format('GRANT ALL ON SEQUENCE duckdb.extensions_table_seq TO %I', privileged_role_name);
+END $$;
+GRANT ALL ON FUNCTION duckdb.install_extension(TEXT) TO neon_superuser;
+GRANT ALL ON TABLE duckdb.extensions TO neon_superuser;
+GRANT ALL ON SEQUENCE duckdb.extensions_table_seq TO neon_superuser;

View File

@@ -1,34 +0,0 @@
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.4.sql b/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
index 58cdf600fce..8be57a996f6 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
@@ -46,3 +46,12 @@ GRANT SELECT ON pg_stat_statements TO PUBLIC;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset() FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset() TO %I', privileged_role_name);
+END $$;
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql b/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
index 6fc3fed4c93..256345a8f79 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
@@ -20,3 +20,12 @@ LANGUAGE C STRICT PARALLEL SAFE;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) TO %I', privileged_role_name);
+END $$;

View File

@@ -1,52 +0,0 @@
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.10--1.11.sql b/contrib/pg_stat_statements/pg_stat_statements--1.10--1.11.sql
index 0bb2c397711..32764db1d8b 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.10--1.11.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.10--1.11.sql
@@ -80,3 +80,12 @@ LANGUAGE C STRICT PARALLEL SAFE;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint, boolean) FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint, boolean) TO %I', privileged_role_name);
+END $$;
\ No newline at end of file
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.4.sql b/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
index 58cdf600fce..8be57a996f6 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
@@ -46,3 +46,12 @@ GRANT SELECT ON pg_stat_statements TO PUBLIC;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset() FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset() TO %I', privileged_role_name);
+END $$;
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql b/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
index 6fc3fed4c93..256345a8f79 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
@@ -20,3 +20,12 @@ LANGUAGE C STRICT PARALLEL SAFE;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) TO %I', privileged_role_name);
+END $$;

View File

@@ -1,17 +0,0 @@
diff --git a/contrib/postgres_fdw/postgres_fdw--1.0.sql b/contrib/postgres_fdw/postgres_fdw--1.0.sql
index a0f0fc1bf45..ee077f2eea6 100644
--- a/contrib/postgres_fdw/postgres_fdw--1.0.sql
+++ b/contrib/postgres_fdw/postgres_fdw--1.0.sql
@@ -16,3 +16,12 @@ LANGUAGE C STRICT;
CREATE FOREIGN DATA WRAPPER postgres_fdw
HANDLER postgres_fdw_handler
VALIDATOR postgres_fdw_validator;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT USAGE ON FOREIGN DATA WRAPPER postgres_fdw TO %I', privileged_role_name);
+END $$;

View File

@@ -87,14 +87,6 @@ struct Cli {
#[arg(short = 'C', long, value_name = "DATABASE_URL")]
pub connstr: String,
#[arg(
long,
default_value = "neon_superuser",
value_name = "PRIVILEGED_ROLE_NAME",
value_parser = Self::parse_privileged_role_name
)]
pub privileged_role_name: String,
#[cfg(target_os = "linux")]
#[arg(long, default_value = "neon-postgres")]
pub cgroup: String,
@@ -157,21 +149,6 @@ impl Cli {
Ok(url)
}
/// For simplicity, we do not escape `privileged_role_name` anywhere in the code.
/// Since it's a system role, which we fully control, that's fine. Still, let's
/// validate it to avoid any surprises.
fn parse_privileged_role_name(value: &str) -> Result<String> {
use regex::Regex;
let pattern = Regex::new(r"^[a-z_]+$").unwrap();
if !pattern.is_match(value) {
bail!("--privileged-role-name can only contain lowercase letters and underscores")
}
Ok(value.to_string())
}
}
fn main() -> Result<()> {
@@ -201,7 +178,6 @@ fn main() -> Result<()> {
ComputeNodeParams {
compute_id: cli.compute_id,
connstr,
privileged_role_name: cli.privileged_role_name.clone(),
pgdata: cli.pgdata.clone(),
pgbin: cli.pgbin.clone(),
pgversion: get_pg_version_string(&cli.pgbin),
@@ -351,49 +327,4 @@ mod test {
])
.expect_err("URL parameters are not allowed");
}
#[test]
fn verify_privileged_role_name() {
// Valid name
let cli = Cli::parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--privileged-role-name",
"my_superuser",
]);
assert_eq!(cli.privileged_role_name, "my_superuser");
// Invalid names
Cli::try_parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--privileged-role-name",
"NeonSuperuser",
])
.expect_err("uppercase letters are not allowed");
Cli::try_parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--privileged-role-name",
"$'neon_superuser",
])
.expect_err("special characters are not allowed");
Cli::try_parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--privileged-role-name",
"",
])
.expect_err("empty name is not allowed");
}
}

View File

@@ -74,20 +74,12 @@ const DEFAULT_INSTALLED_EXTENSIONS_COLLECTION_INTERVAL: u64 = 3600;
/// Static configuration params that don't change after startup. These mostly
/// come from the CLI args, or are derived from them.
#[derive(Clone, Debug)]
pub struct ComputeNodeParams {
/// The ID of the compute
pub compute_id: String,
/// Url type maintains proper escaping
// Url type maintains proper escaping
pub connstr: url::Url,
/// The name of the 'weak' superuser role, which we give to the users.
/// It follows the allow list approach, i.e., we take a standard role
/// and grant it extra permissions with explicit GRANTs here and there,
/// and core patches.
pub privileged_role_name: String,
pub resize_swap_on_bind: bool,
pub set_disk_quota_for_fs: Option<String>,
@@ -1397,7 +1389,6 @@ impl ComputeNode {
self.create_pgdata()?;
config::write_postgres_conf(
pgdata_path,
&self.params,
&pspec.spec,
self.params.internal_http_port,
tls_config,
@@ -1746,7 +1737,6 @@ impl ComputeNode {
}
// Run migrations separately to not hold up cold starts
let params = self.params.clone();
tokio::spawn(async move {
let mut conf = conf.as_ref().clone();
conf.application_name("compute_ctl:migrations");
@@ -1758,7 +1748,7 @@ impl ComputeNode {
eprintln!("connection error: {e}");
}
});
if let Err(e) = handle_migrations(params, &mut client).await {
if let Err(e) = handle_migrations(&mut client).await {
error!("Failed to run migrations: {}", e);
}
}
@@ -1837,7 +1827,6 @@ impl ComputeNode {
let pgdata_path = Path::new(&self.params.pgdata);
config::write_postgres_conf(
pgdata_path,
&self.params,
&spec,
self.params.internal_http_port,
tls_config,

View File

@@ -9,7 +9,6 @@ use std::path::Path;
use compute_api::responses::TlsConfig;
use compute_api::spec::{ComputeAudit, ComputeMode, ComputeSpec, GenericOption};
use crate::compute::ComputeNodeParams;
use crate::pg_helpers::{
GenericOptionExt, GenericOptionsSearch, PgOptionsSerialize, escape_conf_value,
};
@@ -42,7 +41,6 @@ pub fn line_in_file(path: &Path, line: &str) -> Result<bool> {
/// Create or completely rewrite configuration file specified by `path`
pub fn write_postgres_conf(
pgdata_path: &Path,
params: &ComputeNodeParams,
spec: &ComputeSpec,
extension_server_port: u16,
tls_config: &Option<TlsConfig>,
@@ -163,12 +161,6 @@ pub fn write_postgres_conf(
}
}
writeln!(
file,
"neon.privileged_role_name={}",
escape_conf_value(params.privileged_role_name.as_str())
)?;
// If there are any extra options in the 'settings' field, append those
if spec.cluster.settings.is_some() {
writeln!(file, "# Managed by compute_ctl: begin")?;

View File

@@ -1 +0,0 @@
ALTER ROLE {privileged_role_name} BYPASSRLS;

View File

@@ -0,0 +1 @@
ALTER ROLE neon_superuser BYPASSRLS;

View File

@@ -15,7 +15,7 @@ DO $$
DECLARE
role_name text;
BEGIN
FOR role_name IN SELECT rolname FROM pg_roles WHERE pg_has_role(rolname, '{privileged_role_name}', 'member')
FOR role_name IN SELECT rolname FROM pg_roles WHERE pg_has_role(rolname, 'neon_superuser', 'member')
LOOP
RAISE NOTICE 'EXECUTING ALTER ROLE % INHERIT', quote_ident(role_name);
EXECUTE 'ALTER ROLE ' || quote_ident(role_name) || ' INHERIT';
@@ -23,7 +23,7 @@ BEGIN
FOR role_name IN SELECT rolname FROM pg_roles
WHERE
NOT pg_has_role(rolname, '{privileged_role_name}', 'member') AND NOT starts_with(rolname, 'pg_')
NOT pg_has_role(rolname, 'neon_superuser', 'member') AND NOT starts_with(rolname, 'pg_')
LOOP
RAISE NOTICE 'EXECUTING ALTER ROLE % NOBYPASSRLS', quote_ident(role_name);
EXECUTE 'ALTER ROLE ' || quote_ident(role_name) || ' NOBYPASSRLS';

View File

@@ -1,6 +1,6 @@
DO $$
BEGIN
IF (SELECT setting::numeric >= 160000 FROM pg_settings WHERE name = 'server_version_num') THEN
EXECUTE 'GRANT pg_create_subscription TO {privileged_role_name}';
EXECUTE 'GRANT pg_create_subscription TO neon_superuser';
END IF;
END $$;

View File

@@ -0,0 +1 @@
GRANT pg_monitor TO neon_superuser WITH ADMIN OPTION;

View File

@@ -1 +0,0 @@
GRANT pg_monitor TO {privileged_role_name} WITH ADMIN OPTION;

View File

@@ -1,4 +1,4 @@
-- SKIP: Deemed insufficient for allowing relations created by extensions to be
-- interacted with by {privileged_role_name} without permission issues.
-- interacted with by neon_superuser without permission issues.
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO {privileged_role_name};
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO neon_superuser;

View File

@@ -1,4 +1,4 @@
-- SKIP: Deemed insufficient for allowing relations created by extensions to be
-- interacted with by {privileged_role_name} without permission issues.
-- interacted with by neon_superuser without permission issues.
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO {privileged_role_name};
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO neon_superuser;

View File

@@ -1,3 +1,3 @@
-- SKIP: Moved inline to the handle_grants() functions.
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO {privileged_role_name} WITH GRANT OPTION;
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO neon_superuser WITH GRANT OPTION;

View File

@@ -1,3 +1,3 @@
-- SKIP: Moved inline to the handle_grants() functions.
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO {privileged_role_name} WITH GRANT OPTION;
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO neon_superuser WITH GRANT OPTION;

View File

@@ -1,7 +1,7 @@
DO $$
BEGIN
IF (SELECT setting::numeric >= 160000 FROM pg_settings WHERE name = 'server_version_num') THEN
EXECUTE 'GRANT EXECUTE ON FUNCTION pg_export_snapshot TO {privileged_role_name}';
EXECUTE 'GRANT EXECUTE ON FUNCTION pg_log_standby_snapshot TO {privileged_role_name}';
EXECUTE 'GRANT EXECUTE ON FUNCTION pg_export_snapshot TO neon_superuser';
EXECUTE 'GRANT EXECUTE ON FUNCTION pg_log_standby_snapshot TO neon_superuser';
END IF;
END $$;

View File

@@ -0,0 +1 @@
GRANT EXECUTE ON FUNCTION pg_show_replication_origin_status TO neon_superuser;

View File

@@ -1 +0,0 @@
GRANT EXECUTE ON FUNCTION pg_show_replication_origin_status TO {privileged_role_name};

View File

@@ -0,0 +1 @@
GRANT pg_signal_backend TO neon_superuser WITH ADMIN OPTION;

View File

@@ -1 +0,0 @@
GRANT pg_signal_backend TO {privileged_role_name} WITH ADMIN OPTION;

View File

@@ -9,7 +9,6 @@ use reqwest::StatusCode;
use tokio_postgres::Client;
use tracing::{error, info, instrument};
use crate::compute::ComputeNodeParams;
use crate::config;
use crate::metrics::{CPLANE_REQUESTS_TOTAL, CPlaneRequestRPC, UNKNOWN_HTTP_STATUS};
use crate::migration::MigrationRunner;
@@ -170,7 +169,7 @@ pub async fn handle_neon_extension_upgrade(client: &mut Client) -> Result<()> {
}
#[instrument(skip_all)]
pub async fn handle_migrations(params: ComputeNodeParams, client: &mut Client) -> Result<()> {
pub async fn handle_migrations(client: &mut Client) -> Result<()> {
info!("handle migrations");
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -179,59 +178,26 @@ pub async fn handle_migrations(params: ComputeNodeParams, client: &mut Client) -
// Add new migrations in numerical order.
let migrations = [
&format!(
include_str!("./migrations/0001-add_bypass_rls_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
include_str!("./migrations/0001-neon_superuser_bypass_rls.sql"),
include_str!("./migrations/0002-alter_roles.sql"),
include_str!("./migrations/0003-grant_pg_create_subscription_to_neon_superuser.sql"),
include_str!("./migrations/0004-grant_pg_monitor_to_neon_superuser.sql"),
include_str!("./migrations/0005-grant_all_on_tables_to_neon_superuser.sql"),
include_str!("./migrations/0006-grant_all_on_sequences_to_neon_superuser.sql"),
include_str!(
"./migrations/0007-grant_all_on_tables_to_neon_superuser_with_grant_option.sql"
),
&format!(
include_str!("./migrations/0002-alter_roles.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0003-grant_pg_create_subscription_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0004-grant_pg_monitor_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0005-grant_all_on_tables_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0006-grant_all_on_sequences_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!(
"./migrations/0007-grant_all_on_tables_with_grant_option_to_privileged_role.sql"
),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!(
"./migrations/0008-grant_all_on_sequences_with_grant_option_to_privileged_role.sql"
),
privileged_role_name = params.privileged_role_name
include_str!(
"./migrations/0008-grant_all_on_sequences_to_neon_superuser_with_grant_option.sql"
),
include_str!("./migrations/0009-revoke_replication_for_previously_allowed_roles.sql"),
&format!(
include_str!(
"./migrations/0010-grant_snapshot_synchronization_funcs_to_privileged_role.sql"
),
privileged_role_name = params.privileged_role_name
include_str!(
"./migrations/0010-grant_snapshot_synchronization_funcs_to_neon_superuser.sql"
),
&format!(
include_str!(
"./migrations/0011-grant_pg_show_replication_origin_status_to_privileged_role.sql"
),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0012-grant_pg_signal_backend_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
include_str!(
"./migrations/0011-grant_pg_show_replication_origin_status_to_neon_superuser.sql"
),
include_str!("./migrations/0012-grant_pg_signal_backend_to_neon_superuser.sql"),
];
MigrationRunner::new(client, &migrations)

View File

@@ -13,14 +13,14 @@ use tokio_postgres::Client;
use tokio_postgres::error::SqlState;
use tracing::{Instrument, debug, error, info, info_span, instrument, warn};
use crate::compute::{ComputeNode, ComputeNodeParams, ComputeState};
use crate::compute::{ComputeNode, ComputeState};
use crate::pg_helpers::{
DatabaseExt, Escaping, GenericOptionsSearch, RoleExt, get_existing_dbs_async,
get_existing_roles_async,
};
use crate::spec_apply::ApplySpecPhase::{
CreateAndAlterDatabases, CreateAndAlterRoles, CreateAvailabilityCheck, CreatePgauditExtension,
CreatePgauditlogtofileExtension, CreatePrivilegedRole, CreateSchemaNeon,
CreateAndAlterDatabases, CreateAndAlterRoles, CreateAvailabilityCheck, CreateNeonSuperuser,
CreatePgauditExtension, CreatePgauditlogtofileExtension, CreateSchemaNeon,
DisablePostgresDBPgAudit, DropInvalidDatabases, DropRoles, FinalizeDropLogicalSubscriptions,
HandleNeonExtension, HandleOtherExtensions, RenameAndDeleteDatabases, RenameRoles,
RunInEachDatabase,
@@ -49,7 +49,6 @@ impl ComputeNode {
// Proceed with post-startup configuration. Note, that order of operations is important.
let client = Self::get_maintenance_client(&conf).await?;
let spec = spec.clone();
let params = Arc::new(self.params.clone());
let databases = get_existing_dbs_async(&client).await?;
let roles = get_existing_roles_async(&client)
@@ -158,7 +157,6 @@ impl ComputeNode {
let conf = Arc::new(conf);
let fut = Self::apply_spec_sql_db(
params.clone(),
spec.clone(),
conf,
ctx.clone(),
@@ -187,7 +185,7 @@ impl ComputeNode {
}
for phase in [
CreatePrivilegedRole,
CreateNeonSuperuser,
DropInvalidDatabases,
RenameRoles,
CreateAndAlterRoles,
@@ -197,7 +195,6 @@ impl ComputeNode {
] {
info!("Applying phase {:?}", &phase);
apply_operations(
params.clone(),
spec.clone(),
ctx.clone(),
jwks_roles.clone(),
@@ -246,7 +243,6 @@ impl ComputeNode {
}
let fut = Self::apply_spec_sql_db(
params.clone(),
spec.clone(),
conf,
ctx.clone(),
@@ -297,7 +293,6 @@ impl ComputeNode {
for phase in phases {
debug!("Applying phase {:?}", &phase);
apply_operations(
params.clone(),
spec.clone(),
ctx.clone(),
jwks_roles.clone(),
@@ -318,9 +313,7 @@ impl ComputeNode {
/// May opt to not connect to databases that don't have any scheduled
/// operations. The function is concurrency-controlled with the provided
/// semaphore. The caller has to make sure the semaphore isn't exhausted.
#[allow(clippy::too_many_arguments)] // TODO: needs bigger refactoring
async fn apply_spec_sql_db(
params: Arc<ComputeNodeParams>,
spec: Arc<ComputeSpec>,
conf: Arc<tokio_postgres::Config>,
ctx: Arc<tokio::sync::RwLock<MutableApplyContext>>,
@@ -335,7 +328,6 @@ impl ComputeNode {
for subphase in subphases {
apply_operations(
params.clone(),
spec.clone(),
ctx.clone(),
jwks_roles.clone(),
@@ -475,7 +467,7 @@ pub enum PerDatabasePhase {
#[derive(Clone, Debug)]
pub enum ApplySpecPhase {
CreatePrivilegedRole,
CreateNeonSuperuser,
DropInvalidDatabases,
RenameRoles,
CreateAndAlterRoles,
@@ -518,7 +510,6 @@ pub struct MutableApplyContext {
/// - No timeouts have (yet) been implemented.
/// - The caller is responsible for limiting and/or applying concurrency.
pub async fn apply_operations<'a, Fut, F>(
params: Arc<ComputeNodeParams>,
spec: Arc<ComputeSpec>,
ctx: Arc<RwLock<MutableApplyContext>>,
jwks_roles: Arc<HashSet<String>>,
@@ -536,7 +527,7 @@ where
debug!("Processing phase {:?}", &apply_spec_phase);
let ctx = ctx;
let mut ops = get_operations(&params, &spec, &ctx, &jwks_roles, &apply_spec_phase)
let mut ops = get_operations(&spec, &ctx, &jwks_roles, &apply_spec_phase)
.await?
.peekable();
@@ -597,18 +588,14 @@ where
/// sort/merge/batch execution, but for now this is a nice way to improve
/// batching behavior of the commands.
async fn get_operations<'a>(
params: &'a ComputeNodeParams,
spec: &'a ComputeSpec,
ctx: &'a RwLock<MutableApplyContext>,
jwks_roles: &'a HashSet<String>,
apply_spec_phase: &'a ApplySpecPhase,
) -> Result<Box<dyn Iterator<Item = Operation> + 'a + Send>> {
match apply_spec_phase {
ApplySpecPhase::CreatePrivilegedRole => Ok(Box::new(once(Operation {
query: format!(
include_str!("sql/create_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
ApplySpecPhase::CreateNeonSuperuser => Ok(Box::new(once(Operation {
query: include_str!("sql/create_neon_superuser.sql").to_string(),
comment: None,
}))),
ApplySpecPhase::DropInvalidDatabases => {
@@ -710,9 +697,8 @@ async fn get_operations<'a>(
None => {
let query = if !jwks_roles.contains(role.name.as_str()) {
format!(
"CREATE ROLE {} INHERIT CREATEROLE CREATEDB BYPASSRLS REPLICATION IN ROLE {} {}",
"CREATE ROLE {} INHERIT CREATEROLE CREATEDB BYPASSRLS REPLICATION IN ROLE neon_superuser {}",
role.name.pg_quote(),
params.privileged_role_name,
role.to_pg_options(),
)
} else {
@@ -863,9 +849,8 @@ async fn get_operations<'a>(
// ALL PRIVILEGES grants CREATE, CONNECT, and TEMPORARY on the database
// (see https://www.postgresql.org/docs/current/ddl-priv.html)
query: format!(
"GRANT ALL PRIVILEGES ON DATABASE {} TO {}",
db.name.pg_quote(),
params.privileged_role_name
"GRANT ALL PRIVILEGES ON DATABASE {} TO neon_superuser",
db.name.pg_quote()
),
comment: None,
},

View File

@@ -0,0 +1,8 @@
DO $$
BEGIN
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = 'neon_superuser')
THEN
CREATE ROLE neon_superuser CREATEDB CREATEROLE NOLOGIN REPLICATION BYPASSRLS IN ROLE pg_read_all_data, pg_write_all_data;
END IF;
END
$$;

View File

@@ -1,8 +0,0 @@
DO $$
BEGIN
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{privileged_role_name}')
THEN
CREATE ROLE {privileged_role_name} CREATEDB CREATEROLE NOLOGIN REPLICATION BYPASSRLS IN ROLE pg_read_all_data, pg_write_all_data;
END IF;
END
$$;

View File

@@ -631,10 +631,6 @@ struct EndpointCreateCmdArgs {
help = "Allow multiple primary endpoints running on the same branch. Shouldn't be used normally, but useful for tests."
)]
allow_multiple: bool,
/// Only allow changing it on creation
#[clap(long, help = "Name of the privileged role for the endpoint")]
privileged_role_name: Option<String>,
}
#[derive(clap::Args)]
@@ -1484,7 +1480,6 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
args.grpc,
!args.update_catalog,
false,
args.privileged_role_name.clone(),
)?;
}
EndpointCmd::Start(args) => {

View File

@@ -99,7 +99,6 @@ pub struct EndpointConf {
features: Vec<ComputeFeature>,
cluster: Option<Cluster>,
compute_ctl_config: ComputeCtlConfig,
privileged_role_name: Option<String>,
}
//
@@ -200,7 +199,6 @@ impl ComputeControlPlane {
grpc: bool,
skip_pg_catalog_updates: bool,
drop_subscriptions_before_start: bool,
privileged_role_name: Option<String>,
) -> Result<Arc<Endpoint>> {
let pg_port = pg_port.unwrap_or_else(|| self.get_port());
let external_http_port = external_http_port.unwrap_or_else(|| self.get_port() + 1);
@@ -238,7 +236,6 @@ impl ComputeControlPlane {
features: vec![],
cluster: None,
compute_ctl_config: compute_ctl_config.clone(),
privileged_role_name: privileged_role_name.clone(),
});
ep.create_endpoint_dir()?;
@@ -260,7 +257,6 @@ impl ComputeControlPlane {
features: vec![],
cluster: None,
compute_ctl_config,
privileged_role_name,
})?,
)?;
std::fs::write(
@@ -336,9 +332,6 @@ pub struct Endpoint {
/// The compute_ctl config for the endpoint's compute.
compute_ctl_config: ComputeCtlConfig,
/// The name of the privileged role for the endpoint.
privileged_role_name: Option<String>,
}
#[derive(PartialEq, Eq)]
@@ -439,7 +432,6 @@ impl Endpoint {
features: conf.features,
cluster: conf.cluster,
compute_ctl_config: conf.compute_ctl_config,
privileged_role_name: conf.privileged_role_name,
})
}
@@ -878,10 +870,6 @@ impl Endpoint {
cmd.arg("--dev");
}
if let Some(privileged_role_name) = self.privileged_role_name.clone() {
cmd.args(["--privileged-role-name", &privileged_role_name]);
}
let child = cmd.spawn()?;
// set up a scopeguard to kill & wait for the child in case we panic or bail below
let child = scopeguard::guard(child, |mut child| {

View File

@@ -1 +1,418 @@
pub mod shmem;
//! Shared memory utilities for neon communicator
use std::num::NonZeroUsize;
use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use nix::errno::Errno;
use nix::sys::mman::MapFlags;
use nix::sys::mman::ProtFlags;
use nix::sys::mman::mmap as nix_mmap;
use nix::sys::mman::munmap as nix_munmap;
use nix::unistd::ftruncate as nix_ftruncate;
/// ShmemHandle represents a shared memory area that can be shared by processes over fork().
/// Unlike shared memory allocated by Postgres, this area is resizable, up to 'max_size' that's
/// specified at creation.
///
/// The area is backed by an anonymous file created with memfd_create(). The full address space for
/// 'max_size' is reserved up-front with mmap(), but whenever you call [`ShmemHandle::set_size`],
/// the underlying file is resized. Do not access the area beyond the current size. Currently, that
/// will cause the file to be expanded, but we might use mprotect() etc. to enforce that in the
/// future.
pub struct ShmemHandle {
/// memfd file descriptor
fd: OwnedFd,
max_size: usize,
// Pointer to the beginning of the shared memory area. The header is stored there.
shared_ptr: NonNull<SharedStruct>,
// Pointer to the beginning of the user data
pub data_ptr: NonNull<u8>,
}
/// This is stored at the beginning in the shared memory area.
struct SharedStruct {
max_size: usize,
/// Current size of the backing file. The high-order bit is used for the RESIZE_IN_PROGRESS flag
current_size: AtomicUsize,
}
const RESIZE_IN_PROGRESS: usize = 1 << 63;
const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
/// Error type returned by the ShmemHandle functions.
#[derive(thiserror::Error, Debug)]
#[error("{msg}: {errno}")]
pub struct Error {
pub msg: String,
pub errno: Errno,
}
impl Error {
fn new(msg: &str, errno: Errno) -> Error {
Error {
msg: msg.to_string(),
errno,
}
}
}
impl ShmemHandle {
/// Create a new shared memory area. To communicate between processes, the processes need to be
/// fork()'d after calling this, so that the ShmemHandle is inherited by all processes.
///
/// If the ShmemHandle is dropped, the memory is unmapped from the current process. Other
/// processes can continue using it, however.
pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<ShmemHandle, Error> {
// create the backing anonymous file.
let fd = create_backing_file(name)?;
Self::new_with_fd(fd, initial_size, max_size)
}
fn new_with_fd(
fd: OwnedFd,
initial_size: usize,
max_size: usize,
) -> Result<ShmemHandle, Error> {
// We reserve the high-order bit for the RESIZE_IN_PROGRESS flag, and the actual size
// is a little larger than this because of the SharedStruct header. Make the upper limit
// somewhat smaller than that, because with anything close to that, you'll run out of
// memory anyway.
if max_size >= 1 << 48 {
panic!("max size {max_size} too large");
}
if initial_size > max_size {
panic!("initial size {initial_size} larger than max size {max_size}");
}
// The actual initial / max size is the one given by the caller, plus the size of
// 'SharedStruct'.
let initial_size = HEADER_SIZE + initial_size;
let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
// Reserve address space for it with mmap
//
// TODO: Use MAP_HUGETLB if possible
let start_ptr = unsafe {
nix_mmap(
None,
max_size,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_SHARED,
&fd,
0,
)
}
.map_err(|e| Error::new("mmap failed: {e}", e))?;
// Reserve space for the initial size
enlarge_file(fd.as_fd(), initial_size as u64)?;
// Initialize the header
let shared: NonNull<SharedStruct> = start_ptr.cast();
unsafe {
shared.write(SharedStruct {
max_size: max_size.into(),
current_size: AtomicUsize::new(initial_size),
})
};
// The user data begins after the header
let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
Ok(ShmemHandle {
fd,
max_size: max_size.into(),
shared_ptr: shared,
data_ptr,
})
}
// return reference to the header
fn shared(&self) -> &SharedStruct {
unsafe { self.shared_ptr.as_ref() }
}
/// Resize the shared memory area. 'new_size' must not be larger than the 'max_size' specified
/// when creating the area.
///
/// This may only be called from one process/thread concurrently. We detect that case
/// and return an Error.
pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
let new_size = new_size + HEADER_SIZE;
let shared = self.shared();
if new_size > self.max_size {
panic!(
"new size ({} is greater than max size ({})",
new_size, self.max_size
);
}
assert_eq!(self.max_size, shared.max_size);
// Lock the area by setting the bit in 'current_size'
//
// Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
// and the posix_fallocate/ftruncate call is surely a synchronization point anyway. But
// since this is not performance-critical, better safe than sorry .
let mut old_size = shared.current_size.load(Ordering::Acquire);
loop {
if (old_size & RESIZE_IN_PROGRESS) != 0 {
return Err(Error::new(
"concurrent resize detected",
Errno::UnknownErrno,
));
}
match shared.current_size.compare_exchange(
old_size,
new_size,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => old_size = x,
}
}
// Ok, we got the lock.
//
// NB: If anything goes wrong, we *must* clear the bit!
let result = {
use std::cmp::Ordering::{Equal, Greater, Less};
match new_size.cmp(&old_size) {
Less => nix_ftruncate(&self.fd, new_size as i64).map_err(|e| {
Error::new("could not shrink shmem segment, ftruncate failed: {e}", e)
}),
Equal => Ok(()),
Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
}
};
// Unlock
shared.current_size.store(
if result.is_ok() { new_size } else { old_size },
Ordering::Release,
);
result
}
/// Returns the current user-visible size of the shared memory segment.
///
/// NOTE: a concurrent set_size() call can change the size at any time. It is the caller's
/// responsibility not to access the area beyond the current size.
pub fn current_size(&self) -> usize {
let total_current_size =
self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
total_current_size - HEADER_SIZE
}
}
impl Drop for ShmemHandle {
fn drop(&mut self) {
// SAFETY: The pointer was obtained from mmap() with the given size.
// We unmap the entire region.
let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
// The fd is dropped automatically by OwnedFd.
}
}
/// Create a "backing file" for the shared memory area. On Linux, use memfd_create(), to create an
/// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
/// development and testing, but in production we want the file to stay in memory.
///
/// disable 'unused_variables' warnings, because in the macos path, 'name' is unused.
#[allow(unused_variables)]
fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
#[cfg(not(target_os = "macos"))]
{
nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
.map_err(|e| Error::new("memfd_create failed: {e}", e))
}
#[cfg(target_os = "macos")]
{
let file = tempfile::tempfile().map_err(|e| {
Error::new(
"could not create temporary file to back shmem area: {e}",
nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
)
})?;
Ok(OwnedFd::from(file))
}
}
fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
// Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
// we don't get a segfault later when trying to actually use it.
#[cfg(not(target_os = "macos"))]
{
nix::fcntl::posix_fallocate(fd, 0, size as i64).map_err(|e| {
Error::new(
"could not grow shmem segment, posix_fallocate failed: {e}",
e,
)
})
}
// As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
#[cfg(target_os = "macos")]
{
nix::unistd::ftruncate(fd, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, ftruncate failed: {e}", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nix::unistd::ForkResult;
use std::ops::Range;
/// check that all bytes in given range have the expected value.
fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
for i in range {
let b = unsafe { *(ptr.add(i)) };
assert_eq!(expected, b, "unexpected byte at offset {i}");
}
}
/// Write 'b' to all bytes in the given range
fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
}
// simple single-process test of growing and shrinking
#[test]
fn test_shmem_resize() -> Result<(), Error> {
let max_size = 1024 * 1024;
let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
assert_eq!(init_struct.current_size(), 0);
// Initial grow
let size1 = 10000;
init_struct.set_size(size1).unwrap();
assert_eq!(init_struct.current_size(), size1);
// Write some data
let data_ptr = init_struct.data_ptr.as_ptr();
write_range(data_ptr, 0xAA, 0..size1);
assert_range(data_ptr, 0xAA, 0..size1);
// Shrink
let size2 = 5000;
init_struct.set_size(size2).unwrap();
assert_eq!(init_struct.current_size(), size2);
// Grow again
let size3 = 20000;
init_struct.set_size(size3).unwrap();
assert_eq!(init_struct.current_size(), size3);
// Try to read it. The area that was shrunk and grown again should read as all zeros now
assert_range(data_ptr, 0xAA, 0..5000);
assert_range(data_ptr, 0, 5000..size1);
// Try to grow beyond max_size
//let size4 = max_size + 1;
//assert!(init_struct.set_size(size4).is_err());
// Dropping init_struct should unmap the memory
drop(init_struct);
Ok(())
}
/// This is used in tests to coordinate between test processes. It's like std::sync::Barrier,
/// but is stored in the shared memory area and works across processes. It's implemented by
/// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
struct SimpleBarrier {
num_procs: usize,
count: AtomicUsize,
}
impl SimpleBarrier {
unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
unsafe {
*ptr = SimpleBarrier {
num_procs,
count: AtomicUsize::new(0),
}
}
}
pub fn wait(&self) {
let old = self.count.fetch_add(1, Ordering::Relaxed);
let generation = old / self.num_procs;
let mut current = old + 1;
while current < (generation + 1) * self.num_procs {
std::thread::sleep(std::time::Duration::from_millis(10));
current = self.count.load(Ordering::Relaxed);
}
}
}
#[test]
fn test_multi_process() {
// Initialize
let max_size = 1_000_000_000_000;
let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
let ptr = init_struct.data_ptr.as_ptr();
// Store the SimpleBarrier in the first 1k of the area.
init_struct.set_size(10000).unwrap();
let barrier_ptr: *mut SimpleBarrier = unsafe {
ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
.cast()
};
unsafe { SimpleBarrier::init(barrier_ptr, 2) };
let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
// Fork another test process. The code after this runs in both processes concurrently.
let fork_result = unsafe { nix::unistd::fork().unwrap() };
// In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
if fork_result.is_parent() {
write_range(ptr, 0xAA, 1000..2000);
} else {
write_range(ptr, 0xBB, 2000..3000);
}
barrier.wait();
// Verify the contents. (in both processes)
assert_range(ptr, 0xAA, 1000..2000);
assert_range(ptr, 0xBB, 2000..3000);
// Grow, from the child this time
let size = 10_000_000;
if !fork_result.is_parent() {
init_struct.set_size(size).unwrap();
}
barrier.wait();
// make some writes at the end
if fork_result.is_parent() {
write_range(ptr, 0xAA, (size - 10)..size);
} else {
write_range(ptr, 0xBB, (size - 20)..(size - 10));
}
barrier.wait();
// Verify the contents. (This runs in both processes)
assert_range(ptr, 0, (size - 1000)..(size - 20));
assert_range(ptr, 0xBB, (size - 20)..(size - 10));
assert_range(ptr, 0xAA, (size - 10)..size);
if let ForkResult::Parent { child } = fork_result {
nix::sys::wait::waitpid(child, None).unwrap();
}
}
}

View File

@@ -1,409 +0,0 @@
//! Dynamically resizable contiguous chunk of shared memory
use std::num::NonZeroUsize;
use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use nix::errno::Errno;
use nix::sys::mman::MapFlags;
use nix::sys::mman::ProtFlags;
use nix::sys::mman::mmap as nix_mmap;
use nix::sys::mman::munmap as nix_munmap;
use nix::unistd::ftruncate as nix_ftruncate;
/// `ShmemHandle` represents a shared memory area that can be shared by processes over `fork()`.
/// Unlike shared memory allocated by Postgres, this area is resizable, up to `max_size` that's
/// specified at creation.
///
/// The area is backed by an anonymous file created with `memfd_create()`. The full address space for
/// `max_size` is reserved up-front with `mmap()`, but whenever you call [`ShmemHandle::set_size`],
/// the underlying file is resized. Do not access the area beyond the current size. Currently, that
/// will cause the file to be expanded, but we might use `mprotect()` etc. to enforce that in the
/// future.
pub struct ShmemHandle {
/// memfd file descriptor
fd: OwnedFd,
max_size: usize,
// Pointer to the beginning of the shared memory area. The header is stored there.
shared_ptr: NonNull<SharedStruct>,
// Pointer to the beginning of the user data
pub data_ptr: NonNull<u8>,
}
/// This is stored at the beginning in the shared memory area.
struct SharedStruct {
max_size: usize,
/// Current size of the backing file. The high-order bit is used for the [`RESIZE_IN_PROGRESS`] flag.
current_size: AtomicUsize,
}
const RESIZE_IN_PROGRESS: usize = 1 << 63;
const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
/// Error type returned by the [`ShmemHandle`] functions.
#[derive(thiserror::Error, Debug)]
#[error("{msg}: {errno}")]
pub struct Error {
pub msg: String,
pub errno: Errno,
}
impl Error {
fn new(msg: &str, errno: Errno) -> Self {
Self {
msg: msg.to_string(),
errno,
}
}
}
impl ShmemHandle {
/// Create a new shared memory area. To communicate between processes, the processes need to be
/// `fork()`'d after calling this, so that the `ShmemHandle` is inherited by all processes.
///
/// If the `ShmemHandle` is dropped, the memory is unmapped from the current process. Other
/// processes can continue using it, however.
pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<Self, Error> {
// create the backing anonymous file.
let fd = create_backing_file(name)?;
Self::new_with_fd(fd, initial_size, max_size)
}
fn new_with_fd(fd: OwnedFd, initial_size: usize, max_size: usize) -> Result<Self, Error> {
// We reserve the high-order bit for the `RESIZE_IN_PROGRESS` flag, and the actual size
// is a little larger than this because of the SharedStruct header. Make the upper limit
// somewhat smaller than that, because with anything close to that, you'll run out of
// memory anyway.
assert!(max_size < 1 << 48, "max size {max_size} too large");
assert!(
initial_size <= max_size,
"initial size {initial_size} larger than max size {max_size}"
);
// The actual initial / max size is the one given by the caller, plus the size of
// 'SharedStruct'.
let initial_size = HEADER_SIZE + initial_size;
let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
// Reserve address space for it with mmap
//
// TODO: Use MAP_HUGETLB if possible
let start_ptr = unsafe {
nix_mmap(
None,
max_size,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_SHARED,
&fd,
0,
)
}
.map_err(|e| Error::new("mmap failed", e))?;
// Reserve space for the initial size
enlarge_file(fd.as_fd(), initial_size as u64)?;
// Initialize the header
let shared: NonNull<SharedStruct> = start_ptr.cast();
unsafe {
shared.write(SharedStruct {
max_size: max_size.into(),
current_size: AtomicUsize::new(initial_size),
});
}
// The user data begins after the header
let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
Ok(Self {
fd,
max_size: max_size.into(),
shared_ptr: shared,
data_ptr,
})
}
// return reference to the header
fn shared(&self) -> &SharedStruct {
unsafe { self.shared_ptr.as_ref() }
}
/// Resize the shared memory area. `new_size` must not be larger than the `max_size` specified
/// when creating the area.
///
/// This may only be called from one process/thread concurrently. We detect that case
/// and return an [`shmem::Error`](Error).
pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
let new_size = new_size + HEADER_SIZE;
let shared = self.shared();
assert!(
new_size <= self.max_size,
"new size ({new_size}) is greater than max size ({})",
self.max_size
);
assert_eq!(self.max_size, shared.max_size);
// Lock the area by setting the bit in `current_size`
//
// Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
// and the `posix_fallocate`/`ftruncate` call is surely a synchronization point anyway. But
// since this is not performance-critical, better safe than sorry.
let mut old_size = shared.current_size.load(Ordering::Acquire);
loop {
if (old_size & RESIZE_IN_PROGRESS) != 0 {
return Err(Error::new(
"concurrent resize detected",
Errno::UnknownErrno,
));
}
match shared.current_size.compare_exchange(
old_size,
new_size,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => old_size = x,
}
}
// Ok, we got the lock.
//
// NB: If anything goes wrong, we *must* clear the bit!
let result = {
use std::cmp::Ordering::{Equal, Greater, Less};
match new_size.cmp(&old_size) {
Less => nix_ftruncate(&self.fd, new_size as i64)
.map_err(|e| Error::new("could not shrink shmem segment, ftruncate failed", e)),
Equal => Ok(()),
Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
}
};
// Unlock
shared.current_size.store(
if result.is_ok() { new_size } else { old_size },
Ordering::Release,
);
result
}
/// Returns the current user-visible size of the shared memory segment.
///
/// NOTE: a concurrent [`ShmemHandle::set_size()`] call can change the size at any time.
/// It is the caller's responsibility not to access the area beyond the current size.
pub fn current_size(&self) -> usize {
let total_current_size =
self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
total_current_size - HEADER_SIZE
}
}
impl Drop for ShmemHandle {
fn drop(&mut self) {
// SAFETY: The pointer was obtained from mmap() with the given size.
// We unmap the entire region.
let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
// The fd is dropped automatically by OwnedFd.
}
}
/// Create a "backing file" for the shared memory area. On Linux, use `memfd_create()`, to create an
/// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
/// development and testing, but in production we want the file to stay in memory.
///
/// Disable unused variables warnings because `name` is unused in the macos path.
#[allow(unused_variables)]
fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
#[cfg(not(target_os = "macos"))]
{
nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
.map_err(|e| Error::new("memfd_create failed", e))
}
#[cfg(target_os = "macos")]
{
let file = tempfile::tempfile().map_err(|e| {
Error::new(
"could not create temporary file to back shmem area",
nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
)
})?;
Ok(OwnedFd::from(file))
}
}
fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
// Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
// we don't get a segfault later when trying to actually use it.
#[cfg(not(target_os = "macos"))]
{
nix::fcntl::posix_fallocate(fd, 0, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, posix_fallocate failed", e))
}
// As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
#[cfg(target_os = "macos")]
{
nix::unistd::ftruncate(fd, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, ftruncate failed", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nix::unistd::ForkResult;
use std::ops::Range;
/// check that all bytes in given range have the expected value.
fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
for i in range {
let b = unsafe { *(ptr.add(i)) };
assert_eq!(expected, b, "unexpected byte at offset {i}");
}
}
/// Write 'b' to all bytes in the given range
fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
}
// simple single-process test of growing and shrinking
#[test]
fn test_shmem_resize() -> Result<(), Error> {
let max_size = 1024 * 1024;
let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
assert_eq!(init_struct.current_size(), 0);
// Initial grow
let size1 = 10000;
init_struct.set_size(size1).unwrap();
assert_eq!(init_struct.current_size(), size1);
// Write some data
let data_ptr = init_struct.data_ptr.as_ptr();
write_range(data_ptr, 0xAA, 0..size1);
assert_range(data_ptr, 0xAA, 0..size1);
// Shrink
let size2 = 5000;
init_struct.set_size(size2).unwrap();
assert_eq!(init_struct.current_size(), size2);
// Grow again
let size3 = 20000;
init_struct.set_size(size3).unwrap();
assert_eq!(init_struct.current_size(), size3);
// Try to read it. The area that was shrunk and grown again should read as all zeros now
assert_range(data_ptr, 0xAA, 0..5000);
assert_range(data_ptr, 0, 5000..size1);
// Try to grow beyond max_size
//let size4 = max_size + 1;
//assert!(init_struct.set_size(size4).is_err());
// Dropping init_struct should unmap the memory
drop(init_struct);
Ok(())
}
/// This is used in tests to coordinate between test processes. It's like `std::sync::Barrier`,
/// but is stored in the shared memory area and works across processes. It's implemented by
/// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
struct SimpleBarrier {
num_procs: usize,
count: AtomicUsize,
}
impl SimpleBarrier {
unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
unsafe {
*ptr = SimpleBarrier {
num_procs,
count: AtomicUsize::new(0),
}
}
}
pub fn wait(&self) {
let old = self.count.fetch_add(1, Ordering::Relaxed);
let generation = old / self.num_procs;
let mut current = old + 1;
while current < (generation + 1) * self.num_procs {
std::thread::sleep(std::time::Duration::from_millis(10));
current = self.count.load(Ordering::Relaxed);
}
}
}
#[test]
fn test_multi_process() {
// Initialize
let max_size = 1_000_000_000_000;
let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
let ptr = init_struct.data_ptr.as_ptr();
// Store the SimpleBarrier in the first 1k of the area.
init_struct.set_size(10000).unwrap();
let barrier_ptr: *mut SimpleBarrier = unsafe {
ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
.cast()
};
unsafe { SimpleBarrier::init(barrier_ptr, 2) };
let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
// Fork another test process. The code after this runs in both processes concurrently.
let fork_result = unsafe { nix::unistd::fork().unwrap() };
// In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
if fork_result.is_parent() {
write_range(ptr, 0xAA, 1000..2000);
} else {
write_range(ptr, 0xBB, 2000..3000);
}
barrier.wait();
// Verify the contents. (in both processes)
assert_range(ptr, 0xAA, 1000..2000);
assert_range(ptr, 0xBB, 2000..3000);
// Grow, from the child this time
let size = 10_000_000;
if !fork_result.is_parent() {
init_struct.set_size(size).unwrap();
}
barrier.wait();
// make some writes at the end
if fork_result.is_parent() {
write_range(ptr, 0xAA, (size - 10)..size);
} else {
write_range(ptr, 0xBB, (size - 20)..(size - 10));
}
barrier.wait();
// Verify the contents. (This runs in both processes)
assert_range(ptr, 0, (size - 1000)..(size - 20));
assert_range(ptr, 0xBB, (size - 20)..(size - 10));
assert_range(ptr, 0xAA, (size - 10)..size);
if let ForkResult::Parent { child } = fork_result {
nix::sys::wait::waitpid(child, None).unwrap();
}
}
}

View File

@@ -110,6 +110,7 @@ fn main() -> anyhow::Result<()> {
.allowlist_type("XLogRecPtr")
.allowlist_type("XLogSegNo")
.allowlist_type("TimeLineID")
.allowlist_type("TimestampTz")
.allowlist_type("MultiXactId")
.allowlist_type("MultiXactOffset")
.allowlist_type("MultiXactStatus")

View File

@@ -227,7 +227,8 @@ pub mod walrecord;
// Export some widely used datatypes that are unlikely to change across Postgres versions
pub use v14::bindings::{
BlockNumber, CheckPoint, ControlFileData, MultiXactId, OffsetNumber, Oid, PageHeaderData,
RepOriginId, TimeLineID, TransactionId, XLogRecPtr, XLogRecord, XLogSegNo, uint32, uint64,
RepOriginId, TimeLineID, TimestampTz, TransactionId, XLogRecPtr, XLogRecord, XLogSegNo, uint32,
uint64,
};
// Likewise for these, although the assumption that these don't change is a little more iffy.
pub use v14::bindings::{MultiXactOffset, MultiXactStatus};

View File

@@ -4,14 +4,13 @@
//! TODO: Generate separate types for each supported PG version
use bytes::{Buf, Bytes};
use postgres_ffi_types::TimestampTz;
use serde::{Deserialize, Serialize};
use utils::bin_ser::DeserializeError;
use utils::lsn::Lsn;
use crate::{
BLCKSZ, BlockNumber, MultiXactId, MultiXactOffset, MultiXactStatus, Oid, PgMajorVersion,
RepOriginId, TransactionId, XLOG_SIZE_OF_XLOG_RECORD, XLogRecord, pg_constants,
RepOriginId, TimestampTz, TransactionId, XLOG_SIZE_OF_XLOG_RECORD, XLogRecord, pg_constants,
};
#[repr(C)]
@@ -864,8 +863,7 @@ pub mod v17 {
XlHeapDelete, XlHeapInsert, XlHeapLock, XlHeapMultiInsert, XlHeapUpdate, XlParameterChange,
rm_neon,
};
pub use crate::TimeLineID;
pub use postgres_ffi_types::TimestampTz;
pub use crate::{TimeLineID, TimestampTz};
#[repr(C)]
#[derive(Debug)]

View File

@@ -9,11 +9,10 @@
use super::super::waldecoder::WalStreamDecoder;
use super::bindings::{
CheckPoint, ControlFileData, DBState_DB_SHUTDOWNED, FullTransactionId, TimeLineID,
CheckPoint, ControlFileData, DBState_DB_SHUTDOWNED, FullTransactionId, TimeLineID, TimestampTz,
XLogLongPageHeaderData, XLogPageHeaderData, XLogRecPtr, XLogRecord, XLogSegNo, XLOG_PAGE_MAGIC,
MY_PGVERSION
};
use postgres_ffi_types::TimestampTz;
use super::wal_generator::LogicalMessageGenerator;
use crate::pg_constants;
use crate::PG_TLI;

View File

@@ -11,4 +11,3 @@ pub mod forknum;
pub type Oid = u32;
pub type RepOriginId = u16;
pub type TimestampTz = i64;

View File

@@ -9,7 +9,7 @@ anyhow.workspace = true
const_format.workspace = true
serde.workspace = true
serde_json.workspace = true
postgres_ffi_types.workspace = true
postgres_ffi.workspace = true
postgres_versioninfo.workspace = true
pq_proto.workspace = true
tokio.workspace = true

View File

@@ -3,7 +3,7 @@
use std::net::SocketAddr;
use pageserver_api::shard::ShardIdentity;
use postgres_ffi_types::TimestampTz;
use postgres_ffi::TimestampTz;
use postgres_versioninfo::PgVersionId;
use serde::{Deserialize, Serialize};
use tokio::time::Instant;

View File

@@ -2,8 +2,7 @@
use bytes::Bytes;
use postgres_ffi::walrecord::{MultiXactMember, describe_postgres_wal_record};
use postgres_ffi::{MultiXactId, MultiXactOffset, TransactionId};
use postgres_ffi_types::TimestampTz;
use postgres_ffi::{MultiXactId, MultiXactOffset, TimestampTz, TransactionId};
use serde::{Deserialize, Serialize};
use utils::bin_ser::DeserializeError;

View File

@@ -431,7 +431,7 @@ pub fn empty_shmem() -> crate::bindings::WalproposerShmemState {
let empty_wal_rate_limiter = crate::bindings::WalRateLimiter {
should_limit: crate::bindings::pg_atomic_uint32 { value: 0 },
sent_bytes: 0,
last_recorded_time_us: crate::bindings::pg_atomic_uint64 { value: 0 },
last_recorded_time_us: 0,
};
crate::bindings::WalproposerShmemState {

View File

@@ -2357,7 +2357,6 @@ async fn timeline_compact_handler(
flags,
sub_compaction,
sub_compaction_max_job_size_mb,
gc_compaction_do_metadata_compaction: false,
};
let scheduled = compact_request

View File

@@ -25,9 +25,9 @@ use pageserver_api::keyspace::{KeySpaceRandomAccum, SparseKeySpace};
use pageserver_api::models::RelSizeMigration;
use pageserver_api::reltag::{BlockNumber, RelTag, SlruKind};
use pageserver_api::shard::ShardIdentity;
use postgres_ffi::{BLCKSZ, PgMajorVersion, TransactionId};
use postgres_ffi::{BLCKSZ, PgMajorVersion, TimestampTz, TransactionId};
use postgres_ffi_types::forknum::{FSM_FORKNUM, VISIBILITYMAP_FORKNUM};
use postgres_ffi_types::{Oid, RepOriginId, TimestampTz};
use postgres_ffi_types::{Oid, RepOriginId};
use serde::{Deserialize, Serialize};
use strum::IntoEnumIterator;
use tokio_util::sync::CancellationToken;

View File

@@ -1908,20 +1908,16 @@ impl TenantShard {
.map_err(LoadLocalTimelineError::ResumeDeletion)?;
}
// Upload the tenant manifest.
//
// This is uploaded unconditionally on every attach. This prevents races where a stale,
// still-alive tenant may modify a past manifest, and a future tenant loads it after this
// tenant has acted on it. Uploading a new manifest effectively hands over ownership of the
// manifest state. See: <https://databricks.atlassian.net/browse/LKB-165>.
// Stash the preloaded tenant manifest, and upload a new manifest if changed.
//
// NB: this must happen after the tenant is fully populated above. In particular the
// offloaded timelines, which are included in the manifest.
assert!(
self.remote_tenant_manifest.lock().await.is_none(),
"tenant manifest set before attach"
);
self.maybe_upload_tenant_manifest().await?; // always uploads, remote_tenant_manifest is None
{
let mut guard = self.remote_tenant_manifest.lock().await;
assert!(guard.is_none(), "tenant manifest set before preload"); // first populated here
*guard = preload.tenant_manifest;
}
self.maybe_upload_tenant_manifest().await?;
// The local filesystem contents are a cache of what's in the remote IndexPart;
// IndexPart is the source of truth.
@@ -9220,11 +9216,7 @@ mod tests {
let cancel = CancellationToken::new();
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
@@ -9307,11 +9299,7 @@ mod tests {
guard.cutoffs.space = Lsn(0x40);
}
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
@@ -9848,11 +9836,7 @@ mod tests {
let cancel = CancellationToken::new();
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
@@ -9887,11 +9871,7 @@ mod tests {
guard.cutoffs.space = Lsn(0x40);
}
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
@@ -10466,7 +10446,7 @@ mod tests {
&cancel,
CompactOptions {
flags: dryrun_flags,
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -10477,22 +10457,14 @@ mod tests {
verify_result().await;
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
verify_result().await;
// compact again
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
verify_result().await;
@@ -10511,22 +10483,14 @@ mod tests {
guard.cutoffs.space = Lsn(0x38);
}
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
verify_result().await; // no wals between 0x30 and 0x38, so we should obtain the same result
// not increasing the GC horizon and compact again
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
verify_result().await;
@@ -10731,7 +10695,7 @@ mod tests {
&cancel,
CompactOptions {
flags: dryrun_flags,
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -10742,22 +10706,14 @@ mod tests {
verify_result().await;
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
verify_result().await;
// compact again
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
verify_result().await;
@@ -10957,11 +10913,7 @@ mod tests {
let cancel = CancellationToken::new();
branch_tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
@@ -10974,7 +10926,7 @@ mod tests {
&cancel,
CompactOptions {
compact_lsn_range: Some(CompactLsnRange::above(Lsn(0x40))),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -11642,7 +11594,7 @@ mod tests {
CompactOptions {
flags: EnumSet::new(),
compact_key_range: Some((get_key(0)..get_key(2)).into()),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -11689,7 +11641,7 @@ mod tests {
CompactOptions {
flags: EnumSet::new(),
compact_key_range: Some((get_key(2)..get_key(4)).into()),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -11741,7 +11693,7 @@ mod tests {
CompactOptions {
flags: EnumSet::new(),
compact_key_range: Some((get_key(4)..get_key(9)).into()),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -11792,7 +11744,7 @@ mod tests {
CompactOptions {
flags: EnumSet::new(),
compact_key_range: Some((get_key(9)..get_key(10)).into()),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -11848,7 +11800,7 @@ mod tests {
CompactOptions {
flags: EnumSet::new(),
compact_key_range: Some((get_key(0)..get_key(10)).into()),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -12119,7 +12071,7 @@ mod tests {
&cancel,
CompactOptions {
compact_lsn_range: Some(CompactLsnRange::above(Lsn(0x28))),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -12154,11 +12106,7 @@ mod tests {
// compact again
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
verify_result().await;
@@ -12377,7 +12325,7 @@ mod tests {
CompactOptions {
compact_key_range: Some((get_key(0)..get_key(2)).into()),
compact_lsn_range: Some((Lsn(0x20)..Lsn(0x28)).into()),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -12423,7 +12371,7 @@ mod tests {
CompactOptions {
compact_key_range: Some((get_key(3)..get_key(8)).into()),
compact_lsn_range: Some((Lsn(0x28)..Lsn(0x40)).into()),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -12471,7 +12419,7 @@ mod tests {
CompactOptions {
compact_key_range: Some((get_key(0)..get_key(5)).into()),
compact_lsn_range: Some((Lsn(0x20)..Lsn(0x50)).into()),
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)
@@ -12506,11 +12454,7 @@ mod tests {
// final full compaction
tline
.compact_with_gc(
&cancel,
CompactOptions::default_for_gc_compaction_unit_tests(),
&ctx,
)
.compact_with_gc(&cancel, CompactOptions::default(), &ctx)
.await
.unwrap();
verify_result().await;
@@ -12620,7 +12564,7 @@ mod tests {
CompactOptions {
compact_key_range: None,
compact_lsn_range: None,
..CompactOptions::default_for_gc_compaction_unit_tests()
..Default::default()
},
&ctx,
)

View File

@@ -939,20 +939,6 @@ pub(crate) struct CompactOptions {
/// Set job size for the GC compaction.
/// This option is only used by GC compaction.
pub sub_compaction_max_job_size_mb: Option<u64>,
/// Only for GC compaction.
/// If set, the compaction will compact the metadata layers. Should be only set to true in unit tests
/// because metadata compaction is not fully supported yet.
pub gc_compaction_do_metadata_compaction: bool,
}
impl CompactOptions {
#[cfg(test)]
pub fn default_for_gc_compaction_unit_tests() -> Self {
Self {
gc_compaction_do_metadata_compaction: true,
..Default::default()
}
}
}
impl std::fmt::Debug for Timeline {
@@ -2199,7 +2185,6 @@ impl Timeline {
compact_lsn_range: None,
sub_compaction: false,
sub_compaction_max_job_size_mb: None,
gc_compaction_do_metadata_compaction: false,
},
ctx,
)

View File

@@ -396,7 +396,6 @@ impl GcCompactionQueue {
}),
compact_lsn_range: None,
sub_compaction_max_job_size_mb: None,
gc_compaction_do_metadata_compaction: false,
},
permit,
);
@@ -513,7 +512,6 @@ impl GcCompactionQueue {
compact_key_range: Some(job.compact_key_range.into()),
compact_lsn_range: Some(job.compact_lsn_range.into()),
sub_compaction_max_job_size_mb: None,
gc_compaction_do_metadata_compaction: false,
};
pending_tasks.push(GcCompactionQueueItem::SubCompactionJob {
options,
@@ -787,8 +785,6 @@ pub(crate) struct GcCompactJob {
/// as specified here. The true range being compacted is `min_lsn/max_lsn` in [`GcCompactionJobDescription`].
/// min_lsn will always <= the lower bound specified here, and max_lsn will always >= the upper bound specified here.
pub compact_lsn_range: Range<Lsn>,
/// See [`CompactOptions::gc_compaction_do_metadata_compaction`].
pub do_metadata_compaction: bool,
}
impl GcCompactJob {
@@ -803,7 +799,6 @@ impl GcCompactJob {
.compact_lsn_range
.map(|x| x.into())
.unwrap_or(Lsn::INVALID..Lsn::MAX),
do_metadata_compaction: options.gc_compaction_do_metadata_compaction,
}
}
}
@@ -3179,7 +3174,6 @@ impl Timeline {
dry_run: job.dry_run,
compact_key_range: start..end,
compact_lsn_range: job.compact_lsn_range.start..compact_below_lsn,
do_metadata_compaction: false,
});
current_start = Some(end);
}
@@ -3242,7 +3236,7 @@ impl Timeline {
async fn compact_with_gc_inner(
self: &Arc<Self>,
cancel: &CancellationToken,
mut job: GcCompactJob,
job: GcCompactJob,
ctx: &RequestContext,
yield_for_l0: bool,
) -> Result<CompactionOutcome, CompactionError> {
@@ -3250,28 +3244,6 @@ impl Timeline {
// with legacy compaction tasks in the future. Always ensure the lock order is compaction -> gc.
// Note that we already acquired the compaction lock when the outer `compact` function gets called.
// If the job is not configured to compact the metadata key range, shrink the key range
// to exclude the metadata key range. The check is done by checking if the end of the key range
// is larger than the start of the metadata key range. Note that metadata keys cover the entire
// second half of the keyspace, so it's enough to only check the end of the key range.
if !job.do_metadata_compaction
&& job.compact_key_range.end > Key::metadata_key_range().start
{
tracing::info!(
"compaction for metadata key range is not supported yet, overriding compact_key_range from {} to {}",
job.compact_key_range.end,
Key::metadata_key_range().start
);
// Shrink the key range to exclude the metadata key range.
job.compact_key_range.end = Key::metadata_key_range().start;
// Skip the job if the key range completely lies within the metadata key range.
if job.compact_key_range.start >= job.compact_key_range.end {
tracing::info!("compact_key_range is empty, skipping compaction");
return Ok(CompactionOutcome::Done);
}
}
let timer = Instant::now();
let begin_timer = timer;

View File

@@ -184,7 +184,7 @@ pub(super) async fn connection_manager_loop_step(
// If we've not received any updates from the broker from a while, are waiting for WAL
// and have no safekeeper connection or connection candidates, then it might be that
// the broker subscription is wedged. Drop the current subscription and re-subscribe
// the broker subscription is wedged. Drop the currrent subscription and re-subscribe
// with the goal of unblocking it.
_ = broker_reset_interval.tick() => {
let awaiting_lsn = wait_lsn_status.borrow().is_some();
@@ -192,7 +192,7 @@ pub(super) async fn connection_manager_loop_step(
let no_connection = connection_manager_state.wal_connection.is_none();
if awaiting_lsn && no_candidates && no_connection {
tracing::info!("No broker updates received for a while, but waiting for WAL. Re-setting stream ...");
tracing::warn!("No broker updates received for a while, but waiting for WAL. Re-setting stream ...");
broker_subscription = subscribe_for_timeline_updates(broker_client, id, cancel).await?;
}
},

View File

@@ -32,10 +32,9 @@ use pageserver_api::reltag::{BlockNumber, RelTag, SlruKind};
use pageserver_api::shard::ShardIdentity;
use postgres_ffi::walrecord::*;
use postgres_ffi::{
PgMajorVersion, TransactionId, dispatch_pgversion, enum_pgversion, enum_pgversion_dispatch,
fsm_logical_to_physical, pg_constants,
PgMajorVersion, TimestampTz, TransactionId, dispatch_pgversion, enum_pgversion,
enum_pgversion_dispatch, fsm_logical_to_physical, pg_constants,
};
use postgres_ffi_types::TimestampTz;
use postgres_ffi_types::forknum::{FSM_FORKNUM, INIT_FORKNUM, MAIN_FORKNUM, VISIBILITYMAP_FORKNUM};
use tracing::*;
use utils::bin_ser::{DeserializeError, SerializeError};

View File

@@ -543,15 +543,6 @@ _PG_init(void)
PGC_POSTMASTER,
0,
NULL, NULL, NULL);
DefineCustomStringVariable(
"neon.privileged_role_name",
"Name of the 'weak' superuser role, which we give to the users",
NULL,
&privileged_role_name,
"neon_superuser",
PGC_POSTMASTER, 0, NULL, NULL, NULL);
/*
* Important: This must happen after other parts of the extension are
* loaded, otherwise any settings to GUCs that were set before the

View File

@@ -16,6 +16,7 @@
extern char *neon_auth_token;
extern char *neon_timeline;
extern char *neon_tenant;
extern char *wal_acceptors_list;
extern int wal_acceptor_reconnect_timeout;
extern int wal_acceptor_connection_timeout;

View File

@@ -13,7 +13,7 @@
* accumulate changes. On subtransaction commit, the top of the stack
* is merged with the table below it.
*
* Support event triggers for {privileged_role_name}
* Support event triggers for neon_superuser
*
* IDENTIFICATION
* contrib/neon/neon_dll_handler.c
@@ -49,7 +49,6 @@
#include "neon_ddl_handler.h"
#include "neon_utils.h"
#include "neon.h"
static ProcessUtility_hook_type PreviousProcessUtilityHook = NULL;
static fmgr_hook_type next_fmgr_hook = NULL;
@@ -542,11 +541,11 @@ NeonXactCallback(XactEvent event, void *arg)
}
static bool
IsPrivilegedRole(const char *role_name)
RoleIsNeonSuperuser(const char *role_name)
{
Assert(role_name);
return strcmp(role_name, privileged_role_name) == 0;
return strcmp(role_name, "neon_superuser") == 0;
}
static void
@@ -579,9 +578,8 @@ HandleCreateDb(CreatedbStmt *stmt)
{
const char *owner_name = defGetString(downer);
if (IsPrivilegedRole(owner_name))
elog(ERROR, "could not create a database with owner %s", privileged_role_name);
if (RoleIsNeonSuperuser(owner_name))
elog(ERROR, "can't create a database with owner neon_superuser");
entry->owner = get_role_oid(owner_name, false);
}
else
@@ -611,9 +609,8 @@ HandleAlterOwner(AlterOwnerStmt *stmt)
memset(entry->old_name, 0, sizeof(entry->old_name));
new_owner = get_rolespec_name(stmt->newowner);
if (IsPrivilegedRole(new_owner))
elog(ERROR, "could not alter owner to %s", privileged_role_name);
if (RoleIsNeonSuperuser(new_owner))
elog(ERROR, "can't alter owner to neon_superuser");
entry->owner = get_role_oid(new_owner, false);
entry->type = Op_Set;
}
@@ -719,8 +716,8 @@ HandleAlterRole(AlterRoleStmt *stmt)
InitRoleTableIfNeeded();
role_name = get_rolespec_name(stmt->role);
if (IsPrivilegedRole(role_name) && !superuser())
elog(ERROR, "could not ALTER %s", privileged_role_name);
if (RoleIsNeonSuperuser(role_name) && !superuser())
elog(ERROR, "can't ALTER neon_superuser");
dpass = NULL;
foreach(option, stmt->options)
@@ -834,7 +831,7 @@ HandleRename(RenameStmt *stmt)
*
* In vanilla only superuser can create Event Triggers.
*
* We allow it for {privileged_role_name} by temporary switching to superuser. But as
* We allow it for neon_superuser by temporary switching to superuser. But as
* far as event trigger can fire in superuser context we should protect
* superuser from execution of arbitrary user's code.
*
@@ -894,7 +891,7 @@ force_noop(FmgrInfo *finfo)
* Also skip executing Event Triggers when GUC neon.event_triggers has been
* set to false. This might be necessary to be able to connect again after a
* LOGIN Event Trigger has been installed that would prevent connections as
* {privileged_role_name}.
* neon_superuser.
*/
static void
neon_fmgr_hook(FmgrHookEventType event, FmgrInfo *flinfo, Datum *private)
@@ -913,24 +910,24 @@ neon_fmgr_hook(FmgrHookEventType event, FmgrInfo *flinfo, Datum *private)
}
/*
* The {privileged_role_name} role can use the GUC neon.event_triggers to disable
* The neon_superuser role can use the GUC neon.event_triggers to disable
* firing Event Trigger.
*
* SET neon.event_triggers TO false;
*
* This only applies to the {privileged_role_name} role though, and only allows
* skipping Event Triggers owned by {privileged_role_name}, which we check by
* proxy of the Event Trigger function being owned by {privileged_role_name}.
* This only applies to the neon_superuser role though, and only allows
* skipping Event Triggers owned by neon_superuser, which we check by
* proxy of the Event Trigger function being owned by neon_superuser.
*
* A role that is created in role {privileged_role_name} should be allowed to also
* A role that is created in role neon_superuser should be allowed to also
* benefit from the neon_event_triggers GUC, and will be considered the
* same as the {privileged_role_name} role.
* same as the neon_superuser role.
*/
if (event == FHET_START
&& !neon_event_triggers
&& is_privileged_role())
&& is_neon_superuser())
{
Oid weak_superuser_oid = get_role_oid(privileged_role_name, false);
Oid neon_superuser_oid = get_role_oid("neon_superuser", false);
/* Find the Function Attributes (owner Oid, security definer) */
const char *fun_owner_name = NULL;
@@ -940,8 +937,8 @@ neon_fmgr_hook(FmgrHookEventType event, FmgrInfo *flinfo, Datum *private)
LookupFuncOwnerSecDef(flinfo->fn_oid, &fun_owner, &fun_is_secdef);
fun_owner_name = GetUserNameFromId(fun_owner, false);
if (IsPrivilegedRole(fun_owner_name)
|| has_privs_of_role(fun_owner, weak_superuser_oid))
if (RoleIsNeonSuperuser(fun_owner_name)
|| has_privs_of_role(fun_owner, neon_superuser_oid))
{
elog(WARNING,
"Skipping Event Trigger: neon.event_triggers is false");
@@ -1152,13 +1149,13 @@ ProcessCreateEventTrigger(
}
/*
* Allow {privileged_role_name} to create Event Trigger, while keeping the
* Allow neon_superuser to create Event Trigger, while keeping the
* ownership of the object.
*
* For that we give superuser membership to the role for the execution of
* the command.
*/
if (IsTransactionState() && is_privileged_role())
if (IsTransactionState() && is_neon_superuser())
{
/* Find the Event Trigger function Oid */
Oid func_oid = LookupFuncName(stmt->funcname, 0, NULL, false);
@@ -1235,7 +1232,7 @@ ProcessCreateEventTrigger(
*
* That way [ ALTER | DROP ] EVENT TRIGGER commands just work.
*/
if (IsTransactionState() && is_privileged_role())
if (IsTransactionState() && is_neon_superuser())
{
if (!current_user_is_super)
{
@@ -1355,17 +1352,19 @@ NeonProcessUtility(
}
/*
* Only {privileged_role_name} is granted privilege to edit neon.event_triggers GUC.
* Only neon_superuser is granted privilege to edit neon.event_triggers GUC.
*/
static void
neon_event_triggers_assign_hook(bool newval, void *extra)
{
if (IsTransactionState() && !is_privileged_role())
/* MyDatabaseId == InvalidOid || !OidIsValid(GetUserId()) */
if (IsTransactionState() && !is_neon_superuser())
{
ereport(ERROR,
(errcode(ERRCODE_INSUFFICIENT_PRIVILEGE),
errmsg("permission denied to set neon.event_triggers"),
errdetail("Only \"%s\" is allowed to set the GUC", privileged_role_name)));
errdetail("Only \"neon_superuser\" is allowed to set the GUC")));
}
}

View File

@@ -377,16 +377,6 @@ typedef struct PageserverFeedback
} PageserverFeedback;
/* BEGIN_HADRON */
/**
* WAL proposer is the only backend that will update `sent_bytes` and `last_recorded_time_us`.
* Once the `sent_bytes` reaches the limit, it puts backpressure on PG backends.
*
* A PG backend checks `should_limit` to see if it should hit backpressure.
* - If yes, it also checks the `last_recorded_time_us` to see
* if it's time to push more WALs. This is because the WAL proposer
* only resets `should_limit` to 0 after it is notified about new WALs
* which might take a while.
*/
typedef struct WalRateLimiter
{
/* If the value is 1, PG backends will hit backpressure. */
@@ -394,7 +384,7 @@ typedef struct WalRateLimiter
/* The number of bytes sent in the current second. */
uint64 sent_bytes;
/* The last recorded time in microsecond. */
pg_atomic_uint64 last_recorded_time_us;
TimestampTz last_recorded_time_us;
} WalRateLimiter;
/* END_HADRON */

View File

@@ -449,20 +449,8 @@ backpressure_lag_impl(void)
}
state = GetWalpropShmemState();
if (state != NULL && !!pg_atomic_read_u32(&state->wal_rate_limiter.should_limit))
if (state != NULL && pg_atomic_read_u32(&state->wal_rate_limiter.should_limit) == 1)
{
TimestampTz now = GetCurrentTimestamp();
struct WalRateLimiter *limiter = &state->wal_rate_limiter;
uint64 last_recorded_time = pg_atomic_read_u64(&limiter->last_recorded_time_us);
if (now - last_recorded_time > USECS_PER_SEC)
{
/*
* The backend has past 1 second since the last recorded time and it's time to push more WALs.
* If the backends are pushing WALs too fast, the wal proposer will rate limit them again.
*/
uint32 expected = true;
pg_atomic_compare_exchange_u32(&state->wal_rate_limiter.should_limit, &expected, false);
}
return 1;
}
/* END_HADRON */
@@ -514,7 +502,6 @@ WalproposerShmemInit(void)
pg_atomic_init_u64(&walprop_shared->currentClusterSize, 0);
/* BEGIN_HADRON */
pg_atomic_init_u32(&walprop_shared->wal_rate_limiter.should_limit, 0);
pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.last_recorded_time_us, 0);
/* END_HADRON */
}
LWLockRelease(AddinShmemInitLock);
@@ -533,7 +520,6 @@ WalproposerShmemInit_SyncSafekeeper(void)
pg_atomic_init_u64(&walprop_shared->backpressureThrottlingTime, 0);
/* BEGIN_HADRON */
pg_atomic_init_u32(&walprop_shared->wal_rate_limiter.should_limit, 0);
pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.last_recorded_time_us, 0);
/* END_HADRON */
}
@@ -1565,18 +1551,18 @@ XLogBroadcastWalProposer(WalProposer *wp)
{
uint64 max_wal_bytes = (uint64) databricks_max_wal_mb_per_second * 1024 * 1024;
struct WalRateLimiter *limiter = &state->wal_rate_limiter;
uint64 last_recorded_time = pg_atomic_read_u64(&limiter->last_recorded_time_us);
if (now - last_recorded_time > USECS_PER_SEC)
if (now - limiter->last_recorded_time_us > USECS_PER_SEC)
{
/* Reset the rate limiter */
limiter->last_recorded_time_us = now;
limiter->sent_bytes = 0;
pg_atomic_write_u64(&limiter->last_recorded_time_us, now);
pg_atomic_write_u32(&limiter->should_limit, false);
pg_atomic_exchange_u32(&limiter->should_limit, 0);
}
limiter->sent_bytes += (endptr - startptr);
if (limiter->sent_bytes > max_wal_bytes)
{
pg_atomic_write_u32(&limiter->should_limit, true);
pg_atomic_exchange_u32(&limiter->should_limit, 1);
}
}
/* END_HADRON */

View File

@@ -63,16 +63,19 @@ impl<P: QueueProcessing> BatchQueue<P> {
}
}
pub fn enqueue(&self, req: P::Req) -> (u64, oneshot::Receiver<ProcResult<P>>) {
self.inner.lock_propagate_poison().register_job(req)
}
/// Perform a single request-response process, this may be batched internally.
///
/// This function is not cancel safe.
pub async fn call<R>(
&self,
req: P::Req,
id: u64,
mut rx: oneshot::Receiver<ProcResult<P>>,
cancelled: impl Future<Output = R>,
) -> Result<P::Res, BatchQueueError<P::Err, R>> {
let (id, mut rx) = self.inner.lock_propagate_poison().register_job(req);
let mut cancelled = pin!(cancelled);
let resp: Option<Result<P::Res, P::Err>> = loop {
// try become the leader, or try wait for success.

View File

@@ -1,11 +1,13 @@
use std::collections::{HashMap, HashSet, hash_map};
use std::convert::Infallible;
use std::sync::atomic::AtomicU64;
use std::time::Duration;
use async_trait::async_trait;
use clashmap::ClashMap;
use clashmap::mapref::one::Ref;
use rand::{Rng, thread_rng};
use tokio::sync::Mutex;
use tokio::time::Instant;
use tracing::{debug, info};
@@ -20,23 +22,31 @@ pub(crate) trait ProjectInfoCache {
fn invalidate_endpoint_access_for_project(&self, project_id: ProjectIdInt);
fn invalidate_endpoint_access_for_org(&self, account_id: AccountIdInt);
fn invalidate_role_secret_for_project(&self, project_id: ProjectIdInt, role_name: RoleNameInt);
async fn decrement_active_listeners(&self);
async fn increment_active_listeners(&self);
}
struct Entry<T> {
expires_at: Instant,
created_at: Instant,
value: T,
}
impl<T> Entry<T> {
pub(crate) fn new(value: T, ttl: Duration) -> Self {
pub(crate) fn new(value: T) -> Self {
Self {
expires_at: Instant::now() + ttl,
created_at: Instant::now(),
value,
}
}
pub(crate) fn get(&self) -> Option<&T> {
(self.expires_at > Instant::now()).then_some(&self.value)
pub(crate) fn get(&self, valid_since: Instant) -> Option<&T> {
(valid_since < self.created_at).then_some(&self.value)
}
}
impl<T> From<T> for Entry<T> {
fn from(value: T) -> Self {
Self::new(value)
}
}
@@ -46,12 +56,18 @@ struct EndpointInfo {
}
impl EndpointInfo {
pub(crate) fn get_role_secret(&self, role_name: RoleNameInt) -> Option<RoleAccessControl> {
self.role_controls.get(&role_name)?.get().cloned()
pub(crate) fn get_role_secret(
&self,
role_name: RoleNameInt,
valid_since: Instant,
) -> Option<RoleAccessControl> {
let controls = self.role_controls.get(&role_name)?;
controls.get(valid_since).cloned()
}
pub(crate) fn get_controls(&self) -> Option<EndpointAccessControl> {
self.controls.as_ref()?.get().cloned()
pub(crate) fn get_controls(&self, valid_since: Instant) -> Option<EndpointAccessControl> {
let controls = self.controls.as_ref()?;
controls.get(valid_since).cloned()
}
pub(crate) fn invalidate_endpoint(&mut self) {
@@ -76,8 +92,11 @@ pub struct ProjectInfoCacheImpl {
project2ep: ClashMap<ProjectIdInt, HashSet<EndpointIdInt>>,
// FIXME(stefan): we need a way to GC the account2ep map.
account2ep: ClashMap<AccountIdInt, HashSet<EndpointIdInt>>,
config: ProjectInfoCacheOptions,
start_time: Instant,
ttl_disabled_since_us: AtomicU64,
active_listeners_lock: Mutex<usize>,
}
#[async_trait]
@@ -133,6 +152,29 @@ impl ProjectInfoCache for ProjectInfoCacheImpl {
}
}
}
async fn decrement_active_listeners(&self) {
let mut listeners_guard = self.active_listeners_lock.lock().await;
if *listeners_guard == 0 {
tracing::error!("active_listeners count is already 0, something is broken");
return;
}
*listeners_guard -= 1;
if *listeners_guard == 0 {
self.ttl_disabled_since_us
.store(u64::MAX, std::sync::atomic::Ordering::SeqCst);
}
}
async fn increment_active_listeners(&self) {
let mut listeners_guard = self.active_listeners_lock.lock().await;
*listeners_guard += 1;
if *listeners_guard == 1 {
let new_ttl = (self.start_time.elapsed() + self.config.ttl).as_micros() as u64;
self.ttl_disabled_since_us
.store(new_ttl, std::sync::atomic::Ordering::SeqCst);
}
}
}
impl ProjectInfoCacheImpl {
@@ -142,6 +184,9 @@ impl ProjectInfoCacheImpl {
project2ep: ClashMap::new(),
account2ep: ClashMap::new(),
config,
ttl_disabled_since_us: AtomicU64::new(u64::MAX),
start_time: Instant::now(),
active_listeners_lock: Mutex::new(0),
}
}
@@ -158,17 +203,19 @@ impl ProjectInfoCacheImpl {
endpoint_id: &EndpointId,
role_name: &RoleName,
) -> Option<RoleAccessControl> {
let valid_since = self.get_cache_times();
let role_name = RoleNameInt::get(role_name)?;
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_role_secret(role_name)
endpoint_info.get_role_secret(role_name, valid_since)
}
pub(crate) fn get_endpoint_access(
&self,
endpoint_id: &EndpointId,
) -> Option<EndpointAccessControl> {
let valid_since = self.get_cache_times();
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_controls()
endpoint_info.get_controls(valid_since)
}
pub(crate) fn insert_endpoint_access(
@@ -190,8 +237,8 @@ impl ProjectInfoCacheImpl {
return;
}
let controls = Entry::new(controls, self.config.ttl);
let role_controls = Entry::new(role_controls, self.config.ttl);
let controls = Entry::from(controls);
let role_controls = Entry::from(role_controls);
match self.cache.entry(endpoint_id) {
clashmap::Entry::Vacant(e) => {
@@ -228,6 +275,27 @@ impl ProjectInfoCacheImpl {
}
}
fn ignore_ttl_since(&self) -> Option<Instant> {
let ttl_disabled_since_us = self
.ttl_disabled_since_us
.load(std::sync::atomic::Ordering::Relaxed);
if ttl_disabled_since_us == u64::MAX {
return None;
}
Some(self.start_time + Duration::from_micros(ttl_disabled_since_us))
}
fn get_cache_times(&self) -> Instant {
let mut valid_since = Instant::now() - self.config.ttl;
if let Some(ignore_ttl_since) = self.ignore_ttl_since() {
// We are fine if entry is not older than ttl or was added before we are getting notifications.
valid_since = valid_since.min(ignore_ttl_since);
}
valid_since
}
pub fn maybe_invalidate_role_secret(&self, endpoint_id: &EndpointId, role_name: &RoleName) {
let Some(endpoint_id) = EndpointIdInt::get(endpoint_id) else {
return;
@@ -245,7 +313,16 @@ impl ProjectInfoCacheImpl {
return;
};
if role_controls.get().expires_at <= Instant::now() {
let created_at = role_controls.get().created_at;
let expire = match self.ignore_ttl_since() {
// if ignoring TTL, we should still try and roll the password if it's old
// and we the client gave an incorrect password. There could be some lag on the redis channel.
Some(_) => created_at + self.config.ttl < Instant::now(),
// edge case: redis is down, let's be generous and invalidate the cache immediately.
None => true,
};
if expire {
role_controls.remove();
}
}

View File

@@ -6,6 +6,7 @@ use std::time::Duration;
use futures::FutureExt;
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use metrics::MeasuredCounterPairGuard;
use postgres_client::RawCancelToken;
use postgres_client::tls::MakeTlsConnect;
use redis::{Cmd, FromRedisValue, SetExpiry, SetOptions, Value};
@@ -23,7 +24,9 @@ use crate::context::RequestContext;
use crate::control_plane::ControlPlaneApi;
use crate::error::ReportableError;
use crate::ext::LockExt;
use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind};
use crate::metrics::{
CancelChannelSizeGauge, CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind,
};
use crate::pqproto::CancelKeyData;
use crate::rate_limiter::LeakyBucketRateLimiter;
use crate::redis::keys::KeyPrefix;
@@ -52,6 +55,28 @@ pub enum CancelKeyOp {
GetOld {
key: CancelKeyData,
},
Delete {
key: CancelKeyData,
},
}
impl CancelKeyOp {
fn redis_msg_kind(&self) -> RedisMsgKind {
match self {
CancelKeyOp::Store { .. } => RedisMsgKind::Set,
CancelKeyOp::Refresh { .. } => RedisMsgKind::Expire,
CancelKeyOp::Get { .. } => RedisMsgKind::Get,
CancelKeyOp::GetOld { .. } => RedisMsgKind::HGet,
CancelKeyOp::Delete { .. } => RedisMsgKind::Unlink,
}
}
fn metric_guard(&self) -> MeasuredCounterPairGuard<'static, CancelChannelSizeGauge> {
Metrics::get()
.proxy
.cancel_channel_size
.guard(self.redis_msg_kind())
}
}
#[derive(thiserror::Error, Debug, Clone)]
@@ -107,6 +132,10 @@ impl Pipeline {
self.inner.add_command(cmd);
self.replies += 1;
}
fn add_command_ignore_reply(&mut self, cmd: Cmd) {
self.inner.add_command(cmd).ignore();
}
}
impl CancelKeyOp {
@@ -132,6 +161,10 @@ impl CancelKeyOp {
let key = KeyPrefix::Cancel(*key).build_redis_key();
pipe.add_command(Cmd::get(key));
}
CancelKeyOp::Delete { key } => {
let key = KeyPrefix::Cancel(*key).build_redis_key();
pipe.add_command_ignore_reply(Cmd::unlink(key));
}
}
}
}
@@ -268,14 +301,11 @@ impl CancellationHandler {
return Err(CancelError::InternalError);
};
let guard = Metrics::get()
.proxy
.cancel_channel_size
.guard(RedisMsgKind::Get);
let op = CancelKeyOp::Get { key };
let (id, rx) = tx.enqueue((op.metric_guard(), op));
let result = timeout(
TIMEOUT,
tx.call((guard, op), std::future::pending::<Infallible>()),
tx.call(id, rx, std::future::pending::<Infallible>()),
)
.await
.map_err(|_| {
@@ -293,14 +323,11 @@ impl CancellationHandler {
&& let Some(errcode) = err.code()
&& errcode == "WRONGTYPE"
{
let guard = Metrics::get()
.proxy
.cancel_channel_size
.guard(RedisMsgKind::HGet);
let op = CancelKeyOp::GetOld { key };
let (id, rx) = tx.enqueue((op.metric_guard(), op));
timeout(
TIMEOUT,
tx.call((guard, op), std::future::pending::<Infallible>()),
tx.call(id, rx, std::future::pending::<Infallible>()),
)
.await
.map_err(|_| {
@@ -482,51 +509,59 @@ impl Session {
let mut cancel = pin!(cancel);
#[derive(Copy, Clone, PartialEq, Eq)]
enum State {
Set,
Refresh,
Delete,
}
let mut state = State::Set;
loop {
let guard_op = match state {
let op = match state {
State::Set => {
let guard = Metrics::get()
.proxy
.cancel_channel_size
.guard(RedisMsgKind::Set);
let op = CancelKeyOp::Store {
key: self.key,
value: closure_json.clone(),
expire: CANCEL_KEY_TTL,
};
tracing::debug!(
src=%self.key,
dest=?cancel_closure.cancel_token,
"registering cancellation key"
);
(guard, op)
CancelKeyOp::Store {
key: self.key,
value: closure_json.clone(),
expire: CANCEL_KEY_TTL,
}
}
State::Refresh => {
let guard = Metrics::get()
.proxy
.cancel_channel_size
.guard(RedisMsgKind::Expire);
let op = CancelKeyOp::Refresh {
key: self.key,
expire: CANCEL_KEY_TTL,
};
tracing::debug!(
src=%self.key,
dest=?cancel_closure.cancel_token,
"refreshing cancellation key"
);
(guard, op)
CancelKeyOp::Refresh {
key: self.key,
expire: CANCEL_KEY_TTL,
}
}
State::Delete => {
tracing::debug!(
src=%self.key,
dest=?cancel_closure.cancel_token,
"deleting cancellation key"
);
CancelKeyOp::Delete { key: self.key }
}
};
match tx.call(guard_op, cancel.as_mut()).await {
let (id, rx) = tx.enqueue((op.metric_guard(), op));
if state == State::Delete {
// The key deletion is just best effort. We enqueue the command,
// but don't drive the queue and wait for a response.
break;
}
match tx.call(id, rx, cancel.as_mut()).await {
// SET returns OK
Ok(Value::Okay) => {
tracing::debug!(
@@ -561,7 +596,10 @@ impl Session {
continue;
}
Err(BatchQueueError::Cancelled(Err(_cancelled))) => break,
Err(BatchQueueError::Cancelled(Err(_cancelled))) => {
state = State::Delete;
continue;
}
}
// wait before continuing. break immediately if cancelled.
@@ -569,7 +607,7 @@ impl Session {
.await
.is_err()
{
break;
state = State::Delete;
}
}

View File

@@ -1,10 +1,12 @@
use std::cell::RefCell;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::{env, io};
use chrono::{DateTime, Utc};
use opentelemetry::trace::TraceContextExt;
use serde::ser::{SerializeMap, Serializer};
use tracing::subscriber::Interest;
use tracing::{Event, Metadata, Span, Subscriber, callsite, span};
use tracing_opentelemetry::OpenTelemetrySpanExt;
@@ -14,9 +16,7 @@ use tracing_subscriber::fmt::time::SystemTime;
use tracing_subscriber::fmt::{FormatEvent, FormatFields};
use tracing_subscriber::layer::{Context, Layer};
use tracing_subscriber::prelude::*;
use tracing_subscriber::registry::LookupSpan;
use crate::metrics::Metrics;
use tracing_subscriber::registry::{LookupSpan, SpanRef};
/// Initialize logging and OpenTelemetry tracing and exporter.
///
@@ -210,9 +210,6 @@ struct JsonLoggingLayer<C: Clock, W: MakeWriter> {
/// tracks which fields of each **event** are duplicates
skipped_field_indices: CallsiteMap<SkippedFieldIndices>,
/// tracks callsite names to an ID.
callsite_name_ids: papaya::HashMap<&'static str, u32, ahash::RandomState>,
span_info: CallsiteMap<CallsiteSpanInfo>,
/// Fields we want to keep track of in a separate json object.
@@ -225,7 +222,6 @@ impl<C: Clock, W: MakeWriter> JsonLoggingLayer<C, W> {
clock,
skipped_field_indices: CallsiteMap::default(),
span_info: CallsiteMap::default(),
callsite_name_ids: papaya::HashMap::default(),
writer,
extract_fields,
}
@@ -236,7 +232,7 @@ impl<C: Clock, W: MakeWriter> JsonLoggingLayer<C, W> {
self.span_info
.pin()
.get_or_insert_with(metadata.callsite(), || {
CallsiteSpanInfo::new(&self.callsite_name_ids, metadata, self.extract_fields)
CallsiteSpanInfo::new(metadata, self.extract_fields)
})
.clone()
}
@@ -253,7 +249,7 @@ where
// early, before OTel machinery, and add as event extension.
let now = self.clock.now();
EVENT_FORMATTER.with(|f| {
let res: io::Result<()> = EVENT_FORMATTER.with(|f| {
let mut borrow = f.try_borrow_mut();
let formatter = match borrow.as_deref_mut() {
Ok(formatter) => formatter,
@@ -263,19 +259,31 @@ where
Err(_) => &mut EventFormatter::new(),
};
formatter.reset();
formatter.format(
now,
event,
&ctx,
&self.skipped_field_indices,
self.extract_fields,
);
let mut writer = self.writer.make_writer();
if writer.write_all(formatter.buffer()).is_err() {
Metrics::get().proxy.logging_errors_count.inc();
}
)?;
self.writer.make_writer().write_all(formatter.buffer())
});
// In case logging fails we generate a simpler JSON object.
if let Err(err) = res
&& let Ok(mut line) = serde_json::to_vec(&serde_json::json!( {
"timestamp": now.to_rfc3339_opts(chrono::SecondsFormat::Micros, true),
"level": "ERROR",
"message": format_args!("cannot log event: {err:?}"),
"fields": {
"event": format_args!("{event:?}"),
},
}))
{
line.push(b'\n');
self.writer.make_writer().write_all(&line).ok();
}
}
/// Registers a SpanFields instance as span extension.
@@ -348,11 +356,10 @@ struct CallsiteSpanInfo {
}
impl CallsiteSpanInfo {
fn new(
callsite_name_ids: &papaya::HashMap<&'static str, u32, ahash::RandomState>,
metadata: &'static Metadata<'static>,
extract_fields: &[&'static str],
) -> Self {
fn new(metadata: &'static Metadata<'static>, extract_fields: &[&'static str]) -> Self {
// Start at 1 to reserve 0 for default.
static COUNTER: AtomicU32 = AtomicU32::new(1);
let names: Vec<&'static str> = metadata.fields().iter().map(|f| f.name()).collect();
// get all the indices of span fields we want to focus
@@ -365,18 +372,8 @@ impl CallsiteSpanInfo {
// normalized_name is unique for each callsite, but it is not
// unified across separate proxy instances.
// todo: can we do better here?
let cid = *callsite_name_ids
.pin()
.update_or_insert(metadata.name(), |&cid| cid + 1, 0);
// we hope that most span names are unique, in which case this will always be 0
let normalized_name = if cid == 0 {
metadata.name().into()
} else {
// if the span name is not unique, add the numeric ID to span name to distinguish it.
// sadly this is non-determinstic, across restarts but we should fix it by disambiguating re-used span names instead.
format!("{}#{cid}", metadata.name()).into()
};
let cid = COUNTER.fetch_add(1, Ordering::Relaxed);
let normalized_name = format!("{}#{cid}", metadata.name()).into();
Self {
extract,
@@ -385,24 +382,9 @@ impl CallsiteSpanInfo {
}
}
#[derive(Clone)]
struct RawValue(Box<[u8]>);
impl RawValue {
fn new(v: impl json::ValueEncoder) -> Self {
Self(json::value_to_vec!(|val| v.encode(val)).into_boxed_slice())
}
}
impl json::ValueEncoder for &RawValue {
fn encode(self, v: json::ValueSer<'_>) {
v.write_raw_json(&self.0);
}
}
/// Stores span field values recorded during the spans lifetime.
struct SpanFields {
values: [Option<RawValue>; MAX_TRACING_FIELDS],
values: [serde_json::Value; MAX_TRACING_FIELDS],
/// cached span info so we can avoid extra hashmap lookups in the hot path.
span_info: CallsiteSpanInfo,
@@ -412,7 +394,7 @@ impl SpanFields {
fn new(span_info: CallsiteSpanInfo) -> Self {
Self {
span_info,
values: [const { None }; MAX_TRACING_FIELDS],
values: [const { serde_json::Value::Null }; MAX_TRACING_FIELDS],
}
}
}
@@ -420,55 +402,55 @@ impl SpanFields {
impl tracing::field::Visit for SpanFields {
#[inline]
fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
self.values[field.index()] = Some(RawValue::new(value));
self.values[field.index()] = serde_json::Value::from(value);
}
#[inline]
fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
self.values[field.index()] = Some(RawValue::new(value));
self.values[field.index()] = serde_json::Value::from(value);
}
#[inline]
fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
self.values[field.index()] = Some(RawValue::new(value));
self.values[field.index()] = serde_json::Value::from(value);
}
#[inline]
fn record_i128(&mut self, field: &tracing::field::Field, value: i128) {
if let Ok(value) = i64::try_from(value) {
self.values[field.index()] = Some(RawValue::new(value));
self.values[field.index()] = serde_json::Value::from(value);
} else {
self.values[field.index()] = Some(RawValue::new(format_args!("{value}")));
self.values[field.index()] = serde_json::Value::from(format!("{value}"));
}
}
#[inline]
fn record_u128(&mut self, field: &tracing::field::Field, value: u128) {
if let Ok(value) = u64::try_from(value) {
self.values[field.index()] = Some(RawValue::new(value));
self.values[field.index()] = serde_json::Value::from(value);
} else {
self.values[field.index()] = Some(RawValue::new(format_args!("{value}")));
self.values[field.index()] = serde_json::Value::from(format!("{value}"));
}
}
#[inline]
fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
self.values[field.index()] = Some(RawValue::new(value));
self.values[field.index()] = serde_json::Value::from(value);
}
#[inline]
fn record_bytes(&mut self, field: &tracing::field::Field, value: &[u8]) {
self.values[field.index()] = Some(RawValue::new(value));
self.values[field.index()] = serde_json::Value::from(value);
}
#[inline]
fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
self.values[field.index()] = Some(RawValue::new(value));
self.values[field.index()] = serde_json::Value::from(value);
}
#[inline]
fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
self.values[field.index()] = Some(RawValue::new(format_args!("{value:?}")));
self.values[field.index()] = serde_json::Value::from(format!("{value:?}"));
}
#[inline]
@@ -477,7 +459,7 @@ impl tracing::field::Visit for SpanFields {
field: &tracing::field::Field,
value: &(dyn std::error::Error + 'static),
) {
self.values[field.index()] = Some(RawValue::new(format_args!("{value}")));
self.values[field.index()] = serde_json::Value::from(format!("{value}"));
}
}
@@ -526,6 +508,11 @@ impl EventFormatter {
&self.logline_buffer
}
#[inline]
fn reset(&mut self) {
self.logline_buffer.clear();
}
fn format<S>(
&mut self,
now: DateTime<Utc>,
@@ -533,7 +520,8 @@ impl EventFormatter {
ctx: &Context<'_, S>,
skipped_field_indices: &CallsiteMap<SkippedFieldIndices>,
extract_fields: &'static [&'static str],
) where
) -> io::Result<()>
where
S: Subscriber + for<'a> LookupSpan<'a>,
{
let timestamp = now.to_rfc3339_opts(chrono::SecondsFormat::Micros, true);
@@ -548,99 +536,78 @@ impl EventFormatter {
.copied()
.unwrap_or_default();
self.logline_buffer.clear();
let serializer = json::ValueSer::new(&mut self.logline_buffer);
json::value_as_object!(|serializer| {
let mut serialize = || {
let mut serializer = serde_json::Serializer::new(&mut self.logline_buffer);
let mut serializer = serializer.serialize_map(None)?;
// Timestamp comes first, so raw lines can be sorted by timestamp.
serializer.entry("timestamp", &*timestamp);
serializer.serialize_entry("timestamp", &timestamp)?;
// Level next.
serializer.entry("level", meta.level().as_str());
serializer.serialize_entry("level", &meta.level().as_str())?;
// Message next.
serializer.serialize_key("message")?;
let mut message_extractor =
MessageFieldExtractor::new(serializer.key("message"), skipped_field_indices);
MessageFieldExtractor::new(serializer, skipped_field_indices);
event.record(&mut message_extractor);
message_extractor.finish();
let mut serializer = message_extractor.into_serializer()?;
// Direct message fields.
{
let mut message_skipper = MessageFieldSkipper::new(
serializer.key("fields").object(),
skipped_field_indices,
);
event.record(&mut message_skipper);
// rollback if no fields are present.
if message_skipper.present {
message_skipper.serializer.finish();
}
let mut fields_present = FieldsPresent(false, skipped_field_indices);
event.record(&mut fields_present);
if fields_present.0 {
serializer.serialize_entry(
"fields",
&SerializableEventFields(event, skipped_field_indices),
)?;
}
let mut extracted = ExtractedSpanFields::new(extract_fields);
let spans = serializer.key("spans");
json::value_as_object!(|spans| {
let parent_spans = ctx
let spans = SerializableSpans {
// collect all spans from parent to root.
spans: ctx
.event_span(event)
.map_or(vec![], |parent| parent.scope().collect());
for span in parent_spans.iter().rev() {
let ext = span.extensions();
// all spans should have this extension.
let Some(fields) = ext.get() else { continue };
extracted.layer_span(fields);
let SpanFields { values, span_info } = fields;
let span_fields = spans.key(&*span_info.normalized_name);
json::value_as_object!(|span_fields| {
for (field, value) in std::iter::zip(span.metadata().fields(), values) {
if let Some(value) = value {
span_fields.entry(field.name(), value);
}
}
});
}
});
.map_or(vec![], |parent| parent.scope().collect()),
extracted: ExtractedSpanFields::new(extract_fields),
};
serializer.serialize_entry("spans", &spans)?;
// TODO: thread-local cache?
let pid = std::process::id();
// Skip adding pid 1 to reduce noise for services running in containers.
if pid != 1 {
serializer.entry("process_id", pid);
serializer.serialize_entry("process_id", &pid)?;
}
THREAD_ID.with(|tid| serializer.entry("thread_id", tid));
THREAD_ID.with(|tid| serializer.serialize_entry("thread_id", tid))?;
// TODO: tls cache? name could change
if let Some(thread_name) = std::thread::current().name()
&& !thread_name.is_empty()
&& thread_name != "tokio-runtime-worker"
{
serializer.entry("thread_name", thread_name);
serializer.serialize_entry("thread_name", thread_name)?;
}
if let Some(task_id) = tokio::task::try_id() {
serializer.entry("task_id", format_args!("{task_id}"));
serializer.serialize_entry("task_id", &format_args!("{task_id}"))?;
}
serializer.entry("target", meta.target());
serializer.serialize_entry("target", meta.target())?;
// Skip adding module if it's the same as target.
if let Some(module) = meta.module_path()
&& module != meta.target()
{
serializer.entry("module", module);
serializer.serialize_entry("module", module)?;
}
if let Some(file) = meta.file() {
if let Some(line) = meta.line() {
serializer.entry("src", format_args!("{file}:{line}"));
serializer.serialize_entry("src", &format_args!("{file}:{line}"))?;
} else {
serializer.entry("src", file);
serializer.serialize_entry("src", file)?;
}
}
@@ -649,104 +616,124 @@ impl EventFormatter {
let otel_spanref = otel_context.span();
let span_context = otel_spanref.span_context();
if span_context.is_valid() {
serializer.entry("trace_id", format_args!("{}", span_context.trace_id()));
serializer.serialize_entry(
"trace_id",
&format_args!("{}", span_context.trace_id()),
)?;
}
}
if extracted.has_values() {
if spans.extracted.has_values() {
// TODO: add fields from event, too?
let extract = serializer.key("extract");
json::value_as_object!(|extract| {
for (key, value) in std::iter::zip(extracted.names, extracted.values) {
if let Some(value) = value {
extract.entry(*key, &value);
}
}
});
serializer.serialize_entry("extract", &spans.extracted)?;
}
});
serializer.end()
};
serialize().map_err(io::Error::other)?;
self.logline_buffer.push(b'\n');
Ok(())
}
}
/// Extracts the message field that's mixed will other fields.
struct MessageFieldExtractor<'buf> {
serializer: Option<json::ValueSer<'buf>>,
struct MessageFieldExtractor<S: serde::ser::SerializeMap> {
serializer: S,
skipped_field_indices: SkippedFieldIndices,
state: Option<Result<(), S::Error>>,
}
impl<'buf> MessageFieldExtractor<'buf> {
impl<S: serde::ser::SerializeMap> MessageFieldExtractor<S> {
#[inline]
fn new(serializer: json::ValueSer<'buf>, skipped_field_indices: SkippedFieldIndices) -> Self {
fn new(serializer: S, skipped_field_indices: SkippedFieldIndices) -> Self {
Self {
serializer: Some(serializer),
serializer,
skipped_field_indices,
state: None,
}
}
#[inline]
fn finish(self) {
if let Some(ser) = self.serializer {
ser.value("");
fn into_serializer(mut self) -> Result<S, S::Error> {
match self.state {
Some(Ok(())) => {}
Some(Err(err)) => return Err(err),
None => self.serializer.serialize_value("")?,
}
Ok(self.serializer)
}
#[inline]
fn record_field(&mut self, field: &tracing::field::Field, v: impl json::ValueEncoder) {
if field.name() == MESSAGE_FIELD
fn accept_field(&self, field: &tracing::field::Field) -> bool {
self.state.is_none()
&& field.name() == MESSAGE_FIELD
&& !self.skipped_field_indices.contains(field.index())
&& let Some(ser) = self.serializer.take()
{
ser.value(v);
}
}
}
impl tracing::field::Visit for MessageFieldExtractor<'_> {
impl<S: serde::ser::SerializeMap> tracing::field::Visit for MessageFieldExtractor<S> {
#[inline]
fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
self.record_field(field, value);
if self.accept_field(field) {
self.state = Some(self.serializer.serialize_value(&value));
}
}
#[inline]
fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
self.record_field(field, value);
if self.accept_field(field) {
self.state = Some(self.serializer.serialize_value(&value));
}
}
#[inline]
fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
self.record_field(field, value);
if self.accept_field(field) {
self.state = Some(self.serializer.serialize_value(&value));
}
}
#[inline]
fn record_i128(&mut self, field: &tracing::field::Field, value: i128) {
self.record_field(field, value);
if self.accept_field(field) {
self.state = Some(self.serializer.serialize_value(&value));
}
}
#[inline]
fn record_u128(&mut self, field: &tracing::field::Field, value: u128) {
self.record_field(field, value);
if self.accept_field(field) {
self.state = Some(self.serializer.serialize_value(&value));
}
}
#[inline]
fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
self.record_field(field, value);
if self.accept_field(field) {
self.state = Some(self.serializer.serialize_value(&value));
}
}
#[inline]
fn record_bytes(&mut self, field: &tracing::field::Field, value: &[u8]) {
self.record_field(field, format_args!("{value:x?}"));
if self.accept_field(field) {
self.state = Some(self.serializer.serialize_value(&format_args!("{value:x?}")));
}
}
#[inline]
fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
self.record_field(field, value);
if self.accept_field(field) {
self.state = Some(self.serializer.serialize_value(&value));
}
}
#[inline]
fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
self.record_field(field, format_args!("{value:?}"));
if self.accept_field(field) {
self.state = Some(self.serializer.serialize_value(&format_args!("{value:?}")));
}
}
#[inline]
@@ -755,83 +742,147 @@ impl tracing::field::Visit for MessageFieldExtractor<'_> {
field: &tracing::field::Field,
value: &(dyn std::error::Error + 'static),
) {
self.record_field(field, format_args!("{value}"));
if self.accept_field(field) {
self.state = Some(self.serializer.serialize_value(&format_args!("{value}")));
}
}
}
/// Checks if there's any fields and field values present. If not, the JSON subobject
/// can be skipped.
// This is entirely optional and only cosmetic, though maybe helps a
// bit during log parsing in dashboards when there's no field with empty object.
struct FieldsPresent(pub bool, SkippedFieldIndices);
// Even though some methods have an overhead (error, bytes) it is assumed the
// compiler won't include this since we ignore the value entirely.
impl tracing::field::Visit for FieldsPresent {
#[inline]
fn record_debug(&mut self, field: &tracing::field::Field, _: &dyn std::fmt::Debug) {
if !self.1.contains(field.index())
&& field.name() != MESSAGE_FIELD
&& !field.name().starts_with("log.")
{
self.0 |= true;
}
}
}
/// Serializes the fields directly supplied with a log event.
struct SerializableEventFields<'a, 'event>(&'a tracing::Event<'event>, SkippedFieldIndices);
impl serde::ser::Serialize for SerializableEventFields<'_, '_> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
use serde::ser::SerializeMap;
let serializer = serializer.serialize_map(None)?;
let mut message_skipper = MessageFieldSkipper::new(serializer, self.1);
self.0.record(&mut message_skipper);
let serializer = message_skipper.into_serializer()?;
serializer.end()
}
}
/// A tracing field visitor that skips the message field.
struct MessageFieldSkipper<'buf> {
serializer: json::ObjectSer<'buf>,
struct MessageFieldSkipper<S: serde::ser::SerializeMap> {
serializer: S,
skipped_field_indices: SkippedFieldIndices,
present: bool,
state: Result<(), S::Error>,
}
impl<'buf> MessageFieldSkipper<'buf> {
impl<S: serde::ser::SerializeMap> MessageFieldSkipper<S> {
#[inline]
fn new(serializer: json::ObjectSer<'buf>, skipped_field_indices: SkippedFieldIndices) -> Self {
fn new(serializer: S, skipped_field_indices: SkippedFieldIndices) -> Self {
Self {
serializer,
skipped_field_indices,
present: false,
state: Ok(()),
}
}
#[inline]
fn record_field(&mut self, field: &tracing::field::Field, v: impl json::ValueEncoder) {
if field.name() != MESSAGE_FIELD
fn accept_field(&self, field: &tracing::field::Field) -> bool {
self.state.is_ok()
&& field.name() != MESSAGE_FIELD
&& !field.name().starts_with("log.")
&& !self.skipped_field_indices.contains(field.index())
{
self.serializer.entry(field.name(), v);
self.present |= true;
}
}
#[inline]
fn into_serializer(self) -> Result<S, S::Error> {
self.state?;
Ok(self.serializer)
}
}
impl tracing::field::Visit for MessageFieldSkipper<'_> {
impl<S: serde::ser::SerializeMap> tracing::field::Visit for MessageFieldSkipper<S> {
#[inline]
fn record_f64(&mut self, field: &tracing::field::Field, value: f64) {
self.record_field(field, value);
if self.accept_field(field) {
self.state = self.serializer.serialize_entry(field.name(), &value);
}
}
#[inline]
fn record_i64(&mut self, field: &tracing::field::Field, value: i64) {
self.record_field(field, value);
if self.accept_field(field) {
self.state = self.serializer.serialize_entry(field.name(), &value);
}
}
#[inline]
fn record_u64(&mut self, field: &tracing::field::Field, value: u64) {
self.record_field(field, value);
if self.accept_field(field) {
self.state = self.serializer.serialize_entry(field.name(), &value);
}
}
#[inline]
fn record_i128(&mut self, field: &tracing::field::Field, value: i128) {
self.record_field(field, value);
if self.accept_field(field) {
self.state = self.serializer.serialize_entry(field.name(), &value);
}
}
#[inline]
fn record_u128(&mut self, field: &tracing::field::Field, value: u128) {
self.record_field(field, value);
if self.accept_field(field) {
self.state = self.serializer.serialize_entry(field.name(), &value);
}
}
#[inline]
fn record_bool(&mut self, field: &tracing::field::Field, value: bool) {
self.record_field(field, value);
if self.accept_field(field) {
self.state = self.serializer.serialize_entry(field.name(), &value);
}
}
#[inline]
fn record_bytes(&mut self, field: &tracing::field::Field, value: &[u8]) {
self.record_field(field, format_args!("{value:x?}"));
if self.accept_field(field) {
self.state = self
.serializer
.serialize_entry(field.name(), &format_args!("{value:x?}"));
}
}
#[inline]
fn record_str(&mut self, field: &tracing::field::Field, value: &str) {
self.record_field(field, value);
if self.accept_field(field) {
self.state = self.serializer.serialize_entry(field.name(), &value);
}
}
#[inline]
fn record_debug(&mut self, field: &tracing::field::Field, value: &dyn std::fmt::Debug) {
self.record_field(field, format_args!("{value:?}"));
if self.accept_field(field) {
self.state = self
.serializer
.serialize_entry(field.name(), &format_args!("{value:?}"));
}
}
#[inline]
@@ -840,40 +891,131 @@ impl tracing::field::Visit for MessageFieldSkipper<'_> {
field: &tracing::field::Field,
value: &(dyn std::error::Error + 'static),
) {
self.record_field(field, format_args!("{value}"));
if self.accept_field(field) {
self.state = self.serializer.serialize_value(&format_args!("{value}"));
}
}
}
/// Serializes the span stack from root to leaf (parent of event) as object
/// with the span names as keys. To prevent collision we append a numberic value
/// to the name. Also, collects any span fields we're interested in. Last one
/// wins.
struct SerializableSpans<'ctx, S>
where
S: for<'lookup> LookupSpan<'lookup>,
{
spans: Vec<SpanRef<'ctx, S>>,
extracted: ExtractedSpanFields,
}
impl<S> serde::ser::Serialize for SerializableSpans<'_, S>
where
S: for<'lookup> LookupSpan<'lookup>,
{
fn serialize<Ser>(&self, serializer: Ser) -> Result<Ser::Ok, Ser::Error>
where
Ser: serde::ser::Serializer,
{
let mut serializer = serializer.serialize_map(None)?;
for span in self.spans.iter().rev() {
let ext = span.extensions();
// all spans should have this extension.
let Some(fields) = ext.get() else { continue };
self.extracted.layer_span(fields);
let SpanFields { values, span_info } = fields;
serializer.serialize_entry(
&*span_info.normalized_name,
&SerializableSpanFields {
fields: span.metadata().fields(),
values,
},
)?;
}
serializer.end()
}
}
/// Serializes the span fields as object.
struct SerializableSpanFields<'span> {
fields: &'span tracing::field::FieldSet,
values: &'span [serde_json::Value; MAX_TRACING_FIELDS],
}
impl serde::ser::Serialize for SerializableSpanFields<'_> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::ser::Serializer,
{
let mut serializer = serializer.serialize_map(None)?;
for (field, value) in std::iter::zip(self.fields, self.values) {
if value.is_null() {
continue;
}
serializer.serialize_entry(field.name(), value)?;
}
serializer.end()
}
}
struct ExtractedSpanFields {
names: &'static [&'static str],
values: Vec<Option<RawValue>>,
values: RefCell<Vec<serde_json::Value>>,
}
impl ExtractedSpanFields {
fn new(names: &'static [&'static str]) -> Self {
ExtractedSpanFields {
names,
values: vec![None; names.len()],
values: RefCell::new(vec![serde_json::Value::Null; names.len()]),
}
}
fn layer_span(&mut self, fields: &SpanFields) {
fn layer_span(&self, fields: &SpanFields) {
let mut v = self.values.borrow_mut();
let SpanFields { values, span_info } = fields;
// extract the fields
for (i, &j) in span_info.extract.iter().enumerate() {
let Some(Some(value)) = values.get(j) else {
continue;
};
let Some(value) = values.get(j) else { continue };
// TODO: replace clone with reference, if possible.
self.values[i] = Some(value.clone());
if !value.is_null() {
// TODO: replace clone with reference, if possible.
v[i] = value.clone();
}
}
}
#[inline]
fn has_values(&self) -> bool {
self.values.iter().any(|v| v.is_some())
self.values.borrow().iter().any(|v| !v.is_null())
}
}
impl serde::ser::Serialize for ExtractedSpanFields {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::ser::Serializer,
{
let mut serializer = serializer.serialize_map(None)?;
let values = self.values.borrow();
for (key, value) in std::iter::zip(self.names, &*values) {
if value.is_null() {
continue;
}
serializer.serialize_entry(key, value)?;
}
serializer.end()
}
}
@@ -928,7 +1070,6 @@ mod tests {
clock: clock.clone(),
skipped_field_indices: papaya::HashMap::default(),
span_info: papaya::HashMap::default(),
callsite_name_ids: papaya::HashMap::default(),
writer: buffer.clone(),
extract_fields: &["x"],
};
@@ -937,16 +1078,14 @@ mod tests {
tracing::subscriber::with_default(registry, || {
info_span!("some_span", x = 24).in_scope(|| {
info_span!("some_other_span", y = 30).in_scope(|| {
info_span!("some_span", x = 40, x = 41, x = 42).in_scope(|| {
tracing::error!(
a = 1,
a = 2,
a = 3,
message = "explicit message field",
"implicit message field"
);
});
info_span!("some_span", x = 40, x = 41, x = 42).in_scope(|| {
tracing::error!(
a = 1,
a = 2,
a = 3,
message = "explicit message field",
"implicit message field"
);
});
});
});
@@ -965,15 +1104,12 @@ mod tests {
"a": 3,
},
"spans": {
"some_span":{
"some_span#1":{
"x": 24,
},
"some_other_span": {
"y": 30,
},
"some_span#1": {
"some_span#2": {
"x": 42,
},
}
},
"extract": {
"x": 42,

View File

@@ -112,9 +112,6 @@ pub struct ProxyMetrics {
/// Number of bytes sent/received between all clients and backends.
pub io_bytes: CounterVec<StaticLabelSet<Direction>>,
/// Number of IO errors while logging.
pub logging_errors_count: Counter,
/// Number of errors by a given classification.
pub errors_total: CounterVec<StaticLabelSet<crate::error::ErrorKind>>,
@@ -381,6 +378,7 @@ pub enum RedisMsgKind {
Get,
Expire,
HGet,
Unlink,
}
#[derive(Default, Clone)]

View File

@@ -265,7 +265,10 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
return Ok(());
}
let mut conn = match try_connect(&redis).await {
Ok(conn) => conn,
Ok(conn) => {
handler.cache.increment_active_listeners().await;
conn
}
Err(e) => {
tracing::error!(
"failed to connect to redis: {e}, will try to reconnect in {RECONNECT_TIMEOUT:#?}"
@@ -284,9 +287,11 @@ async fn handle_messages<C: ProjectInfoCache + Send + Sync + 'static>(
}
}
if cancellation_token.is_cancelled() {
handler.cache.decrement_active_listeners().await;
return Ok(());
}
}
handler.cache.decrement_active_listeners().await;
}
}

View File

@@ -58,7 +58,6 @@ metrics.workspace = true
pem.workspace = true
postgres_backend.workspace = true
postgres_ffi.workspace = true
postgres_ffi_types.workspace = true
postgres_versioninfo.workspace = true
pq_proto.workspace = true
remote_storage.workspace = true
@@ -72,7 +71,6 @@ http-utils.workspace = true
utils.workspace = true
wal_decoder.workspace = true
env_logger.workspace = true
nix.workspace = true
workspace_hack.workspace = true

View File

@@ -17,9 +17,8 @@ use http_utils::tls_certs::ReloadingCertificateResolver;
use metrics::set_build_info_metric;
use remote_storage::RemoteStorageConfig;
use safekeeper::defaults::{
DEFAULT_CONTROL_FILE_SAVE_INTERVAL, DEFAULT_EVICTION_MIN_RESIDENT,
DEFAULT_GLOBAL_DISK_CHECK_INTERVAL, DEFAULT_HEARTBEAT_TIMEOUT, DEFAULT_HTTP_LISTEN_ADDR,
DEFAULT_MAX_GLOBAL_DISK_USAGE_RATIO, DEFAULT_MAX_OFFLOADER_LAG_BYTES,
DEFAULT_CONTROL_FILE_SAVE_INTERVAL, DEFAULT_EVICTION_MIN_RESIDENT, DEFAULT_HEARTBEAT_TIMEOUT,
DEFAULT_HTTP_LISTEN_ADDR, DEFAULT_MAX_OFFLOADER_LAG_BYTES,
DEFAULT_MAX_REELECT_OFFLOADER_LAG_BYTES, DEFAULT_MAX_TIMELINE_DISK_USAGE_BYTES,
DEFAULT_PARTIAL_BACKUP_CONCURRENCY, DEFAULT_PARTIAL_BACKUP_TIMEOUT, DEFAULT_PG_LISTEN_ADDR,
DEFAULT_SSL_CERT_FILE, DEFAULT_SSL_CERT_RELOAD_PERIOD, DEFAULT_SSL_KEY_FILE,
@@ -43,12 +42,6 @@ use utils::metrics_collector::{METRICS_COLLECTION_INTERVAL, METRICS_COLLECTOR};
use utils::sentry_init::init_sentry;
use utils::{pid_file, project_build_tag, project_git_version, tcp_listener};
use safekeeper::hadron::{
GLOBAL_DISK_LIMIT_EXCEEDED, get_filesystem_capacity, get_filesystem_usage,
};
use safekeeper::metrics::GLOBAL_DISK_UTIL_CHECK_SECONDS;
use std::sync::atomic::Ordering;
#[global_allocator]
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
@@ -263,15 +256,6 @@ struct Args {
/* BEGIN_HADRON */
#[arg(long)]
enable_pull_timeline_on_startup: bool,
/// How often to scan entire data-dir for total disk usage
#[arg(long, value_parser=humantime::parse_duration, default_value = DEFAULT_GLOBAL_DISK_CHECK_INTERVAL)]
global_disk_check_interval: Duration,
/// The portion of the filesystem capacity that can be used by all timelines.
/// A circuit breaker will trip and reject all WAL writes if the total usage
/// exceeds this ratio.
/// Set to 0 to disable the global disk usage limit.
#[arg(long, default_value_t = DEFAULT_MAX_GLOBAL_DISK_USAGE_RATIO)]
max_global_disk_usage_ratio: f64,
/* END_HADRON */
}
@@ -460,8 +444,6 @@ async fn main() -> anyhow::Result<()> {
advertise_pg_addr_tenant_only: None,
enable_pull_timeline_on_startup: args.enable_pull_timeline_on_startup,
hcc_base_url: None,
global_disk_check_interval: args.global_disk_check_interval,
max_global_disk_usage_ratio: args.max_global_disk_usage_ratio,
/* END_HADRON */
});
@@ -636,49 +618,6 @@ async fn start_safekeeper(conf: Arc<SafeKeeperConf>) -> Result<()> {
.map(|res| ("Timeline map housekeeping".to_owned(), res));
tasks_handles.push(Box::pin(timeline_housekeeping_handle));
/* BEGIN_HADRON */
// Spawn global disk usage watcher task, if a global disk usage limit is specified.
let interval = conf.global_disk_check_interval;
let data_dir = conf.workdir.clone();
// Use the safekeeper data directory to compute filesystem capacity. This only runs once on startup, so
// there is little point to continue if we can't have the proper protections in place.
let fs_capacity_bytes = get_filesystem_capacity(data_dir.as_std_path())
.expect("Failed to get filesystem capacity for data directory");
let limit: u64 = (conf.max_global_disk_usage_ratio * fs_capacity_bytes as f64) as u64;
if limit > 0 {
let disk_usage_watch_handle = BACKGROUND_RUNTIME
.handle()
.spawn(async move {
// Use Tokio interval to preserve fixed cadence between filesystem utilization checks
let mut ticker = tokio::time::interval(interval);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
loop {
ticker.tick().await;
let data_dir_clone = data_dir.clone();
let check_start = Instant::now();
let usage = tokio::task::spawn_blocking(move || {
get_filesystem_usage(data_dir_clone.as_std_path())
})
.await
.unwrap_or(0);
let elapsed = check_start.elapsed().as_secs_f64();
GLOBAL_DISK_UTIL_CHECK_SECONDS.observe(elapsed);
if usage > limit {
warn!(
"Global disk usage exceeded limit. Usage: {} bytes, limit: {} bytes",
usage, limit
);
}
GLOBAL_DISK_LIMIT_EXCEEDED.store(usage > limit, Ordering::Relaxed);
}
})
.map(|res| ("Global disk usage watcher".to_string(), res));
tasks_handles.push(Box::pin(disk_usage_watch_handle));
}
/* END_HADRON */
if let Some(pg_listener_tenant_only) = pg_listener_tenant_only {
let wal_service_handle = current_thread_rt
.as_ref()

View File

@@ -1,17 +1,12 @@
use once_cell::sync::Lazy;
use pem::Pem;
use safekeeper_api::models::PullTimelineRequest;
use std::{
collections::HashMap, env::VarError, net::IpAddr, sync::Arc, sync::atomic::AtomicBool,
time::Duration,
};
use std::{collections::HashMap, env::VarError, net::IpAddr, sync::Arc, time::Duration};
use tokio::time::sleep;
use tokio_util::sync::CancellationToken;
use url::Url;
use utils::{backoff, critical_timeline, id::TenantTimelineId, ip_address};
use anyhow::{Result, anyhow};
use utils::{backoff, id::TenantTimelineId, ip_address};
use anyhow::Result;
use pageserver_api::controller_api::{
AvailabilityZone, NodeRegisterRequest, SafekeeperTimeline, SafekeeperTimelinesResponse,
};
@@ -351,70 +346,6 @@ pub async fn hcc_pull_timelines(
Ok(())
}
/// true if the last background scan found total usage > limit
pub static GLOBAL_DISK_LIMIT_EXCEEDED: Lazy<AtomicBool> = Lazy::new(|| AtomicBool::new(false));
/// Returns filesystem usage in bytes for the filesystem containing the given path.
// Need to suppress the clippy::unnecessary_cast warning because the casts on the block count and the
// block size are required on macOS (they are 32-bit integers on macOS, apparantly).
#[allow(clippy::unnecessary_cast)]
pub fn get_filesystem_usage(path: &std::path::Path) -> u64 {
// Allow overriding disk usage via failpoint for tests
fail::fail_point!("sk-global-disk-usage", |val| {
// val is Option<String>; parse payload if present
val.and_then(|s| s.parse::<u64>().ok()).unwrap_or(0)
});
// Call statvfs(3) for filesystem usage
use nix::sys::statvfs::statvfs;
match statvfs(path) {
Ok(stat) => {
// fragment size (f_frsize) if non-zero else block size (f_bsize)
let frsize = stat.fragment_size();
let blocksz = if frsize > 0 {
frsize
} else {
stat.block_size()
};
// used blocks = total blocks - available blocks for unprivileged
let used_blocks = stat.blocks().saturating_sub(stat.blocks_available());
used_blocks as u64 * blocksz as u64
}
Err(e) => {
// The global disk usage watcher aren't associated with a tenant or timeline, so we just
// pass placeholder (all-zero) tenant and timeline IDs to the critical!() macro.
let placeholder_ttid = TenantTimelineId::empty();
critical_timeline!(
placeholder_ttid.tenant_id,
placeholder_ttid.timeline_id,
"Global disk usage watcher failed to read filesystem usage: {:?}",
e
);
0
}
}
}
/// Returns the total capacity of the current working directory's filesystem in bytes.
#[allow(clippy::unnecessary_cast)]
pub fn get_filesystem_capacity(path: &std::path::Path) -> Result<u64> {
// Call statvfs(3) for filesystem stats
use nix::sys::statvfs::statvfs;
match statvfs(path) {
Ok(stat) => {
// fragment size (f_frsize) if non-zero else block size (f_bsize)
let frsize = stat.fragment_size();
let blocksz = if frsize > 0 {
frsize
} else {
stat.block_size()
};
Ok(stat.blocks() as u64 * blocksz as u64)
}
Err(e) => Err(anyhow!("Failed to read filesystem capacity: {:?}", e)),
}
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -33,13 +33,11 @@ use utils::id::{TenantId, TenantTimelineId, TimelineId};
use utils::lsn::Lsn;
use crate::debug_dump::TimelineDigestRequest;
use crate::hadron::{get_filesystem_capacity, get_filesystem_usage};
use crate::safekeeper::TermLsn;
use crate::timelines_global_map::DeleteOrExclude;
use crate::{
GlobalTimelines, SafeKeeperConf, copy_timeline, debug_dump, patch_control_file, pull_timeline,
};
use serde_json::json;
/// Healthcheck handler.
async fn status_handler(request: Request<Body>) -> Result<Response<Body>, ApiError> {
@@ -129,21 +127,6 @@ async fn utilization_handler(request: Request<Body>) -> Result<Response<Body>, A
json_response(StatusCode::OK, utilization)
}
/// Returns filesystem capacity and current utilization for the safekeeper data directory.
async fn filesystem_usage_handler(request: Request<Body>) -> Result<Response<Body>, ApiError> {
check_permission(&request, None)?;
let conf = get_conf(&request);
let path = conf.workdir.as_std_path();
let capacity = get_filesystem_capacity(path).map_err(ApiError::InternalServerError)?;
let usage = get_filesystem_usage(path);
let resp = json!({
"data_dir": path,
"capacity_bytes": capacity,
"usage_bytes": usage,
});
json_response(StatusCode::OK, resp)
}
/// List all (not deleted) timelines.
/// Note: it is possible to do the same with debug_dump.
async fn timeline_list_handler(request: Request<Body>) -> Result<Response<Body>, ApiError> {
@@ -747,11 +730,6 @@ pub fn make_router(
})
})
.get("/v1/utilization", |r| request_span(r, utilization_handler))
/* BEGIN_HADRON */
.get("/v1/debug/filesystem_usage", |r| {
request_span(r, filesystem_usage_handler)
})
/* END_HADRON */
.delete("/v1/tenant/:tenant_id", |r| {
request_span(r, tenant_delete_handler)
})

View File

@@ -50,7 +50,6 @@ pub mod wal_storage;
pub mod test_utils;
mod timelines_global_map;
use std::sync::Arc;
pub use timelines_global_map::GlobalTimelines;
@@ -84,10 +83,6 @@ pub mod defaults {
pub const DEFAULT_SSL_KEY_FILE: &str = "server.key";
pub const DEFAULT_SSL_CERT_FILE: &str = "server.crt";
pub const DEFAULT_SSL_CERT_RELOAD_PERIOD: &str = "60s";
// Global disk watcher defaults
pub const DEFAULT_GLOBAL_DISK_CHECK_INTERVAL: &str = "60s";
pub const DEFAULT_MAX_GLOBAL_DISK_USAGE_RATIO: f64 = 0.0;
}
#[derive(Debug, Clone)]
@@ -121,10 +116,6 @@ pub struct SafeKeeperConf {
/* BEGIN_HADRON */
pub max_reelect_offloader_lag_bytes: u64,
pub max_timeline_disk_usage_bytes: u64,
/// How often to check the working directory's filesystem for total disk usage.
pub global_disk_check_interval: Duration,
/// The portion of the filesystem capacity that can be used by all timelines.
pub max_global_disk_usage_ratio: f64,
/* END_HADRON */
pub backup_parallel_jobs: usize,
pub wal_backup_enabled: bool,
@@ -182,8 +173,6 @@ impl SafeKeeperConf {
/* BEGIN_HADRON */
max_reelect_offloader_lag_bytes: defaults::DEFAULT_MAX_REELECT_OFFLOADER_LAG_BYTES,
max_timeline_disk_usage_bytes: defaults::DEFAULT_MAX_TIMELINE_DISK_USAGE_BYTES,
global_disk_check_interval: Duration::from_secs(60),
max_global_disk_usage_ratio: defaults::DEFAULT_MAX_GLOBAL_DISK_USAGE_RATIO,
/* END_HADRON */
current_thread_runtime: false,
walsenders_keep_horizon: false,
@@ -246,13 +235,10 @@ pub static WAL_BACKUP_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
.expect("Failed to create WAL backup runtime")
});
/// Hadron: Dedicated runtime for infrequent background tasks.
pub static BACKGROUND_RUNTIME: Lazy<Runtime> = Lazy::new(|| {
tokio::runtime::Builder::new_multi_thread()
.thread_name("Hadron background worker")
// One worker thread is enough, as most of the actual tasks run on blocking threads
// which has it own thread pool.
.worker_threads(1)
.thread_name("background worker")
.worker_threads(1) // there is only one task now (ssl certificate reloading), having more threads doesn't make sense
.enable_all()
.build()
.expect("Failed to create background runtime")

View File

@@ -963,17 +963,3 @@ async fn collect_timeline_metrics(global_timelines: Arc<GlobalTimelines>) -> Vec
}
res
}
/* BEGIN_HADRON */
// Metrics reporting the time spent to perform each safekeeper filesystem utilization check.
pub static GLOBAL_DISK_UTIL_CHECK_SECONDS: Lazy<Histogram> = Lazy::new(|| {
// Buckets from 1ms up to 10s
let buckets = vec![0.001, 0.01, 0.1, 0.5, 1.0, 2.0, 5.0, 10.0];
register_histogram!(
"safekeeper_global_disk_utilization_check_seconds",
"Seconds spent to perform each safekeeper filesystem utilization check",
buckets
)
.expect("Failed to register safekeeper_global_disk_utilization_check_seconds histogram")
});
/* END_HADRON */

View File

@@ -12,8 +12,7 @@ use futures::FutureExt;
use itertools::Itertools;
use parking_lot::Mutex;
use postgres_backend::{CopyStreamHandlerEnd, PostgresBackend, PostgresBackendReader, QueryError};
use postgres_ffi::{MAX_SEND_SIZE, PgMajorVersion, get_current_timestamp};
use postgres_ffi_types::TimestampTz;
use postgres_ffi::{MAX_SEND_SIZE, PgMajorVersion, TimestampTz, get_current_timestamp};
use pq_proto::{BeMessage, WalSndKeepAlive, XLogDataBody};
use safekeeper_api::Term;
use safekeeper_api::models::{

View File

@@ -29,8 +29,6 @@ use utils::sync::gate::Gate;
use crate::metrics::{
FullTimelineInfo, MISC_OPERATION_SECONDS, WAL_STORAGE_LIMIT_ERRORS, WalStorageMetrics,
};
use crate::hadron::GLOBAL_DISK_LIMIT_EXCEEDED;
use crate::rate_limit::RateLimiter;
use crate::receive_wal::WalReceivers;
use crate::safekeeper::{AcceptorProposerMessage, ProposerAcceptorMessage, SafeKeeper, TermLsn};
@@ -1083,11 +1081,6 @@ impl WalResidentTimeline {
);
}
}
if GLOBAL_DISK_LIMIT_EXCEEDED.load(Ordering::Relaxed) {
bail!("Global disk usage exceeded limit");
}
Ok(())
}
// END HADRON

View File

@@ -195,8 +195,6 @@ pub fn run_server(os: NodeOs, disk: Arc<SafekeeperDisk>) -> Result<()> {
enable_pull_timeline_on_startup: false,
advertise_pg_addr_tenant_only: None,
hcc_base_url: None,
global_disk_check_interval: Duration::from_secs(10),
max_global_disk_usage_ratio: 0.0,
/* END_HADRON */
};

View File

@@ -735,13 +735,15 @@ async fn handle_tenant_timeline_passthrough(
);
// Find the node that holds shard zero
let (node, tenant_shard_id, consistent) = if tenant_or_shard_id.is_unsharded() {
let (node, tenant_shard_id) = if tenant_or_shard_id.is_unsharded() {
service
.tenant_shard0_node(tenant_or_shard_id.tenant_id)
.await?
} else {
let (node, consistent) = service.tenant_shard_node(tenant_or_shard_id).await?;
(node, tenant_or_shard_id, consistent)
(
service.tenant_shard_node(tenant_or_shard_id).await?,
tenant_or_shard_id,
)
};
// Callers will always pass an unsharded tenant ID. Before proxying, we must
@@ -786,12 +788,16 @@ async fn handle_tenant_timeline_passthrough(
}
// Transform 404 into 503 if we raced with a migration
if resp.status() == reqwest::StatusCode::NOT_FOUND && !consistent {
// Rather than retry here, send the client a 503 to prompt a retry: this matches
// the pageserver's use of 503, and all clients calling this API should retry on 503.
return Err(ApiError::ResourceUnavailable(
format!("Pageserver {node} returned 404 due to ongoing migration, retry later").into(),
));
if resp.status() == reqwest::StatusCode::NOT_FOUND {
// Look up node again: if we migrated it will be different
let new_node = service.tenant_shard_node(tenant_shard_id).await?;
if new_node.get_id() != node.get_id() {
// Rather than retry here, send the client a 503 to prompt a retry: this matches
// the pageserver's use of 503, and all clients calling this API should retry on 503.
return Err(ApiError::ResourceUnavailable(
format!("Pageserver {node} returned 404, was migrated to {new_node}").into(),
));
}
}
// We have a reqest::Response, would like a http::Response
@@ -2591,17 +2597,6 @@ pub fn make_router(
)
},
)
// Tenant timeline mark_invisible passthrough to shard zero
.put(
"/v1/tenant/:tenant_id/timeline/:timeline_id/mark_invisible",
|r| {
tenant_service_handler(
r,
handle_tenant_timeline_passthrough,
RequestName("v1_tenant_timeline_mark_invisible_passthrough"),
)
},
)
// Tenant detail GET passthrough to shard zero:
.get("/v1/tenant/:tenant_id", |r| {
tenant_service_handler(
@@ -2620,6 +2615,17 @@ pub fn make_router(
RequestName("v1_tenant_passthrough"),
)
})
// Tenant timeline mark_invisible passthrough to shard zero
.put(
"/v1/tenant/:tenant_id/timeline/:timeline_id/mark_invisible",
|r| {
tenant_service_handler(
r,
handle_tenant_timeline_passthrough,
RequestName("v1_tenant_timeline_mark_invisible_passthrough"),
)
},
)
}
#[cfg(test)]

View File

@@ -207,27 +207,6 @@ enum ShardGenerationValidity {
},
}
/// We collect the state of attachments for some operations to determine if the operation
/// needs to be retried when it fails.
struct TenantShardAttachState {
/// The targets of the operation.
///
/// Tenant shard ID, node ID, node, is intent node observed primary.
targets: Vec<(TenantShardId, NodeId, Node, bool)>,
/// The targets grouped by node ID.
by_node_id: HashMap<NodeId, (TenantShardId, Node, bool)>,
}
impl TenantShardAttachState {
fn for_api_call(&self) -> Vec<(TenantShardId, Node)> {
self.targets
.iter()
.map(|(tenant_shard_id, _, node, _)| (*tenant_shard_id, node.clone()))
.collect()
}
}
pub const RECONCILER_CONCURRENCY_DEFAULT: usize = 128;
pub const PRIORITY_RECONCILER_CONCURRENCY_DEFAULT: usize = 256;
pub const SAFEKEEPER_RECONCILER_CONCURRENCY_DEFAULT: usize = 32;
@@ -4773,86 +4752,6 @@ impl Service {
Ok(())
}
fn is_observed_consistent_with_intent(
&self,
shard: &TenantShard,
intent_node_id: NodeId,
) -> bool {
if let Some(location) = shard.observed.locations.get(&intent_node_id)
&& let Some(ref conf) = location.conf
&& (conf.mode == LocationConfigMode::AttachedSingle
|| conf.mode == LocationConfigMode::AttachedMulti)
{
true
} else {
false
}
}
fn collect_tenant_shards(
&self,
tenant_id: TenantId,
) -> Result<TenantShardAttachState, ApiError> {
let locked = self.inner.read().unwrap();
let mut targets = Vec::new();
let mut by_node_id = HashMap::new();
// If the request got an unsharded tenant id, then apply
// the operation to all shards. Otherwise, apply it to a specific shard.
let shards_range = TenantShardId::tenant_range(tenant_id);
for (tenant_shard_id, shard) in locked.tenants.range(shards_range) {
if let Some(node_id) = shard.intent.get_attached() {
let node = locked
.nodes
.get(node_id)
.expect("Pageservers may not be deleted while referenced");
let consistent = self.is_observed_consistent_with_intent(shard, *node_id);
targets.push((*tenant_shard_id, *node_id, node.clone(), consistent));
by_node_id.insert(*node_id, (*tenant_shard_id, node.clone(), consistent));
}
}
Ok(TenantShardAttachState {
targets,
by_node_id,
})
}
fn process_result_and_passthrough_errors<T>(
&self,
results: Vec<(Node, Result<T, mgmt_api::Error>)>,
attach_state: TenantShardAttachState,
) -> Result<Vec<(Node, T)>, ApiError> {
let mut processed_results: Vec<(Node, T)> = Vec::with_capacity(results.len());
debug_assert_eq!(results.len(), attach_state.targets.len());
for (node, res) in results {
let is_consistent = attach_state
.by_node_id
.get(&node.get_id())
.map(|(_, _, consistent)| *consistent);
match res {
Ok(res) => processed_results.push((node, res)),
Err(mgmt_api::Error::ApiError(StatusCode::NOT_FOUND, _))
if is_consistent == Some(false) =>
{
// This is expected if the attach is not finished yet. Return 503 so that the client can retry.
return Err(ApiError::ResourceUnavailable(
format!(
"Timeline is not attached to the pageserver {} yet, please retry",
node.get_id()
)
.into(),
));
}
Err(e) => return Err(passthrough_api_error(&node, e)),
}
}
Ok(processed_results)
}
pub(crate) async fn tenant_timeline_lsn_lease(
&self,
tenant_id: TenantId,
@@ -4866,11 +4765,49 @@ impl Service {
)
.await;
let attach_state = self.collect_tenant_shards(tenant_id)?;
let mut retry_if_not_attached = false;
let targets = {
let locked = self.inner.read().unwrap();
let mut targets = Vec::new();
let results = self
// If the request got an unsharded tenant id, then apply
// the operation to all shards. Otherwise, apply it to a specific shard.
let shards_range = TenantShardId::tenant_range(tenant_id);
for (tenant_shard_id, shard) in locked.tenants.range(shards_range) {
if let Some(node_id) = shard.intent.get_attached() {
let node = locked
.nodes
.get(node_id)
.expect("Pageservers may not be deleted while referenced");
targets.push((*tenant_shard_id, node.clone()));
if let Some(location) = shard.observed.locations.get(node_id) {
if let Some(ref conf) = location.conf {
if conf.mode != LocationConfigMode::AttachedSingle
&& conf.mode != LocationConfigMode::AttachedMulti
{
// If the shard is attached as secondary, we need to retry if 404.
retry_if_not_attached = true;
}
// If the shard is attached as primary, we should succeed.
} else {
// Location conf is not available yet, retry if 404.
retry_if_not_attached = true;
}
} else {
// The shard is not attached to the intended pageserver yet, retry if 404.
retry_if_not_attached = true;
}
}
}
targets
};
let res = self
.tenant_for_shards_api(
attach_state.for_api_call(),
targets,
|tenant_shard_id, client| async move {
client
.timeline_lease_lsn(tenant_shard_id, timeline_id, lsn)
@@ -4883,13 +4820,31 @@ impl Service {
)
.await;
let leases = self.process_result_and_passthrough_errors(results, attach_state)?;
let mut valid_until = None;
for (_, lease) in leases {
if let Some(ref mut valid_until) = valid_until {
*valid_until = std::cmp::min(*valid_until, lease.valid_until);
} else {
valid_until = Some(lease.valid_until);
for (node, r) in res {
match r {
Ok(lease) => {
if let Some(ref mut valid_until) = valid_until {
*valid_until = std::cmp::min(*valid_until, lease.valid_until);
} else {
valid_until = Some(lease.valid_until);
}
}
Err(mgmt_api::Error::ApiError(StatusCode::NOT_FOUND, _))
if retry_if_not_attached =>
{
// This is expected if the attach is not finished yet. Return 503 so that the client can retry.
return Err(ApiError::ResourceUnavailable(
format!(
"Timeline is not attached to the pageserver {} yet, please retry",
node.get_id()
)
.into(),
));
}
Err(e) => {
return Err(passthrough_api_error(&node, e));
}
}
}
Ok(LsnLease {
@@ -5312,12 +5267,10 @@ impl Service {
status_code
}
/// When you know the TenantId but not a specific shard, and would like to get the node holding shard 0.
///
/// Returns the node, tenant shard id, and whether it is consistent with the observed state.
pub(crate) async fn tenant_shard0_node(
&self,
tenant_id: TenantId,
) -> Result<(Node, TenantShardId, bool), ApiError> {
) -> Result<(Node, TenantShardId), ApiError> {
let tenant_shard_id = {
let locked = self.inner.read().unwrap();
let Some((tenant_shard_id, _shard)) = locked
@@ -5335,17 +5288,15 @@ impl Service {
self.tenant_shard_node(tenant_shard_id)
.await
.map(|(node, consistent)| (node, tenant_shard_id, consistent))
.map(|node| (node, tenant_shard_id))
}
/// When you need to send an HTTP request to the pageserver that holds a shard of a tenant, this
/// function looks up and returns node. If the shard isn't found, returns Err(ApiError::NotFound)
///
/// Returns the intent node and whether it is consistent with the observed state.
pub(crate) async fn tenant_shard_node(
&self,
tenant_shard_id: TenantShardId,
) -> Result<(Node, bool), ApiError> {
) -> Result<Node, ApiError> {
// Look up in-memory state and maybe use the node from there.
{
let locked = self.inner.read().unwrap();
@@ -5375,8 +5326,7 @@ impl Service {
"Shard refers to nonexistent node"
)));
};
let consistent = self.is_observed_consistent_with_intent(shard, *intent_node_id);
return Ok((node.clone(), consistent));
return Ok(node.clone());
}
};
@@ -5410,8 +5360,8 @@ impl Service {
"Shard refers to nonexistent node"
)));
};
// As a reconciliation is in flight, we do not have the observed state yet, and therefore we assume it is always inconsistent.
Ok((node.clone(), false))
Ok(node.clone())
}
pub(crate) fn tenant_locate(

View File

@@ -1272,9 +1272,7 @@ impl TenantShard {
}
/// Return true if the optimization was really applied: it will not be applied if the optimization's
/// sequence is behind this tenant shard's or if the intent state proposed by the optimization
/// is not compatible with the current intent state. The later may happen when the background
/// reconcile loops runs concurrently with HTTP driven optimisations.
/// sequence is behind this tenant shard's
pub(crate) fn apply_optimization(
&mut self,
scheduler: &mut Scheduler,
@@ -1284,15 +1282,6 @@ impl TenantShard {
return false;
}
if !self.validate_optimization(&optimization) {
tracing::info!(
"Skipping optimization for {} because it does not match current intent: {:?}",
self.tenant_shard_id,
optimization,
);
return false;
}
metrics::METRICS_REGISTRY
.metrics_group
.storage_controller_schedule_optimization
@@ -1333,34 +1322,6 @@ impl TenantShard {
true
}
/// Check that the desired modifications to the intent state are compatible with
/// the current intent state
fn validate_optimization(&self, optimization: &ScheduleOptimization) -> bool {
match optimization.action {
ScheduleOptimizationAction::MigrateAttachment(MigrateAttachment {
old_attached_node_id,
new_attached_node_id,
}) => {
self.intent.attached == Some(old_attached_node_id)
&& self.intent.secondary.contains(&new_attached_node_id)
}
ScheduleOptimizationAction::ReplaceSecondary(ReplaceSecondary {
old_node_id: _,
new_node_id,
}) => {
// It's legal to remove a secondary that is not present in the intent state
!self.intent.secondary.contains(&new_node_id)
}
ScheduleOptimizationAction::CreateSecondary(new_node_id) => {
!self.intent.secondary.contains(&new_node_id)
}
ScheduleOptimizationAction::RemoveSecondary(_) => {
// It's legal to remove a secondary that is not present in the intent state
true
}
}
}
/// When a shard has several secondary locations, we need to pick one in situations where
/// we promote one of them to an attached location:
/// - When draining a node for restart

View File

@@ -503,7 +503,6 @@ class NeonLocalCli(AbstractNeonCli):
pageserver_id: int | None = None,
allow_multiple=False,
update_catalog: bool = False,
privileged_role_name: str | None = None,
) -> subprocess.CompletedProcess[str]:
args = [
"endpoint",
@@ -535,8 +534,6 @@ class NeonLocalCli(AbstractNeonCli):
args.extend(["--allow-multiple"])
if update_catalog:
args.extend(["--update-catalog"])
if privileged_role_name is not None:
args.extend(["--privileged-role-name", privileged_role_name])
res = self.raw_cli(args)
res.check_returncode()

View File

@@ -4324,7 +4324,6 @@ class Endpoint(PgProtocol, LogUtils):
pageserver_id: int | None = None,
allow_multiple: bool = False,
update_catalog: bool = False,
privileged_role_name: str | None = None,
) -> Self:
"""
Create a new Postgres endpoint.
@@ -4352,7 +4351,6 @@ class Endpoint(PgProtocol, LogUtils):
pageserver_id=pageserver_id,
allow_multiple=allow_multiple,
update_catalog=update_catalog,
privileged_role_name=privileged_role_name,
)
path = Path("endpoints") / self.endpoint_id / "pgdata"
self.pgdata_dir = self.env.repo_dir / path
@@ -4802,7 +4800,6 @@ class EndpointFactory:
config_lines: list[str] | None = None,
pageserver_id: int | None = None,
update_catalog: bool = False,
privileged_role_name: str | None = None,
) -> Endpoint:
ep = Endpoint(
self.env,
@@ -4826,7 +4823,6 @@ class EndpointFactory:
config_lines=config_lines,
pageserver_id=pageserver_id,
update_catalog=update_catalog,
privileged_role_name=privileged_role_name,
)
def stop_all(self, fail_on_error=True) -> Self:

View File

@@ -73,11 +73,6 @@ def test_sharding_autosplit(neon_env_builder: NeonEnvBuilder, pg_bin: PgBin):
".*Local notification hook failed.*",
".*Marking shard.*for notification retry.*",
".*Failed to notify compute.*",
# As an optimization, the storage controller kicks the downloads on the secondary
# after the shard split. However, secondaries are created async, so it's possible
# that the intent state was modified, but the actual secondary hasn't been created,
# which results in an error.
".*Error calling secondary download after shard split.*",
]
)

View File

@@ -103,90 +103,3 @@ def test_neon_superuser(neon_simple_env: NeonEnv, pg_version: PgVersion):
query = "DROP SUBSCRIPTION sub CASCADE"
log.info(f"Dropping subscription: {query}")
cur.execute(query)
def test_privileged_role_override(neon_simple_env: NeonEnv, pg_version: PgVersion):
"""
Test that we can override the privileged role for an endpoint and when we do it,
everything is correctly bootstrapped inside Postgres and we don't have neon_superuser
role in the database.
"""
PRIVILEGED_ROLE_NAME = "my_superuser"
env = neon_simple_env
env.create_branch("test_privileged_role_override")
ep = env.endpoints.create(
"test_privileged_role_override",
privileged_role_name=PRIVILEGED_ROLE_NAME,
update_catalog=True,
)
ep.start()
ep.wait_for_migrations()
member_roles = [
"pg_read_all_data",
"pg_write_all_data",
"pg_monitor",
"pg_signal_backend",
]
non_member_roles = [
"pg_execute_server_program",
"pg_read_server_files",
"pg_write_server_files",
]
role_attributes = {
"rolsuper": False,
"rolinherit": True,
"rolcreaterole": True,
"rolcreatedb": True,
"rolcanlogin": False,
"rolreplication": True,
"rolconnlimit": -1,
"rolbypassrls": True,
}
if pg_version >= PgVersion.V15:
non_member_roles.append("pg_checkpoint")
if pg_version >= PgVersion.V16:
member_roles.append("pg_create_subscription")
non_member_roles.append("pg_use_reserved_connections")
with ep.cursor() as cur:
cur.execute(f"SELECT rolname FROM pg_roles WHERE rolname = '{PRIVILEGED_ROLE_NAME}'")
assert cur.fetchall()[0][0] == PRIVILEGED_ROLE_NAME
cur.execute("SELECT rolname FROM pg_roles WHERE rolname = 'neon_superuser'")
assert len(cur.fetchall()) == 0
cur.execute("SHOW neon.privileged_role_name")
assert cur.fetchall()[0][0] == PRIVILEGED_ROLE_NAME
# check PRIVILEGED_ROLE_NAME role is created
cur.execute(f"select * from pg_roles where rolname = '{PRIVILEGED_ROLE_NAME}'")
assert cur.fetchone() is not None
# check PRIVILEGED_ROLE_NAME role has the correct member roles
for role in member_roles:
cur.execute(f"SELECT pg_has_role('{PRIVILEGED_ROLE_NAME}', '{role}', 'member')")
assert cur.fetchone() == (True,), (
f"Role {role} should be a member of {PRIVILEGED_ROLE_NAME}"
)
for role in non_member_roles:
cur.execute(f"SELECT pg_has_role('{PRIVILEGED_ROLE_NAME}', '{role}', 'member')")
assert cur.fetchone() == (False,), (
f"Role {role} should not be a member of {PRIVILEGED_ROLE_NAME}"
)
# check PRIVILEGED_ROLE_NAME role has the correct role attributes
for attr, val in role_attributes.items():
cur.execute(f"SELECT {attr} FROM pg_roles WHERE rolname = '{PRIVILEGED_ROLE_NAME}'")
curr_val = cur.fetchone()
assert curr_val == (val,), (
f"Role attribute {attr} should be {val} instead of {curr_val}"
)

View File

@@ -246,9 +246,9 @@ def test_total_size_limit(neon_env_builder: NeonEnvBuilder):
system_memory = psutil.virtual_memory().total
# The smallest total size limit we can configure is 1/1024th of the system memory (e.g. 256MB on
# a system with 256GB of RAM). We will then write enough data to violate this limit.
max_dirty_data = 256 * 1024 * 1024
# The smallest total size limit we can configure is 1/1024th of the system memory (e.g. 128MB on
# a system with 128GB of RAM). We will then write enough data to violate this limit.
max_dirty_data = 128 * 1024 * 1024
ephemeral_bytes_per_memory_kb = (max_dirty_data * 1024) // system_memory
assert ephemeral_bytes_per_memory_kb > 0
@@ -272,7 +272,7 @@ def test_total_size_limit(neon_env_builder: NeonEnvBuilder):
timeline_count = 10
# This is about 2MiB of data per timeline
entries_per_timeline = 200_000
entries_per_timeline = 100_000
last_flush_lsns = asyncio.run(workload(env, tenant_conf, timeline_count, entries_per_timeline))
wait_until_pageserver_is_caught_up(env, last_flush_lsns)

View File

@@ -2788,8 +2788,7 @@ def test_timeline_disk_usage_limit(neon_env_builder: NeonEnvBuilder):
# Wait for the error message to appear in the compute log
def error_logged():
if endpoint.log_contains("WAL storage utilization exceeds configured limit") is None:
raise Exception("Expected error message not found in compute log yet")
return endpoint.log_contains("WAL storage utilization exceeds configured limit") is not None
wait_until(error_logged)
log.info("Found expected error message in compute log, resuming.")
@@ -2823,87 +2822,3 @@ def test_timeline_disk_usage_limit(neon_env_builder: NeonEnvBuilder):
cur.execute("select count(*) from t")
# 2000 rows from first insert + 1000 from last insert
assert cur.fetchone() == (3000,)
def test_global_disk_usage_limit(neon_env_builder: NeonEnvBuilder):
"""
Similar to `test_timeline_disk_usage_limit`, but test that the global disk usage circuit breaker
also works as expected. The test scenario:
1. Create a timeline and endpoint.
2. Mock high disk usage via failpoint
3. Write data to the timeline so that disk usage exceeds the limit.
4. Verify that the writes hang and the expected error message appears in the compute log.
5. Mock low disk usage via failpoint
6. Verify that the hanging writes unblock and we can continue to write as normal.
"""
neon_env_builder.num_safekeepers = 1
remote_storage_kind = s3_storage()
neon_env_builder.enable_safekeeper_remote_storage(remote_storage_kind)
env = neon_env_builder.init_start()
env.create_branch("test_global_disk_usage_limit")
endpoint = env.endpoints.create_start("test_global_disk_usage_limit")
with closing(endpoint.connect()) as conn:
with conn.cursor() as cur:
cur.execute("create table t2(key int, value text)")
for sk in env.safekeepers:
sk.stop().start(
extra_opts=["--global-disk-check-interval=1s", "--max-global-disk-usage-ratio=0.8"]
)
# Set the failpoint to have the disk usage check return u64::MAX, which definitely exceeds the practical
# limits in the test environment.
for sk in env.safekeepers:
sk.http_client().configure_failpoints(
[("sk-global-disk-usage", "return(18446744073709551615)")]
)
# Wait until the global disk usage limit watcher trips the circuit breaker.
def error_logged_in_sk():
for sk in env.safekeepers:
if sk.log_contains("Global disk usage exceeded limit") is None:
raise Exception("Expected error message not found in safekeeper log yet")
wait_until(error_logged_in_sk)
def run_hanging_insert_global():
with closing(endpoint.connect()) as bg_conn:
with bg_conn.cursor() as bg_cur:
# This should generate more than 1KiB of WAL
bg_cur.execute("insert into t2 select generate_series(1,2000), 'payload'")
bg_thread_global = threading.Thread(target=run_hanging_insert_global)
bg_thread_global.start()
def error_logged_in_compute():
if endpoint.log_contains("Global disk usage exceeded limit") is None:
raise Exception("Expected error message not found in compute log yet")
wait_until(error_logged_in_compute)
log.info("Found the expected error message in compute log, resuming.")
time.sleep(2)
assert bg_thread_global.is_alive(), "Global hanging insert unblocked prematurely!"
# Make the disk usage check always return 0 through the failpoint to simulate the disk pressure easing.
# The SKs should resume accepting WAL writes without restarting.
for sk in env.safekeepers:
sk.http_client().configure_failpoints([("sk-global-disk-usage", "return(0)")])
bg_thread_global.join(timeout=120)
assert not bg_thread_global.is_alive(), "Hanging global insert did not complete after restart"
log.info("Global hanging insert unblocked.")
# Verify that we can continue to write as normal and we don't have obvious data corruption
# following the recovery.
with closing(endpoint.connect()) as conn:
with conn.cursor() as cur:
cur.execute("insert into t2 select generate_series(2001,3000), 'payload'")
with closing(endpoint.connect()) as conn:
with conn.cursor() as cur:
cur.execute("select count(*) from t2")
assert cur.fetchone() == (3000,)

Some files were not shown because too many files have changed in this diff Show More