mirror of
https://github.com/neondatabase/neon.git
synced 2026-05-18 13:40:37 +00:00
Compare commits
19 Commits
erik/alway
...
conrad/ref
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
26be13067c | ||
|
|
791b5d736b | ||
|
|
96bcfba79e | ||
|
|
8e95455aef | ||
|
|
f3ef60d236 | ||
|
|
8f627ea0ab | ||
|
|
6a353c33e3 | ||
|
|
64d0008389 | ||
|
|
53a05e8ccb | ||
|
|
62c0152e6b | ||
|
|
7fef4435c1 | ||
|
|
43fd5b218b | ||
|
|
29ee273d78 | ||
|
|
8b0f2efa57 | ||
|
|
b309cbc6e9 | ||
|
|
f0c0733a64 | ||
|
|
8862e7c4bf | ||
|
|
b7fc5a2fe0 | ||
|
|
4559ba79b6 |
108
Cargo.lock
generated
108
Cargo.lock
generated
@@ -1872,6 +1872,7 @@ dependencies = [
|
||||
"diesel_derives",
|
||||
"itoa",
|
||||
"serde_json",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2533,6 +2534,18 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"r-efi",
|
||||
"wasi 0.14.2+wasi-0.2.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gettid"
|
||||
version = "0.1.3"
|
||||
@@ -3606,9 +3619,9 @@ checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104"
|
||||
|
||||
[[package]]
|
||||
name = "lock_api"
|
||||
version = "0.4.10"
|
||||
version = "0.4.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16"
|
||||
checksum = "96936507f153605bddfcda068dd804796c84324ed2510809e5b2a624c81da765"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"scopeguard",
|
||||
@@ -3758,7 +3771,7 @@ dependencies = [
|
||||
"procfs",
|
||||
"prometheus",
|
||||
"rand 0.8.5",
|
||||
"rand_distr",
|
||||
"rand_distr 0.4.3",
|
||||
"twox-hash",
|
||||
]
|
||||
|
||||
@@ -3846,7 +3859,12 @@ checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
|
||||
name = "neon-shmem"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"lock_api",
|
||||
"nix 0.30.1",
|
||||
"rand 0.9.1",
|
||||
"rand_distr 0.5.1",
|
||||
"rustc-hash 2.1.1",
|
||||
"tempfile",
|
||||
"thiserror 1.0.69",
|
||||
"workspace_hack",
|
||||
@@ -5347,7 +5365,7 @@ dependencies = [
|
||||
"postgres_backend",
|
||||
"pq_proto",
|
||||
"rand 0.8.5",
|
||||
"rand_distr",
|
||||
"rand_distr 0.4.3",
|
||||
"rcgen",
|
||||
"redis",
|
||||
"regex",
|
||||
@@ -5358,7 +5376,7 @@ dependencies = [
|
||||
"reqwest-tracing",
|
||||
"rsa",
|
||||
"rstest",
|
||||
"rustc-hash 1.1.0",
|
||||
"rustc-hash 2.1.1",
|
||||
"rustls 0.23.27",
|
||||
"rustls-native-certs 0.8.0",
|
||||
"rustls-pemfile 2.1.1",
|
||||
@@ -5388,6 +5406,7 @@ dependencies = [
|
||||
"tracing-test",
|
||||
"tracing-utils",
|
||||
"try-lock",
|
||||
"type-safe-id",
|
||||
"typed-json",
|
||||
"url",
|
||||
"urlencoding",
|
||||
@@ -5451,6 +5470,12 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "r-efi"
|
||||
version = "5.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f"
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.7.3"
|
||||
@@ -5475,6 +5500,16 @@ dependencies = [
|
||||
"rand_core 0.6.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.9.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
|
||||
dependencies = [
|
||||
"rand_chacha 0.9.0",
|
||||
"rand_core 0.9.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.2.2"
|
||||
@@ -5495,6 +5530,16 @@ dependencies = [
|
||||
"rand_core 0.6.4",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
|
||||
dependencies = [
|
||||
"ppv-lite86",
|
||||
"rand_core 0.9.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.5.1"
|
||||
@@ -5513,6 +5558,15 @@ dependencies = [
|
||||
"getrandom 0.2.11",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.9.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
|
||||
dependencies = [
|
||||
"getrandom 0.3.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_distr"
|
||||
version = "0.4.3"
|
||||
@@ -5523,6 +5577,16 @@ dependencies = [
|
||||
"rand 0.8.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_distr"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
"rand 0.9.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_hc"
|
||||
version = "0.2.0"
|
||||
@@ -6933,6 +6997,7 @@ dependencies = [
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"utils",
|
||||
"uuid",
|
||||
"workspace_hack",
|
||||
]
|
||||
|
||||
@@ -8023,6 +8088,19 @@ dependencies = [
|
||||
"static_assertions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "type-safe-id"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dd9267f90719e0433aae095640b294ff36ccbf89649ecb9ee34464ec504be157"
|
||||
dependencies = [
|
||||
"arrayvec",
|
||||
"rand 0.9.1",
|
||||
"serde",
|
||||
"thiserror 2.0.11",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typed-json"
|
||||
version = "0.1.1"
|
||||
@@ -8206,6 +8284,7 @@ dependencies = [
|
||||
"tracing-error",
|
||||
"tracing-subscriber",
|
||||
"tracing-utils",
|
||||
"uuid",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
@@ -8348,6 +8427,15 @@ version = "0.11.0+wasi-snapshot-preview1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.14.2+wasi-0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3"
|
||||
dependencies = [
|
||||
"wit-bindgen-rt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasite"
|
||||
version = "0.1.0"
|
||||
@@ -8705,6 +8793,15 @@ dependencies = [
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen-rt"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1"
|
||||
dependencies = [
|
||||
"bitflags 2.8.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "workspace_hack"
|
||||
version = "0.1.0"
|
||||
@@ -8807,7 +8904,6 @@ dependencies = [
|
||||
"tracing-log",
|
||||
"tracing-subscriber",
|
||||
"url",
|
||||
"uuid",
|
||||
"zeroize",
|
||||
"zstd",
|
||||
"zstd-safe",
|
||||
|
||||
@@ -130,6 +130,7 @@ jemalloc_pprof = { version = "0.7", features = ["symbolize", "flamegraph"] }
|
||||
jsonwebtoken = "9"
|
||||
lasso = "0.7"
|
||||
libc = "0.2"
|
||||
lock_api = "0.4.13"
|
||||
md5 = "0.7.0"
|
||||
measured = { version = "0.0.22", features=["lasso"] }
|
||||
measured-process = { version = "0.0.22" }
|
||||
@@ -165,7 +166,7 @@ reqwest-middleware = "0.4"
|
||||
reqwest-retry = "0.7"
|
||||
routerify = "3"
|
||||
rpds = "0.13"
|
||||
rustc-hash = "1.1.0"
|
||||
rustc-hash = "2.1.1"
|
||||
rustls = { version = "0.23.16", default-features = false }
|
||||
rustls-pemfile = "2"
|
||||
rustls-pki-types = "1.11"
|
||||
|
||||
@@ -2450,14 +2450,31 @@ LIMIT 100",
|
||||
pub fn spawn_lfc_offload_task(self: &Arc<Self>, interval: Duration) {
|
||||
self.terminate_lfc_offload_task();
|
||||
let secs = interval.as_secs();
|
||||
info!("spawning lfc offload worker with {secs}s interval");
|
||||
let this = self.clone();
|
||||
|
||||
info!("spawning LFC offload worker with {secs}s interval");
|
||||
let handle = spawn(async move {
|
||||
let mut interval = time::interval(interval);
|
||||
interval.tick().await; // returns immediately
|
||||
loop {
|
||||
interval.tick().await;
|
||||
this.offload_lfc_async().await;
|
||||
|
||||
let prewarm_state = this.state.lock().unwrap().lfc_prewarm_state.clone();
|
||||
// Do not offload LFC state if we are currently prewarming or any issue occurred.
|
||||
// If we'd do that, we might override the LFC state in endpoint storage with some
|
||||
// incomplete state. Imagine a situation:
|
||||
// 1. Endpoint started with `autoprewarm: true`
|
||||
// 2. While prewarming is not completed, we upload the new incomplete state
|
||||
// 3. Compute gets interrupted and restarts
|
||||
// 4. We start again and try to prewarm with the state from 2. instead of the previous complete state
|
||||
if matches!(
|
||||
prewarm_state,
|
||||
LfcPrewarmState::Completed
|
||||
| LfcPrewarmState::NotPrewarmed
|
||||
| LfcPrewarmState::Skipped
|
||||
) {
|
||||
this.offload_lfc_async().await;
|
||||
}
|
||||
}
|
||||
});
|
||||
*self.lfc_offload_task.lock().unwrap() = Some(handle);
|
||||
|
||||
@@ -89,7 +89,7 @@ impl ComputeNode {
|
||||
self.state.lock().unwrap().lfc_offload_state.clone()
|
||||
}
|
||||
|
||||
/// If there is a prewarm request ongoing, return false, true otherwise
|
||||
/// If there is a prewarm request ongoing, return `false`, `true` otherwise.
|
||||
pub fn prewarm_lfc(self: &Arc<Self>, from_endpoint: Option<String>) -> bool {
|
||||
{
|
||||
let state = &mut self.state.lock().unwrap().lfc_prewarm_state;
|
||||
@@ -101,15 +101,25 @@ impl ComputeNode {
|
||||
|
||||
let cloned = self.clone();
|
||||
spawn(async move {
|
||||
let Err(err) = cloned.prewarm_impl(from_endpoint).await else {
|
||||
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Completed;
|
||||
return;
|
||||
};
|
||||
crate::metrics::LFC_PREWARM_ERRORS.inc();
|
||||
error!(%err, "prewarming lfc");
|
||||
cloned.state.lock().unwrap().lfc_prewarm_state = LfcPrewarmState::Failed {
|
||||
error: err.to_string(),
|
||||
let state = match cloned.prewarm_impl(from_endpoint).await {
|
||||
Ok(true) => LfcPrewarmState::Completed,
|
||||
Ok(false) => {
|
||||
info!(
|
||||
"skipping LFC prewarm because LFC state is not found in endpoint storage"
|
||||
);
|
||||
LfcPrewarmState::Skipped
|
||||
}
|
||||
Err(err) => {
|
||||
crate::metrics::LFC_PREWARM_ERRORS.inc();
|
||||
error!(%err, "could not prewarm LFC");
|
||||
|
||||
LfcPrewarmState::Failed {
|
||||
error: err.to_string(),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
cloned.state.lock().unwrap().lfc_prewarm_state = state;
|
||||
});
|
||||
true
|
||||
}
|
||||
@@ -120,15 +130,21 @@ impl ComputeNode {
|
||||
EndpointStoragePair::from_spec_and_endpoint(state.pspec.as_ref().unwrap(), from_endpoint)
|
||||
}
|
||||
|
||||
async fn prewarm_impl(&self, from_endpoint: Option<String>) -> Result<()> {
|
||||
/// Request LFC state from endpoint storage and load corresponding pages into Postgres.
|
||||
/// Returns a result with `false` if the LFC state is not found in endpoint storage.
|
||||
async fn prewarm_impl(&self, from_endpoint: Option<String>) -> Result<bool> {
|
||||
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(from_endpoint)?;
|
||||
info!(%url, "requesting LFC state from endpoint storage");
|
||||
|
||||
info!(%url, "requesting LFC state from endpoint storage");
|
||||
let request = Client::new().get(&url).bearer_auth(token);
|
||||
let res = request.send().await.context("querying endpoint storage")?;
|
||||
let status = res.status();
|
||||
if status != StatusCode::OK {
|
||||
bail!("{status} querying endpoint storage")
|
||||
match status {
|
||||
StatusCode::OK => (),
|
||||
StatusCode::NOT_FOUND => {
|
||||
return Ok(false);
|
||||
}
|
||||
_ => bail!("{status} querying endpoint storage"),
|
||||
}
|
||||
|
||||
let mut uncompressed = Vec::new();
|
||||
@@ -141,7 +157,8 @@ impl ComputeNode {
|
||||
.await
|
||||
.context("decoding LFC state")?;
|
||||
let uncompressed_len = uncompressed.len();
|
||||
info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}, loading into postgres");
|
||||
|
||||
info!(%url, "downloaded LFC state, uncompressed size {uncompressed_len}, loading into Postgres");
|
||||
|
||||
ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
|
||||
.await
|
||||
@@ -149,7 +166,9 @@ impl ComputeNode {
|
||||
.query_one("select neon.prewarm_local_cache($1)", &[&uncompressed])
|
||||
.await
|
||||
.context("loading LFC state into postgres")
|
||||
.map(|_| ())
|
||||
.map(|_| ())?;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
/// If offload request is ongoing, return false, true otherwise
|
||||
@@ -177,12 +196,14 @@ impl ComputeNode {
|
||||
|
||||
async fn offload_lfc_with_state_update(&self) {
|
||||
crate::metrics::LFC_OFFLOADS.inc();
|
||||
|
||||
let Err(err) = self.offload_lfc_impl().await else {
|
||||
self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Completed;
|
||||
return;
|
||||
};
|
||||
|
||||
crate::metrics::LFC_OFFLOAD_ERRORS.inc();
|
||||
error!(%err, "offloading lfc");
|
||||
error!(%err, "could not offload LFC state to endpoint storage");
|
||||
self.state.lock().unwrap().lfc_offload_state = LfcOffloadState::Failed {
|
||||
error: err.to_string(),
|
||||
};
|
||||
@@ -190,7 +211,7 @@ impl ComputeNode {
|
||||
|
||||
async fn offload_lfc_impl(&self) -> Result<()> {
|
||||
let EndpointStoragePair { url, token } = self.endpoint_storage_pair(None)?;
|
||||
info!(%url, "requesting LFC state from postgres");
|
||||
info!(%url, "requesting LFC state from Postgres");
|
||||
|
||||
let mut compressed = Vec::new();
|
||||
ComputeNode::get_maintenance_client(&self.tokio_conn_conf)
|
||||
@@ -205,13 +226,17 @@ impl ComputeNode {
|
||||
.read_to_end(&mut compressed)
|
||||
.await
|
||||
.context("compressing LFC state")?;
|
||||
|
||||
let compressed_len = compressed.len();
|
||||
info!(%url, "downloaded LFC state, compressed size {compressed_len}, writing to endpoint storage");
|
||||
|
||||
let request = Client::new().put(url).bearer_auth(token).body(compressed);
|
||||
match request.send().await {
|
||||
Ok(res) if res.status() == StatusCode::OK => Ok(()),
|
||||
Ok(res) => bail!("Error writing to endpoint storage: {}", res.status()),
|
||||
Ok(res) => bail!(
|
||||
"Request to endpoint storage failed with status: {}",
|
||||
res.status()
|
||||
),
|
||||
Err(err) => Err(err).context("writing to endpoint storage"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,14 +56,15 @@ pub fn write_postgres_conf(
|
||||
writeln!(file, "{conf}")?;
|
||||
}
|
||||
|
||||
// Stripe size GUC should be defined prior to connection string
|
||||
if let Some(stripe_size) = spec.shard_stripe_size {
|
||||
writeln!(file, "neon.stripe_size={stripe_size}")?;
|
||||
}
|
||||
// Add options for connecting to storage
|
||||
writeln!(file, "# Neon storage settings")?;
|
||||
if let Some(s) = &spec.pageserver_connstring {
|
||||
writeln!(file, "neon.pageserver_connstring={}", escape_conf_value(s))?;
|
||||
}
|
||||
if let Some(stripe_size) = spec.shard_stripe_size {
|
||||
writeln!(file, "neon.stripe_size={stripe_size}")?;
|
||||
}
|
||||
if !spec.safekeeper_connstrings.is_empty() {
|
||||
let mut neon_safekeepers_value = String::new();
|
||||
tracing::info!(
|
||||
|
||||
@@ -613,11 +613,11 @@ components:
|
||||
- skipped
|
||||
properties:
|
||||
status:
|
||||
description: Lfc prewarm status
|
||||
enum: [not_prewarmed, prewarming, completed, failed]
|
||||
description: LFC prewarm status
|
||||
enum: [not_prewarmed, prewarming, completed, failed, skipped]
|
||||
type: string
|
||||
error:
|
||||
description: Lfc prewarm error, if any
|
||||
description: LFC prewarm error, if any
|
||||
type: string
|
||||
total:
|
||||
description: Total pages processed
|
||||
@@ -635,11 +635,11 @@ components:
|
||||
- status
|
||||
properties:
|
||||
status:
|
||||
description: Lfc offload status
|
||||
description: LFC offload status
|
||||
enum: [not_offloaded, offloading, completed, failed]
|
||||
type: string
|
||||
error:
|
||||
description: Lfc offload error, if any
|
||||
description: LFC offload error, if any
|
||||
type: string
|
||||
|
||||
PromoteState:
|
||||
|
||||
@@ -8,10 +8,10 @@ code changes locally, but not suitable for running production systems.
|
||||
|
||||
## Example: Start with Postgres 16
|
||||
|
||||
To create and start a local development environment with Postgres 16, you will need to provide `--pg-version` flag to 3 of the start-up commands.
|
||||
To create and start a local development environment with Postgres 16, you will need to provide `--pg-version` flag to 2 of the start-up commands.
|
||||
|
||||
```shell
|
||||
cargo neon init --pg-version 16
|
||||
cargo neon init
|
||||
cargo neon start
|
||||
cargo neon tenant create --set-default --pg-version 16
|
||||
cargo neon endpoint create main --pg-version 16
|
||||
|
||||
@@ -76,6 +76,12 @@ enum Command {
|
||||
NodeStartDelete {
|
||||
#[arg(long)]
|
||||
node_id: NodeId,
|
||||
/// When `force` is true, skip waiting for shards to prewarm during migration.
|
||||
/// This can significantly speed up node deletion since prewarming all shards
|
||||
/// can take considerable time, but may result in slower initial access to
|
||||
/// migrated shards until they warm up naturally.
|
||||
#[arg(long)]
|
||||
force: bool,
|
||||
},
|
||||
/// Cancel deletion of the specified pageserver and wait for `timeout`
|
||||
/// for the operation to be canceled. May be retried.
|
||||
@@ -952,13 +958,14 @@ async fn main() -> anyhow::Result<()> {
|
||||
.dispatch::<(), ()>(Method::DELETE, format!("control/v1/node/{node_id}"), None)
|
||||
.await?;
|
||||
}
|
||||
Command::NodeStartDelete { node_id } => {
|
||||
Command::NodeStartDelete { node_id, force } => {
|
||||
let query = if force {
|
||||
format!("control/v1/node/{node_id}/delete?force=true")
|
||||
} else {
|
||||
format!("control/v1/node/{node_id}/delete")
|
||||
};
|
||||
storcon_client
|
||||
.dispatch::<(), ()>(
|
||||
Method::PUT,
|
||||
format!("control/v1/node/{node_id}/delete"),
|
||||
None,
|
||||
)
|
||||
.dispatch::<(), ()>(Method::PUT, query, None)
|
||||
.await?;
|
||||
println!("Delete started for {node_id}");
|
||||
}
|
||||
|
||||
@@ -46,16 +46,33 @@ pub struct ExtensionInstallResponse {
|
||||
pub version: ExtVersion,
|
||||
}
|
||||
|
||||
/// Status of the LFC prewarm process. The same state machine is reused for
|
||||
/// both autoprewarm (prewarm after compute/Postgres start using the previously
|
||||
/// stored LFC state) and explicit prewarming via API.
|
||||
#[derive(Serialize, Default, Debug, Clone, PartialEq)]
|
||||
#[serde(tag = "status", rename_all = "snake_case")]
|
||||
pub enum LfcPrewarmState {
|
||||
/// Default value when compute boots up.
|
||||
#[default]
|
||||
NotPrewarmed,
|
||||
/// Prewarming thread is active and loading pages into LFC.
|
||||
Prewarming,
|
||||
/// We found requested LFC state in the endpoint storage and
|
||||
/// completed prewarming successfully.
|
||||
Completed,
|
||||
Failed {
|
||||
error: String,
|
||||
},
|
||||
/// Unexpected error happened during prewarming. Note, `Not Found 404`
|
||||
/// response from the endpoint storage is explicitly excluded here
|
||||
/// because it can normally happen on the first compute start,
|
||||
/// since LFC state is not available yet.
|
||||
Failed { error: String },
|
||||
/// We tried to fetch the corresponding LFC state from the endpoint storage,
|
||||
/// but received `Not Found 404`. This should normally happen only during the
|
||||
/// first endpoint start after creation with `autoprewarm: true`.
|
||||
///
|
||||
/// During the orchestrated prewarm via API, when a caller explicitly
|
||||
/// provides the LFC state key to prewarm from, it's the caller responsibility
|
||||
/// to handle this status as an error state in this case.
|
||||
Skipped,
|
||||
}
|
||||
|
||||
impl Display for LfcPrewarmState {
|
||||
@@ -64,6 +81,7 @@ impl Display for LfcPrewarmState {
|
||||
LfcPrewarmState::NotPrewarmed => f.write_str("NotPrewarmed"),
|
||||
LfcPrewarmState::Prewarming => f.write_str("Prewarming"),
|
||||
LfcPrewarmState::Completed => f.write_str("Completed"),
|
||||
LfcPrewarmState::Skipped => f.write_str("Skipped"),
|
||||
LfcPrewarmState::Failed { error } => write!(f, "Error({error})"),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,12 +4,14 @@
|
||||
//! a default registry.
|
||||
#![deny(clippy::undocumented_unsafe_blocks)]
|
||||
|
||||
use std::sync::RwLock;
|
||||
|
||||
use measured::label::{LabelGroupSet, LabelGroupVisitor, LabelName, NoLabels};
|
||||
use measured::metric::counter::CounterState;
|
||||
use measured::metric::gauge::GaugeState;
|
||||
use measured::metric::group::Encoding;
|
||||
use measured::metric::name::{MetricName, MetricNameEncoder};
|
||||
use measured::metric::{MetricEncoding, MetricFamilyEncoding};
|
||||
use measured::metric::{MetricEncoding, MetricFamilyEncoding, MetricType};
|
||||
use measured::{FixedCardinalityLabel, LabelGroup, MetricGroup};
|
||||
use once_cell::sync::Lazy;
|
||||
use prometheus::Registry;
|
||||
@@ -116,12 +118,52 @@ pub fn pow2_buckets(start: usize, end: usize) -> Vec<f64> {
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub struct InfoMetric<L: LabelGroup, M: MetricType = GaugeState> {
|
||||
label: RwLock<L>,
|
||||
metric: M,
|
||||
}
|
||||
|
||||
impl<L: LabelGroup> InfoMetric<L> {
|
||||
pub fn new(label: L) -> Self {
|
||||
Self::with_metric(label, GaugeState::new(1))
|
||||
}
|
||||
}
|
||||
|
||||
impl<L: LabelGroup, M: MetricType<Metadata = ()>> InfoMetric<L, M> {
|
||||
pub fn with_metric(label: L, metric: M) -> Self {
|
||||
Self {
|
||||
label: RwLock::new(label),
|
||||
metric,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_label(&self, label: L) {
|
||||
*self.label.write().unwrap() = label;
|
||||
}
|
||||
}
|
||||
|
||||
impl<L, M, E> MetricFamilyEncoding<E> for InfoMetric<L, M>
|
||||
where
|
||||
L: LabelGroup,
|
||||
M: MetricEncoding<E, Metadata = ()>,
|
||||
E: Encoding,
|
||||
{
|
||||
fn collect_family_into(
|
||||
&self,
|
||||
name: impl measured::metric::name::MetricNameEncoder,
|
||||
enc: &mut E,
|
||||
) -> Result<(), E::Err> {
|
||||
M::write_type(&name, enc)?;
|
||||
self.metric
|
||||
.collect_into(&(), &*self.label.read().unwrap(), name, enc)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BuildInfo {
|
||||
pub revision: &'static str,
|
||||
pub build_tag: &'static str,
|
||||
}
|
||||
|
||||
// todo: allow label group without the set
|
||||
impl LabelGroup for BuildInfo {
|
||||
fn visit_values(&self, v: &mut impl LabelGroupVisitor) {
|
||||
const REVISION: &LabelName = LabelName::from_str("revision");
|
||||
@@ -131,24 +173,6 @@ impl LabelGroup for BuildInfo {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Encoding> MetricFamilyEncoding<T> for BuildInfo
|
||||
where
|
||||
GaugeState: MetricEncoding<T>,
|
||||
{
|
||||
fn collect_family_into(
|
||||
&self,
|
||||
name: impl measured::metric::name::MetricNameEncoder,
|
||||
enc: &mut T,
|
||||
) -> Result<(), T::Err> {
|
||||
enc.write_help(&name, "Build/version information")?;
|
||||
GaugeState::write_type(&name, enc)?;
|
||||
GaugeState {
|
||||
count: std::sync::atomic::AtomicI64::new(1),
|
||||
}
|
||||
.collect_into(&(), self, name, enc)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(MetricGroup)]
|
||||
#[metric(new(build_info: BuildInfo))]
|
||||
pub struct NeonMetrics {
|
||||
@@ -165,8 +189,8 @@ pub struct NeonMetrics {
|
||||
#[derive(MetricGroup)]
|
||||
#[metric(new(build_info: BuildInfo))]
|
||||
pub struct LibMetrics {
|
||||
#[metric(init = build_info)]
|
||||
build_info: BuildInfo,
|
||||
#[metric(init = InfoMetric::new(build_info))]
|
||||
build_info: InfoMetric<BuildInfo>,
|
||||
|
||||
#[metric(flatten)]
|
||||
rusage: Rusage,
|
||||
|
||||
@@ -8,6 +8,13 @@ license.workspace = true
|
||||
thiserror.workspace = true
|
||||
nix.workspace=true
|
||||
workspace_hack = { version = "0.1", path = "../../workspace_hack" }
|
||||
libc.workspace = true
|
||||
lock_api.workspace = true
|
||||
rustc-hash.workspace = true
|
||||
|
||||
[target.'cfg(target_os = "macos")'.dependencies]
|
||||
tempfile = "3.14.0"
|
||||
|
||||
[dev-dependencies]
|
||||
rand = "0.9"
|
||||
rand_distr = "0.5.1"
|
||||
|
||||
583
libs/neon-shmem/src/hash.rs
Normal file
583
libs/neon-shmem/src/hash.rs
Normal file
@@ -0,0 +1,583 @@
|
||||
//! Resizable hash table implementation on top of byte-level storage (either a [`ShmemHandle`] or a fixed byte array).
|
||||
//!
|
||||
//! This hash table has two major components: the bucket array and the dictionary. Each bucket within the
|
||||
//! bucket array contains a `Option<(K, V)>` and an index of another bucket. In this way there is both an
|
||||
//! implicit freelist within the bucket array (`None` buckets point to other `None` entries) and various hash
|
||||
//! chains within the bucket array (a Some bucket will point to other Some buckets that had the same hash).
|
||||
//!
|
||||
//! Buckets are never moved unless they are within a region that is being shrunk, and so the actual hash-
|
||||
//! dependent component is done with the dictionary. When a new key is inserted into the map, a position
|
||||
//! within the dictionary is decided based on its hash, the data is inserted into an empty bucket based
|
||||
//! off of the freelist, and then the index of said bucket is placed in the dictionary.
|
||||
//!
|
||||
//! This map is resizable (if initialized on top of a [`ShmemHandle`]). Both growing and shrinking happen
|
||||
//! in-place and are at a high level achieved by expanding/reducing the bucket array and rebuilding the
|
||||
//! dictionary by rehashing all keys.
|
||||
//!
|
||||
//! Concurrency is managed very simply: the entire map is guarded by one shared-memory RwLock.
|
||||
|
||||
use std::hash::{BuildHasher, Hash};
|
||||
use std::mem::MaybeUninit;
|
||||
|
||||
use crate::shmem::ShmemHandle;
|
||||
use crate::{shmem, sync::*};
|
||||
|
||||
mod core;
|
||||
pub mod entry;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use core::{Bucket, CoreHashMap, INVALID_POS};
|
||||
use entry::{Entry, OccupiedEntry, PrevPos, VacantEntry};
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// Error type for a hashmap shrink operation.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum HashMapShrinkError {
|
||||
/// There was an error encountered while resizing the memory area.
|
||||
#[error("shmem resize failed: {0}")]
|
||||
ResizeError(shmem::Error),
|
||||
/// Occupied entries in to-be-shrunk space were encountered beginning at the given index.
|
||||
#[error("occupied entry in deallocated space found at {0}")]
|
||||
RemainingEntries(usize),
|
||||
}
|
||||
|
||||
/// This represents a hash table that (possibly) lives in shared memory.
|
||||
/// If a new process is launched with fork(), the child process inherits
|
||||
/// this struct.
|
||||
#[must_use]
|
||||
pub struct HashMapInit<'a, K, V, S = rustc_hash::FxBuildHasher> {
|
||||
shmem_handle: Option<ShmemHandle>,
|
||||
shared_ptr: *mut HashMapShared<'a, K, V>,
|
||||
shared_size: usize,
|
||||
hasher: S,
|
||||
num_buckets: u32,
|
||||
}
|
||||
|
||||
/// This is a per-process handle to a hash table that (possibly) lives in shared memory.
|
||||
/// If a child process is launched with fork(), the child process should
|
||||
/// get its own HashMapAccess by calling HashMapInit::attach_writer/reader().
|
||||
///
|
||||
/// XXX: We're not making use of it at the moment, but this struct could
|
||||
/// hold process-local information in the future.
|
||||
pub struct HashMapAccess<'a, K, V, S = rustc_hash::FxBuildHasher> {
|
||||
shmem_handle: Option<ShmemHandle>,
|
||||
shared_ptr: *mut HashMapShared<'a, K, V>,
|
||||
hasher: S,
|
||||
}
|
||||
|
||||
unsafe impl<K: Sync, V: Sync, S> Sync for HashMapAccess<'_, K, V, S> {}
|
||||
unsafe impl<K: Send, V: Send, S> Send for HashMapAccess<'_, K, V, S> {}
|
||||
|
||||
impl<'a, K: Clone + Hash + Eq, V, S> HashMapInit<'a, K, V, S> {
|
||||
/// Change the 'hasher' used by the hash table.
|
||||
///
|
||||
/// NOTE: This must be called right after creating the hash table,
|
||||
/// before inserting any entries and before calling attach_writer/reader.
|
||||
/// Otherwise different accessors could be using different hash function,
|
||||
/// with confusing results.
|
||||
pub fn with_hasher<T: BuildHasher>(self, hasher: T) -> HashMapInit<'a, K, V, T> {
|
||||
HashMapInit {
|
||||
hasher,
|
||||
shmem_handle: self.shmem_handle,
|
||||
shared_ptr: self.shared_ptr,
|
||||
shared_size: self.shared_size,
|
||||
num_buckets: self.num_buckets,
|
||||
}
|
||||
}
|
||||
|
||||
/// Loosely (over)estimate the size needed to store a hash table with `num_buckets` buckets.
|
||||
pub fn estimate_size(num_buckets: u32) -> usize {
|
||||
// add some margin to cover alignment etc.
|
||||
CoreHashMap::<K, V>::estimate_size(num_buckets) + size_of::<HashMapShared<K, V>>() + 1000
|
||||
}
|
||||
|
||||
fn new(
|
||||
num_buckets: u32,
|
||||
shmem_handle: Option<ShmemHandle>,
|
||||
area_ptr: *mut u8,
|
||||
area_size: usize,
|
||||
hasher: S,
|
||||
) -> Self {
|
||||
let mut ptr: *mut u8 = area_ptr;
|
||||
let end_ptr: *mut u8 = unsafe { ptr.add(area_size) };
|
||||
|
||||
// carve out area for the One Big Lock (TM) and the HashMapShared.
|
||||
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<libc::pthread_rwlock_t>())) };
|
||||
let raw_lock_ptr = ptr;
|
||||
ptr = unsafe { ptr.add(size_of::<libc::pthread_rwlock_t>()) };
|
||||
ptr = unsafe { ptr.add(ptr.align_offset(align_of::<HashMapShared<K, V>>())) };
|
||||
let shared_ptr: *mut HashMapShared<K, V> = ptr.cast();
|
||||
ptr = unsafe { ptr.add(size_of::<HashMapShared<K, V>>()) };
|
||||
|
||||
// carve out the buckets
|
||||
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<core::Bucket<K, V>>())) };
|
||||
let buckets_ptr = ptr;
|
||||
ptr = unsafe { ptr.add(size_of::<core::Bucket<K, V>>() * num_buckets as usize) };
|
||||
|
||||
// use remaining space for the dictionary
|
||||
ptr = unsafe { ptr.byte_add(ptr.align_offset(align_of::<u32>())) };
|
||||
assert!(ptr.addr() < end_ptr.addr());
|
||||
let dictionary_ptr = ptr;
|
||||
let dictionary_size = unsafe { end_ptr.byte_offset_from(ptr) / size_of::<u32>() as isize };
|
||||
assert!(dictionary_size > 0);
|
||||
|
||||
let buckets =
|
||||
unsafe { std::slice::from_raw_parts_mut(buckets_ptr.cast(), num_buckets as usize) };
|
||||
let dictionary = unsafe {
|
||||
std::slice::from_raw_parts_mut(dictionary_ptr.cast(), dictionary_size as usize)
|
||||
};
|
||||
|
||||
let hashmap = CoreHashMap::new(buckets, dictionary);
|
||||
unsafe {
|
||||
let lock = RwLock::from_raw(PthreadRwLock::new(raw_lock_ptr.cast()), hashmap);
|
||||
std::ptr::write(shared_ptr, lock);
|
||||
}
|
||||
|
||||
Self {
|
||||
num_buckets,
|
||||
shmem_handle,
|
||||
shared_ptr,
|
||||
shared_size: area_size,
|
||||
hasher,
|
||||
}
|
||||
}
|
||||
|
||||
/// Attach to a hash table for writing.
|
||||
pub fn attach_writer(self) -> HashMapAccess<'a, K, V, S> {
|
||||
HashMapAccess {
|
||||
shmem_handle: self.shmem_handle,
|
||||
shared_ptr: self.shared_ptr,
|
||||
hasher: self.hasher,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize a table for reading. Currently identical to [`HashMapInit::attach_writer`].
|
||||
///
|
||||
/// This is a holdover from a previous implementation and is being kept around for
|
||||
/// backwards compatibility reasons.
|
||||
pub fn attach_reader(self) -> HashMapAccess<'a, K, V, S> {
|
||||
self.attach_writer()
|
||||
}
|
||||
}
|
||||
|
||||
/// Hash table data that is actually stored in the shared memory area.
|
||||
///
|
||||
/// NOTE: We carve out the parts from a contiguous chunk. Growing and shrinking the hash table
|
||||
/// relies on the memory layout! The data structures are laid out in the contiguous shared memory
|
||||
/// area as follows:
|
||||
///
|
||||
/// [`libc::pthread_rwlock_t`]
|
||||
/// [`HashMapShared`]
|
||||
/// buckets
|
||||
/// dictionary
|
||||
///
|
||||
/// In between the above parts, there can be padding bytes to align the parts correctly.
|
||||
type HashMapShared<'a, K, V> = RwLock<CoreHashMap<'a, K, V>>;
|
||||
|
||||
impl<'a, K, V> HashMapInit<'a, K, V, rustc_hash::FxBuildHasher>
|
||||
where
|
||||
K: Clone + Hash + Eq,
|
||||
{
|
||||
/// Place the hash table within a user-supplied fixed memory area.
|
||||
pub fn with_fixed(num_buckets: u32, area: &'a mut [MaybeUninit<u8>]) -> Self {
|
||||
Self::new(
|
||||
num_buckets,
|
||||
None,
|
||||
area.as_mut_ptr().cast(),
|
||||
area.len(),
|
||||
rustc_hash::FxBuildHasher,
|
||||
)
|
||||
}
|
||||
|
||||
/// Place a new hash map in the given shared memory area
|
||||
///
|
||||
/// # Panics
|
||||
/// Will panic on failure to resize area to expected map size.
|
||||
pub fn with_shmem(num_buckets: u32, shmem: ShmemHandle) -> Self {
|
||||
let size = Self::estimate_size(num_buckets);
|
||||
shmem
|
||||
.set_size(size)
|
||||
.expect("could not resize shared memory area");
|
||||
let ptr = shmem.data_ptr.as_ptr().cast();
|
||||
Self::new(
|
||||
num_buckets,
|
||||
Some(shmem),
|
||||
ptr,
|
||||
size,
|
||||
rustc_hash::FxBuildHasher,
|
||||
)
|
||||
}
|
||||
|
||||
/// Make a resizable hash map within a new shared memory area with the given name.
|
||||
pub fn new_resizeable_named(num_buckets: u32, max_buckets: u32, name: &str) -> Self {
|
||||
let size = Self::estimate_size(num_buckets);
|
||||
let max_size = Self::estimate_size(max_buckets);
|
||||
let shmem =
|
||||
ShmemHandle::new(name, size, max_size).expect("failed to make shared memory area");
|
||||
let ptr = shmem.data_ptr.as_ptr().cast();
|
||||
|
||||
Self::new(
|
||||
num_buckets,
|
||||
Some(shmem),
|
||||
ptr,
|
||||
size,
|
||||
rustc_hash::FxBuildHasher,
|
||||
)
|
||||
}
|
||||
|
||||
/// Make a resizable hash map within a new anonymous shared memory area.
|
||||
pub fn new_resizeable(num_buckets: u32, max_buckets: u32) -> Self {
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
static COUNTER: AtomicUsize = AtomicUsize::new(0);
|
||||
let val = COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
let name = format!("neon_shmem_hmap{val}");
|
||||
Self::new_resizeable_named(num_buckets, max_buckets, &name)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, K, V, S: BuildHasher> HashMapAccess<'a, K, V, S>
|
||||
where
|
||||
K: Clone + Hash + Eq,
|
||||
{
|
||||
/// Hash a key using the map's hasher.
|
||||
#[inline]
|
||||
fn get_hash_value(&self, key: &K) -> u64 {
|
||||
self.hasher.hash_one(key)
|
||||
}
|
||||
|
||||
fn entry_with_hash(&self, key: K, hash: u64) -> Entry<'a, '_, K, V> {
|
||||
let mut map = unsafe { self.shared_ptr.as_ref() }.unwrap().write();
|
||||
let dict_pos = hash as usize % map.dictionary.len();
|
||||
let first = map.dictionary[dict_pos];
|
||||
if first == INVALID_POS {
|
||||
// no existing entry
|
||||
return Entry::Vacant(VacantEntry {
|
||||
map,
|
||||
key,
|
||||
dict_pos: dict_pos as u32,
|
||||
});
|
||||
}
|
||||
|
||||
let mut prev_pos = PrevPos::First(dict_pos as u32);
|
||||
let mut next = first;
|
||||
loop {
|
||||
let bucket = &mut map.buckets[next as usize];
|
||||
let (bucket_key, _bucket_value) = bucket.inner.as_mut().expect("entry is in use");
|
||||
if *bucket_key == key {
|
||||
// found existing entry
|
||||
return Entry::Occupied(OccupiedEntry {
|
||||
map,
|
||||
_key: key,
|
||||
prev_pos,
|
||||
bucket_pos: next,
|
||||
});
|
||||
}
|
||||
|
||||
if bucket.next == INVALID_POS {
|
||||
// No existing entry
|
||||
return Entry::Vacant(VacantEntry {
|
||||
map,
|
||||
key,
|
||||
dict_pos: dict_pos as u32,
|
||||
});
|
||||
}
|
||||
prev_pos = PrevPos::Chained(next);
|
||||
next = bucket.next;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a reference to the corresponding value for a key.
|
||||
pub fn get<'e>(&'e self, key: &K) -> Option<ValueReadGuard<'e, V>> {
|
||||
let hash = self.get_hash_value(key);
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
|
||||
RwLockReadGuard::try_map(map, |m| m.get_with_hash(key, hash)).ok()
|
||||
}
|
||||
|
||||
/// Get a reference to the entry containing a key.
|
||||
///
|
||||
/// NB: THis takes a write lock as there's no way to distinguish whether the intention
|
||||
/// is to use the entry for reading or for writing in advance.
|
||||
pub fn entry(&self, key: K) -> Entry<'a, '_, K, V> {
|
||||
let hash = self.get_hash_value(&key);
|
||||
self.entry_with_hash(key, hash)
|
||||
}
|
||||
|
||||
/// Remove a key given its hash. Returns the associated value if it existed.
|
||||
pub fn remove(&self, key: &K) -> Option<V> {
|
||||
let hash = self.get_hash_value(key);
|
||||
match self.entry_with_hash(key.clone(), hash) {
|
||||
Entry::Occupied(e) => Some(e.remove()),
|
||||
Entry::Vacant(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert/update a key. Returns the previous associated value if it existed.
|
||||
///
|
||||
/// # Errors
|
||||
/// Will return [`core::FullError`] if there is no more space left in the map.
|
||||
pub fn insert(&self, key: K, value: V) -> Result<Option<V>, core::FullError> {
|
||||
let hash = self.get_hash_value(&key);
|
||||
match self.entry_with_hash(key.clone(), hash) {
|
||||
Entry::Occupied(mut e) => Ok(Some(e.insert(value))),
|
||||
Entry::Vacant(e) => {
|
||||
_ = e.insert(value)?;
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Optionally return the entry for a bucket at a given index if it exists.
|
||||
///
|
||||
/// Has more overhead than one would intuitively expect: performs both a clone of the key
|
||||
/// due to the [`OccupiedEntry`] type owning the key and also a hash of the key in order
|
||||
/// to enable repairing the hash chain if the entry is removed.
|
||||
pub fn entry_at_bucket(&self, pos: usize) -> Option<OccupiedEntry<'a, '_, K, V>> {
|
||||
let map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
|
||||
if pos >= map.buckets.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let entry = map.buckets[pos].inner.as_ref();
|
||||
match entry {
|
||||
Some((key, _)) => Some(OccupiedEntry {
|
||||
_key: key.clone(),
|
||||
bucket_pos: pos as u32,
|
||||
prev_pos: entry::PrevPos::Unknown(self.get_hash_value(key)),
|
||||
map,
|
||||
}),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of buckets in the table.
|
||||
pub fn get_num_buckets(&self) -> usize {
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
|
||||
map.get_num_buckets()
|
||||
}
|
||||
|
||||
/// Return the key and value stored in bucket with given index. This can be used to
|
||||
/// iterate through the hash map.
|
||||
// TODO: An Iterator might be nicer. The communicator's clock algorithm needs to
|
||||
// _slowly_ iterate through all buckets with its clock hand, without holding a lock.
|
||||
// If we switch to an Iterator, it must not hold the lock.
|
||||
pub fn get_at_bucket(&self, pos: usize) -> Option<ValueReadGuard<(K, V)>> {
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
|
||||
if pos >= map.buckets.len() {
|
||||
return None;
|
||||
}
|
||||
RwLockReadGuard::try_map(map, |m| m.buckets[pos].inner.as_ref()).ok()
|
||||
}
|
||||
|
||||
/// Returns the index of the bucket a given value corresponds to.
|
||||
pub fn get_bucket_for_value(&self, val_ptr: *const V) -> usize {
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
|
||||
|
||||
let origin = map.buckets.as_ptr();
|
||||
let idx = (val_ptr as usize - origin as usize) / size_of::<Bucket<K, V>>();
|
||||
assert!(idx < map.buckets.len());
|
||||
|
||||
idx
|
||||
}
|
||||
|
||||
/// Returns the number of occupied buckets in the table.
|
||||
pub fn get_num_buckets_in_use(&self) -> usize {
|
||||
let map = unsafe { self.shared_ptr.as_ref() }.unwrap().read();
|
||||
map.buckets_in_use as usize
|
||||
}
|
||||
|
||||
/// Clears all entries in a table. Does not reset any shrinking operations.
|
||||
pub fn clear(&self) {
|
||||
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
|
||||
map.clear();
|
||||
}
|
||||
|
||||
/// Perform an in-place rehash of some region (0..`rehash_buckets`) of the table and reset
|
||||
/// the `buckets` and `dictionary` slices to be as long as `num_buckets`. Resets the freelist
|
||||
/// in the process.
|
||||
fn rehash_dict(
|
||||
&self,
|
||||
inner: &mut CoreHashMap<'a, K, V>,
|
||||
buckets_ptr: *mut core::Bucket<K, V>,
|
||||
end_ptr: *mut u8,
|
||||
num_buckets: u32,
|
||||
rehash_buckets: u32,
|
||||
) {
|
||||
inner.free_head = INVALID_POS;
|
||||
|
||||
let buckets;
|
||||
let dictionary;
|
||||
unsafe {
|
||||
let buckets_end_ptr = buckets_ptr.add(num_buckets as usize);
|
||||
let dictionary_ptr: *mut u32 = buckets_end_ptr
|
||||
.byte_add(buckets_end_ptr.align_offset(align_of::<u32>()))
|
||||
.cast();
|
||||
let dictionary_size: usize =
|
||||
end_ptr.byte_offset_from(buckets_end_ptr) as usize / size_of::<u32>();
|
||||
|
||||
buckets = std::slice::from_raw_parts_mut(buckets_ptr, num_buckets as usize);
|
||||
dictionary = std::slice::from_raw_parts_mut(dictionary_ptr, dictionary_size);
|
||||
}
|
||||
for e in dictionary.iter_mut() {
|
||||
*e = INVALID_POS;
|
||||
}
|
||||
|
||||
for (i, bucket) in buckets.iter_mut().enumerate().take(rehash_buckets as usize) {
|
||||
if bucket.inner.is_none() {
|
||||
bucket.next = inner.free_head;
|
||||
inner.free_head = i as u32;
|
||||
continue;
|
||||
}
|
||||
|
||||
let hash = self.hasher.hash_one(&bucket.inner.as_ref().unwrap().0);
|
||||
let pos: usize = (hash % dictionary.len() as u64) as usize;
|
||||
bucket.next = dictionary[pos];
|
||||
dictionary[pos] = i as u32;
|
||||
}
|
||||
|
||||
inner.dictionary = dictionary;
|
||||
inner.buckets = buckets;
|
||||
}
|
||||
|
||||
/// Rehash the map without growing or shrinking.
|
||||
pub fn shuffle(&self) {
|
||||
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
|
||||
let num_buckets = map.get_num_buckets() as u32;
|
||||
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
|
||||
let end_ptr: *mut u8 = unsafe { self.shared_ptr.byte_add(size_bytes).cast() };
|
||||
let buckets_ptr = map.buckets.as_mut_ptr();
|
||||
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets);
|
||||
}
|
||||
|
||||
/// Grow the number of buckets within the table.
|
||||
///
|
||||
/// 1. Grows the underlying shared memory area
|
||||
/// 2. Initializes new buckets and overwrites the current dictionary
|
||||
/// 3. Rehashes the dictionary
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if called on a map initialized with [`HashMapInit::with_fixed`].
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an [`shmem::Error`] if any errors occur resizing the memory region.
|
||||
pub fn grow(&self, num_buckets: u32) -> Result<(), shmem::Error> {
|
||||
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
|
||||
let old_num_buckets = map.buckets.len() as u32;
|
||||
|
||||
assert!(
|
||||
num_buckets >= old_num_buckets,
|
||||
"grow called with a smaller number of buckets"
|
||||
);
|
||||
if num_buckets == old_num_buckets {
|
||||
return Ok(());
|
||||
}
|
||||
let shmem_handle = self
|
||||
.shmem_handle
|
||||
.as_ref()
|
||||
.expect("grow called on a fixed-size hash table");
|
||||
|
||||
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
|
||||
shmem_handle.set_size(size_bytes)?;
|
||||
let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
|
||||
|
||||
// Initialize new buckets. The new buckets are linked to the free list.
|
||||
// NB: This overwrites the dictionary!
|
||||
let buckets_ptr = map.buckets.as_mut_ptr();
|
||||
unsafe {
|
||||
for i in old_num_buckets..num_buckets {
|
||||
let bucket = buckets_ptr.add(i as usize);
|
||||
bucket.write(core::Bucket {
|
||||
next: if i < num_buckets - 1 {
|
||||
i + 1
|
||||
} else {
|
||||
map.free_head
|
||||
},
|
||||
inner: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, old_num_buckets);
|
||||
map.free_head = old_num_buckets;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Begin a shrink, limiting all new allocations to be in buckets with index below `num_buckets`.
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if called on a map initialized with [`HashMapInit::with_fixed`] or if `num_buckets` is
|
||||
/// greater than the number of buckets in the map.
|
||||
pub fn begin_shrink(&mut self, num_buckets: u32) {
|
||||
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
|
||||
assert!(
|
||||
num_buckets <= map.get_num_buckets() as u32,
|
||||
"shrink called with a larger number of buckets"
|
||||
);
|
||||
_ = self
|
||||
.shmem_handle
|
||||
.as_ref()
|
||||
.expect("shrink called on a fixed-size hash table");
|
||||
map.alloc_limit = num_buckets;
|
||||
}
|
||||
|
||||
/// If a shrink operation is underway, returns the target size of the map. Otherwise, returns None.
|
||||
pub fn shrink_goal(&self) -> Option<usize> {
|
||||
let map = unsafe { self.shared_ptr.as_mut() }.unwrap().read();
|
||||
let goal = map.alloc_limit;
|
||||
if goal == INVALID_POS {
|
||||
None
|
||||
} else {
|
||||
Some(goal as usize)
|
||||
}
|
||||
}
|
||||
|
||||
/// Complete a shrink after caller has evicted entries, removing the unused buckets and rehashing.
|
||||
///
|
||||
/// # Panics
|
||||
/// The following cases result in a panic:
|
||||
/// - Calling this function on a map initialized with [`HashMapInit::with_fixed`].
|
||||
/// - Calling this function on a map when no shrink operation is in progress.
|
||||
pub fn finish_shrink(&self) -> Result<(), HashMapShrinkError> {
|
||||
let mut map = unsafe { self.shared_ptr.as_mut() }.unwrap().write();
|
||||
assert!(
|
||||
map.alloc_limit != INVALID_POS,
|
||||
"called finish_shrink when no shrink is in progress"
|
||||
);
|
||||
|
||||
let num_buckets = map.alloc_limit;
|
||||
|
||||
if map.get_num_buckets() == num_buckets as usize {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
assert!(
|
||||
map.buckets_in_use <= num_buckets,
|
||||
"called finish_shrink before enough entries were removed"
|
||||
);
|
||||
|
||||
for i in (num_buckets as usize)..map.buckets.len() {
|
||||
if map.buckets[i].inner.is_some() {
|
||||
return Err(HashMapShrinkError::RemainingEntries(i));
|
||||
}
|
||||
}
|
||||
|
||||
let shmem_handle = self
|
||||
.shmem_handle
|
||||
.as_ref()
|
||||
.expect("shrink called on a fixed-size hash table");
|
||||
|
||||
let size_bytes = HashMapInit::<K, V, S>::estimate_size(num_buckets);
|
||||
if let Err(e) = shmem_handle.set_size(size_bytes) {
|
||||
return Err(HashMapShrinkError::ResizeError(e));
|
||||
}
|
||||
let end_ptr: *mut u8 = unsafe { shmem_handle.data_ptr.as_ptr().add(size_bytes) };
|
||||
let buckets_ptr = map.buckets.as_mut_ptr();
|
||||
self.rehash_dict(&mut map, buckets_ptr, end_ptr, num_buckets, num_buckets);
|
||||
map.alloc_limit = INVALID_POS;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
174
libs/neon-shmem/src/hash/core.rs
Normal file
174
libs/neon-shmem/src/hash/core.rs
Normal file
@@ -0,0 +1,174 @@
|
||||
//! Simple hash table with chaining.
|
||||
|
||||
use std::hash::Hash;
|
||||
use std::mem::MaybeUninit;
|
||||
|
||||
use crate::hash::entry::*;
|
||||
|
||||
/// Invalid position within the map (either within the dictionary or bucket array).
|
||||
pub(crate) const INVALID_POS: u32 = u32::MAX;
|
||||
|
||||
/// Fundamental storage unit within the hash table. Either empty or contains a key-value pair.
|
||||
/// Always part of a chain of some kind (either a freelist if empty or a hash chain if full).
|
||||
pub(crate) struct Bucket<K, V> {
|
||||
/// Index of next bucket in the chain.
|
||||
pub(crate) next: u32,
|
||||
/// Key-value pair contained within bucket.
|
||||
pub(crate) inner: Option<(K, V)>,
|
||||
}
|
||||
|
||||
/// Core hash table implementation.
|
||||
pub(crate) struct CoreHashMap<'a, K, V> {
|
||||
/// Dictionary used to map hashes to bucket indices.
|
||||
pub(crate) dictionary: &'a mut [u32],
|
||||
/// Buckets containing key-value pairs.
|
||||
pub(crate) buckets: &'a mut [Bucket<K, V>],
|
||||
/// Head of the freelist.
|
||||
pub(crate) free_head: u32,
|
||||
/// Maximum index of a bucket allowed to be allocated. [`INVALID_POS`] if no limit.
|
||||
pub(crate) alloc_limit: u32,
|
||||
/// The number of currently occupied buckets.
|
||||
pub(crate) buckets_in_use: u32,
|
||||
}
|
||||
|
||||
/// Error for when there are no empty buckets left but one is needed.
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct FullError;
|
||||
|
||||
impl<'a, K: Clone + Hash + Eq, V> CoreHashMap<'a, K, V> {
|
||||
const FILL_FACTOR: f32 = 0.60;
|
||||
|
||||
/// Estimate the size of data contained within the the hash map.
|
||||
pub fn estimate_size(num_buckets: u32) -> usize {
|
||||
let mut size = 0;
|
||||
|
||||
// buckets
|
||||
size += size_of::<Bucket<K, V>>() * num_buckets as usize;
|
||||
|
||||
// dictionary
|
||||
size += (f32::ceil((size_of::<u32>() * num_buckets as usize) as f32 / Self::FILL_FACTOR))
|
||||
as usize;
|
||||
|
||||
size
|
||||
}
|
||||
|
||||
pub fn new(
|
||||
buckets: &'a mut [MaybeUninit<Bucket<K, V>>],
|
||||
dictionary: &'a mut [MaybeUninit<u32>],
|
||||
) -> Self {
|
||||
// Initialize the buckets
|
||||
for i in 0..buckets.len() {
|
||||
buckets[i].write(Bucket {
|
||||
next: if i < buckets.len() - 1 {
|
||||
i as u32 + 1
|
||||
} else {
|
||||
INVALID_POS
|
||||
},
|
||||
inner: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Initialize the dictionary
|
||||
for e in dictionary.iter_mut() {
|
||||
e.write(INVALID_POS);
|
||||
}
|
||||
|
||||
// TODO: use std::slice::assume_init_mut() once it stabilizes
|
||||
let buckets =
|
||||
unsafe { std::slice::from_raw_parts_mut(buckets.as_mut_ptr().cast(), buckets.len()) };
|
||||
let dictionary = unsafe {
|
||||
std::slice::from_raw_parts_mut(dictionary.as_mut_ptr().cast(), dictionary.len())
|
||||
};
|
||||
|
||||
Self {
|
||||
dictionary,
|
||||
buckets,
|
||||
free_head: 0,
|
||||
buckets_in_use: 0,
|
||||
alloc_limit: INVALID_POS,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the value associated with a key (if it exists) given its hash.
|
||||
pub fn get_with_hash(&self, key: &K, hash: u64) -> Option<&V> {
|
||||
let mut next = self.dictionary[hash as usize % self.dictionary.len()];
|
||||
loop {
|
||||
if next == INVALID_POS {
|
||||
return None;
|
||||
}
|
||||
|
||||
let bucket = &self.buckets[next as usize];
|
||||
let (bucket_key, bucket_value) = bucket.inner.as_ref().expect("entry is in use");
|
||||
if bucket_key == key {
|
||||
return Some(bucket_value);
|
||||
}
|
||||
next = bucket.next;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get number of buckets in map.
|
||||
pub fn get_num_buckets(&self) -> usize {
|
||||
self.buckets.len()
|
||||
}
|
||||
|
||||
/// Clears all entries from the hashmap.
|
||||
///
|
||||
/// Does not reset any allocation limits, but does clear any entries beyond them.
|
||||
pub fn clear(&mut self) {
|
||||
for i in 0..self.buckets.len() {
|
||||
self.buckets[i] = Bucket {
|
||||
next: if i < self.buckets.len() - 1 {
|
||||
i as u32 + 1
|
||||
} else {
|
||||
INVALID_POS
|
||||
},
|
||||
inner: None,
|
||||
}
|
||||
}
|
||||
for i in 0..self.dictionary.len() {
|
||||
self.dictionary[i] = INVALID_POS;
|
||||
}
|
||||
|
||||
self.free_head = 0;
|
||||
self.buckets_in_use = 0;
|
||||
}
|
||||
|
||||
/// Find the position of an unused bucket via the freelist and initialize it.
|
||||
pub(crate) fn alloc_bucket(&mut self, key: K, value: V) -> Result<u32, FullError> {
|
||||
let mut pos = self.free_head;
|
||||
|
||||
// Find the first bucket we're *allowed* to use.
|
||||
let mut prev = PrevPos::First(self.free_head);
|
||||
while pos != INVALID_POS && pos >= self.alloc_limit {
|
||||
let bucket = &mut self.buckets[pos as usize];
|
||||
prev = PrevPos::Chained(pos);
|
||||
pos = bucket.next;
|
||||
}
|
||||
if pos == INVALID_POS {
|
||||
return Err(FullError);
|
||||
}
|
||||
|
||||
// Repair the freelist.
|
||||
match prev {
|
||||
PrevPos::First(_) => {
|
||||
let next_pos = self.buckets[pos as usize].next;
|
||||
self.free_head = next_pos;
|
||||
}
|
||||
PrevPos::Chained(p) => {
|
||||
if p != INVALID_POS {
|
||||
let next_pos = self.buckets[pos as usize].next;
|
||||
self.buckets[p as usize].next = next_pos;
|
||||
}
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
// Initialize the bucket.
|
||||
let bucket = &mut self.buckets[pos as usize];
|
||||
self.buckets_in_use += 1;
|
||||
bucket.next = INVALID_POS;
|
||||
bucket.inner = Some((key, value));
|
||||
|
||||
Ok(pos)
|
||||
}
|
||||
}
|
||||
130
libs/neon-shmem/src/hash/entry.rs
Normal file
130
libs/neon-shmem/src/hash/entry.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
//! Equivalent of [`std::collections::hash_map::Entry`] for this hashmap.
|
||||
|
||||
use crate::hash::core::{CoreHashMap, FullError, INVALID_POS};
|
||||
use crate::sync::{RwLockWriteGuard, ValueWriteGuard};
|
||||
|
||||
use std::hash::Hash;
|
||||
use std::mem;
|
||||
|
||||
pub enum Entry<'a, 'b, K, V> {
|
||||
Occupied(OccupiedEntry<'a, 'b, K, V>),
|
||||
Vacant(VacantEntry<'a, 'b, K, V>),
|
||||
}
|
||||
|
||||
/// Enum representing the previous position within a chain.
|
||||
#[derive(Clone, Copy)]
|
||||
pub(crate) enum PrevPos {
|
||||
/// Starting index within the dictionary.
|
||||
First(u32),
|
||||
/// Regular index within the buckets.
|
||||
Chained(u32),
|
||||
/// Unknown - e.g. the associated entry was retrieved by index instead of chain.
|
||||
Unknown(u64),
|
||||
}
|
||||
|
||||
pub struct OccupiedEntry<'a, 'b, K, V> {
|
||||
/// Mutable reference to the map containing this entry.
|
||||
pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>,
|
||||
/// The key of the occupied entry
|
||||
pub(crate) _key: K,
|
||||
/// The index of the previous entry in the chain.
|
||||
pub(crate) prev_pos: PrevPos,
|
||||
/// The position of the bucket in the [`CoreHashMap`] bucket array.
|
||||
pub(crate) bucket_pos: u32,
|
||||
}
|
||||
|
||||
impl<K, V> OccupiedEntry<'_, '_, K, V> {
|
||||
pub fn get(&self) -> &V {
|
||||
&self.map.buckets[self.bucket_pos as usize]
|
||||
.inner
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.1
|
||||
}
|
||||
|
||||
pub fn get_mut(&mut self) -> &mut V {
|
||||
&mut self.map.buckets[self.bucket_pos as usize]
|
||||
.inner
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.1
|
||||
}
|
||||
|
||||
/// Inserts a value into the entry, replacing (and returning) the existing value.
|
||||
pub fn insert(&mut self, value: V) -> V {
|
||||
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
|
||||
// This assumes inner is Some, which it must be for an OccupiedEntry
|
||||
mem::replace(&mut bucket.inner.as_mut().unwrap().1, value)
|
||||
}
|
||||
|
||||
/// Removes the entry from the hash map, returning the value originally stored within it.
|
||||
///
|
||||
/// This may result in multiple bucket accesses if the entry was obtained by index as the
|
||||
/// previous chain entry needs to be discovered in this case.
|
||||
pub fn remove(mut self) -> V {
|
||||
// If this bucket was queried by index, go ahead and follow its chain from the start.
|
||||
let prev = if let PrevPos::Unknown(hash) = self.prev_pos {
|
||||
let dict_idx = hash as usize % self.map.dictionary.len();
|
||||
let mut prev = PrevPos::First(dict_idx as u32);
|
||||
let mut curr = self.map.dictionary[dict_idx];
|
||||
while curr != self.bucket_pos {
|
||||
assert!(curr != INVALID_POS);
|
||||
prev = PrevPos::Chained(curr);
|
||||
curr = self.map.buckets[curr as usize].next;
|
||||
}
|
||||
prev
|
||||
} else {
|
||||
self.prev_pos
|
||||
};
|
||||
|
||||
// CoreHashMap::remove returns Option<(K, V)>. We know it's Some for an OccupiedEntry.
|
||||
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
|
||||
|
||||
// unlink it from the chain
|
||||
match prev {
|
||||
PrevPos::First(dict_pos) => {
|
||||
self.map.dictionary[dict_pos as usize] = bucket.next;
|
||||
}
|
||||
PrevPos::Chained(bucket_pos) => {
|
||||
self.map.buckets[bucket_pos as usize].next = bucket.next;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
|
||||
// and add it to the freelist
|
||||
let free = self.map.free_head;
|
||||
let bucket = &mut self.map.buckets[self.bucket_pos as usize];
|
||||
let old_value = bucket.inner.take();
|
||||
bucket.next = free;
|
||||
self.map.free_head = self.bucket_pos;
|
||||
self.map.buckets_in_use -= 1;
|
||||
|
||||
old_value.unwrap().1
|
||||
}
|
||||
}
|
||||
|
||||
/// An abstract view into a vacant entry within the map.
|
||||
pub struct VacantEntry<'a, 'b, K, V> {
|
||||
/// Mutable reference to the map containing this entry.
|
||||
pub(crate) map: RwLockWriteGuard<'b, CoreHashMap<'a, K, V>>,
|
||||
/// The key to be inserted into this entry.
|
||||
pub(crate) key: K,
|
||||
/// The position within the dictionary corresponding to the key's hash.
|
||||
pub(crate) dict_pos: u32,
|
||||
}
|
||||
|
||||
impl<'b, K: Clone + Hash + Eq, V> VacantEntry<'_, 'b, K, V> {
|
||||
/// Insert a value into the vacant entry, finding and populating an empty bucket in the process.
|
||||
///
|
||||
/// # Errors
|
||||
/// Will return [`FullError`] if there are no unoccupied buckets in the map.
|
||||
pub fn insert(mut self, value: V) -> Result<ValueWriteGuard<'b, V>, FullError> {
|
||||
let pos = self.map.alloc_bucket(self.key, value)?;
|
||||
self.map.buckets[pos as usize].next = self.map.dictionary[self.dict_pos as usize];
|
||||
self.map.dictionary[self.dict_pos as usize] = pos;
|
||||
|
||||
Ok(RwLockWriteGuard::map(self.map, |m| {
|
||||
&mut m.buckets[pos as usize].inner.as_mut().unwrap().1
|
||||
}))
|
||||
}
|
||||
}
|
||||
428
libs/neon-shmem/src/hash/tests.rs
Normal file
428
libs/neon-shmem/src/hash/tests.rs
Normal file
@@ -0,0 +1,428 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::collections::HashSet;
|
||||
use std::fmt::Debug;
|
||||
use std::mem::MaybeUninit;
|
||||
|
||||
use crate::hash::Entry;
|
||||
use crate::hash::HashMapAccess;
|
||||
use crate::hash::HashMapInit;
|
||||
use crate::hash::core::FullError;
|
||||
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::{Rng, RngCore};
|
||||
use rand_distr::Zipf;
|
||||
|
||||
const TEST_KEY_LEN: usize = 16;
|
||||
|
||||
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)]
|
||||
struct TestKey([u8; TEST_KEY_LEN]);
|
||||
|
||||
impl From<&TestKey> for u128 {
|
||||
fn from(val: &TestKey) -> u128 {
|
||||
u128::from_be_bytes(val.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u128> for TestKey {
|
||||
fn from(val: u128) -> TestKey {
|
||||
TestKey(val.to_be_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<&'a [u8]> for TestKey {
|
||||
fn from(bytes: &'a [u8]) -> TestKey {
|
||||
TestKey(bytes.try_into().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
fn test_inserts<K: Into<TestKey> + Copy>(keys: &[K]) {
|
||||
let w = HashMapInit::<TestKey, usize>::new_resizeable_named(100000, 120000, "test_inserts")
|
||||
.attach_writer();
|
||||
|
||||
for (idx, k) in keys.iter().enumerate() {
|
||||
let res = w.entry((*k).into());
|
||||
match res {
|
||||
Entry::Occupied(mut e) => {
|
||||
e.insert(idx);
|
||||
}
|
||||
Entry::Vacant(e) => {
|
||||
let res = e.insert(idx);
|
||||
assert!(res.is_ok());
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
for (idx, k) in keys.iter().enumerate() {
|
||||
let x = w.get(&(*k).into());
|
||||
let value = x.as_deref().copied();
|
||||
assert_eq!(value, Some(idx));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dense() {
|
||||
// This exercises splitting a node with prefix
|
||||
let keys: &[u128] = &[0, 1, 2, 3, 256];
|
||||
test_inserts(keys);
|
||||
|
||||
// Dense keys
|
||||
let mut keys: Vec<u128> = (0..10000).collect();
|
||||
test_inserts(&keys);
|
||||
|
||||
// Do the same in random orders
|
||||
for _ in 1..10 {
|
||||
keys.shuffle(&mut rand::rng());
|
||||
test_inserts(&keys);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sparse() {
|
||||
// sparse keys
|
||||
let mut keys: Vec<TestKey> = Vec::new();
|
||||
let mut used_keys = HashSet::new();
|
||||
for _ in 0..10000 {
|
||||
loop {
|
||||
let key = rand::random::<u128>();
|
||||
if used_keys.contains(&key) {
|
||||
continue;
|
||||
}
|
||||
used_keys.insert(key);
|
||||
keys.push(key.into());
|
||||
break;
|
||||
}
|
||||
}
|
||||
test_inserts(&keys);
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct TestOp(TestKey, Option<usize>);
|
||||
|
||||
fn apply_op(
|
||||
op: &TestOp,
|
||||
map: &mut HashMapAccess<TestKey, usize>,
|
||||
shadow: &mut BTreeMap<TestKey, usize>,
|
||||
) {
|
||||
// apply the change to the shadow tree first
|
||||
let shadow_existing = if let Some(v) = op.1 {
|
||||
shadow.insert(op.0, v)
|
||||
} else {
|
||||
shadow.remove(&op.0)
|
||||
};
|
||||
|
||||
let entry = map.entry(op.0);
|
||||
let hash_existing = match op.1 {
|
||||
Some(new) => match entry {
|
||||
Entry::Occupied(mut e) => Some(e.insert(new)),
|
||||
Entry::Vacant(e) => {
|
||||
_ = e.insert(new).unwrap();
|
||||
None
|
||||
}
|
||||
},
|
||||
None => match entry {
|
||||
Entry::Occupied(e) => Some(e.remove()),
|
||||
Entry::Vacant(_) => None,
|
||||
},
|
||||
};
|
||||
|
||||
assert_eq!(shadow_existing, hash_existing);
|
||||
}
|
||||
|
||||
fn do_random_ops(
|
||||
num_ops: usize,
|
||||
size: u32,
|
||||
del_prob: f64,
|
||||
writer: &mut HashMapAccess<TestKey, usize>,
|
||||
shadow: &mut BTreeMap<TestKey, usize>,
|
||||
rng: &mut rand::rngs::ThreadRng,
|
||||
) {
|
||||
for i in 0..num_ops {
|
||||
let key: TestKey = ((rng.next_u32() % size) as u128).into();
|
||||
let op = TestOp(
|
||||
key,
|
||||
if rng.random_bool(del_prob) {
|
||||
Some(i)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
);
|
||||
apply_op(&op, writer, shadow);
|
||||
}
|
||||
}
|
||||
|
||||
fn do_deletes(
|
||||
num_ops: usize,
|
||||
writer: &mut HashMapAccess<TestKey, usize>,
|
||||
shadow: &mut BTreeMap<TestKey, usize>,
|
||||
) {
|
||||
for _ in 0..num_ops {
|
||||
let (k, _) = shadow.pop_first().unwrap();
|
||||
writer.remove(&k);
|
||||
}
|
||||
}
|
||||
|
||||
fn do_shrink(
|
||||
writer: &mut HashMapAccess<TestKey, usize>,
|
||||
shadow: &mut BTreeMap<TestKey, usize>,
|
||||
from: u32,
|
||||
to: u32,
|
||||
) {
|
||||
assert!(writer.shrink_goal().is_none());
|
||||
writer.begin_shrink(to);
|
||||
assert_eq!(writer.shrink_goal(), Some(to as usize));
|
||||
for i in to..from {
|
||||
if let Some(entry) = writer.entry_at_bucket(i as usize) {
|
||||
shadow.remove(&entry._key);
|
||||
entry.remove();
|
||||
}
|
||||
}
|
||||
let old_usage = writer.get_num_buckets_in_use();
|
||||
writer.finish_shrink().unwrap();
|
||||
assert!(writer.shrink_goal().is_none());
|
||||
assert_eq!(writer.get_num_buckets_in_use(), old_usage);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn random_ops() {
|
||||
let mut writer =
|
||||
HashMapInit::<TestKey, usize>::new_resizeable_named(100000, 120000, "test_random")
|
||||
.attach_writer();
|
||||
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
|
||||
|
||||
let distribution = Zipf::new(u128::MAX as f64, 1.1).unwrap();
|
||||
let mut rng = rand::rng();
|
||||
for i in 0..100000 {
|
||||
let key: TestKey = (rng.sample(distribution) as u128).into();
|
||||
|
||||
let op = TestOp(key, if rng.random_bool(0.75) { Some(i) } else { None });
|
||||
|
||||
apply_op(&op, &mut writer, &mut shadow);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shuffle() {
|
||||
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 1200, "test_shuf")
|
||||
.attach_writer();
|
||||
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
|
||||
let mut rng = rand::rng();
|
||||
|
||||
do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng);
|
||||
writer.shuffle();
|
||||
do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_grow() {
|
||||
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 2000, "test_grow")
|
||||
.attach_writer();
|
||||
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
|
||||
let mut rng = rand::rng();
|
||||
|
||||
do_random_ops(10000, 1000, 0.75, &mut writer, &mut shadow, &mut rng);
|
||||
let old_usage = writer.get_num_buckets_in_use();
|
||||
writer.grow(1500).unwrap();
|
||||
assert_eq!(writer.get_num_buckets_in_use(), old_usage);
|
||||
assert_eq!(writer.get_num_buckets(), 1500);
|
||||
do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear() {
|
||||
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_clear")
|
||||
.attach_writer();
|
||||
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
|
||||
let mut rng = rand::rng();
|
||||
do_random_ops(2000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
|
||||
writer.clear();
|
||||
assert_eq!(writer.get_num_buckets_in_use(), 0);
|
||||
assert_eq!(writer.get_num_buckets(), 1500);
|
||||
while let Some((key, _)) = shadow.pop_first() {
|
||||
assert!(writer.get(&key).is_none());
|
||||
}
|
||||
do_random_ops(2000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
|
||||
for i in 0..(1500 - writer.get_num_buckets_in_use()) {
|
||||
writer.insert((1500 + i as u128).into(), 0).unwrap();
|
||||
}
|
||||
assert_eq!(writer.insert(5000.into(), 0), Err(FullError {}));
|
||||
writer.clear();
|
||||
assert!(writer.insert(5000.into(), 0).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_idx_remove() {
|
||||
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_clear")
|
||||
.attach_writer();
|
||||
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
|
||||
let mut rng = rand::rng();
|
||||
do_random_ops(2000, 1500, 0.25, &mut writer, &mut shadow, &mut rng);
|
||||
for _ in 0..100 {
|
||||
let idx = (rng.next_u32() % 1500) as usize;
|
||||
if let Some(e) = writer.entry_at_bucket(idx) {
|
||||
shadow.remove(&e._key);
|
||||
e.remove();
|
||||
}
|
||||
}
|
||||
while let Some((key, val)) = shadow.pop_first() {
|
||||
assert_eq!(*writer.get(&key).unwrap(), val);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_idx_get() {
|
||||
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_clear")
|
||||
.attach_writer();
|
||||
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
|
||||
let mut rng = rand::rng();
|
||||
do_random_ops(2000, 1500, 0.25, &mut writer, &mut shadow, &mut rng);
|
||||
for _ in 0..100 {
|
||||
let idx = (rng.next_u32() % 1500) as usize;
|
||||
if let Some(pair) = writer.get_at_bucket(idx) {
|
||||
{
|
||||
let v: *const usize = &pair.1;
|
||||
assert_eq!(writer.get_bucket_for_value(v), idx);
|
||||
}
|
||||
{
|
||||
let v: *const usize = &pair.1;
|
||||
assert_eq!(writer.get_bucket_for_value(v), idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shrink() {
|
||||
let mut writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_shrink")
|
||||
.attach_writer();
|
||||
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
|
||||
let mut rng = rand::rng();
|
||||
|
||||
do_random_ops(10000, 1500, 0.75, &mut writer, &mut shadow, &mut rng);
|
||||
do_shrink(&mut writer, &mut shadow, 1500, 1000);
|
||||
assert_eq!(writer.get_num_buckets(), 1000);
|
||||
do_deletes(500, &mut writer, &mut shadow);
|
||||
do_random_ops(10000, 500, 0.75, &mut writer, &mut shadow, &mut rng);
|
||||
assert!(writer.get_num_buckets_in_use() <= 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shrink_grow_seq() {
|
||||
let mut writer =
|
||||
HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 20000, "test_grow_seq")
|
||||
.attach_writer();
|
||||
let mut shadow: std::collections::BTreeMap<TestKey, usize> = BTreeMap::new();
|
||||
let mut rng = rand::rng();
|
||||
|
||||
do_random_ops(500, 1000, 0.1, &mut writer, &mut shadow, &mut rng);
|
||||
eprintln!("Shrinking to 750");
|
||||
do_shrink(&mut writer, &mut shadow, 1000, 750);
|
||||
do_random_ops(200, 1000, 0.5, &mut writer, &mut shadow, &mut rng);
|
||||
eprintln!("Growing to 1500");
|
||||
writer.grow(1500).unwrap();
|
||||
do_random_ops(600, 1500, 0.1, &mut writer, &mut shadow, &mut rng);
|
||||
eprintln!("Shrinking to 200");
|
||||
while shadow.len() > 100 {
|
||||
do_deletes(1, &mut writer, &mut shadow);
|
||||
}
|
||||
do_shrink(&mut writer, &mut shadow, 1500, 200);
|
||||
do_random_ops(50, 1500, 0.25, &mut writer, &mut shadow, &mut rng);
|
||||
eprintln!("Growing to 10k");
|
||||
writer.grow(10000).unwrap();
|
||||
do_random_ops(10000, 5000, 0.25, &mut writer, &mut shadow, &mut rng);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bucket_ops() {
|
||||
let writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1000, 1200, "test_bucket_ops")
|
||||
.attach_writer();
|
||||
match writer.entry(1.into()) {
|
||||
Entry::Occupied(mut e) => {
|
||||
e.insert(2);
|
||||
}
|
||||
Entry::Vacant(e) => {
|
||||
_ = e.insert(2).unwrap();
|
||||
}
|
||||
}
|
||||
assert_eq!(writer.get_num_buckets_in_use(), 1);
|
||||
assert_eq!(writer.get_num_buckets(), 1000);
|
||||
assert_eq!(*writer.get(&1.into()).unwrap(), 2);
|
||||
let pos = match writer.entry(1.into()) {
|
||||
Entry::Occupied(e) => {
|
||||
assert_eq!(e._key, 1.into());
|
||||
e.bucket_pos as usize
|
||||
}
|
||||
Entry::Vacant(_) => {
|
||||
panic!("Insert didn't affect entry");
|
||||
}
|
||||
};
|
||||
assert_eq!(writer.entry_at_bucket(pos).unwrap()._key, 1.into());
|
||||
assert_eq!(*writer.get_at_bucket(pos).unwrap(), (1.into(), 2));
|
||||
{
|
||||
let ptr: *const usize = &*writer.get(&1.into()).unwrap();
|
||||
assert_eq!(writer.get_bucket_for_value(ptr), pos);
|
||||
}
|
||||
writer.remove(&1.into());
|
||||
assert!(writer.get(&1.into()).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_shrink_zero() {
|
||||
let mut writer =
|
||||
HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_shrink_zero")
|
||||
.attach_writer();
|
||||
writer.begin_shrink(0);
|
||||
for i in 0..1500 {
|
||||
writer.entry_at_bucket(i).map(|x| x.remove());
|
||||
}
|
||||
writer.finish_shrink().unwrap();
|
||||
assert_eq!(writer.get_num_buckets_in_use(), 0);
|
||||
let entry = writer.entry(1.into());
|
||||
if let Entry::Vacant(v) = entry {
|
||||
assert!(v.insert(2).is_err());
|
||||
} else {
|
||||
panic!("Somehow got non-vacant entry in empty map.")
|
||||
}
|
||||
writer.grow(50).unwrap();
|
||||
let entry = writer.entry(1.into());
|
||||
if let Entry::Vacant(v) = entry {
|
||||
assert!(v.insert(2).is_ok());
|
||||
} else {
|
||||
panic!("Somehow got non-vacant entry in empty map.")
|
||||
}
|
||||
assert_eq!(writer.get_num_buckets_in_use(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_grow_oom() {
|
||||
let writer = HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2000, "test_grow_oom")
|
||||
.attach_writer();
|
||||
writer.grow(20000).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_shrink_bigger() {
|
||||
let mut writer =
|
||||
HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2500, "test_shrink_bigger")
|
||||
.attach_writer();
|
||||
writer.begin_shrink(2000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_shrink_early_finish() {
|
||||
let writer =
|
||||
HashMapInit::<TestKey, usize>::new_resizeable_named(1500, 2500, "test_shrink_early_finish")
|
||||
.attach_writer();
|
||||
writer.finish_shrink().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_shrink_fixed_size() {
|
||||
let mut area = [MaybeUninit::uninit(); 10000];
|
||||
let init_struct = HashMapInit::<TestKey, usize>::with_fixed(3, &mut area);
|
||||
let mut writer = init_struct.attach_writer();
|
||||
writer.begin_shrink(1);
|
||||
}
|
||||
@@ -1 +1,3 @@
|
||||
pub mod hash;
|
||||
pub mod shmem;
|
||||
pub mod sync;
|
||||
|
||||
111
libs/neon-shmem/src/sync.rs
Normal file
111
libs/neon-shmem/src/sync.rs
Normal file
@@ -0,0 +1,111 @@
|
||||
//! Simple utilities akin to what's in [`std::sync`] but designed to work with shared memory.
|
||||
|
||||
use std::mem::MaybeUninit;
|
||||
use std::ptr::NonNull;
|
||||
|
||||
use nix::errno::Errno;
|
||||
|
||||
pub type RwLock<T> = lock_api::RwLock<PthreadRwLock, T>;
|
||||
pub type RwLockReadGuard<'a, T> = lock_api::RwLockReadGuard<'a, PthreadRwLock, T>;
|
||||
pub type RwLockWriteGuard<'a, T> = lock_api::RwLockWriteGuard<'a, PthreadRwLock, T>;
|
||||
pub type ValueReadGuard<'a, T> = lock_api::MappedRwLockReadGuard<'a, PthreadRwLock, T>;
|
||||
pub type ValueWriteGuard<'a, T> = lock_api::MappedRwLockWriteGuard<'a, PthreadRwLock, T>;
|
||||
|
||||
/// Shared memory read-write lock.
|
||||
pub struct PthreadRwLock(Option<NonNull<libc::pthread_rwlock_t>>);
|
||||
|
||||
/// Simple macro that calls a function in the libc namespace and panics if return value is nonzero.
|
||||
macro_rules! libc_checked {
|
||||
($fn_name:ident ( $($arg:expr),* )) => {{
|
||||
let res = libc::$fn_name($($arg),*);
|
||||
if res != 0 {
|
||||
panic!("{} failed with {}", stringify!($fn_name), Errno::from_raw(res));
|
||||
}
|
||||
}};
|
||||
}
|
||||
|
||||
impl PthreadRwLock {
|
||||
/// Creates a new `PthreadRwLock` on top of a pointer to a pthread rwlock.
|
||||
///
|
||||
/// # Safety
|
||||
/// `lock` must be non-null. Every unsafe operation will panic in the event of an error.
|
||||
pub unsafe fn new(lock: *mut libc::pthread_rwlock_t) -> Self {
|
||||
unsafe {
|
||||
let mut attrs = MaybeUninit::uninit();
|
||||
libc_checked!(pthread_rwlockattr_init(attrs.as_mut_ptr()));
|
||||
libc_checked!(pthread_rwlockattr_setpshared(
|
||||
attrs.as_mut_ptr(),
|
||||
libc::PTHREAD_PROCESS_SHARED
|
||||
));
|
||||
libc_checked!(pthread_rwlock_init(lock, attrs.as_mut_ptr()));
|
||||
// Safety: POSIX specifies that "any function affecting the attributes
|
||||
// object (including destruction) shall not affect any previously
|
||||
// initialized read-write locks".
|
||||
libc_checked!(pthread_rwlockattr_destroy(attrs.as_mut_ptr()));
|
||||
Self(Some(NonNull::new_unchecked(lock)))
|
||||
}
|
||||
}
|
||||
|
||||
fn inner(&self) -> NonNull<libc::pthread_rwlock_t> {
|
||||
match self.0 {
|
||||
None => {
|
||||
panic!("PthreadRwLock constructed badly - something likely used RawRwLock::INIT")
|
||||
}
|
||||
Some(x) => x,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl lock_api::RawRwLock for PthreadRwLock {
|
||||
type GuardMarker = lock_api::GuardSend;
|
||||
const INIT: Self = Self(None);
|
||||
|
||||
fn try_lock_shared(&self) -> bool {
|
||||
unsafe {
|
||||
let res = libc::pthread_rwlock_tryrdlock(self.inner().as_ptr());
|
||||
match res {
|
||||
0 => true,
|
||||
libc::EAGAIN => false,
|
||||
_ => panic!(
|
||||
"pthread_rwlock_tryrdlock failed with {}",
|
||||
Errno::from_raw(res)
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn try_lock_exclusive(&self) -> bool {
|
||||
unsafe {
|
||||
let res = libc::pthread_rwlock_trywrlock(self.inner().as_ptr());
|
||||
match res {
|
||||
0 => true,
|
||||
libc::EAGAIN => false,
|
||||
_ => panic!("try_wrlock failed with {}", Errno::from_raw(res)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn lock_shared(&self) {
|
||||
unsafe {
|
||||
libc_checked!(pthread_rwlock_rdlock(self.inner().as_ptr()));
|
||||
}
|
||||
}
|
||||
|
||||
fn lock_exclusive(&self) {
|
||||
unsafe {
|
||||
libc_checked!(pthread_rwlock_wrlock(self.inner().as_ptr()));
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn unlock_exclusive(&self) {
|
||||
unsafe {
|
||||
libc_checked!(pthread_rwlock_unlock(self.inner().as_ptr()));
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn unlock_shared(&self) {
|
||||
unsafe {
|
||||
libc_checked!(pthread_rwlock_unlock(self.inner().as_ptr()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -749,7 +749,18 @@ impl<IO: AsyncRead + AsyncWrite + Unpin> PostgresBackend<IO> {
|
||||
trace!("got query {query_string:?}");
|
||||
if let Err(e) = handler.process_query(self, query_string).await {
|
||||
match e {
|
||||
QueryError::Shutdown => return Ok(ProcessMsgResult::Break),
|
||||
err @ QueryError::Shutdown => {
|
||||
// Notify postgres of the connection shutdown at the libpq
|
||||
// protocol level. This avoids postgres having to tell apart
|
||||
// from an idle connection and a stale one, which is bug prone.
|
||||
let shutdown_error = short_error(&err);
|
||||
self.write_message_noflush(&BeMessage::ErrorResponse(
|
||||
&shutdown_error,
|
||||
Some(err.pg_error_code()),
|
||||
))?;
|
||||
|
||||
return Ok(ProcessMsgResult::Break);
|
||||
}
|
||||
QueryError::SimulatedConnectionError => {
|
||||
return Err(QueryError::SimulatedConnectionError);
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ tracing-subscriber = { workspace = true, features = ["json", "registry"] }
|
||||
tracing-utils.workspace = true
|
||||
rand.workspace = true
|
||||
scopeguard.workspace = true
|
||||
uuid.workspace = true
|
||||
strum.workspace = true
|
||||
strum_macros.workspace = true
|
||||
walkdir.workspace = true
|
||||
|
||||
@@ -12,7 +12,8 @@ use jsonwebtoken::{
|
||||
Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation, decode, encode,
|
||||
};
|
||||
use pem::Pem;
|
||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
||||
use serde::{Deserialize, Deserializer, Serialize, de::DeserializeOwned};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::id::TenantId;
|
||||
|
||||
@@ -25,6 +26,11 @@ pub enum Scope {
|
||||
/// Provides access to all data for a specific tenant (specified in `struct Claims` below)
|
||||
// TODO: join these two?
|
||||
Tenant,
|
||||
/// Provides access to all data for a specific tenant, but based on endpoint ID. This token scope
|
||||
/// is only used by compute to fetch the spec for a specific endpoint. The spec contains a Tenant-scoped
|
||||
/// token authorizing access to all data of a tenant, so the spec-fetch API requires a TenantEndpoint
|
||||
/// scope token to ensure that untrusted compute nodes can't fetch spec for arbitrary endpoints.
|
||||
TenantEndpoint,
|
||||
/// Provides blanket access to all tenants on the pageserver plus pageserver-wide APIs.
|
||||
/// Should only be used e.g. for status check/tenant creation/list.
|
||||
PageServerApi,
|
||||
@@ -51,17 +57,43 @@ pub enum Scope {
|
||||
ControllerPeer,
|
||||
}
|
||||
|
||||
fn deserialize_empty_string_as_none_uuid<'de, D>(deserializer: D) -> Result<Option<Uuid>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let opt = Option::<String>::deserialize(deserializer)?;
|
||||
match opt.as_deref() {
|
||||
Some("") => Ok(None),
|
||||
Some(s) => Uuid::parse_str(s)
|
||||
.map(Some)
|
||||
.map_err(serde::de::Error::custom),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// JWT payload. See docs/authentication.md for the format
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
pub struct Claims {
|
||||
#[serde(default)]
|
||||
pub tenant_id: Option<TenantId>,
|
||||
#[serde(
|
||||
default,
|
||||
skip_serializing_if = "Option::is_none",
|
||||
// Neon control plane includes this field as empty in the claims.
|
||||
// Consider it None in those cases.
|
||||
deserialize_with = "deserialize_empty_string_as_none_uuid"
|
||||
)]
|
||||
pub endpoint_id: Option<Uuid>,
|
||||
pub scope: Scope,
|
||||
}
|
||||
|
||||
impl Claims {
|
||||
pub fn new(tenant_id: Option<TenantId>, scope: Scope) -> Self {
|
||||
Self { tenant_id, scope }
|
||||
Self {
|
||||
tenant_id,
|
||||
scope,
|
||||
endpoint_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -212,6 +244,7 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
|
||||
let expected_claims = Claims {
|
||||
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
|
||||
scope: Scope::Tenant,
|
||||
endpoint_id: None,
|
||||
};
|
||||
|
||||
// A test token containing the following payload, signed using TEST_PRIV_KEY_ED25519:
|
||||
@@ -240,6 +273,7 @@ MC4CAQAwBQYDK2VwBCIEID/Drmc1AA6U/znNRWpF3zEGegOATQxfkdWxitcOMsIH
|
||||
let claims = Claims {
|
||||
tenant_id: Some(TenantId::from_str("3d1f7595b468230304e0b73cecbcb081").unwrap()),
|
||||
scope: Scope::Tenant,
|
||||
endpoint_id: None,
|
||||
};
|
||||
|
||||
let pem = pem::parse(TEST_PRIV_KEY_ED25519).unwrap();
|
||||
|
||||
@@ -873,6 +873,22 @@ impl Client {
|
||||
.map_err(Error::ReceiveBody)
|
||||
}
|
||||
|
||||
pub async fn reset_alert_gauges(&self) -> Result<()> {
|
||||
let uri = format!(
|
||||
"{}/hadron-internal/reset_alert_gauges",
|
||||
self.mgmt_api_endpoint
|
||||
);
|
||||
self.start_request(Method::POST, uri)
|
||||
.send()
|
||||
.await
|
||||
.map_err(Error::SendRequest)?
|
||||
.error_from_body()
|
||||
.await?
|
||||
.json()
|
||||
.await
|
||||
.map_err(Error::ReceiveBody)
|
||||
}
|
||||
|
||||
pub async fn wait_lsn(
|
||||
&self,
|
||||
tenant_shard_id: TenantShardId,
|
||||
|
||||
@@ -20,7 +20,8 @@ pub fn check_permission(claims: &Claims, tenant_id: Option<TenantId>) -> Result<
|
||||
| Scope::GenerationsApi
|
||||
| Scope::Infra
|
||||
| Scope::Scrubber
|
||||
| Scope::ControllerPeer,
|
||||
| Scope::ControllerPeer
|
||||
| Scope::TenantEndpoint,
|
||||
_,
|
||||
) => Err(AuthError(
|
||||
format!(
|
||||
|
||||
@@ -813,6 +813,7 @@ impl Timeline {
|
||||
let gc_cutoff_lsn_guard = self.get_applied_gc_cutoff_lsn();
|
||||
let gc_cutoff_planned = {
|
||||
let gc_info = self.gc_info.read().unwrap();
|
||||
info!(cutoffs=?gc_info.cutoffs, applied_cutoff=%*gc_cutoff_lsn_guard, "starting find_lsn_for_timestamp");
|
||||
gc_info.min_cutoff()
|
||||
};
|
||||
// Usually the planned cutoff is newer than the cutoff of the last gc run,
|
||||
|
||||
@@ -1908,20 +1908,16 @@ impl TenantShard {
|
||||
.map_err(LoadLocalTimelineError::ResumeDeletion)?;
|
||||
}
|
||||
|
||||
// Upload the tenant manifest.
|
||||
//
|
||||
// This is uploaded unconditionally on every attach. This prevents races where a stale,
|
||||
// still-alive tenant may modify a past manifest, and a future tenant loads it after this
|
||||
// tenant has acted on it. Uploading a new manifest effectively hands over ownership of the
|
||||
// manifest state. See: <https://databricks.atlassian.net/browse/LKB-165>.
|
||||
// Stash the preloaded tenant manifest, and upload a new manifest if changed.
|
||||
//
|
||||
// NB: this must happen after the tenant is fully populated above. In particular the
|
||||
// offloaded timelines, which are included in the manifest.
|
||||
assert!(
|
||||
self.remote_tenant_manifest.lock().await.is_none(),
|
||||
"tenant manifest set before attach"
|
||||
);
|
||||
self.maybe_upload_tenant_manifest().await?; // always uploads, remote_tenant_manifest is None
|
||||
{
|
||||
let mut guard = self.remote_tenant_manifest.lock().await;
|
||||
assert!(guard.is_none(), "tenant manifest set before preload"); // first populated here
|
||||
*guard = preload.tenant_manifest;
|
||||
}
|
||||
self.maybe_upload_tenant_manifest().await?;
|
||||
|
||||
// The local filesystem contents are a cache of what's in the remote IndexPart;
|
||||
// IndexPart is the source of truth.
|
||||
|
||||
@@ -219,10 +219,6 @@ static char *lfc_path;
|
||||
static uint64 lfc_generation;
|
||||
static FileCacheControl *lfc_ctl;
|
||||
static bool lfc_do_prewarm;
|
||||
static shmem_startup_hook_type prev_shmem_startup_hook;
|
||||
#if PG_VERSION_NUM>=150000
|
||||
static shmem_request_hook_type prev_shmem_request_hook;
|
||||
#endif
|
||||
|
||||
bool lfc_store_prefetch_result;
|
||||
bool lfc_prewarm_update_ws_estimation;
|
||||
@@ -342,18 +338,14 @@ lfc_ensure_opened(void)
|
||||
return true;
|
||||
}
|
||||
|
||||
static void
|
||||
lfc_shmem_startup(void)
|
||||
void
|
||||
LfcShmemInit(void)
|
||||
{
|
||||
bool found;
|
||||
static HASHCTL info;
|
||||
|
||||
if (prev_shmem_startup_hook)
|
||||
{
|
||||
prev_shmem_startup_hook();
|
||||
}
|
||||
|
||||
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
|
||||
if (lfc_max_size <= 0)
|
||||
return;
|
||||
|
||||
lfc_ctl = (FileCacheControl *) ShmemInitStruct("lfc", sizeof(FileCacheControl), &found);
|
||||
if (!found)
|
||||
@@ -398,19 +390,16 @@ lfc_shmem_startup(void)
|
||||
ConditionVariableInit(&lfc_ctl->cv[i]);
|
||||
|
||||
}
|
||||
LWLockRelease(AddinShmemInitLock);
|
||||
}
|
||||
|
||||
static void
|
||||
lfc_shmem_request(void)
|
||||
void
|
||||
LfcShmemRequest(void)
|
||||
{
|
||||
#if PG_VERSION_NUM>=150000
|
||||
if (prev_shmem_request_hook)
|
||||
prev_shmem_request_hook();
|
||||
#endif
|
||||
|
||||
RequestAddinShmemSpace(sizeof(FileCacheControl) + hash_estimate_size(SIZE_MB_TO_CHUNKS(lfc_max_size) + 1, FILE_CACHE_ENRTY_SIZE));
|
||||
RequestNamedLWLockTranche("lfc_lock", 1);
|
||||
if (lfc_max_size > 0)
|
||||
{
|
||||
RequestAddinShmemSpace(sizeof(FileCacheControl) + hash_estimate_size(SIZE_MB_TO_CHUNKS(lfc_max_size) + 1, FILE_CACHE_ENRTY_SIZE));
|
||||
RequestNamedLWLockTranche("lfc_lock", 1);
|
||||
}
|
||||
}
|
||||
|
||||
static bool
|
||||
@@ -642,18 +631,6 @@ lfc_init(void)
|
||||
NULL,
|
||||
NULL,
|
||||
NULL);
|
||||
|
||||
if (lfc_max_size == 0)
|
||||
return;
|
||||
|
||||
prev_shmem_startup_hook = shmem_startup_hook;
|
||||
shmem_startup_hook = lfc_shmem_startup;
|
||||
#if PG_VERSION_NUM>=150000
|
||||
prev_shmem_request_hook = shmem_request_hook;
|
||||
shmem_request_hook = lfc_shmem_request;
|
||||
#else
|
||||
lfc_shmem_request();
|
||||
#endif
|
||||
}
|
||||
|
||||
FileCacheState*
|
||||
|
||||
@@ -90,6 +90,7 @@ typedef struct
|
||||
{
|
||||
char connstring[MAX_SHARDS][MAX_PAGESERVER_CONNSTRING_SIZE];
|
||||
size_t num_shards;
|
||||
size_t stripe_size;
|
||||
} ShardMap;
|
||||
|
||||
/*
|
||||
@@ -110,6 +111,11 @@ typedef struct
|
||||
* has changed since last access, and to detect and retry copying the value if
|
||||
* the postmaster changes the value concurrently. (Postmaster doesn't have a
|
||||
* PGPROC entry and therefore cannot use LWLocks.)
|
||||
*
|
||||
* stripe_size is now also part of ShardMap, although it is defined by separate GUC.
|
||||
* Postgres doesn't provide any mechanism to enforce dependencies between GUCs,
|
||||
* that it we we have to rely on order of GUC definition in config file.
|
||||
* "neon.stripe_size" should be defined prior to "neon.pageserver_connstring"
|
||||
*/
|
||||
typedef struct
|
||||
{
|
||||
@@ -118,10 +124,6 @@ typedef struct
|
||||
ShardMap shard_map;
|
||||
} PagestoreShmemState;
|
||||
|
||||
#if PG_VERSION_NUM >= 150000
|
||||
static shmem_request_hook_type prev_shmem_request_hook = NULL;
|
||||
#endif
|
||||
static shmem_startup_hook_type prev_shmem_startup_hook;
|
||||
static PagestoreShmemState *pagestore_shared;
|
||||
static uint64 pagestore_local_counter = 0;
|
||||
|
||||
@@ -234,7 +236,10 @@ ParseShardMap(const char *connstr, ShardMap *result)
|
||||
p = sep + 1;
|
||||
}
|
||||
if (result)
|
||||
{
|
||||
result->num_shards = nshards;
|
||||
result->stripe_size = stripe_size;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -295,12 +300,13 @@ AssignPageserverConnstring(const char *newval, void *extra)
|
||||
* last call, terminates all existing connections to all pageservers.
|
||||
*/
|
||||
static void
|
||||
load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p)
|
||||
load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p, size_t* stripe_size_p)
|
||||
{
|
||||
uint64 begin_update_counter;
|
||||
uint64 end_update_counter;
|
||||
ShardMap *shard_map = &pagestore_shared->shard_map;
|
||||
shardno_t num_shards;
|
||||
size_t stripe_size;
|
||||
|
||||
/*
|
||||
* Postmaster can update the shared memory values concurrently, in which
|
||||
@@ -315,6 +321,7 @@ load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p)
|
||||
end_update_counter = pg_atomic_read_u64(&pagestore_shared->end_update_counter);
|
||||
|
||||
num_shards = shard_map->num_shards;
|
||||
stripe_size = shard_map->stripe_size;
|
||||
if (connstr_p && shard_no < MAX_SHARDS)
|
||||
strlcpy(connstr_p, shard_map->connstring[shard_no], MAX_PAGESERVER_CONNSTRING_SIZE);
|
||||
pg_memory_barrier();
|
||||
@@ -349,6 +356,8 @@ load_shard_map(shardno_t shard_no, char *connstr_p, shardno_t *num_shards_p)
|
||||
|
||||
if (num_shards_p)
|
||||
*num_shards_p = num_shards;
|
||||
if (stripe_size_p)
|
||||
*stripe_size_p = stripe_size;
|
||||
}
|
||||
|
||||
#define MB (1024*1024)
|
||||
@@ -357,9 +366,10 @@ shardno_t
|
||||
get_shard_number(BufferTag *tag)
|
||||
{
|
||||
shardno_t n_shards;
|
||||
size_t stripe_size;
|
||||
uint32 hash;
|
||||
|
||||
load_shard_map(0, NULL, &n_shards);
|
||||
load_shard_map(0, NULL, &n_shards, &stripe_size);
|
||||
|
||||
#if PG_MAJORVERSION_NUM < 16
|
||||
hash = murmurhash32(tag->rnode.relNode);
|
||||
@@ -412,7 +422,7 @@ pageserver_connect(shardno_t shard_no, int elevel)
|
||||
* Note that connstr is used both during connection start, and when we
|
||||
* log the successful connection.
|
||||
*/
|
||||
load_shard_map(shard_no, connstr, NULL);
|
||||
load_shard_map(shard_no, connstr, NULL, NULL);
|
||||
|
||||
switch (shard->state)
|
||||
{
|
||||
@@ -1284,18 +1294,12 @@ check_neon_id(char **newval, void **extra, GucSource source)
|
||||
return **newval == '\0' || HexDecodeString(id, *newval, 16);
|
||||
}
|
||||
|
||||
static Size
|
||||
PagestoreShmemSize(void)
|
||||
{
|
||||
return add_size(sizeof(PagestoreShmemState), NeonPerfCountersShmemSize());
|
||||
}
|
||||
|
||||
static bool
|
||||
void
|
||||
PagestoreShmemInit(void)
|
||||
{
|
||||
bool found;
|
||||
|
||||
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
|
||||
pagestore_shared = ShmemInitStruct("libpagestore shared state",
|
||||
sizeof(PagestoreShmemState),
|
||||
&found);
|
||||
@@ -1306,44 +1310,12 @@ PagestoreShmemInit(void)
|
||||
memset(&pagestore_shared->shard_map, 0, sizeof(ShardMap));
|
||||
AssignPageserverConnstring(page_server_connstring, NULL);
|
||||
}
|
||||
|
||||
NeonPerfCountersShmemInit();
|
||||
|
||||
LWLockRelease(AddinShmemInitLock);
|
||||
return found;
|
||||
}
|
||||
|
||||
static void
|
||||
pagestore_shmem_startup_hook(void)
|
||||
void
|
||||
PagestoreShmemRequest(void)
|
||||
{
|
||||
if (prev_shmem_startup_hook)
|
||||
prev_shmem_startup_hook();
|
||||
|
||||
PagestoreShmemInit();
|
||||
}
|
||||
|
||||
static void
|
||||
pagestore_shmem_request(void)
|
||||
{
|
||||
#if PG_VERSION_NUM >= 150000
|
||||
if (prev_shmem_request_hook)
|
||||
prev_shmem_request_hook();
|
||||
#endif
|
||||
|
||||
RequestAddinShmemSpace(PagestoreShmemSize());
|
||||
}
|
||||
|
||||
static void
|
||||
pagestore_prepare_shmem(void)
|
||||
{
|
||||
#if PG_VERSION_NUM >= 150000
|
||||
prev_shmem_request_hook = shmem_request_hook;
|
||||
shmem_request_hook = pagestore_shmem_request;
|
||||
#else
|
||||
pagestore_shmem_request();
|
||||
#endif
|
||||
prev_shmem_startup_hook = shmem_startup_hook;
|
||||
shmem_startup_hook = pagestore_shmem_startup_hook;
|
||||
RequestAddinShmemSpace(sizeof(PagestoreShmemState));
|
||||
}
|
||||
|
||||
/*
|
||||
@@ -1352,8 +1324,6 @@ pagestore_prepare_shmem(void)
|
||||
void
|
||||
pg_init_libpagestore(void)
|
||||
{
|
||||
pagestore_prepare_shmem();
|
||||
|
||||
DefineCustomStringVariable("neon.pageserver_connstring",
|
||||
"connection string to the page server",
|
||||
NULL,
|
||||
@@ -1504,8 +1474,6 @@ pg_init_libpagestore(void)
|
||||
0,
|
||||
NULL, NULL, NULL);
|
||||
|
||||
relsize_hash_init();
|
||||
|
||||
if (page_server != NULL)
|
||||
neon_log(ERROR, "libpagestore already loaded");
|
||||
|
||||
|
||||
101
pgxn/neon/neon.c
101
pgxn/neon/neon.c
@@ -22,6 +22,7 @@
|
||||
#include "replication/slot.h"
|
||||
#include "replication/walsender.h"
|
||||
#include "storage/proc.h"
|
||||
#include "storage/ipc.h"
|
||||
#include "funcapi.h"
|
||||
#include "access/htup_details.h"
|
||||
#include "utils/builtins.h"
|
||||
@@ -59,11 +60,15 @@ static ExecutorEnd_hook_type prev_ExecutorEnd = NULL;
|
||||
static void neon_ExecutorStart(QueryDesc *queryDesc, int eflags);
|
||||
static void neon_ExecutorEnd(QueryDesc *queryDesc);
|
||||
|
||||
#if PG_MAJORVERSION_NUM >= 16
|
||||
static shmem_startup_hook_type prev_shmem_startup_hook;
|
||||
|
||||
static void neon_shmem_startup_hook(void);
|
||||
static void neon_shmem_request_hook(void);
|
||||
|
||||
#if PG_MAJORVERSION_NUM >= 15
|
||||
static shmem_request_hook_type prev_shmem_request_hook = NULL;
|
||||
#endif
|
||||
|
||||
|
||||
#if PG_MAJORVERSION_NUM >= 17
|
||||
uint32 WAIT_EVENT_NEON_LFC_MAINTENANCE;
|
||||
uint32 WAIT_EVENT_NEON_LFC_READ;
|
||||
@@ -450,15 +455,44 @@ _PG_init(void)
|
||||
*/
|
||||
#if PG_VERSION_NUM >= 160000
|
||||
load_file("$libdir/neon_rmgr", false);
|
||||
|
||||
prev_shmem_startup_hook = shmem_startup_hook;
|
||||
shmem_startup_hook = neon_shmem_startup_hook;
|
||||
#endif
|
||||
|
||||
/* dummy call to a Rust function in the communicator library, to check that it works */
|
||||
(void) communicator_dummy(123);
|
||||
|
||||
/*
|
||||
* Initializing a pre-loaded Postgres extension happens in three stages:
|
||||
*
|
||||
* 1. _PG_init() is called early at postmaster startup. In this stage, no
|
||||
* shared memory has been allocated yet. Core Postgres GUCs have been
|
||||
* initialized from the config files, but notably, MaxBackends has not
|
||||
* calculated yet. In this stage, we must register any extension GUCs
|
||||
* and can do other early initialization that doesn't depend on shared
|
||||
* memory. In this stage we must also register "shmem request" and
|
||||
* "shmem starutup" hooks, to be called in stages 2 and 3.
|
||||
*
|
||||
* 2. After MaxBackends have been calculated, the "shmem request" hooks
|
||||
* are called. The hooks can reserve shared memory by calling
|
||||
* RequestAddinShmemSpace and RequestNamedLWLockTranche(). The "shmem
|
||||
* request hooks" are a new mechanism in Postgres v15. In v14 and
|
||||
* below, you had to make those Requests in stage 1 already, which
|
||||
* means they could not depend on MaxBackends. (See hack in
|
||||
* NeonPerfCountersShmemRequest())
|
||||
*
|
||||
* 3. After some more runtime-computed GUCs that affect the amount of
|
||||
* shared memory needed have been calculated, the "shmem startup" hooks
|
||||
* are called. In this stage, we allocate any shared memory, LWLocks
|
||||
* and other shared resources.
|
||||
*
|
||||
* Here, in the 'neon' extension, we register just one shmem request hook
|
||||
* and one startup hook, which call into functions in all the subsystems
|
||||
* that are part of the extension. On v14, the ShmemRequest functions are
|
||||
* called in stage 1, and on v15 onwards they are called in stage 2.
|
||||
*/
|
||||
|
||||
/* Stage 1: Define GUCs, and other early intialization */
|
||||
pg_init_libpagestore();
|
||||
relsize_hash_init();
|
||||
lfc_init();
|
||||
pg_init_walproposer();
|
||||
init_lwlsncache();
|
||||
@@ -561,6 +595,22 @@ _PG_init(void)
|
||||
|
||||
ReportSearchPath();
|
||||
|
||||
/*
|
||||
* Register initialization hooks for stage 2. (On v14, there's no "shmem
|
||||
* request" hooks, so call the ShmemRequest functions immediately.)
|
||||
*/
|
||||
#if PG_VERSION_NUM >= 150000
|
||||
prev_shmem_request_hook = shmem_request_hook;
|
||||
shmem_request_hook = neon_shmem_request_hook;
|
||||
#else
|
||||
neon_shmem_request_hook();
|
||||
#endif
|
||||
|
||||
/* Register hooks for stage 3 */
|
||||
prev_shmem_startup_hook = shmem_startup_hook;
|
||||
shmem_startup_hook = neon_shmem_startup_hook;
|
||||
|
||||
/* Other misc initialization */
|
||||
prev_ExecutorStart = ExecutorStart_hook;
|
||||
ExecutorStart_hook = neon_ExecutorStart;
|
||||
prev_ExecutorEnd = ExecutorEnd_hook;
|
||||
@@ -646,7 +696,34 @@ approximate_working_set_size(PG_FUNCTION_ARGS)
|
||||
PG_RETURN_INT32(dc);
|
||||
}
|
||||
|
||||
#if PG_MAJORVERSION_NUM >= 16
|
||||
/*
|
||||
* Initialization stage 2: make requests for the amount of shared memory we
|
||||
* will need.
|
||||
*
|
||||
* For a high-level explanation of the initialization process, see _PG_init().
|
||||
*/
|
||||
static void
|
||||
neon_shmem_request_hook(void)
|
||||
{
|
||||
#if PG_VERSION_NUM >= 150000
|
||||
if (prev_shmem_request_hook)
|
||||
prev_shmem_request_hook();
|
||||
#endif
|
||||
|
||||
LfcShmemRequest();
|
||||
NeonPerfCountersShmemRequest();
|
||||
PagestoreShmemRequest();
|
||||
RelsizeCacheShmemRequest();
|
||||
WalproposerShmemRequest();
|
||||
LwLsnCacheShmemRequest();
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* Initialization stage 3: Initialize shared memory.
|
||||
*
|
||||
* For a high-level explanation of the initialization process, see _PG_init().
|
||||
*/
|
||||
static void
|
||||
neon_shmem_startup_hook(void)
|
||||
{
|
||||
@@ -654,6 +731,15 @@ neon_shmem_startup_hook(void)
|
||||
if (prev_shmem_startup_hook)
|
||||
prev_shmem_startup_hook();
|
||||
|
||||
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
|
||||
|
||||
LfcShmemInit();
|
||||
NeonPerfCountersShmemInit();
|
||||
PagestoreShmemInit();
|
||||
RelsizeCacheShmemInit();
|
||||
WalproposerShmemInit();
|
||||
LwLsnCacheShmemInit();
|
||||
|
||||
#if PG_MAJORVERSION_NUM >= 17
|
||||
WAIT_EVENT_NEON_LFC_MAINTENANCE = WaitEventExtensionNew("Neon/FileCache_Maintenance");
|
||||
WAIT_EVENT_NEON_LFC_READ = WaitEventExtensionNew("Neon/FileCache_Read");
|
||||
@@ -666,8 +752,9 @@ neon_shmem_startup_hook(void)
|
||||
WAIT_EVENT_NEON_PS_READ = WaitEventExtensionNew("Neon/PS_ReadIO");
|
||||
WAIT_EVENT_NEON_WAL_DL = WaitEventExtensionNew("Neon/WAL_Download");
|
||||
#endif
|
||||
|
||||
LWLockRelease(AddinShmemInitLock);
|
||||
}
|
||||
#endif
|
||||
|
||||
/*
|
||||
* ExecutorStart hook: start up tracking if needed
|
||||
|
||||
@@ -70,4 +70,19 @@ extern PGDLLEXPORT void WalProposerSync(int argc, char *argv[]);
|
||||
extern PGDLLEXPORT void WalProposerMain(Datum main_arg);
|
||||
extern PGDLLEXPORT void LogicalSlotsMonitorMain(Datum main_arg);
|
||||
|
||||
extern void LfcShmemRequest(void);
|
||||
extern void PagestoreShmemRequest(void);
|
||||
extern void RelsizeCacheShmemRequest(void);
|
||||
extern void WalproposerShmemRequest(void);
|
||||
extern void LwLsnCacheShmemRequest(void);
|
||||
extern void NeonPerfCountersShmemRequest(void);
|
||||
|
||||
extern void LfcShmemInit(void);
|
||||
extern void PagestoreShmemInit(void);
|
||||
extern void RelsizeCacheShmemInit(void);
|
||||
extern void WalproposerShmemInit(void);
|
||||
extern void LwLsnCacheShmemInit(void);
|
||||
extern void NeonPerfCountersShmemInit(void);
|
||||
|
||||
|
||||
#endif /* NEON_H */
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "postgres.h"
|
||||
|
||||
#include "neon.h"
|
||||
#include "neon_lwlsncache.h"
|
||||
|
||||
#include "miscadmin.h"
|
||||
@@ -81,14 +82,6 @@ static set_max_lwlsn_hook_type prev_set_max_lwlsn_hook = NULL;
|
||||
static set_lwlsn_relation_hook_type prev_set_lwlsn_relation_hook = NULL;
|
||||
static set_lwlsn_db_hook_type prev_set_lwlsn_db_hook = NULL;
|
||||
|
||||
static shmem_startup_hook_type prev_shmem_startup_hook;
|
||||
|
||||
#if PG_VERSION_NUM >= 150000
|
||||
static shmem_request_hook_type prev_shmem_request_hook;
|
||||
#endif
|
||||
|
||||
static void shmemrequest(void);
|
||||
static void shmeminit(void);
|
||||
static void neon_set_max_lwlsn(XLogRecPtr lsn);
|
||||
|
||||
void
|
||||
@@ -99,16 +92,6 @@ init_lwlsncache(void)
|
||||
|
||||
lwlc_register_gucs();
|
||||
|
||||
prev_shmem_startup_hook = shmem_startup_hook;
|
||||
shmem_startup_hook = shmeminit;
|
||||
|
||||
#if PG_VERSION_NUM >= 150000
|
||||
prev_shmem_request_hook = shmem_request_hook;
|
||||
shmem_request_hook = shmemrequest;
|
||||
#else
|
||||
shmemrequest();
|
||||
#endif
|
||||
|
||||
prev_set_lwlsn_block_range_hook = set_lwlsn_block_range_hook;
|
||||
set_lwlsn_block_range_hook = neon_set_lwlsn_block_range;
|
||||
prev_set_lwlsn_block_v_hook = set_lwlsn_block_v_hook;
|
||||
@@ -124,20 +107,19 @@ init_lwlsncache(void)
|
||||
}
|
||||
|
||||
|
||||
static void shmemrequest(void) {
|
||||
void
|
||||
LwLsnCacheShmemRequest(void)
|
||||
{
|
||||
Size requested_size = sizeof(LwLsnCacheCtl);
|
||||
|
||||
|
||||
requested_size += hash_estimate_size(lwlsn_cache_size, sizeof(LastWrittenLsnCacheEntry));
|
||||
|
||||
RequestAddinShmemSpace(requested_size);
|
||||
|
||||
#if PG_VERSION_NUM >= 150000
|
||||
if (prev_shmem_request_hook)
|
||||
prev_shmem_request_hook();
|
||||
#endif
|
||||
}
|
||||
|
||||
static void shmeminit(void) {
|
||||
void
|
||||
LwLsnCacheShmemInit(void)
|
||||
{
|
||||
static HASHCTL info;
|
||||
bool found;
|
||||
if (lwlsn_cache_size > 0)
|
||||
@@ -157,9 +139,6 @@ static void shmeminit(void) {
|
||||
}
|
||||
dlist_init(&LwLsnCache->lastWrittenLsnLRU);
|
||||
LwLsnCache->maxLastWrittenLsn = GetRedoRecPtr();
|
||||
if (prev_shmem_startup_hook) {
|
||||
prev_shmem_startup_hook();
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
@@ -17,22 +17,32 @@
|
||||
#include "storage/shmem.h"
|
||||
#include "utils/builtins.h"
|
||||
|
||||
#include "neon.h"
|
||||
#include "neon_perf_counters.h"
|
||||
#include "neon_pgversioncompat.h"
|
||||
|
||||
neon_per_backend_counters *neon_per_backend_counters_shared;
|
||||
|
||||
Size
|
||||
NeonPerfCountersShmemSize(void)
|
||||
void
|
||||
NeonPerfCountersShmemRequest(void)
|
||||
{
|
||||
Size size = 0;
|
||||
|
||||
size = add_size(size, mul_size(NUM_NEON_PERF_COUNTER_SLOTS,
|
||||
sizeof(neon_per_backend_counters)));
|
||||
|
||||
return size;
|
||||
Size size;
|
||||
#if PG_MAJORVERSION_NUM < 15
|
||||
/* Hack: in PG14 MaxBackends is not initialized at the time of calling NeonPerfCountersShmemRequest function.
|
||||
* Do it ourselves and then undo to prevent assertion failure
|
||||
*/
|
||||
Assert(MaxBackends == 0); /* not initialized yet */
|
||||
InitializeMaxBackends();
|
||||
size = mul_size(NUM_NEON_PERF_COUNTER_SLOTS, sizeof(neon_per_backend_counters));
|
||||
MaxBackends = 0;
|
||||
#else
|
||||
size = mul_size(NUM_NEON_PERF_COUNTER_SLOTS, sizeof(neon_per_backend_counters));
|
||||
#endif
|
||||
RequestAddinShmemSpace(size);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void
|
||||
NeonPerfCountersShmemInit(void)
|
||||
{
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
*/
|
||||
#include "postgres.h"
|
||||
|
||||
#include "neon.h"
|
||||
#include "neon_pgversioncompat.h"
|
||||
|
||||
#include "pagestore_client.h"
|
||||
@@ -49,32 +50,23 @@ typedef struct
|
||||
* algorithm */
|
||||
} RelSizeHashControl;
|
||||
|
||||
static HTAB *relsize_hash;
|
||||
static LWLockId relsize_lock;
|
||||
static int relsize_hash_size;
|
||||
static RelSizeHashControl* relsize_ctl;
|
||||
static shmem_startup_hook_type prev_shmem_startup_hook = NULL;
|
||||
#if PG_VERSION_NUM >= 150000
|
||||
static shmem_request_hook_type prev_shmem_request_hook = NULL;
|
||||
static void relsize_shmem_request(void);
|
||||
#endif
|
||||
|
||||
/*
|
||||
* Size of a cache entry is 36 bytes. So this default will take about 2.3 MB,
|
||||
* which seems reasonable.
|
||||
*/
|
||||
#define DEFAULT_RELSIZE_HASH_SIZE (64 * 1024)
|
||||
|
||||
static void
|
||||
neon_smgr_shmem_startup(void)
|
||||
static HTAB *relsize_hash;
|
||||
static LWLockId relsize_lock;
|
||||
static int relsize_hash_size = DEFAULT_RELSIZE_HASH_SIZE;
|
||||
static RelSizeHashControl* relsize_ctl;
|
||||
|
||||
void
|
||||
RelsizeCacheShmemInit(void)
|
||||
{
|
||||
static HASHCTL info;
|
||||
bool found;
|
||||
|
||||
if (prev_shmem_startup_hook)
|
||||
prev_shmem_startup_hook();
|
||||
|
||||
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
|
||||
relsize_ctl = (RelSizeHashControl *) ShmemInitStruct("relsize_hash", sizeof(RelSizeHashControl), &found);
|
||||
if (!found)
|
||||
{
|
||||
@@ -85,7 +77,6 @@ neon_smgr_shmem_startup(void)
|
||||
relsize_hash_size, relsize_hash_size,
|
||||
&info,
|
||||
HASH_ELEM | HASH_BLOBS);
|
||||
LWLockRelease(AddinShmemInitLock);
|
||||
relsize_ctl->size = 0;
|
||||
relsize_ctl->hits = 0;
|
||||
relsize_ctl->misses = 0;
|
||||
@@ -242,34 +233,15 @@ relsize_hash_init(void)
|
||||
PGC_POSTMASTER,
|
||||
0,
|
||||
NULL, NULL, NULL);
|
||||
|
||||
if (relsize_hash_size > 0)
|
||||
{
|
||||
#if PG_VERSION_NUM >= 150000
|
||||
prev_shmem_request_hook = shmem_request_hook;
|
||||
shmem_request_hook = relsize_shmem_request;
|
||||
#else
|
||||
RequestAddinShmemSpace(hash_estimate_size(relsize_hash_size, sizeof(RelSizeEntry)));
|
||||
RequestNamedLWLockTranche("neon_relsize", 1);
|
||||
#endif
|
||||
|
||||
prev_shmem_startup_hook = shmem_startup_hook;
|
||||
shmem_startup_hook = neon_smgr_shmem_startup;
|
||||
}
|
||||
}
|
||||
|
||||
#if PG_VERSION_NUM >= 150000
|
||||
/*
|
||||
* shmem_request hook: request additional shared resources. We'll allocate or
|
||||
* attach to the shared resources in neon_smgr_shmem_startup().
|
||||
*/
|
||||
static void
|
||||
relsize_shmem_request(void)
|
||||
void
|
||||
RelsizeCacheShmemRequest(void)
|
||||
{
|
||||
if (prev_shmem_request_hook)
|
||||
prev_shmem_request_hook();
|
||||
|
||||
RequestAddinShmemSpace(sizeof(RelSizeHashControl) + hash_estimate_size(relsize_hash_size, sizeof(RelSizeEntry)));
|
||||
RequestNamedLWLockTranche("neon_relsize", 1);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -83,10 +83,8 @@ static XLogRecPtr standby_flush_lsn = InvalidXLogRecPtr;
|
||||
static XLogRecPtr standby_apply_lsn = InvalidXLogRecPtr;
|
||||
static HotStandbyFeedback agg_hs_feedback;
|
||||
|
||||
static void nwp_shmem_startup_hook(void);
|
||||
static void nwp_register_gucs(void);
|
||||
static void assign_neon_safekeepers(const char *newval, void *extra);
|
||||
static void nwp_prepare_shmem(void);
|
||||
static uint64 backpressure_lag_impl(void);
|
||||
static uint64 startup_backpressure_wrap(void);
|
||||
static bool backpressure_throttling_impl(void);
|
||||
@@ -99,11 +97,6 @@ static TimestampTz walprop_pg_get_current_timestamp(WalProposer *wp);
|
||||
static void walprop_pg_load_libpqwalreceiver(void);
|
||||
|
||||
static process_interrupts_callback_t PrevProcessInterruptsCallback = NULL;
|
||||
static shmem_startup_hook_type prev_shmem_startup_hook_type;
|
||||
#if PG_VERSION_NUM >= 150000
|
||||
static shmem_request_hook_type prev_shmem_request_hook = NULL;
|
||||
static void walproposer_shmem_request(void);
|
||||
#endif
|
||||
static void WalproposerShmemInit_SyncSafekeeper(void);
|
||||
|
||||
|
||||
@@ -193,8 +186,6 @@ pg_init_walproposer(void)
|
||||
|
||||
nwp_register_gucs();
|
||||
|
||||
nwp_prepare_shmem();
|
||||
|
||||
delay_backend_us = &startup_backpressure_wrap;
|
||||
PrevProcessInterruptsCallback = ProcessInterruptsCallback;
|
||||
ProcessInterruptsCallback = backpressure_throttling_impl;
|
||||
@@ -494,12 +485,11 @@ WalproposerShmemSize(void)
|
||||
return sizeof(WalproposerShmemState);
|
||||
}
|
||||
|
||||
static bool
|
||||
void
|
||||
WalproposerShmemInit(void)
|
||||
{
|
||||
bool found;
|
||||
|
||||
LWLockAcquire(AddinShmemInitLock, LW_EXCLUSIVE);
|
||||
walprop_shared = ShmemInitStruct("Walproposer shared state",
|
||||
sizeof(WalproposerShmemState),
|
||||
&found);
|
||||
@@ -517,9 +507,6 @@ WalproposerShmemInit(void)
|
||||
pg_atomic_init_u64(&walprop_shared->wal_rate_limiter.last_recorded_time_us, 0);
|
||||
/* END_HADRON */
|
||||
}
|
||||
LWLockRelease(AddinShmemInitLock);
|
||||
|
||||
return found;
|
||||
}
|
||||
|
||||
static void
|
||||
@@ -623,42 +610,15 @@ walprop_register_bgworker(void)
|
||||
|
||||
/* shmem handling */
|
||||
|
||||
static void
|
||||
nwp_prepare_shmem(void)
|
||||
{
|
||||
#if PG_VERSION_NUM >= 150000
|
||||
prev_shmem_request_hook = shmem_request_hook;
|
||||
shmem_request_hook = walproposer_shmem_request;
|
||||
#else
|
||||
RequestAddinShmemSpace(WalproposerShmemSize());
|
||||
#endif
|
||||
prev_shmem_startup_hook_type = shmem_startup_hook;
|
||||
shmem_startup_hook = nwp_shmem_startup_hook;
|
||||
}
|
||||
|
||||
#if PG_VERSION_NUM >= 150000
|
||||
/*
|
||||
* shmem_request hook: request additional shared resources. We'll allocate or
|
||||
* attach to the shared resources in nwp_shmem_startup_hook().
|
||||
* attach to the shared resources in WalproposerShmemInit().
|
||||
*/
|
||||
static void
|
||||
walproposer_shmem_request(void)
|
||||
void
|
||||
WalproposerShmemRequest(void)
|
||||
{
|
||||
if (prev_shmem_request_hook)
|
||||
prev_shmem_request_hook();
|
||||
|
||||
RequestAddinShmemSpace(WalproposerShmemSize());
|
||||
}
|
||||
#endif
|
||||
|
||||
static void
|
||||
nwp_shmem_startup_hook(void)
|
||||
{
|
||||
if (prev_shmem_startup_hook_type)
|
||||
prev_shmem_startup_hook_type();
|
||||
|
||||
WalproposerShmemInit();
|
||||
}
|
||||
|
||||
WalproposerShmemState *
|
||||
GetWalpropShmemState(void)
|
||||
|
||||
@@ -98,6 +98,7 @@ tracing-log.workspace = true
|
||||
tracing-opentelemetry.workspace = true
|
||||
try-lock.workspace = true
|
||||
typed-json.workspace = true
|
||||
type-safe-id = { version = "0.3.3", features = ["serde"] }
|
||||
url.workspace = true
|
||||
urlencoding.workspace = true
|
||||
utils.workspace = true
|
||||
|
||||
@@ -26,6 +26,7 @@ use utils::project_git_version;
|
||||
use utils::sentry_init::init_sentry;
|
||||
|
||||
use crate::context::RequestContext;
|
||||
use crate::id::{ClientConnId, RequestId};
|
||||
use crate::metrics::{Metrics, ThreadPoolMetrics};
|
||||
use crate::pglb::TlsRequired;
|
||||
use crate::pqproto::FeStartupPacket;
|
||||
@@ -219,7 +220,8 @@ pub(super) async fn task_main(
|
||||
{
|
||||
let (socket, peer_addr) = accept_result?;
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let conn_id = ClientConnId::new();
|
||||
let session_id = RequestId::from_uuid(conn_id.uuid());
|
||||
let tls_config = Arc::clone(&tls_config);
|
||||
let dest_suffix = Arc::clone(&dest_suffix);
|
||||
let compute_tls_config = compute_tls_config.clone();
|
||||
@@ -231,6 +233,7 @@ pub(super) async fn task_main(
|
||||
.context("failed to set socket option")?;
|
||||
|
||||
let ctx = RequestContext::new(
|
||||
conn_id,
|
||||
session_id,
|
||||
ConnectionInfo {
|
||||
addr: peer_addr,
|
||||
@@ -252,7 +255,7 @@ pub(super) async fn task_main(
|
||||
// Acknowledge that the task has finished with an error.
|
||||
error!("per-client task finished with an error: {e:#}");
|
||||
})
|
||||
.instrument(tracing::info_span!("handle_client", ?session_id)),
|
||||
.instrument(tracing::info_span!("handle_client", %session_id)),
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
283
proxy/src/cache/project_info.rs
vendored
283
proxy/src/cache/project_info.rs
vendored
@@ -10,6 +10,7 @@ use tokio::time::Instant;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::config::ProjectInfoCacheOptions;
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, Reason};
|
||||
use crate::control_plane::{EndpointAccessControl, RoleAccessControl};
|
||||
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt, RoleNameInt};
|
||||
use crate::types::{EndpointId, RoleName};
|
||||
@@ -36,22 +37,37 @@ impl<T> Entry<T> {
|
||||
}
|
||||
|
||||
pub(crate) fn get(&self) -> Option<&T> {
|
||||
(self.expires_at > Instant::now()).then_some(&self.value)
|
||||
(!self.is_expired()).then_some(&self.value)
|
||||
}
|
||||
|
||||
fn is_expired(&self) -> bool {
|
||||
self.expires_at <= Instant::now()
|
||||
}
|
||||
}
|
||||
|
||||
struct EndpointInfo {
|
||||
role_controls: HashMap<RoleNameInt, Entry<RoleAccessControl>>,
|
||||
controls: Option<Entry<EndpointAccessControl>>,
|
||||
role_controls: HashMap<RoleNameInt, Entry<ControlPlaneResult<RoleAccessControl>>>,
|
||||
controls: Option<Entry<ControlPlaneResult<EndpointAccessControl>>>,
|
||||
}
|
||||
|
||||
type ControlPlaneResult<T> = Result<T, Box<ControlPlaneErrorMessage>>;
|
||||
|
||||
impl EndpointInfo {
|
||||
pub(crate) fn get_role_secret(&self, role_name: RoleNameInt) -> Option<RoleAccessControl> {
|
||||
self.role_controls.get(&role_name)?.get().cloned()
|
||||
pub(crate) fn get_role_secret_with_ttl(
|
||||
&self,
|
||||
role_name: RoleNameInt,
|
||||
) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
|
||||
let entry = self.role_controls.get(&role_name)?;
|
||||
let ttl = entry.expires_at - Instant::now();
|
||||
Some((entry.get()?.clone(), ttl))
|
||||
}
|
||||
|
||||
pub(crate) fn get_controls(&self) -> Option<EndpointAccessControl> {
|
||||
self.controls.as_ref()?.get().cloned()
|
||||
pub(crate) fn get_controls_with_ttl(
|
||||
&self,
|
||||
) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
|
||||
let entry = self.controls.as_ref()?;
|
||||
let ttl = entry.expires_at - Instant::now();
|
||||
Some((entry.get()?.clone(), ttl))
|
||||
}
|
||||
|
||||
pub(crate) fn invalidate_endpoint(&mut self) {
|
||||
@@ -153,28 +169,28 @@ impl ProjectInfoCacheImpl {
|
||||
self.cache.get(&endpoint_id)
|
||||
}
|
||||
|
||||
pub(crate) fn get_role_secret(
|
||||
pub(crate) fn get_role_secret_with_ttl(
|
||||
&self,
|
||||
endpoint_id: &EndpointId,
|
||||
role_name: &RoleName,
|
||||
) -> Option<RoleAccessControl> {
|
||||
) -> Option<(ControlPlaneResult<RoleAccessControl>, Duration)> {
|
||||
let role_name = RoleNameInt::get(role_name)?;
|
||||
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
|
||||
endpoint_info.get_role_secret(role_name)
|
||||
endpoint_info.get_role_secret_with_ttl(role_name)
|
||||
}
|
||||
|
||||
pub(crate) fn get_endpoint_access(
|
||||
pub(crate) fn get_endpoint_access_with_ttl(
|
||||
&self,
|
||||
endpoint_id: &EndpointId,
|
||||
) -> Option<EndpointAccessControl> {
|
||||
) -> Option<(ControlPlaneResult<EndpointAccessControl>, Duration)> {
|
||||
let endpoint_info = self.get_endpoint_cache(endpoint_id)?;
|
||||
endpoint_info.get_controls()
|
||||
endpoint_info.get_controls_with_ttl()
|
||||
}
|
||||
|
||||
pub(crate) fn insert_endpoint_access(
|
||||
&self,
|
||||
account_id: Option<AccountIdInt>,
|
||||
project_id: ProjectIdInt,
|
||||
project_id: Option<ProjectIdInt>,
|
||||
endpoint_id: EndpointIdInt,
|
||||
role_name: RoleNameInt,
|
||||
controls: EndpointAccessControl,
|
||||
@@ -183,26 +199,89 @@ impl ProjectInfoCacheImpl {
|
||||
if let Some(account_id) = account_id {
|
||||
self.insert_account2endpoint(account_id, endpoint_id);
|
||||
}
|
||||
self.insert_project2endpoint(project_id, endpoint_id);
|
||||
if let Some(project_id) = project_id {
|
||||
self.insert_project2endpoint(project_id, endpoint_id);
|
||||
}
|
||||
|
||||
if self.cache.len() >= self.config.size {
|
||||
// If there are too many entries, wait until the next gc cycle.
|
||||
return;
|
||||
}
|
||||
|
||||
let controls = Entry::new(controls, self.config.ttl);
|
||||
let role_controls = Entry::new(role_controls, self.config.ttl);
|
||||
debug!(
|
||||
key = &*endpoint_id,
|
||||
"created a cache entry for endpoint access"
|
||||
);
|
||||
|
||||
let controls = Some(Entry::new(Ok(controls), self.config.ttl));
|
||||
let role_controls = Entry::new(Ok(role_controls), self.config.ttl);
|
||||
|
||||
match self.cache.entry(endpoint_id) {
|
||||
clashmap::Entry::Vacant(e) => {
|
||||
e.insert(EndpointInfo {
|
||||
role_controls: HashMap::from_iter([(role_name, role_controls)]),
|
||||
controls: Some(controls),
|
||||
controls,
|
||||
});
|
||||
}
|
||||
clashmap::Entry::Occupied(mut e) => {
|
||||
let ep = e.get_mut();
|
||||
ep.controls = Some(controls);
|
||||
ep.controls = controls;
|
||||
if ep.role_controls.len() < self.config.max_roles {
|
||||
ep.role_controls.insert(role_name, role_controls);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn insert_endpoint_access_err(
|
||||
&self,
|
||||
endpoint_id: EndpointIdInt,
|
||||
role_name: RoleNameInt,
|
||||
msg: Box<ControlPlaneErrorMessage>,
|
||||
ttl: Option<Duration>,
|
||||
) {
|
||||
if self.cache.len() >= self.config.size {
|
||||
// If there are too many entries, wait until the next gc cycle.
|
||||
return;
|
||||
}
|
||||
|
||||
debug!(
|
||||
key = &*endpoint_id,
|
||||
"created a cache entry for an endpoint access error"
|
||||
);
|
||||
|
||||
let ttl = ttl.unwrap_or(self.config.ttl);
|
||||
|
||||
let controls = if msg.get_reason() == Reason::RoleProtected {
|
||||
// RoleProtected is the only role-specific error that control plane can give us.
|
||||
// If a given role name does not exist, it still returns a successful response,
|
||||
// just with an empty secret.
|
||||
None
|
||||
} else {
|
||||
// We can cache all the other errors in EndpointInfo.controls,
|
||||
// because they don't depend on what role name we pass to control plane.
|
||||
Some(Entry::new(Err(msg.clone()), ttl))
|
||||
};
|
||||
|
||||
let role_controls = Entry::new(Err(msg), ttl);
|
||||
|
||||
match self.cache.entry(endpoint_id) {
|
||||
clashmap::Entry::Vacant(e) => {
|
||||
e.insert(EndpointInfo {
|
||||
role_controls: HashMap::from_iter([(role_name, role_controls)]),
|
||||
controls,
|
||||
});
|
||||
}
|
||||
clashmap::Entry::Occupied(mut e) => {
|
||||
let ep = e.get_mut();
|
||||
if let Some(entry) = &ep.controls
|
||||
&& !entry.is_expired()
|
||||
&& entry.value.is_ok()
|
||||
{
|
||||
// If we have cached non-expired, non-error controls, keep them.
|
||||
} else {
|
||||
ep.controls = controls;
|
||||
}
|
||||
if ep.role_controls.len() < self.config.max_roles {
|
||||
ep.role_controls.insert(role_name, role_controls);
|
||||
}
|
||||
@@ -245,7 +324,7 @@ impl ProjectInfoCacheImpl {
|
||||
return;
|
||||
};
|
||||
|
||||
if role_controls.get().expires_at <= Instant::now() {
|
||||
if role_controls.get().is_expired() {
|
||||
role_controls.remove();
|
||||
}
|
||||
}
|
||||
@@ -287,10 +366,9 @@ mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
use crate::control_plane::messages::EndpointRateLimitConfig;
|
||||
use crate::control_plane::messages::{Details, EndpointRateLimitConfig, ErrorInfo, Status};
|
||||
use crate::control_plane::{AccessBlockerFlags, AuthSecret};
|
||||
use crate::scram::ServerSecret;
|
||||
use crate::types::ProjectId;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_project_info_cache_settings() {
|
||||
@@ -301,9 +379,9 @@ mod tests {
|
||||
ttl: Duration::from_secs(1),
|
||||
gc_interval: Duration::from_secs(600),
|
||||
});
|
||||
let project_id: ProjectId = "project".into();
|
||||
let project_id: Option<ProjectIdInt> = Some(ProjectIdInt::from(&"project".into()));
|
||||
let endpoint_id: EndpointId = "endpoint".into();
|
||||
let account_id: Option<AccountIdInt> = None;
|
||||
let account_id = None;
|
||||
|
||||
let user1: RoleName = "user1".into();
|
||||
let user2: RoleName = "user2".into();
|
||||
@@ -316,7 +394,7 @@ mod tests {
|
||||
|
||||
cache.insert_endpoint_access(
|
||||
account_id,
|
||||
(&project_id).into(),
|
||||
project_id,
|
||||
(&endpoint_id).into(),
|
||||
(&user1).into(),
|
||||
EndpointAccessControl {
|
||||
@@ -332,7 +410,7 @@ mod tests {
|
||||
|
||||
cache.insert_endpoint_access(
|
||||
account_id,
|
||||
(&project_id).into(),
|
||||
project_id,
|
||||
(&endpoint_id).into(),
|
||||
(&user2).into(),
|
||||
EndpointAccessControl {
|
||||
@@ -346,11 +424,17 @@ mod tests {
|
||||
},
|
||||
);
|
||||
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
|
||||
assert_eq!(cached.secret, secret1);
|
||||
let (cached, ttl) = cache
|
||||
.get_role_secret_with_ttl(&endpoint_id, &user1)
|
||||
.unwrap();
|
||||
assert_eq!(cached.unwrap().secret, secret1);
|
||||
assert_eq!(ttl, cache.config.ttl);
|
||||
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user2).unwrap();
|
||||
assert_eq!(cached.secret, secret2);
|
||||
let (cached, ttl) = cache
|
||||
.get_role_secret_with_ttl(&endpoint_id, &user2)
|
||||
.unwrap();
|
||||
assert_eq!(cached.unwrap().secret, secret2);
|
||||
assert_eq!(ttl, cache.config.ttl);
|
||||
|
||||
// Shouldn't add more than 2 roles.
|
||||
let user3: RoleName = "user3".into();
|
||||
@@ -358,7 +442,7 @@ mod tests {
|
||||
|
||||
cache.insert_endpoint_access(
|
||||
account_id,
|
||||
(&project_id).into(),
|
||||
project_id,
|
||||
(&endpoint_id).into(),
|
||||
(&user3).into(),
|
||||
EndpointAccessControl {
|
||||
@@ -372,17 +456,144 @@ mod tests {
|
||||
},
|
||||
);
|
||||
|
||||
assert!(cache.get_role_secret(&endpoint_id, &user3).is_none());
|
||||
assert!(
|
||||
cache
|
||||
.get_role_secret_with_ttl(&endpoint_id, &user3)
|
||||
.is_none()
|
||||
);
|
||||
|
||||
let cached = cache.get_endpoint_access(&endpoint_id).unwrap();
|
||||
let cached = cache
|
||||
.get_endpoint_access_with_ttl(&endpoint_id)
|
||||
.unwrap()
|
||||
.0
|
||||
.unwrap();
|
||||
assert_eq!(cached.allowed_ips, allowed_ips);
|
||||
|
||||
tokio::time::advance(Duration::from_secs(2)).await;
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user1);
|
||||
let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user1);
|
||||
assert!(cached.is_none());
|
||||
let cached = cache.get_role_secret(&endpoint_id, &user2);
|
||||
let cached = cache.get_role_secret_with_ttl(&endpoint_id, &user2);
|
||||
assert!(cached.is_none());
|
||||
let cached = cache.get_endpoint_access(&endpoint_id);
|
||||
let cached = cache.get_endpoint_access_with_ttl(&endpoint_id);
|
||||
assert!(cached.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_caching_project_info_errors() {
|
||||
let cache = ProjectInfoCacheImpl::new(ProjectInfoCacheOptions {
|
||||
size: 10,
|
||||
max_roles: 10,
|
||||
ttl: Duration::from_secs(1),
|
||||
gc_interval: Duration::from_secs(600),
|
||||
});
|
||||
let project_id = Some(ProjectIdInt::from(&"project".into()));
|
||||
let endpoint_id: EndpointId = "endpoint".into();
|
||||
let account_id = None;
|
||||
|
||||
let user1: RoleName = "user1".into();
|
||||
let user2: RoleName = "user2".into();
|
||||
let secret = Some(AuthSecret::Scram(ServerSecret::mock([1; 32])));
|
||||
|
||||
let role_msg = Box::new(ControlPlaneErrorMessage {
|
||||
error: "role is protected and cannot be used for password-based authentication"
|
||||
.to_owned()
|
||||
.into_boxed_str(),
|
||||
http_status_code: http::StatusCode::NOT_FOUND,
|
||||
status: Some(Status {
|
||||
code: "PERMISSION_DENIED".to_owned().into_boxed_str(),
|
||||
message: "role is protected and cannot be used for password-based authentication"
|
||||
.to_owned()
|
||||
.into_boxed_str(),
|
||||
details: Details {
|
||||
error_info: Some(ErrorInfo {
|
||||
reason: Reason::RoleProtected,
|
||||
}),
|
||||
retry_info: None,
|
||||
user_facing_message: None,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
let generic_msg = Box::new(ControlPlaneErrorMessage {
|
||||
error: "oh noes".to_owned().into_boxed_str(),
|
||||
http_status_code: http::StatusCode::NOT_FOUND,
|
||||
status: None,
|
||||
});
|
||||
|
||||
let get_role_secret = |endpoint_id, role_name| {
|
||||
cache
|
||||
.get_role_secret_with_ttl(endpoint_id, role_name)
|
||||
.unwrap()
|
||||
.0
|
||||
};
|
||||
let get_endpoint_access =
|
||||
|endpoint_id| cache.get_endpoint_access_with_ttl(endpoint_id).unwrap().0;
|
||||
|
||||
// stores role-specific errors only for get_role_secret
|
||||
cache.insert_endpoint_access_err(
|
||||
(&endpoint_id).into(),
|
||||
(&user1).into(),
|
||||
role_msg.clone(),
|
||||
None,
|
||||
);
|
||||
assert_eq!(
|
||||
get_role_secret(&endpoint_id, &user1).unwrap_err().error,
|
||||
role_msg.error
|
||||
);
|
||||
assert!(cache.get_endpoint_access_with_ttl(&endpoint_id).is_none());
|
||||
|
||||
// stores non-role specific errors for both get_role_secret and get_endpoint_access
|
||||
cache.insert_endpoint_access_err(
|
||||
(&endpoint_id).into(),
|
||||
(&user1).into(),
|
||||
generic_msg.clone(),
|
||||
None,
|
||||
);
|
||||
assert_eq!(
|
||||
get_role_secret(&endpoint_id, &user1).unwrap_err().error,
|
||||
generic_msg.error
|
||||
);
|
||||
assert_eq!(
|
||||
get_endpoint_access(&endpoint_id).unwrap_err().error,
|
||||
generic_msg.error
|
||||
);
|
||||
|
||||
// error isn't returned for other roles in the same endpoint
|
||||
assert!(
|
||||
cache
|
||||
.get_role_secret_with_ttl(&endpoint_id, &user2)
|
||||
.is_none()
|
||||
);
|
||||
|
||||
// success for a role does not overwrite errors for other roles
|
||||
cache.insert_endpoint_access(
|
||||
account_id,
|
||||
project_id,
|
||||
(&endpoint_id).into(),
|
||||
(&user2).into(),
|
||||
EndpointAccessControl {
|
||||
allowed_ips: Arc::new(vec![]),
|
||||
allowed_vpce: Arc::new(vec![]),
|
||||
flags: AccessBlockerFlags::default(),
|
||||
rate_limits: EndpointRateLimitConfig::default(),
|
||||
},
|
||||
RoleAccessControl {
|
||||
secret: secret.clone(),
|
||||
},
|
||||
);
|
||||
assert!(get_role_secret(&endpoint_id, &user1).is_err());
|
||||
assert!(get_role_secret(&endpoint_id, &user2).is_ok());
|
||||
// ...but does clear the access control error
|
||||
assert!(get_endpoint_access(&endpoint_id).is_ok());
|
||||
|
||||
// storing an error does not overwrite successful access control response
|
||||
cache.insert_endpoint_access_err(
|
||||
(&endpoint_id).into(),
|
||||
(&user2).into(),
|
||||
generic_msg.clone(),
|
||||
None,
|
||||
);
|
||||
assert!(get_role_secret(&endpoint_id, &user2).is_err());
|
||||
assert!(get_endpoint_access(&endpoint_id).is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ use crate::context::RequestContext;
|
||||
use crate::control_plane::ControlPlaneApi;
|
||||
use crate::error::ReportableError;
|
||||
use crate::ext::LockExt;
|
||||
use crate::id::RequestId;
|
||||
use crate::metrics::{CancelChannelSizeGuard, CancellationRequest, Metrics, RedisMsgKind};
|
||||
use crate::pqproto::CancelKeyData;
|
||||
use crate::rate_limiter::LeakyBucketRateLimiter;
|
||||
@@ -32,8 +33,11 @@ use crate::util::run_until;
|
||||
|
||||
type IpSubnetKey = IpNet;
|
||||
|
||||
const CANCEL_KEY_TTL: Duration = Duration::from_secs(600);
|
||||
const CANCEL_KEY_REFRESH: Duration = Duration::from_secs(570);
|
||||
/// Initial period and TTL is shorter to clear keys of short-lived connections faster.
|
||||
const CANCEL_KEY_INITIAL_PERIOD: Duration = Duration::from_secs(60);
|
||||
const CANCEL_KEY_REFRESH_PERIOD: Duration = Duration::from_secs(10 * 60);
|
||||
/// `CANCEL_KEY_TTL_SLACK` is added to the periods to determine the actual TTL.
|
||||
const CANCEL_KEY_TTL_SLACK: Duration = Duration::from_secs(30);
|
||||
|
||||
// Message types for sending through mpsc channel
|
||||
pub enum CancelKeyOp {
|
||||
@@ -54,6 +58,24 @@ pub enum CancelKeyOp {
|
||||
},
|
||||
}
|
||||
|
||||
impl CancelKeyOp {
|
||||
const fn redis_msg_kind(&self) -> RedisMsgKind {
|
||||
match self {
|
||||
CancelKeyOp::Store { .. } => RedisMsgKind::Set,
|
||||
CancelKeyOp::Refresh { .. } => RedisMsgKind::Expire,
|
||||
CancelKeyOp::Get { .. } => RedisMsgKind::Get,
|
||||
CancelKeyOp::GetOld { .. } => RedisMsgKind::HGet,
|
||||
}
|
||||
}
|
||||
|
||||
fn cancel_channel_metric_guard(&self) -> CancelChannelSizeGuard<'static> {
|
||||
Metrics::get()
|
||||
.proxy
|
||||
.cancel_channel_size
|
||||
.guard(self.redis_msg_kind())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(thiserror::Error, Debug, Clone)]
|
||||
pub enum PipelineError {
|
||||
#[error("could not send cmd to redis: {0}")]
|
||||
@@ -465,7 +487,7 @@ impl Session {
|
||||
/// This is not cancel safe
|
||||
pub(crate) async fn maintain_cancel_key(
|
||||
&self,
|
||||
session_id: uuid::Uuid,
|
||||
session_id: RequestId,
|
||||
cancel: tokio::sync::oneshot::Receiver<Infallible>,
|
||||
cancel_closure: &CancelClosure,
|
||||
compute_config: &ComputeConfig,
|
||||
@@ -483,50 +505,49 @@ impl Session {
|
||||
let mut cancel = pin!(cancel);
|
||||
|
||||
enum State {
|
||||
Set,
|
||||
Init,
|
||||
Refresh,
|
||||
}
|
||||
let mut state = State::Set;
|
||||
|
||||
let mut state = State::Init;
|
||||
loop {
|
||||
let guard_op = match state {
|
||||
State::Set => {
|
||||
let guard = Metrics::get()
|
||||
.proxy
|
||||
.cancel_channel_size
|
||||
.guard(RedisMsgKind::Set);
|
||||
let op = CancelKeyOp::Store {
|
||||
key: self.key,
|
||||
value: closure_json.clone(),
|
||||
expire: CANCEL_KEY_TTL,
|
||||
};
|
||||
let (op, mut wait_interval) = match state {
|
||||
State::Init => {
|
||||
tracing::debug!(
|
||||
src=%self.key,
|
||||
dest=?cancel_closure.cancel_token,
|
||||
"registering cancellation key"
|
||||
);
|
||||
(guard, op)
|
||||
(
|
||||
CancelKeyOp::Store {
|
||||
key: self.key,
|
||||
value: closure_json.clone(),
|
||||
expire: CANCEL_KEY_INITIAL_PERIOD + CANCEL_KEY_TTL_SLACK,
|
||||
},
|
||||
CANCEL_KEY_INITIAL_PERIOD,
|
||||
)
|
||||
}
|
||||
|
||||
State::Refresh => {
|
||||
let guard = Metrics::get()
|
||||
.proxy
|
||||
.cancel_channel_size
|
||||
.guard(RedisMsgKind::Expire);
|
||||
let op = CancelKeyOp::Refresh {
|
||||
key: self.key,
|
||||
expire: CANCEL_KEY_TTL,
|
||||
};
|
||||
tracing::debug!(
|
||||
src=%self.key,
|
||||
dest=?cancel_closure.cancel_token,
|
||||
"refreshing cancellation key"
|
||||
);
|
||||
(guard, op)
|
||||
(
|
||||
CancelKeyOp::Refresh {
|
||||
key: self.key,
|
||||
expire: CANCEL_KEY_REFRESH_PERIOD + CANCEL_KEY_TTL_SLACK,
|
||||
},
|
||||
CANCEL_KEY_REFRESH_PERIOD,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
match tx.call(guard_op, cancel.as_mut()).await {
|
||||
match tx
|
||||
.call((op.cancel_channel_metric_guard(), op), cancel.as_mut())
|
||||
.await
|
||||
{
|
||||
// SET returns OK
|
||||
Ok(Value::Okay) => {
|
||||
tracing::debug!(
|
||||
@@ -549,23 +570,23 @@ impl Session {
|
||||
Ok(_) => {
|
||||
// Any other response likely means the key expired.
|
||||
tracing::warn!(src=%self.key, "refreshing cancellation key failed");
|
||||
// Re-enter the SET loop to repush full data.
|
||||
state = State::Set;
|
||||
// Re-enter the SET loop quickly to repush full data.
|
||||
state = State::Init;
|
||||
wait_interval = Duration::ZERO;
|
||||
}
|
||||
|
||||
// retry immediately.
|
||||
Err(BatchQueueError::Result(error)) => {
|
||||
tracing::warn!(?error, "error refreshing cancellation key");
|
||||
// Small delay to prevent busy loop with high cpu and logging.
|
||||
tokio::time::sleep(Duration::from_millis(10)).await;
|
||||
continue;
|
||||
wait_interval = Duration::from_millis(10);
|
||||
}
|
||||
|
||||
Err(BatchQueueError::Cancelled(Err(_cancelled))) => break,
|
||||
}
|
||||
|
||||
// wait before continuing. break immediately if cancelled.
|
||||
if run_until(tokio::time::sleep(CANCEL_KEY_REFRESH), cancel.as_mut())
|
||||
if run_until(tokio::time::sleep(wait_interval), cancel.as_mut())
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
@@ -579,7 +600,7 @@ impl Session {
|
||||
.await
|
||||
{
|
||||
tracing::warn!(
|
||||
?session_id,
|
||||
%session_id,
|
||||
?err,
|
||||
"could not cancel the query in the database"
|
||||
);
|
||||
|
||||
@@ -25,6 +25,7 @@ use crate::control_plane::client::ApiLockError;
|
||||
use crate::control_plane::errors::WakeComputeError;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::id::ComputeConnId;
|
||||
use crate::metrics::{Metrics, NumDbConnectionsGuard};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::neon_option;
|
||||
@@ -356,6 +357,7 @@ pub struct PostgresSettings {
|
||||
}
|
||||
|
||||
pub struct ComputeConnection {
|
||||
pub compute_conn_id: ComputeConnId,
|
||||
/// Socket connected to a compute node.
|
||||
pub stream: MaybeTlsStream<tokio::net::TcpStream, RustlsStream>,
|
||||
/// Labels for proxy's metrics.
|
||||
@@ -373,6 +375,7 @@ impl ConnectInfo {
|
||||
ctx: &RequestContext,
|
||||
aux: &MetricsAuxInfo,
|
||||
config: &ComputeConfig,
|
||||
compute_conn_id: ComputeConnId,
|
||||
) -> Result<ComputeConnection, ConnectionError> {
|
||||
let pause = ctx.latency_timer_pause(crate::metrics::Waiting::Compute);
|
||||
let (socket_addr, stream) = self.connect_raw(config).await?;
|
||||
@@ -382,6 +385,7 @@ impl ConnectInfo {
|
||||
|
||||
// TODO: lots of useful info but maybe we can move it elsewhere (eg traces?)
|
||||
info!(
|
||||
%compute_conn_id,
|
||||
cold_start_info = ctx.cold_start_info().as_str(),
|
||||
"connected to compute node at {} ({socket_addr}) sslmode={:?}, latency={}, query_id={}",
|
||||
self.host,
|
||||
@@ -391,6 +395,7 @@ impl ConnectInfo {
|
||||
);
|
||||
|
||||
let connection = ComputeConnection {
|
||||
compute_conn_id,
|
||||
stream,
|
||||
socket_addr,
|
||||
hostname: self.host.clone(),
|
||||
|
||||
@@ -10,7 +10,8 @@ use crate::cancellation::CancellationHandler;
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::ReportableError;
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::id::{ClientConnId, ComputeConnId, RequestId};
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard, Protocol};
|
||||
use crate::pglb::ClientRequestError;
|
||||
use crate::pglb::handshake::{HandshakeData, handshake};
|
||||
use crate::pglb::passthrough::ProxyPassthrough;
|
||||
@@ -42,12 +43,10 @@ pub async fn task_main(
|
||||
{
|
||||
let (socket, peer_addr) = accept_result?;
|
||||
|
||||
let conn_gauge = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
.guard(crate::metrics::Protocol::Tcp);
|
||||
let conn_gauge = Metrics::get().proxy.client_connections.guard(Protocol::Tcp);
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let conn_id = ClientConnId::new();
|
||||
let session_id = RequestId::from_uuid(conn_id.uuid());
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
|
||||
@@ -90,7 +89,7 @@ pub async fn task_main(
|
||||
}
|
||||
}
|
||||
|
||||
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
|
||||
let ctx = RequestContext::new(conn_id, session_id, conn_info, Protocol::Tcp);
|
||||
|
||||
let res = handle_client(
|
||||
config,
|
||||
@@ -120,13 +119,13 @@ pub async fn task_main(
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
error!(
|
||||
?session_id,
|
||||
%session_id,
|
||||
"per-client task finished with an IO error from the client: {e:#}"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
error!(
|
||||
?session_id,
|
||||
%session_id,
|
||||
"per-client task finished with an IO error from the compute: {e:#}"
|
||||
);
|
||||
}
|
||||
@@ -214,10 +213,14 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
};
|
||||
auth_info.set_startup_params(¶ms, true);
|
||||
|
||||
// for TCP/WS, we have client_id=session_id=compute_id for now.
|
||||
let compute_conn_id = ComputeConnId::from_uuid(ctx.session_id().uuid());
|
||||
|
||||
let mut node = connect_to_compute(
|
||||
ctx,
|
||||
&TcpMechanism {
|
||||
locks: &config.connect_compute_locks,
|
||||
compute_conn_id,
|
||||
},
|
||||
&node_info,
|
||||
config.wake_compute_retry_config,
|
||||
@@ -250,6 +253,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
});
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
compute_conn_id: node.compute_conn_id,
|
||||
|
||||
client: stream,
|
||||
compute: node.stream,
|
||||
|
||||
|
||||
@@ -9,11 +9,11 @@ use tokio::sync::mpsc;
|
||||
use tracing::field::display;
|
||||
use tracing::{Span, error, info_span};
|
||||
use try_lock::TryLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
use self::parquet::RequestData;
|
||||
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::error::ErrorKind;
|
||||
use crate::id::{ClientConnId, RequestId};
|
||||
use crate::intern::{BranchIdInt, ProjectIdInt};
|
||||
use crate::metrics::{LatencyAccumulated, LatencyTimer, Metrics, Protocol, Waiting};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
@@ -40,7 +40,7 @@ pub struct RequestContext(
|
||||
|
||||
struct RequestContextInner {
|
||||
pub(crate) conn_info: ConnectionInfo,
|
||||
pub(crate) session_id: Uuid,
|
||||
pub(crate) session_id: RequestId,
|
||||
pub(crate) protocol: Protocol,
|
||||
first_packet: chrono::DateTime<Utc>,
|
||||
pub(crate) span: Span,
|
||||
@@ -116,12 +116,18 @@ impl Clone for RequestContext {
|
||||
}
|
||||
|
||||
impl RequestContext {
|
||||
pub fn new(session_id: Uuid, conn_info: ConnectionInfo, protocol: Protocol) -> Self {
|
||||
pub fn new(
|
||||
conn_id: ClientConnId,
|
||||
session_id: RequestId,
|
||||
conn_info: ConnectionInfo,
|
||||
protocol: Protocol,
|
||||
) -> Self {
|
||||
// TODO: be careful with long lived spans
|
||||
let span = info_span!(
|
||||
"connect_request",
|
||||
%protocol,
|
||||
?session_id,
|
||||
%session_id,
|
||||
%conn_id,
|
||||
%conn_info,
|
||||
ep = tracing::field::Empty,
|
||||
role = tracing::field::Empty,
|
||||
@@ -164,7 +170,13 @@ impl RequestContext {
|
||||
let ip = IpAddr::from([127, 0, 0, 1]);
|
||||
let addr = SocketAddr::new(ip, 5432);
|
||||
let conn_info = ConnectionInfo { addr, extra: None };
|
||||
RequestContext::new(Uuid::now_v7(), conn_info, Protocol::Tcp)
|
||||
let uuid = uuid::Uuid::now_v7();
|
||||
RequestContext::new(
|
||||
ClientConnId::from_uuid(uuid),
|
||||
RequestId::from_uuid(uuid),
|
||||
conn_info,
|
||||
Protocol::Tcp,
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn console_application_name(&self) -> String {
|
||||
@@ -311,7 +323,7 @@ impl RequestContext {
|
||||
self.0.try_lock().expect("should not deadlock").span.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn session_id(&self) -> Uuid {
|
||||
pub(crate) fn session_id(&self) -> RequestId {
|
||||
self.0.try_lock().expect("should not deadlock").session_id
|
||||
}
|
||||
|
||||
|
||||
@@ -124,7 +124,7 @@ impl serde::Serialize for Options<'_> {
|
||||
impl From<&RequestContextInner> for RequestData {
|
||||
fn from(value: &RequestContextInner) -> Self {
|
||||
Self {
|
||||
session_id: value.session_id,
|
||||
session_id: value.session_id.uuid(),
|
||||
peer_addr: value.conn_info.addr.ip().to_string(),
|
||||
timestamp: value.first_packet.naive_utc(),
|
||||
username: value.user.as_deref().map(String::from),
|
||||
|
||||
@@ -68,6 +68,66 @@ impl NeonControlPlaneClient {
|
||||
self.endpoint.url().as_str()
|
||||
}
|
||||
|
||||
async fn get_and_cache_auth_info<T>(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
endpoint: &EndpointId,
|
||||
role: &RoleName,
|
||||
cache_key: &EndpointId,
|
||||
extract: impl FnOnce(&EndpointAccessControl, &RoleAccessControl) -> T,
|
||||
) -> Result<T, GetAuthInfoError> {
|
||||
match self.do_get_auth_req(ctx, endpoint, role).await {
|
||||
Ok(auth_info) => {
|
||||
let control = EndpointAccessControl {
|
||||
allowed_ips: Arc::new(auth_info.allowed_ips),
|
||||
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
|
||||
flags: auth_info.access_blocker_flags,
|
||||
rate_limits: auth_info.rate_limits,
|
||||
};
|
||||
let role_control = RoleAccessControl {
|
||||
secret: auth_info.secret,
|
||||
};
|
||||
let res = extract(&control, &role_control);
|
||||
|
||||
self.caches.project_info.insert_endpoint_access(
|
||||
auth_info.account_id,
|
||||
auth_info.project_id,
|
||||
cache_key.into(),
|
||||
role.into(),
|
||||
control,
|
||||
role_control,
|
||||
);
|
||||
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
Err(err) => match err {
|
||||
GetAuthInfoError::ApiError(ControlPlaneError::Message(ref msg)) => {
|
||||
let retry_info = msg.status.as_ref().and_then(|s| s.details.retry_info);
|
||||
|
||||
// If we can retry this error, do not cache it,
|
||||
// unless we were given a retry delay.
|
||||
if msg.could_retry() && retry_info.is_none() {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
self.caches.project_info.insert_endpoint_access_err(
|
||||
cache_key.into(),
|
||||
role.into(),
|
||||
msg.clone(),
|
||||
retry_info.map(|r| Duration::from_millis(r.retry_delay_ms)),
|
||||
);
|
||||
|
||||
Err(err)
|
||||
}
|
||||
err => Err(err),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn do_get_auth_req(
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
@@ -284,43 +344,34 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
ctx: &RequestContext,
|
||||
endpoint: &EndpointId,
|
||||
role: &RoleName,
|
||||
) -> Result<RoleAccessControl, crate::control_plane::errors::GetAuthInfoError> {
|
||||
let normalized_ep = &endpoint.normalize();
|
||||
if let Some(secret) = self
|
||||
) -> Result<RoleAccessControl, GetAuthInfoError> {
|
||||
let key = endpoint.normalize();
|
||||
|
||||
if let Some((role_control, ttl)) = self
|
||||
.caches
|
||||
.project_info
|
||||
.get_role_secret(normalized_ep, role)
|
||||
.get_role_secret_with_ttl(&key, role)
|
||||
{
|
||||
return Ok(secret);
|
||||
return match role_control {
|
||||
Err(mut msg) => {
|
||||
info!(key = &*key, "found cached get_role_access_control error");
|
||||
|
||||
// if retry_delay_ms is set change it to the remaining TTL
|
||||
replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64);
|
||||
|
||||
Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg)))
|
||||
}
|
||||
Ok(role_control) => {
|
||||
debug!(key = &*key, "found cached role access control");
|
||||
Ok(role_control)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
|
||||
|
||||
let control = EndpointAccessControl {
|
||||
allowed_ips: Arc::new(auth_info.allowed_ips),
|
||||
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
|
||||
flags: auth_info.access_blocker_flags,
|
||||
rate_limits: auth_info.rate_limits,
|
||||
};
|
||||
let role_control = RoleAccessControl {
|
||||
secret: auth_info.secret,
|
||||
};
|
||||
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
let normalized_ep_int = normalized_ep.into();
|
||||
|
||||
self.caches.project_info.insert_endpoint_access(
|
||||
auth_info.account_id,
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
role.into(),
|
||||
control,
|
||||
role_control.clone(),
|
||||
);
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
|
||||
Ok(role_control)
|
||||
self.get_and_cache_auth_info(ctx, endpoint, role, &key, |_, role_control| {
|
||||
role_control.clone()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
@@ -330,38 +381,30 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
endpoint: &EndpointId,
|
||||
role: &RoleName,
|
||||
) -> Result<EndpointAccessControl, GetAuthInfoError> {
|
||||
let normalized_ep = &endpoint.normalize();
|
||||
if let Some(control) = self.caches.project_info.get_endpoint_access(normalized_ep) {
|
||||
return Ok(control);
|
||||
let key = endpoint.normalize();
|
||||
|
||||
if let Some((control, ttl)) = self.caches.project_info.get_endpoint_access_with_ttl(&key) {
|
||||
return match control {
|
||||
Err(mut msg) => {
|
||||
info!(
|
||||
key = &*key,
|
||||
"found cached get_endpoint_access_control error"
|
||||
);
|
||||
|
||||
// if retry_delay_ms is set change it to the remaining TTL
|
||||
replace_retry_delay_ms(&mut msg, |_| ttl.as_millis() as u64);
|
||||
|
||||
Err(GetAuthInfoError::ApiError(ControlPlaneError::Message(msg)))
|
||||
}
|
||||
Ok(control) => {
|
||||
debug!(key = &*key, "found cached endpoint access control");
|
||||
Ok(control)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let auth_info = self.do_get_auth_req(ctx, endpoint, role).await?;
|
||||
|
||||
let control = EndpointAccessControl {
|
||||
allowed_ips: Arc::new(auth_info.allowed_ips),
|
||||
allowed_vpce: Arc::new(auth_info.allowed_vpc_endpoint_ids),
|
||||
flags: auth_info.access_blocker_flags,
|
||||
rate_limits: auth_info.rate_limits,
|
||||
};
|
||||
let role_control = RoleAccessControl {
|
||||
secret: auth_info.secret,
|
||||
};
|
||||
|
||||
if let Some(project_id) = auth_info.project_id {
|
||||
let normalized_ep_int = normalized_ep.into();
|
||||
|
||||
self.caches.project_info.insert_endpoint_access(
|
||||
auth_info.account_id,
|
||||
project_id,
|
||||
normalized_ep_int,
|
||||
role.into(),
|
||||
control.clone(),
|
||||
role_control,
|
||||
);
|
||||
ctx.set_project_id(project_id);
|
||||
}
|
||||
|
||||
Ok(control)
|
||||
self.get_and_cache_auth_info(ctx, endpoint, role, &key, |control, _| control.clone())
|
||||
.await
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip_all)]
|
||||
@@ -390,13 +433,9 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
info!(key = &*key, "found cached wake_compute error");
|
||||
|
||||
// if retry_delay_ms is set, reduce it by the amount of time it spent in cache
|
||||
if let Some(status) = &mut msg.status {
|
||||
if let Some(retry_info) = &mut status.details.retry_info {
|
||||
retry_info.retry_delay_ms = retry_info
|
||||
.retry_delay_ms
|
||||
.saturating_sub(created_at.elapsed().as_millis() as u64)
|
||||
}
|
||||
}
|
||||
replace_retry_delay_ms(&mut msg, |delay| {
|
||||
delay.saturating_sub(created_at.elapsed().as_millis() as u64)
|
||||
});
|
||||
|
||||
Err(WakeComputeError::ControlPlane(ControlPlaneError::Message(
|
||||
msg,
|
||||
@@ -478,6 +517,14 @@ impl super::ControlPlaneApi for NeonControlPlaneClient {
|
||||
}
|
||||
}
|
||||
|
||||
fn replace_retry_delay_ms(msg: &mut ControlPlaneErrorMessage, f: impl FnOnce(u64) -> u64) {
|
||||
if let Some(status) = &mut msg.status
|
||||
&& let Some(retry_info) = &mut status.details.retry_info
|
||||
{
|
||||
retry_info.retry_delay_ms = f(retry_info.retry_delay_ms);
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse http response body, taking status code into account.
|
||||
fn parse_body<T: for<'a> serde::Deserialize<'a>>(
|
||||
status: StatusCode,
|
||||
|
||||
@@ -52,7 +52,7 @@ impl ReportableError for ControlPlaneError {
|
||||
| Reason::EndpointNotFound
|
||||
| Reason::EndpointDisabled
|
||||
| Reason::BranchNotFound
|
||||
| Reason::InvalidEphemeralEndpointOptions => ErrorKind::User,
|
||||
| Reason::WrongLsnOrTimestamp => ErrorKind::User,
|
||||
|
||||
Reason::RateLimitExceeded => ErrorKind::ServiceRateLimit,
|
||||
|
||||
|
||||
@@ -107,7 +107,7 @@ pub(crate) struct ErrorInfo {
|
||||
// Schema could also have `metadata` field, but it's not structured. Skip it for now.
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Default)]
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Default, PartialEq, Eq)]
|
||||
pub(crate) enum Reason {
|
||||
/// RoleProtected indicates that the role is protected and the attempted operation is not permitted on protected roles.
|
||||
#[serde(rename = "ROLE_PROTECTED")]
|
||||
@@ -133,9 +133,9 @@ pub(crate) enum Reason {
|
||||
/// or that the subject doesn't have enough permissions to access the requested branch.
|
||||
#[serde(rename = "BRANCH_NOT_FOUND")]
|
||||
BranchNotFound,
|
||||
/// InvalidEphemeralEndpointOptions indicates that the specified LSN or timestamp are wrong.
|
||||
#[serde(rename = "INVALID_EPHEMERAL_OPTIONS")]
|
||||
InvalidEphemeralEndpointOptions,
|
||||
/// WrongLsnOrTimestamp indicates that the specified LSN or timestamp are wrong.
|
||||
#[serde(rename = "WRONG_LSN_OR_TIMESTAMP")]
|
||||
WrongLsnOrTimestamp,
|
||||
/// RateLimitExceeded indicates that the rate limit for the operation has been exceeded.
|
||||
#[serde(rename = "RATE_LIMIT_EXCEEDED")]
|
||||
RateLimitExceeded,
|
||||
@@ -205,7 +205,7 @@ impl Reason {
|
||||
| Reason::EndpointNotFound
|
||||
| Reason::EndpointDisabled
|
||||
| Reason::BranchNotFound
|
||||
| Reason::InvalidEphemeralEndpointOptions => false,
|
||||
| Reason::WrongLsnOrTimestamp => false,
|
||||
// we were asked to go away
|
||||
Reason::RateLimitExceeded
|
||||
| Reason::NonDefaultBranchComputeTimeExceeded
|
||||
@@ -257,19 +257,19 @@ pub(crate) struct GetEndpointAccessControl {
|
||||
pub(crate) rate_limits: EndpointRateLimitConfig,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Deserialize, Default)]
|
||||
#[derive(Copy, Clone, Deserialize, Default, Debug)]
|
||||
pub struct EndpointRateLimitConfig {
|
||||
pub connection_attempts: ConnectionAttemptsLimit,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Deserialize, Default)]
|
||||
#[derive(Copy, Clone, Deserialize, Default, Debug)]
|
||||
pub struct ConnectionAttemptsLimit {
|
||||
pub tcp: Option<LeakyBucketSetting>,
|
||||
pub ws: Option<LeakyBucketSetting>,
|
||||
pub http: Option<LeakyBucketSetting>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Deserialize)]
|
||||
#[derive(Copy, Clone, Deserialize, Debug)]
|
||||
pub struct LeakyBucketSetting {
|
||||
pub rps: f64,
|
||||
pub burst: f64,
|
||||
|
||||
@@ -20,6 +20,7 @@ use crate::cache::{Cached, TimedLru};
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::messages::{ControlPlaneErrorMessage, MetricsAuxInfo};
|
||||
use crate::id::ComputeConnId;
|
||||
use crate::intern::{AccountIdInt, EndpointIdInt, ProjectIdInt};
|
||||
use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::rate_limiter::{EndpointRateLimiter, LeakyBucketConfig};
|
||||
@@ -77,12 +78,15 @@ impl NodeInfo {
|
||||
&self,
|
||||
ctx: &RequestContext,
|
||||
config: &ComputeConfig,
|
||||
compute_conn_id: ComputeConnId,
|
||||
) -> Result<compute::ComputeConnection, compute::ConnectionError> {
|
||||
self.conn_info.connect(ctx, &self.aux, config).await
|
||||
self.conn_info
|
||||
.connect(ctx, &self.aux, config, compute_conn_id)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Default)]
|
||||
#[derive(Copy, Clone, Default, Debug)]
|
||||
pub(crate) struct AccessBlockerFlags {
|
||||
pub public_access_blocked: bool,
|
||||
pub vpc_access_blocked: bool,
|
||||
@@ -92,12 +96,12 @@ pub(crate) type NodeInfoCache =
|
||||
TimedLru<EndpointCacheKey, Result<NodeInfo, Box<ControlPlaneErrorMessage>>>;
|
||||
pub(crate) type CachedNodeInfo = Cached<&'static NodeInfoCache, NodeInfo>;
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RoleAccessControl {
|
||||
pub secret: Option<AuthSecret>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct EndpointAccessControl {
|
||||
pub allowed_ips: Arc<Vec<IpPattern>>,
|
||||
pub allowed_vpce: Arc<Vec<String>>,
|
||||
|
||||
33
proxy/src/id.rs
Normal file
33
proxy/src/id.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
//! Various ID types used by proxy.
|
||||
|
||||
use type_safe_id::{StaticType, TypeSafeId};
|
||||
|
||||
/// The ID used for the client connection
|
||||
pub type ClientConnId = TypeSafeId<ClientConn>;
|
||||
|
||||
#[derive(Copy, Clone, Default, Hash, PartialEq, Eq)]
|
||||
pub struct ClientConn;
|
||||
|
||||
impl StaticType for ClientConn {
|
||||
// This is visible by customers, so we use 'neon' here instead of 'client'.
|
||||
const TYPE: &'static str = "neon_conn";
|
||||
}
|
||||
|
||||
/// The ID used for the compute connection
|
||||
pub type ComputeConnId = TypeSafeId<ComputeConn>;
|
||||
|
||||
#[derive(Copy, Clone, Default, Hash, PartialEq, Eq)]
|
||||
pub struct ComputeConn;
|
||||
|
||||
impl StaticType for ComputeConn {
|
||||
const TYPE: &'static str = "compute_conn";
|
||||
}
|
||||
|
||||
/// The ID used for the request to authenticate
|
||||
pub type RequestId = TypeSafeId<Request>;
|
||||
#[derive(Copy, Clone, Default, Hash, PartialEq, Eq)]
|
||||
pub struct Request;
|
||||
|
||||
impl StaticType for Request {
|
||||
const TYPE: &'static str = "request";
|
||||
}
|
||||
@@ -91,6 +91,7 @@ mod control_plane;
|
||||
mod error;
|
||||
mod ext;
|
||||
mod http;
|
||||
mod id;
|
||||
mod intern;
|
||||
mod jemalloc;
|
||||
mod logging;
|
||||
|
||||
@@ -17,7 +17,8 @@ use crate::cancellation::{self, CancellationHandler};
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2, TlsConfig};
|
||||
use crate::context::RequestContext;
|
||||
use crate::error::{ReportableError, UserFacingError};
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard};
|
||||
use crate::id::{ClientConnId, RequestId};
|
||||
use crate::metrics::{Metrics, NumClientConnectionsGuard, Protocol};
|
||||
pub use crate::pglb::copy_bidirectional::ErrorSource;
|
||||
use crate::pglb::handshake::{HandshakeData, HandshakeError, handshake};
|
||||
use crate::pglb::passthrough::ProxyPassthrough;
|
||||
@@ -65,12 +66,11 @@ pub async fn task_main(
|
||||
{
|
||||
let (socket, peer_addr) = accept_result?;
|
||||
|
||||
let conn_gauge = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
.guard(crate::metrics::Protocol::Tcp);
|
||||
let conn_gauge = Metrics::get().proxy.client_connections.guard(Protocol::Tcp);
|
||||
|
||||
let conn_id = ClientConnId::new();
|
||||
let session_id = RequestId::from_uuid(conn_id.uuid());
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let cancellation_handler = Arc::clone(&cancellation_handler);
|
||||
let cancellations = cancellations.clone();
|
||||
|
||||
@@ -114,7 +114,7 @@ pub async fn task_main(
|
||||
}
|
||||
}
|
||||
|
||||
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Tcp);
|
||||
let ctx = RequestContext::new(conn_id, session_id, conn_info, Protocol::Tcp);
|
||||
|
||||
let res = handle_connection(
|
||||
config,
|
||||
@@ -142,17 +142,22 @@ pub async fn task_main(
|
||||
Ok(Some(p)) => {
|
||||
ctx.set_success();
|
||||
let _disconnect = ctx.log_connect();
|
||||
let compute_conn_id = p.compute_conn_id;
|
||||
match p.proxy_pass().await {
|
||||
Ok(()) => {}
|
||||
Err(ErrorSource::Client(e)) => {
|
||||
warn!(
|
||||
?session_id,
|
||||
%conn_id,
|
||||
%session_id,
|
||||
%compute_conn_id,
|
||||
"per-client task finished with an IO error from the client: {e:#}"
|
||||
);
|
||||
}
|
||||
Err(ErrorSource::Compute(e)) => {
|
||||
error!(
|
||||
?session_id,
|
||||
%conn_id,
|
||||
%session_id,
|
||||
%compute_conn_id,
|
||||
"per-client task finished with an IO error from the compute: {e:#}"
|
||||
);
|
||||
}
|
||||
@@ -318,6 +323,8 @@ pub(crate) async fn handle_connection<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
};
|
||||
|
||||
Ok(Some(ProxyPassthrough {
|
||||
compute_conn_id: node.compute_conn_id,
|
||||
|
||||
client,
|
||||
compute: node.stream,
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ use utils::measured_stream::MeasuredStream;
|
||||
use super::copy_bidirectional::ErrorSource;
|
||||
use crate::compute::MaybeRustlsStream;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::id::ComputeConnId;
|
||||
use crate::metrics::{
|
||||
Direction, Metrics, NumClientConnectionsGuard, NumConnectionRequestsGuard,
|
||||
NumDbConnectionsGuard,
|
||||
@@ -65,6 +66,8 @@ pub(crate) async fn proxy_pass(
|
||||
}
|
||||
|
||||
pub(crate) struct ProxyPassthrough<S> {
|
||||
pub(crate) compute_conn_id: ComputeConnId,
|
||||
|
||||
pub(crate) client: Stream<S>,
|
||||
pub(crate) compute: MaybeRustlsStream,
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ use crate::control_plane::errors::WakeComputeError;
|
||||
use crate::control_plane::locks::ApiLocks;
|
||||
use crate::control_plane::{self, NodeInfo};
|
||||
use crate::error::ReportableError;
|
||||
use crate::id::ComputeConnId;
|
||||
use crate::metrics::{
|
||||
ConnectOutcome, ConnectionFailureKind, Metrics, RetriesMetricGroup, RetryType,
|
||||
};
|
||||
@@ -51,6 +52,7 @@ pub(crate) trait ConnectMechanism {
|
||||
pub(crate) struct TcpMechanism {
|
||||
/// connect_to_compute concurrency lock
|
||||
pub(crate) locks: &'static ApiLocks<Host>,
|
||||
pub(crate) compute_conn_id: ComputeConnId,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -70,7 +72,7 @@ impl ConnectMechanism for TcpMechanism {
|
||||
config: &ComputeConfig,
|
||||
) -> Result<ComputeConnection, Self::Error> {
|
||||
let permit = self.locks.get_permit(&node_info.conn_info.host).await?;
|
||||
permit.release_result(node_info.connect(ctx, config).await)
|
||||
permit.release_result(node_info.connect(ctx, config, self.compute_conn_id).await)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ use crate::compute::ComputeConnection;
|
||||
use crate::config::ProxyConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::client::ControlPlaneClient;
|
||||
use crate::id::ComputeConnId;
|
||||
pub use crate::pglb::copy_bidirectional::{ErrorSource, copy_bidirectional_client_compute};
|
||||
use crate::pglb::{ClientMode, ClientRequestError};
|
||||
use crate::pqproto::{BeMessage, CancelKeyData, StartupMessageParams};
|
||||
@@ -94,6 +95,8 @@ pub(crate) async fn handle_client<S: AsyncRead + AsyncWrite + Unpin + Send>(
|
||||
let mut attempt = 0;
|
||||
let connect = TcpMechanism {
|
||||
locks: &config.connect_compute_locks,
|
||||
// for TCP/WS, we have client_id=session_id=compute_id for now.
|
||||
compute_conn_id: ComputeConnId::from_uuid(ctx.session_id().uuid()),
|
||||
};
|
||||
let backend = auth::Backend::ControlPlane(cplane, creds.info);
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ use crate::control_plane::client::ApiLockError;
|
||||
use crate::control_plane::errors::{GetAuthInfoError, WakeComputeError};
|
||||
use crate::control_plane::locks::ApiLocks;
|
||||
use crate::error::{ErrorKind, ReportableError, UserFacingError};
|
||||
use crate::id::{ComputeConnId, RequestId};
|
||||
use crate::intern::EndpointIdInt;
|
||||
use crate::proxy::connect_compute::ConnectMechanism;
|
||||
use crate::proxy::retry::{CouldRetry, ShouldRetryWakeCompute};
|
||||
@@ -161,7 +162,7 @@ impl PoolingBackend {
|
||||
#[tracing::instrument(skip_all, fields(
|
||||
pid = tracing::field::Empty,
|
||||
compute_id = tracing::field::Empty,
|
||||
conn_id = tracing::field::Empty,
|
||||
compute_conn_id = tracing::field::Empty,
|
||||
))]
|
||||
pub(crate) async fn connect_to_compute(
|
||||
&self,
|
||||
@@ -181,14 +182,14 @@ impl PoolingBackend {
|
||||
if let Some(client) = maybe_client {
|
||||
return Ok(client);
|
||||
}
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let compute_conn_id = ComputeConnId::new();
|
||||
tracing::Span::current().record("compute_conn_id", display(compute_conn_id));
|
||||
info!(%compute_conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let backend = self.auth_backend.as_ref().map(|()| keys.info);
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&TokioMechanism {
|
||||
conn_id,
|
||||
compute_conn_id,
|
||||
conn_info,
|
||||
pool: self.pool.clone(),
|
||||
locks: &self.config.connect_compute_locks,
|
||||
@@ -204,7 +205,7 @@ impl PoolingBackend {
|
||||
// Wake up the destination if needed
|
||||
#[tracing::instrument(skip_all, fields(
|
||||
compute_id = tracing::field::Empty,
|
||||
conn_id = tracing::field::Empty,
|
||||
compute_conn_id = tracing::field::Empty,
|
||||
))]
|
||||
pub(crate) async fn connect_to_local_proxy(
|
||||
&self,
|
||||
@@ -216,9 +217,9 @@ impl PoolingBackend {
|
||||
return Ok(client);
|
||||
}
|
||||
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
debug!(%conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let compute_conn_id = ComputeConnId::new();
|
||||
tracing::Span::current().record("compute_conn_id", display(compute_conn_id));
|
||||
debug!(%compute_conn_id, "pool: opening a new connection '{conn_info}'");
|
||||
let backend = self.auth_backend.as_ref().map(|()| ComputeUserInfo {
|
||||
user: conn_info.user_info.user.clone(),
|
||||
endpoint: EndpointId::from(format!(
|
||||
@@ -230,7 +231,7 @@ impl PoolingBackend {
|
||||
crate::proxy::connect_compute::connect_to_compute(
|
||||
ctx,
|
||||
&HyperMechanism {
|
||||
conn_id,
|
||||
compute_conn_id,
|
||||
conn_info,
|
||||
pool: self.http_conn_pool.clone(),
|
||||
locks: &self.config.connect_compute_locks,
|
||||
@@ -251,7 +252,7 @@ impl PoolingBackend {
|
||||
/// Panics if called with a non-local_proxy backend.
|
||||
#[tracing::instrument(skip_all, fields(
|
||||
pid = tracing::field::Empty,
|
||||
conn_id = tracing::field::Empty,
|
||||
compute_conn_id = tracing::field::Empty,
|
||||
))]
|
||||
pub(crate) async fn connect_to_local_postgres(
|
||||
&self,
|
||||
@@ -303,9 +304,9 @@ impl PoolingBackend {
|
||||
}
|
||||
}
|
||||
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
tracing::Span::current().record("conn_id", display(conn_id));
|
||||
info!(%conn_id, "local_pool: opening a new connection '{conn_info}'");
|
||||
let compute_conn_id = ComputeConnId::new();
|
||||
tracing::Span::current().record("compute_conn_id", display(compute_conn_id));
|
||||
info!(%compute_conn_id, "local_pool: opening a new connection '{conn_info}'");
|
||||
|
||||
let (key, jwk) = create_random_jwk();
|
||||
|
||||
@@ -340,7 +341,7 @@ impl PoolingBackend {
|
||||
client,
|
||||
connection,
|
||||
key,
|
||||
conn_id,
|
||||
compute_conn_id,
|
||||
local_backend.node_info.aux.clone(),
|
||||
);
|
||||
|
||||
@@ -378,7 +379,7 @@ fn create_random_jwk() -> (SigningKey, jose_jwk::Key) {
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum HttpConnError {
|
||||
#[error("pooled connection closed at inconsistent state")]
|
||||
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<uuid::Uuid>),
|
||||
ConnectionClosedAbruptly(#[from] tokio::sync::watch::error::SendError<RequestId>),
|
||||
#[error("could not connect to postgres in compute")]
|
||||
PostgresConnectionError(#[from] postgres_client::Error),
|
||||
#[error("could not connect to local-proxy in compute")]
|
||||
@@ -509,7 +510,7 @@ impl ShouldRetryWakeCompute for LocalProxyConnError {
|
||||
struct TokioMechanism {
|
||||
pool: Arc<GlobalConnPool<postgres_client::Client, EndpointConnPool<postgres_client::Client>>>,
|
||||
conn_info: ConnInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
compute_conn_id: ComputeConnId,
|
||||
keys: ComputeCredentialKeys,
|
||||
|
||||
/// connect_to_compute concurrency lock
|
||||
@@ -561,7 +562,7 @@ impl ConnectMechanism for TokioMechanism {
|
||||
self.conn_info.clone(),
|
||||
client,
|
||||
connection,
|
||||
self.conn_id,
|
||||
self.compute_conn_id,
|
||||
node_info.aux.clone(),
|
||||
))
|
||||
}
|
||||
@@ -570,7 +571,7 @@ impl ConnectMechanism for TokioMechanism {
|
||||
struct HyperMechanism {
|
||||
pool: Arc<GlobalConnPool<Send, HttpConnPool<Send>>>,
|
||||
conn_info: ConnInfo,
|
||||
conn_id: uuid::Uuid,
|
||||
compute_conn_id: ComputeConnId,
|
||||
|
||||
/// connect_to_compute concurrency lock
|
||||
locks: &'static ApiLocks<Host>,
|
||||
@@ -620,7 +621,7 @@ impl ConnectMechanism for HyperMechanism {
|
||||
&self.conn_info,
|
||||
client,
|
||||
connection,
|
||||
self.conn_id,
|
||||
self.compute_conn_id,
|
||||
node_info.aux.clone(),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -10,7 +10,8 @@ use rand::{Rng, thread_rng};
|
||||
use rustc_hash::FxHasher;
|
||||
use tokio::time::Instant;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::id::ClientConnId;
|
||||
|
||||
type Hasher = BuildHasherDefault<FxHasher>;
|
||||
|
||||
@@ -21,7 +22,7 @@ pub struct CancelSet {
|
||||
}
|
||||
|
||||
pub(crate) struct CancelShard {
|
||||
tokens: IndexMap<uuid::Uuid, (Instant, CancellationToken), Hasher>,
|
||||
tokens: IndexMap<ClientConnId, (Instant, CancellationToken), Hasher>,
|
||||
}
|
||||
|
||||
impl CancelSet {
|
||||
@@ -53,7 +54,7 @@ impl CancelSet {
|
||||
.and_then(|len| self.shards[rng % len].lock().take(rng / len))
|
||||
}
|
||||
|
||||
pub(crate) fn insert(&self, id: uuid::Uuid, token: CancellationToken) -> CancelGuard<'_> {
|
||||
pub(crate) fn insert(&self, id: ClientConnId, token: CancellationToken) -> CancelGuard<'_> {
|
||||
let shard = NonZeroUsize::new(self.shards.len()).map(|len| {
|
||||
let hash = self.hasher.hash_one(id) as usize;
|
||||
let shard = &self.shards[hash % len];
|
||||
@@ -77,18 +78,18 @@ impl CancelShard {
|
||||
})
|
||||
}
|
||||
|
||||
fn remove(&mut self, id: uuid::Uuid) {
|
||||
fn remove(&mut self, id: ClientConnId) {
|
||||
self.tokens.swap_remove(&id);
|
||||
}
|
||||
|
||||
fn insert(&mut self, id: uuid::Uuid, token: CancellationToken) {
|
||||
fn insert(&mut self, id: ClientConnId, token: CancellationToken) {
|
||||
self.tokens.insert(id, (Instant::now(), token));
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct CancelGuard<'a> {
|
||||
shard: Option<&'a Mutex<CancelShard>>,
|
||||
id: Uuid,
|
||||
id: ClientConnId,
|
||||
}
|
||||
|
||||
impl Drop for CancelGuard<'_> {
|
||||
|
||||
@@ -26,6 +26,7 @@ use super::conn_pool_lib::{
|
||||
use crate::config::ComputeConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::messages::MetricsAuxInfo;
|
||||
use crate::id::{ComputeConnId, RequestId};
|
||||
use crate::metrics::Metrics;
|
||||
|
||||
type TlsStream = <ComputeConfig as MakeTlsConnect<TcpStream>>::Stream;
|
||||
@@ -62,14 +63,14 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
conn_info: ConnInfo,
|
||||
client: C,
|
||||
mut connection: postgres_client::Connection<TcpStream, TlsStream>,
|
||||
conn_id: uuid::Uuid,
|
||||
compute_conn_id: ComputeConnId,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> Client<C> {
|
||||
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
|
||||
let mut session_id = ctx.session_id();
|
||||
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
|
||||
|
||||
let span = info_span!(parent: None, "connection", %conn_id);
|
||||
let span = info_span!(parent: None, "connection", %compute_conn_id);
|
||||
let cold_start_info = ctx.cold_start_info();
|
||||
span.in_scope(|| {
|
||||
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
|
||||
@@ -117,7 +118,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
if let Some(pool) = pool.clone().upgrade() {
|
||||
// remove client from pool - should close the connection if it's idle.
|
||||
// does nothing if the client is currently checked-out and in-use
|
||||
if pool.write().remove_client(db_user.clone(), conn_id) {
|
||||
if pool.write().remove_client(db_user.clone(), compute_conn_id) {
|
||||
info!("idle connection removed");
|
||||
}
|
||||
}
|
||||
@@ -149,7 +150,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
|
||||
// remove from connection pool
|
||||
if let Some(pool) = pool.clone().upgrade()
|
||||
&& pool.write().remove_client(db_user.clone(), conn_id) {
|
||||
&& pool.write().remove_client(db_user.clone(), compute_conn_id) {
|
||||
info!("closed connection removed");
|
||||
}
|
||||
|
||||
@@ -161,7 +162,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
let inner = ClientInnerCommon {
|
||||
inner: client,
|
||||
aux,
|
||||
conn_id,
|
||||
compute_conn_id,
|
||||
data: ClientDataEnum::Remote(ClientDataRemote {
|
||||
session: tx,
|
||||
cancel,
|
||||
@@ -173,12 +174,12 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ClientDataRemote {
|
||||
session: tokio::sync::watch::Sender<uuid::Uuid>,
|
||||
session: tokio::sync::watch::Sender<RequestId>,
|
||||
cancel: CancellationToken,
|
||||
}
|
||||
|
||||
impl ClientDataRemote {
|
||||
pub fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
|
||||
pub fn session(&mut self) -> &mut tokio::sync::watch::Sender<RequestId> {
|
||||
&mut self.session
|
||||
}
|
||||
|
||||
@@ -192,6 +193,7 @@ mod tests {
|
||||
use std::sync::atomic::AtomicBool;
|
||||
|
||||
use super::*;
|
||||
use crate::id::ComputeConnId;
|
||||
use crate::proxy::NeonOptions;
|
||||
use crate::serverless::cancel_set::CancelSet;
|
||||
use crate::types::{BranchId, EndpointId, ProjectId};
|
||||
@@ -225,9 +227,9 @@ mod tests {
|
||||
compute_id: "compute".into(),
|
||||
cold_start_info: crate::control_plane::messages::ColdStartInfo::Warm,
|
||||
},
|
||||
conn_id: uuid::Uuid::new_v4(),
|
||||
compute_conn_id: ComputeConnId::new(),
|
||||
data: ClientDataEnum::Remote(ClientDataRemote {
|
||||
session: tokio::sync::watch::Sender::new(uuid::Uuid::new_v4()),
|
||||
session: tokio::sync::watch::Sender::new(RequestId::new()),
|
||||
cancel: CancellationToken::new(),
|
||||
}),
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ use super::local_conn_pool::ClientDataLocal;
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::id::ComputeConnId;
|
||||
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
|
||||
use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::types::{DbName, EndpointCacheKey, RoleName};
|
||||
@@ -58,7 +59,7 @@ pub(crate) enum ClientDataEnum {
|
||||
pub(crate) struct ClientInnerCommon<C: ClientInnerExt> {
|
||||
pub(crate) inner: C,
|
||||
pub(crate) aux: MetricsAuxInfo,
|
||||
pub(crate) conn_id: uuid::Uuid,
|
||||
pub(crate) compute_conn_id: ComputeConnId,
|
||||
pub(crate) data: ClientDataEnum, // custom client data like session, key, jti
|
||||
}
|
||||
|
||||
@@ -77,8 +78,8 @@ impl<C: ClientInnerExt> Drop for ClientInnerCommon<C> {
|
||||
}
|
||||
|
||||
impl<C: ClientInnerExt> ClientInnerCommon<C> {
|
||||
pub(crate) fn get_conn_id(&self) -> uuid::Uuid {
|
||||
self.conn_id
|
||||
pub(crate) fn get_conn_id(&self) -> ComputeConnId {
|
||||
self.compute_conn_id
|
||||
}
|
||||
|
||||
pub(crate) fn get_data(&mut self) -> &mut ClientDataEnum {
|
||||
@@ -144,7 +145,7 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
pub(crate) fn remove_client(
|
||||
&mut self,
|
||||
db_user: (DbName, RoleName),
|
||||
conn_id: uuid::Uuid,
|
||||
conn_id: ComputeConnId,
|
||||
) -> bool {
|
||||
let Self {
|
||||
pools,
|
||||
@@ -189,7 +190,7 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
}
|
||||
|
||||
pub(crate) fn put(pool: &RwLock<Self>, conn_info: &ConnInfo, client: ClientInnerCommon<C>) {
|
||||
let conn_id = client.get_conn_id();
|
||||
let compute_conn_id = client.get_conn_id();
|
||||
let (max_conn, conn_count, pool_name) = {
|
||||
let pool = pool.read();
|
||||
(
|
||||
@@ -201,12 +202,12 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
};
|
||||
|
||||
if client.inner.is_closed() {
|
||||
info!(%conn_id, "{}: throwing away connection '{conn_info}' because connection is closed", pool_name);
|
||||
info!(%compute_conn_id, "{}: throwing away connection '{conn_info}' because connection is closed", pool_name);
|
||||
return;
|
||||
}
|
||||
|
||||
if conn_count >= max_conn {
|
||||
info!(%conn_id, "{}: throwing away connection '{conn_info}' because pool is full", pool_name);
|
||||
info!(%compute_conn_id, "{}: throwing away connection '{conn_info}' because pool is full", pool_name);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -241,9 +242,9 @@ impl<C: ClientInnerExt> EndpointConnPool<C> {
|
||||
|
||||
// do logging outside of the mutex
|
||||
if returned {
|
||||
debug!(%conn_id, "{pool_name}: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
|
||||
debug!(%compute_conn_id, "{pool_name}: returning connection '{conn_info}' back to the pool, total_conns={total_conns}, for this (db, user)={per_db_size}");
|
||||
} else {
|
||||
info!(%conn_id, "{pool_name}: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
|
||||
info!(%compute_conn_id, "{pool_name}: throwing away connection '{conn_info}' because pool is full, total_conns={total_conns}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ use super::conn_pool_lib::{
|
||||
};
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::id::ComputeConnId;
|
||||
use crate::metrics::{HttpEndpointPoolsGuard, Metrics};
|
||||
use crate::protocol2::ConnectionInfoExtra;
|
||||
use crate::types::EndpointCacheKey;
|
||||
@@ -65,7 +66,7 @@ impl<C: ClientInnerExt + Clone> HttpConnPool<C> {
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_conn(&mut self, conn_id: uuid::Uuid) -> bool {
|
||||
fn remove_conn(&mut self, conn_id: ComputeConnId) -> bool {
|
||||
let Self {
|
||||
conns,
|
||||
global_connections_count,
|
||||
@@ -73,7 +74,7 @@ impl<C: ClientInnerExt + Clone> HttpConnPool<C> {
|
||||
} = self;
|
||||
|
||||
let old_len = conns.len();
|
||||
conns.retain(|entry| entry.conn.conn_id != conn_id);
|
||||
conns.retain(|entry| entry.conn.compute_conn_id != conn_id);
|
||||
let new_len = conns.len();
|
||||
let removed = old_len - new_len;
|
||||
if removed > 0 {
|
||||
@@ -135,7 +136,10 @@ impl<C: ClientInnerExt + Clone> GlobalConnPool<C, HttpConnPool<C>> {
|
||||
return result;
|
||||
};
|
||||
|
||||
tracing::Span::current().record("conn_id", tracing::field::display(client.conn.conn_id));
|
||||
tracing::Span::current().record(
|
||||
"conn_id",
|
||||
tracing::field::display(client.conn.compute_conn_id),
|
||||
);
|
||||
debug!(
|
||||
cold_start_info = ColdStartInfo::HttpPoolHit.as_str(),
|
||||
"pool: reusing connection '{conn_info}'"
|
||||
@@ -194,13 +198,13 @@ pub(crate) fn poll_http2_client(
|
||||
conn_info: &ConnInfo,
|
||||
client: Send,
|
||||
connection: Connect,
|
||||
conn_id: uuid::Uuid,
|
||||
compute_conn_id: ComputeConnId,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> Client<Send> {
|
||||
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
|
||||
let session_id = ctx.session_id();
|
||||
|
||||
let span = info_span!(parent: None, "connection", %conn_id);
|
||||
let span = info_span!(parent: None, "connection", %compute_conn_id);
|
||||
let cold_start_info = ctx.cold_start_info();
|
||||
span.in_scope(|| {
|
||||
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
|
||||
@@ -212,7 +216,7 @@ pub(crate) fn poll_http2_client(
|
||||
let client = ClientInnerCommon {
|
||||
inner: client.clone(),
|
||||
aux: aux.clone(),
|
||||
conn_id,
|
||||
compute_conn_id,
|
||||
data: ClientDataEnum::Http(ClientDataHttp()),
|
||||
};
|
||||
pool.write().conns.push_back(ConnPoolEntry {
|
||||
@@ -241,7 +245,7 @@ pub(crate) fn poll_http2_client(
|
||||
|
||||
// remove from connection pool
|
||||
if let Some(pool) = pool.clone().upgrade()
|
||||
&& pool.write().remove_conn(conn_id)
|
||||
&& pool.write().remove_conn(compute_conn_id)
|
||||
{
|
||||
info!("closed connection removed");
|
||||
}
|
||||
@@ -252,7 +256,7 @@ pub(crate) fn poll_http2_client(
|
||||
let client = ClientInnerCommon {
|
||||
inner: client,
|
||||
aux,
|
||||
conn_id,
|
||||
compute_conn_id,
|
||||
data: ClientDataEnum::Http(ClientDataHttp()),
|
||||
};
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ use http_body_util::{BodyExt, Full};
|
||||
use http_utils::error::ApiError;
|
||||
use serde::Serialize;
|
||||
use url::Url;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::conn_pool::{AuthData, ConnInfoWithAuth};
|
||||
use super::conn_pool_lib::ConnInfo;
|
||||
@@ -18,6 +17,7 @@ use super::error::{ConnInfoError, Credentials};
|
||||
use crate::auth::backend::ComputeUserInfo;
|
||||
use crate::config::AuthenticationConfig;
|
||||
use crate::context::RequestContext;
|
||||
use crate::id::RequestId;
|
||||
use crate::metrics::{Metrics, SniGroup, SniKind};
|
||||
use crate::pqproto::StartupMessageParams;
|
||||
use crate::proxy::NeonOptions;
|
||||
@@ -34,9 +34,8 @@ pub(super) static TXN_ISOLATION_LEVEL: HeaderName =
|
||||
pub(super) static TXN_READ_ONLY: HeaderName = HeaderName::from_static("neon-batch-read-only");
|
||||
pub(super) static TXN_DEFERRABLE: HeaderName = HeaderName::from_static("neon-batch-deferrable");
|
||||
|
||||
pub(crate) fn uuid_to_header_value(id: Uuid) -> HeaderValue {
|
||||
let mut uuid = [0; uuid::fmt::Hyphenated::LENGTH];
|
||||
HeaderValue::from_str(id.as_hyphenated().encode_lower(&mut uuid[..]))
|
||||
pub(crate) fn uuid_to_header_value(id: RequestId) -> HeaderValue {
|
||||
HeaderValue::from_maybe_shared(Bytes::from(id.to_string().into_bytes()))
|
||||
.expect("uuid hyphenated format should be all valid header characters")
|
||||
}
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ use super::conn_pool_lib::{
|
||||
use super::sql_over_http::SqlOverHttpError;
|
||||
use crate::context::RequestContext;
|
||||
use crate::control_plane::messages::{ColdStartInfo, MetricsAuxInfo};
|
||||
use crate::id::{ComputeConnId, RequestId};
|
||||
use crate::metrics::Metrics;
|
||||
|
||||
pub(crate) const EXT_NAME: &str = "pg_session_jwt";
|
||||
@@ -48,14 +49,14 @@ pub(crate) const EXT_SCHEMA: &str = "auth";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct ClientDataLocal {
|
||||
session: tokio::sync::watch::Sender<uuid::Uuid>,
|
||||
session: tokio::sync::watch::Sender<RequestId>,
|
||||
cancel: CancellationToken,
|
||||
key: SigningKey,
|
||||
jti: u64,
|
||||
}
|
||||
|
||||
impl ClientDataLocal {
|
||||
pub fn session(&mut self) -> &mut tokio::sync::watch::Sender<uuid::Uuid> {
|
||||
pub fn session(&mut self) -> &mut tokio::sync::watch::Sender<RequestId> {
|
||||
&mut self.session
|
||||
}
|
||||
|
||||
@@ -167,14 +168,14 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
client: C,
|
||||
mut connection: postgres_client::Connection<TcpStream, NoTlsStream>,
|
||||
key: SigningKey,
|
||||
conn_id: uuid::Uuid,
|
||||
compute_conn_id: ComputeConnId,
|
||||
aux: MetricsAuxInfo,
|
||||
) -> Client<C> {
|
||||
let conn_gauge = Metrics::get().proxy.db_connections.guard(ctx.protocol());
|
||||
let mut session_id = ctx.session_id();
|
||||
let (tx, mut rx) = tokio::sync::watch::channel(session_id);
|
||||
|
||||
let span = info_span!(parent: None, "connection", %conn_id);
|
||||
let span = info_span!(parent: None, "connection", %compute_conn_id);
|
||||
let cold_start_info = ctx.cold_start_info();
|
||||
span.in_scope(|| {
|
||||
info!(cold_start_info = cold_start_info.as_str(), %conn_info, %session_id, "new connection");
|
||||
@@ -218,7 +219,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
if let Some(pool) = pool.clone().upgrade() {
|
||||
// remove client from pool - should close the connection if it's idle.
|
||||
// does nothing if the client is currently checked-out and in-use
|
||||
if pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
|
||||
if pool.global_pool.write().remove_client(db_user.clone(), compute_conn_id) {
|
||||
info!("idle connection removed");
|
||||
}
|
||||
}
|
||||
@@ -250,7 +251,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
|
||||
// remove from connection pool
|
||||
if let Some(pool) = pool.clone().upgrade()
|
||||
&& pool.global_pool.write().remove_client(db_user.clone(), conn_id) {
|
||||
&& pool.global_pool.write().remove_client(db_user.clone(), compute_conn_id) {
|
||||
info!("closed connection removed");
|
||||
}
|
||||
|
||||
@@ -263,7 +264,7 @@ pub(crate) fn poll_client<C: ClientInnerExt>(
|
||||
let inner = ClientInnerCommon {
|
||||
inner: client,
|
||||
aux,
|
||||
conn_id,
|
||||
compute_conn_id,
|
||||
data: ClientDataEnum::Local(ClientDataLocal {
|
||||
session: tx,
|
||||
cancel,
|
||||
|
||||
@@ -16,16 +16,16 @@ mod websocket;
|
||||
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::pin::{Pin, pin};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Context;
|
||||
use arc_swap::ArcSwapOption;
|
||||
use async_trait::async_trait;
|
||||
use atomic_take::AtomicTake;
|
||||
use bytes::Bytes;
|
||||
pub use conn_pool_lib::GlobalConnPoolOptions;
|
||||
use futures::TryFutureExt;
|
||||
use futures::future::{Either, select};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use http::{Method, Response, StatusCode};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Empty};
|
||||
@@ -48,7 +48,8 @@ use crate::cancellation::CancellationHandler;
|
||||
use crate::config::{ProxyConfig, ProxyProtocolV2};
|
||||
use crate::context::RequestContext;
|
||||
use crate::ext::TaskExt;
|
||||
use crate::metrics::Metrics;
|
||||
use crate::id::{ClientConnId, RequestId};
|
||||
use crate::metrics::{Metrics, Protocol};
|
||||
use crate::protocol2::{ConnectHeader, ConnectionInfo, read_proxy_protocol};
|
||||
use crate::rate_limiter::EndpointRateLimiter;
|
||||
use crate::serverless::backend::PoolingBackend;
|
||||
@@ -131,13 +132,12 @@ pub async fn task_main(
|
||||
tracing::error!("could not set nodelay: {e}");
|
||||
continue;
|
||||
}
|
||||
let conn_id = uuid::Uuid::new_v4();
|
||||
let http_conn_span = tracing::info_span!("http_conn", ?conn_id);
|
||||
let conn_id = ClientConnId::new();
|
||||
|
||||
let n_connections = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
.sample(crate::metrics::Protocol::Http);
|
||||
.sample(Protocol::Http);
|
||||
tracing::trace!(?n_connections, threshold = ?config.http_config.client_conn_threshold, "check");
|
||||
if n_connections > config.http_config.client_conn_threshold {
|
||||
tracing::trace!("attempting to cancel a random connection");
|
||||
@@ -154,46 +154,41 @@ pub async fn task_main(
|
||||
let cancellation_handler = cancellation_handler.clone();
|
||||
let endpoint_rate_limiter = endpoint_rate_limiter.clone();
|
||||
let cancellations = cancellations.clone();
|
||||
connections.spawn(
|
||||
async move {
|
||||
let conn_token2 = conn_token.clone();
|
||||
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
|
||||
connections.spawn(async move {
|
||||
let conn_token2 = conn_token.clone();
|
||||
let _cancel_guard = config.http_config.cancel_set.insert(conn_id, conn_token2);
|
||||
|
||||
let session_id = uuid::Uuid::new_v4();
|
||||
let _gauge = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
.guard(Protocol::Http);
|
||||
|
||||
let _gauge = Metrics::get()
|
||||
.proxy
|
||||
.client_connections
|
||||
.guard(crate::metrics::Protocol::Http);
|
||||
let startup_result = Box::pin(connection_startup(
|
||||
config,
|
||||
tls_acceptor,
|
||||
conn_id,
|
||||
conn,
|
||||
peer_addr,
|
||||
))
|
||||
.await;
|
||||
let Some((conn, conn_info)) = startup_result else {
|
||||
return;
|
||||
};
|
||||
|
||||
let startup_result = Box::pin(connection_startup(
|
||||
config,
|
||||
tls_acceptor,
|
||||
session_id,
|
||||
conn,
|
||||
peer_addr,
|
||||
))
|
||||
.await;
|
||||
let Some((conn, conn_info)) = startup_result else {
|
||||
return;
|
||||
};
|
||||
|
||||
Box::pin(connection_handler(
|
||||
config,
|
||||
backend,
|
||||
connections2,
|
||||
cancellations,
|
||||
cancellation_handler,
|
||||
endpoint_rate_limiter,
|
||||
conn_token,
|
||||
conn,
|
||||
conn_info,
|
||||
session_id,
|
||||
))
|
||||
.await;
|
||||
}
|
||||
.instrument(http_conn_span),
|
||||
);
|
||||
Box::pin(connection_handler(
|
||||
config,
|
||||
backend,
|
||||
connections2,
|
||||
cancellations,
|
||||
cancellation_handler,
|
||||
endpoint_rate_limiter,
|
||||
conn_token,
|
||||
conn,
|
||||
conn_info,
|
||||
conn_id,
|
||||
))
|
||||
.await;
|
||||
});
|
||||
}
|
||||
|
||||
connections.wait().await;
|
||||
@@ -230,7 +225,7 @@ impl MaybeTlsAcceptor for &'static ArcSwapOption<crate::config::TlsConfig> {
|
||||
async fn connection_startup(
|
||||
config: &ProxyConfig,
|
||||
tls_acceptor: Arc<dyn MaybeTlsAcceptor>,
|
||||
session_id: uuid::Uuid,
|
||||
conn_id: ClientConnId,
|
||||
conn: TcpStream,
|
||||
peer_addr: SocketAddr,
|
||||
) -> Option<(AsyncRW, ConnectionInfo)> {
|
||||
@@ -265,12 +260,12 @@ async fn connection_startup(
|
||||
IpAddr::V4(ip) => ip.is_private(),
|
||||
IpAddr::V6(_) => false,
|
||||
};
|
||||
info!(?session_id, %conn_info, "accepted new TCP connection");
|
||||
info!(%conn_id, %conn_info, "accepted new TCP connection");
|
||||
|
||||
// try upgrade to TLS, but with a timeout.
|
||||
let conn = match timeout(config.handshake_timeout, tls_acceptor.accept(conn)).await {
|
||||
Ok(Ok(conn)) => {
|
||||
info!(?session_id, %conn_info, "accepted new TLS connection");
|
||||
info!(%conn_id, %conn_info, "accepted new TLS connection");
|
||||
conn
|
||||
}
|
||||
// The handshake failed
|
||||
@@ -278,7 +273,7 @@ async fn connection_startup(
|
||||
if !has_private_peer_addr {
|
||||
Metrics::get().proxy.tls_handshake_failures.inc();
|
||||
}
|
||||
warn!(?session_id, %conn_info, "failed to accept TLS connection: {e:?}");
|
||||
warn!(%conn_id, %conn_info, "failed to accept TLS connection: {e:?}");
|
||||
return None;
|
||||
}
|
||||
// The handshake timed out
|
||||
@@ -286,7 +281,7 @@ async fn connection_startup(
|
||||
if !has_private_peer_addr {
|
||||
Metrics::get().proxy.tls_handshake_failures.inc();
|
||||
}
|
||||
warn!(?session_id, %conn_info, "failed to accept TLS connection: {e:?}");
|
||||
warn!(%conn_id, %conn_info, "failed to accept TLS connection: {e:?}");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
@@ -309,10 +304,8 @@ async fn connection_handler(
|
||||
cancellation_token: CancellationToken,
|
||||
conn: AsyncRW,
|
||||
conn_info: ConnectionInfo,
|
||||
session_id: uuid::Uuid,
|
||||
conn_id: ClientConnId,
|
||||
) {
|
||||
let session_id = AtomicTake::new(session_id);
|
||||
|
||||
// Cancel all current inflight HTTP requests if the HTTP connection is closed.
|
||||
let http_cancellation_token = CancellationToken::new();
|
||||
let _cancel_connection = http_cancellation_token.clone().drop_guard();
|
||||
@@ -322,20 +315,6 @@ async fn connection_handler(
|
||||
let conn = server.serve_connection_with_upgrades(
|
||||
hyper_util::rt::TokioIo::new(conn),
|
||||
hyper::service::service_fn(move |req: hyper::Request<Incoming>| {
|
||||
// First HTTP request shares the same session ID
|
||||
let mut session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);
|
||||
|
||||
if matches!(backend.auth_backend, crate::auth::Backend::Local(_)) {
|
||||
// take session_id from request, if given.
|
||||
if let Some(id) = req
|
||||
.headers()
|
||||
.get(&NEON_REQUEST_ID)
|
||||
.and_then(|id| uuid::Uuid::try_parse_ascii(id.as_bytes()).ok())
|
||||
{
|
||||
session_id = id;
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel the current inflight HTTP request if the requets stream is closed.
|
||||
// This is slightly different to `_cancel_connection` in that
|
||||
// h2 can cancel individual requests with a `RST_STREAM`.
|
||||
@@ -352,7 +331,7 @@ async fn connection_handler(
|
||||
backend.clone(),
|
||||
connections.clone(),
|
||||
cancellation_handler.clone(),
|
||||
session_id,
|
||||
conn_id,
|
||||
conn_info2.clone(),
|
||||
http_request_token,
|
||||
endpoint_rate_limiter.clone(),
|
||||
@@ -362,15 +341,8 @@ async fn connection_handler(
|
||||
.map_ok_or_else(api_error_into_response, |r| r),
|
||||
);
|
||||
async move {
|
||||
let mut res = handler.await;
|
||||
let res = handler.await;
|
||||
cancel_request.disarm();
|
||||
|
||||
// add the session ID to the response
|
||||
if let Ok(resp) = &mut res {
|
||||
resp.headers_mut()
|
||||
.append(&NEON_REQUEST_ID, uuid_to_header_value(session_id));
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
}),
|
||||
@@ -392,6 +364,44 @@ async fn connection_handler(
|
||||
}
|
||||
}
|
||||
|
||||
fn get_request_id(backend: &PoolingBackend, req: &hyper::Request<Incoming>) -> RequestId {
|
||||
if matches!(backend.auth_backend, crate::auth::Backend::Local(_)) {
|
||||
// take session_id from request, if given.
|
||||
|
||||
if let Some(id) = req
|
||||
.headers()
|
||||
.get(&NEON_REQUEST_ID)
|
||||
.and_then(|id| uuid::Uuid::try_parse_ascii(id.as_bytes()).ok())
|
||||
{
|
||||
return RequestId::from_uuid(id);
|
||||
}
|
||||
|
||||
if let Some(id) = req
|
||||
.headers()
|
||||
.get(&NEON_REQUEST_ID)
|
||||
.and_then(|id| id.to_str().ok())
|
||||
.and_then(|id| RequestId::from_str(id).ok())
|
||||
{
|
||||
return id;
|
||||
}
|
||||
}
|
||||
|
||||
RequestId::new()
|
||||
}
|
||||
|
||||
fn set_request_id<T, E>(
|
||||
mut res: Result<hyper::Response<T>, E>,
|
||||
session_id: RequestId,
|
||||
) -> Result<hyper::Response<T>, E> {
|
||||
// add the session ID to the response
|
||||
if let Ok(resp) = &mut res {
|
||||
resp.headers_mut()
|
||||
.append(&NEON_REQUEST_ID, uuid_to_header_value(session_id));
|
||||
}
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn request_handler(
|
||||
mut request: hyper::Request<Incoming>,
|
||||
@@ -399,7 +409,7 @@ async fn request_handler(
|
||||
backend: Arc<PoolingBackend>,
|
||||
ws_connections: TaskTracker,
|
||||
cancellation_handler: Arc<CancellationHandler>,
|
||||
session_id: uuid::Uuid,
|
||||
conn_id: ClientConnId,
|
||||
conn_info: ConnectionInfo,
|
||||
// used to cancel in-flight HTTP requests. not used to cancel websockets
|
||||
http_cancellation_token: CancellationToken,
|
||||
@@ -417,7 +427,8 @@ async fn request_handler(
|
||||
if config.http_config.accept_websockets
|
||||
&& framed_websockets::upgrade::is_upgrade_request(&request)
|
||||
{
|
||||
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Ws);
|
||||
let session_id = RequestId::from_uuid(conn_id.uuid());
|
||||
let ctx = RequestContext::new(conn_id, session_id, conn_info, Protocol::Ws);
|
||||
|
||||
ctx.set_user_agent(
|
||||
request
|
||||
@@ -457,7 +468,8 @@ async fn request_handler(
|
||||
// Return the response so the spawned future can continue.
|
||||
Ok(response.map(|b| b.map_err(|x| match x {}).boxed()))
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::POST {
|
||||
let ctx = RequestContext::new(session_id, conn_info, crate::metrics::Protocol::Http);
|
||||
let session_id = get_request_id(&backend, &request);
|
||||
let ctx = RequestContext::new(conn_id, session_id, conn_info, Protocol::Http);
|
||||
let span = ctx.span();
|
||||
|
||||
let testodrome_id = request
|
||||
@@ -473,6 +485,7 @@ async fn request_handler(
|
||||
|
||||
sql_over_http::handle(config, ctx, request, backend, http_cancellation_token)
|
||||
.instrument(span)
|
||||
.map(|res| set_request_id(res, session_id))
|
||||
.await
|
||||
} else if request.uri().path() == "/sql" && *request.method() == Method::OPTIONS {
|
||||
Response::builder()
|
||||
|
||||
@@ -21,7 +21,8 @@ pub fn check_permission(claims: &Claims, tenant_id: Option<TenantId>) -> Result<
|
||||
| Scope::GenerationsApi
|
||||
| Scope::Infra
|
||||
| Scope::Scrubber
|
||||
| Scope::ControllerPeer,
|
||||
| Scope::ControllerPeer
|
||||
| Scope::TenantEndpoint,
|
||||
_,
|
||||
) => Err(AuthError(
|
||||
format!(
|
||||
|
||||
@@ -52,6 +52,7 @@ tokio-rustls.workspace = true
|
||||
tokio-util.workspace = true
|
||||
tokio.workspace = true
|
||||
tracing.workspace = true
|
||||
uuid.workspace = true
|
||||
measured.workspace = true
|
||||
rustls.workspace = true
|
||||
scopeguard.workspace = true
|
||||
@@ -63,6 +64,7 @@ tokio-postgres-rustls.workspace = true
|
||||
diesel = { version = "2.2.6", features = [
|
||||
"serde_json",
|
||||
"chrono",
|
||||
"uuid",
|
||||
] }
|
||||
diesel-async = { version = "0.5.2", features = ["postgres", "bb8", "async-connection-wrapper"] }
|
||||
diesel_migrations = { version = "2.2.0" }
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
DROP TABLE hadron_safekeepers;
|
||||
DROP TABLE hadron_timeline_safekeepers;
|
||||
@@ -0,0 +1,17 @@
|
||||
-- hadron_safekeepers keep track of all Safe Keeper nodes that exist in the system.
|
||||
-- Upon startup, each Safe Keeper reaches out to the hadron cluster coordinator to register its node ID and listen addresses.
|
||||
|
||||
CREATE TABLE hadron_safekeepers (
|
||||
sk_node_id BIGINT PRIMARY KEY NOT NULL,
|
||||
listen_http_addr VARCHAR NOT NULL,
|
||||
listen_http_port INTEGER NOT NULL,
|
||||
listen_pg_addr VARCHAR NOT NULL,
|
||||
listen_pg_port INTEGER NOT NULL
|
||||
);
|
||||
|
||||
CREATE TABLE hadron_timeline_safekeepers (
|
||||
timeline_id VARCHAR NOT NULL,
|
||||
sk_node_id BIGINT NOT NULL,
|
||||
legacy_endpoint_id UUID DEFAULT NULL,
|
||||
PRIMARY KEY(timeline_id, sk_node_id)
|
||||
);
|
||||
@@ -1,4 +1,5 @@
|
||||
use utils::auth::{AuthError, Claims, Scope};
|
||||
use uuid::Uuid;
|
||||
|
||||
pub fn check_permission(claims: &Claims, required_scope: Scope) -> Result<(), AuthError> {
|
||||
if claims.scope != required_scope {
|
||||
@@ -7,3 +8,14 @@ pub fn check_permission(claims: &Claims, required_scope: Scope) -> Result<(), Au
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn check_endpoint_permission(claims: &Claims, endpoint_id: Uuid) -> Result<(), AuthError> {
|
||||
if claims.scope != Scope::TenantEndpoint {
|
||||
return Err(AuthError("Scope mismatch. Permission denied".into()));
|
||||
}
|
||||
if claims.endpoint_id != Some(endpoint_id) {
|
||||
return Err(AuthError("Endpoint id mismatch. Permission denied".into()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -810,6 +810,7 @@ impl ComputeHook {
|
||||
let send_locked = tokio::select! {
|
||||
guard = send_lock.lock_owned() => {guard},
|
||||
_ = cancel.cancelled() => {
|
||||
tracing::info!("Notification cancelled while waiting for lock");
|
||||
return Err(NotifyError::ShuttingDown)
|
||||
}
|
||||
};
|
||||
@@ -851,11 +852,32 @@ impl ComputeHook {
|
||||
let notify_url = compute_hook_url.as_ref().unwrap();
|
||||
self.do_notify(notify_url, &request, cancel).await
|
||||
} else {
|
||||
self.do_notify_local::<M>(&request).await.map_err(|e| {
|
||||
match self.do_notify_local::<M>(&request).await.map_err(|e| {
|
||||
// This path is for testing only, so munge the error into our prod-style error type.
|
||||
tracing::error!("neon_local notification hook failed: {e}");
|
||||
NotifyError::Fatal(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
})
|
||||
if e.to_string().contains("refresh-configuration-pending") {
|
||||
// If the error message mentions "refresh-configuration-pending", it means the compute node
|
||||
// rejected our notification request because it already trying to reconfigure itself. We
|
||||
// can proceed with the rest of the reconcliation process as the compute node already
|
||||
// discovers the need to reconfigure and will eventually update its configuration once
|
||||
// we update the pageserver mappings. In fact, it is important that we continue with
|
||||
// reconcliation to make sure we update the pageserver mappings to unblock the compute node.
|
||||
tracing::info!("neon_local notification hook failed: {e}");
|
||||
tracing::info!("Notification failed likely due to compute node self-reconfiguration, will retry.");
|
||||
Ok(())
|
||||
} else {
|
||||
tracing::error!("neon_local notification hook failed: {e}");
|
||||
Err(NotifyError::Fatal(StatusCode::INTERNAL_SERVER_ERROR))
|
||||
}
|
||||
}) {
|
||||
// Compute node accepted the notification request. Ok to proceed.
|
||||
Ok(_) => Ok(()),
|
||||
// Compute node rejected our request but it is already self-reconfiguring. Ok to proceed.
|
||||
Err(Ok(_)) => Ok(()),
|
||||
// Fail the reconciliation attempt in all other cases. Recall that this whole code path involving
|
||||
// neon_local is for testing only. In production we always retry failed reconcliations so we
|
||||
// don't have any deadends here.
|
||||
Err(Err(e)) => Err(e),
|
||||
}
|
||||
};
|
||||
|
||||
match result {
|
||||
|
||||
44
storage_controller/src/hadron_utils.rs
Normal file
44
storage_controller/src/hadron_utils.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use rand::Rng;
|
||||
use utils::shard::TenantShardId;
|
||||
|
||||
static CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!@#$%^&*()";
|
||||
|
||||
/// Generate a random string of `length` that can be used as a password. The generated string
|
||||
/// contains alphanumeric characters and special characters (!@#$%^&*())
|
||||
pub fn generate_random_password(length: usize) -> String {
|
||||
let mut rng = rand::thread_rng();
|
||||
(0..length)
|
||||
.map(|_| {
|
||||
let idx = rng.gen_range(0..CHARSET.len());
|
||||
CHARSET[idx] as char
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) struct TenantShardSizeMap {
|
||||
#[expect(dead_code)]
|
||||
pub map: BTreeMap<TenantShardId, u64>,
|
||||
}
|
||||
|
||||
impl TenantShardSizeMap {
|
||||
pub fn new(map: BTreeMap<TenantShardId, u64>) -> Self {
|
||||
Self { map }
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_random_password() {
|
||||
let pwd1 = generate_random_password(10);
|
||||
assert_eq!(pwd1.len(), 10);
|
||||
let pwd2 = generate_random_password(10);
|
||||
assert_ne!(pwd1, pwd2);
|
||||
assert!(pwd1.chars().all(|c| CHARSET.contains(&(c as u8))));
|
||||
assert!(pwd2.chars().all(|c| CHARSET.contains(&(c as u8))));
|
||||
}
|
||||
}
|
||||
@@ -48,7 +48,10 @@ use crate::metrics::{
|
||||
};
|
||||
use crate::persistence::SafekeeperUpsert;
|
||||
use crate::reconciler::ReconcileError;
|
||||
use crate::service::{LeadershipStatus, RECONCILE_TIMEOUT, STARTUP_RECONCILE_TIMEOUT, Service};
|
||||
use crate::service::{
|
||||
LeadershipStatus, RECONCILE_TIMEOUT, STARTUP_RECONCILE_TIMEOUT, Service,
|
||||
TenantMutationLocations,
|
||||
};
|
||||
|
||||
/// State available to HTTP request handlers
|
||||
pub struct HttpState {
|
||||
@@ -734,77 +737,104 @@ async fn handle_tenant_timeline_passthrough(
|
||||
path
|
||||
);
|
||||
|
||||
// Find the node that holds shard zero
|
||||
let (node, tenant_shard_id, consistent) = if tenant_or_shard_id.is_unsharded() {
|
||||
service
|
||||
let tenant_shard_id = if tenant_or_shard_id.is_unsharded() {
|
||||
// If the request contains only tenant ID, find the node that holds shard zero
|
||||
let (_, shard_id) = service
|
||||
.tenant_shard0_node(tenant_or_shard_id.tenant_id)
|
||||
.await?
|
||||
.await?;
|
||||
shard_id
|
||||
} else {
|
||||
let (node, consistent) = service.tenant_shard_node(tenant_or_shard_id).await?;
|
||||
(node, tenant_or_shard_id, consistent)
|
||||
tenant_or_shard_id
|
||||
};
|
||||
|
||||
// Callers will always pass an unsharded tenant ID. Before proxying, we must
|
||||
// rewrite this to a shard-aware shard zero ID.
|
||||
let path = format!("{path}");
|
||||
let tenant_str = tenant_or_shard_id.tenant_id.to_string();
|
||||
let tenant_shard_str = format!("{tenant_shard_id}");
|
||||
let path = path.replace(&tenant_str, &tenant_shard_str);
|
||||
let service_inner = service.clone();
|
||||
|
||||
let latency = &METRICS_REGISTRY
|
||||
.metrics_group
|
||||
.storage_controller_passthrough_request_latency;
|
||||
service.tenant_shard_remote_mutation(tenant_shard_id, |locations| async move {
|
||||
let TenantMutationLocations(locations) = locations;
|
||||
if locations.is_empty() {
|
||||
return Err(ApiError::NotFound(anyhow::anyhow!("Tenant {} not found", tenant_or_shard_id.tenant_id).into()));
|
||||
}
|
||||
|
||||
let path_label = path_without_ids(&path)
|
||||
.split('/')
|
||||
.filter(|token| !token.is_empty())
|
||||
.collect::<Vec<_>>()
|
||||
.join("_");
|
||||
let labels = PageserverRequestLabelGroup {
|
||||
pageserver_id: &node.get_id().to_string(),
|
||||
path: &path_label,
|
||||
method: crate::metrics::Method::Get,
|
||||
};
|
||||
let (tenant_or_shard_id, locations) = locations.into_iter().next().unwrap();
|
||||
let node = locations.latest.node;
|
||||
|
||||
let _timer = latency.start_timer(labels.clone());
|
||||
// Callers will always pass an unsharded tenant ID. Before proxying, we must
|
||||
// rewrite this to a shard-aware shard zero ID.
|
||||
let path = format!("{path}");
|
||||
let tenant_str = tenant_or_shard_id.tenant_id.to_string();
|
||||
let tenant_shard_str = format!("{tenant_shard_id}");
|
||||
let path = path.replace(&tenant_str, &tenant_shard_str);
|
||||
|
||||
let client = mgmt_api::Client::new(
|
||||
service.get_http_client().clone(),
|
||||
node.base_url(),
|
||||
service.get_config().pageserver_jwt_token.as_deref(),
|
||||
);
|
||||
let resp = client.op_raw(method, path).await.map_err(|e|
|
||||
// We return 503 here because if we can't successfully send a request to the pageserver,
|
||||
// either we aren't available or the pageserver is unavailable.
|
||||
ApiError::ResourceUnavailable(format!("Error sending pageserver API request to {node}: {e}").into()))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let error_counter = &METRICS_REGISTRY
|
||||
let latency = &METRICS_REGISTRY
|
||||
.metrics_group
|
||||
.storage_controller_passthrough_request_error;
|
||||
error_counter.inc(labels);
|
||||
}
|
||||
.storage_controller_passthrough_request_latency;
|
||||
|
||||
// Transform 404 into 503 if we raced with a migration
|
||||
if resp.status() == reqwest::StatusCode::NOT_FOUND && !consistent {
|
||||
// Rather than retry here, send the client a 503 to prompt a retry: this matches
|
||||
// the pageserver's use of 503, and all clients calling this API should retry on 503.
|
||||
return Err(ApiError::ResourceUnavailable(
|
||||
format!("Pageserver {node} returned 404 due to ongoing migration, retry later").into(),
|
||||
));
|
||||
}
|
||||
let path_label = path_without_ids(&path)
|
||||
.split('/')
|
||||
.filter(|token| !token.is_empty())
|
||||
.collect::<Vec<_>>()
|
||||
.join("_");
|
||||
let labels = PageserverRequestLabelGroup {
|
||||
pageserver_id: &node.get_id().to_string(),
|
||||
path: &path_label,
|
||||
method: crate::metrics::Method::Get,
|
||||
};
|
||||
|
||||
// We have a reqest::Response, would like a http::Response
|
||||
let mut builder = hyper::Response::builder().status(map_reqwest_hyper_status(resp.status())?);
|
||||
for (k, v) in resp.headers() {
|
||||
builder = builder.header(k.as_str(), v.as_bytes());
|
||||
}
|
||||
let _timer = latency.start_timer(labels.clone());
|
||||
|
||||
let response = builder
|
||||
.body(Body::wrap_stream(resp.bytes_stream()))
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
let client = mgmt_api::Client::new(
|
||||
service_inner.get_http_client().clone(),
|
||||
node.base_url(),
|
||||
service_inner.get_config().pageserver_jwt_token.as_deref(),
|
||||
);
|
||||
let resp = client.op_raw(method, path).await.map_err(|e|
|
||||
// We return 503 here because if we can't successfully send a request to the pageserver,
|
||||
// either we aren't available or the pageserver is unavailable.
|
||||
ApiError::ResourceUnavailable(format!("Error sending pageserver API request to {node}: {e}").into()))?;
|
||||
|
||||
Ok(response)
|
||||
if !resp.status().is_success() {
|
||||
let error_counter = &METRICS_REGISTRY
|
||||
.metrics_group
|
||||
.storage_controller_passthrough_request_error;
|
||||
error_counter.inc(labels);
|
||||
}
|
||||
let resp_staus = resp.status();
|
||||
|
||||
// We have a reqest::Response, would like a http::Response
|
||||
let mut builder = hyper::Response::builder().status(map_reqwest_hyper_status(resp_staus)?);
|
||||
for (k, v) in resp.headers() {
|
||||
builder = builder.header(k.as_str(), v.as_bytes());
|
||||
}
|
||||
let resp_bytes = resp
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
// Inspect 404 errors: at this point, we know that the tenant exists, but the pageserver we route
|
||||
// the request to might not yet be ready. Therefore, if it is a _tenant_ not found error, we can
|
||||
// convert it into a 503. TODO: we should make this part of the check in `tenant_shard_remote_mutation`.
|
||||
// However, `tenant_shard_remote_mutation` currently cannot inspect the HTTP error response body,
|
||||
// so we have to do it here instead.
|
||||
if resp_staus == reqwest::StatusCode::NOT_FOUND {
|
||||
let resp_str = std::str::from_utf8(&resp_bytes)
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
// We only handle "tenant not found" errors; other 404s like timeline not found should
|
||||
// be forwarded as-is.
|
||||
if Service::is_tenant_not_found_error(resp_str, tenant_or_shard_id.tenant_id) {
|
||||
// Rather than retry here, send the client a 503 to prompt a retry: this matches
|
||||
// the pageserver's use of 503, and all clients calling this API should retry on 503.
|
||||
return Err(ApiError::ResourceUnavailable(
|
||||
format!(
|
||||
"Pageserver {node} returned tenant 404 due to ongoing migration, retry later"
|
||||
)
|
||||
.into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
let response = builder
|
||||
.body(Body::from(resp_bytes))
|
||||
.map_err(|e| ApiError::InternalServerError(e.into()))?;
|
||||
Ok(response)
|
||||
}).await?
|
||||
}
|
||||
|
||||
async fn handle_tenant_locate(
|
||||
@@ -1085,9 +1115,10 @@ async fn handle_node_delete(req: Request<Body>) -> Result<Response<Body>, ApiErr
|
||||
|
||||
let state = get_state(&req);
|
||||
let node_id: NodeId = parse_request_param(&req, "node_id")?;
|
||||
let force: bool = parse_query_param(&req, "force")?.unwrap_or(false);
|
||||
json_response(
|
||||
StatusCode::OK,
|
||||
state.service.start_node_delete(node_id).await?,
|
||||
state.service.start_node_delete(node_id, force).await?,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ extern crate hyper0 as hyper;
|
||||
mod auth;
|
||||
mod background_node_operations;
|
||||
mod compute_hook;
|
||||
pub mod hadron_utils;
|
||||
mod heartbeater;
|
||||
pub mod http;
|
||||
mod id_lock_map;
|
||||
|
||||
@@ -76,8 +76,8 @@ pub(crate) struct StorageControllerMetricGroup {
|
||||
/// How many shards would like to reconcile but were blocked by concurrency limits
|
||||
pub(crate) storage_controller_pending_reconciles: measured::Gauge,
|
||||
|
||||
/// How many shards are keep-failing and will be ignored when considering to run optimizations
|
||||
pub(crate) storage_controller_keep_failing_reconciles: measured::Gauge,
|
||||
/// How many shards are stuck and will be ignored when considering to run optimizations
|
||||
pub(crate) storage_controller_stuck_reconciles: measured::Gauge,
|
||||
|
||||
/// HTTP request status counters for handled requests
|
||||
pub(crate) storage_controller_http_request_status:
|
||||
@@ -151,6 +151,29 @@ pub(crate) struct StorageControllerMetricGroup {
|
||||
/// Indicator of completed safekeeper reconciles, broken down by safekeeper.
|
||||
pub(crate) storage_controller_safekeeper_reconciles_complete:
|
||||
measured::CounterVec<SafekeeperReconcilerLabelGroupSet>,
|
||||
|
||||
/* BEGIN HADRON */
|
||||
/// Hadron `config_watcher` reconciliation runs completed, broken down by success/failure.
|
||||
pub(crate) storage_controller_config_watcher_complete:
|
||||
measured::CounterVec<ConfigWatcherCompleteLabelGroupSet>,
|
||||
|
||||
/// Hadron long waits for node state changes during drain and fill.
|
||||
pub(crate) storage_controller_drain_and_fill_long_waits: measured::Counter,
|
||||
|
||||
/// Set to 1 if we detect any page server pods with pending node pool rotation annotations.
|
||||
/// Requires manual reset after oncall investigation.
|
||||
pub(crate) storage_controller_ps_node_pool_rotation_pending: measured::Gauge,
|
||||
|
||||
/// Hadron storage scrubber status.
|
||||
pub(crate) storage_controller_storage_scrub_status:
|
||||
measured::CounterVec<StorageScrubberLabelGroupSet>,
|
||||
|
||||
/// Desired number of pageservers managed by the storage controller
|
||||
pub(crate) storage_controller_num_pageservers_desired: measured::Gauge,
|
||||
|
||||
/// Desired number of safekeepers managed by the storage controller
|
||||
pub(crate) storage_controller_num_safekeeper_desired: measured::Gauge,
|
||||
/* END HADRON */
|
||||
}
|
||||
|
||||
impl StorageControllerMetrics {
|
||||
@@ -173,6 +196,10 @@ impl Default for StorageControllerMetrics {
|
||||
.storage_controller_reconcile_complete
|
||||
.init_all_dense();
|
||||
|
||||
metrics_group
|
||||
.storage_controller_config_watcher_complete
|
||||
.init_all_dense();
|
||||
|
||||
Self {
|
||||
metrics_group,
|
||||
encoder: Mutex::new(measured::text::BufferedTextEncoder::new()),
|
||||
@@ -262,11 +289,48 @@ pub(crate) struct ReconcileLongRunningLabelGroup<'a> {
|
||||
pub(crate) sequence: &'a str,
|
||||
}
|
||||
|
||||
#[derive(measured::LabelGroup, Clone)]
|
||||
#[label(set = StorageScrubberLabelGroupSet)]
|
||||
pub(crate) struct StorageScrubberLabelGroup<'a> {
|
||||
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
|
||||
pub(crate) tenant_id: &'a str,
|
||||
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
|
||||
pub(crate) shard_number: &'a str,
|
||||
#[label(dynamic_with = lasso::ThreadedRodeo, default)]
|
||||
pub(crate) timeline_id: &'a str,
|
||||
pub(crate) outcome: StorageScrubberOutcome,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Clone, Copy)]
|
||||
pub(crate) enum StorageScrubberOutcome {
|
||||
PSOk,
|
||||
PSWarning,
|
||||
PSError,
|
||||
PSOrphan,
|
||||
SKOk,
|
||||
SKError,
|
||||
}
|
||||
|
||||
#[derive(measured::LabelGroup)]
|
||||
#[label(set = ConfigWatcherCompleteLabelGroupSet)]
|
||||
pub(crate) struct ConfigWatcherCompleteLabelGroup {
|
||||
// Reuse the ReconcileOutcome from the SC's reconciliation metrics.
|
||||
pub(crate) status: ReconcileOutcome,
|
||||
}
|
||||
|
||||
#[derive(FixedCardinalityLabel, Clone, Copy)]
|
||||
pub(crate) enum ReconcileOutcome {
|
||||
// Successfully reconciled everything.
|
||||
#[label(rename = "ok")]
|
||||
Success,
|
||||
// Used by tenant-shard reconciler only. Reconciled pageserver state successfully,
|
||||
// but failed to delivery the compute notificiation. This error is typically transient
|
||||
// but if its occurance keeps increasing, it should be investigated.
|
||||
#[label(rename = "ok_no_notify")]
|
||||
SuccessNoNotify,
|
||||
// We failed to reconcile some state and the reconcilation will be retried.
|
||||
Error,
|
||||
// Reconciliation was cancelled.
|
||||
Cancel,
|
||||
}
|
||||
|
||||
|
||||
@@ -51,6 +51,39 @@ pub(crate) struct Node {
|
||||
cancel: CancellationToken,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
const ONE_MILLION: i64 = 1000000;
|
||||
|
||||
// Converts a pool ID to a large number that can be used to assign unique IDs to pods in StatefulSets.
|
||||
/// For example, if pool_id is 1, then the pods have NodeIds 1000000, 1000001, 1000002, etc.
|
||||
/// If pool_id is None, then the pods have NodeIds 0, 1, 2, etc.
|
||||
#[allow(dead_code)]
|
||||
pub fn transform_pool_id(pool_id: Option<i32>) -> i64 {
|
||||
match pool_id {
|
||||
Some(id) => (id as i64) * ONE_MILLION,
|
||||
None => 0,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn get_pool_id_from_node_id(node_id: i64) -> i32 {
|
||||
(node_id / ONE_MILLION) as i32
|
||||
}
|
||||
|
||||
/// Example pod name: page-server-0-1, safe-keeper-1-0
|
||||
#[allow(dead_code)]
|
||||
pub fn get_node_id_from_pod_name(pod_name: &str) -> anyhow::Result<NodeId> {
|
||||
let parts: Vec<&str> = pod_name.split('-').collect();
|
||||
if parts.len() != 4 {
|
||||
return Err(anyhow::anyhow!("Invalid pod name: {}", pod_name));
|
||||
}
|
||||
let pool_id = parts[2].parse::<i32>()?;
|
||||
let node_offset = parts[3].parse::<i64>()?;
|
||||
let node_id = transform_pool_id(Some(pool_id)) + node_offset;
|
||||
|
||||
Ok(NodeId(node_id as u64))
|
||||
}
|
||||
|
||||
/// When updating [`Node::availability`] we use this type to indicate to the caller
|
||||
/// whether/how they changed it.
|
||||
pub(crate) enum AvailabilityTransition {
|
||||
@@ -403,3 +436,25 @@ impl std::fmt::Debug for Node {
|
||||
write!(f, "{} ({})", self.id, self.listen_http_addr)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use utils::id::NodeId;
|
||||
|
||||
use crate::node::get_node_id_from_pod_name;
|
||||
|
||||
#[test]
|
||||
fn test_get_node_id_from_pod_name() {
|
||||
let pod_name = "page-server-3-12";
|
||||
let node_id = get_node_id_from_pod_name(pod_name).unwrap();
|
||||
assert_eq!(node_id, NodeId(3000012));
|
||||
|
||||
let pod_name = "safe-keeper-1-0";
|
||||
let node_id = get_node_id_from_pod_name(pod_name).unwrap();
|
||||
assert_eq!(node_id, NodeId(1000000));
|
||||
|
||||
let pod_name = "invalid-pod-name";
|
||||
let result = get_node_id_from_pod_name(pod_name);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,6 +14,8 @@ use reqwest::StatusCode;
|
||||
use utils::id::{NodeId, TenantId, TimelineId};
|
||||
use utils::lsn::Lsn;
|
||||
|
||||
use crate::hadron_utils::TenantShardSizeMap;
|
||||
|
||||
/// Thin wrapper around [`pageserver_client::mgmt_api::Client`]. It allows the storage
|
||||
/// controller to collect metrics in a non-intrusive manner.
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -86,6 +88,31 @@ impl PageserverClient {
|
||||
)
|
||||
}
|
||||
|
||||
#[expect(dead_code)]
|
||||
pub(crate) async fn tenant_timeline_compact(
|
||||
&self,
|
||||
tenant_shard_id: TenantShardId,
|
||||
timeline_id: TimelineId,
|
||||
force_image_layer_creation: bool,
|
||||
wait_until_done: bool,
|
||||
) -> Result<()> {
|
||||
measured_request!(
|
||||
"tenant_timeline_compact",
|
||||
crate::metrics::Method::Put,
|
||||
&self.node_id_label,
|
||||
self.inner
|
||||
.tenant_timeline_compact(
|
||||
tenant_shard_id,
|
||||
timeline_id,
|
||||
force_image_layer_creation,
|
||||
true,
|
||||
false,
|
||||
wait_until_done,
|
||||
)
|
||||
.await
|
||||
)
|
||||
}
|
||||
|
||||
/* BEGIN_HADRON */
|
||||
pub(crate) async fn tenant_timeline_describe(
|
||||
&self,
|
||||
@@ -101,6 +128,17 @@ impl PageserverClient {
|
||||
.await
|
||||
)
|
||||
}
|
||||
|
||||
#[expect(dead_code)]
|
||||
pub(crate) async fn list_tenant_visible_size(&self) -> Result<TenantShardSizeMap> {
|
||||
measured_request!(
|
||||
"list_tenant_visible_size",
|
||||
crate::metrics::Method::Get,
|
||||
&self.node_id_label,
|
||||
self.inner.list_tenant_visible_size().await
|
||||
)
|
||||
.map(TenantShardSizeMap::new)
|
||||
}
|
||||
/* END_HADRON */
|
||||
|
||||
pub(crate) async fn tenant_scan_remote_storage(
|
||||
@@ -365,6 +403,16 @@ impl PageserverClient {
|
||||
)
|
||||
}
|
||||
|
||||
#[expect(dead_code)]
|
||||
pub(crate) async fn reset_alert_gauges(&self) -> Result<()> {
|
||||
measured_request!(
|
||||
"reset_alert_gauges",
|
||||
crate::metrics::Method::Post,
|
||||
&self.node_id_label,
|
||||
self.inner.reset_alert_gauges().await
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) async fn wait_lsn(
|
||||
&self,
|
||||
tenant_shard_id: TenantShardId,
|
||||
|
||||
@@ -862,11 +862,11 @@ impl Reconciler {
|
||||
Some(conf) if conf.conf.as_ref() == Some(&wanted_conf) => {
|
||||
if refreshed {
|
||||
tracing::info!(
|
||||
node_id=%node.get_id(), "Observed configuration correct after refresh. Notifying compute.");
|
||||
node_id=%node.get_id(), "[Attached] Observed configuration correct after refresh. Notifying compute.");
|
||||
self.compute_notify().await?;
|
||||
} else {
|
||||
// Nothing to do
|
||||
tracing::info!(node_id=%node.get_id(), "Observed configuration already correct.");
|
||||
tracing::info!(node_id=%node.get_id(), "[Attached] Observed configuration already correct.");
|
||||
}
|
||||
}
|
||||
observed => {
|
||||
@@ -945,17 +945,17 @@ impl Reconciler {
|
||||
match self.observed.locations.get(&node.get_id()) {
|
||||
Some(conf) if conf.conf.as_ref() == Some(&wanted_conf) => {
|
||||
// Nothing to do
|
||||
tracing::info!(node_id=%node.get_id(), "Observed configuration already correct.")
|
||||
tracing::info!(node_id=%node.get_id(), "[Secondary] Observed configuration already correct.")
|
||||
}
|
||||
_ => {
|
||||
// Only try and configure secondary locations on nodes that are available. This
|
||||
// allows the reconciler to "succeed" while some secondaries are offline (e.g. after
|
||||
// a node failure, where the failed node will have a secondary intent)
|
||||
if node.is_available() {
|
||||
tracing::info!(node_id=%node.get_id(), "Observed configuration requires update.");
|
||||
tracing::info!(node_id=%node.get_id(), "[Secondary] Observed configuration requires update.");
|
||||
changes.push((node.clone(), wanted_conf))
|
||||
} else {
|
||||
tracing::info!(node_id=%node.get_id(), "Skipping configuration as secondary, node is unavailable");
|
||||
tracing::info!(node_id=%node.get_id(), "[Secondary] Skipping configuration as secondary, node is unavailable");
|
||||
self.observed
|
||||
.locations
|
||||
.insert(node.get_id(), ObservedStateLocation { conf: None });
|
||||
@@ -1066,6 +1066,9 @@ impl Reconciler {
|
||||
}
|
||||
result
|
||||
} else {
|
||||
tracing::info!(
|
||||
"Compute notification is skipped because the tenant shard does not have an attached (primary) location"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,24 @@ diesel::table! {
|
||||
}
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
hadron_safekeepers (sk_node_id) {
|
||||
sk_node_id -> Int8,
|
||||
listen_http_addr -> Varchar,
|
||||
listen_http_port -> Int4,
|
||||
listen_pg_addr -> Varchar,
|
||||
listen_pg_port -> Int4,
|
||||
}
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
hadron_timeline_safekeepers (timeline_id, sk_node_id) {
|
||||
timeline_id -> Varchar,
|
||||
sk_node_id -> Int8,
|
||||
legacy_endpoint_id -> Nullable<Uuid>,
|
||||
}
|
||||
}
|
||||
|
||||
diesel::table! {
|
||||
metadata_health (tenant_id, shard_number, shard_count) {
|
||||
tenant_id -> Varchar,
|
||||
@@ -105,6 +123,8 @@ diesel::table! {
|
||||
|
||||
diesel::allow_tables_to_appear_in_same_query!(
|
||||
controllers,
|
||||
hadron_safekeepers,
|
||||
hadron_timeline_safekeepers,
|
||||
metadata_health,
|
||||
nodes,
|
||||
safekeeper_timeline_pending_ops,
|
||||
|
||||
@@ -207,34 +207,13 @@ enum ShardGenerationValidity {
|
||||
},
|
||||
}
|
||||
|
||||
/// We collect the state of attachments for some operations to determine if the operation
|
||||
/// needs to be retried when it fails.
|
||||
struct TenantShardAttachState {
|
||||
/// The targets of the operation.
|
||||
///
|
||||
/// Tenant shard ID, node ID, node, is intent node observed primary.
|
||||
targets: Vec<(TenantShardId, NodeId, Node, bool)>,
|
||||
|
||||
/// The targets grouped by node ID.
|
||||
by_node_id: HashMap<NodeId, (TenantShardId, Node, bool)>,
|
||||
}
|
||||
|
||||
impl TenantShardAttachState {
|
||||
fn for_api_call(&self) -> Vec<(TenantShardId, Node)> {
|
||||
self.targets
|
||||
.iter()
|
||||
.map(|(tenant_shard_id, _, node, _)| (*tenant_shard_id, node.clone()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
pub const RECONCILER_CONCURRENCY_DEFAULT: usize = 128;
|
||||
pub const PRIORITY_RECONCILER_CONCURRENCY_DEFAULT: usize = 256;
|
||||
pub const SAFEKEEPER_RECONCILER_CONCURRENCY_DEFAULT: usize = 32;
|
||||
|
||||
// Number of consecutive reconciliation errors, occured for one shard,
|
||||
// Number of consecutive reconciliations that have occurred for one shard,
|
||||
// after which the shard is ignored when considering to run optimizations.
|
||||
const MAX_CONSECUTIVE_RECONCILIATION_ERRORS: usize = 5;
|
||||
const MAX_CONSECUTIVE_RECONCILES: usize = 10;
|
||||
|
||||
// Depth of the channel used to enqueue shards for reconciliation when they can't do it immediately.
|
||||
// This channel is finite-size to avoid using excessive memory if we get into a state where reconciles are finishing more slowly
|
||||
@@ -719,47 +698,70 @@ pub(crate) enum ReconcileResultRequest {
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MutationLocation {
|
||||
node: Node,
|
||||
generation: Generation,
|
||||
pub(crate) struct MutationLocation {
|
||||
pub(crate) node: Node,
|
||||
pub(crate) generation: Generation,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ShardMutationLocations {
|
||||
latest: MutationLocation,
|
||||
other: Vec<MutationLocation>,
|
||||
pub(crate) struct ShardMutationLocations {
|
||||
pub(crate) latest: MutationLocation,
|
||||
pub(crate) other: Vec<MutationLocation>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
struct TenantMutationLocations(BTreeMap<TenantShardId, ShardMutationLocations>);
|
||||
pub(crate) struct TenantMutationLocations(pub BTreeMap<TenantShardId, ShardMutationLocations>);
|
||||
|
||||
struct ReconcileAllResult {
|
||||
spawned_reconciles: usize,
|
||||
keep_failing_reconciles: usize,
|
||||
stuck_reconciles: usize,
|
||||
has_delayed_reconciles: bool,
|
||||
}
|
||||
|
||||
impl ReconcileAllResult {
|
||||
fn new(
|
||||
spawned_reconciles: usize,
|
||||
keep_failing_reconciles: usize,
|
||||
stuck_reconciles: usize,
|
||||
has_delayed_reconciles: bool,
|
||||
) -> Self {
|
||||
assert!(
|
||||
spawned_reconciles >= keep_failing_reconciles,
|
||||
"It is impossible to have more keep-failing reconciles than spawned reconciles"
|
||||
spawned_reconciles >= stuck_reconciles,
|
||||
"It is impossible to have less spawned reconciles than stuck reconciles"
|
||||
);
|
||||
Self {
|
||||
spawned_reconciles,
|
||||
keep_failing_reconciles,
|
||||
stuck_reconciles,
|
||||
has_delayed_reconciles,
|
||||
}
|
||||
}
|
||||
|
||||
/// We can run optimizations only if we don't have any delayed reconciles and
|
||||
/// all spawned reconciles are also keep-failing reconciles.
|
||||
/// all spawned reconciles are also stuck reconciles.
|
||||
fn can_run_optimizations(&self) -> bool {
|
||||
!self.has_delayed_reconciles && self.spawned_reconciles == self.keep_failing_reconciles
|
||||
!self.has_delayed_reconciles && self.spawned_reconciles == self.stuck_reconciles
|
||||
}
|
||||
}
|
||||
|
||||
enum TenantIdOrShardId {
|
||||
TenantId(TenantId),
|
||||
TenantShardId(TenantShardId),
|
||||
}
|
||||
|
||||
impl TenantIdOrShardId {
|
||||
fn tenant_id(&self) -> TenantId {
|
||||
match self {
|
||||
TenantIdOrShardId::TenantId(tenant_id) => *tenant_id,
|
||||
TenantIdOrShardId::TenantShardId(tenant_shard_id) => tenant_shard_id.tenant_id,
|
||||
}
|
||||
}
|
||||
|
||||
fn matches(&self, tenant_shard_id: &TenantShardId) -> bool {
|
||||
match self {
|
||||
TenantIdOrShardId::TenantId(tenant_id) => tenant_shard_id.tenant_id == *tenant_id,
|
||||
TenantIdOrShardId::TenantShardId(this_tenant_shard_id) => {
|
||||
this_tenant_shard_id == tenant_shard_id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1503,7 +1505,6 @@ impl Service {
|
||||
|
||||
match result.result {
|
||||
Ok(()) => {
|
||||
tenant.consecutive_errors_count = 0;
|
||||
tenant.apply_observed_deltas(deltas);
|
||||
tenant.waiter.advance(result.sequence);
|
||||
}
|
||||
@@ -1522,8 +1523,6 @@ impl Service {
|
||||
}
|
||||
}
|
||||
|
||||
tenant.consecutive_errors_count = tenant.consecutive_errors_count.saturating_add(1);
|
||||
|
||||
// Ordering: populate last_error before advancing error_seq,
|
||||
// so that waiters will see the correct error after waiting.
|
||||
tenant.set_last_error(result.sequence, e);
|
||||
@@ -1535,6 +1534,8 @@ impl Service {
|
||||
}
|
||||
}
|
||||
|
||||
tenant.consecutive_reconciles_count = tenant.consecutive_reconciles_count.saturating_add(1);
|
||||
|
||||
// If we just finished detaching all shards for a tenant, it might be time to drop it from memory.
|
||||
if tenant.policy == PlacementPolicy::Detached {
|
||||
// We may only drop a tenant from memory while holding the exclusive lock on the tenant ID: this protects us
|
||||
@@ -4773,72 +4774,24 @@ impl Service {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_observed_consistent_with_intent(
|
||||
&self,
|
||||
shard: &TenantShard,
|
||||
intent_node_id: NodeId,
|
||||
) -> bool {
|
||||
if let Some(location) = shard.observed.locations.get(&intent_node_id)
|
||||
&& let Some(ref conf) = location.conf
|
||||
&& (conf.mode == LocationConfigMode::AttachedSingle
|
||||
|| conf.mode == LocationConfigMode::AttachedMulti)
|
||||
{
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_tenant_shards(
|
||||
&self,
|
||||
tenant_id: TenantId,
|
||||
) -> Result<TenantShardAttachState, ApiError> {
|
||||
let locked = self.inner.read().unwrap();
|
||||
let mut targets = Vec::new();
|
||||
let mut by_node_id = HashMap::new();
|
||||
|
||||
// If the request got an unsharded tenant id, then apply
|
||||
// the operation to all shards. Otherwise, apply it to a specific shard.
|
||||
let shards_range = TenantShardId::tenant_range(tenant_id);
|
||||
|
||||
for (tenant_shard_id, shard) in locked.tenants.range(shards_range) {
|
||||
if let Some(node_id) = shard.intent.get_attached() {
|
||||
let node = locked
|
||||
.nodes
|
||||
.get(node_id)
|
||||
.expect("Pageservers may not be deleted while referenced");
|
||||
|
||||
let consistent = self.is_observed_consistent_with_intent(shard, *node_id);
|
||||
|
||||
targets.push((*tenant_shard_id, *node_id, node.clone(), consistent));
|
||||
by_node_id.insert(*node_id, (*tenant_shard_id, node.clone(), consistent));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(TenantShardAttachState {
|
||||
targets,
|
||||
by_node_id,
|
||||
})
|
||||
pub(crate) fn is_tenant_not_found_error(body: &str, tenant_id: TenantId) -> bool {
|
||||
body.contains(&format!("tenant {tenant_id}"))
|
||||
}
|
||||
|
||||
fn process_result_and_passthrough_errors<T>(
|
||||
&self,
|
||||
tenant_id: TenantId,
|
||||
results: Vec<(Node, Result<T, mgmt_api::Error>)>,
|
||||
attach_state: TenantShardAttachState,
|
||||
) -> Result<Vec<(Node, T)>, ApiError> {
|
||||
let mut processed_results: Vec<(Node, T)> = Vec::with_capacity(results.len());
|
||||
debug_assert_eq!(results.len(), attach_state.targets.len());
|
||||
for (node, res) in results {
|
||||
let is_consistent = attach_state
|
||||
.by_node_id
|
||||
.get(&node.get_id())
|
||||
.map(|(_, _, consistent)| *consistent);
|
||||
match res {
|
||||
Ok(res) => processed_results.push((node, res)),
|
||||
Err(mgmt_api::Error::ApiError(StatusCode::NOT_FOUND, _))
|
||||
if is_consistent == Some(false) =>
|
||||
Err(mgmt_api::Error::ApiError(StatusCode::NOT_FOUND, body))
|
||||
if Self::is_tenant_not_found_error(&body, tenant_id) =>
|
||||
{
|
||||
// This is expected if the attach is not finished yet. Return 503 so that the client can retry.
|
||||
// If there's a tenant not found, we are still in the process of attaching the tenant.
|
||||
// Return 503 so that the client can retry.
|
||||
return Err(ApiError::ResourceUnavailable(
|
||||
format!(
|
||||
"Timeline is not attached to the pageserver {} yet, please retry",
|
||||
@@ -4866,35 +4819,48 @@ impl Service {
|
||||
)
|
||||
.await;
|
||||
|
||||
let attach_state = self.collect_tenant_shards(tenant_id)?;
|
||||
|
||||
let results = self
|
||||
.tenant_for_shards_api(
|
||||
attach_state.for_api_call(),
|
||||
|tenant_shard_id, client| async move {
|
||||
client
|
||||
.timeline_lease_lsn(tenant_shard_id, timeline_id, lsn)
|
||||
.await
|
||||
},
|
||||
1,
|
||||
1,
|
||||
SHORT_RECONCILE_TIMEOUT,
|
||||
&self.cancel,
|
||||
)
|
||||
.await;
|
||||
|
||||
let leases = self.process_result_and_passthrough_errors(results, attach_state)?;
|
||||
let mut valid_until = None;
|
||||
for (_, lease) in leases {
|
||||
if let Some(ref mut valid_until) = valid_until {
|
||||
*valid_until = std::cmp::min(*valid_until, lease.valid_until);
|
||||
} else {
|
||||
valid_until = Some(lease.valid_until);
|
||||
self.tenant_remote_mutation(tenant_id, |locations| async move {
|
||||
if locations.0.is_empty() {
|
||||
return Err(ApiError::NotFound(
|
||||
anyhow::anyhow!("Tenant not found").into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
Ok(LsnLease {
|
||||
valid_until: valid_until.unwrap_or_else(SystemTime::now),
|
||||
|
||||
let results = self
|
||||
.tenant_for_shards_api(
|
||||
locations
|
||||
.0
|
||||
.iter()
|
||||
.map(|(tenant_shard_id, ShardMutationLocations { latest, .. })| {
|
||||
(*tenant_shard_id, latest.node.clone())
|
||||
})
|
||||
.collect(),
|
||||
|tenant_shard_id, client| async move {
|
||||
client
|
||||
.timeline_lease_lsn(tenant_shard_id, timeline_id, lsn)
|
||||
.await
|
||||
},
|
||||
1,
|
||||
1,
|
||||
SHORT_RECONCILE_TIMEOUT,
|
||||
&self.cancel,
|
||||
)
|
||||
.await;
|
||||
|
||||
let leases = self.process_result_and_passthrough_errors(tenant_id, results)?;
|
||||
let mut valid_until = None;
|
||||
for (_, lease) in leases {
|
||||
if let Some(ref mut valid_until) = valid_until {
|
||||
*valid_until = std::cmp::min(*valid_until, lease.valid_until);
|
||||
} else {
|
||||
valid_until = Some(lease.valid_until);
|
||||
}
|
||||
}
|
||||
Ok(LsnLease {
|
||||
valid_until: valid_until.unwrap_or_else(SystemTime::now),
|
||||
})
|
||||
})
|
||||
.await?
|
||||
}
|
||||
|
||||
pub(crate) async fn tenant_timeline_download_heatmap_layers(
|
||||
@@ -5041,11 +5007,37 @@ impl Service {
|
||||
/// - Looks up the shards and the nodes where they were most recently attached
|
||||
/// - Guarantees that after the inner function returns, the shards' generations haven't moved on: this
|
||||
/// ensures that the remote operation acted on the most recent generation, and is therefore durable.
|
||||
async fn tenant_remote_mutation<R, O, F>(
|
||||
pub(crate) async fn tenant_remote_mutation<R, O, F>(
|
||||
&self,
|
||||
tenant_id: TenantId,
|
||||
op: O,
|
||||
) -> Result<R, ApiError>
|
||||
where
|
||||
O: FnOnce(TenantMutationLocations) -> F,
|
||||
F: std::future::Future<Output = R>,
|
||||
{
|
||||
self.tenant_remote_mutation_inner(TenantIdOrShardId::TenantId(tenant_id), op)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn tenant_shard_remote_mutation<R, O, F>(
|
||||
&self,
|
||||
tenant_shard_id: TenantShardId,
|
||||
op: O,
|
||||
) -> Result<R, ApiError>
|
||||
where
|
||||
O: FnOnce(TenantMutationLocations) -> F,
|
||||
F: std::future::Future<Output = R>,
|
||||
{
|
||||
self.tenant_remote_mutation_inner(TenantIdOrShardId::TenantShardId(tenant_shard_id), op)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn tenant_remote_mutation_inner<R, O, F>(
|
||||
&self,
|
||||
tenant_id_or_shard_id: TenantIdOrShardId,
|
||||
op: O,
|
||||
) -> Result<R, ApiError>
|
||||
where
|
||||
O: FnOnce(TenantMutationLocations) -> F,
|
||||
F: std::future::Future<Output = R>,
|
||||
@@ -5057,7 +5049,13 @@ impl Service {
|
||||
// run concurrently with reconciliations, and it is not guaranteed that the node we find here
|
||||
// will still be the latest when we're done: we will check generations again at the end of
|
||||
// this function to handle that.
|
||||
let generations = self.persistence.tenant_generations(tenant_id).await?;
|
||||
let generations = self
|
||||
.persistence
|
||||
.tenant_generations(tenant_id_or_shard_id.tenant_id())
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter(|i| tenant_id_or_shard_id.matches(&i.tenant_shard_id))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if generations
|
||||
.iter()
|
||||
@@ -5071,9 +5069,14 @@ impl Service {
|
||||
// One or more shards has not been attached to a pageserver. Check if this is because it's configured
|
||||
// to be detached (409: caller should give up), or because it's meant to be attached but isn't yet (503: caller should retry)
|
||||
let locked = self.inner.read().unwrap();
|
||||
for (shard_id, shard) in
|
||||
locked.tenants.range(TenantShardId::tenant_range(tenant_id))
|
||||
{
|
||||
let tenant_shards = locked
|
||||
.tenants
|
||||
.range(TenantShardId::tenant_range(
|
||||
tenant_id_or_shard_id.tenant_id(),
|
||||
))
|
||||
.filter(|(shard_id, _)| tenant_id_or_shard_id.matches(shard_id))
|
||||
.collect::<Vec<_>>();
|
||||
for (shard_id, shard) in tenant_shards {
|
||||
match shard.policy {
|
||||
PlacementPolicy::Attached(_) => {
|
||||
// This shard is meant to be attached: the caller is not wrong to try and
|
||||
@@ -5183,7 +5186,14 @@ impl Service {
|
||||
// Post-check: are all the generations of all the shards the same as they were initially? This proves that
|
||||
// our remote operation executed on the latest generation and is therefore persistent.
|
||||
{
|
||||
let latest_generations = self.persistence.tenant_generations(tenant_id).await?;
|
||||
let latest_generations = self
|
||||
.persistence
|
||||
.tenant_generations(tenant_id_or_shard_id.tenant_id())
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter(|i| tenant_id_or_shard_id.matches(&i.tenant_shard_id))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if latest_generations
|
||||
.into_iter()
|
||||
.map(
|
||||
@@ -5317,7 +5327,7 @@ impl Service {
|
||||
pub(crate) async fn tenant_shard0_node(
|
||||
&self,
|
||||
tenant_id: TenantId,
|
||||
) -> Result<(Node, TenantShardId, bool), ApiError> {
|
||||
) -> Result<(Node, TenantShardId), ApiError> {
|
||||
let tenant_shard_id = {
|
||||
let locked = self.inner.read().unwrap();
|
||||
let Some((tenant_shard_id, _shard)) = locked
|
||||
@@ -5335,7 +5345,7 @@ impl Service {
|
||||
|
||||
self.tenant_shard_node(tenant_shard_id)
|
||||
.await
|
||||
.map(|(node, consistent)| (node, tenant_shard_id, consistent))
|
||||
.map(|node| (node, tenant_shard_id))
|
||||
}
|
||||
|
||||
/// When you need to send an HTTP request to the pageserver that holds a shard of a tenant, this
|
||||
@@ -5345,7 +5355,7 @@ impl Service {
|
||||
pub(crate) async fn tenant_shard_node(
|
||||
&self,
|
||||
tenant_shard_id: TenantShardId,
|
||||
) -> Result<(Node, bool), ApiError> {
|
||||
) -> Result<Node, ApiError> {
|
||||
// Look up in-memory state and maybe use the node from there.
|
||||
{
|
||||
let locked = self.inner.read().unwrap();
|
||||
@@ -5375,8 +5385,7 @@ impl Service {
|
||||
"Shard refers to nonexistent node"
|
||||
)));
|
||||
};
|
||||
let consistent = self.is_observed_consistent_with_intent(shard, *intent_node_id);
|
||||
return Ok((node.clone(), consistent));
|
||||
return Ok(node.clone());
|
||||
}
|
||||
};
|
||||
|
||||
@@ -5411,7 +5420,7 @@ impl Service {
|
||||
)));
|
||||
};
|
||||
// As a reconciliation is in flight, we do not have the observed state yet, and therefore we assume it is always inconsistent.
|
||||
Ok((node.clone(), false))
|
||||
Ok(node.clone())
|
||||
}
|
||||
|
||||
pub(crate) fn tenant_locate(
|
||||
@@ -7385,6 +7394,7 @@ impl Service {
|
||||
self: &Arc<Self>,
|
||||
node_id: NodeId,
|
||||
policy_on_start: NodeSchedulingPolicy,
|
||||
force: bool,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<(), OperationError> {
|
||||
let reconciler_config = ReconcilerConfigBuilder::new(ReconcilerPriority::Normal).build();
|
||||
@@ -7392,23 +7402,27 @@ impl Service {
|
||||
let mut waiters: Vec<ReconcilerWaiter> = Vec::new();
|
||||
let mut tid_iter = create_shared_shard_iterator(self.clone());
|
||||
|
||||
let reset_node_policy_on_cancel = || async {
|
||||
match self
|
||||
.node_configure(node_id, None, Some(policy_on_start))
|
||||
.await
|
||||
{
|
||||
Ok(()) => OperationError::Cancelled,
|
||||
Err(err) => {
|
||||
OperationError::FinalizeError(
|
||||
format!(
|
||||
"Failed to finalise delete cancel of {} by setting scheduling policy to {}: {}",
|
||||
node_id, String::from(policy_on_start), err
|
||||
)
|
||||
.into(),
|
||||
)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
while !tid_iter.finished() {
|
||||
if cancel.is_cancelled() {
|
||||
match self
|
||||
.node_configure(node_id, None, Some(policy_on_start))
|
||||
.await
|
||||
{
|
||||
Ok(()) => return Err(OperationError::Cancelled),
|
||||
Err(err) => {
|
||||
return Err(OperationError::FinalizeError(
|
||||
format!(
|
||||
"Failed to finalise delete cancel of {} by setting scheduling policy to {}: {}",
|
||||
node_id, String::from(policy_on_start), err
|
||||
)
|
||||
.into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
return Err(reset_node_policy_on_cancel().await);
|
||||
}
|
||||
|
||||
operation_utils::validate_node_state(
|
||||
@@ -7477,8 +7491,18 @@ impl Service {
|
||||
nodes,
|
||||
reconciler_config,
|
||||
);
|
||||
if let Some(some) = waiter {
|
||||
waiters.push(some);
|
||||
|
||||
if force {
|
||||
// Here we remove an existing observed location for the node we're removing, and it will
|
||||
// not be re-added by a reconciler's completion because we filter out removed nodes in
|
||||
// process_result.
|
||||
//
|
||||
// Note that we update the shard's observed state _after_ calling maybe_configured_reconcile_shard:
|
||||
// that means any reconciles we spawned will know about the node we're deleting,
|
||||
// enabling them to do live migrations if it's still online.
|
||||
tenant_shard.observed.locations.remove(&node_id);
|
||||
} else if let Some(waiter) = waiter {
|
||||
waiters.push(waiter);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -7492,21 +7516,7 @@ impl Service {
|
||||
|
||||
while !waiters.is_empty() {
|
||||
if cancel.is_cancelled() {
|
||||
match self
|
||||
.node_configure(node_id, None, Some(policy_on_start))
|
||||
.await
|
||||
{
|
||||
Ok(()) => return Err(OperationError::Cancelled),
|
||||
Err(err) => {
|
||||
return Err(OperationError::FinalizeError(
|
||||
format!(
|
||||
"Failed to finalise drain cancel of {} by setting scheduling policy to {}: {}",
|
||||
node_id, String::from(policy_on_start), err
|
||||
)
|
||||
.into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
return Err(reset_node_policy_on_cancel().await);
|
||||
}
|
||||
|
||||
tracing::info!("Awaiting {} pending delete reconciliations", waiters.len());
|
||||
@@ -7516,6 +7526,12 @@ impl Service {
|
||||
.await;
|
||||
}
|
||||
|
||||
let pf = pausable_failpoint!("delete-node-after-reconciles-spawned", &cancel);
|
||||
if pf.is_err() {
|
||||
// An error from pausable_failpoint indicates the cancel token was triggered.
|
||||
return Err(reset_node_policy_on_cancel().await);
|
||||
}
|
||||
|
||||
self.persistence
|
||||
.set_tombstone(node_id)
|
||||
.await
|
||||
@@ -8111,6 +8127,7 @@ impl Service {
|
||||
pub(crate) async fn start_node_delete(
|
||||
self: &Arc<Self>,
|
||||
node_id: NodeId,
|
||||
force: bool,
|
||||
) -> Result<(), ApiError> {
|
||||
let (ongoing_op, node_policy, schedulable_nodes_count) = {
|
||||
let locked = self.inner.read().unwrap();
|
||||
@@ -8180,7 +8197,7 @@ impl Service {
|
||||
|
||||
tracing::info!("Delete background operation starting");
|
||||
let res = service
|
||||
.delete_node(node_id, policy_on_start, cancel)
|
||||
.delete_node(node_id, policy_on_start, force, cancel)
|
||||
.await;
|
||||
match res {
|
||||
Ok(()) => {
|
||||
@@ -8632,7 +8649,7 @@ impl Service {
|
||||
// This function is an efficient place to update lazy statistics, since we are walking
|
||||
// all tenants.
|
||||
let mut pending_reconciles = 0;
|
||||
let mut keep_failing_reconciles = 0;
|
||||
let mut stuck_reconciles = 0;
|
||||
let mut az_violations = 0;
|
||||
|
||||
// If we find any tenants to drop from memory, stash them to offload after
|
||||
@@ -8668,30 +8685,32 @@ impl Service {
|
||||
|
||||
// Eventual consistency: if an earlier reconcile job failed, and the shard is still
|
||||
// dirty, spawn another one
|
||||
let consecutive_errors_count = shard.consecutive_errors_count;
|
||||
if self
|
||||
.maybe_reconcile_shard(shard, &pageservers, ReconcilerPriority::Normal)
|
||||
.is_some()
|
||||
{
|
||||
spawned_reconciles += 1;
|
||||
|
||||
// Count shards that are keep-failing. We still want to reconcile them
|
||||
// to avoid a situation where a shard is stuck.
|
||||
// But we don't want to consider them when deciding to run optimizations.
|
||||
if consecutive_errors_count >= MAX_CONSECUTIVE_RECONCILIATION_ERRORS {
|
||||
if shard.consecutive_reconciles_count >= MAX_CONSECUTIVE_RECONCILES {
|
||||
// Count shards that are stuck, butwe still want to reconcile them.
|
||||
// We don't want to consider them when deciding to run optimizations.
|
||||
tracing::warn!(
|
||||
tenant_id=%shard.tenant_shard_id.tenant_id,
|
||||
shard_id=%shard.tenant_shard_id.shard_slug(),
|
||||
"Shard reconciliation is keep-failing: {} errors",
|
||||
consecutive_errors_count
|
||||
"Shard reconciliation is stuck: {} consecutive launches",
|
||||
shard.consecutive_reconciles_count
|
||||
);
|
||||
keep_failing_reconciles += 1;
|
||||
stuck_reconciles += 1;
|
||||
}
|
||||
} else {
|
||||
if shard.delayed_reconcile {
|
||||
// Shard wanted to reconcile but for some reason couldn't.
|
||||
pending_reconciles += 1;
|
||||
}
|
||||
} else if shard.delayed_reconcile {
|
||||
// Shard wanted to reconcile but for some reason couldn't.
|
||||
pending_reconciles += 1;
|
||||
}
|
||||
|
||||
// Reset the counter when we don't need to launch a reconcile.
|
||||
shard.consecutive_reconciles_count = 0;
|
||||
}
|
||||
// If this tenant is detached, try dropping it from memory. This is usually done
|
||||
// proactively in [`Self::process_results`], but we do it here to handle the edge
|
||||
// case where a reconcile completes while someone else is holding an op lock for the tenant.
|
||||
@@ -8727,14 +8746,10 @@ impl Service {
|
||||
|
||||
metrics::METRICS_REGISTRY
|
||||
.metrics_group
|
||||
.storage_controller_keep_failing_reconciles
|
||||
.set(keep_failing_reconciles as i64);
|
||||
.storage_controller_stuck_reconciles
|
||||
.set(stuck_reconciles as i64);
|
||||
|
||||
ReconcileAllResult::new(
|
||||
spawned_reconciles,
|
||||
keep_failing_reconciles,
|
||||
has_delayed_reconciles,
|
||||
)
|
||||
ReconcileAllResult::new(spawned_reconciles, stuck_reconciles, has_delayed_reconciles)
|
||||
}
|
||||
|
||||
/// `optimize` in this context means identifying shards which have valid scheduled locations, but
|
||||
|
||||
@@ -131,14 +131,16 @@ pub(crate) struct TenantShard {
|
||||
#[serde(serialize_with = "read_last_error")]
|
||||
pub(crate) last_error: std::sync::Arc<std::sync::Mutex<Option<Arc<ReconcileError>>>>,
|
||||
|
||||
/// Number of consecutive reconciliation errors that have occurred for this shard.
|
||||
/// Amount of consecutive [`crate::service::Service::reconcile_all`] iterations that have been
|
||||
/// scheduled a reconciliation for this shard.
|
||||
///
|
||||
/// When this count reaches MAX_CONSECUTIVE_RECONCILIATION_ERRORS, the tenant shard
|
||||
/// will be countered as keep-failing in `reconcile_all` calculations. This will lead to
|
||||
/// allowing optimizations to run even with some failing shards.
|
||||
/// If this reaches `MAX_CONSECUTIVE_RECONCILES`, the shard is considered "stuck" and will be
|
||||
/// ignored when deciding whether optimizations can run. This includes both successful and failed
|
||||
/// reconciliations.
|
||||
///
|
||||
/// The counter is reset to 0 after a successful reconciliation.
|
||||
pub(crate) consecutive_errors_count: usize,
|
||||
/// Incremented in [`crate::service::Service::process_result`], and reset to 0 when
|
||||
/// [`crate::service::Service::reconcile_all`] determines no reconciliation is needed for this shard.
|
||||
pub(crate) consecutive_reconciles_count: usize,
|
||||
|
||||
/// If we have a pending compute notification that for some reason we weren't able to send,
|
||||
/// set this to true. If this is set, calls to [`Self::get_reconcile_needed`] will return Yes
|
||||
@@ -603,7 +605,7 @@ impl TenantShard {
|
||||
waiter: Arc::new(SeqWait::new(Sequence(0))),
|
||||
error_waiter: Arc::new(SeqWait::new(Sequence(0))),
|
||||
last_error: Arc::default(),
|
||||
consecutive_errors_count: 0,
|
||||
consecutive_reconciles_count: 0,
|
||||
pending_compute_notification: false,
|
||||
scheduling_policy: ShardSchedulingPolicy::default(),
|
||||
preferred_node: None,
|
||||
@@ -1609,7 +1611,13 @@ impl TenantShard {
|
||||
|
||||
// Update result counter
|
||||
let outcome_label = match &result {
|
||||
Ok(_) => ReconcileOutcome::Success,
|
||||
Ok(_) => {
|
||||
if reconciler.compute_notify_failure {
|
||||
ReconcileOutcome::SuccessNoNotify
|
||||
} else {
|
||||
ReconcileOutcome::Success
|
||||
}
|
||||
}
|
||||
Err(ReconcileError::Cancel) => ReconcileOutcome::Cancel,
|
||||
Err(_) => ReconcileOutcome::Error,
|
||||
};
|
||||
@@ -1908,7 +1916,7 @@ impl TenantShard {
|
||||
waiter: Arc::new(SeqWait::new(Sequence::initial())),
|
||||
error_waiter: Arc::new(SeqWait::new(Sequence::initial())),
|
||||
last_error: Arc::default(),
|
||||
consecutive_errors_count: 0,
|
||||
consecutive_reconciles_count: 0,
|
||||
pending_compute_notification: false,
|
||||
delayed_reconcile: false,
|
||||
scheduling_policy: serde_json::from_str(&tsp.scheduling_policy).unwrap(),
|
||||
|
||||
@@ -2119,11 +2119,14 @@ class NeonStorageController(MetricsGetter, LogUtils):
|
||||
headers=self.headers(TokenScope.ADMIN),
|
||||
)
|
||||
|
||||
def node_delete(self, node_id):
|
||||
def node_delete(self, node_id, force: bool = False):
|
||||
log.info(f"node_delete({node_id})")
|
||||
query = f"{self.api}/control/v1/node/{node_id}/delete"
|
||||
if force:
|
||||
query += "?force=true"
|
||||
self.request(
|
||||
"PUT",
|
||||
f"{self.api}/control/v1/node/{node_id}/delete",
|
||||
query,
|
||||
headers=self.headers(TokenScope.ADMIN),
|
||||
)
|
||||
|
||||
|
||||
@@ -847,7 +847,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
|
||||
return res_json
|
||||
|
||||
def timeline_lsn_lease(
|
||||
self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, lsn: Lsn
|
||||
self, tenant_id: TenantId | TenantShardId, timeline_id: TimelineId, lsn: Lsn, **kwargs
|
||||
):
|
||||
data = {
|
||||
"lsn": str(lsn),
|
||||
@@ -857,6 +857,7 @@ class PageserverHttpClient(requests.Session, MetricsGetter):
|
||||
res = self.post(
|
||||
f"http://localhost:{self.port}/v1/tenant/{tenant_id}/timeline/{timeline_id}/lsn_lease",
|
||||
json=data,
|
||||
**kwargs,
|
||||
)
|
||||
self.verbose_error(res)
|
||||
res_json = res.json()
|
||||
|
||||
@@ -187,19 +187,21 @@ def test_create_snapshot(
|
||||
env.pageserver.stop()
|
||||
env.storage_controller.stop()
|
||||
|
||||
# Directory `compatibility_snapshot_dir` is uploaded to S3 in a workflow, keep the name in sync with it
|
||||
compatibility_snapshot_dir = (
|
||||
# Directory `new_compatibility_snapshot_dir` is uploaded to S3 in a workflow, keep the name in sync with it
|
||||
new_compatibility_snapshot_dir = (
|
||||
top_output_dir / f"compatibility_snapshot_pg{pg_version.v_prefixed}"
|
||||
)
|
||||
if compatibility_snapshot_dir.exists():
|
||||
shutil.rmtree(compatibility_snapshot_dir)
|
||||
if new_compatibility_snapshot_dir.exists():
|
||||
shutil.rmtree(new_compatibility_snapshot_dir)
|
||||
|
||||
shutil.copytree(
|
||||
test_output_dir,
|
||||
compatibility_snapshot_dir,
|
||||
new_compatibility_snapshot_dir,
|
||||
ignore=shutil.ignore_patterns("pg_dynshmem"),
|
||||
)
|
||||
|
||||
log.info(f"Copied new compatibility snapshot dir to: {new_compatibility_snapshot_dir}")
|
||||
|
||||
|
||||
# check_neon_works does recovery from WAL => the compatibility snapshot's WAL is old => will log this warning
|
||||
ingest_lag_log_line = ".*ingesting record with timestamp lagging more than wait_lsn_timeout.*"
|
||||
@@ -218,6 +220,7 @@ def test_backward_compatibility(
|
||||
"""
|
||||
Test that the new binaries can read old data
|
||||
"""
|
||||
log.info(f"Using snapshot dir at {compatibility_snapshot_dir}")
|
||||
neon_env_builder.num_safekeepers = 3
|
||||
env = neon_env_builder.from_repo_dir(compatibility_snapshot_dir / "repo")
|
||||
env.pageserver.allowed_errors.append(ingest_lag_log_line)
|
||||
@@ -242,7 +245,6 @@ def test_forward_compatibility(
|
||||
test_output_dir: Path,
|
||||
top_output_dir: Path,
|
||||
pg_version: PgVersion,
|
||||
compatibility_snapshot_dir: Path,
|
||||
compute_reconfigure_listener: ComputeReconfigure,
|
||||
):
|
||||
"""
|
||||
@@ -266,8 +268,14 @@ def test_forward_compatibility(
|
||||
neon_env_builder.neon_binpath = neon_env_builder.compatibility_neon_binpath
|
||||
neon_env_builder.pg_distrib_dir = neon_env_builder.compatibility_pg_distrib_dir
|
||||
|
||||
# Note that we are testing with new data, so we should use `new_compatibility_snapshot_dir`, which is created by test_create_snapshot.
|
||||
new_compatibility_snapshot_dir = (
|
||||
top_output_dir / f"compatibility_snapshot_pg{pg_version.v_prefixed}"
|
||||
)
|
||||
|
||||
log.info(f"Using snapshot dir at {new_compatibility_snapshot_dir}")
|
||||
env = neon_env_builder.from_repo_dir(
|
||||
compatibility_snapshot_dir / "repo",
|
||||
new_compatibility_snapshot_dir / "repo",
|
||||
)
|
||||
# there may be an arbitrary number of unrelated tests run between create_snapshot and here
|
||||
env.pageserver.allowed_errors.append(ingest_lag_log_line)
|
||||
@@ -296,7 +304,7 @@ def test_forward_compatibility(
|
||||
check_neon_works(
|
||||
env,
|
||||
test_output_dir=test_output_dir,
|
||||
sql_dump_path=compatibility_snapshot_dir / "dump.sql",
|
||||
sql_dump_path=new_compatibility_snapshot_dir / "dump.sql",
|
||||
repo_dir=env.repo_dir,
|
||||
)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from typing import TYPE_CHECKING
|
||||
import fixtures.utils
|
||||
import pytest
|
||||
from fixtures.auth_tokens import TokenScope
|
||||
from fixtures.common_types import TenantId, TenantShardId, TimelineId
|
||||
from fixtures.common_types import Lsn, TenantId, TenantShardId, TimelineId
|
||||
from fixtures.log_helper import log
|
||||
from fixtures.neon_fixtures import (
|
||||
DEFAULT_AZ_ID,
|
||||
@@ -47,6 +47,7 @@ from fixtures.utils import (
|
||||
wait_until,
|
||||
)
|
||||
from fixtures.workload import Workload
|
||||
from requests.adapters import HTTPAdapter
|
||||
from urllib3 import Retry
|
||||
from werkzeug.wrappers.response import Response
|
||||
|
||||
@@ -72,6 +73,12 @@ def get_node_shard_counts(env: NeonEnv, tenant_ids):
|
||||
return counts
|
||||
|
||||
|
||||
class DeletionAPIKind(Enum):
|
||||
OLD = "old"
|
||||
FORCE = "force"
|
||||
GRACEFUL = "graceful"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(**fixtures.utils.allpairs_versions())
|
||||
def test_storage_controller_smoke(
|
||||
neon_env_builder: NeonEnvBuilder, compute_reconfigure_listener: ComputeReconfigure, combination
|
||||
@@ -990,7 +997,7 @@ def test_storage_controller_compute_hook_retry(
|
||||
|
||||
|
||||
@run_only_on_default_postgres("postgres behavior is not relevant")
|
||||
def test_storage_controller_compute_hook_keep_failing(
|
||||
def test_storage_controller_compute_hook_stuck_reconciles(
|
||||
httpserver: HTTPServer,
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
httpserver_listen_address: ListenAddress,
|
||||
@@ -1040,7 +1047,7 @@ def test_storage_controller_compute_hook_keep_failing(
|
||||
env.storage_controller.allowed_errors.append(NOTIFY_BLOCKED_LOG)
|
||||
env.storage_controller.allowed_errors.extend(NOTIFY_FAILURE_LOGS)
|
||||
env.storage_controller.allowed_errors.append(".*Keeping extra secondaries.*")
|
||||
env.storage_controller.allowed_errors.append(".*Shard reconciliation is keep-failing.*")
|
||||
env.storage_controller.allowed_errors.append(".*Shard reconciliation is stuck.*")
|
||||
env.storage_controller.node_configure(banned_tenant_ps.id, {"availability": "Offline"})
|
||||
|
||||
# Migrate all allowed tenant shards to the first alive pageserver
|
||||
@@ -1055,7 +1062,7 @@ def test_storage_controller_compute_hook_keep_failing(
|
||||
|
||||
# Make some reconcile_all calls to trigger optimizations
|
||||
# RECONCILE_COUNT must be greater than storcon's MAX_CONSECUTIVE_RECONCILIATION_ERRORS
|
||||
RECONCILE_COUNT = 12
|
||||
RECONCILE_COUNT = 20
|
||||
for i in range(RECONCILE_COUNT):
|
||||
try:
|
||||
n = env.storage_controller.reconcile_all()
|
||||
@@ -1068,6 +1075,8 @@ def test_storage_controller_compute_hook_keep_failing(
|
||||
assert banned_descr["shards"][0]["is_pending_compute_notification"] is True
|
||||
time.sleep(2)
|
||||
|
||||
env.storage_controller.assert_log_contains(".*Shard reconciliation is stuck.*")
|
||||
|
||||
# Check that the allowed tenant shards are optimized due to affinity rules
|
||||
locations = alive_pageservers[0].http_client().tenant_list_locations()["tenant_shards"]
|
||||
not_optimized_shard_count = 0
|
||||
@@ -2572,9 +2581,11 @@ def test_background_operation_cancellation(neon_env_builder: NeonEnvBuilder):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("while_offline", [True, False])
|
||||
@pytest.mark.parametrize("deletion_api", [DeletionAPIKind.OLD, DeletionAPIKind.FORCE])
|
||||
def test_storage_controller_node_deletion(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
while_offline: bool,
|
||||
deletion_api: DeletionAPIKind,
|
||||
):
|
||||
"""
|
||||
Test that deleting a node works & properly reschedules everything that was on the node.
|
||||
@@ -2598,6 +2609,8 @@ def test_storage_controller_node_deletion(
|
||||
assert env.storage_controller.reconcile_all() == 0
|
||||
|
||||
victim = env.pageservers[-1]
|
||||
if deletion_api == DeletionAPIKind.FORCE and not while_offline:
|
||||
victim.allowed_errors.append(".*request was dropped before completing.*")
|
||||
|
||||
# The procedure a human would follow is:
|
||||
# 1. Mark pageserver scheduling=pause
|
||||
@@ -2621,7 +2634,12 @@ def test_storage_controller_node_deletion(
|
||||
wait_until(assert_shards_migrated)
|
||||
|
||||
log.info(f"Deleting pageserver {victim.id}")
|
||||
env.storage_controller.node_delete_old(victim.id)
|
||||
if deletion_api == DeletionAPIKind.FORCE:
|
||||
env.storage_controller.node_delete(victim.id, force=True)
|
||||
elif deletion_api == DeletionAPIKind.OLD:
|
||||
env.storage_controller.node_delete_old(victim.id)
|
||||
else:
|
||||
raise AssertionError(f"Invalid deletion API: {deletion_api}")
|
||||
|
||||
if not while_offline:
|
||||
|
||||
@@ -2634,7 +2652,15 @@ def test_storage_controller_node_deletion(
|
||||
wait_until(assert_victim_evacuated)
|
||||
|
||||
# The node should be gone from the list API
|
||||
assert victim.id not in [n["id"] for n in env.storage_controller.node_list()]
|
||||
def assert_node_is_gone():
|
||||
assert victim.id not in [n["id"] for n in env.storage_controller.node_list()]
|
||||
|
||||
if deletion_api == DeletionAPIKind.FORCE:
|
||||
wait_until(assert_node_is_gone)
|
||||
elif deletion_api == DeletionAPIKind.OLD:
|
||||
assert_node_is_gone()
|
||||
else:
|
||||
raise AssertionError(f"Invalid deletion API: {deletion_api}")
|
||||
|
||||
# No tenants should refer to the node in their intent
|
||||
for tenant_id in tenant_ids:
|
||||
@@ -2656,7 +2682,11 @@ def test_storage_controller_node_deletion(
|
||||
env.storage_controller.consistency_check()
|
||||
|
||||
|
||||
def test_storage_controller_node_delete_cancellation(neon_env_builder: NeonEnvBuilder):
|
||||
@pytest.mark.parametrize("deletion_api", [DeletionAPIKind.FORCE, DeletionAPIKind.GRACEFUL])
|
||||
def test_storage_controller_node_delete_cancellation(
|
||||
neon_env_builder: NeonEnvBuilder,
|
||||
deletion_api: DeletionAPIKind,
|
||||
):
|
||||
neon_env_builder.num_pageservers = 3
|
||||
neon_env_builder.num_azs = 3
|
||||
env = neon_env_builder.init_configs()
|
||||
@@ -2680,12 +2710,16 @@ def test_storage_controller_node_delete_cancellation(neon_env_builder: NeonEnvBu
|
||||
assert len(nodes) == 3
|
||||
|
||||
env.storage_controller.configure_failpoints(("sleepy-delete-loop", "return(10000)"))
|
||||
env.storage_controller.configure_failpoints(("delete-node-after-reconciles-spawned", "pause"))
|
||||
|
||||
ps_id_to_delete = env.pageservers[0].id
|
||||
|
||||
env.storage_controller.warm_up_all_secondaries()
|
||||
|
||||
assert deletion_api in [DeletionAPIKind.FORCE, DeletionAPIKind.GRACEFUL]
|
||||
force = deletion_api == DeletionAPIKind.FORCE
|
||||
env.storage_controller.retryable_node_operation(
|
||||
lambda ps_id: env.storage_controller.node_delete(ps_id),
|
||||
lambda ps_id: env.storage_controller.node_delete(ps_id, force),
|
||||
ps_id_to_delete,
|
||||
max_attempts=3,
|
||||
backoff=2,
|
||||
@@ -2701,6 +2735,8 @@ def test_storage_controller_node_delete_cancellation(neon_env_builder: NeonEnvBu
|
||||
|
||||
env.storage_controller.cancel_node_delete(ps_id_to_delete)
|
||||
|
||||
env.storage_controller.configure_failpoints(("delete-node-after-reconciles-spawned", "off"))
|
||||
|
||||
env.storage_controller.poll_node_status(
|
||||
ps_id_to_delete,
|
||||
PageserverAvailability.ACTIVE,
|
||||
@@ -3252,7 +3288,10 @@ def test_storage_controller_ps_restarted_during_drain(neon_env_builder: NeonEnvB
|
||||
wait_until(reconfigure_node_again)
|
||||
|
||||
|
||||
def test_ps_unavailable_after_delete(neon_env_builder: NeonEnvBuilder):
|
||||
@pytest.mark.parametrize("deletion_api", [DeletionAPIKind.OLD, DeletionAPIKind.FORCE])
|
||||
def test_ps_unavailable_after_delete(
|
||||
neon_env_builder: NeonEnvBuilder, deletion_api: DeletionAPIKind
|
||||
):
|
||||
neon_env_builder.num_pageservers = 3
|
||||
|
||||
env = neon_env_builder.init_start()
|
||||
@@ -3265,10 +3304,16 @@ def test_ps_unavailable_after_delete(neon_env_builder: NeonEnvBuilder):
|
||||
assert_nodes_count(3)
|
||||
|
||||
ps = env.pageservers[0]
|
||||
env.storage_controller.node_delete_old(ps.id)
|
||||
|
||||
# After deletion, the node count must be reduced
|
||||
assert_nodes_count(2)
|
||||
if deletion_api == DeletionAPIKind.FORCE:
|
||||
ps.allowed_errors.append(".*request was dropped before completing.*")
|
||||
env.storage_controller.node_delete(ps.id, force=True)
|
||||
wait_until(lambda: assert_nodes_count(2))
|
||||
elif deletion_api == DeletionAPIKind.OLD:
|
||||
env.storage_controller.node_delete_old(ps.id)
|
||||
assert_nodes_count(2)
|
||||
else:
|
||||
raise AssertionError(f"Invalid deletion API: {deletion_api}")
|
||||
|
||||
# Running pageserver CLI init in a separate thread
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
@@ -4814,3 +4859,103 @@ def test_storage_controller_migrate_with_pageserver_restart(
|
||||
"shards": [{"node_id": int(secondary.id), "shard_number": 0}],
|
||||
"preferred_az": DEFAULT_AZ_ID,
|
||||
}
|
||||
|
||||
|
||||
@run_only_on_default_postgres("PG version is not important for this test")
|
||||
def test_storage_controller_forward_404(neon_env_builder: NeonEnvBuilder):
|
||||
"""
|
||||
Ensures that the storage controller correctly forwards 404s and converts some of them
|
||||
into 503s before forwarding to the client.
|
||||
"""
|
||||
neon_env_builder.num_pageservers = 2
|
||||
neon_env_builder.num_azs = 2
|
||||
|
||||
env = neon_env_builder.init_start()
|
||||
env.storage_controller.allowed_errors.append(".*Reconcile error.*")
|
||||
env.storage_controller.allowed_errors.append(".*Timed out.*")
|
||||
|
||||
env.storage_controller.tenant_policy_update(env.initial_tenant, {"placement": {"Attached": 1}})
|
||||
env.storage_controller.reconcile_until_idle()
|
||||
|
||||
# 404s on tenants and timelines are forwarded as-is when reconciler is not running.
|
||||
|
||||
# Access a non-existing timeline -> 404
|
||||
with pytest.raises(PageserverApiException) as e:
|
||||
env.storage_controller.pageserver_api().timeline_detail(
|
||||
env.initial_tenant, TimelineId.generate()
|
||||
)
|
||||
assert e.value.status_code == 404
|
||||
with pytest.raises(PageserverApiException) as e:
|
||||
env.storage_controller.pageserver_api().timeline_lsn_lease(
|
||||
env.initial_tenant, TimelineId.generate(), Lsn(0)
|
||||
)
|
||||
assert e.value.status_code == 404
|
||||
|
||||
# Access a non-existing tenant when reconciler is not running -> 404
|
||||
with pytest.raises(PageserverApiException) as e:
|
||||
env.storage_controller.pageserver_api().timeline_detail(
|
||||
TenantId.generate(), env.initial_timeline
|
||||
)
|
||||
assert e.value.status_code == 404
|
||||
with pytest.raises(PageserverApiException) as e:
|
||||
env.storage_controller.pageserver_api().timeline_lsn_lease(
|
||||
TenantId.generate(), env.initial_timeline, Lsn(0)
|
||||
)
|
||||
assert e.value.status_code == 404
|
||||
|
||||
# Normal requests should succeed
|
||||
detail = env.storage_controller.pageserver_api().timeline_detail(
|
||||
env.initial_tenant, env.initial_timeline
|
||||
)
|
||||
last_record_lsn = Lsn(detail["last_record_lsn"])
|
||||
env.storage_controller.pageserver_api().timeline_lsn_lease(
|
||||
env.initial_tenant, env.initial_timeline, last_record_lsn
|
||||
)
|
||||
|
||||
# Get into a situation where the intent state is not the same as the observed state.
|
||||
describe = env.storage_controller.tenant_describe(env.initial_tenant)["shards"][0]
|
||||
current_primary = describe["node_attached"]
|
||||
current_secondary = describe["node_secondary"][0]
|
||||
assert current_primary != current_secondary
|
||||
|
||||
# Pause the reconciler so that the generation number won't be updated.
|
||||
env.storage_controller.configure_failpoints(
|
||||
("reconciler-live-migrate-post-generation-inc", "pause")
|
||||
)
|
||||
|
||||
# Do the migration in another thread; the request will be dropped as we don't wait.
|
||||
shard_zero = TenantShardId(env.initial_tenant, 0, 0)
|
||||
concurrent.futures.ThreadPoolExecutor(max_workers=1).submit(
|
||||
env.storage_controller.tenant_shard_migrate,
|
||||
shard_zero,
|
||||
current_secondary,
|
||||
StorageControllerMigrationConfig(override_scheduler=True),
|
||||
)
|
||||
# Not the best way to do this, we should wait until the migration gets started.
|
||||
time.sleep(1)
|
||||
placement = env.storage_controller.get_tenants_placement()[str(shard_zero)]
|
||||
assert placement["observed"] != placement["intent"]
|
||||
assert placement["observed"]["attached"] == current_primary
|
||||
assert placement["intent"]["attached"] == current_secondary
|
||||
|
||||
# Now we issue requests that would cause 404 again
|
||||
retry_strategy = Retry(total=0)
|
||||
adapter = HTTPAdapter(max_retries=retry_strategy)
|
||||
|
||||
no_retry_api = env.storage_controller.pageserver_api()
|
||||
no_retry_api.mount("http://", adapter)
|
||||
no_retry_api.mount("https://", adapter)
|
||||
|
||||
# As intent state != observed state, tenant not found error should return 503,
|
||||
# so that the client can retry once we've successfully migrated.
|
||||
with pytest.raises(PageserverApiException) as e:
|
||||
no_retry_api.timeline_detail(env.initial_tenant, TimelineId.generate())
|
||||
assert e.value.status_code == 503, f"unexpected status code and error: {e.value}"
|
||||
with pytest.raises(PageserverApiException) as e:
|
||||
no_retry_api.timeline_lsn_lease(env.initial_tenant, TimelineId.generate(), Lsn(0))
|
||||
assert e.value.status_code == 503, f"unexpected status code and error: {e.value}"
|
||||
|
||||
# Unblock reconcile operations
|
||||
env.storage_controller.configure_failpoints(
|
||||
("reconciler-live-migrate-post-generation-inc", "off")
|
||||
)
|
||||
|
||||
2
vendor/postgres-v14
vendored
2
vendor/postgres-v14
vendored
Submodule vendor/postgres-v14 updated: ac3c460e01...47304b9215
2
vendor/postgres-v15
vendored
2
vendor/postgres-v15
vendored
Submodule vendor/postgres-v15 updated: 24313bf8f3...cef72d5308
2
vendor/postgres-v16
vendored
2
vendor/postgres-v16
vendored
Submodule vendor/postgres-v16 updated: 51194dc5ce...e9db1ff5a6
2
vendor/postgres-v17
vendored
2
vendor/postgres-v17
vendored
Submodule vendor/postgres-v17 updated: eac5279cd1...a50d80c750
8
vendor/revisions.json
vendored
8
vendor/revisions.json
vendored
@@ -1,18 +1,18 @@
|
||||
{
|
||||
"v17": [
|
||||
"17.5",
|
||||
"eac5279cd147d4086e0eb242198aae2f4b766d7b"
|
||||
"a50d80c7507e8ae9fc37bf1869051cf2d51370ab"
|
||||
],
|
||||
"v16": [
|
||||
"16.9",
|
||||
"51194dc5ce2e3523068d8607852e6c3125a17e58"
|
||||
"e9db1ff5a6f3ca18f626ba3d62ab475e6c688a96"
|
||||
],
|
||||
"v15": [
|
||||
"15.13",
|
||||
"24313bf8f3de722968a2fdf764de7ef77ed64f06"
|
||||
"cef72d5308ddce3795a9043fcd94f8849f7f4800"
|
||||
],
|
||||
"v14": [
|
||||
"14.18",
|
||||
"ac3c460e01a31f11fb52fd8d8e88e60f0e1069b4"
|
||||
"47304b921555b3f33eb3b49daada3078e774cfd7"
|
||||
]
|
||||
}
|
||||
|
||||
@@ -107,7 +107,6 @@ tracing-core = { version = "0.1" }
|
||||
tracing-log = { version = "0.2" }
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
||||
url = { version = "2", features = ["serde"] }
|
||||
uuid = { version = "1", features = ["serde", "v4", "v7"] }
|
||||
zeroize = { version = "1", features = ["derive", "serde"] }
|
||||
zstd = { version = "0.13" }
|
||||
zstd-safe = { version = "7", default-features = false, features = ["arrays", "legacy", "std", "zdict_builder"] }
|
||||
|
||||
Reference in New Issue
Block a user