Compare commits

..

11 Commits

Author SHA1 Message Date
Victor Polevoy
e191daa152 Add a test for generating pprof from collapsed 2025-08-21 12:40:57 +02:00
Victor Polevoy
7572ffc725 Document the continuous profiling of computes 2025-08-21 11:53:34 +02:00
Victor Polevoy
0aacbc2583 Install bpfcc tools in the compute node docker image.
This provides the binaries and libraries required for the continuous
profiling at runtime.
2025-08-21 11:16:00 +02:00
Victor Polevoy
cb9874dc4e After switching to ext4 img, symlinking is possible 2025-08-20 16:24:07 +02:00
Victor Polevoy
891c1fe512 Fix the kernel headers enabling 2025-08-02 14:03:25 +02:00
Victor Polevoy
9cae494555 TRY TRY TRY (pink) 2025-07-31 13:02:45 +02:00
Victor Polevoy
c3e6d360b5 Clarify the reason for postgres comment 2025-07-18 12:52:57 +02:00
Victor Polevoy
9e69e24a52 Use the supplied kernel headers and modules 2025-07-18 12:52:45 +02:00
Victor Polevoy
8dbf5a8c5b Add annotation to skip on CI and macos 2025-07-15 12:30:34 +02:00
Victor Polevoy
463429af97 build-tools 2025-07-11 16:20:04 +02:00
Victor Polevoy
d8b0c0834e Implement HTTP endpoint for compute profiling.
Exposes an endpoint "/profile/cpu" for profiling the postgres
processes (currently spawned and the new ones) using "perf".

Adds the corresponding python test to test the added endpoint
and confirm the output expected is the profiling data in the
expected format.

Add "perf" binary to the sudo list.

Fix python poetry ruff

Address the clippy lints

Document the code

Format python code

Address code review

Prettify

Embed profile_pb2.py and small code/test fixes.

Make the code slightly better.

1. Makes optional the sampling_frequency parameter for profiling.
2. Avoids using unsafe code when killing a child.

Better code, better tests

More tests

Separate start and stop of profiling

Correctly check for the exceptions

Address clippy lint

Final fixes.

1. Allows the perf to be found in $PATH instead of having the path
hardcoded.
2. Changes the path to perf in the sudoers file so that the compute
can run it properly.
3. Changes the way perf is invoked, now it is with sudo and the path
from $PATH.
4. Removes the authentication requirement from the /profile/cpu/
endpoint.

hakari thing

Python fixes

Fix python formatting

More python fixes

Update poetry lock

Fix ruff

Address the review comments

Fix the tests

Try fixing the flaky test for pg17?

Try fixing the flaky test for pg17?

PYTHON

Fix the tests

Remove the PROGRESS parameter

Remove unused

Increase the timeout due to concurrency

Increase the timeout to 60

Increase the profiling window timeout

Try this

Lets see the error

Just log all the errors

Add perf into the build environment

uijdfghjdf

Update tempfile to 3.20

Snapshot

Use bbc-profile

Update tempfile to 3.20

Provide bpfcc-tools in debian

Properly respond with status

Python check

Fix build-tools dockerfile

Add path probation for the bcc profile

Try err printing

Refactor

Add bpfcc-tools to the final image

Add error context

sudo not found?

Print more errors for verbosity

Remove procfs and use libproc

Update hakari

Debug sudo in CI

Rebase and adjust hakari

remove leftover

Add archiving support

Correct the paths to the perf binary

Try hardcoded sudo path

Add sudo into build-tools dockerfile

Minor cleanup

Print out the sudoers file from github

Stop the tests earlier

Add the sudoers entry for nonroot, install kmod for modprobe for bcc-profile

Try hacking the kernel headers for bcc-profile

Redeclare the kernel version argument

Try using the kernel of the runner

Try another way

Check bpfcc-tools
2025-07-11 12:53:48 +02:00
286 changed files with 6216 additions and 24151 deletions

View File

@@ -27,4 +27,4 @@
!storage_controller/
!vendor/postgres-*/
!workspace_hack/
!build-tools/patches
!build_tools/patches

View File

@@ -31,7 +31,6 @@ config-variables:
- NEON_PROD_AWS_ACCOUNT_ID
- PGREGRESS_PG16_PROJECT_ID
- PGREGRESS_PG17_PROJECT_ID
- PREWARM_PGBENCH_SIZE
- REMOTE_STORAGE_AZURE_CONTAINER
- REMOTE_STORAGE_AZURE_REGION
- SLACK_CICD_CHANNEL_ID

View File

@@ -176,13 +176,7 @@ runs:
fi
if [[ $BUILD_TYPE == "debug" && $RUNNER_ARCH == 'X64' ]]; then
# We don't use code coverage for regression tests (the step is disabled),
# so there's no need to collect it.
# Ref https://github.com/neondatabase/neon/issues/4540
# cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/coverage run)
cov_prefix=()
# Explicitly set LLVM_PROFILE_FILE to /dev/null to avoid writing *.profraw files
export LLVM_PROFILE_FILE=/dev/null
cov_prefix=(scripts/coverage "--profraw-prefix=$GITHUB_JOB" --dir=/tmp/coverage run)
else
cov_prefix=()
fi

View File

@@ -150,7 +150,7 @@ jobs:
secretKey: ${{ secrets.HETZNER_CACHE_SECRET_KEY }}
use-fallback: false
path: pg_install/v14
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build-type }}-pg-${{ steps.pg_v14_rev.outputs.pg_rev }}-bookworm-${{ hashFiles('Makefile', 'build-tools/Dockerfile') }}
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build-type }}-pg-${{ steps.pg_v14_rev.outputs.pg_rev }}-bookworm-${{ hashFiles('Makefile', 'build-tools.Dockerfile') }}
- name: Cache postgres v15 build
id: cache_pg_15
@@ -162,7 +162,7 @@ jobs:
secretKey: ${{ secrets.HETZNER_CACHE_SECRET_KEY }}
use-fallback: false
path: pg_install/v15
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build-type }}-pg-${{ steps.pg_v15_rev.outputs.pg_rev }}-bookworm-${{ hashFiles('Makefile', 'build-tools/Dockerfile') }}
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build-type }}-pg-${{ steps.pg_v15_rev.outputs.pg_rev }}-bookworm-${{ hashFiles('Makefile', 'build-tools.Dockerfile') }}
- name: Cache postgres v16 build
id: cache_pg_16
@@ -174,7 +174,7 @@ jobs:
secretKey: ${{ secrets.HETZNER_CACHE_SECRET_KEY }}
use-fallback: false
path: pg_install/v16
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build-type }}-pg-${{ steps.pg_v16_rev.outputs.pg_rev }}-bookworm-${{ hashFiles('Makefile', 'build-tools/Dockerfile') }}
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build-type }}-pg-${{ steps.pg_v16_rev.outputs.pg_rev }}-bookworm-${{ hashFiles('Makefile', 'build-tools.Dockerfile') }}
- name: Cache postgres v17 build
id: cache_pg_17
@@ -186,7 +186,7 @@ jobs:
secretKey: ${{ secrets.HETZNER_CACHE_SECRET_KEY }}
use-fallback: false
path: pg_install/v17
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build-type }}-pg-${{ steps.pg_v17_rev.outputs.pg_rev }}-bookworm-${{ hashFiles('Makefile', 'build-tools/Dockerfile') }}
key: v1-${{ runner.os }}-${{ runner.arch }}-${{ inputs.build-type }}-pg-${{ steps.pg_v17_rev.outputs.pg_rev }}-bookworm-${{ hashFiles('Makefile', 'build-tools.Dockerfile') }}
- name: Build all
# Note: the Makefile picks up BUILD_TYPE and CARGO_PROFILE from the env variables

View File

@@ -219,7 +219,6 @@ jobs:
--ignore test_runner/performance/test_cumulative_statistics_persistence.py
--ignore test_runner/performance/test_perf_many_relations.py
--ignore test_runner/performance/test_perf_oltp_large_tenant.py
--ignore test_runner/performance/test_lfc_prewarm.py
env:
BENCHMARK_CONNSTR: ${{ steps.create-neon-project.outputs.dsn }}
VIP_VAP_ACCESS_TOKEN: "${{ secrets.VIP_VAP_ACCESS_TOKEN }}"
@@ -411,77 +410,6 @@ jobs:
env:
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
prewarm-test:
if: ${{ github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null }}
permissions:
contents: write
statuses: write
id-token: write # aws-actions/configure-aws-credentials
env:
PGBENCH_SIZE: ${{ vars.PREWARM_PGBENCH_SIZE }}
POSTGRES_DISTRIB_DIR: /tmp/neon/pg_install
DEFAULT_PG_VERSION: 17
TEST_OUTPUT: /tmp/test_output
BUILD_TYPE: remote
SAVE_PERF_REPORT: ${{ github.event.inputs.save_perf_report || ( github.ref_name == 'main' ) }}
PLATFORM: "neon-staging"
runs-on: [ self-hosted, us-east-2, x64 ]
container:
image: ghcr.io/neondatabase/build-tools:pinned-bookworm
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
options: --init
steps:
- name: Harden the runner (Audit all outbound calls)
uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0
with:
egress-policy: audit
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@e3dd6a429d7300a6a4c196c26e071d42e0343502 # v4.0.2
with:
aws-region: eu-central-1
role-to-assume: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
role-duration-seconds: 18000 # 5 hours
- name: Download Neon artifact
uses: ./.github/actions/download
with:
name: neon-${{ runner.os }}-${{ runner.arch }}-release-artifact
path: /tmp/neon/
prefix: latest
aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
- name: Run prewarm benchmark
uses: ./.github/actions/run-python-test-set
with:
build_type: ${{ env.BUILD_TYPE }}
test_selection: performance/test_lfc_prewarm.py
run_in_parallel: false
save_perf_report: ${{ env.SAVE_PERF_REPORT }}
extra_params: -m remote_cluster --timeout 5400
pg_version: ${{ env.DEFAULT_PG_VERSION }}
aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
env:
VIP_VAP_ACCESS_TOKEN: "${{ secrets.VIP_VAP_ACCESS_TOKEN }}"
PERF_TEST_RESULT_CONNSTR: "${{ secrets.PERF_TEST_RESULT_CONNSTR }}"
NEON_API_KEY: ${{ secrets.NEON_STAGING_API_KEY }}
- name: Create Allure report
id: create-allure-report
if: ${{ !cancelled() }}
uses: ./.github/actions/allure-report-generate
with:
store-test-results-into-db: true
aws-oidc-role-arn: ${{ vars.DEV_AWS_OIDC_ROLE_ARN }}
env:
REGRESS_TEST_RESULT_CONNSTR_NEW: ${{ secrets.REGRESS_TEST_RESULT_CONNSTR_NEW }}
generate-matrices:
if: ${{ github.event.inputs.run_only_pgvector_tests == 'false' || github.event.inputs.run_only_pgvector_tests == null }}
# Create matrices for the benchmarking jobs, so we run benchmarks on rds only once a week (on Saturday)

View File

@@ -72,7 +72,7 @@ jobs:
ARCHS: ${{ inputs.archs || '["x64","arm64"]' }}
DEBIANS: ${{ inputs.debians || '["bullseye","bookworm"]' }}
IMAGE_TAG: |
${{ hashFiles('build-tools/Dockerfile',
${{ hashFiles('build-tools.Dockerfile',
'.github/workflows/build-build-tools-image.yml') }}
run: |
echo "archs=${ARCHS}" | tee -a ${GITHUB_OUTPUT}
@@ -144,7 +144,7 @@ jobs:
- uses: docker/build-push-action@471d1dc4e07e5cdedd4c2171150001c434f0b7a4 # v6.15.0
with:
file: build-tools/Dockerfile
file: build-tools.Dockerfile
context: .
provenance: false
push: true

View File

@@ -87,27 +87,22 @@ jobs:
uses: ./.github/workflows/build-build-tools-image.yml
secrets: inherit
lint-yamls:
needs: [ meta, check-permissions, build-build-tools-image ]
lint-openapi-spec:
runs-on: ubuntu-22.04
needs: [ meta, check-permissions ]
# We do need to run this in `.*-rc-pr` because of hotfixes.
if: ${{ contains(fromJSON('["pr", "push-main", "storage-rc-pr", "proxy-rc-pr", "compute-rc-pr"]'), needs.meta.outputs.run-kind) }}
runs-on: [ self-hosted, small ]
container:
image: ${{ needs.build-build-tools-image.outputs.image }}
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
options: --init
steps:
- name: Harden the runner (Audit all outbound calls)
uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0
with:
egress-policy: audit
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- run: make -C compute manifest-schema-validation
- uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- run: make lint-openapi-spec
check-codestyle-python:
@@ -222,6 +217,28 @@ jobs:
build-tools-image: ${{ needs.build-build-tools-image.outputs.image }}-bookworm
secrets: inherit
validate-compute-manifest:
runs-on: ubuntu-22.04
needs: [ meta, check-permissions ]
# We do need to run this in `.*-rc-pr` because of hotfixes.
if: ${{ contains(fromJSON('["pr", "push-main", "storage-rc-pr", "proxy-rc-pr", "compute-rc-pr"]'), needs.meta.outputs.run-kind) }}
steps:
- name: Harden the runner (Audit all outbound calls)
uses: step-security/harden-runner@4d991eb9b905ef189e4c376166672c3f2f230481 # v2.11.0
with:
egress-policy: audit
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Node.js
uses: actions/setup-node@49933ea5288caeca8642d1e84afbd3f7d6820020 # v4.4.0
with:
node-version: '24'
- name: Validate manifest against schema
run: |
make -C compute manifest-schema-validation
build-and-test-locally:
needs: [ meta, build-build-tools-image ]
# We do need to run this in `.*-rc-pr` because of hotfixes.

4
.gitignore vendored
View File

@@ -15,7 +15,6 @@ neon.iml
/.neon
/integration_tests/.neon
compaction-suite-results.*
pgxn/neon/communicator/communicator_bindings.h
docker-compose/docker-compose-parallel.yml
# Coverage
@@ -30,6 +29,3 @@ docker-compose/docker-compose-parallel.yml
# pgindent typedef lists
*.list
# Node
**/node_modules/

8
.gitmodules vendored
View File

@@ -1,16 +1,16 @@
[submodule "vendor/postgres-v14"]
path = vendor/postgres-v14
url = ../postgres.git
url = https://github.com/neondatabase/postgres.git
branch = REL_14_STABLE_neon
[submodule "vendor/postgres-v15"]
path = vendor/postgres-v15
url = ../postgres.git
url = https://github.com/neondatabase/postgres.git
branch = REL_15_STABLE_neon
[submodule "vendor/postgres-v16"]
path = vendor/postgres-v16
url = ../postgres.git
url = https://github.com/neondatabase/postgres.git
branch = REL_16_STABLE_neon
[submodule "vendor/postgres-v17"]
path = vendor/postgres-v17
url = ../postgres.git
url = https://github.com/neondatabase/postgres.git
branch = REL_17_STABLE_neon

385
Cargo.lock generated
View File

@@ -253,17 +253,6 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8ab6b55fe97976e46f91ddbed8d147d966475dc29b2032757ba47e02376fbc3"
[[package]]
name = "atomic_enum"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99e1aca718ea7b89985790c94aad72d77533063fe00bc497bb79a7c2dae6a661"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "autocfg"
version = "1.1.0"
@@ -739,7 +728,7 @@ dependencies = [
"http 1.1.0",
"http-body 1.0.0",
"http-body-util",
"hyper 1.6.0",
"hyper 1.4.1",
"hyper-util",
"itoa",
"matchit 0.8.4",
@@ -1020,6 +1009,24 @@ dependencies = [
"serde",
]
[[package]]
name = "bindgen"
version = "0.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f"
dependencies = [
"bitflags 2.8.0",
"cexpr",
"clang-sys",
"itertools 0.12.1",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.100",
]
[[package]]
name = "bindgen"
version = "0.71.1"
@@ -1346,31 +1353,10 @@ dependencies = [
[[package]]
name = "communicator"
version = "0.0.0"
version = "0.1.0"
dependencies = [
"atomic_enum",
"axum 0.8.1",
"bytes",
"cbindgen",
"clashmap",
"http 1.1.0",
"libc",
"metrics",
"neon-shmem",
"nix 0.30.1",
"pageserver_api",
"pageserver_client_grpc",
"pageserver_page_api",
"prometheus",
"prost 0.13.5",
"thiserror 1.0.69",
"tokio",
"tokio-pipe",
"tonic 0.12.3",
"tracing",
"tracing-subscriber",
"uring-common",
"utils",
"workspace_hack",
]
@@ -1415,8 +1401,10 @@ dependencies = [
"hostname-validator",
"http 1.1.0",
"indexmap 2.9.0",
"inferno 0.12.0",
"itertools 0.10.5",
"jsonwebtoken",
"libproc",
"metrics",
"nix 0.30.1",
"notify",
@@ -1430,6 +1418,8 @@ dependencies = [
"postgres-types",
"postgres_initdb",
"postgres_versioninfo",
"pprof 0.15.0",
"prost 0.12.6",
"regex",
"remote_storage",
"reqwest",
@@ -1441,6 +1431,7 @@ dependencies = [
"serde_with",
"signal-hook",
"tar",
"tempfile",
"thiserror 1.0.69",
"tokio",
"tokio-postgres",
@@ -1705,9 +1696,9 @@ dependencies = [
[[package]]
name = "crossbeam-utils"
version = "0.8.21"
version = "0.8.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345"
[[package]]
name = "crossterm"
@@ -1951,7 +1942,6 @@ dependencies = [
"diesel_derives",
"itoa",
"serde_json",
"uuid",
]
[[package]]
@@ -2278,12 +2268,12 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]]
name = "errno"
version = "0.3.8"
version = "0.3.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245"
checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad"
dependencies = [
"libc",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
@@ -2443,7 +2433,7 @@ dependencies = [
"futures-core",
"futures-sink",
"http-body-util",
"hyper 1.6.0",
"hyper 1.4.1",
"hyper-util",
"pin-project",
"rand 0.8.5",
@@ -2882,6 +2872,15 @@ dependencies = [
"digest",
]
[[package]]
name = "home"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "hostname"
version = "0.4.0"
@@ -2991,7 +2990,7 @@ dependencies = [
"jsonwebtoken",
"metrics",
"once_cell",
"pprof",
"pprof 0.14.0",
"regex",
"routerify",
"rustls 0.23.27",
@@ -3014,9 +3013,9 @@ dependencies = [
[[package]]
name = "httparse"
version = "1.10.1"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904"
[[package]]
name = "httpdate"
@@ -3066,9 +3065,9 @@ dependencies = [
[[package]]
name = "hyper"
version = "1.6.0"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80"
checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05"
dependencies = [
"bytes",
"futures-channel",
@@ -3108,7 +3107,7 @@ checksum = "a0bea761b46ae2b24eb4aef630d8d1c398157b6fc29e6350ecf090a0b70c952c"
dependencies = [
"futures-util",
"http 1.1.0",
"hyper 1.6.0",
"hyper 1.4.1",
"hyper-util",
"rustls 0.22.4",
"rustls-pki-types",
@@ -3123,7 +3122,7 @@ version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3203a961e5c83b6f5498933e78b6b263e208c197b63e9c6c53cc82ffd3f63793"
dependencies = [
"hyper 1.6.0",
"hyper 1.4.1",
"hyper-util",
"pin-project-lite",
"tokio",
@@ -3132,21 +3131,20 @@ dependencies = [
[[package]]
name = "hyper-util"
version = "0.1.14"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc2fdfdbff08affe55bb779f33b053aa1fe5dd5b54c257343c17edfa55711bdb"
checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9"
dependencies = [
"bytes",
"futures-channel",
"futures-core",
"futures-util",
"http 1.1.0",
"http-body 1.0.0",
"hyper 1.6.0",
"libc",
"hyper 1.4.1",
"pin-project-lite",
"socket2",
"tokio",
"tower 0.4.13",
"tower-service",
"tracing",
]
@@ -3654,7 +3652,7 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
dependencies = [
"spin",
"spin 0.9.8",
]
[[package]]
@@ -3679,6 +3677,17 @@ version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
[[package]]
name = "libproc"
version = "0.14.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e78a09b56be5adbcad5aa1197371688dc6bb249a26da3bca2011ee2fb987ebfb"
dependencies = [
"bindgen 0.70.1",
"errno",
"libc",
]
[[package]]
name = "linux-raw-sys"
version = "0.4.14"
@@ -3691,6 +3700,12 @@ version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0b5399f6804fbab912acbd8878ed3532d506b7c951b8f9f164ef90fef39e3f4"
[[package]]
name = "linux-raw-sys"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12"
[[package]]
name = "litemap"
version = "0.7.4"
@@ -3699,9 +3714,9 @@ checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104"
[[package]]
name = "lock_api"
version = "0.4.13"
version = "0.4.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765"
checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16"
dependencies = [
"autocfg",
"scopeguard",
@@ -3857,7 +3872,7 @@ dependencies = [
"procfs",
"prometheus",
"rand 0.8.5",
"rand_distr 0.4.3",
"rand_distr",
"twox-hash",
]
@@ -3945,28 +3960,12 @@ checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
name = "neon-shmem"
version = "0.1.0"
dependencies = [
"libc",
"lock_api",
"nix 0.30.1",
"rand 0.9.1",
"rand_distr 0.5.1",
"rustc-hash 2.1.1",
"tempfile",
"thiserror 1.0.69",
"workspace_hack",
]
[[package]]
name = "neonart"
version = "0.1.0"
dependencies = [
"crossbeam-utils",
"rand 0.9.1",
"rand_distr 0.5.1",
"spin",
"tracing",
]
[[package]]
name = "never-say-never"
version = "6.6.666"
@@ -4400,21 +4399,17 @@ version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"axum 0.8.1",
"bytes",
"camino",
"clap",
"futures",
"hdrhistogram",
"http 1.1.0",
"humantime",
"humantime-serde",
"metrics",
"pageserver_api",
"pageserver_client",
"pageserver_client_grpc",
"pageserver_page_api",
"pprof",
"rand 0.8.5",
"reqwest",
"serde",
@@ -4499,7 +4494,6 @@ dependencies = [
"pageserver_client",
"pageserver_compaction",
"pageserver_page_api",
"peekable",
"pem",
"pin-project-lite",
"postgres-protocol",
@@ -4510,10 +4504,9 @@ dependencies = [
"postgres_ffi_types",
"postgres_initdb",
"posthog_client_lite",
"pprof",
"pprof 0.14.0",
"pq_proto",
"procfs",
"prost 0.13.5",
"rand 0.8.5",
"range-set-blaze",
"regex",
@@ -4808,15 +4801,6 @@ dependencies = [
"sha2",
]
[[package]]
name = "peekable"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "225f9651e475709164f871dc2f5724956be59cb9edb055372ffeeab01ec2d20b"
dependencies = [
"smallvec",
]
[[package]]
name = "pem"
version = "3.0.3"
@@ -5095,7 +5079,7 @@ name = "postgres_ffi"
version = "0.1.0"
dependencies = [
"anyhow",
"bindgen",
"bindgen 0.71.1",
"bytes",
"crc32c",
"criterion",
@@ -5106,7 +5090,7 @@ dependencies = [
"postgres",
"postgres_ffi_types",
"postgres_versioninfo",
"pprof",
"pprof 0.14.0",
"regex",
"serde",
"thiserror 1.0.69",
@@ -5196,6 +5180,30 @@ dependencies = [
"thiserror 1.0.69",
]
[[package]]
name = "pprof"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38a01da47675efa7673b032bf8efd8214f1917d89685e07e395ab125ea42b187"
dependencies = [
"aligned-vec",
"backtrace",
"cfg-if",
"findshlibs",
"inferno 0.11.21",
"libc",
"log",
"nix 0.26.4",
"once_cell",
"protobuf",
"protobuf-codegen",
"smallvec",
"spin 0.10.0",
"symbolic-demangle",
"tempfile",
"thiserror 2.0.11",
]
[[package]]
name = "pprof_util"
version = "0.7.0"
@@ -5271,7 +5279,7 @@ dependencies = [
"hex",
"lazy_static",
"procfs-core",
"rustix",
"rustix 0.38.41",
]
[[package]]
@@ -5407,6 +5415,57 @@ dependencies = [
"prost 0.13.5",
]
[[package]]
name = "protobuf"
version = "3.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d65a1d4ddae7d8b5de68153b48f6aa3bba8cb002b243dbdbc55a5afbc98f99f4"
dependencies = [
"once_cell",
"protobuf-support",
"thiserror 1.0.69",
]
[[package]]
name = "protobuf-codegen"
version = "3.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d3976825c0014bbd2f3b34f0001876604fe87e0c86cd8fa54251530f1544ace"
dependencies = [
"anyhow",
"once_cell",
"protobuf",
"protobuf-parse",
"regex",
"tempfile",
"thiserror 1.0.69",
]
[[package]]
name = "protobuf-parse"
version = "3.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b4aeaa1f2460f1d348eeaeed86aea999ce98c1bded6f089ff8514c9d9dbdc973"
dependencies = [
"anyhow",
"indexmap 2.9.0",
"log",
"protobuf",
"protobuf-support",
"tempfile",
"thiserror 1.0.69",
"which",
]
[[package]]
name = "protobuf-support"
version = "3.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e36c2f31e0a47f9280fb347ef5e461ffcd2c52dd520d8e216b52f93b0b0d7d6"
dependencies = [
"thiserror 1.0.69",
]
[[package]]
name = "proxy"
version = "0.1.0"
@@ -5419,7 +5478,6 @@ dependencies = [
"async-trait",
"atomic-take",
"aws-config",
"aws-credential-types",
"aws-sdk-iam",
"aws-sigv4",
"base64 0.22.1",
@@ -5451,7 +5509,7 @@ dependencies = [
"humantime",
"humantime-serde",
"hyper 0.14.30",
"hyper 1.6.0",
"hyper 1.4.1",
"hyper-util",
"indexmap 2.9.0",
"ipnet",
@@ -5459,7 +5517,6 @@ dependencies = [
"itoa",
"jose-jwa",
"jose-jwk",
"json",
"lasso",
"measured",
"metrics",
@@ -5476,7 +5533,7 @@ dependencies = [
"postgres_backend",
"pq_proto",
"rand 0.8.5",
"rand_distr 0.4.3",
"rand_distr",
"rcgen",
"redis",
"regex",
@@ -5487,7 +5544,7 @@ dependencies = [
"reqwest-tracing",
"rsa",
"rstest",
"rustc-hash 2.1.1",
"rustc-hash 1.1.0",
"rustls 0.23.27",
"rustls-native-certs 0.8.0",
"rustls-pemfile 2.1.1",
@@ -5610,16 +5667,6 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rand"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
]
[[package]]
name = "rand_chacha"
version = "0.2.2"
@@ -5640,16 +5687,6 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rand_chacha"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core 0.9.3",
]
[[package]]
name = "rand_core"
version = "0.5.1"
@@ -5668,15 +5705,6 @@ dependencies = [
"getrandom 0.2.11",
]
[[package]]
name = "rand_core"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
dependencies = [
"getrandom 0.3.3",
]
[[package]]
name = "rand_distr"
version = "0.4.3"
@@ -5687,16 +5715,6 @@ dependencies = [
"rand 0.8.5",
]
[[package]]
name = "rand_distr"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463"
dependencies = [
"num-traits",
"rand 0.9.1",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
@@ -5895,7 +5913,7 @@ dependencies = [
"http-body-util",
"http-types",
"humantime-serde",
"hyper 1.6.0",
"hyper 1.4.1",
"itertools 0.10.5",
"metrics",
"once_cell",
@@ -5935,7 +5953,7 @@ dependencies = [
"http 1.1.0",
"http-body 1.0.0",
"http-body-util",
"hyper 1.6.0",
"hyper 1.4.1",
"hyper-rustls 0.26.0",
"hyper-util",
"ipnet",
@@ -5992,7 +6010,7 @@ dependencies = [
"futures",
"getrandom 0.2.11",
"http 1.1.0",
"hyper 1.6.0",
"hyper 1.4.1",
"parking_lot 0.11.2",
"reqwest",
"reqwest-middleware",
@@ -6204,6 +6222,19 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "rustix"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266"
dependencies = [
"bitflags 2.8.0",
"errno",
"libc",
"linux-raw-sys 0.9.4",
"windows-sys 0.59.0",
]
[[package]]
name = "rustls"
version = "0.21.12"
@@ -6378,7 +6409,6 @@ dependencies = [
"itertools 0.10.5",
"jsonwebtoken",
"metrics",
"nix 0.30.1",
"once_cell",
"pageserver_api",
"parking_lot 0.12.1",
@@ -6386,9 +6416,8 @@ dependencies = [
"postgres-protocol",
"postgres_backend",
"postgres_ffi",
"postgres_ffi_types",
"postgres_versioninfo",
"pprof",
"pprof 0.14.0",
"pq_proto",
"rand 0.8.5",
"regex",
@@ -6431,7 +6460,7 @@ dependencies = [
"anyhow",
"const_format",
"pageserver_api",
"postgres_ffi_types",
"postgres_ffi",
"postgres_versioninfo",
"pq_proto",
"serde",
@@ -6963,12 +6992,12 @@ dependencies = [
[[package]]
name = "socket2"
version = "0.5.10"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678"
checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9"
dependencies = [
"libc",
"windows-sys 0.52.0",
"windows-sys 0.48.0",
]
[[package]]
@@ -6980,6 +7009,15 @@ dependencies = [
"lock_api",
]
[[package]]
name = "spin"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d5fe4ccb98d9c292d56fec89a5e07da7fc4cf0dc11e156b41793132775d3e591"
dependencies = [
"lock_api",
]
[[package]]
name = "spinning_top"
version = "0.3.0"
@@ -7037,7 +7075,7 @@ dependencies = [
"http-body-util",
"http-utils",
"humantime",
"hyper 1.6.0",
"hyper 1.4.1",
"hyper-util",
"metrics",
"once_cell",
@@ -7110,7 +7148,6 @@ dependencies = [
"tokio-util",
"tracing",
"utils",
"uuid",
"workspace_hack",
]
@@ -7174,7 +7211,6 @@ dependencies = [
"pageserver_api",
"pageserver_client",
"reqwest",
"safekeeper_api",
"serde_json",
"storage_controller_client",
"tokio",
@@ -7341,14 +7377,14 @@ dependencies = [
[[package]]
name = "tempfile"
version = "3.14.0"
version = "3.20.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c"
checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1"
dependencies = [
"cfg-if",
"fastrand 2.2.0",
"getrandom 0.3.3",
"once_cell",
"rustix",
"rustix 1.0.7",
"windows-sys 0.59.0",
]
@@ -7648,16 +7684,6 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "tokio-pipe"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f213a84bffbd61b8fa0ba8a044b4bbe35d471d0b518867181e82bd5c15542784"
dependencies = [
"libc",
"tokio",
]
[[package]]
name = "tokio-postgres"
version = "0.7.10"
@@ -7754,7 +7780,6 @@ dependencies = [
"futures-core",
"pin-project-lite",
"tokio",
"tokio-util",
]
[[package]]
@@ -7862,7 +7887,7 @@ dependencies = [
"http 1.1.0",
"http-body 1.0.0",
"http-body-util",
"hyper 1.6.0",
"hyper 1.4.1",
"hyper-timeout",
"hyper-util",
"percent-encoding",
@@ -7892,7 +7917,7 @@ dependencies = [
"http 1.1.0",
"http-body 1.0.0",
"http-body-util",
"hyper 1.6.0",
"hyper 1.4.1",
"hyper-timeout",
"hyper-util",
"percent-encoding",
@@ -8386,7 +8411,7 @@ dependencies = [
"pem",
"pin-project-lite",
"postgres_connection",
"pprof",
"pprof 0.14.0",
"pq_proto",
"rand 0.8.5",
"regex",
@@ -8408,7 +8433,6 @@ dependencies = [
"tracing-error",
"tracing-subscriber",
"tracing-utils",
"uuid",
"walkdir",
]
@@ -8494,7 +8518,7 @@ dependencies = [
"pageserver_api",
"postgres_ffi",
"postgres_ffi_types",
"pprof",
"pprof 0.14.0",
"prost 0.13.5",
"remote_storage",
"serde",
@@ -8524,7 +8548,7 @@ name = "walproposer"
version = "0.1.0"
dependencies = [
"anyhow",
"bindgen",
"bindgen 0.71.1",
"postgres_ffi",
"utils",
]
@@ -8689,6 +8713,18 @@ dependencies = [
"rustls-pki-types",
]
[[package]]
name = "which"
version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
dependencies = [
"either",
"home",
"once_cell",
"rustix 0.38.41",
]
[[package]]
name = "whoami"
version = "1.5.1"
@@ -8941,9 +8977,11 @@ dependencies = [
"camino",
"cc",
"chrono",
"clang-sys",
"clap",
"clap_builder",
"const-oid",
"criterion",
"crypto-bigint 0.5.5",
"der 0.7.8",
"deranged",
@@ -8968,8 +9006,9 @@ dependencies = [
"hex",
"hmac",
"hyper 0.14.30",
"hyper 1.6.0",
"hyper 1.4.1",
"hyper-util",
"indexmap 1.9.3",
"indexmap 2.9.0",
"itertools 0.12.1",
"lazy_static",
@@ -8986,7 +9025,6 @@ dependencies = [
"num-iter",
"num-rational",
"num-traits",
"once_cell",
"p256 0.13.2",
"parquet",
"prettyplease",
@@ -9008,6 +9046,7 @@ dependencies = [
"sha2",
"signature 2.2.0",
"smallvec",
"spin 0.9.8",
"spki 0.7.3",
"stable_deref_trait",
"subtle",
@@ -9022,12 +9061,14 @@ dependencies = [
"tokio-stream",
"tokio-util",
"toml_edit",
"tonic 0.12.3",
"tower 0.5.2",
"tracing",
"tracing-core",
"tracing-log",
"tracing-subscriber",
"url",
"uuid",
"zeroize",
"zstd",
"zstd-safe",

View File

@@ -35,7 +35,6 @@ members = [
"libs/pq_proto",
"libs/tenant_size_model",
"libs/metrics",
"libs/neonart",
"libs/postgres_connection",
"libs/remote_storage",
"libs/tracing-utils",
@@ -93,7 +92,6 @@ clap = { version = "4.0", features = ["derive", "env"] }
clashmap = { version = "1.0", features = ["raw-api"] }
comfy-table = "7.1"
const_format = "0.2"
crossbeam-utils = "0.8.21"
crc32c = "0.6"
diatomic-waker = { version = "0.2.3" }
either = "1.8"
@@ -132,7 +130,7 @@ jemalloc_pprof = { version = "0.7", features = ["symbolize", "flamegraph"] }
jsonwebtoken = "9"
lasso = "0.7"
libc = "0.2"
lock_api = "0.4.13"
libproc = "0.14"
md5 = "0.7.0"
measured = { version = "0.0.22", features=["lasso"] }
measured-process = { version = "0.0.22" }
@@ -153,7 +151,6 @@ parquet = { version = "53", default-features = false, features = ["zstd"] }
parquet_derive = "53"
pbkdf2 = { version = "0.12.1", features = ["simple", "std"] }
pem = "3.0.3"
peekable = "0.3.0"
pin-project-lite = "0.2"
pprof = { version = "0.14", features = ["criterion", "flamegraph", "frame-pointer", "prost-codec"] }
procfs = "0.16"
@@ -169,7 +166,7 @@ reqwest-middleware = "0.4"
reqwest-retry = "0.7"
routerify = "3"
rpds = "0.13"
rustc-hash = "2.1.1"
rustc-hash = "1.1.0"
rustls = { version = "0.23.16", default-features = false }
rustls-pemfile = "2"
rustls-pki-types = "1.11"
@@ -190,7 +187,6 @@ smallvec = "1.11"
smol_str = { version = "0.2.0", features = ["serde"] }
socket2 = "0.5"
spki = "0.7.3"
spin = "0.9.8"
strum = "0.26"
strum_macros = "0.26"
"subtle" = "2.5.0"
@@ -202,10 +198,11 @@ thiserror = "1.0"
tikv-jemallocator = { version = "0.6", features = ["profiling", "stats", "unprefixed_malloc_on_supported_platforms"] }
tikv-jemalloc-ctl = { version = "0.6", features = ["stats"] }
tokio = { version = "1.43.1", features = ["macros"] }
tokio-epoll-uring = { git = "https://github.com/neondatabase/tokio-epoll-uring.git" , branch = "main" }
tokio-io-timeout = "1.2.0"
tokio-postgres-rustls = "0.12.0"
tokio-rustls = { version = "0.26.0", default-features = false, features = ["tls12", "ring"]}
tokio-stream = { version = "0.1", features = ["sync"] }
tokio-stream = "0.1"
tokio-tar = "0.3"
tokio-util = { version = "0.7.10", features = ["io", "io-util", "rt"] }
toml = "0.8"
@@ -243,9 +240,6 @@ x509-cert = { version = "0.2.5" }
env_logger = "0.11"
log = "0.4"
tokio-epoll-uring = { git = "https://github.com/neondatabase/tokio-epoll-uring.git" , branch = "main" }
uring-common = { git = "https://github.com/neondatabase/tokio-epoll-uring.git" , branch = "main" }
## Libraries from neondatabase/ git forks, ideally with changes to be upstreamed
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon" }
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon" }
@@ -285,6 +279,7 @@ safekeeper_api = { version = "0.1", path = "./libs/safekeeper_api" }
safekeeper_client = { path = "./safekeeper/client" }
storage_broker = { version = "0.1", path = "./storage_broker/" } # Note: main broker code is inside the binary crate, so linking with the library shouldn't be heavy.
storage_controller_client = { path = "./storage_controller/client" }
tempfile = "3"
tenant_size_model = { version = "0.1", path = "./libs/tenant_size_model/" }
tracing-utils = { version = "0.1", path = "./libs/tracing-utils/" }
utils = { version = "0.1", path = "./libs/utils/" }

View File

@@ -109,6 +109,8 @@ RUN set -e \
libreadline-dev \
libseccomp-dev \
ca-certificates \
bpfcc-tools \
sudo \
openssl \
unzip \
curl \

View File

@@ -2,7 +2,7 @@ ROOT_PROJECT_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
# Where to install Postgres, default is ./pg_install, maybe useful for package
# managers.
POSTGRES_INSTALL_DIR ?= $(ROOT_PROJECT_DIR)/pg_install
POSTGRES_INSTALL_DIR ?= $(ROOT_PROJECT_DIR)/pg_install/
# Supported PostgreSQL versions
POSTGRES_VERSIONS = v17 v16 v15 v14
@@ -14,7 +14,7 @@ POSTGRES_VERSIONS = v17 v16 v15 v14
# it is derived from BUILD_TYPE.
# All intermediate build artifacts are stored here.
BUILD_DIR := $(ROOT_PROJECT_DIR)/build
BUILD_DIR := build
ICU_PREFIX_DIR := /usr/local/icu
@@ -212,7 +212,7 @@ neon-pgindent: postgres-v17-pg-bsd-indent neon-pg-ext-v17
FIND_TYPEDEF=$(ROOT_PROJECT_DIR)/vendor/postgres-v17/src/tools/find_typedef \
INDENT=$(BUILD_DIR)/v17/src/tools/pg_bsd_indent/pg_bsd_indent \
PGINDENT_SCRIPT=$(ROOT_PROJECT_DIR)/vendor/postgres-v17/src/tools/pgindent/pgindent \
-C $(BUILD_DIR)/pgxn-v17/neon \
-C $(BUILD_DIR)/neon-v17 \
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile pgindent
@@ -220,15 +220,11 @@ neon-pgindent: postgres-v17-pg-bsd-indent neon-pg-ext-v17
setup-pre-commit-hook:
ln -s -f $(ROOT_PROJECT_DIR)/pre-commit.py .git/hooks/pre-commit
build-tools/node_modules: build-tools/package.json
cd build-tools && $(if $(CI),npm ci,npm install)
touch build-tools/node_modules
.PHONY: lint-openapi-spec
lint-openapi-spec: build-tools/node_modules
lint-openapi-spec:
# operation-2xx-response: pageserver timeline delete returns 404 on success
find . -iname "openapi_spec.y*ml" -exec\
npx --prefix=build-tools/ redocly\
docker run --rm -v ${PWD}:/spec ghcr.io/redocly/cli:1.34.4\
--skip-rule=operation-operationId --skip-rule=operation-summary --extends=minimal\
--skip-rule=no-server-example.com --skip-rule=operation-2xx-response\
lint {} \+

View File

@@ -35,7 +35,7 @@ RUN echo 'Acquire::Retries "5";' > /etc/apt/apt.conf.d/80-retries && \
echo -e "retry_connrefused=on\ntimeout=15\ntries=5\nretry-on-host-error=on\n" > /root/.wgetrc && \
echo -e "--retry-connrefused\n--connect-timeout 15\n--retry 5\n--max-time 300\n" > /root/.curlrc
COPY build-tools/patches/pgcopydbv017.patch /pgcopydbv017.patch
COPY build_tools/patches/pgcopydbv017.patch /pgcopydbv017.patch
RUN if [ "${DEBIAN_VERSION}" = "bookworm" ]; then \
set -e && \
@@ -61,6 +61,9 @@ RUN if [ "${DEBIAN_VERSION}" = "bookworm" ]; then \
libpq5 \
libpq-dev \
libzstd-dev \
linux-perf \
bpfcc-tools \
linux-headers-$(case "$(uname -m)" in x86_64) echo amd64;; aarch64) echo arm64;; esac) \
postgresql-16 \
postgresql-server-dev-16 \
postgresql-common \
@@ -105,15 +108,21 @@ RUN echo 'Acquire::Retries "5";' > /etc/apt/apt.conf.d/80-retries && \
#
# 'gdb' is included so that we get backtraces of core dumps produced in
# regression tests
RUN set -e \
RUN set -ex \
&& KERNEL_VERSION="$(uname -r | cut -d'-' -f1 | sed 's/\.0$//')" \
&& echo KERNEL_VERSION=${KERNEL_VERSION} >> /etc/environment \
&& KERNEL_ARCH=$(uname -m | awk '{ if ($1 ~ /^(x86_64|i[3-6]86)$/) print "x86"; else if ($1 ~ /^(aarch64|arm.*)$/) print "aarch"; else print $1 }') \
&& echo KERNEL_ARCH=${KERNEL_ARCH} >> /etc/environment \
&& apt update \
&& apt install -y \
autoconf \
automake \
bc \
bison \
build-essential \
ca-certificates \
cmake \
cpio \
curl \
flex \
gdb \
@@ -122,8 +131,10 @@ RUN set -e \
gzip \
jq \
jsonnet \
kmod \
libcurl4-openssl-dev \
libbz2-dev \
libelf-dev \
libffi-dev \
liblzma-dev \
libncurses5-dev \
@@ -137,6 +148,11 @@ RUN set -e \
libxml2-dev \
libxmlsec1-dev \
libxxhash-dev \
linux-perf \
bpfcc-tools \
libbpfcc \
libbpfcc-dev \
linux-headers-$(case "$(uname -m)" in x86_64) echo amd64;; aarch64) echo arm64;; esac) \
lsof \
make \
netcat-openbsd \
@@ -144,6 +160,8 @@ RUN set -e \
openssh-client \
parallel \
pkg-config \
rsync \
sudo \
unzip \
wget \
xz-utils \
@@ -188,12 +206,6 @@ RUN curl -fsSL 'https://apt.llvm.org/llvm-snapshot.gpg.key' | apt-key add - \
&& bash -c 'for f in /usr/bin/clang*-${LLVM_VERSION} /usr/bin/llvm*-${LLVM_VERSION}; do ln -s "${f}" "${f%-${LLVM_VERSION}}"; done' \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
# Install node
ENV NODE_VERSION=24
RUN curl -fsSL https://deb.nodesource.com/setup_${NODE_VERSION}.x | bash - \
&& apt install -y nodejs \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
# Install docker
RUN curl -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /usr/share/keyrings/docker-archive-keyring.gpg \
&& echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/debian ${DEBIAN_VERSION} stable" > /etc/apt/sources.list.d/docker.list \
@@ -204,6 +216,8 @@ RUN curl -fsSL https://download.docker.com/linux/ubuntu/gpg | gpg --dearmor -o /
# Configure sudo & docker
RUN usermod -aG sudo nonroot && \
echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers && \
mkdir -p /etc/sudoers.d && \
echo 'nonroot ALL=(ALL) NOPASSWD:ALL' > /etc/sudoers.d/nonroot && \
usermod -aG docker nonroot
# AWS CLI
@@ -317,14 +331,14 @@ RUN curl -sSO https://static.rust-lang.org/rustup/dist/$(uname -m)-unknown-linux
. "$HOME/.cargo/env" && \
cargo --version && rustup --version && \
rustup component add llvm-tools rustfmt clippy && \
cargo install rustfilt --locked --version ${RUSTFILT_VERSION} && \
cargo install cargo-hakari --locked --version ${CARGO_HAKARI_VERSION} && \
cargo install cargo-deny --locked --version ${CARGO_DENY_VERSION} && \
cargo install cargo-hack --locked --version ${CARGO_HACK_VERSION} && \
cargo install cargo-nextest --locked --version ${CARGO_NEXTEST_VERSION} && \
cargo install cargo-chef --locked --version ${CARGO_CHEF_VERSION} && \
cargo install diesel_cli --locked --version ${CARGO_DIESEL_CLI_VERSION} \
--features postgres-bundled --no-default-features && \
cargo install rustfilt --version ${RUSTFILT_VERSION} --locked && \
cargo install cargo-hakari --version ${CARGO_HAKARI_VERSION} --locked && \
cargo install cargo-deny --version ${CARGO_DENY_VERSION} --locked && \
cargo install cargo-hack --version ${CARGO_HACK_VERSION} --locked && \
cargo install cargo-nextest --version ${CARGO_NEXTEST_VERSION} --locked && \
cargo install cargo-chef --version ${CARGO_CHEF_VERSION} --locked && \
cargo install diesel_cli --version ${CARGO_DIESEL_CLI_VERSION} --locked \
--features postgres-bundled --no-default-features && \
rm -rf /home/nonroot/.cargo/registry && \
rm -rf /home/nonroot/.cargo/git

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +0,0 @@
{
"name": "build-tools",
"private": true,
"devDependencies": {
"@redocly/cli": "1.34.4",
"@sourcemeta/jsonschema": "10.0.0"
}
}

View File

@@ -50,9 +50,9 @@ jsonnetfmt-format:
jsonnetfmt --in-place $(jsonnet_files)
.PHONY: manifest-schema-validation
manifest-schema-validation: ../build-tools/node_modules
npx --prefix=../build-tools/ jsonschema validate -d https://json-schema.org/draft/2020-12/schema manifest.schema.json manifest.yaml
manifest-schema-validation: node_modules
node_modules/.bin/jsonschema validate -d https://json-schema.org/draft/2020-12/schema manifest.schema.json manifest.yaml
../build-tools/node_modules: ../build-tools/package.json
cd ../build-tools && $(if $(CI),npm ci,npm install)
touch ../build-tools/node_modules
node_modules: package.json
npm install
touch node_modules

View File

@@ -9,7 +9,7 @@
#
# build-tools: This contains Rust compiler toolchain and other tools needed at compile
# time. This is also used for the storage builds. This image is defined in
# build-tools/Dockerfile.
# build-tools.Dockerfile.
#
# build-deps: Contains C compiler, other build tools, and compile-time dependencies
# needed to compile PostgreSQL and most extensions. (Some extensions need
@@ -115,7 +115,7 @@ ARG EXTENSIONS=all
FROM $BASE_IMAGE_SHA AS build-deps
ARG DEBIAN_VERSION
# Keep in sync with build-tools/Dockerfile
# Keep in sync with build-tools.Dockerfile
ENV PROTOC_VERSION=25.1
# Use strict mode for bash to catch errors early
@@ -149,6 +149,9 @@ RUN case $DEBIAN_VERSION in \
ninja-build git autoconf automake libtool build-essential bison flex libreadline-dev \
zlib1g-dev libxml2-dev libcurl4-openssl-dev libossp-uuid-dev wget ca-certificates pkg-config libssl-dev \
libicu-dev libxslt1-dev liblz4-dev libzstd-dev zstd curl unzip g++ \
bpfcc-tools \
libbpfcc \
libbpfcc-dev \
libclang-dev \
jsonnet \
$VERSION_INSTALLS \
@@ -170,29 +173,7 @@ RUN case $DEBIAN_VERSION in \
FROM build-deps AS pg-build
ARG PG_VERSION
COPY vendor/postgres-${PG_VERSION:?} postgres
COPY compute/patches/postgres_fdw.patch .
COPY compute/patches/pg_stat_statements_pg14-16.patch .
COPY compute/patches/pg_stat_statements_pg17.patch .
RUN cd postgres && \
# Apply patches to some contrib extensions
# For example, we need to grant EXECUTE on pg_stat_statements_reset() to {privileged_role_name}.
# In vanilla Postgres this function is limited to Postgres role superuser.
# In Neon we have {privileged_role_name} role that is not a superuser but replaces superuser in some cases.
# We could add the additional grant statements to the Postgres repository but it would be hard to maintain,
# whenever we need to pick up a new Postgres version and we want to limit the changes in our Postgres fork,
# so we do it here.
case "${PG_VERSION}" in \
"v14" | "v15" | "v16") \
patch -p1 < /pg_stat_statements_pg14-16.patch; \
;; \
"v17") \
patch -p1 < /pg_stat_statements_pg17.patch; \
;; \
*) \
# To do not forget to migrate patches to the next major version
echo "No contrib patches for this PostgreSQL version" && exit 1;; \
esac && \
patch -p1 < /postgres_fdw.patch && \
export CONFIGURE_CMD="./configure CFLAGS='-O2 -g3 -fsigned-char' --enable-debug --with-openssl --with-uuid=ossp \
--with-icu --with-libxml --with-libxslt --with-lz4" && \
if [ "${PG_VERSION:?}" != "v14" ]; then \
@@ -206,6 +187,8 @@ RUN cd postgres && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/autoinc.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/dblink.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/postgres_fdw.control && \
file=/usr/local/pgsql/share/extension/postgres_fdw--1.0.sql && [ -e $file ] && \
echo 'GRANT USAGE ON FOREIGN DATA WRAPPER postgres_fdw TO neon_superuser;' >> $file && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/bloom.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/earthdistance.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/insert_username.control && \
@@ -215,7 +198,34 @@ RUN cd postgres && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgrowlocks.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/pgstattuple.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/refint.control && \
echo 'trusted = true' >> /usr/local/pgsql/share/extension/xml2.control
echo 'trusted = true' >> /usr/local/pgsql/share/extension/xml2.control && \
# We need to grant EXECUTE on pg_stat_statements_reset() to neon_superuser.
# In vanilla postgres this function is limited to Postgres role superuser.
# In neon we have neon_superuser role that is not a superuser but replaces superuser in some cases.
# We could add the additional grant statements to the postgres repository but it would be hard to maintain,
# whenever we need to pick up a new postgres version and we want to limit the changes in our postgres fork,
# so we do it here.
for file in /usr/local/pgsql/share/extension/pg_stat_statements--*.sql; do \
filename=$(basename "$file"); \
# Note that there are no downgrade scripts for pg_stat_statements, so we \
# don't have to modify any downgrade paths or (much) older versions: we only \
# have to make sure every creation of the pg_stat_statements_reset function \
# also adds execute permissions to the neon_superuser.
case $filename in \
pg_stat_statements--1.4.sql) \
# pg_stat_statements_reset is first created with 1.4
echo 'GRANT EXECUTE ON FUNCTION pg_stat_statements_reset() TO neon_superuser;' >> $file; \
;; \
pg_stat_statements--1.6--1.7.sql) \
# Then with the 1.6-1.7 migration it is re-created with a new signature, thus add the permissions back
echo 'GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) TO neon_superuser;' >> $file; \
;; \
pg_stat_statements--1.10--1.11.sql) \
# Then with the 1.10-1.11 migration it is re-created with a new signature again, thus add the permissions back
echo 'GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint, boolean) TO neon_superuser;' >> $file; \
;; \
esac; \
done;
# Set PATH for all the subsequent build steps
ENV PATH="/usr/local/pgsql/bin:$PATH"
@@ -1517,7 +1527,7 @@ WORKDIR /ext-src
COPY compute/patches/pg_duckdb_v031.patch .
COPY compute/patches/duckdb_v120.patch .
# pg_duckdb build requires source dir to be a git repo to get submodules
# allow {privileged_role_name} to execute some functions that in pg_duckdb are available to superuser only:
# allow neon_superuser to execute some functions that in pg_duckdb are available to superuser only:
# - extension management function duckdb.install_extension()
# - access to duckdb.extensions table and its sequence
RUN git clone --depth 1 --branch v0.3.1 https://github.com/duckdb/pg_duckdb.git pg_duckdb-src && \
@@ -1783,7 +1793,7 @@ RUN set -e \
#########################################################################################
FROM build-deps AS exporters
ARG TARGETARCH
# Keep sql_exporter version same as in build-tools/Dockerfile and
# Keep sql_exporter version same as in build-tools.Dockerfile and
# test_runner/regress/test_compute_metrics.py
# See comment on the top of the file regading `echo`, `-e` and `\n`
RUN if [ "$TARGETARCH" = "amd64" ]; then\
@@ -1981,6 +1991,10 @@ RUN apt update && \
locales \
lsof \
procps \
bpfcc-tools \
libbpfcc \
libbpfcc-dev \
libclang-dev \
rsyslog-gnutls \
screen \
tcpdump \

7
compute/package.json Normal file
View File

@@ -0,0 +1,7 @@
{
"name": "neon-compute",
"private": true,
"dependencies": {
"@sourcemeta/jsonschema": "9.3.4"
}
}

View File

@@ -1,26 +1,22 @@
diff --git a/sql/anon.sql b/sql/anon.sql
index 0cdc769..5eab1d6 100644
index 0cdc769..b450327 100644
--- a/sql/anon.sql
+++ b/sql/anon.sql
@@ -1141,3 +1141,19 @@ $$
@@ -1141,3 +1141,15 @@ $$
-- TODO : https://en.wikipedia.org/wiki/L-diversity
-- TODO : https://en.wikipedia.org/wiki/T-closeness
+
+-- NEON Patches
+
+GRANT ALL ON SCHEMA anon to neon_superuser;
+GRANT ALL ON ALL TABLES IN SCHEMA anon TO neon_superuser;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT ALL ON SCHEMA anon to %I', privileged_role_name);
+ EXECUTE format('GRANT ALL ON ALL TABLES IN SCHEMA anon TO %I', privileged_role_name);
+
+ IF current_setting('server_version_num')::int >= 150000 THEN
+ EXECUTE format('GRANT SET ON PARAMETER anon.transparent_dynamic_masking TO %I', privileged_role_name);
+ END IF;
+ IF current_setting('server_version_num')::int >= 150000 THEN
+ GRANT SET ON PARAMETER anon.transparent_dynamic_masking TO neon_superuser;
+ END IF;
+END $$;
diff --git a/sql/init.sql b/sql/init.sql
index 7da6553..9b6164b 100644

View File

@@ -21,21 +21,13 @@ index 3235cc8..6b892bc 100644
include Makefile.global
diff --git a/sql/pg_duckdb--0.2.0--0.3.0.sql b/sql/pg_duckdb--0.2.0--0.3.0.sql
index d777d76..3b54396 100644
index d777d76..af60106 100644
--- a/sql/pg_duckdb--0.2.0--0.3.0.sql
+++ b/sql/pg_duckdb--0.2.0--0.3.0.sql
@@ -1056,3 +1056,14 @@ GRANT ALL ON FUNCTION duckdb.cache(TEXT, TEXT) TO PUBLIC;
@@ -1056,3 +1056,6 @@ GRANT ALL ON FUNCTION duckdb.cache(TEXT, TEXT) TO PUBLIC;
GRANT ALL ON FUNCTION duckdb.cache_info() TO PUBLIC;
GRANT ALL ON FUNCTION duckdb.cache_delete(TEXT) TO PUBLIC;
GRANT ALL ON PROCEDURE duckdb.recycle_ddb() TO PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT ALL ON FUNCTION duckdb.install_extension(TEXT) TO %I', privileged_role_name);
+ EXECUTE format('GRANT ALL ON TABLE duckdb.extensions TO %I', privileged_role_name);
+ EXECUTE format('GRANT ALL ON SEQUENCE duckdb.extensions_table_seq TO %I', privileged_role_name);
+END $$;
+GRANT ALL ON FUNCTION duckdb.install_extension(TEXT) TO neon_superuser;
+GRANT ALL ON TABLE duckdb.extensions TO neon_superuser;
+GRANT ALL ON SEQUENCE duckdb.extensions_table_seq TO neon_superuser;

View File

@@ -1,34 +0,0 @@
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.4.sql b/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
index 58cdf600fce..8be57a996f6 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
@@ -46,3 +46,12 @@ GRANT SELECT ON pg_stat_statements TO PUBLIC;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset() FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset() TO %I', privileged_role_name);
+END $$;
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql b/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
index 6fc3fed4c93..256345a8f79 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
@@ -20,3 +20,12 @@ LANGUAGE C STRICT PARALLEL SAFE;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) TO %I', privileged_role_name);
+END $$;

View File

@@ -1,52 +0,0 @@
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.10--1.11.sql b/contrib/pg_stat_statements/pg_stat_statements--1.10--1.11.sql
index 0bb2c397711..32764db1d8b 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.10--1.11.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.10--1.11.sql
@@ -80,3 +80,12 @@ LANGUAGE C STRICT PARALLEL SAFE;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint, boolean) FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint, boolean) TO %I', privileged_role_name);
+END $$;
\ No newline at end of file
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.4.sql b/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
index 58cdf600fce..8be57a996f6 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.4.sql
@@ -46,3 +46,12 @@ GRANT SELECT ON pg_stat_statements TO PUBLIC;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset() FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset() TO %I', privileged_role_name);
+END $$;
diff --git a/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql b/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
index 6fc3fed4c93..256345a8f79 100644
--- a/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
+++ b/contrib/pg_stat_statements/pg_stat_statements--1.6--1.7.sql
@@ -20,3 +20,12 @@ LANGUAGE C STRICT PARALLEL SAFE;
-- Don't want this to be available to non-superusers.
REVOKE ALL ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) FROM PUBLIC;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT EXECUTE ON FUNCTION pg_stat_statements_reset(Oid, Oid, bigint) TO %I', privileged_role_name);
+END $$;

View File

@@ -1,17 +0,0 @@
diff --git a/contrib/postgres_fdw/postgres_fdw--1.0.sql b/contrib/postgres_fdw/postgres_fdw--1.0.sql
index a0f0fc1bf45..ee077f2eea6 100644
--- a/contrib/postgres_fdw/postgres_fdw--1.0.sql
+++ b/contrib/postgres_fdw/postgres_fdw--1.0.sql
@@ -16,3 +16,12 @@ LANGUAGE C STRICT;
CREATE FOREIGN DATA WRAPPER postgres_fdw
HANDLER postgres_fdw_handler
VALIDATOR postgres_fdw_validator;
+
+DO $$
+DECLARE
+ privileged_role_name text;
+BEGIN
+ privileged_role_name := current_setting('neon.privileged_role_name');
+
+ EXECUTE format('GRANT USAGE ON FOREIGN DATA WRAPPER postgres_fdw TO %I', privileged_role_name);
+END $$;

View File

@@ -39,6 +39,14 @@ commands:
user: nobody
sysvInitAction: respawn
shell: '/bin/sql_exporter -config.file=/etc/sql_exporter_autoscaling.yml -web.listen-address=:9499'
- name: enable-kernel-modules
user: root
sysvInitAction: sysinit
shell: mkdir -p /lib/ && ln -s /neonvm/tools/lib/modules /lib/
- name: enable-bpfs
user: root
sysvInitAction: sysinit
shell: mkdir -p /sys/kernel/debug && mount -t debugfs debugfs /sys/kernel/debug && mount -t bpf bpf /sys/fs/bpf && chmod 755 /sys/fs/bpf
# Rsyslog by default creates a unix socket under /dev/log . That's where Postgres sends logs also.
# We run syslog with postgres user so it can't create /dev/log. Instead we configure rsyslog to
# use a different path for the socket. The symlink actually points to our custom path.
@@ -65,7 +73,7 @@ files:
# regardless of hostname (ALL)
#
# Also allow it to shut down the VM. The fast_import job does that when it's finished.
postgres ALL=(root) NOPASSWD: /neonvm/bin/resize-swap, /neonvm/bin/set-disk-quota, /neonvm/bin/poweroff, /usr/sbin/rsyslogd
postgres ALL=(root) NOPASSWD: /neonvm/bin/resize-swap, /neonvm/bin/set-disk-quota, /neonvm/bin/poweroff, /usr/sbin/rsyslogd, /neonvm/tools/bin/perf, /usr/sbin/profile-bpfcc
- filename: cgconfig.conf
content: |
# Configuration for cgroups in VM compute nodes
@@ -152,6 +160,8 @@ merge: |
RUN set -e \
&& chmod 0644 /etc/cgconfig.conf
ENV PERF_BINARY_PATH=/neonvm/tools/bin/perf
COPY compute_rsyslog.conf /etc/compute_rsyslog.conf
RUN chmod 0666 /etc/compute_rsyslog.conf

View File

@@ -39,6 +39,14 @@ commands:
user: nobody
sysvInitAction: respawn
shell: '/bin/sql_exporter -config.file=/etc/sql_exporter_autoscaling.yml -web.listen-address=:9499'
- name: enable-kernel-modules
user: root
sysvInitAction: sysinit
shell: mkdir -p /lib/ && ln -s /neonvm/tools/lib/modules /lib/
- name: enable-bpfs
user: root
sysvInitAction: sysinit
shell: mkdir -p /sys/kernel/debug && mount -t debugfs debugfs /sys/kernel/debug && mount -t bpf bpf /sys/fs/bpf && chmod 755 /sys/fs/bpf
# Rsyslog by default creates a unix socket under /dev/log . That's where Postgres sends logs also.
# We run syslog with postgres user so it can't create /dev/log. Instead we configure rsyslog to
# use a different path for the socket. The symlink actually points to our custom path.
@@ -65,7 +73,7 @@ files:
# regardless of hostname (ALL)
#
# Also allow it to shut down the VM. The fast_import job does that when it's finished.
postgres ALL=(root) NOPASSWD: /neonvm/bin/resize-swap, /neonvm/bin/set-disk-quota, /neonvm/bin/poweroff, /usr/sbin/rsyslogd
postgres ALL=(root) NOPASSWD: /neonvm/bin/resize-swap, /neonvm/bin/set-disk-quota, /neonvm/bin/poweroff, /usr/sbin/rsyslogd, /neonvm/tools/bin/perf, /usr/sbin/profile-bpfcc
- filename: cgconfig.conf
content: |
# Configuration for cgroups in VM compute nodes
@@ -148,6 +156,8 @@ merge: |
RUN set -e \
&& chmod 0644 /etc/cgconfig.conf
ENV PERF_BINARY_PATH=/neonvm/tools/bin/perf
COPY compute_rsyslog.conf /etc/compute_rsyslog.conf
RUN chmod 0666 /etc/compute_rsyslog.conf
RUN mkdir /var/log/rsyslog && chown -R postgres /var/log/rsyslog

View File

@@ -31,6 +31,7 @@ hostname-validator = "1.1"
indexmap.workspace = true
itertools.workspace = true
jsonwebtoken.workspace = true
libproc.workspace = true
metrics.workspace = true
nix.workspace = true
notify.workspace = true
@@ -49,6 +50,7 @@ serde_with.workspace = true
serde_json.workspace = true
signal-hook.workspace = true
tar.workspace = true
tempfile.workspace = true
tower.workspace = true
tower-http.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
@@ -78,3 +80,10 @@ zstd = "0.13"
bytes = "1.0"
rust-ini = "0.20.0"
rlimit = "0.10.1"
inferno = { version = "0.12", default-features = false, features = [
"multithreaded",
"nameattr",
] }
pprof = { version = "0.15", features = ["protobuf-codec", "flamegraph"] }
prost = "0.12"

View File

@@ -87,14 +87,6 @@ struct Cli {
#[arg(short = 'C', long, value_name = "DATABASE_URL")]
pub connstr: String,
#[arg(
long,
default_value = "neon_superuser",
value_name = "PRIVILEGED_ROLE_NAME",
value_parser = Self::parse_privileged_role_name
)]
pub privileged_role_name: String,
#[cfg(target_os = "linux")]
#[arg(long, default_value = "neon-postgres")]
pub cgroup: String,
@@ -157,21 +149,6 @@ impl Cli {
Ok(url)
}
/// For simplicity, we do not escape `privileged_role_name` anywhere in the code.
/// Since it's a system role, which we fully control, that's fine. Still, let's
/// validate it to avoid any surprises.
fn parse_privileged_role_name(value: &str) -> Result<String> {
use regex::Regex;
let pattern = Regex::new(r"^[a-z_]+$").unwrap();
if !pattern.is_match(value) {
bail!("--privileged-role-name can only contain lowercase letters and underscores")
}
Ok(value.to_string())
}
}
fn main() -> Result<()> {
@@ -201,7 +178,6 @@ fn main() -> Result<()> {
ComputeNodeParams {
compute_id: cli.compute_id,
connstr,
privileged_role_name: cli.privileged_role_name.clone(),
pgdata: cli.pgdata.clone(),
pgbin: cli.pgbin.clone(),
pgversion: get_pg_version_string(&cli.pgbin),
@@ -351,49 +327,4 @@ mod test {
])
.expect_err("URL parameters are not allowed");
}
#[test]
fn verify_privileged_role_name() {
// Valid name
let cli = Cli::parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--privileged-role-name",
"my_superuser",
]);
assert_eq!(cli.privileged_role_name, "my_superuser");
// Invalid names
Cli::try_parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--privileged-role-name",
"NeonSuperuser",
])
.expect_err("uppercase letters are not allowed");
Cli::try_parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--privileged-role-name",
"$'neon_superuser",
])
.expect_err("special characters are not allowed");
Cli::try_parse_from([
"compute_ctl",
"--pgdata=test",
"--connstr=test",
"--compute-id=test",
"--privileged-role-name",
"",
])
.expect_err("empty name is not allowed");
}
}

View File

@@ -6,8 +6,7 @@ use compute_api::responses::{
LfcPrewarmState, PromoteState, TlsConfig,
};
use compute_api::spec::{
ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PageserverConnectionInfo,
PageserverProtocol, PageserverShardConnectionInfo, PageserverShardInfo, PgIdent,
ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, ExtVersion, PageserverProtocol, PgIdent,
};
use futures::StreamExt;
use futures::future::join_all;
@@ -75,20 +74,12 @@ const DEFAULT_INSTALLED_EXTENSIONS_COLLECTION_INTERVAL: u64 = 3600;
/// Static configuration params that don't change after startup. These mostly
/// come from the CLI args, or are derived from them.
#[derive(Clone, Debug)]
pub struct ComputeNodeParams {
/// The ID of the compute
pub compute_id: String,
/// Url type maintains proper escaping
// Url type maintains proper escaping
pub connstr: url::Url,
/// The name of the 'weak' superuser role, which we give to the users.
/// It follows the allow list approach, i.e., we take a standard role
/// and grant it extra permissions with explicit GRANTs here and there,
/// and core patches.
pub privileged_role_name: String,
pub resize_swap_on_bind: bool,
pub set_disk_quota_for_fs: Option<String>,
@@ -234,7 +225,7 @@ pub struct ParsedSpec {
pub spec: ComputeSpec,
pub tenant_id: TenantId,
pub timeline_id: TimelineId,
pub pageserver_conninfo: PageserverConnectionInfo,
pub pageserver_connstr: String,
pub safekeeper_connstrings: Vec<String>,
pub storage_auth_token: Option<String>,
/// k8s dns name and port
@@ -281,114 +272,26 @@ impl ParsedSpec {
}
}
/// Extract PageserverConnectionInfo from a comma-separated list of libpq connection strings.
///
/// This is used for backwards-compatilibity, to parse the legacye `pageserver_connstr`
/// field in the compute spec, or the 'neon.pageserver_connstring' GUC. Nowadays, the
/// 'pageserver_connection_info' field should be used instead.
fn extract_pageserver_conninfo_from_connstr(
connstr: &str,
stripe_size: Option<u32>,
) -> Result<PageserverConnectionInfo, anyhow::Error> {
let shard_infos: Vec<_> = connstr
.split(',')
.map(|connstr| PageserverShardInfo {
pageservers: vec![PageserverShardConnectionInfo {
id: None,
libpq_url: Some(connstr.to_string()),
grpc_url: None,
}],
})
.collect();
match shard_infos.len() {
0 => anyhow::bail!("empty connection string"),
1 => {
// We assume that if there's only connection string, it means "unsharded",
// rather than a sharded system with just a single shard. The latter is
// possible in principle, but we never do it.
let shard_count = ShardCount::unsharded();
let only_shard = shard_infos.first().unwrap().clone();
let shards = vec![(ShardIndex::unsharded(), only_shard)];
Ok(PageserverConnectionInfo {
shard_count,
stripe_size: None,
shards: shards.into_iter().collect(),
prefer_protocol: PageserverProtocol::Libpq,
})
}
n => {
if stripe_size.is_none() {
anyhow::bail!("{n} shards but no stripe_size");
}
let shard_count = ShardCount(n.try_into()?);
let shards = shard_infos
.into_iter()
.enumerate()
.map(|(idx, shard_info)| {
(
ShardIndex {
shard_count,
shard_number: ShardNumber(
idx.try_into().expect("shard number fits in u8"),
),
},
shard_info,
)
})
.collect();
Ok(PageserverConnectionInfo {
shard_count,
stripe_size,
shards,
prefer_protocol: PageserverProtocol::Libpq,
})
}
}
}
impl TryFrom<ComputeSpec> for ParsedSpec {
type Error = anyhow::Error;
fn try_from(spec: ComputeSpec) -> Result<Self, anyhow::Error> {
type Error = String;
fn try_from(spec: ComputeSpec) -> Result<Self, String> {
// Extract the options from the spec file that are needed to connect to
// the storage system.
//
// In compute specs generated by old control plane versions, the spec file might
// be missing the `pageserver_connection_info` field. In that case, we need to dig
// the pageserver connection info from the `pageserver_connstr` field instead, or
// if that's missing too, from the GUC in the cluster.settings field.
let mut pageserver_conninfo = spec.pageserver_connection_info.clone();
if pageserver_conninfo.is_none() {
if let Some(pageserver_connstr_field) = &spec.pageserver_connstring {
pageserver_conninfo = Some(extract_pageserver_conninfo_from_connstr(
pageserver_connstr_field,
spec.shard_stripe_size,
)?);
}
}
if pageserver_conninfo.is_none() {
if let Some(guc) = spec.cluster.settings.find("neon.pageserver_connstring") {
let stripe_size = if let Some(guc) = spec.cluster.settings.find("neon.stripe_size")
{
Some(u32::from_str(&guc)?)
} else {
None
};
pageserver_conninfo =
Some(extract_pageserver_conninfo_from_connstr(&guc, stripe_size)?);
}
}
let pageserver_conninfo = pageserver_conninfo.ok_or(anyhow::anyhow!(
"pageserver connection information should be provided"
))?;
// Similarly for safekeeper connection strings
// For backwards-compatibility, the top-level fields in the spec file
// may be empty. In that case, we need to dig them from the GUCs in the
// cluster.settings field.
let pageserver_connstr = spec
.pageserver_connstring
.clone()
.or_else(|| spec.cluster.settings.find("neon.pageserver_connstring"))
.ok_or("pageserver connstr should be provided")?;
let safekeeper_connstrings = if spec.safekeeper_connstrings.is_empty() {
if matches!(spec.mode, ComputeMode::Primary) {
spec.cluster
.settings
.find("neon.safekeepers")
.ok_or(anyhow::anyhow!("safekeeper connstrings should be provided"))?
.ok_or("safekeeper connstrings should be provided")?
.split(',')
.map(|str| str.to_string())
.collect()
@@ -403,22 +306,22 @@ impl TryFrom<ComputeSpec> for ParsedSpec {
let tenant_id: TenantId = if let Some(tenant_id) = spec.tenant_id {
tenant_id
} else {
let guc = spec
.cluster
spec.cluster
.settings
.find("neon.tenant_id")
.ok_or(anyhow::anyhow!("tenant id should be provided"))?;
TenantId::from_str(&guc).context("invalid tenant id")?
.ok_or("tenant id should be provided")
.map(|s| TenantId::from_str(&s))?
.or(Err("invalid tenant id"))?
};
let timeline_id: TimelineId = if let Some(timeline_id) = spec.timeline_id {
timeline_id
} else {
let guc = spec
.cluster
spec.cluster
.settings
.find("neon.timeline_id")
.ok_or(anyhow::anyhow!("timeline id should be provided"))?;
TimelineId::from_str(&guc).context(anyhow::anyhow!("invalid timeline id"))?
.ok_or("timeline id should be provided")
.map(|s| TimelineId::from_str(&s))?
.or(Err("invalid timeline id"))?
};
let endpoint_storage_addr: Option<String> = spec
@@ -432,7 +335,7 @@ impl TryFrom<ComputeSpec> for ParsedSpec {
let res = ParsedSpec {
spec,
pageserver_conninfo,
pageserver_connstr,
safekeeper_connstrings,
storage_auth_token,
tenant_id,
@@ -442,7 +345,7 @@ impl TryFrom<ComputeSpec> for ParsedSpec {
};
// Now check validity of the parsed specification
res.validate().map_err(anyhow::Error::msg)?;
res.validate()?;
Ok(res)
}
}
@@ -468,7 +371,9 @@ fn maybe_cgexec(cmd: &str) -> Command {
}
}
struct PostgresHandle {
/// A handle to the Postgres process that is running in the compute
/// node.
pub struct PostgresHandle {
postgres: std::process::Child,
log_collector: JoinHandle<Result<()>>,
}
@@ -1129,13 +1034,13 @@ impl ComputeNode {
fn try_get_basebackup(&self, compute_state: &ComputeState, lsn: Lsn) -> Result<()> {
let spec = compute_state.pspec.as_ref().expect("spec must be set");
let shard0_connstr = spec.pageserver_connstr.split(',').next().unwrap();
let started = Instant::now();
let (connected, size) = match spec.pageserver_conninfo.prefer_protocol {
PageserverProtocol::Grpc => self.try_get_basebackup_grpc(spec, lsn)?,
PageserverProtocol::Libpq => self.try_get_basebackup_libpq(spec, lsn)?,
};
self.fix_zenith_signal_neon_signal()?;
let (connected, size) = match PageserverProtocol::from_connstring(shard0_connstr)? {
PageserverProtocol::Libpq => self.try_get_basebackup_libpq(spec, lsn)?,
PageserverProtocol::Grpc => self.try_get_basebackup_grpc(spec, lsn)?,
};
let mut state = self.state.lock().unwrap();
state.metrics.pageserver_connect_micros =
@@ -1146,56 +1051,26 @@ impl ComputeNode {
Ok(())
}
/// Move the Zenith signal file to Neon signal file location.
/// This makes Compute compatible with older PageServers that don't yet
/// know about the Zenith->Neon rename.
fn fix_zenith_signal_neon_signal(&self) -> Result<()> {
let datadir = Path::new(&self.params.pgdata);
let neonsig = datadir.join("neon.signal");
if neonsig.is_file() {
return Ok(());
}
let zenithsig = datadir.join("zenith.signal");
if zenithsig.is_file() {
fs::copy(zenithsig, neonsig)?;
}
Ok(())
}
/// Fetches a basebackup via gRPC. The connstring must use grpc://. Returns the timestamp when
/// the connection was established, and the (compressed) size of the basebackup.
fn try_get_basebackup_grpc(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> {
let shard0_index = ShardIndex {
shard_number: ShardNumber(0),
shard_count: spec.pageserver_conninfo.shard_count,
let shard0_connstr = spec
.pageserver_connstr
.split(',')
.next()
.unwrap()
.to_string();
let shard_index = match spec.pageserver_connstr.split(',').count() as u8 {
0 | 1 => ShardIndex::unsharded(),
count => ShardIndex::new(ShardNumber(0), ShardCount(count)),
};
let shard0 = spec
.pageserver_conninfo
.shards
.get(&shard0_index)
.ok_or_else(|| {
anyhow::anyhow!("shard connection info missing for shard {}", shard0_index)
})?;
let pageserver = shard0
.pageservers
.first()
.expect("must have at least one pageserver");
let shard0_url = pageserver
.grpc_url
.clone()
.expect("no grpc_url for shard 0");
let (reader, connected) = tokio::runtime::Handle::current().block_on(async move {
let mut client = page_api::Client::connect(
shard0_url,
shard0_connstr,
spec.tenant_id,
spec.timeline_id,
shard0_index,
shard_index,
spec.storage_auth_token.clone(),
None, // NB: base backups use payload compression
)
@@ -1227,26 +1102,8 @@ impl ComputeNode {
/// Fetches a basebackup via libpq. The connstring must use postgresql://. Returns the timestamp
/// when the connection was established, and the (compressed) size of the basebackup.
fn try_get_basebackup_libpq(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<(Instant, usize)> {
let shard0_index = ShardIndex {
shard_number: ShardNumber(0),
shard_count: spec.pageserver_conninfo.shard_count,
};
let shard0 = spec
.pageserver_conninfo
.shards
.get(&shard0_index)
.ok_or_else(|| {
anyhow::anyhow!("shard connection info missing for shard {}", shard0_index)
})?;
let pageserver = shard0
.pageservers
.first()
.expect("must have at least one pageserver");
let shard0_connstr = pageserver
.libpq_url
.clone()
.expect("no libpq_url for shard 0");
let mut config = postgres::Config::from_str(&shard0_connstr)?;
let shard0_connstr = spec.pageserver_connstr.split(',').next().unwrap();
let mut config = postgres::Config::from_str(shard0_connstr)?;
// Use the storage auth token from the config file, if given.
// Note: this overrides any password set in the connection string.
@@ -1332,7 +1189,10 @@ impl ComputeNode {
return result;
}
Err(ref e) if attempts < max_attempts => {
warn!("Failed to get basebackup: {e:?} (attempt {attempts}/{max_attempts})");
warn!(
"Failed to get basebackup: {} (attempt {}/{})",
e, attempts, max_attempts
);
std::thread::sleep(std::time::Duration::from_millis(retry_period_ms as u64));
retry_period_ms *= 1.5;
}
@@ -1405,7 +1265,9 @@ impl ComputeNode {
// In case of error, log and fail the check, but don't crash.
// We're playing it safe because these errors could be transient
// and we don't yet retry.
// and we don't yet retry. Also being careful here allows us to
// be backwards compatible with safekeepers that don't have the
// TIMELINE_STATUS API yet.
if responses.len() < quorum {
error!(
"failed sync safekeepers check {:?} {:?} {:?}",
@@ -1508,7 +1370,6 @@ impl ComputeNode {
self.create_pgdata()?;
config::write_postgres_conf(
pgdata_path,
&self.params,
&pspec.spec,
self.params.internal_http_port,
tls_config,
@@ -1540,8 +1401,16 @@ impl ComputeNode {
}
};
self.get_basebackup(compute_state, lsn)
.with_context(|| format!("failed to get basebackup@{lsn}"))?;
info!(
"getting basebackup@{} from pageserver {}",
lsn, &pspec.pageserver_connstr
);
self.get_basebackup(compute_state, lsn).with_context(|| {
format!(
"failed to get basebackup@{} from pageserver {}",
lsn, &pspec.pageserver_connstr
)
})?;
// Update pg_hba.conf received with basebackup.
update_pg_hba(pgdata_path)?;
@@ -1849,7 +1718,6 @@ impl ComputeNode {
}
// Run migrations separately to not hold up cold starts
let params = self.params.clone();
tokio::spawn(async move {
let mut conf = conf.as_ref().clone();
conf.application_name("compute_ctl:migrations");
@@ -1861,7 +1729,7 @@ impl ComputeNode {
eprintln!("connection error: {e}");
}
});
if let Err(e) = handle_migrations(params, &mut client).await {
if let Err(e) = handle_migrations(&mut client).await {
error!("Failed to run migrations: {}", e);
}
}
@@ -1940,7 +1808,6 @@ impl ComputeNode {
let pgdata_path = Path::new(&self.params.pgdata);
config::write_postgres_conf(
pgdata_path,
&self.params,
&spec,
self.params.internal_http_port,
tls_config,
@@ -2484,22 +2351,22 @@ LIMIT 100",
/// The operation will time out after a specified duration.
pub fn wait_timeout_while_pageserver_connstr_unchanged(&self, duration: Duration) {
let state = self.state.lock().unwrap();
let old_pageserver_conninfo = state
let old_pageserver_connstr = state
.pspec
.as_ref()
.expect("spec must be set")
.pageserver_conninfo
.pageserver_connstr
.clone();
let mut unchanged = true;
let _ = self
.state_changed
.wait_timeout_while(state, duration, |s| {
let pageserver_conninfo = &s
let pageserver_connstr = &s
.pspec
.as_ref()
.expect("spec must be set")
.pageserver_conninfo;
unchanged = pageserver_conninfo == &old_pageserver_conninfo;
.pageserver_connstr;
unchanged = pageserver_connstr == &old_pageserver_connstr;
unchanged
})
.unwrap();
@@ -2553,31 +2420,14 @@ LIMIT 100",
pub fn spawn_lfc_offload_task(self: &Arc<Self>, interval: Duration) {
self.terminate_lfc_offload_task();
let secs = interval.as_secs();
info!("spawning lfc offload worker with {secs}s interval");
let this = self.clone();
info!("spawning LFC offload worker with {secs}s interval");
let handle = spawn(async move {
let mut interval = time::interval(interval);
interval.tick().await; // returns immediately
loop {
interval.tick().await;
let prewarm_state = this.state.lock().unwrap().lfc_prewarm_state.clone();
// Do not offload LFC state if we are currently prewarming or any issue occurred.
// If we'd do that, we might override the LFC state in endpoint storage with some
// incomplete state. Imagine a situation:
// 1. Endpoint started with `autoprewarm: true`
// 2. While prewarming is not completed, we upload the new incomplete state
// 3. Compute gets interrupted and restarts
// 4. We start again and try to prewarm with the state from 2. instead of the previous complete state
if matches!(
prewarm_state,
LfcPrewarmState::Completed
| LfcPrewarmState::NotPrewarmed
| LfcPrewarmState::Skipped
) {
this.offload_lfc_async().await;
}
this.offload_lfc_async().await;
}
});
*self.lfc_offload_task.lock().unwrap() = Some(handle);
@@ -2616,7 +2466,7 @@ pub async fn installed_extensions(conf: tokio_postgres::Config) -> Result<()> {
serde_json::to_string(&extensions).expect("failed to serialize extensions list")
);
}
Err(err) => error!("could not get installed extensions: {err}"),
Err(err) => error!("could not get installed extensions: {err:?}"),
}
Ok(())
}
@@ -2729,10 +2579,7 @@ mod tests {
match ParsedSpec::try_from(spec.clone()) {
Ok(_p) => panic!("Failed to detect duplicate entry"),
Err(e) => assert!(
e.to_string()
.starts_with("duplicate entry in safekeeper_connstrings:")
),
Err(e) => assert!(e.starts_with("duplicate entry in safekeeper_connstrings:")),
};
}
}

View File

@@ -89,7 +89,7 @@ impl ComputeNode {
self.state.lock().unwrap().lfc_offload_state.clone()
}
/// If there is a prewarm request ongoing, return `false`, `true` otherwise.
/// If there is a prewarm request ongoing, return false, true otherwise
pub fn prewarm_lfc(self: &Arc<Self>, from_endpoint: Option<String>) -> bool {
{
let state = &mut self.state.lock().unwrap().lfc_prewarm_state;
@@ -101,25 +101,15 @@ impl ComputeNode {
let cloned = self.clone();
spawn(async move {
let state = match cloned.prewarm_impl(from_endpoint).await {
Ok(true) => LfcPrewarmState::Completed,
Ok(false) => {
info!(
"skipping LFC prewarm because LFC state is not found in endpoint storage"
);
LfcPrewarmState::Skipped
}
Err(err) => {
crate::metrics::LFC_PREWARM_ERRORS.inc();
error!(%err, "could not prewarm LFC");
LfcPrewarmState::Failed {
error: err.to_string(),
}
}
let Err(err) = cloned.prewarm_impl(from_endpoint).await else {
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Completed;
return;
};
crate::metrics::LFC_PREWARM_ERRORS.inc();
error!(%err, "prewarming lfc");
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Failed {
error: err.to_string(),
};
cloned.state.lock().unwrap().lfc_prewarm_state = state;
});
true
}
@@ -130,21 +120,15 @@ impl ComputeNode {
EndpointStoragePair::from_spec_and_endpoint(state.pspec.as_ref().unwrap(), from_endpoint)
}
/// Request LFC state from endpoint storage and load corresponding pages into Postgres.
/// Returns a result with `false` if the LFC state is not found in endpoint storage.
async fn prewarm_impl(&self, from_endpoint: Option<String>) -> Result<bool> {
async fn prewarm_impl(&self, from_endpoint: Option<String>) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?;
info!(%url, "requesting LFC state from endpoint storage");
let request = Client::new().get(&url).bearer_auth(token);
let res = request.send().await.context("querying endpoint storage")?;
let status = res.status();
match status {
StatusCode::OK => (),
StatusCode::NOT_FOUND => {
return Ok(false);
}
_ => bail!("{status} querying endpoint storage"),
if status != StatusCode::OK {
bail!("{status} querying endpoint storage")
}
let mut uncompressed = Vec::new();
@@ -157,8 +141,7 @@ impl ComputeNode {
.await
.context("decoding LFC state")?;
let uncompressed_len = uncompressed.len();
info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}, loading into Postgres");
info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}, loading into postgres");
ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
.await
@@ -166,9 +149,7 @@ impl ComputeNode {
.query_one("select neon.prewarm_local_cache($1)", &[&uncompressed])
.await
.context("loading LFC state into postgres")
.map(|_| ())?;
Ok(true)
.map(|_| ())
}
/// If offload request is ongoing, return false, true otherwise
@@ -196,14 +177,12 @@ impl ComputeNode {
async fn offload_lfc_with_state_update(&self) {
crate::metrics::LFC_OFFLOADS.inc();
let Err(err) = self.offload_lfc_impl().await else {
self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Completed;
return;
};
crate::metrics::LFC_OFFLOAD_ERRORS.inc();
error!(%err, "could not offload LFC state to endpoint storage");
error!(%err, "offloading lfc");
self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Failed {
error: err.to_string(),
};
@@ -211,7 +190,7 @@ impl ComputeNode {
async fn offload_lfc_impl(&self) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?;
info!(%url, "requesting LFC state from Postgres");
info!(%url, "requesting LFC state from postgres");
let mut compressed = Vec::new();
ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
@@ -226,17 +205,13 @@ impl ComputeNode {
.read_to_end(&mut compressed)
.await
.context("compressing LFC state")?;
let compressed_len = compressed.len();
info!(%url, "downloaded LFC state, compressed size {compressed_len}, writing to endpoint storage");
let request = Client::new().put(url).bearer_auth(token).body(compressed);
match request.send().await {
Ok(res) if res.status() == StatusCode::OK => Ok(()),
Ok(res) => bail!(
"Request to endpoint storage failed with status: {}",
res.status()
),
Ok(res) => bail!("Error writing to endpoint storage: {}", res.status()),
Err(err) => Err(err).context("writing to endpoint storage"),
}
}

View File

@@ -9,14 +9,11 @@ use std::path::Path;
use compute_api::responses::TlsConfig;
use compute_api::spec::{ComputeAudit, ComputeMode, ComputeSpec, GenericOption};
use crate::compute::ComputeNodeParams;
use crate::pg_helpers::{
GenericOptionExt, GenericOptionsSearch, PgOptionsSerialize, escape_conf_value,
};
use crate::tls::{self, SERVER_CRT, SERVER_KEY};
use utils::shard::{ShardIndex, ShardNumber};
/// Check that `line` is inside a text file and put it there if it is not.
/// Create file if it doesn't exist.
pub fn line_in_file(path: &Path, line: &str) -> Result<bool> {
@@ -44,7 +41,6 @@ pub fn line_in_file(path: &Path, line: &str) -> Result<bool> {
/// Create or completely rewrite configuration file specified by `path`
pub fn write_postgres_conf(
pgdata_path: &Path,
params: &ComputeNodeParams,
spec: &ComputeSpec,
extension_server_port: u16,
tls_config: &Option<TlsConfig>,
@@ -60,99 +56,12 @@ pub fn write_postgres_conf(
// Add options for connecting to storage
writeln!(file, "# Neon storage settings")?;
writeln!(file)?;
if let Some(conninfo) = &spec.pageserver_connection_info {
// Stripe size GUC should be defined prior to connection string
if let Some(stripe_size) = conninfo.stripe_size {
writeln!(
file,
"# from compute spec's pageserver_conninfo.stripe_size field"
)?;
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
let mut libpq_urls: Option<Vec<String>> = Some(Vec::new());
let mut grpc_urls: Option<Vec<String>> = Some(Vec::new());
let num_shards = if conninfo.shard_count.0 == 0 {
1 // unsharded, treat it as a single shard
} else {
conninfo.shard_count.0
};
for shard_number in 0..num_shards {
let shard_index = ShardIndex {
shard_number: ShardNumber(shard_number),
shard_count: conninfo.shard_count,
};
let info = conninfo.shards.get(&shard_index).ok_or_else(|| {
anyhow::anyhow!(
"shard {shard_index} missing from pageserver_connection_info shard map"
)
})?;
let first_pageserver = info
.pageservers
.first()
.expect("must have at least one pageserver");
// Add the libpq URL to the array, or if the URL is missing, reset the array
// forgetting any previous entries. All servers must have a libpq URL, or none
// at all.
if let Some(url) = &first_pageserver.libpq_url {
if let Some(ref mut urls) = libpq_urls {
urls.push(url.clone());
}
} else {
libpq_urls = None
}
// Similarly for gRPC URLs
if let Some(url) = &first_pageserver.grpc_url {
if let Some(ref mut urls) = grpc_urls {
urls.push(url.clone());
}
} else {
grpc_urls = None
}
}
if let Some(libpq_urls) = libpq_urls {
writeln!(
file,
"# derived from compute spec's pageserver_conninfo field"
)?;
writeln!(
file,
"neon.pageserver_connstring={}",
escape_conf_value(&libpq_urls.join(","))
)?;
} else {
writeln!(file, "# no neon.pageserver_connstring")?;
}
if let Some(grpc_urls) = grpc_urls {
writeln!(
file,
"# derived from compute spec's pageserver_conninfo field"
)?;
writeln!(
file,
"neon.pageserver_grpc_urls={}",
escape_conf_value(&grpc_urls.join(","))
)?;
} else {
writeln!(file, "# no neon.pageserver_grpc_urls")?;
}
} else {
// Stripe size GUC should be defined prior to connection string
if let Some(stripe_size) = spec.shard_stripe_size {
writeln!(file, "# from compute spec's shard_stripe_size field")?;
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
if let Some(s) = &spec.pageserver_connstring {
writeln!(file, "# from compute spec's pageserver_connstring field")?;
writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?;
}
if let Some(s) = &spec.pageserver_connstring {
writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?;
}
if let Some(stripe_size) = spec.shard_stripe_size {
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
if !spec.safekeeper_connstrings.is_empty() {
let mut neon_safekeepers_value = String::new();
tracing::info!(
@@ -252,12 +161,6 @@ pub fn write_postgres_conf(
}
}
writeln!(
file,
"neon.privileged_role_name={}",
escape_conf_value(params.privileged_role_name.as_str())
)?;
// If there are any extra options in the 'settings' field, append those
if spec.cluster.settings.is_some() {
writeln!(file, "# Managed by compute_ctl: begin")?;

View File

@@ -613,11 +613,11 @@ components:
- skipped
properties:
status:
description: LFC prewarm status
enum: [not_prewarmed, prewarming, completed, failed, skipped]
description: Lfc prewarm status
enum: [not_prewarmed, prewarming, completed, failed]
type: string
error:
description: LFC prewarm error, if any
description: Lfc prewarm error, if any
type: string
total:
description: Total pages processed
@@ -635,11 +635,11 @@ components:
- status
properties:
status:
description: LFC offload status
description: Lfc offload status
enum: [not_offloaded, offloading, completed, failed]
type: string
error:
description: LFC offload error, if any
description: Lfc offload error, if any
type: string
PromoteState:

View File

@@ -15,6 +15,7 @@ pub(in crate::http) mod lfc;
pub(in crate::http) mod metrics;
pub(in crate::http) mod metrics_json;
pub(in crate::http) mod promote;
pub(in crate::http) mod profile;
pub(in crate::http) mod status;
pub(in crate::http) mod terminate;

View File

@@ -0,0 +1,217 @@
//! Contains the route for profiling the compute.
//!
//! Profiling the compute means generating a pprof profile of the
//! postgres processes.
//!
//! The profiling is done using the `perf` tool, which is expected to be
//! available somewhere in `$PATH`.
use std::sync::atomic::Ordering;
use axum::Json;
use axum::response::IntoResponse;
use http::StatusCode;
use nix::unistd::Pid;
use once_cell::sync::Lazy;
use tokio::sync::Mutex;
use crate::http::JsonResponse;
static CANCEL_CHANNEL: Lazy<Mutex<Option<tokio::sync::broadcast::Sender<()>>>> =
Lazy::new(|| Mutex::new(None));
fn default_sampling_frequency() -> u16 {
100
}
fn default_timeout_seconds() -> u8 {
5
}
fn deserialize_sampling_frequency<'de, D>(deserializer: D) -> Result<u16, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::Deserialize;
const MIN_SAMPLING_FREQUENCY: u16 = 1;
const MAX_SAMPLING_FREQUENCY: u16 = 1000;
let value = u16::deserialize(deserializer)?;
if !(MIN_SAMPLING_FREQUENCY..=MAX_SAMPLING_FREQUENCY).contains(&value) {
return Err(serde::de::Error::custom(format!(
"sampling_frequency must be between {MIN_SAMPLING_FREQUENCY} and {MAX_SAMPLING_FREQUENCY}, got {value}"
)));
}
Ok(value)
}
fn deserialize_profiling_timeout<'de, D>(deserializer: D) -> Result<u8, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::Deserialize;
const MIN_TIMEOUT_SECONDS: u8 = 1;
const MAX_TIMEOUT_SECONDS: u8 = 60;
let value = u8::deserialize(deserializer)?;
if !(MIN_TIMEOUT_SECONDS..=MAX_TIMEOUT_SECONDS).contains(&value) {
return Err(serde::de::Error::custom(format!(
"timeout_seconds must be between {MIN_TIMEOUT_SECONDS} and {MAX_TIMEOUT_SECONDS}, got {value}"
)));
}
Ok(value)
}
/// Request parameters for profiling the compute.
#[derive(Debug, Clone, serde::Deserialize)]
pub(in crate::http) struct ProfileRequest {
/// The profiling tool to use, currently only `perf` is supported.
profiler: crate::profiling::ProfileGenerator,
#[serde(default = "default_sampling_frequency")]
#[serde(deserialize_with = "deserialize_sampling_frequency")]
sampling_frequency: u16,
#[serde(default = "default_timeout_seconds")]
#[serde(deserialize_with = "deserialize_profiling_timeout")]
timeout_seconds: u8,
#[serde(default)]
archive: bool,
}
/// The HTTP request handler for reporting the profiling status of
/// the compute.
pub(in crate::http) async fn profile_status() -> impl IntoResponse {
tracing::info!("Profile status request received.");
let cancel_channel = CANCEL_CHANNEL.lock().await;
if let Some(tx) = cancel_channel.as_ref() {
if tx.receiver_count() > 0 {
return JsonResponse::create_response(
StatusCode::OK,
"Profiling is currently in progress.",
);
}
}
JsonResponse::create_response(StatusCode::NO_CONTENT, "Profiling is not in progress.")
}
/// The HTTP request handler for stopping profiling the compute.
pub(in crate::http) async fn profile_stop() -> impl IntoResponse {
tracing::info!("Profile stop request received.");
match CANCEL_CHANNEL.lock().await.take() {
Some(tx) => {
if tx.send(()).is_err() {
tracing::error!("Failed to send cancellation signal.");
return JsonResponse::create_response(
StatusCode::INTERNAL_SERVER_ERROR,
"Failed to send cancellation signal",
);
}
JsonResponse::create_response(StatusCode::OK, "Profiling stopped successfully.")
}
None => JsonResponse::create_response(
StatusCode::PRECONDITION_FAILED,
"Profiling is not in progress, there is nothing to stop.",
),
}
}
/// The HTTP request handler for starting profiling the compute.
pub(in crate::http) async fn profile_start(
Json(request): Json<ProfileRequest>,
) -> impl IntoResponse {
tracing::info!("Profile start request received: {request:?}");
let tx = tokio::sync::broadcast::Sender::<()>::new(1);
{
let mut cancel_channel = CANCEL_CHANNEL.lock().await;
if cancel_channel.is_some() {
return JsonResponse::create_response(
StatusCode::CONFLICT,
"Profiling is already in progress.",
);
}
*cancel_channel = Some(tx.clone());
}
tracing::info!("Profiling will start with parameters: {request:?}");
let pg_pid = Pid::from_raw(crate::compute::PG_PID.load(Ordering::SeqCst) as _);
let run_with_sudo = !cfg!(feature = "testing");
let options = crate::profiling::ProfileGenerationOptions {
profiler: request.profiler,
run_with_sudo,
pids: [pg_pid].into_iter().collect(),
follow_forks: true,
sampling_frequency: request.sampling_frequency as u32,
blocklist_symbols: vec![
"libc".to_owned(),
"libgcc".to_owned(),
"pthread".to_owned(),
"vdso".to_owned(),
],
archive: request.archive,
};
let options = crate::profiling::ProfileGenerationTaskOptions {
options,
timeout: std::time::Duration::from_secs(request.timeout_seconds as u64),
should_stop: Some(tx),
};
let pprof_data = crate::profiling::generate_pprof_profile(options).await;
if CANCEL_CHANNEL.lock().await.take().is_none() {
tracing::error!("Profiling was cancelled from another request.");
return JsonResponse::create_response(
StatusCode::NO_CONTENT,
"Profiling was cancelled from another request.",
);
}
let pprof_data = match pprof_data {
Ok(data) => data,
Err(e) => {
tracing::error!(error = ?e, "failed to generate pprof data");
return JsonResponse::create_response(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to generate pprof data: {e:?}"),
);
}
};
tracing::info!("Profiling has completed successfully.");
let mut headers = http::HeaderMap::new();
if request.archive {
headers.insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/gzip"),
);
headers.insert(
http::header::CONTENT_DISPOSITION,
http::HeaderValue::from_static("attachment; filename=\"profile.pb.gz\""),
);
} else {
headers.insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/octet-stream"),
);
headers.insert(
http::header::CONTENT_DISPOSITION,
http::HeaderValue::from_static("attachment; filename=\"profile.pb\""),
);
}
(headers, pprof_data.0).into_response()
}

View File

@@ -27,6 +27,7 @@ use super::{
},
};
use crate::compute::ComputeNode;
use crate::http::routes::profile;
/// `compute_ctl` has two servers: internal and external. The internal server
/// binds to the loopback interface and handles communication from clients on
@@ -81,8 +82,14 @@ impl From<&Server> for Router<Arc<ComputeNode>> {
Server::External {
config, compute_id, ..
} => {
let unauthenticated_router =
Router::<Arc<ComputeNode>>::new().route("/metrics", get(metrics::get_metrics));
let unauthenticated_router = Router::<Arc<ComputeNode>>::new()
.route("/metrics", get(metrics::get_metrics))
.route(
"/profile/cpu",
get(profile::profile_status)
.post(profile::profile_start)
.delete(profile::profile_stop),
);
let authenticated_router = Router::<Arc<ComputeNode>>::new()
.route("/lfc/prewarm", get(lfc::prewarm_state).post(lfc::prewarm))

View File

@@ -2,7 +2,6 @@ use std::collections::HashMap;
use anyhow::Result;
use compute_api::responses::{InstalledExtension, InstalledExtensions};
use tokio_postgres::error::Error as PostgresError;
use tokio_postgres::{Client, Config, NoTls};
use crate::metrics::INSTALLED_EXTENSIONS;
@@ -11,7 +10,7 @@ use crate::metrics::INSTALLED_EXTENSIONS;
/// and to make database listing query here more explicit.
///
/// Limit the number of databases to 500 to avoid excessive load.
async fn list_dbs(client: &mut Client) -> Result<Vec<String>, PostgresError> {
async fn list_dbs(client: &mut Client) -> Result<Vec<String>> {
// `pg_database.datconnlimit = -2` means that the database is in the
// invalid state
let databases = client
@@ -38,9 +37,7 @@ async fn list_dbs(client: &mut Client) -> Result<Vec<String>, PostgresError> {
/// Same extension can be installed in multiple databases with different versions,
/// so we report a separate metric (number of databases where it is installed)
/// for each extension version.
pub async fn get_installed_extensions(
mut conf: Config,
) -> Result<InstalledExtensions, PostgresError> {
pub async fn get_installed_extensions(mut conf: Config) -> Result<InstalledExtensions> {
conf.application_name("compute_ctl:get_installed_extensions");
let databases: Vec<String> = {
let (mut client, connection) = conf.connect(NoTls).await?;

View File

@@ -24,6 +24,7 @@ pub mod monitor;
pub mod params;
pub mod pg_helpers;
pub mod pgbouncer;
pub mod profiling;
pub mod rsyslog;
pub mod spec;
mod spec_apply;

View File

@@ -4,13 +4,14 @@ use std::thread;
use std::time::{Duration, SystemTime};
use anyhow::{Result, bail};
use compute_api::spec::{ComputeMode, PageserverConnectionInfo, PageserverProtocol};
use compute_api::spec::{ComputeMode, PageserverProtocol};
use itertools::Itertools as _;
use pageserver_page_api as page_api;
use postgres::{NoTls, SimpleQueryMessage};
use tracing::{info, warn};
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
use utils::shard::TenantShardId;
use utils::shard::{ShardCount, ShardNumber, TenantShardId};
use crate::compute::ComputeNode;
@@ -77,16 +78,17 @@ fn acquire_lsn_lease_with_retry(
loop {
// Note: List of pageservers is dynamic, need to re-read configs before each attempt.
let (conninfo, auth) = {
let (connstrings, auth) = {
let state = compute.state.lock().unwrap();
let spec = state.pspec.as_ref().expect("spec must be set");
(
spec.pageserver_conninfo.clone(),
spec.pageserver_connstr.clone(),
spec.storage_auth_token.clone(),
)
};
let result = try_acquire_lsn_lease(conninfo, auth.as_deref(), tenant_id, timeline_id, lsn);
let result =
try_acquire_lsn_lease(&connstrings, auth.as_deref(), tenant_id, timeline_id, lsn);
match result {
Ok(Some(res)) => {
return Ok(res);
@@ -110,44 +112,35 @@ fn acquire_lsn_lease_with_retry(
/// Tries to acquire LSN leases on all Pageserver shards.
fn try_acquire_lsn_lease(
conninfo: PageserverConnectionInfo,
connstrings: &str,
auth: Option<&str>,
tenant_id: TenantId,
timeline_id: TimelineId,
lsn: Lsn,
) -> Result<Option<SystemTime>> {
let connstrings = connstrings.split(',').collect_vec();
let shard_count = connstrings.len();
let mut leases = Vec::new();
for (shard_index, shard) in conninfo.shards.into_iter() {
let tenant_shard_id = TenantShardId {
tenant_id,
shard_number: shard_index.shard_number,
shard_count: shard_index.shard_count,
for (shard_number, &connstring) in connstrings.iter().enumerate() {
let tenant_shard_id = match shard_count {
0 | 1 => TenantShardId::unsharded(tenant_id),
shard_count => TenantShardId {
tenant_id,
shard_number: ShardNumber(shard_number as u8),
shard_count: ShardCount::new(shard_count as u8),
},
};
// XXX: If there are more than pageserver for the one shard, do we need to get a
// leas on all of them? Currently, that's what we assume, but this is hypothetical
// as of this writing, as we never pass the info for more than one pageserver per
// shard.
for pageserver in shard.pageservers {
let lease = match conninfo.prefer_protocol {
PageserverProtocol::Grpc => acquire_lsn_lease_grpc(
&pageserver.grpc_url.unwrap(),
auth,
tenant_shard_id,
timeline_id,
lsn,
)?,
PageserverProtocol::Libpq => acquire_lsn_lease_libpq(
&pageserver.libpq_url.unwrap(),
auth,
tenant_shard_id,
timeline_id,
lsn,
)?,
};
leases.push(lease);
}
let lease = match PageserverProtocol::from_connstring(connstring)? {
PageserverProtocol::Libpq => {
acquire_lsn_lease_libpq(connstring, auth, tenant_shard_id, timeline_id, lsn)?
}
PageserverProtocol::Grpc => {
acquire_lsn_lease_grpc(connstring, auth, tenant_shard_id, timeline_id, lsn)?
}
};
leases.push(lease);
}
Ok(leases.into_iter().min().flatten())

View File

@@ -1 +0,0 @@
ALTER ROLE {privileged_role_name} BYPASSRLS;

View File

@@ -0,0 +1 @@
ALTER ROLE neon_superuser BYPASSRLS;

View File

@@ -15,7 +15,7 @@ DO $$
DECLARE
role_name text;
BEGIN
FOR role_name IN SELECT rolname FROM pg_roles WHERE pg_has_role(rolname, '{privileged_role_name}', 'member')
FOR role_name IN SELECT rolname FROM pg_roles WHERE pg_has_role(rolname, 'neon_superuser', 'member')
LOOP
RAISE NOTICE 'EXECUTING ALTER ROLE % INHERIT', quote_ident(role_name);
EXECUTE 'ALTER ROLE ' || quote_ident(role_name) || ' INHERIT';
@@ -23,7 +23,7 @@ BEGIN
FOR role_name IN SELECT rolname FROM pg_roles
WHERE
NOT pg_has_role(rolname, '{privileged_role_name}', 'member') AND NOT starts_with(rolname, 'pg_')
NOT pg_has_role(rolname, 'neon_superuser', 'member') AND NOT starts_with(rolname, 'pg_')
LOOP
RAISE NOTICE 'EXECUTING ALTER ROLE % NOBYPASSRLS', quote_ident(role_name);
EXECUTE 'ALTER ROLE ' || quote_ident(role_name) || ' NOBYPASSRLS';

View File

@@ -1,6 +1,6 @@
DO $$
BEGIN
IF (SELECT setting::numeric >= 160000 FROM pg_settings WHERE name = 'server_version_num') THEN
EXECUTE 'GRANT pg_create_subscription TO {privileged_role_name}';
EXECUTE 'GRANT pg_create_subscription TO neon_superuser';
END IF;
END $$;

View File

@@ -0,0 +1 @@
GRANT pg_monitor TO neon_superuser WITH ADMIN OPTION;

View File

@@ -1 +0,0 @@
GRANT pg_monitor TO {privileged_role_name} WITH ADMIN OPTION;

View File

@@ -1,4 +1,4 @@
-- SKIP: Deemed insufficient for allowing relations created by extensions to be
-- interacted with by {privileged_role_name} without permission issues.
-- interacted with by neon_superuser without permission issues.
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO {privileged_role_name};
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO neon_superuser;

View File

@@ -1,4 +1,4 @@
-- SKIP: Deemed insufficient for allowing relations created by extensions to be
-- interacted with by {privileged_role_name} without permission issues.
-- interacted with by neon_superuser without permission issues.
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO {privileged_role_name};
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO neon_superuser;

View File

@@ -1,3 +1,3 @@
-- SKIP: Moved inline to the handle_grants() functions.
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO {privileged_role_name} WITH GRANT OPTION;
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON TABLES TO neon_superuser WITH GRANT OPTION;

View File

@@ -1,3 +1,3 @@
-- SKIP: Moved inline to the handle_grants() functions.
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO {privileged_role_name} WITH GRANT OPTION;
ALTER DEFAULT PRIVILEGES IN SCHEMA public GRANT ALL ON SEQUENCES TO neon_superuser WITH GRANT OPTION;

View File

@@ -1,7 +1,7 @@
DO $$
BEGIN
IF (SELECT setting::numeric >= 160000 FROM pg_settings WHERE name = 'server_version_num') THEN
EXECUTE 'GRANT EXECUTE ON FUNCTION pg_export_snapshot TO {privileged_role_name}';
EXECUTE 'GRANT EXECUTE ON FUNCTION pg_log_standby_snapshot TO {privileged_role_name}';
EXECUTE 'GRANT EXECUTE ON FUNCTION pg_export_snapshot TO neon_superuser';
EXECUTE 'GRANT EXECUTE ON FUNCTION pg_log_standby_snapshot TO neon_superuser';
END IF;
END $$;

View File

@@ -0,0 +1 @@
GRANT EXECUTE ON FUNCTION pg_show_replication_origin_status TO neon_superuser;

View File

@@ -1 +0,0 @@
GRANT EXECUTE ON FUNCTION pg_show_replication_origin_status TO {privileged_role_name};

View File

@@ -0,0 +1 @@
GRANT pg_signal_backend TO neon_superuser WITH ADMIN OPTION;

View File

@@ -1 +0,0 @@
GRANT pg_signal_backend TO {privileged_role_name} WITH ADMIN OPTION;

File diff suppressed because it is too large Load Diff

View File

@@ -9,7 +9,6 @@ use reqwest::StatusCode;
use tokio_postgres::Client;
use tracing::{error, info, instrument};
use crate::compute::ComputeNodeParams;
use crate::config;
use crate::metrics::{CPLANE_REQUESTS_TOTAL, CPlaneRequestRPC, UNKNOWN_HTTP_STATUS};
use crate::migration::MigrationRunner;
@@ -170,7 +169,7 @@ pub async fn handle_neon_extension_upgrade(client: &mut Client) -> Result<()> {
}
#[instrument(skip_all)]
pub async fn handle_migrations(params: ComputeNodeParams, client: &mut Client) -> Result<()> {
pub async fn handle_migrations(client: &mut Client) -> Result<()> {
info!("handle migrations");
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
@@ -179,59 +178,26 @@ pub async fn handle_migrations(params: ComputeNodeParams, client: &mut Client) -
// Add new migrations in numerical order.
let migrations = [
&format!(
include_str!("./migrations/0001-add_bypass_rls_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
include_str!("./migrations/0001-neon_superuser_bypass_rls.sql"),
include_str!("./migrations/0002-alter_roles.sql"),
include_str!("./migrations/0003-grant_pg_create_subscription_to_neon_superuser.sql"),
include_str!("./migrations/0004-grant_pg_monitor_to_neon_superuser.sql"),
include_str!("./migrations/0005-grant_all_on_tables_to_neon_superuser.sql"),
include_str!("./migrations/0006-grant_all_on_sequences_to_neon_superuser.sql"),
include_str!(
"./migrations/0007-grant_all_on_tables_to_neon_superuser_with_grant_option.sql"
),
&format!(
include_str!("./migrations/0002-alter_roles.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0003-grant_pg_create_subscription_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0004-grant_pg_monitor_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0005-grant_all_on_tables_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0006-grant_all_on_sequences_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!(
"./migrations/0007-grant_all_on_tables_with_grant_option_to_privileged_role.sql"
),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!(
"./migrations/0008-grant_all_on_sequences_with_grant_option_to_privileged_role.sql"
),
privileged_role_name = params.privileged_role_name
include_str!(
"./migrations/0008-grant_all_on_sequences_to_neon_superuser_with_grant_option.sql"
),
include_str!("./migrations/0009-revoke_replication_for_previously_allowed_roles.sql"),
&format!(
include_str!(
"./migrations/0010-grant_snapshot_synchronization_funcs_to_privileged_role.sql"
),
privileged_role_name = params.privileged_role_name
include_str!(
"./migrations/0010-grant_snapshot_synchronization_funcs_to_neon_superuser.sql"
),
&format!(
include_str!(
"./migrations/0011-grant_pg_show_replication_origin_status_to_privileged_role.sql"
),
privileged_role_name = params.privileged_role_name
),
&format!(
include_str!("./migrations/0012-grant_pg_signal_backend_to_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
include_str!(
"./migrations/0011-grant_pg_show_replication_origin_status_to_neon_superuser.sql"
),
include_str!("./migrations/0012-grant_pg_signal_backend_to_neon_superuser.sql"),
];
MigrationRunner::new(client, &migrations)

View File

@@ -13,14 +13,14 @@ use tokio_postgres::Client;
use tokio_postgres::error::SqlState;
use tracing::{Instrument, debug, error, info, info_span, instrument, warn};
use crate::compute::{ComputeNode, ComputeNodeParams, ComputeState};
use crate::compute::{ComputeNode, ComputeState};
use crate::pg_helpers::{
DatabaseExt, Escaping, GenericOptionsSearch, RoleExt, get_existing_dbs_async,
get_existing_roles_async,
};
use crate::spec_apply::ApplySpecPhase::{
CreateAndAlterDatabases, CreateAndAlterRoles, CreateAvailabilityCheck, CreatePgauditExtension,
CreatePgauditlogtofileExtension, CreatePrivilegedRole, CreateSchemaNeon,
CreateAndAlterDatabases, CreateAndAlterRoles, CreateAvailabilityCheck, CreateNeonSuperuser,
CreatePgauditExtension, CreatePgauditlogtofileExtension, CreateSchemaNeon,
DisablePostgresDBPgAudit, DropInvalidDatabases, DropRoles, FinalizeDropLogicalSubscriptions,
HandleNeonExtension, HandleOtherExtensions, RenameAndDeleteDatabases, RenameRoles,
RunInEachDatabase,
@@ -49,7 +49,6 @@ impl ComputeNode {
// Proceed with post-startup configuration. Note, that order of operations is important.
let client = Self::get_maintenance_client(&conf).await?;
let spec = spec.clone();
let params = Arc::new(self.params.clone());
let databases = get_existing_dbs_async(&client).await?;
let roles = get_existing_roles_async(&client)
@@ -158,7 +157,6 @@ impl ComputeNode {
let conf = Arc::new(conf);
let fut = Self::apply_spec_sql_db(
params.clone(),
spec.clone(),
conf,
ctx.clone(),
@@ -187,7 +185,7 @@ impl ComputeNode {
}
for phase in [
CreatePrivilegedRole,
CreateNeonSuperuser,
DropInvalidDatabases,
RenameRoles,
CreateAndAlterRoles,
@@ -197,7 +195,6 @@ impl ComputeNode {
] {
info!("Applying phase {:?}", &phase);
apply_operations(
params.clone(),
spec.clone(),
ctx.clone(),
jwks_roles.clone(),
@@ -246,7 +243,6 @@ impl ComputeNode {
}
let fut = Self::apply_spec_sql_db(
params.clone(),
spec.clone(),
conf,
ctx.clone(),
@@ -297,7 +293,6 @@ impl ComputeNode {
for phase in phases {
debug!("Applying phase {:?}", &phase);
apply_operations(
params.clone(),
spec.clone(),
ctx.clone(),
jwks_roles.clone(),
@@ -318,9 +313,7 @@ impl ComputeNode {
/// May opt to not connect to databases that don't have any scheduled
/// operations. The function is concurrency-controlled with the provided
/// semaphore. The caller has to make sure the semaphore isn't exhausted.
#[allow(clippy::too_many_arguments)] // TODO: needs bigger refactoring
async fn apply_spec_sql_db(
params: Arc<ComputeNodeParams>,
spec: Arc<ComputeSpec>,
conf: Arc<tokio_postgres::Config>,
ctx: Arc<tokio::sync::RwLock<MutableApplyContext>>,
@@ -335,7 +328,6 @@ impl ComputeNode {
for subphase in subphases {
apply_operations(
params.clone(),
spec.clone(),
ctx.clone(),
jwks_roles.clone(),
@@ -475,7 +467,7 @@ pub enum PerDatabasePhase {
#[derive(Clone, Debug)]
pub enum ApplySpecPhase {
CreatePrivilegedRole,
CreateNeonSuperuser,
DropInvalidDatabases,
RenameRoles,
CreateAndAlterRoles,
@@ -518,7 +510,6 @@ pub struct MutableApplyContext {
/// - No timeouts have (yet) been implemented.
/// - The caller is responsible for limiting and/or applying concurrency.
pub async fn apply_operations<'a, Fut, F>(
params: Arc<ComputeNodeParams>,
spec: Arc<ComputeSpec>,
ctx: Arc<RwLock<MutableApplyContext>>,
jwks_roles: Arc<HashSet<String>>,
@@ -536,7 +527,7 @@ where
debug!("Processing phase {:?}", &apply_spec_phase);
let ctx = ctx;
let mut ops = get_operations(&params, &spec, &ctx, &jwks_roles, &apply_spec_phase)
let mut ops = get_operations(&spec, &ctx, &jwks_roles, &apply_spec_phase)
.await?
.peekable();
@@ -597,18 +588,14 @@ where
/// sort/merge/batch execution, but for now this is a nice way to improve
/// batching behavior of the commands.
async fn get_operations<'a>(
params: &'a ComputeNodeParams,
spec: &'a ComputeSpec,
ctx: &'a RwLock<MutableApplyContext>,
jwks_roles: &'a HashSet<String>,
apply_spec_phase: &'a ApplySpecPhase,
) -> Result<Box<dyn Iterator<Item = Operation> + 'a + Send>> {
match apply_spec_phase {
ApplySpecPhase::CreatePrivilegedRole => Ok(Box::new(once(Operation {
query: format!(
include_str!("sql/create_privileged_role.sql"),
privileged_role_name = params.privileged_role_name
),
ApplySpecPhase::CreateNeonSuperuser => Ok(Box::new(once(Operation {
query: include_str!("sql/create_neon_superuser.sql").to_string(),
comment: None,
}))),
ApplySpecPhase::DropInvalidDatabases => {
@@ -710,9 +697,8 @@ async fn get_operations<'a>(
None => {
let query = if !jwks_roles.contains(role.name.as_str()) {
format!(
"CREATE ROLE {} INHERIT CREATEROLE CREATEDB BYPASSRLS REPLICATION IN ROLE {} {}",
"CREATE ROLE {} INHERIT CREATEROLE CREATEDB BYPASSRLS REPLICATION IN ROLE neon_superuser {}",
role.name.pg_quote(),
params.privileged_role_name,
role.to_pg_options(),
)
} else {
@@ -863,9 +849,8 @@ async fn get_operations<'a>(
// ALL PRIVILEGES grants CREATE, CONNECT, and TEMPORARY on the database
// (see https://www.postgresql.org/docs/current/ddl-priv.html)
query: format!(
"GRANT ALL PRIVILEGES ON DATABASE {} TO {}",
db.name.pg_quote(),
params.privileged_role_name
"GRANT ALL PRIVILEGES ON DATABASE {} TO neon_superuser",
db.name.pg_quote()
),
comment: None,
},

View File

@@ -0,0 +1,8 @@
DO $$
BEGIN
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = 'neon_superuser')
THEN
CREATE ROLE neon_superuser CREATEDB CREATEROLE NOLOGIN REPLICATION BYPASSRLS IN ROLE pg_read_all_data, pg_write_all_data;
END IF;
END
$$;

View File

@@ -1,8 +0,0 @@
DO $$
BEGIN
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{privileged_role_name}')
THEN
CREATE ROLE {privileged_role_name} CREATEDB CREATEROLE NOLOGIN REPLICATION BYPASSRLS IN ROLE pg_read_all_data, pg_write_all_data;
END IF;
END
$$;

View File

@@ -8,10 +8,10 @@ code changes locally, but not suitable for running production systems.
## Example: Start with Postgres 16
To create and start a local development environment with Postgres 16, you will need to provide `--pg-version` flag to 2 of the start-up commands.
To create and start a local development environment with Postgres 16, you will need to provide `--pg-version` flag to 3 of the start-up commands.
```shell
cargo neon init
cargo neon init --pg-version 16
cargo neon start
cargo neon tenant create --set-default --pg-version 16
cargo neon endpoint create main --pg-version 16

View File

@@ -16,14 +16,9 @@ use std::time::Duration;
use anyhow::{Context, Result, anyhow, bail};
use clap::Parser;
use compute_api::requests::ComputeClaimsScope;
use compute_api::spec::{
ComputeMode, PageserverConnectionInfo, PageserverProtocol, PageserverShardInfo,
};
use compute_api::spec::{ComputeMode, PageserverProtocol};
use control_plane::broker::StorageBroker;
use control_plane::endpoint::{ComputeControlPlane, EndpointTerminateMode};
use control_plane::endpoint::{
pageserver_conf_to_shard_conn_info, tenant_locate_response_to_conn_info,
};
use control_plane::endpoint_storage::{ENDPOINT_STORAGE_DEFAULT_ADDR, EndpointStorage};
use control_plane::local_env;
use control_plane::local_env::{
@@ -49,6 +44,7 @@ use pageserver_api::models::{
};
use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize, TenantShardId};
use postgres_backend::AuthType;
use postgres_connection::parse_host_port;
use safekeeper_api::membership::{SafekeeperGeneration, SafekeeperId};
use safekeeper_api::{
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT,
@@ -56,11 +52,11 @@ use safekeeper_api::{
};
use storage_broker::DEFAULT_LISTEN_ADDR as DEFAULT_BROKER_ADDR;
use tokio::task::JoinSet;
use url::Host;
use utils::auth::{Claims, Scope};
use utils::id::{NodeId, TenantId, TenantTimelineId, TimelineId};
use utils::lsn::Lsn;
use utils::project_git_version;
use utils::shard::ShardIndex;
// Default id of a safekeeper node, if not specified on the command line.
const DEFAULT_SAFEKEEPER_ID: NodeId = NodeId(1);
@@ -635,10 +631,6 @@ struct EndpointCreateCmdArgs {
help = "Allow multiple primary endpoints running on the same branch. Shouldn't be used normally, but useful for tests."
)]
allow_multiple: bool,
/// Only allow changing it on creation
#[clap(long, help = "Name of the privileged role for the endpoint")]
privileged_role_name: Option<String>,
}
#[derive(clap::Args)]
@@ -1488,7 +1480,6 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
args.grpc,
!args.update_catalog,
false,
args.privileged_role_name.clone(),
)?;
}
EndpointCmd::Start(args) => {
@@ -1525,56 +1516,62 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
)?;
}
let prefer_protocol = if endpoint.grpc {
PageserverProtocol::Grpc
} else {
PageserverProtocol::Libpq
};
let mut pageserver_conninfo = if let Some(ps_id) = pageserver_id {
let conf = env.get_pageserver_conf(ps_id).unwrap();
let ps_conninfo = pageserver_conf_to_shard_conn_info(conf)?;
let shard_info = PageserverShardInfo {
pageservers: vec![ps_conninfo],
let (pageservers, stripe_size) = if let Some(pageserver_id) = pageserver_id {
let conf = env.get_pageserver_conf(pageserver_id).unwrap();
// Use gRPC if requested.
let pageserver = if endpoint.grpc {
let grpc_addr = conf.listen_grpc_addr.as_ref().expect("bad config");
let (host, port) = parse_host_port(grpc_addr)?;
let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT);
(PageserverProtocol::Grpc, host, port)
} else {
let (host, port) = parse_host_port(&conf.listen_pg_addr)?;
let port = port.unwrap_or(5432);
(PageserverProtocol::Libpq, host, port)
};
// If caller is telling us what pageserver to use, this is not a tenant which is
// fully managed by storage controller, therefore not sharded.
let shards: HashMap<_, _> = vec![(ShardIndex::unsharded(), shard_info)]
.into_iter()
.collect();
PageserverConnectionInfo {
shard_count: ShardCount(0),
stripe_size: None,
shards,
prefer_protocol,
}
(vec![pageserver], DEFAULT_STRIPE_SIZE)
} else {
// Look up the currently attached location of the tenant, and its striping metadata,
// to pass these on to postgres.
let storage_controller = StorageController::from_env(env);
let locate_result = storage_controller.tenant_locate(endpoint.tenant_id).await?;
assert!(!locate_result.shards.is_empty());
// Initialize LSN leases for static computes.
if let ComputeMode::Static(lsn) = endpoint.mode {
futures::future::try_join_all(locate_result.shards.iter().map(
|shard| async move {
let pageservers = futures::future::try_join_all(
locate_result.shards.into_iter().map(|shard| async move {
if let ComputeMode::Static(lsn) = endpoint.mode {
// Initialize LSN leases for static computes.
let conf = env.get_pageserver_conf(shard.node_id).unwrap();
let pageserver = PageServerNode::from_env(env, conf);
pageserver
.http_client
.timeline_init_lsn_lease(shard.shard_id, endpoint.timeline_id, lsn)
.await
},
))
.await?;
}
.await?;
}
tenant_locate_response_to_conn_info(&locate_result)?
let pageserver = if endpoint.grpc {
(
PageserverProtocol::Grpc,
Host::parse(&shard.listen_grpc_addr.expect("no gRPC address"))?,
shard.listen_grpc_port.expect("no gRPC port"),
)
} else {
(
PageserverProtocol::Libpq,
Host::parse(&shard.listen_pg_addr)?,
shard.listen_pg_port,
)
};
anyhow::Ok(pageserver)
}),
)
.await?;
let stripe_size = locate_result.shard_params.stripe_size;
(pageservers, stripe_size)
};
pageserver_conninfo.prefer_protocol = prefer_protocol;
assert!(!pageservers.is_empty());
let ps_conf = env.get_pageserver_conf(DEFAULT_PAGESERVER_ID)?;
let auth_token = if matches!(ps_conf.pg_auth_type, AuthType::NeonJWT) {
@@ -1604,8 +1601,9 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
endpoint_storage_addr,
safekeepers_generation,
safekeepers,
pageserver_conninfo,
pageservers,
remote_ext_base_url: remote_ext_base_url.clone(),
shard_stripe_size: stripe_size.0 as usize,
create_test_user: args.create_test_user,
start_timeout: args.start_timeout,
autoprewarm: args.autoprewarm,
@@ -1622,45 +1620,51 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
.endpoints
.get(endpoint_id.as_str())
.with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?;
let prefer_protocol = if endpoint.grpc {
PageserverProtocol::Grpc
} else {
PageserverProtocol::Libpq
};
let mut pageserver_conninfo = if let Some(ps_id) = args.endpoint_pageserver_id {
let pageservers = if let Some(ps_id) = args.endpoint_pageserver_id {
let conf = env.get_pageserver_conf(ps_id)?;
let ps_conninfo = pageserver_conf_to_shard_conn_info(conf)?;
let shard_info = PageserverShardInfo {
pageservers: vec![ps_conninfo],
// Use gRPC if requested.
let pageserver = if endpoint.grpc {
let grpc_addr = conf.listen_grpc_addr.as_ref().expect("bad config");
let (host, port) = parse_host_port(grpc_addr)?;
let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT);
(PageserverProtocol::Grpc, host, port)
} else {
let (host, port) = parse_host_port(&conf.listen_pg_addr)?;
let port = port.unwrap_or(5432);
(PageserverProtocol::Libpq, host, port)
};
// If caller is telling us what pageserver to use, this is not a tenant which is
// fully managed by storage controller, therefore not sharded.
let shards: HashMap<_, _> = vec![(ShardIndex::unsharded(), shard_info)]
.into_iter()
.collect();
PageserverConnectionInfo {
shard_count: ShardCount::unsharded(),
stripe_size: None,
shards,
prefer_protocol,
}
vec![pageserver]
} else {
// Look up the currently attached location of the tenant, and its striping metadata,
// to pass these on to postgres.
let storage_controller = StorageController::from_env(env);
let locate_result = storage_controller.tenant_locate(endpoint.tenant_id).await?;
tenant_locate_response_to_conn_info(&locate_result)?
storage_controller
.tenant_locate(endpoint.tenant_id)
.await?
.shards
.into_iter()
.map(|shard| {
// Use gRPC if requested.
if endpoint.grpc {
(
PageserverProtocol::Grpc,
Host::parse(&shard.listen_grpc_addr.expect("no gRPC address"))
.expect("bad hostname"),
shard.listen_grpc_port.expect("no gRPC port"),
)
} else {
(
PageserverProtocol::Libpq,
Host::parse(&shard.listen_pg_addr).expect("bad hostname"),
shard.listen_pg_port,
)
}
})
.collect::<Vec<_>>()
};
pageserver_conninfo.prefer_protocol = prefer_protocol;
// If --safekeepers argument is given, use only the listed
// safekeeper nodes; otherwise all from the env.
let safekeepers = parse_safekeepers(&args.safekeepers)?;
endpoint
.reconfigure(Some(&pageserver_conninfo), safekeepers, None)
.reconfigure(Some(pageservers), None, safekeepers, None)
.await?;
}
EndpointCmd::Stop(args) => {

View File

@@ -36,7 +36,7 @@ impl StorageBroker {
pub async fn start(&self, retry_timeout: &Duration) -> anyhow::Result<()> {
let broker = &self.env.broker;
println!("Starting neon broker at {}", broker.client_url());
print!("Starting neon broker at {}", broker.client_url());
let mut args = Vec::new();

View File

@@ -32,12 +32,11 @@
//! config.json - passed to `compute_ctl`
//! pgdata/
//! postgresql.conf - copy of postgresql.conf created by `compute_ctl`
//! neon.signal
//! zenith.signal - copy of neon.signal, for backward compatibility
//! zenith.signal
//! <other PostgreSQL files>
//! ```
//!
use std::collections::{BTreeMap, HashMap};
use std::collections::BTreeMap;
use std::fmt::Display;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
use std::path::PathBuf;
@@ -58,17 +57,14 @@ use compute_api::responses::{
};
use compute_api::spec::{
Cluster, ComputeAudit, ComputeFeature, ComputeMode, ComputeSpec, Database, PageserverProtocol,
PageserverShardInfo, PgIdent, RemoteExtSpec, Role,
PgIdent, RemoteExtSpec, Role,
};
// re-export these, because they're used in the reconfigure() function
pub use compute_api::spec::{PageserverConnectionInfo, PageserverShardConnectionInfo};
use jsonwebtoken::jwk::{
AlgorithmParameters, CommonParameters, EllipticCurve, Jwk, JwkSet, KeyAlgorithm, KeyOperations,
OctetKeyPairParameters, OctetKeyPairType, PublicKeyUse,
};
use nix::sys::signal::{Signal, kill};
use pageserver_api::shard::ShardStripeSize;
use pem::Pem;
use reqwest::header::CONTENT_TYPE;
use safekeeper_api::PgMajorVersion;
@@ -78,11 +74,8 @@ use sha2::{Digest, Sha256};
use spki::der::Decode;
use spki::{SubjectPublicKeyInfo, SubjectPublicKeyInfoRef};
use tracing::debug;
use url::Host;
use utils::id::{NodeId, TenantId, TimelineId};
use utils::shard::{ShardIndex, ShardNumber};
use pageserver_api::config::DEFAULT_GRPC_LISTEN_PORT as DEFAULT_PAGESERVER_GRPC_PORT;
use postgres_connection::parse_host_port;
use crate::local_env::LocalEnv;
use crate::postgresql_conf::PostgresConf;
@@ -105,7 +98,6 @@ pub struct EndpointConf {
features: Vec<ComputeFeature>,
cluster: Option<Cluster>,
compute_ctl_config: ComputeCtlConfig,
privileged_role_name: Option<String>,
}
//
@@ -206,7 +198,6 @@ impl ComputeControlPlane {
grpc: bool,
skip_pg_catalog_updates: bool,
drop_subscriptions_before_start: bool,
privileged_role_name: Option<String>,
) -> Result<Arc<Endpoint>> {
let pg_port = pg_port.unwrap_or_else(|| self.get_port());
let external_http_port = external_http_port.unwrap_or_else(|| self.get_port() + 1);
@@ -244,7 +235,6 @@ impl ComputeControlPlane {
features: vec![],
cluster: None,
compute_ctl_config: compute_ctl_config.clone(),
privileged_role_name: privileged_role_name.clone(),
});
ep.create_endpoint_dir()?;
@@ -266,7 +256,6 @@ impl ComputeControlPlane {
features: vec![],
cluster: None,
compute_ctl_config,
privileged_role_name,
})?,
)?;
std::fs::write(
@@ -342,9 +331,6 @@ pub struct Endpoint {
/// The compute_ctl config for the endpoint's compute.
compute_ctl_config: ComputeCtlConfig,
/// The name of the privileged role for the endpoint.
privileged_role_name: Option<String>,
}
#[derive(PartialEq, Eq)]
@@ -393,8 +379,9 @@ pub struct EndpointStartArgs {
pub endpoint_storage_addr: String,
pub safekeepers_generation: Option<SafekeeperGeneration>,
pub safekeepers: Vec<NodeId>,
pub pageserver_conninfo: PageserverConnectionInfo,
pub pageservers: Vec<(PageserverProtocol, Host, u16)>,
pub remote_ext_base_url: Option<String>,
pub shard_stripe_size: usize,
pub create_test_user: bool,
pub start_timeout: Duration,
pub autoprewarm: bool,
@@ -444,7 +431,6 @@ impl Endpoint {
features: conf.features,
cluster: conf.cluster,
compute_ctl_config: conf.compute_ctl_config,
privileged_role_name: conf.privileged_role_name,
})
}
@@ -477,7 +463,7 @@ impl Endpoint {
conf.append("max_connections", "100");
conf.append("wal_level", "logical");
// wal_sender_timeout is the maximum time to wait for WAL replication.
// It also defines how often the walreceiver will send a feedback message to the wal sender.
// It also defines how often the walreciever will send a feedback message to the wal sender.
conf.append("wal_sender_timeout", "5s");
conf.append("listen_addresses", &self.pg_address.ip().to_string());
conf.append("port", &self.pg_address.port().to_string());
@@ -667,6 +653,14 @@ impl Endpoint {
}
}
fn build_pageserver_connstr(pageservers: &[(PageserverProtocol, Host, u16)]) -> String {
pageservers
.iter()
.map(|(scheme, host, port)| format!("{scheme}://no_user@{host}:{port}"))
.collect::<Vec<_>>()
.join(",")
}
/// Map safekeepers ids to the actual connection strings.
fn build_safekeepers_connstrs(&self, sk_ids: Vec<NodeId>) -> Result<Vec<String>> {
let mut safekeeper_connstrings = Vec::new();
@@ -712,6 +706,9 @@ impl Endpoint {
std::fs::remove_dir_all(self.pgdata())?;
}
let pageserver_connstring = Self::build_pageserver_connstr(&args.pageservers);
assert!(!pageserver_connstring.is_empty());
let safekeeper_connstrings = self.build_safekeepers_connstrs(args.safekeepers)?;
// check for file remote_extensions_spec.json
@@ -726,46 +723,6 @@ impl Endpoint {
remote_extensions = None;
};
// For the sake of backwards-compatibility, also fill in 'pageserver_connstring'
//
// XXX: I believe this is not really needed, except to make
// test_forward_compatibility happy.
//
// Use a closure so that we can conviniently return None in the middle of the
// loop.
let pageserver_connstring = (|| {
let num_shards = if args.pageserver_conninfo.shard_count.is_unsharded() {
1
} else {
args.pageserver_conninfo.shard_count.0
};
let mut connstrings = Vec::new();
for shard_no in 0..num_shards {
let shard_index = ShardIndex {
shard_count: args.pageserver_conninfo.shard_count,
shard_number: ShardNumber(shard_no),
};
let shard = args
.pageserver_conninfo
.shards
.get(&shard_index)
.expect(&format!(
"shard {} not found in pageserver_connection_info",
shard_index
));
let pageserver = shard
.pageservers
.first()
.expect("must have at least one pageserver");
if let Some(libpq_url) = &pageserver.libpq_url {
connstrings.push(libpq_url.clone());
} else {
return None;
}
}
Some(connstrings.join(","))
})();
// Create config file
let config = {
let mut spec = ComputeSpec {
@@ -810,14 +767,13 @@ impl Endpoint {
branch_id: None,
endpoint_id: Some(self.endpoint_id.clone()),
mode: self.mode,
pageserver_connection_info: Some(args.pageserver_conninfo.clone()),
pageserver_connstring,
pageserver_connstring: Some(pageserver_connstring),
safekeepers_generation: args.safekeepers_generation.map(|g| g.into_inner()),
safekeeper_connstrings,
storage_auth_token: args.auth_token.clone(),
remote_extensions,
pgbouncer_settings: None,
shard_stripe_size: args.pageserver_conninfo.stripe_size, // redundant with pageserver_connection_info.stripe_size
shard_stripe_size: Some(args.shard_stripe_size),
local_proxy_config: None,
reconfigure_concurrency: self.reconfigure_concurrency,
drop_subscriptions_before_start: self.drop_subscriptions_before_start,
@@ -913,10 +869,6 @@ impl Endpoint {
cmd.arg("--dev");
}
if let Some(privileged_role_name) = self.privileged_role_name.clone() {
cmd.args(["--privileged-role-name", &privileged_role_name]);
}
let child = cmd.spawn()?;
// set up a scopeguard to kill & wait for the child in case we panic or bail below
let child = scopeguard::guard(child, |mut child| {
@@ -1029,7 +981,8 @@ impl Endpoint {
pub async fn reconfigure(
&self,
pageserver_conninfo: Option<&PageserverConnectionInfo>,
pageservers: Option<Vec<(PageserverProtocol, Host, u16)>>,
stripe_size: Option<ShardStripeSize>,
safekeepers: Option<Vec<NodeId>>,
safekeeper_generation: Option<SafekeeperGeneration>,
) -> Result<()> {
@@ -1044,15 +997,15 @@ impl Endpoint {
let postgresql_conf = self.read_postgresql_conf()?;
spec.cluster.postgresql_conf = Some(postgresql_conf);
if let Some(pageserver_conninfo) = pageserver_conninfo {
// If pageservers are provided, we need to ensure that they are not empty.
// This is a requirement for the compute_ctl configuration.
anyhow::ensure!(
!pageserver_conninfo.shards.is_empty(),
"no pageservers provided"
);
spec.pageserver_connection_info = Some(pageserver_conninfo.clone());
spec.shard_stripe_size = pageserver_conninfo.stripe_size;
// If pageservers are not specified, don't change them.
if let Some(pageservers) = pageservers {
anyhow::ensure!(!pageservers.is_empty(), "no pageservers provided");
let pageserver_connstr = Self::build_pageserver_connstr(&pageservers);
spec.pageserver_connstring = Some(pageserver_connstr);
if stripe_size.is_some() {
spec.shard_stripe_size = stripe_size.map(|s| s.0 as usize);
}
}
// If safekeepers are not specified, don't change them.
@@ -1101,9 +1054,11 @@ impl Endpoint {
pub async fn reconfigure_pageservers(
&self,
pageservers: &PageserverConnectionInfo,
pageservers: Vec<(PageserverProtocol, Host, u16)>,
stripe_size: Option<ShardStripeSize>,
) -> Result<()> {
self.reconfigure(Some(pageservers), None, None).await
self.reconfigure(Some(pageservers), stripe_size, None, None)
.await
}
pub async fn reconfigure_safekeepers(
@@ -1111,7 +1066,7 @@ impl Endpoint {
safekeepers: Vec<NodeId>,
generation: SafekeeperGeneration,
) -> Result<()> {
self.reconfigure(None, Some(safekeepers), Some(generation))
self.reconfigure(None, None, Some(safekeepers), Some(generation))
.await
}
@@ -1167,68 +1122,3 @@ impl Endpoint {
)
}
}
pub fn pageserver_conf_to_shard_conn_info(
conf: &crate::local_env::PageServerConf,
) -> Result<PageserverShardConnectionInfo> {
let libpq_url = {
let (host, port) = parse_host_port(&conf.listen_pg_addr)?;
let port = port.unwrap_or(5432);
Some(format!("postgres://no_user@{host}:{port}"))
};
let grpc_url = if let Some(grpc_addr) = &conf.listen_grpc_addr {
let (host, port) = parse_host_port(grpc_addr)?;
let port = port.unwrap_or(DEFAULT_PAGESERVER_GRPC_PORT);
Some(format!("grpc://no_user@{host}:{port}"))
} else {
None
};
Ok(PageserverShardConnectionInfo {
id: Some(conf.id.to_string()),
libpq_url,
grpc_url,
})
}
pub fn tenant_locate_response_to_conn_info(
response: &pageserver_api::controller_api::TenantLocateResponse,
) -> Result<PageserverConnectionInfo> {
let mut shards = HashMap::new();
for shard in response.shards.iter() {
tracing::info!("parsing {}", shard.listen_pg_addr);
let libpq_url = {
let host = &shard.listen_pg_addr;
let port = shard.listen_pg_port;
Some(format!("postgres://no_user@{host}:{port}"))
};
let grpc_url = if let Some(grpc_addr) = &shard.listen_grpc_addr {
let host = grpc_addr;
let port = shard.listen_grpc_port.expect("no gRPC port");
Some(format!("grpc://no_user@{host}:{port}"))
} else {
None
};
let shard_info = PageserverShardInfo {
pageservers: vec![PageserverShardConnectionInfo {
id: Some(shard.node_id.to_string()),
libpq_url,
grpc_url,
}],
};
shards.insert(shard.shard_id.to_index(), shard_info);
}
let stripe_size = if response.shard_params.count.is_unsharded() {
None
} else {
Some(response.shard_params.stripe_size.0)
};
Ok(PageserverConnectionInfo {
shard_count: response.shard_params.count,
stripe_size,
shards,
prefer_protocol: PageserverProtocol::default(),
})
}

View File

@@ -217,9 +217,6 @@ pub struct NeonStorageControllerConf {
pub posthog_config: Option<PostHogConfig>,
pub kick_secondary_downloads: Option<bool>,
#[serde(with = "humantime_serde")]
pub shard_split_request_timeout: Option<Duration>,
}
impl NeonStorageControllerConf {
@@ -253,7 +250,6 @@ impl Default for NeonStorageControllerConf {
timeline_safekeeper_count: None,
posthog_config: None,
kick_secondary_downloads: None,
shard_split_request_timeout: None,
}
}
}

View File

@@ -303,7 +303,7 @@ impl PageServerNode {
async fn start_node(&self, retry_timeout: &Duration) -> anyhow::Result<()> {
// TODO: using a thread here because start_process() is not async but we need to call check_status()
let datadir = self.repo_path();
println!(
print!(
"Starting pageserver node {} at '{}' in {:?}, retrying for {:?}",
self.conf.id,
self.pg_connection_config.raw_address(),

View File

@@ -127,7 +127,7 @@ impl SafekeeperNode {
extra_opts: &[String],
retry_timeout: &Duration,
) -> anyhow::Result<()> {
println!(
print!(
"Starting safekeeper at '{}' in '{}', retrying for {:?}",
self.pg_connection_config.raw_address(),
self.datadir_path().display(),

View File

@@ -648,13 +648,6 @@ impl StorageController {
args.push(format!("--timeline-safekeeper-count={sk_cnt}"));
}
if let Some(duration) = self.config.shard_split_request_timeout {
args.push(format!(
"--shard-split-request-timeout={}",
humantime::Duration::from(duration)
));
}
let mut envs = vec![
("LD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()),
("DYLD_LIBRARY_PATH".to_owned(), pg_lib_dir.to_string()),
@@ -667,7 +660,7 @@ impl StorageController {
));
}
println!("Starting storage controller at {scheme}://{host}:{listen_port}");
println!("Starting storage controller");
background_process::start_process(
COMMAND,

View File

@@ -14,7 +14,6 @@ humantime.workspace = true
pageserver_api.workspace = true
pageserver_client.workspace = true
reqwest.workspace = true
safekeeper_api.workspace=true
serde_json = { workspace = true, features = ["raw_value"] }
storage_controller_client.workspace = true
tokio.workspace = true

View File

@@ -11,7 +11,7 @@ use pageserver_api::controller_api::{
PlacementPolicy, SafekeeperDescribeResponse, SafekeeperSchedulingPolicyRequest,
ShardSchedulingPolicy, ShardsPreferredAzsRequest, ShardsPreferredAzsResponse,
SkSchedulingPolicy, TenantCreateRequest, TenantDescribeResponse, TenantPolicyRequest,
TenantShardMigrateRequest, TenantShardMigrateResponse, TimelineSafekeeperMigrateRequest,
TenantShardMigrateRequest, TenantShardMigrateResponse,
};
use pageserver_api::models::{
EvictionPolicy, EvictionPolicyLayerAccessThreshold, ShardParameters, TenantConfig,
@@ -21,7 +21,6 @@ use pageserver_api::models::{
use pageserver_api::shard::{ShardStripeSize, TenantShardId};
use pageserver_client::mgmt_api::{self};
use reqwest::{Certificate, Method, StatusCode, Url};
use safekeeper_api::models::TimelineLocateResponse;
use storage_controller_client::control_api::Client;
use utils::id::{NodeId, TenantId, TimelineId};
@@ -76,12 +75,6 @@ enum Command {
NodeStartDelete {
#[arg(long)]
node_id: NodeId,
/// When `force` is true, skip waiting for shards to prewarm during migration.
/// This can significantly speed up node deletion since prewarming all shards
/// can take considerable time, but may result in slower initial access to
/// migrated shards until they warm up naturally.
#[arg(long)]
force: bool,
},
/// Cancel deletion of the specified pageserver and wait for `timeout`
/// for the operation to be canceled. May be retried.
@@ -286,23 +279,6 @@ enum Command {
#[arg(long)]
concurrency: Option<usize>,
},
/// Locate safekeepers for a timeline from the storcon DB.
TimelineLocate {
#[arg(long)]
tenant_id: TenantId,
#[arg(long)]
timeline_id: TimelineId,
},
/// Migrate a timeline to a new set of safekeepers
TimelineSafekeeperMigrate {
#[arg(long)]
tenant_id: TenantId,
#[arg(long)]
timeline_id: TimelineId,
/// Example: --new-sk-set 1,2,3
#[arg(long, required = true, value_delimiter = ',')]
new_sk_set: Vec<NodeId>,
},
}
#[derive(Parser)]
@@ -482,7 +458,6 @@ async fn main() -> anyhow::Result<()> {
listen_http_port,
listen_https_port,
availability_zone_id: AvailabilityZone(availability_zone_id),
node_ip_addr: None,
}),
)
.await?;
@@ -958,14 +933,13 @@ async fn main() -> anyhow::Result<()> {
.dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None)
.await?;
}
Command::NodeStartDelete { node_id, force } => {
let query = if force {
format!("control/v1/node/{node_id}/delete?force=true")
} else {
format!("control/v1/node/{node_id}/delete")
};
Command::NodeStartDelete { node_id } => {
storcon_client
.dispatch::<(), ()>(Method::PUT, query, None)
.dispatch::<(), ()>(
Method::PUT,
format!("control/v1/node/{node_id}/delete"),
None,
)
.await?;
println!("Delete started for {node_id}");
}
@@ -1350,7 +1324,7 @@ async fn main() -> anyhow::Result<()> {
concurrency,
} => {
let mut path = format!(
"v1/tenant/{tenant_shard_id}/timeline/{timeline_id}/download_heatmap_layers",
"/v1/tenant/{tenant_shard_id}/timeline/{timeline_id}/download_heatmap_layers",
);
if let Some(c) = concurrency {
@@ -1361,41 +1335,6 @@ async fn main() -> anyhow::Result<()> {
.dispatch::<(), ()>(Method::POST, path, None)
.await?;
}
Command::TimelineLocate {
tenant_id,
timeline_id,
} => {
let path = format!("debug/v1/tenant/{tenant_id}/timeline/{timeline_id}/locate");
let resp = storcon_client
.dispatch::<(), TimelineLocateResponse>(Method::GET, path, None)
.await?;
let sk_set = resp.sk_set.iter().map(|id| id.0 as i64).collect::<Vec<_>>();
let new_sk_set = resp
.new_sk_set
.as_ref()
.map(|ids| ids.iter().map(|id| id.0 as i64).collect::<Vec<_>>());
println!("generation = {}", resp.generation);
println!("sk_set = {sk_set:?}");
println!("new_sk_set = {new_sk_set:?}");
}
Command::TimelineSafekeeperMigrate {
tenant_id,
timeline_id,
new_sk_set,
} => {
let path = format!("v1/tenant/{tenant_id}/timeline/{timeline_id}/safekeeper_migrate");
storcon_client
.dispatch::<_, ()>(
Method::POST,
path,
Some(TimelineSafekeeperMigrateRequest { new_sk_set }),
)
.await?;
}
}
Ok(())

View File

@@ -0,0 +1,58 @@
# Continuous Crofiling (Compute)
The continuous profiling of the compute node is performed by `perf` or `bcc-tools`, the latter is preferred.
The executables profiled are all the postgres-related ones only, excluding the actual compute code (Rust). This can be done as well but
was not the main goal.
## Tools
The aforementioned tools are available within the same Docker image as
the compute node itself, but the corresponding dependencies linux the
linux kernel headers and the linux kernel itself are not and can't be
for obvious reasons. To solve that, as we run the compute nodes as a
virtual machine (qemu), we need to deliver these dependencies to it.
This is done by the `autoscaling` part, which builds and deploys the
kernel headers, needed modules, and the `perf` binary into an ext4-fs
disk image, which is later attached to the VM and is symlinked to be
made available for the compute node.
## Output
The output of the profiling is always a binary file in the same format
of `pprof`. It can, however, be archived by `gzip` additionally, if the
corresponding argument is provided in the JSON request.
## REST API
### Test profiling
One can test the profiling after connecting to the VM and running:
```sh
curl -X POST -H "Content-Type: application/json" http://localhost:3080/profile/cpu -d '{"profiler": {"BccProfile": null}, "sampling_frequency": 99, "timeout_seconds": 5, "archive": false}' -v --output profile.pb
```
This uses the `Bcc` profiler and does not archive the output. The
profiling data will be saved into the `profile.pb` file locally.
**Only one profiling session can be run at a time.**
To check the profiling status (to see whether it is already running or
not), one can perform the `GET` request:
```sh
curl http://localhost:3080/profile/cpu -v
```
The profiling can be stopped by performing the `DELETE` request:
```sh
curl -X DELETE http://localhost:3080/profile/cpu -v
```
## Supported profiling data
For now, only the CPU profiling is done and ther is no heap profiling.
Also, only the postgres-related executables are tracked, the compute
(Rust) part itself **is not tracked**.

View File

@@ -129,10 +129,9 @@ segment to bootstrap the WAL writing, but it doesn't contain the checkpoint reco
changes in xlog.c, to allow starting the compute node without reading the last checkpoint record
from WAL.
This includes code to read the `neon.signal` (also `zenith.signal`) file, which tells the startup
code the LSN to start at. When the `neon.signal` file is present, the startup uses that LSN
instead of the last checkpoint's LSN. The system is known to be consistent at that LSN, without
any WAL redo.
This includes code to read the `zenith.signal` file, which tells the startup code the LSN to start
at. When the `zenith.signal` file is present, the startup uses that LSN instead of the last
checkpoint's LSN. The system is known to be consistent at that LSN, without any WAL redo.
### How to get rid of the patch

View File

@@ -75,7 +75,7 @@ CLI examples:
* AWS S3 : `env AWS_ACCESS_KEY_ID='SOMEKEYAAAAASADSAH*#' AWS_SECRET_ACCESS_KEY='SOMEsEcReTsd292v' ${PAGESERVER_BIN} -c "remote_storage={bucket_name='some-sample-bucket',bucket_region='eu-north-1', prefix_in_bucket='/test_prefix/'}"`
For Amazon AWS S3, a key id and secret access key could be located in `~/.aws/credentials` if awscli was ever configured to work with the desired bucket, on the AWS Settings page for a certain user. Also note, that the bucket names does not contain any protocols when used on AWS.
For local S3 installations, refer to their documentation for name format and credentials.
For local S3 installations, refer to the their documentation for name format and credentials.
Similar to other pageserver settings, toml config file can be used to configure either of the storages as backup targets.
Required sections are:

View File

@@ -46,33 +46,16 @@ pub struct ExtensionInstallResponse {
pub version: ExtVersion,
}
/// Status of the LFC prewarm process. The same state machine is reused for
/// both autoprewarm (prewarm after compute/Postgres start using the previously
/// stored LFC state) and explicit prewarming via API.
#[derive(Serialize, Default, Debug, Clone, PartialEq)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum LfcPrewarmState {
/// Default value when compute boots up.
#[default]
NotPrewarmed,
/// Prewarming thread is active and loading pages into LFC.
Prewarming,
/// We found requested LFC state in the endpoint storage and
/// completed prewarming successfully.
Completed,
/// Unexpected error happened during prewarming. Note, `Not Found 404`
/// response from the endpoint storage is explicitly excluded here
/// because it can normally happen on the first compute start,
/// since LFC state is not available yet.
Failed { error: String },
/// We tried to fetch the corresponding LFC state from the endpoint storage,
/// but received `Not Found 404`. This should normally happen only during the
/// first endpoint start after creation with `autoprewarm: true`.
///
/// During the orchestrated prewarm via API, when a caller explicitly
/// provides the LFC state key to prewarm from, it's the caller responsibility
/// to handle this status as an error state in this case.
Skipped,
Failed {
error: String,
},
}
impl Display for LfcPrewarmState {
@@ -81,7 +64,6 @@ impl Display for LfcPrewarmState {
LfcPrewarmState::NotPrewarmed => f.write_str("NotPrewarmed"),
LfcPrewarmState::Prewarming => f.write_str("Prewarming"),
LfcPrewarmState::Completed => f.write_str("Completed"),
LfcPrewarmState::Skipped => f.write_str("Skipped"),
LfcPrewarmState::Failed { error } => write!(f, "Error({error})"),
}
}

View File

@@ -14,7 +14,6 @@ use serde::{Deserialize, Serialize};
use url::Url;
use utils::id::{TenantId, TimelineId};
use utils::lsn::Lsn;
use utils::shard::{ShardCount, ShardIndex};
use crate::responses::TlsConfig;
@@ -106,17 +105,6 @@ pub struct ComputeSpec {
// updated to fill these fields, we can make these non optional.
pub tenant_id: Option<TenantId>,
pub timeline_id: Option<TimelineId>,
/// Pageserver information can be passed in three different ways:
/// 1. Here in `pageserver_connection_info`
/// 2. In the `pageserver_connstring` field.
/// 3. in `cluster.settings`.
///
/// The goal is to use method 1. everywhere. But for backwards-compatibility with old
/// versions of the control plane, `compute_ctl` will check 2. and 3. if the
/// `pageserver_connection_info` field is missing.
pub pageserver_connection_info: Option<PageserverConnectionInfo>,
pub pageserver_connstring: Option<String>,
// More neon ids that we expose to the compute_ctl
@@ -153,7 +141,7 @@ pub struct ComputeSpec {
// Stripe size for pageserver sharding, in pages
#[serde(default)]
pub shard_stripe_size: Option<u32>,
pub shard_stripe_size: Option<usize>,
/// Local Proxy configuration used for JWT authentication
#[serde(default)]
@@ -226,32 +214,6 @@ pub enum ComputeFeature {
UnknownFeature,
}
#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
pub struct PageserverConnectionInfo {
/// NB: 0 for unsharded tenants, 1 for sharded tenants with 1 shard, following storage
pub shard_count: ShardCount,
/// INVARIANT: null if shard_count is 0, otherwise non-null and immutable
pub stripe_size: Option<u32>,
pub shards: HashMap<ShardIndex, PageserverShardInfo>,
#[serde(default)]
pub prefer_protocol: PageserverProtocol,
}
#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
pub struct PageserverShardInfo {
pub pageservers: Vec<PageserverShardConnectionInfo>,
}
#[derive(Clone, Debug, Deserialize, Serialize, Eq, PartialEq)]
pub struct PageserverShardConnectionInfo {
pub id: Option<String>,
pub libpq_url: Option<String>,
pub grpc_url: Option<String>,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
pub struct RemoteExtSpec {
pub public_extensions: Option<Vec<String>>,
@@ -369,12 +331,6 @@ impl ComputeMode {
}
}
impl Display for ComputeMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.to_type_str())
}
}
/// Log level for audit logging
#[derive(Clone, Debug, Default, Eq, PartialEq, Deserialize, Serialize)]
pub enum ComputeAudit {
@@ -485,15 +441,13 @@ pub struct JwksSettings {
pub jwt_audience: Option<String>,
}
/// Protocol used to connect to a Pageserver.
#[derive(Clone, Copy, Debug, Default, Deserialize, Serialize, PartialEq, Eq)]
/// Protocol used to connect to a Pageserver. Parsed from the connstring scheme.
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub enum PageserverProtocol {
/// The original protocol based on libpq and COPY. Uses postgresql:// or postgres:// scheme.
#[default]
#[serde(rename = "libpq")]
Libpq,
/// A newer, gRPC-based protocol. Uses grpc:// scheme.
#[serde(rename = "grpc")]
Grpc,
}

View File

@@ -4,14 +4,12 @@
//! a default registry.
#![deny(clippy::undocumented_unsafe_blocks)]
use std::sync::RwLock;
use measured::label::{LabelGroupSet, LabelGroupVisitor, LabelName, NoLabels};
use measured::metric::counter::CounterState;
use measured::metric::gauge::GaugeState;
use measured::metric::group::Encoding;
use measured::metric::name::{MetricName, MetricNameEncoder};
use measured::metric::{MetricEncoding, MetricFamilyEncoding, MetricType};
use measured::metric::{MetricEncoding, MetricFamilyEncoding};
use measured::{FixedCardinalityLabel, LabelGroup, MetricGroup};
use once_cell::sync::Lazy;
use prometheus::Registry;
@@ -118,52 +116,12 @@ pub fn pow2_buckets(start: usize, end: usize) -> Vec<f64> {
.collect()
}
pub struct InfoMetric<L: LabelGroup, M: MetricType = GaugeState> {
label: RwLock<L>,
metric: M,
}
impl<L: LabelGroup> InfoMetric<L> {
pub fn new(label: L) -> Self {
Self::with_metric(label, GaugeState::new(1))
}
}
impl<L: LabelGroup, M: MetricType<Metadata = ()>> InfoMetric<L, M> {
pub fn with_metric(label: L, metric: M) -> Self {
Self {
label: RwLock::new(label),
metric,
}
}
pub fn set_label(&self, label: L) {
*self.label.write().unwrap() = label;
}
}
impl<L, M, E> MetricFamilyEncoding<E> for InfoMetric<L, M>
where
L: LabelGroup,
M: MetricEncoding<E, Metadata = ()>,
E: Encoding,
{
fn collect_family_into(
&self,
name: impl measured::metric::name::MetricNameEncoder,
enc: &mut E,
) -> Result<(), E::Err> {
M::write_type(&name, enc)?;
self.metric
.collect_into(&(), &*self.label.read().unwrap(), name, enc)
}
}
pub struct BuildInfo {
pub revision: &'static str,
pub build_tag: &'static str,
}
// todo: allow label group without the set
impl LabelGroup for BuildInfo {
fn visit_values(&self, v: &mut impl LabelGroupVisitor) {
const REVISION: &LabelName = LabelName::from_str("revision");
@@ -173,6 +131,24 @@ impl LabelGroup for BuildInfo {
}
}
impl<T: Encoding> MetricFamilyEncoding<T> for BuildInfo
where
GaugeState: MetricEncoding<T>,
{
fn collect_family_into(
&self,
name: impl measured::metric::name::MetricNameEncoder,
enc: &mut T,
) -> Result<(), T::Err> {
enc.write_help(&name, "Build/version information")?;
GaugeState::write_type(&name, enc)?;
GaugeState {
count: std::sync::atomic::AtomicI64::new(1),
}
.collect_into(&(), self, name, enc)
}
}
#[derive(MetricGroup)]
#[metric(new(build_info: BuildInfo))]
pub struct NeonMetrics {
@@ -189,8 +165,8 @@ pub struct NeonMetrics {
#[derive(MetricGroup)]
#[metric(new(build_info: BuildInfo))]
pub struct LibMetrics {
#[metric(init = InfoMetric::new(build_info))]
build_info: InfoMetric<BuildInfo>,
#[metric(init = build_info)]
build_info: BuildInfo,
#[metric(flatten)]
rusage: Rusage,

View File

@@ -6,15 +6,8 @@ license.workspace = true
[dependencies]
thiserror.workspace = true
nix.workspace=true
nix.workspace = true
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
libc.workspace = true
lock_api.workspace = true
rustc-hash.workspace = true
[target.'cfg(target_os = "macos")'.dependencies]
tempfile = "3.14.0"
[dev-dependencies]
rand = "0.9"
rand_distr = "0.5.1"
tempfile = "3.20.0"

View File

@@ -1,330 +0,0 @@
use criterion::{BatchSize, BenchmarkId, Criterion, criterion_group, criterion_main};
use neon_shmem::hash::HashMapAccess;
use neon_shmem::hash::HashMapInit;
use neon_shmem::hash::entry::Entry;
use rand::distr::{Distribution, StandardUniform};
use rand::prelude::*;
use std::default::Default;
use std::hash::BuildHasher;
// Taken from bindings to C code
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
#[repr(C)]
pub struct FileCacheKey {
pub _spc_id: u32,
pub _db_id: u32,
pub _rel_number: u32,
pub _fork_num: u32,
pub _block_num: u32,
}
impl Distribution<FileCacheKey> for StandardUniform {
// questionable, but doesn't need to be good randomness
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> FileCacheKey {
FileCacheKey {
_spc_id: rng.random(),
_db_id: rng.random(),
_rel_number: rng.random(),
_fork_num: rng.random(),
_block_num: rng.random(),
}
}
}
#[derive(Clone, Debug)]
#[repr(C)]
pub struct FileCacheEntry {
pub _offset: u32,
pub _access_count: u32,
pub _prev: *mut FileCacheEntry,
pub _next: *mut FileCacheEntry,
pub _state: [u32; 8],
}
impl FileCacheEntry {
fn dummy() -> Self {
Self {
_offset: 0,
_access_count: 0,
_prev: std::ptr::null_mut(),
_next: std::ptr::null_mut(),
_state: [0; 8],
}
}
}
// Utilities for applying operations.
#[derive(Clone, Debug)]
struct TestOp<K, V>(K, Option<V>);
fn apply_op<K: Clone + std::hash::Hash + Eq, V, S: std::hash::BuildHasher>(
op: TestOp<K, V>,
map: &mut HashMapAccess<K, V, S>,
) {
let entry = map.entry(op.0);
match op.1 {
Some(new) => match entry {
Entry::Occupied(mut e) => Some(e.insert(new)),
Entry::Vacant(e) => {
_ = e.insert(new).unwrap();
None
}
},
None => match entry {
Entry::Occupied(e) => Some(e.remove()),
Entry::Vacant(_) => None,
},
};
}
// Hash utilities
struct SeaRandomState {
k1: u64,
k2: u64,
k3: u64,
k4: u64,
}
impl std::hash::BuildHasher for SeaRandomState {
type Hasher = seahash::SeaHasher;
fn build_hasher(&self) -> Self::Hasher {
seahash::SeaHasher::with_seeds(self.k1, self.k2, self.k3, self.k4)
}
}
impl SeaRandomState {
fn new() -> Self {
let mut rng = rand::rng();
Self {
k1: rng.random(),
k2: rng.random(),
k3: rng.random(),
k4: rng.random(),
}
}
}
fn small_benchs(c: &mut Criterion) {
let mut group = c.benchmark_group("Small maps");
group.sample_size(10);
group.bench_function("small_rehash", |b| {
let ideal_filled = 4_000_000;
let size = 5_000_000;
let mut writer = HashMapInit::new_resizeable(size, size * 2).attach_writer();
let mut rng = rand::rng();
while writer.get_num_buckets_in_use() < ideal_filled as usize {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
apply_op(TestOp(key, Some(val)), &mut writer);
}
b.iter(|| writer.shuffle());
});
group.bench_function("small_rehash_xxhash", |b| {
let ideal_filled = 4_000_000;
let size = 5_000_000;
let mut writer = HashMapInit::new_resizeable(size, size * 2)
.with_hasher(twox_hash::xxhash64::RandomState::default())
.attach_writer();
let mut rng = rand::rng();
while writer.get_num_buckets_in_use() < ideal_filled as usize {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
apply_op(TestOp(key, Some(val)), &mut writer);
}
b.iter(|| writer.shuffle());
});
group.bench_function("small_rehash_ahash", |b| {
let ideal_filled = 4_000_000;
let size = 5_000_000;
let mut writer = HashMapInit::new_resizeable(size, size * 2)
.with_hasher(ahash::RandomState::default())
.attach_writer();
let mut rng = rand::rng();
while writer.get_num_buckets_in_use() < ideal_filled as usize {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
apply_op(TestOp(key, Some(val)), &mut writer);
}
b.iter(|| writer.shuffle());
});
group.bench_function("small_rehash_seahash", |b| {
let ideal_filled = 4_000_000;
let size = 5_000_000;
let mut writer = HashMapInit::new_resizeable(size, size * 2)
.with_hasher(SeaRandomState::new())
.attach_writer();
let mut rng = rand::rng();
while writer.get_num_buckets_in_use() < ideal_filled as usize {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
apply_op(TestOp(key, Some(val)), &mut writer);
}
b.iter(|| writer.shuffle());
});
group.finish();
}
fn real_benchs(c: &mut Criterion) {
let mut group = c.benchmark_group("Realistic workloads");
group.sample_size(10);
group.bench_function("real_bulk_insert", |b| {
let size = 125_000_000;
let ideal_filled = 100_000_000;
let mut rng = rand::rng();
b.iter_batched(
|| HashMapInit::new_resizeable(size, size * 2).attach_writer(),
|writer| {
for _ in 0..ideal_filled {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
let entry = writer.entry(key);
std::hint::black_box(match entry {
Entry::Occupied(mut e) => {
e.insert(val);
}
Entry::Vacant(e) => {
_ = e.insert(val).unwrap();
}
})
}
},
BatchSize::SmallInput,
)
});
group.bench_function("real_rehash", |b| {
let size = 125_000_000;
let ideal_filled = 100_000_000;
let mut writer = HashMapInit::new_resizeable(size, size).attach_writer();
let mut rng = rand::rng();
while writer.get_num_buckets_in_use() < ideal_filled {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
apply_op(TestOp(key, Some(val)), &mut writer);
}
b.iter(|| writer.shuffle());
});
group.bench_function("real_rehash_hashbrown", |b| {
let size = 125_000_000;
let ideal_filled = 100_000_000;
let mut writer = hashbrown::raw::RawTable::new();
let mut rng = rand::rng();
let hasher = rustc_hash::FxBuildHasher::default();
unsafe {
writer
.resize(
size,
|(k, _)| hasher.hash_one(&k),
hashbrown::raw::Fallibility::Infallible,
)
.unwrap();
}
while writer.len() < ideal_filled as usize {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
writer.insert(hasher.hash_one(&key), (key, val), |(k, _)| {
hasher.hash_one(&k)
});
}
b.iter(|| unsafe {
writer.table.rehash_in_place(
&|table, index| {
hasher.hash_one(
&table
.bucket::<(FileCacheKey, FileCacheEntry)>(index)
.as_ref()
.0,
)
},
std::mem::size_of::<(FileCacheKey, FileCacheEntry)>(),
if std::mem::needs_drop::<(FileCacheKey, FileCacheEntry)>() {
Some(|ptr| std::ptr::drop_in_place(ptr as *mut (FileCacheKey, FileCacheEntry)))
} else {
None
},
)
});
});
for elems in [2, 4, 8, 16, 32, 64, 96, 112] {
group.bench_with_input(
BenchmarkId::new("real_rehash_varied", elems),
&elems,
|b, &size| {
let ideal_filled = size * 1_000_000;
let size = 125_000_000;
let mut writer = HashMapInit::new_resizeable(size, size).attach_writer();
let mut rng = rand::rng();
while writer.get_num_buckets_in_use() < ideal_filled as usize {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
apply_op(TestOp(key, Some(val)), &mut writer);
}
b.iter(|| writer.shuffle());
},
);
group.bench_with_input(
BenchmarkId::new("real_rehash_varied_hashbrown", elems),
&elems,
|b, &size| {
let ideal_filled = size * 1_000_000;
let size = 125_000_000;
let mut writer = hashbrown::raw::RawTable::new();
let mut rng = rand::rng();
let hasher = rustc_hash::FxBuildHasher::default();
unsafe {
writer
.resize(
size,
|(k, _)| hasher.hash_one(&k),
hashbrown::raw::Fallibility::Infallible,
)
.unwrap();
}
while writer.len() < ideal_filled as usize {
let key: FileCacheKey = rng.random();
let val = FileCacheEntry::dummy();
writer.insert(hasher.hash_one(&key), (key, val), |(k, _)| {
hasher.hash_one(&k)
});
}
b.iter(|| unsafe {
writer.table.rehash_in_place(
&|table, index| {
hasher.hash_one(
&table
.bucket::<(FileCacheKey, FileCacheEntry)>(index)
.as_ref()
.0,
)
},
std::mem::size_of::<(FileCacheKey, FileCacheEntry)>(),
if std::mem::needs_drop::<(FileCacheKey, FileCacheEntry)>() {
Some(|ptr| {
std::ptr::drop_in_place(ptr as *mut (FileCacheKey, FileCacheEntry))
})
} else {
None
},
)
});
},
);
}
group.finish();
}
criterion_group!(benches, small_benchs, real_benchs);
criterion_main!(benches);

View File

@@ -1,614 +0,0 @@
//! Resizable hash table implementation on top of byte-level storage (either a [`ShmemHandle`] or a fixed byte array).
//!
//! This hash table has two major components: the bucket array and the dictionary. Each bucket within the
//! bucket array contains a `Option<(K, V)>` and an index of another bucket. In this way there is both an
//! implicit freelist within the bucket array (`None` buckets point to other `None` entries) and various hash
//! chains within the bucket array (a Some bucket will point to other Some buckets that had the same hash).
//!
//! Buckets are never moved unless they are within a region that is being shrunk, and so the actual hash-
//! dependent component is done with the dictionary. When a new key is inserted into the map, a position
//! within the dictionary is decided based on its hash, the data is inserted into an empty bucket based
//! off of the freelist, and then the index of said bucket is placed in the dictionary.
//!
//! This map is resizable (if initialized on top of a [`ShmemHandle`]). Both growing and shrinking happen
//! in-place and are at a high level achieved by expanding/reducing the bucket array and rebuilding the
//! dictionary by rehashing all keys.
//!
//! Concurrency is managed very simply: the entire map is guarded by one shared-memory RwLock.
use std::fmt::Debug;
use std::hash::{BuildHasher, Hash};
use std::mem::MaybeUninit;
use crate::shmem::ShmemHandle;
use crate::{shmem, sync::*};
mod core;
pub mod entry;
#[cfg(test)]
mod tests;
use core::{Bucket, CoreHashMap, INVALID_POS};
use entry::{Entry, OccupiedEntry, PrevPos, VacantEntry};
use thiserror::Error;
/// Error type for a hashmap shrink operation.
#[derive(Error, Debug)]
pub enum HashMapShrinkError {
/// There was an error encountered while resizing the memory area.
#[error("shmem resize failed: {0}")]
ResizeError(shmem::Error),
/// Occupied entries in to-be-shrunk space were encountered beginning at the given index.
#[error("occupied entry in deallocated space found at {0}")]
RemainingEntries(usize),
}
/// This represents a hash table that (possibly) lives in shared memory.
/// If a new process is launched with fork(), the child process inherits
/// this struct.
#[must_use]
pub struct HashMapInit<'a, K, V, S = rustc_hash::FxBuildHasher> {
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
shared_size: usize,
hasher: S,
num_buckets: u32,
}
impl<'a, K, V, S> Debug for HashMapInit<'a, K, V, S>
where
K: Debug,
V: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HashMapInit")
.field("shmem_handle", &self.shmem_handle)
.field("shared_ptr", &self.shared_ptr)
.field("shared_size", &self.shared_size)
// .field("hasher", &self.hasher)
.field("num_buckets", &self.num_buckets)
.finish()
}
}
/// This is a per-process handle to a hash table that (possibly) lives in shared memory.
/// If a child process is launched with fork(), the child process should
/// get its own HashMapAccess by calling HashMapInit::attach_writer/reader().
///
/// XXX: We're not making use of it at the moment, but this struct could
/// hold process-local information in the future.
pub struct HashMapAccess<'a, K, V, S = rustc_hash::FxBuildHasher> {
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
hasher: S,
}
unsafe impl<K: Sync, V: Sync, S> Sync for HashMapAccess<'_, K, V, S> {}
unsafe impl<K: Send, V: Send, S> Send for HashMapAccess<'_, K, V, S> {}
impl<'a, K, V, S> Debug for HashMapAccess<'a, K, V, S>
where
K: Debug,
V: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HashMapAccess")
.field("shmem_handle", &self.shmem_handle)
.field("shared_ptr", &self.shared_ptr)
// .field("hasher", &self.hasher)
.finish()
}
}
impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
/// Change the 'hasher' used by the hash table.
///
/// NOTE: This must be called right after creating the hash table,
/// before inserting any entries and before calling attach_writer/reader.
/// Otherwise different accessors could be using different hash function,
/// with confusing results.
pub fn with_hasher<T: BuildHasher>(self, hasher: T) -> HashMapInit<'a, K, V, T> {
HashMapInit {
hasher,
shmem_handle: self.shmem_handle,
shared_ptr: self.shared_ptr,
shared_size: self.shared_size,
num_buckets: self.num_buckets,
}
}
/// Loosely (over)estimate the size needed to store a hash table with `num_buckets` buckets.
pub fn estimate_size(num_buckets: u32) -> usize {
// add some margin to cover alignment etc.
CoreHashMap::<K, V>::estimate_size(num_buckets) + size_of::<HashMapShared<K, V>>() + 1000
}
fn new(
num_buckets: u32,
shmem_handle: Option<ShmemHandle>,
area_ptr: *mut u8,
area_size: usize,
hasher: S,
) -> Self {
let mut ptr: *mut u8 = area_ptr;
let end_ptr: *mut u8 = unsafe { ptr.add(area_size) };
// carve out area for the One Big Lock (TM) and the HashMapShared.
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<libc::pthread_rwlock_t>())) };
let raw_lock_ptr = ptr;
ptr = unsafe { ptr.add(size_of::<libc::pthread_rwlock_t>()) };
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<HashMapShared<K, V>>())) };
let shared_ptr: *mut HashMapShared<K, V> = ptr.cast();
ptr = unsafe { ptr.add(size_of::<HashMapShared<K, V>>()) };
// carve out the buckets
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<core::Bucket<K, V>>())) };
let buckets_ptr = ptr;
ptr = unsafe { ptr.add(size_of::<core::Bucket<K, V>>() * num_buckets as usize) };
// use remaining space for the dictionary
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<u32>())) };
assert!(ptr.addr() < end_ptr.addr());
let dictionary_ptr = ptr;
let dictionary_size = unsafe { end_ptr.byte_offset_from(ptr) / size_of::<u32>() as isize };
assert!(dictionary_size > 0);
let buckets =
unsafe { std::slice::from_raw_parts_mut(buckets_ptr.cast(), num_buckets as usize) };
let dictionary = unsafe {
std::slice::from_raw_parts_mut(dictionary_ptr.cast(), dictionary_size as usize)
};
let hashmap = CoreHashMap::new(buckets, dictionary);
unsafe {
let lock = RwLock::from_raw(PthreadRwLock::new(raw_lock_ptr.cast()), hashmap);
std::ptr::write(shared_ptr, lock);
}
Self {
num_buckets,
shmem_handle,
shared_ptr,
shared_size: area_size,
hasher,
}
}
/// Attach to a hash table for writing.
pub fn attach_writer(self) -> HashMapAccess<'a, K, V, S> {
HashMapAccess {
shmem_handle: self.shmem_handle,
shared_ptr: self.shared_ptr,
hasher: self.hasher,
}
}
/// Initialize a table for reading. Currently identical to [`HashMapInit::attach_writer`].
///
/// This is a holdover from a previous implementation and is being kept around for
/// backwards compatibility reasons.
pub fn attach_reader(self) -> HashMapAccess<'a, K, V, S> {
self.attach_writer()
}
}
/// Hash table data that is actually stored in the shared memory area.
///
/// NOTE: We carve out the parts from a contiguous chunk. Growing and shrinking the hash table
/// relies on the memory layout! The data structures are laid out in the contiguous shared memory
/// area as follows:
///
/// [`libc::pthread_rwlock_t`]
/// [`HashMapShared`]
/// buckets
/// dictionary
///
/// In between the above parts, there can be padding bytes to align the parts correctly.
type HashMapShared<'a, K, V> = RwLock<CoreHashMap<'a, K, V>>;
impl<'a, K, V> HashMapInit<'a, K, V, rustc_hash::FxBuildHasher>
where
K: Clone + Hash + Eq,
{
/// Place the hash table within a user-supplied fixed memory area.
pub fn with_fixed(num_buckets: u32, area: &'a mut [MaybeUninit<u8>]) -> Self {
Self::new(
num_buckets,
None,
area.as_mut_ptr().cast(),
area.len(),
rustc_hash::FxBuildHasher,
)
}
/// Place a new hash map in the given shared memory area
///
/// # Panics
/// Will panic on failure to resize area to expected map size.
pub fn with_shmem(num_buckets: u32, shmem: ShmemHandle) -> Self {
let size = Self::estimate_size(num_buckets);
shmem
.set_size(size)
.expect("could not resize shared memory area");
let ptr = shmem.data_ptr.as_ptr().cast();
Self::new(
num_buckets,
Some(shmem),
ptr,
size,
rustc_hash::FxBuildHasher,
)
}
/// Make a resizable hash map within a new shared memory area with the given name.
pub fn new_resizeable_named(num_buckets: u32, max_buckets: u32, name: &str) -> Self {
let size = Self::estimate_size(num_buckets);
let max_size = Self::estimate_size(max_buckets);
let shmem =
ShmemHandle::new(name, size, max_size).expect("failed to make shared memory area");
let ptr = shmem.data_ptr.as_ptr().cast();
Self::new(
num_buckets,
Some(shmem),
ptr,
size,
rustc_hash::FxBuildHasher,
)
}
/// Make a resizable hash map within a new anonymous shared memory area.
pub fn new_resizeable(num_buckets: u32, max_buckets: u32) -> Self {
use std::sync::atomic::{AtomicUsize, Ordering};
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let val = COUNTER.fetch_add(1, Ordering::Relaxed);
let name = format!("neon_shmem_hmap{val}");
Self::new_resizeable_named(num_buckets, max_buckets, &name)
}
}
impl<'a, K, V, S: BuildHasher> HashMapAccess<'a, K, V, S>
where
K: Clone + Hash + Eq,
{
/// Hash a key using the map's hasher.
#[inline]
fn get_hash_value(&self, key: &K) -> u64 {
self.hasher.hash_one(key)
}
fn entry_with_hash(&self, key: K, hash: u64) -> Entry<'a, '_, K, V> {
let mut map = unsafe { self.shared_ptr.as_ref() }.unwrap().write();
let dict_pos = hash as usize % map.dictionary.len();
let first = map.dictionary[dict_pos];
if first == INVALID_POS {
// no existing entry
return Entry::Vacant(VacantEntry {
map,
key,
dict_pos: dict_pos as u32,
});
}
let mut prev_pos = PrevPos::First(dict_pos as u32);
let mut next = first;
loop {
let bucket = &mut map.buckets[next as usize];
let (bucket_key, _bucket_value) = bucket.inner.as_mut().expect("entry is in use");
if *bucket_key == key {
// found existing entry
return Entry::Occupied(OccupiedEntry {
map,
_key: key,
prev_pos,
bucket_pos: next,
});
}
if bucket.next == INVALID_POS {
// No existing entry
return Entry::Vacant(VacantEntry {
map,
key,
dict_pos: dict_pos as u32,
});
}
prev_pos = PrevPos::Chained(next);
next = bucket.next;
}
}
/// Get a reference to the corresponding value for a key.
pub fn get<'e>(&'e self, key: &K) -> Option<ValueReadGuard<'e, V>> {
let hash = self.get_hash_value(key);
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
RwLockReadGuard::try_map(map, |m| m.get_with_hash(key, hash)).ok()
}
/// Get a reference to the entry containing a key.
///
/// NB: This takes a write lock as there's no way to distinguish whether the intention
/// is to use the entry for reading or for writing in advance.
pub fn entry(&self, key: K) -> Entry<'a, '_, K, V> {
let hash = self.get_hash_value(&key);
self.entry_with_hash(key, hash)
}
/// Remove a key given its hash. Returns the associated value if it existed.
pub fn remove(&self, key: &K) -> Option<V> {
let hash = self.get_hash_value(key);
match self.entry_with_hash(key.clone(), hash) {
Entry::Occupied(e) => Some(e.remove()),
Entry::Vacant(_) => None,
}
}
/// Insert/update a key. Returns the previous associated value if it existed.
///
/// # Errors
/// Will return [`core::FullError`] if there is no more space left in the map.
pub fn insert(&self, key: K, value: V) -> Result<Option<V>, core::FullError> {
let hash = self.get_hash_value(&key);
match self.entry_with_hash(key.clone(), hash) {
Entry::Occupied(mut e) => Ok(Some(e.insert(value))),
Entry::Vacant(e) => {
_ = e.insert(value)?;
Ok(None)
}
}
}
/// Optionally return the entry for a bucket at a given index if it exists.
///
/// Has more overhead than one would intuitively expect: performs both a clone of the key
/// due to the [`OccupiedEntry`] type owning the key and also a hash of the key in order
/// to enable repairing the hash chain if the entry is removed.
pub fn entry_at_bucket(&self, pos: usize) -> Option<OccupiedEntry<'a, '_, K, V>> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
if pos >= map.buckets.len() {
return None;
}
let entry = map.buckets[pos].inner.as_ref();
match entry {
Some((key, _)) => Some(OccupiedEntry {
_key: key.clone(),
bucket_pos: pos as u32,
prev_pos: entry::PrevPos::Unknown(self.get_hash_value(key)),
map,
}),
_ => None,
}
}
/// Returns the number of buckets in the table.
pub fn get_num_buckets(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
map.get_num_buckets()
}
/// Return the key and value stored in bucket with given index. This can be used to
/// iterate through the hash map.
// TODO: An Iterator might be nicer. The communicator's clock algorithm needs to
// _slowly_ iterate through all buckets with its clock hand, without holding a lock.
// If we switch to an Iterator, it must not hold the lock.
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<(K, V)>> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
if pos >= map.buckets.len() {
return None;
}
RwLockReadGuard::try_map(map, |m| m.buckets[pos].inner.as_ref()).ok()
}
/// Returns the index of the bucket a given value corresponds to.
pub fn get_bucket_for_value(&self, val_ptr: *const V) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
let origin = map.buckets.as_ptr();
let idx = (val_ptr as usize - origin as usize) / size_of::<Bucket<K, V>>();
assert!(idx < map.buckets.len());
idx
}
/// Returns the number of occupied buckets in the table.
pub fn get_num_buckets_in_use(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
map.buckets_in_use as usize
}
/// Clears all entries in a table. Does not reset any shrinking operations.
pub fn clear(&self) {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
map.clear();
}
/// Perform an in-place rehash of some region (0..`rehash_buckets`) of the table and reset
/// the `buckets` and `dictionary` slices to be as long as `num_buckets`. Resets the freelist
/// in the process.
fn rehash_dict(
&self,
inner: &mut CoreHashMap<'a, K, V>,
buckets_ptr: *mut core::Bucket<K, V>,
end_ptr: *mut u8,
num_buckets: u32,
rehash_buckets: u32,
) {
inner.free_head = INVALID_POS;
let buckets;
let dictionary;
unsafe {
let buckets_end_ptr = buckets_ptr.add(num_buckets as usize);
let dictionary_ptr: *mut u32 = buckets_end_ptr
.byte_add(buckets_end_ptr.align_offset(align_of::<u32>()))
.cast();
let dictionary_size: usize =
end_ptr.byte_offset_from(buckets_end_ptr) as usize / size_of::<u32>();
buckets = std::slice::from_raw_parts_mut(buckets_ptr, num_buckets as usize);
dictionary = std::slice::from_raw_parts_mut(dictionary_ptr, dictionary_size);
}
for e in dictionary.iter_mut() {
*e = INVALID_POS;
}
for (i, bucket) in buckets.iter_mut().enumerate().take(rehash_buckets as usize) {
if bucket.inner.is_none() {
bucket.next = inner.free_head;
inner.free_head = i as u32;
continue;
}
let hash = self.hasher.hash_one(&bucket.inner.as_ref().unwrap().0);
let pos: usize = (hash % dictionary.len() as u64) as usize;
bucket.next = dictionary[pos];
dictionary[pos] = i as u32;
}
inner.dictionary = dictionary;
inner.buckets = buckets;
}
/// Rehash the map without growing or shrinking.
pub fn shuffle(&self) {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
let num_buckets = map.get_num_buckets() as u32;
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
let end_ptr: *mut u8 = unsafe { self.shared_ptr.byte_add(size_bytes).cast() };
let buckets_ptr = map.buckets.as_mut_ptr();
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets);
}
/// Grow the number of buckets within the table.
///
/// 1. Grows the underlying shared memory area
/// 2. Initializes new buckets and overwrites the current dictionary
/// 3. Rehashes the dictionary
///
/// # Panics
/// Panics if called on a map initialized with [`HashMapInit::with_fixed`].
///
/// # Errors
/// Returns an [`shmem::Error`] if any errors occur resizing the memory region.
pub fn grow(&self, num_buckets: u32) -> Result<(), shmem::Error> {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
let old_num_buckets = map.buckets.len() as u32;
assert!(
num_buckets >= old_num_buckets,
"grow called with a smaller number of buckets"
);
if num_buckets == old_num_buckets {
return Ok(());
}
let shmem_handle = self
.shmem_handle
.as_ref()
.expect("grow called on a fixed-size hash table");
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
shmem_handle.set_size(size_bytes)?;
let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
// Initialize new buckets. The new buckets are linked to the free list.
// NB: This overwrites the dictionary!
let buckets_ptr = map.buckets.as_mut_ptr();
unsafe {
for i in old_num_buckets..num_buckets {
let bucket = buckets_ptr.add(i as usize);
bucket.write(core::Bucket {
next: if i < num_buckets - 1 {
i + 1
} else {
map.free_head
},
inner: None,
});
}
}
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, old_num_buckets);
map.free_head = old_num_buckets;
Ok(())
}
/// Begin a shrink, limiting all new allocations to be in buckets with index below `num_buckets`.
///
/// # Panics
/// Panics if called on a map initialized with [`HashMapInit::with_fixed`] or if `num_buckets` is
/// greater than the number of buckets in the map.
pub fn begin_shrink(&self, num_buckets: u32) {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
assert!(
num_buckets <= map.get_num_buckets() as u32,
"shrink called with a larger number of buckets"
);
_ = self
.shmem_handle
.as_ref()
.expect("shrink called on a fixed-size hash table");
map.alloc_limit = num_buckets;
}
/// If a shrink operation is underway, returns the target size of the map. Otherwise, returns None.
pub fn shrink_goal(&self) -> Option<usize> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap().read();
let goal = map.alloc_limit;
if goal == INVALID_POS {
None
} else {
Some(goal as usize)
}
}
/// Complete a shrink after caller has evicted entries, removing the unused buckets and rehashing.
///
/// # Panics
/// The following cases result in a panic:
/// - Calling this function on a map initialized with [`HashMapInit::with_fixed`].
/// - Calling this function on a map when no shrink operation is in progress.
pub fn finish_shrink(&self) -> Result<(), HashMapShrinkError> {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
assert!(
map.alloc_limit != INVALID_POS,
"called finish_shrink when no shrink is in progress"
);
let num_buckets = map.alloc_limit;
if map.get_num_buckets() == num_buckets as usize {
return Ok(());
}
assert!(
map.buckets_in_use <= num_buckets,
"called finish_shrink before enough entries were removed"
);
for i in (num_buckets as usize)..map.buckets.len() {
if map.buckets[i].inner.is_some() {
return Err(HashMapShrinkError::RemainingEntries(i));
}
}
let shmem_handle = self
.shmem_handle
.as_ref()
.expect("shrink called on a fixed-size hash table");
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
if let Err(e) = shmem_handle.set_size(size_bytes) {
return Err(HashMapShrinkError::ResizeError(e));
}
let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
let buckets_ptr = map.buckets.as_mut_ptr();
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets);
map.alloc_limit = INVALID_POS;
Ok(())
}
}

View File

@@ -1,204 +0,0 @@
//! Simple hash table with chaining.
use std::fmt::Debug;
use std::hash::Hash;
use std::mem::MaybeUninit;
use crate::hash::entry::*;
/// Invalid position within the map (either within the dictionary or bucket array).
pub(crate) const INVALID_POS: u32 = u32::MAX;
/// Fundamental storage unit within the hash table. Either empty or contains a key-value pair.
/// Always part of a chain of some kind (either a freelist if empty or a hash chain if full).
pub(crate) struct Bucket<K, V> {
/// Index of next bucket in the chain.
pub(crate) next: u32,
/// Key-value pair contained within bucket.
pub(crate) inner: Option<(K, V)>,
}
impl<K, V> Debug for Bucket<K, V>
where
K: Debug,
V: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Bucket")
.field("next", &self.next)
.field("inner", &self.inner)
.finish()
}
}
/// Core hash table implementation.
pub(crate) struct CoreHashMap<'a, K, V> {
/// Dictionary used to map hashes to bucket indices.
pub(crate) dictionary: &'a mut [u32],
/// Buckets containing key-value pairs.
pub(crate) buckets: &'a mut [Bucket<K, V>],
/// Head of the freelist.
pub(crate) free_head: u32,
/// Maximum index of a bucket allowed to be allocated. [`INVALID_POS`] if no limit.
pub(crate) alloc_limit: u32,
/// The number of currently occupied buckets.
pub(crate) buckets_in_use: u32,
}
impl<'a, K, V> Debug for CoreHashMap<'a, K, V>
where
K: Debug,
V: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CoreHashMap")
.field("dictionary", &self.dictionary)
.field("buckets", &self.buckets)
.field("free_head", &self.free_head)
.field("alloc_limit", &self.alloc_limit)
.field("buckets_in_use", &self.buckets_in_use)
.finish()
}
}
/// Error for when there are no empty buckets left but one is needed.
#[derive(Debug, PartialEq)]
pub struct FullError;
impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
const FILL_FACTOR: f32 = 0.60;
/// Estimate the size of data contained within the the hash map.
pub fn estimate_size(num_buckets: u32) -> usize {
let mut size = 0;
// buckets
size += size_of::<Bucket<K, V>>() * num_buckets as usize;
// dictionary
size += (f32::ceil((size_of::<u32>() * num_buckets as usize) as f32 / Self::FILL_FACTOR))
as usize;
size
}
pub fn new(
buckets: &'a mut [MaybeUninit<Bucket<K, V>>],
dictionary: &'a mut [MaybeUninit<u32>],
) -> Self {
// Initialize the buckets
for i in 0..buckets.len() {
buckets[i].write(Bucket {
next: if i < buckets.len() - 1 {
i as u32 + 1
} else {
INVALID_POS
},
inner: None,
});
}
// Initialize the dictionary
for e in dictionary.iter_mut() {
e.write(INVALID_POS);
}
// TODO: use std::slice::assume_init_mut() once it stabilizes
let buckets =
unsafe { std::slice::from_raw_parts_mut(buckets.as_mut_ptr().cast(), buckets.len()) };
let dictionary = unsafe {
std::slice::from_raw_parts_mut(dictionary.as_mut_ptr().cast(), dictionary.len())
};
Self {
dictionary,
buckets,
free_head: 0,
buckets_in_use: 0,
alloc_limit: INVALID_POS,
}
}
/// Get the value associated with a key (if it exists) given its hash.
pub fn get_with_hash(&self, key: &K, hash: u64) -> Option<&V> {
let mut next = self.dictionary[hash as usize % self.dictionary.len()];
loop {
if next == INVALID_POS {
return None;
}
let bucket = &self.buckets[next as usize];
let (bucket_key, bucket_value) = bucket.inner.as_ref().expect("entry is in use");
if bucket_key == key {
return Some(bucket_value);
}
next = bucket.next;
}
}
/// Get number of buckets in map.
pub fn get_num_buckets(&self) -> usize {
self.buckets.len()
}
/// Clears all entries from the hashmap.
///
/// Does not reset any allocation limits, but does clear any entries beyond them.
pub fn clear(&mut self) {
for i in 0..self.buckets.len() {
self.buckets[i] = Bucket {
next: if i < self.buckets.len() - 1 {
i as u32 + 1
} else {
INVALID_POS
},
inner: None,
}
}
for i in 0..self.dictionary.len() {
self.dictionary[i] = INVALID_POS;
}
self.free_head = 0;
self.buckets_in_use = 0;
}
/// Find the position of an unused bucket via the freelist and initialize it.
pub(crate) fn alloc_bucket(&mut self, key: K, value: V) -> Result<u32, FullError> {
let mut pos = self.free_head;
// Find the first bucket we're *allowed* to use.
let mut prev = PrevPos::First(self.free_head);
while pos != INVALID_POS && pos >= self.alloc_limit {
let bucket = &mut self.buckets[pos as usize];
prev = PrevPos::Chained(pos);
pos = bucket.next;
}
if pos == INVALID_POS {
return Err(FullError);
}
// Repair the freelist.
match prev {
PrevPos::First(_) => {
let next_pos = self.buckets[pos as usize].next;
self.free_head = next_pos;
}
PrevPos::Chained(p) => {
if p != INVALID_POS {
let next_pos = self.buckets[pos as usize].next;
self.buckets[p as usize].next = next_pos;
}
}
_ => unreachable!(),
}
// Initialize the bucket.
let bucket = &mut self.buckets[pos as usize];
self.buckets_in_use += 1;
bucket.next = INVALID_POS;
bucket.inner = Some((key, value));
Ok(pos)
}
}

View File

@@ -1,130 +0,0 @@
//! Equivalent of [`std::collections::hash_map::Entry`] for this hashmap.
use crate::hash::core::{CoreHashMap, FullError, INVALID_POS};
use crate::sync::{RwLockWriteGuard, ValueWriteGuard};
use std::hash::Hash;
use std::mem;
pub enum Entry<'a, 'b, K, V> {
Occupied(OccupiedEntry<'a, 'b, K, V>),
Vacant(VacantEntry<'a, 'b, K, V>),
}
/// Enum representing the previous position within a chain.
#[derive(Clone, Copy)]
pub(crate) enum PrevPos {
/// Starting index within the dictionary.
First(u32),
/// Regular index within the buckets.
Chained(u32),
/// Unknown - e.g. the associated entry was retrieved by index instead of chain.
Unknown(u64),
}
pub struct OccupiedEntry<'a, 'b, K, V> {
/// Mutable reference to the map containing this entry.
pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>,
/// The key of the occupied entry
pub(crate) _key: K,
/// The index of the previous entry in the chain.
pub(crate) prev_pos: PrevPos,
/// The position of the bucket in the [`CoreHashMap`] bucket array.
pub(crate) bucket_pos: u32,
}
impl<K, V> OccupiedEntry<'_, '_, K, V> {
pub fn get(&self) -> &V {
&self.map.buckets[self.bucket_pos as usize]
.inner
.as_ref()
.unwrap()
.1
}
pub fn get_mut(&mut self) -> &mut V {
&mut self.map.buckets[self.bucket_pos as usize]
.inner
.as_mut()
.unwrap()
.1
}
/// Inserts a value into the entry, replacing (and returning) the existing value.
pub fn insert(&mut self, value: V) -> V {
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
// This assumes inner is Some, which it must be for an OccupiedEntry
mem::replace(&mut bucket.inner.as_mut().unwrap().1, value)
}
/// Removes the entry from the hash map, returning the value originally stored within it.
///
/// This may result in multiple bucket accesses if the entry was obtained by index as the
/// previous chain entry needs to be discovered in this case.
pub fn remove(mut self) -> V {
// If this bucket was queried by index, go ahead and follow its chain from the start.
let prev = if let PrevPos::Unknown(hash) = self.prev_pos {
let dict_idx = hash as usize % self.map.dictionary.len();
let mut prev = PrevPos::First(dict_idx as u32);
let mut curr = self.map.dictionary[dict_idx];
while curr != self.bucket_pos {
assert!(curr != INVALID_POS);
prev = PrevPos::Chained(curr);
curr = self.map.buckets[curr as usize].next;
}
prev
} else {
self.prev_pos
};
// CoreHashMap::remove returns Option<(K, V)>. We know it's Some for an OccupiedEntry.
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
// unlink it from the chain
match prev {
PrevPos::First(dict_pos) => {
self.map.dictionary[dict_pos as usize] = bucket.next;
}
PrevPos::Chained(bucket_pos) => {
self.map.buckets[bucket_pos as usize].next = bucket.next;
}
_ => unreachable!(),
}
// and add it to the freelist
let free = self.map.free_head;
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
let old_value = bucket.inner.take();
bucket.next = free;
self.map.free_head = self.bucket_pos;
self.map.buckets_in_use -= 1;
old_value.unwrap().1
}
}
/// An abstract view into a vacant entry within the map.
pub struct VacantEntry<'a, 'b, K, V> {
/// Mutable reference to the map containing this entry.
pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>,
/// The key to be inserted into this entry.
pub(crate) key: K,
/// The position within the dictionary corresponding to the key's hash.
pub(crate) dict_pos: u32,
}
impl<'b, K: Clone + Hash + Eq, V> VacantEntry<'_, 'b, K, V> {
/// Insert a value into the vacant entry, finding and populating an empty bucket in the process.
///
/// # Errors
/// Will return [`FullError`] if there are no unoccupied buckets in the map.
pub fn insert(mut self, value: V) -> Result<ValueWriteGuard<'b, V>, FullError> {
let pos = self.map.alloc_bucket(self.key, value)?;
self.map.buckets[pos as usize].next = self.map.dictionary[self.dict_pos as usize];
self.map.dictionary[self.dict_pos as usize] = pos;
Ok(RwLockWriteGuard::map(self.map, |m| {
&mut m.buckets[pos as usize].inner.as_mut().unwrap().1
}))
}
}

View File

@@ -1,428 +0,0 @@
use std::collections::BTreeMap;
use std::collections::HashSet;
use std::fmt::Debug;
use std::mem::MaybeUninit;
use crate::hash::Entry;
use crate::hash::HashMapAccess;
use crate::hash::HashMapInit;
use crate::hash::core::FullError;
use rand::seq::SliceRandom;
use rand::{Rng, RngCore};
use rand_distr::Zipf;
const TEST_KEY_LEN: usize = 16;
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
struct TestKey([u8; TEST_KEY_LEN]);
impl From<&TestKey> for u128 {
fn from(val: &TestKey) -> u128 {
u128::from_be_bytes(val.0)
}
}
impl From<u128> for TestKey {
fn from(val: u128) -> TestKey {
TestKey(val.to_be_bytes())
}
}
impl<'a> From<&'a [u8]> for TestKey {
fn from(bytes: &'a [u8]) -> TestKey {
TestKey(bytes.try_into().unwrap())
}
}
fn test_inserts<K: Into<TestKey> + Copy>(keys: &[K]) {
let w = HashMapInit::<TestKey, usize>::new_resizeable_named(100000, 120000, "test_inserts")
.attach_writer();
for (idx, k) in keys.iter().enumerate() {
let res = w.entry((*k).into());
match res {
Entry::Occupied(mut e) => {
e.insert(idx);
}
Entry::Vacant(e) => {
let res = e.insert(idx);
assert!(res.is_ok());
}
};
}
for (idx, k) in keys.iter().enumerate() {
let x = w.get(&(*k).into());
let value = x.as_deref().copied();
assert_eq!(value, Some(idx));
}
}
#[test]
fn dense() {
// This exercises splitting a node with prefix
let keys: &[u128] = &[0, 1, 2, 3, 256];
test_inserts(keys);
// Dense keys
let mut keys: Vec<u128> = (0..10000).collect();
test_inserts(&keys);
// Do the same in random orders
for _ in 1..10 {
keys.shuffle(&mut rand::rng());
test_inserts(&keys);
}
}
#[test]
fn sparse() {
// sparse keys
let mut keys: Vec<TestKey> = Vec::new();
let mut used_keys = HashSet::new();
for _ in 0..10000 {
loop {
let key = rand::random::<u128>();
if used_keys.contains(&key) {
continue;
}
used_keys.insert(key);
keys.push(key.into());
break;
}
}
test_inserts(&keys);
}
#[derive(Clone, Debug)]
struct TestOp(TestKey, Option<usize>);
fn apply_op(
op: &TestOp,
map: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
) {
// apply the change to the shadow tree first
let shadow_existing = if let Some(v) = op.1 {
shadow.insert(op.0, v)
} else {
shadow.remove(&op.0)
};
let entry = map.entry(op.0);
let hash_existing = match op.1 {
Some(new) => match entry {
Entry::Occupied(mut e) => Some(e.insert(new)),
Entry::Vacant(e) => {
_ = e.insert(new).unwrap();
None
}
},
None => match entry {
Entry::Occupied(e) => Some(e.remove()),
Entry::Vacant(_) => None,
},
};
assert_eq!(shadow_existing, hash_existing);
}
fn do_random_ops(
num_ops: usize,
size: u32,
del_prob: f64,
writer: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
rng: &mut rand::rngs::ThreadRng,
) {
for i in 0..num_ops {
let key: TestKey = ((rng.next_u32() % size) as u128).into();
let op = TestOp(
key,
if rng.random_bool(del_prob) {
Some(i)
} else {
None
},
);
apply_op(&op, writer, shadow);
}
}
fn do_deletes(
num_ops: usize,
writer: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
) {
for _ in 0..num_ops {
let (k, _) = shadow.pop_first().unwrap();
writer.remove(&k);
}
}
fn do_shrink(
writer: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
from: u32,
to: u32,
) {
assert!(writer.shrink_goal().is_none());
writer.begin_shrink(to);
assert_eq!(writer.shrink_goal(), Some(to as usize));
for i in to..from {
if let Some(entry) = writer.entry_at_bucket(i as usize) {
shadow.remove(&entry._key);
entry.remove();
}
}
let old_usage = writer.get_num_buckets_in_use();
writer.finish_shrink().unwrap();
assert!(writer.shrink_goal().is_none());
assert_eq!(writer.get_num_buckets_in_use(), old_usage);
}
#[test]
fn random_ops() {
let mut writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(100000, 120000, "test_random")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let distribution = Zipf::new(u128::MAX as f64, 1.1).unwrap();
let mut rng = rand::rng();
for i in 0..100000 {
let key: TestKey = (rng.sample(distribution) as u128).into();
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &mut writer, &mut shadow);
}
}
#[test]
fn test_shuffle() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 1200, "test_shuf")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng);
writer.shuffle();
do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng);
}
#[test]
fn test_grow() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 2000, "test_grow")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng);
let old_usage = writer.get_num_buckets_in_use();
writer.grow(1500).unwrap();
assert_eq!(writer.get_num_buckets_in_use(), old_usage);
assert_eq!(writer.get_num_buckets(), 1500);
do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
}
#[test]
fn test_clear() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_clear")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(2000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
writer.clear();
assert_eq!(writer.get_num_buckets_in_use(), 0);
assert_eq!(writer.get_num_buckets(), 1500);
while let Some((key, _)) = shadow.pop_first() {
assert!(writer.get(&key).is_none());
}
do_random_ops(2000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
for i in 0..(1500 - writer.get_num_buckets_in_use()) {
writer.insert((1500 + i as u128).into(), 0).unwrap();
}
assert_eq!(writer.insert(5000.into(), 0), Err(FullError {}));
writer.clear();
assert!(writer.insert(5000.into(), 0).is_ok());
}
#[test]
fn test_idx_remove() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_clear")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(2000, 1500, 0.25, &mut writer, &mut shadow, &mut rng);
for _ in 0..100 {
let idx = (rng.next_u32() % 1500) as usize;
if let Some(e) = writer.entry_at_bucket(idx) {
shadow.remove(&e._key);
e.remove();
}
}
while let Some((key, val)) = shadow.pop_first() {
assert_eq!(*writer.get(&key).unwrap(), val);
}
}
#[test]
fn test_idx_get() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_clear")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(2000, 1500, 0.25, &mut writer, &mut shadow, &mut rng);
for _ in 0..100 {
let idx = (rng.next_u32() % 1500) as usize;
if let Some(pair) = writer.get_at_bucket(idx) {
{
let v: *const usize = &pair.1;
assert_eq!(writer.get_bucket_for_value(v), idx);
}
{
let v: *const usize = &pair.1;
assert_eq!(writer.get_bucket_for_value(v), idx);
}
}
}
}
#[test]
fn test_shrink() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_shrink")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
do_shrink(&mut writer, &mut shadow, 1500, 1000);
assert_eq!(writer.get_num_buckets(), 1000);
do_deletes(500, &mut writer, &mut shadow);
do_random_ops(10000, 500, 0.75, &mut writer, &mut shadow, &mut rng);
assert!(writer.get_num_buckets_in_use() <= 1000);
}
#[test]
fn test_shrink_grow_seq() {
let mut writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 20000, "test_grow_seq")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(500, 1000, 0.1, &mut writer, &mut shadow, &mut rng);
eprintln!("Shrinking to 750");
do_shrink(&mut writer, &mut shadow, 1000, 750);
do_random_ops(200, 1000, 0.5, &mut writer, &mut shadow, &mut rng);
eprintln!("Growing to 1500");
writer.grow(1500).unwrap();
do_random_ops(600, 1500, 0.1, &mut writer, &mut shadow, &mut rng);
eprintln!("Shrinking to 200");
while shadow.len() > 100 {
do_deletes(1, &mut writer, &mut shadow);
}
do_shrink(&mut writer, &mut shadow, 1500, 200);
do_random_ops(50, 1500, 0.25, &mut writer, &mut shadow, &mut rng);
eprintln!("Growing to 10k");
writer.grow(10000).unwrap();
do_random_ops(10000, 5000, 0.25, &mut writer, &mut shadow, &mut rng);
}
#[test]
fn test_bucket_ops() {
let writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 1200, "test_bucket_ops")
.attach_writer();
match writer.entry(1.into()) {
Entry::Occupied(mut e) => {
e.insert(2);
}
Entry::Vacant(e) => {
_ = e.insert(2).unwrap();
}
}
assert_eq!(writer.get_num_buckets_in_use(), 1);
assert_eq!(writer.get_num_buckets(), 1000);
assert_eq!(*writer.get(&1.into()).unwrap(), 2);
let pos = match writer.entry(1.into()) {
Entry::Occupied(e) => {
assert_eq!(e._key, 1.into());
e.bucket_pos as usize
}
Entry::Vacant(_) => {
panic!("Insert didn't affect entry");
}
};
assert_eq!(writer.entry_at_bucket(pos).unwrap()._key, 1.into());
assert_eq!(*writer.get_at_bucket(pos).unwrap(), (1.into(), 2));
{
let ptr: *const usize = &*writer.get(&1.into()).unwrap();
assert_eq!(writer.get_bucket_for_value(ptr), pos);
}
writer.remove(&1.into());
assert!(writer.get(&1.into()).is_none());
}
#[test]
fn test_shrink_zero() {
let mut writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_shrink_zero")
.attach_writer();
writer.begin_shrink(0);
for i in 0..1500 {
writer.entry_at_bucket(i).map(|x| x.remove());
}
writer.finish_shrink().unwrap();
assert_eq!(writer.get_num_buckets_in_use(), 0);
let entry = writer.entry(1.into());
if let Entry::Vacant(v) = entry {
assert!(v.insert(2).is_err());
} else {
panic!("Somehow got non-vacant entry in empty map.")
}
writer.grow(50).unwrap();
let entry = writer.entry(1.into());
if let Entry::Vacant(v) = entry {
assert!(v.insert(2).is_ok());
} else {
panic!("Somehow got non-vacant entry in empty map.")
}
assert_eq!(writer.get_num_buckets_in_use(), 1);
}
#[test]
#[should_panic]
fn test_grow_oom() {
let writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_grow_oom")
.attach_writer();
writer.grow(20000).unwrap();
}
#[test]
#[should_panic]
fn test_shrink_bigger() {
let mut writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2500, "test_shrink_bigger")
.attach_writer();
writer.begin_shrink(2000);
}
#[test]
#[should_panic]
fn test_shrink_early_finish() {
let writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2500, "test_shrink_early_finish")
.attach_writer();
writer.finish_shrink().unwrap();
}
#[test]
#[should_panic]
fn test_shrink_fixed_size() {
let mut area = [MaybeUninit::uninit(); 10000];
let init_struct = HashMapInit::<TestKey, usize>::with_fixed(3, &mut area);
let mut writer = init_struct.attach_writer();
writer.begin_shrink(1);
}

View File

@@ -1,3 +1,418 @@
pub mod hash;
pub mod shmem;
pub mod sync;
//! Shared memory utilities for neon communicator
use std::num::NonZeroUsize;
use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use nix::errno::Errno;
use nix::sys::mman::MapFlags;
use nix::sys::mman::ProtFlags;
use nix::sys::mman::mmap as nix_mmap;
use nix::sys::mman::munmap as nix_munmap;
use nix::unistd::ftruncate as nix_ftruncate;
/// ShmemHandle represents a shared memory area that can be shared by processes over fork().
/// Unlike shared memory allocated by Postgres, this area is resizable, up to 'max_size' that's
/// specified at creation.
///
/// The area is backed by an anonymous file created with memfd_create(). The full address space for
/// 'max_size' is reserved up-front with mmap(), but whenever you call [`ShmemHandle::set_size`],
/// the underlying file is resized. Do not access the area beyond the current size. Currently, that
/// will cause the file to be expanded, but we might use mprotect() etc. to enforce that in the
/// future.
pub struct ShmemHandle {
/// memfd file descriptor
fd: OwnedFd,
max_size: usize,
// Pointer to the beginning of the shared memory area. The header is stored there.
shared_ptr: NonNull<SharedStruct>,
// Pointer to the beginning of the user data
pub data_ptr: NonNull<u8>,
}
/// This is stored at the beginning in the shared memory area.
struct SharedStruct {
max_size: usize,
/// Current size of the backing file. The high-order bit is used for the RESIZE_IN_PROGRESS flag
current_size: AtomicUsize,
}
const RESIZE_IN_PROGRESS: usize = 1 << 63;
const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
/// Error type returned by the ShmemHandle functions.
#[derive(thiserror::Error, Debug)]
#[error("{msg}: {errno}")]
pub struct Error {
pub msg: String,
pub errno: Errno,
}
impl Error {
fn new(msg: &str, errno: Errno) -> Error {
Error {
msg: msg.to_string(),
errno,
}
}
}
impl ShmemHandle {
/// Create a new shared memory area. To communicate between processes, the processes need to be
/// fork()'d after calling this, so that the ShmemHandle is inherited by all processes.
///
/// If the ShmemHandle is dropped, the memory is unmapped from the current process. Other
/// processes can continue using it, however.
pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<ShmemHandle, Error> {
// create the backing anonymous file.
let fd = create_backing_file(name)?;
Self::new_with_fd(fd, initial_size, max_size)
}
fn new_with_fd(
fd: OwnedFd,
initial_size: usize,
max_size: usize,
) -> Result<ShmemHandle, Error> {
// We reserve the high-order bit for the RESIZE_IN_PROGRESS flag, and the actual size
// is a little larger than this because of the SharedStruct header. Make the upper limit
// somewhat smaller than that, because with anything close to that, you'll run out of
// memory anyway.
if max_size >= 1 << 48 {
panic!("max size {max_size} too large");
}
if initial_size > max_size {
panic!("initial size {initial_size} larger than max size {max_size}");
}
// The actual initial / max size is the one given by the caller, plus the size of
// 'SharedStruct'.
let initial_size = HEADER_SIZE + initial_size;
let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
// Reserve address space for it with mmap
//
// TODO: Use MAP_HUGETLB if possible
let start_ptr = unsafe {
nix_mmap(
None,
max_size,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_SHARED,
&fd,
0,
)
}
.map_err(|e| Error::new("mmap failed: {e}", e))?;
// Reserve space for the initial size
enlarge_file(fd.as_fd(), initial_size as u64)?;
// Initialize the header
let shared: NonNull<SharedStruct> = start_ptr.cast();
unsafe {
shared.write(SharedStruct {
max_size: max_size.into(),
current_size: AtomicUsize::new(initial_size),
})
};
// The user data begins after the header
let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
Ok(ShmemHandle {
fd,
max_size: max_size.into(),
shared_ptr: shared,
data_ptr,
})
}
// return reference to the header
fn shared(&self) -> &SharedStruct {
unsafe { self.shared_ptr.as_ref() }
}
/// Resize the shared memory area. 'new_size' must not be larger than the 'max_size' specified
/// when creating the area.
///
/// This may only be called from one process/thread concurrently. We detect that case
/// and return an Error.
pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
let new_size = new_size + HEADER_SIZE;
let shared = self.shared();
if new_size > self.max_size {
panic!(
"new size ({} is greater than max size ({})",
new_size, self.max_size
);
}
assert_eq!(self.max_size, shared.max_size);
// Lock the area by setting the bit in 'current_size'
//
// Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
// and the posix_fallocate/ftruncate call is surely a synchronization point anyway. But
// since this is not performance-critical, better safe than sorry .
let mut old_size = shared.current_size.load(Ordering::Acquire);
loop {
if (old_size & RESIZE_IN_PROGRESS) != 0 {
return Err(Error::new(
"concurrent resize detected",
Errno::UnknownErrno,
));
}
match shared.current_size.compare_exchange(
old_size,
new_size,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => old_size = x,
}
}
// Ok, we got the lock.
//
// NB: If anything goes wrong, we *must* clear the bit!
let result = {
use std::cmp::Ordering::{Equal, Greater, Less};
match new_size.cmp(&old_size) {
Less => nix_ftruncate(&self.fd, new_size as i64).map_err(|e| {
Error::new("could not shrink shmem segment, ftruncate failed: {e}", e)
}),
Equal => Ok(()),
Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
}
};
// Unlock
shared.current_size.store(
if result.is_ok() { new_size } else { old_size },
Ordering::Release,
);
result
}
/// Returns the current user-visible size of the shared memory segment.
///
/// NOTE: a concurrent set_size() call can change the size at any time. It is the caller's
/// responsibility not to access the area beyond the current size.
pub fn current_size(&self) -> usize {
let total_current_size =
self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
total_current_size - HEADER_SIZE
}
}
impl Drop for ShmemHandle {
fn drop(&mut self) {
// SAFETY: The pointer was obtained from mmap() with the given size.
// We unmap the entire region.
let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
// The fd is dropped automatically by OwnedFd.
}
}
/// Create a "backing file" for the shared memory area. On Linux, use memfd_create(), to create an
/// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
/// development and testing, but in production we want the file to stay in memory.
///
/// disable 'unused_variables' warnings, because in the macos path, 'name' is unused.
#[allow(unused_variables)]
fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
#[cfg(not(target_os = "macos"))]
{
nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
.map_err(|e| Error::new("memfd_create failed: {e}", e))
}
#[cfg(target_os = "macos")]
{
let file = tempfile::tempfile().map_err(|e| {
Error::new(
"could not create temporary file to back shmem area: {e}",
nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
)
})?;
Ok(OwnedFd::from(file))
}
}
fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
// Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
// we don't get a segfault later when trying to actually use it.
#[cfg(not(target_os = "macos"))]
{
nix::fcntl::posix_fallocate(fd, 0, size as i64).map_err(|e| {
Error::new(
"could not grow shmem segment, posix_fallocate failed: {e}",
e,
)
})
}
// As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
#[cfg(target_os = "macos")]
{
nix::unistd::ftruncate(fd, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, ftruncate failed: {e}", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nix::unistd::ForkResult;
use std::ops::Range;
/// check that all bytes in given range have the expected value.
fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
for i in range {
let b = unsafe { *(ptr.add(i)) };
assert_eq!(expected, b, "unexpected byte at offset {i}");
}
}
/// Write 'b' to all bytes in the given range
fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
}
// simple single-process test of growing and shrinking
#[test]
fn test_shmem_resize() -> Result<(), Error> {
let max_size = 1024 * 1024;
let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
assert_eq!(init_struct.current_size(), 0);
// Initial grow
let size1 = 10000;
init_struct.set_size(size1).unwrap();
assert_eq!(init_struct.current_size(), size1);
// Write some data
let data_ptr = init_struct.data_ptr.as_ptr();
write_range(data_ptr, 0xAA, 0..size1);
assert_range(data_ptr, 0xAA, 0..size1);
// Shrink
let size2 = 5000;
init_struct.set_size(size2).unwrap();
assert_eq!(init_struct.current_size(), size2);
// Grow again
let size3 = 20000;
init_struct.set_size(size3).unwrap();
assert_eq!(init_struct.current_size(), size3);
// Try to read it. The area that was shrunk and grown again should read as all zeros now
assert_range(data_ptr, 0xAA, 0..5000);
assert_range(data_ptr, 0, 5000..size1);
// Try to grow beyond max_size
//let size4 = max_size + 1;
//assert!(init_struct.set_size(size4).is_err());
// Dropping init_struct should unmap the memory
drop(init_struct);
Ok(())
}
/// This is used in tests to coordinate between test processes. It's like std::sync::Barrier,
/// but is stored in the shared memory area and works across processes. It's implemented by
/// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
struct SimpleBarrier {
num_procs: usize,
count: AtomicUsize,
}
impl SimpleBarrier {
unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
unsafe {
*ptr = SimpleBarrier {
num_procs,
count: AtomicUsize::new(0),
}
}
}
pub fn wait(&self) {
let old = self.count.fetch_add(1, Ordering::Relaxed);
let generation = old / self.num_procs;
let mut current = old + 1;
while current < (generation + 1) * self.num_procs {
std::thread::sleep(std::time::Duration::from_millis(10));
current = self.count.load(Ordering::Relaxed);
}
}
}
#[test]
fn test_multi_process() {
// Initialize
let max_size = 1_000_000_000_000;
let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
let ptr = init_struct.data_ptr.as_ptr();
// Store the SimpleBarrier in the first 1k of the area.
init_struct.set_size(10000).unwrap();
let barrier_ptr: *mut SimpleBarrier = unsafe {
ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
.cast()
};
unsafe { SimpleBarrier::init(barrier_ptr, 2) };
let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
// Fork another test process. The code after this runs in both processes concurrently.
let fork_result = unsafe { nix::unistd::fork().unwrap() };
// In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
if fork_result.is_parent() {
write_range(ptr, 0xAA, 1000..2000);
} else {
write_range(ptr, 0xBB, 2000..3000);
}
barrier.wait();
// Verify the contents. (in both processes)
assert_range(ptr, 0xAA, 1000..2000);
assert_range(ptr, 0xBB, 2000..3000);
// Grow, from the child this time
let size = 10_000_000;
if !fork_result.is_parent() {
init_struct.set_size(size).unwrap();
}
barrier.wait();
// make some writes at the end
if fork_result.is_parent() {
write_range(ptr, 0xAA, (size - 10)..size);
} else {
write_range(ptr, 0xBB, (size - 20)..(size - 10));
}
barrier.wait();
// Verify the contents. (This runs in both processes)
assert_range(ptr, 0, (size - 1000)..(size - 20));
assert_range(ptr, 0xBB, (size - 20)..(size - 10));
assert_range(ptr, 0xAA, (size - 10)..size);
if let ForkResult::Parent { child } = fork_result {
nix::sys::wait::waitpid(child, None).unwrap();
}
}
}

View File

@@ -1,411 +0,0 @@
//! Dynamically resizable contiguous chunk of shared memory
use std::num::NonZeroUsize;
use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use nix::errno::Errno;
use nix::sys::mman::MapFlags;
use nix::sys::mman::ProtFlags;
use nix::sys::mman::mmap as nix_mmap;
use nix::sys::mman::munmap as nix_munmap;
use nix::unistd::ftruncate as nix_ftruncate;
/// `ShmemHandle` represents a shared memory area that can be shared by processes over `fork()`.
/// Unlike shared memory allocated by Postgres, this area is resizable, up to `max_size` that's
/// specified at creation.
///
/// The area is backed by an anonymous file created with `memfd_create()`. The full address space for
/// `max_size` is reserved up-front with `mmap()`, but whenever you call [`ShmemHandle::set_size`],
/// the underlying file is resized. Do not access the area beyond the current size. Currently, that
/// will cause the file to be expanded, but we might use `mprotect()` etc. to enforce that in the
/// future.
#[derive(Debug)]
pub struct ShmemHandle {
/// memfd file descriptor
fd: OwnedFd,
max_size: usize,
// Pointer to the beginning of the shared memory area. The header is stored there.
shared_ptr: NonNull<SharedStruct>,
// Pointer to the beginning of the user data
pub data_ptr: NonNull<u8>,
}
/// This is stored at the beginning in the shared memory area.
#[derive(Debug)]
struct SharedStruct {
max_size: usize,
/// Current size of the backing file. The high-order bit is used for the [`RESIZE_IN_PROGRESS`] flag.
current_size: AtomicUsize,
}
const RESIZE_IN_PROGRESS: usize = 1 << 63;
const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
/// Error type returned by the [`ShmemHandle`] functions.
#[derive(thiserror::Error, Debug)]
#[error("{msg}: {errno}")]
pub struct Error {
pub msg: String,
pub errno: Errno,
}
impl Error {
fn new(msg: &str, errno: Errno) -> Self {
Self {
msg: msg.to_string(),
errno,
}
}
}
impl ShmemHandle {
/// Create a new shared memory area. To communicate between processes, the processes need to be
/// `fork()`'d after calling this, so that the `ShmemHandle` is inherited by all processes.
///
/// If the `ShmemHandle` is dropped, the memory is unmapped from the current process. Other
/// processes can continue using it, however.
pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<Self, Error> {
// create the backing anonymous file.
let fd = create_backing_file(name)?;
Self::new_with_fd(fd, initial_size, max_size)
}
fn new_with_fd(fd: OwnedFd, initial_size: usize, max_size: usize) -> Result<Self, Error> {
// We reserve the high-order bit for the `RESIZE_IN_PROGRESS` flag, and the actual size
// is a little larger than this because of the SharedStruct header. Make the upper limit
// somewhat smaller than that, because with anything close to that, you'll run out of
// memory anyway.
assert!(max_size < 1 << 48, "max size {max_size} too large");
assert!(
initial_size <= max_size,
"initial size {initial_size} larger than max size {max_size}"
);
// The actual initial / max size is the one given by the caller, plus the size of
// 'SharedStruct'.
let initial_size = HEADER_SIZE + initial_size;
let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
// Reserve address space for it with mmap
//
// TODO: Use MAP_HUGETLB if possible
let start_ptr = unsafe {
nix_mmap(
None,
max_size,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_SHARED,
&fd,
0,
)
}
.map_err(|e| Error::new("mmap failed", e))?;
// Reserve space for the initial size
enlarge_file(fd.as_fd(), initial_size as u64)?;
// Initialize the header
let shared: NonNull<SharedStruct> = start_ptr.cast();
unsafe {
shared.write(SharedStruct {
max_size: max_size.into(),
current_size: AtomicUsize::new(initial_size),
});
}
// The user data begins after the header
let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
Ok(Self {
fd,
max_size: max_size.into(),
shared_ptr: shared,
data_ptr,
})
}
// return reference to the header
fn shared(&self) -> &SharedStruct {
unsafe { self.shared_ptr.as_ref() }
}
/// Resize the shared memory area. `new_size` must not be larger than the `max_size` specified
/// when creating the area.
///
/// This may only be called from one process/thread concurrently. We detect that case
/// and return an [`shmem::Error`](Error).
pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
let new_size = new_size + HEADER_SIZE;
let shared = self.shared();
assert!(
new_size <= self.max_size,
"new size ({new_size}) is greater than max size ({})",
self.max_size
);
assert_eq!(self.max_size, shared.max_size);
// Lock the area by setting the bit in `current_size`
//
// Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
// and the `posix_fallocate`/`ftruncate` call is surely a synchronization point anyway. But
// since this is not performance-critical, better safe than sorry.
let mut old_size = shared.current_size.load(Ordering::Acquire);
loop {
if (old_size & RESIZE_IN_PROGRESS) != 0 {
return Err(Error::new(
"concurrent resize detected",
Errno::UnknownErrno,
));
}
match shared.current_size.compare_exchange(
old_size,
new_size,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => old_size = x,
}
}
// Ok, we got the lock.
//
// NB: If anything goes wrong, we *must* clear the bit!
let result = {
use std::cmp::Ordering::{Equal, Greater, Less};
match new_size.cmp(&old_size) {
Less => nix_ftruncate(&self.fd, new_size as i64)
.map_err(|e| Error::new("could not shrink shmem segment, ftruncate failed", e)),
Equal => Ok(()),
Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
}
};
// Unlock
shared.current_size.store(
if result.is_ok() { new_size } else { old_size },
Ordering::Release,
);
result
}
/// Returns the current user-visible size of the shared memory segment.
///
/// NOTE: a concurrent [`ShmemHandle::set_size()`] call can change the size at any time.
/// It is the caller's responsibility not to access the area beyond the current size.
pub fn current_size(&self) -> usize {
let total_current_size =
self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
total_current_size - HEADER_SIZE
}
}
impl Drop for ShmemHandle {
fn drop(&mut self) {
// SAFETY: The pointer was obtained from mmap() with the given size.
// We unmap the entire region.
let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
// The fd is dropped automatically by OwnedFd.
}
}
/// Create a "backing file" for the shared memory area. On Linux, use `memfd_create()`, to create an
/// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
/// development and testing, but in production we want the file to stay in memory.
///
/// Disable unused variables warnings because `name` is unused in the macos path.
#[allow(unused_variables)]
fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
#[cfg(not(target_os = "macos"))]
{
nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
.map_err(|e| Error::new("memfd_create failed", e))
}
#[cfg(target_os = "macos")]
{
let file = tempfile::tempfile().map_err(|e| {
Error::new(
"could not create temporary file to back shmem area",
nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
)
})?;
Ok(OwnedFd::from(file))
}
}
fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
// Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
// we don't get a segfault later when trying to actually use it.
#[cfg(not(target_os = "macos"))]
{
nix::fcntl::posix_fallocate(fd, 0, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, posix_fallocate failed", e))
}
// As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
#[cfg(target_os = "macos")]
{
nix::unistd::ftruncate(fd, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, ftruncate failed", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nix::unistd::ForkResult;
use std::ops::Range;
/// check that all bytes in given range have the expected value.
fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
for i in range {
let b = unsafe { *(ptr.add(i)) };
assert_eq!(expected, b, "unexpected byte at offset {i}");
}
}
/// Write 'b' to all bytes in the given range
fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
}
// simple single-process test of growing and shrinking
#[test]
fn test_shmem_resize() -> Result<(), Error> {
let max_size = 1024 * 1024;
let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
assert_eq!(init_struct.current_size(), 0);
// Initial grow
let size1 = 10000;
init_struct.set_size(size1).unwrap();
assert_eq!(init_struct.current_size(), size1);
// Write some data
let data_ptr = init_struct.data_ptr.as_ptr();
write_range(data_ptr, 0xAA, 0..size1);
assert_range(data_ptr, 0xAA, 0..size1);
// Shrink
let size2 = 5000;
init_struct.set_size(size2).unwrap();
assert_eq!(init_struct.current_size(), size2);
// Grow again
let size3 = 20000;
init_struct.set_size(size3).unwrap();
assert_eq!(init_struct.current_size(), size3);
// Try to read it. The area that was shrunk and grown again should read as all zeros now
assert_range(data_ptr, 0xAA, 0..5000);
assert_range(data_ptr, 0, 5000..size1);
// Try to grow beyond max_size
//let size4 = max_size + 1;
//assert!(init_struct.set_size(size4).is_err());
// Dropping init_struct should unmap the memory
drop(init_struct);
Ok(())
}
/// This is used in tests to coordinate between test processes. It's like `std::sync::Barrier`,
/// but is stored in the shared memory area and works across processes. It's implemented by
/// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
struct SimpleBarrier {
num_procs: usize,
count: AtomicUsize,
}
impl SimpleBarrier {
unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
unsafe {
*ptr = SimpleBarrier {
num_procs,
count: AtomicUsize::new(0),
}
}
}
pub fn wait(&self) {
let old = self.count.fetch_add(1, Ordering::Relaxed);
let generation = old / self.num_procs;
let mut current = old + 1;
while current < (generation + 1) * self.num_procs {
std::thread::sleep(std::time::Duration::from_millis(10));
current = self.count.load(Ordering::Relaxed);
}
}
}
#[test]
fn test_multi_process() {
// Initialize
let max_size = 1_000_000_000_000;
let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
let ptr = init_struct.data_ptr.as_ptr();
// Store the SimpleBarrier in the first 1k of the area.
init_struct.set_size(10000).unwrap();
let barrier_ptr: *mut SimpleBarrier = unsafe {
ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
.cast()
};
unsafe { SimpleBarrier::init(barrier_ptr, 2) };
let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
// Fork another test process. The code after this runs in both processes concurrently.
let fork_result = unsafe { nix::unistd::fork().unwrap() };
// In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
if fork_result.is_parent() {
write_range(ptr, 0xAA, 1000..2000);
} else {
write_range(ptr, 0xBB, 2000..3000);
}
barrier.wait();
// Verify the contents. (in both processes)
assert_range(ptr, 0xAA, 1000..2000);
assert_range(ptr, 0xBB, 2000..3000);
// Grow, from the child this time
let size = 10_000_000;
if !fork_result.is_parent() {
init_struct.set_size(size).unwrap();
}
barrier.wait();
// make some writes at the end
if fork_result.is_parent() {
write_range(ptr, 0xAA, (size - 10)..size);
} else {
write_range(ptr, 0xBB, (size - 20)..(size - 10));
}
barrier.wait();
// Verify the contents. (This runs in both processes)
assert_range(ptr, 0, (size - 1000)..(size - 20));
assert_range(ptr, 0xBB, (size - 20)..(size - 10));
assert_range(ptr, 0xAA, (size - 10)..size);
if let ForkResult::Parent { child } = fork_result {
nix::sys::wait::waitpid(child, None).unwrap();
}
}
}

View File

@@ -1,111 +0,0 @@
//! Simple utilities akin to what's in [`std::sync`] but designed to work with shared memory.
use std::mem::MaybeUninit;
use std::ptr::NonNull;
use nix::errno::Errno;
pub type RwLock<T> = lock_api::RwLock<PthreadRwLock, T>;
pub type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, PthreadRwLock, T>;
pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, PthreadRwLock, T>;
pub type ValueReadGuard<'a, T> = lock_api::MappedRwLockReadGuard<'a, PthreadRwLock, T>;
pub type ValueWriteGuard<'a, T> = lock_api::MappedRwLockWriteGuard<'a, PthreadRwLock, T>;
/// Shared memory read-write lock.
pub struct PthreadRwLock(Option<NonNull<libc::pthread_rwlock_t>>);
/// Simple macro that calls a function in the libc namespace and panics if return value is nonzero.
macro_rules! libc_checked {
($fn_name:ident ( $($arg:expr),* )) => {{
let res = libc::$fn_name($($arg),*);
if res != 0 {
panic!("{} failed with {}", stringify!($fn_name), Errno::from_raw(res));
}
}};
}
impl PthreadRwLock {
/// Creates a new `PthreadRwLock` on top of a pointer to a pthread rwlock.
///
/// # Safety
/// `lock` must be non-null. Every unsafe operation will panic in the event of an error.
pub unsafe fn new(lock: *mut libc::pthread_rwlock_t) -> Self {
unsafe {
let mut attrs = MaybeUninit::uninit();
libc_checked!(pthread_rwlockattr_init(attrs.as_mut_ptr()));
libc_checked!(pthread_rwlockattr_setpshared(
attrs.as_mut_ptr(),
libc::PTHREAD_PROCESS_SHARED
));
libc_checked!(pthread_rwlock_init(lock, attrs.as_mut_ptr()));
// Safety: POSIX specifies that "any function affecting the attributes
// object (including destruction) shall not affect any previously
// initialized read-write locks".
libc_checked!(pthread_rwlockattr_destroy(attrs.as_mut_ptr()));
Self(Some(NonNull::new_unchecked(lock)))
}
}
fn inner(&self) -> NonNull<libc::pthread_rwlock_t> {
match self.0 {
None => {
panic!("PthreadRwLock constructed badly - something likely used RawRwLock::INIT")
}
Some(x) => x,
}
}
}
unsafe impl lock_api::RawRwLock for PthreadRwLock {
type GuardMarker = lock_api::GuardSend;
const INIT: Self = Self(None);
fn try_lock_shared(&self) -> bool {
unsafe {
let res = libc::pthread_rwlock_tryrdlock(self.inner().as_ptr());
match res {
0 => true,
libc::EAGAIN => false,
_ => panic!(
"pthread_rwlock_tryrdlock failed with {}",
Errno::from_raw(res)
),
}
}
}
fn try_lock_exclusive(&self) -> bool {
unsafe {
let res = libc::pthread_rwlock_trywrlock(self.inner().as_ptr());
match res {
0 => true,
libc::EAGAIN => false,
_ => panic!("try_wrlock failed with {}", Errno::from_raw(res)),
}
}
}
fn lock_shared(&self) {
unsafe {
libc_checked!(pthread_rwlock_rdlock(self.inner().as_ptr()));
}
}
fn lock_exclusive(&self) {
unsafe {
libc_checked!(pthread_rwlock_wrlock(self.inner().as_ptr()));
}
}
unsafe fn unlock_exclusive(&self) {
unsafe {
libc_checked!(pthread_rwlock_unlock(self.inner().as_ptr()));
}
}
unsafe fn unlock_shared(&self) {
unsafe {
libc_checked!(pthread_rwlock_unlock(self.inner().as_ptr()));
}
}
}

View File

@@ -1,14 +0,0 @@
[package]
name = "neonart"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
crossbeam-utils.workspace = true
spin.workspace = true
tracing.workspace = true
[dev-dependencies]
rand = "0.9.1"
rand_distr = "0.5.1"

View File

@@ -1,599 +0,0 @@
mod lock_and_version;
pub(crate) mod node_ptr;
mod node_ref;
use std::vec::Vec;
use crate::algorithm::lock_and_version::ConcurrentUpdateError;
use crate::algorithm::node_ptr::MAX_PREFIX_LEN;
use crate::algorithm::node_ref::{NewNodeRef, NodeRef, ReadLockedNodeRef, WriteLockedNodeRef};
use crate::allocator::OutOfMemoryError;
use crate::TreeWriteGuard;
use crate::UpdateAction;
use crate::allocator::ArtAllocator;
use crate::epoch::EpochPin;
use crate::{Key, Value};
pub(crate) type RootPtr<V> = node_ptr::NodePtr<V>;
#[derive(Debug)]
pub enum ArtError {
ConcurrentUpdate, // need to retry
OutOfMemory,
}
impl From<ConcurrentUpdateError> for ArtError {
fn from(_: ConcurrentUpdateError) -> ArtError {
ArtError::ConcurrentUpdate
}
}
impl From<OutOfMemoryError> for ArtError {
fn from(_: OutOfMemoryError) -> ArtError {
ArtError::OutOfMemory
}
}
pub fn new_root<V: Value>(
allocator: &impl ArtAllocator<V>,
) -> Result<RootPtr<V>, OutOfMemoryError> {
node_ptr::new_root(allocator)
}
pub(crate) fn search<'e, K: Key, V: Value>(
key: &K,
root: RootPtr<V>,
epoch_pin: &'e EpochPin,
) -> Option<&'e V> {
loop {
let root_ref = NodeRef::from_root_ptr(root);
if let Ok(result) = lookup_recurse(key.as_bytes(), root_ref, None, epoch_pin) {
break result;
}
// retry
}
}
pub(crate) fn iter_next<'e, V: Value>(
key: &[u8],
root: RootPtr<V>,
epoch_pin: &'e EpochPin,
) -> Option<(Vec<u8>, &'e V)> {
loop {
let mut path = Vec::new();
let root_ref = NodeRef::from_root_ptr(root);
match next_recurse(key, &mut path, root_ref, epoch_pin) {
Ok(Some(v)) => {
assert_eq!(path.len(), key.len());
break Some((path, v));
}
Ok(None) => break None,
Err(ConcurrentUpdateError()) => {
// retry
continue;
}
}
}
}
pub(crate) fn update_fn<'e, 'g, K: Key, V: Value, A: ArtAllocator<V>, F>(
key: &K,
value_fn: F,
root: RootPtr<V>,
guard: &'g mut TreeWriteGuard<'e, K, V, A>,
) -> Result<(), OutOfMemoryError>
where
F: FnOnce(Option<&V>) -> UpdateAction<V>,
{
let value_fn_cell = std::cell::Cell::new(Some(value_fn));
loop {
let root_ref = NodeRef::from_root_ptr(root);
let this_value_fn = |arg: Option<&V>| value_fn_cell.take().unwrap()(arg);
let key_bytes = key.as_bytes();
match update_recurse(
key_bytes,
this_value_fn,
root_ref,
None,
None,
guard,
0,
key_bytes,
) {
Ok(()) => break Ok(()),
Err(ArtError::ConcurrentUpdate) => {
continue; // retry
}
Err(ArtError::OutOfMemory) => break Err(OutOfMemoryError()),
}
}
}
// Error means you must retry.
//
// This corresponds to the 'lookupOpt' function in the paper
#[allow(clippy::only_used_in_recursion)]
fn lookup_recurse<'e, V: Value>(
key: &[u8],
node: NodeRef<'e, V>,
parent: Option<ReadLockedNodeRef<V>>,
epoch_pin: &'e EpochPin,
) -> Result<Option<&'e V>, ConcurrentUpdateError> {
let rnode = node.read_lock_or_restart()?;
if let Some(parent) = parent {
parent.read_unlock_or_restart()?;
}
// check if the prefix matches, may increment level
let prefix_len = if let Some(prefix_len) = rnode.prefix_matches(key) {
prefix_len
} else {
rnode.read_unlock_or_restart()?;
return Ok(None);
};
if rnode.is_leaf() {
assert_eq!(key.len(), prefix_len);
let vptr = rnode.get_leaf_value_ptr()?;
// safety: It's OK to return a ref of the pointer because we checked the version
// and the lifetime of 'epoch_pin' enforces that the reference is only accessible
// as long as the epoch is pinned.
let v = unsafe { vptr.as_ref().unwrap() };
return Ok(Some(v));
}
let key = &key[prefix_len..];
// find child (or leaf value)
let next_node = rnode.find_child_or_restart(key[0])?;
match next_node {
None => Ok(None), // key not found
Some(child) => lookup_recurse(&key[1..], child, Some(rnode), epoch_pin),
}
}
#[allow(clippy::only_used_in_recursion)]
fn next_recurse<'e, V: Value>(
min_key: &[u8],
path: &mut Vec<u8>,
node: NodeRef<'e, V>,
epoch_pin: &'e EpochPin,
) -> Result<Option<&'e V>, ConcurrentUpdateError> {
let rnode = node.read_lock_or_restart()?;
let prefix = rnode.get_prefix();
if !prefix.is_empty() {
path.extend_from_slice(prefix);
}
use std::cmp::Ordering;
let comparison = path.as_slice().cmp(&min_key[0..path.len()]);
if comparison == Ordering::Less {
rnode.read_unlock_or_restart()?;
return Ok(None);
}
if rnode.is_leaf() {
assert_eq!(path.len(), min_key.len());
let vptr = rnode.get_leaf_value_ptr()?;
// safety: It's OK to return a ref of the pointer because we checked the version
// and the lifetime of 'epoch_pin' enforces that the reference is only accessible
// as long as the epoch is pinned.
let v = unsafe { vptr.as_ref().unwrap() };
return Ok(Some(v));
}
let mut min_key_byte = match comparison {
Ordering::Less => unreachable!(), // checked this above already
Ordering::Equal => min_key[path.len()],
Ordering::Greater => 0,
};
loop {
match rnode.find_next_child_or_restart(min_key_byte)? {
None => {
return Ok(None);
}
Some((key_byte, child_ref)) => {
let path_len = path.len();
path.push(key_byte);
let result = next_recurse(min_key, path, child_ref, epoch_pin)?;
if result.is_some() {
return Ok(result);
}
if key_byte == u8::MAX {
return Ok(None);
}
path.truncate(path_len);
min_key_byte = key_byte + 1;
}
}
}
}
// This corresponds to the 'insertOpt' function in the paper
#[allow(clippy::only_used_in_recursion)]
#[allow(clippy::too_many_arguments)]
pub(crate) fn update_recurse<'e, K: Key, V: Value, A: ArtAllocator<V>, F>(
key: &[u8],
value_fn: F,
node: NodeRef<'e, V>,
rparent: Option<(ReadLockedNodeRef<V>, u8)>,
rgrandparent: Option<(ReadLockedNodeRef<V>, u8)>,
guard: &'_ mut TreeWriteGuard<'e, K, V, A>,
level: usize,
orig_key: &[u8],
) -> Result<(), ArtError>
where
F: FnOnce(Option<&V>) -> UpdateAction<V>,
{
let rnode = node.read_lock_or_restart()?;
let prefix_match_len = rnode.prefix_matches(key);
if prefix_match_len.is_none() {
let (rparent, parent_key) = rparent.expect("direct children of the root have no prefix");
let mut wparent = rparent.upgrade_to_write_lock_or_restart()?;
let mut wnode = rnode.upgrade_to_write_lock_or_restart()?;
match value_fn(None) {
UpdateAction::Nothing => {}
UpdateAction::Insert(new_value) => {
insert_split_prefix(key, new_value, &mut wnode, &mut wparent, parent_key, guard)?;
}
UpdateAction::Remove => {
panic!("unexpected Remove action on insertion");
}
}
wnode.write_unlock();
wparent.write_unlock();
return Ok(());
}
let prefix_match_len = prefix_match_len.unwrap();
let key = &key[prefix_match_len..];
let level = level + prefix_match_len;
if rnode.is_leaf() {
assert_eq!(key.len(), 0);
let (rparent, parent_key) = rparent.expect("root cannot be leaf");
let mut wparent = rparent.upgrade_to_write_lock_or_restart()?;
let mut wnode = rnode.upgrade_to_write_lock_or_restart()?;
// safety: Now that we have acquired the write lock, we have exclusive access to the
// value. XXX: There might be concurrent reads though?
let value_mut = wnode.get_leaf_value_mut();
match value_fn(Some(value_mut)) {
UpdateAction::Nothing => {
wparent.write_unlock();
wnode.write_unlock();
}
UpdateAction::Insert(_) => panic!("cannot insert over existing value"),
UpdateAction::Remove => {
guard.remember_obsolete_node(wnode.as_ptr());
wparent.delete_child(parent_key);
wnode.write_unlock_obsolete();
if let Some(rgrandparent) = rgrandparent {
// FIXME: Ignore concurrency error. It doesn't lead to
// corruption, but it means we might leak something. Until
// another update cleans it up.
let _ = cleanup_parent(wparent, rgrandparent, guard);
}
}
}
return Ok(());
}
let next_node = rnode.find_child_or_restart(key[0])?;
if next_node.is_none() {
if rnode.is_full() {
let (rparent, parent_key) = rparent.expect("root node cannot become full");
let mut wparent = rparent.upgrade_to_write_lock_or_restart()?;
let wnode = rnode.upgrade_to_write_lock_or_restart()?;
match value_fn(None) {
UpdateAction::Nothing => {
wnode.write_unlock();
wparent.write_unlock();
}
UpdateAction::Insert(new_value) => {
insert_and_grow(key, new_value, wnode, &mut wparent, parent_key, guard)?;
wparent.write_unlock();
}
UpdateAction::Remove => {
panic!("unexpected Remove action on insertion");
}
};
} else {
let mut wnode = rnode.upgrade_to_write_lock_or_restart()?;
if let Some((rparent, _)) = rparent {
rparent.read_unlock_or_restart()?;
}
match value_fn(None) {
UpdateAction::Nothing => {}
UpdateAction::Insert(new_value) => {
insert_to_node(&mut wnode, key, new_value, guard)?;
}
UpdateAction::Remove => {
panic!("unexpected Remove action on insertion");
}
};
wnode.write_unlock();
}
Ok(())
} else {
let next_child = next_node.unwrap(); // checked above it's not None
if let Some((ref rparent, _)) = rparent {
rparent.check_or_restart()?;
}
// recurse to next level
update_recurse(
&key[1..],
value_fn,
next_child,
Some((rnode, key[0])),
rparent,
guard,
level + 1,
orig_key,
)
}
}
#[derive(Clone)]
enum PathElement {
Prefix(Vec<u8>),
KeyByte(u8),
}
impl std::fmt::Debug for PathElement {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
match self {
PathElement::Prefix(prefix) => write!(fmt, "{prefix:?}"),
PathElement::KeyByte(key_byte) => write!(fmt, "{key_byte}"),
}
}
}
pub(crate) fn dump_tree<V: Value + std::fmt::Debug>(
root: RootPtr<V>,
epoch_pin: &'_ EpochPin,
dst: &mut dyn std::io::Write,
) {
let root_ref = NodeRef::from_root_ptr(root);
let _ = dump_recurse(&[], root_ref, epoch_pin, 0, dst);
}
// TODO: return an Err if writeln!() returns error, instead of unwrapping
#[allow(clippy::only_used_in_recursion)]
fn dump_recurse<'e, V: Value + std::fmt::Debug>(
path: &[PathElement],
node: NodeRef<'e, V>,
epoch_pin: &'e EpochPin,
level: usize,
dst: &mut dyn std::io::Write,
) -> Result<(), ConcurrentUpdateError> {
let indent = str::repeat(" ", level);
let rnode = node.read_lock_or_restart()?;
let mut path = Vec::from(path);
let prefix = rnode.get_prefix();
if !prefix.is_empty() {
path.push(PathElement::Prefix(Vec::from(prefix)));
}
if rnode.is_leaf() {
let vptr = rnode.get_leaf_value_ptr()?;
// safety: It's OK to return a ref of the pointer because we checked the version
// and the lifetime of 'epoch_pin' enforces that the reference is only accessible
// as long as the epoch is pinned.
let val = unsafe { vptr.as_ref().unwrap() };
writeln!(dst, "{indent} {path:?}: {val:?}").unwrap();
return Ok(());
}
for key_byte in 0..=u8::MAX {
match rnode.find_child_or_restart(key_byte)? {
None => continue,
Some(child_ref) => {
let rchild = child_ref.read_lock_or_restart()?;
writeln!(
dst,
"{} {:?}, {}: prefix {:?}",
indent,
&path,
key_byte,
rchild.get_prefix()
)
.unwrap();
let mut child_path = path.clone();
child_path.push(PathElement::KeyByte(key_byte));
dump_recurse(&child_path, child_ref, epoch_pin, level + 1, dst)?;
}
}
}
Ok(())
}
///```text
/// [fooba]r -> value
///
/// [foo]b -> [a]r -> value
/// e -> [ls]e -> value
///```
fn insert_split_prefix<K: Key, V: Value, A: ArtAllocator<V>>(
key: &[u8],
value: V,
node: &mut WriteLockedNodeRef<V>,
parent: &mut WriteLockedNodeRef<V>,
parent_key: u8,
guard: &'_ TreeWriteGuard<K, V, A>,
) -> Result<(), OutOfMemoryError> {
let old_node = node;
let old_prefix = old_node.get_prefix();
let common_prefix_len = common_prefix(key, old_prefix);
// Allocate a node for the new value.
let new_value_node = allocate_node_for_value(
&key[common_prefix_len + 1..],
value,
guard.tree_writer.allocator,
)?;
// Allocate a new internal node with the common prefix
// FIXME: deallocate 'new_value_node' on OOM
let mut prefix_node =
node_ref::new_internal(&key[..common_prefix_len], guard.tree_writer.allocator)?;
// Add the old node and the new nodes to the new internal node
prefix_node.insert_old_child(old_prefix[common_prefix_len], old_node);
prefix_node.insert_new_child(key[common_prefix_len], new_value_node);
// Modify the prefix of the old child in place
old_node.truncate_prefix(old_prefix.len() - common_prefix_len - 1);
// replace the pointer in the parent
parent.replace_child(parent_key, prefix_node.into_ptr());
Ok(())
}
fn insert_to_node<K: Key, V: Value, A: ArtAllocator<V>>(
wnode: &mut WriteLockedNodeRef<V>,
key: &[u8],
value: V,
guard: &'_ TreeWriteGuard<K, V, A>,
) -> Result<(), OutOfMemoryError> {
let value_child = allocate_node_for_value(&key[1..], value, guard.tree_writer.allocator)?;
wnode.insert_child(key[0], value_child.into_ptr());
Ok(())
}
// On entry: 'parent' and 'node' are locked
fn insert_and_grow<'e, 'g, K: Key, V: Value, A: ArtAllocator<V>>(
key: &[u8],
value: V,
wnode: WriteLockedNodeRef<V>,
parent: &mut WriteLockedNodeRef<V>,
parent_key_byte: u8,
guard: &'g mut TreeWriteGuard<'e, K, V, A>,
) -> Result<(), ArtError> {
let mut bigger_node = wnode.grow(guard.tree_writer.allocator)?;
// FIXME: deallocate 'bigger_node' on OOM
let value_child = allocate_node_for_value(&key[1..], value, guard.tree_writer.allocator)?;
bigger_node.insert_new_child(key[0], value_child);
// Replace the pointer in the parent
parent.replace_child(parent_key_byte, bigger_node.into_ptr());
guard.remember_obsolete_node(wnode.as_ptr());
wnode.write_unlock_obsolete();
Ok(())
}
fn cleanup_parent<'e, 'g, K: Key, V: Value, A: ArtAllocator<V>>(
wparent: WriteLockedNodeRef<V>,
rgrandparent: (ReadLockedNodeRef<V>, u8),
guard: &'g mut TreeWriteGuard<'e, K, V, A>,
) -> Result<(), ArtError> {
let (rgrandparent, grandparent_key_byte) = rgrandparent;
// If the parent becomes completely empty after the deletion, remove the parent from the
// grandparent. (This case is possible because we reserve only 8 bytes for the prefix.)
// TODO: not implemented.
// If the parent has only one child, replace the parent with the remaining child. (This is not
// possible if the child's prefix field cannot absorb the parent's)
if wparent.num_children() == 1 {
// Try to lock the remaining child. This can fail if the child is updated
// concurrently.
let (key_byte, remaining_child) = wparent.find_remaining_child();
let mut wremaining_child = remaining_child.write_lock_or_restart()?;
if 1 + wremaining_child.get_prefix().len() + wparent.get_prefix().len() <= MAX_PREFIX_LEN {
let mut wgrandparent = rgrandparent.upgrade_to_write_lock_or_restart()?;
// Ok, we have locked the leaf, the parent, the grandparent, and the parent's only
// remaining leaf. Proceed with the updates.
// Update the prefix on the remaining leaf
wremaining_child.prepend_prefix(wparent.get_prefix(), key_byte);
// Replace the pointer in the grandparent to point directly to the remaining leaf
wgrandparent.replace_child(grandparent_key_byte, wremaining_child.as_ptr());
// Mark the parent as deleted.
guard.remember_obsolete_node(wparent.as_ptr());
wparent.write_unlock_obsolete();
return Ok(());
}
}
// If the parent's children would fit on a smaller node type after the deletion, replace it with
// a smaller node.
if wparent.can_shrink() {
let mut wgrandparent = rgrandparent.upgrade_to_write_lock_or_restart()?;
let smaller_node = wparent.shrink(guard.tree_writer.allocator)?;
// Replace the pointer in the grandparent
wgrandparent.replace_child(grandparent_key_byte, smaller_node.into_ptr());
guard.remember_obsolete_node(wparent.as_ptr());
wparent.write_unlock_obsolete();
return Ok(());
}
// nothing to do
wparent.write_unlock();
Ok(())
}
// Allocate a new leaf node to hold 'value'. If the key is long, we
// may need to allocate new internal nodes to hold it too
fn allocate_node_for_value<'a, V: Value, A: ArtAllocator<V>>(
key: &[u8],
value: V,
allocator: &'a A,
) -> Result<NewNodeRef<'a, V, A>, OutOfMemoryError> {
let mut prefix_off = key.len().saturating_sub(MAX_PREFIX_LEN);
let leaf_node = node_ref::new_leaf(&key[prefix_off..key.len()], value, allocator)?;
let mut node = leaf_node;
while prefix_off > 0 {
// Need another internal node
let remain_prefix = &key[0..prefix_off];
prefix_off = remain_prefix.len().saturating_sub(MAX_PREFIX_LEN + 1);
let mut internal_node = node_ref::new_internal(
&remain_prefix[prefix_off..remain_prefix.len() - 1],
allocator,
)?;
internal_node.insert_new_child(*remain_prefix.last().unwrap(), node);
node = internal_node;
}
Ok(node)
}
fn common_prefix(a: &[u8], b: &[u8]) -> usize {
for i in 0..MAX_PREFIX_LEN {
if a[i] != b[i] {
return i;
}
}
panic!("prefixes are equal");
}

View File

@@ -1,117 +0,0 @@
//! Each node in the tree has contains one atomic word that stores three things:
//!
//! Bit 0: set if the node is "obsolete". An obsolete node has been removed from the tree,
//! but might still be accessed by concurrent readers until the epoch expires.
//! Bit 1: set if the node is currently write-locked. Used as a spinlock.
//! Bits 2-63: Version number, incremented every time the node is modified.
//!
//! AtomicLockAndVersion represents that.
use std::sync::atomic::{AtomicU64, Ordering};
pub(crate) struct ConcurrentUpdateError();
pub(crate) struct AtomicLockAndVersion {
inner: AtomicU64,
}
impl AtomicLockAndVersion {
pub(crate) fn new() -> AtomicLockAndVersion {
AtomicLockAndVersion {
inner: AtomicU64::new(0),
}
}
}
impl AtomicLockAndVersion {
pub(crate) fn read_lock_or_restart(&self) -> Result<u64, ConcurrentUpdateError> {
let version = self.await_node_unlocked();
if is_obsolete(version) {
return Err(ConcurrentUpdateError());
}
Ok(version)
}
pub(crate) fn check_or_restart(&self, version: u64) -> Result<(), ConcurrentUpdateError> {
self.read_unlock_or_restart(version)
}
pub(crate) fn read_unlock_or_restart(&self, version: u64) -> Result<(), ConcurrentUpdateError> {
if self.inner.load(Ordering::Acquire) != version {
return Err(ConcurrentUpdateError());
}
Ok(())
}
pub(crate) fn upgrade_to_write_lock_or_restart(
&self,
version: u64,
) -> Result<(), ConcurrentUpdateError> {
if self
.inner
.compare_exchange(
version,
set_locked_bit(version),
Ordering::Acquire,
Ordering::Relaxed,
)
.is_err()
{
return Err(ConcurrentUpdateError());
}
Ok(())
}
pub(crate) fn write_lock_or_restart(&self) -> Result<(), ConcurrentUpdateError> {
let old = self.inner.load(Ordering::Relaxed);
if is_obsolete(old) || is_locked(old) {
return Err(ConcurrentUpdateError());
}
if self
.inner
.compare_exchange(
old,
set_locked_bit(old),
Ordering::Acquire,
Ordering::Relaxed,
)
.is_err()
{
return Err(ConcurrentUpdateError());
}
Ok(())
}
pub(crate) fn write_unlock(&self) {
// reset locked bit and overflow into version
self.inner.fetch_add(2, Ordering::Release);
}
pub(crate) fn write_unlock_obsolete(&self) {
// set obsolete, reset locked, overflow into version
self.inner.fetch_add(3, Ordering::Release);
}
// Helper functions
fn await_node_unlocked(&self) -> u64 {
let mut version = self.inner.load(Ordering::Acquire);
while is_locked(version) {
// spinlock
std::thread::yield_now();
version = self.inner.load(Ordering::Acquire)
}
version
}
}
fn set_locked_bit(version: u64) -> u64 {
version + 2
}
fn is_obsolete(version: u64) -> bool {
(version & 1) == 1
}
fn is_locked(version: u64) -> bool {
(version & 2) == 2
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,349 +0,0 @@
use std::fmt::Debug;
use std::marker::PhantomData;
use super::node_ptr;
use super::node_ptr::NodePtr;
use crate::EpochPin;
use crate::Value;
use crate::algorithm::lock_and_version::AtomicLockAndVersion;
use crate::algorithm::lock_and_version::ConcurrentUpdateError;
use crate::allocator::ArtAllocator;
use crate::allocator::OutOfMemoryError;
pub struct NodeRef<'e, V> {
ptr: NodePtr<V>,
phantom: PhantomData<&'e EpochPin<'e>>,
}
impl<'e, V> Debug for NodeRef<'e, V> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(fmt, "{:?}", self.ptr)
}
}
impl<'e, V: Value> NodeRef<'e, V> {
pub(crate) fn from_root_ptr(root_ptr: NodePtr<V>) -> NodeRef<'e, V> {
NodeRef {
ptr: root_ptr,
phantom: PhantomData,
}
}
pub(crate) fn read_lock_or_restart(
&self,
) -> Result<ReadLockedNodeRef<'e, V>, ConcurrentUpdateError> {
let version = self.lockword().read_lock_or_restart()?;
Ok(ReadLockedNodeRef {
ptr: self.ptr,
version,
phantom: self.phantom,
})
}
pub(crate) fn write_lock_or_restart(
&self,
) -> Result<WriteLockedNodeRef<'e, V>, ConcurrentUpdateError> {
self.lockword().write_lock_or_restart()?;
Ok(WriteLockedNodeRef {
ptr: self.ptr,
phantom: self.phantom,
})
}
fn lockword(&self) -> &AtomicLockAndVersion {
self.ptr.lockword()
}
}
/// A reference to a node that has been optimistically read-locked. The functions re-check
/// the version after each read.
pub struct ReadLockedNodeRef<'e, V> {
ptr: NodePtr<V>,
version: u64,
phantom: PhantomData<&'e EpochPin<'e>>,
}
impl<'e, V: Value> ReadLockedNodeRef<'e, V> {
pub(crate) fn is_leaf(&self) -> bool {
self.ptr.is_leaf()
}
pub(crate) fn is_full(&self) -> bool {
self.ptr.is_full()
}
pub(crate) fn get_prefix(&self) -> &[u8] {
self.ptr.get_prefix()
}
/// Note: because we're only holding a read lock, the prefix can change concurrently.
/// You must be prepared to restart, if read_unlock() returns error later.
///
/// Returns the length of the prefix, or None if it's not a match
pub(crate) fn prefix_matches(&self, key: &[u8]) -> Option<usize> {
self.ptr.prefix_matches(key)
}
pub(crate) fn find_child_or_restart(
&self,
key_byte: u8,
) -> Result<Option<NodeRef<'e, V>>, ConcurrentUpdateError> {
let child_or_value = self.ptr.find_child(key_byte);
self.ptr.lockword().check_or_restart(self.version)?;
match child_or_value {
None => Ok(None),
Some(child_ptr) => Ok(Some(NodeRef {
ptr: child_ptr,
phantom: self.phantom,
})),
}
}
pub(crate) fn find_next_child_or_restart(
&self,
min_key_byte: u8,
) -> Result<Option<(u8, NodeRef<'e, V>)>, ConcurrentUpdateError> {
let child_or_value = self.ptr.find_next_child(min_key_byte);
self.ptr.lockword().check_or_restart(self.version)?;
match child_or_value {
None => Ok(None),
Some((k, child_ptr)) => Ok(Some((
k,
NodeRef {
ptr: child_ptr,
phantom: self.phantom,
},
))),
}
}
pub(crate) fn get_leaf_value_ptr(&self) -> Result<*const V, ConcurrentUpdateError> {
let result = self.ptr.get_leaf_value();
self.ptr.lockword().check_or_restart(self.version)?;
// Extend the lifetime.
let result = std::ptr::from_ref(result);
Ok(result)
}
pub(crate) fn upgrade_to_write_lock_or_restart(
self,
) -> Result<WriteLockedNodeRef<'e, V>, ConcurrentUpdateError> {
self.ptr
.lockword()
.upgrade_to_write_lock_or_restart(self.version)?;
Ok(WriteLockedNodeRef {
ptr: self.ptr,
phantom: self.phantom,
})
}
pub(crate) fn read_unlock_or_restart(self) -> Result<(), ConcurrentUpdateError> {
self.ptr.lockword().check_or_restart(self.version)?;
Ok(())
}
pub(crate) fn check_or_restart(&self) -> Result<(), ConcurrentUpdateError> {
self.ptr.lockword().check_or_restart(self.version)?;
Ok(())
}
}
/// A reference to a node that has been optimistically read-locked. The functions re-check
/// the version after each read.
pub struct WriteLockedNodeRef<'e, V> {
ptr: NodePtr<V>,
phantom: PhantomData<&'e EpochPin<'e>>,
}
impl<'e, V: Value> WriteLockedNodeRef<'e, V> {
pub(crate) fn can_shrink(&self) -> bool {
self.ptr.can_shrink()
}
pub(crate) fn num_children(&self) -> usize {
self.ptr.num_children()
}
pub(crate) fn write_unlock(mut self) {
self.ptr.lockword().write_unlock();
self.ptr = NodePtr::null();
}
pub(crate) fn write_unlock_obsolete(mut self) {
self.ptr.lockword().write_unlock_obsolete();
self.ptr = NodePtr::null();
}
pub(crate) fn get_prefix(&self) -> &[u8] {
self.ptr.get_prefix()
}
pub(crate) fn truncate_prefix(&mut self, new_prefix_len: usize) {
self.ptr.truncate_prefix(new_prefix_len)
}
pub(crate) fn prepend_prefix(&mut self, prefix: &[u8], prefix_byte: u8) {
self.ptr.prepend_prefix(prefix, prefix_byte)
}
pub(crate) fn insert_child(&mut self, key_byte: u8, child: NodePtr<V>) {
self.ptr.insert_child(key_byte, child)
}
pub(crate) fn get_leaf_value_mut(&mut self) -> &mut V {
self.ptr.get_leaf_value_mut()
}
pub(crate) fn grow<'a, A>(
&self,
allocator: &'a A,
) -> Result<NewNodeRef<'a, V, A>, OutOfMemoryError>
where
A: ArtAllocator<V>,
{
let new_node = self.ptr.grow(allocator)?;
Ok(NewNodeRef {
ptr: new_node,
allocator,
extra_nodes: Vec::new(),
})
}
pub(crate) fn shrink<'a, A>(
&self,
allocator: &'a A,
) -> Result<NewNodeRef<'a, V, A>, OutOfMemoryError>
where
A: ArtAllocator<V>,
{
let new_node = self.ptr.shrink(allocator)?;
Ok(NewNodeRef {
ptr: new_node,
allocator,
extra_nodes: Vec::new(),
})
}
pub(crate) fn as_ptr(&self) -> NodePtr<V> {
self.ptr
}
pub(crate) fn replace_child(&mut self, key_byte: u8, replacement: NodePtr<V>) {
self.ptr.replace_child(key_byte, replacement);
}
pub(crate) fn delete_child(&mut self, key_byte: u8) {
self.ptr.delete_child(key_byte);
}
pub(crate) fn find_remaining_child(&self) -> (u8, NodeRef<'e, V>) {
assert_eq!(self.num_children(), 1);
let child_or_value = self.ptr.find_next_child(0);
match child_or_value {
None => panic!("could not find only child in node"),
Some((k, child_ptr)) => (
k,
NodeRef {
ptr: child_ptr,
phantom: self.phantom,
},
),
}
}
}
impl<'e, V> Drop for WriteLockedNodeRef<'e, V> {
fn drop(&mut self) {
if !self.ptr.is_null() {
self.ptr.lockword().write_unlock();
}
}
}
pub(crate) struct NewNodeRef<'a, V, A>
where
V: Value,
A: ArtAllocator<V>,
{
ptr: NodePtr<V>,
allocator: &'a A,
extra_nodes: Vec<NodePtr<V>>,
}
impl<'a, V, A> NewNodeRef<'a, V, A>
where
V: Value,
A: ArtAllocator<V>,
{
pub(crate) fn insert_old_child(&mut self, key_byte: u8, child: &WriteLockedNodeRef<V>) {
self.ptr.insert_child(key_byte, child.as_ptr())
}
pub(crate) fn into_ptr(mut self) -> NodePtr<V> {
let ptr = self.ptr;
self.ptr = NodePtr::null();
ptr
}
pub(crate) fn insert_new_child(&mut self, key_byte: u8, child: NewNodeRef<'a, V, A>) {
let child_ptr = child.into_ptr();
self.ptr.insert_child(key_byte, child_ptr);
self.extra_nodes.push(child_ptr);
}
}
impl<'a, V, A> Drop for NewNodeRef<'a, V, A>
where
V: Value,
A: ArtAllocator<V>,
{
/// This drop implementation deallocates the newly allocated node, if into_ptr() was not called.
fn drop(&mut self) {
if !self.ptr.is_null() {
self.ptr.deallocate(self.allocator);
for p in self.extra_nodes.iter() {
p.deallocate(self.allocator);
}
}
}
}
pub(crate) fn new_internal<'a, V, A>(
prefix: &[u8],
allocator: &'a A,
) -> Result<NewNodeRef<'a, V, A>, OutOfMemoryError>
where
V: Value,
A: ArtAllocator<V>,
{
Ok(NewNodeRef {
ptr: node_ptr::new_internal(prefix, allocator)?,
allocator,
extra_nodes: Vec::new(),
})
}
pub(crate) fn new_leaf<'a, V, A>(
prefix: &[u8],
value: V,
allocator: &'a A,
) -> Result<NewNodeRef<'a, V, A>, OutOfMemoryError>
where
V: Value,
A: ArtAllocator<V>,
{
Ok(NewNodeRef {
ptr: node_ptr::new_leaf(prefix, value, allocator)?,
allocator,
extra_nodes: Vec::new(),
})
}

View File

@@ -1,156 +0,0 @@
pub mod block;
mod multislab;
mod slab;
pub mod r#static;
use std::alloc::Layout;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::sync::atomic::Ordering;
use crate::allocator::multislab::MultiSlabAllocator;
use crate::allocator::r#static::alloc_from_slice;
use spin;
use crate::Tree;
pub use crate::algorithm::node_ptr::{
NodeInternal4, NodeInternal16, NodeInternal48, NodeInternal256, NodeLeaf,
};
#[derive(Debug)]
pub struct OutOfMemoryError();
pub trait ArtAllocator<V: crate::Value> {
fn alloc_tree(&self) -> *mut Tree<V>;
fn alloc_node_internal4(&self) -> *mut NodeInternal4<V>;
fn alloc_node_internal16(&self) -> *mut NodeInternal16<V>;
fn alloc_node_internal48(&self) -> *mut NodeInternal48<V>;
fn alloc_node_internal256(&self) -> *mut NodeInternal256<V>;
fn alloc_node_leaf(&self) -> *mut NodeLeaf<V>;
fn dealloc_node_internal4(&self, ptr: *mut NodeInternal4<V>);
fn dealloc_node_internal16(&self, ptr: *mut NodeInternal16<V>);
fn dealloc_node_internal48(&self, ptr: *mut NodeInternal48<V>);
fn dealloc_node_internal256(&self, ptr: *mut NodeInternal256<V>);
fn dealloc_node_leaf(&self, ptr: *mut NodeLeaf<V>);
}
pub struct ArtMultiSlabAllocator<'t, V>
where
V: crate::Value,
{
tree_area: spin::Mutex<Option<&'t mut MaybeUninit<Tree<V>>>>,
pub(crate) inner: MultiSlabAllocator<'t, 5>,
phantom_val: PhantomData<V>,
}
impl<'t, V: crate::Value> ArtMultiSlabAllocator<'t, V> {
const LAYOUTS: [Layout; 5] = [
Layout::new::<NodeInternal4<V>>(),
Layout::new::<NodeInternal16<V>>(),
Layout::new::<NodeInternal48<V>>(),
Layout::new::<NodeInternal256<V>>(),
Layout::new::<NodeLeaf<V>>(),
];
pub fn new(area: &'t mut [MaybeUninit<u8>]) -> &'t mut ArtMultiSlabAllocator<'t, V> {
let (allocator_area, remain) = alloc_from_slice::<ArtMultiSlabAllocator<V>>(area);
let (tree_area, remain) = alloc_from_slice::<Tree<V>>(remain);
allocator_area.write(ArtMultiSlabAllocator {
tree_area: spin::Mutex::new(Some(tree_area)),
inner: MultiSlabAllocator::new(remain, &Self::LAYOUTS),
phantom_val: PhantomData,
})
}
}
impl<'t, V: crate::Value> ArtAllocator<V> for ArtMultiSlabAllocator<'t, V> {
fn alloc_tree(&self) -> *mut Tree<V> {
let mut t = self.tree_area.lock();
if let Some(tree_area) = t.take() {
return tree_area.as_mut_ptr().cast();
}
panic!("cannot allocate more than one tree");
}
fn alloc_node_internal4(&self) -> *mut NodeInternal4<V> {
self.inner.alloc_slab(0).cast()
}
fn alloc_node_internal16(&self) -> *mut NodeInternal16<V> {
self.inner.alloc_slab(1).cast()
}
fn alloc_node_internal48(&self) -> *mut NodeInternal48<V> {
self.inner.alloc_slab(2).cast()
}
fn alloc_node_internal256(&self) -> *mut NodeInternal256<V> {
self.inner.alloc_slab(3).cast()
}
fn alloc_node_leaf(&self) -> *mut NodeLeaf<V> {
self.inner.alloc_slab(4).cast()
}
fn dealloc_node_internal4(&self, ptr: *mut NodeInternal4<V>) {
self.inner.dealloc_slab(0, ptr.cast())
}
fn dealloc_node_internal16(&self, ptr: *mut NodeInternal16<V>) {
self.inner.dealloc_slab(1, ptr.cast())
}
fn dealloc_node_internal48(&self, ptr: *mut NodeInternal48<V>) {
self.inner.dealloc_slab(2, ptr.cast())
}
fn dealloc_node_internal256(&self, ptr: *mut NodeInternal256<V>) {
self.inner.dealloc_slab(3, ptr.cast())
}
fn dealloc_node_leaf(&self, ptr: *mut NodeLeaf<V>) {
self.inner.dealloc_slab(4, ptr.cast())
}
}
impl<'t, V: crate::Value> ArtMultiSlabAllocator<'t, V> {
pub(crate) fn get_statistics(&self) -> ArtMultiSlabStats {
ArtMultiSlabStats {
num_internal4: self.inner.slab_descs[0]
.num_allocated
.load(Ordering::Relaxed),
num_internal16: self.inner.slab_descs[1]
.num_allocated
.load(Ordering::Relaxed),
num_internal48: self.inner.slab_descs[2]
.num_allocated
.load(Ordering::Relaxed),
num_internal256: self.inner.slab_descs[3]
.num_allocated
.load(Ordering::Relaxed),
num_leaf: self.inner.slab_descs[4]
.num_allocated
.load(Ordering::Relaxed),
num_blocks_internal4: self.inner.slab_descs[0].num_blocks.load(Ordering::Relaxed),
num_blocks_internal16: self.inner.slab_descs[1].num_blocks.load(Ordering::Relaxed),
num_blocks_internal48: self.inner.slab_descs[2].num_blocks.load(Ordering::Relaxed),
num_blocks_internal256: self.inner.slab_descs[3].num_blocks.load(Ordering::Relaxed),
num_blocks_leaf: self.inner.slab_descs[4].num_blocks.load(Ordering::Relaxed),
}
}
}
#[derive(Clone, Debug)]
pub struct ArtMultiSlabStats {
pub num_internal4: u64,
pub num_internal16: u64,
pub num_internal48: u64,
pub num_internal256: u64,
pub num_leaf: u64,
pub num_blocks_internal4: u64,
pub num_blocks_internal16: u64,
pub num_blocks_internal48: u64,
pub num_blocks_internal256: u64,
pub num_blocks_leaf: u64,
}

Some files were not shown because too many files have changed in this diff Show More