Compare commits

..

37 Commits

Author SHA1 Message Date
Conrad Ludgate
d8ddf5c850 cancellation does not need tls 2025-06-20 11:15:35 +01:00
Conrad Ludgate
7c469b30aa add another debug and overflow comment 2025-06-10 22:42:57 -07:00
Conrad Ludgate
a78a52acb5 add in timeout for cancellation 2025-06-10 13:44:30 -07:00
Conrad Ludgate
3370e8cb00 optimise future sizes for cancel maintenance 2025-06-10 13:40:39 -07:00
Conrad Ludgate
f37a558280 move the cancel-on-shutdown handling to the cancel session maintenance task 2025-06-10 13:40:27 -07:00
Conrad Ludgate
744011437a create batch processing struct 2025-06-10 13:37:40 -07:00
Conrad Ludgate
a10d26a083 no explicit remove, only passive ttl 2025-06-10 13:36:11 -07:00
Conrad Ludgate
aece520365 remove replies for store/remove ops 2025-06-10 13:18:15 -07:00
Conrad Ludgate
9017811d61 remove dead code for redis keys 2025-06-10 13:18:15 -07:00
Conrad Ludgate
551a33aa04 use hget instead of hgetall 2025-06-10 13:18:15 -07:00
Conrad Ludgate
95216ae6ec box the connect future and respect the redis retry methods on err 2025-06-10 13:18:15 -07:00
Conrad Ludgate
a3a10d1839 split CancelToken into RawCancelToken for smaller sizes and better typesafety 2025-06-10 13:18:15 -07:00
Mikhail
1b935b1958 endpoint_storage: add ?from_endpoint= to /lfc/prewarm (#12195)
Related: https://github.com/neondatabase/cloud/issues/24225
Add optional from_endpoint parameter to allow prewarming from other
endpoint
2025-06-10 19:25:32 +00:00
a-masterov
3f16ca2c18 Respect limits for projects for the Random Operations test (#12184)
## Problem
The project limits were not respected, resulting in errors.
## Summary of changes
Now limits are checked before running an action, and if the action is
not possible to run, another random action will be run.

---------

Co-authored-by: Peter Bendel <peterbendel@neon.tech>
2025-06-10 15:59:51 +00:00
Conrad Ludgate
67b94c5992 [proxy] per endpoint configuration for rate limits (#12148)
https://github.com/neondatabase/cloud/issues/28333

Adds a new `rate_limit` response type to EndpointAccessControl, uses it
for rate limiting, and adds a generic invalidation for the cache.
2025-06-10 14:26:08 +00:00
Folke Behrens
e38193c530 proxy: Move connect_to_compute back to proxy (#12181)
It's mostly responsible for waking, retrying, and caching. A new, thin
wrapper around compute_once will be PGLB's entry point
2025-06-10 11:23:03 +00:00
Konstantin Knizhnik
21949137ed Return last ring index instead of min_ring_index in prefetch_register_bufferv (#12039)
## Problem

See https://github.com/neondatabase/neon/issues/12018

Now `prefetch_register_bufferv` calculates min_ring_index of all vector
requests.
But because of pump prefetch state or connection failure, previous slots
can be already proceeded and reused.

## Summary of changes

Instead of returning minimal index, this function should return last
slot index.
Actually result of this function is used only in two places. A first
place just fort checking (and this check is redundant because the same
check is done in `prefetch_register_bufferv` itself.
And in the second place where index of filled slot is actually used,
there is just one request.
Sp fortunately this bug can cause only assert failure in debug build.

---------

Co-authored-by: Konstantin Knizhnik <knizhnik@neon.tech>
2025-06-10 10:09:46 +00:00
Trung Dinh
02f94edb60 Remove global static TENANTS (#12169)
## Problem
There is this TODO in code:
https://github.com/neondatabase/neon/blob/main/pageserver/src/tenant/mgr.rs#L300-L302
This is an old TODO by @jcsp.

## Summary of changes
This PR addresses the TODO. Specifically, it removes a global static
`TENANTS`. Instead the `TenantManager` now directly manages the tenant
map. Enhancing abstraction.

Essentially, this PR moves all module-level methods to inside the
implementation of `TenantManager`.
2025-06-10 09:26:40 +00:00
Conrad Ludgate
58327ef74d [proxy] fix sql-over-http password setting (#12177)
## Problem

Looks like our sql-over-http tests get to rely on "trust"
authentication, so the path that made sure the authkeys data was set was
never being hit.

## Summary of changes

Slight refactor to WakeComputeBackends, as well as making sure auth keys
are propagated. Fix tests to ensure passwords are tested.
2025-06-10 08:46:29 +00:00
Dmitrii Kovalkov
73be6bb736 fix(compute): use proper safekeeper in VotesCollectedMset (#12175)
## Problem
`VotesCollectedMset` uses the wrong safekeeper to update truncateLsn.
This led to some failed assert later in the code during running
safekeeper migration tests.
- Relates to https://github.com/neondatabase/neon/issues/11823

## Summary of changes
Use proper safekeeper to update truncateLsn in VotesCollectedMset
2025-06-10 07:16:42 +00:00
Alex Chi Z.
40d7583906 feat(pageserver): use hostname as feature flag resolver property (#12141)
## Problem

part of https://github.com/neondatabase/neon/issues/11813

## Summary of changes

Collect pageserver hostname property so that we can use it in the
PostHog UI. Not sure if this is the best way to do that -- open to
suggestions.

---------

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-06-10 07:10:41 +00:00
Alex Chi Z.
7a68699abb feat(pageserver): support azure time-travel recovery (in an okay way) (#12140)
## Problem

part of https://github.com/neondatabase/neon/issues/7546

Add Azure time travel recovery support. The tricky thing is how Azure
handles deletes in its blob version API. For the following sequence:

```
upload file_1 = a
upload file_1 = b
delete file_1
upload file_1 = c
```

The "delete file_1" won't be stored as a version (as AWS did).
Therefore, we can never rollback to a state where file_1 is temporarily
invisible. If we roll back to the time before file_1 gets created for
the first time, it will be removed correctly.

However, this is fine for pageservers, because (1) having extra files in
the tenant storage is usually fine (2) for things like
timelines/X/index_part-Y.json, it will only be deleted once, so it can
always be recovered to a correct state. Therefore, I don't expect any
issues when this functionality is used on pageserver recovery.

TODO: unit tests for time-travel recovery.

## Summary of changes

Add Azure blob storage time-travel recovery support.

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-06-10 05:32:58 +00:00
Konstantin Knizhnik
f42d44342d Increase statement timeout for test_pageserver_restarts_under_workload test (#12139)
\## Problem

See
https://github.com/neondatabase/neon/issues/12119#issuecomment-2942586090

Page server restarts with interval 1 seconds increases time of vacuum
especially off prefetch is enabled and so cause test failure because of
statement timeout expiration.

## Summary of changes

Increase statement timeout to 360 seconds.

---------

Co-authored-by: Konstantin Knizhnik <knizhnik@neon.tech>
Co-authored-by: Alexander Lakhin <alexander.lakhin@neon.tech>
2025-06-10 05:32:03 +00:00
Konstantin Knizhnik
d759fcb8bd Increase wait LFC prewarm timeout (#12174)
## Problem

See https://github.com/neondatabase/neon/issues/12171

## Summary of changes

Increase LFC prewarm wait timeout to 1 minute

Co-authored-by: Konstantin Knizhnik <knizhnik@neon.tech>
2025-06-09 18:01:30 +00:00
Alex Chi Z.
76f95f06d8 feat(pageserver): add global timeline count metrics (#12159)
## Problem

We are getting tenants with a lot of branches and num of timelines is a
good indicator of pageserver loads. I added this metrics to help us
better plan pageserver capacities.

## Summary of changes

Add `pageserver_timeline_states_count` with two labels: active +
offloaded.

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-06-09 09:57:36 +00:00
Mikhail
7efd4554ab endpoint_storage: allow bypassing s3 write check on startup (#12165)
Related: https://github.com/neondatabase/cloud/issues/27195
2025-06-06 18:08:02 +00:00
Erik Grinaker
3c7235669a pageserver: don't delete parent shard files until split is committed (#12146)
## Problem

If a shard split fails and must roll back, the tenant may hit a cold
start as the parent shard's files have already been removed from local
disk.

External contribution with minor adjustments, see
https://neondb.slack.com/archives/C08TE3203RQ/p1748246398269309.

## Summary of changes

Keep the parent shard's files on local disk until the split has been
committed, such that they are available if the spilt is rolled back. If
all else fails, the files will be removed on the next Pageserver
restart.

This should also be fine in a mixed version:

* New storcon, old Pageserver: the Pageserver will delete the files
during the split, storcon will log an error when the cleanup detach
fails.

* Old storcon, new Pageserver: the Pageserver will leave the parent's
files around until the next Pageserver restart.

The change looks good to me, but shard splits are delicate so I'd like
some extra eyes on this.
2025-06-06 15:55:14 +00:00
Conrad Ludgate
6dd84041a1 refactor and simplify the invalidation notification structure (#12154)
The current cache invalidation messages are far too specific. They
should be more generic since it only ends up triggering a
`GetEndpointAccessControl` message anyway.

Mappings:
* `/allowed_ips_updated`, `/block_public_or_vpc_access_updated`, and
`/allowed_vpc_endpoints_updated_for_projects` ->
`/project_settings_update`.
* `/allowed_vpc_endpoints_updated_for_org` ->
`/account_settings_update`.
* `/password_updated` -> `/role_setting_update`.

I've also introduced `/endpoint_settings_update`.

All message types support singular or multiple entries, which allows us
to simplify things both on our side and on cplane side.

I'm opening a PR to cplane to apply the above mappings, but for now
using the old phrases to allow both to roll out independently.

This change is inspired by my need to add yet another cached entry to
`GetEndpointAccessControl` for
https://github.com/neondatabase/cloud/issues/28333
2025-06-06 12:49:29 +00:00
Arpad Müller
df7e301a54 safekeeper: special error if a timeline has been deleted (#12155)
We might delete timelines on safekeepers before we are deleting them on
pageservers. This should be an exceptional situation, but can occur. As
the first step to improve behaviour here, emit a special error that is
less scary/obscure than "was not found in global map".

It is for example emitted when the pageserver tries to run
`IDENTIFY_SYSTEM` on a timeline that has been deleted on the safekeeper.

Found when analyzing the failure of
`test_scrubber_physical_gc_timeline_deletion` when enabling
`--timelines-onto-safekeepers` on the pytests.

Due to safekeeper restarts, there is no hard guarantee that we will keep
issuing this error, so we need to think of something better if we start
encountering this in staging/prod. But I would say that the introduction
of `--timelines-onto-safekeepers` in the pytests and into staging won't
change much about this: we are already deleting timelines from there. In
`test_scrubber_physical_gc_timeline_deletion`, we'd just be leaking the
timeline before on the safekeepers.

Part of #11712
2025-06-06 11:54:07 +00:00
Mikhail
470c7d5e0e endpoint_storage: default listen port, allow inline config (#12152)
Related: https://github.com/neondatabase/cloud/issues/27195
2025-06-06 11:48:01 +00:00
Conrad Ludgate
4d99b6ff4d [proxy] separate compute connect from compute authentication (#12145)
## Problem

PGLB/Neonkeeper needs to separate the concerns of connecting to compute,
and authenticating to compute.

Additionally, the code within `connect_to_compute` is rather messy,
spending effort on recovering the authentication info after
wake_compute.

## Summary of changes

Split `ConnCfg` into `ConnectInfo` and `AuthInfo`. `wake_compute` only
returns `ConnectInfo` and `AuthInfo` is determined separately from the
`handshake`/`authenticate` process.

Additionally, `ConnectInfo::connect_raw` is in-charge or establishing
the TLS connection, and the `postgres_client::Config::connect_raw` is
configured to use `NoTls` which will force it to skip the TLS
negotiation. This should just work.
2025-06-06 10:29:55 +00:00
Alexander Sarantcev
590301df08 storcon: Introduce deletion tombstones to support flaky node scenario (#12096)
## Problem

Removed nodes can re-add themselves on restart if not properly
tombstoned. We need a mechanism (e.g. soft-delete flag) to prevent this,
especially in cases where the node is unreachable.

More details there: #12036

## Summary of changes

- Introduced `NodeLifecycle` enum to represent node lifecycle states.
- Added a string representation of `NodeLifecycle` to the `nodes` table.
- Implemented node removal using a tombstone mechanism.
- Introduced `/debug/v1/tombstone*` handlers to manage the tombstone
state.
2025-06-06 10:16:55 +00:00
Erik Grinaker
c511786548 pageserver: move spawn_grpc to GrpcPageServiceHandler::spawn (#12147)
Mechanical move, no logic changes.
2025-06-06 10:01:58 +00:00
Alex Chi Z.
fe31baf985 feat(build): add aws cli into the docker image (#12161)
## Problem

Makes it easier to debug AWS permission issues (i.e., storage scrubber)

## Summary of changes

Install awscliv2 into the docker image.

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-06-06 09:38:58 +00:00
Alex Chi Z.
b23e75ebfe test(pageserver): ensure offload cleans up metrics (#12127)
Add a test to ensure timeline metrics are fully cleaned up after
offloading.

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-06-06 06:50:54 +00:00
Arpad Müller
24d7c37e6e neon_local timeline import: create timelines on safekeepers (#12138)
neon_local's timeline import subcommand creates timelines manually, but
doesn't create them on the safekeepers. If a test then tries to open an
endpoint to read from the timeline, it will error in the new world with
`--timelines-onto-safekeepers`.

Therefore, if that flag is enabled, create the timelines on the
safekeepers.

Note that this import functionality is different from the fast import
feature (https://github.com/neondatabase/neon/issues/10188, #11801).

Part of #11670
As well as part of #11712
2025-06-05 18:53:14 +00:00
a-masterov
f64eb0cbaf Remove the Flaky Test computed-columns from postgis v16 (#12132)
## Problem
The `computed_columns` test assumes that computed columns are always
faster than the request itself. However, this is not always the case on
Neon, which can lead to flaky results.
## Summary of changes
The `computed_columns` test is excluded from the PostGIS test for
PostgreSQL v16, accompanied by related patch refactoring.
2025-06-05 15:02:38 +00:00
180 changed files with 3519 additions and 15348 deletions

1
.gitignore vendored
View File

@@ -13,7 +13,6 @@ neon.iml
/.neon
/integration_tests/.neon
compaction-suite-results.*
pgxn/neon/communicator/communicator_bindings.h
# Coverage
*.profraw

394
Cargo.lock generated
View File

@@ -253,17 +253,6 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8ab6b55fe97976e46f91ddbed8d147d966475dc29b2032757ba47e02376fbc3"
[[package]]
name = "atomic_enum"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99e1aca718ea7b89985790c94aad72d77533063fe00bc497bb79a7c2dae6a661"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.100",
]
[[package]]
name = "autocfg"
version = "1.1.0"
@@ -698,40 +687,13 @@ dependencies = [
"tracing",
]
[[package]]
name = "axum"
version = "0.7.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
dependencies = [
"async-trait",
"axum-core 0.4.5",
"bytes",
"futures-util",
"http 1.1.0",
"http-body 1.0.0",
"http-body-util",
"itoa",
"matchit 0.7.3",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"rustversion",
"serde",
"sync_wrapper 1.0.1",
"tower 0.5.2",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d6fd624c75e18b3b4c6b9caf42b1afe24437daaee904069137d8bab077be8b8"
dependencies = [
"axum-core 0.5.0",
"axum-core",
"base64 0.22.1",
"bytes",
"form_urlencoded",
@@ -739,10 +701,10 @@ dependencies = [
"http 1.1.0",
"http-body 1.0.0",
"http-body-util",
"hyper 1.6.0",
"hyper 1.4.1",
"hyper-util",
"itoa",
"matchit 0.8.4",
"matchit",
"memchr",
"mime",
"percent-encoding",
@@ -762,26 +724,6 @@ dependencies = [
"tracing",
]
[[package]]
name = "axum-core"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199"
dependencies = [
"async-trait",
"bytes",
"futures-util",
"http 1.1.0",
"http-body 1.0.0",
"http-body-util",
"mime",
"pin-project-lite",
"rustversion",
"sync_wrapper 1.0.1",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-core"
version = "0.5.0"
@@ -808,9 +750,10 @@ version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "460fc6f625a1f7705c6cf62d0d070794e94668988b1c38111baeec177c715f7b"
dependencies = [
"axum 0.8.1",
"axum-core 0.5.0",
"axum",
"axum-core",
"bytes",
"form_urlencoded",
"futures-util",
"headers",
"http 1.1.0",
@@ -819,6 +762,8 @@ dependencies = [
"mime",
"pin-project-lite",
"serde",
"serde_html_form",
"serde_path_to_error",
"tower 0.5.2",
"tower-layer",
"tower-service",
@@ -1144,25 +1089,6 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cbindgen"
version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eadd868a2ce9ca38de7eeafdcec9c7065ef89b42b32f0839278d55f35c54d1ff"
dependencies = [
"clap",
"heck 0.4.1",
"indexmap 2.9.0",
"log",
"proc-macro2",
"quote",
"serde",
"serde_json",
"syn 2.0.100",
"tempfile",
"toml",
]
[[package]]
name = "cc"
version = "1.2.16"
@@ -1289,7 +1215,7 @@ version = "4.5.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab"
dependencies = [
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"syn 2.0.100",
@@ -1347,34 +1273,6 @@ dependencies = [
"unicode-width",
]
[[package]]
name = "communicator"
version = "0.1.0"
dependencies = [
"atomic_enum",
"axum 0.8.1",
"bytes",
"cbindgen",
"clashmap",
"http 1.1.0",
"libc",
"metrics",
"neon-shmem",
"nix 0.30.1",
"pageserver_client_grpc",
"pageserver_page_api",
"prometheus",
"prost 0.13.5",
"thiserror 1.0.69",
"tokio",
"tokio-pipe",
"tonic 0.12.3",
"tracing",
"tracing-subscriber",
"uring-common",
"utils",
]
[[package]]
name = "compute_api"
version = "0.1.0"
@@ -1400,7 +1298,7 @@ dependencies = [
"aws-sdk-kms",
"aws-sdk-s3",
"aws-smithy-types",
"axum 0.8.1",
"axum",
"axum-extra",
"base64 0.13.1",
"bytes",
@@ -1424,7 +1322,6 @@ dependencies = [
"opentelemetry",
"opentelemetry_sdk",
"p256 0.13.2",
"pageserver_page_api",
"postgres",
"postgres_initdb",
"regex",
@@ -1443,7 +1340,6 @@ dependencies = [
"tokio-postgres",
"tokio-stream",
"tokio-util",
"tonic 0.13.1",
"tower 0.5.2",
"tower-http",
"tower-otel",
@@ -1552,6 +1448,7 @@ dependencies = [
"regex",
"reqwest",
"safekeeper_api",
"safekeeper_client",
"scopeguard",
"serde",
"serde_json",
@@ -1701,9 +1598,9 @@ dependencies = [
[[package]]
name = "crossbeam-utils"
version = "0.8.21"
version = "0.8.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345"
[[package]]
name = "crossterm"
@@ -2043,7 +1940,7 @@ checksum = "0892a17df262a24294c382f0d5997571006e7a4348b4327557c4ff1cd4a8bccc"
dependencies = [
"darling",
"either",
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"syn 2.0.100",
@@ -2157,10 +2054,11 @@ name = "endpoint_storage"
version = "0.0.1"
dependencies = [
"anyhow",
"axum 0.8.1",
"axum",
"axum-extra",
"camino",
"camino-tempfile",
"clap",
"futures",
"http-body-util",
"itertools 0.10.5",
@@ -2437,7 +2335,7 @@ dependencies = [
"futures-core",
"futures-sink",
"http-body-util",
"hyper 1.6.0",
"hyper 1.4.1",
"hyper-util",
"pin-project",
"rand 0.8.5",
@@ -2607,18 +2505,6 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "getrandom"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73fea8450eea4bac3940448fb7ae50d91f034f941199fcd9d909a5a07aa455f0"
dependencies = [
"cfg-if",
"libc",
"r-efi",
"wasi 0.14.2+wasi-0.2.4",
]
[[package]]
name = "gettid"
version = "0.1.3"
@@ -2831,12 +2717,6 @@ dependencies = [
"http 1.1.0",
]
[[package]]
name = "heck"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "heck"
version = "0.5.0"
@@ -3008,9 +2888,9 @@ dependencies = [
[[package]]
name = "httparse"
version = "1.10.1"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87"
checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904"
[[package]]
name = "httpdate"
@@ -3060,9 +2940,9 @@ dependencies = [
[[package]]
name = "hyper"
version = "1.6.0"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80"
checksum = "50dfd22e0e76d0f662d429a5f80fcaf3855009297eab6a0a9f8543834744ba05"
dependencies = [
"bytes",
"futures-channel",
@@ -3102,7 +2982,7 @@ checksum = "a0bea761b46ae2b24eb4aef630d8d1c398157b6fc29e6350ecf090a0b70c952c"
dependencies = [
"futures-util",
"http 1.1.0",
"hyper 1.6.0",
"hyper 1.4.1",
"hyper-util",
"rustls 0.22.4",
"rustls-pki-types",
@@ -3117,7 +2997,7 @@ version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3203a961e5c83b6f5498933e78b6b263e208c197b63e9c6c53cc82ffd3f63793"
dependencies = [
"hyper 1.6.0",
"hyper 1.4.1",
"hyper-util",
"pin-project-lite",
"tokio",
@@ -3126,20 +3006,20 @@ dependencies = [
[[package]]
name = "hyper-util"
version = "0.1.12"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf9f1e950e0d9d1d3c47184416723cf29c0d1f93bd8cccf37e4beb6b44f31710"
checksum = "cde7055719c54e36e95e8719f95883f22072a48ede39db7fc17a4e1d5281e9b9"
dependencies = [
"bytes",
"futures-channel",
"futures-util",
"http 1.1.0",
"http-body 1.0.0",
"hyper 1.6.0",
"libc",
"hyper 1.4.1",
"pin-project-lite",
"socket2",
"tokio",
"tower 0.4.13",
"tower-service",
"tracing",
]
@@ -3728,12 +3608,6 @@ dependencies = [
"regex-automata 0.1.10",
]
[[package]]
name = "matchit"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
[[package]]
name = "matchit"
version = "0.8.4"
@@ -3779,7 +3653,7 @@ version = "0.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e6777fc80a575f9503d908c8b498782a6c3ee88a06cb416dc3941401e43b94"
dependencies = [
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"syn 2.0.100",
@@ -3841,7 +3715,7 @@ dependencies = [
"procfs",
"prometheus",
"rand 0.8.5",
"rand_distr 0.4.3",
"rand_distr",
"twox-hash",
]
@@ -3930,25 +3804,11 @@ name = "neon-shmem"
version = "0.1.0"
dependencies = [
"nix 0.30.1",
"rand 0.9.1",
"rand_distr 0.5.1",
"spin",
"tempfile",
"thiserror 1.0.69",
"workspace_hack",
]
[[package]]
name = "neonart"
version = "0.1.0"
dependencies = [
"crossbeam-utils",
"rand 0.9.1",
"rand_distr 0.5.1",
"spin",
"tracing",
]
[[package]]
name = "never-say-never"
version = "6.6.666"
@@ -4382,19 +4242,15 @@ version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"axum 0.8.1",
"bytes",
"camino",
"clap",
"futures",
"hdrhistogram",
"http 1.1.0",
"humantime",
"humantime-serde",
"metrics",
"pageserver_api",
"pageserver_client",
"pageserver_client_grpc",
"pageserver_page_api",
"rand 0.8.5",
"reqwest",
@@ -4477,7 +4333,6 @@ dependencies = [
"pageserver_client",
"pageserver_compaction",
"pageserver_page_api",
"peekable",
"pem",
"pin-project-lite",
"postgres-protocol",
@@ -4490,7 +4345,6 @@ dependencies = [
"pprof",
"pq_proto",
"procfs",
"prost 0.13.5",
"rand 0.8.5",
"range-set-blaze",
"regex",
@@ -4590,34 +4444,6 @@ dependencies = [
"workspace_hack",
]
[[package]]
name = "pageserver_client_grpc"
version = "0.1.0"
dependencies = [
"async-trait",
"bytes",
"chrono",
"dashmap 5.5.0",
"futures",
"http 1.1.0",
"hyper 1.6.0",
"hyper-util",
"metrics",
"pageserver_api",
"pageserver_page_api",
"priority-queue",
"rand 0.8.5",
"thiserror 1.0.69",
"tokio",
"tokio-stream",
"tokio-util",
"tonic 0.13.1",
"tower 0.4.13",
"tracing",
"utils",
"uuid",
]
[[package]]
name = "pageserver_compaction"
version = "0.1.0"
@@ -4782,15 +4608,6 @@ dependencies = [
"sha2",
]
[[package]]
name = "peekable"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "225f9651e475709164f871dc2f5724956be59cb9edb055372ffeeab01ec2d20b"
dependencies = [
"smallvec",
]
[[package]]
name = "pem"
version = "3.0.3"
@@ -5202,17 +5019,6 @@ dependencies = [
"elliptic-curve 0.13.8",
]
[[package]]
name = "priority-queue"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef08705fa1589a1a59aa924ad77d14722cb0cd97b67dd5004ed5f4a4873fce8d"
dependencies = [
"autocfg",
"equivalent",
"indexmap 2.9.0",
]
[[package]]
name = "proc-macro2"
version = "1.0.94"
@@ -5291,7 +5097,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22505a5c94da8e3b7c2996394d1c933236c4d743e81a410bcca4e6989fc066a4"
dependencies = [
"bytes",
"heck 0.5.0",
"heck",
"itertools 0.12.1",
"log",
"multimap",
@@ -5312,7 +5118,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c1318b19085f08681016926435853bbf7858f9c082d0999b80550ff5d9abe15"
dependencies = [
"bytes",
"heck 0.5.0",
"heck",
"itertools 0.12.1",
"log",
"multimap",
@@ -5413,7 +5219,7 @@ dependencies = [
"humantime",
"humantime-serde",
"hyper 0.14.30",
"hyper 1.6.0",
"hyper 1.4.1",
"hyper-util",
"indexmap 2.9.0",
"ipnet",
@@ -5437,7 +5243,7 @@ dependencies = [
"postgres_backend",
"pq_proto",
"rand 0.8.5",
"rand_distr 0.4.3",
"rand_distr",
"rcgen",
"redis",
"regex",
@@ -5541,12 +5347,6 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "r-efi"
version = "5.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5"
[[package]]
name = "rand"
version = "0.7.3"
@@ -5571,16 +5371,6 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rand"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
]
[[package]]
name = "rand_chacha"
version = "0.2.2"
@@ -5601,16 +5391,6 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rand_chacha"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core 0.9.3",
]
[[package]]
name = "rand_core"
version = "0.5.1"
@@ -5629,15 +5409,6 @@ dependencies = [
"getrandom 0.2.11",
]
[[package]]
name = "rand_core"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
dependencies = [
"getrandom 0.3.2",
]
[[package]]
name = "rand_distr"
version = "0.4.3"
@@ -5648,16 +5419,6 @@ dependencies = [
"rand 0.8.5",
]
[[package]]
name = "rand_distr"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463"
dependencies = [
"num-traits",
"rand 0.9.1",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
@@ -5854,7 +5615,7 @@ dependencies = [
"http-body-util",
"http-types",
"humantime-serde",
"hyper 1.6.0",
"hyper 1.4.1",
"itertools 0.10.5",
"metrics",
"once_cell",
@@ -5894,7 +5655,7 @@ dependencies = [
"http 1.1.0",
"http-body 1.0.0",
"http-body-util",
"hyper 1.6.0",
"hyper 1.4.1",
"hyper-rustls 0.26.0",
"hyper-util",
"ipnet",
@@ -5951,7 +5712,7 @@ dependencies = [
"futures",
"getrandom 0.2.11",
"http 1.1.0",
"hyper 1.6.0",
"hyper 1.4.1",
"parking_lot 0.11.2",
"reqwest",
"reqwest-middleware",
@@ -5972,7 +5733,7 @@ dependencies = [
"async-trait",
"getrandom 0.2.11",
"http 1.1.0",
"matchit 0.8.4",
"matchit",
"opentelemetry",
"reqwest",
"reqwest-middleware",
@@ -6664,6 +6425,19 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "serde_html_form"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d2de91cf02bbc07cde38891769ccd5d4f073d22a40683aa4bc7a95781aaa2c4"
dependencies = [
"form_urlencoded",
"indexmap 2.9.0",
"itoa",
"ryu",
"serde",
]
[[package]]
name = "serde_json"
version = "1.0.125"
@@ -6892,12 +6666,12 @@ dependencies = [
[[package]]
name = "socket2"
version = "0.5.9"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f5fd57c80058a56cf5c777ab8a126398ece8e442983605d280a44ce79d0edef"
checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9"
dependencies = [
"libc",
"windows-sys 0.52.0",
"windows-sys 0.48.0",
]
[[package]]
@@ -6905,9 +6679,6 @@ name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
dependencies = [
"lock_api",
]
[[package]]
name = "spinning_top"
@@ -6966,7 +6737,7 @@ dependencies = [
"http-body-util",
"http-utils",
"humantime",
"hyper 1.6.0",
"hyper 1.4.1",
"hyper-util",
"metrics",
"once_cell",
@@ -7147,7 +6918,7 @@ version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be"
dependencies = [
"heck 0.5.0",
"heck",
"proc-macro2",
"quote",
"rustversion",
@@ -7572,16 +7343,6 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "tokio-pipe"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f213a84bffbd61b8fa0ba8a044b4bbe35d471d0b518867181e82bd5c15542784"
dependencies = [
"libc",
"tokio",
]
[[package]]
name = "tokio-postgres"
version = "0.7.10"
@@ -7776,25 +7537,16 @@ version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "877c5b330756d856ffcc4553ab34a5684481ade925ecc54bcd1bf02b1d0d4d52"
dependencies = [
"async-stream",
"async-trait",
"axum 0.7.9",
"base64 0.22.1",
"bytes",
"h2 0.4.4",
"http 1.1.0",
"http-body 1.0.0",
"http-body-util",
"hyper 1.6.0",
"hyper-timeout",
"hyper-util",
"percent-encoding",
"pin-project",
"prost 0.13.5",
"socket2",
"tokio",
"tokio-stream",
"tower 0.4.13",
"tower-layer",
"tower-service",
"tracing",
@@ -7807,15 +7559,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e581ba15a835f4d9ea06c55ab1bd4dce26fc53752c69a04aac00703bfb49ba9"
dependencies = [
"async-trait",
"axum 0.8.1",
"axum",
"base64 0.22.1",
"bytes",
"flate2",
"h2 0.4.4",
"http 1.1.0",
"http-body 1.0.0",
"http-body-util",
"hyper 1.6.0",
"hyper 1.4.1",
"hyper-timeout",
"hyper-util",
"percent-encoding",
@@ -7867,16 +7618,11 @@ checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
dependencies = [
"futures-core",
"futures-util",
"indexmap 1.9.3",
"pin-project",
"pin-project-lite",
"rand 0.8.5",
"slab",
"tokio",
"tokio-util",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
@@ -8360,7 +8106,7 @@ name = "vm_monitor"
version = "0.1.0"
dependencies = [
"anyhow",
"axum 0.8.1",
"axum",
"cgroups-rs",
"clap",
"futures",
@@ -8471,15 +8217,6 @@ version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wasi"
version = "0.14.2+wasi-0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3"
dependencies = [
"wit-bindgen-rt",
]
[[package]]
name = "wasite"
version = "0.1.0"
@@ -8837,15 +8574,6 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "wit-bindgen-rt"
version = "0.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
dependencies = [
"bitflags 2.8.0",
]
[[package]]
name = "workspace_hack"
version = "0.1.0"
@@ -8853,8 +8581,8 @@ dependencies = [
"ahash",
"anstream",
"anyhow",
"axum 0.8.1",
"axum-core 0.5.0",
"axum",
"axum-core",
"base64 0.13.1",
"base64 0.21.7",
"base64ct",
@@ -8887,7 +8615,7 @@ dependencies = [
"hex",
"hmac",
"hyper 0.14.30",
"hyper 1.6.0",
"hyper 1.4.1",
"hyper-util",
"indexmap 2.9.0",
"itertools 0.12.1",

View File

@@ -8,7 +8,6 @@ members = [
"pageserver/compaction",
"pageserver/ctl",
"pageserver/client",
"pageserver/client_grpc",
"pageserver/pagebench",
"pageserver/page_api",
"proxy",
@@ -33,7 +32,6 @@ members = [
"libs/pq_proto",
"libs/tenant_size_model",
"libs/metrics",
"libs/neonart",
"libs/postgres_connection",
"libs/remote_storage",
"libs/tracing-utils",
@@ -46,7 +44,6 @@ members = [
"libs/proxy/postgres-types2",
"libs/proxy/tokio-postgres2",
"endpoint_storage",
"pgxn/neon/communicator",
]
[workspace.package]
@@ -74,7 +71,7 @@ aws-credential-types = "1.2.0"
aws-sigv4 = { version = "1.2", features = ["sign-http"] }
aws-types = "1.3"
axum = { version = "0.8.1", features = ["ws"] }
axum-extra = { version = "0.10.0", features = ["typed-header"] }
axum-extra = { version = "0.10.0", features = ["typed-header", "query"] }
base64 = "0.13.0"
bincode = "1.3"
bindgen = "0.71"
@@ -90,7 +87,6 @@ clap = { version = "4.0", features = ["derive", "env"] }
clashmap = { version = "1.0", features = ["raw-api"] }
comfy-table = "7.1"
const_format = "0.2"
crossbeam-utils = "0.8.21"
crc32c = "0.6"
diatomic-waker = { version = "0.2.3" }
either = "1.8"
@@ -149,7 +145,6 @@ parquet = { version = "53", default-features = false, features = ["zstd"] }
parquet_derive = "53"
pbkdf2 = { version = "0.12.1", features = ["simple", "std"] }
pem = "3.0.3"
peekable = "0.3.0"
pin-project-lite = "0.2"
pprof = { version = "0.14", features = ["criterion", "flamegraph", "frame-pointer", "prost-codec"] }
procfs = "0.16"
@@ -184,7 +179,6 @@ smallvec = "1.11"
smol_str = { version = "0.2.0", features = ["serde"] }
socket2 = "0.5"
spki = "0.7.3"
spin = "0.9.8"
strum = "0.26"
strum_macros = "0.26"
"subtle" = "2.5.0"
@@ -196,15 +190,16 @@ thiserror = "1.0"
tikv-jemallocator = { version = "0.6", features = ["profiling", "stats", "unprefixed_malloc_on_supported_platforms"] }
tikv-jemalloc-ctl = { version = "0.6", features = ["stats"] }
tokio = { version = "1.43.1", features = ["macros"] }
tokio-epoll-uring = { git = "https://github.com/neondatabase/tokio-epoll-uring.git" , branch = "main" }
tokio-io-timeout = "1.2.0"
tokio-postgres-rustls = "0.12.0"
tokio-rustls = { version = "0.26.0", default-features = false, features = ["tls12", "ring"]}
tokio-stream = "0.1"
tokio-tar = "0.3"
tokio-util = { version = "0.7.10", features = ["io", "io-util", "rt"] }
tokio-util = { version = "0.7.10", features = ["io", "rt"] }
toml = "0.8"
toml_edit = "0.22"
tonic = { version = "0.13.1", default-features = false, features = ["channel", "codegen", "gzip", "prost", "router", "server", "tls-ring", "tls-native-roots"] }
tonic = { version = "0.13.1", default-features = false, features = ["channel", "codegen", "prost", "router", "server", "tls-ring", "tls-native-roots"] }
tonic-reflection = { version = "0.13.1", features = ["server"] }
tower = { version = "0.5.2", default-features = false }
tower-http = { version = "0.6.2", features = ["auth", "request-id", "trace"] }
@@ -237,9 +232,6 @@ x509-cert = { version = "0.2.5" }
env_logger = "0.11"
log = "0.4"
tokio-epoll-uring = { git = "https://github.com/neondatabase/tokio-epoll-uring.git" , branch = "main" }
uring-common = { git = "https://github.com/neondatabase/tokio-epoll-uring.git" , branch = "main" }
## Libraries from neondatabase/ git forks, ideally with changes to be upstreamed
postgres = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon" }
postgres-protocol = { git = "https://github.com/neondatabase/rust-postgres.git", branch = "neon" }
@@ -259,12 +251,9 @@ desim = { version = "0.1", path = "./libs/desim" }
endpoint_storage = { version = "0.0.1", path = "./endpoint_storage/" }
http-utils = { version = "0.1", path = "./libs/http-utils/" }
metrics = { version = "0.1", path = "./libs/metrics/" }
neonart = { version = "0.1", path = "./libs/neonart/" }
neon-shmem = { version = "0.1", path = "./libs/neon-shmem/" }
pageserver = { path = "./pageserver" }
pageserver_api = { version = "0.1", path = "./libs/pageserver_api/" }
pageserver_client = { path = "./pageserver/client" }
pageserver_client_grpc = { path = "./pageserver/client_grpc" }
pageserver_compaction = { version = "0.1", path = "./pageserver/compaction/" }
pageserver_page_api = { path = "./pageserver/page_api" }
postgres_backend = { version = "0.1", path = "./libs/postgres_backend/" }
@@ -289,7 +278,6 @@ walproposer = { version = "0.1", path = "./libs/walproposer/" }
workspace_hack = { version = "0.1", path = "./workspace_hack/" }
## Build dependencies
cbindgen = "0.28.0"
criterion = "0.5.1"
rcgen = "0.13"
rstest = "0.18"

View File

@@ -110,6 +110,19 @@ RUN set -e \
# System postgres for use with client libraries (e.g. in storage controller)
postgresql-15 \
openssl \
unzip \
curl \
&& ARCH=$(uname -m) \
&& if [ "$ARCH" = "x86_64" ]; then \
curl "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"; \
elif [ "$ARCH" = "aarch64" ]; then \
curl "https://awscli.amazonaws.com/awscli-exe-linux-aarch64.zip" -o "awscliv2.zip"; \
else \
echo "Unsupported architecture: $ARCH" && exit 1; \
fi \
&& unzip awscliv2.zip \
&& ./aws/install \
&& rm -rf aws awscliv2.zip \
&& rm -f /etc/apt/apt.conf.d/80-retries \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* \
&& useradd -d /data neon \

View File

@@ -18,12 +18,10 @@ ifeq ($(BUILD_TYPE),release)
PG_LDFLAGS = $(LDFLAGS)
# Unfortunately, `--profile=...` is a nightly feature
CARGO_BUILD_FLAGS += --release
NEON_CARGO_ARTIFACT_TARGET_DIR = $(ROOT_PROJECT_DIR)/target/release
else ifeq ($(BUILD_TYPE),debug)
PG_CONFIGURE_OPTS = --enable-debug --with-openssl --enable-cassert --enable-depend
PG_CFLAGS += -O0 -g3 $(CFLAGS)
PG_LDFLAGS = $(LDFLAGS)
NEON_CARGO_ARTIFACT_TARGET_DIR = $(ROOT_PROJECT_DIR)/target/debug
else
$(error Bad build type '$(BUILD_TYPE)', see Makefile for options)
endif
@@ -182,16 +180,11 @@ postgres-check-%: postgres-%
.PHONY: neon-pg-ext-%
neon-pg-ext-%: postgres-%
+@echo "Compiling communicator $*"
$(CARGO_CMD_PREFIX) cargo build -p communicator $(CARGO_BUILD_FLAGS)
+@echo "Compiling neon $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \
LIBCOMMUNICATOR_PATH=$(NEON_CARGO_ARTIFACT_TARGET_DIR) \
-C $(POSTGRES_INSTALL_DIR)/build/neon-$* \
-f $(ROOT_PROJECT_DIR)/pgxn/neon/Makefile install
+@echo "Compiling neon_walredo $*"
mkdir -p $(POSTGRES_INSTALL_DIR)/build/neon-walredo-$*
$(MAKE) PG_CONFIG=$(POSTGRES_INSTALL_DIR)/$*/bin/pg_config COPT='$(COPT)' \

View File

@@ -38,7 +38,6 @@ once_cell.workspace = true
opentelemetry.workspace = true
opentelemetry_sdk.workspace = true
p256 = { version = "0.13", features = ["pem"] }
pageserver_page_api.workspace = true
postgres.workspace = true
regex.workspace = true
reqwest = { workspace = true, features = ["json"] }
@@ -54,7 +53,6 @@ tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
tokio-postgres.workspace = true
tokio-util.workspace = true
tokio-stream.workspace = true
tonic.workspace = true
tower-otel.workspace = true
tracing.workspace = true
tracing-opentelemetry.workspace = true

View File

@@ -1,4 +1,4 @@
use anyhow::{Context, Result, anyhow};
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use compute_api::privilege::Privilege;
use compute_api::responses::{
@@ -15,7 +15,6 @@ use itertools::Itertools;
use nix::sys::signal::{Signal, kill};
use nix::unistd::Pid;
use once_cell::sync::Lazy;
use pageserver_page_api as page_api;
use postgres;
use postgres::NoTls;
use postgres::error::SqlState;
@@ -30,9 +29,7 @@ use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Condvar, Mutex, RwLock};
use std::time::{Duration, Instant};
use std::{env, fs};
use tokio::io::AsyncReadExt;
use tokio::spawn;
use tokio_util::io::StreamReader;
use tracing::{Instrument, debug, error, info, instrument, warn};
use url::Url;
use utils::id::{TenantId, TimelineId};
@@ -372,7 +369,7 @@ impl ComputeNode {
let mut new_state = ComputeState::new();
if let Some(spec) = config.spec {
let pspec = ParsedSpec::try_from(spec).map_err(|msg| anyhow!(msg))?;
let pspec = ParsedSpec::try_from(spec).map_err(|msg| anyhow::anyhow!(msg))?;
new_state.pspec = Some(pspec);
}
@@ -788,7 +785,7 @@ impl ComputeNode {
self.spawn_extension_stats_task();
if pspec.spec.autoprewarm {
self.prewarm_lfc();
self.prewarm_lfc(None);
}
Ok(())
}
@@ -944,74 +941,6 @@ impl ComputeNode {
#[instrument(skip_all, fields(%lsn))]
fn try_get_basebackup(&self, compute_state: &ComputeState, lsn: Lsn) -> Result<()> {
let spec = compute_state.pspec.as_ref().expect("spec must be set");
let shard0_connstr = spec.pageserver_connstr.split(',').next().unwrap();
match Url::parse(shard0_connstr)?.scheme() {
"postgres" | "postgresql" => self.try_get_basebackup_libpq(spec, lsn),
"grpc" => self.try_get_basebackup_grpc(spec, lsn),
scheme => return Err(anyhow!("unknown URL scheme {scheme}")),
}
}
fn try_get_basebackup_grpc(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<()> {
let start_time = Instant::now();
let shard0_connstr = spec
.pageserver_connstr
.split(',')
.next()
.unwrap()
.to_string();
let chunks = tokio::runtime::Handle::current().block_on(async move {
let mut client = page_api::proto::PageServiceClient::connect(shard0_connstr).await?;
let req = page_api::proto::GetBaseBackupRequest {
lsn: lsn.0,
replica: false, // TODO: handle replicas, with LSN 0
};
let mut req = tonic::Request::new(req);
let metadata = req.metadata_mut();
metadata.insert("neon-tenant-id", spec.tenant_id.to_string().parse()?);
metadata.insert("neon-timeline-id", spec.timeline_id.to_string().parse()?);
metadata.insert("neon-shard-id", "0000".to_string().parse()?); // TODO: shard count
if let Some(auth) = spec.storage_auth_token.as_ref() {
metadata.insert("authorization", format!("Bearer {auth}").parse()?);
}
let chunks = client.get_base_backup(req).await?.into_inner();
anyhow::Ok(chunks)
})?;
let pageserver_connect_micros = start_time.elapsed().as_micros() as u64;
// Convert the chunks stream into an AsyncRead
let stream_reader = StreamReader::new(
chunks.map(|chunk| chunk.map(|c| c.chunk).map_err(std::io::Error::other)),
);
// Wrap the AsyncRead into a blocking reader for compatibility with tar::Archive
let reader = tokio_util::io::SyncIoBridge::new(stream_reader);
let mut measured_reader = MeasuredReader::new(reader);
let mut bufreader = std::io::BufReader::new(&mut measured_reader);
// Read the archive directly from the `CopyOutReader`
//
// Set `ignore_zeros` so that unpack() reads all the Copy data and
// doesn't stop at the end-of-archive marker. Otherwise, if the server
// sends an Error after finishing the tarball, we will not notice it.
let mut ar = tar::Archive::new(&mut bufreader);
ar.set_ignore_zeros(true);
ar.unpack(&self.params.pgdata)?;
// Report metrics
let mut state = self.state.lock().unwrap();
state.metrics.pageserver_connect_micros = pageserver_connect_micros;
state.metrics.basebackup_bytes = measured_reader.get_byte_count() as u64;
state.metrics.basebackup_ms = start_time.elapsed().as_millis() as u64;
Ok(())
}
fn try_get_basebackup_libpq(&self, spec: &ParsedSpec, lsn: Lsn) -> Result<()> {
let start_time = Instant::now();
let shard0_connstr = spec.pageserver_connstr.split(',').next().unwrap();
@@ -1027,10 +956,12 @@ impl ComputeNode {
}
config.application_name("compute_ctl");
config.options(&format!(
"-c neon.compute_mode={}",
spec.spec.mode.to_type_str()
));
if let Some(spec) = &compute_state.pspec {
config.options(&format!(
"-c neon.compute_mode={}",
spec.spec.mode.to_type_str()
));
}
// Connect to pageserver
let mut client = config.connect(NoTls)?;
@@ -1104,7 +1035,10 @@ impl ComputeNode {
return result;
}
Err(ref e) if attempts < max_attempts => {
warn!("Failed to get basebackup: {e:?} (attempt {attempts}/{max_attempts})");
warn!(
"Failed to get basebackup: {} (attempt {}/{})",
e, attempts, max_attempts
);
std::thread::sleep(std::time::Duration::from_millis(retry_period_ms as u64));
retry_period_ms *= 1.5;
}
@@ -1982,7 +1916,7 @@ LIMIT 100",
self.params
.remote_ext_base_url
.as_ref()
.ok_or(DownloadError::BadInput(anyhow!(
.ok_or(DownloadError::BadInput(anyhow::anyhow!(
"Remote extensions storage is not configured",
)))?;
@@ -2178,7 +2112,7 @@ LIMIT 100",
let remote_extensions = spec
.remote_extensions
.as_ref()
.ok_or(anyhow!("Remote extensions are not configured"))?;
.ok_or(anyhow::anyhow!("Remote extensions are not configured"))?;
info!("parse shared_preload_libraries from spec.cluster.settings");
let mut libs_vec = Vec::new();

View File

@@ -25,11 +25,16 @@ struct EndpointStoragePair {
}
const KEY: &str = "lfc_state";
impl TryFrom<&crate::compute::ParsedSpec> for EndpointStoragePair {
type Error = anyhow::Error;
fn try_from(pspec: &crate::compute::ParsedSpec) -> Result<Self, Self::Error> {
let Some(ref endpoint_id) = pspec.spec.endpoint_id else {
bail!("pspec.endpoint_id missing")
impl EndpointStoragePair {
/// endpoint_id is set to None while prewarming from other endpoint, see replica promotion
/// If not None, takes precedence over pspec.spec.endpoint_id
fn from_spec_and_endpoint(
pspec: &crate::compute::ParsedSpec,
endpoint_id: Option<String>,
) -> Result<Self> {
let endpoint_id = endpoint_id.as_ref().or(pspec.spec.endpoint_id.as_ref());
let Some(ref endpoint_id) = endpoint_id else {
bail!("pspec.endpoint_id missing, other endpoint_id not provided")
};
let Some(ref base_uri) = pspec.endpoint_storage_addr else {
bail!("pspec.endpoint_storage_addr missing")
@@ -84,7 +89,7 @@ impl ComputeNode {
}
/// Returns false if there is a prewarm request ongoing, true otherwise
pub fn prewarm_lfc(self: &Arc<Self>) -> bool {
pub fn prewarm_lfc(self: &Arc<Self>, from_endpoint: Option<String>) -> bool {
crate::metrics::LFC_PREWARM_REQUESTS.inc();
{
let state = &mut self.state.lock().unwrap().lfc_prewarm_state;
@@ -97,7 +102,7 @@ impl ComputeNode {
let cloned = self.clone();
spawn(async move {
let Err(err) = cloned.prewarm_impl().await else {
let Err(err) = cloned.prewarm_impl(from_endpoint).await else {
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Completed;
return;
};
@@ -109,13 +114,14 @@ impl ComputeNode {
true
}
fn endpoint_storage_pair(&self) -> Result<EndpointStoragePair> {
/// from_endpoint: None for endpoint managed by this compute_ctl
fn endpoint_storage_pair(&self, from_endpoint: Option<String>) -> Result<EndpointStoragePair> {
let state = self.state.lock().unwrap();
state.pspec.as_ref().unwrap().try_into()
EndpointStoragePair::from_spec_and_endpoint(state.pspec.as_ref().unwrap(), from_endpoint)
}
async fn prewarm_impl(&self) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?;
async fn prewarm_impl(&self, from_endpoint: Option<String>) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?;
info!(%url, "requesting LFC state from endpoint storage");
let request = Client::new().get(&url).bearer_auth(token);
@@ -173,7 +179,7 @@ impl ComputeNode {
}
async fn offload_lfc_impl(&self) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair()?;
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?;
info!(%url, "requesting LFC state from postgres");
let mut compressed = Vec::new();

View File

@@ -2,6 +2,7 @@ use crate::compute_prewarm::LfcPrewarmStateWithProgress;
use crate::http::JsonResponse;
use axum::response::{IntoResponse, Response};
use axum::{Json, http::StatusCode};
use axum_extra::extract::OptionalQuery;
use compute_api::responses::LfcOffloadState;
type Compute = axum::extract::State<std::sync::Arc<crate::compute::ComputeNode>>;
@@ -16,8 +17,16 @@ pub(in crate::http) async fn offload_state(compute: Compute) -> Json<LfcOffloadS
Json(compute.lfc_offload_state())
}
pub(in crate::http) async fn prewarm(compute: Compute) -> Response {
if compute.prewarm_lfc() {
#[derive(serde::Deserialize)]
pub struct PrewarmQuery {
pub from_endpoint: String,
}
pub(in crate::http) async fn prewarm(
compute: Compute,
OptionalQuery(query): OptionalQuery<PrewarmQuery>,
) -> Response {
if compute.prewarm_lfc(query.map(|q| q.from_endpoint)) {
StatusCode::ACCEPTED.into_response()
} else {
JsonResponse::error(

View File

@@ -36,6 +36,7 @@ pageserver_api.workspace = true
pageserver_client.workspace = true
postgres_backend.workspace = true
safekeeper_api.workspace = true
safekeeper_client.workspace = true
postgres_connection.workspace = true
storage_broker.workspace = true
http-utils.workspace = true

View File

@@ -18,7 +18,7 @@ use clap::Parser;
use compute_api::requests::ComputeClaimsScope;
use compute_api::spec::ComputeMode;
use control_plane::broker::StorageBroker;
use control_plane::endpoint::{ComputeControlPlane, PageserverProtocol};
use control_plane::endpoint::ComputeControlPlane;
use control_plane::endpoint_storage::{ENDPOINT_STORAGE_DEFAULT_ADDR, EndpointStorage};
use control_plane::local_env;
use control_plane::local_env::{
@@ -45,7 +45,7 @@ use pageserver_api::models::{
use pageserver_api::shard::{DEFAULT_STRIPE_SIZE, ShardCount, ShardStripeSize, TenantShardId};
use postgres_backend::AuthType;
use postgres_connection::parse_host_port;
use safekeeper_api::membership::SafekeeperGeneration;
use safekeeper_api::membership::{SafekeeperGeneration, SafekeeperId};
use safekeeper_api::{
DEFAULT_HTTP_LISTEN_PORT as DEFAULT_SAFEKEEPER_HTTP_PORT,
DEFAULT_PG_LISTEN_PORT as DEFAULT_SAFEKEEPER_PG_PORT,
@@ -664,10 +664,6 @@ struct EndpointStartCmdArgs {
#[clap(short = 't', long, value_parser= humantime::parse_duration, help = "timeout until we fail the command")]
#[arg(default_value = "90s")]
start_timeout: Duration,
/// If enabled, use gRPC (and the communicator) to talk to Pageservers.
#[clap(long)]
grpc: bool,
}
#[derive(clap::Args)]
@@ -686,10 +682,6 @@ struct EndpointReconfigureCmdArgs {
#[clap(long)]
safekeepers: Option<String>,
/// If enabled, use gRPC (and communicator) to talk to Pageservers.
#[clap(long)]
grpc: bool,
}
#[derive(clap::Args)]
@@ -1263,6 +1255,45 @@ async fn handle_timeline(cmd: &TimelineCmd, env: &mut local_env::LocalEnv) -> Re
pageserver
.timeline_import(tenant_id, timeline_id, base, pg_wal, args.pg_version)
.await?;
if env.storage_controller.timelines_onto_safekeepers {
println!("Creating timeline on safekeeper ...");
let timeline_info = pageserver
.timeline_info(
TenantShardId::unsharded(tenant_id),
timeline_id,
pageserver_client::mgmt_api::ForceAwaitLogicalSize::No,
)
.await?;
let default_sk = SafekeeperNode::from_env(env, env.safekeepers.first().unwrap());
let default_host = default_sk
.conf
.listen_addr
.clone()
.unwrap_or_else(|| "localhost".to_string());
let mconf = safekeeper_api::membership::Configuration {
generation: SafekeeperGeneration::new(1),
members: safekeeper_api::membership::MemberSet {
m: vec![SafekeeperId {
host: default_host,
id: default_sk.conf.id,
pg_port: default_sk.conf.pg_port,
}],
},
new_members: None,
};
let pg_version = args.pg_version * 10000;
let req = safekeeper_api::models::TimelineCreateRequest {
tenant_id,
timeline_id,
mconf,
pg_version,
system_id: None,
wal_seg_size: None,
start_lsn: timeline_info.last_record_lsn,
commit_lsn: None,
};
default_sk.create_timeline(&req).await?;
}
env.register_branch_mapping(branch_name.to_string(), tenant_id, timeline_id)?;
println!("Done");
}
@@ -1460,18 +1491,13 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
let (pageservers, stripe_size) = if let Some(pageserver_id) = pageserver_id {
let conf = env.get_pageserver_conf(pageserver_id).unwrap();
// Use gRPC if requested.
let (protocol, host, port) = if args.grpc {
let grpc_addr = conf.listen_grpc_addr.as_ref().expect("bad config");
let (host, port) = parse_host_port(grpc_addr).expect("bad config");
(PageserverProtocol::Grpc, host, port.unwrap_or(51051))
} else {
let (host, port) = parse_host_port(&conf.listen_pg_addr).expect("bad config");
(PageserverProtocol::Libpq, host, port.unwrap_or(5432))
};
// If caller is telling us what pageserver to use, this is not a tenant which is
// fully managed by storage controller, therefore not sharded.
(vec![(protocol, host, port)], DEFAULT_STRIPE_SIZE)
let parsed = parse_host_port(&conf.listen_pg_addr).expect("Bad config");
(
vec![(parsed.0, parsed.1.unwrap_or(5432))],
// If caller is telling us what pageserver to use, this is not a tenant which is
// full managed by storage controller, therefore not sharded.
DEFAULT_STRIPE_SIZE,
)
} else {
// Look up the currently attached location of the tenant, and its striping metadata,
// to pass these on to postgres.
@@ -1490,22 +1516,11 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
.await?;
}
let pageserver = if args.grpc {
(
PageserverProtocol::Grpc,
Host::parse(&shard.listen_grpc_addr.expect("no gRPC addr"))
.expect("bad hostname"),
shard.listen_grpc_port.expect("no gRPC port"),
)
} else {
(
PageserverProtocol::Libpq,
Host::parse(&shard.listen_pg_addr).expect("bad hostname"),
shard.listen_pg_port,
)
};
anyhow::Ok(pageserver)
anyhow::Ok((
Host::parse(&shard.listen_pg_addr)
.expect("Storage controller reported bad hostname"),
shard.listen_pg_port,
))
}),
)
.await?;
@@ -1560,17 +1575,11 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
.get(endpoint_id.as_str())
.with_context(|| format!("postgres endpoint {endpoint_id} is not found"))?;
let pageservers = if let Some(ps_id) = args.endpoint_pageserver_id {
let conf = env.get_pageserver_conf(ps_id)?;
// Use gRPC if requested.
let (protocol, host, port) = if args.grpc {
let grpc_addr = conf.listen_grpc_addr.as_ref().expect("bad config");
let (host, port) = parse_host_port(grpc_addr).expect("bad config");
(PageserverProtocol::Grpc, host, port.unwrap_or(51051))
} else {
let (host, port) = parse_host_port(&conf.listen_pg_addr).expect("bad config");
(PageserverProtocol::Libpq, host, port.unwrap_or(5432))
};
vec![(protocol, host, port)]
let pageserver = PageServerNode::from_env(env, env.get_pageserver_conf(ps_id)?);
vec![(
pageserver.pg_connection_config.host().clone(),
pageserver.pg_connection_config.port(),
)]
} else {
let storage_controller = StorageController::from_env(env);
storage_controller
@@ -1579,20 +1588,11 @@ async fn handle_endpoint(subcmd: &EndpointCmd, env: &local_env::LocalEnv) -> Res
.shards
.into_iter()
.map(|shard| {
if args.grpc {
(
PageserverProtocol::Grpc,
Host::parse(&shard.listen_grpc_addr.expect("no gRPC addr"))
.expect("bad hostname"),
shard.listen_grpc_port.expect("no gRPC port"),
)
} else {
(
PageserverProtocol::Libpq,
Host::parse(&shard.listen_pg_addr).expect("bad hostname"),
shard.listen_pg_port,
)
}
(
Host::parse(&shard.listen_pg_addr)
.expect("Storage controller reported malformed host"),
shard.listen_pg_port,
)
})
.collect::<Vec<_>>()
};

View File

@@ -37,7 +37,6 @@
//! ```
//!
use std::collections::BTreeMap;
use std::fmt::Display;
use std::net::{IpAddr, Ipv4Addr, SocketAddr, TcpStream};
use std::path::PathBuf;
use std::process::Command;
@@ -75,6 +74,7 @@ use utils::id::{NodeId, TenantId, TimelineId};
use crate::local_env::LocalEnv;
use crate::postgresql_conf::PostgresConf;
use crate::storage_controller::StorageController;
// contents of a endpoint.json file
#[derive(Serialize, Deserialize, PartialEq, Eq, Clone, Debug)]
@@ -331,7 +331,7 @@ pub enum EndpointStatus {
RunningNoPidfile,
}
impl Display for EndpointStatus {
impl std::fmt::Display for EndpointStatus {
fn fmt(&self, writer: &mut std::fmt::Formatter) -> std::fmt::Result {
let s = match self {
Self::Running => "running",
@@ -343,28 +343,6 @@ impl Display for EndpointStatus {
}
}
#[derive(Clone, Copy, Debug)]
pub enum PageserverProtocol {
Libpq,
Grpc,
}
impl PageserverProtocol {
/// Returns the URL scheme for the protocol, used in connstrings.
pub fn scheme(&self) -> &'static str {
match self {
Self::Libpq => "postgresql",
Self::Grpc => "grpc",
}
}
}
impl Display for PageserverProtocol {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.scheme())
}
}
impl Endpoint {
fn from_dir_entry(entry: std::fs::DirEntry, env: &LocalEnv) -> Result<Endpoint> {
if !entry.file_type()?.is_dir() {
@@ -628,10 +606,10 @@ impl Endpoint {
}
}
fn build_pageserver_connstr(pageservers: &[(PageserverProtocol, Host, u16)]) -> String {
fn build_pageserver_connstr(pageservers: &[(Host, u16)]) -> String {
pageservers
.iter()
.map(|(scheme, host, port)| format!("{scheme}://no_user@{host}:{port}"))
.map(|(host, port)| format!("postgresql://no_user@{host}:{port}"))
.collect::<Vec<_>>()
.join(",")
}
@@ -676,7 +654,7 @@ impl Endpoint {
endpoint_storage_addr: String,
safekeepers_generation: Option<SafekeeperGeneration>,
safekeepers: Vec<NodeId>,
pageservers: Vec<(PageserverProtocol, Host, u16)>,
pageservers: Vec<(Host, u16)>,
remote_ext_base_url: Option<&String>,
shard_stripe_size: usize,
create_test_user: bool,
@@ -961,12 +939,10 @@ impl Endpoint {
pub async fn reconfigure(
&self,
pageservers: Vec<(PageserverProtocol, Host, u16)>,
mut pageservers: Vec<(Host, u16)>,
stripe_size: Option<ShardStripeSize>,
safekeepers: Option<Vec<NodeId>>,
) -> Result<()> {
anyhow::ensure!(!pageservers.is_empty(), "no pageservers provided");
let (mut spec, compute_ctl_config) = {
let config_path = self.endpoint_path().join("config.json");
let file = std::fs::File::open(config_path)?;
@@ -978,7 +954,25 @@ impl Endpoint {
let postgresql_conf = self.read_postgresql_conf()?;
spec.cluster.postgresql_conf = Some(postgresql_conf);
// If we weren't given explicit pageservers, query the storage controller
if pageservers.is_empty() {
let storage_controller = StorageController::from_env(&self.env);
let locate_result = storage_controller.tenant_locate(self.tenant_id).await?;
pageservers = locate_result
.shards
.into_iter()
.map(|shard| {
(
Host::parse(&shard.listen_pg_addr)
.expect("Storage controller reported bad hostname"),
shard.listen_pg_port,
)
})
.collect::<Vec<_>>();
}
let pageserver_connstr = Self::build_pageserver_connstr(&pageservers);
assert!(!pageserver_connstr.is_empty());
spec.pageserver_connstring = Some(pageserver_connstr);
if stripe_size.is_some() {
spec.shard_stripe_size = stripe_size.map(|s| s.0 as usize);

View File

@@ -265,14 +265,6 @@ impl PageServerNode {
None => None,
};
let mut grpc_host = None;
let mut grpc_port = None;
if let Some(grpc_addr) = &self.conf.listen_grpc_addr {
let (_, port) = parse_host_port(grpc_addr).expect("Unable to parse listen_grpc_addr");
grpc_host = Some("localhost".to_string());
grpc_port = Some(port.unwrap_or(51051));
}
// Intentionally hand-craft JSON: this acts as an implicit format compat test
// in case the pageserver-side structure is edited, and reflects the real life
// situation: the metadata is written by some other script.
@@ -281,8 +273,6 @@ impl PageServerNode {
serde_json::to_vec(&pageserver_api::config::NodeMetadata {
postgres_host: "localhost".to_string(),
postgres_port: self.pg_connection_config.port(),
grpc_host,
grpc_port,
http_host: "localhost".to_string(),
http_port,
https_port,
@@ -645,4 +635,16 @@ impl PageServerNode {
Ok(())
}
pub async fn timeline_info(
&self,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
force_await_logical_size: mgmt_api::ForceAwaitLogicalSize,
) -> anyhow::Result<TimelineInfo> {
let timeline_info = self
.http_client
.timeline_info(tenant_shard_id, timeline_id, force_await_logical_size)
.await?;
Ok(timeline_info)
}
}

View File

@@ -6,7 +6,6 @@
//! .neon/safekeepers/<safekeeper id>
//! ```
use std::error::Error as _;
use std::future::Future;
use std::io::Write;
use std::path::PathBuf;
use std::time::Duration;
@@ -14,9 +13,9 @@ use std::{io, result};
use anyhow::Context;
use camino::Utf8PathBuf;
use http_utils::error::HttpErrorBody;
use postgres_connection::PgConnectionConfig;
use reqwest::{IntoUrl, Method};
use safekeeper_api::models::TimelineCreateRequest;
use safekeeper_client::mgmt_api;
use thiserror::Error;
use utils::auth::{Claims, Scope};
use utils::id::NodeId;
@@ -35,25 +34,14 @@ pub enum SafekeeperHttpError {
type Result<T> = result::Result<T, SafekeeperHttpError>;
pub(crate) trait ResponseErrorMessageExt: Sized {
fn error_from_body(self) -> impl Future<Output = Result<Self>> + Send;
}
impl ResponseErrorMessageExt for reqwest::Response {
async fn error_from_body(self) -> Result<Self> {
let status = self.status();
if !(status.is_client_error() || status.is_server_error()) {
return Ok(self);
}
// reqwest does not export its error construction utility functions, so let's craft the message ourselves
let url = self.url().to_owned();
Err(SafekeeperHttpError::Response(
match self.json::<HttpErrorBody>().await {
Ok(err_body) => format!("Error: {}", err_body.msg),
Err(_) => format!("Http error ({}) at {}.", status.as_u16(), url),
},
))
fn err_from_client_err(err: mgmt_api::Error) -> SafekeeperHttpError {
use mgmt_api::Error::*;
match err {
ApiError(_, str) => SafekeeperHttpError::Response(str),
Cancelled => SafekeeperHttpError::Response("Cancelled".to_owned()),
ReceiveBody(err) => SafekeeperHttpError::Transport(err),
ReceiveErrorBody(err) => SafekeeperHttpError::Response(err),
Timeout(str) => SafekeeperHttpError::Response(format!("timeout: {str}")),
}
}
@@ -70,9 +58,8 @@ pub struct SafekeeperNode {
pub pg_connection_config: PgConnectionConfig,
pub env: LocalEnv,
pub http_client: reqwest::Client,
pub http_client: mgmt_api::Client,
pub listen_addr: String,
pub http_base_url: String,
}
impl SafekeeperNode {
@@ -82,13 +69,14 @@ impl SafekeeperNode {
} else {
"127.0.0.1".to_string()
};
let jwt = None;
let http_base_url = format!("http://{}:{}", listen_addr, conf.http_port);
SafekeeperNode {
id: conf.id,
conf: conf.clone(),
pg_connection_config: Self::safekeeper_connection_config(&listen_addr, conf.pg_port),
env: env.clone(),
http_client: env.create_http_client(),
http_base_url: format!("http://{}:{}/v1", listen_addr, conf.http_port),
http_client: mgmt_api::Client::new(env.create_http_client(), http_base_url, jwt),
listen_addr,
}
}
@@ -278,20 +266,19 @@ impl SafekeeperNode {
)
}
fn http_request<U: IntoUrl>(&self, method: Method, url: U) -> reqwest::RequestBuilder {
// TODO: authentication
//if self.env.auth_type == AuthType::NeonJWT {
// builder = builder.bearer_auth(&self.env.safekeeper_auth_token)
//}
self.http_client.request(method, url)
pub async fn check_status(&self) -> Result<()> {
self.http_client
.status()
.await
.map_err(err_from_client_err)?;
Ok(())
}
pub async fn check_status(&self) -> Result<()> {
self.http_request(Method::GET, format!("{}/{}", self.http_base_url, "status"))
.send()
.await?
.error_from_body()
.await?;
pub async fn create_timeline(&self, req: &TimelineCreateRequest) -> Result<()> {
self.http_client
.create_timeline(req)
.await
.map_err(err_from_client_err)?;
Ok(())
}
}

View File

@@ -37,11 +37,6 @@ enum Command {
#[arg(long)]
listen_pg_port: u16,
#[arg(long)]
listen_grpc_addr: Option<String>,
#[arg(long)]
listen_grpc_port: Option<u16>,
#[arg(long)]
listen_http_addr: String,
#[arg(long)]
@@ -66,10 +61,16 @@ enum Command {
#[arg(long)]
scheduling: Option<NodeSchedulingPolicy>,
},
// Set a node status as deleted.
NodeDelete {
#[arg(long)]
node_id: NodeId,
},
/// Delete a tombstone of node from the storage controller.
NodeDeleteTombstone {
#[arg(long)]
node_id: NodeId,
},
/// Modify a tenant's policies in the storage controller
TenantPolicy {
#[arg(long)]
@@ -87,6 +88,8 @@ enum Command {
},
/// List nodes known to the storage controller
Nodes {},
/// List soft deleted nodes known to the storage controller
NodeTombstones {},
/// List tenants known to the storage controller
Tenants {
/// If this field is set, it will list the tenants on a specific node
@@ -415,8 +418,6 @@ async fn main() -> anyhow::Result<()> {
node_id,
listen_pg_addr,
listen_pg_port,
listen_grpc_addr,
listen_grpc_port,
listen_http_addr,
listen_http_port,
listen_https_port,
@@ -430,8 +431,6 @@ async fn main() -> anyhow::Result<()> {
node_id,
listen_pg_addr,
listen_pg_port,
listen_grpc_addr,
listen_grpc_port,
listen_http_addr,
listen_http_port,
listen_https_port,
@@ -909,6 +908,39 @@ async fn main() -> anyhow::Result<()> {
.dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None)
.await?;
}
Command::NodeDeleteTombstone { node_id } => {
storcon_client
.dispatch::<(), ()>(
Method::DELETE,
format!("debug/v1/tombstone/{node_id}"),
None,
)
.await?;
}
Command::NodeTombstones {} => {
let mut resp = storcon_client
.dispatch::<(), Vec<NodeDescribeResponse>>(
Method::GET,
"debug/v1/tombstone".to_string(),
None,
)
.await?;
resp.sort_by(|a, b| a.listen_http_addr.cmp(&b.listen_http_addr));
let mut table = comfy_table::Table::new();
table.set_header(["Id", "Hostname", "AZ", "Scheduling", "Availability"]);
for node in resp {
table.add_row([
format!("{}", node.id),
node.listen_http_addr,
node.availability_zone_id,
format!("{:?}", node.scheduling),
format!("{:?}", node.availability),
]);
}
println!("{table}");
}
Command::TenantSetTimeBasedEviction {
tenant_id,
period,

View File

@@ -1,9 +1,6 @@
#!/bin/bash
#!/bin/sh
set -ex
cd "$(dirname "$0")"
if [[ ${PG_VERSION} = v17 ]]; then
sed -i '/computed_columns/d' regress/core/tests.mk
fi
patch -p1 <postgis-no-upgrade-test.patch
trap 'echo Cleaning up; patch -R -p1 <postgis-no-upgrade-test.patch' EXIT
patch -p1 <"postgis-common-${PG_VERSION}.patch"
trap 'echo Cleaning up; patch -R -p1 <postgis-common-${PG_VERSION}.patch' EXIT
make installcheck-base

View File

@@ -1,3 +1,19 @@
diff --git a/regress/core/tests.mk b/regress/core/tests.mk
index 3abd7bc..64a9254 100644
--- a/regress/core/tests.mk
+++ b/regress/core/tests.mk
@@ -144,11 +144,6 @@ TESTS_SLOW = \
$(top_srcdir)/regress/core/concave_hull_hard \
$(top_srcdir)/regress/core/knn_recheck
-ifeq ($(shell expr "$(POSTGIS_PGSQL_VERSION)" ">=" 120),1)
- TESTS += \
- $(top_srcdir)/regress/core/computed_columns
-endif
-
ifeq ($(shell expr "$(POSTGIS_GEOS_VERSION)" ">=" 30700),1)
# GEOS-3.7 adds:
# ST_FrechetDistance
diff --git a/regress/runtest.mk b/regress/runtest.mk
index c051f03..010e493 100644
--- a/regress/runtest.mk

View File

@@ -0,0 +1,35 @@
diff --git a/regress/core/tests.mk b/regress/core/tests.mk
index 9e05244..90987df 100644
--- a/regress/core/tests.mk
+++ b/regress/core/tests.mk
@@ -143,8 +143,7 @@ TESTS += \
$(top_srcdir)/regress/core/oriented_envelope \
$(top_srcdir)/regress/core/point_coordinates \
$(top_srcdir)/regress/core/out_geojson \
- $(top_srcdir)/regress/core/wrapx \
- $(top_srcdir)/regress/core/computed_columns
+ $(top_srcdir)/regress/core/wrapx
# Slow slow tests
TESTS_SLOW = \
diff --git a/regress/runtest.mk b/regress/runtest.mk
index 4b95b7e..449d5a2 100644
--- a/regress/runtest.mk
+++ b/regress/runtest.mk
@@ -24,16 +24,6 @@ check-regress:
@POSTGIS_TOP_BUILD_DIR=$(abs_top_builddir) $(PERL) $(top_srcdir)/regress/run_test.pl $(RUNTESTFLAGS) $(RUNTESTFLAGS_INTERNAL) $(TESTS)
- @if echo "$(RUNTESTFLAGS)" | grep -vq -- --upgrade; then \
- echo "Running upgrade test as RUNTESTFLAGS did not contain that"; \
- POSTGIS_TOP_BUILD_DIR=$(abs_top_builddir) $(PERL) $(top_srcdir)/regress/run_test.pl \
- --upgrade \
- $(RUNTESTFLAGS) \
- $(RUNTESTFLAGS_INTERNAL) \
- $(TESTS); \
- else \
- echo "Skipping upgrade test as RUNTESTFLAGS already requested upgrades"; \
- fi
check-long:
$(PERL) $(top_srcdir)/regress/run_test.pl $(RUNTESTFLAGS) $(TESTS) $(TESTS_SLOW)

View File

@@ -125,7 +125,7 @@ index 7a36b65..ad78fc7 100644
DROP SCHEMA tm CASCADE;
+
diff --git a/regress/core/tests.mk b/regress/core/tests.mk
index 3abd7bc..94903c3 100644
index 64a9254..94903c3 100644
--- a/regress/core/tests.mk
+++ b/regress/core/tests.mk
@@ -23,7 +23,6 @@ current_dir := $(dir $(abspath $(lastword $(MAKEFILE_LIST))))
@@ -160,18 +160,6 @@ index 3abd7bc..94903c3 100644
$(top_srcdir)/regress/core/wkb \
$(top_srcdir)/regress/core/wkt \
$(top_srcdir)/regress/core/wmsservers \
@@ -144,11 +140,6 @@ TESTS_SLOW = \
$(top_srcdir)/regress/core/concave_hull_hard \
$(top_srcdir)/regress/core/knn_recheck
-ifeq ($(shell expr "$(POSTGIS_PGSQL_VERSION)" ">=" 120),1)
- TESTS += \
- $(top_srcdir)/regress/core/computed_columns
-endif
-
ifeq ($(shell expr "$(POSTGIS_GEOS_VERSION)" ">=" 30700),1)
# GEOS-3.7 adds:
# ST_FrechetDistance
diff --git a/regress/loader/tests.mk b/regress/loader/tests.mk
index 1fc77ac..c3cb9de 100644
--- a/regress/loader/tests.mk

View File

@@ -125,7 +125,7 @@ index 7a36b65..ad78fc7 100644
DROP SCHEMA tm CASCADE;
+
diff --git a/regress/core/tests.mk b/regress/core/tests.mk
index 9e05244..a63a3e1 100644
index 90987df..74fe3f1 100644
--- a/regress/core/tests.mk
+++ b/regress/core/tests.mk
@@ -16,14 +16,13 @@ POSTGIS_PGSQL_VERSION=170
@@ -168,16 +168,6 @@ index 9e05244..a63a3e1 100644
$(top_srcdir)/regress/core/wkb \
$(top_srcdir)/regress/core/wkt \
$(top_srcdir)/regress/core/wmsservers \
@@ -143,8 +139,7 @@ TESTS += \
$(top_srcdir)/regress/core/oriented_envelope \
$(top_srcdir)/regress/core/point_coordinates \
$(top_srcdir)/regress/core/out_geojson \
- $(top_srcdir)/regress/core/wrapx \
- $(top_srcdir)/regress/core/computed_columns
+ $(top_srcdir)/regress/core/wrapx
# Slow slow tests
TESTS_SLOW = \
diff --git a/regress/loader/tests.mk b/regress/loader/tests.mk
index ac4f8ad..4bad4fc 100644
--- a/regress/loader/tests.mk

View File

@@ -10,8 +10,8 @@ psql -d contrib_regression -c "ALTER DATABASE contrib_regression SET TimeZone='U
-c "CREATE EXTENSION postgis_tiger_geocoder CASCADE" \
-c "CREATE EXTENSION postgis_raster SCHEMA public" \
-c "CREATE EXTENSION postgis_sfcgal SCHEMA public"
patch -p1 <postgis-no-upgrade-test.patch
patch -p1 <"postgis-common-${PG_VERSION}.patch"
patch -p1 <"postgis-regular-${PG_VERSION}.patch"
psql -d contrib_regression -f raster_outdb_template.sql
trap 'patch -R -p1 <postgis-no-upgrade-test.patch && patch -R -p1 <"postgis-regular-${PG_VERSION}.patch"' EXIT
trap 'patch -R -p1 <postgis-regular-${PG_VERSION}.patch && patch -R -p1 <"postgis-common-${PG_VERSION}.patch"' EXIT
POSTGIS_REGRESS_DB=contrib_regression RUNTESTFLAGS=--nocreate make installcheck-base

View File

@@ -8,6 +8,7 @@ anyhow.workspace = true
axum-extra.workspace = true
axum.workspace = true
camino.workspace = true
clap.workspace = true
futures.workspace = true
jsonwebtoken.workspace = true
prometheus.workspace = true

View File

@@ -4,6 +4,8 @@
//! for large computes.
mod app;
use anyhow::Context;
use clap::Parser;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use tracing::info;
use utils::logging;
@@ -12,9 +14,26 @@ const fn max_upload_file_limit() -> usize {
100 * 1024 * 1024
}
const fn listen() -> SocketAddr {
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 51243)
}
#[derive(Parser)]
struct Args {
#[arg(exclusive = true)]
config_file: Option<String>,
#[arg(long, default_value = "false", requires = "config")]
/// to allow testing k8s helm chart where we don't have s3 credentials
no_s3_check_on_startup: bool,
#[arg(long, value_name = "FILE")]
/// inline config mode for k8s helm chart
config: Option<String>,
}
#[derive(serde::Deserialize)]
#[serde(tag = "type")]
struct Config {
#[serde(default = "listen")]
listen: std::net::SocketAddr,
pemfile: camino::Utf8PathBuf,
#[serde(flatten)]
@@ -31,13 +50,18 @@ async fn main() -> anyhow::Result<()> {
logging::Output::Stdout,
)?;
let config: String = std::env::args().skip(1).take(1).collect();
if config.is_empty() {
anyhow::bail!("Usage: endpoint_storage config.json")
}
info!("Reading config from {config}");
let config = std::fs::read_to_string(config.clone())?;
let config: Config = serde_json::from_str(&config).context("parsing config")?;
let args = Args::parse();
let config: Config = if let Some(config_path) = args.config_file {
info!("Reading config from {config_path}");
let config = std::fs::read_to_string(config_path)?;
serde_json::from_str(&config).context("parsing config")?
} else if let Some(config) = args.config {
info!("Reading inline config");
serde_json::from_str(&config).context("parsing config")?
} else {
anyhow::bail!("Supply either config file path or --config=inline-config");
};
info!("Reading pemfile from {}", config.pemfile.clone());
let pemfile = std::fs::read(config.pemfile.clone())?;
info!("Loading public key from {}", config.pemfile.clone());
@@ -48,7 +72,9 @@ async fn main() -> anyhow::Result<()> {
let storage = remote_storage::GenericRemoteStorage::from_config(&config.storage_config).await?;
let cancel = tokio_util::sync::CancellationToken::new();
app::check_storage_permissions(&storage, cancel.clone()).await?;
if !args.no_s3_check_on_startup {
app::check_storage_permissions(&storage, cancel.clone()).await?;
}
let proxy = std::sync::Arc::new(endpoint_storage::Storage {
auth,

View File

@@ -4,7 +4,6 @@
//! provide it by calling the compute_ctl's `/compute_ctl` endpoint, or
//! compute_ctl can fetch it by calling the control plane's API.
use std::collections::HashMap;
use std::fmt::Display;
use indexmap::IndexMap;
use regex::Regex;
@@ -320,12 +319,6 @@ impl ComputeMode {
}
}
impl Display for ComputeMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.to_type_str())
}
}
/// Log level for audit logging
#[derive(Clone, Debug, Default, Eq, PartialEq, Deserialize, Serialize)]
pub enum ComputeAudit {

View File

@@ -6,13 +6,8 @@ license.workspace = true
[dependencies]
thiserror.workspace = true
nix.workspace = true
spin.workspace = true
nix.workspace=true
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
[dev-dependencies]
rand = "0.9.1"
rand_distr = "0.5.1"
[target.'cfg(target_os = "macos")'.dependencies]
tempfile = "3.14.0"

View File

@@ -1,366 +0,0 @@
//! Hash table implementation on top of 'shmem'
//!
//! Features required in the long run by the communicator project:
//!
//! [X] Accessible from both Postgres processes and rust threads in the communicator process
//! [X] Low latency
//! [ ] Scalable to lots of concurrent accesses (currently uses a single spinlock)
//! [ ] Resizable
use std::fmt::Debug;
use std::hash::Hash;
use std::mem::MaybeUninit;
use std::ops::Deref;
use crate::shmem::ShmemHandle;
use spin;
mod core;
#[cfg(test)]
mod tests;
use core::CoreHashMap;
pub enum UpdateAction<V> {
Nothing,
Insert(V),
Remove,
}
#[derive(Debug)]
pub struct OutOfMemoryError();
pub struct HashMapInit<'a, K, V> {
// Hash table can be allocated in a fixed memory area, or in a resizeable ShmemHandle.
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
}
pub struct HashMapAccess<'a, K, V> {
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
}
unsafe impl<'a, K: Sync, V: Sync> Sync for HashMapAccess<'a, K, V> {}
unsafe impl<'a, K: Send, V: Send> Send for HashMapAccess<'a, K, V> {}
impl<'a, K, V> HashMapInit<'a, K, V> {
pub fn attach_writer(self) -> HashMapAccess<'a, K, V> {
HashMapAccess {
shmem_handle: self.shmem_handle,
shared_ptr: self.shared_ptr,
}
}
pub fn attach_reader(self) -> HashMapAccess<'a, K, V> {
// no difference to attach_writer currently
self.attach_writer()
}
}
// This is stored in the shared memory area
struct HashMapShared<'a, K, V> {
inner: spin::RwLock<CoreHashMap<'a, K, V>>,
}
impl<'a, K, V> HashMapInit<'a, K, V>
where
K: Clone + Hash + Eq,
{
pub fn estimate_size(num_buckets: u32) -> usize {
// add some margin to cover alignment etc.
CoreHashMap::<K, V>::estimate_size(num_buckets) + size_of::<HashMapShared<K, V>>() + 1000
}
pub fn init_in_fixed_area(
num_buckets: u32,
area: &'a mut [MaybeUninit<u8>],
) -> HashMapInit<'a, K, V> {
Self::init_common(num_buckets, None, area.as_mut_ptr().cast(), area.len())
}
/// Initialize a new hash map in the given shared memory area
pub fn init_in_shmem(num_buckets: u32, mut shmem: ShmemHandle) -> HashMapInit<'a, K, V> {
let size = Self::estimate_size(num_buckets);
shmem
.set_size(size)
.expect("could not resize shared memory area");
let ptr = unsafe { shmem.data_ptr.as_mut() };
Self::init_common(num_buckets, Some(shmem), ptr, size)
}
fn init_common(
num_buckets: u32,
shmem_handle: Option<ShmemHandle>,
area_ptr: *mut u8,
area_len: usize,
) -> HashMapInit<'a, K, V> {
// carve out HashMapShared from the area. This does not include the hashmap's dictionary
// and buckets.
let mut ptr: *mut u8 = area_ptr;
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<HashMapShared<K, V>>())) };
let shared_ptr: *mut HashMapShared<K, V> = ptr.cast();
ptr = unsafe { ptr.add(size_of::<HashMapShared<K, V>>()) };
// the rest of the space is given to the hash map's dictionary and buckets
let remaining_area = unsafe {
std::slice::from_raw_parts_mut(ptr, area_len - ptr.offset_from(area_ptr) as usize)
};
let hashmap = CoreHashMap::new(num_buckets, remaining_area);
unsafe {
std::ptr::write(
shared_ptr,
HashMapShared {
inner: spin::RwLock::new(hashmap),
},
);
}
HashMapInit {
shmem_handle: shmem_handle,
shared_ptr,
}
}
}
impl<'a, K, V> HashMapAccess<'a, K, V>
where
K: Clone + Hash + Eq,
{
pub fn get<'e>(&'e self, key: &K) -> Option<ValueReadGuard<'e, K, V>> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
let lock_guard = map.inner.read();
match lock_guard.get(key) {
None => None,
Some(val_ref) => {
let val_ptr = std::ptr::from_ref(val_ref);
Some(ValueReadGuard {
_lock_guard: lock_guard,
value: val_ptr,
})
}
}
}
/// Insert a value
pub fn insert(&self, key: &K, value: V) -> Result<bool, OutOfMemoryError> {
let mut success = None;
self.update_with_fn(key, |existing| {
if let Some(_) = existing {
success = Some(false);
UpdateAction::Nothing
} else {
success = Some(true);
UpdateAction::Insert(value)
}
})?;
Ok(success.expect("value_fn not called"))
}
/// Remove value. Returns true if it existed
pub fn remove(&self, key: &K) -> bool {
let mut result = false;
self.update_with_fn(key, |existing| match existing {
Some(_) => {
result = true;
UpdateAction::Remove
}
None => UpdateAction::Nothing,
})
.expect("out of memory while removing");
result
}
/// Update key using the given function. All the other modifying operations are based on this.
pub fn update_with_fn<F>(&self, key: &K, value_fn: F) -> Result<(), OutOfMemoryError>
where
F: FnOnce(Option<&V>) -> UpdateAction<V>,
{
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
let mut lock_guard = map.inner.write();
let old_val = lock_guard.get(key);
let action = value_fn(old_val);
match (old_val, action) {
(_, UpdateAction::Nothing) => {}
(_, UpdateAction::Insert(new_val)) => {
let _ = lock_guard.insert(key, new_val);
}
(None, UpdateAction::Remove) => panic!("Remove action with no old value"),
(Some(_), UpdateAction::Remove) => {
let _ = lock_guard.remove(key);
}
}
Ok(())
}
/// Update key using the given function. All the other modifying operations are based on this.
pub fn update_with_fn_at_bucket<F>(
&self,
pos: usize,
value_fn: F,
) -> Result<(), OutOfMemoryError>
where
F: FnOnce(Option<&V>) -> UpdateAction<V>,
{
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
let mut lock_guard = map.inner.write();
let old_val = lock_guard.get_bucket(pos);
let action = value_fn(old_val.map(|(_k, v)| v));
match (old_val, action) {
(_, UpdateAction::Nothing) => {}
(_, UpdateAction::Insert(_new_val)) => panic!("cannot insert without key"),
(None, UpdateAction::Remove) => panic!("Remove action with no old value"),
(Some((key, _value)), UpdateAction::Remove) => {
let key = key.clone();
let _ = lock_guard.remove(&key);
}
}
Ok(())
}
pub fn get_num_buckets(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
map.inner.read().get_num_buckets()
}
/// Return the key and value stored in bucket with given index. This can be used to
/// iterate through the hash map. (An Iterator might be nicer. The communicator's
/// clock algorithm needs to _slowly_ iterate through all buckets with its clock hand,
/// without holding a lock. If we switch to an Iterator, it must not hold the lock.)
pub fn get_bucket<'e>(&'e self, pos: usize) -> Option<ValueReadGuard<'e, K, V>> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
let lock_guard = map.inner.read();
match lock_guard.get_bucket(pos) {
None => None,
Some((_key, val_ref)) => {
let val_ptr = std::ptr::from_ref(val_ref);
Some(ValueReadGuard {
_lock_guard: lock_guard,
value: val_ptr,
})
}
}
}
// for metrics
pub fn get_num_buckets_in_use(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
map.inner.read().buckets_in_use as usize
}
/// Grow
///
/// 1. grow the underlying shared memory area
/// 2. Initialize new buckets. This overwrites the current dictionary
/// 3. Recalculate the dictionary
pub fn grow(&self, num_buckets: u32) -> Result<(), crate::shmem::Error> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap();
let mut lock_guard = map.inner.write();
let inner = &mut *lock_guard;
let old_num_buckets = inner.buckets.len() as u32;
if num_buckets < old_num_buckets {
panic!("grow called with a smaller number of buckets");
}
if num_buckets == old_num_buckets {
return Ok(());
}
let shmem_handle = self
.shmem_handle
.as_ref()
.expect("grow called on a fixed-size hash table");
let size_bytes = HashMapInit::<K, V>::estimate_size(num_buckets);
shmem_handle.set_size(size_bytes)?;
let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
// Initialize new buckets. The new buckets are linked to the free list. NB: This overwrites
// the dictionary!
let buckets_ptr = inner.buckets.as_mut_ptr();
unsafe {
for i in old_num_buckets..num_buckets {
let bucket_ptr = buckets_ptr.add(i as usize);
bucket_ptr.write(core::Bucket {
hash: 0,
next: if i < num_buckets {
i as u32 + 1
} else {
inner.free_head
},
inner: None,
});
}
}
// Recalculate the dictionary
let buckets;
let dictionary;
unsafe {
let buckets_end_ptr = buckets_ptr.add(num_buckets as usize);
let dictionary_ptr: *mut u32 = buckets_end_ptr
.byte_add(buckets_end_ptr.align_offset(align_of::<u32>()))
.cast();
let dictionary_size: usize =
end_ptr.byte_offset_from(buckets_end_ptr) as usize / size_of::<u32>();
buckets = std::slice::from_raw_parts_mut(buckets_ptr, num_buckets as usize);
dictionary = std::slice::from_raw_parts_mut(dictionary_ptr, dictionary_size);
}
for i in 0..dictionary.len() {
dictionary[i] = core::INVALID_POS;
}
for i in 0..old_num_buckets as usize {
if buckets[i].inner.is_none() {
continue;
}
let pos: usize = (buckets[i].hash % dictionary.len() as u64) as usize;
buckets[i].next = dictionary[pos];
dictionary[pos] = i as u32;
}
// Finally, update the CoreHashMap struct
inner.dictionary = dictionary;
inner.buckets = buckets;
inner.free_head = old_num_buckets;
Ok(())
}
// TODO: Shrinking is a multi-step process that requires co-operation from the caller
//
// 1. The caller must first call begin_shrink(). That forbids allocation of higher-numbered
// buckets.
//
// 2. Next, the caller must evict all entries in higher-numbered buckets.
//
// 3. Finally, call finish_shrink(). This recomputes the dictionary and shrinks the underlying
// shmem area
}
pub struct ValueReadGuard<'a, K, V> {
_lock_guard: spin::RwLockReadGuard<'a, CoreHashMap<'a, K, V>>,
value: *const V,
}
impl<'a, K, V> Deref for ValueReadGuard<'a, K, V> {
type Target = V;
fn deref(&self) -> &Self::Target {
// SAFETY: The `lock_guard` ensures that the underlying map (and thus the value pointed to
// by `value`) remains valid for the lifetime `'a`. The `value` has been obtained from a
// valid reference within the map.
unsafe { &*self.value }
}
}

View File

@@ -1,233 +0,0 @@
//! Simple hash table with chaining
//!
//! # Resizing
//!
use std::hash::{DefaultHasher, Hash, Hasher};
use std::mem::MaybeUninit;
pub(crate) const INVALID_POS: u32 = u32::MAX;
// Bucket
pub(crate) struct Bucket<K, V> {
pub(crate) hash: u64,
pub(crate) next: u32,
pub(crate) inner: Option<(K, V)>,
}
pub(crate) struct CoreHashMap<'a, K, V> {
pub(crate) dictionary: &'a mut [u32],
pub(crate) buckets: &'a mut [Bucket<K, V>],
pub(crate) free_head: u32,
// metrics
pub(crate) buckets_in_use: u32,
}
pub struct FullError();
impl<'a, K, V> CoreHashMap<'a, K, V>
where
K: Clone + Hash + Eq,
{
const FILL_FACTOR: f32 = 0.60;
pub fn estimate_size(num_buckets: u32) -> usize {
let mut size = 0;
// buckets
size += size_of::<Bucket<K, V>>() * num_buckets as usize;
// dictionary
size += (f32::ceil((size_of::<u32>() * num_buckets as usize) as f32 / Self::FILL_FACTOR))
as usize;
size
}
pub fn new(num_buckets: u32, area: &'a mut [u8]) -> CoreHashMap<'a, K, V> {
let len = area.len();
let mut ptr: *mut u8 = area.as_mut_ptr();
let end_ptr: *mut u8 = unsafe { area.as_mut_ptr().add(len) };
// carve out the buckets
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<Bucket<K, V>>())) };
let buckets_ptr = ptr;
ptr = unsafe { ptr.add(size_of::<Bucket<K, V>>() * num_buckets as usize) };
// use remaining space for the dictionary
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<u32>())) };
let dictionary_ptr = ptr;
assert!(ptr.addr() < end_ptr.addr());
let dictionary_size = unsafe { end_ptr.byte_offset_from(ptr) / size_of::<u32>() as isize };
assert!(dictionary_size > 0);
// Initialize the buckets
let buckets = {
let buckets_ptr: *mut MaybeUninit<Bucket<K, V>> = buckets_ptr.cast();
let buckets =
unsafe { std::slice::from_raw_parts_mut(buckets_ptr, num_buckets as usize) };
for i in 0..buckets.len() {
buckets[i].write(Bucket {
hash: 0,
next: if i < buckets.len() - 1 {
i as u32 + 1
} else {
INVALID_POS
},
inner: None,
});
}
// TODO: use std::slice::assume_init_mut() once it stabilizes
unsafe { std::slice::from_raw_parts_mut(buckets_ptr.cast(), num_buckets as usize) }
};
// Initialize the dictionary
let dictionary = {
let dictionary_ptr: *mut MaybeUninit<u32> = dictionary_ptr.cast();
let dictionary =
unsafe { std::slice::from_raw_parts_mut(dictionary_ptr, dictionary_size as usize) };
for i in 0..dictionary.len() {
dictionary[i].write(INVALID_POS);
}
// TODO: use std::slice::assume_init_mut() once it stabilizes
unsafe {
std::slice::from_raw_parts_mut(dictionary_ptr.cast(), dictionary_size as usize)
}
};
CoreHashMap {
dictionary,
buckets,
free_head: 0,
buckets_in_use: 0,
}
}
pub fn get(&self, key: &K) -> Option<&V> {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let hash = hasher.finish();
let mut next = self.dictionary[hash as usize % self.dictionary.len()];
loop {
if next == INVALID_POS {
return None;
}
let bucket = &self.buckets[next as usize];
let (bucket_key, bucket_value) = bucket.inner.as_ref().expect("entry is in use");
if bucket_key == key {
return Some(&bucket_value);
}
next = bucket.next;
}
}
pub fn insert(&mut self, key: &K, value: V) -> Result<(), FullError> {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let hash = hasher.finish();
let first = self.dictionary[hash as usize % self.dictionary.len()];
if first == INVALID_POS {
// no existing entry
let pos = self.alloc_bucket(key.clone(), value, hash)?;
if pos == INVALID_POS {
return Err(FullError());
}
self.dictionary[hash as usize % self.dictionary.len()] = pos;
return Ok(());
}
let mut next = first;
loop {
let bucket = &mut self.buckets[next as usize];
let (bucket_key, bucket_value) = bucket.inner.as_mut().expect("entry is in use");
if bucket_key == key {
// found existing entry, update its value
*bucket_value = value;
return Ok(());
}
if bucket.next == INVALID_POS {
// No existing entry found. Append to the chain
let pos = self.alloc_bucket(key.clone(), value, hash)?;
if pos == INVALID_POS {
return Err(FullError());
}
self.buckets[next as usize].next = pos;
return Ok(());
}
next = bucket.next;
}
}
pub fn remove(&mut self, key: &K) -> Result<(), FullError> {
let mut hasher = DefaultHasher::new();
key.hash(&mut hasher);
let hash = hasher.finish();
let mut next = self.dictionary[hash as usize % self.dictionary.len()];
let mut prev_pos: u32 = INVALID_POS;
loop {
if next == INVALID_POS {
// no existing entry
return Ok(());
}
let bucket = &mut self.buckets[next as usize];
let (bucket_key, _) = bucket.inner.as_mut().expect("entry is in use");
if bucket_key == key {
// found existing entry, unlink it from the chain
if prev_pos == INVALID_POS {
self.dictionary[hash as usize % self.dictionary.len()] = bucket.next;
} else {
self.buckets[prev_pos as usize].next = bucket.next;
}
// and add it to the freelist
let bucket = &mut self.buckets[next as usize];
bucket.hash = 0;
bucket.inner = None;
bucket.next = self.free_head;
self.free_head = next;
self.buckets_in_use -= 1;
return Ok(());
}
prev_pos = next;
next = bucket.next;
}
}
pub fn get_num_buckets(&self) -> usize {
self.buckets.len()
}
pub fn get_bucket(&self, pos: usize) -> Option<&(K, V)> {
if pos >= self.buckets.len() {
return None;
}
self.buckets[pos].inner.as_ref()
}
fn alloc_bucket(&mut self, key: K, value: V, hash: u64) -> Result<u32, FullError> {
let pos = self.free_head;
if pos == INVALID_POS {
return Err(FullError());
}
let bucket = &mut self.buckets[pos as usize];
self.free_head = bucket.next;
self.buckets_in_use += 1;
bucket.hash = hash;
bucket.next = INVALID_POS;
bucket.inner = Some((key, value));
return Ok(pos);
}
}

View File

@@ -1,220 +0,0 @@
use std::collections::BTreeMap;
use std::collections::HashSet;
use std::fmt::{Debug, Formatter};
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::hash::HashMapAccess;
use crate::hash::HashMapInit;
use crate::hash::UpdateAction;
use crate::shmem::ShmemHandle;
use rand::seq::SliceRandom;
use rand::{Rng, RngCore};
use rand_distr::Zipf;
const TEST_KEY_LEN: usize = 16;
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
struct TestKey([u8; TEST_KEY_LEN]);
impl From<&TestKey> for u128 {
fn from(val: &TestKey) -> u128 {
u128::from_be_bytes(val.0)
}
}
impl From<u128> for TestKey {
fn from(val: u128) -> TestKey {
TestKey(val.to_be_bytes())
}
}
impl<'a> From<&'a [u8]> for TestKey {
fn from(bytes: &'a [u8]) -> TestKey {
TestKey(bytes.try_into().unwrap())
}
}
fn test_inserts<K: Into<TestKey> + Copy>(keys: &[K]) {
const MAX_MEM_SIZE: usize = 10000000;
let shmem = ShmemHandle::new("test_inserts", 0, MAX_MEM_SIZE).unwrap();
let init_struct = HashMapInit::<TestKey, usize>::init_in_shmem(100000, shmem);
let w = init_struct.attach_writer();
for (idx, k) in keys.iter().enumerate() {
let res = w.insert(&(*k).into(), idx);
assert!(res.is_ok());
}
for (idx, k) in keys.iter().enumerate() {
let x = w.get(&(*k).into());
let value = x.as_deref().copied();
assert_eq!(value, Some(idx));
}
//eprintln!("stats: {:?}", tree_writer.get_statistics());
}
#[test]
fn dense() {
// This exercises splitting a node with prefix
let keys: &[u128] = &[0, 1, 2, 3, 256];
test_inserts(keys);
// Dense keys
let mut keys: Vec<u128> = (0..10000).collect();
test_inserts(&keys);
// Do the same in random orders
for _ in 1..10 {
keys.shuffle(&mut rand::rng());
test_inserts(&keys);
}
}
#[test]
fn sparse() {
// sparse keys
let mut keys: Vec<TestKey> = Vec::new();
let mut used_keys = HashSet::new();
for _ in 0..10000 {
loop {
let key = rand::random::<u128>();
if used_keys.get(&key).is_some() {
continue;
}
used_keys.insert(key);
keys.push(key.into());
break;
}
}
test_inserts(&keys);
}
struct TestValue(AtomicUsize);
impl TestValue {
fn new(val: usize) -> TestValue {
TestValue(AtomicUsize::new(val))
}
fn load(&self) -> usize {
self.0.load(Ordering::Relaxed)
}
}
impl Clone for TestValue {
fn clone(&self) -> TestValue {
TestValue::new(self.load())
}
}
impl Debug for TestValue {
fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(fmt, "{:?}", self.load())
}
}
#[derive(Clone, Debug)]
struct TestOp(TestKey, Option<usize>);
fn apply_op(
op: &TestOp,
sut: &HashMapAccess<TestKey, TestValue>,
shadow: &mut BTreeMap<TestKey, usize>,
) {
eprintln!("applying op: {op:?}");
// apply the change to the shadow tree first
let shadow_existing = if let Some(v) = op.1 {
shadow.insert(op.0, v)
} else {
shadow.remove(&op.0)
};
// apply to Art tree
sut.update_with_fn(&op.0, |existing| {
assert_eq!(existing.map(TestValue::load), shadow_existing);
match (existing, op.1) {
(None, None) => UpdateAction::Nothing,
(None, Some(new_val)) => UpdateAction::Insert(TestValue::new(new_val)),
(Some(_old_val), None) => UpdateAction::Remove,
(Some(old_val), Some(new_val)) => {
old_val.0.store(new_val, Ordering::Relaxed);
UpdateAction::Nothing
}
}
})
.expect("out of memory");
}
#[test]
fn random_ops() {
const MAX_MEM_SIZE: usize = 10000000;
let shmem = ShmemHandle::new("test_inserts", 0, MAX_MEM_SIZE).unwrap();
let init_struct = HashMapInit::<TestKey, TestValue>::init_in_shmem(100000, shmem);
let writer = init_struct.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let distribution = Zipf::new(u128::MAX as f64, 1.1).unwrap();
let mut rng = rand::rng();
for i in 0..100000 {
let key: TestKey = (rng.sample(distribution) as u128).into();
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
//eprintln!("stats: {:?}", tree_writer.get_statistics());
//test_iter(&tree_writer, &shadow);
}
}
}
#[test]
fn test_grow() {
const MEM_SIZE: usize = 10000000;
let shmem = ShmemHandle::new("test_grow", 0, MEM_SIZE).unwrap();
let init_struct = HashMapInit::<TestKey, TestValue>::init_in_shmem(1000, shmem);
let writer = init_struct.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
for i in 0..10000 {
let key: TestKey = ((rng.next_u32() % 1000) as u128).into();
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
//eprintln!("stats: {:?}", tree_writer.get_statistics());
//test_iter(&tree_writer, &shadow);
}
}
writer.grow(1500).unwrap();
for i in 0..10000 {
let key: TestKey = ((rng.next_u32() % 1500) as u128).into();
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
//eprintln!("stats: {:?}", tree_writer.get_statistics());
//test_iter(&tree_writer, &shadow);
}
}
}

View File

@@ -1,4 +1,418 @@
//! Shared memory utilities for neon communicator
pub mod hash;
pub mod shmem;
use std::num::NonZeroUsize;
use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use nix::errno::Errno;
use nix::sys::mman::MapFlags;
use nix::sys::mman::ProtFlags;
use nix::sys::mman::mmap as nix_mmap;
use nix::sys::mman::munmap as nix_munmap;
use nix::unistd::ftruncate as nix_ftruncate;
/// ShmemHandle represents a shared memory area that can be shared by processes over fork().
/// Unlike shared memory allocated by Postgres, this area is resizable, up to 'max_size' that's
/// specified at creation.
///
/// The area is backed by an anonymous file created with memfd_create(). The full address space for
/// 'max_size' is reserved up-front with mmap(), but whenever you call [`ShmemHandle::set_size`],
/// the underlying file is resized. Do not access the area beyond the current size. Currently, that
/// will cause the file to be expanded, but we might use mprotect() etc. to enforce that in the
/// future.
pub struct ShmemHandle {
/// memfd file descriptor
fd: OwnedFd,
max_size: usize,
// Pointer to the beginning of the shared memory area. The header is stored there.
shared_ptr: NonNull<SharedStruct>,
// Pointer to the beginning of the user data
pub data_ptr: NonNull<u8>,
}
/// This is stored at the beginning in the shared memory area.
struct SharedStruct {
max_size: usize,
/// Current size of the backing file. The high-order bit is used for the RESIZE_IN_PROGRESS flag
current_size: AtomicUsize,
}
const RESIZE_IN_PROGRESS: usize = 1 << 63;
const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
/// Error type returned by the ShmemHandle functions.
#[derive(thiserror::Error, Debug)]
#[error("{msg}: {errno}")]
pub struct Error {
pub msg: String,
pub errno: Errno,
}
impl Error {
fn new(msg: &str, errno: Errno) -> Error {
Error {
msg: msg.to_string(),
errno,
}
}
}
impl ShmemHandle {
/// Create a new shared memory area. To communicate between processes, the processes need to be
/// fork()'d after calling this, so that the ShmemHandle is inherited by all processes.
///
/// If the ShmemHandle is dropped, the memory is unmapped from the current process. Other
/// processes can continue using it, however.
pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<ShmemHandle, Error> {
// create the backing anonymous file.
let fd = create_backing_file(name)?;
Self::new_with_fd(fd, initial_size, max_size)
}
fn new_with_fd(
fd: OwnedFd,
initial_size: usize,
max_size: usize,
) -> Result<ShmemHandle, Error> {
// We reserve the high-order bit for the RESIZE_IN_PROGRESS flag, and the actual size
// is a little larger than this because of the SharedStruct header. Make the upper limit
// somewhat smaller than that, because with anything close to that, you'll run out of
// memory anyway.
if max_size >= 1 << 48 {
panic!("max size {} too large", max_size);
}
if initial_size > max_size {
panic!("initial size {initial_size} larger than max size {max_size}");
}
// The actual initial / max size is the one given by the caller, plus the size of
// 'SharedStruct'.
let initial_size = HEADER_SIZE + initial_size;
let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
// Reserve address space for it with mmap
//
// TODO: Use MAP_HUGETLB if possible
let start_ptr = unsafe {
nix_mmap(
None,
max_size,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_SHARED,
&fd,
0,
)
}
.map_err(|e| Error::new("mmap failed: {e}", e))?;
// Reserve space for the initial size
enlarge_file(fd.as_fd(), initial_size as u64)?;
// Initialize the header
let shared: NonNull<SharedStruct> = start_ptr.cast();
unsafe {
shared.write(SharedStruct {
max_size: max_size.into(),
current_size: AtomicUsize::new(initial_size),
})
};
// The user data begins after the header
let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
Ok(ShmemHandle {
fd,
max_size: max_size.into(),
shared_ptr: shared,
data_ptr,
})
}
// return reference to the header
fn shared(&self) -> &SharedStruct {
unsafe { self.shared_ptr.as_ref() }
}
/// Resize the shared memory area. 'new_size' must not be larger than the 'max_size' specified
/// when creating the area.
///
/// This may only be called from one process/thread concurrently. We detect that case
/// and return an Error.
pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
let new_size = new_size + HEADER_SIZE;
let shared = self.shared();
if new_size > self.max_size {
panic!(
"new size ({} is greater than max size ({})",
new_size, self.max_size
);
}
assert_eq!(self.max_size, shared.max_size);
// Lock the area by setting the bit in 'current_size'
//
// Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
// and the posix_fallocate/ftruncate call is surely a synchronization point anyway. But
// since this is not performance-critical, better safe than sorry .
let mut old_size = shared.current_size.load(Ordering::Acquire);
loop {
if (old_size & RESIZE_IN_PROGRESS) != 0 {
return Err(Error::new(
"concurrent resize detected",
Errno::UnknownErrno,
));
}
match shared.current_size.compare_exchange(
old_size,
new_size,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => old_size = x,
}
}
// Ok, we got the lock.
//
// NB: If anything goes wrong, we *must* clear the bit!
let result = {
use std::cmp::Ordering::{Equal, Greater, Less};
match new_size.cmp(&old_size) {
Less => nix_ftruncate(&self.fd, new_size as i64).map_err(|e| {
Error::new("could not shrink shmem segment, ftruncate failed: {e}", e)
}),
Equal => Ok(()),
Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
}
};
// Unlock
shared.current_size.store(
if result.is_ok() { new_size } else { old_size },
Ordering::Release,
);
result
}
/// Returns the current user-visible size of the shared memory segment.
///
/// NOTE: a concurrent set_size() call can change the size at any time. It is the caller's
/// responsibility not to access the area beyond the current size.
pub fn current_size(&self) -> usize {
let total_current_size =
self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
total_current_size - HEADER_SIZE
}
}
impl Drop for ShmemHandle {
fn drop(&mut self) {
// SAFETY: The pointer was obtained from mmap() with the given size.
// We unmap the entire region.
let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
// The fd is dropped automatically by OwnedFd.
}
}
/// Create a "backing file" for the shared memory area. On Linux, use memfd_create(), to create an
/// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
/// development and testing, but in production we want the file to stay in memory.
///
/// disable 'unused_variables' warnings, because in the macos path, 'name' is unused.
#[allow(unused_variables)]
fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
#[cfg(not(target_os = "macos"))]
{
nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
.map_err(|e| Error::new("memfd_create failed: {e}", e))
}
#[cfg(target_os = "macos")]
{
let file = tempfile::tempfile().map_err(|e| {
Error::new(
"could not create temporary file to back shmem area: {e}",
nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
)
})?;
Ok(OwnedFd::from(file))
}
}
fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
// Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
// we don't get a segfault later when trying to actually use it.
#[cfg(not(target_os = "macos"))]
{
nix::fcntl::posix_fallocate(fd, 0, size as i64).map_err(|e| {
Error::new(
"could not grow shmem segment, posix_fallocate failed: {e}",
e,
)
})
}
// As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
#[cfg(target_os = "macos")]
{
nix::unistd::ftruncate(fd, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, ftruncate failed: {e}", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nix::unistd::ForkResult;
use std::ops::Range;
/// check that all bytes in given range have the expected value.
fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
for i in range {
let b = unsafe { *(ptr.add(i)) };
assert_eq!(expected, b, "unexpected byte at offset {}", i);
}
}
/// Write 'b' to all bytes in the given range
fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
}
// simple single-process test of growing and shrinking
#[test]
fn test_shmem_resize() -> Result<(), Error> {
let max_size = 1024 * 1024;
let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
assert_eq!(init_struct.current_size(), 0);
// Initial grow
let size1 = 10000;
init_struct.set_size(size1).unwrap();
assert_eq!(init_struct.current_size(), size1);
// Write some data
let data_ptr = init_struct.data_ptr.as_ptr();
write_range(data_ptr, 0xAA, 0..size1);
assert_range(data_ptr, 0xAA, 0..size1);
// Shrink
let size2 = 5000;
init_struct.set_size(size2).unwrap();
assert_eq!(init_struct.current_size(), size2);
// Grow again
let size3 = 20000;
init_struct.set_size(size3).unwrap();
assert_eq!(init_struct.current_size(), size3);
// Try to read it. The area that was shrunk and grown again should read as all zeros now
assert_range(data_ptr, 0xAA, 0..5000);
assert_range(data_ptr, 0, 5000..size1);
// Try to grow beyond max_size
//let size4 = max_size + 1;
//assert!(init_struct.set_size(size4).is_err());
// Dropping init_struct should unmap the memory
drop(init_struct);
Ok(())
}
/// This is used in tests to coordinate between test processes. It's like std::sync::Barrier,
/// but is stored in the shared memory area and works across processes. It's implemented by
/// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
struct SimpleBarrier {
num_procs: usize,
count: AtomicUsize,
}
impl SimpleBarrier {
unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
unsafe {
*ptr = SimpleBarrier {
num_procs,
count: AtomicUsize::new(0),
}
}
}
pub fn wait(&self) {
let old = self.count.fetch_add(1, Ordering::Relaxed);
let generation = old / self.num_procs;
let mut current = old + 1;
while current < (generation + 1) * self.num_procs {
std::thread::sleep(std::time::Duration::from_millis(10));
current = self.count.load(Ordering::Relaxed);
}
}
}
#[test]
fn test_multi_process() {
// Initialize
let max_size = 1_000_000_000_000;
let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
let ptr = init_struct.data_ptr.as_ptr();
// Store the SimpleBarrier in the first 1k of the area.
init_struct.set_size(10000).unwrap();
let barrier_ptr: *mut SimpleBarrier = unsafe {
ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
.cast()
};
unsafe { SimpleBarrier::init(barrier_ptr, 2) };
let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
// Fork another test process. The code after this runs in both processes concurrently.
let fork_result = unsafe { nix::unistd::fork().unwrap() };
// In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
if fork_result.is_parent() {
write_range(ptr, 0xAA, 1000..2000);
} else {
write_range(ptr, 0xBB, 2000..3000);
}
barrier.wait();
// Verify the contents. (in both processes)
assert_range(ptr, 0xAA, 1000..2000);
assert_range(ptr, 0xBB, 2000..3000);
// Grow, from the child this time
let size = 10_000_000;
if !fork_result.is_parent() {
init_struct.set_size(size).unwrap();
}
barrier.wait();
// make some writes at the end
if fork_result.is_parent() {
write_range(ptr, 0xAA, (size - 10)..size);
} else {
write_range(ptr, 0xBB, (size - 20)..(size - 10));
}
barrier.wait();
// Verify the contents. (This runs in both processes)
assert_range(ptr, 0, (size - 1000)..(size - 20));
assert_range(ptr, 0xBB, (size - 20)..(size - 10));
assert_range(ptr, 0xAA, (size - 10)..size);
if let ForkResult::Parent { child } = fork_result {
nix::sys::wait::waitpid(child, None).unwrap();
}
}
}

View File

@@ -1,418 +0,0 @@
//! Dynamically resizable contiguous chunk of shared memory
use std::num::NonZeroUsize;
use std::os::fd::{AsFd, BorrowedFd, OwnedFd};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicUsize, Ordering};
use nix::errno::Errno;
use nix::sys::mman::MapFlags;
use nix::sys::mman::ProtFlags;
use nix::sys::mman::mmap as nix_mmap;
use nix::sys::mman::munmap as nix_munmap;
use nix::unistd::ftruncate as nix_ftruncate;
/// ShmemHandle represents a shared memory area that can be shared by processes over fork().
/// Unlike shared memory allocated by Postgres, this area is resizable, up to 'max_size' that's
/// specified at creation.
///
/// The area is backed by an anonymous file created with memfd_create(). The full address space for
/// 'max_size' is reserved up-front with mmap(), but whenever you call [`ShmemHandle::set_size`],
/// the underlying file is resized. Do not access the area beyond the current size. Currently, that
/// will cause the file to be expanded, but we might use mprotect() etc. to enforce that in the
/// future.
pub struct ShmemHandle {
/// memfd file descriptor
fd: OwnedFd,
max_size: usize,
// Pointer to the beginning of the shared memory area. The header is stored there.
shared_ptr: NonNull<SharedStruct>,
// Pointer to the beginning of the user data
pub data_ptr: NonNull<u8>,
}
/// This is stored at the beginning in the shared memory area.
struct SharedStruct {
max_size: usize,
/// Current size of the backing file. The high-order bit is used for the RESIZE_IN_PROGRESS flag
current_size: AtomicUsize,
}
const RESIZE_IN_PROGRESS: usize = 1 << 63;
const HEADER_SIZE: usize = std::mem::size_of::<SharedStruct>();
/// Error type returned by the ShmemHandle functions.
#[derive(thiserror::Error, Debug)]
#[error("{msg}: {errno}")]
pub struct Error {
pub msg: String,
pub errno: Errno,
}
impl Error {
fn new(msg: &str, errno: Errno) -> Error {
Error {
msg: msg.to_string(),
errno,
}
}
}
impl ShmemHandle {
/// Create a new shared memory area. To communicate between processes, the processes need to be
/// fork()'d after calling this, so that the ShmemHandle is inherited by all processes.
///
/// If the ShmemHandle is dropped, the memory is unmapped from the current process. Other
/// processes can continue using it, however.
pub fn new(name: &str, initial_size: usize, max_size: usize) -> Result<ShmemHandle, Error> {
// create the backing anonymous file.
let fd = create_backing_file(name)?;
Self::new_with_fd(fd, initial_size, max_size)
}
fn new_with_fd(
fd: OwnedFd,
initial_size: usize,
max_size: usize,
) -> Result<ShmemHandle, Error> {
// We reserve the high-order bit for the RESIZE_IN_PROGRESS flag, and the actual size
// is a little larger than this because of the SharedStruct header. Make the upper limit
// somewhat smaller than that, because with anything close to that, you'll run out of
// memory anyway.
if max_size >= 1 << 48 {
panic!("max size {} too large", max_size);
}
if initial_size > max_size {
panic!("initial size {initial_size} larger than max size {max_size}");
}
// The actual initial / max size is the one given by the caller, plus the size of
// 'SharedStruct'.
let initial_size = HEADER_SIZE + initial_size;
let max_size = NonZeroUsize::new(HEADER_SIZE + max_size).unwrap();
// Reserve address space for it with mmap
//
// TODO: Use MAP_HUGETLB if possible
let start_ptr = unsafe {
nix_mmap(
None,
max_size,
ProtFlags::PROT_READ | ProtFlags::PROT_WRITE,
MapFlags::MAP_SHARED,
&fd,
0,
)
}
.map_err(|e| Error::new("mmap failed: {e}", e))?;
// Reserve space for the initial size
enlarge_file(fd.as_fd(), initial_size as u64)?;
// Initialize the header
let shared: NonNull<SharedStruct> = start_ptr.cast();
unsafe {
shared.write(SharedStruct {
max_size: max_size.into(),
current_size: AtomicUsize::new(initial_size),
})
};
// The user data begins after the header
let data_ptr = unsafe { start_ptr.cast().add(HEADER_SIZE) };
Ok(ShmemHandle {
fd,
max_size: max_size.into(),
shared_ptr: shared,
data_ptr,
})
}
// return reference to the header
fn shared(&self) -> &SharedStruct {
unsafe { self.shared_ptr.as_ref() }
}
/// Resize the shared memory area. 'new_size' must not be larger than the 'max_size' specified
/// when creating the area.
///
/// This may only be called from one process/thread concurrently. We detect that case
/// and return an Error.
pub fn set_size(&self, new_size: usize) -> Result<(), Error> {
let new_size = new_size + HEADER_SIZE;
let shared = self.shared();
if new_size > self.max_size {
panic!(
"new size ({} is greater than max size ({})",
new_size, self.max_size
);
}
assert_eq!(self.max_size, shared.max_size);
// Lock the area by setting the bit in 'current_size'
//
// Ordering::Relaxed would probably be sufficient here, as we don't access any other memory
// and the posix_fallocate/ftruncate call is surely a synchronization point anyway. But
// since this is not performance-critical, better safe than sorry .
let mut old_size = shared.current_size.load(Ordering::Acquire);
loop {
if (old_size & RESIZE_IN_PROGRESS) != 0 {
return Err(Error::new(
"concurrent resize detected",
Errno::UnknownErrno,
));
}
match shared.current_size.compare_exchange(
old_size,
new_size,
Ordering::Acquire,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(x) => old_size = x,
}
}
// Ok, we got the lock.
//
// NB: If anything goes wrong, we *must* clear the bit!
let result = {
use std::cmp::Ordering::{Equal, Greater, Less};
match new_size.cmp(&old_size) {
Less => nix_ftruncate(&self.fd, new_size as i64).map_err(|e| {
Error::new("could not shrink shmem segment, ftruncate failed: {e}", e)
}),
Equal => Ok(()),
Greater => enlarge_file(self.fd.as_fd(), new_size as u64),
}
};
// Unlock
shared.current_size.store(
if result.is_ok() { new_size } else { old_size },
Ordering::Release,
);
result
}
/// Returns the current user-visible size of the shared memory segment.
///
/// NOTE: a concurrent set_size() call can change the size at any time. It is the caller's
/// responsibility not to access the area beyond the current size.
pub fn current_size(&self) -> usize {
let total_current_size =
self.shared().current_size.load(Ordering::Relaxed) & !RESIZE_IN_PROGRESS;
total_current_size - HEADER_SIZE
}
}
impl Drop for ShmemHandle {
fn drop(&mut self) {
// SAFETY: The pointer was obtained from mmap() with the given size.
// We unmap the entire region.
let _ = unsafe { nix_munmap(self.shared_ptr.cast(), self.max_size) };
// The fd is dropped automatically by OwnedFd.
}
}
/// Create a "backing file" for the shared memory area. On Linux, use memfd_create(), to create an
/// anonymous in-memory file. One macos, fall back to a regular file. That's good enough for
/// development and testing, but in production we want the file to stay in memory.
///
/// disable 'unused_variables' warnings, because in the macos path, 'name' is unused.
#[allow(unused_variables)]
fn create_backing_file(name: &str) -> Result<OwnedFd, Error> {
#[cfg(not(target_os = "macos"))]
{
nix::sys::memfd::memfd_create(name, nix::sys::memfd::MFdFlags::empty())
.map_err(|e| Error::new("memfd_create failed: {e}", e))
}
#[cfg(target_os = "macos")]
{
let file = tempfile::tempfile().map_err(|e| {
Error::new(
"could not create temporary file to back shmem area: {e}",
nix::errno::Errno::from_raw(e.raw_os_error().unwrap_or(0)),
)
})?;
Ok(OwnedFd::from(file))
}
}
fn enlarge_file(fd: BorrowedFd, size: u64) -> Result<(), Error> {
// Use posix_fallocate() to enlarge the file. It reserves the space correctly, so that
// we don't get a segfault later when trying to actually use it.
#[cfg(not(target_os = "macos"))]
{
nix::fcntl::posix_fallocate(fd, 0, size as i64).map_err(|e| {
Error::new(
"could not grow shmem segment, posix_fallocate failed: {e}",
e,
)
})
}
// As a fallback on macos, which doesn't have posix_fallocate, use plain 'fallocate'
#[cfg(target_os = "macos")]
{
nix::unistd::ftruncate(fd, size as i64)
.map_err(|e| Error::new("could not grow shmem segment, ftruncate failed: {e}", e))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nix::unistd::ForkResult;
use std::ops::Range;
/// check that all bytes in given range have the expected value.
fn assert_range(ptr: *const u8, expected: u8, range: Range<usize>) {
for i in range {
let b = unsafe { *(ptr.add(i)) };
assert_eq!(expected, b, "unexpected byte at offset {}", i);
}
}
/// Write 'b' to all bytes in the given range
fn write_range(ptr: *mut u8, b: u8, range: Range<usize>) {
unsafe { std::ptr::write_bytes(ptr.add(range.start), b, range.end - range.start) };
}
// simple single-process test of growing and shrinking
#[test]
fn test_shmem_resize() -> Result<(), Error> {
let max_size = 1024 * 1024;
let init_struct = ShmemHandle::new("test_shmem_resize", 0, max_size)?;
assert_eq!(init_struct.current_size(), 0);
// Initial grow
let size1 = 10000;
init_struct.set_size(size1).unwrap();
assert_eq!(init_struct.current_size(), size1);
// Write some data
let data_ptr = init_struct.data_ptr.as_ptr();
write_range(data_ptr, 0xAA, 0..size1);
assert_range(data_ptr, 0xAA, 0..size1);
// Shrink
let size2 = 5000;
init_struct.set_size(size2).unwrap();
assert_eq!(init_struct.current_size(), size2);
// Grow again
let size3 = 20000;
init_struct.set_size(size3).unwrap();
assert_eq!(init_struct.current_size(), size3);
// Try to read it. The area that was shrunk and grown again should read as all zeros now
assert_range(data_ptr, 0xAA, 0..5000);
assert_range(data_ptr, 0, 5000..size1);
// Try to grow beyond max_size
//let size4 = max_size + 1;
//assert!(init_struct.set_size(size4).is_err());
// Dropping init_struct should unmap the memory
drop(init_struct);
Ok(())
}
/// This is used in tests to coordinate between test processes. It's like std::sync::Barrier,
/// but is stored in the shared memory area and works across processes. It's implemented by
/// polling, because e.g. standard rust mutexes are not guaranteed to work across processes.
struct SimpleBarrier {
num_procs: usize,
count: AtomicUsize,
}
impl SimpleBarrier {
unsafe fn init(ptr: *mut SimpleBarrier, num_procs: usize) {
unsafe {
*ptr = SimpleBarrier {
num_procs,
count: AtomicUsize::new(0),
}
}
}
pub fn wait(&self) {
let old = self.count.fetch_add(1, Ordering::Relaxed);
let generation = old / self.num_procs;
let mut current = old + 1;
while current < (generation + 1) * self.num_procs {
std::thread::sleep(std::time::Duration::from_millis(10));
current = self.count.load(Ordering::Relaxed);
}
}
}
#[test]
fn test_multi_process() {
// Initialize
let max_size = 1_000_000_000_000;
let init_struct = ShmemHandle::new("test_multi_process", 0, max_size).unwrap();
let ptr = init_struct.data_ptr.as_ptr();
// Store the SimpleBarrier in the first 1k of the area.
init_struct.set_size(10000).unwrap();
let barrier_ptr: *mut SimpleBarrier = unsafe {
ptr.add(ptr.align_offset(std::mem::align_of::<SimpleBarrier>()))
.cast()
};
unsafe { SimpleBarrier::init(barrier_ptr, 2) };
let barrier = unsafe { barrier_ptr.as_ref().unwrap() };
// Fork another test process. The code after this runs in both processes concurrently.
let fork_result = unsafe { nix::unistd::fork().unwrap() };
// In the parent, fill bytes between 1000..2000. In the child, between 2000..3000
if fork_result.is_parent() {
write_range(ptr, 0xAA, 1000..2000);
} else {
write_range(ptr, 0xBB, 2000..3000);
}
barrier.wait();
// Verify the contents. (in both processes)
assert_range(ptr, 0xAA, 1000..2000);
assert_range(ptr, 0xBB, 2000..3000);
// Grow, from the child this time
let size = 10_000_000;
if !fork_result.is_parent() {
init_struct.set_size(size).unwrap();
}
barrier.wait();
// make some writes at the end
if fork_result.is_parent() {
write_range(ptr, 0xAA, (size - 10)..size);
} else {
write_range(ptr, 0xBB, (size - 20)..(size - 10));
}
barrier.wait();
// Verify the contents. (This runs in both processes)
assert_range(ptr, 0, (size - 1000)..(size - 20));
assert_range(ptr, 0xBB, (size - 20)..(size - 10));
assert_range(ptr, 0xAA, (size - 10)..size);
if let ForkResult::Parent { child } = fork_result {
nix::sys::wait::waitpid(child, None).unwrap();
}
}
}

View File

@@ -1,14 +0,0 @@
[package]
name = "neonart"
version = "0.1.0"
edition.workspace = true
license.workspace = true
[dependencies]
crossbeam-utils.workspace = true
spin.workspace = true
tracing.workspace = true
[dev-dependencies]
rand = "0.9.1"
rand_distr = "0.5.1"

View File

@@ -1,594 +0,0 @@
mod lock_and_version;
pub(crate) mod node_ptr;
mod node_ref;
use std::vec::Vec;
use crate::algorithm::lock_and_version::ConcurrentUpdateError;
use crate::algorithm::node_ptr::MAX_PREFIX_LEN;
use crate::algorithm::node_ref::{NewNodeRef, NodeRef, ReadLockedNodeRef, WriteLockedNodeRef};
use crate::allocator::OutOfMemoryError;
use crate::TreeWriteGuard;
use crate::UpdateAction;
use crate::allocator::ArtAllocator;
use crate::epoch::EpochPin;
use crate::{Key, Value};
pub(crate) type RootPtr<V> = node_ptr::NodePtr<V>;
#[derive(Debug)]
pub enum ArtError {
ConcurrentUpdate, // need to retry
OutOfMemory,
}
impl From<ConcurrentUpdateError> for ArtError {
fn from(_: ConcurrentUpdateError) -> ArtError {
ArtError::ConcurrentUpdate
}
}
impl From<OutOfMemoryError> for ArtError {
fn from(_: OutOfMemoryError) -> ArtError {
ArtError::OutOfMemory
}
}
pub fn new_root<V: Value>(
allocator: &impl ArtAllocator<V>,
) -> Result<RootPtr<V>, OutOfMemoryError> {
node_ptr::new_root(allocator)
}
pub(crate) fn search<'e, K: Key, V: Value>(
key: &K,
root: RootPtr<V>,
epoch_pin: &'e EpochPin,
) -> Option<&'e V> {
loop {
let root_ref = NodeRef::from_root_ptr(root);
if let Ok(result) = lookup_recurse(key.as_bytes(), root_ref, None, epoch_pin) {
break result;
}
// retry
}
}
pub(crate) fn iter_next<'e, V: Value>(
key: &[u8],
root: RootPtr<V>,
epoch_pin: &'e EpochPin,
) -> Option<(Vec<u8>, &'e V)> {
loop {
let mut path = Vec::new();
let root_ref = NodeRef::from_root_ptr(root);
match next_recurse(key, &mut path, root_ref, epoch_pin) {
Ok(Some(v)) => {
assert_eq!(path.len(), key.len());
break Some((path, v));
}
Ok(None) => break None,
Err(ConcurrentUpdateError()) => {
// retry
continue;
}
}
}
}
pub(crate) fn update_fn<'e, 'g, K: Key, V: Value, A: ArtAllocator<V>, F>(
key: &K,
value_fn: F,
root: RootPtr<V>,
guard: &'g mut TreeWriteGuard<'e, K, V, A>,
) -> Result<(), OutOfMemoryError>
where
F: FnOnce(Option<&V>) -> UpdateAction<V>,
{
let value_fn_cell = std::cell::Cell::new(Some(value_fn));
loop {
let root_ref = NodeRef::from_root_ptr(root);
let this_value_fn = |arg: Option<&V>| value_fn_cell.take().unwrap()(arg);
let key_bytes = key.as_bytes();
match update_recurse(
key_bytes,
this_value_fn,
root_ref,
None,
None,
guard,
0,
key_bytes,
) {
Ok(()) => break Ok(()),
Err(ArtError::ConcurrentUpdate) => {
continue; // retry
}
Err(ArtError::OutOfMemory) => break Err(OutOfMemoryError()),
}
}
}
// Error means you must retry.
//
// This corresponds to the 'lookupOpt' function in the paper
fn lookup_recurse<'e, V: Value>(
key: &[u8],
node: NodeRef<'e, V>,
parent: Option<ReadLockedNodeRef<V>>,
epoch_pin: &'e EpochPin,
) -> Result<Option<&'e V>, ConcurrentUpdateError> {
let rnode = node.read_lock_or_restart()?;
if let Some(parent) = parent {
parent.read_unlock_or_restart()?;
}
// check if the prefix matches, may increment level
let prefix_len = if let Some(prefix_len) = rnode.prefix_matches(key) {
prefix_len
} else {
rnode.read_unlock_or_restart()?;
return Ok(None);
};
if rnode.is_leaf() {
assert_eq!(key.len(), prefix_len);
let vptr = rnode.get_leaf_value_ptr()?;
// safety: It's OK to return a ref of the pointer because we checked the version
// and the lifetime of 'epoch_pin' enforces that the reference is only accessible
// as long as the epoch is pinned.
let v = unsafe { vptr.as_ref().unwrap() };
return Ok(Some(v));
}
let key = &key[prefix_len..];
// find child (or leaf value)
let next_node = rnode.find_child_or_restart(key[0])?;
match next_node {
None => Ok(None), // key not found
Some(child) => lookup_recurse(&key[1..], child, Some(rnode), epoch_pin),
}
}
fn next_recurse<'e, V: Value>(
min_key: &[u8],
path: &mut Vec<u8>,
node: NodeRef<'e, V>,
epoch_pin: &'e EpochPin,
) -> Result<Option<&'e V>, ConcurrentUpdateError> {
let rnode = node.read_lock_or_restart()?;
let prefix = rnode.get_prefix();
if prefix.len() != 0 {
path.extend_from_slice(prefix);
}
use std::cmp::Ordering;
let comparison = path.as_slice().cmp(&min_key[0..path.len()]);
if comparison == Ordering::Less {
rnode.read_unlock_or_restart()?;
return Ok(None);
}
if rnode.is_leaf() {
assert_eq!(path.len(), min_key.len());
let vptr = rnode.get_leaf_value_ptr()?;
// safety: It's OK to return a ref of the pointer because we checked the version
// and the lifetime of 'epoch_pin' enforces that the reference is only accessible
// as long as the epoch is pinned.
let v = unsafe { vptr.as_ref().unwrap() };
return Ok(Some(v));
}
let mut min_key_byte = match comparison {
Ordering::Less => unreachable!(), // checked this above already
Ordering::Equal => min_key[path.len()],
Ordering::Greater => 0,
};
loop {
match rnode.find_next_child_or_restart(min_key_byte)? {
None => {
return Ok(None);
}
Some((key_byte, child_ref)) => {
let path_len = path.len();
path.push(key_byte);
let result = next_recurse(min_key, path, child_ref, epoch_pin)?;
if result.is_some() {
return Ok(result);
}
if key_byte == u8::MAX {
return Ok(None);
}
path.truncate(path_len);
min_key_byte = key_byte + 1;
}
}
}
}
// This corresponds to the 'insertOpt' function in the paper
pub(crate) fn update_recurse<'e, 'g, K: Key, V: Value, A: ArtAllocator<V>, F>(
key: &[u8],
value_fn: F,
node: NodeRef<'e, V>,
rparent: Option<(ReadLockedNodeRef<V>, u8)>,
rgrandparent: Option<(ReadLockedNodeRef<V>, u8)>,
guard: &'g mut TreeWriteGuard<'e, K, V, A>,
level: usize,
orig_key: &[u8],
) -> Result<(), ArtError>
where
F: FnOnce(Option<&V>) -> UpdateAction<V>,
{
let rnode = node.read_lock_or_restart()?;
let prefix_match_len = rnode.prefix_matches(key);
if prefix_match_len.is_none() {
let (rparent, parent_key) = rparent.expect("direct children of the root have no prefix");
let mut wparent = rparent.upgrade_to_write_lock_or_restart()?;
let mut wnode = rnode.upgrade_to_write_lock_or_restart()?;
match value_fn(None) {
UpdateAction::Nothing => {}
UpdateAction::Insert(new_value) => {
insert_split_prefix(key, new_value, &mut wnode, &mut wparent, parent_key, guard)?;
}
UpdateAction::Remove => {
panic!("unexpected Remove action on insertion");
}
}
wnode.write_unlock();
wparent.write_unlock();
return Ok(());
}
let prefix_match_len = prefix_match_len.unwrap();
let key = &key[prefix_match_len as usize..];
let level = level + prefix_match_len as usize;
if rnode.is_leaf() {
assert_eq!(key.len(), 0);
let (rparent, parent_key) = rparent.expect("root cannot be leaf");
let mut wparent = rparent.upgrade_to_write_lock_or_restart()?;
let mut wnode = rnode.upgrade_to_write_lock_or_restart()?;
// safety: Now that we have acquired the write lock, we have exclusive access to the
// value. XXX: There might be concurrent reads though?
let value_mut = wnode.get_leaf_value_mut();
match value_fn(Some(value_mut)) {
UpdateAction::Nothing => {
wparent.write_unlock();
wnode.write_unlock();
}
UpdateAction::Insert(_) => panic!("cannot insert over existing value"),
UpdateAction::Remove => {
guard.remember_obsolete_node(wnode.as_ptr());
wparent.delete_child(parent_key);
wnode.write_unlock_obsolete();
if let Some(rgrandparent) = rgrandparent {
// FIXME: Ignore concurrency error. It doesn't lead to
// corruption, but it means we might leak something. Until
// another update cleans it up.
let _ = cleanup_parent(wparent, rgrandparent, guard);
}
}
}
return Ok(());
}
let next_node = rnode.find_child_or_restart(key[0])?;
if next_node.is_none() {
if rnode.is_full() {
let (rparent, parent_key) = rparent.expect("root node cannot become full");
let mut wparent = rparent.upgrade_to_write_lock_or_restart()?;
let wnode = rnode.upgrade_to_write_lock_or_restart()?;
match value_fn(None) {
UpdateAction::Nothing => {
wnode.write_unlock();
wparent.write_unlock();
}
UpdateAction::Insert(new_value) => {
insert_and_grow(key, new_value, wnode, &mut wparent, parent_key, guard)?;
wparent.write_unlock();
}
UpdateAction::Remove => {
panic!("unexpected Remove action on insertion");
}
};
} else {
let mut wnode = rnode.upgrade_to_write_lock_or_restart()?;
if let Some((rparent, _)) = rparent {
rparent.read_unlock_or_restart()?;
}
match value_fn(None) {
UpdateAction::Nothing => {}
UpdateAction::Insert(new_value) => {
insert_to_node(&mut wnode, key, new_value, guard)?;
}
UpdateAction::Remove => {
panic!("unexpected Remove action on insertion");
}
};
wnode.write_unlock();
}
return Ok(());
} else {
let next_child = next_node.unwrap(); // checked above it's not None
if let Some((ref rparent, _)) = rparent {
rparent.check_or_restart()?;
}
// recurse to next level
update_recurse(
&key[1..],
value_fn,
next_child,
Some((rnode, key[0])),
rparent,
guard,
level + 1,
orig_key,
)
}
}
#[derive(Clone)]
enum PathElement {
Prefix(Vec<u8>),
KeyByte(u8),
}
impl std::fmt::Debug for PathElement {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
match self {
PathElement::Prefix(prefix) => write!(fmt, "{:?}", prefix),
PathElement::KeyByte(key_byte) => write!(fmt, "{}", key_byte),
}
}
}
pub(crate) fn dump_tree<'e, V: Value + std::fmt::Debug>(
root: RootPtr<V>,
epoch_pin: &'e EpochPin,
dst: &mut dyn std::io::Write,
) {
let root_ref = NodeRef::from_root_ptr(root);
let _ = dump_recurse(&[], root_ref, &epoch_pin, 0, dst);
}
// TODO: return an Err if writeln!() returns error, instead of unwrapping
fn dump_recurse<'e, V: Value + std::fmt::Debug>(
path: &[PathElement],
node: NodeRef<'e, V>,
epoch_pin: &'e EpochPin,
level: usize,
dst: &mut dyn std::io::Write,
) -> Result<(), ConcurrentUpdateError> {
let indent = str::repeat(" ", level);
let rnode = node.read_lock_or_restart()?;
let mut path = Vec::from(path);
let prefix = rnode.get_prefix();
if prefix.len() != 0 {
path.push(PathElement::Prefix(Vec::from(prefix)));
}
if rnode.is_leaf() {
let vptr = rnode.get_leaf_value_ptr()?;
// safety: It's OK to return a ref of the pointer because we checked the version
// and the lifetime of 'epoch_pin' enforces that the reference is only accessible
// as long as the epoch is pinned.
let val = unsafe { vptr.as_ref().unwrap() };
writeln!(dst, "{} {:?}: {:?}", indent, path, val).unwrap();
return Ok(());
}
for key_byte in 0..=u8::MAX {
match rnode.find_child_or_restart(key_byte)? {
None => continue,
Some(child_ref) => {
let rchild = child_ref.read_lock_or_restart()?;
writeln!(
dst,
"{} {:?}, {}: prefix {:?}",
indent,
&path,
key_byte,
rchild.get_prefix()
)
.unwrap();
let mut child_path = path.clone();
child_path.push(PathElement::KeyByte(key_byte));
dump_recurse(&child_path, child_ref, epoch_pin, level + 1, dst)?;
}
}
}
Ok(())
}
///```text
/// [fooba]r -> value
///
/// [foo]b -> [a]r -> value
/// e -> [ls]e -> value
///```
fn insert_split_prefix<'e, K: Key, V: Value, A: ArtAllocator<V>>(
key: &[u8],
value: V,
node: &mut WriteLockedNodeRef<V>,
parent: &mut WriteLockedNodeRef<V>,
parent_key: u8,
guard: &'e TreeWriteGuard<K, V, A>,
) -> Result<(), OutOfMemoryError> {
let old_node = node;
let old_prefix = old_node.get_prefix();
let common_prefix_len = common_prefix(key, old_prefix);
// Allocate a node for the new value.
let new_value_node = allocate_node_for_value(
&key[common_prefix_len + 1..],
value,
guard.tree_writer.allocator,
)?;
// Allocate a new internal node with the common prefix
// FIXME: deallocate 'new_value_node' on OOM
let mut prefix_node =
node_ref::new_internal(&key[..common_prefix_len], guard.tree_writer.allocator)?;
// Add the old node and the new nodes to the new internal node
prefix_node.insert_old_child(old_prefix[common_prefix_len], old_node);
prefix_node.insert_new_child(key[common_prefix_len], new_value_node);
// Modify the prefix of the old child in place
old_node.truncate_prefix(old_prefix.len() - common_prefix_len - 1);
// replace the pointer in the parent
parent.replace_child(parent_key, prefix_node.into_ptr());
Ok(())
}
fn insert_to_node<'e, K: Key, V: Value, A: ArtAllocator<V>>(
wnode: &mut WriteLockedNodeRef<V>,
key: &[u8],
value: V,
guard: &'e TreeWriteGuard<K, V, A>,
) -> Result<(), OutOfMemoryError> {
let value_child = allocate_node_for_value(&key[1..], value, guard.tree_writer.allocator)?;
wnode.insert_child(key[0], value_child.into_ptr());
Ok(())
}
// On entry: 'parent' and 'node' are locked
fn insert_and_grow<'e, 'g, K: Key, V: Value, A: ArtAllocator<V>>(
key: &[u8],
value: V,
wnode: WriteLockedNodeRef<V>,
parent: &mut WriteLockedNodeRef<V>,
parent_key_byte: u8,
guard: &'g mut TreeWriteGuard<'e, K, V, A>,
) -> Result<(), ArtError> {
let mut bigger_node = wnode.grow(guard.tree_writer.allocator)?;
// FIXME: deallocate 'bigger_node' on OOM
let value_child = allocate_node_for_value(&key[1..], value, guard.tree_writer.allocator)?;
bigger_node.insert_new_child(key[0], value_child);
// Replace the pointer in the parent
parent.replace_child(parent_key_byte, bigger_node.into_ptr());
guard.remember_obsolete_node(wnode.as_ptr());
wnode.write_unlock_obsolete();
Ok(())
}
fn cleanup_parent<'e, 'g, K: Key, V: Value, A: ArtAllocator<V>>(
wparent: WriteLockedNodeRef<V>,
rgrandparent: (ReadLockedNodeRef<V>, u8),
guard: &'g mut TreeWriteGuard<'e, K, V, A>,
) -> Result<(), ArtError> {
let (rgrandparent, grandparent_key_byte) = rgrandparent;
// If the parent becomes completely empty after the deletion, remove the parent from the
// grandparent. (This case is possible because we reserve only 8 bytes for the prefix.)
// TODO: not implemented.
// If the parent has only one child, replace the parent with the remaining child. (This is not
// possible if the child's prefix field cannot absorb the parent's)
if wparent.num_children() == 1 {
// Try to lock the remaining child. This can fail if the child is updated
// concurrently.
let (key_byte, remaining_child) = wparent.find_remaining_child();
let mut wremaining_child = remaining_child.write_lock_or_restart()?;
if 1 + wremaining_child.get_prefix().len() + wparent.get_prefix().len() <= MAX_PREFIX_LEN {
let mut wgrandparent = rgrandparent.upgrade_to_write_lock_or_restart()?;
// Ok, we have locked the leaf, the parent, the grandparent, and the parent's only
// remaining leaf. Proceed with the updates.
// Update the prefix on the remaining leaf
wremaining_child.prepend_prefix(wparent.get_prefix(), key_byte);
// Replace the pointer in the grandparent to point directly to the remaining leaf
wgrandparent.replace_child(grandparent_key_byte, wremaining_child.as_ptr());
// Mark the parent as deleted.
guard.remember_obsolete_node(wparent.as_ptr());
wparent.write_unlock_obsolete();
return Ok(());
}
}
// If the parent's children would fit on a smaller node type after the deletion, replace it with
// a smaller node.
if wparent.can_shrink() {
let mut wgrandparent = rgrandparent.upgrade_to_write_lock_or_restart()?;
let smaller_node = wparent.shrink(guard.tree_writer.allocator)?;
// Replace the pointer in the grandparent
wgrandparent.replace_child(grandparent_key_byte, smaller_node.into_ptr());
guard.remember_obsolete_node(wparent.as_ptr());
wparent.write_unlock_obsolete();
return Ok(());
}
// nothing to do
wparent.write_unlock();
Ok(())
}
// Allocate a new leaf node to hold 'value'. If the key is long, we
// may need to allocate new internal nodes to hold it too
fn allocate_node_for_value<'a, V: Value, A: ArtAllocator<V>>(
key: &[u8],
value: V,
allocator: &'a A,
) -> Result<NewNodeRef<'a, V, A>, OutOfMemoryError> {
let mut prefix_off = key.len().saturating_sub(MAX_PREFIX_LEN);
let leaf_node = node_ref::new_leaf(&key[prefix_off..key.len()], value, allocator)?;
let mut node = leaf_node;
while prefix_off > 0 {
// Need another internal node
let remain_prefix = &key[0..prefix_off];
prefix_off = remain_prefix.len().saturating_sub(MAX_PREFIX_LEN + 1);
let mut internal_node = node_ref::new_internal(
&remain_prefix[prefix_off..remain_prefix.len() - 1],
allocator,
)?;
internal_node.insert_new_child(*remain_prefix.last().unwrap(), node);
node = internal_node;
}
Ok(node)
}
fn common_prefix(a: &[u8], b: &[u8]) -> usize {
for i in 0..MAX_PREFIX_LEN {
if a[i] != b[i] {
return i;
}
}
panic!("prefixes are equal");
}

View File

@@ -1,117 +0,0 @@
//! Each node in the tree has contains one atomic word that stores three things:
//!
//! Bit 0: set if the node is "obsolete". An obsolete node has been removed from the tree,
//! but might still be accessed by concurrent readers until the epoch expires.
//! Bit 1: set if the node is currently write-locked. Used as a spinlock.
//! Bits 2-63: Version number, incremented every time the node is modified.
//!
//! AtomicLockAndVersion represents that.
use std::sync::atomic::{AtomicU64, Ordering};
pub(crate) struct ConcurrentUpdateError();
pub(crate) struct AtomicLockAndVersion {
inner: AtomicU64,
}
impl AtomicLockAndVersion {
pub(crate) fn new() -> AtomicLockAndVersion {
AtomicLockAndVersion {
inner: AtomicU64::new(0),
}
}
}
impl AtomicLockAndVersion {
pub(crate) fn read_lock_or_restart(&self) -> Result<u64, ConcurrentUpdateError> {
let version = self.await_node_unlocked();
if is_obsolete(version) {
return Err(ConcurrentUpdateError());
}
Ok(version)
}
pub(crate) fn check_or_restart(&self, version: u64) -> Result<(), ConcurrentUpdateError> {
self.read_unlock_or_restart(version)
}
pub(crate) fn read_unlock_or_restart(&self, version: u64) -> Result<(), ConcurrentUpdateError> {
if self.inner.load(Ordering::Acquire) != version {
return Err(ConcurrentUpdateError());
}
Ok(())
}
pub(crate) fn upgrade_to_write_lock_or_restart(
&self,
version: u64,
) -> Result<(), ConcurrentUpdateError> {
if self
.inner
.compare_exchange(
version,
set_locked_bit(version),
Ordering::Acquire,
Ordering::Relaxed,
)
.is_err()
{
return Err(ConcurrentUpdateError());
}
Ok(())
}
pub(crate) fn write_lock_or_restart(&self) -> Result<(), ConcurrentUpdateError> {
let old = self.inner.load(Ordering::Relaxed);
if is_obsolete(old) || is_locked(old) {
return Err(ConcurrentUpdateError());
}
if self
.inner
.compare_exchange(
old,
set_locked_bit(old),
Ordering::Acquire,
Ordering::Relaxed,
)
.is_err()
{
return Err(ConcurrentUpdateError());
}
Ok(())
}
pub(crate) fn write_unlock(&self) {
// reset locked bit and overflow into version
self.inner.fetch_add(2, Ordering::Release);
}
pub(crate) fn write_unlock_obsolete(&self) {
// set obsolete, reset locked, overflow into version
self.inner.fetch_add(3, Ordering::Release);
}
// Helper functions
fn await_node_unlocked(&self) -> u64 {
let mut version = self.inner.load(Ordering::Acquire);
while is_locked(version) {
// spinlock
std::thread::yield_now();
version = self.inner.load(Ordering::Acquire)
}
version
}
}
fn set_locked_bit(version: u64) -> u64 {
return version + 2;
}
fn is_obsolete(version: u64) -> bool {
return (version & 1) == 1;
}
fn is_locked(version: u64) -> bool {
return (version & 2) == 2;
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,349 +0,0 @@
use std::fmt::Debug;
use std::marker::PhantomData;
use super::node_ptr;
use super::node_ptr::NodePtr;
use crate::EpochPin;
use crate::Value;
use crate::algorithm::lock_and_version::AtomicLockAndVersion;
use crate::algorithm::lock_and_version::ConcurrentUpdateError;
use crate::allocator::ArtAllocator;
use crate::allocator::OutOfMemoryError;
pub struct NodeRef<'e, V> {
ptr: NodePtr<V>,
phantom: PhantomData<&'e EpochPin<'e>>,
}
impl<'e, V> Debug for NodeRef<'e, V> {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(fmt, "{:?}", self.ptr)
}
}
impl<'e, V: Value> NodeRef<'e, V> {
pub(crate) fn from_root_ptr(root_ptr: NodePtr<V>) -> NodeRef<'e, V> {
NodeRef {
ptr: root_ptr,
phantom: PhantomData,
}
}
pub(crate) fn read_lock_or_restart(
&self,
) -> Result<ReadLockedNodeRef<'e, V>, ConcurrentUpdateError> {
let version = self.lockword().read_lock_or_restart()?;
Ok(ReadLockedNodeRef {
ptr: self.ptr,
version,
phantom: self.phantom,
})
}
pub(crate) fn write_lock_or_restart(
&self,
) -> Result<WriteLockedNodeRef<'e, V>, ConcurrentUpdateError> {
self.lockword().write_lock_or_restart()?;
Ok(WriteLockedNodeRef {
ptr: self.ptr,
phantom: self.phantom,
})
}
fn lockword(&self) -> &AtomicLockAndVersion {
self.ptr.lockword()
}
}
/// A reference to a node that has been optimistically read-locked. The functions re-check
/// the version after each read.
pub struct ReadLockedNodeRef<'e, V> {
ptr: NodePtr<V>,
version: u64,
phantom: PhantomData<&'e EpochPin<'e>>,
}
impl<'e, V: Value> ReadLockedNodeRef<'e, V> {
pub(crate) fn is_leaf(&self) -> bool {
self.ptr.is_leaf()
}
pub(crate) fn is_full(&self) -> bool {
self.ptr.is_full()
}
pub(crate) fn get_prefix(&self) -> &[u8] {
self.ptr.get_prefix()
}
/// Note: because we're only holding a read lock, the prefix can change concurrently.
/// You must be prepared to restart, if read_unlock() returns error later.
///
/// Returns the length of the prefix, or None if it's not a match
pub(crate) fn prefix_matches(&self, key: &[u8]) -> Option<usize> {
self.ptr.prefix_matches(key)
}
pub(crate) fn find_child_or_restart(
&self,
key_byte: u8,
) -> Result<Option<NodeRef<'e, V>>, ConcurrentUpdateError> {
let child_or_value = self.ptr.find_child(key_byte);
self.ptr.lockword().check_or_restart(self.version)?;
match child_or_value {
None => Ok(None),
Some(child_ptr) => Ok(Some(NodeRef {
ptr: child_ptr,
phantom: self.phantom,
})),
}
}
pub(crate) fn find_next_child_or_restart(
&self,
min_key_byte: u8,
) -> Result<Option<(u8, NodeRef<'e, V>)>, ConcurrentUpdateError> {
let child_or_value = self.ptr.find_next_child(min_key_byte);
self.ptr.lockword().check_or_restart(self.version)?;
match child_or_value {
None => Ok(None),
Some((k, child_ptr)) => Ok(Some((
k,
NodeRef {
ptr: child_ptr,
phantom: self.phantom,
},
))),
}
}
pub(crate) fn get_leaf_value_ptr(&self) -> Result<*const V, ConcurrentUpdateError> {
let result = self.ptr.get_leaf_value();
self.ptr.lockword().check_or_restart(self.version)?;
// Extend the lifetime.
let result = std::ptr::from_ref(result);
Ok(result)
}
pub(crate) fn upgrade_to_write_lock_or_restart(
self,
) -> Result<WriteLockedNodeRef<'e, V>, ConcurrentUpdateError> {
self.ptr
.lockword()
.upgrade_to_write_lock_or_restart(self.version)?;
Ok(WriteLockedNodeRef {
ptr: self.ptr,
phantom: self.phantom,
})
}
pub(crate) fn read_unlock_or_restart(self) -> Result<(), ConcurrentUpdateError> {
self.ptr.lockword().check_or_restart(self.version)?;
Ok(())
}
pub(crate) fn check_or_restart(&self) -> Result<(), ConcurrentUpdateError> {
self.ptr.lockword().check_or_restart(self.version)?;
Ok(())
}
}
/// A reference to a node that has been optimistically read-locked. The functions re-check
/// the version after each read.
pub struct WriteLockedNodeRef<'e, V> {
ptr: NodePtr<V>,
phantom: PhantomData<&'e EpochPin<'e>>,
}
impl<'e, V: Value> WriteLockedNodeRef<'e, V> {
pub(crate) fn can_shrink(&self) -> bool {
self.ptr.can_shrink()
}
pub(crate) fn num_children(&self) -> usize {
self.ptr.num_children()
}
pub(crate) fn write_unlock(mut self) {
self.ptr.lockword().write_unlock();
self.ptr = NodePtr::null();
}
pub(crate) fn write_unlock_obsolete(mut self) {
self.ptr.lockword().write_unlock_obsolete();
self.ptr = NodePtr::null();
}
pub(crate) fn get_prefix(&self) -> &[u8] {
self.ptr.get_prefix()
}
pub(crate) fn truncate_prefix(&mut self, new_prefix_len: usize) {
self.ptr.truncate_prefix(new_prefix_len)
}
pub(crate) fn prepend_prefix(&mut self, prefix: &[u8], prefix_byte: u8) {
self.ptr.prepend_prefix(prefix, prefix_byte)
}
pub(crate) fn insert_child(&mut self, key_byte: u8, child: NodePtr<V>) {
self.ptr.insert_child(key_byte, child)
}
pub(crate) fn get_leaf_value_mut(&mut self) -> &mut V {
self.ptr.get_leaf_value_mut()
}
pub(crate) fn grow<'a, A>(
&self,
allocator: &'a A,
) -> Result<NewNodeRef<'a, V, A>, OutOfMemoryError>
where
A: ArtAllocator<V>,
{
let new_node = self.ptr.grow(allocator)?;
Ok(NewNodeRef {
ptr: new_node,
allocator,
extra_nodes: Vec::new(),
})
}
pub(crate) fn shrink<'a, A>(
&self,
allocator: &'a A,
) -> Result<NewNodeRef<'a, V, A>, OutOfMemoryError>
where
A: ArtAllocator<V>,
{
let new_node = self.ptr.shrink(allocator)?;
Ok(NewNodeRef {
ptr: new_node,
allocator,
extra_nodes: Vec::new(),
})
}
pub(crate) fn as_ptr(&self) -> NodePtr<V> {
self.ptr
}
pub(crate) fn replace_child(&mut self, key_byte: u8, replacement: NodePtr<V>) {
self.ptr.replace_child(key_byte, replacement);
}
pub(crate) fn delete_child(&mut self, key_byte: u8) {
self.ptr.delete_child(key_byte);
}
pub(crate) fn find_remaining_child(&self) -> (u8, NodeRef<'e, V>) {
assert_eq!(self.num_children(), 1);
let child_or_value = self.ptr.find_next_child(0);
match child_or_value {
None => panic!("could not find only child in node"),
Some((k, child_ptr)) => (
k,
NodeRef {
ptr: child_ptr,
phantom: self.phantom,
},
),
}
}
}
impl<'e, V> Drop for WriteLockedNodeRef<'e, V> {
fn drop(&mut self) {
if !self.ptr.is_null() {
self.ptr.lockword().write_unlock();
}
}
}
pub(crate) struct NewNodeRef<'a, V, A>
where
V: Value,
A: ArtAllocator<V>,
{
ptr: NodePtr<V>,
allocator: &'a A,
extra_nodes: Vec<NodePtr<V>>,
}
impl<'a, V, A> NewNodeRef<'a, V, A>
where
V: Value,
A: ArtAllocator<V>,
{
pub(crate) fn insert_old_child(&mut self, key_byte: u8, child: &WriteLockedNodeRef<V>) {
self.ptr.insert_child(key_byte, child.as_ptr())
}
pub(crate) fn into_ptr(mut self) -> NodePtr<V> {
let ptr = self.ptr;
self.ptr = NodePtr::null();
ptr
}
pub(crate) fn insert_new_child(&mut self, key_byte: u8, child: NewNodeRef<'a, V, A>) {
let child_ptr = child.into_ptr();
self.ptr.insert_child(key_byte, child_ptr);
self.extra_nodes.push(child_ptr);
}
}
impl<'a, V, A> Drop for NewNodeRef<'a, V, A>
where
V: Value,
A: ArtAllocator<V>,
{
/// This drop implementation deallocates the newly allocated node, if into_ptr() was not called.
fn drop(&mut self) {
if !self.ptr.is_null() {
self.ptr.deallocate(self.allocator);
for p in self.extra_nodes.iter() {
p.deallocate(self.allocator);
}
}
}
}
pub(crate) fn new_internal<'a, V, A>(
prefix: &[u8],
allocator: &'a A,
) -> Result<NewNodeRef<'a, V, A>, OutOfMemoryError>
where
V: Value,
A: ArtAllocator<V>,
{
Ok(NewNodeRef {
ptr: node_ptr::new_internal(prefix, allocator)?,
allocator,
extra_nodes: Vec::new(),
})
}
pub(crate) fn new_leaf<'a, V, A>(
prefix: &[u8],
value: V,
allocator: &'a A,
) -> Result<NewNodeRef<'a, V, A>, OutOfMemoryError>
where
V: Value,
A: ArtAllocator<V>,
{
Ok(NewNodeRef {
ptr: node_ptr::new_leaf(prefix, value, allocator)?,
allocator,
extra_nodes: Vec::new(),
})
}

View File

@@ -1,158 +0,0 @@
pub mod block;
mod multislab;
mod slab;
pub mod r#static;
use std::alloc::Layout;
use std::marker::PhantomData;
use std::mem::MaybeUninit;
use std::sync::atomic::Ordering;
use crate::allocator::multislab::MultiSlabAllocator;
use crate::allocator::r#static::alloc_from_slice;
use spin;
use crate::Tree;
pub use crate::algorithm::node_ptr::{
NodeInternal4, NodeInternal16, NodeInternal48, NodeInternal256, NodeLeaf,
};
#[derive(Debug)]
pub struct OutOfMemoryError();
pub trait ArtAllocator<V: crate::Value> {
fn alloc_tree(&self) -> *mut Tree<V>;
fn alloc_node_internal4(&self) -> *mut NodeInternal4<V>;
fn alloc_node_internal16(&self) -> *mut NodeInternal16<V>;
fn alloc_node_internal48(&self) -> *mut NodeInternal48<V>;
fn alloc_node_internal256(&self) -> *mut NodeInternal256<V>;
fn alloc_node_leaf(&self) -> *mut NodeLeaf<V>;
fn dealloc_node_internal4(&self, ptr: *mut NodeInternal4<V>);
fn dealloc_node_internal16(&self, ptr: *mut NodeInternal16<V>);
fn dealloc_node_internal48(&self, ptr: *mut NodeInternal48<V>);
fn dealloc_node_internal256(&self, ptr: *mut NodeInternal256<V>);
fn dealloc_node_leaf(&self, ptr: *mut NodeLeaf<V>);
}
pub struct ArtMultiSlabAllocator<'t, V>
where
V: crate::Value,
{
tree_area: spin::Mutex<Option<&'t mut MaybeUninit<Tree<V>>>>,
pub(crate) inner: MultiSlabAllocator<'t, 5>,
phantom_val: PhantomData<V>,
}
impl<'t, V: crate::Value> ArtMultiSlabAllocator<'t, V> {
const LAYOUTS: [Layout; 5] = [
Layout::new::<NodeInternal4<V>>(),
Layout::new::<NodeInternal16<V>>(),
Layout::new::<NodeInternal48<V>>(),
Layout::new::<NodeInternal256<V>>(),
Layout::new::<NodeLeaf<V>>(),
];
pub fn new(area: &'t mut [MaybeUninit<u8>]) -> &'t mut ArtMultiSlabAllocator<'t, V> {
let (allocator_area, remain) = alloc_from_slice::<ArtMultiSlabAllocator<V>>(area);
let (tree_area, remain) = alloc_from_slice::<Tree<V>>(remain);
let allocator = allocator_area.write(ArtMultiSlabAllocator {
tree_area: spin::Mutex::new(Some(tree_area)),
inner: MultiSlabAllocator::new(remain, &Self::LAYOUTS),
phantom_val: PhantomData,
});
allocator
}
}
impl<'t, V: crate::Value> ArtAllocator<V> for ArtMultiSlabAllocator<'t, V> {
fn alloc_tree(&self) -> *mut Tree<V> {
let mut t = self.tree_area.lock();
if let Some(tree_area) = t.take() {
return tree_area.as_mut_ptr().cast();
}
panic!("cannot allocate more than one tree");
}
fn alloc_node_internal4(&self) -> *mut NodeInternal4<V> {
self.inner.alloc_slab(0).cast()
}
fn alloc_node_internal16(&self) -> *mut NodeInternal16<V> {
self.inner.alloc_slab(1).cast()
}
fn alloc_node_internal48(&self) -> *mut NodeInternal48<V> {
self.inner.alloc_slab(2).cast()
}
fn alloc_node_internal256(&self) -> *mut NodeInternal256<V> {
self.inner.alloc_slab(3).cast()
}
fn alloc_node_leaf(&self) -> *mut NodeLeaf<V> {
self.inner.alloc_slab(4).cast()
}
fn dealloc_node_internal4(&self, ptr: *mut NodeInternal4<V>) {
self.inner.dealloc_slab(0, ptr.cast())
}
fn dealloc_node_internal16(&self, ptr: *mut NodeInternal16<V>) {
self.inner.dealloc_slab(1, ptr.cast())
}
fn dealloc_node_internal48(&self, ptr: *mut NodeInternal48<V>) {
self.inner.dealloc_slab(2, ptr.cast())
}
fn dealloc_node_internal256(&self, ptr: *mut NodeInternal256<V>) {
self.inner.dealloc_slab(3, ptr.cast())
}
fn dealloc_node_leaf(&self, ptr: *mut NodeLeaf<V>) {
self.inner.dealloc_slab(4, ptr.cast())
}
}
impl<'t, V: crate::Value> ArtMultiSlabAllocator<'t, V> {
pub(crate) fn get_statistics(&self) -> ArtMultiSlabStats {
ArtMultiSlabStats {
num_internal4: self.inner.slab_descs[0]
.num_allocated
.load(Ordering::Relaxed),
num_internal16: self.inner.slab_descs[1]
.num_allocated
.load(Ordering::Relaxed),
num_internal48: self.inner.slab_descs[2]
.num_allocated
.load(Ordering::Relaxed),
num_internal256: self.inner.slab_descs[3]
.num_allocated
.load(Ordering::Relaxed),
num_leaf: self.inner.slab_descs[4]
.num_allocated
.load(Ordering::Relaxed),
num_blocks_internal4: self.inner.slab_descs[0].num_blocks.load(Ordering::Relaxed),
num_blocks_internal16: self.inner.slab_descs[1].num_blocks.load(Ordering::Relaxed),
num_blocks_internal48: self.inner.slab_descs[2].num_blocks.load(Ordering::Relaxed),
num_blocks_internal256: self.inner.slab_descs[3].num_blocks.load(Ordering::Relaxed),
num_blocks_leaf: self.inner.slab_descs[4].num_blocks.load(Ordering::Relaxed),
}
}
}
#[derive(Clone, Debug)]
pub struct ArtMultiSlabStats {
pub num_internal4: u64,
pub num_internal16: u64,
pub num_internal48: u64,
pub num_internal256: u64,
pub num_leaf: u64,
pub num_blocks_internal4: u64,
pub num_blocks_internal16: u64,
pub num_blocks_internal48: u64,
pub num_blocks_internal256: u64,
pub num_blocks_leaf: u64,
}

View File

@@ -1,191 +0,0 @@
//! Simple allocator of fixed-size blocks
use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicU64, Ordering};
use spin;
pub const BLOCK_SIZE: usize = 16 * 1024;
const INVALID_BLOCK: u64 = u64::MAX;
pub(crate) struct BlockAllocator<'t> {
blocks_ptr: &'t [MaybeUninit<u8>],
num_blocks: u64,
num_initialized: AtomicU64,
freelist_head: spin::Mutex<u64>,
}
struct FreeListBlock {
inner: spin::Mutex<FreeListBlockInner>,
}
struct FreeListBlockInner {
next: u64,
num_free_blocks: u64,
free_blocks: [u64; 100], // FIXME: fill the rest of the block
}
impl<'t> BlockAllocator<'t> {
pub(crate) fn new(area: &'t mut [MaybeUninit<u8>]) -> Self {
// Use all the space for the blocks
let padding = area.as_ptr().align_offset(BLOCK_SIZE);
let remain = &mut area[padding..];
let num_blocks = (remain.len() / BLOCK_SIZE) as u64;
BlockAllocator {
blocks_ptr: remain,
num_blocks,
num_initialized: AtomicU64::new(0),
freelist_head: spin::Mutex::new(INVALID_BLOCK),
}
}
/// safety: you must hold a lock on the pointer to this block, otherwise it might get
/// reused for another kind of block
fn read_freelist_block(&self, blkno: u64) -> &FreeListBlock {
let ptr: *const FreeListBlock = self.get_block_ptr(blkno).cast();
unsafe { ptr.as_ref().unwrap() }
}
fn get_block_ptr(&self, blkno: u64) -> *mut u8 {
assert!(blkno < self.num_blocks);
unsafe {
self.blocks_ptr
.as_ptr()
.byte_offset(blkno as isize * BLOCK_SIZE as isize)
}
.cast_mut()
.cast()
}
#[allow(clippy::mut_from_ref)]
pub(crate) fn alloc_block(&self) -> &mut [MaybeUninit<u8>] {
// FIXME: handle OOM
let blkno = self.alloc_block_internal();
if blkno == INVALID_BLOCK {
panic!("out of memory");
}
let ptr: *mut MaybeUninit<u8> = self.get_block_ptr(blkno).cast();
unsafe { std::slice::from_raw_parts_mut(ptr, BLOCK_SIZE) }
}
fn alloc_block_internal(&self) -> u64 {
// check the free list.
{
let mut freelist_head = self.freelist_head.lock();
if *freelist_head != INVALID_BLOCK {
let freelist_block = self.read_freelist_block(*freelist_head);
// acquire lock on the freelist block before releasing the lock on the parent (i.e. lock coupling)
let mut g = freelist_block.inner.lock();
if g.num_free_blocks > 0 {
g.num_free_blocks -= 1;
let result = g.free_blocks[g.num_free_blocks as usize];
return result;
} else {
// consume the freelist block itself
let result = *freelist_head;
*freelist_head = g.next;
// This freelist block is now unlinked and can be repurposed
drop(g);
return result;
}
}
}
// If there are some blocks left that we've never used, pick next such block
let mut next_uninitialized = self.num_initialized.load(Ordering::Relaxed);
while next_uninitialized < self.num_blocks {
match self.num_initialized.compare_exchange(
next_uninitialized,
next_uninitialized + 1,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => {
return next_uninitialized;
}
Err(old) => {
next_uninitialized = old;
continue;
}
}
}
// out of blocks
return INVALID_BLOCK;
}
// TODO: this is currently unused. The slab allocator never releases blocks
#[allow(dead_code)]
pub(crate) fn release_block(&self, block_ptr: *mut u8) {
let blockno = unsafe { block_ptr.byte_offset_from(self.blocks_ptr) / BLOCK_SIZE as isize };
self.release_block_internal(blockno as u64);
}
fn release_block_internal(&self, blockno: u64) {
let mut freelist_head = self.freelist_head.lock();
if *freelist_head != INVALID_BLOCK {
let freelist_block = self.read_freelist_block(*freelist_head);
// acquire lock on the freelist block before releasing the lock on the parent (i.e. lock coupling)
let mut g = freelist_block.inner.lock();
let num_free_blocks = g.num_free_blocks;
if num_free_blocks < g.free_blocks.len() as u64 {
g.free_blocks[num_free_blocks as usize] = blockno;
g.num_free_blocks += 1;
return;
}
}
// Convert the block into a new freelist block
let block_ptr: *mut FreeListBlock = self.get_block_ptr(blockno).cast();
let init = FreeListBlock {
inner: spin::Mutex::new(FreeListBlockInner {
next: *freelist_head,
num_free_blocks: 0,
free_blocks: [INVALID_BLOCK; 100],
}),
};
unsafe { (*block_ptr) = init };
*freelist_head = blockno;
}
// for debugging
pub(crate) fn get_statistics(&self) -> BlockAllocatorStats {
let mut num_free_blocks = 0;
let mut _prev_lock = None;
let head_lock = self.freelist_head.lock();
let mut next_blk = *head_lock;
let mut _head_lock = Some(head_lock);
while next_blk != INVALID_BLOCK {
let freelist_block = self.read_freelist_block(next_blk);
let lock = freelist_block.inner.lock();
num_free_blocks += lock.num_free_blocks;
next_blk = lock.next;
_prev_lock = Some(lock); // hold the lock until we've read the next block
_head_lock = None;
}
BlockAllocatorStats {
num_blocks: self.num_blocks,
num_initialized: self.num_initialized.load(Ordering::Relaxed),
num_free_blocks,
}
}
}
#[derive(Clone, Debug)]
pub struct BlockAllocatorStats {
pub num_blocks: u64,
pub num_initialized: u64,
pub num_free_blocks: u64,
}

View File

@@ -1,33 +0,0 @@
use std::alloc::Layout;
use std::mem::MaybeUninit;
use crate::allocator::block::BlockAllocator;
use crate::allocator::slab::SlabDesc;
pub struct MultiSlabAllocator<'t, const N: usize> {
pub(crate) block_allocator: BlockAllocator<'t>,
pub(crate) slab_descs: [SlabDesc; N],
}
impl<'t, const N: usize> MultiSlabAllocator<'t, N> {
pub(crate) fn new(
area: &'t mut [MaybeUninit<u8>],
layouts: &[Layout; N],
) -> MultiSlabAllocator<'t, N> {
let block_allocator = BlockAllocator::new(area);
MultiSlabAllocator {
block_allocator,
slab_descs: std::array::from_fn(|i| SlabDesc::new(&layouts[i])),
}
}
pub(crate) fn alloc_slab(&self, slab_idx: usize) -> *mut u8 {
self.slab_descs[slab_idx].alloc_chunk(&self.block_allocator)
}
pub(crate) fn dealloc_slab(&self, slab_idx: usize, ptr: *mut u8) {
self.slab_descs[slab_idx].dealloc_chunk(ptr, &self.block_allocator)
}
}

View File

@@ -1,432 +0,0 @@
//! A slab allocator that carves out fixed-size chunks from larger blocks.
//!
//!
use std::alloc::Layout;
use std::mem::MaybeUninit;
use std::ops::Deref;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use spin;
use super::alloc_from_slice;
use super::block::BlockAllocator;
use crate::allocator::block::BLOCK_SIZE;
pub(crate) struct SlabDesc {
pub(crate) layout: Layout,
block_lists: spin::RwLock<BlockLists>,
pub(crate) num_blocks: AtomicU64,
pub(crate) num_allocated: AtomicU64,
}
// FIXME: Not sure if SlabDesc is really Sync or Send. It probably is when it's empty, but
// 'block_lists' contains pointers when it's not empty. In the current use as part of the
// the art tree, SlabDescs are only moved during initialization.
unsafe impl Sync for SlabDesc {}
unsafe impl Send for SlabDesc {}
#[derive(Default, Debug)]
struct BlockLists {
full_blocks: BlockList,
nonfull_blocks: BlockList,
}
impl BlockLists {
// Unlink a node. It must be in either one of the two lists.
unsafe fn unlink(&mut self, elem: *mut SlabBlockHeader) {
let list = unsafe {
if (*elem).next.is_null() {
if self.full_blocks.tail == elem {
Some(&mut self.full_blocks)
} else {
Some(&mut self.nonfull_blocks)
}
} else if (*elem).prev.is_null() {
if self.full_blocks.head == elem {
Some(&mut self.full_blocks)
} else {
Some(&mut self.nonfull_blocks)
}
} else {
None
}
};
unsafe { unlink_slab_block(list, elem) };
}
}
unsafe fn unlink_slab_block(mut list: Option<&mut BlockList>, elem: *mut SlabBlockHeader) {
unsafe {
if (*elem).next.is_null() {
assert_eq!(list.as_ref().unwrap().tail, elem);
list.as_mut().unwrap().tail = (*elem).prev;
} else {
assert_eq!((*(*elem).next).prev, elem);
(*(*elem).next).prev = (*elem).prev;
}
if (*elem).prev.is_null() {
assert_eq!(list.as_ref().unwrap().head, elem);
list.as_mut().unwrap().head = (*elem).next;
} else {
assert_eq!((*(*elem).prev).next, elem);
(*(*elem).prev).next = (*elem).next;
}
}
}
#[derive(Debug)]
struct BlockList {
head: *mut SlabBlockHeader,
tail: *mut SlabBlockHeader,
}
impl Default for BlockList {
fn default() -> Self {
BlockList {
head: std::ptr::null_mut(),
tail: std::ptr::null_mut(),
}
}
}
impl BlockList {
unsafe fn push_head(&mut self, elem: *mut SlabBlockHeader) {
unsafe {
if self.is_empty() {
self.tail = elem;
(*elem).next = std::ptr::null_mut();
} else {
(*elem).next = self.head;
(*self.head).prev = elem;
}
(*elem).prev = std::ptr::null_mut();
self.head = elem;
}
}
fn is_empty(&self) -> bool {
self.head.is_null()
}
unsafe fn unlink(&mut self, elem: *mut SlabBlockHeader) {
unsafe { unlink_slab_block(Some(self), elem) }
}
#[cfg(test)]
fn dump(&self) {
let mut next = self.head;
while !next.is_null() {
let n = unsafe { next.as_ref() }.unwrap();
eprintln!(
" blk {:?} (free {}/{})",
next,
n.num_free_chunks.load(Ordering::Relaxed),
n.num_chunks
);
next = n.next;
}
}
}
impl SlabDesc {
pub(crate) fn new(layout: &Layout) -> SlabDesc {
SlabDesc {
layout: *layout,
block_lists: spin::RwLock::new(BlockLists::default()),
num_allocated: AtomicU64::new(0),
num_blocks: AtomicU64::new(0),
}
}
}
#[derive(Debug)]
struct SlabBlockHeader {
free_chunks_head: spin::Mutex<*mut FreeChunk>,
num_free_chunks: AtomicU32,
num_chunks: u32, // this is really a constant for a given Layout
// these fields are protected by the lock on the BlockLists
prev: *mut SlabBlockHeader,
next: *mut SlabBlockHeader,
}
struct FreeChunk {
next: *mut FreeChunk,
}
enum ReadOrWriteGuard<'a, T> {
Read(spin::RwLockReadGuard<'a, T>),
Write(spin::RwLockWriteGuard<'a, T>),
}
impl<'a, T> Deref for ReadOrWriteGuard<'a, T> {
type Target = T;
fn deref(&self) -> &<Self as Deref>::Target {
match self {
ReadOrWriteGuard::Read(g) => g.deref(),
ReadOrWriteGuard::Write(g) => g.deref(),
}
}
}
impl SlabDesc {
pub fn alloc_chunk(&self, block_allocator: &BlockAllocator) -> *mut u8 {
// Are there any free chunks?
let mut acquire_write = false;
'outer: loop {
let mut block_lists_guard = if acquire_write {
ReadOrWriteGuard::Write(self.block_lists.write())
} else {
ReadOrWriteGuard::Read(self.block_lists.read())
};
'inner: loop {
let block_ptr = block_lists_guard.nonfull_blocks.head;
if block_ptr.is_null() {
break 'outer;
}
unsafe {
let mut free_chunks_head = (*block_ptr).free_chunks_head.lock();
if !(*free_chunks_head).is_null() {
let result = *free_chunks_head;
(*free_chunks_head) = (*result).next;
let _old = (*block_ptr).num_free_chunks.fetch_sub(1, Ordering::Relaxed);
self.num_allocated.fetch_add(1, Ordering::Relaxed);
return result.cast();
}
}
// The block at the head of the list was full. Grab write lock and retry
match block_lists_guard {
ReadOrWriteGuard::Read(_) => {
acquire_write = true;
continue 'outer;
}
ReadOrWriteGuard::Write(ref mut g) => {
// move the node to the list of full blocks
unsafe {
g.nonfull_blocks.unlink(block_ptr);
g.full_blocks.push_head(block_ptr);
};
continue 'inner;
}
}
}
}
// no free chunks. Allocate a new block (and the chunk from that)
let (new_block, new_chunk) = self.alloc_block_and_chunk(block_allocator);
self.num_blocks.fetch_add(1, Ordering::Relaxed);
// Add the block to the list in the SlabDesc
unsafe {
let mut block_lists_guard = self.block_lists.write();
block_lists_guard.nonfull_blocks.push_head(new_block);
}
self.num_allocated.fetch_add(1, Ordering::Relaxed);
new_chunk
}
pub fn dealloc_chunk(&self, chunk_ptr: *mut u8, _block_allocator: &BlockAllocator) {
// Find the block it belongs to. You can find the block from the address. (And knowing the
// layout, you could calculate the chunk number too.)
let block_ptr: *mut SlabBlockHeader = {
let block_addr = (chunk_ptr.addr() / BLOCK_SIZE) * BLOCK_SIZE;
chunk_ptr.with_addr(block_addr).cast()
};
let chunk_ptr: *mut FreeChunk = chunk_ptr.cast();
// Mark the chunk as free in 'freechunks' list
let num_chunks;
let num_free_chunks;
unsafe {
let mut free_chunks_head = (*block_ptr).free_chunks_head.lock();
(*chunk_ptr).next = *free_chunks_head;
*free_chunks_head = chunk_ptr;
num_free_chunks = (*block_ptr).num_free_chunks.fetch_add(1, Ordering::Relaxed) + 1;
num_chunks = (*block_ptr).num_chunks;
}
if num_free_chunks == 1 {
// If the block was full previously, add it to the nonfull blocks list. Note that
// we're not holding the lock anymore, so it can immediately become full again.
// That's harmless, it will be moved back to the full list again when a call
// to alloc_chunk() sees it.
let mut block_lists = self.block_lists.write();
unsafe {
block_lists.unlink(block_ptr);
block_lists.nonfull_blocks.push_head(block_ptr);
};
} else if num_free_chunks == num_chunks {
// If the block became completely empty, move it to the free list
// TODO
// FIXME: we're still holding the spinlock. It's not exactly safe to return it to
// the free blocks list, is it? Defer it as garbage to wait out concurrent updates?
//block_allocator.release_block()
}
// update stats
self.num_allocated.fetch_sub(1, Ordering::Relaxed);
}
fn alloc_block_and_chunk(
&self,
block_allocator: &BlockAllocator,
) -> (*mut SlabBlockHeader, *mut u8) {
// fixme: handle OOM
let block_slice: &mut [MaybeUninit<u8>] = block_allocator.alloc_block();
let (block_header, remain) = alloc_from_slice::<SlabBlockHeader>(block_slice);
let padding = remain.as_ptr().align_offset(self.layout.align());
let num_chunks = (remain.len() - padding) / self.layout.size();
let first_chunk_ptr: *mut FreeChunk = remain[padding..].as_mut_ptr().cast();
unsafe {
let mut chunk_ptr = first_chunk_ptr;
for _ in 0..num_chunks - 1 {
let next_chunk_ptr = chunk_ptr.byte_add(self.layout.size());
(*chunk_ptr).next = next_chunk_ptr;
chunk_ptr = next_chunk_ptr;
}
(*chunk_ptr).next = std::ptr::null_mut();
let result_chunk = first_chunk_ptr;
let block_header = block_header.write(SlabBlockHeader {
free_chunks_head: spin::Mutex::new((*first_chunk_ptr).next),
prev: std::ptr::null_mut(),
next: std::ptr::null_mut(),
num_chunks: num_chunks as u32,
num_free_chunks: AtomicU32::new(num_chunks as u32 - 1),
});
(block_header, result_chunk.cast())
}
}
#[cfg(test)]
fn dump(&self) {
eprintln!(
"slab dump ({} blocks, {} allocated chunks)",
self.num_blocks.load(Ordering::Relaxed),
self.num_allocated.load(Ordering::Relaxed)
);
let lists = self.block_lists.read();
eprintln!("nonfull blocks:");
lists.nonfull_blocks.dump();
eprintln!("full blocks:");
lists.full_blocks.dump();
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::Rng;
use rand_distr::Zipf;
struct TestObject {
val: usize,
_dummy: [u8; BLOCK_SIZE / 4],
}
struct TestObjectSlab<'a>(SlabDesc, BlockAllocator<'a>);
impl<'a> TestObjectSlab<'a> {
fn new(block_allocator: BlockAllocator) -> TestObjectSlab {
TestObjectSlab(SlabDesc::new(&Layout::new::<TestObject>()), block_allocator)
}
fn alloc(&self, val: usize) -> *mut TestObject {
let obj: *mut TestObject = self.0.alloc_chunk(&self.1).cast();
unsafe { (*obj).val = val };
obj
}
fn dealloc(&self, obj: *mut TestObject) {
self.0.dealloc_chunk(obj.cast(), &self.1)
}
}
#[test]
fn test_slab_alloc() {
const MEM_SIZE: usize = 100000000;
let mut area = Box::new_uninit_slice(MEM_SIZE);
let block_allocator = BlockAllocator::new(&mut area);
let slab = TestObjectSlab::new(block_allocator);
let mut all: Vec<*mut TestObject> = Vec::new();
for i in 0..11 {
all.push(slab.alloc(i));
}
for i in 0..11 {
assert!(unsafe { (*all[i]).val == i });
}
let distribution = Zipf::new(10 as f64, 1.1).unwrap();
let mut rng = rand::rng();
for _ in 0..100000 {
slab.0.dump();
let idx = (rng.sample(distribution) as usize).into();
let ptr: *mut TestObject = all[idx];
if !ptr.is_null() {
assert_eq!(unsafe { (*ptr).val }, idx);
slab.dealloc(ptr);
all[idx] = std::ptr::null_mut();
} else {
all[idx] = slab.alloc(idx);
}
}
}
fn new_test_blk(i: u32) -> *mut SlabBlockHeader {
Box::into_raw(Box::new(SlabBlockHeader {
free_chunks_head: spin::Mutex::new(std::ptr::null_mut()),
num_free_chunks: AtomicU32::new(0),
num_chunks: i,
prev: std::ptr::null_mut(),
next: std::ptr::null_mut(),
}))
}
#[test]
fn test_block_linked_list() {
// note: these are leaked, but that's OK for tests
let a = new_test_blk(0);
let b = new_test_blk(1);
let mut list = BlockList::default();
assert!(list.is_empty());
unsafe {
list.push_head(a);
assert!(!list.is_empty());
list.unlink(a);
}
assert!(list.is_empty());
unsafe {
list.push_head(b);
list.push_head(a);
assert_eq!(list.head, a);
assert_eq!((*a).next, b);
assert_eq!((*b).prev, a);
assert_eq!(list.tail, b);
list.unlink(a);
list.unlink(b);
assert!(list.is_empty());
}
}
}

View File

@@ -1,44 +0,0 @@
use std::mem::MaybeUninit;
pub fn alloc_from_slice<T>(
area: &mut [MaybeUninit<u8>],
) -> (&mut MaybeUninit<T>, &mut [MaybeUninit<u8>]) {
let layout = std::alloc::Layout::new::<T>();
let area_start = area.as_mut_ptr();
// pad to satisfy alignment requirements
let padding = area_start.align_offset(layout.align());
if padding + layout.size() > area.len() {
panic!("out of memory");
}
let area = &mut area[padding..];
let (result_area, remain) = area.split_at_mut(layout.size());
let result_ptr: *mut MaybeUninit<T> = result_area.as_mut_ptr().cast();
let result = unsafe { result_ptr.as_mut().unwrap() };
(result, remain)
}
pub fn alloc_array_from_slice<T>(
area: &mut [MaybeUninit<u8>],
len: usize,
) -> (&mut [MaybeUninit<T>], &mut [MaybeUninit<u8>]) {
let layout = std::alloc::Layout::new::<T>();
let area_start = area.as_mut_ptr();
// pad to satisfy alignment requirements
let padding = area_start.align_offset(layout.align());
if padding + layout.size() * len > area.len() {
panic!("out of memory");
}
let area = &mut area[padding..];
let (result_area, remain) = area.split_at_mut(layout.size() * len);
let result_ptr: *mut MaybeUninit<T> = result_area.as_mut_ptr().cast();
let result = unsafe { std::slice::from_raw_parts_mut(result_ptr.as_mut().unwrap(), len) };
(result, remain)
}

View File

@@ -1,147 +0,0 @@
//! This is similar to crossbeam_epoch crate, but works in shared memory
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use crossbeam_utils::CachePadded;
use spin;
const NUM_SLOTS: usize = 1000;
/// This is the struct that is stored in shmem
///
/// bit 0: is it pinned or not?
/// rest of the bits are the epoch counter.
pub struct EpochShared {
global_epoch: AtomicU64,
participants: [CachePadded<AtomicU64>; NUM_SLOTS],
broadcast_lock: spin::Mutex<()>,
}
impl EpochShared {
pub fn new() -> EpochShared {
EpochShared {
global_epoch: AtomicU64::new(2),
participants: [const { CachePadded::new(AtomicU64::new(2)) }; NUM_SLOTS],
broadcast_lock: spin::Mutex::new(()),
}
}
pub fn register(&self) -> LocalHandle {
LocalHandle {
global: self,
last_slot: AtomicUsize::new(0), // todo: choose more intelligently
}
}
fn release_pin(&self, slot: usize, _epoch: u64) {
let global_epoch = self.global_epoch.load(Ordering::Relaxed);
self.participants[slot].store(global_epoch, Ordering::Relaxed);
}
fn pin_internal(&self, slot_hint: usize) -> (usize, u64) {
// pick a slot
let mut slot = slot_hint;
let epoch = loop {
let old = self.participants[slot].fetch_or(1, Ordering::Relaxed);
if old & 1 == 0 {
// Got this slot
break old;
}
// the slot was busy by another thread / process. try a different slot
slot += 1;
if slot == NUM_SLOTS {
slot = 0;
}
continue;
};
(slot, epoch)
}
pub(crate) fn advance(&self) -> u64 {
// Advance the global epoch
let old_epoch = self.global_epoch.fetch_add(2, Ordering::Relaxed);
let new_epoch = old_epoch + 2;
// Anyone that release their pin after this will update their slot.
new_epoch
}
pub(crate) fn broadcast(&self) {
let Some(_guard) = self.broadcast_lock.try_lock() else {
return;
};
let epoch = self.global_epoch.load(Ordering::Relaxed);
let old_epoch = epoch.wrapping_sub(2);
// Update all free slots.
for i in 0..NUM_SLOTS {
// TODO: check result, as a sanity check. It should either be the old epoch, or pinned
let _ = self.participants[i].compare_exchange(
old_epoch,
epoch,
Ordering::Relaxed,
Ordering::Relaxed,
);
}
// FIXME: memory fence here, since we used Relaxed?
}
pub(crate) fn get_oldest(&self) -> u64 {
// Read all slots.
let now = self.global_epoch.load(Ordering::Relaxed);
let mut oldest = now;
for i in 0..NUM_SLOTS {
let this_epoch = self.participants[i].load(Ordering::Relaxed);
let delta = now.wrapping_sub(this_epoch);
if delta > u64::MAX / 2 {
// this is very recent
} else {
if delta > now.wrapping_sub(oldest) {
oldest = this_epoch;
}
}
}
oldest
}
pub(crate) fn get_current(&self) -> u64 {
self.global_epoch.load(Ordering::Relaxed)
}
}
pub(crate) struct EpochPin<'e> {
slot: usize,
pub(crate) epoch: u64,
handle: &'e LocalHandle<'e>,
}
impl<'e> Drop for EpochPin<'e> {
fn drop(&mut self) {
self.handle.global.release_pin(self.slot, self.epoch);
}
}
pub struct LocalHandle<'g> {
global: &'g EpochShared,
last_slot: AtomicUsize,
}
impl<'g> LocalHandle<'g> {
pub fn pin(&self) -> EpochPin {
let (slot, epoch) = self
.global
.pin_internal(self.last_slot.load(Ordering::Relaxed));
self.last_slot.store(slot, Ordering::Relaxed);
EpochPin {
handle: self,
epoch,
slot,
}
}
}

View File

@@ -1,587 +0,0 @@
//! Adaptive Radix Tree (ART) implementation, with Optimistic Lock Coupling.
//!
//! The data structure is described in these two papers:
//!
//! [1] Leis, V. & Kemper, Alfons & Neumann, Thomas. (2013).
//! The adaptive radix tree: ARTful indexing for main-memory databases.
//! Proceedings - International Conference on Data Engineering. 38-49. 10.1109/ICDE.2013.6544812.
//! https://db.in.tum.de/~leis/papers/ART.pdf
//!
//! [2] Leis, Viktor & Scheibner, Florian & Kemper, Alfons & Neumann, Thomas. (2016).
//! The ART of practical synchronization.
//! 1-8. 10.1145/2933349.2933352.
//! https://db.in.tum.de/~leis/papers/artsync.pdf
//!
//! [1] describes the base data structure, and [2] describes the Optimistic Lock Coupling that we
//! use.
//!
//! The papers mention a few different variants. We have made the following choices in this
//! implementation:
//!
//! - All keys have the same length
//!
//! - Single-value leaves.
//!
//! - For collapsing inner nodes, we use the Pessimistic approach, where each inner node stores a
//! variable length "prefix", which stores the keys of all the one-way nodes which have been
//! removed. However, similar to the "hybrid" approach described in the paper, each node only has
//! space for a constant-size prefix of 8 bytes. If a node would have a longer prefix, then we
//! create create one-way nodes to store them. (There was no particular reason for this choice,
//! the "hybrid" approach described in the paper might be better.)
//!
//! - For concurrency, we use Optimistic Lock Coupling. The paper [2] also describes another method,
//! ROWEX, which generally performs better when there is contention, but that is not important
//! for use and Optimisic Lock Coupling is simpler to implement.
//!
//! ## Requirements
//!
//! This data structure is currently used for the integrated LFC, relsize and last-written LSN cache
//! in the compute communicator, part of the 'neon' Postgres extension. We have some unique
//! requirements, which is why we had to write our own. Namely:
//!
//! - The data structure has to live in fixed-sized shared memory segment. That rules out any
//! built-in Rust collections and most crates. (Except possibly with the 'allocator_api' rust
//! feature, which still nightly-only experimental as of this writing).
//!
//! - The data structure is accessed from multiple processes. Only one process updates the data
//! structure, but other processes perform reads. That rules out using built-in Rust locking
//! primitives like Mutex and RwLock, and most crates too.
//!
//! - Within the one process with write-access, multiple threads can perform updates concurrently.
//! That rules out using PostgreSQL LWLocks for the locking.
//!
//! The implementation is generic, and doesn't depend on any PostgreSQL specifics, but it has been
//! written with that usage and the above constraints in mind. Some noteworthy assumptions:
//!
//! - Contention is assumed to be rare. In the integrated cache in PostgreSQL, there's higher level
//! locking in the PostgreSQL buffer manager, which ensures that two backends should not try to
//! read / write the same page at the same time. (Prefetching can conflict with actual reads,
//! however.)
//!
//! - The keys in the integrated cache are 17 bytes long.
//!
//! ## Usage
//!
//! Because this is designed to be used as a Postgres shared memory data structure, initialization
//! happens in three stages:
//!
//! 0. A fixed area of shared memory is allocated at postmaster startup.
//!
//! 1. TreeInitStruct::new() is called to initialize it, still in Postmaster process, before any
//! other process or thread is running. It returns a TreeInitStruct, which is inherited by all
//! the processes through fork().
//!
//! 2. One process may have write-access to the struct, by calling
//! [TreeInitStruct::attach_writer]. (That process is the communicator process.)
//!
//! 3. Other processes get read-access to the struct, by calling [TreeInitStruct::attach_reader]
//!
//! "Write access" means that you can insert / update / delete values in the tree.
//!
//! NOTE: The Values stored in the tree are sometimes moved, when a leaf node fills up and a new
//! larger node needs to be allocated. The versioning and epoch-based allocator ensure that the data
//! structure stays consistent, but if the Value has interior mutability, like atomic fields,
//! updates to such fields might be lost if the leaf node is concurrently moved! If that becomes a
//! problem, the version check could be passed up to the caller, so that the caller could detect the
//! lost updates and retry the operation.
//!
//! ## Implementation
//!
//! node_ptr: Provides low-level implementations of the four different node types (eight actually,
//! since there is an Internal and Leaf variant of each)
//!
//! lock_and_version.rs: Provides an abstraction for the combined lock and version counter on each
//! node.
//!
//! node_ref.rs: The code in node_ptr.rs deals with raw pointers. node_ref.rs provides more type-safe
//! abstractions on top.
//!
//! algorithm.rs: Contains the functions to implement lookups and updates in the tree
//!
//! allocator.rs: Provides a facility to allocate memory for the tree nodes. (We must provide our
//! own abstraction for that because we need the data structure to live in a pre-allocated shared
//! memory segment).
//!
//! epoch.rs: The data structure requires that when a node is removed from the tree, it is not
//! immediately deallocated, but stays around for as long as concurrent readers might still have
//! pointers to them. This is enforced by an epoch system. This is similar to
//! e.g. crossbeam_epoch, but we couldn't use that either because it has to work across processes
//! communicating over the shared memory segment.
//!
//! ## See also
//!
//! There are some existing Rust ART implementations out there, but none of them filled all
//! the requirements:
//!
//! - https://github.com/XiangpengHao/congee
//! - https://github.com/declanvk/blart
//!
//! ## TODO
//!
//! - Removing values has not been implemented
mod algorithm;
pub mod allocator;
mod epoch;
use algorithm::RootPtr;
use algorithm::node_ptr::NodePtr;
use std::collections::VecDeque;
use std::fmt::Debug;
use std::marker::PhantomData;
use std::ptr::NonNull;
use std::sync::atomic::{AtomicBool, Ordering};
use crate::epoch::EpochPin;
#[cfg(test)]
mod tests;
use allocator::ArtAllocator;
pub use allocator::ArtMultiSlabAllocator;
pub use allocator::OutOfMemoryError;
/// Fixed-length key type.
///
pub trait Key: Debug {
const KEY_LEN: usize;
fn as_bytes(&self) -> &[u8];
}
/// Values stored in the tree
///
/// Values need to be Cloneable, because when a node "grows", the value is copied to a new node and
/// the old sticks around until all readers that might see the old value are gone.
// fixme obsolete, no longer needs Clone
pub trait Value {}
const MAX_GARBAGE: usize = 1024;
/// The root of the tree, plus other tree-wide data. This is stored in the shared memory.
pub struct Tree<V: Value> {
/// For simplicity, so that we never need to grow or shrink the root, the root node is always an
/// Internal256 node. Also, it never has a prefix (that's actually a bit wasteful, incurring one
/// indirection to every lookup)
root: RootPtr<V>,
writer_attached: AtomicBool,
epoch: epoch::EpochShared,
}
unsafe impl<V: Value + Sync> Sync for Tree<V> {}
unsafe impl<V: Value + Send> Send for Tree<V> {}
struct GarbageQueue<V>(VecDeque<(NodePtr<V>, u64)>);
unsafe impl<V: Value + Sync> Sync for GarbageQueue<V> {}
unsafe impl<V: Value + Send> Send for GarbageQueue<V> {}
impl<V> GarbageQueue<V> {
fn new() -> GarbageQueue<V> {
GarbageQueue(VecDeque::with_capacity(MAX_GARBAGE))
}
fn remember_obsolete_node(&mut self, ptr: NodePtr<V>, epoch: u64) {
self.0.push_front((ptr, epoch));
}
fn next_obsolete(&mut self, cutoff_epoch: u64) -> Option<NodePtr<V>> {
if let Some(back) = self.0.back() {
if back.1 < cutoff_epoch {
return Some(self.0.pop_back().unwrap().0);
}
}
None
}
}
/// Struct created at postmaster startup
pub struct TreeInitStruct<'t, K: Key, V: Value, A: ArtAllocator<V>> {
tree: &'t Tree<V>,
allocator: &'t A,
phantom_key: PhantomData<K>,
}
/// The worker process has a reference to this. The write operations are only safe
/// from the worker process
pub struct TreeWriteAccess<'t, K: Key, V: Value, A: ArtAllocator<V>>
where
K: Key,
V: Value,
{
tree: &'t Tree<V>,
pub allocator: &'t A,
epoch_handle: epoch::LocalHandle<'t>,
phantom_key: PhantomData<K>,
/// Obsolete nodes that cannot be recycled until their epoch expires.
garbage: spin::Mutex<GarbageQueue<V>>,
}
/// The backends have a reference to this. It cannot be used to modify the tree
pub struct TreeReadAccess<'t, K: Key, V: Value>
where
K: Key,
V: Value,
{
tree: &'t Tree<V>,
epoch_handle: epoch::LocalHandle<'t>,
phantom_key: PhantomData<K>,
}
impl<'a, 't: 'a, K: Key, V: Value, A: ArtAllocator<V>> TreeInitStruct<'t, K, V, A> {
pub fn new(allocator: &'t A) -> TreeInitStruct<'t, K, V, A> {
let tree_ptr = allocator.alloc_tree();
let tree_ptr = NonNull::new(tree_ptr).expect("out of memory");
let init = Tree {
root: algorithm::new_root(allocator).expect("out of memory"),
writer_attached: AtomicBool::new(false),
epoch: epoch::EpochShared::new(),
};
unsafe { tree_ptr.write(init) };
TreeInitStruct {
tree: unsafe { tree_ptr.as_ref() },
allocator,
phantom_key: PhantomData,
}
}
pub fn attach_writer(self) -> TreeWriteAccess<'t, K, V, A> {
let previously_attached = self.tree.writer_attached.swap(true, Ordering::Relaxed);
if previously_attached {
panic!("writer already attached");
}
TreeWriteAccess {
tree: self.tree,
allocator: self.allocator,
phantom_key: PhantomData,
epoch_handle: self.tree.epoch.register(),
garbage: spin::Mutex::new(GarbageQueue::new()),
}
}
pub fn attach_reader(self) -> TreeReadAccess<'t, K, V> {
TreeReadAccess {
tree: self.tree,
phantom_key: PhantomData,
epoch_handle: self.tree.epoch.register(),
}
}
}
impl<'t, K: Key, V: Value, A: ArtAllocator<V>> TreeWriteAccess<'t, K, V, A> {
pub fn start_write<'g>(&'t self) -> TreeWriteGuard<'g, K, V, A>
where
't: 'g,
{
TreeWriteGuard {
tree_writer: self,
epoch_pin: self.epoch_handle.pin(),
phantom_key: PhantomData,
created_garbage: false,
}
}
pub fn start_read(&'t self) -> TreeReadGuard<'t, K, V> {
TreeReadGuard {
tree: &self.tree,
epoch_pin: self.epoch_handle.pin(),
phantom_key: PhantomData,
}
}
}
impl<'t, K: Key, V: Value> TreeReadAccess<'t, K, V> {
pub fn start_read(&'t self) -> TreeReadGuard<'t, K, V> {
TreeReadGuard {
tree: &self.tree,
epoch_pin: self.epoch_handle.pin(),
phantom_key: PhantomData,
}
}
}
pub struct TreeReadGuard<'e, K, V>
where
K: Key,
V: Value,
{
tree: &'e Tree<V>,
epoch_pin: EpochPin<'e>,
phantom_key: PhantomData<K>,
}
impl<'e, K: Key, V: Value> TreeReadGuard<'e, K, V> {
pub fn get(&'e self, key: &K) -> Option<&'e V> {
algorithm::search(key, self.tree.root, &self.epoch_pin)
}
}
pub struct TreeWriteGuard<'e, K, V, A>
where
K: Key,
V: Value,
A: ArtAllocator<V>,
{
tree_writer: &'e TreeWriteAccess<'e, K, V, A>,
epoch_pin: EpochPin<'e>,
phantom_key: PhantomData<K>,
created_garbage: bool,
}
pub enum UpdateAction<V> {
Nothing,
Insert(V),
Remove,
}
impl<'e, K: Key, V: Value, A: ArtAllocator<V>> TreeWriteGuard<'e, K, V, A> {
/// Get a value
pub fn get(&'e mut self, key: &K) -> Option<&'e V> {
algorithm::search(key, self.tree_writer.tree.root, &self.epoch_pin)
}
/// Insert a value
pub fn insert(self, key: &K, value: V) -> Result<bool, OutOfMemoryError> {
let mut success = None;
self.update_with_fn(key, |existing| {
if let Some(_) = existing {
success = Some(false);
UpdateAction::Nothing
} else {
success = Some(true);
UpdateAction::Insert(value)
}
})?;
Ok(success.expect("value_fn not called"))
}
/// Remove value. Returns true if it existed
pub fn remove(self, key: &K) -> bool {
let mut result = false;
// FIXME: It's not clear if OOM is expected while removing. It seems
// not nice, but shrinking a node can OOM. Then again, we could opt
// to not shrink a node if we cannot allocate, to live a little longer.
self.update_with_fn(key, |existing| match existing {
Some(_) => {
result = true;
UpdateAction::Remove
}
None => UpdateAction::Nothing,
})
.expect("out of memory while removing");
result
}
/// Try to remove value and return the old value.
pub fn remove_and_return(self, key: &K) -> Option<V>
where
V: Clone,
{
let mut old = None;
self.update_with_fn(key, |existing| {
old = existing.cloned();
UpdateAction::Remove
})
.expect("out of memory while removing");
old
}
/// Update key using the given function. All the other modifying operations are based on this.
///
/// The function is passed a reference to the existing value, if any. If the function
/// returns None, the value is removed from the tree (or if there was no existing value,
/// does nothing). If the function returns Some, the existing value is replaced, of if there
/// was no existing value, it is inserted. FIXME: update comment
pub fn update_with_fn<F>(mut self, key: &K, value_fn: F) -> Result<(), OutOfMemoryError>
where
F: FnOnce(Option<&V>) -> UpdateAction<V>,
{
algorithm::update_fn(key, value_fn, self.tree_writer.tree.root, &mut self)?;
if self.created_garbage {
let _ = self.collect_garbage();
}
Ok(())
}
fn remember_obsolete_node(&mut self, ptr: NodePtr<V>) {
self.tree_writer
.garbage
.lock()
.remember_obsolete_node(ptr, self.epoch_pin.epoch);
self.created_garbage = true;
}
// returns number of nodes recycled
fn collect_garbage(&self) -> usize {
self.tree_writer.tree.epoch.advance();
self.tree_writer.tree.epoch.broadcast();
let cutoff_epoch = self.tree_writer.tree.epoch.get_oldest();
let mut result = 0;
let mut garbage_queue = self.tree_writer.garbage.lock();
while let Some(ptr) = garbage_queue.next_obsolete(cutoff_epoch) {
ptr.deallocate(self.tree_writer.allocator);
result += 1;
}
result
}
}
pub struct TreeIterator<K>
where
K: Key + for<'a> From<&'a [u8]>,
{
done: bool,
pub next_key: Vec<u8>,
max_key: Option<Vec<u8>>,
phantom_key: PhantomData<K>,
}
impl<K> TreeIterator<K>
where
K: Key + for<'a> From<&'a [u8]>,
{
pub fn new_wrapping() -> TreeIterator<K> {
let mut next_key = Vec::new();
next_key.resize(K::KEY_LEN, 0);
TreeIterator {
done: false,
next_key,
max_key: None,
phantom_key: PhantomData,
}
}
pub fn new(range: &std::ops::Range<K>) -> TreeIterator<K> {
let result = TreeIterator {
done: false,
next_key: Vec::from(range.start.as_bytes()),
max_key: Some(Vec::from(range.end.as_bytes())),
phantom_key: PhantomData,
};
assert_eq!(result.next_key.len(), K::KEY_LEN);
assert_eq!(result.max_key.as_ref().unwrap().len(), K::KEY_LEN);
result
}
pub fn next<'g, V>(&mut self, read_guard: &'g TreeReadGuard<'g, K, V>) -> Option<(K, &'g V)>
where
V: Value,
{
if self.done {
return None;
}
let mut wrapped_around = false;
loop {
assert_eq!(self.next_key.len(), K::KEY_LEN);
if let Some((k, v)) = algorithm::iter_next(
&mut self.next_key,
read_guard.tree.root,
&read_guard.epoch_pin,
) {
assert_eq!(k.len(), K::KEY_LEN);
assert_eq!(self.next_key.len(), K::KEY_LEN);
// Check if we reached the end of the range
if let Some(max_key) = &self.max_key {
if k.as_slice() >= max_key.as_slice() {
self.done = true;
break None;
}
}
// increment the key
self.next_key = k.clone();
increment_key(self.next_key.as_mut_slice());
let k = k.as_slice().into();
break Some((k, v));
} else {
if self.max_key.is_some() {
self.done = true;
} else {
// Start from beginning
if !wrapped_around {
for i in 0..K::KEY_LEN {
self.next_key[i] = 0;
}
wrapped_around = true;
continue;
} else {
// The tree is completely empty
// FIXME: perhaps we should remember the starting point instead.
// Currently this will scan some ranges twice.
break None;
}
}
break None;
}
}
}
}
fn increment_key(key: &mut [u8]) -> bool {
for i in (0..key.len()).rev() {
let (byte, overflow) = key[i].overflowing_add(1);
key[i] = byte;
if !overflow {
return false;
}
}
true
}
// Debugging functions
impl<'e, K: Key, V: Value + Debug, A: ArtAllocator<V>> TreeWriteGuard<'e, K, V, A> {
pub fn dump(&mut self, dst: &mut dyn std::io::Write) {
algorithm::dump_tree(self.tree_writer.tree.root, &self.epoch_pin, dst)
}
}
impl<'e, K: Key, V: Value + Debug> TreeReadGuard<'e, K, V> {
pub fn dump(&mut self, dst: &mut dyn std::io::Write) {
algorithm::dump_tree(self.tree.root, &self.epoch_pin, dst)
}
}
impl<'e, K: Key, V: Value> TreeWriteAccess<'e, K, V, ArtMultiSlabAllocator<'e, V>> {
pub fn get_statistics(&self) -> ArtTreeStatistics {
self.allocator.get_statistics();
ArtTreeStatistics {
blocks: self.allocator.inner.block_allocator.get_statistics(),
slabs: self.allocator.get_statistics(),
epoch: self.tree.epoch.get_current(),
oldest_epoch: self.tree.epoch.get_oldest(),
num_garbage: self.garbage.lock().0.len() as u64,
}
}
}
#[derive(Clone, Debug)]
pub struct ArtTreeStatistics {
pub blocks: allocator::block::BlockAllocatorStats,
pub slabs: allocator::ArtMultiSlabStats,
pub epoch: u64,
pub oldest_epoch: u64,
pub num_garbage: u64,
}

View File

@@ -1,243 +0,0 @@
use std::collections::BTreeMap;
use std::collections::HashSet;
use std::fmt::{Debug, Formatter};
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::ArtAllocator;
use crate::ArtMultiSlabAllocator;
use crate::TreeInitStruct;
use crate::TreeIterator;
use crate::TreeWriteAccess;
use crate::UpdateAction;
use crate::{Key, Value};
use rand::Rng;
use rand::seq::SliceRandom;
use rand_distr::Zipf;
const TEST_KEY_LEN: usize = 16;
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
struct TestKey([u8; TEST_KEY_LEN]);
impl TestKey {
const MIN: TestKey = TestKey([0; TEST_KEY_LEN]);
const MAX: TestKey = TestKey([u8::MAX; TEST_KEY_LEN]);
}
impl Key for TestKey {
const KEY_LEN: usize = TEST_KEY_LEN;
fn as_bytes(&self) -> &[u8] {
&self.0
}
}
impl From<&TestKey> for u128 {
fn from(val: &TestKey) -> u128 {
u128::from_be_bytes(val.0)
}
}
impl From<u128> for TestKey {
fn from(val: u128) -> TestKey {
TestKey(val.to_be_bytes())
}
}
impl<'a> From<&'a [u8]> for TestKey {
fn from(bytes: &'a [u8]) -> TestKey {
TestKey(bytes.try_into().unwrap())
}
}
impl Value for usize {}
fn test_inserts<K: Into<TestKey> + Copy>(keys: &[K]) {
const MEM_SIZE: usize = 10000000;
let mut area = Box::new_uninit_slice(MEM_SIZE);
let allocator = ArtMultiSlabAllocator::new(&mut area);
let init_struct = TreeInitStruct::<TestKey, usize, _>::new(allocator);
let tree_writer = init_struct.attach_writer();
for (idx, k) in keys.iter().enumerate() {
let w = tree_writer.start_write();
let res = w.insert(&(*k).into(), idx);
assert!(res.is_ok());
}
for (idx, k) in keys.iter().enumerate() {
let r = tree_writer.start_read();
let value = r.get(&(*k).into());
assert_eq!(value, Some(idx).as_ref());
}
eprintln!("stats: {:?}", tree_writer.get_statistics());
}
#[test]
fn dense() {
// This exercises splitting a node with prefix
let keys: &[u128] = &[0, 1, 2, 3, 256];
test_inserts(keys);
// Dense keys
let mut keys: Vec<u128> = (0..10000).collect();
test_inserts(&keys);
// Do the same in random orders
for _ in 1..10 {
keys.shuffle(&mut rand::rng());
test_inserts(&keys);
}
}
#[test]
fn sparse() {
// sparse keys
let mut keys: Vec<TestKey> = Vec::new();
let mut used_keys = HashSet::new();
for _ in 0..10000 {
loop {
let key = rand::random::<u128>();
if used_keys.get(&key).is_some() {
continue;
}
used_keys.insert(key);
keys.push(key.into());
break;
}
}
test_inserts(&keys);
}
struct TestValue(AtomicUsize);
impl TestValue {
fn new(val: usize) -> TestValue {
TestValue(AtomicUsize::new(val))
}
fn load(&self) -> usize {
self.0.load(Ordering::Relaxed)
}
}
impl Value for TestValue {}
impl Clone for TestValue {
fn clone(&self) -> TestValue {
TestValue::new(self.load())
}
}
impl Debug for TestValue {
fn fmt(&self, fmt: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
write!(fmt, "{:?}", self.load())
}
}
#[derive(Clone, Debug)]
struct TestOp(TestKey, Option<usize>);
fn apply_op<A: ArtAllocator<TestValue>>(
op: &TestOp,
tree: &TreeWriteAccess<TestKey, TestValue, A>,
shadow: &mut BTreeMap<TestKey, usize>,
) {
eprintln!("applying op: {op:?}");
// apply the change to the shadow tree first
let shadow_existing = if let Some(v) = op.1 {
shadow.insert(op.0, v)
} else {
shadow.remove(&op.0)
};
// apply to Art tree
let w = tree.start_write();
w.update_with_fn(&op.0, |existing| {
assert_eq!(existing.map(TestValue::load), shadow_existing);
match (existing, op.1) {
(None, None) => UpdateAction::Nothing,
(None, Some(new_val)) => UpdateAction::Insert(TestValue::new(new_val)),
(Some(_old_val), None) => UpdateAction::Remove,
(Some(old_val), Some(new_val)) => {
old_val.0.store(new_val, Ordering::Relaxed);
UpdateAction::Nothing
}
}
})
.expect("out of memory");
}
fn test_iter<A: ArtAllocator<TestValue>>(
tree: &TreeWriteAccess<TestKey, TestValue, A>,
shadow: &BTreeMap<TestKey, usize>,
) {
let mut shadow_iter = shadow.iter();
let mut iter = TreeIterator::new(&(TestKey::MIN..TestKey::MAX));
loop {
let shadow_item = shadow_iter.next().map(|(k, v)| (k.clone(), v.clone()));
let r = tree.start_read();
let item = iter.next(&r);
if shadow_item != item.map(|(k, v)| (k, v.load())) {
eprintln!(
"FAIL: iterator returned {:?}, expected {:?}",
item, shadow_item
);
tree.start_read().dump(&mut std::io::stderr());
eprintln!("SHADOW:");
let mut si = shadow.iter();
while let Some(si) = si.next() {
eprintln!("key: {:?}, val: {}", si.0, si.1);
}
panic!(
"FAIL: iterator returned {:?}, expected {:?}",
item, shadow_item
);
}
if item.is_none() {
break;
}
}
}
#[test]
fn random_ops() {
const MEM_SIZE: usize = 10000000;
let mut area = Box::new_uninit_slice(MEM_SIZE);
let allocator = ArtMultiSlabAllocator::new(&mut area);
let init_struct = TreeInitStruct::<TestKey, TestValue, _>::new(allocator);
let tree_writer = init_struct.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let distribution = Zipf::new(u128::MAX as f64, 1.1).unwrap();
let mut rng = rand::rng();
for i in 0..100000 {
let mut key: TestKey = (rng.sample(distribution) as u128).into();
if rng.random_bool(0.10) {
key = TestKey::from(u128::from(&key) | 0xffffffff);
}
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &tree_writer, &mut shadow);
if i % 1000 == 0 {
eprintln!("{i} ops processed");
eprintln!("stats: {:?}", tree_writer.get_statistics());
test_iter(&tree_writer, &shadow);
}
}
}

View File

@@ -34,8 +34,6 @@ pub struct NodeMetadata {
pub postgres_host: String,
#[serde(rename = "port")]
pub postgres_port: u16,
pub grpc_host: Option<String>,
pub grpc_port: Option<u16>,
pub http_host: String,
pub http_port: u16,
pub https_port: Option<u16>,

View File

@@ -14,8 +14,6 @@ fn test_node_metadata_v1_backward_compatibilty() {
NodeMetadata {
postgres_host: "localhost".to_string(),
postgres_port: 23,
grpc_host: None,
grpc_port: None,
http_host: "localhost".to_string(),
http_port: 42,
https_port: None,
@@ -39,35 +37,6 @@ fn test_node_metadata_v2_backward_compatibilty() {
NodeMetadata {
postgres_host: "localhost".to_string(),
postgres_port: 23,
grpc_host: None,
grpc_port: None,
http_host: "localhost".to_string(),
http_port: 42,
https_port: Some(123),
other: HashMap::new(),
}
)
}
#[test]
fn test_node_metadata_v3_backward_compatibilty() {
let v3 = serde_json::to_vec(&serde_json::json!({
"host": "localhost",
"port": 23,
"grpc_host": "localhost",
"grpc_port": 51,
"http_host": "localhost",
"http_port": 42,
"https_port": 123,
}));
assert_eq!(
serde_json::from_slice::<NodeMetadata>(&v3.unwrap()).unwrap(),
NodeMetadata {
postgres_host: "localhost".to_string(),
postgres_port: 23,
grpc_host: Some("localhost".to_string()),
grpc_port: Some(51),
http_host: "localhost".to_string(),
http_port: 42,
https_port: Some(123),

View File

@@ -53,9 +53,6 @@ pub struct NodeRegisterRequest {
pub listen_pg_addr: String,
pub listen_pg_port: u16,
pub listen_grpc_addr: Option<String>,
pub listen_grpc_port: Option<u16>,
pub listen_http_addr: String,
pub listen_http_port: u16,
pub listen_https_port: Option<u16>,
@@ -105,9 +102,6 @@ pub struct TenantLocateResponseShard {
pub listen_pg_addr: String,
pub listen_pg_port: u16,
pub listen_grpc_addr: Option<String>,
pub listen_grpc_port: Option<u16>,
pub listen_http_addr: String,
pub listen_http_port: u16,
pub listen_https_port: Option<u16>,
@@ -158,8 +152,6 @@ pub struct NodeDescribeResponse {
pub listen_pg_addr: String,
pub listen_pg_port: u16,
pub listen_grpc_addr: Option<String>,
pub listen_grpc_port: Option<u16>,
}
#[derive(Serialize, Deserialize, Debug)]
@@ -352,6 +344,35 @@ impl Default for ShardSchedulingPolicy {
}
}
#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)]
pub enum NodeLifecycle {
Active,
Deleted,
}
impl FromStr for NodeLifecycle {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"active" => Ok(Self::Active),
"deleted" => Ok(Self::Deleted),
_ => Err(anyhow::anyhow!("Unknown node lifecycle '{s}'")),
}
}
}
impl From<NodeLifecycle> for String {
fn from(value: NodeLifecycle) -> String {
use NodeLifecycle::*;
match value {
Active => "active",
Deleted => "deleted",
}
.to_string()
}
}
#[derive(Serialize, Deserialize, Clone, Copy, Eq, PartialEq, Debug)]
pub enum NodeSchedulingPolicy {
Active,

View File

@@ -1,32 +1,19 @@
use std::io;
use tokio::net::TcpStream;
use crate::client::SocketConfig;
use crate::config::{Host, SslMode};
use crate::config::Host;
use crate::tls::MakeTlsConnect;
use crate::{Error, cancel_query_raw, connect_socket};
use crate::{Error, cancel_query_raw, connect_socket, connect_tls};
pub(crate) async fn cancel_query<T>(
config: Option<SocketConfig>,
ssl_mode: SslMode,
mut tls: T,
config: SocketConfig,
tls: T,
process_id: i32,
secret_key: i32,
) -> Result<(), Error>
where
T: MakeTlsConnect<TcpStream>,
{
let config = match config {
Some(config) => config,
None => {
return Err(Error::connect(io::Error::new(
io::ErrorKind::InvalidInput,
"unknown host",
)));
}
};
let hostname = match &config.host {
Host::Tcp(host) => &**host,
};
@@ -42,5 +29,6 @@ where
)
.await?;
cancel_query_raw::cancel_query_raw(socket, ssl_mode, tls, process_id, secret_key).await
let stream = connect_tls::connect_tls(socket, config.ssl_mode, tls).await?;
cancel_query_raw::cancel_query_raw(stream, process_id, secret_key).await
}

View File

@@ -2,23 +2,16 @@ use bytes::BytesMut;
use postgres_protocol2::message::frontend;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use crate::config::SslMode;
use crate::tls::TlsConnect;
use crate::{Error, connect_tls};
use crate::Error;
pub async fn cancel_query_raw<S, T>(
stream: S,
mode: SslMode,
tls: T,
pub async fn cancel_query_raw<S>(
mut stream: S,
process_id: i32,
secret_key: i32,
) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
let mut stream = connect_tls::connect_tls(stream, mode, tls).await?;
let mut buf = BytesMut::new();
frontend::cancel_request(process_id, secret_key, &mut buf);

View File

@@ -3,16 +3,21 @@ use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use crate::client::SocketConfig;
use crate::config::SslMode;
use crate::tls::{MakeTlsConnect, TlsConnect};
use crate::tls::MakeTlsConnect;
use crate::{Error, cancel_query, cancel_query_raw};
/// The capability to request cancellation of in-progress queries on a
/// connection.
#[derive(Clone, Serialize, Deserialize)]
#[derive(Clone)]
pub struct CancelToken {
pub socket_config: Option<SocketConfig>,
pub ssl_mode: SslMode,
pub socket_config: SocketConfig,
pub raw: RawCancelToken,
}
/// The capability to request cancellation of in-progress queries on a
/// connection.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RawCancelToken {
pub process_id: i32,
pub secret_key: i32,
}
@@ -36,28 +41,21 @@ impl CancelToken {
{
cancel_query::cancel_query(
self.socket_config.clone(),
self.ssl_mode,
tls,
self.process_id,
self.secret_key,
)
.await
}
/// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new
/// connection itself.
pub async fn cancel_query_raw<S, T>(&self, stream: S, tls: T) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
T: TlsConnect<S>,
{
cancel_query_raw::cancel_query_raw(
stream,
self.ssl_mode,
tls,
self.process_id,
self.secret_key,
self.raw.process_id,
self.raw.secret_key,
)
.await
}
}
impl RawCancelToken {
/// Like `cancel_query`, but uses a stream which is already connected to the server rather than opening a new
/// connection itself.
pub async fn cancel_query_raw<S>(&self, stream: S) -> Result<(), Error>
where
S: AsyncRead + AsyncWrite + Unpin,
{
cancel_query_raw::cancel_query_raw(stream, self.process_id, self.secret_key).await
}
}

View File

@@ -12,6 +12,7 @@ use postgres_protocol2::message::frontend;
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use crate::cancel_token::RawCancelToken;
use crate::codec::{BackendMessages, FrontendMessage};
use crate::config::{Host, SslMode};
use crate::query::RowStream;
@@ -166,6 +167,7 @@ pub struct SocketConfig {
pub host: Host,
pub port: u16,
pub connect_timeout: Option<Duration>,
pub ssl_mode: SslMode,
}
/// An asynchronous PostgreSQL client.
@@ -177,7 +179,6 @@ pub struct Client {
cached_typeinfo: CachedTypeInfo,
socket_config: SocketConfig,
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
}
@@ -187,7 +188,6 @@ impl Client {
sender: mpsc::UnboundedSender<FrontendMessage>,
receiver: mpsc::Receiver<BackendMessages>,
socket_config: SocketConfig,
ssl_mode: SslMode,
process_id: i32,
secret_key: i32,
) -> Client {
@@ -205,7 +205,6 @@ impl Client {
cached_typeinfo: Default::default(),
socket_config,
ssl_mode,
process_id,
secret_key,
}
@@ -331,10 +330,11 @@ impl Client {
/// connection associated with this client.
pub fn cancel_token(&self) -> CancelToken {
CancelToken {
socket_config: Some(self.socket_config.clone()),
ssl_mode: self.ssl_mode,
process_id: self.process_id,
secret_key: self.secret_key,
socket_config: self.socket_config.clone(),
raw: RawCancelToken {
process_id: self.process_id,
secret_key: self.secret_key,
},
}
}

View File

@@ -17,7 +17,6 @@ use crate::{Client, Connection, Error};
/// TLS configuration.
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[non_exhaustive]
pub enum SslMode {
/// Do not use TLS.
Disable,
@@ -231,7 +230,7 @@ impl Config {
/// Requires the `runtime` Cargo feature (enabled by default).
pub async fn connect<T>(
&self,
tls: T,
tls: &T,
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
where
T: MakeTlsConnect<TcpStream>,

View File

@@ -13,7 +13,7 @@ use crate::tls::{MakeTlsConnect, TlsConnect};
use crate::{Client, Config, Connection, Error, RawConnection};
pub async fn connect<T>(
mut tls: T,
tls: &T,
config: &Config,
) -> Result<(Client, Connection<TcpStream, T::Stream>), Error>
where
@@ -57,6 +57,7 @@ where
host: host.clone(),
port,
connect_timeout: config.connect_timeout,
ssl_mode: config.ssl_mode,
};
let (client_tx, conn_rx) = mpsc::unbounded_channel();
@@ -65,7 +66,6 @@ where
client_tx,
client_rx,
socket_config,
config.ssl_mode,
process_id,
secret_key,
);

View File

@@ -3,7 +3,7 @@
use postgres_protocol2::message::backend::ReadyForQueryBody;
pub use crate::cancel_token::CancelToken;
pub use crate::cancel_token::{CancelToken, RawCancelToken};
pub use crate::client::{Client, SocketConfig};
pub use crate::config::Config;
pub use crate::connect_raw::RawConnection;

View File

@@ -47,7 +47,7 @@ pub trait MakeTlsConnect<S> {
/// Creates a new `TlsConnect`or.
///
/// The domain name is provided for certificate verification and SNI.
fn make_tls_connect(&mut self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
fn make_tls_connect(&self, domain: &str) -> Result<Self::TlsConnect, Self::Error>;
}
/// An asynchronous function wrapping a stream in a TLS session.
@@ -85,7 +85,7 @@ impl<S> MakeTlsConnect<S> for NoTls {
type TlsConnect = NoTls;
type Error = NoTlsError;
fn make_tls_connect(&mut self, _: &str) -> Result<NoTls, NoTlsError> {
fn make_tls_connect(&self, _: &str) -> Result<NoTls, NoTlsError> {
Ok(NoTls)
}
}

View File

@@ -10,7 +10,7 @@ use std::sync::Arc;
use std::time::{Duration, SystemTime};
use std::{env, io};
use anyhow::{Context, Result};
use anyhow::{Context, Result, anyhow};
use azure_core::request_options::{IfMatchCondition, MaxResults, Metadata, Range};
use azure_core::{Continuable, HttpClient, RetryOptions, TransportOptions};
use azure_storage::StorageCredentials;
@@ -37,6 +37,7 @@ use crate::metrics::{AttemptOutcome, RequestKind, start_measuring_requests};
use crate::{
ConcurrencyLimiter, Download, DownloadError, DownloadKind, DownloadOpts, Listing, ListingMode,
ListingObject, RemotePath, RemoteStorage, StorageMetadata, TimeTravelError, TimeoutOrCancel,
Version, VersionKind,
};
pub struct AzureBlobStorage {
@@ -405,6 +406,39 @@ impl AzureBlobStorage {
pub fn container_name(&self) -> &str {
&self.container_name
}
async fn list_versions_with_permit(
&self,
_permit: &tokio::sync::SemaphorePermit<'_>,
prefix: Option<&RemotePath>,
mode: ListingMode,
max_keys: Option<NonZeroU32>,
cancel: &CancellationToken,
) -> Result<crate::VersionListing, DownloadError> {
let customize_builder = |mut builder: ListBlobsBuilder| {
builder = builder.include_versions(true);
// We do not return this info back to `VersionListing` yet.
builder = builder.include_deleted(true);
builder
};
let kind = RequestKind::ListVersions;
let mut stream = std::pin::pin!(self.list_streaming_for_fn(
prefix,
mode,
max_keys,
cancel,
kind,
customize_builder
));
let mut combined: crate::VersionListing =
stream.next().await.expect("At least one item required")?;
while let Some(list) = stream.next().await {
let list = list?;
combined.versions.extend(list.versions.into_iter());
}
Ok(combined)
}
}
trait ListingCollector {
@@ -488,27 +522,10 @@ impl RemoteStorage for AzureBlobStorage {
max_keys: Option<NonZeroU32>,
cancel: &CancellationToken,
) -> std::result::Result<crate::VersionListing, DownloadError> {
let customize_builder = |mut builder: ListBlobsBuilder| {
builder = builder.include_versions(true);
builder
};
let kind = RequestKind::ListVersions;
let mut stream = std::pin::pin!(self.list_streaming_for_fn(
prefix,
mode,
max_keys,
cancel,
kind,
customize_builder
));
let mut combined: crate::VersionListing =
stream.next().await.expect("At least one item required")?;
while let Some(list) = stream.next().await {
let list = list?;
combined.versions.extend(list.versions.into_iter());
}
Ok(combined)
let permit = self.permit(kind, cancel).await?;
self.list_versions_with_permit(&permit, prefix, mode, max_keys, cancel)
.await
}
async fn head_object(
@@ -803,14 +820,158 @@ impl RemoteStorage for AzureBlobStorage {
async fn time_travel_recover(
&self,
_prefix: Option<&RemotePath>,
_timestamp: SystemTime,
_done_if_after: SystemTime,
_cancel: &CancellationToken,
prefix: Option<&RemotePath>,
timestamp: SystemTime,
done_if_after: SystemTime,
cancel: &CancellationToken,
) -> Result<(), TimeTravelError> {
// TODO use Azure point in time recovery feature for this
// https://learn.microsoft.com/en-us/azure/storage/blobs/point-in-time-restore-overview
Err(TimeTravelError::Unimplemented)
let msg = "PLEASE NOTE: Azure Blob storage time-travel recovery may not work as expected "
.to_string()
+ "for some specific files. If a file gets deleted but then overwritten and we want to recover "
+ "to the time during the file was not present, this functionality will recover the file. Only "
+ "use the functionality for services that can tolerate this. For example, recovering a state of the "
+ "pageserver tenants.";
tracing::error!("{}", msg);
let kind = RequestKind::TimeTravel;
let permit = self.permit(kind, cancel).await?;
let mode = ListingMode::NoDelimiter;
let version_listing = self
.list_versions_with_permit(&permit, prefix, mode, None, cancel)
.await
.map_err(|err| match err {
DownloadError::Other(e) => TimeTravelError::Other(e),
DownloadError::Cancelled => TimeTravelError::Cancelled,
other => TimeTravelError::Other(other.into()),
})?;
let versions_and_deletes = version_listing.versions;
tracing::info!(
"Built list for time travel with {} versions and deletions",
versions_and_deletes.len()
);
// Work on the list of references instead of the objects directly,
// otherwise we get lifetime errors in the sort_by_key call below.
let mut versions_and_deletes = versions_and_deletes.iter().collect::<Vec<_>>();
versions_and_deletes.sort_by_key(|vd| (&vd.key, &vd.last_modified));
let mut vds_for_key = HashMap::<_, Vec<_>>::new();
for vd in &versions_and_deletes {
let Version { key, .. } = &vd;
let version_id = vd.version_id().map(|v| v.0.as_str());
if version_id == Some("null") {
return Err(TimeTravelError::Other(anyhow!(
"Received ListVersions response for key={key} with version_id='null', \
indicating either disabled versioning, or legacy objects with null version id values"
)));
}
tracing::trace!("Parsing version key={key} kind={:?}", vd.kind);
vds_for_key.entry(key).or_default().push(vd);
}
let warn_threshold = 3;
let max_retries = 10;
let is_permanent = |e: &_| matches!(e, TimeTravelError::Cancelled);
for (key, versions) in vds_for_key {
let last_vd = versions.last().unwrap();
let key = self.relative_path_to_name(key);
if last_vd.last_modified > done_if_after {
tracing::debug!("Key {key} has version later than done_if_after, skipping");
continue;
}
// the version we want to restore to.
let version_to_restore_to =
match versions.binary_search_by_key(&timestamp, |tpl| tpl.last_modified) {
Ok(v) => v,
Err(e) => e,
};
if version_to_restore_to == versions.len() {
tracing::debug!("Key {key} has no changes since timestamp, skipping");
continue;
}
let mut do_delete = false;
if version_to_restore_to == 0 {
// All versions more recent, so the key didn't exist at the specified time point.
tracing::debug!(
"All {} versions more recent for {key}, deleting",
versions.len()
);
do_delete = true;
} else {
match &versions[version_to_restore_to - 1] {
Version {
kind: VersionKind::Version(version_id),
..
} => {
let source_url = format!(
"{}/{}?versionid={}",
self.client
.url()
.map_err(|e| TimeTravelError::Other(anyhow!("{e}")))?,
key,
version_id.0
);
tracing::debug!(
"Promoting old version {} for {key} at {}...",
version_id.0,
source_url
);
backoff::retry(
|| async {
let blob_client = self.client.blob_client(key.clone());
let op = blob_client.copy(Url::from_str(&source_url).unwrap());
tokio::select! {
res = op => res.map_err(|e| TimeTravelError::Other(e.into())),
_ = cancel.cancelled() => Err(TimeTravelError::Cancelled),
}
},
is_permanent,
warn_threshold,
max_retries,
"copying object version for time_travel_recover",
cancel,
)
.await
.ok_or_else(|| TimeTravelError::Cancelled)
.and_then(|x| x)?;
tracing::info!(?version_id, %key, "Copied old version in Azure blob storage");
}
Version {
kind: VersionKind::DeletionMarker,
..
} => {
do_delete = true;
}
}
};
if do_delete {
if matches!(last_vd.kind, VersionKind::DeletionMarker) {
// Key has since been deleted (but there was some history), no need to do anything
tracing::debug!("Key {key} already deleted, skipping.");
} else {
tracing::debug!("Deleting {key}...");
self.delete(&RemotePath::from_string(&key).unwrap(), cancel)
.await
.map_err(|e| {
// delete_oid0 will use TimeoutOrCancel
if TimeoutOrCancel::caused_by_cancel(&e) {
TimeTravelError::Cancelled
} else {
TimeTravelError::Other(e)
}
})?;
}
}
}
Ok(())
}
}

View File

@@ -1022,6 +1022,7 @@ impl RemoteStorage for S3Bucket {
let Version { key, .. } = &vd;
let version_id = vd.version_id().map(|v| v.0.as_str());
if version_id == Some("null") {
// TODO: check the behavior of using the SDK on a non-versioned container
return Err(TimeTravelError::Other(anyhow!(
"Received ListVersions response for key={key} with version_id='null', \
indicating either disabled versioning, or legacy objects with null version id values"

View File

@@ -13,7 +13,7 @@ use utils::pageserver_feedback::PageserverFeedback;
use crate::membership::Configuration;
use crate::{ServerInfo, Term};
#[derive(Debug, Serialize)]
#[derive(Debug, Serialize, Deserialize)]
pub struct SafekeeperStatus {
pub id: NodeId,
}

View File

@@ -51,7 +51,6 @@ pageserver_api.workspace = true
pageserver_client.workspace = true # for ResponseErrorMessageExt TOOD refactor that
pageserver_compaction.workspace = true
pageserver_page_api.workspace = true
peekable.workspace = true
pem.workspace = true
pin-project-lite.workspace = true
postgres_backend.workspace = true
@@ -63,7 +62,6 @@ postgres-types.workspace = true
posthog_client_lite.workspace = true
pprof.workspace = true
pq_proto.workspace = true
prost.workspace = true
rand.workspace = true
range-set-blaze = { version = "0.1.16", features = ["alloc"] }
regex.workspace = true

View File

@@ -1,30 +0,0 @@
[package]
name = "pageserver_client_grpc"
version = "0.1.0"
edition = "2024"
[dependencies]
bytes.workspace = true
futures.workspace = true
http.workspace = true
thiserror.workspace = true
tonic.workspace = true
tracing.workspace = true
tokio = { version = "1.43.1", features = ["full", "macros", "net", "io-util", "rt", "rt-multi-thread"] }
uuid = { version = "1", features = ["v4"] }
tower = { version = "0.4", features = ["timeout", "util"] }
rand = "0.8"
tokio-util = { version = "0.7", features = ["compat"] }
hyper-util = "0.1.9"
hyper = "1.6.0"
metrics.workspace = true
priority-queue = "2.3.1"
async-trait = { version = "0.1" }
tokio-stream = "0.1"
dashmap = "5"
chrono = { version = "0.4", features = ["serde"] }
pageserver_page_api.workspace = true
pageserver_api.workspace = true
utils.workspace = true

View File

@@ -1,296 +0,0 @@
// examples/load_test.rs, generated by AI
use std::collections::{HashMap, HashSet};
use std::sync::{
Arc,
Mutex,
atomic::{AtomicU64, AtomicUsize, Ordering},
};
use std::time::{Duration, Instant};
use tokio::task;
use tokio::time::sleep;
use rand::Rng;
use tonic::Status;
use uuid::Uuid;
// Pull in your ConnectionPool and PooledItemFactory from the pageserver_client_grpc crate.
// Adjust these paths if necessary.
use pageserver_client_grpc::client_cache::ConnectionPool;
use pageserver_client_grpc::client_cache::PooledItemFactory;
// --------------------------------------
// GLOBAL COUNTERS FOR “CREATED” / “DROPPED” MockConnections
// --------------------------------------
static CREATED: AtomicU64 = AtomicU64::new(0);
static DROPPED: AtomicU64 = AtomicU64::new(0);
// --------------------------------------
// MockConnection + Factory
// --------------------------------------
#[derive(Debug)]
pub struct MockConnection {
pub id: u64,
}
impl Clone for MockConnection {
fn clone(&self) -> Self {
// Cloning a MockConnection does NOT count as “creating” a brandnew connection,
// so we do NOT bump CREATED here. We only bump CREATED in the factorys `create()`.
CREATED.fetch_add(1, Ordering::Relaxed);
MockConnection { id: self.id }
}
}
impl Drop for MockConnection {
fn drop(&mut self) {
// When a MockConnection actually gets dropped, bump the counter.
DROPPED.fetch_add(1, Ordering::SeqCst);
}
}
pub struct MockConnectionFactory {
counter: AtomicU64,
}
impl MockConnectionFactory {
pub fn new() -> Self {
MockConnectionFactory {
counter: AtomicU64::new(1),
}
}
}
#[async_trait::async_trait]
impl PooledItemFactory<MockConnection> for MockConnectionFactory {
/// The trait on ConnectionPool expects:
/// async fn create(&self, timeout: Duration)
/// -> Result<Result<MockConnection, Status>, tokio::time::error::Elapsed>;
///
/// On success: Ok(Ok(MockConnection))
/// On a simulated “gRPC” failure: Ok(Err(Status::…))
/// On a transport/factory error: Err(Box<…>)
async fn create(
&self,
_timeout: Duration,
) -> Result<Result<MockConnection, Status>, tokio::time::error::Elapsed> {
// Simulate connection creation immediately succeeding.
CREATED.fetch_add(1, Ordering::SeqCst);
let next_id = self.counter.fetch_add(1, Ordering::Relaxed);
Ok(Ok(MockConnection { id: next_id }))
}
}
// --------------------------------------
// CLIENT WORKER
// --------------------------------------
//
// Each worker repeatedly calls `pool.get_client().await`. When it succeeds, we:
// 1. Lock the shared Mutex<HashMap<u64, Arc<AtomicUsize>>> to fetch/insert an Arc<AtomicUsize> for this conn_id.
// 2. Lock the shared Mutex<HashSet<u64>> to record this conn_id as “seen.”
// 3. Drop both locks, then atomically increment that counter and assert it ≤ max_consumers.
// 4. Sleep 10100 ms to simulate “work.”
// 5. Atomically decrement the counter.
// 6. Call `pooled.finish(Ok(()))` to return to the pool.
async fn client_worker(
pool: Arc<ConnectionPool<MockConnection>>,
usage_map: Arc<Mutex<HashMap<u64, Arc<AtomicUsize>>>>,
seen_set: Arc<Mutex<HashSet<u64>>>,
max_consumers: usize,
worker_id: usize,
) {
for iteration in 0..10 {
match pool.clone().get_client().await {
Ok(pooled) => {
let conn: MockConnection = pooled.channel();
let conn_id = conn.id;
// 1. Fetch or insert the Arc<AtomicUsize> for this conn_id:
let counter_arc: Arc<AtomicUsize> = {
let mut guard = usage_map.lock().unwrap();
guard
.entry(conn_id)
.or_insert_with(|| Arc::new(AtomicUsize::new(0)))
.clone()
// MutexGuard is dropped here
};
// 2. Record this conn_id in the shared HashSet of “seen” IDs:
{
let mut seen_guard = seen_set.lock().unwrap();
seen_guard.insert(conn_id);
// MutexGuard is dropped immediately
}
// 3. Atomically bump the count for this connection ID
let prev = counter_arc.fetch_add(1, Ordering::SeqCst);
let current = prev + 1;
assert!(
current <= max_consumers,
"Connection {} exceeded max_consumers (got {})",
conn_id,
current
);
println!(
"[worker {}][iter {}] got MockConnection id={} ({} concurrent)",
worker_id, iteration, conn_id, current
);
// 4. Simulate some work (10100 ms)
let delay_ms = rand::thread_rng().gen_range(10..100);
sleep(Duration::from_millis(delay_ms)).await;
// 5. Decrement the usage counter
let prev2 = counter_arc.fetch_sub(1, Ordering::SeqCst);
let after = prev2 - 1;
println!(
"[worker {}][iter {}] returning MockConnection id={} (now {} remain)",
worker_id, iteration, conn_id, after
);
// 6. Return to the pool (mark success)
pooled.finish(Ok(())).await;
}
Err(status) => {
eprintln!(
"[worker {}][iter {}] failed to get client: {:?}",
worker_id, iteration, status
);
}
}
// Small random pause before next iteration to spread out load
let pause = rand::thread_rng().gen_range(0..20);
sleep(Duration::from_millis(pause)).await;
}
}
#[tokio::main(flavor = "multi_thread", worker_threads = 8)]
async fn main() {
// --------------------------------------
// 1. Create factory and shared instrumentation
// --------------------------------------
let factory = Arc::new(MockConnectionFactory::new());
// Shared map: connection ID → Arc<AtomicUsize>
let usage_map: Arc<Mutex<HashMap<u64, Arc<AtomicUsize>>>> =
Arc::new(Mutex::new(HashMap::new()));
// Shared set: record each unique connection ID we actually saw
let seen_set: Arc<Mutex<HashSet<u64>>> = Arc::new(Mutex::new(HashSet::new()));
// --------------------------------------
// 2. Pool parameters
// --------------------------------------
let connect_timeout = Duration::from_millis(500);
let connect_backoff = Duration::from_millis(100);
let max_consumers = 100; // test limit
let error_threshold = 2; // mock never fails
let max_idle_duration = Duration::from_secs(2);
let max_total_connections = 3;
let aggregate_metrics = None;
let pool: Arc<ConnectionPool<MockConnection>> = ConnectionPool::new(
factory,
connect_timeout,
connect_backoff,
max_consumers,
error_threshold,
max_idle_duration,
max_total_connections,
aggregate_metrics,
);
// --------------------------------------
// 3. Spawn worker tasks
// --------------------------------------
let num_workers = 10000;
let mut handles = Vec::with_capacity(num_workers);
let start_time = Instant::now();
for worker_id in 0..num_workers {
let pool_clone = Arc::clone(&pool);
let usage_clone = Arc::clone(&usage_map);
let seen_clone = Arc::clone(&seen_set);
let mc = max_consumers;
let handle = task::spawn(async move {
client_worker(pool_clone, usage_clone, seen_clone, mc, worker_id).await;
});
handles.push(handle);
}
// --------------------------------------
// 4. Wait for workers to finish
// --------------------------------------
for handle in handles {
let _ = handle.await;
}
let elapsed = Instant::now().duration_since(start_time);
println!(
"All {} workers completed in {:?}",
num_workers, elapsed
);
// --------------------------------------
// 5. Print the total number of unique connections seen so far
// --------------------------------------
let unique_count = {
let seen_guard = seen_set.lock().unwrap();
seen_guard.len()
};
println!("Total unique connections used by workers: {}", unique_count);
// --------------------------------------
// 6. Sleep so the background sweeper can run (max_idle_duration = 2 s)
// --------------------------------------
sleep(Duration::from_secs(3)).await;
// --------------------------------------
// 7. Shutdown the pool
// --------------------------------------
let shutdown_pool = Arc::clone(&pool);
shutdown_pool.shutdown().await;
println!("Pool.shutdown() returned.");
// --------------------------------------
// 8. Verify that no background task still holds an Arc clone of `pool`.
// If any task is still alive (sweeper/create_connection), strong_count > 1.
// --------------------------------------
sleep(Duration::from_secs(1)).await; // give tasks time to exit
let sc = Arc::strong_count(&pool);
assert!(
sc == 1,
"Pool tasks did not all terminate: Arc::strong_count = {} (expected 1)",
sc
);
println!("Verified: all pool tasks have terminated (strong_count == 1).");
// --------------------------------------
// 9. Verify no MockConnection was leaked:
// CREATED must equal DROPPED.
// --------------------------------------
let created = CREATED.load(Ordering::SeqCst);
let dropped = DROPPED.load(Ordering::SeqCst);
assert!(
created == dropped,
"Leaked connections: created={} but dropped={}",
created,
dropped
);
println!(
"Verified: no connections leaked (created = {}, dropped = {}).",
created, dropped
);
// --------------------------------------
// 10. Because `client_worker` asserted inside that no connection
// ever exceeded `max_consumers`, reaching this point means that check passed.
// --------------------------------------
println!("All per-connection usage stayed within max_consumers = {}.", max_consumers);
println!("Load test complete; exiting cleanly.");
}

View File

@@ -1,160 +0,0 @@
// examples/request_tracker_load_test.rs
use std::{sync::Arc, time::Duration};
use tokio;
use pageserver_client_grpc::request_tracker::RequestTracker;
use pageserver_client_grpc::request_tracker::MockStreamFactory;
use pageserver_client_grpc::request_tracker::StreamReturner;
use pageserver_client_grpc::client_cache::ConnectionPool;
use pageserver_client_grpc::client_cache::PooledItemFactory;
use pageserver_client_grpc::ClientCacheOptions;
use pageserver_client_grpc::PageserverClientAggregateMetrics;
use pageserver_client_grpc::AuthInterceptor;
use pageserver_client_grpc::client_cache::ChannelFactory;
use tonic::{transport::{Channel}, Request};
use rand::prelude::*;
use pageserver_api::key::Key;
use utils::lsn::Lsn;
use utils::id::TenantTimelineId;
use futures::stream::FuturesOrdered;
use futures::StreamExt;
// use chrono
use chrono::Utc;
use pageserver_page_api::{GetPageClass, GetPageResponse};
use pageserver_page_api::proto;
#[derive(Clone)]
struct KeyRange {
timeline: TenantTimelineId,
timeline_lsn: Lsn,
start: i128,
end: i128,
}
impl KeyRange {
fn len(&self) -> i128 {
self.end - self.start
}
}
#[tokio::main]
async fn main() {
// 1) configure the clientpool behavior
let client_cache_options = ClientCacheOptions {
max_delay_ms: 0,
drop_rate: 0.0,
hang_rate: 0.0,
connect_timeout: Duration::from_secs(10),
connect_backoff: Duration::from_millis(200),
max_consumers: 64,
error_threshold: 10,
max_idle_duration: Duration::from_secs(60),
max_total_connections: 12,
};
// 2) metrics collector (we assume Default is implemented)
let metrics = Arc::new(PageserverClientAggregateMetrics::new());
let pool = ConnectionPool::<StreamReturner>::new(
Arc::new(MockStreamFactory::new(
)),
client_cache_options.connect_timeout,
client_cache_options.connect_backoff,
client_cache_options.max_consumers,
client_cache_options.error_threshold,
client_cache_options.max_idle_duration,
client_cache_options.max_total_connections,
Some(Arc::clone(&metrics)),
);
// -----------
// There is no mock for the unary connection pool, so for now just
// don't use this pool
//
let channel_fact : Arc<dyn PooledItemFactory<Channel> + Send + Sync> = Arc::new(ChannelFactory::new(
"".to_string(),
client_cache_options.max_delay_ms,
client_cache_options.drop_rate,
client_cache_options.hang_rate,
));
let unary_pool: Arc<ConnectionPool<Channel>> = ConnectionPool::new(
Arc::clone(&channel_fact),
client_cache_options.connect_timeout,
client_cache_options.connect_backoff,
client_cache_options.max_consumers,
client_cache_options.error_threshold,
client_cache_options.max_idle_duration,
client_cache_options.max_total_connections,
Some(Arc::clone(&metrics)),
);
// -----------
// Dummy auth interceptor. This is not used in this test.
let auth_interceptor = AuthInterceptor::new("dummy_tenant_id",
"dummy_timeline_id",
None);
let mut tracker = RequestTracker::new(
pool,
unary_pool,
auth_interceptor,
);
// 4) fire off 10 000 requests in parallel
let mut handles = FuturesOrdered::new();
for i in 0..500000 {
let mut rng = rand::thread_rng();
let r = 0..=1000000i128;
let key: i128 = rng.gen_range(r.clone());
let key = Key::from_i128(key);
let (rel_tag, block_no) = key
.to_rel_block()
.expect("we filter non-rel-block keys out above");
let req2 = proto::GetPageRequest {
request_id: 0,
request_class: proto::GetPageClass::Normal as i32,
read_lsn: Some(proto::ReadLsn {
request_lsn: if rng.gen_bool(0.5) {
u64::from(Lsn::MAX)
} else {
10000
},
not_modified_since_lsn: 10000,
}),
rel: Some(rel_tag.into()),
block_number: vec![block_no],
};
let req_model = pageserver_page_api::GetPageRequest::try_from(req2.clone());
// RequestTracker is Clone, so we can share it
let mut tr = tracker.clone();
let fut = async move {
let resp = tr.send_getpage_request(req_model.unwrap()).await.unwrap();
// sanitycheck: the mock echo returns the same request_id
assert!(resp.request_id > 0);
};
handles.push_back(fut);
// empty future
let fut = async move {};
fut.await;
}
// print timestamp
println!("Starting 5000000 requests at: {}", chrono::Utc::now());
// 5) wait for them all
for i in 0..500000 {
handles.next().await.expect("Failed to get next handle");
}
// print timestamp
println!("Finished 5000000 requests at: {}", chrono::Utc::now());
println!("✅ All 100000 requests completed successfully");
}

View File

@@ -1,741 +0,0 @@
use std::{
collections::HashMap,
io::{self, Error, ErrorKind},
sync::Arc,
time::{Duration, Instant},
};
use priority_queue::PriorityQueue;
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::TcpStream,
sync::{Mutex, OwnedSemaphorePermit, Semaphore},
time::sleep,
};
use tonic::transport::{Channel, Endpoint};
use uuid;
use std::{
pin::Pin,
task::{Context, Poll},
};
use futures::future;
use rand::{Rng, SeedableRng, rngs::StdRng};
use bytes::BytesMut;
use http::Uri;
use hyper_util::rt::TokioIo;
use tower::service_fn;
use tokio_util::sync::CancellationToken;
use async_trait::async_trait;
//
// The "TokioTcp" is flakey TCP network for testing purposes, in order
// to simulate network errors and delays.
//
/// Wraps a `TcpStream`, buffers incoming data, and injects a random delay per fresh read/write.
pub struct TokioTcp {
tcp: TcpStream,
/// Maximum randomized delay in milliseconds
delay_ms: u64,
/// Next deadline instant for delay
deadline: Instant,
/// Internal buffer of previously-read data
buffer: BytesMut,
}
impl TokioTcp {
/// Create a new wrapper with given max delay (ms)
pub fn new(stream: TcpStream, delay_ms: u64) -> Self {
let initial = if delay_ms > 0 {
rand::thread_rng().gen_range(0..delay_ms)
} else {
0
};
let deadline = Instant::now() + Duration::from_millis(initial);
TokioTcp {
tcp: stream,
delay_ms,
deadline,
buffer: BytesMut::new(),
}
}
}
impl AsyncRead for TokioTcp {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
// Safe because TokioTcp is Unpin
let this = self.get_mut();
// 1) Drain any buffered data
if !this.buffer.is_empty() {
let to_copy = this.buffer.len().min(buf.remaining());
buf.put_slice(&this.buffer.split_to(to_copy));
return Poll::Ready(Ok(()));
}
// 2) If we're still before the deadline, schedule a wake and return Pending
let now = Instant::now();
if this.delay_ms > 0 && now < this.deadline {
let waker = cx.waker().clone();
let wait = this.deadline - now;
tokio::spawn(async move {
sleep(wait).await;
waker.wake_by_ref();
});
return Poll::Pending;
}
// 3) Past deadline: compute next random deadline
if this.delay_ms > 0 {
let next_ms = rand::thread_rng().gen_range(0..=this.delay_ms);
this.deadline = Instant::now() + Duration::from_millis(next_ms);
}
// 4) Perform actual read into a temporary buffer
let mut tmp = [0u8; 4096];
let mut rb = ReadBuf::new(&mut tmp);
match Pin::new(&mut this.tcp).poll_read(cx, &mut rb) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(())) => {
let filled = rb.filled();
if filled.is_empty() {
// EOF or zero bytes
Poll::Ready(Ok(()))
} else {
this.buffer.extend_from_slice(filled);
let to_copy = this.buffer.len().min(buf.remaining());
buf.put_slice(&this.buffer.split_to(to_copy));
Poll::Ready(Ok(()))
}
}
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
}
}
}
impl AsyncWrite for TokioTcp {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
data: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
// 1) If before deadline, schedule wake and return Pending
let now = Instant::now();
if this.delay_ms > 0 && now < this.deadline {
let waker = cx.waker().clone();
let wait = this.deadline - now;
tokio::spawn(async move {
sleep(wait).await;
waker.wake_by_ref();
});
return Poll::Pending;
}
// 2) Past deadline: compute next random deadline
if this.delay_ms > 0 {
let next_ms = rand::thread_rng().gen_range(0..=this.delay_ms);
this.deadline = Instant::now() + Duration::from_millis(next_ms);
}
// 3) Actual write
Pin::new(&mut this.tcp).poll_write(cx, data)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
Pin::new(&mut this.tcp).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
Pin::new(&mut this.tcp).poll_shutdown(cx)
}
}
#[async_trait]
pub trait PooledItemFactory<T>: Send + Sync + 'static {
/// Create a new pooled item.
async fn create(&self, connect_timeout: Duration) -> Result<Result<T, tonic::Status>, tokio::time::error::Elapsed>;
}
pub struct ChannelFactory {
endpoint: String,
max_delay_ms: u64,
drop_rate: f64,
hang_rate: f64,
}
impl ChannelFactory {
pub fn new(
endpoint: String,
max_delay_ms: u64,
drop_rate: f64,
hang_rate: f64,
) -> Self {
ChannelFactory {
endpoint,
max_delay_ms,
drop_rate,
hang_rate,
}
}
}
#[async_trait]
impl PooledItemFactory<Channel> for ChannelFactory {
async fn create(&self, connect_timeout: Duration) -> Result<Result<Channel, tonic::Status>, tokio::time::error::Elapsed> {
let max_delay_ms = self.max_delay_ms;
let drop_rate = self.drop_rate;
let hang_rate = self.hang_rate;
// This is a custom connector that inserts delays and errors, for
// testing purposes. It would normally be disabled by the config.
let connector = service_fn(move |uri: Uri| {
let drop_rate = drop_rate;
let hang_rate = hang_rate;
async move {
let mut rng = StdRng::from_entropy();
// Simulate an indefinite hang
if hang_rate > 0.0 && rng.gen_bool(hang_rate) {
// never completes, to test timeout
return future::pending::<Result<TokioIo<TokioTcp>, std::io::Error>>().await;
}
// Random drop (connect error)
if drop_rate > 0.0 && rng.gen_bool(drop_rate) {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"simulated connect drop",
));
}
// Otherwise perform real TCP connect
let addr = match (uri.host(), uri.port()) {
// host + explicit port
(Some(host), Some(port)) => format!("{}:{}", host, port.as_str()),
// host only (no port)
(Some(host), None) => host.to_string(),
// neither? error out
_ => return Err(Error::new(ErrorKind::InvalidInput, "no host or port")),
};
let tcp = TcpStream::connect(addr).await?;
let tcpwrapper = TokioTcp::new(tcp, max_delay_ms);
Ok(TokioIo::new(tcpwrapper))
}
});
let attempt = tokio::time::timeout(
connect_timeout,
Endpoint::from_shared(self.endpoint.clone())
.expect("invalid endpoint")
.timeout(connect_timeout)
.connect_with_connector(connector),
)
.await;
match attempt {
Ok(Ok(channel)) => {
// Connection succeeded
Ok(Ok(channel))
}
Ok(Err(e)) => {
Ok(Err(tonic::Status::new(
tonic::Code::Unavailable,
format!("Failed to connect: {}", e),
)))
}
Err(e) => {
Err(e)
}
}
}
}
/// A pooled gRPC client with capacity tracking and error handling.
pub struct ConnectionPool<T> {
inner: Mutex<Inner<T>>,
fact: Arc<dyn PooledItemFactory<T> + Send + Sync>,
connect_timeout: Duration,
connect_backoff: Duration,
/// The maximum number of consumers that can use a single connection.
max_consumers: usize,
/// The number of consecutive errors before a connection is removed from the pool.
error_threshold: usize,
/// The maximum duration a connection can be idle before being removed.
max_idle_duration: Duration,
max_total_connections: usize,
channel_semaphore: Arc<Semaphore>,
shutdown_token: CancellationToken,
aggregate_metrics: Option<Arc<crate::PageserverClientAggregateMetrics>>,
}
struct Inner<T> {
entries: HashMap<uuid::Uuid, ConnectionEntry<T>>,
pq: PriorityQueue<uuid::Uuid, usize>,
// This is updated when a connection is dropped, or we fail
// to create a new connection.
last_connect_failure: Option<Instant>,
waiters: usize,
in_progress: usize,
}
struct ConnectionEntry<T> {
channel: T,
active_consumers: usize,
consecutive_errors: usize,
last_used: Instant,
}
/// A client borrowed from the pool.
pub struct PooledClient<T> {
pub channel: T,
pool: Arc<ConnectionPool<T>>,
is_ok: bool,
id: uuid::Uuid,
permit: OwnedSemaphorePermit,
}
impl<T: Clone + Send + 'static> ConnectionPool<T> {
pub fn new(
fact: Arc<dyn PooledItemFactory<T> + Send + Sync>,
connect_timeout: Duration,
connect_backoff: Duration,
max_consumers: usize,
error_threshold: usize,
max_idle_duration: Duration,
max_total_connections: usize,
aggregate_metrics: Option<Arc<crate::PageserverClientAggregateMetrics>>,
) -> Arc<Self> {
let shutdown_token = CancellationToken::new();
let pool = Arc::new(Self {
inner: Mutex::new(Inner::<T> {
entries: HashMap::new(),
pq: PriorityQueue::new(),
last_connect_failure: None,
waiters: 0,
in_progress: 0,
}),
fact: Arc::clone(&fact),
connect_timeout,
connect_backoff,
max_consumers,
error_threshold,
max_idle_duration,
max_total_connections,
channel_semaphore: Arc::new(Semaphore::new(0)),
shutdown_token: shutdown_token.clone(),
aggregate_metrics: aggregate_metrics.clone(),
});
// Cancelable background task to sweep idle connections
let sweeper_token = shutdown_token.clone();
let sweeper_pool = Arc::clone(&pool);
tokio::spawn(async move {
loop {
tokio::select! {
_ = sweeper_token.cancelled() => break,
_ = async {
sweeper_pool.sweep_idle_connections().await;
sleep(Duration::from_secs(5)).await;
} => {}
}
}
});
pool
}
pub async fn shutdown(self: Arc<Self>) {
self.shutdown_token.cancel();
loop {
let all_idle = {
let inner = self.inner.lock().await;
inner.entries.values().all(|e| e.active_consumers == 0)
};
if all_idle {
break;
}
sleep(Duration::from_millis(100)).await;
}
// 4. Remove all entries
let mut inner = self.inner.lock().await;
inner.entries.clear();
}
/// Sweep and remove idle connections safely, burning their permits.
async fn sweep_idle_connections(self: &Arc<Self>) {
let mut ids_to_remove = Vec::new();
let now = Instant::now();
// Remove idle entries. First collect permits for those connections so that
// no consumer will reserve them, then remove them from the pool.
{
let mut inner = self.inner.lock().await;
inner.entries.retain(|id, entry| {
if entry.active_consumers == 0
&& now.duration_since(entry.last_used) > self.max_idle_duration
{
// metric
match self.aggregate_metrics {
Some(ref metrics) => {
metrics
.retry_counters
.with_label_values(&["connection_swept"])
.inc();
}
None => {}
}
ids_to_remove.push(*id);
return false; // remove this entry
}
true
});
// Remove the entries from the priority queue
for id in ids_to_remove {
inner.pq.remove(&id);
}
}
}
// If we have a permit already, get a connection out of the heap
async fn get_conn_with_permit(
self: Arc<Self>,
permit: OwnedSemaphorePermit,
) -> Option<PooledClient<T>> {
let mut inner = self.inner.lock().await;
// Pop the highest-active-consumers connection. There are no connections
// in the heap that have more than max_consumers active consumers.
if let Some((id, _cons)) = inner.pq.pop() {
let entry = inner
.entries
.get_mut(&id)
.expect("pq and entries got out of sync");
let mut active_consumers = entry.active_consumers;
entry.active_consumers += 1;
entry.last_used = Instant::now();
let client = PooledClient::<T> {
channel: entry.channel.clone(),
pool: Arc::clone(&self),
is_ok: true,
id,
permit: permit,
};
// reinsert with updated priority
active_consumers += 1;
if active_consumers < self.max_consumers {
inner.pq.push(id, active_consumers as usize);
}
return Some(client);
} else {
// If there is no connection to take, it is because permits for a connection
// need to drain. This can happen if a connection is removed because it has
// too many errors. It is taken out of the heap/hash table in this case, but
// we can't remove it's permits until now.
//
// Just forget the permit and retry.
permit.forget();
return None;
}
}
pub async fn get_client(self: Arc<Self>) -> Result<PooledClient<T>, tonic::Status> {
// The pool is shutting down. Don't accept new connections.
if self.shutdown_token.is_cancelled() {
return Err(tonic::Status::unavailable("Pool is shutting down"));
}
// A loop is necessary because when a connection is draining, we have to return
// a permit and retry.
loop {
let self_clone = Arc::clone(&self);
let mut semaphore = Arc::clone(&self_clone.channel_semaphore);
match semaphore.try_acquire_owned() {
Ok(permit_) => {
// We got a permit, so check the heap for a connection
// we can use.
let pool_conn = self_clone.get_conn_with_permit(permit_).await;
match pool_conn {
Some(pool_conn_) => {
return Ok(pool_conn_);
}
None => {
// No connection available. Forget the permit and retry.
continue;
}
}
}
Err(_) => {
match self_clone.aggregate_metrics {
Some(ref metrics) => {
metrics
.retry_counters
.with_label_values(&["sema_acquire_failed"])
.inc();
}
None => {}
}
{
//
// This is going to generate enough connections to handle a burst,
// but it may generate up to twice the number of connections needed
// in the worst case. Extra connections will go idle and be cleaned
// up.
//
let mut inner = self_clone.inner.lock().await;
inner.waiters += 1;
if inner.waiters > (inner.in_progress * self_clone.max_consumers) {
if (inner.entries.len() + inner.in_progress) < self_clone.max_total_connections {
let self_clone_spawn = Arc::clone(&self_clone);
tokio::task::spawn(async move {
self_clone_spawn.create_connection().await;
});
inner.in_progress += 1;
}
}
}
// Wait for a connection to become available, either because it
// was created or because a connection was returned to the pool
// by another consumer.
semaphore = Arc::clone(&self_clone.channel_semaphore);
let conn_permit = semaphore.acquire_owned().await.unwrap();
{
let mut inner = self_clone.inner.lock().await;
inner.waiters -= 1;
}
// We got a permit, check the heap for a connection.
let pool_conn = self_clone.get_conn_with_permit(conn_permit).await;
match pool_conn {
Some(pool_conn_) => {
return Ok(pool_conn_);
}
None => {
// No connection was found, forget the permit and retry.
continue;
}
}
}
}
}
}
async fn create_connection(&self) -> () {
// Generate a random backoff to add some jitter so that connections
// don't all retry at the same time.
let mut backoff_delay = Duration::from_millis(
rand::thread_rng().gen_range(0..=self.connect_backoff.as_millis() as u64),
);
loop {
if self.shutdown_token.is_cancelled() {
return;
}
// Back off.
// Loop because failure can occur while we are sleeping, so wait
// until the failure stopped for at least one backoff period. Backoff
// period includes some jitter, so that if multiple connections are
// failing, they don't all retry at the same time.
loop {
if let Some(delay) = {
let inner = self.inner.lock().await;
inner.last_connect_failure.and_then(|at| {
(at.elapsed() < backoff_delay).then(|| backoff_delay - at.elapsed())
})
} {
sleep(delay).await;
} else {
break; // No delay, so we can create a connection
}
}
//
// Create a new connection.
//
// The connect timeout is also the timeout for an individual gRPC request
// on this connection. (Requests made later on this channel will time out
// with the same timeout.)
//
match self.aggregate_metrics {
Some(ref metrics) => {
metrics
.retry_counters
.with_label_values(&["connection_attempt"])
.inc();
}
None => {}
}
let attempt = self.fact
.create(self.connect_timeout)
.await;
match attempt {
// Connection succeeded
Ok(Ok(channel)) => {
{
match self.aggregate_metrics {
Some(ref metrics) => {
metrics
.retry_counters
.with_label_values(&["connection_success"])
.inc();
}
None => {}
}
let mut inner = self.inner.lock().await;
let id = uuid::Uuid::new_v4();
inner.entries.insert(
id,
ConnectionEntry::<T> {
channel: channel.clone(),
active_consumers: 0,
consecutive_errors: 0,
last_used: Instant::now(),
},
);
inner.pq.push(id, 0);
inner.in_progress -= 1;
self.channel_semaphore.add_permits(self.max_consumers);
return;
};
}
// Connection failed, back off and retry
Ok(Err(_)) | Err(_) => {
match self.aggregate_metrics {
Some(ref metrics) => {
metrics
.retry_counters
.with_label_values(&["connect_failed"])
.inc();
}
None => {}
}
let mut inner = self.inner.lock().await;
inner.last_connect_failure = Some(Instant::now());
// Add some jitter so that every connection doesn't retry at once
let jitter = rand::thread_rng().gen_range(0..=backoff_delay.as_millis() as u64);
backoff_delay =
Duration::from_millis(backoff_delay.as_millis() as u64 + jitter);
// Do not backoff longer than one minute
if backoff_delay > Duration::from_secs(60) {
backoff_delay = Duration::from_secs(60);
}
// continue the loop to retry
}
}
}
}
/// Return client to the pool, indicating success or error.
pub async fn return_client(&self, id: uuid::Uuid, success: bool, permit: OwnedSemaphorePermit) {
let mut inner = self.inner.lock().await;
if let Some(entry) = inner.entries.get_mut(&id) {
entry.last_used = Instant::now();
if entry.active_consumers <= 0 {
panic!("A consumer completed when active_consumers was zero!")
}
entry.active_consumers = entry.active_consumers - 1;
if success {
if entry.consecutive_errors < self.error_threshold {
entry.consecutive_errors = 0;
}
} else {
entry.consecutive_errors += 1;
if entry.consecutive_errors == self.error_threshold {
match self.aggregate_metrics {
Some(ref metrics) => {
metrics
.retry_counters
.with_label_values(&["connection_dropped"])
.inc();
}
None => {}
}
}
}
//
// Too many errors on this connection. If there are no active users,
// remove it. Otherwise just wait for active_consumers to go to zero.
// This connection will not be selected for new consumers.
//
let active_consumers = entry.active_consumers;
if entry.consecutive_errors >= self.error_threshold {
// too many errors, remove the connection permanently. Once it drains,
// it will be dropped.
if inner.pq.get_priority(&id).is_some() {
inner.pq.remove(&id);
}
// remove from entries
// check if entry is in inner
if inner.entries.contains_key(&id) {
inner.entries.remove(&id);
}
inner.last_connect_failure = Some(Instant::now());
// The connection has been removed, it's permits will be
// drained because if we look for a connection and it's not there
// we just forget the permit. However, this process can be a little
// bit faster if we just forget permits as the connections are returned.
permit.forget();
} else {
// update its priority in the queue
if inner.pq.get_priority(&id).is_some() {
inner.pq.change_priority(&id, active_consumers);
} else {
// This connection is not in the heap, but it has space
// for more consumers. Put it back in the heap.
if active_consumers < self.max_consumers {
inner.pq.push(id, active_consumers);
}
}
}
}
}
}
impl<T: Clone + Send + 'static> PooledClient<T> {
pub fn channel(&self) -> T {
return self.channel.clone();
}
pub async fn finish(mut self, result: Result<(), tonic::Status>) {
self.is_ok = result.is_ok();
self.pool.return_client(
self.id,
self.is_ok,
self.permit,
).await;
}
}

View File

@@ -1,456 +0,0 @@
//! Pageserver Data API client
//!
//! - Manage connections to pageserver
//! - Send requests to correct shards
//!
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::RwLock;
use std::time::Duration;
use bytes::Bytes;
use futures::{Stream, StreamExt};
use thiserror::Error;
use tonic::metadata::AsciiMetadataValue;
use pageserver_page_api::proto;
use pageserver_page_api::*;
use pageserver_page_api::proto::PageServiceClient;
use utils::shard::ShardIndex;
use std::fmt::Debug;
pub mod client_cache;
pub mod request_tracker;
use tonic::transport::Channel;
use metrics::{IntCounterVec, core::Collector};
use crate::client_cache::{PooledItemFactory};
use tokio::sync::mpsc;
use async_trait::async_trait;
#[derive(Error, Debug)]
pub enum PageserverClientError {
#[error("could not connect to service: {0}")]
ConnectError(#[from] tonic::transport::Error),
#[error("could not perform request: {0}`")]
RequestError(#[from] tonic::Status),
#[error("protocol error: {0}")]
ProtocolError(#[from] ProtocolError),
#[error("could not perform request: {0}`")]
InvalidUri(#[from] http::uri::InvalidUri),
#[error("could not perform request: {0}`")]
Other(String),
}
#[derive(Clone, Debug)]
pub struct PageserverClientAggregateMetrics {
pub request_counters: IntCounterVec,
pub retry_counters: IntCounterVec,
}
impl PageserverClientAggregateMetrics {
pub fn new() -> Self {
let request_counters = IntCounterVec::new(
metrics::core::Opts::new(
"backend_requests_total",
"Number of requests from backends.",
),
&["request_kind"],
)
.unwrap();
let retry_counters = IntCounterVec::new(
metrics::core::Opts::new(
"backend_requests_retries_total",
"Number of retried requests from backends.",
),
&["request_kind"],
)
.unwrap();
Self {
request_counters,
retry_counters,
}
}
pub fn collect(&self) -> Vec<metrics::proto::MetricFamily> {
let mut metrics = Vec::new();
metrics.append(&mut self.request_counters.collect());
metrics.append(&mut self.retry_counters.collect());
metrics
}
}
pub struct PageserverClient {
_tenant_id: String,
_timeline_id: String,
_auth_token: Option<String>,
shard_map: HashMap<ShardIndex, String>,
channels: RwLock<HashMap<ShardIndex, Arc<client_cache::ConnectionPool<Channel>>>>,
auth_interceptor: AuthInterceptor,
client_cache_options: ClientCacheOptions,
aggregate_metrics: Option<Arc<PageserverClientAggregateMetrics>>,
}
#[derive(Clone)]
pub struct ClientCacheOptions {
pub max_consumers: usize,
pub error_threshold: usize,
pub connect_timeout: Duration,
pub connect_backoff: Duration,
pub max_idle_duration: Duration,
pub max_total_connections: usize,
pub max_delay_ms: u64,
pub drop_rate: f64,
pub hang_rate: f64,
}
impl PageserverClient {
/// TODO: this doesn't currently react to changes in the shard map.
pub fn new(
tenant_id: &str,
timeline_id: &str,
auth_token: &Option<String>,
shard_map: HashMap<ShardIndex, String>,
) -> Self {
let options = ClientCacheOptions {
max_consumers: 5000,
error_threshold: 5,
connect_timeout: Duration::from_secs(5),
connect_backoff: Duration::from_secs(1),
max_idle_duration: Duration::from_secs(60),
max_total_connections: 100000,
max_delay_ms: 0,
drop_rate: 0.0,
hang_rate: 0.0,
};
Self::new_with_config(tenant_id, timeline_id, auth_token, shard_map, options, None)
}
pub fn new_with_config(
tenant_id: &str,
timeline_id: &str,
auth_token: &Option<String>,
shard_map: HashMap<ShardIndex, String>,
options: ClientCacheOptions,
metrics: Option<Arc<PageserverClientAggregateMetrics>>,
) -> Self {
Self {
_tenant_id: tenant_id.to_string(),
_timeline_id: timeline_id.to_string(),
_auth_token: auth_token.clone(),
shard_map,
channels: RwLock::new(HashMap::new()),
auth_interceptor: AuthInterceptor::new(tenant_id, timeline_id, auth_token.as_deref()),
client_cache_options: options,
aggregate_metrics: metrics,
}
}
pub async fn process_check_rel_exists_request(
&self,
request: CheckRelExistsRequest,
) -> Result<bool, PageserverClientError> {
// Current sharding model assumes that all metadata is present only at shard 0.
let shard = ShardIndex::unsharded();
let pooled_client = self.get_client(shard).await;
let chan = pooled_client.channel();
let mut client =
PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard));
let request = proto::CheckRelExistsRequest::from(request);
let response = client.check_rel_exists(tonic::Request::new(request)).await;
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
return Err(PageserverClientError::RequestError(status));
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
return Ok(resp.get_ref().exists);
}
}
}
pub async fn process_get_rel_size_request(
&self,
request: GetRelSizeRequest,
) -> Result<u32, PageserverClientError> {
// Current sharding model assumes that all metadata is present only at shard 0.
let shard = ShardIndex::unsharded();
let pooled_client = self.get_client(shard).await;
let chan = pooled_client.channel();
let mut client =
PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard));
let request = proto::GetRelSizeRequest::from(request);
let response = client.get_rel_size(tonic::Request::new(request)).await;
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
return Err(PageserverClientError::RequestError(status));
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
return Ok(resp.get_ref().num_blocks);
}
}
}
// Request a single batch of pages
//
// TODO: This opens a new gRPC stream for every request, which is extremely inefficient
pub async fn get_page(
&self,
request: GetPageRequest,
) -> Result<Vec<Bytes>, PageserverClientError> {
// FIXME: calculate the shard number correctly
let shard = ShardIndex::unsharded();
let pooled_client = self.get_client(shard).await;
let chan = pooled_client.channel();
let mut client =
PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard));
let request = proto::GetPageRequest::from(request);
let request_stream = futures::stream::once(std::future::ready(request));
let mut response_stream = client
.get_pages(tonic::Request::new(request_stream))
.await?
.into_inner();
let Some(response) = response_stream.next().await else {
return Err(PageserverClientError::Other(
"no response received for getpage request".to_string(),
));
};
match self.aggregate_metrics {
Some(ref metrics) => {
metrics
.request_counters
.with_label_values(&["get_page"])
.inc();
}
None => {}
}
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
return Err(PageserverClientError::RequestError(status));
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
let response: GetPageResponse = resp.into();
return Ok(response.page_images.to_vec());
}
}
}
// Open a stream for requesting pages
//
// TODO: This is a pretty low level interface, the caller should not need to be concerned
// with streams. But 'get_page' is currently very naive and inefficient.
pub async fn get_pages(
&self,
requests: impl Stream<Item = proto::GetPageRequest> + Send + 'static,
) -> std::result::Result<
tonic::Response<tonic::codec::Streaming<proto::GetPageResponse>>,
PageserverClientError,
> {
// FIXME: calculate the shard number correctly
let shard = ShardIndex::unsharded();
let pooled_client = self.get_client(shard).await;
let chan = pooled_client.channel();
let mut client =
PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard));
let response = client.get_pages(tonic::Request::new(requests)).await;
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
return Err(PageserverClientError::RequestError(status));
}
Ok(resp) => {
return Ok(resp);
}
}
}
/// Process a request to get the size of a database.
pub async fn process_get_dbsize_request(
&self,
request: GetDbSizeRequest,
) -> Result<u64, PageserverClientError> {
// Current sharding model assumes that all metadata is present only at shard 0.
let shard = ShardIndex::unsharded();
let pooled_client = self.get_client(shard).await;
let chan = pooled_client.channel();
let mut client =
PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard));
let request = proto::GetDbSizeRequest::from(request);
let response = client.get_db_size(tonic::Request::new(request)).await;
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
return Err(PageserverClientError::RequestError(status));
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
return Ok(resp.get_ref().num_bytes);
}
}
}
/// Process a request to get the size of a database.
pub async fn get_base_backup(
&self,
request: GetBaseBackupRequest,
gzip: bool,
) -> std::result::Result<
tonic::Response<tonic::codec::Streaming<proto::GetBaseBackupResponseChunk>>,
PageserverClientError,
> {
// Current sharding model assumes that all metadata is present only at shard 0.
let shard = ShardIndex::unsharded();
let pooled_client = self.get_client(shard).await;
let chan = pooled_client.channel();
let mut client =
PageServiceClient::with_interceptor(chan, self.auth_interceptor.for_shard(shard));
if gzip {
client = client.accept_compressed(tonic::codec::CompressionEncoding::Gzip);
}
let request = proto::GetBaseBackupRequest::from(request);
let response = client.get_base_backup(tonic::Request::new(request)).await;
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
return Err(PageserverClientError::RequestError(status));
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
return Ok(resp);
}
}
}
/// Get a client for given shard
///
/// Get a client from the pool for this shard, also creating the pool if it doesn't exist.
///
async fn get_client(&self, shard: ShardIndex) -> client_cache::PooledClient<Channel> {
let reused_pool: Option<Arc<client_cache::ConnectionPool<Channel>>> = {
let channels = self.channels.read().unwrap();
channels.get(&shard).cloned()
};
let usable_pool: Arc<client_cache::ConnectionPool<Channel>>;
match reused_pool {
Some(pool) => {
let pooled_client = pool.get_client().await.unwrap();
return pooled_client;
}
None => {
// Create a new pool using client_cache_options
// declare new_pool
let new_pool: Arc<client_cache::ConnectionPool<Channel>>;
let channel_fact = Arc::new(client_cache::ChannelFactory::new(
self.shard_map.get(&shard).unwrap().clone(),
self.client_cache_options.max_delay_ms,
self.client_cache_options.drop_rate,
self.client_cache_options.hang_rate,
));
new_pool = client_cache::ConnectionPool::new(
channel_fact,
self.client_cache_options.connect_timeout,
self.client_cache_options.connect_backoff,
self.client_cache_options.max_consumers,
self.client_cache_options.error_threshold,
self.client_cache_options.max_idle_duration,
self.client_cache_options.max_total_connections,
self.aggregate_metrics.clone(),
);
let mut write_pool = self.channels.write().unwrap();
write_pool.insert(shard, new_pool.clone());
usable_pool = new_pool.clone();
}
}
let pooled_client = usable_pool.get_client().await.unwrap();
return pooled_client;
}
}
/// Inject tenant_id, timeline_id and authentication token to all pageserver requests.
#[derive(Clone)]
pub struct AuthInterceptor {
tenant_id: AsciiMetadataValue,
shard_id: Option<AsciiMetadataValue>,
timeline_id: AsciiMetadataValue,
auth_header: Option<AsciiMetadataValue>, // including "Bearer " prefix
}
impl AuthInterceptor {
pub fn new(tenant_id: &str, timeline_id: &str, auth_token: Option<&str>) -> Self {
Self {
tenant_id: tenant_id.parse().expect("could not parse tenant id"),
shard_id: None,
timeline_id: timeline_id.parse().expect("could not parse timeline id"),
auth_header: auth_token
.map(|t| format!("Bearer {t}"))
.map(|t| t.parse().expect("could not parse auth token")),
}
}
fn for_shard(&self, shard_id: ShardIndex) -> Self {
let mut with_shard = self.clone();
with_shard.shard_id = Some(
shard_id
.to_string()
.parse()
.expect("could not parse shard id"),
);
with_shard
}
}
impl tonic::service::Interceptor for AuthInterceptor {
fn call(&mut self, mut req: tonic::Request<()>) -> Result<tonic::Request<()>, tonic::Status> {
req.metadata_mut()
.insert("neon-tenant-id", self.tenant_id.clone());
if let Some(shard_id) = &self.shard_id {
req.metadata_mut().insert("neon-shard-id", shard_id.clone());
}
req.metadata_mut()
.insert("neon-timeline-id", self.timeline_id.clone());
if let Some(auth_header) = &self.auth_header {
req.metadata_mut()
.insert("authorization", auth_header.clone());
}
Ok(req)
}
}

View File

@@ -1,590 +0,0 @@
//
// API Visible to the spawner, just a function call that is async
//
use std::sync::Arc;
use crate::client_cache;
use pageserver_page_api::GetPageRequest;
use pageserver_page_api::GetPageResponse;
use pageserver_page_api::*;
use pageserver_page_api::proto;
use crate::client_cache::ConnectionPool;
use crate::client_cache::ChannelFactory;
use crate::AuthInterceptor;
use tonic::{transport::{Channel}, Request};
use crate::ClientCacheOptions;
use crate::PageserverClientAggregateMetrics;
use tokio::sync::Mutex;
use std::sync::atomic::{AtomicU64, Ordering};
use utils::shard::ShardIndex;
use tokio_stream::wrappers::ReceiverStream;
use pageserver_page_api::proto::PageServiceClient;
use tonic::{
Status,
Code,
};
use async_trait::async_trait;
use std::time::Duration;
use client_cache::PooledItemFactory;
//use tracing::info;
//
// A mock stream pool that just returns a sending channel, and whenever a GetPageRequest
// comes in on that channel, it randomly sleeps before sending a GetPageResponse
//
#[derive(Clone)]
pub struct StreamReturner {
sender: tokio::sync::mpsc::Sender<proto::GetPageRequest>,
sender_hashmap: Arc<Mutex<std::collections::HashMap<u64, tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, Status>>>>>,
}
pub struct MockStreamFactory {
}
impl MockStreamFactory {
pub fn new() -> Self {
MockStreamFactory {
}
}
}
#[async_trait]
impl PooledItemFactory<StreamReturner> for MockStreamFactory {
async fn create(&self, _connect_timeout: Duration) -> Result<Result<StreamReturner, tonic::Status>, tokio::time::error::Elapsed> {
let (sender, mut receiver) = tokio::sync::mpsc::channel::<proto::GetPageRequest>(1000);
// Create a StreamReturner that will send requests to the receiver channel
let stream_returner = StreamReturner {
sender: sender.clone(),
sender_hashmap: Arc::new(Mutex::new(std::collections::HashMap::new())),
};
let map : Arc<Mutex<std::collections::HashMap<u64, tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, _>>>>>
= Arc::clone(&stream_returner.sender_hashmap);
tokio::spawn(async move {
while let Some(request) = receiver.recv().await {
// Break out of the loop with 1% chance
if rand::random::<f32>() < 0.001 {
break;
}
// Generate a random number between 0 and 100
// Simulate some processing time
let mapclone = Arc::clone(&map);
tokio::spawn(async move {
let sleep_ms = rand::random::<u64>() % 100;
tokio::time::sleep(tokio::time::Duration::from_millis(sleep_ms)).await;
let response = proto::GetPageResponse {
request_id: request.request_id,
..Default::default()
};
// look up stream in hash map
let mut hashmap = mapclone.lock().await;
if let Some(sender) = hashmap.get(&request.request_id) {
// Send the response to the original request sender
if let Err(e) = sender.send(Ok(response.clone())).await {
eprintln!("Failed to send response: {}", e);
}
hashmap.remove(&request.request_id);
} else {
eprintln!("No sender found for request ID: {}", request.request_id);
}
});
}
// Close every sender stream in the hashmap
let hashmap = map.lock().await;
for sender in hashmap.values() {
let error = Status::new(Code::Unknown, "Stream closed");
if let Err(e) = sender.send(Err(error)).await {
eprintln!("Failed to send close response: {}", e);
}
}
});
Ok(Ok(stream_returner))
}
}
pub struct StreamFactory {
connection_pool: Arc<client_cache::ConnectionPool<Channel>>,
auth_interceptor: AuthInterceptor,
shard: ShardIndex,
}
impl StreamFactory {
pub fn new(
connection_pool: Arc<ConnectionPool<Channel>>,
auth_interceptor: AuthInterceptor,
shard: ShardIndex,
) -> Self {
StreamFactory {
connection_pool,
auth_interceptor,
shard,
}
}
}
#[async_trait]
impl PooledItemFactory<StreamReturner> for StreamFactory {
async fn create(&self, _connect_timeout: Duration) ->
Result<Result<StreamReturner, tonic::Status>, tokio::time::error::Elapsed>
{
let pool_clone : Arc<ConnectionPool<Channel>> = Arc::clone(&self.connection_pool);
let pooled_client = pool_clone.get_client().await;
let channel = pooled_client.unwrap().channel();
let mut client =
PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard));
let (sender, receiver) = tokio::sync::mpsc::channel::<proto::GetPageRequest>(1000);
let outbound = ReceiverStream::new(receiver);
let client_resp = client
.get_pages(Request::new(outbound))
.await;
match client_resp {
Err(status) => {
// TODO: Convert this error correctly
Ok(Err(tonic::Status::new(
status.code(),
format!("Failed to connect to pageserver: {}", status.message()),
)))
}
Ok(resp) => {
let stream_returner = StreamReturner {
sender: sender.clone(),
sender_hashmap: Arc::new(Mutex::new(std::collections::HashMap::new())),
};
let map : Arc<Mutex<std::collections::HashMap<u64, tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, _>>>>>
= Arc::clone(&stream_returner.sender_hashmap);
tokio::spawn(async move {
let map_clone = Arc::clone(&map);
let mut inner = resp.into_inner();
loop {
let resp = inner.message().await;
if !resp.is_ok() {
break; // Exit the loop if no more messages
}
let response = resp.unwrap().unwrap();
// look up stream in hash map
let mut hashmap = map_clone.lock().await;
if let Some(sender) = hashmap.get(&response.request_id) {
// Send the response to the original request sender
if let Err(e) = sender.send(Ok(response.clone())).await {
eprintln!("Failed to send response: {}", e);
}
hashmap.remove(&response.request_id);
} else {
eprintln!("No sender found for request ID: {}", response.request_id);
}
}
// Close every sender stream in the hashmap
let hashmap = map_clone.lock().await;
for sender in hashmap.values() {
let error = Status::new(Code::Unknown, "Stream closed");
if let Err(e) = sender.send(Err(error)).await {
eprintln!("Failed to send close response: {}", e);
}
}
});
Ok(Ok(stream_returner))
}
}
}
}
#[derive(Clone)]
pub struct RequestTracker {
cur_id: Arc<AtomicU64>,
stream_pool: Arc<ConnectionPool<StreamReturner>>,
unary_pool: Arc<ConnectionPool<Channel>>,
auth_interceptor: AuthInterceptor,
shard: ShardIndex,
}
impl RequestTracker {
pub fn new(stream_pool: Arc<ConnectionPool<StreamReturner>>,
unary_pool: Arc<ConnectionPool<Channel>>,
auth_interceptor: AuthInterceptor,
shard: ShardIndex,
) -> Self {
let cur_id = Arc::new(AtomicU64::new(0));
RequestTracker {
cur_id: cur_id.clone(),
stream_pool: stream_pool,
unary_pool: unary_pool,
auth_interceptor: auth_interceptor,
shard: shard.clone()
}
}
pub async fn send_process_check_rel_exists_request(
&self,
req: CheckRelExistsRequest,
) -> Result<bool, tonic::Status> {
loop {
let unary_pool = Arc::clone(&self.unary_pool);
let pooled_client = unary_pool.get_client().await.unwrap();
let channel = pooled_client.channel();
let mut ps_client = PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard));
let request = proto::CheckRelExistsRequest::from(req.clone());
let response = ps_client.check_rel_exists(tonic::Request::new(request)).await;
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
continue;
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
return Ok(resp.get_ref().exists);
}
}
}
}
pub async fn send_process_get_rel_size_request(
&self,
req: GetRelSizeRequest,
) -> Result<u32, tonic::Status> {
loop {
// Current sharding model assumes that all metadata is present only at shard 0.
let unary_pool = Arc::clone(&self.unary_pool);
let pooled_client = unary_pool.get_client().await.unwrap();
let channel = pooled_client.channel();
let mut ps_client = PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard));
let request = proto::GetRelSizeRequest::from(req.clone());
let response = ps_client.get_rel_size(tonic::Request::new(request)).await;
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
continue;
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
return Ok(resp.get_ref().num_blocks);
}
}
}
}
pub async fn send_process_get_dbsize_request(
&self,
req: GetDbSizeRequest,
) -> Result<u64, tonic::Status> {
loop {
// Current sharding model assumes that all metadata is present only at shard 0.
let unary_pool = Arc::clone(&self.unary_pool);
let pooled_client = unary_pool.get_client().await.unwrap();let channel = pooled_client.channel();
let mut ps_client = PageServiceClient::with_interceptor(channel, self.auth_interceptor.for_shard(self.shard));
let request = proto::GetDbSizeRequest::from(req.clone());
let response = ps_client.get_db_size(tonic::Request::new(request)).await;
match response {
Err(status) => {
pooled_client.finish(Err(status.clone())).await; // Pass error to finish
continue;
}
Ok(resp) => {
pooled_client.finish(Ok(())).await; // Pass success to finish
return Ok(resp.get_ref().num_bytes);
}
}
}
}
pub async fn send_getpage_request(
&mut self,
req: GetPageRequest,
) -> Result<GetPageResponse, tonic::Status> {
loop {
let mut request = req.clone();
// Increment cur_id
//let request_id = self.cur_id.fetch_add(1, Ordering::SeqCst) + 1;
let request_id = request.request_id;
let response_sender: tokio::sync::mpsc::Sender<Result<proto::GetPageResponse, Status>>;
let mut response_receiver: tokio::sync::mpsc::Receiver<Result<proto::GetPageResponse, Status>>;
(response_sender, response_receiver) = tokio::sync::mpsc::channel(1);
//request.request_id = request_id;
// Get a stream from the stream pool
let pool_clone = Arc::clone(&self.stream_pool);
let sender_stream_pool = pool_clone.get_client().await;
let stream_returner = match sender_stream_pool {
Ok(stream_ret) => stream_ret,
Err(_e) => {
// retry
continue;
}
};
let returner = stream_returner.channel();
let map = returner.sender_hashmap.clone();
// Insert the response sender into the hashmap
{
let mut map_inner = map.lock().await;
map_inner.insert(request_id, response_sender);
}
let sent = returner.sender.send(proto::GetPageRequest::from(request))
.await;
if let Err(_e) = sent {
// Remove the request from the map if sending failed
{
let mut map_inner = map.lock().await;
// remove from hashmap
map_inner.remove(&request_id);
}
stream_returner.finish(Err(Status::new(Code::Unknown,
"Failed to send request"))).await;
continue;
}
let response: Option<Result<proto::GetPageResponse, Status>>;
response = response_receiver.recv().await;
match response {
Some (resp) => {
match resp {
Err(_status) => {
// Handle the case where the response was not received
stream_returner.finish(Err(Status::new(Code::Unknown,
"Failed to receive response"))).await;
continue;
},
Ok(resp) => {
stream_returner.finish(Result::Ok(())).await;
return Ok(resp.clone().into());
}
}
}
None => {
// Handle the case where the response channel was closed
stream_returner.finish(Err(Status::new(Code::Unknown,
"Response channel closed"))).await;
continue;
}
}
}
}
}
struct ShardedRequestTrackerInner {
// Hashmap of shard index to RequestTracker
trackers: std::collections::HashMap<ShardIndex, RequestTracker>,
}
pub struct ShardedRequestTracker {
inner: Arc<Mutex<ShardedRequestTrackerInner>>,
tcp_client_cache_options: ClientCacheOptions,
stream_client_cache_options: ClientCacheOptions,
}
//
// TODO: Functions in the ShardedRequestTracker should be able to timeout and
// cancel a reqeust. The request should return an error if it is cancelled.
//
impl ShardedRequestTracker {
pub fn new() -> Self {
//
// Default configuration for the client. These could be added to a config file
//
let tcp_client_cache_options = ClientCacheOptions {
max_delay_ms: 0,
drop_rate: 0.0,
hang_rate: 0.0,
connect_timeout: Duration::from_secs(1),
connect_backoff: Duration::from_millis(100),
max_consumers: 8, // Streams per connection
error_threshold: 10,
max_idle_duration: Duration::from_secs(5),
max_total_connections: 8,
};
let stream_client_cache_options = ClientCacheOptions {
max_delay_ms: 0,
drop_rate: 0.0,
hang_rate: 0.0,
connect_timeout: Duration::from_secs(1),
connect_backoff: Duration::from_millis(100),
max_consumers: 64, // Requests per stream
error_threshold: 10,
max_idle_duration: Duration::from_secs(5),
max_total_connections: 64, // Total allowable number of streams
};
ShardedRequestTracker {
inner: Arc::new(Mutex::new(ShardedRequestTrackerInner {
trackers: std::collections::HashMap::new(),
})),
tcp_client_cache_options,
stream_client_cache_options,
}
}
pub async fn update_shard_map(&self,
shard_urls: std::collections::HashMap<ShardIndex, String>,
metrics: Option<Arc<PageserverClientAggregateMetrics>>,
tenant_id: String, timeline_id: String, auth_str: Option<&str>) {
let mut trackers = std::collections::HashMap::new();
for (shard, endpoint_url) in shard_urls {
//
// Create a pool of streams for streaming get_page requests
//
let channel_fact : Arc<dyn PooledItemFactory<Channel> + Send + Sync> = Arc::new(ChannelFactory::new(
endpoint_url.clone(),
self.tcp_client_cache_options.max_delay_ms,
self.tcp_client_cache_options.drop_rate,
self.tcp_client_cache_options.hang_rate,
));
let new_pool: Arc<ConnectionPool<Channel>>;
new_pool = ConnectionPool::new(
Arc::clone(&channel_fact),
self.tcp_client_cache_options.connect_timeout,
self.tcp_client_cache_options.connect_backoff,
self.tcp_client_cache_options.max_consumers,
self.tcp_client_cache_options.error_threshold,
self.tcp_client_cache_options.max_idle_duration,
self.tcp_client_cache_options.max_total_connections,
metrics.clone(),
);
let auth_interceptor = AuthInterceptor::new(tenant_id.as_str(),
timeline_id.as_str(),
auth_str);
let stream_pool = ConnectionPool::<StreamReturner>::new(
Arc::new(StreamFactory::new(new_pool.clone(),
auth_interceptor.clone(), ShardIndex::unsharded())),
self.stream_client_cache_options.connect_timeout,
self.stream_client_cache_options.connect_backoff,
self.stream_client_cache_options.max_consumers,
self.stream_client_cache_options.error_threshold,
self.stream_client_cache_options.max_idle_duration,
self.stream_client_cache_options.max_total_connections,
metrics.clone(),
);
//
// Create a client pool for unary requests
//
let unary_pool: Arc<ConnectionPool<Channel>>;
unary_pool = ConnectionPool::new(
Arc::clone(&channel_fact),
self.tcp_client_cache_options.connect_timeout,
self.tcp_client_cache_options.connect_backoff,
self.tcp_client_cache_options.max_consumers,
self.tcp_client_cache_options.error_threshold,
self.tcp_client_cache_options.max_idle_duration,
self.tcp_client_cache_options.max_total_connections,
metrics.clone()
);
//
// Create a new RequestTracker for this shard
//
let new_tracker = RequestTracker::new(stream_pool, unary_pool, auth_interceptor, shard);
trackers.insert(shard, new_tracker);
}
let mut inner = self.inner.lock().await;
inner.trackers = trackers;
}
pub async fn get_page(
&self,
req: GetPageRequest,
) -> Result<GetPageResponse, tonic::Status> {
// Get shard index from the request
let shard_index = ShardIndex::unsharded();
let inner = self.inner.lock().await;
let mut tracker : RequestTracker;
if let Some(t) = inner.trackers.get(&shard_index) {
tracker = t.clone();
} else {
return Err(tonic::Status::not_found(format!("Shard {} not found", shard_index)));
}
drop(inner);
// Call the send_getpage_request method on the tracker
let response = tracker.send_getpage_request(req).await;
match response {
Ok(resp) => Ok(resp),
Err(e) => Err(tonic::Status::unknown(format!("Failed to get page: {}", e))),
}
}
pub async fn process_get_dbsize_request(
&self,
request: GetDbSizeRequest,
) -> Result<u64, tonic::Status> {
let shard_index = ShardIndex::unsharded();
let inner = self.inner.lock().await;
let mut tracker: RequestTracker;
if let Some(t) = inner.trackers.get(&shard_index) {
tracker = t.clone();
} else {
return Err(tonic::Status::not_found(format!("Shard {} not found", shard_index)));
}
drop(inner); // Release the lock before calling send_process_get_dbsize_request
// Call the send_process_get_dbsize_request method on the tracker
let response = tracker.send_process_get_dbsize_request(request).await;
match response {
Ok(resp) => Ok(resp),
Err(e) => Err(e),
}
}
pub async fn process_get_rel_size_request(
&self,
request: GetRelSizeRequest,
) -> Result<u32, tonic::Status> {
let shard_index = ShardIndex::unsharded();
let inner = self.inner.lock().await;
let mut tracker: RequestTracker;
if let Some(t) = inner.trackers.get(&shard_index) {
tracker = t.clone();
} else {
return Err(tonic::Status::not_found(format!("Shard {} not found", shard_index)));
}
drop(inner); // Release the lock before calling send_process_get_rel_size_request
// Call the send_process_get_rel_size_request method on the tracker
let response = tracker.send_process_get_rel_size_request(request).await;
match response {
Ok(resp) => Ok(resp),
Err(e) => Err(e),
}
}
pub async fn process_check_rel_exists_request(
&self,
request: CheckRelExistsRequest,
) -> Result<bool, tonic::Status> {
let shard_index = ShardIndex::unsharded();
let inner = self.inner.lock().await;
let mut tracker: RequestTracker;
if let Some(t) = inner.trackers.get(&shard_index) {
tracker = t.clone();
} else {
return Err(tonic::Status::not_found(format!("Shard {} not found", shard_index)));
}
drop(inner); // Release the lock before calling send_process_check_rel_exists_request
// Call the send_process_check_rel_exists_request method on the tracker
let response = tracker.send_process_check_rel_exists_request(request).await;
match response {
Ok(resp) => Ok(resp),
Err(e) => Err(e),
}
}
}

View File

@@ -104,8 +104,8 @@ message CheckRelExistsResponse {
// Requests a base backup at a given LSN.
message GetBaseBackupRequest {
// The LSN to fetch a base backup at. 0 or absent means the latest LSN known to the Pageserver.
uint64 lsn = 1;
// The LSN to fetch a base backup at.
ReadLsn read_lsn = 1;
// If true, logical replication slots will not be created.
bool replica = 2;
}

View File

@@ -185,25 +185,30 @@ impl From<CheckRelExistsResponse> for proto::CheckRelExistsResponse {
/// Requests a base backup at a given LSN.
#[derive(Clone, Copy, Debug)]
pub struct GetBaseBackupRequest {
/// The LSN to fetch a base backup at. If None, uses the latest LSN known to the Pageserver.
pub lsn: Option<Lsn>,
/// The LSN to fetch a base backup at.
pub read_lsn: ReadLsn,
/// If true, logical replication slots will not be created.
pub replica: bool,
}
impl From<proto::GetBaseBackupRequest> for GetBaseBackupRequest {
fn from(pb: proto::GetBaseBackupRequest) -> Self {
Self {
lsn: (pb.lsn != 0).then_some(Lsn(pb.lsn)),
impl TryFrom<proto::GetBaseBackupRequest> for GetBaseBackupRequest {
type Error = ProtocolError;
fn try_from(pb: proto::GetBaseBackupRequest) -> Result<Self, Self::Error> {
Ok(Self {
read_lsn: pb
.read_lsn
.ok_or(ProtocolError::Missing("read_lsn"))?
.try_into()?,
replica: pb.replica,
}
})
}
}
impl From<GetBaseBackupRequest> for proto::GetBaseBackupRequest {
fn from(request: GetBaseBackupRequest) -> Self {
Self {
lsn: request.lsn.unwrap_or_default().0,
read_lsn: Some(request.read_lsn.into()),
replica: request.replica,
}
}
@@ -482,7 +487,6 @@ impl From<GetPageStatusCode> for i32 {
// Fetches the size of a relation at a given LSN, as # of blocks. Only valid on shard 0, other
// shards will error.
#[derive(Clone)]
pub struct GetRelSizeRequest {
pub read_lsn: ReadLsn,
pub rel: RelTag,

View File

@@ -24,13 +24,9 @@ tracing.workspace = true
tokio.workspace = true
tokio-stream.workspace = true
tokio-util.workspace = true
axum.workspace = true
http.workspace = true
metrics.workspace = true
tonic.workspace = true
pageserver_client.workspace = true
pageserver_client_grpc.workspace = true
pageserver_api.workspace = true
pageserver_page_api.workspace = true
utils = { path = "../../libs/utils/" }

View File

@@ -9,15 +9,12 @@ use anyhow::Context;
use pageserver_api::shard::TenantShardId;
use pageserver_client::mgmt_api::ForceAwaitLogicalSize;
use pageserver_client::page_service::BasebackupRequest;
use pageserver_page_api::{GetBaseBackupRequest, ReadLsn};
use rand::prelude::*;
use tokio::sync::Barrier;
use tokio::task::JoinSet;
use tracing::{info, instrument};
use utils::id::TenantTimelineId;
use utils::lsn::Lsn;
use utils::shard::ShardIndex;
use crate::util::tokio_thread_local_stats::AllThreadLocalStats;
use crate::util::{request_stats, tokio_thread_local_stats};
@@ -25,8 +22,6 @@ use crate::util::{request_stats, tokio_thread_local_stats};
/// basebackup@LatestLSN
#[derive(clap::Parser)]
pub(crate) struct Args {
#[clap(long, default_value = "false")]
grpc: bool,
#[clap(long, default_value = "http://localhost:9898")]
mgmt_api_endpoint: String,
#[clap(long, default_value = "postgres://postgres@localhost:64000")]
@@ -57,7 +52,7 @@ impl LiveStats {
struct Target {
timeline: TenantTimelineId,
lsn_range: Range<Lsn>,
lsn_range: Option<Range<Lsn>>,
}
#[derive(serde::Serialize)]
@@ -110,7 +105,7 @@ async fn main_impl(
anyhow::Ok(Target {
timeline,
// TODO: support lsn_range != latest LSN
lsn_range: info.last_record_lsn..(info.last_record_lsn + 1),
lsn_range: Some(info.last_record_lsn..(info.last_record_lsn + 1)),
})
}
});
@@ -154,27 +149,14 @@ async fn main_impl(
for tl in &timelines {
let (sender, receiver) = tokio::sync::mpsc::channel(1); // TODO: not sure what the implications of this are
work_senders.insert(tl, sender);
let client_task = if args.grpc {
tokio::spawn(client_grpc(
args,
*tl,
Arc::clone(&start_work_barrier),
receiver,
Arc::clone(&all_work_done_barrier),
Arc::clone(&live_stats),
))
} else {
tokio::spawn(client(
args,
*tl,
Arc::clone(&start_work_barrier),
receiver,
Arc::clone(&all_work_done_barrier),
Arc::clone(&live_stats),
))
};
tasks.push(client_task);
tasks.push(tokio::spawn(client(
args,
*tl,
Arc::clone(&start_work_barrier),
receiver,
Arc::clone(&all_work_done_barrier),
Arc::clone(&live_stats),
)));
}
let work_sender = async move {
@@ -183,7 +165,7 @@ async fn main_impl(
let (timeline, work) = {
let mut rng = rand::thread_rng();
let target = all_targets.choose(&mut rng).unwrap();
let lsn = rng.gen_range(target.lsn_range.clone());
let lsn = target.lsn_range.clone().map(|r| rng.gen_range(r));
(
target.timeline,
Work {
@@ -233,7 +215,7 @@ async fn main_impl(
#[derive(Copy, Clone)]
struct Work {
lsn: Lsn,
lsn: Option<Lsn>,
gzip: bool,
}
@@ -258,7 +240,7 @@ async fn client(
.basebackup(&BasebackupRequest {
tenant_id: timeline.tenant_id,
timeline_id: timeline.timeline_id,
lsn: Some(lsn),
lsn,
gzip,
})
.await
@@ -288,71 +270,3 @@ async fn client(
all_work_done_barrier.wait().await;
}
#[instrument(skip_all)]
async fn client_grpc(
args: &'static Args,
timeline: TenantTimelineId,
start_work_barrier: Arc<Barrier>,
mut work: tokio::sync::mpsc::Receiver<Work>,
all_work_done_barrier: Arc<Barrier>,
live_stats: Arc<LiveStats>,
) {
let shard_map = HashMap::from([(
ShardIndex::unsharded(),
args.page_service_connstring.clone(),
)]);
let client = pageserver_client_grpc::PageserverClient::new(
&timeline.tenant_id.to_string(),
&timeline.timeline_id.to_string(),
&None,
shard_map,
);
start_work_barrier.wait().await;
while let Some(Work { lsn, gzip }) = work.recv().await {
let start = Instant::now();
//tokio::time::sleep(std::time::Duration::from_secs(1)).await;
info!("starting get_base_backup");
let mut basebackup_stream = client
.get_base_backup(
GetBaseBackupRequest {
lsn: Some(lsn),
replica: false,
},
gzip,
)
.await
.with_context(|| format!("start basebackup for {timeline}"))
.unwrap()
.into_inner();
info!("starting receive");
use futures::StreamExt;
let mut size = 0;
let mut nchunks = 0;
while let Some(chunk) = basebackup_stream.next().await {
let chunk = chunk
.with_context(|| format!("error during basebackup"))
.unwrap();
size += chunk.chunk.len();
nchunks += 1;
}
info!(
"basebackup size is {} bytes, avg chunk size {} bytes",
size,
size as f32 / nchunks as f32
);
let elapsed = start.elapsed();
live_stats.inc();
STATS.with(|stats| {
stats.borrow().lock().unwrap().observe(elapsed).unwrap();
});
}
all_work_done_barrier.wait().await;
}

View File

@@ -1,11 +1,10 @@
use std::collections::{HashSet, HashMap, VecDeque};
use std::collections::{HashMap, HashSet, VecDeque};
use std::future::Future;
use std::num::NonZeroUsize;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use std::io::Error;
use anyhow::Context;
use async_trait::async_trait;
@@ -24,20 +23,6 @@ use tracing::info;
use utils::id::TenantTimelineId;
use utils::lsn::Lsn;
use tonic::transport::Channel;
use axum::Router;
use axum::body::Body;
use axum::extract::State;
use axum::response::Response;
use http::StatusCode;
use http::header::CONTENT_TYPE;
use metrics;
use metrics::proto::MetricFamily;
use metrics::{Encoder, TextEncoder};
use crate::util::tokio_thread_local_stats::AllThreadLocalStats;
use crate::util::{request_stats, tokio_thread_local_stats};
@@ -50,10 +35,6 @@ enum Protocol {
/// GetPage@LatestLSN, uniformly distributed across the compute-accessible keyspace.
#[derive(clap::Parser)]
pub(crate) struct Args {
#[clap(long, default_value = "false")]
grpc: bool,
#[clap(long, default_value = "false")]
grpc_stream: bool,
#[clap(long, default_value = "http://localhost:9898")]
mgmt_api_endpoint: String,
#[clap(long, default_value = "postgres://postgres@localhost:64000")]
@@ -92,9 +73,6 @@ pub(crate) struct Args {
#[clap(long)]
set_io_mode: Option<pageserver_api::models::virtual_file::IoMode>,
#[clap(long)]
only_relnode: Option<u32>,
/// Queue depth generated in each client.
#[clap(long, default_value = "1")]
queue_depth: NonZeroUsize,
@@ -109,31 +87,10 @@ pub(crate) struct Args {
#[clap(long, default_value = "1")]
batch_size: NonZeroUsize,
#[clap(long)]
only_relnode: Option<u32>,
targets: Option<Vec<TenantTimelineId>>,
#[clap(long, default_value = "100")]
pool_max_consumers: NonZeroUsize,
#[clap(long, default_value = "5")]
pool_error_threshold: NonZeroUsize,
#[clap(long, default_value = "5000")]
pool_connect_timeout: NonZeroUsize,
#[clap(long, default_value = "1000")]
pool_connect_backoff: NonZeroUsize,
#[clap(long, default_value = "60000")]
pool_max_idle_duration: NonZeroUsize,
#[clap(long, default_value = "0")]
max_delay_ms: usize,
#[clap(long, default_value = "0")]
percent_drops: usize,
#[clap(long, default_value = "0")]
percent_hangs: usize,
}
/// State shared by all clients
@@ -190,37 +147,6 @@ pub(crate) fn main(args: Args) -> anyhow::Result<()> {
main_impl(args, thread_local_stats)
})
}
async fn get_metrics(
State(state): State<Arc<pageserver_client_grpc::PageserverClientAggregateMetrics>>,
) -> Response {
let metrics = state.collect();
info!("metrics: {metrics:?}");
// When we call TextEncoder::encode() below, it will immediately return an
// error if a metric family has no metrics, so we need to preemptively
// filter out metric families with no metrics.
let metrics = metrics
.into_iter()
.filter(|m| !m.get_metric().is_empty())
.collect::<Vec<MetricFamily>>();
let encoder = TextEncoder::new();
let mut buffer = vec![];
if let Err(e) = encoder.encode(&metrics, &mut buffer) {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(CONTENT_TYPE, "application/text")
.body(Body::from(e.to_string()))
.unwrap()
} else {
Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, encoder.format_type())
.body(Body::from(buffer))
.unwrap()
}
}
async fn main_impl(
args: Args,
@@ -228,24 +154,6 @@ async fn main_impl(
) -> anyhow::Result<()> {
let args: &'static Args = Box::leak(Box::new(args));
// Vector of pageserver clients
let client_metrics = Arc::new(pageserver_client_grpc::PageserverClientAggregateMetrics::new());
use axum::routing::get;
let app = Router::new()
.route("/metrics", get(get_metrics))
.with_state(client_metrics.clone());
// TODO: make configurable. Or listen on unix domain socket?
let listener = tokio::net::TcpListener::bind("127.0.0.1:9090")
.await
.unwrap();
tokio::spawn(async {
tracing::info!("metrics listener spawned");
axum::serve(listener, app).await.unwrap()
});
let mgmt_api_client = Arc::new(pageserver_client::mgmt_api::Client::new(
reqwest::Client::new(), // TODO: support ssl_ca_file for https APIs in pagebench.
args.mgmt_api_endpoint.clone(),
@@ -404,7 +312,6 @@ async fn main_impl(
let rps_period = args
.per_client_rate
.map(|rps_limit| Duration::from_secs_f64(1.0 / (rps_limit as f64)));
let make_worker: &dyn Fn(WorkerId) -> Pin<Box<dyn Send + Future<Output = ()>>> = &|worker_id| {
let ss = shared_state.clone();
let cancel = cancel.clone();
@@ -430,7 +337,6 @@ async fn main_impl(
.await
.unwrap(),
),
};
run_worker(args, client, ss, cancel, rps_period, ranges, weights).await
})
@@ -698,7 +604,6 @@ impl Client for LibpqClient {
struct GrpcClient {
req_tx: tokio::sync::mpsc::Sender<proto::GetPageRequest>,
resp_rx: tonic::Streaming<proto::GetPageResponse>,
start_times: Vec<Instant>,
}
impl GrpcClient {
@@ -722,7 +627,6 @@ impl GrpcClient {
Ok(Self {
req_tx,
resp_rx: resp_stream,
start_times: Vec::new(),
})
}
}
@@ -747,7 +651,6 @@ impl Client for GrpcClient {
rel: Some(rel.into()),
block_number: blks,
};
self.start_times.push(Instant::now());
self.req_tx.send(req).await?;
Ok(())
}
@@ -762,4 +665,3 @@ impl Client for GrpcClient {
Ok((resp.request_id, resp.page_image))
}
}

View File

@@ -23,6 +23,7 @@ use pageserver::deletion_queue::DeletionQueue;
use pageserver::disk_usage_eviction_task::{self, launch_disk_usage_global_eviction_task};
use pageserver::feature_resolver::FeatureResolver;
use pageserver::metrics::{STARTUP_DURATION, STARTUP_IS_LOADING};
use pageserver::page_service::GrpcPageServiceHandler;
use pageserver::task_mgr::{
BACKGROUND_RUNTIME, COMPUTE_REQUEST_RUNTIME, MGMT_REQUEST_RUNTIME, WALRECEIVER_RUNTIME,
};
@@ -572,7 +573,8 @@ fn start_pageserver(
tokio::sync::mpsc::unbounded_channel();
let deletion_queue_client = deletion_queue.new_client();
let background_purges = mgr::BackgroundPurges::default();
let tenant_manager = BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(
let tenant_manager = mgr::init(
conf,
background_purges.clone(),
TenantSharedResources {
@@ -583,10 +585,10 @@ fn start_pageserver(
basebackup_prepare_sender,
feature_resolver,
},
order,
shutdown_pageserver.clone(),
))?;
);
let tenant_manager = Arc::new(tenant_manager);
BACKGROUND_RUNTIME.block_on(mgr::init_tenant_mgr(tenant_manager.clone(), order))?;
let basebackup_cache = BasebackupCache::spawn(
BACKGROUND_RUNTIME.handle(),
@@ -814,7 +816,7 @@ fn start_pageserver(
// necessary?
let mut page_service_grpc = None;
if let Some(grpc_listener) = grpc_listener {
page_service_grpc = Some(page_service::spawn_grpc(
page_service_grpc = Some(GrpcPageServiceHandler::spawn(
tenant_manager.clone(),
grpc_auth,
otel_guard.as_ref().map(|g| g.dispatch.clone()),

View File

@@ -195,8 +195,6 @@ impl StorageControllerUpcallApi for StorageControllerUpcallClient {
node_id: conf.id,
listen_pg_addr: m.postgres_host,
listen_pg_port: m.postgres_port,
listen_grpc_addr: m.grpc_host,
listen_grpc_port: m.grpc_port,
listen_http_addr: m.http_host,
listen_http_port: m.http_port,
listen_https_port: m.https_port,

View File

@@ -1,5 +1,6 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use pageserver_api::config::NodeMetadata;
use posthog_client_lite::{
CaptureEvent, FeatureResolverBackgroundLoop, PostHogClientConfig, PostHogEvaluationError,
PostHogFlagFilterPropertyValue,
@@ -86,7 +87,35 @@ impl FeatureResolver {
}
}
}
// TODO: add pageserver URL.
// TODO: move this to a background task so that we don't block startup in case of slow disk
let metadata_path = conf.metadata_path();
match std::fs::read_to_string(&metadata_path) {
Ok(metadata_str) => match serde_json::from_str::<NodeMetadata>(&metadata_str) {
Ok(metadata) => {
properties.insert(
"hostname".to_string(),
PostHogFlagFilterPropertyValue::String(metadata.http_host),
);
if let Some(cplane_region) = metadata.other.get("region_id") {
if let Some(cplane_region) = cplane_region.as_str() {
// This region contains the cell number
properties.insert(
"neon_region".to_string(),
PostHogFlagFilterPropertyValue::String(
cplane_region.to_string(),
),
);
}
}
}
Err(e) => {
tracing::warn!("Failed to parse metadata.json: {}", e);
}
},
Err(e) => {
tracing::warn!("Failed to read metadata.json: {}", e);
}
}
Arc::new(properties)
};
let fake_tenants = {

View File

@@ -1053,6 +1053,15 @@ pub(crate) static TENANT_STATE_METRIC: Lazy<UIntGaugeVec> = Lazy::new(|| {
.expect("Failed to register pageserver_tenant_states_count metric")
});
pub(crate) static TIMELINE_STATE_METRIC: Lazy<UIntGaugeVec> = Lazy::new(|| {
register_uint_gauge_vec!(
"pageserver_timeline_states_count",
"Count of timelines per state",
&["state"]
)
.expect("Failed to register pageserver_timeline_states_count metric")
});
/// A set of broken tenants.
///
/// These are expected to be so rare that a set is fine. Set as in a new timeseries per each broken
@@ -3325,6 +3334,8 @@ impl TimelineMetrics {
&timeline_id,
);
TIMELINE_STATE_METRIC.with_label_values(&["active"]).inc();
TimelineMetrics {
tenant_id,
shard_id,
@@ -3479,6 +3490,8 @@ impl TimelineMetrics {
return;
}
TIMELINE_STATE_METRIC.with_label_values(&["active"]).dec();
let tenant_id = &self.tenant_id;
let timeline_id = &self.timeline_id;
let shard_id = &self.shard_id;

View File

@@ -14,7 +14,7 @@ use std::{io, str};
use anyhow::{Context as _, anyhow, bail};
use async_compression::tokio::write::GzipEncoder;
use bytes::{Buf, BufMut as _, BytesMut};
use bytes::{Buf, BytesMut};
use futures::future::BoxFuture;
use futures::{FutureExt, Stream};
use itertools::Itertools;
@@ -169,99 +169,6 @@ pub fn spawn(
Listener { cancel, task }
}
/// Spawns a gRPC server for the page service.
///
/// TODO: move this onto GrpcPageServiceHandler::spawn().
/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we
/// need to reimplement the TCP+TLS accept loop ourselves.
pub fn spawn_grpc(
tenant_manager: Arc<TenantManager>,
auth: Option<Arc<SwappableJwtAuth>>,
perf_trace_dispatch: Option<Dispatch>,
get_vectored_concurrent_io: GetVectoredConcurrentIo,
listener: std::net::TcpListener,
) -> anyhow::Result<CancellableTask> {
let cancel = CancellationToken::new();
let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler)
.download_behavior(DownloadBehavior::Download)
.perf_span_dispatch(perf_trace_dispatch)
.detached_child();
let gate = Gate::default();
// Set up the TCP socket. We take a preconfigured TcpListener to bind the
// port early during startup.
let incoming = {
let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std
listener.set_nonblocking(true)?;
tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std(listener)?)
.with_nodelay(Some(GRPC_TCP_NODELAY))
.with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME))
};
// Set up the gRPC server.
//
// TODO: consider tuning window sizes.
let mut server = tonic::transport::Server::builder()
.http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL))
.http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT))
.max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS));
// Main page service stack. Uses a mix of Tonic interceptors and Tower layers:
//
// * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service.
//
// * Layers: allow async code, can run code after the service response. However, only has access
// to the raw HTTP request/response, not the gRPC types.
let page_service_handler = GrpcPageServiceHandler {
tenant_manager,
ctx,
gate_guard: gate.enter().expect("gate was just created"),
get_vectored_concurrent_io,
};
let observability_layer = ObservabilityLayer;
let mut tenant_interceptor = TenantMetadataInterceptor;
let mut auth_interceptor = TenantAuthInterceptor::new(auth);
let page_service = tower::ServiceBuilder::new()
// Create tracing span and record request start time.
.layer(observability_layer)
// Intercept gRPC requests.
.layer(tonic::service::InterceptorLayer::new(move |mut req| {
// Extract tenant metadata.
req = tenant_interceptor.call(req)?;
// Authenticate tenant JWT token.
req = auth_interceptor.call(req)?;
Ok(req)
}))
.service(proto::PageServiceServer::new(page_service_handler));
let server = server.add_service(page_service);
// Reflection service for use with e.g. grpcurl.
let reflection_service = tonic_reflection::server::Builder::configure()
.register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET)
.build_v1()?;
let server = server.add_service(reflection_service);
// Spawn server task.
let task_cancel = cancel.clone();
let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error(
"grpc listener",
async move {
let result = server
.serve_with_incoming_shutdown(incoming, task_cancel.cancelled())
.await;
if result.is_ok() {
// TODO: revisit shutdown logic once page service is implemented.
gate.close().await;
}
result
},
));
Ok(CancellableTask { task, cancel })
}
impl Listener {
pub async fn stop_accepting(self) -> Connections {
self.cancel.cancel();
@@ -3366,6 +3273,101 @@ pub struct GrpcPageServiceHandler {
}
impl GrpcPageServiceHandler {
/// Spawns a gRPC server for the page service.
///
/// TODO: this doesn't support TLS. We need TLS reloading via ReloadingCertificateResolver, so we
/// need to reimplement the TCP+TLS accept loop ourselves.
pub fn spawn(
tenant_manager: Arc<TenantManager>,
auth: Option<Arc<SwappableJwtAuth>>,
perf_trace_dispatch: Option<Dispatch>,
get_vectored_concurrent_io: GetVectoredConcurrentIo,
listener: std::net::TcpListener,
) -> anyhow::Result<CancellableTask> {
let cancel = CancellationToken::new();
let ctx = RequestContextBuilder::new(TaskKind::PageRequestHandler)
.download_behavior(DownloadBehavior::Download)
.perf_span_dispatch(perf_trace_dispatch)
.detached_child();
let gate = Gate::default();
// Set up the TCP socket. We take a preconfigured TcpListener to bind the
// port early during startup.
let incoming = {
let _runtime = COMPUTE_REQUEST_RUNTIME.enter(); // required by TcpListener::from_std
listener.set_nonblocking(true)?;
tonic::transport::server::TcpIncoming::from(tokio::net::TcpListener::from_std(
listener,
)?)
.with_nodelay(Some(GRPC_TCP_NODELAY))
.with_keepalive(Some(GRPC_TCP_KEEPALIVE_TIME))
};
// Set up the gRPC server.
//
// TODO: consider tuning window sizes.
let mut server = tonic::transport::Server::builder()
.http2_keepalive_interval(Some(GRPC_HTTP2_KEEPALIVE_INTERVAL))
.http2_keepalive_timeout(Some(GRPC_HTTP2_KEEPALIVE_TIMEOUT))
.max_concurrent_streams(Some(GRPC_MAX_CONCURRENT_STREAMS));
// Main page service stack. Uses a mix of Tonic interceptors and Tower layers:
//
// * Interceptors: can inspect and modify the gRPC request. Sync code only, runs before service.
//
// * Layers: allow async code, can run code after the service response. However, only has access
// to the raw HTTP request/response, not the gRPC types.
let page_service_handler = GrpcPageServiceHandler {
tenant_manager,
ctx,
gate_guard: gate.enter().expect("gate was just created"),
get_vectored_concurrent_io,
};
let observability_layer = ObservabilityLayer;
let mut tenant_interceptor = TenantMetadataInterceptor;
let mut auth_interceptor = TenantAuthInterceptor::new(auth);
let page_service = tower::ServiceBuilder::new()
// Create tracing span and record request start time.
.layer(observability_layer)
// Intercept gRPC requests.
.layer(tonic::service::InterceptorLayer::new(move |mut req| {
// Extract tenant metadata.
req = tenant_interceptor.call(req)?;
// Authenticate tenant JWT token.
req = auth_interceptor.call(req)?;
Ok(req)
}))
// Run the page service.
.service(proto::PageServiceServer::new(page_service_handler));
let server = server.add_service(page_service);
// Reflection service for use with e.g. grpcurl.
let reflection_service = tonic_reflection::server::Builder::configure()
.register_encoded_file_descriptor_set(proto::FILE_DESCRIPTOR_SET)
.build_v1()?;
let server = server.add_service(reflection_service);
// Spawn server task.
let task_cancel = cancel.clone();
let task = COMPUTE_REQUEST_RUNTIME.spawn(task_mgr::exit_on_panic_or_error(
"grpc listener",
async move {
let result = server
.serve_with_incoming_shutdown(incoming, task_cancel.cancelled())
.await;
if result.is_ok() {
// TODO: revisit shutdown logic once page service is implemented.
gate.close().await;
}
result
},
));
Ok(CancellableTask { task, cancel })
}
/// Errors if the request is executed on a non-zero shard. Only shard 0 has a complete view of
/// relations and their sizes, as well as SLRU segments and similar data.
#[allow(clippy::result_large_err)]
@@ -3606,38 +3608,35 @@ impl proto::PageService for GrpcPageServiceHandler {
if timeline.is_archived() == Some(true) {
return Err(tonic::Status::failed_precondition("timeline is archived"));
}
let req: page_api::GetBaseBackupRequest = req.into_inner().into();
let req: page_api::GetBaseBackupRequest = req.into_inner().try_into()?;
span_record!(lsn=?req.lsn);
span_record!(lsn=%req.read_lsn);
if let Some(lsn) = req.lsn {
let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn();
timeline
.wait_lsn(
lsn,
WaitLsnWaiter::PageService,
WaitLsnTimeout::Default,
&ctx,
)
.await?;
timeline
.check_lsn_is_in_scope(lsn, &latest_gc_cutoff_lsn)
.map_err(|err| {
tonic::Status::invalid_argument(format!("invalid basebackup LSN: {err}"))
})?;
}
let latest_gc_cutoff_lsn = timeline.get_applied_gc_cutoff_lsn();
timeline
.wait_lsn(
req.read_lsn.request_lsn,
WaitLsnWaiter::PageService,
WaitLsnTimeout::Default,
&ctx,
)
.await?;
timeline
.check_lsn_is_in_scope(req.read_lsn.request_lsn, &latest_gc_cutoff_lsn)
.map_err(|err| {
tonic::Status::invalid_argument(format!("invalid basebackup LSN: {err}"))
})?;
// Spawn a task to run the basebackup.
//
// TODO: do we need to support full base backups, for debugging? This also requires passing
// the prev_lsn parameter.
// TODO: do we need to support full base backups, for debugging?
let span = Span::current();
let (mut simplex_read, mut simplex_write) = tokio::io::simplex(CHUNK_SIZE);
let jh = tokio::spawn(async move {
let result = basebackup::send_basebackup_tarball(
&mut simplex_write,
&timeline,
req.lsn,
Some(req.read_lsn.request_lsn),
None,
false,
req.replica,
@@ -3653,21 +3652,20 @@ impl proto::PageService for GrpcPageServiceHandler {
// Emit chunks of size CHUNK_SIZE.
let chunks = async_stream::try_stream! {
let mut chunk = BytesMut::with_capacity(CHUNK_SIZE);
loop {
let mut chunk = BytesMut::with_capacity(CHUNK_SIZE).limit(CHUNK_SIZE);
loop {
let n = simplex_read.read_buf(&mut chunk).await.map_err(|err| {
tonic::Status::internal(format!("failed to read basebackup chunk: {err}"))
})?;
if n == 0 {
break; // full chunk or closed stream
let n = simplex_read.read_buf(&mut chunk).await.map_err(|err| {
tonic::Status::internal(format!("failed to read basebackup chunk: {err}"))
})?;
// If we read 0 bytes, either the chunk is full or the stream is closed.
if n == 0 {
if chunk.is_empty() {
break;
}
yield proto::GetBaseBackupResponseChunk::from(chunk.clone().freeze());
chunk.clear();
}
let chunk = chunk.into_inner().freeze();
if chunk.is_empty() {
break;
}
yield proto::GetBaseBackupResponseChunk::from(chunk);
}
// Wait for the basebackup task to exit and check for errors.
jh.await.map_err(|err| {

View File

@@ -89,7 +89,8 @@ use crate::l0_flush::L0FlushGlobalState;
use crate::metrics::{
BROKEN_TENANTS_SET, CIRCUIT_BREAKERS_BROKEN, CIRCUIT_BREAKERS_UNBROKEN, CONCURRENT_INITDBS,
INITDB_RUN_TIME, INITDB_SEMAPHORE_ACQUISITION_TIME, TENANT, TENANT_OFFLOADED_TIMELINES,
TENANT_STATE_METRIC, TENANT_SYNTHETIC_SIZE_METRIC, remove_tenant_metrics,
TENANT_STATE_METRIC, TENANT_SYNTHETIC_SIZE_METRIC, TIMELINE_STATE_METRIC,
remove_tenant_metrics,
};
use crate::task_mgr::TaskKind;
use crate::tenant::config::LocationMode;
@@ -544,6 +545,28 @@ pub struct OffloadedTimeline {
/// Part of the `OffloadedTimeline` object's lifecycle: this needs to be set before we drop it
pub deleted_from_ancestor: AtomicBool,
_metrics_guard: OffloadedTimelineMetricsGuard,
}
/// Increases the offloaded timeline count metric when created, and decreases when dropped.
struct OffloadedTimelineMetricsGuard;
impl OffloadedTimelineMetricsGuard {
fn new() -> Self {
TIMELINE_STATE_METRIC
.with_label_values(&["offloaded"])
.inc();
Self
}
}
impl Drop for OffloadedTimelineMetricsGuard {
fn drop(&mut self) {
TIMELINE_STATE_METRIC
.with_label_values(&["offloaded"])
.dec();
}
}
impl OffloadedTimeline {
@@ -576,6 +599,8 @@ impl OffloadedTimeline {
delete_progress: timeline.delete_progress.clone(),
deleted_from_ancestor: AtomicBool::new(false),
_metrics_guard: OffloadedTimelineMetricsGuard::new(),
})
}
fn from_manifest(tenant_shard_id: TenantShardId, manifest: &OffloadedTimelineManifest) -> Self {
@@ -595,6 +620,7 @@ impl OffloadedTimeline {
archived_at,
delete_progress: TimelineDeleteProgress::default(),
deleted_from_ancestor: AtomicBool::new(false),
_metrics_guard: OffloadedTimelineMetricsGuard::new(),
}
}
fn manifest(&self) -> OffloadedTimelineManifest {

File diff suppressed because it is too large Load Diff

View File

@@ -1055,8 +1055,8 @@ pub(crate) enum WaitLsnWaiter<'a> {
/// Argument to [`Timeline::shutdown`].
#[derive(Debug, Clone, Copy)]
pub(crate) enum ShutdownMode {
/// Graceful shutdown, may do a lot of I/O as we flush any open layers to disk and then
/// also to remote storage. This method can easily take multiple seconds for a busy timeline.
/// Graceful shutdown, may do a lot of I/O as we flush any open layers to disk. This method can
/// take multiple seconds for a busy timeline.
///
/// While we are flushing, we continue to accept read I/O for LSNs ingested before
/// the call to [`Timeline::shutdown`].

View File

@@ -1,10 +1,10 @@
# pgxs/neon/Makefile
MODULE_big = neon
OBJS = \
$(WIN32RES) \
communicator.o \
communicator_new.o \
extension_server.o \
file_cache.o \
hll.o \
@@ -22,13 +22,11 @@ OBJS = \
walproposer.o \
walproposer_pg.o \
control_plane_connector.o \
walsender_hooks.o \
$(LIBCOMMUNICATOR_PATH)/libcommunicator.a
walsender_hooks.o
PG_CPPFLAGS = -I$(libpq_srcdir)
SHLIB_LINK_INTERNAL = $(libpq)
SHLIB_LINK = -lcurl
SHLIB_LINK += -framework Security -framework CoreFoundation -framework CoreServices -framework IOKit -framework SystemConfiguration
EXTENSION = neon
DATA = \

View File

@@ -1092,13 +1092,15 @@ communicator_prefetch_register_bufferv(BufferTag tag, neon_request_lsns *frlsns,
MyPState->ring_last <= ring_index);
}
/* internal version. Returns the ring index */
/* Internal version. Returns the ring index of the last block (result of this function is used only
* when nblocks==1)
*/
static uint64
prefetch_register_bufferv(BufferTag tag, neon_request_lsns *frlsns,
BlockNumber nblocks, const bits8 *mask,
bool is_prefetch)
{
uint64 min_ring_index;
uint64 last_ring_index;
PrefetchRequest hashkey;
#ifdef USE_ASSERT_CHECKING
bool any_hits = false;
@@ -1122,13 +1124,12 @@ Retry:
MyPState->ring_unused - MyPState->ring_receive;
MyNeonCounters->getpage_prefetches_buffered =
MyPState->n_responses_buffered;
last_ring_index = UINT64_MAX;
min_ring_index = UINT64_MAX;
for (int i = 0; i < nblocks; i++)
{
PrefetchRequest *slot = NULL;
PrfHashEntry *entry = NULL;
uint64 ring_index;
neon_request_lsns *lsns;
if (PointerIsValid(mask) && BITMAP_ISSET(mask, i))
@@ -1152,12 +1153,12 @@ Retry:
if (entry != NULL)
{
slot = entry->slot;
ring_index = slot->my_ring_index;
Assert(slot == GetPrfSlot(ring_index));
last_ring_index = slot->my_ring_index;
Assert(slot == GetPrfSlot(last_ring_index));
Assert(slot->status != PRFS_UNUSED);
Assert(MyPState->ring_last <= ring_index &&
ring_index < MyPState->ring_unused);
Assert(MyPState->ring_last <= last_ring_index &&
last_ring_index < MyPState->ring_unused);
Assert(BufferTagsEqual(&slot->buftag, &hashkey.buftag));
/*
@@ -1169,9 +1170,9 @@ Retry:
if (!neon_prefetch_response_usable(lsns, slot))
{
/* Wait for the old request to finish and discard it */
if (!prefetch_wait_for(ring_index))
if (!prefetch_wait_for(last_ring_index))
goto Retry;
prefetch_set_unused(ring_index);
prefetch_set_unused(last_ring_index);
entry = NULL;
slot = NULL;
pgBufferUsage.prefetch.expired += 1;
@@ -1188,13 +1189,12 @@ Retry:
*/
if (slot->status == PRFS_TAG_REMAINS)
{
prefetch_set_unused(ring_index);
prefetch_set_unused(last_ring_index);
entry = NULL;
slot = NULL;
}
else
{
min_ring_index = Min(min_ring_index, ring_index);
/* The buffered request is good enough, return that index */
if (is_prefetch)
pgBufferUsage.prefetch.duplicates++;
@@ -1283,12 +1283,12 @@ Retry:
* The next buffer pointed to by `ring_unused` is now definitely empty, so
* we can insert the new request to it.
*/
ring_index = MyPState->ring_unused;
last_ring_index = MyPState->ring_unused;
Assert(MyPState->ring_last <= ring_index &&
ring_index <= MyPState->ring_unused);
Assert(MyPState->ring_last <= last_ring_index &&
last_ring_index <= MyPState->ring_unused);
slot = GetPrfSlotNoCheck(ring_index);
slot = GetPrfSlotNoCheck(last_ring_index);
Assert(slot->status == PRFS_UNUSED);
@@ -1298,11 +1298,9 @@ Retry:
*/
slot->buftag = hashkey.buftag;
slot->shard_no = get_shard_number(&tag);
slot->my_ring_index = ring_index;
slot->my_ring_index = last_ring_index;
slot->flags = 0;
min_ring_index = Min(min_ring_index, ring_index);
if (is_prefetch)
MyNeonCounters->getpage_prefetch_requests_total++;
else
@@ -1315,11 +1313,12 @@ Retry:
MyPState->ring_unused - MyPState->ring_receive;
Assert(any_hits);
Assert(last_ring_index != UINT64_MAX);
Assert(GetPrfSlot(min_ring_index)->status == PRFS_REQUESTED ||
GetPrfSlot(min_ring_index)->status == PRFS_RECEIVED);
Assert(MyPState->ring_last <= min_ring_index &&
min_ring_index < MyPState->ring_unused);
Assert(GetPrfSlot(last_ring_index)->status == PRFS_REQUESTED ||
GetPrfSlot(last_ring_index)->status == PRFS_RECEIVED);
Assert(MyPState->ring_last <= last_ring_index &&
last_ring_index < MyPState->ring_unused);
if (flush_every_n_requests > 0 &&
MyPState->ring_unused - MyPState->ring_flush >= flush_every_n_requests)
@@ -1335,7 +1334,7 @@ Retry:
MyPState->ring_flush = MyPState->ring_unused;
}
return min_ring_index;
return last_ring_index;
}
static bool

View File

@@ -1,372 +0,0 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "addr2line"
version = "0.24.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1"
dependencies = [
"gimli",
]
[[package]]
name = "adler2"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627"
[[package]]
name = "backtrace"
version = "0.3.74"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a"
dependencies = [
"addr2line",
"cfg-if",
"libc",
"miniz_oxide",
"object",
"rustc-demangle",
"windows-targets",
]
[[package]]
name = "base64"
version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "bytes"
version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a"
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "communicator"
version = "0.1.0"
dependencies = [
"tonic",
]
[[package]]
name = "fnv"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "futures-core"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
[[package]]
name = "gimli"
version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "http"
version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565"
dependencies = [
"bytes",
"fnv",
"itoa",
]
[[package]]
name = "http-body"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184"
dependencies = [
"bytes",
"http",
]
[[package]]
name = "http-body-util"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
dependencies = [
"bytes",
"futures-core",
"http",
"http-body",
"pin-project-lite",
]
[[package]]
name = "itoa"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
[[package]]
name = "libc"
version = "0.2.171"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6"
[[package]]
name = "memchr"
version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3"
[[package]]
name = "miniz_oxide"
version = "0.8.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff70ce3e48ae43fa075863cef62e8b43b71a4f2382229920e0df362592919430"
dependencies = [
"adler2",
]
[[package]]
name = "object"
version = "0.36.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87"
dependencies = [
"memchr",
]
[[package]]
name = "once_cell"
version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "percent-encoding"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
[[package]]
name = "pin-project"
version = "1.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "pin-project-lite"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
[[package]]
name = "proc-macro2"
version = "1.0.94"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d"
dependencies = [
"proc-macro2",
]
[[package]]
name = "rustc-demangle"
version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "syn"
version = "2.0.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "tokio"
version = "1.44.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6b88822cbe49de4185e3a4cbf8321dd487cf5fe0c5c65695fef6346371e9c48"
dependencies = [
"backtrace",
"pin-project-lite",
]
[[package]]
name = "tokio-stream"
version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047"
dependencies = [
"futures-core",
"pin-project-lite",
"tokio",
]
[[package]]
name = "tonic"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85839f0b32fd242bb3209262371d07feda6d780d16ee9d2bc88581b89da1549b"
dependencies = [
"base64",
"bytes",
"http",
"http-body",
"http-body-util",
"percent-encoding",
"pin-project",
"tokio-stream",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tower-layer"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e"
[[package]]
name = "tower-service"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]]
name = "tracing"
version = "0.1.41"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0"
dependencies = [
"pin-project-lite",
"tracing-attributes",
"tracing-core",
]
[[package]]
name = "tracing-attributes"
version = "0.1.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tracing-core"
version = "0.1.33"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c"
dependencies = [
"once_cell",
]
[[package]]
name = "unicode-ident"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
[[package]]
name = "windows-targets"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_gnullvm",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
[[package]]
name = "windows_aarch64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
[[package]]
name = "windows_i686_gnu"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
[[package]]
name = "windows_i686_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
[[package]]
name = "windows_x86_64_gnu"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
[[package]]
name = "windows_x86_64_msvc"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"

View File

@@ -1,46 +0,0 @@
[package]
name = "communicator"
version = "0.1.0"
edition = "2024"
[features]
testing = []
[lib]
crate-type = ["staticlib"]
[dependencies]
axum.workspace = true
bytes.workspace = true
clashmap.workspace = true
http.workspace = true
libc.workspace = true
nix.workspace = true
atomic_enum = "0.3.0"
prometheus.workspace = true
prost.workspace = true
tonic = { version = "0.12.0", default-features = false, features=["codegen", "prost", "transport"] }
tokio = { version = "1.43.1", features = ["macros", "net", "io-util", "rt", "rt-multi-thread"] }
tokio-pipe = { version = "0.2.12" }
thiserror.workspace = true
tracing.workspace = true
tracing-subscriber.workspace = true
metrics.workspace = true
uring-common = { workspace = true, features = ["bytes"] }
pageserver_client_grpc.workspace = true
pageserver_page_api.workspace = true
neon-shmem.workspace = true
utils.workspace = true
[build-dependencies]
cbindgen.workspace = true
[target.'cfg(target_os = "macos")']
rustflags = [
"-l", "framework=Security",
"-l", "framework=CoreFoundation",
"-l", "framework=CoreServices",
]

View File

@@ -1,123 +0,0 @@
# Communicator
This package provides the so-called "compute-pageserver communicator",
or just "communicator" in short. It runs in a PostgreSQL server, as
part of the neon extension, and handles the communication with the
pageservers. On the PostgreSQL side, the glue code in pgxn/neon/ uses
the communicator to implement the PostgreSQL Storage Manager (SMGR)
interface.
## Design criteria
- Low latency
- Saturate a 10 Gbit / s network interface without becoming a bottleneck
## Source code view
pgxn/neon/communicator_new.c
Contains the glue that interact with PostgreSQL code and the Rust
communicator code.
pgxn/neon/communicator/src/backend_interface.rs
The entry point for calls from each backend.
pgxn/neon/communicator/src/init.rs
Initialization at server startup
pgxn/neon/communicator/src/worker_process/
Worker process main loop and glue code
At compilation time, pgxn/neon/communicator/ produces a static
library, libcommunicator.a. It is linked to the neon.so extension
library.
The real networking code, which is independent of PostgreSQL, is in
the pageserver/client_grpc crate.
## Process view
The communicator runs in a dedicated background worker process, the
"communicator process". The communicator uses a multi-threaded Tokio
runtime to execute the IO requests. So the communicator process has
multiple threads running. That's unusual for Postgres processes and
care must be taken to make that work.
### Backend <-> worker communication
Each backend has a number of I/O request slots in shared memory. The
slots are statically allocated for each backend, and must not be
accessed by other backends. The worker process reads requests from the
shared memory slots, and writes responses back to the slots.
To submit an IO request, first pick one of your backend's free slots,
and write the details of the IO request in the slot. Finally, update
the 'state' field of the slot to Submitted. That informs the worker
process that it can start processing the request. Once the state has
been set to Submitted, the backend *must not* access the slot anymore,
until the worker process sets its state to 'Completed'. In other
words, each slot is owned by either the backend or the worker process
at all times, and the 'state' field indicates who has ownership at the
moment.
To inform the worker process that a request slot has a pending IO
request, there's a pipe shared by the worker process and all backend
processes. After you have changed the slot's state to Submitted, write
the index of the request slot to the pipe. This wakes up the worker
process.
(Note that the pipe is just used for wakeups, but the worker process
is free to pick up Submitted IO requests even without receiving the
wakeup. As of this writing, it doesn't do that, but it might be useful
in the future to reduce latency even further, for example.)
When the worker process has completed processing the request, it
writes the result back in the request slot. A GetPage request can also
contain a pointer to buffer in the shared buffer cache. In that case,
the worker process writes the resulting page contents directly to the
buffer, and just a result code in the request slot. It then updates
the 'state' field to Completed, which passes the owner ship back to
the originating backend. Finally, it signals the process Latch of the
originating backend, waking it up.
### Differences between PostgreSQL v16, v17 and v18
PostgreSQL v18 introduced the new AIO mechanism. The PostgreSQL AIO
mechanism uses a very similar mechanism as described in the previous
section, for the communication between AIO worker processes and
backends. With our communicator, the AIO worker processes are not
used, but we use the same PgAioHandle request slots as in upstream.
For Neon-specific IO requests like GetDbSize, a neon request slot is
used. But for the actual IO requests, the request slot merely contains
a pointer to the PgAioHandle slot. The worker process updates the
status of that, calls the IO callbacks upon completionetc, just like
the upstream AIO worker processes do.
## Sequence diagram
neon
PostgreSQL extension backend_interface.rs worker_process.rs processor tonic
| . . . .
| smgr_read() . . . .
+-------------> + . . .
. | . . .
. | rcommunicator_ . . .
. | get_page_at_lsn . . .
. +------------------> + . .
| . .
| write request to . . .
| slot . .
| . .
| . .
| submit_request() . .
+-----------------> + .
| | .
| | db_size_request . .
+---------------->.
. TODO
### Compute <-> pageserver protocol
The protocol between Compute and the pageserver is based on gRPC. See `protos/`.

View File

@@ -1,26 +0,0 @@
use std::env;
fn main() -> Result<(), Box<dyn std::error::Error>> {
let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
println!("cargo:rustc-link-lib=framework=Security");
println!("cargo:rustc-link-lib=framework=CoreFoundation");
println!("cargo:rustc-link-lib=framework=CoreServices");
println!("cargo:rustc-link-lib=framework=IOKit");
cbindgen::generate(crate_dir).map_or_else(
|error| match error {
cbindgen::Error::ParseSyntaxError { .. } => {
// This means there was a syntax error in the Rust sources. Don't panic, because
// we want the build to continue and the Rust compiler to hit the error. The
// Rust compiler produces a better error message than cbindgen.
eprintln!("Generating C bindings failed because of a Rust syntax error");
}
e => panic!("Unable to generate C bindings: {:?}", e),
},
|bindings| {
bindings.write_to_file("communicator_bindings.h");
},
);
Ok(())
}

View File

@@ -1,4 +0,0 @@
language = "C"
[enum]
prefix_with_name = true

View File

@@ -1,204 +0,0 @@
//! This module implements a request/response "slot" for submitting requests from backends
//! to the communicator process.
//!
//! NB: The "backend" side of this code runs in Postgres backend processes,
//! which means that it is not safe to use the 'tracing' crate for logging, nor
//! to launch threads or use tokio tasks.
use std::cell::UnsafeCell;
use std::sync::atomic::fence;
use std::sync::atomic::{AtomicI32, Ordering};
use crate::neon_request::{NeonIORequest, NeonIOResult};
use atomic_enum::atomic_enum;
/// One request/response slot. Each backend has its own set of slots that it uses.
///
/// This is the moral equivalent of PgAioHandle for Postgres AIO requests
/// Like PgAioHandle, try to keep this small.
///
/// There is an array of these in shared memory. Therefore, this must be Sized.
///
/// ## Lifecycle of a request
///
/// The slot is always owned by either the backend process or the communicator
/// process, depending on the 'state'. Only the owning process is allowed to
/// read or modify the slot, except for reading the 'state' itself to check who
/// owns it.
///
/// A slot begins in the Idle state, where it is owned by the backend process.
/// To submit a request, the backend process fills the slot with the request
/// data, and changes it to the Submitted state. After changing the state, the
/// slot is owned by the communicator process, and the backend is not allowed
/// to access it until the communicator process marks it as Completed.
///
/// When the communicator process sees that the slot is in Submitted state, it
/// starts to process the request. After processing the request, it stores the
/// result in the slot, and changes the state to Completed. It is now owned by
/// the backend process again, which may now read the result, and reuse the
/// slot for a new request.
///
/// For correctness of the above protocol, we really only need two states:
/// "owned by backend" and "owned by communicator process. But to help with
/// debugging, there are a few more states. When the backend starts to fill in
/// the request details in the slot, it first sets the state from Idle to
/// Filling, and when it's done with that, from Filling to Submitted. In the
/// Filling state, the slot is still owned by the backend. Similarly, when the
/// communicator process starts to process a request, it sets it to Processing
/// state first, but the slot is still owned by the communicator process.
///
/// This struct doesn't handle waking up the communicator process when a request
/// has been submitted or when a response is ready. We only store the 'owner_procno'
/// which can be used for waking up the backend on completion, but the wakeups are
/// performed elsewhere.
pub struct NeonIOHandle {
/// similar to PgAioHandleState
state: AtomicNeonIOHandleState,
/// The owning process's ProcNumber. The worker process uses this to set the process's
/// latch on completion.
///
/// (This could be calculated from num_neon_request_slots_per_backend and the index of
/// this slot in the overall 'neon_requst_slots array')
owner_procno: AtomicI32,
/// SAFETY: This is modified by fill_request(), after it has established ownership
/// of the slot by setting state from Idle to Filling
request: UnsafeCell<NeonIORequest>,
/// valid when state is Completed
///
/// SAFETY: This is modified by RequestProcessingGuard::complete(). There can be
/// only one RequestProcessingGuard outstanding for a slot at a time, because
/// it is returned by start_processing_request() which checks the state, so
/// RequestProcessingGuard has exclusive access to the slot.
result: UnsafeCell<NeonIOResult>,
}
// The protocol described in the "Lifecycle of a request" section above ensures
// the safe access to the fields
unsafe impl Send for NeonIOHandle {}
unsafe impl Sync for NeonIOHandle {}
impl Default for NeonIOHandle {
fn default() -> NeonIOHandle {
NeonIOHandle {
owner_procno: AtomicI32::new(-1),
request: UnsafeCell::new(NeonIORequest::Empty),
result: UnsafeCell::new(NeonIOResult::Empty),
state: AtomicNeonIOHandleState::new(NeonIOHandleState::Idle),
}
}
}
#[atomic_enum]
#[derive(Eq, PartialEq)]
pub enum NeonIOHandleState {
Idle,
/// backend is filling in the request
Filling,
/// Backend has submitted the request to the communicator, but the
/// communicator process has not yet started processing it.
Submitted,
/// Communicator is processing the request
Processing,
/// Communicator has completed the request, and the 'result' field is now
/// valid, but the backend has not read the result yet.
Completed,
}
pub struct RequestProcessingGuard<'a>(&'a NeonIOHandle);
unsafe impl<'a> Send for RequestProcessingGuard<'a> {}
unsafe impl<'a> Sync for RequestProcessingGuard<'a> {}
impl<'a> RequestProcessingGuard<'a> {
pub fn get_request(&self) -> &NeonIORequest {
unsafe { &*self.0.request.get() }
}
pub fn get_owner_procno(&self) -> i32 {
self.0.owner_procno.load(Ordering::Relaxed)
}
pub fn completed(self, result: NeonIOResult) {
unsafe {
*self.0.result.get() = result;
};
// Ok, we have completed the IO. Mark the request as completed. After that,
// we no longer have ownership of the slot, and must not modify it.
let old_state = self
.0
.state
.swap(NeonIOHandleState::Completed, Ordering::Release);
assert!(old_state == NeonIOHandleState::Processing);
}
}
impl NeonIOHandle {
pub fn fill_request(&self, request: &NeonIORequest, proc_number: i32) {
// Verify that the slot is in Idle state previously, and start filling it.
//
// XXX: This step isn't strictly necessary. Assuming the caller didn't screw up
// and try to use a slot that's already in use, we could fill the slot and
// switch it directly from Idle to Submitted state.
if let Err(s) = self.state.compare_exchange(
NeonIOHandleState::Idle,
NeonIOHandleState::Filling,
Ordering::Relaxed,
Ordering::Relaxed,
) {
panic!("unexpected state in request slot: {s:?}");
}
// This fence synchronizes-with store/swap in `communicator_process_main_loop`.
fence(Ordering::Acquire);
self.owner_procno.store(proc_number, Ordering::Relaxed);
unsafe { *self.request.get() = *request }
self.state
.store(NeonIOHandleState::Submitted, Ordering::Release);
}
pub fn try_get_result(&self) -> Option<NeonIOResult> {
// FIXME: ordering?
let state = self.state.load(Ordering::Relaxed);
if state == NeonIOHandleState::Completed {
// This fence synchronizes-with store/swap in `communicator_process_main_loop`.
fence(Ordering::Acquire);
let result = unsafe { *self.result.get() };
self.state.store(NeonIOHandleState::Idle, Ordering::Relaxed);
Some(result)
} else {
None
}
}
pub fn start_processing_request<'a>(&'a self) -> Option<RequestProcessingGuard<'a>> {
// Read the IO request from the slot indicated in the wakeup
//
// XXX: using compare_exchange for this is not strictly necessary, as long as
// the communicator process has _some_ means of tracking which requests it's
// already processing. That could be a flag somewhere in communicator's private
// memory, for example.
if let Err(s) = self.state.compare_exchange(
NeonIOHandleState::Submitted,
NeonIOHandleState::Processing,
Ordering::Relaxed,
Ordering::Relaxed,
) {
// FIXME surprising state. This is unexpected at the moment, but if we
// started to process requests more aggressively, without waiting for the
// read from the pipe, then this could happen
panic!("unexpected state in request slot: {s:?}");
}
fence(Ordering::Acquire);
Some(RequestProcessingGuard(self))
}
}

View File

@@ -1,199 +0,0 @@
//! This code runs in each backend process. That means that launching Rust threads, panicking
//! etc. is forbidden!
use std::os::fd::OwnedFd;
use crate::backend_comms::NeonIOHandle;
use crate::init::CommunicatorInitStruct;
use crate::integrated_cache::{BackendCacheReadOp, IntegratedCacheReadAccess};
use crate::neon_request::CCachedGetPageVResult;
use crate::neon_request::{NeonIORequest, NeonIOResult};
pub struct CommunicatorBackendStruct<'t> {
my_proc_number: i32,
next_neon_request_idx: u32,
my_start_idx: u32, // First request slot that belongs to this backend
my_end_idx: u32, // end + 1 request slot that belongs to this backend
neon_request_slots: &'t [NeonIOHandle],
submission_pipe_write_fd: OwnedFd,
pending_cache_read_op: Option<BackendCacheReadOp<'t>>,
integrated_cache: &'t IntegratedCacheReadAccess<'t>,
}
#[unsafe(no_mangle)]
pub extern "C" fn rcommunicator_backend_init(
cis: Box<CommunicatorInitStruct>,
my_proc_number: i32,
) -> &'static mut CommunicatorBackendStruct<'static> {
let start_idx = my_proc_number as u32 * cis.num_neon_request_slots_per_backend;
let end_idx = start_idx + cis.num_neon_request_slots_per_backend;
let integrated_cache = Box::leak(Box::new(cis.integrated_cache_init_struct.backend_init()));
let bs: &'static mut CommunicatorBackendStruct =
Box::leak(Box::new(CommunicatorBackendStruct {
my_proc_number,
next_neon_request_idx: start_idx,
my_start_idx: start_idx,
my_end_idx: end_idx,
neon_request_slots: cis.neon_request_slots,
submission_pipe_write_fd: cis.submission_pipe_write_fd,
pending_cache_read_op: None,
integrated_cache,
}));
bs
}
/// Start a request. You can poll for its completion and get the result by
/// calling bcomm_poll_dbsize_request_completion(). The communicator will wake
/// us up by setting our process latch, so to wait for the completion, wait on
/// the latch and call bcomm_poll_dbsize_request_completion() every time the
/// latch is set.
///
/// Safety: The C caller must ensure that the references are valid.
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_start_io_request<'t>(
bs: &'t mut CommunicatorBackendStruct,
request: &NeonIORequest,
immediate_result_ptr: &mut NeonIOResult,
) -> i32 {
assert!(bs.pending_cache_read_op.is_none());
// Check if the request can be satisfied from the cache first
if let NeonIORequest::RelSize(req) = request {
if let Some(nblocks) = bs.integrated_cache.get_rel_size(&req.reltag()) {
*immediate_result_ptr = NeonIOResult::RelSize(nblocks);
return -1;
}
}
// Create neon request and submit it
let request_idx = bs.start_neon_request(request);
// Tell the communicator about it
bs.submit_request(request_idx);
return request_idx;
}
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_start_get_page_v_request<'t>(
bs: &'t mut CommunicatorBackendStruct,
request: &NeonIORequest,
immediate_result_ptr: &mut CCachedGetPageVResult,
) -> i32 {
let NeonIORequest::GetPageV(get_pagev_request) = request else {
panic!("invalid request passed to bcomm_start_get_page_v_request()");
};
assert!(matches!(request, NeonIORequest::GetPageV(_)));
assert!(bs.pending_cache_read_op.is_none());
// Check if the request can be satisfied from the cache first
let mut all_cached = true;
let mut read_op = bs.integrated_cache.start_read_op();
for i in 0..get_pagev_request.nblocks {
if let Some(cache_block) = read_op.get_page(
&get_pagev_request.reltag(),
get_pagev_request.block_number + i as u32,
) {
(*immediate_result_ptr).cache_block_numbers[i as usize] = cache_block;
} else {
// not found in cache
all_cached = false;
break;
}
}
if all_cached {
bs.pending_cache_read_op = Some(read_op);
return -1;
}
// Create neon request and submit it
let request_idx = bs.start_neon_request(request);
// Tell the communicator about it
bs.submit_request(request_idx);
request_idx
}
/// Check if a request has completed. Returns:
///
/// -1 if the request is still being processed
/// 0 on success
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_poll_request_completion(
bs: &mut CommunicatorBackendStruct,
request_idx: u32,
result_p: &mut NeonIOResult,
) -> i32 {
match bs.neon_request_slots[request_idx as usize].try_get_result() {
None => -1, // still processing
Some(result) => {
*result_p = result;
0
}
}
}
// LFC functions
/// Finish a local file cache read
///
//
#[unsafe(no_mangle)]
pub extern "C" fn bcomm_finish_cache_read(bs: &mut CommunicatorBackendStruct) -> bool {
if let Some(op) = bs.pending_cache_read_op.take() {
op.finish()
} else {
panic!("bcomm_finish_cache_read() called with no cached read pending");
}
}
impl<'t> CommunicatorBackendStruct<'t> {
/// Send a wakeup to the communicator process
fn submit_request(self: &CommunicatorBackendStruct<'t>, request_idx: i32) {
// wake up communicator by writing the idx to the submission pipe
//
// This can block, if the pipe is full. That should be very rare,
// because the communicator tries hard to drain the pipe to prevent
// that. Also, there's a natural upper bound on how many wakeups can be
// queued up: there is only a limited number of request slots for each
// backend.
//
// If it does block very briefly, that's not too serious.
let idxbuf = request_idx.to_ne_bytes();
let _res = nix::unistd::write(&self.submission_pipe_write_fd, &idxbuf);
// FIXME: check result, return any errors
}
/// Note: there's no guarantee on when the communicator might pick it up. You should ring
/// the doorbell. But it might pick it up immediately.
pub(crate) fn start_neon_request(&mut self, request: &NeonIORequest) -> i32 {
let my_proc_number = self.my_proc_number;
// Grab next free slot
// FIXME: any guarantee that there will be any?
let idx = self.next_neon_request_idx;
let next_idx = idx + 1;
self.next_neon_request_idx = if next_idx == self.my_end_idx {
self.my_start_idx
} else {
next_idx
};
self.neon_request_slots[idx as usize].fill_request(request, my_proc_number);
return idx as i32;
}
}

View File

@@ -1,162 +0,0 @@
//! Implement the "low-level" parts of the file cache.
//!
//! This module just deals with reading and writing the file, and keeping track
//! which blocks in the cache file are in use and which are free. The "high
//! level" parts of tracking which block in the cache file corresponds to which
//! relation block is handled in 'integrated_cache' instead.
//!
//! This module is only used to access the file from the communicator
//! process. The backend processes *also* read the file (and sometimes also
//! write it? ), but the backends use direct C library calls for that.
use std::fs::File;
use std::os::unix::fs::FileExt;
use std::path::Path;
use std::sync::Arc;
use std::sync::Mutex;
use crate::BLCKSZ;
use tokio::task::spawn_blocking;
pub type CacheBlock = u64;
pub const INVALID_CACHE_BLOCK: CacheBlock = u64::MAX;
pub struct FileCache {
file: Arc<File>,
free_list: Mutex<FreeList>,
// metrics
max_blocks_gauge: metrics::IntGauge,
num_free_blocks_gauge: metrics::IntGauge,
}
// TODO: We keep track of all free blocks in this vec. That doesn't really scale.
// Idea: when free_blocks fills up with more than 1024 entries, write them all to
// one block on disk.
struct FreeList {
next_free_block: CacheBlock,
max_blocks: u64,
free_blocks: Vec<CacheBlock>,
}
impl FileCache {
pub fn new(file_cache_path: &Path, mut initial_size: u64) -> Result<FileCache, std::io::Error> {
if initial_size < 100 {
tracing::warn!(
"min size for file cache is 100 blocks, {} requested",
initial_size
);
initial_size = 100;
}
let file = std::fs::OpenOptions::new()
.read(true)
.write(true)
.truncate(true)
.create(true)
.open(file_cache_path)?;
let max_blocks_gauge = metrics::IntGauge::new(
"file_cache_max_blocks",
"Local File Cache size in 8KiB blocks",
)
.unwrap();
let num_free_blocks_gauge = metrics::IntGauge::new(
"file_cache_num_free_blocks",
"Number of free 8KiB blocks in Local File Cache",
)
.unwrap();
tracing::info!("initialized file cache with {} blocks", initial_size);
Ok(FileCache {
file: Arc::new(file),
free_list: Mutex::new(FreeList {
next_free_block: 0,
max_blocks: initial_size,
free_blocks: Vec::new(),
}),
max_blocks_gauge,
num_free_blocks_gauge,
})
}
// File cache management
pub async fn read_block(
&self,
cache_block: CacheBlock,
mut dst: impl uring_common::buf::IoBufMut + Send + Sync,
) -> Result<(), std::io::Error> {
assert!(dst.bytes_total() == BLCKSZ);
let file = self.file.clone();
let dst_ref = unsafe { std::slice::from_raw_parts_mut(dst.stable_mut_ptr(), BLCKSZ) };
spawn_blocking(move || file.read_exact_at(dst_ref, cache_block as u64 * BLCKSZ as u64))
.await??;
Ok(())
}
pub async fn write_block(
&self,
cache_block: CacheBlock,
src: impl uring_common::buf::IoBuf + Send + Sync,
) -> Result<(), std::io::Error> {
assert!(src.bytes_init() == BLCKSZ);
let file = self.file.clone();
let src_ref = unsafe { std::slice::from_raw_parts(src.stable_ptr(), BLCKSZ) };
spawn_blocking(move || file.write_all_at(src_ref, cache_block as u64 * BLCKSZ as u64))
.await??;
Ok(())
}
pub fn alloc_block(&self) -> Option<CacheBlock> {
let mut free_list = self.free_list.lock().unwrap();
if let Some(x) = free_list.free_blocks.pop() {
return Some(x);
}
if free_list.next_free_block < free_list.max_blocks {
let result = free_list.next_free_block;
free_list.next_free_block += 1;
return Some(result);
}
None
}
pub fn dealloc_block(&self, cache_block: CacheBlock) {
let mut free_list = self.free_list.lock().unwrap();
free_list.free_blocks.push(cache_block);
}
}
impl metrics::core::Collector for FileCache {
fn desc(&self) -> Vec<&metrics::core::Desc> {
let mut descs = Vec::new();
descs.append(&mut self.max_blocks_gauge.desc());
descs.append(&mut self.num_free_blocks_gauge.desc());
descs
}
fn collect(&self) -> Vec<metrics::proto::MetricFamily> {
// Update the gauges with fresh values first
{
let free_list = self.free_list.lock().unwrap();
self.max_blocks_gauge.set(free_list.max_blocks as i64);
let total_free_blocks: i64 = free_list.free_blocks.len() as i64
+ (free_list.max_blocks as i64 - free_list.next_free_block as i64);
self.num_free_blocks_gauge.set(total_free_blocks as i64);
}
let mut values = Vec::new();
values.append(&mut self.max_blocks_gauge.collect());
values.append(&mut self.num_free_blocks_gauge.collect());
values
}
}

View File

@@ -1,109 +0,0 @@
//! Global allocator, for tracking memory usage of the Rust parts
//!
//! Postgres is designed to handle allocation failure (ie. malloc() returning NULL) gracefully. It
//! rolls backs the transaction and gives the user an "ERROR: out of memory" error. Rust code
//! however panics if an allocation fails. We don't want that to ever happen, because an unhandled
//! panic leads to Postgres crash and restart. Our strategy is to pre-allocate a large enough chunk
//! of memory for use by the Rust code, so that the allocations never fail.
//!
//! To pick the size for the pre-allocated chunk, we have a metric to track the high watermark
//! memory usage of all the Rust allocations in total.
//!
//! TODO:
//!
//! - Currently we just export the metrics. Actual allocations are still just passed through to
//! the system allocator.
//! - Take padding etc. overhead into account
use std::alloc::{GlobalAlloc, Layout, System};
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use metrics::IntGauge;
struct MyAllocator {
allocations: AtomicU64,
deallocations: AtomicU64,
allocated: AtomicUsize,
high: AtomicUsize,
}
unsafe impl GlobalAlloc for MyAllocator {
unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
self.allocations.fetch_add(1, Ordering::Relaxed);
let mut allocated = self.allocated.fetch_add(layout.size(), Ordering::Relaxed);
allocated += layout.size();
self.high.fetch_max(allocated, Ordering::Relaxed);
unsafe { System.alloc(layout) }
}
unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
self.deallocations.fetch_add(1, Ordering::Relaxed);
self.allocated.fetch_sub(layout.size(), Ordering::Relaxed);
unsafe { System.dealloc(ptr, layout) }
}
}
#[global_allocator]
static GLOBAL: MyAllocator = MyAllocator {
allocations: AtomicU64::new(0),
deallocations: AtomicU64::new(0),
allocated: AtomicUsize::new(0),
high: AtomicUsize::new(0),
};
pub struct MyAllocatorCollector {
allocations: IntGauge,
deallocations: IntGauge,
allocated: IntGauge,
high: IntGauge,
}
impl MyAllocatorCollector {
pub fn new() -> MyAllocatorCollector {
MyAllocatorCollector {
allocations: IntGauge::new("allocations_total", "Number of allocations in Rust code")
.unwrap(),
deallocations: IntGauge::new(
"deallocations_total",
"Number of deallocations in Rust code",
)
.unwrap(),
allocated: IntGauge::new("allocated_total", "Bytes currently allocated").unwrap(),
high: IntGauge::new("allocated_high", "High watermark of allocated bytes").unwrap(),
}
}
}
impl metrics::core::Collector for MyAllocatorCollector {
fn desc(&self) -> Vec<&metrics::core::Desc> {
let mut descs = Vec::new();
descs.append(&mut self.allocations.desc());
descs.append(&mut self.deallocations.desc());
descs.append(&mut self.allocated.desc());
descs.append(&mut self.high.desc());
descs
}
fn collect(&self) -> Vec<metrics::proto::MetricFamily> {
let mut values = Vec::new();
// update the gauges
self.allocations
.set(GLOBAL.allocations.load(Ordering::Relaxed) as i64);
self.deallocations
.set(GLOBAL.allocations.load(Ordering::Relaxed) as i64);
self.allocated
.set(GLOBAL.allocated.load(Ordering::Relaxed) as i64);
self.high.set(GLOBAL.high.load(Ordering::Relaxed) as i64);
values.append(&mut self.allocations.collect());
values.append(&mut self.deallocations.collect());
values.append(&mut self.allocated.collect());
values.append(&mut self.high.collect());
values
}
}

View File

@@ -1,184 +0,0 @@
//! Initialization functions. These are executed in the postmaster process,
//! at different stages of server startup.
//!
//!
//! Communicator initialization steps:
//!
//! 1. At postmaster startup, before shared memory is allocated,
//! rcommunicator_shmem_size() is called to get the amount of
//! shared memory that this module needs.
//!
//! 2. Later, after the shared memory has been allocated,
//! rcommunicator_shmem_init() is called to initialize the shmem
//! area.
//!
//! Per process initialization:
//!
//! When a backend process starts up, it calls rcommunicator_backend_init().
//! In the communicator worker process, other functions are called, see
//! `worker_process` module.
use std::ffi::c_int;
use std::mem;
use std::mem::MaybeUninit;
use std::os::fd::OwnedFd;
use crate::backend_comms::NeonIOHandle;
use crate::integrated_cache::IntegratedCacheInitStruct;
const NUM_NEON_REQUEST_SLOTS_PER_BACKEND: u32 = 5;
/// This struct is created in the postmaster process, and inherited to
/// the communicator process and all backend processes through fork()
#[repr(C)]
pub struct CommunicatorInitStruct {
#[allow(dead_code)]
pub max_procs: u32,
pub submission_pipe_read_fd: OwnedFd,
pub submission_pipe_write_fd: OwnedFd,
// Shared memory data structures
pub num_neon_request_slots_per_backend: u32,
pub neon_request_slots: &'static [NeonIOHandle],
pub integrated_cache_init_struct: IntegratedCacheInitStruct<'static>,
}
impl std::fmt::Debug for CommunicatorInitStruct {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
fmt.debug_struct("CommunicatorInitStruct")
.field("max_procs", &self.max_procs)
.field("submission_pipe_read_fd", &self.submission_pipe_read_fd)
.field("submission_pipe_write_fd", &self.submission_pipe_write_fd)
.field(
"num_neon_request_slots_per_backend",
&self.num_neon_request_slots_per_backend,
)
.field("neon_request_slots length", &self.neon_request_slots.len())
.finish()
}
}
#[unsafe(no_mangle)]
pub extern "C" fn rcommunicator_shmem_size(max_procs: u32) -> u64 {
let mut size = 0;
let num_neon_request_slots = max_procs * NUM_NEON_REQUEST_SLOTS_PER_BACKEND;
size += mem::size_of::<NeonIOHandle>() * num_neon_request_slots as usize;
// For integrated_cache's Allocator. TODO: make this adjustable
size += IntegratedCacheInitStruct::shmem_size(max_procs);
size as u64
}
/// Initialize the shared memory segment. Returns a backend-private
/// struct, which will be inherited by backend processes through fork
#[unsafe(no_mangle)]
pub extern "C" fn rcommunicator_shmem_init(
submission_pipe_read_fd: c_int,
submission_pipe_write_fd: c_int,
max_procs: u32,
shmem_area_ptr: *mut MaybeUninit<u8>,
shmem_area_len: u64,
initial_file_cache_size: u64,
max_file_cache_size: u64,
) -> &'static mut CommunicatorInitStruct {
let shmem_area: &'static mut [MaybeUninit<u8>] =
unsafe { std::slice::from_raw_parts_mut(shmem_area_ptr, shmem_area_len as usize) };
// Carve out the request slots from the shmem area and initialize them
let num_neon_request_slots_per_backend = NUM_NEON_REQUEST_SLOTS_PER_BACKEND as usize;
let num_neon_request_slots = max_procs as usize * num_neon_request_slots_per_backend;
let (neon_request_slots, remaining_area) =
alloc_array_from_slice::<NeonIOHandle>(shmem_area, num_neon_request_slots);
for i in 0..num_neon_request_slots {
neon_request_slots[i].write(NeonIOHandle::default());
}
// 'neon_request_slots' is initialized now. (MaybeUninit::slice_assume_init_mut() is nightly-only
// as of this writing.)
let neon_request_slots = unsafe {
std::mem::transmute::<&mut [MaybeUninit<NeonIOHandle>], &mut [NeonIOHandle]>(
neon_request_slots,
)
};
// Give the rest of the area to the integrated cache
let integrated_cache_init_struct = IntegratedCacheInitStruct::shmem_init(
max_procs,
remaining_area,
initial_file_cache_size,
max_file_cache_size,
);
let (submission_pipe_read_fd, submission_pipe_write_fd) = unsafe {
use std::os::fd::FromRawFd;
(
OwnedFd::from_raw_fd(submission_pipe_read_fd),
OwnedFd::from_raw_fd(submission_pipe_write_fd),
)
};
let cis: &'static mut CommunicatorInitStruct = Box::leak(Box::new(CommunicatorInitStruct {
max_procs,
submission_pipe_read_fd,
submission_pipe_write_fd,
num_neon_request_slots_per_backend: NUM_NEON_REQUEST_SLOTS_PER_BACKEND,
neon_request_slots,
integrated_cache_init_struct,
}));
cis
}
// fixme: currently unused
#[allow(dead_code)]
pub fn alloc_from_slice<T>(
area: &mut [MaybeUninit<u8>],
) -> (&mut MaybeUninit<T>, &mut [MaybeUninit<u8>]) {
let layout = std::alloc::Layout::new::<T>();
let area_start = area.as_mut_ptr();
// pad to satisfy alignment requirements
let padding = area_start.align_offset(layout.align());
if padding + layout.size() > area.len() {
panic!("out of memory");
}
let area = &mut area[padding..];
let (result_area, remain) = area.split_at_mut(layout.size());
let result_ptr: *mut MaybeUninit<T> = result_area.as_mut_ptr().cast();
let result = unsafe { result_ptr.as_mut().unwrap() };
(result, remain)
}
pub fn alloc_array_from_slice<T>(
area: &mut [MaybeUninit<u8>],
len: usize,
) -> (&mut [MaybeUninit<T>], &mut [MaybeUninit<u8>]) {
let layout = std::alloc::Layout::new::<T>();
let area_start = area.as_mut_ptr();
// pad to satisfy alignment requirements
let padding = area_start.align_offset(layout.align());
if padding + layout.size() * len > area.len() {
panic!("out of memory");
}
let area = &mut area[padding..];
let (result_area, remain) = area.split_at_mut(layout.size() * len);
let result_ptr: *mut MaybeUninit<T> = result_area.as_mut_ptr().cast();
let result = unsafe { std::slice::from_raw_parts_mut(result_ptr.as_mut().unwrap(), len) };
(result, remain)
}

View File

@@ -1,803 +0,0 @@
//! Integrated communicator cache
//!
//! It tracks:
//! - Relation sizes and existence
//! - Last-written LSN
//! - Block cache (also known as LFC)
//!
//! TODO: limit the size
//! TODO: concurrency
//!
//! Note: This deals with "relations" which is really just one "relation fork" in Postgres
//! terms. RelFileLocator + ForkNumber is the key.
//
// TODO: Thoughts on eviction:
//
// There are two things we need to track, and evict if we run out of space:
// - blocks in the file cache's file. If the file grows too large, need to evict something.
// Also if the cache is resized
//
// - entries in the cache map. If we run out of memory in the shmem area, need to evict
// something
//
use std::mem::MaybeUninit;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicU64, Ordering};
use utils::lsn::{AtomicLsn, Lsn};
use crate::file_cache::INVALID_CACHE_BLOCK;
use crate::file_cache::{CacheBlock, FileCache};
use pageserver_page_api::RelTag;
use metrics::{IntCounter, IntGauge};
use neon_shmem::hash::HashMapInit;
use neon_shmem::hash::UpdateAction;
use neon_shmem::shmem::ShmemHandle;
// in # of entries
const RELSIZE_CACHE_SIZE: u32 = 64 * 1024;
/// This struct is initialized at postmaster startup, and passed to all the processes via fork().
pub struct IntegratedCacheInitStruct<'t> {
relsize_cache_handle: HashMapInit<'t, RelKey, RelEntry>,
block_map_handle: HashMapInit<'t, BlockKey, BlockEntry>,
}
/// Represents write-access to the integrated cache. This is used by the communicator process.
pub struct IntegratedCacheWriteAccess<'t> {
relsize_cache: neon_shmem::hash::HashMapAccess<'t, RelKey, RelEntry>,
block_map: neon_shmem::hash::HashMapAccess<'t, BlockKey, BlockEntry>,
global_lw_lsn: AtomicU64,
pub(crate) file_cache: Option<FileCache>,
// Fields for eviction
clock_hand: std::sync::Mutex<usize>,
// Metrics
page_evictions_counter: IntCounter,
clock_iterations_counter: IntCounter,
// metrics from the hash map
block_map_num_buckets: IntGauge,
block_map_num_buckets_in_use: IntGauge,
relsize_cache_num_buckets: IntGauge,
relsize_cache_num_buckets_in_use: IntGauge,
}
/// Represents read-only access to the integrated cache. Backend processes have this.
pub struct IntegratedCacheReadAccess<'t> {
relsize_cache: neon_shmem::hash::HashMapAccess<'t, RelKey, RelEntry>,
block_map: neon_shmem::hash::HashMapAccess<'t, BlockKey, BlockEntry>,
}
impl<'t> IntegratedCacheInitStruct<'t> {
/// Return the desired size in bytes of the fixed-size shared memory area to reserve for the
/// integrated cache.
pub fn shmem_size(_max_procs: u32) -> usize {
// The relsize cache is fixed-size. The block map is allocated in a separate resizable
// area.
HashMapInit::<RelKey, RelEntry>::estimate_size(RELSIZE_CACHE_SIZE)
}
/// Initialize the shared memory segment. This runs once in postmaster. Returns a struct which
/// will be inherited by all processes through fork.
pub fn shmem_init(
_max_procs: u32,
shmem_area: &'t mut [MaybeUninit<u8>],
initial_file_cache_size: u64,
max_file_cache_size: u64,
) -> IntegratedCacheInitStruct<'t> {
// Initialize the relsize cache in the fixed-size area
let relsize_cache_handle =
neon_shmem::hash::HashMapInit::init_in_fixed_area(RELSIZE_CACHE_SIZE, shmem_area);
let max_bytes =
HashMapInit::<BlockKey, BlockEntry>::estimate_size(max_file_cache_size as u32);
// Initialize the block map in a separate resizable shared memory area
let shmem_handle = ShmemHandle::new("block mapping", 0, max_bytes).unwrap();
let block_map_handle = neon_shmem::hash::HashMapInit::init_in_shmem(
initial_file_cache_size as u32,
shmem_handle,
);
IntegratedCacheInitStruct {
relsize_cache_handle,
block_map_handle,
}
}
/// Initialize access to the integrated cache for the communicator worker process
pub fn worker_process_init(
self,
lsn: Lsn,
file_cache: Option<FileCache>,
) -> IntegratedCacheWriteAccess<'t> {
let IntegratedCacheInitStruct {
relsize_cache_handle,
block_map_handle,
} = self;
IntegratedCacheWriteAccess {
relsize_cache: relsize_cache_handle.attach_writer(),
block_map: block_map_handle.attach_writer(),
global_lw_lsn: AtomicU64::new(lsn.0),
file_cache,
clock_hand: std::sync::Mutex::new(0),
page_evictions_counter: metrics::IntCounter::new(
"integrated_cache_evictions",
"Page evictions from the Local File Cache",
)
.unwrap(),
clock_iterations_counter: metrics::IntCounter::new(
"clock_iterations",
"Number of times the clock hand has moved",
)
.unwrap(),
block_map_num_buckets: metrics::IntGauge::new(
"block_map_num_buckets",
"Allocated size of the block cache hash map",
)
.unwrap(),
block_map_num_buckets_in_use: metrics::IntGauge::new(
"block_map_num_buckets_in_use",
"Number of buckets in use in the block cache hash map",
)
.unwrap(),
relsize_cache_num_buckets: metrics::IntGauge::new(
"relsize_cache_num_buckets",
"Allocated size of the relsize cache hash map",
)
.unwrap(),
relsize_cache_num_buckets_in_use: metrics::IntGauge::new(
"relsize_cache_num_buckets_in_use",
"Number of buckets in use in the relsize cache hash map",
)
.unwrap(),
}
}
/// Initialize access to the integrated cache for a backend process
pub fn backend_init(self) -> IntegratedCacheReadAccess<'t> {
let IntegratedCacheInitStruct {
relsize_cache_handle,
block_map_handle,
} = self;
IntegratedCacheReadAccess {
relsize_cache: relsize_cache_handle.attach_reader(),
block_map: block_map_handle.attach_reader(),
}
}
}
/// Value stored in the cache mapping hash table.
struct BlockEntry {
lw_lsn: AtomicLsn,
cache_block: AtomicU64,
pinned: AtomicU64,
// 'referenced' bit for the clock algorithm
referenced: AtomicBool,
}
/// Value stored in the relsize cache hash table.
struct RelEntry {
/// cached size of the relation
/// u32::MAX means 'not known' (that's InvalidBlockNumber in Postgres)
nblocks: AtomicU32,
}
impl std::fmt::Debug for RelEntry {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
fmt.debug_struct("Rel")
.field("nblocks", &self.nblocks.load(Ordering::Relaxed))
.finish()
}
}
impl std::fmt::Debug for BlockEntry {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
fmt.debug_struct("Block")
.field("lw_lsn", &self.lw_lsn.load())
.field("cache_block", &self.cache_block.load(Ordering::Relaxed))
.field("pinned", &self.pinned.load(Ordering::Relaxed))
.field("referenced", &self.referenced.load(Ordering::Relaxed))
.finish()
}
}
#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Hash, Ord)]
struct RelKey(RelTag);
impl From<&RelTag> for RelKey {
fn from(val: &RelTag) -> RelKey {
RelKey(val.clone())
}
}
#[derive(Clone, Debug, PartialEq, PartialOrd, Eq, Hash, Ord)]
struct BlockKey {
rel: RelTag,
block_number: u32,
}
impl From<(&RelTag, u32)> for BlockKey {
fn from(val: (&RelTag, u32)) -> BlockKey {
BlockKey {
rel: val.0.clone(),
block_number: val.1,
}
}
}
/// Return type used in the cache's get_*() functions. 'Found' means that the page, or other
/// information that was enqueried, exists in the cache. '
pub enum CacheResult<V> {
/// The enqueried page or other information existed in the cache.
Found(V),
/// The cache doesn't contain the page (or other enqueried information, like relation size). The
/// Lsn is the 'not_modified_since' LSN that should be used in the request to the pageserver to
/// read the page.
NotFound(Lsn),
}
impl<'t> IntegratedCacheWriteAccess<'t> {
pub fn get_rel_size(&'t self, rel: &RelTag) -> CacheResult<u32> {
if let Some(nblocks) = get_rel_size(&self.relsize_cache, rel) {
CacheResult::Found(nblocks)
} else {
let lsn = Lsn(self.global_lw_lsn.load(Ordering::Relaxed));
CacheResult::NotFound(lsn)
}
}
pub async fn get_page(
&'t self,
rel: &RelTag,
block_number: u32,
dst: impl uring_common::buf::IoBufMut + Send + Sync,
) -> Result<CacheResult<()>, std::io::Error> {
let x = if let Some(block_entry) = self.block_map.get(&BlockKey::from((rel, block_number)))
{
block_entry.referenced.store(true, Ordering::Relaxed);
let cache_block = block_entry.cache_block.load(Ordering::Relaxed);
if cache_block != INVALID_CACHE_BLOCK {
// pin it and release lock
block_entry.pinned.fetch_add(1, Ordering::Relaxed);
(cache_block, DeferredUnpin(block_entry.pinned.as_ptr()))
} else {
return Ok(CacheResult::NotFound(block_entry.lw_lsn.load()));
}
} else {
let lsn = Lsn(self.global_lw_lsn.load(Ordering::Relaxed));
return Ok(CacheResult::NotFound(lsn));
};
let (cache_block, _deferred_pin) = x;
self.file_cache
.as_ref()
.unwrap()
.read_block(cache_block, dst)
.await?;
// unpin the entry (by implicitly dropping deferred_pin)
Ok(CacheResult::Found(()))
}
pub async fn page_is_cached(
&'t self,
rel: &RelTag,
block_number: u32,
) -> Result<CacheResult<()>, std::io::Error> {
if let Some(block_entry) = self.block_map.get(&BlockKey::from((rel, block_number))) {
// This is used for prefetch requests. Treat the probe as an 'access', to keep it
// in cache.
block_entry.referenced.store(true, Ordering::Relaxed);
let cache_block = block_entry.cache_block.load(Ordering::Relaxed);
if cache_block != INVALID_CACHE_BLOCK {
Ok(CacheResult::Found(()))
} else {
Ok(CacheResult::NotFound(block_entry.lw_lsn.load()))
}
} else {
let lsn = Lsn(self.global_lw_lsn.load(Ordering::Relaxed));
Ok(CacheResult::NotFound(lsn))
}
}
/// Does the relation exists? CacheResult::NotFound means that the cache doesn't contain that
/// information, i.e. we don't know if the relation exists or not.
pub fn get_rel_exists(&'t self, rel: &RelTag) -> CacheResult<bool> {
// we don't currently cache negative entries, so if the relation is in the cache, it exists
if let Some(_rel_entry) = self.relsize_cache.get(&RelKey::from(rel)) {
CacheResult::Found(true)
} else {
let lsn = Lsn(self.global_lw_lsn.load(Ordering::Relaxed));
CacheResult::NotFound(lsn)
}
}
pub fn get_db_size(&'t self, _db_oid: u32) -> CacheResult<u64> {
// TODO: it would be nice to cache database sizes too. Getting the database size
// is not a very common operation, but when you do it, it's often interactive, with
// e.g. psql \l+ command, so the user will feel the latency.
// fixme: is this right lsn?
let lsn = Lsn(self.global_lw_lsn.load(Ordering::Relaxed));
CacheResult::NotFound(lsn)
}
pub fn remember_rel_size(&'t self, rel: &RelTag, nblocks: u32) {
let result =
self.relsize_cache
.update_with_fn(&RelKey::from(rel), |existing| match existing {
None => {
tracing::info!("inserting rel entry for {rel:?}, {nblocks} blocks");
UpdateAction::Insert(RelEntry {
nblocks: AtomicU32::new(nblocks),
})
}
Some(e) => {
tracing::info!("updating rel entry for {rel:?}, {nblocks} blocks");
e.nblocks.store(nblocks, Ordering::Relaxed);
UpdateAction::Nothing
}
});
// FIXME: what to do if we run out of memory? Evict other relation entries?
result.expect("out of memory");
}
/// Remember the given page contents in the cache.
pub async fn remember_page(
&'t self,
rel: &RelTag,
block_number: u32,
src: impl uring_common::buf::IoBuf + Send + Sync,
lw_lsn: Lsn,
is_write: bool,
) {
let key = BlockKey::from((rel, block_number));
// FIXME: make this work when file cache is disabled. Or make it mandatory
let file_cache = self.file_cache.as_ref().unwrap();
if is_write {
// there should be no concurrent IOs. If a backend tries to read the page
// at the same time, they may get a torn write. That's the same as with
// regular POSIX filesystem read() and write()
// First check if we have a block in cache already
let mut old_cache_block = None;
let mut found_existing = false;
let res = self.block_map.update_with_fn(&key, |existing| {
if let Some(block_entry) = existing {
found_existing = true;
// Prevent this entry from being evicted
let pin_count = block_entry.pinned.fetch_add(1, Ordering::Relaxed);
if pin_count > 0 {
// this is unexpected, because the caller has obtained the io-in-progress lock,
// so no one else should try to modify the page at the same time.
// XXX: and I think a read should not be happening either, because the postgres
// buffer is held locked. TODO: check these conditions and tidy this up a little. Seems fragile to just panic.
panic!("block entry was unexpectedly pinned");
}
let cache_block = block_entry.cache_block.load(Ordering::Relaxed);
old_cache_block = if cache_block != INVALID_CACHE_BLOCK {
Some(cache_block)
} else {
None
};
}
// if there was no existing entry, we will insert one, but not yet
UpdateAction::Nothing
});
// FIXME: what to do if we run out of memory? Evict other relation entries? Remove
// block entries first?
res.expect("out of memory");
// Allocate a new block if required
let cache_block = old_cache_block.unwrap_or_else(|| {
loop {
if let Some(x) = file_cache.alloc_block() {
break x;
}
if let Some(x) = self.try_evict_one_cache_block() {
break x;
}
}
});
// Write the page to the cache file
file_cache
.write_block(cache_block, src)
.await
.expect("error writing to cache");
// FIXME: handle errors gracefully.
// FIXME: unpin the block entry on error
// Update the block entry
let res = self.block_map.update_with_fn(&key, |existing| {
assert_eq!(found_existing, existing.is_some());
if let Some(block_entry) = existing {
// Update the cache block
let old_blk = block_entry.cache_block.compare_exchange(
INVALID_CACHE_BLOCK,
cache_block,
Ordering::Relaxed,
Ordering::Relaxed,
);
assert!(old_blk == Ok(INVALID_CACHE_BLOCK) || old_blk == Err(cache_block));
block_entry.lw_lsn.store(lw_lsn);
block_entry.referenced.store(true, Ordering::Relaxed);
let pin_count = block_entry.pinned.fetch_sub(1, Ordering::Relaxed);
assert!(pin_count > 0);
UpdateAction::Nothing
} else {
UpdateAction::Insert(BlockEntry {
lw_lsn: AtomicLsn::new(lw_lsn.0),
cache_block: AtomicU64::new(cache_block),
pinned: AtomicU64::new(0),
referenced: AtomicBool::new(true),
})
}
});
// FIXME: what to do if we run out of memory? Evict other relation entries? Remove
// block entries first?
res.expect("out of memory");
} else {
// !is_write
//
// We can assume that it doesn't already exist, because the
// caller is assumed to have already checked it, and holds
// the io-in-progress lock. (The BlockEntry might exist, but no cache block)
// Allocate a new block first
let cache_block = {
loop {
if let Some(x) = file_cache.alloc_block() {
break x;
}
if let Some(x) = self.try_evict_one_cache_block() {
break x;
}
}
};
// Write the page to the cache file
file_cache
.write_block(cache_block, src)
.await
.expect("error writing to cache");
// FIXME: handle errors gracefully.
let res = self.block_map.update_with_fn(&key, |existing| {
if let Some(block_entry) = existing {
// FIXME: could there be concurrent readers?
assert!(block_entry.pinned.load(Ordering::Relaxed) == 0);
let old_cache_block = block_entry.cache_block.swap(cache_block, Ordering::Relaxed);
if old_cache_block != INVALID_CACHE_BLOCK {
panic!("remember_page called in !is_write mode, but page is already cached at blk {}", old_cache_block);
}
UpdateAction::Nothing
} else {
UpdateAction::Insert(BlockEntry {
lw_lsn: AtomicLsn::new(lw_lsn.0),
cache_block: AtomicU64::new(cache_block),
pinned: AtomicU64::new(0),
referenced: AtomicBool::new(true),
})
}
});
// FIXME: what to do if we run out of memory? Evict other relation entries? Remove
// block entries first?
res.expect("out of memory");
}
}
/// Forget information about given relation in the cache. (For DROP TABLE and such)
pub fn forget_rel(&'t self, rel: &RelTag) {
tracing::info!("forgetting rel entry for {rel:?}");
self.relsize_cache.remove(&RelKey::from(rel));
// also forget all cached blocks for the relation
// FIXME
/*
let mut iter = MapIterator::new(&key_range_for_rel_blocks(rel));
let r = self.cache_tree.start_read();
while let Some((k, _v)) = iter.next(&r) {
let w = self.cache_tree.start_write();
let mut evicted_cache_block = None;
let res = w.update_with_fn(&k, |e| {
if let Some(e) = e {
let block_entry = if let MapEntry::Block(e) = e {
e
} else {
panic!("unexpected map entry type for block key");
};
let cache_block = block_entry
.cache_block
.swap(INVALID_CACHE_BLOCK, Ordering::Relaxed);
if cache_block != INVALID_CACHE_BLOCK {
evicted_cache_block = Some(cache_block);
}
UpdateAction::Remove
} else {
UpdateAction::Nothing
}
});
// FIXME: It's pretty surprising to run out of memory while removing. But
// maybe it can happen because of trying to shrink a node?
res.expect("out of memory");
if let Some(evicted_cache_block) = evicted_cache_block {
self.file_cache
.as_ref()
.unwrap()
.dealloc_block(evicted_cache_block);
}
}
*/
}
// Maintenance routines
/// Evict one block from the file cache. This is used when the file cache fills up
/// Returns the evicted block. It's not put to the free list, so it's available for the
/// caller to use immediately.
pub fn try_evict_one_cache_block(&self) -> Option<CacheBlock> {
let mut clock_hand = self.clock_hand.lock().unwrap();
for _ in 0..100 {
self.clock_iterations_counter.inc();
(*clock_hand) += 1;
let mut evict_this = false;
let num_buckets = self.block_map.get_num_buckets();
match self
.block_map
.get_bucket((*clock_hand) % num_buckets)
.as_deref()
{
None => {
// This bucket was unused
}
Some(blk_entry) => {
if !blk_entry.referenced.swap(false, Ordering::Relaxed) {
// Evict this. Maybe.
evict_this = true;
}
}
};
if evict_this {
// grab the write lock
let mut evicted_cache_block = None;
let res =
self.block_map
.update_with_fn_at_bucket(*clock_hand % num_buckets, |old| {
match old {
None => UpdateAction::Nothing,
Some(old) => {
// note: all the accesses to 'pinned' currently happen
// within update_with_fn(), or while holding ValueReadGuard, which protects from concurrent
// updates. Otherwise, another thread could set the 'pinned'
// flag just after we have checked it here.
if old.pinned.load(Ordering::Relaxed) != 0 {
return UpdateAction::Nothing;
}
let _ = self
.global_lw_lsn
.fetch_max(old.lw_lsn.load().0, Ordering::Relaxed);
let cache_block = old
.cache_block
.swap(INVALID_CACHE_BLOCK, Ordering::Relaxed);
if cache_block != INVALID_CACHE_BLOCK {
evicted_cache_block = Some(cache_block);
}
UpdateAction::Remove
}
}
});
// Out of memory should not happen here, as we're only updating existing values,
// not inserting new entries to the map.
res.expect("out of memory");
if evicted_cache_block.is_some() {
self.page_evictions_counter.inc();
return evicted_cache_block;
}
}
}
// Give up if we didn't find anything
None
}
pub fn resize_file_cache(&self, num_blocks: u32) {
let old_num_blocks = self.block_map.get_num_buckets() as u32;
if old_num_blocks < num_blocks {
if let Err(err) = self.block_map.grow(num_blocks) {
tracing::warn!(
"could not grow file cache to {} blocks (old size {}): {}",
num_blocks,
old_num_blocks,
err
);
}
}
}
pub fn dump_map(&self, _dst: &mut dyn std::io::Write) {
//FIXME self.cache_map.start_read().dump(dst);
}
}
impl metrics::core::Collector for IntegratedCacheWriteAccess<'_> {
fn desc(&self) -> Vec<&metrics::core::Desc> {
let mut descs = Vec::new();
descs.append(&mut self.page_evictions_counter.desc());
descs.append(&mut self.clock_iterations_counter.desc());
descs.append(&mut self.block_map_num_buckets.desc());
descs.append(&mut self.block_map_num_buckets_in_use.desc());
descs.append(&mut self.relsize_cache_num_buckets.desc());
descs.append(&mut self.relsize_cache_num_buckets_in_use.desc());
descs
}
fn collect(&self) -> Vec<metrics::proto::MetricFamily> {
// Update gauges
self.block_map_num_buckets
.set(self.block_map.get_num_buckets() as i64);
self.block_map_num_buckets_in_use
.set(self.block_map.get_num_buckets_in_use() as i64);
self.relsize_cache_num_buckets
.set(self.relsize_cache.get_num_buckets() as i64);
self.relsize_cache_num_buckets_in_use
.set(self.relsize_cache.get_num_buckets_in_use() as i64);
let mut values = Vec::new();
values.append(&mut self.page_evictions_counter.collect());
values.append(&mut self.clock_iterations_counter.collect());
values.append(&mut self.block_map_num_buckets.collect());
values.append(&mut self.block_map_num_buckets_in_use.collect());
values.append(&mut self.relsize_cache_num_buckets.collect());
values.append(&mut self.relsize_cache_num_buckets_in_use.collect());
values
}
}
/// Read relation size from the cache.
///
/// This is in a separate function so that it can be shared by
/// IntegratedCacheReadAccess::get_rel_size() and IntegratedCacheWriteAccess::get_rel_size()
fn get_rel_size<'t>(
r: &neon_shmem::hash::HashMapAccess<RelKey, RelEntry>,
rel: &RelTag,
) -> Option<u32> {
if let Some(rel_entry) = r.get(&RelKey::from(rel)) {
let nblocks = rel_entry.nblocks.load(Ordering::Relaxed);
if nblocks != u32::MAX {
Some(nblocks)
} else {
None
}
} else {
None
}
}
/// Accessor for other backends
///
/// This allows backends to read pages from the cache directly, on their own, without making a
/// request to the communicator process.
impl<'t> IntegratedCacheReadAccess<'t> {
pub fn get_rel_size(&'t self, rel: &RelTag) -> Option<u32> {
get_rel_size(&self.relsize_cache, rel)
}
pub fn start_read_op(&'t self) -> BackendCacheReadOp<'t> {
BackendCacheReadOp {
read_guards: Vec::new(),
map_access: self,
}
}
}
pub struct BackendCacheReadOp<'t> {
read_guards: Vec<DeferredUnpin>,
map_access: &'t IntegratedCacheReadAccess<'t>,
}
impl<'e> BackendCacheReadOp<'e> {
/// Initiate a read of the page from the cache.
///
/// This returns the "cache block number", i.e. the block number within the cache file, where
/// the page's contents is stored. To get the page contents, the caller needs to read that block
/// from the cache file. This returns a guard object that you must hold while it performs the
/// read. It's possible that while you are performing the read, the cache block is invalidated.
/// After you have completed the read, call BackendCacheReadResult::finish() to check if the
/// read was in fact valid or not. If it was concurrently invalidated, you need to retry.
pub fn get_page(&mut self, rel: &RelTag, block_number: u32) -> Option<u64> {
if let Some(block_entry) = self
.map_access
.block_map
.get(&BlockKey::from((rel, block_number)))
{
block_entry.referenced.store(true, Ordering::Relaxed);
let cache_block = block_entry.cache_block.load(Ordering::Relaxed);
if cache_block != INVALID_CACHE_BLOCK {
block_entry.pinned.fetch_add(1, Ordering::Relaxed);
self.read_guards
.push(DeferredUnpin(block_entry.pinned.as_ptr()));
Some(cache_block)
} else {
None
}
} else {
None
}
}
pub fn finish(self) -> bool {
// TODO: currently, we hold a pin on the in-memory map, so concurrent invalidations are not
// possible. But if we switch to optimistic locking, this would return 'false' if the
// optimistic locking failed and you need to retry.
true
}
}
/// A hack to decrement an AtomicU64 on drop. This is used to decrement the pin count
/// of a BlockEntry. The safety depends on the fact that the BlockEntry is not evicted
/// or moved while it's pinned.
struct DeferredUnpin(*mut u64);
unsafe impl Sync for DeferredUnpin {}
unsafe impl Send for DeferredUnpin {}
impl Drop for DeferredUnpin {
fn drop(&mut self) {
// unpin it
unsafe {
let pin_ref = AtomicU64::from_ptr(self.0);
pin_ref.fetch_sub(1, Ordering::Relaxed);
}
}
}

View File

@@ -1,27 +0,0 @@
//!
//! Three main parts:
//! - async tokio communicator core, which receives requests and processes them.
//! - Main loop and requests queues, which routes requests from backends to the core
//! - the per-backend glue code, which submits requests
//!
mod backend_comms;
// mark this 'pub', because these functions are called from C code. Otherwise, the compiler
// complains about a bunch of structs and enum variants being unused, because it thinkgs
// the functions that use them are never called. There are some C-callable functions in
// other modules too, but marking this as pub is currently enough to silence the warnings
//
// TODO: perhaps collect *all* the extern "C" functions to one module?
pub mod backend_interface;
mod file_cache;
mod init;
mod integrated_cache;
mod neon_request;
mod worker_process;
mod global_allocator;
// FIXME get this from postgres headers somehow
pub const BLCKSZ: usize = 8192;

View File

@@ -1,346 +0,0 @@
type CLsn = u64;
type COid = u32;
// This conveniently matches PG_IOV_MAX
pub const MAX_GETPAGEV_PAGES: usize = 32;
use pageserver_page_api as page_api;
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub enum NeonIORequest {
Empty,
// Read requests. These are C-friendly variants of the corresponding structs in
// pageserver_page_api.
RelExists(CRelExistsRequest),
RelSize(CRelSizeRequest),
GetPageV(CGetPageVRequest),
PrefetchV(CPrefetchVRequest),
DbSize(CDbSizeRequest),
// Write requests. These are needed to keep the relation size cache and LFC up-to-date.
// They are not sent to the pageserver.
WritePage(CWritePageRequest),
RelExtend(CRelExtendRequest),
RelZeroExtend(CRelZeroExtendRequest),
RelCreate(CRelCreateRequest),
RelTruncate(CRelTruncateRequest),
RelUnlink(CRelUnlinkRequest),
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub enum NeonIOResult {
Empty,
RelExists(bool),
RelSize(u32),
/// the result pages are written to the shared memory addresses given in the request
GetPageV,
/// A prefetch request returns as soon as the request has been received by the communicator.
/// It is processed in the background.
PrefetchVLaunched,
DbSize(u64),
// FIXME design compact error codes. Can't easily pass a string or other dynamic data.
// currently, this is 'errno'
Error(i32),
Aborted,
/// used for all write requests
WriteOK,
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct CCachedGetPageVResult {
pub cache_block_numbers: [u64; MAX_GETPAGEV_PAGES],
}
/// ShmemBuf represents a buffer in shared memory.
///
/// SAFETY: The pointer must point to an area in shared memory. The functions allow you to liberally
/// get a mutable pointer to the contents; it is the caller's responsibility to ensure that you
/// don't access a buffer that's you're not allowed to. Inappropriate access to the buffer doesn't
/// violate Rust's safety semantics, but it will mess up and crash Postgres.
///
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct ShmemBuf {
// These fields define where the result is written. Must point into a buffer in shared memory!
pub ptr: *mut u8,
}
unsafe impl Send for ShmemBuf {}
unsafe impl Sync for ShmemBuf {}
unsafe impl uring_common::buf::IoBuf for ShmemBuf {
fn stable_ptr(&self) -> *const u8 {
self.ptr
}
fn bytes_init(&self) -> usize {
crate::BLCKSZ
}
fn bytes_total(&self) -> usize {
crate::BLCKSZ
}
}
unsafe impl uring_common::buf::IoBufMut for ShmemBuf {
fn stable_mut_ptr(&mut self) -> *mut u8 {
self.ptr
}
unsafe fn set_init(&mut self, pos: usize) {
if pos > crate::BLCKSZ as usize {
panic!(
"set_init called past end of buffer, pos {}, buffer size {}",
pos,
crate::BLCKSZ
);
}
}
}
impl ShmemBuf {
pub fn as_mut_ptr(&self) -> *mut u8 {
self.ptr
}
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct CRelExistsRequest {
pub spc_oid: COid,
pub db_oid: COid,
pub rel_number: u32,
pub fork_number: u8,
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct CRelSizeRequest {
pub spc_oid: COid,
pub db_oid: COid,
pub rel_number: u32,
pub fork_number: u8,
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct CGetPageVRequest {
pub spc_oid: COid,
pub db_oid: COid,
pub rel_number: u32,
pub fork_number: u8,
pub block_number: u32,
pub nblocks: u8,
// These fields define where the result is written. Must point into a buffer in shared memory!
pub dest: [ShmemBuf; MAX_GETPAGEV_PAGES],
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct CPrefetchVRequest {
pub spc_oid: COid,
pub db_oid: COid,
pub rel_number: u32,
pub fork_number: u8,
pub block_number: u32,
pub nblocks: u8,
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct CDbSizeRequest {
pub db_oid: COid,
pub request_lsn: CLsn,
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct CWritePageRequest {
pub spc_oid: COid,
pub db_oid: COid,
pub rel_number: u32,
pub fork_number: u8,
pub block_number: u32,
pub lsn: CLsn,
// These fields define where the result is written. Must point into a buffer in shared memory!
pub src: ShmemBuf,
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct CRelExtendRequest {
pub spc_oid: COid,
pub db_oid: COid,
pub rel_number: u32,
pub fork_number: u8,
pub block_number: u32,
pub lsn: CLsn,
// These fields define page contents. Must point into a buffer in shared memory!
pub src_ptr: usize,
pub src_size: u32,
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct CRelZeroExtendRequest {
pub spc_oid: COid,
pub db_oid: COid,
pub rel_number: u32,
pub fork_number: u8,
pub block_number: u32,
pub nblocks: u32,
pub lsn: CLsn,
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct CRelCreateRequest {
pub spc_oid: COid,
pub db_oid: COid,
pub rel_number: u32,
pub fork_number: u8,
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct CRelTruncateRequest {
pub spc_oid: COid,
pub db_oid: COid,
pub rel_number: u32,
pub fork_number: u8,
pub nblocks: u32,
}
#[repr(C)]
#[derive(Copy, Clone, Debug)]
pub struct CRelUnlinkRequest {
pub spc_oid: COid,
pub db_oid: COid,
pub rel_number: u32,
pub fork_number: u8,
pub block_number: u32,
pub nblocks: u32,
}
impl CRelExistsRequest {
pub fn reltag(&self) -> page_api::RelTag {
page_api::RelTag {
spcnode: self.spc_oid,
dbnode: self.db_oid,
relnode: self.rel_number,
forknum: self.fork_number,
}
}
}
impl CRelSizeRequest {
pub fn reltag(&self) -> page_api::RelTag {
page_api::RelTag {
spcnode: self.spc_oid,
dbnode: self.db_oid,
relnode: self.rel_number,
forknum: self.fork_number,
}
}
}
impl CGetPageVRequest {
pub fn reltag(&self) -> page_api::RelTag {
page_api::RelTag {
spcnode: self.spc_oid,
dbnode: self.db_oid,
relnode: self.rel_number,
forknum: self.fork_number,
}
}
}
impl CPrefetchVRequest {
pub fn reltag(&self) -> page_api::RelTag {
page_api::RelTag {
spcnode: self.spc_oid,
dbnode: self.db_oid,
relnode: self.rel_number,
forknum: self.fork_number,
}
}
}
impl CWritePageRequest {
pub fn reltag(&self) -> page_api::RelTag {
page_api::RelTag {
spcnode: self.spc_oid,
dbnode: self.db_oid,
relnode: self.rel_number,
forknum: self.fork_number,
}
}
}
impl CRelExtendRequest {
pub fn reltag(&self) -> page_api::RelTag {
page_api::RelTag {
spcnode: self.spc_oid,
dbnode: self.db_oid,
relnode: self.rel_number,
forknum: self.fork_number,
}
}
}
impl CRelZeroExtendRequest {
pub fn reltag(&self) -> page_api::RelTag {
page_api::RelTag {
spcnode: self.spc_oid,
dbnode: self.db_oid,
relnode: self.rel_number,
forknum: self.fork_number,
}
}
}
impl CRelCreateRequest {
pub fn reltag(&self) -> page_api::RelTag {
page_api::RelTag {
spcnode: self.spc_oid,
dbnode: self.db_oid,
relnode: self.rel_number,
forknum: self.fork_number,
}
}
}
impl CRelTruncateRequest {
pub fn reltag(&self) -> page_api::RelTag {
page_api::RelTag {
spcnode: self.spc_oid,
dbnode: self.db_oid,
relnode: self.rel_number,
forknum: self.fork_number,
}
}
}
impl CRelUnlinkRequest {
pub fn reltag(&self) -> page_api::RelTag {
page_api::RelTag {
spcnode: self.spc_oid,
dbnode: self.db_oid,
relnode: self.rel_number,
forknum: self.fork_number,
}
}
}

View File

@@ -1,28 +0,0 @@
//! C callbacks to PostgreSQL facilities that the neon extension needs
//! to provide. These are implemented in `neon/pgxn/communicator_new.c`.
//! The function signatures better match!
//!
//! These are called from the communicator threads! Careful what you do, most
//! Postgres functions are not safe to call in that context.
use utils::lsn::Lsn;
unsafe extern "C" {
pub fn notify_proc_unsafe(procno: std::ffi::c_int);
pub fn callback_set_my_latch_unsafe();
pub fn callback_get_request_lsn_unsafe() -> u64;
}
// safe wrappers
pub(super) fn notify_proc(procno: std::ffi::c_int) {
unsafe { notify_proc_unsafe(procno) };
}
pub(super) fn callback_set_my_latch() {
unsafe { callback_set_my_latch_unsafe() };
}
pub(super) fn get_request_lsn() -> Lsn {
Lsn(unsafe { callback_get_request_lsn_unsafe() })
}

View File

@@ -1,84 +0,0 @@
use std::cmp::Eq;
use std::hash::Hash;
use std::sync::Arc;
use tokio::sync::{Mutex, OwnedMutexGuard};
use clashmap::ClashMap;
use clashmap::Entry;
use pageserver_page_api::RelTag;
#[derive(Clone, Eq, Hash, PartialEq)]
pub enum RequestInProgressKey {
Db(u32),
Rel(RelTag),
Block(RelTag, u32),
}
pub type RequestInProgressTable = MutexHashSet<RequestInProgressKey>;
// more primitive locking thingie:
pub struct MutexHashSet<K>
where
K: Clone + Eq + Hash,
{
lock_table: ClashMap<K, Arc<Mutex<()>>>,
}
pub struct MutexHashSetGuard<'a, K>
where
K: Clone + Eq + Hash,
{
pub key: K,
set: &'a MutexHashSet<K>,
mutex: Arc<Mutex<()>>,
_guard: OwnedMutexGuard<()>,
}
impl<'a, K> Drop for MutexHashSetGuard<'a, K>
where
K: Clone + Eq + Hash,
{
fn drop(&mut self) {
let (_old_key, old_val) = self.set.lock_table.remove(&self.key).unwrap();
assert!(Arc::ptr_eq(&old_val, &self.mutex));
// the guard will be dropped as we return
}
}
impl<K> MutexHashSet<K>
where
K: Clone + Eq + Hash,
{
pub fn new() -> MutexHashSet<K> {
MutexHashSet {
lock_table: ClashMap::new(),
}
}
pub async fn lock<'a>(&'a self, key: K) -> MutexHashSetGuard<'a, K> {
let my_mutex = Arc::new(Mutex::new(()));
let my_guard = Arc::clone(&my_mutex).lock_owned().await;
loop {
let lock = match self.lock_table.entry(key.clone()) {
Entry::Occupied(e) => Arc::clone(e.get()),
Entry::Vacant(e) => {
e.insert(Arc::clone(&my_mutex));
break;
}
};
let _ = lock.lock().await;
}
MutexHashSetGuard {
key: key,
set: &self,
mutex: my_mutex,
_guard: my_guard,
}
}
}

View File

@@ -1,229 +0,0 @@
//! Glue code to hook up Rust logging, with the `tracing` crate, to the PostgreSQL log
//!
//! In the Rust threads, the log messages are written to a mpsc Channel, and the Postgres
//! process latch is raised. That wakes up the loop in the main thread. It reads the
//! message from the channel and ereport()s it. This ensures that only one thread, the main
//! thread, calls the PostgreSQL logging routines at any time.
use std::sync::mpsc::sync_channel;
use std::sync::mpsc::{Receiver, SyncSender};
use std::sync::mpsc::{TryRecvError, TrySendError};
use tracing::info;
use tracing::{Event, Level, Metadata, Subscriber};
use tracing_subscriber::filter::LevelFilter;
use tracing_subscriber::fmt::FmtContext;
use tracing_subscriber::fmt::FormatEvent;
use tracing_subscriber::fmt::FormatFields;
use tracing_subscriber::fmt::FormattedFields;
use tracing_subscriber::fmt::MakeWriter;
use tracing_subscriber::fmt::format::Writer;
use tracing_subscriber::registry::LookupSpan;
use crate::worker_process::callbacks::callback_set_my_latch;
pub struct LoggingState {
receiver: Receiver<FormattedEventWithMeta>,
}
/// Called once, at worker process startup. The returned LoggingState is passed back
/// in the subsequent calls to `pump_logging`. It is opaque to the C code.
#[unsafe(no_mangle)]
pub extern "C" fn configure_logging() -> Box<LoggingState> {
let (sender, receiver) = sync_channel(1000);
let maker = Maker { channel: sender };
use tracing_subscriber::prelude::*;
let r = tracing_subscriber::registry();
let r = r.with(
tracing_subscriber::fmt::layer()
.event_format(SimpleFormatter::new())
.with_writer(maker)
// TODO: derive this from log_min_messages?
.with_filter(LevelFilter::from_level(Level::INFO)),
);
r.init();
info!("communicator process logging started");
let state = LoggingState { receiver };
Box::new(state)
}
/// Read one message from the logging queue. This is essentially a wrapper to Receiver,
/// with a C-friendly signature.
///
/// The message is copied into *errbuf, which is a caller-supplied buffer of size `errbuf_len`.
/// If the message doesn't fit in the buffer, it is truncated. It is always NULL-terminated.
///
/// The error level is returned *elevel_p. It's one of the PostgreSQL error levels, see elog.h
#[unsafe(no_mangle)]
pub extern "C" fn pump_logging(
state: &mut LoggingState,
errbuf: *mut u8,
errbuf_len: u32,
elevel_p: &mut i32,
) -> i32 {
let msg = match state.receiver.try_recv() {
Err(TryRecvError::Empty) => return 0,
Err(TryRecvError::Disconnected) => return -1,
Ok(msg) => msg,
};
let src: &[u8] = &msg.message;
let dst = errbuf;
let len = std::cmp::min(src.len(), errbuf_len as usize - 1);
unsafe {
std::ptr::copy_nonoverlapping(src.as_ptr(), dst, len);
*(errbuf.add(len)) = b'\0'; // NULL terminator
}
// XXX: these levels are copied from PostgreSQL's elog.h. Introduce another enum
// to hide these?
*elevel_p = match msg.level {
Level::TRACE => 10, // DEBUG5
Level::DEBUG => 14, // DEBUG1
Level::INFO => 17, // INFO
Level::WARN => 19, // WARNING
Level::ERROR => 21, // ERROR
};
1
}
//---- The following functions can be called from any thread ----
#[derive(Clone)]
struct FormattedEventWithMeta {
message: Vec<u8>,
level: tracing::Level,
}
impl Default for FormattedEventWithMeta {
fn default() -> Self {
FormattedEventWithMeta {
message: Vec::new(),
level: tracing::Level::DEBUG,
}
}
}
struct EventBuilder<'a> {
event: FormattedEventWithMeta,
maker: &'a Maker,
}
impl std::io::Write for EventBuilder<'_> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.event.message.write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
self.maker.send_event(self.event.clone());
Ok(())
}
}
impl Drop for EventBuilder<'_> {
fn drop(&mut self) {
let maker = self.maker;
let event = std::mem::take(&mut self.event);
maker.send_event(event);
}
}
struct Maker {
channel: SyncSender<FormattedEventWithMeta>,
}
impl<'a> MakeWriter<'a> for Maker {
type Writer = EventBuilder<'a>;
fn make_writer(&'a self) -> Self::Writer {
panic!("not expected to be called when make_writer_for is implemented");
}
fn make_writer_for(&'a self, meta: &Metadata<'_>) -> Self::Writer {
EventBuilder {
event: FormattedEventWithMeta {
message: Vec::new(),
level: *meta.level(),
},
maker: self,
}
}
}
impl Maker {
fn send_event(&self, e: FormattedEventWithMeta) {
match self.channel.try_send(e) {
Ok(()) => {
// notify the main thread
callback_set_my_latch();
}
Err(TrySendError::Disconnected(_)) => {}
Err(TrySendError::Full(_)) => {
// TODO: record that some messages were lost
}
}
}
}
/// Simple formatter implementation for tracing_subscriber, which prints the log
/// spans and message part like the default formatter, but no timestamp or error
/// level. The error level is captured separately by `FormattedEventWithMeta',
/// and when the error is printed by the main thread, with PostgreSQL ereport(),
/// it gets a timestamp at that point. (The timestamp printed will therefore lag
/// behind the timestamp on the event here, if the main thread doesn't process
/// the log message promptly)
struct SimpleFormatter;
impl<S, N> FormatEvent<S, N> for SimpleFormatter
where
S: Subscriber + for<'a> LookupSpan<'a>,
N: for<'a> FormatFields<'a> + 'static,
{
fn format_event(
&self,
ctx: &FmtContext<'_, S, N>,
mut writer: Writer<'_>,
event: &Event<'_>,
) -> std::fmt::Result {
// Format all the spans in the event's span context.
if let Some(scope) = ctx.event_scope() {
for span in scope.from_root() {
write!(writer, "{}", span.name())?;
// `FormattedFields` is a formatted representation of the span's
// fields, which is stored in its extensions by the `fmt` layer's
// `new_span` method. The fields will have been formatted
// by the same field formatter that's provided to the event
// formatter in the `FmtContext`.
let ext = span.extensions();
let fields = &ext
.get::<FormattedFields<N>>()
.expect("will never be `None`");
// Skip formatting the fields if the span had no fields.
if !fields.is_empty() {
write!(writer, "{{{}}}", fields)?;
}
write!(writer, ": ")?;
}
}
// Write fields on the event
ctx.field_format().format_fields(writer.by_ref(), event)?;
writeln!(writer)
}
}
impl SimpleFormatter {
fn new() -> Self {
SimpleFormatter {}
}
}

View File

@@ -1,618 +0,0 @@
use std::collections::HashMap;
use std::os::fd::AsRawFd;
use std::os::fd::OwnedFd;
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::backend_comms::NeonIOHandle;
use crate::file_cache::FileCache;
use crate::global_allocator::MyAllocatorCollector;
use crate::init::CommunicatorInitStruct;
use crate::integrated_cache::{CacheResult, IntegratedCacheWriteAccess};
use crate::neon_request::{CGetPageVRequest, CPrefetchVRequest};
use crate::neon_request::{NeonIORequest, NeonIOResult};
use crate::worker_process::in_progress_ios::{RequestInProgressKey, RequestInProgressTable};
use pageserver_client_grpc::request_tracker::ShardedRequestTracker;
use pageserver_page_api as page_api;
use metrics::{IntCounter, IntCounterVec};
use tokio::io::AsyncReadExt;
use tokio_pipe::PipeRead;
use uring_common::buf::IoBuf;
use super::callbacks::{get_request_lsn, notify_proc};
use tracing::{error, info, trace};
use utils::lsn::Lsn;
pub struct CommunicatorWorkerProcessStruct<'a> {
neon_request_slots: &'a [NeonIOHandle],
request_tracker: ShardedRequestTracker,
pub(crate) cache: IntegratedCacheWriteAccess<'a>,
submission_pipe_read_fd: OwnedFd,
next_request_id: AtomicU64,
in_progress_table: RequestInProgressTable,
// Metrics
request_counters: IntCounterVec,
request_rel_exists_counter: IntCounter,
request_rel_size_counter: IntCounter,
request_get_pagev_counter: IntCounter,
request_prefetchv_counter: IntCounter,
request_db_size_counter: IntCounter,
request_write_page_counter: IntCounter,
request_rel_extend_counter: IntCounter,
request_rel_zero_extend_counter: IntCounter,
request_rel_create_counter: IntCounter,
request_rel_truncate_counter: IntCounter,
request_rel_unlink_counter: IntCounter,
getpage_cache_misses_counter: IntCounter,
getpage_cache_hits_counter: IntCounter,
request_nblocks_counters: IntCounterVec,
request_get_pagev_nblocks_counter: IntCounter,
request_prefetchv_nblocks_counter: IntCounter,
request_rel_zero_extend_nblocks_counter: IntCounter,
allocator_metrics: MyAllocatorCollector,
}
pub(super) async fn init(
cis: Box<CommunicatorInitStruct>,
tenant_id: String,
timeline_id: String,
auth_token: Option<String>,
mut shard_map: HashMap<utils::shard::ShardIndex, String>,
initial_file_cache_size: u64,
file_cache_path: Option<PathBuf>,
) -> CommunicatorWorkerProcessStruct<'static> {
info!("Test log message");
let last_lsn = get_request_lsn();
let file_cache = if let Some(path) = file_cache_path {
Some(FileCache::new(&path, initial_file_cache_size).expect("could not create cache file"))
} else {
// FIXME: temporarily for testing, use LFC even if disabled
Some(
FileCache::new(&PathBuf::from("new_filecache"), 1000)
.expect("could not create cache file"),
)
};
// TODO: for now, just hack in the gRPC port number. This needs to be plumbed through.
for connstr in shard_map.values_mut() {
*connstr = connstr.replace(":64000", ":51051");
}
tracing::warn!("mangled connstrings to use gRPC port 51051 shard_map={shard_map:?}");
// Initialize subsystems
let cache = cis
.integrated_cache_init_struct
.worker_process_init(last_lsn, file_cache);
let mut request_tracker = ShardedRequestTracker::new();
request_tracker.update_shard_map(shard_map,
None,
tenant_id,
timeline_id,
auth_token.as_deref()).await;
let request_counters = IntCounterVec::new(
metrics::core::Opts::new(
"backend_requests_total",
"Number of requests from backends.",
),
&["request_kind"],
)
.unwrap();
let request_rel_exists_counter = request_counters.with_label_values(&["rel_exists"]);
let request_rel_size_counter = request_counters.with_label_values(&["rel_size"]);
let request_get_pagev_counter = request_counters.with_label_values(&["get_pagev"]);
let request_prefetchv_counter = request_counters.with_label_values(&["prefetchv"]);
let request_db_size_counter = request_counters.with_label_values(&["db_size"]);
let request_write_page_counter = request_counters.with_label_values(&["write_page"]);
let request_rel_extend_counter = request_counters.with_label_values(&["rel_extend"]);
let request_rel_zero_extend_counter = request_counters.with_label_values(&["rel_zero_extend"]);
let request_rel_create_counter = request_counters.with_label_values(&["rel_create"]);
let request_rel_truncate_counter = request_counters.with_label_values(&["rel_truncate"]);
let request_rel_unlink_counter = request_counters.with_label_values(&["rel_unlink"]);
let getpage_cache_misses_counter = IntCounter::new(
"getpage_cache_misses",
"Number of file cache misses in get_pagev requests.",
)
.unwrap();
let getpage_cache_hits_counter = IntCounter::new(
"getpage_cache_hits",
"Number of file cache hits in get_pagev requests.",
)
.unwrap();
// For the requests that affect multiple blocks, have separate counters for the # of blocks affected
let request_nblocks_counters = IntCounterVec::new(
metrics::core::Opts::new(
"request_nblocks_total",
"Number of blocks in backend requests.",
),
&["request_kind"],
)
.unwrap();
let request_get_pagev_nblocks_counter =
request_nblocks_counters.with_label_values(&["get_pagev"]);
let request_prefetchv_nblocks_counter =
request_nblocks_counters.with_label_values(&["prefetchv"]);
let request_rel_zero_extend_nblocks_counter =
request_nblocks_counters.with_label_values(&["rel_zero_extend"]);
CommunicatorWorkerProcessStruct {
neon_request_slots: cis.neon_request_slots,
request_tracker,
cache,
submission_pipe_read_fd: cis.submission_pipe_read_fd,
next_request_id: AtomicU64::new(1),
in_progress_table: RequestInProgressTable::new(),
// metrics
request_counters,
request_rel_exists_counter,
request_rel_size_counter,
request_get_pagev_counter,
request_prefetchv_counter,
request_db_size_counter,
request_write_page_counter,
request_rel_extend_counter,
request_rel_zero_extend_counter,
request_rel_create_counter,
request_rel_truncate_counter,
request_rel_unlink_counter,
getpage_cache_misses_counter,
getpage_cache_hits_counter,
request_nblocks_counters,
request_get_pagev_nblocks_counter,
request_prefetchv_nblocks_counter,
request_rel_zero_extend_nblocks_counter,
allocator_metrics: MyAllocatorCollector::new(),
}
}
impl<'t> CommunicatorWorkerProcessStruct<'t> {
/// Main loop of the worker process. Receive requests from the backends and process them.
pub(super) async fn run(self: &'static Self) {
let mut idxbuf: [u8; 4] = [0; 4];
let mut submission_pipe_read =
PipeRead::try_from(self.submission_pipe_read_fd.as_raw_fd()).expect("invalid pipe fd");
loop {
// Wait for a backend to ring the doorbell
match submission_pipe_read.read(&mut idxbuf).await {
Ok(4) => {}
Ok(nbytes) => panic!("short read ({nbytes} bytes) on communicator pipe"),
Err(e) => panic!("error reading from communicator pipe: {e}"),
}
let request_idx = u32::from_ne_bytes(idxbuf);
// Read the IO request from the slot indicated in the wakeup
let Some(slot) =
self.neon_request_slots[request_idx as usize].start_processing_request()
else {
// This currently should not happen. But if we have multiple threads picking up
// requests, and without waiting for the notifications, it could.
panic!("no request in slot");
};
// Ok, we have ownership of this request now. We must process
// it now, there's no going back.
//trace!("processing request {request_idx}: {request:?}");
// Spawn a separate task for every request. That's a little excessive for requests that
// can be quickly satisfied from the cache, but we expect that to be rare, because the
// requesting backend would have already checked the cache.
tokio::spawn(async {
let result = self.handle_request(slot.get_request()).await;
let owner_procno = slot.get_owner_procno();
// Ok, we have completed the IO. Mark the request as completed. After that,
// we no longer have ownership of the slot, and must not modify it.
slot.completed(result);
// Notify the backend about the completion. (Note that the backend might see
// the completed status even before this; this is just a wakeup)
notify_proc(owner_procno);
});
}
}
fn request_lsns(&self, not_modified_since_lsn: Lsn) -> page_api::ReadLsn {
page_api::ReadLsn {
request_lsn: get_request_lsn(),
not_modified_since_lsn: Some(not_modified_since_lsn),
}
}
async fn handle_request<'x>(self: &'static Self, req: &'x NeonIORequest) -> NeonIOResult {
match req {
NeonIORequest::Empty => {
error!("unexpected Empty IO request");
NeonIOResult::Error(0)
}
NeonIORequest::RelExists(req) => {
self.request_rel_exists_counter.inc();
let rel = req.reltag();
let _in_progress_guard = self
.in_progress_table
.lock(RequestInProgressKey::Rel(rel.clone()));
let not_modified_since = match self.cache.get_rel_exists(&rel) {
CacheResult::Found(exists) => return NeonIOResult::RelExists(exists),
CacheResult::NotFound(lsn) => lsn,
};
match self
.request_tracker
.process_check_rel_exists_request(page_api::CheckRelExistsRequest {
read_lsn: self.request_lsns(not_modified_since),
rel,
})
.await
{
Ok(exists) => NeonIOResult::RelExists(exists),
Err(err) => {
info!("tonic error: {err:?}");
NeonIOResult::Error(0)
}
}
}
NeonIORequest::RelSize(req) => {
self.request_rel_size_counter.inc();
let rel = req.reltag();
let _in_progress_guard = self
.in_progress_table
.lock(RequestInProgressKey::Rel(rel.clone()));
// Check the cache first
let not_modified_since = match self.cache.get_rel_size(&rel) {
CacheResult::Found(nblocks) => {
tracing::trace!("found relsize for {:?} in cache: {}", rel, nblocks);
return NeonIOResult::RelSize(nblocks);
}
CacheResult::NotFound(lsn) => lsn,
};
let read_lsn = self.request_lsns(not_modified_since);
match self
.request_tracker
.process_get_rel_size_request(page_api::GetRelSizeRequest {
read_lsn,
rel: rel.clone(),
})
.await
{
Ok(nblocks) => {
// update the cache
tracing::info!("updated relsize for {:?} in cache: {}", rel, nblocks);
self.cache.remember_rel_size(&rel, nblocks);
NeonIOResult::RelSize(nblocks)
}
Err(err) => {
info!("tonic error: {err:?}");
NeonIOResult::Error(0)
}
}
}
NeonIORequest::GetPageV(req) => {
self.request_get_pagev_counter.inc();
self.request_get_pagev_nblocks_counter
.inc_by(req.nblocks as u64);
match self.handle_get_pagev_request(req).await {
Ok(()) => NeonIOResult::GetPageV,
Err(errno) => NeonIOResult::Error(errno),
}
}
NeonIORequest::PrefetchV(req) => {
self.request_prefetchv_counter.inc();
self.request_prefetchv_nblocks_counter
.inc_by(req.nblocks as u64);
let req = req.clone();
tokio::spawn(async move { self.handle_prefetchv_request(&req).await });
NeonIOResult::PrefetchVLaunched
}
NeonIORequest::DbSize(req) => {
self.request_db_size_counter.inc();
let _in_progress_guard = self
.in_progress_table
.lock(RequestInProgressKey::Db(req.db_oid));
// Check the cache first
let not_modified_since = match self.cache.get_db_size(req.db_oid) {
CacheResult::Found(db_size) => {
// get_page already copied the block content to the destination
return NeonIOResult::DbSize(db_size);
}
CacheResult::NotFound(lsn) => lsn,
};
match self
.request_tracker
.process_get_dbsize_request(page_api::GetDbSizeRequest {
read_lsn: self.request_lsns(not_modified_since),
db_oid: req.db_oid,
})
.await
{
Ok(db_size) => NeonIOResult::DbSize(db_size),
Err(err) => {
info!("tonic error: {err:?}");
NeonIOResult::Error(0)
}
}
}
// Write requests
NeonIORequest::WritePage(req) => {
self.request_write_page_counter.inc();
// Also store it in the LFC while we still have it
let rel = req.reltag();
let _in_progress_guard = self
.in_progress_table
.lock(RequestInProgressKey::Block(rel.clone(), req.block_number));
self.cache
.remember_page(&rel, req.block_number, req.src, Lsn(req.lsn), true)
.await;
NeonIOResult::WriteOK
}
NeonIORequest::RelExtend(req) => {
self.request_rel_extend_counter.inc();
// TODO: need to grab an io-in-progress lock for this? I guess not
self.cache
.remember_rel_size(&req.reltag(), req.block_number + 1);
NeonIOResult::WriteOK
}
NeonIORequest::RelZeroExtend(req) => {
self.request_rel_zero_extend_counter.inc();
self.request_rel_zero_extend_nblocks_counter
.inc_by(req.nblocks as u64);
// TODO: need to grab an io-in-progress lock for this? I guess not
self.cache
.remember_rel_size(&req.reltag(), req.block_number + req.nblocks);
NeonIOResult::WriteOK
}
NeonIORequest::RelCreate(req) => {
self.request_rel_create_counter.inc();
// TODO: need to grab an io-in-progress lock for this? I guess not
self.cache.remember_rel_size(&req.reltag(), 0);
NeonIOResult::WriteOK
}
NeonIORequest::RelTruncate(req) => {
self.request_rel_truncate_counter.inc();
// TODO: need to grab an io-in-progress lock for this? I guess not
self.cache.remember_rel_size(&req.reltag(), req.nblocks);
NeonIOResult::WriteOK
}
NeonIORequest::RelUnlink(req) => {
self.request_rel_unlink_counter.inc();
// TODO: need to grab an io-in-progress lock for this? I guess not
self.cache.forget_rel(&req.reltag());
NeonIOResult::WriteOK
}
}
}
async fn handle_get_pagev_request(&'t self, req: &CGetPageVRequest) -> Result<(), i32> {
let rel = req.reltag();
// Check the cache first
//
// Note: Because the backends perform a direct lookup in the cache before sending
// the request to the communicator process, we expect the pages to almost never
// be already in cache. It could happen when:
// 1. two backends try to read the same page at the same time, but that should never
// happen because there's higher level locking in the Postgres buffer manager, or
// 2. if a prefetch request finished at the same time as a backend requested the
// page. That's much more likely.
let mut cache_misses = Vec::with_capacity(req.nblocks as usize);
for i in 0..req.nblocks {
let blkno = req.block_number + i as u32;
// note: this is deadlock-safe even though we hold multiple locks at the same time,
// because they're always acquired in the same order.
let in_progress_guard = self
.in_progress_table
.lock(RequestInProgressKey::Block(rel.clone(), blkno))
.await;
let dest = req.dest[i as usize];
let not_modified_since = match self.cache.get_page(&rel, blkno, dest).await {
Ok(CacheResult::Found(_)) => {
// get_page already copied the block content to the destination
trace!("found blk {} in rel {:?} in LFC", blkno, rel);
continue;
}
Ok(CacheResult::NotFound(lsn)) => lsn,
Err(_io_error) => return Err(-1), // FIXME errno?
};
cache_misses.push((blkno, not_modified_since, dest, in_progress_guard));
}
self.getpage_cache_misses_counter
.inc_by(cache_misses.len() as u64);
self.getpage_cache_hits_counter
.inc_by(req.nblocks as u64 - cache_misses.len() as u64);
if cache_misses.is_empty() {
return Ok(());
}
let not_modified_since = cache_misses
.iter()
.map(|(_blkno, lsn, _dest, _guard)| *lsn)
.max()
.unwrap();
// TODO: Use batched protocol
for (blkno, _lsn, dest, _guard) in cache_misses.iter() {
match self
.request_tracker
.get_page(page_api::GetPageRequest {
request_id: self.next_request_id.fetch_add(1, Ordering::Relaxed),
request_class: page_api::GetPageClass::Normal,
read_lsn: self.request_lsns(not_modified_since),
rel: rel.clone(),
block_numbers: vec![*blkno],
})
.await
{
Ok(resp) => {
// Write the received page image directly to the shared memory location
// that the backend requested.
assert!(resp.page_images.len() == 1);
let page_image = resp.page_images[0].clone();
let src: &[u8] = page_image.as_ref();
let len = std::cmp::min(src.len(), dest.bytes_total() as usize);
unsafe {
std::ptr::copy_nonoverlapping(src.as_ptr(), dest.as_mut_ptr(), len);
};
// Also store it in the LFC while we have it
self.cache
.remember_page(&rel, *blkno, page_image, not_modified_since, false)
.await;
}
Err(err) => {
info!("tonic error: {err:?}");
return Err(-1);
}
}
}
Ok(())
}
async fn handle_prefetchv_request(
self: &'static Self,
req: &CPrefetchVRequest,
) -> Result<(), i32> {
let rel = req.reltag();
// Check the cache first
let mut cache_misses = Vec::with_capacity(req.nblocks as usize);
for i in 0..req.nblocks {
let blkno = req.block_number + i as u32;
// note: this is deadlock-safe even though we hold multiple locks at the same time,
// because they're always acquired in the same order.
let in_progress_guard = self
.in_progress_table
.lock(RequestInProgressKey::Block(rel.clone(), blkno))
.await;
let not_modified_since = match self.cache.page_is_cached(&rel, blkno).await {
Ok(CacheResult::Found(_)) => {
trace!("found blk {} in rel {:?} in LFC", blkno, rel);
continue;
}
Ok(CacheResult::NotFound(lsn)) => lsn,
Err(_io_error) => return Err(-1), // FIXME errno?
};
cache_misses.push((blkno, not_modified_since, in_progress_guard));
}
if cache_misses.is_empty() {
return Ok(());
}
let not_modified_since = cache_misses
.iter()
.map(|(_blkno, lsn, _guard)| *lsn)
.max()
.unwrap();
// TODO: spawn separate tasks for these. Use the integrated cache to keep track of the
// in-flight requests
// TODO: Use batched protocol
for (blkno, _lsn, _guard) in cache_misses.iter() {
match self
.request_tracker
.get_page(page_api::GetPageRequest {
request_id: self.next_request_id.fetch_add(1, Ordering::Relaxed),
request_class: page_api::GetPageClass::Prefetch,
read_lsn: self.request_lsns(not_modified_since),
rel: rel.clone(),
block_numbers: vec![*blkno],
})
.await
{
Ok(resp) => {
trace!(
"prefetch completed, remembering blk {} in rel {:?} in LFC",
*blkno, rel
);
assert!(resp.page_images.len() == 1);
let page_image = resp.page_images[0].clone();
self.cache
.remember_page(&rel, *blkno, page_image, not_modified_since, false)
.await;
}
Err(err) => {
info!("tonic error: {err:?}");
return Err(-1);
}
}
}
Ok(())
}
}
impl<'t> metrics::core::Collector for CommunicatorWorkerProcessStruct<'t> {
fn desc(&self) -> Vec<&metrics::core::Desc> {
let mut descs = Vec::new();
descs.append(&mut self.request_counters.desc());
descs.append(&mut self.getpage_cache_misses_counter.desc());
descs.append(&mut self.getpage_cache_hits_counter.desc());
descs.append(&mut self.request_nblocks_counters.desc());
if let Some(file_cache) = &self.cache.file_cache {
descs.append(&mut file_cache.desc());
}
descs.append(&mut self.cache.desc());
descs.append(&mut self.allocator_metrics.desc());
descs
}
fn collect(&self) -> Vec<metrics::proto::MetricFamily> {
let mut values = Vec::new();
values.append(&mut self.request_counters.collect());
values.append(&mut self.getpage_cache_misses_counter.collect());
values.append(&mut self.getpage_cache_hits_counter.collect());
values.append(&mut self.request_nblocks_counters.collect());
if let Some(file_cache) = &self.cache.file_cache {
values.append(&mut file_cache.collect());
}
values.append(&mut self.cache.collect());
values.append(&mut self.allocator_metrics.collect());
values
}
}

View File

@@ -1,83 +0,0 @@
//! Export information about Postgres, the communicator process, file cache etc. as
//! prometheus metrics.
use axum::Router;
use axum::body::Body;
use axum::extract::State;
use axum::response::Response;
use http::StatusCode;
use http::header::CONTENT_TYPE;
use metrics;
use metrics::proto::MetricFamily;
use metrics::{Encoder, TextEncoder};
use std::path::PathBuf;
use tokio::net::UnixListener;
use crate::worker_process::main_loop::CommunicatorWorkerProcessStruct;
impl<'a> CommunicatorWorkerProcessStruct<'a> {
pub(crate) async fn launch_exporter_task(&'static self) {
use axum::routing::get;
let app = Router::new()
.route("/metrics", get(get_metrics))
.route("/dump_cache_map", get(dump_cache_map))
.with_state(self);
// Listen on unix domain socket, in the data directory. That should be unique.
let path = PathBuf::from(".metrics.socket");
let listener = UnixListener::bind(path.clone()).unwrap();
tokio::spawn(async {
tracing::info!("metrics listener spawned");
axum::serve(listener, app).await.unwrap()
});
}
}
async fn dump_cache_map(
State(state): State<&CommunicatorWorkerProcessStruct<'static>>,
) -> Response {
let mut buf: Vec<u8> = Vec::new();
state.cache.dump_map(&mut buf);
Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, "application/text")
.body(Body::from(buf))
.unwrap()
}
/// Expose Prometheus metrics.
async fn get_metrics(State(state): State<&CommunicatorWorkerProcessStruct<'static>>) -> Response {
use metrics::core::Collector;
let metrics = state.collect();
// When we call TextEncoder::encode() below, it will immediately return an
// error if a metric family has no metrics, so we need to preemptively
// filter out metric families with no metrics.
let metrics = metrics
.into_iter()
.filter(|m| !m.get_metric().is_empty())
.collect::<Vec<MetricFamily>>();
let encoder = TextEncoder::new();
let mut buffer = vec![];
if let Err(e) = encoder.encode(&metrics, &mut buffer) {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header(CONTENT_TYPE, "application/text")
.body(Body::from(e.to_string()))
.unwrap()
} else {
Response::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, encoder.format_type())
.body(Body::from(buffer))
.unwrap()
}
}

View File

@@ -1,14 +0,0 @@
//! This code runs in the communicator worker process. This provides
//! the glue code to:
//!
//! - launch the 'processor',
//! - receive IO requests from backends and pass them to the processor,
//! - write results back to backends.
mod callbacks;
mod logging;
mod main_loop;
mod metrics_exporter;
mod worker_interface;
mod in_progress_ios;

View File

@@ -1,112 +0,0 @@
//! Functions called from the C code in the worker process
use std::collections::HashMap;
use std::ffi::{CStr, c_char};
use std::path::PathBuf;
use tracing::error;
use crate::init::CommunicatorInitStruct;
use crate::worker_process::main_loop;
use crate::worker_process::main_loop::CommunicatorWorkerProcessStruct;
/// Launch the communicator's tokio tasks, which do most of the work.
///
/// The caller has initialized the process as a regular PostgreSQL
/// background worker process. The shared memory segment used to
/// communicate with the backends has been allocated and initialized
/// earlier, at postmaster startup, in rcommunicator_shmem_init().
#[unsafe(no_mangle)]
pub extern "C" fn communicator_worker_process_launch(
cis: Box<CommunicatorInitStruct>,
tenant_id: *const c_char,
timeline_id: *const c_char,
auth_token: *const c_char,
shard_map: *mut *mut c_char,
nshards: u32,
file_cache_path: *const c_char,
initial_file_cache_size: u64,
) -> &'static CommunicatorWorkerProcessStruct<'static> {
// Convert the arguments into more convenient Rust types
let tenant_id = unsafe { CStr::from_ptr(tenant_id) }.to_str().unwrap();
let timeline_id = unsafe { CStr::from_ptr(timeline_id) }.to_str().unwrap();
let auth_token = unsafe { auth_token.as_ref() }.map(|s| s.to_string());
let file_cache_path = {
if file_cache_path.is_null() {
None
} else {
let c_str = unsafe { CStr::from_ptr(file_cache_path) };
Some(PathBuf::from(c_str.to_str().unwrap()))
}
};
let shard_map = parse_shard_map(nshards, shard_map);
// start main loop
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.thread_name("communicator thread")
.build()
.unwrap();
let worker_struct = runtime.block_on(main_loop::init(
cis,
tenant_id.to_string(),
timeline_id.to_string(),
auth_token,
shard_map,
initial_file_cache_size,
file_cache_path,
));
let worker_struct = Box::leak(Box::new(worker_struct));
let main_loop_handle = runtime.spawn(worker_struct.run());
runtime.spawn(async {
let err = main_loop_handle.await.unwrap_err();
error!("error: {err:?}");
});
runtime.block_on(worker_struct.launch_exporter_task());
// keep the runtime running after we exit this function
Box::leak(Box::new(runtime));
worker_struct
}
/// Convert the "shard map" from an array of C strings, indexed by shard no to a rust HashMap
fn parse_shard_map(
nshards: u32,
shard_map: *mut *mut c_char,
) -> HashMap<utils::shard::ShardIndex, String> {
use utils::shard::*;
assert!(nshards <= u8::MAX as u32);
let mut result: HashMap<ShardIndex, String> = HashMap::new();
let mut p = shard_map;
for i in 0..nshards {
let c_str = unsafe { CStr::from_ptr(*p) };
p = unsafe { p.add(1) };
let s = c_str.to_str().unwrap();
let k = if nshards > 1 {
ShardIndex::new(ShardNumber(i as u8), ShardCount(nshards as u8))
} else {
ShardIndex::unsharded()
};
result.insert(k, s.into());
}
result
}
/// Inform the rust code about a configuration change
#[unsafe(no_mangle)]
pub extern "C" fn communicator_worker_config_reload(
proc_handle: &'static CommunicatorWorkerProcessStruct<'static>,
file_cache_size: u64,
) {
proc_handle.cache.resize_file_cache(file_cache_size as u32);
}

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More