Compare commits

..

8 Commits

Author SHA1 Message Date
Conrad Ludgate
0e952cee8e fix deps 2025-07-11 14:46:56 +01:00
Conrad Ludgate
6d7ab67401 fix lints 2025-07-11 14:18:30 +01:00
Conrad Ludgate
3ff8afa32c update measured properly 2025-07-11 13:59:21 +01:00
Conrad Ludgate
682a54fa9e replace lasso completely 2025-07-11 13:59:21 +01:00
Conrad Ludgate
d5c17559ce bump 2025-07-11 13:59:21 +01:00
Conrad Ludgate
e8ccb4a4d1 fmt 2025-07-11 13:58:53 +01:00
Conrad Ludgate
bdd68bb069 bump 2025-07-11 13:58:53 +01:00
Conrad Ludgate
783260b88a switch to paracord 2025-07-11 13:57:59 +01:00
240 changed files with 3249 additions and 11676 deletions

View File

@@ -27,4 +27,4 @@
!storage_controller/
!vendor/postgres-*/
!workspace_hack/
!build-tools/patches
!build_tools/patches

View File

@@ -31,7 +31,6 @@ config-variables:
- NEON_PROD_AWS_ACCOUNT_ID
- PGREGRESS_PG16_PROJECT_ID
- PGREGRESS_PG17_PROJECT_ID
- PREWARM_PGBENCH_SIZE
- REMOTE_STORAGE_AZURE_CONTAINER
- REMOTE_STORAGE_AZURE_REGION
- SLACK_CICD_CHANNEL_ID

View File

@@ -176,13 +176,7 @@ runs:
fi
if [[ $BUILD_TYPE == "debug" && $RUNNER_ARCH == 'X64' ]]; then
# We don't use code coverage for regression tests (the step is disabled),
# so there's no need to collect it.
# Ref https://github.com/neondatabase/neon/issues/4540
# cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/coverage run)
cov_prefix=()
# Explicitly set LLVM_PROFILE_FILE to /dev/null to avoid writing *.profraw files
export LLVM_PROFILE_FILE=/dev/null
cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/coverage run)
else
cov_prefix=()
fi

View File

@@ -150,7 +150,7 @@ jobs:
secretKey: ${{ secrets.HETZNER_CACHE_SECRET_KEY }}
use-fallback: false
path: pg_install/v14
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build-type }}-pg-${{ steps.pg_v14_rev.outputs.pg_rev }}-bookworm-${{ hashFiles('Makefile', 'build-tools/Dockerfile') }}
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build-type }}-pg-${{ steps.pg_v14_rev.outputs.pg_rev }}-bookworm-${{ hashFiles('Makefile', 'build-tools.Dockerfile') }}
- name: Cache postgres v15 build
id: cache_pg_15
@@ -162,7 +162,7 @@ jobs:
secretKey: ${{ secrets.HETZNER_CACHE_SECRET_KEY }}
use-fallback: false
path: pg_install/v15
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build-type }}-pg-${{ steps.pg_v15_rev.outputs.pg_rev }}-bookworm-${{ hashFiles('Makefile', 'build-tools/Dockerfile') }}
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build-type }}-pg-${{ steps.pg_v15_rev.outputs.pg_rev }}-bookworm-${{ hashFiles('Makefile', 'build-tools.Dockerfile') }}
- name: Cache postgres v16 build
id: cache_pg_16
@@ -174,7 +174,7 @@ jobs:
secretKey: ${{ secrets.HETZNER_CACHE_SECRET_KEY }}
use-fallback: false
path: pg_install/v16
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build-type }}-pg-${{ steps.pg_v16_rev.outputs.pg_rev }}-bookworm-${{ hashFiles('Makefile', 'build-tools/Dockerfile') }}
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build-type }}-pg-${{ steps.pg_v16_rev.outputs.pg_rev }}-bookworm-${{ hashFiles('Makefile', 'build-tools.Dockerfile') }}
- name: Cache postgres v17 build
id: cache_pg_17
@@ -186,7 +186,7 @@ jobs:
secretKey: ${{ secrets.HETZNER_CACHE_SECRET_KEY }}
use-fallback: false
path: pg_install/v17
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build-type }}-pg-${{ steps.pg_v17_rev.outputs.pg_rev }}-bookworm-${{ hashFiles('Makefile', 'build-tools/Dockerfile') }}
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build-type }}-pg-${{ steps.pg_v17_rev.outputs.pg_rev }}-bookworm-${{ hashFiles('Makefile', 'build-tools.Dockerfile') }}
- name: Build all
# Note: the Makefile picks up BUILD_TYPE and CARGO_PROFILE from the env variables

View File

@@ -219,7 +219,6 @@ jobs:
--ignore test_runner/performance/test_cumulative_statistics_persistence.py
--ignore test_runner/performance/test_perf_many_relations.py
--ignore test_runner/performance/test_perf_oltp_large_tenant.py
--ignore test_runner/performance/test_lfc_prewarm.py
env:
BENCHMARK_CONNSTR: ${{ steps.create-neon-project.outputs.dsn }}
VIP_VAP_ACCESS_TOKEN: "${{ secrets.VIP_VAP_ACCESS_TOKEN }}"
@@ -411,77 +410,6 @@ jobs:
env:
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
prewarm-test:
if: ${{ github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null }}
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
env:
PGBENCH_SIZE: ${{ vars.PREWARM_PGBENCH_SIZE }}
POSTGRES_DISTRIB_DIR: /tmp/neon/pg_install
DEFAULT_PG_VERSION: 17
TEST_OUTPUT: /tmp/test_output
BUILD_TYPE: remote
SAVE_PERF_REPORT: ${{ github.event.inputs.save_perf_report || ( github.ref_name == 'main' ) }}
PLATFORM: "neon-staging"
runs-on: [ self-hosted, us-east-2, x64 ]
container:
image: ghcr.io/neondatabase/build-tools:pinned-bookworm
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
options: --init
steps:
- name: Harden the runner (Audit all outbound calls)
uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0
with:
egress-policy: audit
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # 5 hours
- name: Download Neon artifact
uses: ./.github/actions/download
with:
name: neon-${{ runner.os }}-${{ runner.arch }}-release-artifact
path: /tmp/neon/
prefix: latest
aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
- name: Run prewarm benchmark
uses: ./.github/actions/run-python-test-set
with:
build_type: ${{ env.BUILD_TYPE }}
test_selection: performance/test_lfc_prewarm.py
run_in_parallel: false
save_perf_report: ${{ env.SAVE_PERF_REPORT }}
extra_params: -m remote_cluster --timeout 5400
pg_version: ${{ env.DEFAULT_PG_VERSION }}
aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
env:
VIP_VAP_ACCESS_TOKEN: "${{ secrets.VIP_VAP_ACCESS_TOKEN }}"
PERF_TEST_RESULT_CONNSTR: "${{ secrets.PERF_TEST_RESULT_CONNSTR }}"
NEON_API_KEY: ${{ secrets.NEON_STAGING_API_KEY }}
- name: Create Allure report
id: create-allure-report
if: ${{ !cancelled() }}
uses: ./.github/actions/allure-report-generate
with:
store-test-results-into-db: true
aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
env:
REGRESS_TEST_RESULT_CONNSTR_NEW: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}
generate-matrices:
if: ${{ github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null }}
# Create matrices for the benchmarking jobs, so we run benchmarks on rds only once a week (on Saturday)

View File

@@ -72,7 +72,7 @@ jobs:
ARCHS: ${{ inputs.archs || '["x64","arm64"]' }}
DEBIANS: ${{ inputs.debians || '["bullseye","bookworm"]' }}
IMAGE_TAG: |
${{ hashFiles('build-tools/Dockerfile',
${{ hashFiles('build-tools.Dockerfile',
'.github/workflows/build-build-tools-image.yml') }}
run: |
echo "archs=${ARCHS}" | tee -a ${GITHUB_OUTPUT}
@@ -144,7 +144,7 @@ jobs:
- uses: docker/build-push-action@471d1dc4e07e5cdedd4c2171150001c434f0b7a4 # v6.15.0
with:
file: build-tools/Dockerfile
file: build-tools.Dockerfile
context: .
provenance: false
push: true

View File

@@ -87,27 +87,22 @@ jobs:
uses: ./.github/workflows/build-build-tools-image.yml
secrets: inherit
lint-yamls:
needs: [ meta, check-permissions, build-build-tools-image ]
lint-openapi-spec:
runs-on: ubuntu-22.04
needs: [ meta, check-permissions ]
# We do need to run this in `.*-rc-pr` because of hotfixes.
if: ${{ contains(fromJSON('["pr", "push-main", "storage-rc-pr", "proxy-rc-pr", "compute-rc-pr"]'), needs.meta.outputs.run-kind) }}
runs-on: [ self-hosted, small ]
container:
image: ${{ needs.build-build-tools-image.outputs.image }}
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
options: --init
steps:
- name: Harden the runner (Audit all outbound calls)
uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0
with:
egress-policy: audit
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- run: make -C compute manifest-schema-validation
- uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- run: make lint-openapi-spec
check-codestyle-python:
@@ -222,6 +217,28 @@ jobs:
build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm
secrets: inherit
validate-compute-manifest:
runs-on: ubuntu-22.04
needs: [ meta, check-permissions ]
# We do need to run this in `.*-rc-pr` because of hotfixes.
if: ${{ contains(fromJSON('["pr", "push-main", "storage-rc-pr", "proxy-rc-pr", "compute-rc-pr"]'), needs.meta.outputs.run-kind) }}
steps:
- name: Harden the runner (Audit all outbound calls)
uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0
with:
egress-policy: audit
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Node.js
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
with:
node-version: '24'
- name: Validate manifest against schema
run: |
make -C compute manifest-schema-validation
build-and-test-locally:
needs: [ meta, build-build-tools-image ]
# We do need to run this in `.*-rc-pr` because of hotfixes.

3
.gitignore vendored
View File

@@ -29,6 +29,3 @@ docker-compose/docker-compose-parallel.yml
# pgindent typedef lists
*.list
# Node
**/node_modules/

8
.gitmodules vendored
View File

@@ -1,16 +1,16 @@
[submodule "vendor/postgres-v14"]
path = vendor/postgres-v14
url = ../postgres.git
url = https://github.com/neondatabase/postgres.git
branch = REL_14_STABLE_neon
[submodule "vendor/postgres-v15"]
path = vendor/postgres-v15
url = ../postgres.git
url = https://github.com/neondatabase/postgres.git
branch = REL_15_STABLE_neon
[submodule "vendor/postgres-v16"]
path = vendor/postgres-v16
url = ../postgres.git
url = https://github.com/neondatabase/postgres.git
branch = REL_16_STABLE_neon
[submodule "vendor/postgres-v17"]
path = vendor/postgres-v17
url = ../postgres.git
url = https://github.com/neondatabase/postgres.git
branch = REL_17_STABLE_neon

228
Cargo.lock generated
View File

@@ -1009,6 +1009,12 @@ dependencies = [
"generic-array",
]
[[package]]
name = "boxcar"
version = "0.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26c4925bc979b677330a8c7fe7a8c94af2dbb4a2d37b4a20a80d884400f46baa"
[[package]]
name = "bstr"
version = "1.5.0"
@@ -1242,14 +1248,14 @@ checksum = "afb84c814227b90d6895e01398aee0d8033c00e7466aca416fb6a8e0eb19d8a7"
[[package]]
name = "clashmap"
version = "1.0.0"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93bd59c81e2bd87a775ae2de75f070f7e2bfe97363a6ad652f46824564c23e4d"
checksum = "e8a055b1f1bf558eae4959f6dd77cf2d7d50ae1483928e60ef21ca5a24fd4321"
dependencies = [
"crossbeam-utils",
"hashbrown 0.15.2",
"lock_api",
"parking_lot_core 0.9.8",
"parking_lot_core 0.9.10",
"polonius-the-crab",
"replace_with",
]
@@ -1753,19 +1759,6 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "dashmap"
version = "5.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6943ae99c34386c84a470c499d3414f66502a41340aa895406e0d2e4a207b91d"
dependencies = [
"cfg-if",
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core 0.9.8",
]
[[package]]
name = "dashmap"
version = "6.1.0"
@@ -1777,7 +1770,7 @@ dependencies = [
"hashbrown 0.14.5",
"lock_api",
"once_cell",
"parking_lot_core 0.9.8",
"parking_lot_core 0.9.10",
]
[[package]]
@@ -1872,7 +1865,6 @@ dependencies = [
"diesel_derives",
"itoa",
"serde_json",
"uuid",
]
[[package]]
@@ -2344,6 +2336,12 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foldhash"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "form_urlencoded"
version = "1.2.1"
@@ -2534,18 +2532,6 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "getrandom"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
dependencies = [
"cfg-if",
"libc",
"r-efi",
"wasi 0.14.2+wasi-0.2.4",
]
[[package]]
name = "gettid"
version = "0.1.3"
@@ -2595,7 +2581,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "842dc78579ce01e6a1576ad896edc92fca002dd60c9c3746b7fc2bec6fb429d0"
dependencies = [
"cfg-if",
"dashmap 6.1.0",
"dashmap",
"futures-sink",
"futures-timer",
"futures-util",
@@ -3301,7 +3287,7 @@ dependencies = [
"clap",
"crossbeam-channel",
"crossbeam-utils",
"dashmap 6.1.0",
"dashmap",
"env_logger",
"indexmap 2.9.0",
"itoa",
@@ -3558,16 +3544,6 @@ dependencies = [
"libc",
]
[[package]]
name = "lasso"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4644821e1c3d7a560fe13d842d13f587c07348a1a05d3a797152d41c90c56df2"
dependencies = [
"dashmap 5.5.0",
"hashbrown 0.13.2",
]
[[package]]
name = "lazy_static"
version = "1.5.0"
@@ -3619,9 +3595,9 @@ checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104"
[[package]]
name = "lock_api"
version = "0.4.13"
version = "0.4.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765"
checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16"
dependencies = [
"autocfg",
"scopeguard",
@@ -3687,17 +3663,17 @@ checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
[[package]]
name = "measured"
version = "0.0.22"
version = "0.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3051f3a030d55d680cdef6ca50e80abd1182f8da29f2344a7c9cb575721138f0"
checksum = "d22ae866c28b9c59afaeb488ad6e3bd148570cf5a2bacf6c4845def5b9a03470"
dependencies = [
"bytes",
"crossbeam-utils",
"hashbrown 0.14.5",
"itoa",
"lasso",
"measured-derive",
"memchr",
"paracord",
"parking_lot 0.12.1",
"rustc-hash 1.1.0",
"ryu",
@@ -3705,9 +3681,9 @@ dependencies = [
[[package]]
name = "measured-derive"
version = "0.0.22"
version = "0.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e6777fc80a575f9503d908c8b498782a6c3ee88a06cb416dc3941401e43b94"
checksum = "16d734ed9dbca22d87d56b54d220f254ce921cb5cce97d4a960075af0131d076"
dependencies = [
"heck",
"proc-macro2",
@@ -3717,15 +3693,26 @@ dependencies = [
[[package]]
name = "measured-process"
version = "0.0.22"
version = "0.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c4b80445aeb08e832d87bf1830049a924cdc1d6b7ef40b6b9b365bff17bf8ec"
checksum = "f71a318d2b9edcded1e5e0ccf3c9e1e1614217d7f07933631c771daa717743aa"
dependencies = [
"libc",
"measured",
"procfs",
]
[[package]]
name = "measured-tokio"
version = "0.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e79a936051c484268e2d71a5dd01219096c62c8ff09afad95f07c14a3b0c580"
dependencies = [
"itoa",
"measured",
"tokio",
]
[[package]]
name = "memchr"
version = "2.6.4"
@@ -3771,7 +3758,7 @@ dependencies = [
"procfs",
"prometheus",
"rand 0.8.5",
"rand_distr 0.4.3",
"rand_distr",
"twox-hash",
]
@@ -3859,12 +3846,7 @@ checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
name = "neon-shmem"
version = "0.1.0"
dependencies = [
"libc",
"lock_api",
"nix 0.30.1",
"rand 0.9.1",
"rand_distr 0.5.1",
"rustc-hash 2.1.1",
"tempfile",
"thiserror 1.0.69",
"workspace_hack",
@@ -4314,7 +4296,6 @@ dependencies = [
"pageserver_client",
"pageserver_client_grpc",
"pageserver_page_api",
"pprof",
"rand 0.8.5",
"reqwest",
"serde",
@@ -4590,6 +4571,21 @@ dependencies = [
"seize",
]
[[package]]
name = "paracord"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e41e924113f9a05eecd561c6e695648f243a03b3ad8d9b3eff689342b95023b"
dependencies = [
"boxcar",
"clashmap",
"foldhash",
"hashbrown 0.15.2",
"serde",
"sync_wrapper 1.0.1",
"typed-arena",
]
[[package]]
name = "parking"
version = "2.1.1"
@@ -4614,7 +4610,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f"
dependencies = [
"lock_api",
"parking_lot_core 0.9.8",
"parking_lot_core 0.9.10",
]
[[package]]
@@ -4633,15 +4629,15 @@ dependencies = [
[[package]]
name = "parking_lot_core"
version = "0.9.8"
version = "0.9.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447"
checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8"
dependencies = [
"cfg-if",
"libc",
"redox_syscall 0.3.5",
"redox_syscall 0.5.10",
"smallvec",
"windows-targets 0.48.0",
"windows-targets 0.52.6",
]
[[package]]
@@ -5308,7 +5304,6 @@ dependencies = [
"async-trait",
"atomic-take",
"aws-config",
"aws-credential-types",
"aws-sdk-iam",
"aws-sigv4",
"base64 0.22.1",
@@ -5348,14 +5343,14 @@ dependencies = [
"itoa",
"jose-jwa",
"jose-jwk",
"json",
"lasso",
"measured",
"measured-tokio",
"metrics",
"once_cell",
"opentelemetry",
"p256 0.13.2",
"papaya",
"paracord",
"parking_lot 0.12.1",
"parquet",
"parquet_derive",
@@ -5365,7 +5360,7 @@ dependencies = [
"postgres_backend",
"pq_proto",
"rand 0.8.5",
"rand_distr 0.4.3",
"rand_distr",
"rcgen",
"redis",
"regex",
@@ -5376,7 +5371,7 @@ dependencies = [
"reqwest-tracing",
"rsa",
"rstest",
"rustc-hash 2.1.1",
"rustc-hash 1.1.0",
"rustls 0.23.27",
"rustls-native-certs 0.8.0",
"rustls-pemfile 2.1.1",
@@ -5406,7 +5401,6 @@ dependencies = [
"tracing-test",
"tracing-utils",
"try-lock",
"type-safe-id",
"typed-json",
"url",
"urlencoding",
@@ -5470,12 +5464,6 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "r-efi"
version = "5.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
[[package]]
name = "rand"
version = "0.7.3"
@@ -5500,16 +5488,6 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rand"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
]
[[package]]
name = "rand_chacha"
version = "0.2.2"
@@ -5530,16 +5508,6 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rand_chacha"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core 0.9.3",
]
[[package]]
name = "rand_core"
version = "0.5.1"
@@ -5558,15 +5526,6 @@ dependencies = [
"getrandom 0.2.11",
]
[[package]]
name = "rand_core"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
dependencies = [
"getrandom 0.3.3",
]
[[package]]
name = "rand_distr"
version = "0.4.3"
@@ -5577,16 +5536,6 @@ dependencies = [
"rand 0.8.5",
]
[[package]]
name = "rand_distr"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463"
dependencies = [
"num-traits",
"rand 0.9.1",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
@@ -5704,6 +5653,15 @@ dependencies = [
"bitflags 1.3.2",
]
[[package]]
name = "redox_syscall"
version = "0.5.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b8c0c260b63a8219631167be35e6a988e9554dbd323f8bd08439c8ed1302bd1"
dependencies = [
"bitflags 2.8.0",
]
[[package]]
name = "regex"
version = "1.10.2"
@@ -6268,7 +6226,6 @@ dependencies = [
"itertools 0.10.5",
"jsonwebtoken",
"metrics",
"nix 0.30.1",
"once_cell",
"pageserver_api",
"parking_lot 0.12.1",
@@ -6276,7 +6233,6 @@ dependencies = [
"postgres-protocol",
"postgres_backend",
"postgres_ffi",
"postgres_ffi_types",
"postgres_versioninfo",
"pprof",
"pq_proto",
@@ -6321,7 +6277,7 @@ dependencies = [
"anyhow",
"const_format",
"pageserver_api",
"postgres_ffi_types",
"postgres_ffi",
"postgres_versioninfo",
"pq_proto",
"serde",
@@ -6966,12 +6922,12 @@ dependencies = [
"hyper 0.14.30",
"itertools 0.10.5",
"json-structural-diff",
"lasso",
"measured",
"metrics",
"once_cell",
"pageserver_api",
"pageserver_client",
"paracord",
"postgres_connection",
"posthog_client_lite",
"rand 0.8.5",
@@ -6997,7 +6953,6 @@ dependencies = [
"tokio-util",
"tracing",
"utils",
"uuid",
"workspace_hack",
]
@@ -7631,7 +7586,6 @@ dependencies = [
"futures-core",
"pin-project-lite",
"tokio",
"tokio-util",
]
[[package]]
@@ -8089,17 +8043,10 @@ dependencies = [
]
[[package]]
name = "type-safe-id"
version = "0.3.3"
name = "typed-arena"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd9267f90719e0433aae095640b294ff36ccbf89649ecb9ee34464ec504be157"
dependencies = [
"arrayvec",
"rand 0.9.1",
"serde",
"thiserror 2.0.11",
"uuid",
]
checksum = "6af6ae20167a9ece4bcb41af5b80f8a1f1df981f6391189ce00fd257af04126a"
[[package]]
name = "typed-json"
@@ -8284,7 +8231,6 @@ dependencies = [
"tracing-error",
"tracing-subscriber",
"tracing-utils",
"uuid",
"walkdir",
]
@@ -8427,15 +8373,6 @@ version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wasi"
version = "0.14.2+wasi-0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3"
dependencies = [
"wit-bindgen-rt",
]
[[package]]
name = "wasite"
version = "0.1.0"
@@ -8793,15 +8730,6 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "wit-bindgen-rt"
version = "0.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
dependencies = [
"bitflags 2.8.0",
]
[[package]]
name = "workspace_hack"
version = "0.1.0"
@@ -8864,6 +8792,7 @@ dependencies = [
"num-traits",
"once_cell",
"p256 0.13.2",
"paracord",
"parquet",
"prettyplease",
"proc-macro2",
@@ -8904,6 +8833,7 @@ dependencies = [
"tracing-log",
"tracing-subscriber",
"url",
"uuid",
"zeroize",
"zstd",
"zstd-safe",

View File

@@ -128,12 +128,11 @@ itertools = "0.10"
itoa = "1.0.11"
jemalloc_pprof = { version = "0.7", features = ["symbolize", "flamegraph"] }
jsonwebtoken = "9"
lasso = "0.7"
libc = "0.2"
lock_api = "0.4.13"
md5 = "0.7.0"
measured = { version = "0.0.22", features=["lasso"] }
measured-process = { version = "0.0.22" }
measured = { version = "0.0.23", features = ["paracord"] }
measured-process = { version = "0.0.23" }
measured-tokio = { version = "0.0.23" }
memoffset = "0.9"
nix = { version = "0.30.1", features = ["dir", "fs", "mman", "process", "socket", "signal", "poll"] }
# Do not update to >= 7.0.0, at least. The update will have a significant impact
@@ -146,6 +145,7 @@ opentelemetry = "0.27"
opentelemetry_sdk = "0.27"
opentelemetry-otlp = { version = "0.27", default-features = false, features = ["http-proto", "trace", "http", "reqwest-client"] }
opentelemetry-semantic-conventions = "0.27"
paracord = { version = "0.1.0", features = ["serde"] }
parking_lot = "0.12"
parquet = { version = "53", default-features = false, features = ["zstd"] }
parquet_derive = "53"
@@ -166,7 +166,7 @@ reqwest-middleware = "0.4"
reqwest-retry = "0.7"
routerify = "3"
rpds = "0.13"
rustc-hash = "2.1.1"
rustc-hash = "1.1.0"
rustls = { version = "0.23.16", default-features = false }
rustls-pemfile = "2"
rustls-pki-types = "1.11"
@@ -202,7 +202,7 @@ tokio-epoll-uring = { git = "https://github.com/neondatabase/tokio-epoll-uring.g
tokio-io-timeout = "1.2.0"
tokio-postgres-rustls = "0.12.0"
tokio-rustls = { version = "0.26.0", default-features = false, features = ["tls12", "ring"]}
tokio-stream = { version = "0.1", features = ["sync"] }
tokio-stream = "0.1"
tokio-tar = "0.3"
tokio-util = { version = "0.7.10", features = ["io", "io-util", "rt"] }
toml = "0.8"

View File

@@ -2,7 +2,7 @@ ROOT_PROJECT_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
# Where to install Postgres, default is ./pg_install, maybe useful for package
# managers.
POSTGRES_INSTALL_DIR ?= $(ROOT_PROJECT_DIR)/pg_install
POSTGRES_INSTALL_DIR ?= $(ROOT_PROJECT_DIR)/pg_install/
# Supported PostgreSQL versions
POSTGRES_VERSIONS = v17 v16 v15 v14
@@ -14,7 +14,7 @@ POSTGRES_VERSIONS = v17 v16 v15 v14
# it is derived from BUILD_TYPE.
# All intermediate build artifacts are stored here.
BUILD_DIR := $(ROOT_PROJECT_DIR)/build
BUILD_DIR := build
ICU_PREFIX_DIR := /usr/local/icu
@@ -212,7 +212,7 @@ neon-pgindent: postgres-v17-pg-bsd-indent neon-pg-ext-v17
FIND_TYPEDEF=$(ROOT_PROJECT_DIR)/vendor/postgres-v17/src/tools/find_typedef \
INDENT=$(BUILD_DIR)/v17/src/tools/pg_bsd_indent/pg_bsd_indent \
PGINDENT_SCRIPT=$(ROOT_PROJECT_DIR)/vendor/postgres-v17/src/tools/pgindent/pgindent \
-C $(BUILD_DIR)/pgxn-v17/neon \
-C $(BUILD_DIR)/neon-v17 \
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile pgindent
@@ -220,15 +220,11 @@ neon-pgindent: postgres-v17-pg-bsd-indent neon-pg-ext-v17
setup-pre-commit-hook:
ln -s -f $(ROOT_PROJECT_DIR)/pre-commit.py .git/hooks/pre-commit
build-tools/node_modules: build-tools/package.json
cd build-tools && $(if $(CI),npm ci,npm install)
touch build-tools/node_modules
.PHONY: lint-openapi-spec
lint-openapi-spec: build-tools/node_modules
lint-openapi-spec:
# operation-2xx-response: pageserver timeline delete returns 404 on success
find . -iname "openapi_spec.y*ml" -exec\
npx --prefix=build-tools/ redocly\
docker run --rm -v ${PWD}:/spec ghcr.io/redocly/cli:1.34.4\
--skip-rule=operation-operationId --skip-rule=operation-summary --extends=minimal\
--skip-rule=no-server-example.com --skip-rule=operation-2xx-response\
lint {} \+

View File

@@ -35,7 +35,7 @@ RUN echo 'Acquire::Retries "5";' > /etc/apt/apt.conf.d/80-retries && \
echo -e "retry_connrefused=on\ntimeout=15\ntries=5\nretry-on-host-error=on\n" > /root/.wgetrc && \
echo -e "--retry-connrefused\n--connect-timeout 15\n--retry 5\n--max-time 300\n" > /root/.curlrc
COPY build-tools/patches/pgcopydbv017.patch /pgcopydbv017.patch
COPY build_tools/patches/pgcopydbv017.patch /pgcopydbv017.patch
RUN if [ "${DEBIAN_VERSION}" = "bookworm" ]; then \
set -e && \
@@ -188,12 +188,6 @@ RUN curl -fsSL 'https://apt.llvm.org/llvm-snapshot.gpg.key' | apt-key add - \
&& bash -c 'for f in /usr/bin/clang*-${LLVM_VERSION} /usr/bin/llvm*-${LLVM_VERSION}; do ln -s "${f}" "${f%-${LLVM_VERSION}}"; done' \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
# Install node
ENV NODE_VERSION=24
RUN curl -fsSL https://deb.nodesource.com/setup_${NODE_VERSION}.x | bash - \
&& apt install -y nodejs \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
# Install docker
RUN curl -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg \
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/debian ${DEBIAN_VERSION} stable" > /etc/apt/sources.list.d/docker.list \
@@ -317,14 +311,14 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux
. "$HOME/.cargo/env" && \
cargo --version && rustup --version && \
rustup component add llvm-tools rustfmt clippy && \
cargo install rustfilt --locked --version ${RUSTFILT_VERSION} && \
cargo install cargo-hakari --locked --version ${CARGO_HAKARI_VERSION} && \
cargo install cargo-deny --locked --version ${CARGO_DENY_VERSION} && \
cargo install cargo-hack --locked --version ${CARGO_HACK_VERSION} && \
cargo install cargo-nextest --locked --version ${CARGO_NEXTEST_VERSION} && \
cargo install cargo-chef --locked --version ${CARGO_CHEF_VERSION} && \
cargo install diesel_cli --locked --version ${CARGO_DIESEL_CLI_VERSION} \
--features postgres-bundled --no-default-features && \
cargo install rustfilt --version ${RUSTFILT_VERSION} --locked && \
cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} --locked && \
cargo install cargo-deny --version ${CARGO_DENY_VERSION} --locked && \
cargo install cargo-hack --version ${CARGO_HACK_VERSION} --locked && \
cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} --locked && \
cargo install cargo-chef --version ${CARGO_CHEF_VERSION} --locked && \
cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} --locked \
--features postgres-bundled --no-default-features && \
rm -rf /home/nonroot/.cargo/registry && \
rm -rf /home/nonroot/.cargo/git

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +0,0 @@
{
"name": "build-tools",
"private": true,
"devDependencies": {
"@redocly/cli": "1.34.4",
"@sourcemeta/jsonschema": "10.0.0"
}
}

View File

@@ -50,9 +50,9 @@ jsonnetfmt-format:
jsonnetfmt --in-place $(jsonnet_files)
.PHONY: manifest-schema-validation
manifest-schema-validation: ../build-tools/node_modules
npx --prefix=../build-tools/ jsonschema validate -d https://json-schema.org/draft/2020-12/schema manifest.schema.json manifest.yaml
manifest-schema-validation: node_modules
node_modules/.bin/jsonschema validate -d https://json-schema.org/draft/2020-12/schema manifest.schema.json manifest.yaml
../build-tools/node_modules: ../build-tools/package.json
cd ../build-tools && $(if $(CI),npm ci,npm install)
touch ../build-tools/node_modules
node_modules: package.json
npm install
touch node_modules

View File

@@ -9,7 +9,7 @@
#
# build-tools: This contains Rust compiler toolchain and other tools needed at compile
# time. This is also used for the storage builds. This image is defined in
# build-tools/Dockerfile.
# build-tools.Dockerfile.
#
# build-deps: Contains C compiler, other build tools, and compile-time dependencies
# needed to compile PostgreSQL and most extensions. (Some extensions need
@@ -115,7 +115,7 @@ ARG EXTENSIONS=all
FROM $BASE_IMAGE_SHA AS build-deps
ARG DEBIAN_VERSION
# Keep in sync with build-tools/Dockerfile
# Keep in sync with build-tools.Dockerfile
ENV PROTOC_VERSION=25.1
# Use strict mode for bash to catch errors early
@@ -170,29 +170,7 @@ RUN case $DEBIAN_VERSION in \
FROM build-deps AS pg-build
ARG PG_VERSION
COPY vendor/postgres-${PG_VERSION:?} postgres
COPY compute/patches/postgres_fdw.patch .
COPY compute/patches/pg_stat_statements_pg14-16.patch .
COPY compute/patches/pg_stat_statements_pg17.patch .
RUN cd postgres && \
# Apply patches to some contrib extensions
# For example, we need to grant EXECUTE on pg_stat_statements_reset() to {privileged_role_name}.
# In vanilla Postgres this function is limited to Postgres role superuser.
# In Neon we have {privileged_role_name} role that is not a superuser but replaces superuser in some cases.
# We could add the additional grant statements to the Postgres repository but it would be hard to maintain,
# whenever we need to pick up a new Postgres version and we want to limit the changes in our Postgres fork,
# so we do it here.
case "${PG_VERSION}" in \
"v14" | "v15" | "v16") \
patch -p1 < /pg_stat_statements_pg14-16.patch; \
;; \
"v17") \
patch -p1 < /pg_stat_statements_pg17.patch; \
;; \
*) \
# To do not forget to migrate patches to the next major version
echo "No contrib patches for this PostgreSQL version" && exit 1;; \
esac && \
patch -p1 < /postgres_fdw.patch && \
export CONFIGURE_CMD="./configure CFLAGS='-O2 -g3 -fsigned-char' --enable-debug --with-openssl --with-uuid=ossp \
--with-icu --with-libxml --with-libxslt --with-lz4" && \
if [ "${PG_VERSION:?}" != "v14" ]; then \
@@ -206,6 +184,8 @@ RUN cd postgres && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/autoinc.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/dblink.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/postgres_fdw.control && \
file=/usr/local/pgsql/share/extension/postgres_fdw--1.0.sql && [ -e $file ] && \
echo 'GRANT USAGE ON FOREIGN DATA WRAPPER postgres_fdw TO neon_superuser;' >> $file && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/bloom.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/earthdistance.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/insert_username.control && \
@@ -215,7 +195,34 @@ RUN cd postgres && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgrowlocks.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgstattuple.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/refint.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/xml2.control
echo 'trusted = true' >> /usr/local/pgsql/share/extension/xml2.control && \
# We need to grant EXECUTE on pg_stat_statements_reset() to neon_superuser.
# In vanilla postgres this function is limited to Postgres role superuser.
# In neon we have neon_superuser role that is not a superuser but replaces superuser in some cases.
# We could add the additional grant statements to the postgres repository but it would be hard to maintain,
# whenever we need to pick up a new postgres version and we want to limit the changes in our postgres fork,
# so we do it here.
for file in /usr/local/pgsql/share/extension/pg_stat_statements--*.sql; do \
filename=$(basename "$file"); \
# Note that there are no downgrade scripts for pg_stat_statements, so we \
# don't have to modify any downgrade paths or (much) older versions: we only \
# have to make sure every creation of the pg_stat_statements_reset function \
# also adds execute permissions to the neon_superuser.
case $filename in \
pg_stat_statements--1.4.sql) \
# pg_stat_statements_reset is first created with 1.4
echo 'GRANT EXECUTE ON FUNCTION pg_stat_statements_reset() TO neon_superuser;' >> $file; \
;; \
pg_stat_statements--1.6--1.7.sql) \
# Then with the 1.6-1.7 migration it is re-created with a new signature, thus add the permissions back
echo 'GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) TO neon_superuser;' >> $file; \
;; \
pg_stat_statements--1.10--1.11.sql) \
# Then with the 1.10-1.11 migration it is re-created with a new signature again, thus add the permissions back
echo 'GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint, boolean) TO neon_superuser;' >> $file; \
;; \
esac; \
done;
# Set PATH for all the subsequent build steps
ENV PATH="/usr/local/pgsql/bin:$PATH"
@@ -1517,7 +1524,7 @@ WORKDIR /ext-src
COPY compute/patches/pg_duckdb_v031.patch .
COPY compute/patches/duckdb_v120.patch .
# pg_duckdb build requires source dir to be a git repo to get submodules
# allow {privileged_role_name} to execute some functions that in pg_duckdb are available to superuser only:
# allow neon_superuser to execute some functions that in pg_duckdb are available to superuser only:
# - extension management function duckdb.install_extension()
# - access to duckdb.extensions table and its sequence
RUN git clone --depth 1 --branch v0.3.1 https://github.com/duckdb/pg_duckdb.git pg_duckdb-src && \
@@ -1783,7 +1790,7 @@ RUN set -e \
#########################################################################################
FROM build-deps AS exporters
ARG TARGETARCH
# Keep sql_exporter version same as in build-tools/Dockerfile and
# Keep sql_exporter version same as in build-tools.Dockerfile and
# test_runner/regress/test_compute_metrics.py
# See comment on the top of the file regading `echo`, `-e` and `\n`
RUN if [ "$TARGETARCH" = "amd64" ]; then\

7
compute/package.json Normal file
View File

@@ -0,0 +1,7 @@
{
"name": "neon-compute",
"private": true,
"dependencies": {
"@sourcemeta/jsonschema": "9.3.4"
}
}

View File

@@ -1,26 +1,22 @@
diff --git a/sql/anon.sql b/sql/anon.sql
index 0cdc769..5eab1d6 100644
index 0cdc769..b450327 100644
--- a/sql/anon.sql
+++ b/sql/anon.sql
@@ -1141,3 +1141,19 @@ $$
@@ -1141,3 +1141,15 @@ $$
-- TODO : https://en.wikipedia.org/wiki/L-diversity
-- TODO : https://en.wikipedia.org/wiki/T-closeness
+
+-- NEON Patches
+
+GRANT ALL ON SCHEMA anon to neon_superuser;
+GRANT ALL ON ALL TABLES IN SCHEMA anon TO neon_superuser;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT ALL ON SCHEMA anon to %I', privileged_role_name);
+ EXECUTE format('GRANT ALL ON ALL TABLES IN SCHEMA anon TO %I', privileged_role_name);
+
+ IF current_setting('server_version_num')::int >= 150000 THEN
+ EXECUTE format('GRANT SET ON PARAMETER anon.transparent_dynamic_masking TO %I', privileged_role_name);
+ END IF;
+ IF current_setting('server_version_num')::int >= 150000 THEN
+ GRANT SET ON PARAMETER anon.transparent_dynamic_masking TO neon_superuser;
+ END IF;
+END $$;
diff --git a/sql/init.sql b/sql/init.sql
index 7da6553..9b6164b 100644

View File

@@ -21,21 +21,13 @@ index 3235cc8..6b892bc 100644
include Makefile.global
diff --git a/sql/pg_duckdb--0.2.0--0.3.0.sql b/sql/pg_duckdb--0.2.0--0.3.0.sql
index d777d76..3b54396 100644
index d777d76..af60106 100644
--- a/sql/pg_duckdb--0.2.0--0.3.0.sql
+++ b/sql/pg_duckdb--0.2.0--0.3.0.sql
@@ -1056,3 +1056,14 @@ GRANT ALL ON FUNCTION duckdb.cache(TEXT, TEXT) TO PUBLIC;
@@ -1056,3 +1056,6 @@ GRANT ALL ON FUNCTION duckdb.cache(TEXT, TEXT) TO PUBLIC;
GRANT ALL ON FUNCTION duckdb.cache_info() TO PUBLIC;
GRANT ALL ON FUNCTION duckdb.cache_delete(TEXT) TO PUBLIC;
GRANT ALL ON PROCEDURE duckdb.recycle_ddb() TO PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT ALL ON FUNCTION duckdb.install_extension(TEXT) TO %I', privileged_role_name);
+ EXECUTE format('GRANT ALL ON TABLE duckdb.extensions TO %I', privileged_role_name);
+ EXECUTE format('GRANT ALL ON SEQUENCE duckdb.extensions_table_seq TO %I', privileged_role_name);
+END $$;
+GRANT ALL ON FUNCTION duckdb.install_extension(TEXT) TO neon_superuser;
+GRANT ALL ON TABLE duckdb.extensions TO neon_superuser;
+GRANT ALL ON SEQUENCE duckdb.extensions_table_seq TO neon_superuser;

View File

@@ -1,34 +0,0 @@
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.4.sql b/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
index 58cdf600fce..8be57a996f6 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
@@ -46,3 +46,12 @@ GRANT SELECT ON pg_stat_statements TO PUBLIC;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset() FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset() TO %I', privileged_role_name);
+END $$;
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql b/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
index 6fc3fed4c93..256345a8f79 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
@@ -20,3 +20,12 @@ LANGUAGE C STRICT PARALLEL SAFE;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) TO %I', privileged_role_name);
+END $$;

View File

@@ -1,52 +0,0 @@
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.10--1.11.sql b/contrib/pg_stat_statements/pg_stat_statements--1.10--1.11.sql
index 0bb2c397711..32764db1d8b 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.10--1.11.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.10--1.11.sql
@@ -80,3 +80,12 @@ LANGUAGE C STRICT PARALLEL SAFE;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint, boolean) FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint, boolean) TO %I', privileged_role_name);
+END $$;
\ No newline at end of file
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.4.sql b/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
index 58cdf600fce..8be57a996f6 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
@@ -46,3 +46,12 @@ GRANT SELECT ON pg_stat_statements TO PUBLIC;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset() FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset() TO %I', privileged_role_name);
+END $$;
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql b/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
index 6fc3fed4c93..256345a8f79 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
@@ -20,3 +20,12 @@ LANGUAGE C STRICT PARALLEL SAFE;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) TO %I', privileged_role_name);
+END $$;

View File

@@ -1,17 +0,0 @@
diff --git a/contrib/postgres_fdw/postgres_fdw--1.0.sql b/contrib/postgres_fdw/postgres_fdw--1.0.sql
index a0f0fc1bf45..ee077f2eea6 100644
--- a/contrib/postgres_fdw/postgres_fdw--1.0.sql
+++ b/contrib/postgres_fdw/postgres_fdw--1.0.sql
@@ -16,3 +16,12 @@ LANGUAGE C STRICT;
CREATE FOREIGN DATA WRAPPER postgres_fdw
HANDLER postgres_fdw_handler
VALIDATOR postgres_fdw_validator;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT USAGE ON FOREIGN DATA WRAPPER postgres_fdw TO %I', privileged_role_name);
+END $$;

View File

@@ -87,14 +87,6 @@ struct Cli {
#[arg(short = 'C', long, value_name = "DATABASE_URL")]
pub connstr: String,
#[arg(
long,
default_value = "neon_superuser",
value_name = "PRIVILEGED_ROLE_NAME",
value_parser = Self::parse_privileged_role_name
)]
pub privileged_role_name: String,
#[cfg(target_os = "linux")]
#[arg(long, default_value = "neon-postgres")]
pub cgroup: String,
@@ -157,21 +149,6 @@ impl Cli {
Ok(url)
}
/// For simplicity, we do not escape `privileged_role_name` anywhere in the code.
/// Since it's a system role, which we fully control, that's fine. Still, let's
/// validate it to avoid any surprises.
fn parse_privileged_role_name(value: &str) -> Result<String> {
use regex::Regex;
let pattern = Regex::new(r"^[a-z_]+$").unwrap();
if !pattern.is_match(value) {
bail!("--privileged-role-name can only contain lowercase letters and underscores")
}
Ok(value.to_string())
}
}
fn main() -> Result<()> {
@@ -201,7 +178,6 @@ fn main() -> Result<()> {
ComputeNodeParams {
compute_id: cli.compute_id,
connstr,
privileged_role_name: cli.privileged_role_name.clone(),
pgdata: cli.pgdata.clone(),
pgbin: cli.pgbin.clone(),
pgversion: get_pg_version_string(&cli.pgbin),
@@ -351,49 +327,4 @@ mod test {
])
.expect_err("URL parameters are not allowed");
}
#[test]
fn verify_privileged_role_name() {
// Valid name
let cli = Cli::parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--privileged-role-name",
"my_superuser",
]);
assert_eq!(cli.privileged_role_name, "my_superuser");
// Invalid names
Cli::try_parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--privileged-role-name",
"NeonSuperuser",
])
.expect_err("uppercase letters are not allowed");
Cli::try_parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--privileged-role-name",
"$'neon_superuser",
])
.expect_err("special characters are not allowed");
Cli::try_parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--privileged-role-name",
"",
])
.expect_err("empty name is not allowed");
}
}

View File

@@ -74,20 +74,12 @@ const DEFAULT_INSTALLED_EXTENSIONS_COLLECTION_INTERVAL: u64 = 3600;
/// Static configuration params that don't change after startup. These mostly
/// come from the CLI args, or are derived from them.
#[derive(Clone, Debug)]
pub struct ComputeNodeParams {
/// The ID of the compute
pub compute_id: String,
/// Url type maintains proper escaping
// Url type maintains proper escaping
pub connstr: url::Url,
/// The name of the 'weak' superuser role, which we give to the users.
/// It follows the allow list approach, i.e., we take a standard role
/// and grant it extra permissions with explicit GRANTs here and there,
/// and core patches.
pub privileged_role_name: String,
pub resize_swap_on_bind: bool,
pub set_disk_quota_for_fs: Option<String>,
@@ -1048,8 +1040,6 @@ impl ComputeNode {
PageserverProtocol::Grpc => self.try_get_basebackup_grpc(spec, lsn)?,
};
self.fix_zenith_signal_neon_signal()?;
let mut state = self.state.lock().unwrap();
state.metrics.pageserver_connect_micros =
connected.duration_since(started).as_micros() as u64;
@@ -1059,27 +1049,6 @@ impl ComputeNode {
Ok(())
}
/// Move the Zenith signal file to Neon signal file location.
/// This makes Compute compatible with older PageServers that don't yet
/// know about the Zenith->Neon rename.
fn fix_zenith_signal_neon_signal(&self) -> Result<()> {
let datadir = Path::new(&self.params.pgdata);
let neonsig = datadir.join("neon.signal");
if neonsig.is_file() {
return Ok(());
}
let zenithsig = datadir.join("zenith.signal");
if zenithsig.is_file() {
fs::copy(zenithsig, neonsig)?;
}
Ok(())
}
/// Fetches a basebackup via gRPC. The connstring must use grpc://. Returns the timestamp when
/// the connection was established, and the (compressed) size of the basebackup.
fn try_get_basebackup_grpc(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> {
@@ -1294,7 +1263,9 @@ impl ComputeNode {
// In case of error, log and fail the check, but don't crash.
// We're playing it safe because these errors could be transient
// and we don't yet retry.
// and we don't yet retry. Also being careful here allows us to
// be backwards compatible with safekeepers that don't have the
// TIMELINE_STATUS API yet.
if responses.len() < quorum {
error!(
"failed sync safekeepers check {:?} {:?} {:?}",
@@ -1397,7 +1368,6 @@ impl ComputeNode {
self.create_pgdata()?;
config::write_postgres_conf(
pgdata_path,
&self.params,
&pspec.spec,
self.params.internal_http_port,
tls_config,
@@ -1746,7 +1716,6 @@ impl ComputeNode {
}
// Run migrations separately to not hold up cold starts
let params = self.params.clone();
tokio::spawn(async move {
let mut conf = conf.as_ref().clone();
conf.application_name("compute_ctl:migrations");
@@ -1758,7 +1727,7 @@ impl ComputeNode {
eprintln!("connection error: {e}");
}
});
if let Err(e) = handle_migrations(params, &mut client).await {
if let Err(e) = handle_migrations(&mut client).await {
error!("Failed to run migrations: {}", e);
}
}
@@ -1837,7 +1806,6 @@ impl ComputeNode {
let pgdata_path = Path::new(&self.params.pgdata);
config::write_postgres_conf(
pgdata_path,
&self.params,
&spec,
self.params.internal_http_port,
tls_config,
@@ -2450,31 +2418,14 @@ LIMIT 100",
pub fn spawn_lfc_offload_task(self: &Arc<Self>, interval: Duration) {
self.terminate_lfc_offload_task();
let secs = interval.as_secs();
info!("spawning lfc offload worker with {secs}s interval");
let this = self.clone();
info!("spawning LFC offload worker with {secs}s interval");
let handle = spawn(async move {
let mut interval = time::interval(interval);
interval.tick().await; // returns immediately
loop {
interval.tick().await;
let prewarm_state = this.state.lock().unwrap().lfc_prewarm_state.clone();
// Do not offload LFC state if we are currently prewarming or any issue occurred.
// If we'd do that, we might override the LFC state in endpoint storage with some
// incomplete state. Imagine a situation:
// 1. Endpoint started with `autoprewarm: true`
// 2. While prewarming is not completed, we upload the new incomplete state
// 3. Compute gets interrupted and restarts
// 4. We start again and try to prewarm with the state from 2. instead of the previous complete state
if matches!(
prewarm_state,
LfcPrewarmState::Completed
| LfcPrewarmState::NotPrewarmed
| LfcPrewarmState::Skipped
) {
this.offload_lfc_async().await;
}
this.offload_lfc_async().await;
}
});
*self.lfc_offload_task.lock().unwrap() = Some(handle);
@@ -2513,7 +2464,7 @@ pub async fn installed_extensions(conf: tokio_postgres::Config) -> Result<()> {
serde_json::to_string(&extensions).expect("failed to serialize extensions list")
);
}
Err(err) => error!("could not get installed extensions: {err}"),
Err(err) => error!("could not get installed extensions: {err:?}"),
}
Ok(())
}

View File

@@ -89,7 +89,7 @@ impl ComputeNode {
self.state.lock().unwrap().lfc_offload_state.clone()
}
/// If there is a prewarm request ongoing, return `false`, `true` otherwise.
/// If there is a prewarm request ongoing, return false, true otherwise
pub fn prewarm_lfc(self: &Arc<Self>, from_endpoint: Option<String>) -> bool {
{
let state = &mut self.state.lock().unwrap().lfc_prewarm_state;
@@ -101,25 +101,15 @@ impl ComputeNode {
let cloned = self.clone();
spawn(async move {
let state = match cloned.prewarm_impl(from_endpoint).await {
Ok(true) => LfcPrewarmState::Completed,
Ok(false) => {
info!(
"skipping LFC prewarm because LFC state is not found in endpoint storage"
);
LfcPrewarmState::Skipped
}
Err(err) => {
crate::metrics::LFC_PREWARM_ERRORS.inc();
error!(%err, "could not prewarm LFC");
LfcPrewarmState::Failed {
error: err.to_string(),
}
}
let Err(err) = cloned.prewarm_impl(from_endpoint).await else {
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Completed;
return;
};
crate::metrics::LFC_PREWARM_ERRORS.inc();
error!(%err, "prewarming lfc");
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Failed {
error: err.to_string(),
};
cloned.state.lock().unwrap().lfc_prewarm_state = state;
});
true
}
@@ -130,21 +120,15 @@ impl ComputeNode {
EndpointStoragePair::from_spec_and_endpoint(state.pspec.as_ref().unwrap(), from_endpoint)
}
/// Request LFC state from endpoint storage and load corresponding pages into Postgres.
/// Returns a result with `false` if the LFC state is not found in endpoint storage.
async fn prewarm_impl(&self, from_endpoint: Option<String>) -> Result<bool> {
async fn prewarm_impl(&self, from_endpoint: Option<String>) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?;
info!(%url, "requesting LFC state from endpoint storage");
let request = Client::new().get(&url).bearer_auth(token);
let res = request.send().await.context("querying endpoint storage")?;
let status = res.status();
match status {
StatusCode::OK => (),
StatusCode::NOT_FOUND => {
return Ok(false);
}
_ => bail!("{status} querying endpoint storage"),
if status != StatusCode::OK {
bail!("{status} querying endpoint storage")
}
let mut uncompressed = Vec::new();
@@ -157,8 +141,7 @@ impl ComputeNode {
.await
.context("decoding LFC state")?;
let uncompressed_len = uncompressed.len();
info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}, loading into Postgres");
info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}, loading into postgres");
ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
.await
@@ -166,9 +149,7 @@ impl ComputeNode {
.query_one("select neon.prewarm_local_cache($1)", &[&uncompressed])
.await
.context("loading LFC state into postgres")
.map(|_| ())?;
Ok(true)
.map(|_| ())
}
/// If offload request is ongoing, return false, true otherwise
@@ -196,14 +177,12 @@ impl ComputeNode {
async fn offload_lfc_with_state_update(&self) {
crate::metrics::LFC_OFFLOADS.inc();
let Err(err) = self.offload_lfc_impl().await else {
self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Completed;
return;
};
crate::metrics::LFC_OFFLOAD_ERRORS.inc();
error!(%err, "could not offload LFC state to endpoint storage");
error!(%err, "offloading lfc");
self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Failed {
error: err.to_string(),
};
@@ -211,7 +190,7 @@ impl ComputeNode {
async fn offload_lfc_impl(&self) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?;
info!(%url, "requesting LFC state from Postgres");
info!(%url, "requesting LFC state from postgres");
let mut compressed = Vec::new();
ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
@@ -226,17 +205,13 @@ impl ComputeNode {
.read_to_end(&mut compressed)
.await
.context("compressing LFC state")?;
let compressed_len = compressed.len();
info!(%url, "downloaded LFC state, compressed size {compressed_len}, writing to endpoint storage");
let request = Client::new().put(url).bearer_auth(token).body(compressed);
match request.send().await {
Ok(res) if res.status() == StatusCode::OK => Ok(()),
Ok(res) => bail!(
"Request to endpoint storage failed with status: {}",
res.status()
),
Ok(res) => bail!("Error writing to endpoint storage: {}", res.status()),
Err(err) => Err(err).context("writing to endpoint storage"),
}
}

View File

@@ -9,7 +9,6 @@ use std::path::Path;
use compute_api::responses::TlsConfig;
use compute_api::spec::{ComputeAudit, ComputeMode, ComputeSpec, GenericOption};
use crate::compute::ComputeNodeParams;
use crate::pg_helpers::{
GenericOptionExt, GenericOptionsSearch, PgOptionsSerialize, escape_conf_value,
};
@@ -42,7 +41,6 @@ pub fn line_in_file(path: &Path, line: &str) -> Result<bool> {
/// Create or completely rewrite configuration file specified by `path`
pub fn write_postgres_conf(
pgdata_path: &Path,
params: &ComputeNodeParams,
spec: &ComputeSpec,
extension_server_port: u16,
tls_config: &Option<TlsConfig>,
@@ -56,15 +54,14 @@ pub fn write_postgres_conf(
writeln!(file, "{conf}")?;
}
// Stripe size GUC should be defined prior to connection string
if let Some(stripe_size) = spec.shard_stripe_size {
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
// Add options for connecting to storage
writeln!(file, "# Neon storage settings")?;
if let Some(s) = &spec.pageserver_connstring {
writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?;
}
if let Some(stripe_size) = spec.shard_stripe_size {
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
if !spec.safekeeper_connstrings.is_empty() {
let mut neon_safekeepers_value = String::new();
tracing::info!(
@@ -164,12 +161,6 @@ pub fn write_postgres_conf(
}
}
writeln!(
file,
"neon.privileged_role_name={}",
escape_conf_value(params.privileged_role_name.as_str())
)?;
// If there are any extra options in the 'settings' field, append those
if spec.cluster.settings.is_some() {
writeln!(file, "# Managed by compute_ctl: begin")?;

View File

@@ -613,11 +613,11 @@ components:
- skipped
properties:
status:
description: LFC prewarm status
enum: [not_prewarmed, prewarming, completed, failed, skipped]
description: Lfc prewarm status
enum: [not_prewarmed, prewarming, completed, failed]
type: string
error:
description: LFC prewarm error, if any
description: Lfc prewarm error, if any
type: string
total:
description: Total pages processed
@@ -635,11 +635,11 @@ components:
- status
properties:
status:
description: LFC offload status
description: Lfc offload status
enum: [not_offloaded, offloading, completed, failed]
type: string
error:
description: LFC offload error, if any
description: Lfc offload error, if any
type: string
PromoteState:

View File

@@ -2,7 +2,6 @@ use std::collections::HashMap;
use anyhow::Result;
use compute_api::responses::{InstalledExtension, InstalledExtensions};
use tokio_postgres::error::Error as PostgresError;
use tokio_postgres::{Client, Config, NoTls};
use crate::metrics::INSTALLED_EXTENSIONS;
@@ -11,7 +10,7 @@ use crate::metrics::INSTALLED_EXTENSIONS;
/// and to make database listing query here more explicit.
///
/// Limit the number of databases to 500 to avoid excessive load.
async fn list_dbs(client: &mut Client) -> Result<Vec<String>, PostgresError> {
async fn list_dbs(client: &mut Client) -> Result<Vec<String>> {
// `pg_database.datconnlimit = -2` means that the database is in the
// invalid state
let databases = client
@@ -38,9 +37,7 @@ async fn list_dbs(client: &mut Client) -> Result<Vec<String>, PostgresError> {
/// Same extension can be installed in multiple databases with different versions,
/// so we report a separate metric (number of databases where it is installed)
/// for each extension version.
pub async fn get_installed_extensions(
mut conf: Config,
) -> Result<InstalledExtensions, PostgresError> {
pub async fn get_installed_extensions(mut conf: Config) -> Result<InstalledExtensions> {
conf.application_name("compute_ctl:get_installed_extensions");
let databases: Vec<String> = {
let (mut client, connection) = conf.connect(NoTls).await?;

View File

@@ -1 +0,0 @@
ALTER ROLE {privileged_role_name} BYPASSRLS;

View File

@@ -0,0 +1 @@
ALTER ROLE neon_superuser BYPASSRLS;

View File

@@ -15,7 +15,7 @@ DO $$
DECLARE
role_name text;
BEGIN
FOR role_name IN SELECT rolname FROM pg_roles WHERE pg_has_role(rolname, '{privileged_role_name}', 'member')
FOR role_name IN SELECT rolname FROM pg_roles WHERE pg_has_role(rolname, 'neon_superuser', 'member')
LOOP
RAISE NOTICE 'EXECUTING ALTER ROLE % INHERIT', quote_ident(role_name);
EXECUTE 'ALTER ROLE ' || quote_ident(role_name) || ' INHERIT';
@@ -23,7 +23,7 @@ BEGIN
FOR role_name IN SELECT rolname FROM pg_roles
WHERE
NOT pg_has_role(rolname, '{privileged_role_name}', 'member') AND NOT starts_with(rolname, 'pg_')
NOT pg_has_role(rolname, 'neon_superuser', 'member') AND NOT starts_with(rolname, 'pg_')
LOOP
RAISE NOTICE 'EXECUTING ALTER ROLE % NOBYPASSRLS', quote_ident(role_name);
EXECUTE 'ALTER ROLE ' || quote_ident(role_name) || ' NOBYPASSRLS';

View File

@@ -1,6 +1,6 @@
DO $$
BEGIN
IF (SELECT setting::numeric >= 160000 FROM pg_settings WHERE name = 'server_version_num') THEN
EXECUTE 'GRANT pg_create_subscription TO {privileged_role_name}';
EXECUTE 'GRANT pg_create_subscription TO neon_superuser';
END IF;
END $$;

View File

@@ -0,0 +1 @@
GRANT pg_monitor TO neon_superuser WITH ADMIN OPTION;

View File

@@ -1 +0,0 @@
GRANT pg_monitor TO {privileged_role_name} WITH ADMIN OPTION;

View File

@@ -1,4 +1,4 @@
-- SKIP: Deemed insufficient for allowing relations created by extensions to be
-- interacted with by {privileged_role_name} without permission issues.
-- interacted with by neon_superuser without permission issues.
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO {privileged_role_name};
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO neon_superuser;

View File

@@ -1,4 +1,4 @@
-- SKIP: Deemed insufficient for allowing relations created by extensions to be
-- interacted with by {privileged_role_name} without permission issues.
-- interacted with by neon_superuser without permission issues.
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO {privileged_role_name};
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO neon_superuser;

View File

@@ -1,3 +1,3 @@
-- SKIP: Moved inline to the handle_grants() functions.
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO {privileged_role_name} WITH GRANT OPTION;
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO neon_superuser WITH GRANT OPTION;

View File

@@ -1,3 +1,3 @@
-- SKIP: Moved inline to the handle_grants() functions.
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO {privileged_role_name} WITH GRANT OPTION;
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO neon_superuser WITH GRANT OPTION;

View File

@@ -1,7 +1,7 @@
DO $$
BEGIN
IF (SELECT setting::numeric >= 160000 FROM pg_settings WHERE name = 'server_version_num') THEN
EXECUTE 'GRANT EXECUTE ON FUNCTION pg_export_snapshot TO {privileged_role_name}';
EXECUTE 'GRANT EXECUTE ON FUNCTION pg_log_standby_snapshot TO {privileged_role_name}';
EXECUTE 'GRANT EXECUTE ON FUNCTION pg_export_snapshot TO neon_superuser';
EXECUTE 'GRANT EXECUTE ON FUNCTION pg_log_standby_snapshot TO neon_superuser';
END IF;
END $$;

View File

@@ -0,0 +1 @@
GRANT EXECUTE ON FUNCTION pg_show_replication_origin_status TO neon_superuser;

View File

@@ -1 +0,0 @@
GRANT EXECUTE ON FUNCTION pg_show_replication_origin_status TO {privileged_role_name};

View File

@@ -0,0 +1 @@
GRANT pg_signal_backend TO neon_superuser WITH ADMIN OPTION;

View File

@@ -1 +0,0 @@
GRANT pg_signal_backend TO {privileged_role_name} WITH ADMIN OPTION;

View File

@@ -9,7 +9,6 @@ use reqwest::StatusCode;
use tokio_postgres::Client;
use tracing::{error, info, instrument};
use crate::compute::ComputeNodeParams;
use crate::config;
use crate::metrics::{CPLANE_REQUESTS_TOTAL, CPlaneRequestRPC, UNKNOWN_HTTP_STATUS};
use crate::migration::MigrationRunner;
@@ -170,7 +169,7 @@ pub async fn handle_neon_extension_upgrade(client: &mut Client) -> Result<()> {
}
#[instrument(skip_all)]
pub async fn handle_migrations(params: ComputeNodeParams, client: &mut Client) -> Result<()> {
pub async fn handle_migrations(client: &mut Client) -> Result<()> {
info!("handle migrations");
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -179,59 +178,26 @@ pub async fn handle_migrations(params: ComputeNodeParams, client: &mut Client) -
// Add new migrations in numerical order.
let migrations = [
&format!(
include_str!("./migrations/0001-add_bypass_rls_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
include_str!("./migrations/0001-neon_superuser_bypass_rls.sql"),
include_str!("./migrations/0002-alter_roles.sql"),
include_str!("./migrations/0003-grant_pg_create_subscription_to_neon_superuser.sql"),
include_str!("./migrations/0004-grant_pg_monitor_to_neon_superuser.sql"),
include_str!("./migrations/0005-grant_all_on_tables_to_neon_superuser.sql"),
include_str!("./migrations/0006-grant_all_on_sequences_to_neon_superuser.sql"),
include_str!(
"./migrations/0007-grant_all_on_tables_to_neon_superuser_with_grant_option.sql"
),
&format!(
include_str!("./migrations/0002-alter_roles.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0003-grant_pg_create_subscription_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0004-grant_pg_monitor_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0005-grant_all_on_tables_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0006-grant_all_on_sequences_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!(
"./migrations/0007-grant_all_on_tables_with_grant_option_to_privileged_role.sql"
),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!(
"./migrations/0008-grant_all_on_sequences_with_grant_option_to_privileged_role.sql"
),
privileged_role_name = params.privileged_role_name
include_str!(
"./migrations/0008-grant_all_on_sequences_to_neon_superuser_with_grant_option.sql"
),
include_str!("./migrations/0009-revoke_replication_for_previously_allowed_roles.sql"),
&format!(
include_str!(
"./migrations/0010-grant_snapshot_synchronization_funcs_to_privileged_role.sql"
),
privileged_role_name = params.privileged_role_name
include_str!(
"./migrations/0010-grant_snapshot_synchronization_funcs_to_neon_superuser.sql"
),
&format!(
include_str!(
"./migrations/0011-grant_pg_show_replication_origin_status_to_privileged_role.sql"
),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0012-grant_pg_signal_backend_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
include_str!(
"./migrations/0011-grant_pg_show_replication_origin_status_to_neon_superuser.sql"
),
include_str!("./migrations/0012-grant_pg_signal_backend_to_neon_superuser.sql"),
];
MigrationRunner::new(client, &migrations)

View File

@@ -13,14 +13,14 @@ use tokio_postgres::Client;
use tokio_postgres::error::SqlState;
use tracing::{Instrument, debug, error, info, info_span, instrument, warn};
use crate::compute::{ComputeNode, ComputeNodeParams, ComputeState};
use crate::compute::{ComputeNode, ComputeState};
use crate::pg_helpers::{
DatabaseExt, Escaping, GenericOptionsSearch, RoleExt, get_existing_dbs_async,
get_existing_roles_async,
};
use crate::spec_apply::ApplySpecPhase::{
CreateAndAlterDatabases, CreateAndAlterRoles, CreateAvailabilityCheck, CreatePgauditExtension,
CreatePgauditlogtofileExtension, CreatePrivilegedRole, CreateSchemaNeon,
CreateAndAlterDatabases, CreateAndAlterRoles, CreateAvailabilityCheck, CreateNeonSuperuser,
CreatePgauditExtension, CreatePgauditlogtofileExtension, CreateSchemaNeon,
DisablePostgresDBPgAudit, DropInvalidDatabases, DropRoles, FinalizeDropLogicalSubscriptions,
HandleNeonExtension, HandleOtherExtensions, RenameAndDeleteDatabases, RenameRoles,
RunInEachDatabase,
@@ -49,7 +49,6 @@ impl ComputeNode {
// Proceed with post-startup configuration. Note, that order of operations is important.
let client = Self::get_maintenance_client(&conf).await?;
let spec = spec.clone();
let params = Arc::new(self.params.clone());
let databases = get_existing_dbs_async(&client).await?;
let roles = get_existing_roles_async(&client)
@@ -158,7 +157,6 @@ impl ComputeNode {
let conf = Arc::new(conf);
let fut = Self::apply_spec_sql_db(
params.clone(),
spec.clone(),
conf,
ctx.clone(),
@@ -187,7 +185,7 @@ impl ComputeNode {
}
for phase in [
CreatePrivilegedRole,
CreateNeonSuperuser,
DropInvalidDatabases,
RenameRoles,
CreateAndAlterRoles,
@@ -197,7 +195,6 @@ impl ComputeNode {
] {
info!("Applying phase {:?}", &phase);
apply_operations(
params.clone(),
spec.clone(),
ctx.clone(),
jwks_roles.clone(),
@@ -246,7 +243,6 @@ impl ComputeNode {
}
let fut = Self::apply_spec_sql_db(
params.clone(),
spec.clone(),
conf,
ctx.clone(),
@@ -297,7 +293,6 @@ impl ComputeNode {
for phase in phases {
debug!("Applying phase {:?}", &phase);
apply_operations(
params.clone(),
spec.clone(),
ctx.clone(),
jwks_roles.clone(),
@@ -318,9 +313,7 @@ impl ComputeNode {
/// May opt to not connect to databases that don't have any scheduled
/// operations. The function is concurrency-controlled with the provided
/// semaphore. The caller has to make sure the semaphore isn't exhausted.
#[allow(clippy::too_many_arguments)] // TODO: needs bigger refactoring
async fn apply_spec_sql_db(
params: Arc<ComputeNodeParams>,
spec: Arc<ComputeSpec>,
conf: Arc<tokio_postgres::Config>,
ctx: Arc<tokio::sync::RwLock<MutableApplyContext>>,
@@ -335,7 +328,6 @@ impl ComputeNode {
for subphase in subphases {
apply_operations(
params.clone(),
spec.clone(),
ctx.clone(),
jwks_roles.clone(),
@@ -475,7 +467,7 @@ pub enum PerDatabasePhase {
#[derive(Clone, Debug)]
pub enum ApplySpecPhase {
CreatePrivilegedRole,
CreateNeonSuperuser,
DropInvalidDatabases,
RenameRoles,
CreateAndAlterRoles,
@@ -518,7 +510,6 @@ pub struct MutableApplyContext {
/// - No timeouts have (yet) been implemented.
/// - The caller is responsible for limiting and/or applying concurrency.
pub async fn apply_operations<'a, Fut, F>(
params: Arc<ComputeNodeParams>,
spec: Arc<ComputeSpec>,
ctx: Arc<RwLock<MutableApplyContext>>,
jwks_roles: Arc<HashSet<String>>,
@@ -536,7 +527,7 @@ where
debug!("Processing phase {:?}", &apply_spec_phase);
let ctx = ctx;
let mut ops = get_operations(&params, &spec, &ctx, &jwks_roles, &apply_spec_phase)
let mut ops = get_operations(&spec, &ctx, &jwks_roles, &apply_spec_phase)
.await?
.peekable();
@@ -597,18 +588,14 @@ where
/// sort/merge/batch execution, but for now this is a nice way to improve
/// batching behavior of the commands.
async fn get_operations<'a>(
params: &'a ComputeNodeParams,
spec: &'a ComputeSpec,
ctx: &'a RwLock<MutableApplyContext>,
jwks_roles: &'a HashSet<String>,
apply_spec_phase: &'a ApplySpecPhase,
) -> Result<Box<dyn Iterator<Item = Operation> + 'a + Send>> {
match apply_spec_phase {
ApplySpecPhase::CreatePrivilegedRole => Ok(Box::new(once(Operation {
query: format!(
include_str!("sql/create_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
ApplySpecPhase::CreateNeonSuperuser => Ok(Box::new(once(Operation {
query: include_str!("sql/create_neon_superuser.sql").to_string(),
comment: None,
}))),
ApplySpecPhase::DropInvalidDatabases => {
@@ -710,9 +697,8 @@ async fn get_operations<'a>(
None => {
let query = if !jwks_roles.contains(role.name.as_str()) {
format!(
"CREATE ROLE {} INHERIT CREATEROLE CREATEDB BYPASSRLS REPLICATION IN ROLE {} {}",
"CREATE ROLE {} INHERIT CREATEROLE CREATEDB BYPASSRLS REPLICATION IN ROLE neon_superuser {}",
role.name.pg_quote(),
params.privileged_role_name,
role.to_pg_options(),
)
} else {
@@ -863,9 +849,8 @@ async fn get_operations<'a>(
// ALL PRIVILEGES grants CREATE, CONNECT, and TEMPORARY on the database
// (see https://www.postgresql.org/docs/current/ddl-priv.html)
query: format!(
"GRANT ALL PRIVILEGES ON DATABASE {} TO {}",
db.name.pg_quote(),
params.privileged_role_name
"GRANT ALL PRIVILEGES ON DATABASE {} TO neon_superuser",
db.name.pg_quote()
),
comment: None,
},

View File

@@ -0,0 +1,8 @@
DO $$
BEGIN
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = 'neon_superuser')
THEN
CREATE ROLE neon_superuser CREATEDB CREATEROLE NOLOGIN REPLICATION BYPASSRLS IN ROLE pg_read_all_data, pg_write_all_data;
END IF;
END
$$;

View File

@@ -1,8 +0,0 @@
DO $$
BEGIN
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{privileged_role_name}')
THEN
CREATE ROLE {privileged_role_name} CREATEDB CREATEROLE NOLOGIN REPLICATION BYPASSRLS IN ROLE pg_read_all_data, pg_write_all_data;
END IF;
END
$$;

View File

@@ -8,10 +8,10 @@ code changes locally, but not suitable for running production systems.
## Example: Start with Postgres 16
To create and start a local development environment with Postgres 16, you will need to provide `--pg-version` flag to 2 of the start-up commands.
To create and start a local development environment with Postgres 16, you will need to provide `--pg-version` flag to 3 of the start-up commands.
```shell
cargo neon init
cargo neon init --pg-version 16
cargo neon start
cargo neon tenant create --set-default --pg-version 16
cargo neon endpoint create main --pg-version 16

View File

@@ -631,10 +631,6 @@ struct EndpointCreateCmdArgs {
help = "Allow multiple primary endpoints running on the same branch. Shouldn't be used normally, but useful for tests."
)]
allow_multiple: bool,
/// Only allow changing it on creation
#[clap(long, help = "Name of the privileged role for the endpoint")]
privileged_role_name: Option<String>,
}
#[derive(clap::Args)]
@@ -1484,7 +1480,6 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
args.grpc,
!args.update_catalog,
false,
args.privileged_role_name.clone(),
)?;
}
EndpointCmd::Start(args) => {

View File

@@ -32,8 +32,7 @@
//! config.json - passed to `compute_ctl`
//! pgdata/
//! postgresql.conf - copy of postgresql.conf created by `compute_ctl`
//! neon.signal
//! zenith.signal - copy of neon.signal, for backward compatibility
//! zenith.signal
//! <other PostgreSQL files>
//! ```
//!
@@ -99,7 +98,6 @@ pub struct EndpointConf {
features: Vec<ComputeFeature>,
cluster: Option<Cluster>,
compute_ctl_config: ComputeCtlConfig,
privileged_role_name: Option<String>,
}
//
@@ -200,7 +198,6 @@ impl ComputeControlPlane {
grpc: bool,
skip_pg_catalog_updates: bool,
drop_subscriptions_before_start: bool,
privileged_role_name: Option<String>,
) -> Result<Arc<Endpoint>> {
let pg_port = pg_port.unwrap_or_else(|| self.get_port());
let external_http_port = external_http_port.unwrap_or_else(|| self.get_port() + 1);
@@ -238,7 +235,6 @@ impl ComputeControlPlane {
features: vec![],
cluster: None,
compute_ctl_config: compute_ctl_config.clone(),
privileged_role_name: privileged_role_name.clone(),
});
ep.create_endpoint_dir()?;
@@ -260,7 +256,6 @@ impl ComputeControlPlane {
features: vec![],
cluster: None,
compute_ctl_config,
privileged_role_name,
})?,
)?;
std::fs::write(
@@ -336,9 +331,6 @@ pub struct Endpoint {
/// The compute_ctl config for the endpoint's compute.
compute_ctl_config: ComputeCtlConfig,
/// The name of the privileged role for the endpoint.
privileged_role_name: Option<String>,
}
#[derive(PartialEq, Eq)]
@@ -439,7 +431,6 @@ impl Endpoint {
features: conf.features,
cluster: conf.cluster,
compute_ctl_config: conf.compute_ctl_config,
privileged_role_name: conf.privileged_role_name,
})
}
@@ -472,7 +463,7 @@ impl Endpoint {
conf.append("max_connections", "100");
conf.append("wal_level", "logical");
// wal_sender_timeout is the maximum time to wait for WAL replication.
// It also defines how often the walreceiver will send a feedback message to the wal sender.
// It also defines how often the walreciever will send a feedback message to the wal sender.
conf.append("wal_sender_timeout", "5s");
conf.append("listen_addresses", &self.pg_address.ip().to_string());
conf.append("port", &self.pg_address.port().to_string());
@@ -878,10 +869,6 @@ impl Endpoint {
cmd.arg("--dev");
}
if let Some(privileged_role_name) = self.privileged_role_name.clone() {
cmd.args(["--privileged-role-name", &privileged_role_name]);
}
let child = cmd.spawn()?;
// set up a scopeguard to kill & wait for the child in case we panic or bail below
let child = scopeguard::guard(child, |mut child| {

View File

@@ -217,9 +217,6 @@ pub struct NeonStorageControllerConf {
pub posthog_config: Option<PostHogConfig>,
pub kick_secondary_downloads: Option<bool>,
#[serde(with = "humantime_serde")]
pub shard_split_request_timeout: Option<Duration>,
}
impl NeonStorageControllerConf {
@@ -253,7 +250,6 @@ impl Default for NeonStorageControllerConf {
timeline_safekeeper_count: None,
posthog_config: None,
kick_secondary_downloads: None,
shard_split_request_timeout: None,
}
}
}

View File

@@ -648,13 +648,6 @@ impl StorageController {
args.push(format!("--timeline-safekeeper-count={sk_cnt}"));
}
if let Some(duration) = self.config.shard_split_request_timeout {
args.push(format!(
"--shard-split-request-timeout={}",
humantime::Duration::from(duration)
));
}
let mut envs = vec![
("LD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()),
("DYLD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()),

View File

@@ -76,12 +76,6 @@ enum Command {
NodeStartDelete {
#[arg(long)]
node_id: NodeId,
/// When `force` is true, skip waiting for shards to prewarm during migration.
/// This can significantly speed up node deletion since prewarming all shards
/// can take considerable time, but may result in slower initial access to
/// migrated shards until they warm up naturally.
#[arg(long)]
force: bool,
},
/// Cancel deletion of the specified pageserver and wait for `timeout`
/// for the operation to be canceled. May be retried.
@@ -482,7 +476,6 @@ async fn main() -> anyhow::Result<()> {
listen_http_port,
listen_https_port,
availability_zone_id: AvailabilityZone(availability_zone_id),
node_ip_addr: None,
}),
)
.await?;
@@ -958,14 +951,13 @@ async fn main() -> anyhow::Result<()> {
.dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None)
.await?;
}
Command::NodeStartDelete { node_id, force } => {
let query = if force {
format!("control/v1/node/{node_id}/delete?force=true")
} else {
format!("control/v1/node/{node_id}/delete")
};
Command::NodeStartDelete { node_id } => {
storcon_client
.dispatch::<(), ()>(Method::PUT, query, None)
.dispatch::<(), ()>(
Method::PUT,
format!("control/v1/node/{node_id}/delete"),
None,
)
.await?;
println!("Delete started for {node_id}");
}

View File

@@ -52,6 +52,7 @@ exceptions = [
# Zlib license has some restrictions if we decide to change sth
{ allow = ["Zlib"], name = "const_format_proc_macros", version = "*" },
{ allow = ["Zlib"], name = "const_format", version = "*" },
{ allow = ["Zlib"], name = "foldhash", version = "*" },
]
[licenses.private]

View File

@@ -129,10 +129,9 @@ segment to bootstrap the WAL writing, but it doesn't contain the checkpoint reco
changes in xlog.c, to allow starting the compute node without reading the last checkpoint record
from WAL.
This includes code to read the `neon.signal` (also `zenith.signal`) file, which tells the startup
code the LSN to start at. When the `neon.signal` file is present, the startup uses that LSN
instead of the last checkpoint's LSN. The system is known to be consistent at that LSN, without
any WAL redo.
This includes code to read the `zenith.signal` file, which tells the startup code the LSN to start
at. When the `zenith.signal` file is present, the startup uses that LSN instead of the last
checkpoint's LSN. The system is known to be consistent at that LSN, without any WAL redo.
### How to get rid of the patch

View File

@@ -75,7 +75,7 @@ CLI examples:
* AWS S3 : `env AWS_ACCESS_KEY_ID='SOMEKEYAAAAASADSAH*#' AWS_SECRET_ACCESS_KEY='SOMEsEcReTsd292v' ${PAGESERVER_BIN} -c "remote_storage={bucket_name='some-sample-bucket',bucket_region='eu-north-1', prefix_in_bucket='/test_prefix/'}"`
For Amazon AWS S3, a key id and secret access key could be located in `~/.aws/credentials` if awscli was ever configured to work with the desired bucket, on the AWS Settings page for a certain user. Also note, that the bucket names does not contain any protocols when used on AWS.
For local S3 installations, refer to their documentation for name format and credentials.
For local S3 installations, refer to the their documentation for name format and credentials.
Similar to other pageserver settings, toml config file can be used to configure either of the storages as backup targets.
Required sections are:

View File

@@ -46,33 +46,16 @@ pub struct ExtensionInstallResponse {
pub version: ExtVersion,
}
/// Status of the LFC prewarm process. The same state machine is reused for
/// both autoprewarm (prewarm after compute/Postgres start using the previously
/// stored LFC state) and explicit prewarming via API.
#[derive(Serialize, Default, Debug, Clone, PartialEq)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum LfcPrewarmState {
/// Default value when compute boots up.
#[default]
NotPrewarmed,
/// Prewarming thread is active and loading pages into LFC.
Prewarming,
/// We found requested LFC state in the endpoint storage and
/// completed prewarming successfully.
Completed,
/// Unexpected error happened during prewarming. Note, `Not Found 404`
/// response from the endpoint storage is explicitly excluded here
/// because it can normally happen on the first compute start,
/// since LFC state is not available yet.
Failed { error: String },
/// We tried to fetch the corresponding LFC state from the endpoint storage,
/// but received `Not Found 404`. This should normally happen only during the
/// first endpoint start after creation with `autoprewarm: true`.
///
/// During the orchestrated prewarm via API, when a caller explicitly
/// provides the LFC state key to prewarm from, it's the caller responsibility
/// to handle this status as an error state in this case.
Skipped,
Failed {
error: String,
},
}
impl Display for LfcPrewarmState {
@@ -81,7 +64,6 @@ impl Display for LfcPrewarmState {
LfcPrewarmState::NotPrewarmed => f.write_str("NotPrewarmed"),
LfcPrewarmState::Prewarming => f.write_str("Prewarming"),
LfcPrewarmState::Completed => f.write_str("Completed"),
LfcPrewarmState::Skipped => f.write_str("Skipped"),
LfcPrewarmState::Failed { error } => write!(f, "Error({error})"),
}
}

View File

@@ -4,14 +4,12 @@
//! a default registry.
#![deny(clippy::undocumented_unsafe_blocks)]
use std::sync::RwLock;
use measured::label::{LabelGroupSet, LabelGroupVisitor, LabelName, NoLabels};
use measured::metric::counter::CounterState;
use measured::metric::gauge::GaugeState;
use measured::metric::group::Encoding;
use measured::metric::name::{MetricName, MetricNameEncoder};
use measured::metric::{MetricEncoding, MetricFamilyEncoding, MetricType};
use measured::metric::{MetricEncoding, MetricFamilyEncoding};
use measured::{FixedCardinalityLabel, LabelGroup, MetricGroup};
use once_cell::sync::Lazy;
use prometheus::Registry;
@@ -118,52 +116,12 @@ pub fn pow2_buckets(start: usize, end: usize) -> Vec<f64> {
.collect()
}
pub struct InfoMetric<L: LabelGroup, M: MetricType = GaugeState> {
label: RwLock<L>,
metric: M,
}
impl<L: LabelGroup> InfoMetric<L> {
pub fn new(label: L) -> Self {
Self::with_metric(label, GaugeState::new(1))
}
}
impl<L: LabelGroup, M: MetricType<Metadata = ()>> InfoMetric<L, M> {
pub fn with_metric(label: L, metric: M) -> Self {
Self {
label: RwLock::new(label),
metric,
}
}
pub fn set_label(&self, label: L) {
*self.label.write().unwrap() = label;
}
}
impl<L, M, E> MetricFamilyEncoding<E> for InfoMetric<L, M>
where
L: LabelGroup,
M: MetricEncoding<E, Metadata = ()>,
E: Encoding,
{
fn collect_family_into(
&self,
name: impl measured::metric::name::MetricNameEncoder,
enc: &mut E,
) -> Result<(), E::Err> {
M::write_type(&name, enc)?;
self.metric
.collect_into(&(), &*self.label.read().unwrap(), name, enc)
}
}
pub struct BuildInfo {
pub revision: &'static str,
pub build_tag: &'static str,
}
// todo: allow label group without the set
impl LabelGroup for BuildInfo {
fn visit_values(&self, v: &mut impl LabelGroupVisitor) {
const REVISION: &LabelName = LabelName::from_str("revision");
@@ -173,6 +131,24 @@ impl LabelGroup for BuildInfo {
}
}
impl<T: Encoding> MetricFamilyEncoding<T> for BuildInfo
where
GaugeState: MetricEncoding<T>,
{
fn collect_family_into(
&self,
name: impl measured::metric::name::MetricNameEncoder,
enc: &mut T,
) -> Result<(), T::Err> {
enc.write_help(&name, "Build/version information")?;
GaugeState::write_type(&name, enc)?;
GaugeState {
count: std::sync::atomic::AtomicI64::new(1),
}
.collect_into(&(), self, name, enc)
}
}
#[derive(MetricGroup)]
#[metric(new(build_info: BuildInfo))]
pub struct NeonMetrics {
@@ -189,8 +165,8 @@ pub struct NeonMetrics {
#[derive(MetricGroup)]
#[metric(new(build_info: BuildInfo))]
pub struct LibMetrics {
#[metric(init = InfoMetric::new(build_info))]
build_info: InfoMetric<BuildInfo>,
#[metric(init = build_info)]
build_info: BuildInfo,
#[metric(flatten)]
rusage: Rusage,

View File

@@ -8,13 +8,6 @@ license.workspace = true
thiserror.workspace = true
nix.workspace=true
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
libc.workspace = true
lock_api.workspace = true
rustc-hash.workspace = true
[target.'cfg(target_os = "macos")'.dependencies]
tempfile = "3.14.0"
[dev-dependencies]
rand = "0.9"
rand_distr = "0.5.1"

View File

@@ -1,583 +0,0 @@
//! Resizable hash table implementation on top of byte-level storage (either a [`ShmemHandle`] or a fixed byte array).
//!
//! This hash table has two major components: the bucket array and the dictionary. Each bucket within the
//! bucket array contains a `Option<(K, V)>` and an index of another bucket. In this way there is both an
//! implicit freelist within the bucket array (`None` buckets point to other `None` entries) and various hash
//! chains within the bucket array (a Some bucket will point to other Some buckets that had the same hash).
//!
//! Buckets are never moved unless they are within a region that is being shrunk, and so the actual hash-
//! dependent component is done with the dictionary. When a new key is inserted into the map, a position
//! within the dictionary is decided based on its hash, the data is inserted into an empty bucket based
//! off of the freelist, and then the index of said bucket is placed in the dictionary.
//!
//! This map is resizable (if initialized on top of a [`ShmemHandle`]). Both growing and shrinking happen
//! in-place and are at a high level achieved by expanding/reducing the bucket array and rebuilding the
//! dictionary by rehashing all keys.
//!
//! Concurrency is managed very simply: the entire map is guarded by one shared-memory RwLock.
use std::hash::{BuildHasher, Hash};
use std::mem::MaybeUninit;
use crate::shmem::ShmemHandle;
use crate::{shmem, sync::*};
mod core;
pub mod entry;
#[cfg(test)]
mod tests;
use core::{Bucket, CoreHashMap, INVALID_POS};
use entry::{Entry, OccupiedEntry, PrevPos, VacantEntry};
use thiserror::Error;
/// Error type for a hashmap shrink operation.
#[derive(Error, Debug)]
pub enum HashMapShrinkError {
/// There was an error encountered while resizing the memory area.
#[error("shmem resize failed: {0}")]
ResizeError(shmem::Error),
/// Occupied entries in to-be-shrunk space were encountered beginning at the given index.
#[error("occupied entry in deallocated space found at {0}")]
RemainingEntries(usize),
}
/// This represents a hash table that (possibly) lives in shared memory.
/// If a new process is launched with fork(), the child process inherits
/// this struct.
#[must_use]
pub struct HashMapInit<'a, K, V, S = rustc_hash::FxBuildHasher> {
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
shared_size: usize,
hasher: S,
num_buckets: u32,
}
/// This is a per-process handle to a hash table that (possibly) lives in shared memory.
/// If a child process is launched with fork(), the child process should
/// get its own HashMapAccess by calling HashMapInit::attach_writer/reader().
///
/// XXX: We're not making use of it at the moment, but this struct could
/// hold process-local information in the future.
pub struct HashMapAccess<'a, K, V, S = rustc_hash::FxBuildHasher> {
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
hasher: S,
}
unsafe impl<K: Sync, V: Sync, S> Sync for HashMapAccess<'_, K, V, S> {}
unsafe impl<K: Send, V: Send, S> Send for HashMapAccess<'_, K, V, S> {}
impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
/// Change the 'hasher' used by the hash table.
///
/// NOTE: This must be called right after creating the hash table,
/// before inserting any entries and before calling attach_writer/reader.
/// Otherwise different accessors could be using different hash function,
/// with confusing results.
pub fn with_hasher<T: BuildHasher>(self, hasher: T) -> HashMapInit<'a, K, V, T> {
HashMapInit {
hasher,
shmem_handle: self.shmem_handle,
shared_ptr: self.shared_ptr,
shared_size: self.shared_size,
num_buckets: self.num_buckets,
}
}
/// Loosely (over)estimate the size needed to store a hash table with `num_buckets` buckets.
pub fn estimate_size(num_buckets: u32) -> usize {
// add some margin to cover alignment etc.
CoreHashMap::<K, V>::estimate_size(num_buckets) + size_of::<HashMapShared<K, V>>() + 1000
}
fn new(
num_buckets: u32,
shmem_handle: Option<ShmemHandle>,
area_ptr: *mut u8,
area_size: usize,
hasher: S,
) -> Self {
let mut ptr: *mut u8 = area_ptr;
let end_ptr: *mut u8 = unsafe { ptr.add(area_size) };
// carve out area for the One Big Lock (TM) and the HashMapShared.
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<libc::pthread_rwlock_t>())) };
let raw_lock_ptr = ptr;
ptr = unsafe { ptr.add(size_of::<libc::pthread_rwlock_t>()) };
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<HashMapShared<K, V>>())) };
let shared_ptr: *mut HashMapShared<K, V> = ptr.cast();
ptr = unsafe { ptr.add(size_of::<HashMapShared<K, V>>()) };
// carve out the buckets
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<core::Bucket<K, V>>())) };
let buckets_ptr = ptr;
ptr = unsafe { ptr.add(size_of::<core::Bucket<K, V>>() * num_buckets as usize) };
// use remaining space for the dictionary
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<u32>())) };
assert!(ptr.addr() < end_ptr.addr());
let dictionary_ptr = ptr;
let dictionary_size = unsafe { end_ptr.byte_offset_from(ptr) / size_of::<u32>() as isize };
assert!(dictionary_size > 0);
let buckets =
unsafe { std::slice::from_raw_parts_mut(buckets_ptr.cast(), num_buckets as usize) };
let dictionary = unsafe {
std::slice::from_raw_parts_mut(dictionary_ptr.cast(), dictionary_size as usize)
};
let hashmap = CoreHashMap::new(buckets, dictionary);
unsafe {
let lock = RwLock::from_raw(PthreadRwLock::new(raw_lock_ptr.cast()), hashmap);
std::ptr::write(shared_ptr, lock);
}
Self {
num_buckets,
shmem_handle,
shared_ptr,
shared_size: area_size,
hasher,
}
}
/// Attach to a hash table for writing.
pub fn attach_writer(self) -> HashMapAccess<'a, K, V, S> {
HashMapAccess {
shmem_handle: self.shmem_handle,
shared_ptr: self.shared_ptr,
hasher: self.hasher,
}
}
/// Initialize a table for reading. Currently identical to [`HashMapInit::attach_writer`].
///
/// This is a holdover from a previous implementation and is being kept around for
/// backwards compatibility reasons.
pub fn attach_reader(self) -> HashMapAccess<'a, K, V, S> {
self.attach_writer()
}
}
/// Hash table data that is actually stored in the shared memory area.
///
/// NOTE: We carve out the parts from a contiguous chunk. Growing and shrinking the hash table
/// relies on the memory layout! The data structures are laid out in the contiguous shared memory
/// area as follows:
///
/// [`libc::pthread_rwlock_t`]
/// [`HashMapShared`]
/// buckets
/// dictionary
///
/// In between the above parts, there can be padding bytes to align the parts correctly.
type HashMapShared<'a, K, V> = RwLock<CoreHashMap<'a, K, V>>;
impl<'a, K, V> HashMapInit<'a, K, V, rustc_hash::FxBuildHasher>
where
K: Clone + Hash + Eq,
{
/// Place the hash table within a user-supplied fixed memory area.
pub fn with_fixed(num_buckets: u32, area: &'a mut [MaybeUninit<u8>]) -> Self {
Self::new(
num_buckets,
None,
area.as_mut_ptr().cast(),
area.len(),
rustc_hash::FxBuildHasher,
)
}
/// Place a new hash map in the given shared memory area
///
/// # Panics
/// Will panic on failure to resize area to expected map size.
pub fn with_shmem(num_buckets: u32, shmem: ShmemHandle) -> Self {
let size = Self::estimate_size(num_buckets);
shmem
.set_size(size)
.expect("could not resize shared memory area");
let ptr = shmem.data_ptr.as_ptr().cast();
Self::new(
num_buckets,
Some(shmem),
ptr,
size,
rustc_hash::FxBuildHasher,
)
}
/// Make a resizable hash map within a new shared memory area with the given name.
pub fn new_resizeable_named(num_buckets: u32, max_buckets: u32, name: &str) -> Self {
let size = Self::estimate_size(num_buckets);
let max_size = Self::estimate_size(max_buckets);
let shmem =
ShmemHandle::new(name, size, max_size).expect("failed to make shared memory area");
let ptr = shmem.data_ptr.as_ptr().cast();
Self::new(
num_buckets,
Some(shmem),
ptr,
size,
rustc_hash::FxBuildHasher,
)
}
/// Make a resizable hash map within a new anonymous shared memory area.
pub fn new_resizeable(num_buckets: u32, max_buckets: u32) -> Self {
use std::sync::atomic::{AtomicUsize, Ordering};
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let val = COUNTER.fetch_add(1, Ordering::Relaxed);
let name = format!("neon_shmem_hmap{val}");
Self::new_resizeable_named(num_buckets, max_buckets, &name)
}
}
impl<'a, K, V, S: BuildHasher> HashMapAccess<'a, K, V, S>
where
K: Clone + Hash + Eq,
{
/// Hash a key using the map's hasher.
#[inline]
fn get_hash_value(&self, key: &K) -> u64 {
self.hasher.hash_one(key)
}
fn entry_with_hash(&self, key: K, hash: u64) -> Entry<'a, '_, K, V> {
let mut map = unsafe { self.shared_ptr.as_ref() }.unwrap().write();
let dict_pos = hash as usize % map.dictionary.len();
let first = map.dictionary[dict_pos];
if first == INVALID_POS {
// no existing entry
return Entry::Vacant(VacantEntry {
map,
key,
dict_pos: dict_pos as u32,
});
}
let mut prev_pos = PrevPos::First(dict_pos as u32);
let mut next = first;
loop {
let bucket = &mut map.buckets[next as usize];
let (bucket_key, _bucket_value) = bucket.inner.as_mut().expect("entry is in use");
if *bucket_key == key {
// found existing entry
return Entry::Occupied(OccupiedEntry {
map,
_key: key,
prev_pos,
bucket_pos: next,
});
}
if bucket.next == INVALID_POS {
// No existing entry
return Entry::Vacant(VacantEntry {
map,
key,
dict_pos: dict_pos as u32,
});
}
prev_pos = PrevPos::Chained(next);
next = bucket.next;
}
}
/// Get a reference to the corresponding value for a key.
pub fn get<'e>(&'e self, key: &K) -> Option<ValueReadGuard<'e, V>> {
let hash = self.get_hash_value(key);
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
RwLockReadGuard::try_map(map, |m| m.get_with_hash(key, hash)).ok()
}
/// Get a reference to the entry containing a key.
///
/// NB: THis takes a write lock as there's no way to distinguish whether the intention
/// is to use the entry for reading or for writing in advance.
pub fn entry(&self, key: K) -> Entry<'a, '_, K, V> {
let hash = self.get_hash_value(&key);
self.entry_with_hash(key, hash)
}
/// Remove a key given its hash. Returns the associated value if it existed.
pub fn remove(&self, key: &K) -> Option<V> {
let hash = self.get_hash_value(key);
match self.entry_with_hash(key.clone(), hash) {
Entry::Occupied(e) => Some(e.remove()),
Entry::Vacant(_) => None,
}
}
/// Insert/update a key. Returns the previous associated value if it existed.
///
/// # Errors
/// Will return [`core::FullError`] if there is no more space left in the map.
pub fn insert(&self, key: K, value: V) -> Result<Option<V>, core::FullError> {
let hash = self.get_hash_value(&key);
match self.entry_with_hash(key.clone(), hash) {
Entry::Occupied(mut e) => Ok(Some(e.insert(value))),
Entry::Vacant(e) => {
_ = e.insert(value)?;
Ok(None)
}
}
}
/// Optionally return the entry for a bucket at a given index if it exists.
///
/// Has more overhead than one would intuitively expect: performs both a clone of the key
/// due to the [`OccupiedEntry`] type owning the key and also a hash of the key in order
/// to enable repairing the hash chain if the entry is removed.
pub fn entry_at_bucket(&self, pos: usize) -> Option<OccupiedEntry<'a, '_, K, V>> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
if pos >= map.buckets.len() {
return None;
}
let entry = map.buckets[pos].inner.as_ref();
match entry {
Some((key, _)) => Some(OccupiedEntry {
_key: key.clone(),
bucket_pos: pos as u32,
prev_pos: entry::PrevPos::Unknown(self.get_hash_value(key)),
map,
}),
_ => None,
}
}
/// Returns the number of buckets in the table.
pub fn get_num_buckets(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
map.get_num_buckets()
}
/// Return the key and value stored in bucket with given index. This can be used to
/// iterate through the hash map.
// TODO: An Iterator might be nicer. The communicator's clock algorithm needs to
// _slowly_ iterate through all buckets with its clock hand, without holding a lock.
// If we switch to an Iterator, it must not hold the lock.
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<(K, V)>> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
if pos >= map.buckets.len() {
return None;
}
RwLockReadGuard::try_map(map, |m| m.buckets[pos].inner.as_ref()).ok()
}
/// Returns the index of the bucket a given value corresponds to.
pub fn get_bucket_for_value(&self, val_ptr: *const V) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
let origin = map.buckets.as_ptr();
let idx = (val_ptr as usize - origin as usize) / size_of::<Bucket<K, V>>();
assert!(idx < map.buckets.len());
idx
}
/// Returns the number of occupied buckets in the table.
pub fn get_num_buckets_in_use(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
map.buckets_in_use as usize
}
/// Clears all entries in a table. Does not reset any shrinking operations.
pub fn clear(&self) {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
map.clear();
}
/// Perform an in-place rehash of some region (0..`rehash_buckets`) of the table and reset
/// the `buckets` and `dictionary` slices to be as long as `num_buckets`. Resets the freelist
/// in the process.
fn rehash_dict(
&self,
inner: &mut CoreHashMap<'a, K, V>,
buckets_ptr: *mut core::Bucket<K, V>,
end_ptr: *mut u8,
num_buckets: u32,
rehash_buckets: u32,
) {
inner.free_head = INVALID_POS;
let buckets;
let dictionary;
unsafe {
let buckets_end_ptr = buckets_ptr.add(num_buckets as usize);
let dictionary_ptr: *mut u32 = buckets_end_ptr
.byte_add(buckets_end_ptr.align_offset(align_of::<u32>()))
.cast();
let dictionary_size: usize =
end_ptr.byte_offset_from(buckets_end_ptr) as usize / size_of::<u32>();
buckets = std::slice::from_raw_parts_mut(buckets_ptr, num_buckets as usize);
dictionary = std::slice::from_raw_parts_mut(dictionary_ptr, dictionary_size);
}
for e in dictionary.iter_mut() {
*e = INVALID_POS;
}
for (i, bucket) in buckets.iter_mut().enumerate().take(rehash_buckets as usize) {
if bucket.inner.is_none() {
bucket.next = inner.free_head;
inner.free_head = i as u32;
continue;
}
let hash = self.hasher.hash_one(&bucket.inner.as_ref().unwrap().0);
let pos: usize = (hash % dictionary.len() as u64) as usize;
bucket.next = dictionary[pos];
dictionary[pos] = i as u32;
}
inner.dictionary = dictionary;
inner.buckets = buckets;
}
/// Rehash the map without growing or shrinking.
pub fn shuffle(&self) {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
let num_buckets = map.get_num_buckets() as u32;
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
let end_ptr: *mut u8 = unsafe { self.shared_ptr.byte_add(size_bytes).cast() };
let buckets_ptr = map.buckets.as_mut_ptr();
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets);
}
/// Grow the number of buckets within the table.
///
/// 1. Grows the underlying shared memory area
/// 2. Initializes new buckets and overwrites the current dictionary
/// 3. Rehashes the dictionary
///
/// # Panics
/// Panics if called on a map initialized with [`HashMapInit::with_fixed`].
///
/// # Errors
/// Returns an [`shmem::Error`] if any errors occur resizing the memory region.
pub fn grow(&self, num_buckets: u32) -> Result<(), shmem::Error> {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
let old_num_buckets = map.buckets.len() as u32;
assert!(
num_buckets >= old_num_buckets,
"grow called with a smaller number of buckets"
);
if num_buckets == old_num_buckets {
return Ok(());
}
let shmem_handle = self
.shmem_handle
.as_ref()
.expect("grow called on a fixed-size hash table");
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
shmem_handle.set_size(size_bytes)?;
let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
// Initialize new buckets. The new buckets are linked to the free list.
// NB: This overwrites the dictionary!
let buckets_ptr = map.buckets.as_mut_ptr();
unsafe {
for i in old_num_buckets..num_buckets {
let bucket = buckets_ptr.add(i as usize);
bucket.write(core::Bucket {
next: if i < num_buckets - 1 {
i + 1
} else {
map.free_head
},
inner: None,
});
}
}
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, old_num_buckets);
map.free_head = old_num_buckets;
Ok(())
}
/// Begin a shrink, limiting all new allocations to be in buckets with index below `num_buckets`.
///
/// # Panics
/// Panics if called on a map initialized with [`HashMapInit::with_fixed`] or if `num_buckets` is
/// greater than the number of buckets in the map.
pub fn begin_shrink(&mut self, num_buckets: u32) {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
assert!(
num_buckets <= map.get_num_buckets() as u32,
"shrink called with a larger number of buckets"
);
_ = self
.shmem_handle
.as_ref()
.expect("shrink called on a fixed-size hash table");
map.alloc_limit = num_buckets;
}
/// If a shrink operation is underway, returns the target size of the map. Otherwise, returns None.
pub fn shrink_goal(&self) -> Option<usize> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap().read();
let goal = map.alloc_limit;
if goal == INVALID_POS {
None
} else {
Some(goal as usize)
}
}
/// Complete a shrink after caller has evicted entries, removing the unused buckets and rehashing.
///
/// # Panics
/// The following cases result in a panic:
/// - Calling this function on a map initialized with [`HashMapInit::with_fixed`].
/// - Calling this function on a map when no shrink operation is in progress.
pub fn finish_shrink(&self) -> Result<(), HashMapShrinkError> {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
assert!(
map.alloc_limit != INVALID_POS,
"called finish_shrink when no shrink is in progress"
);
let num_buckets = map.alloc_limit;
if map.get_num_buckets() == num_buckets as usize {
return Ok(());
}
assert!(
map.buckets_in_use <= num_buckets,
"called finish_shrink before enough entries were removed"
);
for i in (num_buckets as usize)..map.buckets.len() {
if map.buckets[i].inner.is_some() {
return Err(HashMapShrinkError::RemainingEntries(i));
}
}
let shmem_handle = self
.shmem_handle
.as_ref()
.expect("shrink called on a fixed-size hash table");
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
if let Err(e) = shmem_handle.set_size(size_bytes) {
return Err(HashMapShrinkError::ResizeError(e));
}
let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
let buckets_ptr = map.buckets.as_mut_ptr();
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets);
map.alloc_limit = INVALID_POS;
Ok(())
}
}

View File

@@ -1,174 +0,0 @@
//! Simple hash table with chaining.
use std::hash::Hash;
use std::mem::MaybeUninit;
use crate::hash::entry::*;
/// Invalid position within the map (either within the dictionary or bucket array).
pub(crate) const INVALID_POS: u32 = u32::MAX;
/// Fundamental storage unit within the hash table. Either empty or contains a key-value pair.
/// Always part of a chain of some kind (either a freelist if empty or a hash chain if full).
pub(crate) struct Bucket<K, V> {
/// Index of next bucket in the chain.
pub(crate) next: u32,
/// Key-value pair contained within bucket.
pub(crate) inner: Option<(K, V)>,
}
/// Core hash table implementation.
pub(crate) struct CoreHashMap<'a, K, V> {
/// Dictionary used to map hashes to bucket indices.
pub(crate) dictionary: &'a mut [u32],
/// Buckets containing key-value pairs.
pub(crate) buckets: &'a mut [Bucket<K, V>],
/// Head of the freelist.
pub(crate) free_head: u32,
/// Maximum index of a bucket allowed to be allocated. [`INVALID_POS`] if no limit.
pub(crate) alloc_limit: u32,
/// The number of currently occupied buckets.
pub(crate) buckets_in_use: u32,
}
/// Error for when there are no empty buckets left but one is needed.
#[derive(Debug, PartialEq)]
pub struct FullError;
impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
const FILL_FACTOR: f32 = 0.60;
/// Estimate the size of data contained within the the hash map.
pub fn estimate_size(num_buckets: u32) -> usize {
let mut size = 0;
// buckets
size += size_of::<Bucket<K, V>>() * num_buckets as usize;
// dictionary
size += (f32::ceil((size_of::<u32>() * num_buckets as usize) as f32 / Self::FILL_FACTOR))
as usize;
size
}
pub fn new(
buckets: &'a mut [MaybeUninit<Bucket<K, V>>],
dictionary: &'a mut [MaybeUninit<u32>],
) -> Self {
// Initialize the buckets
for i in 0..buckets.len() {
buckets[i].write(Bucket {
next: if i < buckets.len() - 1 {
i as u32 + 1
} else {
INVALID_POS
},
inner: None,
});
}
// Initialize the dictionary
for e in dictionary.iter_mut() {
e.write(INVALID_POS);
}
// TODO: use std::slice::assume_init_mut() once it stabilizes
let buckets =
unsafe { std::slice::from_raw_parts_mut(buckets.as_mut_ptr().cast(), buckets.len()) };
let dictionary = unsafe {
std::slice::from_raw_parts_mut(dictionary.as_mut_ptr().cast(), dictionary.len())
};
Self {
dictionary,
buckets,
free_head: 0,
buckets_in_use: 0,
alloc_limit: INVALID_POS,
}
}
/// Get the value associated with a key (if it exists) given its hash.
pub fn get_with_hash(&self, key: &K, hash: u64) -> Option<&V> {
let mut next = self.dictionary[hash as usize % self.dictionary.len()];
loop {
if next == INVALID_POS {
return None;
}
let bucket = &self.buckets[next as usize];
let (bucket_key, bucket_value) = bucket.inner.as_ref().expect("entry is in use");
if bucket_key == key {
return Some(bucket_value);
}
next = bucket.next;
}
}
/// Get number of buckets in map.
pub fn get_num_buckets(&self) -> usize {
self.buckets.len()
}
/// Clears all entries from the hashmap.
///
/// Does not reset any allocation limits, but does clear any entries beyond them.
pub fn clear(&mut self) {
for i in 0..self.buckets.len() {
self.buckets[i] = Bucket {
next: if i < self.buckets.len() - 1 {
i as u32 + 1
} else {
INVALID_POS
},
inner: None,
}
}
for i in 0..self.dictionary.len() {
self.dictionary[i] = INVALID_POS;
}
self.free_head = 0;
self.buckets_in_use = 0;
}
/// Find the position of an unused bucket via the freelist and initialize it.
pub(crate) fn alloc_bucket(&mut self, key: K, value: V) -> Result<u32, FullError> {
let mut pos = self.free_head;
// Find the first bucket we're *allowed* to use.
let mut prev = PrevPos::First(self.free_head);
while pos != INVALID_POS && pos >= self.alloc_limit {
let bucket = &mut self.buckets[pos as usize];
prev = PrevPos::Chained(pos);
pos = bucket.next;
}
if pos == INVALID_POS {
return Err(FullError);
}
// Repair the freelist.
match prev {
PrevPos::First(_) => {
let next_pos = self.buckets[pos as usize].next;
self.free_head = next_pos;
}
PrevPos::Chained(p) => {
if p != INVALID_POS {
let next_pos = self.buckets[pos as usize].next;
self.buckets[p as usize].next = next_pos;
}
}
_ => unreachable!(),
}
// Initialize the bucket.
let bucket = &mut self.buckets[pos as usize];
self.buckets_in_use += 1;
bucket.next = INVALID_POS;
bucket.inner = Some((key, value));
Ok(pos)
}
}

View File

@@ -1,130 +0,0 @@
//! Equivalent of [`std::collections::hash_map::Entry`] for this hashmap.
use crate::hash::core::{CoreHashMap, FullError, INVALID_POS};
use crate::sync::{RwLockWriteGuard, ValueWriteGuard};
use std::hash::Hash;
use std::mem;
pub enum Entry<'a, 'b, K, V> {
Occupied(OccupiedEntry<'a, 'b, K, V>),
Vacant(VacantEntry<'a, 'b, K, V>),
}
/// Enum representing the previous position within a chain.
#[derive(Clone, Copy)]
pub(crate) enum PrevPos {
/// Starting index within the dictionary.
First(u32),
/// Regular index within the buckets.
Chained(u32),
/// Unknown - e.g. the associated entry was retrieved by index instead of chain.
Unknown(u64),
}
pub struct OccupiedEntry<'a, 'b, K, V> {
/// Mutable reference to the map containing this entry.
pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>,
/// The key of the occupied entry
pub(crate) _key: K,
/// The index of the previous entry in the chain.
pub(crate) prev_pos: PrevPos,
/// The position of the bucket in the [`CoreHashMap`] bucket array.
pub(crate) bucket_pos: u32,
}
impl<K, V> OccupiedEntry<'_, '_, K, V> {
pub fn get(&self) -> &V {
&self.map.buckets[self.bucket_pos as usize]
.inner
.as_ref()
.unwrap()
.1
}
pub fn get_mut(&mut self) -> &mut V {
&mut self.map.buckets[self.bucket_pos as usize]
.inner
.as_mut()
.unwrap()
.1
}
/// Inserts a value into the entry, replacing (and returning) the existing value.
pub fn insert(&mut self, value: V) -> V {
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
// This assumes inner is Some, which it must be for an OccupiedEntry
mem::replace(&mut bucket.inner.as_mut().unwrap().1, value)
}
/// Removes the entry from the hash map, returning the value originally stored within it.
///
/// This may result in multiple bucket accesses if the entry was obtained by index as the
/// previous chain entry needs to be discovered in this case.
pub fn remove(mut self) -> V {
// If this bucket was queried by index, go ahead and follow its chain from the start.
let prev = if let PrevPos::Unknown(hash) = self.prev_pos {
let dict_idx = hash as usize % self.map.dictionary.len();
let mut prev = PrevPos::First(dict_idx as u32);
let mut curr = self.map.dictionary[dict_idx];
while curr != self.bucket_pos {
assert!(curr != INVALID_POS);
prev = PrevPos::Chained(curr);
curr = self.map.buckets[curr as usize].next;
}
prev
} else {
self.prev_pos
};
// CoreHashMap::remove returns Option<(K, V)>. We know it's Some for an OccupiedEntry.
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
// unlink it from the chain
match prev {
PrevPos::First(dict_pos) => {
self.map.dictionary[dict_pos as usize] = bucket.next;
}
PrevPos::Chained(bucket_pos) => {
self.map.buckets[bucket_pos as usize].next = bucket.next;
}
_ => unreachable!(),
}
// and add it to the freelist
let free = self.map.free_head;
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
let old_value = bucket.inner.take();
bucket.next = free;
self.map.free_head = self.bucket_pos;
self.map.buckets_in_use -= 1;
old_value.unwrap().1
}
}
/// An abstract view into a vacant entry within the map.
pub struct VacantEntry<'a, 'b, K, V> {
/// Mutable reference to the map containing this entry.
pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>,
/// The key to be inserted into this entry.
pub(crate) key: K,
/// The position within the dictionary corresponding to the key's hash.
pub(crate) dict_pos: u32,
}
impl<'b, K: Clone + Hash + Eq, V> VacantEntry<'_, 'b, K, V> {
/// Insert a value into the vacant entry, finding and populating an empty bucket in the process.
///
/// # Errors
/// Will return [`FullError`] if there are no unoccupied buckets in the map.
pub fn insert(mut self, value: V) -> Result<ValueWriteGuard<'b, V>, FullError> {
let pos = self.map.alloc_bucket(self.key, value)?;
self.map.buckets[pos as usize].next = self.map.dictionary[self.dict_pos as usize];
self.map.dictionary[self.dict_pos as usize] = pos;
Ok(RwLockWriteGuard::map(self.map, |m| {
&mut m.buckets[pos as usize].inner.as_mut().unwrap().1
}))
}
}

View File

@@ -1,428 +0,0 @@
use std::collections::BTreeMap;
use std::collections::HashSet;
use std::fmt::Debug;
use std::mem::MaybeUninit;
use crate::hash::Entry;
use crate::hash::HashMapAccess;
use crate::hash::HashMapInit;
use crate::hash::core::FullError;
use rand::seq::SliceRandom;
use rand::{Rng, RngCore};
use rand_distr::Zipf;
const TEST_KEY_LEN: usize = 16;
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
struct TestKey([u8; TEST_KEY_LEN]);
impl From<&TestKey> for u128 {
fn from(val: &TestKey) -> u128 {
u128::from_be_bytes(val.0)
}
}
impl From<u128> for TestKey {
fn from(val: u128) -> TestKey {
TestKey(val.to_be_bytes())
}
}
impl<'a> From<&'a [u8]> for TestKey {
fn from(bytes: &'a [u8]) -> TestKey {
TestKey(bytes.try_into().unwrap())
}
}
fn test_inserts<K: Into<TestKey> + Copy>(keys: &[K]) {
let w = HashMapInit::<TestKey, usize>::new_resizeable_named(100000, 120000, "test_inserts")
.attach_writer();
for (idx, k) in keys.iter().enumerate() {
let res = w.entry((*k).into());
match res {
Entry::Occupied(mut e) => {
e.insert(idx);
}
Entry::Vacant(e) => {
let res = e.insert(idx);
assert!(res.is_ok());
}
};
}
for (idx, k) in keys.iter().enumerate() {
let x = w.get(&(*k).into());
let value = x.as_deref().copied();
assert_eq!(value, Some(idx));
}
}
#[test]
fn dense() {
// This exercises splitting a node with prefix
let keys: &[u128] = &[0, 1, 2, 3, 256];
test_inserts(keys);
// Dense keys
let mut keys: Vec<u128> = (0..10000).collect();
test_inserts(&keys);
// Do the same in random orders
for _ in 1..10 {
keys.shuffle(&mut rand::rng());
test_inserts(&keys);
}
}
#[test]
fn sparse() {
// sparse keys
let mut keys: Vec<TestKey> = Vec::new();
let mut used_keys = HashSet::new();
for _ in 0..10000 {
loop {
let key = rand::random::<u128>();
if used_keys.contains(&key) {
continue;
}
used_keys.insert(key);
keys.push(key.into());
break;
}
}
test_inserts(&keys);
}
#[derive(Clone, Debug)]
struct TestOp(TestKey, Option<usize>);
fn apply_op(
op: &TestOp,
map: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
) {
// apply the change to the shadow tree first
let shadow_existing = if let Some(v) = op.1 {
shadow.insert(op.0, v)
} else {
shadow.remove(&op.0)
};
let entry = map.entry(op.0);
let hash_existing = match op.1 {
Some(new) => match entry {
Entry::Occupied(mut e) => Some(e.insert(new)),
Entry::Vacant(e) => {
_ = e.insert(new).unwrap();
None
}
},
None => match entry {
Entry::Occupied(e) => Some(e.remove()),
Entry::Vacant(_) => None,
},
};
assert_eq!(shadow_existing, hash_existing);
}
fn do_random_ops(
num_ops: usize,
size: u32,
del_prob: f64,
writer: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
rng: &mut rand::rngs::ThreadRng,
) {
for i in 0..num_ops {
let key: TestKey = ((rng.next_u32() % size) as u128).into();
let op = TestOp(
key,
if rng.random_bool(del_prob) {
Some(i)
} else {
None
},
);
apply_op(&op, writer, shadow);
}
}
fn do_deletes(
num_ops: usize,
writer: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
) {
for _ in 0..num_ops {
let (k, _) = shadow.pop_first().unwrap();
writer.remove(&k);
}
}
fn do_shrink(
writer: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
from: u32,
to: u32,
) {
assert!(writer.shrink_goal().is_none());
writer.begin_shrink(to);
assert_eq!(writer.shrink_goal(), Some(to as usize));
for i in to..from {
if let Some(entry) = writer.entry_at_bucket(i as usize) {
shadow.remove(&entry._key);
entry.remove();
}
}
let old_usage = writer.get_num_buckets_in_use();
writer.finish_shrink().unwrap();
assert!(writer.shrink_goal().is_none());
assert_eq!(writer.get_num_buckets_in_use(), old_usage);
}
#[test]
fn random_ops() {
let mut writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(100000, 120000, "test_random")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let distribution = Zipf::new(u128::MAX as f64, 1.1).unwrap();
let mut rng = rand::rng();
for i in 0..100000 {
let key: TestKey = (rng.sample(distribution) as u128).into();
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &mut writer, &mut shadow);
}
}
#[test]
fn test_shuffle() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 1200, "test_shuf")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng);
writer.shuffle();
do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng);
}
#[test]
fn test_grow() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 2000, "test_grow")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng);
let old_usage = writer.get_num_buckets_in_use();
writer.grow(1500).unwrap();
assert_eq!(writer.get_num_buckets_in_use(), old_usage);
assert_eq!(writer.get_num_buckets(), 1500);
do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
}
#[test]
fn test_clear() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_clear")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(2000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
writer.clear();
assert_eq!(writer.get_num_buckets_in_use(), 0);
assert_eq!(writer.get_num_buckets(), 1500);
while let Some((key, _)) = shadow.pop_first() {
assert!(writer.get(&key).is_none());
}
do_random_ops(2000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
for i in 0..(1500 - writer.get_num_buckets_in_use()) {
writer.insert((1500 + i as u128).into(), 0).unwrap();
}
assert_eq!(writer.insert(5000.into(), 0), Err(FullError {}));
writer.clear();
assert!(writer.insert(5000.into(), 0).is_ok());
}
#[test]
fn test_idx_remove() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_clear")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(2000, 1500, 0.25, &mut writer, &mut shadow, &mut rng);
for _ in 0..100 {
let idx = (rng.next_u32() % 1500) as usize;
if let Some(e) = writer.entry_at_bucket(idx) {
shadow.remove(&e._key);
e.remove();
}
}
while let Some((key, val)) = shadow.pop_first() {
assert_eq!(*writer.get(&key).unwrap(), val);
}
}
#[test]
fn test_idx_get() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_clear")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(2000, 1500, 0.25, &mut writer, &mut shadow, &mut rng);
for _ in 0..100 {
let idx = (rng.next_u32() % 1500) as usize;
if let Some(pair) = writer.get_at_bucket(idx) {
{
let v: *const usize = &pair.1;
assert_eq!(writer.get_bucket_for_value(v), idx);
}
{
let v: *const usize = &pair.1;
assert_eq!(writer.get_bucket_for_value(v), idx);
}
}
}
}
#[test]
fn test_shrink() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_shrink")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
do_shrink(&mut writer, &mut shadow, 1500, 1000);
assert_eq!(writer.get_num_buckets(), 1000);
do_deletes(500, &mut writer, &mut shadow);
do_random_ops(10000, 500, 0.75, &mut writer, &mut shadow, &mut rng);
assert!(writer.get_num_buckets_in_use() <= 1000);
}
#[test]
fn test_shrink_grow_seq() {
let mut writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 20000, "test_grow_seq")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(500, 1000, 0.1, &mut writer, &mut shadow, &mut rng);
eprintln!("Shrinking to 750");
do_shrink(&mut writer, &mut shadow, 1000, 750);
do_random_ops(200, 1000, 0.5, &mut writer, &mut shadow, &mut rng);
eprintln!("Growing to 1500");
writer.grow(1500).unwrap();
do_random_ops(600, 1500, 0.1, &mut writer, &mut shadow, &mut rng);
eprintln!("Shrinking to 200");
while shadow.len() > 100 {
do_deletes(1, &mut writer, &mut shadow);
}
do_shrink(&mut writer, &mut shadow, 1500, 200);
do_random_ops(50, 1500, 0.25, &mut writer, &mut shadow, &mut rng);
eprintln!("Growing to 10k");
writer.grow(10000).unwrap();
do_random_ops(10000, 5000, 0.25, &mut writer, &mut shadow, &mut rng);
}
#[test]
fn test_bucket_ops() {
let writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 1200, "test_bucket_ops")
.attach_writer();
match writer.entry(1.into()) {
Entry::Occupied(mut e) => {
e.insert(2);
}
Entry::Vacant(e) => {
_ = e.insert(2).unwrap();
}
}
assert_eq!(writer.get_num_buckets_in_use(), 1);
assert_eq!(writer.get_num_buckets(), 1000);
assert_eq!(*writer.get(&1.into()).unwrap(), 2);
let pos = match writer.entry(1.into()) {
Entry::Occupied(e) => {
assert_eq!(e._key, 1.into());
e.bucket_pos as usize
}
Entry::Vacant(_) => {
panic!("Insert didn't affect entry");
}
};
assert_eq!(writer.entry_at_bucket(pos).unwrap()._key, 1.into());
assert_eq!(*writer.get_at_bucket(pos).unwrap(), (1.into(), 2));
{
let ptr: *const usize = &*writer.get(&1.into()).unwrap();
assert_eq!(writer.get_bucket_for_value(ptr), pos);
}
writer.remove(&1.into());
assert!(writer.get(&1.into()).is_none());
}
#[test]
fn test_shrink_zero() {
let mut writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_shrink_zero")
.attach_writer();
writer.begin_shrink(0);
for i in 0..1500 {
writer.entry_at_bucket(i).map(|x| x.remove());
}
writer.finish_shrink().unwrap();
assert_eq!(writer.get_num_buckets_in_use(), 0);
let entry = writer.entry(1.into());
if let Entry::Vacant(v) = entry {
assert!(v.insert(2).is_err());
} else {
panic!("Somehow got non-vacant entry in empty map.")
}
writer.grow(50).unwrap();
let entry = writer.entry(1.into());
if let Entry::Vacant(v) = entry {
assert!(v.insert(2).is_ok());
} else {
panic!("Somehow got non-vacant entry in empty map.")
}
assert_eq!(writer.get_num_buckets_in_use(), 1);
}
#[test]
#[should_panic]
fn test_grow_oom() {
let writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_grow_oom")
.attach_writer();
writer.grow(20000).unwrap();
}
#[test]
#[should_panic]
fn test_shrink_bigger() {
let mut writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2500, "test_shrink_bigger")
.attach_writer();
writer.begin_shrink(2000);
}
#[test]
#[should_panic]
fn test_shrink_early_finish() {
let writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2500, "test_shrink_early_finish")
.attach_writer();
writer.finish_shrink().unwrap();
}
#[test]
#[should_panic]
fn test_shrink_fixed_size() {
let mut area = [MaybeUninit::uninit(); 10000];
let init_struct = HashMapInit::<TestKey, usize>::with_fixed(3, &mut area);
let mut writer = init_struct.attach_writer();
writer.begin_shrink(1);
}

View File

@@ -1,3 +1,418 @@
pub mod hash;
pub mod shmem;
pub mod sync;
//! Shared memory utilities for neon communicator
use std::num::NonZeroUsize;
use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use nix::errno::Errno;
use nix::sys::mman::MapFlags;
use nix::sys::mman::ProtFlags;
use nix::sys::mman::mmap as nix_mmap;
use nix::sys::mman::munmap as nix_munmap;
use nix::unistd::ftruncate as nix_ftruncate;
/// ShmemHandle represents a shared memory area that can be shared by processes over fork().
/// Unlike shared memory allocated by Postgres, this area is resizable, up to 'max_size' that's
/// specified at creation.
///
/// The area is backed by an anonymous file created with memfd_create(). The full address space for
/// 'max_size' is reserved up-front with mmap(), but whenever you call [`ShmemHandle::set_size`],
/// the underlying file is resized. Do not access the area beyond the current size. Currently, that
/// will cause the file to be expanded, but we might use mprotect() etc. to enforce that in the
/// future.
pub struct ShmemHandle {
/// memfd file descriptor
fd: OwnedFd,
max_size: usize,
// Pointer to the beginning of the shared memory area. The header is stored there.
shared_ptr: NonNull<SharedStruct>,
// Pointer to the beginning of the user data
pub data_ptr: NonNull<u8>,
}
/// This is stored at the beginning in the shared memory area.
struct SharedStruct {
max_size: usize,
/// Current size of the backing file. The high-order bit is used for the RESIZE_IN_PROGRESS flag
current_size: AtomicUsize,
}
const RESIZE_IN_PROGRESS: usize = 1 << 63;
const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
/// Error type returned by the ShmemHandle functions.
#[derive(thiserror::Error, Debug)]
#[error("{msg}: {errno}")]
pub struct Error {
pub msg: String,
pub errno: Errno,
}
impl Error {
fn new(msg: &str, errno: Errno) -> Error {
Error {
msg: msg.to_string(),
errno,
}
}
}
impl ShmemHandle {
/// Create a new shared memory area. To communicate between processes, the processes need to be
/// fork()'d after calling this, so that the ShmemHandle is inherited by all processes.
///
/// If the ShmemHandle is dropped, the memory is unmapped from the current process. Other
/// processes can continue using it, however.
pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<ShmemHandle, Error> {
// create the backing anonymous file.
let fd = create_backing_file(name)?;
Self::new_with_fd(fd, initial_size, max_size)
}
fn new_with_fd(
fd: OwnedFd,
initial_size: usize,
max_size: usize,
) -> Result<ShmemHandle, Error> {
// We reserve the high-order bit for the RESIZE_IN_PROGRESS flag, and the actual size
// is a little larger than this because of the SharedStruct header. Make the upper limit
// somewhat smaller than that, because with anything close to that, you'll run out of
// memory anyway.
if max_size >= 1 << 48 {
panic!("max size {max_size} too large");
}
if initial_size > max_size {
panic!("initial size {initial_size} larger than max size {max_size}");
}
// The actual initial / max size is the one given by the caller, plus the size of
// 'SharedStruct'.
let initial_size = HEADER_SIZE + initial_size;
let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
// Reserve address space for it with mmap
//
// TODO: Use MAP_HUGETLB if possible
let start_ptr = unsafe {
nix_mmap(
None,
max_size,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_SHARED,
&fd,
0,
)
}
.map_err(|e| Error::new("mmap failed: {e}", e))?;
// Reserve space for the initial size
enlarge_file(fd.as_fd(), initial_size as u64)?;
// Initialize the header
let shared: NonNull<SharedStruct> = start_ptr.cast();
unsafe {
shared.write(SharedStruct {
max_size: max_size.into(),
current_size: AtomicUsize::new(initial_size),
})
};
// The user data begins after the header
let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
Ok(ShmemHandle {
fd,
max_size: max_size.into(),
shared_ptr: shared,
data_ptr,
})
}
// return reference to the header
fn shared(&self) -> &SharedStruct {
unsafe { self.shared_ptr.as_ref() }
}
/// Resize the shared memory area. 'new_size' must not be larger than the 'max_size' specified
/// when creating the area.
///
/// This may only be called from one process/thread concurrently. We detect that case
/// and return an Error.
pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
let new_size = new_size + HEADER_SIZE;
let shared = self.shared();
if new_size > self.max_size {
panic!(
"new size ({} is greater than max size ({})",
new_size, self.max_size
);
}
assert_eq!(self.max_size, shared.max_size);
// Lock the area by setting the bit in 'current_size'
//
// Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
// and the posix_fallocate/ftruncate call is surely a synchronization point anyway. But
// since this is not performance-critical, better safe than sorry .
let mut old_size = shared.current_size.load(Ordering::Acquire);
loop {
if (old_size & RESIZE_IN_PROGRESS) != 0 {
return Err(Error::new(
"concurrent resize detected",
Errno::UnknownErrno,
));
}
match shared.current_size.compare_exchange(
old_size,
new_size,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => old_size = x,
}
}
// Ok, we got the lock.
//
// NB: If anything goes wrong, we *must* clear the bit!
let result = {
use std::cmp::Ordering::{Equal, Greater, Less};
match new_size.cmp(&old_size) {
Less => nix_ftruncate(&self.fd, new_size as i64).map_err(|e| {
Error::new("could not shrink shmem segment, ftruncate failed: {e}", e)
}),
Equal => Ok(()),
Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
}
};
// Unlock
shared.current_size.store(
if result.is_ok() { new_size } else { old_size },
Ordering::Release,
);
result
}
/// Returns the current user-visible size of the shared memory segment.
///
/// NOTE: a concurrent set_size() call can change the size at any time. It is the caller's
/// responsibility not to access the area beyond the current size.
pub fn current_size(&self) -> usize {
let total_current_size =
self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
total_current_size - HEADER_SIZE
}
}
impl Drop for ShmemHandle {
fn drop(&mut self) {
// SAFETY: The pointer was obtained from mmap() with the given size.
// We unmap the entire region.
let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
// The fd is dropped automatically by OwnedFd.
}
}
/// Create a "backing file" for the shared memory area. On Linux, use memfd_create(), to create an
/// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
/// development and testing, but in production we want the file to stay in memory.
///
/// disable 'unused_variables' warnings, because in the macos path, 'name' is unused.
#[allow(unused_variables)]
fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
#[cfg(not(target_os = "macos"))]
{
nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
.map_err(|e| Error::new("memfd_create failed: {e}", e))
}
#[cfg(target_os = "macos")]
{
let file = tempfile::tempfile().map_err(|e| {
Error::new(
"could not create temporary file to back shmem area: {e}",
nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
)
})?;
Ok(OwnedFd::from(file))
}
}
fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
// Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
// we don't get a segfault later when trying to actually use it.
#[cfg(not(target_os = "macos"))]
{
nix::fcntl::posix_fallocate(fd, 0, size as i64).map_err(|e| {
Error::new(
"could not grow shmem segment, posix_fallocate failed: {e}",
e,
)
})
}
// As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
#[cfg(target_os = "macos")]
{
nix::unistd::ftruncate(fd, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, ftruncate failed: {e}", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nix::unistd::ForkResult;
use std::ops::Range;
/// check that all bytes in given range have the expected value.
fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
for i in range {
let b = unsafe { *(ptr.add(i)) };
assert_eq!(expected, b, "unexpected byte at offset {i}");
}
}
/// Write 'b' to all bytes in the given range
fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
}
// simple single-process test of growing and shrinking
#[test]
fn test_shmem_resize() -> Result<(), Error> {
let max_size = 1024 * 1024;
let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
assert_eq!(init_struct.current_size(), 0);
// Initial grow
let size1 = 10000;
init_struct.set_size(size1).unwrap();
assert_eq!(init_struct.current_size(), size1);
// Write some data
let data_ptr = init_struct.data_ptr.as_ptr();
write_range(data_ptr, 0xAA, 0..size1);
assert_range(data_ptr, 0xAA, 0..size1);
// Shrink
let size2 = 5000;
init_struct.set_size(size2).unwrap();
assert_eq!(init_struct.current_size(), size2);
// Grow again
let size3 = 20000;
init_struct.set_size(size3).unwrap();
assert_eq!(init_struct.current_size(), size3);
// Try to read it. The area that was shrunk and grown again should read as all zeros now
assert_range(data_ptr, 0xAA, 0..5000);
assert_range(data_ptr, 0, 5000..size1);
// Try to grow beyond max_size
//let size4 = max_size + 1;
//assert!(init_struct.set_size(size4).is_err());
// Dropping init_struct should unmap the memory
drop(init_struct);
Ok(())
}
/// This is used in tests to coordinate between test processes. It's like std::sync::Barrier,
/// but is stored in the shared memory area and works across processes. It's implemented by
/// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
struct SimpleBarrier {
num_procs: usize,
count: AtomicUsize,
}
impl SimpleBarrier {
unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
unsafe {
*ptr = SimpleBarrier {
num_procs,
count: AtomicUsize::new(0),
}
}
}
pub fn wait(&self) {
let old = self.count.fetch_add(1, Ordering::Relaxed);
let generation = old / self.num_procs;
let mut current = old + 1;
while current < (generation + 1) * self.num_procs {
std::thread::sleep(std::time::Duration::from_millis(10));
current = self.count.load(Ordering::Relaxed);
}
}
}
#[test]
fn test_multi_process() {
// Initialize
let max_size = 1_000_000_000_000;
let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
let ptr = init_struct.data_ptr.as_ptr();
// Store the SimpleBarrier in the first 1k of the area.
init_struct.set_size(10000).unwrap();
let barrier_ptr: *mut SimpleBarrier = unsafe {
ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
.cast()
};
unsafe { SimpleBarrier::init(barrier_ptr, 2) };
let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
// Fork another test process. The code after this runs in both processes concurrently.
let fork_result = unsafe { nix::unistd::fork().unwrap() };
// In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
if fork_result.is_parent() {
write_range(ptr, 0xAA, 1000..2000);
} else {
write_range(ptr, 0xBB, 2000..3000);
}
barrier.wait();
// Verify the contents. (in both processes)
assert_range(ptr, 0xAA, 1000..2000);
assert_range(ptr, 0xBB, 2000..3000);
// Grow, from the child this time
let size = 10_000_000;
if !fork_result.is_parent() {
init_struct.set_size(size).unwrap();
}
barrier.wait();
// make some writes at the end
if fork_result.is_parent() {
write_range(ptr, 0xAA, (size - 10)..size);
} else {
write_range(ptr, 0xBB, (size - 20)..(size - 10));
}
barrier.wait();
// Verify the contents. (This runs in both processes)
assert_range(ptr, 0, (size - 1000)..(size - 20));
assert_range(ptr, 0xBB, (size - 20)..(size - 10));
assert_range(ptr, 0xAA, (size - 10)..size);
if let ForkResult::Parent { child } = fork_result {
nix::sys::wait::waitpid(child, None).unwrap();
}
}
}

View File

@@ -1,409 +0,0 @@
//! Dynamically resizable contiguous chunk of shared memory
use std::num::NonZeroUsize;
use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use nix::errno::Errno;
use nix::sys::mman::MapFlags;
use nix::sys::mman::ProtFlags;
use nix::sys::mman::mmap as nix_mmap;
use nix::sys::mman::munmap as nix_munmap;
use nix::unistd::ftruncate as nix_ftruncate;
/// `ShmemHandle` represents a shared memory area that can be shared by processes over `fork()`.
/// Unlike shared memory allocated by Postgres, this area is resizable, up to `max_size` that's
/// specified at creation.
///
/// The area is backed by an anonymous file created with `memfd_create()`. The full address space for
/// `max_size` is reserved up-front with `mmap()`, but whenever you call [`ShmemHandle::set_size`],
/// the underlying file is resized. Do not access the area beyond the current size. Currently, that
/// will cause the file to be expanded, but we might use `mprotect()` etc. to enforce that in the
/// future.
pub struct ShmemHandle {
/// memfd file descriptor
fd: OwnedFd,
max_size: usize,
// Pointer to the beginning of the shared memory area. The header is stored there.
shared_ptr: NonNull<SharedStruct>,
// Pointer to the beginning of the user data
pub data_ptr: NonNull<u8>,
}
/// This is stored at the beginning in the shared memory area.
struct SharedStruct {
max_size: usize,
/// Current size of the backing file. The high-order bit is used for the [`RESIZE_IN_PROGRESS`] flag.
current_size: AtomicUsize,
}
const RESIZE_IN_PROGRESS: usize = 1 << 63;
const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
/// Error type returned by the [`ShmemHandle`] functions.
#[derive(thiserror::Error, Debug)]
#[error("{msg}: {errno}")]
pub struct Error {
pub msg: String,
pub errno: Errno,
}
impl Error {
fn new(msg: &str, errno: Errno) -> Self {
Self {
msg: msg.to_string(),
errno,
}
}
}
impl ShmemHandle {
/// Create a new shared memory area. To communicate between processes, the processes need to be
/// `fork()`'d after calling this, so that the `ShmemHandle` is inherited by all processes.
///
/// If the `ShmemHandle` is dropped, the memory is unmapped from the current process. Other
/// processes can continue using it, however.
pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<Self, Error> {
// create the backing anonymous file.
let fd = create_backing_file(name)?;
Self::new_with_fd(fd, initial_size, max_size)
}
fn new_with_fd(fd: OwnedFd, initial_size: usize, max_size: usize) -> Result<Self, Error> {
// We reserve the high-order bit for the `RESIZE_IN_PROGRESS` flag, and the actual size
// is a little larger than this because of the SharedStruct header. Make the upper limit
// somewhat smaller than that, because with anything close to that, you'll run out of
// memory anyway.
assert!(max_size < 1 << 48, "max size {max_size} too large");
assert!(
initial_size <= max_size,
"initial size {initial_size} larger than max size {max_size}"
);
// The actual initial / max size is the one given by the caller, plus the size of
// 'SharedStruct'.
let initial_size = HEADER_SIZE + initial_size;
let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
// Reserve address space for it with mmap
//
// TODO: Use MAP_HUGETLB if possible
let start_ptr = unsafe {
nix_mmap(
None,
max_size,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_SHARED,
&fd,
0,
)
}
.map_err(|e| Error::new("mmap failed", e))?;
// Reserve space for the initial size
enlarge_file(fd.as_fd(), initial_size as u64)?;
// Initialize the header
let shared: NonNull<SharedStruct> = start_ptr.cast();
unsafe {
shared.write(SharedStruct {
max_size: max_size.into(),
current_size: AtomicUsize::new(initial_size),
});
}
// The user data begins after the header
let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
Ok(Self {
fd,
max_size: max_size.into(),
shared_ptr: shared,
data_ptr,
})
}
// return reference to the header
fn shared(&self) -> &SharedStruct {
unsafe { self.shared_ptr.as_ref() }
}
/// Resize the shared memory area. `new_size` must not be larger than the `max_size` specified
/// when creating the area.
///
/// This may only be called from one process/thread concurrently. We detect that case
/// and return an [`shmem::Error`](Error).
pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
let new_size = new_size + HEADER_SIZE;
let shared = self.shared();
assert!(
new_size <= self.max_size,
"new size ({new_size}) is greater than max size ({})",
self.max_size
);
assert_eq!(self.max_size, shared.max_size);
// Lock the area by setting the bit in `current_size`
//
// Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
// and the `posix_fallocate`/`ftruncate` call is surely a synchronization point anyway. But
// since this is not performance-critical, better safe than sorry.
let mut old_size = shared.current_size.load(Ordering::Acquire);
loop {
if (old_size & RESIZE_IN_PROGRESS) != 0 {
return Err(Error::new(
"concurrent resize detected",
Errno::UnknownErrno,
));
}
match shared.current_size.compare_exchange(
old_size,
new_size,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => old_size = x,
}
}
// Ok, we got the lock.
//
// NB: If anything goes wrong, we *must* clear the bit!
let result = {
use std::cmp::Ordering::{Equal, Greater, Less};
match new_size.cmp(&old_size) {
Less => nix_ftruncate(&self.fd, new_size as i64)
.map_err(|e| Error::new("could not shrink shmem segment, ftruncate failed", e)),
Equal => Ok(()),
Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
}
};
// Unlock
shared.current_size.store(
if result.is_ok() { new_size } else { old_size },
Ordering::Release,
);
result
}
/// Returns the current user-visible size of the shared memory segment.
///
/// NOTE: a concurrent [`ShmemHandle::set_size()`] call can change the size at any time.
/// It is the caller's responsibility not to access the area beyond the current size.
pub fn current_size(&self) -> usize {
let total_current_size =
self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
total_current_size - HEADER_SIZE
}
}
impl Drop for ShmemHandle {
fn drop(&mut self) {
// SAFETY: The pointer was obtained from mmap() with the given size.
// We unmap the entire region.
let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
// The fd is dropped automatically by OwnedFd.
}
}
/// Create a "backing file" for the shared memory area. On Linux, use `memfd_create()`, to create an
/// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
/// development and testing, but in production we want the file to stay in memory.
///
/// Disable unused variables warnings because `name` is unused in the macos path.
#[allow(unused_variables)]
fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
#[cfg(not(target_os = "macos"))]
{
nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
.map_err(|e| Error::new("memfd_create failed", e))
}
#[cfg(target_os = "macos")]
{
let file = tempfile::tempfile().map_err(|e| {
Error::new(
"could not create temporary file to back shmem area",
nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
)
})?;
Ok(OwnedFd::from(file))
}
}
fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
// Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
// we don't get a segfault later when trying to actually use it.
#[cfg(not(target_os = "macos"))]
{
nix::fcntl::posix_fallocate(fd, 0, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, posix_fallocate failed", e))
}
// As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
#[cfg(target_os = "macos")]
{
nix::unistd::ftruncate(fd, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, ftruncate failed", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nix::unistd::ForkResult;
use std::ops::Range;
/// check that all bytes in given range have the expected value.
fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
for i in range {
let b = unsafe { *(ptr.add(i)) };
assert_eq!(expected, b, "unexpected byte at offset {i}");
}
}
/// Write 'b' to all bytes in the given range
fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
}
// simple single-process test of growing and shrinking
#[test]
fn test_shmem_resize() -> Result<(), Error> {
let max_size = 1024 * 1024;
let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
assert_eq!(init_struct.current_size(), 0);
// Initial grow
let size1 = 10000;
init_struct.set_size(size1).unwrap();
assert_eq!(init_struct.current_size(), size1);
// Write some data
let data_ptr = init_struct.data_ptr.as_ptr();
write_range(data_ptr, 0xAA, 0..size1);
assert_range(data_ptr, 0xAA, 0..size1);
// Shrink
let size2 = 5000;
init_struct.set_size(size2).unwrap();
assert_eq!(init_struct.current_size(), size2);
// Grow again
let size3 = 20000;
init_struct.set_size(size3).unwrap();
assert_eq!(init_struct.current_size(), size3);
// Try to read it. The area that was shrunk and grown again should read as all zeros now
assert_range(data_ptr, 0xAA, 0..5000);
assert_range(data_ptr, 0, 5000..size1);
// Try to grow beyond max_size
//let size4 = max_size + 1;
//assert!(init_struct.set_size(size4).is_err());
// Dropping init_struct should unmap the memory
drop(init_struct);
Ok(())
}
/// This is used in tests to coordinate between test processes. It's like `std::sync::Barrier`,
/// but is stored in the shared memory area and works across processes. It's implemented by
/// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
struct SimpleBarrier {
num_procs: usize,
count: AtomicUsize,
}
impl SimpleBarrier {
unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
unsafe {
*ptr = SimpleBarrier {
num_procs,
count: AtomicUsize::new(0),
}
}
}
pub fn wait(&self) {
let old = self.count.fetch_add(1, Ordering::Relaxed);
let generation = old / self.num_procs;
let mut current = old + 1;
while current < (generation + 1) * self.num_procs {
std::thread::sleep(std::time::Duration::from_millis(10));
current = self.count.load(Ordering::Relaxed);
}
}
}
#[test]
fn test_multi_process() {
// Initialize
let max_size = 1_000_000_000_000;
let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
let ptr = init_struct.data_ptr.as_ptr();
// Store the SimpleBarrier in the first 1k of the area.
init_struct.set_size(10000).unwrap();
let barrier_ptr: *mut SimpleBarrier = unsafe {
ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
.cast()
};
unsafe { SimpleBarrier::init(barrier_ptr, 2) };
let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
// Fork another test process. The code after this runs in both processes concurrently.
let fork_result = unsafe { nix::unistd::fork().unwrap() };
// In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
if fork_result.is_parent() {
write_range(ptr, 0xAA, 1000..2000);
} else {
write_range(ptr, 0xBB, 2000..3000);
}
barrier.wait();
// Verify the contents. (in both processes)
assert_range(ptr, 0xAA, 1000..2000);
assert_range(ptr, 0xBB, 2000..3000);
// Grow, from the child this time
let size = 10_000_000;
if !fork_result.is_parent() {
init_struct.set_size(size).unwrap();
}
barrier.wait();
// make some writes at the end
if fork_result.is_parent() {
write_range(ptr, 0xAA, (size - 10)..size);
} else {
write_range(ptr, 0xBB, (size - 20)..(size - 10));
}
barrier.wait();
// Verify the contents. (This runs in both processes)
assert_range(ptr, 0, (size - 1000)..(size - 20));
assert_range(ptr, 0xBB, (size - 20)..(size - 10));
assert_range(ptr, 0xAA, (size - 10)..size);
if let ForkResult::Parent { child } = fork_result {
nix::sys::wait::waitpid(child, None).unwrap();
}
}
}

View File

@@ -1,111 +0,0 @@
//! Simple utilities akin to what's in [`std::sync`] but designed to work with shared memory.
use std::mem::MaybeUninit;
use std::ptr::NonNull;
use nix::errno::Errno;
pub type RwLock<T> = lock_api::RwLock<PthreadRwLock, T>;
pub type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, PthreadRwLock, T>;
pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, PthreadRwLock, T>;
pub type ValueReadGuard<'a, T> = lock_api::MappedRwLockReadGuard<'a, PthreadRwLock, T>;
pub type ValueWriteGuard<'a, T> = lock_api::MappedRwLockWriteGuard<'a, PthreadRwLock, T>;
/// Shared memory read-write lock.
pub struct PthreadRwLock(Option<NonNull<libc::pthread_rwlock_t>>);
/// Simple macro that calls a function in the libc namespace and panics if return value is nonzero.
macro_rules! libc_checked {
($fn_name:ident ( $($arg:expr),* )) => {{
let res = libc::$fn_name($($arg),*);
if res != 0 {
panic!("{} failed with {}", stringify!($fn_name), Errno::from_raw(res));
}
}};
}
impl PthreadRwLock {
/// Creates a new `PthreadRwLock` on top of a pointer to a pthread rwlock.
///
/// # Safety
/// `lock` must be non-null. Every unsafe operation will panic in the event of an error.
pub unsafe fn new(lock: *mut libc::pthread_rwlock_t) -> Self {
unsafe {
let mut attrs = MaybeUninit::uninit();
libc_checked!(pthread_rwlockattr_init(attrs.as_mut_ptr()));
libc_checked!(pthread_rwlockattr_setpshared(
attrs.as_mut_ptr(),
libc::PTHREAD_PROCESS_SHARED
));
libc_checked!(pthread_rwlock_init(lock, attrs.as_mut_ptr()));
// Safety: POSIX specifies that "any function affecting the attributes
// object (including destruction) shall not affect any previously
// initialized read-write locks".
libc_checked!(pthread_rwlockattr_destroy(attrs.as_mut_ptr()));
Self(Some(NonNull::new_unchecked(lock)))
}
}
fn inner(&self) -> NonNull<libc::pthread_rwlock_t> {
match self.0 {
None => {
panic!("PthreadRwLock constructed badly - something likely used RawRwLock::INIT")
}
Some(x) => x,
}
}
}
unsafe impl lock_api::RawRwLock for PthreadRwLock {
type GuardMarker = lock_api::GuardSend;
const INIT: Self = Self(None);
fn try_lock_shared(&self) -> bool {
unsafe {
let res = libc::pthread_rwlock_tryrdlock(self.inner().as_ptr());
match res {
0 => true,
libc::EAGAIN => false,
_ => panic!(
"pthread_rwlock_tryrdlock failed with {}",
Errno::from_raw(res)
),
}
}
}
fn try_lock_exclusive(&self) -> bool {
unsafe {
let res = libc::pthread_rwlock_trywrlock(self.inner().as_ptr());
match res {
0 => true,
libc::EAGAIN => false,
_ => panic!("try_wrlock failed with {}", Errno::from_raw(res)),
}
}
}
fn lock_shared(&self) {
unsafe {
libc_checked!(pthread_rwlock_rdlock(self.inner().as_ptr()));
}
}
fn lock_exclusive(&self) {
unsafe {
libc_checked!(pthread_rwlock_wrlock(self.inner().as_ptr()));
}
}
unsafe fn unlock_exclusive(&self) {
unsafe {
libc_checked!(pthread_rwlock_unlock(self.inner().as_ptr()));
}
}
unsafe fn unlock_shared(&self) {
unsafe {
libc_checked!(pthread_rwlock_unlock(self.inner().as_ptr()));
}
}
}

View File

@@ -1,6 +1,5 @@
use std::collections::{HashMap, HashSet};
use std::fmt::Display;
use std::net::IpAddr;
use std::str::FromStr;
use std::time::{Duration, Instant};
@@ -61,11 +60,6 @@ pub struct NodeRegisterRequest {
pub listen_https_port: Option<u16>,
pub availability_zone_id: AvailabilityZone,
// Reachable IP address of the PS/SK registering, if known.
// Hadron Cluster Coordiantor will update the DNS record of the registering node
// with this IP address.
pub node_ip_addr: Option<IpAddr>,
}
#[derive(Serialize, Deserialize)]
@@ -551,39 +545,6 @@ pub struct SafekeeperDescribeResponse {
pub scheduling_policy: SkSchedulingPolicy,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct TimelineSafekeeperPeer {
pub node_id: NodeId,
pub listen_http_addr: String,
pub http_port: i32,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct SCSafekeeperTimeline {
// SC does not know the tenant id.
pub timeline_id: TimelineId,
pub peers: Vec<NodeId>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct SCSafekeeperTimelinesResponse {
pub timelines: Vec<SCSafekeeperTimeline>,
pub safekeeper_peers: Vec<TimelineSafekeeperPeer>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct SafekeeperTimeline {
pub tenant_id: TenantId,
pub timeline_id: TimelineId,
pub peers: Vec<NodeId>,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct SafekeeperTimelinesResponse {
pub timelines: Vec<SafekeeperTimeline>,
pub safekeeper_peers: Vec<TimelineSafekeeperPeer>,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct SafekeeperSchedulingPolicyRequest {
pub scheduling_policy: SkSchedulingPolicy,

View File

@@ -749,18 +749,7 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
trace!("got query {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
match e {
err @ QueryError::Shutdown => {
// Notify postgres of the connection shutdown at the libpq
// protocol level. This avoids postgres having to tell apart
// from an idle connection and a stale one, which is bug prone.
let shutdown_error = short_error(&err);
self.write_message_noflush(&BeMessage::ErrorResponse(
&shutdown_error,
Some(err.pg_error_code()),
))?;
return Ok(ProcessMsgResult::Break);
}
QueryError::Shutdown => return Ok(ProcessMsgResult::Break),
QueryError::SimulatedConnectionError => {
return Err(QueryError::SimulatedConnectionError);
}

View File

@@ -110,6 +110,7 @@ fn main() -> anyhow::Result<()> {
.allowlist_type("XLogRecPtr")
.allowlist_type("XLogSegNo")
.allowlist_type("TimeLineID")
.allowlist_type("TimestampTz")
.allowlist_type("MultiXactId")
.allowlist_type("MultiXactOffset")
.allowlist_type("MultiXactStatus")

View File

@@ -227,7 +227,8 @@ pub mod walrecord;
// Export some widely used datatypes that are unlikely to change across Postgres versions
pub use v14::bindings::{
BlockNumber, CheckPoint, ControlFileData, MultiXactId, OffsetNumber, Oid, PageHeaderData,
RepOriginId, TimeLineID, TransactionId, XLogRecPtr, XLogRecord, XLogSegNo, uint32, uint64,
RepOriginId, TimeLineID, TimestampTz, TransactionId, XLogRecPtr, XLogRecord, XLogSegNo, uint32,
uint64,
};
// Likewise for these, although the assumption that these don't change is a little more iffy.
pub use v14::bindings::{MultiXactOffset, MultiXactStatus};

View File

@@ -4,14 +4,13 @@
//! TODO: Generate separate types for each supported PG version
use bytes::{Buf, Bytes};
use postgres_ffi_types::TimestampTz;
use serde::{Deserialize, Serialize};
use utils::bin_ser::DeserializeError;
use utils::lsn::Lsn;
use crate::{
BLCKSZ, BlockNumber, MultiXactId, MultiXactOffset, MultiXactStatus, Oid, PgMajorVersion,
RepOriginId, TransactionId, XLOG_SIZE_OF_XLOG_RECORD, XLogRecord, pg_constants,
RepOriginId, TimestampTz, TransactionId, XLOG_SIZE_OF_XLOG_RECORD, XLogRecord, pg_constants,
};
#[repr(C)]
@@ -864,8 +863,7 @@ pub mod v17 {
XlHeapDelete, XlHeapInsert, XlHeapLock, XlHeapMultiInsert, XlHeapUpdate, XlParameterChange,
rm_neon,
};
pub use crate::TimeLineID;
pub use postgres_ffi_types::TimestampTz;
pub use crate::{TimeLineID, TimestampTz};
#[repr(C)]
#[derive(Debug)]

View File

@@ -9,11 +9,10 @@
use super::super::waldecoder::WalStreamDecoder;
use super::bindings::{
CheckPoint, ControlFileData, DBState_DB_SHUTDOWNED, FullTransactionId, TimeLineID,
CheckPoint, ControlFileData, DBState_DB_SHUTDOWNED, FullTransactionId, TimeLineID, TimestampTz,
XLogLongPageHeaderData, XLogPageHeaderData, XLogRecPtr, XLogRecord, XLogSegNo, XLOG_PAGE_MAGIC,
MY_PGVERSION
};
use postgres_ffi_types::TimestampTz;
use super::wal_generator::LogicalMessageGenerator;
use crate::pg_constants;
use crate::PG_TLI;

View File

@@ -11,4 +11,3 @@ pub mod forknum;
pub type Oid = u32;
pub type RepOriginId = u16;
pub type TimestampTz = i64;

View File

@@ -31,7 +31,6 @@ pub struct UnreliableWrapper {
/* BEGIN_HADRON */
// This the probability of failure for each operation, ranged from [0, 100].
// The probability is default to 100, which means that all operations will fail.
// Storage will fail by probability up to attempts_to_fail times.
attempt_failure_probability: u64,
/* END_HADRON */
}

View File

@@ -9,7 +9,7 @@ anyhow.workspace = true
const_format.workspace = true
serde.workspace = true
serde_json.workspace = true
postgres_ffi_types.workspace = true
postgres_ffi.workspace = true
postgres_versioninfo.workspace = true
pq_proto.workspace = true
tokio.workspace = true

View File

@@ -3,7 +3,7 @@
use std::net::SocketAddr;
use pageserver_api::shard::ShardIdentity;
use postgres_ffi_types::TimestampTz;
use postgres_ffi::TimestampTz;
use postgres_versioninfo::PgVersionId;
use serde::{Deserialize, Serialize};
use tokio::time::Instant;

View File

@@ -47,7 +47,6 @@ tracing-subscriber = { workspace = true, features = ["json", "registry"] }
tracing-utils.workspace = true
rand.workspace = true
scopeguard.workspace = true
uuid.workspace = true
strum.workspace = true
strum_macros.workspace = true
walkdir.workspace = true

View File

@@ -12,8 +12,7 @@ use jsonwebtoken::{
Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode,
};
use pem::Pem;
use serde::{Deserialize, Deserializer, Serialize, de::DeserializeOwned};
use uuid::Uuid;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use crate::id::TenantId;
@@ -26,11 +25,6 @@ pub enum Scope {
/// Provides access to all data for a specific tenant (specified in `struct Claims` below)
// TODO: join these two?
Tenant,
/// Provides access to all data for a specific tenant, but based on endpoint ID. This token scope
/// is only used by compute to fetch the spec for a specific endpoint. The spec contains a Tenant-scoped
/// token authorizing access to all data of a tenant, so the spec-fetch API requires a TenantEndpoint
/// scope token to ensure that untrusted compute nodes can't fetch spec for arbitrary endpoints.
TenantEndpoint,
/// Provides blanket access to all tenants on the pageserver plus pageserver-wide APIs.
/// Should only be used e.g. for status check/tenant creation/list.
PageServerApi,
@@ -57,43 +51,17 @@ pub enum Scope {
ControllerPeer,
}
fn deserialize_empty_string_as_none_uuid<'de, D>(deserializer: D) -> Result<Option<Uuid>, D::Error>
where
D: Deserializer<'de>,
{
let opt = Option::<String>::deserialize(deserializer)?;
match opt.as_deref() {
Some("") => Ok(None),
Some(s) => Uuid::parse_str(s)
.map(Some)
.map_err(serde::de::Error::custom),
None => Ok(None),
}
}
/// JWT payload. See docs/authentication.md for the format
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct Claims {
#[serde(default)]
pub tenant_id: Option<TenantId>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
// Neon control plane includes this field as empty in the claims.
// Consider it None in those cases.
deserialize_with = "deserialize_empty_string_as_none_uuid"
)]
pub endpoint_id: Option<Uuid>,
pub scope: Scope,
}
impl Claims {
pub fn new(tenant_id: Option<TenantId>, scope: Scope) -> Self {
Self {
tenant_id,
scope,
endpoint_id: None,
}
Self { tenant_id, scope }
}
}
@@ -244,7 +212,6 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
let expected_claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
scope: Scope::Tenant,
endpoint_id: None,
};
// A test token containing the following payload, signed using TEST_PRIV_KEY_ED25519:
@@ -273,7 +240,6 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
let claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
scope: Scope::Tenant,
endpoint_id: None,
};
let pem = pem::parse(TEST_PRIV_KEY_ED25519).unwrap();

View File

@@ -47,7 +47,6 @@ where
/* BEGIN_HADRON */
pub enum DeploymentMode {
Local,
Dev,
Staging,
Prod,
@@ -65,7 +64,7 @@ pub fn get_deployment_mode() -> Option<DeploymentMode> {
}
},
Err(_) => {
// tracing::error!("DEPLOYMENT_MODE not set");
tracing::error!("DEPLOYMENT_MODE not set");
None
}
}

View File

@@ -1,73 +0,0 @@
use std::env::{VarError, var};
use std::error::Error;
use std::net::IpAddr;
use std::str::FromStr;
/// Name of the environment variable containing the reachable IP address of the node. If set, the IP address contained in this
/// environment variable is used as the reachable IP address of the pageserver or safekeeper node during node registration.
/// In a Kubernetes environment, this environment variable should be set by Kubernetes to the Pod IP (specified in the Pod
/// template).
pub const HADRON_NODE_IP_ADDRESS: &str = "HADRON_NODE_IP_ADDRESS";
/// Read the reachable IP address of this page server from env var HADRON_NODE_IP_ADDRESS.
/// In Kubernetes this environment variable is set to the Pod IP (specified in the Pod template).
pub fn read_node_ip_addr_from_env() -> Result<Option<IpAddr>, Box<dyn Error>> {
match var(HADRON_NODE_IP_ADDRESS) {
Ok(v) => {
if let Ok(addr) = IpAddr::from_str(&v) {
Ok(Some(addr))
} else {
Err(format!("Invalid IP address string: {v}. Cannot be parsed as either an IPv4 or an IPv6 address.").into())
}
}
Err(VarError::NotPresent) => Ok(None),
Err(e) => Err(e.into()),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
use std::net::{Ipv4Addr, Ipv6Addr};
#[test]
fn test_read_node_ip_addr_from_env() {
// SAFETY: test code
unsafe {
// Test with a valid IPv4 address
env::set_var(HADRON_NODE_IP_ADDRESS, "192.168.1.1");
let result = read_node_ip_addr_from_env().unwrap();
assert_eq!(result, Some(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))));
// Test with a valid IPv6 address
env::set_var(
HADRON_NODE_IP_ADDRESS,
"2001:0db8:85a3:0000:0000:8a2e:0370:7334",
);
}
let result = read_node_ip_addr_from_env().unwrap();
assert_eq!(
result,
Some(IpAddr::V6(
Ipv6Addr::from_str("2001:0db8:85a3:0000:0000:8a2e:0370:7334").unwrap()
))
);
// Test with an invalid IP address
// SAFETY: test code
unsafe {
env::set_var(HADRON_NODE_IP_ADDRESS, "invalid_ip");
}
let result = read_node_ip_addr_from_env();
assert!(result.is_err());
// Test with no environment variable set
// SAFETY: test code
unsafe {
env::remove_var(HADRON_NODE_IP_ADDRESS);
}
let result = read_node_ip_addr_from_env().unwrap();
assert_eq!(result, None);
}
}

View File

@@ -26,9 +26,6 @@ pub mod auth;
// utility functions and helper traits for unified unique id generation/serialization etc.
pub mod id;
// utility functions to obtain reachable IP addresses in PS/SK nodes.
pub mod ip_address;
pub mod shard;
mod hex;

View File

@@ -1,5 +1,4 @@
use std::future::Future;
use std::pin::Pin;
use std::str::FromStr;
use std::time::Duration;
@@ -8,7 +7,7 @@ use metrics::{IntCounter, IntCounterVec};
use once_cell::sync::Lazy;
use strum_macros::{EnumString, VariantNames};
use tokio::time::Instant;
use tracing::{info, warn};
use tracing::info;
/// Logs a critical error, similarly to `tracing::error!`. This will:
///
@@ -378,11 +377,10 @@ impl std::fmt::Debug for SecretString {
///
/// TODO: consider upgrading this to a warning, but currently it fires too often.
#[inline]
pub async fn log_slow<O>(
name: &str,
threshold: Duration,
f: Pin<&mut impl Future<Output = O>>,
) -> O {
pub async fn log_slow<F, O>(name: &str, threshold: Duration, f: std::pin::Pin<&mut F>) -> O
where
F: Future<Output = O>,
{
monitor_slow_future(
threshold,
threshold, // period = threshold
@@ -396,42 +394,16 @@ pub async fn log_slow<O>(
if !is_slow {
return;
}
let elapsed = elapsed_total.as_secs_f64();
if ready {
info!("slow {name} completed after {elapsed:.3}s");
info!(
"slow {name} completed after {:.3}s",
elapsed_total.as_secs_f64()
);
} else {
info!("slow {name} still running after {elapsed:.3}s");
}
},
)
.await
}
/// Logs a periodic warning if a future is slow to complete.
#[inline]
pub async fn warn_slow<O>(
name: &str,
threshold: Duration,
f: Pin<&mut impl Future<Output = O>>,
) -> O {
monitor_slow_future(
threshold,
threshold, // period = threshold
f,
|MonitorSlowFutureCallback {
ready,
is_slow,
elapsed_total,
elapsed_since_last_callback: _,
}| {
if !is_slow {
return;
}
let elapsed = elapsed_total.as_secs_f64();
if ready {
warn!("slow {name} completed after {elapsed:.3}s");
} else {
warn!("slow {name} still running after {elapsed:.3}s");
info!(
"slow {name} still running after {:.3}s",
elapsed_total.as_secs_f64()
);
}
},
)
@@ -444,7 +416,7 @@ pub async fn warn_slow<O>(
pub async fn monitor_slow_future<F, O>(
threshold: Duration,
period: Duration,
mut fut: Pin<&mut F>,
mut fut: std::pin::Pin<&mut F>,
mut cb: impl FnMut(MonitorSlowFutureCallback),
) -> O
where

View File

@@ -2,8 +2,7 @@
use bytes::Bytes;
use postgres_ffi::walrecord::{MultiXactMember, describe_postgres_wal_record};
use postgres_ffi::{MultiXactId, MultiXactOffset, TransactionId};
use postgres_ffi_types::TimestampTz;
use postgres_ffi::{MultiXactId, MultiXactOffset, TimestampTz, TransactionId};
use serde::{Deserialize, Serialize};
use utils::bin_ser::DeserializeError;

View File

@@ -431,7 +431,7 @@ pub fn empty_shmem() -> crate::bindings::WalproposerShmemState {
let empty_wal_rate_limiter = crate::bindings::WalRateLimiter {
should_limit: crate::bindings::pg_atomic_uint32 { value: 0 },
sent_bytes: 0,
last_recorded_time_us: crate::bindings::pg_atomic_uint64 { value: 0 },
last_recorded_time_us: 0,
};
crate::bindings::WalproposerShmemState {

View File

@@ -873,22 +873,6 @@ impl Client {
.map_err(Error::ReceiveBody)
}
pub async fn reset_alert_gauges(&self) -> Result<()> {
let uri = format!(
"{}/hadron-internal/reset_alert_gauges",
self.mgmt_api_endpoint
);
self.start_request(Method::POST, uri)
.send()
.await
.map_err(Error::SendRequest)?
.error_from_body()
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
pub async fn wait_lsn(
&self,
tenant_shard_id: TenantShardId,

View File

@@ -1,16 +1,13 @@
use std::collections::HashMap;
use std::num::NonZero;
use std::pin::pin;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::anyhow;
use arc_swap::ArcSwap;
use futures::stream::FuturesUnordered;
use futures::{FutureExt as _, StreamExt as _};
use tonic::codec::CompressionEncoding;
use tracing::{debug, instrument};
use utils::logging::warn_slow;
use tracing::instrument;
use crate::pool::{ChannelPool, ClientGuard, ClientPool, StreamGuard, StreamPool};
use crate::retry::Retry;
@@ -24,40 +21,28 @@ use utils::shard::{ShardCount, ShardIndex, ShardNumber};
/// Max number of concurrent clients per channel (i.e. TCP connection). New channels will be spun up
/// when full.
///
/// Normal requests are small, and we don't pipeline them, so we can afford a large number of
/// streams per connection.
///
/// TODO: tune all of these constants, and consider making them configurable.
const MAX_CLIENTS_PER_CHANNEL: NonZero<usize> = NonZero::new(64).unwrap();
/// TODO: consider separate limits for unary and streaming clients, so we don't fill up channels
/// with only streams.
const MAX_CLIENTS_PER_CHANNEL: NonZero<usize> = NonZero::new(16).unwrap();
/// Max number of concurrent bulk GetPage streams per channel (i.e. TCP connection). These use a
/// dedicated channel pool with a lower client limit, to avoid TCP-level head-of-line blocking and
/// transmission delays. This also concentrates large window sizes on a smaller set of
/// streams/connections, presumably reducing memory use.
const MAX_BULK_CLIENTS_PER_CHANNEL: NonZero<usize> = NonZero::new(16).unwrap();
/// Max number of concurrent unary request clients per shard.
const MAX_UNARY_CLIENTS: NonZero<usize> = NonZero::new(64).unwrap();
/// The batch size threshold at which a GetPage request will use the bulk stream pool.
///
/// The gRPC initial window size is 64 KB. Each page is 8 KB, so let's avoid increasing the window
/// size for the normal stream pool, and route requests for >= 5 pages (>32 KB) to the bulk pool.
const BULK_THRESHOLD_BATCH_SIZE: usize = 5;
/// Max number of concurrent GetPage streams per shard. The max number of concurrent GetPage
/// requests is given by `MAX_STREAMS * MAX_STREAM_QUEUE_DEPTH`.
const MAX_STREAMS: NonZero<usize> = NonZero::new(64).unwrap();
/// The overall request call timeout, including retries and pool acquisition.
/// TODO: should we retry forever? Should the caller decide?
const CALL_TIMEOUT: Duration = Duration::from_secs(60);
/// Max number of pipelined requests per stream.
const MAX_STREAM_QUEUE_DEPTH: NonZero<usize> = NonZero::new(2).unwrap();
/// The per-request (retry attempt) timeout, including any lazy connection establishment.
const REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
/// Max number of concurrent bulk GetPage streams per shard, used e.g. for prefetches. Because these
/// are more throughput-oriented, we have a smaller limit but higher queue depth.
const MAX_BULK_STREAMS: NonZero<usize> = NonZero::new(16).unwrap();
/// The initial request retry backoff duration. The first retry does not back off.
/// TODO: use a different backoff for ResourceExhausted (rate limiting)? Needs server support.
const BASE_BACKOFF: Duration = Duration::from_millis(5);
/// The maximum request retry backoff duration.
const MAX_BACKOFF: Duration = Duration::from_secs(5);
/// Threshold and interval for warning about slow operation.
const SLOW_THRESHOLD: Duration = Duration::from_secs(3);
/// Max number of pipelined requests per bulk stream. These are more throughput-oriented and thus
/// get a larger queue depth.
const MAX_BULK_STREAM_QUEUE_DEPTH: NonZero<usize> = NonZero::new(4).unwrap();
/// A rich Pageserver gRPC client for a single tenant timeline. This client is more capable than the
/// basic `page_api::Client` gRPC client, and supports:
@@ -65,19 +50,10 @@ const SLOW_THRESHOLD: Duration = Duration::from_secs(3);
/// * Sharded tenants across multiple Pageservers.
/// * Pooling of connections, clients, and streams for efficient resource use.
/// * Concurrent use by many callers.
/// * Internal handling of GetPage bidirectional streams.
/// * Internal handling of GetPage bidirectional streams, with pipelining and error handling.
/// * Automatic retries.
/// * Observability.
///
/// The client has dedicated connection/client/stream pools per shard, for resource reuse. These
/// pools are unbounded: we allow scaling out as many concurrent streams as needed to serve all
/// concurrent callers, which mostly eliminates head-of-line blocking. Idle streams are fairly
/// cheap: the server task currently uses 26 KB of memory, so we can comfortably fit 100,000
/// concurrent idle streams (2.5 GB memory). The worst case degenerates to the old libpq case with
/// one stream per backend, but without the TCP connection overhead. In the common case we expect
/// significantly lower stream counts due to stream sharing, driven e.g. by idle backends, LFC hits,
/// read coalescing, sharding (backends typically only talk to one shard at a time), etc.
///
/// TODO: this client does not support base backups or LSN leases, as these are only used by
/// compute_ctl. Consider adding this, but LSN leases need concurrent requests on all shards.
pub struct PageserverClient {
@@ -91,6 +67,8 @@ pub struct PageserverClient {
compression: Option<CompressionEncoding>,
/// The shards for this tenant.
shards: ArcSwap<Shards>,
/// The retry configuration.
retry: Retry,
}
impl PageserverClient {
@@ -116,6 +94,7 @@ impl PageserverClient {
auth_token,
compression,
shards: ArcSwap::new(Arc::new(shards)),
retry: Retry,
})
}
@@ -163,15 +142,13 @@ impl PageserverClient {
&self,
req: page_api::CheckRelExistsRequest,
) -> tonic::Result<page_api::CheckRelExistsResponse> {
debug!("sending request: {req:?}");
let resp = Self::with_retries(CALL_TIMEOUT, async |_| {
// Relation metadata is only available on shard 0.
let mut client = self.shards.load_full().get_zero().client().await?;
Self::with_timeout(REQUEST_TIMEOUT, client.check_rel_exists(req)).await
})
.await?;
debug!("received response: {resp:?}");
Ok(resp)
self.retry
.with(async |_| {
// Relation metadata is only available on shard 0.
let mut client = self.shards.load_full().get_zero().client().await?;
client.check_rel_exists(req).await
})
.await
}
/// Returns the total size of a database, as # of bytes.
@@ -180,15 +157,13 @@ impl PageserverClient {
&self,
req: page_api::GetDbSizeRequest,
) -> tonic::Result<page_api::GetDbSizeResponse> {
debug!("sending request: {req:?}");
let resp = Self::with_retries(CALL_TIMEOUT, async |_| {
// Relation metadata is only available on shard 0.
let mut client = self.shards.load_full().get_zero().client().await?;
Self::with_timeout(REQUEST_TIMEOUT, client.get_db_size(req)).await
})
.await?;
debug!("received response: {resp:?}");
Ok(resp)
self.retry
.with(async |_| {
// Relation metadata is only available on shard 0.
let mut client = self.shards.load_full().get_zero().client().await?;
client.get_db_size(req).await
})
.await
}
/// Fetches pages. The `request_id` must be unique across all in-flight requests, and the
@@ -218,8 +193,6 @@ impl PageserverClient {
return Err(tonic::Status::invalid_argument("request attempt must be 0"));
}
debug!("sending request: {req:?}");
// The shards may change while we're fetching pages. We execute the request using a stable
// view of the shards (especially important for requests that span shards), but retry the
// top-level (pre-split) request to pick up shard changes. This can lead to unnecessary
@@ -228,16 +201,13 @@ impl PageserverClient {
//
// TODO: the gRPC server and client doesn't yet properly support shard splits. Revisit this
// once we figure out how to handle these.
let resp = Self::with_retries(CALL_TIMEOUT, async |attempt| {
let mut req = req.clone();
req.request_id.attempt = attempt as u32;
let shards = self.shards.load_full();
Self::with_timeout(REQUEST_TIMEOUT, Self::get_page_with_shards(req, &shards)).await
})
.await?;
debug!("received response: {resp:?}");
Ok(resp)
self.retry
.with(async |attempt| {
let mut req = req.clone();
req.request_id.attempt = attempt as u32;
Self::get_page_with_shards(req, &self.shards.load_full()).await
})
.await
}
/// Fetches pages using the given shards. This uses a stable view of the shards, regardless of
@@ -276,7 +246,7 @@ impl PageserverClient {
req: page_api::GetPageRequest,
shard: &Shard,
) -> tonic::Result<page_api::GetPageResponse> {
let mut stream = shard.stream(Self::is_bulk(&req)).await?;
let stream = shard.stream(req.request_class.is_bulk()).await;
let resp = stream.send(req.clone()).await?;
// Convert per-request errors into a tonic::Status.
@@ -320,15 +290,13 @@ impl PageserverClient {
&self,
req: page_api::GetRelSizeRequest,
) -> tonic::Result<page_api::GetRelSizeResponse> {
debug!("sending request: {req:?}");
let resp = Self::with_retries(CALL_TIMEOUT, async |_| {
// Relation metadata is only available on shard 0.
let mut client = self.shards.load_full().get_zero().client().await?;
Self::with_timeout(REQUEST_TIMEOUT, client.get_rel_size(req)).await
})
.await?;
debug!("received response: {resp:?}");
Ok(resp)
self.retry
.with(async |_| {
// Relation metadata is only available on shard 0.
let mut client = self.shards.load_full().get_zero().client().await?;
client.get_rel_size(req).await
})
.await
}
/// Fetches an SLRU segment.
@@ -337,50 +305,13 @@ impl PageserverClient {
&self,
req: page_api::GetSlruSegmentRequest,
) -> tonic::Result<page_api::GetSlruSegmentResponse> {
debug!("sending request: {req:?}");
let resp = Self::with_retries(CALL_TIMEOUT, async |_| {
// SLRU segments are only available on shard 0.
let mut client = self.shards.load_full().get_zero().client().await?;
Self::with_timeout(REQUEST_TIMEOUT, client.get_slru_segment(req)).await
})
.await?;
debug!("received response: {resp:?}");
Ok(resp)
}
/// Runs the given async closure with retries up to the given timeout. Only certain gRPC status
/// codes are retried, see [`Retry::should_retry`]. Returns `DeadlineExceeded` on timeout.
async fn with_retries<T, F, O>(timeout: Duration, f: F) -> tonic::Result<T>
where
F: FnMut(usize) -> O, // pass attempt number, starting at 0
O: Future<Output = tonic::Result<T>>,
{
Retry {
timeout: Some(timeout),
base_backoff: BASE_BACKOFF,
max_backoff: MAX_BACKOFF,
}
.with(f)
.await
}
/// Runs the given future with a timeout. Returns `DeadlineExceeded` on timeout.
async fn with_timeout<T>(
timeout: Duration,
f: impl Future<Output = tonic::Result<T>>,
) -> tonic::Result<T> {
let started = Instant::now();
tokio::time::timeout(timeout, f).await.map_err(|_| {
tonic::Status::deadline_exceeded(format!(
"request timed out after {:.3}s",
started.elapsed().as_secs_f64()
))
})?
}
/// Returns true if the request is considered a bulk request and should use the bulk pool.
fn is_bulk(req: &page_api::GetPageRequest) -> bool {
req.block_numbers.len() >= BULK_THRESHOLD_BATCH_SIZE
self.retry
.with(async |_| {
// SLRU segments are only available on shard 0.
let mut client = self.shards.load_full().get_zero().client().await?;
client.get_slru_segment(req).await
})
.await
}
}
@@ -509,23 +440,15 @@ impl Shards {
}
}
/// A single shard. Has dedicated resource pools with the following structure:
/// A single shard. Uses dedicated resource pools with the following structure:
///
/// * Channel pool: MAX_CLIENTS_PER_CHANNEL.
/// * Client pool: unbounded.
/// * Stream pool: unbounded.
/// * Bulk channel pool: MAX_BULK_CLIENTS_PER_CHANNEL.
/// * Channel pool: unbounded.
/// * Unary client pool: MAX_UNARY_CLIENTS.
/// * Stream client pool: unbounded.
/// * Stream pool: MAX_STREAMS and MAX_STREAM_QUEUE_DEPTH.
/// * Bulk channel pool: unbounded.
/// * Bulk client pool: unbounded.
/// * Bulk stream pool: unbounded.
///
/// We use a separate bulk channel pool with a lower concurrency limit for large batch requests.
/// This avoids TCP-level head-of-line blocking, and also concentrates large window sizes on a
/// smaller set of streams/connections, which presumably reduces memory use. Neither of these pools
/// are bounded, nor do they pipeline requests, so the latency characteristics should be mostly
/// similar (except for TCP transmission time).
///
/// TODO: since we never use bounded pools, we could consider removing the pool limiters. However,
/// the code is fairly trivial, so we may as well keep them around for now in case we need them.
/// * Bulk stream pool: MAX_BULK_STREAMS and MAX_BULK_STREAM_QUEUE_DEPTH.
struct Shard {
/// The shard ID.
id: ShardIndex,
@@ -533,7 +456,7 @@ struct Shard {
client_pool: Arc<ClientPool>,
/// GetPage stream pool.
stream_pool: Arc<StreamPool>,
/// GetPage stream pool for bulk requests.
/// GetPage stream pool for bulk requests, e.g. prefetches.
bulk_stream_pool: Arc<StreamPool>,
}
@@ -547,30 +470,50 @@ impl Shard {
auth_token: Option<String>,
compression: Option<CompressionEncoding>,
) -> anyhow::Result<Self> {
// Shard pools for unary requests and non-bulk GetPage requests.
// Common channel pool for unary and stream requests. Bounded by client/stream pools.
let channel_pool = ChannelPool::new(url.clone(), MAX_CLIENTS_PER_CHANNEL)?;
// Client pool for unary requests.
let client_pool = ClientPool::new(
ChannelPool::new(url.clone(), MAX_CLIENTS_PER_CHANNEL)?,
channel_pool.clone(),
tenant_id,
timeline_id,
shard_id,
auth_token.clone(),
compression,
None, // unbounded
Some(MAX_UNARY_CLIENTS),
);
let stream_pool = StreamPool::new(client_pool.clone(), None); // unbounded
// Bulk GetPage stream pool for large batches (prefetches, sequential scans, vacuum, etc.).
// GetPage stream pool. Uses a dedicated client pool to avoid starving out unary clients,
// but shares a channel pool with it (as it's unbounded).
let stream_pool = StreamPool::new(
ClientPool::new(
channel_pool.clone(),
tenant_id,
timeline_id,
shard_id,
auth_token.clone(),
compression,
None, // unbounded, limited by stream pool
),
Some(MAX_STREAMS),
MAX_STREAM_QUEUE_DEPTH,
);
// Bulk GetPage stream pool, e.g. for prefetches. Uses dedicated channel/client/stream pools
// to avoid head-of-line blocking of latency-sensitive requests.
let bulk_stream_pool = StreamPool::new(
ClientPool::new(
ChannelPool::new(url, MAX_BULK_CLIENTS_PER_CHANNEL)?,
ChannelPool::new(url, MAX_CLIENTS_PER_CHANNEL)?,
tenant_id,
timeline_id,
shard_id,
auth_token,
compression,
None, // unbounded,
None, // unbounded, limited by stream pool
),
None, // unbounded
Some(MAX_BULK_STREAMS),
MAX_BULK_STREAM_QUEUE_DEPTH,
);
Ok(Self {
@@ -582,23 +525,19 @@ impl Shard {
}
/// Returns a pooled client for this shard.
#[instrument(skip_all)]
async fn client(&self) -> tonic::Result<ClientGuard> {
warn_slow(
"client pool acquisition",
SLOW_THRESHOLD,
pin!(self.client_pool.get()),
)
.await
self.client_pool
.get()
.await
.map_err(|err| tonic::Status::internal(format!("failed to get client: {err}")))
}
/// Returns a pooled stream for this shard. If `bulk` is `true`, uses the dedicated bulk pool.
#[instrument(skip_all, fields(bulk))]
async fn stream(&self, bulk: bool) -> tonic::Result<StreamGuard> {
let pool = match bulk {
false => &self.stream_pool,
true => &self.bulk_stream_pool,
};
warn_slow("stream pool acquisition", SLOW_THRESHOLD, pin!(pool.get())).await
/// Returns a pooled stream for this shard. If `bulk` is `true`, uses the dedicated bulk stream
/// pool (e.g. for prefetches).
async fn stream(&self, bulk: bool) -> StreamGuard {
match bulk {
false => self.stream_pool.get().await,
true => self.bulk_stream_pool.get().await,
}
}
}

View File

@@ -9,36 +9,19 @@
//!
//! * ChannelPool: manages gRPC channels (TCP connections) to a single Pageserver. Multiple clients
//! can acquire and use the same channel concurrently (via HTTP/2 stream multiplexing), up to a
//! per-channel client limit. Channels are closed immediately when empty, and indirectly rely on
//! client/stream idle timeouts.
//! per-channel client limit. Channels may be closed when they are no longer used by any clients.
//!
//! * ClientPool: manages gRPC clients for a single tenant shard. Each client acquires a (shared)
//! channel from the ChannelPool for the client's lifetime. A client can only be acquired by a
//! single caller at a time, and is returned to the pool when dropped. Idle clients are removed
//! from the pool after a while to free up resources.
//! single caller at a time, and is returned to the pool when dropped. Idle clients may be removed
//! from the pool after some time, to free up the channel.
//!
//! * StreamPool: manages bidirectional gRPC GetPage streams. Each stream acquires a client from the
//! ClientPool for the stream's lifetime. A stream can only be acquired by a single caller at a
//! time, and is returned to the pool when dropped. Idle streams are removed from the pool after
//! a while to free up resources.
//!
//! The stream only supports sending a single, synchronous request at a time, and does not support
//! pipelining multiple requests from different callers onto the same stream -- instead, we scale
//! out concurrent streams to improve throughput. There are many reasons for this design choice:
//!
//! * It (mostly) eliminates head-of-line blocking. A single stream is processed sequentially by
//! a single server task, which may block e.g. on layer downloads, LSN waits, etc.
//!
//! * Cancellation becomes trivial, by closing the stream. Otherwise, if a caller goes away
//! (e.g. because of a timeout), the request would still be processed by the server and block
//! requests behind it in the stream. It might even block its own timeout retry.
//!
//! * Stream scheduling becomes significantly simpler and cheaper.
//!
//! * Individual callers can still use client-side batching for pipelining.
//!
//! * Idle streams are cheap. Benchmarks show that an idle GetPage stream takes up about 26 KB
//! per stream (2.5 GB for 100,000 streams), so we can afford to scale out.
//! ClientPool for the stream's lifetime. Internal streams are not exposed to callers; instead, it
//! returns a guard that can be used to send a single request, to properly enforce queue depth and
//! route responses. Internally, the pool will reuse or spin up a suitable stream for the request,
//! possibly pipelining multiple requests from multiple callers on the same stream (up to some
//! queue depth). Idle streams may be removed from the pool after a while to free up the client.
//!
//! Each channel corresponds to one TCP connection. Each client unary request and each stream
//! corresponds to one HTTP/2 stream and server task.
@@ -46,31 +29,33 @@
//! TODO: error handling (including custom error types).
//! TODO: observability.
use std::collections::BTreeMap;
use std::collections::{BTreeMap, HashMap};
use std::num::NonZero;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, Weak};
use std::time::{Duration, Instant};
use futures::{Stream, StreamExt as _};
use tokio::sync::{OwnedSemaphorePermit, Semaphore, watch};
use tokio_stream::wrappers::WatchStream;
use futures::StreamExt as _;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::sync::{OwnedSemaphorePermit, Semaphore, mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use tonic::codec::CompressionEncoding;
use tonic::transport::{Channel, Endpoint};
use tracing::{error, warn};
use pageserver_page_api as page_api;
use utils::id::{TenantId, TimelineId};
use utils::shard::ShardIndex;
/// Reap clients/streams that have been idle for this long. Channels are reaped immediately when
/// empty, and indirectly rely on the client/stream idle timeouts.
/// Reap channels/clients/streams that have been idle for this long.
///
/// A stream's client will be reaped after 2x the idle threshold (first stream the client), but
/// that's okay -- if the stream closes abruptly (e.g. due to timeout or cancellation), we want to
/// keep its client around in the pool for a while.
/// TODO: this is per-pool. For nested pools, it can take up to 3x as long for a TCP connection to
/// be reaped. First, we must wait for an idle stream to be reaped, which marks its client as idle.
/// Then, we must wait for the idle client to be reaped, which marks its channel as idle. Then, we
/// must wait for the idle channel to be reaped. Is that a problem? Maybe not, we just have to
/// account for it when setting the reap threshold. Alternatively, we can immediately reap empty
/// channels, and/or stream pool clients.
const REAP_IDLE_THRESHOLD: Duration = match cfg!(any(test, feature = "testing")) {
false => Duration::from_secs(180),
true => Duration::from_secs(1), // exercise reaping in tests
@@ -98,6 +83,8 @@ pub struct ChannelPool {
max_clients_per_channel: NonZero<usize>,
/// Open channels.
channels: Mutex<BTreeMap<ChannelID, ChannelEntry>>,
/// Reaps idle channels.
idle_reaper: Reaper,
/// Channel ID generator.
next_channel_id: AtomicUsize,
}
@@ -109,6 +96,9 @@ struct ChannelEntry {
channel: Channel,
/// Number of clients using this channel.
clients: usize,
/// The channel has been idle (no clients) since this time. None if channel is in use.
/// INVARIANT: Some if clients == 0, otherwise None.
idle_since: Option<Instant>,
}
impl ChannelPool {
@@ -118,12 +108,15 @@ impl ChannelPool {
E: TryInto<Endpoint> + Send + Sync + 'static,
<E as TryInto<Endpoint>>::Error: std::error::Error + Send + Sync,
{
Ok(Arc::new(Self {
let pool = Arc::new(Self {
endpoint: endpoint.try_into()?,
max_clients_per_channel,
channels: Mutex::default(),
idle_reaper: Reaper::new(REAP_IDLE_THRESHOLD, REAP_IDLE_INTERVAL),
next_channel_id: AtomicUsize::default(),
}))
});
pool.idle_reaper.spawn(&pool);
Ok(pool)
}
/// Acquires a gRPC channel for a client. Multiple clients may acquire the same channel.
@@ -144,17 +137,22 @@ impl ChannelPool {
let mut channels = self.channels.lock().unwrap();
// Try to find an existing channel with available capacity. We check entries in BTreeMap
// order, to fill up the lower-ordered channels first. The client/stream pools also prefer
// clients with lower-ordered channel IDs first. This will cluster clients in lower-ordered
// order, to fill up the lower-ordered channels first. The ClientPool also prefers clients
// with lower-ordered channel IDs first. This will cluster clients in lower-ordered
// channels, and free up higher-ordered channels such that they can be reaped.
for (&id, entry) in channels.iter_mut() {
assert!(
entry.clients <= self.max_clients_per_channel.get(),
"channel overflow"
);
assert_ne!(entry.clients, 0, "empty channel not reaped");
assert_eq!(
entry.idle_since.is_some(),
entry.clients == 0,
"incorrect channel idle state"
);
if entry.clients < self.max_clients_per_channel.get() {
entry.clients += 1;
entry.idle_since = None;
return ChannelGuard {
pool: Arc::downgrade(self),
id,
@@ -171,6 +169,7 @@ impl ChannelPool {
let entry = ChannelEntry {
channel: channel.clone(),
clients: 1, // account for the guard below
idle_since: None,
};
channels.insert(id, entry);
@@ -182,6 +181,20 @@ impl ChannelPool {
}
}
impl Reapable for ChannelPool {
/// Reaps channels that have been idle since before the cutoff.
fn reap_idle(&self, cutoff: Instant) {
self.channels.lock().unwrap().retain(|_, entry| {
let Some(idle_since) = entry.idle_since else {
assert_ne!(entry.clients, 0, "empty channel not marked idle");
return true;
};
assert_eq!(entry.clients, 0, "idle channel has clients");
idle_since >= cutoff
})
}
}
/// Tracks a channel acquired from the pool. The owned inner channel can be obtained with `take()`,
/// since the gRPC client requires an owned `Channel`.
pub struct ChannelGuard {
@@ -198,7 +211,7 @@ impl ChannelGuard {
}
}
/// Returns the channel to the pool. The channel is closed when empty.
/// Returns the channel to the pool.
impl Drop for ChannelGuard {
fn drop(&mut self) {
let Some(pool) = self.pool.upgrade() else {
@@ -207,12 +220,11 @@ impl Drop for ChannelGuard {
let mut channels = pool.channels.lock().unwrap();
let entry = channels.get_mut(&self.id).expect("unknown channel");
assert!(entry.idle_since.is_none(), "active channel marked idle");
assert!(entry.clients > 0, "channel underflow");
entry.clients -= 1;
// Reap empty channels immediately.
if entry.clients == 0 {
channels.remove(&self.id);
entry.idle_since = Some(Instant::now()); // mark channel as idle
}
}
}
@@ -241,7 +253,8 @@ pub struct ClientPool {
///
/// The first client in the map will be acquired next. The map is sorted by client ID, which in
/// turn is sorted by its channel ID, such that we prefer acquiring idle clients from
/// lower-ordered channels. This allows us to free up and reap higher-ordered channels.
/// lower-ordered channels. This allows us to free up and reap higher-numbered channels as idle
/// clients are reaped.
idle: Mutex<BTreeMap<ClientID, ClientEntry>>,
/// Reaps idle clients.
idle_reaper: Reaper,
@@ -297,7 +310,7 @@ impl ClientPool {
/// This is moderately performance-sensitive. It is called for every unary request, but these
/// establish a new gRPC stream per request so they're already expensive. GetPage requests use
/// the `StreamPool` instead.
pub async fn get(self: &Arc<Self>) -> tonic::Result<ClientGuard> {
pub async fn get(self: &Arc<Self>) -> anyhow::Result<ClientGuard> {
// Acquire a permit if the pool is bounded.
let mut permit = None;
if let Some(limiter) = self.limiter.clone() {
@@ -315,7 +328,7 @@ impl ClientPool {
});
}
// Construct a new client.
// Slow path: construct a new client.
let mut channel_guard = self.channel_pool.get();
let client = page_api::Client::new(
channel_guard.take(),
@@ -324,8 +337,7 @@ impl ClientPool {
self.shard_id,
self.auth_token.clone(),
self.compression,
)
.map_err(|err| tonic::Status::internal(format!("failed to create client: {err}")))?;
)?;
Ok(ClientGuard {
pool: Arc::downgrade(self),
@@ -395,187 +407,287 @@ impl Drop for ClientGuard {
/// A pool of bidirectional gRPC streams. Currently only used for GetPage streams. Each stream
/// acquires a client from the inner `ClientPool` for the stream's lifetime.
///
/// Individual streams only send a single request at a time, and do not pipeline multiple callers
/// onto the same stream. Instead, we scale out the number of concurrent streams. This is primarily
/// to eliminate head-of-line blocking. See the module documentation for more details.
/// Individual streams are not exposed to callers -- instead, the returned guard can be used to send
/// a single request and await the response. Internally, requests are multiplexed across streams and
/// channels. This allows proper queue depth enforcement and response routing.
///
/// TODO: consider making this generic over request and response types; not currently needed.
pub struct StreamPool {
/// The client pool to acquire clients from. Must be unbounded.
client_pool: Arc<ClientPool>,
/// Idle pooled streams. Acquired streams are removed from here and returned on drop.
/// All pooled streams.
///
/// The first stream in the map will be acquired next. The map is sorted by stream ID, which is
/// equivalent to the client ID and in turn sorted by its channel ID. This way we prefer
/// acquiring idle streams from lower-ordered channels, which allows us to free up and reap
/// higher-ordered channels.
idle: Mutex<BTreeMap<StreamID, StreamEntry>>,
/// Limits the max number of concurrent streams. None if the pool is unbounded.
/// Incoming requests will be sent over an existing stream with available capacity. If all
/// streams are full, a new one is spun up and added to the pool (up to `max_streams`). Each
/// stream has an associated Tokio task that processes requests and responses.
streams: Mutex<HashMap<StreamID, StreamEntry>>,
/// The max number of concurrent streams, or None if unbounded.
max_streams: Option<NonZero<usize>>,
/// The max number of concurrent requests per stream.
max_queue_depth: NonZero<usize>,
/// Limits the max number of concurrent requests, given by `max_streams * max_queue_depth`.
/// None if the pool is unbounded.
limiter: Option<Arc<Semaphore>>,
/// Reaps idle streams.
idle_reaper: Reaper,
/// Stream ID generator.
next_stream_id: AtomicUsize,
}
/// The stream ID. Reuses the inner client ID.
type StreamID = ClientID;
type StreamID = usize;
type RequestSender = Sender<(page_api::GetPageRequest, ResponseSender)>;
type RequestReceiver = Receiver<(page_api::GetPageRequest, ResponseSender)>;
type ResponseSender = oneshot::Sender<tonic::Result<page_api::GetPageResponse>>;
/// A pooled stream.
struct StreamEntry {
/// The bidirectional stream.
stream: BiStream,
/// The time when this stream was last used, i.e. when it was put back into `StreamPool::idle`.
idle_since: Instant,
}
/// A bidirectional GetPage stream and its client. Can send requests and receive responses.
struct BiStream {
/// The owning client. Holds onto the channel slot while the stream is alive.
client: ClientGuard,
/// Stream for sending requests. Uses a watch channel, so it can only send a single request at a
/// time, and the caller must await the response before sending another request. This is
/// enforced by `StreamGuard::send`.
sender: watch::Sender<page_api::GetPageRequest>,
/// Stream for receiving responses.
receiver: Pin<Box<dyn Stream<Item = tonic::Result<page_api::GetPageResponse>> + Send>>,
/// Sends caller requests to the stream task. The stream task exits when this is dropped.
sender: RequestSender,
/// Number of in-flight requests on this stream.
queue_depth: usize,
/// The time when this stream went idle (queue_depth == 0).
/// INVARIANT: Some if queue_depth == 0, otherwise None.
idle_since: Option<Instant>,
}
impl StreamPool {
/// Creates a new stream pool, using the given client pool. It will use up to `max_streams`
/// concurrent streams.
/// Creates a new stream pool, using the given client pool. It will send up to `max_queue_depth`
/// concurrent requests on each stream, and use up to `max_streams` concurrent streams.
///
/// The client pool must be unbounded. The stream pool will enforce its own limits, and because
/// streams are long-lived they can cause persistent starvation if they exhaust the client pool.
/// The stream pool should generally have its own dedicated client pool (but it can share a
/// channel pool with others since these are always unbounded).
pub fn new(client_pool: Arc<ClientPool>, max_streams: Option<NonZero<usize>>) -> Arc<Self> {
pub fn new(
client_pool: Arc<ClientPool>,
max_streams: Option<NonZero<usize>>,
max_queue_depth: NonZero<usize>,
) -> Arc<Self> {
assert!(client_pool.limiter.is_none(), "bounded client pool");
let pool = Arc::new(Self {
client_pool,
idle: Mutex::default(),
limiter: max_streams.map(|max_streams| Arc::new(Semaphore::new(max_streams.get()))),
streams: Mutex::default(),
limiter: max_streams.map(|max_streams| {
Arc::new(Semaphore::new(max_streams.get() * max_queue_depth.get()))
}),
max_streams,
max_queue_depth,
idle_reaper: Reaper::new(REAP_IDLE_THRESHOLD, REAP_IDLE_INTERVAL),
next_stream_id: AtomicUsize::default(),
});
pool.idle_reaper.spawn(&pool);
pool
}
/// Acquires an available stream from the pool, or spins up a new stream if all streams are
/// full. Returns a guard that can be used to send requests and await the responses. Blocks if
/// the pool is full.
/// Acquires an available stream from the pool, or spins up a new stream async if all streams
/// are full. Returns a guard that can be used to send a single request on the stream and await
/// the response, with queue depth quota already acquired. Blocks if the pool is at capacity
/// (i.e. `CLIENT_LIMIT * STREAM_QUEUE_DEPTH` requests in flight).
///
/// This is very performance-sensitive, as it is on the GetPage hot path.
///
/// TODO: is a `Mutex<BTreeMap>` performant enough? Will it become too contended? We can't
/// trivially use e.g. DashMap or sharding, because we want to pop lower-ordered streams first
/// to free up higher-ordered channels.
pub async fn get(self: &Arc<Self>) -> tonic::Result<StreamGuard> {
/// TODO: this must do something more sophisticated for performance. We want:
///
/// * Cheap, concurrent access in the common case where we can use a pooled stream.
/// * Quick acquisition of pooled streams with available capacity.
/// * Prefer streams that belong to lower-numbered channels, to reap idle channels.
/// * Prefer filling up existing streams' queue depth before spinning up new streams.
/// * Don't hold a lock while spinning up new streams.
/// * Allow concurrent clients to join onto streams while they're spun up.
/// * Allow spinning up multiple streams concurrently, but don't overshoot limits.
///
/// For now, we just do something simple but inefficient (linear scan under mutex).
pub async fn get(self: &Arc<Self>) -> StreamGuard {
// Acquire a permit if the pool is bounded.
let mut permit = None;
if let Some(limiter) = self.limiter.clone() {
permit = Some(limiter.acquire_owned().await.expect("never closed"));
}
let mut streams = self.streams.lock().unwrap();
// Fast path: acquire an idle stream from the pool.
if let Some((_, entry)) = self.idle.lock().unwrap().pop_first() {
return Ok(StreamGuard {
pool: Arc::downgrade(self),
stream: Some(entry.stream),
can_reuse: true,
permit,
});
// Look for a pooled stream with available capacity.
for (&id, entry) in streams.iter_mut() {
assert!(
entry.queue_depth <= self.max_queue_depth.get(),
"stream queue overflow"
);
assert_eq!(
entry.idle_since.is_some(),
entry.queue_depth == 0,
"incorrect stream idle state"
);
if entry.queue_depth < self.max_queue_depth.get() {
entry.queue_depth += 1;
entry.idle_since = None;
return StreamGuard {
pool: Arc::downgrade(self),
id,
sender: entry.sender.clone(),
permit,
};
}
}
// Spin up a new stream. Uses a watch channel to send a single request at a time, since
// `StreamGuard::send` enforces this anyway and it avoids unnecessary channel overhead.
let mut client = self.client_pool.get().await?;
// No available stream, spin up a new one. We install the stream entry in the pool first and
// return the guard, while spinning up the stream task async. This allows other callers to
// join onto this stream and also create additional streams concurrently if this fills up.
let id = self.next_stream_id.fetch_add(1, Ordering::Relaxed);
let (req_tx, req_rx) = mpsc::channel(self.max_queue_depth.get());
let entry = StreamEntry {
sender: req_tx.clone(),
queue_depth: 1, // reserve quota for this caller
idle_since: None,
};
streams.insert(id, entry);
let (req_tx, req_rx) = watch::channel(page_api::GetPageRequest::default());
let req_stream = WatchStream::from_changes(req_rx);
let resp_stream = client.get_pages(req_stream).await?;
if let Some(max_streams) = self.max_streams {
assert!(streams.len() <= max_streams.get(), "stream overflow");
};
Ok(StreamGuard {
let client_pool = self.client_pool.clone();
let pool = Arc::downgrade(self);
tokio::spawn(async move {
if let Err(err) = Self::run_stream(client_pool, req_rx).await {
error!("stream failed: {err}");
}
// Remove stream from pool on exit. Weak reference to avoid holding the pool alive.
if let Some(pool) = pool.upgrade() {
let entry = pool.streams.lock().unwrap().remove(&id);
assert!(entry.is_some(), "unknown stream ID: {id}");
}
});
StreamGuard {
pool: Arc::downgrade(self),
stream: Some(BiStream {
client,
sender: req_tx,
receiver: Box::pin(resp_stream),
}),
can_reuse: true,
id,
sender: req_tx,
permit,
})
}
}
/// Runs a stream task. This acquires a client from the `ClientPool` and establishes a
/// bidirectional GetPage stream, then forwards requests and responses between callers and the
/// stream. It does not track or enforce queue depths -- that's done by `get()` since it must be
/// atomic with pool stream acquisition.
///
/// The task exits when the request channel is closed, or on a stream error. The caller is
/// responsible for removing the stream from the pool on exit.
async fn run_stream(
client_pool: Arc<ClientPool>,
mut caller_rx: RequestReceiver,
) -> anyhow::Result<()> {
// Acquire a client from the pool and create a stream.
let mut client = client_pool.get().await?;
// NB: use an unbounded channel such that the stream send never blocks. Otherwise, we could
// theoretically deadlock if both the client and server block on sends (since we're not
// reading responses while sending). This is unlikely to happen due to gRPC/TCP buffers and
// low queue depths, but it was seen to happen with the libpq protocol so better safe than
// sorry. It should never buffer more than the queue depth anyway, but using an unbounded
// channel guarantees that it will never block.
let (req_tx, req_rx) = mpsc::unbounded_channel();
let req_stream = tokio_stream::wrappers::UnboundedReceiverStream::new(req_rx);
let mut resp_stream = client.get_pages(req_stream).await?;
// Track caller response channels by request ID. If the task returns early, these response
// channels will be dropped and the waiting callers will receive an error.
//
// NB: this will leak entries if the server doesn't respond to a request (by request ID).
// It shouldn't happen, and if it does it will often hold onto queue depth quota anyway and
// block further use. But we could consider reaping closed channels after some time.
let mut callers = HashMap::new();
// Process requests and responses.
loop {
tokio::select! {
// Receive requests from callers and send them to the stream.
req = caller_rx.recv() => {
// Shut down if request channel is closed.
let Some((req, resp_tx)) = req else {
return Ok(());
};
// Store the response channel by request ID.
if callers.contains_key(&req.request_id) {
// Error on request ID duplicates. Ignore callers that went away.
_ = resp_tx.send(Err(tonic::Status::invalid_argument(
format!("duplicate request ID: {}", req.request_id),
)));
continue;
}
callers.insert(req.request_id, resp_tx);
// Send the request on the stream. Bail out if the stream is closed.
req_tx.send(req).map_err(|_| {
tonic::Status::unavailable("stream closed")
})?;
}
// Receive responses from the stream and send them to callers.
resp = resp_stream.next() => {
// Shut down if the stream is closed, and bail out on stream errors.
let Some(resp) = resp.transpose()? else {
return Ok(())
};
// Send the response to the caller. Ignore errors if the caller went away.
let Some(resp_tx) = callers.remove(&resp.request_id) else {
warn!("received response for unknown request ID: {}", resp.request_id);
continue;
};
_ = resp_tx.send(Ok(resp));
}
}
}
}
}
impl Reapable for StreamPool {
/// Reaps streams that have been idle since before the cutoff.
fn reap_idle(&self, cutoff: Instant) {
self.idle
.lock()
.unwrap()
.retain(|_, entry| entry.idle_since >= cutoff);
self.streams.lock().unwrap().retain(|_, entry| {
let Some(idle_since) = entry.idle_since else {
assert_ne!(entry.queue_depth, 0, "empty stream not marked idle");
return true;
};
assert_eq!(entry.queue_depth, 0, "idle stream has requests");
idle_since >= cutoff
});
}
}
/// A stream acquired from the pool. Returned to the pool when dropped, unless there are still
/// in-flight requests on the stream, or the stream failed.
/// A pooled stream reference. Can be used to send a single request, to properly enforce queue
/// depth. Queue depth is already reserved and will be returned on drop.
pub struct StreamGuard {
pool: Weak<StreamPool>,
stream: Option<BiStream>, // Some until dropped
can_reuse: bool, // returned to pool if true
id: StreamID,
sender: RequestSender,
permit: Option<OwnedSemaphorePermit>, // None if pool is unbounded
}
impl StreamGuard {
/// Sends a request on the stream and awaits the response. If the future is dropped before it
/// resolves (e.g. due to a timeout or cancellation), the stream will be closed to cancel the
/// request and is not returned to the pool. The same is true if the stream errors, in which
/// case the caller can't send further requests on the stream.
/// Sends a request on the stream and awaits the response. Consumes the guard, since it's only
/// valid for a single request (to enforce queue depth). This also drops the guard on return and
/// returns the queue depth quota to the pool.
///
/// We only support sending a single request at a time, to eliminate head-of-line blocking. See
/// module documentation for details.
/// The `GetPageRequest::request_id` must be unique across in-flight requests.
///
/// NB: errors are often returned as `GetPageResponse::status_code` instead of `tonic::Status`
/// to avoid tearing down the stream for per-request errors. Callers must check this.
pub async fn send(
&mut self,
self,
req: page_api::GetPageRequest,
) -> tonic::Result<page_api::GetPageResponse> {
let req_id = req.request_id;
let stream = self.stream.as_mut().expect("not dropped");
let (resp_tx, resp_rx) = oneshot::channel();
// Mark the stream as not reusable while the request is in flight. We can't return the
// stream to the pool until we receive the response, to avoid head-of-line blocking and
// stale responses. Failed streams can't be reused either.
if !self.can_reuse {
return Err(tonic::Status::internal("stream can't be reused"));
}
self.can_reuse = false;
// Send the request and receive the response.
//
// NB: this uses a watch channel, so it's unsafe to change this code to pipeline requests.
stream
.sender
.send(req)
self.sender
.send((req, resp_tx))
.await
.map_err(|_| tonic::Status::unavailable("stream closed"))?;
let resp = stream
.receiver
.next()
resp_rx
.await
.ok_or_else(|| tonic::Status::unavailable("stream closed"))??;
if resp.request_id != req_id {
return Err(tonic::Status::internal(format!(
"response ID {} does not match request ID {}",
resp.request_id, req_id
)));
}
// Success, mark the stream as reusable.
self.can_reuse = true;
Ok(resp)
.map_err(|_| tonic::Status::unavailable("stream closed"))?
}
}
@@ -585,21 +697,26 @@ impl Drop for StreamGuard {
return; // pool was dropped
};
// If the stream isn't reusable, it can't be returned to the pool.
if !self.can_reuse {
return;
// Release the queue depth reservation on drop. This can prematurely decrement it if dropped
// before the response is received, but that's okay.
//
// TODO: actually, it's probably not okay. Queue depth release should be moved into the
// stream task, such that it continues to account for the queue depth slot until the server
// responds. Otherwise, if a slow request times out and keeps blocking the stream, the
// server will keep waiting on it and we can pile on subsequent requests (including the
// timeout retry) in the same stream and get blocked. But we may also want to avoid blocking
// requests on e.g. LSN waits and layer downloads, instead returning early to free up the
// stream. Or just scale out streams with a queue depth of 1 to sidestep all head-of-line
// blocking. TBD.
let mut streams = pool.streams.lock().unwrap();
let entry = streams.get_mut(&self.id).expect("unknown stream");
assert!(entry.idle_since.is_none(), "active stream marked idle");
assert!(entry.queue_depth > 0, "stream queue underflow");
entry.queue_depth -= 1;
if entry.queue_depth == 0 {
entry.idle_since = Some(Instant::now()); // mark stream as idle
}
// Place the idle stream back into the pool.
let entry = StreamEntry {
stream: self.stream.take().expect("dropped once"),
idle_since: Instant::now(),
};
pool.idle
.lock()
.unwrap()
.insert(entry.stream.client.id, entry);
_ = self.permit; // returned on drop, referenced for visibility
}
}

View File

@@ -1,6 +1,5 @@
use std::time::Duration;
use futures::future::pending;
use tokio::time::Instant;
use tracing::{error, info, warn};
@@ -9,54 +8,60 @@ use utils::backoff::exponential_backoff_duration;
/// A retry handler for Pageserver gRPC requests.
///
/// This is used instead of backoff::retry for better control and observability.
pub struct Retry {
/// Timeout across all retry attempts. If None, retries forever.
pub timeout: Option<Duration>,
/// The initial backoff duration. The first retry does not use a backoff.
pub base_backoff: Duration,
/// The maximum backoff duration.
pub max_backoff: Duration,
}
pub struct Retry;
impl Retry {
/// Runs the given async closure with timeouts and retries (exponential backoff). Logs errors,
/// using the current tracing span for context.
/// The per-request timeout.
// TODO: tune these, and/or make them configurable. Should we retry forever?
const REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
/// The total timeout across all attempts
const TOTAL_TIMEOUT: Duration = Duration::from_secs(60);
/// The initial backoff duration.
const BASE_BACKOFF: Duration = Duration::from_millis(10);
/// The maximum backoff duration.
const MAX_BACKOFF: Duration = Duration::from_secs(10);
/// If true, log successful requests. For debugging.
const LOG_SUCCESS: bool = false;
/// Runs the given async closure with timeouts and retries (exponential backoff), passing the
/// attempt number starting at 0. Logs errors, using the current tracing span for context.
///
/// Only certain gRPC status codes are retried, see [`Self::should_retry`].
/// Only certain gRPC status codes are retried, see [`Self::should_retry`]. For default
/// timeouts, see [`Self::REQUEST_TIMEOUT`] and [`Self::TOTAL_TIMEOUT`].
pub async fn with<T, F, O>(&self, mut f: F) -> tonic::Result<T>
where
F: FnMut(usize) -> O, // pass attempt number, starting at 0
F: FnMut(usize) -> O, // takes attempt number, starting at 0
O: Future<Output = tonic::Result<T>>,
{
let started = Instant::now();
let deadline = self.timeout.map(|timeout| started + timeout);
let deadline = started + Self::TOTAL_TIMEOUT;
let mut last_error = None;
let mut retries = 0;
loop {
// Set up a future to wait for the backoff, if any, and run the closure.
// Set up a future to wait for the backoff (if any) and run the request with a timeout.
let backoff_and_try = async {
// NB: sleep() always sleeps 1ms, even when given a 0 argument. See:
// https://github.com/tokio-rs/tokio/issues/6866
if let Some(backoff) = self.backoff_duration(retries) {
if let Some(backoff) = Self::backoff_duration(retries) {
tokio::time::sleep(backoff).await;
}
f(retries).await
let request_started = Instant::now();
tokio::time::timeout(Self::REQUEST_TIMEOUT, f(retries))
.await
.map_err(|_| {
tonic::Status::deadline_exceeded(format!(
"request timed out after {:.3}s",
request_started.elapsed().as_secs_f64()
))
})?
};
// Set up a future for the timeout, if any.
let timeout = async {
match deadline {
Some(deadline) => tokio::time::sleep_until(deadline).await,
None => pending().await,
}
};
// Wait for the backoff and request, or bail out if the timeout is exceeded.
// Wait for the backoff and request, or bail out if the total timeout is exceeded.
let result = tokio::select! {
result = backoff_and_try => result,
_ = timeout => {
_ = tokio::time::sleep_until(deadline) => {
let last_error = last_error.unwrap_or_else(|| {
tonic::Status::deadline_exceeded(format!(
"request timed out after {:.3}s",
@@ -74,7 +79,7 @@ impl Retry {
match result {
// Success, return the result.
Ok(result) => {
if retries > 0 {
if retries > 0 || Self::LOG_SUCCESS {
info!(
"request succeeded after {retries} retries in {:.3}s",
started.elapsed().as_secs_f64(),
@@ -107,13 +112,12 @@ impl Retry {
}
}
/// Returns the backoff duration for the given retry attempt, or None for no backoff. The first
/// attempt and first retry never backs off, so this returns None for 0 and 1 retries.
fn backoff_duration(&self, retries: usize) -> Option<Duration> {
/// Returns the backoff duration for the given retry attempt, or None for no backoff.
fn backoff_duration(retry: usize) -> Option<Duration> {
let backoff = exponential_backoff_duration(
(retries as u32).saturating_sub(1), // first retry does not back off
self.base_backoff.as_secs_f64(),
self.max_backoff.as_secs_f64(),
retry as u32,
Self::BASE_BACKOFF.as_secs_f64(),
Self::MAX_BACKOFF.as_secs_f64(),
);
(!backoff.is_zero()).then_some(backoff)
}

Some files were not shown because too many files have changed in this diff Show More