Compare commits

..

19 Commits

Author SHA1 Message Date
Conrad Ludgate
26be13067c [proxy] refactor logging ID system 2025-07-18 22:21:48 +01:00
Paul Banks
791b5d736b Fixes #10441: control_plane README incorrect neon init args (#12646)
## Problem

As reported in #10441 the `control_plane/README/md` incorrectly
specified that `--pg-version` should be specified in the `cargo neon
init` command. This is not the case and causes an invalid argument
error.

## Summary of changes

Fix the README

## Test Plan

I verified that the steps in the README now work locally. I connected to
the started postgres endpoint and executed some basic metadata queries.
2025-07-18 17:09:20 +00:00
Krzysztof Szafrański
96bcfba79e [proxy] Cache GetEndpointAccessControl errors (#12571)
Related to https://github.com/neondatabase/cloud/issues/19353
2025-07-18 10:17:58 +00:00
Shockingly Good
8e95455aef Update the postgres submodules (#12636)
Synchronises the main branch's postgres submodules with the
`neondatabase/postgres` repository state.
2025-07-18 08:21:22 +00:00
Alex Chi Z.
f3ef60d236 fix(storcon): use unified interface to handle 404 lsn lease (#12650)
## Problem

Close LKB-270. This is part of our series of efforts to make sure
lsn_lease API prompts clients to retry. Follow up of
https://github.com/neondatabase/neon/pull/12631.

Slack thread w/ Vlad:
https://databricks.slack.com/archives/C09254R641L/p1752677940697529

## Summary of changes

- Use `tenant_remote_mutation` API for LSN leases. Makes it consistent
with new APIs added to storcon.
- For 404, we now always retry because we know the tenant is
to-be-attached and will eventually reach a point that we can find that
tenant on the intent pageserver.
- Using the `tenant_remote_mutation` API also prevents us from the case
where the intent pageserver changes within the lease request. The
wrapper function will error with 503 if such things happen.

---------

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-07-18 04:40:35 +00:00
HaoyuHuang
8f627ea0ab A few more SC changes (#12649)
## Problem

## Summary of changes
2025-07-17 23:17:01 +00:00
Arpad Müller
6a353c33e3 print more timestamps in find_lsn_for_timestamp (#12641)
Observability of `find_lsn_for_timestamp` is lacking, as well as how and
when we update gc space and time cutoffs. Log them.
2025-07-17 22:13:21 +00:00
Folke Behrens
64d0008389 proxy: Shorten the initial TTL of cancel keys (#12647)
## Problem

A high rate of short-lived connections means that there a lot of cancel
keys in Redis with TTL=10min that could be avoided by having a much
shorter initial TTL.

## Summary of changes

* Introduce an initial TTL of 1min used with the SET command.
* Fix: don't delay repushing cancel data when expired.
* Prepare for exponentially increasing TTLs.

## Alternatives

A best-effort UNLINK command on connection termination would clean up
cancel keys right away. This needs a bigger refactor due to how batching
is handled.
2025-07-17 21:52:20 +00:00
Alexey Kondratov
53a05e8ccb fix(compute_ctl): Only offload LFC state if no prewarming is in progress (#12645)
## Problem

We currently offload LFC state unconditionally, which can cause
problems. Imagine a situation:
1. Endpoint started with `autoprewarm: true`.
2. While prewarming is not completed, we upload the new incomplete
state.
3. Compute gets interrupted and restarts.
4. We start again and try to prewarm with the state from 2. instead of
the previous complete state.

During the orchestrated prewarming, it's probably not a big issue, but
it's still better to do not interfere with the prewarm process.

## Summary of changes

Do not offload LFC state if we are currently prewarming or any issue
occurred. While on it, also introduce `Skipped` LFC prewarm status,
which is used when the corresponding LFC state is not present in the
endpoint storage. It's primarily needed to distinguish the first compute
start for particular endpoint, as it's completely valid to do not have
LFC state yet.
2025-07-17 21:43:43 +00:00
Vlad Lazar
62c0152e6b pageserver: shut down compute connections at libpq level (#12642)
## Problem

Previously, if a get page failure was cause by timeline shutdown, the
pageserver would attempt to tear down the connection gracefully:
`shutdown(SHUT_WR)` followed by `close()`.

This triggers a code path on the compute where it has to tell apart
between an idle connection and a closed one. That code is bug prone, so
we can just side-step the issue by shutting down the connection via a
libpq error message.

This surfaced as instability in test_shard_resolve_during_split_abort.
It's a new test, but the issue existed for ages.

## Summary of Changes

Send a libpq error message instead of doing graceful TCP connection
shutdown.

Closes LKB-648
2025-07-17 21:03:55 +00:00
Konstantin Knizhnik
7fef4435c1 Store stripe_size in shared memory (#12560)
## Problem

See https://databricks.slack.com/archives/C09254R641L/p1752004515032899

stripe_size GUC update may be delayed at different backends and so cause
inconsistency with connection strings (shard map).

## Summary of changes

Postmaster should store stripe_size in shared memory as well as
connection strings.
It should be also enforced that stripe size is defined prior to
connection strings in postgresql.conf

---------

Co-authored-by: Konstantin Knizhnik <knizhnik@neon.tech>
Co-authored-by: Kosntantin Knizhnik <konstantin.knizhnik@databricks.com>
2025-07-17 20:32:34 +00:00
Konstantin Knizhnik
43fd5b218b Refactor shmem initialization in Neon extension (#12630)
## Problem

Initializing of shared memory in extension is complex and non-portable.
In neon extension this boilerplate code is duplicated in several files.

## Summary of changes

Perform all initialization in one place - neon.c
All other module procvide *ShmemRequest() and *ShmemInit() fuinction
which are called from neon.c

---------

Co-authored-by: Kosntantin Knizhnik <konstantin.knizhnik@databricks.com>
Co-authored-by: Heikki Linnakangas <heikki@neon.tech>
2025-07-17 20:20:38 +00:00
Alex Chi Z.
29ee273d78 fix(storcon): correctly converts 404 for tenant passthrough requests (#12631)
## Problem

Follow up of https://github.com/neondatabase/neon/pull/12620

Discussions:
https://databricks.slack.com/archives/C09254R641L/p1752677940697529

The original code and after the patch above we converts 404s to 503s
regardless of the type of 404. We should only do that for tenant not
found errors. For other 404s like timeline not found, we should not
prompt clients to retry.

## Summary of changes

- Inspect the response body to figure out the type of 404. If it's a
tenant not found error, return 503.
- Otherwise, fallthrough and return 404 as-is.
- Add `tenant_shard_remote_mutation` that manipulates a single shard.
- Use `Service::tenant_shard_remote_mutation` for tenant shard
passthrough requests. This prevents us from another race that the attach
state changes within the request. (This patch mainly addresses the case
that the tenant is "not yet attached").
- TODO: lease API is still using the old code path. We should refactor
it to use `tenant_remote_mutation`.

---------

Signed-off-by: Alex Chi Z <chi@neon.tech>
2025-07-17 19:42:48 +00:00
Conrad Ludgate
8b0f2efa57 experiment with an InfoMetrics metric family (#12612)
Putting this in the neon codebase for now, to experiment. Can be lifted
into measured at a later date.

This metric family is like a MetricVec, but it only supports 1 label
being set at a time. It is useful for reporting info, rather than
reporting metrics.
https://www.robustperception.io/exposing-the-software-version-to-prometheus/
2025-07-17 17:58:47 +00:00
quantumish
b309cbc6e9 Add resizable hashmap and RwLock implementations to neon-shmem (#12596)
Second PR for the hashmap behind the updated LFC implementation ([see
first here](https://github.com/neondatabase/neon/pull/12595)). This only
adds the raw code for the hashmap/lock implementations and doesn't plug
it into the crate (that's dependent on the previous PR and should
probably be done when the full integration into the new communicator is
merged alongside `communicator-rewrite` changes?).

Some high level details: the communicator codebase expects to be able to
store references to entries within this hashmap for arbitrary periods of
time and so the hashmap cannot be allowed to move them during a rehash.
As a result, this implementation has a slightly unusual structure where
key-value pairs (and hash chains) are allocated in a separate region
with a freelist. The core hashmap structure is then an array of
"dictionary entries" that are just indexes into this region of key-value
pairs.

Concurrency support is very naive at the moment with the entire map
guarded by one big `RwLock` (which is implemented on top of a
`pthread_rwlock_t` since Rust doesn't guarantee that a
`std::sync::RwLock` is safe to use in shared memory). This (along with a
lot of other things) is being changed on the
`quantumish/lfc-resizable-map` branch.
2025-07-17 17:40:53 +00:00
Aleksandr Sarantsev
f0c0733a64 storcon: Ignore stuck reconciles when considering optimizations (#12589)
## Problem

The `keep_failing_reconciles` counter was introduced in #12391, but
there is a special case:

> if a reconciliation loop claims to have succeeded, but maybe_reconcile
still thinks the tenant is in need of reconciliation, then that's a
probable bug and we should activate a similar backoff to prevent
flapping.

This PR redefines "flapping" to include not just repeated failures, but
also consecutive reconciliations of any kind (success or failure).

## Summary of Changes

- Replace `keep_failing_reconciles` with a new `stuck_reconciles` metric
- Replace `MAX_CONSECUTIVE_RECONCILIATION_ERRORS` with
`MAX_CONSECUTIVE_RECONCILES`, and increasing that from 5 to 10
- Increment the consecutive reconciles counter for all reconciles, not
just failures
- Reset the counter in `reconcile_all` when no reconcile is needed for a
shard
- Improve and fix the related test

---------

Co-authored-by: Aleksandr Sarantsev <aleksandr.sarantsev@databricks.com>
2025-07-17 14:52:57 +00:00
Vlad Lazar
8862e7c4bf tests: use new snapshot in test_forward_compat (#12637)
## Problem

The forward compatibility test is erroneously
using the downloaded (old) compatibility data. This test is meant to
test that old binaries can work with **new** data. Using the old
compatibility data renders this test useless.

## Summary of changes

Use new snapshot in test_forward_compat

Closes LKB-666

Co-authored-by: William Huang <william.huang@databricks.com>
2025-07-17 13:20:40 +00:00
HaoyuHuang
b7fc5a2fe0 A few SC changes (#12615)
## Summary of changes
A bunch of no-op changes.

---------

Co-authored-by: Vlad Lazar <vlad@neon.tech>
2025-07-17 13:14:36 +00:00
Aleksandr Sarantsev
4559ba79b6 Introduce force flag for new deletion API (#12588)
## Problem

The force deletion API should behave like the graceful deletion API - it
needs to support cancellation, persistence, and be non-blocking.

## Summary of Changes

- Added a `force` flag to the `NodeStartDelete` command.
- Passed the `force` flag through the `start_node_delete` handler in the
storage controller.
- Handled the `force` flag in the `delete_node` function.
- Set the tombstone after removing the node from memory.
- Minor cleanup, like adding a `get_error_on_cancel` closure.

---------

Co-authored-by: Aleksandr Sarantsev <aleksandr.sarantsev@databricks.com>
2025-07-17 11:51:31 +00:00
84 changed files with 3448 additions and 915 deletions

108
Cargo.lock generated
View File

@@ -1872,6 +1872,7 @@ dependencies = [
"diesel_derives",
"itoa",
"serde_json",
"uuid",
]
[[package]]
@@ -2533,6 +2534,18 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "getrandom"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
dependencies = [
"cfg-if",
"libc",
"r-efi",
"wasi 0.14.2+wasi-0.2.4",
]
[[package]]
name = "gettid"
version = "0.1.3"
@@ -3606,9 +3619,9 @@ checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104"
[[package]]
name = "lock_api"
version = "0.4.10"
version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16"
checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765"
dependencies = [
"autocfg",
"scopeguard",
@@ -3758,7 +3771,7 @@ dependencies = [
"procfs",
"prometheus",
"rand 0.8.5",
"rand_distr",
"rand_distr 0.4.3",
"twox-hash",
]
@@ -3846,7 +3859,12 @@ checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
name = "neon-shmem"
version = "0.1.0"
dependencies = [
"libc",
"lock_api",
"nix 0.30.1",
"rand 0.9.1",
"rand_distr 0.5.1",
"rustc-hash 2.1.1",
"tempfile",
"thiserror 1.0.69",
"workspace_hack",
@@ -5347,7 +5365,7 @@ dependencies = [
"postgres_backend",
"pq_proto",
"rand 0.8.5",
"rand_distr",
"rand_distr 0.4.3",
"rcgen",
"redis",
"regex",
@@ -5358,7 +5376,7 @@ dependencies = [
"reqwest-tracing",
"rsa",
"rstest",
"rustc-hash 1.1.0",
"rustc-hash 2.1.1",
"rustls 0.23.27",
"rustls-native-certs 0.8.0",
"rustls-pemfile 2.1.1",
@@ -5388,6 +5406,7 @@ dependencies = [
"tracing-test",
"tracing-utils",
"try-lock",
"type-safe-id",
"typed-json",
"url",
"urlencoding",
@@ -5451,6 +5470,12 @@ dependencies = [
"proc-macro2",
]
[[package]]
name = "r-efi"
version = "5.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
[[package]]
name = "rand"
version = "0.7.3"
@@ -5475,6 +5500,16 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rand"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
]
[[package]]
name = "rand_chacha"
version = "0.2.2"
@@ -5495,6 +5530,16 @@ dependencies = [
"rand_core 0.6.4",
]
[[package]]
name = "rand_chacha"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core 0.9.3",
]
[[package]]
name = "rand_core"
version = "0.5.1"
@@ -5513,6 +5558,15 @@ dependencies = [
"getrandom 0.2.11",
]
[[package]]
name = "rand_core"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
dependencies = [
"getrandom 0.3.3",
]
[[package]]
name = "rand_distr"
version = "0.4.3"
@@ -5523,6 +5577,16 @@ dependencies = [
"rand 0.8.5",
]
[[package]]
name = "rand_distr"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463"
dependencies = [
"num-traits",
"rand 0.9.1",
]
[[package]]
name = "rand_hc"
version = "0.2.0"
@@ -6933,6 +6997,7 @@ dependencies = [
"tokio-util",
"tracing",
"utils",
"uuid",
"workspace_hack",
]
@@ -8023,6 +8088,19 @@ dependencies = [
"static_assertions",
]
[[package]]
name = "type-safe-id"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd9267f90719e0433aae095640b294ff36ccbf89649ecb9ee34464ec504be157"
dependencies = [
"arrayvec",
"rand 0.9.1",
"serde",
"thiserror 2.0.11",
"uuid",
]
[[package]]
name = "typed-json"
version = "0.1.1"
@@ -8206,6 +8284,7 @@ dependencies = [
"tracing-error",
"tracing-subscriber",
"tracing-utils",
"uuid",
"walkdir",
]
@@ -8348,6 +8427,15 @@ version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "wasi"
version = "0.14.2+wasi-0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3"
dependencies = [
"wit-bindgen-rt",
]
[[package]]
name = "wasite"
version = "0.1.0"
@@ -8705,6 +8793,15 @@ dependencies = [
"windows-sys 0.48.0",
]
[[package]]
name = "wit-bindgen-rt"
version = "0.39.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
dependencies = [
"bitflags 2.8.0",
]
[[package]]
name = "workspace_hack"
version = "0.1.0"
@@ -8807,7 +8904,6 @@ dependencies = [
"tracing-log",
"tracing-subscriber",
"url",
"uuid",
"zeroize",
"zstd",
"zstd-safe",

View File

@@ -130,6 +130,7 @@ jemalloc_pprof = { version = "0.7", features = ["symbolize", "flamegraph"] }
jsonwebtoken = "9"
lasso = "0.7"
libc = "0.2"
lock_api = "0.4.13"
md5 = "0.7.0"
measured = { version = "0.0.22", features=["lasso"] }
measured-process = { version = "0.0.22" }
@@ -165,7 +166,7 @@ reqwest-middleware = "0.4"
reqwest-retry = "0.7"
routerify = "3"
rpds = "0.13"
rustc-hash = "1.1.0"
rustc-hash = "2.1.1"
rustls = { version = "0.23.16", default-features = false }
rustls-pemfile = "2"
rustls-pki-types = "1.11"

View File

@@ -2450,14 +2450,31 @@ LIMIT 100",
pub fn spawn_lfc_offload_task(self: &Arc<Self>, interval: Duration) {
self.terminate_lfc_offload_task();
let secs = interval.as_secs();
info!("spawning lfc offload worker with {secs}s interval");
let this = self.clone();
info!("spawning LFC offload worker with {secs}s interval");
let handle = spawn(async move {
let mut interval = time::interval(interval);
interval.tick().await; // returns immediately
loop {
interval.tick().await;
this.offload_lfc_async().await;
let prewarm_state = this.state.lock().unwrap().lfc_prewarm_state.clone();
// Do not offload LFC state if we are currently prewarming or any issue occurred.
// If we'd do that, we might override the LFC state in endpoint storage with some
// incomplete state. Imagine a situation:
// 1. Endpoint started with `autoprewarm: true`
// 2. While prewarming is not completed, we upload the new incomplete state
// 3. Compute gets interrupted and restarts
// 4. We start again and try to prewarm with the state from 2. instead of the previous complete state
if matches!(
prewarm_state,
LfcPrewarmState::Completed
| LfcPrewarmState::NotPrewarmed
| LfcPrewarmState::Skipped
) {
this.offload_lfc_async().await;
}
}
});
*self.lfc_offload_task.lock().unwrap() = Some(handle);

View File

@@ -89,7 +89,7 @@ impl ComputeNode {
self.state.lock().unwrap().lfc_offload_state.clone()
}
/// If there is a prewarm request ongoing, return false, true otherwise
/// If there is a prewarm request ongoing, return `false`, `true` otherwise.
pub fn prewarm_lfc(self: &Arc<Self>, from_endpoint: Option<String>) -> bool {
{
let state = &mut self.state.lock().unwrap().lfc_prewarm_state;
@@ -101,15 +101,25 @@ impl ComputeNode {
let cloned = self.clone();
spawn(async move {
let Err(err) = cloned.prewarm_impl(from_endpoint).await else {
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Completed;
return;
};
crate::metrics::LFC_PREWARM_ERRORS.inc();
error!(%err, "prewarming lfc");
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Failed {
error: err.to_string(),
let state = match cloned.prewarm_impl(from_endpoint).await {
Ok(true) => LfcPrewarmState::Completed,
Ok(false) => {
info!(
"skipping LFC prewarm because LFC state is not found in endpoint storage"
);
LfcPrewarmState::Skipped
}
Err(err) => {
crate::metrics::LFC_PREWARM_ERRORS.inc();
error!(%err, "could not prewarm LFC");
LfcPrewarmState::Failed {
error: err.to_string(),
}
}
};
cloned.state.lock().unwrap().lfc_prewarm_state = state;
});
true
}
@@ -120,15 +130,21 @@ impl ComputeNode {
EndpointStoragePair::from_spec_and_endpoint(state.pspec.as_ref().unwrap(), from_endpoint)
}
async fn prewarm_impl(&self, from_endpoint: Option<String>) -> Result<()> {
/// Request LFC state from endpoint storage and load corresponding pages into Postgres.
/// Returns a result with `false` if the LFC state is not found in endpoint storage.
async fn prewarm_impl(&self, from_endpoint: Option<String>) -> Result<bool> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?;
info!(%url, "requesting LFC state from endpoint storage");
info!(%url, "requesting LFC state from endpoint storage");
let request = Client::new().get(&url).bearer_auth(token);
let res = request.send().await.context("querying endpoint storage")?;
let status = res.status();
if status != StatusCode::OK {
bail!("{status} querying endpoint storage")
match status {
StatusCode::OK => (),
StatusCode::NOT_FOUND => {
return Ok(false);
}
_ => bail!("{status} querying endpoint storage"),
}
let mut uncompressed = Vec::new();
@@ -141,7 +157,8 @@ impl ComputeNode {
.await
.context("decoding LFC state")?;
let uncompressed_len = uncompressed.len();
info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}, loading into postgres");
info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}, loading into Postgres");
ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
.await
@@ -149,7 +166,9 @@ impl ComputeNode {
.query_one("select neon.prewarm_local_cache($1)", &[&uncompressed])
.await
.context("loading LFC state into postgres")
.map(|_| ())
.map(|_| ())?;
Ok(true)
}
/// If offload request is ongoing, return false, true otherwise
@@ -177,12 +196,14 @@ impl ComputeNode {
async fn offload_lfc_with_state_update(&self) {
crate::metrics::LFC_OFFLOADS.inc();
let Err(err) = self.offload_lfc_impl().await else {
self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Completed;
return;
};
crate::metrics::LFC_OFFLOAD_ERRORS.inc();
error!(%err, "offloading lfc");
error!(%err, "could not offload LFC state to endpoint storage");
self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Failed {
error: err.to_string(),
};
@@ -190,7 +211,7 @@ impl ComputeNode {
async fn offload_lfc_impl(&self) -> Result<()> {
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?;
info!(%url, "requesting LFC state from postgres");
info!(%url, "requesting LFC state from Postgres");
let mut compressed = Vec::new();
ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
@@ -205,13 +226,17 @@ impl ComputeNode {
.read_to_end(&mut compressed)
.await
.context("compressing LFC state")?;
let compressed_len = compressed.len();
info!(%url, "downloaded LFC state, compressed size {compressed_len}, writing to endpoint storage");
let request = Client::new().put(url).bearer_auth(token).body(compressed);
match request.send().await {
Ok(res) if res.status() == StatusCode::OK => Ok(()),
Ok(res) => bail!("Error writing to endpoint storage: {}", res.status()),
Ok(res) => bail!(
"Request to endpoint storage failed with status: {}",
res.status()
),
Err(err) => Err(err).context("writing to endpoint storage"),
}
}

View File

@@ -56,14 +56,15 @@ pub fn write_postgres_conf(
writeln!(file, "{conf}")?;
}
// Stripe size GUC should be defined prior to connection string
if let Some(stripe_size) = spec.shard_stripe_size {
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
// Add options for connecting to storage
writeln!(file, "# Neon storage settings")?;
if let Some(s) = &spec.pageserver_connstring {
writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?;
}
if let Some(stripe_size) = spec.shard_stripe_size {
writeln!(file, "neon.stripe_size={stripe_size}")?;
}
if !spec.safekeeper_connstrings.is_empty() {
let mut neon_safekeepers_value = String::new();
tracing::info!(

View File

@@ -613,11 +613,11 @@ components:
- skipped
properties:
status:
description: Lfc prewarm status
enum: [not_prewarmed, prewarming, completed, failed]
description: LFC prewarm status
enum: [not_prewarmed, prewarming, completed, failed, skipped]
type: string
error:
description: Lfc prewarm error, if any
description: LFC prewarm error, if any
type: string
total:
description: Total pages processed
@@ -635,11 +635,11 @@ components:
- status
properties:
status:
description: Lfc offload status
description: LFC offload status
enum: [not_offloaded, offloading, completed, failed]
type: string
error:
description: Lfc offload error, if any
description: LFC offload error, if any
type: string
PromoteState:

View File

@@ -8,10 +8,10 @@ code changes locally, but not suitable for running production systems.
## Example: Start with Postgres 16
To create and start a local development environment with Postgres 16, you will need to provide `--pg-version` flag to 3 of the start-up commands.
To create and start a local development environment with Postgres 16, you will need to provide `--pg-version` flag to 2 of the start-up commands.
```shell
cargo neon init --pg-version 16
cargo neon init
cargo neon start
cargo neon tenant create --set-default --pg-version 16
cargo neon endpoint create main --pg-version 16

View File

@@ -76,6 +76,12 @@ enum Command {
NodeStartDelete {
#[arg(long)]
node_id: NodeId,
/// When `force` is true, skip waiting for shards to prewarm during migration.
/// This can significantly speed up node deletion since prewarming all shards
/// can take considerable time, but may result in slower initial access to
/// migrated shards until they warm up naturally.
#[arg(long)]
force: bool,
},
/// Cancel deletion of the specified pageserver and wait for `timeout`
/// for the operation to be canceled. May be retried.
@@ -952,13 +958,14 @@ async fn main() -> anyhow::Result<()> {
.dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None)
.await?;
}
Command::NodeStartDelete { node_id } => {
Command::NodeStartDelete { node_id, force } => {
let query = if force {
format!("control/v1/node/{node_id}/delete?force=true")
} else {
format!("control/v1/node/{node_id}/delete")
};
storcon_client
.dispatch::<(), ()>(
Method::PUT,
format!("control/v1/node/{node_id}/delete"),
None,
)
.dispatch::<(), ()>(Method::PUT, query, None)
.await?;
println!("Delete started for {node_id}");
}

View File

@@ -46,16 +46,33 @@ pub struct ExtensionInstallResponse {
pub version: ExtVersion,
}
/// Status of the LFC prewarm process. The same state machine is reused for
/// both autoprewarm (prewarm after compute/Postgres start using the previously
/// stored LFC state) and explicit prewarming via API.
#[derive(Serialize, Default, Debug, Clone, PartialEq)]
#[serde(tag = "status", rename_all = "snake_case")]
pub enum LfcPrewarmState {
/// Default value when compute boots up.
#[default]
NotPrewarmed,
/// Prewarming thread is active and loading pages into LFC.
Prewarming,
/// We found requested LFC state in the endpoint storage and
/// completed prewarming successfully.
Completed,
Failed {
error: String,
},
/// Unexpected error happened during prewarming. Note, `Not Found 404`
/// response from the endpoint storage is explicitly excluded here
/// because it can normally happen on the first compute start,
/// since LFC state is not available yet.
Failed { error: String },
/// We tried to fetch the corresponding LFC state from the endpoint storage,
/// but received `Not Found 404`. This should normally happen only during the
/// first endpoint start after creation with `autoprewarm: true`.
///
/// During the orchestrated prewarm via API, when a caller explicitly
/// provides the LFC state key to prewarm from, it's the caller responsibility
/// to handle this status as an error state in this case.
Skipped,
}
impl Display for LfcPrewarmState {
@@ -64,6 +81,7 @@ impl Display for LfcPrewarmState {
LfcPrewarmState::NotPrewarmed => f.write_str("NotPrewarmed"),
LfcPrewarmState::Prewarming => f.write_str("Prewarming"),
LfcPrewarmState::Completed => f.write_str("Completed"),
LfcPrewarmState::Skipped => f.write_str("Skipped"),
LfcPrewarmState::Failed { error } => write!(f, "Error({error})"),
}
}

View File

@@ -4,12 +4,14 @@
//! a default registry.
#![deny(clippy::undocumented_unsafe_blocks)]
use std::sync::RwLock;
use measured::label::{LabelGroupSet, LabelGroupVisitor, LabelName, NoLabels};
use measured::metric::counter::CounterState;
use measured::metric::gauge::GaugeState;
use measured::metric::group::Encoding;
use measured::metric::name::{MetricName, MetricNameEncoder};
use measured::metric::{MetricEncoding, MetricFamilyEncoding};
use measured::metric::{MetricEncoding, MetricFamilyEncoding, MetricType};
use measured::{FixedCardinalityLabel, LabelGroup, MetricGroup};
use once_cell::sync::Lazy;
use prometheus::Registry;
@@ -116,12 +118,52 @@ pub fn pow2_buckets(start: usize, end: usize) -> Vec<f64> {
.collect()
}
pub struct InfoMetric<L: LabelGroup, M: MetricType = GaugeState> {
label: RwLock<L>,
metric: M,
}
impl<L: LabelGroup> InfoMetric<L> {
pub fn new(label: L) -> Self {
Self::with_metric(label, GaugeState::new(1))
}
}
impl<L: LabelGroup, M: MetricType<Metadata = ()>> InfoMetric<L, M> {
pub fn with_metric(label: L, metric: M) -> Self {
Self {
label: RwLock::new(label),
metric,
}
}
pub fn set_label(&self, label: L) {
*self.label.write().unwrap() = label;
}
}
impl<L, M, E> MetricFamilyEncoding<E> for InfoMetric<L, M>
where
L: LabelGroup,
M: MetricEncoding<E, Metadata = ()>,
E: Encoding,
{
fn collect_family_into(
&self,
name: impl measured::metric::name::MetricNameEncoder,
enc: &mut E,
) -> Result<(), E::Err> {
M::write_type(&name, enc)?;
self.metric
.collect_into(&(), &*self.label.read().unwrap(), name, enc)
}
}
pub struct BuildInfo {
pub revision: &'static str,
pub build_tag: &'static str,
}
// todo: allow label group without the set
impl LabelGroup for BuildInfo {
fn visit_values(&self, v: &mut impl LabelGroupVisitor) {
const REVISION: &LabelName = LabelName::from_str("revision");
@@ -131,24 +173,6 @@ impl LabelGroup for BuildInfo {
}
}
impl<T: Encoding> MetricFamilyEncoding<T> for BuildInfo
where
GaugeState: MetricEncoding<T>,
{
fn collect_family_into(
&self,
name: impl measured::metric::name::MetricNameEncoder,
enc: &mut T,
) -> Result<(), T::Err> {
enc.write_help(&name, "Build/version information")?;
GaugeState::write_type(&name, enc)?;
GaugeState {
count: std::sync::atomic::AtomicI64::new(1),
}
.collect_into(&(), self, name, enc)
}
}
#[derive(MetricGroup)]
#[metric(new(build_info: BuildInfo))]
pub struct NeonMetrics {
@@ -165,8 +189,8 @@ pub struct NeonMetrics {
#[derive(MetricGroup)]
#[metric(new(build_info: BuildInfo))]
pub struct LibMetrics {
#[metric(init = build_info)]
build_info: BuildInfo,
#[metric(init = InfoMetric::new(build_info))]
build_info: InfoMetric<BuildInfo>,
#[metric(flatten)]
rusage: Rusage,

View File

@@ -8,6 +8,13 @@ license.workspace = true
thiserror.workspace = true
nix.workspace=true
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
libc.workspace = true
lock_api.workspace = true
rustc-hash.workspace = true
[target.'cfg(target_os = "macos")'.dependencies]
tempfile = "3.14.0"
[dev-dependencies]
rand = "0.9"
rand_distr = "0.5.1"

583
libs/neon-shmem/src/hash.rs Normal file
View File

@@ -0,0 +1,583 @@
//! Resizable hash table implementation on top of byte-level storage (either a [`ShmemHandle`] or a fixed byte array).
//!
//! This hash table has two major components: the bucket array and the dictionary. Each bucket within the
//! bucket array contains a `Option<(K, V)>` and an index of another bucket. In this way there is both an
//! implicit freelist within the bucket array (`None` buckets point to other `None` entries) and various hash
//! chains within the bucket array (a Some bucket will point to other Some buckets that had the same hash).
//!
//! Buckets are never moved unless they are within a region that is being shrunk, and so the actual hash-
//! dependent component is done with the dictionary. When a new key is inserted into the map, a position
//! within the dictionary is decided based on its hash, the data is inserted into an empty bucket based
//! off of the freelist, and then the index of said bucket is placed in the dictionary.
//!
//! This map is resizable (if initialized on top of a [`ShmemHandle`]). Both growing and shrinking happen
//! in-place and are at a high level achieved by expanding/reducing the bucket array and rebuilding the
//! dictionary by rehashing all keys.
//!
//! Concurrency is managed very simply: the entire map is guarded by one shared-memory RwLock.
use std::hash::{BuildHasher, Hash};
use std::mem::MaybeUninit;
use crate::shmem::ShmemHandle;
use crate::{shmem, sync::*};
mod core;
pub mod entry;
#[cfg(test)]
mod tests;
use core::{Bucket, CoreHashMap, INVALID_POS};
use entry::{Entry, OccupiedEntry, PrevPos, VacantEntry};
use thiserror::Error;
/// Error type for a hashmap shrink operation.
#[derive(Error, Debug)]
pub enum HashMapShrinkError {
/// There was an error encountered while resizing the memory area.
#[error("shmem resize failed: {0}")]
ResizeError(shmem::Error),
/// Occupied entries in to-be-shrunk space were encountered beginning at the given index.
#[error("occupied entry in deallocated space found at {0}")]
RemainingEntries(usize),
}
/// This represents a hash table that (possibly) lives in shared memory.
/// If a new process is launched with fork(), the child process inherits
/// this struct.
#[must_use]
pub struct HashMapInit<'a, K, V, S = rustc_hash::FxBuildHasher> {
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
shared_size: usize,
hasher: S,
num_buckets: u32,
}
/// This is a per-process handle to a hash table that (possibly) lives in shared memory.
/// If a child process is launched with fork(), the child process should
/// get its own HashMapAccess by calling HashMapInit::attach_writer/reader().
///
/// XXX: We're not making use of it at the moment, but this struct could
/// hold process-local information in the future.
pub struct HashMapAccess<'a, K, V, S = rustc_hash::FxBuildHasher> {
shmem_handle: Option<ShmemHandle>,
shared_ptr: *mut HashMapShared<'a, K, V>,
hasher: S,
}
unsafe impl<K: Sync, V: Sync, S> Sync for HashMapAccess<'_, K, V, S> {}
unsafe impl<K: Send, V: Send, S> Send for HashMapAccess<'_, K, V, S> {}
impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
/// Change the 'hasher' used by the hash table.
///
/// NOTE: This must be called right after creating the hash table,
/// before inserting any entries and before calling attach_writer/reader.
/// Otherwise different accessors could be using different hash function,
/// with confusing results.
pub fn with_hasher<T: BuildHasher>(self, hasher: T) -> HashMapInit<'a, K, V, T> {
HashMapInit {
hasher,
shmem_handle: self.shmem_handle,
shared_ptr: self.shared_ptr,
shared_size: self.shared_size,
num_buckets: self.num_buckets,
}
}
/// Loosely (over)estimate the size needed to store a hash table with `num_buckets` buckets.
pub fn estimate_size(num_buckets: u32) -> usize {
// add some margin to cover alignment etc.
CoreHashMap::<K, V>::estimate_size(num_buckets) + size_of::<HashMapShared<K, V>>() + 1000
}
fn new(
num_buckets: u32,
shmem_handle: Option<ShmemHandle>,
area_ptr: *mut u8,
area_size: usize,
hasher: S,
) -> Self {
let mut ptr: *mut u8 = area_ptr;
let end_ptr: *mut u8 = unsafe { ptr.add(area_size) };
// carve out area for the One Big Lock (TM) and the HashMapShared.
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<libc::pthread_rwlock_t>())) };
let raw_lock_ptr = ptr;
ptr = unsafe { ptr.add(size_of::<libc::pthread_rwlock_t>()) };
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<HashMapShared<K, V>>())) };
let shared_ptr: *mut HashMapShared<K, V> = ptr.cast();
ptr = unsafe { ptr.add(size_of::<HashMapShared<K, V>>()) };
// carve out the buckets
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<core::Bucket<K, V>>())) };
let buckets_ptr = ptr;
ptr = unsafe { ptr.add(size_of::<core::Bucket<K, V>>() * num_buckets as usize) };
// use remaining space for the dictionary
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<u32>())) };
assert!(ptr.addr() < end_ptr.addr());
let dictionary_ptr = ptr;
let dictionary_size = unsafe { end_ptr.byte_offset_from(ptr) / size_of::<u32>() as isize };
assert!(dictionary_size > 0);
let buckets =
unsafe { std::slice::from_raw_parts_mut(buckets_ptr.cast(), num_buckets as usize) };
let dictionary = unsafe {
std::slice::from_raw_parts_mut(dictionary_ptr.cast(), dictionary_size as usize)
};
let hashmap = CoreHashMap::new(buckets, dictionary);
unsafe {
let lock = RwLock::from_raw(PthreadRwLock::new(raw_lock_ptr.cast()), hashmap);
std::ptr::write(shared_ptr, lock);
}
Self {
num_buckets,
shmem_handle,
shared_ptr,
shared_size: area_size,
hasher,
}
}
/// Attach to a hash table for writing.
pub fn attach_writer(self) -> HashMapAccess<'a, K, V, S> {
HashMapAccess {
shmem_handle: self.shmem_handle,
shared_ptr: self.shared_ptr,
hasher: self.hasher,
}
}
/// Initialize a table for reading. Currently identical to [`HashMapInit::attach_writer`].
///
/// This is a holdover from a previous implementation and is being kept around for
/// backwards compatibility reasons.
pub fn attach_reader(self) -> HashMapAccess<'a, K, V, S> {
self.attach_writer()
}
}
/// Hash table data that is actually stored in the shared memory area.
///
/// NOTE: We carve out the parts from a contiguous chunk. Growing and shrinking the hash table
/// relies on the memory layout! The data structures are laid out in the contiguous shared memory
/// area as follows:
///
/// [`libc::pthread_rwlock_t`]
/// [`HashMapShared`]
/// buckets
/// dictionary
///
/// In between the above parts, there can be padding bytes to align the parts correctly.
type HashMapShared<'a, K, V> = RwLock<CoreHashMap<'a, K, V>>;
impl<'a, K, V> HashMapInit<'a, K, V, rustc_hash::FxBuildHasher>
where
K: Clone + Hash + Eq,
{
/// Place the hash table within a user-supplied fixed memory area.
pub fn with_fixed(num_buckets: u32, area: &'a mut [MaybeUninit<u8>]) -> Self {
Self::new(
num_buckets,
None,
area.as_mut_ptr().cast(),
area.len(),
rustc_hash::FxBuildHasher,
)
}
/// Place a new hash map in the given shared memory area
///
/// # Panics
/// Will panic on failure to resize area to expected map size.
pub fn with_shmem(num_buckets: u32, shmem: ShmemHandle) -> Self {
let size = Self::estimate_size(num_buckets);
shmem
.set_size(size)
.expect("could not resize shared memory area");
let ptr = shmem.data_ptr.as_ptr().cast();
Self::new(
num_buckets,
Some(shmem),
ptr,
size,
rustc_hash::FxBuildHasher,
)
}
/// Make a resizable hash map within a new shared memory area with the given name.
pub fn new_resizeable_named(num_buckets: u32, max_buckets: u32, name: &str) -> Self {
let size = Self::estimate_size(num_buckets);
let max_size = Self::estimate_size(max_buckets);
let shmem =
ShmemHandle::new(name, size, max_size).expect("failed to make shared memory area");
let ptr = shmem.data_ptr.as_ptr().cast();
Self::new(
num_buckets,
Some(shmem),
ptr,
size,
rustc_hash::FxBuildHasher,
)
}
/// Make a resizable hash map within a new anonymous shared memory area.
pub fn new_resizeable(num_buckets: u32, max_buckets: u32) -> Self {
use std::sync::atomic::{AtomicUsize, Ordering};
static COUNTER: AtomicUsize = AtomicUsize::new(0);
let val = COUNTER.fetch_add(1, Ordering::Relaxed);
let name = format!("neon_shmem_hmap{val}");
Self::new_resizeable_named(num_buckets, max_buckets, &name)
}
}
impl<'a, K, V, S: BuildHasher> HashMapAccess<'a, K, V, S>
where
K: Clone + Hash + Eq,
{
/// Hash a key using the map's hasher.
#[inline]
fn get_hash_value(&self, key: &K) -> u64 {
self.hasher.hash_one(key)
}
fn entry_with_hash(&self, key: K, hash: u64) -> Entry<'a, '_, K, V> {
let mut map = unsafe { self.shared_ptr.as_ref() }.unwrap().write();
let dict_pos = hash as usize % map.dictionary.len();
let first = map.dictionary[dict_pos];
if first == INVALID_POS {
// no existing entry
return Entry::Vacant(VacantEntry {
map,
key,
dict_pos: dict_pos as u32,
});
}
let mut prev_pos = PrevPos::First(dict_pos as u32);
let mut next = first;
loop {
let bucket = &mut map.buckets[next as usize];
let (bucket_key, _bucket_value) = bucket.inner.as_mut().expect("entry is in use");
if *bucket_key == key {
// found existing entry
return Entry::Occupied(OccupiedEntry {
map,
_key: key,
prev_pos,
bucket_pos: next,
});
}
if bucket.next == INVALID_POS {
// No existing entry
return Entry::Vacant(VacantEntry {
map,
key,
dict_pos: dict_pos as u32,
});
}
prev_pos = PrevPos::Chained(next);
next = bucket.next;
}
}
/// Get a reference to the corresponding value for a key.
pub fn get<'e>(&'e self, key: &K) -> Option<ValueReadGuard<'e, V>> {
let hash = self.get_hash_value(key);
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
RwLockReadGuard::try_map(map, |m| m.get_with_hash(key, hash)).ok()
}
/// Get a reference to the entry containing a key.
///
/// NB: THis takes a write lock as there's no way to distinguish whether the intention
/// is to use the entry for reading or for writing in advance.
pub fn entry(&self, key: K) -> Entry<'a, '_, K, V> {
let hash = self.get_hash_value(&key);
self.entry_with_hash(key, hash)
}
/// Remove a key given its hash. Returns the associated value if it existed.
pub fn remove(&self, key: &K) -> Option<V> {
let hash = self.get_hash_value(key);
match self.entry_with_hash(key.clone(), hash) {
Entry::Occupied(e) => Some(e.remove()),
Entry::Vacant(_) => None,
}
}
/// Insert/update a key. Returns the previous associated value if it existed.
///
/// # Errors
/// Will return [`core::FullError`] if there is no more space left in the map.
pub fn insert(&self, key: K, value: V) -> Result<Option<V>, core::FullError> {
let hash = self.get_hash_value(&key);
match self.entry_with_hash(key.clone(), hash) {
Entry::Occupied(mut e) => Ok(Some(e.insert(value))),
Entry::Vacant(e) => {
_ = e.insert(value)?;
Ok(None)
}
}
}
/// Optionally return the entry for a bucket at a given index if it exists.
///
/// Has more overhead than one would intuitively expect: performs both a clone of the key
/// due to the [`OccupiedEntry`] type owning the key and also a hash of the key in order
/// to enable repairing the hash chain if the entry is removed.
pub fn entry_at_bucket(&self, pos: usize) -> Option<OccupiedEntry<'a, '_, K, V>> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
if pos >= map.buckets.len() {
return None;
}
let entry = map.buckets[pos].inner.as_ref();
match entry {
Some((key, _)) => Some(OccupiedEntry {
_key: key.clone(),
bucket_pos: pos as u32,
prev_pos: entry::PrevPos::Unknown(self.get_hash_value(key)),
map,
}),
_ => None,
}
}
/// Returns the number of buckets in the table.
pub fn get_num_buckets(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
map.get_num_buckets()
}
/// Return the key and value stored in bucket with given index. This can be used to
/// iterate through the hash map.
// TODO: An Iterator might be nicer. The communicator's clock algorithm needs to
// _slowly_ iterate through all buckets with its clock hand, without holding a lock.
// If we switch to an Iterator, it must not hold the lock.
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<(K, V)>> {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
if pos >= map.buckets.len() {
return None;
}
RwLockReadGuard::try_map(map, |m| m.buckets[pos].inner.as_ref()).ok()
}
/// Returns the index of the bucket a given value corresponds to.
pub fn get_bucket_for_value(&self, val_ptr: *const V) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
let origin = map.buckets.as_ptr();
let idx = (val_ptr as usize - origin as usize) / size_of::<Bucket<K, V>>();
assert!(idx < map.buckets.len());
idx
}
/// Returns the number of occupied buckets in the table.
pub fn get_num_buckets_in_use(&self) -> usize {
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
map.buckets_in_use as usize
}
/// Clears all entries in a table. Does not reset any shrinking operations.
pub fn clear(&self) {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
map.clear();
}
/// Perform an in-place rehash of some region (0..`rehash_buckets`) of the table and reset
/// the `buckets` and `dictionary` slices to be as long as `num_buckets`. Resets the freelist
/// in the process.
fn rehash_dict(
&self,
inner: &mut CoreHashMap<'a, K, V>,
buckets_ptr: *mut core::Bucket<K, V>,
end_ptr: *mut u8,
num_buckets: u32,
rehash_buckets: u32,
) {
inner.free_head = INVALID_POS;
let buckets;
let dictionary;
unsafe {
let buckets_end_ptr = buckets_ptr.add(num_buckets as usize);
let dictionary_ptr: *mut u32 = buckets_end_ptr
.byte_add(buckets_end_ptr.align_offset(align_of::<u32>()))
.cast();
let dictionary_size: usize =
end_ptr.byte_offset_from(buckets_end_ptr) as usize / size_of::<u32>();
buckets = std::slice::from_raw_parts_mut(buckets_ptr, num_buckets as usize);
dictionary = std::slice::from_raw_parts_mut(dictionary_ptr, dictionary_size);
}
for e in dictionary.iter_mut() {
*e = INVALID_POS;
}
for (i, bucket) in buckets.iter_mut().enumerate().take(rehash_buckets as usize) {
if bucket.inner.is_none() {
bucket.next = inner.free_head;
inner.free_head = i as u32;
continue;
}
let hash = self.hasher.hash_one(&bucket.inner.as_ref().unwrap().0);
let pos: usize = (hash % dictionary.len() as u64) as usize;
bucket.next = dictionary[pos];
dictionary[pos] = i as u32;
}
inner.dictionary = dictionary;
inner.buckets = buckets;
}
/// Rehash the map without growing or shrinking.
pub fn shuffle(&self) {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
let num_buckets = map.get_num_buckets() as u32;
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
let end_ptr: *mut u8 = unsafe { self.shared_ptr.byte_add(size_bytes).cast() };
let buckets_ptr = map.buckets.as_mut_ptr();
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets);
}
/// Grow the number of buckets within the table.
///
/// 1. Grows the underlying shared memory area
/// 2. Initializes new buckets and overwrites the current dictionary
/// 3. Rehashes the dictionary
///
/// # Panics
/// Panics if called on a map initialized with [`HashMapInit::with_fixed`].
///
/// # Errors
/// Returns an [`shmem::Error`] if any errors occur resizing the memory region.
pub fn grow(&self, num_buckets: u32) -> Result<(), shmem::Error> {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
let old_num_buckets = map.buckets.len() as u32;
assert!(
num_buckets >= old_num_buckets,
"grow called with a smaller number of buckets"
);
if num_buckets == old_num_buckets {
return Ok(());
}
let shmem_handle = self
.shmem_handle
.as_ref()
.expect("grow called on a fixed-size hash table");
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
shmem_handle.set_size(size_bytes)?;
let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
// Initialize new buckets. The new buckets are linked to the free list.
// NB: This overwrites the dictionary!
let buckets_ptr = map.buckets.as_mut_ptr();
unsafe {
for i in old_num_buckets..num_buckets {
let bucket = buckets_ptr.add(i as usize);
bucket.write(core::Bucket {
next: if i < num_buckets - 1 {
i + 1
} else {
map.free_head
},
inner: None,
});
}
}
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, old_num_buckets);
map.free_head = old_num_buckets;
Ok(())
}
/// Begin a shrink, limiting all new allocations to be in buckets with index below `num_buckets`.
///
/// # Panics
/// Panics if called on a map initialized with [`HashMapInit::with_fixed`] or if `num_buckets` is
/// greater than the number of buckets in the map.
pub fn begin_shrink(&mut self, num_buckets: u32) {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
assert!(
num_buckets <= map.get_num_buckets() as u32,
"shrink called with a larger number of buckets"
);
_ = self
.shmem_handle
.as_ref()
.expect("shrink called on a fixed-size hash table");
map.alloc_limit = num_buckets;
}
/// If a shrink operation is underway, returns the target size of the map. Otherwise, returns None.
pub fn shrink_goal(&self) -> Option<usize> {
let map = unsafe { self.shared_ptr.as_mut() }.unwrap().read();
let goal = map.alloc_limit;
if goal == INVALID_POS {
None
} else {
Some(goal as usize)
}
}
/// Complete a shrink after caller has evicted entries, removing the unused buckets and rehashing.
///
/// # Panics
/// The following cases result in a panic:
/// - Calling this function on a map initialized with [`HashMapInit::with_fixed`].
/// - Calling this function on a map when no shrink operation is in progress.
pub fn finish_shrink(&self) -> Result<(), HashMapShrinkError> {
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
assert!(
map.alloc_limit != INVALID_POS,
"called finish_shrink when no shrink is in progress"
);
let num_buckets = map.alloc_limit;
if map.get_num_buckets() == num_buckets as usize {
return Ok(());
}
assert!(
map.buckets_in_use <= num_buckets,
"called finish_shrink before enough entries were removed"
);
for i in (num_buckets as usize)..map.buckets.len() {
if map.buckets[i].inner.is_some() {
return Err(HashMapShrinkError::RemainingEntries(i));
}
}
let shmem_handle = self
.shmem_handle
.as_ref()
.expect("shrink called on a fixed-size hash table");
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
if let Err(e) = shmem_handle.set_size(size_bytes) {
return Err(HashMapShrinkError::ResizeError(e));
}
let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
let buckets_ptr = map.buckets.as_mut_ptr();
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets);
map.alloc_limit = INVALID_POS;
Ok(())
}
}

View File

@@ -0,0 +1,174 @@
//! Simple hash table with chaining.
use std::hash::Hash;
use std::mem::MaybeUninit;
use crate::hash::entry::*;
/// Invalid position within the map (either within the dictionary or bucket array).
pub(crate) const INVALID_POS: u32 = u32::MAX;
/// Fundamental storage unit within the hash table. Either empty or contains a key-value pair.
/// Always part of a chain of some kind (either a freelist if empty or a hash chain if full).
pub(crate) struct Bucket<K, V> {
/// Index of next bucket in the chain.
pub(crate) next: u32,
/// Key-value pair contained within bucket.
pub(crate) inner: Option<(K, V)>,
}
/// Core hash table implementation.
pub(crate) struct CoreHashMap<'a, K, V> {
/// Dictionary used to map hashes to bucket indices.
pub(crate) dictionary: &'a mut [u32],
/// Buckets containing key-value pairs.
pub(crate) buckets: &'a mut [Bucket<K, V>],
/// Head of the freelist.
pub(crate) free_head: u32,
/// Maximum index of a bucket allowed to be allocated. [`INVALID_POS`] if no limit.
pub(crate) alloc_limit: u32,
/// The number of currently occupied buckets.
pub(crate) buckets_in_use: u32,
}
/// Error for when there are no empty buckets left but one is needed.
#[derive(Debug, PartialEq)]
pub struct FullError;
impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
const FILL_FACTOR: f32 = 0.60;
/// Estimate the size of data contained within the the hash map.
pub fn estimate_size(num_buckets: u32) -> usize {
let mut size = 0;
// buckets
size += size_of::<Bucket<K, V>>() * num_buckets as usize;
// dictionary
size += (f32::ceil((size_of::<u32>() * num_buckets as usize) as f32 / Self::FILL_FACTOR))
as usize;
size
}
pub fn new(
buckets: &'a mut [MaybeUninit<Bucket<K, V>>],
dictionary: &'a mut [MaybeUninit<u32>],
) -> Self {
// Initialize the buckets
for i in 0..buckets.len() {
buckets[i].write(Bucket {
next: if i < buckets.len() - 1 {
i as u32 + 1
} else {
INVALID_POS
},
inner: None,
});
}
// Initialize the dictionary
for e in dictionary.iter_mut() {
e.write(INVALID_POS);
}
// TODO: use std::slice::assume_init_mut() once it stabilizes
let buckets =
unsafe { std::slice::from_raw_parts_mut(buckets.as_mut_ptr().cast(), buckets.len()) };
let dictionary = unsafe {
std::slice::from_raw_parts_mut(dictionary.as_mut_ptr().cast(), dictionary.len())
};
Self {
dictionary,
buckets,
free_head: 0,
buckets_in_use: 0,
alloc_limit: INVALID_POS,
}
}
/// Get the value associated with a key (if it exists) given its hash.
pub fn get_with_hash(&self, key: &K, hash: u64) -> Option<&V> {
let mut next = self.dictionary[hash as usize % self.dictionary.len()];
loop {
if next == INVALID_POS {
return None;
}
let bucket = &self.buckets[next as usize];
let (bucket_key, bucket_value) = bucket.inner.as_ref().expect("entry is in use");
if bucket_key == key {
return Some(bucket_value);
}
next = bucket.next;
}
}
/// Get number of buckets in map.
pub fn get_num_buckets(&self) -> usize {
self.buckets.len()
}
/// Clears all entries from the hashmap.
///
/// Does not reset any allocation limits, but does clear any entries beyond them.
pub fn clear(&mut self) {
for i in 0..self.buckets.len() {
self.buckets[i] = Bucket {
next: if i < self.buckets.len() - 1 {
i as u32 + 1
} else {
INVALID_POS
},
inner: None,
}
}
for i in 0..self.dictionary.len() {
self.dictionary[i] = INVALID_POS;
}
self.free_head = 0;
self.buckets_in_use = 0;
}
/// Find the position of an unused bucket via the freelist and initialize it.
pub(crate) fn alloc_bucket(&mut self, key: K, value: V) -> Result<u32, FullError> {
let mut pos = self.free_head;
// Find the first bucket we're *allowed* to use.
let mut prev = PrevPos::First(self.free_head);
while pos != INVALID_POS && pos >= self.alloc_limit {
let bucket = &mut self.buckets[pos as usize];
prev = PrevPos::Chained(pos);
pos = bucket.next;
}
if pos == INVALID_POS {
return Err(FullError);
}
// Repair the freelist.
match prev {
PrevPos::First(_) => {
let next_pos = self.buckets[pos as usize].next;
self.free_head = next_pos;
}
PrevPos::Chained(p) => {
if p != INVALID_POS {
let next_pos = self.buckets[pos as usize].next;
self.buckets[p as usize].next = next_pos;
}
}
_ => unreachable!(),
}
// Initialize the bucket.
let bucket = &mut self.buckets[pos as usize];
self.buckets_in_use += 1;
bucket.next = INVALID_POS;
bucket.inner = Some((key, value));
Ok(pos)
}
}

View File

@@ -0,0 +1,130 @@
//! Equivalent of [`std::collections::hash_map::Entry`] for this hashmap.
use crate::hash::core::{CoreHashMap, FullError, INVALID_POS};
use crate::sync::{RwLockWriteGuard, ValueWriteGuard};
use std::hash::Hash;
use std::mem;
pub enum Entry<'a, 'b, K, V> {
Occupied(OccupiedEntry<'a, 'b, K, V>),
Vacant(VacantEntry<'a, 'b, K, V>),
}
/// Enum representing the previous position within a chain.
#[derive(Clone, Copy)]
pub(crate) enum PrevPos {
/// Starting index within the dictionary.
First(u32),
/// Regular index within the buckets.
Chained(u32),
/// Unknown - e.g. the associated entry was retrieved by index instead of chain.
Unknown(u64),
}
pub struct OccupiedEntry<'a, 'b, K, V> {
/// Mutable reference to the map containing this entry.
pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>,
/// The key of the occupied entry
pub(crate) _key: K,
/// The index of the previous entry in the chain.
pub(crate) prev_pos: PrevPos,
/// The position of the bucket in the [`CoreHashMap`] bucket array.
pub(crate) bucket_pos: u32,
}
impl<K, V> OccupiedEntry<'_, '_, K, V> {
pub fn get(&self) -> &V {
&self.map.buckets[self.bucket_pos as usize]
.inner
.as_ref()
.unwrap()
.1
}
pub fn get_mut(&mut self) -> &mut V {
&mut self.map.buckets[self.bucket_pos as usize]
.inner
.as_mut()
.unwrap()
.1
}
/// Inserts a value into the entry, replacing (and returning) the existing value.
pub fn insert(&mut self, value: V) -> V {
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
// This assumes inner is Some, which it must be for an OccupiedEntry
mem::replace(&mut bucket.inner.as_mut().unwrap().1, value)
}
/// Removes the entry from the hash map, returning the value originally stored within it.
///
/// This may result in multiple bucket accesses if the entry was obtained by index as the
/// previous chain entry needs to be discovered in this case.
pub fn remove(mut self) -> V {
// If this bucket was queried by index, go ahead and follow its chain from the start.
let prev = if let PrevPos::Unknown(hash) = self.prev_pos {
let dict_idx = hash as usize % self.map.dictionary.len();
let mut prev = PrevPos::First(dict_idx as u32);
let mut curr = self.map.dictionary[dict_idx];
while curr != self.bucket_pos {
assert!(curr != INVALID_POS);
prev = PrevPos::Chained(curr);
curr = self.map.buckets[curr as usize].next;
}
prev
} else {
self.prev_pos
};
// CoreHashMap::remove returns Option<(K, V)>. We know it's Some for an OccupiedEntry.
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
// unlink it from the chain
match prev {
PrevPos::First(dict_pos) => {
self.map.dictionary[dict_pos as usize] = bucket.next;
}
PrevPos::Chained(bucket_pos) => {
self.map.buckets[bucket_pos as usize].next = bucket.next;
}
_ => unreachable!(),
}
// and add it to the freelist
let free = self.map.free_head;
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
let old_value = bucket.inner.take();
bucket.next = free;
self.map.free_head = self.bucket_pos;
self.map.buckets_in_use -= 1;
old_value.unwrap().1
}
}
/// An abstract view into a vacant entry within the map.
pub struct VacantEntry<'a, 'b, K, V> {
/// Mutable reference to the map containing this entry.
pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>,
/// The key to be inserted into this entry.
pub(crate) key: K,
/// The position within the dictionary corresponding to the key's hash.
pub(crate) dict_pos: u32,
}
impl<'b, K: Clone + Hash + Eq, V> VacantEntry<'_, 'b, K, V> {
/// Insert a value into the vacant entry, finding and populating an empty bucket in the process.
///
/// # Errors
/// Will return [`FullError`] if there are no unoccupied buckets in the map.
pub fn insert(mut self, value: V) -> Result<ValueWriteGuard<'b, V>, FullError> {
let pos = self.map.alloc_bucket(self.key, value)?;
self.map.buckets[pos as usize].next = self.map.dictionary[self.dict_pos as usize];
self.map.dictionary[self.dict_pos as usize] = pos;
Ok(RwLockWriteGuard::map(self.map, |m| {
&mut m.buckets[pos as usize].inner.as_mut().unwrap().1
}))
}
}

View File

@@ -0,0 +1,428 @@
use std::collections::BTreeMap;
use std::collections::HashSet;
use std::fmt::Debug;
use std::mem::MaybeUninit;
use crate::hash::Entry;
use crate::hash::HashMapAccess;
use crate::hash::HashMapInit;
use crate::hash::core::FullError;
use rand::seq::SliceRandom;
use rand::{Rng, RngCore};
use rand_distr::Zipf;
const TEST_KEY_LEN: usize = 16;
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
struct TestKey([u8; TEST_KEY_LEN]);
impl From<&TestKey> for u128 {
fn from(val: &TestKey) -> u128 {
u128::from_be_bytes(val.0)
}
}
impl From<u128> for TestKey {
fn from(val: u128) -> TestKey {
TestKey(val.to_be_bytes())
}
}
impl<'a> From<&'a [u8]> for TestKey {
fn from(bytes: &'a [u8]) -> TestKey {
TestKey(bytes.try_into().unwrap())
}
}
fn test_inserts<K: Into<TestKey> + Copy>(keys: &[K]) {
let w = HashMapInit::<TestKey, usize>::new_resizeable_named(100000, 120000, "test_inserts")
.attach_writer();
for (idx, k) in keys.iter().enumerate() {
let res = w.entry((*k).into());
match res {
Entry::Occupied(mut e) => {
e.insert(idx);
}
Entry::Vacant(e) => {
let res = e.insert(idx);
assert!(res.is_ok());
}
};
}
for (idx, k) in keys.iter().enumerate() {
let x = w.get(&(*k).into());
let value = x.as_deref().copied();
assert_eq!(value, Some(idx));
}
}
#[test]
fn dense() {
// This exercises splitting a node with prefix
let keys: &[u128] = &[0, 1, 2, 3, 256];
test_inserts(keys);
// Dense keys
let mut keys: Vec<u128> = (0..10000).collect();
test_inserts(&keys);
// Do the same in random orders
for _ in 1..10 {
keys.shuffle(&mut rand::rng());
test_inserts(&keys);
}
}
#[test]
fn sparse() {
// sparse keys
let mut keys: Vec<TestKey> = Vec::new();
let mut used_keys = HashSet::new();
for _ in 0..10000 {
loop {
let key = rand::random::<u128>();
if used_keys.contains(&key) {
continue;
}
used_keys.insert(key);
keys.push(key.into());
break;
}
}
test_inserts(&keys);
}
#[derive(Clone, Debug)]
struct TestOp(TestKey, Option<usize>);
fn apply_op(
op: &TestOp,
map: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
) {
// apply the change to the shadow tree first
let shadow_existing = if let Some(v) = op.1 {
shadow.insert(op.0, v)
} else {
shadow.remove(&op.0)
};
let entry = map.entry(op.0);
let hash_existing = match op.1 {
Some(new) => match entry {
Entry::Occupied(mut e) => Some(e.insert(new)),
Entry::Vacant(e) => {
_ = e.insert(new).unwrap();
None
}
},
None => match entry {
Entry::Occupied(e) => Some(e.remove()),
Entry::Vacant(_) => None,
},
};
assert_eq!(shadow_existing, hash_existing);
}
fn do_random_ops(
num_ops: usize,
size: u32,
del_prob: f64,
writer: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
rng: &mut rand::rngs::ThreadRng,
) {
for i in 0..num_ops {
let key: TestKey = ((rng.next_u32() % size) as u128).into();
let op = TestOp(
key,
if rng.random_bool(del_prob) {
Some(i)
} else {
None
},
);
apply_op(&op, writer, shadow);
}
}
fn do_deletes(
num_ops: usize,
writer: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
) {
for _ in 0..num_ops {
let (k, _) = shadow.pop_first().unwrap();
writer.remove(&k);
}
}
fn do_shrink(
writer: &mut HashMapAccess<TestKey, usize>,
shadow: &mut BTreeMap<TestKey, usize>,
from: u32,
to: u32,
) {
assert!(writer.shrink_goal().is_none());
writer.begin_shrink(to);
assert_eq!(writer.shrink_goal(), Some(to as usize));
for i in to..from {
if let Some(entry) = writer.entry_at_bucket(i as usize) {
shadow.remove(&entry._key);
entry.remove();
}
}
let old_usage = writer.get_num_buckets_in_use();
writer.finish_shrink().unwrap();
assert!(writer.shrink_goal().is_none());
assert_eq!(writer.get_num_buckets_in_use(), old_usage);
}
#[test]
fn random_ops() {
let mut writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(100000, 120000, "test_random")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let distribution = Zipf::new(u128::MAX as f64, 1.1).unwrap();
let mut rng = rand::rng();
for i in 0..100000 {
let key: TestKey = (rng.sample(distribution) as u128).into();
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
apply_op(&op, &mut writer, &mut shadow);
}
}
#[test]
fn test_shuffle() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 1200, "test_shuf")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng);
writer.shuffle();
do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng);
}
#[test]
fn test_grow() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 2000, "test_grow")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng);
let old_usage = writer.get_num_buckets_in_use();
writer.grow(1500).unwrap();
assert_eq!(writer.get_num_buckets_in_use(), old_usage);
assert_eq!(writer.get_num_buckets(), 1500);
do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
}
#[test]
fn test_clear() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_clear")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(2000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
writer.clear();
assert_eq!(writer.get_num_buckets_in_use(), 0);
assert_eq!(writer.get_num_buckets(), 1500);
while let Some((key, _)) = shadow.pop_first() {
assert!(writer.get(&key).is_none());
}
do_random_ops(2000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
for i in 0..(1500 - writer.get_num_buckets_in_use()) {
writer.insert((1500 + i as u128).into(), 0).unwrap();
}
assert_eq!(writer.insert(5000.into(), 0), Err(FullError {}));
writer.clear();
assert!(writer.insert(5000.into(), 0).is_ok());
}
#[test]
fn test_idx_remove() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_clear")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(2000, 1500, 0.25, &mut writer, &mut shadow, &mut rng);
for _ in 0..100 {
let idx = (rng.next_u32() % 1500) as usize;
if let Some(e) = writer.entry_at_bucket(idx) {
shadow.remove(&e._key);
e.remove();
}
}
while let Some((key, val)) = shadow.pop_first() {
assert_eq!(*writer.get(&key).unwrap(), val);
}
}
#[test]
fn test_idx_get() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_clear")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(2000, 1500, 0.25, &mut writer, &mut shadow, &mut rng);
for _ in 0..100 {
let idx = (rng.next_u32() % 1500) as usize;
if let Some(pair) = writer.get_at_bucket(idx) {
{
let v: *const usize = &pair.1;
assert_eq!(writer.get_bucket_for_value(v), idx);
}
{
let v: *const usize = &pair.1;
assert_eq!(writer.get_bucket_for_value(v), idx);
}
}
}
}
#[test]
fn test_shrink() {
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_shrink")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
do_shrink(&mut writer, &mut shadow, 1500, 1000);
assert_eq!(writer.get_num_buckets(), 1000);
do_deletes(500, &mut writer, &mut shadow);
do_random_ops(10000, 500, 0.75, &mut writer, &mut shadow, &mut rng);
assert!(writer.get_num_buckets_in_use() <= 1000);
}
#[test]
fn test_shrink_grow_seq() {
let mut writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 20000, "test_grow_seq")
.attach_writer();
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
let mut rng = rand::rng();
do_random_ops(500, 1000, 0.1, &mut writer, &mut shadow, &mut rng);
eprintln!("Shrinking to 750");
do_shrink(&mut writer, &mut shadow, 1000, 750);
do_random_ops(200, 1000, 0.5, &mut writer, &mut shadow, &mut rng);
eprintln!("Growing to 1500");
writer.grow(1500).unwrap();
do_random_ops(600, 1500, 0.1, &mut writer, &mut shadow, &mut rng);
eprintln!("Shrinking to 200");
while shadow.len() > 100 {
do_deletes(1, &mut writer, &mut shadow);
}
do_shrink(&mut writer, &mut shadow, 1500, 200);
do_random_ops(50, 1500, 0.25, &mut writer, &mut shadow, &mut rng);
eprintln!("Growing to 10k");
writer.grow(10000).unwrap();
do_random_ops(10000, 5000, 0.25, &mut writer, &mut shadow, &mut rng);
}
#[test]
fn test_bucket_ops() {
let writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 1200, "test_bucket_ops")
.attach_writer();
match writer.entry(1.into()) {
Entry::Occupied(mut e) => {
e.insert(2);
}
Entry::Vacant(e) => {
_ = e.insert(2).unwrap();
}
}
assert_eq!(writer.get_num_buckets_in_use(), 1);
assert_eq!(writer.get_num_buckets(), 1000);
assert_eq!(*writer.get(&1.into()).unwrap(), 2);
let pos = match writer.entry(1.into()) {
Entry::Occupied(e) => {
assert_eq!(e._key, 1.into());
e.bucket_pos as usize
}
Entry::Vacant(_) => {
panic!("Insert didn't affect entry");
}
};
assert_eq!(writer.entry_at_bucket(pos).unwrap()._key, 1.into());
assert_eq!(*writer.get_at_bucket(pos).unwrap(), (1.into(), 2));
{
let ptr: *const usize = &*writer.get(&1.into()).unwrap();
assert_eq!(writer.get_bucket_for_value(ptr), pos);
}
writer.remove(&1.into());
assert!(writer.get(&1.into()).is_none());
}
#[test]
fn test_shrink_zero() {
let mut writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_shrink_zero")
.attach_writer();
writer.begin_shrink(0);
for i in 0..1500 {
writer.entry_at_bucket(i).map(|x| x.remove());
}
writer.finish_shrink().unwrap();
assert_eq!(writer.get_num_buckets_in_use(), 0);
let entry = writer.entry(1.into());
if let Entry::Vacant(v) = entry {
assert!(v.insert(2).is_err());
} else {
panic!("Somehow got non-vacant entry in empty map.")
}
writer.grow(50).unwrap();
let entry = writer.entry(1.into());
if let Entry::Vacant(v) = entry {
assert!(v.insert(2).is_ok());
} else {
panic!("Somehow got non-vacant entry in empty map.")
}
assert_eq!(writer.get_num_buckets_in_use(), 1);
}
#[test]
#[should_panic]
fn test_grow_oom() {
let writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_grow_oom")
.attach_writer();
writer.grow(20000).unwrap();
}
#[test]
#[should_panic]
fn test_shrink_bigger() {
let mut writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2500, "test_shrink_bigger")
.attach_writer();
writer.begin_shrink(2000);
}
#[test]
#[should_panic]
fn test_shrink_early_finish() {
let writer =
HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2500, "test_shrink_early_finish")
.attach_writer();
writer.finish_shrink().unwrap();
}
#[test]
#[should_panic]
fn test_shrink_fixed_size() {
let mut area = [MaybeUninit::uninit(); 10000];
let init_struct = HashMapInit::<TestKey, usize>::with_fixed(3, &mut area);
let mut writer = init_struct.attach_writer();
writer.begin_shrink(1);
}

View File

@@ -1 +1,3 @@
pub mod hash;
pub mod shmem;
pub mod sync;

111
libs/neon-shmem/src/sync.rs Normal file
View File

@@ -0,0 +1,111 @@
//! Simple utilities akin to what's in [`std::sync`] but designed to work with shared memory.
use std::mem::MaybeUninit;
use std::ptr::NonNull;
use nix::errno::Errno;
pub type RwLock<T> = lock_api::RwLock<PthreadRwLock, T>;
pub type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, PthreadRwLock, T>;
pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, PthreadRwLock, T>;
pub type ValueReadGuard<'a, T> = lock_api::MappedRwLockReadGuard<'a, PthreadRwLock, T>;
pub type ValueWriteGuard<'a, T> = lock_api::MappedRwLockWriteGuard<'a, PthreadRwLock, T>;
/// Shared memory read-write lock.
pub struct PthreadRwLock(Option<NonNull<libc::pthread_rwlock_t>>);
/// Simple macro that calls a function in the libc namespace and panics if return value is nonzero.
macro_rules! libc_checked {
($fn_name:ident ( $($arg:expr),* )) => {{
let res = libc::$fn_name($($arg),*);
if res != 0 {
panic!("{} failed with {}", stringify!($fn_name), Errno::from_raw(res));
}
}};
}
impl PthreadRwLock {
/// Creates a new `PthreadRwLock` on top of a pointer to a pthread rwlock.
///
/// # Safety
/// `lock` must be non-null. Every unsafe operation will panic in the event of an error.
pub unsafe fn new(lock: *mut libc::pthread_rwlock_t) -> Self {
unsafe {
let mut attrs = MaybeUninit::uninit();
libc_checked!(pthread_rwlockattr_init(attrs.as_mut_ptr()));
libc_checked!(pthread_rwlockattr_setpshared(
attrs.as_mut_ptr(),
libc::PTHREAD_PROCESS_SHARED
));
libc_checked!(pthread_rwlock_init(lock, attrs.as_mut_ptr()));
// Safety: POSIX specifies that "any function affecting the attributes
// object (including destruction) shall not affect any previously
// initialized read-write locks".
libc_checked!(pthread_rwlockattr_destroy(attrs.as_mut_ptr()));
Self(Some(NonNull::new_unchecked(lock)))
}
}
fn inner(&self) -> NonNull<libc::pthread_rwlock_t> {
match self.0 {
None => {
panic!("PthreadRwLock constructed badly - something likely used RawRwLock::INIT")
}
Some(x) => x,
}
}
}
unsafe impl lock_api::RawRwLock for PthreadRwLock {
type GuardMarker = lock_api::GuardSend;
const INIT: Self = Self(None);
fn try_lock_shared(&self) -> bool {
unsafe {
let res = libc::pthread_rwlock_tryrdlock(self.inner().as_ptr());
match res {
0 => true,
libc::EAGAIN => false,
_ => panic!(
"pthread_rwlock_tryrdlock failed with {}",
Errno::from_raw(res)
),
}
}
}
fn try_lock_exclusive(&self) -> bool {
unsafe {
let res = libc::pthread_rwlock_trywrlock(self.inner().as_ptr());
match res {
0 => true,
libc::EAGAIN => false,
_ => panic!("try_wrlock failed with {}", Errno::from_raw(res)),
}
}
}
fn lock_shared(&self) {
unsafe {
libc_checked!(pthread_rwlock_rdlock(self.inner().as_ptr()));
}
}
fn lock_exclusive(&self) {
unsafe {
libc_checked!(pthread_rwlock_wrlock(self.inner().as_ptr()));
}
}
unsafe fn unlock_exclusive(&self) {
unsafe {
libc_checked!(pthread_rwlock_unlock(self.inner().as_ptr()));
}
}
unsafe fn unlock_shared(&self) {
unsafe {
libc_checked!(pthread_rwlock_unlock(self.inner().as_ptr()));
}
}
}

View File

@@ -749,7 +749,18 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
trace!("got query {query_string:?}");
if let Err(e) = handler.process_query(self, query_string).await {
match e {
QueryError::Shutdown => return Ok(ProcessMsgResult::Break),
err @ QueryError::Shutdown => {
// Notify postgres of the connection shutdown at the libpq
// protocol level. This avoids postgres having to tell apart
// from an idle connection and a stale one, which is bug prone.
let shutdown_error = short_error(&err);
self.write_message_noflush(&BeMessage::ErrorResponse(
&shutdown_error,
Some(err.pg_error_code()),
))?;
return Ok(ProcessMsgResult::Break);
}
QueryError::SimulatedConnectionError => {
return Err(QueryError::SimulatedConnectionError);
}

View File

@@ -47,6 +47,7 @@ tracing-subscriber = { workspace = true, features = ["json", "registry"] }
tracing-utils.workspace = true
rand.workspace = true
scopeguard.workspace = true
uuid.workspace = true
strum.workspace = true
strum_macros.workspace = true
walkdir.workspace = true

View File

@@ -12,7 +12,8 @@ use jsonwebtoken::{
Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode,
};
use pem::Pem;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use serde::{Deserialize, Deserializer, Serialize, de::DeserializeOwned};
use uuid::Uuid;
use crate::id::TenantId;
@@ -25,6 +26,11 @@ pub enum Scope {
/// Provides access to all data for a specific tenant (specified in `struct Claims` below)
// TODO: join these two?
Tenant,
/// Provides access to all data for a specific tenant, but based on endpoint ID. This token scope
/// is only used by compute to fetch the spec for a specific endpoint. The spec contains a Tenant-scoped
/// token authorizing access to all data of a tenant, so the spec-fetch API requires a TenantEndpoint
/// scope token to ensure that untrusted compute nodes can't fetch spec for arbitrary endpoints.
TenantEndpoint,
/// Provides blanket access to all tenants on the pageserver plus pageserver-wide APIs.
/// Should only be used e.g. for status check/tenant creation/list.
PageServerApi,
@@ -51,17 +57,43 @@ pub enum Scope {
ControllerPeer,
}
fn deserialize_empty_string_as_none_uuid<'de, D>(deserializer: D) -> Result<Option<Uuid>, D::Error>
where
D: Deserializer<'de>,
{
let opt = Option::<String>::deserialize(deserializer)?;
match opt.as_deref() {
Some("") => Ok(None),
Some(s) => Uuid::parse_str(s)
.map(Some)
.map_err(serde::de::Error::custom),
None => Ok(None),
}
}
/// JWT payload. See docs/authentication.md for the format
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct Claims {
#[serde(default)]
pub tenant_id: Option<TenantId>,
#[serde(
default,
skip_serializing_if = "Option::is_none",
// Neon control plane includes this field as empty in the claims.
// Consider it None in those cases.
deserialize_with = "deserialize_empty_string_as_none_uuid"
)]
pub endpoint_id: Option<Uuid>,
pub scope: Scope,
}
impl Claims {
pub fn new(tenant_id: Option<TenantId>, scope: Scope) -> Self {
Self { tenant_id, scope }
Self {
tenant_id,
scope,
endpoint_id: None,
}
}
}
@@ -212,6 +244,7 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
let expected_claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
scope: Scope::Tenant,
endpoint_id: None,
};
// A test token containing the following payload, signed using TEST_PRIV_KEY_ED25519:
@@ -240,6 +273,7 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
let claims = Claims {
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
scope: Scope::Tenant,
endpoint_id: None,
};
let pem = pem::parse(TEST_PRIV_KEY_ED25519).unwrap();

View File

@@ -873,6 +873,22 @@ impl Client {
.map_err(Error::ReceiveBody)
}
pub async fn reset_alert_gauges(&self) -> Result<()> {
let uri = format!(
"{}/hadron-internal/reset_alert_gauges",
self.mgmt_api_endpoint
);
self.start_request(Method::POST, uri)
.send()
.await
.map_err(Error::SendRequest)?
.error_from_body()
.await?
.json()
.await
.map_err(Error::ReceiveBody)
}
pub async fn wait_lsn(
&self,
tenant_shard_id: TenantShardId,

View File

@@ -20,7 +20,8 @@ pub fn check_permission(claims: &Claims, tenant_id: Option<TenantId>) -> Result<
| Scope::GenerationsApi
| Scope::Infra
| Scope::Scrubber
| Scope::ControllerPeer,
| Scope::ControllerPeer
| Scope::TenantEndpoint,
_,
) => Err(AuthError(
format!(

View File

@@ -813,6 +813,7 @@ impl Timeline {
let gc_cutoff_lsn_guard = self.get_applied_gc_cutoff_lsn();
let gc_cutoff_planned = {
let gc_info = self.gc_info.read().unwrap();
info!(cutoffs=?gc_info.cutoffs, applied_cutoff=%*gc_cutoff_lsn_guard, "starting find_lsn_for_timestamp");
gc_info.min_cutoff()
};
// Usually the planned cutoff is newer than the cutoff of the last gc run,

View File

@@ -1908,20 +1908,16 @@ impl TenantShard {
.map_err(LoadLocalTimelineError::ResumeDeletion)?;
}
// Upload the tenant manifest.
//
// This is uploaded unconditionally on every attach. This prevents races where a stale,
// still-alive tenant may modify a past manifest, and a future tenant loads it after this
// tenant has acted on it. Uploading a new manifest effectively hands over ownership of the
// manifest state. See: <https://databricks.atlassian.net/browse/LKB-165>.
// Stash the preloaded tenant manifest, and upload a new manifest if changed.
//
// NB: this must happen after the tenant is fully populated above. In particular the
// offloaded timelines, which are included in the manifest.
assert!(
self.remote_tenant_manifest.lock().await.is_none(),
"tenant manifest set before attach"
);
self.maybe_upload_tenant_manifest().await?; // always uploads, remote_tenant_manifest is None
{
let mut guard = self.remote_tenant_manifest.lock().await;
assert!(guard.is_none(), "tenant manifest set before preload"); // first populated here
*guard = preload.tenant_manifest;
}
self.maybe_upload_tenant_manifest().await?;
// The local filesystem contents are a cache of what's in the remote IndexPart;
// IndexPart is the source of truth.

View File

@@ -219,10 +219,6 @@ static char *lfc_path;
static uint64 lfc_generation;
static FileCacheControl *lfc_ctl;
static bool lfc_do_prewarm;
static shmem_startup_hook_type prev_shmem_startup_hook;
#if PG_VERSION_NUM>=150000
static shmem_request_hook_type prev_shmem_request_hook;
#endif
bool lfc_store_prefetch_result;
bool lfc_prewarm_update_ws_estimation;
@@ -342,18 +338,14 @@ lfc_ensure_opened(void)
return true;
}
static void
lfc_shmem_startup(void)
void
LfcShmemInit(void)
{
bool found;
static HASHCTL info;
if (prev_shmem_startup_hook)
{
prev_shmem_startup_hook();
}
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
if (lfc_max_size <= 0)
return;
lfc_ctl = (FileCacheControl *) ShmemInitStruct("lfc", sizeof(FileCacheControl), &found);
if (!found)
@@ -398,19 +390,16 @@ lfc_shmem_startup(void)
ConditionVariableInit(&lfc_ctl->cv[i]);
}
LWLockRelease(AddinShmemInitLock);
}
static void
lfc_shmem_request(void)
void
LfcShmemRequest(void)
{
#if PG_VERSION_NUM>=150000
if (prev_shmem_request_hook)
prev_shmem_request_hook();
#endif
RequestAddinShmemSpace(sizeof(FileCacheControl) + hash_estimate_size(SIZE_MB_TO_CHUNKS(lfc_max_size) + 1, FILE_CACHE_ENRTY_SIZE));
RequestNamedLWLockTranche("lfc_lock", 1);
if (lfc_max_size > 0)
{
RequestAddinShmemSpace(sizeof(FileCacheControl) + hash_estimate_size(SIZE_MB_TO_CHUNKS(lfc_max_size) + 1, FILE_CACHE_ENRTY_SIZE));
RequestNamedLWLockTranche("lfc_lock", 1);
}
}
static bool
@@ -642,18 +631,6 @@ lfc_init(void)
NULL,
NULL,
NULL);
if (lfc_max_size == 0)
return;
prev_shmem_startup_hook = shmem_startup_hook;
shmem_startup_hook = lfc_shmem_startup;
#if PG_VERSION_NUM>=150000
prev_shmem_request_hook = shmem_request_hook;
shmem_request_hook = lfc_shmem_request;
#else
lfc_shmem_request();
#endif
}
FileCacheState*

View File

@@ -90,6 +90,7 @@ typedef struct
{
char connstring[MAX_SHARDS][MAX_PAGESERVER_CONNSTRING_SIZE];
size_t num_shards;
size_t stripe_size;
} ShardMap;
/*
@@ -110,6 +111,11 @@ typedef struct
* has changed since last access, and to detect and retry copying the value if
* the postmaster changes the value concurrently. (Postmaster doesn't have a
* PGPROC entry and therefore cannot use LWLocks.)
*
* stripe_size is now also part of ShardMap, although it is defined by separate GUC.
* Postgres doesn't provide any mechanism to enforce dependencies between GUCs,
* that it we we have to rely on order of GUC definition in config file.
* "neon.stripe_size" should be defined prior to "neon.pageserver_connstring"
*/
typedef struct
{
@@ -118,10 +124,6 @@ typedef struct
ShardMap shard_map;
} PagestoreShmemState;
#if PG_VERSION_NUM >= 150000
static shmem_request_hook_type prev_shmem_request_hook = NULL;
#endif
static shmem_startup_hook_type prev_shmem_startup_hook;
static PagestoreShmemState *pagestore_shared;
static uint64 pagestore_local_counter = 0;
@@ -234,7 +236,10 @@ ParseShardMap(const char *connstr, ShardMap *result)
p = sep + 1;
}
if (result)
{
result->num_shards = nshards;
result->stripe_size = stripe_size;
}
return true;
}
@@ -295,12 +300,13 @@ AssignPageserverConnstring(const char *newval, void *extra)
* last call, terminates all existing connections to all pageservers.
*/
static void
load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p)
load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p, size_t* stripe_size_p)
{
uint64 begin_update_counter;
uint64 end_update_counter;
ShardMap *shard_map = &pagestore_shared->shard_map;
shardno_t num_shards;
size_t stripe_size;
/*
* Postmaster can update the shared memory values concurrently, in which
@@ -315,6 +321,7 @@ load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p)
end_update_counter = pg_atomic_read_u64(&pagestore_shared->end_update_counter);
num_shards = shard_map->num_shards;
stripe_size = shard_map->stripe_size;
if (connstr_p && shard_no < MAX_SHARDS)
strlcpy(connstr_p, shard_map->connstring[shard_no], MAX_PAGESERVER_CONNSTRING_SIZE);
pg_memory_barrier();
@@ -349,6 +356,8 @@ load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p)
if (num_shards_p)
*num_shards_p = num_shards;
if (stripe_size_p)
*stripe_size_p = stripe_size;
}
#define MB (1024*1024)
@@ -357,9 +366,10 @@ shardno_t
get_shard_number(BufferTag *tag)
{
shardno_t n_shards;
size_t stripe_size;
uint32 hash;
load_shard_map(0, NULL, &n_shards);
load_shard_map(0, NULL, &n_shards, &stripe_size);
#if PG_MAJORVERSION_NUM < 16
hash = murmurhash32(tag->rnode.relNode);
@@ -412,7 +422,7 @@ pageserver_connect(shardno_t shard_no, int elevel)
* Note that connstr is used both during connection start, and when we
* log the successful connection.
*/
load_shard_map(shard_no, connstr, NULL);
load_shard_map(shard_no, connstr, NULL, NULL);
switch (shard->state)
{
@@ -1284,18 +1294,12 @@ check_neon_id(char **newval, void **extra, GucSource source)
return **newval == '\0' || HexDecodeString(id, *newval, 16);
}
static Size
PagestoreShmemSize(void)
{
return add_size(sizeof(PagestoreShmemState), NeonPerfCountersShmemSize());
}
static bool
void
PagestoreShmemInit(void)
{
bool found;
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
pagestore_shared = ShmemInitStruct("libpagestore shared state",
sizeof(PagestoreShmemState),
&found);
@@ -1306,44 +1310,12 @@ PagestoreShmemInit(void)
memset(&pagestore_shared->shard_map, 0, sizeof(ShardMap));
AssignPageserverConnstring(page_server_connstring, NULL);
}
NeonPerfCountersShmemInit();
LWLockRelease(AddinShmemInitLock);
return found;
}
static void
pagestore_shmem_startup_hook(void)
void
PagestoreShmemRequest(void)
{
if (prev_shmem_startup_hook)
prev_shmem_startup_hook();
PagestoreShmemInit();
}
static void
pagestore_shmem_request(void)
{
#if PG_VERSION_NUM >= 150000
if (prev_shmem_request_hook)
prev_shmem_request_hook();
#endif
RequestAddinShmemSpace(PagestoreShmemSize());
}
static void
pagestore_prepare_shmem(void)
{
#if PG_VERSION_NUM >= 150000
prev_shmem_request_hook = shmem_request_hook;
shmem_request_hook = pagestore_shmem_request;
#else
pagestore_shmem_request();
#endif
prev_shmem_startup_hook = shmem_startup_hook;
shmem_startup_hook = pagestore_shmem_startup_hook;
RequestAddinShmemSpace(sizeof(PagestoreShmemState));
}
/*
@@ -1352,8 +1324,6 @@ pagestore_prepare_shmem(void)
void
pg_init_libpagestore(void)
{
pagestore_prepare_shmem();
DefineCustomStringVariable("neon.pageserver_connstring",
"connection string to the page server",
NULL,
@@ -1504,8 +1474,6 @@ pg_init_libpagestore(void)
0,
NULL, NULL, NULL);
relsize_hash_init();
if (page_server != NULL)
neon_log(ERROR, "libpagestore already loaded");

View File

@@ -22,6 +22,7 @@
#include "replication/slot.h"
#include "replication/walsender.h"
#include "storage/proc.h"
#include "storage/ipc.h"
#include "funcapi.h"
#include "access/htup_details.h"
#include "utils/builtins.h"
@@ -59,11 +60,15 @@ static ExecutorEnd_hook_type prev_ExecutorEnd = NULL;
static void neon_ExecutorStart(QueryDesc *queryDesc, int eflags);
static void neon_ExecutorEnd(QueryDesc *queryDesc);
#if PG_MAJORVERSION_NUM >= 16
static shmem_startup_hook_type prev_shmem_startup_hook;
static void neon_shmem_startup_hook(void);
static void neon_shmem_request_hook(void);
#if PG_MAJORVERSION_NUM >= 15
static shmem_request_hook_type prev_shmem_request_hook = NULL;
#endif
#if PG_MAJORVERSION_NUM >= 17
uint32 WAIT_EVENT_NEON_LFC_MAINTENANCE;
uint32 WAIT_EVENT_NEON_LFC_READ;
@@ -450,15 +455,44 @@ _PG_init(void)
*/
#if PG_VERSION_NUM >= 160000
load_file("$libdir/neon_rmgr", false);
prev_shmem_startup_hook = shmem_startup_hook;
shmem_startup_hook = neon_shmem_startup_hook;
#endif
/* dummy call to a Rust function in the communicator library, to check that it works */
(void) communicator_dummy(123);
/*
* Initializing a pre-loaded Postgres extension happens in three stages:
*
* 1. _PG_init() is called early at postmaster startup. In this stage, no
* shared memory has been allocated yet. Core Postgres GUCs have been
* initialized from the config files, but notably, MaxBackends has not
* calculated yet. In this stage, we must register any extension GUCs
* and can do other early initialization that doesn't depend on shared
* memory. In this stage we must also register "shmem request" and
* "shmem starutup" hooks, to be called in stages 2 and 3.
*
* 2. After MaxBackends have been calculated, the "shmem request" hooks
* are called. The hooks can reserve shared memory by calling
* RequestAddinShmemSpace and RequestNamedLWLockTranche(). The "shmem
* request hooks" are a new mechanism in Postgres v15. In v14 and
* below, you had to make those Requests in stage 1 already, which
* means they could not depend on MaxBackends. (See hack in
* NeonPerfCountersShmemRequest())
*
* 3. After some more runtime-computed GUCs that affect the amount of
* shared memory needed have been calculated, the "shmem startup" hooks
* are called. In this stage, we allocate any shared memory, LWLocks
* and other shared resources.
*
* Here, in the 'neon' extension, we register just one shmem request hook
* and one startup hook, which call into functions in all the subsystems
* that are part of the extension. On v14, the ShmemRequest functions are
* called in stage 1, and on v15 onwards they are called in stage 2.
*/
/* Stage 1: Define GUCs, and other early intialization */
pg_init_libpagestore();
relsize_hash_init();
lfc_init();
pg_init_walproposer();
init_lwlsncache();
@@ -561,6 +595,22 @@ _PG_init(void)
ReportSearchPath();
/*
* Register initialization hooks for stage 2. (On v14, there's no "shmem
* request" hooks, so call the ShmemRequest functions immediately.)
*/
#if PG_VERSION_NUM >= 150000
prev_shmem_request_hook = shmem_request_hook;
shmem_request_hook = neon_shmem_request_hook;
#else
neon_shmem_request_hook();
#endif
/* Register hooks for stage 3 */
prev_shmem_startup_hook = shmem_startup_hook;
shmem_startup_hook = neon_shmem_startup_hook;
/* Other misc initialization */
prev_ExecutorStart = ExecutorStart_hook;
ExecutorStart_hook = neon_ExecutorStart;
prev_ExecutorEnd = ExecutorEnd_hook;
@@ -646,7 +696,34 @@ approximate_working_set_size(PG_FUNCTION_ARGS)
PG_RETURN_INT32(dc);
}
#if PG_MAJORVERSION_NUM >= 16
/*
* Initialization stage 2: make requests for the amount of shared memory we
* will need.
*
* For a high-level explanation of the initialization process, see _PG_init().
*/
static void
neon_shmem_request_hook(void)
{
#if PG_VERSION_NUM >= 150000
if (prev_shmem_request_hook)
prev_shmem_request_hook();
#endif
LfcShmemRequest();
NeonPerfCountersShmemRequest();
PagestoreShmemRequest();
RelsizeCacheShmemRequest();
WalproposerShmemRequest();
LwLsnCacheShmemRequest();
}
/*
* Initialization stage 3: Initialize shared memory.
*
* For a high-level explanation of the initialization process, see _PG_init().
*/
static void
neon_shmem_startup_hook(void)
{
@@ -654,6 +731,15 @@ neon_shmem_startup_hook(void)
if (prev_shmem_startup_hook)
prev_shmem_startup_hook();
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
LfcShmemInit();
NeonPerfCountersShmemInit();
PagestoreShmemInit();
RelsizeCacheShmemInit();
WalproposerShmemInit();
LwLsnCacheShmemInit();
#if PG_MAJORVERSION_NUM >= 17
WAIT_EVENT_NEON_LFC_MAINTENANCE = WaitEventExtensionNew("Neon/FileCache_Maintenance");
WAIT_EVENT_NEON_LFC_READ = WaitEventExtensionNew("Neon/FileCache_Read");
@@ -666,8 +752,9 @@ neon_shmem_startup_hook(void)
WAIT_EVENT_NEON_PS_READ = WaitEventExtensionNew("Neon/PS_ReadIO");
WAIT_EVENT_NEON_WAL_DL = WaitEventExtensionNew("Neon/WAL_Download");
#endif
LWLockRelease(AddinShmemInitLock);
}
#endif
/*
* ExecutorStart hook: start up tracking if needed

View File

@@ -70,4 +70,19 @@ extern PGDLLEXPORT void WalProposerSync(int argc, char *argv[]);
extern PGDLLEXPORT void WalProposerMain(Datum main_arg);
extern PGDLLEXPORT void LogicalSlotsMonitorMain(Datum main_arg);
extern void LfcShmemRequest(void);
extern void PagestoreShmemRequest(void);
extern void RelsizeCacheShmemRequest(void);
extern void WalproposerShmemRequest(void);
extern void LwLsnCacheShmemRequest(void);
extern void NeonPerfCountersShmemRequest(void);
extern void LfcShmemInit(void);
extern void PagestoreShmemInit(void);
extern void RelsizeCacheShmemInit(void);
extern void WalproposerShmemInit(void);
extern void LwLsnCacheShmemInit(void);
extern void NeonPerfCountersShmemInit(void);
#endif /* NEON_H */

View File

@@ -1,5 +1,6 @@
#include "postgres.h"
#include "neon.h"
#include "neon_lwlsncache.h"
#include "miscadmin.h"
@@ -81,14 +82,6 @@ static set_max_lwlsn_hook_type prev_set_max_lwlsn_hook = NULL;
static set_lwlsn_relation_hook_type prev_set_lwlsn_relation_hook = NULL;
static set_lwlsn_db_hook_type prev_set_lwlsn_db_hook = NULL;
static shmem_startup_hook_type prev_shmem_startup_hook;
#if PG_VERSION_NUM >= 150000
static shmem_request_hook_type prev_shmem_request_hook;
#endif
static void shmemrequest(void);
static void shmeminit(void);
static void neon_set_max_lwlsn(XLogRecPtr lsn);
void
@@ -99,16 +92,6 @@ init_lwlsncache(void)
lwlc_register_gucs();
prev_shmem_startup_hook = shmem_startup_hook;
shmem_startup_hook = shmeminit;
#if PG_VERSION_NUM >= 150000
prev_shmem_request_hook = shmem_request_hook;
shmem_request_hook = shmemrequest;
#else
shmemrequest();
#endif
prev_set_lwlsn_block_range_hook = set_lwlsn_block_range_hook;
set_lwlsn_block_range_hook = neon_set_lwlsn_block_range;
prev_set_lwlsn_block_v_hook = set_lwlsn_block_v_hook;
@@ -124,20 +107,19 @@ init_lwlsncache(void)
}
static void shmemrequest(void) {
void
LwLsnCacheShmemRequest(void)
{
Size requested_size = sizeof(LwLsnCacheCtl);
requested_size += hash_estimate_size(lwlsn_cache_size, sizeof(LastWrittenLsnCacheEntry));
RequestAddinShmemSpace(requested_size);
#if PG_VERSION_NUM >= 150000
if (prev_shmem_request_hook)
prev_shmem_request_hook();
#endif
}
static void shmeminit(void) {
void
LwLsnCacheShmemInit(void)
{
static HASHCTL info;
bool found;
if (lwlsn_cache_size > 0)
@@ -157,9 +139,6 @@ static void shmeminit(void) {
}
dlist_init(&LwLsnCache->lastWrittenLsnLRU);
LwLsnCache->maxLastWrittenLsn = GetRedoRecPtr();
if (prev_shmem_startup_hook) {
prev_shmem_startup_hook();
}
}
/*

View File

@@ -17,22 +17,32 @@
#include "storage/shmem.h"
#include "utils/builtins.h"
#include "neon.h"
#include "neon_perf_counters.h"
#include "neon_pgversioncompat.h"
neon_per_backend_counters *neon_per_backend_counters_shared;
Size
NeonPerfCountersShmemSize(void)
void
NeonPerfCountersShmemRequest(void)
{
Size size = 0;
size = add_size(size, mul_size(NUM_NEON_PERF_COUNTER_SLOTS,
sizeof(neon_per_backend_counters)));
return size;
Size size;
#if PG_MAJORVERSION_NUM < 15
/* Hack: in PG14 MaxBackends is not initialized at the time of calling NeonPerfCountersShmemRequest function.
* Do it ourselves and then undo to prevent assertion failure
*/
Assert(MaxBackends == 0); /* not initialized yet */
InitializeMaxBackends();
size = mul_size(NUM_NEON_PERF_COUNTER_SLOTS, sizeof(neon_per_backend_counters));
MaxBackends = 0;
#else
size = mul_size(NUM_NEON_PERF_COUNTER_SLOTS, sizeof(neon_per_backend_counters));
#endif
RequestAddinShmemSpace(size);
}
void
NeonPerfCountersShmemInit(void)
{

View File

@@ -10,6 +10,7 @@
*/
#include "postgres.h"
#include "neon.h"
#include "neon_pgversioncompat.h"
#include "pagestore_client.h"
@@ -49,32 +50,23 @@ typedef struct
* algorithm */
} RelSizeHashControl;
static HTAB *relsize_hash;
static LWLockId relsize_lock;
static int relsize_hash_size;
static RelSizeHashControl* relsize_ctl;
static shmem_startup_hook_type prev_shmem_startup_hook = NULL;
#if PG_VERSION_NUM >= 150000
static shmem_request_hook_type prev_shmem_request_hook = NULL;
static void relsize_shmem_request(void);
#endif
/*
* Size of a cache entry is 36 bytes. So this default will take about 2.3 MB,
* which seems reasonable.
*/
#define DEFAULT_RELSIZE_HASH_SIZE (64 * 1024)
static void
neon_smgr_shmem_startup(void)
static HTAB *relsize_hash;
static LWLockId relsize_lock;
static int relsize_hash_size = DEFAULT_RELSIZE_HASH_SIZE;
static RelSizeHashControl* relsize_ctl;
void
RelsizeCacheShmemInit(void)
{
static HASHCTL info;
bool found;
if (prev_shmem_startup_hook)
prev_shmem_startup_hook();
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
relsize_ctl = (RelSizeHashControl *) ShmemInitStruct("relsize_hash", sizeof(RelSizeHashControl), &found);
if (!found)
{
@@ -85,7 +77,6 @@ neon_smgr_shmem_startup(void)
relsize_hash_size, relsize_hash_size,
&info,
HASH_ELEM | HASH_BLOBS);
LWLockRelease(AddinShmemInitLock);
relsize_ctl->size = 0;
relsize_ctl->hits = 0;
relsize_ctl->misses = 0;
@@ -242,34 +233,15 @@ relsize_hash_init(void)
PGC_POSTMASTER,
0,
NULL, NULL, NULL);
if (relsize_hash_size > 0)
{
#if PG_VERSION_NUM >= 150000
prev_shmem_request_hook = shmem_request_hook;
shmem_request_hook = relsize_shmem_request;
#else
RequestAddinShmemSpace(hash_estimate_size(relsize_hash_size, sizeof(RelSizeEntry)));
RequestNamedLWLockTranche("neon_relsize", 1);
#endif
prev_shmem_startup_hook = shmem_startup_hook;
shmem_startup_hook = neon_smgr_shmem_startup;
}
}
#if PG_VERSION_NUM >= 150000
/*
* shmem_request hook: request additional shared resources. We'll allocate or
* attach to the shared resources in neon_smgr_shmem_startup().
*/
static void
relsize_shmem_request(void)
void
RelsizeCacheShmemRequest(void)
{
if (prev_shmem_request_hook)
prev_shmem_request_hook();
RequestAddinShmemSpace(sizeof(RelSizeHashControl) + hash_estimate_size(relsize_hash_size, sizeof(RelSizeEntry)));
RequestNamedLWLockTranche("neon_relsize", 1);
}
#endif

View File

@@ -83,10 +83,8 @@ static XLogRecPtr standby_flush_lsn = InvalidXLogRecPtr;
static XLogRecPtr standby_apply_lsn = InvalidXLogRecPtr;
static HotStandbyFeedback agg_hs_feedback;
static void nwp_shmem_startup_hook(void);
static void nwp_register_gucs(void);
static void assign_neon_safekeepers(const char *newval, void *extra);
static void nwp_prepare_shmem(void);
static uint64 backpressure_lag_impl(void);
static uint64 startup_backpressure_wrap(void);
static bool backpressure_throttling_impl(void);
@@ -99,11 +97,6 @@ static TimestampTz walprop_pg_get_current_timestamp(WalProposer *wp);
static void walprop_pg_load_libpqwalreceiver(void);
static process_interrupts_callback_t PrevProcessInterruptsCallback = NULL;
static shmem_startup_hook_type prev_shmem_startup_hook_type;
#if PG_VERSION_NUM >= 150000
static shmem_request_hook_type prev_shmem_request_hook = NULL;
static void walproposer_shmem_request(void);
#endif
static void WalproposerShmemInit_SyncSafekeeper(void);
@@ -193,8 +186,6 @@ pg_init_walproposer(void)
nwp_register_gucs();
nwp_prepare_shmem();
delay_backend_us = &startup_backpressure_wrap;
PrevProcessInterruptsCallback = ProcessInterruptsCallback;
ProcessInterruptsCallback = backpressure_throttling_impl;
@@ -494,12 +485,11 @@ WalproposerShmemSize(void)
return sizeof(WalproposerShmemState);
}
static bool
void
WalproposerShmemInit(void)
{
bool found;
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
walprop_shared = ShmemInitStruct("Walproposer shared state",
sizeof(WalproposerShmemState),
&found);
@@ -517,9 +507,6 @@ WalproposerShmemInit(void)
pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.last_recorded_time_us, 0);
/* END_HADRON */
}
LWLockRelease(AddinShmemInitLock);
return found;
}
static void
@@ -623,42 +610,15 @@ walprop_register_bgworker(void)
/* shmem handling */
static void
nwp_prepare_shmem(void)
{
#if PG_VERSION_NUM >= 150000
prev_shmem_request_hook = shmem_request_hook;
shmem_request_hook = walproposer_shmem_request;
#else
RequestAddinShmemSpace(WalproposerShmemSize());
#endif
prev_shmem_startup_hook_type = shmem_startup_hook;
shmem_startup_hook = nwp_shmem_startup_hook;
}
#if PG_VERSION_NUM >= 150000
/*
* shmem_request hook: request additional shared resources. We'll allocate or
* attach to the shared resources in nwp_shmem_startup_hook().
* attach to the shared resources in WalproposerShmemInit().
*/
static void
walproposer_shmem_request(void)
void
WalproposerShmemRequest(void)
{
if (prev_shmem_request_hook)
prev_shmem_request_hook();
RequestAddinShmemSpace(WalproposerShmemSize());
}
#endif
static void
nwp_shmem_startup_hook(void)
{
if (prev_shmem_startup_hook_type)
prev_shmem_startup_hook_type();
WalproposerShmemInit();
}
WalproposerShmemState *
GetWalpropShmemState(void)

View File

@@ -98,6 +98,7 @@ tracing-log.workspace = true
tracing-opentelemetry.workspace = true
try-lock.workspace = true
typed-json.workspace = true
type-safe-id = { version = "0.3.3", features = ["serde"] }
url.workspace = true
urlencoding.workspace = true
utils.workspace = true

View File

@@ -26,6 +26,7 @@ use utils::project_git_version;
use utils::sentry_init::init_sentry;
use crate::context::RequestContext;
use crate::id::{ClientConnId, RequestId};
use crate::metrics::{Metrics, ThreadPoolMetrics};
use crate::pglb::TlsRequired;
use crate::pqproto::FeStartupPacket;
@@ -219,7 +220,8 @@ pub(super) async fn task_main(
{
let (socket, peer_addr) = accept_result?;
let session_id = uuid::Uuid::new_v4();
let conn_id = ClientConnId::new();
let session_id = RequestId::from_uuid(conn_id.uuid());
let tls_config = Arc::clone(&tls_config);
let dest_suffix = Arc::clone(&dest_suffix);
let compute_tls_config = compute_tls_config.clone();
@@ -231,6 +233,7 @@ pub(super) async fn task_main(
.context("failed to set socket option")?;
let ctx = RequestContext::new(
conn_id,
session_id,
ConnectionInfo {
addr: peer_addr,
@@ -252,7 +255,7 @@ pub(super) async fn task_main(
// Acknowledge that the task has finished with an error.
error!("per-client task finished with an error: {e:#}");
})
.instrument(tracing::info_span!("handle_client", ?session_id)),
.instrument(tracing::info_span!("handle_client", %session_id)),
);
}

View File

@@ -10,6 +10,7 @@ use tokio::time::Instant;
use tracing::{debug, info};
use crate::config::ProjectInfoCacheOptions;
use crate::control_plane::messages::{ControlPlaneErrorMessage, Reason};
use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
use crate::types::{EndpointId, RoleName};
@@ -36,22 +37,37 @@ impl<T> Entry<T> {
}
pub(crate) fn get(&self) -> Option<&T> {
(self.expires_at > Instant::now()).then_some(&self.value)
(!self.is_expired()).then_some(&self.value)
}
fn is_expired(&self) -> bool {
self.expires_at <= Instant::now()
}
}
struct EndpointInfo {
role_controls: HashMap<RoleNameInt, Entry<RoleAccessControl>>,
controls: Option<Entry<EndpointAccessControl>>,
role_controls: HashMap<RoleNameInt, Entry<ControlPlaneResult<RoleAccessControl>>>,
controls: Option<Entry<ControlPlaneResult<EndpointAccessControl>>>,
}
type ControlPlaneResult<T> = Result<T, Box<ControlPlaneErrorMessage>>;
impl EndpointInfo {
pub(crate) fn get_role_secret(&self, role_name: RoleNameInt) -> Option<RoleAccessControl> {
self.role_controls.get(&role_name)?.get().cloned()
pub(crate) fn get_role_secret_with_ttl(
&self,
role_name: RoleNameInt,
) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
let entry = self.role_controls.get(&role_name)?;
let ttl = entry.expires_at - Instant::now();
Some((entry.get()?.clone(), ttl))
}
pub(crate) fn get_controls(&self) -> Option<EndpointAccessControl> {
self.controls.as_ref()?.get().cloned()
pub(crate) fn get_controls_with_ttl(
&self,
) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
let entry = self.controls.as_ref()?;
let ttl = entry.expires_at - Instant::now();
Some((entry.get()?.clone(), ttl))
}
pub(crate) fn invalidate_endpoint(&mut self) {
@@ -153,28 +169,28 @@ impl ProjectInfoCacheImpl {
self.cache.get(&endpoint_id)
}
pub(crate) fn get_role_secret(
pub(crate) fn get_role_secret_with_ttl(
&self,
endpoint_id: &EndpointId,
role_name: &RoleName,
) -> Option<RoleAccessControl> {
) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
let role_name = RoleNameInt::get(role_name)?;
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_role_secret(role_name)
endpoint_info.get_role_secret_with_ttl(role_name)
}
pub(crate) fn get_endpoint_access(
pub(crate) fn get_endpoint_access_with_ttl(
&self,
endpoint_id: &EndpointId,
) -> Option<EndpointAccessControl> {
) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
endpoint_info.get_controls()
endpoint_info.get_controls_with_ttl()
}
pub(crate) fn insert_endpoint_access(
&self,
account_id: Option<AccountIdInt>,
project_id: ProjectIdInt,
project_id: Option<ProjectIdInt>,
endpoint_id: EndpointIdInt,
role_name: RoleNameInt,
controls: EndpointAccessControl,
@@ -183,26 +199,89 @@ impl ProjectInfoCacheImpl {
if let Some(account_id) = account_id {
self.insert_account2endpoint(account_id, endpoint_id);
}
self.insert_project2endpoint(project_id, endpoint_id);
if let Some(project_id) = project_id {
self.insert_project2endpoint(project_id, endpoint_id);
}
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
let controls = Entry::new(controls, self.config.ttl);
let role_controls = Entry::new(role_controls, self.config.ttl);
debug!(
key = &*endpoint_id,
"created a cache entry for endpoint access"
);
let controls = Some(Entry::new(Ok(controls), self.config.ttl));
let role_controls = Entry::new(Ok(role_controls), self.config.ttl);
match self.cache.entry(endpoint_id) {
clashmap::Entry::Vacant(e) => {
e.insert(EndpointInfo {
role_controls: HashMap::from_iter([(role_name, role_controls)]),
controls: Some(controls),
controls,
});
}
clashmap::Entry::Occupied(mut e) => {
let ep = e.get_mut();
ep.controls = Some(controls);
ep.controls = controls;
if ep.role_controls.len() < self.config.max_roles {
ep.role_controls.insert(role_name, role_controls);
}
}
}
}
pub(crate) fn insert_endpoint_access_err(
&self,
endpoint_id: EndpointIdInt,
role_name: RoleNameInt,
msg: Box<ControlPlaneErrorMessage>,
ttl: Option<Duration>,
) {
if self.cache.len() >= self.config.size {
// If there are too many entries, wait until the next gc cycle.
return;
}
debug!(
key = &*endpoint_id,
"created a cache entry for an endpoint access error"
);
let ttl = ttl.unwrap_or(self.config.ttl);
let controls = if msg.get_reason() == Reason::RoleProtected {
// RoleProtected is the only role-specific error that control plane can give us.
// If a given role name does not exist, it still returns a successful response,
// just with an empty secret.
None
} else {
// We can cache all the other errors in EndpointInfo.controls,
// because they don't depend on what role name we pass to control plane.
Some(Entry::new(Err(msg.clone()), ttl))
};
let role_controls = Entry::new(Err(msg), ttl);
match self.cache.entry(endpoint_id) {
clashmap::Entry::Vacant(e) => {
e.insert(EndpointInfo {
role_controls: HashMap::from_iter([(role_name, role_controls)]),
controls,
});
}
clashmap::Entry::Occupied(mut e) => {
let ep = e.get_mut();
if let Some(entry) = &ep.controls
&& !entry.is_expired()
&& entry.value.is_ok()
{
// If we have cached non-expired, non-error controls, keep them.
} else {
ep.controls = controls;
}
if ep.role_controls.len() < self.config.max_roles {
ep.role_controls.insert(role_name, role_controls);
}
@@ -245,7 +324,7 @@ impl ProjectInfoCacheImpl {
return;
};
if role_controls.get().expires_at <= Instant::now() {
if role_controls.get().is_expired() {
role_controls.remove();
}
}
@@ -287,10 +366,9 @@ mod tests {
use std::sync::Arc;
use super::*;
use crate::control_plane::messages::EndpointRateLimitConfig;
use crate::control_plane::messages::{Details, EndpointRateLimitConfig, ErrorInfo, Status};
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
use crate::scram::ServerSecret;
use crate::types::ProjectId;
#[tokio::test]
async fn test_project_info_cache_settings() {
@@ -301,9 +379,9 @@ mod tests {
ttl: Duration::from_secs(1),
gc_interval: Duration::from_secs(600),
});
let project_id: ProjectId = "project".into();
let project_id: Option<ProjectIdInt> = Some(ProjectIdInt::from(&"project".into()));
let endpoint_id: EndpointId = "endpoint".into();
let account_id: Option<AccountIdInt> = None;
let account_id = None;
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
@@ -316,7 +394,7 @@ mod tests {
cache.insert_endpoint_access(
account_id,
(&project_id).into(),
project_id,
(&endpoint_id).into(),
(&user1).into(),
EndpointAccessControl {
@@ -332,7 +410,7 @@ mod tests {
cache.insert_endpoint_access(
account_id,
(&project_id).into(),
project_id,
(&endpoint_id).into(),
(&user2).into(),
EndpointAccessControl {
@@ -346,11 +424,17 @@ mod tests {
},
);
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
assert_eq!(cached.secret, secret1);
let (cached, ttl) = cache
.get_role_secret_with_ttl(&endpoint_id, &user1)
.unwrap();
assert_eq!(cached.unwrap().secret, secret1);
assert_eq!(ttl, cache.config.ttl);
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
assert_eq!(cached.secret, secret2);
let (cached, ttl) = cache
.get_role_secret_with_ttl(&endpoint_id, &user2)
.unwrap();
assert_eq!(cached.unwrap().secret, secret2);
assert_eq!(ttl, cache.config.ttl);
// Shouldn't add more than 2 roles.
let user3: RoleName = "user3".into();
@@ -358,7 +442,7 @@ mod tests {
cache.insert_endpoint_access(
account_id,
(&project_id).into(),
project_id,
(&endpoint_id).into(),
(&user3).into(),
EndpointAccessControl {
@@ -372,17 +456,144 @@ mod tests {
},
);
assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
assert!(
cache
.get_role_secret_with_ttl(&endpoint_id, &user3)
.is_none()
);
let cached = cache.get_endpoint_access(&endpoint_id).unwrap();
let cached = cache
.get_endpoint_access_with_ttl(&endpoint_id)
.unwrap()
.0
.unwrap();
assert_eq!(cached.allowed_ips, allowed_ips);
tokio::time::advance(Duration::from_secs(2)).await;
let cached = cache.get_role_secret(&endpoint_id, &user1);
let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user1);
assert!(cached.is_none());
let cached = cache.get_role_secret(&endpoint_id, &user2);
let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user2);
assert!(cached.is_none());
let cached = cache.get_endpoint_access(&endpoint_id);
let cached = cache.get_endpoint_access_with_ttl(&endpoint_id);
assert!(cached.is_none());
}
#[tokio::test]
async fn test_caching_project_info_errors() {
let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
size: 10,
max_roles: 10,
ttl: Duration::from_secs(1),
gc_interval: Duration::from_secs(600),
});
let project_id = Some(ProjectIdInt::from(&"project".into()));
let endpoint_id: EndpointId = "endpoint".into();
let account_id = None;
let user1: RoleName = "user1".into();
let user2: RoleName = "user2".into();
let secret = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
let role_msg = Box::new(ControlPlaneErrorMessage {
error: "role is protected and cannot be used for password-based authentication"
.to_owned()
.into_boxed_str(),
http_status_code: http::StatusCode::NOT_FOUND,
status: Some(Status {
code: "PERMISSION_DENIED".to_owned().into_boxed_str(),
message: "role is protected and cannot be used for password-based authentication"
.to_owned()
.into_boxed_str(),
details: Details {
error_info: Some(ErrorInfo {
reason: Reason::RoleProtected,
}),
retry_info: None,
user_facing_message: None,
},
}),
});
let generic_msg = Box::new(ControlPlaneErrorMessage {
error: "oh noes".to_owned().into_boxed_str(),
http_status_code: http::StatusCode::NOT_FOUND,
status: None,
});
let get_role_secret = |endpoint_id, role_name| {
cache
.get_role_secret_with_ttl(endpoint_id, role_name)
.unwrap()
.0
};
let get_endpoint_access =
|endpoint_id| cache.get_endpoint_access_with_ttl(endpoint_id).unwrap().0;
// stores role-specific errors only for get_role_secret
cache.insert_endpoint_access_err(
(&endpoint_id).into(),
(&user1).into(),
role_msg.clone(),
None,
);
assert_eq!(
get_role_secret(&endpoint_id, &user1).unwrap_err().error,
role_msg.error
);
assert!(cache.get_endpoint_access_with_ttl(&endpoint_id).is_none());
// stores non-role specific errors for both get_role_secret and get_endpoint_access
cache.insert_endpoint_access_err(
(&endpoint_id).into(),
(&user1).into(),
generic_msg.clone(),
None,
);
assert_eq!(
get_role_secret(&endpoint_id, &user1).unwrap_err().error,
generic_msg.error
);
assert_eq!(
get_endpoint_access(&endpoint_id).unwrap_err().error,
generic_msg.error
);
// error isn't returned for other roles in the same endpoint
assert!(
cache
.get_role_secret_with_ttl(&endpoint_id, &user2)
.is_none()
);
// success for a role does not overwrite errors for other roles
cache.insert_endpoint_access(
account_id,
project_id,
(&endpoint_id).into(),
(&user2).into(),
EndpointAccessControl {
allowed_ips: Arc::new(vec![]),
allowed_vpce: Arc::new(vec![]),
flags: AccessBlockerFlags::default(),
rate_limits: EndpointRateLimitConfig::default(),
},
RoleAccessControl {
secret: secret.clone(),
},
);
assert!(get_role_secret(&endpoint_id, &user1).is_err());
assert!(get_role_secret(&endpoint_id, &user2).is_ok());
// ...but does clear the access control error
assert!(get_endpoint_access(&endpoint_id).is_ok());
// storing an error does not overwrite successful access control response
cache.insert_endpoint_access_err(
(&endpoint_id).into(),
(&user2).into(),
generic_msg.clone(),
None,
);
assert!(get_role_secret(&endpoint_id, &user2).is_err());
assert!(get_endpoint_access(&endpoint_id).is_ok());
}
}

View File

@@ -23,6 +23,7 @@ use crate::context::RequestContext;
use crate::control_plane::ControlPlaneApi;
use crate::error::ReportableError;
use crate::ext::LockExt;
use crate::id::RequestId;
use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind};
use crate::pqproto::CancelKeyData;
use crate::rate_limiter::LeakyBucketRateLimiter;
@@ -32,8 +33,11 @@ use crate::util::run_until;
type IpSubnetKey = IpNet;
const CANCEL_KEY_TTL: Duration = Duration::from_secs(600);
const CANCEL_KEY_REFRESH: Duration = Duration::from_secs(570);
/// Initial period and TTL is shorter to clear keys of short-lived connections faster.
const CANCEL_KEY_INITIAL_PERIOD: Duration = Duration::from_secs(60);
const CANCEL_KEY_REFRESH_PERIOD: Duration = Duration::from_secs(10 * 60);
/// `CANCEL_KEY_TTL_SLACK` is added to the periods to determine the actual TTL.
const CANCEL_KEY_TTL_SLACK: Duration = Duration::from_secs(30);
// Message types for sending through mpsc channel
pub enum CancelKeyOp {
@@ -54,6 +58,24 @@ pub enum CancelKeyOp {
},
}
impl CancelKeyOp {
const fn redis_msg_kind(&self) -> RedisMsgKind {
match self {
CancelKeyOp::Store { .. } => RedisMsgKind::Set,
CancelKeyOp::Refresh { .. } => RedisMsgKind::Expire,
CancelKeyOp::Get { .. } => RedisMsgKind::Get,
CancelKeyOp::GetOld { .. } => RedisMsgKind::HGet,
}
}
fn cancel_channel_metric_guard(&self) -> CancelChannelSizeGuard<'static> {
Metrics::get()
.proxy
.cancel_channel_size
.guard(self.redis_msg_kind())
}
}
#[derive(thiserror::Error, Debug, Clone)]
pub enum PipelineError {
#[error("could not send cmd to redis: {0}")]
@@ -465,7 +487,7 @@ impl Session {
/// This is not cancel safe
pub(crate) async fn maintain_cancel_key(
&self,
session_id: uuid::Uuid,
session_id: RequestId,
cancel: tokio::sync::oneshot::Receiver<Infallible>,
cancel_closure: &CancelClosure,
compute_config: &ComputeConfig,
@@ -483,50 +505,49 @@ impl Session {
let mut cancel = pin!(cancel);
enum State {
Set,
Init,
Refresh,
}
let mut state = State::Set;
let mut state = State::Init;
loop {
let guard_op = match state {
State::Set => {
let guard = Metrics::get()
.proxy
.cancel_channel_size
.guard(RedisMsgKind::Set);
let op = CancelKeyOp::Store {
key: self.key,
value: closure_json.clone(),
expire: CANCEL_KEY_TTL,
};
let (op, mut wait_interval) = match state {
State::Init => {
tracing::debug!(
src=%self.key,
dest=?cancel_closure.cancel_token,
"registering cancellation key"
);
(guard, op)
(
CancelKeyOp::Store {
key: self.key,
value: closure_json.clone(),
expire: CANCEL_KEY_INITIAL_PERIOD + CANCEL_KEY_TTL_SLACK,
},
CANCEL_KEY_INITIAL_PERIOD,
)
}
State::Refresh => {
let guard = Metrics::get()
.proxy
.cancel_channel_size
.guard(RedisMsgKind::Expire);
let op = CancelKeyOp::Refresh {
key: self.key,
expire: CANCEL_KEY_TTL,
};
tracing::debug!(
src=%self.key,
dest=?cancel_closure.cancel_token,
"refreshing cancellation key"
);
(guard, op)
(
CancelKeyOp::Refresh {
key: self.key,
expire: CANCEL_KEY_REFRESH_PERIOD + CANCEL_KEY_TTL_SLACK,
},
CANCEL_KEY_REFRESH_PERIOD,
)
}
};
match tx.call(guard_op, cancel.as_mut()).await {
match tx
.call((op.cancel_channel_metric_guard(), op), cancel.as_mut())
.await
{
// SET returns OK
Ok(Value::Okay) => {
tracing::debug!(
@@ -549,23 +570,23 @@ impl Session {
Ok(_) => {
// Any other response likely means the key expired.
tracing::warn!(src=%self.key, "refreshing cancellation key failed");
// Re-enter the SET loop to repush full data.
state = State::Set;
// Re-enter the SET loop quickly to repush full data.
state = State::Init;
wait_interval = Duration::ZERO;
}
// retry immediately.
Err(BatchQueueError::Result(error)) => {
tracing::warn!(?error, "error refreshing cancellation key");
// Small delay to prevent busy loop with high cpu and logging.
tokio::time::sleep(Duration::from_millis(10)).await;
continue;
wait_interval = Duration::from_millis(10);
}
Err(BatchQueueError::Cancelled(Err(_cancelled))) => break,
}
// wait before continuing. break immediately if cancelled.
if run_until(tokio::time::sleep(CANCEL_KEY_REFRESH), cancel.as_mut())
if run_until(tokio::time::sleep(wait_interval), cancel.as_mut())
.await
.is_err()
{
@@ -579,7 +600,7 @@ impl Session {
.await
{
tracing::warn!(
?session_id,
%session_id,
?err,
"could not cancel the query in the database"
);

View File

@@ -25,6 +25,7 @@ use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::error::{ReportableError, UserFacingError};
use crate::id::ComputeConnId;
use crate::metrics::{Metrics, NumDbConnectionsGuard};
use crate::pqproto::StartupMessageParams;
use crate::proxy::neon_option;
@@ -356,6 +357,7 @@ pub struct PostgresSettings {
}
pub struct ComputeConnection {
pub compute_conn_id: ComputeConnId,
/// Socket connected to a compute node.
pub stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
/// Labels for proxy's metrics.
@@ -373,6 +375,7 @@ impl ConnectInfo {
ctx: &RequestContext,
aux: &MetricsAuxInfo,
config: &ComputeConfig,
compute_conn_id: ComputeConnId,
) -> Result<ComputeConnection, ConnectionError> {
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
let (socket_addr, stream) = self.connect_raw(config).await?;
@@ -382,6 +385,7 @@ impl ConnectInfo {
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
info!(
%compute_conn_id,
cold_start_info = ctx.cold_start_info().as_str(),
"connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
self.host,
@@ -391,6 +395,7 @@ impl ConnectInfo {
);
let connection = ComputeConnection {
compute_conn_id,
stream,
socket_addr,
hostname: self.host.clone(),

View File

@@ -10,7 +10,8 @@ use crate::cancellation::CancellationHandler;
use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::context::RequestContext;
use crate::error::ReportableError;
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::id::{ClientConnId, ComputeConnId, RequestId};
use crate::metrics::{Metrics, NumClientConnectionsGuard, Protocol};
use crate::pglb::ClientRequestError;
use crate::pglb::handshake::{HandshakeData, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
@@ -42,12 +43,10 @@ pub async fn task_main(
{
let (socket, peer_addr) = accept_result?;
let conn_gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Tcp);
let conn_gauge = Metrics::get().proxy.client_connections.guard(Protocol::Tcp);
let session_id = uuid::Uuid::new_v4();
let conn_id = ClientConnId::new();
let session_id = RequestId::from_uuid(conn_id.uuid());
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
@@ -90,7 +89,7 @@ pub async fn task_main(
}
}
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
let ctx = RequestContext::new(conn_id, session_id, conn_info, Protocol::Tcp);
let res = handle_client(
config,
@@ -120,13 +119,13 @@ pub async fn task_main(
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
error!(
?session_id,
%session_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
%session_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
@@ -214,10 +213,14 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
};
auth_info.set_startup_params(&params, true);
// for TCP/WS, we have client_id=session_id=compute_id for now.
let compute_conn_id = ComputeConnId::from_uuid(ctx.session_id().uuid());
let mut node = connect_to_compute(
ctx,
&TcpMechanism {
locks: &config.connect_compute_locks,
compute_conn_id,
},
&node_info,
config.wake_compute_retry_config,
@@ -250,6 +253,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
});
Ok(Some(ProxyPassthrough {
compute_conn_id: node.compute_conn_id,
client: stream,
compute: node.stream,

View File

@@ -9,11 +9,11 @@ use tokio::sync::mpsc;
use tracing::field::display;
use tracing::{Span, error, info_span};
use try_lock::TryLock;
use uuid::Uuid;
use self::parquet::RequestData;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::error::ErrorKind;
use crate::id::{ClientConnId, RequestId};
use crate::intern::{BranchIdInt, ProjectIdInt};
use crate::metrics::{LatencyAccumulated, LatencyTimer, Metrics, Protocol, Waiting};
use crate::pqproto::StartupMessageParams;
@@ -40,7 +40,7 @@ pub struct RequestContext(
struct RequestContextInner {
pub(crate) conn_info: ConnectionInfo,
pub(crate) session_id: Uuid,
pub(crate) session_id: RequestId,
pub(crate) protocol: Protocol,
first_packet: chrono::DateTime<Utc>,
pub(crate) span: Span,
@@ -116,12 +116,18 @@ impl Clone for RequestContext {
}
impl RequestContext {
pub fn new(session_id: Uuid, conn_info: ConnectionInfo, protocol: Protocol) -> Self {
pub fn new(
conn_id: ClientConnId,
session_id: RequestId,
conn_info: ConnectionInfo,
protocol: Protocol,
) -> Self {
// TODO: be careful with long lived spans
let span = info_span!(
"connect_request",
%protocol,
?session_id,
%session_id,
%conn_id,
%conn_info,
ep = tracing::field::Empty,
role = tracing::field::Empty,
@@ -164,7 +170,13 @@ impl RequestContext {
let ip = IpAddr::from([127, 0, 0, 1]);
let addr = SocketAddr::new(ip, 5432);
let conn_info = ConnectionInfo { addr, extra: None };
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp)
let uuid = uuid::Uuid::now_v7();
RequestContext::new(
ClientConnId::from_uuid(uuid),
RequestId::from_uuid(uuid),
conn_info,
Protocol::Tcp,
)
}
pub(crate) fn console_application_name(&self) -> String {
@@ -311,7 +323,7 @@ impl RequestContext {
self.0.try_lock().expect("should not deadlock").span.clone()
}
pub(crate) fn session_id(&self) -> Uuid {
pub(crate) fn session_id(&self) -> RequestId {
self.0.try_lock().expect("should not deadlock").session_id
}

View File

@@ -124,7 +124,7 @@ impl serde::Serialize for Options<'_> {
impl From<&RequestContextInner> for RequestData {
fn from(value: &RequestContextInner) -> Self {
Self {
session_id: value.session_id,
session_id: value.session_id.uuid(),
peer_addr: value.conn_info.addr.ip().to_string(),
timestamp: value.first_packet.naive_utc(),
username: value.user.as_deref().map(String::from),

View File

@@ -68,6 +68,66 @@ impl NeonControlPlaneClient {
self.endpoint.url().as_str()
}
async fn get_and_cache_auth_info<T>(
&self,
ctx: &RequestContext,
endpoint: &EndpointId,
role: &RoleName,
cache_key: &EndpointId,
extract: impl FnOnce(&EndpointAccessControl, &RoleAccessControl) -> T,
) -> Result<T, GetAuthInfoError> {
match self.do_get_auth_req(ctx, endpoint, role).await {
Ok(auth_info) => {
let control = EndpointAccessControl {
allowed_ips: Arc::new(auth_info.allowed_ips),
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
flags: auth_info.access_blocker_flags,
rate_limits: auth_info.rate_limits,
};
let role_control = RoleAccessControl {
secret: auth_info.secret,
};
let res = extract(&control, &role_control);
self.caches.project_info.insert_endpoint_access(
auth_info.account_id,
auth_info.project_id,
cache_key.into(),
role.into(),
control,
role_control,
);
if let Some(project_id) = auth_info.project_id {
ctx.set_project_id(project_id);
}
Ok(res)
}
Err(err) => match err {
GetAuthInfoError::ApiError(ControlPlaneError::Message(ref msg)) => {
let retry_info = msg.status.as_ref().and_then(|s| s.details.retry_info);
// If we can retry this error, do not cache it,
// unless we were given a retry delay.
if msg.could_retry() && retry_info.is_none() {
return Err(err);
}
self.caches.project_info.insert_endpoint_access_err(
cache_key.into(),
role.into(),
msg.clone(),
retry_info.map(|r| Duration::from_millis(r.retry_delay_ms)),
);
Err(err)
}
err => Err(err),
},
}
}
async fn do_get_auth_req(
&self,
ctx: &RequestContext,
@@ -284,43 +344,34 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
ctx: &RequestContext,
endpoint: &EndpointId,
role: &RoleName,
) -> Result<RoleAccessControl, crate::control_plane::errors::GetAuthInfoError> {
let normalized_ep = &endpoint.normalize();
if let Some(secret) = self
) -> Result<RoleAccessControl, GetAuthInfoError> {
let key = endpoint.normalize();
if let Some((role_control, ttl)) = self
.caches
.project_info
.get_role_secret(normalized_ep, role)
.get_role_secret_with_ttl(&key, role)
{
return Ok(secret);
return match role_control {
Err(mut msg) => {
info!(key = &*key, "found cached get_role_access_control error");
// if retry_delay_ms is set change it to the remaining TTL
replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64);
Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg)))
}
Ok(role_control) => {
debug!(key = &*key, "found cached role access control");
Ok(role_control)
}
};
}
let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
let control = EndpointAccessControl {
allowed_ips: Arc::new(auth_info.allowed_ips),
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
flags: auth_info.access_blocker_flags,
rate_limits: auth_info.rate_limits,
};
let role_control = RoleAccessControl {
secret: auth_info.secret,
};
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();
self.caches.project_info.insert_endpoint_access(
auth_info.account_id,
project_id,
normalized_ep_int,
role.into(),
control,
role_control.clone(),
);
ctx.set_project_id(project_id);
}
Ok(role_control)
self.get_and_cache_auth_info(ctx, endpoint, role, &key, |_, role_control| {
role_control.clone()
})
.await
}
#[tracing::instrument(skip_all)]
@@ -330,38 +381,30 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
endpoint: &EndpointId,
role: &RoleName,
) -> Result<EndpointAccessControl, GetAuthInfoError> {
let normalized_ep = &endpoint.normalize();
if let Some(control) = self.caches.project_info.get_endpoint_access(normalized_ep) {
return Ok(control);
let key = endpoint.normalize();
if let Some((control, ttl)) = self.caches.project_info.get_endpoint_access_with_ttl(&key) {
return match control {
Err(mut msg) => {
info!(
key = &*key,
"found cached get_endpoint_access_control error"
);
// if retry_delay_ms is set change it to the remaining TTL
replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64);
Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg)))
}
Ok(control) => {
debug!(key = &*key, "found cached endpoint access control");
Ok(control)
}
};
}
let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
let control = EndpointAccessControl {
allowed_ips: Arc::new(auth_info.allowed_ips),
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
flags: auth_info.access_blocker_flags,
rate_limits: auth_info.rate_limits,
};
let role_control = RoleAccessControl {
secret: auth_info.secret,
};
if let Some(project_id) = auth_info.project_id {
let normalized_ep_int = normalized_ep.into();
self.caches.project_info.insert_endpoint_access(
auth_info.account_id,
project_id,
normalized_ep_int,
role.into(),
control.clone(),
role_control,
);
ctx.set_project_id(project_id);
}
Ok(control)
self.get_and_cache_auth_info(ctx, endpoint, role, &key, |control, _| control.clone())
.await
}
#[tracing::instrument(skip_all)]
@@ -390,13 +433,9 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
info!(key = &*key, "found cached wake_compute error");
// if retry_delay_ms is set, reduce it by the amount of time it spent in cache
if let Some(status) = &mut msg.status {
if let Some(retry_info) = &mut status.details.retry_info {
retry_info.retry_delay_ms = retry_info
.retry_delay_ms
.saturating_sub(created_at.elapsed().as_millis() as u64)
}
}
replace_retry_delay_ms(&mut msg, |delay| {
delay.saturating_sub(created_at.elapsed().as_millis() as u64)
});
Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
msg,
@@ -478,6 +517,14 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
}
}
fn replace_retry_delay_ms(msg: &mut ControlPlaneErrorMessage, f: impl FnOnce(u64) -> u64) {
if let Some(status) = &mut msg.status
&& let Some(retry_info) = &mut status.details.retry_info
{
retry_info.retry_delay_ms = f(retry_info.retry_delay_ms);
}
}
/// Parse http response body, taking status code into account.
fn parse_body<T: for<'a> serde::Deserialize<'a>>(
status: StatusCode,

View File

@@ -52,7 +52,7 @@ impl ReportableError for ControlPlaneError {
| Reason::EndpointNotFound
| Reason::EndpointDisabled
| Reason::BranchNotFound
| Reason::InvalidEphemeralEndpointOptions => ErrorKind::User,
| Reason::WrongLsnOrTimestamp => ErrorKind::User,
Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit,

View File

@@ -107,7 +107,7 @@ pub(crate) struct ErrorInfo {
// Schema could also have `metadata` field, but it's not structured. Skip it for now.
}
#[derive(Clone, Copy, Debug, Deserialize, Default)]
#[derive(Clone, Copy, Debug, Deserialize, Default, PartialEq, Eq)]
pub(crate) enum Reason {
/// RoleProtected indicates that the role is protected and the attempted operation is not permitted on protected roles.
#[serde(rename = "ROLE_PROTECTED")]
@@ -133,9 +133,9 @@ pub(crate) enum Reason {
/// or that the subject doesn't have enough permissions to access the requested branch.
#[serde(rename = "BRANCH_NOT_FOUND")]
BranchNotFound,
/// InvalidEphemeralEndpointOptions indicates that the specified LSN or timestamp are wrong.
#[serde(rename = "INVALID_EPHEMERAL_OPTIONS")]
InvalidEphemeralEndpointOptions,
/// WrongLsnOrTimestamp indicates that the specified LSN or timestamp are wrong.
#[serde(rename = "WRONG_LSN_OR_TIMESTAMP")]
WrongLsnOrTimestamp,
/// RateLimitExceeded indicates that the rate limit for the operation has been exceeded.
#[serde(rename = "RATE_LIMIT_EXCEEDED")]
RateLimitExceeded,
@@ -205,7 +205,7 @@ impl Reason {
| Reason::EndpointNotFound
| Reason::EndpointDisabled
| Reason::BranchNotFound
| Reason::InvalidEphemeralEndpointOptions => false,
| Reason::WrongLsnOrTimestamp => false,
// we were asked to go away
Reason::RateLimitExceeded
| Reason::NonDefaultBranchComputeTimeExceeded
@@ -257,19 +257,19 @@ pub(crate) struct GetEndpointAccessControl {
pub(crate) rate_limits: EndpointRateLimitConfig,
}
#[derive(Copy, Clone, Deserialize, Default)]
#[derive(Copy, Clone, Deserialize, Default, Debug)]
pub struct EndpointRateLimitConfig {
pub connection_attempts: ConnectionAttemptsLimit,
}
#[derive(Copy, Clone, Deserialize, Default)]
#[derive(Copy, Clone, Deserialize, Default, Debug)]
pub struct ConnectionAttemptsLimit {
pub tcp: Option<LeakyBucketSetting>,
pub ws: Option<LeakyBucketSetting>,
pub http: Option<LeakyBucketSetting>,
}
#[derive(Copy, Clone, Deserialize)]
#[derive(Copy, Clone, Deserialize, Debug)]
pub struct LeakyBucketSetting {
pub rps: f64,
pub burst: f64,

View File

@@ -20,6 +20,7 @@ use crate::cache::{Cached, TimedLru};
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
use crate::id::ComputeConnId;
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt};
use crate::protocol2::ConnectionInfoExtra;
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig};
@@ -77,12 +78,15 @@ impl NodeInfo {
&self,
ctx: &RequestContext,
config: &ComputeConfig,
compute_conn_id: ComputeConnId,
) -> Result<compute::ComputeConnection, compute::ConnectionError> {
self.conn_info.connect(ctx, &self.aux, config).await
self.conn_info
.connect(ctx, &self.aux, config, compute_conn_id)
.await
}
}
#[derive(Copy, Clone, Default)]
#[derive(Copy, Clone, Default, Debug)]
pub(crate) struct AccessBlockerFlags {
pub public_access_blocked: bool,
pub vpc_access_blocked: bool,
@@ -92,12 +96,12 @@ pub(crate) type NodeInfoCache =
TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ControlPlaneErrorMessage>>>;
pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct RoleAccessControl {
pub secret: Option<AuthSecret>,
}
#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct EndpointAccessControl {
pub allowed_ips: Arc<Vec<IpPattern>>,
pub allowed_vpce: Arc<Vec<String>>,

33
proxy/src/id.rs Normal file
View File

@@ -0,0 +1,33 @@
//! Various ID types used by proxy.
use type_safe_id::{StaticType, TypeSafeId};
/// The ID used for the client connection
pub type ClientConnId = TypeSafeId<ClientConn>;
#[derive(Copy, Clone, Default, Hash, PartialEq, Eq)]
pub struct ClientConn;
impl StaticType for ClientConn {
// This is visible by customers, so we use 'neon' here instead of 'client'.
const TYPE: &'static str = "neon_conn";
}
/// The ID used for the compute connection
pub type ComputeConnId = TypeSafeId<ComputeConn>;
#[derive(Copy, Clone, Default, Hash, PartialEq, Eq)]
pub struct ComputeConn;
impl StaticType for ComputeConn {
const TYPE: &'static str = "compute_conn";
}
/// The ID used for the request to authenticate
pub type RequestId = TypeSafeId<Request>;
#[derive(Copy, Clone, Default, Hash, PartialEq, Eq)]
pub struct Request;
impl StaticType for Request {
const TYPE: &'static str = "request";
}

View File

@@ -91,6 +91,7 @@ mod control_plane;
mod error;
mod ext;
mod http;
mod id;
mod intern;
mod jemalloc;
mod logging;

View File

@@ -17,7 +17,8 @@ use crate::cancellation::{self, CancellationHandler};
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
use crate::context::RequestContext;
use crate::error::{ReportableError, UserFacingError};
use crate::metrics::{Metrics, NumClientConnectionsGuard};
use crate::id::{ClientConnId, RequestId};
use crate::metrics::{Metrics, NumClientConnectionsGuard, Protocol};
pub use crate::pglb::copy_bidirectional::ErrorSource;
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
use crate::pglb::passthrough::ProxyPassthrough;
@@ -65,12 +66,11 @@ pub async fn task_main(
{
let (socket, peer_addr) = accept_result?;
let conn_gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Tcp);
let conn_gauge = Metrics::get().proxy.client_connections.guard(Protocol::Tcp);
let conn_id = ClientConnId::new();
let session_id = RequestId::from_uuid(conn_id.uuid());
let session_id = uuid::Uuid::new_v4();
let cancellation_handler = Arc::clone(&cancellation_handler);
let cancellations = cancellations.clone();
@@ -114,7 +114,7 @@ pub async fn task_main(
}
}
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
let ctx = RequestContext::new(conn_id, session_id, conn_info, Protocol::Tcp);
let res = handle_connection(
config,
@@ -142,17 +142,22 @@ pub async fn task_main(
Ok(Some(p)) => {
ctx.set_success();
let _disconnect = ctx.log_connect();
let compute_conn_id = p.compute_conn_id;
match p.proxy_pass().await {
Ok(()) => {}
Err(ErrorSource::Client(e)) => {
warn!(
?session_id,
%conn_id,
%session_id,
%compute_conn_id,
"per-client task finished with an IO error from the client: {e:#}"
);
}
Err(ErrorSource::Compute(e)) => {
error!(
?session_id,
%conn_id,
%session_id,
%compute_conn_id,
"per-client task finished with an IO error from the compute: {e:#}"
);
}
@@ -318,6 +323,8 @@ pub(crate) async fn handle_connection<S: AsyncRead + AsyncWrite + Unpin + Send>(
};
Ok(Some(ProxyPassthrough {
compute_conn_id: node.compute_conn_id,
client,
compute: node.stream,

View File

@@ -8,6 +8,7 @@ use utils::measured_stream::MeasuredStream;
use super::copy_bidirectional::ErrorSource;
use crate::compute::MaybeRustlsStream;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::id::ComputeConnId;
use crate::metrics::{
Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard,
NumDbConnectionsGuard,
@@ -65,6 +66,8 @@ pub(crate) async fn proxy_pass(
}
pub(crate) struct ProxyPassthrough<S> {
pub(crate) compute_conn_id: ComputeConnId,
pub(crate) client: Stream<S>,
pub(crate) compute: MaybeRustlsStream,

View File

@@ -9,6 +9,7 @@ use crate::control_plane::errors::WakeComputeError;
use crate::control_plane::locks::ApiLocks;
use crate::control_plane::{self, NodeInfo};
use crate::error::ReportableError;
use crate::id::ComputeConnId;
use crate::metrics::{
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
};
@@ -51,6 +52,7 @@ pub(crate) trait ConnectMechanism {
pub(crate) struct TcpMechanism {
/// connect_to_compute concurrency lock
pub(crate) locks: &'static ApiLocks<Host>,
pub(crate) compute_conn_id: ComputeConnId,
}
#[async_trait]
@@ -70,7 +72,7 @@ impl ConnectMechanism for TcpMechanism {
config: &ComputeConfig,
) -> Result<ComputeConnection, Self::Error> {
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
permit.release_result(node_info.connect(ctx, config).await)
permit.release_result(node_info.connect(ctx, config, self.compute_conn_id).await)
}
}

View File

@@ -24,6 +24,7 @@ use crate::compute::ComputeConnection;
use crate::config::ProxyConfig;
use crate::context::RequestContext;
use crate::control_plane::client::ControlPlaneClient;
use crate::id::ComputeConnId;
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
use crate::pglb::{ClientMode, ClientRequestError};
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
@@ -94,6 +95,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
let mut attempt = 0;
let connect = TcpMechanism {
locks: &config.connect_compute_locks,
// for TCP/WS, we have client_id=session_id=compute_id for now.
compute_conn_id: ComputeConnId::from_uuid(ctx.session_id().uuid()),
};
let backend = auth::Backend::ControlPlane(cplane, creds.info);

View File

@@ -33,6 +33,7 @@ use crate::control_plane::client::ApiLockError;
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
use crate::control_plane::locks::ApiLocks;
use crate::error::{ErrorKind, ReportableError, UserFacingError};
use crate::id::{ComputeConnId, RequestId};
use crate::intern::EndpointIdInt;
use crate::proxy::connect_compute::ConnectMechanism;
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
@@ -161,7 +162,7 @@ impl PoolingBackend {
#[tracing::instrument(skip_all, fields(
pid = tracing::field::Empty,
compute_id = tracing::field::Empty,
conn_id = tracing::field::Empty,
compute_conn_id = tracing::field::Empty,
))]
pub(crate) async fn connect_to_compute(
&self,
@@ -181,14 +182,14 @@ impl PoolingBackend {
if let Some(client) = maybe_client {
return Ok(client);
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
let compute_conn_id = ComputeConnId::new();
tracing::Span::current().record("compute_conn_id", display(compute_conn_id));
info!(%compute_conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| keys.info);
crate::proxy::connect_compute::connect_to_compute(
ctx,
&TokioMechanism {
conn_id,
compute_conn_id,
conn_info,
pool: self.pool.clone(),
locks: &self.config.connect_compute_locks,
@@ -204,7 +205,7 @@ impl PoolingBackend {
// Wake up the destination if needed
#[tracing::instrument(skip_all, fields(
compute_id = tracing::field::Empty,
conn_id = tracing::field::Empty,
compute_conn_id = tracing::field::Empty,
))]
pub(crate) async fn connect_to_local_proxy(
&self,
@@ -216,9 +217,9 @@ impl PoolingBackend {
return Ok(client);
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
debug!(%conn_id, "pool: opening a new connection '{conn_info}'");
let compute_conn_id = ComputeConnId::new();
tracing::Span::current().record("compute_conn_id", display(compute_conn_id));
debug!(%compute_conn_id, "pool: opening a new connection '{conn_info}'");
let backend = self.auth_backend.as_ref().map(|()| ComputeUserInfo {
user: conn_info.user_info.user.clone(),
endpoint: EndpointId::from(format!(
@@ -230,7 +231,7 @@ impl PoolingBackend {
crate::proxy::connect_compute::connect_to_compute(
ctx,
&HyperMechanism {
conn_id,
compute_conn_id,
conn_info,
pool: self.http_conn_pool.clone(),
locks: &self.config.connect_compute_locks,
@@ -251,7 +252,7 @@ impl PoolingBackend {
/// Panics if called with a non-local_proxy backend.
#[tracing::instrument(skip_all, fields(
pid = tracing::field::Empty,
conn_id = tracing::field::Empty,
compute_conn_id = tracing::field::Empty,
))]
pub(crate) async fn connect_to_local_postgres(
&self,
@@ -303,9 +304,9 @@ impl PoolingBackend {
}
}
let conn_id = uuid::Uuid::new_v4();
tracing::Span::current().record("conn_id", display(conn_id));
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
let compute_conn_id = ComputeConnId::new();
tracing::Span::current().record("compute_conn_id", display(compute_conn_id));
info!(%compute_conn_id, "local_pool: opening a new connection '{conn_info}'");
let (key, jwk) = create_random_jwk();
@@ -340,7 +341,7 @@ impl PoolingBackend {
client,
connection,
key,
conn_id,
compute_conn_id,
local_backend.node_info.aux.clone(),
);
@@ -378,7 +379,7 @@ fn create_random_jwk() -> (SigningKey, jose_jwk::Key) {
#[derive(Debug, thiserror::Error)]
pub(crate) enum HttpConnError {
#[error("pooled connection closed at inconsistent state")]
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<RequestId>),
#[error("could not connect to postgres in compute")]
PostgresConnectionError(#[from] postgres_client::Error),
#[error("could not connect to local-proxy in compute")]
@@ -509,7 +510,7 @@ impl ShouldRetryWakeCompute for LocalProxyConnError {
struct TokioMechanism {
pool: Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
compute_conn_id: ComputeConnId,
keys: ComputeCredentialKeys,
/// connect_to_compute concurrency lock
@@ -561,7 +562,7 @@ impl ConnectMechanism for TokioMechanism {
self.conn_info.clone(),
client,
connection,
self.conn_id,
self.compute_conn_id,
node_info.aux.clone(),
))
}
@@ -570,7 +571,7 @@ impl ConnectMechanism for TokioMechanism {
struct HyperMechanism {
pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
conn_info: ConnInfo,
conn_id: uuid::Uuid,
compute_conn_id: ComputeConnId,
/// connect_to_compute concurrency lock
locks: &'static ApiLocks<Host>,
@@ -620,7 +621,7 @@ impl ConnectMechanism for HyperMechanism {
&self.conn_info,
client,
connection,
self.conn_id,
self.compute_conn_id,
node_info.aux.clone(),
))
}

View File

@@ -10,7 +10,8 @@ use rand::{Rng, thread_rng};
use rustc_hash::FxHasher;
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use uuid::Uuid;
use crate::id::ClientConnId;
type Hasher = BuildHasherDefault<FxHasher>;
@@ -21,7 +22,7 @@ pub struct CancelSet {
}
pub(crate) struct CancelShard {
tokens: IndexMap<uuid::Uuid, (Instant, CancellationToken), Hasher>,
tokens: IndexMap<ClientConnId, (Instant, CancellationToken), Hasher>,
}
impl CancelSet {
@@ -53,7 +54,7 @@ impl CancelSet {
.and_then(|len| self.shards[rng % len].lock().take(rng / len))
}
pub(crate) fn insert(&self, id: uuid::Uuid, token: CancellationToken) -> CancelGuard<'_> {
pub(crate) fn insert(&self, id: ClientConnId, token: CancellationToken) -> CancelGuard<'_> {
let shard = NonZeroUsize::new(self.shards.len()).map(|len| {
let hash = self.hasher.hash_one(id) as usize;
let shard = &self.shards[hash % len];
@@ -77,18 +78,18 @@ impl CancelShard {
})
}
fn remove(&mut self, id: uuid::Uuid) {
fn remove(&mut self, id: ClientConnId) {
self.tokens.swap_remove(&id);
}
fn insert(&mut self, id: uuid::Uuid, token: CancellationToken) {
fn insert(&mut self, id: ClientConnId, token: CancellationToken) {
self.tokens.insert(id, (Instant::now(), token));
}
}
pub(crate) struct CancelGuard<'a> {
shard: Option<&'a Mutex<CancelShard>>,
id: Uuid,
id: ClientConnId,
}
impl Drop for CancelGuard<'_> {

View File

@@ -26,6 +26,7 @@ use super::conn_pool_lib::{
use crate::config::ComputeConfig;
use crate::context::RequestContext;
use crate::control_plane::messages::MetricsAuxInfo;
use crate::id::{ComputeConnId, RequestId};
use crate::metrics::Metrics;
type TlsStream = <ComputeConfig as MakeTlsConnect<TcpStream>>::Stream;
@@ -62,14 +63,14 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
conn_info: ConnInfo,
client: C,
mut connection: postgres_client::Connection<TcpStream, TlsStream>,
conn_id: uuid::Uuid,
compute_conn_id: ComputeConnId,
aux: MetricsAuxInfo,
) -> Client<C> {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let mut session_id = ctx.session_id();
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
let span = info_span!(parent: None, "connection", %conn_id);
let span = info_span!(parent: None, "connection", %compute_conn_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
@@ -117,7 +118,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
if let Some(pool) = pool.clone().upgrade() {
// remove client from pool - should close the connection if it's idle.
// does nothing if the client is currently checked-out and in-use
if pool.write().remove_client(db_user.clone(), conn_id) {
if pool.write().remove_client(db_user.clone(), compute_conn_id) {
info!("idle connection removed");
}
}
@@ -149,7 +150,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
// remove from connection pool
if let Some(pool) = pool.clone().upgrade()
&& pool.write().remove_client(db_user.clone(), conn_id) {
&& pool.write().remove_client(db_user.clone(), compute_conn_id) {
info!("closed connection removed");
}
@@ -161,7 +162,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
let inner = ClientInnerCommon {
inner: client,
aux,
conn_id,
compute_conn_id,
data: ClientDataEnum::Remote(ClientDataRemote {
session: tx,
cancel,
@@ -173,12 +174,12 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
#[derive(Clone)]
pub(crate) struct ClientDataRemote {
session: tokio::sync::watch::Sender<uuid::Uuid>,
session: tokio::sync::watch::Sender<RequestId>,
cancel: CancellationToken,
}
impl ClientDataRemote {
pub fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
pub fn session(&mut self) -> &mut tokio::sync::watch::Sender<RequestId> {
&mut self.session
}
@@ -192,6 +193,7 @@ mod tests {
use std::sync::atomic::AtomicBool;
use super::*;
use crate::id::ComputeConnId;
use crate::proxy::NeonOptions;
use crate::serverless::cancel_set::CancelSet;
use crate::types::{BranchId, EndpointId, ProjectId};
@@ -225,9 +227,9 @@ mod tests {
compute_id: "compute".into(),
cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
},
conn_id: uuid::Uuid::new_v4(),
compute_conn_id: ComputeConnId::new(),
data: ClientDataEnum::Remote(ClientDataRemote {
session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
session: tokio::sync::watch::Sender::new(RequestId::new()),
cancel: CancellationToken::new(),
}),
}

View File

@@ -19,6 +19,7 @@ use super::local_conn_pool::ClientDataLocal;
use crate::auth::backend::ComputeUserInfo;
use crate::context::RequestContext;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::id::ComputeConnId;
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::protocol2::ConnectionInfoExtra;
use crate::types::{DbName, EndpointCacheKey, RoleName};
@@ -58,7 +59,7 @@ pub(crate) enum ClientDataEnum {
pub(crate) struct ClientInnerCommon<C: ClientInnerExt> {
pub(crate) inner: C,
pub(crate) aux: MetricsAuxInfo,
pub(crate) conn_id: uuid::Uuid,
pub(crate) compute_conn_id: ComputeConnId,
pub(crate) data: ClientDataEnum, // custom client data like session, key, jti
}
@@ -77,8 +78,8 @@ impl<C: ClientInnerExt> Drop for ClientInnerCommon<C> {
}
impl<C: ClientInnerExt> ClientInnerCommon<C> {
pub(crate) fn get_conn_id(&self) -> uuid::Uuid {
self.conn_id
pub(crate) fn get_conn_id(&self) -> ComputeConnId {
self.compute_conn_id
}
pub(crate) fn get_data(&mut self) -> &mut ClientDataEnum {
@@ -144,7 +145,7 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
pub(crate) fn remove_client(
&mut self,
db_user: (DbName, RoleName),
conn_id: uuid::Uuid,
conn_id: ComputeConnId,
) -> bool {
let Self {
pools,
@@ -189,7 +190,7 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
}
pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInnerCommon<C>) {
let conn_id = client.get_conn_id();
let compute_conn_id = client.get_conn_id();
let (max_conn, conn_count, pool_name) = {
let pool = pool.read();
(
@@ -201,12 +202,12 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
};
if client.inner.is_closed() {
info!(%conn_id, "{}: throwing away connection '{conn_info}' because connection is closed", pool_name);
info!(%compute_conn_id, "{}: throwing away connection '{conn_info}' because connection is closed", pool_name);
return;
}
if conn_count >= max_conn {
info!(%conn_id, "{}: throwing away connection '{conn_info}' because pool is full", pool_name);
info!(%compute_conn_id, "{}: throwing away connection '{conn_info}' because pool is full", pool_name);
return;
}
@@ -241,9 +242,9 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
// do logging outside of the mutex
if returned {
debug!(%conn_id, "{pool_name}: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
debug!(%compute_conn_id, "{pool_name}: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
} else {
info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
info!(%compute_conn_id, "{pool_name}: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
}
}
}

View File

@@ -18,6 +18,7 @@ use super::conn_pool_lib::{
};
use crate::context::RequestContext;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::id::ComputeConnId;
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
use crate::protocol2::ConnectionInfoExtra;
use crate::types::EndpointCacheKey;
@@ -65,7 +66,7 @@ impl<C: ClientInnerExt + Clone> HttpConnPool<C> {
}
}
fn remove_conn(&mut self, conn_id: uuid::Uuid) -> bool {
fn remove_conn(&mut self, conn_id: ComputeConnId) -> bool {
let Self {
conns,
global_connections_count,
@@ -73,7 +74,7 @@ impl<C: ClientInnerExt + Clone> HttpConnPool<C> {
} = self;
let old_len = conns.len();
conns.retain(|entry| entry.conn.conn_id != conn_id);
conns.retain(|entry| entry.conn.compute_conn_id != conn_id);
let new_len = conns.len();
let removed = old_len - new_len;
if removed > 0 {
@@ -135,7 +136,10 @@ impl<C: ClientInnerExt + Clone> GlobalConnPool<C, HttpConnPool<C>> {
return result;
};
tracing::Span::current().record("conn_id", tracing::field::display(client.conn.conn_id));
tracing::Span::current().record(
"conn_id",
tracing::field::display(client.conn.compute_conn_id),
);
debug!(
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
"pool: reusing connection '{conn_info}'"
@@ -194,13 +198,13 @@ pub(crate) fn poll_http2_client(
conn_info: &ConnInfo,
client: Send,
connection: Connect,
conn_id: uuid::Uuid,
compute_conn_id: ComputeConnId,
aux: MetricsAuxInfo,
) -> Client<Send> {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let session_id = ctx.session_id();
let span = info_span!(parent: None, "connection", %conn_id);
let span = info_span!(parent: None, "connection", %compute_conn_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
@@ -212,7 +216,7 @@ pub(crate) fn poll_http2_client(
let client = ClientInnerCommon {
inner: client.clone(),
aux: aux.clone(),
conn_id,
compute_conn_id,
data: ClientDataEnum::Http(ClientDataHttp()),
};
pool.write().conns.push_back(ConnPoolEntry {
@@ -241,7 +245,7 @@ pub(crate) fn poll_http2_client(
// remove from connection pool
if let Some(pool) = pool.clone().upgrade()
&& pool.write().remove_conn(conn_id)
&& pool.write().remove_conn(compute_conn_id)
{
info!("closed connection removed");
}
@@ -252,7 +256,7 @@ pub(crate) fn poll_http2_client(
let client = ClientInnerCommon {
inner: client,
aux,
conn_id,
compute_conn_id,
data: ClientDataEnum::Http(ClientDataHttp()),
};

View File

@@ -10,7 +10,6 @@ use http_body_util::{BodyExt, Full};
use http_utils::error::ApiError;
use serde::Serialize;
use url::Url;
use uuid::Uuid;
use super::conn_pool::{AuthData, ConnInfoWithAuth};
use super::conn_pool_lib::ConnInfo;
@@ -18,6 +17,7 @@ use super::error::{ConnInfoError, Credentials};
use crate::auth::backend::ComputeUserInfo;
use crate::config::AuthenticationConfig;
use crate::context::RequestContext;
use crate::id::RequestId;
use crate::metrics::{Metrics, SniGroup, SniKind};
use crate::pqproto::StartupMessageParams;
use crate::proxy::NeonOptions;
@@ -34,9 +34,8 @@ pub(super) static TXN_ISOLATION_LEVEL: HeaderName =
pub(super) static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
pub(super) static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
pub(crate) fn uuid_to_header_value(id: RequestId) -> HeaderValue {
HeaderValue::from_maybe_shared(Bytes::from(id.to_string().into_bytes()))
.expect("uuid hyphenated format should be all valid header characters")
}

View File

@@ -40,6 +40,7 @@ use super::conn_pool_lib::{
use super::sql_over_http::SqlOverHttpError;
use crate::context::RequestContext;
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
use crate::id::{ComputeConnId, RequestId};
use crate::metrics::Metrics;
pub(crate) const EXT_NAME: &str = "pg_session_jwt";
@@ -48,14 +49,14 @@ pub(crate) const EXT_SCHEMA: &str = "auth";
#[derive(Clone)]
pub(crate) struct ClientDataLocal {
session: tokio::sync::watch::Sender<uuid::Uuid>,
session: tokio::sync::watch::Sender<RequestId>,
cancel: CancellationToken,
key: SigningKey,
jti: u64,
}
impl ClientDataLocal {
pub fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
pub fn session(&mut self) -> &mut tokio::sync::watch::Sender<RequestId> {
&mut self.session
}
@@ -167,14 +168,14 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
client: C,
mut connection: postgres_client::Connection<TcpStream, NoTlsStream>,
key: SigningKey,
conn_id: uuid::Uuid,
compute_conn_id: ComputeConnId,
aux: MetricsAuxInfo,
) -> Client<C> {
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
let mut session_id = ctx.session_id();
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
let span = info_span!(parent: None, "connection", %conn_id);
let span = info_span!(parent: None, "connection", %compute_conn_id);
let cold_start_info = ctx.cold_start_info();
span.in_scope(|| {
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
@@ -218,7 +219,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
if let Some(pool) = pool.clone().upgrade() {
// remove client from pool - should close the connection if it's idle.
// does nothing if the client is currently checked-out and in-use
if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
if pool.global_pool.write().remove_client(db_user.clone(), compute_conn_id) {
info!("idle connection removed");
}
}
@@ -250,7 +251,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
// remove from connection pool
if let Some(pool) = pool.clone().upgrade()
&& pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
&& pool.global_pool.write().remove_client(db_user.clone(), compute_conn_id) {
info!("closed connection removed");
}
@@ -263,7 +264,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
let inner = ClientInnerCommon {
inner: client,
aux,
conn_id,
compute_conn_id,
data: ClientDataEnum::Local(ClientDataLocal {
session: tx,
cancel,

View File

@@ -16,16 +16,16 @@ mod websocket;
use std::net::{IpAddr, SocketAddr};
use std::pin::{Pin, pin};
use std::str::FromStr;
use std::sync::Arc;
use anyhow::Context;
use arc_swap::ArcSwapOption;
use async_trait::async_trait;
use atomic_take::AtomicTake;
use bytes::Bytes;
pub use conn_pool_lib::GlobalConnPoolOptions;
use futures::TryFutureExt;
use futures::future::{Either, select};
use futures::{FutureExt, TryFutureExt};
use http::{Method, Response, StatusCode};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Empty};
@@ -48,7 +48,8 @@ use crate::cancellation::CancellationHandler;
use crate::config::{ProxyConfig, ProxyProtocolV2};
use crate::context::RequestContext;
use crate::ext::TaskExt;
use crate::metrics::Metrics;
use crate::id::{ClientConnId, RequestId};
use crate::metrics::{Metrics, Protocol};
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
@@ -131,13 +132,12 @@ pub async fn task_main(
tracing::error!("could not set nodelay: {e}");
continue;
}
let conn_id = uuid::Uuid::new_v4();
let http_conn_span = tracing::info_span!("http_conn", ?conn_id);
let conn_id = ClientConnId::new();
let n_connections = Metrics::get()
.proxy
.client_connections
.sample(crate::metrics::Protocol::Http);
.sample(Protocol::Http);
tracing::trace!(?n_connections, threshold = ?config.http_config.client_conn_threshold, "check");
if n_connections > config.http_config.client_conn_threshold {
tracing::trace!("attempting to cancel a random connection");
@@ -154,46 +154,41 @@ pub async fn task_main(
let cancellation_handler = cancellation_handler.clone();
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
let cancellations = cancellations.clone();
connections.spawn(
async move {
let conn_token2 = conn_token.clone();
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
connections.spawn(async move {
let conn_token2 = conn_token.clone();
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
let session_id = uuid::Uuid::new_v4();
let _gauge = Metrics::get()
.proxy
.client_connections
.guard(Protocol::Http);
let _gauge = Metrics::get()
.proxy
.client_connections
.guard(crate::metrics::Protocol::Http);
let startup_result = Box::pin(connection_startup(
config,
tls_acceptor,
conn_id,
conn,
peer_addr,
))
.await;
let Some((conn, conn_info)) = startup_result else {
return;
};
let startup_result = Box::pin(connection_startup(
config,
tls_acceptor,
session_id,
conn,
peer_addr,
))
.await;
let Some((conn, conn_info)) = startup_result else {
return;
};
Box::pin(connection_handler(
config,
backend,
connections2,
cancellations,
cancellation_handler,
endpoint_rate_limiter,
conn_token,
conn,
conn_info,
session_id,
))
.await;
}
.instrument(http_conn_span),
);
Box::pin(connection_handler(
config,
backend,
connections2,
cancellations,
cancellation_handler,
endpoint_rate_limiter,
conn_token,
conn,
conn_info,
conn_id,
))
.await;
});
}
connections.wait().await;
@@ -230,7 +225,7 @@ impl MaybeTlsAcceptor for &'static ArcSwapOption<crate::config::TlsConfig> {
async fn connection_startup(
config: &ProxyConfig,
tls_acceptor: Arc<dyn MaybeTlsAcceptor>,
session_id: uuid::Uuid,
conn_id: ClientConnId,
conn: TcpStream,
peer_addr: SocketAddr,
) -> Option<(AsyncRW, ConnectionInfo)> {
@@ -265,12 +260,12 @@ async fn connection_startup(
IpAddr::V4(ip) => ip.is_private(),
IpAddr::V6(_) => false,
};
info!(?session_id, %conn_info, "accepted new TCP connection");
info!(%conn_id, %conn_info, "accepted new TCP connection");
// try upgrade to TLS, but with a timeout.
let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
Ok(Ok(conn)) => {
info!(?session_id, %conn_info, "accepted new TLS connection");
info!(%conn_id, %conn_info, "accepted new TLS connection");
conn
}
// The handshake failed
@@ -278,7 +273,7 @@ async fn connection_startup(
if !has_private_peer_addr {
Metrics::get().proxy.tls_handshake_failures.inc();
}
warn!(?session_id, %conn_info, "failed to accept TLS connection: {e:?}");
warn!(%conn_id, %conn_info, "failed to accept TLS connection: {e:?}");
return None;
}
// The handshake timed out
@@ -286,7 +281,7 @@ async fn connection_startup(
if !has_private_peer_addr {
Metrics::get().proxy.tls_handshake_failures.inc();
}
warn!(?session_id, %conn_info, "failed to accept TLS connection: {e:?}");
warn!(%conn_id, %conn_info, "failed to accept TLS connection: {e:?}");
return None;
}
};
@@ -309,10 +304,8 @@ async fn connection_handler(
cancellation_token: CancellationToken,
conn: AsyncRW,
conn_info: ConnectionInfo,
session_id: uuid::Uuid,
conn_id: ClientConnId,
) {
let session_id = AtomicTake::new(session_id);
// Cancel all current inflight HTTP requests if the HTTP connection is closed.
let http_cancellation_token = CancellationToken::new();
let _cancel_connection = http_cancellation_token.clone().drop_guard();
@@ -322,20 +315,6 @@ async fn connection_handler(
let conn = server.serve_connection_with_upgrades(
hyper_util::rt::TokioIo::new(conn),
hyper::service::service_fn(move |req: hyper::Request<Incoming>| {
// First HTTP request shares the same session ID
let mut session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
if matches!(backend.auth_backend, crate::auth::Backend::Local(_)) {
// take session_id from request, if given.
if let Some(id) = req
.headers()
.get(&NEON_REQUEST_ID)
.and_then(|id| uuid::Uuid::try_parse_ascii(id.as_bytes()).ok())
{
session_id = id;
}
}
// Cancel the current inflight HTTP request if the requets stream is closed.
// This is slightly different to `_cancel_connection` in that
// h2 can cancel individual requests with a `RST_STREAM`.
@@ -352,7 +331,7 @@ async fn connection_handler(
backend.clone(),
connections.clone(),
cancellation_handler.clone(),
session_id,
conn_id,
conn_info2.clone(),
http_request_token,
endpoint_rate_limiter.clone(),
@@ -362,15 +341,8 @@ async fn connection_handler(
.map_ok_or_else(api_error_into_response, |r| r),
);
async move {
let mut res = handler.await;
let res = handler.await;
cancel_request.disarm();
// add the session ID to the response
if let Ok(resp) = &mut res {
resp.headers_mut()
.append(&NEON_REQUEST_ID, uuid_to_header_value(session_id));
}
res
}
}),
@@ -392,6 +364,44 @@ async fn connection_handler(
}
}
fn get_request_id(backend: &PoolingBackend, req: &hyper::Request<Incoming>) -> RequestId {
if matches!(backend.auth_backend, crate::auth::Backend::Local(_)) {
// take session_id from request, if given.
if let Some(id) = req
.headers()
.get(&NEON_REQUEST_ID)
.and_then(|id| uuid::Uuid::try_parse_ascii(id.as_bytes()).ok())
{
return RequestId::from_uuid(id);
}
if let Some(id) = req
.headers()
.get(&NEON_REQUEST_ID)
.and_then(|id| id.to_str().ok())
.and_then(|id| RequestId::from_str(id).ok())
{
return id;
}
}
RequestId::new()
}
fn set_request_id<T, E>(
mut res: Result<hyper::Response<T>, E>,
session_id: RequestId,
) -> Result<hyper::Response<T>, E> {
// add the session ID to the response
if let Ok(resp) = &mut res {
resp.headers_mut()
.append(&NEON_REQUEST_ID, uuid_to_header_value(session_id));
}
res
}
#[allow(clippy::too_many_arguments)]
async fn request_handler(
mut request: hyper::Request<Incoming>,
@@ -399,7 +409,7 @@ async fn request_handler(
backend: Arc<PoolingBackend>,
ws_connections: TaskTracker,
cancellation_handler: Arc<CancellationHandler>,
session_id: uuid::Uuid,
conn_id: ClientConnId,
conn_info: ConnectionInfo,
// used to cancel in-flight HTTP requests. not used to cancel websockets
http_cancellation_token: CancellationToken,
@@ -417,7 +427,8 @@ async fn request_handler(
if config.http_config.accept_websockets
&& framed_websockets::upgrade::is_upgrade_request(&request)
{
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Ws);
let session_id = RequestId::from_uuid(conn_id.uuid());
let ctx = RequestContext::new(conn_id, session_id, conn_info, Protocol::Ws);
ctx.set_user_agent(
request
@@ -457,7 +468,8 @@ async fn request_handler(
// Return the response so the spawned future can continue.
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http);
let session_id = get_request_id(&backend, &request);
let ctx = RequestContext::new(conn_id, session_id, conn_info, Protocol::Http);
let span = ctx.span();
let testodrome_id = request
@@ -473,6 +485,7 @@ async fn request_handler(
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
.instrument(span)
.map(|res| set_request_id(res, session_id))
.await
} else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
Response::builder()

View File

@@ -21,7 +21,8 @@ pub fn check_permission(claims: &Claims, tenant_id: Option<TenantId>) -> Result<
| Scope::GenerationsApi
| Scope::Infra
| Scope::Scrubber
| Scope::ControllerPeer,
| Scope::ControllerPeer
| Scope::TenantEndpoint,
_,
) => Err(AuthError(
format!(

View File

@@ -52,6 +52,7 @@ tokio-rustls.workspace = true
tokio-util.workspace = true
tokio.workspace = true
tracing.workspace = true
uuid.workspace = true
measured.workspace = true
rustls.workspace = true
scopeguard.workspace = true
@@ -63,6 +64,7 @@ tokio-postgres-rustls.workspace = true
diesel = { version = "2.2.6", features = [
"serde_json",
"chrono",
"uuid",
] }
diesel-async = { version = "0.5.2", features = ["postgres", "bb8", "async-connection-wrapper"] }
diesel_migrations = { version = "2.2.0" }

View File

@@ -0,0 +1,2 @@
DROP TABLE hadron_safekeepers;
DROP TABLE hadron_timeline_safekeepers;

View File

@@ -0,0 +1,17 @@
-- hadron_safekeepers keep track of all Safe Keeper nodes that exist in the system.
-- Upon startup, each Safe Keeper reaches out to the hadron cluster coordinator to register its node ID and listen addresses.
CREATE TABLE hadron_safekeepers (
sk_node_id BIGINT PRIMARY KEY NOT NULL,
listen_http_addr VARCHAR NOT NULL,
listen_http_port INTEGER NOT NULL,
listen_pg_addr VARCHAR NOT NULL,
listen_pg_port INTEGER NOT NULL
);
CREATE TABLE hadron_timeline_safekeepers (
timeline_id VARCHAR NOT NULL,
sk_node_id BIGINT NOT NULL,
legacy_endpoint_id UUID DEFAULT NULL,
PRIMARY KEY(timeline_id, sk_node_id)
);

View File

@@ -1,4 +1,5 @@
use utils::auth::{AuthError, Claims, Scope};
use uuid::Uuid;
pub fn check_permission(claims: &Claims, required_scope: Scope) -> Result<(), AuthError> {
if claims.scope != required_scope {
@@ -7,3 +8,14 @@ pub fn check_permission(claims: &Claims, required_scope: Scope) -> Result<(), Au
Ok(())
}
#[allow(dead_code)]
pub fn check_endpoint_permission(claims: &Claims, endpoint_id: Uuid) -> Result<(), AuthError> {
if claims.scope != Scope::TenantEndpoint {
return Err(AuthError("Scope mismatch. Permission denied".into()));
}
if claims.endpoint_id != Some(endpoint_id) {
return Err(AuthError("Endpoint id mismatch. Permission denied".into()));
}
Ok(())
}

View File

@@ -810,6 +810,7 @@ impl ComputeHook {
let send_locked = tokio::select! {
guard = send_lock.lock_owned() => {guard},
_ = cancel.cancelled() => {
tracing::info!("Notification cancelled while waiting for lock");
return Err(NotifyError::ShuttingDown)
}
};
@@ -851,11 +852,32 @@ impl ComputeHook {
let notify_url = compute_hook_url.as_ref().unwrap();
self.do_notify(notify_url, &request, cancel).await
} else {
self.do_notify_local::<M>(&request).await.map_err(|e| {
match self.do_notify_local::<M>(&request).await.map_err(|e| {
// This path is for testing only, so munge the error into our prod-style error type.
tracing::error!("neon_local notification hook failed: {e}");
NotifyError::Fatal(StatusCode::INTERNAL_SERVER_ERROR)
})
if e.to_string().contains("refresh-configuration-pending") {
// If the error message mentions "refresh-configuration-pending", it means the compute node
// rejected our notification request because it already trying to reconfigure itself. We
// can proceed with the rest of the reconcliation process as the compute node already
// discovers the need to reconfigure and will eventually update its configuration once
// we update the pageserver mappings. In fact, it is important that we continue with
// reconcliation to make sure we update the pageserver mappings to unblock the compute node.
tracing::info!("neon_local notification hook failed: {e}");
tracing::info!("Notification failed likely due to compute node self-reconfiguration, will retry.");
Ok(())
} else {
tracing::error!("neon_local notification hook failed: {e}");
Err(NotifyError::Fatal(StatusCode::INTERNAL_SERVER_ERROR))
}
}) {
// Compute node accepted the notification request. Ok to proceed.
Ok(_) => Ok(()),
// Compute node rejected our request but it is already self-reconfiguring. Ok to proceed.
Err(Ok(_)) => Ok(()),
// Fail the reconciliation attempt in all other cases. Recall that this whole code path involving
// neon_local is for testing only. In production we always retry failed reconcliations so we
// don't have any deadends here.
Err(Err(e)) => Err(e),
}
};
match result {

View File

@@ -0,0 +1,44 @@
use std::collections::BTreeMap;
use rand::Rng;
use utils::shard::TenantShardId;
static CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!@#$%^&*()";
/// Generate a random string of `length` that can be used as a password. The generated string
/// contains alphanumeric characters and special characters (!@#$%^&*())
pub fn generate_random_password(length: usize) -> String {
let mut rng = rand::thread_rng();
(0..length)
.map(|_| {
let idx = rng.gen_range(0..CHARSET.len());
CHARSET[idx] as char
})
.collect()
}
pub(crate) struct TenantShardSizeMap {
#[expect(dead_code)]
pub map: BTreeMap<TenantShardId, u64>,
}
impl TenantShardSizeMap {
pub fn new(map: BTreeMap<TenantShardId, u64>) -> Self {
Self { map }
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_generate_random_password() {
let pwd1 = generate_random_password(10);
assert_eq!(pwd1.len(), 10);
let pwd2 = generate_random_password(10);
assert_ne!(pwd1, pwd2);
assert!(pwd1.chars().all(|c| CHARSET.contains(&(c as u8))));
assert!(pwd2.chars().all(|c| CHARSET.contains(&(c as u8))));
}
}

View File

@@ -48,7 +48,10 @@ use crate::metrics::{
};
use crate::persistence::SafekeeperUpsert;
use crate::reconciler::ReconcileError;
use crate::service::{LeadershipStatus, RECONCILE_TIMEOUT, STARTUP_RECONCILE_TIMEOUT, Service};
use crate::service::{
LeadershipStatus, RECONCILE_TIMEOUT, STARTUP_RECONCILE_TIMEOUT, Service,
TenantMutationLocations,
};
/// State available to HTTP request handlers
pub struct HttpState {
@@ -734,77 +737,104 @@ async fn handle_tenant_timeline_passthrough(
path
);
// Find the node that holds shard zero
let (node, tenant_shard_id, consistent) = if tenant_or_shard_id.is_unsharded() {
service
let tenant_shard_id = if tenant_or_shard_id.is_unsharded() {
// If the request contains only tenant ID, find the node that holds shard zero
let (_, shard_id) = service
.tenant_shard0_node(tenant_or_shard_id.tenant_id)
.await?
.await?;
shard_id
} else {
let (node, consistent) = service.tenant_shard_node(tenant_or_shard_id).await?;
(node, tenant_or_shard_id, consistent)
tenant_or_shard_id
};
// Callers will always pass an unsharded tenant ID. Before proxying, we must
// rewrite this to a shard-aware shard zero ID.
let path = format!("{path}");
let tenant_str = tenant_or_shard_id.tenant_id.to_string();
let tenant_shard_str = format!("{tenant_shard_id}");
let path = path.replace(&tenant_str, &tenant_shard_str);
let service_inner = service.clone();
let latency = &METRICS_REGISTRY
.metrics_group
.storage_controller_passthrough_request_latency;
service.tenant_shard_remote_mutation(tenant_shard_id, |locations| async move {
let TenantMutationLocations(locations) = locations;
if locations.is_empty() {
return Err(ApiError::NotFound(anyhow::anyhow!("Tenant {} not found", tenant_or_shard_id.tenant_id).into()));
}
let path_label = path_without_ids(&path)
.split('/')
.filter(|token| !token.is_empty())
.collect::<Vec<_>>()
.join("_");
let labels = PageserverRequestLabelGroup {
pageserver_id: &node.get_id().to_string(),
path: &path_label,
method: crate::metrics::Method::Get,
};
let (tenant_or_shard_id, locations) = locations.into_iter().next().unwrap();
let node = locations.latest.node;
let _timer = latency.start_timer(labels.clone());
// Callers will always pass an unsharded tenant ID. Before proxying, we must
// rewrite this to a shard-aware shard zero ID.
let path = format!("{path}");
let tenant_str = tenant_or_shard_id.tenant_id.to_string();
let tenant_shard_str = format!("{tenant_shard_id}");
let path = path.replace(&tenant_str, &tenant_shard_str);
let client = mgmt_api::Client::new(
service.get_http_client().clone(),
node.base_url(),
service.get_config().pageserver_jwt_token.as_deref(),
);
let resp = client.op_raw(method, path).await.map_err(|e|
// We return 503 here because if we can't successfully send a request to the pageserver,
// either we aren't available or the pageserver is unavailable.
ApiError::ResourceUnavailable(format!("Error sending pageserver API request to {node}: {e}").into()))?;
if !resp.status().is_success() {
let error_counter = &METRICS_REGISTRY
let latency = &METRICS_REGISTRY
.metrics_group
.storage_controller_passthrough_request_error;
error_counter.inc(labels);
}
.storage_controller_passthrough_request_latency;
// Transform 404 into 503 if we raced with a migration
if resp.status() == reqwest::StatusCode::NOT_FOUND && !consistent {
// Rather than retry here, send the client a 503 to prompt a retry: this matches
// the pageserver's use of 503, and all clients calling this API should retry on 503.
return Err(ApiError::ResourceUnavailable(
format!("Pageserver {node} returned 404 due to ongoing migration, retry later").into(),
));
}
let path_label = path_without_ids(&path)
.split('/')
.filter(|token| !token.is_empty())
.collect::<Vec<_>>()
.join("_");
let labels = PageserverRequestLabelGroup {
pageserver_id: &node.get_id().to_string(),
path: &path_label,
method: crate::metrics::Method::Get,
};
// We have a reqest::Response, would like a http::Response
let mut builder = hyper::Response::builder().status(map_reqwest_hyper_status(resp.status())?);
for (k, v) in resp.headers() {
builder = builder.header(k.as_str(), v.as_bytes());
}
let _timer = latency.start_timer(labels.clone());
let response = builder
.body(Body::wrap_stream(resp.bytes_stream()))
.map_err(|e| ApiError::InternalServerError(e.into()))?;
let client = mgmt_api::Client::new(
service_inner.get_http_client().clone(),
node.base_url(),
service_inner.get_config().pageserver_jwt_token.as_deref(),
);
let resp = client.op_raw(method, path).await.map_err(|e|
// We return 503 here because if we can't successfully send a request to the pageserver,
// either we aren't available or the pageserver is unavailable.
ApiError::ResourceUnavailable(format!("Error sending pageserver API request to {node}: {e}").into()))?;
Ok(response)
if !resp.status().is_success() {
let error_counter = &METRICS_REGISTRY
.metrics_group
.storage_controller_passthrough_request_error;
error_counter.inc(labels);
}
let resp_staus = resp.status();
// We have a reqest::Response, would like a http::Response
let mut builder = hyper::Response::builder().status(map_reqwest_hyper_status(resp_staus)?);
for (k, v) in resp.headers() {
builder = builder.header(k.as_str(), v.as_bytes());
}
let resp_bytes = resp
.bytes()
.await
.map_err(|e| ApiError::InternalServerError(e.into()))?;
// Inspect 404 errors: at this point, we know that the tenant exists, but the pageserver we route
// the request to might not yet be ready. Therefore, if it is a _tenant_ not found error, we can
// convert it into a 503. TODO: we should make this part of the check in `tenant_shard_remote_mutation`.
// However, `tenant_shard_remote_mutation` currently cannot inspect the HTTP error response body,
// so we have to do it here instead.
if resp_staus == reqwest::StatusCode::NOT_FOUND {
let resp_str = std::str::from_utf8(&resp_bytes)
.map_err(|e| ApiError::InternalServerError(e.into()))?;
// We only handle "tenant not found" errors; other 404s like timeline not found should
// be forwarded as-is.
if Service::is_tenant_not_found_error(resp_str, tenant_or_shard_id.tenant_id) {
// Rather than retry here, send the client a 503 to prompt a retry: this matches
// the pageserver's use of 503, and all clients calling this API should retry on 503.
return Err(ApiError::ResourceUnavailable(
format!(
"Pageserver {node} returned tenant 404 due to ongoing migration, retry later"
)
.into(),
));
}
}
let response = builder
.body(Body::from(resp_bytes))
.map_err(|e| ApiError::InternalServerError(e.into()))?;
Ok(response)
}).await?
}
async fn handle_tenant_locate(
@@ -1085,9 +1115,10 @@ async fn handle_node_delete(req: Request<Body>) -> Result<Response<Body>, ApiErr
let state = get_state(&req);
let node_id: NodeId = parse_request_param(&req, "node_id")?;
let force: bool = parse_query_param(&req, "force")?.unwrap_or(false);
json_response(
StatusCode::OK,
state.service.start_node_delete(node_id).await?,
state.service.start_node_delete(node_id, force).await?,
)
}

View File

@@ -6,6 +6,7 @@ extern crate hyper0 as hyper;
mod auth;
mod background_node_operations;
mod compute_hook;
pub mod hadron_utils;
mod heartbeater;
pub mod http;
mod id_lock_map;

View File

@@ -76,8 +76,8 @@ pub(crate) struct StorageControllerMetricGroup {
/// How many shards would like to reconcile but were blocked by concurrency limits
pub(crate) storage_controller_pending_reconciles: measured::Gauge,
/// How many shards are keep-failing and will be ignored when considering to run optimizations
pub(crate) storage_controller_keep_failing_reconciles: measured::Gauge,
/// How many shards are stuck and will be ignored when considering to run optimizations
pub(crate) storage_controller_stuck_reconciles: measured::Gauge,
/// HTTP request status counters for handled requests
pub(crate) storage_controller_http_request_status:
@@ -151,6 +151,29 @@ pub(crate) struct StorageControllerMetricGroup {
/// Indicator of completed safekeeper reconciles, broken down by safekeeper.
pub(crate) storage_controller_safekeeper_reconciles_complete:
measured::CounterVec<SafekeeperReconcilerLabelGroupSet>,
/* BEGIN HADRON */
/// Hadron `config_watcher` reconciliation runs completed, broken down by success/failure.
pub(crate) storage_controller_config_watcher_complete:
measured::CounterVec<ConfigWatcherCompleteLabelGroupSet>,
/// Hadron long waits for node state changes during drain and fill.
pub(crate) storage_controller_drain_and_fill_long_waits: measured::Counter,
/// Set to 1 if we detect any page server pods with pending node pool rotation annotations.
/// Requires manual reset after oncall investigation.
pub(crate) storage_controller_ps_node_pool_rotation_pending: measured::Gauge,
/// Hadron storage scrubber status.
pub(crate) storage_controller_storage_scrub_status:
measured::CounterVec<StorageScrubberLabelGroupSet>,
/// Desired number of pageservers managed by the storage controller
pub(crate) storage_controller_num_pageservers_desired: measured::Gauge,
/// Desired number of safekeepers managed by the storage controller
pub(crate) storage_controller_num_safekeeper_desired: measured::Gauge,
/* END HADRON */
}
impl StorageControllerMetrics {
@@ -173,6 +196,10 @@ impl Default for StorageControllerMetrics {
.storage_controller_reconcile_complete
.init_all_dense();
metrics_group
.storage_controller_config_watcher_complete
.init_all_dense();
Self {
metrics_group,
encoder: Mutex::new(measured::text::BufferedTextEncoder::new()),
@@ -262,11 +289,48 @@ pub(crate) struct ReconcileLongRunningLabelGroup<'a> {
pub(crate) sequence: &'a str,
}
#[derive(measured::LabelGroup, Clone)]
#[label(set = StorageScrubberLabelGroupSet)]
pub(crate) struct StorageScrubberLabelGroup<'a> {
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
pub(crate) tenant_id: &'a str,
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
pub(crate) shard_number: &'a str,
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
pub(crate) timeline_id: &'a str,
pub(crate) outcome: StorageScrubberOutcome,
}
#[derive(FixedCardinalityLabel, Clone, Copy)]
pub(crate) enum StorageScrubberOutcome {
PSOk,
PSWarning,
PSError,
PSOrphan,
SKOk,
SKError,
}
#[derive(measured::LabelGroup)]
#[label(set = ConfigWatcherCompleteLabelGroupSet)]
pub(crate) struct ConfigWatcherCompleteLabelGroup {
// Reuse the ReconcileOutcome from the SC's reconciliation metrics.
pub(crate) status: ReconcileOutcome,
}
#[derive(FixedCardinalityLabel, Clone, Copy)]
pub(crate) enum ReconcileOutcome {
// Successfully reconciled everything.
#[label(rename = "ok")]
Success,
// Used by tenant-shard reconciler only. Reconciled pageserver state successfully,
// but failed to delivery the compute notificiation. This error is typically transient
// but if its occurance keeps increasing, it should be investigated.
#[label(rename = "ok_no_notify")]
SuccessNoNotify,
// We failed to reconcile some state and the reconcilation will be retried.
Error,
// Reconciliation was cancelled.
Cancel,
}

View File

@@ -51,6 +51,39 @@ pub(crate) struct Node {
cancel: CancellationToken,
}
#[allow(dead_code)]
const ONE_MILLION: i64 = 1000000;
// Converts a pool ID to a large number that can be used to assign unique IDs to pods in StatefulSets.
/// For example, if pool_id is 1, then the pods have NodeIds 1000000, 1000001, 1000002, etc.
/// If pool_id is None, then the pods have NodeIds 0, 1, 2, etc.
#[allow(dead_code)]
pub fn transform_pool_id(pool_id: Option<i32>) -> i64 {
match pool_id {
Some(id) => (id as i64) * ONE_MILLION,
None => 0,
}
}
#[allow(dead_code)]
pub fn get_pool_id_from_node_id(node_id: i64) -> i32 {
(node_id / ONE_MILLION) as i32
}
/// Example pod name: page-server-0-1, safe-keeper-1-0
#[allow(dead_code)]
pub fn get_node_id_from_pod_name(pod_name: &str) -> anyhow::Result<NodeId> {
let parts: Vec<&str> = pod_name.split('-').collect();
if parts.len() != 4 {
return Err(anyhow::anyhow!("Invalid pod name: {}", pod_name));
}
let pool_id = parts[2].parse::<i32>()?;
let node_offset = parts[3].parse::<i64>()?;
let node_id = transform_pool_id(Some(pool_id)) + node_offset;
Ok(NodeId(node_id as u64))
}
/// When updating [`Node::availability`] we use this type to indicate to the caller
/// whether/how they changed it.
pub(crate) enum AvailabilityTransition {
@@ -403,3 +436,25 @@ impl std::fmt::Debug for Node {
write!(f, "{} ({})", self.id, self.listen_http_addr)
}
}
#[cfg(test)]
mod tests {
use utils::id::NodeId;
use crate::node::get_node_id_from_pod_name;
#[test]
fn test_get_node_id_from_pod_name() {
let pod_name = "page-server-3-12";
let node_id = get_node_id_from_pod_name(pod_name).unwrap();
assert_eq!(node_id, NodeId(3000012));
let pod_name = "safe-keeper-1-0";
let node_id = get_node_id_from_pod_name(pod_name).unwrap();
assert_eq!(node_id, NodeId(1000000));
let pod_name = "invalid-pod-name";
let result = get_node_id_from_pod_name(pod_name);
assert!(result.is_err());
}
}

View File

@@ -14,6 +14,8 @@ use reqwest::StatusCode;
use utils::id::{NodeId, TenantId, TimelineId};
use utils::lsn::Lsn;
use crate::hadron_utils::TenantShardSizeMap;
/// Thin wrapper around [`pageserver_client::mgmt_api::Client`]. It allows the storage
/// controller to collect metrics in a non-intrusive manner.
#[derive(Debug, Clone)]
@@ -86,6 +88,31 @@ impl PageserverClient {
)
}
#[expect(dead_code)]
pub(crate) async fn tenant_timeline_compact(
&self,
tenant_shard_id: TenantShardId,
timeline_id: TimelineId,
force_image_layer_creation: bool,
wait_until_done: bool,
) -> Result<()> {
measured_request!(
"tenant_timeline_compact",
crate::metrics::Method::Put,
&self.node_id_label,
self.inner
.tenant_timeline_compact(
tenant_shard_id,
timeline_id,
force_image_layer_creation,
true,
false,
wait_until_done,
)
.await
)
}
/* BEGIN_HADRON */
pub(crate) async fn tenant_timeline_describe(
&self,
@@ -101,6 +128,17 @@ impl PageserverClient {
.await
)
}
#[expect(dead_code)]
pub(crate) async fn list_tenant_visible_size(&self) -> Result<TenantShardSizeMap> {
measured_request!(
"list_tenant_visible_size",
crate::metrics::Method::Get,
&self.node_id_label,
self.inner.list_tenant_visible_size().await
)
.map(TenantShardSizeMap::new)
}
/* END_HADRON */
pub(crate) async fn tenant_scan_remote_storage(
@@ -365,6 +403,16 @@ impl PageserverClient {
)
}
#[expect(dead_code)]
pub(crate) async fn reset_alert_gauges(&self) -> Result<()> {
measured_request!(
"reset_alert_gauges",
crate::metrics::Method::Post,
&self.node_id_label,
self.inner.reset_alert_gauges().await
)
}
pub(crate) async fn wait_lsn(
&self,
tenant_shard_id: TenantShardId,

View File

@@ -862,11 +862,11 @@ impl Reconciler {
Some(conf) if conf.conf.as_ref() == Some(&wanted_conf) => {
if refreshed {
tracing::info!(
node_id=%node.get_id(), "Observed configuration correct after refresh. Notifying compute.");
node_id=%node.get_id(), "[Attached] Observed configuration correct after refresh. Notifying compute.");
self.compute_notify().await?;
} else {
// Nothing to do
tracing::info!(node_id=%node.get_id(), "Observed configuration already correct.");
tracing::info!(node_id=%node.get_id(), "[Attached] Observed configuration already correct.");
}
}
observed => {
@@ -945,17 +945,17 @@ impl Reconciler {
match self.observed.locations.get(&node.get_id()) {
Some(conf) if conf.conf.as_ref() == Some(&wanted_conf) => {
// Nothing to do
tracing::info!(node_id=%node.get_id(), "Observed configuration already correct.")
tracing::info!(node_id=%node.get_id(), "[Secondary] Observed configuration already correct.")
}
_ => {
// Only try and configure secondary locations on nodes that are available. This
// allows the reconciler to "succeed" while some secondaries are offline (e.g. after
// a node failure, where the failed node will have a secondary intent)
if node.is_available() {
tracing::info!(node_id=%node.get_id(), "Observed configuration requires update.");
tracing::info!(node_id=%node.get_id(), "[Secondary] Observed configuration requires update.");
changes.push((node.clone(), wanted_conf))
} else {
tracing::info!(node_id=%node.get_id(), "Skipping configuration as secondary, node is unavailable");
tracing::info!(node_id=%node.get_id(), "[Secondary] Skipping configuration as secondary, node is unavailable");
self.observed
.locations
.insert(node.get_id(), ObservedStateLocation { conf: None });
@@ -1066,6 +1066,9 @@ impl Reconciler {
}
result
} else {
tracing::info!(
"Compute notification is skipped because the tenant shard does not have an attached (primary) location"
);
Ok(())
}
}

View File

@@ -13,6 +13,24 @@ diesel::table! {
}
}
diesel::table! {
hadron_safekeepers (sk_node_id) {
sk_node_id -> Int8,
listen_http_addr -> Varchar,
listen_http_port -> Int4,
listen_pg_addr -> Varchar,
listen_pg_port -> Int4,
}
}
diesel::table! {
hadron_timeline_safekeepers (timeline_id, sk_node_id) {
timeline_id -> Varchar,
sk_node_id -> Int8,
legacy_endpoint_id -> Nullable<Uuid>,
}
}
diesel::table! {
metadata_health (tenant_id, shard_number, shard_count) {
tenant_id -> Varchar,
@@ -105,6 +123,8 @@ diesel::table! {
diesel::allow_tables_to_appear_in_same_query!(
controllers,
hadron_safekeepers,
hadron_timeline_safekeepers,
metadata_health,
nodes,
safekeeper_timeline_pending_ops,

View File

@@ -207,34 +207,13 @@ enum ShardGenerationValidity {
},
}
/// We collect the state of attachments for some operations to determine if the operation
/// needs to be retried when it fails.
struct TenantShardAttachState {
/// The targets of the operation.
///
/// Tenant shard ID, node ID, node, is intent node observed primary.
targets: Vec<(TenantShardId, NodeId, Node, bool)>,
/// The targets grouped by node ID.
by_node_id: HashMap<NodeId, (TenantShardId, Node, bool)>,
}
impl TenantShardAttachState {
fn for_api_call(&self) -> Vec<(TenantShardId, Node)> {
self.targets
.iter()
.map(|(tenant_shard_id, _, node, _)| (*tenant_shard_id, node.clone()))
.collect()
}
}
pub const RECONCILER_CONCURRENCY_DEFAULT: usize = 128;
pub const PRIORITY_RECONCILER_CONCURRENCY_DEFAULT: usize = 256;
pub const SAFEKEEPER_RECONCILER_CONCURRENCY_DEFAULT: usize = 32;
// Number of consecutive reconciliation errors, occured for one shard,
// Number of consecutive reconciliations that have occurred for one shard,
// after which the shard is ignored when considering to run optimizations.
const MAX_CONSECUTIVE_RECONCILIATION_ERRORS: usize = 5;
const MAX_CONSECUTIVE_RECONCILES: usize = 10;
// Depth of the channel used to enqueue shards for reconciliation when they can't do it immediately.
// This channel is finite-size to avoid using excessive memory if we get into a state where reconciles are finishing more slowly
@@ -719,47 +698,70 @@ pub(crate) enum ReconcileResultRequest {
}
#[derive(Clone)]
struct MutationLocation {
node: Node,
generation: Generation,
pub(crate) struct MutationLocation {
pub(crate) node: Node,
pub(crate) generation: Generation,
}
#[derive(Clone)]
struct ShardMutationLocations {
latest: MutationLocation,
other: Vec<MutationLocation>,
pub(crate) struct ShardMutationLocations {
pub(crate) latest: MutationLocation,
pub(crate) other: Vec<MutationLocation>,
}
#[derive(Default, Clone)]
struct TenantMutationLocations(BTreeMap<TenantShardId, ShardMutationLocations>);
pub(crate) struct TenantMutationLocations(pub BTreeMap<TenantShardId, ShardMutationLocations>);
struct ReconcileAllResult {
spawned_reconciles: usize,
keep_failing_reconciles: usize,
stuck_reconciles: usize,
has_delayed_reconciles: bool,
}
impl ReconcileAllResult {
fn new(
spawned_reconciles: usize,
keep_failing_reconciles: usize,
stuck_reconciles: usize,
has_delayed_reconciles: bool,
) -> Self {
assert!(
spawned_reconciles >= keep_failing_reconciles,
"It is impossible to have more keep-failing reconciles than spawned reconciles"
spawned_reconciles >= stuck_reconciles,
"It is impossible to have less spawned reconciles than stuck reconciles"
);
Self {
spawned_reconciles,
keep_failing_reconciles,
stuck_reconciles,
has_delayed_reconciles,
}
}
/// We can run optimizations only if we don't have any delayed reconciles and
/// all spawned reconciles are also keep-failing reconciles.
/// all spawned reconciles are also stuck reconciles.
fn can_run_optimizations(&self) -> bool {
!self.has_delayed_reconciles && self.spawned_reconciles == self.keep_failing_reconciles
!self.has_delayed_reconciles && self.spawned_reconciles == self.stuck_reconciles
}
}
enum TenantIdOrShardId {
TenantId(TenantId),
TenantShardId(TenantShardId),
}
impl TenantIdOrShardId {
fn tenant_id(&self) -> TenantId {
match self {
TenantIdOrShardId::TenantId(tenant_id) => *tenant_id,
TenantIdOrShardId::TenantShardId(tenant_shard_id) => tenant_shard_id.tenant_id,
}
}
fn matches(&self, tenant_shard_id: &TenantShardId) -> bool {
match self {
TenantIdOrShardId::TenantId(tenant_id) => tenant_shard_id.tenant_id == *tenant_id,
TenantIdOrShardId::TenantShardId(this_tenant_shard_id) => {
this_tenant_shard_id == tenant_shard_id
}
}
}
}
@@ -1503,7 +1505,6 @@ impl Service {
match result.result {
Ok(()) => {
tenant.consecutive_errors_count = 0;
tenant.apply_observed_deltas(deltas);
tenant.waiter.advance(result.sequence);
}
@@ -1522,8 +1523,6 @@ impl Service {
}
}
tenant.consecutive_errors_count = tenant.consecutive_errors_count.saturating_add(1);
// Ordering: populate last_error before advancing error_seq,
// so that waiters will see the correct error after waiting.
tenant.set_last_error(result.sequence, e);
@@ -1535,6 +1534,8 @@ impl Service {
}
}
tenant.consecutive_reconciles_count = tenant.consecutive_reconciles_count.saturating_add(1);
// If we just finished detaching all shards for a tenant, it might be time to drop it from memory.
if tenant.policy == PlacementPolicy::Detached {
// We may only drop a tenant from memory while holding the exclusive lock on the tenant ID: this protects us
@@ -4773,72 +4774,24 @@ impl Service {
Ok(())
}
fn is_observed_consistent_with_intent(
&self,
shard: &TenantShard,
intent_node_id: NodeId,
) -> bool {
if let Some(location) = shard.observed.locations.get(&intent_node_id)
&& let Some(ref conf) = location.conf
&& (conf.mode == LocationConfigMode::AttachedSingle
|| conf.mode == LocationConfigMode::AttachedMulti)
{
true
} else {
false
}
}
fn collect_tenant_shards(
&self,
tenant_id: TenantId,
) -> Result<TenantShardAttachState, ApiError> {
let locked = self.inner.read().unwrap();
let mut targets = Vec::new();
let mut by_node_id = HashMap::new();
// If the request got an unsharded tenant id, then apply
// the operation to all shards. Otherwise, apply it to a specific shard.
let shards_range = TenantShardId::tenant_range(tenant_id);
for (tenant_shard_id, shard) in locked.tenants.range(shards_range) {
if let Some(node_id) = shard.intent.get_attached() {
let node = locked
.nodes
.get(node_id)
.expect("Pageservers may not be deleted while referenced");
let consistent = self.is_observed_consistent_with_intent(shard, *node_id);
targets.push((*tenant_shard_id, *node_id, node.clone(), consistent));
by_node_id.insert(*node_id, (*tenant_shard_id, node.clone(), consistent));
}
}
Ok(TenantShardAttachState {
targets,
by_node_id,
})
pub(crate) fn is_tenant_not_found_error(body: &str, tenant_id: TenantId) -> bool {
body.contains(&format!("tenant {tenant_id}"))
}
fn process_result_and_passthrough_errors<T>(
&self,
tenant_id: TenantId,
results: Vec<(Node, Result<T, mgmt_api::Error>)>,
attach_state: TenantShardAttachState,
) -> Result<Vec<(Node, T)>, ApiError> {
let mut processed_results: Vec<(Node, T)> = Vec::with_capacity(results.len());
debug_assert_eq!(results.len(), attach_state.targets.len());
for (node, res) in results {
let is_consistent = attach_state
.by_node_id
.get(&node.get_id())
.map(|(_, _, consistent)| *consistent);
match res {
Ok(res) => processed_results.push((node, res)),
Err(mgmt_api::Error::ApiError(StatusCode::NOT_FOUND, _))
if is_consistent == Some(false) =>
Err(mgmt_api::Error::ApiError(StatusCode::NOT_FOUND, body))
if Self::is_tenant_not_found_error(&body, tenant_id) =>
{
// This is expected if the attach is not finished yet. Return 503 so that the client can retry.
// If there's a tenant not found, we are still in the process of attaching the tenant.
// Return 503 so that the client can retry.
return Err(ApiError::ResourceUnavailable(
format!(
"Timeline is not attached to the pageserver {} yet, please retry",
@@ -4866,35 +4819,48 @@ impl Service {
)
.await;
let attach_state = self.collect_tenant_shards(tenant_id)?;
let results = self
.tenant_for_shards_api(
attach_state.for_api_call(),
|tenant_shard_id, client| async move {
client
.timeline_lease_lsn(tenant_shard_id, timeline_id, lsn)
.await
},
1,
1,
SHORT_RECONCILE_TIMEOUT,
&self.cancel,
)
.await;
let leases = self.process_result_and_passthrough_errors(results, attach_state)?;
let mut valid_until = None;
for (_, lease) in leases {
if let Some(ref mut valid_until) = valid_until {
*valid_until = std::cmp::min(*valid_until, lease.valid_until);
} else {
valid_until = Some(lease.valid_until);
self.tenant_remote_mutation(tenant_id, |locations| async move {
if locations.0.is_empty() {
return Err(ApiError::NotFound(
anyhow::anyhow!("Tenant not found").into(),
));
}
}
Ok(LsnLease {
valid_until: valid_until.unwrap_or_else(SystemTime::now),
let results = self
.tenant_for_shards_api(
locations
.0
.iter()
.map(|(tenant_shard_id, ShardMutationLocations { latest, .. })| {
(*tenant_shard_id, latest.node.clone())
})
.collect(),
|tenant_shard_id, client| async move {
client
.timeline_lease_lsn(tenant_shard_id, timeline_id, lsn)
.await
},
1,
1,
SHORT_RECONCILE_TIMEOUT,
&self.cancel,
)
.await;
let leases = self.process_result_and_passthrough_errors(tenant_id, results)?;
let mut valid_until = None;
for (_, lease) in leases {
if let Some(ref mut valid_until) = valid_until {
*valid_until = std::cmp::min(*valid_until, lease.valid_until);
} else {
valid_until = Some(lease.valid_until);
}
}
Ok(LsnLease {
valid_until: valid_until.unwrap_or_else(SystemTime::now),
})
})
.await?
}
pub(crate) async fn tenant_timeline_download_heatmap_layers(
@@ -5041,11 +5007,37 @@ impl Service {
/// - Looks up the shards and the nodes where they were most recently attached
/// - Guarantees that after the inner function returns, the shards' generations haven't moved on: this
/// ensures that the remote operation acted on the most recent generation, and is therefore durable.
async fn tenant_remote_mutation<R, O, F>(
pub(crate) async fn tenant_remote_mutation<R, O, F>(
&self,
tenant_id: TenantId,
op: O,
) -> Result<R, ApiError>
where
O: FnOnce(TenantMutationLocations) -> F,
F: std::future::Future<Output = R>,
{
self.tenant_remote_mutation_inner(TenantIdOrShardId::TenantId(tenant_id), op)
.await
}
pub(crate) async fn tenant_shard_remote_mutation<R, O, F>(
&self,
tenant_shard_id: TenantShardId,
op: O,
) -> Result<R, ApiError>
where
O: FnOnce(TenantMutationLocations) -> F,
F: std::future::Future<Output = R>,
{
self.tenant_remote_mutation_inner(TenantIdOrShardId::TenantShardId(tenant_shard_id), op)
.await
}
async fn tenant_remote_mutation_inner<R, O, F>(
&self,
tenant_id_or_shard_id: TenantIdOrShardId,
op: O,
) -> Result<R, ApiError>
where
O: FnOnce(TenantMutationLocations) -> F,
F: std::future::Future<Output = R>,
@@ -5057,7 +5049,13 @@ impl Service {
// run concurrently with reconciliations, and it is not guaranteed that the node we find here
// will still be the latest when we're done: we will check generations again at the end of
// this function to handle that.
let generations = self.persistence.tenant_generations(tenant_id).await?;
let generations = self
.persistence
.tenant_generations(tenant_id_or_shard_id.tenant_id())
.await?
.into_iter()
.filter(|i| tenant_id_or_shard_id.matches(&i.tenant_shard_id))
.collect::<Vec<_>>();
if generations
.iter()
@@ -5071,9 +5069,14 @@ impl Service {
// One or more shards has not been attached to a pageserver. Check if this is because it's configured
// to be detached (409: caller should give up), or because it's meant to be attached but isn't yet (503: caller should retry)
let locked = self.inner.read().unwrap();
for (shard_id, shard) in
locked.tenants.range(TenantShardId::tenant_range(tenant_id))
{
let tenant_shards = locked
.tenants
.range(TenantShardId::tenant_range(
tenant_id_or_shard_id.tenant_id(),
))
.filter(|(shard_id, _)| tenant_id_or_shard_id.matches(shard_id))
.collect::<Vec<_>>();
for (shard_id, shard) in tenant_shards {
match shard.policy {
PlacementPolicy::Attached(_) => {
// This shard is meant to be attached: the caller is not wrong to try and
@@ -5183,7 +5186,14 @@ impl Service {
// Post-check: are all the generations of all the shards the same as they were initially? This proves that
// our remote operation executed on the latest generation and is therefore persistent.
{
let latest_generations = self.persistence.tenant_generations(tenant_id).await?;
let latest_generations = self
.persistence
.tenant_generations(tenant_id_or_shard_id.tenant_id())
.await?
.into_iter()
.filter(|i| tenant_id_or_shard_id.matches(&i.tenant_shard_id))
.collect::<Vec<_>>();
if latest_generations
.into_iter()
.map(
@@ -5317,7 +5327,7 @@ impl Service {
pub(crate) async fn tenant_shard0_node(
&self,
tenant_id: TenantId,
) -> Result<(Node, TenantShardId, bool), ApiError> {
) -> Result<(Node, TenantShardId), ApiError> {
let tenant_shard_id = {
let locked = self.inner.read().unwrap();
let Some((tenant_shard_id, _shard)) = locked
@@ -5335,7 +5345,7 @@ impl Service {
self.tenant_shard_node(tenant_shard_id)
.await
.map(|(node, consistent)| (node, tenant_shard_id, consistent))
.map(|node| (node, tenant_shard_id))
}
/// When you need to send an HTTP request to the pageserver that holds a shard of a tenant, this
@@ -5345,7 +5355,7 @@ impl Service {
pub(crate) async fn tenant_shard_node(
&self,
tenant_shard_id: TenantShardId,
) -> Result<(Node, bool), ApiError> {
) -> Result<Node, ApiError> {
// Look up in-memory state and maybe use the node from there.
{
let locked = self.inner.read().unwrap();
@@ -5375,8 +5385,7 @@ impl Service {
"Shard refers to nonexistent node"
)));
};
let consistent = self.is_observed_consistent_with_intent(shard, *intent_node_id);
return Ok((node.clone(), consistent));
return Ok(node.clone());
}
};
@@ -5411,7 +5420,7 @@ impl Service {
)));
};
// As a reconciliation is in flight, we do not have the observed state yet, and therefore we assume it is always inconsistent.
Ok((node.clone(), false))
Ok(node.clone())
}
pub(crate) fn tenant_locate(
@@ -7385,6 +7394,7 @@ impl Service {
self: &Arc<Self>,
node_id: NodeId,
policy_on_start: NodeSchedulingPolicy,
force: bool,
cancel: CancellationToken,
) -> Result<(), OperationError> {
let reconciler_config = ReconcilerConfigBuilder::new(ReconcilerPriority::Normal).build();
@@ -7392,23 +7402,27 @@ impl Service {
let mut waiters: Vec<ReconcilerWaiter> = Vec::new();
let mut tid_iter = create_shared_shard_iterator(self.clone());
let reset_node_policy_on_cancel = || async {
match self
.node_configure(node_id, None, Some(policy_on_start))
.await
{
Ok(()) => OperationError::Cancelled,
Err(err) => {
OperationError::FinalizeError(
format!(
"Failed to finalise delete cancel of {} by setting scheduling policy to {}: {}",
node_id, String::from(policy_on_start), err
)
.into(),
)
}
}
};
while !tid_iter.finished() {
if cancel.is_cancelled() {
match self
.node_configure(node_id, None, Some(policy_on_start))
.await
{
Ok(()) => return Err(OperationError::Cancelled),
Err(err) => {
return Err(OperationError::FinalizeError(
format!(
"Failed to finalise delete cancel of {} by setting scheduling policy to {}: {}",
node_id, String::from(policy_on_start), err
)
.into(),
));
}
}
return Err(reset_node_policy_on_cancel().await);
}
operation_utils::validate_node_state(
@@ -7477,8 +7491,18 @@ impl Service {
nodes,
reconciler_config,
);
if let Some(some) = waiter {
waiters.push(some);
if force {
// Here we remove an existing observed location for the node we're removing, and it will
// not be re-added by a reconciler's completion because we filter out removed nodes in
// process_result.
//
// Note that we update the shard's observed state _after_ calling maybe_configured_reconcile_shard:
// that means any reconciles we spawned will know about the node we're deleting,
// enabling them to do live migrations if it's still online.
tenant_shard.observed.locations.remove(&node_id);
} else if let Some(waiter) = waiter {
waiters.push(waiter);
}
}
}
@@ -7492,21 +7516,7 @@ impl Service {
while !waiters.is_empty() {
if cancel.is_cancelled() {
match self
.node_configure(node_id, None, Some(policy_on_start))
.await
{
Ok(()) => return Err(OperationError::Cancelled),
Err(err) => {
return Err(OperationError::FinalizeError(
format!(
"Failed to finalise drain cancel of {} by setting scheduling policy to {}: {}",
node_id, String::from(policy_on_start), err
)
.into(),
));
}
}
return Err(reset_node_policy_on_cancel().await);
}
tracing::info!("Awaiting {} pending delete reconciliations", waiters.len());
@@ -7516,6 +7526,12 @@ impl Service {
.await;
}
let pf = pausable_failpoint!("delete-node-after-reconciles-spawned", &cancel);
if pf.is_err() {
// An error from pausable_failpoint indicates the cancel token was triggered.
return Err(reset_node_policy_on_cancel().await);
}
self.persistence
.set_tombstone(node_id)
.await
@@ -8111,6 +8127,7 @@ impl Service {
pub(crate) async fn start_node_delete(
self: &Arc<Self>,
node_id: NodeId,
force: bool,
) -> Result<(), ApiError> {
let (ongoing_op, node_policy, schedulable_nodes_count) = {
let locked = self.inner.read().unwrap();
@@ -8180,7 +8197,7 @@ impl Service {
tracing::info!("Delete background operation starting");
let res = service
.delete_node(node_id, policy_on_start, cancel)
.delete_node(node_id, policy_on_start, force, cancel)
.await;
match res {
Ok(()) => {
@@ -8632,7 +8649,7 @@ impl Service {
// This function is an efficient place to update lazy statistics, since we are walking
// all tenants.
let mut pending_reconciles = 0;
let mut keep_failing_reconciles = 0;
let mut stuck_reconciles = 0;
let mut az_violations = 0;
// If we find any tenants to drop from memory, stash them to offload after
@@ -8668,30 +8685,32 @@ impl Service {
// Eventual consistency: if an earlier reconcile job failed, and the shard is still
// dirty, spawn another one
let consecutive_errors_count = shard.consecutive_errors_count;
if self
.maybe_reconcile_shard(shard, &pageservers, ReconcilerPriority::Normal)
.is_some()
{
spawned_reconciles += 1;
// Count shards that are keep-failing. We still want to reconcile them
// to avoid a situation where a shard is stuck.
// But we don't want to consider them when deciding to run optimizations.
if consecutive_errors_count >= MAX_CONSECUTIVE_RECONCILIATION_ERRORS {
if shard.consecutive_reconciles_count >= MAX_CONSECUTIVE_RECONCILES {
// Count shards that are stuck, butwe still want to reconcile them.
// We don't want to consider them when deciding to run optimizations.
tracing::warn!(
tenant_id=%shard.tenant_shard_id.tenant_id,
shard_id=%shard.tenant_shard_id.shard_slug(),
"Shard reconciliation is keep-failing: {} errors",
consecutive_errors_count
"Shard reconciliation is stuck: {} consecutive launches",
shard.consecutive_reconciles_count
);
keep_failing_reconciles += 1;
stuck_reconciles += 1;
}
} else {
if shard.delayed_reconcile {
// Shard wanted to reconcile but for some reason couldn't.
pending_reconciles += 1;
}
} else if shard.delayed_reconcile {
// Shard wanted to reconcile but for some reason couldn't.
pending_reconciles += 1;
}
// Reset the counter when we don't need to launch a reconcile.
shard.consecutive_reconciles_count = 0;
}
// If this tenant is detached, try dropping it from memory. This is usually done
// proactively in [`Self::process_results`], but we do it here to handle the edge
// case where a reconcile completes while someone else is holding an op lock for the tenant.
@@ -8727,14 +8746,10 @@ impl Service {
metrics::METRICS_REGISTRY
.metrics_group
.storage_controller_keep_failing_reconciles
.set(keep_failing_reconciles as i64);
.storage_controller_stuck_reconciles
.set(stuck_reconciles as i64);
ReconcileAllResult::new(
spawned_reconciles,
keep_failing_reconciles,
has_delayed_reconciles,
)
ReconcileAllResult::new(spawned_reconciles, stuck_reconciles, has_delayed_reconciles)
}
/// `optimize` in this context means identifying shards which have valid scheduled locations, but

View File

@@ -131,14 +131,16 @@ pub(crate) struct TenantShard {
#[serde(serialize_with = "read_last_error")]
pub(crate) last_error: std::sync::Arc<std::sync::Mutex<Option<Arc<ReconcileError>>>>,
/// Number of consecutive reconciliation errors that have occurred for this shard.
/// Amount of consecutive [`crate::service::Service::reconcile_all`] iterations that have been
/// scheduled a reconciliation for this shard.
///
/// When this count reaches MAX_CONSECUTIVE_RECONCILIATION_ERRORS, the tenant shard
/// will be countered as keep-failing in `reconcile_all` calculations. This will lead to
/// allowing optimizations to run even with some failing shards.
/// If this reaches `MAX_CONSECUTIVE_RECONCILES`, the shard is considered "stuck" and will be
/// ignored when deciding whether optimizations can run. This includes both successful and failed
/// reconciliations.
///
/// The counter is reset to 0 after a successful reconciliation.
pub(crate) consecutive_errors_count: usize,
/// Incremented in [`crate::service::Service::process_result`], and reset to 0 when
/// [`crate::service::Service::reconcile_all`] determines no reconciliation is needed for this shard.
pub(crate) consecutive_reconciles_count: usize,
/// If we have a pending compute notification that for some reason we weren't able to send,
/// set this to true. If this is set, calls to [`Self::get_reconcile_needed`] will return Yes
@@ -603,7 +605,7 @@ impl TenantShard {
waiter: Arc::new(SeqWait::new(Sequence(0))),
error_waiter: Arc::new(SeqWait::new(Sequence(0))),
last_error: Arc::default(),
consecutive_errors_count: 0,
consecutive_reconciles_count: 0,
pending_compute_notification: false,
scheduling_policy: ShardSchedulingPolicy::default(),
preferred_node: None,
@@ -1609,7 +1611,13 @@ impl TenantShard {
// Update result counter
let outcome_label = match &result {
Ok(_) => ReconcileOutcome::Success,
Ok(_) => {
if reconciler.compute_notify_failure {
ReconcileOutcome::SuccessNoNotify
} else {
ReconcileOutcome::Success
}
}
Err(ReconcileError::Cancel) => ReconcileOutcome::Cancel,
Err(_) => ReconcileOutcome::Error,
};
@@ -1908,7 +1916,7 @@ impl TenantShard {
waiter: Arc::new(SeqWait::new(Sequence::initial())),
error_waiter: Arc::new(SeqWait::new(Sequence::initial())),
last_error: Arc::default(),
consecutive_errors_count: 0,
consecutive_reconciles_count: 0,
pending_compute_notification: false,
delayed_reconcile: false,
scheduling_policy: serde_json::from_str(&tsp.scheduling_policy).unwrap(),

View File

@@ -2119,11 +2119,14 @@ class NeonStorageController(MetricsGetter, LogUtils):
headers=self.headers(TokenScope.ADMIN),
)
def node_delete(self, node_id):
def node_delete(self, node_id, force: bool = False):
log.info(f"node_delete({node_id})")
query = f"{self.api}/control/v1/node/{node_id}/delete"
if force:
query += "?force=true"
self.request(
"PUT",
f"{self.api}/control/v1/node/{node_id}/delete",
query,
headers=self.headers(TokenScope.ADMIN),
)

View File

@@ -847,7 +847,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
return res_json
def timeline_lsn_lease(
self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, lsn: Lsn
self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, lsn: Lsn, **kwargs
):
data = {
"lsn": str(lsn),
@@ -857,6 +857,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
res = self.post(
f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/lsn_lease",
json=data,
**kwargs,
)
self.verbose_error(res)
res_json = res.json()

View File

@@ -187,19 +187,21 @@ def test_create_snapshot(
env.pageserver.stop()
env.storage_controller.stop()
# Directory `compatibility_snapshot_dir` is uploaded to S3 in a workflow, keep the name in sync with it
compatibility_snapshot_dir = (
# Directory `new_compatibility_snapshot_dir` is uploaded to S3 in a workflow, keep the name in sync with it
new_compatibility_snapshot_dir = (
top_output_dir / f"compatibility_snapshot_pg{pg_version.v_prefixed}"
)
if compatibility_snapshot_dir.exists():
shutil.rmtree(compatibility_snapshot_dir)
if new_compatibility_snapshot_dir.exists():
shutil.rmtree(new_compatibility_snapshot_dir)
shutil.copytree(
test_output_dir,
compatibility_snapshot_dir,
new_compatibility_snapshot_dir,
ignore=shutil.ignore_patterns("pg_dynshmem"),
)
log.info(f"Copied new compatibility snapshot dir to: {new_compatibility_snapshot_dir}")
# check_neon_works does recovery from WAL => the compatibility snapshot's WAL is old => will log this warning
ingest_lag_log_line = ".*ingesting record with timestamp lagging more than wait_lsn_timeout.*"
@@ -218,6 +220,7 @@ def test_backward_compatibility(
"""
Test that the new binaries can read old data
"""
log.info(f"Using snapshot dir at {compatibility_snapshot_dir}")
neon_env_builder.num_safekeepers = 3
env = neon_env_builder.from_repo_dir(compatibility_snapshot_dir / "repo")
env.pageserver.allowed_errors.append(ingest_lag_log_line)
@@ -242,7 +245,6 @@ def test_forward_compatibility(
test_output_dir: Path,
top_output_dir: Path,
pg_version: PgVersion,
compatibility_snapshot_dir: Path,
compute_reconfigure_listener: ComputeReconfigure,
):
"""
@@ -266,8 +268,14 @@ def test_forward_compatibility(
neon_env_builder.neon_binpath = neon_env_builder.compatibility_neon_binpath
neon_env_builder.pg_distrib_dir = neon_env_builder.compatibility_pg_distrib_dir
# Note that we are testing with new data, so we should use `new_compatibility_snapshot_dir`, which is created by test_create_snapshot.
new_compatibility_snapshot_dir = (
top_output_dir / f"compatibility_snapshot_pg{pg_version.v_prefixed}"
)
log.info(f"Using snapshot dir at {new_compatibility_snapshot_dir}")
env = neon_env_builder.from_repo_dir(
compatibility_snapshot_dir / "repo",
new_compatibility_snapshot_dir / "repo",
)
# there may be an arbitrary number of unrelated tests run between create_snapshot and here
env.pageserver.allowed_errors.append(ingest_lag_log_line)
@@ -296,7 +304,7 @@ def test_forward_compatibility(
check_neon_works(
env,
test_output_dir=test_output_dir,
sql_dump_path=compatibility_snapshot_dir / "dump.sql",
sql_dump_path=new_compatibility_snapshot_dir / "dump.sql",
repo_dir=env.repo_dir,
)

View File

@@ -12,7 +12,7 @@ from typing import TYPE_CHECKING
import fixtures.utils
import pytest
from fixtures.auth_tokens import TokenScope
from fixtures.common_types import TenantId, TenantShardId, TimelineId
from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId
from fixtures.log_helper import log
from fixtures.neon_fixtures import (
DEFAULT_AZ_ID,
@@ -47,6 +47,7 @@ from fixtures.utils import (
wait_until,
)
from fixtures.workload import Workload
from requests.adapters import HTTPAdapter
from urllib3 import Retry
from werkzeug.wrappers.response import Response
@@ -72,6 +73,12 @@ def get_node_shard_counts(env: NeonEnv, tenant_ids):
return counts
class DeletionAPIKind(Enum):
OLD = "old"
FORCE = "force"
GRACEFUL = "graceful"
@pytest.mark.parametrize(**fixtures.utils.allpairs_versions())
def test_storage_controller_smoke(
neon_env_builder: NeonEnvBuilder, compute_reconfigure_listener: ComputeReconfigure, combination
@@ -990,7 +997,7 @@ def test_storage_controller_compute_hook_retry(
@run_only_on_default_postgres("postgres behavior is not relevant")
def test_storage_controller_compute_hook_keep_failing(
def test_storage_controller_compute_hook_stuck_reconciles(
httpserver: HTTPServer,
neon_env_builder: NeonEnvBuilder,
httpserver_listen_address: ListenAddress,
@@ -1040,7 +1047,7 @@ def test_storage_controller_compute_hook_keep_failing(
env.storage_controller.allowed_errors.append(NOTIFY_BLOCKED_LOG)
env.storage_controller.allowed_errors.extend(NOTIFY_FAILURE_LOGS)
env.storage_controller.allowed_errors.append(".*Keeping extra secondaries.*")
env.storage_controller.allowed_errors.append(".*Shard reconciliation is keep-failing.*")
env.storage_controller.allowed_errors.append(".*Shard reconciliation is stuck.*")
env.storage_controller.node_configure(banned_tenant_ps.id, {"availability": "Offline"})
# Migrate all allowed tenant shards to the first alive pageserver
@@ -1055,7 +1062,7 @@ def test_storage_controller_compute_hook_keep_failing(
# Make some reconcile_all calls to trigger optimizations
# RECONCILE_COUNT must be greater than storcon's MAX_CONSECUTIVE_RECONCILIATION_ERRORS
RECONCILE_COUNT = 12
RECONCILE_COUNT = 20
for i in range(RECONCILE_COUNT):
try:
n = env.storage_controller.reconcile_all()
@@ -1068,6 +1075,8 @@ def test_storage_controller_compute_hook_keep_failing(
assert banned_descr["shards"][0]["is_pending_compute_notification"] is True
time.sleep(2)
env.storage_controller.assert_log_contains(".*Shard reconciliation is stuck.*")
# Check that the allowed tenant shards are optimized due to affinity rules
locations = alive_pageservers[0].http_client().tenant_list_locations()["tenant_shards"]
not_optimized_shard_count = 0
@@ -2572,9 +2581,11 @@ def test_background_operation_cancellation(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize("while_offline", [True, False])
@pytest.mark.parametrize("deletion_api", [DeletionAPIKind.OLD, DeletionAPIKind.FORCE])
def test_storage_controller_node_deletion(
neon_env_builder: NeonEnvBuilder,
while_offline: bool,
deletion_api: DeletionAPIKind,
):
"""
Test that deleting a node works & properly reschedules everything that was on the node.
@@ -2598,6 +2609,8 @@ def test_storage_controller_node_deletion(
assert env.storage_controller.reconcile_all() == 0
victim = env.pageservers[-1]
if deletion_api == DeletionAPIKind.FORCE and not while_offline:
victim.allowed_errors.append(".*request was dropped before completing.*")
# The procedure a human would follow is:
# 1. Mark pageserver scheduling=pause
@@ -2621,7 +2634,12 @@ def test_storage_controller_node_deletion(
wait_until(assert_shards_migrated)
log.info(f"Deleting pageserver {victim.id}")
env.storage_controller.node_delete_old(victim.id)
if deletion_api == DeletionAPIKind.FORCE:
env.storage_controller.node_delete(victim.id, force=True)
elif deletion_api == DeletionAPIKind.OLD:
env.storage_controller.node_delete_old(victim.id)
else:
raise AssertionError(f"Invalid deletion API: {deletion_api}")
if not while_offline:
@@ -2634,7 +2652,15 @@ def test_storage_controller_node_deletion(
wait_until(assert_victim_evacuated)
# The node should be gone from the list API
assert victim.id not in [n["id"] for n in env.storage_controller.node_list()]
def assert_node_is_gone():
assert victim.id not in [n["id"] for n in env.storage_controller.node_list()]
if deletion_api == DeletionAPIKind.FORCE:
wait_until(assert_node_is_gone)
elif deletion_api == DeletionAPIKind.OLD:
assert_node_is_gone()
else:
raise AssertionError(f"Invalid deletion API: {deletion_api}")
# No tenants should refer to the node in their intent
for tenant_id in tenant_ids:
@@ -2656,7 +2682,11 @@ def test_storage_controller_node_deletion(
env.storage_controller.consistency_check()
def test_storage_controller_node_delete_cancellation(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize("deletion_api", [DeletionAPIKind.FORCE, DeletionAPIKind.GRACEFUL])
def test_storage_controller_node_delete_cancellation(
neon_env_builder: NeonEnvBuilder,
deletion_api: DeletionAPIKind,
):
neon_env_builder.num_pageservers = 3
neon_env_builder.num_azs = 3
env = neon_env_builder.init_configs()
@@ -2680,12 +2710,16 @@ def test_storage_controller_node_delete_cancellation(neon_env_builder: NeonEnvBu
assert len(nodes) == 3
env.storage_controller.configure_failpoints(("sleepy-delete-loop", "return(10000)"))
env.storage_controller.configure_failpoints(("delete-node-after-reconciles-spawned", "pause"))
ps_id_to_delete = env.pageservers[0].id
env.storage_controller.warm_up_all_secondaries()
assert deletion_api in [DeletionAPIKind.FORCE, DeletionAPIKind.GRACEFUL]
force = deletion_api == DeletionAPIKind.FORCE
env.storage_controller.retryable_node_operation(
lambda ps_id: env.storage_controller.node_delete(ps_id),
lambda ps_id: env.storage_controller.node_delete(ps_id, force),
ps_id_to_delete,
max_attempts=3,
backoff=2,
@@ -2701,6 +2735,8 @@ def test_storage_controller_node_delete_cancellation(neon_env_builder: NeonEnvBu
env.storage_controller.cancel_node_delete(ps_id_to_delete)
env.storage_controller.configure_failpoints(("delete-node-after-reconciles-spawned", "off"))
env.storage_controller.poll_node_status(
ps_id_to_delete,
PageserverAvailability.ACTIVE,
@@ -3252,7 +3288,10 @@ def test_storage_controller_ps_restarted_during_drain(neon_env_builder: NeonEnvB
wait_until(reconfigure_node_again)
def test_ps_unavailable_after_delete(neon_env_builder: NeonEnvBuilder):
@pytest.mark.parametrize("deletion_api", [DeletionAPIKind.OLD, DeletionAPIKind.FORCE])
def test_ps_unavailable_after_delete(
neon_env_builder: NeonEnvBuilder, deletion_api: DeletionAPIKind
):
neon_env_builder.num_pageservers = 3
env = neon_env_builder.init_start()
@@ -3265,10 +3304,16 @@ def test_ps_unavailable_after_delete(neon_env_builder: NeonEnvBuilder):
assert_nodes_count(3)
ps = env.pageservers[0]
env.storage_controller.node_delete_old(ps.id)
# After deletion, the node count must be reduced
assert_nodes_count(2)
if deletion_api == DeletionAPIKind.FORCE:
ps.allowed_errors.append(".*request was dropped before completing.*")
env.storage_controller.node_delete(ps.id, force=True)
wait_until(lambda: assert_nodes_count(2))
elif deletion_api == DeletionAPIKind.OLD:
env.storage_controller.node_delete_old(ps.id)
assert_nodes_count(2)
else:
raise AssertionError(f"Invalid deletion API: {deletion_api}")
# Running pageserver CLI init in a separate thread
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
@@ -4814,3 +4859,103 @@ def test_storage_controller_migrate_with_pageserver_restart(
"shards": [{"node_id": int(secondary.id), "shard_number": 0}],
"preferred_az": DEFAULT_AZ_ID,
}
@run_only_on_default_postgres("PG version is not important for this test")
def test_storage_controller_forward_404(neon_env_builder: NeonEnvBuilder):
"""
Ensures that the storage controller correctly forwards 404s and converts some of them
into 503s before forwarding to the client.
"""
neon_env_builder.num_pageservers = 2
neon_env_builder.num_azs = 2
env = neon_env_builder.init_start()
env.storage_controller.allowed_errors.append(".*Reconcile error.*")
env.storage_controller.allowed_errors.append(".*Timed out.*")
env.storage_controller.tenant_policy_update(env.initial_tenant, {"placement": {"Attached": 1}})
env.storage_controller.reconcile_until_idle()
# 404s on tenants and timelines are forwarded as-is when reconciler is not running.
# Access a non-existing timeline -> 404
with pytest.raises(PageserverApiException) as e:
env.storage_controller.pageserver_api().timeline_detail(
env.initial_tenant, TimelineId.generate()
)
assert e.value.status_code == 404
with pytest.raises(PageserverApiException) as e:
env.storage_controller.pageserver_api().timeline_lsn_lease(
env.initial_tenant, TimelineId.generate(), Lsn(0)
)
assert e.value.status_code == 404
# Access a non-existing tenant when reconciler is not running -> 404
with pytest.raises(PageserverApiException) as e:
env.storage_controller.pageserver_api().timeline_detail(
TenantId.generate(), env.initial_timeline
)
assert e.value.status_code == 404
with pytest.raises(PageserverApiException) as e:
env.storage_controller.pageserver_api().timeline_lsn_lease(
TenantId.generate(), env.initial_timeline, Lsn(0)
)
assert e.value.status_code == 404
# Normal requests should succeed
detail = env.storage_controller.pageserver_api().timeline_detail(
env.initial_tenant, env.initial_timeline
)
last_record_lsn = Lsn(detail["last_record_lsn"])
env.storage_controller.pageserver_api().timeline_lsn_lease(
env.initial_tenant, env.initial_timeline, last_record_lsn
)
# Get into a situation where the intent state is not the same as the observed state.
describe = env.storage_controller.tenant_describe(env.initial_tenant)["shards"][0]
current_primary = describe["node_attached"]
current_secondary = describe["node_secondary"][0]
assert current_primary != current_secondary
# Pause the reconciler so that the generation number won't be updated.
env.storage_controller.configure_failpoints(
("reconciler-live-migrate-post-generation-inc", "pause")
)
# Do the migration in another thread; the request will be dropped as we don't wait.
shard_zero = TenantShardId(env.initial_tenant, 0, 0)
concurrent.futures.ThreadPoolExecutor(max_workers=1).submit(
env.storage_controller.tenant_shard_migrate,
shard_zero,
current_secondary,
StorageControllerMigrationConfig(override_scheduler=True),
)
# Not the best way to do this, we should wait until the migration gets started.
time.sleep(1)
placement = env.storage_controller.get_tenants_placement()[str(shard_zero)]
assert placement["observed"] != placement["intent"]
assert placement["observed"]["attached"] == current_primary
assert placement["intent"]["attached"] == current_secondary
# Now we issue requests that would cause 404 again
retry_strategy = Retry(total=0)
adapter = HTTPAdapter(max_retries=retry_strategy)
no_retry_api = env.storage_controller.pageserver_api()
no_retry_api.mount("http://", adapter)
no_retry_api.mount("https://", adapter)
# As intent state != observed state, tenant not found error should return 503,
# so that the client can retry once we've successfully migrated.
with pytest.raises(PageserverApiException) as e:
no_retry_api.timeline_detail(env.initial_tenant, TimelineId.generate())
assert e.value.status_code == 503, f"unexpected status code and error: {e.value}"
with pytest.raises(PageserverApiException) as e:
no_retry_api.timeline_lsn_lease(env.initial_tenant, TimelineId.generate(), Lsn(0))
assert e.value.status_code == 503, f"unexpected status code and error: {e.value}"
# Unblock reconcile operations
env.storage_controller.configure_failpoints(
("reconciler-live-migrate-post-generation-inc", "off")
)

View File

@@ -1,18 +1,18 @@
{
"v17": [
"17.5",
"eac5279cd147d4086e0eb242198aae2f4b766d7b"
"a50d80c7507e8ae9fc37bf1869051cf2d51370ab"
],
"v16": [
"16.9",
"51194dc5ce2e3523068d8607852e6c3125a17e58"
"e9db1ff5a6f3ca18f626ba3d62ab475e6c688a96"
],
"v15": [
"15.13",
"24313bf8f3de722968a2fdf764de7ef77ed64f06"
"cef72d5308ddce3795a9043fcd94f8849f7f4800"
],
"v14": [
"14.18",
"ac3c460e01a31f11fb52fd8d8e88e60f0e1069b4"
"47304b921555b3f33eb3b49daada3078e774cfd7"
]
}

View File

@@ -107,7 +107,6 @@ tracing-core = { version = "0.1" }
tracing-log = { version = "0.2" }
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
url = { version = "2", features = ["serde"] }
uuid = { version = "1", features = ["serde", "v4", "v7"] }
zeroize = { version = "1", features = ["derive", "serde"] }
zstd = { version = "0.13" }
zstd-safe = { version = "7", default-features = false, features = ["arrays", "legacy", "std", "zdict_builder"] }